This repository has been archived by the owner on Oct 19, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 136
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #3 from jina-ai/ws-support
feat(streaming): enable websocket endpoint
- Loading branch information
Showing
14 changed files
with
653 additions
and
78 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
from langchain import LLMChain, OpenAI, SerpAPIWrapper | ||
from langchain.agents import AgentExecutor, Tool, ZeroShotAgent | ||
|
||
from lcserve import serving | ||
|
||
|
||
@serving | ||
def ask(input: str) -> str: | ||
search = SerpAPIWrapper() | ||
tools = [ | ||
Tool( | ||
name="Search", | ||
func=search.run, | ||
description="useful for when you need to answer questions about current events", | ||
) | ||
] | ||
prefix = """Answer the following questions as best you can, but speaking as a pirate might speak. You have access to the following tools:""" | ||
suffix = """Begin! Remember to speak as a pirate when giving your final answer. Use lots of "Args" | ||
Question: {input} | ||
{agent_scratchpad}""" | ||
|
||
prompt = ZeroShotAgent.create_prompt( | ||
tools, | ||
prefix=prefix, | ||
suffix=suffix, | ||
input_variables=["input", "agent_scratchpad"], | ||
) | ||
|
||
print(prompt.template) | ||
|
||
llm_chain = LLMChain(llm=OpenAI(temperature=0), prompt=prompt) | ||
tool_names = [tool.name for tool in tools] | ||
agent = ZeroShotAgent(llm_chain=llm_chain, allowed_tools=tool_names) | ||
|
||
agent_executor = AgentExecutor.from_agent_and_tools( | ||
agent=agent, tools=tools, verbose=True | ||
) | ||
|
||
return agent_executor.run(input) | ||
|
||
|
||
if __name__ == "__main__": | ||
print(ask("What is the capital of France?")) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
from langchain import LLMChain, OpenAI, SerpAPIWrapper | ||
from langchain.agents import AgentExecutor, Tool, ZeroShotAgent | ||
|
||
search = SerpAPIWrapper() | ||
tools = [ | ||
Tool( | ||
name="Search", | ||
func=search.run, | ||
description="useful for when you need to answer questions about current events", | ||
) | ||
] | ||
prefix = """Answer the following questions as best you can, but speaking as a pirate might speak. You have access to the following tools:""" | ||
suffix = """Begin! Remember to speak as a pirate when giving your final answer. Use lots of "Args" | ||
Question: {input} | ||
{agent_scratchpad}""" | ||
|
||
prompt = ZeroShotAgent.create_prompt( | ||
tools, | ||
prefix=prefix, | ||
suffix=suffix, | ||
input_variables=["input", "agent_scratchpad"], | ||
) | ||
|
||
llm_chain = LLMChain(llm=OpenAI(temperature=0), prompt=prompt) | ||
tool_names = [tool.name for tool in tools] | ||
agent = ZeroShotAgent(llm_chain=llm_chain, allowed_tools=tool_names) | ||
|
||
agent_executor = AgentExecutor.from_agent_and_tools( | ||
agent=agent, tools=tools, verbose=True | ||
) | ||
|
||
agent_executor.run("What is the capital of France?") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
openai | ||
google-search-results |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,116 @@ | ||
## Human in the Loop | ||
|
||
This directory contains 3 files | ||
|
||
#### `hitl.py` | ||
|
||
1. Defines a function `hitl` decorated with `@serving` with `websocket=True`. | ||
|
||
https://github.com/jina-ai/langchain-serve/blob/d0726b8e730f4646fd9e54561a7648f9c3b1af60/examples/websockets/hitl/hitl.py#L11-L12 | ||
|
||
2. Accepts `streaming_handler` from kwargs and passes it to `ChatOpenAI` and `OpenAI` callback managers. This handler is responsible to stream the response to the client. | ||
|
||
https://github.com/jina-ai/langchain-serve/blob/9f793f4311007f6cb775e9ac19f89694eb97b80d/examples/websockets/hitl/hitl.py#L19-L22 | ||
|
||
https://github.com/jina-ai/langchain-serve/blob/9f793f4311007f6cb775e9ac19f89694eb97b80d/examples/websockets/hitl/hitl.py#L27-L30 | ||
|
||
3. Returns `agent.run` output which is a `str`. | ||
|
||
https://github.com/jina-ai/langchain-serve/blob/9f793f4311007f6cb775e9ac19f89694eb97b80d/examples/websockets/hitl/hitl.py#L43 | ||
|
||
|
||
--- | ||
|
||
#### `requirements.txt` | ||
|
||
Contains the dependencies for the `hitl` endpoint. | ||
|
||
--- | ||
|
||
|
||
#### `hitl_client.py` | ||
|
||
A simple client | ||
|
||
1. Connects to the websocket server and sends the following to the `hitl` endpoint. | ||
|
||
```json | ||
{ | ||
"question": "${question}", | ||
"envs": { | ||
"OPENAI_API_KEY": "${OPENAI_API_KEY}" | ||
} | ||
} | ||
``` | ||
|
||
https://github.com/jina-ai/langchain-serve/blob/fe9401618fa1635b17c5a117eea0463e79f85805/examples/websockets/hitl/hitl_client.py#L24-L29 | ||
|
||
2. Listens to the stream of responses and prints it to the console | ||
|
||
https://github.com/jina-ai/langchain-serve/blob/fe9401618fa1635b17c5a117eea0463e79f85805/examples/websockets/hitl/hitl_client.py#L31-L39 | ||
|
||
3. When it receives a response in the following format, it asks the prompt to the user using the client and waits for the user to input the answer. (This is how human is brought into the loop). Next, this answer is then sent to the server. | ||
|
||
```json | ||
{ | ||
"prompt": "$prompt" | ||
} | ||
``` | ||
|
||
https://github.com/jina-ai/langchain-serve/blob/fe9401618fa1635b17c5a117eea0463e79f85805/examples/websockets/hitl/hitl_client.py#L42-L44 | ||
|
||
4. Finally, the client is disconnected from the server automatically when the `hitl` function is done executing. | ||
|
||
--- | ||
|
||
|
||
### Example run on localhost | ||
|
||
```bash | ||
python hitl_client.py | ||
``` | ||
|
||
```text | ||
Connected to ws://localhost:8080/hitl. | ||
I don't know Eric Zhu's birthday, so I need to ask a human. | ||
Action: Human | ||
Action Input: "Do you know Eric Zhu and his birthday?" | ||
Yes | ||
Great, now I can ask for Eric Zhu's birthday. | ||
Action: Human | ||
Action Input: "What is Eric Zhu's birthday?" | ||
29th Feb | ||
I need to make sure this is a valid date. | ||
Action: Calculator | ||
Action Input: Check if 29th Feb is a valid date | ||
``` | ||
|
||
```python | ||
import datetime | ||
|
||
try: | ||
datetime.datetime(2020, 2, 29) | ||
print("Valid date") | ||
except ValueError: | ||
print("Invalid date") | ||
``` | ||
|
||
```text | ||
I now have a valid birth date, but I need to know the year for Eric's age. | ||
Action: Human | ||
Action Input: "Do you know Eric Zhu's birth year?" | ||
1990 | ||
Now I can calculate Eric Zhu's age. | ||
Action: Calculator | ||
Action Input: Current year minus 1990 | ||
``` | ||
|
||
```python | ||
import datetime | ||
print(datetime.datetime.now().year - 1990) | ||
``` | ||
|
||
```text | ||
I now know Eric Zhu's age. | ||
Final Answer: Eric Zhu's birthday is February 29th, 1990 and he is currently 33 years old.Eric Zhu's birthday is February 29th, 1990 and he is currently 33 years old.% | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
import os | ||
|
||
from langchain.agents import initialize_agent, load_tools | ||
from langchain.callbacks.base import CallbackManager | ||
from langchain.chat_models import ChatOpenAI | ||
from langchain.llms import OpenAI | ||
|
||
from lcserve import serving | ||
|
||
|
||
@serving(websocket=True) | ||
def hitl(question: str, **kwargs) -> str: | ||
# Get the `streaming_handler` from `kwargs`. This is used to stream data to the client. | ||
streaming_handler = kwargs.get('streaming_handler') | ||
|
||
llm = ChatOpenAI( | ||
temperature=0.0, | ||
verbose=True, | ||
streaming=True, # Pass `streaming=True` to make sure the client receives the data. | ||
callback_manager=CallbackManager( | ||
[streaming_handler] | ||
), # Pass the callback handler | ||
) | ||
math_llm = OpenAI( | ||
temperature=0.0, | ||
verbose=True, | ||
streaming=True, # Pass `streaming=True` to make sure the client receives the data. | ||
callback_manager=CallbackManager( | ||
[streaming_handler] | ||
), # Pass the callback handler | ||
) | ||
tools = load_tools( | ||
["human", "llm-math"], | ||
llm=math_llm, | ||
) | ||
|
||
agent_chain = initialize_agent( | ||
tools, | ||
llm, | ||
agent="zero-shot-react-description", | ||
verbose=True, | ||
) | ||
return agent_chain.run(question) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
import asyncio | ||
import os | ||
from typing import Dict | ||
|
||
import aiohttp | ||
from pydantic import BaseModel, ValidationError | ||
|
||
|
||
class Response(BaseModel): | ||
result: str | ||
error: str | ||
stdout: str | ||
|
||
|
||
class HumanPrompt(BaseModel): | ||
prompt: str | ||
|
||
|
||
async def hitl_client(url: str, name: str, question: str, envs: Dict = {}): | ||
async with aiohttp.ClientSession() as session: | ||
async with session.ws_connect(f'{url}/{name}') as ws: | ||
print(f'Connected to {url}/{name}.') | ||
|
||
await ws.send_json( | ||
{ | ||
"question": question, | ||
"envs": envs if envs else {}, | ||
} | ||
) | ||
|
||
async for msg in ws: | ||
if msg.type == aiohttp.WSMsgType.TEXT: | ||
if msg.data == 'close cmd': | ||
await ws.close() | ||
break | ||
else: | ||
try: | ||
response = Response.parse_raw(msg.data) | ||
print(response.result, end='') | ||
except ValidationError: | ||
try: | ||
prompt = HumanPrompt.parse_raw(msg.data) | ||
answer = input(prompt.prompt + '\n') | ||
await ws.send_str(answer) | ||
except ValidationError: | ||
print(f'Unknown message: {msg.data}') | ||
|
||
elif msg.type == aiohttp.WSMsgType.ERROR: | ||
print('ws connection closed with exception %s' % ws.exception()) | ||
else: | ||
print(msg) | ||
|
||
|
||
asyncio.run( | ||
hitl_client( | ||
url='wss://langchain-1da55ad36a-websocket.wolf.jina.ai', | ||
name='hitl', | ||
question='What is Eric Zhu\'s birthday?', | ||
envs={ | ||
'OPENAI_API_KEY': os.environ['OPENAI_API_KEY'], | ||
}, | ||
) | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
openai |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,15 +1,28 @@ | ||
from functools import wraps | ||
from typing import Callable | ||
|
||
|
||
def serving(func: Callable): | ||
@wraps(func) | ||
def wrapper(*args, **kwargs): | ||
return func(*args, **kwargs) | ||
def serving(_func=None, *, websocket: bool = False): | ||
def decorator(func): | ||
@wraps(func) | ||
def wrapper(*args, **kwargs): | ||
return func(*args, **kwargs) | ||
|
||
wrapper.__serving__ = { | ||
'name': func.__name__, | ||
'doc': func.__doc__, | ||
'params': {}, | ||
} | ||
return wrapper | ||
_args = { | ||
'name': func.__name__, | ||
'doc': func.__doc__, | ||
'params': { | ||
'include_callback_handlers': websocket, | ||
# If websocket is True, pass the callback handlers to the client. | ||
}, | ||
} | ||
if websocket: | ||
wrapper.__ws_serving__ = _args | ||
else: | ||
wrapper.__serving__ = _args | ||
|
||
return wrapper | ||
|
||
if _func is None: | ||
return decorator | ||
else: | ||
return decorator(_func) |
Oops, something went wrong.