Soptq

Soptq

Probably a full-stack, mainly focusing on Distributed System / Consensus / Privacy-preserving Tech etc. Decentralization is a trend, privacy must be protected.
twitter
github
bilibili

简易自动求导机

很久很久以前,当计算机还不普及的时候,对一个复杂函数的求导一定是众多学者的噩梦之一。试想,也许面对一个三元多项式函数时,你还可以游刃有余地对每一个变量求导数,但当函数被拓展到成千上万元、成千上万项的时候,你还有信心求的出来导吗?于是,随着计算机的普及,自动求导算法的提出帮助众多学者从大型函数的求导中解放了出来。

深度学习的反向传播也是一个典型的自动求导过程,而作为 pytorch 魔法力量的核心 Autograd,一定也被很多人好奇过其实现的原理。在本篇文章中,我将向大家解释一个最为简单的自动求道例子,希望可以抛砖引玉,启发大家。

自动求道的原理#

每一个人在大一学习高等数学时,一定都听说过一个名词:链式法则。链式法则的数学表示如下:

dudx=dudydydx\frac {du} {dx}=\frac {du} {dy}\cdot \frac {dy} {dx}

可以看到, uu 是包含了 yy 的函数, yy 是包含了 xx 的函数,那么为了求到 uuxx 的导数, 我们可以先求 uuyy 的导数 dudy\frac{du}{dy}, 然后求 yyxx 的导数 dydx\frac{dy}{dx} ,最后把两个导数相乘就算出了我们想要的结果。

有人说,反向传播就是链式法则的另外一个花哨的名字。确实,这其实就是自动求导的核心。链式法则告诉我们,面对一个复杂的函数,我们若是可以把它拆分为一节一节简单的函数原子复合起来的结果,那我们对每一个简单的函数原子求导,最后把导数乘起来就可以得到复杂函数的结果。那么,深度学习最基础的神经元, y=wx+by = wx + b 是不是就是一个简单的原子?

所以,自动求导和思路也就清晰了。我们通过前向传播记录每一个变量到最终函数的路径,然后我们沿着这条路径从函数返回,就可以得到函数对该变量的导数了。

自动求导路径

利用二叉树设计一个自动求导机#

在本部分我们将利用二叉树设计一个最简单的自动求导机。我们假设对一个简单的函数求导:

z=3x2+5x+lgyz=3x^{2}+5x+\lg y

将算式转换为后缀表达式#

上述函数为一个中缀表达式。中缀表达式是一种方便人类认知的表达式,但其不方便计算机读取。为了让计算机得到这个函数的解析树,我们需要先使用调度场算法将函数从中缀表达式转换为后缀表达式。

将后缀表达式转换为解析树#

将后缀表达式转换为解析树就是一个非常简单的过程了,我们只需要对后缀表达式作一次遍历即可。遍历的规则如下:

  • 当读取到的不为运算符时,将读取的字符作为解析树的一个节点存入栈内。
  • 当读取到的为运算符时,将读取的字符作为解析树的一个节点,取出栈顶的两个节点作为其的左右子节点,然后存入栈内。

例如我们有一个表达式为 3+4×2(15)233+4\times \frac {2} {\left( 1-5\right) ^{2^{3}}} ,他所对应的后缀表达式为:

3 4 2 * 1 5 − 2 3 ^ ^ / + 

于是我们构建解析树的过程如下:

输入
3(3)
4(3),(4)
2(3),(4),(2)
x(3),(4 x 2)
1(3),(4 x 2),(1)
5(3),(4 x 2),(1),(5)
-(3),(4 x 2),(1 - 5)
2(3),(4 x 2),(1 - 5),(2)
3(3),(4 x 2),(1 - 5),(2),(3)
^(3),(4 x 2),(1 - 5),(2 ^ 3)
^(3),(4 x 2),((1 - 5) ^ (2 ^ 3))
/(3),((4 x 2) / ((1 - 5) ^ (2 ^ 3)))
+(3 + (4 x 2) / ((1 - 5) ^ (2 ^ 3)))

代码大概是这样的

struct ParseTree{
	string data; // 存的数据
	int result = -1; // 子树计算的结果
	int diff = -1; // 自动求导值
	struct ParseTree* left = nullptr;
	struct ParseTree* right = nullptr;
};

ParseTree* parse(){
	for (int I = 0; I < o_vector.size(); I++){
		if (isOperator(o_vector[I])) {
			ParseTree* s1 = parse_stack.top();
			parse_stack.pop();
			ParseTree* s2 = parse_stack.top();
			parse_stack.pop();
			ParseTree* parseNode = new ParseTree();
			parseNode->data = o_vector[I];
			parseNode->left = s1;
			parseNode->right = s2;
			parse_stack.push(parseNode);
		} else {
			ParseTree* parseLeaf = new ParseTree();
			parseLeaf->data = o_vector[I];
			parse_stack.push(parseLeaf);
		}
	}
	return parse_stack.top();
}

正向传播#

我们得到的解析树之后,根节点就相当于是函数本身,内节点相当于是每一个计算符号,叶节点相当于每一个变量。正向传播的过程,实际就是从叶节点开始,经过每一个内节点最后计算到根节点的过程。由于二叉树的特点,每一个节点只有两个字节点,也就是说每一次内节点只会涉及一个符号,两个变量。所以整个正向传播过程就是从根节点的一个递归过程。

int calculate(ParseTree* head) {
	if (head->left != nullptr || head->right != nullptr) { // head is a node
		head->result =node_calculate((head->data)[0], calculate(head->left), calculate(head->right)); // node_calculate() is a function to calculate the result
		return head->result;
	}
	return stoi(head->data);
}

反向传播#

我们正向传播后,每一个内节点,尽管其可能包含的是一个运算符号,它也存储了以它为根节点的子树的运算结果。所以当对它的父节点作运算时,可以将它看作一个数字。所以方向传播的过程就变成了从根节点开始的,每一次算字节点的导数,最后导数累加的过程。这个过程希望读者可以自己来实现代码。

当读者实现完这个部分的代码后,一段完整的自动求导代码也就完成了。

加载中...
此文章数据所有权由区块链加密技术和智能合约保障仅归创作者所有。