diff --git a/backend/chainlit/__init__.py b/backend/chainlit/__init__.py index 8e61f96628..093245d31c 100644 --- a/backend/chainlit/__init__.py +++ b/backend/chainlit/__init__.py @@ -18,6 +18,7 @@ AsyncLangchainCallbackHandler, ) from chainlit.llama_index.callbacks import LlamaIndexCallbackHandler + from chainlit.openai import instrument_openai import chainlit.input_widget as input_widget from chainlit.action import Action @@ -310,6 +311,7 @@ def acall(self): "AsyncLangchainCallbackHandler": "chainlit.langchain.callbacks", "LlamaIndexCallbackHandler": "chainlit.llama_index.callbacks", "HaystackAgentCallbackHandler": "chainlit.haystack.callbacks", + "instrument_openai": "chainlit.openai", } ) @@ -362,6 +364,7 @@ def acall(self): "AsyncLangchainCallbackHandler", "LlamaIndexCallbackHandler", "HaystackAgentCallbackHandler", + "instrument_openai", ] diff --git a/backend/chainlit/openai/__init__.py b/backend/chainlit/openai/__init__.py new file mode 100644 index 0000000000..8f2cac1c03 --- /dev/null +++ b/backend/chainlit/openai/__init__.py @@ -0,0 +1,59 @@ +from typing import TYPE_CHECKING, Union + +from chainlit.context import get_context +from chainlit.step import Step +from chainlit.sync import run_sync +from chainlit.utils import check_module_version +from literalai import ChatGeneration, CompletionGeneration + + +def instrument_openai(): + if not check_module_version("openai", "1.0.0"): + raise ValueError( + "Expected OpenAI version >= 1.0.0. Run `pip install openai --upgrade`" + ) + + from literalai.instrumentation.openai import instrument_openai + + async def on_new_generation( + generation: Union["ChatGeneration", "CompletionGeneration"], timing + ): + context = get_context() + + parent_id = None + if context.current_step: + parent_id = context.current_step.id + elif context.session.root_message: + parent_id = context.session.root_message.id + + step = Step( + name=generation.model if generation.model else generation.provider, + type="llm", + parent_id=parent_id, + ) + step.generation = generation + # Convert start/end time from seconds to milliseconds + step.start = ( + timing.get("start") * 1000 + if timing.get("start", None) is not None + else None + ) + step.end = ( + timing.get("end") * 1000 if timing.get("end", None) is not None else None + ) + + if isinstance(generation, ChatGeneration): + step.input = generation.messages + step.output = generation.message_completion # type: ignore + else: + step.input = generation.prompt + step.output = generation.completion + + await step.send() + + def on_new_generation_sync( + generation: Union["ChatGeneration", "CompletionGeneration"], timing + ): + run_sync(on_new_generation(generation, timing)) + + instrument_openai(None, on_new_generation_sync) diff --git a/backend/chainlit/socket.py b/backend/chainlit/socket.py index 5d9ad938a0..8987033b01 100644 --- a/backend/chainlit/socket.py +++ b/backend/chainlit/socket.py @@ -1,5 +1,6 @@ import asyncio import json +import time import uuid from datetime import datetime from typing import Any, Dict, Literal @@ -149,7 +150,6 @@ def emit_call_fn(event: Literal["ask", "call_fn"], data, timeout): thread_id=environ.get("HTTP_X_CHAINLIT_THREAD_ID"), ) - trace_event("connection_successful") return True @@ -237,6 +237,8 @@ async def process_message(session: WebsocketSession, payload: UIMessagePayload): message = await context.emitter.process_user_message(payload) if config.code.on_message: + # Sleep 1ms to make sure any children step starts after the message step start + time.sleep(0.001) await config.code.on_message(message) except InterruptedError: pass