Skip to content

Commit

Permalink
Implement createReActAgent (#169)
Browse files Browse the repository at this point in the history
  • Loading branch information
hinthornw committed May 19, 2024
1 parent 1f5f60d commit 7190937
Show file tree
Hide file tree
Showing 7 changed files with 524 additions and 41 deletions.
2 changes: 1 addition & 1 deletion langgraph/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
"devDependencies": {
"@jest/globals": "^29.5.0",
"@langchain/community": "^0.0.43",
"@langchain/openai": "^0.0.23",
"@langchain/openai": "latest",
"@langchain/scripts": "^0.0.13",
"@swc/core": "^1.3.90",
"@swc/jest": "^0.2.29",
Expand Down
2 changes: 2 additions & 0 deletions langgraph/src/prebuilt/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ export {
type FunctionCallingExecutorState,
createFunctionCallingExecutor,
} from "./chat_agent_executor.js";
export { type AgentState, createReactAgent } from "./react_agent_executor.js";

export {
type ToolExecutorArgs,
type ToolInvocationInterface,
Expand Down
170 changes: 170 additions & 0 deletions langgraph/src/prebuilt/react_agent_executor.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
import { BaseChatModel } from "@langchain/core/language_models/chat_models";
import {
BaseMessage,
BaseMessageChunk,
isAIMessage,
SystemMessage,
} from "@langchain/core/messages";
import {
Runnable,
RunnableInterface,
RunnableLambda,
} from "@langchain/core/runnables";
import { DynamicTool, StructuredTool } from "@langchain/core/tools";

import {
BaseLanguageModelCallOptions,
BaseLanguageModelInput,
} from "@langchain/core/language_models/base";
import { ChatPromptTemplate } from "@langchain/core/prompts";
import { BaseCheckpointSaver } from "../checkpoint/base.js";
import { END, START, StateGraph } from "../graph/index.js";
import { MessagesState } from "../graph/message.js";
import { CompiledStateGraph, StateGraphArgs } from "../graph/state.js";
import { All } from "../pregel/types.js";
import { ToolNode } from "./tool_node.js";

export interface AgentState {
messages: BaseMessage[];
// TODO: This won't be set until we
// implement managed values in LangGraphJS
// Will be useful for inserting a message on
// graph recursion end
// is_last_step: boolean;
}

export type N = typeof START | "agent" | "tools";

/**
* Creates a StateGraph agent that relies on a chat model utilizing tool calling.
* @param model The chat model that can utilize OpenAI-style function calling.
* @param tools A list of tools or a ToolNode.
* @param messageModifier An optional message modifier to apply to messages before being passed to the LLM.
* Can be a SystemMessage, string, function that takes and returns a list of messages, or a Runnable.
* @param checkpointSaver An optional checkpoint saver to persist the agent's state.
* @param interruptBefore An optional list of node names to interrupt before running.
* @param interruptAfter An optional list of node names to interrupt after running.
* @returns A compiled agent as a LangChain Runnable.
*/
export function createReactAgent(
model: BaseChatModel,
tools: ToolNode<MessagesState> | StructuredTool[],
messageModifier?:
| SystemMessage
| string
| ((messages: BaseMessage[]) => BaseMessage[])
| Runnable,
checkpointSaver?: BaseCheckpointSaver,
interruptBefore?: N[] | All,
interruptAfter?: N[] | All
): CompiledStateGraph<
AgentState,
Partial<AgentState>,
typeof START | "agent" | "tools"
> {
const schema: StateGraphArgs<AgentState>["channels"] = {
messages: {
value: (left: BaseMessage[], right: BaseMessage[]) => left.concat(right),
default: () => [],
},
};

let toolClasses: (StructuredTool | DynamicTool)[];
if (!Array.isArray(tools)) {
toolClasses = tools.tools;
} else {
toolClasses = tools;
}
if (!("bindTools" in model) || typeof model.bindTools !== "function") {
throw new Error(`Model ${model} must define bindTools method.`);
}
const modelWithTools = model.bindTools(toolClasses);
const modelRunnable = _createModelWrapper(modelWithTools, messageModifier);

const shouldContinue = (state: AgentState) => {
const { messages } = state;
const lastMessage = messages[messages.length - 1];
if (
isAIMessage(lastMessage) &&
(!lastMessage.tool_calls || lastMessage.tool_calls.length === 0)
) {
return END;
} else {
return "continue";
}
};

const callModel = async (state: AgentState) => {
const { messages } = state;
// TODO: Auto-promote streaming.
return { messages: [await modelRunnable.invoke(messages)] };
};

const workflow = new StateGraph<AgentState>({
channels: schema,
})
.addNode(
"agent",
new RunnableLambda({ func: callModel }).withConfig({ runName: "agent" })
)
.addNode("tools", new ToolNode<AgentState>(toolClasses))
.addEdge(START, "agent")
.addConditionalEdges("agent", shouldContinue, {
continue: "tools",
end: END,
})
.addEdge("tools", "agent");

return workflow.compile({
checkpointer: checkpointSaver,
interruptBefore,
interruptAfter,
});
}

function _createModelWrapper(
modelWithTools: RunnableInterface<
BaseLanguageModelInput,
BaseMessageChunk,
BaseLanguageModelCallOptions
>,
messageModifier?:
| SystemMessage
| string
| ((messages: BaseMessage[]) => BaseMessage[])
| Runnable
) {
if (!messageModifier) {
return modelWithTools;
}
const endict = new RunnableLambda({
func: (messages: BaseMessage[]) => ({ messages }),
});
if (typeof messageModifier === "string") {
const systemMessage = new SystemMessage(messageModifier);
const prompt = ChatPromptTemplate.fromMessages([
systemMessage,
["placeholder", "{messages}"],
]);
return endict.pipe(prompt).pipe(modelWithTools);
}
if (typeof messageModifier === "function") {
const lambda = new RunnableLambda({ func: messageModifier }).withConfig({
runName: "message_modifier",
});
return lambda.pipe(modelWithTools);
}
if (Runnable.isRunnable(messageModifier)) {
return messageModifier.pipe(modelWithTools);
}
if (messageModifier._getType() === "system") {
const prompt = ChatPromptTemplate.fromMessages([
messageModifier,
["placeholder", "{messages}"],
]);
return endict.pipe(prompt).pipe(modelWithTools);
}
throw new Error(
`Unsupported message modifier type: ${typeof messageModifier}`
);
}
17 changes: 10 additions & 7 deletions langgraph/src/prebuilt/tool_node.ts
Original file line number Diff line number Diff line change
@@ -1,24 +1,27 @@
import { BaseMessage, ToolMessage, AIMessage } from "@langchain/core/messages";
import { RunnableConfig } from "@langchain/core/runnables";
import { Tool } from "@langchain/core/tools";
import { StructuredTool } from "@langchain/core/tools";
import { RunnableCallable } from "../utils.js";
import { END } from "../graph/graph.js";
import { MessagesState } from "../graph/message.js";

export class ToolNode extends RunnableCallable<
BaseMessage[] | MessagesState,
BaseMessage[] | MessagesState
> {
export class ToolNode<
T extends BaseMessage[] | MessagesState
> extends RunnableCallable<T, T> {
/**
A node that runs the tools requested in the last AIMessage. It can be used
either in StateGraph with a "messages" key or in MessageGraph. If multiple
tool calls are requested, they will be run in parallel. The output will be
a list of ToolMessages, one for each tool call.
*/

tools: Tool[];
tools: StructuredTool[];

constructor(tools: Tool[], name: string = "tools", tags: string[] = []) {
constructor(
tools: StructuredTool[],
name: string = "tools",
tags: string[] = []
) {
super({ name, tags, func: (input, config) => this.run(input, config) });
this.tools = tools;
}
Expand Down
108 changes: 93 additions & 15 deletions langgraph/src/tests/prebuilt.int.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@ import { it, beforeAll, describe, expect } from "@jest/globals";
import { Tool } from "@langchain/core/tools";
import { ChatOpenAI } from "@langchain/openai";
import { BaseMessage, HumanMessage } from "@langchain/core/messages";
import { END } from "../index.js";
import { createFunctionCallingExecutor } from "../prebuilt/index.js";
import {
createReactAgent,
createFunctionCallingExecutor,
} from "../prebuilt/index.js";

// Tracing slows down the tests
beforeAll(() => {
Expand Down Expand Up @@ -43,7 +45,6 @@ describe("createFunctionCallingExecutor", () => {
messages: [new HumanMessage("What's the weather like in SF?")],
});

console.log(response);
// It needs at least one human message, one AI and one function message.
expect(response.messages.length > 3).toBe(true);
const firstFunctionMessage = (response.messages as Array<BaseMessage>).find(
Expand Down Expand Up @@ -78,26 +79,103 @@ describe("createFunctionCallingExecutor", () => {
tools,
});

const stream = await functionsAgentExecutor.stream({
messages: [new HumanMessage("What's the weather like in SF?")],
});
const stream = await functionsAgentExecutor.stream(
{
messages: [new HumanMessage("What's the weather like in SF?")],
},
{ streamMode: "values" }
);
const fullResponse = [];
for await (const item of stream) {
console.log(item);
console.log("-----\n");
fullResponse.push(item);
}

// Needs at least 3 llm calls, plus one `__end__` call.
expect(fullResponse.length >= 4).toBe(true);

const endMessage = fullResponse[fullResponse.length - 1];
expect(END in endMessage).toBe(true);
expect(endMessage[END].messages.length > 0).toBe(true);
// human -> agent -> action -> agent
expect(fullResponse.length).toEqual(4);

const functionCall = endMessage[END].messages.find(
const endState = fullResponse[fullResponse.length - 1];
// 1 human, 2 llm calls, 1 function call.
expect(endState.messages.length).toEqual(4);
const functionCall = endState.messages.find(
(message: BaseMessage) => message._getType() === "function"
);
expect(functionCall.content).toBe(weatherResponse);
});
});

describe("createReactAgent", () => {
it("can call a tool", async () => {
const weatherResponse = `Not too cold, not too hot 😎`;
const model = new ChatOpenAI();
class SanFranciscoWeatherTool extends Tool {
name = "current_weather";

description = "Get the current weather report for San Francisco, CA";

constructor() {
super();
}

async _call(_: string): Promise<string> {
return weatherResponse;
}
}
const tools = [new SanFranciscoWeatherTool()];

const reactAgent = createReactAgent(model, tools);

const response = await reactAgent.invoke({
messages: [new HumanMessage("What's the weather like in SF?")],
});

// It needs at least one human message and one AI message.
expect(response.messages.length > 1).toBe(true);
const lastMessage = response.messages[response.messages.length - 1];
expect(lastMessage._getType()).toBe("ai");
expect(lastMessage.content.toLowerCase()).toContain("not too cold");
});

it("can stream a tool call", async () => {
const weatherResponse = `Not too cold, not too hot 😎`;
const model = new ChatOpenAI({
streaming: true,
});
class SanFranciscoWeatherTool extends Tool {
name = "current_weather";

description = "Get the current weather report for San Francisco, CA";

constructor() {
super();
}

async _call(_: string): Promise<string> {
return weatherResponse;
}
}
const tools = [new SanFranciscoWeatherTool()];

const reactAgent = createReactAgent(model, tools);

const stream = await reactAgent.stream(
{
messages: [new HumanMessage("What's the weather like in SF?")],
},
{ streamMode: "values" }
);
const fullResponse = [];
for await (const item of stream) {
fullResponse.push(item);
}

// human -> agent -> action -> agent
expect(fullResponse.length).toEqual(4);
const endState = fullResponse[fullResponse.length - 1];
// 1 human, 2 ai, 1 tool.
expect(endState.messages.length).toEqual(4);

const lastMessage = endState.messages[endState.messages.length - 1];
expect(lastMessage._getType()).toBe("ai");
expect(lastMessage.content.toLowerCase()).toContain("not too cold");
});
});
Loading

0 comments on commit 7190937

Please sign in to comment.