Skip to content

Commit ca805fb

Browse files
committed
also for normal accumulator
Signed-off-by: Ruihang Xia <waynestxia@gmail.com>
1 parent 86390ef commit ca805fb

File tree

1 file changed

+38
-135
lines changed
  • datafusion/functions-aggregate/src

1 file changed

+38
-135
lines changed

datafusion/functions-aggregate/src/count.rs

Lines changed: 38 additions & 135 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ use arrow::array::{
2323
use arrow::buffer::{Buffer, OffsetBuffer, ScalarBuffer};
2424
use datafusion_common::hash_utils::create_hashes;
2525
use datafusion_common::stats::Precision;
26+
use datafusion_common::utils::SingleRowListArrayBuilder;
2627
use datafusion_expr::expr::WindowFunction;
2728
use datafusion_functions_aggregate_common::aggregate::count_distinct::BytesViewDistinctCountAccumulator;
2829
use datafusion_functions_aggregate_common::aggregate::groups_accumulator::nulls::filtered_null_mask;
@@ -236,8 +237,6 @@ impl AggregateUDFImpl for Count {
236237
}
237238

238239
fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
239-
panic!("should not create normal accumulator");
240-
241240
if !acc_args.is_distinct {
242241
return Ok(Box::new(CountAccumulator::new()));
243242
}
@@ -246,116 +245,11 @@ impl AggregateUDFImpl for Count {
246245
return not_impl_err!("COUNT DISTINCT with multiple arguments");
247246
}
248247

249-
let data_type = &acc_args.exprs[0].data_type(acc_args.schema)?;
250-
Ok(match data_type {
251-
// try and use a specialized accumulator if possible, otherwise fall back to generic accumulator
252-
DataType::Int8 => Box::new(
253-
PrimitiveDistinctCountAccumulator::<Int8Type>::new(data_type),
254-
),
255-
DataType::Int16 => Box::new(
256-
PrimitiveDistinctCountAccumulator::<Int16Type>::new(data_type),
257-
),
258-
DataType::Int32 => Box::new(
259-
PrimitiveDistinctCountAccumulator::<Int32Type>::new(data_type),
260-
),
261-
DataType::Int64 => Box::new(
262-
PrimitiveDistinctCountAccumulator::<Int64Type>::new(data_type),
263-
),
264-
DataType::UInt8 => Box::new(
265-
PrimitiveDistinctCountAccumulator::<UInt8Type>::new(data_type),
266-
),
267-
DataType::UInt16 => Box::new(
268-
PrimitiveDistinctCountAccumulator::<UInt16Type>::new(data_type),
269-
),
270-
DataType::UInt32 => Box::new(
271-
PrimitiveDistinctCountAccumulator::<UInt32Type>::new(data_type),
272-
),
273-
DataType::UInt64 => Box::new(
274-
PrimitiveDistinctCountAccumulator::<UInt64Type>::new(data_type),
275-
),
276-
DataType::Decimal128(_, _) => Box::new(PrimitiveDistinctCountAccumulator::<
277-
Decimal128Type,
278-
>::new(data_type)),
279-
DataType::Decimal256(_, _) => Box::new(PrimitiveDistinctCountAccumulator::<
280-
Decimal256Type,
281-
>::new(data_type)),
282-
283-
DataType::Date32 => Box::new(
284-
PrimitiveDistinctCountAccumulator::<Date32Type>::new(data_type),
285-
),
286-
DataType::Date64 => Box::new(
287-
PrimitiveDistinctCountAccumulator::<Date64Type>::new(data_type),
288-
),
289-
DataType::Time32(TimeUnit::Millisecond) => Box::new(
290-
PrimitiveDistinctCountAccumulator::<Time32MillisecondType>::new(
291-
data_type,
292-
),
293-
),
294-
DataType::Time32(TimeUnit::Second) => Box::new(
295-
PrimitiveDistinctCountAccumulator::<Time32SecondType>::new(data_type),
296-
),
297-
DataType::Time64(TimeUnit::Microsecond) => Box::new(
298-
PrimitiveDistinctCountAccumulator::<Time64MicrosecondType>::new(
299-
data_type,
300-
),
301-
),
302-
DataType::Time64(TimeUnit::Nanosecond) => Box::new(
303-
PrimitiveDistinctCountAccumulator::<Time64NanosecondType>::new(data_type),
304-
),
305-
DataType::Timestamp(TimeUnit::Microsecond, _) => Box::new(
306-
PrimitiveDistinctCountAccumulator::<TimestampMicrosecondType>::new(
307-
data_type,
308-
),
309-
),
310-
DataType::Timestamp(TimeUnit::Millisecond, _) => Box::new(
311-
PrimitiveDistinctCountAccumulator::<TimestampMillisecondType>::new(
312-
data_type,
313-
),
314-
),
315-
DataType::Timestamp(TimeUnit::Nanosecond, _) => Box::new(
316-
PrimitiveDistinctCountAccumulator::<TimestampNanosecondType>::new(
317-
data_type,
318-
),
319-
),
320-
DataType::Timestamp(TimeUnit::Second, _) => Box::new(
321-
PrimitiveDistinctCountAccumulator::<TimestampSecondType>::new(data_type),
322-
),
323-
324-
DataType::Float16 => {
325-
Box::new(FloatDistinctCountAccumulator::<Float16Type>::new())
326-
}
327-
DataType::Float32 => {
328-
Box::new(FloatDistinctCountAccumulator::<Float32Type>::new())
329-
}
330-
DataType::Float64 => {
331-
Box::new(FloatDistinctCountAccumulator::<Float64Type>::new())
332-
}
333-
334-
DataType::Utf8 => {
335-
Box::new(BytesDistinctCountAccumulator::<i32>::new(OutputType::Utf8))
336-
}
337-
DataType::Utf8View => {
338-
Box::new(BytesViewDistinctCountAccumulator::new(OutputType::Utf8View))
339-
}
340-
DataType::LargeUtf8 => {
341-
Box::new(BytesDistinctCountAccumulator::<i64>::new(OutputType::Utf8))
342-
}
343-
DataType::Binary => Box::new(BytesDistinctCountAccumulator::<i32>::new(
344-
OutputType::Binary,
345-
)),
346-
DataType::BinaryView => Box::new(BytesViewDistinctCountAccumulator::new(
347-
OutputType::BinaryView,
348-
)),
349-
DataType::LargeBinary => Box::new(BytesDistinctCountAccumulator::<i64>::new(
350-
OutputType::Binary,
351-
)),
352-
353-
// Use the generic accumulator based on `ScalarValue` for all other types
354-
_ => Box::new(DistinctCountAccumulator {
355-
values: HashSet::default(),
356-
state_data_type: data_type.clone(),
357-
}),
358-
})
248+
Ok(Box::new(DistinctCountAccumulator {
249+
values: HashSet::default(),
250+
random_state: RandomState::with_seeds(1, 2, 3, 4),
251+
batch_hashes: vec![],
252+
}))
359253
}
360254

361255
fn aliases(&self) -> &[String] {
@@ -681,8 +575,9 @@ fn null_count_for_multiple_cols(values: &[ArrayRef]) -> usize {
681575
/// [`BytesDistinctCountAccumulator`]
682576
#[derive(Debug)]
683577
struct DistinctCountAccumulator {
684-
values: HashSet<ScalarValue, RandomState>,
685-
state_data_type: DataType,
578+
values: HashSet<u64, RandomState>,
579+
random_state: RandomState,
580+
batch_hashes: Vec<u64>,
686581
}
687582

688583
impl DistinctCountAccumulator {
@@ -691,12 +586,12 @@ impl DistinctCountAccumulator {
691586
// not suitable for variable length values like strings or complex types
692587
fn fixed_size(&self) -> usize {
693588
size_of_val(self)
694-
+ (size_of::<ScalarValue>() * self.values.capacity())
589+
+ (size_of::<u64>() * self.values.capacity())
695590
+ self
696591
.values
697592
.iter()
698593
.next()
699-
.map(|vals| ScalarValue::size(vals) - size_of_val(vals))
594+
.map(|vals| 8 - size_of_val(vals))
700595
.unwrap_or(0)
701596
+ size_of::<DataType>()
702597
}
@@ -705,11 +600,11 @@ impl DistinctCountAccumulator {
705600
// method is expensive
706601
fn full_size(&self) -> usize {
707602
size_of_val(self)
708-
+ (size_of::<ScalarValue>() * self.values.capacity())
603+
+ (size_of::<u64>() * self.values.capacity())
709604
+ self
710605
.values
711606
.iter()
712-
.map(|vals| ScalarValue::size(vals) - size_of_val(vals))
607+
.map(|vals| 8 - size_of_val(vals))
713608
.sum::<usize>()
714609
+ size_of::<DataType>()
715610
}
@@ -718,10 +613,10 @@ impl DistinctCountAccumulator {
718613
impl Accumulator for DistinctCountAccumulator {
719614
/// Returns the distinct values seen so far as (one element) ListArray.
720615
fn state(&mut self) -> Result<Vec<ScalarValue>> {
721-
let scalars = self.values.iter().cloned().collect::<Vec<_>>();
722-
let arr =
723-
ScalarValue::new_list_nullable(scalars.as_slice(), &self.state_data_type);
724-
Ok(vec![ScalarValue::List(arr)])
616+
let values = self.values.iter().cloned().collect::<Vec<_>>();
617+
let arr = Arc::new(UInt64Array::from(values)) as _;
618+
let list_scalar = SingleRowListArrayBuilder::new(arr).build_list_scalar();
619+
Ok(vec![list_scalar])
725620
}
726621

727622
fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
@@ -734,13 +629,21 @@ impl Accumulator for DistinctCountAccumulator {
734629
return Ok(());
735630
}
736631

737-
(0..arr.len()).try_for_each(|index| {
738-
if !arr.is_null(index) {
739-
let scalar = ScalarValue::try_from_array(arr, index)?;
740-
self.values.insert(scalar);
741-
}
742-
Ok(())
743-
})
632+
// (0..arr.len()).try_for_each(|index| {
633+
// if !arr.is_null(index) {
634+
// let scalar = ScalarValue::try_from_array(arr, index)?;
635+
// self.values.insert(scalar);
636+
// }
637+
// Ok(())
638+
// })
639+
self.batch_hashes.clear();
640+
self.batch_hashes.resize(arr.len(), 0);
641+
let hashes =
642+
create_hashes(&[arr.clone()], &self.random_state, &mut self.batch_hashes)?;
643+
for hash in hashes.as_slice() {
644+
self.values.insert(*hash);
645+
}
646+
Ok(())
744647
}
745648

746649
/// Merges multiple sets of distinct values into the current set.
@@ -761,7 +664,11 @@ impl Accumulator for DistinctCountAccumulator {
761664
"Intermediate results of COUNT DISTINCT should always be non null"
762665
);
763666
};
764-
self.update_batch(&[inner_array])?;
667+
// self.update_batch(&[inner_array])?;
668+
let hash_array = inner_array.as_any().downcast_ref::<UInt64Array>().unwrap();
669+
for i in 0..hash_array.len() {
670+
self.values.insert(hash_array.value(i));
671+
}
765672
}
766673
Ok(())
767674
}
@@ -771,11 +678,7 @@ impl Accumulator for DistinctCountAccumulator {
771678
}
772679

773680
fn size(&self) -> usize {
774-
match &self.state_data_type {
775-
DataType::Boolean | DataType::Null => self.fixed_size(),
776-
d if d.is_primitive() => self.fixed_size(),
777-
_ => self.full_size(),
778-
}
681+
self.fixed_size()
779682
}
780683
}
781684

0 commit comments

Comments
 (0)