Skip to content

Commit

Permalink
[Feature] Add lmdeploy tis python backend model (#1014)
Browse files Browse the repository at this point in the history
* add lmdeploy tis python backend model

* fix pr check

* update
  • Loading branch information
ispobock committed Apr 23, 2024
1 parent 8fe7b27 commit 81d0e4d
Show file tree
Hide file tree
Showing 3 changed files with 242 additions and 0 deletions.
41 changes: 41 additions & 0 deletions configs/eval_internlm_chat_lmdeploy_tis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from mmengine.config import read_base
from opencompass.models.lmdeploy_tis import LmdeployTisModel

with read_base():
# choose a list of datasets
from .datasets.mmlu.mmlu_gen_a484b3 import mmlu_datasets
from .datasets.ceval.ceval_gen_5f30c7 import ceval_datasets
from .datasets.SuperGLUE_WiC.SuperGLUE_WiC_gen_d06864 import WiC_datasets
from .datasets.SuperGLUE_WSC.SuperGLUE_WSC_gen_7902a7 import WSC_datasets
from .datasets.triviaqa.triviaqa_gen_2121ce import triviaqa_datasets
from .datasets.gsm8k.gsm8k_gen_1d7fe4 import gsm8k_datasets
from .datasets.humaneval.humaneval_gen_8e312c import humaneval_datasets
from .datasets.race.race_gen_69ee4f import race_datasets
from .datasets.crowspairs.crowspairs_gen_381af0 import crowspairs_datasets
# and output the results in a choosen format
from .summarizers.medium import summarizer

datasets = sum((v for k, v in locals().items() if k.endswith('_datasets')), [])

meta_template = dict(
round=[
dict(role='HUMAN', begin='<|im_start|>user\n', end='<|im_end|>\n'),
dict(role='BOT', begin='<|im_start|>assistant\n', end='<|im_end|>\n', generate=True),
],
eos_token_id=92542
)

models = [
dict(
type=LmdeployTisModel,
abbr='internlm-chat-20b-lmdeploy-tis',
path="internlm/internlm-chat-20b",
tis_addr='0.0.0.0:33337',
max_out_len=100,
max_seq_len=2048,
batch_size=8,
meta_template=meta_template,
run_cfg=dict(num_gpus=1, num_procs=1),
end_str='<|im_end|>',
)
]
1 change: 1 addition & 0 deletions opencompass/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from .lightllm_api import LightllmAPI # noqa: F401
from .llama2 import Llama2, Llama2Chat # noqa: F401, F403
from .lmdeploy_pytorch import LmdeployPytorchModel # noqa: F401
from .lmdeploy_tis import LmdeployTisModel # noqa: F401
from .minimax_api import MiniMax # noqa: F401
from .mistral_api import Mistral # noqa: F401
from .mixtral import Mixtral # noqa: F401
Expand Down
200 changes: 200 additions & 0 deletions opencompass/models/lmdeploy_tis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
import threading
from concurrent.futures import ThreadPoolExecutor
from functools import partial
from queue import Queue
from typing import Dict, List, Optional, Union

import numpy as np

from opencompass.models.base import BaseModel, LMTemplateParser
from opencompass.utils.logging import get_logger
from opencompass.utils.prompt import PromptList

PromptType = Union[PromptList, str]


def valid_str(string, coding='utf-8'):
"""decode text according to its encoding type."""
invalid_chars = [b'\xef\xbf\xbd']
bstr = bytes(string, coding)
for invalid_char in invalid_chars:
bstr = bstr.replace(invalid_char, b'')
ret = bstr.decode(encoding=coding, errors='ignore')
return ret


def prepare_tensor(name, input_tensor):
"""Create grpcclient's InferInput instance according to a given tensor."""
import tritonclient.grpc as grpcclient
from tritonclient.utils import np_to_triton_dtype
t = grpcclient.InferInput(name, list(input_tensor.shape),
np_to_triton_dtype(input_tensor.dtype))
t.set_data_from_numpy(input_tensor)
return t


def stream_callback(que, result, error):
"""callback function invoked by triton client."""
que.put((result, error))


class LmdeployTisModel(BaseModel):
"""Model wrapper for LMDeploy Python Backend Triton Inference Server gRPC
API.
Args:
path (str): The name of OpenAI's model.
tis_addr (str): The address (ip:port format) of turbomind's
triton inference server
max_seq_len (int): The maximum allowed sequence length of a model.
Note that the length of prompt + generated tokens shall not exceed
this value. Defaults to 2048.
meta_template (Dict, optional): The model's meta prompt
template if needed, in case the requirement of injecting or
wrapping of any meta instructions.
"""

is_api: bool = True

def __init__(self,
path: str,
tis_addr: str = '0.0.0.0:33337',
max_seq_len: int = 2048,
meta_template: Optional[Dict] = None,
end_str: Optional[str] = None):
super().__init__(path=path,
max_seq_len=max_seq_len,
meta_template=meta_template)
from lmdeploy.tokenizer import Tokenizer

self.logger = get_logger()
self.template_parser = LMTemplateParser(meta_template)
self.eos_token_id = None
if meta_template and 'eos_token_id' in meta_template:
self.eos_token_id = meta_template['eos_token_id']
self.tis_addr = tis_addr
self.tokenizer = Tokenizer(path)
self.end_str = end_str

def generate(
self,
inputs: List[str or PromptList],
max_out_len: int = 512,
temperature: float = 1.0,
) -> List[str]:
"""Generate results given a list of inputs.
Args:
inputs (List[str or PromptList]): A list of strings or PromptDicts.
The PromptDict should be organized in OpenCompass'
API format.
max_out_len (int): The maximum length of the output.
temperature (float): What sampling temperature to use,
between 0 and 2. Higher values like 0.8 will make the output
more random, while lower values like 0.2 will make it more
focused and deterministic. Defaults to 0.7.
Returns:
List[str]: A list of generated strings.
"""

with ThreadPoolExecutor() as executor:
results = list(
executor.map(self._generate, inputs,
[max_out_len] * len(inputs),
[temperature] * len(inputs),
[self.end_str] * len(inputs)))
return results

def wait(self):
"""Wait till the next query can be sent.
Applicable in both single-thread and multi-thread environments.
"""
return self.token_bucket.get_token()

def get_token_len(self, prompt: str) -> int:
input_ids = self.tokenizer.encode(prompt)
return len(input_ids)

def _call_triton_server(self, prompt, tis_addr, session_id,
request_output_len, temperature, res_que):
import tritonclient.grpc as grpcclient

with grpcclient.InferenceServerClient(tis_addr) as client:
inputs = [
prepare_tensor('prompt',
np.array([prompt.encode()], dtype=np.object_)),
prepare_tensor('max_tokens',
np.array([request_output_len], dtype=np.int32)),
prepare_tensor('temperature',
np.array([temperature], dtype=np.float_)),
prepare_tensor('top_p', np.array([1.0], dtype=np.float_)),
prepare_tensor('top_k', np.array([1], dtype=np.int32)),
prepare_tensor('ignore_eos', np.array([False],
dtype=np.bool_)),
prepare_tensor('stream', np.array([True], dtype=np.bool_)),
]

# async_stream
client.start_stream(partial(stream_callback, res_que))
client.async_stream_infer('lmdeploy_model',
inputs,
sequence_id=session_id,
sequence_start=True,
sequence_end=True)

res_que.put(None)
return

def _process_result(self, que):
text = ''
while True:
res = que.get()
if res is not None:
result, err = res
if err is not None:
print(err)
else:
res = result.as_numpy('response').item().decode()
text += res
else:
return text

def _generate(self,
prompt: str or PromptList,
max_out_len: int,
temperature: float,
end_str: Optional[str] = None) -> str:
"""Generate results given a list of inputs.
Args:
prompt (str or PromptList): A string or PromptDict.
The PromptDict should be organized in OpenCompass'
API format.
max_out_len (int): The maximum length of the output.
temperature (float): What sampling temperature to use,
between 0 and 2. Higher values like 0.8 will make the output
more random, while lower values like 0.2 will make it more
focused and deterministic.
Returns:
str: The generated string.
"""
assert type(
prompt
) is str, 'We only support string for LMDeploy Python Backend TIS API'

res_que = Queue()

self._call_triton_server(prompt=prompt,
tis_addr=self.tis_addr,
session_id=threading.currentThread().ident,
request_output_len=max_out_len,
temperature=temperature,
res_que=res_que)
text = self._process_result(res_que)
response = valid_str(text)
if end_str:
response = response.split(end_str)[0]
return response

0 comments on commit 81d0e4d

Please sign in to comment.