Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: use query bundle #702

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
12 changes: 7 additions & 5 deletions packages/core/src/Retriever.ts
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
import type { NodeWithScore } from "./Node.js";
import type { ServiceContext } from "./ServiceContext.js";
import type { MessageContent } from "./llm/index.js";
import type { QueryBundle } from "./types.js";

export type RetrieveParams = {
query: string;
preFilters?: unknown;
export type RetrieveParams<Filters = unknown> = {
query: QueryBundle | MessageContent;
preFilters?: Filters;
};

/**
* Retrievers retrieve the nodes that most closely match our query in similarity.
*/
export interface BaseRetriever {
retrieve(params: RetrieveParams): Promise<NodeWithScore[]>;
export interface BaseRetriever<Filters = unknown> {
retrieve(params: RetrieveParams<Filters>): Promise<NodeWithScore[]>;

// to be deprecated soon
serviceContext?: ServiceContext;
Expand Down
3 changes: 2 additions & 1 deletion packages/core/src/agent/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import type {
StreamingAgentChatResponse,
} from "../engines/chat/index.js";

import { toQueryBundle } from "../internal/utils.js";
import type { BaseMemory } from "../memory/types.js";
import type { QueryEngineParamsNonStreaming } from "../types.js";

Expand Down Expand Up @@ -62,7 +63,7 @@ export abstract class BaseAgent implements BaseChatEngine, BaseQueryEngine {
): Promise<AgentChatResponse | StreamingAgentChatResponse> {
// Handle non-streaming query
const agentResponse = await this.chat({
message: params.query,
message: toQueryBundle(params.query).query,
chatHistory: [],
});

Expand Down
3 changes: 2 additions & 1 deletion packages/core/src/callbacks/CallbackManager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import type {
LLMStartEvent,
LLMStreamEvent,
LLMToolCallEvent,
MessageContent,
} from "../llm/types.js";

export class LlamaIndexCustomEvent<T = any> extends CustomEvent<T> {
Expand Down Expand Up @@ -88,7 +89,7 @@ export interface StreamCallbackResponse {
}

export interface RetrievalCallbackResponse {
query: string;
query: MessageContent;
nodes: NodeWithScore[];
}

Expand Down
6 changes: 4 additions & 2 deletions packages/core/src/cloud/LlamaCloudRetriever.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ import { ObjectType, jsonToNode } from "../Node.js";
import type { BaseRetriever, RetrieveParams } from "../Retriever.js";
import { Settings } from "../Settings.js";
import { wrapEventCaller } from "../internal/context/EventCaller.js";
import { toQueryBundle } from "../internal/utils.js";
import { extractText } from "../llm/utils.js";
import type { ClientParams, CloudConstructorParams } from "./types.js";
import { DEFAULT_PROJECT_NAME } from "./types.js";
import { getClient } from "./utils.js";
Expand Down Expand Up @@ -70,14 +72,14 @@ export class LlamaCloudRetriever implements BaseRetriever {
await this.getClient()
).pipeline.runSearch(pipelines[0].id, {
...this.retrieveParams,
query,
query: extractText(toQueryBundle(query).query),
searchFilters: preFilters as Record<string, unknown[]>,
});

const nodes = this.resultNodesToNodeWithScore(results.retrievalNodes);

Settings.callbackManager.dispatchEvent("retrieve", {
query,
query: extractText(toQueryBundle(query).query),
nodes,
});

Expand Down
8 changes: 6 additions & 2 deletions packages/core/src/engines/chat/ContextChatEngine.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,13 @@ import type { ContextSystemPrompt } from "../../Prompt.js";
import { Response } from "../../Response.js";
import type { BaseRetriever } from "../../Retriever.js";
import { wrapEventCaller } from "../../internal/context/EventCaller.js";
import type { ChatMessage, ChatResponseChunk, LLM } from "../../llm/index.js";
import type {
ChatMessage,
ChatResponseChunk,
LLM,
MessageContent,
} from "../../llm/index.js";
import { OpenAI } from "../../llm/index.js";
import type { MessageContent } from "../../llm/types.js";
import {
extractText,
streamConverter,
Expand Down
3 changes: 1 addition & 2 deletions packages/core/src/engines/chat/types.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import type { ChatHistory } from "../../ChatHistory.js";
import type { NodeWithScore } from "../../Node.js";
import type { Response } from "../../Response.js";
import type { ChatMessage } from "../../llm/index.js";
import type { MessageContent } from "../../llm/types.js";
import type { ChatMessage, MessageContent } from "../../llm/index.js";
import type { ToolOutput } from "../../tools/types.js";

/**
Expand Down
40 changes: 17 additions & 23 deletions packages/core/src/engines/query/RetrieverQueryEngine.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,44 +2,35 @@ import type { NodeWithScore } from "../../Node.js";
import type { Response } from "../../Response.js";
import type { BaseRetriever } from "../../Retriever.js";
import { wrapEventCaller } from "../../internal/context/EventCaller.js";
import { toQueryBundle } from "../../internal/utils.js";
import type { MessageContent } from "../../llm/index.js";
import type { BaseNodePostprocessor } from "../../postprocessors/index.js";
import { PromptMixin } from "../../prompts/Mixin.js";
import type { BaseSynthesizer } from "../../synthesizers/index.js";
import { ResponseSynthesizer } from "../../synthesizers/index.js";
import type {
BaseQueryEngine,
QueryBundle,
QueryEngineParamsNonStreaming,
QueryEngineParamsStreaming,
} from "../../types.js";

/**
* A query engine that uses a retriever to query an index and then synthesizes the response.
*/
export class RetrieverQueryEngine
export class RetrieverQueryEngine<Filters = unknown>
extends PromptMixin
implements BaseQueryEngine
{
retriever: BaseRetriever;
responseSynthesizer: BaseSynthesizer;
nodePostprocessors: BaseNodePostprocessor[];
preFilters?: unknown;

constructor(
retriever: BaseRetriever,
responseSynthesizer?: BaseSynthesizer,
preFilters?: unknown,
nodePostprocessors?: BaseNodePostprocessor[],
public retriever: BaseRetriever<Filters>,
public responseSynthesizer: BaseSynthesizer = new ResponseSynthesizer({
serviceContext: retriever.serviceContext,
}),
public preFilters?: Filters,
public nodePostprocessors: BaseNodePostprocessor[] = [],
) {
super();

this.retriever = retriever;
this.responseSynthesizer =
responseSynthesizer ||
new ResponseSynthesizer({
serviceContext: retriever.serviceContext,
});
this.preFilters = preFilters;
this.nodePostprocessors = nodePostprocessors || [];
}

_getPromptModules() {
Expand All @@ -48,7 +39,10 @@ export class RetrieverQueryEngine
};
}

private async applyNodePostprocessors(nodes: NodeWithScore[], query: string) {
private async applyNodePostprocessors(
nodes: NodeWithScore[],
query: QueryBundle | MessageContent,
) {
let nodesWithScore = nodes;

for (const postprocessor of this.nodePostprocessors) {
Expand All @@ -61,9 +55,9 @@ export class RetrieverQueryEngine
return nodesWithScore;
}

private async retrieve(query: string) {
private async retrieve(query: QueryBundle) {
const nodes = await this.retriever.retrieve({
query,
...query,
preFilters: this.preFilters,
});

Expand All @@ -77,7 +71,7 @@ export class RetrieverQueryEngine
params: QueryEngineParamsStreaming | QueryEngineParamsNonStreaming,
): Promise<Response | AsyncIterable<Response>> {
const { query, stream } = params;
const nodesWithScore = await this.retrieve(query);
const nodesWithScore = await this.retrieve(toQueryBundle(query));
if (stream) {
return this.responseSynthesizer.synthesize({
query,
Expand Down
9 changes: 5 additions & 4 deletions packages/core/src/engines/query/RouterQueryEngine.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ import type { NodeWithScore } from "../../Node.js";
import { Response } from "../../Response.js";
import type { ServiceContext } from "../../ServiceContext.js";
import { llmFromSettingsOrContext } from "../../Settings.js";
import { toQueryBundle } from "../../internal/utils.js";
import { PromptMixin } from "../../prompts/index.js";
import type { BaseSelector } from "../../selectors/index.js";
import { LLMSingleSelector } from "../../selectors/index.js";
Expand Down Expand Up @@ -44,7 +45,7 @@ async function combineResponses(
}

const summary = await summarizer.getResponse({
query: queryBundle.queryStr,
query: queryBundle.query,
textChunks: responseStrs,
});

Expand Down Expand Up @@ -115,7 +116,7 @@ export class RouterQueryEngine extends PromptMixin implements BaseQueryEngine {
): Promise<Response | AsyncIterable<Response>> {
const { query, stream } = params;

const response = await this.queryRoute({ queryStr: query });
const response = await this.queryRoute(toQueryBundle(query));

if (stream) {
throw new Error("Streaming is not supported yet.");
Expand All @@ -140,7 +141,7 @@ export class RouterQueryEngine extends PromptMixin implements BaseQueryEngine {
const selectedQueryEngine = this.queryEngines[engineInd.index];
responses.push(
await selectedQueryEngine.query({
query: queryBundle.queryStr,
query: queryBundle.query,
}),
);
}
Expand Down Expand Up @@ -177,7 +178,7 @@ export class RouterQueryEngine extends PromptMixin implements BaseQueryEngine {
}

const finalResponse = await selectedQueryEngine.query({
query: queryBundle.queryStr,
query: queryBundle.query,
});

// add selected result
Expand Down
6 changes: 5 additions & 1 deletion packages/core/src/engines/query/SubQuestionQueryEngine.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import type {
} from "../../types.js";

import { wrapEventCaller } from "../../internal/context/EventCaller.js";
import { toQueryBundle } from "../../internal/utils.js";
import type { BaseQuestionGenerator, SubQuestion } from "./types.js";

/**
Expand Down Expand Up @@ -84,7 +85,10 @@ export class SubQuestionQueryEngine
params: QueryEngineParamsStreaming | QueryEngineParamsNonStreaming,
): Promise<Response | AsyncIterable<Response>> {
const { query, stream } = params;
const subQuestions = await this.questionGen.generate(this.metadatas, query);
const subQuestions = await this.questionGen.generate(
this.metadatas,
toQueryBundle(query).query,
);

const subQNodes = await Promise.all(
subQuestions.map((subQ) => this.querySubQ(subQ)),
Expand Down
6 changes: 5 additions & 1 deletion packages/core/src/engines/query/types.ts
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
import type { MessageContent } from "../../llm/index.js";
import type { ToolMetadata } from "../../types.js";

/**
* QuestionGenerators generate new questions for the LLM using tools and a user query.
*/
export interface BaseQuestionGenerator {
generate(tools: ToolMetadata[], query: string): Promise<SubQuestion[]>;
generate(
tools: ToolMetadata[],
query: MessageContent,
): Promise<SubQuestion[]>;
}

export interface SubQuestion {
Expand Down
6 changes: 5 additions & 1 deletion packages/core/src/indices/keyword/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@ import {
} from "./utils.js";

import { llmFromSettingsOrContext } from "../../Settings.js";
import { toQueryBundle } from "../../internal/utils.js";
import type { LLM } from "../../llm/types.js";
import { extractText } from "../../llm/utils.js";

export interface KeywordIndexOptions {
nodes?: BaseNode[];
Expand Down Expand Up @@ -85,7 +87,9 @@ abstract class BaseKeywordTableRetriever implements BaseRetriever {
abstract getKeywords(query: string): Promise<string[]>;

async retrieve({ query }: RetrieveParams): Promise<NodeWithScore[]> {
const keywords = await this.getKeywords(query);
const keywords = await this.getKeywords(
extractText(toQueryBundle(query).query),
);
const chunkIndicesCount: { [key: string]: number } = {};
const filteredKeywords = keywords.filter((keyword) =>
this.indexStruct.table.has(keyword),
Expand Down
12 changes: 8 additions & 4 deletions packages/core/src/indices/summary/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ import {
} from "../../Settings.js";
import { RetrieverQueryEngine } from "../../engines/query/index.js";
import { wrapEventCaller } from "../../internal/context/EventCaller.js";
import { toQueryBundle } from "../../internal/utils.js";
import { extractText } from "../../llm/utils.js";
import type { BaseNodePostprocessor } from "../../postprocessors/index.js";
import type { StorageContext } from "../../storage/StorageContext.js";
import { storageContextFromDefaults } from "../../storage/StorageContext.js";
Expand Down Expand Up @@ -297,7 +299,7 @@ export class SummaryIndexRetriever implements BaseRetriever {
}));

Settings.callbackManager.dispatchEvent("retrieve", {
query,
query: toQueryBundle(query).query,
nodes: result,
});

Expand Down Expand Up @@ -343,13 +345,15 @@ export class SummaryIndexLLMRetriever implements BaseRetriever {
const nodesBatch = await this.index.docStore.getNodes(nodeIdsBatch);

const fmtBatchStr = this.formatNodeBatchFn(nodesBatch);
const input = { context: fmtBatchStr, query: query };

const llm = llmFromSettingsOrContext(this.serviceContext);

const rawResponse = (
await llm.complete({
prompt: this.choiceSelectPrompt(input),
prompt: this.choiceSelectPrompt({
context: fmtBatchStr,
query: extractText(toQueryBundle(query).query),
}),
})
).text;

Expand All @@ -372,7 +376,7 @@ export class SummaryIndexLLMRetriever implements BaseRetriever {
}

Settings.callbackManager.dispatchEvent("retrieve", {
query,
query: toQueryBundle(query).query,
nodes: results,
});

Expand Down