diff --git a/CHANGELOG.md b/CHANGELOG.md index 74e3dc5895..e7deed5df0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,21 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/). Nothing is unreleased! +## [1.0.400] - 2023-03-06 + +### Added + +- OpenAI integration + +### Fixed + +- Langchain final answer streaming should work again +- Elements with public URLs should be correctly persisted by the data layer + +### Changed + +- Enforce UTC DateTimes + ## [1.0.300] - 2023-02-19 ### Added diff --git a/backend/chainlit/element.py b/backend/chainlit/element.py index 06fb90dbcc..73351549ac 100644 --- a/backend/chainlit/element.py +++ b/backend/chainlit/element.py @@ -123,14 +123,14 @@ def from_dict(self, _dict: FileDict): ) async def _create(self) -> bool: - if (self.persisted or self.url) and not self.updatable: + if self.persisted and not self.updatable: return True if data_layer := get_data_layer(): try: asyncio.create_task(data_layer.create_element(self)) except Exception as e: logger.error(f"Failed to create element: {str(e)}") - if not self.chainlit_key or self.updatable: + if not self.url and (not self.chainlit_key or self.updatable): file_dict = await context.session.persist_file( name=self.name, path=self.path, @@ -203,6 +203,7 @@ class Text(Element): class Pdf(Element): """Useful to send a pdf to the UI.""" + mime: str = "application/pdf" page: Optional[int] = None type: ClassVar[ElementType] = "pdf" diff --git a/backend/chainlit/emitter.py b/backend/chainlit/emitter.py index 6cde4b5e2b..6d0cd38dbf 100644 --- a/backend/chainlit/emitter.py +++ b/backend/chainlit/emitter.py @@ -1,6 +1,5 @@ import asyncio import uuid -from datetime import datetime from typing import Any, Dict, List, Literal, Optional, Union, cast from chainlit.data import get_data_layer @@ -18,6 +17,7 @@ UIMessagePayload, ) from chainlit.user import PersistedUser +from literalai.helper import utc_now from socketio.exceptions import TimeoutError @@ -196,7 +196,7 @@ async def process_user_message(self, payload: UIMessagePayload): message = Message.from_dict(step_dict) # Overwrite the created_at timestamp with the current time - message.created_at = datetime.utcnow().isoformat() + message.created_at = utc_now() asyncio.create_task(message._create()) diff --git a/backend/chainlit/haystack/callbacks.py b/backend/chainlit/haystack/callbacks.py index 9776a7a9bc..e2232e3dc2 100644 --- a/backend/chainlit/haystack/callbacks.py +++ b/backend/chainlit/haystack/callbacks.py @@ -1,13 +1,14 @@ -from datetime import datetime -from typing import Any, Generic, List, Optional, TypeVar import re +from typing import Any, Generic, List, Optional, TypeVar from chainlit.context import context from chainlit.step import Step from chainlit.sync import run_sync -from chainlit import Message from haystack.agents import Agent, Tool from haystack.agents.agent_step import AgentStep +from literalai.helper import utc_now + +from chainlit import Message T = TypeVar("T") @@ -36,7 +37,12 @@ class HaystackAgentCallbackHandler: stack: Stack[Step] last_step: Optional[Step] - def __init__(self, agent: Agent, stream_final_answer: bool = False, stream_final_answer_agent_name: str = 'Agent'): + def __init__( + self, + agent: Agent, + stream_final_answer: bool = False, + stream_final_answer_agent_name: str = "Agent", + ): agent.callback_manager.on_agent_start += self.on_agent_start agent.callback_manager.on_agent_step += self.on_agent_step agent.callback_manager.on_agent_finish += self.on_agent_finish @@ -56,14 +62,16 @@ def on_agent_start(self, **kwargs: Any) -> None: self.stack = Stack[Step]() if self.stream_final_answer: - self.final_stream = Message(author=self.stream_final_answer_agent_name, content="") + self.final_stream = Message( + author=self.stream_final_answer_agent_name, content="" + ) self.last_tokens: List[str] = [] self.answer_reached = False root_message = context.session.root_message parent_id = root_message.id if root_message else None run_step = Step(name=self.agent_name, type="run", parent_id=parent_id) - run_step.start = datetime.utcnow().isoformat() + run_step.start = utc_now() run_step.input = kwargs run_sync(run_step.send()) @@ -73,7 +81,7 @@ def on_agent_start(self, **kwargs: Any) -> None: def on_agent_finish(self, agent_step: AgentStep, **kwargs: Any) -> None: if self.last_step: run_step = self.last_step - run_step.end = datetime.utcnow().isoformat() + run_step.end = utc_now() run_step.output = agent_step.prompt_node_response run_sync(run_step.update()) @@ -85,7 +93,7 @@ def on_agent_step(self, agent_step: AgentStep, **kwargs: Any) -> None: # If token streaming is disabled if self.last_step.output == "": self.last_step.output = agent_step.prompt_node_response - self.last_step.end = datetime.utcnow().isoformat() + self.last_step.end = utc_now() run_sync(self.last_step.update()) if not agent_step.is_last(): @@ -101,12 +109,16 @@ def on_new_token(self, token, **kwargs: Any) -> None: else: self.last_tokens.append(token) - last_tokens_concat = ''.join(self.last_tokens) - final_answer_match = re.search(self.final_answer_pattern, last_tokens_concat) + last_tokens_concat = "".join(self.last_tokens) + final_answer_match = re.search( + self.final_answer_pattern, last_tokens_concat + ) if final_answer_match: self.answer_reached = True - run_sync(self.final_stream.stream_token(final_answer_match.group(1))) + run_sync( + self.final_stream.stream_token(final_answer_match.group(1)) + ) run_sync(self.stack.peek().stream_token(token)) @@ -115,7 +127,7 @@ def on_tool_start(self, tool_input: str, tool: Tool, **kwargs: Any) -> None: parent_id = self.stack.items[0].id if self.stack.items[0] else None tool_step = Step(name=tool.name, type="tool", parent_id=parent_id) tool_step.input = tool_input - tool_step.start = datetime.utcnow().isoformat() + tool_step.start = utc_now() self.stack.push(tool_step) def on_tool_finish( @@ -128,7 +140,7 @@ def on_tool_finish( # Tool finished, send step with tool_result tool_step = self.stack.pop() tool_step.output = tool_result - tool_step.end = datetime.utcnow().isoformat() + tool_step.end = utc_now() run_sync(tool_step.update()) def on_tool_error(self, exception: Exception, tool: Tool, **kwargs: Any) -> None: @@ -136,5 +148,5 @@ def on_tool_error(self, exception: Exception, tool: Tool, **kwargs: Any) -> None error_step = self.stack.pop() error_step.is_error = True error_step.output = str(exception) - error_step.end = datetime.utcnow().isoformat() + error_step.end = utc_now() run_sync(error_step.update()) diff --git a/backend/chainlit/langchain/callbacks.py b/backend/chainlit/langchain/callbacks.py index da544e66f7..81c6fb52c9 100644 --- a/backend/chainlit/langchain/callbacks.py +++ b/backend/chainlit/langchain/callbacks.py @@ -1,6 +1,5 @@ import json import time -from datetime import datetime from typing import Any, Dict, List, Optional, TypedDict, Union from uuid import UUID @@ -13,6 +12,7 @@ from langchain.schema.output import ChatGenerationChunk, GenerationChunk from langchain_core.outputs import ChatGenerationChunk, GenerationChunk from literalai import ChatGeneration, CompletionGeneration, GenerationMessage +from literalai.helper import utc_now from literalai.step import TrueStepType DEFAULT_ANSWER_PREFIX_TOKENS = ["Final", "Answer", ":"] @@ -180,6 +180,10 @@ def _convert_message( content="", ) + if literal_uuid := message.additional_kwargs.get("uuid"): + msg["uuid"] = literal_uuid + msg["templated"] = True + if name := getattr(message, "name", None): msg["name"] = name @@ -353,6 +357,18 @@ def on_llm_new_token( if start["tt_first_token"] is None: start["tt_first_token"] = (time.time() - start["start"]) * 1000 + if self.stream_final_answer: + self._append_to_last_tokens(token) + + if self.answer_reached: + if not self.final_stream: + self.final_stream = Message(content="") + self._run_sync(self.final_stream.send()) + self._run_sync(self.final_stream.stream_token(token)) + self.has_streamed_final_answer = True + else: + self.answer_reached = self._check_if_answer_reached() + return super().on_llm_new_token( token, chunk=chunk, @@ -458,7 +474,7 @@ def _start_trace(self, run: Run) -> None: parent_id=parent_id, disable_feedback=disable_feedback, ) - step.start = datetime.utcnow().isoformat() + step.start = utc_now() step.input = run.inputs self.steps[str(run.id)] = step @@ -505,6 +521,15 @@ def _on_run_update(self, run: Run) -> None: ], message_completion=message_completion, ) + + # find first message with prompt_id + for m in chat_start["input_messages"]: + if m.additional_kwargs.get("prompt_id"): + current_step.generation.prompt_id = m.additional_kwargs["prompt_id"] + if custom_variables := m.additional_kwargs.get("variables"): + current_step.generation.variables = custom_variables + break + current_step.language = "json" current_step.output = json.dumps(message_completion) else: @@ -529,7 +554,7 @@ def _on_run_update(self, run: Run) -> None: current_step.output = completion if current_step: - current_step.end = datetime.utcnow().isoformat() + current_step.end = utc_now() self._run_sync(current_step.update()) if self.final_stream and self.has_streamed_final_answer: @@ -547,7 +572,7 @@ def _on_run_update(self, run: Run) -> None: if current_step: current_step.output = output - current_step.end = datetime.utcnow().isoformat() + current_step.end = utc_now() self._run_sync(current_step.update()) def _on_error(self, error: BaseException, *, run_id: UUID, **kwargs: Any): @@ -556,7 +581,7 @@ def _on_error(self, error: BaseException, *, run_id: UUID, **kwargs: Any): if current_step := self.steps.get(str(run_id), None): current_step.is_error = True current_step.output = str(error) - current_step.end = datetime.utcnow().isoformat() + current_step.end = utc_now() self._run_sync(current_step.update()) on_llm_error = _on_error diff --git a/backend/chainlit/llama_index/callbacks.py b/backend/chainlit/llama_index/callbacks.py index d5f6ba6115..12735f7a1e 100644 --- a/backend/chainlit/llama_index/callbacks.py +++ b/backend/chainlit/llama_index/callbacks.py @@ -1,10 +1,10 @@ -from datetime import datetime from typing import Any, Dict, List, Optional from chainlit.context import context_var from chainlit.element import Text from chainlit.step import Step, StepType from literalai import ChatGeneration, CompletionGeneration, GenerationMessage +from literalai.helper import utc_now from llama_index.callbacks import TokenCountingHandler from llama_index.callbacks.schema import CBEventType, EventPayload from llama_index.llms.base import ChatMessage, ChatResponse, CompletionResponse @@ -87,7 +87,7 @@ def on_event_start( disable_feedback=False, ) self.steps[event_id] = step - step.start = datetime.utcnow().isoformat() + step.start = utc_now() step.input = payload or {} self.context.loop.create_task(step.send()) return event_id @@ -107,7 +107,7 @@ def on_event_end( self._restore_context() - step.end = datetime.utcnow().isoformat() + step.end = utc_now() if event_type == CBEventType.RETRIEVE: sources = payload.get(EventPayload.NODES) diff --git a/backend/chainlit/message.py b/backend/chainlit/message.py index dbf026b210..c9f4430a71 100644 --- a/backend/chainlit/message.py +++ b/backend/chainlit/message.py @@ -3,7 +3,6 @@ import time import uuid from abc import ABC -from datetime import datetime from typing import Dict, List, Optional, Union, cast from chainlit.action import Action @@ -23,6 +22,7 @@ FileDict, ) from literalai import BaseGeneration +from literalai.helper import utc_now from literalai.step import MessageStepType @@ -149,7 +149,7 @@ async def _create(self): async def send(self): if not self.created_at: - self.created_at = datetime.utcnow().isoformat() + self.created_at = utc_now() if self.content is None: self.content = "" @@ -367,7 +367,7 @@ async def send(self) -> Union[StepDict, None]: """ trace_event("send_ask_user") if not self.created_at: - self.created_at = datetime.utcnow().isoformat() + self.created_at = utc_now() if config.code.author_rename: self.author = await config.code.author_rename(self.author) @@ -439,7 +439,7 @@ async def send(self) -> Union[List[AskFileResponse], None]: trace_event("send_ask_file") if not self.created_at: - self.created_at = datetime.utcnow().isoformat() + self.created_at = utc_now() if self.streaming: self.streaming = False @@ -512,7 +512,7 @@ async def send(self) -> Union[AskActionResponse, None]: trace_event("send_ask_action") if not self.created_at: - self.created_at = datetime.utcnow().isoformat() + self.created_at = utc_now() if self.streaming: self.streaming = False diff --git a/backend/chainlit/openai/__init__.py b/backend/chainlit/openai/__init__.py index 8f2cac1c03..9757cd1a19 100644 --- a/backend/chainlit/openai/__init__.py +++ b/backend/chainlit/openai/__init__.py @@ -1,10 +1,11 @@ -from typing import TYPE_CHECKING, Union +from typing import Union from chainlit.context import get_context from chainlit.step import Step from chainlit.sync import run_sync from chainlit.utils import check_module_version from literalai import ChatGeneration, CompletionGeneration +from literalai.helper import timestamp_utc def instrument_openai(): @@ -34,12 +35,14 @@ async def on_new_generation( step.generation = generation # Convert start/end time from seconds to milliseconds step.start = ( - timing.get("start") * 1000 + timestamp_utc(timing.get("start")) if timing.get("start", None) is not None else None ) step.end = ( - timing.get("end") * 1000 if timing.get("end", None) is not None else None + timestamp_utc(timing.get("end")) + if timing.get("end", None) is not None + else None ) if isinstance(generation, ChatGeneration): diff --git a/backend/chainlit/socket.py b/backend/chainlit/socket.py index 8987033b01..773bd1434a 100644 --- a/backend/chainlit/socket.py +++ b/backend/chainlit/socket.py @@ -2,7 +2,6 @@ import json import time import uuid -from datetime import datetime from typing import Any, Dict, Literal from chainlit.action import Action diff --git a/backend/chainlit/step.py b/backend/chainlit/step.py index edf330f77b..9ce9752304 100644 --- a/backend/chainlit/step.py +++ b/backend/chainlit/step.py @@ -3,7 +3,6 @@ import json import time import uuid -from datetime import datetime from functools import wraps from typing import Callable, Dict, List, Optional, TypedDict, Union @@ -15,6 +14,7 @@ from chainlit.telemetry import trace_event from chainlit.types import FeedbackDict from literalai import BaseGeneration +from literalai.helper import utc_now from literalai.step import StepType, TrueStepType @@ -177,7 +177,7 @@ def __init__( self.generation = None self.elements = elements or [] - self.created_at = datetime.utcnow().isoformat() + self.created_at = utc_now() self.start = None self.end = None @@ -372,7 +372,7 @@ def __call__(self, func): # Handle Context Manager Protocol async def __aenter__(self): - self.start = datetime.utcnow().isoformat() + self.start = utc_now() previous_steps = local_steps.get() or [] parent_step = previous_steps[-1] if previous_steps else None @@ -387,7 +387,7 @@ async def __aenter__(self): return self async def __aexit__(self, exc_type, exc_val, exc_tb): - self.end = datetime.utcnow().isoformat() + self.end = utc_now() if self in context.active_steps: context.active_steps.remove(self) @@ -400,7 +400,7 @@ async def __aexit__(self, exc_type, exc_val, exc_tb): await self.update() def __enter__(self): - self.start = datetime.utcnow().isoformat() + self.start = utc_now() previous_steps = local_steps.get() or [] parent_step = previous_steps[-1] if previous_steps else None @@ -417,7 +417,7 @@ def __enter__(self): return self def __exit__(self, exc_type, exc_val, exc_tb): - self.end = datetime.utcnow().isoformat() + self.end = utc_now() if self in context.active_steps: context.active_steps.remove(self) diff --git a/backend/pyproject.toml b/backend/pyproject.toml index 9ea0474268..44226a9656 100644 --- a/backend/pyproject.toml +++ b/backend/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "chainlit" -version = "1.0.301" +version = "1.0.400" keywords = ['LLM', 'Agents', 'gen ai', 'chat ui', 'chatbot ui', 'openai', 'copilot', 'langchain', 'conversational ai'] description = "Build Conversational AI." authors = ["Chainlit"] @@ -23,7 +23,7 @@ chainlit = 'chainlit.cli:cli' [tool.poetry.dependencies] python = ">=3.8.1,<4.0.0" httpx = ">=0.23.0" -literalai = "0.0.204" +literalai = "0.0.300" dataclasses_json = "^0.5.7" fastapi = ">=0.100" # Starlette >= 0.33.0 breaks socketio (alway 404) diff --git a/cypress/e2e/data_layer/main.py b/cypress/e2e/data_layer/main.py index 82efde08f8..9518adbb51 100644 --- a/cypress/e2e/data_layer/main.py +++ b/cypress/e2e/data_layer/main.py @@ -1,12 +1,12 @@ -from datetime import datetime from typing import List, Optional import chainlit.data as cl_data from chainlit.step import StepDict +from literalai.helper import utc_now import chainlit as cl -now = datetime.utcnow().isoformat() +now = utc_now() create_step_counter = 0