Skip to content

Commit 022bef7

Browse files
pschuhGoogle-ML-Automation
authored andcommitted
Add a common AbstractTrackedDeviceBuffer type which can be used by
a AbstractLocalPjRtBuffer to allow a unified implementation of donation logic. PiperOrigin-RevId: 744888717
1 parent 8710ae9 commit 022bef7

7 files changed

+573
-447
lines changed

xla/pjrt/BUILD

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,11 +84,28 @@ xla_cc_test(
8484
],
8585
)
8686

87+
cc_library(
88+
name = "abstract_tracked_device_buffer",
89+
srcs = ["abstract_tracked_device_buffer.cc"],
90+
hdrs = ["abstract_tracked_device_buffer.h"],
91+
deps = [
92+
":pjrt_client",
93+
":pjrt_future",
94+
"@com_google_absl//absl/base:core_headers",
95+
"@com_google_absl//absl/log:check",
96+
"@com_google_absl//absl/status",
97+
"@com_google_absl//absl/status:statusor",
98+
"@com_google_absl//absl/synchronization",
99+
"@tsl//tsl/profiler/lib:traceme",
100+
],
101+
)
102+
87103
cc_library(
88104
name = "tracked_device_buffer",
89105
srcs = ["tracked_device_buffer.cc"],
90106
hdrs = ["tracked_device_buffer.h"],
91107
deps = [
108+
":abstract_tracked_device_buffer",
92109
":event_pool",
93110
":pjrt_client",
94111
":pjrt_common",
@@ -504,6 +521,7 @@ cc_library(
504521
hdrs = ["pjrt_stream_executor_client.h"],
505522
visibility = internal_visibility(["//xla:friends"]),
506523
deps = [
524+
":abstract_tracked_device_buffer",
507525
":event_pool",
508526
":host_callback",
509527
":host_memory_spaces",
Lines changed: 247 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,247 @@
1+
/* Copyright 2025 The OpenXLA Authors.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#include "xla/pjrt/abstract_tracked_device_buffer.h"
17+
18+
#include <memory>
19+
#include <utility>
20+
21+
#include "absl/base/thread_annotations.h"
22+
#include "absl/log/check.h"
23+
#include "absl/status/status.h"
24+
#include "absl/status/statusor.h"
25+
#include "absl/synchronization/mutex.h"
26+
#include "tsl/profiler/lib/traceme.h"
27+
28+
namespace xla {
29+
30+
CommonPjRtBuffer::CommonPjRtBuffer(
31+
std::unique_ptr<AbstractTrackedDeviceBuffer> device_buffer)
32+
: device_buffer_(std::move(device_buffer)) {
33+
for (int i = 0; i < ScopedHold::Type::kMaxValue; ++i) {
34+
holds_[i] = 0;
35+
}
36+
}
37+
38+
CommonPjRtBuffer::~CommonPjRtBuffer() {
39+
for (int i = 0; i < ScopedHold::Type::kMaxValue; ++i) {
40+
CHECK_EQ(holds_[i], 0) << "Non-zero type " << i << " hold on destruction.";
41+
}
42+
}
43+
44+
void CommonPjRtBuffer::WaitForOutstandingUsageHolds() {
45+
auto not_in_usage_hold = [&]() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
46+
return holds_[ScopedHold::kUsage] == 0;
47+
};
48+
mu_.Await(absl::Condition(&not_in_usage_hold));
49+
}
50+
51+
void CommonPjRtBuffer::WaitForOutstandingDonationHold() {
52+
auto not_in_donation_hold = [&]() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
53+
return holds_[ScopedHold::kDonation] == 0;
54+
};
55+
mu_.Await(absl::Condition(&not_in_donation_hold));
56+
}
57+
58+
absl::StatusOr<AbstractTrackedDeviceBuffer*>
59+
CommonPjRtBuffer::GetBufferForUsageOrExternalHoldLocked(ScopedHold::Type type) {
60+
// All callers should have called WaitForOutstandingDonationHold().
61+
CHECK_EQ(holds_[ScopedHold::kDonation], 0);
62+
if (device_buffer_ == nullptr) {
63+
return absl::InvalidArgumentError("Buffer has been deleted or donated.");
64+
} else {
65+
++holds_[type];
66+
}
67+
return device_buffer_.get();
68+
}
69+
70+
absl::StatusOr<std::unique_ptr<AbstractTrackedDeviceBuffer>>
71+
CommonPjRtBuffer::GetBufferForDonationHoldLocked() {
72+
// All callers should have called WaitForOutstandingDonationHold().
73+
CHECK_EQ(holds_[ScopedHold::kDonation], 0);
74+
if (device_buffer_ == nullptr) {
75+
return absl::InvalidArgumentError("Donation requested for invalid buffer");
76+
}
77+
if (holds_[ScopedHold::kExternalReference] > 0) {
78+
return absl::InvalidArgumentError(
79+
"Donation requested for buffer with external reference");
80+
}
81+
// First add the donation hold.
82+
++holds_[ScopedHold::kDonation];
83+
// Then wait for any usage holds to be dropped or converted. No new usage
84+
// holds can be added until we drop the donation hold so this wait will
85+
// complete eventually.
86+
WaitForOutstandingUsageHolds();
87+
// Because we added a donation hold, nobody could release the buffer while
88+
// we were waiting.
89+
CHECK(device_buffer_ != nullptr);
90+
return std::move(device_buffer_);
91+
}
92+
93+
void CommonPjRtBuffer::AcquireHoldLocked(ScopedHold* hold) {
94+
if (hold->type() == ScopedHold::kDonation) {
95+
hold->AcquireDonation(GetBufferForDonationHoldLocked());
96+
return;
97+
}
98+
99+
hold->AcquireUsageOrExternalReference(
100+
GetBufferForUsageOrExternalHoldLocked(hold->type()));
101+
}
102+
103+
void CommonPjRtBuffer::DropUsageOrExternalHold(
104+
ScopedHold::Type type, AbstractTrackedDeviceBuffer* buffer) {
105+
absl::MutexLock lock(&mu_);
106+
CHECK(device_buffer_.get() == buffer || device_buffer_ == nullptr);
107+
CHECK_GT(holds_[type], 0);
108+
--holds_[type];
109+
}
110+
111+
void CommonPjRtBuffer::DropDonationHold(
112+
std::unique_ptr<AbstractTrackedDeviceBuffer> buffer) {
113+
absl::MutexLock lock(&mu_);
114+
CHECK_EQ(device_buffer_.get(), nullptr);
115+
device_buffer_ = std::move(buffer);
116+
CHECK_GT(holds_[ScopedHold::kDonation], 0);
117+
--holds_[ScopedHold::kDonation];
118+
CHECK_EQ(holds_[ScopedHold::kDonation], 0);
119+
CHECK_EQ(holds_[ScopedHold::kUsage], 0);
120+
CHECK_EQ(holds_[ScopedHold::kExternalReference], 0);
121+
}
122+
123+
absl::Status CommonPjRtBuffer::ScopedHold::status() const {
124+
// Lazily create absl::Status values only when they are requested.
125+
switch (state_) {
126+
case kUninitialized:
127+
return absl::InvalidArgumentError("Buffer has not been initialized");
128+
case kValid:
129+
return absl::OkStatus();
130+
case kMoved:
131+
return absl::InvalidArgumentError("Buffer has been moved.");
132+
case kConverted:
133+
return absl::InvalidArgumentError("Buffer has been converted");
134+
case kReleased:
135+
return absl::InvalidArgumentError("Buffer has been released");
136+
case kDonated:
137+
return absl::InvalidArgumentError("Buffer has been donated");
138+
case kError:
139+
return status_;
140+
default:
141+
CHECK(false) << "Unexpected state value " << state_;
142+
}
143+
}
144+
145+
void CommonPjRtBuffer::ScopedHold::DropHold() {
146+
if (ok()) {
147+
if (type_ == kDonation) {
148+
parent_->DropDonationHold(std::move(buffer_));
149+
} else {
150+
parent_->DropUsageOrExternalHold(type_, buffer_ptr_);
151+
}
152+
}
153+
}
154+
155+
CommonPjRtBuffer::ScopedHold::~ScopedHold() { DropHold(); }
156+
157+
CommonPjRtBuffer::ScopedHold::ScopedHold(ScopedHold&& other)
158+
: parent_(other.parent_),
159+
type_(other.type_),
160+
state_(other.state_),
161+
status_(std::move(other.status_)),
162+
buffer_ptr_(other.buffer_ptr_),
163+
buffer_(std::move(other.buffer_)) {
164+
// Preserve the invariant that status is invalid if buffer == nullptr.
165+
other.SetState(kMoved);
166+
}
167+
168+
void CommonPjRtBuffer::ScopedHold::AcquireDonation(
169+
absl::StatusOr<std::unique_ptr<AbstractTrackedDeviceBuffer>> buffer_or) {
170+
CHECK(!ok());
171+
if (buffer_or.ok()) {
172+
buffer_ = std::move(buffer_or).value();
173+
buffer_ptr_ = buffer_.get();
174+
SetState(kValid);
175+
} else {
176+
status_ = std::move(buffer_or).status();
177+
buffer_ = nullptr;
178+
buffer_ptr_ = nullptr;
179+
SetState(kError);
180+
}
181+
// Check the invariant holds.
182+
CHECK(!ok() || buffer_ptr_ != nullptr);
183+
}
184+
185+
void CommonPjRtBuffer::ScopedHold::AcquireUsageOrExternalReference(
186+
absl::StatusOr<AbstractTrackedDeviceBuffer*> buffer_or) {
187+
CHECK(!ok());
188+
if (buffer_or.ok()) {
189+
buffer_.reset();
190+
buffer_ptr_ = buffer_or.value();
191+
SetState(kValid);
192+
} else {
193+
status_ = std::move(buffer_or).status();
194+
buffer_.reset();
195+
buffer_ = nullptr;
196+
SetState(kError);
197+
}
198+
// Check the invariant holds.
199+
CHECK(!ok() || buffer_ptr_ != nullptr);
200+
}
201+
202+
void CommonPjRtBuffer::ScopedHold::ConfirmDonation() {
203+
CHECK(ok());
204+
CHECK_EQ(type(), kDonation);
205+
parent()->ConfirmDonation(buffer());
206+
SetState(kDonated);
207+
}
208+
209+
void CommonPjRtBuffer::ConfirmDonation(
210+
AbstractTrackedDeviceBuffer* device_buffer) {
211+
absl::MutexLock lock(&mu_);
212+
CHECK_EQ(holds_[ScopedHold::kUsage], 0);
213+
CHECK_EQ(holds_[ScopedHold::kExternalReference], 0);
214+
CHECK_EQ(holds_[ScopedHold::kDonation], 1);
215+
holds_[ScopedHold::kDonation] = 0;
216+
device_buffer->ConfirmDonation();
217+
}
218+
219+
std::unique_ptr<AbstractTrackedDeviceBuffer> CommonPjRtBuffer::ReleaseBuffer() {
220+
absl::MutexLock lock(&mu_);
221+
{
222+
tsl::profiler::TraceMe t1("Wait for donation holds");
223+
// We first wait for a donation hold to complete if there is one in
224+
// progress. If the donation succeeds via ConfirmDonation() then it will
225+
// set device_buffer_ to nullptr before returning to this thread.
226+
WaitForOutstandingDonationHold();
227+
}
228+
if (device_buffer_ == nullptr) {
229+
// Buffer has been deleted.
230+
return nullptr;
231+
}
232+
// Return device_buffer_ by move which also sets it to nullptr, so
233+
// that no other thread can add a hold while we are in
234+
// WaitForOutstandingUsageHolds() below.
235+
auto buffer = std::move(device_buffer_);
236+
237+
tsl::profiler::TraceMe t2("Wait for usage holds");
238+
WaitForOutstandingUsageHolds();
239+
return buffer;
240+
}
241+
242+
bool CommonPjRtBuffer::IsDeleted() {
243+
absl::MutexLock lock(&mu_);
244+
return device_buffer_ == nullptr;
245+
}
246+
247+
} // namespace xla

0 commit comments

Comments
 (0)