Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
321 changes: 317 additions & 4 deletions datafusion/functions/src/math/floor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,10 @@ use std::any::Any;
use std::sync::Arc;

use arrow::array::{ArrayRef, AsArray};
use arrow::compute::{DecimalCast, rescale_decimal};
use arrow::datatypes::{
DataType, Decimal32Type, Decimal64Type, Decimal128Type, Decimal256Type, Float32Type,
Float64Type,
ArrowNativeTypeOp, DataType, Decimal32Type, Decimal64Type, Decimal128Type,
Decimal256Type, DecimalType, Float32Type, Float64Type,
};
use datafusion_common::{Result, ScalarValue, exec_err};
use datafusion_expr::interval_arithmetic::Interval;
Expand Down Expand Up @@ -230,8 +231,6 @@ impl ScalarUDFImpl for FloorFunc {

// Compute lower bound (N) and upper bound (N + 1) using helper functions
let Some((lower, upper)) = (match lit_value {
// Decimal types should be supported and tracked in
// https://github.com/apache/datafusion/issues/20080
// Floating-point types
ScalarValue::Float64(Some(n)) => float_preimage_bounds(*n).map(|(lo, hi)| {
(
Expand Down Expand Up @@ -260,6 +259,48 @@ impl ScalarUDFImpl for FloorFunc {
(ScalarValue::Int64(Some(lo)), ScalarValue::Int64(Some(hi)))
}),

// Decimal types
ScalarValue::Decimal32(Some(n), precision, scale) => {
decimal_preimage_bounds::<Decimal32Type>(*n, *precision, *scale).map(
Comment thread
devanshu0987 marked this conversation as resolved.
Outdated
|(lo, hi)| {
(
ScalarValue::Decimal32(Some(lo), *precision, *scale),
ScalarValue::Decimal32(Some(hi), *precision, *scale),
)
},
)
}
ScalarValue::Decimal64(Some(n), precision, scale) => {
decimal_preimage_bounds::<Decimal64Type>(*n, *precision, *scale).map(
|(lo, hi)| {
(
ScalarValue::Decimal64(Some(lo), *precision, *scale),
ScalarValue::Decimal64(Some(hi), *precision, *scale),
)
},
)
}
ScalarValue::Decimal128(Some(n), precision, scale) => {
decimal_preimage_bounds::<Decimal128Type>(*n, *precision, *scale).map(
|(lo, hi)| {
(
ScalarValue::Decimal128(Some(lo), *precision, *scale),
ScalarValue::Decimal128(Some(hi), *precision, *scale),
)
},
)
}
ScalarValue::Decimal256(Some(n), precision, scale) => {
decimal_preimage_bounds::<Decimal256Type>(*n, *precision, *scale).map(
|(lo, hi)| {
(
ScalarValue::Decimal256(Some(lo), *precision, *scale),
ScalarValue::Decimal256(Some(hi), *precision, *scale),
)
},
)
}

// Unsupported types
_ => None,
}) else {
Expand Down Expand Up @@ -310,9 +351,45 @@ fn int_preimage_bounds<I: CheckedAdd + One + Copy>(n: I) -> Option<(I, I)> {
Some((n, upper))
}

/// Compute preimage bounds for floor function on decimal types.
/// For floor(x) = n, the preimage is [n, n+1).
/// Returns None if:
/// - The value has a fractional part (floor always returns integers)
/// - Adding 1 would overflow
fn decimal_preimage_bounds<D: DecimalType>(
value: D::Native,
precision: u8,
scale: i8,
) -> Option<(D::Native, D::Native)>
where
D::Native: DecimalCast + ArrowNativeTypeOp + std::ops::Rem<Output = D::Native>,
{
// Use rescale_decimal to compute "1" at target scale (avoids manual pow)
// Convert integer 1 (scale=0) to the target scale
let one_scaled: D::Native = rescale_decimal::<D, D>(
D::Native::ONE, // value = 1
1, // input_precision = 1
0, // input_scale = 0 (integer)
precision, // output_precision
scale, // output_scale
)?;

// floor always returns an integer, so if value has a fractional part, there's no solution
// Check: value % one_scaled != 0 means fractional part exists
if scale > 0 && value % one_scaled != D::Native::ZERO {
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about negative scales ?
Arrow supports them.

Copy link
Copy Markdown
Contributor Author

@devanshu0987 devanshu0987 Feb 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, the logic here comes from the idea that floor(x) = 1.3 has no pre-image solution. We do not want to optimize it via this.

floor(x) will always return a floor data type but it will actually be an integer.

    // floor always returns an integer, so if n has a fractional part, there's no solution
    if n.fract() != F::zero() {
        return None;
    }

Same logic is carried over here for Decimal.

scale <= 0 is effectively an integer and hence only possibility is to check for overflow which we check in the next line.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code review comment in the last PR: #20059 (comment)

return None;
}

// Compute upper bound using checked addition
let upper = value.add_checked(one_scaled).ok()?;

Some((value, upper))
}

#[cfg(test)]
mod tests {
use super::*;
use arrow_buffer::i256;
use datafusion_expr::col;

/// Helper to test valid preimage cases that should return a Range
Expand Down Expand Up @@ -463,4 +540,240 @@ mod tests {
"Expected None for zero args"
);
}

// ============ Decimal32 Tests (mirrors float/int tests) ============

#[test]
fn test_floor_preimage_decimal_valid_cases() {
// ===== Decimal32 =====
// Positive integer decimal: 100.00 (scale=2, so raw=10000)
// floor(x) = 100.00 -> x in [100.00, 101.00)
assert_preimage_range(
ScalarValue::Decimal32(Some(10000), 9, 2),
ScalarValue::Decimal32(Some(10000), 9, 2), // 100.00
ScalarValue::Decimal32(Some(10100), 9, 2), // 101.00
);

// Smaller positive: 50.00
assert_preimage_range(
ScalarValue::Decimal32(Some(5000), 9, 2),
ScalarValue::Decimal32(Some(5000), 9, 2), // 50.00
ScalarValue::Decimal32(Some(5100), 9, 2), // 51.00
);

// Negative integer decimal: -5.00
assert_preimage_range(
ScalarValue::Decimal32(Some(-500), 9, 2),
ScalarValue::Decimal32(Some(-500), 9, 2), // -5.00
ScalarValue::Decimal32(Some(-400), 9, 2), // -4.00
);

// Zero: 0.00
assert_preimage_range(
ScalarValue::Decimal32(Some(0), 9, 2),
ScalarValue::Decimal32(Some(0), 9, 2), // 0.00
ScalarValue::Decimal32(Some(100), 9, 2), // 1.00
);

// Scale 0 (pure integer): 42
assert_preimage_range(
ScalarValue::Decimal32(Some(42), 9, 0),
ScalarValue::Decimal32(Some(42), 9, 0),
ScalarValue::Decimal32(Some(43), 9, 0),
);

// ===== Decimal64 =====
assert_preimage_range(
ScalarValue::Decimal64(Some(10000), 18, 2),
ScalarValue::Decimal64(Some(10000), 18, 2), // 100.00
ScalarValue::Decimal64(Some(10100), 18, 2), // 101.00
);

// Negative
assert_preimage_range(
ScalarValue::Decimal64(Some(-500), 18, 2),
ScalarValue::Decimal64(Some(-500), 18, 2), // -5.00
ScalarValue::Decimal64(Some(-400), 18, 2), // -4.00
);

// Zero
assert_preimage_range(
ScalarValue::Decimal64(Some(0), 18, 2),
ScalarValue::Decimal64(Some(0), 18, 2),
ScalarValue::Decimal64(Some(100), 18, 2),
);

// ===== Decimal128 =====
assert_preimage_range(
ScalarValue::Decimal128(Some(10000), 38, 2),
ScalarValue::Decimal128(Some(10000), 38, 2), // 100.00
ScalarValue::Decimal128(Some(10100), 38, 2), // 101.00
);

// Negative
assert_preimage_range(
ScalarValue::Decimal128(Some(-500), 38, 2),
ScalarValue::Decimal128(Some(-500), 38, 2), // -5.00
ScalarValue::Decimal128(Some(-400), 38, 2), // -4.00
);

// Zero
assert_preimage_range(
ScalarValue::Decimal128(Some(0), 38, 2),
ScalarValue::Decimal128(Some(0), 38, 2),
ScalarValue::Decimal128(Some(100), 38, 2),
);

// ===== Decimal256 =====
assert_preimage_range(
ScalarValue::Decimal256(Some(i256::from(10000)), 76, 2),
ScalarValue::Decimal256(Some(i256::from(10000)), 76, 2), // 100.00
ScalarValue::Decimal256(Some(i256::from(10100)), 76, 2), // 101.00
);

// Negative
assert_preimage_range(
ScalarValue::Decimal256(Some(i256::from(-500)), 76, 2),
ScalarValue::Decimal256(Some(i256::from(-500)), 76, 2), // -5.00
ScalarValue::Decimal256(Some(i256::from(-400)), 76, 2), // -4.00
);

// Zero
assert_preimage_range(
ScalarValue::Decimal256(Some(i256::ZERO), 76, 2),
ScalarValue::Decimal256(Some(i256::ZERO), 76, 2),
ScalarValue::Decimal256(Some(i256::from(100)), 76, 2),
);
}

#[test]
fn test_floor_preimage_decimal_non_integer() {
// floor(x) = 1.30 has NO SOLUTION because floor always returns an integer
// Therefore preimage should return None for non-integer decimals

// Decimal32
assert_preimage_none(ScalarValue::Decimal32(Some(130), 9, 2)); // 1.30
assert_preimage_none(ScalarValue::Decimal32(Some(-250), 9, 2)); // -2.50
assert_preimage_none(ScalarValue::Decimal32(Some(370), 9, 2)); // 3.70
assert_preimage_none(ScalarValue::Decimal32(Some(1), 9, 2)); // 0.01

// Decimal64
assert_preimage_none(ScalarValue::Decimal64(Some(130), 18, 2)); // 1.30
assert_preimage_none(ScalarValue::Decimal64(Some(-250), 18, 2)); // -2.50

// Decimal128
assert_preimage_none(ScalarValue::Decimal128(Some(130), 38, 2)); // 1.30
assert_preimage_none(ScalarValue::Decimal128(Some(-250), 38, 2)); // -2.50

// Decimal256
assert_preimage_none(ScalarValue::Decimal256(Some(i256::from(130)), 76, 2)); // 1.30
assert_preimage_none(ScalarValue::Decimal256(Some(i256::from(-250)), 76, 2)); // -2.50
}

#[test]
fn test_floor_preimage_decimal_overflow() {
// Test near MAX where adding scale_factor would overflow

// Decimal32: i32::MAX
// For scale=2, we add 100, so i32::MAX - 50 would overflow
assert_preimage_none(ScalarValue::Decimal32(Some(i32::MAX - 50), 9, 2));
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What exactly is the idea of this test ?
2147483647 - 50 => 2147483597, has 10 digits and does not fit into precision=9
I think there is no call to add 100 at all here.

Copy link
Copy Markdown
Contributor Author

@devanshu0987 devanshu0987 Feb 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, you are right.
The comment is misleading here. And the test case should move to test_floor_preimage_decimal_non_integer

Decimal32(Some(i32::MAX - 50), 9, 2) is exactly forcing the case explained in your earlier comment.
2147483597 = 21474835.97 which has fractional part and hence the preimage will be None. This test is not the right place for this.

But, 21474835.97 logically is not representable by Decimal(9,2). That is wrong as well. I didn't think too deep here.

Thanks. Let me fix this.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I removed a bunch of repetitive tests and moved the tests around. Now it should make sense. Please let me know.

// For scale=0, we add 1, so i32::MAX would overflow
assert_preimage_none(ScalarValue::Decimal32(Some(i32::MAX), 9, 0));

// Decimal64: i64::MAX
assert_preimage_none(ScalarValue::Decimal64(Some(i64::MAX - 50), 18, 2));
assert_preimage_none(ScalarValue::Decimal64(Some(i64::MAX), 18, 0));

// Decimal128: i128::MAX
assert_preimage_none(ScalarValue::Decimal128(Some(i128::MAX - 50), 38, 2));
assert_preimage_none(ScalarValue::Decimal128(Some(i128::MAX), 38, 0));

// Decimal256: i256::MAX
assert_preimage_none(ScalarValue::Decimal256(
Some(i256::MAX.wrapping_sub(i256::from(50))),
76,
2,
));
assert_preimage_none(ScalarValue::Decimal256(Some(i256::MAX), 76, 0));
}

#[test]
fn test_floor_preimage_decimal_edge_cases() {
// ===== Decimal32 =====
// Large value that doesn't overflow
// i32::MAX = 2147483647, with scale=2, max safe is around i32::MAX - 100
let safe_max_32 = i32::MAX - 100;
// Make it divisible by 100 for scale=2
let safe_max_aligned_32 = (safe_max_32 / 100) * 100;
assert_preimage_range(
ScalarValue::Decimal32(Some(safe_max_aligned_32), 9, 2),
ScalarValue::Decimal32(Some(safe_max_aligned_32), 9, 2),
ScalarValue::Decimal32(Some(safe_max_aligned_32 + 100), 9, 2),
);

// Negative edge: i32::MIN should work since we're adding (not subtracting)
let min_aligned_32 = (i32::MIN / 100) * 100;
assert_preimage_range(
ScalarValue::Decimal32(Some(min_aligned_32), 9, 2),
ScalarValue::Decimal32(Some(min_aligned_32), 9, 2),
ScalarValue::Decimal32(Some(min_aligned_32 + 100), 9, 2),
);

// ===== Decimal64 =====
let safe_max_64 = i64::MAX - 100;
let safe_max_aligned_64 = (safe_max_64 / 100) * 100;
assert_preimage_range(
ScalarValue::Decimal64(Some(safe_max_aligned_64), 18, 2),
ScalarValue::Decimal64(Some(safe_max_aligned_64), 18, 2),
ScalarValue::Decimal64(Some(safe_max_aligned_64 + 100), 18, 2),
);

let min_aligned_64 = (i64::MIN / 100) * 100;
assert_preimage_range(
ScalarValue::Decimal64(Some(min_aligned_64), 18, 2),
ScalarValue::Decimal64(Some(min_aligned_64), 18, 2),
ScalarValue::Decimal64(Some(min_aligned_64 + 100), 18, 2),
);

// ===== Decimal128 =====
let safe_max_128 = i128::MAX - 100;
let safe_max_aligned_128 = (safe_max_128 / 100) * 100;
assert_preimage_range(
ScalarValue::Decimal128(Some(safe_max_aligned_128), 38, 2),
ScalarValue::Decimal128(Some(safe_max_aligned_128), 38, 2),
ScalarValue::Decimal128(Some(safe_max_aligned_128 + 100), 38, 2),
);

let min_aligned_128 = (i128::MIN / 100) * 100;
assert_preimage_range(
ScalarValue::Decimal128(Some(min_aligned_128), 38, 2),
ScalarValue::Decimal128(Some(min_aligned_128), 38, 2),
ScalarValue::Decimal128(Some(min_aligned_128 + 100), 38, 2),
);

// ===== Decimal256 =====
// For i256, we use smaller values since MAX is huge
let large_256 = i256::from(1_000_000_000_000i64);
assert_preimage_range(
ScalarValue::Decimal256(Some(large_256), 76, 2),
ScalarValue::Decimal256(Some(large_256), 76, 2),
ScalarValue::Decimal256(Some(large_256.wrapping_add(i256::from(100))), 76, 2),
);

// Negative i256
let neg_256 = i256::from(-1_000_000_000_000i64);
assert_preimage_range(
ScalarValue::Decimal256(Some(neg_256), 76, 2),
ScalarValue::Decimal256(Some(neg_256), 76, 2),
ScalarValue::Decimal256(Some(neg_256.wrapping_add(i256::from(100))), 76, 2),
);
}

#[test]
fn test_floor_preimage_decimal_null() {
assert_preimage_none(ScalarValue::Decimal32(None, 9, 2));
assert_preimage_none(ScalarValue::Decimal64(None, 18, 2));
assert_preimage_none(ScalarValue::Decimal128(None, 38, 2));
assert_preimage_none(ScalarValue::Decimal256(None, 76, 2));
}
}
Loading