-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Add Decimal support for floor preimage #20099
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 4 commits
5ff1cd2
77e3e4c
d32651d
1830ace
fccac54
ccd28cf
b771696
7419625
48aebcb
3197819
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -19,9 +19,10 @@ use std::any::Any; | |
| use std::sync::Arc; | ||
|
|
||
| use arrow::array::{ArrayRef, AsArray}; | ||
| use arrow::compute::{DecimalCast, rescale_decimal}; | ||
| use arrow::datatypes::{ | ||
| DataType, Decimal32Type, Decimal64Type, Decimal128Type, Decimal256Type, Float32Type, | ||
| Float64Type, | ||
| ArrowNativeTypeOp, DataType, Decimal32Type, Decimal64Type, Decimal128Type, | ||
| Decimal256Type, DecimalType, Float32Type, Float64Type, | ||
| }; | ||
| use datafusion_common::{Result, ScalarValue, exec_err}; | ||
| use datafusion_expr::interval_arithmetic::Interval; | ||
|
|
@@ -230,8 +231,6 @@ impl ScalarUDFImpl for FloorFunc { | |
|
|
||
| // Compute lower bound (N) and upper bound (N + 1) using helper functions | ||
| let Some((lower, upper)) = (match lit_value { | ||
| // Decimal types should be supported and tracked in | ||
| // https://github.com/apache/datafusion/issues/20080 | ||
| // Floating-point types | ||
| ScalarValue::Float64(Some(n)) => float_preimage_bounds(*n).map(|(lo, hi)| { | ||
| ( | ||
|
|
@@ -260,6 +259,48 @@ impl ScalarUDFImpl for FloorFunc { | |
| (ScalarValue::Int64(Some(lo)), ScalarValue::Int64(Some(hi))) | ||
| }), | ||
|
|
||
| // Decimal types | ||
| ScalarValue::Decimal32(Some(n), precision, scale) => { | ||
| decimal_preimage_bounds::<Decimal32Type>(*n, *precision, *scale).map( | ||
| |(lo, hi)| { | ||
| ( | ||
| ScalarValue::Decimal32(Some(lo), *precision, *scale), | ||
| ScalarValue::Decimal32(Some(hi), *precision, *scale), | ||
| ) | ||
| }, | ||
| ) | ||
| } | ||
| ScalarValue::Decimal64(Some(n), precision, scale) => { | ||
| decimal_preimage_bounds::<Decimal64Type>(*n, *precision, *scale).map( | ||
| |(lo, hi)| { | ||
| ( | ||
| ScalarValue::Decimal64(Some(lo), *precision, *scale), | ||
| ScalarValue::Decimal64(Some(hi), *precision, *scale), | ||
| ) | ||
| }, | ||
| ) | ||
| } | ||
| ScalarValue::Decimal128(Some(n), precision, scale) => { | ||
| decimal_preimage_bounds::<Decimal128Type>(*n, *precision, *scale).map( | ||
| |(lo, hi)| { | ||
| ( | ||
| ScalarValue::Decimal128(Some(lo), *precision, *scale), | ||
| ScalarValue::Decimal128(Some(hi), *precision, *scale), | ||
| ) | ||
| }, | ||
| ) | ||
| } | ||
| ScalarValue::Decimal256(Some(n), precision, scale) => { | ||
| decimal_preimage_bounds::<Decimal256Type>(*n, *precision, *scale).map( | ||
| |(lo, hi)| { | ||
| ( | ||
| ScalarValue::Decimal256(Some(lo), *precision, *scale), | ||
| ScalarValue::Decimal256(Some(hi), *precision, *scale), | ||
| ) | ||
| }, | ||
| ) | ||
| } | ||
|
|
||
| // Unsupported types | ||
| _ => None, | ||
| }) else { | ||
|
|
@@ -310,9 +351,45 @@ fn int_preimage_bounds<I: CheckedAdd + One + Copy>(n: I) -> Option<(I, I)> { | |
| Some((n, upper)) | ||
| } | ||
|
|
||
| /// Compute preimage bounds for floor function on decimal types. | ||
| /// For floor(x) = n, the preimage is [n, n+1). | ||
| /// Returns None if: | ||
| /// - The value has a fractional part (floor always returns integers) | ||
| /// - Adding 1 would overflow | ||
| fn decimal_preimage_bounds<D: DecimalType>( | ||
| value: D::Native, | ||
| precision: u8, | ||
| scale: i8, | ||
| ) -> Option<(D::Native, D::Native)> | ||
| where | ||
| D::Native: DecimalCast + ArrowNativeTypeOp + std::ops::Rem<Output = D::Native>, | ||
| { | ||
| // Use rescale_decimal to compute "1" at target scale (avoids manual pow) | ||
| // Convert integer 1 (scale=0) to the target scale | ||
| let one_scaled: D::Native = rescale_decimal::<D, D>( | ||
| D::Native::ONE, // value = 1 | ||
| 1, // input_precision = 1 | ||
| 0, // input_scale = 0 (integer) | ||
| precision, // output_precision | ||
| scale, // output_scale | ||
| )?; | ||
|
|
||
| // floor always returns an integer, so if value has a fractional part, there's no solution | ||
| // Check: value % one_scaled != 0 means fractional part exists | ||
| if scale > 0 && value % one_scaled != D::Native::ZERO { | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What about negative scales ?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hi, the logic here comes from the idea that floor(x) will always return a floor data type but it will actually be an integer. Same logic is carried over here for Decimal.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Code review comment in the last PR: #20059 (comment) |
||
| return None; | ||
| } | ||
|
|
||
| // Compute upper bound using checked addition | ||
| let upper = value.add_checked(one_scaled).ok()?; | ||
|
|
||
| Some((value, upper)) | ||
| } | ||
|
|
||
| #[cfg(test)] | ||
| mod tests { | ||
| use super::*; | ||
| use arrow_buffer::i256; | ||
| use datafusion_expr::col; | ||
|
|
||
| /// Helper to test valid preimage cases that should return a Range | ||
|
|
@@ -463,4 +540,240 @@ mod tests { | |
| "Expected None for zero args" | ||
| ); | ||
| } | ||
|
|
||
| // ============ Decimal32 Tests (mirrors float/int tests) ============ | ||
|
|
||
| #[test] | ||
| fn test_floor_preimage_decimal_valid_cases() { | ||
| // ===== Decimal32 ===== | ||
| // Positive integer decimal: 100.00 (scale=2, so raw=10000) | ||
| // floor(x) = 100.00 -> x in [100.00, 101.00) | ||
| assert_preimage_range( | ||
| ScalarValue::Decimal32(Some(10000), 9, 2), | ||
| ScalarValue::Decimal32(Some(10000), 9, 2), // 100.00 | ||
| ScalarValue::Decimal32(Some(10100), 9, 2), // 101.00 | ||
| ); | ||
|
|
||
| // Smaller positive: 50.00 | ||
| assert_preimage_range( | ||
| ScalarValue::Decimal32(Some(5000), 9, 2), | ||
| ScalarValue::Decimal32(Some(5000), 9, 2), // 50.00 | ||
| ScalarValue::Decimal32(Some(5100), 9, 2), // 51.00 | ||
| ); | ||
|
|
||
| // Negative integer decimal: -5.00 | ||
| assert_preimage_range( | ||
| ScalarValue::Decimal32(Some(-500), 9, 2), | ||
| ScalarValue::Decimal32(Some(-500), 9, 2), // -5.00 | ||
| ScalarValue::Decimal32(Some(-400), 9, 2), // -4.00 | ||
| ); | ||
|
|
||
| // Zero: 0.00 | ||
| assert_preimage_range( | ||
| ScalarValue::Decimal32(Some(0), 9, 2), | ||
| ScalarValue::Decimal32(Some(0), 9, 2), // 0.00 | ||
| ScalarValue::Decimal32(Some(100), 9, 2), // 1.00 | ||
| ); | ||
|
|
||
| // Scale 0 (pure integer): 42 | ||
| assert_preimage_range( | ||
| ScalarValue::Decimal32(Some(42), 9, 0), | ||
| ScalarValue::Decimal32(Some(42), 9, 0), | ||
| ScalarValue::Decimal32(Some(43), 9, 0), | ||
| ); | ||
|
|
||
| // ===== Decimal64 ===== | ||
| assert_preimage_range( | ||
| ScalarValue::Decimal64(Some(10000), 18, 2), | ||
| ScalarValue::Decimal64(Some(10000), 18, 2), // 100.00 | ||
| ScalarValue::Decimal64(Some(10100), 18, 2), // 101.00 | ||
| ); | ||
|
|
||
| // Negative | ||
| assert_preimage_range( | ||
| ScalarValue::Decimal64(Some(-500), 18, 2), | ||
| ScalarValue::Decimal64(Some(-500), 18, 2), // -5.00 | ||
| ScalarValue::Decimal64(Some(-400), 18, 2), // -4.00 | ||
| ); | ||
|
|
||
| // Zero | ||
| assert_preimage_range( | ||
| ScalarValue::Decimal64(Some(0), 18, 2), | ||
| ScalarValue::Decimal64(Some(0), 18, 2), | ||
| ScalarValue::Decimal64(Some(100), 18, 2), | ||
| ); | ||
|
|
||
| // ===== Decimal128 ===== | ||
| assert_preimage_range( | ||
| ScalarValue::Decimal128(Some(10000), 38, 2), | ||
| ScalarValue::Decimal128(Some(10000), 38, 2), // 100.00 | ||
| ScalarValue::Decimal128(Some(10100), 38, 2), // 101.00 | ||
| ); | ||
|
|
||
| // Negative | ||
| assert_preimage_range( | ||
| ScalarValue::Decimal128(Some(-500), 38, 2), | ||
| ScalarValue::Decimal128(Some(-500), 38, 2), // -5.00 | ||
| ScalarValue::Decimal128(Some(-400), 38, 2), // -4.00 | ||
| ); | ||
|
|
||
| // Zero | ||
| assert_preimage_range( | ||
| ScalarValue::Decimal128(Some(0), 38, 2), | ||
| ScalarValue::Decimal128(Some(0), 38, 2), | ||
| ScalarValue::Decimal128(Some(100), 38, 2), | ||
| ); | ||
|
|
||
| // ===== Decimal256 ===== | ||
| assert_preimage_range( | ||
| ScalarValue::Decimal256(Some(i256::from(10000)), 76, 2), | ||
| ScalarValue::Decimal256(Some(i256::from(10000)), 76, 2), // 100.00 | ||
| ScalarValue::Decimal256(Some(i256::from(10100)), 76, 2), // 101.00 | ||
| ); | ||
|
|
||
| // Negative | ||
| assert_preimage_range( | ||
| ScalarValue::Decimal256(Some(i256::from(-500)), 76, 2), | ||
| ScalarValue::Decimal256(Some(i256::from(-500)), 76, 2), // -5.00 | ||
| ScalarValue::Decimal256(Some(i256::from(-400)), 76, 2), // -4.00 | ||
| ); | ||
|
|
||
| // Zero | ||
| assert_preimage_range( | ||
| ScalarValue::Decimal256(Some(i256::ZERO), 76, 2), | ||
| ScalarValue::Decimal256(Some(i256::ZERO), 76, 2), | ||
| ScalarValue::Decimal256(Some(i256::from(100)), 76, 2), | ||
| ); | ||
| } | ||
|
|
||
| #[test] | ||
| fn test_floor_preimage_decimal_non_integer() { | ||
| // floor(x) = 1.30 has NO SOLUTION because floor always returns an integer | ||
| // Therefore preimage should return None for non-integer decimals | ||
|
|
||
| // Decimal32 | ||
| assert_preimage_none(ScalarValue::Decimal32(Some(130), 9, 2)); // 1.30 | ||
| assert_preimage_none(ScalarValue::Decimal32(Some(-250), 9, 2)); // -2.50 | ||
| assert_preimage_none(ScalarValue::Decimal32(Some(370), 9, 2)); // 3.70 | ||
| assert_preimage_none(ScalarValue::Decimal32(Some(1), 9, 2)); // 0.01 | ||
|
|
||
| // Decimal64 | ||
| assert_preimage_none(ScalarValue::Decimal64(Some(130), 18, 2)); // 1.30 | ||
| assert_preimage_none(ScalarValue::Decimal64(Some(-250), 18, 2)); // -2.50 | ||
|
|
||
| // Decimal128 | ||
| assert_preimage_none(ScalarValue::Decimal128(Some(130), 38, 2)); // 1.30 | ||
| assert_preimage_none(ScalarValue::Decimal128(Some(-250), 38, 2)); // -2.50 | ||
|
|
||
| // Decimal256 | ||
| assert_preimage_none(ScalarValue::Decimal256(Some(i256::from(130)), 76, 2)); // 1.30 | ||
| assert_preimage_none(ScalarValue::Decimal256(Some(i256::from(-250)), 76, 2)); // -2.50 | ||
| } | ||
|
|
||
| #[test] | ||
| fn test_floor_preimage_decimal_overflow() { | ||
| // Test near MAX where adding scale_factor would overflow | ||
|
|
||
| // Decimal32: i32::MAX | ||
| // For scale=2, we add 100, so i32::MAX - 50 would overflow | ||
| assert_preimage_none(ScalarValue::Decimal32(Some(i32::MAX - 50), 9, 2)); | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What exactly is the idea of this test ?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hi, you are right.
But, Thanks. Let me fix this.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I removed a bunch of repetitive tests and moved the tests around. Now it should make sense. Please let me know. |
||
| // For scale=0, we add 1, so i32::MAX would overflow | ||
| assert_preimage_none(ScalarValue::Decimal32(Some(i32::MAX), 9, 0)); | ||
|
|
||
| // Decimal64: i64::MAX | ||
| assert_preimage_none(ScalarValue::Decimal64(Some(i64::MAX - 50), 18, 2)); | ||
| assert_preimage_none(ScalarValue::Decimal64(Some(i64::MAX), 18, 0)); | ||
|
|
||
| // Decimal128: i128::MAX | ||
| assert_preimage_none(ScalarValue::Decimal128(Some(i128::MAX - 50), 38, 2)); | ||
| assert_preimage_none(ScalarValue::Decimal128(Some(i128::MAX), 38, 0)); | ||
|
|
||
| // Decimal256: i256::MAX | ||
| assert_preimage_none(ScalarValue::Decimal256( | ||
| Some(i256::MAX.wrapping_sub(i256::from(50))), | ||
| 76, | ||
| 2, | ||
| )); | ||
| assert_preimage_none(ScalarValue::Decimal256(Some(i256::MAX), 76, 0)); | ||
| } | ||
|
|
||
| #[test] | ||
| fn test_floor_preimage_decimal_edge_cases() { | ||
| // ===== Decimal32 ===== | ||
| // Large value that doesn't overflow | ||
| // i32::MAX = 2147483647, with scale=2, max safe is around i32::MAX - 100 | ||
| let safe_max_32 = i32::MAX - 100; | ||
| // Make it divisible by 100 for scale=2 | ||
| let safe_max_aligned_32 = (safe_max_32 / 100) * 100; | ||
| assert_preimage_range( | ||
| ScalarValue::Decimal32(Some(safe_max_aligned_32), 9, 2), | ||
| ScalarValue::Decimal32(Some(safe_max_aligned_32), 9, 2), | ||
| ScalarValue::Decimal32(Some(safe_max_aligned_32 + 100), 9, 2), | ||
| ); | ||
|
|
||
| // Negative edge: i32::MIN should work since we're adding (not subtracting) | ||
| let min_aligned_32 = (i32::MIN / 100) * 100; | ||
| assert_preimage_range( | ||
| ScalarValue::Decimal32(Some(min_aligned_32), 9, 2), | ||
| ScalarValue::Decimal32(Some(min_aligned_32), 9, 2), | ||
| ScalarValue::Decimal32(Some(min_aligned_32 + 100), 9, 2), | ||
| ); | ||
|
|
||
| // ===== Decimal64 ===== | ||
| let safe_max_64 = i64::MAX - 100; | ||
| let safe_max_aligned_64 = (safe_max_64 / 100) * 100; | ||
| assert_preimage_range( | ||
| ScalarValue::Decimal64(Some(safe_max_aligned_64), 18, 2), | ||
| ScalarValue::Decimal64(Some(safe_max_aligned_64), 18, 2), | ||
| ScalarValue::Decimal64(Some(safe_max_aligned_64 + 100), 18, 2), | ||
| ); | ||
|
|
||
| let min_aligned_64 = (i64::MIN / 100) * 100; | ||
| assert_preimage_range( | ||
| ScalarValue::Decimal64(Some(min_aligned_64), 18, 2), | ||
| ScalarValue::Decimal64(Some(min_aligned_64), 18, 2), | ||
| ScalarValue::Decimal64(Some(min_aligned_64 + 100), 18, 2), | ||
| ); | ||
|
|
||
| // ===== Decimal128 ===== | ||
| let safe_max_128 = i128::MAX - 100; | ||
| let safe_max_aligned_128 = (safe_max_128 / 100) * 100; | ||
| assert_preimage_range( | ||
| ScalarValue::Decimal128(Some(safe_max_aligned_128), 38, 2), | ||
| ScalarValue::Decimal128(Some(safe_max_aligned_128), 38, 2), | ||
| ScalarValue::Decimal128(Some(safe_max_aligned_128 + 100), 38, 2), | ||
| ); | ||
|
|
||
| let min_aligned_128 = (i128::MIN / 100) * 100; | ||
| assert_preimage_range( | ||
| ScalarValue::Decimal128(Some(min_aligned_128), 38, 2), | ||
| ScalarValue::Decimal128(Some(min_aligned_128), 38, 2), | ||
| ScalarValue::Decimal128(Some(min_aligned_128 + 100), 38, 2), | ||
| ); | ||
|
|
||
| // ===== Decimal256 ===== | ||
| // For i256, we use smaller values since MAX is huge | ||
| let large_256 = i256::from(1_000_000_000_000i64); | ||
| assert_preimage_range( | ||
| ScalarValue::Decimal256(Some(large_256), 76, 2), | ||
| ScalarValue::Decimal256(Some(large_256), 76, 2), | ||
| ScalarValue::Decimal256(Some(large_256.wrapping_add(i256::from(100))), 76, 2), | ||
| ); | ||
|
|
||
| // Negative i256 | ||
| let neg_256 = i256::from(-1_000_000_000_000i64); | ||
| assert_preimage_range( | ||
| ScalarValue::Decimal256(Some(neg_256), 76, 2), | ||
| ScalarValue::Decimal256(Some(neg_256), 76, 2), | ||
| ScalarValue::Decimal256(Some(neg_256.wrapping_add(i256::from(100))), 76, 2), | ||
| ); | ||
| } | ||
|
|
||
| #[test] | ||
| fn test_floor_preimage_decimal_null() { | ||
| assert_preimage_none(ScalarValue::Decimal32(None, 9, 2)); | ||
| assert_preimage_none(ScalarValue::Decimal64(None, 18, 2)); | ||
| assert_preimage_none(ScalarValue::Decimal128(None, 38, 2)); | ||
| assert_preimage_none(ScalarValue::Decimal256(None, 76, 2)); | ||
| } | ||
| } | ||
Uh oh!
There was an error while loading. Please reload this page.