datafusion_functions_aggregate/
approx_percentile_cont.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::any::Any;
19use std::fmt::{Debug, Formatter};
20use std::mem::size_of_val;
21use std::sync::Arc;
22
23use arrow::array::{Array, RecordBatch};
24use arrow::compute::{filter, is_not_null};
25use arrow::{
26    array::{
27        ArrayRef, Float32Array, Float64Array, Int16Array, Int32Array, Int64Array,
28        Int8Array, UInt16Array, UInt32Array, UInt64Array, UInt8Array,
29    },
30    datatypes::{DataType, Field, Schema},
31};
32
33use datafusion_common::{
34    downcast_value, internal_err, not_impl_datafusion_err, not_impl_err, plan_err,
35    Result, ScalarValue,
36};
37use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
38use datafusion_expr::type_coercion::aggregates::{INTEGERS, NUMERICS};
39use datafusion_expr::utils::format_state_name;
40use datafusion_expr::{
41    Accumulator, AggregateUDFImpl, ColumnarValue, Documentation, Expr, Signature,
42    TypeSignature, Volatility,
43};
44use datafusion_functions_aggregate_common::tdigest::{
45    TDigest, TryIntoF64, DEFAULT_MAX_SIZE,
46};
47use datafusion_macros::user_doc;
48use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
49
50create_func!(ApproxPercentileCont, approx_percentile_cont_udaf);
51
52/// Computes the approximate percentile continuous of a set of numbers
53pub fn approx_percentile_cont(
54    expression: Expr,
55    percentile: Expr,
56    centroids: Option<Expr>,
57) -> Expr {
58    let args = if let Some(centroids) = centroids {
59        vec![expression, percentile, centroids]
60    } else {
61        vec![expression, percentile]
62    };
63    approx_percentile_cont_udaf().call(args)
64}
65
66#[user_doc(
67    doc_section(label = "Approximate Functions"),
68    description = "Returns the approximate percentile of input values using the t-digest algorithm.",
69    syntax_example = "approx_percentile_cont(expression, percentile, centroids)",
70    sql_example = r#"```sql
71> SELECT approx_percentile_cont(column_name, 0.75, 100) FROM table_name;
72+-------------------------------------------------+
73| approx_percentile_cont(column_name, 0.75, 100)  |
74+-------------------------------------------------+
75| 65.0                                            |
76+-------------------------------------------------+
77```"#,
78    standard_argument(name = "expression",),
79    argument(
80        name = "percentile",
81        description = "Percentile to compute. Must be a float value between 0 and 1 (inclusive)."
82    ),
83    argument(
84        name = "centroids",
85        description = "Number of centroids to use in the t-digest algorithm. _Default is 100_. A higher number results in more accurate approximation but requires more memory."
86    )
87)]
88pub struct ApproxPercentileCont {
89    signature: Signature,
90}
91
92impl Debug for ApproxPercentileCont {
93    fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
94        f.debug_struct("ApproxPercentileCont")
95            .field("name", &self.name())
96            .field("signature", &self.signature)
97            .finish()
98    }
99}
100
101impl Default for ApproxPercentileCont {
102    fn default() -> Self {
103        Self::new()
104    }
105}
106
107impl ApproxPercentileCont {
108    /// Create a new [`ApproxPercentileCont`] aggregate function.
109    pub fn new() -> Self {
110        let mut variants = Vec::with_capacity(NUMERICS.len() * (INTEGERS.len() + 1));
111        // Accept any numeric value paired with a float64 percentile
112        for num in NUMERICS {
113            variants.push(TypeSignature::Exact(vec![num.clone(), DataType::Float64]));
114            // Additionally accept an integer number of centroids for T-Digest
115            for int in INTEGERS {
116                variants.push(TypeSignature::Exact(vec![
117                    num.clone(),
118                    DataType::Float64,
119                    int.clone(),
120                ]))
121            }
122        }
123        Self {
124            signature: Signature::one_of(variants, Volatility::Immutable),
125        }
126    }
127
128    pub(crate) fn create_accumulator(
129        &self,
130        args: AccumulatorArgs,
131    ) -> Result<ApproxPercentileAccumulator> {
132        let percentile = validate_input_percentile_expr(&args.exprs[1])?;
133        let tdigest_max_size = if args.exprs.len() == 3 {
134            Some(validate_input_max_size_expr(&args.exprs[2])?)
135        } else {
136            None
137        };
138
139        let data_type = args.exprs[0].data_type(args.schema)?;
140        let accumulator: ApproxPercentileAccumulator = match data_type {
141            t @ (DataType::UInt8
142            | DataType::UInt16
143            | DataType::UInt32
144            | DataType::UInt64
145            | DataType::Int8
146            | DataType::Int16
147            | DataType::Int32
148            | DataType::Int64
149            | DataType::Float32
150            | DataType::Float64) => {
151                if let Some(max_size) = tdigest_max_size {
152                    ApproxPercentileAccumulator::new_with_max_size(percentile, t, max_size)
153                }else{
154                    ApproxPercentileAccumulator::new(percentile, t)
155
156                }
157            }
158            other => {
159                return not_impl_err!(
160                    "Support for 'APPROX_PERCENTILE_CONT' for data type {other} is not implemented"
161                )
162            }
163        };
164
165        Ok(accumulator)
166    }
167}
168
169fn get_scalar_value(expr: &Arc<dyn PhysicalExpr>) -> Result<ScalarValue> {
170    let empty_schema = Arc::new(Schema::empty());
171    let batch = RecordBatch::new_empty(Arc::clone(&empty_schema));
172    if let ColumnarValue::Scalar(s) = expr.evaluate(&batch)? {
173        Ok(s)
174    } else {
175        internal_err!("Didn't expect ColumnarValue::Array")
176    }
177}
178
179fn validate_input_percentile_expr(expr: &Arc<dyn PhysicalExpr>) -> Result<f64> {
180    let percentile = match get_scalar_value(expr)
181        .map_err(|_| not_impl_datafusion_err!("Percentile value for 'APPROX_PERCENTILE_CONT' must be a literal, got: {expr}"))? {
182        ScalarValue::Float32(Some(value)) => {
183            value as f64
184        }
185        ScalarValue::Float64(Some(value)) => {
186            value
187        }
188        sv => {
189            return not_impl_err!(
190                "Percentile value for 'APPROX_PERCENTILE_CONT' must be Float32 or Float64 literal (got data type {})",
191                sv.data_type()
192            )
193        }
194    };
195
196    // Ensure the percentile is between 0 and 1.
197    if !(0.0..=1.0).contains(&percentile) {
198        return plan_err!(
199            "Percentile value must be between 0.0 and 1.0 inclusive, {percentile} is invalid"
200        );
201    }
202    Ok(percentile)
203}
204
205fn validate_input_max_size_expr(expr: &Arc<dyn PhysicalExpr>) -> Result<usize> {
206    let max_size = match get_scalar_value(expr)
207        .map_err(|_| not_impl_datafusion_err!("Tdigest max_size value for 'APPROX_PERCENTILE_CONT' must be a literal, got: {expr}"))? {
208        ScalarValue::UInt8(Some(q)) => q as usize,
209        ScalarValue::UInt16(Some(q)) => q as usize,
210        ScalarValue::UInt32(Some(q)) => q as usize,
211        ScalarValue::UInt64(Some(q)) => q as usize,
212        ScalarValue::Int32(Some(q)) if q > 0 => q as usize,
213        ScalarValue::Int64(Some(q)) if q > 0 => q as usize,
214        ScalarValue::Int16(Some(q)) if q > 0 => q as usize,
215        ScalarValue::Int8(Some(q)) if q > 0 => q as usize,
216        sv => {
217            return not_impl_err!(
218                "Tdigest max_size value for 'APPROX_PERCENTILE_CONT' must be UInt > 0 literal (got data type {}).",
219                sv.data_type()
220            )
221        },
222    };
223
224    Ok(max_size)
225}
226
227impl AggregateUDFImpl for ApproxPercentileCont {
228    fn as_any(&self) -> &dyn Any {
229        self
230    }
231
232    #[allow(rustdoc::private_intra_doc_links)]
233    /// See [`TDigest::to_scalar_state()`] for a description of the serialized
234    /// state.
235    fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> {
236        Ok(vec![
237            Field::new(
238                format_state_name(args.name, "max_size"),
239                DataType::UInt64,
240                false,
241            ),
242            Field::new(
243                format_state_name(args.name, "sum"),
244                DataType::Float64,
245                false,
246            ),
247            Field::new(
248                format_state_name(args.name, "count"),
249                DataType::UInt64,
250                false,
251            ),
252            Field::new(
253                format_state_name(args.name, "max"),
254                DataType::Float64,
255                false,
256            ),
257            Field::new(
258                format_state_name(args.name, "min"),
259                DataType::Float64,
260                false,
261            ),
262            Field::new_list(
263                format_state_name(args.name, "centroids"),
264                Field::new_list_field(DataType::Float64, true),
265                false,
266            ),
267        ])
268    }
269
270    fn name(&self) -> &str {
271        "approx_percentile_cont"
272    }
273
274    fn signature(&self) -> &Signature {
275        &self.signature
276    }
277
278    #[inline]
279    fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
280        Ok(Box::new(self.create_accumulator(acc_args)?))
281    }
282
283    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
284        if !arg_types[0].is_numeric() {
285            return plan_err!("approx_percentile_cont requires numeric input types");
286        }
287        if arg_types.len() == 3 && !arg_types[2].is_integer() {
288            return plan_err!(
289                "approx_percentile_cont requires integer max_size input types"
290            );
291        }
292        Ok(arg_types[0].clone())
293    }
294
295    fn documentation(&self) -> Option<&Documentation> {
296        self.doc()
297    }
298}
299
300#[derive(Debug)]
301pub struct ApproxPercentileAccumulator {
302    digest: TDigest,
303    percentile: f64,
304    return_type: DataType,
305}
306
307impl ApproxPercentileAccumulator {
308    pub fn new(percentile: f64, return_type: DataType) -> Self {
309        Self {
310            digest: TDigest::new(DEFAULT_MAX_SIZE),
311            percentile,
312            return_type,
313        }
314    }
315
316    pub fn new_with_max_size(
317        percentile: f64,
318        return_type: DataType,
319        max_size: usize,
320    ) -> Self {
321        Self {
322            digest: TDigest::new(max_size),
323            percentile,
324            return_type,
325        }
326    }
327
328    // public for approx_percentile_cont_with_weight
329    pub fn merge_digests(&mut self, digests: &[TDigest]) {
330        let digests = digests.iter().chain(std::iter::once(&self.digest));
331        self.digest = TDigest::merge_digests(digests)
332    }
333
334    // public for approx_percentile_cont_with_weight
335    pub fn convert_to_float(values: &ArrayRef) -> Result<Vec<f64>> {
336        match values.data_type() {
337            DataType::Float64 => {
338                let array = downcast_value!(values, Float64Array);
339                Ok(array
340                    .values()
341                    .iter()
342                    .filter_map(|v| v.try_as_f64().transpose())
343                    .collect::<Result<Vec<_>>>()?)
344            }
345            DataType::Float32 => {
346                let array = downcast_value!(values, Float32Array);
347                Ok(array
348                    .values()
349                    .iter()
350                    .filter_map(|v| v.try_as_f64().transpose())
351                    .collect::<Result<Vec<_>>>()?)
352            }
353            DataType::Int64 => {
354                let array = downcast_value!(values, Int64Array);
355                Ok(array
356                    .values()
357                    .iter()
358                    .filter_map(|v| v.try_as_f64().transpose())
359                    .collect::<Result<Vec<_>>>()?)
360            }
361            DataType::Int32 => {
362                let array = downcast_value!(values, Int32Array);
363                Ok(array
364                    .values()
365                    .iter()
366                    .filter_map(|v| v.try_as_f64().transpose())
367                    .collect::<Result<Vec<_>>>()?)
368            }
369            DataType::Int16 => {
370                let array = downcast_value!(values, Int16Array);
371                Ok(array
372                    .values()
373                    .iter()
374                    .filter_map(|v| v.try_as_f64().transpose())
375                    .collect::<Result<Vec<_>>>()?)
376            }
377            DataType::Int8 => {
378                let array = downcast_value!(values, Int8Array);
379                Ok(array
380                    .values()
381                    .iter()
382                    .filter_map(|v| v.try_as_f64().transpose())
383                    .collect::<Result<Vec<_>>>()?)
384            }
385            DataType::UInt64 => {
386                let array = downcast_value!(values, UInt64Array);
387                Ok(array
388                    .values()
389                    .iter()
390                    .filter_map(|v| v.try_as_f64().transpose())
391                    .collect::<Result<Vec<_>>>()?)
392            }
393            DataType::UInt32 => {
394                let array = downcast_value!(values, UInt32Array);
395                Ok(array
396                    .values()
397                    .iter()
398                    .filter_map(|v| v.try_as_f64().transpose())
399                    .collect::<Result<Vec<_>>>()?)
400            }
401            DataType::UInt16 => {
402                let array = downcast_value!(values, UInt16Array);
403                Ok(array
404                    .values()
405                    .iter()
406                    .filter_map(|v| v.try_as_f64().transpose())
407                    .collect::<Result<Vec<_>>>()?)
408            }
409            DataType::UInt8 => {
410                let array = downcast_value!(values, UInt8Array);
411                Ok(array
412                    .values()
413                    .iter()
414                    .filter_map(|v| v.try_as_f64().transpose())
415                    .collect::<Result<Vec<_>>>()?)
416            }
417            e => internal_err!(
418                "APPROX_PERCENTILE_CONT is not expected to receive the type {e:?}"
419            ),
420        }
421    }
422}
423
424impl Accumulator for ApproxPercentileAccumulator {
425    fn state(&mut self) -> Result<Vec<ScalarValue>> {
426        Ok(self.digest.to_scalar_state().into_iter().collect())
427    }
428
429    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
430        // Remove any nulls before computing the percentile
431        let mut values = Arc::clone(&values[0]);
432        if values.nulls().is_some() {
433            values = filter(&values, &is_not_null(&values)?)?;
434        }
435        let sorted_values = &arrow::compute::sort(&values, None)?;
436        let sorted_values = ApproxPercentileAccumulator::convert_to_float(sorted_values)?;
437        self.digest = self.digest.merge_sorted_f64(&sorted_values);
438        Ok(())
439    }
440
441    fn evaluate(&mut self) -> Result<ScalarValue> {
442        if self.digest.count() == 0 {
443            return ScalarValue::try_from(self.return_type.clone());
444        }
445        let q = self.digest.estimate_quantile(self.percentile);
446
447        // These acceptable return types MUST match the validation in
448        // ApproxPercentile::create_accumulator.
449        Ok(match &self.return_type {
450            DataType::Int8 => ScalarValue::Int8(Some(q as i8)),
451            DataType::Int16 => ScalarValue::Int16(Some(q as i16)),
452            DataType::Int32 => ScalarValue::Int32(Some(q as i32)),
453            DataType::Int64 => ScalarValue::Int64(Some(q as i64)),
454            DataType::UInt8 => ScalarValue::UInt8(Some(q as u8)),
455            DataType::UInt16 => ScalarValue::UInt16(Some(q as u16)),
456            DataType::UInt32 => ScalarValue::UInt32(Some(q as u32)),
457            DataType::UInt64 => ScalarValue::UInt64(Some(q as u64)),
458            DataType::Float32 => ScalarValue::Float32(Some(q as f32)),
459            DataType::Float64 => ScalarValue::Float64(Some(q)),
460            v => unreachable!("unexpected return type {:?}", v),
461        })
462    }
463
464    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
465        if states.is_empty() {
466            return Ok(());
467        }
468
469        let states = (0..states[0].len())
470            .map(|index| {
471                states
472                    .iter()
473                    .map(|array| ScalarValue::try_from_array(array, index))
474                    .collect::<Result<Vec<_>>>()
475                    .map(|state| TDigest::from_scalar_state(&state))
476            })
477            .collect::<Result<Vec<_>>>()?;
478
479        self.merge_digests(&states);
480
481        Ok(())
482    }
483
484    fn size(&self) -> usize {
485        size_of_val(self) + self.digest.size() - size_of_val(&self.digest)
486            + self.return_type.size()
487            - size_of_val(&self.return_type)
488    }
489}
490
491#[cfg(test)]
492mod tests {
493    use arrow::datatypes::DataType;
494
495    use datafusion_functions_aggregate_common::tdigest::TDigest;
496
497    use crate::approx_percentile_cont::ApproxPercentileAccumulator;
498
499    #[test]
500    fn test_combine_approx_percentile_accumulator() {
501        let mut digests: Vec<TDigest> = Vec::new();
502
503        // one TDigest with 50_000 values from 1 to 1_000
504        for _ in 1..=50 {
505            let t = TDigest::new(100);
506            let values: Vec<_> = (1..=1_000).map(f64::from).collect();
507            let t = t.merge_unsorted_f64(values);
508            digests.push(t)
509        }
510
511        let t1 = TDigest::merge_digests(&digests);
512        let t2 = TDigest::merge_digests(&digests);
513
514        let mut accumulator =
515            ApproxPercentileAccumulator::new_with_max_size(0.5, DataType::Float64, 100);
516
517        accumulator.merge_digests(&[t1]);
518        assert_eq!(accumulator.digest.count(), 50_000);
519        accumulator.merge_digests(&[t2]);
520        assert_eq!(accumulator.digest.count(), 100_000);
521    }
522}