Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support SQLAlchemy for custom data layer #836

Merged
merged 42 commits into from Apr 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
1a4c361
Create sql_alchemy.py
hayescode Mar 21, 2024
bec6e5d
Update pyproject.toml
hayescode Mar 21, 2024
354ddc2
Create sql_alchemy.py
hayescode Mar 21, 2024
775ed88
Merge branch 'Chainlit:main' into list_thread-update-and-imports
hayescode Mar 25, 2024
21b2d31
Merge branch 'Chainlit:main' into main
hayescode Mar 25, 2024
610a0f4
Create sql_alchemy.py
hayescode Mar 25, 2024
e11d00d
Delete backend/chainlit/sql_alchemy.py
hayescode Mar 25, 2024
3e576c7
Merge pull request #2 from hayescode/list_thread-update-and-imports
hayescode Mar 25, 2024
b99d05d
Delete backend/chainlit/sql_alchemy.py
hayescode Mar 25, 2024
59b779d
Update types.py
hayescode Mar 25, 2024
7d87846
Update BaseDataLayer to use types from chainlit
hayescode Mar 25, 2024
b93c6ec
Update __init__.py
hayescode Mar 25, 2024
f244196
Update sql_alchemy.py
hayescode Mar 25, 2024
055a906
persist elements with url, sql select typo fixes
hayescode Mar 26, 2024
2e386b5
Merge branch 'Chainlit:main' into main
hayescode Mar 28, 2024
03f748b
Update pyproject.toml
hayescode Mar 28, 2024
0bd0aa6
Update __init__.py
hayescode Mar 28, 2024
b80209d
Update __init__.py
hayescode Mar 28, 2024
f7cc829
Create storage_clients.py
hayescode Mar 28, 2024
56056d0
Update sql_alchemy.py
hayescode Mar 28, 2024
2346bd4
Update sql_alchemy.py
hayescode Mar 28, 2024
f0a71ac
Update storage_clients.py
hayescode Mar 28, 2024
8711ec4
Update __init__.py
hayescode Mar 28, 2024
4a42df9
Merge branch 'main' into main
hayescode Apr 2, 2024
44312e5
Chainlit 1.0.500 compatability (#3)
hayescode Apr 2, 2024
3e36bff
Merge branch 'Chainlit:main' into main
hayescode Apr 2, 2024
52b07bf
Update types.py
hayescode Apr 2, 2024
cc0e307
Merge branch 'Chainlit:main' into main
hayescode Apr 4, 2024
1030522
Update pyproject.toml to make custom-data optional
hayescode Apr 4, 2024
dde5e0b
Update pyproject.toml
hayescode Apr 4, 2024
7729488
Update mypy.yaml
hayescode Apr 4, 2024
4b49360
Update storage_clients.py
hayescode Apr 4, 2024
3da7d81
Update sql_alchemy.py
hayescode Apr 4, 2024
a03da85
Update storage_clients.py
hayescode Apr 4, 2024
27ba523
Update __init__.py
hayescode Apr 9, 2024
f4a3928
Merge branch 'main' into main
hayescode Apr 9, 2024
958c9cf
add context.session.user check
hayescode Apr 9, 2024
1fb4379
Update sql_alchemy.py
hayescode Apr 9, 2024
67bdff4
Merge branch 'Chainlit:main' into main
hayescode Apr 10, 2024
7de42e6
remove context check except for db writes
hayescode Apr 10, 2024
5939877
Update sql_alchemy.py
hayescode Apr 10, 2024
5d0cb18
Update sql_alchemy.py
hayescode Apr 11, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/mypy.yaml
Expand Up @@ -18,6 +18,6 @@ jobs:
- name: Install Poetry
run: pip install poetry
- name: Install dependencies
run: poetry install --with tests --with mypy
run: poetry install --with tests --with mypy --with custom-data
- name: Run Mypy
run: poetry run mypy chainlit/
20 changes: 16 additions & 4 deletions backend/chainlit/data/__init__.py
Expand Up @@ -2,16 +2,16 @@
import json
import os
from collections import deque
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Union, cast
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Union, cast, Protocol, Any

import aiofiles
from chainlit.config import config
from chainlit.context import context
from chainlit.logger import logger
from chainlit.session import WebsocketSession
from chainlit.types import Feedback, Pagination, ThreadDict, ThreadFilter
from chainlit.types import Feedback, Pagination, ThreadDict, ThreadFilter, PageInfo, PaginatedResponse
from chainlit.user import PersistedUser, User
from literalai import Attachment, PageInfo, PaginatedResponse, Score as LiteralScore, Step as LiteralStep
from literalai import Attachment, PaginatedResponse as LiteralPaginatedResponse, Score as LiteralScore, Step as LiteralStep
from literalai.filter import threads_filters as LiteralThreadsFilters
from literalai.step import StepDict as LiteralStepDict

Expand Down Expand Up @@ -411,12 +411,20 @@ async def list_threads(
}
)

return await self.client.api.list_threads(
literal_response: LiteralPaginatedResponse = await self.client.api.list_threads(
first=pagination.first,
after=pagination.cursor,
filters=literal_filters,
order_by={"column": "createdAt", "direction": "DESC"},
)
return PaginatedResponse(
pageInfo=PageInfo(
hasNextPage=literal_response.pageInfo.hasNextPage,
startCursor=literal_response.pageInfo.startCursor,
endCursor=literal_response.pageInfo.endCursor
),
data=literal_response.data,
)

async def get_thread(self, thread_id: str) -> "Optional[ThreadDict]":
thread = await self.client.api.get_thread(id=thread_id)
Expand Down Expand Up @@ -462,6 +470,10 @@ async def update_thread(
tags=tags,
)

class BaseStorageClient(Protocol):
"""Base class for non-text data persistence like Azure Data Lake, S3, Google Storage, etc."""
async def upload_file(self, object_key: str, data: Union[bytes, str], mime: str = 'application/octet-stream', overwrite: bool = True) -> Dict[str, Any]:
hayescode marked this conversation as resolved.
Show resolved Hide resolved
pass

if api_key := os.environ.get("LITERAL_API_KEY"):
# support legacy LITERAL_SERVER variable as fallback
Expand Down
494 changes: 494 additions & 0 deletions backend/chainlit/data/sql_alchemy.py

Large diffs are not rendered by default.

58 changes: 58 additions & 0 deletions backend/chainlit/data/storage_clients.py
@@ -0,0 +1,58 @@
from chainlit.data import BaseStorageClient
from chainlit.logger import logger
from typing import TYPE_CHECKING, Optional, Dict, Union, Any
from azure.storage.filedatalake import DataLakeServiceClient, FileSystemClient, DataLakeFileClient, ContentSettings
import boto3 # type: ignore

if TYPE_CHECKING:
from azure.core.credentials import AzureNamedKeyCredential, AzureSasCredential, TokenCredential

class AzureStorageClient(BaseStorageClient):
"""
Class to enable Azure Data Lake Storage (ADLS) Gen2
parms:
account_url: "https://<your_account>.dfs.core.windows.net"
credential: Access credential (AzureKeyCredential)
sas_token: Optionally include SAS token to append to urls
"""
def __init__(self, account_url: str, container: str, credential: Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", "TokenCredential"]], sas_token: Optional[str] = None):
try:
self.data_lake_client = DataLakeServiceClient(account_url=account_url, credential=credential)
self.container_client: FileSystemClient = self.data_lake_client.get_file_system_client(file_system=container)
self.sas_token = sas_token
logger.info("AzureStorageClient initialized")
except Exception as e:
logger.warn(f"AzureStorageClient initialization error: {e}")

async def upload_file(self, object_key: str, data: Union[bytes, str], mime: str = 'application/octet-stream', overwrite: bool = True) -> Dict[str, Any]:
hayescode marked this conversation as resolved.
Show resolved Hide resolved
try:
file_client: DataLakeFileClient = self.container_client.get_file_client(object_key)
content_settings = ContentSettings(content_type=mime)
file_client.upload_data(data, overwrite=overwrite, content_settings=content_settings)
url = f"{file_client.url}{self.sas_token}" if self.sas_token else file_client.url
return {"object_key": object_key, "url": url}
except Exception as e:
logger.warn(f"AzureStorageClient, upload_file error: {e}")
return {}

class S3StorageClient(BaseStorageClient):
"""
Class to enable Amazon S3 storage provider
"""
def __init__(self, bucket: str):
try:
self.bucket = bucket
self.client = boto3.client("s3")
logger.info("S3StorageClient initialized")
except Exception as e:
logger.warn(f"S3StorageClient initialization error: {e}")

async def upload_file(self, object_key: str, data: Union[bytes, str], mime: str = 'application/octet-stream', overwrite: bool = True) -> Dict[str, Any]:
hayescode marked this conversation as resolved.
Show resolved Hide resolved
try:
self.client.put_object(Bucket=self.bucket, Key=object_key, Body=data, ContentType=mime)
url = f"https://{self.bucket}.s3.amazonaws.com/{object_key}"
return {"object_key": object_key, "url": url}
except Exception as e:
logger.warn(f"S3StorageClient, upload_file error: {e}")
return {}
54 changes: 53 additions & 1 deletion backend/chainlit/types.py
@@ -1,5 +1,5 @@
from enum import Enum
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, TypedDict, Union
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, TypedDict, Union, Generic, TypeVar, Protocol, Any

if TYPE_CHECKING:
from chainlit.element import ElementDict
Expand Down Expand Up @@ -37,6 +37,58 @@ class ThreadFilter(BaseModel):
userId: Optional[str] = None
search: Optional[str] = None

@dataclass
class PageInfo:
hasNextPage: bool
startCursor: Optional[str]
endCursor: Optional[str]

def to_dict(self):
return {
"hasNextPage": self.hasNextPage,
"startCursor": self.startCursor,
"endCursor": self.endCursor,
}

@classmethod
def from_dict(cls, page_info_dict: Dict) -> "PageInfo":
hasNextPage = page_info_dict.get("hasNextPage", False)
startCursor = page_info_dict.get("startCursor", None)
endCursor = page_info_dict.get("endCursor", None)
return cls(
hasNextPage=hasNextPage, startCursor=startCursor, endCursor=endCursor
)

T = TypeVar("T", covariant=True)

class HasFromDict(Protocol[T]):
@classmethod
def from_dict(cls, obj_dict: Any) -> T:
raise NotImplementedError()

@dataclass
class PaginatedResponse(Generic[T]):
pageInfo: PageInfo
data: List[T]

def to_dict(self):
return {
"pageInfo": self.pageInfo.to_dict(),
"data": [
(d.to_dict() if hasattr(d, "to_dict") and callable(d.to_dict) else d)
for d in self.data
],
}

@classmethod
def from_dict(
cls, paginated_response_dict: Dict, the_class: HasFromDict[T]
) -> "PaginatedResponse[T]":
pageInfo = PageInfo.from_dict(paginated_response_dict.get("pageInfo", {}))

data = [the_class.from_dict(d) for d in paginated_response_dict.get("data", [])]

return cls(pageInfo=pageInfo, data=data)

@dataclass
class FileSpec(DataClassJsonMixin):
Expand Down
10 changes: 10 additions & 0 deletions backend/pyproject.toml
Expand Up @@ -92,6 +92,16 @@ module = [
]
ignore_missing_imports = true

[tool.poetry.group.custom-data]
optional = true

[tool.poetry.group.custom-data.dependencies]
asyncpg = "^0.29.0"
SQLAlchemy = "^2.0.28"
boto3 = "^1.34.73"
azure-identity = "^1.14.1"
azure-storage-file-datalake = "^12.14.0"

[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"
31 changes: 31 additions & 0 deletions cypress/e2e/custom_data_layer/sql_alchemy.py
@@ -0,0 +1,31 @@
from typing import List, Optional

import chainlit.data as cl_data
from chainlit.data.sql_alchemy import SQLAlchemyDataLayer
from chainlit.data.storage_clients import AzureStorageClient
from literalai.helper import utc_now

import chainlit as cl

storage_client = AzureStorageClient(account_url="<your_account_url>", container="<your_container>")

cl_data._data_layer = SQLAlchemyDataLayer(conninfo="<your conninfo>", storage_provider=storage_client)


@cl.on_chat_start
async def main():
await cl.Message("Hello, send me a message!", disable_feedback=True).send()


@cl.on_message
async def handle_message():
await cl.sleep(2)
await cl.Message("Ok!").send()


@cl.password_auth_callback
def auth_callback(username: str, password: str) -> Optional[cl.User]:
if (username, password) == ("admin", "admin"):
return cl.User(identifier="admin")
else:
return None