datafusion_functions_aggregate/
covariance.rs1use arrow::datatypes::FieldRef;
21use arrow::{
22 array::{ArrayRef, Float64Array, UInt64Array},
23 compute::kernels::cast,
24 datatypes::{DataType, Field},
25};
26use datafusion_common::{
27 downcast_value, plan_err, unwrap_or_internal_err, DataFusionError, Result,
28 ScalarValue,
29};
30use datafusion_expr::{
31 function::{AccumulatorArgs, StateFieldsArgs},
32 type_coercion::aggregates::NUMERICS,
33 utils::format_state_name,
34 Accumulator, AggregateUDFImpl, Documentation, Signature, Volatility,
35};
36use datafusion_functions_aggregate_common::stats::StatsType;
37use datafusion_macros::user_doc;
38use std::fmt::Debug;
39use std::mem::size_of_val;
40use std::sync::Arc;
41
42make_udaf_expr_and_func!(
43 CovarianceSample,
44 covar_samp,
45 y x,
46 "Computes the sample covariance.",
47 covar_samp_udaf
48);
49
50make_udaf_expr_and_func!(
51 CovariancePopulation,
52 covar_pop,
53 y x,
54 "Computes the population covariance.",
55 covar_pop_udaf
56);
57
58#[user_doc(
59 doc_section(label = "Statistical Functions"),
60 description = "Returns the sample covariance of a set of number pairs.",
61 syntax_example = "covar_samp(expression1, expression2)",
62 sql_example = r#"```sql
63> SELECT covar_samp(column1, column2) FROM table_name;
64+-----------------------------------+
65| covar_samp(column1, column2) |
66+-----------------------------------+
67| 8.25 |
68+-----------------------------------+
69```"#,
70 standard_argument(name = "expression1", prefix = "First"),
71 standard_argument(name = "expression2", prefix = "Second")
72)]
73pub struct CovarianceSample {
74 signature: Signature,
75 aliases: Vec<String>,
76}
77
78impl Debug for CovarianceSample {
79 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
80 f.debug_struct("CovarianceSample")
81 .field("name", &self.name())
82 .field("signature", &self.signature)
83 .finish()
84 }
85}
86
87impl Default for CovarianceSample {
88 fn default() -> Self {
89 Self::new()
90 }
91}
92
93impl CovarianceSample {
94 pub fn new() -> Self {
95 Self {
96 aliases: vec![String::from("covar")],
97 signature: Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable),
98 }
99 }
100}
101
102impl AggregateUDFImpl for CovarianceSample {
103 fn as_any(&self) -> &dyn std::any::Any {
104 self
105 }
106
107 fn name(&self) -> &str {
108 "covar_samp"
109 }
110
111 fn signature(&self) -> &Signature {
112 &self.signature
113 }
114
115 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
116 if !arg_types[0].is_numeric() {
117 return plan_err!("Covariance requires numeric input types");
118 }
119
120 Ok(DataType::Float64)
121 }
122
123 fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
124 let name = args.name;
125 Ok(vec![
126 Field::new(format_state_name(name, "count"), DataType::UInt64, true),
127 Field::new(format_state_name(name, "mean1"), DataType::Float64, true),
128 Field::new(format_state_name(name, "mean2"), DataType::Float64, true),
129 Field::new(
130 format_state_name(name, "algo_const"),
131 DataType::Float64,
132 true,
133 ),
134 ]
135 .into_iter()
136 .map(Arc::new)
137 .collect())
138 }
139
140 fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
141 Ok(Box::new(CovarianceAccumulator::try_new(StatsType::Sample)?))
142 }
143
144 fn aliases(&self) -> &[String] {
145 &self.aliases
146 }
147
148 fn documentation(&self) -> Option<&Documentation> {
149 self.doc()
150 }
151}
152
153#[user_doc(
154 doc_section(label = "Statistical Functions"),
155 description = "Returns the sample covariance of a set of number pairs.",
156 syntax_example = "covar_samp(expression1, expression2)",
157 sql_example = r#"```sql
158> SELECT covar_samp(column1, column2) FROM table_name;
159+-----------------------------------+
160| covar_samp(column1, column2) |
161+-----------------------------------+
162| 8.25 |
163+-----------------------------------+
164```"#,
165 standard_argument(name = "expression1", prefix = "First"),
166 standard_argument(name = "expression2", prefix = "Second")
167)]
168pub struct CovariancePopulation {
169 signature: Signature,
170}
171
172impl Debug for CovariancePopulation {
173 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
174 f.debug_struct("CovariancePopulation")
175 .field("name", &self.name())
176 .field("signature", &self.signature)
177 .finish()
178 }
179}
180
181impl Default for CovariancePopulation {
182 fn default() -> Self {
183 Self::new()
184 }
185}
186
187impl CovariancePopulation {
188 pub fn new() -> Self {
189 Self {
190 signature: Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable),
191 }
192 }
193}
194
195impl AggregateUDFImpl for CovariancePopulation {
196 fn as_any(&self) -> &dyn std::any::Any {
197 self
198 }
199
200 fn name(&self) -> &str {
201 "covar_pop"
202 }
203
204 fn signature(&self) -> &Signature {
205 &self.signature
206 }
207
208 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
209 if !arg_types[0].is_numeric() {
210 return plan_err!("Covariance requires numeric input types");
211 }
212
213 Ok(DataType::Float64)
214 }
215
216 fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
217 let name = args.name;
218 Ok(vec![
219 Field::new(format_state_name(name, "count"), DataType::UInt64, true),
220 Field::new(format_state_name(name, "mean1"), DataType::Float64, true),
221 Field::new(format_state_name(name, "mean2"), DataType::Float64, true),
222 Field::new(
223 format_state_name(name, "algo_const"),
224 DataType::Float64,
225 true,
226 ),
227 ]
228 .into_iter()
229 .map(Arc::new)
230 .collect())
231 }
232
233 fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
234 Ok(Box::new(CovarianceAccumulator::try_new(
235 StatsType::Population,
236 )?))
237 }
238
239 fn documentation(&self) -> Option<&Documentation> {
240 self.doc()
241 }
242}
243
244#[derive(Debug)]
258pub struct CovarianceAccumulator {
259 algo_const: f64,
260 mean1: f64,
261 mean2: f64,
262 count: u64,
263 stats_type: StatsType,
264}
265
266impl CovarianceAccumulator {
267 pub fn try_new(s_type: StatsType) -> Result<Self> {
269 Ok(Self {
270 algo_const: 0_f64,
271 mean1: 0_f64,
272 mean2: 0_f64,
273 count: 0_u64,
274 stats_type: s_type,
275 })
276 }
277
278 pub fn get_count(&self) -> u64 {
279 self.count
280 }
281
282 pub fn get_mean1(&self) -> f64 {
283 self.mean1
284 }
285
286 pub fn get_mean2(&self) -> f64 {
287 self.mean2
288 }
289
290 pub fn get_algo_const(&self) -> f64 {
291 self.algo_const
292 }
293}
294
295impl Accumulator for CovarianceAccumulator {
296 fn state(&mut self) -> Result<Vec<ScalarValue>> {
297 Ok(vec![
298 ScalarValue::from(self.count),
299 ScalarValue::from(self.mean1),
300 ScalarValue::from(self.mean2),
301 ScalarValue::from(self.algo_const),
302 ])
303 }
304
305 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
306 let values1 = &cast(&values[0], &DataType::Float64)?;
307 let values2 = &cast(&values[1], &DataType::Float64)?;
308
309 let mut arr1 = downcast_value!(values1, Float64Array).iter().flatten();
310 let mut arr2 = downcast_value!(values2, Float64Array).iter().flatten();
311
312 for i in 0..values1.len() {
313 let value1 = if values1.is_valid(i) {
314 arr1.next()
315 } else {
316 None
317 };
318 let value2 = if values2.is_valid(i) {
319 arr2.next()
320 } else {
321 None
322 };
323
324 if value1.is_none() || value2.is_none() {
325 continue;
326 }
327
328 let value1 = unwrap_or_internal_err!(value1);
329 let value2 = unwrap_or_internal_err!(value2);
330 let new_count = self.count + 1;
331 let delta1 = value1 - self.mean1;
332 let new_mean1 = delta1 / new_count as f64 + self.mean1;
333 let delta2 = value2 - self.mean2;
334 let new_mean2 = delta2 / new_count as f64 + self.mean2;
335 let new_c = delta1 * (value2 - new_mean2) + self.algo_const;
336
337 self.count += 1;
338 self.mean1 = new_mean1;
339 self.mean2 = new_mean2;
340 self.algo_const = new_c;
341 }
342
343 Ok(())
344 }
345
346 fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
347 let values1 = &cast(&values[0], &DataType::Float64)?;
348 let values2 = &cast(&values[1], &DataType::Float64)?;
349 let mut arr1 = downcast_value!(values1, Float64Array).iter().flatten();
350 let mut arr2 = downcast_value!(values2, Float64Array).iter().flatten();
351
352 for i in 0..values1.len() {
353 let value1 = if values1.is_valid(i) {
354 arr1.next()
355 } else {
356 None
357 };
358 let value2 = if values2.is_valid(i) {
359 arr2.next()
360 } else {
361 None
362 };
363
364 if value1.is_none() || value2.is_none() {
365 continue;
366 }
367
368 let value1 = unwrap_or_internal_err!(value1);
369 let value2 = unwrap_or_internal_err!(value2);
370
371 let new_count = self.count - 1;
372 let delta1 = self.mean1 - value1;
373 let new_mean1 = delta1 / new_count as f64 + self.mean1;
374 let delta2 = self.mean2 - value2;
375 let new_mean2 = delta2 / new_count as f64 + self.mean2;
376 let new_c = self.algo_const - delta1 * (new_mean2 - value2);
377
378 self.count -= 1;
379 self.mean1 = new_mean1;
380 self.mean2 = new_mean2;
381 self.algo_const = new_c;
382 }
383
384 Ok(())
385 }
386
387 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
388 let counts = downcast_value!(states[0], UInt64Array);
389 let means1 = downcast_value!(states[1], Float64Array);
390 let means2 = downcast_value!(states[2], Float64Array);
391 let cs = downcast_value!(states[3], Float64Array);
392
393 for i in 0..counts.len() {
394 let c = counts.value(i);
395 if c == 0_u64 {
396 continue;
397 }
398 let new_count = self.count + c;
399 let new_mean1 = self.mean1 * self.count as f64 / new_count as f64
400 + means1.value(i) * c as f64 / new_count as f64;
401 let new_mean2 = self.mean2 * self.count as f64 / new_count as f64
402 + means2.value(i) * c as f64 / new_count as f64;
403 let delta1 = self.mean1 - means1.value(i);
404 let delta2 = self.mean2 - means2.value(i);
405 let new_c = self.algo_const
406 + cs.value(i)
407 + delta1 * delta2 * self.count as f64 * c as f64 / new_count as f64;
408
409 self.count = new_count;
410 self.mean1 = new_mean1;
411 self.mean2 = new_mean2;
412 self.algo_const = new_c;
413 }
414 Ok(())
415 }
416
417 fn evaluate(&mut self) -> Result<ScalarValue> {
418 let count = match self.stats_type {
419 StatsType::Population => self.count,
420 StatsType::Sample => {
421 if self.count > 0 {
422 self.count - 1
423 } else {
424 self.count
425 }
426 }
427 };
428
429 if count == 0 {
430 Ok(ScalarValue::Float64(None))
431 } else {
432 Ok(ScalarValue::Float64(Some(self.algo_const / count as f64)))
433 }
434 }
435
436 fn size(&self) -> usize {
437 size_of_val(self)
438 }
439}