Дистилляция знаний
Базовая идея
При использовании моделей глубокого обучения важна не только точность моделей, но и их вычислительная простота, обеспечивающая более быстрое построение прогнозов. Упрощение модели позволяет её применять на более простых вычислительных устройствах, таких как мобильный телефон.
Дистилляция знаний (knowledge distillation) - процесс настройки простой модели (студент, student model) воспроизводить поведение сложной уже настроенной точной модели (учитель, teacher model). Впервые технология была развёрнуто описана в [1].
В базовом варианте дистилляция знаний для задачи классификации состоит из следующих шагов:
-
Обучить сложную модель на размеченном тренировочном датасете. Далее сложная модель зафиксирована и не меняется.
-
Обучить простую модель восстанавливать вероятности классов сложной модели на трансферном датасете (transfer set).
Трансферный датасет может совпадать с тренировочным, на котором обучалась сложная модель, может составлять его подмножество, а может составлять совсем другую выборку объектов.
Мотивация
Как правило, трансферный датасет мал, из-за чего возникают неоднозначности с обобщающей способностью - экстраполировать информацию об ограниченных наблюдениях можно по-разному. При этом известно, что сложная модель точна, т.е. обладает хорошей обобщающей способностью.
Поэтому, чтобы расширить объем получаемой информации с каждого объекта, простой модели даётся не просто информация о корректном классе, а предоставляется полная информация о вероятностях каждого из классов по мнению сложной модели. Таким образом, с каждого объекта простая модель получает в раз больше информации и учится быстрее.
Кроме того, пытаясь воспроизвести рейтинги всех классов сложной модели, простая модель обучается воспроизводить высокую обобщающую способность сложной модели.
Например, при классификации изображений,
-
в стандартном обучении простая модель видит только изображение и верный класс, например, "собака". Переобучившись, она легко может научиться её путать с классом "медведь" и "лось" из-за цветовой схожести.
-
в дистилляции знаний простая модель видит, что "собака" обладает максимальным рейтингом, но также высоким рейтингом обладают классы "волк" и "шакал", а класс "медведь" и "лось" обладают, наоборот, низким рейтингом. Т.е. простая модель учится правильной обобщающей способности сложной модели.
Настройка простой модели
В стандартной процедуре обучения [1] простая модель выдаёт рейтинги , которые трансформируются в вероятности классов через SoftMax-преобразование:
где - гиперпараметр температуры, обычно выбираемый равным единице. Чем он выше, тем выходные вероятности получаются более сглаженные и близкие к равномерному распределению.
Настройка модели производится, используя стандартную кросс-энтропию:
В дистилляции знаний есть уже настроенная сложная модель , выдающая вероятности классов . Для простой модели эти вероятности представляют собой мягкую разметку (soft labels), к которой нужно приближать собственные прогнозы, используя ту же кросс-энтропию в качестве потерь:
Используя (2) можно настраивать простую модель на неразмеченном трансферном датасете. Если же трансферный датасет содержит метки истинных классов, то настроить простую модель можно точнее, используя взвешенную сумму потерь (1) и (2):
где - гиперпараметр, отвечающий за силу дистилляции знаний. Потери (1) вычисляются с , а потери (2) рекомендуется вычислять с большим для сложной и простой модели, поскольку иначе сложная модель будет часто выдавать слишком сконцентрированное распределение вокруг истинного класса, что затруднит дистилляцию знаний. Поскольку градиент по весам потерь (2) убывает по закону при возрастании , то чтобы выровнять эффект потерь (1) и (2) вторые потери рекомендуется явно домножать на .