Skip to content

Commit 2f2c60d

Browse files
committed
add float implementations
1 parent a430ef8 commit 2f2c60d

1 file changed

Lines changed: 167 additions & 2 deletions

File tree

  • datafusion/physical-expr/src/expressions

datafusion/physical-expr/src/expressions/in_list.rs

Lines changed: 167 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -151,8 +151,11 @@ fn instantiate_static_filter(
151151
DataType::UInt16 => Ok(Arc::new(UInt16StaticFilter::try_new(&in_array)?)),
152152
DataType::UInt32 => Ok(Arc::new(UInt32StaticFilter::try_new(&in_array)?)),
153153
DataType::UInt64 => Ok(Arc::new(UInt64StaticFilter::try_new(&in_array)?)),
154+
// Float primitive types (use ordered wrappers for Hash/Eq)
155+
DataType::Float32 => Ok(Arc::new(Float32StaticFilter::try_new(&in_array)?)),
156+
DataType::Float64 => Ok(Arc::new(Float64StaticFilter::try_new(&in_array)?)),
154157
_ => {
155-
/* fall through to generic implementation for unsupported types (Float32/Float64, Struct, etc.) */
158+
/* fall through to generic implementation for unsupported types (Struct, etc.) */
156159
Ok(Arc::new(ArrayStaticFilter::try_new(in_array)?))
157160
}
158161
}
@@ -210,6 +213,56 @@ impl ArrayStaticFilter {
210213
}
211214
}
212215

216+
/// Wrapper for f32 that implements Hash and Eq using IEEE 754 total ordering.
217+
/// This treats NaN values as equal to each other (using total_cmp).
218+
#[derive(Clone, Copy)]
219+
struct OrderedFloat32(f32);
220+
221+
impl Hash for OrderedFloat32 {
222+
fn hash<H: Hasher>(&self, state: &mut H) {
223+
self.0.to_ne_bytes().hash(state);
224+
}
225+
}
226+
227+
impl PartialEq for OrderedFloat32 {
228+
fn eq(&self, other: &Self) -> bool {
229+
self.0.total_cmp(&other.0).is_eq()
230+
}
231+
}
232+
233+
impl Eq for OrderedFloat32 {}
234+
235+
impl From<f32> for OrderedFloat32 {
236+
fn from(v: f32) -> Self {
237+
Self(v)
238+
}
239+
}
240+
241+
/// Wrapper for f64 that implements Hash and Eq using IEEE 754 total ordering.
242+
/// This treats NaN values as equal to each other (using total_cmp).
243+
#[derive(Clone, Copy)]
244+
struct OrderedFloat64(f64);
245+
246+
impl Hash for OrderedFloat64 {
247+
fn hash<H: Hasher>(&self, state: &mut H) {
248+
self.0.to_ne_bytes().hash(state);
249+
}
250+
}
251+
252+
impl PartialEq for OrderedFloat64 {
253+
fn eq(&self, other: &Self) -> bool {
254+
self.0.total_cmp(&other.0).is_eq()
255+
}
256+
}
257+
258+
impl Eq for OrderedFloat64 {}
259+
260+
impl From<f64> for OrderedFloat64 {
261+
fn from(v: f64) -> Self {
262+
Self(v)
263+
}
264+
}
265+
213266
// Macro to generate specialized StaticFilter implementations for primitive types
214267
macro_rules! primitive_static_filter {
215268
($Name:ident, $ArrowType:ty) => {
@@ -334,7 +387,6 @@ macro_rules! primitive_static_filter {
334387
}
335388

336389
// Generate specialized filters for all integer primitive types
337-
// Note: Float32 and Float64 are excluded because they don't implement Hash/Eq due to NaN
338390
primitive_static_filter!(Int8StaticFilter, Int8Type);
339391
primitive_static_filter!(Int16StaticFilter, Int16Type);
340392
primitive_static_filter!(Int32StaticFilter, Int32Type);
@@ -344,6 +396,119 @@ primitive_static_filter!(UInt16StaticFilter, UInt16Type);
344396
primitive_static_filter!(UInt32StaticFilter, UInt32Type);
345397
primitive_static_filter!(UInt64StaticFilter, UInt64Type);
346398

399+
// Macro to generate specialized StaticFilter implementations for float types
400+
// Floats require a wrapper type (OrderedFloat*) to implement Hash/Eq due to NaN semantics
401+
macro_rules! float_static_filter {
402+
($Name:ident, $ArrowType:ty, $OrderedType:ty) => {
403+
struct $Name {
404+
null_count: usize,
405+
values: HashSet<$OrderedType>,
406+
}
407+
408+
impl $Name {
409+
fn try_new(in_array: &ArrayRef) -> Result<Self> {
410+
let in_array = in_array
411+
.as_primitive_opt::<$ArrowType>()
412+
.ok_or_else(|| exec_datafusion_err!("Failed to downcast an array to a '{}' array", stringify!($ArrowType)))?;
413+
414+
let mut values = HashSet::with_capacity(in_array.len());
415+
let null_count = in_array.null_count();
416+
417+
for v in in_array.iter().flatten() {
418+
values.insert(<$OrderedType>::from(v));
419+
}
420+
421+
Ok(Self { null_count, values })
422+
}
423+
}
424+
425+
impl StaticFilter for $Name {
426+
fn null_count(&self) -> usize {
427+
self.null_count
428+
}
429+
430+
fn contains(&self, v: &dyn Array, negated: bool) -> Result<BooleanArray> {
431+
// Handle dictionary arrays by recursing on the values
432+
downcast_dictionary_array! {
433+
v => {
434+
let values_contains = self.contains(v.values().as_ref(), negated)?;
435+
let result = take(&values_contains, v.keys(), None)?;
436+
return Ok(downcast_array(result.as_ref()))
437+
}
438+
_ => {}
439+
}
440+
441+
let v = v
442+
.as_primitive_opt::<$ArrowType>()
443+
.ok_or_else(|| exec_datafusion_err!("Failed to downcast an array to a '{}' array", stringify!($ArrowType)))?;
444+
445+
let haystack_has_nulls = self.null_count > 0;
446+
447+
let needle_values = v.values();
448+
let needle_nulls = v.nulls();
449+
let needle_has_nulls = v.null_count() > 0;
450+
451+
// Compute the "contains" result using collect_bool (fast batched approach)
452+
// This ignores nulls - we handle them separately
453+
let contains_buffer = if negated {
454+
BooleanBuffer::collect_bool(needle_values.len(), |i| {
455+
!self.values.contains(&<$OrderedType>::from(needle_values[i]))
456+
})
457+
} else {
458+
BooleanBuffer::collect_bool(needle_values.len(), |i| {
459+
self.values.contains(&<$OrderedType>::from(needle_values[i]))
460+
})
461+
};
462+
463+
// Compute the null mask
464+
// Output is null when:
465+
// 1. needle value is null, OR
466+
// 2. needle value is not in set AND haystack has nulls
467+
let result_nulls = match (needle_has_nulls, haystack_has_nulls) {
468+
(false, false) => {
469+
// No nulls anywhere
470+
None
471+
}
472+
(true, false) => {
473+
// Only needle has nulls - just use needle's null mask
474+
needle_nulls.cloned()
475+
}
476+
(false, true) => {
477+
// Only haystack has nulls - null where not-in-set
478+
let validity = if negated {
479+
!&contains_buffer
480+
} else {
481+
contains_buffer.clone()
482+
};
483+
Some(NullBuffer::new(validity))
484+
}
485+
(true, true) => {
486+
// Both have nulls - combine needle nulls with haystack-induced nulls
487+
let needle_validity = needle_nulls.map(|n| n.inner().clone())
488+
.unwrap_or_else(|| BooleanBuffer::new_set(needle_values.len()));
489+
490+
let haystack_validity = if negated {
491+
!&contains_buffer
492+
} else {
493+
contains_buffer.clone()
494+
};
495+
496+
// Combined validity: valid only where both are valid
497+
let combined_validity = &needle_validity & &haystack_validity;
498+
Some(NullBuffer::new(combined_validity))
499+
}
500+
};
501+
502+
Ok(BooleanArray::new(contains_buffer, result_nulls))
503+
}
504+
}
505+
};
506+
}
507+
508+
// Generate specialized filters for float types using ordered wrappers
509+
float_static_filter!(Float32StaticFilter, Float32Type, OrderedFloat32);
510+
float_static_filter!(Float64StaticFilter, Float64Type, OrderedFloat64);
511+
347512
/// Evaluates the list of expressions into an array, flattening any dictionaries
348513
fn evaluate_list(
349514
list: &[Arc<dyn PhysicalExpr>],

0 commit comments

Comments
 (0)