Soptq

Soptq

Probably a full-stack, mainly focusing on Distributed System / Consensus / Privacy-preserving Tech etc. Decentralization is a trend, privacy must be protected.
twitter
github
bilibili

分散シナリオでのバッチ正規化処理

バッチ正規化(BN)#

バッチ正規化(BN)の動機は次のようなものです。ディープラーニングのプロセスでは、各層のネットワークは入力データに対してマッピングを行います。したがって、ネットワークが深くなるにつれて、最初の $n$ 層のネットワークは入力データに対する影響を累積します。したがって、$n+1$ 層にとって、各入力データはデータセットの元のデータが $n$ 回のマッピングを経た値です。したがって、$n+1$ 層にとって、各入力データは分布ではない可能性が非常に高いため、$n+1$ 層は現在の分布に対して再学習を行います。これにより、ネットワークの学習が遅くなり、トレーニングができなくなる可能性があります。

さらに、飽和活性化関数(Saturated Activation Function)を使用する場合、例えばsigmoidtanh活性化関数を使用する場合、モデルのトレーニングが勾配飽和領域(Saturated Regime)に陥りやすくなります。これは、モデルのトレーニングに従って $x$ が大きくなり、$w^{T} x + b$ も大きくなるため、活性化関数の勾配がゼロに近づき、収束速度が遅くなることを示しています。

これらの 2 つの問題を解決するために、バッチ正規化は $x$ を正規化するために導入されました。BN 層を通過することで、データの分布は $N (1, 0)$ に正規化されます。したがって、各層のネットワークにとって、入力パラメータは分布から取得されるため、勾配飽和領域に陥りにくくなり、モデル全体のトレーニング速度が大幅に向上します。

BN はなぜ分散トレーニングでは機能しないのか?#

BN の実装では、プログラムは現在のデータの平均値と分散を知る必要があります。これにより、データを $N (1, 0)$ 分布に正規化することができます。一方、テストデータセットの分布はわかりません。したがって、テストデータセットの分布を推定するために、トレーニングデータセットの分布データを使用する必要があります。

トレーニング時には、BN 層はバッチ統計を使用して現在のバッチデータの真の平均値と分散を計算し、入力データを $N (1, 0)$ 分布に正確に正規化します。同時に、通常、テストデータセットとトレーニングデータセットは同じ分布(IID)であると仮定するため、バッチの平均値と分散を EMA(指数移動平均)の更新方法で使用して、トレーニングデータセット全体の分布を推定します(これにより、メモリの使用量が少なくなります)。

しかし、分散トレーニングでは、各ノードの BN 層はそれぞれ割り当てられたデータの加重平均値と分散を計算します。明らかに、各小さなデータセットの分散の加重平均を計算することは、データセット全体の分散とは必ずしも等しくないことが容易に想像できます(特にフェデレーテッドラーニングの場合)。したがって、最終的にトレーニングされたモデルは、BN 層でテストデータが正しく正規化されないため、性能が著しく低下します。

どう解決するか?#

最も直接的な方法は、BN 層を削除することです。

はい、これは最も直接的な方法ですが、モデルの最終的な効果とトレーニング時間に大きな影響を与えることになります。他の方法はありますか?以下のリンクを提供して、詳細を調べることができます:

  1. https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
  2. https://arxiv.org/abs/1803.08494
読み込み中...
文章は、創作者によって署名され、ブロックチェーンに安全に保存されています。