Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add Pose and YaRN #8327

Open
wants to merge 8 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
45 changes: 45 additions & 0 deletions llm/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,3 +267,48 @@ def convert_example_chatglm(example, tokenizer, data_args, is_test=True, intoken
features["position_ids"] = np.stack([position_ids, block_position_ids], axis=0)

return features


def get_example_pose(example, tokenizer, scaled_max_position_embeddings, model_max_position_embeddings):
source = example["text"]
tokenized_source = tokenizer(
source,
max_length=scaled_max_position_embeddings,
truncation=True,
truncation_side="left",
add_special_tokens=True,
)

ids = tokenized_source["input_ids"]
len_chunk = min(len(ids), model_max_position_embeddings)
if len_chunk <= model_max_position_embeddings:
len_chunk -= 1
len_input = len(ids)

import random
lt1 = 0 # chunk1 start pos
rt1 = random.randint(1, (len_chunk+1)//2) # chunk1 end pos

rt2 = random.randint(lt1+len_chunk, len_input-1) # chunk2 end pos
lt2 = rt2 - (len_chunk - (rt1-lt1)) # chunk2 start pos
chunked_ids = ids[lt1:rt1] + ids[lt2:rt2]
labels = ids[lt1+1:rt1+1] + ids[lt2+1:rt2+1] # Revised

pos_ids = range(len(chunked_ids))
pos_ids = [x + lt1 if i < rt1-lt1 else x + (lt2-(rt1-lt1)) for i, x in enumerate(pos_ids)] # 修正了position id的计算公式

features = {"input_ids": chunked_ids, "labels": labels, "position_ids": pos_ids} # Revised

return features


def test_preprocess_function(example, tokenizer, inference_length):

source = example["text"]
tokenized_source = tokenizer(source, padding=False, truncation=True, max_length=inference_length, return_dict=False)
input_ids = tokenized_source["input_ids"]
input_ids, labels = input_ids[:-1], input_ids[1:]
position_ids = list(range(len(input_ids)))
features = {"input_ids": input_ids, "position_ids": position_ids, "labels": labels}

return features
227 changes: 227 additions & 0 deletions llm/eval_ppl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,227 @@
import copy
import random
import argparse
from dataclasses import dataclass, field
from typing import Optional, Dict, Sequence, List
from pathlib import Path


import numpy as np
from paddlenlp.datasets import load_dataset
from paddlenlp.datasets.dataset import MapDataset
from tqdm import tqdm

from argument import (
DataArgument,
TrainingArguments,
)
from paddlenlp.utils.log import logger
import paddle
from paddlenlp.transformers import (
AutoConfig,
AutoModelForCausalLM,
AutoTokenizer,
LlamaTokenizer,
)
from utils import init_chat_template


@dataclass
class DataArguments(DataArgument):
dataset_name: str = field(default="scrolls-gov_report")
input_field: str = field(default="text")

@dataclass
class TrainingArguments(TrainingArguments):
optim: str = field(default="adamw")
model_max_position_embeddings: int = field(
default=2048,
metadata={"help": "Maximum position embeddings."},
)
min_input_tokens: int = field(default=500)
max_input_tokens: int = field(default=1000)
sliding_window_step: int = field(default=256)
window_length_list: List[int] = field(default_factory=lambda: [])
rope_scaling_type: Optional[str] = field(default=None)
rope_scaling_factor: float = field(default=1.0)


def compute_perplexity(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

和longlora的同学沟通一下,只实现一个通用的eval_ppl.py即可

encodings, model, tokenizer, add_start_token: bool = True, max_length=None, sliding_window_step=256, truncate=False, aggressive_memory=False
):

if add_start_token:
assert (
tokenizer.bos_token is not None
), "Input model must already have a BOS token if using add_start_token=True. Please use a different model, or set add_start_token=False"
max_tokenized_len = max_length - 1
else:
max_tokenized_len = max_length

encoded_texts = encodings["input_ids"]
attn_masks = encodings["attention_mask"]

if max_length and truncate:
encoded_texts = [x[0:max_tokenized_len] for x in encoded_texts]
attn_masks = [x[0:max_tokenized_len] for x in attn_masks]
sliding_window_step = max_tokenized_len

pbar = tqdm(total=len(encoded_texts))
nlls = []
total_nll = paddle.to_tensor(0,dtype="float64")
total_token_cnt = 0
for encoding_index in range(0, len(encoded_texts)):
labels = paddle.to_tensor(encoded_texts[encoding_index:encoding_index+1])

seq_len = labels.shape[1]

prev_end_loc = 0
for begin_loc in range(0, seq_len, sliding_window_step):

end_loc = min(begin_loc + max_tokenized_len, seq_len)
trg_len = end_loc - prev_end_loc
input_ids = labels[:, begin_loc:end_loc]

if add_start_token:
bos_tokens_tensor = paddle.to_tensor(
[[tokenizer.bos_token_id]] * input_ids.shape[0])
input_ids = paddle.concat(
[bos_tokens_tensor, input_ids], axis=1)


target_ids = input_ids.clone()
target_ids[:, :-trg_len] = -100

# Revised for paddle shift
input_ids = input_ids[..., :-1]
target_ids = target_ids[..., 1:]

with paddle.no_grad():
outputs = model(input_ids, labels=target_ids)
neg_log_likelihood = outputs[0]
total_nll += neg_log_likelihood * trg_len
total_token_cnt += trg_len

nlls.append(neg_log_likelihood)

ppl = float(paddle.exp(total_nll / total_token_cnt).cpu())
pbar.set_postfix(ppl=ppl)

prev_end_loc = end_loc
if end_loc == seq_len:
break

pbar.update(1)

ppl = float(paddle.exp(total_nll / total_token_cnt).cpu())
return {"mean_perplexity": ppl}


def main():

parser = argparse.ArgumentParser()
parser.add_argument("--min_input_tokens", type=int, default=500)
parser.add_argument("--max_input_tokens", type=int, default=1000)
parser.add_argument("--step", type=int, default=500)
parser.add_argument("--eval_nums", type=int, default=50)
parser.add_argument("--batch_size", type=int, default=1)
parser.add_argument("--sliding_window_step", type=int, default=256)
parser.add_argument('--window_length_list', type=int, nargs='+', default=[])
parser.add_argument("--truncate", action="store_true", default=False)
parser.add_argument("--model_max_position_embeddings", type=int, default=2048)
parser.add_argument("--rope_scaling_factor", type=float, default=1.0)
parser.add_argument("--rope_scaling_type", type=str, default=None)
parser.add_argument("--input_field", type=str, default="text")
parser.add_argument("--model_name", type=str, default="llama-7b")
parser.add_argument("--model_name_or_path", type=str, default="/home/v-daweizhu/teamdrive/model/llama-7b")
parser.add_argument("--dataset_name", type=str, default="scrolls-gov_report")
parser.add_argument("--dataset_name_or_path", type=str, default="")
parser.add_argument("--gpu_device", type=int, default=0)
args = parser.parse_args()

model_name_or_path = args.model_name_or_path
paddle.device.set_device('gpu:3')
print(f"Model loaded on {paddle.device.get_device()}")


scaled_max_position_embeddings=int(args.model_max_position_embeddings * args.rope_scaling_factor)

# Load Model
model_config = AutoConfig.from_pretrained(args.model_name_or_path)
model_config.max_position_embeddings = scaled_max_position_embeddings

# RoPE interpolation
if args.rope_scaling_type is not None:
model_config.rope_scaling={"type": args.rope_scaling_type, "factor": args.rope_scaling_factor}
if args.rope_scaling_type == "yarn":
model_config.rope_scaling["original_max_position_embeddings"] = args.model_max_position_embeddings

model = AutoModelForCausalLM.from_pretrained(
args.model_name_or_path,
config=model_config,
)

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(
args.model_name_or_path,
padding_side="right",
use_fast=True
)

if isinstance(tokenizer, LlamaTokenizer):
tokenizer.pad_token_id = tokenizer.eos_token_id
tokenizer.unk_token_id = tokenizer.eos_token_id

if "scrolls" in args.dataset_name:
args.input_field = "input"
elif "pile" in args.dataset_name:
args.input_field = "text"
elif "proof" in args.dataset_name:
args.input_field = "text"

input_texts = load_dataset("json", data_files=args.dataset_name_or_path, splits="train")

def tokenize_filter(example):
tokenized = tokenizer(
example[args.input_field],
add_special_tokens=False,
padding=True,
truncation=True,
max_length=args.max_input_tokens - 1, # leave room for <BOS> token to be added
return_attention_mask=True,
return_dict=True
)
tokenized["tokenized_len"] = len(tokenized["input_ids"])
return tokenized

def tokenize_batch(examples):
input_texts = {"input_ids": [], "position_ids":[], "attention_mask":[]}
for example in examples:
input_texts["input_ids"].append(example["input_ids"])
input_texts["position_ids"].append(example["position_ids"])
input_texts["attention_mask"].append(example["attention_mask"])

return input_texts

input_texts = input_texts.map(fn=tokenize_filter, lazy=False)
if args.min_input_tokens:
input_texts = input_texts.filter(
lambda x: x["tokenized_len"] >= 0)

if args.eval_nums:
input_texts = input_texts[:args.eval_nums]

input_texts = MapDataset(input_texts)
input_texts = input_texts.map(fn=tokenize_batch, batched=True, lazy=False)

context_window_size = args.window_length_list

for ctx_size in context_window_size:
# if args.truncate is True, we calucate the ppl on the whole input text
# otherwise, we calucate the ppl with sliding window
ppl = compute_perplexity(encodings=input_texts, model=model, tokenizer=tokenizer, add_start_token=True, max_length=ctx_size, sliding_window_step=args.sliding_window_step, truncate=True)["mean_perplexity"]
print(f"model: {args.model_name_or_path}; context window size: {ctx_size}; ppl: {ppl}")


if __name__ == "__main__":
main()
22 changes: 22 additions & 0 deletions llm/eval_ppl.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
model_prefix="facebook/llama-7b"
paddle_prefix=/home/whf/PaddleNLP_old/llm/checkpoints/pose_new_ckpts/yarn/0422_bs8/checkpoint-6000
data_prefix=/home/whf/PoSE/PoSE-Datasets


for factor in 8
do
python eval_ppl.py \
--model_name_or_path ${model_prefix} \
--model_name llama-7b-2k-$((factor*2))k-yarn \
--gpu_device 0 \
--rope_scaling_type ntk \
--rope_scaling_factor ${factor} \
--model_max_position_embeddings 2048 \
--max_input_tokens 16384 \
--min_input_tokens 16384 \
--window_length_list 2048 4096 8192 16384 \
--truncate \
--batch_size 1 \
--dataset_name scrolls-gov_report \
--dataset_name_or_path ${data_prefix}/scrolls/gov_report/test_long.jsonl
done
2 changes: 1 addition & 1 deletion llm/finetune_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -678,4 +678,4 @@ def compute_metrics_do_generation(eval_preds):


if __name__ == "__main__":
main()
main()