66//
77// Of course, you can have a look at the tantivy's built-in collectors
88// such as the `CountCollector` for more examples.
9-
10- // ---
11- // Importing tantivy...
9+ use std:: fmt:: Debug ;
1210use std:: marker:: PhantomData ;
13- use std:: sync:: Arc ;
1411
15- use columnar:: { ColumnValues , DynamicColumn , HasAssociatedColumnType } ;
12+ use columnar:: { BytesColumn , Column , DynamicColumn , HasAssociatedColumnType } ;
1613
1714use crate :: collector:: { Collector , SegmentCollector } ;
1815use crate :: schema:: Field ;
19- use crate :: { Score , SegmentReader , TantivyError } ;
16+ use crate :: { DocId , Score , SegmentReader , TantivyError } ;
2017
2118/// The `FilterCollector` filters docs using a fast field value and a predicate.
22- /// Only the documents for which the predicate returned "true" will be passed on to the next
23- /// collector.
19+ ///
20+ /// Only the documents containing at least one value for which the predicate returns `true`
21+ /// will be passed on to the next collector.
22+ ///
23+ /// In other words,
24+ /// - documents with no values are filtered out.
25+ /// - documents with several values are accepted if at least one value matches the predicate.
26+ ///
2427///
2528/// ```rust
2629/// use tantivy::collector::{TopDocs, FilterCollector};
2730/// use tantivy::query::QueryParser;
28- /// use tantivy::schema::{Schema, TEXT, INDEXED, FAST};
31+ /// use tantivy::schema::{Schema, TEXT, FAST};
2932/// use tantivy::{doc, DocAddress, Index};
3033///
3134/// # fn main() -> tantivy::Result<()> {
3235/// let mut schema_builder = Schema::builder();
3336/// let title = schema_builder.add_text_field("title", TEXT);
34- /// let price = schema_builder.add_u64_field("price", INDEXED | FAST);
37+ /// let price = schema_builder.add_u64_field("price", FAST);
3538/// let schema = schema_builder.build();
3639/// let index = Index::create_in_ram(schema);
3740///
@@ -47,20 +50,24 @@ use crate::{Score, SegmentReader, TantivyError};
4750///
4851/// let query_parser = QueryParser::for_index(&index, vec![title]);
4952/// let query = query_parser.parse_query("diary")?;
50- /// let no_filter_collector = FilterCollector::new(price, & |value: u64| value > 20_120u64, TopDocs::with_limit(2));
53+ /// let no_filter_collector = FilterCollector::new(price, |value: u64| value > 20_120u64, TopDocs::with_limit(2));
5154/// let top_docs = searcher.search(&query, &no_filter_collector)?;
5255///
5356/// assert_eq!(top_docs.len(), 1);
5457/// assert_eq!(top_docs[0].1, DocAddress::new(0, 1));
5558///
56- /// let filter_all_collector: FilterCollector<_, _, u64> = FilterCollector::new(price, & |value| value < 5u64, TopDocs::with_limit(2));
59+ /// let filter_all_collector: FilterCollector<_, _, u64> = FilterCollector::new(price, |value| value < 5u64, TopDocs::with_limit(2));
5760/// let filtered_top_docs = searcher.search(&query, &filter_all_collector)?;
5861///
5962/// assert_eq!(filtered_top_docs.len(), 0);
6063/// # Ok(())
6164/// # }
6265/// ```
63- pub struct FilterCollector < TCollector , TPredicate , TPredicateValue : Default >
66+ ///
67+ /// Note that this is limited to fast fields which implement the [`FastValue`] trait,
68+ /// e.g. `u64` but not `&[u8]`. To filter based on a bytes fast field,
69+ /// use a [`BytesFilterCollector`] instead.
70+ pub struct FilterCollector < TCollector , TPredicate , TPredicateValue >
6471where TPredicate : ' static + Clone
6572{
6673 field : Field ,
@@ -69,19 +76,15 @@ where TPredicate: 'static + Clone
6976 t_predicate_value : PhantomData < TPredicateValue > ,
7077}
7178
72- impl < TCollector , TPredicate , TPredicateValue : Default >
79+ impl < TCollector , TPredicate , TPredicateValue >
7380 FilterCollector < TCollector , TPredicate , TPredicateValue >
7481where
7582 TCollector : Collector + Send + Sync ,
7683 TPredicate : Fn ( TPredicateValue ) -> bool + Send + Sync + Clone ,
7784{
78- /// Create a new FilterCollector.
79- pub fn new (
80- field : Field ,
81- predicate : TPredicate ,
82- collector : TCollector ,
83- ) -> FilterCollector < TCollector , TPredicate , TPredicateValue > {
84- FilterCollector {
85+ /// Create a new `FilterCollector`.
86+ pub fn new ( field : Field , predicate : TPredicate , collector : TCollector ) -> Self {
87+ Self {
8588 field,
8689 predicate,
8790 collector,
@@ -90,16 +93,14 @@ where
9093 }
9194}
9295
93- impl < TCollector , TPredicate , TPredicateValue : Default > Collector
96+ impl < TCollector , TPredicate , TPredicateValue > Collector
9497 for FilterCollector < TCollector , TPredicate , TPredicateValue >
9598where
9699 TCollector : Collector + Send + Sync ,
97100 TPredicate : ' static + Fn ( TPredicateValue ) -> bool + Send + Sync + Clone ,
98101 TPredicateValue : HasAssociatedColumnType ,
99102 DynamicColumn : Into < Option < columnar:: Column < TPredicateValue > > > ,
100103{
101- // That's the type of our result.
102- // Our standard deviation will be a float.
103104 type Fruit = TCollector :: Fruit ;
104105
105106 type Child = FilterSegmentCollector < TCollector :: Child , TPredicate , TPredicateValue > ;
@@ -108,7 +109,7 @@ where
108109 & self ,
109110 segment_local_id : u32 ,
110111 segment_reader : & SegmentReader ,
111- ) -> crate :: Result < FilterSegmentCollector < TCollector :: Child , TPredicate , TPredicateValue > > {
112+ ) -> crate :: Result < Self :: Child > {
112113 let schema = segment_reader. schema ( ) ;
113114 let field_entry = schema. get_field_entry ( self . field ) ;
114115 if !field_entry. is_fast ( ) {
@@ -118,16 +119,16 @@ where
118119 ) ) ) ;
119120 }
120121
121- let fast_field_reader = segment_reader
122+ let column_opt = segment_reader
122123 . fast_fields ( )
123- . column_first_or_default ( schema . get_field_name ( self . field ) ) ?;
124+ . column_opt ( field_entry . name ( ) ) ?;
124125
125126 let segment_collector = self
126127 . collector
127128 . for_segment ( segment_local_id, segment_reader) ?;
128129
129130 Ok ( FilterSegmentCollector {
130- fast_field_reader ,
131+ column_opt ,
131132 segment_collector,
132133 predicate : self . predicate . clone ( ) ,
133134 t_predicate_value : PhantomData ,
@@ -146,35 +147,208 @@ where
146147 }
147148}
148149
149- pub struct FilterSegmentCollector < TSegmentCollector , TPredicate , TPredicateValue >
150- where
151- TPredicate : ' static ,
152- DynamicColumn : Into < Option < columnar:: Column < TPredicateValue > > > ,
153- {
154- fast_field_reader : Arc < dyn ColumnValues < TPredicateValue > > ,
150+ pub struct FilterSegmentCollector < TSegmentCollector , TPredicate , TPredicateValue > {
151+ column_opt : Option < Column < TPredicateValue > > ,
155152 segment_collector : TSegmentCollector ,
156153 predicate : TPredicate ,
157154 t_predicate_value : PhantomData < TPredicateValue > ,
158155}
159156
157+ impl < TSegmentCollector , TPredicate , TPredicateValue >
158+ FilterSegmentCollector < TSegmentCollector , TPredicate , TPredicateValue >
159+ where
160+ TPredicateValue : PartialOrd + Copy + Debug + Send + Sync + ' static ,
161+ TPredicate : ' static + Fn ( TPredicateValue ) -> bool + Send + Sync ,
162+ {
163+ #[ inline]
164+ fn accept_document ( & self , doc_id : DocId ) -> bool {
165+ if let Some ( column) = & self . column_opt {
166+ for val in column. values_for_doc ( doc_id) {
167+ if ( self . predicate ) ( val) {
168+ return true ;
169+ }
170+ }
171+ }
172+ false
173+ }
174+ }
175+
160176impl < TSegmentCollector , TPredicate , TPredicateValue > SegmentCollector
161177 for FilterSegmentCollector < TSegmentCollector , TPredicate , TPredicateValue >
162178where
163179 TSegmentCollector : SegmentCollector ,
164180 TPredicateValue : HasAssociatedColumnType ,
165- TPredicate : ' static + Fn ( TPredicateValue ) -> bool + Send + Sync ,
166- DynamicColumn : Into < Option < columnar:: Column < TPredicateValue > > > ,
181+ TPredicate : ' static + Fn ( TPredicateValue ) -> bool + Send + Sync , /* DynamicColumn: Into<Option<columnar::Column<TPredicateValue>>> */
182+ {
183+ type Fruit = TSegmentCollector :: Fruit ;
184+
185+ fn collect ( & mut self , doc : u32 , score : Score ) {
186+ if self . accept_document ( doc) {
187+ self . segment_collector . collect ( doc, score) ;
188+ }
189+ }
190+
191+ fn harvest ( self ) -> TSegmentCollector :: Fruit {
192+ self . segment_collector . harvest ( )
193+ }
194+ }
195+
196+ /// A variant of the [`FilterCollector`] specialized for bytes fast fields, i.e.
197+ /// it transparently wraps an inner [`Collector`] but filters documents
198+ /// based on the result of applying the predicate to the bytes fast field.
199+ ///
200+ /// A document is accepted if and only if the predicate returns `true` for at least one value.
201+ ///
202+ /// In other words,
203+ /// - documents with no values are filtered out.
204+ /// - documents with several values are accepted if at least one value matches the predicate.
205+ ///
206+ /// ```rust
207+ /// use tantivy::collector::{TopDocs, BytesFilterCollector};
208+ /// use tantivy::query::QueryParser;
209+ /// use tantivy::schema::{Schema, TEXT, FAST};
210+ /// use tantivy::{doc, DocAddress, Index};
211+ ///
212+ /// # fn main() -> tantivy::Result<()> {
213+ /// let mut schema_builder = Schema::builder();
214+ /// let title = schema_builder.add_text_field("title", TEXT);
215+ /// let barcode = schema_builder.add_bytes_field("barcode", FAST);
216+ /// let schema = schema_builder.build();
217+ /// let index = Index::create_in_ram(schema);
218+ ///
219+ /// let mut index_writer = index.writer_with_num_threads(1, 10_000_000)?;
220+ /// index_writer.add_document(doc!(title => "The Name of the Wind", barcode => &b"010101"[..]))?;
221+ /// index_writer.add_document(doc!(title => "The Diary of Muadib", barcode => &b"110011"[..]))?;
222+ /// index_writer.add_document(doc!(title => "A Dairy Cow", barcode => &b"110111"[..]))?;
223+ /// index_writer.add_document(doc!(title => "The Diary of a Young Girl", barcode => &b"011101"[..]))?;
224+ /// index_writer.add_document(doc!(title => "Bridget Jones's Diary"))?;
225+ /// index_writer.commit()?;
226+ ///
227+ /// let reader = index.reader()?;
228+ /// let searcher = reader.searcher();
229+ ///
230+ /// let query_parser = QueryParser::for_index(&index, vec![title]);
231+ /// let query = query_parser.parse_query("diary")?;
232+ /// let filter_collector = BytesFilterCollector::new(barcode, |bytes: &[u8]| bytes.starts_with(b"01"), TopDocs::with_limit(2));
233+ /// let top_docs = searcher.search(&query, &filter_collector)?;
234+ ///
235+ /// assert_eq!(top_docs.len(), 1);
236+ /// assert_eq!(top_docs[0].1, DocAddress::new(0, 3));
237+ /// # Ok(())
238+ /// # }
239+ /// ```
240+ pub struct BytesFilterCollector < TCollector , TPredicate >
241+ where TPredicate : ' static + Clone
242+ {
243+ field : Field ,
244+ collector : TCollector ,
245+ predicate : TPredicate ,
246+ }
247+
248+ impl < TCollector , TPredicate > BytesFilterCollector < TCollector , TPredicate >
249+ where
250+ TCollector : Collector + Send + Sync ,
251+ TPredicate : Fn ( & [ u8 ] ) -> bool + Send + Sync + Clone ,
252+ {
253+ /// Create a new `BytesFilterCollector`.
254+ pub fn new ( field : Field , predicate : TPredicate , collector : TCollector ) -> Self {
255+ Self {
256+ field,
257+ predicate,
258+ collector,
259+ }
260+ }
261+ }
262+
263+ impl < TCollector , TPredicate > Collector for BytesFilterCollector < TCollector , TPredicate >
264+ where
265+ TCollector : Collector + Send + Sync ,
266+ TPredicate : ' static + Fn ( & [ u8 ] ) -> bool + Send + Sync + Clone ,
267+ {
268+ type Fruit = TCollector :: Fruit ;
269+
270+ type Child = BytesFilterSegmentCollector < TCollector :: Child , TPredicate > ;
271+
272+ fn for_segment (
273+ & self ,
274+ segment_local_id : u32 ,
275+ segment_reader : & SegmentReader ,
276+ ) -> crate :: Result < Self :: Child > {
277+ let schema = segment_reader. schema ( ) ;
278+ let field_name = schema. get_field_name ( self . field ) ;
279+
280+ let column_opt = segment_reader. fast_fields ( ) . bytes ( field_name) ?;
281+
282+ let segment_collector = self
283+ . collector
284+ . for_segment ( segment_local_id, segment_reader) ?;
285+
286+ Ok ( BytesFilterSegmentCollector {
287+ column_opt,
288+ segment_collector,
289+ predicate : self . predicate . clone ( ) ,
290+ buffer : Vec :: new ( ) ,
291+ } )
292+ }
293+
294+ fn requires_scoring ( & self ) -> bool {
295+ self . collector . requires_scoring ( )
296+ }
297+
298+ fn merge_fruits (
299+ & self ,
300+ segment_fruits : Vec < <TCollector :: Child as SegmentCollector >:: Fruit > ,
301+ ) -> crate :: Result < TCollector :: Fruit > {
302+ self . collector . merge_fruits ( segment_fruits)
303+ }
304+ }
305+
306+ pub struct BytesFilterSegmentCollector < TSegmentCollector , TPredicate >
307+ where TPredicate : ' static
308+ {
309+ column_opt : Option < BytesColumn > ,
310+ segment_collector : TSegmentCollector ,
311+ predicate : TPredicate ,
312+ buffer : Vec < u8 > ,
313+ }
314+
315+ impl < TSegmentCollector , TPredicate > BytesFilterSegmentCollector < TSegmentCollector , TPredicate >
316+ where
317+ TSegmentCollector : SegmentCollector ,
318+ TPredicate : ' static + Fn ( & [ u8 ] ) -> bool + Send + Sync ,
319+ {
320+ #[ inline]
321+ fn accept_document ( & mut self , doc_id : DocId ) -> bool {
322+ if let Some ( column) = & self . column_opt {
323+ for ord in column. term_ords ( doc_id) {
324+ self . buffer . clear ( ) ;
325+
326+ let found = column. ord_to_bytes ( ord, & mut self . buffer ) . unwrap_or ( false ) ;
327+
328+ if found && ( self . predicate ) ( & self . buffer ) {
329+ return true ;
330+ }
331+ }
332+ }
333+ false
334+ }
335+ }
336+
337+ impl < TSegmentCollector , TPredicate > SegmentCollector
338+ for BytesFilterSegmentCollector < TSegmentCollector , TPredicate >
339+ where
340+ TSegmentCollector : SegmentCollector ,
341+ TPredicate : ' static + Fn ( & [ u8 ] ) -> bool + Send + Sync ,
167342{
168343 type Fruit = TSegmentCollector :: Fruit ;
169344
170345 fn collect ( & mut self , doc : u32 , score : Score ) {
171- let value = self . fast_field_reader . get_val ( doc) ;
172- if ( self . predicate ) ( value) {
173- self . segment_collector . collect ( doc, score)
346+ if self . accept_document ( doc) {
347+ self . segment_collector . collect ( doc, score) ;
174348 }
175349 }
176350
177- fn harvest ( self ) -> < TSegmentCollector as SegmentCollector > :: Fruit {
351+ fn harvest ( self ) -> TSegmentCollector :: Fruit {
178352 self . segment_collector . harvest ( )
179353 }
180354}
0 commit comments