ArXivAbstractGithub RepoRRHF?기존의 RLHF(PPO)?RRHF 동작 구조 with Code데이터 샘플RRHF Trainer전체코드RRHF Trainer, Line-by-Line실제 데이터로, 입력부터 차근차근.1. 입력에 들어갈 inputs : Token Encoded된 데이터셋2. 모델에 입력해 나오는 Logits3. Log Softmax된 logits4. Logit Label : 학습용 Logits5. Score : Length Penalty가 부여된, 모델이 생성한 토큰의 Logits에 따른 생성 점수6. RRHF Loss: Ranking에 기반한 Loss7. SFT Loss: 점수가 가장 높은 예제와 ↔ 모델의 Logits과의 Loss8. Loss: RRHF Loss + SFT Loss나머지 학습은?Reference
ArXiv
Abstract
Reinforcement Learning from Human Feedback (RLHF) facilitates the alignment of large language models with human preferences, significantly enhancing the quality of interactions between humans and these models. InstructGPT implements RLHF through several stages, including Supervised Fine-Tuning (SFT), reward model training, and Proximal Policy Optimization (PPO). PPO, however, is sensitive to hyperparameters and requires a minimum of four models in its standard implementation, which makes it hard to train. In contrast, we propose a novel learning paradigm called RRHF, which scores responses generated by different sampling policies and learns to align them with human preferences through ranking loss. RRHF can efficiently align language model output probabilities with human preferences as robust as fine-tuning and it only needs 1 to 2 models during tuning. In addition, RRHF can be considered an extension of SFT and reward models while being simpler than PPO in terms of coding, model counts, and hyperparameters. The entire alignment process can be accomplished within a single RRHF training session. We evaluate RRHF using LLaMA and Alpaca on Helpful and Harmless data, demonstrating performance comparable to PPO.
→ PPO 할 때는 Reward용 모델을 새로 만들어야 하지만 RRHF에서는 Ranking Loss를 대신 사용해서 PPO처럼 복잡하지 않고서도 RLHF같은 효과를 줄 수 있다고 함
- Wombat이라는 LLAMA 기반 챗봇도 함께 릴리즈
Github Repo
RRHF?
RRHF = Rank Response to align Human Feedback
- 기존의 OpenAI, TRL등에서 PPO를 통해 Reward Model / Reference Model / Value Model 구조를 추가로 만들어 Model의 Response에 대해 강화학습할 수 있는 환경을 만들어주었지만, 단순 Language Model 외에도 Proximator로 동작할 타 모델을 상당히 많이 만들어줘야 하는 문제가 있다.
기존의 RLHF(PPO)?
→ 너무 복잡함. Scaling 이슈가 있음.
- 위와 같이 PPO에서는 주어진 Prompts → LM이 생성 → 사람이 검수 → Scoring
- Score ~ 생성된 결과물 Pair를 기준으로 → P(Score | 주어진 생성물)를 만드는 Reward Model을 학습시킨다.
- 이때, 이걸 강화학습에 대응시키면..
- Policy = “Initial LM이 Prompt를 받아서 → Text 생성하는 것”
- Action Space = LM이 가진 Vocab의 수(=생성할 수 있는 토큰..)
- Observation Space = Input token sequence = 토큰(vocab)^문장길이(seq len)
- Reward Function = Preference Model + Policy Shift constraint
- Preference model은 단순 값 하나(Scalar value)를 반환
- Policy Shift constraint는 KL divergence를 Scaling해서 제한
- Original LM이 생성한 텍스트와 너무 멀어지지 않도록 제한하는 KLD
- 단, 이때 RL Policy Model이 생성한 Text와 Prob을 쓰는게 아니라..
- RL Policy가 text를 생성하고
- 해당 text를 기존 base LM에다가 다시 넣어서 → 그 Prob을 가져다 쓴다.
- Update Rule = PPO를 통해서 현재 Batch에서 Reward를 최대화하는 방향으로 학습
- PPO는 on-policy RL이라서 현재 Prompt-Generated Text pair에 대해서만 학습
- Error in subsequent estimation — Text 생성이라고 치면 초창기에 생기는 생성 에러부분으로 인해서 생기는 err propagation을 쉽게 잡기 어려움
- Policy update 자체의 이슈 — PPO를 할 때 기존 policy 대비 너무 멀어지지 않도록 학습하지만 + 동시에 더 잘 하도록 학습되어야 하는, 학습 목표 자체의 충돌
- On policy라는 Sample Inefficiency — 현재 Batch에 들어온 샘플만 사용하기 때문에 가장 최근에 사용한 (그리고 현재 들어온) 데이터에만 최적화, 이로 인해 겹치지 않는 많이 떨어진 샘플이 들어오면 → Loss 차이가 크게 날 수 있음
- 이건 onpolicy라서 생기는건 아니지만… RL로 할때 lr 등에 굉장히 민감함.
KL Divergence를 Text 생성에서 Huggingface 모델로 간단히 계산한 예제 코드
from transformers import AutoTokenizer, AutoModelForCausalLM from torch.distributions import Categorical import torch.nn.functional as F import torch tokenizer = AutoTokenizer.from_pretrained("beomi/KoRWKV-6B") # KoRWKV-6B를 기본 PLM으로 간주 model_pretrained = AutoModelForCausalLM.from_pretrained("beomi/KoRWKV-6B") # KoAlpaca로 파튜한 모델을 Policy등으로 학습된 Model이라고 간주 model_policy = AutoModelForCausalLM.from_pretrained("beomi/KoAlpaca-KoRWKV-6B") def generate_text(model, prompt, max_length=50): input_ids = tokenizer.encode(prompt, return_tensors='pt') output = model.generate(input_ids, max_length=max_length, temperature=1, do_sample=True) return tokenizer.decode(output[0]) prompt = "안녕하세요, " # 텍스트 생성하고.. generated_pretrained = generate_text(model_pretrained, prompt) generated_policy = generate_text(model_policy, prompt) # KLD에서는 하나는 Log prob이고 하나는 일반 prob으로 계산한다. # KL(P || Q) = Σ P(x) * log(P(x) / Q(x)) def get_log_probs(model, text): input_ids = tokenizer.encode(text, return_tensors='pt') with torch.no_grad(): outputs = model(input_ids) logits = outputs.logits return F.log_softmax(logits[:, :-1], dim=-1) def get_probs(model, text): input_ids = tokenizer.encode(text, return_tensors='pt') with torch.no_grad(): outputs = model(input_ids) logits = outputs.logits return F.softmax(logits[:, :-1], dim=-1) log_probs_policy = get_log_probs(model_policy, generated_policy) # Note that we're using the same text as for the policy model probs_pretrained = get_probs(model_pretrained, generated_policy) # P(x), Q(x), log(P(x)), log(Q(x)) log_probs_policy = get_log_probs(model_policy, generated_policy) probs_pretrained = get_probs(model_pretrained, generated_pretrained) # Flatten the tensor into 2D for KL Div calculation log_probs_policy_2d = log_probs_policy.view(-1, log_probs_policy.size(-1)) probs_pretrained_2d = probs_pretrained.view(-1, probs_pretrained.size(-1)) # Calculate KL Divergence kldiv_loss = torch.nn.KLDivLoss(reduction='batchmean') r_KL = kldiv_loss(log_probs_policy_2d, probs_pretrained_2d) print(f'KL Divergence: {r_KL.item()}')
On-policy RL이라서 생기는 Loss fluctuation issue
RRHF 동작 구조 with Code
데이터 샘플
{ "query":" \n\nHuman: What is the correct order to add ingredients when making a taco?\n\nAssistant:", "responses":[ " You should add the meat or vegetables first, then the lettuce, cheese, tomatoes, onions, and other toppings.", " The correct order to add ingredients when making a taco is to start with the tortilla, then add the meat or vegetable filling, followed by the toppings such as cheese, salsa, avocado, sour cream, and lettuce.", " You should add the meat or vegetables first, then the lettuce, cheese, tomatoes, onions, and other toppings. Start with the meat or vegetables, then the lettuce, then the cheese, tomatoes, onions, and other toppings.", " The correct order to add ingredients when making a taco is to start with a corn or wheat tortilla, then add your choice of meat, vegetables, cheese, beans, salsa, guacamole, sour cream, and other toppings.", " The correct order is to add the meat, cheese, and sauce first, and then the vegetables last. You can think of the question this way: \u201cWhat ingredients do you add first when making a taco?\u201d The answer to this is: \u201cThe meat, cheese, and sauce\u201d", " Which kind of taco? Do you want something more traditional or a fusion?" ], "scores":[ -1.427734375, -1.03515625, -1.28125, -1.1923828125, -1.6572265625, -2.400390625 ] }
- 위와 같이
query
,responses
, 그리고 각 response에 해당하는scores
가 있다.
RRHF Trainer
전체코드
RRHF Trainer, Line-by-Line
아래는 RRHF 공식 코드에 구현된
RRHFTrainer
클래스- RRHF Loss를 구하는 부분이 들어있다.
- 위 전체 코드의 가장 아래부분인
compute_loss
부분부터 살펴보면, 최종 학습을 위한loss
가sft_loss
와rrhf_loss
의 합으로 이루어진 것을 볼 수 있다.
logits = model(input_ids=inputs.get('input_ids'), attention_mask=inputs.get('attention_mask'))[0] logits = F.log_softmax(logits, dim=-1) logit_label = self.gather_logits_labels(logits, inputs.get("labels")) scores = self.get_score(logit_label, inputs.get("labels")) rrhf_loss = self.rrhf_loss(scores, inputs.get("idxs"), inputs.get("scores")) sft_loss = self.sft_loss(logit_label, inputs.get("idxs"), inputs.get("scores")) loss = self.args.rrhf_weight * rrhf_loss + sft_loss
logits
는 Policy Model()에 해당하는, RRHF 학습되고 있는 모델에 주어진 prompt를 입력했을 때 나오는 토큰에 대한 logits value- 이때 Loss 계산을 위해서 사용하는
logit_label
은 몇 가지 단계를 통해 생성된다.
# logit_label = self.gather_logits_labels(logits, inputs.get("labels")) def gather_logits_labels(self, logits, labels): # 만약 아래와 같은 값(batch size = 2)으로 들어왔다면... # logits = torch.tensor([[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8]]) # labels = torch.tensor([[1, 2, -100, 0], [3, -100, 0, 1]]) mask = (labels != -100).long() # Step 1 # mask == tensor([[1, 1, 0, 1], [1, 0, 1, 1]]) new_logits = logits.clone() # Step 2 labels[labels == -100] = 0 # Step 3 # labels == tensor([[1, 2, 0, 0], [3, 0, 0, 1]]) output = torch.gather(new_logits, dim=-1, index=labels.unsqueeze(-1)).squeeze(-1) # Step 4 # output == tensor([[0.2, 0.3, 0.1, 0.1], [0.8, 0.5, 0.5, 0.6]]) output = output * mask # Step 5 # output == tensor([[0.2, 0.3, 0.0, 0.1], [0.8, 0.0, 0.5, 0.6]]) return output
- label 값이
-100
, 즉 loss를 계산하지 않도록 처리된 부분(보통은 Prompt)을 마스킹하고
- Tensor in-place 업데이트를 막기 위해 복제하고
- 앞서서 label이
-100
인 부분을 0으로 처리한다. ← 4번째 단계에서 logits indexing에 쓰이는데 Negative값은 오류가 나기 때문.
torch.gather
로 label과 합치고
- Mask 처리된 부분을 0로 바꿔버린다.
- RRHF 학습중인 모델 가 학습할
score
는 아래와 같이 계산된다.
def get_score(self, logit_label, labels): mask = (labels != -100).float() # Step 1 length = mask.sum(-1) # Step 2 scores = logit_label.sum(-1) / (length ** self.args.length_penalty) # Step 3 return scores
- 앞서와 같이 mask 부분을 찾고
- mask된 부분을 제외한 토큰 갯수를 찾고
- 모델이 생성한
logit_label
과, 정답labels
각각에 대해 masking된 부분을 제외하고 Length Penalty(만약length_penalty
가 1보다 작으면 짧은 답변 선호, 1보다 크면 긴 답변 선호)를 부여한다.
- 이렇게 나온 값은 대략
배치사이즈
크기의 텐서 ex)torch.tensor([-0.5, -0.3, -0.4, -0.8, -0.2, -0.1])
(배치사이즈 6)
- RRHF Loss 계산하기 🌟
def rrhf_loss(self, scores, idxs, rw_scores): # 아래와 같은 예제를 쓴다고 가정 # scores = torch.tensor([-0.5, -0.3, -0.4, -0.8, -0.2, -0.1]) # rw_scores = torch.tensor([-0.84, -0.85, -0.57, -0.62, -0.49, -0.87]) diff = scores.unsqueeze(0) - scores.unsqueeze(-1) # b * b # Step 1 rw_diff = rw_scores.unsqueeze(0) - rw_scores.unsqueeze(-1) # b * b # Step 2 aval = torch.bitwise_and(rw_diff > 0, diff < 0)[0] # Step 3 return -diff[aval].sum() # Step 4
- diff, rw_diff를 예제로 계산하면…
diff = tensor([[ 0.0000, 0.2000, 0.1000, -0.3000, 0.3000, 0.4000], [-0.2000, 0.0000, -0.1000, -0.5000, 0.1000, 0.2000], [-0.1000, 0.1000, 0.0000, -0.4000, 0.2000, 0.3000], [ 0.3000, 0.5000, 0.4000, 0.0000, 0.6000, 0.7000], [-0.3000, -0.1000, -0.2000, -0.6000, 0.0000, 0.1000], [-0.4000, -0.2000, -0.3000, -0.7000, -0.1000, 0.0000]]) rw_diff = tensor([[ 0.0000, -0.0100, 0.2700, 0.2200, 0.3500, -0.0300], [ 0.0100, 0.0000, 0.2800, 0.2300, 0.3600, -0.0200], [-0.2700, -0.2800, 0.0000, -0.0500, 0.0800, -0.3000], [-0.2200, -0.2300, 0.0500, 0.0000, 0.1300, -0.2500], [-0.3500, -0.3600, -0.0800, -0.1300, 0.0000, -0.3800], [ 0.0300, 0.0200, 0.3000, 0.2500, 0.3800, 0.0000]])
- aval을 Bitwise and로 계산하면…
aval = tensor([False, False, False, True, False, False])
- 최종적으로 나오는 값은…
-diff[aval].sum() ==> tensor(-2.5000)
- SFT Loss 계산하기 — 일반적 Finetune과 동일
def sft_loss(self, logit_label, idxs, rw_scores): max_idx = torch.argmax(rw_scores) return -logit_label[max_idx].mean()
rw_scores
값이 가장 높은, 가장 좋은 점수의 예시를 고르고실제 데이터로, 입력부터 차근차근.
def compute_loss(self, model, inputs, return_outputs=False): if self.args.only_use_provide: inputs['input_ids'] = inputs['input_ids'][-2:] inputs['attention_mask'] = inputs['attention_mask'][-2:] inputs['labels'] = inputs['labels'][-2:] inputs["idxs"] = inputs["idxs"][:,-2:] inputs["scores"] = inputs["scores"][:,-2:] if self.args.only_use_sample: inputs['input_ids'] = inputs['input_ids'][:-2] inputs['attention_mask'] = inputs['attention_mask'][:-2] inputs['labels'] = inputs['labels'][:-2] inputs["idxs"] = inputs["idxs"][:,:-2] inputs["scores"] = inputs["scores"][:,:-2] print(inputs) # 1 logits = model(input_ids=inputs.get('input_ids'), attention_mask=inputs.get('attention_mask'))[0] # (batch * cand) * L * V print('logits:', logits) # 2 logits = F.log_softmax(logits, dim=-1) print('logits:', logits) # 3 logit_label = self.gather_logits_labels(logits, inputs.get("labels")) print('logit_label:', logit_label) # 4 scores = self.get_score(logit_label, inputs.get("labels")) print('scores:', scores) # 5 rrhf_loss = self.rrhf_loss(scores, inputs.get("idxs"), inputs.get("scores")) print('rrhf_loss:',rrhf_loss) # 6 sft_loss = self.sft_loss(logit_label, inputs.get("idxs"), inputs.get("scores")) print('sft_loss:',sft_loss) # 7 loss = self.args.rrhf_weight * rrhf_loss + sft_loss print(loss) # 8 return (loss, scores) if return_outputs else loss
1. 입력에 들어갈 inputs
: Token Encoded된 데이터셋
실제 숫자로 나오는 데이터
{'input_ids': tensor([[ 224, 202, 202, 43, 21601, 29, 13302, 376, 1356, 702, 11818, 7993, 12641, 1035, 4798, 38072, 370, 1183, 5018, 7281, 2297, 20789, 656, 427, 68, 697, 34, 202, 202, 6811, 1324, 2303, 29, 11113, 6084, 4798, 702, 2200, 376, 2078, 224, 995, 1089, 1417, 86, 9712, 15, 15760, 702, 3852, 3144, 88, 645, 15, 484, 486, 72, 450, 15, 1035, 3630, 82, 700, 15, 2006, 12990, 15, 1442, 6625, 1035, 1141, 11839, 17, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [ 224, 202, 202, 43, 21601, 29, 13302, 376, 1356, 702, 11818, 7993, 12641, 1035, 4798, 38072, 370, 1183, 5018, 7281, 2297, 20789, 656, 427, 68, 697, 34, 202, 202, 6811, 1324, 2303, 29, 2398, 11818, 7993, 12641, 1035, 4798, 38072, 370, 1183, 5018, 7281, 2297, 20789, 656, 427, 68, 697, 1356, 1035, 8088, 2452, 702, 427, 1087, 17987, 15, 15760, 4798, 702, 2200, 376, 2078, 224, 995, 1089, 1417, 13056, 559, 15, 8884, 600, 2425, 702, 1035, 1141, 11839, 20451, 2749, 484, 486, 72, 450, 15, 23816, 2630, 15, 23769, 1485, 19585, 15, 683, 1711, 484, 2856, 15, 1442, 3852, 3144, 88, 645, 17, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], ...], device='cuda:0'), 'attention_mask': tensor([[ True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False], [ True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False], ...], device='cuda:0'), 'labels': tensor([[ -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 11113, 6084, 4798, 702, 2200, 376, 2078, 224, 995, 1089, 1417, 86, 9712, 15, 15760, 702, 3852, 3144, 88, 645, 15, 484, 486, 72, 450, 15, 1035, 3630, 82, 700, 15, 2006, 12990, 15, 1442, 6625, 1035, 1141, 11839, 17, 0, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100], [ -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 2398, 11818, 7993, 12641, 1035, 4798, 38072, 370, 1183, 5018, 7281, 2297, 20789, 656, 427, 68, 697, 1356, 1035, 8088, 2452, 702, 427, 1087, 17987, 15, 15760, 4798, 702, 2200, 376, 2078, 224, 995, 1089, 1417, 13056, 559, 15, 8884, 600, 2425, 702, 1035, 1141, 11839, 20451, 2749, 484, 486, 72, 450, 15, 23816, 2630, 15, 23769, 1485, 19585, 15, 683, 1711, 484, 2856, 15, 1442, 3852, 3144, 88, 645, 17, 0, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100], ...], device='cuda:0'), 'idxs': tensor([[0, 0, ...]], device='cuda:0'), 'scores': tensor([[-1.4277, -1.0352, ...]], device='cuda:0')}
input_ids
: tokenizer로 인코딩된 토큰.
attention_mask
: Attention 처리에 마스킹 유무.- padding 토큰을 제외한 모든 부분에는 (당연히) 모델이 해당 토큰들을 봐야 하니까 True.
- padding 부분은 보면 안되니까 False.
labels
: Loss 계산에 포함할지 유무가 포함된 + 토큰이 인코딩된 상태- prompt 부분을 제외하고 ‘정답’으로서 생성되어야 하는(= 학습 데이터의
response
) 부분을 남긴다. - 나머지 부분은 Loss 계산에 포함하지 않게 하기 위해
-100
으로 마스킹
scores
: 학습 데이터 json에 포함된 Score → 높은(0에 가까운) 스코어일수록 사람이 선호 🆙
2. 모델에 입력해 나오는 Logits
실제 숫자로 나오는 데이터
logits: tensor([[[ 9.3750, 1.8516, -3.6562, ..., 2.5000, 1.9297, 1.9375], [ 10.5625, 1.2344, -4.1562, ..., 4.2188, 2.5156, 4.2500], [ 12.0625, -0.2393, -5.1875, ..., 4.3125, 2.5156, 2.1250], ..., [ 10.0625, 3.7188, -4.0625, ..., 8.3750, 5.5938, 6.2812], [ 10.0625, 3.6562, -4.0312, ..., 8.3125, 5.6250, 6.1250], [ 10.1875, 3.6406, -3.9844, ..., 8.4375, 5.7500, 6.2188]], [[ 9.3750, 1.8516, -3.6562, ..., 2.5000, 1.9297, 1.9375], [ 10.5625, 1.2344, -4.1562, ..., 4.2188, 2.5156, 4.2500], [ 12.0625, -0.2393, -5.1875, ..., 4.3125, 2.5156, 2.1250], ..., [ 9.8750, 3.8594, -4.2500, ..., 8.8125, 6.0312, 6.8438], [ 9.9375, 3.8594, -4.3750, ..., 8.6250, 5.8125, 6.6250], [ 10.0000, 3.7969, -4.4688, ..., 8.4375, 5.7500, 6.4062]], [[ 9.3750, 1.8516, -3.6562, ..., 2.5000, 1.9297, 1.9375], [ 10.5625, 1.2344, -4.1562, ..., 4.2188, 2.5156, 4.2500], [ 12.0625, -0.2393, -5.1875, ..., 4.3125, 2.5156, 2.1250], ..., [ 10.1875, 3.3281, -4.3125, ..., 8.2500, 5.5312, 6.1562], [ 10.1875, 3.4688, -4.3125, ..., 8.2500, 5.5312, 6.0625], [ 10.3750, 3.5469, -4.2812, ..., 8.2500, 5.5625, 6.0938]], [[ 9.3750, 1.8516, -3.6562, ..., 2.5000, 1.9297, 1.9375], [ 10.5625, 1.2344, -4.1562, ..., 4.2188, 2.5156, 4.2500], [ 12.0625, -0.2393, -5.1875, ..., 4.3125, 2.5156, 2.1250], ..., [ 9.9375, 3.8438, -3.9688, ..., 8.5625, 5.5938, 6.5938], [ 10.0000, 3.8906, -3.9375, ..., 8.7500, 5.7188, 6.5938], [ 10.0625, 3.8594, -3.9844, ..., 8.7500, 5.7188, 6.5938]], [[ 9.3750, 1.8516, -3.6562, ..., 2.5000, 1.9297, 1.9375], [ 10.5625, 1.2344, -4.1562, ..., 4.2188, 2.5156, 4.2500], [ 12.0625, -0.2393, -5.1875, ..., 4.3125, 2.5156, 2.1250], ..., [ 6.1562, -3.5156, -12.2500, ..., -1.4766, -2.9219, -4.3750], [ 10.7500, -2.1562, -9.0000, ..., -0.7461, 0.5195, -1.6641], [ 9.8125, 3.0156, -4.8438, ..., 7.5625, 5.6250, 5.7812]], [[ 9.3750, 1.8516, -3.6562, ..., 2.5000, 1.9297, 1.9375], [ 10.5625, 1.2344, -4.1562, ..., 4.2188, 2.5156, 4.2500], [ 12.0625, -0.2393, -5.1875, ..., 4.3125, 2.5156, 2.1250], ..., [ 9.6875, 3.4688, -3.9688, ..., 8.2500, 5.3125, 6.5625], [ 9.6875, 3.2969, -4.0000, ..., 8.1250, 5.3125, 6.4062], [ 9.6875, 3.2031, -4.1250, ..., 8.0625, 5.3125, 6.3438]]], device='cuda:0', grad_fn=<ToCopyBackward0>)
- 6개의 후보 텍스트가 있어서 → 각 텍스트 별로 Logits 값이 나온다.
3. Log Softmax된 logits
실제 숫자로 나오는 데이터
logits: tensor([[[ -8.7143, -16.2377, -21.7456, ..., -15.5893, -16.1596, -16.1518], [ -5.1288, -14.4569, -19.8475, ..., -11.4725, -13.1757, -11.4413], [ -3.4711, -15.7729, -20.7211, ..., -11.2211, -13.0180, -13.4086], ..., [ -7.9139, -14.2576, -22.0389, ..., -9.6014, -12.3826, -11.6951], [ -7.8303, -14.2365, -21.9240, ..., -9.5803, -12.2678, -11.7678], [ -7.7656, -14.3125, -21.9375, ..., -9.5156, -12.2031, -11.7344]], [[ -8.7143, -16.2377, -21.7456, ..., -15.5893, -16.1596, -16.1518], [ -5.1288, -14.4569, -19.8475, ..., -11.4725, -13.1757, -11.4413], [ -3.4711, -15.7729, -20.7211, ..., -11.2211, -13.0180, -13.4086], ..., [ -8.1194, -14.1351, -22.2444, ..., -9.1819, -11.9632, -11.1507], [ -7.9777, -14.0558, -22.2902, ..., -9.2902, -12.1027, -11.2902], [ -7.8487, -14.0518, -22.3174, ..., -9.4112, -12.0987, -11.4424]], [[ -8.7143, -16.2377, -21.7456, ..., -15.5893, -16.1596, -16.1518], [ -5.1288, -14.4569, -19.8475, ..., -11.4725, -13.1757, -11.4413], [ -3.4711, -15.7729, -20.7211, ..., -11.2211, -13.0180, -13.4086], ..., [ -7.6713, -14.5307, -22.1713, ..., -9.6088, -12.3276, -11.7026], [ -7.6189, -14.3377, -22.1189, ..., -9.5564, -12.2752, -11.7439], [ -7.4675, -14.2957, -22.1238, ..., -9.5925, -12.2800, -11.7488]], [[ -8.7143, -16.2377, -21.7456, ..., -15.5893, -16.1596, -16.1518], [ -5.1288, -14.4569, -19.8475, ..., -11.4725, -13.1757, -11.4413], [ -3.4711, -15.7729, -20.7211, ..., -11.2211, -13.0180, -13.4086], ..., [ -8.0349, -14.1286, -21.9411, ..., -9.4099, -12.3786, -11.3786], [ -8.0179, -14.1272, -21.9554, ..., -9.2679, -12.2991, -11.4241], [ -7.9908, -14.1939, -22.0377, ..., -9.3033, -12.3345, -11.4595]], [[ -8.7143, -16.2377, -21.7456, ..., -15.5893, -16.1596, -16.1518], [ -5.1288, -14.4569, -19.8475, ..., -11.4725, -13.1757, -11.4413], [ -3.4711, -15.7729, -20.7211, ..., -11.2211, -13.0180, -13.4086], ..., [-10.5943, -20.2662, -29.0006, ..., -18.2271, -19.6725, -21.1256], [ -5.4727, -18.3790, -25.2227, ..., -16.9688, -15.7032, -17.8868], [ -7.7776, -14.5745, -22.4338, ..., -10.0276, -11.9651, -11.8088]], [[ -8.7143, -16.2377, -21.7456, ..., -15.5893, -16.1596, -16.1518], [ -5.1288, -14.4569, -19.8475, ..., -11.4725, -13.1757, -11.4413], [ -3.4711, -15.7729, -20.7211, ..., -11.2211, -13.0180, -13.4086], ..., [ -8.1588, -14.3775, -21.8150, ..., -9.5963, -12.5338, -11.2838], [ -8.0131, -14.4038, -21.7006, ..., -9.5756, -12.3881, -11.2944], [ -7.9537, -14.4381, -21.7662, ..., -9.5787, -12.3287, -11.2974]]], device='cuda:0', grad_fn=<LogSoftmaxBackward0>)
4. Logit Label : 학습용 Logits
실제 숫자로 나오는 데이터
logit_label: tensor([[-0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -5.2047e+00, -3.9633e+00, -3.7154e+00, -1.8586e+00, -5.5417e+00, -1.0325e-01, -4.3098e+00, -3.9197e+00, -5.6314e-02, -9.2349e-01, -1.0923e-02, -2.9228e-01, -4.9035e+00, -2.0921e+00, -2.3459e+00, -2.2275e+00, -6.2254e+00, -7.2684e-01, -8.3274e-02, -2.0360e-02, -2.3622e+00, -3.6149e+00, -2.3155e+00, -9.3095e-02, -2.5647e-03, -3.4374e-01, -2.2443e+00, -2.8266e-01, -6.7356e-03, -1.5269e+00, -4.5461e-01, -1.2591e+00, -8.9856e-01, -2.6449e-01, -2.6092e+00, -4.6270e+00, -4.6410e+00, -5.7265e-01, -1.8393e-03, -2.1693e+00, -7.1166e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00], [-0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -3.8922e+00, -2.9277e+00, -4.4169e-03, -3.0753e-01, -8.3347e-01, -5.7375e-01, -7.9791e-01, -6.8701e-04, -3.7425e-03, -6.5682e-05, -1.5995e-01, -1.4910e-01, -1.9286e-04, -8.9794e-02, -2.7459e-01, -1.0095e-02, -3.0750e-02, -3.6543e+00, -1.4377e+00, -5.2084e+00, -1.6428e+00, -2.2102e+00, -4.0789e+00, -7.7777e+00, -7.5434e-02, -3.8790e+00, -3.1241e+00, -1.8450e+00, -1.2490e+00, -4.8884e+00, -7.3326e-02, -4.0398e+00, -3.6256e+00, -7.0264e-02, -7.8034e-01, -1.1456e-02, -9.6471e+00, -1.0769e-01, -3.0114e+00, -3.9529e+00, -1.6534e-02, -7.6546e-03, -1.1023e+00, -5.5995e+00, -2.8106e+00, -7.8820e-03, -5.8266e+00, -9.8396e-03, -3.8590e+00, -1.4775e+00, -2.3561e-02, -1.1159e-03, -1.5604e+00, -3.8725e+00, -2.6680e+00, -6.5971e-01, -5.5761e+00, -8.8965e-03, -2.6756e-02, -7.2358e-01, -3.7405e+00, -7.8345e-01, -1.6047e-01, -6.9308e-04, -8.4214e-01, -2.6555e+00, -4.9179e+00, -2.9730e-01, -1.6792e-02, -6.2096e-03, -6.8873e-01, -6.0952e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00], ...], device='cuda:0', grad_fn=<MulBackward0>)
- Prompt 부분과 Pad 부분은 0으로 처리된다.
- 나머지 부분(생성된 토큰 부분)에 대해서는
gather_logits_labels
함수를 통해 Logit Label을 계산 - 해당 Sequence의 각 Token에 대해서, 해당 토큰에 대해 모델이 갖는 확률%값 = Logit값
5. Score : Length Penalty가 부여된, 모델이 생성한 토큰의 Logits에 따른 생성 점수
- 실제 숫자로 나오는 데이터
scores: tensor([-0.7102, -1.0950, -1.0060, -1.0923, -1.4811, -0.5766], device='cuda:0', grad_fn=<DivBackward0>)
- 아래 코드를 통해서, 각 답변의 Sequence Length Penalty가 부여된 후의 Score
- 단, masking된 Prompt와 Pad 부분은 Length Penalty의 length 계산 길이에 포함되지 않는다.
이런식으로 1 된것만 포함.
tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])
def get_score(self, logit_label, labels): mask = (labels != -100).float() length = mask.sum(-1) ## <- Masking 된 부분은 Length Penalty에 포함 X scores = logit_label.sum(-1) / (length ** self.args.length_penalty) return scores
6. RRHF Loss: Ranking에 기반한 Loss
- 실제 숫자로 나오는 데이터
rrhf_loss: tensor(3.7421, device='cuda:0', grad_fn=<NegBackward0>)
- 함수
def rrhf_loss(self, scores, idxs, rw_scores): print('scores: ', scores) print('idxs: ', idxs) print('rw_scores: ',rw_scores) diff = scores.unsqueeze(0) - scores.unsqueeze(-1) # b * b print('diff', diff) rw_diff = rw_scores.unsqueeze(0) - rw_scores.unsqueeze(-1) # b * b print('rw_diff', rw_diff) aval = torch.bitwise_and(rw_diff > 0, diff < 0)[0] print('aval', aval) r = -diff[aval].sum() print('r', r) return r
diff = scores.unsqueeze(0) - scores.unsqueeze(-1)
과정 자체가 순서를 매기는 과정- 가로 세로로 자기자신을 배치하고 → 순위가 맞는지를 검증
rw_diff
는 정답의 순서를 의미- 즉, 이제 두개의 순서가 틀린(0이하 0이상으로 갈리는 부분) 경우가 틀린 Ranking임
- 그 부분을 Bitwise and로 추출(하나는 >0, 하나는 <0임)
- 해당 부분(loss를 계산할 부분)을.. 기존
diff
, 즉 Model 에 대해서 생긴 Loss로 대응시켜서 - 해당 부분만 sum하고 음수로 치환해주면 끝(최종 loss는 양수니까)
- 예시 수치
# score -- \pi 모델이 생성한 Score = \pi 모델이 보고 있는 랭킹 scores: tensor([-1.8340, -1.5368, -2.0713, -1.2388, -2.1106, -0.4438], device='cuda:0', grad_fn=<DivBackward0>) idxs: tensor([[0, 0, 0, 0, 0, 0]], device='cuda:0') # rw_scores -- RLHF, 사용자들이 준 랭킹 rw_scores: tensor([[-0.6064, -0.5854, -0.5791, -0.7280, -2.1777, -2.5430]], device='cuda:0') diff tensor([[ 0.0000, 0.2972, -0.2374, 0.5952, -0.2766, 1.3902], [-0.2972, 0.0000, -0.5345, 0.2980, -0.5738, 1.0930], [ 0.2374, 0.5345, 0.0000, 0.8326, -0.0393, 1.6276], [-0.5952, -0.2980, -0.8326, 0.0000, -0.8718, 0.7950], [ 0.2766, 0.5738, 0.0393, 0.8718, 0.0000, 1.6668], [-1.3902, -1.0930, -1.6276, -0.7950, -1.6668, 0.0000]], device='cuda:0', grad_fn=<SubBackward0>) rw_diff tensor([[[ 0.0000, 0.0210, 0.0273, -0.1216, -1.5713, -1.9365], [-0.0210, 0.0000, 0.0063, -0.1426, -1.5923, -1.9575], [-0.0273, -0.0063, 0.0000, -0.1489, -1.5986, -1.9639], [ 0.1216, 0.1426, 0.1489, 0.0000, -1.4497, -1.8149], [ 1.5713, 1.5923, 1.5986, 1.4497, 0.0000, -0.3652], [ 1.9365, 1.9575, 1.9639, 1.8149, 0.3652, 0.0000]]], device='cuda:0') ## 이땐, `rw_diff > 0` 이면서 동시에 `diff < 0`인 경우를 True로 간주. aval tensor([[False, False, True, False, False, False], [False, False, True, False, False, False], [False, False, False, False, False, False], [ True, True, True, False, False, False], [False, False, False, False, False, False], [ True, True, True, True, True, False]], device='cuda:0') ## 위 Matrix에서 `True`로 나오는 부분에 대해서 -> 기존 `diff` 부분만 추려서 -> Loss로 계산 r tensor(9.0703, device='cuda:0', grad_fn=<NegBackward0>)
7. SFT Loss: 점수가 가장 높은 예제와 ↔ 모델의 Logits과의 Loss
- 실제 숫자로 나오는 데이터
sft_loss: tensor(1.0950, device='cuda:0', grad_fn=<NegBackward0>)
rw_score
예제를 닮도록 학습8. Loss: RRHF Loss + SFT Loss
- 실제 숫자로 나오는 데이터
- 3.7421 + 1.0950 = 4.8371 단순 합
tensor(4.8371, device='cuda:0', grad_fn=<AddBackward0>)
- 혹은 아래와 같이
rrhf_weight
를 조절할 수도 있음
loss = self.args.rrhf_weight * rrhf_loss + sft_loss
나머지 학습은?
- 그냥 HF 일반 모델 학습 하듯이 학습하면 된다.
- 5.8B~7B모델은 A100 80G 1대(Adafactor optim기준)로 학습 가능
- 주의: Sequence Length를 너무 작게 잡으면 → prompt 부분만으로 seq len이 채워져서 loss = 0으로 학습된다.(= 학습 데이터를 날리는 셈)