@@ -151,8 +151,11 @@ fn instantiate_static_filter(
151151 DataType :: UInt16 => Ok ( Arc :: new ( UInt16StaticFilter :: try_new ( & in_array) ?) ) ,
152152 DataType :: UInt32 => Ok ( Arc :: new ( UInt32StaticFilter :: try_new ( & in_array) ?) ) ,
153153 DataType :: UInt64 => Ok ( Arc :: new ( UInt64StaticFilter :: try_new ( & in_array) ?) ) ,
154+ // Float primitive types (use ordered wrappers for Hash/Eq)
155+ DataType :: Float32 => Ok ( Arc :: new ( Float32StaticFilter :: try_new ( & in_array) ?) ) ,
156+ DataType :: Float64 => Ok ( Arc :: new ( Float64StaticFilter :: try_new ( & in_array) ?) ) ,
154157 _ => {
155- /* fall through to generic implementation for unsupported types (Float32/Float64, Struct, etc.) */
158+ /* fall through to generic implementation for unsupported types (Struct, etc.) */
156159 Ok ( Arc :: new ( ArrayStaticFilter :: try_new ( in_array) ?) )
157160 }
158161 }
@@ -210,6 +213,56 @@ impl ArrayStaticFilter {
210213 }
211214}
212215
216+ /// Wrapper for f32 that implements Hash and Eq using IEEE 754 total ordering.
217+ /// This treats NaN values as equal to each other (using total_cmp).
218+ #[ derive( Clone , Copy ) ]
219+ struct OrderedFloat32 ( f32 ) ;
220+
221+ impl Hash for OrderedFloat32 {
222+ fn hash < H : Hasher > ( & self , state : & mut H ) {
223+ self . 0 . to_ne_bytes ( ) . hash ( state) ;
224+ }
225+ }
226+
227+ impl PartialEq for OrderedFloat32 {
228+ fn eq ( & self , other : & Self ) -> bool {
229+ self . 0 . total_cmp ( & other. 0 ) . is_eq ( )
230+ }
231+ }
232+
233+ impl Eq for OrderedFloat32 { }
234+
235+ impl From < f32 > for OrderedFloat32 {
236+ fn from ( v : f32 ) -> Self {
237+ Self ( v)
238+ }
239+ }
240+
241+ /// Wrapper for f64 that implements Hash and Eq using IEEE 754 total ordering.
242+ /// This treats NaN values as equal to each other (using total_cmp).
243+ #[ derive( Clone , Copy ) ]
244+ struct OrderedFloat64 ( f64 ) ;
245+
246+ impl Hash for OrderedFloat64 {
247+ fn hash < H : Hasher > ( & self , state : & mut H ) {
248+ self . 0 . to_ne_bytes ( ) . hash ( state) ;
249+ }
250+ }
251+
252+ impl PartialEq for OrderedFloat64 {
253+ fn eq ( & self , other : & Self ) -> bool {
254+ self . 0 . total_cmp ( & other. 0 ) . is_eq ( )
255+ }
256+ }
257+
258+ impl Eq for OrderedFloat64 { }
259+
260+ impl From < f64 > for OrderedFloat64 {
261+ fn from ( v : f64 ) -> Self {
262+ Self ( v)
263+ }
264+ }
265+
213266// Macro to generate specialized StaticFilter implementations for primitive types
214267macro_rules! primitive_static_filter {
215268 ( $Name: ident, $ArrowType: ty) => {
@@ -334,7 +387,6 @@ macro_rules! primitive_static_filter {
334387}
335388
336389// Generate specialized filters for all integer primitive types
337- // Note: Float32 and Float64 are excluded because they don't implement Hash/Eq due to NaN
338390primitive_static_filter ! ( Int8StaticFilter , Int8Type ) ;
339391primitive_static_filter ! ( Int16StaticFilter , Int16Type ) ;
340392primitive_static_filter ! ( Int32StaticFilter , Int32Type ) ;
@@ -344,6 +396,119 @@ primitive_static_filter!(UInt16StaticFilter, UInt16Type);
344396primitive_static_filter ! ( UInt32StaticFilter , UInt32Type ) ;
345397primitive_static_filter ! ( UInt64StaticFilter , UInt64Type ) ;
346398
399+ // Macro to generate specialized StaticFilter implementations for float types
400+ // Floats require a wrapper type (OrderedFloat*) to implement Hash/Eq due to NaN semantics
401+ macro_rules! float_static_filter {
402+ ( $Name: ident, $ArrowType: ty, $OrderedType: ty) => {
403+ struct $Name {
404+ null_count: usize ,
405+ values: HashSet <$OrderedType>,
406+ }
407+
408+ impl $Name {
409+ fn try_new( in_array: & ArrayRef ) -> Result <Self > {
410+ let in_array = in_array
411+ . as_primitive_opt:: <$ArrowType>( )
412+ . ok_or_else( || exec_datafusion_err!( "Failed to downcast an array to a '{}' array" , stringify!( $ArrowType) ) ) ?;
413+
414+ let mut values = HashSet :: with_capacity( in_array. len( ) ) ;
415+ let null_count = in_array. null_count( ) ;
416+
417+ for v in in_array. iter( ) . flatten( ) {
418+ values. insert( <$OrderedType>:: from( v) ) ;
419+ }
420+
421+ Ok ( Self { null_count, values } )
422+ }
423+ }
424+
425+ impl StaticFilter for $Name {
426+ fn null_count( & self ) -> usize {
427+ self . null_count
428+ }
429+
430+ fn contains( & self , v: & dyn Array , negated: bool ) -> Result <BooleanArray > {
431+ // Handle dictionary arrays by recursing on the values
432+ downcast_dictionary_array! {
433+ v => {
434+ let values_contains = self . contains( v. values( ) . as_ref( ) , negated) ?;
435+ let result = take( & values_contains, v. keys( ) , None ) ?;
436+ return Ok ( downcast_array( result. as_ref( ) ) )
437+ }
438+ _ => { }
439+ }
440+
441+ let v = v
442+ . as_primitive_opt:: <$ArrowType>( )
443+ . ok_or_else( || exec_datafusion_err!( "Failed to downcast an array to a '{}' array" , stringify!( $ArrowType) ) ) ?;
444+
445+ let haystack_has_nulls = self . null_count > 0 ;
446+
447+ let needle_values = v. values( ) ;
448+ let needle_nulls = v. nulls( ) ;
449+ let needle_has_nulls = v. null_count( ) > 0 ;
450+
451+ // Compute the "contains" result using collect_bool (fast batched approach)
452+ // This ignores nulls - we handle them separately
453+ let contains_buffer = if negated {
454+ BooleanBuffer :: collect_bool( needle_values. len( ) , |i| {
455+ !self . values. contains( & <$OrderedType>:: from( needle_values[ i] ) )
456+ } )
457+ } else {
458+ BooleanBuffer :: collect_bool( needle_values. len( ) , |i| {
459+ self . values. contains( & <$OrderedType>:: from( needle_values[ i] ) )
460+ } )
461+ } ;
462+
463+ // Compute the null mask
464+ // Output is null when:
465+ // 1. needle value is null, OR
466+ // 2. needle value is not in set AND haystack has nulls
467+ let result_nulls = match ( needle_has_nulls, haystack_has_nulls) {
468+ ( false , false ) => {
469+ // No nulls anywhere
470+ None
471+ }
472+ ( true , false ) => {
473+ // Only needle has nulls - just use needle's null mask
474+ needle_nulls. cloned( )
475+ }
476+ ( false , true ) => {
477+ // Only haystack has nulls - null where not-in-set
478+ let validity = if negated {
479+ !& contains_buffer
480+ } else {
481+ contains_buffer. clone( )
482+ } ;
483+ Some ( NullBuffer :: new( validity) )
484+ }
485+ ( true , true ) => {
486+ // Both have nulls - combine needle nulls with haystack-induced nulls
487+ let needle_validity = needle_nulls. map( |n| n. inner( ) . clone( ) )
488+ . unwrap_or_else( || BooleanBuffer :: new_set( needle_values. len( ) ) ) ;
489+
490+ let haystack_validity = if negated {
491+ !& contains_buffer
492+ } else {
493+ contains_buffer. clone( )
494+ } ;
495+
496+ // Combined validity: valid only where both are valid
497+ let combined_validity = & needle_validity & & haystack_validity;
498+ Some ( NullBuffer :: new( combined_validity) )
499+ }
500+ } ;
501+
502+ Ok ( BooleanArray :: new( contains_buffer, result_nulls) )
503+ }
504+ }
505+ } ;
506+ }
507+
508+ // Generate specialized filters for float types using ordered wrappers
509+ float_static_filter ! ( Float32StaticFilter , Float32Type , OrderedFloat32 ) ;
510+ float_static_filter ! ( Float64StaticFilter , Float64Type , OrderedFloat64 ) ;
511+
347512/// Evaluates the list of expressions into an array, flattening any dictionaries
348513fn evaluate_list (
349514 list : & [ Arc < dyn PhysicalExpr > ] ,
0 commit comments