1010from pgvector .psycopg import register_vector
1111from psycopg import Connection , Cursor , sql
1212
13+ from ...filter import Filter , FilterOp
1314from ..api import VectorDB
1415from .config import VectorChordConfigDict , VectorChordIndexConfig
1516
1920class VectorChord (VectorDB ):
2021 """Use psycopg instructions"""
2122
23+ thread_safe : bool = False
24+ supported_filter_types : list [FilterOp ] = [
25+ FilterOp .NonFilter ,
26+ FilterOp .NumGE ,
27+ ]
28+
2229 conn : psycopg .Connection [Any ] | None = None
2330 cursor : psycopg .Cursor [Any ] | None = None
2431
25- _unfiltered_search : sql .Composed
26- _filtered_search : sql . Composed
32+ _search : sql .Composed
33+ where_clause : str = ""
2734
2835 def __init__ (
2936 self ,
@@ -79,11 +86,11 @@ def __init__(
7986 @staticmethod
8087 def _create_connection (** kwargs ) -> tuple [Connection , Cursor ]:
8188 conn = psycopg .connect (** kwargs )
82- conn .cursor ().execute ("CREATE EXTENSION IF NOT EXISTS vchord CASCADE" )
89+ cursor = conn .cursor ()
90+ cursor .execute ("CREATE EXTENSION IF NOT EXISTS vchord CASCADE" )
8391 conn .commit ()
8492 register_vector (conn )
8593 conn .autocommit = False
86- cursor = conn .cursor ()
8794
8895 assert conn is not None , "Connection is not initialized"
8996 assert cursor is not None , "Cursor is not initialized"
@@ -101,35 +108,12 @@ def init(self) -> Generator[None, None, None]:
101108 for setting_name , setting_val in session_options .items ():
102109 command = sql .SQL ("SET {setting_name} " + "= {setting_val};" ).format (
103110 setting_name = sql .Identifier (setting_name ),
104- setting_val = sql .Identifier (str (setting_val )),
111+ setting_val = sql .Literal (str (setting_val )),
105112 )
106113 log .debug (command .as_string (self .cursor ))
107114 self .cursor .execute (command )
108115 self .conn .commit ()
109116
110- # Search query cast type: rabitq8/rabitq4 queries still accept ::vector input
111- cast_type = "vector"
112-
113- self ._filtered_search = sql .Composed (
114- [
115- sql .SQL ("SELECT id FROM public.{} WHERE id >= %s ORDER BY embedding " ).format (
116- sql .Identifier (self .table_name ),
117- ),
118- sql .SQL (self .case_config .search_param ()["metric_fun_op" ]),
119- sql .SQL (f" %s::{ cast_type } LIMIT %s::int" ),
120- ],
121- )
122-
123- self ._unfiltered_search = sql .Composed (
124- [
125- sql .SQL ("SELECT id FROM public.{} ORDER BY embedding " ).format (
126- sql .Identifier (self .table_name ),
127- ),
128- sql .SQL (self .case_config .search_param ()["metric_fun_op" ]),
129- sql .SQL (f" %s::{ cast_type } LIMIT %s::int" ),
130- ],
131- )
132-
133117 try :
134118 yield
135119 finally :
@@ -228,7 +212,7 @@ def _create_index(self):
228212 else :
229213 with_clause = sql .SQL (";" )
230214
231- full_sql = ( index_create_sql + with_clause ). join (" " )
215+ full_sql = index_create_sql + sql . SQL (" " ) + with_clause
232216 log .debug (full_sql .as_string (self .cursor ))
233217 self .cursor .execute (full_sql )
234218 self .conn .commit ()
@@ -299,26 +283,41 @@ def insert_embeddings(
299283 log .warning (f"Failed to insert data into vectorchord table ({ self .table_name } ), error: { e } " )
300284 return 0 , e
301285
286+ def _generate_search_query (self ) -> sql .Composed :
287+ # Search query cast type: rabitq8/rabitq4 queries still accept ::vector input
288+ cast_type = "vector"
289+ return sql .Composed (
290+ [
291+ sql .SQL ("SELECT id FROM public.{table_name} {where_clause} ORDER BY embedding " ).format (
292+ table_name = sql .Identifier (self .table_name ),
293+ where_clause = sql .SQL (self .where_clause ),
294+ ),
295+ sql .SQL (self .case_config .search_param ()["metric_fun_op" ]),
296+ sql .SQL (f" %s::{ cast_type } LIMIT %s::int" ),
297+ ],
298+ )
299+
300+ def prepare_filter (self , filters : Filter ):
301+ if filters .type == FilterOp .NonFilter :
302+ self .where_clause = ""
303+ elif filters .type == FilterOp .NumGE :
304+ self .where_clause = f"WHERE { self ._primary_field } >= { filters .int_value } "
305+ else :
306+ msg = f"Not support Filter for VectorChord - { filters } "
307+ raise ValueError (msg )
308+
309+ self ._search = self ._generate_search_query ()
310+
302311 def search_embedding (
303312 self ,
304313 query : list [float ],
305314 k : int = 100 ,
306- filters : dict | None = None ,
307315 timeout : int | None = None ,
316+ ** kwargs : Any ,
308317 ) -> list [int ]:
309318 assert self .conn is not None , "Connection is not initialized"
310319 assert self .cursor is not None , "Cursor is not initialized"
311320
312321 q = np .asarray (query )
313- if filters :
314- gt = filters .get ("id" )
315- result = self .cursor .execute (
316- self ._filtered_search ,
317- (gt , q , k ),
318- prepare = True ,
319- binary = True ,
320- )
321- else :
322- result = self .cursor .execute (self ._unfiltered_search , (q , k ), prepare = True , binary = True )
323-
322+ result = self .cursor .execute (self ._search , (q , k ), prepare = True , binary = True )
324323 return [int (i [0 ]) for i in result .fetchall ()]
0 commit comments