Skip to content

Commit

Permalink
Merge pull request #207 from YiVal/add_template_instruct_for_opro
Browse files Browse the repository at this point in the history
add template instruct for opro
  • Loading branch information
uni-zhuan committed Nov 1, 2023
2 parents d3cfb6f + dd425f5 commit 5a5ec21
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 5 deletions.
33 changes: 28 additions & 5 deletions src/yival/enhancers/optimize_by_prompt_enhancer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
"""

import copy
import itertools
import json
import logging
from typing import Dict, List, Tuple
Expand All @@ -43,7 +44,13 @@
)
from ..schemas.model_configs import Request
from .base_combination_enhancer import BaseCombinationEnhancer
from .utils import construct_output_format, format_input_from_dict, scratch_variations_from_str
from .utils import (
construct_output_format,
construct_template_restrict,
format_input_from_dict,
scratch_template_vars,
scratch_variations_from_str,
)

rate_limiter = RateLimiter(60 / 60)

Expand Down Expand Up @@ -105,7 +112,7 @@ def construct_solution_score_pairs(
def construct_opro_full_prompt(
cache: List[Tuple[Dict, Dict]], head_meta_instruction: str,
optimation_task_format: str | None, end_meta_instruction: str,
enhance_var: List[str]
enhance_var: List[str], template_vars: List[str] | None
) -> str:
"""
Construct full opro prompt , which has a format as follow
Expand All @@ -120,6 +127,8 @@ def construct_opro_full_prompt(
if optimation_task_format:
full_prompt += (optimation_task_format + '\n')
full_prompt += (end_meta_instruction + '\n')
if template_vars:
full_prompt += construct_template_restrict(template_vars)
full_prompt += construct_output_format(enhance_var)

return full_prompt
Expand Down Expand Up @@ -211,6 +220,15 @@ def enhance(
assert set(self.config.enhance_var).issubset(set(best_combo.keys()))

variations_now = best_combo
template_vars = [
scratch_template_vars(str_value)
for str_value in best_combo.values()
]
if template_vars:
template_vars = list(itertools.chain(*template_vars)) #type:ignore
else:
template_vars = None #type:ignore

logging.info(f"[INFO][opro] first variations is {variations_now}")

lite_experiment_runner = LiteExperimentRunner(
Expand Down Expand Up @@ -243,11 +261,16 @@ def enhance(
cache.append((best_combo, score))

opro_prompt = construct_opro_full_prompt(
cache, self.config.head_meta_instruction,
cache,
self.config.head_meta_instruction,
self.config.optimation_task_format,
self.config.end_meta_instruction, self.config.enhance_var
self.config.end_meta_instruction,
self.config.enhance_var,
template_vars #type:ignore
)

print(f"[INFO] opro_prompt: {opro_prompt}")

gen_variations = self.fetch_next_variations(opro_prompt)

if not gen_variations:
Expand Down Expand Up @@ -289,4 +312,4 @@ def enhance(
BaseCombinationEnhancer.register_enhancer(
"optimize_by_prompt_enhancer", OptimizeByPromptEnhancer,
OptimizeByPromptEnhancerConfig
)
)
30 changes: 30 additions & 0 deletions src/yival/enhancers/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import re
from typing import Dict, List


Expand Down Expand Up @@ -79,6 +80,35 @@ def construct_output_format(variations: List[str]) -> str:
return prompt


def construct_template_restrict(template_vars: List[str]) -> str:
"""
Restrict llm to output variations in format template
e.g. template_vars: ['user_info']
retrict_prompt:
Please follow python's template formatting for replies and make sure your output conforms to the format of python string.
* Use {user_info} instead of user_info
"""

prompt = "Please follow python's template formatting for replies and make sure your output conforms to the format of python string.\n"
for var in template_vars:
prompt += f"* Use {{{var}}} instead of {var}\n"
return prompt + '\n'


def scratch_template_vars(prompt: str) -> List[str]:
"""
scratch template vars from given prompt.
e.g. prompt: write a short discord welcome message based on the following info\n user_info: {user_info}\n channel_type: {channel_type}
response: ['user_info', 'channel_type']
"""
return re.findall(r"\{(\w+)\}", prompt)


if __name__ == "__main__":
input = """
This is the generated new output
Expand Down

0 comments on commit 5a5ec21

Please sign in to comment.