7 Comments

This is not a good conceptual model for neural net optimization. I am 100% positive that if you that if you take anything approaching a real neural net and train it on anything approaching a real dataset using gradient descent at anything approaching a real learning rate, the optimization dynamics will quickly become not-contractive. Assuming square loss, the spectral norm of the NTK JJ’ will grow until reaching the specific value at which the map is not contractive, and will then stop growing. The dynamics during this phase of training will not be locally linear over any timescale, no matter how short.

But you shouldn’t believe a word that I (or anyone else) says — you should run this experiment for yourself on your own net and data, so that you can see with your own eyes that what I am saying is true.

Expand full comment
author

How are you 100% positive having never run an experiment?

Expand full comment

To be clear, I’ve spent my entire PhD waking up and spending the whole day running variants of this specific experiment on different networks and different datasets. At no point have I ever observed any outcome other than the one I have predicted you will observe. Furthermore, many other people have run this experiment on their own networks and datasets and all of them have had the same experience.

I would be happy to make a monetary bet that you will observe the same outcome as well, on YOUR OWN network and dataset. (Where we apply the “reasonable ML researcher standard” to the three “anything approaching a real” qualifications in my claim. You can sub sample data to make gradient descent training practical.)

Expand full comment
author

What does "Assuming square loss, the spectral norm of the NTK JJ’ will grow until reaching the specific value at which the map is not contractive, and will then stop growing." mean? You mean you find eigenvalues equal to 1?

Expand full comment

If you train using deterministic gradient descent and using the square loss function, the spectral norm of JJ' will increase until reaching 2/s, where s is the learning rate. After reaching 2/s, the spectral norm of JJ' will flatline at 2/s or oscillate tightly around this value. Equivalently, the spectral norm of the GD state transition matrix [I - s J J'] will flatline at 1 or oscillate around this value.

There are 115 figures in my paper here which exhibit an analogous phenomenon: https://arxiv.org/abs/2103.00065 (this is in terms of the Hessian rather than the NTK).

If you want to see version of this phenomenon for the NTK matrix, you could look at Figures 1,2 here https://arxiv.org/abs/2207.12678.

Anyway, I already have code to compute the top eigenvalue of the NTK, so if you give me PyTorch code for a neural net and a training dataset, I would be happy to verify my claim!

Expand full comment
author

How do you compute the top eigenvalue of the NTK? Can you show me this code?

Expand full comment
Sep 22, 2023·edited Sep 22, 2023

The naive way is to explicitly form the NTK by computing the pairwise inner products between the network gradients on every pair of training data points. This is practical if the dataset is small enough.

A more efficient way is to compute the eigenvalues of the Gauss-Newton matrix J' J using the Lanczos method, which only requires matrix-vector products. I have some horrible grad student research code here https://pastebin.com/zfndXTKg, but be warned that the code is really horrible and also depends on a library called BackPack (https://backpack.pt/) to compute gradients in a batched manner.

So I would probably recommend either just computing the NTK matrix explicitly, or getting students in your ML class to run the experiment.

---

Edit: actually, it looks like PyTorch now has built-in tools for computing both the explicit NTK as well as NTK-vector products: https://pytorch.org/functorch/stable/notebooks/neural_tangent_kernels.html. So I would recommend using this instead. If you can do NTK-vector products, then you can use scipy's built-in Lanczos algorithm to get the top eigenvalues: https://github.com/locuslab/edge-of-stability/blob/github/src/utilities.py#L106

Expand full comment