Skip to content

Commit

Permalink
Fix Json API Request for Claude
Browse files Browse the repository at this point in the history
  • Loading branch information
Jiang, Fengyi committed Mar 14, 2024
1 parent 373864f commit bfadfb9
Showing 1 changed file with 39 additions and 23 deletions.
62 changes: 39 additions & 23 deletions server/model_workers/claude.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import json, httpx
from typing import List, Dict
from configs import logger, log_verbose
import uvicorn


class ClaudeWorker(ApiModelWorker):
Expand All @@ -13,12 +14,12 @@ def __init__(
*,
controller_addr: str = None,
worker_addr: str = None,
model_name: str = ["claude-api"],
model_names: List[str] = ["claude-api"],
version: str = "2023-06-01",

**kwargs,
):
kwargs.update(model_name=model_name, controller_addr=controller_addr, worker_addr=worker_addr)
kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr)
kwargs.setdefault("context_len", 1024)
super().__init__(**kwargs)
self.version = version
Expand Down Expand Up @@ -47,9 +48,10 @@ def do_chat(self, params: ApiChatParams) -> Dict:
data = self.create_claude_messages(params)
url = "https://api.anthropic.com/v1/messages"
headers = {
'anthropic-version': '2023-06-01',
'anthropic-beta': 'messages-2023-12-15',
'Content-Type': 'application/json',
'x-api-key': params.api_key,
'anthropic-version': '2023-06-01'
}
if log_verbose:
logger.info(f'{self.__class__.__name__}:url: {url}')
Expand All @@ -60,38 +62,52 @@ def do_chat(self, params: ApiChatParams) -> Dict:
json_string = ""
timeout = httpx.Timeout(60.0)
client = get_httpx_client(timeout=timeout)
with client.post(url, headers=headers, json=data) as response:
if response.status_code == 200:
resp = response.json()
if 'messages' in resp:
for message in resp['messages']:
if 'content' in message:
text += message['content']
yield {
"error_code": 0,
"text": text
}
else:
logger.error(f"Failed to get response: {response.text}")
yield {
"error_code": response.status_code,
"text": "Failed to communicate with Claude API."
}
client = get_httpx_client()
with client.stream("POST", url, headers=headers, json=data) as response:
for line in response.iter_lines():
line = line.strip()
if not line:
continue
json_string += line

try:
event_data = json.loads(line)
event_type = event_data.get("type")
if event_type == "content_block_delta":
delta_text = event_data.get("delta", {}).get("text", "")
text += delta_text
elif event_type == "message_stop":
# Message is complete, yield the result
yield {
"error_code": 0,
"text": text
}
text = ""
else:
logger.error(f"Failed to get response: {response.text}")
yield {
"error_code": response.status_code,
"text": "Failed to communicate with Claude API."
}

except json.JSONDecodeError as e:
print("Failed to decode JSON:", e)
print("Invalid JSON string:", json_string)

def get_embeddings(self, params):
# Implement embedding retrieval if necessary
print("embedding")
print(params)

def make_conv_template(self, conv_template: List[Dict[str, str]] = None) -> Conversation:
def make_conv_template(self, conv_template: List[Dict[str, str]] = None, model_path: str = None) -> Conversation:
if conv_template is None:
conv_template = [
{"role": "user", "content": "Hello there."},
{"role": "assistant", "content": "Hi, I'm Claude. How can I help you?"},
{"role": "user", "content": "Can you explain LLMs in plain English?"}
]
return Conversation(
name=self.model_name,
name=self.model_names[0],
system_message="You are Claude, a helpful, respectful, and honest assistant.",
messages=conv_template,
roles=["user", "assistant"],
Expand All @@ -101,7 +117,7 @@ def make_conv_template(self, conv_template: List[Dict[str, str]] = None) -> Conv


if __name__ == "__main__":
import uvicorn

from server.utils import MakeFastAPIOffline
from fastchat.serve.base_model_worker import app

Expand Down

0 comments on commit bfadfb9

Please sign in to comment.