1use std::any::Any;
19use std::fmt::{Debug, Formatter};
20use std::mem::size_of_val;
21use std::sync::Arc;
22
23use arrow::array::{Array, RecordBatch};
24use arrow::compute::{filter, is_not_null};
25use arrow::datatypes::FieldRef;
26use arrow::{
27 array::{
28 ArrayRef, Float32Array, Float64Array, Int16Array, Int32Array, Int64Array,
29 Int8Array, UInt16Array, UInt32Array, UInt64Array, UInt8Array,
30 },
31 datatypes::{DataType, Field, Schema},
32};
33use datafusion_common::{
34 downcast_value, internal_err, not_impl_datafusion_err, not_impl_err, plan_err,
35 Result, ScalarValue,
36};
37use datafusion_expr::expr::{AggregateFunction, Sort};
38use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
39use datafusion_expr::type_coercion::aggregates::{INTEGERS, NUMERICS};
40use datafusion_expr::utils::format_state_name;
41use datafusion_expr::{
42 Accumulator, AggregateUDFImpl, ColumnarValue, Documentation, Expr, Signature,
43 TypeSignature, Volatility,
44};
45use datafusion_functions_aggregate_common::tdigest::{
46 TDigest, TryIntoF64, DEFAULT_MAX_SIZE,
47};
48use datafusion_macros::user_doc;
49use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
50
51create_func!(ApproxPercentileCont, approx_percentile_cont_udaf);
52
53pub fn approx_percentile_cont(
55 order_by: Sort,
56 percentile: Expr,
57 centroids: Option<Expr>,
58) -> Expr {
59 let expr = order_by.expr.clone();
60
61 let args = if let Some(centroids) = centroids {
62 vec![expr, percentile, centroids]
63 } else {
64 vec![expr, percentile]
65 };
66
67 Expr::AggregateFunction(AggregateFunction::new_udf(
68 approx_percentile_cont_udaf(),
69 args,
70 false,
71 None,
72 Some(vec![order_by]),
73 None,
74 ))
75}
76
77#[user_doc(
78 doc_section(label = "Approximate Functions"),
79 description = "Returns the approximate percentile of input values using the t-digest algorithm.",
80 syntax_example = "approx_percentile_cont(percentile, centroids) WITHIN GROUP (ORDER BY expression)",
81 sql_example = r#"```sql
82> SELECT approx_percentile_cont(0.75, 100) WITHIN GROUP (ORDER BY column_name) FROM table_name;
83+-----------------------------------------------------------------------+
84| approx_percentile_cont(0.75, 100) WITHIN GROUP (ORDER BY column_name) |
85+-----------------------------------------------------------------------+
86| 65.0 |
87+-----------------------------------------------------------------------+
88```"#,
89 standard_argument(name = "expression",),
90 argument(
91 name = "percentile",
92 description = "Percentile to compute. Must be a float value between 0 and 1 (inclusive)."
93 ),
94 argument(
95 name = "centroids",
96 description = "Number of centroids to use in the t-digest algorithm. _Default is 100_. A higher number results in more accurate approximation but requires more memory."
97 )
98)]
99pub struct ApproxPercentileCont {
100 signature: Signature,
101}
102
103impl Debug for ApproxPercentileCont {
104 fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
105 f.debug_struct("ApproxPercentileCont")
106 .field("name", &self.name())
107 .field("signature", &self.signature)
108 .finish()
109 }
110}
111
112impl Default for ApproxPercentileCont {
113 fn default() -> Self {
114 Self::new()
115 }
116}
117
118impl ApproxPercentileCont {
119 pub fn new() -> Self {
121 let mut variants = Vec::with_capacity(NUMERICS.len() * (INTEGERS.len() + 1));
122 for num in NUMERICS {
124 variants.push(TypeSignature::Exact(vec![num.clone(), DataType::Float64]));
125 for int in INTEGERS {
127 variants.push(TypeSignature::Exact(vec![
128 num.clone(),
129 DataType::Float64,
130 int.clone(),
131 ]))
132 }
133 }
134 Self {
135 signature: Signature::one_of(variants, Volatility::Immutable),
136 }
137 }
138
139 pub(crate) fn create_accumulator(
140 &self,
141 args: AccumulatorArgs,
142 ) -> Result<ApproxPercentileAccumulator> {
143 let percentile = validate_input_percentile_expr(&args.exprs[1])?;
144
145 let is_descending = args
146 .ordering_req
147 .first()
148 .map(|sort_expr| sort_expr.options.descending)
149 .unwrap_or(false);
150
151 let percentile = if is_descending {
152 1.0 - percentile
153 } else {
154 percentile
155 };
156
157 let tdigest_max_size = if args.exprs.len() == 3 {
158 Some(validate_input_max_size_expr(&args.exprs[2])?)
159 } else {
160 None
161 };
162
163 let data_type = args.exprs[0].data_type(args.schema)?;
164 let accumulator: ApproxPercentileAccumulator = match data_type {
165 t @ (DataType::UInt8
166 | DataType::UInt16
167 | DataType::UInt32
168 | DataType::UInt64
169 | DataType::Int8
170 | DataType::Int16
171 | DataType::Int32
172 | DataType::Int64
173 | DataType::Float32
174 | DataType::Float64) => {
175 if let Some(max_size) = tdigest_max_size {
176 ApproxPercentileAccumulator::new_with_max_size(percentile, t, max_size)
177 }else{
178 ApproxPercentileAccumulator::new(percentile, t)
179
180 }
181 }
182 other => {
183 return not_impl_err!(
184 "Support for 'APPROX_PERCENTILE_CONT' for data type {other} is not implemented"
185 )
186 }
187 };
188
189 Ok(accumulator)
190 }
191}
192
193fn get_scalar_value(expr: &Arc<dyn PhysicalExpr>) -> Result<ScalarValue> {
194 let empty_schema = Arc::new(Schema::empty());
195 let batch = RecordBatch::new_empty(Arc::clone(&empty_schema));
196 if let ColumnarValue::Scalar(s) = expr.evaluate(&batch)? {
197 Ok(s)
198 } else {
199 internal_err!("Didn't expect ColumnarValue::Array")
200 }
201}
202
203fn validate_input_percentile_expr(expr: &Arc<dyn PhysicalExpr>) -> Result<f64> {
204 let percentile = match get_scalar_value(expr)
205 .map_err(|_| not_impl_datafusion_err!("Percentile value for 'APPROX_PERCENTILE_CONT' must be a literal, got: {expr}"))? {
206 ScalarValue::Float32(Some(value)) => {
207 value as f64
208 }
209 ScalarValue::Float64(Some(value)) => {
210 value
211 }
212 sv => {
213 return not_impl_err!(
214 "Percentile value for 'APPROX_PERCENTILE_CONT' must be Float32 or Float64 literal (got data type {})",
215 sv.data_type()
216 )
217 }
218 };
219
220 if !(0.0..=1.0).contains(&percentile) {
222 return plan_err!(
223 "Percentile value must be between 0.0 and 1.0 inclusive, {percentile} is invalid"
224 );
225 }
226 Ok(percentile)
227}
228
229fn validate_input_max_size_expr(expr: &Arc<dyn PhysicalExpr>) -> Result<usize> {
230 let max_size = match get_scalar_value(expr)
231 .map_err(|_| not_impl_datafusion_err!("Tdigest max_size value for 'APPROX_PERCENTILE_CONT' must be a literal, got: {expr}"))? {
232 ScalarValue::UInt8(Some(q)) => q as usize,
233 ScalarValue::UInt16(Some(q)) => q as usize,
234 ScalarValue::UInt32(Some(q)) => q as usize,
235 ScalarValue::UInt64(Some(q)) => q as usize,
236 ScalarValue::Int32(Some(q)) if q > 0 => q as usize,
237 ScalarValue::Int64(Some(q)) if q > 0 => q as usize,
238 ScalarValue::Int16(Some(q)) if q > 0 => q as usize,
239 ScalarValue::Int8(Some(q)) if q > 0 => q as usize,
240 sv => {
241 return not_impl_err!(
242 "Tdigest max_size value for 'APPROX_PERCENTILE_CONT' must be UInt > 0 literal (got data type {}).",
243 sv.data_type()
244 )
245 },
246 };
247
248 Ok(max_size)
249}
250
251impl AggregateUDFImpl for ApproxPercentileCont {
252 fn as_any(&self) -> &dyn Any {
253 self
254 }
255
256 #[allow(rustdoc::private_intra_doc_links)]
257 fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
260 Ok(vec![
261 Field::new(
262 format_state_name(args.name, "max_size"),
263 DataType::UInt64,
264 false,
265 ),
266 Field::new(
267 format_state_name(args.name, "sum"),
268 DataType::Float64,
269 false,
270 ),
271 Field::new(
272 format_state_name(args.name, "count"),
273 DataType::UInt64,
274 false,
275 ),
276 Field::new(
277 format_state_name(args.name, "max"),
278 DataType::Float64,
279 false,
280 ),
281 Field::new(
282 format_state_name(args.name, "min"),
283 DataType::Float64,
284 false,
285 ),
286 Field::new_list(
287 format_state_name(args.name, "centroids"),
288 Field::new_list_field(DataType::Float64, true),
289 false,
290 ),
291 ]
292 .into_iter()
293 .map(Arc::new)
294 .collect())
295 }
296
297 fn name(&self) -> &str {
298 "approx_percentile_cont"
299 }
300
301 fn signature(&self) -> &Signature {
302 &self.signature
303 }
304
305 #[inline]
306 fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
307 Ok(Box::new(self.create_accumulator(acc_args)?))
308 }
309
310 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
311 if !arg_types[0].is_numeric() {
312 return plan_err!("approx_percentile_cont requires numeric input types");
313 }
314 if arg_types.len() == 3 && !arg_types[2].is_integer() {
315 return plan_err!(
316 "approx_percentile_cont requires integer max_size input types"
317 );
318 }
319 Ok(arg_types[0].clone())
320 }
321
322 fn supports_null_handling_clause(&self) -> bool {
323 false
324 }
325
326 fn is_ordered_set_aggregate(&self) -> bool {
327 true
328 }
329
330 fn documentation(&self) -> Option<&Documentation> {
331 self.doc()
332 }
333}
334
335#[derive(Debug)]
336pub struct ApproxPercentileAccumulator {
337 digest: TDigest,
338 percentile: f64,
339 return_type: DataType,
340}
341
342impl ApproxPercentileAccumulator {
343 pub fn new(percentile: f64, return_type: DataType) -> Self {
344 Self {
345 digest: TDigest::new(DEFAULT_MAX_SIZE),
346 percentile,
347 return_type,
348 }
349 }
350
351 pub fn new_with_max_size(
352 percentile: f64,
353 return_type: DataType,
354 max_size: usize,
355 ) -> Self {
356 Self {
357 digest: TDigest::new(max_size),
358 percentile,
359 return_type,
360 }
361 }
362
363 pub fn merge_digests(&mut self, digests: &[TDigest]) {
365 let digests = digests.iter().chain(std::iter::once(&self.digest));
366 self.digest = TDigest::merge_digests(digests)
367 }
368
369 pub fn convert_to_float(values: &ArrayRef) -> Result<Vec<f64>> {
371 match values.data_type() {
372 DataType::Float64 => {
373 let array = downcast_value!(values, Float64Array);
374 Ok(array
375 .values()
376 .iter()
377 .filter_map(|v| v.try_as_f64().transpose())
378 .collect::<Result<Vec<_>>>()?)
379 }
380 DataType::Float32 => {
381 let array = downcast_value!(values, Float32Array);
382 Ok(array
383 .values()
384 .iter()
385 .filter_map(|v| v.try_as_f64().transpose())
386 .collect::<Result<Vec<_>>>()?)
387 }
388 DataType::Int64 => {
389 let array = downcast_value!(values, Int64Array);
390 Ok(array
391 .values()
392 .iter()
393 .filter_map(|v| v.try_as_f64().transpose())
394 .collect::<Result<Vec<_>>>()?)
395 }
396 DataType::Int32 => {
397 let array = downcast_value!(values, Int32Array);
398 Ok(array
399 .values()
400 .iter()
401 .filter_map(|v| v.try_as_f64().transpose())
402 .collect::<Result<Vec<_>>>()?)
403 }
404 DataType::Int16 => {
405 let array = downcast_value!(values, Int16Array);
406 Ok(array
407 .values()
408 .iter()
409 .filter_map(|v| v.try_as_f64().transpose())
410 .collect::<Result<Vec<_>>>()?)
411 }
412 DataType::Int8 => {
413 let array = downcast_value!(values, Int8Array);
414 Ok(array
415 .values()
416 .iter()
417 .filter_map(|v| v.try_as_f64().transpose())
418 .collect::<Result<Vec<_>>>()?)
419 }
420 DataType::UInt64 => {
421 let array = downcast_value!(values, UInt64Array);
422 Ok(array
423 .values()
424 .iter()
425 .filter_map(|v| v.try_as_f64().transpose())
426 .collect::<Result<Vec<_>>>()?)
427 }
428 DataType::UInt32 => {
429 let array = downcast_value!(values, UInt32Array);
430 Ok(array
431 .values()
432 .iter()
433 .filter_map(|v| v.try_as_f64().transpose())
434 .collect::<Result<Vec<_>>>()?)
435 }
436 DataType::UInt16 => {
437 let array = downcast_value!(values, UInt16Array);
438 Ok(array
439 .values()
440 .iter()
441 .filter_map(|v| v.try_as_f64().transpose())
442 .collect::<Result<Vec<_>>>()?)
443 }
444 DataType::UInt8 => {
445 let array = downcast_value!(values, UInt8Array);
446 Ok(array
447 .values()
448 .iter()
449 .filter_map(|v| v.try_as_f64().transpose())
450 .collect::<Result<Vec<_>>>()?)
451 }
452 e => internal_err!(
453 "APPROX_PERCENTILE_CONT is not expected to receive the type {e:?}"
454 ),
455 }
456 }
457}
458
459impl Accumulator for ApproxPercentileAccumulator {
460 fn state(&mut self) -> Result<Vec<ScalarValue>> {
461 Ok(self.digest.to_scalar_state().into_iter().collect())
462 }
463
464 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
465 let mut values = Arc::clone(&values[0]);
467 if values.nulls().is_some() {
468 values = filter(&values, &is_not_null(&values)?)?;
469 }
470 let sorted_values = &arrow::compute::sort(&values, None)?;
471 let sorted_values = ApproxPercentileAccumulator::convert_to_float(sorted_values)?;
472 self.digest = self.digest.merge_sorted_f64(&sorted_values);
473 Ok(())
474 }
475
476 fn evaluate(&mut self) -> Result<ScalarValue> {
477 if self.digest.count() == 0 {
478 return ScalarValue::try_from(self.return_type.clone());
479 }
480 let q = self.digest.estimate_quantile(self.percentile);
481
482 Ok(match &self.return_type {
485 DataType::Int8 => ScalarValue::Int8(Some(q as i8)),
486 DataType::Int16 => ScalarValue::Int16(Some(q as i16)),
487 DataType::Int32 => ScalarValue::Int32(Some(q as i32)),
488 DataType::Int64 => ScalarValue::Int64(Some(q as i64)),
489 DataType::UInt8 => ScalarValue::UInt8(Some(q as u8)),
490 DataType::UInt16 => ScalarValue::UInt16(Some(q as u16)),
491 DataType::UInt32 => ScalarValue::UInt32(Some(q as u32)),
492 DataType::UInt64 => ScalarValue::UInt64(Some(q as u64)),
493 DataType::Float32 => ScalarValue::Float32(Some(q as f32)),
494 DataType::Float64 => ScalarValue::Float64(Some(q)),
495 v => unreachable!("unexpected return type {:?}", v),
496 })
497 }
498
499 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
500 if states.is_empty() {
501 return Ok(());
502 }
503
504 let states = (0..states[0].len())
505 .map(|index| {
506 states
507 .iter()
508 .map(|array| ScalarValue::try_from_array(array, index))
509 .collect::<Result<Vec<_>>>()
510 .map(|state| TDigest::from_scalar_state(&state))
511 })
512 .collect::<Result<Vec<_>>>()?;
513
514 self.merge_digests(&states);
515
516 Ok(())
517 }
518
519 fn size(&self) -> usize {
520 size_of_val(self) + self.digest.size() - size_of_val(&self.digest)
521 + self.return_type.size()
522 - size_of_val(&self.return_type)
523 }
524}
525
526#[cfg(test)]
527mod tests {
528 use arrow::datatypes::DataType;
529
530 use datafusion_functions_aggregate_common::tdigest::TDigest;
531
532 use crate::approx_percentile_cont::ApproxPercentileAccumulator;
533
534 #[test]
535 fn test_combine_approx_percentile_accumulator() {
536 let mut digests: Vec<TDigest> = Vec::new();
537
538 for _ in 1..=50 {
540 let t = TDigest::new(100);
541 let values: Vec<_> = (1..=1_000).map(f64::from).collect();
542 let t = t.merge_unsorted_f64(values);
543 digests.push(t)
544 }
545
546 let t1 = TDigest::merge_digests(&digests);
547 let t2 = TDigest::merge_digests(&digests);
548
549 let mut accumulator =
550 ApproxPercentileAccumulator::new_with_max_size(0.5, DataType::Float64, 100);
551
552 accumulator.merge_digests(&[t1]);
553 assert_eq!(accumulator.digest.count(), 50_000);
554 accumulator.merge_digests(&[t2]);
555 assert_eq!(accumulator.digest.count(), 100_000);
556 }
557}