datafusion_functions_aggregate/
count.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use ahash::RandomState;
19use datafusion_common::stats::Precision;
20use datafusion_expr::expr::WindowFunction;
21use datafusion_functions_aggregate_common::aggregate::count_distinct::BytesViewDistinctCountAccumulator;
22use datafusion_macros::user_doc;
23use datafusion_physical_expr::expressions;
24use std::collections::HashSet;
25use std::fmt::Debug;
26use std::mem::{size_of, size_of_val};
27use std::ops::BitAnd;
28use std::sync::Arc;
29
30use arrow::{
31    array::{ArrayRef, AsArray},
32    compute,
33    datatypes::{
34        DataType, Date32Type, Date64Type, Decimal128Type, Decimal256Type, Field,
35        Float16Type, Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type,
36        Time32MillisecondType, Time32SecondType, Time64MicrosecondType,
37        Time64NanosecondType, TimeUnit, TimestampMicrosecondType,
38        TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType,
39        UInt16Type, UInt32Type, UInt64Type, UInt8Type,
40    },
41};
42
43use arrow::{
44    array::{Array, BooleanArray, Int64Array, PrimitiveArray},
45    buffer::BooleanBuffer,
46};
47use datafusion_common::{
48    downcast_value, internal_err, not_impl_err, Result, ScalarValue,
49};
50use datafusion_expr::function::StateFieldsArgs;
51use datafusion_expr::{
52    function::AccumulatorArgs, utils::format_state_name, Accumulator, AggregateUDFImpl,
53    Documentation, EmitTo, GroupsAccumulator, SetMonotonicity, Signature, Volatility,
54};
55use datafusion_expr::{
56    Expr, ReversedUDAF, StatisticsArgs, TypeSignature, WindowFunctionDefinition,
57};
58use datafusion_functions_aggregate_common::aggregate::count_distinct::{
59    BytesDistinctCountAccumulator, FloatDistinctCountAccumulator,
60    PrimitiveDistinctCountAccumulator,
61};
62use datafusion_functions_aggregate_common::aggregate::groups_accumulator::accumulate::accumulate_indices;
63use datafusion_physical_expr_common::binary_map::OutputType;
64
65use datafusion_common::utils::expr::COUNT_STAR_EXPANSION;
66make_udaf_expr_and_func!(
67    Count,
68    count,
69    expr,
70    "Count the number of non-null values in the column",
71    count_udaf
72);
73
74pub fn count_distinct(expr: Expr) -> Expr {
75    Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction::new_udf(
76        count_udaf(),
77        vec![expr],
78        true,
79        None,
80        None,
81        None,
82    ))
83}
84
85/// Creates aggregation to count all rows.
86///
87/// In SQL this is `SELECT COUNT(*) ... `
88///
89/// The expression is equivalent to `COUNT(*)`, `COUNT()`, `COUNT(1)`, and is
90/// aliased to a column named `"count(*)"` for backward compatibility.
91///
92/// Example
93/// ```
94/// # use datafusion_functions_aggregate::count::count_all;
95/// # use datafusion_expr::col;
96/// // create `count(*)` expression
97/// let expr = count_all();
98/// assert_eq!(expr.schema_name().to_string(), "count(*)");
99/// // if you need to refer to this column, use the `schema_name` function
100/// let expr = col(expr.schema_name().to_string());
101/// ```
102pub fn count_all() -> Expr {
103    count(Expr::Literal(COUNT_STAR_EXPANSION)).alias("count(*)")
104}
105
106/// Creates window aggregation to count all rows.
107///
108/// In SQL this is `SELECT COUNT(*) OVER (..) ... `
109///
110/// The expression is equivalent to `COUNT(*)`, `COUNT()`, `COUNT(1)`
111///
112/// Example
113/// ```
114/// # use datafusion_functions_aggregate::count::count_all_window;
115/// # use datafusion_expr::col;
116/// // create `count(*)` OVER ... window function expression
117/// let expr = count_all_window();
118/// assert_eq!(
119///   expr.schema_name().to_string(),
120///   "count(Int64(1)) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING"
121/// );
122/// // if you need to refer to this column, use the `schema_name` function
123/// let expr = col(expr.schema_name().to_string());
124/// ```
125pub fn count_all_window() -> Expr {
126    Expr::WindowFunction(WindowFunction::new(
127        WindowFunctionDefinition::AggregateUDF(count_udaf()),
128        vec![Expr::Literal(COUNT_STAR_EXPANSION)],
129    ))
130}
131
132#[user_doc(
133    doc_section(label = "General Functions"),
134    description = "Returns the number of non-null values in the specified column. To include null values in the total count, use `count(*)`.",
135    syntax_example = "count(expression)",
136    sql_example = r#"```sql
137> SELECT count(column_name) FROM table_name;
138+-----------------------+
139| count(column_name)     |
140+-----------------------+
141| 100                   |
142+-----------------------+
143
144> SELECT count(*) FROM table_name;
145+------------------+
146| count(*)         |
147+------------------+
148| 120              |
149+------------------+
150```"#,
151    standard_argument(name = "expression",)
152)]
153pub struct Count {
154    signature: Signature,
155}
156
157impl Debug for Count {
158    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
159        f.debug_struct("Count")
160            .field("name", &self.name())
161            .field("signature", &self.signature)
162            .finish()
163    }
164}
165
166impl Default for Count {
167    fn default() -> Self {
168        Self::new()
169    }
170}
171
172impl Count {
173    pub fn new() -> Self {
174        Self {
175            signature: Signature::one_of(
176                vec![TypeSignature::VariadicAny, TypeSignature::Nullary],
177                Volatility::Immutable,
178            ),
179        }
180    }
181}
182
183impl AggregateUDFImpl for Count {
184    fn as_any(&self) -> &dyn std::any::Any {
185        self
186    }
187
188    fn name(&self) -> &str {
189        "count"
190    }
191
192    fn signature(&self) -> &Signature {
193        &self.signature
194    }
195
196    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
197        Ok(DataType::Int64)
198    }
199
200    fn is_nullable(&self) -> bool {
201        false
202    }
203
204    fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> {
205        if args.is_distinct {
206            Ok(vec![Field::new_list(
207                format_state_name(args.name, "count distinct"),
208                // See COMMENTS.md to understand why nullable is set to true
209                Field::new_list_field(args.input_types[0].clone(), true),
210                false,
211            )])
212        } else {
213            Ok(vec![Field::new(
214                format_state_name(args.name, "count"),
215                DataType::Int64,
216                false,
217            )])
218        }
219    }
220
221    fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
222        if !acc_args.is_distinct {
223            return Ok(Box::new(CountAccumulator::new()));
224        }
225
226        if acc_args.exprs.len() > 1 {
227            return not_impl_err!("COUNT DISTINCT with multiple arguments");
228        }
229
230        let data_type = &acc_args.exprs[0].data_type(acc_args.schema)?;
231        Ok(match data_type {
232            // try and use a specialized accumulator if possible, otherwise fall back to generic accumulator
233            DataType::Int8 => Box::new(
234                PrimitiveDistinctCountAccumulator::<Int8Type>::new(data_type),
235            ),
236            DataType::Int16 => Box::new(
237                PrimitiveDistinctCountAccumulator::<Int16Type>::new(data_type),
238            ),
239            DataType::Int32 => Box::new(
240                PrimitiveDistinctCountAccumulator::<Int32Type>::new(data_type),
241            ),
242            DataType::Int64 => Box::new(
243                PrimitiveDistinctCountAccumulator::<Int64Type>::new(data_type),
244            ),
245            DataType::UInt8 => Box::new(
246                PrimitiveDistinctCountAccumulator::<UInt8Type>::new(data_type),
247            ),
248            DataType::UInt16 => Box::new(
249                PrimitiveDistinctCountAccumulator::<UInt16Type>::new(data_type),
250            ),
251            DataType::UInt32 => Box::new(
252                PrimitiveDistinctCountAccumulator::<UInt32Type>::new(data_type),
253            ),
254            DataType::UInt64 => Box::new(
255                PrimitiveDistinctCountAccumulator::<UInt64Type>::new(data_type),
256            ),
257            DataType::Decimal128(_, _) => Box::new(PrimitiveDistinctCountAccumulator::<
258                Decimal128Type,
259            >::new(data_type)),
260            DataType::Decimal256(_, _) => Box::new(PrimitiveDistinctCountAccumulator::<
261                Decimal256Type,
262            >::new(data_type)),
263
264            DataType::Date32 => Box::new(
265                PrimitiveDistinctCountAccumulator::<Date32Type>::new(data_type),
266            ),
267            DataType::Date64 => Box::new(
268                PrimitiveDistinctCountAccumulator::<Date64Type>::new(data_type),
269            ),
270            DataType::Time32(TimeUnit::Millisecond) => Box::new(
271                PrimitiveDistinctCountAccumulator::<Time32MillisecondType>::new(
272                    data_type,
273                ),
274            ),
275            DataType::Time32(TimeUnit::Second) => Box::new(
276                PrimitiveDistinctCountAccumulator::<Time32SecondType>::new(data_type),
277            ),
278            DataType::Time64(TimeUnit::Microsecond) => Box::new(
279                PrimitiveDistinctCountAccumulator::<Time64MicrosecondType>::new(
280                    data_type,
281                ),
282            ),
283            DataType::Time64(TimeUnit::Nanosecond) => Box::new(
284                PrimitiveDistinctCountAccumulator::<Time64NanosecondType>::new(data_type),
285            ),
286            DataType::Timestamp(TimeUnit::Microsecond, _) => Box::new(
287                PrimitiveDistinctCountAccumulator::<TimestampMicrosecondType>::new(
288                    data_type,
289                ),
290            ),
291            DataType::Timestamp(TimeUnit::Millisecond, _) => Box::new(
292                PrimitiveDistinctCountAccumulator::<TimestampMillisecondType>::new(
293                    data_type,
294                ),
295            ),
296            DataType::Timestamp(TimeUnit::Nanosecond, _) => Box::new(
297                PrimitiveDistinctCountAccumulator::<TimestampNanosecondType>::new(
298                    data_type,
299                ),
300            ),
301            DataType::Timestamp(TimeUnit::Second, _) => Box::new(
302                PrimitiveDistinctCountAccumulator::<TimestampSecondType>::new(data_type),
303            ),
304
305            DataType::Float16 => {
306                Box::new(FloatDistinctCountAccumulator::<Float16Type>::new())
307            }
308            DataType::Float32 => {
309                Box::new(FloatDistinctCountAccumulator::<Float32Type>::new())
310            }
311            DataType::Float64 => {
312                Box::new(FloatDistinctCountAccumulator::<Float64Type>::new())
313            }
314
315            DataType::Utf8 => {
316                Box::new(BytesDistinctCountAccumulator::<i32>::new(OutputType::Utf8))
317            }
318            DataType::Utf8View => {
319                Box::new(BytesViewDistinctCountAccumulator::new(OutputType::Utf8View))
320            }
321            DataType::LargeUtf8 => {
322                Box::new(BytesDistinctCountAccumulator::<i64>::new(OutputType::Utf8))
323            }
324            DataType::Binary => Box::new(BytesDistinctCountAccumulator::<i32>::new(
325                OutputType::Binary,
326            )),
327            DataType::BinaryView => Box::new(BytesViewDistinctCountAccumulator::new(
328                OutputType::BinaryView,
329            )),
330            DataType::LargeBinary => Box::new(BytesDistinctCountAccumulator::<i64>::new(
331                OutputType::Binary,
332            )),
333
334            // Use the generic accumulator based on `ScalarValue` for all other types
335            _ => Box::new(DistinctCountAccumulator {
336                values: HashSet::default(),
337                state_data_type: data_type.clone(),
338            }),
339        })
340    }
341
342    fn aliases(&self) -> &[String] {
343        &[]
344    }
345
346    fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
347        // groups accumulator only supports `COUNT(c1)`, not
348        // `COUNT(c1, c2)`, etc
349        if args.is_distinct {
350            return false;
351        }
352        args.exprs.len() == 1
353    }
354
355    fn create_groups_accumulator(
356        &self,
357        _args: AccumulatorArgs,
358    ) -> Result<Box<dyn GroupsAccumulator>> {
359        // instantiate specialized accumulator
360        Ok(Box::new(CountGroupsAccumulator::new()))
361    }
362
363    fn reverse_expr(&self) -> ReversedUDAF {
364        ReversedUDAF::Identical
365    }
366
367    fn default_value(&self, _data_type: &DataType) -> Result<ScalarValue> {
368        Ok(ScalarValue::Int64(Some(0)))
369    }
370
371    fn value_from_stats(&self, statistics_args: &StatisticsArgs) -> Option<ScalarValue> {
372        if statistics_args.is_distinct {
373            return None;
374        }
375        if let Precision::Exact(num_rows) = statistics_args.statistics.num_rows {
376            if statistics_args.exprs.len() == 1 {
377                // TODO optimize with exprs other than Column
378                if let Some(col_expr) = statistics_args.exprs[0]
379                    .as_any()
380                    .downcast_ref::<expressions::Column>()
381                {
382                    let current_val = &statistics_args.statistics.column_statistics
383                        [col_expr.index()]
384                    .null_count;
385                    if let &Precision::Exact(val) = current_val {
386                        return Some(ScalarValue::Int64(Some((num_rows - val) as i64)));
387                    }
388                } else if let Some(lit_expr) = statistics_args.exprs[0]
389                    .as_any()
390                    .downcast_ref::<expressions::Literal>()
391                {
392                    if lit_expr.value() == &COUNT_STAR_EXPANSION {
393                        return Some(ScalarValue::Int64(Some(num_rows as i64)));
394                    }
395                }
396            }
397        }
398        None
399    }
400
401    fn documentation(&self) -> Option<&Documentation> {
402        self.doc()
403    }
404
405    fn set_monotonicity(&self, _data_type: &DataType) -> SetMonotonicity {
406        // `COUNT` is monotonically increasing as it always increases or stays
407        // the same as new values are seen.
408        SetMonotonicity::Increasing
409    }
410}
411
412#[derive(Debug)]
413struct CountAccumulator {
414    count: i64,
415}
416
417impl CountAccumulator {
418    /// new count accumulator
419    pub fn new() -> Self {
420        Self { count: 0 }
421    }
422}
423
424impl Accumulator for CountAccumulator {
425    fn state(&mut self) -> Result<Vec<ScalarValue>> {
426        Ok(vec![ScalarValue::Int64(Some(self.count))])
427    }
428
429    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
430        let array = &values[0];
431        self.count += (array.len() - null_count_for_multiple_cols(values)) as i64;
432        Ok(())
433    }
434
435    fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
436        let array = &values[0];
437        self.count -= (array.len() - null_count_for_multiple_cols(values)) as i64;
438        Ok(())
439    }
440
441    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
442        let counts = downcast_value!(states[0], Int64Array);
443        let delta = &compute::sum(counts);
444        if let Some(d) = delta {
445            self.count += *d;
446        }
447        Ok(())
448    }
449
450    fn evaluate(&mut self) -> Result<ScalarValue> {
451        Ok(ScalarValue::Int64(Some(self.count)))
452    }
453
454    fn supports_retract_batch(&self) -> bool {
455        true
456    }
457
458    fn size(&self) -> usize {
459        size_of_val(self)
460    }
461}
462
463/// An accumulator to compute the counts of [`PrimitiveArray<T>`].
464/// Stores values as native types, and does overflow checking
465///
466/// Unlike most other accumulators, COUNT never produces NULLs. If no
467/// non-null values are seen in any group the output is 0. Thus, this
468/// accumulator has no additional null or seen filter tracking.
469#[derive(Debug)]
470struct CountGroupsAccumulator {
471    /// Count per group.
472    ///
473    /// Note this is an i64 and not a u64 (or usize) because the
474    /// output type of count is `DataType::Int64`. Thus by using `i64`
475    /// for the counts, the output [`Int64Array`] can be created
476    /// without copy.
477    counts: Vec<i64>,
478}
479
480impl CountGroupsAccumulator {
481    pub fn new() -> Self {
482        Self { counts: vec![] }
483    }
484}
485
486impl GroupsAccumulator for CountGroupsAccumulator {
487    fn update_batch(
488        &mut self,
489        values: &[ArrayRef],
490        group_indices: &[usize],
491        opt_filter: Option<&BooleanArray>,
492        total_num_groups: usize,
493    ) -> Result<()> {
494        assert_eq!(values.len(), 1, "single argument to update_batch");
495        let values = &values[0];
496
497        // Add one to each group's counter for each non null, non
498        // filtered value
499        self.counts.resize(total_num_groups, 0);
500        accumulate_indices(
501            group_indices,
502            values.logical_nulls().as_ref(),
503            opt_filter,
504            |group_index| {
505                self.counts[group_index] += 1;
506            },
507        );
508
509        Ok(())
510    }
511
512    fn merge_batch(
513        &mut self,
514        values: &[ArrayRef],
515        group_indices: &[usize],
516        // Since aggregate filter should be applied in partial stage, in final stage there should be no filter
517        _opt_filter: Option<&BooleanArray>,
518        total_num_groups: usize,
519    ) -> Result<()> {
520        assert_eq!(values.len(), 1, "one argument to merge_batch");
521        // first batch is counts, second is partial sums
522        let partial_counts = values[0].as_primitive::<Int64Type>();
523
524        // intermediate counts are always created as non null
525        assert_eq!(partial_counts.null_count(), 0);
526        let partial_counts = partial_counts.values();
527
528        // Adds the counts with the partial counts
529        self.counts.resize(total_num_groups, 0);
530        group_indices.iter().zip(partial_counts.iter()).for_each(
531            |(&group_index, partial_count)| {
532                self.counts[group_index] += partial_count;
533            },
534        );
535
536        Ok(())
537    }
538
539    fn evaluate(&mut self, emit_to: EmitTo) -> Result<ArrayRef> {
540        let counts = emit_to.take_needed(&mut self.counts);
541
542        // Count is always non null (null inputs just don't contribute to the overall values)
543        let nulls = None;
544        let array = PrimitiveArray::<Int64Type>::new(counts.into(), nulls);
545
546        Ok(Arc::new(array))
547    }
548
549    // return arrays for counts
550    fn state(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> {
551        let counts = emit_to.take_needed(&mut self.counts);
552        let counts: PrimitiveArray<Int64Type> = Int64Array::from(counts); // zero copy, no nulls
553        Ok(vec![Arc::new(counts) as ArrayRef])
554    }
555
556    /// Converts an input batch directly to a state batch
557    ///
558    /// The state of `COUNT` is always a single Int64Array:
559    /// * `1` (for non-null, non filtered values)
560    /// * `0` (for null values)
561    fn convert_to_state(
562        &self,
563        values: &[ArrayRef],
564        opt_filter: Option<&BooleanArray>,
565    ) -> Result<Vec<ArrayRef>> {
566        let values = &values[0];
567
568        let state_array = match (values.logical_nulls(), opt_filter) {
569            (None, None) => {
570                // In case there is no nulls in input and no filter, returning array of 1
571                Arc::new(Int64Array::from_value(1, values.len()))
572            }
573            (Some(nulls), None) => {
574                // If there are any nulls in input values -- casting `nulls` (true for values, false for nulls)
575                // of input array to Int64
576                let nulls = BooleanArray::new(nulls.into_inner(), None);
577                compute::cast(&nulls, &DataType::Int64)?
578            }
579            (None, Some(filter)) => {
580                // If there is only filter
581                // - applying filter null mask to filter values by bitand filter values and nulls buffers
582                //   (using buffers guarantees absence of nulls in result)
583                // - casting result of bitand to Int64 array
584                let (filter_values, filter_nulls) = filter.clone().into_parts();
585
586                let state_buf = match filter_nulls {
587                    Some(filter_nulls) => &filter_values & filter_nulls.inner(),
588                    None => filter_values,
589                };
590
591                let boolean_state = BooleanArray::new(state_buf, None);
592                compute::cast(&boolean_state, &DataType::Int64)?
593            }
594            (Some(nulls), Some(filter)) => {
595                // For both input nulls and filter
596                // - applying filter null mask to filter values by bitand filter values and nulls buffers
597                //   (using buffers guarantees absence of nulls in result)
598                // - applying values null mask to filter buffer by another bitand on filter result and
599                //   nulls from input values
600                // - casting result to Int64 array
601                let (filter_values, filter_nulls) = filter.clone().into_parts();
602
603                let filter_buf = match filter_nulls {
604                    Some(filter_nulls) => &filter_values & filter_nulls.inner(),
605                    None => filter_values,
606                };
607                let state_buf = &filter_buf & nulls.inner();
608
609                let boolean_state = BooleanArray::new(state_buf, None);
610                compute::cast(&boolean_state, &DataType::Int64)?
611            }
612        };
613
614        Ok(vec![state_array])
615    }
616
617    fn supports_convert_to_state(&self) -> bool {
618        true
619    }
620
621    fn size(&self) -> usize {
622        self.counts.capacity() * size_of::<usize>()
623    }
624}
625
626/// count null values for multiple columns
627/// for each row if one column value is null, then null_count + 1
628fn null_count_for_multiple_cols(values: &[ArrayRef]) -> usize {
629    if values.len() > 1 {
630        let result_bool_buf: Option<BooleanBuffer> = values
631            .iter()
632            .map(|a| a.logical_nulls())
633            .fold(None, |acc, b| match (acc, b) {
634                (Some(acc), Some(b)) => Some(acc.bitand(b.inner())),
635                (Some(acc), None) => Some(acc),
636                (None, Some(b)) => Some(b.into_inner()),
637                _ => None,
638            });
639        result_bool_buf.map_or(0, |b| values[0].len() - b.count_set_bits())
640    } else {
641        values[0]
642            .logical_nulls()
643            .map_or(0, |nulls| nulls.null_count())
644    }
645}
646
647/// General purpose distinct accumulator that works for any DataType by using
648/// [`ScalarValue`].
649///
650/// It stores intermediate results as a `ListArray`
651///
652/// Note that many types have specialized accumulators that are (much)
653/// more efficient such as [`PrimitiveDistinctCountAccumulator`] and
654/// [`BytesDistinctCountAccumulator`]
655#[derive(Debug)]
656struct DistinctCountAccumulator {
657    values: HashSet<ScalarValue, RandomState>,
658    state_data_type: DataType,
659}
660
661impl DistinctCountAccumulator {
662    // calculating the size for fixed length values, taking first batch size *
663    // number of batches This method is faster than .full_size(), however it is
664    // not suitable for variable length values like strings or complex types
665    fn fixed_size(&self) -> usize {
666        size_of_val(self)
667            + (size_of::<ScalarValue>() * self.values.capacity())
668            + self
669                .values
670                .iter()
671                .next()
672                .map(|vals| ScalarValue::size(vals) - size_of_val(vals))
673                .unwrap_or(0)
674            + size_of::<DataType>()
675    }
676
677    // calculates the size as accurately as possible. Note that calling this
678    // method is expensive
679    fn full_size(&self) -> usize {
680        size_of_val(self)
681            + (size_of::<ScalarValue>() * self.values.capacity())
682            + self
683                .values
684                .iter()
685                .map(|vals| ScalarValue::size(vals) - size_of_val(vals))
686                .sum::<usize>()
687            + size_of::<DataType>()
688    }
689}
690
691impl Accumulator for DistinctCountAccumulator {
692    /// Returns the distinct values seen so far as (one element) ListArray.
693    fn state(&mut self) -> Result<Vec<ScalarValue>> {
694        let scalars = self.values.iter().cloned().collect::<Vec<_>>();
695        let arr =
696            ScalarValue::new_list_nullable(scalars.as_slice(), &self.state_data_type);
697        Ok(vec![ScalarValue::List(arr)])
698    }
699
700    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
701        if values.is_empty() {
702            return Ok(());
703        }
704
705        let arr = &values[0];
706        if arr.data_type() == &DataType::Null {
707            return Ok(());
708        }
709
710        (0..arr.len()).try_for_each(|index| {
711            if !arr.is_null(index) {
712                let scalar = ScalarValue::try_from_array(arr, index)?;
713                self.values.insert(scalar);
714            }
715            Ok(())
716        })
717    }
718
719    /// Merges multiple sets of distinct values into the current set.
720    ///
721    /// The input to this function is a `ListArray` with **multiple** rows,
722    /// where each row contains the values from a partial aggregate's phase (e.g.
723    /// the result of calling `Self::state` on multiple accumulators).
724    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
725        if states.is_empty() {
726            return Ok(());
727        }
728        assert_eq!(states.len(), 1, "array_agg states must be singleton!");
729        let array = &states[0];
730        let list_array = array.as_list::<i32>();
731        for inner_array in list_array.iter() {
732            let Some(inner_array) = inner_array else {
733                return internal_err!(
734                    "Intermediate results of COUNT DISTINCT should always be non null"
735                );
736            };
737            self.update_batch(&[inner_array])?;
738        }
739        Ok(())
740    }
741
742    fn evaluate(&mut self) -> Result<ScalarValue> {
743        Ok(ScalarValue::Int64(Some(self.values.len() as i64)))
744    }
745
746    fn size(&self) -> usize {
747        match &self.state_data_type {
748            DataType::Boolean | DataType::Null => self.fixed_size(),
749            d if d.is_primitive() => self.fixed_size(),
750            _ => self.full_size(),
751        }
752    }
753}
754
755#[cfg(test)]
756mod tests {
757    use super::*;
758    use arrow::array::NullArray;
759
760    #[test]
761    fn count_accumulator_nulls() -> Result<()> {
762        let mut accumulator = CountAccumulator::new();
763        accumulator.update_batch(&[Arc::new(NullArray::new(10))])?;
764        assert_eq!(accumulator.evaluate()?, ScalarValue::Int64(Some(0)));
765        Ok(())
766    }
767}