Skip to content

Commit 68ee4f7

Browse files
berkinilbeyicopybara-github
authored andcommitted
[XLA] Disallow instructions from different parents when identifying loops for memory-bound loop optimization.
PiperOrigin-RevId: 535724784
1 parent dd04208 commit 68ee4f7

File tree

2 files changed

+109
-4
lines changed

2 files changed

+109
-4
lines changed

xla/service/memory_space_assignment.cc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3230,6 +3230,12 @@ void AlternateMemoryBestFitHeap::IdentifyAndOptimizeMemoryBoundLoops() {
32303230
// Found the start of the loop.
32313231
loop_start_idx = i;
32323232
}
3233+
if (inst->parent() != instruction_sequence[loop_start_idx]->parent()) {
3234+
VLOG(3) << "Mismatch (computation) at " << i << ": "
3235+
<< inst->parent()->name() << " vs "
3236+
<< instruction_sequence[loop_start_idx]->parent()->name();
3237+
break;
3238+
}
32333239
operand_distances.push_back({});
32343240
if (ignore_op(inst) || fingerprint_it == fingerprint_map_.end()) {
32353241
continue;

xla/service/memory_space_assignment_test.cc

Lines changed: 103 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9333,10 +9333,7 @@ TEST_F(MemoryBoundLoopOptimizerTest, OptimizerEndToEnd) {
93339333
/*alternate_memory_size=*/1024,
93349334
loop_start_idx, &optimizer));
93359335

9336-
LOG(INFO) << "Running Optimize";
93379336
optimizer->Optimize();
9338-
9339-
LOG(INFO) << "Running MSA";
93409337
TF_ASSERT_OK_AND_ASSIGN(auto preset_assignments,
93419338
RunMsa(module.get(), /*alternate_memory_size=*/1024));
93429339

@@ -9394,7 +9391,6 @@ ENTRY entry {
93949391

93959392
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo_str));
93969393

9397-
LOG(INFO) << "Running MSA";
93989394
TF_ASSERT_OK_AND_ASSIGN(auto preset_assignments,
93999395
RunMsa(module.get(), /*alternate_memory_size=*/512));
94009396

@@ -9427,5 +9423,108 @@ ENTRY entry {
94279423
EXPECT_EQ(prefetch_distance(next_copy_done), prefetch_distance(copy_done));
94289424
}
94299425

9426+
TEST_F(MemoryBoundLoopOptimizerTest, OptimizerEndToEndNestedWhileLoopBug) {
9427+
absl::string_view hlo_str = R"(
9428+
HloModule module, is_scheduled=true
9429+
9430+
prev_while_cond {
9431+
prev_while_cond_param = (f32[1,4], pred[]) parameter(0)
9432+
ROOT p = pred[] get-tuple-element(prev_while_cond_param), index=1
9433+
}
9434+
9435+
prev_while_body {
9436+
prev_while_body_param = (f32[1,4], pred[]) parameter(0)
9437+
prev_while_body_gte = f32[1,4] get-tuple-element(prev_while_body_param), index=0
9438+
prev_while_body_pred = pred[] get-tuple-element(prev_while_body_param), index=1
9439+
prev_while_body_op = f32[1,4] negate(prev_while_body_gte)
9440+
ROOT prev_while_body_root = (f32[1,4], pred[]) tuple(prev_while_body_op, prev_while_body_pred)
9441+
}
9442+
9443+
current_while_cond {
9444+
current_while_cond_param = (f32[1,4], pred[]) parameter(0)
9445+
ROOT p = pred[] get-tuple-element(current_while_cond_param), index=1
9446+
}
9447+
9448+
current_while_body {
9449+
current_while_body_param = (f32[1,4], pred[]) parameter(0)
9450+
current_while_body_gte = f32[1,4] get-tuple-element(current_while_body_param), index=0
9451+
current_while_body_pred = pred[] get-tuple-element(current_while_body_param), index=1
9452+
current_while_body_op = f32[1,4] negate(current_while_body_gte)
9453+
ROOT current_while_body_root = (f32[1,4], pred[]) tuple(current_while_body_op, current_while_body_pred)
9454+
}
9455+
9456+
next_while_cond {
9457+
next_while_cond_param = (f32[1,4], pred[]) parameter(0)
9458+
ROOT p = pred[] get-tuple-element(next_while_cond_param), index=1
9459+
}
9460+
9461+
next_while_body {
9462+
next_while_body_param = (f32[1,4], pred[]) parameter(0)
9463+
next_while_body_gte = f32[1,4] get-tuple-element(next_while_body_param), index=0
9464+
next_while_body_pred = pred[] get-tuple-element(next_while_body_param), index=1
9465+
next_while_body_op = f32[1,4] negate(next_while_body_gte)
9466+
ROOT next_while_body_root = (f32[1,4], pred[]) tuple(next_while_body_op, next_while_body_pred)
9467+
}
9468+
9469+
while_cond {
9470+
while_cond_param = (f32[1,4], f32[1,4], f32[1,4], f32[1,4], f32[1,4], f32[1,4], pred[]) parameter(0)
9471+
ROOT p = pred[] get-tuple-element(while_cond_param), index=6
9472+
}
9473+
9474+
while_body {
9475+
while_body_param = (f32[1,4], f32[1,4], f32[1,4], f32[1,4], f32[1,4], f32[1,4], pred[]) parameter(0)
9476+
prev_param0 = f32[1,4] get-tuple-element(while_body_param), index=0
9477+
param0 = f32[1,4] get-tuple-element(while_body_param), index=1
9478+
next_param0 = f32[1,4] get-tuple-element(while_body_param), index=2
9479+
prev_prev_op3 = f32[1,4] get-tuple-element(while_body_param), index=3
9480+
prev_prev_op4 = f32[1,4] get-tuple-element(while_body_param), index=4
9481+
while_pred = pred[] get-tuple-element(while_body_param), index=6
9482+
prev_op0 = f32[1,4] add(f32[1,4] prev_prev_op3, f32[1,4] prev_prev_op4)
9483+
prev_op1 = f32[1,4] add(f32[1,4] prev_prev_op4, f32[1,4] prev_op0)
9484+
prev_op2 = f32[1,4] add(f32[1,4] prev_op0, f32[1,4] prev_op1)
9485+
prev_op3 = f32[1,4] add(f32[1,4] prev_op1, f32[1,4] prev_op2)
9486+
prev_tuple = (f32[1,4], pred[]) tuple(prev_op3, while_pred)
9487+
prev_while = (f32[1,4], pred[]) while(prev_tuple), condition=prev_while_cond, body=prev_while_body
9488+
prev_gte = f32[1,4] get-tuple-element(prev_while), index=0
9489+
prev_op4 = f32[1,4] multiply(f32[1,4] prev_param0, f32[1,4] prev_gte)
9490+
op0 = f32[1,4] add(f32[1,4] prev_op3, f32[1,4] prev_op4)
9491+
op1 = f32[1,4] add(f32[1,4] prev_op4, f32[1,4] op0)
9492+
op2 = f32[1,4] add(f32[1,4] op0, f32[1,4] op1)
9493+
op3 = f32[1,4] add(f32[1,4] op1, f32[1,4] op2)
9494+
current_tuple = (f32[1,4], pred[]) tuple(op3, while_pred)
9495+
current_while = (f32[1,4], pred[]) while(current_tuple), condition=current_while_cond, body=current_while_body
9496+
current_gte = f32[1,4] get-tuple-element(current_while), index=0
9497+
op4 = f32[1,4] multiply(f32[1,4] param0, f32[1,4] current_gte)
9498+
next_op0 = f32[1,4] add(f32[1,4] op3, f32[1,4] op4)
9499+
next_op1 = f32[1,4] add(f32[1,4] op4, f32[1,4] next_op0)
9500+
next_op2 = f32[1,4] add(f32[1,4] next_op0, f32[1,4] next_op1)
9501+
next_op3 = f32[1,4] add(f32[1,4] next_op1, f32[1,4] next_op2)
9502+
next_tuple = (f32[1,4], pred[]) tuple(next_op3, while_pred)
9503+
next_while = (f32[1,4], pred[]) while(next_tuple), condition=next_while_cond, body=next_while_body
9504+
next_gte = f32[1,4] get-tuple-element(next_while), index=0
9505+
next_op4 = f32[1,4] multiply(f32[1,4] next_param0, f32[1,4] next_gte)
9506+
ROOT root = tuple(prev_param0, param0, next_param0, prev_prev_op3, prev_prev_op4, next_op4, while_pred)
9507+
}
9508+
9509+
ENTRY entry {
9510+
p0 = f32[1,4] parameter(0)
9511+
p1 = f32[1,4] parameter(1)
9512+
p2 = f32[1,4] parameter(2)
9513+
p3 = f32[1,4] parameter(3)
9514+
p4 = f32[1,4] parameter(4)
9515+
p5 = pred[] parameter(5)
9516+
copy = f32[1,4] copy(p4)
9517+
tuple = (f32[1,4], f32[1,4], f32[1,4], f32[1,4], f32[1,4], f32[1,4], pred[]) tuple(p0, p1, p2, p3, p4, copy, p5)
9518+
while = (f32[1,4], f32[1,4], f32[1,4], f32[1,4], f32[1,4], f32[1,4], pred[]) while(tuple), condition=while_cond, body=while_body
9519+
ROOT root = f32[1,4] get-tuple-element(while), index=5
9520+
}
9521+
)";
9522+
9523+
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo_str));
9524+
9525+
TF_ASSERT_OK_AND_ASSIGN(auto preset_assignments,
9526+
RunMsa(module.get(), /*alternate_memory_size=*/512));
9527+
}
9528+
94309529
} // namespace
94319530
} // namespace xla

0 commit comments

Comments
 (0)