1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // Suite of datatypes to represent data-parallel kernel objects (code entities). 17 // Kernel is the untyped variant, whereas TypedKernel takes a type signature 18 // to do some template-based helper generation and give compile-time type 19 // checking for kernel launch parameters. 20 // 21 // Users typically don't see KernelBase, they see typed kernels, analogous to a 22 // typed function pointer. TypedKernels express their argument types via 23 // template parameters like so: 24 // 25 // TypedKernel<DeviceMemory<int>*, int> 26 // 27 // Which expresses a data parallel kernel signature for: 28 // 29 // void(int*, int); 30 // 31 // And for a const memory region: 32 // 33 // TypedKernel<const DeviceMemory<int>&, int> 34 // 35 // Corresponds to a data parallel kernel signature for: 36 // 37 // void(const int*, int) 38 // 39 // Note that kernels always have a void return type, so results typically must 40 // be memcpy'ied from device memory to the host. 41 // 42 // Also note that a scalar integer residing in device memory and an array of 43 // integers residing in device memory have the same signature: DeviceMemory<T>. 44 // However, in the future, checks may be added for additional safety that arrays 45 // of minimum sizes are passed when those minimum sizes are contractually 46 // expected by the kernel. 47 // 48 // For user-defined types whose definitions are appropriately shared between the 49 // host code doing the launching and the kernel code being launched, the user 50 // defined types are similarly permitted to be expressed as residing in device 51 // memory: 52 // 53 // TypedKernel<DeviceMemory<MyUserDefinedStructure>> 54 // 55 // And, when the alignment and padding are agreed upon, POD types will also be 56 // able to be passed by value; for example, it is a common idiom to specify a 57 // bunch of options simultaneously with a structure: 58 // 59 // TypedKernel<MyOptionsStructurePassedByValue, DeviceMemory<float>> 60 // 61 // Which corresponds to a data parallel kernel signature like: 62 // 63 // void(MyOptionsStructurePassedByValue value, float *result); 64 // 65 // Users typically won't need to type out the TypedKernel signature in full, it 66 // will be typedef'd by automatically generated code; for example, see 67 // stream_executor::executor_sample::VecReduceAddKernel. 68 69 #ifndef TENSORFLOW_STREAM_EXECUTOR_KERNEL_H_ 70 #define TENSORFLOW_STREAM_EXECUTOR_KERNEL_H_ 71 72 #include <array> 73 #include <memory> 74 #include <tuple> 75 #include <type_traits> 76 #include <vector> 77 78 #include "absl/strings/string_view.h" 79 #include "tensorflow/core/platform/logging.h" 80 #include "tensorflow/stream_executor/device_memory.h" 81 #include "tensorflow/stream_executor/kernel_cache_config.h" 82 #include "tensorflow/stream_executor/lib/array_slice.h" 83 #include "tensorflow/stream_executor/platform/port.h" 84 85 namespace stream_executor { 86 87 class DeviceMemoryBase; 88 template <typename ElemT> 89 class DeviceMemory; 90 class StreamExecutor; 91 92 namespace internal { 93 class KernelInterface; 94 } // namespace internal 95 96 // KernelMetadata holds runtime-queryable attributes of a loaded kernel, such as 97 // registers allocated, shared memory used, etc. 98 // Not all platforms support reporting of all information, so each accessor 99 // returns false if the associated field is not populated in the underlying 100 // platform. 101 class KernelMetadata { 102 public: KernelMetadata()103 KernelMetadata() 104 : has_registers_per_thread_(false), has_shared_memory_bytes_(false) {} 105 106 // Returns the number of registers used per thread executing this kernel. 107 bool registers_per_thread(int *registers_per_thread) const; 108 109 // Sets the number of registers used per thread executing this kernel. 110 void set_registers_per_thread(int registers_per_thread); 111 112 // Returns the amount of [static] shared memory used per block executing this 113 // kernel. Note that dynamic shared memory allocations are not (and can not) 114 // be reported here (since they're not specified until kernel launch time). 115 bool shared_memory_bytes(int *shared_memory_bytes) const; 116 117 // Sets the amount of [static] shared memory used per block executing this 118 // kernel. 119 void set_shared_memory_bytes(int shared_memory_bytes); 120 121 private: 122 // Holds the value returned by registers_per_thread above. 123 bool has_registers_per_thread_; 124 int registers_per_thread_; 125 126 // Holds the value returned by shared_memory_bytes above. 127 bool has_shared_memory_bytes_; 128 int64 shared_memory_bytes_; 129 }; 130 131 // A data-parallel kernel (code entity) for launching via the StreamExecutor, 132 // analogous to a void* device function pointer. See TypedKernel for the typed 133 // variant. 134 // 135 // Thread-compatible. 136 class KernelBase { 137 public: 138 KernelBase(KernelBase &&from); 139 140 // Constructs an "empty" (not-yet-loaded) kernel instance. 141 // 142 // parent is the StreamExecutor that will be responsible for loading the 143 // implementation of this kernel. It must not be null. 144 explicit KernelBase(StreamExecutor *parent); 145 146 // Test-only constructor that can take a mock KernelInterface implementation. 147 KernelBase(StreamExecutor *parent, internal::KernelInterface *implementation); 148 149 // Releases resources associated with the kernel instance (i.e. 150 // platform-specific implementation). 151 ~KernelBase(); 152 153 // Returns the number of parameters that this kernel accepts. (Arity refers to 154 // nullary, unary, ...). 155 unsigned Arity() const; 156 157 // Returns the StreamExecutor that represents the platform this kernel 158 // executes upon. parent()159 StreamExecutor *parent() const { return parent_; } 160 161 // Returns a const pointer to the (opaque) platform-dependent implementation. implementation()162 const internal::KernelInterface *implementation() const { 163 return implementation_.get(); 164 } 165 166 // Returns a non-const pointer to the (opaque) platform-dependent 167 // implementation. implementation()168 internal::KernelInterface *implementation() { return implementation_.get(); } 169 set_metadata(const KernelMetadata & metadata)170 void set_metadata(const KernelMetadata &metadata) { metadata_ = metadata; } 171 metadata()172 const KernelMetadata &metadata() const { return metadata_; } 173 174 // Sets the preferred cache configuration for a kernel. This is just a 175 // suggestion to the runtime, and may not be honored during execution. 176 void SetPreferredCacheConfig(KernelCacheConfig config); 177 178 // Gets the preferred cache configuration for a kernel. 179 KernelCacheConfig GetPreferredCacheConfig() const; 180 181 void set_name(absl::string_view name); name()182 const std::string &name() const { return name_; } demangled_name()183 const std::string &demangled_name() const { return demangled_name_; } 184 185 private: 186 // The StreamExecutor that loads this kernel object. 187 StreamExecutor *parent_; 188 189 // Implementation delegated to for platform-specific functionality. 190 std::unique_ptr<internal::KernelInterface> implementation_; 191 192 std::string name_; 193 std::string demangled_name_; 194 195 KernelMetadata metadata_; 196 197 SE_DISALLOW_COPY_AND_ASSIGN(KernelBase); 198 }; 199 200 // Whether T is a DeviceMemory-family pointer. 201 template <typename T> 202 struct IsDeviceMemoryPointer { 203 static constexpr bool value = false; 204 }; 205 206 template <typename U> 207 struct IsDeviceMemoryPointer<DeviceMemory<U> *> { 208 static constexpr bool value = true; 209 }; 210 211 template <> 212 struct IsDeviceMemoryPointer<DeviceMemoryBase *> { 213 static constexpr bool value = true; 214 }; 215 216 // Whether T is a DeviceMemory-family value-like thing (which includes a 217 // reference). This trait is useful because we pack values in the same manner as 218 // references. 219 template <typename T> 220 struct IsDeviceMemoryValueLike { 221 static constexpr bool value = false; 222 }; 223 224 template <typename U> 225 struct IsDeviceMemoryValueLike<DeviceMemory<U> &> { 226 static constexpr bool value = true; 227 }; 228 229 // We need to treat SharedDeviceMemory types differently than other DeviceMemory 230 // types (since they maintain no allocations), hence these specializations. 231 template <typename U> 232 struct IsDeviceMemoryValueLike<SharedDeviceMemory<U> &> { 233 static constexpr bool value = false; 234 }; 235 236 template <> 237 struct IsDeviceMemoryValueLike<DeviceMemoryBase &> { 238 static constexpr bool value = true; 239 }; 240 241 template <typename U> 242 struct IsDeviceMemoryValueLike<DeviceMemory<U>> { 243 static constexpr bool value = true; 244 }; 245 246 template <typename U> 247 struct IsDeviceMemoryValueLike<SharedDeviceMemory<U>> { 248 static constexpr bool value = false; 249 }; 250 251 template <> 252 struct IsDeviceMemoryValueLike<DeviceMemoryBase> { 253 static constexpr bool value = true; 254 }; 255 256 template <typename U> 257 struct IsSharedDeviceMemory { 258 static constexpr bool value = false; 259 }; 260 261 template <typename U> 262 struct IsSharedDeviceMemory<SharedDeviceMemory<U> &> { 263 static constexpr bool value = true; 264 }; 265 266 template <typename U> 267 struct IsSharedDeviceMemory<SharedDeviceMemory<U>> { 268 static constexpr bool value = true; 269 }; 270 271 // Basic data about a kernel argument. 272 struct KernelArg { 273 bool is_shared; 274 const void *address; 275 size_t size; 276 }; 277 278 // An iterator for traversing all the arguments of a KernelArgsArray. 279 class KernelArgIterator { 280 public: 281 KernelArgIterator(int number_of_argument_addresses, 282 int number_of_shared_memory_arguments, 283 const void *const *arg_addresses_data, 284 const size_t *arg_sizes_data, 285 const size_t *shmem_bytes_data, 286 const size_t *shmem_indices_data) 287 : arg_index_(0), 288 number_of_arguments_(number_of_argument_addresses + 289 number_of_shared_memory_arguments), 290 arg_address_iter_(arg_addresses_data), 291 arg_size_iter_(arg_sizes_data), 292 shmem_bytes_iter_(shmem_bytes_data), 293 shmem_indices_iter_(shmem_indices_data), 294 shmem_indices_end_(shmem_indices_data + 295 number_of_shared_memory_arguments) {} 296 297 // Returns true if another argument is present in the iterator. 298 bool has_next() { return arg_index_ < number_of_arguments_; } 299 300 // Returns the next argument in the iterator. 301 // 302 // Returns a default-constructed KernelArg if there is no next argument. 303 KernelArg next() { 304 KernelArg result = {}; 305 if (!has_next()) { 306 return result; 307 } else if ((shmem_indices_iter_ != shmem_indices_end_) && 308 (arg_index_ == *shmem_indices_iter_)) { 309 result.is_shared = true; 310 result.address = nullptr; 311 result.size = *shmem_bytes_iter_; 312 ++shmem_indices_iter_; 313 ++shmem_bytes_iter_; 314 } else { 315 result.is_shared = false; 316 result.address = *arg_address_iter_; 317 result.size = *arg_size_iter_; 318 ++arg_address_iter_; 319 ++arg_size_iter_; 320 } 321 ++arg_index_; 322 return result; 323 } 324 325 private: 326 size_t arg_index_; 327 size_t number_of_arguments_; 328 const void *const *arg_address_iter_; 329 const size_t *arg_size_iter_; 330 const size_t *shmem_bytes_iter_; 331 const size_t *shmem_indices_iter_; 332 const size_t *const shmem_indices_end_; 333 }; 334 335 // Base class for KernelArgsArray. 336 // 337 // Supports all the getter methods that do not depend on the compile-time number 338 // of arguments template parameter. 339 // 340 // This class exists as a way to pass kernel arguments to 341 // StreamExecutorInterface::Launch. That Launch method is virtual, so it can't 342 // be templated to accept any KernelArgsArray type, therefore a reference to 343 // this base type is passed instead. 344 // 345 // Performance is not a concern here because each of these methods will be 346 // called at most once per kernel launch. Past performance concerns with 347 // KernelArgsArray have been in reference to the argument packing routines which 348 // are called once per kernel argument. Those packing routines are now handled 349 // by the templated KernelArgsArray subclass of this class where they can take 350 // advantage of compile-time knowledge of the number of arguments in order to be 351 // very efficient. 352 class KernelArgsArrayBase { 353 public: 354 virtual ~KernelArgsArrayBase() = default; 355 356 // Gets the number of arguments added so far, including shared memory 357 // arguments. 358 virtual size_t number_of_arguments() const = 0; 359 360 // Gets the total number of shared memory bytes added so far. 361 virtual uint64 number_of_shared_bytes() const = 0; 362 363 // Gets the list of argument addresses. 364 virtual port::ArraySlice<const void *> argument_addresses() const = 0; 365 366 // Gets an iterator to the arguments in the array. 367 virtual KernelArgIterator arg_iterator() const = 0; 368 }; 369 370 // A list of arguments for a kernel call. 371 // 372 // The template parameter kNumArgs is the maximum number of arguments which can 373 // be stored in the list. 374 // 375 // Contains a list of addresses for non-shared-memory arguments and a list of 376 // sizes for shared-memory arguments. Since the shared-memory arguments may be 377 // interspersed with the non-shared-memory arguments, it also stores a list of 378 // the indices at which the shared-memory arguments appeared. 379 // 380 // For example, if the argument address list contains {a, b, c, d, e}, the 381 // shared-memory arguments list contains the sizes of {A, B, C}, and the 382 // shared-memory indices list contains {0, 3, 5}, then the original list of 383 // arguments was {A, a, b, B, c, C, d, e}. 384 // 385 // This way of storing the arguments makes CUDA kernel calls efficient because 386 // they only require the argument address list and the total number of shared 387 // bytes, but it also makes it possible for OpenCL kernel calls because they 388 // depend on the location of each shared-memory argument and its size. 389 // 390 // Note that the code for adding arguments has been identified as a performance 391 // hotspot in some real-world applications so this structure has been optimized 392 // for the performance of argument adding. 393 template <size_t kNumArgs> 394 class KernelArgsArray : public KernelArgsArrayBase { 395 public: 396 static constexpr int kMaxGenericArgSize = 8; 397 398 // Adds an argument to the list. 399 template <typename T> 400 void add_argument(const T &arg) { 401 static_assert(sizeof(T) <= kMaxGenericArgSize, 402 "Please adjust kMaxGenericArgSize"); 403 static_assert(std::is_pod<T>::value, "Only pod types supported!"); 404 char *generic_arg_storage = 405 &generic_arguments_[number_of_generic_arguments_++ * 406 kMaxGenericArgSize]; 407 408 CHECK_EQ(reinterpret_cast<uintptr_t>(generic_arg_storage) % alignof(T), 0); 409 std::memcpy(generic_arg_storage, &arg, sizeof(T)); 410 411 argument_addresses_[number_of_argument_addresses_] = generic_arg_storage; 412 argument_sizes_[number_of_argument_addresses_] = sizeof(arg); 413 ++number_of_argument_addresses_; 414 } 415 416 // Adds a device memory argument to the list. 417 void add_device_memory_argument(const DeviceMemoryBase &arg) { 418 const void **copy_ptr = 419 &device_memory_opaque_pointers_[number_of_argument_addresses_]; 420 *copy_ptr = arg.opaque(); 421 argument_addresses_[number_of_argument_addresses_] = copy_ptr; 422 argument_sizes_[number_of_argument_addresses_] = sizeof(void *); 423 ++number_of_argument_addresses_; 424 } 425 426 // Adds a shared memory argument to the list. 427 // 428 // The only significant information about a shared argument is its size, so 429 // that is the only parameter in this function. 430 void add_shared_bytes(size_t number_of_bytes) { 431 shared_memory_indices_[number_of_shared_memory_arguments_] = 432 number_of_argument_addresses_ + number_of_shared_memory_arguments_; 433 shared_memory_bytes_[number_of_shared_memory_arguments_] = number_of_bytes; 434 ++number_of_shared_memory_arguments_; 435 total_shared_memory_bytes_ += number_of_bytes; 436 } 437 438 // Gets the number of arguments added so far, including shared memory 439 // arguments. 440 size_t number_of_arguments() const override { 441 return number_of_argument_addresses_ + number_of_shared_memory_arguments_; 442 } 443 444 // Gets the total number of shared memory bytes added so far. 445 uint64 number_of_shared_bytes() const override { 446 return total_shared_memory_bytes_; 447 } 448 449 // Gets the list of argument addresses. 450 port::ArraySlice<const void *> argument_addresses() const override { 451 return port::ArraySlice<const void *>(argument_addresses_.data(), 452 number_of_argument_addresses_); 453 } 454 455 // Gets an iterator to the arguments in the array. 456 KernelArgIterator arg_iterator() const override { 457 return KernelArgIterator( 458 number_of_argument_addresses_, number_of_shared_memory_arguments_, 459 argument_addresses_.data(), argument_sizes_.data(), 460 shared_memory_bytes_.data(), shared_memory_indices_.data()); 461 } 462 463 private: 464 // A place to store copies of opaque pointers from device memory arguments. 465 std::array<const void *, kNumArgs> device_memory_opaque_pointers_; 466 467 // Addresses for non-shared-memory arguments. 468 std::array<const void *, kNumArgs> argument_addresses_; 469 470 // Storage for arguments of templated type. 471 alignas(kMaxGenericArgSize) 472 std::array<char, kNumArgs * kMaxGenericArgSize> generic_arguments_; 473 474 // Sizes for non-shared-memory arguments. 475 std::array<size_t, kNumArgs> argument_sizes_; 476 477 // Size in bytes for each shared memory argument. 478 std::array<size_t, kNumArgs> shared_memory_bytes_; 479 480 // Indices in the arguments array for shared memory arguments. 481 std::array<size_t, kNumArgs> shared_memory_indices_; 482 483 // Total of all shared memory sizes. 484 size_t total_shared_memory_bytes_ = 0; 485 486 // Number of significant entries in argument_addresses_ and argument_sizes_. 487 size_t number_of_argument_addresses_ = 0; 488 489 // Number of significant entries in shared_memory_bytes_ and 490 // shared_memory_indices_. 491 size_t number_of_shared_memory_arguments_ = 0; 492 493 // The number of generic arguments that have been added to generic_arguments_. 494 size_t number_of_generic_arguments_ = 0; 495 }; 496 497 // Typed variant of KernelBase, like a typed device function pointer. See the 498 // file comment for details and example usage. 499 // 500 // This class contains template metaprogramming magic to type check the 501 // parameters passed to a kernel launch are acceptable, and subsequently pack 502 // them into a form which can be used by the StreamExecutorInterface 503 // implementation. (i.e. CUDA and OpenCL both bind void*s with associated 504 // sizes as kernel arguments.) 505 // 506 // Thread-compatible. 507 template <typename... Params> 508 class TypedKernel : public KernelBase { 509 public: 510 static constexpr size_t kNumberOfParameters = sizeof...(Params); 511 512 // Delegates to KernelBase::KernelBase(), see that constructor. 513 explicit TypedKernel(StreamExecutor *parent) : KernelBase(parent) {} 514 515 // Test-only constructor that can take a mock KernelInterface implementation. 516 // Takes ownership of implementation, it should not be null. 517 TypedKernel(StreamExecutor *parent, internal::KernelInterface *implementation) 518 : KernelBase(parent, implementation) {} 519 520 private: 521 // Stream needs access to the specific parameter-packing functionality that 522 // the TypedKernel provides for its corresponding type signature (and no other 523 // type signatures). 524 friend class Stream; 525 526 // This is the main entry point into the magic. Packs the parameters (which 527 // must type check against the class template) into the args and sizes 528 // arrays. 529 // 530 // Const refs are taken as parameters on all of the handlers to avoid 531 // implicit type promotion of integers. 532 // 533 // WARNING: as a performance optimization this method may store pointers to 534 // some of the input parameters in the kernel args structure, so any params 535 // passed into this method must live at least as long as the kernel args 536 // structure. 537 void PackParams(KernelArgsArray<kNumberOfParameters> *args, 538 Params &... params) const { 539 PackOneParamFromList(args, params...); 540 } 541 542 template <typename T, typename... RestOfParams> 543 void PackOneParamFromList(KernelArgsArray<kNumberOfParameters> *args, 544 const T &arg, const RestOfParams &... rest) const { 545 PackOneParam(args, arg); 546 PackOneParamFromList(args, rest...); 547 } 548 549 // Base case for variadic template expansion - nothing to do! 550 void PackOneParamFromList(KernelArgsArray<kNumberOfParameters> *args) const {} 551 552 // Packs one (non-DeviceMemoryBase) parameter into the arg and sizes array. 553 // The enable_if<> is for excluding DeviceMemoryBase args, which have a 554 // separate implementation below. 555 template <typename T> 556 void PackOneParam( 557 KernelArgsArray<kNumberOfParameters> *args, const T &arg, 558 typename std::enable_if<!IsDeviceMemoryValueLike<T>::value && 559 !IsDeviceMemoryPointer<T>::value && 560 !IsSharedDeviceMemory<T>::value>::type * = 561 nullptr) const { 562 static_assert(!std::is_pointer<T>::value, 563 "cannot pass raw pointer to the device"); 564 static_assert(!std::is_convertible<T, DeviceMemoryBase>::value, 565 "cannot pass device memory as a normal value"); 566 args->add_argument(arg); 567 } 568 569 // DeviceMemoryBase family reference override. 570 template <typename T> 571 void PackOneParam( 572 KernelArgsArray<kNumberOfParameters> *args, const T &arg, 573 typename std::enable_if<IsDeviceMemoryValueLike<T>::value>::type * = 574 nullptr) const { 575 args->add_device_memory_argument(arg); 576 } 577 578 // DeviceMemoryBase family pointer override. 579 template <typename T> 580 void PackOneParam( 581 KernelArgsArray<kNumberOfParameters> *args, T arg, 582 typename std::enable_if<IsDeviceMemoryPointer<T>::value>::type * = 583 nullptr) const { 584 DeviceMemoryBase *ptr = static_cast<DeviceMemoryBase *>(arg); 585 args->add_device_memory_argument(*ptr); 586 } 587 588 // Dynamic shared device memory has a size, but no associated allocation on 589 // the host; internally, the device will allocate storage. 590 template <typename T> 591 void PackOneParam( 592 KernelArgsArray<kNumberOfParameters> *args, T arg, 593 typename std::enable_if<IsSharedDeviceMemory<T>::value>::type * = 594 nullptr) const { 595 args->add_shared_bytes(arg.size()); 596 } 597 598 SE_DISALLOW_COPY_AND_ASSIGN(TypedKernel); 599 }; 600 601 // Template metaprogramming helper type that helps us produce better error 602 // messages at compile time when the are mismatches between the parameter 603 // type list and the argument type list. 604 template <typename ParamTuple, typename ArgTuple> 605 struct KernelInvocationChecker { 606 // Whether the parameter tuple and argument tuple match in length. 607 static constexpr bool kLengthMatches = 608 std::tuple_size<ParamTuple>::value == std::tuple_size<ArgTuple>::value; 609 610 // The (matching) length of the parameters and arguments type lists. 611 static constexpr int kTupleLength = 612 static_cast<int>(std::tuple_size<ArgTuple>::value); 613 614 // Helper trait to say whether the parameter wants a DeviceMemory-reference 615 // compatible type. This is for inexact type matches, so that it doesn't have 616 // to be precisely a const DeviceMemory<T>&, but can also be a value that 617 // represents the same. 618 template <typename ParamType, typename ArgType> 619 struct IsCompatibleDeviceMemoryRef { 620 static constexpr bool value = false; 621 }; 622 623 // See type trait definition above. 624 template <typename U> 625 struct IsCompatibleDeviceMemoryRef<const DeviceMemory<U> &, DeviceMemory<U>> { 626 static constexpr bool value = true; 627 }; 628 629 // See type trait definition above. 630 template <typename U> 631 struct IsCompatibleDeviceMemoryRef<const SharedDeviceMemory<U> &, 632 SharedDeviceMemory<U>> { 633 static constexpr bool value = true; 634 }; 635 636 // Returns whether ParamT and ArgT are compatible for data parallel kernel 637 // parameter packing without any assert functionality. 638 template <typename ParamT, typename ArgT> 639 static constexpr bool CompatibleNoAssert() { 640 return std::is_same<typename std::remove_const<ParamT>::type, 641 ArgT>::value || 642 IsCompatibleDeviceMemoryRef<ParamT, ArgT>::value; 643 } 644 645 // Checks whether ParamT and ArgT are compatible for data parallel kernel 646 // parameter packing. kArgumentNumber is unused, it just for error display. 647 // 648 // NOTE: if you encounter an error here, you can see the mismatch by looking 649 // at the end of the last error message, which will be of the form: 650 // 651 // ...::Compatible<const stream_executor::DeviceMemory<OneThing> &, 652 // stream_executor::DeviceMemory<AnotherThing>, true, 653 // 0>' 654 // requested here 655 // 656 // This means that the 0th argument you passed to the kernel invocation should 657 // have been DeviceMemory<OneThing> but was observed to be 658 // DeviceMemory<AnotherThing>. 659 template <typename ParamT, typename ArgT, bool kShouldStaticAssert, 660 int kArgumentNumber> 661 static constexpr bool Compatible() { 662 static_assert( 663 kShouldStaticAssert ? CompatibleNoAssert<ParamT, ArgT>() : true, 664 "parameter type (LHS) is not compatible with argument type (RHS)"); 665 return CompatibleNoAssert<ParamT, ArgT>(); 666 } 667 668 // Checks the parameter/argument match at kArgumentNumber for an out of bounds 669 // argument number. 670 // 671 // This is the base case: we've run out of argument to check, so we're all 672 // good. 673 template <int kArgumentNumber, bool kShouldStaticAssert> 674 static constexpr bool CheckParam( 675 typename std::enable_if<(kArgumentNumber < 0)>::type *dummy = nullptr) { 676 return true; 677 } 678 679 // Checks the parameter/argument match at kArgumentNumber. 680 // kShouldStaticAssert determines whether to assert out on a mismatch, or just 681 // yield the constexpr boolean value. 682 template <int kArgumentNumber, bool kShouldStaticAssert> 683 static constexpr bool CheckParam( 684 typename std::enable_if<kArgumentNumber >= 0>::type *dummy = nullptr) { 685 typedef typename std::tuple_element<kArgumentNumber, ParamTuple>::type 686 ParamT; 687 typedef typename std::tuple_element<kArgumentNumber, ArgTuple>::type ArgT; 688 return Compatible<ParamT, ArgT, kShouldStaticAssert, kArgumentNumber>() && 689 CheckParam<kArgumentNumber - 1, kShouldStaticAssert>(); 690 } 691 692 // Checks the parameters/arguments for match, but doesn't static assert out. 693 // This is useful for testing/inspecting whether a set of parameters match in 694 // things like tests. 695 static constexpr bool CheckAllNoStaticAssert() { 696 return kLengthMatches && CheckParam<kTupleLength - 1, false>(); 697 } 698 699 // Checks the parameters and static asserts out with a helpful error message 700 // (and useful template parameters in the instantiation stack) if there is an 701 // error. 702 static constexpr bool CheckAllStaticAssert() { 703 static_assert(kLengthMatches, 704 "argument length mismatched against typed kernel parameters"); 705 return kLengthMatches && CheckParam<kTupleLength - 1, true>(); 706 } 707 }; 708 709 // This is a convenience type for checking whether a typed kernel matches 710 // against a type list. 711 template <typename KernelT, typename... Params> 712 struct KernelParamsOk { 713 static constexpr bool kResult = false; 714 }; 715 716 // See above. 717 template <typename... Params, typename... Args> 718 struct KernelParamsOk<TypedKernel<Params...>, Args...> { 719 static constexpr bool kResult = KernelInvocationChecker< 720 std::tuple<Params...>, std::tuple<Args...>>::CheckAllNoStaticAssert(); 721 }; 722 723 } // namespace stream_executor 724 725 #endif // TENSORFLOW_STREAM_EXECUTOR_KERNEL_H_ 726