Soptq

Soptq

Probably a full-stack, mainly focusing on Distributed System / Consensus / Privacy-preserving Tech etc. Decentralization is a trend, privacy must be protected.
twitter
github
bilibili

分布式情境下的 Batch Normalization 處理

批次正規化(BN)#

批次正規化(BN)的動機大致如下。在深度學習過程中,每一層網路都會對輸入的數據進行一次映射。隨著網路的不斷加深,前面 $n$ 層的網路對輸入數據的影響是累加的。因此對於第 $n+1$ 層來說,每一次輸入的數據都是數據集的原始數據通過 $n$ 次映射後的值。因此很有可能對於第 $n+1$ 層來說,每一次輸入的數據不是一個分布的,所以每一次第 $n+1$ 層都會重新學習對於當前分布的參數。這樣網路的學習就會被拖慢,甚至導致無法訓練。

此外,當我們使用飽和激活函數(飽和激活函數)時,例如 sigmoidtanh 激活函數,很容易使得模型訓練陷入梯度飽和區(飽和區)。體現就是隨著模型的訓練,$x$ 會越來越大,導致 $w^{T} x + b$ 越來越大,於是激活函數的梯度接近於零,收斂速度變慢。

為了解決這兩個問題,引入了批次正規化(Batch Normalization)來對 $x$ 做歸一化。通過 BN 層後,數據分布會被歸一化到 $N (1, 0)$,這樣對於每一層網路來說,輸入的參數都是來自於一個分布的,梯度飽和區也不容易陷入,整個模型訓練速度就會快很多。

BN 在分佈式訓練中為什麼不起作用?#

在 BN 的實現中,程式需要知道當前數據的平均值和方差才能將其歸一化到 $N (1, 0)$ 分佈,而測試數據集的分佈我們是不知道的。所以我們需要用訓練數據集的分佈數據來估計測試數據集的分佈數據。

在訓練的時候,BN 層使用批次統計量計算當前批次數據的真實平均值和方差,從而非常準確地將輸入數據歸一化到 $N (1, 0)$ 分佈。同時,由於我們通常假設測試數據集和訓練數據集是同分佈的(IID),所以我們會將本次批次的平均值和方差以 EMA 更新(滑動更新)的方式用來估計全部訓練數據集的分佈(因為這樣更省內存)。

但是在分佈式的時候,每一個節點的 BN 層會分別計算它所分配的數據的加權平均值和方差。很容易想到,將每一個小數據集的方差求加權平均,顯然不一定等價於整個數據集的方差(特別是在聯邦學習的情境下)。因此最終訓練的模型在使用的時候,測試數據在 BN 層就無法被正確歸一化,最終導致性能顯著下降。

怎麼解決呢?#

最粗暴的方法:直接刪掉 BN 層。

嗯,確實是最粗暴的,但這樣必將極大地影響模型的最終效果和訓練時間。有沒有其他方法呢?這裡為讀者提供幾個鏈接,可以去深入了解一下:

  1. https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
  2. https://arxiv.org/abs/1803.08494
載入中......
此文章數據所有權由區塊鏈加密技術和智能合約保障僅歸創作者所有。