很久很久以前,當計算機還不普及的時候,對一個複雜函數的求導一定是眾多學者的噩夢之一。試想,也許面對一個三元多項式函數時,你還可以遊刃有餘地對每一個變量求導數,但當函數被拓展到成千上萬元、成千上萬項的時候,你還有信心求的出來導嗎?於是,隨著計算機的普及,自動求導算法的提出幫助眾多學者從大型函數的求導中解放了出來。
深度學習的反向傳播也是一個典型的自動求導過程,而作為 pytorch 魔法力量的核心 Autograd,一定也被很多人好奇過其實現的原理。在本篇文章中,我將向大家解釋一個最為簡單的自動求道例子,希望可以抛磚引玉,啟發大家。
自動求道的原理#
每一個人在大一學習高等數學時,一定都聽說過一個名詞:鏈式法則。鏈式法則的數學表示如下:
可以看到, 是包含了 的函數, 是包含了 的函數,那麼為了求到 對 的導數, 我們可以先求 對 的導數 , 然後求 對 的導數 ,最後把兩個導數相乘就算出了我們想要的結果。
有人說,反向傳播就是鏈式法則的另外一個花哨的名字。確實,這其實就是自動求導的核心。鏈式法則告訴我們,面對一個複雜的函數,我們若是可以把它拆分為一節一節簡單的函數原子複合起來的結果,那我們對每一個簡單的函數原子求導,最後把導數乘起來就可以得到複雜函數的結果。那麼,深度學習最基礎的神經元, 是不是就是一個簡單的原子?
所以,自動求導和思路也就清晰了。我們通過前向傳播記錄每一個變量到最終函數的路徑,然後我們沿著這條路徑從函數返回,就可以得到函數對該變量的導數了。
利用二叉樹設計一個自動求導機#
在本部分我們將利用二叉樹設計一個最簡單的自動求導機。我們假設對一個簡單的函數求導:
將算式轉換為後綴表達式#
上述函數為一個中綴表達式。中綴表達式是一種方便人類認知的表達式,但其不方便計算機讀取。為了讓計算機得到這個函數的解析樹,我們需要先使用調度場算法將函數從中綴表達式轉換為後綴表達式。
將後綴表達式轉換為解析樹#
將後綴表達式轉換為解析樹就是一個非常簡單的過程了,我們只需要對後綴表達式作一次遍歷即可。遍歷的規則如下:
- 當讀取到的不為運算符時,將讀取的字符作為解析樹的一個節點存入棧內。
- 當讀取到的為運算符時,將讀取的字符作為解析樹的一個節點,取出棧頂的兩個節點作為其的左右子節點,然後存入棧內。
例如我們有一個表達式為 ,他所對應的後綴表達式為:
3 4 2 * 1 5 − 2 3 ^ ^ / +
於是我們構建解析樹的過程如下:
輸入 | 棧 |
---|---|
3 | (3) |
4 | (3),(4) |
2 | (3),(4),(2) |
x | (3),(4 x 2) |
1 | (3),(4 x 2),(1) |
5 | (3),(4 x 2),(1),(5) |
- | (3),(4 x 2),(1 - 5) |
2 | (3),(4 x 2),(1 - 5),(2) |
3 | (3),(4 x 2),(1 - 5),(2),(3) |
^ | (3),(4 x 2),(1 - 5),(2 ^ 3) |
^ | (3),(4 x 2),((1 - 5) ^ (2 ^ 3)) |
/ | (3),((4 x 2) / ((1 - 5) ^ (2 ^ 3))) |
+ | (3 + (4 x 2) / ((1 - 5) ^ (2 ^ 3))) |
代碼大概是這樣的
struct ParseTree{
string data; // 存的數據
int result = -1; // 子樹計算的結果
int diff = -1; // 自動求導值
struct ParseTree* left = nullptr;
struct ParseTree* right = nullptr;
};
ParseTree* parse(){
for (int I = 0; I < o_vector.size(); I++){
if (isOperator(o_vector[I])) {
ParseTree* s1 = parse_stack.top();
parse_stack.pop();
ParseTree* s2 = parse_stack.top();
parse_stack.pop();
ParseTree* parseNode = new ParseTree();
parseNode->data = o_vector[I];
parseNode->left = s1;
parseNode->right = s2;
parse_stack.push(parseNode);
} else {
ParseTree* parseLeaf = new ParseTree();
parseLeaf->data = o_vector[I];
parse_stack.push(parseLeaf);
}
}
return parse_stack.top();
}
正向傳播#
我們得到的解析樹之後,根節點就相當於是函數本身,內節點相當於是每一個計算符號,葉節點相當於每一個變量。正向傳播的過程,實際就是從葉節點開始,經過每一個內節點最後計算到根節點的過程。由於二叉樹的特點,每一個節點只有兩個字節點,也就是說每一次內節點只會涉及一個符號,兩個變量。所以整個正向傳播過程就是從根節點的一個遞歸過程。
int calculate(ParseTree* head) {
if (head->left != nullptr || head->right != nullptr) { // head is a node
head->result =node_calculate((head->data)[0], calculate(head->left), calculate(head->right)); // node_calculate() is a function to calculate the result
return head->result;
}
return stoi(head->data);
}
反向傳播#
我們正向傳播後,每一個內節點,儘管其可能包含的是一個運算符號,它也存儲了以它為根節點的子樹的運算結果。所以當對它的父節點作運算時,可以將它看作一個數字。所以方向傳播的過程就變成了從根節點開始的,每一次算字節點的導數,最後導數累加的過程。這個過程希望讀者可以自己來實現代碼。
當讀者實現完這個部分的代碼後,一段完整的自動求導代碼也就完成了。