diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 549f25c1c..9d7627935 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -43,6 +43,23 @@ jobs: - name: Type check if: ${{ github.event_name == 'pull_request' }} run: git diff --name-only --diff-filter=AM "origin/$GITHUB_BASE_REF" -z -- '*.py{,i}' | xargs -0 --no-run-if-empty hatch run dev:typing + tests: + runs-on: ubuntu-latest + if: ${{ github.event_name == 'pull_request' || github.event_name == 'push' }} + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest, macos-latest] + python-version: ['3.8', '3.9', '3.10', '3.11'] + name: tests (${{ matrix.os }}-${{ matrix.python-version }}) + steps: + - uses: actions/checkout@v3 + with: + fetch-depth: 0 + - name: Setup CI + uses: ./.github/actions/setup-repo + - name: Run tests + run: hatch run test concurrency: group: ci-${{ github.event.pull_request.number || github.sha }} cancel-in-progress: true diff --git a/pyproject.toml b/pyproject.toml index 441858860..267d2a255 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,6 +41,7 @@ dependencies = [ # tabulate for CLI with CJK support # >=0.9.0 for some bug fixes "tabulate[widechars]>=0.9.0", + "typing_extensions", ] description = 'OpenLLM: REST/gRPC API server for running any open Large-Language Model - StableLM, Llama, Alpaca, Dolly, Flan-T5, Custom' dynamic = ["version"] @@ -101,9 +102,16 @@ dependencies = [ "pytest-mock", "pytest-randomly", "pytest-rerunfailures", + # NOTE: To run all hooks "pre-commit", + # NOTE: Using under ./tools/update-optional-dependencies.py "tomlkit", + # NOTE: Using under ./tools/update-readme.py "markdown-it-py", + # NOTE: Tests strategies with Hypothesis + "hypothesis", + # NOTE: snapshot testing + "syrupy", ] [tool.hatch.envs.default.scripts] cov = ["test-cov", "cov-report"] diff --git a/src/openllm/__init__.py b/src/openllm/__init__.py index c23c67dd7..53aea9979 100644 --- a/src/openllm/__init__.py +++ b/src/openllm/__init__.py @@ -30,6 +30,11 @@ from . import utils as utils from .__about__ import __version__ as __version__ from .exceptions import MissingDependencyError +import logging as _ + +if utils.DEBUG: + _.basicConfig(level=_.NOTSET) + _import_structure = { "_llm": ["LLM", "Runner"], diff --git a/src/openllm/_configuration.py b/src/openllm/_configuration.py index 448edd99b..56605d0ad 100644 --- a/src/openllm/_configuration.py +++ b/src/openllm/_configuration.py @@ -21,10 +21,21 @@ ```python class FlanT5Config(openllm.LLMConfig): + __config__ = { + "url": "https://huggingface.co/docs/transformers/model_doc/flan-t5", + "default_id": "google/flan-t5-large", + "model_ids": [ + "google/flan-t5-small", + "google/flan-t5-base", + "google/flan-t5-large", + "google/flan-t5-xl", + "google/flan-t5-xxl", + ], + } class GenerationConfig: - temperature: float = 0.75 - max_new_tokens: int = 3000 + temperature: float = 0.9 + max_new_tokens: int = 2048 top_k: int = 50 top_p: float = 0.4 repetition_penalty = 1.0 @@ -37,9 +48,11 @@ class GenerationConfig: """ from __future__ import annotations +import functools import inspect import logging import os +import sys import typing as t from operator import itemgetter @@ -52,8 +65,18 @@ class GenerationConfig: import openllm -from .exceptions import GpuNotAvailableError, OpenLLMException -from .utils import LazyType, ModelEnv, bentoml_cattr, dantic, first_not_none, lenient_issubclass +from .exceptions import ForbiddenAttributeError, GpuNotAvailableError, OpenLLMException +from .utils import DEBUG, LazyType, bentoml_cattr, dantic, first_not_none, lenient_issubclass + +if hasattr(t, "Required"): + from typing import Required +else: + from typing_extensions import Required + +if hasattr(t, "NotRequired"): + from typing import NotRequired +else: + from typing_extensions import NotRequired _T = t.TypeVar("_T") @@ -129,7 +152,7 @@ def attrs_to_options( typ: type[t.Any] | None = None, suffix_generation: bool = False, ) -> t.Callable[..., ClickFunctionWrapper[..., t.Any]]: - # TODO: support parsing nested attrs class + # TODO: support parsing nested attrs class and Union envvar = field.metadata["env"] dasherized = inflection.dasherize(name) underscored = inflection.underscore(name) @@ -429,6 +452,10 @@ def _populate_value_from_env_var( _sentinel = object() +def _field_env_key(model_name: str, key: str, suffix: str | None = None) -> str: + return "_".join(filter(None, map(str.upper, ["OPENLLM", model_name, suffix.strip("_") if suffix else "", key]))) + + def _has_own_attribute(cls: type[t.Any], attrib_name: t.Any): """ Check whether *cls* defines *attrib_name* (and doesn't just inherit it). @@ -480,192 +507,200 @@ def _is_class_var(annot: str | t.Any) -> bool: return annot.startswith(_classvar_prefixes) -def _add_method_dunders(cls: type[t.Any], method: t.Callable[..., t.Any]) -> t.Callable[..., t.Any]: +def _add_method_dunders(cls: type[t.Any], method_or_cls: _T, _overwrite_doc: str | None = None) -> _T: """ Add __module__ and __qualname__ to a *method* if possible. """ try: - method.__module__ = cls.__module__ + method_or_cls.__module__ = cls.__module__ except AttributeError: pass try: - method.__qualname__ = ".".join((cls.__qualname__, method.__name__)) + method_or_cls.__qualname__ = ".".join((cls.__qualname__, method_or_cls.__name__)) except AttributeError: pass try: - method.__doc__ = "Method generated by attrs for class " f"{cls.__qualname__}." + method_or_cls.__doc__ = ( + _overwrite_doc or "Method or class generated by LLMConfig for class " f"{cls.__qualname__}." + ) except AttributeError: pass - return method + return method_or_cls # cached it here to save one lookup per assignment _object_getattribute = object.__getattribute__ -def _make_evolve_kwds(attr_name: str) -> str: - evolve_kwds = { - "metadata": f"dict(env=__env_{attr_name}, description=__description_{attr_name})", - "default": f"__extract_{attr_name}", - } - return ", ".join(f"{k}={v}" for k, v in evolve_kwds.items()) - - -def _safe_getattribute_get(attr_name: str, value_var: t.Any) -> list[str]: - return [ - f"__fallback = _getenv(__env_{attr_name}, {value_var})", - "try:", - f" __extract_{attr_name} = _getenv(__env_{attr_name}, _getattr('{attr_name}'))", - "except AttributeError:", - f" __extract_{attr_name} = __fallback", - f"transformed.append(__attr_{attr_name}.evolve({_make_evolve_kwds(attr_name)}))", - ] - - -def _make_auto_generation_env(cls: type[LLMConfig], model_name: str) -> FieldTransformers[type[LLMConfig]]: - # Circumvent the descriptor protocol to save one lookup per assignment. - globs: dict[str, t.Any] = { - "_getenv": os.environ.get, - "_cached_getattribute_get": _object_getattribute.__get__, - "_cached_gen": getattr(cls, "GenerationConfig", _sentinel), - } - lines: list[str] = ["_getattr = _cached_getattribute_get(_cached_gen)", "transformed = []"] - for f in attr.fields(GenerationConfig): - attr_name = f.name - globs.update( - { - f"__env_{attr_name}": f"OPENLLM_{model_name.upper()}_GENERATION_{attr_name.upper()}", - f"__attr_{attr_name}": f, - f"__description_{attr_name}": f.metadata.get("description", "(not provided)"), - } - ) - lines.extend(_safe_getattribute_get(attr_name, f.default)) - lines.append("return transformed") - - script = "def auto_env(_, fields):\n %s\n" % "\n ".join(lines) if lines else "pass" - return _make_method("auto_env", script, _generate_unique_filename(cls, "auto_env"), globs) +class ModelSettings(t.TypedDict, total=False): + """ModelSettings serve only for typing purposes as this is transcribed into LLMConfig.__config__. + Note that all fields from this dictionary will then be converted to __openllm_*__ fields in LLMConfig. + """ + # NOTE: These required fields should be at the top, as it will be kw_only + default_id: Required[str] + model_ids: Required[ListStr] -# NOTE: This is the ModelSettings where we can control the behaviour of the LLM. -# refers to the __openllm_*__ docstring inside LLMConfig for more information. -class ModelSettings(t.TypedDict, total=False): - # NOTE: meta + # meta url: str requires_gpu: bool trust_remote_code: bool - requirements: t.Optional[t.List[str]] + requirements: t.Optional[ListStr] - # NOTE: naming convention, only name_type is needed + # naming convention, only name_type is needed to infer from the class # as the three below it can be determined automatically name_type: t.Literal["dasherize", "lowercase"] - model_name: str - start_name: str - env: openllm.utils.ModelEnv + model_name: NotRequired[str] + start_name: NotRequired[str] + env: NotRequired[openllm.utils.ModelEnv] - # NOTE: serving configuration + # serving configuration timeout: int workers_per_resource: t.Union[int, float] - # NOTE: use t.Required once we drop 3.8 support - default_id: str - model_ids: list[str] - - # NOTE: the target generation_config class to be used. + # the target generation_config class to be used. generation_class: t.Type[GenerationConfig] -def _gen_default_settings(cls: type[LLMConfig]) -> ModelSettings: - """Generate the default ModelConfig and delete __config__ in LLMConfig - if defined inplace.""" +_ModelSettings: type[attr.AttrsInstance] = _add_method_dunders( + type("__internal__", (ModelSettings,), {"__module__": "openllm._configuration"}), + attr.make_class( + "ModelSettings", + { + k: dantic.Field( + kw_only=False if t.get_origin(ann) is not Required else True, + auto_default=True, + use_default_converter=False, + type=ann, + metadata={ + "target": f"__openllm_{k}__", + "required": False if t.get_origin(ann) is NotRequired else t.get_origin(ann) is Required, + }, + description=f"ModelSettings field for {k}.", + ) + for k, ann in t.get_type_hints(ModelSettings).items() + }, + bases=(DictStrAny,), + slots=True, + weakref_slot=True, + collect_by_mro=True, + ), + _overwrite_doc="Internal attrs representation of ModelSettings.", +) + - _internal_config = t.cast(ModelSettings, getattr(cls, "__config__", {})) - default_id = _internal_config.get("default_id", None) - if default_id is None: - raise RuntimeError("'default_id' is required under '__config__'.") - model_ids = _internal_config.get("model_ids", None) - if model_ids is None: - raise RuntimeError("'model_ids' is required under '__config__'.") +def structure_settings(cl_: type[LLMConfig], cls: type[t.Any]): + if not lenient_issubclass(cl_, LLMConfig): + raise RuntimeError(f"Given LLMConfig must be a subclass type of 'LLMConfig', got '{cl_}' instead.") + settings = cl_.__config__ - def _first_not_null(key: str, default: _T) -> _T: - return first_not_none(_internal_config.get(key), default=default) + if settings is None: + raise RuntimeError("Given LLMConfig must have '__config__' defined.") - llm_config_striped = cls.__name__.replace("Config", "") + required = [i.name for i in attr.fields(cls) if i.metadata.get("required", False)] + if any(k not in settings for k in required): + raise ValueError(f"The following keys are required under '__config__': {required}") + if not settings["default_id"] or not settings["model_ids"]: + raise ValueError("Make sure that either 'default_id', 'model_ids' are not emptied under '__config__'.") - name_type: t.Literal["dasherize", "lowercase"] = _first_not_null("name_type", "dasherize") + if any(k in settings for k in ("env", "start_name", "model_name")): + raise ValueError("The following keys are not allowed under '__config__': env, start_name, model_name") - if name_type == "dasherize": - default_model_name = inflection.underscore(llm_config_striped) - default_start_name = inflection.dasherize(default_model_name) - else: - default_model_name = llm_config_striped.lower() - default_start_name = default_model_name - - model_name = _first_not_null("model_name", default_model_name) - - return ModelSettings( - name_type=name_type, - model_name=model_name, - default_id=default_id, - model_ids=model_ids, - start_name=_first_not_null("start_name", default_start_name), - url=_first_not_null("url", "(not provided)"), - requires_gpu=_first_not_null("requires_gpu", False), - trust_remote_code=_first_not_null("trust_remote_code", False), - requirements=_first_not_null("requirements", ListStr()), - env=_first_not_null("env", openllm.utils.ModelEnv(model_name)), - timeout=_first_not_null("timeout", 3600), - workers_per_resource=_first_not_null("workers_per_resource", 1), - generation_class=attr.make_class( - cls.__name__.replace("Config", "GenerationConfig"), + if "generation_class" in settings: + raise ValueError( + "'generation_class' shouldn't be defined in '__config__', rather defining " + f"all required attributes under '{cl_}.GenerationConfig' when defining the class." + ) + + _cl_name = cl_.__name__.replace("Config", "") + name_type = first_not_none(settings.get("name_type"), "dasherize") + model_name = inflection.underscore(_cl_name) if name_type == "dasherize" else _cl_name.lower() + start_name = inflection.dasherize(model_name) if name_type == "dasherize" else model_name + partialed = functools.partial(_field_env_key, model_name=model_name, suffix="generation") + + def auto_env_transformers(_: t.Any, fields: list[attr.Attribute[t.Any]]) -> list[attr.Attribute[t.Any]]: + _has_own_gen = _has_own_attribute(cl_, "GenerationConfig") + return [ + f.evolve( + default=_populate_value_from_env_var( + partialed(key=f.name), + fallback=getattr(cl_.GenerationConfig, f.name, f.default) if _has_own_gen else f.default, + ), + metadata={"env": partialed(key=f.name), "description": f.metadata.get("description", "(not provided)")}, + converter=None, + ) + for f in fields + ] + + _target: DictStrAny = { + "default_id": settings["default_id"], + "model_ids": settings["model_ids"], + "url": settings.get("url", ""), + "requires_gpu": settings.get("requires_gpu", False), + "trust_remote_code": settings.get("trust_remote_code", False), + "requirements": settings.get("requirements", None), + "name_type": name_type, + "model_name": model_name, + "start_name": start_name, + "env": openllm.utils.ModelEnv(model_name), + "timeout": settings.get("timeout", 3600), + "workers_per_resource": settings.get("workers_per_resource", 1), + "generation_class": attr.make_class( + f"{_cl_name}GenerationConfig", [], bases=(GenerationConfig,), - frozen=True, slots=True, + weakref_slot=True, + frozen=False, repr=True, - cache_hash=True, - field_transformer=_make_auto_generation_env(cls, model_name), + field_transformer=auto_env_transformers, ), - ) + } + + return cls(**_target) + + +bentoml_cattr.register_structure_hook(_ModelSettings, structure_settings) def _generate_unique_filename(cls: type[t.Any], func_name: str): return f"" -def _setattr_class(attr_name: str, value_var: t.Any): +def _setattr_class(attr_name: str, value_var: t.Any, add_dunder: bool = False): """ Use the builtin setattr to set *attr_name* to *value_var*. We can't use the cached object.__setattr__ since we are setting attributes to a class. """ + if add_dunder: + return f"setattr(cls, '{attr_name}', __add_dunder(cls, {value_var}))" return f"setattr(cls, '{attr_name}', {value_var})" -@t.overload -def _make_assignment_script(cls: type[LLMConfig], attributes: ModelSettings) -> t.Callable[..., None]: - ... - - -@t.overload -def _make_assignment_script(cls: type[LLMConfig], attributes: dict[str, t.Any]) -> t.Callable[..., None]: - ... +_dunder_add = {"generation_class"} -def _make_assignment_script(cls: type[LLMConfig], attributes: t.Any) -> t.Callable[..., None]: +def _make_assignment_script(cls: type[LLMConfig], attributes: attr.AttrsInstance) -> t.Callable[..., None]: """Generate the assignment script with prefix attributes __openllm___""" - args: list[str] = [] - globs: dict[str, t.Any] = {"cls": cls, "attr_dict": attributes} - annotations: dict[str, t.Any] = {"return": None} + args: ListStr = [] + globs: DictStrAny = { + "cls": cls, + "_cached_attribute": attributes, + "_cached_getattribute_get": _object_getattribute.__get__, + "__add_dunder": _add_method_dunders, + } + annotations: DictStrAny = {"return": None} - lines: list[str] = [] - for attr_name in attributes: - arg_name = f"__openllm_{inflection.underscore(attr_name)}__" - args.append(f"{attr_name}=attr_dict['{attr_name}']") - lines.append(_setattr_class(arg_name, attr_name)) - annotations[attr_name] = type(attributes[attr_name]) + lines: ListStr = ["_getattr = _cached_getattribute_get(_cached_attribute)"] + for attr_name, field in attr.fields_dict(attributes.__class__).items(): + arg_name = field.metadata.get("target", f"__openllm_{inflection.underscore(attr_name)}__") + args.append(f"{attr_name}=getattr(_cached_attribute, '{attr_name}')") + lines.append(_setattr_class(arg_name, attr_name, add_dunder=attr_name in _dunder_add)) + annotations[attr_name] = field.type script = "def __assign_attr(cls, %s):\n %s\n" % (", ".join(args), "\n ".join(lines) if lines else "pass") assign_method = _make_method( @@ -676,13 +711,16 @@ def _make_assignment_script(cls: type[LLMConfig], attributes: t.Any) -> t.Callab ) assign_method.__annotations__ = annotations + if DEBUG: + logger.info("Generated script:\n%s", script) + return assign_method -_object_setattr = object.__setattr__ +_reserved_namespace = {"__config__", "GenerationConfig"} -@attr.define +@attr.define(slots=True) class LLMConfig: """ ``openllm.LLMConfig`` is somewhat a hybrid combination between the performance of `attrs` with the @@ -705,7 +743,6 @@ class LLMConfig: ```python class FlanT5Config(openllm.LLMConfig): - class GenerationConfig: temperature: float = 0.75 max_new_tokens: int = 3000 @@ -716,16 +753,29 @@ class GenerationConfig: By doing so, openllm.LLMConfig will create a compatible GenerationConfig attrs class that can be converted to ``transformers.GenerationConfig``. These attribute can be accessed via ``LLMConfig.generation_config``. - By default, all LLMConfig has a __config__ that contains a default value. If any LLM requires customization, - provide a ``__config__`` under the class declaration: + By default, all LLMConfig must provide a __config__ with 'default_id' and 'model_ids'. + + All other fields are optional, and will be use default value if not set. ```python class FalconConfig(openllm.LLMConfig): - __config__ = {"trust_remote_code": True, "default_timeout": 3600000} + __config__ = { + "name_type": "lowercase", + "trust_remote_code": True, + "requires_gpu": True, + "timeout": 3600000, + "url": "https://falconllm.tii.ae/", + "requirements": ["einops", "xformers", "safetensors"], + # NOTE: The below are always required + "default_id": "tiiuae/falcon-7b", + "model_ids": [ + "tiiuae/falcon-7b", + "tiiuae/falcon-40b", + "tiiuae/falcon-7b-instruct", + "tiiuae/falcon-40b-instruct", + ], + } ``` - - Note that ``model_name``, ``start_name``, and ``env`` is optional under ``__config__``. If set, then OpenLLM - will respect that option for start and other components within the library. """ Field = dantic.Field @@ -761,7 +811,7 @@ class GenerationConfig: """ # NOTE: Internal attributes that should only be used by OpenLLM. Users usually shouldn't - # concern any of these. + # concern any of these. These are here for pyright not to complain. def __attrs_init__(self, **attrs: t.Any): """Generated __attrs_init__ for LLMConfig subclass that follows the attrs contract.""" @@ -771,16 +821,17 @@ def __attrs_init__(self, **attrs: t.Any): __attrs_attrs__ will be handled dynamically by __init_subclass__. """ - __openllm_attrs__: tuple[str, ...] = tuple() - """Internal attribute tracking to store converted LLMConfig attributes to correct attrs""" - - __openllm_hints__: dict[str, t.Any] = Field(None, init=False) + __openllm_hints__: DictStrAny = Field(None, init=False) """An internal cache of resolved types for this LLMConfig.""" __openllm_accepted_keys__: set[str] = Field(None, init=False) """The accepted keys for this LLMConfig.""" - # NOTE: The following will be populated from __config__ + __openllm_extras__: DictStrAny = Field(None, init=False) + """Extra metadata for this LLMConfig.""" + + # NOTE: The following will be populated from __config__ and also + # considered to be public API. __openllm_url__: str = Field(None, init=False) """The resolved url for this LLMConfig.""" @@ -790,7 +841,7 @@ def __attrs_init__(self, **attrs: t.Any): __openllm_trust_remote_code__: bool = False """Whether to always trust remote code""" - __openllm_requirements__: list[str] | None = None + __openllm_requirements__: ListStr | None = None """The default PyPI requirements needed to run this given LLM. By default, we will depend on bentoml, torch, transformers.""" @@ -820,7 +871,7 @@ def __attrs_init__(self, **attrs: t.Any): """Return the default model to use when using 'openllm start '. This could be one of the keys in 'self.model_ids' or custom users model.""" - __openllm_model_ids__: list[str] = Field(None) + __openllm_model_ids__: ListStr = Field(None) """A list of supported pretrained models tag for this given runnable. For example: @@ -833,30 +884,38 @@ def __attrs_init__(self, **attrs: t.Any): to create the generation_config argument that can be used throughout the lifecycle.""" def __init_subclass__(cls): - # NOTE: auto assignment attributes generated from __config__ - _make_assignment_script(cls, _gen_default_settings(cls))(cls) + """The purpose of this __init_subclass__ is that we want all subclass of LLMConfig + to adhere to the attrs contract, and have pydantic-like interface. This means we will + construct all fields and metadata and hack into how attrs use some of the 'magic' construction + to generate the fields. - # NOTE: Since we want to enable a pydantic-like experience - # this means we will have to hide the attr abstraction, and generate - # all of the Field from __init_subclass__ - anns = _get_annotations(cls) - cd = cls.__dict__ + It also does a few more extra features: It also generate all __openllm_*__ config from + ModelSettings (derived from __config__) to the class. + """ + if not cls.__name__.endswith("Config"): + logger.warning("LLMConfig subclass should end with 'Config'. Updating to %sConfig", cls.__name__) + cls.__name__ = f"{cls.__name__}Config" - def field_env_key(key: str) -> str: - return f"OPENLLM_{cls.__openllm_model_name__.upper()}_{key.upper()}" + # NOTE: auto assignment attributes generated from __config__ + _make_assignment_script(cls, bentoml_cattr.structure(cls, _ModelSettings))(cls) + # process a fields under cls.__dict__ and auto convert them with dantic.Field + cd = cls.__dict__ + anns = _get_annotations(cls) + partialed = functools.partial(_field_env_key, model_name=cls.__openllm_model_name__) def auto_config_env(_: type[LLMConfig], attrs: list[attr.Attribute[t.Any]]) -> list[attr.Attribute[t.Any]]: return [ a.evolve( - default=_populate_value_from_env_var(a.name, transform=field_env_key, fallback=a.default), + default=_populate_value_from_env_var(partialed(key=a.name), fallback=a.default), metadata={ - "env": a.metadata.get("env", field_env_key(a.name)), + "env": a.metadata.get("env", partialed(key=a.name)), "description": a.metadata.get("description", "(not provided)"), }, ) for a in attrs ] + # _CountingAttr is the underlying representation of attr.field ca_names = {name for name, attr in cd.items() if isinstance(attr, _CountingAttr)} these: dict[str, _CountingAttr[t.Any]] = {} annotated_names: set[str] = set() @@ -867,31 +926,29 @@ def auto_config_env(_: type[LLMConfig], attrs: list[attr.Attribute[t.Any]]) -> l val = cd.get(attr_name, attr.NOTHING) if not LazyType["_CountingAttr[t.Any]"](_CountingAttr).isinstance(val): if val is attr.NOTHING: - val = cls.Field(env=field_env_key(attr_name)) + val = cls.Field(env=partialed(key=attr_name)) else: - val = cls.Field(default=val, env=field_env_key(attr_name)) + val = cls.Field(default=val, env=partialed(key=attr_name)) these[attr_name] = val unannotated = ca_names - annotated_names - if len(unannotated) > 0: missing_annotated = sorted(unannotated, key=lambda n: t.cast("_CountingAttr[t.Any]", cd.get(n)).counter) raise openllm.exceptions.MissingAnnotationAttributeError( f"The following field doesn't have a type annotation: {missing_annotated}" ) - # __openllm_attrs__ is a tracking tuple[attr.Attribute[t.Any]] - # that we construct ourself. - cls.__openllm_attrs__ = tuple(these.keys()) - cls.__openllm_accepted_keys__ = set(cls.__openllm_attrs__) | set( - attr.fields_dict(cls.__openllm_generation_class__) - ) - cls.__openllm_hints__ = {**t.get_type_hints(cls), **t.get_type_hints(cls.__openllm_generation_class__)} + cls.__openllm_accepted_keys__ = set(these.keys()) | { + a.name for a in attr.fields(cls.__openllm_generation_class__) + } + # 'generation_config' is a special fields that wraps the GenerationConfig class + # which is handled in _make_assignment_script these["generation_config"] = cls.Field( default=cls.__openllm_generation_class__(), description=inspect.cleandoc(cls.__openllm_generation_class__.__doc__ or ""), ) + # Generate the base __attrs_attrs__ transformation here. attrs, base_attrs, base_attr_map = _transform_attrs( cls, # the current class these, # the parsed attributes we previous did @@ -900,20 +957,15 @@ def auto_config_env(_: type[LLMConfig], attrs: list[attr.Attribute[t.Any]]) -> l True, # collect_by_mro field_transformer=auto_config_env, ) - - _cls_dict = dict(cls.__dict__) + _weakref_slot = True # slots = True _base_names = {a.name for a in base_attrs} _attr_names = tuple(a.name for a in attrs) - _slots = True - _weakref_slot = True - _delete_attribs = not bool(these) _has_pre_init = bool(getattr(cls, "__attrs_pre_init__", False)) _has_post_init = bool(getattr(cls, "__attrs_post_init__", False)) - # the protocol for attrs-decorated class cls.__attrs_attrs__ = attrs - - # generate a __attrs_init__ for the subclass + # generate a __attrs_init__ for the subclass, since we will + # implement a custom __init__ cls.__attrs_init__ = _add_method_dunders( cls, _make_init( @@ -923,7 +975,7 @@ def auto_config_env(_: type[LLMConfig], attrs: list[attr.Attribute[t.Any]]) -> l _has_post_init, # post_init False, # frozen True, # slots - True, # cache_hash + False, # cache_hash base_attr_map, # base_attr_map False, # is_exc (check if it is exception) None, # cls_on_setattr (essentially attr.setters) @@ -932,10 +984,9 @@ def auto_config_env(_: type[LLMConfig], attrs: list[attr.Attribute[t.Any]]) -> l ) # __repr__ function with the updated fields. cls.__repr__ = _add_method_dunders(cls, _make_repr(cls.__attrs_attrs__, None, cls)) - # Traverse the MRO to collect existing slots # and check for an existing __weakref__. - existing_slots: dict[str, t.Any] = dict() + existing_slots: DictStrAny = dict() weakref_inherited = False for base_cls in cls.__mro__[1:-1]: if base_cls.__dict__.get("__weakref__", None) is not None: @@ -958,29 +1009,44 @@ def auto_config_env(_: type[LLMConfig], attrs: list[attr.Attribute[t.Any]]) -> l # As their descriptors may be overridden by a child class, # we collect them here and update the class dict reused_slots = {slot: slot_descriptor for slot, slot_descriptor in existing_slots.items() if slot in slot_names} - slot_names = [name for name in slot_names if name not in reused_slots] - setattr(cls, "__slots__", tuple(slot_names)) + # __openllm_extras__ holds additional metadata that might be usefule for users, hence we add it to slots + slot_names = [name for name in slot_names if name not in reused_slots] + ["__openllm_extras__"] + cls.__slots__ = tuple(slot_names) + # Finally, resolve the types + if getattr(cls, "__attrs_types_resolved__", None) != cls: + # NOTE: We will try to resolve type here, and cached it for faster use + # It will be about 15-20ms slower comparing not resolving types. + globs: DictStrAny = {"t": t, "typing": t, "Constraint": Constraint} + if cls.__module__ in sys.modules: + globs.update(sys.modules[cls.__module__].__dict__) + attr.resolve_types(cls.__openllm_generation_class__, globalns=globs) + + cls = attr.resolve_types(cls, globalns=globs) + # the hint cache for easier access + cls.__openllm_hints__ = { + f.name: f.type for ite in map(attr.fields, (cls, cls.__openllm_generation_class__)) for f in ite + } + + def __setattr__(self, attr: str, value: t.Any): + if attr in _reserved_namespace: + raise ForbiddenAttributeError( + f"{attr} should not be set during runtime " + f"as these value will be reflected during runtime. " + f"Instead, you can create a custom LLM subclass {self.__class__.__name__}." + ) + + super().__setattr__(attr, value) def __init__( self, *, - generation_config: dict[str, t.Any] | None = None, - __openllm_extras__: dict[str, t.Any] | None = None, + generation_config: DictStrAny | None = None, + __openllm_extras__: DictStrAny | None = None, **attrs: t.Any, ): - # create a copy of the list of keys as cache + # create a copy of the keys as cache _cached_keys = tuple(attrs.keys()) - self.__openllm_extras__ = first_not_none(__openllm_extras__, default={}) - config_merger.merge( - self.__openllm_extras__, {k: v for k, v in attrs.items() if k not in self.__openllm_accepted_keys__} - ) - - for k in _cached_keys: - if k in self.__openllm_extras__ or attrs.get(k) is None: - del attrs[k] - _cached_keys = tuple(k for k in _cached_keys if k in attrs) - _generation_cl_dict = attr.fields_dict(self.__openllm_generation_class__) if generation_config is None: generation_config = {k: v for k, v in attrs.items() if k in _generation_cl_dict} @@ -988,30 +1054,44 @@ def __init__( generation_keys = {k for k in attrs if k in _generation_cl_dict} if len(generation_keys) > 0: logger.warning( - "When 'generation_config' is passed, \ - the following keys are ignored and won't be used: %s. If you wish to use those values, \ - pass it into 'generation_config'.", + "Both 'generation_config' and keys for 'generation_config' are passed." + " The following keys in 'generation_config' will be overriden be keywords argument: %s", ", ".join(generation_keys), ) - for k in _cached_keys: - if k in generation_keys: - del attrs[k] - _cached_keys = tuple(k for k in _cached_keys if k in attrs) + config_merger.merge(generation_config, {k: v for k, v in attrs.items() if k in generation_keys}) for k in _cached_keys: - if k in generation_config: + if k in generation_config or attrs.get(k) is None: del attrs[k] + _cached_keys = tuple(k for k in _cached_keys if k in attrs) + + self.__openllm_extras__ = first_not_none(__openllm_extras__, default={}) + config_merger.merge( + self.__openllm_extras__, {k: v for k, v in attrs.items() if k not in self.__openllm_accepted_keys__} + ) + + for k in _cached_keys: + if k in self.__openllm_extras__: + del attrs[k] + _cached_keys = tuple(k for k in _cached_keys if k in attrs) + + if DEBUG: + logger.info( + "Creating %s with the following attributes: %s, generation_config=%s", + self.__class__.__name__, + _cached_keys, + generation_config, + ) # The rest of attrs should only be the attributes to be passed to __attrs_init__ self.__attrs_init__(generation_config=self.__openllm_generation_class__(**generation_config), **attrs) - def __getattr__(self, item: str) -> t.Any: - if hasattr(self.generation_config, item): - return getattr(self.generation_config, item) - elif item in self.__openllm_extras__: - return self.__openllm_extras__[item] - else: - return super().__getattribute__(item) + def __getattribute__(self, item: str) -> t.Any: + if item in _reserved_namespace: + raise ForbiddenAttributeError( + f"'{item}' is a reserved namespace for {self.__class__} and should not be access nor modified." + ) + return _object_getattribute.__get__(self)(item) @classmethod def check_if_gpu_is_available(cls, implementation: str | None = None, force: bool = False): @@ -1055,7 +1135,7 @@ def model_construct_env(cls, **attrs: t.Any) -> t.Self: """ attrs = {k: v for k, v in attrs.items() if v is not None} - model_config = ModelEnv(cls.__openllm_model_name__).model_config + model_config = cls.__openllm_env__.model_config env_json_string = os.environ.get(model_config, None) @@ -1064,9 +1144,10 @@ def model_construct_env(cls, **attrs: t.Any) -> t.Self: config_from_env = orjson.loads(env_json_string) except orjson.JSONDecodeError as e: raise RuntimeError(f"Failed to parse '{model_config}' as valid JSON string.") from e - ncls = bentoml_cattr.structure(config_from_env, cls) else: - ncls = cls() + config_from_env = {} + + env_struct = bentoml_cattr.structure(config_from_env, cls) if "generation_config" in attrs: generation_config = attrs.pop("generation_config") @@ -1074,16 +1155,19 @@ def model_construct_env(cls, **attrs: t.Any) -> t.Self: raise RuntimeError(f"Expected a dictionary, but got {type(generation_config)}") else: generation_config = { - k: v for k, v in attrs.items() if k in attr.fields_dict(ncls.__openllm_generation_class__) + k: v for k, v in attrs.items() if k in attr.fields_dict(env_struct.__openllm_generation_class__) } - attrs = {k: v for k, v in attrs.items() if k not in generation_config} - return attr.evolve(ncls, generation_config=generation_config, **attrs) + for k in tuple(attrs.keys()): + if k in generation_config: + del attrs[k] + + return attr.evolve(env_struct, generation_config=generation_config, **attrs) - def model_validate_click(self, **attrs: t.Any) -> tuple[LLMConfig, dict[str, t.Any]]: + def model_validate_click(self, **attrs: t.Any) -> tuple[LLMConfig, DictStrAny]: """Parse given click attributes into a LLMConfig and return the remaining click attributes.""" - llm_config_attrs: dict[str, t.Any] = {"generation_config": {}} - key_to_remove: list[str] = [] + llm_config_attrs: DictStrAny = {"generation_config": {}} + key_to_remove: ListStr = [] for k, v in attrs.items(): if k.startswith(f"{self.__openllm_model_name__}_generation_"): @@ -1096,14 +1180,14 @@ def model_validate_click(self, **attrs: t.Any) -> tuple[LLMConfig, dict[str, t.A return self.model_construct_env(**llm_config_attrs), {k: v for k, v in attrs.items() if k not in key_to_remove} @t.overload - def to_generation_config(self, return_as_dict: t.Literal[True] = ...) -> dict[str, t.Any]: + def to_generation_config(self, return_as_dict: t.Literal[True] = ...) -> DictStrAny: ... @t.overload def to_generation_config(self, return_as_dict: t.Literal[False] = ...) -> transformers.GenerationConfig: ... - def to_generation_config(self, return_as_dict: bool = False) -> transformers.GenerationConfig | dict[str, t.Any]: + def to_generation_config(self, return_as_dict: bool = False) -> transformers.GenerationConfig | DictStrAny: config = transformers.GenerationConfig(**bentoml_cattr.unstructure(self.generation_config)) return config.to_dict() if return_as_dict else config @@ -1134,11 +1218,13 @@ def to_click_options(cls, f: t.Callable[..., t.Any]) -> t.Callable[..., t.Any]: f = attrs_to_options(name, field, cls.__openllm_model_name__, typ=ty, suffix_generation=True)(f) f = optgroup.group(f"{cls.__openllm_generation_class__.__name__} generation options")(f) - if len(cls.__openllm_attrs__) == 0: + if len(cls.__openllm_accepted_keys__.difference(set(attr.fields_dict(cls.__openllm_generation_class__)))) == 0: # NOTE: in this case, the function is already a ClickFunctionWrapper # hence the casting return f + # We pop out 'generation_config' as it is a attribute that we don't + # need to expose to CLI. for name, field in attr.fields_dict(cls).items(): ty = cls.__openllm_hints__.get(name) if t.get_origin(ty) is t.Union or name == "generation_config": @@ -1155,7 +1241,7 @@ def to_click_options(cls, f: t.Callable[..., t.Any]) -> t.Callable[..., t.Any]: ) -def structure_llm_config(data: dict[str, t.Any], cls: type[LLMConfig]) -> LLMConfig: +def structure_llm_config(data: DictStrAny, cls: type[LLMConfig]) -> LLMConfig: """ Structure a dictionary to a LLMConfig object. @@ -1168,8 +1254,8 @@ def structure_llm_config(data: dict[str, t.Any], cls: type[LLMConfig]) -> LLMCon if not LazyType(DictStrAny).isinstance(data): raise RuntimeError(f"Expected a dictionary, but got {type(data)}") - cls_attrs = {k: v for k, v in data.items() if k in cls.__openllm_attrs__} generation_cls_fields = attr.fields_dict(cls.__openllm_generation_class__) + cls_attrs = {k: v for k, v in data.items() if k in cls.__openllm_accepted_keys__ and k not in generation_cls_fields} if "generation_config" in data: generation_config = data.pop("generation_config") if not LazyType(DictStrAny).isinstance(generation_config): diff --git a/src/openllm/_llm.py b/src/openllm/_llm.py index fba95b406..2061aa445 100644 --- a/src/openllm/_llm.py +++ b/src/openllm/_llm.py @@ -472,6 +472,8 @@ def load_model(self, tag: bentoml.Tag, *args: t.Any, **attrs: t.Any) -> t.Any: However, if `model_id` is a path, this argument is recomended to include. """ + load_in_mha = attrs.pop("load_in_mha", False) + if llm_config is not None: logger.debug("Using given 'llm_config=(%s)' to initialize LLM", llm_config) self.config = llm_config @@ -494,6 +496,10 @@ def load_model(self, tag: bentoml.Tag, *args: t.Any, **attrs: t.Any) -> t.Any: if self.__openllm_post_init__: self.llm_post_init() + # finally: we allow users to overwrite the load_in_mha defined by the LLM subclass. + if load_in_mha: + self.load_in_mha = load_in_mha + def __setattr__(self, attr: str, value: t.Any): if attr in _reserved_namespace: raise ForbiddenAttributeError( diff --git a/src/openllm/_package.py b/src/openllm/_package.py index 334507ba5..00f936d0f 100644 --- a/src/openllm/_package.py +++ b/src/openllm/_package.py @@ -143,6 +143,7 @@ def construct_docker_options(llm: openllm.LLM[t.Any, t.Any], _: FS, workers_per_ env={ llm.config.__openllm_env__.framework: llm.config.__openllm_env__.get_framework_env(), "OPENLLM_MODEL": llm.config.__openllm_model_name__, + "OPENLLM_MODEL_ID": llm._model_id, "BENTOML_DEBUG": str(get_debug_mode()), "BENTOML_CONFIG_OPTIONS": _bentoml_config_options, }, @@ -165,8 +166,10 @@ def build(model_name: str, *, __cli__: bool = False, **attrs: t.Any) -> tuple[be overwrite_existing_bento = attrs.pop("_overwrite_existing_bento", False) current_model_envvar = os.environ.pop("OPENLLM_MODEL", None) + current_model_id_envvar = os.environ.pop("OPENLLM_MODEL_ID", None) _previously_built = False workers = attrs.pop("_workers", None) + model_id: str = attrs.pop("model_id", None) llm_config = openllm.AutoConfig.for_model(model_name) @@ -176,14 +179,15 @@ def build(model_name: str, *, __cli__: bool = False, **attrs: t.Any) -> tuple[be # during build. This is a current limitation of bentoml build where we actually import the service.py into sys.path try: os.environ["OPENLLM_MODEL"] = inflection.underscore(model_name) + os.environ["OPENLLM_MODEL_ID"] = model_id to_use_framework = llm_config.__openllm_env__.get_framework_env() if to_use_framework == "flax": - llm = openllm.AutoFlaxLLM.for_model(model_name, llm_config=llm_config, **attrs) + llm = openllm.AutoFlaxLLM.for_model(model_name, model_id=model_id, llm_config=llm_config, **attrs) elif to_use_framework == "tf": - llm = openllm.AutoTFLLM.for_model(model_name, llm_config=llm_config, **attrs) + llm = openllm.AutoTFLLM.for_model(model_name, model_id=model_id, llm_config=llm_config, **attrs) else: - llm = openllm.AutoLLM.for_model(model_name, llm_config=llm_config, **attrs) + llm = openllm.AutoLLM.for_model(model_name, model_id=model_id, llm_config=llm_config, **attrs) labels = dict(llm.identifying_params) labels.update({"_type": llm.llm_type, "_framework": to_use_framework}) @@ -226,6 +230,9 @@ def build(model_name: str, *, __cli__: bool = False, **attrs: t.Any) -> tuple[be raise finally: del os.environ["OPENLLM_MODEL"] + del os.environ["OPENLLM_MODEL_ID"] # restore original OPENLLM_MODEL envvar if set. if current_model_envvar is not None: os.environ["OPENLLM_MODEL"] = current_model_envvar + if current_model_id_envvar is not None: + os.environ["OPENLLM_MODEL_ID"] = current_model_id_envvar diff --git a/src/openllm/_service.py b/src/openllm/_service.py index 9eea9b0f8..98f5f46f0 100644 --- a/src/openllm/_service.py +++ b/src/openllm/_service.py @@ -31,9 +31,10 @@ import openllm model = os.environ.get("OPENLLM_MODEL", "{__model_name__}") # openllm: model name +model_id = os.environ.get("OPENLLM_MODEL_ID", "{__model_id__}") # openllm: model id llm_config = openllm.AutoConfig.for_model(model) -runner = openllm.Runner(model, llm_config=llm_config) +runner = openllm.Runner(model, model_id=model_id, llm_config=llm_config) svc = bentoml.Service(name=f"llm-{llm_config.__openllm_start_name__}-service", runners=[runner]) diff --git a/src/openllm/cli.py b/src/openllm/cli.py index 7d7c107ab..36aa5628f 100644 --- a/src/openllm/cli.py +++ b/src/openllm/cli.py @@ -41,6 +41,8 @@ import openllm +from .utils import DEBUG, LazyType, ModelEnv, analytics, bentoml_cattr, first_not_none + if t.TYPE_CHECKING: from ._types import ClickFunctionWrapper, F, P @@ -138,7 +140,7 @@ def start_model_command( @group.command(**command_attrs) def noop() -> openllm.LLMConfig: _echo("No GPU available, therefore this command is disabled", fg="red") - openllm.utils.analytics.track_start_init(llm_config) + analytics.track_start_init(llm_config) return llm_config return noop @@ -182,10 +184,8 @@ def model_start( fg="yellow", ) - workers_per_resource = openllm.utils.first_not_none( - workers, default=llm.config.__openllm_workers_per_resource__ - ) - server_timeout = openllm.utils.first_not_none(server_timeout, default=llm.config.__openllm_timeout__) + workers_per_resource = first_not_none(workers, default=llm.config.__openllm_workers_per_resource__) + server_timeout = first_not_none(server_timeout, default=llm.config.__openllm_timeout__) num_workers = int(1 / workers_per_resource) if num_workers > 1: @@ -207,6 +207,7 @@ def model_start( # NOTE: This is to set current configuration _bentoml_config_options = start_env.pop("BENTOML_CONFIG_OPTIONS", "") _bentoml_config_options_opts = [ + "tracing.sample_rate=1.0", f"api_server.traffic.timeout={server_timeout}", f'runners."llm-{llm.config.__openllm_start_name__}-runner".traffic.timeout={llm.config.__openllm_timeout__}', f'runners."llm-{llm.config.__openllm_start_name__}-runner".workers_per_resource={workers_per_resource}', @@ -229,6 +230,7 @@ def model_start( llm.config.__openllm_env__.framework: llm.config.__openllm_env__.get_framework_env(), llm.config.__openllm_env__.model_config: llm.config.model_dump_json().decode(), "OPENLLM_MODEL": model_name, + "OPENLLM_MODEL_ID": llm._model_id, "BENTOML_DEBUG": str(get_debug_mode()), "BENTOML_CONFIG_OPTIONS": _bentoml_config_options, "BENTOML_HOME": os.environ.get("BENTOML_HOME", BentoMLContainer.bentoml_home.get()), @@ -243,7 +245,7 @@ def model_start( server = server_cls("_service.py:svc", **server_attrs) try: - openllm.utils.analytics.track_start_init(llm.config) + analytics.track_start_init(llm.config) server.start(env=start_env, text=True, blocking=True) except Exception as err: _echo(f"Error caught while starting LLM Server:\n{err}", fg="red") @@ -273,8 +275,6 @@ def common_params(f: F[P, t.Any]) -> ClickFunctionWrapper[..., t.Any]: from bentoml._internal.configuration import DEBUG_ENV_VAR, QUIET_ENV_VAR, set_debug_mode - from .utils import analytics - @click.option("-q", "--quiet", envvar=QUIET_ENV_VAR, is_flag=True, default=False, help="Suppress all output.") @click.option( "--debug", "--verbose", envvar=DEBUG_ENV_VAR, is_flag=True, default=False, help="Print out debug logs." @@ -308,8 +308,6 @@ def usage_tracking( """This is not supposed to be used with unprocessed click function. This should be used a the last currying from common_params -> usage_tracking -> exception_handling """ - from .utils import analytics - command_name = attrs.get("name", func.__name__) @functools.wraps(func) @@ -532,7 +530,7 @@ def parse_device_callback(_: click.Context, params: click.Parameter, value: tupl if value is None: return value - if not openllm.utils.LazyType(TupleStrAny).isinstance(value): + if not LazyType(TupleStrAny).isinstance(value): raise RuntimeError(f"{params} only accept multiple values.") parsed: tuple[str, ...] = tuple() for v in value: @@ -552,14 +550,12 @@ def _start( **attrs: t.Any, ): """Python API to start a LLM server.""" - from . import utils - _serve_grpc = attrs.pop("_serve_grpc", False) - ModelEnv = utils.ModelEnv(model_name) + _ModelEnv = ModelEnv(model_name) if framework is not None: - os.environ[ModelEnv.framework] = framework + os.environ[_ModelEnv.framework] = framework start_model_command(model_name, t.cast(OpenLLMCommandGroup, cli), _serve_grpc=_serve_grpc)( standalone_mode=False, **attrs ) @@ -581,7 +577,7 @@ def _start( ) -def model_id_option(factory: t.Any, model_env: openllm.utils.ModelEnv | None = None): +def model_id_option(factory: t.Any, model_env: ModelEnv | None = None): envvar = None if model_env is not None: envvar = model_env.model_id @@ -698,7 +694,10 @@ def build(model_name: str, model_id: str | None, overwrite: bool, output: Output + "* Push to BentoCloud with `bentoml push`:\n" + f" $ bentoml push {bento.tag}\n" + "* Containerize your Bento with `bentoml containerize`:\n" - + f" $ bentoml containerize {bento.tag}", + + f" $ bentoml containerize {bento.tag}\n" + + " Tip: To enable additional BentoML feature for 'containerize', " + + "use '--enable-features=FEATURE[,FEATURE]' " + + "[see 'bentoml containerize -h' for more advanced usage]\n", fg="blue", ) elif output == "json": @@ -732,28 +731,36 @@ def models(output: OutputLiteral, show_available: bool): else: failed_initialized: list[tuple[str, Exception]] = [] - json_data: dict[str, dict[t.Literal["model_id", "description", "runtime_impl"], t.Any]] = {} + json_data: dict[ + str, dict[t.Literal["model_id", "url", "installation", "requires_gpu", "runtime_impl"], t.Any] + ] = {} + + # NOTE: Keep a sync list with ./tools/update-optional-dependencies.py + extras = ["chatglm", "falcon", "flan-t5", "starcoder"] converted: list[str] = [] for m in models: - try: - model = openllm.AutoLLM.for_model(m) - docs = inspect.cleandoc(model.config.__doc__ or "(No description)") - runtime_impl: tuple[t.Literal["pt", "flax", "tf"], ...] = tuple() - if model.config.__openllm_model_name__ in openllm.MODEL_MAPPING_NAMES: - runtime_impl += ("pt",) - if model.config.__openllm_model_name__ in openllm.MODEL_FLAX_MAPPING_NAMES: - runtime_impl += ("flax",) - if model.config.__openllm_model_name__ in openllm.MODEL_TF_MAPPING_NAMES: - runtime_impl += ("tf",) - json_data[m] = { - "model_id": model.config.__openllm_model_ids__, - "description": docs, - "runtime_impl": runtime_impl, - } - converted.extend([convert_transformers_model_name(i) for i in model.config.__openllm_model_ids__]) - except Exception as err: - failed_initialized.append((m, err)) + config = openllm.AutoConfig.for_model(m) + runtime_impl: tuple[t.Literal["pt", "flax", "tf"], ...] = tuple() + if config.__openllm_model_name__ in openllm.MODEL_MAPPING_NAMES: + runtime_impl += ("pt",) + if config.__openllm_model_name__ in openllm.MODEL_FLAX_MAPPING_NAMES: + runtime_impl += ("flax",) + if config.__openllm_model_name__ in openllm.MODEL_TF_MAPPING_NAMES: + runtime_impl += ("tf",) + json_data[m] = { + "model_id": config.__openllm_model_ids__, + "url": config.__openllm_url__, + "requires_gpu": config.__openllm_requires_gpu__, + "runtime_impl": runtime_impl, + "installation": "pip install openllm" if m not in extras else f'pip install "openllm[{m}]"', + } + converted.extend([convert_transformers_model_name(i) for i in config.__openllm_model_ids__]) + if DEBUG: + try: + openllm.AutoLLM.for_model(m, llm_config=config) + except Exception as err: + failed_initialized.append((m, err)) ids_in_local_store = None if show_available: @@ -764,10 +771,30 @@ def models(output: OutputLiteral, show_available: bool): tabulate.PRESERVE_WHITESPACE = True - data: list[str | tuple[str, str, list[str], tuple[t.Literal["pt", "flax", "tf"], ...]]] = [] + data: list[ + str | tuple[str, str, list[str], str, t.LiteralString, tuple[t.Literal["pt", "flax", "tf"], ...]] + ] = [] for m, v in json_data.items(): - data.extend([(m, v["description"], v["model_id"], v["runtime_impl"])]) - column_widths = [int(COLUMNS / 6), int(COLUMNS / 2), int(COLUMNS / 3), int(COLUMNS / 9)] + data.extend( + [ + ( + m, + v["url"], + v["model_id"], + v["installation"], + "✅" if v["requires_gpu"] else "❌", + v["runtime_impl"], + ) + ] + ) + column_widths = [ + int(COLUMNS / 6), + int(COLUMNS / 6), + int(COLUMNS / 3), + int(COLUMNS / 6), + int(COLUMNS / 6), + int(COLUMNS / 9), + ] if len(data) == 0 and len(failed_initialized) > 0: _echo("Exception found while parsing models:\n", fg="yellow") @@ -779,7 +806,7 @@ def models(output: OutputLiteral, show_available: bool): table = tabulate.tabulate( data, tablefmt="fancy_grid", - headers=["LLM", "Description", "Models Id", "Runtime"], + headers=["LLM", "URL", "Models Id", "Installation", "GPU Only", "Runtime"], maxcolwidths=column_widths, ) @@ -790,7 +817,7 @@ def models(output: OutputLiteral, show_available: bool): ) _echo(formatted_table, fg="white") - if len(failed_initialized) > 0: + if DEBUG and len(failed_initialized) > 0: _echo("\nThe following models are supported but failed to initialize:\n") for m, err in failed_initialized: _echo(f"- {m}: ", fg="blue", nl=False) @@ -805,7 +832,7 @@ def models(output: OutputLiteral, show_available: bool): dumped: dict[str, t.Any] = json_data if show_available: assert ids_in_local_store - dumped["local"] = [openllm.utils.bentoml_cattr.unstructure(i.tag) for i in ids_in_local_store] + dumped["local"] = [bentoml_cattr.unstructure(i.tag) for i in ids_in_local_store] _echo( orjson.dumps( dumped, diff --git a/src/openllm/models/flan_t5/modeling_flan_t5.py b/src/openllm/models/flan_t5/modeling_flan_t5.py index dfd25cc41..d2162e33e 100644 --- a/src/openllm/models/flan_t5/modeling_flan_t5.py +++ b/src/openllm/models/flan_t5/modeling_flan_t5.py @@ -22,6 +22,7 @@ if t.TYPE_CHECKING: import torch + import transformers # noqa else: torch = openllm.utils.LazyLoader("torch", globals(), "torch") diff --git a/src/openllm/models/flan_t5/modeling_flax_flan_t5.py b/src/openllm/models/flan_t5/modeling_flax_flan_t5.py index 7cf468198..2c014ccab 100644 --- a/src/openllm/models/flan_t5/modeling_flax_flan_t5.py +++ b/src/openllm/models/flan_t5/modeling_flax_flan_t5.py @@ -20,6 +20,9 @@ from ..._prompt import default_formatter from .configuration_flan_t5 import DEFAULT_PROMPT_TEMPLATE +if t.TYPE_CHECKING: + import transformers # noqa + class FlaxFlanT5(openllm.LLM["transformers.FlaxT5ForConditionalGeneration", "transformers.T5TokenizerFast"]): __openllm_internal__ = True diff --git a/src/openllm/models/flan_t5/modeling_tf_flan_t5.py b/src/openllm/models/flan_t5/modeling_tf_flan_t5.py index d1c974611..a21a6506a 100644 --- a/src/openllm/models/flan_t5/modeling_tf_flan_t5.py +++ b/src/openllm/models/flan_t5/modeling_tf_flan_t5.py @@ -20,6 +20,9 @@ from ..._prompt import default_formatter from .configuration_flan_t5 import DEFAULT_PROMPT_TEMPLATE +if t.TYPE_CHECKING: + import transformers # noqa + class TFFlanT5(openllm.LLM["transformers.TFT5ForConditionalGeneration", "transformers.T5TokenizerFast"]): __openllm_internal__ = True diff --git a/src/openllm/models/stablelm/modeling_stablelm.py b/src/openllm/models/stablelm/modeling_stablelm.py index f4520f468..5ba0202fb 100644 --- a/src/openllm/models/stablelm/modeling_stablelm.py +++ b/src/openllm/models/stablelm/modeling_stablelm.py @@ -23,6 +23,9 @@ from ..._prompt import default_formatter from .configuration_stablelm import DEFAULT_PROMPT_TEMPLATE, SYSTEM_PROMPT +if t.TYPE_CHECKING: + import transformers # noqa + class StopOnTokens(StoppingCriteria): def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: diff --git a/src/openllm/utils/__init__.py b/src/openllm/utils/__init__.py index 867ff4066..ccfd35578 100644 --- a/src/openllm/utils/__init__.py +++ b/src/openllm/utils/__init__.py @@ -19,6 +19,7 @@ import sys import types +import os import typing as t from bentoml._internal.types import LazyType @@ -61,6 +62,8 @@ def lenient_issubclass(cls: t.Any, class_or_tuple: type[t.Any] | tuple[type[t.An raise +DEBUG = sys.flags.dev_mode or (not sys.flags.ignore_environment and bool(os.environ.get("OPENLLMDEVDEBUG"))) + _import_structure = { "analytics": [], "codegen": [], @@ -103,6 +106,7 @@ def lenient_issubclass(cls: t.Any, class_or_tuple: type[t.Any] | tuple[type[t.An "pkg": pkg, "LazyModule": LazyModule, "LazyType": LazyType, + "DEBUG": DEBUG, "LazyLoader": LazyLoader, "bentoml_cattr": bentoml_cattr, "copy_file_to_fs_folder": copy_file_to_fs_folder, diff --git a/src/openllm/utils/codegen.py b/src/openllm/utils/codegen.py index 25aab09a3..41bc9e1a0 100644 --- a/src/openllm/utils/codegen.py +++ b/src/openllm/utils/codegen.py @@ -34,21 +34,17 @@ from fs.base import FS class ModifyNodeProtocol(t.Protocol): - @t.overload - def __call__(self, node: Node, model_name: str) -> None: - ... - - @t.overload - def __call__(self, node: Node, *args: t.Any, **attrs: t.Any) -> None: + def __call__(self, node: Node, model_name: str, formatter: type[ModelNameFormatter]) -> None: ... logger = logging.getLogger(__name__) OPENLLM_MODEL_NAME = {"# openllm: model name"} +OPENLLM_MODEL_ID = {"# openllm: model id"} -class ModelFormatter(string.Formatter): +class ModelNameFormatter(string.Formatter): model_keyword: t.LiteralString = "__model_name__" def __init__(self, model_name: str): @@ -69,6 +65,10 @@ def is_correct_leaf(self, leaf: Leaf): return leaf.type == token.STRING and self.can_format(leaf.value) +class ModelIdFormatter(ModelNameFormatter): + model_keyword: t.LiteralString = "__model_id__" + + def recurse_modify_node(node: Node | Leaf, node_type: int, callables: ModifyNodeProtocol, *args: t.Any) -> Node | None: if isinstance(node, Node) and node.type == node_type: callables(node, *args) @@ -76,12 +76,12 @@ def recurse_modify_node(node: Node | Leaf, node_type: int, callables: ModifyNode recurse_modify_node(child, node_type, callables, *args) -def modify_node_with_comments(node: Node, model_name: str): +def modify_node_with_comments(node: Node, model_name: str, formatter: type[ModelNameFormatter]): """ Modify the node with comments '# openllm: model name' and replace the formatted value with the actual model name. """ - _formatter = ModelFormatter(model_name) + _formatter = formatter(model_name) for children in node.children: if isinstance(children, Leaf) and _formatter.is_correct_leaf(children): children.value = _formatter.vformat(children.value) @@ -118,7 +118,14 @@ def _parse_service_file(src_contents: str, model_name: str, mode: Mode) -> str: for comment in list_comments(leaf.prefix, is_endmarker=False): if comment.value in OPENLLM_MODEL_NAME: assert leaf.prev_sibling is not None, "'# openllm: model name' line must not be modified." - recurse_modify_node(leaf.prev_sibling, syms.arglist, modify_node_with_comments, model_name) + recurse_modify_node( + leaf.prev_sibling, syms.arglist, modify_node_with_comments, model_name, ModelNameFormatter + ) + if comment.value in OPENLLM_MODEL_ID: + assert leaf.prev_sibling is not None, "'# openllm: model id' line must not be modified." + recurse_modify_node( + leaf.prev_sibling, syms.arglist, modify_node_with_comments, model_name, ModelIdFormatter + ) # NOTE: The below is the same as black.format_str dst_blocks: list[LinesBlock] = [] diff --git a/src/openllm/utils/dantic.py b/src/openllm/utils/dantic.py index 76e28dc6e..8b62391e0 100644 --- a/src/openllm/utils/dantic.py +++ b/src/openllm/utils/dantic.py @@ -53,6 +53,8 @@ def Field( validator: _ValidatorType[_T] | None = None, description: str | None = None, env: str | None = None, + auto_default: bool = False, + use_default_converter: bool = True, **attrs: t.Any, ): """A decorator that extends attr.field with additional arguments, which provides the same @@ -63,7 +65,14 @@ def Field( Args: ge: Greater than or equal to. Defaults to None. - docs: the documentation for the field. Defaults to None. + description: the documentation for the field. Defaults to None. + env: the environment variable to read from. Defaults to None. + auto_default: a bool indicating whether to use the default value as the environment. + Defaults to False. If set to True, the behaviour of this Field will also depends + on kw_only. If kw_only=True, the this field will become 'Required' and the default + value is omitted. If kw_only=False, then the default value will be used as before. + use_default_converter: a bool indicating whether to use the default converter. Defaults + to True. If set to False, then the default converter will not be used. **kwargs: The rest of the arguments are passed to attr.field """ metadata = attrs.pop("metadata", {}) @@ -74,7 +83,9 @@ def Field( metadata["env"] = env piped: list[_ValidatorType[t.Any]] = [] - converter = attrs.pop("converter", functools.partial(_default_converter, env=env)) + converter = attrs.pop("converter", None) + if use_default_converter: + converter = functools.partial(_default_converter, env=env) if ge is not None: piped.append(attr.validators.ge(ge)) @@ -99,6 +110,10 @@ def Field( else: attrs["default"] = default + kw_only = attrs.pop("kw_only", False) + if auto_default and kw_only: + attrs.pop("default") + return attr.field(metadata=metadata, validator=_validator, converter=converter, **attrs) diff --git a/tests/__init__.py b/tests/__init__.py index 3a2faba50..1e3029e66 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -11,3 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import os + +from hypothesis import HealthCheck, settings + +settings.register_profile("CI", settings(suppress_health_check=[HealthCheck.too_slow]), deadline=None) + +if "CI" in os.environ: + settings.load_profile("CI") diff --git a/tests/conftest.py b/tests/_strategies/__init__.py similarity index 100% rename from tests/conftest.py rename to tests/_strategies/__init__.py diff --git a/tests/_strategies/_configuration.py b/tests/_strategies/_configuration.py new file mode 100644 index 000000000..a8b6c02a4 --- /dev/null +++ b/tests/_strategies/_configuration.py @@ -0,0 +1,76 @@ +# Copyright 2023 BentoML Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import logging +import typing as t + +from hypothesis import strategies as st + +import openllm +from openllm._configuration import ModelSettings + +logger = logging.getLogger(__name__) + +env_strats = st.sampled_from([openllm.utils.ModelEnv(model_name) for model_name in openllm.CONFIG_MAPPING.keys()]) + + +@st.composite +def model_settings(draw: st.DrawFn): + """Strategy for generating ModelSettings objects.""" + kwargs: dict[str, t.Any] = dict( + default_id=st.text(min_size=1), + model_ids=st.lists(st.text(), min_size=1), + url=st.text(), + requires_gpu=st.booleans(), + trust_remote_code=st.booleans(), + requirements=st.none() | st.lists(st.text(), min_size=1), + name_type=st.sampled_from(["dasherize", "lowercase"]), + timeout=st.integers(min_value=3600), + workers_per_resource=st.one_of(st.integers(min_value=1), st.floats(min_value=0.1, max_value=1.0)), + ) + return draw(st.builds(ModelSettings, **kwargs)) + + +def make_llm_config( + cls_name: str, + dunder_config: dict[str, t.Any] | ModelSettings, + fields: tuple[tuple[t.LiteralString, str, t.Any], ...] | None = None, + generation_fields: tuple[tuple[t.LiteralString, t.Any], ...] | None = None, +) -> type[openllm.LLMConfig]: + globs: dict[str, t.Any] = {"openllm": openllm} + _config_args: list[str] = [] + lines: list[str] = [f"class {cls_name}Config(openllm.LLMConfig):"] + for attr, value in dunder_config.items(): + _config_args.append(f'"{attr}": __attr_{attr}') + globs[f"_{cls_name}Config__attr_{attr}"] = value + lines.append(f' __config__ = {{ {", ".join(_config_args)} }}') + if fields is not None: + for field, type_, default in fields: + lines.append(f" {field}: {type_} = {repr(default)}") + if generation_fields is not None: + generation_lines = ["class GenerationConfig:"] + for field, default in generation_fields: + generation_lines.append(f" {field} = {repr(default)}") + lines.extend(map(lambda line: " " + line, generation_lines)) + + script = "\n".join(lines) + + if openllm.utils.DEBUG: + logger.info("Generated class %s:\n%s", cls_name, script) + + eval(compile(script, "name", "exec"), globs) + + return globs[f"{cls_name}Config"] diff --git a/tests/models/flan_t5/__init__.py b/tests/models/flan_t5/__init__.py new file mode 100644 index 000000000..3a2faba50 --- /dev/null +++ b/tests/models/flan_t5/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2023 BentoML Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/models/flan_t5/test_modeling_flan_t5.py b/tests/models/flan_t5/test_modeling_flan_t5.py new file mode 100644 index 000000000..b10536c7f --- /dev/null +++ b/tests/models/flan_t5/test_modeling_flan_t5.py @@ -0,0 +1,43 @@ +# Copyright 2023 BentoML Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import pytest + +import openllm + + +@pytest.fixture +def qa_prompt() -> str: + return ( + "Answer the following yes/no question by reasoning step-by-step. Can you write a whole Haiku in a single tweet?" + ) + + +@pytest.fixture +def flan_t5_id() -> str: + return "google/flan-t5-small" + + +def test_small_flan(qa_prompt: str, flan_t5_id: str): + llm = openllm.AutoLLM.for_model("flan-t5", model_id=flan_t5_id) + generate = llm(qa_prompt) + assert generate + + +def test_small_runner_flan(qa_prompt: str, flan_t5_id: str): + llm = openllm.Runner("flan-t5", model_id=flan_t5_id, init_local=True) + generate = llm(qa_prompt) + assert generate diff --git a/tests/test_configuration.py b/tests/test_configuration.py new file mode 100644 index 000000000..39af2549b --- /dev/null +++ b/tests/test_configuration.py @@ -0,0 +1,152 @@ +# Copyright 2023 BentoML Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""All configuration-related tests for openllm.LLMConfig. This will include testing +for ModelEnv construction and parsing environment variables.""" +from __future__ import annotations + +import logging + +import pytest +from hypothesis import assume, given +from hypothesis import strategies as st + +import openllm +from openllm._configuration import GenerationConfig, ModelSettings, _field_env_key +from openllm.utils import DEBUG + +from ._strategies._configuration import make_llm_config, model_settings + +logger = logging.getLogger(__name__) + + +def test_missing_default(): + with pytest.raises(ValueError, match="The following keys are required*"): + make_llm_config("MissingDefaultId", {"name_type": "lowercase", "requirements": ["bentoml"]}) + + with pytest.raises(ValueError, match="The following keys are required*"): + make_llm_config("MissingModelId", {"default_id": "huggingface/t5-tiny-testing", "requirements": ["bentoml"]}) + + +def test_forbidden_access(): + cl_ = make_llm_config( + "ForbiddenAccess", + { + "default_id": "huggingface/t5-tiny-testing", + "model_ids": ["huggingface/t5-tiny-testing", "bentoml/t5-tiny-testing"], + "requirements": ["bentoml"], + }, + ) + + assert pytest.raises( + openllm.exceptions.ForbiddenAttributeError, + cl_.__getattribute__, + cl_(), + "__config__", + ) + assert pytest.raises( + openllm.exceptions.ForbiddenAttributeError, + cl_.__getattribute__, + cl_(), + "GenerationConfig", + ) + + assert openllm.utils.lenient_issubclass(cl_.__openllm_generation_class__, GenerationConfig) + + +@given(model_settings()) +def test_class_normal_gen(gen_settings: ModelSettings): + assume(gen_settings["default_id"] and gen_settings["model_ids"]) + cl_: type[openllm.LLMConfig] = make_llm_config("NotFullLLM", gen_settings) + assert issubclass(cl_, openllm.LLMConfig) + for key in gen_settings: + assert object.__getattribute__(cl_, f"__openllm_{key}__") == gen_settings.__getitem__(key) + + +@given(model_settings(), st.integers()) +def test_simple_struct_dump(gen_settings: ModelSettings, field1: int): + cl_ = make_llm_config("IdempotentLLM", gen_settings, fields=(("field1", "float", field1),)) + assert cl_().model_dump()["field1"] == field1 + + +@given( + model_settings(), + st.integers(max_value=283473), + st.floats(min_value=0.0, max_value=1.0), + st.integers(max_value=283473), + st.floats(min_value=0.0, max_value=1.0), +) +def test_complex_struct_dump( + gen_settings: ModelSettings, field1: int, temperature: float, input_field1: int, input_temperature: float +): + cl_ = make_llm_config( + "ComplexLLM", + gen_settings, + fields=(("field1", "float", field1),), + generation_fields=(("temperature", temperature),), + ) + sent = cl_() + assert ( + sent.model_dump()["field1"] == field1 and sent.model_dump()["generation_config"]["temperature"] == temperature + ) + assert ( + sent.model_dump(flatten=True)["field1"] == field1 + and sent.model_dump(flatten=True)["temperature"] == temperature + ) + + passed = cl_(field1=input_field1, temperature=input_temperature) + assert ( + passed.model_dump()["field1"] == input_field1 + and passed.model_dump()["generation_config"]["temperature"] == input_temperature + ) + assert ( + passed.model_dump(flatten=True)["field1"] == input_field1 + and passed.model_dump(flatten=True)["temperature"] == input_temperature + ) + + pas_nested = cl_(generation_config={"temperature": input_temperature}, field1=input_field1) + assert ( + pas_nested.model_dump()["field1"] == input_field1 + and pas_nested.model_dump()["generation_config"]["temperature"] == input_temperature + ) + + +def test_struct_envvar(monkeypatch: pytest.MonkeyPatch): + class EnvLLM(openllm.LLMConfig): + __config__ = {"default_id": "asdfasdf", "model_ids": ["asdf", "asdfasdfads"]} + field1: int = 2 + + class GenerationConfig: + temperature: float = 0.8 + + f1_env = _field_env_key(EnvLLM.__openllm_model_name__, "field1") + temperature_env = _field_env_key(EnvLLM.__openllm_model_name__, "temperature", suffix="generation") + + if DEBUG: + logger.info(f"Env keys: {f1_env}, {temperature_env}") + + with monkeypatch.context() as m: + m.setenv(f1_env, "4") + m.setenv(temperature_env, "0.2") + sent = EnvLLM() + assert sent.field1 == 4 + assert sent.generation_config.temperature == 0.8 + + # NOTE: This is the expected behaviour, where users pass in value, we respect it over envvar. + with monkeypatch.context() as m: + m.setenv(f1_env, "4") + m.setenv(temperature_env, "0.2") + sent = EnvLLM.model_construct_env(field1=20, temperature=0.4) + assert sent.field1 == 4 + assert sent.generation_config.temperature == 0.4 diff --git a/typings/attr/__init__.pyi b/typings/attr/__init__.pyi index 80424eec6..e81d52ad0 100644 --- a/typings/attr/__init__.pyi +++ b/typings/attr/__init__.pyi @@ -57,6 +57,9 @@ _OnSetAttrArgType = Union[_OnSetAttrType, List[_OnSetAttrType], setters._NoOpTyp _FieldTransformer = Callable[[type, List["Attribute[Any]"]], List["Attribute[Any]"]] _ValidatorArgType = Union[_ValidatorType[_T], Sequence[_ValidatorType[_T]]] +class ReprProtocol(Protocol): + def __call__(__self, self: Any) -> str: ... + class AttrsInstance(AttrsInstance_, Protocol): ... _A = TypeVar("_A", bound=AttrsInstance) @@ -503,7 +506,7 @@ def _make_init( attrs_init: bool, ) -> Callable[_P, Any]: ... def _make_method(name: str, script: str, filename: str, globs: dict[str, Any]) -> Callable[..., Any]: ... -def _make_repr(attrs: tuple[Attribute[Any]], ns: str | None, cls: AttrsInstance) -> Callable[[AttrsInstance], str]: ... +def _make_repr(attrs: tuple[Attribute[Any]], ns: str | None, cls: AttrsInstance) -> ReprProtocol: ... def _transform_attrs( cls: type[AttrsInstance], these: dict[str, _CountingAttr[_T]] | None,