1use ahash::RandomState;
19use datafusion_common::stats::Precision;
20use datafusion_expr::expr::WindowFunction;
21use datafusion_functions_aggregate_common::aggregate::count_distinct::BytesViewDistinctCountAccumulator;
22use datafusion_macros::user_doc;
23use datafusion_physical_expr::expressions;
24use std::collections::HashSet;
25use std::fmt::Debug;
26use std::mem::{size_of, size_of_val};
27use std::ops::BitAnd;
28use std::sync::Arc;
29
30use arrow::{
31 array::{ArrayRef, AsArray},
32 compute,
33 datatypes::{
34 DataType, Date32Type, Date64Type, Decimal128Type, Decimal256Type, Field,
35 Float16Type, Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type,
36 Time32MillisecondType, Time32SecondType, Time64MicrosecondType,
37 Time64NanosecondType, TimeUnit, TimestampMicrosecondType,
38 TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType,
39 UInt16Type, UInt32Type, UInt64Type, UInt8Type,
40 },
41};
42
43use arrow::{
44 array::{Array, BooleanArray, Int64Array, PrimitiveArray},
45 buffer::BooleanBuffer,
46};
47use datafusion_common::{
48 downcast_value, internal_err, not_impl_err, Result, ScalarValue,
49};
50use datafusion_expr::function::StateFieldsArgs;
51use datafusion_expr::{
52 function::AccumulatorArgs, utils::format_state_name, Accumulator, AggregateUDFImpl,
53 Documentation, EmitTo, GroupsAccumulator, SetMonotonicity, Signature, Volatility,
54};
55use datafusion_expr::{
56 Expr, ReversedUDAF, StatisticsArgs, TypeSignature, WindowFunctionDefinition,
57};
58use datafusion_functions_aggregate_common::aggregate::count_distinct::{
59 BytesDistinctCountAccumulator, FloatDistinctCountAccumulator,
60 PrimitiveDistinctCountAccumulator,
61};
62use datafusion_functions_aggregate_common::aggregate::groups_accumulator::accumulate::accumulate_indices;
63use datafusion_physical_expr_common::binary_map::OutputType;
64
65use datafusion_common::utils::expr::COUNT_STAR_EXPANSION;
66make_udaf_expr_and_func!(
67 Count,
68 count,
69 expr,
70 "Count the number of non-null values in the column",
71 count_udaf
72);
73
74pub fn count_distinct(expr: Expr) -> Expr {
75 Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction::new_udf(
76 count_udaf(),
77 vec![expr],
78 true,
79 None,
80 None,
81 None,
82 ))
83}
84
85pub fn count_all() -> Expr {
103 count(Expr::Literal(COUNT_STAR_EXPANSION)).alias("count(*)")
104}
105
106pub fn count_all_window() -> Expr {
126 Expr::WindowFunction(WindowFunction::new(
127 WindowFunctionDefinition::AggregateUDF(count_udaf()),
128 vec![Expr::Literal(COUNT_STAR_EXPANSION)],
129 ))
130}
131
132#[user_doc(
133 doc_section(label = "General Functions"),
134 description = "Returns the number of non-null values in the specified column. To include null values in the total count, use `count(*)`.",
135 syntax_example = "count(expression)",
136 sql_example = r#"```sql
137> SELECT count(column_name) FROM table_name;
138+-----------------------+
139| count(column_name) |
140+-----------------------+
141| 100 |
142+-----------------------+
143
144> SELECT count(*) FROM table_name;
145+------------------+
146| count(*) |
147+------------------+
148| 120 |
149+------------------+
150```"#,
151 standard_argument(name = "expression",)
152)]
153pub struct Count {
154 signature: Signature,
155}
156
157impl Debug for Count {
158 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
159 f.debug_struct("Count")
160 .field("name", &self.name())
161 .field("signature", &self.signature)
162 .finish()
163 }
164}
165
166impl Default for Count {
167 fn default() -> Self {
168 Self::new()
169 }
170}
171
172impl Count {
173 pub fn new() -> Self {
174 Self {
175 signature: Signature::one_of(
176 vec![TypeSignature::VariadicAny, TypeSignature::Nullary],
177 Volatility::Immutable,
178 ),
179 }
180 }
181}
182
183impl AggregateUDFImpl for Count {
184 fn as_any(&self) -> &dyn std::any::Any {
185 self
186 }
187
188 fn name(&self) -> &str {
189 "count"
190 }
191
192 fn signature(&self) -> &Signature {
193 &self.signature
194 }
195
196 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
197 Ok(DataType::Int64)
198 }
199
200 fn is_nullable(&self) -> bool {
201 false
202 }
203
204 fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> {
205 if args.is_distinct {
206 Ok(vec![Field::new_list(
207 format_state_name(args.name, "count distinct"),
208 Field::new_list_field(args.input_types[0].clone(), true),
210 false,
211 )])
212 } else {
213 Ok(vec![Field::new(
214 format_state_name(args.name, "count"),
215 DataType::Int64,
216 false,
217 )])
218 }
219 }
220
221 fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
222 if !acc_args.is_distinct {
223 return Ok(Box::new(CountAccumulator::new()));
224 }
225
226 if acc_args.exprs.len() > 1 {
227 return not_impl_err!("COUNT DISTINCT with multiple arguments");
228 }
229
230 let data_type = &acc_args.exprs[0].data_type(acc_args.schema)?;
231 Ok(match data_type {
232 DataType::Int8 => Box::new(
234 PrimitiveDistinctCountAccumulator::<Int8Type>::new(data_type),
235 ),
236 DataType::Int16 => Box::new(
237 PrimitiveDistinctCountAccumulator::<Int16Type>::new(data_type),
238 ),
239 DataType::Int32 => Box::new(
240 PrimitiveDistinctCountAccumulator::<Int32Type>::new(data_type),
241 ),
242 DataType::Int64 => Box::new(
243 PrimitiveDistinctCountAccumulator::<Int64Type>::new(data_type),
244 ),
245 DataType::UInt8 => Box::new(
246 PrimitiveDistinctCountAccumulator::<UInt8Type>::new(data_type),
247 ),
248 DataType::UInt16 => Box::new(
249 PrimitiveDistinctCountAccumulator::<UInt16Type>::new(data_type),
250 ),
251 DataType::UInt32 => Box::new(
252 PrimitiveDistinctCountAccumulator::<UInt32Type>::new(data_type),
253 ),
254 DataType::UInt64 => Box::new(
255 PrimitiveDistinctCountAccumulator::<UInt64Type>::new(data_type),
256 ),
257 DataType::Decimal128(_, _) => Box::new(PrimitiveDistinctCountAccumulator::<
258 Decimal128Type,
259 >::new(data_type)),
260 DataType::Decimal256(_, _) => Box::new(PrimitiveDistinctCountAccumulator::<
261 Decimal256Type,
262 >::new(data_type)),
263
264 DataType::Date32 => Box::new(
265 PrimitiveDistinctCountAccumulator::<Date32Type>::new(data_type),
266 ),
267 DataType::Date64 => Box::new(
268 PrimitiveDistinctCountAccumulator::<Date64Type>::new(data_type),
269 ),
270 DataType::Time32(TimeUnit::Millisecond) => Box::new(
271 PrimitiveDistinctCountAccumulator::<Time32MillisecondType>::new(
272 data_type,
273 ),
274 ),
275 DataType::Time32(TimeUnit::Second) => Box::new(
276 PrimitiveDistinctCountAccumulator::<Time32SecondType>::new(data_type),
277 ),
278 DataType::Time64(TimeUnit::Microsecond) => Box::new(
279 PrimitiveDistinctCountAccumulator::<Time64MicrosecondType>::new(
280 data_type,
281 ),
282 ),
283 DataType::Time64(TimeUnit::Nanosecond) => Box::new(
284 PrimitiveDistinctCountAccumulator::<Time64NanosecondType>::new(data_type),
285 ),
286 DataType::Timestamp(TimeUnit::Microsecond, _) => Box::new(
287 PrimitiveDistinctCountAccumulator::<TimestampMicrosecondType>::new(
288 data_type,
289 ),
290 ),
291 DataType::Timestamp(TimeUnit::Millisecond, _) => Box::new(
292 PrimitiveDistinctCountAccumulator::<TimestampMillisecondType>::new(
293 data_type,
294 ),
295 ),
296 DataType::Timestamp(TimeUnit::Nanosecond, _) => Box::new(
297 PrimitiveDistinctCountAccumulator::<TimestampNanosecondType>::new(
298 data_type,
299 ),
300 ),
301 DataType::Timestamp(TimeUnit::Second, _) => Box::new(
302 PrimitiveDistinctCountAccumulator::<TimestampSecondType>::new(data_type),
303 ),
304
305 DataType::Float16 => {
306 Box::new(FloatDistinctCountAccumulator::<Float16Type>::new())
307 }
308 DataType::Float32 => {
309 Box::new(FloatDistinctCountAccumulator::<Float32Type>::new())
310 }
311 DataType::Float64 => {
312 Box::new(FloatDistinctCountAccumulator::<Float64Type>::new())
313 }
314
315 DataType::Utf8 => {
316 Box::new(BytesDistinctCountAccumulator::<i32>::new(OutputType::Utf8))
317 }
318 DataType::Utf8View => {
319 Box::new(BytesViewDistinctCountAccumulator::new(OutputType::Utf8View))
320 }
321 DataType::LargeUtf8 => {
322 Box::new(BytesDistinctCountAccumulator::<i64>::new(OutputType::Utf8))
323 }
324 DataType::Binary => Box::new(BytesDistinctCountAccumulator::<i32>::new(
325 OutputType::Binary,
326 )),
327 DataType::BinaryView => Box::new(BytesViewDistinctCountAccumulator::new(
328 OutputType::BinaryView,
329 )),
330 DataType::LargeBinary => Box::new(BytesDistinctCountAccumulator::<i64>::new(
331 OutputType::Binary,
332 )),
333
334 _ => Box::new(DistinctCountAccumulator {
336 values: HashSet::default(),
337 state_data_type: data_type.clone(),
338 }),
339 })
340 }
341
342 fn aliases(&self) -> &[String] {
343 &[]
344 }
345
346 fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
347 if args.is_distinct {
350 return false;
351 }
352 args.exprs.len() == 1
353 }
354
355 fn create_groups_accumulator(
356 &self,
357 _args: AccumulatorArgs,
358 ) -> Result<Box<dyn GroupsAccumulator>> {
359 Ok(Box::new(CountGroupsAccumulator::new()))
361 }
362
363 fn reverse_expr(&self) -> ReversedUDAF {
364 ReversedUDAF::Identical
365 }
366
367 fn default_value(&self, _data_type: &DataType) -> Result<ScalarValue> {
368 Ok(ScalarValue::Int64(Some(0)))
369 }
370
371 fn value_from_stats(&self, statistics_args: &StatisticsArgs) -> Option<ScalarValue> {
372 if statistics_args.is_distinct {
373 return None;
374 }
375 if let Precision::Exact(num_rows) = statistics_args.statistics.num_rows {
376 if statistics_args.exprs.len() == 1 {
377 if let Some(col_expr) = statistics_args.exprs[0]
379 .as_any()
380 .downcast_ref::<expressions::Column>()
381 {
382 let current_val = &statistics_args.statistics.column_statistics
383 [col_expr.index()]
384 .null_count;
385 if let &Precision::Exact(val) = current_val {
386 return Some(ScalarValue::Int64(Some((num_rows - val) as i64)));
387 }
388 } else if let Some(lit_expr) = statistics_args.exprs[0]
389 .as_any()
390 .downcast_ref::<expressions::Literal>()
391 {
392 if lit_expr.value() == &COUNT_STAR_EXPANSION {
393 return Some(ScalarValue::Int64(Some(num_rows as i64)));
394 }
395 }
396 }
397 }
398 None
399 }
400
401 fn documentation(&self) -> Option<&Documentation> {
402 self.doc()
403 }
404
405 fn set_monotonicity(&self, _data_type: &DataType) -> SetMonotonicity {
406 SetMonotonicity::Increasing
409 }
410}
411
412#[derive(Debug)]
413struct CountAccumulator {
414 count: i64,
415}
416
417impl CountAccumulator {
418 pub fn new() -> Self {
420 Self { count: 0 }
421 }
422}
423
424impl Accumulator for CountAccumulator {
425 fn state(&mut self) -> Result<Vec<ScalarValue>> {
426 Ok(vec![ScalarValue::Int64(Some(self.count))])
427 }
428
429 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
430 let array = &values[0];
431 self.count += (array.len() - null_count_for_multiple_cols(values)) as i64;
432 Ok(())
433 }
434
435 fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
436 let array = &values[0];
437 self.count -= (array.len() - null_count_for_multiple_cols(values)) as i64;
438 Ok(())
439 }
440
441 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
442 let counts = downcast_value!(states[0], Int64Array);
443 let delta = &compute::sum(counts);
444 if let Some(d) = delta {
445 self.count += *d;
446 }
447 Ok(())
448 }
449
450 fn evaluate(&mut self) -> Result<ScalarValue> {
451 Ok(ScalarValue::Int64(Some(self.count)))
452 }
453
454 fn supports_retract_batch(&self) -> bool {
455 true
456 }
457
458 fn size(&self) -> usize {
459 size_of_val(self)
460 }
461}
462
463#[derive(Debug)]
470struct CountGroupsAccumulator {
471 counts: Vec<i64>,
478}
479
480impl CountGroupsAccumulator {
481 pub fn new() -> Self {
482 Self { counts: vec![] }
483 }
484}
485
486impl GroupsAccumulator for CountGroupsAccumulator {
487 fn update_batch(
488 &mut self,
489 values: &[ArrayRef],
490 group_indices: &[usize],
491 opt_filter: Option<&BooleanArray>,
492 total_num_groups: usize,
493 ) -> Result<()> {
494 assert_eq!(values.len(), 1, "single argument to update_batch");
495 let values = &values[0];
496
497 self.counts.resize(total_num_groups, 0);
500 accumulate_indices(
501 group_indices,
502 values.logical_nulls().as_ref(),
503 opt_filter,
504 |group_index| {
505 self.counts[group_index] += 1;
506 },
507 );
508
509 Ok(())
510 }
511
512 fn merge_batch(
513 &mut self,
514 values: &[ArrayRef],
515 group_indices: &[usize],
516 _opt_filter: Option<&BooleanArray>,
518 total_num_groups: usize,
519 ) -> Result<()> {
520 assert_eq!(values.len(), 1, "one argument to merge_batch");
521 let partial_counts = values[0].as_primitive::<Int64Type>();
523
524 assert_eq!(partial_counts.null_count(), 0);
526 let partial_counts = partial_counts.values();
527
528 self.counts.resize(total_num_groups, 0);
530 group_indices.iter().zip(partial_counts.iter()).for_each(
531 |(&group_index, partial_count)| {
532 self.counts[group_index] += partial_count;
533 },
534 );
535
536 Ok(())
537 }
538
539 fn evaluate(&mut self, emit_to: EmitTo) -> Result<ArrayRef> {
540 let counts = emit_to.take_needed(&mut self.counts);
541
542 let nulls = None;
544 let array = PrimitiveArray::<Int64Type>::new(counts.into(), nulls);
545
546 Ok(Arc::new(array))
547 }
548
549 fn state(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> {
551 let counts = emit_to.take_needed(&mut self.counts);
552 let counts: PrimitiveArray<Int64Type> = Int64Array::from(counts); Ok(vec![Arc::new(counts) as ArrayRef])
554 }
555
556 fn convert_to_state(
562 &self,
563 values: &[ArrayRef],
564 opt_filter: Option<&BooleanArray>,
565 ) -> Result<Vec<ArrayRef>> {
566 let values = &values[0];
567
568 let state_array = match (values.logical_nulls(), opt_filter) {
569 (None, None) => {
570 Arc::new(Int64Array::from_value(1, values.len()))
572 }
573 (Some(nulls), None) => {
574 let nulls = BooleanArray::new(nulls.into_inner(), None);
577 compute::cast(&nulls, &DataType::Int64)?
578 }
579 (None, Some(filter)) => {
580 let (filter_values, filter_nulls) = filter.clone().into_parts();
585
586 let state_buf = match filter_nulls {
587 Some(filter_nulls) => &filter_values & filter_nulls.inner(),
588 None => filter_values,
589 };
590
591 let boolean_state = BooleanArray::new(state_buf, None);
592 compute::cast(&boolean_state, &DataType::Int64)?
593 }
594 (Some(nulls), Some(filter)) => {
595 let (filter_values, filter_nulls) = filter.clone().into_parts();
602
603 let filter_buf = match filter_nulls {
604 Some(filter_nulls) => &filter_values & filter_nulls.inner(),
605 None => filter_values,
606 };
607 let state_buf = &filter_buf & nulls.inner();
608
609 let boolean_state = BooleanArray::new(state_buf, None);
610 compute::cast(&boolean_state, &DataType::Int64)?
611 }
612 };
613
614 Ok(vec![state_array])
615 }
616
617 fn supports_convert_to_state(&self) -> bool {
618 true
619 }
620
621 fn size(&self) -> usize {
622 self.counts.capacity() * size_of::<usize>()
623 }
624}
625
626fn null_count_for_multiple_cols(values: &[ArrayRef]) -> usize {
629 if values.len() > 1 {
630 let result_bool_buf: Option<BooleanBuffer> = values
631 .iter()
632 .map(|a| a.logical_nulls())
633 .fold(None, |acc, b| match (acc, b) {
634 (Some(acc), Some(b)) => Some(acc.bitand(b.inner())),
635 (Some(acc), None) => Some(acc),
636 (None, Some(b)) => Some(b.into_inner()),
637 _ => None,
638 });
639 result_bool_buf.map_or(0, |b| values[0].len() - b.count_set_bits())
640 } else {
641 values[0]
642 .logical_nulls()
643 .map_or(0, |nulls| nulls.null_count())
644 }
645}
646
647#[derive(Debug)]
656struct DistinctCountAccumulator {
657 values: HashSet<ScalarValue, RandomState>,
658 state_data_type: DataType,
659}
660
661impl DistinctCountAccumulator {
662 fn fixed_size(&self) -> usize {
666 size_of_val(self)
667 + (size_of::<ScalarValue>() * self.values.capacity())
668 + self
669 .values
670 .iter()
671 .next()
672 .map(|vals| ScalarValue::size(vals) - size_of_val(vals))
673 .unwrap_or(0)
674 + size_of::<DataType>()
675 }
676
677 fn full_size(&self) -> usize {
680 size_of_val(self)
681 + (size_of::<ScalarValue>() * self.values.capacity())
682 + self
683 .values
684 .iter()
685 .map(|vals| ScalarValue::size(vals) - size_of_val(vals))
686 .sum::<usize>()
687 + size_of::<DataType>()
688 }
689}
690
691impl Accumulator for DistinctCountAccumulator {
692 fn state(&mut self) -> Result<Vec<ScalarValue>> {
694 let scalars = self.values.iter().cloned().collect::<Vec<_>>();
695 let arr =
696 ScalarValue::new_list_nullable(scalars.as_slice(), &self.state_data_type);
697 Ok(vec![ScalarValue::List(arr)])
698 }
699
700 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
701 if values.is_empty() {
702 return Ok(());
703 }
704
705 let arr = &values[0];
706 if arr.data_type() == &DataType::Null {
707 return Ok(());
708 }
709
710 (0..arr.len()).try_for_each(|index| {
711 if !arr.is_null(index) {
712 let scalar = ScalarValue::try_from_array(arr, index)?;
713 self.values.insert(scalar);
714 }
715 Ok(())
716 })
717 }
718
719 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
725 if states.is_empty() {
726 return Ok(());
727 }
728 assert_eq!(states.len(), 1, "array_agg states must be singleton!");
729 let array = &states[0];
730 let list_array = array.as_list::<i32>();
731 for inner_array in list_array.iter() {
732 let Some(inner_array) = inner_array else {
733 return internal_err!(
734 "Intermediate results of COUNT DISTINCT should always be non null"
735 );
736 };
737 self.update_batch(&[inner_array])?;
738 }
739 Ok(())
740 }
741
742 fn evaluate(&mut self) -> Result<ScalarValue> {
743 Ok(ScalarValue::Int64(Some(self.values.len() as i64)))
744 }
745
746 fn size(&self) -> usize {
747 match &self.state_data_type {
748 DataType::Boolean | DataType::Null => self.fixed_size(),
749 d if d.is_primitive() => self.fixed_size(),
750 _ => self.full_size(),
751 }
752 }
753}
754
755#[cfg(test)]
756mod tests {
757 use super::*;
758 use arrow::array::NullArray;
759
760 #[test]
761 fn count_accumulator_nulls() -> Result<()> {
762 let mut accumulator = CountAccumulator::new();
763 accumulator.update_batch(&[Arc::new(NullArray::new(10))])?;
764 assert_eq!(accumulator.evaluate()?, ScalarValue::Int64(Some(0)));
765 Ok(())
766 }
767}