datafusion_functions_aggregate/
covariance.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`CovarianceSample`]: covariance sample aggregations.
19
20use std::fmt::Debug;
21use std::mem::size_of_val;
22
23use arrow::{
24    array::{ArrayRef, Float64Array, UInt64Array},
25    compute::kernels::cast,
26    datatypes::{DataType, Field},
27};
28
29use datafusion_common::{
30    downcast_value, plan_err, unwrap_or_internal_err, DataFusionError, Result,
31    ScalarValue,
32};
33use datafusion_expr::{
34    function::{AccumulatorArgs, StateFieldsArgs},
35    type_coercion::aggregates::NUMERICS,
36    utils::format_state_name,
37    Accumulator, AggregateUDFImpl, Documentation, Signature, Volatility,
38};
39use datafusion_functions_aggregate_common::stats::StatsType;
40use datafusion_macros::user_doc;
41
42make_udaf_expr_and_func!(
43    CovarianceSample,
44    covar_samp,
45    y x,
46    "Computes the sample covariance.",
47    covar_samp_udaf
48);
49
50make_udaf_expr_and_func!(
51    CovariancePopulation,
52    covar_pop,
53    y x,
54    "Computes the population covariance.",
55    covar_pop_udaf
56);
57
58#[user_doc(
59    doc_section(label = "Statistical Functions"),
60    description = "Returns the sample covariance of a set of number pairs.",
61    syntax_example = "covar_samp(expression1, expression2)",
62    sql_example = r#"```sql
63> SELECT covar_samp(column1, column2) FROM table_name;
64+-----------------------------------+
65| covar_samp(column1, column2)      |
66+-----------------------------------+
67| 8.25                              |
68+-----------------------------------+
69```"#,
70    standard_argument(name = "expression1", prefix = "First"),
71    standard_argument(name = "expression2", prefix = "Second")
72)]
73pub struct CovarianceSample {
74    signature: Signature,
75    aliases: Vec<String>,
76}
77
78impl Debug for CovarianceSample {
79    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
80        f.debug_struct("CovarianceSample")
81            .field("name", &self.name())
82            .field("signature", &self.signature)
83            .finish()
84    }
85}
86
87impl Default for CovarianceSample {
88    fn default() -> Self {
89        Self::new()
90    }
91}
92
93impl CovarianceSample {
94    pub fn new() -> Self {
95        Self {
96            aliases: vec![String::from("covar")],
97            signature: Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable),
98        }
99    }
100}
101
102impl AggregateUDFImpl for CovarianceSample {
103    fn as_any(&self) -> &dyn std::any::Any {
104        self
105    }
106
107    fn name(&self) -> &str {
108        "covar_samp"
109    }
110
111    fn signature(&self) -> &Signature {
112        &self.signature
113    }
114
115    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
116        if !arg_types[0].is_numeric() {
117            return plan_err!("Covariance requires numeric input types");
118        }
119
120        Ok(DataType::Float64)
121    }
122
123    fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> {
124        let name = args.name;
125        Ok(vec![
126            Field::new(format_state_name(name, "count"), DataType::UInt64, true),
127            Field::new(format_state_name(name, "mean1"), DataType::Float64, true),
128            Field::new(format_state_name(name, "mean2"), DataType::Float64, true),
129            Field::new(
130                format_state_name(name, "algo_const"),
131                DataType::Float64,
132                true,
133            ),
134        ])
135    }
136
137    fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
138        Ok(Box::new(CovarianceAccumulator::try_new(StatsType::Sample)?))
139    }
140
141    fn aliases(&self) -> &[String] {
142        &self.aliases
143    }
144
145    fn documentation(&self) -> Option<&Documentation> {
146        self.doc()
147    }
148}
149
150#[user_doc(
151    doc_section(label = "Statistical Functions"),
152    description = "Returns the sample covariance of a set of number pairs.",
153    syntax_example = "covar_samp(expression1, expression2)",
154    sql_example = r#"```sql
155> SELECT covar_samp(column1, column2) FROM table_name;
156+-----------------------------------+
157| covar_samp(column1, column2)      |
158+-----------------------------------+
159| 8.25                              |
160+-----------------------------------+
161```"#,
162    standard_argument(name = "expression1", prefix = "First"),
163    standard_argument(name = "expression2", prefix = "Second")
164)]
165pub struct CovariancePopulation {
166    signature: Signature,
167}
168
169impl Debug for CovariancePopulation {
170    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
171        f.debug_struct("CovariancePopulation")
172            .field("name", &self.name())
173            .field("signature", &self.signature)
174            .finish()
175    }
176}
177
178impl Default for CovariancePopulation {
179    fn default() -> Self {
180        Self::new()
181    }
182}
183
184impl CovariancePopulation {
185    pub fn new() -> Self {
186        Self {
187            signature: Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable),
188        }
189    }
190}
191
192impl AggregateUDFImpl for CovariancePopulation {
193    fn as_any(&self) -> &dyn std::any::Any {
194        self
195    }
196
197    fn name(&self) -> &str {
198        "covar_pop"
199    }
200
201    fn signature(&self) -> &Signature {
202        &self.signature
203    }
204
205    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
206        if !arg_types[0].is_numeric() {
207            return plan_err!("Covariance requires numeric input types");
208        }
209
210        Ok(DataType::Float64)
211    }
212
213    fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> {
214        let name = args.name;
215        Ok(vec![
216            Field::new(format_state_name(name, "count"), DataType::UInt64, true),
217            Field::new(format_state_name(name, "mean1"), DataType::Float64, true),
218            Field::new(format_state_name(name, "mean2"), DataType::Float64, true),
219            Field::new(
220                format_state_name(name, "algo_const"),
221                DataType::Float64,
222                true,
223            ),
224        ])
225    }
226
227    fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
228        Ok(Box::new(CovarianceAccumulator::try_new(
229            StatsType::Population,
230        )?))
231    }
232
233    fn documentation(&self) -> Option<&Documentation> {
234        self.doc()
235    }
236}
237
238/// An accumulator to compute covariance
239/// The algorithm used is an online implementation and numerically stable. It is derived from the following paper
240/// for calculating variance:
241/// Welford, B. P. (1962). "Note on a method for calculating corrected sums of squares and products".
242/// Technometrics. 4 (3): 419–420. doi:10.2307/1266577. JSTOR 1266577.
243///
244/// The algorithm has been analyzed here:
245/// Ling, Robert F. (1974). "Comparison of Several Algorithms for Computing Sample Means and Variances".
246/// Journal of the American Statistical Association. 69 (348): 859–866. doi:10.2307/2286154. JSTOR 2286154.
247///
248/// Though it is not covered in the original paper but is based on the same idea, as a result the algorithm is online,
249/// parallelize and numerically stable.
250
251#[derive(Debug)]
252pub struct CovarianceAccumulator {
253    algo_const: f64,
254    mean1: f64,
255    mean2: f64,
256    count: u64,
257    stats_type: StatsType,
258}
259
260impl CovarianceAccumulator {
261    /// Creates a new `CovarianceAccumulator`
262    pub fn try_new(s_type: StatsType) -> Result<Self> {
263        Ok(Self {
264            algo_const: 0_f64,
265            mean1: 0_f64,
266            mean2: 0_f64,
267            count: 0_u64,
268            stats_type: s_type,
269        })
270    }
271
272    pub fn get_count(&self) -> u64 {
273        self.count
274    }
275
276    pub fn get_mean1(&self) -> f64 {
277        self.mean1
278    }
279
280    pub fn get_mean2(&self) -> f64 {
281        self.mean2
282    }
283
284    pub fn get_algo_const(&self) -> f64 {
285        self.algo_const
286    }
287}
288
289impl Accumulator for CovarianceAccumulator {
290    fn state(&mut self) -> Result<Vec<ScalarValue>> {
291        Ok(vec![
292            ScalarValue::from(self.count),
293            ScalarValue::from(self.mean1),
294            ScalarValue::from(self.mean2),
295            ScalarValue::from(self.algo_const),
296        ])
297    }
298
299    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
300        let values1 = &cast(&values[0], &DataType::Float64)?;
301        let values2 = &cast(&values[1], &DataType::Float64)?;
302
303        let mut arr1 = downcast_value!(values1, Float64Array).iter().flatten();
304        let mut arr2 = downcast_value!(values2, Float64Array).iter().flatten();
305
306        for i in 0..values1.len() {
307            let value1 = if values1.is_valid(i) {
308                arr1.next()
309            } else {
310                None
311            };
312            let value2 = if values2.is_valid(i) {
313                arr2.next()
314            } else {
315                None
316            };
317
318            if value1.is_none() || value2.is_none() {
319                continue;
320            }
321
322            let value1 = unwrap_or_internal_err!(value1);
323            let value2 = unwrap_or_internal_err!(value2);
324            let new_count = self.count + 1;
325            let delta1 = value1 - self.mean1;
326            let new_mean1 = delta1 / new_count as f64 + self.mean1;
327            let delta2 = value2 - self.mean2;
328            let new_mean2 = delta2 / new_count as f64 + self.mean2;
329            let new_c = delta1 * (value2 - new_mean2) + self.algo_const;
330
331            self.count += 1;
332            self.mean1 = new_mean1;
333            self.mean2 = new_mean2;
334            self.algo_const = new_c;
335        }
336
337        Ok(())
338    }
339
340    fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
341        let values1 = &cast(&values[0], &DataType::Float64)?;
342        let values2 = &cast(&values[1], &DataType::Float64)?;
343        let mut arr1 = downcast_value!(values1, Float64Array).iter().flatten();
344        let mut arr2 = downcast_value!(values2, Float64Array).iter().flatten();
345
346        for i in 0..values1.len() {
347            let value1 = if values1.is_valid(i) {
348                arr1.next()
349            } else {
350                None
351            };
352            let value2 = if values2.is_valid(i) {
353                arr2.next()
354            } else {
355                None
356            };
357
358            if value1.is_none() || value2.is_none() {
359                continue;
360            }
361
362            let value1 = unwrap_or_internal_err!(value1);
363            let value2 = unwrap_or_internal_err!(value2);
364
365            let new_count = self.count - 1;
366            let delta1 = self.mean1 - value1;
367            let new_mean1 = delta1 / new_count as f64 + self.mean1;
368            let delta2 = self.mean2 - value2;
369            let new_mean2 = delta2 / new_count as f64 + self.mean2;
370            let new_c = self.algo_const - delta1 * (new_mean2 - value2);
371
372            self.count -= 1;
373            self.mean1 = new_mean1;
374            self.mean2 = new_mean2;
375            self.algo_const = new_c;
376        }
377
378        Ok(())
379    }
380
381    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
382        let counts = downcast_value!(states[0], UInt64Array);
383        let means1 = downcast_value!(states[1], Float64Array);
384        let means2 = downcast_value!(states[2], Float64Array);
385        let cs = downcast_value!(states[3], Float64Array);
386
387        for i in 0..counts.len() {
388            let c = counts.value(i);
389            if c == 0_u64 {
390                continue;
391            }
392            let new_count = self.count + c;
393            let new_mean1 = self.mean1 * self.count as f64 / new_count as f64
394                + means1.value(i) * c as f64 / new_count as f64;
395            let new_mean2 = self.mean2 * self.count as f64 / new_count as f64
396                + means2.value(i) * c as f64 / new_count as f64;
397            let delta1 = self.mean1 - means1.value(i);
398            let delta2 = self.mean2 - means2.value(i);
399            let new_c = self.algo_const
400                + cs.value(i)
401                + delta1 * delta2 * self.count as f64 * c as f64 / new_count as f64;
402
403            self.count = new_count;
404            self.mean1 = new_mean1;
405            self.mean2 = new_mean2;
406            self.algo_const = new_c;
407        }
408        Ok(())
409    }
410
411    fn evaluate(&mut self) -> Result<ScalarValue> {
412        let count = match self.stats_type {
413            StatsType::Population => self.count,
414            StatsType::Sample => {
415                if self.count > 0 {
416                    self.count - 1
417                } else {
418                    self.count
419                }
420            }
421        };
422
423        if count == 0 {
424            Ok(ScalarValue::Float64(None))
425        } else {
426            Ok(ScalarValue::Float64(Some(self.algo_const / count as f64)))
427        }
428    }
429
430    fn size(&self) -> usize {
431        size_of_val(self)
432    }
433}