Prelude
This paper pill covers recent work by Zhu et al. [Zhu23F] and puts it into a short historical/algorithmic context. The advantage-induced policy alignment (APA) algorithm proposed therein is the latest installment in a long line of extensions of simple methods in reinforcement learning (RL), which was started by the reward weighted regression algorithm (derived from an expectation-maximization framework by Peters and Schaal [Pet07R]). This pill contains more references than usual, because I believe the context and ideas surrounding the algorithms mentioned here may be of general interest.
APA not only holds the promise of improving the state-of-the-art for aligning language models (LM) to human preferences, but could also be of general use in RL. What I find particularly fascinating about the works mentioned in this pill (and about the modern history of RL in general), is how powerful algorithms arise from small modifications to previously published (and often simple) methods that proved to work well in restricted settings. It leaves me wondering what other gems are still hidden in the literature, waiting to be rediscovered as state-of-the-art algorithms.
As a motivation for practitioners, we proceed with the experimental results of APA [Zhu23F], before diving deeper into the theoretical background and derivations.
Experimental Results
The APA algorithm follows the same training routines as typical reinforcement learning from human feedback (RLHF) methods (see e.g. Ziegler et al. [Zie20F]), but with a different loss. Thus, the authors can make a direct comparison of APA to RLHF with PPO on an already trained and fixed reward model. They also include a comparison to another loss derived from advantage weighted regression (AWR), which to my knowledge is the first time this method is applied to language modeling. Since one of the main novelties of the original AWR paper was the extension of the AWR loss to off-policy learning, the poor performance of AWR in the on-policy setting is maybe not too surprising. Thus, the main practically relevant comparison is between APA and PPO.
Figure 1 from [Zhu23F]. Comparison of the performance of three methods on the Helpfulness and Harmlessnes dataset. Top: The x-axis represents the total steps, which are proportional to the amount of data used in the training procedure. The y-axis is the reward evaluated by the same reward model. Bottom: The x-axis represents the total steps. The y-axis is the KL divergence between the trained model and the initial model.
We can see that APA outperforms PPO in terms of reward and has a comparable
control of the KL divergence to the non-finetuned model
For the largest of the tested models (which is still small in comparison to LLMs in use today), APA was able to achieve higher reward at the cost of a higher KL divergence. Moreover, looking at the graphs, there seems to be a tendency to deviate more strongly from the initial policy as the model size increases. On the other hand, for smaller models, APA is unambiguously better than PPO.
The practical conclusion from these experiments is unclear to me at this moment. One would at least need to perform an evaluation on the standard downstream tasks to see if the performance of the model has actually improved. Unfortunately, the authors provide no such evaluation. This does not mean that APA is not a promising method, and I hope a future evaluation of with larger models and a variety of downstream tasks will shed more light on it.
From Weighted Regression to APA
Notation
We use standard notation for RL, where
Core Ideas of Weighted Regression
RL algorithms from the weighted regression family are generally based on the
following idea: given some policy
A particularly clean manifestation of this scheme can be found in the AWR paper [Pen19A]. It goes roughly as follows:
Start with the first-order approximation of the expected improvement of a policy
over the sampling policy :where
is the expectation of discounted sum of future rewards (i.e. the objective one wants to maximize), and is the advantage function of .3 3 This expression is the starting point of several policy gradient algorithms, in particular trust region policy optimization and proximal policy optimization. A derivation can be found e.g. in our blog post on this topic.Restrict the solution space to policies that are close to
in terms of KL divergence, resulting in the Lagrangianwhere
is the Lagrange multiplier and some small constant.The analytic solution of the above problem is:
where
is a state-dependent normalization factor (the partition function). As noted above, this is a weighted version of .Find an improved set of parameters
by minimizing the difference between and on the dataset according to some distance function for probability distributions:
Weighted Regression in Practice
There are multiple variations of this basic scheme. In the original AWR paper,
the authors use the KL-divergence as the distance function and make the
assumption that the partition function
Early works by Peters, Schaal, and Neumann [Pet07R, Neu08F] arrived at similar schemes by using the expectation-maximization (EM) framework on a lower bound on the expected reward (see also Strupl et. al. [Str22R] for a nice theoretical analysis and extension of the original reward weighted regression). The well-known REPS algorithm [Pet10R] and follow-up work is making use of the analytic solution to an objective similar to Equation 2. The widely used trust region and proximal policy optimization algorithms [Sch15T, Sch17P] directly aim at optimizing Equation 2 or a variation thereof, but without using the analytic solution. Instead, they focus on policy gradient methods and practical implementation details (see e.g. our blog post Natural, Trust Region and Proximal Policy Optimization for a detailed discussion).
The APA Algorithm
The APA algorithm is the newest addition to the family of weighted regression algorithms, and the first to my knowledge to tackle language modeling. It makes the following modifications to the basic scheme outlined above:
For language modelling, one wants the fine-tuned model to be close to original pre-trained model
in terms of KL-divergence, i.e. is replaced with . Therefore, the constraint in Equation 2 is adjusted accordingly, resulting in the optimal policyThe authors argue that, for LMs, one can safely assume
.5 5 The authors' intuition behind this is that to first order in , the partition function is . Assuming that the tuned policy is close enough to the initial policy for the advantages to be similar, one gets . During the training they notice that the loss decreases by very little, suggesting that the tuned policy is indeed very close to . On a personal note, while I don’t necessarily see the validity of these intermediate approximations (a value of is used in the experiments, which is not particularly large), the final justification is that the resulting algorithm works well. As the distance function for the projection step Equation 4, they use the squared error between logits weighted by previously sampled data, i.e.
These design choices result in the APA loss:
When applying APA for training LMs, Zhu et al. use the weighted regression scheme in its on-policy version, i.e. they do not reuse samples from previous iterations. Thus, the method can be directly compared to other on-policy algorithms, differing from them only in the loss function Equation 5.
Discussion
From the theoretical point of view, the new advantage-induced policy alignment algorithm resulted from a small modification of advantage-weighted regression, and as such is not a major breakthrough.
On the practical side, however, this modification seems to result in a significant improvement in performance. Therefore, this work may be important for practical purposes, as well as for rejuvenating interest in the family of weighted regression algorithms.
I found it interesting to follow the discussions leading to the rejections of the AWR paper as well as some popular follow-up works like AWAC from the conferences they applied to; see the discussions on openreview here and here. The main reasons for rejection were that the new algorithms were neither sufficiently different from existing weighted regression approaches, nor more performant than state-of-the-art algorithms. We see, however, that simple and efficient ideas, while perhaps not immediately beating the state-of-the-art, can still be highly influential and eventually lead to significant improvements in performance. Many now famous papers in machine learning are based on such small theoretical modifications, so I believe it is fair to say that the authors of AWR were somewhat unlucky in their timing for the submission of their results.
Software-wise, APA is built on top of the modern trlx library for combining transformers with RL, and as such it could be a useful starting point for custom RLHF algorithms that follow good software engineering practices. It is refreshing to see research code that is going in the direction of being reusable and extensible. The code is freely available on GitHub under the MIT license.