Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: use restful client for generate/chat in in cmdline #371

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
198 changes: 68 additions & 130 deletions xinference/deploy/cmdline.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import asyncio
import configparser
import logging
import os
Expand All @@ -24,7 +23,6 @@

from .. import __version__
from ..client import (
Client,
RESTfulChatglmCppChatModelHandle,
RESTfulChatModelHandle,
RESTfulClient,
Expand All @@ -36,7 +34,6 @@
XINFERENCE_DEFAULT_LOCAL_HOST,
XINFERENCE_ENV_ENDPOINT,
)
from ..isolation import Isolation
from ..types import ChatCompletionMessage

try:
Expand Down Expand Up @@ -353,68 +350,39 @@ def model_generate(
stream: bool,
):
endpoint = get_endpoint(endpoint)
if stream:
# TODO: when stream=True, RestfulClient cannot generate words one by one.
# So use Client in temporary. The implementation needs to be changed to
# RestfulClient in the future.
async def generate_internal():
while True:
# the prompt will be written to stdout.
# https://docs.python.org/3.10/library/functions.html#input
prompt = input("Prompt: ")
if prompt == "":
break
print(f"Completion: {prompt}", end="", file=sys.stdout)
async for chunk in model.generate(
prompt=prompt,
generate_config={"stream": stream, "max_tokens": max_tokens},
):
choice = chunk["choices"][0]
if "text" not in choice:
continue
else:
print(choice["text"], end="", flush=True, file=sys.stdout)
print("\n", file=sys.stdout)

client = Client(endpoint=endpoint)
model = client.get_model(model_uid=model_uid)

loop = asyncio.get_event_loop()
coro = generate_internal()

if loop.is_running():
isolation = Isolation(asyncio.new_event_loop(), threaded=True)
isolation.start()
isolation.call(coro)
client = RESTfulClient(base_url=endpoint)
model = client.get_model(model_uid=model_uid)
if not isinstance(model, (RESTfulChatModelHandle, RESTfulGenerateModelHandle)):
raise ValueError(f"model {model_uid} has no generate method")

while True:
# the prompt will be written to stdout.
# https://docs.python.org/3.10/library/functions.html#input
prompt = input("Prompt: ")
if prompt.lower() == "exit" or prompt.lower() == "quit":
break
print(f"Completion: {prompt}", end="", file=sys.stdout)

if stream:
iter = model.generate(
prompt=prompt,
generate_config={"stream": stream, "max_tokens": max_tokens},
)
assert not isinstance(iter, dict)
for chunk in iter:
choice = chunk["choices"][0]
if "text" not in choice:
continue
else:
print(choice["text"], end="", flush=True, file=sys.stdout)
else:
task = loop.create_task(coro)
try:
loop.run_until_complete(task)
except KeyboardInterrupt:
task.cancel()
loop.run_until_complete(task)
# avoid displaying exception-unhandled warnings
task.exception()
else:
restful_client = RESTfulClient(base_url=endpoint)
restful_model = restful_client.get_model(model_uid=model_uid)
if not isinstance(
restful_model, (RESTfulChatModelHandle, RESTfulGenerateModelHandle)
):
raise ValueError(f"model {model_uid} has no generate method")

while True:
prompt = input("User: ")
if prompt == "":
break
print(f"Assistant: {prompt}", end="", file=sys.stdout)
response = restful_model.generate(
response = model.generate(
prompt=prompt,
generate_config={"stream": stream, "max_tokens": max_tokens},
)
if not isinstance(response, dict):
raise ValueError("generate result is not valid")
print(f"{response['choices'][0]['text']}\n", file=sys.stdout)
assert isinstance(response, dict)
print(f"{response['choices'][0]['text']}", file=sys.stdout)
print("\n", file=sys.stdout)


@cli.command("chat")
Expand All @@ -434,82 +402,52 @@ def model_chat(
):
# TODO: chat model roles may not be user and assistant.
endpoint = get_endpoint(endpoint)
client = RESTfulClient(base_url=endpoint)
model = client.get_model(model_uid=model_uid)
if not isinstance(
model, (RESTfulChatModelHandle, RESTfulChatglmCppChatModelHandle)
):
raise ValueError(f"model {model_uid} has no chat method")

chat_history: "List[ChatCompletionMessage]" = []
if stream:
# TODO: when stream=True, RestfulClient cannot generate words one by one.
# So use Client in temporary. The implementation needs to be changed to
# RestfulClient in the future.
async def chat_internal():
while True:
# the prompt will be written to stdout.
# https://docs.python.org/3.10/library/functions.html#input
prompt = input("User: ")
if prompt == "":
break
chat_history.append(ChatCompletionMessage(role="user", content=prompt))
print("Assistant: ", end="", file=sys.stdout)
response_content = ""
async for chunk in model.chat(
prompt=prompt,
chat_history=chat_history,
generate_config={"stream": stream, "max_tokens": max_tokens},
):
delta = chunk["choices"][0]["delta"]
if "content" not in delta:
continue
else:
response_content += delta["content"]
print(delta["content"], end="", flush=True, file=sys.stdout)
print("\n", file=sys.stdout)
chat_history.append(
ChatCompletionMessage(role="assistant", content=response_content)
)

client = Client(endpoint=endpoint)
model = client.get_model(model_uid=model_uid)

loop = asyncio.get_event_loop()
coro = chat_internal()

if loop.is_running():
isolation = Isolation(asyncio.new_event_loop(), threaded=True)
isolation.start()
isolation.call(coro)
while True:
# the prompt will be written to stdout.
# https://docs.python.org/3.10/library/functions.html#input
prompt = input("User: ")
if prompt == "":
break
chat_history.append(ChatCompletionMessage(role="user", content=prompt))
print("Assistant: ", end="", file=sys.stdout)

response_content = ""
if stream:
iter = model.chat(
prompt=prompt,
chat_history=chat_history,
generate_config={"stream": stream, "max_tokens": max_tokens},
)
assert not isinstance(iter, dict)
for chunk in iter:
delta = chunk["choices"][0]["delta"]
if "content" not in delta:
continue
else:
response_content += delta["content"]
print(delta["content"], end="", flush=True, file=sys.stdout)
else:
task = loop.create_task(coro)
try:
loop.run_until_complete(task)
except KeyboardInterrupt:
task.cancel()
loop.run_until_complete(task)
# avoid displaying exception-unhandled warnings
task.exception()
else:
restful_client = RESTfulClient(base_url=endpoint)
restful_model = restful_client.get_model(model_uid=model_uid)
if not isinstance(
restful_model, (RESTfulChatModelHandle, RESTfulChatglmCppChatModelHandle)
):
raise ValueError(f"model {model_uid} has no chat method")

while True:
prompt = input("User: ")
if prompt == "":
break
chat_history.append(ChatCompletionMessage(role="user", content=prompt))
print("Assistant: ", end="", file=sys.stdout)
response = restful_model.chat(
response = model.chat(
prompt=prompt,
chat_history=chat_history,
generate_config={"stream": stream, "max_tokens": max_tokens},
)
if not isinstance(response, dict):
raise ValueError("chat result is not valid")
assert isinstance(response, dict)
response_content = response["choices"][0]["message"]["content"]
print(f"{response_content}\n", file=sys.stdout)
chat_history.append(
ChatCompletionMessage(role="assistant", content=response_content)
)
print(f"{response_content}", file=sys.stdout)

chat_history.append(
ChatCompletionMessage(role="assistant", content=response_content)
)
print("\n", file=sys.stdout)


if __name__ == "__main__":
Expand Down
4 changes: 2 additions & 2 deletions xinference/deploy/test/test_cmdline.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import pytest
from click.testing import CliRunner

from ...client import Client
from ...client import RESTfulClient
from ..cmdline import (
list_model_registrations,
model_chat,
Expand Down Expand Up @@ -59,7 +59,7 @@ def test_cmdline(setup, stream):
"""
# if use `model_launch` command to launch model, CI will fail.
# So use client to launch model in temporary
client = Client(endpoint)
client = RESTfulClient(endpoint)
model_uid = client.launch_model(
model_name="orca", model_size_in_billions=3, quantization="q4_0"
)
Expand Down