Перейти к основному содержимому

Жадный и лучевой поиск

Введение

Рассмотрим, как можно улучшить генерацию текста как последовательности слов с помощью языковой модели (language model), основанной на рекуррентной сети. Хотя стоит помнить, что предложенный метод улучшения применим

  • к любым языковым моделям, например, трансформерам;

  • к генерации любых дискретных объектов (ДНК как последовательности нуклеотидов; сессии как последовательности действий пользователя на сайте и т.д.).

Сеть генерирует текст последовательно слово за словом среди VV уникальных слов словаря. На каждом шаге сеть выдаёт VV рейтингов каждого отдельного слова (или символа при посимвольной генерации):

r^=[r^1,r^2,...r^V]=f(xt,ht)\hat{r}=[\hat{r}_1,\hat{r}_2,...\hat{r}_V]=f(x_t,h_t)

Эти рейтинги преобразуются в вероятности слов, используя SoftMax преобразование:

SoftMaxτ(r^1,...r^V)=1ier^i/τ(er^1/τer^2/τer^V/τ),\text{SoftMax}_{\tau}\left(\widehat{r}_{1},...\widehat{r}_{V}\right)=\frac{1}{\sum_{i}e^{\widehat{r}_{i}/\tau}}\cdot\left(\begin{array}{c} e^{\widehat{r}_{1}/\tau}\\ e^{\widehat{r}_{2}/\tau}\\ \cdots\\ e^{\widehat{r}_{V}/\tau} \end{array}\right), τ>0гиперпараметр температуры\tau>0 - \text{гиперпараметр температуры}

Нас интересует генерация последовательности слов, обладающая максимальным рейтингом (который является некоторой функцией от модельного правдоподобия сгенерированной последовательности слов).

Вначале рассмотрим работу более простого алгоритма жадного поиска (greedy search), а затем опишем работу более продвинутого лучевого поиска (beam search), осуществляющего более полный перебор.

Алгоритм жадного поиска

Простейший подход генерации слов или символов текста - генерировать каждый раз следующее слово (или символ), дающее наивысший рейтинг последовательности. Этот подход называется жадным поиском (greedy search). Визуализируем его.

Пример работы

Пусть, для простоты, генерация происходит не на уровне слов, а на уровне букв, причем в словаре только 2 буквы: "A" и "M". Нам требуется сгенерировать слово из 4-х букв словаря, обладающего максимальным рейтингом S()S(\cdot):

argmaxc1c2c3c4S(c1c2c3c4)?\arg\max_{c_1c_2c_3c_4} S(c_1c_2c_3c_4) - ?

В качестве рейтинга используется некоторое преобразование от модельной вероятности пронаблюдать именно такую цепочку символов в последовательности:

P(c1c2c3c4)=P(c1)P(c2c1)P(c3c1c2)P(c4c1c2c3)P(c_1c_2c_3c_4)=P(c_1)P(c_2|c_1)P(c_3|c_1c_2)P(c_4|c_1c_2c_3)

Процесс итеративного построения слова можно представить в виде дерева, каждый узел которого отвечает некоторому варианту выбранного слова с приписанным ему рейтингом.

Нулевой ярус дерева отвечает пустому слову, первый - слову из одной буквы ("A" или "M"), следующий - слову из 2х букв ("AA", "AM", "MA", "MM") и т.д.

Поскольку нам нужно составить наилучшее слово из 4-х букв, нам требуется найти узел 4-го яруса дерева, обладающий максимальным рейтингом.

Ниже представлен пример дерева выбора с рейтингами узлов:

Генерация жадного поиска стартует из корня дерева (отвечающего пустому слову) и анализируются рейтинги слов из одной буквы:

Выбирается слово "A", обладающее максимальным рейтингом 20.

Далее для префикса "А" анализируются продолжения "AA" и "AM".

Выбирается "AM", как обладающее максимальным рейтингом.

Далее снова анализируются возможные продолжения "AA": "AAA" и "AAM":

Выбирается префикс "AAA":

Для "AAA" снова анализируются продолжения: "AAAA" и "AAAM":

Выбирается слово "AAAA", и на этом генерация 4-х буквенного слова завершается:

Полученное слово "AAAA" обладает рейтингом 48.

Алгоритм лучевого поиска

Поскольку каждый раз алгоритм смотрит только на один шаг вперёд, то сгенерированная последовательность как целое может оказаться неоптимальной, то есть обладающей недостаточно высоким рейтингом S(c1c2...cT)S(c_1c_2...c_T) по сравнению с альтернативными вариантами.

Для повышения качества сгенерированных последовательностей используется лучевой поиск (beam search), суть которого состоит в том, что выбирается каждый раз не одна гипотеза, а набор из K лучших гипотез, где KK - гиперпараметр метода. Получив KK итоговых последовательностей можно среди них выбрать ту, которая обладает максимальным рейтингом, что обеспечит более полный перебор вариантов.

Пример работы

Опишем алгоритм визуально для примера выше (генерация 4-х буквенного слова из букв "A" и "M"), когда K=2K=2, т.е. параллельно дорабатываются 2 лучших гипотезы.

Генерация лучевого поиска стартует из корня дерева (отвечающего пустому слову) и анализируются рейтинги слов из одной буквы:

В отличие от жадного поиска, лучевой поиск идёт одновременно по 2-м маршрутам, используя гипотезы "A" и "M":

Для каждой гипотезы анализируются их всевозможные расширения:

Наилучшими оказались гипотезы "AA" и "AM", префикс "M" отбрасывается как обеспечивающий расширение с меньшим рейтингом:

Снова анализируются продолжения каждой из 2-х гипотез:

Лучшими продолжениями оказываются "AAA" и "AMA":

Анализируются продолжения этих гипотез:

Выбираются две лучших - "AAAA" и "AMAM" с рейтингами 48 и 60:

Итоговой генерацией будет слово "AMAM", обладающую максимальным рейтингом среди отобранных на предыдущем шаге.

За счёт расширения пространства поиска нам удалось сгенерировать слово с более высоким рейтингом (60), чем пользуясь жадным поиском (48)!

Анализ лучевого поиска

Обратим внимание, что лучевой поиск всё же не обеспечивает полный перебор. Из-за этого мы упустили наилучшую генерацию слова "MAMA" с рейтингом 80:

При K=1K=1 лучевой поиск сводится к жадному алгоритму. Чем гиперпараметр KK выше, тем шире пространство перебора, и тем больше шансов найти не локально, а глобально оптимальное решение.

При KVTK\ge V^T, где VV- объём словаря, а TT - длина генерируемой последовательности, лучевой поиск сведётся к полному перебору (full search).

Случайная генерация

Если в языковом моделировании выбирать слова, обладающие максимальной вероятностью, то это часто будет приводить к генерации очень правильных, но слишком однообразных текстов, склонных к зацикливанию.

Чтобы этого не происходило, слова рекомендуется сэмлировать из распределения, предсказанного моделью, причём гиперпараметр температуры τ>0\tau>0 в SoftMax преобразовании управляет контрастностью выходных вероятностей.

Как именно?

При τ0\tau\to 0 сэмплирование по-прежнему сводится к выбору самого вероятного слова и генерирует текст максимально правильным, но однообразным. Увеличение τ\tau повышает разнообразие ценой уменьшения согласованности слов в тексте (или букв при посимвольной генерации).

Таким образом, стохастическая генерация последовательности слов (или других дискретных элементов) в рекуррентной сети происходит пошагово слово за словом:

  1. генерируется w1pτ(w)w_1\sim p_\tau(w);

  2. генерируется w2pτ(ww1)w_2\sim p_\tau(w|w_1);

  3. генерируется w3pτ(ww1w2)w_3\sim p_\tau(w|w_1w_2);

  4. генерируется w4pτ(ww1w2w3)w_4\sim p_\tau(w|w_1w_2w_3);

и т.д.

При этом, чтобы уменьшить риск генерации случайного слова, которое будет совсем выпадать из контекста, слова генерируют не из всего списка слов, а из подмножества максимально вероятных. Для этого есть два подхода:

  • Top-K sampling позволяет сэмплировать только KK слов, обладающих максимальной вероятностью.

  • Nucleus sampling: вместо задания KK задаётся пороговая вероятность PP. Слова генерируются не из всего списка, а из подмножества самых вероятных слов. Это подмножество формируется следующим образом: слова сортируются по убыванию их вероятности, и в допустимое включается минимальное число топ-K самых вероятных слов так, что их суммарная вероятность выше PP. Таким образом Nucleus sampling представляет собой разновидность top-K sampling с динамически изменяемым параметром K, способным более адаптивно подстраиваться под контекст.

Генерация слов продолжается, пока не будет сгенерирован специальный токен [EOS], означающий конец последовательности либо пока вероятность конца генерации (предсказываемая отдельным выходом сети) не станет больше порога.

Рейтинг последовательности

При генерации последовательности слов w1w2...wTw_1w_2...w_T можно по-разному оценивать её качество или рейтинг. Можно было бы оценивать как логарифм модельной вероятности сгенерированной цепочки:

S(w1w2...wT)=logP(w1w2...wT)=logP(w1)p(w2w1)...P(wTw1w2...wT1)=t=1TlogP(wtw1w2...wt1)\begin{align*} S(w_1w_2...w_T) &= \log P(w_1w_2...w_T) \\ &=\log P(w_1)p(w_2|w_1)...P(w_T|w_1w_2...w_{T-1}) \\ &=\sum_{t=1}^T \log P(w_t|w_1w_2...w_{t-1}) \\ \end{align*}

однако такой способ будет поощрять более короткие последовательности как приводящие к меньшему числу произведений малых вероятностей. Чтобы простимулировать метод генерировать и более длинные варианты тоже, рейтинг нормируют на длину выходной последовательности по формуле

S(w1w2...wT)=1Tαt=1TlogP(wtw1w2...wt1),\begin{align*} S(w_1w_2...w_T) = \frac{1}{T^\alpha}\sum_{t=1}^T \log P(w_t|w_1w_2...w_{t-1}), \end{align*}

где α0.75\alpha\sim 0.75 - гиперпараметр, управляющий тем, насколько длинные выходные последовательности для нас более предпочтительны.

Как именно?

При уменьшении α\alpha рейтинг длинных последовательностей будет получаться выше, и они будут выбираться чаще.

В расчёт рейтинга также можно добавлять другие условия, измеряющие то, насколько сгенерированный текст:

  • разнообразен (содержит много уникальных nn-грамм (т.е. nn подряд идущих слов или символов);

  • естественен (содержит длинные последовательности слов, которые реально встречаются в текстах).

Литература

  1. http://karpathy.github.io/2015/05/21/rnn-effectiveness/