Title: | Curvature explains loss of plasticity |
Authors: | Alex Lewandowski, Haruto Tanaka, Dale Schuurmans, Marlos C. Machado |
Link: | https://arxiv.org/pdf/2312.00246.pdf |
What
- This paper presents empirical evidence that neural networks lose curvature directions during training in a continual learning setting, resulting in a loss of plasticity in learning. The loss of curvature (defined in the Sec. below) is measured by the empirical effective rank of an approximation to the Hessian of the parameters.
- This study also provides simple counter-examples to demonstrate that previous explanations of loss of plasticity are inconsistent with the situations in which such a loss should occur.
- Finally, the authors developed a simple regularizer that maintains the parameters distribution close to the initialization distribution, while also ensuring that the curvature is not significantly reduced.
Why
- To achieve the ambitious goal of AI, we need to develop agents that can learn continually and adapt to changes in data distribution without external intervention.
- Whilst there are many existing explanations drawn empirically of why loss of plasticity occurs in neural networks it is not clear what the root cause is.
How
TL;DR: The text describes an empirical relationship between the loss of curvature and the loss of plasticity. This relationship is more consistent with different situations than existing explanations. A simple regularizer is developed based on this curvature-plasticity relationship. It preserves the curvature of the Hessian matrix of parameters by penalizing the distribution of parameters if it lies too far away from the distribution of the parameters at initialization. This regularizer allows for a more generous departure from the initialization distribution compared to existing regularization methods, such as weight decay.
The continual learning setting
- Learning algorithm operates on mini-batch of data of size .
- At fixed periodic times after successive updates, the distribution generating the observations and targets is changed (task).
- A task does not change in difficulty.
- The error measured at the end of each task and averaged across all observations in that task:
where is the generating distribution of inputs and targets for task , some loss (e.g., classification loss), and the parameters at the end of task after updates in total so far.
Measuring curvature to establish an empirical relationship to plasticity
- The definition of curvature of the optimization objective used in this paper is equal to the effective rank of the Hessian matrix of the parameters of the neural network and the training data .
- The Hessian is the gradient of the loss derivative with respect to and it is a function of the training data too.
- The effective rank of a matrix indicates the number of basis vectors that can represent 99% of the training data , ordered by decreasing singular values.
- The lower the effective rank, the fewer basis vectors can represent most of the data.
-
Important to note that the effective rank can change without the parameters changing, but with the distribution of the training data (when the task in a continual learning setting is varied).
- The computation of the actual Hessian is not done here because of high computational overhead being a function of the number of weights of a neural network!
- Instead the authors approximate it like so:
where , , and .
-
As is a Gram matrix, it is computationally advantageous to calculate the effective rank of by computing instead of , since .
-
This definition of curvature was used on a suite of tests to show that in comparison with previous explanations of loss of plasticity, loss of curvature leading to loss of plasticity is consistent.
-
For the other explanations, the authors could find simple counter-examples elucidating that they were in conflict at times and could not explain loss of plasticity on their own.
Curvature is preserved by regularizing the layerwise -difference of the parameters at time and parameters at initialization time
where is the flattened matrix of parameters in layer at time indexing the individual parameters by .