diff --git a/.changeset/yellow-grapes-explain.md b/.changeset/yellow-grapes-explain.md new file mode 100644 index 000000000..b12fdb9e4 --- /dev/null +++ b/.changeset/yellow-grapes-explain.md @@ -0,0 +1,6 @@ +--- +"llamaindex": patch +"llamaindex-loader-example": patch +--- + +feat: support `OpenaiAssistantAgent` diff --git a/.npmrc b/.npmrc index 5da1bc1c6..7f5ffd319 100644 --- a/.npmrc +++ b/.npmrc @@ -2,3 +2,4 @@ auto-install-peers = true enable-pre-post-scripts = true prefer-workspace-packages = true save-workspace-protocol = true +link-workspace-packages = true diff --git a/examples/agent/openai-assistant-agent.ts b/examples/agent/openai-assistant-agent.ts new file mode 100644 index 000000000..004ae8a13 --- /dev/null +++ b/examples/agent/openai-assistant-agent.ts @@ -0,0 +1,13 @@ +import { OpenaiAssistantAgent } from 'llamaindex/agent/openai' + +const assistantId = process.env.ASSISTANT_ID + +if (!assistantId) { + throw new Error('ASSISTANT_ID is required for openai') +} + +const agent = new OpenaiAssistantAgent({ + assistantId +}) + +agent.run('What\'s the weather today?') diff --git a/examples/package.json b/examples/package.json index 3403f7b13..b06aa3e34 100644 --- a/examples/package.json +++ b/examples/package.json @@ -12,7 +12,7 @@ "commander": "^11.1.0", "dotenv": "^16.4.5", "js-tiktoken": "^1.0.11", - "llamaindex": "*", + "llamaindex": "latest", "mongodb": "^6.5.0", "pathe": "^1.1.2" }, diff --git a/examples/readers/package.json b/examples/readers/package.json index 8202eb59c..fd5e7aee7 100644 --- a/examples/readers/package.json +++ b/examples/readers/package.json @@ -12,7 +12,7 @@ "start:llamaparse": "node --loader ts-node/esm ./src/llamaparse.ts" }, "dependencies": { - "llamaindex": "*" + "llamaindex": "latest" }, "devDependencies": { "@types/node": "^20.12.7", diff --git a/packages/core/src/agent/index.ts b/packages/core/src/agent/index.ts index 2fc557f4b..b09f6db6a 100644 --- a/packages/core/src/agent/index.ts +++ b/packages/core/src/agent/index.ts @@ -7,6 +7,8 @@ export { OpenAIAgent, OpenAIAgentWorker, type OpenAIAgentParams, + OpenaiAssistantAgent, + type OpenaiAssistantAgentParams } from "./openai.js"; export { ReACTAgent, diff --git a/packages/core/src/agent/openai.ts b/packages/core/src/agent/openai.ts index c0d59b076..6c787ecf0 100644 --- a/packages/core/src/agent/openai.ts +++ b/packages/core/src/agent/openai.ts @@ -1,194 +1,2 @@ -import { pipeline } from "@llamaindex/env"; -import { Settings } from "../Settings.js"; -import type { - ChatResponseChunk, - ToolCall, - ToolCallLLMMessageOptions, -} from "../llm/index.js"; -import { OpenAI } from "../llm/open_ai.js"; -import { ObjectRetriever } from "../objects/index.js"; -import type { BaseToolWithCall } from "../types.js"; -import { - AgentRunner, - AgentWorker, - type AgentParamsBase, - type TaskHandler, -} from "./base.js"; -import { callTool } from "./utils.js"; - -type OpenAIParamsBase = AgentParamsBase; - -type OpenAIParamsWithTools = OpenAIParamsBase & { - tools: BaseToolWithCall[]; -}; - -type OpenAIParamsWithToolRetriever = OpenAIParamsBase & { - toolRetriever: ObjectRetriever; -}; - -export type OpenAIAgentParams = - | OpenAIParamsWithTools - | OpenAIParamsWithToolRetriever; - -export class OpenAIAgentWorker extends AgentWorker { - taskHandler = OpenAIAgent.taskHandler; -} - -export class OpenAIAgent extends AgentRunner { - constructor(params: OpenAIAgentParams) { - super({ - llm: - params.llm ?? Settings.llm instanceof OpenAI - ? (Settings.llm as OpenAI) - : new OpenAI(), - chatHistory: params.chatHistory ?? [], - runner: new OpenAIAgentWorker(), - systemPrompt: params.systemPrompt ?? null, - tools: - "tools" in params - ? params.tools - : params.toolRetriever.retrieve.bind(params.toolRetriever), - }); - } - - createStore = AgentRunner.defaultCreateStore; - - static taskHandler: TaskHandler = async (step) => { - const { input } = step; - const { llm, stream, getTools } = step.context; - if (input) { - step.context.store.messages = [...step.context.store.messages, input]; - } - const lastMessage = step.context.store.messages.at(-1)!.content; - const tools = await getTools(lastMessage); - const response = await llm.chat({ - // @ts-expect-error - stream, - tools, - messages: [...step.context.store.messages], - }); - if (!stream) { - step.context.store.messages = [ - ...step.context.store.messages, - response.message, - ]; - const options = response.message.options ?? {}; - if ("toolCall" in options) { - const { toolCall } = options; - const targetTool = tools.find( - (tool) => tool.metadata.name === toolCall.name, - ); - const toolOutput = await callTool(targetTool, toolCall); - step.context.store.toolOutputs.push(toolOutput); - return { - taskStep: step, - output: { - raw: response.raw, - message: { - content: toolOutput.output, - role: "user", - options: { - toolResult: { - result: toolOutput.output, - isError: toolOutput.isError, - id: toolCall.id, - }, - }, - }, - }, - isLast: false, - }; - } else { - return { - taskStep: step, - output: response, - isLast: true, - }; - } - } else { - const responseChunkStream = new ReadableStream< - ChatResponseChunk - >({ - async start(controller) { - for await (const chunk of response) { - controller.enqueue(chunk); - } - controller.close(); - }, - }); - const [pipStream, finalStream] = responseChunkStream.tee(); - const reader = pipStream.getReader(); - const { value } = await reader.read(); - reader.releaseLock(); - if (value === undefined) { - throw new Error( - "first chunk value is undefined, this should not happen", - ); - } - // check if first chunk has tool calls, if so, this is a function call - // otherwise, it's a regular message - const hasToolCall = !!(value.options && "toolCall" in value.options); - - if (hasToolCall) { - // you need to consume the response to get the full toolCalls - const toolCalls = await pipeline( - pipStream, - async ( - iter: AsyncIterable>, - ) => { - const toolCalls = new Map(); - for await (const chunk of iter) { - if (chunk.options && "toolCall" in chunk.options) { - const toolCall = chunk.options.toolCall; - toolCalls.set(toolCall.id, toolCall); - } - } - return [...toolCalls.values()]; - }, - ); - for (const toolCall of toolCalls) { - const targetTool = tools.find( - (tool) => tool.metadata.name === toolCall.name, - ); - step.context.store.messages = [ - ...step.context.store.messages, - { - role: "assistant" as const, - content: "", - options: { - toolCall, - }, - }, - ]; - const toolOutput = await callTool(targetTool, toolCall); - step.context.store.messages = [ - ...step.context.store.messages, - { - role: "user" as const, - content: toolOutput.output, - options: { - toolResult: { - result: toolOutput.output, - isError: toolOutput.isError, - id: toolCall.id, - }, - }, - }, - ]; - step.context.store.toolOutputs.push(toolOutput); - } - return { - taskStep: step, - output: null, - isLast: false, - }; - } else { - return { - taskStep: step, - output: finalStream, - isLast: true, - }; - } - } - }; -} +export { OpenAIAgent, OpenAIAgentWorker, type OpenAIAgentParams } from './openai/openai-agent.js' +export { OpenaiAssistantAgent, type OpenaiAssistantAgentParams } from './openai/openai-assistant-agent.js' \ No newline at end of file diff --git a/packages/core/src/agent/openai/openai-agent.ts b/packages/core/src/agent/openai/openai-agent.ts new file mode 100644 index 000000000..e12719e7e --- /dev/null +++ b/packages/core/src/agent/openai/openai-agent.ts @@ -0,0 +1,194 @@ +import { pipeline } from "@llamaindex/env"; +import { Settings } from "../../Settings.js"; +import type { + ChatResponseChunk, + ToolCall, + ToolCallLLMMessageOptions, +} from "../../llm/index.js"; +import { OpenAI } from "../../llm/open_ai.js"; +import { ObjectRetriever } from "../../objects/index.js"; +import type { BaseToolWithCall } from "../../types.js"; +import { + AgentRunner, + AgentWorker, + type AgentParamsBase, + type TaskHandler, +} from "../base.js"; +import { callTool } from "../utils.js"; + +type OpenAIParamsBase = AgentParamsBase; + +type OpenAIParamsWithTools = OpenAIParamsBase & { + tools: BaseToolWithCall[]; +}; + +type OpenAIParamsWithToolRetriever = OpenAIParamsBase & { + toolRetriever: ObjectRetriever; +}; + +export type OpenAIAgentParams = + | OpenAIParamsWithTools + | OpenAIParamsWithToolRetriever; + +export class OpenAIAgentWorker extends AgentWorker { + taskHandler = OpenAIAgent.taskHandler; +} + +export class OpenAIAgent extends AgentRunner { + constructor(params: OpenAIAgentParams) { + super({ + llm: + params.llm ?? Settings.llm instanceof OpenAI + ? (Settings.llm as OpenAI) + : new OpenAI(), + chatHistory: params.chatHistory ?? [], + runner: new OpenAIAgentWorker(), + systemPrompt: params.systemPrompt ?? null, + tools: + "tools" in params + ? params.tools + : params.toolRetriever.retrieve.bind(params.toolRetriever), + }); + } + + createStore = AgentRunner.defaultCreateStore; + + static taskHandler: TaskHandler = async (step) => { + const { input } = step; + const { llm, stream, getTools } = step.context; + if (input) { + step.context.store.messages = [...step.context.store.messages, input]; + } + const lastMessage = step.context.store.messages.at(-1)!.content; + const tools = await getTools(lastMessage); + const response = await llm.chat({ + // @ts-expect-error + stream, + tools, + messages: [...step.context.store.messages], + }); + if (!stream) { + step.context.store.messages = [ + ...step.context.store.messages, + response.message, + ]; + const options = response.message.options ?? {}; + if ("toolCall" in options) { + const { toolCall } = options; + const targetTool = tools.find( + (tool) => tool.metadata.name === toolCall.name, + ); + const toolOutput = await callTool(targetTool, toolCall); + step.context.store.toolOutputs.push(toolOutput); + return { + taskStep: step, + output: { + raw: response.raw, + message: { + content: toolOutput.output, + role: "user", + options: { + toolResult: { + result: toolOutput.output, + isError: toolOutput.isError, + id: toolCall.id, + }, + }, + }, + }, + isLast: false, + }; + } else { + return { + taskStep: step, + output: response, + isLast: true, + }; + } + } else { + const responseChunkStream = new ReadableStream< + ChatResponseChunk + >({ + async start(controller) { + for await (const chunk of response) { + controller.enqueue(chunk); + } + controller.close(); + }, + }); + const [pipStream, finalStream] = responseChunkStream.tee(); + const reader = pipStream.getReader(); + const { value } = await reader.read(); + reader.releaseLock(); + if (value === undefined) { + throw new Error( + "first chunk value is undefined, this should not happen", + ); + } + // check if first chunk has tool calls, if so, this is a function call + // otherwise, it's a regular message + const hasToolCall = !!(value.options && "toolCall" in value.options); + + if (hasToolCall) { + // you need to consume the response to get the full toolCalls + const toolCalls = await pipeline( + pipStream, + async ( + iter: AsyncIterable>, + ) => { + const toolCalls = new Map(); + for await (const chunk of iter) { + if (chunk.options && "toolCall" in chunk.options) { + const toolCall = chunk.options.toolCall; + toolCalls.set(toolCall.id, toolCall); + } + } + return [...toolCalls.values()]; + }, + ); + for (const toolCall of toolCalls) { + const targetTool = tools.find( + (tool) => tool.metadata.name === toolCall.name, + ); + step.context.store.messages = [ + ...step.context.store.messages, + { + role: "assistant" as const, + content: "", + options: { + toolCall, + }, + }, + ]; + const toolOutput = await callTool(targetTool, toolCall); + step.context.store.messages = [ + ...step.context.store.messages, + { + role: "user" as const, + content: toolOutput.output, + options: { + toolResult: { + result: toolOutput.output, + isError: toolOutput.isError, + id: toolCall.id, + }, + }, + }, + ]; + step.context.store.toolOutputs.push(toolOutput); + } + return { + taskStep: step, + output: null, + isLast: false, + }; + } else { + return { + taskStep: step, + output: finalStream, + isLast: true, + }; + } + } + }; +} diff --git a/packages/core/src/agent/openai/openai-assistant-agent.ts b/packages/core/src/agent/openai/openai-assistant-agent.ts new file mode 100644 index 000000000..8d7b8ed5c --- /dev/null +++ b/packages/core/src/agent/openai/openai-assistant-agent.ts @@ -0,0 +1,65 @@ +import { OpenAI } from 'openai' +import { OpenAI as OpenAIClass } from '../../llm/open_ai.js' +import { getEnv } from '@llamaindex/env' +import type { Thread } from 'openai/resources/beta/threads/threads' +import type { OpenAIAgentParams } from '../openai.js' +import { Settings } from '../../Settings.js' + +export type OpenaiAssistantAgentParams = OpenAIAgentParams & { + assistantId: string + threadId?: string +} + +export class OpenaiAssistantAgent { + readonly #assistantId: string + readonly #threadPromise: Promise + readonly defaultModel: string; + + #client: OpenAI + + constructor ( + params: OpenaiAssistantAgentParams + ) { + const { assistantId, llm } = params + this.#assistantId = assistantId + if (llm) { + this.#client = llm.session.openai + this.defaultModel = llm.model; + } else { + // retrieve the client from defaults + if (Settings.llm instanceof OpenAIClass) { + this.#client = Settings.llm.session.openai + this.defaultModel = Settings.llm.model + } else { + this.#client = new OpenAI({ + apiKey: getEnv('OPENAI_API_KEY') + }) + this.defaultModel = 'gpt-4-turbo' + } + } + if (!params.threadId) { + this.#threadPromise = this.#client.beta.threads.create() + } else { + this.#threadPromise = this.#client.beta.threads.retrieve(params.threadId) + } + } + + public async run( + instructions: string | null = null + ) { + const thread = await this.#threadPromise + const threadId = thread.id; + + const stream = this.#client.beta.threads.runs.stream(threadId, { + assistant_id: this.#assistantId, + instructions + }) + const currentRun = stream.currentRun() + if (!currentRun) { + throw new TypeError('No current run') + } + for await (const streamEvent of stream) { + // todo: wrap into StepOutput + } + } +} \ No newline at end of file diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index df4ddf454..684b35b2c 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -147,8 +147,8 @@ importers: specifier: ^1.0.11 version: 1.0.11 llamaindex: - specifier: '*' - version: 0.2.10(encoding@0.1.13)(node-fetch@2.7.0(encoding@0.1.13))(readable-stream@4.5.2)(typescript@5.4.5) + specifier: latest + version: link:../packages/core mongodb: specifier: ^6.5.0 version: 6.5.0 @@ -172,8 +172,8 @@ importers: examples/readers: dependencies: llamaindex: - specifier: '*' - version: 0.2.10(encoding@0.1.13)(node-fetch@2.7.0(encoding@0.1.13))(readable-stream@4.5.2)(typescript@5.4.3) + specifier: latest + version: link:../../packages/core devDependencies: '@types/node': specifier: ^20.12.7 @@ -1852,13 +1852,6 @@ packages: node-fetch: optional: true - '@llamaindex/env@0.0.7': - resolution: {integrity: sha512-6j7eGXhSDspz33FzdWJRTbGlXa3osYP/aP9dm10Z7JCxaxyQZmGIWL149HNkGgV4lxiPmGPx7YWjBBj9nRdo2w==} - peerDependencies: - '@aws-crypto/sha256-js': ^5.2.0 - pathe: ^1.1.2 - readable-stream: ^4.5.2 - '@manypkg/find-root@1.1.0': resolution: {integrity: sha512-mki5uBvhHzO8kYYix/WRy2WX8S3B5wdVSc9D6KcU5lQNglP2yt58/VfLuAK49glRXChosY8ap2oJ1qgma3GUVA==} @@ -5581,10 +5574,6 @@ packages: resolution: {integrity: sha512-ovJXBXkKGfq+CwmKTjluEqFi3p4h8xvkxGQQAQan22YCgef4KZ1mKGjzfGh6PL6AW5Csw0QiQPNuQyH+6Xk3hA==} engines: {node: '>=18.0.0'} - llamaindex@0.2.10: - resolution: {integrity: sha512-GXO/H4k6iF0dQStg1kOTYYm0pnMbD1gM8LwRKEPOeC/mY+Q2pyIyDB22cPc8nOTf+ah3rbLiXOxTORTUmC1xKA==} - engines: {node: '>=18.0.0'} - load-yaml-file@0.2.0: resolution: {integrity: sha512-OfCBkGEw4nN6JLtgRidPX6QxjBQGQf72q3si2uvqyFEMbycSFFHwAZeXx6cJgFM9wmLrf9zBwCP3Ivqa+LLZPw==} engines: {node: '>=6'} @@ -10812,15 +10801,6 @@ snapshots: optionalDependencies: node-fetch: 2.7.0(encoding@0.1.13) - '@llamaindex/env@0.0.7(@aws-crypto/sha256-js@5.2.0)(pathe@1.1.2)(readable-stream@4.5.2)': - dependencies: - '@aws-crypto/sha256-js': 5.2.0 - '@types/lodash': 4.17.0 - '@types/node': 20.12.7 - lodash: 4.17.21 - pathe: 1.1.2 - readable-stream: 4.5.2 - '@manypkg/find-root@1.1.0': dependencies: '@babel/runtime': 7.24.4 @@ -11000,13 +10980,6 @@ snapshots: '@protobufjs/utf8@1.1.0': {} - '@qdrant/js-client-rest@1.8.2(typescript@5.4.3)': - dependencies: - '@qdrant/openapi-typescript-fetch': 1.2.6 - '@sevinf/maybe': 0.5.0 - typescript: 5.4.3 - undici: 5.28.4 - '@qdrant/js-client-rest@1.8.2(typescript@5.4.5)': dependencies: '@qdrant/openapi-typescript-fetch': 1.2.6 @@ -15097,122 +15070,6 @@ snapshots: rfdc: 1.3.1 wrap-ansi: 9.0.0 - llamaindex@0.2.10(encoding@0.1.13)(node-fetch@2.7.0(encoding@0.1.13))(readable-stream@4.5.2)(typescript@5.4.3): - dependencies: - '@anthropic-ai/sdk': 0.20.6(encoding@0.1.13) - '@aws-crypto/sha256-js': 5.2.0 - '@datastax/astra-db-ts': 1.0.1 - '@grpc/grpc-js': 1.10.6 - '@llamaindex/cloud': 0.0.5(node-fetch@2.7.0(encoding@0.1.13)) - '@llamaindex/env': 0.0.7(@aws-crypto/sha256-js@5.2.0)(pathe@1.1.2)(readable-stream@4.5.2) - '@mistralai/mistralai': 0.1.3(encoding@0.1.13) - '@notionhq/client': 2.2.15(encoding@0.1.13) - '@pinecone-database/pinecone': 2.2.0 - '@qdrant/js-client-rest': 1.8.2(typescript@5.4.3) - '@types/lodash': 4.17.0 - '@types/node': 20.12.7 - '@types/papaparse': 5.3.14 - '@types/pg': 8.11.5 - '@xenova/transformers': 2.17.1 - '@zilliz/milvus2-sdk-node': 2.4.1 - ajv: 8.12.0 - assemblyai: 4.4.1 - chromadb: 1.7.3(cohere-ai@7.9.5(encoding@0.1.13))(encoding@0.1.13)(openai@4.38.0(encoding@0.1.13)) - cohere-ai: 7.9.5(encoding@0.1.13) - js-tiktoken: 1.0.11 - lodash: 4.17.21 - magic-bytes.js: 1.10.0 - mammoth: 1.7.1 - md-utils-ts: 2.0.0 - mongodb: 6.5.0 - notion-md-crawler: 0.0.2(encoding@0.1.13) - openai: 4.38.0(encoding@0.1.13) - papaparse: 5.4.1 - pathe: 1.1.2 - pdf2json: 3.0.5 - pg: 8.11.5 - pgvector: 0.1.8 - portkey-ai: 0.1.16 - rake-modified: 1.0.8 - string-strip-html: 13.4.8 - wikipedia: 2.1.2 - wink-nlp: 1.14.3 - transitivePeerDependencies: - - '@aws-sdk/credential-providers' - - '@google/generative-ai' - - '@mongodb-js/zstd' - - bufferutil - - debug - - encoding - - gcp-metadata - - kerberos - - mongodb-client-encryption - - node-fetch - - pg-native - - readable-stream - - snappy - - socks - - typescript - - utf-8-validate - - llamaindex@0.2.10(encoding@0.1.13)(node-fetch@2.7.0(encoding@0.1.13))(readable-stream@4.5.2)(typescript@5.4.5): - dependencies: - '@anthropic-ai/sdk': 0.20.6(encoding@0.1.13) - '@aws-crypto/sha256-js': 5.2.0 - '@datastax/astra-db-ts': 1.0.1 - '@grpc/grpc-js': 1.10.6 - '@llamaindex/cloud': 0.0.5(node-fetch@2.7.0(encoding@0.1.13)) - '@llamaindex/env': 0.0.7(@aws-crypto/sha256-js@5.2.0)(pathe@1.1.2)(readable-stream@4.5.2) - '@mistralai/mistralai': 0.1.3(encoding@0.1.13) - '@notionhq/client': 2.2.15(encoding@0.1.13) - '@pinecone-database/pinecone': 2.2.0 - '@qdrant/js-client-rest': 1.8.2(typescript@5.4.5) - '@types/lodash': 4.17.0 - '@types/node': 20.12.7 - '@types/papaparse': 5.3.14 - '@types/pg': 8.11.5 - '@xenova/transformers': 2.17.1 - '@zilliz/milvus2-sdk-node': 2.4.1 - ajv: 8.12.0 - assemblyai: 4.4.1 - chromadb: 1.7.3(cohere-ai@7.9.5(encoding@0.1.13))(encoding@0.1.13)(openai@4.38.0(encoding@0.1.13)) - cohere-ai: 7.9.5(encoding@0.1.13) - js-tiktoken: 1.0.11 - lodash: 4.17.21 - magic-bytes.js: 1.10.0 - mammoth: 1.7.1 - md-utils-ts: 2.0.0 - mongodb: 6.5.0 - notion-md-crawler: 0.0.2(encoding@0.1.13) - openai: 4.38.0(encoding@0.1.13) - papaparse: 5.4.1 - pathe: 1.1.2 - pdf2json: 3.0.5 - pg: 8.11.5 - pgvector: 0.1.8 - portkey-ai: 0.1.16 - rake-modified: 1.0.8 - string-strip-html: 13.4.8 - wikipedia: 2.1.2 - wink-nlp: 1.14.3 - transitivePeerDependencies: - - '@aws-sdk/credential-providers' - - '@google/generative-ai' - - '@mongodb-js/zstd' - - bufferutil - - debug - - encoding - - gcp-metadata - - kerberos - - mongodb-client-encryption - - node-fetch - - pg-native - - readable-stream - - snappy - - socks - - typescript - - utf-8-validate - load-yaml-file@0.2.0: dependencies: graceful-fs: 4.2.11