-
Notifications
You must be signed in to change notification settings - Fork 299
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: add tidb vector store #656
base: main
Are you sure you want to change the base?
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,322 @@ | ||
import type mysql from "mysql2/promise"; | ||
import type { PoolOptions, RowDataPacket } from "mysql2/promise"; | ||
|
||
import type { | ||
VectorStore, | ||
VectorStoreQuery, | ||
VectorStoreQueryResult, | ||
} from "./types.js"; | ||
|
||
import type { GenericFileSystem } from "@llamaindex/env"; | ||
import type { BaseNode, Metadata } from "../../Node.js"; | ||
import { Document, MetadataMode } from "../../Node.js"; | ||
|
||
export const TIDB_VECTOR_TABLE = "llamaindex_embedding"; | ||
|
||
interface DocumentEmbedding extends RowDataPacket { | ||
id: string; | ||
document: string; | ||
metadata: Metadata; | ||
embeddings: number[]; | ||
score: number; | ||
} | ||
|
||
/** | ||
* Provides support for writing and querying vector data in TiDB. | ||
*/ | ||
export class TiDBVectorStore implements VectorStore { | ||
storesText: boolean = true; | ||
|
||
private namespace: string = ""; | ||
private tableName: string = TIDB_VECTOR_TABLE; | ||
private poolOptions: PoolOptions = {}; | ||
private dimensions: number = 1536; | ||
|
||
private db?: mysql.Pool; | ||
|
||
/** | ||
* Constructs a new instance of the TiDBVectorStore | ||
* | ||
* @param {object} config - The configuration settings for the instance. | ||
* @param {string} config.tableName - The name of the table (optional). Defaults to TIDB_VECTOR_TABLE. | ||
* @param {number} config.dimensions - The dimensions of the embedding model. | ||
* @param {string} config.poolOptions - The pool options for the TiDB connection. | ||
*/ | ||
constructor(config?: { | ||
namespace?: string; | ||
tableName?: string; | ||
dimensions?: number; | ||
poolOptions?: PoolOptions; | ||
client?: mysql.Pool; | ||
}) { | ||
this.tableName = config?.tableName ?? TIDB_VECTOR_TABLE; | ||
this.namespace = config?.namespace ?? ""; | ||
this.dimensions = config?.dimensions ?? 1536; | ||
this.poolOptions = config?.poolOptions ?? {}; | ||
Comment on lines
+52
to
+55
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. you can use |
||
if (config?.client) { | ||
this.db = config.client; | ||
} | ||
} | ||
|
||
/** | ||
* Setter for the namespace property. | ||
* Using a namespace allows for simple segregation of vector data, | ||
* e.g. by user, source, or access-level. | ||
* Leave/set blank to ignore the namespace value when querying. | ||
* @param namespace Name for the namespace. | ||
*/ | ||
setNamespace(namespace: string) { | ||
const name = this.formatNamespace(namespace); | ||
if (name.length > 64 || name.length == 0) { | ||
throw new Error( | ||
"Invalid namespace: " + name + ", must be 1-64 characters length.", | ||
); | ||
} | ||
this.namespace = name; | ||
} | ||
|
||
/** | ||
* Getter for the namespace property. | ||
* Using a namespace allows for simple segregation of vector data, | ||
* e.g. by user, source, or access-level. | ||
* Leave/set blank to ignore the namespace value when querying. | ||
* @returns The currently-set namespace value. Default is empty string. | ||
*/ | ||
getNamespace(): string { | ||
return this.namespace; | ||
} | ||
|
||
private async getDb(): Promise<mysql.Pool> { | ||
if (!this.db) { | ||
try { | ||
const { createPool } = await import("mysql2/promise"); | ||
// Create DB connection | ||
// Read connection params from env - see comment block above | ||
const db = createPool(this.poolOptions); | ||
|
||
// Check schema, table(s), index(es) | ||
await this.checkSchema(db); | ||
|
||
// All good? Keep the connection reference | ||
this.db = db; | ||
} catch (err: any) { | ||
console.error(err); | ||
return Promise.reject(err); | ||
} | ||
} | ||
|
||
return Promise.resolve(this.db); | ||
Mini256 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
} | ||
|
||
private async checkSchema(db: mysql.Pool) { | ||
const tbl = `CREATE TABLE IF NOT EXISTS ${this.tableName}( | ||
namespace VARCHAR(64), | ||
id BINARY(16), | ||
external_id VARCHAR(100), | ||
document TEXT, | ||
metadata JSON, | ||
embeddings VECTOR(${this.dimensions}) NOT NULL COMMENT 'hnsw(distance=cosine)', | ||
PRIMARY KEY (namespace, id), | ||
) | ||
PARTITION BY LIST COLUMNS (namespace) ( | ||
PARTITION p_default DEFAULT | ||
);`; | ||
await db.query(tbl); | ||
|
||
if (this.namespace.length == 0) { | ||
const partition = `ALTER TABLE ${this.tableName} ADD PARTITION IF NOT EXISTS (PARTITION p_${this.namespace} VALUES IN ('${this.namespace}'));`; | ||
await db.query(partition); | ||
} | ||
|
||
return db; | ||
} | ||
|
||
/** | ||
* Connects to the database specified in environment vars. | ||
* This method also checks and creates the vector extension, | ||
* the destination table and indexes if not found. | ||
* @returns A connection to the database, or the error encountered while connecting/setting up. | ||
*/ | ||
client() { | ||
return this.getDb(); | ||
} | ||
|
||
/** | ||
* Delete all vector records for the specified namespace. | ||
* @returns The result of the delete query. | ||
*/ | ||
async clearNamespace() { | ||
const sql: string = `DELETE FROM ${this.tableName} WHERE namespace = $1`; | ||
const db = await this.getDb(); | ||
return db.query(sql, [this.namespace]); | ||
} | ||
|
||
private getDataToInsert(embeddingResults: BaseNode<Metadata>[]) { | ||
const result = []; | ||
for (let index = 0; index < embeddingResults.length; index++) { | ||
const row = embeddingResults[index]; | ||
|
||
const id: any = row.id_.length ? row.id_ : null; | ||
const externalId = ""; | ||
const document = row.getContent(MetadataMode.EMBED); | ||
const meta = row.metadata || {}; | ||
meta.create_date = new Date(); | ||
const embeddings = this.vectorToSQL(row.getEmbedding()); | ||
|
||
const params = [ | ||
this.namespace, | ||
id, | ||
externalId, | ||
document, | ||
meta, | ||
embeddings, | ||
]; | ||
|
||
result.push(params); | ||
} | ||
return result; | ||
} | ||
|
||
/** | ||
* Adds vector record(s) to the table. | ||
* @param embeddingResults The Nodes to be inserted, optionally including metadata tuples. | ||
* @returns A list of zero or more id values for the created records. | ||
*/ | ||
async add(embeddingResults: BaseNode<Metadata>[]): Promise<string[]> { | ||
if (embeddingResults.length == 0) { | ||
console.debug("Empty list sent to TiDBVectorStore::add"); | ||
return Promise.resolve([]); | ||
} | ||
|
||
const sql: string = `INSERT INTO ${this.tableName} (namespace, id, external_id, document, metadata, embeddings) VALUES ($1, $2, $3, $4, $5, $6)`; | ||
const db = await this.getDb(); | ||
const data = this.getDataToInsert(embeddingResults); | ||
|
||
const ret: string[] = []; | ||
for (let index = 0; index < data.length; index++) { | ||
const params = data[index]; | ||
try { | ||
const [rows] = await db.query<DocumentEmbedding[]>(sql, params); | ||
if (rows.length) { | ||
const id = rows[0].id as string; | ||
ret.push(id); | ||
} | ||
} catch (err) { | ||
const msg = `${err}`; | ||
console.log(msg, err); | ||
} | ||
} | ||
|
||
return Promise.resolve(ret); | ||
} | ||
|
||
/** | ||
* Deletes a single record from the database by id. | ||
* @param refDocId Unique identifier for the record to delete. | ||
* @param deleteKwargs Required by VectorStore interface. Currently ignored. | ||
* @returns Promise that resolves if the delete query did not throw an error. | ||
*/ | ||
async delete(refDocId: string, deleteKwargs?: any): Promise<void> { | ||
const namespaceCriteria = this.namespace.length ? "AND namespace = $2" : ""; | ||
const sql: string = `DELETE FROM ${this.tableName} WHERE id = $1 ${namespaceCriteria}`; | ||
const db = await this.getDb(); | ||
const params = this.namespace.length | ||
? [refDocId, this.namespace] | ||
: [refDocId]; | ||
await db.query(sql, params); | ||
return Promise.resolve(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. no-op |
||
} | ||
|
||
/** | ||
* Query the vector store for the closest matching data to the query embeddings | ||
* @param query The VectorStoreQuery to be used | ||
* @param options Required by VectorStore interface. Currently ignored. | ||
* @returns Zero or more Document instances with data from the vector store. | ||
*/ | ||
async query( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We're using metadataDictToNode and nodeToMetadata from "./utils.js" to de/serialize the node in the metadata. You can reference the other vector store implementations to have a look into this. |
||
query: VectorStoreQuery, | ||
options?: any, | ||
): Promise<VectorStoreQueryResult> { | ||
const embedding = this.vectorToSQL(query.queryEmbedding); | ||
const max = query.similarityTopK ?? 2; | ||
const whereClauses = options.namespace ? ["namespace = $2"] : []; | ||
|
||
const params: Array<string | number> = options.namespace | ||
? [embedding, options.namespace] | ||
: [embedding]; | ||
|
||
query.filters?.filters.forEach((filter, index) => { | ||
const paramIndex = params.length + 1; | ||
whereClauses.push(`metadata->>'${filter.key}' = $${paramIndex}`); | ||
params.push(filter.value); | ||
}); | ||
|
||
const where = | ||
whereClauses.length > 0 ? `WHERE ${whereClauses.join(" AND ")}` : ""; | ||
|
||
const sql = `SELECT | ||
v.*, | ||
VEC_COSINE_DISTINCE(embeddings, $1) AS score | ||
FROM ${this.tableName} v | ||
${where} | ||
ORDER BY score | ||
LIMIT ${max} | ||
`; | ||
|
||
const db = await this.getDb(); | ||
const [rows] = await db.query<DocumentEmbedding[]>(sql, params); | ||
const nodes = rows.map((row) => { | ||
return new Document({ | ||
id_: row.id, | ||
text: row.document, | ||
metadata: row.metadata, | ||
embedding: row.embeddings, | ||
}); | ||
}); | ||
|
||
const ret = { | ||
nodes: nodes, | ||
similarities: rows.map((row) => row.score), | ||
ids: rows.map((row) => row.id), | ||
}; | ||
|
||
return Promise.resolve(ret); | ||
} | ||
|
||
/** | ||
* Required by VectorStore interface. Currently ignored. | ||
* @param persistPath | ||
* @param fs | ||
* @returns Resolved Promise. | ||
*/ | ||
persist( | ||
persistPath: string, | ||
fs?: GenericFileSystem | undefined, | ||
): Promise<void> { | ||
return Promise.resolve(); | ||
} | ||
|
||
/** | ||
* Converts a vector to a SQL string. | ||
* @param vector The vector to convert. | ||
*/ | ||
vectorToSQL(vector?: number[]): string { | ||
return "[" + vector?.join(",") + "]"; | ||
} | ||
|
||
/** | ||
* Formats a namespace string to a valid SQL identifier. | ||
* @param namespace | ||
*/ | ||
formatNamespace(namespace: string): string { | ||
namespace = namespace.toLowerCase(); | ||
|
||
// Replace non-alphanumeric characters with underscores. | ||
namespace = namespace.replace(/\W+/g, "_"); | ||
|
||
// Remove leading/trailing underscores. | ||
namespace = namespace.replace(/^_+|_+$/g, ""); | ||
|
||
return namespace; | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
did you check if the data generated is compatible with the python version (see https://github.com/run-llama/llama_index/blob/main/llama-index-integrations/vector_stores/llama-index-vector-stores-tidbvector/llama_index/vector_stores/tidbvector/base.py) ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, I will make some adjustments to be consistent with the Python version.