-
Notifications
You must be signed in to change notification settings - Fork 1
/
vectordb.py
118 lines (97 loc) · 3.08 KB
/
vectordb.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import lancedb
import pyarrow as pa
import json
embedding_models=[
"",
""
]
class LanceDBAssistant:
def __init__(self, dirpath, filename,n=384):
self.dirpath = dirpath
self.filename = filename
self.db = None
self.create_schema(n)
def create_schema(self,n=384):
self.schema = pa.schema([
pa.field("vector", pa.list_(pa.float32(), n)),
pa.field("item", pa.string()),
pa.field("id", pa.string()),
])
def connect(self):
if self.db is None:
self.db = lancedb.connect(self.dirpath)
def create(self):
self.connect()
table = self.db.create_table(self.filename, schema=self.schema, mode="overwrite")
return table
def open(self):
table=None
try:
ts=self.db.table_names()
if self.filename in ts:
table = self.db.open_table(self.filename)
except:
print('Creating a new table')
return table
def add(self, data):
self.connect()
table = self.open()
if table is None:
table = self.create() # Assuming data is a pyarrow.Table
table.add(data=data,
# mode="overwrite" //这个导致了bug,全部覆盖了
)
return self.db[self.filename].head()
def search(self, vector, limit=5):
self.connect()
table = self.open()
res=[]
if table:
res = table.search(vector).select(['id','item']).limit(limit).to_list()
res=[{
'id':r['id'],
'item':json.loads(r['item']),
'_distance':r['_distance']
} for r in res]
return res
def list_tables(self):
self.connect()
# result=[]
# for name in self.db.table_names():
# print(self.db[name].head())
return self.db.table_names()
def delete_table(self,filename):
self.connect()
return self.db.drop_table(filename, ignore_missing=True)
def get_by_id(self,id):
self.connect()
table = self.open()
if table:
items=table.search().where(f"id = '{id}'", prefilter=True).select(['id']).to_list()
for item in items:
if item['id']==id:
return item
return
def update(self,id,item):
self.connect()
table = self.open()
if table:
table.update(where=f"id = '{id}'", values={"item":item})
# dirpath = "tmp/sample-lancedb"
# filename = "my_table2"
# assistant = LanceDBAssistant(dirpath, filename)
# # Create a new table
# assistant.create_schema()
# table = assistant.create(schema)
# # Add new data
# data = [{"vector": [1.3, 1.4], "item": "fizz" },
# {"vector": [9.5, 56.2], "item": "buzz" }]
# assistant.add(data)
# # Search by vector
# vector = [1.3, 1.4] # Your search vector
# results = assistant.search(vector)
# # List all tables
# tables = assistant.list_tables()
# print(results)
# Delete the table
# assistant.delete_table()