Long, long ago, when computers were not yet popular, finding the derivative of a complex function was surely a nightmare for many scholars. Just imagine, maybe you can easily differentiate each variable when faced with a ternary polynomial function, but when the function is expanded to thousands or tens of thousands of terms, do you still have the confidence to find the derivative? Therefore, with the popularization of computers, the introduction of automatic differentiation algorithms has helped many scholars free themselves from the differentiation of large functions.
Backpropagation in deep learning is also a typical process of automatic differentiation, and as the core of the magical power of PyTorch, Autograd, many people must have been curious about its implementation. In this article, I will explain a simple example of automatic differentiation, hoping to inspire everyone.
Principle of Automatic Differentiation#
When everyone was studying advanced mathematics in their freshman year, they must have heard of a term: the chain rule. The mathematical representation of the chain rule is as follows:
As you can see, is a function that includes , and is a function that includes . So, in order to find the derivative of with respect to , we can first find the derivative of with respect to , and then find the derivative of with respect to , and finally multiply the two derivatives to get the desired result.
Some people say that backpropagation is just another fancy name for the chain rule. Indeed, this is actually the core of automatic differentiation. The chain rule tells us that when faced with a complex function, if we can break it down into simple atomic functions and combine the results, we can differentiate each simple atomic function and multiply the derivatives to obtain the result of the complex function. So, is the most basic neuron in deep learning, , a simple atomic function?
Therefore, the idea of automatic differentiation becomes clear. We record the path from each variable to the final function through forward propagation, and then we return along this path from the function to obtain the derivative of the function with respect to that variable.
Designing an Automatic Differentiation Machine Using Binary Trees#
In this section, we will use a binary tree to design the simplest automatic differentiation machine. Let's assume we want to differentiate a simple function:
Converting the Equation to Postfix Notation#
The above function is an infix notation. Infix notation is a convenient expression for human cognition, but it is not convenient for computers to read. In order for the computer to obtain the parse tree of this function, we need to first convert the function from infix notation to postfix notation using the Shunting Yard algorithm.
Converting Postfix Notation to Parse Tree#
Converting postfix notation to a parse tree is a very simple process. We only need to traverse the postfix notation once. The rules for traversal are as follows:
- When a non-operator is read, the character read is stored as a node in the parse tree.
- When an operator is read, the character read is stored as a node in the parse tree, and the top two nodes of the stack are taken as its left and right child nodes, and then stored in the stack.
For example, if we have an expression , its corresponding postfix notation is:
3 4 2 * 1 5 − 2 3 ^ ^ / +
So the process of constructing the parse tree is as follows:
Input | Stack |
---|---|
3 | (3) |
4 | (3),(4) |
2 | (3),(4),(2) |
x | (3),(4 x 2) |
1 | (3),(4 x 2),(1) |
5 | (3),(4 x 2),(1),(5) |
- | (3),(4 x 2),(1 - 5) |
2 | (3),(4 x 2),(1 - 5),(2) |
3 | (3),(4 x 2),(1 - 5),(2),(3) |
^ | (3),(4 x 2),(1 - 5),(2 ^ 3) |
^ | (3),(4 x 2),((1 - 5) ^ (2 ^ 3)) |
/ | (3),((4 x 2) / ((1 - 5) ^ (2 ^ 3))) |
+ | (3 + (4 x 2) / ((1 - 5) ^ (2 ^ 3))) |
The code would look something like this:
struct ParseTree{
string data; // stores the data
int result = -1; // stores the result of subtree calculation
int diff = -1; // stores the derivative value
struct ParseTree* left = nullptr;
struct ParseTree* right = nullptr;
};
ParseTree* parse(){
for (int I = 0; I < o_vector.size(); I++){
if (isOperator(o_vector[I])) {
ParseTree* s1 = parse_stack.top();
parse_stack.pop();
ParseTree* s2 = parse_stack.top();
parse_stack.pop();
ParseTree* parseNode = new ParseTree();
parseNode->data = o_vector[I];
parseNode->left = s1;
parseNode->right = s2;
parse_stack.push(parseNode);
} else {
ParseTree* parseLeaf = new ParseTree();
parseLeaf->data = o_vector[I];
parse_stack.push(parseLeaf);
}
}
return parse_stack.top();
}
Forward Propagation#
After obtaining the parse tree, the root node is equivalent to the function itself, the internal nodes represent each operator, and the leaf nodes represent each variable. The process of forward propagation is actually a recursive process from the leaf nodes to the root node. Due to the characteristics of binary trees, each node only has two child nodes, which means that each internal node only involves one operator and two variables. Therefore, the entire process of forward propagation is a recursive process from the root node.
int calculate(ParseTree* head) {
if (head->left != nullptr || head->right != nullptr) { // head is a node
head->result =node_calculate((head->data)[0], calculate(head->left), calculate(head->right)); // node_calculate() is a function to calculate the result
return head->result;
}
return stoi(head->data);
}
Backward Propagation#
After forward propagation, each internal node, even if it may contain an operator, also stores the result of the subtree rooted at it. So when performing operations on its parent node, it can be treated as a number. Therefore, the process of backward propagation becomes a process starting from the root node, calculating the derivative of each leaf node, and finally accumulating the derivatives. I hope the reader can implement this part of the code on their own.
After implementing this part of the code, a complete automatic differentiation code is also completed.