Skip to content

Commit 4032129

Browse files
authored
feat: date_add and date_sub functions (apache#910)
* date_add test case. * Add DateAdd to proto and QueryPlanSerde. Next up is the native side. * Add DateAdd in planner.rs that generates a Literal for right child. Need to confirm if any other type of expression can occur here. * Minor refactor. * Change test predicate to actually select some rows. * Switch to scalar UDF implementation for date_add. * Docs and minor refactor. * Add a new test to explicitly cover array scenario. * cargo clippy fixes * Fix Scala 2.13. * New approved plans for q72 due to date_add. * Address first round of feedback. * Add date_sub and tests. * Fix error message to be more general. * Update error message for Spark 4.0+ * Support Int8 and Int16 for days.
1 parent acd5aac commit 4032129

16 files changed

Lines changed: 1380 additions & 1411 deletions

File tree

native/core/src/execution/datafusion/expressions/comet_scalar_funcs.rs

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,9 @@ use datafusion_comet_spark_expr::scalar_funcs::hash_expressions::{
2020
spark_sha224, spark_sha256, spark_sha384, spark_sha512,
2121
};
2222
use datafusion_comet_spark_expr::scalar_funcs::{
23-
spark_ceil, spark_decimal_div, spark_floor, spark_hex, spark_isnan, spark_make_decimal,
24-
spark_murmur3_hash, spark_read_side_padding, spark_round, spark_unhex, spark_unscaled_value,
25-
spark_xxhash64, SparkChrFunc,
23+
spark_ceil, spark_date_add, spark_date_sub, spark_decimal_div, spark_floor, spark_hex,
24+
spark_isnan, spark_make_decimal, spark_murmur3_hash, spark_read_side_padding, spark_round,
25+
spark_unhex, spark_unscaled_value, spark_xxhash64, SparkChrFunc,
2626
};
2727
use datafusion_common::{DataFusionError, Result as DataFusionResult};
2828
use datafusion_expr::registry::FunctionRegistry;
@@ -121,6 +121,14 @@ pub fn create_comet_physical_fun(
121121
let func = Arc::new(spark_sha512);
122122
make_comet_scalar_udf!("sha512", func, without data_type)
123123
}
124+
"date_add" => {
125+
let func = Arc::new(spark_date_add);
126+
make_comet_scalar_udf!("date_add", func, without data_type)
127+
}
128+
"date_sub" => {
129+
let func = Arc::new(spark_date_sub);
130+
make_comet_scalar_udf!("date_sub", func, without data_type)
131+
}
124132
_ => registry.udf(fun_name).map_err(|e| {
125133
DataFusionError::Execution(format!(
126134
"Function {fun_name} not found in the registry: {e}",

native/spark-expr/src/scalar_funcs.rs

Lines changed: 80 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,16 +15,20 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18+
use arrow::compute::kernels::numeric::{add, sub};
19+
use arrow::datatypes::IntervalDayTime;
1820
use arrow::{
1921
array::{
2022
ArrayRef, AsArray, Decimal128Builder, Float32Array, Float64Array, Int16Array, Int32Array,
2123
Int64Array, Int64Builder, Int8Array, OffsetSizeTrait,
2224
},
2325
datatypes::{validate_decimal_precision, Decimal128Type, Int64Type},
2426
};
25-
use arrow_array::builder::GenericStringBuilder;
26-
use arrow_array::{Array, ArrowNativeTypeOp, BooleanArray, Decimal128Array};
27-
use arrow_schema::{DataType, DECIMAL128_MAX_PRECISION};
27+
use arrow_array::builder::{GenericStringBuilder, IntervalDayTimeBuilder};
28+
use arrow_array::types::{Int16Type, Int32Type, Int8Type};
29+
use arrow_array::{Array, ArrowNativeTypeOp, BooleanArray, Datum, Decimal128Array};
30+
use arrow_schema::{ArrowError, DataType, DECIMAL128_MAX_PRECISION};
31+
use datafusion::physical_expr_common::datum;
2832
use datafusion::{functions::math::round::round, physical_plan::ColumnarValue};
2933
use datafusion_common::{
3034
cast::as_generic_string_array, exec_err, internal_err, DataFusionError,
@@ -547,3 +551,76 @@ pub fn spark_isnan(args: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionEr
547551
},
548552
}
549553
}
554+
555+
macro_rules! scalar_date_arithmetic {
556+
($start:expr, $days:expr, $op:expr) => {{
557+
let interval = IntervalDayTime::new(*$days as i32, 0);
558+
let interval_cv = ColumnarValue::Scalar(ScalarValue::IntervalDayTime(Some(interval)));
559+
datum::apply($start, &interval_cv, $op)
560+
}};
561+
}
562+
macro_rules! array_date_arithmetic {
563+
($days:expr, $interval_builder:expr, $intType:ty) => {{
564+
for day in $days.as_primitive::<$intType>().into_iter() {
565+
if let Some(non_null_day) = day {
566+
$interval_builder.append_value(IntervalDayTime::new(non_null_day as i32, 0));
567+
} else {
568+
$interval_builder.append_null();
569+
}
570+
}
571+
}};
572+
}
573+
574+
/// Spark-compatible `date_add` and `date_sub` expressions, which assumes days for the second
575+
/// argument, but we cannot directly add that to a Date32. We generate an IntervalDayTime from the
576+
/// second argument and use DataFusion's interface to apply Arrow's operators.
577+
fn spark_date_arithmetic(
578+
args: &[ColumnarValue],
579+
op: impl Fn(&dyn Datum, &dyn Datum) -> Result<ArrayRef, ArrowError>,
580+
) -> Result<ColumnarValue, DataFusionError> {
581+
let start = &args[0];
582+
match &args[1] {
583+
ColumnarValue::Scalar(ScalarValue::Int8(Some(days))) => {
584+
scalar_date_arithmetic!(start, days, op)
585+
}
586+
ColumnarValue::Scalar(ScalarValue::Int16(Some(days))) => {
587+
scalar_date_arithmetic!(start, days, op)
588+
}
589+
ColumnarValue::Scalar(ScalarValue::Int32(Some(days))) => {
590+
scalar_date_arithmetic!(start, days, op)
591+
}
592+
ColumnarValue::Array(days) => {
593+
let mut interval_builder = IntervalDayTimeBuilder::with_capacity(days.len());
594+
match days.data_type() {
595+
DataType::Int8 => {
596+
array_date_arithmetic!(days, interval_builder, Int8Type)
597+
}
598+
DataType::Int16 => {
599+
array_date_arithmetic!(days, interval_builder, Int16Type)
600+
}
601+
DataType::Int32 => {
602+
array_date_arithmetic!(days, interval_builder, Int32Type)
603+
}
604+
_ => {
605+
return Err(DataFusionError::Internal(format!(
606+
"Unsupported data types {:?} for date arithmetic.",
607+
args,
608+
)))
609+
}
610+
}
611+
let interval_cv = ColumnarValue::Array(Arc::new(interval_builder.finish()));
612+
datum::apply(start, &interval_cv, op)
613+
}
614+
_ => Err(DataFusionError::Internal(format!(
615+
"Unsupported data types {:?} for date arithmetic.",
616+
args,
617+
))),
618+
}
619+
}
620+
pub fn spark_date_add(args: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionError> {
621+
spark_date_arithmetic(args, add)
622+
}
623+
624+
pub fn spark_date_sub(args: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionError> {
625+
spark_date_arithmetic(args, sub)
626+
}

spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1492,6 +1492,18 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim
14921492
None
14931493
}
14941494

1495+
case DateAdd(left, right) =>
1496+
val leftExpr = exprToProtoInternal(left, inputs)
1497+
val rightExpr = exprToProtoInternal(right, inputs)
1498+
val optExpr = scalarExprToProtoWithReturnType("date_add", DateType, leftExpr, rightExpr)
1499+
optExprWithInfo(optExpr, expr, left, right)
1500+
1501+
case DateSub(left, right) =>
1502+
val leftExpr = exprToProtoInternal(left, inputs)
1503+
val rightExpr = exprToProtoInternal(right, inputs)
1504+
val optExpr = scalarExprToProtoWithReturnType("date_sub", DateType, leftExpr, rightExpr)
1505+
optExprWithInfo(optExpr, expr, left, right)
1506+
14951507
case TruncDate(child, format) =>
14961508
val childExpr = exprToProtoInternal(child, inputs)
14971509
val formatExpr = exprToProtoInternal(format, inputs)

0 commit comments

Comments
 (0)