Reinitializing the moments of the optimizer during learning (Jax)

Published May 27, 2023

Reinitializing the weights of a neural network has been shown to be effective for continual learning (cf. CBP algorithm 1). In this short guide, I show how to update the moments of a momentum-based optimizer in Jax.

Here I reinitialized all components of the two moments, not just those corresponding to low utility hidden units, as CBP does.

Let’s consider a simple regression problem with multivariate inputs in three dimensions and scalar targets. We minimize the squared error using stochastic momentum-based gradient descent. Specifically, we use the well-established Adam optimizer to update the parameters of a linear layer, processing a single example at each time step, for a total of 1000 steps.

At each time step, after updating the parameters of the linear layer, we reinitialize the two moments of Adam to zero.

import jax
import optax
from flax import linen as nn
from flax import traverse_util
from flax.core import freeze
from flax.training import train_state
from flax.training.train_state import TrainState
from jax import numpy as jnp
from jax import random

Initialize hyper-parameters:

n_samples = 1000
x_dim = 3
y_dim = 1

seed = 0
step_size = 1e-3

Define linear model

model = nn.Dense(features=y_dim)

Initialize model

key, subkey = random.split(jax.random.PRNGKey(seed))
x = jax.random.normal(key, (x_dim, ))

key, subkey = random.split(subkey)
params = model.init(key, x)
params
    FrozenDict({
        params: {
            kernel: Array([[-0.48033935],
                [ 0.07632174],
                [ 1.1726483 ]], dtype=float32),
            bias: Array([0.], dtype=float32),
        },
    })

Generate random ground truth W and b

key = random.PRNGKey(seed)
key, subkey = random.split(key)
W = random.normal(key, (x_dim, y_dim))
b = random.normal(subkey, (y_dim, ))
true_params = freeze({'params': {'bias': b, 'kernel': W}})

# Generate samples with some noise
key_sample, key_noise = random.split(key)
x_samples = random.normal(key_sample, (n_samples, x_dim))
y_samples = jnp.dot(x_samples,
                    W) + b + 0.1 * random.normal(key_noise, (n_samples, y_dim))
x_samples.shape, y_samples.shape
    (1000, 3) (1000, 1)

 

@jax.jit
def train_step(state, x, y):
    def squared_error(params):
        pred = state.apply_fn(params, x)
        return jnp.inner(y - pred, y - pred) / 2.0

    loss_val, grads = jax.value_and_grad(squared_error)(state.params)

    state = state.apply_gradients(grads=grads)

    return state, loss_val
tx = optax.adam(learning_rate=step_size)
opt_state = tx.init(params)
opt_state
    (ScaleByAdamState(count=Array(0, dtype=int32), mu=FrozenDict({
        params: {
            bias: Array([0.], dtype=float32),
            kernel: Array([[0.],
                    [0.],
                    [0.]], dtype=float32),
        },
    }), nu=FrozenDict({
        params: {
            bias: Array([0.], dtype=float32),
            kernel: Array([[0.],
                    [0.],
                    [0.]], dtype=float32),
        },
    })),
    EmptyState())

 

Initialize training state (model and optimizer params)

state = TrainState(apply_fn=model.apply,
                   params=params,
                   tx=tx,
                   opt_state=opt_state,
                   step=jnp.array(0.))

Train on a per-example basis and re-initialize moments at each time-step

losses = []
for i in range(n_samples):
    state, loss_val = train_step(state, x_samples[i], y_samples[i])

    # Reset moments:
    # 1. flatten
    flat_mu = traverse_util.flatten_dict(state.opt_state[0].mu, sep="/")
    flat_nu = traverse_util.flatten_dict(state.opt_state[0].nu, sep="/")

    # 2. modify
    flat_mu = jax.tree_util.tree_map(lambda x: jnp.zeros_like(x), flat_mu)
    flat_nu = jax.tree_util.tree_map(lambda x: jnp.zeros_like(x), flat_nu)

    # 3. unflatten & freeze
    new_mu = freeze(traverse_util.unflatten_dict(flat_mu, sep='/'))
    new_nu = freeze(traverse_util.unflatten_dict(flat_nu, sep='/'))

    # update moments optimizer
    new_opt_state = (state.opt_state[0]._replace(
        mu=new_mu, nu=new_nu), ) + state.opt_state[1:]
    state = state.replace(opt_state=new_opt_state)

    # book-keeping
    losses.append(loss_val)

    # log loss averaged over last 100 examples
    if i > 0 and i % 100 == 0:
        print(
            'Step: {}:'.format(i), 'Avg. loss (last 100): {:.4f}'.format(
                jnp.mean(jnp.array(losses)[max(i - 100, 0):i])))
    Step: 100: Avg. loss (last 100): 2.4251
    Step: 200: Avg. loss (last 100): 2.6208
    Step: 300: Avg. loss (last 100): 2.2898
    Step: 400: Avg. loss (last 100): 1.7616
    Step: 500: Avg. loss (last 100): 1.6727
    Step: 600: Avg. loss (last 100): 1.0890
    Step: 700: Avg. loss (last 100): 1.1949
    Step: 800: Avg. loss (last 100): 0.9610
    Step: 900: Avg. loss (last 100): 0.6888

 

References


  1. Dohare, S., Sutton, R. S., & Mahmood, A. R. (2021). Continual backprop: Stochastic gradient descent with persistent randomness. arXiv preprint arXiv:2108.06325. ↩︎