Batch Normalization (BN)#
The motivation behind Batch Normalization (BN) is as follows. In the process of deep learning, each layer of the network maps the input data. As the network gets deeper, the influence of the previous n layers on the input data accumulates. Therefore, for the n+1 layer, each input data is the value obtained by mapping the original data of the dataset n times. It is very likely that for the n+1 layer, each input data is not a distribution, so each time the n+1 layer will relearn the parameters for the current distribution. This slows down the learning of the network and may even make it impossible to train.
In addition, when using saturated activation functions such as sigmoid and tanh, it is easy for the model training to get stuck in the saturated regime. This is reflected in the fact that as the model trains, x becomes larger and larger, causing w^Tx + b to become larger and larger, resulting in the gradient of the activation function approaching zero and the convergence speed becoming slower.
To solve these two problems, Batch Normalization is introduced to normalize x. After passing through the BN layer, the data distribution is normalized to N(1, 0). This means that for each layer of the network, the input parameters come from a distribution, and it is not easy to get stuck in the saturated regime, resulting in a much faster training speed for the entire model.
Why doesn't BN work in distributed training?#
In the implementation of BN, the program needs to know the current data's mean and variance in order to normalize it to the N(1, 0) distribution. However, we do not know the distribution of the test dataset. Therefore, we need to estimate the distribution of the test dataset using the distribution data of the training dataset.
During training, the BN layer uses batch statistics to calculate the true mean and variance of the current batch data, accurately normalizing the input data to the N(1, 0) distribution. At the same time, since we usually assume that the test dataset and the training dataset are identically and independently distributed (IID), we use the mean and variance of the current batch in an exponentially weighted moving average (EMA) update to estimate the distribution of the entire training dataset (because it saves memory).
However, in distributed training, each node's BN layer calculates the weighted average mean and variance of the data it is assigned. It is easy to see that taking the weighted average of the variances of each small dataset is not necessarily equivalent to the variance of the entire dataset (especially in the context of federated learning). As a result, when the trained model is used, the test data cannot be properly normalized by the BN layer, leading to a significant decrease in performance.
How to solve this?#
The most straightforward method is to simply remove the BN layer.
Well, that is indeed the most straightforward method, but it will greatly affect the final performance and training time of the model. Are there any other methods? Here are a few links for readers to explore further: