Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding Claude 3 API support #3340

Merged
merged 3 commits into from
Apr 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
11 changes: 11 additions & 0 deletions configs/model_config.py.example
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,17 @@ ONLINE_LLM_MODEL = {
"api_key": "",
"provider": "GeminiWorker",
}
# Claude API : https://www.anthropic.com/api
# Available models:
# Claude 3 Opus: claude-3-opus-20240229
# Claude 3 Sonnet claude-3-sonnet-20240229
# Claude 3 Haiku claude-3-haiku-20240307
"claude-api": {
"api_key": "",
"version": "2023-06-01",
"model_name":"claude-3-opus-20240229",
"provider": "ClaudeWorker",
}

}

Expand Down
3 changes: 3 additions & 0 deletions configs/server_config.py.example
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,9 @@ FSCHAT_MODEL_WORKERS = {
"gemini-api": {
"port": 21010,
},
"claude-api": {
"port": 21011,
},
}

FSCHAT_CONTROLLER = {
Expand Down
3 changes: 2 additions & 1 deletion server/model_workers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,5 @@
from .baichuan import BaiChuanWorker
from .azure import AzureWorker
from .tiangong import TianGongWorker
from .gemini import GeminiWorker
from .gemini import GeminiWorker
from .claude import ClaudeWorker
130 changes: 130 additions & 0 deletions server/model_workers/claude.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
import sys
from fastchat.conversation import Conversation
from server.model_workers.base import *
from server.utils import get_httpx_client
import json, httpx
from typing import List, Dict
from configs import logger, log_verbose
import uvicorn


class ClaudeWorker(ApiModelWorker):
def __init__(
self,
*,
controller_addr: str = None,
worker_addr: str = None,
model_names: List[str] = ["claude-api"],
version: str = "2023-06-01",

**kwargs,
):
kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr)
kwargs.setdefault("context_len", 1024)
super().__init__(**kwargs)
self.version = version

def create_claude_messages(self, params: ApiChatParams) -> json:
has_history = any(msg['role'] == 'assistant' for msg in params.messages)
claude_msg = {
"model": params.model_name,
"max_tokens": params.context_len,
"messages": []
}

for msg in params.messages:
role = msg['role']
content = msg['content']
if role == 'system':
continue
# Adjusting for history presence
if has_history and role == 'assistant':
role = "model"
claude_msg["messages"].append({"role": role, "content": content})

return claude_msg

def do_chat(self, params: ApiChatParams) -> Dict:
data = self.create_claude_messages(params)
url = "https://api.anthropic.com/v1/messages"
headers = {
'anthropic-version': '2023-06-01',
'anthropic-beta': 'messages-2023-12-15',
'Content-Type': 'application/json',
'x-api-key': params.api_key,
}
if log_verbose:
logger.info(f'{self.__class__.__name__}:url: {url}')
logger.info(f'{self.__class__.__name__}:headers: {headers}')
logger.info(f'{self.__class__.__name__}:data: {data}')

text = ""
json_string = ""
timeout = httpx.Timeout(60.0)
client = get_httpx_client(timeout=timeout)
client = get_httpx_client()
with client.stream("POST", url, headers=headers, json=data) as response:
for line in response.iter_lines():
line = line.strip()
if not line:
continue
json_string += line

try:
event_data = json.loads(line)
event_type = event_data.get("type")
if event_type == "content_block_delta":
delta_text = event_data.get("delta", {}).get("text", "")
text += delta_text
elif event_type == "message_stop":
# Message is complete, yield the result
yield {
"error_code": 0,
"text": text
}
text = ""
else:
logger.error(f"Failed to get response: {response.text}")
yield {
"error_code": response.status_code,
"text": "Failed to communicate with Claude API."
}

except json.JSONDecodeError as e:
print("Failed to decode JSON:", e)
print("Invalid JSON string:", json_string)

def get_embeddings(self, params):
# Implement embedding retrieval if necessary
print("embedding")
print(params)

def make_conv_template(self, conv_template: List[Dict[str, str]] = None, model_path: str = None) -> Conversation:
if conv_template is None:
conv_template = [
{"role": "user", "content": "Hello there."},
{"role": "assistant", "content": "Hi, I'm Claude. How can I help you?"},
{"role": "user", "content": "Can you explain LLMs in plain English?"}
]
return Conversation(
name=self.model_names[0],
system_message="You are Claude, a helpful, respectful, and honest assistant.",
messages=conv_template,
roles=["user", "assistant"],
sep="\n### ",
stop_str="###",
)


if __name__ == "__main__":

from server.utils import MakeFastAPIOffline
from fastchat.serve.base_model_worker import app

worker = ClaudeWorker(
controller_addr="http://127.0.0.1:20001",
worker_addr="http://127.0.0.1:21011",
)
sys.modules["fastchat.serve.model_worker"].worker = worker
MakeFastAPIOffline(app)
uvicorn.run(app, port=21011)