Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

prepare release #794

Merged
merged 4 commits into from Mar 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
15 changes: 15 additions & 0 deletions CHANGELOG.md
Expand Up @@ -8,6 +8,21 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).

Nothing is unreleased!

## [1.0.400] - 2023-03-06

### Added

- OpenAI integration

### Fixed

- Langchain final answer streaming should work again
- Elements with public URLs should be correctly persisted by the data layer

### Changed

- Enforce UTC DateTimes

## [1.0.300] - 2023-02-19

### Added
Expand Down
5 changes: 3 additions & 2 deletions backend/chainlit/element.py
Expand Up @@ -123,14 +123,14 @@ def from_dict(self, _dict: FileDict):
)

async def _create(self) -> bool:
if (self.persisted or self.url) and not self.updatable:
if self.persisted and not self.updatable:
return True
if data_layer := get_data_layer():
try:
asyncio.create_task(data_layer.create_element(self))
except Exception as e:
logger.error(f"Failed to create element: {str(e)}")
if not self.chainlit_key or self.updatable:
if not self.url and (not self.chainlit_key or self.updatable):
file_dict = await context.session.persist_file(
name=self.name,
path=self.path,
Expand Down Expand Up @@ -203,6 +203,7 @@ class Text(Element):
class Pdf(Element):
"""Useful to send a pdf to the UI."""

mime: str = "application/pdf"
page: Optional[int] = None
type: ClassVar[ElementType] = "pdf"

Expand Down
4 changes: 2 additions & 2 deletions backend/chainlit/emitter.py
@@ -1,6 +1,5 @@
import asyncio
import uuid
from datetime import datetime
from typing import Any, Dict, List, Literal, Optional, Union, cast

from chainlit.data import get_data_layer
Expand All @@ -18,6 +17,7 @@
UIMessagePayload,
)
from chainlit.user import PersistedUser
from literalai.helper import utc_now
from socketio.exceptions import TimeoutError


Expand Down Expand Up @@ -196,7 +196,7 @@ async def process_user_message(self, payload: UIMessagePayload):

message = Message.from_dict(step_dict)
# Overwrite the created_at timestamp with the current time
message.created_at = datetime.utcnow().isoformat()
message.created_at = utc_now()

asyncio.create_task(message._create())

Expand Down
40 changes: 26 additions & 14 deletions backend/chainlit/haystack/callbacks.py
@@ -1,13 +1,14 @@
from datetime import datetime
from typing import Any, Generic, List, Optional, TypeVar
import re
from typing import Any, Generic, List, Optional, TypeVar

from chainlit.context import context
from chainlit.step import Step
from chainlit.sync import run_sync
from chainlit import Message
from haystack.agents import Agent, Tool
from haystack.agents.agent_step import AgentStep
from literalai.helper import utc_now

from chainlit import Message

T = TypeVar("T")

Expand Down Expand Up @@ -36,7 +37,12 @@ class HaystackAgentCallbackHandler:
stack: Stack[Step]
last_step: Optional[Step]

def __init__(self, agent: Agent, stream_final_answer: bool = False, stream_final_answer_agent_name: str = 'Agent'):
def __init__(
self,
agent: Agent,
stream_final_answer: bool = False,
stream_final_answer_agent_name: str = "Agent",
):
agent.callback_manager.on_agent_start += self.on_agent_start
agent.callback_manager.on_agent_step += self.on_agent_step
agent.callback_manager.on_agent_finish += self.on_agent_finish
Expand All @@ -56,14 +62,16 @@ def on_agent_start(self, **kwargs: Any) -> None:
self.stack = Stack[Step]()

if self.stream_final_answer:
self.final_stream = Message(author=self.stream_final_answer_agent_name, content="")
self.final_stream = Message(
author=self.stream_final_answer_agent_name, content=""
)
self.last_tokens: List[str] = []
self.answer_reached = False

root_message = context.session.root_message
parent_id = root_message.id if root_message else None
run_step = Step(name=self.agent_name, type="run", parent_id=parent_id)
run_step.start = datetime.utcnow().isoformat()
run_step.start = utc_now()
run_step.input = kwargs

run_sync(run_step.send())
Expand All @@ -73,7 +81,7 @@ def on_agent_start(self, **kwargs: Any) -> None:
def on_agent_finish(self, agent_step: AgentStep, **kwargs: Any) -> None:
if self.last_step:
run_step = self.last_step
run_step.end = datetime.utcnow().isoformat()
run_step.end = utc_now()
run_step.output = agent_step.prompt_node_response
run_sync(run_step.update())

Expand All @@ -85,7 +93,7 @@ def on_agent_step(self, agent_step: AgentStep, **kwargs: Any) -> None:
# If token streaming is disabled
if self.last_step.output == "":
self.last_step.output = agent_step.prompt_node_response
self.last_step.end = datetime.utcnow().isoformat()
self.last_step.end = utc_now()
run_sync(self.last_step.update())

if not agent_step.is_last():
Expand All @@ -101,12 +109,16 @@ def on_new_token(self, token, **kwargs: Any) -> None:
else:
self.last_tokens.append(token)

last_tokens_concat = ''.join(self.last_tokens)
final_answer_match = re.search(self.final_answer_pattern, last_tokens_concat)
last_tokens_concat = "".join(self.last_tokens)
final_answer_match = re.search(
self.final_answer_pattern, last_tokens_concat
)

if final_answer_match:
self.answer_reached = True
run_sync(self.final_stream.stream_token(final_answer_match.group(1)))
run_sync(
self.final_stream.stream_token(final_answer_match.group(1))
)

run_sync(self.stack.peek().stream_token(token))

Expand All @@ -115,7 +127,7 @@ def on_tool_start(self, tool_input: str, tool: Tool, **kwargs: Any) -> None:
parent_id = self.stack.items[0].id if self.stack.items[0] else None
tool_step = Step(name=tool.name, type="tool", parent_id=parent_id)
tool_step.input = tool_input
tool_step.start = datetime.utcnow().isoformat()
tool_step.start = utc_now()
self.stack.push(tool_step)

def on_tool_finish(
Expand All @@ -128,13 +140,13 @@ def on_tool_finish(
# Tool finished, send step with tool_result
tool_step = self.stack.pop()
tool_step.output = tool_result
tool_step.end = datetime.utcnow().isoformat()
tool_step.end = utc_now()
run_sync(tool_step.update())

def on_tool_error(self, exception: Exception, tool: Tool, **kwargs: Any) -> None:
# Tool error, send error message
error_step = self.stack.pop()
error_step.is_error = True
error_step.output = str(exception)
error_step.end = datetime.utcnow().isoformat()
error_step.end = utc_now()
run_sync(error_step.update())
35 changes: 30 additions & 5 deletions backend/chainlit/langchain/callbacks.py
@@ -1,6 +1,5 @@
import json
import time
from datetime import datetime
from typing import Any, Dict, List, Optional, TypedDict, Union
from uuid import UUID

Expand All @@ -13,6 +12,7 @@
from langchain.schema.output import ChatGenerationChunk, GenerationChunk
from langchain_core.outputs import ChatGenerationChunk, GenerationChunk
from literalai import ChatGeneration, CompletionGeneration, GenerationMessage
from literalai.helper import utc_now
from literalai.step import TrueStepType

DEFAULT_ANSWER_PREFIX_TOKENS = ["Final", "Answer", ":"]
Expand Down Expand Up @@ -180,6 +180,10 @@ def _convert_message(
content="",
)

if literal_uuid := message.additional_kwargs.get("uuid"):
msg["uuid"] = literal_uuid
msg["templated"] = True

if name := getattr(message, "name", None):
msg["name"] = name

Expand Down Expand Up @@ -353,6 +357,18 @@ def on_llm_new_token(
if start["tt_first_token"] is None:
start["tt_first_token"] = (time.time() - start["start"]) * 1000

if self.stream_final_answer:
self._append_to_last_tokens(token)

if self.answer_reached:
if not self.final_stream:
self.final_stream = Message(content="")
self._run_sync(self.final_stream.send())
self._run_sync(self.final_stream.stream_token(token))
self.has_streamed_final_answer = True
else:
self.answer_reached = self._check_if_answer_reached()

return super().on_llm_new_token(
token,
chunk=chunk,
Expand Down Expand Up @@ -458,7 +474,7 @@ def _start_trace(self, run: Run) -> None:
parent_id=parent_id,
disable_feedback=disable_feedback,
)
step.start = datetime.utcnow().isoformat()
step.start = utc_now()
step.input = run.inputs

self.steps[str(run.id)] = step
Expand Down Expand Up @@ -505,6 +521,15 @@ def _on_run_update(self, run: Run) -> None:
],
message_completion=message_completion,
)

# find first message with prompt_id
for m in chat_start["input_messages"]:
if m.additional_kwargs.get("prompt_id"):
current_step.generation.prompt_id = m.additional_kwargs["prompt_id"]
if custom_variables := m.additional_kwargs.get("variables"):
current_step.generation.variables = custom_variables
break

current_step.language = "json"
current_step.output = json.dumps(message_completion)
else:
Expand All @@ -529,7 +554,7 @@ def _on_run_update(self, run: Run) -> None:
current_step.output = completion

if current_step:
current_step.end = datetime.utcnow().isoformat()
current_step.end = utc_now()
self._run_sync(current_step.update())

if self.final_stream and self.has_streamed_final_answer:
Expand All @@ -547,7 +572,7 @@ def _on_run_update(self, run: Run) -> None:

if current_step:
current_step.output = output
current_step.end = datetime.utcnow().isoformat()
current_step.end = utc_now()
self._run_sync(current_step.update())

def _on_error(self, error: BaseException, *, run_id: UUID, **kwargs: Any):
Expand All @@ -556,7 +581,7 @@ def _on_error(self, error: BaseException, *, run_id: UUID, **kwargs: Any):
if current_step := self.steps.get(str(run_id), None):
current_step.is_error = True
current_step.output = str(error)
current_step.end = datetime.utcnow().isoformat()
current_step.end = utc_now()
self._run_sync(current_step.update())

on_llm_error = _on_error
Expand Down
6 changes: 3 additions & 3 deletions backend/chainlit/llama_index/callbacks.py
@@ -1,10 +1,10 @@
from datetime import datetime
from typing import Any, Dict, List, Optional

from chainlit.context import context_var
from chainlit.element import Text
from chainlit.step import Step, StepType
from literalai import ChatGeneration, CompletionGeneration, GenerationMessage
from literalai.helper import utc_now
from llama_index.callbacks import TokenCountingHandler
from llama_index.callbacks.schema import CBEventType, EventPayload
from llama_index.llms.base import ChatMessage, ChatResponse, CompletionResponse
Expand Down Expand Up @@ -87,7 +87,7 @@ def on_event_start(
disable_feedback=False,
)
self.steps[event_id] = step
step.start = datetime.utcnow().isoformat()
step.start = utc_now()
step.input = payload or {}
self.context.loop.create_task(step.send())
return event_id
Expand All @@ -107,7 +107,7 @@ def on_event_end(

self._restore_context()

step.end = datetime.utcnow().isoformat()
step.end = utc_now()

if event_type == CBEventType.RETRIEVE:
sources = payload.get(EventPayload.NODES)
Expand Down
10 changes: 5 additions & 5 deletions backend/chainlit/message.py
Expand Up @@ -3,7 +3,6 @@
import time
import uuid
from abc import ABC
from datetime import datetime
from typing import Dict, List, Optional, Union, cast

from chainlit.action import Action
Expand All @@ -23,6 +22,7 @@
FileDict,
)
from literalai import BaseGeneration
from literalai.helper import utc_now
from literalai.step import MessageStepType


Expand Down Expand Up @@ -149,7 +149,7 @@ async def _create(self):

async def send(self):
if not self.created_at:
self.created_at = datetime.utcnow().isoformat()
self.created_at = utc_now()
if self.content is None:
self.content = ""

Expand Down Expand Up @@ -367,7 +367,7 @@ async def send(self) -> Union[StepDict, None]:
"""
trace_event("send_ask_user")
if not self.created_at:
self.created_at = datetime.utcnow().isoformat()
self.created_at = utc_now()

if config.code.author_rename:
self.author = await config.code.author_rename(self.author)
Expand Down Expand Up @@ -439,7 +439,7 @@ async def send(self) -> Union[List[AskFileResponse], None]:
trace_event("send_ask_file")

if not self.created_at:
self.created_at = datetime.utcnow().isoformat()
self.created_at = utc_now()

if self.streaming:
self.streaming = False
Expand Down Expand Up @@ -512,7 +512,7 @@ async def send(self) -> Union[AskActionResponse, None]:
trace_event("send_ask_action")

if not self.created_at:
self.created_at = datetime.utcnow().isoformat()
self.created_at = utc_now()

if self.streaming:
self.streaming = False
Expand Down
9 changes: 6 additions & 3 deletions backend/chainlit/openai/__init__.py
@@ -1,10 +1,11 @@
from typing import TYPE_CHECKING, Union
from typing import Union

from chainlit.context import get_context
from chainlit.step import Step
from chainlit.sync import run_sync
from chainlit.utils import check_module_version
from literalai import ChatGeneration, CompletionGeneration
from literalai.helper import timestamp_utc


def instrument_openai():
Expand Down Expand Up @@ -34,12 +35,14 @@ async def on_new_generation(
step.generation = generation
# Convert start/end time from seconds to milliseconds
step.start = (
timing.get("start") * 1000
timestamp_utc(timing.get("start"))
if timing.get("start", None) is not None
else None
)
step.end = (
timing.get("end") * 1000 if timing.get("end", None) is not None else None
timestamp_utc(timing.get("end"))
if timing.get("end", None) is not None
else None
)

if isinstance(generation, ChatGeneration):
Expand Down