Skip to content

Commit

Permalink
remove context check except for db writes
Browse files Browse the repository at this point in the history
  • Loading branch information
hayescode committed Apr 10, 2024
1 parent 67bdff4 commit 7de42e6
Showing 1 changed file with 30 additions and 30 deletions.
60 changes: 30 additions & 30 deletions backend/chainlit/data/sql_alchemy.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,25 +113,25 @@ async def create_user(self, user: User) -> Optional[PersistedUser]:
###### Threads ######
async def get_thread_author(self, thread_id: str) -> str:
logger.info(f"SQLAlchemy: get_thread_author, thread_id={thread_id}")
if not getattr(context.session.user, 'id', None):
raise ValueError("No authenticated user in context")
query = """SELECT u.* FROM threads t JOIN users u ON t."user_id" = u."id" WHERE t."id" = :id"""
query = """SELECT "userIdentifier" FROM threads WHERE "id" = :id"""
parameters = {"id": thread_id}
result = await self.execute_sql(query=query, parameters=parameters)
if isinstance(result, list) and result[0]:
author_identifier = result[0].get('identifier')
author_identifier = result[0].get('userIdentifier')
if author_identifier is not None:
return author_identifier
raise ValueError(f"Author not found for thread_id {thread_id}")
raise ValueError (f"Author not found for thread_id {thread_id}")

async def get_thread(self, thread_id: str) -> Optional[ThreadDict]:
logger.info(f"SQLAlchemy: get_thread, thread_id={thread_id}")
if not getattr(context.session.user, 'id', None):
raise ValueError("No authenticated user in context")
if isinstance(context.session.user, PersistedUser):
user_id = context.session.user.id
thread_user_identifier = await self.get_thread_author(thread_id=thread_id)
if not thread_user_identifier:
return None
user = await self.get_user(thread_user_identifier)
if not user:
return None
else:
raise ValueError("User not found in session context or is not a PersistedUser")
user_id = user.id
user_threads: Optional[List[ThreadDict]] = await self.get_all_user_threads(user_id=user_id)
if not user_threads:
return None
Expand All @@ -142,8 +142,8 @@ async def get_thread(self, thread_id: str) -> Optional[ThreadDict]:

async def update_thread(self, thread_id: str, name: Optional[str] = None, user_id: Optional[str] = None, metadata: Optional[Dict] = None, tags: Optional[List[str]] = None):
logger.info(f"SQLAlchemy: update_thread, thread_id={thread_id}")
if not getattr(context.session.user, 'id', None):
raise ValueError("No authenticated user in context")
if not user_id:
return
if context.session.user is not None:
user_identifier = context.session.user.identifier
else:
Expand Down Expand Up @@ -171,16 +171,19 @@ async def update_thread(self, thread_id: str, name: Optional[str] = None, user_i

async def delete_thread(self, thread_id: str):
logger.info(f"SQLAlchemy: delete_thread, thread_id={thread_id}")
if not getattr(context.session.user, 'id', None):
raise ValueError("No authenticated user in context")
query = """DELETE FROM threads WHERE "id" = :id"""
# Delete feedbacks/elements/steps/thread
feedbacks_query = """DELETE FROM feedbacks WHERE "forId" IN (SELECT "id" FROM steps WHERE "threadId" = :id)"""
elements_query = """DELETE FROM elements WHERE "threadId" = :id"""
steps_query = """DELETE FROM steps WHERE "threadId" = :id"""
thread_query = """DELETE FROM threads WHERE "id" = :id"""
parameters = {"id": thread_id}
await self.execute_sql(query=query, parameters=parameters)
await self.execute_sql(query=feedbacks_query, parameters=parameters)
await self.execute_sql(query=elements_query, parameters=parameters)
await self.execute_sql(query=steps_query, parameters=parameters)
await self.execute_sql(query=thread_query, parameters=parameters)

async def list_threads(self, pagination: Pagination, filters: ThreadFilter) -> PaginatedResponse:
logger.info(f"SQLAlchemy: list_threads, pagination={pagination}, filters={filters}")
if not getattr(context.session.user, 'id', None):
raise ValueError("No authenticated user in context")
if not filters.userId:
raise ValueError("userId is required")
all_user_threads: List[ThreadDict] = await self.get_all_user_threads(user_id=filters.userId) or []
Expand Down Expand Up @@ -250,11 +253,14 @@ async def update_step(self, step_dict: 'StepDict'):
@queue_until_user_message()
async def delete_step(self, step_id: str):
logger.info(f"SQLAlchemy: delete_step, step_id={step_id}")
if not getattr(context.session.user, 'id', None):
raise ValueError("No authenticated user in context")
query = """DELETE FROM steps WHERE "id" = :id"""
# Delete feedbacks/elements/steps
feedbacks_query = """DELETE FROM feedbacks WHERE "forId" = :id"""
elements_query = """DELETE FROM elements WHERE "forId" = :id"""
steps_query = """DELETE FROM steps WHERE "forId" = :id"""
parameters = {"id": step_id}
await self.execute_sql(query=query, parameters=parameters)
await self.execute_sql(query=feedbacks_query, parameters=parameters)
await self.execute_sql(query=elements_query, parameters=parameters)
await self.execute_sql(query=steps_query, parameters=parameters)

###### Feedback ######
async def upsert_feedback(self, feedback: Feedback) -> str:
Expand All @@ -279,8 +285,6 @@ async def upsert_feedback(self, feedback: Feedback) -> str:

async def delete_feedback(self, feedback_id: str) -> bool:
logger.info(f"SQLAlchemy: delete_feedback, feedback_id={feedback_id}")
if not getattr(context.session.user, 'id', None):
raise ValueError("No authenticated user in context")
query = """DELETE FROM feedbacks WHERE "id" = :feedback_id"""
parameters = {"feedback_id": feedback_id}
await self.execute_sql(query=query, parameters=parameters)
Expand Down Expand Up @@ -345,8 +349,6 @@ async def create_element(self, element: 'Element'):
@queue_until_user_message()
async def delete_element(self, element_id: str):
logger.info(f"SQLAlchemy: delete_element, element_id={element_id}")
if not getattr(context.session.user, 'id', None):
raise ValueError("No authenticated user in context")
query = """DELETE FROM elements WHERE "id" = :id"""
parameters = {"id": element_id}
await self.execute_sql(query=query, parameters=parameters)
Expand All @@ -357,8 +359,6 @@ async def delete_user_session(self, id: str) -> bool:
async def get_all_user_threads(self, user_id: str) -> Optional[List[ThreadDict]]:
"""Fetch all user threads for fast retrieval, up to self.user_thread_limit"""
logger.info(f"SQLAlchemy: get_all_user_threads")
if not getattr(context.session.user, 'id', None):
raise ValueError("No authenticated user in context")
user_threads_query = """
SELECT
"id" AS thread_id,
Expand All @@ -376,9 +376,9 @@ async def get_all_user_threads(self, user_id: str) -> Optional[List[ThreadDict]]
user_threads = await self.execute_sql(query=user_threads_query, parameters={"user_id": user_id, "limit": self.user_thread_limit})
if not isinstance(user_threads, list):
return None
thread_ids = "('" + "','".join(map(str, [thread['thread_id'] for thread in user_threads])) + "')"
if not thread_ids:
if not user_threads:
return []
thread_ids = "('" + "','".join(map(str, [thread['thread_id'] for thread in user_threads])) + "')"

steps_feedbacks_query = f"""
SELECT
Expand Down

0 comments on commit 7de42e6

Please sign in to comment.