Soptq

Soptq

Probably a full-stack, mainly focusing on Distributed System / Consensus / Privacy-preserving Tech etc. Decentralization is a trend, privacy must be protected.
twitter
github
bilibili

反向伝播アルゴリズムの数学的導出

本文はCSDN 反向伝播アルゴリズム(プロセスと公式の導出)を深く参考にしています。

基本的な定義#

単純なニューラルネットワークの例

上記の図に示されている単純なニューラルネットワークでは、layer 1は入力層、layer 2は隠れ層、layer 3は出力層です。以下にいくつかの変数名の意味を説明します:

名前意味
bilb_{i}^{l}ll 層の ii 番目のニューロンのバイアス
wjilw_{ji}^{l}l1l-1 層の ii 番目のニューロンが第 ll 層の jj 番目のニューロンに接続される重み
zilz_{i}^{l}ll 層の ii 番目のニューロンの入力
aila_{i}^{l}ll 層の ii 番目のニューロンの出力
σ\sigma活性化関数

上記の定義から、次のことがわかります:

zjl=iwjilail1+bjlz_{j}^{l} = \sum_{i}w_{ji}^{l}a_{i}^{l-1} + b_{j}^{l}

ajl=σzjl=σ(iwjilail1+bjl)a_{j}^{l} = \sigma z_{j}^{l} = \sigma \left( \sum_{i}w_{ji}^{l}a_{i}^{l-1} + b_{j}^{l} \right)

損失関数を二乗コスト関数(Quadratic Cost Function)とします:

J=12nxy(x)aL(x)2J = \frac{1}{2n} \sum_{x} \lvert \lvert y(x) - a^{L}(x) \rvert \rvert ^ {2}

ここで、xxは入力サンプルを表し、y(x)y(x)は実際の分類を表し、aL(x)a^{L}(x)は予測された分類を表し、LLはネットワークの最大層数を表します。入力サンプルが 1 つだけの場合、損失関数 JJは次のようになります:

J=12xy(x)aL(x)2J = \frac{1}{2} \sum_{x} \lvert \lvert y(x) - a^{L}(x) \rvert \rvert ^ {2}

最後に、第 ll 層の ii 番目のニューロンで生成されるエラーを次のように定義します:

δilJzil\delta_{i}^{l} \equiv \frac{\partial{J}}{\partial{z_{i}^{l}}}

公式の導出#

損失関数に対する最後の層のニューラルネットワークのエラーは次のようになります:

δiL=JziL=JaiLaiLziL=J(aiL)σ(ziL)\begin{aligned}\delta_{i}^{L} &= \frac{\partial{J}}{\partial{z_{i}^{L}}}\\&=\frac{\partial{J}}{\partial{a_{i}^{L}}} \cdot \frac{\partial{a_{i}^{L}}}{\partial{z_{i}^{L}}}\\&=\nabla J(a_{i}^{L}) \sigma^{'}(z_{i}^{L})\end{aligned}

δL=J(aL)σ(zL)\delta^{L} = \nabla J(a^{L}) \odot \sigma^{'}(z^{L})

損失関数に対する jj 層目のネットワークのエラーは次のようになります:

δjl=Jzjl=Jajlajlzjl=iJzil+1zil+1ajlajlzjl=iδil+1wijl+1ajl+bil+1ajlσ(zjl)=iδil+1wijl+1σ(zjl)\begin{aligned}\delta_{j}^{l} &= \frac{\partial{J}}{\partial{z_{j}^{l}}} \\ &= \frac{\partial{J}}{\partial{a_{j}^{l}}} \cdot \frac{\partial{a_{j}^{l}}}{\partial{z_{j}^{l}}} \\ &= \sum_{i} \frac{\partial{J}}{\partial{z_{i}^{l+1}}} \cdot \frac{\partial{z_{i}^{l+1}}}{\partial{a_{j}^{l}}} \cdot \frac{\partial{a_{j}^{l}}}{\partial{z_{j}^{l}}} \\ &= \sum_{i} \delta_{i}^{l+1} \cdot \frac{\partial{w_{ij}^{l+1}a_{j}^{l} + b_{i}^{l+1}}}{\partial{a_{j}^{l}}} \cdot \sigma^{'}(z_{j}^{l}) \\ &=\sum_{i} \delta_{i}^{l+1} \cdot w_{ij}^{l+1} \cdot \sigma^{'}(z_{j}^{l}) \end{aligned}

δl=((wl+1)Tδl+1)σ(zl)\delta^{l} = \left( \left( w^{l+1} \right)^{T} \delta^{l+1} \right) \odot \sigma^{'}(z^{l})

したがって、損失関数を使用して重みの勾配を計算することができます:

Jwjil=Jzjlzjlwjil=δjl(wjilail1+bjl)wjil=δjlail1\begin{aligned} \frac{\partial{J}}{\partial{w_{ji}^{l}}} &= \frac{\partial{J}}{\partial{z_{j}^{l}}} \cdot \frac{\partial{z_{j}^{l}}}{\partial{w_{ji}^{l}}} \\ &= \delta_{j}^{l} \cdot \frac{\partial{\left( w_{ji}^{l}a_{i}^{l-1} + b_{j}^{l} \right)}}{\partial{w_{ji}^{l}}} \\ &= \delta_{j}^{l} \cdot a_{i}^{l-1} \end{aligned}

Jwjil=δjlail1\frac{\partial{J}}{\partial{w_{ji}^{l}}} = \delta_{j}^{l} \cdot a_{i}^{l-1}

最後に、損失関数を使用してバイアスの勾配を計算します:

Jbjl=Jzjlzjlbjl=δjlwjilail1+bjlbjl=δjl\begin{aligned} \frac{\partial{J}}{\partial{b_{j}^{l}}} &= \frac{\partial{J}}{\partial{z_{j}^{l}}} \cdot \frac{\partial{z_{j}^{l}}}{\partial{b_{j}^{l}}} \\ &= \delta_{j}^{l} \cdot \frac{\partial{w_{ji}^{l} a_{i}^{l-1} + b_{j}^{l}}}{\partial{b_{j}^{l}}} \\ &=\delta_{j}^{l} \end{aligned}

読み込み中...
文章は、創作者によって署名され、ブロックチェーンに安全に保存されています。