Skip to content

Commit 6f0640a

Browse files
authored
Optimize ReduceSum, ReduceMean, ReduceMin, ReduceMax (microsoft#10280)
* Optimize ReduceSum, ReduceMean, ReduceMin, ReduceMax * improve reducemax, reducemin * faster, smaller * replace std::vector by gsl::span for shapes * fix merging issues
1 parent df841ee commit 6f0640a

File tree

3 files changed

+520
-56
lines changed

3 files changed

+520
-56
lines changed

onnxruntime/core/providers/cpu/reduction/reduction_ops.cc

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -259,10 +259,15 @@ static void ValidateFastReduceRK(const gsl::span<const int64_t>& fast_shape, con
259259
}
260260

261261
static void ValidateFastReduceKRK(const gsl::span<const int64_t>& fast_shape, const Tensor& output) {
262-
ORT_ENFORCE(fast_shape.size() == 3, "Only works on matrices with two dimensions.");
262+
ORT_ENFORCE(fast_shape.size() == 3, "Only works on matrices with three dimensions.");
263263
ORT_ENFORCE(fast_shape[0] * fast_shape[2] == output.Shape().Size(), "Output size mismatch.");
264264
}
265265

266+
static void ValidateFastReduceRKR(const gsl::span<const int64_t>& fast_shape, const Tensor& output) {
267+
ORT_ENFORCE(fast_shape.size() == 3, "Only works on matrices with three dimensions.");
268+
ORT_ENFORCE(fast_shape[1] == output.Shape().Size(), "Output size mismatch.");
269+
}
270+
266271
void ReduceAggregatorBase::FastReduceKR(const Tensor&, const gsl::span<const int64_t>&, Tensor&, concurrency::ThreadPool*) {
267272
ValidateMustBeOverloaded();
268273
}
@@ -272,6 +277,9 @@ void ReduceAggregatorBase::FastReduceRK(const Tensor&, const gsl::span<const int
272277
void ReduceAggregatorBase::FastReduceKRK(const Tensor&, const gsl::span<const int64_t>&, Tensor&, concurrency::ThreadPool*) {
273278
ValidateMustBeOverloaded();
274279
}
280+
void ReduceAggregatorBase::FastReduceRKR(const Tensor&, const gsl::span<const int64_t>&, Tensor&, concurrency::ThreadPool*) {
281+
ValidateMustBeOverloaded();
282+
}
275283

276284
void NoTransposePrepareForReduce(const TensorShape& new_input_shape,
277285
gsl::span<const int64_t> reduced_axes,
@@ -624,8 +632,8 @@ FastReduceKind OptimizeShapeForFastReduce(gsl::span<const int64_t> input_shape,
624632
if (fast_shape.size() == 2) {
625633
return reduce[0] ? FastReduceKind::kRK : FastReduceKind::kKR;
626634
}
627-
if (fast_shape.size() == 3 && !reduce[0]) {
628-
return FastReduceKind::kKRK;
635+
if (fast_shape.size() == 3) {
636+
return reduce[0] ? FastReduceKind::kRKR : FastReduceKind::kKRK;
629637
}
630638
return FastReduceKind::kNone;
631639
}
@@ -671,7 +679,8 @@ bool CommonFastReduceSwitch(OpKernelContext* ctx,
671679
FastReduceKind which_fast_reduce,
672680
fast_reduce_fct* case_kr,
673681
fast_reduce_fct* case_rk,
674-
fast_reduce_fct* case_krk) {
682+
fast_reduce_fct* case_krk,
683+
fast_reduce_fct* case_rkr) {
675684
TensorShapeVector axes;
676685
const Tensor* input = ctx->Input<Tensor>(0);
677686
auto reduced_dims = input->Shape().GetDims();
@@ -715,6 +724,14 @@ bool CommonFastReduceSwitch(OpKernelContext* ctx,
715724
} else {
716725
break;
717726
}
727+
case FastReduceKind::kRKR:
728+
ValidateFastReduceRKR(fast_shape, *output);
729+
if (fast_shape[1] >= std::max(2, concurrency::ThreadPool::DegreeOfParallelism(ctx->GetOperatorThreadPool()))) {
730+
case_rkr(*input, fast_shape, *output, ctx->GetOperatorThreadPool());
731+
return true;
732+
} else {
733+
break;
734+
}
718735
case FastReduceKind::kR:
719736
case FastReduceKind::kK:
720737
case FastReduceKind::kNone:
@@ -738,7 +755,8 @@ bool CommonFastReduce(OpKernelContext* ctx,
738755
TensorShapeVector& fast_axes) {
739756
return CommonFastReduceSwitch(ctx, axes_, keepdims_, noop_with_empty_axes,
740757
fast_kind, fast_shape, output_shape, fast_axes,
741-
AGG::WhichFastReduce(), &AGG::FastReduceKR, &AGG::FastReduceRK, &AGG::FastReduceKRK);
758+
AGG::WhichFastReduce(), &AGG::FastReduceKR, &AGG::FastReduceRK,
759+
&AGG::FastReduceKRK, &AGG::FastReduceRKR);
742760
}
743761

744762
static void ValidateKeepDims(const TensorShape& shape, int64_t keepdims) {
@@ -925,6 +943,14 @@ std::unique_ptr<Tensor> ReduceSum<T>::Impl(const Tensor& input, gsl::span<const
925943
} else {
926944
break;
927945
}
946+
case FastReduceKind::kRKR:
947+
ValidateFastReduceRKR(fast_shape, *output);
948+
if (fast_shape[0] >= std::max(2, concurrency::ThreadPool::DegreeOfParallelism(tp))) {
949+
ReduceAggregatorSum<T>::FastReduceRKR(input, fast_shape, *output, tp);
950+
return output;
951+
} else {
952+
break;
953+
}
928954
case FastReduceKind::kR:
929955
case FastReduceKind::kK:
930956
case FastReduceKind::kNone:

0 commit comments

Comments
 (0)