Skip to content

Commit 57ceba0

Browse files
berkinilbeyicopybara-github
authored andcommitted
[XLA] Use max size of buffer in heap simulator and buffer assignment and make size DCHECKs into CHECKs
PiperOrigin-RevId: 564421004
1 parent 2b693f7 commit 57ceba0

File tree

6 files changed

+127
-12
lines changed

6 files changed

+127
-12
lines changed

xla/service/BUILD

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1704,9 +1704,12 @@ xla_cc_test(
17041704
name = "heap_simulator_test",
17051705
srcs = ["heap_simulator_test.cc"],
17061706
deps = [
1707+
":async_op_canonicalizer",
17071708
":buffer_value",
17081709
":heap_simulator",
1710+
":hlo_dce",
17091711
":hlo_ordering",
1712+
":hlo_parser",
17101713
":hlo_value",
17111714
":tuple_points_to_analysis",
17121715
"//xla:literal",

xla/service/buffer_assignment.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -551,9 +551,9 @@ class BufferAssignment {
551551
BufferAllocation* GetMutableAllocation(BufferAllocation::Index index);
552552

553553
int64_t HloBufferSize(const HloBuffer& buffer) {
554-
int64_t result = buffer_size_(*buffer.values()[0]);
554+
int64_t result = 0;
555555
for (const HloValue* value : buffer.values()) {
556-
DCHECK_EQ(result, buffer_size_(*value));
556+
result = std::max(result, buffer_size_(*value));
557557
}
558558
return result;
559559
}

xla/service/buffer_assignment_test.cc

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2999,6 +2999,45 @@ ENTRY %main (a: f32[4096], b: f32[4096]) -> f32[4096] {
29992999
get_slice("negate_5", {}) == get_slice("negate_1", {}));
30003000
}
30013001

3002+
TEST_F(BufferAssignmentTest, AsyncCallImplicitSharding) {
3003+
std::string hlo_string = R"(
3004+
HloModule module, is_scheduled=true
3005+
3006+
called_computation {
3007+
param0 = f32[4] parameter(0)
3008+
constant = f32[1] constant(1)
3009+
dynamic-update-slice = f32[4] dynamic-update-slice(param0, constant, constant)
3010+
ROOT negate = f32[4] negate(dynamic-update-slice)
3011+
}
3012+
3013+
ENTRY entry {
3014+
p0 = f32[8] parameter(0)
3015+
call-start = ((f32[8]), f32[8], s32[]) call-start(p0), async_execution_thread="foo", to_apply=called_computation
3016+
ROOT call-done = f32[8] call-done(call-start), async_execution_thread="foo", to_apply=called_computation
3017+
}
3018+
)";
3019+
3020+
TF_ASSERT_OK_AND_ASSIGN(auto module,
3021+
ParseAndReturnUnverifiedModule(hlo_string));
3022+
AsyncOpCanonicalizer canonicalizer;
3023+
TF_ASSERT_OK(canonicalizer.Run(module.get()).status());
3024+
HloDCE dce;
3025+
TF_ASSERT_OK(dce.Run(module.get()).status());
3026+
3027+
auto buffers = RunBufferAssignmentWithSequentialOrdering(module.get());
3028+
3029+
LOG(INFO) << buffers->ToString();
3030+
3031+
auto get_slice = [&](std::string_view hlo_name, const ShapeIndex& index) {
3032+
return buffers
3033+
->GetUniqueSlice(FindInstruction(module.get(), hlo_name), index)
3034+
.value();
3035+
};
3036+
3037+
EXPECT_EQ(get_slice("p0", {}).size(), 32);
3038+
EXPECT_EQ(get_slice("dynamic-update-slice", {}).size(), 32);
3039+
}
3040+
30023041
TEST_F(BufferAssignmentTest, BufferInfoStringTest) {
30033042
absl::string_view module_str = R"(
30043043
HloModule test_module

xla/service/heap_simulator.cc

Lines changed: 27 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,17 @@ Status HeapSimulator::RunComputation(
257257

258258
VLOG(1) << "Program time" << hlo_live_range->schedule_end_time();
259259

260+
// Populate buffer sizes with the maximum size of the constituent HloValues.
261+
for (const HloBuffer& buffer : alias_analysis.buffers()) {
262+
int64_t size = 0;
263+
for (const HloValue* value : buffer.values()) {
264+
size = std::max(size, size_fn_(*value));
265+
}
266+
for (const HloValue* value : buffer.values()) {
267+
buffer_sizes_[value] = size;
268+
}
269+
}
270+
260271
// Go through each step in the program and replay each buffer define and free
261272
// events.
262273
for (int64_t i = 0; i < hlo_live_range->schedule_end_time() + 1; ++i) {
@@ -406,7 +417,7 @@ void HeapSimulator::Alloc(const HloValue* buffer,
406417
<< "Alloc called on freed buffer: " << *buffer;
407418

408419
allocated_buffers_.insert(buffer);
409-
const int64_t size = size_fn_(*buffer);
420+
const int64_t size = GetBufferSize(buffer);
410421
algorithm_->Alloc(buffer, size);
411422
no_fragmentation_stats_->Alloc(buffer, size);
412423
FillDebugTrace(HeapSimulatorTrace::Event::ALLOC, buffer, instruction,
@@ -419,7 +430,7 @@ void HeapSimulator::Alloc(const HloValue* buffer,
419430
// causes Free to be called on the underlying algorithm.
420431
void HeapSimulator::Free(const HloValue* buffer,
421432
const HloInstruction* instruction) {
422-
const int64_t size = size_fn_(*buffer);
433+
const int64_t size = GetBufferSize(buffer);
423434
algorithm_->Free(buffer, size);
424435
no_fragmentation_stats_->Free(buffer, size);
425436
FillDebugTrace(HeapSimulatorTrace::Event::FREE, buffer, instruction, nullptr);
@@ -432,12 +443,18 @@ void HeapSimulator::Free(const HloValue* buffer,
432443
// SharedGroup.
433444
void HeapSimulator::ShareBuffer(const HloValue* buffer, const HloValue* shared,
434445
const HloInstruction* instruction) {
435-
algorithm_->ShareWith(buffer, shared, size_fn_(*shared));
436-
no_fragmentation_stats_->ShareWith(buffer, shared, size_fn_(*shared));
446+
algorithm_->ShareWith(buffer, shared, GetBufferSize(shared));
447+
no_fragmentation_stats_->ShareWith(buffer, shared, GetBufferSize(shared));
437448
FillDebugTrace(HeapSimulatorTrace::Event::SHARE_WITH, buffer, instruction,
438449
shared);
439450
}
440451

452+
int64_t HeapSimulator::GetBufferSize(const HloValue* buffer) const {
453+
auto it = buffer_sizes_.find(buffer);
454+
CHECK(it != buffer_sizes_.end());
455+
return it->second;
456+
}
457+
441458
HeapSimulator::Result<HloValue> HeapSimulator::Finish() {
442459
Result<HloValue> result = algorithm_->Finish();
443460

@@ -591,7 +608,7 @@ void GlobalDecreasingSizeBestFitHeap<BufferType>::Alloc(
591608

592609
auto emplace_result = buffer_intervals_.emplace(
593610
buffer, BufferInterval{buffer, size, current_time_, -1, {}, true});
594-
DCHECK(emplace_result.second);
611+
CHECK(emplace_result.second);
595612
++current_time_;
596613
}
597614

@@ -603,11 +620,11 @@ void GlobalDecreasingSizeBestFitHeap<BufferType>::ShareWith(
603620
result_.chunk_map.emplace(buffer, Chunk::FromOffsetSize(0, 0));
604621
return;
605622
}
606-
DCHECK_NE(buffer_intervals_.count(share_with), 0);
623+
CHECK_NE(buffer_intervals_.count(share_with), 0);
607624
buffer_intervals_[share_with].colocations.push_back(buffer);
608625
auto emplace_result = buffer_intervals_.emplace(
609626
buffer, BufferInterval{buffer, size, current_time_, -1, {}, false});
610-
DCHECK(emplace_result.second);
627+
CHECK(emplace_result.second);
611628
++current_time_;
612629
}
613630

@@ -638,9 +655,9 @@ void GlobalDecreasingSizeBestFitHeap<BufferType>::Free(const BufferType* buffer,
638655
return;
639656
}
640657
BufferInterval& buffer_interval = FindOrDie(buffer_intervals_, buffer);
641-
DCHECK_EQ(buffer_interval.buffer, buffer);
642-
DCHECK_EQ(buffer_interval.size, size);
643-
DCHECK_EQ(buffer_interval.end, -1);
658+
CHECK_EQ(buffer_interval.buffer, buffer);
659+
CHECK_EQ(buffer_interval.size, size);
660+
CHECK_EQ(buffer_interval.end, -1);
644661
if (buffer_interval.end != -1) {
645662
return;
646663
}

xla/service/heap_simulator.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,10 @@ class HeapSimulator {
223223
void ShareBuffer(const HloValue* buffer, const HloValue* shared,
224224
const HloInstruction* instruction);
225225

226+
// Returns the size of the HloValue, which is the max size of the HloValues
227+
// that are part of the HloBuffer.
228+
int64_t GetBufferSize(const HloValue* buffer) const;
229+
226230
// Returns true if:
227231
// Two buffers belong to the same shared group.
228232
// Eight of the buffer has no shared group assigned.
@@ -253,6 +257,8 @@ class HeapSimulator {
253257
absl::flat_hash_set<const HloValue*> allocated_buffers_;
254258
absl::flat_hash_set<const HloValue*> freed_buffers_;
255259

260+
absl::flat_hash_map<const HloValue*, int64_t> buffer_sizes_;
261+
256262
// Debugging information filled in while the heap simulator runs.
257263
HeapSimulatorTrace debug_trace_;
258264
};

xla/service/heap_simulator_test.cc

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,11 @@ limitations under the License.
2929
#include "xla/hlo/ir/hlo_instruction.h"
3030
#include "xla/hlo/ir/hlo_module.h"
3131
#include "xla/literal.h"
32+
#include "xla/service/async_op_canonicalizer.h"
3233
#include "xla/service/buffer_value.h"
34+
#include "xla/service/hlo_dce.h"
3335
#include "xla/service/hlo_ordering.h"
36+
#include "xla/service/hlo_parser.h"
3437
#include "xla/service/hlo_value.h"
3538
#include "xla/service/tuple_points_to_analysis.h"
3639
#include "xla/status_macros.h"
@@ -940,6 +943,53 @@ TEST_F(HeapSimulatorTest, WholeModule) {
940943
});
941944
}
942945

946+
TEST_F(HeapSimulatorTest, AsyncCallImplicitSharding) {
947+
std::string hlo_string = R"(
948+
HloModule module, is_scheduled=true
949+
950+
called_computation {
951+
param0 = f32[4] parameter(0)
952+
constant = f32[1] constant(1)
953+
dynamic-update-slice = f32[4] dynamic-update-slice(param0, constant, constant)
954+
ROOT negate = f32[4] negate(dynamic-update-slice)
955+
}
956+
957+
ENTRY entry {
958+
p0 = f32[8] parameter(0)
959+
call-start = ((f32[8]), f32[8], s32[]) call-start(p0), async_execution_thread="foo", to_apply=called_computation
960+
ROOT call-done = f32[8] call-done(call-start), async_execution_thread="foo", to_apply=called_computation
961+
}
962+
)";
963+
964+
TF_ASSERT_OK_AND_ASSIGN(auto module,
965+
ParseAndReturnUnverifiedModule(hlo_string));
966+
AsyncOpCanonicalizer canonicalizer;
967+
TF_ASSERT_OK(canonicalizer.Run(module.get()).status());
968+
HloDCE dce;
969+
TF_ASSERT_OK(dce.Run(module.get()).status());
970+
TF_ASSERT_OK_AND_ASSIGN(auto alias_analysis,
971+
HloAliasAnalysis::Run(module.get()));
972+
auto size_fn = [](const BufferValue& buffer) -> int64_t {
973+
const Shape& shape = buffer.shape();
974+
if (!shape.IsArray()) {
975+
return 0;
976+
}
977+
return ShapeUtil::ByteSizeOf(shape);
978+
};
979+
auto algorithm = std::make_unique<GlobalDecreasingSizeBestFitHeap<HloValue>>(
980+
/*alignment=*/1);
981+
982+
HeapSimulator::Result<HloValue> result =
983+
HeapSimulator::Run(std::move(algorithm), *module, module->schedule(),
984+
*alias_analysis, size_fn)
985+
.value();
986+
for (const auto& [value, chunk] : result.heap_results[0].chunk_map) {
987+
if (value->instruction()->name() == "dynamic-update-slice") {
988+
EXPECT_EQ(chunk.size, 32);
989+
}
990+
}
991+
}
992+
943993
// Base class for heap algorithm tests.
944994
class HeapAlgorithmTestBase : public ::testing::Test {
945995
protected:

0 commit comments

Comments
 (0)