Skip to content

Commit 2fb811a

Browse files
zhli1142015facebook-github-bot
authored andcommitted
Add collect_set Spark aggregation function (facebookincubator#10038)
Summary: Doc: https://docs.databricks.com/en/sql/language-manual/functions/collect_set.html Code: https://github.com/apache/spark/blob/master/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala#L39C16-L39C23 https://github.com/apache/spark/blob/master/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala#L147C12-L147C22 There are 3 semantic difference from `set_agg`: 1. Null values are excluded. ``` import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ val jsonStr = """{"txn":null}""" val jsonStr1 = """{"txn":null}""" val jsonStr2 = """{"txn":null}""" val jsonStr3 = """{"txn":null}""" val jsonSchema = StructType(Seq(StructField("txn",LongType,true))) val df = spark.read.schema(jsonSchema).json(Seq(jsonStr, jsonStr1, jsonStr2, jsonStr3).toDS) df.select(collect_set($"txn")).show +----------------+ |collect_set(txn) | +----------------+ | [] | +----------------+ ``` 2. Nested Nulls are allowed. ``` import org.apache.spark.sql.types._ val jsonStr = """{"txn":{"appId":"txnId","version":0,"lastUpdated":null}}""" val jsonStr1 = """{"txn":{"appId":"txnId","version":1,"lastUpdated":1}}""" val jsonStr2 = """{"txn":{"appId":"txnId","version":0,"lastUpdated":null}}""" val jsonStr3 = """{"txn":{"appId":"txnId","version":1,"lastUpdated":null}}""" val jsonSchema = StructType(Seq(StructField("txn", StructType(Seq(StructField("appId",StringType,true),StructField("lastUpdated",LongType,true),StructField("version",LongType,true))),true))) val df = spark.read.schema(jsonSchema).json(Seq(jsonStr, jsonStr1, jsonStr2, jsonStr3).toDS) df.select(collect_set(col("txn"))).show(false) +---------------------------------------------------+ |collect_set(txn) | +---------------------------------------------------+ |[{txnId, 1, 1}, {txnId, null, 0}, {txnId, null, 1}] | +---------------------------------------------------+ ``` 3. Map type is not allowed. Changes: 1. Move `SetBaseAggregate` and `SetAggAggregate` to lib folder. 2. Register `SetAggAggregate` with `ignoreNulls=true` (1), `throwOnNestedNulls=false` (2) for complex types. 3. Not register map type for this function. Pull Request resolved: facebookincubator#10038 Reviewed By: kagamiori Differential Revision: D60386641 Pulled By: Yuhta fbshipit-source-id: eb791b9a2471cf4841d14f2bc64073e18f78d675
1 parent 73ca922 commit 2fb811a

File tree

9 files changed

+676
-305
lines changed

9 files changed

+676
-305
lines changed

velox/docs/functions/spark/aggregate.rst

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,29 @@ General Aggregate Functions
5858
Returns an array created from the input ``x`` elements. Ignores null
5959
inputs, and returns an empty array when all inputs are null.
6060

61+
.. spark:function:: collect_set(x) -> array<[same as x]>
62+
63+
Returns an array consisting of all unique values from the input ``x`` elements.
64+
Null values are excluded, and returns an empty array when all inputs are null.
65+
66+
Example::
67+
68+
SELECT collect_set(i)
69+
FROM (
70+
VALUES
71+
(1),
72+
(null)
73+
) AS t(i);
74+
-- ARRAY[1]
75+
76+
SELECT collect_set(elements)
77+
FROM (
78+
VALUES
79+
ARRAY[1, 2],
80+
ARRAY[1, null]
81+
) AS t(elements);
82+
-- ARRAY[ARRAY[1, 2], ARRAY[1, null]]
83+
6184
.. spark:function:: first(x) -> x
6285
6386
Returns the first value of `x`.
Lines changed: 326 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,326 @@
1+
/*
2+
* Copyright (c) Facebook, Inc. and its affiliates.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
#pragma once
17+
18+
#include "velox/exec/Aggregate.h"
19+
#include "velox/exec/SetAccumulator.h"
20+
#include "velox/functions/lib/CheckNestedNulls.h"
21+
22+
namespace facebook::velox::functions::aggregate {
23+
24+
template <typename T, bool ignoreNulls = false>
25+
class SetBaseAggregate : public exec::Aggregate {
26+
public:
27+
explicit SetBaseAggregate(const TypePtr& resultType)
28+
: exec::Aggregate(resultType) {}
29+
30+
using AccumulatorType = velox::aggregate::prestosql::SetAccumulator<T>;
31+
32+
int32_t accumulatorFixedWidthSize() const override {
33+
return sizeof(AccumulatorType);
34+
}
35+
36+
bool isFixedSize() const override {
37+
return false;
38+
}
39+
40+
void extractValues(char** groups, int32_t numGroups, VectorPtr* result)
41+
override {
42+
auto arrayVector = (*result)->as<ArrayVector>();
43+
arrayVector->resize(numGroups);
44+
45+
auto* rawOffsets = arrayVector->offsets()->asMutable<vector_size_t>();
46+
auto* rawSizes = arrayVector->sizes()->asMutable<vector_size_t>();
47+
48+
vector_size_t numValues = 0;
49+
uint64_t* rawNulls = getRawNulls(arrayVector);
50+
for (auto i = 0; i < numGroups; ++i) {
51+
auto* group = groups[i];
52+
if (isNull(group)) {
53+
arrayVector->setNull(i, true);
54+
} else {
55+
clearNull(rawNulls, i);
56+
57+
const auto size = value(group)->size();
58+
59+
rawOffsets[i] = numValues;
60+
rawSizes[i] = size;
61+
62+
numValues += size;
63+
}
64+
}
65+
66+
if constexpr (std::is_same_v<T, ComplexType>) {
67+
auto values = arrayVector->elements();
68+
values->resize(numValues);
69+
70+
vector_size_t offset = 0;
71+
for (auto i = 0; i < numGroups; ++i) {
72+
auto* group = groups[i];
73+
if (!isNull(group)) {
74+
offset += value(group)->extractValues(*values, offset);
75+
}
76+
}
77+
} else {
78+
auto values = arrayVector->elements()->as<FlatVector<T>>();
79+
values->resize(numValues);
80+
81+
vector_size_t offset = 0;
82+
for (auto i = 0; i < numGroups; ++i) {
83+
auto* group = groups[i];
84+
if (!isNull(group)) {
85+
offset += value(group)->extractValues(*values, offset);
86+
}
87+
}
88+
}
89+
}
90+
91+
void extractAccumulators(char** groups, int32_t numGroups, VectorPtr* result)
92+
override {
93+
return extractValues(groups, numGroups, result);
94+
}
95+
96+
void addIntermediateResults(
97+
char** groups,
98+
const SelectivityVector& rows,
99+
const std::vector<VectorPtr>& args,
100+
bool /*mayPushdown*/) override {
101+
addIntermediateResultsInt(groups, rows, args, false);
102+
}
103+
104+
void addSingleGroupIntermediateResults(
105+
char* group,
106+
const SelectivityVector& rows,
107+
const std::vector<VectorPtr>& args,
108+
bool /*mayPushdown*/) override {
109+
addSingleGroupIntermediateResultsInt(group, rows, args, false);
110+
}
111+
112+
protected:
113+
inline AccumulatorType* value(char* group) {
114+
return reinterpret_cast<AccumulatorType*>(group + Aggregate::offset_);
115+
}
116+
117+
void addIntermediateResultsInt(
118+
char** groups,
119+
const SelectivityVector& rows,
120+
const std::vector<VectorPtr>& args,
121+
bool clearNullForAllInputs) {
122+
decoded_.decode(*args[0], rows);
123+
124+
auto baseArray = decoded_.base()->template as<ArrayVector>();
125+
decodedElements_.decode(*baseArray->elements());
126+
127+
rows.applyToSelected([&](vector_size_t i) {
128+
if (decoded_.isNullAt(i)) {
129+
if (clearNullForAllInputs) {
130+
clearNull(groups[i]);
131+
}
132+
return;
133+
}
134+
135+
auto* group = groups[i];
136+
clearNull(group);
137+
138+
auto tracker = trackRowSize(group);
139+
140+
auto decodedIndex = decoded_.index(i);
141+
if constexpr (ignoreNulls) {
142+
value(group)->addNonNullValues(
143+
*baseArray, decodedIndex, decodedElements_, allocator_);
144+
} else {
145+
value(group)->addValues(
146+
*baseArray, decodedIndex, decodedElements_, allocator_);
147+
}
148+
});
149+
}
150+
151+
void addSingleGroupIntermediateResultsInt(
152+
char* group,
153+
const SelectivityVector& rows,
154+
const std::vector<VectorPtr>& args,
155+
bool clearNullForAllInputs) {
156+
decoded_.decode(*args[0], rows);
157+
158+
auto baseArray = decoded_.base()->template as<ArrayVector>();
159+
160+
decodedElements_.decode(*baseArray->elements());
161+
162+
auto* accumulator = value(group);
163+
164+
auto tracker = trackRowSize(group);
165+
rows.applyToSelected([&](vector_size_t i) {
166+
if (decoded_.isNullAt(i)) {
167+
if (clearNullForAllInputs) {
168+
clearNull(group);
169+
}
170+
return;
171+
}
172+
173+
clearNull(group);
174+
175+
auto decodedIndex = decoded_.index(i);
176+
if constexpr (ignoreNulls) {
177+
accumulator->addNonNullValues(
178+
*baseArray, decodedIndex, decodedElements_, allocator_);
179+
} else {
180+
accumulator->addValues(
181+
*baseArray, decodedIndex, decodedElements_, allocator_);
182+
}
183+
});
184+
}
185+
186+
void initializeNewGroupsInternal(
187+
char** groups,
188+
folly::Range<const vector_size_t*> indices) override {
189+
const auto& type = resultType()->childAt(0);
190+
exec::Aggregate::setAllNulls(groups, indices);
191+
for (auto i : indices) {
192+
new (groups[i] + offset_) AccumulatorType(type, allocator_);
193+
}
194+
}
195+
196+
void destroyInternal(folly::Range<char**> groups) override {
197+
for (auto* group : groups) {
198+
if (isInitialized(group) && !isNull(group)) {
199+
value(group)->free(*allocator_);
200+
}
201+
}
202+
}
203+
204+
DecodedVector decoded_;
205+
DecodedVector decodedElements_;
206+
};
207+
208+
template <typename T, bool ignoreNulls = false>
209+
class SetAggAggregate : public SetBaseAggregate<T, ignoreNulls> {
210+
public:
211+
explicit SetAggAggregate(
212+
const TypePtr& resultType,
213+
const bool throwOnNestedNulls = false)
214+
: SetBaseAggregate<T, ignoreNulls>(resultType),
215+
throwOnNestedNulls_(throwOnNestedNulls) {}
216+
217+
using Base = SetBaseAggregate<T, ignoreNulls>;
218+
219+
bool supportsToIntermediate() const override {
220+
return true;
221+
}
222+
223+
void toIntermediate(
224+
const SelectivityVector& rows,
225+
std::vector<VectorPtr>& args,
226+
VectorPtr& result) const override {
227+
const auto& elements = args[0];
228+
229+
if (throwOnNestedNulls_) {
230+
DecodedVector decodedElements(*elements, rows);
231+
auto indices = decodedElements.indices();
232+
rows.applyToSelected([&](vector_size_t i) {
233+
velox::functions::checkNestedNulls(
234+
decodedElements, indices, i, throwOnNestedNulls_);
235+
});
236+
}
237+
238+
const auto numRows = rows.size();
239+
240+
// Convert input to a single-entry array.
241+
242+
// Set nulls for rows not present in 'rows'.
243+
auto* pool = Base::allocator_->pool();
244+
BufferPtr nulls = allocateNulls(numRows, pool);
245+
memcpy(
246+
nulls->asMutable<uint64_t>(),
247+
rows.asRange().bits(),
248+
bits::nbytes(numRows));
249+
250+
// Set offsets to 0, 1, 2, 3...
251+
BufferPtr offsets = allocateOffsets(numRows, pool);
252+
auto* rawOffsets = offsets->asMutable<vector_size_t>();
253+
std::iota(rawOffsets, rawOffsets + numRows, 0);
254+
255+
// Set sizes to 1.
256+
BufferPtr sizes = allocateSizes(numRows, pool);
257+
auto* rawSizes = sizes->asMutable<vector_size_t>();
258+
std::fill(rawSizes, rawSizes + numRows, 1);
259+
260+
result = std::make_shared<ArrayVector>(
261+
pool,
262+
ARRAY(elements->type()),
263+
nulls,
264+
numRows,
265+
offsets,
266+
sizes,
267+
BaseVector::loadedVectorShared(elements));
268+
}
269+
270+
void addRawInput(
271+
char** groups,
272+
const SelectivityVector& rows,
273+
const std::vector<VectorPtr>& args,
274+
bool /*mayPushdown*/) override {
275+
Base::decoded_.decode(*args[0], rows);
276+
auto indices = Base::decoded_.indices();
277+
rows.applyToSelected([&](vector_size_t i) {
278+
auto* group = groups[i];
279+
Base::clearNull(group);
280+
281+
if (throwOnNestedNulls_) {
282+
velox::functions::checkNestedNulls(
283+
Base::decoded_, indices, i, throwOnNestedNulls_);
284+
}
285+
286+
auto tracker = Base::trackRowSize(group);
287+
if constexpr (ignoreNulls) {
288+
Base::value(group)->addNonNullValue(
289+
Base::decoded_, i, Base::allocator_);
290+
} else {
291+
Base::value(group)->addValue(Base::decoded_, i, Base::allocator_);
292+
}
293+
});
294+
}
295+
296+
void addSingleGroupRawInput(
297+
char* group,
298+
const SelectivityVector& rows,
299+
const std::vector<VectorPtr>& args,
300+
bool /*mayPushdown*/) override {
301+
Base::decoded_.decode(*args[0], rows);
302+
303+
Base::clearNull(group);
304+
auto* accumulator = Base::value(group);
305+
306+
auto tracker = Base::trackRowSize(group);
307+
auto indices = Base::decoded_.indices();
308+
rows.applyToSelected([&](vector_size_t i) {
309+
if (throwOnNestedNulls_) {
310+
velox::functions::checkNestedNulls(
311+
Base::decoded_, indices, i, throwOnNestedNulls_);
312+
}
313+
314+
if constexpr (ignoreNulls) {
315+
accumulator->addNonNullValue(Base::decoded_, i, Base::allocator_);
316+
} else {
317+
accumulator->addValue(Base::decoded_, i, Base::allocator_);
318+
}
319+
});
320+
}
321+
322+
private:
323+
const bool throwOnNestedNulls_;
324+
};
325+
326+
} // namespace facebook::velox::functions::aggregate

0 commit comments

Comments
 (0)