-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
22 lines (20 loc) · 822 Bytes
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
def print_trainable_parameters(model):
"""Prints the number of trainable parameters in the model."""
trainable_params = 0
all_param = 0
for _, param in model.named_parameters():
all_param += param.numel()
if param.requires_grad:
trainable_params += param.numel()
print(
f"trainable params: {trainable_params} || all params: {all_param} || trainable %: {100 * trainable_params / all_param}"
)
def generate_prompt(data_point):
return f'''
<human>: {data_point["question"]}
<assistant>: {data_point["answer"]}
'''.strip()
def generate_and_tokenize_prompt(data_point, tokenizer):
full_prompt = generate_prompt(data_point)
tokenized_full_prompt = tokenizer(full_prompt, padding=True, truncation=True)
return tokenized_full_prompt