Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

openai[minor]: Update OpenAI with Azure Specific Code #5323

Merged
merged 13 commits into from
May 14, 2024
2 changes: 1 addition & 1 deletion langchain/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -654,7 +654,7 @@
"node-llama-cpp": "2.7.3",
"notion-to-md": "^3.1.0",
"officeparser": "^4.0.4",
"openai": "^4.32.1",
"openai": "^4.41.1",
"pdf-parse": "1.1.1",
"peggy": "^3.0.2",
"playwright": "^1.32.1",
Expand Down
9 changes: 6 additions & 3 deletions langchain/src/experimental/openai_assistant/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,8 @@ export class OpenAIAssistantRunnable<
tools: formattedTools,
model,
file_ids: fileIds,
});
// eslint-disable-next-line @typescript-eslint/no-explicit-any
} as any);

return new this({
client: oaiClient,
Expand Down Expand Up @@ -130,7 +131,8 @@ export class OpenAIAssistantRunnable<
role: "user",
file_ids: input.file_ids,
metadata: input.messagesMetadata,
});
// eslint-disable-next-line @typescript-eslint/no-explicit-any
} as any);
run = await this._createRun(input);
} else {
// Submitting tool outputs to an existing run, outside the AgentExecutor
Expand Down Expand Up @@ -189,7 +191,8 @@ export class OpenAIAssistantRunnable<
instructions,
model,
file_ids: fileIds,
});
// eslint-disable-next-line @typescript-eslint/no-explicit-any
} as any);
}

private async _parseStepsInput(input: RunInput): Promise<RunInput> {
Expand Down
3 changes: 2 additions & 1 deletion libs/langchain-openai/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,12 @@
"dependencies": {
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey there! 👋 I noticed that the package.json file has an update to the "openai" dependency version and a new dev dependency "@azure/identity" added. This is just a heads up for the maintainers to review the changes in dependencies. Keep up the great work! 🚀

"@langchain/core": "~0.2.0-rc.0",
"js-tiktoken": "^1.0.7",
"openai": "^4.32.1",
"openai": "^4.41.1",
"zod": "^3.22.4",
"zod-to-json-schema": "^3.22.3"
},
"devDependencies": {
"@azure/identity": "^4.2.0",
"@jest/globals": "^29.5.0",
"@langchain/scripts": "~0.0",
"@swc/core": "^1.3.90",
Expand Down
62 changes: 54 additions & 8 deletions libs/langchain-openai/src/azure/chat_models.ts
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import { type ClientOptions } from "openai";
import { type ClientOptions, AzureOpenAI as AzureOpenAIClient } from "openai";
import { type BaseChatModelParams } from "@langchain/core/language_models/chat_models";
import { ChatOpenAI } from "../chat_models.js";
import { OpenAIEndpointConfig, getEndpoint } from "../utils/azure.js";
import {
AzureOpenAIInput,
LegacyOpenAIInput,
OpenAIChatInput,
OpenAICoreRequestOptions,
} from "../types.js";

export class AzureChatOpenAI extends ChatOpenAI {
Expand All @@ -31,15 +33,8 @@ export class AzureChatOpenAI extends ChatOpenAI {
configuration?: ClientOptions & LegacyOpenAIInput;
}
) {
// assume the base URL does not contain "openai" nor "deployments" prefix
let basePath = fields?.openAIBasePath ?? "";
if (!basePath.endsWith("/")) basePath += "/";
if (!basePath.endsWith("openai/deployments"))
basePath += "openai/deployments";

const newFields = fields ? { ...fields } : fields;
if (newFields) {
newFields.azureOpenAIBasePath = basePath;
newFields.azureOpenAIApiDeploymentName = newFields.deploymentName;
newFields.azureOpenAIApiKey = newFields.openAIApiKey;
newFields.azureOpenAIApiVersion = newFields.openAIApiVersion;
Expand All @@ -48,6 +43,57 @@ export class AzureChatOpenAI extends ChatOpenAI {
super(newFields);
}
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey there! I've reviewed the code changes and it looks like a new HTTP request setup using the AzureOpenAIClient has been added in the _getClientOptions method. I've flagged this for your review to ensure it aligns with the project's requirements. Let me know if you have any questions or need further clarification.


protected _getClientOptions(options: OpenAICoreRequestOptions | undefined) {
if (!this.client) {
const openAIEndpointConfig: OpenAIEndpointConfig = {
azureOpenAIApiDeploymentName: this.azureOpenAIApiDeploymentName,
azureOpenAIApiInstanceName: this.azureOpenAIApiInstanceName,
azureOpenAIApiKey: this.azureOpenAIApiKey,
azureOpenAIBasePath: this.azureOpenAIBasePath,
baseURL: this.clientConfig.baseURL,
};

const endpoint = getEndpoint(openAIEndpointConfig);

const params = {
...this.clientConfig,
baseURL: endpoint,
timeout: this.timeout,
maxRetries: 0,
};

if (!this.azureADTokenProvider) {
params.apiKey = openAIEndpointConfig.azureOpenAIApiKey;
}

if (!params.baseURL) {
delete params.baseURL;
}

this.client = new AzureOpenAIClient({
apiVersion: this.azureOpenAIApiVersion,
azureADTokenProvider: this.azureADTokenProvider,
deployment: this.azureOpenAIApiDeploymentName,
...params,
});
}
const requestOptions = {
...this.clientConfig,
...options,
} as OpenAICoreRequestOptions;
if (this.azureOpenAIApiKey) {
requestOptions.headers = {
"api-key": this.azureOpenAIApiKey,
...requestOptions.headers,
};
requestOptions.query = {
"api-version": this.azureOpenAIApiVersion,
...requestOptions.query,
};
}
return requestOptions;
}

// eslint-disable-next-line @typescript-eslint/no-explicit-any
toJSON(): any {
const json = super.toJSON() as unknown;
Expand Down
98 changes: 98 additions & 0 deletions libs/langchain-openai/src/azure/embeddings.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
import {
type ClientOptions,
AzureOpenAI as AzureOpenAIClient,
OpenAI as OpenAIClient,
} from "openai";
import { OpenAIEmbeddings, OpenAIEmbeddingsParams } from "../embeddings.js";
import {
AzureOpenAIInput,
OpenAICoreRequestOptions,
LegacyOpenAIInput,
} from "../types.js";
import { getEndpoint, OpenAIEndpointConfig } from "../utils/azure.js";
import { wrapOpenAIClientError } from "../utils/openai.js";

export class AzureOpenAIEmbeddings extends OpenAIEmbeddings {
constructor(
fields?: Partial<OpenAIEmbeddingsParams> &
Partial<AzureOpenAIInput> & {
verbose?: boolean;
/** The OpenAI API key to use. */
apiKey?: string;
configuration?: ClientOptions;
deploymentName?: string;
openAIApiVersion?: string;
},
configuration?: ClientOptions & LegacyOpenAIInput
) {
const newFields = { ...fields };
if (Object.entries(newFields).length) {
newFields.azureOpenAIApiDeploymentName = newFields.deploymentName;
newFields.azureOpenAIApiKey = newFields.apiKey;
newFields.azureOpenAIApiVersion = newFields.openAIApiVersion;
}

super(newFields, configuration);
}

protected async embeddingWithRetry(
request: OpenAIClient.EmbeddingCreateParams
) {
if (!this.client) {
const openAIEndpointConfig: OpenAIEndpointConfig = {
azureOpenAIApiDeploymentName: this.azureOpenAIApiDeploymentName,
azureOpenAIApiInstanceName: this.azureOpenAIApiInstanceName,
azureOpenAIApiKey: this.azureOpenAIApiKey,
azureOpenAIBasePath: this.azureOpenAIBasePath,
baseURL: this.clientConfig.baseURL,
};

const endpoint = getEndpoint(openAIEndpointConfig);

const params = {
...this.clientConfig,
baseURL: endpoint,
timeout: this.timeout,
maxRetries: 0,
};

if (!this.azureADTokenProvider) {
params.apiKey = openAIEndpointConfig.azureOpenAIApiKey;
}

if (!params.baseURL) {
delete params.baseURL;
}

this.client = new AzureOpenAIClient({
apiVersion: this.azureOpenAIApiVersion,
azureADTokenProvider: this.azureADTokenProvider,
deployment: this.azureOpenAIApiDeploymentName,
...params,
});
}
const requestOptions: OpenAICoreRequestOptions = {};
if (this.azureOpenAIApiKey) {
requestOptions.headers = {
"api-key": this.azureOpenAIApiKey,
...requestOptions.headers,
};
requestOptions.query = {
"api-version": this.azureOpenAIApiVersion,
...requestOptions.query,
};
}
return this.caller.call(async () => {
try {
const res = await this.client.embeddings.create(
request,
requestOptions
);
return res;
} catch (e) {
const error = wrapOpenAIClientError(e);
throw error;
}
});
}
}
62 changes: 54 additions & 8 deletions libs/langchain-openai/src/azure/llms.ts
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import { type ClientOptions } from "openai";
import { type ClientOptions, AzureOpenAI as AzureOpenAIClient } from "openai";
import { type BaseLLMParams } from "@langchain/core/language_models/llms";
import { OpenAI } from "../llms.js";
import { OpenAIEndpointConfig, getEndpoint } from "../utils/azure.js";
import type {
OpenAIInput,
AzureOpenAIInput,
OpenAICoreRequestOptions,
LegacyOpenAIInput,
} from "../types.js";

Expand All @@ -27,15 +29,8 @@ export class AzureOpenAI extends OpenAI {
configuration?: ClientOptions & LegacyOpenAIInput;
}
) {
// assume the base URL does not contain "openai" nor "deployments" prefix
let basePath = fields?.openAIBasePath ?? "";
if (!basePath.endsWith("/")) basePath += "/";
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, this is breaking, isn't it?

I don't think this is heavily used yet and there are no examples in our docs, so I'm ok with a change now.

if (!basePath.endsWith("openai/deployments"))
basePath += "openai/deployments";

const newFields = fields ? { ...fields } : fields;
if (newFields) {
newFields.azureOpenAIBasePath = basePath;
newFields.azureOpenAIApiDeploymentName = newFields.deploymentName;
newFields.azureOpenAIApiKey = newFields.openAIApiKey;
newFields.azureOpenAIApiVersion = newFields.openAIApiVersion;
Expand All @@ -44,6 +39,57 @@ export class AzureOpenAI extends OpenAI {
super(newFields);
}

protected _getClientOptions(options: OpenAICoreRequestOptions | undefined) {
if (!this.client) {
const openAIEndpointConfig: OpenAIEndpointConfig = {
azureOpenAIApiDeploymentName: this.azureOpenAIApiDeploymentName,
azureOpenAIApiInstanceName: this.azureOpenAIApiInstanceName,
azureOpenAIApiKey: this.azureOpenAIApiKey,
azureOpenAIBasePath: this.azureOpenAIBasePath,
baseURL: this.clientConfig.baseURL,
};

const endpoint = getEndpoint(openAIEndpointConfig);

const params = {
...this.clientConfig,
baseURL: endpoint,
timeout: this.timeout,
maxRetries: 0,
};

if (!this.azureADTokenProvider) {
params.apiKey = openAIEndpointConfig.azureOpenAIApiKey;
}

if (!params.baseURL) {
delete params.baseURL;
}

this.client = new AzureOpenAIClient({
apiVersion: this.azureOpenAIApiVersion,
azureADTokenProvider: this.azureADTokenProvider,
...params,
});
}

const requestOptions = {
...this.clientConfig,
...options,
} as OpenAICoreRequestOptions;
if (this.azureOpenAIApiKey) {
requestOptions.headers = {
"api-key": this.azureOpenAIApiKey,
...requestOptions.headers,
};
requestOptions.query = {
"api-version": this.azureOpenAIApiVersion,
...requestOptions.query,
};
}
return requestOptions;
}

// eslint-disable-next-line @typescript-eslint/no-explicit-any
toJSON(): any {
const json = super.toJSON() as unknown;
Expand Down
18 changes: 12 additions & 6 deletions libs/langchain-openai/src/chat_models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,8 @@ export class ChatOpenAI<

azureOpenAIApiKey?: string;

azureADTokenProvider?: () => Promise<string>;

azureOpenAIApiInstanceName?: string;

azureOpenAIApiDeploymentName?: string;
Expand All @@ -386,9 +388,9 @@ export class ChatOpenAI<

organization?: string;

private client: OpenAIClient;
protected client: OpenAIClient;

private clientConfig: ClientOptions;
protected clientConfig: ClientOptions;

constructor(
fields?: Partial<OpenAIChatInput> &
Expand All @@ -411,8 +413,12 @@ export class ChatOpenAI<
fields?.azureOpenAIApiKey ??
getEnvironmentVariable("AZURE_OPENAI_API_KEY");

if (!this.azureOpenAIApiKey && !this.apiKey) {
throw new Error("OpenAI or Azure OpenAI API key not found");
this.azureADTokenProvider = fields?.azureADTokenProvider ?? undefined;

if (!this.azureOpenAIApiKey && !this.apiKey && !this.azureADTokenProvider) {
throw new Error(
"OpenAI or Azure OpenAI API key or Token Provider not found"
);
}

this.azureOpenAIApiInstanceName =
Expand Down Expand Up @@ -455,7 +461,7 @@ export class ChatOpenAI<

this.streaming = fields?.streaming ?? false;

if (this.azureOpenAIApiKey) {
if (this.azureOpenAIApiKey || this.azureADTokenProvider) {
if (!this.azureOpenAIApiInstanceName && !this.azureOpenAIBasePath) {
throw new Error("Azure OpenAI API instance name not found");
}
Expand Down Expand Up @@ -898,7 +904,7 @@ export class ChatOpenAI<
});
}

private _getClientOptions(options: OpenAICoreRequestOptions | undefined) {
protected _getClientOptions(options: OpenAICoreRequestOptions | undefined) {
if (!this.client) {
const openAIEndpointConfig: OpenAIEndpointConfig = {
azureOpenAIApiDeploymentName: this.azureOpenAIApiDeploymentName,
Expand Down
Loading
Loading