Skip to content
This repository has been archived by the owner on Oct 19, 2023. It is now read-only.

Commit

Permalink
Merge pull request #3 from jina-ai/ws-support
Browse files Browse the repository at this point in the history
feat(streaming): enable websocket endpoint
  • Loading branch information
deepankarm committed Apr 4, 2023
2 parents b86b3d3 + d6b2984 commit c0af0de
Show file tree
Hide file tree
Showing 14 changed files with 653 additions and 78 deletions.
44 changes: 44 additions & 0 deletions examples/rest/app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from langchain import LLMChain, OpenAI, SerpAPIWrapper
from langchain.agents import AgentExecutor, Tool, ZeroShotAgent

from lcserve import serving


@serving
def ask(input: str) -> str:
search = SerpAPIWrapper()
tools = [
Tool(
name="Search",
func=search.run,
description="useful for when you need to answer questions about current events",
)
]
prefix = """Answer the following questions as best you can, but speaking as a pirate might speak. You have access to the following tools:"""
suffix = """Begin! Remember to speak as a pirate when giving your final answer. Use lots of "Args"
Question: {input}
{agent_scratchpad}"""

prompt = ZeroShotAgent.create_prompt(
tools,
prefix=prefix,
suffix=suffix,
input_variables=["input", "agent_scratchpad"],
)

print(prompt.template)

llm_chain = LLMChain(llm=OpenAI(temperature=0), prompt=prompt)
tool_names = [tool.name for tool in tools]
agent = ZeroShotAgent(llm_chain=llm_chain, allowed_tools=tool_names)

agent_executor = AgentExecutor.from_agent_and_tools(
agent=agent, tools=tools, verbose=True
)

return agent_executor.run(input)


if __name__ == "__main__":
print(ask("What is the capital of France?"))
33 changes: 33 additions & 0 deletions examples/rest/app_before.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from langchain import LLMChain, OpenAI, SerpAPIWrapper
from langchain.agents import AgentExecutor, Tool, ZeroShotAgent

search = SerpAPIWrapper()
tools = [
Tool(
name="Search",
func=search.run,
description="useful for when you need to answer questions about current events",
)
]
prefix = """Answer the following questions as best you can, but speaking as a pirate might speak. You have access to the following tools:"""
suffix = """Begin! Remember to speak as a pirate when giving your final answer. Use lots of "Args"
Question: {input}
{agent_scratchpad}"""

prompt = ZeroShotAgent.create_prompt(
tools,
prefix=prefix,
suffix=suffix,
input_variables=["input", "agent_scratchpad"],
)

llm_chain = LLMChain(llm=OpenAI(temperature=0), prompt=prompt)
tool_names = [tool.name for tool in tools]
agent = ZeroShotAgent(llm_chain=llm_chain, allowed_tools=tool_names)

agent_executor = AgentExecutor.from_agent_and_tools(
agent=agent, tools=tools, verbose=True
)

agent_executor.run("What is the capital of France?")
2 changes: 2 additions & 0 deletions examples/rest/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
openai
google-search-results
116 changes: 116 additions & 0 deletions examples/websockets/hitl/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
## Human in the Loop

This directory contains 3 files

#### `hitl.py`

1. Defines a function `hitl` decorated with `@serving` with `websocket=True`.

https://github.com/jina-ai/langchain-serve/blob/d0726b8e730f4646fd9e54561a7648f9c3b1af60/examples/websockets/hitl/hitl.py#L11-L12

2. Accepts `streaming_handler` from kwargs and passes it to `ChatOpenAI` and `OpenAI` callback managers. This handler is responsible to stream the response to the client.

https://github.com/jina-ai/langchain-serve/blob/9f793f4311007f6cb775e9ac19f89694eb97b80d/examples/websockets/hitl/hitl.py#L19-L22

https://github.com/jina-ai/langchain-serve/blob/9f793f4311007f6cb775e9ac19f89694eb97b80d/examples/websockets/hitl/hitl.py#L27-L30

3. Returns `agent.run` output which is a `str`.

https://github.com/jina-ai/langchain-serve/blob/9f793f4311007f6cb775e9ac19f89694eb97b80d/examples/websockets/hitl/hitl.py#L43


---

#### `requirements.txt`

Contains the dependencies for the `hitl` endpoint.

---


#### `hitl_client.py`

A simple client

1. Connects to the websocket server and sends the following to the `hitl` endpoint.

```json
{
"question": "${question}",
"envs": {
"OPENAI_API_KEY": "${OPENAI_API_KEY}"
}
}
```

https://github.com/jina-ai/langchain-serve/blob/fe9401618fa1635b17c5a117eea0463e79f85805/examples/websockets/hitl/hitl_client.py#L24-L29

2. Listens to the stream of responses and prints it to the console

https://github.com/jina-ai/langchain-serve/blob/fe9401618fa1635b17c5a117eea0463e79f85805/examples/websockets/hitl/hitl_client.py#L31-L39

3. When it receives a response in the following format, it asks the prompt to the user using the client and waits for the user to input the answer. (This is how human is brought into the loop). Next, this answer is then sent to the server.

```json
{
"prompt": "$prompt"
}
```

https://github.com/jina-ai/langchain-serve/blob/fe9401618fa1635b17c5a117eea0463e79f85805/examples/websockets/hitl/hitl_client.py#L42-L44

4. Finally, the client is disconnected from the server automatically when the `hitl` function is done executing.

---


### Example run on localhost

```bash
python hitl_client.py
```

```text
Connected to ws://localhost:8080/hitl.
I don't know Eric Zhu's birthday, so I need to ask a human.
Action: Human
Action Input: "Do you know Eric Zhu and his birthday?"
Yes
Great, now I can ask for Eric Zhu's birthday.
Action: Human
Action Input: "What is Eric Zhu's birthday?"
29th Feb
I need to make sure this is a valid date.
Action: Calculator
Action Input: Check if 29th Feb is a valid date
```

```python
import datetime

try:
datetime.datetime(2020, 2, 29)
print("Valid date")
except ValueError:
print("Invalid date")
```

```text
I now have a valid birth date, but I need to know the year for Eric's age.
Action: Human
Action Input: "Do you know Eric Zhu's birth year?"
1990
Now I can calculate Eric Zhu's age.
Action: Calculator
Action Input: Current year minus 1990
```

```python
import datetime
print(datetime.datetime.now().year - 1990)
```

```text
I now know Eric Zhu's age.
Final Answer: Eric Zhu's birthday is February 29th, 1990 and he is currently 33 years old.Eric Zhu's birthday is February 29th, 1990 and he is currently 33 years old.%
```
43 changes: 43 additions & 0 deletions examples/websockets/hitl/hitl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import os

from langchain.agents import initialize_agent, load_tools
from langchain.callbacks.base import CallbackManager
from langchain.chat_models import ChatOpenAI
from langchain.llms import OpenAI

from lcserve import serving


@serving(websocket=True)
def hitl(question: str, **kwargs) -> str:
# Get the `streaming_handler` from `kwargs`. This is used to stream data to the client.
streaming_handler = kwargs.get('streaming_handler')

llm = ChatOpenAI(
temperature=0.0,
verbose=True,
streaming=True, # Pass `streaming=True` to make sure the client receives the data.
callback_manager=CallbackManager(
[streaming_handler]
), # Pass the callback handler
)
math_llm = OpenAI(
temperature=0.0,
verbose=True,
streaming=True, # Pass `streaming=True` to make sure the client receives the data.
callback_manager=CallbackManager(
[streaming_handler]
), # Pass the callback handler
)
tools = load_tools(
["human", "llm-math"],
llm=math_llm,
)

agent_chain = initialize_agent(
tools,
llm,
agent="zero-shot-react-description",
verbose=True,
)
return agent_chain.run(question)
63 changes: 63 additions & 0 deletions examples/websockets/hitl/hitl_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import asyncio
import os
from typing import Dict

import aiohttp
from pydantic import BaseModel, ValidationError


class Response(BaseModel):
result: str
error: str
stdout: str


class HumanPrompt(BaseModel):
prompt: str


async def hitl_client(url: str, name: str, question: str, envs: Dict = {}):
async with aiohttp.ClientSession() as session:
async with session.ws_connect(f'{url}/{name}') as ws:
print(f'Connected to {url}/{name}.')

await ws.send_json(
{
"question": question,
"envs": envs if envs else {},
}
)

async for msg in ws:
if msg.type == aiohttp.WSMsgType.TEXT:
if msg.data == 'close cmd':
await ws.close()
break
else:
try:
response = Response.parse_raw(msg.data)
print(response.result, end='')
except ValidationError:
try:
prompt = HumanPrompt.parse_raw(msg.data)
answer = input(prompt.prompt + '\n')
await ws.send_str(answer)
except ValidationError:
print(f'Unknown message: {msg.data}')

elif msg.type == aiohttp.WSMsgType.ERROR:
print('ws connection closed with exception %s' % ws.exception())
else:
print(msg)


asyncio.run(
hitl_client(
url='wss://langchain-1da55ad36a-websocket.wolf.jina.ai',
name='hitl',
question='What is Eric Zhu\'s birthday?',
envs={
'OPENAI_API_KEY': os.environ['OPENAI_API_KEY'],
},
)
)
1 change: 1 addition & 0 deletions examples/websockets/hitl/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
openai
5 changes: 3 additions & 2 deletions lcserve/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,16 @@ async def serve_on_jcloud(
from .backend.playground.utils.helper import get_random_tag

tag = get_random_tag()
gateway_id_wo_tag = push_app_to_hubble(module, tag=tag, verbose=verbose)
gateway_id_wo_tag, websocket = push_app_to_hubble(module, tag=tag, verbose=verbose)
app_id, endpoint = await deploy_app_on_jcloud(
flow_dict=get_flow_dict(
module,
module=module,
jcloud=True,
port=8080,
name=name,
app_id=app_id,
gateway_id=gateway_id_wo_tag + ':' + tag,
websocket=websocket,
),
app_id=app_id,
verbose=verbose,
Expand Down
35 changes: 24 additions & 11 deletions lcserve/backend/decorators.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,28 @@
from functools import wraps
from typing import Callable


def serving(func: Callable):
@wraps(func)
def wrapper(*args, **kwargs):
return func(*args, **kwargs)
def serving(_func=None, *, websocket: bool = False):
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
return func(*args, **kwargs)

wrapper.__serving__ = {
'name': func.__name__,
'doc': func.__doc__,
'params': {},
}
return wrapper
_args = {
'name': func.__name__,
'doc': func.__doc__,
'params': {
'include_callback_handlers': websocket,
# If websocket is True, pass the callback handlers to the client.
},
}
if websocket:
wrapper.__ws_serving__ = _args
else:
wrapper.__serving__ = _args

return wrapper

if _func is None:
return decorator
else:
return decorator(_func)

0 comments on commit c0af0de

Please sign in to comment.