Перейти к основному содержимому

Контрастное обучение

Определение

Контрастное обучение (contrastive learning, [1],[2],[3]) настраивает отображение объекта x\mathbf{x} в признаковое представление или эмбеддинг (embedding) f(x)RKf(\mathbf{x})\in \mathbb{R}^K таким образом, чтобы "похожие" объекты были близки, а "непохожие" - далеки друг от друга в пространстве эмбеддингов.

Преобразование f(x)f(\mathbf{x}), отображающее объект в его эмбеддинг, называется сиамской сетью (siamese network), поскольку при настройке точная копия этой сети будет применяться не к одному, а сразу к нескольким объектам.

При этом размерность пространства эмбеддингов KK обычно невелика и составляет несколько сотен признаков.

Примеры использования

Ниже приводятся примеры задач, решаемых контрастным обучением, и то, какие объекты считаются похожими и непохожими в каждой задаче:

NЗадачаПохожие объектыНепохожие объекты
1классификацияпринадлежат одному классупринадлежат разным классам
2проверка подписи человека по её сканыподписи одного и того же человекаподписи разных людей
3обнаружение перефразированияперефразирования одной и той же фразыперефразирования разных фраз
4обнаружение одинаковых изображенийпреобразования одного и того же изображения (поворот, обрезка, добавление шума, изменение цветов)разные изображения

В задаче 1,2,3 используется разметка, поэтому это задача обучения с учителем, известная как supervised contrasitve learning. Задача 4 - это задача обучения без учителя, поскольку внешняя разметка в ней не используется. Эта задача известна как instance discrimination.

Настройка сиамских сетей производится по группам объектов, причем различаются ситуации, когда объекты похожи, а когда нет, поскольку это будет влиять на то, какая функция потерь к ним будет применяться.

Ниже приведён пример эмбеддингов, получаемых в задаче классификации рукописных цифр на датасете MNIST, полученные с промежуточных слоёв обычной классификационной сети (слева) и с финальных слоёв сиамской сети f(x)f(x) (справа).

embeddings_comparison.png

Как видим, эмбеддинги сиамской сети сильнее раздвинуты для объектов разных классов по сравнению с эмбеддингами обычной классификационной сети, что следует из функции потерь контрастного обучения, которые сближают эмбеддинги похожих объектов (объектов одного класса) и раздвигают эмбеддинги непохожих (разных классов).

После того, как сиамская сеть настроена, можно решать конечную задачу. Если эта классификация, то можно

  • инициализировать классификационную сеть первыми слоями сиамской сети (особенно обученной решать задачу instance discrimination);

  • решать классификацию в пространстве эмбеддингов. Поскольку эмбеддинги уже хорошо разделяют классы, то можно использовать метод ближайших центроидов или (лучше, но более ресурсоёмко) метод K ближайших соседей.

Обработка объектов разных типов

Также применяется контрастное обучение для объектов разных типов с разными преобразованиями для каждого типа:

ЗадачаПохожие объектыНепохожие объекты
ранжирование, информационный поискпоисковый запрос и соответствующие ему документыпоисковый запрос и нерелевантные документы
построение текстового описания к изображениюизображение и соответствующие ему описаниеизображение и не соответствующие ему описания
рекомендательные системыпользователь и товары, которые ему понравилисьпользователь и товары, которые ему не понравились или по которым нет данных

Например, для рекомендательной системы будут настраиваться две нейросети f()f(\cdot) и g()g(\cdot):

f(u):пользовательэмбеддингg(i):товарэмбеддинг\begin{aligned} f(\mathbf{u}): \text{пользователь} \to \text{эмбеддинг} \\ g(\mathbf{i}): \text{товар} \to \text{эмбеддинг} \end{aligned}

Различные преобразования f()f(\cdot) и g()g(\cdot) будут отображать объекты в общее пространство эмбеддингов, сохраняя свойство, что эмбеддинги соответствующих друг другу объектов должны быть близки, а не соответствующих - далеки друг от друга.

Настройка сиамской сети

Рассмотрим случай использования одной сиамской сети f()f(\cdot).

Существуют три основных функции потерь для их настройки.

Попарные потери

При обучении, используя попарные потери (pairwise contrastive loss, spring loss), сэмплируются пары объектов xi,xj\mathbf{x}_i,\mathbf{x}_j, и происходит минимизация

L(xi,xj)={ρ(xi,xj)2,если xi,xj похожиmax{0,αρ(xi,xj)}2,если xi,xj не похожи\mathcal{L}(\mathbf{x}_i,\mathbf{x}_j)= \begin{cases} \rho(\mathbf{x}_i,\mathbf{x}_j)^{2}, & \text{если } \mathbf{x}_i,\mathbf{x}_j \text{ похожи} \\ \max\left\{ 0,\alpha-\rho(\mathbf{x}_i,\mathbf{x}_j)\right\}^{2}, & \text{если } \mathbf{x}_i,\mathbf{x}_j \text{ не похожи} \\ \end{cases}

В качестве ρ(xi,xj)\rho(\mathbf{x}_i,\mathbf{x}_j) обычно берётся Евклидово расстояние.

Гиперпараметр α>0\alpha>0 управляет минимальным расстоянием между непохожими объектами, при котором не будет штрафа. Его можно выбрать равным единице.

Поскольку сэмплируются всевозможные пары объектов, то число уникальных сэмплов будет O(N2)O(N^2), где NN - число объектов в обучающей выборке.

Тройные потери

При обучении, используя тройные потери (triplet loss) сэмплируются тройки объектов:

  • x\mathbf{x} - опорный объект (anchor),

  • x+\mathbf{x}^{+} - похожий на x\mathbf{x} (положительный, positive),

  • x\mathbf{x}^{-} - не похожий на x\mathbf{x} (отрицательный, negative).

Сами тройные потери штрафуют ситуацию, когда эмбеддинги опорного и положительного объекта становятся слишком непохожими относительно различий эмбеддингов опорного и отрицательного объекта:

L(x,x+,x)=max{ρ(x,x+)2ρ(x,x)2+α;0}\mathcal{L}(\mathbf{x},\mathbf{x}^{+},\mathbf{x}^{-})=\max\left\{ \rho(\mathbf{x},\mathbf{x}^+)^{2}-\rho(\mathbf{x},\mathbf{x}^-)^{2}+\alpha;0\right\}

Поскольку сэмплируются тройки объектов, то число уникальных сэмплов будет иметь порядок O(N3)O(N^3).

Вероятностные потери

При обучении, используя вероятностные потери (InfoNCE loss, NCE=noise constrastive estimation, [1],[2]), сэмплируются

  • x\mathbf{x} - опорный объект (anchor)

  • x+\mathbf{x}^{+} - похожий на xx (положительный, positive)

  • x1,...xM\mathbf{x}_1,...\mathbf{x}_M - набор непохожих на x\mathbf{x} объектов (отрицательные, negatives)

L(x,x+,x1,...xM)=lnesim(x,x+)esim(x,x+)+m=1Mesim(x,xm)\mathcal{L}(\mathbf{x},\mathbf{x}^+,\mathbf{x}^-_1,...\mathbf{x}^-_M)=-\ln\frac{e^{\text{sim}(\mathbf{x},\mathbf{x}^+)}}{e^{\text{sim}(\mathbf{x},\mathbf{x}^+)}+\sum_{m=1}^M e^{\text{sim}(\mathbf{x},\mathbf{x}^-_m)}}

где sim(x,x)\text{sim}(x,x') - косинусная мера близости:

sim(x,x)=xTxxx\text{sim}(x,x')=\frac{\mathbf{x}^T \mathbf{x}'}{||\mathbf{x}||\cdot||\mathbf{x}'||}

За счёт сэмплирования не одного, а целого набора из MM непохожих объектов число уникальных сэмплов будет иметь порядок больше O(NM+2)O(N^{M+2}).

Важность числа уникальных сэмплов

При классической настройке обычной сети по обучающей выборке из NN объектов существует всего NN уникальных сэмплов, по которым считается функция потерь. Как мы видели, при контрастном оценивании число уникальных сэмплов гораздо больше, что позволяет их качественно настраивать, используя более маленькие обучающие выборки.

Даже если есть всего один пример класса, то сравнивая его с каждым из других объектов уже получим N1N-1 уникальных сэмплов, поэтому контрастное обучение может эффективно использоваться в обучении few-shot learning, когда доступны всего несколько примеров класса.

Генерация обучающих примеров

Не ограничивая общности, рассмотрим, задачу классификации, решаемую контрастным оцениванием. При генерации обучающих примеров можно сэмплировать равномерно

– по объектам

– по классам (типам непохожести между объектами).

При этом в первом случае будет максимизироваться микроусреднённые меры качества на объектах, а во втором - макроусреднённые меры на классах.

Обучение расстояний (metric learning)

Приёмы контрастного обучения можно использовать и для обучения расстояний (metric learning), когда функция расстояния в метрических методах прогнозирования параметризуется, а параметры подбираются так, чтобы лучше решать конечную задачу. В случае задачи классификации можно подбирать такую меру близости (или расстояния) между объектами, чтобы объекты одного класса оказывались близки, а объекты разных классов - далеки друг от друга.

Литература

  1. Gutmann M., Hyvärinen A. Noise-contrastive estimation: A new estimation principle for unnormalized statistical models //Proceedings of the thirteenth international conference on artificial intelligence and statistics. – JMLR Workshop and Conference Proceedings, 2010. – С. 297-304.

  2. Oord A., Li Y., Vinyals O. Representation learning with contrastive predictive coding //arXiv preprint arXiv:1807.03748. – 2018.

  3. Chen T. et al. A simple framework for contrastive learning of visual representations //International conference on machine learning. – PMLR, 2020. – С. 1597-1607.