Skip to content

Commit 58e4eea

Browse files
ZixuanJiangGoogle-ML-Automation
authored andcommitted
[XLA:SPMD] Fix a bug in PartitionGatherTrivialSlicedOperandDimensions.
Previously, we use `indices.CloneWithNewHlo(filter)` to get the partitioned filter, which is wrong since `indices` and `filter` have different datatypes (S32 vs. PRED). With this change, we create `PartitionedHlo` for `filter` explicitly. PiperOrigin-RevId: 700147197
1 parent 402bf17 commit 58e4eea

File tree

1 file changed

+14
-9
lines changed

1 file changed

+14
-9
lines changed

xla/service/spmd/gather_scatter_handler.cc

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -530,18 +530,23 @@ absl::StatusOr<HloInstruction*> PartitionGatherTrivialSlicedOperandDimensions(
530530
pshape, output_grouped.sharding, batch_dims,
531531
slice_sizes, visitor, allow_recursive));
532532
// Mask out invalid results.
533-
auto filter = b->AddInstruction(HloInstruction::CreateCompare(
534-
ShapeUtil::ChangeElementType(indices.hlo()->shape(), PRED),
535-
indices.hlo(), indices_min, ComparisonDirection::kLt));
536-
filter = b->AddInstruction(HloInstruction::CreateBinary(
537-
filter->shape(), HloOpcode::kOr, filter,
538-
b->AddInstruction(HloInstruction::CreateCompare(
539-
ShapeUtil::ChangeElementType(indices.hlo()->shape(), PRED),
540-
indices.hlo(), indices_max, ComparisonDirection::kGt))));
533+
const Shape filter_shape =
534+
ShapeUtil::ChangeElementType(indices.hlo()->shape(), PRED);
535+
const Shape filter_base_shape =
536+
ShapeUtil::ChangeElementType(indices.base_shape(), PRED);
537+
HloInstruction* compare_lt = b->AddInstruction(
538+
HloInstruction::CreateCompare(filter_shape, indices.hlo(), indices_min,
539+
ComparisonDirection::kLt));
540+
HloInstruction* compare_gt = b->AddInstruction(
541+
HloInstruction::CreateCompare(filter_shape, indices.hlo(), indices_max,
542+
ComparisonDirection::kGt));
543+
HloInstruction* filter = b->AddInstruction(HloInstruction::CreateBinary(
544+
filter_shape, HloOpcode::kOr, compare_lt, compare_gt));
545+
filter->set_sharding(indices.hlo()->sharding());
541546
// Make sure that filter is of the same shape on the index pass-through
542547
// dimensions as the partitioned gather output, since we will need to filter
543548
// the gather output later.
544-
PartitionedHlo pfilter = indices.CloneWithNewHlo(filter);
549+
PartitionedHlo pfilter(filter, filter_base_shape, indices.state());
545550
pfilter = pfilter.Reshard(hlo_sharding_util::GatherIndexShardingFromOutput(
546551
hlo_sharding_util::UngroupSharding(output_grouped), gather));
547552
filter = pfilter.hlo();

0 commit comments

Comments
 (0)