Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integrate pytorch poc python api #490

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
42 changes: 42 additions & 0 deletions configs/eval_internlm_chat_pytorch_poc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from mmengine.config import read_base
from opencompass.models import PytorchModel


with read_base():
# choose a list of datasets
from .datasets.mmlu.mmlu_gen_a484b3 import mmlu_datasets
from .datasets.ceval.ceval_gen_5f30c7 import ceval_datasets
from .datasets.SuperGLUE_WiC.SuperGLUE_WiC_gen_d06864 import WiC_datasets
from .datasets.SuperGLUE_WSC.SuperGLUE_WSC_gen_6dc406 import WSC_datasets
from .datasets.triviaqa.triviaqa_gen_2121ce import triviaqa_datasets
from .datasets.gsm8k.gsm8k_gen_1d7fe4 import gsm8k_datasets
from .datasets.humaneval.humaneval_gen_8e312c import humaneval_datasets
from .datasets.race.race_gen_69ee4f import race_datasets
from .datasets.crowspairs.crowspairs_gen_381af0 import crowspairs_datasets
# and output the results in a choosen format
from .summarizers.medium import summarizer

datasets = sum((v for k, v in locals().items() if k.endswith('_datasets')), [])


meta_template = dict(
round=[
dict(role='HUMAN', begin='<|User|>:', end='<eoh>\n'),
dict(role='BOT', begin='<|Bot|>:', end='<eoa>\n', generate=True),
],
eos_token_id=103028)

models = [
dict(
type=PytorchModel,
abbr='internlm-chat-20b-pytorch-poc',
# path = '/mnt/140/InternLM/internlm-chat-7b',
path='/mnt/140/InternLM/20B/internlm-20b-chat',
max_out_len=100,
max_seq_len=2048,
batch_size=8,
concurrency=8,
meta_template=meta_template,
run_cfg=dict(num_gpus=1, num_procs=1),
)
]
41 changes: 41 additions & 0 deletions configs/eval_internlm_chat_pytorch_poc_w8a8.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from mmengine.config import read_base
from opencompass.models import PytorchModel


with read_base():
# choose a list of datasets
from .datasets.mmlu.mmlu_gen_a484b3 import mmlu_datasets
from .datasets.ceval.ceval_gen_5f30c7 import ceval_datasets
from .datasets.SuperGLUE_WiC.SuperGLUE_WiC_gen_d06864 import WiC_datasets
from .datasets.SuperGLUE_WSC.SuperGLUE_WSC_gen_6dc406 import WSC_datasets
from .datasets.triviaqa.triviaqa_gen_2121ce import triviaqa_datasets
from .datasets.gsm8k.gsm8k_gen_1d7fe4 import gsm8k_datasets
from .datasets.humaneval.humaneval_gen_8e312c import humaneval_datasets
from .datasets.race.race_gen_69ee4f import race_datasets
from .datasets.crowspairs.crowspairs_gen_381af0 import crowspairs_datasets
# and output the results in a choosen format
from .summarizers.medium import summarizer

datasets = sum((v for k, v in locals().items() if k.endswith('_datasets')), [])


meta_template = dict(
round=[
dict(role='HUMAN', begin='<|User|>:', end='<eoh>\n'),
dict(role='BOT', begin='<|Bot|>:', end='<eoa>\n', generate=True),
],
eos_token_id=103028)

models = [
dict(
type=PytorchModel,
abbr='internlm-chat-7b-pytorch-poc-w8a8',
path = '/nvme/caoweihan/projects/lmdeploy/work_dir',
max_out_len=100,
max_seq_len=2048,
batch_size=8,
concurrency=8,
meta_template=meta_template,
run_cfg=dict(num_gpus=1, num_procs=1),
)
]
120 changes: 120 additions & 0 deletions configs/eval_llama2_chat_pytorch_poc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
from mmengine.config import read_base
from opencompass.models import PytorchModel


with read_base():
# choose a list of datasets
# from .datasets.mmlu.mmlu_gen_a484b3 import mmlu_datasets
from .datasets.ceval.ceval_gen_5f30c7 import ceval_datasets
# from .datasets.SuperGLUE_WiC.SuperGLUE_WiC_gen_d06864 import WiC_datasets
# from .datasets.SuperGLUE_WSC.SuperGLUE_WSC_gen_6dc406 import WSC_datasets
# from .datasets.triviaqa.triviaqa_gen_2121ce import triviaqa_datasets
# from .datasets.gsm8k.gsm8k_gen_1d7fe4 import gsm8k_datasets
# from .datasets.humaneval.humaneval_gen_8e312c import humaneval_datasets
# from .datasets.race.race_gen_69ee4f import race_datasets
# from .datasets.crowspairs.crowspairs_gen_381af0 import crowspairs_datasets
# and output the results in a choosen format
from .summarizers.medium import summarizer

datasets = sum((v for k, v in locals().items() if k.endswith('_datasets')), [])


meta_template = dict(
round=[
dict(role="HUMAN", begin='[INST] ', end=' [/INST] '),
dict(role="BOT", begin="", end='', generate=True),
],
)

# config for internlm-chat-7b
# models = [
# dict(
# type=TurboMindModel,
# abbr='internlm-chat-7b-turbomind',
# path="./turbomind",
# max_out_len=100,
# max_seq_len=2048,
# batch_size=32,
# concurrency=32,
# meta_template=meta_template,
# run_cfg=dict(num_gpus=1, num_procs=1),
# )
# ]

# config for internlm-chat-7b-w4 model
# models = [
# dict(
# type=TurboMindModel,
# abbr='internlm-chat-7b-w4-turbomind',
# path="./turbomind",
# max_out_len=100,
# max_seq_len=2048,
# batch_size=32,
# concurrency=32,
# meta_template=meta_template,
# run_cfg=dict(num_gpus=1, num_procs=1),
# )
# ]

# config for internlm-chat-7b-w4kv8 model
# models = [
# dict(
# type=TurboMindModel,
# abbr='internlm-chat-7b-w4kv8-turbomind',
# path="./turbomind",
# max_out_len=100,
# max_seq_len=2048,
# batch_size=32,
# concurrency=32,
# meta_template=meta_template,
# run_cfg=dict(num_gpus=1, num_procs=1),
# )
# ]

# config for internlm-chat-20b
# models = [
# dict(
# type=TurboMindModel,
# abbr='internlm-chat-20b-turbomind',
# path="./turbomind",
# max_out_len=100,
# max_seq_len=2048,
# batch_size=8,
# concurrency=8,
# meta_template=meta_template,
# run_cfg=dict(num_gpus=1, num_procs=1),
# )
# ]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also remove the useless configs


# config for internlm-chat-20b-w4 model
models = [
dict(
type=PytorchModel,
abbr='llama2-chat-7b-pytorch-poc',
path="/mnt/142/gaojianfei/quantization/smooth_llama_chat_absmax",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please avoid the internal specific path

# path='/mnt/140/InternLM/20B/internlm-20b-chat',
max_out_len=100,
max_seq_len=2048,
batch_size=1,
concurrency=1,
meta_template=meta_template,
run_cfg=dict(num_gpus=1, num_procs=1),
# stop_words=[103027, 103028],
# w8a8=True
)
]

# config for internlm-chat-20b-w4kv8 model
# models = [
# dict(
# type=TurboMindModel,
# abbr='internlm-chat-20b-w4kv8-turbomind',
# path="./turbomind",
# max_out_len=100,
# max_seq_len=2048,
# batch_size=16,
# concurrency=16,
# meta_template=meta_template,
# run_cfg=dict(num_gpus=1, num_procs=1),
# )
# ]
1 change: 1 addition & 0 deletions opencompass/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@
from .intern_model import InternLM # noqa: F401, F403
from .llama2 import Llama2, Llama2Chat # noqa: F401, F403
from .openai_api import OpenAI # noqa: F401
from .pytorch_poc import PytorchModel
153 changes: 153 additions & 0 deletions opencompass/models/pytorch_poc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
import random
from concurrent.futures import ThreadPoolExecutor
from typing import Dict, List, Optional, Union

from lmdeploy.pytorch_poc import engine as tm
HIT-cwh marked this conversation as resolved.
Show resolved Hide resolved
from lmdeploy.pytorch_poc.messages import SamplingParam
from transformers import AutoTokenizer

from opencompass.models.base import BaseModel
from opencompass.utils.logging import get_logger
from opencompass.utils.prompt import PromptList

PromptType = Union[PromptList, str]


def valid_str(string, coding='utf-8'):
"""decode text according to its encoding type."""
invalid_chars = [b'\xef\xbf\xbd']
bstr = bytes(string, coding)
for invalid_char in invalid_chars:
bstr = bstr.replace(invalid_char, b'')
ret = bstr.decode(encoding=coding, errors='ignore')
return ret


class PytorchModel(BaseModel):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about PytorchTurbomindModel?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the ambiguity. With Turbomind in lmdeploy, we harmoniously integrate C++ with Python to carry out the inference process. On the other hand, our Pytorch proof-of-concept prefers to take a more streamlined approach by solely utilizing Python for inference. To explore the utilization of Turbomind on OpenCompass, kindly consider referring to pr484 for detailed guidance.

"""Model wrapper for TurboMind Python API.

Args:
path (str): path of the turbomind model
max_seq_len (int): The maximum allowed sequence length of a model.
Note that the length of prompt + generated tokens shall not exceed
this value. Defaults to 2048.
meta_template (Dict, optional): The model's meta prompt
template if needed, in case the requirement of injecting or
wrapping of any meta instructions.
"""

def __init__(
self,
path: str,
concurrency: int = 8,
max_seq_len: int = 2048,
meta_template: Optional[Dict] = None,
):

super().__init__(path=path,
max_seq_len=max_seq_len,
meta_template=meta_template)
self.logger = get_logger()
self.tokenizer = AutoTokenizer.from_pretrained(path,
trust_remote_code=True)
tm_model = tm.Engine(path)
self.generators = [
tm_model.create_instance() for i in range(concurrency)
]
self.generator_ids = [i + 1 for i in range(concurrency)]

def generate(
self,
inputs: List[str],
max_out_len: int = 512,
temperature: float = 1.0,
) -> List[str]:
"""Generate results given a list of inputs.

Args:
inputs (List[str]): A list of prompts
max_out_len (int): The maximum length of the output.
temperature (float): What sampling temperature to use,
between 0 and 2. Higher values like 0.8 will make the output
more random, while lower values like 0.2 will make it more
focused and deterministic. Defaults to 1.0.

Returns:
List[str]: A list of generated strings.
"""
assert isinstance(
inputs, List), f'List(str) is expected, but got {type(inputs)}'

# split inputs into batches
batch_size = len(self.generators)
batch_inputs = [
inputs[i:i + batch_size] for i in range(0, len(inputs), batch_size)
]

results = []
for batch_input in batch_inputs:
with ThreadPoolExecutor() as executor:
_results = list(
executor.map(self._generate,
self.generators[:len(batch_input)],
self.generator_ids[:len(batch_input)],
batch_input, [max_out_len] * len(batch_input),
[temperature] * len(batch_input)))
results += _results
return results

def get_token_len(self, prompt: str) -> int:
input_ids = self.tokenizer.encode(prompt)
return len(input_ids)

def wait(self):
"""Wait till the next query can be sent.

Applicable in both single-thread and multi-thread environments.
"""
return self.token_bucket.get_token()

def _generate(self, generator, session_id, prompt: str or PromptList,
max_out_len: int, temperature: float) -> str:
"""Generate results given a list of inputs.

Args:
prompt (str or PromptList): A string or PromptDict.
The PromptDict should be organized in OpenCompass'
API format.
max_out_len (int): The maximum length of the output.
temperature (float): What sampling temperature to use,
between 0 and 2. Higher values like 0.8 will make the output
more random, while lower values like 0.2 will make it more
focused and deterministic.

Returns:
str: The generated string.
"""
assert type(
prompt) is str, 'We only support string for TurboMind Python API'
input_ids = self.tokenizer.encode(prompt)
sampling_param = SamplingParam(top_k=40,
top_p=0.8,
temperature=temperature,
repetition_penalty=1.0,
ignore_eos=False,
random_seed=random.getrandbits(64),
stop_words=[self.eos_token_id])
response_size = 0

for outputs in generator.stream_infer(
session_id=session_id,
# input_ids=input_ids,
prompt_token_ids=input_ids,
request_output_len=max_out_len,
step=0,
sampling_param=sampling_param):
status, res, tokens = outputs
response_all = self.tokenizer.decode(res)
response_cur = response_all[response_size:]
response_all = valid_str(response_all)
response_size += len(response_cur)
if hasattr(generator, 'end'):
generator.end(session_id)
return response_all