The Deep Optimization Cookbook
Get those models to converge with one weird trick.
At the end of class yesterday, Konpat Preechakul attested neural nets aren’t that easy to optimize. And he’s right! Neural net optimization still needs a lot of tricks, and if you try to code these things up from scratch, you’ll run into roadblocks. In 2015, I tried to implement AlexNet on CIFAR10 in the newly released TensorFlow, and even with its high-level language, getting all of the fine details of initialization and normalization correct was a headache. Part of the issue was that no one had implemented Alex’s “Local Response Normalization” function. Part of the issue was this was the parameter list from Alex’s repository:
Ugh, having a GitHub repo that you can just slightly modify makes a huge difference in diving in and getting neural nets to work on the problems you actually care about.
Without these repos, you have to learn how to initialize the weights correctly using heuristic guidelines. And there are still issues with backpropagation that vary based on your chosen architecture. You might get “dead ReLUs” where your model is entirely in the part of the activation that equals zero. You can get vanishing gradients where long chains of matrices multiply to zero. Long chains of multiplication can also lead to exploding gradients. You add architectural widgets to get around these, like gradient clipping and batch norm and resnets.
But these inventions all aim to make optimization easier. And I think there’s a unifying optimization motivation behind all of these that also explains why larger models are preferable.
Tell me if you buy this…
All we care about in machine learning is getting training loss small. We care about the predictions of the model getting close to the labels of the training data. We do not care about the parameters of the model at all. The parameters don’t mean anything. Let us not forget that machine learning is nonparametric prediction. So what happens when we study the convergence of the predictions rather than of the weights?
Let’s consider a generic nonlinear least squares problem
If we run gradient descent, our iteration is
Here Jt is the Jacobian of F at wt. In neural nets, each component of F here is the prediction of one of the data points. The corresponding coordinate in y is the label of that data point.
With this in mind, let pt denote the vector of predictions of y at the iteration t. That is pt=F(wt). We can use Taylor’s Theorem to track how the predictions change over time.
If we combine this formula with the definition of gradient descent, we find
This final formula says that as long as this matrix I-s JtJtT is contractive, the predictions should get closer to the labels. If you run long enough, the predictions converge to the labels. I have no idea where the weights go. But if you can keep the Jacobian matrices well conditioned, you’ll get your training error to zero.
Prediction convergence suggests many of the tricks of the neural trade. In an ideal world, you’d like the Jacobian to have rank n. In this case, it’s necessary that you pick a model with more than n parameters. You also want to make sure the Jacobian doesn’t become degenerate as you move. So you’ll have to design clever initializations. You’ll need tricks to battle vanishing gradients. Almost all of the tricks I discussed above, be they resnets or gradient clipping, can be interpreted as ways of preventing degenerate Jacobians. I think writing this out is helpful.
For optimization researchers, it’s clarifying to capture the practical intuition of neural net engineers. If you have an overparameterized model and keep the Jacobian away from degeneracy, gradient descent on nonlinear least squares will force your predictions to converge to your labels. That doesn’t feel deeply profound, but only recently has it become conventional wisdom. The most important lesson I have learned in the last ten years is that we shouldn’t fear overparameterized models. In hindsight, this is blindingly obvious. But it wasn’t obvious to anyone in 2007. It wasn’t obvious to me before 2016. And it’s worth expanding on this argument now and trying to figure out why.
A brief postscript that I couldn’t figure out how to work cleanly into the post.
I learned this convergence argument from reading papers about Neural Tangent Kernels. I think I first saw it in this paper led by Simon Du. If you initialize your weights at random, this J0J0T is some random matrix. It is an n by n positive definite matrix. Where have we seen those kinds of matrices before? Oh, look, it’s a kernel matrix. Lo and behold, this is the “neural tangent kernel” that people love to wax philosophical about. For better or for worse, I don’t think the neural tangent kernel tells us all that much about what neural nets do. You can move very far away from this initialization during training. But the NTKs are valuable because they let you make the convergence argument I sketch here rigorous. Is that useful? You tell me.