@@ -24,6 +24,7 @@ use std::sync::Arc;
2424use crate::error::{DataFusionError, Result};
2525use crate::physical_plan::groups_accumulator::GroupsAccumulator;
2626use crate::physical_plan::groups_accumulator_flat_adapter::GroupsAccumulatorFlatAdapter;
27+ use crate::physical_plan::groups_accumulator_prim_op::PrimitiveGroupsAccumulator;
2728use crate::physical_plan::{Accumulator, AggregateExpr, PhysicalExpr};
2829use crate::scalar::ScalarValue;
2930use arrow::compute;
@@ -49,6 +50,7 @@ use smallvec::SmallVec;
4950pub struct Sum {
5051 name: String,
5152 data_type: DataType,
53+ input_data_type: DataType,
5254 expr: Arc<dyn PhysicalExpr>,
5355 nullable: bool,
5456}
@@ -80,11 +82,16 @@ impl Sum {
8082 expr: Arc<dyn PhysicalExpr>,
8183 name: impl Into<String>,
8284 data_type: DataType,
85+ input_data_type: &DataType,
8386 ) -> Self {
87+ // Note: data_type = sum_return_type(input_data_type) in the actual caller, so we don't
88+ // really need two params. But, we keep the four params to break symmetry with other
89+ // accumulators and any code that might use 3 params, such as the generic_test_op macro.
8490 Self {
8591 name: name.into(),
8692 expr,
8793 data_type,
94+ input_data_type: input_data_type.clone(),
8895 nullable: true,
8996 }
9097 }
@@ -127,12 +134,147 @@ impl AggregateExpr for Sum {
127134 fn create_groups_accumulator(
128135 &self,
129136 ) -> arrow::error::Result<Option<Box<dyn GroupsAccumulator>>> {
130- let data_type = self.data_type.clone();
131- Ok(Some(Box::new(
132- GroupsAccumulatorFlatAdapter::<SumAccumulator>::new(move || {
133- SumAccumulator::try_new(&data_type)
134- }),
135- )))
137+ use arrow::datatypes::ArrowPrimitiveType;
138+
139+ macro_rules! make_accumulator {
140+ ($T:ty, $U:ty) => {
141+ Box::new(PrimitiveGroupsAccumulator::<$T, $U, _, _>::new(
142+ &<$T as ArrowPrimitiveType>::DATA_TYPE,
143+ |x: &mut <$T as ArrowPrimitiveType>::Native,
144+ y: <$U as ArrowPrimitiveType>::Native| {
145+ *x = *x + (y as <$T as ArrowPrimitiveType>::Native);
146+ },
147+ |x: &mut <$T as ArrowPrimitiveType>::Native,
148+ y: <$T as ArrowPrimitiveType>::Native| {
149+ *x = *x + y;
150+ },
151+ ))
152+ };
153+ }
154+
155+ // Note that upstream uses x.add_wrapping(y) for the sum functions -- but here we just mimic
156+ // the current datafusion Sum accumulator implementation using native +. (That native +
157+ // specifically is the one in the expressions *x = *x + ... above.)
158+ Ok(Some(match (&self.data_type, &self.input_data_type) {
159+ (DataType::Int64, DataType::Int64) => make_accumulator!(
160+ arrow::datatypes::Int64Type,
161+ arrow::datatypes::Int64Type
162+ ),
163+ (DataType::Int64, DataType::Int32) => make_accumulator!(
164+ arrow::datatypes::Int64Type,
165+ arrow::datatypes::Int32Type
166+ ),
167+ (DataType::Int64, DataType::Int16) => make_accumulator!(
168+ arrow::datatypes::Int64Type,
169+ arrow::datatypes::Int16Type
170+ ),
171+ (DataType::Int64, DataType::Int8) => {
172+ make_accumulator!(arrow::datatypes::Int64Type, arrow::datatypes::Int8Type)
173+ }
174+
175+ (DataType::Int96, DataType::Int96) => make_accumulator!(
176+ arrow::datatypes::Int96Type,
177+ arrow::datatypes::Int96Type
178+ ),
179+
180+ (DataType::Int64Decimal(0), DataType::Int64Decimal(0)) => make_accumulator!(
181+ arrow::datatypes::Int64Decimal0Type,
182+ arrow::datatypes::Int64Decimal0Type
183+ ),
184+ (DataType::Int64Decimal(1), DataType::Int64Decimal(1)) => make_accumulator!(
185+ arrow::datatypes::Int64Decimal1Type,
186+ arrow::datatypes::Int64Decimal1Type
187+ ),
188+ (DataType::Int64Decimal(2), DataType::Int64Decimal(2)) => make_accumulator!(
189+ arrow::datatypes::Int64Decimal2Type,
190+ arrow::datatypes::Int64Decimal2Type
191+ ),
192+ (DataType::Int64Decimal(3), DataType::Int64Decimal(3)) => make_accumulator!(
193+ arrow::datatypes::Int64Decimal3Type,
194+ arrow::datatypes::Int64Decimal3Type
195+ ),
196+ (DataType::Int64Decimal(4), DataType::Int64Decimal(4)) => make_accumulator!(
197+ arrow::datatypes::Int64Decimal4Type,
198+ arrow::datatypes::Int64Decimal4Type
199+ ),
200+ (DataType::Int64Decimal(5), DataType::Int64Decimal(5)) => make_accumulator!(
201+ arrow::datatypes::Int64Decimal5Type,
202+ arrow::datatypes::Int64Decimal5Type
203+ ),
204+ (DataType::Int64Decimal(10), DataType::Int64Decimal(10)) => {
205+ make_accumulator!(
206+ arrow::datatypes::Int64Decimal10Type,
207+ arrow::datatypes::Int64Decimal10Type
208+ )
209+ }
210+
211+ (DataType::Int96Decimal(0), DataType::Int96Decimal(0)) => make_accumulator!(
212+ arrow::datatypes::Int96Decimal0Type,
213+ arrow::datatypes::Int96Decimal0Type
214+ ),
215+ (DataType::Int96Decimal(1), DataType::Int96Decimal(1)) => make_accumulator!(
216+ arrow::datatypes::Int96Decimal1Type,
217+ arrow::datatypes::Int96Decimal1Type
218+ ),
219+ (DataType::Int96Decimal(2), DataType::Int96Decimal(2)) => make_accumulator!(
220+ arrow::datatypes::Int96Decimal2Type,
221+ arrow::datatypes::Int96Decimal2Type
222+ ),
223+ (DataType::Int96Decimal(3), DataType::Int96Decimal(3)) => make_accumulator!(
224+ arrow::datatypes::Int96Decimal3Type,
225+ arrow::datatypes::Int96Decimal3Type
226+ ),
227+ (DataType::Int96Decimal(4), DataType::Int96Decimal(4)) => make_accumulator!(
228+ arrow::datatypes::Int96Decimal4Type,
229+ arrow::datatypes::Int96Decimal4Type
230+ ),
231+ (DataType::Int96Decimal(5), DataType::Int96Decimal(5)) => make_accumulator!(
232+ arrow::datatypes::Int96Decimal5Type,
233+ arrow::datatypes::Int96Decimal5Type
234+ ),
235+ (DataType::Int96Decimal(10), DataType::Int96Decimal(10)) => {
236+ make_accumulator!(
237+ arrow::datatypes::Int96Decimal10Type,
238+ arrow::datatypes::Int96Decimal10Type
239+ )
240+ }
241+
242+ (DataType::UInt64, DataType::UInt64) => make_accumulator!(
243+ arrow::datatypes::UInt64Type,
244+ arrow::datatypes::UInt64Type
245+ ),
246+ (DataType::UInt64, DataType::UInt32) => make_accumulator!(
247+ arrow::datatypes::UInt64Type,
248+ arrow::datatypes::UInt32Type
249+ ),
250+ (DataType::UInt64, DataType::UInt16) => make_accumulator!(
251+ arrow::datatypes::UInt64Type,
252+ arrow::datatypes::UInt16Type
253+ ),
254+ (DataType::UInt64, DataType::UInt8) => make_accumulator!(
255+ arrow::datatypes::UInt64Type,
256+ arrow::datatypes::UInt8Type
257+ ),
258+
259+ (DataType::Float32, DataType::Float32) => make_accumulator!(
260+ arrow::datatypes::Float32Type,
261+ arrow::datatypes::Float32Type
262+ ),
263+ (DataType::Float64, DataType::Float64) => make_accumulator!(
264+ arrow::datatypes::Float64Type,
265+ arrow::datatypes::Float64Type
266+ ),
267+
268+ _ => {
269+ // This case should never be reached because we've handled all sum_return_type
270+ // arg_type values. Nonetheless:
271+ let data_type = self.data_type.clone();
272+
273+ Box::new(GroupsAccumulatorFlatAdapter::<SumAccumulator>::new(
274+ move || SumAccumulator::try_new(&data_type),
275+ ))
276+ }
277+ }))
136278 }
137279
138280 fn name(&self) -> &str {
@@ -416,13 +558,27 @@ mod tests {
416558 use arrow::datatypes::*;
417559 use arrow::record_batch::RecordBatch;
418560
561+ // A wrapper to make Sum::new, which now has an input_type argument, work with
562+ // generic_test_op!.
563+ struct SumTestStandin;
564+ impl SumTestStandin {
565+ fn new(
566+ expr: Arc<dyn PhysicalExpr>,
567+ name: impl Into<String>,
568+ data_type: DataType,
569+ ) -> Sum {
570+ Sum::new(expr, name, data_type.clone(), &data_type)
571+ }
572+ }
573+
419574 #[test]
420575 fn sum_i32() -> Result<()> {
421576 let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5]));
577+
422578 generic_test_op!(
423579 a,
424580 DataType::Int32,
425- Sum ,
581+ SumTestStandin ,
426582 ScalarValue::from(15i64),
427583 DataType::Int64
428584 )
@@ -440,7 +596,7 @@ mod tests {
440596 generic_test_op!(
441597 a,
442598 DataType::Int32,
443- Sum ,
599+ SumTestStandin ,
444600 ScalarValue::from(13i64),
445601 DataType::Int64
446602 )
@@ -452,7 +608,7 @@ mod tests {
452608 generic_test_op!(
453609 a,
454610 DataType::Int32,
455- Sum ,
611+ SumTestStandin ,
456612 ScalarValue::Int64(None),
457613 DataType::Int64
458614 )
@@ -465,7 +621,7 @@ mod tests {
465621 generic_test_op!(
466622 a,
467623 DataType::UInt32,
468- Sum ,
624+ SumTestStandin ,
469625 ScalarValue::from(15u64),
470626 DataType::UInt64
471627 )
@@ -478,7 +634,7 @@ mod tests {
478634 generic_test_op!(
479635 a,
480636 DataType::Float32,
481- Sum ,
637+ SumTestStandin ,
482638 ScalarValue::from(15_f32),
483639 DataType::Float32
484640 )
@@ -491,7 +647,7 @@ mod tests {
491647 generic_test_op!(
492648 a,
493649 DataType::Float64,
494- Sum ,
650+ SumTestStandin ,
495651 ScalarValue::from(15_f64),
496652 DataType::Float64
497653 )
0 commit comments