Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

migrate to literal score #851

Merged
merged 10 commits into from Mar 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
141 changes: 91 additions & 50 deletions backend/chainlit/data/__init__.py
Expand Up @@ -2,7 +2,7 @@
import json
import os
from collections import deque
from typing import TYPE_CHECKING, Dict, List, Optional, Union
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Union, cast

import aiofiles
from chainlit.config import config
Expand All @@ -11,13 +11,11 @@
from chainlit.session import WebsocketSession
from chainlit.types import Feedback, Pagination, ThreadDict, ThreadFilter
from chainlit.user import PersistedUser, User, UserDict
from literalai import Attachment
from literalai import Feedback as ClientFeedback
from literalai import PageInfo, PaginatedResponse
from literalai import Step as ClientStep
from literalai.step import StepDict as ClientStepDict
from literalai.thread import NumberListFilter, StringFilter, StringListFilter
from literalai.thread import ThreadFilter as ClientThreadFilter
from literalai import Attachment, PageInfo, PaginatedResponse
from literalai import Score as LiteralScore
from literalai import Step as LiteralStep
from literalai.filter import threads_filters as LiteralThreadsFilters
from literalai.step import StepDict as LiteralStepDict

if TYPE_CHECKING:
from chainlit.element import Element, ElementDict
Expand Down Expand Up @@ -57,6 +55,12 @@ async def get_user(self, identifier: str) -> Optional["PersistedUser"]:
async def create_user(self, user: "User") -> Optional["PersistedUser"]:
pass

async def delete_feedback(
self,
feedback_id: str,
) -> bool:
return True

async def upsert_feedback(
self,
feedback: Feedback,
Expand Down Expand Up @@ -98,7 +102,8 @@ async def list_threads(
self, pagination: "Pagination", filters: "ThreadFilter"
) -> "PaginatedResponse[ThreadDict]":
return PaginatedResponse(
data=[], pageInfo=PageInfo(hasNextPage=False, endCursor=None)
data=[],
pageInfo=PageInfo(hasNextPage=False, startCursor=None, endCursor=None),
)

async def get_thread(self, thread_id: str) -> "Optional[ThreadDict]":
Expand Down Expand Up @@ -146,33 +151,46 @@ def attachment_to_element_dict(self, attachment: Attachment) -> "ElementDict":
"threadId": attachment.thread_id,
}

def feedback_to_feedback_dict(
self, feedback: Optional[ClientFeedback]
def score_to_feedback_dict(
self, score: Optional[LiteralScore]
) -> "Optional[FeedbackDict]":
if not feedback:
if not score:
return None
return {
"id": feedback.id or "",
"forId": feedback.step_id or "",
"value": feedback.value or 0, # type: ignore
"comment": feedback.comment,
"strategy": "BINARY",
"id": score.id or "",
"forId": score.step_id or "",
"value": cast(Literal[0, 1], score.value),
"comment": score.comment,
}

def step_to_step_dict(self, step: ClientStep) -> "StepDict":
def step_to_step_dict(self, step: LiteralStep) -> "StepDict":
metadata = step.metadata or {}
input = (step.input or {}).get("content") or (
json.dumps(step.input) if step.input and step.input != {} else ""
)
output = (step.output or {}).get("content") or (
json.dumps(step.output) if step.output and step.output != {} else ""
)

user_feedback = (
next(
(
s
for s in step.scores
if s.type == "HUMAN" and s.name == "user-feedback"
),
None,
)
if step.scores
else None
)

return {
"createdAt": step.created_at,
"id": step.id or "",
"threadId": step.thread_id or "",
"parentId": step.parent_id,
"feedback": self.feedback_to_feedback_dict(step.feedback),
"feedback": self.score_to_feedback_dict(user_feedback),
"start": step.start_time,
"end": step.end_time,
"type": step.type or "undefined",
Expand All @@ -186,7 +204,6 @@ def step_to_step_dict(self, step: ClientStep) -> "StepDict":
"language": metadata.get("language"),
"isError": metadata.get("isError", False),
"waitForAnswer": metadata.get("waitForAnswer", False),
"feedback": self.feedback_to_feedback_dict(step.feedback),
}

async def get_user(self, identifier: str) -> Optional[PersistedUser]:
Expand Down Expand Up @@ -215,26 +232,37 @@ async def create_user(self, user: User) -> Optional[PersistedUser]:
createdAt=_user.created_at or "",
)

async def delete_feedback(
self,
feedback_id: str,
):
if feedback_id:
await self.client.api.delete_score(
id=feedback_id,
)
return True
return False

async def upsert_feedback(
self,
feedback: Feedback,
):
if feedback.id:
await self.client.api.update_feedback(
await self.client.api.update_score(
id=feedback.id,
update_params={
"comment": feedback.comment,
"strategy": feedback.strategy,
"value": feedback.value,
},
)
return feedback.id
else:
created = await self.client.api.create_feedback(
created = await self.client.api.create_score(
step_id=feedback.forId,
value=feedback.value,
comment=feedback.comment,
strategy=feedback.strategy,
name="user-feedback",
type="HUMAN",
)
return created.id or ""

Expand Down Expand Up @@ -307,7 +335,7 @@ async def create_step(self, step_dict: "StepDict"):
"showInput": step_dict.get("showInput"),
}

step: ClientStepDict = {
step: LiteralStepDict = {
"createdAt": step_dict.get("createdAt"),
"startTime": step_dict.get("start"),
"endTime": step_dict.get("end"),
Expand Down Expand Up @@ -338,33 +366,54 @@ async def get_thread_author(self, thread_id: str) -> str:
thread = await self.get_thread(thread_id)
if not thread:
return ""
user = thread.get("user")
if not user:
user_identifier = thread.get("userIdentifier")
if not user_identifier:
return ""
return user.get("identifier") or ""

return user_identifier

async def delete_thread(self, thread_id: str):
await self.client.api.delete_thread(id=thread_id)

async def list_threads(
self, pagination: "Pagination", filters: "ThreadFilter"
) -> "PaginatedResponse[ThreadDict]":
if not filters.userIdentifier:
raise ValueError("userIdentifier is required")
if not filters.userId:
raise ValueError("userId is required")

literal_filters: LiteralThreadsFilters = [
{
"field": "participantId",
"operator": "eq",
"value": filters.userId,
}
]

client_filters = ClientThreadFilter(
participantsIdentifier=StringListFilter(
operator="in", value=[filters.userIdentifier]
),
)
if filters.search:
client_filters.search = StringFilter(operator="ilike", value=filters.search)
if filters.feedback:
client_filters.feedbacksValue = NumberListFilter(
operator="in", value=[filters.feedback]
literal_filters.append(
{
"field": "stepOutput",
"operator": "ilike",
"value": filters.search,
"path": "content",
}
)

if filters.feedback is not None:
literal_filters.append(
{
"field": "scoreValue",
"operator": "eq",
"value": filters.feedback,
"path": "user-feedback",
}
)

return await self.client.api.list_threads(
first=pagination.first, after=pagination.cursor, filters=client_filters
first=pagination.first,
after=pagination.cursor,
filters=literal_filters,
order_by={"column": "createdAt", "direction": "DESC"},
)

async def get_thread(self, thread_id: str) -> "Optional[ThreadDict]":
Expand All @@ -383,23 +432,15 @@ async def get_thread(self, thread_id: str) -> "Optional[ThreadDict]":
step.generation = None
steps.append(self.step_to_step_dict(step))

user = None # type: Optional["UserDict"]

if thread.user:
user = {
"id": thread.user.id or "",
"identifier": thread.user.identifier or "",
"metadata": thread.user.metadata,
}

return {
"createdAt": thread.created_at or "",
"id": thread.id,
"name": thread.name or None,
"steps": steps,
"elements": elements,
"metadata": thread.metadata,
"user": user,
"userId": thread.participant_id,
"userIdentifier": thread.participant_identifier,
"tags": thread.tags,
}

Expand Down
3 changes: 1 addition & 2 deletions backend/chainlit/langchain/callbacks.py
Expand Up @@ -533,8 +533,7 @@ def _on_run_update(self, run: Run) -> None:
break

current_step.language = "json"
current_step.output = json.dumps(message_completion)
completion = message_completion.get("content", "")
current_step.output = json.dumps(message_completion, indent=4, ensure_ascii=False)
else:
completion_start = self.completion_generations[str(run.id)]
completion = generation.get("text", "")
Expand Down
24 changes: 23 additions & 1 deletion backend/chainlit/server.py
Expand Up @@ -41,6 +41,7 @@
GetThreadsRequest,
Theme,
UpdateFeedbackRequest,
DeleteFeedbackRequest,
)
from chainlit.user import PersistedUser, User
from fastapi import (
Expand Down Expand Up @@ -551,6 +552,24 @@ async def update_feedback(

return JSONResponse(content={"success": True, "feedbackId": feedback_id})

@app.delete("/feedback")
async def delete_feedback(
request: Request,
payload: DeleteFeedbackRequest,
current_user: Annotated[Union[User, PersistedUser], Depends(get_current_user)],
):
"""Delete a feedback."""

data_layer = get_data_layer()

if not data_layer:
raise HTTPException(status_code=400, detail="Data persistence is not enabled")

feedback_id = payload.feedbackId

await data_layer.delete_feedback(feedback_id)
return JSONResponse(content={"success": True})


@app.post("/project/threads")
async def get_user_threads(
Expand All @@ -566,7 +585,10 @@ async def get_user_threads(
if not data_layer:
raise HTTPException(status_code=400, detail="Data persistence is not enabled")

payload.filter.userIdentifier = current_user.identifier
if not isinstance(current_user, PersistedUser):
raise HTTPException(status_code=400, detail="User not persisted")

payload.filter.userId = current_user.id

res = await data_layer.list_threads(payload.pagination, payload.filter)
return JSONResponse(content=res.to_dict())
Expand Down
2 changes: 1 addition & 1 deletion backend/chainlit/socket.py
Expand Up @@ -42,7 +42,7 @@ async def resume_thread(session: WebsocketSession):
if not thread:
return

author = thread.get("user").get("identifier") if thread["user"] else None
author = thread.get("userIdentifier")
user_is_author = author == session.user.identifier

if user_is_author:
Expand Down
4 changes: 2 additions & 2 deletions backend/chainlit/step.py
Expand Up @@ -194,13 +194,13 @@ def _process_content(self, content, set_language=False):
if set_language:
self.language = "json"
except TypeError:
processed_content = str(content)
processed_content = str(content).replace("\\n", "\n")
if set_language:
self.language = "text"
elif isinstance(content, str):
processed_content = content
else:
processed_content = str(content)
processed_content = str(content).replace("\\n", "\n")
if set_language:
self.language = "text"
return processed_content
Expand Down