NeurIPS 2020

Batch normalization provably avoids ranks collapse for randomly initialised deep networks


Review 1

Summary and Contributions: This paper studies the effect of adding batch normalization (BN) layers on the rank of the data matrix as it passes through the network. The authors prove theoretically that the rank of the data matrix when passed through a network with BN does not collapse to rank one, which is what occurs when the network does not have BN layers. The theoretical results are accompanied by experimental results which highlight the phenomenon in networks trained on and operating on real and simulated data.

Strengths: This paper is well motivated and grounded. While I did not completely check the proofs, the application of markov chain theory to analyzing representations learned by deep networks is an interesting approach and might be applicable beyond studying BN. The paper has an easy to follow narrative that highlights and explains the phenomenon. The authors also consider how initialization interacts with batch normalization and the rank of the representations.

Weaknesses: While there are not many technical or experimental weaknesses in this paper, I wonder whether rank preserving transformations are important in other learning models - say linear ones or kernel machines, etc. It could be the case that this is a phenomenon exclusive to deep networks and other models are not relevant. Another issue is that in the case of binary classification one could still perform the task when rank collapse happens, as long as the relevant discriminatory signal is captured by the principal direction that the data is collapsed to. I would like to know if the authors agree or disagree with this hypothetical. Finally I do not think the authors address the case where the networks are overparameterized and d>N in each layer. Then the controlling factor on the rank of the representations is N and not \Omega(\sqrt{d}). How do the authors think their results adapt to this situation.

Correctness: Yes the theory and experiments seem reasonable.

Clarity: Yes the paper is very well presented and was easy to read.

Relation to Prior Work: Yes the authors discuss other possible explanations for batch normalization and point out how their approach might fill in some gaps.

Reproducibility: Yes

Additional Feedback: Update after reading reviews and author response: The authors have addressed all my questions convincingly in their response. I hope they include their responses to these questions in the final version of the paper. I vote for acceptance.


Review 2

Summary and Contributions: The aim of the paper is to explore and evaluate the importance of batch normalization in deep learning networks initialized with random weights. This is an important point since most of the DNNs are trained with this method.

Strengths: - The paper provides a useful insight between batch-normalization and stability of the rank through the layers - Interestingly the authors prove that batch-normalization avoid rank collapse for linear NNn, while the importance of the rank in gradient based learning is experimentally evaluated. - The random initialization setting is widely used and relevant to the scientific community.

Weaknesses: - The proposed solution aspires at explaining an important behavior of DNNs. Unfortunately it does not provide a constructive approach to build new DNNs relying on / exploring batch-normalization. From this perspective the discussion paragraph is interesting even if further in-depth analysis would be required.

Correctness: The employed methods are correct

Clarity: The paper is generally well written.

Relation to Prior Work: The paper is well contextualized in the related literature.

Reproducibility: Yes

Additional Feedback: The overall impression of the paper is positive even if it is not fully clear how broad is the impact of what proposed.


Review 3

Summary and Contributions: This paper shows that BatchNorm can effectively prevent rank collapse in DNNs, both theoretically and empirically. The theoretical result holds for linear MLPs, but is also shown to be true empirically on a variety of commonly used neural nets trained on standard benchmark datasets. Rank collapse can cause gradients to become independent of the data, thus interfering with learning. In addition, the authors propose a novel pre-training method for SGD to avoid rank collapse in vanilla neural nets.

Strengths: -The theoretical claims are sound and the experiments are extensive. -The result is significant. There has been a line of work understanding batch normalization. This paper points out some insufficiencies of previous work, and develops new theories for the efficacy of BatchNorm. Moreover, the theoretical results are supported by experiments on standard deep neural nets. -The contribution is novel. The pre-training step for SGD is an interesting technique to use when BatchNorm isn’t the most efficient solution to avoid rank collapse.

Weaknesses: -The theoretical result seems to hold only when \gamma is small.

Correctness: Yes.

Clarity: Yes.

Relation to Prior Work: Yes.

Reproducibility: Yes

Additional Feedback: -In the warmup analysis, Lemma 6 in the appendix, \gamma is taken to be at most 1/8Bd. How does this result apply to the overparameterized regime, when d is large? Similarly, in Theorem 14, \gamma needs to be sufficiently small, meaning that the skip connection strength is high. Does this hold in practice? -In appendix I, figure 11, pre-trained SGD seems to overfit with 30 hidden layers. It seems that even though pre-trained SGD accelerates training, it might not generalize as well when the net is not very deep? -To what extent can pre-training replace BatchNorm? Are there experiments with pre-training in architectures like VGG?