Skip to content

Commit

Permalink
Update sql_alchemy.py
Browse files Browse the repository at this point in the history
  • Loading branch information
hayescode committed Apr 10, 2024
1 parent 7de42e6 commit 5939877
Showing 1 changed file with 76 additions and 83 deletions.
159 changes: 76 additions & 83 deletions backend/chainlit/data/sql_alchemy.py
Expand Up @@ -65,19 +65,16 @@ async def execute_sql(self, query: str, parameters: dict) -> Union[List[Dict[str
return None

async def get_current_timestamp(self) -> str:
return datetime.now(timezone.utc).astimezone().isoformat()
return datetime.now().isoformat() + "Z"

def clean_result(self, obj):
"""Recursively change UUID -> str and serialize dictionaries"""
if isinstance(obj, dict):
for k, v in obj.items():
obj[k] = self.clean_result(v)
return {k: self.clean_result(v) for k, v in obj.items()}
elif isinstance(obj, list):
return [self.clean_result(item) for item in obj]
elif isinstance(obj, uuid.UUID):
return str(obj)
elif isinstance(obj, dict):
return json.dumps(obj)
return obj

###### User ######
Expand Down Expand Up @@ -119,31 +116,20 @@ async def get_thread_author(self, thread_id: str) -> str:
if isinstance(result, list) and result[0]:
author_identifier = result[0].get('userIdentifier')
if author_identifier is not None:
print(f'Author found: {author_identifier}')
return author_identifier
raise ValueError (f"Author not found for thread_id {thread_id}")

async def get_thread(self, thread_id: str) -> Optional[ThreadDict]:
logger.info(f"SQLAlchemy: get_thread, thread_id={thread_id}")
thread_user_identifier = await self.get_thread_author(thread_id=thread_id)
if not thread_user_identifier:
return None
user = await self.get_user(thread_user_identifier)
if not user:
return None
user_threads: Optional[List[ThreadDict]] = await self.get_all_user_threads(thread_id=thread_id)
if user_threads:
return user_threads[0]
else:
user_id = user.id
user_threads: Optional[List[ThreadDict]] = await self.get_all_user_threads(user_id=user_id)
if not user_threads:
return None
for thread in user_threads:
if thread['id'] == thread_id:
return thread
return None

async def update_thread(self, thread_id: str, name: Optional[str] = None, user_id: Optional[str] = None, metadata: Optional[Dict] = None, tags: Optional[List[str]] = None):
logger.info(f"SQLAlchemy: update_thread, thread_id={thread_id}")
if not user_id:
return
if context.session.user is not None:
user_identifier = context.session.user.identifier
else:
Expand Down Expand Up @@ -234,6 +220,7 @@ async def create_step(self, step_dict: 'StepDict'):
raise ValueError("No authenticated user in context")
step_dict['showInput'] = str(step_dict.get('showInput', '')).lower() if 'showInput' in step_dict else None
parameters = {key: value for key, value in step_dict.items() if value is not None and not (isinstance(value, dict) and not value)}
parameters['metadata'] = json.dumps(step_dict.get('metadata', {}))
columns = ', '.join(f'"{key}"' for key in parameters.keys())
values = ', '.join(f':{key}' for key in parameters.keys())
updates = ', '.join(f'"{key}" = :{key}' for key in parameters.keys() if key != 'id')
Expand Down Expand Up @@ -356,8 +343,8 @@ async def delete_element(self, element_id: str):
async def delete_user_session(self, id: str) -> bool:
return False # Not sure why documentation wants this

async def get_all_user_threads(self, user_id: str) -> Optional[List[ThreadDict]]:
"""Fetch all user threads for fast retrieval, up to self.user_thread_limit"""
async def get_all_user_threads(self, user_id: Optional[str] = None, thread_id: Optional[str] = None) -> Optional[List[ThreadDict]]:
"""Fetch all user threads up to self.user_thread_limit, or one thread by id if thread_id is provided."""
logger.info(f"SQLAlchemy: get_all_user_threads")
user_threads_query = """
SELECT
Expand All @@ -369,16 +356,17 @@ async def get_all_user_threads(self, user_id: str) -> Optional[List[ThreadDict]]
"tags" AS thread_tags,
"metadata" AS thread_metadata
FROM threads
WHERE "userId" = :user_id
WHERE "userId" = :user_id OR "id" = :thread_id
ORDER BY "createdAt" DESC
LIMIT :limit
"""
user_threads = await self.execute_sql(query=user_threads_query, parameters={"user_id": user_id, "limit": self.user_thread_limit})
user_threads = await self.execute_sql(query=user_threads_query, parameters={"user_id": user_id, "limit": self.user_thread_limit, "thread_id": thread_id})
if not isinstance(user_threads, list):
return None
if not user_threads:
return []
thread_ids = "('" + "','".join(map(str, [thread['thread_id'] for thread in user_threads])) + "')"
else:
thread_ids = "('" + "','".join(map(str, [thread['thread_id'] for thread in user_threads])) + "')"

steps_feedbacks_query = f"""
SELECT
Expand All @@ -392,6 +380,7 @@ async def get_all_user_threads(self, user_id: str) -> Optional[List[ThreadDict]]
s."waitForAnswer" AS step_waitforanswer,
s."isError" AS step_iserror,
s."metadata" AS step_metadata,
s."tags" AS step_tags,
s."input" AS step_input,
s."output" AS step_output,
s."createdAt" AS step_createdat,
Expand Down Expand Up @@ -432,72 +421,76 @@ async def get_all_user_threads(self, user_id: str) -> Optional[List[ThreadDict]]
thread_dicts = {}
for thread in user_threads:
thread_id = thread['thread_id']
thread_dicts[thread_id] = ThreadDict(
id=thread_id,
createdAt=thread['thread_createdat'],
name=thread['thread_name'],
userId=thread['user_id'],
userIdentifier=thread['user_identifier'],
tags=thread['thread_tags'],
metadata=thread['thread_metadata'],
steps=[],
elements=[]
)
if thread_id is not None:
thread_dicts[thread_id] = ThreadDict(
id=thread_id,
createdAt=thread['thread_createdat'],
name=thread['thread_name'],
userId=thread['user_id'],
userIdentifier=thread['user_identifier'],
tags=thread['thread_tags'],
metadata=thread['thread_metadata'],
steps=[],
elements=[]
)
# Process steps_feedbacks to populate the steps in the corresponding ThreadDict
if isinstance(steps_feedbacks, list):
for step_feedback in steps_feedbacks:
thread_id = step_feedback['step_threadid']
feedback = None
if step_feedback['feedback_value'] is not None:
feedback = FeedbackDict(
forId=step_feedback['step_id'],
id=step_feedback.get('feedback_id'),
value=step_feedback['feedback_value'],
comment=step_feedback.get('feedback_comment')
if thread_id is not None:
feedback = None
if step_feedback['feedback_value'] is not None:
feedback = FeedbackDict(
forId=step_feedback['step_id'],
id=step_feedback.get('feedback_id'),
value=step_feedback['feedback_value'],
comment=step_feedback.get('feedback_comment')
)
step_dict = StepDict(
id=step_feedback['step_id'],
name=step_feedback['step_name'],
type=step_feedback['step_type'],
threadId=thread_id,
parentId=step_feedback.get('step_parentid'),
disableFeedback=step_feedback.get('step_disablefeedback', False),
streaming=step_feedback.get('step_streaming', False),
waitForAnswer=step_feedback.get('step_waitforanswer'),
isError=step_feedback.get('step_iserror'),
metadata=step_feedback['step_metadata'] if step_feedback.get('step_metadata') is not None else {},
tags=step_feedback.get('step_tags'),
input=step_feedback.get('step_input', '') if step_feedback['step_showinput'] else '',
output=step_feedback.get('step_output', ''),
createdAt=step_feedback.get('step_createdat'),
start=step_feedback.get('step_start'),
end=step_feedback.get('step_end'),
generation=step_feedback.get('step_generation'),
showInput=step_feedback.get('step_showinput'),
language=step_feedback.get('step_language'),
indent=step_feedback.get('step_indent'),
feedback=feedback
)
step_dict = StepDict(
id=step_feedback['step_id'],
name=step_feedback['step_name'],
type=step_feedback['step_type'],
threadId=thread_id,
parentId=step_feedback.get('step_parentid'),
disableFeedback=step_feedback.get('step_disablefeedback', False),
streaming=step_feedback.get('step_streaming', False),
waitForAnswer=step_feedback.get('step_waitforanswer'),
isError=step_feedback.get('step_iserror'),
metadata=step_feedback.get('step_metadata', {}),
input=step_feedback.get('step_input', '') if step_feedback['step_showinput'] else None,
output=step_feedback.get('step_output', ''),
createdAt=step_feedback.get('step_createdat'),
start=step_feedback.get('step_start'),
end=step_feedback.get('step_end'),
generation=step_feedback.get('step_generation'),
showInput=step_feedback.get('step_showinput'),
language=step_feedback.get('step_language'),
indent=step_feedback.get('step_indent'),
feedback=feedback
)
# Append the step to the steps list of the corresponding ThreadDict
thread_dicts[thread_id]['steps'].append(step_dict)
# Append the step to the steps list of the corresponding ThreadDict
thread_dicts[thread_id]['steps'].append(step_dict)

if isinstance(elements, list):
for element in elements:
thread_id = element['element_threadid']
element_dict = ElementDict(
id=element['element_id'],
threadId=thread_id,
type=element['element_type'],
chainlitKey=element.get('element_chainlitkey'),
url=element.get('element_url'),
objectKey=element.get('element_objectkey'),
name=element['element_name'],
display=element['element_display'],
size=element.get('element_size'),
language=element.get('element_language'),
page=element.get('element_page'),
forId=element.get('element_forid'),
mime=element.get('element_mime'),
)
thread_dicts[thread_id]['elements'].append(element_dict) # type: ignore
if thread_id is not None:
element_dict = ElementDict(
id=element['element_id'],
threadId=thread_id,
type=element['element_type'],
chainlitKey=element.get('element_chainlitkey'),
url=element.get('element_url'),
objectKey=element.get('element_objectkey'),
name=element['element_name'],
display=element['element_display'],
size=element.get('element_size'),
language=element.get('element_language'),
page=element.get('element_page'),
forId=element.get('element_forid'),
mime=element.get('element_mime'),
)
thread_dicts[thread_id]['elements'].append(element_dict) # type: ignore

return list(thread_dicts.values())

0 comments on commit 5939877

Please sign in to comment.