Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support custom specified models for HF model parsers #863

Merged
merged 1 commit into from
Jan 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
from .local_inference.text_summarization import HuggingFaceTextSummarizationTransformer
from .local_inference.text_translation import HuggingFaceTextTranslationTransformer
from .remote_inference_client.text_generation import HuggingFaceTextGenerationParser
from .local_inference.util import get_hf_model

UTILS = [get_hf_model]

LOCAL_INFERENCE_CLASSES = [
"HuggingFaceAutomaticSpeechRecognitionTransformer",
Expand All @@ -18,4 +20,4 @@
"HuggingFaceTextTranslationTransformer",
]
REMOTE_INFERENCE_CLASSES = ["HuggingFaceTextGenerationParser"]
__ALL__ = LOCAL_INFERENCE_CLASSES + REMOTE_INFERENCE_CLASSES
__ALL__ = LOCAL_INFERENCE_CLASSES + REMOTE_INFERENCE_CLASSES + UTILS
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@

import torch
from transformers import pipeline, Pipeline
from aiconfig_extension_hugging_face.local_inference.util import get_hf_model
from aiconfig import ParameterizedModelParser, InferenceOptions
from aiconfig.callback import CallbackEvent
from aiconfig.schema import Prompt, Output, ExecuteResult, Attachment


if TYPE_CHECKING:
from aiconfig import AIConfigRuntime

Expand Down Expand Up @@ -96,15 +96,16 @@ async def run_inference(self, prompt: Prompt, aiconfig: "AIConfigRuntime", optio

model_settings = self.get_model_settings(prompt, aiconfig)
[pipeline_creation_data, _] = refine_pipeline_creation_params(model_settings)
model_name = aiconfig.get_model_name(prompt)
model_name = get_hf_model(aiconfig, prompt, self)
key = model_name if model_name is not None else "__default__"

if isinstance(model_name, str) and model_name not in self.pipelines:
if key not in self.pipelines:
device = self._get_device()
if pipeline_creation_data.get("device", None) is None:
pipeline_creation_data["device"] = device
self.pipelines[model_name] = pipeline(task="automatic-speech-recognition", **pipeline_creation_data)
self.pipelines[key] = pipeline(task="automatic-speech-recognition", model=model_name, **pipeline_creation_data)

asr_pipeline = self.pipelines[model_name]
asr_pipeline = self.pipelines[key]
completion_data = await self.deserialize(prompt, aiconfig, parameters)

response = asr_pipeline(**completion_data)
Expand Down Expand Up @@ -234,8 +235,8 @@ def refine_asr_completion_params(unfiltered_completion_params: Dict[str, Any]) -
Note: This doesn't support base pipeline params like `num_workers`
TODO: Figure out how to find which params are supported.

TODO: Distinguish pipeline creation and refine completion
https://github.com/lastmile-ai/aiconfig/issues/825
TODO: Distinguish pipeline creation and refine completion
https://github.com/lastmile-ai/aiconfig/issues/825
https://github.com/lastmile-ai/aiconfig/issues/824
"""

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
pipeline,
)

from aiconfig_extension_hugging_face.local_inference.util import get_hf_model

from aiconfig import ParameterizedModelParser, InferenceOptions
from aiconfig.callback import CallbackEvent
from aiconfig.schema import (
Expand Down Expand Up @@ -108,7 +110,7 @@ async def deserialize(
model_settings = self.get_model_settings(prompt, aiconfig)
completion_params = refine_completion_params(model_settings)

#Add image inputs
# Add image inputs
inputs = validate_and_retrieve_images_from_attachments(prompt)
completion_params["inputs"] = inputs

Expand All @@ -127,10 +129,12 @@ async def run_inference(self, prompt: Prompt, aiconfig: "AIConfigRuntime", optio
completion_data = await self.deserialize(prompt, aiconfig, parameters)
inputs = completion_data.pop("inputs")

model_name: str | None = aiconfig.get_model_name(prompt)
if isinstance(model_name, str) and model_name not in self.pipelines:
self.pipelines[model_name] = pipeline(task="image-to-text", model=model_name)
captioner = self.pipelines[model_name]
model_name = get_hf_model(aiconfig, prompt, self)
key = model_name if model_name is not None else "__default__"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: double underscores are usually associated with dunder methods or special built-in values in Python. Would prefer a different value here To avoid any confusion

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can make a little enum class or something, like

class Sentinel(Enum):
    DEFAUT = "DEFAULT"

Then check for that value wherever needed and supply the real default value in a specific use case


if key not in self.pipelines:
self.pipelines[key] = pipeline(task="image-to-text", model=model_name)
captioner = self.pipelines[key]

outputs: List[Output] = []
response: List[Any] = captioner(inputs, **completion_data)
Expand Down Expand Up @@ -189,6 +193,7 @@ def refine_completion_params(model_settings: Dict[str, Any]) -> Dict[str, Any]:

return completion_data


# Helper methods
def construct_regular_output(result: Dict[str, str], execution_count: int) -> Output:
"""
Expand All @@ -198,7 +203,7 @@ def construct_regular_output(result: Dict[str, str], execution_count: int) -> Ou
**{
"output_type": "execute_result",
# For some reason result is always in list format we haven't found
# a way of being able to return multiple sequences from the image
# a way of being able to return multiple sequences from the image
# to text pipeline
"data": result[0]["generated_text"],
"execution_count": execution_count,
Expand Down Expand Up @@ -251,7 +256,7 @@ def validate_and_retrieve_images_from_attachments(prompt: Prompt) -> list[Union[
# vs. uri. This will be fixed once we have standardized inputs
# See https://github.com/lastmile-ai/aiconfig/issues/829
if len(input_data) > 10000:
pil_image : Image = Image.open(BytesIO(base64.b64decode(input_data)))
pil_image: Image = Image.open(BytesIO(base64.b64decode(input_data)))
images.append(pil_image)
else:
images.append(input_data)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from PIL import Image
from transformers import Pipeline

from aiconfig_extension_hugging_face.local_inference.util import get_hf_model

from aiconfig.default_parsers.parameterized_model_parser import ParameterizedModelParser
from aiconfig.model_parser import InferenceOptions
Expand Down Expand Up @@ -124,11 +125,12 @@ def refine_image_completion_params(unfiltered_completion_params: Dict[str, Any])
return completion_params


class ImageData():
class ImageData:
"""
Helper class to store each image response data as fields instead
Helper class to store each image response data as fields instead
of separate arrays. See `_refine_responses` for more details
"""

image: Image.Image
nsfw_content_detected: bool

Expand Down Expand Up @@ -289,17 +291,17 @@ async def run_inference(self, prompt: Prompt, aiconfig: "AIConfigRuntime", optio
"""
print(pipeline_building_disclaimer_message)

model_name: str = aiconfig.get_model_name(prompt)
model_name = get_hf_model(aiconfig, prompt, self)
key = model_name if model_name is not None else "__default__"

# TODO (rossdanlm): Figure out a way to save model and re-use checkpoint
# Otherwise right now a lot of these models are taking 5 mins to load with 50
# num_inference_steps (default value). See here for more details:
# https://huggingface.co/docs/diffusers/using-diffusers/loading#checkpoint-variants
if isinstance(model_name, str) and model_name not in self.generators:
if key not in self.generators:
device = self._get_device()
self.generators[model_name] = AutoPipelineForText2Image.from_pretrained(pretrained_model_or_path=model_name, **pipeline_creation_data).to(
device
)
generator = self.generators[model_name]
self.generators[key] = AutoPipelineForText2Image.from_pretrained(pretrained_model_or_path=model_name, **pipeline_creation_data).to(device)
generator = self.generators[key]

disclaimer_long_response_print_message = """\n
Calling image generation. This can take a long time, (up to SEVERAL MINUTES depending
Expand Down Expand Up @@ -370,21 +372,22 @@ def _get_device(self) -> str:
return "mps"
return "cpu"


def _refine_responses(
response_images: List[Image.Image],
response_images: List[Image.Image],
nsfw_content_detected: List[bool],
) -> List[ImageData]:
"""
Helper function for taking the separate response data lists (`images` and
`nsfw_content_detected`) from StableDiffusionPipelineOutput or
`nsfw_content_detected`) from StableDiffusionPipelineOutput or
StableDiffusionXLPipelineOutput and merging this data into a single array
containing ImageData which stores information at the image-level. This
makes processing later easier since all the data we need is stored in a
containing ImageData which stores information at the image-level. This
makes processing later easier since all the data we need is stored in a
single object, so we don't need to compare two separate lists

Args:
response_images List[Image.Image]: List of images
nsfw_content_detected List[bool]: List of whether the image at that
nsfw_content_detected List[bool]: List of whether the image at that
corresponding index from `response_images` has detected that it
contains nsfw_content. It is possible for this list to be empty

Expand All @@ -396,8 +399,5 @@ def _refine_responses(
# Use zip.longest because nsfw_content_detected can be empty
itertools.zip_longest(response_images, nsfw_content_detected)
)
image_data_objects: List[ImageData] = [
ImageData(image=image, nsfw_content_detected=has_nsfw)
for (image, has_nsfw) in merged_responses
]
image_data_objects: List[ImageData] = [ImageData(image=image, nsfw_content_detected=has_nsfw) for (image, has_nsfw) in merged_responses]
return image_data_objects
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@
import io
import json
import numpy as np
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
from typing import TYPE_CHECKING, Any, Dict, List, Optional
from transformers import Pipeline, pipeline
from scipy.io.wavfile import write as write_wav

from aiconfig_extension_hugging_face.local_inference.util import get_hf_model

from aiconfig.default_parsers.parameterized_model_parser import ParameterizedModelParser
from aiconfig.model_parser import InferenceOptions
from aiconfig.schema import (
Expand All @@ -24,7 +26,7 @@

# Step 1: define Helpers
def refine_pipeline_creation_params(model_settings: Dict[str, Any]) -> List[Dict[str, Any]]:
# These are from the transformers Github repo:
# These are from the transformers Github repo:
# https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_utils.py#L2534
supported_keys = {
"torch_dtype",
Expand Down Expand Up @@ -63,7 +65,7 @@ def refine_pipeline_creation_params(model_settings: Dict[str, Any]) -> List[Dict

def refine_completion_params(unfiltered_completion_params: Dict[str, Any]) -> Dict[str, Any]:
# Note: There seems to be no public API docs on what completion
# params are supported for text to speech:
# params are supported for text to speech:
# https://huggingface.co/docs/transformers/tasks/text-to-speech#inference
# The only one mentioned is `forward_params` which can contain `speaker_embeddings`
supported_keys = {}
Expand Down Expand Up @@ -194,10 +196,11 @@ async def run_inference(self, prompt: Prompt, aiconfig: "AIConfigRuntime", optio
model_settings = self.get_model_settings(prompt, aiconfig)
[pipeline_creation_data, _] = refine_pipeline_creation_params(model_settings)

model_name: str = aiconfig.get_model_name(prompt)
if isinstance(model_name, str) and model_name not in self.synthesizers:
self.synthesizers[model_name] = pipeline("text-to-speech", model_name)
synthesizer = self.synthesizers[model_name]
model_name = get_hf_model(aiconfig, prompt, self)
key = model_name if model_name is not None else "__default__"
if key not in self.synthesizers:
self.synthesizers[key] = pipeline("text-to-speech", model=model_name)
synthesizer = self.synthesizers[key]

completion_data = await self.deserialize(prompt, aiconfig, options, parameters)
inputs = completion_data.pop("prompt", None)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
TextIteratorStreamer,
)

from aiconfig_extension_hugging_face.local_inference.util import get_hf_model

from aiconfig.default_parsers.parameterized_model_parser import ParameterizedModelParser
from aiconfig.model_parser import InferenceOptions
from aiconfig.schema import (
Expand Down Expand Up @@ -79,7 +81,7 @@ def refine_completion_params(model_settings: Dict[str, Any]) -> Dict[str, Any]:
"encoder_no_repeat_ngram_size",
"decoder_start_token_id",
"num_assistant_tokens",
"num_assistant_tokens_schedule"
"num_assistant_tokens_schedule",
}

completion_data = {}
Expand All @@ -104,6 +106,7 @@ def construct_regular_output(result: Dict[str, str], execution_count: int) -> Ou
)
return output


def construct_stream_output(
streamer: TextIteratorStreamer,
options: InferenceOptions,
Expand All @@ -119,13 +122,13 @@ def construct_stream_output(

"""
output = ExecuteResult(
**{
"output_type": "execute_result",
"data": "", # We update this below
"execution_count": 0, #Multiple outputs are not supported for streaming
"metadata": {},
}
)
**{
"output_type": "execute_result",
"data": "", # We update this below
"execution_count": 0, # Multiple outputs are not supported for streaming
"metadata": {},
}
)
accumulated_message = ""
for new_text in streamer:
if isinstance(new_text, str):
Expand All @@ -152,7 +155,7 @@ def __init__(self):
config.register_model_parser(parser)
"""
super().__init__()
self.generators: dict[str, Pipeline]= {}
self.generators: dict[str, Pipeline] = {}

def id(self) -> str:
"""
Expand Down Expand Up @@ -190,9 +193,7 @@ async def serialize(
prompt = Prompt(
name=prompt_name,
input=prompt_input,
metadata=PromptMetadata(
model=model_metadata, parameters=parameters, **kwargs
),
metadata=PromptMetadata(model=model_metadata, parameters=parameters, **kwargs),
)
return [prompt]

Expand All @@ -217,14 +218,12 @@ async def deserialize(
model_settings = self.get_model_settings(prompt, aiconfig)
completion_data = refine_completion_params(model_settings)

#Add resolved prompt
# Add resolved prompt
resolved_prompt = resolve_prompt(prompt, params, aiconfig)
completion_data["prompt"] = resolved_prompt
return completion_data

async def run_inference(
self, prompt: Prompt, aiconfig: "AIConfigRuntime", options: InferenceOptions, parameters: Dict[str, Any]
) -> List[Output]:
async def run_inference(self, prompt: Prompt, aiconfig: "AIConfigRuntime", options: InferenceOptions, parameters: Dict[str, Any]) -> List[Output]:
"""
Invoked to run a prompt in the .aiconfig. This method should perform
the actual model inference based on the provided prompt and inference settings.
Expand All @@ -239,16 +238,15 @@ async def run_inference(
completion_data = await self.deserialize(prompt, aiconfig, options, parameters)
completion_data["text_inputs"] = completion_data.pop("prompt", None)

model_name: str | None = aiconfig.get_model_name(prompt)
if isinstance(model_name, str) and model_name not in self.generators:
self.generators[model_name] = pipeline('text-generation', model=model_name)
generator = self.generators[model_name]
model_name = get_hf_model(aiconfig, prompt, self)
key = model_name if model_name is not None else "__default__"
if key not in self.generators:
self.generators[key] = pipeline("text-generation", model=model_name)
generator = self.generators[key]

# if stream enabled in runtime options and config, then stream. Otherwise don't stream.
streamer = None
should_stream = (options.stream if options else False) and (
not "stream" in completion_data or completion_data.get("stream") != False
)
should_stream = (options.stream if options else False) and (not "stream" in completion_data or completion_data.get("stream") != False)
if should_stream:
tokenizer: AutoTokenizer = AutoTokenizer.from_pretrained(model_name)
streamer = TextIteratorStreamer(tokenizer)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
TextIteratorStreamer,
)

from aiconfig_extension_hugging_face.local_inference.util import get_hf_model

from aiconfig.default_parsers.parameterized_model_parser import ParameterizedModelParser
from aiconfig.model_parser import InferenceOptions
from aiconfig.schema import (
Expand Down Expand Up @@ -242,16 +244,15 @@ async def run_inference(self, prompt: Prompt, aiconfig: "AIConfigRuntime", optio
completion_data = await self.deserialize(prompt, aiconfig, options, parameters)
inputs = completion_data.pop("prompt", None)

model_name: str = aiconfig.get_model_name(prompt)
if isinstance(model_name, str) and model_name not in self.summarizers:
self.summarizers[model_name] = pipeline("summarization", model=model_name)
summarizer = self.summarizers[model_name]
model_name = get_hf_model(aiconfig, prompt, self)
key = model_name if model_name is not None else "__default__"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here

if key not in self.summarizers:
self.summarizers[key] = pipeline("summarization", model=model_name)
summarizer = self.summarizers[key]

# if stream enabled in runtime options and config, then stream. Otherwise don't stream.
streamer = None
should_stream = (options.stream if options else False) and (
not "stream" in completion_data or completion_data.get("stream") != False
)
should_stream = (options.stream if options else False) and (not "stream" in completion_data or completion_data.get("stream") != False)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not for this diff but this stream logic should be cleaned up

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have an bootcamp issue to track this in #861

if should_stream:
tokenizer: AutoTokenizer = AutoTokenizer.from_pretrained(model_name)
streamer = TextIteratorStreamer(tokenizer)
Expand Down
Loading