Перейти к основному содержимому

Контрастное обучение

Определение

Контрастное обучение (contrastive learning, [1],[2],[3]) настраивает отображение объектов xx в признаковое представление или эмбеддинг (embedding) объектов f(x)RKf(x)\in \mathbb{R}^K таким образом, чтобы "похожие" объекты были близки, а "непохожие - далеки друг от друга в пространстве эмбеддингов.

Преобразование f(x), отображающее объект в его эмбеддинг, называется сиамской сетью (siamese network).

При этом размерность пространства эмбеддингов KK обычно невелика и составляет несколько сотен признаков.

Примеры использования

ЗадачаПохожие объектыНепохожие объекты
классификацияпринадлежат одному классупринадлежат разным классам
проверка подписисканы подписей одного и того же человекасканы подписей разных людей
обнаружение перефразированияперефразирования одной и той же фразыперефразирования разных фраз
обнаружение одинаковых изображенийпреобразования одного и того же изображения (поворот, обрезка, добавление шума, изменение цветов)разные изображения

В задаче 1,2,3 используется разметка, поэтому это задача обучения с учителем, известная как supervised contrasitve learning. В задаче 4 разметка не используется - объекты генерируются по самим себе. Эта задача известна как instance discrimination.

При настройке похожие объекты сэмплируются целенаправленно, чтобы быть похожими, а непохожие - сэмплируются случайно. Поскольку число объектов в выборке, как правило, велико, то почти во всех случаях это будет приводить к генерации действительно разных не соответствующих друг другу объектов.

Ниже приведён пример эмбеддингов, получаемых в задаче классификации рукописных цифр на датасете MNIST, полученные с промежуточных слоёв обычной классификационной сети (слева) и с финальных слоёв сиамской сети f(x)f(x) (справа).

embeddings_comparison.png

Как видим, эмбеддинги сиамской сети сильнее раздвинуты для объектов разных классов по сравнению с эмбеддингами обычной классификационной сети.

После того, как сиамская сеть настроена, можно решать конечную задачу. Если эта классификация, то можно

  • инициализировать классификационную сеть первыми слоями сиамской сети (особенно для instance discrimination)

  • решать классификацию в пространстве эмбеддингов. Поскольку эмбеддинги уже хорошо разделяют классы, то можно использовать метод ближайших центроидов или (лучше, но более ресурсоёмко) метод K ближайших соседей.

Обработка объектов разных типов

Также применяется контрастное обучение для объектов разных типов с разными преобразованиями для каждого типа:

ЗадачаПохожие объектыНепохожие объекты
ранжирование, информационный поискпоисковый запрос и соответствующие ему документыпоисковый запрос и нерелевантные документы
построение текстового описания к изображениюизображение и соответствующие ему описаниеизображение и не соответствующие ему описания
рекомендательные системыпользователь и товары, которые ему понравилисьпользователь и товары, которые ему не понравились или по которым нет данных

Например, для рекомендательной системы будут настраиваться две нейросети f()f(\cdot) и g()g(\cdot):

f(u):пользовательэмбеддингg(i):товарэмбеддинг\begin{align*} f(u): \text{пользователь} \to \text{эмбеддинг} \\ g(i): \text{товар} \to \text{эмбеддинг} \end{align*}

Но преобразования f()f(\cdot) и g()g(\cdot) будут отображать объекты в общее пространство эмбеддингов, сохраняя свойство, что эмбеддинги соответствующих друг другу объектов будут близки, а не соответствующих - далеки.

Настройка сиамской сети

Рассмотрим случай использования одной сиамской сети f()f(\cdot).

Существуют три основных функции потерь для их настройки.

Попарные потери

При обучении, используя попарные потери (pairwise contrastive loss, spring loss), сэмплируются пары объектов xi,xjx_i,x_j, и происходит минимизация

L(xi,xj)={ρ(xi,xj)2,если xi,xj похожиmax{0,αρ(xi,xj)}2,если xi,xj похожи\mathcal{L}(x_i,x_j)= \begin{cases} \rho(x_i,x_j)^{2}, & \text{если } x_i,x_j \text{ похожи} \\ \max\left\{ 0,\alpha-\rho(x_i,x_j)\right\}^{2}, & \text{если } x_i,x_j \text{ похожи} \\ \end{cases}

В качестве ρ(xi,xj)\rho(x_i,x_j) обычно берётся Евклидово расстояние.

Гиперпараметр α>0\alpha>0 управляет минимальным расстоянием между непохожими объектами, при котором не будет штрафа. Его можно выбрать равным единице.

Поскольку сэмплируются всевозможные пары объектов, то число уникальных сэмплов будет O(N2)O(N^2), где NN - число объектов в обучающей выборке.

Тройные потери

При обучении, используя тройные потери (triplet loss) сэмплируются тройки объектов:

  • xx - опорный объект (anchor)

  • x+x^{+} - похожий на xx (positive)

  • xx^{-} - не похожий на xx (negative)

L(x,x+,x)=max{ρ(x,x+)2ρ(x,x)2+α;0}\mathcal{L}(x,x^{+},x^{-})=\max\left\{ \rho(x,x^+)^{2}-\rho(x,x^-)^{2}+\alpha;0\right\}

Поскольку сэмплируются тройки объектов, то число уникальных сэмплов будет иметь порядок O(N3)O(N^3).

Вероятностные потери

При обучении, используя вероятностные потери (InfoNCE loss, NCE=noise constrastive estimation, [1],[2]), сэмплируются

  • xx - опорный объект (anchor)

  • x+x^{+} - похожий на xx (positive)

  • x1,...xMx_1,...x_M - набор непохожих на xx объектов (negative)

L(x,x+,x1,...xM)=lnesim(x,x+)esim(x,x+)+m=1Mesim(x,xm)\mathcal{L}(x,x^+,x^-_1,...x^-_M)=-\ln\frac{e^{sim(x,x^+)}}{e^{sim(x,x^+)}+\sum_{m=1}^M e^{sim(x,x^-_m)}}

где sim(x,x)sim(x,x') - косинусная мера близости:

sim(x,x)=xTxxxsim(x,x')=\frac{x^T x'}{||x||\cdot||x'||}

За счёт сэмплирования не одного, а целого набора из MM непохожих объектов число уникальных сэмплов будет иметь порядок больше O(NM+2)O(N^{M+2}).

Важность числа уникальных сэмплов

При классической настройке обычной сети по обучающей выборке из NN объектов существует всего NN уникальных сэмплов, по которым считается функция потерь. Как мы видели, при контрастном оценивании число уникальных сэмплов гораздо больше, что позволяет их качественно настраивать, используя более маленькие датасеты.

Даже если есть всего один пример класса, то сравнивая его с каждым из других объектов получим уже N1N-1 уникальных сэмплов, поэтому контрастное обучение может эффективно использоваться в обучении few-shot learning, когда доступны всего несколько примеров класса.

Генерация сэмплов

Не ограничивая общности, рассмотрим, задачу классификации, решаемую контрастным оцениванием. При генерации сэмплов можно сэмплировать равномерно

– по объектам

– по классам.

При этом в первом случае будет максимизироваться микроусреднённые меры качества на объектах, а во втором - макроусреднённые меры на классах.

Обучение расстояний (metric learning)

Приёмы контрастного обучения можно использовать и для обучения расстояний (metric learning), когда функция расстояния в метрических методах прогнозирования параметризуется, а параметры подбираются так, чтобы лучше решать конечную задачу. В случае задачи классификации можно подбирать такую меру близости (или расстояния) между объектами, чтобы объекты одного класса оказывались близки, а объекты разных классов - далеки друг от друга.

Литература

  1. Gutmann M., Hyvärinen A. Noise-contrastive estimation: A new estimation principle for unnormalized statistical models //Proceedings of the thirteenth international conference on artificial intelligence and statistics. – JMLR Workshop and Conference Proceedings, 2010. – С. 297-304.

  2. Oord A., Li Y., Vinyals O. Representation learning with contrastive predictive coding //arXiv preprint arXiv:1807.03748. – 2018.

  3. Chen T. et al. A simple framework for contrastive learning of visual representations //International conference on machine learning. – PMLR, 2020. – С. 1597-1607.