Перейти к основному содержимому

Контрастное обучение

Определение

Контрастное обучение (contrastive learning [1]) настраивает отображение объекта x\mathbf{x} в признаковое представление или эмбеддинг (embedding) f(x)RKf(\mathbf{x})\in \mathbb{R}^K таким образом, чтобы похожие объекты были близки, а непохожие - далеки друг от друга в пространстве эмбеддингов. При этом понятие похожести/непохожести определяется решаемой задачей.

Преобразование f(x)f(\mathbf{x}), отображающее объект в его эмбеддинг, называется сиамской сетью (siamese network [1], [2]), поскольку при настройке точная копия этой сети будет применяться не к одному, а сразу к нескольким объектам.

При этом размерность пространства эмбеддингов KK обычно невелика и составляет несколько сотен признаков.

Примеры использования

Ниже приводятся примеры задач, решаемых с помощью контрастного обучения, и то, какие объекты считаются похожими и непохожими в каждой задаче:

NЗадачаПохожие объектыНепохожие объекты
1классификацияпринадлежат одному классупринадлежат разным классам
2проверка подписи человека по её отсканированной версииподписи одного и того же человекаподписи разных людей
3обнаружение перефразированияперефразирования одной и той же фразыперефразирования разных фраз
4обнаружение одинаковых изображенийпреобразования одного и того же изображения (поворот, обрезка, добавление шума, изменение цветов)разные изображения

В задачах 1, 2, 3 используется разметка, поэтому это задача обучения с учителем, известная как supervised contrasitve learning. Задача 4 - это задача обучения без учителя, поскольку внешняя разметка в ней не используется. Эта задача известна как instance discrimination.

Функция потерь контрастного обучения штрафует эмбеддинги похожих объектов за похожесть, а эмбеддинги непохожих объектов штрафует за непохожесть.

Ниже приведён пример эмбеддингов, построенных по задаче классификации рукописных цифр на датасете MNIST [3] и извлечённых с промежуточных слоёв обычной классификационной сети (слева) и с финальных слоёв сиамской сети f(x)f(x) (справа):

embeddings_comparison.png

Похожими объектами выступают сканы одной и той же цифры, а непохожими - сканы разных цифр. Отвечающие эмбеддингам цифры обозначены цветом.

Как видим, эмбеддинги сиамской сети сильнее раздвинуты для объектов разных классов по сравнению с эмбеддингами обычной классификационной сети, что следует из функции потерь контрастного обучения, которые сближают эмбеддинги похожих объектов (объектов одного класса) и раздвигают эмбеддинги непохожих (разных классов).

После того, как сиамская сеть настроена, можно решать конечную задачу метрическими методами в пространстве эмбеддингов, такими как метод ближайших центроидов или метод K ближайших соседей, поскольку эмбеддинги объектов разных классов будут далеко, а одного класса - близко.

Также можно использовать слои обученной сиамской сети для извлечения признаков для другой задачи (transfer learning).

Обработка объектов разных типов

Также контрастное обучение для обработки объектов разных типов. В этом случае каждый тип преобразуется в эмбеддинг своей нейросетью:

ЗадачаПохожие объектыНепохожие объекты
ранжирование, информационный поискпоисковый запрос и соответствующие ему документыпоисковый запрос и нерелевантные документы
построение текстового описания к изображениюизображение и соответствующее ему описаниеизображение и не соответствующие ему описания
рекомендательные системыпользователь и товары, которые ему понравилисьпользователь и товары, которые ему не понравились или по которым нет данных

Например, для рекомендательной системы будут настраиваться две нейросети f()f(\cdot) и g()g(\cdot):

f(u):пользовательэмбеддингg(i):товарэмбеддинг\begin{aligned} f(\mathbf{u}): \text{пользователь} \to \text{эмбеддинг} \\ g(\mathbf{i}): \text{товар} \to \text{эмбеддинг} \end{aligned}

Различные преобразования f()f(\cdot) и g()g(\cdot) будут отображать объекты в общее пространство эмбеддингов, сохраняя свойство, что эмбеддинги соответствующих друг другу объектов должны быть близки, а не соответствующих - далеки друг от друга.

Далее пользователю можно рекомендовать те товары, эмбеддинги которых близки к эмбеддингу пользователя.

Аналогично можно реализовать информационный поиск, отображая в начале поисковой выдачи те документы, эмбеддинги которых максимально похожи на эмбеддинг поискового запроса.

Настройка сиамской сети

Рассмотрим для простоты обозначений обработку объектов одного типа с помощью сиамской сети, отображающей объект в его эмбеддинг:

e=f(x)\mathbf{e}=f(\mathbf{x})

Существуют три основные функции потерь для их настройки.

Попарные потери

При обучении с использованием попарных потерь (pairwise contrastive loss [4]) сэмплируются пары объектов xi,xj\mathbf{x}_i,\mathbf{x}_j, и происходит минимизация

L(xi,xj)={ρ(f(xi),f(xj))2,если xi,xj похожиmax{0,αρ(f(xi),f(xj))}2,если xi,xj не похожи\mathcal{L}(\mathbf{x}_i,\mathbf{x}_j)= \begin{cases} \rho(f(\mathbf{x}_i),f(\mathbf{x}_j))^{2}, & \text{если } \mathbf{x}_i,\mathbf{x}_j \text{ похожи} \\ \max\left\{ 0,\alpha-\rho(f(\mathbf{x}_i),f(\mathbf{x}_j))\right\}^{2}, & \text{если } \mathbf{x}_i,\mathbf{x}_j \text{ не похожи} \\ \end{cases}

В качестве ρ(xi,xj)\rho(\mathbf{x}_i,\mathbf{x}_j) обычно берётся Евклидово расстояние.

Гиперпараметр α>0\alpha>0 управляет минимальным расстоянием между непохожими объектами, при котором не будет штрафа. Его можно выбрать равным единице.

Поскольку сэмплируются всевозможные пары объектов, то число уникальных сэмплов будет O(N2)O(N^2), где NN - число объектов в обучающей выборке.

Тройные потери

При обучении, используя тройные потери (triplet loss [5]), сэмплируются тройки объектов:

  • x\mathbf{x} - опорный объект (anchor);

  • x+\mathbf{x}^{+} - похожий на x\mathbf{x} (положительный, positive);

  • x\mathbf{x}^{-} - не похожий на x\mathbf{x} (отрицательный, negative).

Сами тройные потери штрафуют ситуацию, когда эмбеддинги опорного и положительного объекта становятся слишком непохожими относительно различий эмбеддингов опорного и отрицательного объекта:

L(x,x+,x)=max{ρ(f(x),f(x+))2ρ(f(x),f(x))2+α;0}\mathcal{L}(\mathbf{x},\mathbf{x}^{+},\mathbf{x}^{-})=\max\left\{ \rho(f(\mathbf{x}),f(\mathbf{x}^+))^{2}-\rho(f(\mathbf{x}),f(\mathbf{x}^-))^{2}+\alpha;0\right\}

Поскольку сэмплируются тройки объектов, то число уникальных сэмплов будет иметь порядок O(N3)O(N^3).

Вероятностные потери

При обучении, используя вероятностные потери (InfoNCE loss, NCE=noise constrastive estimation, [6],[7]), сэмплируются

  • x\mathbf{x} - опорный объект (anchor);

  • x+\mathbf{x}^{+} - похожий на xx (положительный, positive);

  • x1,...xM\mathbf{x}_1,...\mathbf{x}_M - набор непохожих на x\mathbf{x} объектов (отрицательные, negatives).

Сами вероятностные потери имеют вид:

L(x,x+,x1,...xM)=lnesim(f(x),f(x+))esim(f(x),f(x+))+m=1Mesim(f(x),f(xm)),\mathcal{L}(\mathbf{x},\mathbf{x}^+,\mathbf{x}^-_1,...\mathbf{x}^-_M)=-\ln\frac{e^{\text{sim}(f(\mathbf{x}),f(\mathbf{x}^+))}}{e^{\text{sim}(f(\mathbf{x}),f(\mathbf{x}^+))}+\sum_{m=1}^M e^{\text{sim}(f(\mathbf{x}),f(\mathbf{x}^-_m))}},

где sim(e,e)\text{sim}(e,e') - косинусная мера близости:

sim(e,e)=eTeee\text{sim}(e,e')=\frac{\mathbf{e}^T \mathbf{e}'}{||\mathbf{e}||\cdot||\mathbf{e}'||}

За счёт сэмплирования не одного, а целого набора из MM непохожих объектов число уникальных сэмплов будет иметь порядок больше O(NM+2)O(N^{M+2}).

Указанные функции потерь можно применять и для объектов разных типов. В этом случае в формулах нужно подставлять эмбеддинги объектов, полученных из разных нейросетей, обрабатывающих соответствующие типы данных.

Важность числа уникальных сэмплов

При классической настройке обычной сети по обучающей выборке из NN объектов существует всего NN уникальных сэмплов, по которым считается функция потерь. При контрастном оценивании число уникальных сэмплов гораздо больше, что позволяет настраивать сиамскую сеть точнее даже для небольших выборок!

Даже если есть всего один пример класса, то, сравнивая его с каждым из других объектов, уже получим N1N-1 уникальных сэмплов, поэтому контрастное обучение может эффективно использоваться в обучении всего по нескольким примерам класса (few-shot learning [8]).

Генерация обучающих примеров

Не ограничивая общности, рассмотрим задачу классификации, решаемую методом контрастного обучения. При генерации обучающих примеров можно сэмплировать равномерно

– по объектам;

– по классам (типам непохожести между объектами).

В первом случае будут оптимизироваться микроусреднённые меры качества (на объектах), а во втором - макроусреднённые меры качества (на классах). Обычно требуется хорошо отделять каждый из классов, поэтому чаще используется второй подход.

Обучение расстояний (metric learning)

Приёмы контрастного обучения можно использовать и для обучения расстояний (metric learning [9]), когда функция расстояния в метрических методах прогнозирования параметризуется, а параметры подбираются так, чтобы лучше решать конечную задачу. В случае задачи классификации можно подбирать такую меру близости (или расстояния) между объектами, чтобы объекты одного класса оказывались близки, а объекты разных классов - далеки друг от друга.


Детальнее о контрастном обучении и его приложениях можно прочитать в [10]. С библиотеками и последними статьями по теме можно ознакомиться в [11].

Литература

  1. Bromley J. et al. Signature verification using a" siamese" time delay neural network //Advances in neural information processing systems. – 1993. – Т. 6.

  2. Wikipedia: Siamese neural network.

  3. LeCun Y. The MNIST database of handwritten digits. – 1998.

  4. Hadsell R., Chopra S., LeCun Y. Dimensionality reduction by learning an invariant mapping //2006 IEEE computer society conference on computer vision and pattern recognition (CVPR'06). – IEEE, 2006. – Т. 2. – С. 1735-1742.

  5. Schroff F., Kalenichenko D., Philbin J. Facenet: A unified embedding for face recognition and clustering //Proceedings of the IEEE conference on computer vision and pattern recognition. – 2015. – С. 815-823.

  6. Gutmann M., Hyvärinen A. Noise-contrastive estimation: A new estimation principle for unnormalized statistical models //Proceedings of the thirteenth international conference on artificial intelligence and statistics. – JMLR Workshop and Conference Proceedings, 2010. – С. 297-304.

  7. Oord A., Li Y., Vinyals O. Representation learning with contrastive predictive coding //arXiv preprint arXiv:1807.03748. – 2018.

  8. ibm.com: What is few-shot learning?

  9. Kulis B. et al. Metric learning: A survey //Foundations and Trends® in Machine Learning. – 2013. – Т. 5. – №. 4. – С. 287-364.

  10. Jaiswal A. et al. A survey on contrastive self-supervised learning //Technologies. – 2020. – Т. 9. – №. 1. – С. 2.

  11. paperswithcode.com: Contrastive Learning.