In spite of the success of deep learning, we know relatively little about the many possible solutions to which a trained network can converge. Networks generally converge to some local minima—a region in space where the loss function increases in every direction—of their loss function during training. Our research explores why local minima outperforms others when a trained network is evaluated on a held-out test set.

Recent work on local minima has explained the loss landscape of a neural network. A loss landscape is the geometry of the network’s loss function, or the response of the network’s loss function when its weight values are tweaked. Work in this area has investigated the connectivity of solutions in the loss landscape and the geometry of particularly high-accuracy local minima.

Researchers have shown that networks living in a shallow basin (a local minima in which small changes to weights greatly affect the loss function) of the loss landscape (a local minima in which small changes to weights don't affect the loss function a lot) exhibit better generalization properties than networks living in a steep canyon. This means we can produce accurate models by adjusting our training methods to find a network that lives at the center of a shallow basin. For example, Stochastic Weight Averaging first finds a collection of models near the edge of a shallow basin in the loss landscape. Then, it averages the weights from this collection of models to find a more accurate solution nearer to the center of the basin. Another method, Sharpness-Aware Minimization, adds a regularization term to the optimization specifically designed to find wide, flat regions of the training loss.

In our research, Learning Neural Network Subspaces, we take a different approach to leverage properties of the loss landscape to train more accurate models. Research shows that regions of the loss landscape are connected by low-loss curves in the loss landscape. Inspired by Collegial Ensembles, we train a subspace of neural networks designed to contain multiple networks with low loss. Traditionally, individual models are trained to find local minima of the loss landscape. Instead, we train a line in the loss landscape with endpoints parameterized by two sets of neural network weights, w1w_1​ and w2w_2​, as seen in Figure 1. All points on the line exhibit low loss, and correspond to network weights that yield high accuracy. Our method is conceptually similar to ensembling, training multiple networks and averaging their predictions). But our method can produce a prediction with only a single forward pass.

Figure 1: Standard training adjusts a single set of network weights to find a point at a local minima of the loss landscape. Our method uses multiple sets of network weights to find a subspace of low loss.

Our approach has three major benefits over standard training. First, the subspace can be sampled to find more accurate solutions. For example. the center of the line in Figure 1 is more accurate than the traditional solution. We can produce more accurate deep learning models with a very simple tweak to our training procedures. Second, this approach yields models with improved calibration. Model outputs better represent the likelihood that the model is correct. This is useful for applications like sensitive or fault-tolerant systems. Third, the approach yields solutions more robust to noise. We can produce models more resilient to image corruption or mislabeled images than standard models.

In the next sections, we provide an overview of our training method. And, show experimental results that evaluate the effectiveness of our method in terms of improved accuracy, calibration, and robustness.

Training a Subspace of Networks

Our method starts by parameterizing a subspace of neural networks, a region in weight space where each point represents the weights of a neural network. We can represent a neural network with a vector ww that contains all of its weights. This vector represents a single point in weight space. Now, suppose we have two networks with weights w1w_1​ and w2w_2​ (e.g. two points in weight space). The set of all points on the line from w1w_1​ to w2w_2​ is a one-dimensional subspace of networks. Each point on the line represents a different set of network weights. In mathematical terms, this line corresponds to the set of points {αw1+(1α)w2:α[0,1]}\{{\alpha}w_1+(1-\alpha)w_2:{\alpha}\in[0,1]\}. We can also consider higher-dimensional subspaces by considering shapes with more endpoints. Using three endpoints, we can define a triangular subspace, all points in the interior of the triangle. In the general case, using mm endpoints results in an m1m-1 dimensional simplex. This includes lines and triangles as special cases. We can also learn more exotic shapes like Bezier curves.

So far, we’ve discussed how to mathematically define a subspace. But how do we find a subspace that contains high-accuracy neural networks? Generally, if two models are trained independently, the line that connects them in weight space doesn't contain high-accuracy solutions. This is where our special training method comes in:

  1. Create mm sets of network weights. For the case of a line, m=2m=2. Let wiw_i​ refer to the ithi^{th} set of weights for i=[1,,m]i=[1,\ldots,m].
  2. Choose a network from the interior of the simplex defined by the endpoints wiw_i​. For the case of a line, this corresponds to choosing some α[0,1]\alpha\in[0,1] and computing the weights of a new network, w=αw1+(1α)w2w^*={\alpha}w_1+(1-\alpha)w_2. For m>2m>2, we are instead sampling from the interior of a higher-dimensional simplex, and constructing ww^* as a weighted sum of the endpoints wiw_i​ of the simplex.
  3. Perform a standard neural network forward pass using this network with weights ww^*.
  4. Perform a standard neural network backward pass. Backpropagate gradients to each of the mm endpoints wiw_i​ from which ww^* was constructed.

Our method encourages all points inside the region defined by endpoints wiw_i to contain high-accuracy solutions. Let’s look at a plot that verifies this. Figure 2 shows the test accuracy of cResNet20 models on the CIFAR10 dataset which contains small images from 10 categories. We show the case of training a neural network subspace with m=3m=3 endpoints. The triangular region defined by these endpoints contains high accuracy solutions. Outside this triangular region, accuracy diminishes.

Figure 2: The loss landscape of our learned subspace on the CIFAR10 dataset. The triangular region contains high-accuracy network weights.

That covers our discussion of what it means to train a subspace. Fundamentally, the discovery of training regions of space that contain accurate solutions furthers our understanding of the structure of loss landscape. Training a subspace has practical applications for model accuracy and model robustness.

Improved Model Accuracy

First, we consider the case of learning a one-dimensional subspace, and compare the result to standard training. Figure 3 shows the performance of our method when training a line or a Bezier curve. When training a line, we consider two cases. The first case is the standard case described above. The second case is slightly modified: we sample a separate parameter α\alpha for each layer’s weights during training but still using a single α\alpha when testing the model. We call this method layer-wise line. We compare the cases of a line, layer-wise line, and Bezier curve to the results obtained by standard training.

Figure 3 shows that our method produces solutions that exceed the accuracy of a standard trained network. Training a subspace and taking the model in the center of the subspace, allows us to exceed the accuracy of an independently trained model. This is remarkable, because the training procedure used to learn a subspace is able to help us identify a more accurate model than before, even if we discard the subspace and only retain its midpoint after training. Thus, we can train more accurate models with no additional computational cost at inference time.

Standard Training
Line (Layerwise)
Line
Curve
Figure 3: Accuracy of our method computed on the CIFAR10 dataset at different points along our neural network subspace.

What happens if we move beyond this one-dimensional subspace and experiment with a higher-dimensional simplex? Can we further improve our accuracy? We investigate this in Figure 4, showing the accuracy of the model at the center of our subspace for m>1m>1. We compare our results to the Stochastic Weight Averaging. Their method involves averaging several models that converged near the same minima in an effort to find a model closer to the center of a shallow basin in the loss landscape.

In some cases, we can improve accuracy by using larger subspaces, but most of the benefit is achieved early on. This means we don’t need to worry about training an enormous subspace to find the most accurate model (a few endpoints suffice) or unduly increasing computational burden by maintaining too many copies of network weights.

Standard Training
Simplex (Layerwise Midpoint)
Simplex (Midpoint)
SWA (High Const. LR)
SWA (Cyclic LR)
Figure 4: The performance of the model on the CIFAR10 dataset at the center of a simplex as the number of models m increases.

Improved Model Calibration

Next, we investigate the calibration of our model. This is how well the model’s outputs reflect the probability that the prediction is correct. To better understand the notion of calibration, consider that classification models typically produce output scores for each possible class. The scores are all between 0 and 1 and sum to 1. As a result, these scores are often interpreted as probabilities that the image belongs to a given class, which is appealing for many applications in which we care about how certain the model is. For example, if a secure application only wants to take action if a model is certain of its prediction, it could examine the model’s output to determine the probability that the model is correct.

How can we tell whether these confidences truly reflect the probability that the model is correct? Expected Calibration Error measures this. To find this metric, we compute the network’s outputs on the test set. We bucket these outputs into several bins to build a histogram. Then, we look at each bucket in the histogram and measure the difference between the output scores in that bin and the number of correct predictions in that bin.

Figure 5 shows the expected calibration error of our model compared to baseline methods. The dropout baseline corresponds to training with Dropout, which randomly deletes portions of a neural network’s output and improves calibration. Label Smoothing corresponds to adjusting the training labels when optimizing a network to reduce overfitting. Our method combined with label smoothing provides the lowest calibration error. So, our model’s outputs are more tightly correlated with the probability that the prediction is correct.

Standard Training
Dropout (best)
Label Smoothing (best)
Simplex (Layerwise Midpoint)
Simplex (Midpoint)
Simplex + LS (Layerwise Midpoint)
SWA (High Const. LR)
Figure 5: Expected calibration error of our models on the CIFAR10 dataset compared to baselines.

Improved Model Robustness

Next, we investigated the robustness of our model. Ideally, a model performs well on images even if those images suffer from a reasonable amount of corruption. Robustness can help models perform better in the real world, when unexpected events like a smudged camera lens affects the quality of inputs to the model.

The Tiny ImageNet-C dataset provides inaccurate versions of images from the ImageNet dataset. In our evaluation, we train our models on ImageNet, then test them on TinyImageNet-C under the “snow” inaccurate images at varying severity levels. Figure 6 shows our results. We find that our model outperforms standard training.

Standard Training
Standard Ensemble of Two
Line
Line (Ensemble)
Line (Layerwise)
Line (Layerwise Ensemble)
Figure 6: ImageNet-C performance of our model compared to baselines using “snow” corruption. We find that our model performs well even at high levels of image corruption, indicating good robustness.

Figure 7 shows our model’s robustness to a different type of corruption: inaccurate training labels. Robustness to inaccurate labeling is useful because many real-world datasets suffer from this problem. During training, we changed 20 percent of the labels to random, incorrect values. Our test set labels remained unchanged. No other part of the training formulations were altered. We found that our layer-wise training method outperformed baseline methods for robustness against corrupted labeling.

Standard Training
Standard Training (Opt. Early Stop)
Dropout (best)
Label Smoothing (best)
Simplex (Layerwise Midpoint)
Simplex (Midpoint)
SWA (High Const. LR)
Figure 7: Robustness to corruption of training labels. Our layer-wise method achieves the strongest accuracy, outperforming other methods for improving robustness.

Conclusion

We presented a method for training subspaces of neural networks, or regions in weight space containing high-accuracy solutions. Our simple and effective training method can be used to obtain a model that demonstrates enhanced performance when compared to various baseline approaches. First, our model is more accurate; more likely to make correct predictions. Second, our model is better calibrated; its confidence better reflects the probability that its predictions are correct. Finally, our model is more robust; it can perform better when exposed to image and labeling corruption or noise. These advancements help us produce better and more reliable machine learning models for users. We look forward to continued explorations of the loss landscape and to discovering new ways to leverage its properties to improve neural networks.

Acknowledgements

We acknowledge the contributions of Mitchell Wortsman, Maxwell Horton, Carlos Guestrin, Ali Farhadi, and Mohammad Rastegari.

References

Deng, J., Dong, W., Socher, R., Li, L.-J., Li, K., and Fei-Fei, L. Imagenet: A large-scale hierarchical imagedatabase. In 2009 IEEE Conference on Computer Vision and Pattern Recognition, pp. 248–255, 2009. doi:10.1109/CVPR.2009.5206848. [link].

Foret, P., Kleiner, A., Mobahi, H., and Neyshabur, B. Sharpness-aware minimization for efficiently improvinggeneralization. In International Conference on LearningRepresentations, 2021.

Garipov, T., Izmailov, P., Podoprikhin, D., Vetrov, D., and Wilson, A. G. Loss surfaces, mode connectivity, and fast ensembling of dnns. In Proceedings of the 32nd International Conference on Neural Information Processing Systems, NIPS’18, pp. 8803–8812, Red Hook, NY, USA, 2018. Curran Associates Inc.

He, K., Zhang, X., Ren, S., and Sun, J. Deep residuallearning for image recognition. 2016 IEEE Conference on Computer Vision and Pattern Recognition (CVPR), pp.770–778, 2016.

Hendrycks, D. and Dietterich, T. Benchmarking neural network robustness to common corruptions and perturbations. Proceedings of the International Conference onLearning Representations, 2019.

Izmailov, P., Podoprikhin, D., Garipov, T., Vetrov, D., and Wilson, A. G. Averaging weights leads to wider optima and better generalization, 2019.

Krizhevsky, A. Learning multiple layers of features from tiny images. Technical report, 2009.

Krizhevsky, A., Sutskever, I., and Hinton, G. E. Imagenet classification with deep convolutional neuralnetworks.In Pereira, F., Burges, C. J. C., Bottou,L., and Weinberger, K. Q. (eds.),Advances in NeuralInformation Processing Systems, volume 25. Curran As-sociates, Inc., 2012.

Li, H., Xu, Z., Taylor, G., Studer, C., and Goldstein, T. Visualizing the loss landscape of neural nets. In Bengio, S., Wallach, H., Larochelle, H., Grauman, K., Cesa-Bianchi,N., and Garnett, R. (eds.), Advances in Neural Information Processing Systems, volume 31. Curran Associates, Inc., 2018a. [link].

Littwin, E., Myara, B., Sabah, S., Susskind, J., Zhai, S., and Golan, O. Collegial ensembles. 2020. [link].

Liu, Y. and Zhang, J. Deep Learning in Machine Translation, pp. 147–183. Springer Singapore, Singapore, 2018. ISBN 978-981-10-5209-5. doi: 10.1007/978-981-10-5209-56.

Müller, R., Kornblith, S., and Hinton, G. When does label smoothing help?, 2020.

Nixon, J., Dusenberry, M., Jerfel, G., Zhang, L., and Tran, D. Measuring calibration in deep learning, 2020. [link].

Srivastava, N., Hinton, G., Krizhevsky, A., Sutskever, I., and Salakhutdinov, R. Dropout: a simple way to prevent neural networks from overfitting. The journal of machine learning research, 15(1):1929–1958, 2014.

Related readings and updates.

A Multi-Task Neural Architecture for On-Device Scene Analysis

Scene analysis is an integral core technology that powers many features and experiences in the Apple ecosystem. From visual content search to powerful memories marking special occasions in one’s life, outputs (or "signals") produced by scene analysis are critical to how users interface with the photos on their devices. Deploying dedicated models for each of these individual features is inefficient as many of these models can benefit from sharing resources. We present how we developed Apple Neural Scene Analyzer (ANSA), a unified backbone to build and maintain scene analysis workflows in production. This was an important step towards enabling Apple to be among the first in the industry to deploy fully client-side scene analysis in 2016.

See highlight details

Learning Neural Network Subspaces

Recent observations have advanced our understanding of the neural network optimization landscape, revealing the existence of (1) paths of high accuracy containing diverse solutions and (2) wider minima offering improved performance. Previous methods observing diverse paths require multiple training runs. In contrast we aim to leverage both property (1) and (2) with a single method and in a single training run. With a similar computational cost as…
See paper details