Skip to content

Commit f71cb9d

Browse files
berkinilbeyicopybara-github
authored andcommitted
[XLA] Verify that async computations are trivial (contain only a root and parameter instructions).
PiperOrigin-RevId: 566509445
1 parent cbd752f commit f71cb9d

File tree

2 files changed

+42
-0
lines changed

2 files changed

+42
-0
lines changed

xla/service/hlo_verifier.cc

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2240,6 +2240,18 @@ Status VerifyAsynchronousInstructionPairs(const HloModule& module) {
22402240
return OkStatus();
22412241
}
22422242

2243+
// Checks that the asynchronous computation only has a root and parameter
2244+
// instructions.
2245+
Status VerifyAsyncComputation(const HloComputation* async_computation) {
2246+
if (!async_computation->CanExpandIntoSingleInstruction()) {
2247+
return FailedPrecondition(
2248+
"Asynchronous computation %s expected to contain only the root and "
2249+
"parameter instructions.",
2250+
async_computation->name());
2251+
}
2252+
return OkStatus();
2253+
}
2254+
22432255
// Checks that AllReduce instructions in the module are either all layout
22442256
// constrained or all unconstrained.
22452257
Status VerifyLayoutConstrainedAllReduce(const HloModule& module) {
@@ -2849,6 +2861,9 @@ StatusOr<bool> HloVerifier::Run(
28492861
for (auto* computation : module->computations(execution_threads)) {
28502862
TF_RETURN_IF_ERROR(computation->Accept(shape_verifier.get()));
28512863
TF_RETURN_IF_ERROR(computation->Accept(&instruction_verifier));
2864+
if (computation->IsAsyncComputation()) {
2865+
TF_RETURN_IF_ERROR(VerifyAsyncComputation(computation));
2866+
}
28522867
}
28532868

28542869
TF_RETURN_IF_ERROR(shape_verifier->VerifyEntryComputationLayout(*module));

xla/service/hlo_verifier_test.cc

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1219,6 +1219,33 @@ TEST_F(HloVerifierTest, AsyncUpdateWrongType) {
12191219
"async-update expects the shape of operand and output to match"));
12201220
}
12211221

1222+
TEST_F(HloVerifierTest, AsyncOpComputationNotTrivial) {
1223+
const char* const hlo_string = R"(
1224+
HloModule Module
1225+
1226+
async_computation {
1227+
p = f32[2,3] parameter(0)
1228+
copy = f32[2,3] copy(p)
1229+
ROOT custom-call = f32[3,2] custom-call(copy), custom_call_target="foo"
1230+
}
1231+
1232+
ENTRY AsyncStartAndAsyncDone {
1233+
p0 = f32[2,3] parameter(0)
1234+
async-start = ((f32[2,3]), f32[3,2], u32[]) async-start(p0), calls=async_computation
1235+
ROOT async-done = f32[3,2] async-done(async-start), calls=async_computation
1236+
}
1237+
)";
1238+
TF_ASSERT_OK_AND_ASSIGN(auto module,
1239+
ParseAndReturnUnverifiedModule(hlo_string));
1240+
1241+
auto status = verifier().Run(module.get()).status();
1242+
ASSERT_FALSE(status.ok());
1243+
EXPECT_THAT(
1244+
status.message(),
1245+
HasSubstr(
1246+
"expected to contain only the root and parameter instructions"));
1247+
}
1248+
12221249
TEST_F(HloVerifierTestLayoutSensitive, AsyncDoneWrongGroupId) {
12231250
const char* const hlo_string = R"(
12241251
HloModule Module

0 commit comments

Comments
 (0)