Skip to content

Commit 6f4416d

Browse files
committed
feat: enhance VectorChord with improved type safety and search functionality 🎉✨
- Updated quantization_type to use Literal for better type validation. - Refactored search methods to streamline query generation and filtering. - Added support for dynamic where clauses in search queries. 🔍
1 parent ccb4f7c commit 6f4416d

File tree

2 files changed

+42
-43
lines changed

2 files changed

+42
-43
lines changed

vectordb_bench/backend/clients/vectorchord/config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from abc import abstractmethod
2-
from typing import LiteralString, TypedDict
2+
from typing import Literal, LiteralString, TypedDict
33

44
from pydantic import BaseModel, SecretStr
55

@@ -64,7 +64,7 @@ class VectorChordIndexConfig(BaseModel, DBCaseConfig):
6464
metric_type: MetricType | None = None
6565
create_index_before_load: bool = False
6666
create_index_after_load: bool = True
67-
quantization_type: str = "vector" # vector, halfvec, rabitq8, rabitq4
67+
quantization_type: Literal["vector", "halfvec", "rabitq8", "rabitq4"] = "vector"
6868

6969
def parse_metric(self) -> str:
7070
ops = _METRIC_OPS.get(self.quantization_type, _METRIC_OPS["vector"])

vectordb_bench/backend/clients/vectorchord/vectorchord.py

Lines changed: 40 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from pgvector.psycopg import register_vector
1111
from psycopg import Connection, Cursor, sql
1212

13+
from ...filter import Filter, FilterOp
1314
from ..api import VectorDB
1415
from .config import VectorChordConfigDict, VectorChordIndexConfig
1516

@@ -19,11 +20,17 @@
1920
class VectorChord(VectorDB):
2021
"""Use psycopg instructions"""
2122

23+
thread_safe: bool = False
24+
supported_filter_types: list[FilterOp] = [
25+
FilterOp.NonFilter,
26+
FilterOp.NumGE,
27+
]
28+
2229
conn: psycopg.Connection[Any] | None = None
2330
cursor: psycopg.Cursor[Any] | None = None
2431

25-
_unfiltered_search: sql.Composed
26-
_filtered_search: sql.Composed
32+
_search: sql.Composed
33+
where_clause: str = ""
2734

2835
def __init__(
2936
self,
@@ -79,11 +86,11 @@ def __init__(
7986
@staticmethod
8087
def _create_connection(**kwargs) -> tuple[Connection, Cursor]:
8188
conn = psycopg.connect(**kwargs)
82-
conn.cursor().execute("CREATE EXTENSION IF NOT EXISTS vchord CASCADE")
89+
cursor = conn.cursor()
90+
cursor.execute("CREATE EXTENSION IF NOT EXISTS vchord CASCADE")
8391
conn.commit()
8492
register_vector(conn)
8593
conn.autocommit = False
86-
cursor = conn.cursor()
8794

8895
assert conn is not None, "Connection is not initialized"
8996
assert cursor is not None, "Cursor is not initialized"
@@ -101,35 +108,12 @@ def init(self) -> Generator[None, None, None]:
101108
for setting_name, setting_val in session_options.items():
102109
command = sql.SQL("SET {setting_name} " + "= {setting_val};").format(
103110
setting_name=sql.Identifier(setting_name),
104-
setting_val=sql.Identifier(str(setting_val)),
111+
setting_val=sql.Literal(str(setting_val)),
105112
)
106113
log.debug(command.as_string(self.cursor))
107114
self.cursor.execute(command)
108115
self.conn.commit()
109116

110-
# Search query cast type: rabitq8/rabitq4 queries still accept ::vector input
111-
cast_type = "vector"
112-
113-
self._filtered_search = sql.Composed(
114-
[
115-
sql.SQL("SELECT id FROM public.{} WHERE id >= %s ORDER BY embedding ").format(
116-
sql.Identifier(self.table_name),
117-
),
118-
sql.SQL(self.case_config.search_param()["metric_fun_op"]),
119-
sql.SQL(f" %s::{cast_type} LIMIT %s::int"),
120-
],
121-
)
122-
123-
self._unfiltered_search = sql.Composed(
124-
[
125-
sql.SQL("SELECT id FROM public.{} ORDER BY embedding ").format(
126-
sql.Identifier(self.table_name),
127-
),
128-
sql.SQL(self.case_config.search_param()["metric_fun_op"]),
129-
sql.SQL(f" %s::{cast_type} LIMIT %s::int"),
130-
],
131-
)
132-
133117
try:
134118
yield
135119
finally:
@@ -228,7 +212,7 @@ def _create_index(self):
228212
else:
229213
with_clause = sql.SQL(";")
230214

231-
full_sql = (index_create_sql + with_clause).join(" ")
215+
full_sql = index_create_sql + sql.SQL(" ") + with_clause
232216
log.debug(full_sql.as_string(self.cursor))
233217
self.cursor.execute(full_sql)
234218
self.conn.commit()
@@ -299,26 +283,41 @@ def insert_embeddings(
299283
log.warning(f"Failed to insert data into vectorchord table ({self.table_name}), error: {e}")
300284
return 0, e
301285

286+
def _generate_search_query(self) -> sql.Composed:
287+
# Search query cast type: rabitq8/rabitq4 queries still accept ::vector input
288+
cast_type = "vector"
289+
return sql.Composed(
290+
[
291+
sql.SQL("SELECT id FROM public.{table_name} {where_clause} ORDER BY embedding ").format(
292+
table_name=sql.Identifier(self.table_name),
293+
where_clause=sql.SQL(self.where_clause),
294+
),
295+
sql.SQL(self.case_config.search_param()["metric_fun_op"]),
296+
sql.SQL(f" %s::{cast_type} LIMIT %s::int"),
297+
],
298+
)
299+
300+
def prepare_filter(self, filters: Filter):
301+
if filters.type == FilterOp.NonFilter:
302+
self.where_clause = ""
303+
elif filters.type == FilterOp.NumGE:
304+
self.where_clause = f"WHERE {self._primary_field} >= {filters.int_value}"
305+
else:
306+
msg = f"Not support Filter for VectorChord - {filters}"
307+
raise ValueError(msg)
308+
309+
self._search = self._generate_search_query()
310+
302311
def search_embedding(
303312
self,
304313
query: list[float],
305314
k: int = 100,
306-
filters: dict | None = None,
307315
timeout: int | None = None,
316+
**kwargs: Any,
308317
) -> list[int]:
309318
assert self.conn is not None, "Connection is not initialized"
310319
assert self.cursor is not None, "Cursor is not initialized"
311320

312321
q = np.asarray(query)
313-
if filters:
314-
gt = filters.get("id")
315-
result = self.cursor.execute(
316-
self._filtered_search,
317-
(gt, q, k),
318-
prepare=True,
319-
binary=True,
320-
)
321-
else:
322-
result = self.cursor.execute(self._unfiltered_search, (q, k), prepare=True, binary=True)
323-
322+
result = self.cursor.execute(self._search, (q, k), prepare=True, binary=True)
324323
return [int(i[0]) for i in result.fetchall()]

0 commit comments

Comments
 (0)