Skip to content

Commit 43977da

Browse files
devanshu0987Devanshu
andauthored
Add Decimal support for floor preimage (#20099)
## Which issue does this PR close? - Closes #20080 ## Rationale for this change <!-- Why are you proposing this change? If this is already explained clearly in the issue then this section is not needed. Explaining clearly why changes are proposed helps reviewers understand your changes and offer better suggestions for fixes. --> ## What changes are included in this PR? - Decimal support - SLT Tests for Floor preimage ## Are these changes tested? - Unit Tests - SLT Tests <!-- We typically require tests for all PRs in order to: 1. Prevent the code from being accidentally broken by subsequent changes 2. Serve as another way to document the expected behavior of the code If tests are not included in your PR, please explain why (for example, are they covered by existing tests)? --> ## Are there any user-facing changes? No <!-- If there are user-facing changes then we may require documentation to be updated before approving the PR. --> <!-- If there are any breaking changes to public APIs, please add the `api change` label. --> --------- Co-authored-by: Devanshu <devanshu@codapayments.com>
1 parent 796c7d1 commit 43977da

2 files changed

Lines changed: 563 additions & 55 deletions

File tree

datafusion/functions/src/math/floor.rs

Lines changed: 278 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,10 @@ use std::any::Any;
1919
use std::sync::Arc;
2020

2121
use arrow::array::{ArrayRef, AsArray};
22+
use arrow::compute::{DecimalCast, rescale_decimal};
2223
use arrow::datatypes::{
23-
DataType, Decimal32Type, Decimal64Type, Decimal128Type, Decimal256Type, Float32Type,
24-
Float64Type,
24+
ArrowNativeTypeOp, DataType, Decimal32Type, Decimal64Type, Decimal128Type,
25+
Decimal256Type, DecimalType, Float32Type, Float64Type,
2526
};
2627
use datafusion_common::{Result, ScalarValue, exec_err};
2728
use datafusion_expr::interval_arithmetic::Interval;
@@ -77,6 +78,42 @@ impl FloorFunc {
7778
}
7879
}
7980

81+
// ============ Macro for preimage bounds ============
82+
/// Generates the code to call the appropriate bounds function and wrap results.
83+
macro_rules! preimage_bounds {
84+
// Float types: call float_preimage_bounds and wrap in ScalarValue
85+
(float: $variant:ident, $value:expr) => {
86+
float_preimage_bounds($value).map(|(lo, hi)| {
87+
(
88+
ScalarValue::$variant(Some(lo)),
89+
ScalarValue::$variant(Some(hi)),
90+
)
91+
})
92+
};
93+
94+
// Integer types: call int_preimage_bounds and wrap in ScalarValue
95+
(int: $variant:ident, $value:expr) => {
96+
int_preimage_bounds($value).map(|(lo, hi)| {
97+
(
98+
ScalarValue::$variant(Some(lo)),
99+
ScalarValue::$variant(Some(hi)),
100+
)
101+
})
102+
};
103+
104+
// Decimal types: call decimal_preimage_bounds with precision/scale and wrap in ScalarValue
105+
(decimal: $variant:ident, $decimal_type:ty, $value:expr, $precision:expr, $scale:expr) => {
106+
decimal_preimage_bounds::<$decimal_type>($value, $precision, $scale).map(
107+
|(lo, hi)| {
108+
(
109+
ScalarValue::$variant(Some(lo), $precision, $scale),
110+
ScalarValue::$variant(Some(hi), $precision, $scale),
111+
)
112+
},
113+
)
114+
};
115+
}
116+
80117
impl ScalarUDFImpl for FloorFunc {
81118
fn as_any(&self) -> &dyn Any {
82119
self
@@ -216,10 +253,8 @@ impl ScalarUDFImpl for FloorFunc {
216253
lit_expr: &Expr,
217254
_info: &SimplifyContext,
218255
) -> Result<PreimageResult> {
219-
// floor takes exactly one argument
220-
if args.len() != 1 {
221-
return Ok(PreimageResult::None);
222-
}
256+
// floor takes exactly one argument and we do not expect to reach here with multiple arguments.
257+
debug_assert!(args.len() == 1, "floor() takes exactly one argument");
223258

224259
let arg = args[0].clone();
225260

@@ -230,35 +265,34 @@ impl ScalarUDFImpl for FloorFunc {
230265

231266
// Compute lower bound (N) and upper bound (N + 1) using helper functions
232267
let Some((lower, upper)) = (match lit_value {
233-
// Decimal types should be supported and tracked in
234-
// https://github.com/apache/datafusion/issues/20080
235268
// Floating-point types
236-
ScalarValue::Float64(Some(n)) => float_preimage_bounds(*n).map(|(lo, hi)| {
237-
(
238-
ScalarValue::Float64(Some(lo)),
239-
ScalarValue::Float64(Some(hi)),
240-
)
241-
}),
242-
ScalarValue::Float32(Some(n)) => float_preimage_bounds(*n).map(|(lo, hi)| {
243-
(
244-
ScalarValue::Float32(Some(lo)),
245-
ScalarValue::Float32(Some(hi)),
246-
)
247-
}),
248-
249-
// Integer types
250-
ScalarValue::Int8(Some(n)) => int_preimage_bounds(*n).map(|(lo, hi)| {
251-
(ScalarValue::Int8(Some(lo)), ScalarValue::Int8(Some(hi)))
252-
}),
253-
ScalarValue::Int16(Some(n)) => int_preimage_bounds(*n).map(|(lo, hi)| {
254-
(ScalarValue::Int16(Some(lo)), ScalarValue::Int16(Some(hi)))
255-
}),
256-
ScalarValue::Int32(Some(n)) => int_preimage_bounds(*n).map(|(lo, hi)| {
257-
(ScalarValue::Int32(Some(lo)), ScalarValue::Int32(Some(hi)))
258-
}),
259-
ScalarValue::Int64(Some(n)) => int_preimage_bounds(*n).map(|(lo, hi)| {
260-
(ScalarValue::Int64(Some(lo)), ScalarValue::Int64(Some(hi)))
261-
}),
269+
ScalarValue::Float64(Some(n)) => preimage_bounds!(float: Float64, *n),
270+
ScalarValue::Float32(Some(n)) => preimage_bounds!(float: Float32, *n),
271+
272+
// Integer types (not reachable from SQL/SLT: floor() only accepts Float64/Float32/Decimal,
273+
// so the RHS literal is always coerced to one of those before preimage runs; kept for
274+
// programmatic use and unit tests)
275+
ScalarValue::Int8(Some(n)) => preimage_bounds!(int: Int8, *n),
276+
ScalarValue::Int16(Some(n)) => preimage_bounds!(int: Int16, *n),
277+
ScalarValue::Int32(Some(n)) => preimage_bounds!(int: Int32, *n),
278+
ScalarValue::Int64(Some(n)) => preimage_bounds!(int: Int64, *n),
279+
280+
// Decimal types
281+
// DECIMAL(precision, scale) where precision ≤ 38 -> Decimal128(precision, scale)
282+
// DECIMAL(precision, scale) where precision > 38 -> Decimal256(precision, scale)
283+
// Decimal32 and Decimal64 are unreachable from SQL/SLT.
284+
ScalarValue::Decimal32(Some(n), precision, scale) => {
285+
preimage_bounds!(decimal: Decimal32, Decimal32Type, *n, *precision, *scale)
286+
}
287+
ScalarValue::Decimal64(Some(n), precision, scale) => {
288+
preimage_bounds!(decimal: Decimal64, Decimal64Type, *n, *precision, *scale)
289+
}
290+
ScalarValue::Decimal128(Some(n), precision, scale) => {
291+
preimage_bounds!(decimal: Decimal128, Decimal128Type, *n, *precision, *scale)
292+
}
293+
ScalarValue::Decimal256(Some(n), precision, scale) => {
294+
preimage_bounds!(decimal: Decimal256, Decimal256Type, *n, *precision, *scale)
295+
}
262296

263297
// Unsupported types
264298
_ => None,
@@ -310,9 +344,49 @@ fn int_preimage_bounds<I: CheckedAdd + One + Copy>(n: I) -> Option<(I, I)> {
310344
Some((n, upper))
311345
}
312346

347+
/// Compute preimage bounds for floor function on decimal types.
348+
/// For floor(x) = n, the preimage is [n, n+1).
349+
/// Returns None if:
350+
/// - The value has a fractional part (floor always returns integers)
351+
/// - Adding 1 would overflow
352+
fn decimal_preimage_bounds<D: DecimalType>(
353+
value: D::Native,
354+
precision: u8,
355+
scale: i8,
356+
) -> Option<(D::Native, D::Native)>
357+
where
358+
D::Native: DecimalCast + ArrowNativeTypeOp + std::ops::Rem<Output = D::Native>,
359+
{
360+
// Use rescale_decimal to compute "1" at target scale (avoids manual pow)
361+
// Convert integer 1 (scale=0) to the target scale
362+
let one_scaled: D::Native = rescale_decimal::<D, D>(
363+
D::Native::ONE, // value = 1
364+
1, // input_precision = 1
365+
0, // input_scale = 0 (integer)
366+
precision, // output_precision
367+
scale, // output_scale
368+
)?;
369+
370+
// floor always returns an integer, so if value has a fractional part, there's no solution
371+
// Check: value % one_scaled != 0 means fractional part exists
372+
if scale > 0 && value % one_scaled != D::Native::ZERO {
373+
return None;
374+
}
375+
376+
// Compute upper bound using checked addition
377+
// Before preimage stage, the internal i128/i256(value) is validated based on the precision and scale.
378+
// MAX_DECIMAL128_FOR_EACH_PRECISION and MAX_DECIMAL256_FOR_EACH_PRECISION are used to validate the internal i128/i256.
379+
// Any invalid i128/i256 will not reach here.
380+
// Therefore, the add_checked will always succeed if tested via SQL/SLT path.
381+
let upper = value.add_checked(one_scaled).ok()?;
382+
383+
Some((value, upper))
384+
}
385+
313386
#[cfg(test)]
314387
mod tests {
315388
use super::*;
389+
use arrow_buffer::i256;
316390
use datafusion_expr::col;
317391

318392
/// Helper to test valid preimage cases that should return a Range
@@ -434,33 +508,182 @@ mod tests {
434508
assert_preimage_none(ScalarValue::Int64(None));
435509
}
436510

511+
// ============ Decimal32 Tests (mirrors float/int tests) ============
512+
437513
#[test]
438-
fn test_floor_preimage_invalid_inputs() {
439-
let floor_func = FloorFunc::new();
440-
let info = SimplifyContext::default();
514+
fn test_floor_preimage_decimal_valid_cases() {
515+
// ===== Decimal32 =====
516+
// Positive integer decimal: 100.00 (scale=2, so raw=10000)
517+
// floor(x) = 100.00 -> x in [100.00, 101.00)
518+
assert_preimage_range(
519+
ScalarValue::Decimal32(Some(10000), 9, 2),
520+
ScalarValue::Decimal32(Some(10000), 9, 2), // 100.00
521+
ScalarValue::Decimal32(Some(10100), 9, 2), // 101.00
522+
);
441523

442-
// Non-literal comparison value
443-
let result = floor_func.preimage(&[col("x")], &col("y"), &info).unwrap();
444-
assert!(
445-
matches!(result, PreimageResult::None),
446-
"Expected None for non-literal"
524+
// Smaller positive: 50.00
525+
assert_preimage_range(
526+
ScalarValue::Decimal32(Some(5000), 9, 2),
527+
ScalarValue::Decimal32(Some(5000), 9, 2), // 50.00
528+
ScalarValue::Decimal32(Some(5100), 9, 2), // 51.00
447529
);
448530

449-
// Wrong argument count (too many)
450-
let lit = Expr::Literal(ScalarValue::Float64(Some(100.0)), None);
451-
let result = floor_func
452-
.preimage(&[col("x"), col("y")], &lit, &info)
453-
.unwrap();
454-
assert!(
455-
matches!(result, PreimageResult::None),
456-
"Expected None for wrong arg count"
531+
// Negative integer decimal: -5.00
532+
assert_preimage_range(
533+
ScalarValue::Decimal32(Some(-500), 9, 2),
534+
ScalarValue::Decimal32(Some(-500), 9, 2), // -5.00
535+
ScalarValue::Decimal32(Some(-400), 9, 2), // -4.00
457536
);
458537

459-
// Wrong argument count (zero)
460-
let result = floor_func.preimage(&[], &lit, &info).unwrap();
461-
assert!(
462-
matches!(result, PreimageResult::None),
463-
"Expected None for zero args"
538+
// Zero: 0.00
539+
assert_preimage_range(
540+
ScalarValue::Decimal32(Some(0), 9, 2),
541+
ScalarValue::Decimal32(Some(0), 9, 2), // 0.00
542+
ScalarValue::Decimal32(Some(100), 9, 2), // 1.00
464543
);
544+
545+
// Scale 0 (pure integer): 42
546+
assert_preimage_range(
547+
ScalarValue::Decimal32(Some(42), 9, 0),
548+
ScalarValue::Decimal32(Some(42), 9, 0),
549+
ScalarValue::Decimal32(Some(43), 9, 0),
550+
);
551+
552+
// ===== Decimal64 =====
553+
assert_preimage_range(
554+
ScalarValue::Decimal64(Some(10000), 18, 2),
555+
ScalarValue::Decimal64(Some(10000), 18, 2), // 100.00
556+
ScalarValue::Decimal64(Some(10100), 18, 2), // 101.00
557+
);
558+
559+
// Negative
560+
assert_preimage_range(
561+
ScalarValue::Decimal64(Some(-500), 18, 2),
562+
ScalarValue::Decimal64(Some(-500), 18, 2), // -5.00
563+
ScalarValue::Decimal64(Some(-400), 18, 2), // -4.00
564+
);
565+
566+
// Zero
567+
assert_preimage_range(
568+
ScalarValue::Decimal64(Some(0), 18, 2),
569+
ScalarValue::Decimal64(Some(0), 18, 2),
570+
ScalarValue::Decimal64(Some(100), 18, 2),
571+
);
572+
573+
// ===== Decimal128 =====
574+
assert_preimage_range(
575+
ScalarValue::Decimal128(Some(10000), 38, 2),
576+
ScalarValue::Decimal128(Some(10000), 38, 2), // 100.00
577+
ScalarValue::Decimal128(Some(10100), 38, 2), // 101.00
578+
);
579+
580+
// Negative
581+
assert_preimage_range(
582+
ScalarValue::Decimal128(Some(-500), 38, 2),
583+
ScalarValue::Decimal128(Some(-500), 38, 2), // -5.00
584+
ScalarValue::Decimal128(Some(-400), 38, 2), // -4.00
585+
);
586+
587+
// Zero
588+
assert_preimage_range(
589+
ScalarValue::Decimal128(Some(0), 38, 2),
590+
ScalarValue::Decimal128(Some(0), 38, 2),
591+
ScalarValue::Decimal128(Some(100), 38, 2),
592+
);
593+
594+
// ===== Decimal256 =====
595+
assert_preimage_range(
596+
ScalarValue::Decimal256(Some(i256::from(10000)), 76, 2),
597+
ScalarValue::Decimal256(Some(i256::from(10000)), 76, 2), // 100.00
598+
ScalarValue::Decimal256(Some(i256::from(10100)), 76, 2), // 101.00
599+
);
600+
601+
// Negative
602+
assert_preimage_range(
603+
ScalarValue::Decimal256(Some(i256::from(-500)), 76, 2),
604+
ScalarValue::Decimal256(Some(i256::from(-500)), 76, 2), // -5.00
605+
ScalarValue::Decimal256(Some(i256::from(-400)), 76, 2), // -4.00
606+
);
607+
608+
// Zero
609+
assert_preimage_range(
610+
ScalarValue::Decimal256(Some(i256::ZERO), 76, 2),
611+
ScalarValue::Decimal256(Some(i256::ZERO), 76, 2),
612+
ScalarValue::Decimal256(Some(i256::from(100)), 76, 2),
613+
);
614+
}
615+
616+
#[test]
617+
fn test_floor_preimage_decimal_non_integer() {
618+
// floor(x) = 1.30 has NO SOLUTION because floor always returns an integer
619+
// Therefore preimage should return None for non-integer decimals
620+
621+
// Decimal32
622+
assert_preimage_none(ScalarValue::Decimal32(Some(130), 9, 2)); // 1.30
623+
assert_preimage_none(ScalarValue::Decimal32(Some(-250), 9, 2)); // -2.50
624+
assert_preimage_none(ScalarValue::Decimal32(Some(370), 9, 2)); // 3.70
625+
assert_preimage_none(ScalarValue::Decimal32(Some(1), 9, 2)); // 0.01
626+
627+
// Decimal64
628+
assert_preimage_none(ScalarValue::Decimal64(Some(130), 18, 2)); // 1.30
629+
assert_preimage_none(ScalarValue::Decimal64(Some(-250), 18, 2)); // -2.50
630+
631+
// Decimal128
632+
assert_preimage_none(ScalarValue::Decimal128(Some(130), 38, 2)); // 1.30
633+
assert_preimage_none(ScalarValue::Decimal128(Some(-250), 38, 2)); // -2.50
634+
635+
// Decimal256
636+
assert_preimage_none(ScalarValue::Decimal256(Some(i256::from(130)), 76, 2)); // 1.30
637+
assert_preimage_none(ScalarValue::Decimal256(Some(i256::from(-250)), 76, 2)); // -2.50
638+
639+
// Decimal32: i32::MAX - 50
640+
// This return None because the value is not an integer, not because it is out of range.
641+
assert_preimage_none(ScalarValue::Decimal32(Some(i32::MAX - 50), 10, 2));
642+
643+
// Decimal64: i64::MAX - 50
644+
// This return None because the value is not an integer, not because it is out of range.
645+
assert_preimage_none(ScalarValue::Decimal64(Some(i64::MAX - 50), 19, 2));
646+
}
647+
648+
#[test]
649+
fn test_floor_preimage_decimal_overflow() {
650+
// Test near MAX where adding scale_factor would overflow
651+
652+
// Decimal32: i32::MAX
653+
assert_preimage_none(ScalarValue::Decimal32(Some(i32::MAX), 10, 0));
654+
655+
// Decimal64: i64::MAX
656+
assert_preimage_none(ScalarValue::Decimal64(Some(i64::MAX), 19, 0));
657+
}
658+
659+
#[test]
660+
fn test_floor_preimage_decimal_edge_cases() {
661+
// ===== Decimal32 =====
662+
// Large value that doesn't overflow
663+
// Decimal(9,2) max value is 9,999,999.99 (stored as 999,999,999)
664+
// Use a large value that fits Decimal(9,2) and is divisible by 100
665+
let safe_max_aligned_32 = 999_999_900; // 9,999,999.00
666+
assert_preimage_range(
667+
ScalarValue::Decimal32(Some(safe_max_aligned_32), 9, 2),
668+
ScalarValue::Decimal32(Some(safe_max_aligned_32), 9, 2),
669+
ScalarValue::Decimal32(Some(safe_max_aligned_32 + 100), 9, 2),
670+
);
671+
672+
// Negative edge: use a large negative value that fits Decimal(9,2)
673+
// Decimal(9,2) min value is -9,999,999.99 (stored as -999,999,999)
674+
let min_aligned_32 = -999_999_900; // -9,999,999.00
675+
assert_preimage_range(
676+
ScalarValue::Decimal32(Some(min_aligned_32), 9, 2),
677+
ScalarValue::Decimal32(Some(min_aligned_32), 9, 2),
678+
ScalarValue::Decimal32(Some(min_aligned_32 + 100), 9, 2),
679+
);
680+
}
681+
682+
#[test]
683+
fn test_floor_preimage_decimal_null() {
684+
assert_preimage_none(ScalarValue::Decimal32(None, 9, 2));
685+
assert_preimage_none(ScalarValue::Decimal64(None, 18, 2));
686+
assert_preimage_none(ScalarValue::Decimal128(None, 38, 2));
687+
assert_preimage_none(ScalarValue::Decimal256(None, 76, 2));
465688
}
466689
}

0 commit comments

Comments
 (0)