Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

openai[patch]: ChatOpenAI.batch function #5016

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
27 changes: 26 additions & 1 deletion libs/langchain-openai/src/chat_models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import {
SystemMessageChunk,
ToolMessage,
ToolMessageChunk,
BaseMessageChunk,
OpenAIToolCall,
isAIMessage,
} from "@langchain/core/messages";
Expand Down Expand Up @@ -253,6 +254,7 @@ export interface ChatOpenAICallOptions
promptIndex?: number;
response_format?: { type: "json_object" };
seed?: number;
n?: number;
}

/**
Expand Down Expand Up @@ -522,7 +524,7 @@ export class ChatOpenAI<
max_tokens: this.maxTokens === -1 ? undefined : this.maxTokens,
logprobs: this.logprobs,
top_logprobs: this.topLogprobs,
n: this.n,
n: options?.n ?? this.n,
logit_bias: this.logitBias,
stop: options?.stop ?? this.stopSequences,
user: this.user,
Expand Down Expand Up @@ -627,6 +629,29 @@ export class ChatOpenAI<
return this._identifyingParams();
}

async batch(
inputs: BaseLanguageModelInput[],
options?: CallOptions
): Promise<BaseMessageChunk[]> {
const promptValues = inputs.map((i) =>
BaseChatModel._convertInputToPromptValue(i)
);

const promptValueStrings = promptValues.map((p) => p.toString());
if (promptValueStrings.every((p) => p === promptValueStrings[0])) {
const result = await this.generatePrompt(
[promptValues[0]],
{ ...options, n: inputs.length } as CallOptions,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should probably upper bound this - I can handle it!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this have the same output as just sending n requests? Or will it pick the top n candidates?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey so chatted with the Python folks - this would change the tracing behavior for folks, and they have some concerns about overall behavior changing since it's a black box on OpenAI's end.

Could we table it for now? Sorry for the thrash - you can always wrap a .generate() call in a lambda.

Copy link
Contributor Author

@davidfant davidfant Apr 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this have the same output as just sending n requests? Or will it pick the top n candidates?

Yes, this makes OpenAI create n independent results for the same prompt. best_of would return top candidates based on log probs
https://platform.openai.com/docs/api-reference/chat/create

Hey so chatted with the Python folks - this would change the tracing behavior for folks, and they have some concerns about overall behavior changing since it's a black box on OpenAI's end.

Could we table it for now? Sorry for the thrash - you can always wrap a .generate() call in a lambda.

Ok. FWIW here are my 2c:

  • I don't really get the point with "concerns about overall behavior". The samples are generated independently, with the benefit of only paying for input tokens once.
  • Pricing-wise the difference is huge, esp for use cases with lots of input and limited output. For us, we have lots of input tokens and not so many output tokens (relatively speaking), so not using n would be not great
  • IMO the tracing behavior is changed for the better, at least in terms of how this is visualized in LS
  • The goal with adding this to ChatOpenAI.batch (rather than hackily accomplishing the same thing with generate) is to avoid having lots of different logic for how to do requests depending on what model provider is used. Basically I've abstracted out model in my runnables so that they're given model: BaseChatModel, which lets me easily configure what model to use from one place.

If this still isn't a change that doesn't make sense on your end, I'll just apply a patch locally for now.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OpenAI supporting n completions is a very high value feature, because of the fact that input tokens are priced only once. If you make n separate requests you eat the input token costs n times. This is an amazing aspect of the OpenAI pricing model, which many other providers don't support (for example Anthropic). I believe making it easy for users to benefit from this, even if they don't know about it is a great value add LangChain can provide.

OpenAI supports the best_of option, which has interplay with n.

Generates best_of completions server-side and returns the "best" (the one with the highest log probability per token). Results cannot be streamed.

Users can also do this themselves now that chat completions return logprobs. It's a common pattern in my workflows to increase temperature for higher generation variance and utilizing the logprobs or simply doing self-consistency voting (https://arxiv.org/abs/2203.11171). The OpenAI pricing model has great synergy with these techniques, since you only pay extra for your generations.

I would almost argue that this feature of the API enables quality improving techniques where they would otherwise be cost prohibitive, and think leaning in and making these as easy to use as possible is of immense value.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll figure it out on our end and get this merged. Thanks for weighing in!

options?.callbacks
);
// TODO: Remove cast after figuring out inheritance
const chatGenerations = result.generations[0] as ChatGeneration[];
return chatGenerations.map((g) => g.message as BaseMessageChunk);
} else {
return super.batch(inputs, options);
}
}

/** @ignore */
async _generate(
messages: BaseMessage[],
Expand Down
36 changes: 36 additions & 0 deletions libs/langchain-openai/src/tests/chat_models.int.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,42 @@ test("Test ChatOpenAI tokenUsage with a batch", async () => {
expect(tokenUsage.promptTokens).toBeGreaterThan(0);
});

test("Test ChatOpenAI batch sends prompt to API with n option if prompts are the same", async () => {
const model = new ChatOpenAI({
temperature: 2,
modelName: "gpt-3.5-turbo",
maxTokens: 10,
});

const generatePromptSpy = jest.spyOn(model, "generatePrompt");
const res = await model.batch([
[new HumanMessage("Hello!")],
[new HumanMessage("Hello!")],
]);
console.log(res);
expect(res).toHaveLength(2);

expect(generatePromptSpy).toHaveBeenCalledTimes(1);
});

test("Test ChatOpenAI batch sends prompt to API in separate requests if prompts are different", async () => {
const model = new ChatOpenAI({
temperature: 2,
modelName: "gpt-3.5-turbo",
maxTokens: 10,
});

const generatePromptSpy = jest.spyOn(model, "generatePrompt");
const res = await model.batch([
[new HumanMessage("Hello!")],
[new HumanMessage("Hi")],
]);
console.log(res);
expect(res).toHaveLength(2);

expect(generatePromptSpy).toHaveBeenCalledTimes(2);
});

test("Test ChatOpenAI in streaming mode", async () => {
let nrNewTokens = 0;
let streamedCompletion = "";
Expand Down