Skip to content

Commit d327f14

Browse files
MattiaMolonr-dh
andauthored
feat: allow for multiple metadata filters on list fields (#179)
* feat: allow for multiple field filter when field is a list * fix: embedding serialization for llama model * fix: isolate RAGlite config for metadata filtering tests * fix: allow litellm to select allowed parameters per model and reintroduce temperature=0 * fix: Few shot prompt engineering * fix: completion thread safety * fix: restore original self-query prompt wording with few-shot examples * fix: preserve insertion order in self-query metadata deduplication * test: add tests for postregSQL OR conditions --------- Co-authored-by: r-dh <remydheygere@gmail.com>
1 parent 54f4f80 commit d327f14

13 files changed

Lines changed: 442 additions & 77 deletions

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,3 +78,6 @@ uv.lock
7878

7979
# VS Code
8080
.vscode/
81+
82+
# evals
83+
evals/

README.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,7 @@ insert_documents(documents, config=my_config)
201201
202202
> [!TIP]
203203
> 📝 Documents can include metadata by passing keyword arguments to `Document.from_text()` or `Document.from_path()`. This metadata can later be used for filtering during retrieval.
204+
> For list values, metadata is stored as-is (e.g. `domain=["open", "music"]`).
204205
205206
You may also want to expand the document metadata before insertion:
206207
@@ -308,6 +309,14 @@ chunk_ids_hybrid, _ = hybrid_search(
308309
user_prompt, num_results=20, metadata_filter={"topic": "physics"}, config=my_config
309310
) # Filter results to only include chunks from documents with topic="physics" (works with any search method)
310311
312+
# Multi-value filter in one field uses OR semantics:
313+
chunk_ids_or, _ = hybrid_search(
314+
user_prompt,
315+
num_results=20,
316+
metadata_filter={"domain": ["open", "music"]},
317+
config=my_config,
318+
) # Returns chunks where domain includes "open" OR "music".
319+
311320
# Retrieve chunks
312321
from raglite import retrieve_chunks
313322

src/raglite/_database.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@
4848
MetadataJSON = JSON().with_variant(JSONB(), "postgresql")
4949

5050

51-
def _adapt_metadata(metadata: Any) -> dict[str, MetadataValue | list[MetadataValue]]:
51+
def _adapt_metadata(metadata: Any) -> dict[str, list[MetadataValue]]:
5252
"""Adapt metadata to the format expected by the database."""
5353
if not metadata:
5454
return {}

src/raglite/_delete.py

Lines changed: 10 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,11 @@
11
"""Delete documents from the database."""
22

3-
import json
43
from contextlib import nullcontext
54
from pathlib import Path
65
from typing import Any, Literal
76

87
from filelock import FileLock
9-
from sqlalchemy import delete, func, text, update
10-
from sqlalchemy.dialects.postgresql import JSONB
8+
from sqlalchemy import delete, text, update
119
from sqlalchemy.engine import make_url
1210
from sqlalchemy.orm import load_only
1311
from sqlalchemy.orm.attributes import flag_modified
@@ -21,10 +19,10 @@
2119
Eval,
2220
IndexMetadata,
2321
Metadata,
24-
_adapt_metadata,
2522
create_database_engine,
2623
)
2724
from raglite._insert import _aggregate_metadata_from_documents
25+
from raglite._metadata_filter import build_metadata_filter_condition
2826
from raglite._typing import DocumentId
2927

3028

@@ -49,17 +47,14 @@ def _get_documents_with_metadata(
4947
metadata_filter: dict[str, Any], session: Session
5048
) -> list[DocumentId]:
5149
"""Get document IDs matching a metadata filter."""
52-
metadata_filter = _adapt_metadata(metadata_filter)
53-
54-
# Determine the filter condition based on the database engine
55-
if session.get_bind().dialect.name == "postgresql":
56-
condition = col(Document.metadata_).cast(JSONB).op("@>")(metadata_filter) # type: ignore[attr-defined]
57-
else:
58-
condition = func.json_contains(
59-
col(Document.metadata_), func.json(json.dumps(metadata_filter))
60-
)
61-
62-
statement = select(Document.id).where(condition)
50+
condition = build_metadata_filter_condition(
51+
Document.metadata_,
52+
metadata_filter,
53+
dialect=session.get_bind().dialect.name,
54+
)
55+
statement = select(Document.id)
56+
if condition is not None:
57+
statement = statement.where(condition)
6358

6459
return list(session.exec(statement).all())
6560

src/raglite/_embed.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""String embedder."""
22

33
from functools import partial
4+
from threading import Lock
45
from typing import Literal
56

67
import numpy as np
@@ -12,6 +13,8 @@
1213
from raglite._litellm import LlamaCppPythonLLM
1314
from raglite._typing import FloatMatrix, IntVector
1415

16+
LLAMA_EMBED_LOCK = Lock()
17+
1518

1619
def embed_strings_with_late_chunking( # noqa: C901,PLR0915
1720
sentences: list[str], *, config: RAGLiteConfig | None = None
@@ -116,7 +119,8 @@ def _create_segment(
116119
# Get the token embeddings of the entire segment, including preamble and content.
117120
segment_start_index, content_start_index, segment_end_index = segment
118121
segment_sentences = sentences[segment_start_index:segment_end_index]
119-
segment_embedding = np.asarray(embedder.embed("".join(segment_sentences)))
122+
with LLAMA_EMBED_LOCK:
123+
segment_embedding = np.asarray(embedder.embed("".join(segment_sentences)))
120124
# Split the segment embeddings into embedding matrices per sentence using the largest
121125
# remainder method.
122126
segment_tokens = num_tokens[segment_start_index:segment_end_index]
@@ -151,7 +155,8 @@ def _embed_string_batch(string_batch: list[str], *, config: RAGLiteConfig) -> Fl
151155
embedder = LlamaCppPythonLLM.llm(
152156
config.embedder, embedding=True, pooling_type=LLAMA_POOLING_TYPE_NONE
153157
)
154-
embeddings = np.asarray([np.mean(row, axis=0) for row in embedder.embed(string_batch)])
158+
with LLAMA_EMBED_LOCK:
159+
embeddings = np.asarray([np.mean(row, axis=0) for row in embedder.embed(string_batch)])
155160
else:
156161
# Use LiteLLM's API to embed the batch of strings.
157162
response = embedding(config.embedder, string_batch)

src/raglite/_litellm.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from collections.abc import AsyncIterator, Callable, Iterator
99
from functools import cache
1010
from io import StringIO
11+
from threading import Lock
1112
from typing import Any, ClassVar, cast
1213

1314
import httpx
@@ -35,6 +36,7 @@
3536

3637
# Reduce the logging level for LiteLLM, flashrank, and httpx.
3738
litellm.suppress_debug_info = True
39+
litellm.drop_params = True # Drop unsupported parameters for models like GPT-5
3840
os.environ["LITELLM_LOG"] = "WARNING"
3941
logging.getLogger("LiteLLM").setLevel(logging.WARNING)
4042
logging.getLogger("flashrank").setLevel(logging.WARNING)
@@ -62,8 +64,9 @@ class LlamaCppPythonLLM(CustomLLM):
6264
```
6365
"""
6466

65-
# Create a lock to prevent concurrent access to llama-cpp-python models.
67+
# Create locks to prevent concurrent access to llama-cpp-python models.
6668
streaming_lock: ClassVar[asyncio.Lock] = asyncio.Lock()
69+
completion_lock: ClassVar[Lock] = Lock()
6770

6871
# The set of supported OpenAI parameters is the intersection of [1] and [2]. Not included:
6972
# max_completion_tokens, stream_options, n, user, logprobs, top_logprobs, extra_headers.
@@ -198,10 +201,11 @@ def completion( # noqa: PLR0913
198201
llm = self.llm(model)
199202
llama_cpp_python_params = self._translate_openai_params(optional_params)
200203
llama_cpp_python_params = self._add_recommended_model_params(model, llama_cpp_python_params)
201-
response = cast(
202-
"llama_types.CreateChatCompletionResponse",
203-
llm.create_chat_completion(messages=messages, **llama_cpp_python_params),
204-
)
204+
with LlamaCppPythonLLM.completion_lock:
205+
response = cast(
206+
"llama_types.CreateChatCompletionResponse",
207+
llm.create_chat_completion(messages=messages, **llama_cpp_python_params),
208+
)
205209
litellm_model_response: ModelResponse = convert_to_model_response_object(
206210
response_object=response,
207211
model_response_object=model_response,

src/raglite/_metadata_filter.py

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
"""Helpers to build metadata filter conditions with consistent semantics."""
2+
3+
import json
4+
from collections.abc import Mapping
5+
from typing import Any
6+
7+
from sqlalchemy import and_, false, or_
8+
from sqlalchemy.dialects.postgresql import JSONB
9+
from sqlmodel import col, func
10+
11+
from raglite._database import _adapt_metadata
12+
from raglite._typing import MetadataFilter, MetadataValue
13+
14+
15+
def build_metadata_filter_condition(
16+
metadata_column: Any,
17+
metadata_filter: MetadataFilter | None,
18+
*,
19+
dialect: str,
20+
) -> Any:
21+
"""Build a SQLAlchemy condition for metadata filtering.
22+
23+
A list of values within the same field uses OR semantics.
24+
Different fields are combined with AND semantics.
25+
"""
26+
normalized_metadata_filter = _adapt_metadata(metadata_filter)
27+
if not normalized_metadata_filter:
28+
return None
29+
30+
field_conditions: list[Any] = []
31+
for metadata_name, metadata_values in normalized_metadata_filter.items():
32+
if not metadata_values:
33+
return false() # empty filters are considered unsatisfiable
34+
35+
value_conditions: list[Any] = []
36+
for metadata_value in metadata_values:
37+
single_value_filter = {metadata_name: [metadata_value]}
38+
if dialect == "postgresql":
39+
value_conditions.append(
40+
col(metadata_column).cast(JSONB).op("@>")(single_value_filter) # type: ignore[attr-defined]
41+
)
42+
elif dialect == "duckdb":
43+
value_conditions.append(
44+
func.json_contains(
45+
col(metadata_column), func.json(json.dumps(single_value_filter))
46+
)
47+
)
48+
else:
49+
error_message = f"Unsupported dialect: {dialect}."
50+
raise ValueError(error_message)
51+
field_conditions.append(or_(*value_conditions)) # combine values for the same field with OR
52+
return and_(*field_conditions) # combine different fields with AND
53+
54+
55+
def build_metadata_filter_sql(
56+
metadata_filter: Mapping[str, list[MetadataValue] | MetadataValue] | None,
57+
*,
58+
dialect: str,
59+
) -> tuple[str, dict[str, str]]:
60+
"""Build SQL fragment and bound parameters for metadata filtering.
61+
62+
A list of values within the same field uses OR semantics.
63+
Different fields are combined with AND semantics.
64+
65+
Returns
66+
-------
67+
sql_fragment : str
68+
A SQL fragment to be included in the WHERE clause, with placeholders for parameters.
69+
parameters : dict
70+
A dictionary of parameter names and their corresponding JSON string values to be used in
71+
the query execution
72+
"""
73+
normalized_metadata_filter = _adapt_metadata(metadata_filter)
74+
if not normalized_metadata_filter:
75+
return "", {}
76+
77+
field_sql_conditions: list[str] = []
78+
parameters: dict[str, str] = {}
79+
parameter_index = 0
80+
81+
for metadata_name, metadata_values in normalized_metadata_filter.items():
82+
if not metadata_values:
83+
return " AND 1=0", {}
84+
85+
value_sql_conditions: list[str] = []
86+
for metadata_value in metadata_values:
87+
parameter_name = f"metadata_filter_{parameter_index}"
88+
parameter_index += 1
89+
single_value_filter = json.dumps({metadata_name: [metadata_value]})
90+
parameters[parameter_name] = single_value_filter
91+
92+
if dialect == "postgresql":
93+
value_sql_conditions.append(f"metadata::jsonb @> :{parameter_name}")
94+
elif dialect == "duckdb":
95+
value_sql_conditions.append(f"json_contains(metadata, JSON(:{parameter_name}))")
96+
else:
97+
error_message = f"Unsupported dialect: {dialect}."
98+
raise ValueError(error_message)
99+
field_sql_conditions.append(f"({' OR '.join(value_sql_conditions)})")
100+
return f" AND {' AND '.join(field_sql_conditions)}", parameters

0 commit comments

Comments
 (0)