RRHF 논문 & 코드리뷰: Rank Responses to Align Language Models with Human Feedback without tears
RRHF 논문 & 코드리뷰: Rank Responses to Align Language Models with Human Feedback without tears

RRHF 논문 & 코드리뷰: Rank Responses to Align Language Models with Human Feedback without tears

Tags
NLP
논문리뷰
RLHF
Published
Published June 23, 2023

ArXiv

Abstract

Reinforcement Learning from Human Feedback (RLHF) facilitates the alignment of large language models with human preferences, significantly enhancing the quality of interactions between humans and these models. InstructGPT implements RLHF through several stages, including Supervised Fine-Tuning (SFT), reward model training, and Proximal Policy Optimization (PPO). PPO, however, is sensitive to hyperparameters and requires a minimum of four models in its standard implementation, which makes it hard to train. In contrast, we propose a novel learning paradigm called RRHF, which scores responses generated by different sampling policies and learns to align them with human preferences through ranking loss. RRHF can efficiently align language model output probabilities with human preferences as robust as fine-tuning and it only needs 1 to 2 models during tuning. In addition, RRHF can be considered an extension of SFT and reward models while being simpler than PPO in terms of coding, model counts, and hyperparameters. The entire alignment process can be accomplished within a single RRHF training session. We evaluate RRHF using LLaMA and Alpaca on Helpful and Harmless data, demonstrating performance comparable to PPO.
→ PPO 할 때는 Reward용 모델을 새로 만들어야 하지만 RRHF에서는 Ranking Loss를 대신 사용해서 PPO처럼 복잡하지 않고서도 RLHF같은 효과를 줄 수 있다고 함
  • Wombat이라는 LLAMA 기반 챗봇도 함께 릴리즈

Github Repo

RRHF?

왼쪽이 PPO, 우측이 RRHF
PPO가 LM 외에 3종류의 더 많은 모델을 강화학습 Proxy로 사용하지만, RRHF에서는 단일 Reward Model만 사용
왼쪽이 PPO, 우측이 RRHF PPO가 LM 외에 3종류의 더 많은 모델을 강화학습 Proxy로 사용하지만, RRHF에서는 단일 Reward Model만 사용
RRHF = Rank Response to align Human Feedback
  • 기존의 OpenAI, TRL등에서 PPO를 통해 Reward Model / Reference Model / Value Model 구조를 추가로 만들어 Model의 Response에 대해 강화학습할 수 있는 환경을 만들어주었지만, 단순 Language Model 외에도 Proximator로 동작할 타 모델을 상당히 많이 만들어줘야 하는 문제가 있다.

기존의 RLHF(PPO)?

→ 너무 복잡함. Scaling 이슈가 있음.
RLHF-PPO를 설명하는 삽화 이미지. Language Model이 생성한 결과물을 기반으로 사람이 Scoring을 한뒤, 그 Score를 맞추도록(=사람의 선호도를 학습한) 일종의 회귀/분류형 Score Model(=Reward Model)을 학습한다.
RLHF-PPO를 설명하는 삽화 이미지. Language Model이 생성한 결과물을 기반으로 사람이 Scoring을 한뒤, 그 Score를 맞추도록(=사람의 선호도를 학습한) 일종의 회귀/분류형 Score Model(=Reward Model)을 학습한다.
  • 위와 같이 PPO에서는 주어진 Prompts → LM이 생성 → 사람이 검수 → Scoring
  • Score ~ 생성된 결과물 Pair를 기준으로 → P(Score | 주어진 생성물)를 만드는 Reward Model을 학습시킨다.
notion image
  • 이때, 이걸 강화학습에 대응시키면..
    • Policy = “Initial LM이 Prompt를 받아서 → Text 생성하는 것”
    • Action Space = LM이 가진 Vocab의 수(=생성할 수 있는 토큰..)
    • Observation Space = Input token sequence = 토큰(vocab)^문장길이(seq len)
    • Reward Function = Preference Model + Policy Shift constraint
      • Preference model은 단순 값 하나(Scalar value)를 반환
      • Policy Shift constraint는 KL divergence를 Scaling해서 제한
        • Original LM이 생성한 텍스트와 너무 멀어지지 않도록 제한하는 KLD
        • KL Divergence를 Text 생성에서 Huggingface 모델로 간단히 계산한 예제 코드
          • 단, 이때 RL Policy Model이 생성한 Text와 Prob을 쓰는게 아니라..
            • RL Policy가 text를 생성하고
            • 해당 text를 기존 base LM에다가 다시 넣어서 → 그 Prob을 가져다 쓴다.
          from transformers import AutoTokenizer, AutoModelForCausalLM from torch.distributions import Categorical import torch.nn.functional as F import torch tokenizer = AutoTokenizer.from_pretrained("beomi/KoRWKV-6B") # KoRWKV-6B를 기본 PLM으로 간주 model_pretrained = AutoModelForCausalLM.from_pretrained("beomi/KoRWKV-6B") # KoAlpaca로 파튜한 모델을 Policy등으로 학습된 Model이라고 간주 model_policy = AutoModelForCausalLM.from_pretrained("beomi/KoAlpaca-KoRWKV-6B") def generate_text(model, prompt, max_length=50): input_ids = tokenizer.encode(prompt, return_tensors='pt') output = model.generate(input_ids, max_length=max_length, temperature=1, do_sample=True) return tokenizer.decode(output[0]) prompt = "안녕하세요, " # 텍스트 생성하고.. generated_pretrained = generate_text(model_pretrained, prompt) generated_policy = generate_text(model_policy, prompt) # KLD에서는 하나는 Log prob이고 하나는 일반 prob으로 계산한다. # KL(P || Q) = Σ P(x) * log(P(x) / Q(x)) def get_log_probs(model, text): input_ids = tokenizer.encode(text, return_tensors='pt') with torch.no_grad(): outputs = model(input_ids) logits = outputs.logits return F.log_softmax(logits[:, :-1], dim=-1) def get_probs(model, text): input_ids = tokenizer.encode(text, return_tensors='pt') with torch.no_grad(): outputs = model(input_ids) logits = outputs.logits return F.softmax(logits[:, :-1], dim=-1) log_probs_policy = get_log_probs(model_policy, generated_policy) # Note that we're using the same text as for the policy model probs_pretrained = get_probs(model_pretrained, generated_policy) # P(x), Q(x), log(P(x)), log(Q(x)) log_probs_policy = get_log_probs(model_policy, generated_policy) probs_pretrained = get_probs(model_pretrained, generated_pretrained) # Flatten the tensor into 2D for KL Div calculation log_probs_policy_2d = log_probs_policy.view(-1, log_probs_policy.size(-1)) probs_pretrained_2d = probs_pretrained.view(-1, probs_pretrained.size(-1)) # Calculate KL Divergence kldiv_loss = torch.nn.KLDivLoss(reduction='batchmean') r_KL = kldiv_loss(log_probs_policy_2d, probs_pretrained_2d) print(f'KL Divergence: {r_KL.item()}')
    • Update Rule = PPO를 통해서 현재 Batch에서 Reward를 최대화하는 방향으로 학습
      • PPO는 on-policy RL이라서 현재 Prompt-Generated Text pair에 대해서만 학습
        • On-policy RL이라서 생기는 Loss fluctuation issue
          • Error in subsequent estimation — Text 생성이라고 치면 초창기에 생기는 생성 에러부분으로 인해서 생기는 err propagation을 쉽게 잡기 어려움
          • Policy update 자체의 이슈 — PPO를 할 때 기존 policy 대비 너무 멀어지지 않도록 학습하지만 + 동시에 더 잘 하도록 학습되어야 하는, 학습 목표 자체의 충돌
          • On policy라는 Sample Inefficiency — 현재 Batch에 들어온 샘플만 사용하기 때문에 가장 최근에 사용한 (그리고 현재 들어온) 데이터에만 최적화, 이로 인해 겹치지 않는 많이 떨어진 샘플이 들어오면 → Loss 차이가 크게 날 수 있음
          • 이건 onpolicy라서 생기는건 아니지만… RL로 할때 lr 등에 굉장히 민감함.

RRHF 동작 구조 with Code

데이터 샘플

{ "query":" \n\nHuman: What is the correct order to add ingredients when making a taco?\n\nAssistant:", "responses":[ " You should add the meat or vegetables first, then the lettuce, cheese, tomatoes, onions, and other toppings.", " The correct order to add ingredients when making a taco is to start with the tortilla, then add the meat or vegetable filling, followed by the toppings such as cheese, salsa, avocado, sour cream, and lettuce.", " You should add the meat or vegetables first, then the lettuce, cheese, tomatoes, onions, and other toppings. Start with the meat or vegetables, then the lettuce, then the cheese, tomatoes, onions, and other toppings.", " The correct order to add ingredients when making a taco is to start with a corn or wheat tortilla, then add your choice of meat, vegetables, cheese, beans, salsa, guacamole, sour cream, and other toppings.", " The correct order is to add the meat, cheese, and sauce first, and then the vegetables last. You can think of the question this way: \u201cWhat ingredients do you add first when making a taco?\u201d The answer to this is: \u201cThe meat, cheese, and sauce\u201d", " Which kind of taco? Do you want something more traditional or a fusion?" ], "scores":[ -1.427734375, -1.03515625, -1.28125, -1.1923828125, -1.6572265625, -2.400390625 ] }
  • 위와 같이 query, responses, 그리고 각 response에 해당하는 scores가 있다.
 

RRHF Trainer

전체코드

RRHF Trainer, Line-by-Line

아래는 RRHF 공식 코드에 구현된 RRHFTrainer 클래스
  • RRHF Loss를 구하는 부분이 들어있다.
  • 위 전체 코드의 가장 아래부분인 compute_loss 부분부터 살펴보면, 최종 학습을 위한 losssft_lossrrhf_loss 의 합으로 이루어진 것을 볼 수 있다.
    • logits = model(input_ids=inputs.get('input_ids'), attention_mask=inputs.get('attention_mask'))[0] logits = F.log_softmax(logits, dim=-1) logit_label = self.gather_logits_labels(logits, inputs.get("labels")) scores = self.get_score(logit_label, inputs.get("labels")) rrhf_loss = self.rrhf_loss(scores, inputs.get("idxs"), inputs.get("scores")) sft_loss = self.sft_loss(logit_label, inputs.get("idxs"), inputs.get("scores")) loss = self.args.rrhf_weight * rrhf_loss + sft_loss
    • 위 부분에서 logits는 Policy Model()에 해당하는, RRHF 학습되고 있는 모델에 주어진 prompt를 입력했을 때 나오는 토큰에 대한 logits value
  • 이때 Loss 계산을 위해서 사용하는 logit_label은 몇 가지 단계를 통해 생성된다.
    • # logit_label = self.gather_logits_labels(logits, inputs.get("labels")) def gather_logits_labels(self, logits, labels): # 만약 아래와 같은 값(batch size = 2)으로 들어왔다면... # logits = torch.tensor([[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8]]) # labels = torch.tensor([[1, 2, -100, 0], [3, -100, 0, 1]]) mask = (labels != -100).long() # Step 1 # mask == tensor([[1, 1, 0, 1], [1, 0, 1, 1]]) new_logits = logits.clone() # Step 2 labels[labels == -100] = 0 # Step 3 # labels == tensor([[1, 2, 0, 0], [3, 0, 0, 1]]) output = torch.gather(new_logits, dim=-1, index=labels.unsqueeze(-1)).squeeze(-1) # Step 4 # output == tensor([[0.2, 0.3, 0.1, 0.1], [0.8, 0.5, 0.5, 0.6]]) output = output * mask # Step 5 # output == tensor([[0.2, 0.3, 0.0, 0.1], [0.8, 0.0, 0.5, 0.6]]) return output
      1. label 값이 -100, 즉 loss를 계산하지 않도록 처리된 부분(보통은 Prompt)을 마스킹하고
      1. Tensor in-place 업데이트를 막기 위해 복제하고
      1. 앞서서 label이 -100 인 부분을 0으로 처리한다. ← 4번째 단계에서 logits indexing에 쓰이는데 Negative값은 오류가 나기 때문.
      1. torch.gather 로 label과 합치고
      1. Mask 처리된 부분을 0로 바꿔버린다.
  • RRHF 학습중인 모델 가 학습할 score 는 아래와 같이 계산된다.
    • def get_score(self, logit_label, labels): mask = (labels != -100).float() # Step 1 length = mask.sum(-1) # Step 2 scores = logit_label.sum(-1) / (length ** self.args.length_penalty) # Step 3 return scores
      1. 앞서와 같이 mask 부분을 찾고
      1. mask된 부분을 제외한 토큰 갯수를 찾고
      1. 모델이 생성한 logit_label 과, 정답 labels 각각에 대해 masking된 부분을 제외하고 Length Penalty(만약 length_penalty가 1보다 작으면 짧은 답변 선호, 1보다 크면 긴 답변 선호)를 부여한다.
      1. 이렇게 나온 값은 대략 배치사이즈 크기의 텐서 ex) torch.tensor([-0.5, -0.3, -0.4, -0.8, -0.2, -0.1]) (배치사이즈 6)
  • RRHF Loss 계산하기 🌟
    • def rrhf_loss(self, scores, idxs, rw_scores): # 아래와 같은 예제를 쓴다고 가정 # scores = torch.tensor([-0.5, -0.3, -0.4, -0.8, -0.2, -0.1]) # rw_scores = torch.tensor([-0.84, -0.85, -0.57, -0.62, -0.49, -0.87]) diff = scores.unsqueeze(0) - scores.unsqueeze(-1) # b * b # Step 1 rw_diff = rw_scores.unsqueeze(0) - rw_scores.unsqueeze(-1) # b * b # Step 2 aval = torch.bitwise_and(rw_diff > 0, diff < 0)[0] # Step 3 return -diff[aval].sum() # Step 4
      1. diff, rw_diff를 예제로 계산하면…
        1. diff = tensor([[ 0.0000, 0.2000, 0.1000, -0.3000, 0.3000, 0.4000], [-0.2000, 0.0000, -0.1000, -0.5000, 0.1000, 0.2000], [-0.1000, 0.1000, 0.0000, -0.4000, 0.2000, 0.3000], [ 0.3000, 0.5000, 0.4000, 0.0000, 0.6000, 0.7000], [-0.3000, -0.1000, -0.2000, -0.6000, 0.0000, 0.1000], [-0.4000, -0.2000, -0.3000, -0.7000, -0.1000, 0.0000]]) rw_diff = tensor([[ 0.0000, -0.0100, 0.2700, 0.2200, 0.3500, -0.0300], [ 0.0100, 0.0000, 0.2800, 0.2300, 0.3600, -0.0200], [-0.2700, -0.2800, 0.0000, -0.0500, 0.0800, -0.3000], [-0.2200, -0.2300, 0.0500, 0.0000, 0.1300, -0.2500], [-0.3500, -0.3600, -0.0800, -0.1300, 0.0000, -0.3800], [ 0.0300, 0.0200, 0.3000, 0.2500, 0.3800, 0.0000]])
      1. aval을 Bitwise and로 계산하면…
        1. aval = tensor([False, False, False, True, False, False])
      1. 최종적으로 나오는 값은…
        1. -diff[aval].sum() ==> tensor(-2.5000)
  • SFT Loss 계산하기 — 일반적 Finetune과 동일
    • def sft_loss(self, logit_label, idxs, rw_scores): max_idx = torch.argmax(rw_scores) return -logit_label[max_idx].mean()
    • 일반적인 finetune loss와 동일하다고 하지만, 훨씬 심플한 방식
    • rw_scores 값이 가장 높은, 가장 좋은 점수의 예시를 고르고
    • 해당 예시를 파인튜닝중인 모델 에서 나온 logits에 대해서 mean을 취하는 방식.
    • 어쨌든 해당 예시와 최대한 비슷해지도록 모델을 유도하는 Loss.
    • Cross Entropy와 같은 느낌이라기 보다는, 오히려 Penalty에 가까운 접근방식.
 

실제 데이터로, 입력부터 차근차근.

def compute_loss(self, model, inputs, return_outputs=False): if self.args.only_use_provide: inputs['input_ids'] = inputs['input_ids'][-2:] inputs['attention_mask'] = inputs['attention_mask'][-2:] inputs['labels'] = inputs['labels'][-2:] inputs["idxs"] = inputs["idxs"][:,-2:] inputs["scores"] = inputs["scores"][:,-2:] if self.args.only_use_sample: inputs['input_ids'] = inputs['input_ids'][:-2] inputs['attention_mask'] = inputs['attention_mask'][:-2] inputs['labels'] = inputs['labels'][:-2] inputs["idxs"] = inputs["idxs"][:,:-2] inputs["scores"] = inputs["scores"][:,:-2] print(inputs) # 1 logits = model(input_ids=inputs.get('input_ids'), attention_mask=inputs.get('attention_mask'))[0] # (batch * cand) * L * V print('logits:', logits) # 2 logits = F.log_softmax(logits, dim=-1) print('logits:', logits) # 3 logit_label = self.gather_logits_labels(logits, inputs.get("labels")) print('logit_label:', logit_label) # 4 scores = self.get_score(logit_label, inputs.get("labels")) print('scores:', scores) # 5 rrhf_loss = self.rrhf_loss(scores, inputs.get("idxs"), inputs.get("scores")) print('rrhf_loss:',rrhf_loss) # 6 sft_loss = self.sft_loss(logit_label, inputs.get("idxs"), inputs.get("scores")) print('sft_loss:',sft_loss) # 7 loss = self.args.rrhf_weight * rrhf_loss + sft_loss print(loss) # 8 return (loss, scores) if return_outputs else loss
 

1. 입력에 들어갈 inputs : Token Encoded된 데이터셋

실제 숫자로 나오는 데이터
{'input_ids': tensor([[ 224, 202, 202, 43, 21601, 29, 13302, 376, 1356, 702, 11818, 7993, 12641, 1035, 4798, 38072, 370, 1183, 5018, 7281, 2297, 20789, 656, 427, 68, 697, 34, 202, 202, 6811, 1324, 2303, 29, 11113, 6084, 4798, 702, 2200, 376, 2078, 224, 995, 1089, 1417, 86, 9712, 15, 15760, 702, 3852, 3144, 88, 645, 15, 484, 486, 72, 450, 15, 1035, 3630, 82, 700, 15, 2006, 12990, 15, 1442, 6625, 1035, 1141, 11839, 17, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [ 224, 202, 202, 43, 21601, 29, 13302, 376, 1356, 702, 11818, 7993, 12641, 1035, 4798, 38072, 370, 1183, 5018, 7281, 2297, 20789, 656, 427, 68, 697, 34, 202, 202, 6811, 1324, 2303, 29, 2398, 11818, 7993, 12641, 1035, 4798, 38072, 370, 1183, 5018, 7281, 2297, 20789, 656, 427, 68, 697, 1356, 1035, 8088, 2452, 702, 427, 1087, 17987, 15, 15760, 4798, 702, 2200, 376, 2078, 224, 995, 1089, 1417, 13056, 559, 15, 8884, 600, 2425, 702, 1035, 1141, 11839, 20451, 2749, 484, 486, 72, 450, 15, 23816, 2630, 15, 23769, 1485, 19585, 15, 683, 1711, 484, 2856, 15, 1442, 3852, 3144, 88, 645, 17, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], ...], device='cuda:0'), 'attention_mask': tensor([[ True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False], [ True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False], ...], device='cuda:0'), 'labels': tensor([[ -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 11113, 6084, 4798, 702, 2200, 376, 2078, 224, 995, 1089, 1417, 86, 9712, 15, 15760, 702, 3852, 3144, 88, 645, 15, 484, 486, 72, 450, 15, 1035, 3630, 82, 700, 15, 2006, 12990, 15, 1442, 6625, 1035, 1141, 11839, 17, 0, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100], [ -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 2398, 11818, 7993, 12641, 1035, 4798, 38072, 370, 1183, 5018, 7281, 2297, 20789, 656, 427, 68, 697, 1356, 1035, 8088, 2452, 702, 427, 1087, 17987, 15, 15760, 4798, 702, 2200, 376, 2078, 224, 995, 1089, 1417, 13056, 559, 15, 8884, 600, 2425, 702, 1035, 1141, 11839, 20451, 2749, 484, 486, 72, 450, 15, 23816, 2630, 15, 23769, 1485, 19585, 15, 683, 1711, 484, 2856, 15, 1442, 3852, 3144, 88, 645, 17, 0, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100], ...], device='cuda:0'), 'idxs': tensor([[0, 0, ...]], device='cuda:0'), 'scores': tensor([[-1.4277, -1.0352, ...]], device='cuda:0')}
 
  • input_ids : tokenizer로 인코딩된 토큰.
  • attention_mask : Attention 처리에 마스킹 유무.
    • padding 토큰을 제외한 모든 부분에는 (당연히) 모델이 해당 토큰들을 봐야 하니까 True.
    • padding 부분은 보면 안되니까 False.
  • labels : Loss 계산에 포함할지 유무가 포함된 + 토큰이 인코딩된 상태
    • prompt 부분을 제외하고 ‘정답’으로서 생성되어야 하는(= 학습 데이터의 response) 부분을 남긴다.
    • 나머지 부분은 Loss 계산에 포함하지 않게 하기 위해 -100 으로 마스킹
  • scores : 학습 데이터 json에 포함된 Score → 높은(0에 가까운) 스코어일수록 사람이 선호 🆙
 

2. 모델에 입력해 나오는 Logits

실제 숫자로 나오는 데이터
logits: tensor([[[ 9.3750, 1.8516, -3.6562, ..., 2.5000, 1.9297, 1.9375], [ 10.5625, 1.2344, -4.1562, ..., 4.2188, 2.5156, 4.2500], [ 12.0625, -0.2393, -5.1875, ..., 4.3125, 2.5156, 2.1250], ..., [ 10.0625, 3.7188, -4.0625, ..., 8.3750, 5.5938, 6.2812], [ 10.0625, 3.6562, -4.0312, ..., 8.3125, 5.6250, 6.1250], [ 10.1875, 3.6406, -3.9844, ..., 8.4375, 5.7500, 6.2188]], [[ 9.3750, 1.8516, -3.6562, ..., 2.5000, 1.9297, 1.9375], [ 10.5625, 1.2344, -4.1562, ..., 4.2188, 2.5156, 4.2500], [ 12.0625, -0.2393, -5.1875, ..., 4.3125, 2.5156, 2.1250], ..., [ 9.8750, 3.8594, -4.2500, ..., 8.8125, 6.0312, 6.8438], [ 9.9375, 3.8594, -4.3750, ..., 8.6250, 5.8125, 6.6250], [ 10.0000, 3.7969, -4.4688, ..., 8.4375, 5.7500, 6.4062]], [[ 9.3750, 1.8516, -3.6562, ..., 2.5000, 1.9297, 1.9375], [ 10.5625, 1.2344, -4.1562, ..., 4.2188, 2.5156, 4.2500], [ 12.0625, -0.2393, -5.1875, ..., 4.3125, 2.5156, 2.1250], ..., [ 10.1875, 3.3281, -4.3125, ..., 8.2500, 5.5312, 6.1562], [ 10.1875, 3.4688, -4.3125, ..., 8.2500, 5.5312, 6.0625], [ 10.3750, 3.5469, -4.2812, ..., 8.2500, 5.5625, 6.0938]], [[ 9.3750, 1.8516, -3.6562, ..., 2.5000, 1.9297, 1.9375], [ 10.5625, 1.2344, -4.1562, ..., 4.2188, 2.5156, 4.2500], [ 12.0625, -0.2393, -5.1875, ..., 4.3125, 2.5156, 2.1250], ..., [ 9.9375, 3.8438, -3.9688, ..., 8.5625, 5.5938, 6.5938], [ 10.0000, 3.8906, -3.9375, ..., 8.7500, 5.7188, 6.5938], [ 10.0625, 3.8594, -3.9844, ..., 8.7500, 5.7188, 6.5938]], [[ 9.3750, 1.8516, -3.6562, ..., 2.5000, 1.9297, 1.9375], [ 10.5625, 1.2344, -4.1562, ..., 4.2188, 2.5156, 4.2500], [ 12.0625, -0.2393, -5.1875, ..., 4.3125, 2.5156, 2.1250], ..., [ 6.1562, -3.5156, -12.2500, ..., -1.4766, -2.9219, -4.3750], [ 10.7500, -2.1562, -9.0000, ..., -0.7461, 0.5195, -1.6641], [ 9.8125, 3.0156, -4.8438, ..., 7.5625, 5.6250, 5.7812]], [[ 9.3750, 1.8516, -3.6562, ..., 2.5000, 1.9297, 1.9375], [ 10.5625, 1.2344, -4.1562, ..., 4.2188, 2.5156, 4.2500], [ 12.0625, -0.2393, -5.1875, ..., 4.3125, 2.5156, 2.1250], ..., [ 9.6875, 3.4688, -3.9688, ..., 8.2500, 5.3125, 6.5625], [ 9.6875, 3.2969, -4.0000, ..., 8.1250, 5.3125, 6.4062], [ 9.6875, 3.2031, -4.1250, ..., 8.0625, 5.3125, 6.3438]]], device='cuda:0', grad_fn=<ToCopyBackward0>)
 
  • 6개의 후보 텍스트가 있어서 → 각 텍스트 별로 Logits 값이 나온다.
 

3. Log Softmax된 logits

실제 숫자로 나오는 데이터
logits: tensor([[[ -8.7143, -16.2377, -21.7456, ..., -15.5893, -16.1596, -16.1518], [ -5.1288, -14.4569, -19.8475, ..., -11.4725, -13.1757, -11.4413], [ -3.4711, -15.7729, -20.7211, ..., -11.2211, -13.0180, -13.4086], ..., [ -7.9139, -14.2576, -22.0389, ..., -9.6014, -12.3826, -11.6951], [ -7.8303, -14.2365, -21.9240, ..., -9.5803, -12.2678, -11.7678], [ -7.7656, -14.3125, -21.9375, ..., -9.5156, -12.2031, -11.7344]], [[ -8.7143, -16.2377, -21.7456, ..., -15.5893, -16.1596, -16.1518], [ -5.1288, -14.4569, -19.8475, ..., -11.4725, -13.1757, -11.4413], [ -3.4711, -15.7729, -20.7211, ..., -11.2211, -13.0180, -13.4086], ..., [ -8.1194, -14.1351, -22.2444, ..., -9.1819, -11.9632, -11.1507], [ -7.9777, -14.0558, -22.2902, ..., -9.2902, -12.1027, -11.2902], [ -7.8487, -14.0518, -22.3174, ..., -9.4112, -12.0987, -11.4424]], [[ -8.7143, -16.2377, -21.7456, ..., -15.5893, -16.1596, -16.1518], [ -5.1288, -14.4569, -19.8475, ..., -11.4725, -13.1757, -11.4413], [ -3.4711, -15.7729, -20.7211, ..., -11.2211, -13.0180, -13.4086], ..., [ -7.6713, -14.5307, -22.1713, ..., -9.6088, -12.3276, -11.7026], [ -7.6189, -14.3377, -22.1189, ..., -9.5564, -12.2752, -11.7439], [ -7.4675, -14.2957, -22.1238, ..., -9.5925, -12.2800, -11.7488]], [[ -8.7143, -16.2377, -21.7456, ..., -15.5893, -16.1596, -16.1518], [ -5.1288, -14.4569, -19.8475, ..., -11.4725, -13.1757, -11.4413], [ -3.4711, -15.7729, -20.7211, ..., -11.2211, -13.0180, -13.4086], ..., [ -8.0349, -14.1286, -21.9411, ..., -9.4099, -12.3786, -11.3786], [ -8.0179, -14.1272, -21.9554, ..., -9.2679, -12.2991, -11.4241], [ -7.9908, -14.1939, -22.0377, ..., -9.3033, -12.3345, -11.4595]], [[ -8.7143, -16.2377, -21.7456, ..., -15.5893, -16.1596, -16.1518], [ -5.1288, -14.4569, -19.8475, ..., -11.4725, -13.1757, -11.4413], [ -3.4711, -15.7729, -20.7211, ..., -11.2211, -13.0180, -13.4086], ..., [-10.5943, -20.2662, -29.0006, ..., -18.2271, -19.6725, -21.1256], [ -5.4727, -18.3790, -25.2227, ..., -16.9688, -15.7032, -17.8868], [ -7.7776, -14.5745, -22.4338, ..., -10.0276, -11.9651, -11.8088]], [[ -8.7143, -16.2377, -21.7456, ..., -15.5893, -16.1596, -16.1518], [ -5.1288, -14.4569, -19.8475, ..., -11.4725, -13.1757, -11.4413], [ -3.4711, -15.7729, -20.7211, ..., -11.2211, -13.0180, -13.4086], ..., [ -8.1588, -14.3775, -21.8150, ..., -9.5963, -12.5338, -11.2838], [ -8.0131, -14.4038, -21.7006, ..., -9.5756, -12.3881, -11.2944], [ -7.9537, -14.4381, -21.7662, ..., -9.5787, -12.3287, -11.2974]]], device='cuda:0', grad_fn=<LogSoftmaxBackward0>)
 

4. Logit Label : 학습용 Logits

실제 숫자로 나오는 데이터
logit_label: tensor([[-0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -5.2047e+00, -3.9633e+00, -3.7154e+00, -1.8586e+00, -5.5417e+00, -1.0325e-01, -4.3098e+00, -3.9197e+00, -5.6314e-02, -9.2349e-01, -1.0923e-02, -2.9228e-01, -4.9035e+00, -2.0921e+00, -2.3459e+00, -2.2275e+00, -6.2254e+00, -7.2684e-01, -8.3274e-02, -2.0360e-02, -2.3622e+00, -3.6149e+00, -2.3155e+00, -9.3095e-02, -2.5647e-03, -3.4374e-01, -2.2443e+00, -2.8266e-01, -6.7356e-03, -1.5269e+00, -4.5461e-01, -1.2591e+00, -8.9856e-01, -2.6449e-01, -2.6092e+00, -4.6270e+00, -4.6410e+00, -5.7265e-01, -1.8393e-03, -2.1693e+00, -7.1166e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00], [-0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -3.8922e+00, -2.9277e+00, -4.4169e-03, -3.0753e-01, -8.3347e-01, -5.7375e-01, -7.9791e-01, -6.8701e-04, -3.7425e-03, -6.5682e-05, -1.5995e-01, -1.4910e-01, -1.9286e-04, -8.9794e-02, -2.7459e-01, -1.0095e-02, -3.0750e-02, -3.6543e+00, -1.4377e+00, -5.2084e+00, -1.6428e+00, -2.2102e+00, -4.0789e+00, -7.7777e+00, -7.5434e-02, -3.8790e+00, -3.1241e+00, -1.8450e+00, -1.2490e+00, -4.8884e+00, -7.3326e-02, -4.0398e+00, -3.6256e+00, -7.0264e-02, -7.8034e-01, -1.1456e-02, -9.6471e+00, -1.0769e-01, -3.0114e+00, -3.9529e+00, -1.6534e-02, -7.6546e-03, -1.1023e+00, -5.5995e+00, -2.8106e+00, -7.8820e-03, -5.8266e+00, -9.8396e-03, -3.8590e+00, -1.4775e+00, -2.3561e-02, -1.1159e-03, -1.5604e+00, -3.8725e+00, -2.6680e+00, -6.5971e-01, -5.5761e+00, -8.8965e-03, -2.6756e-02, -7.2358e-01, -3.7405e+00, -7.8345e-01, -1.6047e-01, -6.9308e-04, -8.4214e-01, -2.6555e+00, -4.9179e+00, -2.9730e-01, -1.6792e-02, -6.2096e-03, -6.8873e-01, -6.0952e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00], ...], device='cuda:0', grad_fn=<MulBackward0>)
 
  • Prompt 부분과 Pad 부분은 0으로 처리된다.
  • 나머지 부분(생성된 토큰 부분)에 대해서는 gather_logits_labels 함수를 통해 Logit Label을 계산
    • 해당 Sequence의 각 Token에 대해서, 해당 토큰에 대해 모델이 갖는 확률%값 = Logit값
 

5. Score : Length Penalty가 부여된, 모델이 생성한 토큰의 Logits에 따른 생성 점수

  • 실제 숫자로 나오는 데이터
    • scores: tensor([-0.7102, -1.0950, -1.0060, -1.0923, -1.4811, -0.5766], device='cuda:0', grad_fn=<DivBackward0>)
  • 아래 코드를 통해서, 각 답변의 Sequence Length Penalty가 부여된 후의 Score
    • 단, masking된 Prompt와 Pad 부분은 Length Penalty의 length 계산 길이에 포함되지 않는다.
      • 이런식으로 1 된것만 포함.
        tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])
      def get_score(self, logit_label, labels): mask = (labels != -100).float() length = mask.sum(-1) ## <- Masking 된 부분은 Length Penalty에 포함 X scores = logit_label.sum(-1) / (length ** self.args.length_penalty) return scores
 

6. RRHF Loss: Ranking에 기반한 Loss

  • 실제 숫자로 나오는 데이터
    • rrhf_loss: tensor(3.7421, device='cuda:0', grad_fn=<NegBackward0>)
  • 함수
    • def rrhf_loss(self, scores, idxs, rw_scores): print('scores: ', scores) print('idxs: ', idxs) print('rw_scores: ',rw_scores) diff = scores.unsqueeze(0) - scores.unsqueeze(-1) # b * b print('diff', diff) rw_diff = rw_scores.unsqueeze(0) - rw_scores.unsqueeze(-1) # b * b print('rw_diff', rw_diff) aval = torch.bitwise_and(rw_diff > 0, diff < 0)[0] print('aval', aval) r = -diff[aval].sum() print('r', r) return r
    • diff = scores.unsqueeze(0) - scores.unsqueeze(-1) 과정 자체가 순서를 매기는 과정
      • 가로 세로로 자기자신을 배치하고 → 순위가 맞는지를 검증
      • rw_diff 는 정답의 순서를 의미
      • 즉, 이제 두개의 순서가 틀린(0이하 0이상으로 갈리는 부분) 경우가 틀린 Ranking임
        • 그 부분을 Bitwise and로 추출(하나는 >0, 하나는 <0임)
      • 해당 부분(loss를 계산할 부분)을.. 기존 diff, 즉 Model 에 대해서 생긴 Loss로 대응시켜서
      • 해당 부분만 sum하고 음수로 치환해주면 끝(최종 loss는 양수니까)
  • 예시 수치
    • # score -- \pi 모델이 생성한 Score = \pi 모델이 보고 있는 랭킹 scores: tensor([-1.8340, -1.5368, -2.0713, -1.2388, -2.1106, -0.4438], device='cuda:0', grad_fn=<DivBackward0>) idxs: tensor([[0, 0, 0, 0, 0, 0]], device='cuda:0') # rw_scores -- RLHF, 사용자들이 준 랭킹 rw_scores: tensor([[-0.6064, -0.5854, -0.5791, -0.7280, -2.1777, -2.5430]], device='cuda:0') diff tensor([[ 0.0000, 0.2972, -0.2374, 0.5952, -0.2766, 1.3902], [-0.2972, 0.0000, -0.5345, 0.2980, -0.5738, 1.0930], [ 0.2374, 0.5345, 0.0000, 0.8326, -0.0393, 1.6276], [-0.5952, -0.2980, -0.8326, 0.0000, -0.8718, 0.7950], [ 0.2766, 0.5738, 0.0393, 0.8718, 0.0000, 1.6668], [-1.3902, -1.0930, -1.6276, -0.7950, -1.6668, 0.0000]], device='cuda:0', grad_fn=<SubBackward0>) rw_diff tensor([[[ 0.0000, 0.0210, 0.0273, -0.1216, -1.5713, -1.9365], [-0.0210, 0.0000, 0.0063, -0.1426, -1.5923, -1.9575], [-0.0273, -0.0063, 0.0000, -0.1489, -1.5986, -1.9639], [ 0.1216, 0.1426, 0.1489, 0.0000, -1.4497, -1.8149], [ 1.5713, 1.5923, 1.5986, 1.4497, 0.0000, -0.3652], [ 1.9365, 1.9575, 1.9639, 1.8149, 0.3652, 0.0000]]], device='cuda:0') ## 이땐, `rw_diff > 0` 이면서 동시에 `diff < 0`인 경우를 True로 간주. aval tensor([[False, False, True, False, False, False], [False, False, True, False, False, False], [False, False, False, False, False, False], [ True, True, True, False, False, False], [False, False, False, False, False, False], [ True, True, True, True, True, False]], device='cuda:0') ## 위 Matrix에서 `True`로 나오는 부분에 대해서 -> 기존 `diff` 부분만 추려서 -> Loss로 계산 r tensor(9.0703, device='cuda:0', grad_fn=<NegBackward0>)
 

7. SFT Loss: 점수가 가장 높은 예제와 ↔ 모델의 Logits과의 Loss

  • 실제 숫자로 나오는 데이터
    • sft_loss: tensor(1.0950, device='cuda:0', grad_fn=<NegBackward0>)
    • argmax로 가장 높은 rw_score 예제를 닮도록 학습
 

8. Loss: RRHF Loss + SFT Loss

  • 실제 숫자로 나오는 데이터
    • 3.7421 + 1.0950 = 4.8371 단순 합
    • tensor(4.8371, device='cuda:0', grad_fn=<AddBackward0>)
  • 혹은 아래와 같이 rrhf_weight 를 조절할 수도 있음
    • loss = self.args.rrhf_weight * rrhf_loss + sft_loss
 

나머지 학습은?

  • 그냥 HF 일반 모델 학습 하듯이 학습하면 된다.
  • 5.8B~7B모델은 A100 80G 1대(Adafactor optim기준)로 학습 가능
  • 주의: Sequence Length를 너무 작게 잡으면 → prompt 부분만으로 seq len이 채워져서 loss = 0으로 학습된다.(= 학습 데이터를 날리는 셈)
 

Reference