From 59398777b7f04048ddd8961a47924a6a58c23be0 Mon Sep 17 00:00:00 2001 From: Josh Hayes <35790761+hayescode@users.noreply.github.com> Date: Wed, 10 Apr 2024 12:15:16 -0500 Subject: [PATCH] Update sql_alchemy.py --- backend/chainlit/data/sql_alchemy.py | 159 +++++++++++++-------------- 1 file changed, 76 insertions(+), 83 deletions(-) diff --git a/backend/chainlit/data/sql_alchemy.py b/backend/chainlit/data/sql_alchemy.py index 1f1c752f3b..bb355f4d16 100644 --- a/backend/chainlit/data/sql_alchemy.py +++ b/backend/chainlit/data/sql_alchemy.py @@ -65,19 +65,16 @@ async def execute_sql(self, query: str, parameters: dict) -> Union[List[Dict[str return None async def get_current_timestamp(self) -> str: - return datetime.now(timezone.utc).astimezone().isoformat() + return datetime.now().isoformat() + "Z" def clean_result(self, obj): """Recursively change UUID -> str and serialize dictionaries""" if isinstance(obj, dict): - for k, v in obj.items(): - obj[k] = self.clean_result(v) + return {k: self.clean_result(v) for k, v in obj.items()} elif isinstance(obj, list): return [self.clean_result(item) for item in obj] elif isinstance(obj, uuid.UUID): return str(obj) - elif isinstance(obj, dict): - return json.dumps(obj) return obj ###### User ###### @@ -119,31 +116,20 @@ async def get_thread_author(self, thread_id: str) -> str: if isinstance(result, list) and result[0]: author_identifier = result[0].get('userIdentifier') if author_identifier is not None: + print(f'Author found: {author_identifier}') return author_identifier raise ValueError (f"Author not found for thread_id {thread_id}") async def get_thread(self, thread_id: str) -> Optional[ThreadDict]: logger.info(f"SQLAlchemy: get_thread, thread_id={thread_id}") - thread_user_identifier = await self.get_thread_author(thread_id=thread_id) - if not thread_user_identifier: - return None - user = await self.get_user(thread_user_identifier) - if not user: - return None + user_threads: Optional[List[ThreadDict]] = await self.get_all_user_threads(thread_id=thread_id) + if user_threads: + return user_threads[0] else: - user_id = user.id - user_threads: Optional[List[ThreadDict]] = await self.get_all_user_threads(user_id=user_id) - if not user_threads: return None - for thread in user_threads: - if thread['id'] == thread_id: - return thread - return None async def update_thread(self, thread_id: str, name: Optional[str] = None, user_id: Optional[str] = None, metadata: Optional[Dict] = None, tags: Optional[List[str]] = None): logger.info(f"SQLAlchemy: update_thread, thread_id={thread_id}") - if not user_id: - return if context.session.user is not None: user_identifier = context.session.user.identifier else: @@ -234,6 +220,7 @@ async def create_step(self, step_dict: 'StepDict'): raise ValueError("No authenticated user in context") step_dict['showInput'] = str(step_dict.get('showInput', '')).lower() if 'showInput' in step_dict else None parameters = {key: value for key, value in step_dict.items() if value is not None and not (isinstance(value, dict) and not value)} + parameters['metadata'] = json.dumps(step_dict.get('metadata', {})) columns = ', '.join(f'"{key}"' for key in parameters.keys()) values = ', '.join(f':{key}' for key in parameters.keys()) updates = ', '.join(f'"{key}" = :{key}' for key in parameters.keys() if key != 'id') @@ -356,8 +343,8 @@ async def delete_element(self, element_id: str): async def delete_user_session(self, id: str) -> bool: return False # Not sure why documentation wants this - async def get_all_user_threads(self, user_id: str) -> Optional[List[ThreadDict]]: - """Fetch all user threads for fast retrieval, up to self.user_thread_limit""" + async def get_all_user_threads(self, user_id: Optional[str] = None, thread_id: Optional[str] = None) -> Optional[List[ThreadDict]]: + """Fetch all user threads up to self.user_thread_limit, or one thread by id if thread_id is provided.""" logger.info(f"SQLAlchemy: get_all_user_threads") user_threads_query = """ SELECT @@ -369,16 +356,17 @@ async def get_all_user_threads(self, user_id: str) -> Optional[List[ThreadDict]] "tags" AS thread_tags, "metadata" AS thread_metadata FROM threads - WHERE "userId" = :user_id + WHERE "userId" = :user_id OR "id" = :thread_id ORDER BY "createdAt" DESC LIMIT :limit """ - user_threads = await self.execute_sql(query=user_threads_query, parameters={"user_id": user_id, "limit": self.user_thread_limit}) + user_threads = await self.execute_sql(query=user_threads_query, parameters={"user_id": user_id, "limit": self.user_thread_limit, "thread_id": thread_id}) if not isinstance(user_threads, list): return None if not user_threads: return [] - thread_ids = "('" + "','".join(map(str, [thread['thread_id'] for thread in user_threads])) + "')" + else: + thread_ids = "('" + "','".join(map(str, [thread['thread_id'] for thread in user_threads])) + "')" steps_feedbacks_query = f""" SELECT @@ -392,6 +380,7 @@ async def get_all_user_threads(self, user_id: str) -> Optional[List[ThreadDict]] s."waitForAnswer" AS step_waitforanswer, s."isError" AS step_iserror, s."metadata" AS step_metadata, + s."tags" AS step_tags, s."input" AS step_input, s."output" AS step_output, s."createdAt" AS step_createdat, @@ -432,72 +421,76 @@ async def get_all_user_threads(self, user_id: str) -> Optional[List[ThreadDict]] thread_dicts = {} for thread in user_threads: thread_id = thread['thread_id'] - thread_dicts[thread_id] = ThreadDict( - id=thread_id, - createdAt=thread['thread_createdat'], - name=thread['thread_name'], - userId=thread['user_id'], - userIdentifier=thread['user_identifier'], - tags=thread['thread_tags'], - metadata=thread['thread_metadata'], - steps=[], - elements=[] - ) + if thread_id is not None: + thread_dicts[thread_id] = ThreadDict( + id=thread_id, + createdAt=thread['thread_createdat'], + name=thread['thread_name'], + userId=thread['user_id'], + userIdentifier=thread['user_identifier'], + tags=thread['thread_tags'], + metadata=thread['thread_metadata'], + steps=[], + elements=[] + ) # Process steps_feedbacks to populate the steps in the corresponding ThreadDict if isinstance(steps_feedbacks, list): for step_feedback in steps_feedbacks: thread_id = step_feedback['step_threadid'] - feedback = None - if step_feedback['feedback_value'] is not None: - feedback = FeedbackDict( - forId=step_feedback['step_id'], - id=step_feedback.get('feedback_id'), - value=step_feedback['feedback_value'], - comment=step_feedback.get('feedback_comment') + if thread_id is not None: + feedback = None + if step_feedback['feedback_value'] is not None: + feedback = FeedbackDict( + forId=step_feedback['step_id'], + id=step_feedback.get('feedback_id'), + value=step_feedback['feedback_value'], + comment=step_feedback.get('feedback_comment') + ) + step_dict = StepDict( + id=step_feedback['step_id'], + name=step_feedback['step_name'], + type=step_feedback['step_type'], + threadId=thread_id, + parentId=step_feedback.get('step_parentid'), + disableFeedback=step_feedback.get('step_disablefeedback', False), + streaming=step_feedback.get('step_streaming', False), + waitForAnswer=step_feedback.get('step_waitforanswer'), + isError=step_feedback.get('step_iserror'), + metadata=step_feedback['step_metadata'] if step_feedback.get('step_metadata') is not None else {}, + tags=step_feedback.get('step_tags'), + input=step_feedback.get('step_input', '') if step_feedback['step_showinput'] else '', + output=step_feedback.get('step_output', ''), + createdAt=step_feedback.get('step_createdat'), + start=step_feedback.get('step_start'), + end=step_feedback.get('step_end'), + generation=step_feedback.get('step_generation'), + showInput=step_feedback.get('step_showinput'), + language=step_feedback.get('step_language'), + indent=step_feedback.get('step_indent'), + feedback=feedback ) - step_dict = StepDict( - id=step_feedback['step_id'], - name=step_feedback['step_name'], - type=step_feedback['step_type'], - threadId=thread_id, - parentId=step_feedback.get('step_parentid'), - disableFeedback=step_feedback.get('step_disablefeedback', False), - streaming=step_feedback.get('step_streaming', False), - waitForAnswer=step_feedback.get('step_waitforanswer'), - isError=step_feedback.get('step_iserror'), - metadata=step_feedback.get('step_metadata', {}), - input=step_feedback.get('step_input', '') if step_feedback['step_showinput'] else None, - output=step_feedback.get('step_output', ''), - createdAt=step_feedback.get('step_createdat'), - start=step_feedback.get('step_start'), - end=step_feedback.get('step_end'), - generation=step_feedback.get('step_generation'), - showInput=step_feedback.get('step_showinput'), - language=step_feedback.get('step_language'), - indent=step_feedback.get('step_indent'), - feedback=feedback - ) - # Append the step to the steps list of the corresponding ThreadDict - thread_dicts[thread_id]['steps'].append(step_dict) + # Append the step to the steps list of the corresponding ThreadDict + thread_dicts[thread_id]['steps'].append(step_dict) if isinstance(elements, list): for element in elements: thread_id = element['element_threadid'] - element_dict = ElementDict( - id=element['element_id'], - threadId=thread_id, - type=element['element_type'], - chainlitKey=element.get('element_chainlitkey'), - url=element.get('element_url'), - objectKey=element.get('element_objectkey'), - name=element['element_name'], - display=element['element_display'], - size=element.get('element_size'), - language=element.get('element_language'), - page=element.get('element_page'), - forId=element.get('element_forid'), - mime=element.get('element_mime'), - ) - thread_dicts[thread_id]['elements'].append(element_dict) # type: ignore + if thread_id is not None: + element_dict = ElementDict( + id=element['element_id'], + threadId=thread_id, + type=element['element_type'], + chainlitKey=element.get('element_chainlitkey'), + url=element.get('element_url'), + objectKey=element.get('element_objectkey'), + name=element['element_name'], + display=element['element_display'], + size=element.get('element_size'), + language=element.get('element_language'), + page=element.get('element_page'), + forId=element.get('element_forid'), + mime=element.get('element_mime'), + ) + thread_dicts[thread_id]['elements'].append(element_dict) # type: ignore return list(thread_dicts.values())