@@ -70,6 +70,7 @@ limitations under the License.
70
70
#define XLA_STREAM_EXECUTOR_KERNEL_H_
71
71
72
72
#include < array>
73
+ #include < cstddef>
73
74
#include < cstdint>
74
75
#include < cstring>
75
76
#include < memory>
@@ -376,7 +377,6 @@ class KernelArgsPackedArray : public KernelArgsPackedArrayBase {
376
377
std::memcpy (generic_arg_storage, &arg, sizeof (T));
377
378
378
379
argument_addresses_[number_of_argument_addresses_] = generic_arg_storage;
379
- argument_sizes_[number_of_argument_addresses_] = sizeof (arg);
380
380
++number_of_argument_addresses_;
381
381
}
382
382
@@ -386,7 +386,6 @@ class KernelArgsPackedArray : public KernelArgsPackedArrayBase {
386
386
&device_memory_opaque_pointers_[number_of_argument_addresses_];
387
387
*copy_ptr = arg.opaque ();
388
388
argument_addresses_[number_of_argument_addresses_] = copy_ptr;
389
- argument_sizes_[number_of_argument_addresses_] = sizeof (void *);
390
389
++number_of_argument_addresses_;
391
390
}
392
391
@@ -395,17 +394,13 @@ class KernelArgsPackedArray : public KernelArgsPackedArrayBase {
395
394
// The only significant information about a shared argument is its size, so
396
395
// that is the only parameter in this function.
397
396
void add_shared_bytes (size_t number_of_bytes) {
398
- shared_memory_indices_[number_of_shared_memory_arguments_] =
399
- number_of_argument_addresses_ + number_of_shared_memory_arguments_;
400
- shared_memory_bytes_[number_of_shared_memory_arguments_] = number_of_bytes;
401
- ++number_of_shared_memory_arguments_;
402
397
total_shared_memory_bytes_ += number_of_bytes;
403
398
}
404
399
405
400
// Gets the number of arguments added so far, including shared memory
406
401
// arguments.
407
402
size_t number_of_arguments () const override {
408
- return number_of_argument_addresses_ + number_of_shared_memory_arguments_ ;
403
+ return number_of_argument_addresses_ + (total_shared_memory_bytes_ > 0 ) ;
409
404
}
410
405
411
406
// Gets the total number of shared memory bytes added so far.
@@ -427,28 +422,15 @@ class KernelArgsPackedArray : public KernelArgsPackedArrayBase {
427
422
std::array<const void *, kNumArgs > argument_addresses_;
428
423
429
424
// Storage for arguments of templated type.
430
- alignas (kMaxGenericArgSize )
425
+ alignas (std:: max_align_t )
431
426
std::array<char , kNumArgs * kMaxGenericArgSize > generic_arguments_;
432
427
433
- // Sizes for non-shared-memory arguments.
434
- std::array<size_t , kNumArgs > argument_sizes_;
435
-
436
- // Size in bytes for each shared memory argument.
437
- std::array<size_t , kNumArgs > shared_memory_bytes_;
438
-
439
- // Indices in the arguments array for shared memory arguments.
440
- std::array<size_t , kNumArgs > shared_memory_indices_;
441
-
442
428
// Total of all shared memory sizes.
443
429
size_t total_shared_memory_bytes_ = 0 ;
444
430
445
431
// Number of significant entries in argument_addresses_ and argument_sizes_.
446
432
size_t number_of_argument_addresses_ = 0 ;
447
433
448
- // Number of significant entries in shared_memory_bytes_ and
449
- // shared_memory_indices_.
450
- size_t number_of_shared_memory_arguments_ = 0 ;
451
-
452
434
// The number of generic arguments that have been added to generic_arguments_.
453
435
size_t number_of_generic_arguments_ = 0 ;
454
436
};
0 commit comments