1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // Suite of datatypes to represent data-parallel kernel objects (code entities). 17 // Kernel is the untyped variant, whereas TypedKernel takes a type signature 18 // to do some template-based helper generation and give compile-time type 19 // checking for kernel launch parameters. 20 // 21 // Users typically don't see KernelBase, they see typed kernels, analogous to a 22 // typed function pointer. TypedKernels express their argument types via 23 // template parameters like so: 24 // 25 // TypedKernel<DeviceMemory<int>*, int> 26 // 27 // Which expresses a data parallel kernel signature for: 28 // 29 // void(int*, int); 30 // 31 // And for a const memory region: 32 // 33 // TypedKernel<const DeviceMemory<int>&, int> 34 // 35 // Corresponds to a data parallel kernel signature for: 36 // 37 // void(const int*, int) 38 // 39 // Note that kernels always have a void return type, so results typically must 40 // be memcpy'ied from device memory to the host. 41 // 42 // Also note that a scalar integer residing in device memory and an array of 43 // integers residing in device memory have the same signature: DeviceMemory<T>. 44 // However, in the future, checks may be added for additional safety that arrays 45 // of minimum sizes are passed when those minimum sizes are contractually 46 // expected by the kernel. 47 // 48 // For user-defined types whose definitions are appropriately shared between the 49 // host code doing the launching and the kernel code being launched, the user 50 // defined types are similarly permitted to be expressed as residing in device 51 // memory: 52 // 53 // TypedKernel<DeviceMemory<MyUserDefinedStructure>> 54 // 55 // And, when the alignment and padding are agreed upon, POD types will also be 56 // able to be passed by value; for example, it is a common idiom to specify a 57 // bunch of options simultaneously with a structure: 58 // 59 // TypedKernel<MyOptionsStructurePassedByValue, DeviceMemory<float>> 60 // 61 // Which corresponds to a data parallel kernel signature like: 62 // 63 // void(MyOptionsStructurePassedByValue value, float *result); 64 // 65 // Users typically won't need to type out the TypedKernel signature in full, it 66 // will be typedef'd by automatically generated code; for example, see 67 // stream_executor::executor_sample::VecReduceAddKernel. 68 69 #ifndef TENSORFLOW_STREAM_EXECUTOR_KERNEL_H_ 70 #define TENSORFLOW_STREAM_EXECUTOR_KERNEL_H_ 71 72 #include <array> 73 #include <memory> 74 #include <tuple> 75 #include <type_traits> 76 #include <vector> 77 78 #include "absl/strings/string_view.h" 79 #include "tensorflow/stream_executor/device_memory.h" 80 #include "tensorflow/stream_executor/kernel_cache_config.h" 81 #include "tensorflow/stream_executor/lib/array_slice.h" 82 #include "tensorflow/stream_executor/platform/port.h" 83 84 namespace stream_executor { 85 86 class DeviceMemoryBase; 87 template <typename ElemT> 88 class DeviceMemory; 89 class StreamExecutor; 90 91 namespace internal { 92 class KernelInterface; 93 } // namespace internal 94 95 // KernelMetadata holds runtime-queryable attributes of a loaded kernel, such as 96 // registers allocated, shared memory used, etc. 97 // Not all platforms support reporting of all information, so each accessor 98 // returns false if the associated field is not populated in the underlying 99 // platform. 100 class KernelMetadata { 101 public: KernelMetadata()102 KernelMetadata() 103 : has_registers_per_thread_(false), has_shared_memory_bytes_(false) {} 104 105 // Returns the number of registers used per thread executing this kernel. 106 bool registers_per_thread(int *registers_per_thread) const; 107 108 // Sets the number of registers used per thread executing this kernel. 109 void set_registers_per_thread(int registers_per_thread); 110 111 // Returns the amount of [static] shared memory used per block executing this 112 // kernel. Note that dynamic shared memory allocations are not (and can not) 113 // be reported here (since they're not specified until kernel launch time). 114 bool shared_memory_bytes(int *shared_memory_bytes) const; 115 116 // Sets the amount of [static] shared memory used per block executing this 117 // kernel. 118 void set_shared_memory_bytes(int shared_memory_bytes); 119 120 private: 121 // Holds the value returned by registers_per_thread above. 122 bool has_registers_per_thread_; 123 int registers_per_thread_; 124 125 // Holds the value returned by shared_memory_bytes above. 126 bool has_shared_memory_bytes_; 127 int64 shared_memory_bytes_; 128 }; 129 130 // A data-parallel kernel (code entity) for launching via the StreamExecutor, 131 // analogous to a void* device function pointer. See TypedKernel for the typed 132 // variant. 133 // 134 // Thread-compatible. 135 class KernelBase { 136 public: 137 KernelBase(KernelBase &&from); 138 139 // Constructs an "empty" (not-yet-loaded) kernel instance. 140 // 141 // parent is the StreamExecutor that will be responsible for loading the 142 // implementation of this kernel. It must not be null. 143 explicit KernelBase(StreamExecutor *parent); 144 145 // Test-only constructor that can take a mock KernelInterface implementation. 146 KernelBase(StreamExecutor *parent, internal::KernelInterface *implementation); 147 148 // Releases resources associated with the kernel instance (i.e. 149 // platform-specific implementation). 150 ~KernelBase(); 151 152 // Returns the number of parameters that this kernel accepts. (Arity refers to 153 // nullary, unary, ...). 154 unsigned Arity() const; 155 156 // Returns the StreamExecutor that represents the platform this kernel 157 // executes upon. parent()158 StreamExecutor *parent() const { return parent_; } 159 160 // Returns a const pointer to the (opaque) platform-dependent implementation. implementation()161 const internal::KernelInterface *implementation() const { 162 return implementation_.get(); 163 } 164 165 // Returns a non-const pointer to the (opaque) platform-dependent 166 // implementation. implementation()167 internal::KernelInterface *implementation() { return implementation_.get(); } 168 set_metadata(const KernelMetadata & metadata)169 void set_metadata(const KernelMetadata &metadata) { metadata_ = metadata; } 170 metadata()171 const KernelMetadata &metadata() const { return metadata_; } 172 173 // Sets the preferred cache configuration for a kernel. This is just a 174 // suggestion to the runtime, and may not be honored during execution. 175 void SetPreferredCacheConfig(KernelCacheConfig config); 176 177 // Gets the preferred cache configuration for a kernel. 178 KernelCacheConfig GetPreferredCacheConfig() const; 179 180 void set_name(absl::string_view name); name()181 const string &name() const { return name_; } demangled_name()182 const string &demangled_name() const { return demangled_name_; } 183 184 private: 185 // The StreamExecutor that loads this kernel object. 186 StreamExecutor *parent_; 187 188 // Implementation delegated to for platform-specific functionality. 189 std::unique_ptr<internal::KernelInterface> implementation_; 190 191 string name_; 192 string demangled_name_; 193 194 KernelMetadata metadata_; 195 196 SE_DISALLOW_COPY_AND_ASSIGN(KernelBase); 197 }; 198 199 // Whether T is a DeviceMemory-family pointer. 200 template <typename T> 201 struct IsDeviceMemoryPointer { 202 static constexpr bool value = false; 203 }; 204 205 template <typename U> 206 struct IsDeviceMemoryPointer<DeviceMemory<U> *> { 207 static constexpr bool value = true; 208 }; 209 210 template <> 211 struct IsDeviceMemoryPointer<DeviceMemoryBase *> { 212 static constexpr bool value = true; 213 }; 214 215 // Whether T is a DeviceMemory-family value-like thing (which includes a 216 // reference). This trait is useful because we pack values in the same manner as 217 // references. 218 template <typename T> 219 struct IsDeviceMemoryValueLike { 220 static constexpr bool value = false; 221 }; 222 223 template <typename U> 224 struct IsDeviceMemoryValueLike<DeviceMemory<U> &> { 225 static constexpr bool value = true; 226 }; 227 228 // We need to treat SharedDeviceMemory types differently than other DeviceMemory 229 // types (since they maintain no allocations), hence these specializations. 230 template <typename U> 231 struct IsDeviceMemoryValueLike<SharedDeviceMemory<U> &> { 232 static constexpr bool value = false; 233 }; 234 235 template <> 236 struct IsDeviceMemoryValueLike<DeviceMemoryBase &> { 237 static constexpr bool value = true; 238 }; 239 240 template <typename U> 241 struct IsDeviceMemoryValueLike<DeviceMemory<U>> { 242 static constexpr bool value = true; 243 }; 244 245 template <typename U> 246 struct IsDeviceMemoryValueLike<SharedDeviceMemory<U>> { 247 static constexpr bool value = false; 248 }; 249 250 template <> 251 struct IsDeviceMemoryValueLike<DeviceMemoryBase> { 252 static constexpr bool value = true; 253 }; 254 255 template <typename U> 256 struct IsSharedDeviceMemory { 257 static constexpr bool value = false; 258 }; 259 260 template <typename U> 261 struct IsSharedDeviceMemory<SharedDeviceMemory<U> &> { 262 static constexpr bool value = true; 263 }; 264 265 template <typename U> 266 struct IsSharedDeviceMemory<SharedDeviceMemory<U>> { 267 static constexpr bool value = true; 268 }; 269 270 // Basic data about a kernel argument. 271 struct KernelArg { 272 bool is_shared; 273 const void *address; 274 size_t size; 275 }; 276 277 // An iterator for traversing all the arguments of a KernelArgsArray. 278 class KernelArgIterator { 279 public: 280 KernelArgIterator(int number_of_argument_addresses, 281 int number_of_shared_memory_arguments, 282 const void *const *arg_addresses_data, 283 const size_t *arg_sizes_data, 284 const size_t *shmem_bytes_data, 285 const size_t *shmem_indices_data) 286 : arg_index_(0), 287 number_of_arguments_(number_of_argument_addresses + 288 number_of_shared_memory_arguments), 289 arg_address_iter_(arg_addresses_data), 290 arg_size_iter_(arg_sizes_data), 291 shmem_bytes_iter_(shmem_bytes_data), 292 shmem_indices_iter_(shmem_indices_data), 293 shmem_indices_end_(shmem_indices_data + 294 number_of_shared_memory_arguments) {} 295 296 // Returns true if another argument is present in the iterator. 297 bool has_next() { return arg_index_ < number_of_arguments_; } 298 299 // Returns the next argument in the iterator. 300 // 301 // Returns a default-constructed KernelArg if there is no next argument. 302 KernelArg next() { 303 KernelArg result = {}; 304 if (!has_next()) { 305 return result; 306 } else if ((shmem_indices_iter_ != shmem_indices_end_) && 307 (arg_index_ == *shmem_indices_iter_)) { 308 result.is_shared = true; 309 result.address = nullptr; 310 result.size = *shmem_bytes_iter_; 311 ++shmem_indices_iter_; 312 ++shmem_bytes_iter_; 313 } else { 314 result.is_shared = false; 315 result.address = *arg_address_iter_; 316 result.size = *arg_size_iter_; 317 ++arg_address_iter_; 318 ++arg_size_iter_; 319 } 320 ++arg_index_; 321 return result; 322 } 323 324 private: 325 size_t arg_index_; 326 size_t number_of_arguments_; 327 const void *const *arg_address_iter_; 328 const size_t *arg_size_iter_; 329 const size_t *shmem_bytes_iter_; 330 const size_t *shmem_indices_iter_; 331 const size_t *const shmem_indices_end_; 332 }; 333 334 // Base class for KernelArgsArray. 335 // 336 // Supports all the getter methods that do not depend on the compile-time number 337 // of arguments template parameter. 338 // 339 // This class exists as a way to pass kernel arguments to 340 // StreamExecutorInterface::Launch. That Launch method is virtual, so it can't 341 // be templated to accept any KernelArgsArray type, therefore a reference to 342 // this base type is passed instead. 343 // 344 // Performance is not a concern here because each of these methods will be 345 // called at most once per kernel launch. Past performance concerns with 346 // KernelArgsArray have been in reference to the argument packing routines which 347 // are called once per kernel argument. Those packing routines are now handled 348 // by the templated KernelArgsArray subclass of this class where they can take 349 // advantage of compile-time knowledge of the number of arguments in order to be 350 // very efficient. 351 class KernelArgsArrayBase { 352 public: 353 virtual ~KernelArgsArrayBase() = default; 354 355 // Gets the number of arguments added so far, including shared memory 356 // arguments. 357 virtual size_t number_of_arguments() const = 0; 358 359 // Gets the total number of shared memory bytes added so far. 360 virtual uint64 number_of_shared_bytes() const = 0; 361 362 // Gets the list of argument addresses. 363 virtual port::ArraySlice<const void *> argument_addresses() const = 0; 364 365 // Gets an iterator to the arguments in the array. 366 virtual KernelArgIterator arg_iterator() const = 0; 367 }; 368 369 // A list of arguments for a kernel call. 370 // 371 // The template parameter kNumArgs is the maximum number of arguments which can 372 // be stored in the list. 373 // 374 // Contains a list of addresses for non-shared-memory arguments and a list of 375 // sizes for shared-memory arguments. Since the shared-memory arguments may be 376 // interspersed with the non-shared-memory arguments, it also stores a list of 377 // the indices at which the shared-memory arguments appeared. 378 // 379 // For example, if the argument address list contains {a, b, c, d, e}, the 380 // shared-memory arguments list contains the sizes of {A, B, C}, and the 381 // shared-memory indices list contains {0, 3, 5}, then the original list of 382 // arguments was {A, a, b, B, c, C, d, e}. 383 // 384 // This way of storing the arguments makes CUDA kernel calls efficient because 385 // they only require the argument address list and the total number of shared 386 // bytes, but it also makes it possible for OpenCL kernel calls because they 387 // depend on the location of each shared-memory argument and its size. 388 // 389 // Note that the code for adding arguments has been identified as a performance 390 // hotspot in some real-world applications so this structure has been optimized 391 // for the performance of argument adding. 392 template <size_t kNumArgs> 393 class KernelArgsArray : public KernelArgsArrayBase { 394 public: 395 explicit KernelArgsArray() 396 : total_shared_memory_bytes_(0), 397 number_of_argument_addresses_(0), 398 number_of_shared_memory_arguments_(0) {} 399 400 // Adds an argument to the list. 401 // 402 // Note that the address of the argument is stored, so the input must not go 403 // out of scope before the instance of this class that calls this method does. 404 template <typename T> 405 void add_argument(const T &arg) { 406 argument_addresses_[number_of_argument_addresses_] = 407 static_cast<const void *>(&arg); 408 argument_sizes_[number_of_argument_addresses_] = sizeof(arg); 409 ++number_of_argument_addresses_; 410 } 411 412 // Adds a device memory argument to the list. 413 void add_device_memory_argument(const DeviceMemoryBase &arg) { 414 const void **copy_ptr = 415 &device_memory_opaque_pointers_[number_of_argument_addresses_]; 416 *copy_ptr = arg.opaque(); 417 argument_addresses_[number_of_argument_addresses_] = copy_ptr; 418 argument_sizes_[number_of_argument_addresses_] = sizeof(void *); 419 ++number_of_argument_addresses_; 420 } 421 422 // Adds a shared memory argument to the list. 423 // 424 // The only significant information about a shared argument is its size, so 425 // that is the only parameter in this function. 426 void add_shared_bytes(size_t number_of_bytes) { 427 shared_memory_indices_[number_of_shared_memory_arguments_] = 428 number_of_argument_addresses_ + number_of_shared_memory_arguments_; 429 shared_memory_bytes_[number_of_shared_memory_arguments_] = number_of_bytes; 430 ++number_of_shared_memory_arguments_; 431 total_shared_memory_bytes_ += number_of_bytes; 432 } 433 434 // Gets the number of arguments added so far, including shared memory 435 // arguments. 436 size_t number_of_arguments() const override { 437 return number_of_argument_addresses_ + number_of_shared_memory_arguments_; 438 } 439 440 // Gets the total number of shared memory bytes added so far. 441 uint64 number_of_shared_bytes() const override { 442 return total_shared_memory_bytes_; 443 } 444 445 // Gets the list of argument addresses. 446 port::ArraySlice<const void *> argument_addresses() const override { 447 return port::ArraySlice<const void *>(argument_addresses_.data(), 448 number_of_argument_addresses_); 449 } 450 451 // Gets an iterator to the arguments in the array. 452 KernelArgIterator arg_iterator() const override { 453 return KernelArgIterator( 454 number_of_argument_addresses_, number_of_shared_memory_arguments_, 455 argument_addresses_.data(), argument_sizes_.data(), 456 shared_memory_bytes_.data(), shared_memory_indices_.data()); 457 } 458 459 private: 460 // A place to store copies of opaque pointers from device memory arguments. 461 std::array<const void *, kNumArgs> device_memory_opaque_pointers_; 462 463 // Addresses for non-shared-memory arguments. 464 std::array<const void *, kNumArgs> argument_addresses_; 465 466 // Sizes for non-shared-memory arguments. 467 std::array<size_t, kNumArgs> argument_sizes_; 468 469 // Size in bytes for each shared memory argument. 470 std::array<size_t, kNumArgs> shared_memory_bytes_; 471 472 // Indices in the arguments array for shared memory arguments. 473 std::array<size_t, kNumArgs> shared_memory_indices_; 474 475 // Total of all shared memory sizes. 476 size_t total_shared_memory_bytes_; 477 478 // Number of significant entries in argument_addresses_ and argument_sizes_. 479 size_t number_of_argument_addresses_; 480 481 // Number of significant entries in shared_memory_bytes_ and 482 // shared_memory_indices_. 483 size_t number_of_shared_memory_arguments_; 484 }; 485 486 // Typed variant of KernelBase, like a typed device function pointer. See the 487 // file comment for details and example usage. 488 // 489 // This class contains template metaprogramming magic to type check the 490 // parameters passed to a kernel launch are acceptable, and subsequently pack 491 // them into a form which can be used by the StreamExecutorInterface 492 // implementation. (i.e. CUDA and OpenCL both bind void*s with associated 493 // sizes as kernel arguments.) 494 // 495 // Thread-compatible. 496 template <typename... Params> 497 class TypedKernel : public KernelBase { 498 public: 499 static constexpr size_t kNumberOfParameters = sizeof...(Params); 500 501 // Delegates to KernelBase::KernelBase(), see that constructor. 502 explicit TypedKernel(StreamExecutor *parent) : KernelBase(parent) {} 503 504 // Test-only constructor that can take a mock KernelInterface implementation. 505 // Takes ownership of implementation, it should not be null. 506 TypedKernel(StreamExecutor *parent, internal::KernelInterface *implementation) 507 : KernelBase(parent, implementation) {} 508 509 private: 510 // Stream needs access to the specific parameter-packing functionality that 511 // the TypedKernel provides for its corresponding type signature (and no other 512 // type signatures). 513 friend class Stream; 514 515 // This is the main entry point into the magic. Packs the parameters (which 516 // must type check against the class template) into the args and sizes 517 // arrays. 518 // 519 // Const refs are taken as parameters on all of the handlers to avoid 520 // implicit type promotion of integers. 521 // 522 // WARNING: as a performance optimization this method may store pointers to 523 // some of the input parameters in the kernel args structure, so any params 524 // passed into this method must live at least as long as the kernel args 525 // structure. 526 void PackParams(KernelArgsArray<kNumberOfParameters> *args, 527 Params &... params) const { 528 PackOneParam(args, params...); 529 } 530 531 template <typename T, typename... RestOfParams> 532 void PackOneParam(KernelArgsArray<kNumberOfParameters> *args, const T &arg, 533 const RestOfParams &... rest) const { 534 PackOneParam(args, arg); 535 PackOneParam(args, rest...); 536 } 537 538 // Packs one (non-DeviceMemoryBase) parameter into the arg and sizes array. 539 // The enable_if<> is for excluding DeviceMemoryBase args, which have a 540 // separate implementation below. 541 template <typename T> 542 void PackOneParam( 543 KernelArgsArray<kNumberOfParameters> *args, const T &arg, 544 typename std::enable_if<!IsDeviceMemoryValueLike<T>::value && 545 !IsDeviceMemoryPointer<T>::value && 546 !IsSharedDeviceMemory<T>::value>::type * = 547 nullptr) const { 548 static_assert(!std::is_pointer<T>::value, 549 "cannot pass raw pointer to the device"); 550 static_assert(!std::is_convertible<T, DeviceMemoryBase>::value, 551 "cannot pass device memory as a normal value"); 552 args->add_argument(arg); 553 } 554 555 // DeviceMemoryBase family reference override. 556 template <typename T> 557 void PackOneParam( 558 KernelArgsArray<kNumberOfParameters> *args, const T &arg, 559 typename std::enable_if<IsDeviceMemoryValueLike<T>::value>::type * = 560 nullptr) const { 561 args->add_device_memory_argument(arg); 562 } 563 564 // DeviceMemoryBase family pointer override. 565 template <typename T> 566 void PackOneParam( 567 KernelArgsArray<kNumberOfParameters> *args, T arg, 568 typename std::enable_if<IsDeviceMemoryPointer<T>::value>::type * = 569 nullptr) const { 570 DeviceMemoryBase *ptr = static_cast<DeviceMemoryBase *>(arg); 571 args->add_device_memory_argument(*ptr); 572 } 573 574 // Dynamic shared device memory has a size, but no associated allocation on 575 // the host; internally, the device will allocate storage. 576 template <typename T> 577 void PackOneParam( 578 KernelArgsArray<kNumberOfParameters> *args, T arg, 579 typename std::enable_if<IsSharedDeviceMemory<T>::value>::type * = 580 nullptr) const { 581 args->add_shared_bytes(arg.size()); 582 } 583 584 // Base case for variadic template expansion - nothing to do! 585 void PackOneParam(KernelArgsArray<kNumberOfParameters> *args) const {} 586 587 SE_DISALLOW_COPY_AND_ASSIGN(TypedKernel); 588 }; 589 590 // Template metaprogramming helper type that helps us produce better error 591 // messages at compile time when the are mismatches between the parameter 592 // type list and the argument type list. 593 template <typename ParamTuple, typename ArgTuple> 594 struct KernelInvocationChecker { 595 // Whether the parameter tuple and argument tuple match in length. 596 static constexpr bool kLengthMatches = 597 std::tuple_size<ParamTuple>::value == std::tuple_size<ArgTuple>::value; 598 599 // The (matching) length of the parameters and arguments type lists. 600 static constexpr int kTupleLength = 601 static_cast<int>(std::tuple_size<ArgTuple>::value); 602 603 // Helper trait to say whether the parameter wants a DeviceMemory-reference 604 // compatible type. This is for inexact type matches, so that it doesn't have 605 // to be precisely a const DeviceMemory<T>&, but can also be a value that 606 // represents the same. 607 template <typename ParamType, typename ArgType> 608 struct IsCompatibleDeviceMemoryRef { 609 static constexpr bool value = false; 610 }; 611 612 // See type trait definition above. 613 template <typename U> 614 struct IsCompatibleDeviceMemoryRef<const DeviceMemory<U> &, DeviceMemory<U>> { 615 static constexpr bool value = true; 616 }; 617 618 // See type trait definition above. 619 template <typename U> 620 struct IsCompatibleDeviceMemoryRef<const SharedDeviceMemory<U> &, 621 SharedDeviceMemory<U>> { 622 static constexpr bool value = true; 623 }; 624 625 // Returns whether ParamT and ArgT are compatible for data parallel kernel 626 // parameter packing without any assert functionality. 627 template <typename ParamT, typename ArgT> 628 static constexpr bool CompatibleNoAssert() { 629 return std::is_same<typename std::remove_const<ParamT>::type, 630 ArgT>::value || 631 IsCompatibleDeviceMemoryRef<ParamT, ArgT>::value; 632 } 633 634 // Checks whether ParamT and ArgT are compatible for data parallel kernel 635 // parameter packing. kArgumentNumber is unused, it just for error display. 636 // 637 // NOTE: if you encounter an error here, you can see the mismatch by looking 638 // at the end of the last error message, which will be of the form: 639 // 640 // ...::Compatible<const stream_executor::DeviceMemory<OneThing> &, 641 // stream_executor::DeviceMemory<AnotherThing>, true, 642 // 0>' 643 // requested here 644 // 645 // This means that the 0th argument you passed to the kernel invocation should 646 // have been DeviceMemory<OneThing> but was observed to be 647 // DeviceMemory<AnotherThing>. 648 template <typename ParamT, typename ArgT, bool kShouldStaticAssert, 649 int kArgumentNumber> 650 static constexpr bool Compatible() { 651 static_assert( 652 kShouldStaticAssert ? CompatibleNoAssert<ParamT, ArgT>() : true, 653 "parameter type (LHS) is not compatible with argument type (RHS)"); 654 return CompatibleNoAssert<ParamT, ArgT>(); 655 } 656 657 // Checks the parameter/argument match at kArgumentNumber for an out of bounds 658 // argument number. 659 // 660 // This is the base case: we've run out of argument to check, so we're all 661 // good. 662 template <int kArgumentNumber, bool kShouldStaticAssert> 663 static constexpr bool CheckParam( 664 typename std::enable_if<(kArgumentNumber < 0)>::type *dummy = nullptr) { 665 return true; 666 } 667 668 // Checks the parameter/argument match at kArgumentNumber. 669 // kShouldStaticAssert determines whether to assert out on a mismatch, or just 670 // yield the constexpr boolean value. 671 template <int kArgumentNumber, bool kShouldStaticAssert> 672 static constexpr bool CheckParam( 673 typename std::enable_if<kArgumentNumber >= 0>::type *dummy = nullptr) { 674 typedef typename std::tuple_element<kArgumentNumber, ParamTuple>::type 675 ParamT; 676 typedef typename std::tuple_element<kArgumentNumber, ArgTuple>::type ArgT; 677 return Compatible<ParamT, ArgT, kShouldStaticAssert, kArgumentNumber>() && 678 CheckParam<kArgumentNumber - 1, kShouldStaticAssert>(); 679 } 680 681 // Checks the parameters/arguments for match, but doesn't static assert out. 682 // This is useful for testing/inspecting whether a set of parameters match in 683 // things like tests. 684 static constexpr bool CheckAllNoStaticAssert() { 685 return kLengthMatches && CheckParam<kTupleLength - 1, false>(); 686 } 687 688 // Checks the parameters and static asserts out with a helpful error message 689 // (and useful template parameters in the instantiation stack) if there is an 690 // error. 691 static constexpr bool CheckAllStaticAssert() { 692 static_assert(kLengthMatches, 693 "argument length mismatched against typed kernel parameters"); 694 return kLengthMatches && CheckParam<kTupleLength - 1, true>(); 695 } 696 }; 697 698 // This is a convenience type for checking whether a typed kernel matches 699 // against a type list. 700 template <typename KernelT, typename... Params> 701 struct KernelParamsOk { 702 static constexpr bool kResult = false; 703 }; 704 705 // See above. 706 template <typename... Params, typename... Args> 707 struct KernelParamsOk<TypedKernel<Params...>, Args...> { 708 static constexpr bool kResult = KernelInvocationChecker< 709 std::tuple<Params...>, std::tuple<Args...>>::CheckAllNoStaticAssert(); 710 }; 711 712 } // namespace stream_executor 713 714 #endif // TENSORFLOW_STREAM_EXECUTOR_KERNEL_H_ 715