@@ -9333,10 +9333,7 @@ TEST_F(MemoryBoundLoopOptimizerTest, OptimizerEndToEnd) {
9333
9333
/* alternate_memory_size=*/ 1024 ,
9334
9334
loop_start_idx, &optimizer));
9335
9335
9336
- LOG (INFO) << " Running Optimize" ;
9337
9336
optimizer->Optimize ();
9338
-
9339
- LOG (INFO) << " Running MSA" ;
9340
9337
TF_ASSERT_OK_AND_ASSIGN (auto preset_assignments,
9341
9338
RunMsa (module .get (), /* alternate_memory_size=*/ 1024 ));
9342
9339
@@ -9394,7 +9391,6 @@ ENTRY entry {
9394
9391
9395
9392
TF_ASSERT_OK_AND_ASSIGN (auto module , ParseAndReturnVerifiedModule (hlo_str));
9396
9393
9397
- LOG (INFO) << " Running MSA" ;
9398
9394
TF_ASSERT_OK_AND_ASSIGN (auto preset_assignments,
9399
9395
RunMsa (module .get (), /* alternate_memory_size=*/ 512 ));
9400
9396
@@ -9427,5 +9423,108 @@ ENTRY entry {
9427
9423
EXPECT_EQ (prefetch_distance (next_copy_done), prefetch_distance (copy_done));
9428
9424
}
9429
9425
9426
+ TEST_F (MemoryBoundLoopOptimizerTest, OptimizerEndToEndNestedWhileLoopBug) {
9427
+ absl::string_view hlo_str = R"(
9428
+ HloModule module, is_scheduled=true
9429
+
9430
+ prev_while_cond {
9431
+ prev_while_cond_param = (f32[1,4], pred[]) parameter(0)
9432
+ ROOT p = pred[] get-tuple-element(prev_while_cond_param), index=1
9433
+ }
9434
+
9435
+ prev_while_body {
9436
+ prev_while_body_param = (f32[1,4], pred[]) parameter(0)
9437
+ prev_while_body_gte = f32[1,4] get-tuple-element(prev_while_body_param), index=0
9438
+ prev_while_body_pred = pred[] get-tuple-element(prev_while_body_param), index=1
9439
+ prev_while_body_op = f32[1,4] negate(prev_while_body_gte)
9440
+ ROOT prev_while_body_root = (f32[1,4], pred[]) tuple(prev_while_body_op, prev_while_body_pred)
9441
+ }
9442
+
9443
+ current_while_cond {
9444
+ current_while_cond_param = (f32[1,4], pred[]) parameter(0)
9445
+ ROOT p = pred[] get-tuple-element(current_while_cond_param), index=1
9446
+ }
9447
+
9448
+ current_while_body {
9449
+ current_while_body_param = (f32[1,4], pred[]) parameter(0)
9450
+ current_while_body_gte = f32[1,4] get-tuple-element(current_while_body_param), index=0
9451
+ current_while_body_pred = pred[] get-tuple-element(current_while_body_param), index=1
9452
+ current_while_body_op = f32[1,4] negate(current_while_body_gte)
9453
+ ROOT current_while_body_root = (f32[1,4], pred[]) tuple(current_while_body_op, current_while_body_pred)
9454
+ }
9455
+
9456
+ next_while_cond {
9457
+ next_while_cond_param = (f32[1,4], pred[]) parameter(0)
9458
+ ROOT p = pred[] get-tuple-element(next_while_cond_param), index=1
9459
+ }
9460
+
9461
+ next_while_body {
9462
+ next_while_body_param = (f32[1,4], pred[]) parameter(0)
9463
+ next_while_body_gte = f32[1,4] get-tuple-element(next_while_body_param), index=0
9464
+ next_while_body_pred = pred[] get-tuple-element(next_while_body_param), index=1
9465
+ next_while_body_op = f32[1,4] negate(next_while_body_gte)
9466
+ ROOT next_while_body_root = (f32[1,4], pred[]) tuple(next_while_body_op, next_while_body_pred)
9467
+ }
9468
+
9469
+ while_cond {
9470
+ while_cond_param = (f32[1,4], f32[1,4], f32[1,4], f32[1,4], f32[1,4], f32[1,4], pred[]) parameter(0)
9471
+ ROOT p = pred[] get-tuple-element(while_cond_param), index=6
9472
+ }
9473
+
9474
+ while_body {
9475
+ while_body_param = (f32[1,4], f32[1,4], f32[1,4], f32[1,4], f32[1,4], f32[1,4], pred[]) parameter(0)
9476
+ prev_param0 = f32[1,4] get-tuple-element(while_body_param), index=0
9477
+ param0 = f32[1,4] get-tuple-element(while_body_param), index=1
9478
+ next_param0 = f32[1,4] get-tuple-element(while_body_param), index=2
9479
+ prev_prev_op3 = f32[1,4] get-tuple-element(while_body_param), index=3
9480
+ prev_prev_op4 = f32[1,4] get-tuple-element(while_body_param), index=4
9481
+ while_pred = pred[] get-tuple-element(while_body_param), index=6
9482
+ prev_op0 = f32[1,4] add(f32[1,4] prev_prev_op3, f32[1,4] prev_prev_op4)
9483
+ prev_op1 = f32[1,4] add(f32[1,4] prev_prev_op4, f32[1,4] prev_op0)
9484
+ prev_op2 = f32[1,4] add(f32[1,4] prev_op0, f32[1,4] prev_op1)
9485
+ prev_op3 = f32[1,4] add(f32[1,4] prev_op1, f32[1,4] prev_op2)
9486
+ prev_tuple = (f32[1,4], pred[]) tuple(prev_op3, while_pred)
9487
+ prev_while = (f32[1,4], pred[]) while(prev_tuple), condition=prev_while_cond, body=prev_while_body
9488
+ prev_gte = f32[1,4] get-tuple-element(prev_while), index=0
9489
+ prev_op4 = f32[1,4] multiply(f32[1,4] prev_param0, f32[1,4] prev_gte)
9490
+ op0 = f32[1,4] add(f32[1,4] prev_op3, f32[1,4] prev_op4)
9491
+ op1 = f32[1,4] add(f32[1,4] prev_op4, f32[1,4] op0)
9492
+ op2 = f32[1,4] add(f32[1,4] op0, f32[1,4] op1)
9493
+ op3 = f32[1,4] add(f32[1,4] op1, f32[1,4] op2)
9494
+ current_tuple = (f32[1,4], pred[]) tuple(op3, while_pred)
9495
+ current_while = (f32[1,4], pred[]) while(current_tuple), condition=current_while_cond, body=current_while_body
9496
+ current_gte = f32[1,4] get-tuple-element(current_while), index=0
9497
+ op4 = f32[1,4] multiply(f32[1,4] param0, f32[1,4] current_gte)
9498
+ next_op0 = f32[1,4] add(f32[1,4] op3, f32[1,4] op4)
9499
+ next_op1 = f32[1,4] add(f32[1,4] op4, f32[1,4] next_op0)
9500
+ next_op2 = f32[1,4] add(f32[1,4] next_op0, f32[1,4] next_op1)
9501
+ next_op3 = f32[1,4] add(f32[1,4] next_op1, f32[1,4] next_op2)
9502
+ next_tuple = (f32[1,4], pred[]) tuple(next_op3, while_pred)
9503
+ next_while = (f32[1,4], pred[]) while(next_tuple), condition=next_while_cond, body=next_while_body
9504
+ next_gte = f32[1,4] get-tuple-element(next_while), index=0
9505
+ next_op4 = f32[1,4] multiply(f32[1,4] next_param0, f32[1,4] next_gte)
9506
+ ROOT root = tuple(prev_param0, param0, next_param0, prev_prev_op3, prev_prev_op4, next_op4, while_pred)
9507
+ }
9508
+
9509
+ ENTRY entry {
9510
+ p0 = f32[1,4] parameter(0)
9511
+ p1 = f32[1,4] parameter(1)
9512
+ p2 = f32[1,4] parameter(2)
9513
+ p3 = f32[1,4] parameter(3)
9514
+ p4 = f32[1,4] parameter(4)
9515
+ p5 = pred[] parameter(5)
9516
+ copy = f32[1,4] copy(p4)
9517
+ tuple = (f32[1,4], f32[1,4], f32[1,4], f32[1,4], f32[1,4], f32[1,4], pred[]) tuple(p0, p1, p2, p3, p4, copy, p5)
9518
+ while = (f32[1,4], f32[1,4], f32[1,4], f32[1,4], f32[1,4], f32[1,4], pred[]) while(tuple), condition=while_cond, body=while_body
9519
+ ROOT root = f32[1,4] get-tuple-element(while), index=5
9520
+ }
9521
+ )" ;
9522
+
9523
+ TF_ASSERT_OK_AND_ASSIGN (auto module , ParseAndReturnVerifiedModule (hlo_str));
9524
+
9525
+ TF_ASSERT_OK_AND_ASSIGN (auto preset_assignments,
9526
+ RunMsa (module .get (), /* alternate_memory_size=*/ 512 ));
9527
+ }
9528
+
9430
9529
} // namespace
9431
9530
} // namespace xla
0 commit comments