Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

migrate to literal score #851

Merged
merged 10 commits into from Mar 30, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
98 changes: 59 additions & 39 deletions backend/chainlit/data/__init__.py
Expand Up @@ -2,7 +2,7 @@
import json
import os
from collections import deque
from typing import TYPE_CHECKING, Dict, List, Optional, Union
from typing import TYPE_CHECKING, Dict, List, Optional, Union, Literal, cast

import aiofiles
from chainlit.config import config
Expand All @@ -11,13 +11,9 @@
from chainlit.session import WebsocketSession
from chainlit.types import Feedback, Pagination, ThreadDict, ThreadFilter
from chainlit.user import PersistedUser, User, UserDict
from literalai import Attachment
from literalai import Feedback as ClientFeedback
from literalai import PageInfo, PaginatedResponse
from literalai import Step as ClientStep
from literalai.step import StepDict as ClientStepDict
from literalai.thread import NumberListFilter, StringFilter, StringListFilter
from literalai.thread import ThreadFilter as ClientThreadFilter
from literalai import Score as LiteralScore, PageInfo, PaginatedResponse, Attachment, Step as LiteralStep
from literalai.step import StepDict as LiteralStepDict
from literalai.filter import threads_filters as LiteralThreadsFilters

if TYPE_CHECKING:
from chainlit.element import Element, ElementDict
Expand Down Expand Up @@ -57,6 +53,13 @@ async def get_user(self, identifier: str) -> Optional["PersistedUser"]:
async def create_user(self, user: "User") -> Optional["PersistedUser"]:
pass

async def delete_feedback(
self,
feedback_id: str,
) -> bool:
return True


async def upsert_feedback(
self,
feedback: Feedback,
Expand Down Expand Up @@ -98,7 +101,7 @@ async def list_threads(
self, pagination: "Pagination", filters: "ThreadFilter"
) -> "PaginatedResponse[ThreadDict]":
return PaginatedResponse(
data=[], pageInfo=PageInfo(hasNextPage=False, endCursor=None)
data=[], pageInfo=PageInfo(hasNextPage=False, startCursor=None, endCursor=None)
)

async def get_thread(self, thread_id: str) -> "Optional[ThreadDict]":
Expand Down Expand Up @@ -146,33 +149,35 @@ def attachment_to_element_dict(self, attachment: Attachment) -> "ElementDict":
"threadId": attachment.thread_id,
}

def feedback_to_feedback_dict(
self, feedback: Optional[ClientFeedback]
def score_to_feedback_dict(
self, score: Optional[LiteralScore]
) -> "Optional[FeedbackDict]":
if not feedback:
if not score:
return None
return {
"id": feedback.id or "",
"forId": feedback.step_id or "",
"value": feedback.value or 0, # type: ignore
"comment": feedback.comment,
"strategy": "BINARY",
"id": score.id or "",
"forId": score.step_id or "",
"value": cast(Literal[0, 1], score.value),
"comment": score.comment,
}

def step_to_step_dict(self, step: ClientStep) -> "StepDict":
def step_to_step_dict(self, step: LiteralStep) -> "StepDict":
metadata = step.metadata or {}
input = (step.input or {}).get("content") or (
json.dumps(step.input) if step.input and step.input != {} else ""
)
output = (step.output or {}).get("content") or (
json.dumps(step.output) if step.output and step.output != {} else ""
)

user_feedback = next((s for s in step.scores if s.type == "HUMAN" and s.name == "user-feedback"), None) if step.scores else None

return {
"createdAt": step.created_at,
"id": step.id or "",
"threadId": step.thread_id or "",
"parentId": step.parent_id,
"feedback": self.feedback_to_feedback_dict(step.feedback),
"feedback": self.score_to_feedback_dict(user_feedback),
"start": step.start_time,
"end": step.end_time,
"type": step.type or "undefined",
Expand All @@ -186,7 +191,6 @@ def step_to_step_dict(self, step: ClientStep) -> "StepDict":
"language": metadata.get("language"),
"isError": metadata.get("isError", False),
"waitForAnswer": metadata.get("waitForAnswer", False),
"feedback": self.feedback_to_feedback_dict(step.feedback),
}

async def get_user(self, identifier: str) -> Optional[PersistedUser]:
Expand Down Expand Up @@ -215,26 +219,38 @@ async def create_user(self, user: User) -> Optional[PersistedUser]:
createdAt=_user.created_at or "",
)

async def delete_feedback(
self,
feedback_id: str,
):
if feedback_id:
await self.client.api.delete_score(
id=feedback_id,
)
return True
return False


async def upsert_feedback(
self,
feedback: Feedback,
):
if feedback.id:
await self.client.api.update_feedback(
await self.client.api.update_score(
id=feedback.id,
update_params={
"comment": feedback.comment,
"strategy": feedback.strategy,
"value": feedback.value,
},
)
return feedback.id
else:
created = await self.client.api.create_feedback(
created = await self.client.api.create_score(
step_id=feedback.forId,
value=feedback.value,
comment=feedback.comment,
strategy=feedback.strategy,
name="user-feedback",
type="HUMAN",
)
return created.id or ""

Expand Down Expand Up @@ -307,7 +323,7 @@ async def create_step(self, step_dict: "StepDict"):
"showInput": step_dict.get("showInput"),
}

step: ClientStepDict = {
step: LiteralStepDict = {
"createdAt": step_dict.get("createdAt"),
"startTime": step_dict.get("start"),
"endTime": step_dict.get("end"),
Expand Down Expand Up @@ -349,22 +365,26 @@ async def delete_thread(self, thread_id: str):
async def list_threads(
self, pagination: "Pagination", filters: "ThreadFilter"
) -> "PaginatedResponse[ThreadDict]":
if not filters.userIdentifier:
raise ValueError("userIdentifier is required")

client_filters = ClientThreadFilter(
participantsIdentifier=StringListFilter(
operator="in", value=[filters.userIdentifier]
),
)
if not filters.userId:
raise ValueError("userId is required")

literal_filters: LiteralThreadsFilters = [
{
"field": "participantId",
"operator": "eq",
"value": filters.userId,
}
]

if filters.search:
client_filters.search = StringFilter(operator="ilike", value=filters.search)
if filters.feedback:
client_filters.feedbacksValue = NumberListFilter(
operator="in", value=[filters.feedback]
)
literal_filters.append({"field": "stepOutput", "operator": "ilike", "value": filters.search, "path": "content"})


if filters.feedback is not None:
literal_filters.append({"field": "scoreValue", "operator": "eq", "value": filters.feedback, "path": "user-feedback"})

return await self.client.api.list_threads(
first=pagination.first, after=pagination.cursor, filters=client_filters
first=pagination.first, after=pagination.cursor, filters=literal_filters, order_by={"column": "createdAt", "direction": "DESC"}
)

async def get_thread(self, thread_id: str) -> "Optional[ThreadDict]":
Expand Down
24 changes: 23 additions & 1 deletion backend/chainlit/server.py
Expand Up @@ -41,6 +41,7 @@
GetThreadsRequest,
Theme,
UpdateFeedbackRequest,
DeleteFeedbackRequest,
)
from chainlit.user import PersistedUser, User
from fastapi import (
Expand Down Expand Up @@ -551,6 +552,24 @@ async def update_feedback(

return JSONResponse(content={"success": True, "feedbackId": feedback_id})

@app.delete("/feedback")
async def delete_feedback(
request: Request,
payload: DeleteFeedbackRequest,
current_user: Annotated[Union[User, PersistedUser], Depends(get_current_user)],
):
"""Delete a feedback."""

data_layer = get_data_layer()

if not data_layer:
raise HTTPException(status_code=400, detail="Data persistence is not enabled")

feedback_id = payload.feedbackId

await data_layer.delete_feedback(feedback_id)
return JSONResponse(content={"success": True})


@app.post("/project/threads")
async def get_user_threads(
Expand All @@ -566,7 +585,10 @@ async def get_user_threads(
if not data_layer:
raise HTTPException(status_code=400, detail="Data persistence is not enabled")

payload.filter.userIdentifier = current_user.identifier
if not isinstance(current_user, PersistedUser):
raise HTTPException(status_code=400, detail="User not persisted")

payload.filter.userId = current_user.id

res = await data_layer.list_threads(payload.pagination, payload.filter)
return JSONResponse(content=res.to_dict())
Expand Down
15 changes: 9 additions & 6 deletions backend/chainlit/types.py
Expand Up @@ -33,8 +33,8 @@ class Pagination(BaseModel):


class ThreadFilter(BaseModel):
feedback: Optional[Literal[-1, 0, 1]] = None
userIdentifier: Optional[str] = None
feedback: Optional[Literal[0, 1]] = None
userId: Optional[str] = None
Comment on lines +36 to +37
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@willydouhard why is -1 being removed?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

-1 becomes 0 and we introduce the concept of "feedback deletion" to replace the current 0 purpose. Having 3 values was a workaround to avoid to handle feedback deletion

search: Optional[str] = None


Expand Down Expand Up @@ -122,6 +122,9 @@ def is_chat(self):
class DeleteThreadRequest(BaseModel):
threadId: str

class DeleteFeedbackRequest(BaseModel):
feedbackId: str


class GetThreadsRequest(BaseModel):
pagination: Pagination
Expand All @@ -146,16 +149,16 @@ class ChatProfile(DataClassJsonMixin):


class FeedbackDict(TypedDict):
value: Literal[-1, 0, 1]
strategy: FeedbackStrategy
forId: str
id: Optional[str]
value: Literal[0, 1]
comment: Optional[str]


@dataclass
class Feedback:
forId: str
value: Literal[-1, 0, 1]
strategy: FeedbackStrategy = "BINARY"
value: Literal[0, 1]
id: Optional[str] = None
comment: Optional[str] = None
Comment on lines 152 to 164
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@willydouhard why are these changing?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Chainlit is only supporting binary feedback so this was never used throughout the project.


Expand Down
4 changes: 2 additions & 2 deletions backend/pyproject.toml
@@ -1,6 +1,6 @@
[tool.poetry]
name = "chainlit"
version = "1.0.401"
version = "1.0.500"
keywords = ['LLM', 'Agents', 'gen ai', 'chat ui', 'chatbot ui', 'openai', 'copilot', 'langchain', 'conversational ai']
description = "Build Conversational AI."
authors = ["Chainlit"]
Expand All @@ -23,7 +23,7 @@ chainlit = 'chainlit.cli:cli'
[tool.poetry.dependencies]
python = ">=3.8.1,<4.0.0"
httpx = ">=0.23.0"
literalai = "0.0.300"
literalai = "0.0.400"
dataclasses_json = "^0.5.7"
fastapi = ">=0.100"
# Starlette >= 0.33.0 breaks socketio (alway 404)
Expand Down
2 changes: 1 addition & 1 deletion cypress/e2e/data_layer/main.py
Expand Up @@ -81,7 +81,7 @@ async def list_threads(
) -> cl_data.PaginatedResponse[cl_data.ThreadDict]:
return cl_data.PaginatedResponse(
data=[t for t in thread_history if t["id"] not in deleted_thread_ids],
pageInfo=cl_data.PageInfo(hasNextPage=False, endCursor=None),
pageInfo=cl_data.PageInfo(hasNextPage=False, startCursor=None, endCursor=None),
)

async def get_thread(self, thread_id: str):
Expand Down