datafusion_functions_aggregate/
correlation.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`Correlation`]: correlation sample aggregations.
19
20use std::any::Any;
21use std::fmt::Debug;
22use std::mem::size_of_val;
23use std::sync::Arc;
24
25use arrow::array::{
26    downcast_array, Array, AsArray, BooleanArray, Float64Array, NullBufferBuilder,
27    UInt64Array,
28};
29use arrow::compute::{and, filter, is_not_null, kernels::cast};
30use arrow::datatypes::{Float64Type, UInt64Type};
31use arrow::{
32    array::ArrayRef,
33    datatypes::{DataType, Field},
34};
35use datafusion_expr::{EmitTo, GroupsAccumulator};
36use datafusion_functions_aggregate_common::aggregate::groups_accumulator::accumulate::accumulate_multiple;
37use log::debug;
38
39use crate::covariance::CovarianceAccumulator;
40use crate::stddev::StddevAccumulator;
41use datafusion_common::{plan_err, Result, ScalarValue};
42use datafusion_expr::{
43    function::{AccumulatorArgs, StateFieldsArgs},
44    type_coercion::aggregates::NUMERICS,
45    utils::format_state_name,
46    Accumulator, AggregateUDFImpl, Documentation, Signature, Volatility,
47};
48use datafusion_functions_aggregate_common::stats::StatsType;
49use datafusion_macros::user_doc;
50
51make_udaf_expr_and_func!(
52    Correlation,
53    corr,
54    y x,
55    "Correlation between two numeric values.",
56    corr_udaf
57);
58
59#[user_doc(
60    doc_section(label = "Statistical Functions"),
61    description = "Returns the coefficient of correlation between two numeric values.",
62    syntax_example = "corr(expression1, expression2)",
63    sql_example = r#"```sql
64> SELECT corr(column1, column2) FROM table_name;
65+--------------------------------+
66| corr(column1, column2)         |
67+--------------------------------+
68| 0.85                           |
69+--------------------------------+
70```"#,
71    standard_argument(name = "expression1", prefix = "First"),
72    standard_argument(name = "expression2", prefix = "Second")
73)]
74#[derive(Debug)]
75pub struct Correlation {
76    signature: Signature,
77}
78
79impl Default for Correlation {
80    fn default() -> Self {
81        Self::new()
82    }
83}
84
85impl Correlation {
86    /// Create a new COVAR_POP aggregate function
87    pub fn new() -> Self {
88        Self {
89            signature: Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable),
90        }
91    }
92}
93
94impl AggregateUDFImpl for Correlation {
95    /// Return a reference to Any that can be used for downcasting
96    fn as_any(&self) -> &dyn Any {
97        self
98    }
99
100    fn name(&self) -> &str {
101        "corr"
102    }
103
104    fn signature(&self) -> &Signature {
105        &self.signature
106    }
107
108    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
109        if !arg_types[0].is_numeric() {
110            return plan_err!("Correlation requires numeric input types");
111        }
112
113        Ok(DataType::Float64)
114    }
115
116    fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
117        Ok(Box::new(CorrelationAccumulator::try_new()?))
118    }
119
120    fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> {
121        let name = args.name;
122        Ok(vec![
123            Field::new(format_state_name(name, "count"), DataType::UInt64, true),
124            Field::new(format_state_name(name, "mean1"), DataType::Float64, true),
125            Field::new(format_state_name(name, "m2_1"), DataType::Float64, true),
126            Field::new(format_state_name(name, "mean2"), DataType::Float64, true),
127            Field::new(format_state_name(name, "m2_2"), DataType::Float64, true),
128            Field::new(
129                format_state_name(name, "algo_const"),
130                DataType::Float64,
131                true,
132            ),
133        ])
134    }
135
136    fn documentation(&self) -> Option<&Documentation> {
137        self.doc()
138    }
139
140    fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool {
141        true
142    }
143
144    fn create_groups_accumulator(
145        &self,
146        _args: AccumulatorArgs,
147    ) -> Result<Box<dyn GroupsAccumulator>> {
148        debug!("GroupsAccumulator is created for aggregate function `corr(c1, c2)`");
149        Ok(Box::new(CorrelationGroupsAccumulator::new()))
150    }
151}
152
153/// An accumulator to compute correlation
154#[derive(Debug)]
155pub struct CorrelationAccumulator {
156    covar: CovarianceAccumulator,
157    stddev1: StddevAccumulator,
158    stddev2: StddevAccumulator,
159}
160
161impl CorrelationAccumulator {
162    /// Creates a new `CorrelationAccumulator`
163    pub fn try_new() -> Result<Self> {
164        Ok(Self {
165            covar: CovarianceAccumulator::try_new(StatsType::Population)?,
166            stddev1: StddevAccumulator::try_new(StatsType::Population)?,
167            stddev2: StddevAccumulator::try_new(StatsType::Population)?,
168        })
169    }
170}
171
172impl Accumulator for CorrelationAccumulator {
173    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
174        // TODO: null input skipping logic duplicated across Correlation
175        // and its children accumulators.
176        // This could be simplified by splitting up input filtering and
177        // calculation logic in children accumulators, and calling only
178        // calculation part from Correlation
179        let values = if values[0].null_count() != 0 || values[1].null_count() != 0 {
180            let mask = and(&is_not_null(&values[0])?, &is_not_null(&values[1])?)?;
181            let values1 = filter(&values[0], &mask)?;
182            let values2 = filter(&values[1], &mask)?;
183
184            vec![values1, values2]
185        } else {
186            values.to_vec()
187        };
188
189        self.covar.update_batch(&values)?;
190        self.stddev1.update_batch(&values[0..1])?;
191        self.stddev2.update_batch(&values[1..2])?;
192        Ok(())
193    }
194
195    fn evaluate(&mut self) -> Result<ScalarValue> {
196        let covar = self.covar.evaluate()?;
197        let stddev1 = self.stddev1.evaluate()?;
198        let stddev2 = self.stddev2.evaluate()?;
199
200        if let ScalarValue::Float64(Some(c)) = covar {
201            if let ScalarValue::Float64(Some(s1)) = stddev1 {
202                if let ScalarValue::Float64(Some(s2)) = stddev2 {
203                    if s1 == 0_f64 || s2 == 0_f64 {
204                        return Ok(ScalarValue::Float64(Some(0_f64)));
205                    } else {
206                        return Ok(ScalarValue::Float64(Some(c / s1 / s2)));
207                    }
208                }
209            }
210        }
211
212        Ok(ScalarValue::Float64(None))
213    }
214
215    fn size(&self) -> usize {
216        size_of_val(self) - size_of_val(&self.covar) + self.covar.size()
217            - size_of_val(&self.stddev1)
218            + self.stddev1.size()
219            - size_of_val(&self.stddev2)
220            + self.stddev2.size()
221    }
222
223    fn state(&mut self) -> Result<Vec<ScalarValue>> {
224        Ok(vec![
225            ScalarValue::from(self.covar.get_count()),
226            ScalarValue::from(self.covar.get_mean1()),
227            ScalarValue::from(self.stddev1.get_m2()),
228            ScalarValue::from(self.covar.get_mean2()),
229            ScalarValue::from(self.stddev2.get_m2()),
230            ScalarValue::from(self.covar.get_algo_const()),
231        ])
232    }
233
234    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
235        let states_c = [
236            Arc::clone(&states[0]),
237            Arc::clone(&states[1]),
238            Arc::clone(&states[3]),
239            Arc::clone(&states[5]),
240        ];
241        let states_s1 = [
242            Arc::clone(&states[0]),
243            Arc::clone(&states[1]),
244            Arc::clone(&states[2]),
245        ];
246        let states_s2 = [
247            Arc::clone(&states[0]),
248            Arc::clone(&states[3]),
249            Arc::clone(&states[4]),
250        ];
251
252        self.covar.merge_batch(&states_c)?;
253        self.stddev1.merge_batch(&states_s1)?;
254        self.stddev2.merge_batch(&states_s2)?;
255        Ok(())
256    }
257
258    fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
259        let values = if values[0].null_count() != 0 || values[1].null_count() != 0 {
260            let mask = and(&is_not_null(&values[0])?, &is_not_null(&values[1])?)?;
261            let values1 = filter(&values[0], &mask)?;
262            let values2 = filter(&values[1], &mask)?;
263
264            vec![values1, values2]
265        } else {
266            values.to_vec()
267        };
268
269        self.covar.retract_batch(&values)?;
270        self.stddev1.retract_batch(&values[0..1])?;
271        self.stddev2.retract_batch(&values[1..2])?;
272        Ok(())
273    }
274}
275
276#[derive(Default)]
277pub struct CorrelationGroupsAccumulator {
278    // Number of elements for each group
279    // This is also used to track nulls: if a group has 0 valid values accumulated,
280    // final aggregation result will be null.
281    count: Vec<u64>,
282    // Sum of x values for each group
283    sum_x: Vec<f64>,
284    // Sum of y
285    sum_y: Vec<f64>,
286    // Sum of x*y
287    sum_xy: Vec<f64>,
288    // Sum of x^2
289    sum_xx: Vec<f64>,
290    // Sum of y^2
291    sum_yy: Vec<f64>,
292}
293
294impl CorrelationGroupsAccumulator {
295    pub fn new() -> Self {
296        Default::default()
297    }
298}
299
300/// Specialized version of `accumulate_multiple` for correlation's merge_batch
301///
302/// Note: Arrays in `state_arrays` should not have null values, because they are all
303/// intermediate states created within the accumulator, instead of inputs from
304/// outside.
305fn accumulate_correlation_states(
306    group_indices: &[usize],
307    state_arrays: (
308        &UInt64Array,  // count
309        &Float64Array, // sum_x
310        &Float64Array, // sum_y
311        &Float64Array, // sum_xy
312        &Float64Array, // sum_xx
313        &Float64Array, // sum_yy
314    ),
315    mut value_fn: impl FnMut(usize, u64, &[f64]),
316) {
317    let (counts, sum_x, sum_y, sum_xy, sum_xx, sum_yy) = state_arrays;
318
319    assert_eq!(counts.null_count(), 0);
320    assert_eq!(sum_x.null_count(), 0);
321    assert_eq!(sum_y.null_count(), 0);
322    assert_eq!(sum_xy.null_count(), 0);
323    assert_eq!(sum_xx.null_count(), 0);
324    assert_eq!(sum_yy.null_count(), 0);
325
326    let counts_values = counts.values().as_ref();
327    let sum_x_values = sum_x.values().as_ref();
328    let sum_y_values = sum_y.values().as_ref();
329    let sum_xy_values = sum_xy.values().as_ref();
330    let sum_xx_values = sum_xx.values().as_ref();
331    let sum_yy_values = sum_yy.values().as_ref();
332
333    for (idx, &group_idx) in group_indices.iter().enumerate() {
334        let row = [
335            sum_x_values[idx],
336            sum_y_values[idx],
337            sum_xy_values[idx],
338            sum_xx_values[idx],
339            sum_yy_values[idx],
340        ];
341        value_fn(group_idx, counts_values[idx], &row);
342    }
343}
344
345/// GroupsAccumulator implementation for `corr(x, y)` that computes the Pearson correlation coefficient
346/// between two numeric columns.
347///
348/// Online algorithm for correlation:
349///
350/// r = (n * sum_xy - sum_x * sum_y) / sqrt((n * sum_xx - sum_x^2) * (n * sum_yy - sum_y^2))
351/// where:
352/// n = number of observations
353/// sum_x = sum of x values
354/// sum_y = sum of y values  
355/// sum_xy = sum of (x * y)
356/// sum_xx = sum of x^2 values
357/// sum_yy = sum of y^2 values
358///
359/// Reference: <https://en.wikipedia.org/wiki/Pearson_correlation_coefficient#For_a_sample>
360impl GroupsAccumulator for CorrelationGroupsAccumulator {
361    fn update_batch(
362        &mut self,
363        values: &[ArrayRef],
364        group_indices: &[usize],
365        opt_filter: Option<&BooleanArray>,
366        total_num_groups: usize,
367    ) -> Result<()> {
368        self.count.resize(total_num_groups, 0);
369        self.sum_x.resize(total_num_groups, 0.0);
370        self.sum_y.resize(total_num_groups, 0.0);
371        self.sum_xy.resize(total_num_groups, 0.0);
372        self.sum_xx.resize(total_num_groups, 0.0);
373        self.sum_yy.resize(total_num_groups, 0.0);
374
375        let array_x = &cast(&values[0], &DataType::Float64)?;
376        let array_x = downcast_array::<Float64Array>(array_x);
377        let array_y = &cast(&values[1], &DataType::Float64)?;
378        let array_y = downcast_array::<Float64Array>(array_y);
379
380        accumulate_multiple(
381            group_indices,
382            &[&array_x, &array_y],
383            opt_filter,
384            |group_index, batch_index, columns| {
385                let x = columns[0].value(batch_index);
386                let y = columns[1].value(batch_index);
387                self.count[group_index] += 1;
388                self.sum_x[group_index] += x;
389                self.sum_y[group_index] += y;
390                self.sum_xy[group_index] += x * y;
391                self.sum_xx[group_index] += x * x;
392                self.sum_yy[group_index] += y * y;
393            },
394        );
395
396        Ok(())
397    }
398
399    fn merge_batch(
400        &mut self,
401        values: &[ArrayRef],
402        group_indices: &[usize],
403        opt_filter: Option<&BooleanArray>,
404        total_num_groups: usize,
405    ) -> Result<()> {
406        // Resize vectors to accommodate total number of groups
407        self.count.resize(total_num_groups, 0);
408        self.sum_x.resize(total_num_groups, 0.0);
409        self.sum_y.resize(total_num_groups, 0.0);
410        self.sum_xy.resize(total_num_groups, 0.0);
411        self.sum_xx.resize(total_num_groups, 0.0);
412        self.sum_yy.resize(total_num_groups, 0.0);
413
414        // Extract arrays from input values
415        let partial_counts = values[0].as_primitive::<UInt64Type>();
416        let partial_sum_x = values[1].as_primitive::<Float64Type>();
417        let partial_sum_y = values[2].as_primitive::<Float64Type>();
418        let partial_sum_xy = values[3].as_primitive::<Float64Type>();
419        let partial_sum_xx = values[4].as_primitive::<Float64Type>();
420        let partial_sum_yy = values[5].as_primitive::<Float64Type>();
421
422        assert!(opt_filter.is_none(), "aggregate filter should be applied in partial stage, there should be no filter in final stage");
423
424        accumulate_correlation_states(
425            group_indices,
426            (
427                partial_counts,
428                partial_sum_x,
429                partial_sum_y,
430                partial_sum_xy,
431                partial_sum_xx,
432                partial_sum_yy,
433            ),
434            |group_index, count, values| {
435                self.count[group_index] += count;
436                self.sum_x[group_index] += values[0];
437                self.sum_y[group_index] += values[1];
438                self.sum_xy[group_index] += values[2];
439                self.sum_xx[group_index] += values[3];
440                self.sum_yy[group_index] += values[4];
441            },
442        );
443
444        Ok(())
445    }
446
447    fn evaluate(&mut self, emit_to: EmitTo) -> Result<ArrayRef> {
448        let n = match emit_to {
449            EmitTo::All => self.count.len(),
450            EmitTo::First(n) => n,
451        };
452
453        let mut values = Vec::with_capacity(n);
454        let mut nulls = NullBufferBuilder::new(n);
455
456        // Notes for `Null` handling:
457        // - If the `count` state of a group is 0, no valid records are accumulated
458        //   for this group, so the aggregation result is `Null`.
459        // - Correlation can't be calculated when a group only has 1 record, or when
460        //   the `denominator` state is 0. In these cases, the final aggregation
461        //   result should be `Null` (according to PostgreSQL's behavior).
462        //
463        // TODO: Old datafusion implementation returns 0.0 for these invalid cases.
464        // Update this to match PostgreSQL's behavior.
465        for i in 0..n {
466            if self.count[i] < 2 {
467                // TODO: Evaluate as `Null` (see notes above)
468                values.push(0.0);
469                nulls.append_null();
470                continue;
471            }
472
473            let count = self.count[i];
474            let sum_x = self.sum_x[i];
475            let sum_y = self.sum_y[i];
476            let sum_xy = self.sum_xy[i];
477            let sum_xx = self.sum_xx[i];
478            let sum_yy = self.sum_yy[i];
479
480            let mean_x = sum_x / count as f64;
481            let mean_y = sum_y / count as f64;
482
483            let numerator = sum_xy - sum_x * mean_y;
484            let denominator =
485                ((sum_xx - sum_x * mean_x) * (sum_yy - sum_y * mean_y)).sqrt();
486
487            if denominator == 0.0 {
488                // TODO: Evaluate as `Null` (see notes above)
489                values.push(0.0);
490                nulls.append_null();
491            } else {
492                values.push(numerator / denominator);
493                nulls.append_non_null();
494            }
495        }
496
497        Ok(Arc::new(Float64Array::new(values.into(), nulls.finish())))
498    }
499
500    fn state(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> {
501        let n = match emit_to {
502            EmitTo::All => self.count.len(),
503            EmitTo::First(n) => n,
504        };
505
506        Ok(vec![
507            Arc::new(UInt64Array::from(self.count[0..n].to_vec())),
508            Arc::new(Float64Array::from(self.sum_x[0..n].to_vec())),
509            Arc::new(Float64Array::from(self.sum_y[0..n].to_vec())),
510            Arc::new(Float64Array::from(self.sum_xy[0..n].to_vec())),
511            Arc::new(Float64Array::from(self.sum_xx[0..n].to_vec())),
512            Arc::new(Float64Array::from(self.sum_yy[0..n].to_vec())),
513        ])
514    }
515
516    fn size(&self) -> usize {
517        size_of_val(&self.count)
518            + size_of_val(&self.sum_x)
519            + size_of_val(&self.sum_y)
520            + size_of_val(&self.sum_xy)
521            + size_of_val(&self.sum_xx)
522            + size_of_val(&self.sum_yy)
523    }
524}
525
526#[cfg(test)]
527mod tests {
528    use super::*;
529    use arrow::array::{Float64Array, UInt64Array};
530
531    #[test]
532    fn test_accumulate_correlation_states() {
533        // Test data
534        let group_indices = vec![0, 1, 0, 1];
535        let counts = UInt64Array::from(vec![1, 2, 3, 4]);
536        let sum_x = Float64Array::from(vec![10.0, 20.0, 30.0, 40.0]);
537        let sum_y = Float64Array::from(vec![1.0, 2.0, 3.0, 4.0]);
538        let sum_xy = Float64Array::from(vec![10.0, 40.0, 90.0, 160.0]);
539        let sum_xx = Float64Array::from(vec![100.0, 400.0, 900.0, 1600.0]);
540        let sum_yy = Float64Array::from(vec![1.0, 4.0, 9.0, 16.0]);
541
542        let mut accumulated = vec![];
543        accumulate_correlation_states(
544            &group_indices,
545            (&counts, &sum_x, &sum_y, &sum_xy, &sum_xx, &sum_yy),
546            |group_idx, count, values| {
547                accumulated.push((group_idx, count, values.to_vec()));
548            },
549        );
550
551        let expected = vec![
552            (0, 1, vec![10.0, 1.0, 10.0, 100.0, 1.0]),
553            (1, 2, vec![20.0, 2.0, 40.0, 400.0, 4.0]),
554            (0, 3, vec![30.0, 3.0, 90.0, 900.0, 9.0]),
555            (1, 4, vec![40.0, 4.0, 160.0, 1600.0, 16.0]),
556        ];
557        assert_eq!(accumulated, expected);
558
559        // Test that function panics with null values
560        let counts = UInt64Array::from(vec![Some(1), None, Some(3), Some(4)]);
561        let sum_x = Float64Array::from(vec![10.0, 20.0, 30.0, 40.0]);
562        let sum_y = Float64Array::from(vec![1.0, 2.0, 3.0, 4.0]);
563        let sum_xy = Float64Array::from(vec![10.0, 40.0, 90.0, 160.0]);
564        let sum_xx = Float64Array::from(vec![100.0, 400.0, 900.0, 1600.0]);
565        let sum_yy = Float64Array::from(vec![1.0, 4.0, 9.0, 16.0]);
566
567        let result = std::panic::catch_unwind(|| {
568            accumulate_correlation_states(
569                &group_indices,
570                (&counts, &sum_x, &sum_y, &sum_xy, &sum_xx, &sum_yy),
571                |_, _, _| {},
572            )
573        });
574        assert!(result.is_err());
575    }
576}