Перейти к основному содержимому

Дистилляция знаний

Базовая идея

При использовании моделей глубокого обучения важна не только точность моделей, но и их вычислительная простота, обеспечивающая более быстрое построение прогнозов. Упрощение модели позволяет её применять на более простых вычислительных устройствах, таких как мобильный телефон.

Дистилляция знаний (knowledge distillation) - процесс настройки простой модели (студент, student model) воспроизводить поведение сложной уже настроенной точной модели (учитель, teacher model). Впервые технология была развёрнуто описана в [1].

В базовом варианте дистилляция знаний для задачи классификации состоит из следующих шагов:

  1. Обучить сложную модель g(x)g(x) на размеченном тренировочном датасете. Далее сложная модель зафиксирована и не меняется.

  2. Обучить простую модель f(x)f(x) восстанавливать вероятности классов сложной модели на трансферном датасете (transfer set).

Трансферный датасет может совпадать с тренировочным, на котором обучалась сложная модель, может составлять его подмножество, а может составлять совсем другую выборку объектов.

Мотивация

Как правило, трансферный датасет мал, из-за чего возникают неоднозначности с обобщающей способностью - экстраполировать информацию об ограниченных наблюдениях можно по-разному. При этом известно, что сложная модель точна, т.е. обладает хорошей обобщающей способностью.

Поэтому, чтобы расширить объем получаемой информации с каждого объекта, простой модели даётся не просто информация о корректном классе, а предоставляется полная информация о вероятностях каждого из CC классов по мнению сложной модели. Таким образом, с каждого объекта простая модель получает в CC раз больше информации и учится быстрее.

Кроме того, пытаясь воспроизвести рейтинги всех классов сложной модели, простая модель обучается воспроизводить высокую обобщающую способность сложной модели.

Например, при классификации изображений,

  • в стандартном обучении простая модель видит только изображение и верный класс, например, "собака". Переобучившись, она легко может научиться её путать с классом "медведь" и "лось" из-за цветовой схожести.

  • в дистилляции знаний простая модель видит, что "собака" обладает максимальным рейтингом, но также высоким рейтингом обладают классы "волк" и "шакал", а класс "медведь" и "лось" обладают, наоборот, низким рейтингом. Т.е. простая модель учится правильной обобщающей способности сложной модели.

Настройка простой модели

В стандартной процедуре обучения [1] простая модель выдаёт рейтинги f1(x),...fC(x)f_1(x),...f_C(x), которые трансформируются в вероятности классов через SoftMax-преобразование:

p1f=ef1(x)/Tc=1Cefc(x/T),p2f=ef2(x)/Tc=1Cefc(x)/T,    pCf=efC(x)/Tc=1Cefc(x)/T,\begin{align*} & p^f_1 = \frac{e^{f_1(x)/T}}{\sum_{c=1}^C e^{f_c(x/T)}}, \\ & p^f_2 = \frac{e^{f_2(x)/T}}{\sum_{c=1}^C e^{f_c(x)/T}}, \\ & \cdots \; \cdots \; \cdots\\ & p^f_C = \frac{e^{f_C(x)/T}}{\sum_{c=1}^C e^{f_c(x)/T}}, \\ \end{align*}

где T>0T>0 - гиперпараметр температуры, обычно выбираемый равным единице. Чем он выше, тем выходные вероятности получаются более сглаженные и близкие к равномерному распределению.

Настройка модели производится, используя стандартную кросс-энтропию:

Ls(f,y)=c=1Cyclnpcf(x),yc=I{y=c}(1)\mathcal{L}_s(f,y)=-\sum_{c=1}^C y_c\ln p^f_c(x),\quad y_c=\mathbb{I}\{y=c\} \tag{1}

В дистилляции знаний есть уже настроенная сложная модель g()g(\cdot), выдающая вероятности классов [p1g,...,pCg][p^g_1, ..., p^g_C]. Для простой модели эти вероятности представляют собой мягкую разметку (soft labels), к которой нужно приближать собственные прогнозы, используя ту же кросс-энтропию в качестве потерь:

Lkd(f,g)=c=1Cpcglnpcf(x)(2)\mathcal{L}_{kd}(f,g)=-\sum_{c=1}^C p^g_c\ln p^f_c(x) \tag{2}

Используя (2) можно настраивать простую модель на неразмеченном трансферном датасете. Если же трансферный датасет содержит метки истинных классов, то настроить простую модель можно точнее, используя взвешенную сумму потерь (1) и (2):

L(f,g,y)=Ls(f,y)+λT2Lkd(f,g),\mathcal{L}(f,g,y) = \mathcal{L}_s(f,y) + \lambda T^2 \mathcal{L}_{kd}(f,g),

где λ>0\lambda>0 - гиперпараметр, отвечающий за силу дистилляции знаний. Потери (1) вычисляются с T=1T=1, а потери (2) рекомендуется вычислять с большим TT для сложной и простой модели, поскольку иначе сложная модель будет часто выдавать слишком сконцентрированное распределение вокруг истинного класса, что затруднит дистилляцию знаний. Поскольку градиент по весам потерь (2) убывает по закону 1/T21/T^2 при возрастании TT, то чтобы выровнять эффект потерь (1) и (2) вторые потери рекомендуется явно домножать на T2T^2.

Другие варианты применения

Выше мы рассмотрели перенос знаний сложной модели только с её последнего слоя, этот подход называется дистилляцией, основанной на откликах (responce based distillation). Но обучать простую модель воспроизводить карту признаков (активации нейронов) с промежуточного слоя. Для этого обычно при настройке простой модели используется регуляризатор в виде квадрата L2L_2-расхождения между картой признаков простой модели fK(x)f_K(x) и сложной модели gK(x)g_K(x) на слое KK:

L(f,g,y)=Ls(f,y)+λfK(x)gK(x)2\mathcal{L}(f,g,y) = \mathcal{L}_s(f,y)+\lambda ||f_K(x)-g_K(x)||^2

Этот подход называется дистилляция, основанная на признаках (feature based distillation, [2]).

Также существует дистилляция, основанная на взаимодействии признаков (relation based distillation, [2]), при которой простая модель учится воспроизводить такую же взаимосвязь между представлениями с разных слоёв ii и jj, что и сложная модель. Пусть Φ(u,v)\Phi(u,v) - функция взаимодействия между слоями. Например, это может быть матрица попарных корреляций между каждым признаком слоя ii с каждым признаком слоя jj. Тогда простая модель настраивается, используя следующую функцию потерь:

L(f,g,y)=Ls(f,y)+λΦ(fi(x),fj(x))Φ(gi(x),gj(x))2\mathcal{L}(f,g,y) = \mathcal{L}_s(f,y)+\lambda ||\Phi(f_i(x),f_j(x))-\Phi(g_i(x),g_j(x))||^2

Онлайн-дистилляция

В традиционной дистилляции знаний вначале настраивается сложная модель, затем она фиксируется и используется только для генерации целевых значений простой модели. Простая модель при этом может испытывать сложности при воспроизведении ответов сложной модели из-за недостаточной выразительности. В онлайн-дистилляции (online distillation, [2]) сложная и простая модели обучаются параллельно, при этом

  • простая модель использует ответы последней версии сложной модели;

  • простая модель посылает величину собственной ошибки аппроксимации сложной модели;

  • сложная модель учитывает эту ошибку при собственном обучении, адаптируя свою настройку таким образом, чтобы простой модели было проще воспроизводить её поведение.

Дистилляция со многими учителями

Простую нейросеть можно обучить качественнее, если использовать не одну, а сразу несколько сложных сетей. Тогда можно будет собрать экспертизу с разных моделей, использующих разные подходы и принять более взвешенное решение.

Обычно это достигается такой настройкой простой модели, чтобы она хорошо приближала усреднённые отклики набора сложных моделей. Однако лучшего качества удаётся достичь [3], если заставить простую модель одновременно приближать прогноз каждой сложной модели в отдельности.

Используя эти подходы, хорошо настроенная простая модель может начать работать даже точнее, чем каждая из сложных.

Дистилляция, совмещённая с упрощением сложной модели

Существует большая разница в сложности между сложной и простой моделью, из-за чего простой модели может быть сложно перенять некоторые паттерны поведения сложной модели. В статье [4] удаётся обучить простую модель лучшего качества за счёт предварительного упрощения сложной модели, используя обрезку моделей (network pruning).

Взаимная дистилляция

Располагая KK серверами, можно распараллелить настройку сложной модели на больших данных. Для этого

  1. Грубо настраивается KK моделей на подвыборках обучающего датасета.

  2. Производится донастройка каждой модели, при которой периодически модели пересылают между серверами свои текущие веса. Донастройка осуществляется таким образом, чтобы каждая из моделей не только выдавала точные прогнозы на рассматриваемых объектах, но и чтобы эти прогнозы не слишком отличались от прогнозов последних известных версий других моделей.

В этом подходе нет явного деления на простую и сложную модели, все модели считаются равноправными, а сам подход называется взаимная дистилляция (codistillation, [5]).

Литература

  1. Hinton G., Vinyals O., Dean J. Distilling the knowledge in a neural network //arXiv preprint arXiv:1503.02531. – 2015.
  2. Gou J. et al. Knowledge distillation: A survey //International Journal of Computer Vision. – 2021. – Т. 129. – №. 6. – С. 1789-1819.
  3. Zuchniak K. Multi-teacher knowledge distillation as an effective method for compressing ensembles of neural networks //arXiv preprint arXiv:2302.07215. – 2023.
  4. Park J., No A. Prune your model before distill it //European Conference on Computer Vision. – Cham : Springer Nature Switzerland, 2022. – С. 120-136.
  5. Anil R. et al. Large scale distributed neural network training through online distillation //arXiv preprint arXiv:1804.03235. – 2018.