1use std::any::Any;
19use std::fmt::{Debug, Formatter};
20use std::mem::size_of_val;
21use std::sync::Arc;
22
23use arrow::array::{Array, RecordBatch};
24use arrow::compute::{filter, is_not_null};
25use arrow::{
26 array::{
27 ArrayRef, Float32Array, Float64Array, Int16Array, Int32Array, Int64Array,
28 Int8Array, UInt16Array, UInt32Array, UInt64Array, UInt8Array,
29 },
30 datatypes::{DataType, Field, Schema},
31};
32
33use datafusion_common::{
34 downcast_value, internal_err, not_impl_datafusion_err, not_impl_err, plan_err,
35 Result, ScalarValue,
36};
37use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
38use datafusion_expr::type_coercion::aggregates::{INTEGERS, NUMERICS};
39use datafusion_expr::utils::format_state_name;
40use datafusion_expr::{
41 Accumulator, AggregateUDFImpl, ColumnarValue, Documentation, Expr, Signature,
42 TypeSignature, Volatility,
43};
44use datafusion_functions_aggregate_common::tdigest::{
45 TDigest, TryIntoF64, DEFAULT_MAX_SIZE,
46};
47use datafusion_macros::user_doc;
48use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
49
50create_func!(ApproxPercentileCont, approx_percentile_cont_udaf);
51
52pub fn approx_percentile_cont(
54 expression: Expr,
55 percentile: Expr,
56 centroids: Option<Expr>,
57) -> Expr {
58 let args = if let Some(centroids) = centroids {
59 vec![expression, percentile, centroids]
60 } else {
61 vec![expression, percentile]
62 };
63 approx_percentile_cont_udaf().call(args)
64}
65
66#[user_doc(
67 doc_section(label = "Approximate Functions"),
68 description = "Returns the approximate percentile of input values using the t-digest algorithm.",
69 syntax_example = "approx_percentile_cont(expression, percentile, centroids)",
70 sql_example = r#"```sql
71> SELECT approx_percentile_cont(column_name, 0.75, 100) FROM table_name;
72+-------------------------------------------------+
73| approx_percentile_cont(column_name, 0.75, 100) |
74+-------------------------------------------------+
75| 65.0 |
76+-------------------------------------------------+
77```"#,
78 standard_argument(name = "expression",),
79 argument(
80 name = "percentile",
81 description = "Percentile to compute. Must be a float value between 0 and 1 (inclusive)."
82 ),
83 argument(
84 name = "centroids",
85 description = "Number of centroids to use in the t-digest algorithm. _Default is 100_. A higher number results in more accurate approximation but requires more memory."
86 )
87)]
88pub struct ApproxPercentileCont {
89 signature: Signature,
90}
91
92impl Debug for ApproxPercentileCont {
93 fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
94 f.debug_struct("ApproxPercentileCont")
95 .field("name", &self.name())
96 .field("signature", &self.signature)
97 .finish()
98 }
99}
100
101impl Default for ApproxPercentileCont {
102 fn default() -> Self {
103 Self::new()
104 }
105}
106
107impl ApproxPercentileCont {
108 pub fn new() -> Self {
110 let mut variants = Vec::with_capacity(NUMERICS.len() * (INTEGERS.len() + 1));
111 for num in NUMERICS {
113 variants.push(TypeSignature::Exact(vec![num.clone(), DataType::Float64]));
114 for int in INTEGERS {
116 variants.push(TypeSignature::Exact(vec![
117 num.clone(),
118 DataType::Float64,
119 int.clone(),
120 ]))
121 }
122 }
123 Self {
124 signature: Signature::one_of(variants, Volatility::Immutable),
125 }
126 }
127
128 pub(crate) fn create_accumulator(
129 &self,
130 args: AccumulatorArgs,
131 ) -> Result<ApproxPercentileAccumulator> {
132 let percentile = validate_input_percentile_expr(&args.exprs[1])?;
133 let tdigest_max_size = if args.exprs.len() == 3 {
134 Some(validate_input_max_size_expr(&args.exprs[2])?)
135 } else {
136 None
137 };
138
139 let data_type = args.exprs[0].data_type(args.schema)?;
140 let accumulator: ApproxPercentileAccumulator = match data_type {
141 t @ (DataType::UInt8
142 | DataType::UInt16
143 | DataType::UInt32
144 | DataType::UInt64
145 | DataType::Int8
146 | DataType::Int16
147 | DataType::Int32
148 | DataType::Int64
149 | DataType::Float32
150 | DataType::Float64) => {
151 if let Some(max_size) = tdigest_max_size {
152 ApproxPercentileAccumulator::new_with_max_size(percentile, t, max_size)
153 }else{
154 ApproxPercentileAccumulator::new(percentile, t)
155
156 }
157 }
158 other => {
159 return not_impl_err!(
160 "Support for 'APPROX_PERCENTILE_CONT' for data type {other} is not implemented"
161 )
162 }
163 };
164
165 Ok(accumulator)
166 }
167}
168
169fn get_scalar_value(expr: &Arc<dyn PhysicalExpr>) -> Result<ScalarValue> {
170 let empty_schema = Arc::new(Schema::empty());
171 let batch = RecordBatch::new_empty(Arc::clone(&empty_schema));
172 if let ColumnarValue::Scalar(s) = expr.evaluate(&batch)? {
173 Ok(s)
174 } else {
175 internal_err!("Didn't expect ColumnarValue::Array")
176 }
177}
178
179fn validate_input_percentile_expr(expr: &Arc<dyn PhysicalExpr>) -> Result<f64> {
180 let percentile = match get_scalar_value(expr)
181 .map_err(|_| not_impl_datafusion_err!("Percentile value for 'APPROX_PERCENTILE_CONT' must be a literal, got: {expr}"))? {
182 ScalarValue::Float32(Some(value)) => {
183 value as f64
184 }
185 ScalarValue::Float64(Some(value)) => {
186 value
187 }
188 sv => {
189 return not_impl_err!(
190 "Percentile value for 'APPROX_PERCENTILE_CONT' must be Float32 or Float64 literal (got data type {})",
191 sv.data_type()
192 )
193 }
194 };
195
196 if !(0.0..=1.0).contains(&percentile) {
198 return plan_err!(
199 "Percentile value must be between 0.0 and 1.0 inclusive, {percentile} is invalid"
200 );
201 }
202 Ok(percentile)
203}
204
205fn validate_input_max_size_expr(expr: &Arc<dyn PhysicalExpr>) -> Result<usize> {
206 let max_size = match get_scalar_value(expr)
207 .map_err(|_| not_impl_datafusion_err!("Tdigest max_size value for 'APPROX_PERCENTILE_CONT' must be a literal, got: {expr}"))? {
208 ScalarValue::UInt8(Some(q)) => q as usize,
209 ScalarValue::UInt16(Some(q)) => q as usize,
210 ScalarValue::UInt32(Some(q)) => q as usize,
211 ScalarValue::UInt64(Some(q)) => q as usize,
212 ScalarValue::Int32(Some(q)) if q > 0 => q as usize,
213 ScalarValue::Int64(Some(q)) if q > 0 => q as usize,
214 ScalarValue::Int16(Some(q)) if q > 0 => q as usize,
215 ScalarValue::Int8(Some(q)) if q > 0 => q as usize,
216 sv => {
217 return not_impl_err!(
218 "Tdigest max_size value for 'APPROX_PERCENTILE_CONT' must be UInt > 0 literal (got data type {}).",
219 sv.data_type()
220 )
221 },
222 };
223
224 Ok(max_size)
225}
226
227impl AggregateUDFImpl for ApproxPercentileCont {
228 fn as_any(&self) -> &dyn Any {
229 self
230 }
231
232 #[allow(rustdoc::private_intra_doc_links)]
233 fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> {
236 Ok(vec![
237 Field::new(
238 format_state_name(args.name, "max_size"),
239 DataType::UInt64,
240 false,
241 ),
242 Field::new(
243 format_state_name(args.name, "sum"),
244 DataType::Float64,
245 false,
246 ),
247 Field::new(
248 format_state_name(args.name, "count"),
249 DataType::UInt64,
250 false,
251 ),
252 Field::new(
253 format_state_name(args.name, "max"),
254 DataType::Float64,
255 false,
256 ),
257 Field::new(
258 format_state_name(args.name, "min"),
259 DataType::Float64,
260 false,
261 ),
262 Field::new_list(
263 format_state_name(args.name, "centroids"),
264 Field::new_list_field(DataType::Float64, true),
265 false,
266 ),
267 ])
268 }
269
270 fn name(&self) -> &str {
271 "approx_percentile_cont"
272 }
273
274 fn signature(&self) -> &Signature {
275 &self.signature
276 }
277
278 #[inline]
279 fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
280 Ok(Box::new(self.create_accumulator(acc_args)?))
281 }
282
283 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
284 if !arg_types[0].is_numeric() {
285 return plan_err!("approx_percentile_cont requires numeric input types");
286 }
287 if arg_types.len() == 3 && !arg_types[2].is_integer() {
288 return plan_err!(
289 "approx_percentile_cont requires integer max_size input types"
290 );
291 }
292 Ok(arg_types[0].clone())
293 }
294
295 fn documentation(&self) -> Option<&Documentation> {
296 self.doc()
297 }
298}
299
300#[derive(Debug)]
301pub struct ApproxPercentileAccumulator {
302 digest: TDigest,
303 percentile: f64,
304 return_type: DataType,
305}
306
307impl ApproxPercentileAccumulator {
308 pub fn new(percentile: f64, return_type: DataType) -> Self {
309 Self {
310 digest: TDigest::new(DEFAULT_MAX_SIZE),
311 percentile,
312 return_type,
313 }
314 }
315
316 pub fn new_with_max_size(
317 percentile: f64,
318 return_type: DataType,
319 max_size: usize,
320 ) -> Self {
321 Self {
322 digest: TDigest::new(max_size),
323 percentile,
324 return_type,
325 }
326 }
327
328 pub fn merge_digests(&mut self, digests: &[TDigest]) {
330 let digests = digests.iter().chain(std::iter::once(&self.digest));
331 self.digest = TDigest::merge_digests(digests)
332 }
333
334 pub fn convert_to_float(values: &ArrayRef) -> Result<Vec<f64>> {
336 match values.data_type() {
337 DataType::Float64 => {
338 let array = downcast_value!(values, Float64Array);
339 Ok(array
340 .values()
341 .iter()
342 .filter_map(|v| v.try_as_f64().transpose())
343 .collect::<Result<Vec<_>>>()?)
344 }
345 DataType::Float32 => {
346 let array = downcast_value!(values, Float32Array);
347 Ok(array
348 .values()
349 .iter()
350 .filter_map(|v| v.try_as_f64().transpose())
351 .collect::<Result<Vec<_>>>()?)
352 }
353 DataType::Int64 => {
354 let array = downcast_value!(values, Int64Array);
355 Ok(array
356 .values()
357 .iter()
358 .filter_map(|v| v.try_as_f64().transpose())
359 .collect::<Result<Vec<_>>>()?)
360 }
361 DataType::Int32 => {
362 let array = downcast_value!(values, Int32Array);
363 Ok(array
364 .values()
365 .iter()
366 .filter_map(|v| v.try_as_f64().transpose())
367 .collect::<Result<Vec<_>>>()?)
368 }
369 DataType::Int16 => {
370 let array = downcast_value!(values, Int16Array);
371 Ok(array
372 .values()
373 .iter()
374 .filter_map(|v| v.try_as_f64().transpose())
375 .collect::<Result<Vec<_>>>()?)
376 }
377 DataType::Int8 => {
378 let array = downcast_value!(values, Int8Array);
379 Ok(array
380 .values()
381 .iter()
382 .filter_map(|v| v.try_as_f64().transpose())
383 .collect::<Result<Vec<_>>>()?)
384 }
385 DataType::UInt64 => {
386 let array = downcast_value!(values, UInt64Array);
387 Ok(array
388 .values()
389 .iter()
390 .filter_map(|v| v.try_as_f64().transpose())
391 .collect::<Result<Vec<_>>>()?)
392 }
393 DataType::UInt32 => {
394 let array = downcast_value!(values, UInt32Array);
395 Ok(array
396 .values()
397 .iter()
398 .filter_map(|v| v.try_as_f64().transpose())
399 .collect::<Result<Vec<_>>>()?)
400 }
401 DataType::UInt16 => {
402 let array = downcast_value!(values, UInt16Array);
403 Ok(array
404 .values()
405 .iter()
406 .filter_map(|v| v.try_as_f64().transpose())
407 .collect::<Result<Vec<_>>>()?)
408 }
409 DataType::UInt8 => {
410 let array = downcast_value!(values, UInt8Array);
411 Ok(array
412 .values()
413 .iter()
414 .filter_map(|v| v.try_as_f64().transpose())
415 .collect::<Result<Vec<_>>>()?)
416 }
417 e => internal_err!(
418 "APPROX_PERCENTILE_CONT is not expected to receive the type {e:?}"
419 ),
420 }
421 }
422}
423
424impl Accumulator for ApproxPercentileAccumulator {
425 fn state(&mut self) -> Result<Vec<ScalarValue>> {
426 Ok(self.digest.to_scalar_state().into_iter().collect())
427 }
428
429 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
430 let mut values = Arc::clone(&values[0]);
432 if values.nulls().is_some() {
433 values = filter(&values, &is_not_null(&values)?)?;
434 }
435 let sorted_values = &arrow::compute::sort(&values, None)?;
436 let sorted_values = ApproxPercentileAccumulator::convert_to_float(sorted_values)?;
437 self.digest = self.digest.merge_sorted_f64(&sorted_values);
438 Ok(())
439 }
440
441 fn evaluate(&mut self) -> Result<ScalarValue> {
442 if self.digest.count() == 0 {
443 return ScalarValue::try_from(self.return_type.clone());
444 }
445 let q = self.digest.estimate_quantile(self.percentile);
446
447 Ok(match &self.return_type {
450 DataType::Int8 => ScalarValue::Int8(Some(q as i8)),
451 DataType::Int16 => ScalarValue::Int16(Some(q as i16)),
452 DataType::Int32 => ScalarValue::Int32(Some(q as i32)),
453 DataType::Int64 => ScalarValue::Int64(Some(q as i64)),
454 DataType::UInt8 => ScalarValue::UInt8(Some(q as u8)),
455 DataType::UInt16 => ScalarValue::UInt16(Some(q as u16)),
456 DataType::UInt32 => ScalarValue::UInt32(Some(q as u32)),
457 DataType::UInt64 => ScalarValue::UInt64(Some(q as u64)),
458 DataType::Float32 => ScalarValue::Float32(Some(q as f32)),
459 DataType::Float64 => ScalarValue::Float64(Some(q)),
460 v => unreachable!("unexpected return type {:?}", v),
461 })
462 }
463
464 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
465 if states.is_empty() {
466 return Ok(());
467 }
468
469 let states = (0..states[0].len())
470 .map(|index| {
471 states
472 .iter()
473 .map(|array| ScalarValue::try_from_array(array, index))
474 .collect::<Result<Vec<_>>>()
475 .map(|state| TDigest::from_scalar_state(&state))
476 })
477 .collect::<Result<Vec<_>>>()?;
478
479 self.merge_digests(&states);
480
481 Ok(())
482 }
483
484 fn size(&self) -> usize {
485 size_of_val(self) + self.digest.size() - size_of_val(&self.digest)
486 + self.return_type.size()
487 - size_of_val(&self.return_type)
488 }
489}
490
491#[cfg(test)]
492mod tests {
493 use arrow::datatypes::DataType;
494
495 use datafusion_functions_aggregate_common::tdigest::TDigest;
496
497 use crate::approx_percentile_cont::ApproxPercentileAccumulator;
498
499 #[test]
500 fn test_combine_approx_percentile_accumulator() {
501 let mut digests: Vec<TDigest> = Vec::new();
502
503 for _ in 1..=50 {
505 let t = TDigest::new(100);
506 let values: Vec<_> = (1..=1_000).map(f64::from).collect();
507 let t = t.merge_unsorted_f64(values);
508 digests.push(t)
509 }
510
511 let t1 = TDigest::merge_digests(&digests);
512 let t2 = TDigest::merge_digests(&digests);
513
514 let mut accumulator =
515 ApproxPercentileAccumulator::new_with_max_size(0.5, DataType::Float64, 100);
516
517 accumulator.merge_digests(&[t1]);
518 assert_eq!(accumulator.digest.count(), 50_000);
519 accumulator.merge_digests(&[t2]);
520 assert_eq!(accumulator.digest.count(), 100_000);
521 }
522}