Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow multiple metadata keys on RedisVectorStoreFilterType #5015 #5028

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
8 changes: 5 additions & 3 deletions libs/langchain-redis/src/tests/vectorstores.int.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
/* eslint-disable no-promise-executor-return */

import { RedisClientType, createClient } from "redis";
import { SchemaFieldTypes } from "redis";
import { v4 as uuidv4 } from "uuid";
import { test, expect } from "@jest/globals";
import { faker } from "@faker-js/faker";
Expand All @@ -21,6 +22,9 @@ describe("RedisVectorStore", () => {
redisClient: client as RedisClientType,
indexName: "test-index",
keyPrefix: "test:",
metadataSchema: {
["foo"]: SchemaFieldTypes.TEXT,
}
});
});

Expand Down Expand Up @@ -66,9 +70,7 @@ describe("RedisVectorStore", () => {
]);

// If the filter wasn't working, we'd get all 3 documents back
const results = await vectorStore.similaritySearch(pageContent, 3, [
`${uuid}`,
]);
const results = await vectorStore.similaritySearch(pageContent, 3, `@foo:(${uuid})`);

expect(results).toEqual([
new Document({ metadata: { foo: uuid }, pageContent }),
Expand Down
46 changes: 12 additions & 34 deletions libs/langchain-redis/src/tests/vectorstores.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import { jest, test, expect, describe } from "@jest/globals";
import { FakeEmbeddings } from "@langchain/core/utils/testing";

import { RedisVectorStore } from "../vectorstores.js";
import { SchemaFieldTypes } from "redis";

const createRedisClientMockup = () => {
const hSetMock = jest.fn();
Expand Down Expand Up @@ -34,6 +35,9 @@ test("RedisVectorStore with external keys", async () => {
const store = new RedisVectorStore(embeddings, {
redisClient: client as any,
indexName: "documents",
metadataSchema: {
["a"]: SchemaFieldTypes.NUMERIC,
}
});

expect(store).toBeDefined();
Expand All @@ -44,7 +48,6 @@ test("RedisVectorStore with external keys", async () => {
pageContent: "hello",
metadata: {
a: 1,
b: { nested: [1, { a: 4 }] },
},
},
],
Expand All @@ -55,7 +58,7 @@ test("RedisVectorStore with external keys", async () => {
expect(client.hSet).toHaveBeenCalledWith("id1", {
content_vector: Buffer.from(new Float32Array([0.1, 0.2, 0.3, 0.4]).buffer),
content: "hello",
metadata: `{\\"a\\"\\:1,\\"b\\"\\:{\\"nested\\"\\:[1,{\\"a\\"\\:4}]}}`,
a: 1,
});

const results = await store.similaritySearch("goodbye", 1);
Expand All @@ -70,6 +73,9 @@ test("RedisVectorStore with generated keys", async () => {
const store = new RedisVectorStore(embeddings, {
redisClient: client as any,
indexName: "documents",
metadataSchema: {
["a"]: SchemaFieldTypes.NUMERIC,
}
});

expect(store).toBeDefined();
Expand All @@ -90,46 +96,18 @@ test("RedisVectorStore with filters", async () => {
const store = new RedisVectorStore(embeddings, {
redisClient: client as any,
indexName: "documents",
});

expect(store).toBeDefined();

await store.similaritySearch("hello", 1, ["a", "b", "c"]);

expect(client.ft.search).toHaveBeenCalledWith(
"documents",
"@metadata:(a|b|c) => [KNN 1 @content_vector $vector AS vector_score]",
{
PARAMS: {
vector: Buffer.from(new Float32Array([0.1, 0.2, 0.3, 0.4]).buffer),
},
RETURN: ["metadata", "content", "vector_score"],
SORTBY: "vector_score",
DIALECT: 2,
LIMIT: {
from: 0,
size: 1,
},
metadataSchema: {
["metadata"]: SchemaFieldTypes.TEXT
}
);
});

test("RedisVectorStore with raw filter", async () => {
const client = createRedisClientMockup();
const embeddings = new FakeEmbeddings();

const store = new RedisVectorStore(embeddings, {
redisClient: client as any,
indexName: "documents",
});

expect(store).toBeDefined();

await store.similaritySearch("hello", 1, "a b c");
await store.similaritySearch("hello", 1, "@metadata:(a|b|c)");

expect(client.ft.search).toHaveBeenCalledWith(
"documents",
"@metadata:(a b c) => [KNN 1 @content_vector $vector AS vector_score]",
"@metadata:(a|b|c) => [KNN 1 @content_vector $vector AS vector_score]",
{
PARAMS: {
vector: Buffer.from(new Float32Array([0.1, 0.2, 0.3, 0.4]).buffer),
Expand Down
76 changes: 43 additions & 33 deletions libs/langchain-redis/src/vectorstores.ts
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,8 @@ export interface RedisVectorStoreConfig {
createIndexOptions?: Omit<RedisVectorStoreIndexOptions, "PREFIX">; // PREFIX must be set with keyPrefix
keyPrefix?: string;
contentKey?: string;
metadataKey?: string;
vectorKey?: string;
metadataSchema?: RediSearchSchema;
filter?: RedisVectorStoreFilterType;
}

Expand All @@ -88,12 +88,9 @@ export interface RedisAddOptions {
}

/**
* Type for the filter used in the RedisVectorStore. It is an array of
* strings.
* If a string is passed instead of an array the value is used directly, this
* allows custom filters to be passed.
* Type for the filter used in the RedisVectorStore. It is a Redis filter setence, such as @field:{value1}.
*/
export type RedisVectorStoreFilterType = string[] | string;
export type RedisVectorStoreFilterType = string;

/**
* Class representing a RedisVectorStore. It extends the VectorStore class
Expand All @@ -117,10 +114,10 @@ export class RedisVectorStore extends VectorStore {

contentKey: string;

metadataKey: string;

vectorKey: string;

metadataSchema: RediSearchSchema;

filter?: RedisVectorStoreFilterType;

_vectorstoreType(): string {
Expand All @@ -141,8 +138,8 @@ export class RedisVectorStore extends VectorStore {
};
this.keyPrefix = _dbConfig.keyPrefix ?? `doc:${this.indexName}:`;
this.contentKey = _dbConfig.contentKey ?? "content";
this.metadataKey = _dbConfig.metadataKey ?? "metadata";
this.vectorKey = _dbConfig.vectorKey ?? "content_vector";
this.metadataSchema = _dbConfig.metadataSchema ?? {};
this.filter = _dbConfig.filter;
this.createIndexOptions = {
ON: "HASH",
Expand Down Expand Up @@ -185,6 +182,7 @@ export class RedisVectorStore extends VectorStore {
if (!vectors.length || !vectors[0].length) {
throw new Error("No vectors provided");
}

// check if the index exists and create it if it doesn't
await this.createIndex(vectors[0].length);

Expand All @@ -202,12 +200,19 @@ export class RedisVectorStore extends VectorStore {
? documents[idx].metadata
: {};

multi.hSet(key, {
[this.vectorKey]: this.getFloat32Buffer(vector),
[this.contentKey]: documents[idx].pageContent,
[this.metadataKey]: this.escapeSpecialChars(JSON.stringify(metadata)),
var t = {
[this.vectorKey]: this.getFloat32Buffer(vector),
[this.contentKey]: documents[idx].pageContent,
};

Object.keys(this.metadataSchema).forEach((key) => {
if(metadata[key]) {
t[key] = (Array.isArray(metadata[key])) ? this.escapeSpecialChars(metadata[key].map((val: any) => val.toString()).join(",")) : this.escapeSpecialChars(metadata[key].toString());
}
});

multi.hSet(key, t);

// write batch
if (idx % batchSize === 0) {
await multi.exec();
Expand Down Expand Up @@ -250,11 +255,20 @@ export class RedisVectorStore extends VectorStore {
result.push([
new Document({
pageContent: (document[this.contentKey] ?? "") as string,
metadata: JSON.parse(
this.unEscapeSpecialChars(
(document.metadata ?? "{}") as string
)
),
metadata: Object.keys(this.metadataSchema).reduce((acc: any, key) => {
const str: string = this.unEscapeSpecialChars((document[key] || "") as string);
switch(this.metadataSchema[key]) {
case SchemaFieldTypes.NUMERIC:
acc[key] = parseFloat(str);
break;
case SchemaFieldTypes.TAG:
acc[key] = str.split(",");
break;
default:
acc[key] = str;
}
return acc;
}, {}),
}),
Number(document.vector_score),
]);
Expand Down Expand Up @@ -361,9 +375,12 @@ export class RedisVectorStore extends VectorStore {
...this.indexOptions,
},
[this.contentKey]: SchemaFieldTypes.TEXT,
[this.metadataKey]: SchemaFieldTypes.TEXT,
};

Object.keys(this.metadataSchema).forEach((key) => {
schema[key] = this.metadataSchema[key];
});

await this.redisClient.ft.create(
this.indexName,
schema,
Expand Down Expand Up @@ -407,16 +424,16 @@ export class RedisVectorStore extends VectorStore {
): [string, SearchOptions] {
const vectorScoreField = "vector_score";

let hybridFields = "*";
let hybridFields: string;
// if a filter is set, modify the hybrid query
if (filter && filter.length) {
// `filter` is a list of strings, then it's applied using the OR operator in the metadata key
// for example: filter = ['foo', 'bar'] => this will filter all metadata containing either 'foo' OR 'bar'
hybridFields = `@${this.metadataKey}:(${this.prepareFilter(filter)})`;
if (typeof filter === "string") {
hybridFields = `${filter}`;
} else {
hybridFields = "*";
}

const baseQuery = `${hybridFields} => [KNN ${k} @${this.vectorKey} $vector AS ${vectorScoreField}]`;
const returnFields = [this.metadataKey, this.contentKey, vectorScoreField];
const returnFields = [...Object.keys(this.metadataSchema), this.contentKey, vectorScoreField];

const options: SearchOptions = {
PARAMS: {
Expand All @@ -434,13 +451,6 @@ export class RedisVectorStore extends VectorStore {
return [baseQuery, options];
}

private prepareFilter(filter: RedisVectorStoreFilterType) {
if (Array.isArray(filter)) {
return filter.map(this.escapeSpecialChars).join("|");
}
return filter;
}

/**
* Escapes all '-', ':', and '"' characters.
* RediSearch considers these all as special characters, so we need
Expand Down Expand Up @@ -480,4 +490,4 @@ export class RedisVectorStore extends VectorStore {
private getFloat32Buffer(vector: number[]) {
return Buffer.from(new Float32Array(vector).buffer);
}
}
}