Skip to content
This repository has been archived by the owner on Aug 10, 2023. It is now read-only.

Require authentication for all money-costing API endpoints #89

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
34 changes: 11 additions & 23 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -162,29 +162,6 @@ typings/
# Local History for Visual Studio Code
.history/


# Provided default Pycharm Run/Debug Configurations should be tracked by git
# In case of local modifications made by Pycharm, use update-index command
# for each changed file, like this:
# git update-index --assume-unchanged .idea/chat_all_the_docs.iml
### JetBrains template
# Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and Webstorm
# Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839

# User-specific stuff:
.idea/**/workspace.xml
.idea/**/tasks.xml
.idea/dictionaries

# Sensitive or high-churn files:
.idea/**/dataSources/
.idea/**/dataSources.ids
.idea/**/dataSources.xml
.idea/**/dataSources.local.xml
.idea/**/sqlDataSources.xml
.idea/**/dynamic.xml
.idea/**/uiDesigner.xml

# Gradle:
.idea/**/gradle.xml
.idea/**/libraries
Expand Down Expand Up @@ -338,3 +315,14 @@ delphic/media/*

### Models for Question Answering
cache/*

# https://github.com/cookiecutter/cookiecutter-django/blob/de8759fdbd45ac288b97e050073a5d09f50029db/.gitignore#L211
# Even though the project might be opened and edited
# in any of the JetBrains IDEs, it makes no sense whatsoever
# to 'run' anything within it since any particular cookiecutter
# is declarative by nature.
.idea/

### Local configuration files
/.envs/.local
/frontend/.frontend
4 changes: 2 additions & 2 deletions compose/local/django/celery/worker/start
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,5 @@ set -o errexit
set -o nounset


#exec watchfiles celery.__main__.main --args '-A config.celery_app worker -l INFO'
exec celery -A config.celery_app worker -l INFO
exec watchfiles --filter python celery.__main__.main --args '-A config.celery_app worker -l INFO'
#exec celery -A config.celery_app worker -l INFO
28 changes: 10 additions & 18 deletions config/api/auth/api_key.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,21 @@
import logging

from asgiref.sync import sync_to_async
from ninja.security import APIKeyHeader
from rest_framework_api_key.models import APIKey
from django.http import HttpRequest
from ninja_extra.security import AsyncAPIKeyHeader
from rest_framework_api_key.models import AbstractAPIKey, APIKey

logger = logging.getLogger(__name__)


class NinjaApiKeyAuth(APIKeyHeader):
param_name = "AUTHORIZATION"
class NinjaApiKeyAuth(AsyncAPIKeyHeader):
param_name = "Authorization"

# def authenticate(self, request, key):
# print(f"API KEY authenticatE: {key}")
# try:
# api_key = APIKey.objects.get_from_key(key)
# print("Success")
# return api_key
# except Exception as e:
# logger.warning(f"INVALID KEY! - Error: {e}")
async def authenticate(self, request, key):
print(f"API KEY authenticatE: {key}")
async def authenticate(self, request: HttpRequest, key) -> "AbstractAPIKey":
try:
# Use the asynchronous ORM to get the API key
api_key = await sync_to_async(APIKey.objects.get_from_key)(key)
print(f"Success - api_key: {api_key.name}")
return api_key
return await sync_to_async(APIKey.objects.get_from_key)(key)
except APIKey.DoesNotExist:
pass
except Exception as e:
logger.warning(f"INVALID KEY! - Error: {e}")
logger.warning(f"NinjaApiKeyAuth: invalid key: {type(e)} {e}")
42 changes: 27 additions & 15 deletions config/api/endpoints.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
import logging

from asgiref.sync import sync_to_async
from django.conf import settings
from django.core.files.base import ContentFile
from django.db.models import Q
from django.http import HttpRequest
from ninja import File, Form, Router
from ninja.files import UploadedFile
from ninja_extra import NinjaExtraAPI
from ninja_jwt.authentication import AsyncJWTAuth
from ninja_jwt.controller import NinjaJWTDefaultController
from rest_framework_api_key.models import APIKey

from delphic.indexes.models import Collection, Document
from delphic.tasks import create_index
Expand All @@ -19,18 +24,21 @@
CollectionStatusEnum,
)

logger = logging.getLogger(__name__)

collections_router = Router()

api = NinjaExtraAPI(
title="GREMLIN Engine NLP Microservice",
description="Chat-All-The-Docs is a LLM document model orchestration engine that makes it easy to vectorize a "
"collection of documents and then build and serve discrete Llama-Index models for each collection to "
"create bespoke, highly-targed chattable knowledge bases.",
title="Delphic LLM Microservice",
description="""Delphic is a LLM document model orchestration engine that makes it
easy to vectorize a collection of documents and then build and serve discrete
Llama-Index models for each collection to create bespoke, highly-targed chattable
knowledge bases.""",
version="b0.9.0",
auth=None if settings.OPEN_ACCESS_MODE else NinjaApiKeyAuth(),
)

api.add_router("/collections", collections_router)
auth = None if settings.OPEN_ACCESS_MODE else [NinjaApiKeyAuth(), AsyncJWTAuth()]
api.add_router("/collections", collections_router, auth=auth)
api.register_controllers(NinjaJWTDefaultController)


Expand All @@ -52,9 +60,11 @@ async def create_collection(
description: str = Form(...),
files: list[UploadedFile] = File(...),
):
key = None if getattr(request, "auth", None) is None else request.auth
if key is not None:
key = await key
key = None

if api_key := getattr(request, "auth", None):
if isinstance(key, APIKey):
key = api_key

collection_instance = Collection(
api_key=key,
Expand Down Expand Up @@ -107,15 +117,17 @@ def query_collection_view(request: HttpRequest, query_input: CollectionQueryInpu
@collections_router.get(
"/available",
response=list[CollectionModelSchema],
summary="Get a list of all of the collections " "created with my api_key",
summary="Get a list of all of the collections created with my api_key",
)
async def get_my_collections_view(request: HttpRequest):
key = None if getattr(request, "auth", None) is None else request.auth
if key is not None:
key = await key
print(f"API KEY: {key}")
collections_filt = Q(api_key=None)

if key := getattr(request, "auth", None):
if isinstance(key, APIKey):
logger.debug(f"API key: {key.prefix}")
collections_filt |= Q(api_key=key)

collections = Collection.objects.filter(api_key=key)
collections = Collection.objects.filter(collections_filt)

return [
{
Expand Down
4 changes: 3 additions & 1 deletion config/api/websockets/queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,9 @@ async def receive(self, text_data):

{query_str}
"""
response = self.index.query(modified_query_str)

query_engine = self.index.as_query_engine()
response = query_engine.query(modified_query_str)

# Format the response as markdown
markdown_response = f"## Response\n\n{response}\n\n"
Expand Down
26 changes: 16 additions & 10 deletions delphic/tasks/index_tasks.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
import logging
import os
import tempfile
Expand All @@ -8,7 +9,7 @@
from django.core.files import File
from langchain import OpenAI
from llama_index import (
GPTSimpleVectorIndex,
GPTVectorStoreIndex,
LLMPredictor,
ServiceContext,
download_loader,
Expand All @@ -23,11 +24,11 @@
@celery_app.task
def create_index(collection_id):
"""
Celery task to create a GPTSimpleVectorIndex for a given Collection object.
Celery task to create a GPTVectorStoreIndex for a given Collection object.

This task takes the ID of a Collection object, retrieves it from the
database along with its related documents, and saves the document files
to a temporary directory. Then, it creates a GPTSimpleVectorIndex using
to a temporary directory. Then, it creates a GPTVectorStoreIndex using
the provided code and saves the index to the Comparison.model FileField.

Args:
Expand Down Expand Up @@ -60,15 +61,18 @@ def create_index(collection_id):
with temp_file_path.open("wb") as f:
f.write(file_data)

# Create the GPTSimpleVectorIndex
SimpleDirectoryReader = download_loader("SimpleDirectoryReader")
# Create the GPTVectorStoreIndex
try:
SimpleDirectoryReader = download_loader("SimpleDirectoryReader")
except Exception as e:
logger.error(f"Error downloading SimpleDirectoryReader: {e}")
raise

loader = SimpleDirectoryReader(
tempdir_path, recursive=True, exclude_hidden=False
)
documents = loader.load_data()
# index = GPTSimpleVectorIndex(documents)

# documents = SimpleDirectoryReader(str(tempdir_path)).load_data()
llm_predictor = LLMPredictor(
llm=OpenAI(
temperature=0,
Expand All @@ -81,11 +85,11 @@ def create_index(collection_id):
)

# build index
index = GPTSimpleVectorIndex.from_documents(
index = GPTVectorStoreIndex.from_documents(
documents, service_context=service_context
)

index_str = index.save_to_string()
index_str = json.dumps(index.storage_context.to_dict())

# Save the index_str to the Comparison.model FileField
with tempfile.NamedTemporaryFile(delete=False) as f:
Expand All @@ -105,7 +109,9 @@ def create_index(collection_id):
return True

except Exception as e:
logger.error(f"Error creating index for collection {collection_id}: {e}")
logger.error(
f"{type(e).__name__} creating index for collection {collection_id}: {e}"
)
collection.status = CollectionStatus.ERROR
collection.save()

Expand Down
32 changes: 12 additions & 20 deletions delphic/utils/collections.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import json
import logging
import textwrap
from pathlib import Path

from django.conf import settings
from langchain import OpenAI
from llama_index import GPTSimpleVectorIndex, LLMPredictor, ServiceContext
from llama_index import StorageContext, load_index_from_storage
from llama_index.indices.base import BaseIndex

from delphic.indexes.models import Collection

Expand All @@ -27,22 +28,22 @@ def format_source(source):
return formatted_source


async def load_collection_model(collection_id: str | int) -> GPTSimpleVectorIndex:
async def load_collection_model(collection_id: str | int) -> "BaseIndex":
"""
Load the Collection model from cache or the database, and return the index.

Args:
collection_id (Union[str, int]): The ID of the Collection model instance.

Returns:
GPTSimpleVectorIndex: The loaded index.
VectorStoreIndex: The loaded index.

This function performs the following steps:
1. Retrieve the Collection object with the given collection_id.
2. Check if a JSON file with the name '/cache/model_{collection_id}.json' exists.
3. If the JSON file doesn't exist, load the JSON from the Collection.model FileField and save it to
3. If the JSON file doesn't exist, load the JSON from the `Collection.model` FileField and save it to
'/cache/model_{collection_id}.json'.
4. Call GPTSimpleVectorIndex.load_from_disk with the cache_file_path.
4. Call VectorStoreIndex.load_from_disk with the cache_file_path.
"""
# Retrieve the Collection object
collection = await Collection.objects.aget(id=collection_id)
Expand All @@ -61,21 +62,12 @@ async def load_collection_model(collection_id: str | int) -> GPTSimpleVectorInde
with cache_file_path.open("w+", encoding="utf-8") as cache_file:
cache_file.write(model_file.read().decode("utf-8"))

# define LLM
logger.info(
f"load_collection_model() - Setup service context with tokens {settings.MAX_TOKENS} and "
f"model {settings.MODEL_NAME}"
)
llm_predictor = LLMPredictor(
llm=OpenAI(temperature=0, model_name="text-davinci-003", max_tokens=512)
)
service_context = ServiceContext.from_defaults(llm_predictor=llm_predictor)

# Call GPTSimpleVectorIndex.load_from_disk
# Call VectorStoreIndex.load_from_disk
logger.info("load_collection_model() - Load llama index")
index = GPTSimpleVectorIndex.load_from_disk(
cache_file_path, service_context=service_context
)
with cache_file_path.open("r") as cache_file:
storage_context = StorageContext.from_dict(json.load(cache_file))
index = load_index_from_storage(storage_context)

logger.info(
"load_collection_model() - Llamaindex loaded and ready for query..."
)
Expand Down
8 changes: 4 additions & 4 deletions frontend/src/api/collections.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ export const createCollection = async (

const headers: Record<string, string> = {
"Content-Type": "multipart/form-data",
Authorization: authToken,
"Authorization": `Bearer ${authToken}`,
};

return axios.post<CollectionModelSchema>(
Expand All @@ -40,7 +40,7 @@ export const queryCollection = async (
authToken: string
): Promise<AxiosResponse<CollectionQueryOutput>> => {
const headers: Record<string, string> = {
Authorization: authToken,
"Authorization": `Bearer ${authToken}`,
};

return axios.post<CollectionQueryOutput>(
Expand All @@ -57,7 +57,7 @@ export const getMyCollections = async (
REACT_APP_API_ROOT_URL + "/api/collections/available",
{
headers: {
Authorization: authToken,
"Authorization": `Bearer ${authToken}`,
},
}
);
Expand Down Expand Up @@ -85,7 +85,7 @@ export const addFileToCollection = async (
{
headers: {
"Content-Type": "multipart/form-data",
Authorization: authToken,
"Authorization": `Bearer ${authToken}`,
},
}
);
Expand Down
6 changes: 3 additions & 3 deletions requirements/base.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@ django-redis==5.2.0 # https://github.com/jazzband/django-redis

# API-Related
# ------------------------------------------------------------------------------
django-ninja==0.21.0 # https://github.com/vitalik/django-ninja
django-ninja==0.22.2 # https://github.com/vitalik/django-ninja
djangorestframework==3.14.0 # https://github.com/encode/django-rest-framework
djangorestframework-api-key==2.* # https://github.com/florimondmanca/djangorestframework-api-key
django-ninja-jwt==5.2.5 # https://github.com/eadwinCode/django-ninja-jwt
django-ninja-extra
django-ninja-extra==0.19.1
django-cors-headers==3.14.0
# Websockets
# ------------------------------------------------------------------------------
Expand All @@ -32,6 +32,6 @@ channels_redis

# NLP-Related
# ------------------------------------------------------------------------------
llama_index==0.5.25 # https://github.com/jerryjliu/llama_index
llama_index==0.6.38.post1 # https://github.com/jerryjliu/llama_index
PyPDF2==3.* # https://pypdf2.readthedocs.io/en/latest/
docx2txt==0.8