Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

community[minor]: Improve Azure Cosmos DB vector store support #5197

Merged
merged 11 commits into from
May 13, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,14 @@ const store = await AzureCosmosDBVectorStore.fromDocuments(
{
databaseName: "langchain",
collectionName: "documents",
indexOptions: {
numLists: 100,
dimensions: 1536,
similarity: AzureCosmosDBSimilarityType.COS,
},
}
);

// Create the index
const numLists = 100;
const dimensions = 1536;
const similarity = AzureCosmosDBSimilarityType.COS;
await store.createIndex(numLists, dimensions, similarity);

// Performs a similarity search
const resultDocuments = await store.similaritySearch(
"What did the president say about Ketanji Brown Jackson?"
Expand Down
108 changes: 87 additions & 21 deletions libs/langchain-community/src/vectorstores/azure_cosmosdb.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@ import {
Document as MongoDBDocument,
MongoClient,
Db,
Filter,
} from "mongodb";
import type { EmbeddingsInterface } from "@langchain/core/embeddings";
import {
MaxMarginalRelevanceSearchOptions,
VectorStore,
} from "@langchain/core/vectorstores";
import { Document } from "@langchain/core/documents";
import { Document, DocumentInterface } from "@langchain/core/documents";
import { maximalMarginalRelevance } from "@langchain/core/utils/math";
import { getEnvironmentVariable } from "@langchain/core/utils/env";

Expand All @@ -28,6 +29,26 @@ export const AzureCosmosDBSimilarityType = {
export type AzureCosmosDBSimilarityType =
(typeof AzureCosmosDBSimilarityType)[keyof typeof AzureCosmosDBSimilarityType];

/** Azure Cosmos DB Index Options. */
export type AzureCosmosDBIndexOptions = {
/** Skips automatic index creation. */
readonly skipCreate?: boolean;
/** Number of clusters that the inverted file (IVF) index uses to group the vector data. */
readonly numLists?: number;
/** Number of dimensions for vector similarity. */
readonly dimensions?: number;
/** Similarity metric to use with the IVF index. */
readonly similarity?: AzureCosmosDBSimilarityType;
};

/** Azure Cosmos DB Delete Parameters. */
export type AzureCosmosDBDeleteParams = {
/** List of IDs for the documents to be removed. */
readonly ids?: string | string[];
/** MongoDB filter object or list of IDs for the documents to be removed. */
readonly filter?: Filter<MongoDBDocument>;
};

/**
* Configuration options for the `AzureCosmosDBVectorStore` constructor.
*/
Expand All @@ -39,6 +60,7 @@ export interface AzureCosmosDBConfig {
readonly indexName?: string;
readonly textKey?: string;
readonly embeddingKey?: string;
readonly indexOptions?: AzureCosmosDBIndexOptions;
}

/**
Expand All @@ -60,6 +82,8 @@ export class AzureCosmosDBVectorStore extends VectorStore {
};
}

private connectPromise: Promise<void>;

private readonly initPromise: Promise<void>;

private readonly client: MongoClient | undefined;
Expand All @@ -74,6 +98,8 @@ export class AzureCosmosDBVectorStore extends VectorStore {

readonly embeddingKey: string;

private readonly indexOptions: AzureCosmosDBIndexOptions;

_vectorstoreType(): string {
return "azure_cosmosdb";
}
Expand Down Expand Up @@ -105,6 +131,7 @@ export class AzureCosmosDBVectorStore extends VectorStore {
this.indexName = dbConfig.indexName ?? "vectorSearchIndex";
this.textKey = dbConfig.textKey ?? "textContent";
this.embeddingKey = dbConfig.embeddingKey ?? "vectorContent";
this.indexOptions = dbConfig.indexOptions ?? {};

// Start initialization, but don't wait for it to finish here
this.initPromise = this.init(client, databaseName, collectionName).catch(
Expand Down Expand Up @@ -169,7 +196,9 @@ export class AzureCosmosDBVectorStore extends VectorStore {
* Using a numLists value of 1 is akin to performing brute-force search,
* which has limited performance
* @param dimensions Number of dimensions for vector similarity.
* The maximum number of supported dimensions is 2000
* The maximum number of supported dimensions is 2000.
* If no number is provided, it will be determined automatically by
* embedding a short text.
* @param similarity Similarity metric to use with the IVF index.
* Possible options are:
* - CosmosDBSimilarityType.COS (cosine distance)
Expand All @@ -179,10 +208,17 @@ export class AzureCosmosDBVectorStore extends VectorStore {
*/
async createIndex(
numLists = 100,
dimensions = 1536,
dimensions: number | undefined = undefined,
similarity: AzureCosmosDBSimilarityType = AzureCosmosDBSimilarityType.COS
): Promise<void> {
await this.initPromise;
await this.connectPromise;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this is hypothetically called before connectPromise is initialized, could we throw instead?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah never mind, I see below. Could we always just rely on this.initPromise instead?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since init calls createIndex we need 2 different promise there otherwise there's an interlocking:

  • connectPromise => create DB + collections clients
  • initPromise => connectPromise + createIndex

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the throw question, the issue is that the promise is created in the constructor and there's no way for the user to wait for it or know when the connect task is done.


let vectorLength = dimensions;

if (vectorLength === undefined) {
const queryEmbedding = await this.embeddings.embedQuery("test");
vectorLength = queryEmbedding.length;
}

const createIndexCommands = {
createIndexes: this.collection.collectionName,
Expand All @@ -194,7 +230,7 @@ export class AzureCosmosDBVectorStore extends VectorStore {
kind: "vector-ivf",
numLists,
similarity,
dimensions,
dimensions: vectorLength,
},
},
],
Expand All @@ -205,19 +241,33 @@ export class AzureCosmosDBVectorStore extends VectorStore {

/**
* Removes specified documents from the AzureCosmosDBVectorStore.
* @param ids IDs of the documents to be removed. If no IDs are specified,
* all documents will be removed.
* If no IDs or filter are specified, all documents will be removed.
* @param params Parameters for the delete operation.
* @returns A promise that resolves when the documents have been removed.
*/
async delete(ids?: string[]): Promise<void> {
async delete(
params: AzureCosmosDBDeleteParams | string[] = {}
): Promise<void> {
await this.initPromise;

if (ids) {
const objectIds = ids.map((id) => new ObjectId(id));
await this.collection.deleteMany({ _id: { $in: objectIds } });
let ids: string | string[] | undefined;
let filter: AzureCosmosDBDeleteParams["filter"];
if (Array.isArray(params)) {
ids = params;
} else {
await this.collection.deleteMany({});
ids = params.ids;
filter = params.filter;
}
const idsArray = Array.isArray(ids) ? ids : [ids];
const deleteIds = ids && idsArray.length > 0 ? idsArray : undefined;
let deleteFilter = filter ?? {};

if (deleteIds) {
const objectIds = deleteIds.map((id) => new ObjectId(id));
deleteFilter = { _id: { $in: objectIds }, ...deleteFilter };
}

await this.collection.deleteMany(deleteFilter);
}

/**
Expand All @@ -236,27 +286,31 @@ export class AzureCosmosDBVectorStore extends VectorStore {
* Method for adding vectors to the AzureCosmosDBVectorStore.
* @param vectors Vectors to be added.
* @param documents Corresponding documents to be added.
* @returns A promise that resolves when the vectors and documents have been added.
* @returns A promise that resolves to the added documents IDs.
*/
async addVectors(vectors: number[][], documents: Document[]): Promise<void> {
async addVectors(
vectors: number[][],
documents: DocumentInterface[]
): Promise<string[]> {
const docs = vectors.map((embedding, idx) => ({
[this.textKey]: documents[idx].pageContent,
[this.embeddingKey]: embedding,
...documents[idx].metadata,
}));
await this.initPromise;
await this.collection.insertMany(docs);
const result = await this.collection.insertMany(docs);
return Object.values(result.insertedIds).map((id) => String(id));
}

/**
* Method for adding documents to the AzureCosmosDBVectorStore. It first converts
* the documents to texts and then adds them as vectors.
* @param documents The documents to add.
* @returns A promise that resolves when the documents have been added.
* @returns A promise that resolves to the added documents IDs.
*/
async addDocuments(documents: Document[]): Promise<void> {
async addDocuments(documents: DocumentInterface[]): Promise<string[]> {
const texts = documents.map(({ pageContent }) => pageContent);
await this.addVectors(
return this.addVectors(
await this.embeddings.embedDocuments(texts),
documents
);
Expand Down Expand Up @@ -355,9 +409,21 @@ export class AzureCosmosDBVectorStore extends VectorStore {
databaseName: string,
collectionName: string
): Promise<void> {
await client.connect();
this.database = client.db(databaseName);
this.collection = this.database.collection(collectionName);
this.connectPromise = (async () => {
await client.connect();
this.database = client.db(databaseName);
this.collection = this.database.collection(collectionName);
})();

// Unless skipCreate is set, create the index
// This operation is no-op if the index already exists
if (!this.indexOptions.skipCreate) {
jacoblee93 marked this conversation as resolved.
Show resolved Hide resolved
await this.createIndex(
this.indexOptions.numLists,
this.indexOptions.dimensions,
this.indexOptions.similarity
);
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,18 @@ describe.skip("AzureCosmosDBVectorStore", () => {
process.env.AZURE_COSMOSDB_CONNECTION_STRING!
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey there! I've reviewed the code and noticed that the recent changes explicitly access an environment variable using process.env. I've flagged this for your review to ensure it aligns with our environment variable handling practices. Let me know if you have any questions or need further clarification.

);
await client.connect();
const collection = client.db(DATABASE_NAME).collection(COLLECTION_NAME);
const db = client.db(DATABASE_NAME);
const collection = await db.createCollection(COLLECTION_NAME);

// Make sure the database is empty
await collection.deleteMany({});

// Delete any existing index
await collection.dropIndex(INDEX_NAME);
try {
await collection.dropIndex(INDEX_NAME);
} catch {
// Ignore error if the index does not exist
}

await client.close();
});
Expand All @@ -64,6 +69,9 @@ describe.skip("AzureCosmosDBVectorStore", () => {
databaseName: DATABASE_NAME,
collectionName: COLLECTION_NAME,
indexName: INDEX_NAME,
indexOptions: {
numLists: 1,
},
});

expect(vectorStore).toBeDefined();
Expand All @@ -75,9 +83,6 @@ describe.skip("AzureCosmosDBVectorStore", () => {
{ pageContent: "The house is open", metadata: { d: 1, e: 2 } },
]);

// Make sure the index is created
await vectorStore.createIndex(1);

const results: Document[] = await vectorStore.similaritySearch(
"sandwich",
1
Expand Down Expand Up @@ -110,12 +115,12 @@ describe.skip("AzureCosmosDBVectorStore", () => {
databaseName: DATABASE_NAME,
collectionName: COLLECTION_NAME,
indexName: INDEX_NAME,
indexOptions: {
numLists: 1,
},
}
);

// Make sure the index is created
await vectorStore.createIndex(1);

const output = await vectorStore.maxMarginalRelevanceSearch("foo", {
k: 10,
fetchK: 20,
Expand Down Expand Up @@ -160,4 +165,90 @@ describe.skip("AzureCosmosDBVectorStore", () => {

await vectorStore.close();
});

test("deletes documents by id", async () => {
const vectorStore = new AzureCosmosDBVectorStore(new OpenAIEmbeddings(), {
databaseName: DATABASE_NAME,
collectionName: COLLECTION_NAME,
indexName: INDEX_NAME,
indexOptions: {
numLists: 1,
},
});

const ids = await vectorStore.addDocuments([
{ pageContent: "This book is about politics", metadata: { a: 1 } },
{
pageContent: "The is the house of parliament",
metadata: { d: 1, e: 2 },
},
]);

// Delete document matching specified ids
await vectorStore.delete({ ids: ids.slice(0, 1) });

const results = await vectorStore.similaritySearch("politics", 10);

expect(results.length).toEqual(1);
expect(results[0].pageContent).toEqual("The is the house of parliament");

await vectorStore.close();
});

test("deletes documents by filter", async () => {
const vectorStore = new AzureCosmosDBVectorStore(new OpenAIEmbeddings(), {
databaseName: DATABASE_NAME,
collectionName: COLLECTION_NAME,
indexName: INDEX_NAME,
indexOptions: {
numLists: 1,
},
});

await vectorStore.addDocuments([
{ pageContent: "This book is about politics", metadata: { a: 1 } },
{
pageContent: "The is the house of parliament",
metadata: { d: 1, e: 2 },
},
]);

// Delete document matching the filter
await vectorStore.delete({ filter: { a: 1 } });

const results = await vectorStore.similaritySearch("politics", 10);

expect(results.length).toEqual(1);
expect(results[0].pageContent).toEqual("The is the house of parliament");

await vectorStore.close();
});

test("deletes all documents", async () => {
const vectorStore = new AzureCosmosDBVectorStore(new OpenAIEmbeddings(), {
databaseName: DATABASE_NAME,
collectionName: COLLECTION_NAME,
indexName: INDEX_NAME,
indexOptions: {
numLists: 1,
},
});

await vectorStore.addDocuments([
{ pageContent: "This book is about politics", metadata: { a: 1 } },
{
pageContent: "The is the house of parliament",
metadata: { d: 1, e: 2 },
},
]);

// Delete all documents
await vectorStore.delete();

const results = await vectorStore.similaritySearch("politics", 10);

expect(results.length).toEqual(0);

await vectorStore.close();
});
});