@@ -70,6 +70,7 @@ limitations under the License.
70
70
#define XLA_STREAM_EXECUTOR_KERNEL_H_
71
71
72
72
#include < array>
73
+ #include < cassert>
73
74
#include < cstddef>
74
75
#include < cstdint>
75
76
#include < cstring>
@@ -325,59 +326,79 @@ class KernelArgsPackedArrayBase : public KernelArgsArrayBase {
325
326
Kind kind () const final { return Kind::kPackedArray ; }
326
327
};
327
328
328
- // A list of arguments for a kernel call.
329
- //
330
- // The template parameter kNumArgs is the maximum number of arguments which can
331
- // be stored in the list.
332
- //
333
- // Contains a list of addresses for non-shared-memory arguments and a list of
334
- // sizes for shared-memory arguments. Since the shared-memory arguments may be
335
- // interspersed with the non-shared-memory arguments, it also stores a list of
336
- // the indices at which the shared-memory arguments appeared.
337
- //
338
- // For example, if the argument address list contains {a, b, c, d, e}, the
339
- // shared-memory arguments list contains the sizes of {A, B, C}, and the
340
- // shared-memory indices list contains {0, 3, 5}, then the original list of
341
- // arguments was {A, a, b, B, c, C, d, e}.
342
- //
343
- // This way of storing the arguments makes CUDA kernel calls efficient because
344
- // they only require the argument address list and the total number of shared
345
- // bytes, but it also makes it possible for OpenCL kernel calls because they
346
- // depend on the location of each shared-memory argument and its size.
329
+ // ===----------------------------------------------------------------------===//
330
+ // Kernel arguments packing for device memory and POD args.
331
+ // ===----------------------------------------------------------------------===//
332
+
333
+ // KernelArgsPackedArray is optimized for packing DeviceMemoryBase pointers
334
+ // and POD arguments (i.e. scalars) when the number and type of arguments are
335
+ // not known at compile time.
336
+
337
+ namespace internal {
338
+
339
+ // An empty storage for packing just the device memory arguments, that are
340
+ // stored directly in the `KernelArgsPackedArray`.
341
+ class EmptyArgs {};
342
+
343
+ // A storage for POD generic arguments that are smaller than `size` and require
344
+ // alignment smaller or equal to `alignment`.
345
+ template <size_t capacity, size_t size = 8 ,
346
+ size_t alignment = alignof (std::max_align_t )>
347
+ class PodArgs {
348
+ protected:
349
+ template <typename T>
350
+ const std::byte *add_pod_argument (const T &arg) {
351
+ static_assert (
352
+ std::is_pod_v<T> && sizeof (T) <= size & alignof (T) <= alignment,
353
+ " Type is not compatible with POD arguments storage" );
354
+
355
+ assert (num_args_ < capacity && " pod args overflow" );
356
+ std::byte *arg_storage = args_storage_[num_args_++].storage ;
357
+ std::memcpy (arg_storage, &arg, sizeof (T));
358
+
359
+ return arg_storage;
360
+ }
361
+
362
+ private:
363
+ struct Arg {
364
+ alignas (alignment) std::byte storage[size];
365
+ };
366
+
367
+ size_t num_args_ = 0 ;
368
+ std::array<Arg, capacity> args_storage_;
369
+ };
370
+
371
+ template <typename ArgsStorage>
372
+ static constexpr bool is_pod_args_v = false ;
373
+
374
+ template <size_t capacity, size_t size, size_t alignment>
375
+ static constexpr bool is_pod_args_v<PodArgs<capacity, size, alignment>> = true ;
376
+
377
+ } // namespace internal
378
+
379
+ // An array of arguments for a kernel call.
347
380
//
348
- // Note that the code for adding arguments has been identified as a performance
349
- // hotspot in some real-world applications so this structure has been optimized
350
- // for the performance of argument adding.
351
- template <size_t kNumArgs >
352
- class KernelArgsPackedArray : public KernelArgsPackedArrayBase {
381
+ // The template parameter `num_args` is the maximum number of arguments which
382
+ // can be stored in the array.
383
+ template <size_t num_args, typename ArgsStorage = internal::PodArgs<num_args>>
384
+ class KernelArgsPackedArray : public KernelArgsPackedArrayBase , ArgsStorage {
353
385
public:
354
- static constexpr int kMaxGenericArgSize = 8 ;
355
-
356
386
KernelArgsPackedArray () = default ;
357
387
358
388
// KernelArgsPackedArray is not copyable or movable because argument addresses
359
389
// point to inline storage that can't be moved.
360
390
KernelArgsPackedArray (const KernelArgsPackedArray &) = delete ;
361
391
KernelArgsPackedArray &operator =(const KernelArgsPackedArray &) = delete ;
362
392
363
- // Do not allow casting into concrete packed array type.
364
- static bool classof (const KernelArgsArrayBase *args) { return false ; }
365
-
366
393
// Adds an argument to the list.
367
394
template <typename T>
368
395
void add_argument (const T &arg) {
369
- static_assert (sizeof (T) <= kMaxGenericArgSize ,
370
- " Please adjust kMaxGenericArgSize" );
371
- static_assert (std::is_pod_v<T>, " Only pod types supported!" );
372
- char *generic_arg_storage =
373
- &generic_arguments_[number_of_generic_arguments_++ *
374
- kMaxGenericArgSize ];
375
-
376
- CHECK_EQ (reinterpret_cast <uintptr_t >(generic_arg_storage) % alignof (T), 0 );
377
- std::memcpy (generic_arg_storage, &arg, sizeof (T));
378
-
379
- argument_addresses_[number_of_argument_addresses_] = generic_arg_storage;
380
- ++number_of_argument_addresses_;
396
+ if constexpr (internal::is_pod_args_v<ArgsStorage>) {
397
+ argument_addresses_[number_of_argument_addresses_++] =
398
+ ArgsStorage::add_pod_argument (arg);
399
+ } else {
400
+ static_assert (false , " Arguments storage is not supported" );
401
+ }
381
402
}
382
403
383
404
// Adds a device memory argument to the list.
@@ -399,54 +420,49 @@ class KernelArgsPackedArray : public KernelArgsPackedArrayBase {
399
420
400
421
// Gets the number of arguments added so far, including shared memory
401
422
// arguments.
402
- size_t number_of_arguments () const override {
423
+ size_t number_of_arguments () const final {
403
424
return number_of_argument_addresses_ + (total_shared_memory_bytes_ > 0 );
404
425
}
405
426
406
427
// Gets the total number of shared memory bytes added so far.
407
- uint64_t number_of_shared_bytes () const override {
428
+ uint64_t number_of_shared_bytes () const final {
408
429
return total_shared_memory_bytes_;
409
430
}
410
431
411
432
// Gets the list of argument addresses.
412
- absl::Span<const void *const > argument_addresses () const override {
433
+ absl::Span<const void *const > argument_addresses () const final {
413
434
return absl::Span<const void *const >(argument_addresses_.data (),
414
435
number_of_argument_addresses_);
415
436
}
416
437
417
438
private:
418
439
// A place to store copies of opaque pointers from device memory arguments.
419
- std::array<const void *, kNumArgs > device_memory_opaque_pointers_;
440
+ std::array<const void *, num_args > device_memory_opaque_pointers_;
420
441
421
442
// Addresses for non-shared-memory arguments.
422
- std::array<const void *, kNumArgs > argument_addresses_;
423
-
424
- // Storage for arguments of templated type.
425
- alignas (std::max_align_t )
426
- std::array<char , kNumArgs * kMaxGenericArgSize > generic_arguments_;
443
+ std::array<const void *, num_args> argument_addresses_;
427
444
428
445
// Total of all shared memory sizes.
429
446
size_t total_shared_memory_bytes_ = 0 ;
430
447
431
- // Number of significant entries in argument_addresses_ and argument_sizes_ .
448
+ // Number of significant entries in argument_addresses_.
432
449
size_t number_of_argument_addresses_ = 0 ;
433
-
434
- // The number of generic arguments that have been added to generic_arguments_.
435
- size_t number_of_generic_arguments_ = 0 ;
436
450
};
437
451
452
+ namespace internal {
438
453
template <int n>
439
454
std::unique_ptr<KernelArgsPackedArrayBase> PackKernelArgs (
440
455
absl::Span<const DeviceMemoryBase> args, uint32_t shared_mem_bytes) {
441
- auto kernel_args = std::make_unique<KernelArgsPackedArray<n>>();
456
+ auto packed = std::make_unique<KernelArgsPackedArray<n, EmptyArgs >>();
442
457
for (const DeviceMemoryBase &buf : args) {
443
- kernel_args ->add_device_memory_argument (buf);
458
+ packed ->add_device_memory_argument (buf);
444
459
}
445
460
if (shared_mem_bytes > 0 ) {
446
- kernel_args ->add_shared_bytes (shared_mem_bytes);
461
+ packed ->add_shared_bytes (shared_mem_bytes);
447
462
}
448
- return kernel_args ;
463
+ return packed ;
449
464
}
465
+ } // namespace internal
450
466
451
467
inline tsl::StatusOr<std::unique_ptr<KernelArgsPackedArrayBase>> PackKernelArgs (
452
468
absl::Span<const DeviceMemoryBase> args, uint32_t shared_mem_bytes) {
@@ -461,22 +477,22 @@ inline tsl::StatusOr<std::unique_ptr<KernelArgsPackedArrayBase>> PackKernelArgs(
461
477
// Specialize kernel arguments array for small sizes to allocate a smaller
462
478
// chunk of memory and hopefully hit a small allocations cache.
463
479
if (args.size () <= 4 ) {
464
- return PackKernelArgs<8 >(args, shared_mem_bytes);
480
+ return internal:: PackKernelArgs<4 >(args, shared_mem_bytes);
465
481
} else if (args.size () <= 8 ) {
466
- return PackKernelArgs<8 >(args, shared_mem_bytes);
482
+ return internal:: PackKernelArgs<8 >(args, shared_mem_bytes);
467
483
} else if (args.size () <= 16 ) {
468
- return PackKernelArgs<16 >(args, shared_mem_bytes);
484
+ return internal:: PackKernelArgs<16 >(args, shared_mem_bytes);
469
485
} else if (args.size () <= 32 ) {
470
- return PackKernelArgs<32 >(args, shared_mem_bytes);
486
+ return internal:: PackKernelArgs<32 >(args, shared_mem_bytes);
471
487
} else if (args.size () <= 64 ) {
472
- return PackKernelArgs<64 >(args, shared_mem_bytes);
488
+ return internal:: PackKernelArgs<64 >(args, shared_mem_bytes);
473
489
} else if (args.size () <= 256 ) {
474
- return PackKernelArgs<256 >(args, shared_mem_bytes);
490
+ return internal:: PackKernelArgs<256 >(args, shared_mem_bytes);
475
491
} else if (args.size () <= 512 ) {
476
- return PackKernelArgs<512 >(args, shared_mem_bytes);
492
+ return internal:: PackKernelArgs<512 >(args, shared_mem_bytes);
477
493
}
478
494
479
- return PackKernelArgs<kKernelArgsLimit >(args, shared_mem_bytes);
495
+ return internal:: PackKernelArgs<kKernelArgsLimit >(args, shared_mem_bytes);
480
496
}
481
497
482
498
inline tsl::StatusOr<std::unique_ptr<KernelArgsPackedArrayBase>> PackKernelArgs (
@@ -709,8 +725,7 @@ struct KernelParamsOk<TypedKernel<Params...>, Args...> {
709
725
};
710
726
711
727
// Packs the given arguments into a KernelArgsArray with compile-time type
712
- // checks. se::Stream::ThenLaunch does this too except it ignores the shared
713
- // memory size in `kernel`.
728
+ // checks.
714
729
template <typename ... Params, typename ... Args>
715
730
std::unique_ptr<KernelArgsPackedArrayBase> PackKernelArgs (
716
731
const TypedKernel<Params...> &kernel, const Args &...args) {
0 commit comments