Skip to content

Commit da46e28

Browse files
ezhulenevcopybara-github
authored andcommitted
[stream_executor] Optimize KernelArgsPackedArray for storing only device memory pointers
name old cpu/op new cpu/op delta BM_PackDeviceMemoryArgs/4 22.3ns ± 9% 18.1ns ± 2% -19.00% (p=0.000 n=59+46) BM_PackDeviceMemoryArgs/8 31.3ns ± 3% 32.4ns ± 2% +3.69% (p=0.000 n=50+50) BM_PackDeviceMemoryArgs/32 80.5ns ± 4% 74.3ns ± 2% -7.81% (p=0.000 n=57+43) BM_PackDeviceMemoryArgs/64 157ns ± 3% 142ns ± 4% -9.63% (p=0.000 n=51+53) BM_PackDeviceMemoryArgs/128 364ns ± 5% 323ns ± 3% -11.31% (p=0.000 n=60+44) BM_PackDeviceMemoryArgs/256 581ns ± 3% 546ns ± 5% -6.08% (p=0.000 n=49+58) BM_PackDeviceMemoryArgs/512 1.14µs ± 2% 1.06µs ± 4% -6.75% (p=0.000 n=48+54) BM_PackDeviceMemoryArgs/1024 2.38µs ± 3% 2.13µs ± 5% -10.39% (p=0.000 n=51+60) PiperOrigin-RevId: 580365888
1 parent 30e5335 commit da46e28

File tree

3 files changed

+145
-69
lines changed

3 files changed

+145
-69
lines changed

xla/stream_executor/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -648,6 +648,7 @@ xla_cc_test(
648648
deps = [
649649
":device_memory",
650650
":stream_executor",
651+
"//xla/stream_executor/host:host_platform",
651652
"@tsl//tsl/lib/core:status_test_util",
652653
"@tsl//tsl/platform:status",
653654
"@tsl//tsl/platform:test",

xla/stream_executor/kernel.h

Lines changed: 83 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ limitations under the License.
7070
#define XLA_STREAM_EXECUTOR_KERNEL_H_
7171

7272
#include <array>
73+
#include <cassert>
7374
#include <cstddef>
7475
#include <cstdint>
7576
#include <cstring>
@@ -325,59 +326,79 @@ class KernelArgsPackedArrayBase : public KernelArgsArrayBase {
325326
Kind kind() const final { return Kind::kPackedArray; }
326327
};
327328

328-
// A list of arguments for a kernel call.
329-
//
330-
// The template parameter kNumArgs is the maximum number of arguments which can
331-
// be stored in the list.
332-
//
333-
// Contains a list of addresses for non-shared-memory arguments and a list of
334-
// sizes for shared-memory arguments. Since the shared-memory arguments may be
335-
// interspersed with the non-shared-memory arguments, it also stores a list of
336-
// the indices at which the shared-memory arguments appeared.
337-
//
338-
// For example, if the argument address list contains {a, b, c, d, e}, the
339-
// shared-memory arguments list contains the sizes of {A, B, C}, and the
340-
// shared-memory indices list contains {0, 3, 5}, then the original list of
341-
// arguments was {A, a, b, B, c, C, d, e}.
342-
//
343-
// This way of storing the arguments makes CUDA kernel calls efficient because
344-
// they only require the argument address list and the total number of shared
345-
// bytes, but it also makes it possible for OpenCL kernel calls because they
346-
// depend on the location of each shared-memory argument and its size.
329+
//===----------------------------------------------------------------------===//
330+
// Kernel arguments packing for device memory and POD args.
331+
//===----------------------------------------------------------------------===//
332+
333+
// KernelArgsPackedArray is optimized for packing DeviceMemoryBase pointers
334+
// and POD arguments (i.e. scalars) when the number and type of arguments are
335+
// not known at compile time.
336+
337+
namespace internal {
338+
339+
// An empty storage for packing just the device memory arguments, that are
340+
// stored directly in the `KernelArgsPackedArray`.
341+
class EmptyArgs {};
342+
343+
// A storage for POD generic arguments that are smaller than `size` and require
344+
// alignment smaller or equal to `alignment`.
345+
template <size_t capacity, size_t size = 8,
346+
size_t alignment = alignof(std::max_align_t)>
347+
class PodArgs {
348+
protected:
349+
template <typename T>
350+
const std::byte *add_pod_argument(const T &arg) {
351+
static_assert(
352+
std::is_pod_v<T> && sizeof(T) <= size & alignof(T) <= alignment,
353+
"Type is not compatible with POD arguments storage");
354+
355+
assert(num_args_ < capacity && "pod args overflow");
356+
std::byte *arg_storage = args_storage_[num_args_++].storage;
357+
std::memcpy(arg_storage, &arg, sizeof(T));
358+
359+
return arg_storage;
360+
}
361+
362+
private:
363+
struct Arg {
364+
alignas(alignment) std::byte storage[size];
365+
};
366+
367+
size_t num_args_ = 0;
368+
std::array<Arg, capacity> args_storage_;
369+
};
370+
371+
template <typename ArgsStorage>
372+
static constexpr bool is_pod_args_v = false;
373+
374+
template <size_t capacity, size_t size, size_t alignment>
375+
static constexpr bool is_pod_args_v<PodArgs<capacity, size, alignment>> = true;
376+
377+
} // namespace internal
378+
379+
// An array of arguments for a kernel call.
347380
//
348-
// Note that the code for adding arguments has been identified as a performance
349-
// hotspot in some real-world applications so this structure has been optimized
350-
// for the performance of argument adding.
351-
template <size_t kNumArgs>
352-
class KernelArgsPackedArray : public KernelArgsPackedArrayBase {
381+
// The template parameter `num_args` is the maximum number of arguments which
382+
// can be stored in the array.
383+
template <size_t num_args, typename ArgsStorage = internal::PodArgs<num_args>>
384+
class KernelArgsPackedArray : public KernelArgsPackedArrayBase, ArgsStorage {
353385
public:
354-
static constexpr int kMaxGenericArgSize = 8;
355-
356386
KernelArgsPackedArray() = default;
357387

358388
// KernelArgsPackedArray is not copyable or movable because argument addresses
359389
// point to inline storage that can't be moved.
360390
KernelArgsPackedArray(const KernelArgsPackedArray &) = delete;
361391
KernelArgsPackedArray &operator=(const KernelArgsPackedArray &) = delete;
362392

363-
// Do not allow casting into concrete packed array type.
364-
static bool classof(const KernelArgsArrayBase *args) { return false; }
365-
366393
// Adds an argument to the list.
367394
template <typename T>
368395
void add_argument(const T &arg) {
369-
static_assert(sizeof(T) <= kMaxGenericArgSize,
370-
"Please adjust kMaxGenericArgSize");
371-
static_assert(std::is_pod_v<T>, "Only pod types supported!");
372-
char *generic_arg_storage =
373-
&generic_arguments_[number_of_generic_arguments_++ *
374-
kMaxGenericArgSize];
375-
376-
CHECK_EQ(reinterpret_cast<uintptr_t>(generic_arg_storage) % alignof(T), 0);
377-
std::memcpy(generic_arg_storage, &arg, sizeof(T));
378-
379-
argument_addresses_[number_of_argument_addresses_] = generic_arg_storage;
380-
++number_of_argument_addresses_;
396+
if constexpr (internal::is_pod_args_v<ArgsStorage>) {
397+
argument_addresses_[number_of_argument_addresses_++] =
398+
ArgsStorage::add_pod_argument(arg);
399+
} else {
400+
static_assert(false, "Arguments storage is not supported");
401+
}
381402
}
382403

383404
// Adds a device memory argument to the list.
@@ -399,54 +420,49 @@ class KernelArgsPackedArray : public KernelArgsPackedArrayBase {
399420

400421
// Gets the number of arguments added so far, including shared memory
401422
// arguments.
402-
size_t number_of_arguments() const override {
423+
size_t number_of_arguments() const final {
403424
return number_of_argument_addresses_ + (total_shared_memory_bytes_ > 0);
404425
}
405426

406427
// Gets the total number of shared memory bytes added so far.
407-
uint64_t number_of_shared_bytes() const override {
428+
uint64_t number_of_shared_bytes() const final {
408429
return total_shared_memory_bytes_;
409430
}
410431

411432
// Gets the list of argument addresses.
412-
absl::Span<const void *const> argument_addresses() const override {
433+
absl::Span<const void *const> argument_addresses() const final {
413434
return absl::Span<const void *const>(argument_addresses_.data(),
414435
number_of_argument_addresses_);
415436
}
416437

417438
private:
418439
// A place to store copies of opaque pointers from device memory arguments.
419-
std::array<const void *, kNumArgs> device_memory_opaque_pointers_;
440+
std::array<const void *, num_args> device_memory_opaque_pointers_;
420441

421442
// Addresses for non-shared-memory arguments.
422-
std::array<const void *, kNumArgs> argument_addresses_;
423-
424-
// Storage for arguments of templated type.
425-
alignas(std::max_align_t)
426-
std::array<char, kNumArgs * kMaxGenericArgSize> generic_arguments_;
443+
std::array<const void *, num_args> argument_addresses_;
427444

428445
// Total of all shared memory sizes.
429446
size_t total_shared_memory_bytes_ = 0;
430447

431-
// Number of significant entries in argument_addresses_ and argument_sizes_.
448+
// Number of significant entries in argument_addresses_.
432449
size_t number_of_argument_addresses_ = 0;
433-
434-
// The number of generic arguments that have been added to generic_arguments_.
435-
size_t number_of_generic_arguments_ = 0;
436450
};
437451

452+
namespace internal {
438453
template <int n>
439454
std::unique_ptr<KernelArgsPackedArrayBase> PackKernelArgs(
440455
absl::Span<const DeviceMemoryBase> args, uint32_t shared_mem_bytes) {
441-
auto kernel_args = std::make_unique<KernelArgsPackedArray<n>>();
456+
auto packed = std::make_unique<KernelArgsPackedArray<n, EmptyArgs>>();
442457
for (const DeviceMemoryBase &buf : args) {
443-
kernel_args->add_device_memory_argument(buf);
458+
packed->add_device_memory_argument(buf);
444459
}
445460
if (shared_mem_bytes > 0) {
446-
kernel_args->add_shared_bytes(shared_mem_bytes);
461+
packed->add_shared_bytes(shared_mem_bytes);
447462
}
448-
return kernel_args;
463+
return packed;
449464
}
465+
} // namespace internal
450466

451467
inline tsl::StatusOr<std::unique_ptr<KernelArgsPackedArrayBase>> PackKernelArgs(
452468
absl::Span<const DeviceMemoryBase> args, uint32_t shared_mem_bytes) {
@@ -461,22 +477,22 @@ inline tsl::StatusOr<std::unique_ptr<KernelArgsPackedArrayBase>> PackKernelArgs(
461477
// Specialize kernel arguments array for small sizes to allocate a smaller
462478
// chunk of memory and hopefully hit a small allocations cache.
463479
if (args.size() <= 4) {
464-
return PackKernelArgs<8>(args, shared_mem_bytes);
480+
return internal::PackKernelArgs<4>(args, shared_mem_bytes);
465481
} else if (args.size() <= 8) {
466-
return PackKernelArgs<8>(args, shared_mem_bytes);
482+
return internal::PackKernelArgs<8>(args, shared_mem_bytes);
467483
} else if (args.size() <= 16) {
468-
return PackKernelArgs<16>(args, shared_mem_bytes);
484+
return internal::PackKernelArgs<16>(args, shared_mem_bytes);
469485
} else if (args.size() <= 32) {
470-
return PackKernelArgs<32>(args, shared_mem_bytes);
486+
return internal::PackKernelArgs<32>(args, shared_mem_bytes);
471487
} else if (args.size() <= 64) {
472-
return PackKernelArgs<64>(args, shared_mem_bytes);
488+
return internal::PackKernelArgs<64>(args, shared_mem_bytes);
473489
} else if (args.size() <= 256) {
474-
return PackKernelArgs<256>(args, shared_mem_bytes);
490+
return internal::PackKernelArgs<256>(args, shared_mem_bytes);
475491
} else if (args.size() <= 512) {
476-
return PackKernelArgs<512>(args, shared_mem_bytes);
492+
return internal::PackKernelArgs<512>(args, shared_mem_bytes);
477493
}
478494

479-
return PackKernelArgs<kKernelArgsLimit>(args, shared_mem_bytes);
495+
return internal::PackKernelArgs<kKernelArgsLimit>(args, shared_mem_bytes);
480496
}
481497

482498
inline tsl::StatusOr<std::unique_ptr<KernelArgsPackedArrayBase>> PackKernelArgs(
@@ -709,8 +725,7 @@ struct KernelParamsOk<TypedKernel<Params...>, Args...> {
709725
};
710726

711727
// Packs the given arguments into a KernelArgsArray with compile-time type
712-
// checks. se::Stream::ThenLaunch does this too except it ignores the shared
713-
// memory size in `kernel`.
728+
// checks.
714729
template <typename... Params, typename... Args>
715730
std::unique_ptr<KernelArgsPackedArrayBase> PackKernelArgs(
716731
const TypedKernel<Params...> &kernel, const Args &...args) {

xla/stream_executor/kernel_test.cc

Lines changed: 61 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,74 @@ limitations under the License.
1515

1616
#include "xla/stream_executor/kernel.h"
1717

18+
#include <cstdint>
19+
#include <memory>
1820
#include <vector>
1921

2022
#include "xla/stream_executor/device_memory.h"
23+
#include "xla/stream_executor/stream_executor.h"
24+
#include "tsl/platform/test.h"
2125
#include "tsl/platform/test_benchmark.h"
2226

2327
namespace stream_executor {
2428

25-
// TODO(ezhulenev): Add tests for packing custom arguments.
29+
static std::unique_ptr<StreamExecutor> NewStreamExecutor() {
30+
Platform* platform = MultiPlatformManager::PlatformWithName("Host").value();
31+
StreamExecutorConfig config(/*ordinal=*/0);
32+
return platform->GetUncachedExecutor(config).value();
33+
}
34+
35+
TEST(KernelTest, PackDeviceMemoryArguments) {
36+
auto executor = NewStreamExecutor();
37+
38+
DeviceMemoryBase a(reinterpret_cast<void*>(0x12345678));
39+
DeviceMemoryBase b(reinterpret_cast<void*>(0x87654321));
40+
41+
auto args = PackKernelArgs({a, b}, 0).value();
42+
ASSERT_EQ(args->number_of_arguments(), 2);
43+
44+
auto packed = args->argument_addresses();
45+
const void* ptr0 = *reinterpret_cast<const void* const*>(packed[0]);
46+
const void* ptr1 = *reinterpret_cast<const void* const*>(packed[1]);
47+
48+
ASSERT_EQ(ptr0, a.opaque());
49+
ASSERT_EQ(ptr1, b.opaque());
50+
}
51+
52+
TEST(KernelTest, PackPodArguments) {
53+
auto args = std::make_unique<KernelArgsPackedArray<4>>();
54+
args->add_argument(1);
55+
args->add_argument(2.0f);
56+
args->add_argument(3.0);
57+
58+
ASSERT_EQ(args->number_of_arguments(), 3);
59+
60+
auto packed = args->argument_addresses();
61+
int32_t i32 = *reinterpret_cast<const int32_t*>(packed[0]);
62+
float f32 = *reinterpret_cast<const float*>(packed[1]);
63+
double f64 = *reinterpret_cast<const double*>(packed[2]);
64+
65+
ASSERT_EQ(i32, 1);
66+
ASSERT_EQ(f32, 2.0f);
67+
ASSERT_EQ(f64, 3.0);
68+
}
69+
70+
TEST(KernelTest, PackTypedKernelArguments) {
71+
auto executor = NewStreamExecutor();
72+
TypedKernel<int32_t, float, double> kernel(executor.get());
73+
74+
auto args = PackKernelArgs(kernel, 1, 2.0f, 3.0);
75+
ASSERT_EQ(args->number_of_arguments(), 3);
76+
77+
auto packed = args->argument_addresses();
78+
int32_t i32 = *reinterpret_cast<const int32_t*>(packed[0]);
79+
float f32 = *reinterpret_cast<const float*>(packed[1]);
80+
double f64 = *reinterpret_cast<const double*>(packed[2]);
81+
82+
ASSERT_EQ(i32, 1);
83+
ASSERT_EQ(f32, 2.0f);
84+
ASSERT_EQ(f64, 3.0);
85+
}
2686

2787
//===----------------------------------------------------------------------===//
2888
// Performance benchmarks below

0 commit comments

Comments
 (0)