• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Suite of datatypes to represent data-parallel kernel objects (code entities).
17 // Kernel is the untyped variant, whereas TypedKernel takes a type signature
18 // to do some template-based helper generation and give compile-time type
19 // checking for kernel launch parameters.
20 //
21 // Users typically don't see KernelBase, they see typed kernels, analogous to a
22 // typed function pointer. TypedKernels express their argument types via
23 // template parameters like so:
24 //
25 //  TypedKernel<DeviceMemory<int>*, int>
26 //
27 // Which expresses a data parallel kernel signature for:
28 //
29 //  void(int*, int);
30 //
31 // And for a const memory region:
32 //
33 //  TypedKernel<const DeviceMemory<int>&, int>
34 //
35 // Corresponds to a data parallel kernel signature for:
36 //
37 //  void(const int*, int)
38 //
39 // Note that kernels always have a void return type, so results typically must
40 // be memcpy'ied from device memory to the host.
41 //
42 // Also note that a scalar integer residing in device memory and an array of
43 // integers residing in device memory have the same signature: DeviceMemory<T>.
44 // However, in the future, checks may be added for additional safety that arrays
45 // of minimum sizes are passed when those minimum sizes are contractually
46 // expected by the kernel.
47 //
48 // For user-defined types whose definitions are appropriately shared between the
49 // host code doing the launching and the kernel code being launched, the user
50 // defined types are similarly permitted to be expressed as residing in device
51 // memory:
52 //
53 //  TypedKernel<DeviceMemory<MyUserDefinedStructure>>
54 //
55 // And, when the alignment and padding are agreed upon, POD types will also be
56 // able to be passed by value; for example, it is a common idiom to specify a
57 // bunch of options simultaneously with a structure:
58 //
59 //  TypedKernel<MyOptionsStructurePassedByValue, DeviceMemory<float>>
60 //
61 // Which corresponds to a data parallel kernel signature like:
62 //
63 //  void(MyOptionsStructurePassedByValue value, float *result);
64 //
65 // Users typically won't need to type out the TypedKernel signature in full, it
66 // will be typedef'd by automatically generated code; for example, see
67 // stream_executor::executor_sample::VecReduceAddKernel.
68 
69 #ifndef TENSORFLOW_STREAM_EXECUTOR_KERNEL_H_
70 #define TENSORFLOW_STREAM_EXECUTOR_KERNEL_H_
71 
72 #include <array>
73 #include <memory>
74 #include <tuple>
75 #include <type_traits>
76 #include <vector>
77 
78 #include "absl/strings/string_view.h"
79 #include "tensorflow/core/platform/logging.h"
80 #include "tensorflow/stream_executor/device_memory.h"
81 #include "tensorflow/stream_executor/kernel_cache_config.h"
82 #include "tensorflow/stream_executor/lib/array_slice.h"
83 #include "tensorflow/stream_executor/platform/port.h"
84 
85 namespace stream_executor {
86 
87 class DeviceMemoryBase;
88 template <typename ElemT>
89 class DeviceMemory;
90 class StreamExecutor;
91 
92 namespace internal {
93 class KernelInterface;
94 }  // namespace internal
95 
96 // KernelMetadata holds runtime-queryable attributes of a loaded kernel, such as
97 // registers allocated, shared memory used, etc.
98 // Not all platforms support reporting of all information, so each accessor
99 // returns false if the associated field is not populated in the underlying
100 // platform.
101 class KernelMetadata {
102  public:
KernelMetadata()103   KernelMetadata()
104       : has_registers_per_thread_(false), has_shared_memory_bytes_(false) {}
105 
106   // Returns the number of registers used per thread executing this kernel.
107   bool registers_per_thread(int *registers_per_thread) const;
108 
109   // Sets the number of registers used per thread executing this kernel.
110   void set_registers_per_thread(int registers_per_thread);
111 
112   // Returns the amount of [static] shared memory used per block executing this
113   // kernel. Note that dynamic shared memory allocations are not (and can not)
114   // be reported here (since they're not specified until kernel launch time).
115   bool shared_memory_bytes(int *shared_memory_bytes) const;
116 
117   // Sets the amount of [static] shared memory used per block executing this
118   // kernel.
119   void set_shared_memory_bytes(int shared_memory_bytes);
120 
121  private:
122   // Holds the value returned by registers_per_thread above.
123   bool has_registers_per_thread_;
124   int registers_per_thread_;
125 
126   // Holds the value returned by shared_memory_bytes above.
127   bool has_shared_memory_bytes_;
128   int64 shared_memory_bytes_;
129 };
130 
131 // A data-parallel kernel (code entity) for launching via the StreamExecutor,
132 // analogous to a void* device function pointer. See TypedKernel for the typed
133 // variant.
134 //
135 // Thread-compatible.
136 class KernelBase {
137  public:
138   KernelBase(KernelBase &&from);
139 
140   // Constructs an "empty" (not-yet-loaded) kernel instance.
141   //
142   // parent is the StreamExecutor that will be responsible for loading the
143   // implementation of this kernel. It must not be null.
144   explicit KernelBase(StreamExecutor *parent);
145 
146   // Test-only constructor that can take a mock KernelInterface implementation.
147   KernelBase(StreamExecutor *parent, internal::KernelInterface *implementation);
148 
149   // Releases resources associated with the kernel instance (i.e.
150   // platform-specific implementation).
151   ~KernelBase();
152 
153   // Returns the number of parameters that this kernel accepts. (Arity refers to
154   // nullary, unary, ...).
155   unsigned Arity() const;
156 
157   // Returns the StreamExecutor that represents the platform this kernel
158   // executes upon.
parent()159   StreamExecutor *parent() const { return parent_; }
160 
161   // Returns a const pointer to the (opaque) platform-dependent implementation.
implementation()162   const internal::KernelInterface *implementation() const {
163     return implementation_.get();
164   }
165 
166   // Returns a non-const pointer to the (opaque) platform-dependent
167   // implementation.
implementation()168   internal::KernelInterface *implementation() { return implementation_.get(); }
169 
set_metadata(const KernelMetadata & metadata)170   void set_metadata(const KernelMetadata &metadata) { metadata_ = metadata; }
171 
metadata()172   const KernelMetadata &metadata() const { return metadata_; }
173 
174   // Sets the preferred cache configuration for a kernel. This is just a
175   // suggestion to the runtime, and may not be honored during execution.
176   void SetPreferredCacheConfig(KernelCacheConfig config);
177 
178   // Gets the preferred cache configuration for a kernel.
179   KernelCacheConfig GetPreferredCacheConfig() const;
180 
181   void set_name(absl::string_view name);
name()182   const std::string &name() const { return name_; }
demangled_name()183   const std::string &demangled_name() const { return demangled_name_; }
184 
185  private:
186   // The StreamExecutor that loads this kernel object.
187   StreamExecutor *parent_;
188 
189   // Implementation delegated to for platform-specific functionality.
190   std::unique_ptr<internal::KernelInterface> implementation_;
191 
192   std::string name_;
193   std::string demangled_name_;
194 
195   KernelMetadata metadata_;
196 
197   SE_DISALLOW_COPY_AND_ASSIGN(KernelBase);
198 };
199 
200 // Whether T is a DeviceMemory-family pointer.
201 template <typename T>
202 struct IsDeviceMemoryPointer {
203   static constexpr bool value = false;
204 };
205 
206 template <typename U>
207 struct IsDeviceMemoryPointer<DeviceMemory<U> *> {
208   static constexpr bool value = true;
209 };
210 
211 template <>
212 struct IsDeviceMemoryPointer<DeviceMemoryBase *> {
213   static constexpr bool value = true;
214 };
215 
216 // Whether T is a DeviceMemory-family value-like thing (which includes a
217 // reference). This trait is useful because we pack values in the same manner as
218 // references.
219 template <typename T>
220 struct IsDeviceMemoryValueLike {
221   static constexpr bool value = false;
222 };
223 
224 template <typename U>
225 struct IsDeviceMemoryValueLike<DeviceMemory<U> &> {
226   static constexpr bool value = true;
227 };
228 
229 // We need to treat SharedDeviceMemory types differently than other DeviceMemory
230 // types (since they maintain no allocations), hence these specializations.
231 template <typename U>
232 struct IsDeviceMemoryValueLike<SharedDeviceMemory<U> &> {
233   static constexpr bool value = false;
234 };
235 
236 template <>
237 struct IsDeviceMemoryValueLike<DeviceMemoryBase &> {
238   static constexpr bool value = true;
239 };
240 
241 template <typename U>
242 struct IsDeviceMemoryValueLike<DeviceMemory<U>> {
243   static constexpr bool value = true;
244 };
245 
246 template <typename U>
247 struct IsDeviceMemoryValueLike<SharedDeviceMemory<U>> {
248   static constexpr bool value = false;
249 };
250 
251 template <>
252 struct IsDeviceMemoryValueLike<DeviceMemoryBase> {
253   static constexpr bool value = true;
254 };
255 
256 template <typename U>
257 struct IsSharedDeviceMemory {
258   static constexpr bool value = false;
259 };
260 
261 template <typename U>
262 struct IsSharedDeviceMemory<SharedDeviceMemory<U> &> {
263   static constexpr bool value = true;
264 };
265 
266 template <typename U>
267 struct IsSharedDeviceMemory<SharedDeviceMemory<U>> {
268   static constexpr bool value = true;
269 };
270 
271 // Basic data about a kernel argument.
272 struct KernelArg {
273   bool is_shared;
274   const void *address;
275   size_t size;
276 };
277 
278 // An iterator for traversing all the arguments of a KernelArgsArray.
279 class KernelArgIterator {
280  public:
281   KernelArgIterator(int number_of_argument_addresses,
282                     int number_of_shared_memory_arguments,
283                     const void *const *arg_addresses_data,
284                     const size_t *arg_sizes_data,
285                     const size_t *shmem_bytes_data,
286                     const size_t *shmem_indices_data)
287       : arg_index_(0),
288         number_of_arguments_(number_of_argument_addresses +
289                              number_of_shared_memory_arguments),
290         arg_address_iter_(arg_addresses_data),
291         arg_size_iter_(arg_sizes_data),
292         shmem_bytes_iter_(shmem_bytes_data),
293         shmem_indices_iter_(shmem_indices_data),
294         shmem_indices_end_(shmem_indices_data +
295                            number_of_shared_memory_arguments) {}
296 
297   // Returns true if another argument is present in the iterator.
298   bool has_next() { return arg_index_ < number_of_arguments_; }
299 
300   // Returns the next argument in the iterator.
301   //
302   // Returns a default-constructed KernelArg if there is no next argument.
303   KernelArg next() {
304     KernelArg result = {};
305     if (!has_next()) {
306       return result;
307     } else if ((shmem_indices_iter_ != shmem_indices_end_) &&
308                (arg_index_ == *shmem_indices_iter_)) {
309       result.is_shared = true;
310       result.address = nullptr;
311       result.size = *shmem_bytes_iter_;
312       ++shmem_indices_iter_;
313       ++shmem_bytes_iter_;
314     } else {
315       result.is_shared = false;
316       result.address = *arg_address_iter_;
317       result.size = *arg_size_iter_;
318       ++arg_address_iter_;
319       ++arg_size_iter_;
320     }
321     ++arg_index_;
322     return result;
323   }
324 
325  private:
326   size_t arg_index_;
327   size_t number_of_arguments_;
328   const void *const *arg_address_iter_;
329   const size_t *arg_size_iter_;
330   const size_t *shmem_bytes_iter_;
331   const size_t *shmem_indices_iter_;
332   const size_t *const shmem_indices_end_;
333 };
334 
335 // Base class for KernelArgsArray.
336 //
337 // Supports all the getter methods that do not depend on the compile-time number
338 // of arguments template parameter.
339 //
340 // This class exists as a way to pass kernel arguments to
341 // StreamExecutorInterface::Launch. That Launch method is virtual, so it can't
342 // be templated to accept any KernelArgsArray type, therefore a reference to
343 // this base type is passed instead.
344 //
345 // Performance is not a concern here because each of these methods will be
346 // called at most once per kernel launch. Past performance concerns with
347 // KernelArgsArray have been in reference to the argument packing routines which
348 // are called once per kernel argument. Those packing routines are now handled
349 // by the templated KernelArgsArray subclass of this class where they can take
350 // advantage of compile-time knowledge of the number of arguments in order to be
351 // very efficient.
352 class KernelArgsArrayBase {
353  public:
354   virtual ~KernelArgsArrayBase() = default;
355 
356   // Gets the number of arguments added so far, including shared memory
357   // arguments.
358   virtual size_t number_of_arguments() const = 0;
359 
360   // Gets the total number of shared memory bytes added so far.
361   virtual uint64 number_of_shared_bytes() const = 0;
362 
363   // Gets the list of argument addresses.
364   virtual port::ArraySlice<const void *> argument_addresses() const = 0;
365 
366   // Gets an iterator to the arguments in the array.
367   virtual KernelArgIterator arg_iterator() const = 0;
368 };
369 
370 // A list of arguments for a kernel call.
371 //
372 // The template parameter kNumArgs is the maximum number of arguments which can
373 // be stored in the list.
374 //
375 // Contains a list of addresses for non-shared-memory arguments and a list of
376 // sizes for shared-memory arguments. Since the shared-memory arguments may be
377 // interspersed with the non-shared-memory arguments, it also stores a list of
378 // the indices at which the shared-memory arguments appeared.
379 //
380 // For example, if the argument address list contains {a, b, c, d, e}, the
381 // shared-memory arguments list contains the sizes of {A, B, C}, and the
382 // shared-memory indices list contains {0, 3, 5}, then the original list of
383 // arguments was {A, a, b, B, c, C, d, e}.
384 //
385 // This way of storing the arguments makes CUDA kernel calls efficient because
386 // they only require the argument address list and the total number of shared
387 // bytes, but it also makes it possible for OpenCL kernel calls because they
388 // depend on the location of each shared-memory argument and its size.
389 //
390 // Note that the code for adding arguments has been identified as a performance
391 // hotspot in some real-world applications so this structure has been optimized
392 // for the performance of argument adding.
393 template <size_t kNumArgs>
394 class KernelArgsArray : public KernelArgsArrayBase {
395  public:
396   static constexpr int kMaxGenericArgSize = 8;
397 
398   // Adds an argument to the list.
399   template <typename T>
400   void add_argument(const T &arg) {
401     static_assert(sizeof(T) <= kMaxGenericArgSize,
402                   "Please adjust kMaxGenericArgSize");
403     static_assert(std::is_pod<T>::value, "Only pod types supported!");
404     char *generic_arg_storage =
405         &generic_arguments_[number_of_generic_arguments_++ *
406                             kMaxGenericArgSize];
407 
408     CHECK_EQ(reinterpret_cast<uintptr_t>(generic_arg_storage) % alignof(T), 0);
409     std::memcpy(generic_arg_storage, &arg, sizeof(T));
410 
411     argument_addresses_[number_of_argument_addresses_] = generic_arg_storage;
412     argument_sizes_[number_of_argument_addresses_] = sizeof(arg);
413     ++number_of_argument_addresses_;
414   }
415 
416   // Adds a device memory argument to the list.
417   void add_device_memory_argument(const DeviceMemoryBase &arg) {
418     const void **copy_ptr =
419         &device_memory_opaque_pointers_[number_of_argument_addresses_];
420     *copy_ptr = arg.opaque();
421     argument_addresses_[number_of_argument_addresses_] = copy_ptr;
422     argument_sizes_[number_of_argument_addresses_] = sizeof(void *);
423     ++number_of_argument_addresses_;
424   }
425 
426   // Adds a shared memory argument to the list.
427   //
428   // The only significant information about a shared argument is its size, so
429   // that is the only parameter in this function.
430   void add_shared_bytes(size_t number_of_bytes) {
431     shared_memory_indices_[number_of_shared_memory_arguments_] =
432         number_of_argument_addresses_ + number_of_shared_memory_arguments_;
433     shared_memory_bytes_[number_of_shared_memory_arguments_] = number_of_bytes;
434     ++number_of_shared_memory_arguments_;
435     total_shared_memory_bytes_ += number_of_bytes;
436   }
437 
438   // Gets the number of arguments added so far, including shared memory
439   // arguments.
440   size_t number_of_arguments() const override {
441     return number_of_argument_addresses_ + number_of_shared_memory_arguments_;
442   }
443 
444   // Gets the total number of shared memory bytes added so far.
445   uint64 number_of_shared_bytes() const override {
446     return total_shared_memory_bytes_;
447   }
448 
449   // Gets the list of argument addresses.
450   port::ArraySlice<const void *> argument_addresses() const override {
451     return port::ArraySlice<const void *>(argument_addresses_.data(),
452                                           number_of_argument_addresses_);
453   }
454 
455   // Gets an iterator to the arguments in the array.
456   KernelArgIterator arg_iterator() const override {
457     return KernelArgIterator(
458         number_of_argument_addresses_, number_of_shared_memory_arguments_,
459         argument_addresses_.data(), argument_sizes_.data(),
460         shared_memory_bytes_.data(), shared_memory_indices_.data());
461   }
462 
463  private:
464   // A place to store copies of opaque pointers from device memory arguments.
465   std::array<const void *, kNumArgs> device_memory_opaque_pointers_;
466 
467   // Addresses for non-shared-memory arguments.
468   std::array<const void *, kNumArgs> argument_addresses_;
469 
470   // Storage for arguments of templated type.
471   alignas(kMaxGenericArgSize)
472       std::array<char, kNumArgs * kMaxGenericArgSize> generic_arguments_;
473 
474   // Sizes for non-shared-memory arguments.
475   std::array<size_t, kNumArgs> argument_sizes_;
476 
477   // Size in bytes for each shared memory argument.
478   std::array<size_t, kNumArgs> shared_memory_bytes_;
479 
480   // Indices in the arguments array for shared memory arguments.
481   std::array<size_t, kNumArgs> shared_memory_indices_;
482 
483   // Total of all shared memory sizes.
484   size_t total_shared_memory_bytes_ = 0;
485 
486   // Number of significant entries in argument_addresses_ and argument_sizes_.
487   size_t number_of_argument_addresses_ = 0;
488 
489   // Number of significant entries in shared_memory_bytes_ and
490   // shared_memory_indices_.
491   size_t number_of_shared_memory_arguments_ = 0;
492 
493   // The number of generic arguments that have been added to generic_arguments_.
494   size_t number_of_generic_arguments_ = 0;
495 };
496 
497 // Typed variant of KernelBase, like a typed device function pointer. See the
498 // file comment for details and example usage.
499 //
500 // This class contains template metaprogramming magic to type check the
501 // parameters passed to a kernel launch are acceptable, and subsequently pack
502 // them into a form which can be used by the StreamExecutorInterface
503 // implementation. (i.e.  CUDA and OpenCL both bind void*s with associated
504 // sizes as kernel arguments.)
505 //
506 // Thread-compatible.
507 template <typename... Params>
508 class TypedKernel : public KernelBase {
509  public:
510   static constexpr size_t kNumberOfParameters = sizeof...(Params);
511 
512   // Delegates to KernelBase::KernelBase(), see that constructor.
513   explicit TypedKernel(StreamExecutor *parent) : KernelBase(parent) {}
514 
515   // Test-only constructor that can take a mock KernelInterface implementation.
516   // Takes ownership of implementation, it should not be null.
517   TypedKernel(StreamExecutor *parent, internal::KernelInterface *implementation)
518       : KernelBase(parent, implementation) {}
519 
520  private:
521   // Stream needs access to the specific parameter-packing functionality that
522   // the TypedKernel provides for its corresponding type signature (and no other
523   // type signatures).
524   friend class Stream;
525 
526   // This is the main entry point into the magic. Packs the parameters (which
527   // must type check against the class template) into the args and sizes
528   // arrays.
529   //
530   // Const refs are taken as parameters on all of the handlers to avoid
531   // implicit type promotion of integers.
532   //
533   // WARNING: as a performance optimization this method may store pointers to
534   // some of the input parameters in the kernel args structure, so any params
535   // passed into this method must live at least as long as the kernel args
536   // structure.
537   void PackParams(KernelArgsArray<kNumberOfParameters> *args,
538                   Params &... params) const {
539     PackOneParamFromList(args, params...);
540   }
541 
542   template <typename T, typename... RestOfParams>
543   void PackOneParamFromList(KernelArgsArray<kNumberOfParameters> *args,
544                             const T &arg, const RestOfParams &... rest) const {
545     PackOneParam(args, arg);
546     PackOneParamFromList(args, rest...);
547   }
548 
549   // Base case for variadic template expansion - nothing to do!
550   void PackOneParamFromList(KernelArgsArray<kNumberOfParameters> *args) const {}
551 
552   // Packs one (non-DeviceMemoryBase) parameter into the arg and sizes array.
553   // The enable_if<> is for excluding DeviceMemoryBase args, which have a
554   // separate implementation below.
555   template <typename T>
556   void PackOneParam(
557       KernelArgsArray<kNumberOfParameters> *args, const T &arg,
558       typename std::enable_if<!IsDeviceMemoryValueLike<T>::value &&
559                               !IsDeviceMemoryPointer<T>::value &&
560                               !IsSharedDeviceMemory<T>::value>::type * =
561           nullptr) const {
562     static_assert(!std::is_pointer<T>::value,
563                   "cannot pass raw pointer to the device");
564     static_assert(!std::is_convertible<T, DeviceMemoryBase>::value,
565                   "cannot pass device memory as a normal value");
566     args->add_argument(arg);
567   }
568 
569   // DeviceMemoryBase family reference override.
570   template <typename T>
571   void PackOneParam(
572       KernelArgsArray<kNumberOfParameters> *args, const T &arg,
573       typename std::enable_if<IsDeviceMemoryValueLike<T>::value>::type * =
574           nullptr) const {
575     args->add_device_memory_argument(arg);
576   }
577 
578   // DeviceMemoryBase family pointer override.
579   template <typename T>
580   void PackOneParam(
581       KernelArgsArray<kNumberOfParameters> *args, T arg,
582       typename std::enable_if<IsDeviceMemoryPointer<T>::value>::type * =
583           nullptr) const {
584     DeviceMemoryBase *ptr = static_cast<DeviceMemoryBase *>(arg);
585     args->add_device_memory_argument(*ptr);
586   }
587 
588   // Dynamic shared device memory has a size, but no associated allocation on
589   // the host; internally, the device will allocate storage.
590   template <typename T>
591   void PackOneParam(
592       KernelArgsArray<kNumberOfParameters> *args, T arg,
593       typename std::enable_if<IsSharedDeviceMemory<T>::value>::type * =
594           nullptr) const {
595     args->add_shared_bytes(arg.size());
596   }
597 
598   SE_DISALLOW_COPY_AND_ASSIGN(TypedKernel);
599 };
600 
601 // Template metaprogramming helper type that helps us produce better error
602 // messages at compile time when the are mismatches between the parameter
603 // type list and the argument type list.
604 template <typename ParamTuple, typename ArgTuple>
605 struct KernelInvocationChecker {
606   // Whether the parameter tuple and argument tuple match in length.
607   static constexpr bool kLengthMatches =
608       std::tuple_size<ParamTuple>::value == std::tuple_size<ArgTuple>::value;
609 
610   // The (matching) length of the parameters and arguments type lists.
611   static constexpr int kTupleLength =
612       static_cast<int>(std::tuple_size<ArgTuple>::value);
613 
614   // Helper trait to say whether the parameter wants a DeviceMemory-reference
615   // compatible type. This is for inexact type matches, so that it doesn't have
616   // to be precisely a const DeviceMemory<T>&, but can also be a value that
617   // represents the same.
618   template <typename ParamType, typename ArgType>
619   struct IsCompatibleDeviceMemoryRef {
620     static constexpr bool value = false;
621   };
622 
623   // See type trait definition above.
624   template <typename U>
625   struct IsCompatibleDeviceMemoryRef<const DeviceMemory<U> &, DeviceMemory<U>> {
626     static constexpr bool value = true;
627   };
628 
629   // See type trait definition above.
630   template <typename U>
631   struct IsCompatibleDeviceMemoryRef<const SharedDeviceMemory<U> &,
632                                      SharedDeviceMemory<U>> {
633     static constexpr bool value = true;
634   };
635 
636   // Returns whether ParamT and ArgT are compatible for data parallel kernel
637   // parameter packing without any assert functionality.
638   template <typename ParamT, typename ArgT>
639   static constexpr bool CompatibleNoAssert() {
640     return std::is_same<typename std::remove_const<ParamT>::type,
641                         ArgT>::value ||
642            IsCompatibleDeviceMemoryRef<ParamT, ArgT>::value;
643   }
644 
645   // Checks whether ParamT and ArgT are compatible for data parallel kernel
646   // parameter packing. kArgumentNumber is unused, it just for error display.
647   //
648   // NOTE: if you encounter an error here, you can see the mismatch by looking
649   // at the end of the last error message, which will be of the form:
650   //
651   //    ...::Compatible<const stream_executor::DeviceMemory<OneThing> &,
652   //                    stream_executor::DeviceMemory<AnotherThing>, true,
653   //                    0>'
654   //    requested here
655   //
656   // This means that the 0th argument you passed to the kernel invocation should
657   // have been DeviceMemory<OneThing> but was observed to be
658   // DeviceMemory<AnotherThing>.
659   template <typename ParamT, typename ArgT, bool kShouldStaticAssert,
660             int kArgumentNumber>
661   static constexpr bool Compatible() {
662     static_assert(
663         kShouldStaticAssert ? CompatibleNoAssert<ParamT, ArgT>() : true,
664         "parameter type (LHS) is not compatible with argument type (RHS)");
665     return CompatibleNoAssert<ParamT, ArgT>();
666   }
667 
668   // Checks the parameter/argument match at kArgumentNumber for an out of bounds
669   // argument number.
670   //
671   // This is the base case: we've run out of argument to check, so we're all
672   // good.
673   template <int kArgumentNumber, bool kShouldStaticAssert>
674   static constexpr bool CheckParam(
675       typename std::enable_if<(kArgumentNumber < 0)>::type *dummy = nullptr) {
676     return true;
677   }
678 
679   // Checks the parameter/argument match at kArgumentNumber.
680   // kShouldStaticAssert determines whether to assert out on a mismatch, or just
681   // yield the constexpr boolean value.
682   template <int kArgumentNumber, bool kShouldStaticAssert>
683   static constexpr bool CheckParam(
684       typename std::enable_if<kArgumentNumber >= 0>::type *dummy = nullptr) {
685     typedef typename std::tuple_element<kArgumentNumber, ParamTuple>::type
686         ParamT;
687     typedef typename std::tuple_element<kArgumentNumber, ArgTuple>::type ArgT;
688     return Compatible<ParamT, ArgT, kShouldStaticAssert, kArgumentNumber>() &&
689            CheckParam<kArgumentNumber - 1, kShouldStaticAssert>();
690   }
691 
692   // Checks the parameters/arguments for match, but doesn't static assert out.
693   // This is useful for testing/inspecting whether a set of parameters match in
694   // things like tests.
695   static constexpr bool CheckAllNoStaticAssert() {
696     return kLengthMatches && CheckParam<kTupleLength - 1, false>();
697   }
698 
699   // Checks the parameters and static asserts out with a helpful error message
700   // (and useful template parameters in the instantiation stack) if there is an
701   // error.
702   static constexpr bool CheckAllStaticAssert() {
703     static_assert(kLengthMatches,
704                   "argument length mismatched against typed kernel parameters");
705     return kLengthMatches && CheckParam<kTupleLength - 1, true>();
706   }
707 };
708 
709 // This is a convenience type for checking whether a typed kernel matches
710 // against a type list.
711 template <typename KernelT, typename... Params>
712 struct KernelParamsOk {
713   static constexpr bool kResult = false;
714 };
715 
716 // See above.
717 template <typename... Params, typename... Args>
718 struct KernelParamsOk<TypedKernel<Params...>, Args...> {
719   static constexpr bool kResult = KernelInvocationChecker<
720       std::tuple<Params...>, std::tuple<Args...>>::CheckAllNoStaticAssert();
721 };
722 
723 }  // namespace stream_executor
724 
725 #endif  // TENSORFLOW_STREAM_EXECUTOR_KERNEL_H_
726