Skip to content

Commit 5c4c771

Browse files
author
Devanshu
committed
Implement preimage for floor function to enable predicate pushdown
This adds a `preimage` implementation for the `floor()` function that transforms `floor(x) = N` into `x >= N AND x < N+1`. This enables statistics-based predicate pushdown for queries using floor(). For example, a query like: SELECT * FROM t WHERE floor(price) = 100 Is rewritten to: SELECT * FROM t WHERE price >= 100 AND price < 101 This allows the query engine to leverage min/max statistics from Parquet row groups, significantly reducing the amount of data scanned. Benchmarks on the ClickBench hits dataset show: - 80% file pruning (89 out of 111 files skipped) - 70x fewer rows scanned (1.4M vs 100M)
1 parent a77e5a5 commit 5c4c771

1 file changed

Lines changed: 241 additions & 2 deletions

File tree

datafusion/functions/src/math/floor.rs

Lines changed: 241 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,15 @@ use arrow::datatypes::{
2525
};
2626
use datafusion_common::{Result, ScalarValue, exec_err};
2727
use datafusion_expr::interval_arithmetic::Interval;
28+
use datafusion_expr::preimage::PreimageResult;
29+
use datafusion_expr::simplify::SimplifyContext;
2830
use datafusion_expr::sort_properties::{ExprProperties, SortProperties};
2931
use datafusion_expr::{
30-
Coercion, ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature,
31-
TypeSignature, TypeSignatureClass, Volatility,
32+
Coercion, ColumnarValue, Documentation, Expr, ScalarFunctionArgs, ScalarUDFImpl,
33+
Signature, TypeSignature, TypeSignatureClass, Volatility,
3234
};
3335
use datafusion_macros::user_doc;
36+
use num_traits::{CheckedAdd, Float, One};
3437

3538
use super::decimal::{apply_decimal_op, floor_decimal_value};
3639

@@ -200,7 +203,243 @@ impl ScalarUDFImpl for FloorFunc {
200203
Interval::make_unbounded(&data_type)
201204
}
202205

206+
/// Compute the preimage for floor function.
207+
///
208+
/// For `floor(x) = N`, the preimage is `x >= N AND x < N + 1`
209+
/// because floor(x) = N for all x in [N, N+1).
210+
///
211+
/// This enables predicate pushdown optimizations, transforming:
212+
/// `floor(col) = 100` into `col >= 100 AND col < 101`
213+
fn preimage(
214+
&self,
215+
args: &[Expr],
216+
lit_expr: &Expr,
217+
_info: &SimplifyContext,
218+
) -> Result<PreimageResult> {
219+
// floor takes exactly one argument
220+
if args.len() != 1 {
221+
return Ok(PreimageResult::None);
222+
}
223+
224+
let arg = args[0].clone();
225+
226+
// Extract the literal value being compared to
227+
let Expr::Literal(lit_value, _) = lit_expr else {
228+
return Ok(PreimageResult::None);
229+
};
230+
231+
// Compute lower bound (N) and upper bound (N + 1) using helper functions
232+
let Some((lower, upper)) = (match lit_value {
233+
// Floating-point types
234+
ScalarValue::Float64(Some(n)) => float_preimage_bounds(*n).map(|(lo, hi)| {
235+
(
236+
ScalarValue::Float64(Some(lo)),
237+
ScalarValue::Float64(Some(hi)),
238+
)
239+
}),
240+
ScalarValue::Float32(Some(n)) => float_preimage_bounds(*n).map(|(lo, hi)| {
241+
(
242+
ScalarValue::Float32(Some(lo)),
243+
ScalarValue::Float32(Some(hi)),
244+
)
245+
}),
246+
247+
// Integer types
248+
ScalarValue::Int8(Some(n)) => int_preimage_bounds(*n).map(|(lo, hi)| {
249+
(ScalarValue::Int8(Some(lo)), ScalarValue::Int8(Some(hi)))
250+
}),
251+
ScalarValue::Int16(Some(n)) => int_preimage_bounds(*n).map(|(lo, hi)| {
252+
(ScalarValue::Int16(Some(lo)), ScalarValue::Int16(Some(hi)))
253+
}),
254+
ScalarValue::Int32(Some(n)) => int_preimage_bounds(*n).map(|(lo, hi)| {
255+
(ScalarValue::Int32(Some(lo)), ScalarValue::Int32(Some(hi)))
256+
}),
257+
ScalarValue::Int64(Some(n)) => int_preimage_bounds(*n).map(|(lo, hi)| {
258+
(ScalarValue::Int64(Some(lo)), ScalarValue::Int64(Some(hi)))
259+
}),
260+
261+
// Unsupported types
262+
_ => None,
263+
}) else {
264+
return Ok(PreimageResult::None);
265+
};
266+
267+
Ok(PreimageResult::Range {
268+
expr: arg,
269+
interval: Box::new(Interval::try_new(lower, upper)?),
270+
})
271+
}
272+
203273
fn documentation(&self) -> Option<&Documentation> {
204274
self.doc()
205275
}
206276
}
277+
278+
// ============ Helper functions for preimage bounds ============
279+
280+
/// Compute preimage bounds for floor function on floating-point types.
281+
/// For floor(x) = n, the preimage is [n, n+1).
282+
/// Returns None if the value is non-finite or would lose precision.
283+
fn float_preimage_bounds<F: Float>(n: F) -> Option<(F, F)> {
284+
let one = F::one();
285+
// Check for non-finite values (infinity, NaN) or precision loss at extreme values
286+
if !n.is_finite() || n + one <= n {
287+
return None;
288+
}
289+
Some((n, n + one))
290+
}
291+
292+
/// Compute preimage bounds for floor function on integer types.
293+
/// For floor(x) = n, the preimage is [n, n+1).
294+
/// Returns None if adding 1 would overflow.
295+
fn int_preimage_bounds<I: CheckedAdd + One + Copy>(n: I) -> Option<(I, I)> {
296+
let upper = n.checked_add(&I::one())?;
297+
Some((n, upper))
298+
}
299+
300+
#[cfg(test)]
301+
mod tests {
302+
use super::*;
303+
use datafusion_expr::col;
304+
305+
/// Helper to test valid preimage cases that should return a Range
306+
fn assert_preimage_range(
307+
input: ScalarValue,
308+
expected_lower: ScalarValue,
309+
expected_upper: ScalarValue,
310+
) {
311+
let floor_func = FloorFunc::new();
312+
let args = vec![col("x")];
313+
let lit_expr = Expr::Literal(input.clone(), None);
314+
let info = SimplifyContext::default();
315+
316+
let result = floor_func.preimage(&args, &lit_expr, &info).unwrap();
317+
318+
match result {
319+
PreimageResult::Range { expr, interval } => {
320+
assert_eq!(expr, col("x"));
321+
assert_eq!(interval.lower().clone(), expected_lower);
322+
assert_eq!(interval.upper().clone(), expected_upper);
323+
}
324+
PreimageResult::None => {
325+
panic!("Expected Range, got None for input {:?}", input)
326+
}
327+
}
328+
}
329+
330+
/// Helper to test cases that should return None
331+
fn assert_preimage_none(input: ScalarValue) {
332+
let floor_func = FloorFunc::new();
333+
let args = vec![col("x")];
334+
let lit_expr = Expr::Literal(input.clone(), None);
335+
let info = SimplifyContext::default();
336+
337+
let result = floor_func.preimage(&args, &lit_expr, &info).unwrap();
338+
assert!(
339+
matches!(result, PreimageResult::None),
340+
"Expected None for input {:?}",
341+
input
342+
);
343+
}
344+
345+
#[test]
346+
fn test_floor_preimage_valid_cases() {
347+
// Float64
348+
assert_preimage_range(
349+
ScalarValue::Float64(Some(100.0)),
350+
ScalarValue::Float64(Some(100.0)),
351+
ScalarValue::Float64(Some(101.0)),
352+
);
353+
// Float32
354+
assert_preimage_range(
355+
ScalarValue::Float32(Some(50.0)),
356+
ScalarValue::Float32(Some(50.0)),
357+
ScalarValue::Float32(Some(51.0)),
358+
);
359+
// Int64
360+
assert_preimage_range(
361+
ScalarValue::Int64(Some(42)),
362+
ScalarValue::Int64(Some(42)),
363+
ScalarValue::Int64(Some(43)),
364+
);
365+
// Int32
366+
assert_preimage_range(
367+
ScalarValue::Int32(Some(100)),
368+
ScalarValue::Int32(Some(100)),
369+
ScalarValue::Int32(Some(101)),
370+
);
371+
// Negative values
372+
assert_preimage_range(
373+
ScalarValue::Float64(Some(-5.0)),
374+
ScalarValue::Float64(Some(-5.0)),
375+
ScalarValue::Float64(Some(-4.0)),
376+
);
377+
// Zero
378+
assert_preimage_range(
379+
ScalarValue::Float64(Some(0.0)),
380+
ScalarValue::Float64(Some(0.0)),
381+
ScalarValue::Float64(Some(1.0)),
382+
);
383+
}
384+
385+
#[test]
386+
fn test_floor_preimage_integer_overflow() {
387+
// All integer types at MAX value should return None
388+
assert_preimage_none(ScalarValue::Int64(Some(i64::MAX)));
389+
assert_preimage_none(ScalarValue::Int32(Some(i32::MAX)));
390+
assert_preimage_none(ScalarValue::Int16(Some(i16::MAX)));
391+
assert_preimage_none(ScalarValue::Int8(Some(i8::MAX)));
392+
}
393+
394+
#[test]
395+
fn test_floor_preimage_float_edge_cases() {
396+
// Float64 edge cases
397+
assert_preimage_none(ScalarValue::Float64(Some(f64::INFINITY)));
398+
assert_preimage_none(ScalarValue::Float64(Some(f64::NEG_INFINITY)));
399+
assert_preimage_none(ScalarValue::Float64(Some(f64::NAN)));
400+
assert_preimage_none(ScalarValue::Float64(Some(f64::MAX))); // precision loss
401+
402+
// Float32 edge cases
403+
assert_preimage_none(ScalarValue::Float32(Some(f32::INFINITY)));
404+
assert_preimage_none(ScalarValue::Float32(Some(f32::NEG_INFINITY)));
405+
assert_preimage_none(ScalarValue::Float32(Some(f32::NAN)));
406+
assert_preimage_none(ScalarValue::Float32(Some(f32::MAX))); // precision loss
407+
}
408+
409+
#[test]
410+
fn test_floor_preimage_null_values() {
411+
assert_preimage_none(ScalarValue::Float64(None));
412+
assert_preimage_none(ScalarValue::Float32(None));
413+
assert_preimage_none(ScalarValue::Int64(None));
414+
}
415+
416+
#[test]
417+
fn test_floor_preimage_invalid_inputs() {
418+
let floor_func = FloorFunc::new();
419+
let info = SimplifyContext::default();
420+
421+
// Non-literal comparison value
422+
let result = floor_func.preimage(&[col("x")], &col("y"), &info).unwrap();
423+
assert!(
424+
matches!(result, PreimageResult::None),
425+
"Expected None for non-literal"
426+
);
427+
428+
// Wrong argument count (too many)
429+
let lit = Expr::Literal(ScalarValue::Float64(Some(100.0)), None);
430+
let result = floor_func
431+
.preimage(&[col("x"), col("y")], &lit, &info)
432+
.unwrap();
433+
assert!(
434+
matches!(result, PreimageResult::None),
435+
"Expected None for wrong arg count"
436+
);
437+
438+
// Wrong argument count (zero)
439+
let result = floor_func.preimage(&[], &lit, &info).unwrap();
440+
assert!(
441+
matches!(result, PreimageResult::None),
442+
"Expected None for zero args"
443+
);
444+
}
445+
}

0 commit comments

Comments
 (0)