datafusion_functions_aggregate/
approx_percentile_cont.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::any::Any;
19use std::fmt::{Debug, Formatter};
20use std::mem::size_of_val;
21use std::sync::Arc;
22
23use arrow::array::{Array, RecordBatch};
24use arrow::compute::{filter, is_not_null};
25use arrow::datatypes::FieldRef;
26use arrow::{
27    array::{
28        ArrayRef, Float32Array, Float64Array, Int16Array, Int32Array, Int64Array,
29        Int8Array, UInt16Array, UInt32Array, UInt64Array, UInt8Array,
30    },
31    datatypes::{DataType, Field, Schema},
32};
33use datafusion_common::{
34    downcast_value, internal_err, not_impl_datafusion_err, not_impl_err, plan_err,
35    Result, ScalarValue,
36};
37use datafusion_expr::expr::{AggregateFunction, Sort};
38use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
39use datafusion_expr::type_coercion::aggregates::{INTEGERS, NUMERICS};
40use datafusion_expr::utils::format_state_name;
41use datafusion_expr::{
42    Accumulator, AggregateUDFImpl, ColumnarValue, Documentation, Expr, Signature,
43    TypeSignature, Volatility,
44};
45use datafusion_functions_aggregate_common::tdigest::{
46    TDigest, TryIntoF64, DEFAULT_MAX_SIZE,
47};
48use datafusion_macros::user_doc;
49use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
50
51create_func!(ApproxPercentileCont, approx_percentile_cont_udaf);
52
53/// Computes the approximate percentile continuous of a set of numbers
54pub fn approx_percentile_cont(
55    order_by: Sort,
56    percentile: Expr,
57    centroids: Option<Expr>,
58) -> Expr {
59    let expr = order_by.expr.clone();
60
61    let args = if let Some(centroids) = centroids {
62        vec![expr, percentile, centroids]
63    } else {
64        vec![expr, percentile]
65    };
66
67    Expr::AggregateFunction(AggregateFunction::new_udf(
68        approx_percentile_cont_udaf(),
69        args,
70        false,
71        None,
72        Some(vec![order_by]),
73        None,
74    ))
75}
76
77#[user_doc(
78    doc_section(label = "Approximate Functions"),
79    description = "Returns the approximate percentile of input values using the t-digest algorithm.",
80    syntax_example = "approx_percentile_cont(percentile, centroids) WITHIN GROUP (ORDER BY expression)",
81    sql_example = r#"```sql
82> SELECT approx_percentile_cont(0.75, 100) WITHIN GROUP (ORDER BY column_name) FROM table_name;
83+-----------------------------------------------------------------------+
84| approx_percentile_cont(0.75, 100) WITHIN GROUP (ORDER BY column_name) |
85+-----------------------------------------------------------------------+
86| 65.0                                                                  |
87+-----------------------------------------------------------------------+
88```"#,
89    standard_argument(name = "expression",),
90    argument(
91        name = "percentile",
92        description = "Percentile to compute. Must be a float value between 0 and 1 (inclusive)."
93    ),
94    argument(
95        name = "centroids",
96        description = "Number of centroids to use in the t-digest algorithm. _Default is 100_. A higher number results in more accurate approximation but requires more memory."
97    )
98)]
99pub struct ApproxPercentileCont {
100    signature: Signature,
101}
102
103impl Debug for ApproxPercentileCont {
104    fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
105        f.debug_struct("ApproxPercentileCont")
106            .field("name", &self.name())
107            .field("signature", &self.signature)
108            .finish()
109    }
110}
111
112impl Default for ApproxPercentileCont {
113    fn default() -> Self {
114        Self::new()
115    }
116}
117
118impl ApproxPercentileCont {
119    /// Create a new [`ApproxPercentileCont`] aggregate function.
120    pub fn new() -> Self {
121        let mut variants = Vec::with_capacity(NUMERICS.len() * (INTEGERS.len() + 1));
122        // Accept any numeric value paired with a float64 percentile
123        for num in NUMERICS {
124            variants.push(TypeSignature::Exact(vec![num.clone(), DataType::Float64]));
125            // Additionally accept an integer number of centroids for T-Digest
126            for int in INTEGERS {
127                variants.push(TypeSignature::Exact(vec![
128                    num.clone(),
129                    DataType::Float64,
130                    int.clone(),
131                ]))
132            }
133        }
134        Self {
135            signature: Signature::one_of(variants, Volatility::Immutable),
136        }
137    }
138
139    pub(crate) fn create_accumulator(
140        &self,
141        args: AccumulatorArgs,
142    ) -> Result<ApproxPercentileAccumulator> {
143        let percentile = validate_input_percentile_expr(&args.exprs[1])?;
144
145        let is_descending = args
146            .ordering_req
147            .first()
148            .map(|sort_expr| sort_expr.options.descending)
149            .unwrap_or(false);
150
151        let percentile = if is_descending {
152            1.0 - percentile
153        } else {
154            percentile
155        };
156
157        let tdigest_max_size = if args.exprs.len() == 3 {
158            Some(validate_input_max_size_expr(&args.exprs[2])?)
159        } else {
160            None
161        };
162
163        let data_type = args.exprs[0].data_type(args.schema)?;
164        let accumulator: ApproxPercentileAccumulator = match data_type {
165            t @ (DataType::UInt8
166            | DataType::UInt16
167            | DataType::UInt32
168            | DataType::UInt64
169            | DataType::Int8
170            | DataType::Int16
171            | DataType::Int32
172            | DataType::Int64
173            | DataType::Float32
174            | DataType::Float64) => {
175                if let Some(max_size) = tdigest_max_size {
176                    ApproxPercentileAccumulator::new_with_max_size(percentile, t, max_size)
177                }else{
178                    ApproxPercentileAccumulator::new(percentile, t)
179
180                }
181            }
182            other => {
183                return not_impl_err!(
184                    "Support for 'APPROX_PERCENTILE_CONT' for data type {other} is not implemented"
185                )
186            }
187        };
188
189        Ok(accumulator)
190    }
191}
192
193fn get_scalar_value(expr: &Arc<dyn PhysicalExpr>) -> Result<ScalarValue> {
194    let empty_schema = Arc::new(Schema::empty());
195    let batch = RecordBatch::new_empty(Arc::clone(&empty_schema));
196    if let ColumnarValue::Scalar(s) = expr.evaluate(&batch)? {
197        Ok(s)
198    } else {
199        internal_err!("Didn't expect ColumnarValue::Array")
200    }
201}
202
203fn validate_input_percentile_expr(expr: &Arc<dyn PhysicalExpr>) -> Result<f64> {
204    let percentile = match get_scalar_value(expr)
205        .map_err(|_| not_impl_datafusion_err!("Percentile value for 'APPROX_PERCENTILE_CONT' must be a literal, got: {expr}"))? {
206        ScalarValue::Float32(Some(value)) => {
207            value as f64
208        }
209        ScalarValue::Float64(Some(value)) => {
210            value
211        }
212        sv => {
213            return not_impl_err!(
214                "Percentile value for 'APPROX_PERCENTILE_CONT' must be Float32 or Float64 literal (got data type {})",
215                sv.data_type()
216            )
217        }
218    };
219
220    // Ensure the percentile is between 0 and 1.
221    if !(0.0..=1.0).contains(&percentile) {
222        return plan_err!(
223            "Percentile value must be between 0.0 and 1.0 inclusive, {percentile} is invalid"
224        );
225    }
226    Ok(percentile)
227}
228
229fn validate_input_max_size_expr(expr: &Arc<dyn PhysicalExpr>) -> Result<usize> {
230    let max_size = match get_scalar_value(expr)
231        .map_err(|_| not_impl_datafusion_err!("Tdigest max_size value for 'APPROX_PERCENTILE_CONT' must be a literal, got: {expr}"))? {
232        ScalarValue::UInt8(Some(q)) => q as usize,
233        ScalarValue::UInt16(Some(q)) => q as usize,
234        ScalarValue::UInt32(Some(q)) => q as usize,
235        ScalarValue::UInt64(Some(q)) => q as usize,
236        ScalarValue::Int32(Some(q)) if q > 0 => q as usize,
237        ScalarValue::Int64(Some(q)) if q > 0 => q as usize,
238        ScalarValue::Int16(Some(q)) if q > 0 => q as usize,
239        ScalarValue::Int8(Some(q)) if q > 0 => q as usize,
240        sv => {
241            return not_impl_err!(
242                "Tdigest max_size value for 'APPROX_PERCENTILE_CONT' must be UInt > 0 literal (got data type {}).",
243                sv.data_type()
244            )
245        },
246    };
247
248    Ok(max_size)
249}
250
251impl AggregateUDFImpl for ApproxPercentileCont {
252    fn as_any(&self) -> &dyn Any {
253        self
254    }
255
256    #[allow(rustdoc::private_intra_doc_links)]
257    /// See [`TDigest::to_scalar_state()`] for a description of the serialized
258    /// state.
259    fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
260        Ok(vec![
261            Field::new(
262                format_state_name(args.name, "max_size"),
263                DataType::UInt64,
264                false,
265            ),
266            Field::new(
267                format_state_name(args.name, "sum"),
268                DataType::Float64,
269                false,
270            ),
271            Field::new(
272                format_state_name(args.name, "count"),
273                DataType::UInt64,
274                false,
275            ),
276            Field::new(
277                format_state_name(args.name, "max"),
278                DataType::Float64,
279                false,
280            ),
281            Field::new(
282                format_state_name(args.name, "min"),
283                DataType::Float64,
284                false,
285            ),
286            Field::new_list(
287                format_state_name(args.name, "centroids"),
288                Field::new_list_field(DataType::Float64, true),
289                false,
290            ),
291        ]
292        .into_iter()
293        .map(Arc::new)
294        .collect())
295    }
296
297    fn name(&self) -> &str {
298        "approx_percentile_cont"
299    }
300
301    fn signature(&self) -> &Signature {
302        &self.signature
303    }
304
305    #[inline]
306    fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
307        Ok(Box::new(self.create_accumulator(acc_args)?))
308    }
309
310    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
311        if !arg_types[0].is_numeric() {
312            return plan_err!("approx_percentile_cont requires numeric input types");
313        }
314        if arg_types.len() == 3 && !arg_types[2].is_integer() {
315            return plan_err!(
316                "approx_percentile_cont requires integer max_size input types"
317            );
318        }
319        Ok(arg_types[0].clone())
320    }
321
322    fn supports_null_handling_clause(&self) -> bool {
323        false
324    }
325
326    fn is_ordered_set_aggregate(&self) -> bool {
327        true
328    }
329
330    fn documentation(&self) -> Option<&Documentation> {
331        self.doc()
332    }
333}
334
335#[derive(Debug)]
336pub struct ApproxPercentileAccumulator {
337    digest: TDigest,
338    percentile: f64,
339    return_type: DataType,
340}
341
342impl ApproxPercentileAccumulator {
343    pub fn new(percentile: f64, return_type: DataType) -> Self {
344        Self {
345            digest: TDigest::new(DEFAULT_MAX_SIZE),
346            percentile,
347            return_type,
348        }
349    }
350
351    pub fn new_with_max_size(
352        percentile: f64,
353        return_type: DataType,
354        max_size: usize,
355    ) -> Self {
356        Self {
357            digest: TDigest::new(max_size),
358            percentile,
359            return_type,
360        }
361    }
362
363    // public for approx_percentile_cont_with_weight
364    pub fn merge_digests(&mut self, digests: &[TDigest]) {
365        let digests = digests.iter().chain(std::iter::once(&self.digest));
366        self.digest = TDigest::merge_digests(digests)
367    }
368
369    // public for approx_percentile_cont_with_weight
370    pub fn convert_to_float(values: &ArrayRef) -> Result<Vec<f64>> {
371        match values.data_type() {
372            DataType::Float64 => {
373                let array = downcast_value!(values, Float64Array);
374                Ok(array
375                    .values()
376                    .iter()
377                    .filter_map(|v| v.try_as_f64().transpose())
378                    .collect::<Result<Vec<_>>>()?)
379            }
380            DataType::Float32 => {
381                let array = downcast_value!(values, Float32Array);
382                Ok(array
383                    .values()
384                    .iter()
385                    .filter_map(|v| v.try_as_f64().transpose())
386                    .collect::<Result<Vec<_>>>()?)
387            }
388            DataType::Int64 => {
389                let array = downcast_value!(values, Int64Array);
390                Ok(array
391                    .values()
392                    .iter()
393                    .filter_map(|v| v.try_as_f64().transpose())
394                    .collect::<Result<Vec<_>>>()?)
395            }
396            DataType::Int32 => {
397                let array = downcast_value!(values, Int32Array);
398                Ok(array
399                    .values()
400                    .iter()
401                    .filter_map(|v| v.try_as_f64().transpose())
402                    .collect::<Result<Vec<_>>>()?)
403            }
404            DataType::Int16 => {
405                let array = downcast_value!(values, Int16Array);
406                Ok(array
407                    .values()
408                    .iter()
409                    .filter_map(|v| v.try_as_f64().transpose())
410                    .collect::<Result<Vec<_>>>()?)
411            }
412            DataType::Int8 => {
413                let array = downcast_value!(values, Int8Array);
414                Ok(array
415                    .values()
416                    .iter()
417                    .filter_map(|v| v.try_as_f64().transpose())
418                    .collect::<Result<Vec<_>>>()?)
419            }
420            DataType::UInt64 => {
421                let array = downcast_value!(values, UInt64Array);
422                Ok(array
423                    .values()
424                    .iter()
425                    .filter_map(|v| v.try_as_f64().transpose())
426                    .collect::<Result<Vec<_>>>()?)
427            }
428            DataType::UInt32 => {
429                let array = downcast_value!(values, UInt32Array);
430                Ok(array
431                    .values()
432                    .iter()
433                    .filter_map(|v| v.try_as_f64().transpose())
434                    .collect::<Result<Vec<_>>>()?)
435            }
436            DataType::UInt16 => {
437                let array = downcast_value!(values, UInt16Array);
438                Ok(array
439                    .values()
440                    .iter()
441                    .filter_map(|v| v.try_as_f64().transpose())
442                    .collect::<Result<Vec<_>>>()?)
443            }
444            DataType::UInt8 => {
445                let array = downcast_value!(values, UInt8Array);
446                Ok(array
447                    .values()
448                    .iter()
449                    .filter_map(|v| v.try_as_f64().transpose())
450                    .collect::<Result<Vec<_>>>()?)
451            }
452            e => internal_err!(
453                "APPROX_PERCENTILE_CONT is not expected to receive the type {e:?}"
454            ),
455        }
456    }
457}
458
459impl Accumulator for ApproxPercentileAccumulator {
460    fn state(&mut self) -> Result<Vec<ScalarValue>> {
461        Ok(self.digest.to_scalar_state().into_iter().collect())
462    }
463
464    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
465        // Remove any nulls before computing the percentile
466        let mut values = Arc::clone(&values[0]);
467        if values.nulls().is_some() {
468            values = filter(&values, &is_not_null(&values)?)?;
469        }
470        let sorted_values = &arrow::compute::sort(&values, None)?;
471        let sorted_values = ApproxPercentileAccumulator::convert_to_float(sorted_values)?;
472        self.digest = self.digest.merge_sorted_f64(&sorted_values);
473        Ok(())
474    }
475
476    fn evaluate(&mut self) -> Result<ScalarValue> {
477        if self.digest.count() == 0 {
478            return ScalarValue::try_from(self.return_type.clone());
479        }
480        let q = self.digest.estimate_quantile(self.percentile);
481
482        // These acceptable return types MUST match the validation in
483        // ApproxPercentile::create_accumulator.
484        Ok(match &self.return_type {
485            DataType::Int8 => ScalarValue::Int8(Some(q as i8)),
486            DataType::Int16 => ScalarValue::Int16(Some(q as i16)),
487            DataType::Int32 => ScalarValue::Int32(Some(q as i32)),
488            DataType::Int64 => ScalarValue::Int64(Some(q as i64)),
489            DataType::UInt8 => ScalarValue::UInt8(Some(q as u8)),
490            DataType::UInt16 => ScalarValue::UInt16(Some(q as u16)),
491            DataType::UInt32 => ScalarValue::UInt32(Some(q as u32)),
492            DataType::UInt64 => ScalarValue::UInt64(Some(q as u64)),
493            DataType::Float32 => ScalarValue::Float32(Some(q as f32)),
494            DataType::Float64 => ScalarValue::Float64(Some(q)),
495            v => unreachable!("unexpected return type {:?}", v),
496        })
497    }
498
499    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
500        if states.is_empty() {
501            return Ok(());
502        }
503
504        let states = (0..states[0].len())
505            .map(|index| {
506                states
507                    .iter()
508                    .map(|array| ScalarValue::try_from_array(array, index))
509                    .collect::<Result<Vec<_>>>()
510                    .map(|state| TDigest::from_scalar_state(&state))
511            })
512            .collect::<Result<Vec<_>>>()?;
513
514        self.merge_digests(&states);
515
516        Ok(())
517    }
518
519    fn size(&self) -> usize {
520        size_of_val(self) + self.digest.size() - size_of_val(&self.digest)
521            + self.return_type.size()
522            - size_of_val(&self.return_type)
523    }
524}
525
526#[cfg(test)]
527mod tests {
528    use arrow::datatypes::DataType;
529
530    use datafusion_functions_aggregate_common::tdigest::TDigest;
531
532    use crate::approx_percentile_cont::ApproxPercentileAccumulator;
533
534    #[test]
535    fn test_combine_approx_percentile_accumulator() {
536        let mut digests: Vec<TDigest> = Vec::new();
537
538        // one TDigest with 50_000 values from 1 to 1_000
539        for _ in 1..=50 {
540            let t = TDigest::new(100);
541            let values: Vec<_> = (1..=1_000).map(f64::from).collect();
542            let t = t.merge_unsorted_f64(values);
543            digests.push(t)
544        }
545
546        let t1 = TDigest::merge_digests(&digests);
547        let t2 = TDigest::merge_digests(&digests);
548
549        let mut accumulator =
550            ApproxPercentileAccumulator::new_with_max_size(0.5, DataType::Float64, 100);
551
552        accumulator.merge_digests(&[t1]);
553        assert_eq!(accumulator.digest.count(), 50_000);
554        accumulator.merge_digests(&[t2]);
555        assert_eq!(accumulator.digest.count(), 100_000);
556    }
557}