Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds Online DPO #1605

Closed
wants to merge 14 commits into from
209 changes: 209 additions & 0 deletions examples/scripts/dpo_online.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,209 @@
# flake8: noqa
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
# regular:
accelerate launch --config_file=examples/accelerate_configs/multi_gpu.yaml examples/scripts/dpo_online.py \
--dataset_name=trl-internal-testing/hh-rlhf-trl-style \
--dataset_num_proc=4 \
--model_name_or_path=Qwen/Qwen1.5-0.5B-Chat \
--per_device_train_batch_size 4 \
--learning_rate 1e-3 \
--gradient_accumulation_steps 1 \
--logging_steps 10 \
--eval_steps 500 \
--output_dir="scratch/dpo_anthropic_hh" \
--warmup_steps 150 \
--report_to wandb \
--bf16 \
--logging_first_step \
--no_remove_unused_columns \
--sanity_check

# peft:
python examples/scripts/dpo.py \
--dataset_name=trl-internal-testing/hh-rlhf-trl-style \
--model_name_or_path=gpt2 \
--per_device_train_batch_size 4 \
--learning_rate 1e-3 \
--gradient_accumulation_steps 1 \
--logging_steps 10 \
--eval_steps 500 \
--output_dir="dpo_anthropic_hh" \
--optim rmsprop \
--warmup_steps 150 \
--report_to wandb \
--bf16 \
--logging_first_step \
--no_remove_unused_columns \
--use_peft \
--lora_r=16 \
--lora_alpha=16
"""

import logging
import os
from contextlib import nullcontext

TRL_USE_RICH = os.environ.get("TRL_USE_RICH", False)

from trl.commands.cli_utils import DPOScriptArguments, init_zero_verbose, TrlParser

if TRL_USE_RICH:
init_zero_verbose()
FORMAT = "%(message)s"

from rich.console import Console
from rich.logging import RichHandler

import torch
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
from trl.trainer import OnlineDPOTrainer
from trl import (
DPOConfig,
DPOTrainer,
ModelConfig,
RichProgressCallback,
get_kbit_device_map,
get_peft_config,
get_quantization_config,
)

from trl.trainer import WinRateCallback, MockJudge, PairRMJudge

if TRL_USE_RICH:
logging.basicConfig(format=FORMAT, datefmt="[%X]", handlers=[RichHandler()], level=logging.INFO)


if __name__ == "__main__":
parser = TrlParser((DPOScriptArguments, DPOConfig, ModelConfig))
args, training_args, model_config = parser.parse_args_and_config()

# Force use our print callback
if TRL_USE_RICH:
training_args.disable_tqdm = True
console = Console()

###################
# Model & Tokenizer
###################
torch_dtype = (
model_config.torch_dtype
if model_config.torch_dtype in ["auto", None]
else getattr(torch, model_config.torch_dtype)
)
quantization_config = get_quantization_config(model_config)
model_kwargs = dict(
revision=model_config.model_revision,
trust_remote_code=model_config.trust_remote_code,
attn_implementation=model_config.attn_implementation,
torch_dtype=torch_dtype,
use_cache=False if training_args.gradient_checkpointing else True,
device_map=get_kbit_device_map() if quantization_config is not None else None,
quantization_config=quantization_config,
)
model = AutoModelForCausalLM.from_pretrained(model_config.model_name_or_path, **model_kwargs)
peft_config = get_peft_config(model_config)
if peft_config is None:
model_ref = AutoModelForCausalLM.from_pretrained(model_config.model_name_or_path, **model_kwargs)
else:
model_ref = None
tokenizer = AutoTokenizer.from_pretrained(model_config.model_name_or_path, revision=model_config.model_revision)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
if tokenizer.bos_token is None:
tokenizer.bos_token = tokenizer.eos_token
if tokenizer.chat_template is None:
tokenizer.chat_template = "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\n' + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}"
if args.ignore_bias_buffers:
# torch distributed hack
model._ddp_params_and_buffers_to_ignore = [
name for name, buffer in model.named_buffers() if buffer.dtype == torch.bool
]

################
# Optional rich context managers
###############
init_context = nullcontext() if not TRL_USE_RICH else console.status("[bold green]Initializing the DPOTrainer...")
save_context = (
nullcontext()
if not TRL_USE_RICH
else console.status(f"[bold green]Training completed! Saving the model to {training_args.output_dir}")
)

################
# Dataset
################
ds = load_dataset(args.dataset_name)
if args.sanity_check:
for key in ds:
ds[key] = ds[key].select(range(1_000))


for key in ds:
ds[key] = ds[key].remove_columns(["chosen", "rejected", "score_chosen", "score_rejected"])

def process(row):
row["prompt"] = tokenizer.apply_chat_template(row["messages"], tokenize=False, add_generation_prompt=True)
return row

ds = ds.map(
process,
num_proc=training_args.dataset_num_proc,
load_from_cache_file=False,
)
train_dataset = ds[args.dataset_train_split]
eval_dataset = ds[args.dataset_test_split]

################
# Training
################
with init_context:
trainer = OnlineDPOTrainer(
model,
model_ref,
annotator_cls=PairRMJudge,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
tokenizer=tokenizer,
peft_config=get_peft_config(model_config),
callbacks=[RichProgressCallback] if TRL_USE_RICH else None,
)
# prompts_ds = load_dataset(args.dataset_name, split="test[:32]")
# prompts_ds = prompts_ds.map(
# lambda x: {
# "prompt": tokenizer.apply_chat_template(x["chosen"][:-1], tokenize=False, add_generation_prompt=True)
# }
# )
# win_rate_callback = WinRateCallback(
# prompts=prompts_ds["prompt"],
# judge=judge,
# generation_config=GenerationConfig(
# temperature=0.9,
# do_sample=True,
# num_return_sequences=1,
# pad_token_id=tokenizer.pad_token_id,
# eos_token_id=tokenizer.eos_token_id,
# max_new_tokens=512,
# ),
# trainer=trainer,
# )
# trainer.add_callback(win_rate_callback)

trainer.train()

with save_context:
trainer.save_model(training_args.output_dir)
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@
"deepspeed": ["deepspeed>=0.9.5"],
"benchmark": ["wandb", "ghapi", "openrlbenchmark==0.2.1a5", "requests", "deepspeed"],
"quantization": ["bitsandbytes<=0.41.1"],
"llm_judge": ["openai>=1.23.2", "huggingface_hub>=0.22.2", "llm-blender>=0.0.2"],
}
EXTRAS["dev"] = []
for reqs in EXTRAS.values():
Expand Down
1 change: 1 addition & 0 deletions trl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@
RewardTrainer,
SFTConfig,
SFTTrainer,
OnlineDPOTrainer,
)
from .trainer.utils import get_kbit_device_map, get_peft_config, get_quantization_config, RichProgressCallback
from .commands.cli_utils import init_zero_verbose, SFTScriptArguments, DPOScriptArguments, TrlParser
Expand Down
1 change: 1 addition & 0 deletions trl/commands/scripts
5 changes: 5 additions & 0 deletions trl/trainer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
],
"dpo_config": ["DPOConfig"],
"dpo_trainer": ["DPOTrainer"],
"odpo_trainer": ["OnlineDPOTrainer"],
"cpo_config": ["CPOConfig"],
"cpo_trainer": ["CPOTrainer"],
"iterative_sft_trainer": ["IterativeSFTTrainer"],
Expand All @@ -47,6 +48,8 @@
"sft_trainer": ["SFTTrainer"],
"base": ["BaseTrainer"],
"ddpo_config": ["DDPOConfig"],
"callbacks": ["WinRateCallback"],
"judges": ["MockJudge", "PairRMJudge", "OpenAIJudge"],
}

try:
Expand Down Expand Up @@ -92,6 +95,8 @@
from .reward_trainer import RewardTrainer, compute_accuracy
from .sft_config import SFTConfig
from .sft_trainer import SFTTrainer
from .odpo_trainer import OnlineDPOTrainer
from .judges import PairRMJudge, MockJudge, OpenAIJudge, WinRateCallback

try:
if not is_diffusers_available():
Expand Down
123 changes: 123 additions & 0 deletions trl/trainer/callbacks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
from typing import List

from accelerate.utils import gather_object
from datasets import Dataset
from tqdm import tqdm
from transformers import (
GenerationConfig,
TrainerCallback,
TrainerControl,
TrainerState,
TrainingArguments,
is_wandb_available,
)

from ..models.utils import unwrap_model_for_generation


if is_wandb_available():
import wandb


class WinRateCallback(TrainerCallback):
def __init__(
self,
prompts: List[str],
generation_config: GenerationConfig,
judge,
trainer,
):
self.prompts = prompts
self.generation_config = generation_config
self.completions = []
self.judge = judge
self.ref_completions = []
self.trainer = trainer

def on_train_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
model = self.trainer.model_wrapped
tokenizer = kwargs["tokenizer"]
accelerator = self.trainer.accelerator

with accelerator.split_between_processes(self.prompts, apply_padding=True) as prompts:
# local_dataset = Dataset.from_dict(prompts)

with unwrap_model_for_generation(model, accelerator) as unwrapped_model:
unwrapped_model.eval()
for prompt in tqdm(prompts, desc="Generating ref completions for win rate"):
tokenized_prompt = tokenizer(prompt, return_tensors="pt").to(model.device)
generation = unwrapped_model.generate(
**tokenized_prompt,
generation_config=self.generation_config,
)
padded_prompt_length = tokenized_prompt.input_ids.shape[1]
generation = generation[:, padded_prompt_length:]
text_generations = tokenizer.batch_decode(generation, skip_special_tokens=True)

ref_response = text_generations[0]
self.ref_completions.append(ref_response)
unwrapped_model.train()

def on_evaluate(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
model = self.trainer.model_wrapped
tokenizer = kwargs["tokenizer"]
accelerator = self.trainer.accelerator

with accelerator.split_between_processes(self.prompts, apply_padding=True) as prompts:
annotation_batch = {"prompts": prompts, "completions": []}

with unwrap_model_for_generation(model, accelerator) as unwrapped_model:
unwrapped_model.eval()
for idx, prompt in enumerate(tqdm(prompts, desc="Generating completions for win rate")):
tokenized_prompt = tokenizer(prompt, return_tensors="pt").to(model.device)
generations = unwrapped_model.generate(
**tokenized_prompt,
generation_config=self.generation_config,
)
padded_prompt_length = tokenized_prompt.input_ids.shape[1]
generations = generations[:, padded_prompt_length:]
text_generations = tokenizer.batch_decode(generations, skip_special_tokens=True)

response0 = text_generations[0]
response1 = self.ref_completions[idx]

annotation_batch["completions"].append([response0, response1])
unwrapped_model.train()
# TODO, rerun with order or responses swapped and average
results_dict = self.judge.judge_batch(annotation_batch["prompts"], annotation_batch["completions"])
results_dict = Dataset.from_dict(
{
"results": results_dict,
"prompts": annotation_batch["prompts"],
"completions": annotation_batch["completions"],
}
) # maybe just map the original dataset for logging
results_dict = gather_object(results_dict)

# Logging
if accelerator.is_main_process:
dataset_len = len(self.prompts)
results_dataset = Dataset.from_list(results_dict).select(range(dataset_len))

win_rate = sum([r == 0 for r in results_dataset["results"]]) / len(results_dataset)
self.trainer.log({"win_rate": win_rate})

if is_wandb_available():
wandb.log({"eval_win_rate": win_rate, "train/global_step": state.global_step})
prompts = results_dataset["prompts"]
policy = [c[0] for c in results_dataset["completions"]]
ref = [c[1] for c in results_dataset["completions"]]
chosen_indices = results_dataset["results"]
self.trainer.log(
{
"winrate_generations": wandb.Table(
columns=["Prompt", "Policy", "Ref Model", "Chosen index"],
rows=[ # TODO replace with zip unpacking
[prompt, pol, ref, index]
for prompt, pol, ref, index in zip(prompts, policy, ref, chosen_indices)
],
)
}
)
# pop Table otherwise it is included in the history which cannot be pickled and causes an error
self.trainer.state.log_history.pop()