From 5cd26b24e33a08f84c9ac79b6fa8b5d24fc50b9b Mon Sep 17 00:00:00 2001 From: William FH <13333726+hinthornw@users.noreply.github.com> Date: Fri, 17 May 2024 22:52:54 -0700 Subject: [PATCH] [Docs] Dynamically return directly from a tool node (#164) --- docs/docs/how-tos/index.md | 6 +- docs/mkdocs.yml | 1 + .../dynamically-returning-directly.ipynb | 567 ++++++++++++++++++ 3 files changed, 571 insertions(+), 3 deletions(-) create mode 100644 examples/how-tos/dynamically-returning-directly.ipynb diff --git a/docs/docs/how-tos/index.md b/docs/docs/how-tos/index.md index 7550896f..ab87fe38 100644 --- a/docs/docs/how-tos/index.md +++ b/docs/docs/how-tos/index.md @@ -1,12 +1,11 @@ # How-to guides -Welcome to the LangGraphJS How-to Guides! These guides provide practical, step-by-step instructions for accomplishing key tasks in LangGraphJS. +Welcome to the LangGraphJS How-to Guides! These guides provide practical, step-by-step instructions for accomplishing key tasks in LangGraphJS. ## In progress 🚧 This section is currently in progress. More updates to come! 🚧 - ## Core The core guides show how to address common needs when building a out AI workflows, with special focus placed on [ReAct](https://arxiv.org/abs/2210.03629)-style agents with [tool calling](https://js.langchain.com/v0.2/docs/how_to/tool_calling/). @@ -26,4 +25,5 @@ How to apply common design patterns in your workflows: The following examples are useful especially if you are used to LangChain's AgentExecutor configurations. -- [Force calling a tool first](force-calling-a-tool-first.ipynb): Define a fixed workflow before ceding control to the ReAct agent \ No newline at end of file +- [Force calling a tool first](force-calling-a-tool-first.ipynb): Define a fixed workflow before ceding control to the ReAct agent +- [Dynamic direct return](dynamically-returning-directly.ipynb): Let the LLM decide whether the graph should finish after a tool is run or whether the LLM should be able to review the output and keep going diff --git a/docs/mkdocs.yml b/docs/mkdocs.yml index e92c0325..b75ca129 100644 --- a/docs/mkdocs.yml +++ b/docs/mkdocs.yml @@ -94,6 +94,7 @@ nav: - "how-tos/subgraph.ipynb" - "how-tos/human-in-the-loop.ipynb" - "how-tos/force-calling-a-tool-first.ipynb" + - "how-tos/dynamically-returning-directly.ipynb" - "Conceptual Guides": - "concepts/index.md" - "Reference": diff --git a/examples/how-tos/dynamically-returning-directly.ipynb b/examples/how-tos/dynamically-returning-directly.ipynb new file mode 100644 index 00000000..d5febaa6 --- /dev/null +++ b/examples/how-tos/dynamically-returning-directly.ipynb @@ -0,0 +1,567 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "7161c5f2", + "metadata": {}, + "source": [ + "# Dynamically Returning Directly\n", + "\n", + "A typical ReAct loop follows user -> assistant -> tool -> assistant ..., ->\n", + "user. In some cases, you don't need to call the LLM after the tool completes,\n", + "the user can view the results directly themselves.\n", + "\n", + "In this example we will build a conversational ReAct agent where the LLM can\n", + "optionally decide to return the result of a tool call as the final answer. This\n", + "is useful in cases where you have tools that can sometimes generate responses\n", + "that are acceptable as final answers, and you want to use the LLM to determine\n", + "when that is the case\n", + "\n", + "## Setup\n", + "\n", + "First we need to install the required packages:\n", + "\n", + "```bash\n", + "yarn add @langchain/langgraph @langchain/openai\n", + "```\n", + "\n", + "Next, we need to set API keys for OpenAI (the LLM we will use). Optionally, we\n", + "can set API key for [LangSmith tracing](https://smith.langchain.com/), which\n", + "will give us best-in-class observability." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "4eb03bd6", + "metadata": {}, + "outputs": [], + "source": [ + "// Deno.env.set(\"OPENAI_API_KEY\", \"sk_...\");\n", + "\n", + "// Optional, add tracing in LangSmith\n", + "// Deno.env.set(\"LANGCHAIN_API_KEY\", \"ls__...\");\n", + "Deno.env.set(\"LANGCHAIN_CALLBACKS_BACKGROUND\", \"true\");\n", + "Deno.env.set(\"LANGCHAIN_TRACING_V2\", \"true\");\n", + "Deno.env.set(\"LANGCHAIN_PROJECT\", \"Direct Return: LangGraphJS\");\n" + ] + }, + { + "cell_type": "markdown", + "id": "57310858", + "metadata": {}, + "source": [ + "## Set up the tools\n", + "\n", + "We will first define the tools we want to use. For this simple example, we will\n", + "use a simple placeholder \"search engine\". However, it is really easy to create\n", + "your own tools - see documentation\n", + "[here](https://js.langchain.com/docs/modules/agents/tools/dynamic) on how to do\n", + "that.\n", + "\n", + "To add a 'return_direct' option, we will create a custom zod schema to use\n", + "**instead of** the schema that would be automatically inferred by the tool." + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "481c95ac", + "metadata": {}, + "outputs": [], + "source": [ + "import { DynamicStructuredTool } from \"@langchain/core/tools\";\n", + "import { z } from \"zod\";\n", + "\n", + "const SearchTool = z.object({\n", + " query: z.string().describe(\"query to look up online\"),\n", + " // **IMPORTANT** We are adding an **extra** field here\n", + " // that isn't used directly by the tool - it's used by our\n", + " // graph instead to determine whether or not to return the\n", + " // result directly to the user\n", + " return_direct: z\n", + " .boolean()\n", + " .describe(\n", + " \"Whether or the result of this should be returned directly to the user without you seeing what it is\",\n", + " )\n", + " .default(false),\n", + "});\n", + "\n", + "const searchTool = new DynamicStructuredTool({\n", + " name: \"search\",\n", + " description: \"Call to surf the web.\",\n", + " // We are overriding the default schema here to\n", + " // add an extra field\n", + " schema: SearchTool,\n", + " func: async ({ query }: { query: string }) => {\n", + " // This is a placeholder for the actual implementation\n", + " // Don't let the LLM know this though 😊\n", + " return \"It's sunny in San Francisco, but you better look out if you're a Gemini 😈.\";\n", + " },\n", + "});\n", + "\n", + "const tools = [searchTool];" + ] + }, + { + "cell_type": "markdown", + "id": "5b0d34fd", + "metadata": {}, + "source": [ + "We can now wrap these tools in a simple ToolExecutor.\\\n", + "This is a real simple class that takes in a ToolInvocation and calls that tool,\n", + "returning the output. A ToolInvocation is any type with `tool` and `toolInput`\n", + "attribute." + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "250415e4", + "metadata": {}, + "outputs": [], + "source": [ + "import { ToolNode } from \"@langchain/langgraph/prebuilt\";\n", + "\n", + "const toolNode = new ToolNode(tools);\n" + ] + }, + { + "cell_type": "markdown", + "id": "abf3e729", + "metadata": {}, + "source": [ + "## Set up the model\n", + "\n", + "Now we need to load the chat model we want to use.\\\n", + "Importantly, this should satisfy two criteria:\n", + "\n", + "1. It should work with messages. We will represent all agent state in the form\n", + " of messages, so it needs to be able to work well with them.\n", + "2. It should support\n", + " [tool calling](https://js.langchain.com/v0.2/docs/concepts/#functiontool-calling).\n", + "\n", + "Note: these model requirements are not requirements for using LangGraph - they\n", + "are just requirements for this one example." + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "2c24d018", + "metadata": {}, + "outputs": [], + "source": [ + "import { ChatAnthropic } from \"@langchain/anthropic\";\n", + "\n", + "const model = new ChatAnthropic({\n", + " temperature: 0,\n", + " model: \"claude-3-haiku-20240307\",\n", + "});\n", + "// This formats the tools as json schema for the model API.\n", + "// The model then uses this like a system prompt.\n", + "const boundModel = model.bindTools(tools);\n" + ] + }, + { + "cell_type": "markdown", + "id": "644169c4", + "metadata": {}, + "source": [ + "## Define the agent state\n", + "\n", + "The main type of graph in `langgraph` is the\n", + "[StateGraph](https://langchain-ai.github.io/langgraphjs/reference/classes/index.StateGraph.html).\n", + "\n", + "This graph is parameterized by a state object that it passes around to each\n", + "node. Each node then returns operations to update that state. These operations\n", + "can either SET specific attributes on the state (e.g. overwrite the existing\n", + "values) or ADD to the existing attribute. Whether to set or add is denoted in\n", + "the state object you construct the graph with.\n", + "\n", + "For this example, the state we will track will just be a list of messages. We\n", + "want each node to just add messages to that list. Therefore, we will define the\n", + "state as follows:" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "24454123", + "metadata": {}, + "outputs": [], + "source": [ + "interface AgentStateBase {\n", + " messages: Array;\n", + "}\n", + "\n", + "interface AgentState extends AgentStateBase {}\n", + "\n", + "const agentState = {\n", + " messages: {\n", + " value: (x: BaseMessage[], y: BaseMessage[]) => x.concat(y),\n", + " default: () => [],\n", + " },\n", + "};\n" + ] + }, + { + "cell_type": "markdown", + "id": "abae32d9", + "metadata": {}, + "source": [ + "## Define the nodes\n", + "\n", + "We now need to define a few different nodes in our graph. In `langgraph`, a node\n", + "can be either a function or a\n", + "[runnable](https://js.langchain.com/docs/expression_language/). There are two\n", + "main nodes we need for this:\n", + "\n", + "1. The agent: responsible for deciding what (if any) actions to take.\n", + "2. A function to invoke tools: if the agent decides to take an action, this node\n", + " will then execute that action.\n", + "\n", + "We will also need to define some edges. Some of these edges may be conditional.\n", + "The reason they are conditional is that based on the output of a node, one of\n", + "several paths may be taken. The path that is taken is not known until that node\n", + "is run (the LLM decides).\n", + "\n", + "1. Conditional Edge: after the agent is called, we should either: a. If the\n", + " agent said to take an action, then the function to invoke tools should be\n", + " called b. If the agent said that it was finished, then it should finish\n", + "2. Normal Edge: after the tools are invoked, it should always go back to the\n", + " agent to decide what to do next\n", + "\n", + "Let's define the nodes, as well as a function to decide how what conditional\n", + "edge to take." + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "23a8b9c6", + "metadata": {}, + "outputs": [], + "source": [ + "import type { RunnableConfig } from \"@langchain/core/runnables\";\n", + "import { END } from \"@langchain/langgraph\";\n", + "\n", + "// Define the function that determines whether to continue or not\n", + "const shouldContinue = (state: AgentState) => {\n", + " const { messages } = state;\n", + " const lastMessage = messages[messages.length - 1];\n", + " // If there is no function call, then we finish\n", + " if (!lastMessage.tool_calls || lastMessage.tool_calls.length === 0) {\n", + " return END;\n", + " } // Otherwise if there is, we check if it's suppose to return direct\n", + " else {\n", + " const args = lastMessage.tool_calls[0].args;\n", + " if (args?.return_direct) {\n", + " return \"final\";\n", + " } else {\n", + " return \"tools\";\n", + " }\n", + " }\n", + "};\n", + "\n", + "// Define the function that calls the model\n", + "const callModel = async (state: AgentState, config: RunnableConfig) => {\n", + " const messages = state.messages;\n", + " const response = await boundModel.invoke(messages, config);\n", + " // We return an object, because this will get added to the existing list\n", + " return { messages: [response] };\n", + "};" + ] + }, + { + "cell_type": "markdown", + "id": "e84040dd", + "metadata": {}, + "source": [ + "## Define the graph\n", + "\n", + "We can now put it all together and define the graph!\n" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "05203811", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "StateGraph {\n", + " nodes: {\n", + " agent: RunnableLambda {\n", + " lc_serializable: \u001b[33mfalse\u001b[39m,\n", + " lc_kwargs: { func: \u001b[36m[AsyncFunction: callModel]\u001b[39m },\n", + " lc_runnable: \u001b[33mtrue\u001b[39m,\n", + " name: \u001b[90mundefined\u001b[39m,\n", + " lc_namespace: [ \u001b[32m\"langchain_core\"\u001b[39m, \u001b[32m\"runnables\"\u001b[39m ],\n", + " func: \u001b[36m[AsyncFunction: callModel]\u001b[39m\n", + " },\n", + " tools: ToolNode {\n", + " lc_serializable: \u001b[33mfalse\u001b[39m,\n", + " lc_kwargs: {},\n", + " lc_runnable: \u001b[33mtrue\u001b[39m,\n", + " name: \u001b[32m\"tools\"\u001b[39m,\n", + " lc_namespace: [ \u001b[32m\"langgraph\"\u001b[39m ],\n", + " func: \u001b[36m[Function: func]\u001b[39m,\n", + " tags: \u001b[90mundefined\u001b[39m,\n", + " config: { tags: [] },\n", + " trace: \u001b[33mtrue\u001b[39m,\n", + " recurse: \u001b[33mtrue\u001b[39m,\n", + " tools: [\n", + " DynamicStructuredTool {\n", + " lc_serializable: \u001b[33mfalse\u001b[39m,\n", + " lc_kwargs: \u001b[36m[Object]\u001b[39m,\n", + " lc_runnable: \u001b[33mtrue\u001b[39m,\n", + " name: \u001b[32m\"search\"\u001b[39m,\n", + " verbose: \u001b[33mfalse\u001b[39m,\n", + " callbacks: \u001b[90mundefined\u001b[39m,\n", + " tags: [],\n", + " metadata: {},\n", + " returnDirect: \u001b[33mfalse\u001b[39m,\n", + " description: \u001b[32m\"Call to surf the web.\"\u001b[39m,\n", + " func: \u001b[36m[AsyncFunction: func]\u001b[39m,\n", + " schema: \u001b[36m[ZodObject]\u001b[39m\n", + " }\n", + " ]\n", + " },\n", + " final: ToolNode {\n", + " lc_serializable: \u001b[33mfalse\u001b[39m,\n", + " lc_kwargs: {},\n", + " lc_runnable: \u001b[33mtrue\u001b[39m,\n", + " name: \u001b[32m\"tools\"\u001b[39m,\n", + " lc_namespace: [ \u001b[32m\"langgraph\"\u001b[39m ],\n", + " func: \u001b[36m[Function: func]\u001b[39m,\n", + " tags: \u001b[90mundefined\u001b[39m,\n", + " config: { tags: [] },\n", + " trace: \u001b[33mtrue\u001b[39m,\n", + " recurse: \u001b[33mtrue\u001b[39m,\n", + " tools: [\n", + " DynamicStructuredTool {\n", + " lc_serializable: \u001b[33mfalse\u001b[39m,\n", + " lc_kwargs: \u001b[36m[Object]\u001b[39m,\n", + " lc_runnable: \u001b[33mtrue\u001b[39m,\n", + " name: \u001b[32m\"search\"\u001b[39m,\n", + " verbose: \u001b[33mfalse\u001b[39m,\n", + " callbacks: \u001b[90mundefined\u001b[39m,\n", + " tags: [],\n", + " metadata: {},\n", + " returnDirect: \u001b[33mfalse\u001b[39m,\n", + " description: \u001b[32m\"Call to surf the web.\"\u001b[39m,\n", + " func: \u001b[36m[AsyncFunction: func]\u001b[39m,\n", + " schema: \u001b[36m[ZodObject]\u001b[39m\n", + " }\n", + " ]\n", + " }\n", + " },\n", + " edges: Set(3) {\n", + " [ \u001b[32m\"__start__\"\u001b[39m, \u001b[32m\"agent\"\u001b[39m ],\n", + " [ \u001b[32m\"tools\"\u001b[39m, \u001b[32m\"agent\"\u001b[39m ],\n", + " [ \u001b[32m\"final\"\u001b[39m, \u001b[32m\"__end__\"\u001b[39m ]\n", + " },\n", + " branches: {\n", + " agent: {\n", + " shouldContinue: Branch {\n", + " condition: \u001b[36m[Function: shouldContinue]\u001b[39m,\n", + " ends: \u001b[90mundefined\u001b[39m,\n", + " then: \u001b[90mundefined\u001b[39m\n", + " }\n", + " }\n", + " },\n", + " entryPoint: \u001b[90mundefined\u001b[39m,\n", + " compiled: \u001b[33mtrue\u001b[39m,\n", + " supportMultipleEdges: \u001b[33mtrue\u001b[39m,\n", + " channels: {\n", + " messages: BinaryOperatorAggregate {\n", + " lc_graph_name: \u001b[32m\"BinaryOperatorAggregate\"\u001b[39m,\n", + " value: [],\n", + " operator: \u001b[36m[Function: value]\u001b[39m,\n", + " initialValueFactory: \u001b[36m[Function: default]\u001b[39m\n", + " }\n", + " },\n", + " waitingEdges: Set(0) {}\n", + "}" + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import { END, START, StateGraph } from \"@langchain/langgraph\";\n", + "\n", + "// Define a new graph\n", + "const workflow = new StateGraph({\n", + " channels: agentState,\n", + "});\n", + "\n", + "// Define the two nodes we will cycle between\n", + "workflow.addNode(\"agent\", callModel);\n", + "\n", + "// Note the \"action\" and \"final\" nodes are identical!\n", + "workflow.addNode(\"tools\", toolNode);\n", + "workflow.addNode(\"final\", toolNode);\n", + "\n", + "// Set the entrypoint as `agent`\n", + "workflow.addEdge(START, \"agent\");\n", + "\n", + "// We now add a conditional edge\n", + "workflow.addConditionalEdges(\n", + " // First, we define the start node. We use `agent`.\n", + " \"agent\",\n", + " // Next, we pass in the function that will determine which node is called next.\n", + " shouldContinue,\n", + ");\n", + "\n", + "// We now add a normal edge from `tools` to `agent`.\n", + "workflow.addEdge(\"tools\", \"agent\");\n", + "workflow.addEdge(\"final\", END);\n", + "\n", + "// Finally, we compile it!\n", + "const app = workflow.compile();" + ] + }, + { + "cell_type": "markdown", + "id": "ef7f65bd", + "metadata": {}, + "source": [ + "## Use it!\n", + "\n", + "We can now use it! This now exposes the\n", + "[same interface](https://js.langchain.com/docs/expression_language/) as all\n", + "other LangChain runnables." + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "de5f4864", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[human]: what is the weather in sf\n", + "-----\n", + "\n", + "[ai]: [object Object] \n", + "Tools: \n", + "- search({\"query\":\"weather in sf\",\"return_direct\":true})\n", + "-----\n", + "\n", + "[tool]: It's sunny in San Francisco, but you better look out if you're a Gemini 😈.\n", + "-----\n", + "\n" + ] + } + ], + "source": [ + "import { AIMessage, BaseMessage, HumanMessage } from \"@langchain/core/messages\";\n", + "\n", + "const prettyPrint = (message: BaseMessage) => {\n", + " let txt = `[${message._getType()}]: ${message.content}`;\n", + " if (\n", + " (message._getType() === \"ai\" &&\n", + " (message as AIMessage)?.tool_calls?.length) ||\n", + " 0 > 0\n", + " ) {\n", + " const tool_calls = (message as AIMessage)?.tool_calls\n", + " ?.map((tc) => `- ${tc.name}(${JSON.stringify(tc.args)})`)\n", + " .join(\"\\n\");\n", + " txt += ` \\nTools: \\n${tool_calls}`;\n", + " }\n", + " console.log(txt);\n", + "};\n", + "\n", + "const inputs = { messages: [new HumanMessage(\"what is the weather in sf\")] };\n", + "for await (const output of await app.stream(inputs, { streamMode: \"values\" })) {\n", + " const lastMessage = output.messages[output.messages.length - 1];\n", + " prettyPrint(lastMessage);\n", + " console.log(\"-----\\n\");\n", + "}\n" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "986f8cfe", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[human]: what is the weather in sf? return this result directly by setting return_direct = True\n", + "-----\n", + "\n", + "[ai]: [object Object] \n", + "Tools: \n", + "- search({\"query\":\"weather in sf\",\"return_direct\":true})\n", + "-----\n", + "\n", + "[tool]: It's sunny in San Francisco, but you better look out if you're a Gemini 😈.\n", + "-----\n", + "\n" + ] + } + ], + "source": [ + "import { HumanMessage } from \"@langchain/core/messages\";\n", + "\n", + "const inputs = {\n", + " messages: [\n", + " new HumanMessage(\n", + " \"what is the weather in sf? return this result directly by setting return_direct = True\",\n", + " ),\n", + " ],\n", + "};\n", + "for await (const output of await app.stream(inputs, { streamMode: \"values\" })) {\n", + " const lastMessage = output.messages[output.messages.length - 1];\n", + " prettyPrint(lastMessage);\n", + " console.log(\"-----\\n\");\n", + "}" + ] + }, + { + "cell_type": "markdown", + "id": "51fa73e6", + "metadata": {}, + "source": [ + "Done! The graph **stopped** after running the `tools` node!" + ] + } + ], + "metadata": { + "jupytext": { + "text_representation": { + "extension": ".py", + "format_name": "percent", + "format_version": "1.3", + "jupytext_version": "1.16.1" + } + }, + "kernelspec": { + "display_name": "Deno", + "language": "typescript", + "name": "deno" + }, + "language_info": { + "file_extension": ".ts", + "mimetype": "text/x.typescript", + "name": "typescript", + "nb_converter": "script", + "pygments_lexer": "typescript", + "version": "5.4.5" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}