diff --git a/backend/chainlit/data/__init__.py b/backend/chainlit/data/__init__.py index bf52c9fd71..d8210d9ff0 100644 --- a/backend/chainlit/data/__init__.py +++ b/backend/chainlit/data/__init__.py @@ -2,7 +2,7 @@ import json import os from collections import deque -from typing import TYPE_CHECKING, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Union, cast import aiofiles from chainlit.config import config @@ -11,13 +11,11 @@ from chainlit.session import WebsocketSession from chainlit.types import Feedback, Pagination, ThreadDict, ThreadFilter from chainlit.user import PersistedUser, User, UserDict -from literalai import Attachment -from literalai import Feedback as ClientFeedback -from literalai import PageInfo, PaginatedResponse -from literalai import Step as ClientStep -from literalai.step import StepDict as ClientStepDict -from literalai.thread import NumberListFilter, StringFilter, StringListFilter -from literalai.thread import ThreadFilter as ClientThreadFilter +from literalai import Attachment, PageInfo, PaginatedResponse +from literalai import Score as LiteralScore +from literalai import Step as LiteralStep +from literalai.filter import threads_filters as LiteralThreadsFilters +from literalai.step import StepDict as LiteralStepDict if TYPE_CHECKING: from chainlit.element import Element, ElementDict @@ -57,6 +55,12 @@ async def get_user(self, identifier: str) -> Optional["PersistedUser"]: async def create_user(self, user: "User") -> Optional["PersistedUser"]: pass + async def delete_feedback( + self, + feedback_id: str, + ) -> bool: + return True + async def upsert_feedback( self, feedback: Feedback, @@ -98,7 +102,8 @@ async def list_threads( self, pagination: "Pagination", filters: "ThreadFilter" ) -> "PaginatedResponse[ThreadDict]": return PaginatedResponse( - data=[], pageInfo=PageInfo(hasNextPage=False, endCursor=None) + data=[], + pageInfo=PageInfo(hasNextPage=False, startCursor=None, endCursor=None), ) async def get_thread(self, thread_id: str) -> "Optional[ThreadDict]": @@ -146,20 +151,19 @@ def attachment_to_element_dict(self, attachment: Attachment) -> "ElementDict": "threadId": attachment.thread_id, } - def feedback_to_feedback_dict( - self, feedback: Optional[ClientFeedback] + def score_to_feedback_dict( + self, score: Optional[LiteralScore] ) -> "Optional[FeedbackDict]": - if not feedback: + if not score: return None return { - "id": feedback.id or "", - "forId": feedback.step_id or "", - "value": feedback.value or 0, # type: ignore - "comment": feedback.comment, - "strategy": "BINARY", + "id": score.id or "", + "forId": score.step_id or "", + "value": cast(Literal[0, 1], score.value), + "comment": score.comment, } - def step_to_step_dict(self, step: ClientStep) -> "StepDict": + def step_to_step_dict(self, step: LiteralStep) -> "StepDict": metadata = step.metadata or {} input = (step.input or {}).get("content") or ( json.dumps(step.input) if step.input and step.input != {} else "" @@ -167,12 +171,26 @@ def step_to_step_dict(self, step: ClientStep) -> "StepDict": output = (step.output or {}).get("content") or ( json.dumps(step.output) if step.output and step.output != {} else "" ) + + user_feedback = ( + next( + ( + s + for s in step.scores + if s.type == "HUMAN" and s.name == "user-feedback" + ), + None, + ) + if step.scores + else None + ) + return { "createdAt": step.created_at, "id": step.id or "", "threadId": step.thread_id or "", "parentId": step.parent_id, - "feedback": self.feedback_to_feedback_dict(step.feedback), + "feedback": self.score_to_feedback_dict(user_feedback), "start": step.start_time, "end": step.end_time, "type": step.type or "undefined", @@ -186,7 +204,6 @@ def step_to_step_dict(self, step: ClientStep) -> "StepDict": "language": metadata.get("language"), "isError": metadata.get("isError", False), "waitForAnswer": metadata.get("waitForAnswer", False), - "feedback": self.feedback_to_feedback_dict(step.feedback), } async def get_user(self, identifier: str) -> Optional[PersistedUser]: @@ -215,26 +232,37 @@ async def create_user(self, user: User) -> Optional[PersistedUser]: createdAt=_user.created_at or "", ) + async def delete_feedback( + self, + feedback_id: str, + ): + if feedback_id: + await self.client.api.delete_score( + id=feedback_id, + ) + return True + return False + async def upsert_feedback( self, feedback: Feedback, ): if feedback.id: - await self.client.api.update_feedback( + await self.client.api.update_score( id=feedback.id, update_params={ "comment": feedback.comment, - "strategy": feedback.strategy, "value": feedback.value, }, ) return feedback.id else: - created = await self.client.api.create_feedback( + created = await self.client.api.create_score( step_id=feedback.forId, value=feedback.value, comment=feedback.comment, - strategy=feedback.strategy, + name="user-feedback", + type="HUMAN", ) return created.id or "" @@ -307,7 +335,7 @@ async def create_step(self, step_dict: "StepDict"): "showInput": step_dict.get("showInput"), } - step: ClientStepDict = { + step: LiteralStepDict = { "createdAt": step_dict.get("createdAt"), "startTime": step_dict.get("start"), "endTime": step_dict.get("end"), @@ -338,10 +366,11 @@ async def get_thread_author(self, thread_id: str) -> str: thread = await self.get_thread(thread_id) if not thread: return "" - user = thread.get("user") - if not user: + user_identifier = thread.get("userIdentifier") + if not user_identifier: return "" - return user.get("identifier") or "" + + return user_identifier async def delete_thread(self, thread_id: str): await self.client.api.delete_thread(id=thread_id) @@ -349,22 +378,42 @@ async def delete_thread(self, thread_id: str): async def list_threads( self, pagination: "Pagination", filters: "ThreadFilter" ) -> "PaginatedResponse[ThreadDict]": - if not filters.userIdentifier: - raise ValueError("userIdentifier is required") + if not filters.userId: + raise ValueError("userId is required") + + literal_filters: LiteralThreadsFilters = [ + { + "field": "participantId", + "operator": "eq", + "value": filters.userId, + } + ] - client_filters = ClientThreadFilter( - participantsIdentifier=StringListFilter( - operator="in", value=[filters.userIdentifier] - ), - ) if filters.search: - client_filters.search = StringFilter(operator="ilike", value=filters.search) - if filters.feedback: - client_filters.feedbacksValue = NumberListFilter( - operator="in", value=[filters.feedback] + literal_filters.append( + { + "field": "stepOutput", + "operator": "ilike", + "value": filters.search, + "path": "content", + } + ) + + if filters.feedback is not None: + literal_filters.append( + { + "field": "scoreValue", + "operator": "eq", + "value": filters.feedback, + "path": "user-feedback", + } ) + return await self.client.api.list_threads( - first=pagination.first, after=pagination.cursor, filters=client_filters + first=pagination.first, + after=pagination.cursor, + filters=literal_filters, + order_by={"column": "createdAt", "direction": "DESC"}, ) async def get_thread(self, thread_id: str) -> "Optional[ThreadDict]": @@ -383,15 +432,6 @@ async def get_thread(self, thread_id: str) -> "Optional[ThreadDict]": step.generation = None steps.append(self.step_to_step_dict(step)) - user = None # type: Optional["UserDict"] - - if thread.user: - user = { - "id": thread.user.id or "", - "identifier": thread.user.identifier or "", - "metadata": thread.user.metadata, - } - return { "createdAt": thread.created_at or "", "id": thread.id, @@ -399,7 +439,8 @@ async def get_thread(self, thread_id: str) -> "Optional[ThreadDict]": "steps": steps, "elements": elements, "metadata": thread.metadata, - "user": user, + "userId": thread.participant_id, + "userIdentifier": thread.participant_identifier, "tags": thread.tags, } diff --git a/backend/chainlit/langchain/callbacks.py b/backend/chainlit/langchain/callbacks.py index 8c6d16049c..0876444929 100644 --- a/backend/chainlit/langchain/callbacks.py +++ b/backend/chainlit/langchain/callbacks.py @@ -533,8 +533,7 @@ def _on_run_update(self, run: Run) -> None: break current_step.language = "json" - current_step.output = json.dumps(message_completion) - completion = message_completion.get("content", "") + current_step.output = json.dumps(message_completion, indent=4, ensure_ascii=False) else: completion_start = self.completion_generations[str(run.id)] completion = generation.get("text", "") diff --git a/backend/chainlit/server.py b/backend/chainlit/server.py index d11c317831..1bc93863a4 100644 --- a/backend/chainlit/server.py +++ b/backend/chainlit/server.py @@ -41,6 +41,7 @@ GetThreadsRequest, Theme, UpdateFeedbackRequest, + DeleteFeedbackRequest, ) from chainlit.user import PersistedUser, User from fastapi import ( @@ -551,6 +552,24 @@ async def update_feedback( return JSONResponse(content={"success": True, "feedbackId": feedback_id}) +@app.delete("/feedback") +async def delete_feedback( + request: Request, + payload: DeleteFeedbackRequest, + current_user: Annotated[Union[User, PersistedUser], Depends(get_current_user)], +): + """Delete a feedback.""" + + data_layer = get_data_layer() + + if not data_layer: + raise HTTPException(status_code=400, detail="Data persistence is not enabled") + + feedback_id = payload.feedbackId + + await data_layer.delete_feedback(feedback_id) + return JSONResponse(content={"success": True}) + @app.post("/project/threads") async def get_user_threads( @@ -566,7 +585,10 @@ async def get_user_threads( if not data_layer: raise HTTPException(status_code=400, detail="Data persistence is not enabled") - payload.filter.userIdentifier = current_user.identifier + if not isinstance(current_user, PersistedUser): + raise HTTPException(status_code=400, detail="User not persisted") + + payload.filter.userId = current_user.id res = await data_layer.list_threads(payload.pagination, payload.filter) return JSONResponse(content=res.to_dict()) diff --git a/backend/chainlit/socket.py b/backend/chainlit/socket.py index 773bd1434a..2217849032 100644 --- a/backend/chainlit/socket.py +++ b/backend/chainlit/socket.py @@ -42,7 +42,7 @@ async def resume_thread(session: WebsocketSession): if not thread: return - author = thread.get("user").get("identifier") if thread["user"] else None + author = thread.get("userIdentifier") user_is_author = author == session.user.identifier if user_is_author: diff --git a/backend/chainlit/step.py b/backend/chainlit/step.py index 9ce9752304..46c9a263c3 100644 --- a/backend/chainlit/step.py +++ b/backend/chainlit/step.py @@ -194,13 +194,13 @@ def _process_content(self, content, set_language=False): if set_language: self.language = "json" except TypeError: - processed_content = str(content) + processed_content = str(content).replace("\\n", "\n") if set_language: self.language = "text" elif isinstance(content, str): processed_content = content else: - processed_content = str(content) + processed_content = str(content).replace("\\n", "\n") if set_language: self.language = "text" return processed_content diff --git a/backend/chainlit/types.py b/backend/chainlit/types.py index 28058bff28..56427778d7 100644 --- a/backend/chainlit/types.py +++ b/backend/chainlit/types.py @@ -3,7 +3,6 @@ if TYPE_CHECKING: from chainlit.element import ElementDict - from chainlit.user import UserDict from chainlit.step import StepDict from dataclasses_json import DataClassJsonMixin @@ -20,7 +19,8 @@ class ThreadDict(TypedDict): id: str createdAt: str name: Optional[str] - user: Optional["UserDict"] + userId: Optional[str] + userIdentifier: Optional[str] tags: Optional[List[str]] metadata: Optional[Dict] steps: List["StepDict"] @@ -33,8 +33,8 @@ class Pagination(BaseModel): class ThreadFilter(BaseModel): - feedback: Optional[Literal[-1, 0, 1]] = None - userIdentifier: Optional[str] = None + feedback: Optional[Literal[0, 1]] = None + userId: Optional[str] = None search: Optional[str] = None @@ -123,6 +123,10 @@ class DeleteThreadRequest(BaseModel): threadId: str +class DeleteFeedbackRequest(BaseModel): + feedbackId: str + + class GetThreadsRequest(BaseModel): pagination: Pagination filter: ThreadFilter @@ -146,16 +150,16 @@ class ChatProfile(DataClassJsonMixin): class FeedbackDict(TypedDict): - value: Literal[-1, 0, 1] - strategy: FeedbackStrategy + forId: str + id: Optional[str] + value: Literal[0, 1] comment: Optional[str] @dataclass class Feedback: forId: str - value: Literal[-1, 0, 1] - strategy: FeedbackStrategy = "BINARY" + value: Literal[0, 1] id: Optional[str] = None comment: Optional[str] = None diff --git a/backend/pyproject.toml b/backend/pyproject.toml index a0a6ba0c36..9109d683ed 100644 --- a/backend/pyproject.toml +++ b/backend/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "chainlit" -version = "1.0.401" +version = "1.0.500" keywords = ['LLM', 'Agents', 'gen ai', 'chat ui', 'chatbot ui', 'openai', 'copilot', 'langchain', 'conversational ai'] description = "Build Conversational AI." authors = ["Chainlit"] @@ -23,7 +23,7 @@ chainlit = 'chainlit.cli:cli' [tool.poetry.dependencies] python = ">=3.8.1,<4.0.0" httpx = ">=0.23.0" -literalai = "0.0.300" +literalai = "0.0.401" dataclasses_json = "^0.5.7" fastapi = ">=0.100" # Starlette >= 0.33.0 breaks socketio (alway 404) diff --git a/cypress/e2e/data_layer/main.py b/cypress/e2e/data_layer/main.py index 9518adbb51..057edb4bb2 100644 --- a/cypress/e2e/data_layer/main.py +++ b/cypress/e2e/data_layer/main.py @@ -10,14 +10,14 @@ create_step_counter = 0 -user_dict = {"id": "test", "createdAt": now, "identifier": "admin"} thread_history = [ { "id": "test1", "name": "thread 1", "createdAt": now, - "user": user_dict, + "userId": "test", + "userIdentifier": "admin", "steps": [ { "id": "test1", @@ -38,7 +38,8 @@ { "id": "test2", "createdAt": now, - "user": user_dict, + "userId": "test", + "userIdentifier": "admin", "name": "thread 2", "steps": [ { @@ -81,7 +82,9 @@ async def list_threads( ) -> cl_data.PaginatedResponse[cl_data.ThreadDict]: return cl_data.PaginatedResponse( data=[t for t in thread_history if t["id"] not in deleted_thread_ids], - pageInfo=cl_data.PageInfo(hasNextPage=False, endCursor=None), + pageInfo=cl_data.PageInfo( + hasNextPage=False, startCursor=None, endCursor=None + ), ) async def get_thread(self, thread_id: str): diff --git a/frontend/src/components/molecules/Code.tsx b/frontend/src/components/molecules/Code.tsx index d37c6cb079..58e73707b3 100644 --- a/frontend/src/components/molecules/Code.tsx +++ b/frontend/src/components/molecules/Code.tsx @@ -99,6 +99,7 @@ const Code = ({ children, ...props }: any) => { alignItems: 'center', borderTopLeftRadius: '4px', borderTopRightRadius: '4px', + color: 'text.secondary', background: isDarkMode ? grey[900] : grey[200] }} > diff --git a/frontend/src/components/molecules/messages/components/FeedbackButtons.tsx b/frontend/src/components/molecules/messages/components/FeedbackButtons.tsx index 608cb20d08..cde0681837 100644 --- a/frontend/src/components/molecules/messages/components/FeedbackButtons.tsx +++ b/frontend/src/components/molecules/messages/components/FeedbackButtons.tsx @@ -1,6 +1,7 @@ import { MessageContext } from 'contexts/MessageContext'; import { useContext, useState } from 'react'; import { useMemo } from 'react'; +import { useRecoilValue } from 'recoil'; import StickyNote2Outlined from '@mui/icons-material/StickyNote2Outlined'; import ThumbDownAlt from '@mui/icons-material/ThumbDownAlt'; @@ -11,6 +12,8 @@ import IconButton from '@mui/material/IconButton'; import Stack from '@mui/material/Stack'; import Tooltip from '@mui/material/Tooltip'; +import { firstUserInteraction, useChatSession } from '@chainlit/react-client'; + import Dialog from 'components/atoms/Dialog'; import { AccentButton } from 'components/atoms/buttons/AccentButton'; import { TextInput } from 'components/atoms/inputs'; @@ -24,18 +27,31 @@ interface Props { } const FeedbackButtons = ({ message }: Props) => { - const { onFeedbackUpdated } = useContext(MessageContext); + const { onFeedbackUpdated, onFeedbackDeleted } = useContext(MessageContext); const [showFeedbackDialog, setShowFeedbackDialog] = useState(); const [commentInput, setCommentInput] = useState(); + const firstInteraction = useRecoilValue(firstUserInteraction); + const { idToResume } = useChatSession(); - const [feedback, setFeedback] = useState(message.feedback?.value || 0); + const [feedback, setFeedback] = useState(message.feedback?.value); const [comment, setComment] = useState(message.feedback?.comment); - const DownIcon = feedback === -1 ? ThumbDownAlt : ThumbDownAltOutlined; + const DownIcon = feedback === 0 ? ThumbDownAlt : ThumbDownAltOutlined; const UpIcon = feedback === 1 ? ThumbUpAlt : ThumbUpAltOutlined; - const handleFeedbackChanged = (feedback: number, comment?: string) => { - onFeedbackUpdated && + const handleFeedbackChanged = (feedback?: number, comment?: string) => { + if (feedback === undefined) { + if (onFeedbackDeleted && message.feedback?.id) { + onFeedbackDeleted( + message, + () => { + setFeedback(undefined); + setComment(undefined); + }, + message.feedback.id + ); + } + } else if (onFeedbackUpdated) { onFeedbackUpdated( message, () => { @@ -43,23 +59,27 @@ const FeedbackButtons = ({ message }: Props) => { setComment(comment); }, { - ...(message.feedback || { strategy: 'BINARY' }), + ...(message.feedback || {}), forId: message.id, value: feedback, comment } ); + } }; - const handleFeedbackClick = (status: number) => { - if (feedback === status) { - handleFeedbackChanged(0); + const handleFeedbackClick = (nextValue: number) => { + if (feedback === nextValue) { + handleFeedbackChanged(undefined); } else { - setShowFeedbackDialog(status); + setShowFeedbackDialog(nextValue); } }; - const disabled = !!message.streaming; + const isPersisted = firstInteraction || idToResume; + const isStreaming = !!message.streaming; + + const disabled = isStreaming || !isPersisted; const buttons = useMemo(() => { const iconSx = { @@ -92,7 +112,7 @@ const FeedbackButtons = ({ message }: Props) => { disabled={disabled} className={`negative-feedback-${feedback === -1 ? 'on' : 'off'}`} onClick={() => { - handleFeedbackClick(-1); + handleFeedbackClick(0); }} > @@ -138,11 +158,11 @@ const FeedbackButtons = ({ message }: Props) => { onClose={() => { setShowFeedbackDialog(undefined); }} - open={!!showFeedbackDialog} + open={showFeedbackDialog !== undefined} title={ - {showFeedbackDialog === -1 ? : } - Provide additional feedback + {showFeedbackDialog === 0 ? : } + Add a comment } content={ diff --git a/frontend/src/components/organisms/chat/Messages/container.tsx b/frontend/src/components/organisms/chat/Messages/container.tsx index a58f5b1617..9c937a3e7a 100644 --- a/frontend/src/components/organisms/chat/Messages/container.tsx +++ b/frontend/src/components/organisms/chat/Messages/container.tsx @@ -36,6 +36,11 @@ interface Props { onSuccess: () => void, feedback: IFeedback ) => void; + onFeedbackDeleted: ( + message: IStep, + onSuccess: () => void, + feedback: string + ) => void; callAction?: (action: IAction) => void; setAutoScroll?: (autoScroll: boolean) => void; } @@ -50,6 +55,7 @@ const MessageContainer = memo( elements, messages, onFeedbackUpdated, + onFeedbackDeleted, callAction, setAutoScroll }: Props) => { @@ -164,6 +170,7 @@ const MessageContainer = memo( onElementRefClick, onError, onFeedbackUpdated, + onFeedbackDeleted, onPlaygroundButtonClick }; }, [ diff --git a/frontend/src/components/organisms/chat/Messages/index.tsx b/frontend/src/components/organisms/chat/Messages/index.tsx index 7ff1d8ec28..0ac12ea380 100644 --- a/frontend/src/components/organisms/chat/Messages/index.tsx +++ b/frontend/src/components/organisms/chat/Messages/index.tsx @@ -106,6 +106,34 @@ const Messages = ({ [] ); + const onFeedbackDeleted = useCallback( + async (message: IStep, onSuccess: () => void, feedbackId: string) => { + try { + toast.promise(apiClient.deleteFeedback(feedbackId, accessToken), { + loading: t('components.organisms.chat.Messages.index.updating'), + success: () => { + setMessages((prev) => + updateMessageById(prev, message.id, { + ...message, + feedback: undefined + }) + ); + onSuccess(); + return t( + 'components.organisms.chat.Messages.index.feedbackUpdated' + ); + }, + error: (err) => { + return {err.message}; + } + }); + } catch (err) { + console.log(err); + } + }, + [] + ); + return !idToResume && !messages.length && projectSettings?.ui.show_readme_as_default ? ( @@ -125,6 +153,7 @@ const Messages = ({ messages={messages} autoScroll={autoScroll} onFeedbackUpdated={onFeedbackUpdated} + onFeedbackDeleted={onFeedbackDeleted} callAction={callActionWithToast} setAutoScroll={setAutoScroll} /> diff --git a/frontend/src/components/organisms/threadHistory/Thread.tsx b/frontend/src/components/organisms/threadHistory/Thread.tsx index 8063224245..2f113a5397 100644 --- a/frontend/src/components/organisms/threadHistory/Thread.tsx +++ b/frontend/src/components/organisms/threadHistory/Thread.tsx @@ -1,4 +1,5 @@ import { useCallback, useEffect, useState } from 'react'; +import { useTranslation } from 'react-i18next'; import { Link } from 'react-router-dom'; import { useRecoilValue } from 'recoil'; import { toast } from 'sonner'; @@ -31,6 +32,7 @@ const Thread = ({ thread, error, isLoading }: Props) => { const accessToken = useRecoilValue(accessTokenState); const [steps, setSteps] = useState([]); const apiClient = useRecoilValue(apiClientState); + const { t } = useTranslation(); useEffect(() => { if (!thread) return; @@ -72,6 +74,40 @@ const Thread = ({ thread, error, isLoading }: Props) => { [] ); + const onFeedbackDeleted = useCallback( + async (message: IStep, onSuccess: () => void, feedbackId: string) => { + try { + toast.promise(apiClient.deleteFeedback(feedbackId, accessToken), { + loading: t('components.organisms.chat.Messages.index.updating'), + success: () => { + setSteps((prev) => + prev.map((step) => { + if (step.id === message.id) { + return { + ...step, + feedback: undefined + }; + } + return step; + }) + ); + + onSuccess(); + return t( + 'components.organisms.chat.Messages.index.feedbackUpdated' + ); + }, + error: (err) => { + return {err.message}; + } + }); + } catch (err) { + console.log(err); + } + }, + [] + ); + if (isLoading) { return ( <> @@ -150,6 +186,7 @@ const Thread = ({ thread, error, isLoading }: Props) => { actions={actions} elements={(elements || []) as IMessageElement[]} onFeedbackUpdated={onFeedbackUpdated} + onFeedbackDeleted={onFeedbackDeleted} messages={messages} autoScroll={true} /> diff --git a/frontend/src/components/organisms/threadHistory/sidebar/filters/FeedbackSelect.tsx b/frontend/src/components/organisms/threadHistory/sidebar/filters/FeedbackSelect.tsx index 53f0bb020a..e6ce544f8a 100644 --- a/frontend/src/components/organisms/threadHistory/sidebar/filters/FeedbackSelect.tsx +++ b/frontend/src/components/organisms/threadHistory/sidebar/filters/FeedbackSelect.tsx @@ -13,10 +13,9 @@ import Stack from '@mui/material/Stack'; import { threadsFiltersState } from 'state/threads'; -export enum FEEDBACKS { - ALL = 0, +export enum Feedback { POSITIVE = 1, - NEGATIVE = -1 + NEGATIVE = 0 } export default function FeedbackSelect() { @@ -25,12 +24,12 @@ export default function FeedbackSelect() { const { t } = useTranslation(); - const handleChange = (feedback: number) => { + const handleChange = (feedback?: number) => { setFilters((prev) => ({ ...prev, feedback })); setAnchorEl(null); }; - const renderMenuItem = (label: string, feedback: number) => { + const renderMenuItem = (label: string, feedback?: number) => { return ( handleChange(feedback)} @@ -53,9 +52,9 @@ export default function FeedbackSelect() { const sx = { width: 16, height: 16 }; switch (filters.feedback) { - case FEEDBACKS.POSITIVE: + case Feedback.POSITIVE: return ; - case FEEDBACKS.NEGATIVE: + case Feedback.NEGATIVE: return ; default: return ; @@ -102,19 +101,19 @@ export default function FeedbackSelect() { t( 'components.organisms.threadHistory.sidebar.filters.FeedbackSelect.feedbackAll' ), - FEEDBACKS.ALL + undefined )} {renderMenuItem( t( 'components.organisms.threadHistory.sidebar.filters.FeedbackSelect.feedbackPositive' ), - FEEDBACKS.POSITIVE + Feedback.POSITIVE )} {renderMenuItem( t( 'components.organisms.threadHistory.sidebar.filters.FeedbackSelect.feedbackNegative' ), - FEEDBACKS.NEGATIVE + Feedback.NEGATIVE )} diff --git a/frontend/src/types/messageContext.ts b/frontend/src/types/messageContext.ts index 34f272a4d0..a5d436a6f5 100644 --- a/frontend/src/types/messageContext.ts +++ b/frontend/src/types/messageContext.ts @@ -30,6 +30,11 @@ interface IMessageContext { onSuccess: () => void, feedback: IFeedback ) => void; + onFeedbackDeleted?: ( + message: IStep, + onSuccess: () => void, + feedbackId: string + ) => void; onError: (error: string) => void; } diff --git a/frontend/tests/message.spec.tsx b/frontend/tests/message.spec.tsx index c3169b860e..c326ca000f 100644 --- a/frontend/tests/message.spec.tsx +++ b/frontend/tests/message.spec.tsx @@ -24,7 +24,8 @@ describe('Message', () => { name: 'bar', createdAt: '12/12/2002', start: '12/12/2002', - end: '12/12/2002' + end: '12/12/2002', + disableFeedback: true } ], waitForAnswer: false, diff --git a/libs/copilot/src/chat/messages/container.tsx b/libs/copilot/src/chat/messages/container.tsx index 391d850bf0..e73dd03afa 100644 --- a/libs/copilot/src/chat/messages/container.tsx +++ b/libs/copilot/src/chat/messages/container.tsx @@ -33,6 +33,11 @@ interface Props { onSuccess: () => void, feedback: IFeedback ) => void; + onFeedbackDeleted: ( + message: IStep, + onSuccess: () => void, + feedbackId: string + ) => void; callAction?: (action: IAction) => void; setAutoScroll?: (autoScroll: boolean) => void; } @@ -47,6 +52,7 @@ const MessageContainer = memo( elements, messages, onFeedbackUpdated, + onFeedbackDeleted, callAction, setAutoScroll }: Props) => { @@ -115,6 +121,7 @@ const MessageContainer = memo( onElementRefClick, onError, onFeedbackUpdated, + onFeedbackDeleted, onPlaygroundButtonClick }; }, [ diff --git a/libs/copilot/src/chat/messages/index.tsx b/libs/copilot/src/chat/messages/index.tsx index a8806d05c2..af3f5ac367 100644 --- a/libs/copilot/src/chat/messages/index.tsx +++ b/libs/copilot/src/chat/messages/index.tsx @@ -96,6 +96,32 @@ const Messages = ({ [] ); + const onFeedbackDeleted = useCallback( + async (message: IStep, onSuccess: () => void, feedbackId: string) => { + try { + toast.promise(apiClient.deleteFeedback(feedbackId, accessToken), { + loading: 'Updating', + success: (res) => { + setMessages((prev) => + updateMessageById(prev, message.id, { + ...message, + feedback: undefined + }) + ); + onSuccess(); + return 'Feedback updated!'; + }, + error: (err) => { + return {err.message}; + } + }); + } catch (err) { + console.log(err); + } + }, + [] + ); + const showWelcomeScreen = !idToResume && !messages.length && @@ -122,6 +148,7 @@ const Messages = ({ messages={messages} autoScroll={autoScroll} onFeedbackUpdated={onFeedbackUpdated} + onFeedbackDeleted={onFeedbackDeleted} callAction={callActionWithToast} setAutoScroll={setAutoScroll} /> diff --git a/libs/react-client/src/api/index.tsx b/libs/react-client/src/api/index.tsx index b6f7aeb30a..f542d11a16 100644 --- a/libs/react-client/src/api/index.tsx +++ b/libs/react-client/src/api/index.tsx @@ -213,6 +213,14 @@ export class ChainlitAPI extends APIBase { return res.json(); } + async deleteFeedback( + feedbackId: string, + accessToken?: string + ): Promise<{ success: boolean }> { + const res = await this.delete(`/feedback`, { feedbackId }, accessToken); + return res.json(); + } + async listThreads( pagination: IPagination, filter: IThreadFilters, diff --git a/libs/react-client/src/types/feedback.ts b/libs/react-client/src/types/feedback.ts index 7c8b764875..625a8df4de 100644 --- a/libs/react-client/src/types/feedback.ts +++ b/libs/react-client/src/types/feedback.ts @@ -2,6 +2,5 @@ export interface IFeedback { id?: string; forId?: string; comment?: string; - strategy: 'BINARY'; value: number; } diff --git a/libs/react-client/src/types/thread.ts b/libs/react-client/src/types/thread.ts index 8dbdfc73e5..fe0cc2bbc2 100644 --- a/libs/react-client/src/types/thread.ts +++ b/libs/react-client/src/types/thread.ts @@ -1,12 +1,12 @@ import { IElement } from './element'; import { IStep } from './step'; -import { IUser } from './user'; export interface IThread { id: string; createdAt: number | string; name?: string; - user?: IUser; + userId?: string; + userIdentifier?: string; metadata?: Record; steps: IStep[]; elements?: IElement[];