Skip to content
/ prelude Public

Aligning LLM Agents by Learning Latent Preference from User Edits

License

Notifications You must be signed in to change notification settings

gao-g/prelude

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

6 Commits
 
 
 
 
 
 
 
 
 
 

Repository files navigation

PRELUDE

Code for Aligning LLM Agents by Learning Latent Preference from User Edits.

Table of Contents

Installation

  1. This project is developed in Python 3.6. Using Conda to set up a virtual environment is recommended.

  2. Install the required dependencies.

    pip install -r requirements.txt
    
  3. Install PyTorch from http://pytorch.org/.

Implementation of PRELUDE Framework

PRELUDE implementation contains the follwoing main concepts task, user, and agent.

Task

Task is the class encapsulating the following:

  1. Access to dataset which is sequence of the $(x_t, f^\star_t)$ pairs of (context, true user preference pairs)
  2. Main task prompt (Prompts to generate $y_t$ given $x_t$ and optionally $f_t$):
def get_task_prompt(self, input: str, preference: Optional[str] = None) -> str:
    ...
  1. User evaluation prompts (Prompts to generate $y'_t$):
def get_edit_prompts(self, input: str, output: str, preference: str) -> Tuple[str, str]:
    ...

Right now two different tasks are implemented - content summarization and email writing

Task specifics can be controlled using TaskConfig which allows to:

  1. Change the number of examples
  2. Choose random seed
  3. Specify data source

User

User encapsulates access to task and LLM resource for simulating user responses. For initialization, TaskConfig and UserConfig (allowing to specify the LLM model name) are required.

Agent

Classes responsible for accomplishing the tasks, encapsulating access to LLM and learning algorithm implementations.

Reproduce Our Experiments

All agents mentioned in our paper are located in the agent folder.

(INSTRUCTIONS TO BE ADDED)

Implement Your Own Agent

Every agent should be inherited from the base Agent class, and have implementations of the following methods:

  1. def complete(self, text) -> LLMOutput - task completion method returning LLMOutput object containing output text and (optionally) debug token information
  2. def learn(self, message, correction: Correction) -> Dict - learning method taking context text and pair of (agent completion, user edits) as inputs. Return value is the dictionary of metrics required to be logged.

Please check the notebook example of dummy agent implementation and end-to-end experiment run here.