datafusion_expr_common/
accumulator.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Accumulator module contains the trait definition for aggregation function's accumulators.
19
20use arrow::array::ArrayRef;
21use datafusion_common::{internal_err, Result, ScalarValue};
22use std::fmt::Debug;
23
24/// Tracks an aggregate function's state.
25///
26/// `Accumulator`s are stateful objects that implement a single group. They
27/// aggregate values from multiple rows together into a final output aggregate.
28///
29/// [`GroupsAccumulator]` is an additional more performant (but also complex) API
30/// that manages state for multiple groups at once.
31///
32/// An accumulator knows how to:
33/// * update its state from inputs via [`update_batch`]
34///
35/// * compute the final value from its internal state via [`evaluate`]
36///
37/// * retract an update to its state from given inputs via
38///   [`retract_batch`] (when used as a window aggregate [window
39///   function])
40///
41/// * convert its internal state to a vector of aggregate values via
42///   [`state`] and combine the state from multiple accumulators
43///   via [`merge_batch`], as part of efficient multi-phase grouping.
44///
45/// [`update_batch`]: Self::update_batch
46/// [`retract_batch`]: Self::retract_batch
47/// [`state`]: Self::state
48/// [`evaluate`]: Self::evaluate
49/// [`merge_batch`]: Self::merge_batch
50/// [window function]: https://en.wikipedia.org/wiki/Window_function_(SQL)
51pub trait Accumulator: Send + Sync + Debug {
52    /// Updates the accumulator's state from its input.
53    ///
54    /// `values` contains the arguments to this aggregate function.
55    ///
56    /// For example, the `SUM` accumulator maintains a running sum,
57    /// and `update_batch` adds each of the input values to the
58    /// running sum.
59    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()>;
60
61    /// Returns the final aggregate value, consuming the internal state.
62    ///
63    /// For example, the `SUM` accumulator maintains a running sum,
64    /// and `evaluate` will produce that running sum as its output.
65    ///
66    /// This function should not be called twice, otherwise it will
67    /// result in potentially non-deterministic behavior.
68    ///
69    /// This function gets `&mut self` to allow for the accumulator to build
70    /// arrow-compatible internal state that can be returned without copying
71    /// when possible (for example distinct strings)
72    fn evaluate(&mut self) -> Result<ScalarValue>;
73
74    /// Returns the allocated size required for this accumulator, in
75    /// bytes, including `Self`.
76    ///
77    /// This value is used to calculate the memory used during
78    /// execution so DataFusion can stay within its allotted limit.
79    ///
80    /// "Allocated" means that for internal containers such as `Vec`,
81    /// the `capacity` should be used not the `len`.
82    fn size(&self) -> usize;
83
84    /// Returns the intermediate state of the accumulator, consuming the
85    /// intermediate state.
86    ///
87    /// This function should not be called twice, otherwise it will
88    /// result in potentially non-deterministic behavior.
89    ///
90    /// This function gets `&mut self` to allow for the accumulator to build
91    /// arrow-compatible internal state that can be returned without copying
92    /// when possible (for example distinct strings).
93    ///
94    /// Intermediate state is used for "multi-phase" grouping in
95    /// DataFusion, where an aggregate is computed in parallel with
96    /// multiple `Accumulator` instances, as described below:
97    ///
98    /// # Multi-Phase Grouping
99    ///
100    /// ```text
101    ///                               ▲
102    ///                               │                   evaluate() is called to
103    ///                               │                   produce the final aggregate
104    ///                               │                   value per group
105    ///                               │
106    ///                  ┌─────────────────────────┐
107    ///                  │GroupBy                  │
108    ///                  │(AggregateMode::Final)   │      state() is called for each
109    ///                  │                         │      group and the resulting
110    ///                  └─────────────────────────┘      RecordBatches passed to the
111    ///                                                   Final GroupBy via merge_batch()
112    ///                               ▲
113    ///                               │
114    ///              ┌────────────────┴───────────────┐
115    ///              │                                │
116    ///              │                                │
117    /// ┌─────────────────────────┐      ┌─────────────────────────┐
118    /// │        GroupBy          │      │        GroupBy          │
119    /// │(AggregateMode::Partial) │      │(AggregateMode::Partial) │
120    /// └─────────────────────────┘      └─────────────────────────┘
121    ///              ▲                                ▲
122    ///              │                                │    update_batch() is called for
123    ///              │                                │    each input RecordBatch
124    ///         .─────────.                      .─────────.
125    ///      ,─'           '─.                ,─'           '─.
126    ///     ;      Input      :              ;      Input      :
127    ///     :   Partition 0   ;              :   Partition 1   ;
128    ///      ╲               ╱                ╲               ╱
129    ///       '─.         ,─'                  '─.         ,─'
130    ///          `───────'                        `───────'
131    /// ```
132    ///
133    /// The partial state is serialized as `Arrays` and then combined
134    /// with other partial states from different instances of this
135    /// Accumulator (that ran on different partitions, for example).
136    ///
137    /// The state can be and often is a different type than the output
138    /// type of the [`Accumulator`] and needs different merge
139    /// operations (for example, the partial state for `COUNT` needs
140    /// to be summed together)
141    ///
142    /// Some accumulators can return multiple values for their
143    /// intermediate states. For example, the average accumulator
144    /// tracks `sum` and `n`, and this function should return a vector
145    /// of two values, sum and n.
146    ///
147    /// Note that [`ScalarValue::List`] can be used to pass multiple
148    /// values if the number of intermediate values is not known at
149    /// planning time (e.g. for `MEDIAN`)
150    ///
151    /// # Multi-phase repartitioned Grouping
152    ///
153    /// Many multi-phase grouping plans contain a Repartition operation
154    /// as well as shown below:
155    ///
156    /// ```text
157    ///                ▲                          ▲
158    ///                │                          │
159    ///                │                          │
160    ///                │                          │
161    ///                │                          │
162    ///                │                          │
163    ///    ┌───────────────────────┐  ┌───────────────────────┐       4. Each AggregateMode::Final
164    ///    │GroupBy                │  │GroupBy                │       GroupBy has an entry for its
165    ///    │(AggregateMode::Final) │  │(AggregateMode::Final) │       subset of groups (in this case
166    ///    │                       │  │                       │       that means half the entries)
167    ///    └───────────────────────┘  └───────────────────────┘
168    ///                ▲                          ▲
169    ///                │                          │
170    ///                └─────────────┬────────────┘
171    ///                              │
172    ///                              │
173    ///                              │
174    ///                 ┌─────────────────────────┐                   3. Repartitioning by hash(group
175    ///                 │       Repartition       │                   keys) ensures that each distinct
176    ///                 │         HASH(x)         │                   group key now appears in exactly
177    ///                 └─────────────────────────┘                   one partition
178    ///                              ▲
179    ///                              │
180    ///              ┌───────────────┴─────────────┐
181    ///              │                             │
182    ///              │                             │
183    /// ┌─────────────────────────┐  ┌──────────────────────────┐     2. Each AggregateMode::Partial
184    /// │        GroupBy          │  │       GroupBy            │     GroupBy has an entry for *all*
185    /// │(AggregateMode::Partial) │  │ (AggregateMode::Partial) │     the groups
186    /// └─────────────────────────┘  └──────────────────────────┘
187    ///              ▲                             ▲
188    ///              │                             │
189    ///              │                             │
190    ///         .─────────.                   .─────────.
191    ///      ,─'           '─.             ,─'           '─.
192    ///     ;      Input      :           ;      Input      :         1. Since input data is
193    ///     :   Partition 0   ;           :   Partition 1   ;         arbitrarily or RoundRobin
194    ///      ╲               ╱             ╲               ╱          distributed, each partition
195    ///       '─.         ,─'               '─.         ,─'           likely has all distinct
196    ///          `───────'                     `───────'
197    /// ```
198    ///
199    /// This structure is used so that the `AggregateMode::Partial` accumulators
200    /// reduces the cardinality of the input as soon as possible. Typically,
201    /// each partial accumulator sees all groups in the input as the group keys
202    /// are evenly distributed across the input.
203    ///
204    /// The final output is computed by repartitioning the result of
205    /// [`Self::state`] from each Partial aggregate and `hash(group keys)` so
206    /// that each distinct group key appears in exactly one of the
207    /// `AggregateMode::Final` GroupBy nodes. The outputs of the final nodes are
208    /// then unioned together to produce the overall final output.
209    ///
210    /// Here is an example that shows the distribution of groups in the
211    /// different phases
212    ///
213    /// ```text
214    ///               ┌─────┐                ┌─────┐
215    ///               │  1  │                │  3  │
216    ///               ├─────┤                ├─────┤
217    ///               │  2  │                │  4  │                After repartitioning by
218    ///               └─────┘                └─────┘                hash(group keys), each distinct
219    ///               ┌─────┐                ┌─────┐                group key now appears in exactly
220    ///               │  1  │                │  3  │                one partition
221    ///               ├─────┤                ├─────┤
222    ///               │  2  │                │  4  │
223    ///               └─────┘                └─────┘
224    ///
225    ///
226    /// ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─
227    ///
228    ///               ┌─────┐                ┌─────┐
229    ///               │  2  │                │  2  │
230    ///               ├─────┤                ├─────┤
231    ///               │  1  │                │  2  │
232    ///               ├─────┤                ├─────┤
233    ///               │  3  │                │  3  │
234    ///               ├─────┤                ├─────┤
235    ///               │  4  │                │  1  │
236    ///               └─────┘                └─────┘                Input data is arbitrarily or
237    ///                 ...                    ...                  RoundRobin distributed, each
238    ///               ┌─────┐                ┌─────┐                partition likely has all
239    ///               │  1  │                │  4  │                distinct group keys
240    ///               ├─────┤                ├─────┤
241    ///               │  4  │                │  3  │
242    ///               ├─────┤                ├─────┤
243    ///               │  1  │                │  1  │
244    ///               ├─────┤                ├─────┤
245    ///               │  4  │                │  3  │
246    ///               └─────┘                └─────┘
247    ///
248    ///           group values           group values
249    ///           in partition 0         in partition 1
250    /// ```
251    fn state(&mut self) -> Result<Vec<ScalarValue>>;
252
253    /// Updates the accumulator's state from an `Array` containing one
254    /// or more intermediate values.
255    ///
256    /// For some aggregates (such as `SUM`), merge_batch is the same
257    /// as `update_batch`, but for some aggregates (such as `COUNT`)
258    /// the operations differ. See [`Self::state`] for more details on how
259    /// state is used and merged.
260    ///
261    /// The `states` array passed was formed by concatenating the
262    /// results of calling [`Self::state`] on zero or more other
263    /// `Accumulator` instances.
264    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()>;
265
266    /// Retracts (removed) an update (caused by the given inputs) to
267    /// accumulator's state.
268    ///
269    /// This is the inverse operation of [`Self::update_batch`] and is used
270    /// to incrementally calculate window aggregates where the `OVER`
271    /// clause defines a bounded window.
272    ///
273    /// # Example
274    ///
275    /// For example, given the following input partition
276    ///
277    /// ```text
278    ///                     │      current      │
279    ///                            window
280    ///                     │                   │
281    ///                ┌────┬────┬────┬────┬────┬────┬────┬────┬────┐
282    ///     Input      │ A  │ B  │ C  │ D  │ E  │ F  │ G  │ H  │ I  │
283    ///   partition    └────┴────┴────┴────┼────┴────┴────┴────┼────┘
284    ///
285    ///                                    │         next      │
286    ///                                             window
287    /// ```
288    ///
289    /// First, [`Self::evaluate`] will be called to produce the output
290    /// for the current window.
291    ///
292    /// Then, to advance to the next window:
293    ///
294    /// First, [`Self::retract_batch`] will be called with the values
295    /// that are leaving the window, `[B, C, D]` and then
296    /// [`Self::update_batch`] will be called with the values that are
297    /// entering the window, `[F, G, H]`.
298    fn retract_batch(&mut self, _values: &[ArrayRef]) -> Result<()> {
299        // TODO add retract for all accumulators
300        internal_err!(
301            "Retract should be implemented for aggregate functions when used with custom window frame queries"
302        )
303    }
304
305    /// Does the accumulator support incrementally updating its value
306    /// by *removing* values.
307    ///
308    /// If this function returns true, [`Self::retract_batch`] will be
309    /// called for sliding window functions such as queries with an
310    /// `OVER (ROWS BETWEEN 1 PRECEDING AND 2 FOLLOWING)`
311    fn supports_retract_batch(&self) -> bool {
312        false
313    }
314}