datafusion_functions_aggregate/
median.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::cmp::Ordering;
19use std::fmt::{Debug, Formatter};
20use std::mem::{size_of, size_of_val};
21use std::sync::Arc;
22
23use arrow::array::{
24    downcast_integer, ArrowNumericType, BooleanArray, ListArray, PrimitiveArray,
25    PrimitiveBuilder,
26};
27use arrow::buffer::{OffsetBuffer, ScalarBuffer};
28use arrow::{
29    array::{ArrayRef, AsArray},
30    datatypes::{
31        DataType, Decimal128Type, Decimal256Type, Field, Float16Type, Float32Type,
32        Float64Type,
33    },
34};
35
36use arrow::array::Array;
37use arrow::array::ArrowNativeTypeOp;
38use arrow::datatypes::{ArrowNativeType, ArrowPrimitiveType};
39
40use datafusion_common::{
41    internal_datafusion_err, internal_err, DataFusionError, HashSet, Result, ScalarValue,
42};
43use datafusion_expr::function::StateFieldsArgs;
44use datafusion_expr::{
45    function::AccumulatorArgs, utils::format_state_name, Accumulator, AggregateUDFImpl,
46    Documentation, Signature, Volatility,
47};
48use datafusion_expr::{EmitTo, GroupsAccumulator};
49use datafusion_functions_aggregate_common::aggregate::groups_accumulator::accumulate::accumulate;
50use datafusion_functions_aggregate_common::aggregate::groups_accumulator::nulls::filtered_null_mask;
51use datafusion_functions_aggregate_common::utils::Hashable;
52use datafusion_macros::user_doc;
53
54make_udaf_expr_and_func!(
55    Median,
56    median,
57    expression,
58    "Computes the median of a set of numbers",
59    median_udaf
60);
61
62#[user_doc(
63    doc_section(label = "General Functions"),
64    description = "Returns the median value in the specified column.",
65    syntax_example = "median(expression)",
66    sql_example = r#"```sql
67> SELECT median(column_name) FROM table_name;
68+----------------------+
69| median(column_name)   |
70+----------------------+
71| 45.5                 |
72+----------------------+
73```"#,
74    standard_argument(name = "expression", prefix = "The")
75)]
76/// MEDIAN aggregate expression. If using the non-distinct variation, then this uses a
77/// lot of memory because all values need to be stored in memory before a result can be
78/// computed. If an approximation is sufficient then APPROX_MEDIAN provides a much more
79/// efficient solution.
80///
81/// If using the distinct variation, the memory usage will be similarly high if the
82/// cardinality is high as it stores all distinct values in memory before computing the
83/// result, but if cardinality is low then memory usage will also be lower.
84pub struct Median {
85    signature: Signature,
86}
87
88impl Debug for Median {
89    fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
90        f.debug_struct("Median")
91            .field("name", &self.name())
92            .field("signature", &self.signature)
93            .finish()
94    }
95}
96
97impl Default for Median {
98    fn default() -> Self {
99        Self::new()
100    }
101}
102
103impl Median {
104    pub fn new() -> Self {
105        Self {
106            signature: Signature::numeric(1, Volatility::Immutable),
107        }
108    }
109}
110
111impl AggregateUDFImpl for Median {
112    fn as_any(&self) -> &dyn std::any::Any {
113        self
114    }
115
116    fn name(&self) -> &str {
117        "median"
118    }
119
120    fn signature(&self) -> &Signature {
121        &self.signature
122    }
123
124    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
125        Ok(arg_types[0].clone())
126    }
127
128    fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> {
129        //Intermediate state is a list of the elements we have collected so far
130        let field = Field::new_list_field(args.input_types[0].clone(), true);
131        let state_name = if args.is_distinct {
132            "distinct_median"
133        } else {
134            "median"
135        };
136
137        Ok(vec![Field::new(
138            format_state_name(args.name, state_name),
139            DataType::List(Arc::new(field)),
140            true,
141        )])
142    }
143
144    fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
145        macro_rules! helper {
146            ($t:ty, $dt:expr) => {
147                if acc_args.is_distinct {
148                    Ok(Box::new(DistinctMedianAccumulator::<$t> {
149                        data_type: $dt.clone(),
150                        distinct_values: HashSet::new(),
151                    }))
152                } else {
153                    Ok(Box::new(MedianAccumulator::<$t> {
154                        data_type: $dt.clone(),
155                        all_values: vec![],
156                    }))
157                }
158            };
159        }
160
161        let dt = acc_args.exprs[0].data_type(acc_args.schema)?;
162        downcast_integer! {
163            dt => (helper, dt),
164            DataType::Float16 => helper!(Float16Type, dt),
165            DataType::Float32 => helper!(Float32Type, dt),
166            DataType::Float64 => helper!(Float64Type, dt),
167            DataType::Decimal128(_, _) => helper!(Decimal128Type, dt),
168            DataType::Decimal256(_, _) => helper!(Decimal256Type, dt),
169            _ => Err(DataFusionError::NotImplemented(format!(
170                "MedianAccumulator not supported for {} with {}",
171                acc_args.name,
172                dt,
173            ))),
174        }
175    }
176
177    fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
178        !args.is_distinct
179    }
180
181    fn create_groups_accumulator(
182        &self,
183        args: AccumulatorArgs,
184    ) -> Result<Box<dyn GroupsAccumulator>> {
185        let num_args = args.exprs.len();
186        if num_args != 1 {
187            return internal_err!(
188                "median should only have 1 arg, but found num args:{}",
189                args.exprs.len()
190            );
191        }
192
193        let dt = args.exprs[0].data_type(args.schema)?;
194
195        macro_rules! helper {
196            ($t:ty, $dt:expr) => {
197                Ok(Box::new(MedianGroupsAccumulator::<$t>::new($dt)))
198            };
199        }
200
201        downcast_integer! {
202            dt => (helper, dt),
203            DataType::Float16 => helper!(Float16Type, dt),
204            DataType::Float32 => helper!(Float32Type, dt),
205            DataType::Float64 => helper!(Float64Type, dt),
206            DataType::Decimal128(_, _) => helper!(Decimal128Type, dt),
207            DataType::Decimal256(_, _) => helper!(Decimal256Type, dt),
208            _ => Err(DataFusionError::NotImplemented(format!(
209                "MedianGroupsAccumulator not supported for {} with {}",
210                args.name,
211                dt,
212            ))),
213        }
214    }
215
216    fn aliases(&self) -> &[String] {
217        &[]
218    }
219
220    fn documentation(&self) -> Option<&Documentation> {
221        self.doc()
222    }
223}
224
225/// The median accumulator accumulates the raw input values
226/// as `ScalarValue`s
227///
228/// The intermediate state is represented as a List of scalar values updated by
229/// `merge_batch` and a `Vec` of `ArrayRef` that are converted to scalar values
230/// in the final evaluation step so that we avoid expensive conversions and
231/// allocations during `update_batch`.
232struct MedianAccumulator<T: ArrowNumericType> {
233    data_type: DataType,
234    all_values: Vec<T::Native>,
235}
236
237impl<T: ArrowNumericType> Debug for MedianAccumulator<T> {
238    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
239        write!(f, "MedianAccumulator({})", self.data_type)
240    }
241}
242
243impl<T: ArrowNumericType> Accumulator for MedianAccumulator<T> {
244    fn state(&mut self) -> Result<Vec<ScalarValue>> {
245        // Convert `all_values` to `ListArray` and return a single List ScalarValue
246
247        // Build offsets
248        let offsets =
249            OffsetBuffer::new(ScalarBuffer::from(vec![0, self.all_values.len() as i32]));
250
251        // Build inner array
252        let values_array = PrimitiveArray::<T>::new(
253            ScalarBuffer::from(std::mem::take(&mut self.all_values)),
254            None,
255        )
256        .with_data_type(self.data_type.clone());
257
258        // Build the result list array
259        let list_array = ListArray::new(
260            Arc::new(Field::new_list_field(self.data_type.clone(), true)),
261            offsets,
262            Arc::new(values_array),
263            None,
264        );
265
266        Ok(vec![ScalarValue::List(Arc::new(list_array))])
267    }
268
269    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
270        let values = values[0].as_primitive::<T>();
271        self.all_values.reserve(values.len() - values.null_count());
272        self.all_values.extend(values.iter().flatten());
273        Ok(())
274    }
275
276    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
277        let array = states[0].as_list::<i32>();
278        for v in array.iter().flatten() {
279            self.update_batch(&[v])?
280        }
281        Ok(())
282    }
283
284    fn evaluate(&mut self) -> Result<ScalarValue> {
285        let d = std::mem::take(&mut self.all_values);
286        let median = calculate_median::<T>(d);
287        ScalarValue::new_primitive::<T>(median, &self.data_type)
288    }
289
290    fn size(&self) -> usize {
291        size_of_val(self) + self.all_values.capacity() * size_of::<T::Native>()
292    }
293}
294
295/// The median groups accumulator accumulates the raw input values
296///
297/// For calculating the accurate medians of groups, we need to store all values
298/// of groups before final evaluation.
299/// So values in each group will be stored in a `Vec<T>`, and the total group values
300/// will be actually organized as a `Vec<Vec<T>>`.
301///
302#[derive(Debug)]
303struct MedianGroupsAccumulator<T: ArrowNumericType + Send> {
304    data_type: DataType,
305    group_values: Vec<Vec<T::Native>>,
306}
307
308impl<T: ArrowNumericType + Send> MedianGroupsAccumulator<T> {
309    pub fn new(data_type: DataType) -> Self {
310        Self {
311            data_type,
312            group_values: Vec::new(),
313        }
314    }
315}
316
317impl<T: ArrowNumericType + Send> GroupsAccumulator for MedianGroupsAccumulator<T> {
318    fn update_batch(
319        &mut self,
320        values: &[ArrayRef],
321        group_indices: &[usize],
322        opt_filter: Option<&BooleanArray>,
323        total_num_groups: usize,
324    ) -> Result<()> {
325        assert_eq!(values.len(), 1, "single argument to update_batch");
326        let values = values[0].as_primitive::<T>();
327
328        // Push the `not nulls + not filtered` row into its group
329        self.group_values.resize(total_num_groups, Vec::new());
330        accumulate(
331            group_indices,
332            values,
333            opt_filter,
334            |group_index, new_value| {
335                self.group_values[group_index].push(new_value);
336            },
337        );
338
339        Ok(())
340    }
341
342    fn merge_batch(
343        &mut self,
344        values: &[ArrayRef],
345        group_indices: &[usize],
346        // Since aggregate filter should be applied in partial stage, in final stage there should be no filter
347        _opt_filter: Option<&BooleanArray>,
348        total_num_groups: usize,
349    ) -> Result<()> {
350        assert_eq!(values.len(), 1, "one argument to merge_batch");
351
352        // The merged values should be organized like as a `ListArray` which is nullable
353        // (input with nulls usually generated from `convert_to_state`), but `inner array` of
354        // `ListArray`  is `non-nullable`.
355        //
356        // Following is the possible and impossible input `values`:
357        //
358        // # Possible values
359        // ```text
360        //   group 0: [1, 2, 3]
361        //   group 1: null (list array is nullable)
362        //   group 2: [6, 7, 8]
363        //   ...
364        //   group n: [...]
365        // ```
366        //
367        // # Impossible values
368        // ```text
369        //   group x: [1, 2, null] (values in list array is non-nullable)
370        // ```
371        //
372        let input_group_values = values[0].as_list::<i32>();
373
374        // Ensure group values big enough
375        self.group_values.resize(total_num_groups, Vec::new());
376
377        // Extend values to related groups
378        // TODO: avoid using iterator of the `ListArray`, this will lead to
379        // many calls of `slice` of its ``inner array`, and `slice` is not
380        // so efficient(due to the calculation of `null_count` for each `slice`).
381        group_indices
382            .iter()
383            .zip(input_group_values.iter())
384            .for_each(|(&group_index, values_opt)| {
385                if let Some(values) = values_opt {
386                    let values = values.as_primitive::<T>();
387                    self.group_values[group_index].extend(values.values().iter());
388                }
389            });
390
391        Ok(())
392    }
393
394    fn state(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> {
395        // Emit values
396        let emit_group_values = emit_to.take_needed(&mut self.group_values);
397
398        // Build offsets
399        let mut offsets = Vec::with_capacity(self.group_values.len() + 1);
400        offsets.push(0);
401        let mut cur_len = 0_i32;
402        for group_value in &emit_group_values {
403            cur_len += group_value.len() as i32;
404            offsets.push(cur_len);
405        }
406        // TODO: maybe we can use `OffsetBuffer::new_unchecked` like what in `convert_to_state`,
407        // but safety should be considered more carefully here(and I am not sure if it can get
408        // performance improvement when we introduce checks to keep the safety...).
409        //
410        // Can see more details in:
411        // https://github.com/apache/datafusion/pull/13681#discussion_r1931209791
412        //
413        let offsets = OffsetBuffer::new(ScalarBuffer::from(offsets));
414
415        // Build inner array
416        let flatten_group_values =
417            emit_group_values.into_iter().flatten().collect::<Vec<_>>();
418        let group_values_array =
419            PrimitiveArray::<T>::new(ScalarBuffer::from(flatten_group_values), None)
420                .with_data_type(self.data_type.clone());
421
422        // Build the result list array
423        let result_list_array = ListArray::new(
424            Arc::new(Field::new_list_field(self.data_type.clone(), true)),
425            offsets,
426            Arc::new(group_values_array),
427            None,
428        );
429
430        Ok(vec![Arc::new(result_list_array)])
431    }
432
433    fn evaluate(&mut self, emit_to: EmitTo) -> Result<ArrayRef> {
434        // Emit values
435        let emit_group_values = emit_to.take_needed(&mut self.group_values);
436
437        // Calculate median for each group
438        let mut evaluate_result_builder =
439            PrimitiveBuilder::<T>::new().with_data_type(self.data_type.clone());
440        for values in emit_group_values {
441            let median = calculate_median::<T>(values);
442            evaluate_result_builder.append_option(median);
443        }
444
445        Ok(Arc::new(evaluate_result_builder.finish()))
446    }
447
448    fn convert_to_state(
449        &self,
450        values: &[ArrayRef],
451        opt_filter: Option<&BooleanArray>,
452    ) -> Result<Vec<ArrayRef>> {
453        assert_eq!(values.len(), 1, "one argument to merge_batch");
454
455        let input_array = values[0].as_primitive::<T>();
456
457        // Directly convert the input array to states, each row will be
458        // seen as a respective group.
459        // For detail, the `input_array` will be converted to a `ListArray`.
460        // And if row is `not null + not filtered`, it will be converted to a list
461        // with only one element; otherwise, this row in `ListArray` will be set
462        // to null.
463
464        // Reuse values buffer in `input_array` to build `values` in `ListArray`
465        let values = PrimitiveArray::<T>::new(input_array.values().clone(), None)
466            .with_data_type(self.data_type.clone());
467
468        // `offsets` in `ListArray`, each row as a list element
469        let offset_end = i32::try_from(input_array.len()).map_err(|e| {
470            internal_datafusion_err!(
471                "cast array_len to i32 failed in convert_to_state of group median, err:{e:?}"
472            )
473        })?;
474        let offsets = (0..=offset_end).collect::<Vec<_>>();
475        // Safety: all checks in `OffsetBuffer::new` are ensured to pass
476        let offsets = unsafe { OffsetBuffer::new_unchecked(ScalarBuffer::from(offsets)) };
477
478        // `nulls` for converted `ListArray`
479        let nulls = filtered_null_mask(opt_filter, input_array);
480
481        let converted_list_array = ListArray::new(
482            Arc::new(Field::new_list_field(self.data_type.clone(), true)),
483            offsets,
484            Arc::new(values),
485            nulls,
486        );
487
488        Ok(vec![Arc::new(converted_list_array)])
489    }
490
491    fn supports_convert_to_state(&self) -> bool {
492        true
493    }
494
495    fn size(&self) -> usize {
496        self.group_values
497            .iter()
498            .map(|values| values.capacity() * size_of::<T>())
499            .sum::<usize>()
500            // account for size of self.grou_values too
501            + self.group_values.capacity() * size_of::<Vec<T>>()
502    }
503}
504
505/// The distinct median accumulator accumulates the raw input values
506/// as `ScalarValue`s
507///
508/// The intermediate state is represented as a List of scalar values updated by
509/// `merge_batch` and a `Vec` of `ArrayRef` that are converted to scalar values
510/// in the final evaluation step so that we avoid expensive conversions and
511/// allocations during `update_batch`.
512struct DistinctMedianAccumulator<T: ArrowNumericType> {
513    data_type: DataType,
514    distinct_values: HashSet<Hashable<T::Native>>,
515}
516
517impl<T: ArrowNumericType> Debug for DistinctMedianAccumulator<T> {
518    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
519        write!(f, "DistinctMedianAccumulator({})", self.data_type)
520    }
521}
522
523impl<T: ArrowNumericType> Accumulator for DistinctMedianAccumulator<T> {
524    fn state(&mut self) -> Result<Vec<ScalarValue>> {
525        let all_values = self
526            .distinct_values
527            .iter()
528            .map(|x| ScalarValue::new_primitive::<T>(Some(x.0), &self.data_type))
529            .collect::<Result<Vec<_>>>()?;
530
531        let arr = ScalarValue::new_list_nullable(&all_values, &self.data_type);
532        Ok(vec![ScalarValue::List(arr)])
533    }
534
535    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
536        if values.is_empty() {
537            return Ok(());
538        }
539
540        let array = values[0].as_primitive::<T>();
541        match array.nulls().filter(|x| x.null_count() > 0) {
542            Some(n) => {
543                for idx in n.valid_indices() {
544                    self.distinct_values.insert(Hashable(array.value(idx)));
545                }
546            }
547            None => array.values().iter().for_each(|x| {
548                self.distinct_values.insert(Hashable(*x));
549            }),
550        }
551        Ok(())
552    }
553
554    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
555        let array = states[0].as_list::<i32>();
556        for v in array.iter().flatten() {
557            self.update_batch(&[v])?
558        }
559        Ok(())
560    }
561
562    fn evaluate(&mut self) -> Result<ScalarValue> {
563        let d = std::mem::take(&mut self.distinct_values)
564            .into_iter()
565            .map(|v| v.0)
566            .collect::<Vec<_>>();
567        let median = calculate_median::<T>(d);
568        ScalarValue::new_primitive::<T>(median, &self.data_type)
569    }
570
571    fn size(&self) -> usize {
572        size_of_val(self) + self.distinct_values.capacity() * size_of::<T::Native>()
573    }
574}
575
576/// Get maximum entry in the slice,
577fn slice_max<T>(array: &[T::Native]) -> T::Native
578where
579    T: ArrowPrimitiveType,
580    T::Native: PartialOrd, // Ensure the type supports PartialOrd for comparison
581{
582    // Make sure that, array is not empty.
583    debug_assert!(!array.is_empty());
584    // `.unwrap()` is safe here as the array is supposed to be non-empty
585    *array
586        .iter()
587        .max_by(|x, y| x.partial_cmp(y).unwrap_or(Ordering::Less))
588        .unwrap()
589}
590
591fn calculate_median<T: ArrowNumericType>(
592    mut values: Vec<T::Native>,
593) -> Option<T::Native> {
594    let cmp = |x: &T::Native, y: &T::Native| x.compare(*y);
595
596    let len = values.len();
597    if len == 0 {
598        None
599    } else if len % 2 == 0 {
600        let (low, high, _) = values.select_nth_unstable_by(len / 2, cmp);
601        // Get the maximum of the low (left side after bi-partitioning)
602        let left_max = slice_max::<T>(low);
603        let median = left_max
604            .add_wrapping(*high)
605            .div_wrapping(T::Native::usize_as(2));
606        Some(median)
607    } else {
608        let (_, median, _) = values.select_nth_unstable_by(len / 2, cmp);
609        Some(*median)
610    }
611}