Перейти к основному содержимому

Метод обратного распространения ошибки

Все градиентные методы настройки нейросетей основаны вычислении градиента L(w)\nabla L(w) от функции потерь. Для эффективного обучения нейросетей, вычисление градиента должно производиться

  • точно

  • и вычислительно эффективно.

Можно использовать различные подходы для его вычисления, описанные далее.

Дифференцирование напрямую

Искать производную напрямую можно двумя способами:

  1. Можно вычислять градиент вручную и программно реализовывать расчёт найденных производных.

  2. Можно доверить вычисление градиента библиотекам символьного дифференцирования.

Оба метода дадут точное значение градиента. Второй метод предпочтительнее, поскольку избавит нас от ошибок вычисления производных. Однако недостатками подходов является сильное разрастание формул, по которым будет вычисляться градиент.

Рассмотрим нахождение производной по ww от функции A(w)B(w)C(w)A(w)B(w)C(w). Производная будет:

[A(w)B(w)C(w)]=A(w)B(w)C(w)+A(w)B(w)C(w)+A(w)B(w)C(w)\left[A(w)B(w)C(w)\right]'=A'(w)B(w)C(w)+A(w)B'(w)C(w)+A(w)B(w)C'(w)

в которой требуется многократное повторное перевычисление функций A(w),B(w),C(w)A(w),B(w),C(w), что неэффективно.

Если подобные операции возникают на каждом слое нейросети, то это приведёт к экспоненциальному разрастанию объёма вычислений.

Численная аппроксимация

Градиент можно находить, используя численное приближение производных:

L(w)=[L(w)w1L(w)w2L(w)wK][L(w+Δ1w)L(w)εL(w+Δ2w)L(w)εL(w+ΔKw)L(w)ε]\nabla L(w)=\begin{bmatrix} \frac{\partial L(w)}{\partial w_1} \\ \frac{\partial L(w)}{\partial w_2} \\ \vdots \\ \frac{\partial L(w)}{\partial w_K} \\ \end{bmatrix} \approx \begin{bmatrix} \frac{L(w+\Delta_1 w)-L(w)}{\varepsilon} \\ \frac{L(w+\Delta_2 w)-L(w)}{\varepsilon} \\ \vdots \\ \frac{L(w+\Delta_K w)-L(w)}{\varepsilon} \\ \end{bmatrix}

где Δiw=[w1,...wi1,wi+ε,wi+1,...wK]\Delta_i w = [w_1,...w_{i-1},w_i+\varepsilon,w_{i+1},...w_K], а ε\varepsilon - некоторая малая константа. Этот подход избавлен от ошибок дифференцирования, но даст лишь приближённое значение производных. Также он вычислительно неэффективен, поскольку требует K+1K+1 прохода по нейросети (для K+1K+1 значений весов, учитывая расчёт для немодифицированных значений).

Поскольку стоимость одного прохода вперёд по сети имеет порядок O(K)O(K), то общая сложность вычисления всего градиента - O(K2)O(K^2), что много для современных нейросетей, имеющих миллионы настраиваемых параметров.

Автоматическое дифференцирование

Вычислить производные можно точно и эффективно как по объёму вычислений, так и по памяти, используя метод автоматического дифференцирования (automatic differentiation). Существуют библиотеки, эффективно реализующие этот метод используя процессор и видеокарты, такие как pytorch, tensorflow и JAX.

Задача метода

Цель автоматического дифференцирования - не вывести общую функциональную формулу производной, а уметь быстро вычислять её значение в заданной точке, используя программный код для построения нейросетевого прогноза.

Расчёт функции потерь L(w)\mathcal{L}(w) сопряжено с вычислением суперпозиции большого числа математических преобразований, вызванных как расчётом нейросетевого прогноза, так и вычислением самой функции потерь от него.

Автоматическое дифференцирование основано на формуле расчёта производной сложной функции

L(w)=A(B(w))\mathcal{L}(w)=A(B(w)) L(w)w=A(B)BBw,\frac{\partial \mathcal{L}(w)}{\partial w}=\frac{\partial A(B)}{\partial B}\frac{\partial B}{\partial w},

Для функции

L(w)=A(B1(w),B2(w),...BM(w))\mathcal{L}(w)=A(B_1(w),B_2(w),...B_M(w))

она будет

L(w)w=AB1B1w+AB2B2w+...+ABMBMw\frac{\partial \mathcal{L}(w)}{\partial w}=\frac{\partial A}{\partial B_1}\frac{\partial B_1}{\partial w}+\frac{\partial A}{\partial B_2}\frac{\partial B_2}{\partial w}+...+\frac{\partial A}{\partial B_M}\frac{\partial B_M}{\partial w}

Поскольку нейросети включают в себя суперпозицию большого числа преобразований, необходимо вычислять производную от большого числа вложенных функций

Например для

L(w)=A(B(C(D(w))))\mathcal{L}(w)=A(B(C(D(w))))

формула дифференцирования сложной функции тогда будет:

L(w)w=A(B)BB(C)CC(D)DD(w)w.\frac{\partial \mathcal{L}(w)}{\partial w}=\frac{\partial A(B)}{\partial B}\frac{\partial B(C)}{\partial C}\frac{\partial C(D)}{\partial D}\frac{D(w)}{\partial w}.

Делать это можно двумя способами:

  • в режиме прохода вперёд (forward-mode).

  • в режиме прохода назад (backward-mode).

В обоих подходах вычисление функции представляется в виде графа вычислений (computatonal graph), в котором узлами являются промежуточные переменные, необходимые для расчёта, а связи указывают, какие промежуточные переменные от каких зависят. Каждая переменная в графе вычислений - это некоторая элементарная операция от ранее посчитанных переменных, какая как сумма, разность, вычисление стандартной функции и т.д.

На самом деле операция может быть и не элементарной, а любой, но по которой мы можем просто вычислить производную, зависящую от

Рассмотрим простейшую нейросеть - линейную регрессию, настраиваемую с L2L_2 регуляризацией:

y^=wTxL(w)=(wTxy)2+λwTw\begin{align*} \hat{y} &= w^T x \\ \mathcal{L}(w) &= (w^T x-y)^2+\lambda w^T w \\ \end{align*}

Тогда вычисление L(w)\mathcal{L}(w) можно представить в виде следующего вычислительного графа:

Foward-mode

Рассмотрим эффективный расчёт градиента функции потерь в методом foward-mode.

Он основан только на проходе вперёд (forward pass) по графу, при котором вычисляются не только значения промежуточных переменных графа, но и производные каждой промежуточной переменной по вектору весов сети ww. В каждую найденную производную подставляются значения промежуточных переменных графа, от которых она зависит, в результате чего сразу получаем промежуточное значение производной. Значение производной по финальной переменной и есть целевой градиент.

Пример использования

Рассмотрим вычисление потерь

L(w)=(wTxy)2+λwTw\mathcal{L}(w) = (w^T x-y)^2+\lambda w^T w

Построим по ней вычислительный граф:

Пусть нам нужно вычислить wL(w)\nabla_w \mathcal{L}(w) при

y=0,x=[1,2],w=[3,4],λ=5.y=0, x=[1,2], w=[3,4], \lambda=5.

При проходе вперёд по графу вычислений получаем следующие значения промежуточных переменных:

a=3+8=11b=ay=11c=b2=121d=9+16=25e=λd=125L=c+e=121+125=246\begin{align*} a &= 3+8=11 \\ b &= a-y=11 \\ c &= b^{2}=121 \\ d &= 9+16=25 \\ e &= \lambda d=125 \\ \mathcal{L} &=c+e=121+125=246 \end{align*}

Но одновременно при проходе вперёд вычисляются и производные по целевой переменной (вектору весов ww):

aw=x=[1,2]\frac{\partial a}{\partial w}=x=[1,2] dw=2w=[6,8]\frac{\partial d}{\partial w}=2w=[6,8] bw=b(a)w=baaw=1[1,2]=[1,2]\frac{\partial b}{\partial w}=\frac{\partial b(a)}{\partial w}=\frac{\partial b}{\partial a}\frac{\partial a}{\partial w}=1\cdot[1,2]=[1,2] ew=e(d)w=eddw=λ[6,8]=[30,40]\frac{\partial e}{\partial w}=\frac{\partial e(d)}{\partial w}=\frac{\partial e}{\partial d}\frac{\partial d}{\partial w}=\lambda\cdot[6,8]=[30,40] cw=c(b)w=cbbw=2b[1,2]=22[1,2]=[22,44]\frac{\partial c}{\partial w}=\frac{\partial c(b)}{\partial w}=\frac{\partial c}{\partial b}\frac{\partial b}{\partial w}=2b\cdot[1,2]=22\cdot[1,2]=[22,44] Lw=L(c,e)w=Lccw+Leew=1[22,44]+1[30,40]=[52,84]\frac{\partial\mathcal{L}}{\partial w}=\frac{\partial\mathcal{L}(c,e)}{\partial w}=\frac{\partial\mathcal{L}}{\partial c}\frac{\partial c}{\partial w}+\frac{\partial\mathcal{L}}{\partial e}\frac{\partial e}{\partial w}=1\cdot[22,44]+1\cdot[30,40]=[52,84]

Градиент равен 2-мерному вектору, поскольку представляет вектор из частных производных

[Lw1,Lw2]\left[ \frac{\partial\mathcal{L}}{\partial w_{1}},\frac{\partial\mathcal{L}}{\partial w_{2}}\right ]

Backward-mode (обратное распространение ошибки)

Теперь рассмотрим другой способ эффективного расчёта градиента потерь по весам нейросети - backward-mode.

Метод обратного распространения ошибки

Этот метод также называется методом обратного распространения ошибки (backpropagation, backprop) и является основным методом для настройки весов нейросети.

Этот метод состоит из двух шагов:

  1. Проход вперёд (forward pass), на котором итеративно слева-направо вычисляются (и запоминаются!) значения всех промежуточных переменных графа вычислений. Производные по весам при этом не вычисляются.

  2. Проход назад (backward pass), на котором итеративно справа-налево вычисляются производные текущих переменных графа от более ранних переменных (от которых зависят текущие), как функции от переменных графа вычислений. Далее подставляются значения этих переменных, в результате чего производные вычисляются (и используются) как числа.

Пример использования

Рассмотрим снова вычисление потерь

L(w)=(wTxy)2+λwTw\mathcal{L}(w) = (w^T x-y)^2+\lambda w^T w

Построим по ней вычислительный граф:

Пусть, как и раньше, нам нужно вычислить wL(w)\nabla_w \mathcal{L}(w) при

y=0,x=[1,2],w=[3,4],λ=5.y=0, x=[1,2], w=[3,4], \lambda=5.

При проходе вперёд (forward pass) по графу вычислений получаем следующие значения промежуточных переменных:

a=3+8=11b=ay=11c=b2=121d=9+16=25e=λd=125L=c+e=121+125=246\begin{align*} a &= 3+8=11 \\ b &= a-y=11 \\ c &= b^{2}=121 \\ d &= 9+16=25 \\ e &= \lambda d=125 \\ \mathcal{L} &=c+e=121+125=246 \end{align*}

Эти переменные запоминаются для использования при проходе назад (backward pass), в которых рекуррентно вычисляются производные итоговых потерь по промежуточным переменным графа от самых последних назад к первым (справа-налево).

Покажем, как осуществляется проход назад:

Lc=1,Le=1\frac{\partial\mathcal{L}}{\partial c}=1,\quad\frac{\partial\mathcal{L}}{\partial e}=1 Lb=L(c)b=Lccb=12b=22\frac{\partial\mathcal{L}}{\partial b}=\frac{\partial\mathcal{L}(c)}{\partial b}=\frac{\partial\mathcal{L}}{\partial c}\frac{\partial c}{\partial b}=1\cdot2b=22 Ld=L(e)d=Leed=1λ=5\frac{\partial\mathcal{L}}{\partial d}=\frac{\partial\mathcal{L}(e)}{\partial d}=\frac{\partial\mathcal{L}}{\partial e}\frac{\partial e}{\partial d}=1\cdot\lambda=5 La=L(b)a=Lbba=221=22\frac{\partial\mathcal{L}}{\partial a}=\frac{\partial\mathcal{L}(b)}{\partial a}=\frac{\partial\mathcal{L}}{\partial b}\frac{\partial b}{\partial a}=22\cdot1=22 Lw=L(a,d)w=Laaw+Lddw=22x+52w=22[1,2]+10[3,4]=[22,44]+[30,40]=[52,84]\begin{align*} \frac{\partial\mathcal{L}}{\partial w}&=\frac{\partial\mathcal{L}(a,d)}{\partial w} \\ &=\frac{\partial\mathcal{L}}{\partial a}\frac{\partial a}{\partial w}+\frac{\partial\mathcal{L}}{\partial d}\frac{\partial d}{\partial w} \\ &=22\cdot x+5\cdot2w \\ &=22\cdot[1,2]+10\cdot[3,4] \\ &=[22,44]+[30,40]=[52,84] \end{align*}

Как видим, получили тот же результат, что и при режиме forward-mode.

Отличия forward-mode и backward-mode

Тем не менее, у forward-mode и backward-mode есть вычислительные отличия:

Преимуществом forward-mode является то, что и расчёт промежуточных переменных, и расчёт производных производится всего за один проход (forward pass). Это снижает расходуемый объём памяти, поскольку значения промежуточных переменных достаточно помнить только до их последнего использования, а дальше можно высвобождать память для других вычислений.

Расчёт градиента через backward-mode состоит из двух проходов:

  1. прохода вперёд, во время которого нужно запомнить значения всех промежуточных переменных;

  2. прохода назад, на котором вычисляются производные потерь по промежуточным переменным.

Таким образом, перед проходом назад нужно держать в памяти значения всех промежуточных переменных вычислительного графа, вследствие чего расходы на память у backward-mode увеличиваются.

Если рассматривать расходы на вычисления, то для сужающегося вычислительного графа, как на рисунке ниже, эффективнее backward-mode:

Мы это видели в примере, где для forward-mode все операции были векторными (по сути, forward-mode требует отдельных вычислений для каждого веса), в то время как в backward-mode большая часть операций были скалярными, и лишь в конце появились векторы.

Для расширяющегося же графа вычислений, как на рисунке ниже, вычислительно эффективнее forward-mode:

Итоговая рекомендация

На практике обычно работают с архитектурами с сужающимся графом вычислений, которые в конце упираются в единственную скалярную функцию потерь. Поэтому при настройке нейросетей используют backward-mode (он же метод обратного распространения ошибки).

Тем не менее, бывают случаи, когда forward-mode полезен. Рассмотрим автокодировщик, минимизирующий среднеквадратичные потери на каждом выходе:

Пусть нам требуется посчитать матрицу производных {x^ixj}ij\left\{\frac{\partial \hat{x}_i}{\partial x_j}\right\}_{ij}.

По формуле производной сложной функции имеем:

x^x=x^(e)x=x^eex\frac{\partial\widehat{x}}{\partial x}=\frac{\partial\widehat{x}(e)}{\partial x}=\frac{\partial\widehat{x}}{\partial e}\frac{\partial e}{\partial x}

Тогда производные ex\frac{\partial e}{\partial x} эффективнее посчитать через backward-mode, а x^e\frac{\partial\widehat{x}}{\partial e} - через forward mode.