Soptq

Soptq

Probably a full-stack, mainly focusing on Distributed System / Consensus / Privacy-preserving Tech etc. Decentralization is a trend, privacy must be protected.
twitter
github
bilibili

分布式情景下的 Batch Normalization 处理

Batch Normalization (BN)#

Batch Normalization 的 motivation 大概是这样的。在深度学习过程中,每一层网络都会对输入的数据做一次映射。于是随着网络的不断加深,前面 $n$ 层的网络对输入数据的影响是累加的。所以对于 $n+1$ 层来说,每一次输入的数据都是数据集的原始数据通过 $n$ 次的映射后的值。所以非常有可能对于 $n+1$ 层来说,每一次输入的数据不是一个分布的,所以每一次第 $n+1$ 层都会去重新学习对于当前分布的参数。这样网络的学习就会被拖慢,甚至导致无法训练。

另外,当我们使用饱和激活函数 (Saturated Activation Function) 时,例如sigmoidtanh激活函数,很容易使得模型训练陷入梯度饱和区 (Saturated Regime)。体现就是随着模型的训练,$x$ 会越来越大,导致 $w^{T} x + b$ 越来越大,于是激活函数的梯度接近于零,收敛速度变慢。

为了解决这两个问题,Batch Normalization 被引入用来对 $x$ 做归一化。通过 BN 层后数据分布会被归一到 $N (1, 0)$,这样子对于每一层网络来说,输入的参数都是来自于一个分布的,梯度饱和区也不容易陷入,整个模型训练速度就会快很多。

BN 为什么在分布式训练中不 work?#

在 BN 的实现中,程序需要知道当前数据的平均值和方差才能将其归一到 $N (1, 0)$ 分布,而测试数据集的分布我们是不知道的。所以我们需要用训练数据集的分布数据取估计测试数据集的分布数据。

在训练的时候,BN 层使用 Batch Statistic 计算当前 Batch 数据的真实平均值和方差,从而非常准确的将输入数据归一到 $N (1, 0)$ 分布。同时,由于我们通常假设测试数据集和训练数据集是同分布的 (IID),所以我们会将本次 Batch 的平均值和方差以 EMA 更新(滑动更新)的方式用来估计全部训练数据集的分布(因为这样更省内存)。

但是在分布式的时候,每一个节点的 BN 层会分别计算它所分配的数据的加权平均值和方差。很容易的想到,把每一个小数据集的方差求加权平均,显然是不一定等价于整个数据集的方差的(特别是在联邦学习的情景下)。于是最终训练的模型在使用的时候,测试数据在 BN 层就无法被正确归一化,最终导致性能显著下降。

怎么解决呢?#

最粗暴的方法:直接删掉 BN 层。

嗯,确实是最粗暴的,可是这样必将极大地影响模型的最终效果和训练时间。有没有其他方法呢?这里为读者提供几个链接,可以去深入了解一下:

  1. https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
  2. https://arxiv.org/abs/1803.08494
加载中...
此文章数据所有权由区块链加密技术和智能合约保障仅归创作者所有。