Skip to content

Commit

Permalink
Support SQLAlchemy for custom data layer (#836)
Browse files Browse the repository at this point in the history
- adds custom, direct database, data layer using SQLAlchemy with support for a wide-range of SQL dialects
- configures ADLS or S3 as the blob storage provider
- duplicated `PageInfo` and `PaginatedResponse` from literal SDK into backend/chainlit/types.py and updated typing
  • Loading branch information
hayescode committed Apr 15, 2024
1 parent bd4c6f1 commit 2cd87ec
Show file tree
Hide file tree
Showing 7 changed files with 663 additions and 6 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/mypy.yaml
Expand Up @@ -18,6 +18,6 @@ jobs:
- name: Install Poetry
run: pip install poetry
- name: Install dependencies
run: poetry install --with tests --with mypy
run: poetry install --with tests --with mypy --with custom-data
- name: Run Mypy
run: poetry run mypy chainlit/
20 changes: 16 additions & 4 deletions backend/chainlit/data/__init__.py
Expand Up @@ -2,16 +2,16 @@
import json
import os
from collections import deque
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Union, cast
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Union, cast, Protocol, Any

import aiofiles
from chainlit.config import config
from chainlit.context import context
from chainlit.logger import logger
from chainlit.session import WebsocketSession
from chainlit.types import Feedback, Pagination, ThreadDict, ThreadFilter
from chainlit.types import Feedback, Pagination, ThreadDict, ThreadFilter, PageInfo, PaginatedResponse
from chainlit.user import PersistedUser, User
from literalai import Attachment, PageInfo, PaginatedResponse, Score as LiteralScore, Step as LiteralStep
from literalai import Attachment, PaginatedResponse as LiteralPaginatedResponse, Score as LiteralScore, Step as LiteralStep
from literalai.filter import threads_filters as LiteralThreadsFilters
from literalai.step import StepDict as LiteralStepDict

Expand Down Expand Up @@ -411,12 +411,20 @@ async def list_threads(
}
)

return await self.client.api.list_threads(
literal_response: LiteralPaginatedResponse = await self.client.api.list_threads(
first=pagination.first,
after=pagination.cursor,
filters=literal_filters,
order_by={"column": "createdAt", "direction": "DESC"},
)
return PaginatedResponse(
pageInfo=PageInfo(
hasNextPage=literal_response.pageInfo.hasNextPage,
startCursor=literal_response.pageInfo.startCursor,
endCursor=literal_response.pageInfo.endCursor
),
data=literal_response.data,
)

async def get_thread(self, thread_id: str) -> "Optional[ThreadDict]":
thread = await self.client.api.get_thread(id=thread_id)
Expand Down Expand Up @@ -462,6 +470,10 @@ async def update_thread(
tags=tags,
)

class BaseStorageClient(Protocol):
"""Base class for non-text data persistence like Azure Data Lake, S3, Google Storage, etc."""
async def upload_file(self, object_key: str, data: Union[bytes, str], mime: str = 'application/octet-stream', overwrite: bool = True) -> Dict[str, Any]:
pass

if api_key := os.environ.get("LITERAL_API_KEY"):
# support legacy LITERAL_SERVER variable as fallback
Expand Down
494 changes: 494 additions & 0 deletions backend/chainlit/data/sql_alchemy.py

Large diffs are not rendered by default.

58 changes: 58 additions & 0 deletions backend/chainlit/data/storage_clients.py
@@ -0,0 +1,58 @@
from chainlit.data import BaseStorageClient
from chainlit.logger import logger
from typing import TYPE_CHECKING, Optional, Dict, Union, Any
from azure.storage.filedatalake import DataLakeServiceClient, FileSystemClient, DataLakeFileClient, ContentSettings
import boto3 # type: ignore

if TYPE_CHECKING:
from azure.core.credentials import AzureNamedKeyCredential, AzureSasCredential, TokenCredential

class AzureStorageClient(BaseStorageClient):
"""
Class to enable Azure Data Lake Storage (ADLS) Gen2
parms:
account_url: "https://<your_account>.dfs.core.windows.net"
credential: Access credential (AzureKeyCredential)
sas_token: Optionally include SAS token to append to urls
"""
def __init__(self, account_url: str, container: str, credential: Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", "TokenCredential"]], sas_token: Optional[str] = None):
try:
self.data_lake_client = DataLakeServiceClient(account_url=account_url, credential=credential)
self.container_client: FileSystemClient = self.data_lake_client.get_file_system_client(file_system=container)
self.sas_token = sas_token
logger.info("AzureStorageClient initialized")
except Exception as e:
logger.warn(f"AzureStorageClient initialization error: {e}")

async def upload_file(self, object_key: str, data: Union[bytes, str], mime: str = 'application/octet-stream', overwrite: bool = True) -> Dict[str, Any]:
try:
file_client: DataLakeFileClient = self.container_client.get_file_client(object_key)
content_settings = ContentSettings(content_type=mime)
file_client.upload_data(data, overwrite=overwrite, content_settings=content_settings)
url = f"{file_client.url}{self.sas_token}" if self.sas_token else file_client.url
return {"object_key": object_key, "url": url}
except Exception as e:
logger.warn(f"AzureStorageClient, upload_file error: {e}")
return {}

class S3StorageClient(BaseStorageClient):
"""
Class to enable Amazon S3 storage provider
"""
def __init__(self, bucket: str):
try:
self.bucket = bucket
self.client = boto3.client("s3")
logger.info("S3StorageClient initialized")
except Exception as e:
logger.warn(f"S3StorageClient initialization error: {e}")

async def upload_file(self, object_key: str, data: Union[bytes, str], mime: str = 'application/octet-stream', overwrite: bool = True) -> Dict[str, Any]:
try:
self.client.put_object(Bucket=self.bucket, Key=object_key, Body=data, ContentType=mime)
url = f"https://{self.bucket}.s3.amazonaws.com/{object_key}"
return {"object_key": object_key, "url": url}
except Exception as e:
logger.warn(f"S3StorageClient, upload_file error: {e}")
return {}
54 changes: 53 additions & 1 deletion backend/chainlit/types.py
@@ -1,5 +1,5 @@
from enum import Enum
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, TypedDict, Union
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, TypedDict, Union, Generic, TypeVar, Protocol, Any

if TYPE_CHECKING:
from chainlit.element import ElementDict
Expand Down Expand Up @@ -37,6 +37,58 @@ class ThreadFilter(BaseModel):
userId: Optional[str] = None
search: Optional[str] = None

@dataclass
class PageInfo:
hasNextPage: bool
startCursor: Optional[str]
endCursor: Optional[str]

def to_dict(self):
return {
"hasNextPage": self.hasNextPage,
"startCursor": self.startCursor,
"endCursor": self.endCursor,
}

@classmethod
def from_dict(cls, page_info_dict: Dict) -> "PageInfo":
hasNextPage = page_info_dict.get("hasNextPage", False)
startCursor = page_info_dict.get("startCursor", None)
endCursor = page_info_dict.get("endCursor", None)
return cls(
hasNextPage=hasNextPage, startCursor=startCursor, endCursor=endCursor
)

T = TypeVar("T", covariant=True)

class HasFromDict(Protocol[T]):
@classmethod
def from_dict(cls, obj_dict: Any) -> T:
raise NotImplementedError()

@dataclass
class PaginatedResponse(Generic[T]):
pageInfo: PageInfo
data: List[T]

def to_dict(self):
return {
"pageInfo": self.pageInfo.to_dict(),
"data": [
(d.to_dict() if hasattr(d, "to_dict") and callable(d.to_dict) else d)
for d in self.data
],
}

@classmethod
def from_dict(
cls, paginated_response_dict: Dict, the_class: HasFromDict[T]
) -> "PaginatedResponse[T]":
pageInfo = PageInfo.from_dict(paginated_response_dict.get("pageInfo", {}))

data = [the_class.from_dict(d) for d in paginated_response_dict.get("data", [])]

return cls(pageInfo=pageInfo, data=data)

@dataclass
class FileSpec(DataClassJsonMixin):
Expand Down
10 changes: 10 additions & 0 deletions backend/pyproject.toml
Expand Up @@ -92,6 +92,16 @@ module = [
]
ignore_missing_imports = true

[tool.poetry.group.custom-data]
optional = true

[tool.poetry.group.custom-data.dependencies]
asyncpg = "^0.29.0"
SQLAlchemy = "^2.0.28"
boto3 = "^1.34.73"
azure-identity = "^1.14.1"
azure-storage-file-datalake = "^12.14.0"

[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"
31 changes: 31 additions & 0 deletions cypress/e2e/custom_data_layer/sql_alchemy.py
@@ -0,0 +1,31 @@
from typing import List, Optional

import chainlit.data as cl_data
from chainlit.data.sql_alchemy import SQLAlchemyDataLayer
from chainlit.data.storage_clients import AzureStorageClient
from literalai.helper import utc_now

import chainlit as cl

storage_client = AzureStorageClient(account_url="<your_account_url>", container="<your_container>")

cl_data._data_layer = SQLAlchemyDataLayer(conninfo="<your conninfo>", storage_provider=storage_client)


@cl.on_chat_start
async def main():
await cl.Message("Hello, send me a message!", disable_feedback=True).send()


@cl.on_message
async def handle_message():
await cl.sleep(2)
await cl.Message("Ok!").send()


@cl.password_auth_callback
def auth_callback(username: str, password: str) -> Optional[cl.User]:
if (username, password) == ("admin", "admin"):
return cl.User(identifier="admin")
else:
return None

0 comments on commit 2cd87ec

Please sign in to comment.