Quick introduction to RLHF for fine-tuning LLMs to better match human preferences

Published November 28, 2024

What


  • Reinforcement learning from human feedback (RLHF) is a fine-tuning technique 1 to align LLM outputs to human preferences 2.

  • RLHF consists of four main steps 3:

    1. LLM pre-training $\rightarrow$ Base model
    2. Supervised fine-tuning (SFT) $\rightarrow$ Instruct-tuned model (IM)
    3. Reward model training $\rightarrow$ Reward model (RM)
    4. LLM policy optimization $\rightarrow$ Policy model (PM)
    • Steps 2-3 can be iterated continuously by using the current best policy model to get a better instruct-tuned model, which is used to get a better reward model, which is used to improve the policy model, and so on.

Why


  • RLHF has been useful in reducing responses related to toxicity, bias, and harmfulness by using the preference signal of a well-intentioned and unbiased human feedback 4.

How


  • To explain how it works, let’s go through each step in more detail, noting he input, output, model, and data used at each step.

  • LLM pre-training: Typically, we take an existing LLM that has been trained on a large corpus of Internet-scale data to predict the next token given an input token sequence.

  • Supervised fine-tuning: We fine-tune the base LLM model on prompt-answer pairs using standard supervised-learning.

  • RM training: The outcome of this step is a model that takes in a sequence of prompt and answer tokens concatenated, and outputs a scalar reward.

    • The starting model is the LLM from the previous step of SFT, with the last vocabulary embedding layer removed and replaced with a linear layer for predicting the reward 5.
    • The dataset is a set of tuples containing: prompt, preferred answer, rejected answer, reward for preferred answer, reward for rejected answer 6.
      • The model is trained to give a high positive reward to the preferred answer and give a low negative reward to the rejected answer (by the human labeler).
      • The loss is at high level: $$ -\log \sigma(\text{model}(\text{prompt}, \text{answer}_{\text{prefer}}), \text{model}(\text{prompt}, \text{answer}_{\text{reject}})) $$
  • LLM Policy optimization: In this step, the actual LLM is fine-tuned using RL.

    • Specifically, the problem of predicting the next token given a sequence of input tokens (prompt) is cast as an RL problem as follows:
      • Agent: LLM model from step 2
      • State: Current prompt plus previously predicted answer tokens
        • For example this is a trajectory experienced by the LLM-agent:
          • s0 = current prompt
          • a0 = next token predicted by LLM
          • s1 = (s0, a0)
          • a1 = next token predicted by LLM
          • s2 = (s1, a1)
      • Action: next token from vocabulary
      • Reward: scalar reward from RM
      • Environment: provides the initial user-prompt plus the iteratively predicted answer tokens at each step
    • Policy takes in the state (user-prompt and generated answer tokens so far) and produces the next token in the sequence.
      • It’s trained using the PPO algorithm 7:
        • Samples a set of trajectories (given the user prompt) and generates different answers (using the temperature hyper-parameter of the LLM).
        • The initial policy is frozen and used as a reference to compute the KL-divergence between it and the optimizing policy.
          • This is done to avoid reward hacking where the LLM might return useless answers just to maximize the expected reward (i.e., human preference).
        • The value function that estimates the expected reward given a token sequence is computed by adding a separate head to the LLM policy to make this prediction.
          • This value function is used to compute the advantage function in PPO 8.

Footnotes


  1. To the best of my knowledge, RLHF hasn’t been used in practice for end-to-end training of a LLM from scratch. ↩︎

  2. RLHF has also been used to match human preferences for similarity and diversity in images (cf. Ding, L., Zhang, J., Clune, J., Spector, L., & Lehman, J. (2023). Quality diversity through human feedback. arXiv preprint arXiv:2310.12103.). ↩︎

  3. Ouyang, L., Wu, J., Jiang, X., Almeida, D., Wainwright, C., Mishkin, P., … & Lowe, R. (2022). Training language models to follow instructions with human feedback. Advances in neural information processing systems, 35, 27730-27744. ↩︎

  4. I suppose one could train a LLM without RLHF to avoid these drawbacks by assembling an extremely well curated and constrained dataset so that it is very unlikely to generate such ill-intentioned responses. ↩︎

  5. The reward model can also be a completely different model than the current LLM, as long as it respects the input/output API. ↩︎

  6. There can be more than two answers generated by the LLM for the same input prompt, compared for human preference. For simplicity, I have only mentioned two here. ↩︎

  7. https://spinningup.openai.com/en/latest/algorithms/ppo.html#quick-facts ↩︎

  8. https://spinningup.openai.com/en/latest/algorithms/ppo.html#id6 ↩︎