Skip to content

Commit

Permalink
migrate to literal score (#851)
Browse files Browse the repository at this point in the history
* migrate to literal score

* fix tests

* enhance langchain llm step display

* correctly display new lines

* move new line cleaning to backend

* fix thread dict

* bump sdk version

* fix data layer test

* fix casing
  • Loading branch information
willydouhard committed Mar 30, 2024
1 parent 31dcd15 commit 942944e
Show file tree
Hide file tree
Showing 21 changed files with 308 additions and 99 deletions.
141 changes: 91 additions & 50 deletions backend/chainlit/data/__init__.py
Expand Up @@ -2,7 +2,7 @@
import json
import os
from collections import deque
from typing import TYPE_CHECKING, Dict, List, Optional, Union
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Union, cast

import aiofiles
from chainlit.config import config
Expand All @@ -11,13 +11,11 @@
from chainlit.session import WebsocketSession
from chainlit.types import Feedback, Pagination, ThreadDict, ThreadFilter
from chainlit.user import PersistedUser, User, UserDict
from literalai import Attachment
from literalai import Feedback as ClientFeedback
from literalai import PageInfo, PaginatedResponse
from literalai import Step as ClientStep
from literalai.step import StepDict as ClientStepDict
from literalai.thread import NumberListFilter, StringFilter, StringListFilter
from literalai.thread import ThreadFilter as ClientThreadFilter
from literalai import Attachment, PageInfo, PaginatedResponse
from literalai import Score as LiteralScore
from literalai import Step as LiteralStep
from literalai.filter import threads_filters as LiteralThreadsFilters
from literalai.step import StepDict as LiteralStepDict

if TYPE_CHECKING:
from chainlit.element import Element, ElementDict
Expand Down Expand Up @@ -57,6 +55,12 @@ async def get_user(self, identifier: str) -> Optional["PersistedUser"]:
async def create_user(self, user: "User") -> Optional["PersistedUser"]:
pass

async def delete_feedback(
self,
feedback_id: str,
) -> bool:
return True

async def upsert_feedback(
self,
feedback: Feedback,
Expand Down Expand Up @@ -98,7 +102,8 @@ async def list_threads(
self, pagination: "Pagination", filters: "ThreadFilter"
) -> "PaginatedResponse[ThreadDict]":
return PaginatedResponse(
data=[], pageInfo=PageInfo(hasNextPage=False, endCursor=None)
data=[],
pageInfo=PageInfo(hasNextPage=False, startCursor=None, endCursor=None),
)

async def get_thread(self, thread_id: str) -> "Optional[ThreadDict]":
Expand Down Expand Up @@ -146,33 +151,46 @@ def attachment_to_element_dict(self, attachment: Attachment) -> "ElementDict":
"threadId": attachment.thread_id,
}

def feedback_to_feedback_dict(
self, feedback: Optional[ClientFeedback]
def score_to_feedback_dict(
self, score: Optional[LiteralScore]
) -> "Optional[FeedbackDict]":
if not feedback:
if not score:
return None
return {
"id": feedback.id or "",
"forId": feedback.step_id or "",
"value": feedback.value or 0, # type: ignore
"comment": feedback.comment,
"strategy": "BINARY",
"id": score.id or "",
"forId": score.step_id or "",
"value": cast(Literal[0, 1], score.value),
"comment": score.comment,
}

def step_to_step_dict(self, step: ClientStep) -> "StepDict":
def step_to_step_dict(self, step: LiteralStep) -> "StepDict":
metadata = step.metadata or {}
input = (step.input or {}).get("content") or (
json.dumps(step.input) if step.input and step.input != {} else ""
)
output = (step.output or {}).get("content") or (
json.dumps(step.output) if step.output and step.output != {} else ""
)

user_feedback = (
next(
(
s
for s in step.scores
if s.type == "HUMAN" and s.name == "user-feedback"
),
None,
)
if step.scores
else None
)

return {
"createdAt": step.created_at,
"id": step.id or "",
"threadId": step.thread_id or "",
"parentId": step.parent_id,
"feedback": self.feedback_to_feedback_dict(step.feedback),
"feedback": self.score_to_feedback_dict(user_feedback),
"start": step.start_time,
"end": step.end_time,
"type": step.type or "undefined",
Expand All @@ -186,7 +204,6 @@ def step_to_step_dict(self, step: ClientStep) -> "StepDict":
"language": metadata.get("language"),
"isError": metadata.get("isError", False),
"waitForAnswer": metadata.get("waitForAnswer", False),
"feedback": self.feedback_to_feedback_dict(step.feedback),
}

async def get_user(self, identifier: str) -> Optional[PersistedUser]:
Expand Down Expand Up @@ -215,26 +232,37 @@ async def create_user(self, user: User) -> Optional[PersistedUser]:
createdAt=_user.created_at or "",
)

async def delete_feedback(
self,
feedback_id: str,
):
if feedback_id:
await self.client.api.delete_score(
id=feedback_id,
)
return True
return False

async def upsert_feedback(
self,
feedback: Feedback,
):
if feedback.id:
await self.client.api.update_feedback(
await self.client.api.update_score(
id=feedback.id,
update_params={
"comment": feedback.comment,
"strategy": feedback.strategy,
"value": feedback.value,
},
)
return feedback.id
else:
created = await self.client.api.create_feedback(
created = await self.client.api.create_score(
step_id=feedback.forId,
value=feedback.value,
comment=feedback.comment,
strategy=feedback.strategy,
name="user-feedback",
type="HUMAN",
)
return created.id or ""

Expand Down Expand Up @@ -307,7 +335,7 @@ async def create_step(self, step_dict: "StepDict"):
"showInput": step_dict.get("showInput"),
}

step: ClientStepDict = {
step: LiteralStepDict = {
"createdAt": step_dict.get("createdAt"),
"startTime": step_dict.get("start"),
"endTime": step_dict.get("end"),
Expand Down Expand Up @@ -338,33 +366,54 @@ async def get_thread_author(self, thread_id: str) -> str:
thread = await self.get_thread(thread_id)
if not thread:
return ""
user = thread.get("user")
if not user:
user_identifier = thread.get("userIdentifier")
if not user_identifier:
return ""
return user.get("identifier") or ""

return user_identifier

async def delete_thread(self, thread_id: str):
await self.client.api.delete_thread(id=thread_id)

async def list_threads(
self, pagination: "Pagination", filters: "ThreadFilter"
) -> "PaginatedResponse[ThreadDict]":
if not filters.userIdentifier:
raise ValueError("userIdentifier is required")
if not filters.userId:
raise ValueError("userId is required")

literal_filters: LiteralThreadsFilters = [
{
"field": "participantId",
"operator": "eq",
"value": filters.userId,
}
]

client_filters = ClientThreadFilter(
participantsIdentifier=StringListFilter(
operator="in", value=[filters.userIdentifier]
),
)
if filters.search:
client_filters.search = StringFilter(operator="ilike", value=filters.search)
if filters.feedback:
client_filters.feedbacksValue = NumberListFilter(
operator="in", value=[filters.feedback]
literal_filters.append(
{
"field": "stepOutput",
"operator": "ilike",
"value": filters.search,
"path": "content",
}
)

if filters.feedback is not None:
literal_filters.append(
{
"field": "scoreValue",
"operator": "eq",
"value": filters.feedback,
"path": "user-feedback",
}
)

return await self.client.api.list_threads(
first=pagination.first, after=pagination.cursor, filters=client_filters
first=pagination.first,
after=pagination.cursor,
filters=literal_filters,
order_by={"column": "createdAt", "direction": "DESC"},
)

async def get_thread(self, thread_id: str) -> "Optional[ThreadDict]":
Expand All @@ -383,23 +432,15 @@ async def get_thread(self, thread_id: str) -> "Optional[ThreadDict]":
step.generation = None
steps.append(self.step_to_step_dict(step))

user = None # type: Optional["UserDict"]

if thread.user:
user = {
"id": thread.user.id or "",
"identifier": thread.user.identifier or "",
"metadata": thread.user.metadata,
}

return {
"createdAt": thread.created_at or "",
"id": thread.id,
"name": thread.name or None,
"steps": steps,
"elements": elements,
"metadata": thread.metadata,
"user": user,
"userId": thread.participant_id,
"userIdentifier": thread.participant_identifier,
"tags": thread.tags,
}

Expand Down
3 changes: 1 addition & 2 deletions backend/chainlit/langchain/callbacks.py
Expand Up @@ -533,8 +533,7 @@ def _on_run_update(self, run: Run) -> None:
break

current_step.language = "json"
current_step.output = json.dumps(message_completion)
completion = message_completion.get("content", "")
current_step.output = json.dumps(message_completion, indent=4, ensure_ascii=False)
else:
completion_start = self.completion_generations[str(run.id)]
completion = generation.get("text", "")
Expand Down
24 changes: 23 additions & 1 deletion backend/chainlit/server.py
Expand Up @@ -41,6 +41,7 @@
GetThreadsRequest,
Theme,
UpdateFeedbackRequest,
DeleteFeedbackRequest,
)
from chainlit.user import PersistedUser, User
from fastapi import (
Expand Down Expand Up @@ -551,6 +552,24 @@ async def update_feedback(

return JSONResponse(content={"success": True, "feedbackId": feedback_id})

@app.delete("/feedback")
async def delete_feedback(
request: Request,
payload: DeleteFeedbackRequest,
current_user: Annotated[Union[User, PersistedUser], Depends(get_current_user)],
):
"""Delete a feedback."""

data_layer = get_data_layer()

if not data_layer:
raise HTTPException(status_code=400, detail="Data persistence is not enabled")

feedback_id = payload.feedbackId

await data_layer.delete_feedback(feedback_id)
return JSONResponse(content={"success": True})


@app.post("/project/threads")
async def get_user_threads(
Expand All @@ -566,7 +585,10 @@ async def get_user_threads(
if not data_layer:
raise HTTPException(status_code=400, detail="Data persistence is not enabled")

payload.filter.userIdentifier = current_user.identifier
if not isinstance(current_user, PersistedUser):
raise HTTPException(status_code=400, detail="User not persisted")

payload.filter.userId = current_user.id

res = await data_layer.list_threads(payload.pagination, payload.filter)
return JSONResponse(content=res.to_dict())
Expand Down
2 changes: 1 addition & 1 deletion backend/chainlit/socket.py
Expand Up @@ -42,7 +42,7 @@ async def resume_thread(session: WebsocketSession):
if not thread:
return

author = thread.get("user").get("identifier") if thread["user"] else None
author = thread.get("userIdentifier")
user_is_author = author == session.user.identifier

if user_is_author:
Expand Down
4 changes: 2 additions & 2 deletions backend/chainlit/step.py
Expand Up @@ -194,13 +194,13 @@ def _process_content(self, content, set_language=False):
if set_language:
self.language = "json"
except TypeError:
processed_content = str(content)
processed_content = str(content).replace("\\n", "\n")
if set_language:
self.language = "text"
elif isinstance(content, str):
processed_content = content
else:
processed_content = str(content)
processed_content = str(content).replace("\\n", "\n")
if set_language:
self.language = "text"
return processed_content
Expand Down

0 comments on commit 942944e

Please sign in to comment.