1use arrow::array::Float64Array;
21use arrow::{
22 array::{ArrayRef, UInt64Array},
23 compute::cast,
24 datatypes::DataType,
25 datatypes::Field,
26};
27use datafusion_common::{
28 downcast_value, plan_err, unwrap_or_internal_err, DataFusionError, HashMap, Result,
29 ScalarValue,
30};
31use datafusion_expr::aggregate_doc_sections::DOC_SECTION_STATISTICAL;
32use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
33use datafusion_expr::type_coercion::aggregates::NUMERICS;
34use datafusion_expr::utils::format_state_name;
35use datafusion_expr::{
36 Accumulator, AggregateUDFImpl, Documentation, Signature, Volatility,
37};
38use std::any::Any;
39use std::fmt::Debug;
40use std::mem::size_of_val;
41use std::sync::LazyLock;
42
43macro_rules! make_regr_udaf_expr_and_func {
44 ($EXPR_FN:ident, $AGGREGATE_UDF_FN:ident, $REGR_TYPE:expr) => {
45 make_udaf_expr!($EXPR_FN, expr_y expr_x, concat!("Compute a linear regression of type [", stringify!($REGR_TYPE), "]"), $AGGREGATE_UDF_FN);
46 create_func!($EXPR_FN, $AGGREGATE_UDF_FN, Regr::new($REGR_TYPE, stringify!($EXPR_FN)));
47 }
48}
49
50make_regr_udaf_expr_and_func!(regr_slope, regr_slope_udaf, RegrType::Slope);
51make_regr_udaf_expr_and_func!(regr_intercept, regr_intercept_udaf, RegrType::Intercept);
52make_regr_udaf_expr_and_func!(regr_count, regr_count_udaf, RegrType::Count);
53make_regr_udaf_expr_and_func!(regr_r2, regr_r2_udaf, RegrType::R2);
54make_regr_udaf_expr_and_func!(regr_avgx, regr_avgx_udaf, RegrType::AvgX);
55make_regr_udaf_expr_and_func!(regr_avgy, regr_avgy_udaf, RegrType::AvgY);
56make_regr_udaf_expr_and_func!(regr_sxx, regr_sxx_udaf, RegrType::SXX);
57make_regr_udaf_expr_and_func!(regr_syy, regr_syy_udaf, RegrType::SYY);
58make_regr_udaf_expr_and_func!(regr_sxy, regr_sxy_udaf, RegrType::SXY);
59
60pub struct Regr {
61 signature: Signature,
62 regr_type: RegrType,
63 func_name: &'static str,
64}
65
66impl Debug for Regr {
67 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
68 f.debug_struct("regr")
69 .field("name", &self.name())
70 .field("signature", &self.signature)
71 .finish()
72 }
73}
74
75impl Regr {
76 pub fn new(regr_type: RegrType, func_name: &'static str) -> Self {
77 Self {
78 signature: Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable),
79 regr_type,
80 func_name,
81 }
82 }
83}
84
85#[derive(Debug, Clone, PartialEq, Hash, Eq)]
86#[allow(clippy::upper_case_acronyms)]
87pub enum RegrType {
88 Slope,
93 Intercept,
98 Count,
102 R2,
106 AvgX,
110 AvgY,
114 SXX,
118 SYY,
122 SXY,
126}
127
128impl RegrType {
129 fn documentation(&self) -> Option<&Documentation> {
131 get_regr_docs().get(self)
132 }
133}
134
135static DOCUMENTATION: LazyLock<HashMap<RegrType, Documentation>> = LazyLock::new(|| {
136 let mut hash_map = HashMap::new();
137 hash_map.insert(
138 RegrType::Slope,
139 Documentation::builder(
140 DOC_SECTION_STATISTICAL,
141 "Returns the slope of the linear regression line for non-null pairs in aggregate columns. \
142 Given input column Y and X: regr_slope(Y, X) returns the slope (k in Y = k*X + b) using minimal RSS fitting.",
143
144 "regr_slope(expression_y, expression_x)")
145 .with_standard_argument("expression_y", Some("Dependent variable"))
146 .with_standard_argument("expression_x", Some("Independent variable"))
147 .build()
148 );
149
150 hash_map.insert(
151 RegrType::Intercept,
152 Documentation::builder(
153 DOC_SECTION_STATISTICAL,
154 "Computes the y-intercept of the linear regression line. For the equation (y = kx + b), \
155 this function returns b.",
156
157 "regr_intercept(expression_y, expression_x)")
158 .with_standard_argument("expression_y", Some("Dependent variable"))
159 .with_standard_argument("expression_x", Some("Independent variable"))
160 .build()
161 );
162
163 hash_map.insert(
164 RegrType::Count,
165 Documentation::builder(
166 DOC_SECTION_STATISTICAL,
167 "Counts the number of non-null paired data points.",
168 "regr_count(expression_y, expression_x)",
169 )
170 .with_standard_argument("expression_y", Some("Dependent variable"))
171 .with_standard_argument("expression_x", Some("Independent variable"))
172 .build(),
173 );
174
175 hash_map.insert(
176 RegrType::R2,
177 Documentation::builder(
178 DOC_SECTION_STATISTICAL,
179 "Computes the square of the correlation coefficient between the independent and dependent variables.",
180
181 "regr_r2(expression_y, expression_x)")
182 .with_standard_argument("expression_y", Some("Dependent variable"))
183 .with_standard_argument("expression_x", Some("Independent variable"))
184 .build()
185 );
186
187 hash_map.insert(
188 RegrType::AvgX,
189 Documentation::builder(
190 DOC_SECTION_STATISTICAL,
191 "Computes the average of the independent variable (input) expression_x for the non-null paired data points.",
192
193 "regr_avgx(expression_y, expression_x)")
194 .with_standard_argument("expression_y", Some("Dependent variable"))
195 .with_standard_argument("expression_x", Some("Independent variable"))
196 .build()
197 );
198
199 hash_map.insert(
200 RegrType::AvgY,
201 Documentation::builder(
202 DOC_SECTION_STATISTICAL,
203 "Computes the average of the dependent variable (output) expression_y for the non-null paired data points.",
204
205 "regr_avgy(expression_y, expression_x)")
206 .with_standard_argument("expression_y", Some("Dependent variable"))
207 .with_standard_argument("expression_x", Some("Independent variable"))
208 .build()
209 );
210
211 hash_map.insert(
212 RegrType::SXX,
213 Documentation::builder(
214 DOC_SECTION_STATISTICAL,
215 "Computes the sum of squares of the independent variable.",
216 "regr_sxx(expression_y, expression_x)",
217 )
218 .with_standard_argument("expression_y", Some("Dependent variable"))
219 .with_standard_argument("expression_x", Some("Independent variable"))
220 .build(),
221 );
222
223 hash_map.insert(
224 RegrType::SYY,
225 Documentation::builder(
226 DOC_SECTION_STATISTICAL,
227 "Computes the sum of squares of the dependent variable.",
228 "regr_syy(expression_y, expression_x)",
229 )
230 .with_standard_argument("expression_y", Some("Dependent variable"))
231 .with_standard_argument("expression_x", Some("Independent variable"))
232 .build(),
233 );
234
235 hash_map.insert(
236 RegrType::SXY,
237 Documentation::builder(
238 DOC_SECTION_STATISTICAL,
239 "Computes the sum of products of paired data points.",
240 "regr_sxy(expression_y, expression_x)",
241 )
242 .with_standard_argument("expression_y", Some("Dependent variable"))
243 .with_standard_argument("expression_x", Some("Independent variable"))
244 .build(),
245 );
246 hash_map
247});
248fn get_regr_docs() -> &'static HashMap<RegrType, Documentation> {
249 &DOCUMENTATION
250}
251
252impl AggregateUDFImpl for Regr {
253 fn as_any(&self) -> &dyn Any {
254 self
255 }
256
257 fn name(&self) -> &str {
258 self.func_name
259 }
260
261 fn signature(&self) -> &Signature {
262 &self.signature
263 }
264
265 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
266 if !arg_types[0].is_numeric() {
267 return plan_err!("Covariance requires numeric input types");
268 }
269
270 if matches!(self.regr_type, RegrType::Count) {
271 Ok(DataType::UInt64)
272 } else {
273 Ok(DataType::Float64)
274 }
275 }
276
277 fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
278 Ok(Box::new(RegrAccumulator::try_new(&self.regr_type)?))
279 }
280
281 fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> {
282 Ok(vec![
283 Field::new(
284 format_state_name(args.name, "count"),
285 DataType::UInt64,
286 true,
287 ),
288 Field::new(
289 format_state_name(args.name, "mean_x"),
290 DataType::Float64,
291 true,
292 ),
293 Field::new(
294 format_state_name(args.name, "mean_y"),
295 DataType::Float64,
296 true,
297 ),
298 Field::new(
299 format_state_name(args.name, "m2_x"),
300 DataType::Float64,
301 true,
302 ),
303 Field::new(
304 format_state_name(args.name, "m2_y"),
305 DataType::Float64,
306 true,
307 ),
308 Field::new(
309 format_state_name(args.name, "algo_const"),
310 DataType::Float64,
311 true,
312 ),
313 ])
314 }
315
316 fn documentation(&self) -> Option<&Documentation> {
317 self.regr_type.documentation()
318 }
319}
320
321#[derive(Debug)]
361pub struct RegrAccumulator {
362 count: u64,
363 mean_x: f64,
364 mean_y: f64,
365 m2_x: f64,
366 m2_y: f64,
367 algo_const: f64,
368 regr_type: RegrType,
369}
370
371impl RegrAccumulator {
372 pub fn try_new(regr_type: &RegrType) -> Result<Self> {
374 Ok(Self {
375 count: 0_u64,
376 mean_x: 0_f64,
377 mean_y: 0_f64,
378 m2_x: 0_f64,
379 m2_y: 0_f64,
380 algo_const: 0_f64,
381 regr_type: regr_type.clone(),
382 })
383 }
384}
385
386impl Accumulator for RegrAccumulator {
387 fn state(&mut self) -> Result<Vec<ScalarValue>> {
388 Ok(vec![
389 ScalarValue::from(self.count),
390 ScalarValue::from(self.mean_x),
391 ScalarValue::from(self.mean_y),
392 ScalarValue::from(self.m2_x),
393 ScalarValue::from(self.m2_y),
394 ScalarValue::from(self.algo_const),
395 ])
396 }
397
398 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
399 let values_y = &cast(&values[0], &DataType::Float64)?;
401 let values_x = &cast(&values[1], &DataType::Float64)?;
402
403 let mut arr_y = downcast_value!(values_y, Float64Array).iter().flatten();
404 let mut arr_x = downcast_value!(values_x, Float64Array).iter().flatten();
405
406 for i in 0..values_y.len() {
407 let value_y = if values_y.is_valid(i) {
409 arr_y.next()
410 } else {
411 None
412 };
413 let value_x = if values_x.is_valid(i) {
414 arr_x.next()
415 } else {
416 None
417 };
418 if value_y.is_none() || value_x.is_none() {
419 continue;
420 }
421
422 let value_y = unwrap_or_internal_err!(value_y);
424 let value_x = unwrap_or_internal_err!(value_x);
425
426 self.count += 1;
427 let delta_x = value_x - self.mean_x;
428 let delta_y = value_y - self.mean_y;
429 self.mean_x += delta_x / self.count as f64;
430 self.mean_y += delta_y / self.count as f64;
431 let delta_x_2 = value_x - self.mean_x;
432 let delta_y_2 = value_y - self.mean_y;
433 self.m2_x += delta_x * delta_x_2;
434 self.m2_y += delta_y * delta_y_2;
435 self.algo_const += delta_x * (value_y - self.mean_y);
436 }
437
438 Ok(())
439 }
440
441 fn supports_retract_batch(&self) -> bool {
442 true
443 }
444
445 fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
446 let values_y = &cast(&values[0], &DataType::Float64)?;
447 let values_x = &cast(&values[1], &DataType::Float64)?;
448
449 let mut arr_y = downcast_value!(values_y, Float64Array).iter().flatten();
450 let mut arr_x = downcast_value!(values_x, Float64Array).iter().flatten();
451
452 for i in 0..values_y.len() {
453 let value_y = if values_y.is_valid(i) {
455 arr_y.next()
456 } else {
457 None
458 };
459 let value_x = if values_x.is_valid(i) {
460 arr_x.next()
461 } else {
462 None
463 };
464 if value_y.is_none() || value_x.is_none() {
465 continue;
466 }
467
468 let value_y = unwrap_or_internal_err!(value_y);
470 let value_x = unwrap_or_internal_err!(value_x);
471
472 if self.count > 1 {
473 self.count -= 1;
474 let delta_x = value_x - self.mean_x;
475 let delta_y = value_y - self.mean_y;
476 self.mean_x -= delta_x / self.count as f64;
477 self.mean_y -= delta_y / self.count as f64;
478 let delta_x_2 = value_x - self.mean_x;
479 let delta_y_2 = value_y - self.mean_y;
480 self.m2_x -= delta_x * delta_x_2;
481 self.m2_y -= delta_y * delta_y_2;
482 self.algo_const -= delta_x * (value_y - self.mean_y);
483 } else {
484 self.count = 0;
485 self.mean_x = 0.0;
486 self.m2_x = 0.0;
487 self.m2_y = 0.0;
488 self.mean_y = 0.0;
489 self.algo_const = 0.0;
490 }
491 }
492
493 Ok(())
494 }
495
496 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
497 let count_arr = downcast_value!(states[0], UInt64Array);
498 let mean_x_arr = downcast_value!(states[1], Float64Array);
499 let mean_y_arr = downcast_value!(states[2], Float64Array);
500 let m2_x_arr = downcast_value!(states[3], Float64Array);
501 let m2_y_arr = downcast_value!(states[4], Float64Array);
502 let algo_const_arr = downcast_value!(states[5], Float64Array);
503
504 for i in 0..count_arr.len() {
505 let count_b = count_arr.value(i);
506 if count_b == 0_u64 {
507 continue;
508 }
509 let (count_a, mean_x_a, mean_y_a, m2_x_a, m2_y_a, algo_const_a) = (
510 self.count,
511 self.mean_x,
512 self.mean_y,
513 self.m2_x,
514 self.m2_y,
515 self.algo_const,
516 );
517 let (count_b, mean_x_b, mean_y_b, m2_x_b, m2_y_b, algo_const_b) = (
518 count_b,
519 mean_x_arr.value(i),
520 mean_y_arr.value(i),
521 m2_x_arr.value(i),
522 m2_y_arr.value(i),
523 algo_const_arr.value(i),
524 );
525
526 let count_ab = count_a + count_b;
535 let (count_a, count_b) = (count_a as f64, count_b as f64);
536 let d_x = mean_x_b - mean_x_a;
537 let d_y = mean_y_b - mean_y_a;
538 let mean_x_ab = mean_x_a + d_x * count_b / count_ab as f64;
539 let mean_y_ab = mean_y_a + d_y * count_b / count_ab as f64;
540 let m2_x_ab =
541 m2_x_a + m2_x_b + d_x * d_x * count_a * count_b / count_ab as f64;
542 let m2_y_ab =
543 m2_y_a + m2_y_b + d_y * d_y * count_a * count_b / count_ab as f64;
544 let algo_const_ab = algo_const_a
545 + algo_const_b
546 + d_x * d_y * count_a * count_b / count_ab as f64;
547
548 self.count = count_ab;
549 self.mean_x = mean_x_ab;
550 self.mean_y = mean_y_ab;
551 self.m2_x = m2_x_ab;
552 self.m2_y = m2_y_ab;
553 self.algo_const = algo_const_ab;
554 }
555 Ok(())
556 }
557
558 fn evaluate(&mut self) -> Result<ScalarValue> {
559 let cov_pop_x_y = self.algo_const / self.count as f64;
560 let var_pop_x = self.m2_x / self.count as f64;
561 let var_pop_y = self.m2_y / self.count as f64;
562
563 let nullif_or_stat = |cond: bool, stat: f64| {
564 if cond {
565 Ok(ScalarValue::Float64(None))
566 } else {
567 Ok(ScalarValue::Float64(Some(stat)))
568 }
569 };
570
571 match self.regr_type {
572 RegrType::Slope => {
573 let nullif_cond = self.count <= 1 || var_pop_x == 0.0;
575 nullif_or_stat(nullif_cond, cov_pop_x_y / var_pop_x)
576 }
577 RegrType::Intercept => {
578 let slope = cov_pop_x_y / var_pop_x;
579 let nullif_cond = self.count <= 1 || var_pop_x == 0.0;
581 nullif_or_stat(nullif_cond, self.mean_y - slope * self.mean_x)
582 }
583 RegrType::Count => Ok(ScalarValue::UInt64(Some(self.count))),
584 RegrType::R2 => {
585 let nullif_cond = self.count <= 1 || var_pop_x == 0.0 || var_pop_y == 0.0;
587 nullif_or_stat(
588 nullif_cond,
589 (cov_pop_x_y * cov_pop_x_y) / (var_pop_x * var_pop_y),
590 )
591 }
592 RegrType::AvgX => nullif_or_stat(self.count < 1, self.mean_x),
593 RegrType::AvgY => nullif_or_stat(self.count < 1, self.mean_y),
594 RegrType::SXX => nullif_or_stat(self.count < 1, self.m2_x),
595 RegrType::SYY => nullif_or_stat(self.count < 1, self.m2_y),
596 RegrType::SXY => nullif_or_stat(self.count < 1, self.algo_const),
597 }
598 }
599
600 fn size(&self) -> usize {
601 size_of_val(self)
602 }
603}