<<

Revisit Batch Normalization: New Understanding and Refinement via Composition Optimization

Xiangru Lian Ji Liu Department of Computer Science Department of Computer Science University of Rochester University of Rochester [email protected] [email protected]

Abstract models are based on BN such as ResNet [He et al., 2016, Xie et al., 2017] and Inception [Szegedy et al., 2016, 2017]. It is often believed that BN can Batch Normalization (BN) has been used ex- mitigate the exploding/vanishing problem tensively in deep learning to achieve faster [Cooijmans et al., 2016] or reduce internal training process and better resulting models. [Ioffe and Szegedy, 2015]. Therefore, BN has become a However, whether BN works strongly depends standard tool that is implemented almost in all deep on how the batches are constructed during learning solvers such as Tensorflow [Abadi et al., 2015], training, and it may not converge to a desired MXnet [Chen et al., 2015], Pytorch [Paszke et al., 2017], solution if the statistics on the batch are not etc. close to the statistics over the whole dataset. In this paper, we try to understand BN from Despite the great success of BN in quite a few scenarios, an optimization perspective by providing an people with rich experiences with BN may also notice explicit objective function associated with BN. some issues with BN. Some examples are provided This explicit objective function reveals that: below. 1) BN, rather than being a new optimiza- tion algorithm or trick, is creating a different objective function instead of the one in our BN fails/overfits when the mini-batch size is common sense; and 2) why BN may not work 1 as shown in Figure 1 We construct a simple well in some scenarios. We then propose a network with 3 layers for a classification task. It fails refinement of BN based on the compositional to learn a reasonable classifier on a dataset with only optimization technique called Full Normaliza- 3 samples as seen in Figure 1. tion (FN) to alleviate the issues of BN when the batches are not constructed ideally. The convergence analysis and empirical study for FN are also included in this paper. 5

training loss with BN 4

1 Introduction loss test loss with BN

3 Batch Normalization (BN) [Ioffe and Szegedy, 2015] has been used extensively in deep learning [Szegedy 0 20 40 60 80 100 et al., 2016, He et al., 2016, Silver et al., 2017, Huang epoch et al., 2017, Hubara et al., 2017] to accelerate the training process. During the training process, a BN Figure 1: (Fail when batch size is 1) Given a dataset normalizes its input by the mean and variance comprised of three samples: [0, 0, 0] with label 0, [1, 1, 1] computed within a mini-batch. Many state-of-the-art with label 1 and [2, 2, 2] with label 2, use the following simple network including one batch normalization layer, where the numbers in the parenthesis are the dimensions of nd Proceedings of the 22 International Conference on Ar- input and output of a layer: linear layer (3 →3) ⇒ batch tificial Intelligence and Statistics (AISTATS) 2019, Naha, normalization ⇒ relu ⇒ linear layer (3 → 3) ⇒ nll loss. Okinawa, Japan. PMLR: Volume 89. Copyright 2019 by Train with batch size 1, and test on the same dataset. The the author(s). test loss increases while the training loss decreases. Revisit Batch Normalization: New Understanding and Refinement via Composition Optimization

80 bs=1 80 bs=4 60 60 bs=16 without BN test bs=64 without BN train 40 40 with BN test

accuracy with BN train 20

test accuracy 20 0 0 2 4 6 8 10 1 3 5 7 9 11 13 15 17 19 epoch epoch

Figure 2: (Sensitive to the size of mini-batch) The Figure 3: (Fail if data are with large variation) The test accuracy for ResNet18 on CIFAR10 dataset trained for training and test accuracy for a simple logistic regression 10 epochs with different batch sizes. The smaller the batch problem on synthesized dataset with mini-batch size 20. size is, with BN layers in ResNet, the worse the convergence We synthesize 10,000 samples where each sample is a vector result is. of 1000 elements. There are 1000 classes. Each sample is generated from a zero vector by firstly randomly assigning a class to the sample (for example it is the i-th class), and BN’s solution is sensitive to the mini-batch size then setting the i-th element of the vector to be a random number from 0 to 50, and finally adding noise (generated as shown in Figure 2 The test conducted in Fig- from a standard normal distribution) to each element. We ure 2 uses ResNet18 on the Cifar10 dataset. When generate 100 test samples in the same way. A logistic the batch size changes, the neural network with BN regression classifier should be able to classify this dataset training tends to converge to a different solution. In- easily. However, if we add a BN layer to the classifier, the deed, this observation is true even for solving convex model no longer converges. optimization problems. It can also be observed that the smaller the batch size, the worse the performance Besides these, we also provide the explicit form of the of the solution. original objective BN aims to optimize (but does not in fact), and propose a Multilayer Compositional Stochas- BN fails if data are with large variation as tic Descent (MCSGD) algorithm based on shown in Figure 3 BN breaks convergence on sim- the compositional optimization technology to solve the ple convex logistic regression problems if the variance original objective. We prove the convergence of the of the dataset is large. Figure 3 shows the first 20 proposed MCSGD algorithm and empirically study its epochs of such training on a synthesized dataset. This effectiveness to refine the state-of-the-art BN algorithm. phenomenon also often appears when using distributed training algorithms where each worker only has its lo- cal dataset, and the local datasets are very different. 2 Related work Therefore, these observations may remind people to We first review traditional normalization techniques. ask some fundermental questions: LeCun et al. [1998] showed that normalizing the input dataset makes training faster. In deep neural networks, 1. Does BN always converge and/or where does it normalization was used before the invention of BN. For converge to? example Local Response Normalization (LRU) [Lyu 2. Using BN to train the model, why does sometimes and Simoncelli, 2008, Jarrett K et al., 2009, Krizhevsky severe over-fitting happen? et al., 2012] which computes the statistics for the neigh- 3. Is BN a trustable “optimization” algorithm? borhoods around each pixel. In this paper, we aim at understanding the BN algo- We then review batch normalization techniques. Ioffe rithm from a rigorous optimization perspective to show and Szegedy [2015] proposed the Batch Normalization (BN) algorithm which performs normalization along 1. BN always converges, but not solving either the the batch dimension. It is more global than LRU objective in our common sense, or the optimization and can be done in the middle of a neural network. originally motivated. Since during inference there is no “batch”, BN uses 2. The result of BN heavily depends on how to con- the running average of the statistics during training struct batches, and it can overfit predefined batches. to perform inference, which introduces unwanted bias BN treats the training and the inference differently, between training and testing. While this paper tries to which makes the situation worse. study why BN fails in some scenarios and how to fix 3. BN is not always trustable, especially when the it, Santurkar et al. [2018], Bjorck et al. [2018], Kohler batches are not constructed with randomly selected et al. [2018] provided new understandings about why samples or the batch size is small. BN accelerates training by analyzing simplified neural Xiangru Lian, Ji Liu networks or the magnitude of gradients. Ioffe [2017] Algorithm 1 Batch Normalization Layer proposed Batch Renormalization (BR) that introduces Training two extra parameters to reduce the drift of the esti- Require: Input Bin, which is a batch of input. Estimated mation of mean and variance. However, Both BN and mean µ and variance ν, and averaging constant α. BR are based on the assumption that the statistics on 1: a mini-batch approximates the statistics on the whole dataset. µ ←(1 − α) · µ + α · mean(Bin), ν ←(1 − α) · ν + α · var(Bin). Next we review the instance based normalization which normalizes the input sample-wise, instead of using the . mean(Bin) and var(Bin) calculate the mean and batch information. This includes Layer Normalization variance of Bin respectively. 2: Output [Ba et al., 2016], Instance Normalization [Ulyanov et al., Bin − mean(Bin) 2016], Weight Normalization [Salimans and Kingma, Bout ← . pvar(B ) +  2016] and a recently proposed Group Normalization in [Wu and He, 2018]. They all normalize on a single .  is a small constant for numerical stability. sample basis and are less global than BN. They avoid Inference doing statistics on a batch, which work well when it Require: Input Bin, estimated mean µ and variance ν.A comes to some vision tasks where the outputs of layers small constant  for numerical stability. can be of handreds of channels, and doing statistics 1: Output Bin − µ on the outputs of a single sample is enough. However Bout ← √ . ν +  we still prefer BN when it comes to the case that the single sample statistics is not enough and more global statistics is needed to increase accuracy. Finally we review the compositional optimization which has been observed in many applications that the speed our refinement of BN is based on. Wang and Liu of training is improved by using BN. [2016] proposed the compositional optimization algo- rithm which optimizes the nested expectation problem min Ef(Eg(·)), and later the convergence rate was im- proved in Wang et al. [2016]. The convergence of 4 Training objective of BN compositional optimization on nonsmooth regularized problems was shown in Huo et al. [2017] and a variance reduced variant solving strongly convex problems was In the original paper proposing BN algorithm [Ioffe analyzed in Lian et al. [2016]. and Szegedy, 2015], authors do not provide the explicit optimization objective BN targets to solve. Therefore, 3 Review the BN algorithm many people may naturally think that BN is an op- timization trick, that accelerates the training process In this section we review the Batch Normalization (BN) but still solves the original objective in our common algorithm [Ioffe and Szegedy, 2015]. sense. Unfortunately, this is not true! The actual ob- jective BN solves is different from the objective in our BN is usually implemented as an additional layer in common sense and also nontrivial. In this section, we neural networks. In each iteration of the training pro- derive the objective function which is actually opti- cess, the BN layer normalizes its input using the mean mized when BN layers are used. Then we discuss how and variance of each channel of the input batch to make this objective is different from the objective without its output having zero mean and unit variance. The BN layers and when the BN’s objective is less desired. mean and variance on the input batch is expected to be similar to the mean and variance over the whole dataset. In the inference process, the layer normal- Rigorous mathematical description of BN layer izes its input’s each channel using the saved mean and To define the objective of BN in a precise way, variance, which are the running averages of mean and we need to define the normalization operator f B,σ variance calculated during training. This is described W that maps a function to a function associating with in Algorithm 11. With BN, the input at the BN layer a mini-batches B, an σ, and is normalized so that the next layer in the network ac- parameters W . Let g(·) be a function g(·) = cepts inputs that are easier to train on. In practice it > [g1(·) g2(·) ··· gn(·) 1] where g1(·), ··· , gn(·) 1A linear transformation is often added after applying are functions mapping a vector to a number. The B,σ BN to compensate the normalization. operator fW is defined by Revisit Batch Normalization: New Understanding and Refinement via Composition Optimization

• People can choose b as the size of minibatch, and form B by B := {B ⊂ D : |B| = b} . Choosing B in   g1(·)−mean(g1,B)  p B,σ var(g ,B) this way implicitly assumes that all nodes can access f ’s argument is 1 W   g2(·)−mean(g2,B)  a function g   p  the same dataset, which may not be true in practice. z}|{   var(g2,B)  f B,σ (g) (·) := σ W  .  (1) • If data is distributed (Di is the local dataset on the W   .  | {z }    i-th worker, disjoint with others, satisfying D = D1 ∪ f B,σ (g) is   gn(·)−mean(gn,B)  W p · · ·∪ Dn), a typical B is defined by B = B1 ∪B2 ∪· · ·∪ another function   var(gn,B)  1 Bn with Bi defined by Bi := {B ⊂ Di : |B| = b} . • When the mini-batch is chosen to be the whole where mean(r, B) is defined by mean(r, B) := dataset, B contains only one element D. 1 P |B| b∈B r(b). and var(r, B) is defined by var(r, B) = mean(r2,B) − mean(r, B)2. After figure out the implicit objective BN optimizes in (4), it is not difficult to have following key observa- Note that the first augment of mean(·, ·) and var(·, ·) tions is a function, and the second augment is a set of samples. A m-layer’s neural network with BN can be represented • The BN objectives vary a lot when different sam- by a function pling strategies are applied. This explains why the convergent solution could be very different when we B B,σm B,σm−1 B,σ1 F m (·) :=f (f (··· (f (I))))(·) or {Wi}i=1 Wm Wm−1 W1 change the sampling strategy; B,σm B,σm−1 B,σ1 • For the same sample strategy, BN’s objective func- fW ◦ fW ◦ · · · ◦ fW (I)(·) m m−1 1 tion also varies if the size of minibatch gets changed. where I(·) is the identical mapping function from a This explains why BN could be sensitive to the batch vector to itself, that is, I(x) = x. size. The observations above may seriously bother us how to BN’s objective is very different from the ob- appropriately choose parameters for BN, since it does jective in our common sense Using the same not optimize the objective in our common sense, and way, we can represent a fully connected network in could be sensitive to the sampling strategy and the size our common sense (without BN) by F{W }m (·) := i i=1 of minibatch. f¯σm ◦ f¯σm−1 ◦ · · · ◦ f¯σ1 (I)(·) where operator f¯σ is Wm Wm−1 W1 W defined by BN’s objective (4) with B = {D} has the ¯σ same optimal value as the original objective (3) fW (g)(·) := σ(W g(·)). (2) When the batch size is equal to the size of the whole Besides of the difference of operator function definitions, dataset in (4), the objective becomes their ultimate objectives are also different. Given the 1 X D training date set D, the objective without BN (or the (FN) min l(F m (x), y), (5) m {Wj }j=1 {Wj } |D| objective in our common sense) is defined by j=1 (x,y)∈D 1 X which differs from the original objective (3) only in min l(F{W }m (x), y), (3) m j j=1 the first argument of l. Noting the only difference {Wj }j=1 |D| (x,y)∈D between (1) and (2) when B is a constant D is a linear where l(·, ·) is a predefined , while the transformation which can be absorbed into W :   g1(·)−mean(g1,D)  objective with BN (or the objective BN is actually p solving) is var(g1,D)   g2(·)−mean(g2,D)    p    var(g2,D)  1 X 1 X B D,σ   .  (BN) min l(F m (x), y), fW (g)(·) = σ W . m {Wj }j=1   .  {Wj }j=1 |B| |B|    B∈B (x,y)∈B   gn(·)−mean(gn,D)  p (4)   var(gn,D)  where B is the set of batches. 1   Therefore, the objective of BN could be very different  1 −mean(g1,D)   p p  from the objective in our common sense in general.  var(g1,D) var(g1,D)  g1(·)   1 −mean(g2,D)    p p  g2(·)    var(g2,D) var(g2,D)     .   .  BN could be very sensitive to the sample strat- = σ W  . . .   .    . .  .      g (·)  egy and the minibatch size The super set B has  1 −mean(gn,D)  n   p p    var(gn,D) var(gn,D) 1  different forms depending on how to define mini-batchs.  1  | {z } For example, =:W 0 Xiangru Lian, Ji Liu

 0 > =σ W g1(·) g2(·) ··· gn(·) 1 . There exist compositional optimization algorithms [Wang and Liu, 2016, Wang et al., 2016] for solving 0 If we use this W as our new W , (3) and (4) has the (6) when n = 2, but for n > 2 we still do not have same form and the two objectives should have the same a good algorithm to solve it. We follow the spirit to optimal value, which means when B = {D} adding extend the compositional optimization algorithms to BN does not hurt the expressiveness of the network. solve the general optimization problem (6) as shown However, since each layer’s input has been normalized, in Algorithm 2. See Section 6 for an implementation in general the condition number of the objective is when it comes to normalization. reduced, and thus easier to train. To simplify notation, given w = (w1, . . . , wn) we define:

wi+1,D wi+2,D wn,D 5 Solving full normalization objective ei(w; ξ) :=ei(Fi+1 ◦ Fi+2 ◦ · · · ◦ Fn ◦ I(ξ)). (5) via compositional optimization We define the following operator if we already have an estimation, say eˆi, of Eξ∈Dei(w; ξ): The BN objective contains the normalization operator ˆwi,eˆi which mostly reduces the condition number of the ob- Fi (g)(·) :=Fi(wi; g(·);e ˆi). jective that can accelerate the training process. While Given w = (w1, . . . , wn) and eˆ = (ˆe1,..., eˆn), we de- the FN formulation in (5) provides a more stable and fine trustable way to define the BN formation, it is very ˆ ˆwi+1,eˆi+1 ˆwi+2,eˆi+2 challenging to solve the FN formulation in practice. If ei(w; ξ;e ˆ) :=ei(Fi+1 ◦ Fi+2 ◦ ˆwn,eˆn follow the standard SGD used in solving BN to solve · · · ◦ Fn ◦ I(ξ)). (5), then every single iteration needs to involve the whole dataset since B = D, which is almost impossible Algorithm 2 Multilayer Compositional Stochastic in practice. The key difficulty to solve (5) lies on that (MCSGD) algorithm there exists “expectation” in each layer, which is very K Require: {γk}k=0, approxima- expensive to compute even we just need to compute a K tion rate {αk}k=0, dataset D, initial point stochastic gradient if use the standard SGD method. (0) (0) (0) w := (w1 , . . . , wn ) and initial estimations To solve (5) efficiently, we follow the spirit of composi- (0) (0) (0) eˆ := (ˆe1 ,..., eˆn ). tional optimization to develop a new algorithm namely, 1: for k = 0, 1, 2,...,K do Multilayer Compositional Stochastic Gradi- 2: For each i, randomly select a sample ξk, estimate (k) ent Descent (Algorithm 2). The proposed algorithm Eξ[ei(w ; ξ)] by does not have any requirement on the size of minibatch. eˆ(k+1) ← (1 − α )ˆe(k) + α eˆ (w(k); ξ ;e ˆ(k)).a The key idea is to estimate the expectation from the i k i k i k current samples and historical record. 3: Ask the oracle O using w(k) and estimated means eˆ(k+1) to obtain the approximated gradient at w(k): 5.1 Formulation g(k). . See Remark 1 for discussion on the oracle. (k+1) (k) (k) To propose our solution to solve (5), let us define the 4: w ← w − γkg . objective in a more general but neater way in the 5: end for following. When B = {D}, (4) is a special case of the aThis step is borrowed from compositional optimization following general objective: to estimate the statistics across the whole dataset. h i w1,D w2,D wn,D min f(w) := ξ F ◦ F ◦ · · · ◦ Fn ◦ I(ξ) , Remark 1 (MCSGD oracle). In MCSGD, the gradient w E 1 2 (6) oracle takes the current parameters of the model, and (k) the current estimation of each Eξ[ei(w ; ξ)], to output where ξ represents a sample in the dataset D, for ex- an estimation of a stochastic gradient. ample it can be the pair (x, y) in (4). wi represents the parameters of the i-th layer. w represents all pa- For example for an objective like (6), the derivative of the loss function w.r.t. the i-th layer’s parameters is rameters of the layers: w := (w1, . . . , wn). Each Fi is an operator of the form: w1,D w2,D wn,D ∂wi f(w) =∂wi (Eξ[F1 ◦ F2 ◦ · · · ◦ Fn ◦ I(ξ)]) wi,D w ,D F (g)(·) :=Fi(wi; g(·); ξ∈ ei(g(ξ))),  1  i E D [∂x(F1 (x)(ξ))] w2,D wn,D x=F2 ◦···◦Fn ◦I w2,D where e is used to compute statistics over the dataset.  w ,D  i ·[∂x(F2 (x)(ξ))]x=F 3 ◦···◦F wn,D◦I 2  3 n  For example, with ei(x) = [x, x ], the mean and vari- = Eξ  wi−1,D  ,  ···· [∂x(Fi−1 (x)(ξ))]x=F wi,D◦···◦F wn,D◦I  ance of the layer’s input over the dataset can be calcu-  i n  wi,D wi,D w , lated, and F can use that information, for example, ·[∂wi (Fi (x)(ξ))] i+1 D wn,D i x=Fi ◦···◦Fn ◦I to perform normalization. (7) Revisit Batch Normalization: New Understanding and Refinement via Composition Optimization where for any j ∈ [1, . . . , i − 1]: Lemma 1 (Approximation error) Choose the learning rate γk and αk in Algorithm 2 wj ,D w , [∂xFj (x)(ξ)] j+1 D wn,D in the form defined in Assumption 1-6 with param- x=Fj+1 ◦···◦Fn ◦I   eters γ and a. Under Assumption 1, the sequence ∂xFj(wj; x; y)+ = 0 generated in Algorithm 2 satisfies ∂yFj(wj; x; y) · Eξ0∼DDj(ξ ) w ,D (k+1) x = F j+1 ◦ · · · ◦ F wn,D ◦ I(ξ), (k) 2 j+1 n Ekeˆi − Eξei(w ; ξ)k y = Eξ[ej (w; ξ)], with 0 . −2γ+2a+ε −a+ε Dj (ξ ) = [∂z ei(z)] wj+1,D w ,D E(k + k ), ∀i, ∀k, (8) z=F ◦···◦F n ◦I(ξ0) 6 j+1 n (k+1) (k) 2 0 Ekeˆi − Eξei(w ; ξ)k The oracle samples a ξ for each layer and calculate  α  this derivative with y set to an estimated value eˆ , and 1 − k keˆ(k) − e (w(k−1); ξ)k2 j 6 2 E i Eξ i returns the stochastic gradient. The expectation of −2γ+a+ε −2a+ε this stochastic gradient will be (7) with some error + C(k + k ), ∀i, ∀k. (9) caused by the difference between the estimated eˆ and j for any ε satisfying 1 − a > ε > 0, where E and C the true expectation [e (w; ξ)]. The more accurate Eξ j are two constants independent of k. the estimation is, the closer the expectation of this stochastic and the true gradient (7) will be. 0 Then it can be shown that on (6) Algorithm 2 converges In practice, we often use the same ξ for each layer to as seen in Theorem 2 and Corollary 3. It is worth save some computation resource, this introduce some noting that the convergence rate in Corollary 3 is slower bias the expectation of the estimated stochastic gradient. than the √1 convergence rate of SGD without any One such implementation can be found in Algorithm 3. K normalization. This is due to the estimation error. If the estimation error is small (for example, the samples 5.2 Theoretical analysis in a batch are randomly selected and the batch size is In this section we analyze Algorithm 2 to show its large), the convergence will be fast. convergence when applied on (6) (a generic version Theorem 2 (Convergence) of the FN formulation in (5)). The detailed proof is provided in the supplementary material. We use the Choose the learning rate γk and αk in Algorithm 2 following assumptions as shown in Assumption 1 for in the form defined in Assumption 1-6 with pa- the analysis. rameters γ and a satisfying a < 2γ − 1, a < 1/2, γkLg 1 (k) and . Under Assumption 1, for any inte- Assumption 1. 1. The gradients g ’s are bounded: αk+1 6 2 ger K > 1 the sequence generated by Algorithm 2 kg(k)k G, ∀k for some constant G. 6 satisfies 2. The variance of all ei’s are bounded: PK γ k∂f(w(k))k2 H ke (w; ξ) − e (w; ξ)k2 σ2, ∀i, w, k=0 kE , Eξ i Eξ i 6 PK 6PK k=0 γk k=0 γk 2 2 Eξkei(w; ξ;e ˆ) − Eξei(w; ξ;e ˆ)k 6 σ , ∀i, w, eˆ where H is a constant independent of K. for some constant σ. 3. The error of approximated gradient E[g(k)] is pro- We next specify the choice of γ and a to give a more portional to the error of approximation for E[ei]: clear convergence rate for our MCSGD algorithm in (k) (k) 2 Corollary 3. kEξk [g ] − ∇f(w )k n X (k+1) (k) 2 Corollary 3 6Lg keˆi − Eξk [ei(w ; ξk)]k , ∀k i=1 Choose the learning rate γk and αk in Algorithm 2 in the form defined in Assumption 1-6, more specif- for some constant Lg. 1 −4/5 −2/5 ically, γk = (k + 2) and αk = (k + 1) . 2Lg 4. All functions and their first order derivatives are Under Assumption 1, for any integer K > 1 the se- Lipschitzian with Lipschitz constant L. quence generated in Algorithm 2 satisfies 5. The minimum of the objective f(w) is finite. −γ PK (k) 2 6. γk, αk are monotonically decreasing. γk = O(k ) k=0 Ek∂f(w )k H −a 6 , and αk = O(k ) for some constants γ > a > 0. K + 2 (K + 2)1/5

It can be shown under the given assumptions, the where H is a constant independent of K. approximation errors will vanish shown in Lemma 1. Xiangru Lian, Ji Liu

2.5 2.5 6 Experiments BN BN 2.0 FN 2.0 FN 1.5 1.5 Experiments are conducted to validate the effective- 1.0 1.0

test loss ness of our MCSGD algorithm for solving the FN for- training loss 0.5 0.5 mulation. We consider two settings. The first one in 0.0 0.0 0 200 400 600 800 0 2 4 6 8 10 Section 6.1 shows BN’s convergence highly depends on iteration epoch the size of batches, while FN is more robust to different Figure 4: The training and testing error for the given model batch sizes. The second one in Section 6.2 shows that trained on MNIST with batch size 64, learning rate 0.01 FN is more robust to different construction of mini- and momentum 0.5 using BN or FN for the case where all batches. samples in a batch are of a single label. The approximation k −0.4 We use a simple neural network as the testing network rate α in FN is ( 20 +1) where k is the iteration number. whose architecture is shown in Figure 6. Steps 2 and 3 in Algorithm 2 are implemented in Algorithm 3. 2.0 BN 3.0 BN FN FN The forward pass essentially performs Step 2 of Al- 2.5 1.5 gorithm 2, which estimates the mean and variance of 2.0 layer inputs over the dataset. The backward pass es- testing loss training loss 1.0 1.5 sentially performs Step 3 of Algorithm 2, which gives 1.0 0 50 100 0 50 100 the approximated gradient based on current network epoch epoch parameters and the estimations. Note that as discussed Figure 5: The training and testing error for the given in Remark 1, for efficiency, in each iteration of this model trained on CIFAR10 with batch size 64, learning implementation we are using the same samples to do rate 0.01 and momentum 0.9 using BN and FN for the case estimation in all normalization layers, which saves com- where all samples in a batch are of no more than 3 labels. putational cost but brings additional bias. The learning rate is decreased by a factor of 5 for every k −0.3 20 epochs. The approximation rate α in FN is ( 5 + 1) where k is the iteration number. 6.1 Dependence on batch size

Algorithm 3 Full Normalization (FN) layer MNIST is used as the testing dataset with some modi- fication by multiplying a random number in (−2.5, 2.5) Forward pass to each sample to make them more diverse. In this Require: Learning rate γ, approximation rate α, input b×d case, as shown in Figure 9, with batch size 1 and batch Bin ∈ R , mean estimation µ, and mean of square estimation ν. . b is the batch size, and d is the size 16, we can see the convergent points are different dimension of each sample. in BN, while the convergent results of FN are much 1: if training then more close for different batch sizes. Therefore, FN is more robust to the size of mini-batch. µ ←(1 − α)µ + αmean(Bin), ν ←(1 − α)ν + αmean_of_square(Bin). 6.2 Dependence on batch construction

2: end if We study two cases — shuffled case and unshuffled case. 3: return layer output: For the shuffled case, where the samples in a batch

Bin − µ are randomly sampled, we expect the performance of Boutput ← . pmax{ν − µ2, } FN matches BN’s in this case. For the unshuffled case, where the batch contains only samples in just a few .  is a small constant for numerical stability. number of categories (so the mean and variance are Backward pass very different from batch to batch). In this case we can Require: Mean estimation µ, mean of square estimation ν, observe that the FN outperforms BN. b×d the gradient at the output gout, and input Bin ∈ R . 1: Define Unshuffled case We show the FN has advantages

Bin − µ over BN when the data variation is large among mini- f(µ, ν, Bin) := . pmax{ν − µ2, } batches. In the unshuffled setting, we do not shuffle the dataset. It means that the samples in the same mini- 2: return the gradient at the input: batch are mostly with the same labels. The comparison   uses two datasets (MINST and CIFAR10): ∂µf + 2Bin∂ν f gin ← gout · ∂B f + . in bd • On MNIST, the batch size is chosen to be 64 and each batch only contains a single label. The conver- . ∂ f is the derivative of f w.r.t. a at µ, ν, B . a in gence results are shown in Figure 4. We can observe Revisit Batch Normalization: New Understanding and Refinement via Composition Optimization

Figure 6: The simple neural network used in the experiments. The “Normalization” between layers can be BN or FN. The activation function is relu and the loss function is negative log likelihood.

from these figures that for the training loss, BN and 2.5 BN BN 2.0 FN 2.0 FN FN converge equally fast (BN might be slightly faster). 1.5 1.5 However since the statistics on any batch are very dif- 1.0 1.0 ferent from the whole dataset, the estimated mean and test loss training loss 0.5 0.5 variance in BN are very different from the true mean 0.0 0.0 0 200 400 600 800 0 2 4 6 8 10 and variance on the whole dataset, resulting in a very iteration epoch high test loss for BN. • On CIFAR10, we observe similar results as shown Figure 7: The training and testing error for the given model trained on MNIST with batch size 64, learning rate 0.01 and in Figure 5. In this case, we restrict the number of momentum 0.5 using BN or FN for the case where samples labels in every batch to be no more than 3. Thus in a batch are randomly selected. The approximation rate BN’s performance on CIFAR10 is slightly better than k −0.4 α in FN is ( 20 + 1) where k is the iteration number. on MNIST. We see the convergence efficiency of both methods is still comparable in term of the training loss.

1.8 However, the testing error for BN is still far behind BN BN 1.6 1.4 FN FN the FN. 1.4 1.2 1.2 1.0 1.0 Shuffled case For the shuffled case, where the sam- testing loss training loss 0.8 0.8 ples in each batch are selected randomly, we expect BN 0 20 40 60 80 0 20 40 60 80 to have similar performance as FN in this case, since epoch epoch the statistics in a batch is close to the statistics on the Figure 8: The training and testing error for the given whole dataset. The results for MNIST are shown in model trained on CIFAR10 with batch size 64, learning rate Figure 7 and the results for CIFAR10 are shown in Fig- 0.01 and momentum 0.9 using BN or FN for the case where ure 8. We observe the convergence curves of BN and samples in a batch are randomly selected. The learning FN for both training and testing loss match well. rate is decreased by a factor of 5 for every 20 epochs. The k −0.2 approximation rate α in FN is ( 20 + 1) where k is the iteration number. 7 Conclusion We provide new understanding for BN from an opti- mization perspective by rigorously defining the opti- 2.5 2.5 bs=1 bs=1 2.0 bs=16 2.0 bs=16 mization objective for BN. BN essentially optimizes an

1.5 1.5 objective different from the one in our common sense. 1.0 1.0 The implicitly targeted objective by BN depends on test loss test loss 0.5 0.5 the sampling strategy as well as the minibatch size. 0.0 0.0 0 2 4 6 8 10 0 2 4 6 8 10 That explains why BN becomes unstable and sensi- epoch epoch tive in some scenarios. The stablest objective of BN (a) With BN layers. (b) With FN layers. (called FN formulation) is to use the full dataset as the mini-batch, but it is very challenging to solve such Figure 9: The testing error for the given model trained formulation. To solve the FN objective, we follow the on MNIST, where each sample is multiplied by a random spirit of compositional optimization to develop MC- number in (−2.5, 2.5). The optimization algorithm is SGD SGD algorithm to solve it efficiently. with learning rate is 0.01 and momentum is 0.5. FN gives more consistent results under different batch sizes. The Acknowledgment Xiangru Lian and Ji Liu are in part approximation rate α in FN is ( k + 1)−0.4 where k is the 20 supported by NSF CCF1718513, IBM faculty award, and iteration number. NEC fellowship. Xiangru Lian, Ji Liu

References K. K. Jarrett K, A. Ranzato M, et al. What is the best multii stage architecture for object recognition? M. Abadi, A. Agarwal, P. Barham, E. Brevdo, Proceedings of the 2009 IEEE 12th International Con- Z. Chen, C. Citro, G. S. Corrado, A. Davis, J. Dean, ference on . Kyoto, Japan, 2146:2153, M. Devin, S. Ghemawat, I. Goodfellow, A. Harp, 2009. G. Irving, M. Isard, Y. Jia, R. Jozefowicz, L. Kaiser, M. Kudlur, J. Levenberg, D. Mané, R. Monga, J. Kohler, H. Daneshmand, A. Lucchi, M. Zhou, S. Moore, D. Murray, C. Olah, M. Schuster, J. Shlens, K. Neymeyr, and T. Hofmann. Towards a theoretical B. Steiner, I. Sutskever, K. Talwar, P. Tucker, V. Van- understanding of batch normalization. arXiv preprint houcke, V. Vasudevan, F. Viégas, O. Vinyals, P. War- arXiv:1805.10694, 2018. den, M. Wattenberg, M. Wicke, Y. Yu, and X. Zheng. A. Krizhevsky, I. Sutskever, and G. E. Hinton. Im- TensorFlow: Large-scale on heteroge- agenet classification with deep convolutional neural neous systems, 2015. URL https://www.tensorflow. networks. In Advances in neural information process- org/. Software available from tensorflow.org. ing systems, pages 1097–1105, 2012. J. L. Ba, J. R. Kiros, and G. E. Hinton. Layer nor- Y. LeCun, L. Bottou, G. B. Orr, and K.-R. Müller. malization. arXiv preprint arXiv:1607.06450, 2016. Efficient backprop. In Neural networks: Tricks of the N. Bjorck, C. P. Gomes, B. Selman, and K. Q. Wein- trade, pages 9–50. Springer, 1998. berger. Understanding batch normalization. In Ad- X. Lian, M. Wang, and J. Liu. Finite-sum composition vances in Neural Information Processing Systems, optimization via variance reduced gradient descent. pages 7705–7716, 2018. arXiv preprint arXiv:1610.04674, 2016. T. Chen, M. Li, Y. Li, M. Lin, N. Wang, M. Wang, T. Xiao, B. Xu, C. Zhang, and Z. Zhang. Mxnet: S. Lyu and E. P. Simoncelli. Nonlinear image repre- A flexible and efficient machine learning library for sentation using divisive normalization. In Computer heterogeneous distributed systems. arXiv preprint Vision and , 2008. CVPR 2008. arXiv:1512.01274, 2015. IEEE Conference on, pages 1–8. IEEE, 2008. T. Cooijmans, N. Ballas, C. Laurent, Ç. Gülçehre, and A. Paszke, S. Gross, S. Chintala, G. Chanan, E. Yang, A. Courville. Recurrent batch normalization. arXiv Z. DeVito, Z. Lin, A. Desmaison, L. Antiga, and preprint arXiv:1603.09025, 2016. A. Lerer. Automatic differentiation in . In NIPS-W, 2017. K. He, X. Zhang, S. Ren, and J. Sun. Deep residual learning for image recognition. In Proceedings of T. Salimans and D. P. Kingma. Weight normalization: the IEEE conference on computer vision and pattern A simple reparameterization to accelerate training of recognition, pages 770–778, 2016. deep neural networks. In Advances in Neural Infor- mation Processing Systems, pages 901–909, 2016. G. Huang, Z. Liu, L. Van Der Maaten, and K. Q. Wein- berger. Densely connected convolutional networks. In S. Santurkar, D. Tsipras, A. Ilyas, and A. Madry. CVPR, volume 1, page 3, 2017. How does batch normalization help optimization? In Advances in Neural Information Processing Systems, I. Hubara, M. Courbariaux, D. Soudry, R. El-Yaniv, pages 2488–2498, 2018. and Y. Bengio. Quantized neural networks: Training neural networks with low precision weights and acti- D. Silver, J. Schrittwieser, K. Simonyan, vations. Journal of Machine Learning Research, 18: I. Antonoglou, A. Huang, A. Guez, T. Hubert, 187–1, 2017. L. Baker, M. Lai, A. Bolton, et al. Mastering the game of go without human knowledge. Nature, 550 Z. Huo, B. Gu, and H. Huang. Accelerated method for (7676):354, 2017. stochastic composition optimization with nonsmooth regularization. arXiv preprint arXiv:1711.03937, 2017. C. Szegedy, V. Vanhoucke, S. Ioffe, J. Shlens, and Z. Wojna. Rethinking the inception architecture for S. Ioffe. Batch renormalization: Towards reducing computer vision. In Proceedings of the IEEE Confer- minibatch dependence in batch-normalized models. In ence on Computer Vision and Pattern Recognition, Advances in Neural Information Processing Systems, pages 2818–2826, 2016. pages 1942–1950, 2017. C. Szegedy, S. Ioffe, V. Vanhoucke, and A. A. Alemi. S. Ioffe and C. Szegedy. Batch normalization: Accel- Inception-v4, inception-resnet and the impact of resid- erating deep network training by reducing internal co- ual connections on learning. In AAAI, volume 4, variate shift. arXiv preprint arXiv:1502.03167, 2015. page 12, 2017. Revisit Batch Normalization: New Understanding and Refinement via Composition Optimization

D. Ulyanov, A. Vedaldi, and V. Lempitsky. Instance normalization: The missing ingredient for fast styliza- tion. arXiv preprint arXiv:1607.08022, 2016. M. Wang and J. Liu. A stochastic compositional gradient method using markov samples. In Proceed- ings of the 2016 Winter Simulation Conference, pages 702–713. IEEE Press, 2016. M. Wang, J. Liu, and E. Fang. Accelerating stochas- tic composition optimization. In Advances in Neu- ral Information Processing Systems, pages 1714–1722, 2016. Y. Wu and K. He. Group normalization. arXiv preprint arXiv:1803.08494, 2018. S. Xie, R. Girshick, P. Dollár, Z. Tu, and K. He. Ag- gregated residual transformations for deep neural net- works. In Computer Vision and Pattern Recognition (CVPR), 2017 IEEE Conference on, pages 5987–5995. IEEE, 2017.