forked from zilliztech/VectorDBBench
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathapi.py
More file actions
262 lines (209 loc) · 7.8 KB
/
api.py
File metadata and controls
262 lines (209 loc) · 7.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
from abc import ABC, abstractmethod
from contextlib import contextmanager
from enum import StrEnum
from pydantic import BaseModel, model_validator
from vectordb_bench.backend.filter import Filter, FilterOp
class MetricType(StrEnum):
L2 = "L2"
COSINE = "COSINE"
IP = "IP"
DP = "DP"
HAMMING = "HAMMING"
JACCARD = "JACCARD"
class IndexType(StrEnum):
HNSW = "HNSW"
HNSW_SQ = "HNSW_SQ"
HNSW_BQ = "HNSW_BQ"
HNSW_PQ = "HNSW_PQ"
HNSW_PRQ = "HNSW_PRQ"
DISKANN = "DISKANN"
STREAMING_DISKANN = "DISKANN"
IVFFlat = "IVF_FLAT"
IVFPQ = "IVF_PQ"
IVFBQ = "IVF_BQ"
IVFSQ8 = "IVF_SQ8"
IVF_RABITQ = "IVF_RABITQ"
Flat = "FLAT"
AUTOINDEX = "AUTOINDEX"
ES_HNSW = "hnsw"
ES_HNSW_INT8 = "int8_hnsw"
ES_HNSW_INT4 = "int4_hnsw"
ES_HNSW_BBQ = "bbq_hnsw"
TES_VSEARCH = "vsearch"
ES_IVFFlat = "ivfflat"
GPU_IVF_FLAT = "GPU_IVF_FLAT"
GPU_BRUTE_FORCE = "GPU_BRUTE_FORCE"
GPU_IVF_PQ = "GPU_IVF_PQ"
GPU_CAGRA = "GPU_CAGRA"
SCANN = "scann"
VCHORDRQ = "vchordrq"
VCHORDG = "vchordg"
SCANN_MILVUS = "SCANN_MILVUS"
Hologres_HGraph = "HGraph"
Hologres_Graph = "Graph"
NONE = "NONE"
class SQType(StrEnum):
SQ4U = "SQ4U"
SQ6 = "SQ6"
SQ8 = "SQ8"
BF16 = "BF16"
FP16 = "FP16"
FP32 = "FP32"
class DBConfig(ABC, BaseModel):
"""DBConfig contains the connection info of vector database
Args:
db_label(str): label to distinguish different types of DB of the same database.
MilvusConfig.db_label = 2c8g
MilvusConfig.db_label = 16c64g
ZillizCloudConfig.db_label = 1cu-perf
"""
db_label: str = ""
version: str = ""
note: str = ""
@staticmethod
def common_short_configs() -> list[str]:
"""
short input, such as `db_label`, `version`
"""
return ["version", "db_label"]
@staticmethod
def common_long_configs() -> list[str]:
"""
long input, such as `note`
"""
return ["note"]
@abstractmethod
def to_dict(self) -> dict:
raise NotImplementedError
@model_validator(mode="before")
@classmethod
def not_empty_field(cls, data: any) -> any:
if not isinstance(data, dict):
return data
skip = set(cls.common_short_configs()) | set(cls.common_long_configs())
for field_name, v in data.items():
if field_name in skip:
continue
if isinstance(v, str) and not v:
raise ValueError("Empty string!")
return data
class DBCaseConfig(ABC):
"""Case specific vector database configs, usually uesed for index params like HNSW"""
@abstractmethod
def index_param(self) -> dict:
raise NotImplementedError
@abstractmethod
def search_param(self) -> dict:
raise NotImplementedError
class EmptyDBCaseConfig(BaseModel, DBCaseConfig):
"""EmptyDBCaseConfig will be used if the vector database has no case specific configs"""
null: str | None = None
def index_param(self) -> dict:
return {}
def search_param(self) -> dict:
return {}
class VectorDB(ABC):
"""Each VectorDB will be __init__ once for one case, the object will be copied into multiple processes.
In each process, the benchmark cases ensure VectorDB.init() calls before any other methods operations
insert_embeddings, search_embedding, and, optimize will be timed for each call.
Examples:
>>> milvus = Milvus()
>>> with milvus.init():
>>> milvus.insert_embeddings()
>>> milvus.search_embedding()
"""
"The filtering types supported by the VectorDB Client, default only non-filter"
supported_filter_types: list[FilterOp] = [FilterOp.NonFilter]
name: str = ""
# Whether the client can share a single connection across threads.
# If False, concurrent runners will deep-copy the instance and call
# init() per thread instead of sharing the parent connection.
thread_safe: bool = True
@classmethod
def filter_supported(cls, filters: Filter) -> bool:
"""Ensure that the filters are supported before testing filtering cases."""
return filters.type in cls.supported_filter_types
def prepare_filter(self, filters: Filter):
"""The vector database is allowed to pre-prepare different filter conditions
to reduce redundancy during the testing process.
(All search tests in a case use consistent filtering conditions.)"""
return
@abstractmethod
def __init__(
self,
dim: int,
db_config: dict,
db_case_config: DBCaseConfig | None,
collection_name: str,
drop_old: bool = False,
**kwargs,
) -> None:
"""Initialize wrapper around the vector database client.
Please drop the existing collection if drop_old is True. And create collection
if collection not in the Vector Database
Args:
dim(int): the dimension of the dataset
db_config(dict): configs to establish connections with the vector database
db_case_config(DBCaseConfig | None): case specific configs for indexing and searching
drop_old(bool): whether to drop the existing collection of the dataset.
"""
raise NotImplementedError
@abstractmethod
@contextmanager
def init(self) -> None:
"""create and destory connections to database.
Why contextmanager:
In multiprocessing search tasks, vectordbbench might init
totally hundreds of thousands of connections with DB server.
Too many connections may drain local FDs or server connection resources.
If the DB client doesn't have `close()` method, just set the object to None.
Examples:
>>> with self.init():
>>> self.insert_embeddings()
"""
raise NotImplementedError
def need_normalize_cosine(self) -> bool:
"""Wheather this database need to normalize dataset to support COSINE"""
return False
@abstractmethod
def insert_embeddings(
self,
embeddings: list[list[float]],
metadata: list[int],
labels_data: list[str] | None = None,
**kwargs,
) -> tuple[int, Exception]:
"""Insert the embeddings to the vector database. The default number of embeddings for
each insert_embeddings is 5000.
Args:
embeddings(list[list[float]]): list of embedding to add to the vector database.
metadatas(list[int]): metadata associated with the embeddings, for filtering.
**kwargs(Any): vector database specific parameters.
Returns:
int: inserted data count
"""
raise NotImplementedError
@abstractmethod
def search_embedding(
self,
query: list[float],
k: int = 100,
) -> list[int]:
"""Get k most similar embeddings to query vector.
Args:
query(list[float]): query embedding to look up documents similar to.
k(int): Number of most similar embeddings to return. Defaults to 100.
filters(dict, optional): filtering expression to filter the data while searching.
Returns:
list[int]: list of k most similar embeddings IDs to the query embedding.
"""
raise NotImplementedError
@abstractmethod
def optimize(self, data_size: int | None = None):
"""optimize will be called between insertion and search in performance cases.
Should be blocked until the vectorDB is ready to be tested on
heavy performance cases.
Time(insert the dataset) + Time(optimize) will be recorded as "load_duration" metric
Optimize's execution time is limited, the limited time is based on cases.
"""
raise NotImplementedError