Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Refactor] refactor prompt viewer & output reference #327

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
142 changes: 68 additions & 74 deletions tools/prompt_viewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@

from mmengine.config import Config, ConfigDict

from opencompass.openicl.icl_inferencer import (CLPInferencer, GenInferencer,
PPLInferencer)
from opencompass.openicl.icl_inferencer import GenInferencer, PPLInferencer
from opencompass.registry import ICL_PROMPT_TEMPLATES, ICL_RETRIEVERS
from opencompass.utils import (Menu, build_dataset_from_cfg,
build_model_from_cfg, dataset_abbr_from_cfg,
Expand Down Expand Up @@ -45,6 +44,24 @@ def parse_dataset_cfg(dataset_cfg: ConfigDict) -> Dict[str, ConfigDict]:
return dataset2cfg


def get_prompt(ice_idx, prompt_func, max_seq_len, model):
prompt = prompt_func(ice_idx)
num_token = model.get_token_len_from_template(prompt)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't work when no model is provided

if max_seq_len is None:
print(f'Number of tokens: {num_token}')
return ice_idx, prompt
while len(ice_idx) > 0 and num_token > max_seq_len:
num_ice = len(ice_idx)
old_num_token = num_token
ice_idx = ice_idx[:-1]
prompt = prompt_func(ice_idx)
num_token = model.get_token_len_from_template(prompt)
print(f'Truncating ice {num_ice} -> {num_ice - 1}',
f'Number of tokens: {old_num_token} -> {num_token}')
print(f'Number of tokens: {num_token}')
return prompt


def print_prompts(model_cfg, dataset_cfg, count=1):
# TODO: A really dirty method that copies code from PPLInferencer and
# GenInferencer. In the future, the prompt extraction code should be
Expand Down Expand Up @@ -84,92 +101,70 @@ def print_prompts(model_cfg, dataset_cfg, count=1):
assert infer_cfg.inferencer.type in [PPLInferencer, GenInferencer], \
'Only PPLInferencer and GenInferencer are supported'

ice = retriever.generate_ice(ice_idx_list[0], ice_template=ice_template)
print('=' * 100)
print('Full in-context example:')
print('-' * 100)
print(ice)
print('=' * 100)
for idx in range(min(count, len(ice_idx_list))):
print('=' * 100)
print(f'Data Item #{idx}:')
if infer_cfg.inferencer.type == PPLInferencer:
labels = retriever.get_labels(ice_template=ice_template,
prompt_template=prompt_template)
ice = [
retriever.generate_ice(ice_idx_list[_idx],
ice_template=ice_template)
for _idx in range(len(ice_idx_list))
]
print('-' * 100)
print('ICE Template:')
print('-' * 100)
print(ice[0])
print('-' * 100)
for label in labels:
prompt = retriever.generate_label_prompt(
idx,
ice[idx],
label,
ice_template=ice_template,
prompt_template=prompt_template,
remain_sep=None)
if max_seq_len is not None:
prompt_token_num = model.get_token_len_from_template(
prompt)
while len(ice_idx_list[idx]
) > 0 and prompt_token_num > max_seq_len:
num_ice = len(ice_idx_list[idx])
print(f'Truncating ice {num_ice} -> {num_ice - 1}',
f'Number of tokens: {prompt_token_num} -> ...')
ice_idx_list[idx] = ice_idx_list[idx][:-1]
ice[idx] = retriever.generate_ice(
ice_idx_list[idx], ice_template=ice_template)
prompt = retriever.generate_label_prompt(
idx,
ice[idx],
label,
ice_template=ice_template,
prompt_template=prompt_template)
prompt_token_num = model.get_token_len_from_template(
prompt)
print(f'Number of tokens: {prompt_token_num}')
if model is not None:
prompt = model.parse_template(prompt, mode='ppl')
print('-' * 100)
print(f'Label: {label}')
print('Sample prompt:')
print('-' * 100)
print(prompt)
print('-' * 100)
elif infer_cfg.inferencer.type in [GenInferencer, CLPInferencer]:
ice_idx = ice_idx_list[idx]
ice = retriever.generate_ice(ice_idx, ice_template=ice_template)
prompt = retriever.generate_prompt_for_generate_task(
idx,
ice,
gen_field_replace_token=infer_cfg.inferencer.get(
'gen_field_replace_token', ''),
ice_template=ice_template,
prompt_template=prompt_template)
if max_seq_len is not None:
prompt_token_num = model.get_token_len_from_template(prompt)
while len(ice_idx) > 0 and prompt_token_num > max_seq_len:
num_ice = len(ice_idx)
print(f'Truncating ice {num_ice} -> {num_ice - 1}',
f'Number of tokens: {prompt_token_num} -> ...')
ice_idx = ice_idx[:-1]

def prompt_func(ice_idx):
ice = retriever.generate_ice(ice_idx,
ice_template=ice_template)
prompt = retriever.generate_prompt_for_generate_task(
return retriever.generate_label_prompt(
idx,
ice,
gen_field_replace_token=infer_cfg.inferencer.get(
'gen_field_replace_token', ''),
label,
ice_template=ice_template,
prompt_template=prompt_template)
prompt_token_num = model.get_token_len_from_template(
prompt)
print(f'Number of tokens: {prompt_token_num}')
prompt_template=prompt_template,
remain_sep=None)

print('-' * 100)
print(f'Label: {label}')
prompt = get_prompt(ice_idx_list[idx], prompt_func,
max_seq_len, model)
if model is not None:
prompt = model.parse_template(prompt, mode='ppl')
print('Prompt:')
print('-' * 61)
print(prompt)
print('-' * 100)
elif infer_cfg.inferencer.type == GenInferencer:

def prompt_func(ice_idx):
ice = retriever.generate_ice(ice_idx,
ice_template=ice_template)
return retriever.generate_prompt_for_generate_task(
idx,
ice,
gen_field_replace_token=infer_cfg.inferencer.get(
'gen_field_replace_token', ''),
ice_template=ice_template,
prompt_template=prompt_template)

prompt = get_prompt(ice_idx_list[idx], prompt_func, max_seq_len,
model)
if model is not None:
prompt = model.parse_template(prompt, mode='gen')
print('-' * 100)
print('Sample prompt:')
print('-' * 100)
print('Prompt:')
print('-' * 61)
print(prompt)
print('-' * 100)
else:
raise NotImplementedError

reference = dataset.test[idx][dataset.reader.output_column]
print('Reference:')
print('-' * 61)
print(reference)


def main():
Expand Down Expand Up @@ -215,7 +210,6 @@ def main():
print('=' * 64, '[BEGIN]', '=' * 64)
print(f'[MODEL]: {model_abbr}')
print(f'[DATASET]: {dataset_abbr}')
print('---')
print_prompts(model_cfg, dataset_cfg, args.count)
print('=' * 65, '[END]', '=' * 65)
print()
Expand Down