NeurIPS 2020

The Pitfalls of Simplicity Bias in Neural Networks


Meta Review

The work aims at studying the inductive bias implicitly implemented by DNNs trained with SGD. The authors do so by introducing several synthetic toy and image-based datasets, where the notion of "simple feature" is made precise. By training DNNs on these datasets and analyzing the resulting models, the authors establish several curious observations: (Observation 1) When there is *one* simple feature and *many* less simple features, while each feature is predictive enough to allow 100% test accuracy, DNN ends up using *only* the simple feature; (Observation 2) Moreover, if there is *one* simple feature and *many* less simple features, so that the simplest features only allows 90% test accuracy while each of the less simple features leads to 100% test accuracy, DNNs end up using only the simplest feature (and thus suffering 90% test accuracy). These conclusions seem to hold for a range of model architectures and training scenarios. Moreover, the authors provide a theoretical result (Theorem 1) proving the aforementioned simplicity bias for 1-hidden layer DNNs trained with hinge loss on one of their specific synthetic distributions. Overall, the reviewers agree that the focus/topic of this paper are timely. The implicit bias of DNNs is one of the crucial ingredients of the "Deep Learning Phenomena" field that are still not understood. I believe that any steps towards advancing our understanding on this core problem will be very useful in future. And I think the paper does make such steps: (a) many reviewers found the CIFAR/MNIST interesting, (b) the authors explicitly demonstrate that for their specific problems DNNs indeed exhibit strong biases, and moreover these biases are quite surprising (for instance, Observation 2). That being said, there were several concerns raised by the reviewers. I require the authors to account for them and revise the draft accordingly. The following list contains some of them (while the authors should also refer to the detailed reviews to address all other minor ones as well): (1) The writing of the paper in many places is overly bold. The authors use loud words such as "explain" multiple times. In some places they speculate on implications for practical settings. The authors should clarify where the results are context-specific. The authors should avoid speculating about robustness (they seem to claim that the simplicity bias is the major reason behind non-robustness, while this claim is not supported at all). The authors should consider moving Section 4.3 to the "Discussions" section, clearly emphasizing the argument is speculative. (2) The authors should clearly mention that for any given bias one can construct a problem where this bias hurts. This is essentially the argument of Rev#3. Therefore, it is dangerous to conclude that "the bias hurts generalization of DNNs in the natural tasks" from the fact that "the bias hurts DNNs in the specially designed synthetic problems". The authors have no evidence that the implicit bias of DNNs (whatever it is) hurts their generalization when trained, say, on natural image datasets. Moreover, for natural images, the concept of "feature" in "simple feature" is not clear: is it a pixel? A patch? Or the output of the filter? (3) Rev#2 raised several good questions. First, the authors seem to blame mainly SGD for the simplicity bias. But, perhaps, DNN architecture and loss function may also play a role? Second, it would be nice to discuss the influence of the initialization scheme on the observed simplicity bias, specifically the role of the initialization scale (i.e. the magnitude of the weights in the beginning of training). It may be that if we use a larger scale, DNNs will pick up the "complex" features. On top of this, I will ask the authors to clarify the role of $d$ in Theorem 1 and more generally in the paper: (a) The proof of Theorem 1 assumes (at some point, see Lemma F3, for instance) that $d > exp( (8c/eta)^2)$. If we set learning rate to the unrealistically large value eta=1, then we get $d > exp( (8 * 10)^2)$. How come this *extremely strong* assumption was not mentioned in the main part of the paper? (b) Moreover, it seems that everywhere in the paper experiments use a *large number* of features (50). So, I feel, the empirical/theoretical evidence of this paper supports the fact that "when there are *many* complex and *one* simple feature DNN picks up the simple one" with emphasize on *many*. What happens if we only have 2 features, i.e. d=2? Theorem 1 in this case becomes vacuous. What about empirical results? I ask the authors to explicitly address this question in the revision. I feel all the requirements listed above are realistic and I also believe that addressing them will significantly improve the presentation of the paper. I tend to trust the authors in implementing these modifications. Otherwise, I feel this paper deserves to be published.