-
Notifications
You must be signed in to change notification settings - Fork 582
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: Integrate Azure OpenAI API. (new) #499
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. please move the config code to the camel/configs/ folder |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,162 @@ | ||
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. =========== | ||
# Licensed under the Apache License, Version 2.0 (the “License”); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an “AS IS” BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. =========== | ||
import os | ||
from typing import Any, Dict, List, Optional, Union | ||
|
||
from openai import AzureOpenAI, Stream | ||
|
||
from camel.configs import ( | ||
OPENAI_API_PARAMS_WITH_FUNCTIONS, | ||
AZURE_OPENAI_API_BACKEND_PARAMS, | ||
) | ||
from camel.messages import OpenAIMessage | ||
from camel.models import BaseModelBackend | ||
from camel.types import ChatCompletion, ChatCompletionChunk, ModelType | ||
from camel.utils import ( | ||
BaseTokenCounter, | ||
OpenAITokenCounter, | ||
azure_openai_api_key_required, | ||
) | ||
|
||
|
||
class AzureOpenAIModel(BaseModelBackend): | ||
r"""Azure OpenAI API in a unified BaseModelBackend interface.""" | ||
|
||
def __init__( | ||
self, | ||
model_type: ModelType, | ||
model_config_dict: Dict[str, Any], | ||
backend_config_dict: Dict[str, Any] = {}, | ||
Wendong-Fan marked this conversation as resolved.
Show resolved
Hide resolved
|
||
) -> None: | ||
r"""Constructor for Azure OpenAI backend. | ||
|
||
Args: | ||
model_type (ModelType): Model for which a backend is created at | ||
Azure, one of GPT_* series. | ||
model_config_dict (Dict[str, Any]): A dictionary that will | ||
be fed into openai.ChatCompletion.create(). | ||
backend_config_dict (Dict[str, Any]): A dictionary that contains | ||
the backend configs like model_type, deployment_name, | ||
endpoint, api_version, etc.(default : {}) | ||
""" | ||
super().__init__(model_type, model_config_dict) | ||
|
||
self.backend_config_dict = backend_config_dict | ||
|
||
self.model_type = backend_config_dict.get( | ||
"model_type", os.environ.get("AZURE_MODEL_TYPE", None)) | ||
Comment on lines
+57
to
+58
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. again, model type is not necessary for azure There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Explained my concerns above. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. how about let user set |
||
self.deployment_name = backend_config_dict.get( | ||
"deployment_name", os.environ.get("AZURE_DEPLOYMENT_NAME", None)) | ||
self.azure_endpoint = backend_config_dict.get( | ||
"azure_endpoint", os.environ.get("AZURE_ENDPOINT", None)) | ||
self.api_version = backend_config_dict.get( | ||
"api_version", | ||
os.environ.get("AZURE_API_VERSION", "2023-10-01-preview"), | ||
) | ||
|
||
try: | ||
assert self.model_type is not None | ||
except AssertionError: | ||
raise ValueError("Azure model type is not provided.") | ||
try: | ||
assert self.deployment_name is not None | ||
except AssertionError: | ||
raise ValueError("Azure model deployment name is not provided.") | ||
try: | ||
assert self.api_version is not None | ||
except AssertionError: | ||
raise ValueError("Azure API version is not provided.") | ||
|
||
if isinstance(self.model_type, str): | ||
self.model_type = ModelType[self.model_type.upper()] | ||
|
||
self._client = AzureOpenAI( | ||
timeout=60, | ||
max_retries=3, | ||
api_version=self.api_version, | ||
azure_endpoint=self.azure_endpoint, | ||
) | ||
self._token_counter: Optional[BaseTokenCounter] = None | ||
|
||
@property | ||
def token_counter(self) -> BaseTokenCounter: | ||
r"""Initialize the token counter for the model backend. | ||
|
||
Returns: | ||
BaseTokenCounter: The token counter following the model's | ||
tokenization style. | ||
""" | ||
if not self._token_counter: | ||
self._token_counter = OpenAITokenCounter(self.model_type) | ||
return self._token_counter | ||
|
||
@azure_openai_api_key_required | ||
def run( | ||
self, | ||
messages: List[OpenAIMessage], | ||
) -> Union[ChatCompletion, Stream[ChatCompletionChunk]]: | ||
r"""Runs inference of Azure OpenAI chat completion. | ||
|
||
Args: | ||
messages (List[OpenAIMessage]): Message list with the chat history | ||
in OpenAI API format. | ||
|
||
Returns: | ||
Union[ChatCompletion, Stream[ChatCompletionChunk]]: | ||
`ChatCompletion` in the non-stream mode, or | ||
`Stream[ChatCompletionChunk]` in the stream mode. | ||
""" | ||
response = self._client.chat.completions.create( | ||
messages=messages, | ||
model=self.deployment_name, | ||
**self.model_config_dict, | ||
) | ||
return response | ||
|
||
def check_model_config(self): | ||
r"""Check whether the model configuration contains any | ||
unexpected arguments to Azure OpenAI API. | ||
|
||
Raises: | ||
ValueError: If the model configuration dictionary contains any | ||
unexpected arguments to Azure OpenAI API. | ||
""" | ||
for param in self.model_config_dict: | ||
if param not in OPENAI_API_PARAMS_WITH_FUNCTIONS: | ||
raise ValueError(f"Unexpected argument `{param}` is " | ||
"input into OpenAI model backend.") | ||
|
||
def check_backend_config(self): | ||
r"""Check whether the backend configuration contains any | ||
unexpected arguments to Azure OpenAI API. | ||
|
||
Raises: | ||
ValueError: If the backend configuration dictionary contains any | ||
unexpected arguments to Azure OpenAI API. | ||
""" | ||
for param in self.backend_config_dict: | ||
if param not in AZURE_OPENAI_API_BACKEND_PARAMS: | ||
raise ValueError( | ||
f"Unexpected argument `{param}` for " | ||
"Azure OpenAI API backend." | ||
) | ||
|
||
@property | ||
def stream(self) -> bool: | ||
r"""Returns whether the model is in stream mode, | ||
which sends partial results each time. | ||
Returns: | ||
bool: Whether the model is in stream mode. | ||
""" | ||
return self.model_config_dict.get("stream", False) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -71,6 +71,31 @@ def wrapper(self, *args, **kwargs): | |
return cast(F, wrapper) | ||
|
||
|
||
def azure_openai_api_key_required(func: F) -> F: | ||
r"""Decorator that checks if the Azure OpenAI API key is available in the | ||
environment variables. | ||
|
||
Args: | ||
func (callable): The function to be wrapped. | ||
|
||
Returns: | ||
callable: The decorated function. | ||
|
||
Raises: | ||
ValueError: If the Azure OpenAI API key is not found in the environment | ||
variables. | ||
""" | ||
|
||
@wraps(func) | ||
def wrapper(self, *args, **kwargs): | ||
if 'AZURE_OPENAI_API_KEY' in os.environ: | ||
return func(self, *args, **kwargs) | ||
else: | ||
raise ValueError('Azure OpenAI API key not found.') | ||
|
||
return cast(F, wrapper) | ||
|
||
|
||
Comment on lines
+74
to
+98
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. for azure openai, I think not only api is required, the endpoint and api version also need to be checked There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don’t think we need to check these variables here since they will be checked when initializing the AzureOpenAIModel instance. When running the commands, we should only check the AZURE_OPENAI_API_KEY There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
def print_text_animated(text, delay: float = 0.02, end: str = ""): | ||
r"""Prints the given text with an animated effect. | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. =========== | ||
# Licensed under the Apache License, Version 2.0 (the “License”); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an “AS IS” BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. =========== | ||
from colorama import Fore | ||
|
||
from camel.societies import RolePlaying | ||
from camel.utils import print_text_animated | ||
|
||
|
||
def main(model_type=None, chat_turn_limit=50) -> None: | ||
task_prompt = "Develop a trading bot for the stock market" | ||
role_play_session = RolePlaying( | ||
assistant_role_name="Python Programmer", | ||
assistant_agent_kwargs=dict(model_type=model_type), | ||
user_role_name="Stock Trader", | ||
user_agent_kwargs=dict(model_type=model_type), | ||
task_prompt=task_prompt, | ||
with_task_specify=True, | ||
task_specify_agent_kwargs=dict(model_type=model_type), | ||
) | ||
|
||
print( | ||
Fore.GREEN + | ||
f"AI Assistant sys message:\n{role_play_session.assistant_sys_msg}\n") | ||
print(Fore.BLUE + | ||
f"AI User sys message:\n{role_play_session.user_sys_msg}\n") | ||
|
||
print(Fore.YELLOW + f"Original task prompt:\n{task_prompt}\n") | ||
print( | ||
Fore.CYAN + | ||
f"Specified task prompt:\n{role_play_session.specified_task_prompt}\n") | ||
print(Fore.RED + f"Final task prompt:\n{role_play_session.task_prompt}\n") | ||
|
||
n = 0 | ||
input_msg = role_play_session.init_chat() | ||
while n < chat_turn_limit: | ||
n += 1 | ||
assistant_response, user_response = role_play_session.step(input_msg) | ||
|
||
if assistant_response.terminated: | ||
print(Fore.GREEN + | ||
("AI Assistant terminated. Reason: " | ||
f"{assistant_response.info['termination_reasons']}.")) | ||
break | ||
if user_response.terminated: | ||
print(Fore.GREEN + | ||
("AI User terminated. " | ||
f"Reason: {user_response.info['termination_reasons']}.")) | ||
break | ||
|
||
print_text_animated(Fore.BLUE + | ||
f"AI User:\n\n{user_response.msg.content}\n") | ||
print_text_animated(Fore.GREEN + "AI Assistant:\n\n" | ||
f"{assistant_response.msg.content}\n") | ||
|
||
if "CAMEL_TASK_DONE" in user_response.msg.content: | ||
break | ||
|
||
input_msg = assistant_response.msg | ||
|
||
|
||
if __name__ == "__main__": | ||
from camel.types import ModelType | ||
|
||
main(model_type=ModelType.AZURE) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems for Azure,
AZURE_MODEL_TYPE
is not a must, butAZURE_ API_VERSION
is required, could you pls check this?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Currently the model type is used by token counter. I still think we should keep this variable for two reason: Firstly, Azure OpenAI API is basically an OpenAI API, I want other future function needs the model type developed on OpenAI Model class can be used directly by this class. Secondly, the deployed model type is a very important attribute for Azure OpenAI API, therefore I want it to be specified when we create this class.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey @L4zyy , I got your point,
ModelType
is required for counting token. We can passModelType
as parameter rather than using export to set it into environment, just like how we did for OpenAI service. ButAZURE_ API_VERSION
is indeed required when we call Azure service, here in the README if you don't ask user set theAZURE_ API_VERSION
then it will always be the default value you set in your code2023-10-01-preview
, it shouldn't be the case