-
Notifications
You must be signed in to change notification settings - Fork 11
/
train_reward_model.py
410 lines (362 loc) · 15 KB
/
train_reward_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
#!python
# -*- coding: utf-8 -*-
# @author: Kun
import os
import torch
import evaluate
import numpy as np
import torch.nn as nn
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Union
from datasets import load_dataset
from peft import LoraConfig, TaskType, get_peft_model, prepare_model_for_int8_training
from transformers import (
AutoConfig,
AutoModelForSequenceClassification,
AutoTokenizer,
AutoModel,
HfArgumentParser,
PreTrainedTokenizerBase,
Trainer,
TrainingArguments,
)
from transformers.utils import PaddingStrategy
from transformers import LlamaForSequenceClassification, LlamaConfig, LlamaTokenizer
from transformers import AutoModelForSeq2SeqLM
from reward_model import RewardModel
DEFAULT_PAD_TOKEN = "[PAD]"
DEFAULT_EOS_TOKEN = "</s>"
DEFAULT_BOS_TOKEN = "</s>"
DEFAULT_UNK_TOKEN = "</s>"
# Define and parse arguments.
@dataclass
class ScriptArguments:
"""
These arguments vary depending on how many GPUs you have, what their capacity and features are, and what size model you want to train.
"""
local_rank: Optional[int] = field(
default=-1, metadata={"help": "Used for multi-gpu"})
resume_from_checkpoint: Optional[bool] = field(
default=False,
metadata={"help": "If you want to resume training where it left off."},
)
deepspeed: Optional[str] = field(
default=None,
metadata={
"help": "Path to deepspeed config if using deepspeed. You may need this if the model that you want to train doesn't fit on a single GPU."
},
)
per_device_train_batch_size: Optional[int] = field(default=4)
per_device_eval_batch_size: Optional[int] = field(default=1)
gradient_accumulation_steps: Optional[int] = field(default=1)
learning_rate: Optional[float] = field(default=2e-5)
weight_decay: Optional[int] = field(default=0.001)
model_name: Optional[str] = field(
default="decapoda-research/llama-7b-hf",
metadata={
"help": "The model that you want to train from the Hugging Face hub or local."
},
)
bf16: Optional[bool] = field(
default=True,
metadata={
"help": "This essentially cuts the training time in half if you want to sacrifice a little precision and have a supported GPU."
},
)
num_train_epochs: Optional[int] = field(
default=1,
metadata={"help": "The number of training epochs for the reward model."},
)
train_subset: Optional[int] = field(
default=100000,
metadata={"help": "The size of the subset of the training data to use"},
)
eval_subset: Optional[int] = field(
default=50000,
metadata={"help": "The size of the subset of the eval data to use"},
)
gradient_checkpointing: Optional[bool] = field(
default=False,
metadata={"help": "Enables gradient checkpointing."},
)
optim: Optional[str] = field(
default="adamw_hf",
metadata={"help": "Enables gradient checkpointing."},
)
lr_scheduler_type: Optional[str] = field(
default="linear",
metadata={"help": "The lr scheduler"},
)
parser = HfArgumentParser(ScriptArguments)
script_args = parser.parse_args_into_dataclasses()[0]
dataset_name = "./datasets/"
print("dataset_name: ", dataset_name)
# Load the dataset for tuning the reward model.
# train_dataset = load_dataset("lvwerra/stack-exchange-paired", data_dir="data/reward", split="train")
train_dataset = load_dataset(dataset_name, split="train")
if script_args.train_subset > 0:
train_dataset = train_dataset.select(range(script_args.train_subset))
# eval_dataset = load_dataset("lvwerra/stack-exchange-paired", data_dir="data/evaluation", split="train")
eval_dataset = load_dataset(dataset_name, split="train")
if script_args.eval_subset > 0:
eval_dataset = eval_dataset.select(range(script_args.eval_subset))
# Define the training args. Needs to be done before the model is loaded if you are using deepspeed.
model_name_split = script_args.model_name.split("/")[-1]
# output_name = (
# f"{model_name_split}_peft_gpt-4-llm_rm_{script_args.train_subset}_{script_args.learning_rate}"
# )
# output_name = (
# f"{model_name_split}_peft_comparision_data-paired_rmts__{script_args.train_subset}_{script_args.learning_rate}"
# )
output_name = (
f"reward_model_{model_name_split}__{script_args.train_subset}_{script_args.learning_rate}"
)
training_args = TrainingArguments(
output_dir=output_name,
learning_rate=script_args.learning_rate,
per_device_train_batch_size=script_args.per_device_train_batch_size,
per_device_eval_batch_size=script_args.per_device_eval_batch_size,
num_train_epochs=script_args.num_train_epochs,
weight_decay=script_args.weight_decay,
evaluation_strategy="steps",
eval_steps=200, # 500,
save_strategy="steps",
save_steps=200, # 500,
save_total_limit=2,
gradient_accumulation_steps=script_args.gradient_accumulation_steps,
gradient_checkpointing=script_args.gradient_checkpointing,
deepspeed=script_args.deepspeed,
# local_rank=script_args.local_rank,
remove_unused_columns=False,
label_names=[],
# bf16=script_args.bf16,
# fp16=True, #! this is important! if True, cuda out of memory.
logging_strategy="steps",
logging_steps=10,
optim=script_args.optim,
lr_scheduler_type=script_args.lr_scheduler_type,
)
# Load the value-head model and tokenizer.
# tokenizer = AutoTokenizer.from_pretrained(script_args.model_name, use_auth_token=True)
if "llama" in script_args.model_name or "vicuna" in script_args.model_name or "Vicuna" in script_args.model_name:
tokenizer = LlamaTokenizer.from_pretrained(script_args.model_name)
config = LlamaConfig.from_pretrained(script_args.model_name)
elif "chatglm" in script_args.model_name:
tokenizer = AutoTokenizer.from_pretrained(
script_args.model_name, trust_remote_code=True)
config = AutoConfig.from_pretrained(
script_args.model_name, trust_remote_code=True)
else:
tokenizer = AutoTokenizer.from_pretrained(
script_args.model_name, trust_remote_code=True)
config = AutoConfig.from_pretrained(
script_args.model_name, trust_remote_code=True)
print("tokenizer: ", type(tokenizer))
if "llama" in script_args.model_name or "vicuna" in script_args.model_name or "Vicuna" in script_args.model_name:
# required for llama
tokenizer.add_special_tokens(
{
"eos_token": DEFAULT_EOS_TOKEN,
"bos_token": DEFAULT_BOS_TOKEN,
"unk_token": DEFAULT_UNK_TOKEN,
"pad_token": DEFAULT_PAD_TOKEN,
}
)
else:
# required for gpt2
tokenizer.pad_token = tokenizer.eos_token
device_map = "auto"
world_size = int(os.environ.get("WORLD_SIZE", 1))
ddp = world_size != 1
if ddp:
device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)}
print("device_map: ", device_map)
# model = AutoModelForSequenceClassification.from_pretrained(
# script_args.model_name, num_labels=1, torch_dtype=torch.bfloat16
# )
if "llama" in script_args.model_name or "vicuna" in script_args.model_name or "Vicuna" in script_args.model_name:
model = LlamaForSequenceClassification.from_pretrained(
script_args.model_name,
num_labels=1,
# torch_dtype=torch.bfloat16,
torch_dtype=torch.float16,
load_in_8bit=True,
device_map=device_map,
)
elif "chatglm" in script_args.model_name:
model = AutoModelForSeq2SeqLM.from_pretrained(
script_args.model_name,
num_labels=1,
# torch_dtype=torch.bfloat16,
torch_dtype=torch.float16,
trust_remote_code=True,
load_in_8bit=True,
device_map=device_map,
)
else:
model = AutoModelForSequenceClassification.from_pretrained(
script_args.model_name,
num_labels=1,
# torch_dtype=torch.bfloat16,
torch_dtype=torch.float16,
trust_remote_code=True,
load_in_8bit=True,
device_map=device_map,
)
print("model: ", type(model))
model = prepare_model_for_int8_training(model)
print("model: ", type(model))
peft_config = LoraConfig(
task_type=TaskType.SEQ_CLS,
inference_mode=False,
r=8,
lora_alpha=16, # 32,
lora_dropout=0.05, # 0.1,
bias="none",
)
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()
# Need to do this for gpt2, because it doesn't have an official pad token.
tokenizer.pad_token = tokenizer.eos_token
model.config.pad_token_id = tokenizer.eos_token_id
model.config.use_cache = not script_args.gradient_checkpointing
num_proc = 24 # Can adjust to be higher if you have more processors.
original_columns = train_dataset.column_names
reward_model = RewardModel(model.config, model.transformer, tokenizer)
print(reward_model)
layers = reward_model.transformer.layers
# Freeze the first 70% of the hidden layers of the reward model backbone
# parser.add_argument("--freeze_ratio", type=float, default=0.0, help="ratio of layers frozen for reward training")
num_layers = len(layers)
num_frozen = int(0.7 * num_layers)
for layer in layers[:num_frozen]:
layer.requires_grad_(False)
# if args.checkpoint is not None:
# checkpoints = glob.glob(args.checkpoint.replace("star", "*"))
# st = dict()
# for checkpoint in checkpoints:
# st.update(torch.load(checkpoint, map_location="cpu"))
# res = reward_model.load_state_dict(st, strict=False)
print(f"Finished loading model and tokenizer")
# Turn the dataset into pairs of post + summaries, where text_j is the preferred question + answer and text_k is the other.
# Then tokenize the dataset.
def preprocess_function(examples):
new_examples = {
"input_ids_j": [],
"attention_mask_j": [],
"input_ids_k": [],
"attention_mask_k": [],
}
# for question, response_j, response_k in zip(examples["question"], examples["response_j"], examples["response_k"]):
for question, response_j, response_k in zip(examples["user_input"], examples["completion_a"], examples["completion_b"]):
tokenized_j = tokenizer(
"Question: " + question + "\n\nAnswer: " + response_j, truncation=True)
tokenized_k = tokenizer(
"Question: " + question + "\n\nAnswer: " + response_k, truncation=True)
new_examples["input_ids_j"].append(tokenized_j["input_ids"])
new_examples["attention_mask_j"].append(tokenized_j["attention_mask"])
new_examples["input_ids_k"].append(tokenized_k["input_ids"])
new_examples["attention_mask_k"].append(tokenized_k["attention_mask"])
return new_examples
# preprocess the dataset and filter out QAs that are longer than 512
print("train_dataset: ", len(train_dataset))
train_dataset = train_dataset.map(
preprocess_function, batched=True, num_proc=num_proc, remove_columns=original_columns
)
train_dataset = train_dataset.filter(lambda x: len(
x["input_ids_j"]) <= 512 and len(x["input_ids_k"]) <= 512)
print("train_dataset: ", len(train_dataset))
print("eval_dataset: ", len(eval_dataset))
eval_dataset = eval_dataset.map(
preprocess_function, batched=True, num_proc=num_proc, remove_columns=original_columns)
eval_dataset = eval_dataset.filter(lambda x: len(
x["input_ids_j"]) <= 512 and len(x["input_ids_k"]) <= 512)
print("eval_dataset: ", len(eval_dataset))
# We need to define a special data collator that batches the data in our j vs k format.
@dataclass
class RewardDataCollatorWithPadding:
tokenizer: PreTrainedTokenizerBase
padding: Union[bool, str, PaddingStrategy] = True
max_length: Optional[int] = None
pad_to_multiple_of: Optional[int] = None
return_tensors: str = "pt"
def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
features_j = []
features_k = []
for feature in features:
features_j.append(
{
"input_ids": feature["input_ids_j"],
"attention_mask": feature["attention_mask_j"],
}
)
features_k.append(
{
"input_ids": feature["input_ids_k"],
"attention_mask": feature["attention_mask_k"],
}
)
batch_j = self.tokenizer.pad(
features_j,
padding=self.padding,
max_length=self.max_length,
pad_to_multiple_of=self.pad_to_multiple_of,
return_tensors=self.return_tensors,
)
batch_k = self.tokenizer.pad(
features_k,
padding=self.padding,
max_length=self.max_length,
pad_to_multiple_of=self.pad_to_multiple_of,
return_tensors=self.return_tensors,
)
batch = {
"input_ids_j": batch_j["input_ids"],
"attention_mask_j": batch_j["attention_mask"],
"input_ids_k": batch_k["input_ids"],
"attention_mask_k": batch_k["attention_mask"],
"return_loss": True,
}
return batch
# Define the metric that we'll use for validation.
accuracy = evaluate.load("accuracy")
def compute_metrics(eval_pred):
predictions, _ = eval_pred
# Here, predictions is rewards_j and rewards_k.
# We want to see how much of the time rewards_j > rewards_k.
predictions = np.argmax(predictions, axis=0)
labels = np.zeros(predictions.shape)
return accuracy.compute(predictions=predictions, references=labels)
class RewardTrainer(Trainer):
# Define how to compute the reward loss. We use the InstructGPT pairwise logloss: https://arxiv.org/abs/2203.02155
def compute_loss(self, model, inputs, return_outputs=False):
# print('inputs["input_ids_j"]: ', inputs["input_ids_j"].shape)
# print('inputs["attention_mask_j"]: ', inputs["attention_mask_j"].shape)
rewards_j = model(
chosen_input_ids=inputs["input_ids_j"], chosen_attention_mask=inputs["attention_mask_j"])["chosen_reward"]
# print("rewards_j: ", type(rewards_j), rewards_j.shape)
# print('inputs["input_ids_k"]: ', inputs["input_ids_k"].shape)
# print('inputs["attention_mask_k"]: ', inputs["attention_mask_k"].shape)
rewards_k = model(
rejected_input_ids=inputs["input_ids_k"], rejected_attention_mask=inputs["attention_mask_k"])["reject_reward"]
# print("rewards_k: ", type(rewards_k), rewards_k.shape)
loss = -nn.functional.logsigmoid(rewards_j - rewards_k).mean()
if return_outputs:
return loss, {"rewards_j": rewards_j, "rewards_k": rewards_k}
return loss
# Train the model, woohoo.
trainer = RewardTrainer(
# model=model,
model=reward_model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
compute_metrics=compute_metrics,
data_collator=RewardDataCollatorWithPadding(
tokenizer=tokenizer, max_length=512, pad_to_multiple_of=8),
)
model.config.use_cache = False
trainer.train(script_args.resume_from_checkpoint)
print("Saving last checkpoint of the model")
# model.save_pretrained(script_args.output_dir + "peft_last_checkpoint")
model.save_pretrained(output_name)