Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dpo add #8407

Open
wants to merge 40 commits into
base: develop
Choose a base branch
from
Open

Dpo add #8407

Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
0e6af02
add jinja template
Southpika Apr 3, 2024
518b9ed
add in cfg
Southpika Mar 12, 2024
e50325f
add system
Southpika Apr 7, 2024
c6ac050
add notes
Southpika Apr 7, 2024
f1bc935
fix apply chat template
Southpika Apr 7, 2024
89a5188
add generation flag
Southpika Apr 7, 2024
f35e5b3
update jinja ut
Southpika Apr 7, 2024
c0cfbc7
fix error
Southpika Apr 7, 2024
167057b
add syntax error check
Southpika Apr 7, 2024
d98369e
fix syntax error
Southpika Apr 7, 2024
263a888
refresh codev
Southpika Apr 8, 2024
9677f25
refresh codecov
Southpika Apr 9, 2024
7f5d261
Merge branch 'chat_template' of https://github.com/Southpika/PaddleNL…
wtmlon Apr 12, 2024
393f87f
update special token map in render
Southpika Apr 12, 2024
a418566
add dpo data process
wtmlon Apr 12, 2024
2124df7
update save
Southpika Apr 16, 2024
9adf87b
Merge branch 'chat_template' of https://github.com/Southpika/PaddleNL…
wtmlon Apr 16, 2024
680d0ab
zero padding data stream
wtmlon Apr 16, 2024
2cea89f
complete dpo zero padding
wtmlon Apr 17, 2024
abe9a4f
update
wtmlon Apr 17, 2024
e654ade
add unittest
wtmlon Apr 18, 2024
daf79ef
fix template
wtmlon Apr 18, 2024
0b7e0e7
bug fix
wtmlon Apr 19, 2024
9a1c4c1
base model support
wtmlon Apr 22, 2024
8e4f626
support llama qwen dpo
wtmlon Apr 24, 2024
4c24566
Merge branch 'develop' of https://github.com/PaddlePaddle/PaddleNLP i…
wtmlon Apr 24, 2024
24d83fd
bugfix
wtmlon Apr 28, 2024
9cd8293
add efficient token benchmark
wtmlon Apr 29, 2024
fd68dfe
Merge branch 'develop' of https://github.com/PaddlePaddle/PaddleNLP i…
wtmlon Apr 29, 2024
7e82d70
remove pp code
wtmlon Apr 30, 2024
7b5fe66
update
wtmlon Apr 30, 2024
73fc5ad
remove import
wtmlon Apr 30, 2024
6363bc3
remove offload
wtmlon Apr 30, 2024
71a4735
update
wtmlon Apr 30, 2024
0e87fec
Merge branch 'develop' of https://github.com/PaddlePaddle/PaddleNLP i…
wtmlon May 8, 2024
966a4bd
fix lint
wtmlon May 8, 2024
1572b27
benchmark dp support
wtmlon May 8, 2024
6fe2a78
fix dpo
lugimzzz May 9, 2024
b819478
add dpo
lugimzzz May 9, 2024
54de043
add dpo
lugimzzz May 9, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
16 changes: 8 additions & 8 deletions llm/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,9 +163,9 @@ def tokenize_rounds_example(tokenizer, example, data_args, **kwargs):
return tokenized_source, labels


def convert_example_common(example, tokenizer, data_args, is_test=True, intokens=False):
def convert_example_common(example, tokenizer, data_args, is_test=True, zero_padding=False):
if tokenizer.chat_template is not None:
return convert_rounds_example_common(example, tokenizer, data_args, is_test, intokens)
return convert_rounds_example_common(example, tokenizer, data_args, is_test, zero_padding)

tokenized_source, tokenized_target_input_ids = tokenize_example(tokenizer, example, data_args)
if is_test:
Expand All @@ -183,21 +183,21 @@ def convert_example_common(example, tokenizer, data_args, is_test=True, intokens
features = {"input_ids": input_ids, "labels": labels}
if "position_ids" in tokenized_source:
features["position_ids"] = list(range(seq_length))
if intokens:
if zero_padding:
features["attention_mask"] = np.tri(seq_length, seq_length, dtype=bool)

return features


def convert_rounds_example_common(example, tokenizer, data_args, is_test=True, intokens=False):
def convert_rounds_example_common(example, tokenizer, data_args, is_test=True, zero_padding=False):
"""convert multi-rounds conversation example

Args:
example (dict): the source of example
tokenizer (PretrainedTokenizer): the instance of tokenizer
data_args (DataArgument): data argument for data preprocessing
is_test (bool, optional): whether is testing stage. Defaults to True.
intokens (bool, optional): whether use in_tokens. Defaults to False.
zero_padding (bool, optional): whether use in_tokens. Defaults to False.

Returns:
dict[str, np.ndarray]: the features of example
Expand All @@ -216,7 +216,7 @@ def convert_rounds_example_common(example, tokenizer, data_args, is_test=True, i

seq_length = len(input_ids)
features = {"input_ids": input_ids, "labels": labels}
if intokens:
if zero_padding:
features["attention_mask"] = np.tri(seq_length, seq_length, dtype=bool)

if "position_ids" in rounds_inputs:
Expand All @@ -226,7 +226,7 @@ def convert_rounds_example_common(example, tokenizer, data_args, is_test=True, i
return rounds_inputs


def convert_example_chatglm(example, tokenizer, data_args, is_test=True, intokens=False):
def convert_example_chatglm(example, tokenizer, data_args, is_test=True, zero_padding=False):
if tokenizer.chat_template is not None:
# chatglm only support single-round finetune
example = convert_multi_rounds_to_single_round(example, tokenizer)
Expand All @@ -249,7 +249,7 @@ def convert_example_chatglm(example, tokenizer, data_args, is_test=True, intoken
"labels": labels,
}

if intokens:
if zero_padding:
seq_length = len(input_ids)
# attention_mask
attention_mask = np.tri(seq_length, seq_length, dtype=bool)
Expand Down
223 changes: 223 additions & 0 deletions llm/dpo_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,223 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np


def preprocess_dpo_example(data, tokenizer, data_args, model_args):
"""Convert raw format example to Example."""
# 1. Check data format
if isinstance(data["src"], str):
data["src"] = [data["src"]]
if isinstance(data["tgt"], str):
data["tgt"] = [data["tgt"]]
if len(data["src"]) != len(data["tgt"]) + 1:
return None
if (
(len(data["response"]) != 2)
or (len(data["response"]) != len(data["sort"]))
or data["sort"][0] == data["sort"][1]
):
return None
if data["sort"][0] > data["sort"][1]:
chosen = data["response"][0]
rejected = data["response"][1]
else:
chosen = data["response"][1]
rejected = data["response"][0]

chosen_encode_tokens = []
for idx in range(len(data["src"])):
if idx < len(data["tgt"]):
if tokenizer.chat_template is not None:
chosen_encode_tokens.append(
[
data["src"][idx].strip(),
data["tgt"][idx].strip(),
]
)
else:
chosen_encode_tokens.append(
[
tokenizer.encode(data["src"][idx].strip(), add_special_tokens=True)["input_ids"],
tokenizer.encode(data["tgt"][idx].strip(), add_special_tokens=False)["input_ids"]
+ [tokenizer.eos_token_id],
]
)
else:
if tokenizer.chat_template is not None:
chosen_encode_tokens.append(
[
data["src"][idx].strip(),
chosen.strip(),
]
)
else:
chosen_encode_tokens.append(
[
tokenizer.encode(data["src"][idx].strip(), add_special_tokens=True)["input_ids"],
tokenizer.encode(chosen.strip(), add_special_tokens=False)["input_ids"]
+ [tokenizer.eos_token_id],
]
)

if tokenizer.chat_template is not None:
chat_input_list = chosen_encode_tokens
chosen_encode_tokens = tokenizer.encode_chat_inputs(chat_input_list)["conversations"]
# convert to rejected chosen_encode_tokens
chat_input_list[-1][-1] = rejected.strip()
rejected_encode_tokens = tokenizer.encode_chat_inputs(chat_input_list)["conversations"]

"""Post process sequence: tokenization & truncation."""
tokens_prompt = chosen_encode_tokens[-1][0][:-1]
tokens_chosen = chosen_encode_tokens[-1][0][-1:] + chosen_encode_tokens[-1][-1][:-1]
tokens_rejected = chosen_encode_tokens[-1][0][-1:] + rejected_encode_tokens[-1][-1][:-1]
else:
tokens_prompt = chosen_encode_tokens[-1][0][:-1]
tokens_chosen = (
chosen_encode_tokens[-1][0][-1:] + tokenizer.encode(chosen.strip(), add_special_tokens=False)["input_ids"]
)
tokens_rejected = (
chosen_encode_tokens[-1][0][-1:]
+ tokenizer.encode(rejected.strip(), add_special_tokens=False)["input_ids"]
)

if len(tokens_prompt) + len(tokens_chosen) + len(tokens_rejected) > data_args.max_seq_len:
# truncate prompt
tokens_prompt = tokens_prompt[-data_args.max_prompt_len :]
if (len(tokens_prompt) + len(tokens_chosen) + len(tokens_rejected)) > data_args.max_seq_len:
max_response_len = data_args.max_seq_len - len(tokens_prompt)
# 按比例截断
max_chosen_len = int(len(tokens_chosen) / (len(tokens_chosen) + len(tokens_rejected)) * max_response_len)
max_rejected_len = max_response_len - max_chosen_len
tokens_chosen = tokens_chosen[:max_chosen_len]
tokens_rejected = tokens_rejected[:max_rejected_len]

cur_len = +len(tokens_prompt) + len(tokens_chosen) + len(tokens_rejected)
turn_index = len(chosen_encode_tokens) - 2

# append former dialog contents
while turn_index >= 0:
tokens_src = chosen_encode_tokens[turn_index][0]
tokens_target = chosen_encode_tokens[turn_index][1]
turn_index -= 1

if len(tokens_src) + len(tokens_target) > data_args.max_seq_len - cur_len:
break
tokens_prompt = tokens_src + tokens_target + tokens_prompt
cur_len += len(tokens_src) + len(tokens_target)

input_ids = tokens_prompt + tokens_chosen + tokens_rejected

prompt_len = len(tokens_prompt)
chosen_len = len(tokens_chosen)
rejected_len = len(tokens_rejected)
seq_len = len(input_ids)
# make position ids & labels

position_ids = (
list(range(prompt_len)) # prompt
+ list(range(prompt_len, prompt_len + chosen_len)) # chosen
+ list(range(prompt_len, prompt_len + rejected_len)) # rejected
)
chosen_labels = [0] * prompt_len + tokens_chosen[1:] + [tokenizer.eos_token_id] + [0] * rejected_len
rejected_labels = [0] * prompt_len + [0] * chosen_len + tokens_rejected[1:] + [tokenizer.eos_token_id]

# response index
response_indexs = [prompt_len, prompt_len + chosen_len, seq_len]
output_dict = {
"input_ids": input_ids,
"position_ids": position_ids,
"chosen_labels": chosen_labels,
"rejected_labels": rejected_labels,
"response_indexs": response_indexs,
}

# attention mask
if model_args.use_attn_mask_start_row_indices:
output_dict["attn_mask_start_row_indices"] = (
[seq_len] * prompt_len + [prompt_len + chosen_len] * chosen_len + [seq_len] * rejected_len
)

else:
attention_mask = np.tri(seq_len, seq_len, dtype=bool)
attention_mask[(prompt_len + chosen_len) :, prompt_len : (prompt_len + chosen_len)] = False
output_dict["attention_mask"] = attention_mask

return output_dict


def dpo_collate_fn(batch, max_seq_len=None):
"""Convert batch data into tensor."""
# max_seq_len = 4096
if max_seq_len is None:
raise ValueError("max_seq_len is None.")

input_dict = {
"input_ids": [],
"position_ids": [],
"chosen_labels": [],
"rejected_labels": [],
"response_indexs": [],
}
sequence = batch[0]
if "attn_mask_start_row_indices" in sequence:
input_dict["attn_mask_start_row_indices"] = []
use_attn_mask_start_row_indices = True
elif "attention_mask" in sequence:
input_dict["attention_mask"] = []
use_attn_mask_start_row_indices = False
else:
raise ValueError("attention_mask and attn_mask_start_row_indices are both None.")

for i, sequence in enumerate(batch):
difference = max_seq_len - len(sequence["input_ids"])

input_dict["input_ids"].append(sequence["input_ids"] + [0] * difference)
input_dict["position_ids"].append(sequence["position_ids"] + [0] * difference)
input_dict["chosen_labels"].append(sequence["chosen_labels"] + [0] * difference)
input_dict["rejected_labels"].append(sequence["rejected_labels"] + [0] * difference)
if use_attn_mask_start_row_indices:
input_dict["attn_mask_start_row_indices"].append(
[sequence["attn_mask_start_row_indices"] + [sequence["attn_mask_start_row_indices"][-1]] * difference]
)
else:
input_dict["attention_mask"].append(
np.pad(
sequence["attention_mask"],
pad_width=((0, 0), (0, difference), (0, difference)),
mode="constant",
constant_values=False,
)
)

for ri in sequence["response_indexs"]:
input_dict["response_indexs"].append(
[
i, # bs
ri[0], # chosen_response_start_index
ri[1], # rejeted_response_start_index
ri[2], # rejeted_response_end_index + 1
]
)

for key in input_dict:
if key == "attention_mask":
input_dict[key] = np.array(input_dict[key], dtype=bool)
elif key == "attn_mask_start_row_indices":
input_dict[key] = np.array(input_dict[key], dtype=np.int32)
else:
input_dict[key] = np.array(input_dict[key])

return input_dict