Soptq

Soptq

Probably a full-stack, mainly focusing on Distributed System / Consensus / Privacy-preserving Tech etc. Decentralization is a trend, privacy must be protected.
twitter
github
bilibili

簡単な自動微分機

很久很久以前、コンピュータが一般的ではなかった頃、複雑な関数の微分は多くの学者にとって悪夢の一つでした。例えば、三次多項式関数に直面した場合、各変数の微分を余裕を持って行うことができるかもしれませんが、関数が数千万項、数万項に拡張された場合、それでも微分を求める自信がありますか?そのため、コンピュータの普及とともに、自動微分アルゴリズムの提案により、多くの学者が大規模な関数の微分から解放されました。

ディープラーニングのバックプロパゲーションも典型的な自動微分プロセスであり、pytorch の魔法の力である Autograd も、その実装原理について多くの人々の好奇心を引いたことでしょう。この記事では、最もシンプルな自動微分の例を説明し、皆さんに刺激となることを願っています。

自動微分の原理#

誰もが大学の初等数学を学ぶ際に、連鎖律という用語を聞いたことがあるでしょう。連鎖律の数学的表現は次のようになります:

dudx=dudydydx\frac {du} {dx}=\frac {du} {dy}\cdot \frac {dy} {dx}

ここで、 uuyy を含む関数であり、 yyxx を含む関数です。したがって、 uuxx に関して微分するためには、まず uuyy に関して微分する dudy\frac{du}{dy} を求め、次に yyxx に関して微分する dydx\frac{dy}{dx} を求め、最後にこれらの微分を掛け合わせることで求めたい結果が得られます。

誰かが言った、「バックプロパゲーションは連鎖律の別の派手な名前だ」と。実際、これが自動微分の核心です。連鎖律は、複雑な関数に直面した場合、それを単純な関数の原子に分解し、各単純な関数の原子に対して微分を求め、最後にそれらの微分を掛け合わせることで複雑な関数の結果を得ることができると教えてくれます。では、ディープラーニングの基本的なニューロン、 y=wx+by = wx + b は単純な原子ではないでしょうか?

したがって、自動微分とそのアイデアは明確になりました。前方伝播では、各変数から最終関数へのパスを記録し、そのパスに沿って関数から戻ることで、関数に対する変数の微分を得ることができます。

自動微分のパス

二分木を使用して自動微分マシンを設計する#

このセクションでは、二分木を使用して最もシンプルな自動微分マシンを設計します。単純な関数の微分を求めることを想定しています:

z=3x2+5x+lgyz=3x^{2}+5x+\lg y

式を逆ポーランド記法に変換する#

上記の関数は中置記法です。中置記法は人間にとって理解しやすい表記方法ですが、コンピュータが読み取りにくいです。この関数の解析木を得るために、まず逆ポーランド記法を使用して関数を中置記法から逆ポーランド記法に変換する必要があります。

逆ポーランド記法を解析木に変換する#

逆ポーランド記法を解析木に変換することは非常に簡単なプロセスです。逆ポーランド記法を一度トラバースするだけで済みます。トラバースのルールは次のとおりです:

  • 読み取った文字が演算子でない場合、読み取った文字を解析木のノードとしてスタックに保存します。
  • 読み取った文字が演算子の場合、読み取った文字を解析木のノードとし、スタックのトップの 2 つのノードをその左右の子ノードとして取り出し、スタックに保存します。

例えば、式 3+4×2(15)233+4\times \frac {2} {\left( 1-5\right) ^{2^{3}}} が与えられた場合、それに対応する逆ポーランド記法は次のようになります:

したがって、解析木の構築プロセスは次のようになります:

入力スタック
3(3)
4(3),(4)
2(3),(4),(2)
x(3),(4 x 2)
1(3),(4 x 2),(1)
5(3),(4 x 2),(1),(5)
-(3),(4 x 2),(1 - 5)
2(3),(4 x 2),(1 - 5),(2)
3(3),(4 x 2),(1 - 5),(2),(3)
^(3),(4 x 2),(1 - 5),(2 ^ 3)
^(3),(4 x 2),((1 - 5) ^ (2 ^ 3))
/(3),((4 x 2) / ((1 - 5) ^ (2 ^ 3)))
+(3 + (4 x 2) / ((1 - 5) ^ (2 ^ 3)))

コードは以下のようになります。

struct ParseTree{
	string data; // データを保存
	int result = -1; // サブツリーの計算結果
	int diff = -1; // 自動微分値
	struct ParseTree* left = nullptr;
	struct ParseTree* right = nullptr;
};

ParseTree* parse(){
	for (int I = 0; I < o_vector.size(); I++){
		if (isOperator(o_vector[I])) {
			ParseTree* s1 = parse_stack.top();
			parse_stack.pop();
			ParseTree* s2 = parse_stack.top();
			parse_stack.pop();
			ParseTree* parseNode = new ParseTree();
			parseNode->data = o_vector[I];
			parseNode->left = s1;
			parseNode->right = s2;
			parse_stack.push(parseNode);
		} else {
			ParseTree* parseLeaf = new ParseTree();
			parseLeaf->data = o_vector[I];
			parse_stack.push(parseLeaf);
		}
	}
	return parse_stack.top();
}

順方向伝播#

解析木が得られたら、根ノードは関数自体を表し、内部ノードは各演算子、葉ノードは各変数を表します。順方向伝播のプロセスは、葉ノードから始まり、各内部ノードを経由して最終的に根ノードまで計算するプロセスです。二分木の特性により、各ノードは 2 つの子ノードしか持たないため、各内部ノードは 1 つの演算子と 2 つの変数に関わるだけです。したがって、順方向伝播のプロセス全体は根ノードから始まる再帰プロセスです。

int calculate(ParseTree* head) {
	if (head->left != nullptr || head->right != nullptr) { // head is a node
		head->result =node_calculate((head->data)[0], calculate(head->left), calculate(head->right)); // node_calculate()は結果を計算する関数です
		return head->result;
	}
	return stoi(head->data);
}

逆方向伝播#

順方向伝播が終わると、各内部ノードは演算子を含むかもしれませんが、そのノードを根とする部分木の計算結果も保持しています。したがって、その親ノードで演算を行う際には、そのノードを数字として扱うことができます。したがって、逆方向伝播のプロセスは、根ノードから始まり、各葉ノードの微分を計算し、最後に微分を累積するプロセスになります。このプロセスは、読者自身でコードを実装してみてください。

この部分のコードを実装した後、完全な自動微分コードが完成します。

読み込み中...
文章は、創作者によって署名され、ブロックチェーンに安全に保存されています。