Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/openai n #5087

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
239 changes: 192 additions & 47 deletions libs/langchain-openai/src/chat_models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import {
SystemMessageChunk,
ToolMessage,
ToolMessageChunk,
BaseMessageChunk,
OpenAIToolCall,
isAIMessage,
} from "@langchain/core/messages";
Expand All @@ -38,9 +39,11 @@ import { convertToOpenAITool } from "@langchain/core/utils/function_calling";
import { z } from "zod";
import {
Runnable,
RunnableBatchOptions,
RunnableInterface,
RunnablePassthrough,
RunnableSequence,
getCallbackManagerForConfig,
} from "@langchain/core/runnables";
import {
JsonOutputParser,
Expand Down Expand Up @@ -253,6 +256,11 @@ export interface ChatOpenAICallOptions
promptIndex?: number;
response_format?: { type: "json_object" };
seed?: number;
n?: number;
}

export interface ChatOpenAIBatchOptions extends RunnableBatchOptions {
preferSingleRequests?: boolean;
}

/**
Expand Down Expand Up @@ -522,7 +530,7 @@ export class ChatOpenAI<
max_tokens: this.maxTokens === -1 ? undefined : this.maxTokens,
logprobs: this.logprobs,
top_logprobs: this.topLogprobs,
n: this.n,
n: options?.n ?? this.n,
logit_bias: this.logitBias,
stop: options?.stop ?? this.stopSequences,
user: this.user,
Expand Down Expand Up @@ -627,6 +635,123 @@ export class ChatOpenAI<
return this._identifyingParams();
}

async batch(
inputs: BaseLanguageModelInput[],
options?: Partial<CallOptions> | Partial<CallOptions>[],
batchOptions?: ChatOpenAIBatchOptions & { returnExceptions?: false }
): Promise<BaseMessageChunk[]>;

async batch(
inputs: BaseLanguageModelInput[],
options?: Partial<CallOptions> | Partial<CallOptions>[],
batchOptions?: ChatOpenAIBatchOptions & { returnExceptions: true }
): Promise<(BaseMessageChunk | Error)[]>;

async batch(
inputs: BaseLanguageModelInput[],
options?: Partial<CallOptions> | Partial<CallOptions>[],
batchOptions?: ChatOpenAIBatchOptions
): Promise<(BaseMessageChunk | Error)[]>;

async batch(
inputs: BaseLanguageModelInput[],
options?: Partial<CallOptions> | Partial<CallOptions>[],
batchOptions?: ChatOpenAIBatchOptions
): Promise<(BaseMessageChunk | Error)[]> {
if (!batchOptions?.preferSingleRequests || Array.isArray(options)) {
return super.batch(inputs, options, batchOptions);
}
const promptValues = inputs.map((i) =>
BaseChatModel._convertInputToPromptValue(i)
);
const promptValueStrings = promptValues.map((p) => p.toString());
if (!promptValueStrings.every((p) => p === promptValueStrings[0])) {
return super.batch(inputs, options, batchOptions);
}
const maxConcurrency =
options?.maxConcurrency ?? batchOptions?.maxConcurrency ?? inputs.length;
const batchCount = Math.ceil(inputs.length / maxConcurrency);
const results: (BaseMessageChunk | Error)[] = [];
for (let i = 0; i < batchCount; i += 1) {
const concurrency = Math.min(
inputs.length - results.length,
maxConcurrency
);
const optionsList = this._getOptionsList(options ?? {}, concurrency);
const callbackManagers = await Promise.all(
optionsList.map(getCallbackManagerForConfig)
);
const runManagers = await Promise.all(
callbackManagers.map(async (callbackManager, j) => {
const handleStartRes = await callbackManager?.handleChatModelStart(
this.toJSON(),
[promptValues[0].toChatMessages()],
optionsList[j].runId,
undefined,
undefined,
undefined,
undefined,
optionsList[j].runName ?? this.getName()
);
delete optionsList[j].runId;
return handleStartRes;
})
);
try {
const { generations, llmOutput } = await this._generateNonStreaming(
promptValues[0].toChatMessages(),
{ ...options, n: concurrency } as CallOptions
);
results.push(
...generations.map(
(generation) => generation.message as BaseMessageChunk
)
);
await Promise.all(
runManagers.map((subRunManagers, j) =>
(subRunManagers ?? []).map((runManager) => {
if (j === 0) {
return runManager.handleLLMEnd({
generations: [generations],
llmOutput,
});
}
return runManager.handleLLMEnd({
generations: [generations],
llmOutput: {
...llmOutput,
tokenUsage: {
promptTokens: 0,
completionTokens: 0,
totalTokens: 0,
},
},
});
})
)
);
} catch (e) {
await Promise.all(
runManagers.map((subRunManagers) =>
Promise.all(
(subRunManagers ?? []).map((runManager) =>
runManager?.handleLLMError(e)
)
)
)
);
if (batchOptions?.returnExceptions) {
for (let j = 0; j < concurrency; j += 1) {
results.push(e as Error);
}
} else {
throw e;
}
}
}
return results;
}

/** @ignore */
async _generate(
messages: BaseMessage[],
Expand All @@ -635,8 +760,6 @@ export class ChatOpenAI<
): Promise<ChatResult> {
const tokenUsage: TokenUsage = {};
const params = this.invocationParams(options);
const messagesMapped: OpenAICompletionParam[] =
convertMessagesToOpenAIParams(messages);

if (params.stream) {
const stream = this._streamResponseChunks(messages, options, runManager);
Expand Down Expand Up @@ -677,56 +800,67 @@ export class ChatOpenAI<
tokenUsage.totalTokens = promptTokenUsage + completionTokenUsage;
return { generations, llmOutput: { estimatedTokenUsage: tokenUsage } };
} else {
const data = await this.completionWithRetry(
{
...params,
stream: false,
messages: messagesMapped,
},
{
signal: options?.signal,
...options?.options,
}
);
const {
completion_tokens: completionTokens,
prompt_tokens: promptTokens,
total_tokens: totalTokens,
} = data?.usage ?? {};

if (completionTokens) {
tokenUsage.completionTokens =
(tokenUsage.completionTokens ?? 0) + completionTokens;
}
return this._generateNonStreaming(messages, options);
}
}

if (promptTokens) {
tokenUsage.promptTokens = (tokenUsage.promptTokens ?? 0) + promptTokens;
private async _generateNonStreaming(
messages: BaseMessage[],
options: this["ParsedCallOptions"]
) {
const tokenUsage: TokenUsage = {};
const params = this.invocationParams(options);
const messagesMapped: OpenAICompletionParam[] =
convertMessagesToOpenAIParams(messages);
const data = await this.completionWithRetry(
{
...params,
stream: false,
messages: messagesMapped,
},
{
signal: options?.signal,
...options?.options,
}
);
const {
completion_tokens: completionTokens,
prompt_tokens: promptTokens,
total_tokens: totalTokens,
} = data?.usage ?? {};

if (completionTokens) {
tokenUsage.completionTokens =
(tokenUsage.completionTokens ?? 0) + completionTokens;
}

if (totalTokens) {
tokenUsage.totalTokens = (tokenUsage.totalTokens ?? 0) + totalTokens;
}
if (promptTokens) {
tokenUsage.promptTokens = (tokenUsage.promptTokens ?? 0) + promptTokens;
}

const generations: ChatGeneration[] = [];
for (const part of data?.choices ?? []) {
const text = part.message?.content ?? "";
const generation: ChatGeneration = {
text,
message: openAIResponseToChatMessage(
part.message ?? { role: "assistant" }
),
};
generation.generationInfo = {
...(part.finish_reason ? { finish_reason: part.finish_reason } : {}),
...(part.logprobs ? { logprobs: part.logprobs } : {}),
};
generations.push(generation);
}
return {
generations,
llmOutput: { tokenUsage },
if (totalTokens) {
tokenUsage.totalTokens = (tokenUsage.totalTokens ?? 0) + totalTokens;
}

const generations: ChatGeneration[] = [];
for (const part of data?.choices ?? []) {
const text = part.message?.content ?? "";
const generation: ChatGeneration = {
text,
message: openAIResponseToChatMessage(
part.message ?? { role: "assistant" }
),
};
generation.generationInfo = {
...(part.finish_reason ? { finish_reason: part.finish_reason } : {}),
...(part.logprobs ? { logprobs: part.logprobs } : {}),
};
generations.push(generation);
}
return {
generations,
llmOutput: { tokenUsage },
};
}

/**
Expand Down Expand Up @@ -1151,3 +1285,14 @@ function isStructuredOutputMethodParams(
"object"
);
}

// eslint-disable-next-line @typescript-eslint/no-explicit-any
export function _coerceToDict(value: any, defaultKey: string) {
return value &&
!Array.isArray(value) &&
// eslint-disable-next-line no-instanceof/no-instanceof
!(value instanceof Date) &&
typeof value === "object"
? value
: { [defaultKey]: value };
}
41 changes: 41 additions & 0 deletions libs/langchain-openai/src/tests/chat_models.int.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,47 @@ test("Test ChatOpenAI tokenUsage with a batch", async () => {
expect(tokenUsage.promptTokens).toBeGreaterThan(0);
});

test("Test ChatOpenAI batch sends prompt to API with n option if prompts are the same", async () => {
const model = new ChatOpenAI({
temperature: 2,
modelName: "gpt-3.5-turbo",
maxTokens: 10,
});

const completionSpy = jest.spyOn(model, "completionWithRetry");
const res = await model.batch(
[
[new HumanMessage("Hello!")],
[new HumanMessage("Hello!")],
[new HumanMessage("Hello!")],
],
undefined,
{ preferSingleRequests: true }
);
console.log(res);
expect(res).toHaveLength(3);

expect(completionSpy).toHaveBeenCalledTimes(1);
});

test("Test ChatOpenAI batch sends prompt to API in separate requests if prompts are different", async () => {
const model = new ChatOpenAI({
temperature: 2,
modelName: "gpt-3.5-turbo",
maxTokens: 10,
});

const completionSpy = jest.spyOn(model, "completionWithRetry");
const res = await model.batch([
[new HumanMessage("Hello!")],
[new HumanMessage("Hi")],
]);
console.log(res);
expect(res).toHaveLength(2);

expect(completionSpy).toHaveBeenCalledTimes(2);
});

test("Test ChatOpenAI in streaming mode", async () => {
let nrNewTokens = 0;
let streamedCompletion = "";
Expand Down