@@ -259,10 +259,15 @@ static void ValidateFastReduceRK(const gsl::span<const int64_t>& fast_shape, con
259259}
260260
261261static void ValidateFastReduceKRK (const gsl::span<const int64_t >& fast_shape, const Tensor& output) {
262- ORT_ENFORCE (fast_shape.size () == 3 , " Only works on matrices with two dimensions." );
262+ ORT_ENFORCE (fast_shape.size () == 3 , " Only works on matrices with three dimensions." );
263263 ORT_ENFORCE (fast_shape[0 ] * fast_shape[2 ] == output.Shape ().Size (), " Output size mismatch." );
264264}
265265
266+ static void ValidateFastReduceRKR (const gsl::span<const int64_t >& fast_shape, const Tensor& output) {
267+ ORT_ENFORCE (fast_shape.size () == 3 , " Only works on matrices with three dimensions." );
268+ ORT_ENFORCE (fast_shape[1 ] == output.Shape ().Size (), " Output size mismatch." );
269+ }
270+
266271void ReduceAggregatorBase::FastReduceKR (const Tensor&, const gsl::span<const int64_t >&, Tensor&, concurrency::ThreadPool*) {
267272 ValidateMustBeOverloaded ();
268273}
@@ -272,6 +277,9 @@ void ReduceAggregatorBase::FastReduceRK(const Tensor&, const gsl::span<const int
272277void ReduceAggregatorBase::FastReduceKRK (const Tensor&, const gsl::span<const int64_t >&, Tensor&, concurrency::ThreadPool*) {
273278 ValidateMustBeOverloaded ();
274279}
280+ void ReduceAggregatorBase::FastReduceRKR (const Tensor&, const gsl::span<const int64_t >&, Tensor&, concurrency::ThreadPool*) {
281+ ValidateMustBeOverloaded ();
282+ }
275283
276284void NoTransposePrepareForReduce (const TensorShape& new_input_shape,
277285 gsl::span<const int64_t > reduced_axes,
@@ -624,8 +632,8 @@ FastReduceKind OptimizeShapeForFastReduce(gsl::span<const int64_t> input_shape,
624632 if (fast_shape.size () == 2 ) {
625633 return reduce[0 ] ? FastReduceKind::kRK : FastReduceKind::kKR ;
626634 }
627- if (fast_shape.size () == 3 && !reduce[ 0 ] ) {
628- return FastReduceKind::kKRK ;
635+ if (fast_shape.size () == 3 ) {
636+ return reduce[ 0 ] ? FastReduceKind:: kRKR : FastReduceKind::kKRK ;
629637 }
630638 return FastReduceKind::kNone ;
631639}
@@ -671,7 +679,8 @@ bool CommonFastReduceSwitch(OpKernelContext* ctx,
671679 FastReduceKind which_fast_reduce,
672680 fast_reduce_fct* case_kr,
673681 fast_reduce_fct* case_rk,
674- fast_reduce_fct* case_krk) {
682+ fast_reduce_fct* case_krk,
683+ fast_reduce_fct* case_rkr) {
675684 TensorShapeVector axes;
676685 const Tensor* input = ctx->Input <Tensor>(0 );
677686 auto reduced_dims = input->Shape ().GetDims ();
@@ -715,6 +724,14 @@ bool CommonFastReduceSwitch(OpKernelContext* ctx,
715724 } else {
716725 break ;
717726 }
727+ case FastReduceKind::kRKR :
728+ ValidateFastReduceRKR (fast_shape, *output);
729+ if (fast_shape[1 ] >= std::max (2 , concurrency::ThreadPool::DegreeOfParallelism (ctx->GetOperatorThreadPool ()))) {
730+ case_rkr (*input, fast_shape, *output, ctx->GetOperatorThreadPool ());
731+ return true ;
732+ } else {
733+ break ;
734+ }
718735 case FastReduceKind::kR :
719736 case FastReduceKind::kK :
720737 case FastReduceKind::kNone :
@@ -738,7 +755,8 @@ bool CommonFastReduce(OpKernelContext* ctx,
738755 TensorShapeVector& fast_axes) {
739756 return CommonFastReduceSwitch (ctx, axes_, keepdims_, noop_with_empty_axes,
740757 fast_kind, fast_shape, output_shape, fast_axes,
741- AGG::WhichFastReduce (), &AGG::FastReduceKR, &AGG::FastReduceRK, &AGG::FastReduceKRK);
758+ AGG::WhichFastReduce (), &AGG::FastReduceKR, &AGG::FastReduceRK,
759+ &AGG::FastReduceKRK, &AGG::FastReduceRKR);
742760}
743761
744762static void ValidateKeepDims (const TensorShape& shape, int64_t keepdims) {
@@ -925,6 +943,14 @@ std::unique_ptr<Tensor> ReduceSum<T>::Impl(const Tensor& input, gsl::span<const
925943 } else {
926944 break ;
927945 }
946+ case FastReduceKind::kRKR :
947+ ValidateFastReduceRKR (fast_shape, *output);
948+ if (fast_shape[0 ] >= std::max (2 , concurrency::ThreadPool::DegreeOfParallelism (tp))) {
949+ ReduceAggregatorSum<T>::FastReduceRKR (input, fast_shape, *output, tp);
950+ return output;
951+ } else {
952+ break ;
953+ }
928954 case FastReduceKind::kR :
929955 case FastReduceKind::kK :
930956 case FastReduceKind::kNone :
0 commit comments