@@ -25,12 +25,15 @@ use arrow::datatypes::{
2525} ;
2626use datafusion_common:: { Result , ScalarValue , exec_err} ;
2727use datafusion_expr:: interval_arithmetic:: Interval ;
28+ use datafusion_expr:: preimage:: PreimageResult ;
29+ use datafusion_expr:: simplify:: SimplifyContext ;
2830use datafusion_expr:: sort_properties:: { ExprProperties , SortProperties } ;
2931use datafusion_expr:: {
30- Coercion , ColumnarValue , Documentation , ScalarFunctionArgs , ScalarUDFImpl , Signature ,
31- TypeSignature , TypeSignatureClass , Volatility ,
32+ Coercion , ColumnarValue , Documentation , Expr , ScalarFunctionArgs , ScalarUDFImpl ,
33+ Signature , TypeSignature , TypeSignatureClass , Volatility ,
3234} ;
3335use datafusion_macros:: user_doc;
36+ use num_traits:: { CheckedAdd , Float , One } ;
3437
3538use super :: decimal:: { apply_decimal_op, floor_decimal_value} ;
3639
@@ -200,7 +203,243 @@ impl ScalarUDFImpl for FloorFunc {
200203 Interval :: make_unbounded ( & data_type)
201204 }
202205
206+ /// Compute the preimage for floor function.
207+ ///
208+ /// For `floor(x) = N`, the preimage is `x >= N AND x < N + 1`
209+ /// because floor(x) = N for all x in [N, N+1).
210+ ///
211+ /// This enables predicate pushdown optimizations, transforming:
212+ /// `floor(col) = 100` into `col >= 100 AND col < 101`
213+ fn preimage (
214+ & self ,
215+ args : & [ Expr ] ,
216+ lit_expr : & Expr ,
217+ _info : & SimplifyContext ,
218+ ) -> Result < PreimageResult > {
219+ // floor takes exactly one argument
220+ if args. len ( ) != 1 {
221+ return Ok ( PreimageResult :: None ) ;
222+ }
223+
224+ let arg = args[ 0 ] . clone ( ) ;
225+
226+ // Extract the literal value being compared to
227+ let Expr :: Literal ( lit_value, _) = lit_expr else {
228+ return Ok ( PreimageResult :: None ) ;
229+ } ;
230+
231+ // Compute lower bound (N) and upper bound (N + 1) using helper functions
232+ let Some ( ( lower, upper) ) = ( match lit_value {
233+ // Floating-point types
234+ ScalarValue :: Float64 ( Some ( n) ) => float_preimage_bounds ( * n) . map ( |( lo, hi) | {
235+ (
236+ ScalarValue :: Float64 ( Some ( lo) ) ,
237+ ScalarValue :: Float64 ( Some ( hi) ) ,
238+ )
239+ } ) ,
240+ ScalarValue :: Float32 ( Some ( n) ) => float_preimage_bounds ( * n) . map ( |( lo, hi) | {
241+ (
242+ ScalarValue :: Float32 ( Some ( lo) ) ,
243+ ScalarValue :: Float32 ( Some ( hi) ) ,
244+ )
245+ } ) ,
246+
247+ // Integer types
248+ ScalarValue :: Int8 ( Some ( n) ) => int_preimage_bounds ( * n) . map ( |( lo, hi) | {
249+ ( ScalarValue :: Int8 ( Some ( lo) ) , ScalarValue :: Int8 ( Some ( hi) ) )
250+ } ) ,
251+ ScalarValue :: Int16 ( Some ( n) ) => int_preimage_bounds ( * n) . map ( |( lo, hi) | {
252+ ( ScalarValue :: Int16 ( Some ( lo) ) , ScalarValue :: Int16 ( Some ( hi) ) )
253+ } ) ,
254+ ScalarValue :: Int32 ( Some ( n) ) => int_preimage_bounds ( * n) . map ( |( lo, hi) | {
255+ ( ScalarValue :: Int32 ( Some ( lo) ) , ScalarValue :: Int32 ( Some ( hi) ) )
256+ } ) ,
257+ ScalarValue :: Int64 ( Some ( n) ) => int_preimage_bounds ( * n) . map ( |( lo, hi) | {
258+ ( ScalarValue :: Int64 ( Some ( lo) ) , ScalarValue :: Int64 ( Some ( hi) ) )
259+ } ) ,
260+
261+ // Unsupported types
262+ _ => None ,
263+ } ) else {
264+ return Ok ( PreimageResult :: None ) ;
265+ } ;
266+
267+ Ok ( PreimageResult :: Range {
268+ expr : arg,
269+ interval : Box :: new ( Interval :: try_new ( lower, upper) ?) ,
270+ } )
271+ }
272+
203273 fn documentation ( & self ) -> Option < & Documentation > {
204274 self . doc ( )
205275 }
206276}
277+
278+ // ============ Helper functions for preimage bounds ============
279+
280+ /// Compute preimage bounds for floor function on floating-point types.
281+ /// For floor(x) = n, the preimage is [n, n+1).
282+ /// Returns None if the value is non-finite or would lose precision.
283+ fn float_preimage_bounds < F : Float > ( n : F ) -> Option < ( F , F ) > {
284+ let one = F :: one ( ) ;
285+ // Check for non-finite values (infinity, NaN) or precision loss at extreme values
286+ if !n. is_finite ( ) || n + one <= n {
287+ return None ;
288+ }
289+ Some ( ( n, n + one) )
290+ }
291+
292+ /// Compute preimage bounds for floor function on integer types.
293+ /// For floor(x) = n, the preimage is [n, n+1).
294+ /// Returns None if adding 1 would overflow.
295+ fn int_preimage_bounds < I : CheckedAdd + One + Copy > ( n : I ) -> Option < ( I , I ) > {
296+ let upper = n. checked_add ( & I :: one ( ) ) ?;
297+ Some ( ( n, upper) )
298+ }
299+
300+ #[ cfg( test) ]
301+ mod tests {
302+ use super :: * ;
303+ use datafusion_expr:: col;
304+
305+ /// Helper to test valid preimage cases that should return a Range
306+ fn assert_preimage_range (
307+ input : ScalarValue ,
308+ expected_lower : ScalarValue ,
309+ expected_upper : ScalarValue ,
310+ ) {
311+ let floor_func = FloorFunc :: new ( ) ;
312+ let args = vec ! [ col( "x" ) ] ;
313+ let lit_expr = Expr :: Literal ( input. clone ( ) , None ) ;
314+ let info = SimplifyContext :: default ( ) ;
315+
316+ let result = floor_func. preimage ( & args, & lit_expr, & info) . unwrap ( ) ;
317+
318+ match result {
319+ PreimageResult :: Range { expr, interval } => {
320+ assert_eq ! ( expr, col( "x" ) ) ;
321+ assert_eq ! ( interval. lower( ) . clone( ) , expected_lower) ;
322+ assert_eq ! ( interval. upper( ) . clone( ) , expected_upper) ;
323+ }
324+ PreimageResult :: None => {
325+ panic ! ( "Expected Range, got None for input {:?}" , input)
326+ }
327+ }
328+ }
329+
330+ /// Helper to test cases that should return None
331+ fn assert_preimage_none ( input : ScalarValue ) {
332+ let floor_func = FloorFunc :: new ( ) ;
333+ let args = vec ! [ col( "x" ) ] ;
334+ let lit_expr = Expr :: Literal ( input. clone ( ) , None ) ;
335+ let info = SimplifyContext :: default ( ) ;
336+
337+ let result = floor_func. preimage ( & args, & lit_expr, & info) . unwrap ( ) ;
338+ assert ! (
339+ matches!( result, PreimageResult :: None ) ,
340+ "Expected None for input {:?}" ,
341+ input
342+ ) ;
343+ }
344+
345+ #[ test]
346+ fn test_floor_preimage_valid_cases ( ) {
347+ // Float64
348+ assert_preimage_range (
349+ ScalarValue :: Float64 ( Some ( 100.0 ) ) ,
350+ ScalarValue :: Float64 ( Some ( 100.0 ) ) ,
351+ ScalarValue :: Float64 ( Some ( 101.0 ) ) ,
352+ ) ;
353+ // Float32
354+ assert_preimage_range (
355+ ScalarValue :: Float32 ( Some ( 50.0 ) ) ,
356+ ScalarValue :: Float32 ( Some ( 50.0 ) ) ,
357+ ScalarValue :: Float32 ( Some ( 51.0 ) ) ,
358+ ) ;
359+ // Int64
360+ assert_preimage_range (
361+ ScalarValue :: Int64 ( Some ( 42 ) ) ,
362+ ScalarValue :: Int64 ( Some ( 42 ) ) ,
363+ ScalarValue :: Int64 ( Some ( 43 ) ) ,
364+ ) ;
365+ // Int32
366+ assert_preimage_range (
367+ ScalarValue :: Int32 ( Some ( 100 ) ) ,
368+ ScalarValue :: Int32 ( Some ( 100 ) ) ,
369+ ScalarValue :: Int32 ( Some ( 101 ) ) ,
370+ ) ;
371+ // Negative values
372+ assert_preimage_range (
373+ ScalarValue :: Float64 ( Some ( -5.0 ) ) ,
374+ ScalarValue :: Float64 ( Some ( -5.0 ) ) ,
375+ ScalarValue :: Float64 ( Some ( -4.0 ) ) ,
376+ ) ;
377+ // Zero
378+ assert_preimage_range (
379+ ScalarValue :: Float64 ( Some ( 0.0 ) ) ,
380+ ScalarValue :: Float64 ( Some ( 0.0 ) ) ,
381+ ScalarValue :: Float64 ( Some ( 1.0 ) ) ,
382+ ) ;
383+ }
384+
385+ #[ test]
386+ fn test_floor_preimage_integer_overflow ( ) {
387+ // All integer types at MAX value should return None
388+ assert_preimage_none ( ScalarValue :: Int64 ( Some ( i64:: MAX ) ) ) ;
389+ assert_preimage_none ( ScalarValue :: Int32 ( Some ( i32:: MAX ) ) ) ;
390+ assert_preimage_none ( ScalarValue :: Int16 ( Some ( i16:: MAX ) ) ) ;
391+ assert_preimage_none ( ScalarValue :: Int8 ( Some ( i8:: MAX ) ) ) ;
392+ }
393+
394+ #[ test]
395+ fn test_floor_preimage_float_edge_cases ( ) {
396+ // Float64 edge cases
397+ assert_preimage_none ( ScalarValue :: Float64 ( Some ( f64:: INFINITY ) ) ) ;
398+ assert_preimage_none ( ScalarValue :: Float64 ( Some ( f64:: NEG_INFINITY ) ) ) ;
399+ assert_preimage_none ( ScalarValue :: Float64 ( Some ( f64:: NAN ) ) ) ;
400+ assert_preimage_none ( ScalarValue :: Float64 ( Some ( f64:: MAX ) ) ) ; // precision loss
401+
402+ // Float32 edge cases
403+ assert_preimage_none ( ScalarValue :: Float32 ( Some ( f32:: INFINITY ) ) ) ;
404+ assert_preimage_none ( ScalarValue :: Float32 ( Some ( f32:: NEG_INFINITY ) ) ) ;
405+ assert_preimage_none ( ScalarValue :: Float32 ( Some ( f32:: NAN ) ) ) ;
406+ assert_preimage_none ( ScalarValue :: Float32 ( Some ( f32:: MAX ) ) ) ; // precision loss
407+ }
408+
409+ #[ test]
410+ fn test_floor_preimage_null_values ( ) {
411+ assert_preimage_none ( ScalarValue :: Float64 ( None ) ) ;
412+ assert_preimage_none ( ScalarValue :: Float32 ( None ) ) ;
413+ assert_preimage_none ( ScalarValue :: Int64 ( None ) ) ;
414+ }
415+
416+ #[ test]
417+ fn test_floor_preimage_invalid_inputs ( ) {
418+ let floor_func = FloorFunc :: new ( ) ;
419+ let info = SimplifyContext :: default ( ) ;
420+
421+ // Non-literal comparison value
422+ let result = floor_func. preimage ( & [ col ( "x" ) ] , & col ( "y" ) , & info) . unwrap ( ) ;
423+ assert ! (
424+ matches!( result, PreimageResult :: None ) ,
425+ "Expected None for non-literal"
426+ ) ;
427+
428+ // Wrong argument count (too many)
429+ let lit = Expr :: Literal ( ScalarValue :: Float64 ( Some ( 100.0 ) ) , None ) ;
430+ let result = floor_func
431+ . preimage ( & [ col ( "x" ) , col ( "y" ) ] , & lit, & info)
432+ . unwrap ( ) ;
433+ assert ! (
434+ matches!( result, PreimageResult :: None ) ,
435+ "Expected None for wrong arg count"
436+ ) ;
437+
438+ // Wrong argument count (zero)
439+ let result = floor_func. preimage ( & [ ] , & lit, & info) . unwrap ( ) ;
440+ assert ! (
441+ matches!( result, PreimageResult :: None ) ,
442+ "Expected None for zero args"
443+ ) ;
444+ }
445+ }
0 commit comments