Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: Extend available LLMs #44

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 2 additions & 0 deletions backend/app/app/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ class Settings(BaseSettings):
OPENAI_API_KEY: str
OPENAI_ORGANIZATION: Optional[str] = None
OPENAI_API_BASE: Optional[str] = None
ANTHROPIC_API_KEY: Optional[str] = None
GOOGLE_API_KEY: Optional[str] = None
DATABASE_USER: str
DATABASE_PASSWORD: str
DATABASE_HOST: str
Expand Down
4 changes: 4 additions & 0 deletions backend/app/app/schemas/tool_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@
"gpt-3.5-turbo",
"azure-4-32k",
"azure-3.5",
"claude-3-opus",
"claude-3-sonnet",
"claude-3-haiku",
"gemini-1.0-pro",
]


Expand Down
42 changes: 42 additions & 0 deletions backend/app/app/services/chat_agent/helpers/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,12 @@
import tiktoken
from langchain.base_language import BaseLanguageModel
from langchain_openai import AzureChatOpenAI, ChatOpenAI
from langchain_anthropic import ChatAnthropic
from langchain_google_genai import (
ChatGoogleGenerativeAI,
HarmBlockThreshold,
HarmCategory,
)

from app.core.config import settings
from app.schemas.tool_schema import LLMType
Expand Down Expand Up @@ -58,6 +64,41 @@ def get_llm(
openai_api_key=api_key if api_key is not None else settings.OPENAI_API_KEY,
streaming=True,
)
case "claude-3-opus":
return ChatAnthropic(
temperature=0,
model_name="claude-3-opus-20240229",
anthropic_api_key=settings.ANTHROPIC_API_KEY,
streaming=True,
)
case "claude-3-sonnet":
return ChatAnthropic(
temperature=0,
model_name="claude-3-sonnet-20240229",
anthropic_api_key=settings.ANTHROPIC_API_KEY,
streaming=True,
)
case "claude-3-haiku":
return ChatAnthropic(
temperature=0,
model_name="claude-3-haiku-20240307",
anthropic_api_key=settings.ANTHROPIC_API_KEY,
streaming=True,
)
case "gemini-1.0-pro":
return ChatGoogleGenerativeAI(
temperature=0,
model="gemini-1.0-pro-latest",
google_api_key=settings.GOOGLE_API_KEY,
streaming=True,
convert_system_message_to_human=True,
safety_settings={
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE,
HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
},
)
# If an exact match is not confirmed, this last case will be used if provided
case _:
logger.warning(f"LLM {llm} not found, using default LLM")
Expand All @@ -68,3 +109,4 @@ def get_llm(
openai_api_key=settings.OPENAI_API_KEY,
streaming=True,
)

4 changes: 3 additions & 1 deletion backend/app/app/services/chat_agent/meta_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,10 @@ def get_conv_token_buffer_memory(
ConversationTokenBufferMemory: The ConversationTokenBufferMemory object.
"""
agent_config = get_agent_config()

# We use gpt-4 in ConversationTokenBufferMemory to standardize tokenization
llm = get_llm(
agent_config.common.llm,
'gpt-4',
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kaikun213 forcing it to use gpt-4's tokenizer here through tiktoken because otherwise langchain tries to import claude's tokenizer from transformers otherwise (and I don't think we want to add that to dependencies). Does that work?

api_key=api_key,
)
chat_history = ChatMessageHistory()
Expand Down