Soptq

Soptq

Probably a full-stack, mainly focusing on Distributed System / Consensus / Privacy-preserving Tech etc. Decentralization is a trend, privacy must be protected.
twitter
github
bilibili

簡易自動求導機

很久很久以前,當計算機還不普及的時候,對一個複雜函數的求導一定是眾多學者的噩夢之一。試想,也許面對一個三元多項式函數時,你還可以遊刃有餘地對每一個變量求導數,但當函數被拓展到成千上萬元、成千上萬項的時候,你還有信心求的出來導嗎?於是,隨著計算機的普及,自動求導算法的提出幫助眾多學者從大型函數的求導中解放了出來。

深度學習的反向傳播也是一個典型的自動求導過程,而作為 pytorch 魔法力量的核心 Autograd,一定也被很多人好奇過其實現的原理。在本篇文章中,我將向大家解釋一個最為簡單的自動求道例子,希望可以抛磚引玉,啟發大家。

自動求道的原理#

每一個人在大一學習高等數學時,一定都聽說過一個名詞:鏈式法則。鏈式法則的數學表示如下:

dudx=dudydydx\frac {du} {dx}=\frac {du} {dy}\cdot \frac {dy} {dx}

可以看到, uu 是包含了 yy 的函數, yy 是包含了 xx 的函數,那麼為了求到 uuxx 的導數, 我們可以先求 uuyy 的導數 dudy\frac{du}{dy}, 然後求 yyxx 的導數 dydx\frac{dy}{dx} ,最後把兩個導數相乘就算出了我們想要的結果。

有人說,反向傳播就是鏈式法則的另外一個花哨的名字。確實,這其實就是自動求導的核心。鏈式法則告訴我們,面對一個複雜的函數,我們若是可以把它拆分為一節一節簡單的函數原子複合起來的結果,那我們對每一個簡單的函數原子求導,最後把導數乘起來就可以得到複雜函數的結果。那麼,深度學習最基礎的神經元, y=wx+by = wx + b 是不是就是一個簡單的原子?

所以,自動求導和思路也就清晰了。我們通過前向傳播記錄每一個變量到最終函數的路徑,然後我們沿著這條路徑從函數返回,就可以得到函數對該變量的導數了。

自動求導路徑

利用二叉樹設計一個自動求導機#

在本部分我們將利用二叉樹設計一個最簡單的自動求導機。我們假設對一個簡單的函數求導:

z=3x2+5x+lgyz=3x^{2}+5x+\lg y

將算式轉換為後綴表達式#

上述函數為一個中綴表達式。中綴表達式是一種方便人類認知的表達式,但其不方便計算機讀取。為了讓計算機得到這個函數的解析樹,我們需要先使用調度場算法將函數從中綴表達式轉換為後綴表達式。

將後綴表達式轉換為解析樹#

將後綴表達式轉換為解析樹就是一個非常簡單的過程了,我們只需要對後綴表達式作一次遍歷即可。遍歷的規則如下:

  • 當讀取到的不為運算符時,將讀取的字符作為解析樹的一個節點存入棧內。
  • 當讀取到的為運算符時,將讀取的字符作為解析樹的一個節點,取出棧頂的兩個節點作為其的左右子節點,然後存入棧內。

例如我們有一個表達式為 3+4×2(15)233+4\times \frac {2} {\left( 1-5\right) ^{2^{3}}} ,他所對應的後綴表達式為:

3 4 2 * 1 5 − 2 3 ^ ^ / + 

於是我們構建解析樹的過程如下:

輸入
3(3)
4(3),(4)
2(3),(4),(2)
x(3),(4 x 2)
1(3),(4 x 2),(1)
5(3),(4 x 2),(1),(5)
-(3),(4 x 2),(1 - 5)
2(3),(4 x 2),(1 - 5),(2)
3(3),(4 x 2),(1 - 5),(2),(3)
^(3),(4 x 2),(1 - 5),(2 ^ 3)
^(3),(4 x 2),((1 - 5) ^ (2 ^ 3))
/(3),((4 x 2) / ((1 - 5) ^ (2 ^ 3)))
+(3 + (4 x 2) / ((1 - 5) ^ (2 ^ 3)))

代碼大概是這樣的

struct ParseTree{
	string data; // 存的數據
	int result = -1; // 子樹計算的結果
	int diff = -1; // 自動求導值
	struct ParseTree* left = nullptr;
	struct ParseTree* right = nullptr;
};

ParseTree* parse(){
	for (int I = 0; I < o_vector.size(); I++){
		if (isOperator(o_vector[I])) {
			ParseTree* s1 = parse_stack.top();
			parse_stack.pop();
			ParseTree* s2 = parse_stack.top();
			parse_stack.pop();
			ParseTree* parseNode = new ParseTree();
			parseNode->data = o_vector[I];
			parseNode->left = s1;
			parseNode->right = s2;
			parse_stack.push(parseNode);
		} else {
			ParseTree* parseLeaf = new ParseTree();
			parseLeaf->data = o_vector[I];
			parse_stack.push(parseLeaf);
		}
	}
	return parse_stack.top();
}

正向傳播#

我們得到的解析樹之後,根節點就相當於是函數本身,內節點相當於是每一個計算符號,葉節點相當於每一個變量。正向傳播的過程,實際就是從葉節點開始,經過每一個內節點最後計算到根節點的過程。由於二叉樹的特點,每一個節點只有兩個字節點,也就是說每一次內節點只會涉及一個符號,兩個變量。所以整個正向傳播過程就是從根節點的一個遞歸過程。

int calculate(ParseTree* head) {
	if (head->left != nullptr || head->right != nullptr) { // head is a node
		head->result =node_calculate((head->data)[0], calculate(head->left), calculate(head->right)); // node_calculate() is a function to calculate the result
		return head->result;
	}
	return stoi(head->data);
}

反向傳播#

我們正向傳播後,每一個內節點,儘管其可能包含的是一個運算符號,它也存儲了以它為根節點的子樹的運算結果。所以當對它的父節點作運算時,可以將它看作一個數字。所以方向傳播的過程就變成了從根節點開始的,每一次算字節點的導數,最後導數累加的過程。這個過程希望讀者可以自己來實現代碼。

當讀者實現完這個部分的代碼後,一段完整的自動求導代碼也就完成了。

載入中......
此文章數據所有權由區塊鏈加密技術和智能合約保障僅歸創作者所有。