datafusion_functions_aggregate/
covariance.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`CovarianceSample`]: covariance sample aggregations.
19
20use arrow::datatypes::FieldRef;
21use arrow::{
22    array::{ArrayRef, Float64Array, UInt64Array},
23    compute::kernels::cast,
24    datatypes::{DataType, Field},
25};
26use datafusion_common::{
27    downcast_value, plan_err, unwrap_or_internal_err, DataFusionError, Result,
28    ScalarValue,
29};
30use datafusion_expr::{
31    function::{AccumulatorArgs, StateFieldsArgs},
32    type_coercion::aggregates::NUMERICS,
33    utils::format_state_name,
34    Accumulator, AggregateUDFImpl, Documentation, Signature, Volatility,
35};
36use datafusion_functions_aggregate_common::stats::StatsType;
37use datafusion_macros::user_doc;
38use std::fmt::Debug;
39use std::mem::size_of_val;
40use std::sync::Arc;
41
42make_udaf_expr_and_func!(
43    CovarianceSample,
44    covar_samp,
45    y x,
46    "Computes the sample covariance.",
47    covar_samp_udaf
48);
49
50make_udaf_expr_and_func!(
51    CovariancePopulation,
52    covar_pop,
53    y x,
54    "Computes the population covariance.",
55    covar_pop_udaf
56);
57
58#[user_doc(
59    doc_section(label = "Statistical Functions"),
60    description = "Returns the sample covariance of a set of number pairs.",
61    syntax_example = "covar_samp(expression1, expression2)",
62    sql_example = r#"```sql
63> SELECT covar_samp(column1, column2) FROM table_name;
64+-----------------------------------+
65| covar_samp(column1, column2)      |
66+-----------------------------------+
67| 8.25                              |
68+-----------------------------------+
69```"#,
70    standard_argument(name = "expression1", prefix = "First"),
71    standard_argument(name = "expression2", prefix = "Second")
72)]
73pub struct CovarianceSample {
74    signature: Signature,
75    aliases: Vec<String>,
76}
77
78impl Debug for CovarianceSample {
79    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
80        f.debug_struct("CovarianceSample")
81            .field("name", &self.name())
82            .field("signature", &self.signature)
83            .finish()
84    }
85}
86
87impl Default for CovarianceSample {
88    fn default() -> Self {
89        Self::new()
90    }
91}
92
93impl CovarianceSample {
94    pub fn new() -> Self {
95        Self {
96            aliases: vec![String::from("covar")],
97            signature: Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable),
98        }
99    }
100}
101
102impl AggregateUDFImpl for CovarianceSample {
103    fn as_any(&self) -> &dyn std::any::Any {
104        self
105    }
106
107    fn name(&self) -> &str {
108        "covar_samp"
109    }
110
111    fn signature(&self) -> &Signature {
112        &self.signature
113    }
114
115    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
116        if !arg_types[0].is_numeric() {
117            return plan_err!("Covariance requires numeric input types");
118        }
119
120        Ok(DataType::Float64)
121    }
122
123    fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
124        let name = args.name;
125        Ok(vec![
126            Field::new(format_state_name(name, "count"), DataType::UInt64, true),
127            Field::new(format_state_name(name, "mean1"), DataType::Float64, true),
128            Field::new(format_state_name(name, "mean2"), DataType::Float64, true),
129            Field::new(
130                format_state_name(name, "algo_const"),
131                DataType::Float64,
132                true,
133            ),
134        ]
135        .into_iter()
136        .map(Arc::new)
137        .collect())
138    }
139
140    fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
141        Ok(Box::new(CovarianceAccumulator::try_new(StatsType::Sample)?))
142    }
143
144    fn aliases(&self) -> &[String] {
145        &self.aliases
146    }
147
148    fn documentation(&self) -> Option<&Documentation> {
149        self.doc()
150    }
151}
152
153#[user_doc(
154    doc_section(label = "Statistical Functions"),
155    description = "Returns the sample covariance of a set of number pairs.",
156    syntax_example = "covar_samp(expression1, expression2)",
157    sql_example = r#"```sql
158> SELECT covar_samp(column1, column2) FROM table_name;
159+-----------------------------------+
160| covar_samp(column1, column2)      |
161+-----------------------------------+
162| 8.25                              |
163+-----------------------------------+
164```"#,
165    standard_argument(name = "expression1", prefix = "First"),
166    standard_argument(name = "expression2", prefix = "Second")
167)]
168pub struct CovariancePopulation {
169    signature: Signature,
170}
171
172impl Debug for CovariancePopulation {
173    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
174        f.debug_struct("CovariancePopulation")
175            .field("name", &self.name())
176            .field("signature", &self.signature)
177            .finish()
178    }
179}
180
181impl Default for CovariancePopulation {
182    fn default() -> Self {
183        Self::new()
184    }
185}
186
187impl CovariancePopulation {
188    pub fn new() -> Self {
189        Self {
190            signature: Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable),
191        }
192    }
193}
194
195impl AggregateUDFImpl for CovariancePopulation {
196    fn as_any(&self) -> &dyn std::any::Any {
197        self
198    }
199
200    fn name(&self) -> &str {
201        "covar_pop"
202    }
203
204    fn signature(&self) -> &Signature {
205        &self.signature
206    }
207
208    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
209        if !arg_types[0].is_numeric() {
210            return plan_err!("Covariance requires numeric input types");
211        }
212
213        Ok(DataType::Float64)
214    }
215
216    fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
217        let name = args.name;
218        Ok(vec![
219            Field::new(format_state_name(name, "count"), DataType::UInt64, true),
220            Field::new(format_state_name(name, "mean1"), DataType::Float64, true),
221            Field::new(format_state_name(name, "mean2"), DataType::Float64, true),
222            Field::new(
223                format_state_name(name, "algo_const"),
224                DataType::Float64,
225                true,
226            ),
227        ]
228        .into_iter()
229        .map(Arc::new)
230        .collect())
231    }
232
233    fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
234        Ok(Box::new(CovarianceAccumulator::try_new(
235            StatsType::Population,
236        )?))
237    }
238
239    fn documentation(&self) -> Option<&Documentation> {
240        self.doc()
241    }
242}
243
244/// An accumulator to compute covariance
245/// The algorithm used is an online implementation and numerically stable. It is derived from the following paper
246/// for calculating variance:
247/// Welford, B. P. (1962). "Note on a method for calculating corrected sums of squares and products".
248/// Technometrics. 4 (3): 419–420. doi:10.2307/1266577. JSTOR 1266577.
249///
250/// The algorithm has been analyzed here:
251/// Ling, Robert F. (1974). "Comparison of Several Algorithms for Computing Sample Means and Variances".
252/// Journal of the American Statistical Association. 69 (348): 859–866. doi:10.2307/2286154. JSTOR 2286154.
253///
254/// Though it is not covered in the original paper but is based on the same idea, as a result the algorithm is online,
255/// parallelize and numerically stable.
256
257#[derive(Debug)]
258pub struct CovarianceAccumulator {
259    algo_const: f64,
260    mean1: f64,
261    mean2: f64,
262    count: u64,
263    stats_type: StatsType,
264}
265
266impl CovarianceAccumulator {
267    /// Creates a new `CovarianceAccumulator`
268    pub fn try_new(s_type: StatsType) -> Result<Self> {
269        Ok(Self {
270            algo_const: 0_f64,
271            mean1: 0_f64,
272            mean2: 0_f64,
273            count: 0_u64,
274            stats_type: s_type,
275        })
276    }
277
278    pub fn get_count(&self) -> u64 {
279        self.count
280    }
281
282    pub fn get_mean1(&self) -> f64 {
283        self.mean1
284    }
285
286    pub fn get_mean2(&self) -> f64 {
287        self.mean2
288    }
289
290    pub fn get_algo_const(&self) -> f64 {
291        self.algo_const
292    }
293}
294
295impl Accumulator for CovarianceAccumulator {
296    fn state(&mut self) -> Result<Vec<ScalarValue>> {
297        Ok(vec![
298            ScalarValue::from(self.count),
299            ScalarValue::from(self.mean1),
300            ScalarValue::from(self.mean2),
301            ScalarValue::from(self.algo_const),
302        ])
303    }
304
305    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
306        let values1 = &cast(&values[0], &DataType::Float64)?;
307        let values2 = &cast(&values[1], &DataType::Float64)?;
308
309        let mut arr1 = downcast_value!(values1, Float64Array).iter().flatten();
310        let mut arr2 = downcast_value!(values2, Float64Array).iter().flatten();
311
312        for i in 0..values1.len() {
313            let value1 = if values1.is_valid(i) {
314                arr1.next()
315            } else {
316                None
317            };
318            let value2 = if values2.is_valid(i) {
319                arr2.next()
320            } else {
321                None
322            };
323
324            if value1.is_none() || value2.is_none() {
325                continue;
326            }
327
328            let value1 = unwrap_or_internal_err!(value1);
329            let value2 = unwrap_or_internal_err!(value2);
330            let new_count = self.count + 1;
331            let delta1 = value1 - self.mean1;
332            let new_mean1 = delta1 / new_count as f64 + self.mean1;
333            let delta2 = value2 - self.mean2;
334            let new_mean2 = delta2 / new_count as f64 + self.mean2;
335            let new_c = delta1 * (value2 - new_mean2) + self.algo_const;
336
337            self.count += 1;
338            self.mean1 = new_mean1;
339            self.mean2 = new_mean2;
340            self.algo_const = new_c;
341        }
342
343        Ok(())
344    }
345
346    fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
347        let values1 = &cast(&values[0], &DataType::Float64)?;
348        let values2 = &cast(&values[1], &DataType::Float64)?;
349        let mut arr1 = downcast_value!(values1, Float64Array).iter().flatten();
350        let mut arr2 = downcast_value!(values2, Float64Array).iter().flatten();
351
352        for i in 0..values1.len() {
353            let value1 = if values1.is_valid(i) {
354                arr1.next()
355            } else {
356                None
357            };
358            let value2 = if values2.is_valid(i) {
359                arr2.next()
360            } else {
361                None
362            };
363
364            if value1.is_none() || value2.is_none() {
365                continue;
366            }
367
368            let value1 = unwrap_or_internal_err!(value1);
369            let value2 = unwrap_or_internal_err!(value2);
370
371            let new_count = self.count - 1;
372            let delta1 = self.mean1 - value1;
373            let new_mean1 = delta1 / new_count as f64 + self.mean1;
374            let delta2 = self.mean2 - value2;
375            let new_mean2 = delta2 / new_count as f64 + self.mean2;
376            let new_c = self.algo_const - delta1 * (new_mean2 - value2);
377
378            self.count -= 1;
379            self.mean1 = new_mean1;
380            self.mean2 = new_mean2;
381            self.algo_const = new_c;
382        }
383
384        Ok(())
385    }
386
387    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
388        let counts = downcast_value!(states[0], UInt64Array);
389        let means1 = downcast_value!(states[1], Float64Array);
390        let means2 = downcast_value!(states[2], Float64Array);
391        let cs = downcast_value!(states[3], Float64Array);
392
393        for i in 0..counts.len() {
394            let c = counts.value(i);
395            if c == 0_u64 {
396                continue;
397            }
398            let new_count = self.count + c;
399            let new_mean1 = self.mean1 * self.count as f64 / new_count as f64
400                + means1.value(i) * c as f64 / new_count as f64;
401            let new_mean2 = self.mean2 * self.count as f64 / new_count as f64
402                + means2.value(i) * c as f64 / new_count as f64;
403            let delta1 = self.mean1 - means1.value(i);
404            let delta2 = self.mean2 - means2.value(i);
405            let new_c = self.algo_const
406                + cs.value(i)
407                + delta1 * delta2 * self.count as f64 * c as f64 / new_count as f64;
408
409            self.count = new_count;
410            self.mean1 = new_mean1;
411            self.mean2 = new_mean2;
412            self.algo_const = new_c;
413        }
414        Ok(())
415    }
416
417    fn evaluate(&mut self) -> Result<ScalarValue> {
418        let count = match self.stats_type {
419            StatsType::Population => self.count,
420            StatsType::Sample => {
421                if self.count > 0 {
422                    self.count - 1
423                } else {
424                    self.count
425                }
426            }
427        };
428
429        if count == 0 {
430            Ok(ScalarValue::Float64(None))
431        } else {
432            Ok(ScalarValue::Float64(Some(self.algo_const / count as f64)))
433        }
434    }
435
436    fn size(&self) -> usize {
437        size_of_val(self)
438    }
439}