-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add Pose and YaRN #8327
Open
whf313
wants to merge
8
commits into
PaddlePaddle:develop
Choose a base branch
from
whf313:develop
base: develop
Could not load branches
Branch not found: {{ refName }}
Could not load tags
Nothing to show
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
add Pose and YaRN #8327
Changes from all commits
Commits
Show all changes
8 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,227 @@ | ||
import copy | ||
import random | ||
import argparse | ||
from dataclasses import dataclass, field | ||
from typing import Optional, Dict, Sequence, List | ||
from pathlib import Path | ||
|
||
|
||
import numpy as np | ||
from paddlenlp.datasets import load_dataset | ||
from paddlenlp.datasets.dataset import MapDataset | ||
from tqdm import tqdm | ||
|
||
from argument import ( | ||
DataArgument, | ||
TrainingArguments, | ||
) | ||
from paddlenlp.utils.log import logger | ||
import paddle | ||
from paddlenlp.transformers import ( | ||
AutoConfig, | ||
AutoModelForCausalLM, | ||
AutoTokenizer, | ||
LlamaTokenizer, | ||
) | ||
from utils import init_chat_template | ||
|
||
|
||
@dataclass | ||
class DataArguments(DataArgument): | ||
dataset_name: str = field(default="scrolls-gov_report") | ||
input_field: str = field(default="text") | ||
|
||
@dataclass | ||
class TrainingArguments(TrainingArguments): | ||
optim: str = field(default="adamw") | ||
model_max_position_embeddings: int = field( | ||
default=2048, | ||
metadata={"help": "Maximum position embeddings."}, | ||
) | ||
min_input_tokens: int = field(default=500) | ||
max_input_tokens: int = field(default=1000) | ||
sliding_window_step: int = field(default=256) | ||
window_length_list: List[int] = field(default_factory=lambda: []) | ||
rope_scaling_type: Optional[str] = field(default=None) | ||
rope_scaling_factor: float = field(default=1.0) | ||
|
||
|
||
def compute_perplexity( | ||
encodings, model, tokenizer, add_start_token: bool = True, max_length=None, sliding_window_step=256, truncate=False, aggressive_memory=False | ||
): | ||
|
||
if add_start_token: | ||
assert ( | ||
tokenizer.bos_token is not None | ||
), "Input model must already have a BOS token if using add_start_token=True. Please use a different model, or set add_start_token=False" | ||
max_tokenized_len = max_length - 1 | ||
else: | ||
max_tokenized_len = max_length | ||
|
||
encoded_texts = encodings["input_ids"] | ||
attn_masks = encodings["attention_mask"] | ||
|
||
if max_length and truncate: | ||
encoded_texts = [x[0:max_tokenized_len] for x in encoded_texts] | ||
attn_masks = [x[0:max_tokenized_len] for x in attn_masks] | ||
sliding_window_step = max_tokenized_len | ||
|
||
pbar = tqdm(total=len(encoded_texts)) | ||
nlls = [] | ||
total_nll = paddle.to_tensor(0,dtype="float64") | ||
total_token_cnt = 0 | ||
for encoding_index in range(0, len(encoded_texts)): | ||
labels = paddle.to_tensor(encoded_texts[encoding_index:encoding_index+1]) | ||
|
||
seq_len = labels.shape[1] | ||
|
||
prev_end_loc = 0 | ||
for begin_loc in range(0, seq_len, sliding_window_step): | ||
|
||
end_loc = min(begin_loc + max_tokenized_len, seq_len) | ||
trg_len = end_loc - prev_end_loc | ||
input_ids = labels[:, begin_loc:end_loc] | ||
|
||
if add_start_token: | ||
bos_tokens_tensor = paddle.to_tensor( | ||
[[tokenizer.bos_token_id]] * input_ids.shape[0]) | ||
input_ids = paddle.concat( | ||
[bos_tokens_tensor, input_ids], axis=1) | ||
|
||
|
||
target_ids = input_ids.clone() | ||
target_ids[:, :-trg_len] = -100 | ||
|
||
# Revised for paddle shift | ||
input_ids = input_ids[..., :-1] | ||
target_ids = target_ids[..., 1:] | ||
|
||
with paddle.no_grad(): | ||
outputs = model(input_ids, labels=target_ids) | ||
neg_log_likelihood = outputs[0] | ||
total_nll += neg_log_likelihood * trg_len | ||
total_token_cnt += trg_len | ||
|
||
nlls.append(neg_log_likelihood) | ||
|
||
ppl = float(paddle.exp(total_nll / total_token_cnt).cpu()) | ||
pbar.set_postfix(ppl=ppl) | ||
|
||
prev_end_loc = end_loc | ||
if end_loc == seq_len: | ||
break | ||
|
||
pbar.update(1) | ||
|
||
ppl = float(paddle.exp(total_nll / total_token_cnt).cpu()) | ||
return {"mean_perplexity": ppl} | ||
|
||
|
||
def main(): | ||
|
||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--min_input_tokens", type=int, default=500) | ||
parser.add_argument("--max_input_tokens", type=int, default=1000) | ||
parser.add_argument("--step", type=int, default=500) | ||
parser.add_argument("--eval_nums", type=int, default=50) | ||
parser.add_argument("--batch_size", type=int, default=1) | ||
parser.add_argument("--sliding_window_step", type=int, default=256) | ||
parser.add_argument('--window_length_list', type=int, nargs='+', default=[]) | ||
parser.add_argument("--truncate", action="store_true", default=False) | ||
parser.add_argument("--model_max_position_embeddings", type=int, default=2048) | ||
parser.add_argument("--rope_scaling_factor", type=float, default=1.0) | ||
parser.add_argument("--rope_scaling_type", type=str, default=None) | ||
parser.add_argument("--input_field", type=str, default="text") | ||
parser.add_argument("--model_name", type=str, default="llama-7b") | ||
parser.add_argument("--model_name_or_path", type=str, default="/home/v-daweizhu/teamdrive/model/llama-7b") | ||
parser.add_argument("--dataset_name", type=str, default="scrolls-gov_report") | ||
parser.add_argument("--dataset_name_or_path", type=str, default="") | ||
parser.add_argument("--gpu_device", type=int, default=0) | ||
args = parser.parse_args() | ||
|
||
model_name_or_path = args.model_name_or_path | ||
paddle.device.set_device('gpu:3') | ||
print(f"Model loaded on {paddle.device.get_device()}") | ||
|
||
|
||
scaled_max_position_embeddings=int(args.model_max_position_embeddings * args.rope_scaling_factor) | ||
|
||
# Load Model | ||
model_config = AutoConfig.from_pretrained(args.model_name_or_path) | ||
model_config.max_position_embeddings = scaled_max_position_embeddings | ||
|
||
# RoPE interpolation | ||
if args.rope_scaling_type is not None: | ||
model_config.rope_scaling={"type": args.rope_scaling_type, "factor": args.rope_scaling_factor} | ||
if args.rope_scaling_type == "yarn": | ||
model_config.rope_scaling["original_max_position_embeddings"] = args.model_max_position_embeddings | ||
|
||
model = AutoModelForCausalLM.from_pretrained( | ||
args.model_name_or_path, | ||
config=model_config, | ||
) | ||
|
||
# Load tokenizer | ||
tokenizer = AutoTokenizer.from_pretrained( | ||
args.model_name_or_path, | ||
padding_side="right", | ||
use_fast=True | ||
) | ||
|
||
if isinstance(tokenizer, LlamaTokenizer): | ||
tokenizer.pad_token_id = tokenizer.eos_token_id | ||
tokenizer.unk_token_id = tokenizer.eos_token_id | ||
|
||
if "scrolls" in args.dataset_name: | ||
args.input_field = "input" | ||
elif "pile" in args.dataset_name: | ||
args.input_field = "text" | ||
elif "proof" in args.dataset_name: | ||
args.input_field = "text" | ||
|
||
input_texts = load_dataset("json", data_files=args.dataset_name_or_path, splits="train") | ||
|
||
def tokenize_filter(example): | ||
tokenized = tokenizer( | ||
example[args.input_field], | ||
add_special_tokens=False, | ||
padding=True, | ||
truncation=True, | ||
max_length=args.max_input_tokens - 1, # leave room for <BOS> token to be added | ||
return_attention_mask=True, | ||
return_dict=True | ||
) | ||
tokenized["tokenized_len"] = len(tokenized["input_ids"]) | ||
return tokenized | ||
|
||
def tokenize_batch(examples): | ||
input_texts = {"input_ids": [], "position_ids":[], "attention_mask":[]} | ||
for example in examples: | ||
input_texts["input_ids"].append(example["input_ids"]) | ||
input_texts["position_ids"].append(example["position_ids"]) | ||
input_texts["attention_mask"].append(example["attention_mask"]) | ||
|
||
return input_texts | ||
|
||
input_texts = input_texts.map(fn=tokenize_filter, lazy=False) | ||
if args.min_input_tokens: | ||
input_texts = input_texts.filter( | ||
lambda x: x["tokenized_len"] >= 0) | ||
|
||
if args.eval_nums: | ||
input_texts = input_texts[:args.eval_nums] | ||
|
||
input_texts = MapDataset(input_texts) | ||
input_texts = input_texts.map(fn=tokenize_batch, batched=True, lazy=False) | ||
|
||
context_window_size = args.window_length_list | ||
|
||
for ctx_size in context_window_size: | ||
# if args.truncate is True, we calucate the ppl on the whole input text | ||
# otherwise, we calucate the ppl with sliding window | ||
ppl = compute_perplexity(encodings=input_texts, model=model, tokenizer=tokenizer, add_start_token=True, max_length=ctx_size, sliding_window_step=args.sliding_window_step, truncate=True)["mean_perplexity"] | ||
print(f"model: {args.model_name_or_path}; context window size: {ctx_size}; ppl: {ppl}") | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
model_prefix="facebook/llama-7b" | ||
paddle_prefix=/home/whf/PaddleNLP_old/llm/checkpoints/pose_new_ckpts/yarn/0422_bs8/checkpoint-6000 | ||
data_prefix=/home/whf/PoSE/PoSE-Datasets | ||
|
||
|
||
for factor in 8 | ||
do | ||
python eval_ppl.py \ | ||
--model_name_or_path ${model_prefix} \ | ||
--model_name llama-7b-2k-$((factor*2))k-yarn \ | ||
--gpu_device 0 \ | ||
--rope_scaling_type ntk \ | ||
--rope_scaling_factor ${factor} \ | ||
--model_max_position_embeddings 2048 \ | ||
--max_input_tokens 16384 \ | ||
--min_input_tokens 16384 \ | ||
--window_length_list 2048 4096 8192 16384 \ | ||
--truncate \ | ||
--batch_size 1 \ | ||
--dataset_name scrolls-gov_report \ | ||
--dataset_name_or_path ${data_prefix}/scrolls/gov_report/test_long.jsonl | ||
done |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -678,4 +678,4 @@ def compute_metrics_do_generation(eval_preds): | |
|
||
|
||
if __name__ == "__main__": | ||
main() | ||
main() |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
和longlora的同学沟通一下,只实现一个通用的eval_ppl.py即可