1616// under the License.
1717
1818use ahash:: RandomState ;
19- use arrow:: array:: NullArray ;
2019use arrow:: buffer:: { OffsetBuffer , ScalarBuffer } ;
2120use datafusion_common:: stats:: Precision ;
2221use datafusion_expr:: expr:: WindowFunction ;
@@ -212,7 +211,9 @@ impl AggregateUDFImpl for Count {
212211 format_state_name( args. name, "count distinct" ) ,
213212 // See COMMENTS.md to understand why nullable is set to true
214213 Field :: new_list_field( args. input_types[ 0 ] . clone( ) , true ) ,
215- false ,
214+ // For group count distinct accumulator, null list item stands for an
215+ // empty value set (i.e., all NULL value so far for that group).
216+ true ,
216217 ) ] )
217218 } else {
218219 Ok ( vec ! [ Field :: new(
@@ -360,7 +361,9 @@ impl AggregateUDFImpl for Count {
360361 ) -> Result < Box < dyn GroupsAccumulator > > {
361362 // instantiate specialized accumulator
362363 if args. is_distinct {
363- Ok ( Box :: new ( DistinctCountGroupsAccumulator :: new ( ) ) )
364+ Ok ( Box :: new ( DistinctCountGroupsAccumulator :: new (
365+ args. exprs [ 0 ] . data_type ( args. schema ) ?,
366+ ) ) )
364367 } else {
365368 Ok ( Box :: new ( CountGroupsAccumulator :: new ( ) ) )
366369 }
@@ -759,15 +762,19 @@ impl Accumulator for DistinctCountAccumulator {
759762}
760763
761764/// GroupsAccumulator for COUNT DISTINCT operations
762- #[ derive( Debug , Default ) ]
765+ #[ derive( Debug ) ]
763766pub struct DistinctCountGroupsAccumulator {
764767 /// One HashSet per group to track distinct values
765768 distinct_sets : Vec < HashSet < ScalarValue , RandomState > > ,
769+ data_type : DataType ,
766770}
767771
768772impl DistinctCountGroupsAccumulator {
769- pub fn new ( ) -> Self {
770- Self :: default ( )
773+ pub fn new ( data_type : DataType ) -> Self {
774+ Self {
775+ distinct_sets : vec ! [ ] ,
776+ data_type,
777+ }
771778 }
772779
773780 fn ensure_sets ( & mut self , total_num_groups : usize ) {
@@ -850,30 +857,13 @@ impl GroupsAccumulator for DistinctCountGroupsAccumulator {
850857 }
851858
852859 fn evaluate ( & mut self , emit_to : EmitTo ) -> Result < ArrayRef > {
853- // Convert counts to Int64Array
854- let counts = match emit_to {
855- EmitTo :: All => {
856- let counts: Vec < i64 > = self
857- . distinct_sets
858- . iter ( )
859- . map ( |set| set. len ( ) as i64 )
860- . collect ( ) ;
861- self . distinct_sets . clear ( ) ;
862- counts
863- }
864- EmitTo :: First ( n) => {
865- let counts: Vec < i64 > = self
866- . distinct_sets
867- . iter ( )
868- . take ( n)
869- . map ( |set| set. len ( ) as i64 )
870- . collect ( ) ;
871- self . distinct_sets = self . distinct_sets . split_off ( n) ;
872- counts
873- }
874- } ;
860+ let distinct_sets: Vec < HashSet < ScalarValue , RandomState > > =
861+ emit_to. take_needed ( & mut self . distinct_sets ) ;
875862
876- // COUNT DISTINCT never returns nulls
863+ let counts = distinct_sets
864+ . iter ( )
865+ . map ( |set| set. len ( ) as i64 )
866+ . collect :: < Vec < _ > > ( ) ;
877867 Ok ( Arc :: new ( Int64Array :: from ( counts) ) )
878868 }
879869
@@ -929,14 +919,14 @@ impl GroupsAccumulator for DistinctCountGroupsAccumulator {
929919 } )
930920 . peekable ( ) ;
931921 let data_array: ArrayRef = if value_iter. peek ( ) . is_none ( ) {
932- Arc :: new ( NullArray :: new ( 0 ) ) as _
922+ arrow :: array :: new_empty_array ( & self . data_type ) as _
933923 } else {
934924 Arc :: new ( ScalarValue :: iter_to_array ( value_iter) ?) as _
935925 } ;
936926 let offset_buffer = OffsetBuffer :: new ( ScalarBuffer :: from ( offsets) ) ;
937927
938928 let list_array = ListArray :: new (
939- Arc :: new ( Field :: new_list_field ( data_array . data_type ( ) . clone ( ) , true ) ) ,
929+ Arc :: new ( Field :: new_list_field ( self . data_type . clone ( ) , true ) ) ,
940930 offset_buffer,
941931 data_array,
942932 None ,
@@ -1021,7 +1011,7 @@ mod tests {
10211011
10221012 #[ test]
10231013 fn test_distinct_count_groups_basic ( ) -> Result < ( ) > {
1024- let mut accumulator = DistinctCountGroupsAccumulator :: new ( ) ;
1014+ let mut accumulator = DistinctCountGroupsAccumulator :: new ( DataType :: Int32 ) ;
10251015 let values = vec ! [ Arc :: new( Int32Array :: from( vec![ 1 , 2 , 1 , 3 , 2 , 1 ] ) ) as ArrayRef ] ;
10261016
10271017 // 3 groups
@@ -1043,7 +1033,7 @@ mod tests {
10431033
10441034 #[ test]
10451035 fn test_distinct_count_groups_with_filter ( ) -> Result < ( ) > {
1046- let mut accumulator = DistinctCountGroupsAccumulator :: new ( ) ;
1036+ let mut accumulator = DistinctCountGroupsAccumulator :: new ( DataType :: Utf8 ) ;
10471037 let values = vec ! [
10481038 Arc :: new( StringArray :: from( vec![ "a" , "b" , "a" , "c" , "b" , "d" ] ) ) as ArrayRef ,
10491039 ] ;
0 commit comments