datafusion_functions_aggregate/
covariance.rs1use std::fmt::Debug;
21use std::mem::size_of_val;
22
23use arrow::{
24 array::{ArrayRef, Float64Array, UInt64Array},
25 compute::kernels::cast,
26 datatypes::{DataType, Field},
27};
28
29use datafusion_common::{
30 downcast_value, plan_err, unwrap_or_internal_err, DataFusionError, Result,
31 ScalarValue,
32};
33use datafusion_expr::{
34 function::{AccumulatorArgs, StateFieldsArgs},
35 type_coercion::aggregates::NUMERICS,
36 utils::format_state_name,
37 Accumulator, AggregateUDFImpl, Documentation, Signature, Volatility,
38};
39use datafusion_functions_aggregate_common::stats::StatsType;
40use datafusion_macros::user_doc;
41
42make_udaf_expr_and_func!(
43 CovarianceSample,
44 covar_samp,
45 y x,
46 "Computes the sample covariance.",
47 covar_samp_udaf
48);
49
50make_udaf_expr_and_func!(
51 CovariancePopulation,
52 covar_pop,
53 y x,
54 "Computes the population covariance.",
55 covar_pop_udaf
56);
57
58#[user_doc(
59 doc_section(label = "Statistical Functions"),
60 description = "Returns the sample covariance of a set of number pairs.",
61 syntax_example = "covar_samp(expression1, expression2)",
62 sql_example = r#"```sql
63> SELECT covar_samp(column1, column2) FROM table_name;
64+-----------------------------------+
65| covar_samp(column1, column2) |
66+-----------------------------------+
67| 8.25 |
68+-----------------------------------+
69```"#,
70 standard_argument(name = "expression1", prefix = "First"),
71 standard_argument(name = "expression2", prefix = "Second")
72)]
73pub struct CovarianceSample {
74 signature: Signature,
75 aliases: Vec<String>,
76}
77
78impl Debug for CovarianceSample {
79 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
80 f.debug_struct("CovarianceSample")
81 .field("name", &self.name())
82 .field("signature", &self.signature)
83 .finish()
84 }
85}
86
87impl Default for CovarianceSample {
88 fn default() -> Self {
89 Self::new()
90 }
91}
92
93impl CovarianceSample {
94 pub fn new() -> Self {
95 Self {
96 aliases: vec![String::from("covar")],
97 signature: Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable),
98 }
99 }
100}
101
102impl AggregateUDFImpl for CovarianceSample {
103 fn as_any(&self) -> &dyn std::any::Any {
104 self
105 }
106
107 fn name(&self) -> &str {
108 "covar_samp"
109 }
110
111 fn signature(&self) -> &Signature {
112 &self.signature
113 }
114
115 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
116 if !arg_types[0].is_numeric() {
117 return plan_err!("Covariance requires numeric input types");
118 }
119
120 Ok(DataType::Float64)
121 }
122
123 fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> {
124 let name = args.name;
125 Ok(vec![
126 Field::new(format_state_name(name, "count"), DataType::UInt64, true),
127 Field::new(format_state_name(name, "mean1"), DataType::Float64, true),
128 Field::new(format_state_name(name, "mean2"), DataType::Float64, true),
129 Field::new(
130 format_state_name(name, "algo_const"),
131 DataType::Float64,
132 true,
133 ),
134 ])
135 }
136
137 fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
138 Ok(Box::new(CovarianceAccumulator::try_new(StatsType::Sample)?))
139 }
140
141 fn aliases(&self) -> &[String] {
142 &self.aliases
143 }
144
145 fn documentation(&self) -> Option<&Documentation> {
146 self.doc()
147 }
148}
149
150#[user_doc(
151 doc_section(label = "Statistical Functions"),
152 description = "Returns the sample covariance of a set of number pairs.",
153 syntax_example = "covar_samp(expression1, expression2)",
154 sql_example = r#"```sql
155> SELECT covar_samp(column1, column2) FROM table_name;
156+-----------------------------------+
157| covar_samp(column1, column2) |
158+-----------------------------------+
159| 8.25 |
160+-----------------------------------+
161```"#,
162 standard_argument(name = "expression1", prefix = "First"),
163 standard_argument(name = "expression2", prefix = "Second")
164)]
165pub struct CovariancePopulation {
166 signature: Signature,
167}
168
169impl Debug for CovariancePopulation {
170 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
171 f.debug_struct("CovariancePopulation")
172 .field("name", &self.name())
173 .field("signature", &self.signature)
174 .finish()
175 }
176}
177
178impl Default for CovariancePopulation {
179 fn default() -> Self {
180 Self::new()
181 }
182}
183
184impl CovariancePopulation {
185 pub fn new() -> Self {
186 Self {
187 signature: Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable),
188 }
189 }
190}
191
192impl AggregateUDFImpl for CovariancePopulation {
193 fn as_any(&self) -> &dyn std::any::Any {
194 self
195 }
196
197 fn name(&self) -> &str {
198 "covar_pop"
199 }
200
201 fn signature(&self) -> &Signature {
202 &self.signature
203 }
204
205 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
206 if !arg_types[0].is_numeric() {
207 return plan_err!("Covariance requires numeric input types");
208 }
209
210 Ok(DataType::Float64)
211 }
212
213 fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> {
214 let name = args.name;
215 Ok(vec![
216 Field::new(format_state_name(name, "count"), DataType::UInt64, true),
217 Field::new(format_state_name(name, "mean1"), DataType::Float64, true),
218 Field::new(format_state_name(name, "mean2"), DataType::Float64, true),
219 Field::new(
220 format_state_name(name, "algo_const"),
221 DataType::Float64,
222 true,
223 ),
224 ])
225 }
226
227 fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
228 Ok(Box::new(CovarianceAccumulator::try_new(
229 StatsType::Population,
230 )?))
231 }
232
233 fn documentation(&self) -> Option<&Documentation> {
234 self.doc()
235 }
236}
237
238#[derive(Debug)]
252pub struct CovarianceAccumulator {
253 algo_const: f64,
254 mean1: f64,
255 mean2: f64,
256 count: u64,
257 stats_type: StatsType,
258}
259
260impl CovarianceAccumulator {
261 pub fn try_new(s_type: StatsType) -> Result<Self> {
263 Ok(Self {
264 algo_const: 0_f64,
265 mean1: 0_f64,
266 mean2: 0_f64,
267 count: 0_u64,
268 stats_type: s_type,
269 })
270 }
271
272 pub fn get_count(&self) -> u64 {
273 self.count
274 }
275
276 pub fn get_mean1(&self) -> f64 {
277 self.mean1
278 }
279
280 pub fn get_mean2(&self) -> f64 {
281 self.mean2
282 }
283
284 pub fn get_algo_const(&self) -> f64 {
285 self.algo_const
286 }
287}
288
289impl Accumulator for CovarianceAccumulator {
290 fn state(&mut self) -> Result<Vec<ScalarValue>> {
291 Ok(vec![
292 ScalarValue::from(self.count),
293 ScalarValue::from(self.mean1),
294 ScalarValue::from(self.mean2),
295 ScalarValue::from(self.algo_const),
296 ])
297 }
298
299 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
300 let values1 = &cast(&values[0], &DataType::Float64)?;
301 let values2 = &cast(&values[1], &DataType::Float64)?;
302
303 let mut arr1 = downcast_value!(values1, Float64Array).iter().flatten();
304 let mut arr2 = downcast_value!(values2, Float64Array).iter().flatten();
305
306 for i in 0..values1.len() {
307 let value1 = if values1.is_valid(i) {
308 arr1.next()
309 } else {
310 None
311 };
312 let value2 = if values2.is_valid(i) {
313 arr2.next()
314 } else {
315 None
316 };
317
318 if value1.is_none() || value2.is_none() {
319 continue;
320 }
321
322 let value1 = unwrap_or_internal_err!(value1);
323 let value2 = unwrap_or_internal_err!(value2);
324 let new_count = self.count + 1;
325 let delta1 = value1 - self.mean1;
326 let new_mean1 = delta1 / new_count as f64 + self.mean1;
327 let delta2 = value2 - self.mean2;
328 let new_mean2 = delta2 / new_count as f64 + self.mean2;
329 let new_c = delta1 * (value2 - new_mean2) + self.algo_const;
330
331 self.count += 1;
332 self.mean1 = new_mean1;
333 self.mean2 = new_mean2;
334 self.algo_const = new_c;
335 }
336
337 Ok(())
338 }
339
340 fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
341 let values1 = &cast(&values[0], &DataType::Float64)?;
342 let values2 = &cast(&values[1], &DataType::Float64)?;
343 let mut arr1 = downcast_value!(values1, Float64Array).iter().flatten();
344 let mut arr2 = downcast_value!(values2, Float64Array).iter().flatten();
345
346 for i in 0..values1.len() {
347 let value1 = if values1.is_valid(i) {
348 arr1.next()
349 } else {
350 None
351 };
352 let value2 = if values2.is_valid(i) {
353 arr2.next()
354 } else {
355 None
356 };
357
358 if value1.is_none() || value2.is_none() {
359 continue;
360 }
361
362 let value1 = unwrap_or_internal_err!(value1);
363 let value2 = unwrap_or_internal_err!(value2);
364
365 let new_count = self.count - 1;
366 let delta1 = self.mean1 - value1;
367 let new_mean1 = delta1 / new_count as f64 + self.mean1;
368 let delta2 = self.mean2 - value2;
369 let new_mean2 = delta2 / new_count as f64 + self.mean2;
370 let new_c = self.algo_const - delta1 * (new_mean2 - value2);
371
372 self.count -= 1;
373 self.mean1 = new_mean1;
374 self.mean2 = new_mean2;
375 self.algo_const = new_c;
376 }
377
378 Ok(())
379 }
380
381 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
382 let counts = downcast_value!(states[0], UInt64Array);
383 let means1 = downcast_value!(states[1], Float64Array);
384 let means2 = downcast_value!(states[2], Float64Array);
385 let cs = downcast_value!(states[3], Float64Array);
386
387 for i in 0..counts.len() {
388 let c = counts.value(i);
389 if c == 0_u64 {
390 continue;
391 }
392 let new_count = self.count + c;
393 let new_mean1 = self.mean1 * self.count as f64 / new_count as f64
394 + means1.value(i) * c as f64 / new_count as f64;
395 let new_mean2 = self.mean2 * self.count as f64 / new_count as f64
396 + means2.value(i) * c as f64 / new_count as f64;
397 let delta1 = self.mean1 - means1.value(i);
398 let delta2 = self.mean2 - means2.value(i);
399 let new_c = self.algo_const
400 + cs.value(i)
401 + delta1 * delta2 * self.count as f64 * c as f64 / new_count as f64;
402
403 self.count = new_count;
404 self.mean1 = new_mean1;
405 self.mean2 = new_mean2;
406 self.algo_const = new_c;
407 }
408 Ok(())
409 }
410
411 fn evaluate(&mut self) -> Result<ScalarValue> {
412 let count = match self.stats_type {
413 StatsType::Population => self.count,
414 StatsType::Sample => {
415 if self.count > 0 {
416 self.count - 1
417 } else {
418 self.count
419 }
420 }
421 };
422
423 if count == 0 {
424 Ok(ScalarValue::Float64(None))
425 } else {
426 Ok(ScalarValue::Float64(Some(self.algo_const / count as f64)))
427 }
428 }
429
430 fn size(&self) -> usize {
431 size_of_val(self)
432 }
433}