Backpropagation(BP) 倒傳遞法 #1 工作原理與說明

本篇會介紹在機器學習(machine learning)與深度學習(deep learning)領域裡很流行的倒傳遞法(Back Propagation/ Backpropagation, BP)的精髓:梯度下降法(Gradient Descent)、連鎖率(Chain Rule)

你想要知道該如何以Python實作BP並應用於優化層類神經網路可以讀這篇:Backpropagation(BP) 倒傳遞法 #2 貓貓分類器-2層類神經網路;你想要知道該如何優化多層類神經網路可以讀這篇:Backpropagation(BP) 倒傳遞法 #3 貓貓分類器-N層類神經網路

倒傳遞法(Backpropagation),這是一個很多學者在同一個年代都有發表過的最佳化演算法,其中包括鼎鼎大名的 Rumelhart 與 Hinton 在1986發表的『Learning representations by back-propagation errors』與更早幾年歸納這個方法的Webors於1974所發表的博士學位論文也有提到。BP是一種可大致分為正向傳遞(Forward pass)與反向傳遞(Backward pass),其中又結合梯度下降法(Gradient Descent)和微積分中的連鎖率(Chain Rule)而成的最佳化演算法。

那什麼是梯度下降呢?
就直接從梯度下降法切入Backpropagation吧!

梯度下降法

梯度下降法的基礎概念

贊助廣告

假設最佳化目標:對成本函數最小化

而梯度下降法就從參數初始位置朝向最陡的下坡方向前進並更新參數位置,獲得更新後的參數最終可以帶來降低成本的效果。那獲得坡度資訊的方法就是使用導函數(精確地來說應該是偏微分)。微積分應該有學到,對函數求得一階導函數可以獲得斜率函數,梯度下降法就是運用這個特性來優化成本函數。

 

梯度下降法的數學基礎

假設最佳化目標:對成本函數($J(w)$)最小化
先只用一個參數來看梯度下降法。可以注意到,這邊設定成本函數 $J$ 擁有一個輸入值 $w$,就是要解釋微分帶來的效果。
若成本函數$J$是一個拋物線,如下圖(1):

Gradient Descent cost function J(w)
圖(1):成本函數$J$

若要找到能產出最小成本的$w$,就是必須要不斷的改變$w$帶入$J$來嘗試。
但是電腦沒有上帝視角,所以我們透過微分讓電腦知道應該要朝哪個方向來方法更新 $w$,因此產生了下面的更新公式(1)。從公式(1)可以發現有一個未曾看過的$\alpha$,這是學習速率(learning rate),用來控制學習步伐的參數,數值通常是介於$0$到$1$之間。

$w = w – \alpha \frac{dJ(w)}{dw}$ $(1)$

公式(1)為我們帶來使用梯度更新$w$的概念,所以我們可以用下圖(2)來理解。可以發現紫色的三角形就是我們每次計算出來微分值$\frac{dJ(w)}{dw}$,而這張圖是建立在當微分值$\frac{dJ(w)}{dw}$大於$0$的情況。

Gradient Descent J(w) when positive slope圖(2):梯度下降示意圖(當斜率為大於$0$)

如果$w$的初始值在比較接近原點的地方呢?可以用左半邊的線段斜率是負數來幫助理解,所以微分值就會小於$0$,但是稍微計算一下就知道就如果是$\frac{dJ(w)}{dw}< 0$帶入更新公式(1)計算的結果會讓$w$數值變大,也就是會讓$w$趨近於$J(w)$較小的方向!(如圖(3)左半部所示)

Gradient Descent J(w) when negative slope
圖(3):梯度下降示意圖(當斜率小於$0$)

 

梯度下降法

假設優化目標:對成本函數 $J(w,b)$ 最小化

上面已經介紹過一個參數的成本函數,但是在真實應用上成本函數中不會只有一個參數
因為要做微分的目標不只一個,此時就必須要使用到偏微分的方法,才能知道參數$w$、$b$分別對於成本函數$J$的影響。
偏微分記號我們以$\partial$表示。更新公式則修改成如下公式(2)和公式(3):

$w=w-\alpha \frac{\partial J(w,b)}{\partial w}$ $(2)$
$b=b-\alpha \frac{\partial J(w,b)}{\partial b}$ $(3)$

梯度下降法,就是利用公式(2)、(3)的概念對成本函數進行優化,經過若干迭代之後就可以得到優化後的參數$w$和$b$。

 

連鎖率

連鎖率是Backward pass的精髓,一定要懂!
下方圖(3)為成本函數$J(a, b, c)$的計算圖(computation graph),我們可以藉由這樣的圖來理解成本函數的計算過程,以及連鎖率。

首先,順著計算圖的流程分別先計算$u$、$v$最後算出成本函數$J$可以獲得成本$J$。(其實整個計算成本的過程就是Forward pass)

Computation graph for backward pass
圖(4):簡易計算圖

倘若欲優化成本函數$J$,就勢必要優化參數$a$、$b$、$c$。
所以要透過上述的梯度下降法優化$J$,就要倒著計算圖的順序來找出這幾個參數對$J$的偏微分值,如此一來便可知道這三個參數改變一點點的話,對於成本函數的影響是多少
因此我們得從$J$回推到參數$a$、$b$、$c$,倒著計算圖的順序可以先看到$J=3v$,用偏微分可計算出若給$v$一點點變動,能夠影響多少$J$的變動量為:$\frac{\partial J}{\partial v}$。

接續著來看$v = a + u$,分別用偏微分計算更改$a$一點點,會$v$對產生多少影響:$\frac{\partial v}{\partial a}$,想當然$u$對$v$的影響就是:$\frac{\partial v}{\partial u}$。

再來就是$u=bc$,依照上述偏微分的計算方式可以得知$b$對$u$的影響就是$\frac{\partial u}{\partial b}$,而$c$對$u$的影響則是$\frac{\partial u}{\partial c}$。

現在順著計算圖的方向來看誰會影響誰
$a$會影響$v$,而$v$又會影響$J$,我們可以這樣表示:$a\rightarrow v\rightarrow J$
$b\rightarrow u\rightarrow v\rightarrow J$
$c\rightarrow u\rightarrow v\rightarrow J$

既然優化成本函數時就必須要計算各個參數對成本函數的影響量($\frac{\partial J}{\partial a}$、$\frac{\partial J}{\partial b}$、$\frac{\partial J}{\partial c}$),那我們就可以透過這個上述『誰影響誰』的路徑來計算:

$\frac{\partial J}{\partial a}=\frac{\partial v}{\partial a}\frac{\partial J}{\partial v}$ $(4)$
$\frac{\partial J}{\partial b}=\frac{\partial u}{\partial b}\frac{\partial v}{\partial u}\frac{\partial J}{\partial v}$ $(5)$
$\frac{\partial J}{\partial c}=\frac{\partial u}{\partial c}\frac{\partial v}{\partial u}\frac{\partial J}{\partial v}$ $(6)$

像這樣子的影響鏈就是微積分這門學問裡的連鎖率(Chain Rule)
依照圖(4)計算圖的流程,我們是可以算出各參數分別對成本函數偏微分值分別是多少:$\frac{\partial J}{\partial a}=3$、$\frac{\partial J}{\partial b}=6$、$\frac{\partial J}{\partial c}=9$
最終獲得$\frac{\partial J}{\partial a}$、$\frac{\partial J}{\partial b}$、$\frac{\partial J}{\partial c}$,就可以用來更新參數了!
(其實,計算$\frac{\partial J}{\partial a}$、$\frac{\partial J}{\partial b}$、$\frac{\partial J}{\partial c}$的過程就是Backward pass)

 

References
  1. Andrew Ng – Neural Networks & Deep Learning in Coursera
  2. (paper) Learning representations by back-propagation errors

 

Andy Wang

站在巨人的肩膀上仍須戰戰兢兢!

One thought on “Backpropagation(BP) 倒傳遞法 #1 工作原理與說明

發表迴響

這個網站採用 Akismet 服務減少垃圾留言。進一步了解 Akismet 如何處理網站訪客的留言資料