Curvature explains loss of plasticity

Published March 29, 2024
Title: Curvature explains loss of plasticity
Authors: Alex Lewandowski, Haruto Tanaka, Dale Schuurmans, Marlos C. Machado
Link: https://arxiv.org/pdf/2312.00246.pdf

What


  • This paper presents empirical evidence that neural networks lose curvature directions during training in a continual learning setting, resulting in a loss of plasticity in learning. The loss of curvature (defined in the Sec. below) is measured by the empirical effective rank of an approximation to the Hessian of the parameters.
  • This study also provides simple counter-examples to demonstrate that previous explanations of loss of plasticity are inconsistent with the situations in which such a loss should occur.
  • Finally, the authors developed a simple regularizer that maintains the parameters distribution close to the initialization distribution, while also ensuring that the curvature is not significantly reduced.

Why


  • To achieve the ambitious goal of AI, we need to develop agents that can learn continually and adapt to changes in data distribution without external intervention.
  • Whilst there are many existing explanations drawn empirically of why loss of plasticity occurs in neural networks it is not clear what the root cause is.

How


TL;DR: The text describes an empirical relationship between the loss of curvature and the loss of plasticity. This relationship is more consistent with different situations than existing explanations. A simple regularizer is developed based on this curvature-plasticity relationship. It preserves the curvature of the Hessian matrix of parameters by penalizing the distribution of parameters if it lies too far away from the distribution of the parameters at initialization. This regularizer allows for a more generous departure from the initialization distribution compared to existing regularization methods, such as weight decay.

The continual learning setting

  • Learning algorithm operates on mini-batch of data of size MM.
  • At fixed periodic times after UU successive updates, the distribution generating the observations and targets is changed (task).
    • A task does not change in difficulty.
  • The error measured at the end of each task KK and averaged across all observations in that task:

J(θnU,K)=.EpK[l(fθnU,K(x),y)] J(\theta_{nU, K}) \stackrel{.}{=} \mathbb{E}_{p_K}[ l(f_{\theta_{nU, K}}(x), y) ]

where pKp_K is the generating distribution of inputs and targets for task KK, l()l(\cdot) some loss (e.g., classification loss), and θnU,K\theta_{nU, K} the parameters at the end of task KK after nUn U updates in total so far.

Measuring curvature to establish an empirical relationship to plasticity

  • The definition of curvature of the optimization objective used in this paper is equal to the effective rank of the Hessian matrix of the parameters θ\theta of the neural network and the training data D\mathcal{D}.
    • The Hessian is the gradient of the loss derivative with respect to θ\theta and it is a function of the training data too.
    • The effective rank of a matrix indicates the number of basis vectors that can represent 99% of the training data D\mathcal{D}, ordered by decreasing singular values.
      • The lower the effective rank, the fewer basis vectors can represent most of the data.
      • Important to note that the effective rank can change without the parameters θ\theta changing, but with the distribution of the training data (when the task in a continual learning setting is varied).

  • The computation of the actual Hessian is not done here because of high computational overhead being a function of the number of weights of a neural network!
  • Instead the authors approximate it like so:

HH^=.i=1Mgigi=GG \mathbf{H} \approx \mathbf{\hat{H}} \stackrel{.}{=} \sum_{i=1}^{M} g_i g_i^{\top} = \mathbf{G} \mathbf{G}^{\top}

where g=.θJ(θ,xi,yi)g \stackrel{.}{=}\nabla_{\theta} J(\theta, x_i, y_i), GRd×M\mathbf{G} \in \mathcal{R}^{d \times M}, and MdM \ll d.

  • As H^\mathbf{\hat{H}} is a Gram matrix, it is computationally advantageous to calculate the effective rank of H\mathbf{H} by computing GGRM×M\mathbf{G}^{\top} \mathbf{G} \in \mathcal{R}^{M \times M} instead of GGRd×d\mathbf{G} \mathbf{G}^{\top} \in \mathcal{R}^{d \times d}, since rank(GG)=rank(GG)\text{rank}(\mathbf{G} \mathbf{G}^{\top}) = \text{rank}(\mathbf{G}^{\top} \mathbf{G}).

  • This definition of curvature was used on a suite of tests to show that in comparison with previous explanations of loss of plasticity, loss of curvature leading to loss of plasticity is consistent.

  • For the other explanations, the authors could find simple counter-examples elucidating that they were in conflict at times and could not explain loss of plasticity on their own.

Curvature is preserved by regularizing the layerwise L2L_2-difference of the parameters at time tt and parameters at initialization time t0t_0

W22(p(l,0),p(l,t))=i=1d(θˉ(i)(l,t)θˉ(i)(l,0))2 \mathcal{W}_2^2(p^{(l, 0)}, p^{(l, t)}) = \sum_{i=1}^{d} \Big( \bar{\theta}^{(l, t)}_{(i)} - \bar{\theta}^{(l, 0)}_{(i)} \Big)^2

where θˉ(i)(l,t)\bar{\theta}^{(l, t)}_{(i)} is the flattened matrix of parameters in layer ll at time tt indexing the individual parameters by ii.