diff --git a/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/image_2_text.py b/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/image_2_text.py index ad408ca17..aab79c965 100644 --- a/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/image_2_text.py +++ b/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/image_2_text.py @@ -1,11 +1,21 @@ +import json from typing import Any, Dict, Optional, List, TYPE_CHECKING +from transformers import ( + Pipeline, + pipeline, +) + from aiconfig import ParameterizedModelParser, InferenceOptions from aiconfig.callback import CallbackEvent -import torch -from aiconfig.schema import Prompt, Output, ExecuteResult, Attachment - -from transformers import pipeline, Pipeline - +from aiconfig.schema import ( + Attachment, + ExecuteResult, + Output, + OutputDataWithValue, + Prompt, +) + +# Circular Dependency Type Hints if TYPE_CHECKING: from aiconfig import AIConfigRuntime @@ -93,10 +103,11 @@ async def deserialize( await aiconfig.callback_manager.run_callbacks(CallbackEvent("on_deserialize_start", __name__, {"prompt": prompt, "params": params})) # Build Completion data - completion_params = self.get_model_settings(prompt, aiconfig) + model_settings = self.get_model_settings(prompt, aiconfig) + completion_params = refine_completion_params(model_settings) + #Add image inputs inputs = validate_and_retrieve_image_from_attachments(prompt) - completion_params["inputs"] = inputs await aiconfig.callback_manager.run_callbacks(CallbackEvent("on_deserialize_complete", __name__, {"output": completion_params})) @@ -110,24 +121,93 @@ async def run_inference(self, prompt: Prompt, aiconfig: "AIConfigRuntime", optio {"prompt": prompt, "options": options, "parameters": parameters}, ) ) - model_name = aiconfig.get_model_name(prompt) - - self.pipelines[model_name] = pipeline(task="image-to-text", model=model_name) - captioner = self.pipelines[model_name] completion_data = await self.deserialize(prompt, aiconfig, parameters) inputs = completion_data.pop("inputs") - model = completion_data.pop("model") - response = captioner(inputs, **completion_data) - output = ExecuteResult(output_type="execute_result", data=response, metadata={}) + model_name: str | None = aiconfig.get_model_name(prompt) + if isinstance(model_name, str) and model_name not in self.pipelines: + self.pipelines[model_name] = pipeline(task="image-to-text", model=model_name) + captioner = self.pipelines[model_name] + + outputs: List[Output] = [] + response: List[Any] = captioner(inputs, **completion_data) + for count, result in enumerate(response): + output: Output = construct_regular_output(result, count) + outputs.append(output) - prompt.outputs = [output] - await aiconfig.callback_manager.run_callbacks(CallbackEvent("on_run_complete", __name__, {"result": prompt.outputs})) + prompt.outputs = outputs + print(f"{prompt.outputs=}") + await aiconfig.callback_manager.run_callbacks( + CallbackEvent( + "on_run_complete", + __name__, + {"result": prompt.outputs}, + ) + ) return prompt.outputs - def get_output_text(self, response: dict[str, Any]) -> str: - raise NotImplementedError("get_output_text is not implemented for HuggingFaceImage2TextTransformer") + def get_output_text( + self, + prompt: Prompt, + aiconfig: "AIConfigRuntime", + output: Optional[Output] = None, + ) -> str: + if output is None: + output = aiconfig.get_latest_output(prompt) + + if output is None: + return "" + + # TODO (rossdanlm): Handle multiple outputs in list + # https://github.com/lastmile-ai/aiconfig/issues/467 + if output.output_type == "execute_result": + output_data = output.data + if isinstance(output_data, str): + return output_data + if isinstance(output_data, OutputDataWithValue): + if isinstance(output_data.value, str): + return output_data.value + # HuggingFace Text summarization does not support function + # calls so shouldn't get here, but just being safe + return json.dumps(output_data.value, indent=2) + return "" + + +def refine_completion_params(model_settings: Dict[str, Any]) -> Dict[str, Any]: + """ + Refines the completion params for the HF image to text api. Removes any unsupported params. + The supported keys were found by looking at the HF ImageToTextPipeline.__call__ method + """ + supported_keys = { + "max_new_tokens", + "timeout", + } + + completion_data = {} + for key in model_settings: + if key.lower() in supported_keys: + completion_data[key.lower()] = model_settings[key] + + return completion_data + +# Helper methods +def construct_regular_output(result: Dict[str, str], execution_count: int) -> Output: + """ + Construct regular output per response result, without streaming enabled + """ + output = ExecuteResult( + **{ + "output_type": "execute_result", + # For some reason result is always in list format we haven't found + # a way of being able to return multiple sequences from the image + # to text pipeline + "data": result[0]["generated_text"], + "execution_count": execution_count, + "metadata": {}, + } + ) + return output def validate_attachment_type_is_image(attachment: Attachment): diff --git a/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/text_generation.py b/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/text_generation.py index c6218c82e..4da5d7037 100644 --- a/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/text_generation.py +++ b/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/text_generation.py @@ -153,7 +153,7 @@ def __init__(self): config.register_model_parser(parser) """ super().__init__() - self.generators : dict[str, Pipeline]= {} + self.generators: dict[str, Pipeline]= {} def id(self) -> str: """ @@ -217,14 +217,14 @@ async def deserialize( # Build Completion data model_settings = self.get_model_settings(prompt, aiconfig) completion_data = refine_chat_completion_params(model_settings) - + #Add resolved prompt resolved_prompt = resolve_prompt(prompt, params, aiconfig) completion_data["prompt"] = resolved_prompt return completion_data async def run_inference( - self, prompt: Prompt, aiconfig : "AIConfigRuntime", options : InferenceOptions, parameters: Dict[str, Any] + self, prompt: Prompt, aiconfig: "AIConfigRuntime", options: InferenceOptions, parameters: Dict[str, Any] ) -> List[Output]: """ Invoked to run a prompt in the .aiconfig. This method should perform @@ -239,8 +239,8 @@ async def run_inference( """ completion_data = await self.deserialize(prompt, aiconfig, options, parameters) completion_data["text_inputs"] = completion_data.pop("prompt", None) - - model_name : str = aiconfig.get_model_name(prompt) + + model_name: str | None = aiconfig.get_model_name(prompt) if isinstance(model_name, str) and model_name not in self.generators: self.generators[model_name] = pipeline('text-generation', model=model_name) generator = self.generators[model_name] @@ -255,10 +255,10 @@ async def run_inference( streamer = TextIteratorStreamer(tokenizer) completion_data["streamer"] = streamer - outputs : List[Output] = [] + outputs: List[Output] = [] output = None if not should_stream: - response : List[Any] = generator(**completion_data) + response: List[Any] = generator(**completion_data) for count, result in enumerate(response): output = construct_regular_output(result, count) outputs.append(output) @@ -267,7 +267,7 @@ async def run_inference( raise ValueError("Sorry, TextIteratorStreamer does not support multiple return sequences, please set `num_return_sequences` to 1") if not streamer: raise ValueError("Stream option is selected but streamer is not initialized") - + # For streaming, cannot call `generator` directly otherwise response will be blocking thread = threading.Thread(target=generator, kwargs=completion_data) thread.start() diff --git a/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/text_summarization.py b/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/text_summarization.py index 32b90b908..2b3b61358 100644 --- a/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/text_summarization.py +++ b/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/text_summarization.py @@ -258,12 +258,10 @@ async def run_inference(self, prompt: Prompt, aiconfig: "AIConfigRuntime", optio streamer = TextIteratorStreamer(tokenizer) completion_data["streamer"] = streamer - outputs: List[Output] = [] - output = None - def _summarize(): return summarizer(inputs, **completion_data) + outputs: List[Output] = [] if not should_stream: response: List[Any] = _summarize() for count, result in enumerate(response):