Skip to content

Commit

Permalink
Improve and Fix: response parser, invalid response and others (#522)
Browse files Browse the repository at this point in the history
* chore: minor updates

* Add: send live inference time to frontend

* add: timeout in settings

* Improve: response parsing, add temperature to models

* patches

close #510, #507, #502, #468
  • Loading branch information
ARajgor committed May 2, 2024
1 parent 75df1c6 commit cdfb782
Show file tree
Hide file tree
Showing 22 changed files with 245 additions and 180 deletions.
29 changes: 0 additions & 29 deletions app.dockerfile

This file was deleted.

4 changes: 1 addition & 3 deletions devika.py
Expand Up @@ -186,9 +186,7 @@ def real_time_logs():
@route_logger(logger)
def set_settings():
data = request.json
print("Data: ", data)
config.config.update(data)
config.save_config()
config.update_config(data)
return jsonify({"message": "Settings updated"})


Expand Down
3 changes: 3 additions & 0 deletions sample.config.toml
Expand Up @@ -26,3 +26,6 @@ OPENAI = "https://api.openai.com/v1"
[LOGGING]
LOG_REST_API = "true"
LOG_PROMPTS = "false"

[TIMEOUT]
INFERENCE = 60
13 changes: 2 additions & 11 deletions src/agents/action/action.py
Expand Up @@ -2,7 +2,7 @@

from jinja2 import Environment, BaseLoader

from src.services.utils import retry_wrapper
from src.services.utils import retry_wrapper, validate_responses
from src.config import Config
from src.llm import LLM

Expand All @@ -24,17 +24,8 @@ def render(
conversation=conversation
)

@validate_responses
def validate_response(self, response: str):
response = response.strip().replace("```json", "```")

if response.startswith("```") and response.endswith("```"):
response = response[3:-3].strip()

try:
response = json.loads(response)
except Exception as _:
return False

if "response" not in response and "action" not in response:
return False
else:
Expand Down
13 changes: 2 additions & 11 deletions src/agents/answer/answer.py
Expand Up @@ -2,7 +2,7 @@

from jinja2 import Environment, BaseLoader

from src.services.utils import retry_wrapper
from src.services.utils import retry_wrapper, validate_responses
from src.config import Config
from src.llm import LLM

Expand All @@ -25,17 +25,8 @@ def render(
code_markdown=code_markdown
)

@validate_responses
def validate_response(self, response: str):
response = response.strip().replace("```json", "```")

if response.startswith("```") and response.endswith("```"):
response = response[3:-3].strip()

try:
response = json.loads(response)
except Exception as _:
return False

if "response" not in response:
return False
else:
Expand Down
13 changes: 2 additions & 11 deletions src/agents/decision/decision.py
Expand Up @@ -2,7 +2,7 @@

from jinja2 import Environment, BaseLoader

from src.services.utils import retry_wrapper
from src.services.utils import retry_wrapper, validate_responses
from src.llm import LLM

PROMPT = open("src/agents/decision/prompt.jinja2").read().strip()
Expand All @@ -16,17 +16,8 @@ def render(self, prompt: str) -> str:
template = env.from_string(PROMPT)
return template.render(prompt=prompt)

@validate_responses
def validate_response(self, response: str):
response = response.strip().replace("```json", "```")

if response.startswith("```") and response.endswith("```"):
response = response[3:-3].strip()

try:
response = json.loads(response)
except Exception as _:
return False

for item in response:
if "function" not in item or "args" not in item or "reply" not in item:
return False
Expand Down
17 changes: 4 additions & 13 deletions src/agents/internal_monologue/internal_monologue.py
Expand Up @@ -3,7 +3,7 @@
from jinja2 import Environment, BaseLoader

from src.llm import LLM
from src.services.utils import retry_wrapper
from src.services.utils import retry_wrapper, validate_responses

PROMPT = open("src/agents/internal_monologue/prompt.jinja2").read().strip()

Expand All @@ -16,19 +16,10 @@ def render(self, current_prompt: str) -> str:
template = env.from_string(PROMPT)
return template.render(current_prompt=current_prompt)

@validate_responses
def validate_response(self, response: str):
response = response.strip().replace("```json", "```")

if response.startswith("```") and response.endswith("```"):
response = response[3:-3].strip()

try:
response = json.loads(response)
except Exception as _:
return False

response = {k.replace("\\", ""): v for k, v in response.items()}

print('-------------------> ', response)
print("####", type(response))
if "internal_monologue" not in response:
return False
else:
Expand Down
25 changes: 10 additions & 15 deletions src/agents/researcher/prompt.jinja2
Expand Up @@ -11,18 +11,20 @@ Only respond in the following JSON format:

```
{
"queries": [
"<QUERY 1>",
"<QUERY 2>"
],
"ask_user": "<ASK INPUT FROM USER>"
"queries": ["<QUERY 1>", "<QUERY 2>", "<QUERY 3>", ... ],
"ask_user": "<ASK INPUT FROM USER IF REQUIRED, OTHERWISE LEAVE EMPTY STRING>"
}
```
Example =>
```
{
"queries": ["How to do Bing Search via API in Python", "Claude API Documentation Python"],
"ask_user": "Can you please provide API Keys for Claude, OpenAI, and Firebase?"
}
```

Keywords for Search Query: {{ contextual_keywords }}

Example "queries": ["How to do Bing Search via API in Python", "Claude API Documentation Python"]
Example "ask_user": "Can you please provide API Keys for Claude, OpenAI, and Firebase?"

Rules:
- Only search for a maximum of 3 queries.
Expand All @@ -33,13 +35,6 @@ Rules:
- Do not search for basic queries, only search for advanced and specific queries. You are allowed to leave the "queries" field empty if no search queries are needed for the step.
- DO NOT EVER SEARCH FOR BASIC QUERIES. ONLY SEARCH FOR ADVANCED QUERIES.
- YOU ARE ALLOWED TO LEAVE THE "queries" FIELD EMPTY IF NO SEARCH QUERIES ARE NEEDED FOR THE STEP.

Remember to only make search queries for resources that might require external information (like Documentation or a Blog or an Article). If the information is already known to you or commonly known, there is no need to search for it.

The `queries` key and the `ask_user` key can be empty list and string respectively if no search queries or user input are needed for the step. Try to keep the number of search queries to a minimum to save context window. One query per subject.

Only search for documentation or articles that are relevant to the task at hand. Do not search for general information.

Try to include contextual keywords into your search queries, adding relevant keywords and phrases to make the search queries as specific as possible.
- you only have to return one JSON object with the queries and ask_user fields. You can't return multiple JSON objects.

Only the provided JSON response format is accepted. Any other response format will be rejected.
13 changes: 2 additions & 11 deletions src/agents/researcher/researcher.py
Expand Up @@ -4,7 +4,7 @@
from jinja2 import Environment, BaseLoader

from src.llm import LLM
from src.services.utils import retry_wrapper
from src.services.utils import retry_wrapper, validate_responses
from src.browser.search import BingSearch

PROMPT = open("src/agents/researcher/prompt.jinja2").read().strip()
Expand All @@ -23,17 +23,8 @@ def render(self, step_by_step_plan: str, contextual_keywords: str) -> str:
contextual_keywords=contextual_keywords
)

@validate_responses
def validate_response(self, response: str) -> dict | bool:
response = response.strip().replace("```json", "```")

if response.startswith("```") and response.endswith("```"):
response = response[3:-3].strip()
try:
response = json.loads(response)
except Exception as _:
return False

response = {k.replace("\\", ""): v for k, v in response.items()}

if "queries" not in response and "ask_user" not in response:
return False
Expand Down
30 changes: 4 additions & 26 deletions src/agents/runner/runner.py
Expand Up @@ -10,7 +10,7 @@
from src.llm import LLM
from src.state import AgentState
from src.project import ProjectManager
from src.services.utils import retry_wrapper
from src.services.utils import retry_wrapper, validate_responses

PROMPT = open("src/agents/runner/prompt.jinja2", "r").read().strip()
RERUNNER_PROMPT = open("src/agents/runner/rerunner.jinja2", "r").read().strip()
Expand Down Expand Up @@ -52,37 +52,15 @@ def render_rerunner(
error=error
)

@validate_responses
def validate_response(self, response: str):
response = response.strip().replace("```json", "```")

if response.startswith("```") and response.endswith("```"):
response = response[3:-3].strip()

try:
response = json.loads(response)
except Exception as _:
return False

if "commands" not in response:
return False
else:
return response["commands"]


@validate_responses
def validate_rerunner_response(self, response: str):
response = response.strip().replace("```json", "```")

if response.startswith("```") and response.endswith("```"):
response = response[3:-3].strip()

print(response)

try:
response = json.loads(response)
except Exception as _:
return False

print(response)

if "action" not in response and "response" not in response:
return False
else:
Expand Down
42 changes: 18 additions & 24 deletions src/config.py
Expand Up @@ -104,6 +104,9 @@ def get_logging_rest_api(self):

def get_logging_prompts(self):
return self.config["LOGGING"]["LOG_PROMPTS"] == "true"

def get_timeout_inference(self):
return self.config["TIMEOUT"]["INFERENCE"]

def set_bing_api_key(self, key):
self.config["API_KEYS"]["BING"] = key
Expand Down Expand Up @@ -157,30 +160,6 @@ def set_netlify_api_key(self, key):
self.config["API_KEYS"]["NETLIFY"] = key
self.save_config()

def set_sqlite_db(self, db):
self.config["STORAGE"]["SQLITE_DB"] = db
self.save_config()

def set_screenshots_dir(self, dir):
self.config["STORAGE"]["SCREENSHOTS_DIR"] = dir
self.save_config()

def set_pdfs_dir(self, dir):
self.config["STORAGE"]["PDFS_DIR"] = dir
self.save_config()

def set_projects_dir(self, dir):
self.config["STORAGE"]["PROJECTS_DIR"] = dir
self.save_config()

def set_logs_dir(self, dir):
self.config["STORAGE"]["LOGS_DIR"] = dir
self.save_config()

def set_repos_dir(self, dir):
self.config["STORAGE"]["REPOS_DIR"] = dir
self.save_config()

def set_logging_rest_api(self, value):
self.config["LOGGING"]["LOG_REST_API"] = "true" if value else "false"
self.save_config()
Expand All @@ -189,6 +168,21 @@ def set_logging_prompts(self, value):
self.config["LOGGING"]["LOG_PROMPTS"] = "true" if value else "false"
self.save_config()

def set_timeout_inference(self, value):
self.config["TIMEOUT"]["INFERENCE"] = value
self.save_config()

def save_config(self):
with open("config.toml", "w") as f:
toml.dump(self.config, f)

def update_config(self, data):
for key, value in data.items():
if key in self.config:
with open("config.toml", "r+") as f:
config = toml.load(f)
for sub_key, sub_value in value.items():
self.config[key][sub_key] = sub_value
config[key][sub_key] = sub_value
f.seek(0)
toml.dump(config, f)
1 change: 1 addition & 0 deletions src/llm/claude_client.py
Expand Up @@ -20,6 +20,7 @@ def inference(self, model_id: str, prompt: str) -> str:
}
],
model=model_id,
temperature=0
)

return message.content[0].text
3 changes: 2 additions & 1 deletion src/llm/gemini_client.py
Expand Up @@ -10,7 +10,8 @@ def __init__(self):
genai.configure(api_key=api_key)

def inference(self, model_id: str, prompt: str) -> str:
model = genai.GenerativeModel(model_id)
config = genai.GenerationConfig(temperature=0)
model = genai.GenerativeModel(model_id, generation_config=config)
# Set safety settings for the request
safety_settings = {
HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
Expand Down
1 change: 1 addition & 0 deletions src/llm/groq_client.py
Expand Up @@ -18,6 +18,7 @@ def inference(self, model_id: str, prompt: str) -> str:
}
],
model=model_id,
temperature=0
)

return chat_completion.choices[0].message.content

0 comments on commit cdfb782

Please sign in to comment.