Skip to content

Commit

Permalink
Add cookbook to run Outlines with BentoML and BentoCloud
Browse files Browse the repository at this point in the history
  • Loading branch information
larme authored and rlouf committed Apr 30, 2024
1 parent 1eb9a04 commit 9188eff
Show file tree
Hide file tree
Showing 7 changed files with 343 additions and 0 deletions.
221 changes: 221 additions & 0 deletions docs/cookbook/deploy-using-bentoml.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,221 @@
# Run Outlines using BentoML

[BentoML](https://github.com/bentoml/BentoML) is an open-source model serving library for building performant and scalable AI applications with Python. It comes with tools that you need for serving optimization, model packaging, and production deployment.

In this guide, we will show you how to use BentoML to run programs written with Outlines on GPU locally and in [BentoCloud](https://www.bentoml.com/), an AI Inference Platform for enterprise AI teams. The example source code in this guide is also available in the [examples/bentoml/](https://github.com/outlines-dev/outlines/blob/main/examples/bentoml/) directory.

## Import a model

First we need to download an LLM (Mistral-7B-v0.1 in this example and you can use any other LLM) and import the model into BentoML's [Model Store](https://docs.bentoml.com/en/latest/guides/model-store.html). Let's install BentoML and other dependencies from PyPi (preferably in a virtual environment):

```bash
pip install -r requirements.txt
```

Then save the code snippet below as `import_model.py` and run `python import_model.py`.

**Note**: You need to accept related conditions on [Hugging Face](https://huggingface.co/mistralai/Mistral-7B-v0.1) first to gain access to Mistral-7B-v0.1.

```python
import bentoml

MODEL_ID = "mistralai/Mistral-7B-v0.1"
BENTO_MODEL_TAG = MODEL_ID.lower().replace("/", "--")

def import_model(model_id, bento_model_tag):

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
)

with bentoml.models.create(bento_model_tag) as bento_model_ref:
tokenizer.save_pretrained(bento_model_ref.path)
model.save_pretrained(bento_model_ref.path)


if __name__ == "__main__":
import_model(MODEL_ID, BENTO_MODEL_TAG)
```

You can verify the download is successful by running:

```bash
$ bentoml models list

Tag Module Size Creation Time
mistralai--mistral-7b-v0.1:m7lmf5ac2cmubnnz 13.49 GiB 2024-04-25 06:52:39
## Define a BentoML Service

As the model is ready, we can define a [BentoML Service](https://docs.bentoml.com/en/latest/guides/services.html) to wrap the capabilities of the model.

We will run the JSON-structured generation example [in the README](https://github.com/outlines-dev/outlines?tab=readme-ov-file#efficient-json-generation-following-a-json-schema), with the following schema:


```python
DEFAULT_SCHEMA = """{
"title": "Character",
"type": "object",
"properties": {
"name": {
"title": "Name",
"maxLength": 10,
"type": "string"
},
"age": {
"title": "Age",
"type": "integer"
},
"armor": {"$ref": "#/definitions/Armor"},
"weapon": {"$ref": "#/definitions/Weapon"},
"strength": {
"title": "Strength",
"type": "integer"
}
},
"required": ["name", "age", "armor", "weapon", "strength"],
"definitions": {
"Armor": {
"title": "Armor",
"description": "An enumeration.",
"enum": ["leather", "chainmail", "plate"],
"type": "string"
},
"Weapon": {
"title": "Weapon",
"description": "An enumeration.",
"enum": ["sword", "axe", "mace", "spear", "bow", "crossbow"],
"type": "string"
}
}
}"""
```

First, we need to define a BentoML service by decorating an ordinary class (`Outlines` here) with `@bentoml.service` decorator. We pass to this decorator some configuration and GPU on which we want this service to run in BentoCloud (here an L4 with 24GB memory):

```python
import typing as t
import bentoml
from import_model import BENTO_MODEL_TAG
@bentoml.service(
traffic={
"timeout": 300,
},
resources={
"gpu": 1,
"gpu_type": "nvidia-l4",
},
)
class Outlines:
bento_model_ref = bentoml.models.get(BENTO_MODEL_TAG)
def __init__(self) -> None:
import outlines
import torch
self.model = outlines.models.transformers(
self.bento_model_ref.path,
device="cuda",
model_kwargs={"torch_dtype": torch.float16},
)
...
```

We then need to define an HTTP endpoint using `@bentoml.api` to decorate the method `generate` of `Outlines` class:

```python
...
@bentoml.api
async def generate(
self,
prompt: str = "Give me a character description.",
json_schema: t.Optional[str] = DEFAULT_SCHEMA,
) -> t.Dict[str, t.Any]:
import outlines
generator = outlines.generate.json(self.model, json_schema)
character = generator(prompt)
return character
```

Here `@bentoml.api` decorator defines `generate` as an HTTP endpoint that accepts a JSON request body with two fields: `prompt` and `json_schema` (optional, which allows HTTP clients to provide their own JSON schema). The type hints in the function signature will be used to validate incoming JSON requests. You can define as many HTTP endpoints as you want by using `@bentoml.api` to decorate other methods of `Outlines` class.

Now you can save the above code to `service.py` (or use [this implementation](https://github.com/outlines-dev/outlines/blob/main/examples/bentoml/)), and run the code using the BentoML CLI.

## Run locally for testing and debugging

Then you can run a server locally by:

```bash
bentoml serve .
```

The server is now active at <http://localhost:3000>. You can interact with it using the Swagger UI or in other different ways:

<details>

<summary>CURL</summary>

```bash
curl -X 'POST' \
'http://localhost:3000/generate' \
-H 'accept: application/json' \
-H 'Content-Type: application/json' \
-d '{
"prompt": "Give me a character description."
}'
```

</details>

<details>

<summary>Python client</summary>

```python
import bentoml
with bentoml.SyncHTTPClient("http://localhost:3000") as client:
response = client.generate(
prompt="Give me a character description"
)
print(response)
```

</details>

Expected output:

```bash
{
"name": "Aura",
"age": 15,
"armor": "plate",
"weapon": "sword",
"strength": 20
}
## Deploy to BentoCloud
After the Service is ready, you can deploy it to [BentoCloud](https://docs.bentoml.com/en/latest/bentocloud/get-started.html) for better management and scalability. [Sign up](https://cloud.bentoml.com/signup) if you haven't got a BentoCloud account.
Make sure you have [logged in to BentoCloud](https://docs.bentoml.com/en/latest/bentocloud/how-tos/manage-access-token.html), then run the following command to deploy it.
```bash
bentoml deploy .
```
Once the application is up and running on BentoCloud, you can access it via the exposed URL.
**Note**: For custom deployment in your own infrastructure, use [BentoML to generate an OCI-compliant image](https://docs.bentoml.com/en/latest/guides/containerization.html).
5 changes: 5 additions & 0 deletions examples/bentoml/.bentoignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
__pycache__/
*.py[cod]
*$py.class
.ipynb_checkpoints
venv/
9 changes: 9 additions & 0 deletions examples/bentoml/bentofile.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
service: "service:Outlines"
labels:
owner: bentoml-team
stage: demo
include:
- "*.py"
python:
requirements_txt: "./requirements.txt"
lock_packages: false
24 changes: 24 additions & 0 deletions examples/bentoml/import_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import bentoml

MODEL_ID = "mistralai/Mistral-7B-v0.1"
BENTO_MODEL_TAG = MODEL_ID.lower().replace("/", "--")


def import_model(model_id, bento_model_tag):
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
)

with bentoml.models.create(bento_model_tag) as bento_model_ref:
tokenizer.save_pretrained(bento_model_ref.path)
model.save_pretrained(bento_model_ref.path)


if __name__ == "__main__":
import_model(MODEL_ID, BENTO_MODEL_TAG)
5 changes: 5 additions & 0 deletions examples/bentoml/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
bentoml>=1.2.11
outlines==0.0.37
transformers==4.38.2
datasets==2.18.0
accelerate==0.27.2
78 changes: 78 additions & 0 deletions examples/bentoml/service.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import typing as t

import bentoml
from import_model import BENTO_MODEL_TAG

DEFAULT_SCHEMA = """{
"title": "Character",
"type": "object",
"properties": {
"name": {
"title": "Name",
"maxLength": 10,
"type": "string"
},
"age": {
"title": "Age",
"type": "integer"
},
"armor": {"$ref": "#/definitions/Armor"},
"weapon": {"$ref": "#/definitions/Weapon"},
"strength": {
"title": "Strength",
"type": "integer"
}
},
"required": ["name", "age", "armor", "weapon", "strength"],
"definitions": {
"Armor": {
"title": "Armor",
"description": "An enumeration.",
"enum": ["leather", "chainmail", "plate"],
"type": "string"
},
"Weapon": {
"title": "Weapon",
"description": "An enumeration.",
"enum": ["sword", "axe", "mace", "spear", "bow", "crossbow"],
"type": "string"
}
}
}"""


@bentoml.service(
traffic={
"timeout": 300,
},
resources={
"gpu": 1,
"gpu_type": "nvidia-l4",
},
)
class Outlines:
bento_model_ref = bentoml.models.get(BENTO_MODEL_TAG)

def __init__(self) -> None:
import torch

import outlines

self.model = outlines.models.transformers(
self.bento_model_ref.path,
device="cuda",
model_kwargs={"torch_dtype": torch.float16},
)

@bentoml.api
async def generate(
self,
prompt: str = "Give me a character description.",
json_schema: t.Optional[str] = DEFAULT_SCHEMA,
) -> t.Dict[str, t.Any]:
import outlines

generator = outlines.generate.json(self.model, json_schema)
character = generator(prompt)

return character
1 change: 1 addition & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ nav:
- Summarize a document: cookbook/chain_of_density.md
- Playing chess: cookbook/models_playing_chess.md
- Run on the cloud:
- BentoML: cookbook/deploy-using-bentoml.md
- Modal: cookbook/deploy-using-modal.md
- Docs:
- reference/index.md
Expand Down

0 comments on commit 9188eff

Please sign in to comment.