Direct Preference Optimization 논문리뷰

Direct Preference Optimization 논문리뷰

Published June 28, 2023

ArXiv / Github


While large-scale unsupervised language models (LMs) learn broad world knowl- edge and some reasoning skills, achieving precise control of their behavior is difficult due to the completely unsupervised nature of their training. Existing methods for gaining such steerability collect human labels of the relative quality of model generations and fine-tune the unsupervised LM to align with these prefer- ences, often with reinforcement learning from human feedback (RLHF). However, RLHF is a complex and often unstable procedure, first fitting a reward model that reflects the human preferences, and then fine-tuning the large unsupervised LM using reinforcement learning to maximize this estimated reward without drifting too far from the original model. In this paper, we leverage a mapping between reward functions and optimal policies to show that this constrained reward maxi- mization problem can be optimized exactly with a single stage of policy training, essentially solving a classification problem on the human preference data. The resulting algorithm, which we call Direct Preference Optimization (DPO), is stable, performant, and computationally lightweight, eliminating the need for fitting a reward model, sampling from the LM during fine-tuning, or performing significant hyperparameter tuning. Our experiments show that DPO can fine-tune LMs to align with human preferences as well as or better than existing methods. Notably, fine-tuning with DPO exceeds RLHF’s ability to control sentiment of generations and improves response quality in summarization and single-turn dialogue while being substantially simpler to implement and train.
  • 기존의 RLHF는 복잡하다
  • Reward function과 Optimal Policy간의 Reward-maximization 문제를.. → 단순 1회의 Classification 문제로 치환할 수 있음
  • 새로운 방법론의 이름은 DPO: Direct Preference Optimizaiton
    • Reward Model 학습 X
      • PPO에서는 Reward Model을 학습하니까.
      • RRHF 등에서도 Reward Model을 학습하는 경우도 있음
    • Finetune시에 LM에서 데이터 sampling할 필요 X
      • PPO, RRHF 모두 에서 Instruction에 대한 Output을 샘플링
        • 물론 RRHF에서는 가 아닌 다른 모델의 생성 결과만 써도 되긴 한다.
    • Hparams 튜닝을 안해도 된다
      • RLHF가 워낙 LR등 Hparams에 민감하니까, 민감하지 않으면 학습에 매우 이로움


RL 없이도 Human Preference를 학습할 수 있는 DPO
RL 없이도 Human Preference를 학습할 수 있는 DPO
RLHF에서 PPO를 사용했던 방식은 윗 그림 왼쪽처럼, Human Preferece Data를 갖고서 Reward Model 를 MLE를 통해 학습한 뒤, 해당 모델을 갖고서 모델을 추가로 학습하는 방식으로 전체 과정이 이뤄진다.
DPO 학습 pipeline은 크게 두 과정으로 이뤄진다.
  1. 모델(Base LM)을 SFT로 학습하는 과정 →
  1. 위에 Preference를 학습하는 과정 →
위의 1.과정의 경우 일반적인 Alpaca학습과 동일한 방식으로 진행하면 된다.

DPO Loss

def dpo_loss(policy_chosen_logps: torch.FloatTensor, policy_rejected_logps: torch.FloatTensor, reference_chosen_logps: torch.FloatTensor, reference_rejected_logps: torch.FloatTensor, beta: float, reference_free: bool = False) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: """Compute the DPO loss for a batch of policy and reference model log probabilities. Args: policy_chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (batch_size,) policy_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (batch_size,) reference_chosen_logps: Log probabilities of the reference model for the chosen responses. Shape: (batch_size,) reference_rejected_logps: Log probabilities of the reference model for the rejected responses. Shape: (batch_size,) beta: Temperature parameter for the DPO loss, typically something in the range of 0.1 to 0.5. We ignore the reference model as beta -> 0. reference_free: If True, we ignore the _provided_ reference model and implicitly use a reference model that assigns equal probability to all responses. Returns: A tuple of three tensors: (losses, chosen_rewards, rejected_rewards). The losses tensor contains the DPO loss for each example in the batch. The chosen_rewards and rejected_rewards tensors contain the rewards for the chosen and rejected responses, respectively. """ pi_logratios = policy_chosen_logps - policy_rejected_logps ref_logratios = reference_chosen_logps - reference_rejected_logps if reference_free: ref_logratios = 0 logits = pi_logratios - ref_logratios losses = -F.logsigmoid(beta * logits) chosen_rewards = beta * (policy_chosen_logps - reference_chosen_logps).detach() rejected_rewards = beta * (policy_rejected_logps - reference_rejected_logps).detach() return losses, chosen_rewards, rejected_rewards
DPO Loss는 policy_chosen_logps, policy_rejected_logps, reference_chosen_logps, reference_rejected_logps 4가지의 Log prob값을 통해 계산한다.
def get_batch_metrics(self, batch: Dict[str, Union[List, torch.LongTensor]], loss_config: DictConfig, train=True): """Compute the SFT or DPO loss and other metrics for the given batch of inputs.""" metrics = {} train_test = 'train' if train else 'eval' if == 'dpo': policy_chosen_logps, policy_rejected_logps = self.concatenated_forward(self.policy, batch) with torch.no_grad(): reference_chosen_logps, reference_rejected_logps = self.concatenated_forward(self.reference_model, batch) losses, chosen_rewards, rejected_rewards = dpo_loss( policy_chosen_logps, policy_rejected_logps, reference_chosen_logps, reference_rejected_logps, beta=loss_config.beta, reference_free=loss_config.reference_free) reward_accuracies = (chosen_rewards > rejected_rewards).float() chosen_rewards = all_gather_if_needed(chosen_rewards, self.rank, self.world_size) rejected_rewards = all_gather_if_needed(rejected_rewards, self.rank, self.world_size) reward_accuracies = all_gather_if_needed(reward_accuracies, self.rank, self.world_size) metrics[f'rewards_{train_test}/chosen'] = chosen_rewards.cpu().numpy().tolist() metrics[f'rewards_{train_test}/rejected'] = rejected_rewards.cpu().numpy().tolist() metrics[f'rewards_{train_test}/accuracies'] = reward_accuracies.cpu().numpy().tolist() metrics[f'rewards_{train_test}/margins'] = (chosen_rewards - rejected_rewards).cpu().numpy().tolist() policy_rejected_logps = all_gather_if_needed(policy_rejected_logps.detach(), self.rank, self.world_size) metrics[f'logps_{train_test}/rejected'] = policy_rejected_logps.cpu().numpy().tolist()