Skip to content

Commit 2bd4986

Browse files
committed
record data type
Signed-off-by: Ruihang Xia <waynestxia@gmail.com>
1 parent f5c0935 commit 2bd4986

File tree

1 file changed

+23
-33
lines changed
  • datafusion/functions-aggregate/src

1 file changed

+23
-33
lines changed

datafusion/functions-aggregate/src/count.rs

Lines changed: 23 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
// under the License.
1717

1818
use ahash::RandomState;
19-
use arrow::array::NullArray;
2019
use arrow::buffer::{OffsetBuffer, ScalarBuffer};
2120
use datafusion_common::stats::Precision;
2221
use datafusion_expr::expr::WindowFunction;
@@ -212,7 +211,9 @@ impl AggregateUDFImpl for Count {
212211
format_state_name(args.name, "count distinct"),
213212
// See COMMENTS.md to understand why nullable is set to true
214213
Field::new_list_field(args.input_types[0].clone(), true),
215-
false,
214+
// For group count distinct accumulator, null list item stands for an
215+
// empty value set (i.e., all NULL value so far for that group).
216+
true,
216217
)])
217218
} else {
218219
Ok(vec![Field::new(
@@ -360,7 +361,9 @@ impl AggregateUDFImpl for Count {
360361
) -> Result<Box<dyn GroupsAccumulator>> {
361362
// instantiate specialized accumulator
362363
if args.is_distinct {
363-
Ok(Box::new(DistinctCountGroupsAccumulator::new()))
364+
Ok(Box::new(DistinctCountGroupsAccumulator::new(
365+
args.exprs[0].data_type(args.schema)?,
366+
)))
364367
} else {
365368
Ok(Box::new(CountGroupsAccumulator::new()))
366369
}
@@ -759,15 +762,19 @@ impl Accumulator for DistinctCountAccumulator {
759762
}
760763

761764
/// GroupsAccumulator for COUNT DISTINCT operations
762-
#[derive(Debug, Default)]
765+
#[derive(Debug)]
763766
pub struct DistinctCountGroupsAccumulator {
764767
/// One HashSet per group to track distinct values
765768
distinct_sets: Vec<HashSet<ScalarValue, RandomState>>,
769+
data_type: DataType,
766770
}
767771

768772
impl DistinctCountGroupsAccumulator {
769-
pub fn new() -> Self {
770-
Self::default()
773+
pub fn new(data_type: DataType) -> Self {
774+
Self {
775+
distinct_sets: vec![],
776+
data_type,
777+
}
771778
}
772779

773780
fn ensure_sets(&mut self, total_num_groups: usize) {
@@ -850,30 +857,13 @@ impl GroupsAccumulator for DistinctCountGroupsAccumulator {
850857
}
851858

852859
fn evaluate(&mut self, emit_to: EmitTo) -> Result<ArrayRef> {
853-
// Convert counts to Int64Array
854-
let counts = match emit_to {
855-
EmitTo::All => {
856-
let counts: Vec<i64> = self
857-
.distinct_sets
858-
.iter()
859-
.map(|set| set.len() as i64)
860-
.collect();
861-
self.distinct_sets.clear();
862-
counts
863-
}
864-
EmitTo::First(n) => {
865-
let counts: Vec<i64> = self
866-
.distinct_sets
867-
.iter()
868-
.take(n)
869-
.map(|set| set.len() as i64)
870-
.collect();
871-
self.distinct_sets = self.distinct_sets.split_off(n);
872-
counts
873-
}
874-
};
860+
let distinct_sets: Vec<HashSet<ScalarValue, RandomState>> =
861+
emit_to.take_needed(&mut self.distinct_sets);
875862

876-
// COUNT DISTINCT never returns nulls
863+
let counts = distinct_sets
864+
.iter()
865+
.map(|set| set.len() as i64)
866+
.collect::<Vec<_>>();
877867
Ok(Arc::new(Int64Array::from(counts)))
878868
}
879869

@@ -929,14 +919,14 @@ impl GroupsAccumulator for DistinctCountGroupsAccumulator {
929919
})
930920
.peekable();
931921
let data_array: ArrayRef = if value_iter.peek().is_none() {
932-
Arc::new(NullArray::new(0)) as _
922+
arrow::array::new_empty_array(&self.data_type) as _
933923
} else {
934924
Arc::new(ScalarValue::iter_to_array(value_iter)?) as _
935925
};
936926
let offset_buffer = OffsetBuffer::new(ScalarBuffer::from(offsets));
937927

938928
let list_array = ListArray::new(
939-
Arc::new(Field::new_list_field(data_array.data_type().clone(), true)),
929+
Arc::new(Field::new_list_field(self.data_type.clone(), true)),
940930
offset_buffer,
941931
data_array,
942932
None,
@@ -1021,7 +1011,7 @@ mod tests {
10211011

10221012
#[test]
10231013
fn test_distinct_count_groups_basic() -> Result<()> {
1024-
let mut accumulator = DistinctCountGroupsAccumulator::new();
1014+
let mut accumulator = DistinctCountGroupsAccumulator::new(DataType::Int32);
10251015
let values = vec![Arc::new(Int32Array::from(vec![1, 2, 1, 3, 2, 1])) as ArrayRef];
10261016

10271017
// 3 groups
@@ -1043,7 +1033,7 @@ mod tests {
10431033

10441034
#[test]
10451035
fn test_distinct_count_groups_with_filter() -> Result<()> {
1046-
let mut accumulator = DistinctCountGroupsAccumulator::new();
1036+
let mut accumulator = DistinctCountGroupsAccumulator::new(DataType::Utf8);
10471037
let values = vec![
10481038
Arc::new(StringArray::from(vec!["a", "b", "a", "c", "b", "d"])) as ArrayRef,
10491039
];

0 commit comments

Comments
 (0)