Skip to content

Commit

Permalink
Adding Claude 3 API support (#3340)
Browse files Browse the repository at this point in the history

---------

Co-authored-by: Jiang, Fengyi <[email protected]>
  • Loading branch information
FJiangArthur and Jiang, Fengyi committed Apr 16, 2024
1 parent 3902f6d commit 448e795
Show file tree
Hide file tree
Showing 4 changed files with 146 additions and 1 deletion.
11 changes: 11 additions & 0 deletions configs/model_config.py.example
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,17 @@ ONLINE_LLM_MODEL = {
"api_key": "",
"provider": "GeminiWorker",
}
# Claude API : https://www.anthropic.com/api
# Available models:
# Claude 3 Opus: claude-3-opus-20240229
# Claude 3 Sonnet claude-3-sonnet-20240229
# Claude 3 Haiku claude-3-haiku-20240307
"claude-api": {
"api_key": "",
"version": "2023-06-01",
"model_name":"claude-3-opus-20240229",
"provider": "ClaudeWorker",
}

}

Expand Down
3 changes: 3 additions & 0 deletions configs/server_config.py.example
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,9 @@ FSCHAT_MODEL_WORKERS = {
"gemini-api": {
"port": 21010,
},
"claude-api": {
"port": 21011,
},
}

FSCHAT_CONTROLLER = {
Expand Down
3 changes: 2 additions & 1 deletion server/model_workers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,5 @@
from .baichuan import BaiChuanWorker
from .azure import AzureWorker
from .tiangong import TianGongWorker
from .gemini import GeminiWorker
from .gemini import GeminiWorker
from .claude import ClaudeWorker
130 changes: 130 additions & 0 deletions server/model_workers/claude.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
import sys
from fastchat.conversation import Conversation
from server.model_workers.base import *
from server.utils import get_httpx_client
import json, httpx
from typing import List, Dict
from configs import logger, log_verbose
import uvicorn


class ClaudeWorker(ApiModelWorker):
def __init__(
self,
*,
controller_addr: str = None,
worker_addr: str = None,
model_names: List[str] = ["claude-api"],
version: str = "2023-06-01",

**kwargs,
):
kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr)
kwargs.setdefault("context_len", 1024)
super().__init__(**kwargs)
self.version = version

def create_claude_messages(self, params: ApiChatParams) -> json:
has_history = any(msg['role'] == 'assistant' for msg in params.messages)
claude_msg = {
"model": params.model_name,
"max_tokens": params.context_len,
"messages": []
}

for msg in params.messages:
role = msg['role']
content = msg['content']
if role == 'system':
continue
# Adjusting for history presence
if has_history and role == 'assistant':
role = "model"
claude_msg["messages"].append({"role": role, "content": content})

return claude_msg

def do_chat(self, params: ApiChatParams) -> Dict:
data = self.create_claude_messages(params)
url = "https://api.anthropic.com/v1/messages"
headers = {
'anthropic-version': '2023-06-01',
'anthropic-beta': 'messages-2023-12-15',
'Content-Type': 'application/json',
'x-api-key': params.api_key,
}
if log_verbose:
logger.info(f'{self.__class__.__name__}:url: {url}')
logger.info(f'{self.__class__.__name__}:headers: {headers}')
logger.info(f'{self.__class__.__name__}:data: {data}')

text = ""
json_string = ""
timeout = httpx.Timeout(60.0)
client = get_httpx_client(timeout=timeout)
client = get_httpx_client()
with client.stream("POST", url, headers=headers, json=data) as response:
for line in response.iter_lines():
line = line.strip()
if not line:
continue
json_string += line

try:
event_data = json.loads(line)
event_type = event_data.get("type")
if event_type == "content_block_delta":
delta_text = event_data.get("delta", {}).get("text", "")
text += delta_text
elif event_type == "message_stop":
# Message is complete, yield the result
yield {
"error_code": 0,
"text": text
}
text = ""
else:
logger.error(f"Failed to get response: {response.text}")
yield {
"error_code": response.status_code,
"text": "Failed to communicate with Claude API."
}

except json.JSONDecodeError as e:
print("Failed to decode JSON:", e)
print("Invalid JSON string:", json_string)

def get_embeddings(self, params):
# Implement embedding retrieval if necessary
print("embedding")
print(params)

def make_conv_template(self, conv_template: List[Dict[str, str]] = None, model_path: str = None) -> Conversation:
if conv_template is None:
conv_template = [
{"role": "user", "content": "Hello there."},
{"role": "assistant", "content": "Hi, I'm Claude. How can I help you?"},
{"role": "user", "content": "Can you explain LLMs in plain English?"}
]
return Conversation(
name=self.model_names[0],
system_message="You are Claude, a helpful, respectful, and honest assistant.",
messages=conv_template,
roles=["user", "assistant"],
sep="\n### ",
stop_str="###",
)


if __name__ == "__main__":

from server.utils import MakeFastAPIOffline
from fastchat.serve.base_model_worker import app

worker = ClaudeWorker(
controller_addr="http://127.0.0.1:20001",
worker_addr="http://127.0.0.1:21011",
)
sys.modules["fastchat.serve.model_worker"].worker = worker
MakeFastAPIOffline(app)
uvicorn.run(app, port=21011)

0 comments on commit 448e795

Please sign in to comment.