datafusion_functions_aggregate/
regr.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Defines physical expressions that can evaluated at runtime during query execution
19
20use arrow::array::Float64Array;
21use arrow::{
22    array::{ArrayRef, UInt64Array},
23    compute::cast,
24    datatypes::DataType,
25    datatypes::Field,
26};
27use datafusion_common::{
28    downcast_value, plan_err, unwrap_or_internal_err, DataFusionError, HashMap, Result,
29    ScalarValue,
30};
31use datafusion_expr::aggregate_doc_sections::DOC_SECTION_STATISTICAL;
32use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
33use datafusion_expr::type_coercion::aggregates::NUMERICS;
34use datafusion_expr::utils::format_state_name;
35use datafusion_expr::{
36    Accumulator, AggregateUDFImpl, Documentation, Signature, Volatility,
37};
38use std::any::Any;
39use std::fmt::Debug;
40use std::mem::size_of_val;
41use std::sync::LazyLock;
42
43macro_rules! make_regr_udaf_expr_and_func {
44    ($EXPR_FN:ident, $AGGREGATE_UDF_FN:ident, $REGR_TYPE:expr) => {
45        make_udaf_expr!($EXPR_FN, expr_y expr_x, concat!("Compute a linear regression of type [", stringify!($REGR_TYPE), "]"), $AGGREGATE_UDF_FN);
46        create_func!($EXPR_FN, $AGGREGATE_UDF_FN, Regr::new($REGR_TYPE, stringify!($EXPR_FN)));
47    }
48}
49
50make_regr_udaf_expr_and_func!(regr_slope, regr_slope_udaf, RegrType::Slope);
51make_regr_udaf_expr_and_func!(regr_intercept, regr_intercept_udaf, RegrType::Intercept);
52make_regr_udaf_expr_and_func!(regr_count, regr_count_udaf, RegrType::Count);
53make_regr_udaf_expr_and_func!(regr_r2, regr_r2_udaf, RegrType::R2);
54make_regr_udaf_expr_and_func!(regr_avgx, regr_avgx_udaf, RegrType::AvgX);
55make_regr_udaf_expr_and_func!(regr_avgy, regr_avgy_udaf, RegrType::AvgY);
56make_regr_udaf_expr_and_func!(regr_sxx, regr_sxx_udaf, RegrType::SXX);
57make_regr_udaf_expr_and_func!(regr_syy, regr_syy_udaf, RegrType::SYY);
58make_regr_udaf_expr_and_func!(regr_sxy, regr_sxy_udaf, RegrType::SXY);
59
60pub struct Regr {
61    signature: Signature,
62    regr_type: RegrType,
63    func_name: &'static str,
64}
65
66impl Debug for Regr {
67    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
68        f.debug_struct("regr")
69            .field("name", &self.name())
70            .field("signature", &self.signature)
71            .finish()
72    }
73}
74
75impl Regr {
76    pub fn new(regr_type: RegrType, func_name: &'static str) -> Self {
77        Self {
78            signature: Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable),
79            regr_type,
80            func_name,
81        }
82    }
83}
84
85#[derive(Debug, Clone, PartialEq, Hash, Eq)]
86#[allow(clippy::upper_case_acronyms)]
87pub enum RegrType {
88    /// Variant for `regr_slope` aggregate expression
89    /// Returns the slope of the linear regression line for non-null pairs in aggregate columns.
90    /// Given input column Y and X: `regr_slope(Y, X)` returns the slope (k in Y = k*X + b) using minimal
91    /// RSS (Residual Sum of Squares) fitting.
92    Slope,
93    /// Variant for `regr_intercept` aggregate expression
94    /// Returns the intercept of the linear regression line for non-null pairs in aggregate columns.
95    /// Given input column Y and X: `regr_intercept(Y, X)` returns the intercept (b in Y = k*X + b) using minimal
96    /// RSS fitting.
97    Intercept,
98    /// Variant for `regr_count` aggregate expression
99    /// Returns the number of input rows for which both expressions are not null.
100    /// Given input column Y and X: `regr_count(Y, X)` returns the count of non-null pairs.
101    Count,
102    /// Variant for `regr_r2` aggregate expression
103    /// Returns the coefficient of determination (R-squared value) of the linear regression line for non-null pairs in aggregate columns.
104    /// The R-squared value represents the proportion of variance in Y that is predictable from X.
105    R2,
106    /// Variant for `regr_avgx` aggregate expression
107    /// Returns the average of the independent variable for non-null pairs in aggregate columns.
108    /// Given input column X: `regr_avgx(Y, X)` returns the average of X values.
109    AvgX,
110    /// Variant for `regr_avgy` aggregate expression
111    /// Returns the average of the dependent variable for non-null pairs in aggregate columns.
112    /// Given input column Y: `regr_avgy(Y, X)` returns the average of Y values.
113    AvgY,
114    /// Variant for `regr_sxx` aggregate expression
115    /// Returns the sum of squares of the independent variable for non-null pairs in aggregate columns.
116    /// Given input column X: `regr_sxx(Y, X)` returns the sum of squares of deviations of X from its mean.
117    SXX,
118    /// Variant for `regr_syy` aggregate expression
119    /// Returns the sum of squares of the dependent variable for non-null pairs in aggregate columns.
120    /// Given input column Y: `regr_syy(Y, X)` returns the sum of squares of deviations of Y from its mean.
121    SYY,
122    /// Variant for `regr_sxy` aggregate expression
123    /// Returns the sum of products of pairs of numbers for non-null pairs in aggregate columns.
124    /// Given input column Y and X: `regr_sxy(Y, X)` returns the sum of products of the deviations of Y and X from their respective means.
125    SXY,
126}
127
128impl RegrType {
129    /// return the documentation for the `RegrType`
130    fn documentation(&self) -> Option<&Documentation> {
131        get_regr_docs().get(self)
132    }
133}
134
135static DOCUMENTATION: LazyLock<HashMap<RegrType, Documentation>> = LazyLock::new(|| {
136    let mut hash_map = HashMap::new();
137    hash_map.insert(
138            RegrType::Slope,
139            Documentation::builder(
140                DOC_SECTION_STATISTICAL,
141                    "Returns the slope of the linear regression line for non-null pairs in aggregate columns. \
142                    Given input column Y and X: regr_slope(Y, X) returns the slope (k in Y = k*X + b) using minimal RSS fitting.",
143
144                "regr_slope(expression_y, expression_x)")
145                .with_standard_argument("expression_y", Some("Dependent variable"))
146                .with_standard_argument("expression_x", Some("Independent variable"))
147                .build()
148        );
149
150    hash_map.insert(
151            RegrType::Intercept,
152            Documentation::builder(
153                DOC_SECTION_STATISTICAL,
154                    "Computes the y-intercept of the linear regression line. For the equation (y = kx + b), \
155                    this function returns b.",
156
157                "regr_intercept(expression_y, expression_x)")
158                .with_standard_argument("expression_y", Some("Dependent variable"))
159                .with_standard_argument("expression_x", Some("Independent variable"))
160                .build()
161        );
162
163    hash_map.insert(
164        RegrType::Count,
165        Documentation::builder(
166            DOC_SECTION_STATISTICAL,
167            "Counts the number of non-null paired data points.",
168            "regr_count(expression_y, expression_x)",
169        )
170        .with_standard_argument("expression_y", Some("Dependent variable"))
171        .with_standard_argument("expression_x", Some("Independent variable"))
172        .build(),
173    );
174
175    hash_map.insert(
176            RegrType::R2,
177            Documentation::builder(
178                DOC_SECTION_STATISTICAL,
179                    "Computes the square of the correlation coefficient between the independent and dependent variables.",
180
181                "regr_r2(expression_y, expression_x)")
182                .with_standard_argument("expression_y", Some("Dependent variable"))
183                .with_standard_argument("expression_x", Some("Independent variable"))
184                .build()
185        );
186
187    hash_map.insert(
188            RegrType::AvgX,
189            Documentation::builder(
190                DOC_SECTION_STATISTICAL,
191                    "Computes the average of the independent variable (input) expression_x for the non-null paired data points.",
192
193                "regr_avgx(expression_y, expression_x)")
194                .with_standard_argument("expression_y", Some("Dependent variable"))
195                .with_standard_argument("expression_x", Some("Independent variable"))
196                .build()
197        );
198
199    hash_map.insert(
200            RegrType::AvgY,
201            Documentation::builder(
202                DOC_SECTION_STATISTICAL,
203                    "Computes the average of the dependent variable (output) expression_y for the non-null paired data points.",
204
205                "regr_avgy(expression_y, expression_x)")
206                .with_standard_argument("expression_y", Some("Dependent variable"))
207                .with_standard_argument("expression_x", Some("Independent variable"))
208                .build()
209        );
210
211    hash_map.insert(
212        RegrType::SXX,
213        Documentation::builder(
214            DOC_SECTION_STATISTICAL,
215            "Computes the sum of squares of the independent variable.",
216            "regr_sxx(expression_y, expression_x)",
217        )
218        .with_standard_argument("expression_y", Some("Dependent variable"))
219        .with_standard_argument("expression_x", Some("Independent variable"))
220        .build(),
221    );
222
223    hash_map.insert(
224        RegrType::SYY,
225        Documentation::builder(
226            DOC_SECTION_STATISTICAL,
227            "Computes the sum of squares of the dependent variable.",
228            "regr_syy(expression_y, expression_x)",
229        )
230        .with_standard_argument("expression_y", Some("Dependent variable"))
231        .with_standard_argument("expression_x", Some("Independent variable"))
232        .build(),
233    );
234
235    hash_map.insert(
236        RegrType::SXY,
237        Documentation::builder(
238            DOC_SECTION_STATISTICAL,
239            "Computes the sum of products of paired data points.",
240            "regr_sxy(expression_y, expression_x)",
241        )
242        .with_standard_argument("expression_y", Some("Dependent variable"))
243        .with_standard_argument("expression_x", Some("Independent variable"))
244        .build(),
245    );
246    hash_map
247});
248fn get_regr_docs() -> &'static HashMap<RegrType, Documentation> {
249    &DOCUMENTATION
250}
251
252impl AggregateUDFImpl for Regr {
253    fn as_any(&self) -> &dyn Any {
254        self
255    }
256
257    fn name(&self) -> &str {
258        self.func_name
259    }
260
261    fn signature(&self) -> &Signature {
262        &self.signature
263    }
264
265    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
266        if !arg_types[0].is_numeric() {
267            return plan_err!("Covariance requires numeric input types");
268        }
269
270        if matches!(self.regr_type, RegrType::Count) {
271            Ok(DataType::UInt64)
272        } else {
273            Ok(DataType::Float64)
274        }
275    }
276
277    fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
278        Ok(Box::new(RegrAccumulator::try_new(&self.regr_type)?))
279    }
280
281    fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> {
282        Ok(vec![
283            Field::new(
284                format_state_name(args.name, "count"),
285                DataType::UInt64,
286                true,
287            ),
288            Field::new(
289                format_state_name(args.name, "mean_x"),
290                DataType::Float64,
291                true,
292            ),
293            Field::new(
294                format_state_name(args.name, "mean_y"),
295                DataType::Float64,
296                true,
297            ),
298            Field::new(
299                format_state_name(args.name, "m2_x"),
300                DataType::Float64,
301                true,
302            ),
303            Field::new(
304                format_state_name(args.name, "m2_y"),
305                DataType::Float64,
306                true,
307            ),
308            Field::new(
309                format_state_name(args.name, "algo_const"),
310                DataType::Float64,
311                true,
312            ),
313        ])
314    }
315
316    fn documentation(&self) -> Option<&Documentation> {
317        self.regr_type.documentation()
318    }
319}
320
321/// `RegrAccumulator` is used to compute linear regression aggregate functions
322/// by maintaining statistics needed to compute them in an online fashion.
323///
324/// This struct uses Welford's online algorithm for calculating variance and covariance:
325/// <https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm>
326///
327/// Given the statistics, the following aggregate functions can be calculated:
328///
329/// - `regr_slope(y, x)`: Slope of the linear regression line, calculated as:
330///   cov_pop(x, y) / var_pop(x).
331///   It represents the expected change in Y for a one-unit change in X.
332///
333/// - `regr_intercept(y, x)`: Intercept of the linear regression line, calculated as:
334///   mean_y - (regr_slope(y, x) * mean_x).
335///   It represents the expected value of Y when X is 0.
336///
337/// - `regr_count(y, x)`: Count of the non-null(both x and y) input rows.
338///
339/// - `regr_r2(y, x)`: R-squared value (coefficient of determination), calculated as:
340///   (cov_pop(x, y) ^ 2) / (var_pop(x) * var_pop(y)).
341///   It provides a measure of how well the model's predictions match the observed data.
342///
343/// - `regr_avgx(y, x)`: Average of the independent variable X, calculated as: mean_x.
344///
345/// - `regr_avgy(y, x)`: Average of the dependent variable Y, calculated as: mean_y.
346///
347/// - `regr_sxx(y, x)`: Sum of squares of the independent variable X, calculated as:
348///   m2_x.
349///
350/// - `regr_syy(y, x)`: Sum of squares of the dependent variable Y, calculated as:
351///   m2_y.
352///
353/// - `regr_sxy(y, x)`: Sum of products of paired values, calculated as:
354///   algo_const.
355///
356/// Here's how the statistics maintained in this struct are calculated:
357/// - `cov_pop(x, y)`: algo_const / count.
358/// - `var_pop(x)`: m2_x / count.
359/// - `var_pop(y)`: m2_y / count.
360#[derive(Debug)]
361pub struct RegrAccumulator {
362    count: u64,
363    mean_x: f64,
364    mean_y: f64,
365    m2_x: f64,
366    m2_y: f64,
367    algo_const: f64,
368    regr_type: RegrType,
369}
370
371impl RegrAccumulator {
372    /// Creates a new `RegrAccumulator`
373    pub fn try_new(regr_type: &RegrType) -> Result<Self> {
374        Ok(Self {
375            count: 0_u64,
376            mean_x: 0_f64,
377            mean_y: 0_f64,
378            m2_x: 0_f64,
379            m2_y: 0_f64,
380            algo_const: 0_f64,
381            regr_type: regr_type.clone(),
382        })
383    }
384}
385
386impl Accumulator for RegrAccumulator {
387    fn state(&mut self) -> Result<Vec<ScalarValue>> {
388        Ok(vec![
389            ScalarValue::from(self.count),
390            ScalarValue::from(self.mean_x),
391            ScalarValue::from(self.mean_y),
392            ScalarValue::from(self.m2_x),
393            ScalarValue::from(self.m2_y),
394            ScalarValue::from(self.algo_const),
395        ])
396    }
397
398    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
399        // regr_slope(Y, X) calculates k in y = k*x + b
400        let values_y = &cast(&values[0], &DataType::Float64)?;
401        let values_x = &cast(&values[1], &DataType::Float64)?;
402
403        let mut arr_y = downcast_value!(values_y, Float64Array).iter().flatten();
404        let mut arr_x = downcast_value!(values_x, Float64Array).iter().flatten();
405
406        for i in 0..values_y.len() {
407            // skip either x or y is NULL
408            let value_y = if values_y.is_valid(i) {
409                arr_y.next()
410            } else {
411                None
412            };
413            let value_x = if values_x.is_valid(i) {
414                arr_x.next()
415            } else {
416                None
417            };
418            if value_y.is_none() || value_x.is_none() {
419                continue;
420            }
421
422            // Update states for regr_slope(y,x) [using cov_pop(x,y)/var_pop(x)]
423            let value_y = unwrap_or_internal_err!(value_y);
424            let value_x = unwrap_or_internal_err!(value_x);
425
426            self.count += 1;
427            let delta_x = value_x - self.mean_x;
428            let delta_y = value_y - self.mean_y;
429            self.mean_x += delta_x / self.count as f64;
430            self.mean_y += delta_y / self.count as f64;
431            let delta_x_2 = value_x - self.mean_x;
432            let delta_y_2 = value_y - self.mean_y;
433            self.m2_x += delta_x * delta_x_2;
434            self.m2_y += delta_y * delta_y_2;
435            self.algo_const += delta_x * (value_y - self.mean_y);
436        }
437
438        Ok(())
439    }
440
441    fn supports_retract_batch(&self) -> bool {
442        true
443    }
444
445    fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
446        let values_y = &cast(&values[0], &DataType::Float64)?;
447        let values_x = &cast(&values[1], &DataType::Float64)?;
448
449        let mut arr_y = downcast_value!(values_y, Float64Array).iter().flatten();
450        let mut arr_x = downcast_value!(values_x, Float64Array).iter().flatten();
451
452        for i in 0..values_y.len() {
453            // skip either x or y is NULL
454            let value_y = if values_y.is_valid(i) {
455                arr_y.next()
456            } else {
457                None
458            };
459            let value_x = if values_x.is_valid(i) {
460                arr_x.next()
461            } else {
462                None
463            };
464            if value_y.is_none() || value_x.is_none() {
465                continue;
466            }
467
468            // Update states for regr_slope(y,x) [using cov_pop(x,y)/var_pop(x)]
469            let value_y = unwrap_or_internal_err!(value_y);
470            let value_x = unwrap_or_internal_err!(value_x);
471
472            if self.count > 1 {
473                self.count -= 1;
474                let delta_x = value_x - self.mean_x;
475                let delta_y = value_y - self.mean_y;
476                self.mean_x -= delta_x / self.count as f64;
477                self.mean_y -= delta_y / self.count as f64;
478                let delta_x_2 = value_x - self.mean_x;
479                let delta_y_2 = value_y - self.mean_y;
480                self.m2_x -= delta_x * delta_x_2;
481                self.m2_y -= delta_y * delta_y_2;
482                self.algo_const -= delta_x * (value_y - self.mean_y);
483            } else {
484                self.count = 0;
485                self.mean_x = 0.0;
486                self.m2_x = 0.0;
487                self.m2_y = 0.0;
488                self.mean_y = 0.0;
489                self.algo_const = 0.0;
490            }
491        }
492
493        Ok(())
494    }
495
496    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
497        let count_arr = downcast_value!(states[0], UInt64Array);
498        let mean_x_arr = downcast_value!(states[1], Float64Array);
499        let mean_y_arr = downcast_value!(states[2], Float64Array);
500        let m2_x_arr = downcast_value!(states[3], Float64Array);
501        let m2_y_arr = downcast_value!(states[4], Float64Array);
502        let algo_const_arr = downcast_value!(states[5], Float64Array);
503
504        for i in 0..count_arr.len() {
505            let count_b = count_arr.value(i);
506            if count_b == 0_u64 {
507                continue;
508            }
509            let (count_a, mean_x_a, mean_y_a, m2_x_a, m2_y_a, algo_const_a) = (
510                self.count,
511                self.mean_x,
512                self.mean_y,
513                self.m2_x,
514                self.m2_y,
515                self.algo_const,
516            );
517            let (count_b, mean_x_b, mean_y_b, m2_x_b, m2_y_b, algo_const_b) = (
518                count_b,
519                mean_x_arr.value(i),
520                mean_y_arr.value(i),
521                m2_x_arr.value(i),
522                m2_y_arr.value(i),
523                algo_const_arr.value(i),
524            );
525
526            // Assuming two different batches of input have calculated the states:
527            // batch A of Y, X -> {count_a, mean_x_a, mean_y_a, m2_x_a, algo_const_a}
528            // batch B of Y, X -> {count_b, mean_x_b, mean_y_b, m2_x_b, algo_const_b}
529            // The merged states from A and B are {count_ab, mean_x_ab, mean_y_ab, m2_x_ab,
530            // algo_const_ab}
531            //
532            // Reference for the algorithm to merge states:
533            // https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm
534            let count_ab = count_a + count_b;
535            let (count_a, count_b) = (count_a as f64, count_b as f64);
536            let d_x = mean_x_b - mean_x_a;
537            let d_y = mean_y_b - mean_y_a;
538            let mean_x_ab = mean_x_a + d_x * count_b / count_ab as f64;
539            let mean_y_ab = mean_y_a + d_y * count_b / count_ab as f64;
540            let m2_x_ab =
541                m2_x_a + m2_x_b + d_x * d_x * count_a * count_b / count_ab as f64;
542            let m2_y_ab =
543                m2_y_a + m2_y_b + d_y * d_y * count_a * count_b / count_ab as f64;
544            let algo_const_ab = algo_const_a
545                + algo_const_b
546                + d_x * d_y * count_a * count_b / count_ab as f64;
547
548            self.count = count_ab;
549            self.mean_x = mean_x_ab;
550            self.mean_y = mean_y_ab;
551            self.m2_x = m2_x_ab;
552            self.m2_y = m2_y_ab;
553            self.algo_const = algo_const_ab;
554        }
555        Ok(())
556    }
557
558    fn evaluate(&mut self) -> Result<ScalarValue> {
559        let cov_pop_x_y = self.algo_const / self.count as f64;
560        let var_pop_x = self.m2_x / self.count as f64;
561        let var_pop_y = self.m2_y / self.count as f64;
562
563        let nullif_or_stat = |cond: bool, stat: f64| {
564            if cond {
565                Ok(ScalarValue::Float64(None))
566            } else {
567                Ok(ScalarValue::Float64(Some(stat)))
568            }
569        };
570
571        match self.regr_type {
572            RegrType::Slope => {
573                // Only 0/1 point or slope is infinite
574                let nullif_cond = self.count <= 1 || var_pop_x == 0.0;
575                nullif_or_stat(nullif_cond, cov_pop_x_y / var_pop_x)
576            }
577            RegrType::Intercept => {
578                let slope = cov_pop_x_y / var_pop_x;
579                // Only 0/1 point or slope is infinite
580                let nullif_cond = self.count <= 1 || var_pop_x == 0.0;
581                nullif_or_stat(nullif_cond, self.mean_y - slope * self.mean_x)
582            }
583            RegrType::Count => Ok(ScalarValue::UInt64(Some(self.count))),
584            RegrType::R2 => {
585                // Only 0/1 point or all x(or y) is the same
586                let nullif_cond = self.count <= 1 || var_pop_x == 0.0 || var_pop_y == 0.0;
587                nullif_or_stat(
588                    nullif_cond,
589                    (cov_pop_x_y * cov_pop_x_y) / (var_pop_x * var_pop_y),
590                )
591            }
592            RegrType::AvgX => nullif_or_stat(self.count < 1, self.mean_x),
593            RegrType::AvgY => nullif_or_stat(self.count < 1, self.mean_y),
594            RegrType::SXX => nullif_or_stat(self.count < 1, self.m2_x),
595            RegrType::SYY => nullif_or_stat(self.count < 1, self.m2_y),
596            RegrType::SXY => nullif_or_stat(self.count < 1, self.algo_const),
597        }
598    }
599
600    fn size(&self) -> usize {
601        size_of_val(self)
602    }
603}