|
| 1 | +/* Copyright 2025 The OpenXLA Authors. |
| 2 | +
|
| 3 | +Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +you may not use this file except in compliance with the License. |
| 5 | +You may obtain a copy of the License at |
| 6 | +
|
| 7 | + http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +
|
| 9 | +Unless required by applicable law or agreed to in writing, software |
| 10 | +distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +See the License for the specific language governing permissions and |
| 13 | +limitations under the License. |
| 14 | +==============================================================================*/ |
| 15 | + |
| 16 | +#include "xla/pjrt/abstract_tracked_device_buffer.h" |
| 17 | + |
| 18 | +#include <memory> |
| 19 | +#include <utility> |
| 20 | + |
| 21 | +#include "absl/base/thread_annotations.h" |
| 22 | +#include "absl/log/check.h" |
| 23 | +#include "absl/status/status.h" |
| 24 | +#include "absl/status/statusor.h" |
| 25 | +#include "absl/synchronization/mutex.h" |
| 26 | +#include "tsl/profiler/lib/traceme.h" |
| 27 | + |
| 28 | +namespace xla { |
| 29 | + |
| 30 | +CommonPjRtBuffer::CommonPjRtBuffer( |
| 31 | + std::unique_ptr<AbstractTrackedDeviceBuffer> device_buffer) |
| 32 | + : device_buffer_(std::move(device_buffer)) { |
| 33 | + for (int i = 0; i < ScopedHold::Type::kMaxValue; ++i) { |
| 34 | + holds_[i] = 0; |
| 35 | + } |
| 36 | +} |
| 37 | + |
| 38 | +CommonPjRtBuffer::~CommonPjRtBuffer() { |
| 39 | + for (int i = 0; i < ScopedHold::Type::kMaxValue; ++i) { |
| 40 | + CHECK_EQ(holds_[i], 0) << "Non-zero type " << i << " hold on destruction."; |
| 41 | + } |
| 42 | +} |
| 43 | + |
| 44 | +void CommonPjRtBuffer::WaitForOutstandingUsageHolds() { |
| 45 | + auto not_in_usage_hold = [&]() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_) { |
| 46 | + return holds_[ScopedHold::kUsage] == 0; |
| 47 | + }; |
| 48 | + mu_.Await(absl::Condition(¬_in_usage_hold)); |
| 49 | +} |
| 50 | + |
| 51 | +void CommonPjRtBuffer::WaitForOutstandingDonationHold() { |
| 52 | + auto not_in_donation_hold = [&]() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_) { |
| 53 | + return holds_[ScopedHold::kDonation] == 0; |
| 54 | + }; |
| 55 | + mu_.Await(absl::Condition(¬_in_donation_hold)); |
| 56 | +} |
| 57 | + |
| 58 | +absl::StatusOr<AbstractTrackedDeviceBuffer*> |
| 59 | +CommonPjRtBuffer::GetBufferForUsageOrExternalHoldLocked(ScopedHold::Type type) { |
| 60 | + // All callers should have called WaitForOutstandingDonationHold(). |
| 61 | + CHECK_EQ(holds_[ScopedHold::kDonation], 0); |
| 62 | + if (device_buffer_ == nullptr) { |
| 63 | + return absl::InvalidArgumentError("Buffer has been deleted or donated."); |
| 64 | + } else { |
| 65 | + ++holds_[type]; |
| 66 | + } |
| 67 | + return device_buffer_.get(); |
| 68 | +} |
| 69 | + |
| 70 | +absl::StatusOr<std::unique_ptr<AbstractTrackedDeviceBuffer>> |
| 71 | +CommonPjRtBuffer::GetBufferForDonationHoldLocked() { |
| 72 | + // All callers should have called WaitForOutstandingDonationHold(). |
| 73 | + CHECK_EQ(holds_[ScopedHold::kDonation], 0); |
| 74 | + if (device_buffer_ == nullptr) { |
| 75 | + return absl::InvalidArgumentError("Donation requested for invalid buffer"); |
| 76 | + } |
| 77 | + if (holds_[ScopedHold::kExternalReference] > 0) { |
| 78 | + return absl::InvalidArgumentError( |
| 79 | + "Donation requested for buffer with external reference"); |
| 80 | + } |
| 81 | + // First add the donation hold. |
| 82 | + ++holds_[ScopedHold::kDonation]; |
| 83 | + // Then wait for any usage holds to be dropped or converted. No new usage |
| 84 | + // holds can be added until we drop the donation hold so this wait will |
| 85 | + // complete eventually. |
| 86 | + WaitForOutstandingUsageHolds(); |
| 87 | + // Because we added a donation hold, nobody could release the buffer while |
| 88 | + // we were waiting. |
| 89 | + CHECK(device_buffer_ != nullptr); |
| 90 | + return std::move(device_buffer_); |
| 91 | +} |
| 92 | + |
| 93 | +void CommonPjRtBuffer::AcquireHoldLocked(ScopedHold* hold) { |
| 94 | + if (hold->type() == ScopedHold::kDonation) { |
| 95 | + hold->AcquireDonation(GetBufferForDonationHoldLocked()); |
| 96 | + return; |
| 97 | + } |
| 98 | + |
| 99 | + hold->AcquireUsageOrExternalReference( |
| 100 | + GetBufferForUsageOrExternalHoldLocked(hold->type())); |
| 101 | +} |
| 102 | + |
| 103 | +void CommonPjRtBuffer::DropUsageOrExternalHold( |
| 104 | + ScopedHold::Type type, AbstractTrackedDeviceBuffer* buffer) { |
| 105 | + absl::MutexLock lock(&mu_); |
| 106 | + CHECK(device_buffer_.get() == buffer || device_buffer_ == nullptr); |
| 107 | + CHECK_GT(holds_[type], 0); |
| 108 | + --holds_[type]; |
| 109 | +} |
| 110 | + |
| 111 | +void CommonPjRtBuffer::DropDonationHold( |
| 112 | + std::unique_ptr<AbstractTrackedDeviceBuffer> buffer) { |
| 113 | + absl::MutexLock lock(&mu_); |
| 114 | + CHECK_EQ(device_buffer_.get(), nullptr); |
| 115 | + device_buffer_ = std::move(buffer); |
| 116 | + CHECK_GT(holds_[ScopedHold::kDonation], 0); |
| 117 | + --holds_[ScopedHold::kDonation]; |
| 118 | + CHECK_EQ(holds_[ScopedHold::kDonation], 0); |
| 119 | + CHECK_EQ(holds_[ScopedHold::kUsage], 0); |
| 120 | + CHECK_EQ(holds_[ScopedHold::kExternalReference], 0); |
| 121 | +} |
| 122 | + |
| 123 | +absl::Status CommonPjRtBuffer::ScopedHold::status() const { |
| 124 | + // Lazily create absl::Status values only when they are requested. |
| 125 | + switch (state_) { |
| 126 | + case kUninitialized: |
| 127 | + return absl::InvalidArgumentError("Buffer has not been initialized"); |
| 128 | + case kValid: |
| 129 | + return absl::OkStatus(); |
| 130 | + case kMoved: |
| 131 | + return absl::InvalidArgumentError("Buffer has been moved."); |
| 132 | + case kConverted: |
| 133 | + return absl::InvalidArgumentError("Buffer has been converted"); |
| 134 | + case kReleased: |
| 135 | + return absl::InvalidArgumentError("Buffer has been released"); |
| 136 | + case kDonated: |
| 137 | + return absl::InvalidArgumentError("Buffer has been donated"); |
| 138 | + case kError: |
| 139 | + return status_; |
| 140 | + default: |
| 141 | + CHECK(false) << "Unexpected state value " << state_; |
| 142 | + } |
| 143 | +} |
| 144 | + |
| 145 | +void CommonPjRtBuffer::ScopedHold::DropHold() { |
| 146 | + if (ok()) { |
| 147 | + if (type_ == kDonation) { |
| 148 | + parent_->DropDonationHold(std::move(buffer_)); |
| 149 | + } else { |
| 150 | + parent_->DropUsageOrExternalHold(type_, buffer_ptr_); |
| 151 | + } |
| 152 | + } |
| 153 | +} |
| 154 | + |
| 155 | +CommonPjRtBuffer::ScopedHold::~ScopedHold() { DropHold(); } |
| 156 | + |
| 157 | +CommonPjRtBuffer::ScopedHold::ScopedHold(ScopedHold&& other) |
| 158 | + : parent_(other.parent_), |
| 159 | + type_(other.type_), |
| 160 | + state_(other.state_), |
| 161 | + status_(std::move(other.status_)), |
| 162 | + buffer_ptr_(other.buffer_ptr_), |
| 163 | + buffer_(std::move(other.buffer_)) { |
| 164 | + // Preserve the invariant that status is invalid if buffer == nullptr. |
| 165 | + other.SetState(kMoved); |
| 166 | +} |
| 167 | + |
| 168 | +void CommonPjRtBuffer::ScopedHold::AcquireDonation( |
| 169 | + absl::StatusOr<std::unique_ptr<AbstractTrackedDeviceBuffer>> buffer_or) { |
| 170 | + CHECK(!ok()); |
| 171 | + if (buffer_or.ok()) { |
| 172 | + buffer_ = std::move(buffer_or).value(); |
| 173 | + buffer_ptr_ = buffer_.get(); |
| 174 | + SetState(kValid); |
| 175 | + } else { |
| 176 | + status_ = std::move(buffer_or).status(); |
| 177 | + buffer_ = nullptr; |
| 178 | + buffer_ptr_ = nullptr; |
| 179 | + SetState(kError); |
| 180 | + } |
| 181 | + // Check the invariant holds. |
| 182 | + CHECK(!ok() || buffer_ptr_ != nullptr); |
| 183 | +} |
| 184 | + |
| 185 | +void CommonPjRtBuffer::ScopedHold::AcquireUsageOrExternalReference( |
| 186 | + absl::StatusOr<AbstractTrackedDeviceBuffer*> buffer_or) { |
| 187 | + CHECK(!ok()); |
| 188 | + if (buffer_or.ok()) { |
| 189 | + buffer_.reset(); |
| 190 | + buffer_ptr_ = buffer_or.value(); |
| 191 | + SetState(kValid); |
| 192 | + } else { |
| 193 | + status_ = std::move(buffer_or).status(); |
| 194 | + buffer_.reset(); |
| 195 | + buffer_ = nullptr; |
| 196 | + SetState(kError); |
| 197 | + } |
| 198 | + // Check the invariant holds. |
| 199 | + CHECK(!ok() || buffer_ptr_ != nullptr); |
| 200 | +} |
| 201 | + |
| 202 | +void CommonPjRtBuffer::ScopedHold::ConfirmDonation() { |
| 203 | + CHECK(ok()); |
| 204 | + CHECK_EQ(type(), kDonation); |
| 205 | + parent()->ConfirmDonation(buffer()); |
| 206 | + SetState(kDonated); |
| 207 | +} |
| 208 | + |
| 209 | +void CommonPjRtBuffer::ConfirmDonation( |
| 210 | + AbstractTrackedDeviceBuffer* device_buffer) { |
| 211 | + absl::MutexLock lock(&mu_); |
| 212 | + CHECK_EQ(holds_[ScopedHold::kUsage], 0); |
| 213 | + CHECK_EQ(holds_[ScopedHold::kExternalReference], 0); |
| 214 | + CHECK_EQ(holds_[ScopedHold::kDonation], 1); |
| 215 | + holds_[ScopedHold::kDonation] = 0; |
| 216 | + device_buffer->ConfirmDonation(); |
| 217 | +} |
| 218 | + |
| 219 | +std::unique_ptr<AbstractTrackedDeviceBuffer> CommonPjRtBuffer::ReleaseBuffer() { |
| 220 | + absl::MutexLock lock(&mu_); |
| 221 | + { |
| 222 | + tsl::profiler::TraceMe t1("Wait for donation holds"); |
| 223 | + // We first wait for a donation hold to complete if there is one in |
| 224 | + // progress. If the donation succeeds via ConfirmDonation() then it will |
| 225 | + // set device_buffer_ to nullptr before returning to this thread. |
| 226 | + WaitForOutstandingDonationHold(); |
| 227 | + } |
| 228 | + if (device_buffer_ == nullptr) { |
| 229 | + // Buffer has been deleted. |
| 230 | + return nullptr; |
| 231 | + } |
| 232 | + // Return device_buffer_ by move which also sets it to nullptr, so |
| 233 | + // that no other thread can add a hold while we are in |
| 234 | + // WaitForOutstandingUsageHolds() below. |
| 235 | + auto buffer = std::move(device_buffer_); |
| 236 | + |
| 237 | + tsl::profiler::TraceMe t2("Wait for usage holds"); |
| 238 | + WaitForOutstandingUsageHolds(); |
| 239 | + return buffer; |
| 240 | +} |
| 241 | + |
| 242 | +bool CommonPjRtBuffer::IsDeleted() { |
| 243 | + absl::MutexLock lock(&mu_); |
| 244 | + return device_buffer_ == nullptr; |
| 245 | +} |
| 246 | + |
| 247 | +} // namespace xla |
0 commit comments