Перейти к основному содержимому

Контрастное обучение

Определение

Контрастное обучение (contrastive learning, [1],[2],[3]) настраивает отображение объекта x\mathbf{x} в признаковое представление или эмбеддинг (embedding) f(x)RKf(\mathbf{x})\in \mathbb{R}^K таким образом, чтобы "похожие" объекты были близки, а "непохожие" - далеки друг от друга в пространстве эмбеддингов.

Преобразование f(x)f(\mathbf{x}), отображающее объект в его эмбеддинг, называется сиамской сетью (siamese network), поскольку при настройке точная копия этой сети будет применяться не к одному, а сразу к нескольким объектам.

При этом размерность пространства эмбеддингов KK обычно невелика и составляет несколько сотен признаков.

Примеры использования

Ниже приводятся примеры задач, решаемых с помощью контрастного обучения, и то, какие объекты считаются похожими и непохожими в каждой задаче:

NЗадачаПохожие объектыНепохожие объекты
1классификацияпринадлежат одному классупринадлежат разным классам
2проверка подписи человека по её отсканированной версииподписи одного и того же человекаподписи разных людей
3обнаружение перефразированияперефразирования одной и той же фразыперефразирования разных фраз
4обнаружение одинаковых изображенийпреобразования одного и того же изображения (поворот, обрезка, добавление шума, изменение цветов)разные изображения

В задаче 1,2,3 используется разметка, поэтому это задача обучения с учителем, известная как supervised contrasitve learning. Задача 4 - это задача обучения без учителя, поскольку внешняя разметка в ней не используется. Эта задача известна как instance discrimination.

Настройка сиамских сетей производится по группам объектов, причем различаются ситуации, когда объекты похожи, а когда нет, поскольку это будет влиять на то, какая функция потерь к ним будет применяться.

Ниже приведён пример эмбеддингов, построенных по задаче классификации рукописных цифр на датасете MNIST и извлечённых с промежуточных слоёв обычной классификационной сети (слева) и с финальных слоёв сиамской сети f(x)f(x) (справа).

embeddings_comparison.png

Как видим, эмбеддинги сиамской сети сильнее раздвинуты для объектов разных классов по сравнению с эмбеддингами обычной классификационной сети, что следует из функции потерь контрастного обучения, которые сближают эмбеддинги похожих объектов (объектов одного класса) и раздвигают эмбеддинги непохожих (разных классов).

После того, как сиамская сеть настроена, можно решать конечную задачу. Если это задача классификации, то можно

  • инициализировать классификационную сеть первыми слоями сиамской сети (особенно обученной решать задачу instance discrimination);

  • решать классификацию в пространстве эмбеддингов. Поскольку эмбеддинги уже хорошо разделяют классы, то можно использовать метод ближайших центроидов или метод K ближайших соседей. Последний метод работает точнее, но требует больше вычислительных ресурсов.

Обработка объектов разных типов

Также применяется контрастное обучение для объектов разных типов с разными преобразованиями для каждого типа:

ЗадачаПохожие объектыНепохожие объекты
ранжирование, информационный поискпоисковый запрос и соответствующие ему документыпоисковый запрос и нерелевантные документы
построение текстового описания к изображениюизображение и соответствующее ему описаниеизображение и не соответствующие ему описания
рекомендательные системыпользователь и товары, которые ему понравилисьпользователь и товары, которые ему не понравились или по которым нет данных

Например, для рекомендательной системы будут настраиваться две нейросети f()f(\cdot) и g()g(\cdot):

f(u):пользовательэмбеддингg(i):товарэмбеддинг\begin{aligned} f(\mathbf{u}): \text{пользователь} \to \text{эмбеддинг} \\ g(\mathbf{i}): \text{товар} \to \text{эмбеддинг} \end{aligned}

Различные преобразования f()f(\cdot) и g()g(\cdot) будут отображать объекты в общее пространство эмбеддингов, сохраняя свойство, что эмбеддинги соответствующих друг другу объектов должны быть близки, а не соответствующих - далеки друг от друга.

Настройка сиамской сети

Рассмотрим случай использования одной сиамской сети f()f(\cdot).

Существуют три основные функции потерь для их настройки.

Попарные потери

При обучении с использованием попарных потерь (pairwise contrastive loss, spring loss) сэмплируются пары объектов xi,xj\mathbf{x}_i,\mathbf{x}_j, и происходит минимизация

L(xi,xj)={ρ(xi,xj)2,если xi,xj похожиmax{0,αρ(xi,xj)}2,если xi,xj не похожи\mathcal{L}(\mathbf{x}_i,\mathbf{x}_j)= \begin{cases} \rho(\mathbf{x}_i,\mathbf{x}_j)^{2}, & \text{если } \mathbf{x}_i,\mathbf{x}_j \text{ похожи} \\ \max\left\{ 0,\alpha-\rho(\mathbf{x}_i,\mathbf{x}_j)\right\}^{2}, & \text{если } \mathbf{x}_i,\mathbf{x}_j \text{ не похожи} \\ \end{cases}

В качестве ρ(xi,xj)\rho(\mathbf{x}_i,\mathbf{x}_j) обычно берётся Евклидово расстояние.

Гиперпараметр α>0\alpha>0 управляет минимальным расстоянием между непохожими объектами, при котором не будет штрафа. Его можно выбрать равным единице.

Поскольку сэмплируются всевозможные пары объектов, то число уникальных сэмплов будет O(N2)O(N^2), где NN - число объектов в обучающей выборке.

Тройные потери

При обучении, используя тройные потери (triplet loss), сэмплируются тройки объектов:

  • x\mathbf{x} - опорный объект (anchor);

  • x+\mathbf{x}^{+} - похожий на x\mathbf{x} (положительный, positive);

  • x\mathbf{x}^{-} - не похожий на x\mathbf{x} (отрицательный, negative).

Сами тройные потери штрафуют ситуацию, когда эмбеддинги опорного и положительного объекта становятся слишком непохожими относительно различий эмбеддингов опорного и отрицательного объекта:

L(x,x+,x)=max{ρ(x,x+)2ρ(x,x)2+α;0}\mathcal{L}(\mathbf{x},\mathbf{x}^{+},\mathbf{x}^{-})=\max\left\{ \rho(\mathbf{x},\mathbf{x}^+)^{2}-\rho(\mathbf{x},\mathbf{x}^-)^{2}+\alpha;0\right\}

Поскольку сэмплируются тройки объектов, то число уникальных сэмплов будет иметь порядок O(N3)O(N^3).

Вероятностные потери

При обучении, используя вероятностные потери (InfoNCE loss, NCE=noise constrastive estimation, [1],[2]), сэмплируются

  • x\mathbf{x} - опорный объект (anchor);

  • x+\mathbf{x}^{+} - похожий на xx (положительный, positive);

  • x1,...xM\mathbf{x}_1,...\mathbf{x}_M - набор непохожих на x\mathbf{x} объектов (отрицательные, negatives).

Сами вероятностные потери имеют вид:

L(x,x+,x1,...xM)=lnesim(x,x+)esim(x,x+)+m=1Mesim(x,xm),\mathcal{L}(\mathbf{x},\mathbf{x}^+,\mathbf{x}^-_1,...\mathbf{x}^-_M)=-\ln\frac{e^{\text{sim}(\mathbf{x},\mathbf{x}^+)}}{e^{\text{sim}(\mathbf{x},\mathbf{x}^+)}+\sum_{m=1}^M e^{\text{sim}(\mathbf{x},\mathbf{x}^-_m)}},

где sim(x,x)\text{sim}(x,x') - косинусная мера близости:

sim(x,x)=xTxxx\text{sim}(x,x')=\frac{\mathbf{x}^T \mathbf{x}'}{||\mathbf{x}||\cdot||\mathbf{x}'||}

За счёт сэмплирования не одного, а целого набора из MM непохожих объектов число уникальных сэмплов будет иметь порядок больше O(NM+2)O(N^{M+2}).

Важность числа уникальных сэмплов

При классической настройке обычной сети по обучающей выборке из NN объектов существует всего NN уникальных сэмплов, по которым считается функция потерь. Как мы видели, при контрастном оценивании число уникальных сэмплов гораздо больше, что позволяет их качественно настраивать, используя более компактные обучающие выборки, состоящие из меньшего числа объектов.

Даже если есть всего один пример класса, то, сравнивая его с каждым из других объектов, уже получим N1N-1 уникальных сэмплов, поэтому контрастное обучение может эффективно использоваться в обучении всего по нескольким примерам класса (few-shot learning).

Генерация обучающих примеров

Не ограничивая общности, рассмотрим задачу классификации, решаемую методом контрастного обучения. При генерации обучающих примеров можно сэмплировать равномерно

– по объектам;

– по классам (типам непохожести между объектами).

В первом случае будут оптимизироваться микроусреднённые меры качества на объектах, а во втором - макроусреднённые меры на классах.

Обучение расстояний (metric learning)

Приёмы контрастного обучения можно использовать и для обучения расстояний (metric learning), когда функция расстояния в метрических методах прогнозирования параметризуется, а параметры подбираются так, чтобы лучше решать конечную задачу. В случае задачи классификации можно подбирать такую меру близости (или расстояния) между объектами, чтобы объекты одного класса оказывались близки, а объекты разных классов - далеки друг от друга.

Литература

  1. Gutmann M., Hyvärinen A. Noise-contrastive estimation: A new estimation principle for unnormalized statistical models //Proceedings of the thirteenth international conference on artificial intelligence and statistics. – JMLR Workshop and Conference Proceedings, 2010. – С. 297-304.

  2. Oord A., Li Y., Vinyals O. Representation learning with contrastive predictive coding //arXiv preprint arXiv:1807.03748. – 2018.

  3. Chen T. et al. A simple framework for contrastive learning of visual representations //International conference on machine learning. – PMLR, 2020. – С. 1597-1607.