Soft actor-critic: Off-policy maximum entropy deep reinforcement learning with a stochastic actor

Published June 1, 2023
Title: Soft actor-critic: Off-policy maximum entropy deep reinforcement learning with a stochastic actor
Authors: Tuomas Haarnoja, Aurick Zhou, Pieter Abbeel, Sergey Levine
Link: https://proceedings.mlr.press/v80/haarnoja18b/haarnoja18b.pdf

What


Off-policy actor-critic deep RL algorithm with stochastic actor for continuous state and action space maximizing the expected reward plus the expected entropy of the target policy.

Why


Existing model-free deep RL algorithms have very high sample complexity, and sensitive convergence properties, hence extensive hyper-parameter tuning for new domains needed.

How


TL;DR: The Soft Actor-Critic (SAC) algorithm is an off-policy algorithm that optimizes a stochastic policy, concurrently learning a policy and two Q-functions, and it is designed for continuous state and action spaces, aiming to maximize the entropy of the policy to encourage exploration.

SAC aims to solve the problem of maximizing cumulative expected reward plus maximizing the entropy of the policy. The problem is sometimes referred to as Entropy-Regularized1 Reinforcement Learning (RL), and is different than the standard RL problem of maximizing cumulative expected reward.

The new RL objective is:

$$ \begin{align} J(\pi) \stackrel{.}{=} \mathbb{E}_{\pi} \Big[ R_{t+1} + \beta \mathcal{H} \big( \pi(A_t | S_t ) \big) | S_t, A_t \sim \pi \Big], \end{align} $$

where

  • $\mathcal{H} \big( \pi( \cdot | S_t ) \big) \stackrel{.}{=} \mathbb{E}_{\pi}\Big[ -\log \big( \pi( \cdot | S_t ) \big) \Big] $ is the entropy of the policy $\pi$.

    • $\beta$2 – trade-off coefficient between the policy entropy term and the reward3.

Soft Approximate Policy Iteration (s-API)

Instead of running policy iteration with policy evaluation and policy improvement to convergence, the two steps are each approximated with some number of gradient descent steps.

The value function, two action-value functions, and policy are parametrized using function approximation by $\mathbf{w}, \bm{\theta}_1, \bm{\theta}_2, \bm{\theta}_\pi$ . We surpress in the following equations, where appropriate, that these quantities are functions of their parameters.

Policy Evaluation

In policy evaluation, the state-value and action-value functions are updated.

State-value function

$$ \begin{align} J_{\hat{v}}(\mathbf{w}) &\stackrel{.}{=} \mathbb{E}_{s \sim \mathcal{D}} \Bigg[ \frac{1}{2} \Big( \hat{v}(s, \mathbf{w}) - \mathbb{E}_{\tilde{a} \sim \pi( \cdot | s, \bm{\theta}_\pi) } [ \hat{q}(s, \tilde{a}) - \log \pi(\tilde{a} | s, \bm{\theta}_\pi) ] \Big)^2 \Bigg], \end{align}, $$

where $\mathcal{D} \stackrel{.}{=} \lbrace (s, a, r, s^\prime) \rbrace^N_{n=1}$ is the replay buffer and $\tilde{a}$ is to denote that the action is sampled from the current policy and not the replay buffer.

$$ \begin{align} \hat{q}(s, a) \stackrel{.}{=} \min_{ i \in \lbrace 1, 2 \rbrace } \hat{q}(s, a | \bm{\theta}_i), ~~~~ \forall s \in \mathcal{S}, a \in \mathcal{A} \end{align} $$

The gradient of Equation 2 is:

$$ \begin{align} \nabla J_{\hat{v}}(\mathbf{w}) \stackrel{.}{=} \nabla \hat{v}(s, \mathbf{w}) \Big( \hat{v}(s, \mathbf{w}) - \hat{q}(s, a) + \log \pi(a | s, \bm{\theta}_\pi) \Big), ~~~~ \forall s \in \mathcal{S}, a \in \mathcal{A}. \end{align} $$

Action-value function

$$ \begin{align} J_{\hat{q}}(\bm{\theta}) \stackrel{.}{=} \mathbb{E}_{(s, a) \sim \mathcal{D}} \Bigg[ \frac{1}{2} \Big( \hat{q}(s, a, \bm{\theta}) - (r(s, a) + \gamma \mathbb{E}_{s^\prime \sim \mathcal{D}}[ \hat{v}(s^\prime, \mathbf{w^{\text{target}}})]) \Big) \Bigg] \end{align} $$

The gradient of Equation 5 is:

$$ \begin{align} \nabla J_{\hat{q}}(\bm{\theta}) \stackrel{.}{=} \nabla \hat{q}(a, s, \bm{\theta}) \left( \hat{q}(s, a, \bm{\theta}) - r(s, a) - \gamma \hat{v}(s^\prime, \mathbf{w^{\text{target}}}) \right). \end{align} $$

  • $\mathbf{w^{\text{target}}}$ are the parameters of the target state-value function 4.

Policy Improvement

In policy improvement step, the policy is updated using the value functions from previous step.

In this paper, they minimize the KL-divergence between the policy and a Boltzmann distribution.

$$ \begin{align} J_{\pi}(\bm{\theta_{\pi}}) \stackrel{.}{=} \mathbb{E}_{s \sim \mathcal{D}} \Bigg[ \mathcal{KL} \Big( \pi(\cdot | s, \bm{\theta_{\pi}}) \Vert \frac{\exp(s, \hat{q}(s, \cdot, \bm{\theta}))}{Z(s, \bm{\theta})} \Big) \Bigg] \end{align} $$

To take the gradient of Equation 7, the paper uses standard backpropagation. Since the KL divergence is an expectation over a distribution that depends on the parameters of the policy, we need to remove this dependence since we cannot take derivatives of a stochastic quantity. To do this, we use the reparametrization trick, which moves the stochasticity of the distribution away from the learnable parameters. This doesn’t work for every distribution, but it does for Gaussians.

The policy is parametrized using function approximation to estimate the mean and standard deviation of a Gaussian squashed by a $\text{tanh}$ to have bounded action values in range $[ -1, 1]$.

$$ a \stackrel{.}{=} \tanh \Big( \mu_{\bm{\theta}_\pi}(s) + \sigma_{\bm{\theta}_\pi}(s) \odot \xi \Big) \stackrel{.}{=} f(\xi, s, \bm{\theta}_\pi), ~~~~ \xi \sim \mathcal{N}(\bm{0}, \mathbf{I}) $$

With this reparametrization the Equation 7 is now:

$$ \begin{align} J_{\pi}(\bm{\theta_{\pi}}) \stackrel{.}{=} \mathbb{E}_{s \sim \mathcal{D}, \xi \sim \mathcal{N}} \Big[ \log \pi( f(\xi | s, \bm{\theta_{\pi}}), s, \bm{\theta_{\pi}}) - \hat{q}(s, f(\xi | s, \bm{\theta_{\pi}}), \bm{\theta}) \Big] \end{align} $$

The gradient of Equation 8 is:

$$ \begin{align} \nabla J_{\pi}(\bm{\theta_{\pi}}) \stackrel{.}{=} \nabla_{\bm{\theta_{\pi}}} + (\nabla_a \log \pi(a, s) - \nabla_a \hat{q}(s, a, \bm{\theta})) \nabla f(\xi, s, \bm{\theta_{\pi}}) \end{align} $$

Pseudocode

Below is the pseudocode algorithm specification of SAC using a Py-like syntax.

Soft-Actor Critic for estimating $\pi_\theta \approx \pi_*$

## Inputs
# Input: differentiable state-value function approximation v_hat(s | w)
v_hat = v_hat.init(w)
# Input: target state-value function approximation v_hat_bar(s | w_target)
v_hat_bar = v_hat.copy()
# Input: differentiable soft state-action function approximation q_hat_1(s, a | theta_q1)
q_hat_1 = theta_q1.init(theta_q_1)
# Input: differentiable soft state-action function approximation q_hat_2(s, a | theta_q2)
q_hat_2 = q_hat_2.init(theta_q_2)
# Input: differentiable policy parameterization pi(a | s, theta_pi)
pi = pi.init(theta_pi)

## Parameters
# where appropriate the values are from the paper
# Stepsizes
alpha_v_hat, alpha_q_hat, alpha_pi = 0.003, 0.003, 0.003
# entropy coefficient
beta = 0.01
# Polyak averaging decay
tau = 0.995
# Discount factor
gamma = 0.99

## Initialize
B = ReplayBuffer(size=1000_000)
s = env.reset()

## Define
# double q-function minimum state-action function approximation q_hat(s, a)
q_hat = lambda s, a: min(q_hat_1(s, a), q_hat_2(s, a))

while not done:
    a = pi(s)
    s_next, r, done = env.step(a)
    B.append((s, a, r, s_next))

    batch = B.sample(size=100)

    for example in batch:

        # Update state-value function
        J_v_hat = 0.5 * ((v_hat(s) - (q_hat(s, a) - log(pi(s, theta_pi)) ))**2)
        w = w - alpha_v_hat * grad(J_v_hat)

        # Update target state-value function
        v_hat_bar = tau * v_hat + (1 - tau) * v_hat_bar

        # Update state-action functions
        for i in range(2):
            J_q = 0.5 * ((q_hat_i(s, a, theta_q_i) -
                (r + gamma * v_hat_bar(s_next, theta_q_i)))**2
            )
            theta_q_i = theta_q_i - alpha_q_hat * grad(J_q)

        # Update policy parameters
        J_theta_pi = (
            KL(pi(s, theta_pi), exp(q_hat(q_hat(s, a))))
        )
        theta_pi = theta_pi - alpha_pi * grad(J_theta_pi)

        s = s_next

Thoughts


  • I’m skeptical about the experimental setup, as some of the results don’t agree with the TD3 paper (Fujimoto et. al 2018) that was done concurrently with this work.
  • To my understanding, SAC as developed above is not a true policy gradient method since the policy is not updated using the policy gradient theorem.
    • Would be interesting to know how SAC performs if the policy is optimized using the likelihood ration from the policy gradient theorem instead of doing backpropagation through the action-value function.

References


  1. https://spinningup.openai.com/en/latest/algorithms/sac.html#id6 ↩︎

  2. Also known as temperature parameter in the literature. ↩︎

  3. In the limit $\alpha \rightarrow 0$, the standard RL objective is recovered. ↩︎

  4. This target is used to make the updates of the value function stable. They can be updated using an exponentially moving average (cf. Polyak averaging) of the actual state-value function parameters $\mathbf{w}$ or updated periodically. ↩︎