Skip to content

Commit 1bb87a1

Browse files
jreifferscopybara-github
authored andcommitted
HloFusionAnalysis: Don't peek outside the fusion boundary.
Currently, we use the FindNonTrivialHero overload that doesn't take a boundary function. This isn't always correct. PiperOrigin-RevId: 572576780
1 parent cbf1b8e commit 1bb87a1

File tree

6 files changed

+133
-11
lines changed

6 files changed

+133
-11
lines changed

xla/service/gpu/BUILD

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3300,6 +3300,20 @@ cc_library(
33003300
],
33013301
)
33023302

3303+
xla_cc_test(
3304+
name = "hlo_fusion_analysis_test",
3305+
srcs = ["hlo_fusion_analysis_test.cc"],
3306+
deps = [
3307+
":backend_configs_cc",
3308+
":gpu_device_info_for_tests",
3309+
":hlo_fusion_analysis",
3310+
":hlo_traversal",
3311+
"//xla/tests:hlo_test_base",
3312+
"//xla/tests:xla_internal_test_main",
3313+
"@tsl//tsl/platform:statusor",
3314+
],
3315+
)
3316+
33033317
cc_library(
33043318
name = "gpu_performance_model",
33053319
srcs = ["gpu_performance_model.cc"],

xla/service/gpu/hlo_fusion_analysis.cc

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -262,7 +262,7 @@ StatusOr<HloFusionAnalysis> HloFusionAnalysis::Create(
262262
std::vector<const HloInstruction*> heroes;
263263
heroes.reserve(hlo_roots.size());
264264
for (auto* root : hlo_roots) {
265-
heroes.push_back(&FindNonTrivialHero(*root));
265+
heroes.push_back(&FindNonTrivialHero(*root, boundary_fn));
266266
}
267267

268268
std::vector<const HloInstruction*> fusion_arguments;
@@ -306,28 +306,26 @@ HloFusionAnalysis::EmitterFusionKind HloFusionAnalysis::GetEmitterFusionKind()
306306
return EmitterFusionKind::kTriton;
307307
}
308308
#endif
309-
const auto& roots = fusion_roots();
310-
311-
if (absl::c_any_of(roots, [](const HloInstruction* root) {
312-
return IsRealReductionHero(*root, FindNonTrivialHero(*root));
313-
})) {
314-
return EmitterFusionKind::kReduction;
309+
for (auto [root, hero] : llvm::zip(fusion_roots_, fusion_heroes_)) {
310+
if (IsRealReductionHero(*root, *hero)) {
311+
return EmitterFusionKind::kReduction;
312+
}
315313
}
316314

317315
// We expect that the last dimension is swapped with a different dimension.
318316
if (HasConsistentTransposeHeros() && tiled_transpose_->permutation[2] != 2) {
319317
return EmitterFusionKind::kTranspose;
320318
}
321319

322-
if (roots.size() > 1) {
323-
if (IsInputFusibleNonStridedSlices(roots) &&
324-
AllSliceInputsAreCompatible(roots)) {
320+
if (fusion_roots_.size() > 1) {
321+
if (IsInputFusibleNonStridedSlices(fusion_roots_) &&
322+
AllSliceInputsAreCompatible(fusion_roots_)) {
325323
return EmitterFusionKind::kInputSlices;
326324
}
327325
return EmitterFusionKind::kLoop;
328326
}
329327

330-
if (roots[0]->opcode() == HloOpcode::kScatter) {
328+
if (fusion_roots_[0]->opcode() == HloOpcode::kScatter) {
331329
return EmitterFusionKind::kScatter;
332330
}
333331

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
#include "xla/service/gpu/hlo_fusion_analysis.h"
16+
17+
#include "xla/service/gpu/backend_configs.pb.h"
18+
#include "xla/service/gpu/gpu_device_info_for_tests.h"
19+
#include "xla/service/gpu/hlo_traversal.h"
20+
#include "xla/tests/hlo_test_base.h"
21+
#include "tsl/platform/statusor.h"
22+
23+
namespace xla::gpu {
24+
namespace {
25+
26+
class HloFusionAnalysisTest : public HloTestBase {};
27+
28+
TEST_F(HloFusionAnalysisTest, DoesNotPeekOutsideBoundary) {
29+
auto module = ParseAndReturnVerifiedModule(R"(
30+
HloModule test_module
31+
32+
add {
33+
p0 = f32[] parameter(0)
34+
p1 = f32[] parameter(1)
35+
ROOT add = f32[] add(p0, p1)
36+
}
37+
38+
ENTRY main {
39+
%p0 = f32[1024] parameter(0)
40+
%p1 = f32[] parameter(1)
41+
%reduce = f32[] reduce(%p0, %p1), dimensions={0}, to_apply=add
42+
ROOT %bitcast = s32[] bitcast(%reduce)
43+
})")
44+
.value();
45+
46+
auto device_info = TestGpuDeviceInfo::RTXA6000DeviceInfo();
47+
48+
auto* root = module->entry_computation()->root_instruction();
49+
TF_ASSERT_OK_AND_ASSIGN(
50+
auto analysis, HloFusionAnalysis::Create(
51+
FusionBackendConfig::default_instance(), {root},
52+
MakeSingleInstructionFusion(*root), &device_info));
53+
EXPECT_EQ(analysis.GetEmitterFusionKind(),
54+
HloFusionAnalysis::EmitterFusionKind::kLoop);
55+
56+
TF_ASSERT_OK_AND_ASSIGN(
57+
auto analysis_fused,
58+
HloFusionAnalysis::Create(FusionBackendConfig::default_instance(), {root},
59+
DefaultFusionBoundaryFn, &device_info));
60+
EXPECT_EQ(analysis_fused.GetEmitterFusionKind(),
61+
HloFusionAnalysis::EmitterFusionKind::kReduction);
62+
}
63+
64+
} // namespace
65+
} // namespace xla::gpu

xla/service/gpu/hlo_traversal.cc

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,13 @@ FusionBoundaryFn MakeProducerConsumerFusion(
6666
};
6767
}
6868

69+
FusionBoundaryFn MakeSingleInstructionFusion(const HloInstruction& root) {
70+
if (root.opcode() == HloOpcode::kFusion) {
71+
return DefaultFusionBoundaryFn;
72+
}
73+
return [](const HloInstruction&, const HloInstruction&) { return true; };
74+
}
75+
6976
void HloBfsConsumersFirstTraversal(
7077
absl::Span<const HloInstruction* const> roots,
7178
const FusionBoundaryFn& boundary,

xla/service/gpu/hlo_traversal.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,12 @@ bool DefaultFusionBoundaryFn(const HloInstruction& producer,
4646
FusionBoundaryFn MakeProducerConsumerFusion(
4747
const HloInstruction& fused_producer, const HloInstruction& fused_consumer);
4848

49+
// Creates a fusion boundary function for a fusion consisting only of `root`. If
50+
// `root` is a fusion, the result is the same as `DefaultFusionBuondaryFn`. If
51+
// `root` is the root of a fusion, the result is just that root, not the entire
52+
// computation.
53+
FusionBoundaryFn MakeSingleInstructionFusion(const HloInstruction& root);
54+
4955
// Visit the HLO nodes starting from `roots` in BFS order (consumers before
5056
// producers). Each node will be visited exactly once. The graph is not
5157
// traversed along edges for which `boundary` returns true.

xla/service/gpu/hlo_traversal_test.cc

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -311,6 +311,38 @@ TEST_F(HloTraversalTest, FuseFusionConsumerAndProducer) {
311311
EXPECT_THAT(params, ElementsAre("negate", "p0"));
312312
}
313313

314+
TEST_F(HloTraversalTest, SingleInstructionFusionOfFusion) {
315+
auto module = ParseAndReturnVerifiedModule(kTwoFusions).value();
316+
auto* fusion =
317+
module->entry_computation()->GetInstructionWithName("fusion.1");
318+
319+
auto boundary = MakeSingleInstructionFusion(*fusion);
320+
std::vector<std::string> nodes;
321+
HloBfsConsumersFirstTraversal({fusion}, boundary,
322+
[&](const HloInstruction& node) {
323+
nodes.emplace_back(node.name());
324+
return TraversalResult::kVisitOperands;
325+
});
326+
327+
EXPECT_THAT(nodes,
328+
ElementsAre("fusion.1", "reduce.1", "mul", "p0.1", "p1.1"));
329+
}
330+
331+
TEST_F(HloTraversalTest, SingleInstructionFusionOfInstruction) {
332+
auto module = ParseAndReturnVerifiedModule(kTwoFusions).value();
333+
auto* negate = module->entry_computation()->GetInstructionWithName("negate");
334+
335+
auto boundary = MakeSingleInstructionFusion(*negate);
336+
std::vector<std::string> nodes;
337+
HloBfsConsumersFirstTraversal({negate}, boundary,
338+
[&](const HloInstruction& node) {
339+
nodes.emplace_back(node.name());
340+
return TraversalResult::kVisitOperands;
341+
});
342+
343+
EXPECT_THAT(nodes, ElementsAre("negate"));
344+
}
345+
314346
} // namespace
315347
} // namespace gpu
316348
} // namespace xla

0 commit comments

Comments
 (0)