Skip to content

Commit f577aae

Browse files
jreifferscopybara-github
authored andcommitted
Priority fusion: fix analysis of reduction epilogues.
The current code sometimes doesn't detect that a fusion can be emitted using the reduction emitter. PiperOrigin-RevId: 580430642
1 parent 3f64597 commit f577aae

File tree

5 files changed

+62
-5
lines changed

5 files changed

+62
-5
lines changed

xla/service/gpu/hlo_traversal.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,9 @@ FusionBoundaryFn MakeProducerConsumerFusion(
5959
// producer.
6060
return &fused_producer != &producer;
6161
}
62+
if (&producer == &fused_consumer) {
63+
return true;
64+
}
6265

6366
// Otherwise, fall back to the default; we're already in the fused
6467
// producer.

xla/service/gpu/ir_emission_utils.cc

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -723,15 +723,25 @@ std::optional<TransposeDescription> GetDescriptionForTiledTransposeEmitter(
723723
return std::nullopt;
724724
}
725725

726-
bool IsIntermediate(const HloInstruction* instr, int allowed_operand_count) {
726+
bool IsIntermediate(const HloInstruction* instr, int allowed_operand_count,
727+
FusionBoundaryFn boundary) {
727728
// Number of operands should be in range [1, allowed_operand_count].
728729
if (instr->operand_count() == 0 ||
729730
instr->operand_count() > allowed_operand_count) {
730731
return false;
731732
}
732733

733734
// Intermediate `instr` can't have multiple users.
734-
if (instr->user_count() > 1) {
735+
// If we have a boundary function, only consider users within the
736+
// boundary. This isn't really correct, since the real users aren't
737+
// necessarily the instruction's users at this point.
738+
// TODO(jreiffers): Figure out the point of this check.
739+
int64_t num_users =
740+
boundary ? absl::c_count_if(
741+
instr->users(),
742+
[&](const auto* user) { return !boundary(*instr, *user); })
743+
: instr->user_count();
744+
if (num_users > 1) {
735745
return false;
736746
}
737747

@@ -780,7 +790,8 @@ const HloInstruction& FindNonTrivialHero(
780790
auto preds = FindPredecessors(*node, is_boundary);
781791
return preds.size() == 1 ? preds.front() : nullptr;
782792
}
783-
return IsIntermediate(node) && !is_boundary(*node->operand(0), *node)
793+
return IsIntermediate(node, 1, is_boundary) &&
794+
!is_boundary(*node->operand(0), *node)
784795
? node->operand(0)
785796
: nullptr;
786797
};

xla/service/gpu/ir_emission_utils.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ limitations under the License.
2525
#include "xla/hlo/ir/hlo_instruction.h"
2626
#include "xla/mlir_hlo/lhlo/IR/lhlo_ops.h"
2727
#include "xla/service/buffer_assignment.h"
28+
#include "xla/service/gpu/hlo_traversal.h"
2829

2930
namespace xla {
3031
namespace gpu {
@@ -193,7 +194,8 @@ std::optional<TransposeDescription> FindTiledLogicalTranspose(
193194
std::optional<TransposeDescription> GetDescriptionForTiledTransposeEmitter(
194195
const HloInstruction& root, const HloInstruction& hero);
195196

196-
bool IsIntermediate(const HloInstruction* instr, int allowed_operand_count = 1);
197+
bool IsIntermediate(const HloInstruction* instr, int allowed_operand_count = 1,
198+
FusionBoundaryFn boundary = nullptr);
197199

198200
// Log the given module if the VLOG level is >= level.
199201
void VLogModule(int level, const llvm::Module& module);

xla/service/gpu/model/gpu_performance_model.cc

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -444,6 +444,7 @@ GpuPerformanceModel::RunTimes GpuPerformanceModel::EstimateRunTimes(
444444
absl::Duration exec_time_fused = absl::ZeroDuration();
445445
absl::Duration producer_output_read_time_unfused = absl::ZeroDuration();
446446
for (const HloInstruction* fused_consumer : fused_consumers) {
447+
VLOG(8) << "Consumer: " << fused_consumer->name();
447448
float utilization_by_this_consumer = cost_analysis->operand_utilization(
448449
*fused_consumer, fused_consumer->operand_index(producer));
449450
total_producer_utilization += utilization_by_this_consumer;
@@ -478,6 +479,9 @@ GpuPerformanceModel::RunTimes GpuPerformanceModel::EstimateRunTimes(
478479
absl::Duration input_access_time_by_this_consumer = ProducerInputAccessTime(
479480
cost_analysis, *device_info, launch_dimensions_fused.num_blocks(),
480481
producer, analysis_fused, config, fused_consumer);
482+
VLOG(10) << " Compute time by consumer: " << compute_time_by_this_consumer;
483+
VLOG(10) << " Input access time by consumer: "
484+
<< input_access_time_by_this_consumer;
481485

482486
exec_time_fused += std::max(compute_time_by_this_consumer,
483487
input_access_time_by_this_consumer);
@@ -486,11 +490,14 @@ GpuPerformanceModel::RunTimes GpuPerformanceModel::EstimateRunTimes(
486490
utilization_by_this_consumer);
487491
int64_t n_bytes_net = std::min(producer_data.bytes_written, n_bytes_total);
488492

489-
producer_output_read_time_unfused += ReadTime(
493+
auto read_time_unfused = ReadTime(
490494
*device_info, launch_dimensions_unfused.num_blocks(), n_bytes_net,
491495
n_bytes_total, fused_consumer->shape().element_type(),
492496
/*coalesced=*/!TransposesMinorDimension(fused_consumer),
493497
config.first_read_from_dram);
498+
499+
VLOG(10) << " Read time unfused: " << read_time_unfused;
500+
producer_output_read_time_unfused += read_time_unfused;
494501
}
495502

496503
absl::Duration time_unfused =

xla/service/gpu/priority_fusion_test.cc

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -549,5 +549,39 @@ CHECK: ROOT {{.*}} reduce(
549549
)");
550550
}
551551

552+
TEST_F(PriorityFusionTest, FuseReductionEpilogueWithMultipleUsers) {
553+
// Regression test that verifies we correctly fuse the `log` into the reduce.
554+
constexpr absl::string_view kHlo = R"(
555+
HloModule test_module
556+
557+
add {
558+
x = f32[] parameter(0)
559+
y = f32[] parameter(1)
560+
ROOT add = f32[] add(x, y)
561+
}
562+
563+
fused_computation {
564+
p0 = f32[64,16384]{1,0} parameter(0)
565+
c0 = f32[] constant(0)
566+
ROOT reduce.858 = f32[64]{0} reduce(p0, c0), dimensions={1}, to_apply=add
567+
}
568+
569+
ENTRY main {
570+
p0 = f32[64,16384]{1,0} parameter(0)
571+
fusion = f32[64]{0} fusion(p0), kind=kInput, calls=fused_computation
572+
log = f32[64]{0} log(fusion)
573+
negate = f32[64]{0} custom-call(log), custom_call_target="negate"
574+
ROOT add = f32[64]{0} add(negate, log)
575+
}
576+
)";
577+
578+
RunAndFilecheckHloRewrite(kHlo, std::move(priority_fusion_), R"(
579+
CHECK: ENTRY
580+
CHECK: %[[PARAM:.*]] = {{.*}} parameter(0)
581+
CHECK: %[[FUSION:.*]] = {{.*}} fusion(%[[PARAM]])
582+
CHECK: custom-call(%[[FUSION]])
583+
)");
584+
}
585+
552586
} // namespace gpu
553587
} // namespace xla

0 commit comments

Comments
 (0)