Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FEAT: Support TensorRT-LLM backend #646

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
13 changes: 13 additions & 0 deletions xinference/model/llm/trtllm/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright 2022-2023 XProbe Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
356 changes: 356 additions & 0 deletions xinference/model/llm/trtllm/core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,356 @@
# Copyright 2022-2023 XProbe Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import logging
import time
import uuid
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional, TypedDict, Union

import numpy as np

from ....types import (
ChatCompletion,
ChatCompletionChunk,
ChatCompletionMessage,
Completion,
CompletionChoice,
CompletionChunk,
CompletionUsage,
)
from ..core import LLM
from ..llm_family import BUILTIN_LLM_FAMILIES
from ..utils import ChatModelMixin

logger = logging.getLogger(__name__)


if TYPE_CHECKING:
from torch import Tensor

try:
import tensorrt_llm # noqa: F401
import torch
from tensorrt_llm.quantization import QuantMode
from tensorrt_llm.runtime import ModelConfig, SamplingConfig

TRTLLM_INSTALLED = True
except ImportError:
TRTLLM_INSTALLED = False


class TRTModelConfig(TypedDict, total=False):
tokenizer_dir: str


class TRTGenerateConfig(TypedDict, total=False):
max_tokens: int
end_id: int
pad_id: int
num_beams: int
temperature: float
top_k: int
top_p: int
length_penalty: float
repetition_penalty: float
min_length: int
presence_penalty: float
use_beam_hyps: bool

stream: bool
stream_interval: int


MODEL_SPECIAL_TOKENS = {
"llama-2": {"EOS_TOKEN": 2, "PAD_TOKEN": 2},
"llama-2-chat": {"EOS_TOKEN": 2, "PAD_TOKEN": 2},
}
MODEL_NAME_TO_FAMILY = dict(
(family.model_name, family) for family in BUILTIN_LLM_FAMILIES
)


def read_config(config_path: Path):
with open(config_path, "r") as f:
config = json.load(f)
use_gpt_attention_plugin = config["plugin_config"]["gpt_attention_plugin"]
remove_input_padding = config["plugin_config"]["remove_input_padding"]
dtype = config["builder_config"]["precision"]
tp_size = config["builder_config"]["tensor_parallel"]
pp_size = config["builder_config"]["pipeline_parallel"]
world_size = tp_size * pp_size
assert (
world_size == tensorrt_llm.mpi_world_size()
), f"Engine world size ({world_size}) != Runtime world size ({tensorrt_llm.mpi_world_size()})"
num_heads = config["builder_config"]["num_heads"] // tp_size
hidden_size = config["builder_config"]["hidden_size"] // tp_size
vocab_size = config["builder_config"]["vocab_size"]
num_layers = config["builder_config"]["num_layers"]
num_kv_heads = config["builder_config"].get("num_kv_heads", num_heads)
paged_kv_cache = config["plugin_config"]["paged_kv_cache"]
tokens_per_block = config["plugin_config"]["tokens_per_block"]
quant_mode = QuantMode(config["builder_config"]["quant_mode"])
if config["builder_config"].get("multi_query_mode", False):
tensorrt_llm.logger.warning(
"`multi_query_mode` config is deprecated. Please rebuild the engine."
)
num_kv_heads = 1
num_kv_heads = (num_kv_heads + tp_size - 1) // tp_size
use_custom_all_reduce = config["plugin_config"].get("use_custom_all_reduce", False)

model_config = ModelConfig(
num_heads=num_heads,
num_kv_heads=num_kv_heads,
hidden_size=hidden_size,
vocab_size=vocab_size,
num_layers=num_layers,
gpt_attention_plugin=use_gpt_attention_plugin,
paged_kv_cache=paged_kv_cache,
tokens_per_block=tokens_per_block,
remove_input_padding=remove_input_padding,
dtype=dtype,
quant_mode=quant_mode,
use_custom_all_reduce=use_custom_all_reduce,
)

return model_config, tp_size, pp_size, dtype


def get_engine_name(model, dtype, tp_size, pp_size, rank):
if pp_size == 1:
return "{}_{}_tp{}_rank{}.engine".format(model, dtype, tp_size, rank)
return "{}_{}_tp{}_pp{}_rank{}.engine".format(model, dtype, tp_size, pp_size, rank)


class TRTModel(LLM):
def __init__(
self,
model_uid: str,
model_name: str,
model_path: str,
tokenizer_path: str,
):
if model_name not in MODEL_SPECIAL_TOKENS:
raise ValueError(
f"Model name must be one of follows: {MODEL_SPECIAL_TOKENS.keys()}"
)
self._model_uid: str = model_uid
self._model_name: str = model_name
self._model_path: str = model_path
self._tokenizer_path: str = tokenizer_path
self._model_config: "ModelConfig" = None
self._decoder: Any = None
self._tokenizer: Any = None

def load(self):
try:
import tensorrt_llm
except ImportError:
error_message = "Failed to import module 'tensorrt_llm'"
installation_guide = ["Please make sure 'tensorrt_llm' is installed. "]

raise ImportError(f"{error_message}\n\n{''.join(installation_guide)}")
try:
from transformers import PreTrainedTokenizerFast
except ImportError:
error_message = "Failed to import module 'transformers'"
installation_guide = [
"Please make sure 'transformers' is installed. ",
"You can install it by `pip install transformers`\n",
]
raise ImportError(f"{error_message}\n\n{''.join(installation_guide)}")

self._tokenizer = PreTrainedTokenizerFast.from_pretrained(self._tokenizer_path)
engine_dir = Path(self._model_path)
config_path = engine_dir / "config.json"
model_config, tp_size, pp_size, dtype = read_config(config_path)
logger.info(
f"Loading {self._model_uid} with following model config: {model_config}"
)
# TODO: support multiple GPUs
runtime_mapping = tensorrt_llm.Mapping(1, 0, tp_size=tp_size, pp_size=pp_size)
engine_name = get_engine_name("llama", dtype, tp_size, pp_size, 0)
serialize_path = engine_dir / engine_name
with open(serialize_path, "rb") as f:
engine_buffer = f.read()
self._model_config = model_config
self._decoder = tensorrt_llm.runtime.GenerationSession(
model_config, engine_buffer, runtime_mapping
)

def _sanitize_generate_config(
self,
generate_config: Optional[Dict] = None,
) -> TRTGenerateConfig:
if not generate_config:
generate_config = {}

sanitized = TRTGenerateConfig()
default_eos_token = MODEL_SPECIAL_TOKENS[self._model_name]["EOS_TOKEN"]
default_pad_token = MODEL_SPECIAL_TOKENS[self._model_name]["PAD_TOKEN"]
sanitized.setdefault("end_id", generate_config.get("end_id", default_eos_token))
sanitized.setdefault("pad_id", generate_config.get("pad_id", default_pad_token))

sanitized.setdefault("max_tokens", generate_config.get("max_tokens", 512))
sanitized.setdefault("num_beams", generate_config.get("num_beams", 1))
sanitized.setdefault("temperature", generate_config.get("temperature", 1.0))
sanitized.setdefault("top_k", generate_config.get("top_k", 1))
sanitized.setdefault("top_p", generate_config.get("top_p", 0.0))
sanitized.setdefault(
"length_penalty", generate_config.get("length_penalty", 1.0)
)
sanitized.setdefault(
"repetition_penalty", generate_config.get("repetition_penalty", 1.0)
)
sanitized.setdefault("min_length", generate_config.get("min_length", 1))
sanitized.setdefault(
"presence_penalty", generate_config.get("presence_penalty", 0.0)
)
sanitized.setdefault(
"use_beam_hyps", generate_config.get("use_beam_hyps", True)
)
sanitized.setdefault("stream", generate_config.get("stream", None))
sanitized.setdefault(
"stream_interval", generate_config.get("stream_interval", 5)
)
return sanitized

def _gen_completion_chunk(
self, out_ids: "Tensor", num_beams: int, out_start: int, out_end: int
):
choices = []
for beam in range(num_beams):
ids = out_ids[0][beam][out_start:out_end].tolist()
out_text = self._tokenizer.decode(ids)
completion_choice = CompletionChoice(
text=out_text, index=beam, logprobs=None, finish_reason=None
)
choices.append(completion_choice)
completion_chunk = CompletionChunk(
id=str(uuid.uuid1()),
object="text_completion",
created=int(time.time()),
model=self._model_uid,
choices=choices,
)
return completion_chunk

def generate(
self, prompt: str, generate_config: Optional[Dict] = None
) -> Union[Completion, Iterator[CompletionChunk]]:
if generate_config is None:
generate_config = dict()
sanitized_generate_config = self._sanitize_generate_config(generate_config)
max_tokens = sanitized_generate_config.pop("max_tokens")
stream = sanitized_generate_config.pop("stream")
stream_interval = sanitized_generate_config.pop("stream_interval")
num_beams = sanitized_generate_config.pop("num_beams")
sampling_config = SamplingConfig(**sanitized_generate_config)

input_tokens = [self._tokenizer.encode(prompt, add_special_tokens=False)]
input_lengths = torch.tensor(
[len(x) for x in input_tokens], dtype=torch.int32, device="cuda"
)
if self._model_config.remove_input_padding:
input_ids = np.concatenate(input_tokens)
input_ids = torch.tensor(
input_ids, dtype=torch.int32, device="cuda"
).unsqueeze(0)
else:
input_ids = torch.nested.to_padded_tensor(
torch.nested.nested_tensor(input_tokens, dtype=torch.int32),
sampling_config["end_id"],
).cuda()

max_input_length = torch.max(input_lengths).item()
self._decoder.setup(
input_lengths.size(0),
max_input_length,
max_tokens,
num_beams,
)

output_gen_ids = self._decoder.decode(
input_ids, input_lengths, sampling_config, streaming=stream
)
out_start = len(input_lengths)
if stream:
i = 0
out_ids = None
for out_ids in output_gen_ids:
i += 1
if not i % stream_interval:
# TODO: use async to decode and detokenize
out_end = len(input_lengths) + i
yield self._gen_completion_chunk(
out_ids, num_beams, out_start, out_end
)
out_start = out_end
if out_ids and i % stream_interval:
out_end = len(input_lengths) + i
yield self._gen_completion_chunk(out_ids, num_beams, out_start, out_end)
else:
completion = self._gen_completion_chunk(
output_gen_ids, num_beams, out_start, len(output_gen_ids)
)
choices = completion["choices"]
completion_tokens = 0
for beam in range(num_beams):
completion_tokens += int(
(
output_gen_ids[0][beam] == sanitized_generate_config["end_id"]
).nonzero(as_tuple=True)[0][0]
)
usage = CompletionUsage(
prompt_tokens=len(input_lengths),
completion_tokens=completion_tokens,
total_tokens=len(input_lengths) + completion_tokens,
)
return Completion(
id=str(uuid.uuid1()),
object="text_completion",
created=int(time.time()),
model=self._model_uid,
choices=choices,
usage=usage,
)


class TRTChatModel(TRTModel, ChatModelMixin):
def chat(
self,
prompt: str,
system_prompt: Optional[str] = None,
chat_history: Optional[List[ChatCompletionMessage]] = None,
generate_config: Optional[Dict] = None,
) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]:
model_family = MODEL_NAME_TO_FAMILY[self._model_name]
assert model_family.prompt_style is not None
prompt_style = model_family.prompt_style.copy()
if system_prompt:
prompt_style.system_prompt = system_prompt
chat_history = chat_history or []
full_prompt = self.get_prompt(prompt, chat_history, prompt_style)
if not generate_config:
generate_config = dict()
stream = generate_config.get("stream", None)
if stream:
it = self.generate(full_prompt, generate_config)
assert isinstance(it, Iterator)
return self._to_chat_completion_chunks(it)
else:
c = self.generate(full_prompt, generate_config)
assert not isinstance(c, Iterator)
return self._to_chat_completion(c)