• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/stream_executor/cuda/cuda_dnn.h"
17 
18 #include <functional>
19 #include <memory>
20 #include <utility>
21 
22 #include "absl/memory/memory.h"
23 #include "absl/strings/str_cat.h"
24 #include "absl/strings/str_format.h"
25 #include "third_party/eigen3/Eigen/Core"
26 #include "tensorflow/core/lib/core/errors.h"
27 #include "tensorflow/core/platform/tensor_float_32_utils.h"
28 #include "tensorflow/core/util/determinism.h"
29 #include "tensorflow/core/util/env_var.h"
30 #include "tensorflow/core/util/use_cudnn.h"
31 #include "tensorflow/stream_executor/cuda/cuda_activation.h"
32 #include "tensorflow/stream_executor/cuda/cuda_diagnostics.h"
33 #include "tensorflow/stream_executor/cuda/cuda_driver.h"
34 #include "tensorflow/stream_executor/cuda/cuda_gpu_executor.h"
35 #include "tensorflow/stream_executor/cuda/cuda_platform_id.h"
36 #include "tensorflow/stream_executor/cuda/cuda_stream.h"
37 #include "tensorflow/stream_executor/cuda/cuda_timer.h"
38 #include "tensorflow/stream_executor/cuda/cudnn_version.h"
39 #include "tensorflow/stream_executor/dnn.h"
40 #include "tensorflow/stream_executor/lib/env.h"
41 #include "tensorflow/stream_executor/lib/error.h"
42 #include "tensorflow/stream_executor/lib/initialize.h"
43 #include "tensorflow/stream_executor/lib/mathutil.h"
44 #include "tensorflow/stream_executor/lib/threadpool.h"
45 #include "tensorflow/stream_executor/platform/logging.h"
46 #include "tensorflow/stream_executor/plugin_registry.h"
47 #include "tensorflow/stream_executor/scratch_allocator.h"
48 #include "tensorflow/stream_executor/stream.h"
49 #include "tensorflow/stream_executor/stream_executor_pimpl.h"
50 // clang-format off
51 #include "third_party/gpus/cudnn/cudnn.h"
52 #if CUDNN_VERSION >= 8100 && TF_ENABLE_CUDNN_FRONTEND
53 #include "third_party/cudnn_frontend/include/cudnn_frontend.h"
54 #endif  // CUDNN_VERSION >= 8100 && TF_ENABLE_CUDNN_FRONTEND
55 #include "absl/strings/string_view.h"
56 // clang-format on
57 
58 #pragma clang diagnostic push
59 
60 // Make sure that Eigen::half forward declaration in dnn.h matches the
61 // declaration in Eigen.
62 #pragma clang diagnostic warning "-Wmismatched-tags"
63 
64 namespace stream_executor {
65 namespace gpu {
66 
67 PLUGIN_REGISTRY_DEFINE_PLUGIN_ID(kCuDnnPlugin);
68 
69 namespace {
70 
71 static_assert(CUDNN_VERSION >= 7300, "cuDNN needs to be version 7.3 or higher");
72 
73 // Exits the program if 'expr' doesn't return CUDNN_STATUS_SUCCESS.
74 #define CHECK_CUDNN_OK(expr) CHECK_EQ(expr, CUDNN_STATUS_SUCCESS)
75 
76 // If 'expr' doesn't return CUDNN_STATUS_SUCCESS, returns from the current
77 // function with a non-successful port::Status.
78 #define RETURN_IF_CUDNN_ERROR(expr)                                      \
79   do {                                                                   \
80     cudnnStatus_t _status = (expr);                                      \
81     if (!SE_PREDICT_TRUE(_status == CUDNN_STATUS_SUCCESS)) {             \
82       std::ostringstream oss;                                            \
83       oss << ToString(_status) << "\nin " << __FILE__ << "(" << __LINE__ \
84           << "): '" << #expr << "'";                                     \
85       return port::Status(port::error::UNKNOWN, oss.str());              \
86     }                                                                    \
87   } while (false)
88 
89 #define RETURN_MSG_IF_CUDNN_ERROR(expr)                                  \
90   do {                                                                   \
91     cudnnStatus_t _status = (expr).get_status();                         \
92     if (!SE_PREDICT_TRUE(_status == CUDNN_STATUS_SUCCESS)) {             \
93       std::ostringstream oss;                                            \
94       oss << ToString(_status) << "\nin " << __FILE__ << "(" << __LINE__ \
95           << "): '" << #expr << "' " << (expr).get_error();              \
96       return port::Status(port::error::UNKNOWN, oss.str());              \
97     }                                                                    \
98   } while (false)
99 
100 #define RETURN_FALSE_IF_CUDNN_ERROR(expr)                                \
101   do {                                                                   \
102     if (!SE_PREDICT_TRUE((expr).get_status() == CUDNN_STATUS_SUCCESS)) { \
103       return false;                                                      \
104     }                                                                    \
105   } while (false)
106 
107 // Converts (via narrowing) a type T value to a type U, and checks that the
108 // value has no value change due to the conversion.
109 template <typename WideT, typename NarrowT>
CheckedNarrowing(const WideT & wide)110 NarrowT CheckedNarrowing(const WideT& wide) {
111   NarrowT narrow = wide;
112   CHECK_EQ(narrow, wide)
113       << "checked narrowing failed; values not equal post-conversion";
114   return narrow;
115 }
116 
ToString(cudnnStatus_t status)117 std::string ToString(cudnnStatus_t status) {
118   switch (status) {
119     case CUDNN_STATUS_SUCCESS:
120       return "CUDNN_STATUS_SUCCESS";
121     case CUDNN_STATUS_NOT_INITIALIZED:
122       return "CUDNN_STATUS_NOT_INITIALIZED";
123     case CUDNN_STATUS_ALLOC_FAILED:
124       return "CUDNN_STATUS_ALLOC_FAILED";
125     case CUDNN_STATUS_BAD_PARAM:
126       return "CUDNN_STATUS_BAD_PARAM";
127     case CUDNN_STATUS_INTERNAL_ERROR:
128       return "CUDNN_STATUS_INTERNAL_ERROR";
129     case CUDNN_STATUS_INVALID_VALUE:
130       return "CUDNN_STATUS_INVALID_VALUE";
131     case CUDNN_STATUS_ARCH_MISMATCH:
132       return "CUDNN_STATUS_ARCH_MISMATCH";
133     case CUDNN_STATUS_MAPPING_ERROR:
134       return "CUDNN_STATUS_MAPPING_ERROR";
135     case CUDNN_STATUS_EXECUTION_FAILED:
136       return "CUDNN_STATUS_EXECUTION_FAILED";
137     case CUDNN_STATUS_NOT_SUPPORTED:
138       return "CUDNN_STATUS_NOT_SUPPORTED";
139     case CUDNN_STATUS_LICENSE_ERROR:
140       return "CUDNN_STATUS_LICENSE_ERROR";
141     case CUDNN_STATUS_RUNTIME_PREREQUISITE_MISSING:
142       return "CUDNN_STATUS_RUNTIME_PREREQUISITE_MISSING";
143     case CUDNN_STATUS_RUNTIME_IN_PROGRESS:
144       return "CUDNN_STATUS_RUNTIME_IN_PROGRESS";
145     case CUDNN_STATUS_RUNTIME_FP_OVERFLOW:
146       return "CUDNN_STATUS_RUNTIME_FP_OVERFLOW";
147     default:
148       return absl::StrCat("<unknown cudnn status: ", static_cast<int>(status),
149                           ">");
150   }
151 }
152 
153 // RAII wrapper for all calls to cuDNN with a cuDNN handle argument.
154 //
155 // See CudnnAccess::GetHandle() for details.
156 class CudnnHandle {
157  public:
158   // Takes ownership of the executor context and the lock to access cuDNN
159   // using handle.
CudnnHandle(gpu::ScopedActivateExecutorContext context,std::unique_ptr<absl::MutexLock> lock,cudnnHandle_t handle)160   CudnnHandle(gpu::ScopedActivateExecutorContext context,
161               std::unique_ptr<absl::MutexLock> lock, cudnnHandle_t handle)
162       : context_(std::move(context)), lock_(std::move(lock)), handle_(handle) {}
163 
164   // Returns cuDNN handle. To be passed directly to cuDNN APIs, don't keep
165   // a copy.
handle() const166   cudnnHandle_t handle() const { return handle_; }
167 
168  private:
169   gpu::ScopedActivateExecutorContext context_;
170   std::unique_ptr<absl::MutexLock> lock_;
171   cudnnHandle_t handle_;  // Not owned.
172 };
173 
174 }  // namespace
175 
176 // Wraps a cuDNN handle and provides access to it through CudnnHandle
177 // instances, which also locks a mutex, acquires the CUDA context, and sets
178 // the stream that cuDNN should use to enqueue any work.
179 //
180 // Note: CudnnSupport::cudnn_ should be the only instantiation of this class.
181 class CudnnAccess {
182  public:
183   // Takes ownership of the handle.
CudnnAccess(cudnnHandle_t handle)184   explicit CudnnAccess(cudnnHandle_t handle) : handle_(handle) {}
185 
~CudnnAccess()186   ~CudnnAccess() {
187     absl::MutexLock lock(&mutex_);
188     cudnnDestroy(handle_);
189   }
190 
191   // Creates a CudnnHandle instance for stream.
192   //
193   // cuDNN API calls using the same handle instance need to be serialized
194   // across threads. This is guaranteed by CudnnHandle instances locking the
195   // mutex owned by this class.
196   //
197   // Most cuDNN APIs taking a handle perform work on a CUDA stream. The
198   // CudnnHandle instance acquires the executor's CUDA context and sets cuDNN
199   // to use the provided stream.
200   //
201   // The stream argument may be null, which translates to the legacy default
202   // stream. See
203   // https://docs.nvidia.com/cuda/cuda-driver-api/stream-sync-behavior.html.
204   // The legacy default stream synchronizes with all other streams and it is
205   // therefore a bad idea (performance wise) to call any cuDNN APIs that
206   // enqueue work in the stream.
GetHandle(GpuExecutor * executor,Stream * stream)207   CudnnHandle GetHandle(GpuExecutor* executor, Stream* stream) {
208     auto lock = absl::make_unique<absl::MutexLock>(&mutex_);
209     mutex_.AssertHeld();
210     gpu::ScopedActivateExecutorContext context(executor);
211     CUstream cu_stream = stream ? AsGpuStreamValue(stream) : cudaStreamLegacy;
212     const auto status = cudnnSetStream(handle_, cu_stream);
213     CHECK_EQ(status, CUDNN_STATUS_SUCCESS) << "Failed to set cuDNN stream.";
214     return CudnnHandle(std::move(context), std::move(lock), handle_);
215   }
216 
217  private:
218   // Guards the enqueueing of cuDNN operations via the handle_ below.
219   absl::Mutex mutex_;
220 
221   // cuDNN library handle.
222   cudnnHandle_t handle_ TF_GUARDED_BY(mutex_);  // Owned.
223 };
224 
225 namespace {
226 
227 // A helper function to return the internal compute type for
228 // RNNs in cudnn.
229 cudnnDataType_t GetRnnComputeType(dnn::DataType data_type);
230 
ToConvForwardAlgo(dnn::AlgorithmDesc algorithm)231 cudnnConvolutionFwdAlgo_t ToConvForwardAlgo(dnn::AlgorithmDesc algorithm) {
232   cudnnConvolutionFwdAlgo_t algo =
233       cudnnConvolutionFwdAlgo_t(algorithm.algo_id());
234   switch (algo) {
235     case CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM:
236     case CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM:
237     case CUDNN_CONVOLUTION_FWD_ALGO_GEMM:
238     case CUDNN_CONVOLUTION_FWD_ALGO_DIRECT:
239     case CUDNN_CONVOLUTION_FWD_ALGO_FFT:
240     case CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING:
241     case CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD:
242     case CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED:
243       return algo;
244     default:
245       LOG(FATAL) << "Unsupported Cudnn convolution forward algorithm: "
246                  << algorithm.algo_id();
247   }
248 }
249 
ToConvBackwardDataAlgo(dnn::AlgorithmDesc algorithm)250 cudnnConvolutionBwdDataAlgo_t ToConvBackwardDataAlgo(
251     dnn::AlgorithmDesc algorithm) {
252   cudnnConvolutionBwdDataAlgo_t algo =
253       cudnnConvolutionBwdDataAlgo_t(algorithm.algo_id());
254   switch (algo) {
255     case CUDNN_CONVOLUTION_BWD_DATA_ALGO_0:
256     case CUDNN_CONVOLUTION_BWD_DATA_ALGO_1:
257     case CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT:
258     case CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING:
259     case CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD:
260     case CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD_NONFUSED:
261       return algo;
262     default:
263       LOG(FATAL)
264           << "Unsupported Cudnn convolution backward algorithm for data: "
265           << algorithm.algo_id();
266   }
267 }
268 
ToConvBackwardFilterAlgo(dnn::AlgorithmDesc algorithm)269 cudnnConvolutionBwdFilterAlgo_t ToConvBackwardFilterAlgo(
270     dnn::AlgorithmDesc algorithm) {
271   cudnnConvolutionBwdFilterAlgo_t algo =
272       cudnnConvolutionBwdFilterAlgo_t(algorithm.algo_id());
273   switch (algo) {
274     case CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0:
275     case CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1:
276     case CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT:
277     case CUDNN_CONVOLUTION_BWD_FILTER_ALGO_3:
278     // Based on cudnn.h, the following is not implemented.
279     // case CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD:
280     case CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD_NONFUSED:
281     case CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT_TILING:
282       return algo;
283     default:
284       LOG(FATAL)
285           << "Unsupported Cudnn convolution backward algorithm for filter: "
286           << algorithm.algo_id();
287   }
288 }
289 
GetCudnnProperty(libraryPropertyType type)290 port::StatusOr<int> GetCudnnProperty(libraryPropertyType type) {
291   int value;
292   RETURN_IF_CUDNN_ERROR(cudnnGetProperty(type, &value));
293   return value;
294 }
295 
ToCudnnRNNAlgo(absl::optional<dnn::AlgorithmDesc> algorithm)296 cudnnRNNAlgo_t ToCudnnRNNAlgo(absl::optional<dnn::AlgorithmDesc> algorithm) {
297   if (!algorithm.has_value()) {
298     return CUDNN_RNN_ALGO_STANDARD;
299   }
300   cudnnRNNAlgo_t algo = static_cast<cudnnRNNAlgo_t>(algorithm->algo_id());
301   switch (algo) {
302     case CUDNN_RNN_ALGO_STANDARD:
303     case CUDNN_RNN_ALGO_PERSIST_STATIC:
304     case CUDNN_RNN_ALGO_PERSIST_DYNAMIC:
305       return algo;
306     default:
307       LOG(FATAL) << "Unsupported Cudnn RNN algorithm: " << algorithm->algo_id();
308   }
309 }
310 
GetLoadedCudnnVersion(CudnnVersion * version)311 port::Status GetLoadedCudnnVersion(CudnnVersion* version) {
312   SE_ASSIGN_OR_RETURN(version->major_version, GetCudnnProperty(MAJOR_VERSION));
313   SE_ASSIGN_OR_RETURN(version->minor_version, GetCudnnProperty(MINOR_VERSION));
314   SE_ASSIGN_OR_RETURN(version->patch_level, GetCudnnProperty(PATCH_LEVEL));
315   return port::Status::OK();
316 }
317 
318 #if CUDNN_MAJOR >= 8 && (CUDNN_MINOR > 0 || CUDNN_PATCHLEVEL >= 4)
PreloadCudnnLibrary(cudnnStatus_t (* version_check_fn)(),absl::string_view sub_library)319 void PreloadCudnnLibrary(cudnnStatus_t (*version_check_fn)(),
320                          absl::string_view sub_library) {
321   cudnnStatus_t status = version_check_fn();
322   if (status != CUDNN_STATUS_SUCCESS) {
323     VLOG(1) << "Could not pre-initialize cuDNN sub-library " << sub_library
324             << ".  Error: " << cudnnGetErrorString(status) << ".";
325   }
326 }
327 #endif
328 
329 }  // namespace
330 
CudnnSupport(GpuExecutor * parent)331 CudnnSupport::CudnnSupport(GpuExecutor* parent) : parent_(parent) {}
332 
Init()333 port::Status CudnnSupport::Init() {
334   ScopedActivateExecutorContext context(parent_);
335 
336   // Peek at the last error to give more information in cases of errors.
337   cudaError_t cerr = cudaPeekAtLastError();
338   if (cerr != cudaSuccess) {
339     LOG(WARNING) << "There was an error before creating cudnn handle: "
340                  << cudaGetErrorName(cerr) << " : " << cudaGetErrorString(cerr);
341   }
342 
343   cudnnHandle_t cudnn_handle = nullptr;
344   const auto status = cudnnCreate(&cudnn_handle);
345   if (status == CUDNN_STATUS_SUCCESS) {
346     CudnnVersion source_version(CUDNN_MAJOR, CUDNN_MINOR, CUDNN_PATCHLEVEL);
347 
348     CudnnVersion loaded_version;
349     TF_RETURN_IF_ERROR(GetLoadedCudnnVersion(&loaded_version));
350     if (!IsSourceCompatibleWithCudnnLibrary(source_version, loaded_version)) {
351       const std::string error = absl::StrCat(
352           "Loaded runtime CuDNN library: ", loaded_version.ToString(),
353           " but source was compiled with: ", source_version.ToString(),
354           ".  CuDNN library needs to have matching major version and equal or "
355           "higher minor version. If using a binary install, upgrade your CuDNN "
356           "library.  If building from sources, make sure the library loaded at "
357           "runtime is compatible with the version specified during compile "
358           "configuration.");
359       LOG(ERROR) << error;
360       cudnnDestroy(cudnn_handle);
361       return port::Status(port::error::INTERNAL, error);
362     }
363 
364     cudnn_.reset(new CudnnAccess(cudnn_handle));
365 
366     LOG(INFO) << "Loaded cuDNN version " << cudnnGetVersion();
367     return port::Status::OK();
368   }
369 
370   CHECK_EQ(cudnn_handle, nullptr);
371   LOG(ERROR) << "Could not create cudnn handle: " << ToString(status);
372   if (status == CUDNN_STATUS_NOT_INITIALIZED) {
373     auto result = gpu::Diagnostician::FindKernelDriverVersion();
374     if (!result.ok()) {
375       LOG(ERROR) << "Error retrieving driver version: "
376                  << cuda::DriverVersionStatusToString(result);
377     } else {
378       const auto& version = result.ValueOrDie();
379       LOG(ERROR) << "Possibly insufficient driver version: "
380                  << cuda::DriverVersionToString(version);
381     }
382   }
383 
384   return port::Status(port::error::INTERNAL,
385                       absl::StrCat("cudnn library could not create a handle: ",
386                                    ToString(status)));
387 }
388 
389 port::StatusOr<perftools::gputools::dnn::VersionInfo>
GetVersion()390 CudnnSupport::GetVersion() {
391   CudnnVersion version;
392   TF_RETURN_IF_ERROR(GetLoadedCudnnVersion(&version));
393   return perftools::gputools::dnn::VersionInfo(
394       version.major_version, version.minor_version, version.patch_level);
395 }
396 
397 namespace {
398 
399 // Deleter functors for cuDNN types that need to be deleted.
400 struct TensorDescriptorDeleter {
operator ()stream_executor::gpu::__anone3ef6fe60311::TensorDescriptorDeleter401   void operator()(cudnnTensorDescriptor_t descriptor) const {
402     CHECK_CUDNN_OK(cudnnDestroyTensorDescriptor(descriptor));
403   }
404 };
405 struct RNNDataDescriptorDeleter {
operator ()stream_executor::gpu::__anone3ef6fe60311::RNNDataDescriptorDeleter406   void operator()(cudnnRNNDataDescriptor_t descriptor) const {
407     CHECK_CUDNN_OK(cudnnDestroyRNNDataDescriptor(descriptor));
408   }
409 };
410 struct FilterDescriptorDeleter {
operator ()stream_executor::gpu::__anone3ef6fe60311::FilterDescriptorDeleter411   void operator()(cudnnFilterDescriptor_t descriptor) const {
412     CHECK_CUDNN_OK(cudnnDestroyFilterDescriptor(descriptor));
413   }
414 };
415 struct ConvolutionDescriptorDeleter {
operator ()stream_executor::gpu::__anone3ef6fe60311::ConvolutionDescriptorDeleter416   void operator()(cudnnConvolutionDescriptor_t descriptor) const {
417     CHECK_CUDNN_OK(cudnnDestroyConvolutionDescriptor(descriptor));
418   }
419 };
420 struct PoolingDescriptorDeleter {
operator ()stream_executor::gpu::__anone3ef6fe60311::PoolingDescriptorDeleter421   void operator()(cudnnPoolingDescriptor_t descriptor) const {
422     CHECK_CUDNN_OK(cudnnDestroyPoolingDescriptor(descriptor));
423   }
424 };
425 struct LrnDescriptorDeleter {
operator ()stream_executor::gpu::__anone3ef6fe60311::LrnDescriptorDeleter426   void operator()(cudnnLRNDescriptor_t descriptor) const {
427     CHECK_CUDNN_OK(cudnnDestroyLRNDescriptor(descriptor));
428   }
429 };
430 
431 struct ActivationDescriptorDeleter {
operator ()stream_executor::gpu::__anone3ef6fe60311::ActivationDescriptorDeleter432   void operator()(cudnnActivationDescriptor_t descriptor) const {
433     CHECK_CUDNN_OK(cudnnDestroyActivationDescriptor(descriptor));
434   }
435 };
436 struct DropoutDescriptorDeleter {
operator ()stream_executor::gpu::__anone3ef6fe60311::DropoutDescriptorDeleter437   void operator()(cudnnDropoutDescriptor_t descriptor) const {
438     CHECK_CUDNN_OK(cudnnDestroyDropoutDescriptor(descriptor));
439   }
440 };
441 struct RnnDescriptorDeleter {
operator ()stream_executor::gpu::__anone3ef6fe60311::RnnDescriptorDeleter442   void operator()(cudnnRNNDescriptor_t descriptor) const {
443     CHECK_CUDNN_OK(cudnnDestroyRNNDescriptor(descriptor));
444   }
445 };
446 struct PersistentRnnPlanDeleter {
operator ()stream_executor::gpu::__anone3ef6fe60311::PersistentRnnPlanDeleter447   void operator()(cudnnPersistentRNNPlan_t plan) const {
448     CHECK_CUDNN_OK(cudnnDestroyPersistentRNNPlan(plan));
449   }
450 };
451 #if CUDNN_VERSION >= 7603
452 struct CtcLossDescriptorDeleter {
operator ()stream_executor::gpu::__anone3ef6fe60311::CtcLossDescriptorDeleter453   void operator()(cudnnCTCLossDescriptor_t descriptor) const {
454     CHECK_CUDNN_OK(cudnnDestroyCTCLossDescriptor(descriptor));
455   }
456 };
457 #endif
458 
459 // RAII wrappers for cuDNN types.
460 using TensorDescriptor =
461     std::unique_ptr<cudnnTensorStruct, TensorDescriptorDeleter>;
462 using RNNDataDescriptor =
463     std::unique_ptr<cudnnRNNDataStruct, RNNDataDescriptorDeleter>;
464 using FilterDescriptor =
465     std::unique_ptr<cudnnFilterStruct, FilterDescriptorDeleter>;
466 using ConvolutionDescriptor =
467     std::unique_ptr<cudnnConvolutionStruct, ConvolutionDescriptorDeleter>;
468 using PoolingDescriptor =
469     std::unique_ptr<cudnnPoolingStruct, PoolingDescriptorDeleter>;
470 using LrnDescriptor = std::unique_ptr<cudnnLRNStruct, LrnDescriptorDeleter>;
471 using ActivationDescriptor =
472     std::unique_ptr<cudnnActivationStruct, ActivationDescriptorDeleter>;
473 using DropoutDescriptor =
474     std::unique_ptr<cudnnDropoutStruct, DropoutDescriptorDeleter>;
475 using RnnDescriptor = std::unique_ptr<cudnnRNNStruct, RnnDescriptorDeleter>;
476 using PersistentRnnPlan =
477     std::unique_ptr<cudnnPersistentRNNPlan, PersistentRnnPlanDeleter>;
478 #if CUDNN_VERSION >= 7603
479 using CtcLossDescriptor =
480     std::unique_ptr<cudnnCTCLossStruct, CtcLossDescriptorDeleter>;
481 #endif
482 
483 // Factory methods for cuDNN types.
CreateTensorDescriptor()484 TensorDescriptor CreateTensorDescriptor() {
485   cudnnTensorDescriptor_t result;
486   CHECK_CUDNN_OK(cudnnCreateTensorDescriptor(&result));
487   return TensorDescriptor(result);
488 }
CreateRNNDataDescriptor()489 RNNDataDescriptor CreateRNNDataDescriptor() {
490   cudnnRNNDataDescriptor_t result;
491   CHECK_CUDNN_OK(cudnnCreateRNNDataDescriptor(&result));
492   return RNNDataDescriptor(result);
493 }
CreateFilterDescriptor()494 FilterDescriptor CreateFilterDescriptor() {
495   cudnnFilterDescriptor_t result;
496   CHECK_CUDNN_OK(cudnnCreateFilterDescriptor(&result));
497   return FilterDescriptor(result);
498 }
CreateConvolutionDescriptor()499 ConvolutionDescriptor CreateConvolutionDescriptor() {
500   cudnnConvolutionDescriptor_t result;
501   CHECK_CUDNN_OK(cudnnCreateConvolutionDescriptor(&result));
502   return ConvolutionDescriptor(result);
503 }
CreatePoolingDescriptor()504 PoolingDescriptor CreatePoolingDescriptor() {
505   cudnnPoolingDescriptor_t result;
506   CHECK_CUDNN_OK(cudnnCreatePoolingDescriptor(&result));
507   return PoolingDescriptor(result);
508 }
CreateLrnDescriptor()509 LrnDescriptor CreateLrnDescriptor() {
510   cudnnLRNDescriptor_t result;
511   CHECK_CUDNN_OK(cudnnCreateLRNDescriptor(&result));
512   return LrnDescriptor(result);
513 }
CreateActivationDescriptor()514 ActivationDescriptor CreateActivationDescriptor() {
515   cudnnActivationDescriptor_t result;
516   CHECK_CUDNN_OK(cudnnCreateActivationDescriptor(&result));
517   return ActivationDescriptor(result);
518 }
CreateDropoutDescriptor()519 DropoutDescriptor CreateDropoutDescriptor() {
520   cudnnDropoutDescriptor_t result;
521   CHECK_CUDNN_OK(cudnnCreateDropoutDescriptor(&result));
522   return DropoutDescriptor(result);
523 }
CreateRnnDescriptor()524 RnnDescriptor CreateRnnDescriptor() {
525   cudnnRNNDescriptor_t result;
526   CHECK_CUDNN_OK(cudnnCreateRNNDescriptor(&result));
527   return RnnDescriptor(result);
528 }
529 #if CUDNN_VERSION >= 7603
CreateCtcLossDescriptor()530 CtcLossDescriptor CreateCtcLossDescriptor() {
531   cudnnCTCLossDescriptor_t result;
532   CHECK_CUDNN_OK(cudnnCreateCTCLossDescriptor(&result));
533   return CtcLossDescriptor(result);
534 }
535 #endif
536 
CreatePersistentRnnPlan(cudnnRNNDescriptor_t rnn_desc,int batch_size,cudnnDataType_t data_type)537 port::StatusOr<PersistentRnnPlan> CreatePersistentRnnPlan(
538     cudnnRNNDescriptor_t rnn_desc, int batch_size, cudnnDataType_t data_type) {
539   cudnnPersistentRNNPlan_t result;
540   RETURN_IF_CUDNN_ERROR(
541       cudnnCreatePersistentRNNPlan(rnn_desc, batch_size, data_type, &result));
542   return port::StatusOr<PersistentRnnPlan>(PersistentRnnPlan(result));
543 }
544 
545 // Turns a BatchDescriptor structure into a cudnn tensor handle within a
546 // scope.
547 class CudnnTensorDescriptor {
548  public:
CudnnTensorDescriptor(const dnn::BatchDescriptor & batch_descriptor,cudnnDataType_t elem_type)549   CudnnTensorDescriptor(const dnn::BatchDescriptor& batch_descriptor,
550                         cudnnDataType_t elem_type)
551       : handle_(CreateTensorDescriptor()) {
552     switch (batch_descriptor.layout()) {
553       case dnn::DataLayout::kBatchYXDepth:
554       case dnn::DataLayout::kBatchDepthYX: {
555         const int nd = batch_descriptor.ndims() + 2;
556         // cuDNN requires the strides and dims to be ordered as BDYX.
557         std::vector<int64> strides64 =
558             batch_descriptor.full_strides(dnn::DataLayout::kBatchDepthYX);
559         std::vector<int64> dims64 =
560             batch_descriptor.full_dims(dnn::DataLayout::kBatchDepthYX);
561 
562         // cuDNN requires arrays of ints.
563         std::vector<int> strides(nd);
564         std::vector<int> dims(nd);
565         std::transform(strides64.cbegin(), strides64.cend(), strides.begin(),
566                        &CheckedNarrowing<int64, int>);
567         std::transform(dims64.cbegin(), dims64.cend(), dims.begin(),
568                        &CheckedNarrowing<int64, int>);
569         CHECK_CUDNN_OK(cudnnSetTensorNdDescriptor(handle_.get(), elem_type, nd,
570                                                   dims.data(), strides.data()))
571             << "batch_descriptor: " << batch_descriptor.ToString();
572         break;
573       }
574       case dnn::DataLayout::kBatchDepthYX4:
575       case dnn::DataLayout::kBatchDepthYX32: {
576         auto expected_elem_ty =
577             batch_descriptor.layout() == dnn::DataLayout::kBatchDepthYX4
578                 ? CUDNN_DATA_INT8x4
579                 : CUDNN_DATA_INT8x32;
580         CHECK_EQ(elem_type, expected_elem_ty);
581         CHECK_CUDNN_OK(cudnnSetTensor4dDescriptor(
582             handle_.get(), CUDNN_TENSOR_NCHW_VECT_C, elem_type,
583             batch_descriptor.count(), batch_descriptor.feature_map_count(),
584             batch_descriptor.height(), batch_descriptor.width()))
585             << "batch_descriptor: " << batch_descriptor.ToString();
586         break;
587       }
588       default:
589         LOG(FATAL) << "Unsupported tensor format "
590                    << DataLayoutString(batch_descriptor.layout());
591         break;
592     }
593   }
594 
handle() const595   cudnnTensorDescriptor_t handle() const { return handle_.get(); }
596 
597  private:
598   TensorDescriptor handle_;
599 
600   SE_DISALLOW_COPY_AND_ASSIGN(CudnnTensorDescriptor);
601 };
602 
603 // Turns a FilterDescriptor structure into a cudnn filter handle within a
604 // scope.
605 class CudnnFilterDescriptor {
606  public:
CudnnFilterDescriptor(const dnn::FilterDescriptor & filter_descriptor,cudnnDataType_t elem_type)607   CudnnFilterDescriptor(const dnn::FilterDescriptor& filter_descriptor,
608                         cudnnDataType_t elem_type)
609       : handle_(CreateFilterDescriptor()) {
610     // TODO(b/23032134): Even if the filter layout is not supported,
611     // cudnnSetFilter4DDescriptor_v4 will return CUDNN_STATUS_SUCCESS because
612     // it does not take layout as an input. Maybe force cuDNN by giving wrong
613     // inputs intentionally?
614     cudnnTensorFormat_t format;
615     switch (filter_descriptor.layout()) {
616       case dnn::FilterLayout::kOutputInputYX:
617         format = CUDNN_TENSOR_NCHW;
618         break;
619       case dnn::FilterLayout::kOutputYXInput:
620         format = CUDNN_TENSOR_NHWC;
621         break;
622       case dnn::FilterLayout::kOutputInputYX4:
623       case dnn::FilterLayout::kOutputInputYX32: {
624         auto expected_elem_ty =
625             filter_descriptor.layout() == dnn::FilterLayout::kOutputInputYX4
626                 ? CUDNN_DATA_INT8x4
627                 : CUDNN_DATA_INT8x32;
628         CHECK_EQ(elem_type, expected_elem_ty);
629         format = CUDNN_TENSOR_NCHW_VECT_C;
630         break;
631       }
632       default:
633         LOG(FATAL) << "Unsupported filter format "
634                    << FilterLayoutString(filter_descriptor.layout());
635         break;
636     }
637 
638     std::vector<int> dims(2 + filter_descriptor.ndims());
639     dims[0] = filter_descriptor.output_feature_map_count();
640     dims[1] = filter_descriptor.input_feature_map_count();
641     absl::Span<const int64> spatial_dims =
642         filter_descriptor.input_filter_dims();
643     std::copy(spatial_dims.begin(), spatial_dims.end(), dims.begin() + 2);
644 
645     CHECK_CUDNN_OK(cudnnSetFilterNdDescriptor(handle_.get(), elem_type, format,
646                                               dims.size(), dims.data()));
647   }
648 
handle() const649   cudnnFilterDescriptor_t handle() const { return handle_.get(); }
650 
651  private:
652   FilterDescriptor handle_;  // Owned.
653 
654   SE_DISALLOW_COPY_AND_ASSIGN(CudnnFilterDescriptor);
655 };
656 
657 #if CUDNN_VERSION >= 8100 && TF_ENABLE_CUDNN_FRONTEND
658 // The errata sheet (JSON format) for marking the cudnn engines that might be
659 // buggy. For example, we don't want the engine 999 of forward convolution:
660 // R"({ "version" : 1,
661 //      "rules"   : [
662 //        { "rule_id"             : "ConvFwd_eng999",
663 //          "operation"           : "ConvFwd",
664 //          "engine"              : 999,
665 //          "knob"                : [],
666 //          "cudnn_version_start" : 8000,
667 //          "cudnn_version_end"   : -1
668 //        }
669 // ]})"
670 // We skip eng0 in the static filter because they are too slow. Additionally,
671 // users can specify an additional errata JSON file via
672 // CUDNN_ERRATA_JSON_FILE at runtime.
CudnnExecutionPlanEngineFilterStatic()673 const json* CudnnExecutionPlanEngineFilterStatic() {
674   static absl::string_view filter_str = R"({
675       "version" : 1,
676         "rules"   : [
677           { "rule_id"             : "ConvFwd_eng0",
678             "operation"           : "ConvFwd",
679             "engine"              : 0,
680             "knob"                : [],
681             "cudnn_version_start" : 8000,
682             "cudnn_version_end"   : -1
683           },
684           { "rule_id"             : "ConvBwdData_eng0",
685             "operation"           : "ConvBwdData",
686             "engine"              : 0,
687             "knob"                : [],
688             "cudnn_version_start" : 8000,
689             "cudnn_version_end"   : -1
690           },
691           { "rule_id"             : "ConvBwdFilter_eng0",
692             "operation"           : "ConvBwdFilter",
693             "engine"              : 0,
694             "knob"                : [],
695             "cudnn_version_start" : 8000,
696             "cudnn_version_end"   : -1
697           }
698       ]})";
699   static const json* json_handle = new json(json::parse(filter_str));
700   return json_handle;
701 }
702 
CudnnExecutionPlanEngineFilterRuntime()703 const json* CudnnExecutionPlanEngineFilterRuntime() {
704   static const json* json_handle = []() -> const json* {
705     json j;
706     if (cudnn_frontend::load_from_config(j, "")) {
707       return new json(j);
708     }
709     return nullptr;
710   }();
711   return json_handle;
712 }
713 #endif  // CUDNN_VERSION >= 8100 && TF_ENABLE_CUDNN_FRONTEND
714 
715 // A helper function to decide whether to use
716 // CUDNN_BATCHNORM_SPATIAL_PERSISTENT in batchnorm. This mode can be faster in
717 // some tasks because an optimized path may be selected for CUDNN_DATA_FLOAT
718 // and CUDNN_DATA_HALF data types, compute capability 6.0 or higher. The
719 // reason we set it to false by default is that this mode may use scaled
720 // atomic integer reduction that may cause a numerical overflow for certain
721 // input data range.
722 // TODO(yangzihao): Use autotune to choose between this mode and
723 // CUDNN_BATCHNORM_SPATIAL mode.
BatchnormSpatialPersistentEnabled()724 bool BatchnormSpatialPersistentEnabled() {
725   static bool is_enabled = [] {
726     bool is_enabled = false;
727     TF_CHECK_OK(tensorflow::ReadBoolFromEnvVar(
728         "TF_USE_CUDNN_BATCHNORM_SPATIAL_PERSISTENT",
729         /*default_val=*/false, &is_enabled));
730     return is_enabled;
731   }();
732   return is_enabled;
733 }
734 
RequireCudnnDeterminism()735 bool RequireCudnnDeterminism() {
736   static bool require_cudnn_determinism = [] {
737     // TODO(reedwm): Remove the TF_CUDNN_DETERMINISTIC env var.
738     bool cudnn_deterministic = false;
739     TF_CHECK_OK(tensorflow::ReadBoolFromEnvVar("TF_CUDNN_DETERMINISTIC",
740                                                /*default_val=*/false,
741                                                &cudnn_deterministic));
742     return cudnn_deterministic;
743   }();
744   return tensorflow::OpDeterminismRequired() || require_cudnn_determinism;
745 }
746 
747 // A helper function to decide whether to force the default conv algorithm.
ConvUseDefaultAlgorithm()748 bool ConvUseDefaultAlgorithm() {
749   static bool use_default = [] {
750     bool use_default = false;
751     TF_CHECK_OK(tensorflow::ReadBoolFromEnvVar("TF_USE_DEFAULT_CONV_ALGO",
752                                                /*default_val=*/false,
753                                                &use_default));
754     return use_default;
755   }();
756   return use_default;
757 }
758 
759 // Turns a ConvolutionDescriptor structure into a cudnn convolution handle
760 // within a scope.
761 class CudnnConvolutionDescriptor {
762  public:
CudnnConvolutionDescriptor(const dnn::ConvolutionDescriptor & convolution_descriptor,cudnnDataType_t data_type)763   CudnnConvolutionDescriptor(
764       const dnn::ConvolutionDescriptor& convolution_descriptor,
765       cudnnDataType_t data_type)
766       : handle_(CreateConvolutionDescriptor()) {
767     absl::Span<const int64> strides64 = convolution_descriptor.strides();
768     absl::Span<const int64> padding64 = convolution_descriptor.padding();
769     absl::Span<const int64> dilations64 = convolution_descriptor.dilations();
770     CHECK_NE(convolution_descriptor.pad_alignment(),
771              dnn::PadAlignment::kTensorFlowPadding)
772         << "TensorFlow padding alignment is not supported.";
773 
774     // cuDNN requires arrays of ints.
775     std::vector<int> strides(convolution_descriptor.ndims());
776     std::vector<int> padding(convolution_descriptor.ndims());
777     std::vector<int> dilations(convolution_descriptor.ndims());
778     std::transform(strides64.cbegin(), strides64.cend(), strides.begin(),
779                    &CheckedNarrowing<int64, int>);
780     std::transform(padding64.cbegin(), padding64.cend(), padding.begin(),
781                    &CheckedNarrowing<int64, int>);
782     // TODO(yangzihao): Test with negative dilation to make sure that cudnn
783     // doesn't crash.
784     std::transform(dilations64.cbegin(), dilations64.cend(), dilations.begin(),
785                    &CheckedNarrowing<int64, int>);
786 
787     CHECK_CUDNN_OK(cudnnSetConvolutionNdDescriptor(
788         handle_.get(), convolution_descriptor.ndims(), padding.data(),
789         strides.data(), dilations.data(),
790         convolution_descriptor.convolution_not_crosscorr()
791             ? CUDNN_CONVOLUTION
792             : CUDNN_CROSS_CORRELATION,
793         data_type));
794 
795 #if CUDNN_MAJOR >= 7
796     VLOG(2) << "Requesting grouped convolution: "
797             << convolution_descriptor.group_count();
798     CHECK_CUDNN_OK(cudnnSetConvolutionGroupCount(
799         handle_.get(), convolution_descriptor.group_count()));
800 #else
801     CHECK_EQ(convolution_descriptor.group_count(), 1)
802         << "Requested grouped convolution for cuDNN version < 7";
803 #endif
804   }
805 
set_use_tensor_op_math(bool use_tensor_op_math)806   void set_use_tensor_op_math(bool use_tensor_op_math) {
807     cudnnMathType_t math_type =
808 #if CUDNN_VERSION >= 8000
809         (use_tensor_op_math ? CUDNN_TENSOR_OP_MATH : CUDNN_FMA_MATH);
810 #else
811         (use_tensor_op_math ? CUDNN_TENSOR_OP_MATH : CUDNN_DEFAULT_MATH);
812 #endif
813     CHECK_CUDNN_OK(cudnnSetConvolutionMathType(handle_.get(), math_type));
814   }
815 
handle() const816   cudnnConvolutionDescriptor_t handle() const { return handle_.get(); }
817 
818  private:
819   ConvolutionDescriptor handle_;  // Owned.
820 
821   SE_DISALLOW_COPY_AND_ASSIGN(CudnnConvolutionDescriptor);
822 };
823 
824 // A helper function to query if a CudnnConvolutionDescriptor has tensor_op_math
825 // set
IsTensorMathOpSet(const CudnnConvolutionDescriptor & conv)826 static bool IsTensorMathOpSet(const CudnnConvolutionDescriptor& conv) {
827   cudnnMathType_t math_type;
828   CHECK_CUDNN_OK(cudnnGetConvolutionMathType(conv.handle(), &math_type));
829 #if CUDNN_VERSION >= 8000
830   return math_type != CUDNN_FMA_MATH;
831 #else
832   return math_type == CUDNN_TENSOR_OP_MATH;
833 #endif
834 }
835 
TensorOpMathAvailable(CudaComputeCapability cuda_compute_capability)836 static bool TensorOpMathAvailable(
837     CudaComputeCapability cuda_compute_capability) {
838   return cuda_compute_capability.IsAtLeast(7);
839 }
840 
IsTensorMathEnabled(Stream * stream,dnn::DataType input_type)841 static bool IsTensorMathEnabled(Stream* stream, dnn::DataType input_type) {
842   if (!TensorOpMathAvailable(stream->GetCudaComputeCapability())) {
843     return false;
844   }
845   if (input_type == dnn::DataType::kFloat) {
846 #if CUDNN_VERSION < 8000
847     return false;
848 #else
849     if (!tensorflow::tensor_float_32_execution_enabled()) {
850       return false;
851     }
852 #endif
853   }
854   return true;
855 }
856 
857 // Turns a PoolingDescriptor structure into a cudnn pooling descriptor handle
858 // within a scope.
859 class CudnnPoolingDescriptor {
860  public:
CudnnPoolingDescriptor(const dnn::PoolingDescriptor & pooling_descriptor)861   explicit CudnnPoolingDescriptor(
862       const dnn::PoolingDescriptor& pooling_descriptor)
863       : handle_(CreatePoolingDescriptor()) {
864     absl::Span<const int64> strides64 = pooling_descriptor.strides();
865     absl::Span<const int64> padding64 = pooling_descriptor.padding();
866     absl::Span<const int64> shape64 = pooling_descriptor.window();
867 
868     const int nd = pooling_descriptor.ndims();
869     std::vector<int> shape(nd);
870     std::vector<int> padding(nd);
871     std::vector<int> strides(nd);
872     std::transform(strides64.cbegin(), strides64.cend(), strides.begin(),
873                    &CheckedNarrowing<int64, int>);
874     std::transform(padding64.cbegin(), padding64.cend(), padding.begin(),
875                    &CheckedNarrowing<int64, int>);
876     std::transform(shape64.cbegin(), shape64.cend(), shape.begin(),
877                    &CheckedNarrowing<int64, int>);
878     bool propagate_nans = pooling_descriptor.propagate_nans();
879     const auto cudnn_max_pooling_mode = RequireCudnnDeterminism()
880                                             ? CUDNN_POOLING_MAX_DETERMINISTIC
881                                             : CUDNN_POOLING_MAX;
882     CHECK_CUDNN_OK(cudnnSetPoolingNdDescriptor(
883         handle_.get(),
884         (pooling_descriptor.mode() == dnn::PoolingMode::kMaximum
885              ? cudnn_max_pooling_mode
886              : CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING),
887         propagate_nans ? CUDNN_PROPAGATE_NAN : CUDNN_NOT_PROPAGATE_NAN, nd,
888         shape.data(), padding.data(), strides.data()));
889   }
890 
handle() const891   cudnnPoolingDescriptor_t handle() const { return handle_.get(); }
892 
893  private:
894   PoolingDescriptor handle_;  // Owned.
895 
896   SE_DISALLOW_COPY_AND_ASSIGN(CudnnPoolingDescriptor);
897 };
898 
899 // Turns a NormalizeDescriptor structure into a cudnn LRN descriptor handle.
900 class CudnnNormalizeDescriptor {
901  public:
CudnnNormalizeDescriptor(const dnn::NormalizeDescriptor & normalize_descriptor)902   explicit CudnnNormalizeDescriptor(
903       const dnn::NormalizeDescriptor& normalize_descriptor)
904       : handle_(CreateLrnDescriptor()) {
905     // The range specifies that the indices in the closed range
906     // [i - range, i + range] should be included in the normalization for index
907     // i. The lrnN value is the total number of elements in the range, so
908     // lrnN = 2*range + 1.
909     unsigned lrnN = 2 * normalize_descriptor.range() + 1;
910 
911     // Note that SE defines the normalization operation as
912     //
913     //  U_i = V_i / ((bias +  alpha      * (sum_j V_j^2)) ^ beta)
914     //
915     // but cuDNN defines it as
916     //
917     //  U_i = V_i / ((bias + (alpha / n) * (sum_j V_j^2)) ^ beta)
918     //
919     // i.e. there is a factor of n difference between the meaning of the alphas
920     // in the two contexts. The cuDNN alpha is n times the SE alpha.
921     double lrnAlpha = lrnN * normalize_descriptor.alpha();
922 
923     double lrnBeta = normalize_descriptor.beta();
924     double lrnK = normalize_descriptor.bias();
925     CHECK_CUDNN_OK(
926         cudnnSetLRNDescriptor(handle_.get(), lrnN, lrnAlpha, lrnBeta, lrnK));
927   }
928 
handle() const929   cudnnLRNDescriptor_t handle() const { return handle_.get(); }
930 
931  private:
932   LrnDescriptor handle_;  // Owned.
933 
934   SE_DISALLOW_COPY_AND_ASSIGN(CudnnNormalizeDescriptor);
935 };
936 
937 // Turns a ActivationDescriptor structure into a cudnn activation
938 // descriptor handle within a scope.
939 class CudnnActivationDescriptor {
940  public:
CudnnActivationDescriptor(dnn::ActivationMode activation_mode,cudnnNanPropagation_t nan_propagation,double value_max)941   CudnnActivationDescriptor(dnn::ActivationMode activation_mode,
942                             cudnnNanPropagation_t nan_propagation,
943                             double value_max)
944       : handle_(CreateActivationDescriptor()) {
945     double relu_ceiling = 0.0;
946     cudnnActivationMode_t mode;
947     switch (activation_mode) {
948       case dnn::ActivationMode::kNone:
949         mode = CUDNN_ACTIVATION_IDENTITY;
950         break;
951       case dnn::ActivationMode::kRelu6:
952         relu_ceiling = 6.0;
953         mode = CUDNN_ACTIVATION_CLIPPED_RELU;
954         break;
955       case dnn::ActivationMode::kReluX:
956         relu_ceiling = value_max;
957         mode = CUDNN_ACTIVATION_CLIPPED_RELU;
958         break;
959       case dnn::ActivationMode::kRelu:
960         mode = CUDNN_ACTIVATION_RELU;
961         break;
962       case dnn::ActivationMode::kSigmoid:
963         mode = CUDNN_ACTIVATION_SIGMOID;
964         break;
965       case dnn::ActivationMode::kTanh:
966         mode = CUDNN_ACTIVATION_TANH;
967         break;
968       default:
969         LOG(FATAL) << "unrecognized activation mode: "
970                    << static_cast<int>(activation_mode);
971     }
972 
973     CHECK_CUDNN_OK(cudnnSetActivationDescriptor(handle_.get(), mode,
974                                                 nan_propagation, relu_ceiling));
975   }
976 
handle() const977   cudnnActivationDescriptor_t handle() const { return handle_.get(); }
978 
979  private:
980   ActivationDescriptor handle_;  // Owned.
981 
982   SE_DISALLOW_COPY_AND_ASSIGN(CudnnActivationDescriptor);
983 };
984 
ToCudnnDataType(dnn::DataType data_type,dnn::DataLayout data_layout=dnn::DataLayout::kBatchDepthYX)985 cudnnDataType_t ToCudnnDataType(
986     dnn::DataType data_type,
987     dnn::DataLayout data_layout = dnn::DataLayout::kBatchDepthYX) {
988   switch (data_type) {
989     case dnn::DataType::kFloat:
990       return CUDNN_DATA_FLOAT;
991     case dnn::DataType::kDouble:
992       return CUDNN_DATA_DOUBLE;
993     case dnn::DataType::kHalf:
994       return CUDNN_DATA_HALF;
995     case dnn::DataType::kInt8:
996       switch (data_layout) {
997         case dnn::DataLayout::kBatchDepthYX4:
998           return CUDNN_DATA_INT8x4;
999         case dnn::DataLayout::kBatchDepthYX32:
1000           return CUDNN_DATA_INT8x32;
1001         default:
1002           return CUDNN_DATA_INT8;
1003       }
1004     case dnn::DataType::kInt32:
1005       return CUDNN_DATA_INT32;
1006 #if CUDNN_VERSION >= 8200
1007     case dnn::DataType::kBF16:
1008       return CUDNN_DATA_BFLOAT16;
1009 #endif
1010     default:
1011       LOG(FATAL) << "Invalid DNN data type: " << static_cast<int>(data_type);
1012   }
1013 }
1014 
ToCudnnDataType(dnn::DataType data_type,dnn::FilterLayout filter_layout)1015 cudnnDataType_t ToCudnnDataType(dnn::DataType data_type,
1016                                 dnn::FilterLayout filter_layout) {
1017   if (data_type == dnn::DataType::kInt8 &&
1018       filter_layout == dnn::FilterLayout::kOutputInputYX4) {
1019     return CUDNN_DATA_INT8x4;
1020   }
1021   if (data_type == dnn::DataType::kInt8 &&
1022       filter_layout == dnn::FilterLayout::kOutputInputYX32) {
1023     return CUDNN_DATA_INT8x32;
1024   }
1025   return ToCudnnDataType(data_type);
1026 }
1027 
1028 template <typename T>
GetCudnnDataType(dnn::DataLayout data_layout=dnn::DataLayout::kBatchDepthYX)1029 cudnnDataType_t GetCudnnDataType(
1030     dnn::DataLayout data_layout = dnn::DataLayout::kBatchDepthYX) {
1031   return ToCudnnDataType(dnn::ToDataType<T>::value, data_layout);
1032 }
1033 
1034 template <typename T>
GetCudnnDataType(dnn::FilterLayout filter_layout)1035 cudnnDataType_t GetCudnnDataType(dnn::FilterLayout filter_layout) {
1036   return ToCudnnDataType(dnn::ToDataType<T>::value, filter_layout);
1037 }
1038 
ToCudnnRnnInputMode(dnn::RnnInputMode input_mode)1039 cudnnRNNInputMode_t ToCudnnRnnInputMode(dnn::RnnInputMode input_mode) {
1040   switch (input_mode) {
1041     case dnn::RnnInputMode::kRnnLinearSkip:
1042     case dnn::RnnInputMode::kRnnSkipInput:
1043       return static_cast<cudnnRNNInputMode_t>(input_mode);
1044     default:
1045       LOG(FATAL) << "Invalid RNN input mode: " << static_cast<int>(input_mode);
1046   }
1047 }
1048 
ToCudnnRnnDirectionMode(dnn::RnnDirectionMode direction_mode)1049 cudnnDirectionMode_t ToCudnnRnnDirectionMode(
1050     dnn::RnnDirectionMode direction_mode) {
1051   switch (direction_mode) {
1052     case dnn::RnnDirectionMode::kRnnUnidirectional:
1053     case dnn::RnnDirectionMode::kRnnBidirectional:
1054       return static_cast<cudnnDirectionMode_t>(direction_mode);
1055     default:
1056       LOG(FATAL) << "Invalid RNN direction mode: "
1057                  << static_cast<int>(direction_mode);
1058   }
1059 }
1060 
ToCudnnRnnMode(dnn::RnnMode rnn_mode)1061 cudnnRNNMode_t ToCudnnRnnMode(dnn::RnnMode rnn_mode) {
1062   switch (rnn_mode) {
1063     case dnn::RnnMode::kRnnRelu:
1064     case dnn::RnnMode::kRnnTanh:
1065     case dnn::RnnMode::kRnnLstm:
1066     case dnn::RnnMode::kRnnGru:
1067       return static_cast<cudnnRNNMode_t>(rnn_mode);
1068     default:
1069       LOG(FATAL) << "Invalid RNN Mode: " << static_cast<int>(rnn_mode);
1070   }
1071 }
1072 
CudnnDataTypeToByteSize(cudnnDataType_t data_type)1073 int CudnnDataTypeToByteSize(cudnnDataType_t data_type) {
1074   switch (data_type) {
1075     case CUDNN_DATA_FLOAT:
1076       return sizeof(float);
1077     case CUDNN_DATA_DOUBLE:
1078       return sizeof(double);
1079     case CUDNN_DATA_HALF:
1080       return sizeof(Eigen::half);
1081     default:
1082       LOG(FATAL) << "Invalid DNN data type: " << static_cast<int>(data_type);
1083   }
1084 }
1085 
1086 class CudnnDropoutDescriptor {
CudnnDropoutDescriptor(DropoutDescriptor handle)1087   explicit CudnnDropoutDescriptor(DropoutDescriptor handle)
1088       : handle_(std::move(handle)) {}
1089 
1090  public:
1091   CudnnDropoutDescriptor(CudnnDropoutDescriptor&&) = default;
1092 
Create(const CudnnHandle & cudnn,float dropout,uint64 seed,ScratchAllocator * state_allocator)1093   static port::StatusOr<CudnnDropoutDescriptor> Create(
1094       const CudnnHandle& cudnn, float dropout, uint64 seed,
1095       ScratchAllocator* state_allocator) {
1096     DropoutDescriptor handle = CreateDropoutDescriptor();
1097 
1098     if (dropout == 0.0f) {
1099       // Return 'empty' dropout descriptor.
1100       return CudnnDropoutDescriptor(std::move(handle));
1101     }
1102 
1103     DeviceMemory<uint8> state_memory;
1104     if (state_allocator) {
1105       size_t state_sizes_in_bytes = 0;
1106       RETURN_IF_CUDNN_ERROR(
1107           cudnnDropoutGetStatesSize(cudnn.handle(), &state_sizes_in_bytes));
1108       SE_ASSIGN_OR_RETURN(state_memory,
1109                           state_allocator->AllocateBytes(state_sizes_in_bytes));
1110     }
1111     RETURN_IF_CUDNN_ERROR(cudnnSetDropoutDescriptor(
1112         handle.get(), cudnn.handle(), dropout, state_memory.opaque(),
1113         state_memory.size(), seed));
1114 
1115     return CudnnDropoutDescriptor(std::move(handle));
1116   }
1117 
handle() const1118   cudnnDropoutDescriptor_t handle() const { return handle_.get(); }
1119 
1120  private:
1121   DropoutDescriptor handle_;  // Owned.
1122   SE_DISALLOW_COPY_AND_ASSIGN(CudnnDropoutDescriptor);
1123 };
1124 
1125 class CudnnRnnParamsDescriptor {
1126   typedef dnn::RnnDescriptor::ParamsRegions ParamsRegions;
1127 
CudnnRnnParamsDescriptor(FilterDescriptor handle,int64_t params_size_in_bytes,ParamsRegions weights,ParamsRegions biases)1128   CudnnRnnParamsDescriptor(FilterDescriptor handle,
1129                            int64_t params_size_in_bytes, ParamsRegions weights,
1130                            ParamsRegions biases)
1131       : handle_(std::move(handle)),
1132         params_size_in_bytes_(params_size_in_bytes),
1133         weights_(std::move(weights)),
1134         biases_(std::move(biases)) {}
1135 
1136  public:
1137   CudnnRnnParamsDescriptor(CudnnRnnParamsDescriptor&&) = default;
1138 
1139   static port::StatusOr<CudnnRnnParamsDescriptor> Create(
1140       const CudnnHandle& cudnn, int input_size, cudnnDataType_t data_type,
1141       cudnnRNNDescriptor_t rnn_desc, cudnnRNNMode_t rnn_mode,
1142       cudnnDirectionMode_t direction_mode, int num_layers);
1143 
handle() const1144   cudnnFilterDescriptor_t handle() const { return handle_.get(); }
params_size_in_bytes() const1145   int64 params_size_in_bytes() const { return params_size_in_bytes_; }
params_weights() const1146   ParamsRegions params_weights() const { return weights_; }
params_biases() const1147   ParamsRegions params_biases() const { return biases_; }
1148 
1149  private:
1150   FilterDescriptor handle_;
1151   int64 params_size_in_bytes_;
1152   ParamsRegions weights_;
1153   ParamsRegions biases_;
1154   SE_DISALLOW_COPY_AND_ASSIGN(CudnnRnnParamsDescriptor);
1155 };
1156 
1157 }  // namespace
1158 
1159 class CudnnRnnDescriptor : public dnn::RnnDescriptor {
CudnnRnnDescriptor(const CudnnHandle & cudnn,gpu::RnnDescriptor rnn_desc,PersistentRnnPlan rnn_plan,int num_layers,int hidden_size,int input_size,int cell_size,int batch_size,cudnnRNNInputMode_t input_mode,cudnnDirectionMode_t direction_mode,cudnnRNNMode_t rnn_mode,cudnnDataType_t data_type,cudnnDataType_t compute_type,const dnn::AlgorithmConfig & algorithm_config,CudnnDropoutDescriptor dropout_desc,CudnnRnnParamsDescriptor params_desc)1160   CudnnRnnDescriptor(const CudnnHandle& cudnn, gpu::RnnDescriptor rnn_desc,
1161                      PersistentRnnPlan rnn_plan, int num_layers,
1162                      int hidden_size, int input_size, int cell_size,
1163                      int batch_size, cudnnRNNInputMode_t input_mode,
1164                      cudnnDirectionMode_t direction_mode,
1165                      cudnnRNNMode_t rnn_mode, cudnnDataType_t data_type,
1166                      cudnnDataType_t compute_type,
1167                      const dnn::AlgorithmConfig& algorithm_config,
1168                      CudnnDropoutDescriptor dropout_desc,
1169                      CudnnRnnParamsDescriptor params_desc)
1170       : rnn_desc_(std::move(rnn_desc)),
1171         rnn_plan_(std::move(rnn_plan)),
1172         num_layers_(num_layers),
1173         hidden_size_(hidden_size),
1174         input_size_(input_size),
1175         cell_size_(cell_size),
1176         batch_size_(batch_size),
1177         rnn_algo_(ToCudnnRNNAlgo(algorithm_config.algorithm())),
1178         input_mode_(input_mode),
1179         direction_mode_(direction_mode),
1180         rnn_mode_(rnn_mode),
1181         data_type_(data_type),
1182         compute_type_(compute_type),
1183         algorithm_config_(algorithm_config),
1184         dropout_desc_(std::move(dropout_desc)),
1185         params_desc_(std::move(params_desc)) {}
1186 
1187  public:
1188   CudnnRnnDescriptor(CudnnRnnDescriptor&& other) = default;
1189 
Create(const CudnnHandle & cudnn,int num_layers,int hidden_size,int input_size,int cell_size,int batch_size,cudnnRNNInputMode_t input_mode,cudnnDirectionMode_t direction_mode,cudnnRNNMode_t rnn_mode,cudnnDataType_t data_type,cudnnDataType_t compute_type,const dnn::AlgorithmConfig & algorithm_config,float dropout,uint64 seed,ScratchAllocator * state_allocator,bool use_padded_io)1190   static port::StatusOr<CudnnRnnDescriptor> Create(
1191       const CudnnHandle& cudnn, int num_layers, int hidden_size, int input_size,
1192       int cell_size, int batch_size, cudnnRNNInputMode_t input_mode,
1193       cudnnDirectionMode_t direction_mode, cudnnRNNMode_t rnn_mode,
1194       cudnnDataType_t data_type, cudnnDataType_t compute_type,
1195       const dnn::AlgorithmConfig& algorithm_config, float dropout, uint64 seed,
1196       ScratchAllocator* state_allocator, bool use_padded_io) {
1197     SE_ASSIGN_OR_RETURN(
1198         CudnnDropoutDescriptor dropout_desc,
1199         CudnnDropoutDescriptor::Create(cudnn, dropout, seed, state_allocator));
1200 
1201     gpu::RnnDescriptor rnn_desc = CreateRnnDescriptor();
1202     cudnnRNNAlgo_t rnn_algo = ToCudnnRNNAlgo(algorithm_config.algorithm());
1203 
1204     // TODO: allow the user to choose an algorithm.
1205     auto proj_size = hidden_size;
1206     hidden_size = std::max(hidden_size, cell_size);
1207 
1208     // Require explicit algorithm config to enable tensor cores. Some configs
1209     // return CUDNN_NOT_SUPPORTED when tensor ops are enabled (which is against
1210     // the idiom that enabling tensor ops is only a hint: see nvbugs/2172799).
1211     // We can only reasonably expect the user to handle the subsequent failure
1212     // in profile mode, which is run with algorithms returned from
1213     // GetRnnAlgorithms() (which are non-default and explicitly set whether to
1214     // use tensor ops). CuDNN 7.2.1 fixed this issue.
1215     // TODO(csigg): Minimal support cuDNN version is 7.3, clean up.
1216     bool allow_tensor_ops = data_type == CUDNN_DATA_HALF;
1217     if (data_type == CUDNN_DATA_FLOAT)
1218       allow_tensor_ops = tensorflow::tensor_float_32_execution_enabled();
1219     bool use_tensor_ops =
1220         algorithm_config.algorithm().has_value()
1221             ? algorithm_config.algorithm()->tensor_ops_enabled()
1222             : allow_tensor_ops;
1223     if (use_tensor_ops && !allow_tensor_ops) {
1224       return port::Status(port::error::INVALID_ARGUMENT,
1225                           "Algo requests disallowed tensor op evaluation.");
1226     }
1227 
1228 #if CUDNN_VERSION >= 8000
1229     cudnnMathType_t math_type =
1230         use_tensor_ops ? CUDNN_TENSOR_OP_MATH : CUDNN_FMA_MATH;
1231 #else
1232     cudnnMathType_t math_type =
1233         use_tensor_ops ? CUDNN_TENSOR_OP_MATH : CUDNN_DEFAULT_MATH;
1234 #endif
1235 
1236 #if CUDNN_VERSION >= 8000
1237     cudnnRNNBiasMode_t bias_mode = CUDNN_RNN_DOUBLE_BIAS;
1238     uint32_t aux_flags = 0;
1239     if (use_padded_io) aux_flags |= CUDNN_RNN_PADDED_IO_ENABLED;
1240     RETURN_IF_CUDNN_ERROR(cudnnSetRNNDescriptor_v8(
1241         /*rnnDesc=*/rnn_desc.get(), /*algo=*/rnn_algo, /*cellMode=*/rnn_mode,
1242         /*biasMode=*/bias_mode, /*dirMode=*/direction_mode,
1243         /*inputMode=*/input_mode,
1244         /*dataType=*/data_type, /*mathPrec=*/compute_type,
1245         /*mathType=*/math_type,
1246         /*inputSize=*/input_size,
1247         /*hiddenSize=*/hidden_size, /*projSize=*/proj_size,
1248         /*numLayers=*/num_layers,
1249         /*dropoutDesc=*/dropout_desc.handle(),
1250         /*auxFlags=*/aux_flags));
1251 #else
1252     RETURN_IF_CUDNN_ERROR(cudnnSetRNNDescriptor_v6(
1253         cudnn.handle(), /*rnnDesc=*/rnn_desc.get(),
1254         /*hiddenSize=*/hidden_size, /*numLayers=*/num_layers,
1255         /*dropoutDesc=*/dropout_desc.handle(), /*inputMode=*/input_mode,
1256         /*direction=*/direction_mode, /*mode=*/rnn_mode, /*algo=*/rnn_algo,
1257         /*dataType=*/compute_type));
1258     CHECK_CUDNN_OK(cudnnSetRNNMatrixMathType(rnn_desc.get(), math_type));
1259 
1260     if (proj_size < hidden_size) {
1261       RETURN_IF_CUDNN_ERROR(cudnnSetRNNProjectionLayers(
1262           cudnn.handle(), /*rnnDesc=*/rnn_desc.get(),
1263           /*recProjSize=*/proj_size, /*outProjSize=*/0));
1264     }
1265 
1266     // TODO: For now, we only use cudnnRNN**Ex API to process padded inputs.
1267     // But in the future if these APIs are used to process full length arrays,
1268     // we need to distinguish when to set it.
1269     if (use_padded_io) {
1270       RETURN_IF_CUDNN_ERROR(
1271           cudnnSetRNNPaddingMode(rnn_desc.get(), CUDNN_RNN_PADDED_IO_ENABLED));
1272     }
1273 #endif
1274 
1275     port::StatusOr<PersistentRnnPlan> rnn_plan_wrapper;
1276     PersistentRnnPlan rnn_plan;
1277     if (rnn_algo == CUDNN_RNN_ALGO_PERSIST_DYNAMIC) {
1278       CHECK_GE(batch_size, 0);
1279       rnn_plan_wrapper =
1280           CreatePersistentRnnPlan(rnn_desc.get(), batch_size, data_type);
1281       if (!rnn_plan_wrapper.ok()) {
1282         return port::StatusOr<CudnnRnnDescriptor>(rnn_plan_wrapper.status());
1283       } else {
1284         rnn_plan = rnn_plan_wrapper.ConsumeValueOrDie();
1285         RETURN_IF_CUDNN_ERROR(
1286             cudnnSetPersistentRNNPlan(rnn_desc.get(), rnn_plan.get()));
1287       }
1288     }
1289 
1290     // Create the params handle.
1291     // TODO(kaixih@nvidia.com): Should be removed when cudnnRNNForward*** and
1292     // cudnnRNNForward***Ex are removed from the codebase, since the new API
1293     // doesn't need param descriptors any more.
1294     SE_ASSIGN_OR_RETURN(auto params_desc,
1295                         CudnnRnnParamsDescriptor::Create(
1296                             cudnn, input_size, data_type, rnn_desc.get(),
1297                             rnn_mode, direction_mode, num_layers));
1298 
1299     return CudnnRnnDescriptor(cudnn, std::move(rnn_desc), std::move(rnn_plan),
1300                               num_layers, hidden_size, input_size, cell_size,
1301                               batch_size, input_mode, direction_mode, rnn_mode,
1302                               data_type, compute_type, algorithm_config,
1303                               std::move(dropout_desc), std::move(params_desc));
1304   }
1305 
handle() const1306   cudnnRNNDescriptor_t handle() const { return rnn_desc_.get(); }
num_layers() const1307   int num_layers() const { return num_layers_; }
hidden_size() const1308   int hidden_size() const { return hidden_size_; }
input_size() const1309   int input_size() const { return input_size_; }
cell_size() const1310   int cell_size() const { return cell_size_; }
batch_size() const1311   int batch_size() const { return batch_size_; }
input_mode() const1312   cudnnRNNInputMode_t input_mode() const { return input_mode_; }
direction_mode() const1313   cudnnDirectionMode_t direction_mode() const { return direction_mode_; }
rnn_mode() const1314   cudnnRNNMode_t rnn_mode() const { return rnn_mode_; }
data_type() const1315   cudnnDataType_t data_type() const { return data_type_; }
compute_type() const1316   cudnnDataType_t compute_type() const { return compute_type_; }
algorithm_config() const1317   const dnn::AlgorithmConfig& algorithm_config() const {
1318     return algorithm_config_;
1319   }
ParamsSizeInBytes() const1320   int64 ParamsSizeInBytes() const override {
1321     return params_desc_.params_size_in_bytes();
1322   }
params_handle() const1323   cudnnFilterDescriptor_t params_handle() const {
1324     return params_desc_.handle();
1325   }
ParamsWeightRegions() const1326   ParamsRegions ParamsWeightRegions() const override {
1327     return params_desc_.params_weights();
1328   }
ParamsBiasRegions() const1329   ParamsRegions ParamsBiasRegions() const override {
1330     return params_desc_.params_biases();
1331   }
1332 
1333  private:
1334   gpu::RnnDescriptor rnn_desc_;
1335   PersistentRnnPlan rnn_plan_;
1336   int num_layers_;
1337   int hidden_size_;
1338   int input_size_;
1339   // cell_size_ is the size of cell state, which will be different from
1340   // hidden_size_ if the projection is used.
1341   int cell_size_;
1342   // batch_size_ is set to -1 when not using CUDNN_RNN_ALGO_PERSIST_DYNAMIC
1343   // algorithm.
1344   int batch_size_;
1345   cudnnRNNAlgo_t rnn_algo_;
1346   cudnnRNNInputMode_t input_mode_;
1347   cudnnDirectionMode_t direction_mode_;
1348   cudnnRNNMode_t rnn_mode_;
1349   cudnnDataType_t data_type_;
1350   cudnnDataType_t compute_type_;
1351   dnn::AlgorithmConfig algorithm_config_;
1352   CudnnDropoutDescriptor dropout_desc_;
1353   CudnnRnnParamsDescriptor params_desc_;
1354   SE_DISALLOW_COPY_AND_ASSIGN(CudnnRnnDescriptor);
1355 };
1356 
1357 #if CUDNN_VERSION >= 7603
1358 class CudnnCtcLossDescriptor {
1359  public:
CudnnCtcLossDescriptor(cudnnDataType_t data_type)1360   explicit CudnnCtcLossDescriptor(cudnnDataType_t data_type)
1361       : handle_(CreateCtcLossDescriptor()) {
1362     CHECK_CUDNN_OK(cudnnSetCTCLossDescriptorEx(
1363         /*ctcLossDesc=*/handle_.get(),
1364         /*compType=*/data_type,
1365         /*normMode=*/CUDNN_LOSS_NORMALIZATION_SOFTMAX,
1366         /*gradMode=*/CUDNN_NOT_PROPAGATE_NAN));
1367   }
1368 
handle() const1369   cudnnCTCLossDescriptor_t handle() const { return handle_.get(); }
1370 
1371  private:
1372   CtcLossDescriptor handle_;  // Owned
1373 
1374   SE_DISALLOW_COPY_AND_ASSIGN(CudnnCtcLossDescriptor);
1375 };
1376 #else
1377 // dummy class
1378 class CudnnCtcLossDescriptor {
1379  public:
CudnnCtcLossDescriptor(cudnnDataType_t data_type)1380   CudnnCtcLossDescriptor(cudnnDataType_t data_type) {}
1381 };
1382 #endif
1383 
1384 namespace {
1385 
1386 // Check if the LSTM projection is used. If yes, an additional weight matrix
1387 // (projection matrix) will be fetched to the 'weights'. Otherwise, nothing will
1388 // be done.
CheckAndFetchProjectionWeights(const CudnnHandle & cudnn,cudnnRNNDescriptor_t rnn_desc,const int layer,const TensorDescriptor & input_desc,const FilterDescriptor & filter_desc,const FilterDescriptor & region_desc_handle,dnn::RnnDescriptor::ParamsRegions * weights)1389 port::Status CheckAndFetchProjectionWeights(
1390     const CudnnHandle& cudnn, cudnnRNNDescriptor_t rnn_desc, const int layer,
1391     const TensorDescriptor& input_desc, const FilterDescriptor& filter_desc,
1392     const FilterDescriptor& region_desc_handle,
1393     dnn::RnnDescriptor::ParamsRegions* weights) {
1394   int hidden_size_v;
1395   int num_layers_v;
1396   cudnnDropoutDescriptor_t dropout_desc;
1397   cudnnRNNInputMode_t input_mode;
1398   cudnnDirectionMode_t direction;
1399   cudnnRNNMode_t mode;
1400   cudnnRNNAlgo_t algo;
1401   cudnnDataType_t data_type;
1402 #if CUDNN_VERSION >= 8000
1403   RETURN_IF_CUDNN_ERROR(cudnnGetRNNDescriptor_v6(
1404       /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc,
1405       /*hiddenSize=*/&hidden_size_v,
1406       /*numLayers=*/&num_layers_v,
1407       /*dropoutDesc=*/&dropout_desc,
1408       /*inputMode=*/&input_mode,
1409       /*direction=*/&direction,
1410       /*mode=*/&mode,
1411       /*algo=*/&algo,
1412       /*mathPrec=*/&data_type));
1413 #else
1414   RETURN_IF_CUDNN_ERROR(cudnnGetRNNDescriptor(
1415       /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc,
1416       /*hiddenSize=*/&hidden_size_v,
1417       /*numLayers=*/&num_layers_v,
1418       /*dropoutDesc=*/&dropout_desc,
1419       /*inputMode=*/&input_mode,
1420       /*direction=*/&direction,
1421       /*mode=*/&mode,
1422       /*algo=*/&algo,
1423       /*mathPrec=*/&data_type));
1424 #endif
1425   int rec_proj_size_v;
1426   int out_proj_size_v;
1427   RETURN_IF_CUDNN_ERROR(cudnnGetRNNProjectionLayers(
1428       /*handle=*/cudnn.handle(),
1429       /*rnnDesc=*/rnn_desc,
1430       /*recProjSize*/ &rec_proj_size_v,
1431       /*outProjSize*/ &out_proj_size_v));
1432   if (rec_proj_size_v != hidden_size_v) {
1433     void* offset = nullptr;
1434     int region_id = 8;
1435     RETURN_IF_CUDNN_ERROR(cudnnGetRNNLinLayerMatrixParams(
1436         /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc,
1437         /*layer=*/layer, /*xDesc=*/input_desc.get(),
1438         /*wDesc=*/filter_desc.get(),
1439         /*w=*/nullptr, /*linLayerID=*/region_id,
1440         /*linLayerMatDesc=*/region_desc_handle.get(),
1441         /*linLayerMat or linLayerBias=*/&offset));
1442     int dims[] = {1, 1, 1};
1443     cudnnDataType_t data_type;
1444     cudnnTensorFormat_t tensor_format;
1445     int n_dims;
1446     RETURN_IF_CUDNN_ERROR(cudnnGetFilterNdDescriptor(
1447         /*filterDesc=*/region_desc_handle.get(),
1448         /*nbDimsRequested=*/sizeof(dims) / sizeof(dims[0]),
1449         /*dataType=*/&data_type, /*format=*/&tensor_format,
1450         /*nbDims=*/&n_dims, /*filterDimA=*/dims));
1451     int64_t size =
1452         dims[0] * dims[1] * dims[2] * CudnnDataTypeToByteSize(data_type);
1453     dnn::RnnDescriptor::ParamsRegion region = {reinterpret_cast<int64>(offset),
1454                                                size};
1455     weights->push_back(region);
1456   }
1457   return port::Status::OK();
1458 }
1459 
Create(const CudnnHandle & cudnn,int input_size,cudnnDataType_t data_type,cudnnRNNDescriptor_t rnn_desc,cudnnRNNMode_t rnn_mode,cudnnDirectionMode_t direction_mode,int num_layers)1460 port::StatusOr<CudnnRnnParamsDescriptor> CudnnRnnParamsDescriptor::Create(
1461     const CudnnHandle& cudnn, int input_size, cudnnDataType_t data_type,
1462     cudnnRNNDescriptor_t rnn_desc, cudnnRNNMode_t rnn_mode,
1463     cudnnDirectionMode_t direction_mode, int num_layers) {
1464   // Query the params size.
1465   TensorDescriptor input_desc = CreateTensorDescriptor();
1466   int tensor_dims[] = {1, input_size, 1};
1467   int strides[] = {tensor_dims[1] * tensor_dims[2], tensor_dims[2], 1};
1468   RETURN_IF_CUDNN_ERROR(cudnnSetTensorNdDescriptor(
1469       /*tensorDesc=*/input_desc.get(), /*dataType=*/data_type,
1470       /*nbDims=*/sizeof(tensor_dims) / sizeof(tensor_dims[0]),
1471       /*dimA=*/tensor_dims,
1472       /*strideA=*/strides));
1473 
1474   size_t params_size = 0;
1475   RETURN_IF_CUDNN_ERROR(cudnnGetRNNParamsSize(
1476       /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc,
1477       /*xDesc=*/input_desc.get(), /*sizeInBytes=*/&params_size,
1478       /*dataType=*/data_type));
1479   int64_t params_size_in_bytes = static_cast<int64>(params_size);
1480 
1481   FilterDescriptor filter_desc = CreateFilterDescriptor();
1482   int64_t filter_dim0 =
1483       params_size_in_bytes / CudnnDataTypeToByteSize(data_type);
1484   int filter_dims[] = {static_cast<int>(filter_dim0), 1, 1};
1485   RETURN_IF_CUDNN_ERROR(cudnnSetFilterNdDescriptor(
1486       /*filterDesc=*/filter_desc.get(), /*dataType=*/data_type,
1487       /*format=*/CUDNN_TENSOR_NCHW,
1488       /*nbDims=*/sizeof(filter_dims) / sizeof(filter_dims[0]),
1489       /*filterDimA=*/filter_dims));
1490 
1491   // Create the weights and biases into the params buffer
1492   int region_count_per_layer = [&] {
1493     switch (rnn_mode) {
1494       case CUDNN_RNN_RELU:
1495       case CUDNN_RNN_TANH:
1496         return 2;
1497       case CUDNN_LSTM:
1498         return 8;
1499       case CUDNN_GRU:
1500         return 6;
1501       default:
1502         LOG(FATAL) << "Invalid RNN Mode: " << static_cast<int>(rnn_mode);
1503         return 0;
1504     }
1505   }();
1506 
1507   FilterDescriptor region_desc_handle = CreateFilterDescriptor();
1508   const int layer_count =
1509       direction_mode == CUDNN_UNIDIRECTIONAL ? num_layers : 2 * num_layers;
1510 
1511   ParamsRegions weights;
1512   ParamsRegions biases;
1513 
1514   for (int layer = 0; layer < layer_count; layer++) {
1515     for (int region = 0; region < region_count_per_layer; region++) {
1516       for (int type = 0; type < 2; type++) {
1517         void* offset = nullptr;
1518         RETURN_IF_CUDNN_ERROR(
1519             type == 0 ? cudnnGetRNNLinLayerMatrixParams(
1520                             /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc,
1521                             /*layer=*/layer, /*xDesc=*/input_desc.get(),
1522                             /*wDesc=*/filter_desc.get(),
1523                             /*w=*/nullptr, /*linLayerID=*/region,
1524                             /*linLayerMatDesc=*/region_desc_handle.get(),
1525                             /*linLayerMat or linLayerBias=*/&offset)
1526                       : cudnnGetRNNLinLayerBiasParams(
1527                             /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc,
1528                             /*layer=*/layer, /*xDesc=*/input_desc.get(),
1529                             /*wDesc=*/filter_desc.get(),
1530                             /*w=*/nullptr, /*linLayerID=*/region,
1531                             /*linLayerMatDesc=*/region_desc_handle.get(),
1532                             /*linLayerMat or linLayerBias=*/&offset));
1533         int dims[] = {1, 1, 1};
1534         cudnnDataType_t data_type;
1535         cudnnTensorFormat_t tensor_format;
1536         int n_dims;
1537         RETURN_IF_CUDNN_ERROR(cudnnGetFilterNdDescriptor(
1538             /*filterDesc=*/region_desc_handle.get(),
1539             /*nbDimsRequested=*/sizeof(dims) / sizeof(dims[0]),
1540             /*dataType=*/&data_type, /*format=*/&tensor_format,
1541             /*nbDims=*/&n_dims, /*filterDimA=*/dims));
1542         int64_t size =
1543             dims[0] * dims[1] * dims[2] * CudnnDataTypeToByteSize(data_type);
1544         dnn::RnnDescriptor::ParamsRegion region = {
1545             reinterpret_cast<int64>(offset), size};
1546         (type == 0 ? weights : biases).push_back(region);
1547       }
1548     }
1549     TF_RETURN_IF_ERROR(CheckAndFetchProjectionWeights(
1550         cudnn, rnn_desc, layer, input_desc, filter_desc, region_desc_handle,
1551         &weights));
1552   }
1553 
1554   return CudnnRnnParamsDescriptor(std::move(filter_desc), params_size_in_bytes,
1555                                   weights, biases);
1556 }
1557 
1558 }  // namespace
1559 
1560 class CudnnRnnSequenceTensorDescriptor
1561     : public dnn::RnnSequenceTensorDescriptor {
CudnnRnnSequenceTensorDescriptor(GpuExecutor * parent,int max_seq_length,int batch_size,int data_size,cudnnDataType_t data_type,RNNDataDescriptor data_handle,TensorDescriptor handle)1562   CudnnRnnSequenceTensorDescriptor(GpuExecutor* parent, int max_seq_length,
1563                                    int batch_size, int data_size,
1564                                    cudnnDataType_t data_type,
1565                                    RNNDataDescriptor data_handle,
1566                                    TensorDescriptor handle)
1567       : max_seq_length_(max_seq_length),
1568         batch_size_(batch_size),
1569         data_size_(data_size),
1570         data_type_(data_type),
1571         handle_(std::move(handle)),
1572         rnn_data_handle_(std::move(data_handle)),
1573         handles_(max_seq_length, handle_.get()) {}
1574 
1575  public:
1576   CudnnRnnSequenceTensorDescriptor(CudnnRnnSequenceTensorDescriptor&&) =
1577       default;
1578 
Create(GpuExecutor * parent,int max_seq_length,int batch_size,int data_size,cudnnDataType_t data_type)1579   static port::StatusOr<CudnnRnnSequenceTensorDescriptor> Create(
1580       GpuExecutor* parent, int max_seq_length, int batch_size, int data_size,
1581       cudnnDataType_t data_type) {
1582     if (max_seq_length <= 0) {
1583       return port::Status(port::error::INVALID_ARGUMENT, "max_seq_length <= 0");
1584     }
1585     int dims[] = {batch_size, data_size, 1};
1586     int strides[] = {dims[1] * dims[2], dims[2], 1};
1587     TensorDescriptor tensor_desc = CreateTensorDescriptor();
1588     RETURN_IF_CUDNN_ERROR(cudnnSetTensorNdDescriptor(
1589         /*tensorDesc=*/tensor_desc.get(), /*dataType=*/data_type,
1590         /*nbDims=*/sizeof(dims) / sizeof(dims[0]), /*dimA=*/dims,
1591         /*strideA=*/strides));
1592     return CudnnRnnSequenceTensorDescriptor(parent, max_seq_length, batch_size,
1593                                             data_size, data_type, nullptr,
1594                                             std::move(tensor_desc));
1595   }
1596 
Create(GpuExecutor * parent,int max_seq_length,int batch_size,int data_size,const absl::Span<const int> & seq_lengths,bool time_major,cudnnDataType_t data_type)1597   static port::StatusOr<CudnnRnnSequenceTensorDescriptor> Create(
1598       GpuExecutor* parent, int max_seq_length, int batch_size, int data_size,
1599       const absl::Span<const int>& seq_lengths, bool time_major,
1600       cudnnDataType_t data_type) {
1601     if (max_seq_length <= 0) {
1602       return port::Status(port::error::INVALID_ARGUMENT, "max_seq_length <= 0");
1603     }
1604     int dims[] = {batch_size, data_size, 1};
1605     int strides[] = {dims[1] * dims[2], dims[2], 1};
1606     TensorDescriptor tensor_desc = CreateTensorDescriptor();
1607     RETURN_IF_CUDNN_ERROR(cudnnSetTensorNdDescriptor(
1608         /*tensorDesc=*/tensor_desc.get(), /*dataType=*/data_type,
1609         /*nbDims=*/sizeof(dims) / sizeof(dims[0]), /*dimA=*/dims,
1610         /*strideA=*/strides));
1611     const int* seq_lengths_array = seq_lengths.data();
1612     RNNDataDescriptor data_desc = CreateRNNDataDescriptor();
1613     float padding_fill = 0.0f;
1614     cudnnRNNDataLayout_t layout;
1615     if (time_major) {
1616       layout = CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_UNPACKED;
1617     } else {
1618       layout = CUDNN_RNN_DATA_LAYOUT_BATCH_MAJOR_UNPACKED;
1619     }
1620     RETURN_IF_CUDNN_ERROR(cudnnSetRNNDataDescriptor(
1621         /*RNNDataDesc=*/data_desc.get(), /*dataType*/ data_type,
1622         /*layout=*/layout,
1623         /*maxSeqLength=*/max_seq_length,
1624         /*batchSize=*/batch_size, /*vectorSize=*/data_size,
1625         /*seqLengthArray=*/seq_lengths_array,
1626         /*paddingFill*/ (void*)&padding_fill));
1627     return CudnnRnnSequenceTensorDescriptor(
1628         parent, max_seq_length, batch_size, data_size, data_type,
1629         std::move(data_desc), std::move(tensor_desc));
1630   }
1631 
handles() const1632   const cudnnTensorDescriptor_t* handles() const { return handles_.data(); }
data_handle() const1633   const cudnnRNNDataDescriptor_t data_handle() const {
1634     return rnn_data_handle_.get();
1635   }
1636 
max_seq_length() const1637   int max_seq_length() const { return max_seq_length_; }
batch_size() const1638   int batch_size() const { return batch_size_; }
data_size() const1639   int data_size() const { return data_size_; }
is_var_seq_lengths() const1640   bool is_var_seq_lengths() const { return rnn_data_handle_ != nullptr; }
1641 
1642  private:
1643   int max_seq_length_;
1644   int batch_size_;
1645   int data_size_;
1646   cudnnDataType_t data_type_;
1647   TensorDescriptor handle_;
1648   RNNDataDescriptor rnn_data_handle_;
1649   std::vector<cudnnTensorDescriptor_t> handles_;  // Copies of handle_.
1650   SE_DISALLOW_COPY_AND_ASSIGN(CudnnRnnSequenceTensorDescriptor);
1651 };
1652 
1653 class CudnnRnnStateTensorDescriptor : public dnn::RnnStateTensorDescriptor {
1654  public:
CudnnRnnStateTensorDescriptor(GpuExecutor * parent,int num_layers,int batch_size,int data_size,cudnnDataType_t data_type)1655   CudnnRnnStateTensorDescriptor(GpuExecutor* parent, int num_layers,
1656                                 int batch_size, int data_size,
1657                                 cudnnDataType_t data_type)
1658       : handle_(CreateTensorDescriptor()),
1659         num_layers_(num_layers),
1660         batch_size_(batch_size),
1661         data_size_(data_size),
1662         data_type_(data_type) {
1663     int dims[] = {num_layers, batch_size, data_size};
1664     int strides[] = {dims[1] * dims[2], dims[2], 1};
1665     CHECK_CUDNN_OK(cudnnSetTensorNdDescriptor(
1666         /*tensorDesc=*/handle_.get(), /*dataType=*/data_type,
1667         /*nbDims=*/sizeof(dims) / sizeof(dims[0]), /*dimA=*/dims,
1668         /*strideA=*/strides));
1669   }
1670 
handle() const1671   cudnnTensorDescriptor_t handle() const { return handle_.get(); }
1672 
num_layers() const1673   int num_layers() const { return num_layers_; }
batch_size() const1674   int batch_size() const { return batch_size_; }
data_size() const1675   int data_size() const { return data_size_; }
1676 
1677  private:
1678   TensorDescriptor handle_;
1679   int num_layers_;
1680   int batch_size_;
1681   int data_size_;
1682   cudnnDataType_t data_type_;
1683   SE_DISALLOW_COPY_AND_ASSIGN(CudnnRnnStateTensorDescriptor);
1684 };
1685 
1686 #if CUDNN_VERSION >= 8100 && TF_ENABLE_CUDNN_FRONTEND
1687 class CudnnConvolveExecutionPlan : public dnn::ConvolveExecutionPlan {
1688  public:
CudnnConvolveExecutionPlan(cudnn_frontend::ExecutionPlan plan)1689   CudnnConvolveExecutionPlan(cudnn_frontend::ExecutionPlan plan)
1690       : plan_(std::move(plan)) {}
getTag()1691   std::string getTag() override { return plan_.getTag(); };
get_raw_desc()1692   void* get_raw_desc() override { return plan_.get_raw_desc(); }
getWorkspaceSize()1693   int64_t getWorkspaceSize() override { return plan_.getWorkspaceSize(); }
1694 
1695  private:
1696   cudnn_frontend::ExecutionPlan plan_;
1697   SE_DISALLOW_COPY_AND_ASSIGN(CudnnConvolveExecutionPlan);
1698 };
1699 #endif  // CUDNN_VERSION >= 8100 && TF_ENABLE_CUDNN_FRONTEND
1700 
1701 namespace {
1702 
1703 struct RnnModelDims {
1704   int num_layers = 0;
1705   int batch_size = 0;
1706   int max_seq_length = 0;
1707   int hidden_size = 0;
1708   int input_size = 0;
1709   int cell_size = 0;
1710   int dir_count = 0;
1711 };
1712 
1713 template <class T>
ExtractAndCheckRnnForward(const CudnnRnnDescriptor & rnn_desc,const CudnnRnnSequenceTensorDescriptor & input_desc,const DeviceMemory<T> & input_data,const CudnnRnnStateTensorDescriptor & input_h_desc,const DeviceMemory<T> & input_h_data,const CudnnRnnStateTensorDescriptor & input_c_desc,const DeviceMemory<T> & input_c_data,const DeviceMemory<T> & params,const CudnnRnnSequenceTensorDescriptor & output_desc,const DeviceMemory<T> & output_data,const CudnnRnnStateTensorDescriptor & output_h_desc,const DeviceMemory<T> & output_h_data,const CudnnRnnStateTensorDescriptor & output_c_desc,const DeviceMemory<T> & output_c_data)1714 port::StatusOr<RnnModelDims> ExtractAndCheckRnnForward(
1715     const CudnnRnnDescriptor& rnn_desc,
1716     const CudnnRnnSequenceTensorDescriptor& input_desc,
1717     const DeviceMemory<T>& input_data,
1718     const CudnnRnnStateTensorDescriptor& input_h_desc,
1719     const DeviceMemory<T>& input_h_data,
1720     const CudnnRnnStateTensorDescriptor& input_c_desc,
1721     const DeviceMemory<T>& input_c_data, const DeviceMemory<T>& params,
1722     const CudnnRnnSequenceTensorDescriptor& output_desc,
1723     const DeviceMemory<T>& output_data,
1724     const CudnnRnnStateTensorDescriptor& output_h_desc,
1725     const DeviceMemory<T>& output_h_data,
1726     const CudnnRnnStateTensorDescriptor& output_c_desc,
1727     const DeviceMemory<T>& output_c_data) {
1728   // extract model parameters
1729   RnnModelDims model_dims;
1730   model_dims.num_layers = rnn_desc.num_layers();
1731   model_dims.batch_size = input_desc.batch_size();
1732   model_dims.max_seq_length = input_desc.max_seq_length();
1733   model_dims.hidden_size = rnn_desc.hidden_size();
1734   model_dims.input_size = input_desc.data_size();
1735   model_dims.cell_size = rnn_desc.cell_size();
1736   model_dims.dir_count =
1737       (rnn_desc.direction_mode() == CUDNN_BIDIRECTIONAL) ? 2 : 1;
1738 
1739   // check parameters
1740   if (!(input_h_desc.num_layers() ==
1741             model_dims.num_layers * model_dims.dir_count &&
1742         input_h_desc.batch_size() == model_dims.batch_size &&
1743         input_h_desc.data_size() == model_dims.hidden_size)) {
1744     return port::Status(port::error::INVALID_ARGUMENT, "Invalid input_h shape");
1745   }
1746   // The LSTM projection will be used if input_h_desc.data_size() <
1747   // input_c_desc.data_size()
1748   if (!(input_h_desc.num_layers() == input_c_desc.num_layers() &&
1749         input_h_desc.batch_size() == input_c_desc.batch_size() &&
1750         input_h_desc.data_size() <= input_c_desc.data_size())) {
1751     return port::Status(port::error::INVALID_ARGUMENT, "Invalid input_c shape");
1752   }
1753   if (!(output_desc.max_seq_length() == model_dims.max_seq_length &&
1754         output_desc.batch_size() == model_dims.batch_size &&
1755         output_desc.data_size() ==
1756             model_dims.hidden_size * model_dims.dir_count)) {
1757     return port::Status(port::error::INVALID_ARGUMENT, "Invalid output shape");
1758   }
1759   if (!(input_h_desc.num_layers() == output_h_desc.num_layers() &&
1760         input_h_desc.batch_size() == output_h_desc.batch_size() &&
1761         input_h_desc.data_size() == output_h_desc.data_size())) {
1762     return port::Status(port::error::INVALID_ARGUMENT,
1763                         "Invalid output_h shape");
1764   }
1765   if (!(input_h_desc.num_layers() == output_c_desc.num_layers() &&
1766         input_h_desc.batch_size() == output_c_desc.batch_size() &&
1767         input_h_desc.data_size() <= output_c_desc.data_size())) {
1768     return port::Status(port::error::INVALID_ARGUMENT,
1769                         "Invalid output_c shape");
1770   }
1771 
1772   return model_dims;
1773 }
1774 
CheckRNNParameterSize(const CudnnHandle & cudnn,const CudnnRnnDescriptor & rnn_desc,const CudnnRnnSequenceTensorDescriptor & input_desc)1775 port::Status CheckRNNParameterSize(
1776     const CudnnHandle& cudnn, const CudnnRnnDescriptor& rnn_desc,
1777     const CudnnRnnSequenceTensorDescriptor& input_desc) {
1778   size_t params_size_in_bytes = 0;
1779 #if CUDNN_VERSION >= 8100 && TF_ENABLE_CUDNN_FRONTEND
1780   RETURN_IF_CUDNN_ERROR(cudnnGetRNNWeightSpaceSize(
1781       /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(),
1782       /*sizeInBytes=*/&params_size_in_bytes));
1783 #else
1784   RETURN_IF_CUDNN_ERROR(cudnnGetRNNParamsSize(
1785       /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(),
1786       /*xDesc=*/input_desc.handles()[0], /*sizeInBytes=*/&params_size_in_bytes,
1787       /*dataType=*/rnn_desc.data_type()));
1788 #endif
1789   if (static_cast<int64>(params_size_in_bytes) !=
1790       rnn_desc.ParamsSizeInBytes()) {
1791     return port::Status(port::error::INVALID_ARGUMENT,
1792                         "Mismatching RNN parameter size");
1793   }
1794   return port::Status::OK();
1795 }
1796 
CreateRnnWorkspace(Stream * stream,const CudnnHandle & cudnn,const CudnnRnnDescriptor & rnn_desc,const CudnnRnnSequenceTensorDescriptor & input_desc,ScratchAllocator * workspace_allocator)1797 port::StatusOr<DeviceMemory<uint8>> CreateRnnWorkspace(
1798     Stream* stream, const CudnnHandle& cudnn,
1799     const CudnnRnnDescriptor& rnn_desc,
1800     const CudnnRnnSequenceTensorDescriptor& input_desc,
1801     ScratchAllocator* workspace_allocator) {
1802   // Query the workspace size.
1803   size_t workspace_size_in_bytes = 0;
1804   RETURN_IF_CUDNN_ERROR(cudnnGetRNNWorkspaceSize(
1805       /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(),
1806       /*seqLength=*/input_desc.max_seq_length(), /*xDesc=*/input_desc.handles(),
1807       /*sizeInBytes=*/&workspace_size_in_bytes));
1808   // Allocate the workspace.
1809   if (workspace_size_in_bytes == 0) {
1810     return DeviceMemory<uint8>();
1811   }
1812   return workspace_allocator->AllocateBytes(workspace_size_in_bytes);
1813 }
1814 
1815 #if CUDNN_VERSION >= 7402
CreateBatchNormForwardWorkspace(Stream * stream,const CudnnHandle & cudnn,const cudnnBatchNormMode_t & mode,const cudnnBatchNormOps_t & bn_ops,const cudnnActivationDescriptor_t & activation_desc,const CudnnTensorDescriptor & x_descriptor,const CudnnTensorDescriptor & scale_offset_descriptor,ScratchAllocator * workspace_allocator)1816 port::StatusOr<DeviceMemory<uint8>> CreateBatchNormForwardWorkspace(
1817     Stream* stream, const CudnnHandle& cudnn, const cudnnBatchNormMode_t& mode,
1818     const cudnnBatchNormOps_t& bn_ops,
1819     const cudnnActivationDescriptor_t& activation_desc,
1820     const CudnnTensorDescriptor& x_descriptor,
1821     const CudnnTensorDescriptor& scale_offset_descriptor,
1822     ScratchAllocator* workspace_allocator) {
1823   // Query the workspace size.
1824   size_t workspace_size_in_bytes = 0;
1825   RETURN_IF_CUDNN_ERROR(
1826       cudnnGetBatchNormalizationForwardTrainingExWorkspaceSize(
1827           /*handle=*/cudnn.handle(), /*mode=*/mode, /*bnOps=*/bn_ops,
1828           /*xDesc=*/x_descriptor.handle(), /*zDesc=*/x_descriptor.handle(),
1829           /*yDesc=*/x_descriptor.handle(),
1830           /*bnScaleBiasMeanVarDesc=*/scale_offset_descriptor.handle(),
1831           /*activationDesc=*/activation_desc,
1832           /*sizeInBytes=*/&workspace_size_in_bytes));
1833   // Allocate the workspace.
1834   if (workspace_size_in_bytes == 0) {
1835     return DeviceMemory<uint8>();
1836   }
1837   return workspace_allocator->AllocateBytes(workspace_size_in_bytes);
1838 }
1839 
CreateBatchNormBackwardWorkspace(Stream * stream,const CudnnHandle & cudnn,const cudnnBatchNormMode_t & mode,const cudnnBatchNormOps_t & bn_ops,const CudnnTensorDescriptor & x_descriptor,const CudnnTensorDescriptor & scale_offset_descriptor,ScratchAllocator * workspace_allocator)1840 port::StatusOr<DeviceMemory<uint8>> CreateBatchNormBackwardWorkspace(
1841     Stream* stream, const CudnnHandle& cudnn, const cudnnBatchNormMode_t& mode,
1842     const cudnnBatchNormOps_t& bn_ops,
1843     const CudnnTensorDescriptor& x_descriptor,
1844     const CudnnTensorDescriptor& scale_offset_descriptor,
1845     ScratchAllocator* workspace_allocator) {
1846   // Query the workspace size.
1847   size_t workspace_size_in_bytes = 0;
1848   RETURN_IF_CUDNN_ERROR(cudnnGetBatchNormalizationBackwardExWorkspaceSize(
1849       /*handle=*/cudnn.handle(), /*mode=*/mode, /*bnOps=*/bn_ops,
1850       /*xDesc=*/x_descriptor.handle(),
1851       /*yDesc=*/x_descriptor.handle(),
1852       /*dyDesc=*/x_descriptor.handle(),
1853       /*dzDesc=*/nullptr,
1854       /*dxDesc=*/x_descriptor.handle(),
1855       /*dBnScaleBiasDesc=*/scale_offset_descriptor.handle(),
1856       /*activationDesc=*/nullptr, /*sizeInBytes=*/&workspace_size_in_bytes));
1857   // Allocate the workspace.
1858   if (workspace_size_in_bytes == 0) {
1859     return DeviceMemory<uint8>();
1860   }
1861   return workspace_allocator->AllocateBytes(workspace_size_in_bytes);
1862 }
1863 
1864 #endif
1865 
1866 }  // namespace
1867 
1868 template <class T>
DoRnnForwardImpl(Stream * stream,const CudnnRnnDescriptor & rnn_desc,const CudnnRnnSequenceTensorDescriptor & input_desc,const DeviceMemory<T> & input_data,const DeviceMemory<int> & seq_lengths_data,const CudnnRnnStateTensorDescriptor & input_h_desc,const DeviceMemory<T> & input_h_data,const CudnnRnnStateTensorDescriptor & input_c_desc,const DeviceMemory<T> & input_c_data,const DeviceMemory<T> & params,const CudnnRnnSequenceTensorDescriptor & output_desc,DeviceMemory<T> * output_data,const CudnnRnnStateTensorDescriptor & output_h_desc,DeviceMemory<T> * output_h_data,const CudnnRnnStateTensorDescriptor & output_c_desc,DeviceMemory<T> * output_c_data,bool is_training,ScratchAllocator * reserve_space_allocator,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)1869 port::Status CudnnSupport::DoRnnForwardImpl(
1870     Stream* stream, const CudnnRnnDescriptor& rnn_desc,
1871     const CudnnRnnSequenceTensorDescriptor& input_desc,
1872     const DeviceMemory<T>& input_data,
1873     const DeviceMemory<int>& seq_lengths_data,
1874     const CudnnRnnStateTensorDescriptor& input_h_desc,
1875     const DeviceMemory<T>& input_h_data,
1876     const CudnnRnnStateTensorDescriptor& input_c_desc,
1877     const DeviceMemory<T>& input_c_data, const DeviceMemory<T>& params,
1878     const CudnnRnnSequenceTensorDescriptor& output_desc,
1879     DeviceMemory<T>* output_data,
1880     const CudnnRnnStateTensorDescriptor& output_h_desc,
1881     DeviceMemory<T>* output_h_data,
1882     const CudnnRnnStateTensorDescriptor& output_c_desc,
1883     DeviceMemory<T>* output_c_data, bool is_training,
1884     ScratchAllocator* reserve_space_allocator,
1885     ScratchAllocator* workspace_allocator,
1886     dnn::ProfileResult* output_profile_result) {
1887   SE_ASSIGN_OR_RETURN(
1888       RnnModelDims model_dims,
1889       ExtractAndCheckRnnForward(
1890           rnn_desc, input_desc, input_data, input_h_desc, input_h_data,
1891           input_c_desc, input_c_data, params, output_desc, *output_data,
1892           output_h_desc, *output_h_data, output_c_desc, *output_c_data));
1893 
1894   auto cudnn = cudnn_->GetHandle(parent_, stream);
1895 
1896   SE_RETURN_IF_ERROR(CheckRNNParameterSize(cudnn, rnn_desc, input_desc));
1897 
1898   // In CUDNN v8.0, the cudnnRNNForward*** and cudnnRNNForward***Ex have been
1899   // deprecated. Instead, we use the cudnnRNNForward which requires the
1900   // sequence_lengths parameter. For more info,
1901   // https://docs.nvidia.com/deeplearning/cudnn/api/index.html#release-802.
1902 #if CUDNN_VERSION >= 8100 && TF_ENABLE_CUDNN_FRONTEND
1903   if (input_desc.is_var_seq_lengths()) {
1904     DeviceMemory<uint8> workspace;
1905     DeviceMemory<uint8> reserve_space;
1906     cudnnForwardMode_t rnn_fwd_mode;
1907     if (is_training) {
1908       rnn_fwd_mode = CUDNN_FWD_MODE_TRAINING;
1909     } else {
1910       rnn_fwd_mode = CUDNN_FWD_MODE_INFERENCE;
1911     }
1912     size_t reserve_space_size_in_bytes = 0;
1913     size_t workspace_size_in_bytes = 0;
1914     RETURN_IF_CUDNN_ERROR(cudnnGetRNNTempSpaceSizes(
1915         /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(),
1916         /*fMode=*/rnn_fwd_mode, /*xDesc=*/input_desc.data_handle(),
1917         /*workSpaceSize=*/&workspace_size_in_bytes,
1918         /*reserveSpaceSize=*/&reserve_space_size_in_bytes));
1919 
1920     if (workspace_size_in_bytes > 0) {
1921       SE_ASSIGN_OR_RETURN(workspace, workspace_allocator->AllocateBytes(
1922                                          workspace_size_in_bytes));
1923     }
1924     if (reserve_space_size_in_bytes > 0) {
1925       SE_ASSIGN_OR_RETURN(reserve_space, reserve_space_allocator->AllocateBytes(
1926                                              reserve_space_size_in_bytes));
1927     }
1928 
1929     std::unique_ptr<GpuTimer, GpuTimerDeleter> timer;
1930     const bool is_profiling = output_profile_result != nullptr;
1931     if (is_profiling) {
1932       timer.reset(new GpuTimer(parent_));
1933       // The start and stop of the timer should be as close to the Cudnn call as
1934       // possible. It is still possible for other threads to issue workload on
1935       // to this stream. So it could take multiple profiling measurements.
1936       if (!timer->Init() || !timer->Start(AsGpuStream(stream))) {
1937         return port::Status(port::error::INTERNAL, "Failed to start timer");
1938       }
1939     }
1940 
1941     RETURN_IF_CUDNN_ERROR(cudnnRNNForward(
1942         /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(),
1943         /*fwdMode=*/rnn_fwd_mode,
1944         /*devSeqLengths=*/
1945         reinterpret_cast<const int*>(seq_lengths_data.opaque()),
1946         /*xDesc=*/input_desc.data_handle(), /*x=*/input_data.opaque(),
1947         /*yDesc=*/output_desc.data_handle(), /*y=*/output_data->opaque(),
1948         /*hxDesc=*/input_h_desc.handle(), /*hx=*/input_h_data.opaque(),
1949         /*hy=*/output_h_data->opaque(),
1950         /*cxDesc=*/input_c_desc.handle(), /*cx=*/input_c_data.opaque(),
1951         /*cy=*/output_c_data->opaque(),
1952         /*weightSpaceSize=*/rnn_desc.ParamsSizeInBytes(),
1953         /*weightSpace=*/params.opaque(),
1954         /*workSpaceSize=*/workspace.size(), /*workspace=*/workspace.opaque(),
1955         /*reserveSpaceSizeInBytes=*/reserve_space.size(),
1956         /*reserveSpace=*/reserve_space.opaque()));
1957 
1958     if (is_profiling) {
1959       if (!timer->Stop(AsGpuStream(stream))) {
1960         return port::Status(port::error::INTERNAL, "Failed to stop timer");
1961       }
1962       auto algo_desc = *rnn_desc.algorithm_config().algorithm();
1963       output_profile_result->set_algorithm(algo_desc);
1964       output_profile_result->set_elapsed_time_in_ms(
1965           timer->GetElapsedMilliseconds());
1966     }
1967     return port::Status::OK();
1968   }
1969 #endif
1970   SE_ASSIGN_OR_RETURN(DeviceMemory<uint8> workspace,
1971                       CreateRnnWorkspace(stream, cudnn, rnn_desc, input_desc,
1972                                          workspace_allocator))
1973 
1974   // query the reserve space size
1975   // allocate the reserve space
1976   DeviceMemory<uint8> reserve_space;
1977   if (is_training) {
1978     size_t reserve_space_size_in_bytes = 0;
1979     RETURN_IF_CUDNN_ERROR(cudnnGetRNNTrainingReserveSize(
1980         /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(),
1981         /*seqLength=*/model_dims.max_seq_length, /*xDesc=*/input_desc.handles(),
1982         /*sizeInBytes=*/&reserve_space_size_in_bytes));
1983 
1984     if (reserve_space_size_in_bytes > 0) {
1985       SE_ASSIGN_OR_RETURN(reserve_space, reserve_space_allocator->AllocateBytes(
1986                                              reserve_space_size_in_bytes));
1987     }
1988   }
1989 
1990   std::unique_ptr<GpuTimer, GpuTimerDeleter> timer;
1991   const bool is_profiling = output_profile_result != nullptr;
1992   if (is_profiling) {
1993     timer.reset(new GpuTimer(parent_));
1994     // The start and stop of the timer should be as close to the Cudnn call as
1995     // possible. It is still possible for other threads to issue workload on
1996     // to this stream. So it could take multiple profiling measurements.
1997     if (!timer->Init() || !timer->Start(AsGpuStream(stream))) {
1998       return port::Status(port::error::INTERNAL, "Failed to start timer");
1999     }
2000   }
2001 
2002   if (!is_training) {
2003     if (input_desc.is_var_seq_lengths()) {
2004       RETURN_IF_CUDNN_ERROR(cudnnRNNForwardInferenceEx(
2005           /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(),
2006           /*xDesc=*/input_desc.data_handle(), /*x=*/input_data.opaque(),
2007           /*hxDesc=*/input_h_desc.handle(), /*hx=*/input_h_data.opaque(),
2008           /*cxDesc=*/input_c_desc.handle(), /*cx=*/input_c_data.opaque(),
2009           /*wDesc=*/rnn_desc.params_handle(), /*w=*/params.opaque(),
2010           /*yDesc=*/output_desc.data_handle(),
2011           /*y=*/output_data->opaque(),
2012           /*hyDesc=*/output_h_desc.handle(), /*hy=*/output_h_data->opaque(),
2013           /*cyDesc=*/output_c_desc.handle(), /*cy=*/output_c_data->opaque(),
2014           nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr,
2015           nullptr,
2016           /*workspace=*/workspace.opaque(),
2017           /*workSpaceSizeInBytes=*/workspace.size()));
2018     } else {
2019       RETURN_IF_CUDNN_ERROR(cudnnRNNForwardInference(
2020           /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(),
2021           /*seqLength=*/model_dims.max_seq_length,
2022           /*xDesc=*/input_desc.handles(),
2023           /*x=*/input_data.opaque(), /*hxDesc=*/input_h_desc.handle(),
2024           /*hx=*/input_h_data.opaque(), /*cxDesc=*/input_c_desc.handle(),
2025           /*cx=*/input_c_data.opaque(), /*wDesc=*/rnn_desc.params_handle(),
2026           /*w=*/params.opaque(), /*yDesc=*/output_desc.handles(),
2027           /*y=*/output_data->opaque(), /*hyDesc=*/output_h_desc.handle(),
2028           /*hy=*/output_h_data->opaque(), /*cyDesc=*/output_c_desc.handle(),
2029           /*cy=*/output_c_data->opaque(), /*workspace=*/workspace.opaque(),
2030           /*workSpaceSizeInBytes=*/workspace.size()));
2031     }
2032   } else {
2033     if (input_desc.is_var_seq_lengths()) {
2034       RETURN_IF_CUDNN_ERROR(cudnnRNNForwardTrainingEx(
2035           /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(),
2036           /*xDesc=*/input_desc.data_handle(), /*x=*/input_data.opaque(),
2037           /*hxDesc=*/input_h_desc.handle(), /*hx=*/input_h_data.opaque(),
2038           /*cxDesc=*/input_c_desc.handle(), /*cx=*/input_c_data.opaque(),
2039           /*wDesc=*/rnn_desc.params_handle(), /*w=*/params.opaque(),
2040           /*yDesc=*/output_desc.data_handle(),
2041           /*y=*/output_data->opaque(),
2042           /*hyDesc=*/output_h_desc.handle(), /*hy=*/output_h_data->opaque(),
2043           /*cyDesc=*/output_c_desc.handle(), /*cy=*/output_c_data->opaque(),
2044           nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr,
2045           nullptr,
2046           /*workspace=*/workspace.opaque(),
2047           /*workSpaceSizeInBytes=*/workspace.size(),
2048           /*reserveSpace=*/reserve_space.opaque(),
2049           /*reserveSpaceSizeInBytes=*/reserve_space.size()));
2050     } else {
2051       RETURN_IF_CUDNN_ERROR(cudnnRNNForwardTraining(
2052           /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(),
2053           /*seqLength=*/model_dims.max_seq_length,
2054           /*xDesc=*/input_desc.handles(),
2055           /*x=*/input_data.opaque(), /*hxDesc=*/input_h_desc.handle(),
2056           /*hx=*/input_h_data.opaque(), /*cxDesc=*/input_c_desc.handle(),
2057           /*cx=*/input_c_data.opaque(), /*wDesc=*/rnn_desc.params_handle(),
2058           /*w=*/params.opaque(), /*yDesc=*/output_desc.handles(),
2059           /*y=*/output_data->opaque(), /*hyDesc=*/output_h_desc.handle(),
2060           /*hy=*/output_h_data->opaque(), /*cyDesc=*/output_c_desc.handle(),
2061           /*cy=*/output_c_data->opaque(), /*workspace=*/workspace.opaque(),
2062           /*workSpaceSizeInBytes=*/workspace.size(),
2063           /*reserveSpace=*/reserve_space.opaque(),
2064           /*reserveSpaceSizeInBytes=*/reserve_space.size()));
2065     }
2066   }
2067 
2068   if (is_profiling) {
2069     if (!timer->Stop(AsGpuStream(stream))) {
2070       return port::Status(port::error::INTERNAL, "Failed to stop timer");
2071     }
2072     auto algo_desc = *rnn_desc.algorithm_config().algorithm();
2073     output_profile_result->set_algorithm(algo_desc);
2074     output_profile_result->set_elapsed_time_in_ms(
2075         timer->GetElapsedMilliseconds());
2076   }
2077 
2078   return port::Status::OK();
2079 }
2080 
2081 template <class T>
DoRnnBackwardImpl(Stream * stream,const CudnnRnnDescriptor & rnn_desc,const CudnnRnnSequenceTensorDescriptor & input_desc,const DeviceMemory<T> & input_data,const DeviceMemory<int> & seq_lengths_data,const CudnnRnnStateTensorDescriptor & input_h_desc,const DeviceMemory<T> & input_h_data,const CudnnRnnStateTensorDescriptor & input_c_desc,const DeviceMemory<T> & input_c_data,const DeviceMemory<T> & params,const CudnnRnnSequenceTensorDescriptor & output_desc,const DeviceMemory<T> & output_data,const CudnnRnnStateTensorDescriptor & output_h_desc,const DeviceMemory<T> & output_h_data,const CudnnRnnStateTensorDescriptor & output_c_desc,const DeviceMemory<T> & output_c_data,const DeviceMemory<T> & output_backprop_data,const DeviceMemory<T> & output_h_backprop_data,const DeviceMemory<T> & output_c_backprop_data,DeviceMemory<T> * input_backprop_data,DeviceMemory<T> * input_h_backprop_data,DeviceMemory<T> * input_c_backprop_data,DeviceMemory<T> * params_backprop_data,DeviceMemory<uint8> * reserve_space_data,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)2082 port::Status CudnnSupport::DoRnnBackwardImpl(
2083     Stream* stream, const CudnnRnnDescriptor& rnn_desc,
2084     const CudnnRnnSequenceTensorDescriptor& input_desc,
2085     const DeviceMemory<T>& input_data,
2086     const DeviceMemory<int>& seq_lengths_data,
2087     const CudnnRnnStateTensorDescriptor& input_h_desc,
2088     const DeviceMemory<T>& input_h_data,
2089     const CudnnRnnStateTensorDescriptor& input_c_desc,
2090     const DeviceMemory<T>& input_c_data, const DeviceMemory<T>& params,
2091     const CudnnRnnSequenceTensorDescriptor& output_desc,
2092     const DeviceMemory<T>& output_data,
2093     const CudnnRnnStateTensorDescriptor& output_h_desc,
2094     const DeviceMemory<T>& output_h_data,
2095     const CudnnRnnStateTensorDescriptor& output_c_desc,
2096     const DeviceMemory<T>& output_c_data,
2097     const DeviceMemory<T>& output_backprop_data,
2098     const DeviceMemory<T>& output_h_backprop_data,
2099     const DeviceMemory<T>& output_c_backprop_data,
2100     DeviceMemory<T>* input_backprop_data,
2101     DeviceMemory<T>* input_h_backprop_data,
2102     DeviceMemory<T>* input_c_backprop_data,
2103     DeviceMemory<T>* params_backprop_data,
2104     DeviceMemory<uint8>* reserve_space_data,
2105     ScratchAllocator* workspace_allocator,
2106     dnn::ProfileResult* output_profile_result) {
2107   SE_ASSIGN_OR_RETURN(
2108       RnnModelDims model_dims,
2109       ExtractAndCheckRnnForward(rnn_desc, input_desc, input_data, input_h_desc,
2110                                 input_h_data, input_c_desc, input_c_data,
2111                                 params, output_desc, output_data, output_h_desc,
2112                                 output_h_data, output_c_desc, output_c_data));
2113 
2114   auto cudnn = cudnn_->GetHandle(parent_, stream);
2115 
2116   SE_RETURN_IF_ERROR(CheckRNNParameterSize(cudnn, rnn_desc, input_desc));
2117 
2118   // In CUDNN v8.0, the cudnnRNNForward*** and cudnnRNNForward***Ex have been
2119   // deprecated. Instead, we use the cudnnRNNForward which requires the
2120   // sequence_lengths parameter. For more info,
2121   // https://docs.nvidia.com/deeplearning/cudnn/api/index.html#release-802.
2122 #if CUDNN_VERSION >= 8100 && TF_ENABLE_CUDNN_FRONTEND
2123   if (input_desc.is_var_seq_lengths()) {
2124     DeviceMemory<uint8> workspace;
2125     size_t workspace_size_in_bytes = 0;
2126     RETURN_IF_CUDNN_ERROR(cudnnGetRNNTempSpaceSizes(
2127         /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(),
2128         /*fMode=*/CUDNN_FWD_MODE_TRAINING, /*xDesc=*/input_desc.data_handle(),
2129         /*workSpaceSize=*/&workspace_size_in_bytes,
2130         /*reserveSpaceSize=*/NULL));
2131     if (workspace_size_in_bytes > 0) {
2132       SE_ASSIGN_OR_RETURN(workspace, workspace_allocator->AllocateBytes(
2133                                          workspace_size_in_bytes));
2134     }
2135 
2136     std::unique_ptr<GpuTimer, GpuTimerDeleter> timer;
2137     const bool is_profiling = output_profile_result != nullptr;
2138     if (is_profiling) {
2139       timer.reset(new GpuTimer(parent_));
2140       // The start and stop of the timer should be as close to the Cudnn call as
2141       // possible. It is still possible for other threads to issue workload on
2142       // to this stream. So it could take multiple profiling measurements.
2143       if (!timer->Init() || !timer->Start(AsGpuStream(stream))) {
2144         return port::Status(port::error::INTERNAL, "Failed to start timer");
2145       }
2146     }
2147 
2148     RETURN_IF_CUDNN_ERROR(cudnnRNNBackwardData_v8(
2149         /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(),
2150         /*devSeqLengths=*/
2151         reinterpret_cast<const int*>(seq_lengths_data.opaque()),
2152         /*yDesc=*/output_desc.data_handle(), /*y=*/output_data.opaque(),
2153         /*dy=*/output_backprop_data.opaque(),
2154         /*xDesc=*/input_desc.data_handle(),
2155         /*dx=*/input_backprop_data->opaque(),
2156         /*hxDesc=*/input_h_desc.handle(), /*hx=*/input_h_data.opaque(),
2157         /*dhy=*/output_h_backprop_data.opaque(),
2158         /*dhx=*/input_h_backprop_data->opaque(),
2159         /*cxDesc=*/input_c_desc.handle(), /*cx=*/input_c_data.opaque(),
2160         /*dcy=*/output_c_backprop_data.opaque(),
2161         /*dcx=*/input_c_backprop_data->opaque(),
2162         /*weightSpaceSize=*/rnn_desc.ParamsSizeInBytes(),
2163         /*weightSpace=*/params.opaque(),
2164         /*workSpaceSize=*/workspace.size(), /*workSpace=*/workspace.opaque(),
2165         /*reserveSpaceSize=*/reserve_space_data->size(),
2166         /*reserveSpace=*/reserve_space_data->opaque()));
2167 
2168     if (params_backprop_data != nullptr) {
2169       // Clear the dw to zeros.
2170       stream->ThenMemZero(params_backprop_data, params_backprop_data->size());
2171       RETURN_IF_CUDNN_ERROR(cudnnRNNBackwardWeights_v8(
2172           /*handle=*/cudnn.handle(),
2173           /*rnnDesc=*/rnn_desc.handle(),
2174           /*addGrad=*/CUDNN_WGRAD_MODE_ADD,
2175           /*devSeqLengths=*/
2176           reinterpret_cast<const int*>(seq_lengths_data.opaque()),
2177           /*xDesc=*/input_desc.data_handle(),
2178           /*x=*/input_data.opaque(),
2179           /*hDesc=*/input_h_desc.handle(),
2180           /*hx=*/input_h_data.opaque(),
2181           /*yDesc=*/output_desc.data_handle(),
2182           /*y=*/output_data.opaque(),
2183           /*weightSpaceSize=*/rnn_desc.ParamsSizeInBytes(),
2184           /*dweightSpace=*/params_backprop_data->opaque(),
2185           /*workSpaceSize=*/workspace.size(),
2186           /*workSpace=*/workspace.opaque(),
2187           /*reserveSpaceSize=*/reserve_space_data->size(),
2188           /*reserveSpace=*/reserve_space_data->opaque()));
2189     }
2190 
2191     if (is_profiling) {
2192       if (!timer->Stop(AsGpuStream(stream))) {
2193         return port::Status(port::error::INTERNAL, "Failed to stop timer");
2194       }
2195       auto algo_desc = *rnn_desc.algorithm_config().algorithm();
2196       output_profile_result->set_algorithm(algo_desc);
2197       output_profile_result->set_elapsed_time_in_ms(
2198           timer->GetElapsedMilliseconds());
2199     }
2200     return port::Status::OK();
2201   }
2202 #endif
2203   SE_ASSIGN_OR_RETURN(DeviceMemory<uint8> workspace,
2204                       CreateRnnWorkspace(stream, cudnn, rnn_desc, input_desc,
2205                                          workspace_allocator));
2206 
2207   std::unique_ptr<GpuTimer, GpuTimerDeleter> timer;
2208   const bool is_profiling = output_profile_result != nullptr;
2209   if (is_profiling) {
2210     timer.reset(new GpuTimer(parent_));
2211     // The start and stop of the timer should be as close to the Cudnn call as
2212     // possible. It is still possible for other threads to issue workload on
2213     // to this stream. So it could take multiple profiling measurements.
2214     if (!timer->Init() || !timer->Start(AsGpuStream(stream))) {
2215       return port::Status(port::error::INTERNAL, "Failed to start timer");
2216     }
2217   }
2218 
2219   if (input_desc.is_var_seq_lengths()) {
2220     RETURN_IF_CUDNN_ERROR(cudnnRNNBackwardDataEx(
2221         /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(),
2222         /*yDesc=*/output_desc.data_handle(), /*y=*/output_data.opaque(),
2223         /*dyDesc=*/output_desc.data_handle(),
2224         /*dy=*/output_backprop_data.opaque(), nullptr, nullptr,
2225         /*dhyDesc=*/output_h_desc.handle(),
2226         /*dhy=*/output_h_backprop_data.opaque(),
2227         /*dcyDesc=*/output_c_desc.handle(),
2228         /*dcy=*/output_c_backprop_data.opaque(),
2229         /*wDesc=*/rnn_desc.params_handle(), /*w=*/params.opaque(),
2230         /*hxDesc=*/input_h_desc.handle(), /*hx=*/input_h_data.opaque(),
2231         /*cxDesc=*/input_c_desc.handle(), /*cx=*/input_c_data.opaque(),
2232         /*dxDesc=*/input_desc.data_handle(),
2233         /*dx=*/input_backprop_data->opaque(),
2234         /*dhxDesc=*/input_h_desc.handle(),
2235         /*dhx=*/input_h_backprop_data->opaque(),
2236         /*dcxDesc=*/input_c_desc.handle(),
2237         /*dcx=*/input_c_backprop_data->opaque(), nullptr, nullptr,
2238         /*workspace=*/workspace.opaque(),
2239         /*workSpaceSizeInBytes=*/workspace.size(),
2240         /*reserveSpace=*/reserve_space_data->opaque(),
2241         /*reserveSpaceSizeInBytes=*/reserve_space_data->size()));
2242   } else {
2243     RETURN_IF_CUDNN_ERROR(cudnnRNNBackwardData(
2244         /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(),
2245         /*seqLength=*/model_dims.max_seq_length,
2246         /*yDesc=*/output_desc.handles(),
2247         /*y=*/output_data.opaque(), /*dyDesc=*/output_desc.handles(),
2248         /*dy=*/output_backprop_data.opaque(),
2249         /*dhyDesc=*/output_h_desc.handle(),
2250         /*dhy=*/output_h_backprop_data.opaque(),
2251         /*dcyDesc=*/output_c_desc.handle(),
2252         /*dcy=*/output_c_backprop_data.opaque(),
2253         /*wDesc=*/rnn_desc.params_handle(), /*w=*/params.opaque(),
2254         /*hxDesc=*/input_h_desc.handle(), /*hx=*/input_h_data.opaque(),
2255         /*cxDesc=*/input_c_desc.handle(), /*cx=*/input_c_data.opaque(),
2256         /*dxDesc=*/input_desc.handles(), /*dx=*/input_backprop_data->opaque(),
2257         /*dhxDesc=*/input_h_desc.handle(),
2258         /*dhx=*/input_h_backprop_data->opaque(),
2259         /*dcxDesc=*/input_c_desc.handle(),
2260         /*dcx=*/input_c_backprop_data->opaque(),
2261         /*workspace=*/workspace.opaque(),
2262         /*workSpaceSizeInBytes=*/workspace.size(),
2263         /*reserveSpace=*/reserve_space_data->opaque(),
2264         /*reserveSpaceSizeInBytes=*/reserve_space_data->size()));
2265   }
2266 
2267   if (params_backprop_data != nullptr) {
2268     // Clear the dw to zeros.
2269     stream->ThenMemZero(params_backprop_data, params_backprop_data->size());
2270     if (input_desc.is_var_seq_lengths()) {
2271       RETURN_IF_CUDNN_ERROR(cudnnRNNBackwardWeightsEx(
2272           /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(),
2273           /*xDesc=*/input_desc.data_handle(), /*x=*/input_data.opaque(),
2274           /*hxDesc=*/input_h_desc.handle(), /*hx=*/input_h_data.opaque(),
2275           /*yDesc=*/output_desc.data_handle(),
2276           /*y=*/output_data.opaque(),
2277           /*workspace=*/workspace.opaque(),
2278           /*workSpaceSizeInBytes=*/workspace.size(),
2279           /*dwDesc=*/rnn_desc.params_handle(),
2280           /*dw=*/params_backprop_data->opaque(),
2281           /*reserveSpace=*/reserve_space_data->opaque(),
2282           /*reserveSpaceSizeInBytes=*/reserve_space_data->size()));
2283     } else {
2284       // make the backward weight call
2285       RETURN_IF_CUDNN_ERROR(cudnnRNNBackwardWeights(
2286           /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(),
2287           /*seqLength=*/model_dims.max_seq_length,
2288           /*xDesc=*/input_desc.handles(),
2289           /*x=*/input_data.opaque(), /*hxDesc=*/input_h_desc.handle(),
2290           /*hx=*/input_h_data.opaque(), /*yDesc=*/output_desc.handles(),
2291           /*y=*/output_data.opaque(), /*workspace=*/workspace.opaque(),
2292           /*workSpaceSizeInBytes=*/workspace.size(),
2293           /*dwDesc=*/rnn_desc.params_handle(),
2294           /*dw=*/params_backprop_data->opaque(),
2295           /*reserveSpace=*/reserve_space_data->opaque(),
2296           /*reserveSpaceSizeInBytes=*/reserve_space_data->size()));
2297     }
2298   }
2299 
2300   if (is_profiling) {
2301     if (!timer->Stop(AsGpuStream(stream))) {
2302       return port::Status(port::error::INTERNAL, "Failed to stop timer");
2303     }
2304     auto algo_desc = *rnn_desc.algorithm_config().algorithm();
2305     output_profile_result->set_algorithm(algo_desc);
2306     output_profile_result->set_elapsed_time_in_ms(
2307         timer->GetElapsedMilliseconds());
2308   }
2309 
2310   return port::Status::OK();
2311 }
2312 
DoCtcLossImpl(Stream * stream,const CudnnRnnStateTensorDescriptor & probs_desc,const DeviceMemoryBase probs_data,absl::Span<const int> labels_data,absl::Span<const int> labels_lengths_data,absl::Span<const int> input_lengths_data,DeviceMemoryBase costs_data,const CudnnRnnStateTensorDescriptor & grads_desc,DeviceMemoryBase grads_data,const CudnnCtcLossDescriptor & ctc_loss_desc,DeviceMemory<uint8> scratch_memory,int ctc_loss_algo_id)2313 port::Status CudnnSupport::DoCtcLossImpl(
2314     Stream* stream, const CudnnRnnStateTensorDescriptor& probs_desc,
2315     const DeviceMemoryBase probs_data, absl::Span<const int> labels_data,
2316     absl::Span<const int> labels_lengths_data,
2317     absl::Span<const int> input_lengths_data, DeviceMemoryBase costs_data,
2318     const CudnnRnnStateTensorDescriptor& grads_desc,
2319     DeviceMemoryBase grads_data, const CudnnCtcLossDescriptor& ctc_loss_desc,
2320     DeviceMemory<uint8> scratch_memory, int ctc_loss_algo_id) {
2321   auto cudnn = cudnn_->GetHandle(parent_, stream);
2322 
2323   int kNumTimestamps = probs_desc.num_layers();
2324   int kBatchSize = probs_desc.batch_size();
2325   int kNumLabels = probs_desc.data_size();
2326   int total_size = kNumLabels * kNumTimestamps * kBatchSize;
2327   (void)total_size;
2328 
2329 #if CUDNN_VERSION >= 7603
2330   cudnnCTCLossAlgo_t ctc_loss_algo =
2331       static_cast<cudnnCTCLossAlgo_t>(ctc_loss_algo_id);
2332   RETURN_IF_CUDNN_ERROR(cudnnCTCLoss(
2333       /*handle=*/cudnn.handle(), /*probsDesc=*/probs_desc.handle(),
2334       /*probs=*/probs_data.opaque(), /*labels=*/labels_data.data(),
2335       /*labelLengths=*/labels_lengths_data.data(),
2336       /*inputLengths=*/input_lengths_data.data(),
2337       /*costs=*/costs_data.opaque(), /*gradientsDesc=*/grads_desc.handle(),
2338       /*gradients=*/grads_data.opaque(),
2339       /*algo=*/ctc_loss_algo,
2340       /*ctcLossDesc=*/ctc_loss_desc.handle(),
2341       /*workspace=*/scratch_memory.opaque(),
2342       /*workSpaceSizeInBytes=*/scratch_memory.size()));
2343 #else
2344   return port::Status(port::error::INVALID_ARGUMENT,
2345                       "No supported cudnnCTCLoss when "
2346                       "CUDNN_VERSION < 7.6.3");
2347 #endif
2348 
2349   return port::Status::OK();
2350 }
2351 
2352 port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>>
createRnnDescriptor(int num_layers,int hidden_size,int input_size,int cell_size,int batch_size,dnn::RnnInputMode input_mode,dnn::RnnDirectionMode direction_mode,dnn::RnnMode rnn_mode,dnn::DataType data_type,const dnn::AlgorithmConfig & algorithm_config,float dropout,uint64 seed,ScratchAllocator * state_allocator,bool use_padded_io)2353 CudnnSupport::createRnnDescriptor(
2354     int num_layers, int hidden_size, int input_size, int cell_size,
2355     int batch_size, dnn::RnnInputMode input_mode,
2356     dnn::RnnDirectionMode direction_mode, dnn::RnnMode rnn_mode,
2357     dnn::DataType data_type, const dnn::AlgorithmConfig& algorithm_config,
2358     float dropout, uint64 seed, ScratchAllocator* state_allocator,
2359     bool use_padded_io) {
2360   // Setting up a cudnnRNNDescriptor requires a cuDNN handle, but because it's
2361   // not enqueueing anything into a stream, we pass in the null stream.
2362   auto cudnn = cudnn_->GetHandle(parent_, /*stream=*/nullptr);
2363   SE_ASSIGN_OR_RETURN(
2364       CudnnRnnDescriptor rnn_desc,
2365       CudnnRnnDescriptor::Create(
2366           cudnn, num_layers, hidden_size, input_size, cell_size, batch_size,
2367           ToCudnnRnnInputMode(input_mode),
2368           ToCudnnRnnDirectionMode(direction_mode), ToCudnnRnnMode(rnn_mode),
2369           ToCudnnDataType(data_type), GetRnnComputeType(data_type),
2370           algorithm_config, dropout, seed, state_allocator, use_padded_io));
2371   return std::unique_ptr<dnn::RnnDescriptor>(
2372       new CudnnRnnDescriptor(std::move(rnn_desc)));
2373 }
2374 
2375 port::StatusOr<std::unique_ptr<dnn::RnnSequenceTensorDescriptor>>
createRnnSequenceTensorDescriptor(int max_seq_length,int batch_size,int data_size,dnn::DataType data_type)2376 CudnnSupport::createRnnSequenceTensorDescriptor(int max_seq_length,
2377                                                 int batch_size, int data_size,
2378                                                 dnn::DataType data_type) {
2379   SE_ASSIGN_OR_RETURN(CudnnRnnSequenceTensorDescriptor descriptor,
2380                       CudnnRnnSequenceTensorDescriptor::Create(
2381                           parent_, max_seq_length, batch_size, data_size,
2382                           ToCudnnDataType(data_type)));
2383   return std::unique_ptr<dnn::RnnSequenceTensorDescriptor>(
2384       new CudnnRnnSequenceTensorDescriptor(std::move(descriptor)));
2385 }
2386 
2387 port::StatusOr<std::unique_ptr<dnn::RnnSequenceTensorDescriptor>>
createRnnSequenceTensorDescriptor(int max_seq_length,int batch_size,int data_size,const absl::Span<const int> & seq_lengths,bool time_major,dnn::DataType data_type)2388 CudnnSupport::createRnnSequenceTensorDescriptor(
2389     int max_seq_length, int batch_size, int data_size,
2390     const absl::Span<const int>& seq_lengths, bool time_major,
2391     dnn::DataType data_type) {
2392   SE_ASSIGN_OR_RETURN(CudnnRnnSequenceTensorDescriptor descriptor,
2393                       CudnnRnnSequenceTensorDescriptor::Create(
2394                           parent_, max_seq_length, batch_size, data_size,
2395                           seq_lengths, time_major, ToCudnnDataType(data_type)));
2396   return std::unique_ptr<dnn::RnnSequenceTensorDescriptor>(
2397       new CudnnRnnSequenceTensorDescriptor(std::move(descriptor)));
2398 }
2399 
2400 port::StatusOr<std::unique_ptr<dnn::RnnStateTensorDescriptor>>
createRnnStateTensorDescriptor(int num_layer,int batch_size,int data_size,dnn::DataType data_type)2401 CudnnSupport::createRnnStateTensorDescriptor(int num_layer, int batch_size,
2402                                              int data_size,
2403                                              dnn::DataType data_type) {
2404   return std::unique_ptr<dnn::RnnStateTensorDescriptor>(
2405       new CudnnRnnStateTensorDescriptor(parent_, num_layer, batch_size,
2406                                         data_size, ToCudnnDataType(data_type)));
2407 }
2408 
DoRnnForward(Stream * stream,const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<Eigen::half> & input_data,const DeviceMemory<int> & seq_lengths_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<Eigen::half> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<Eigen::half> & input_c_data,const DeviceMemory<Eigen::half> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,DeviceMemory<Eigen::half> * output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,DeviceMemory<Eigen::half> * output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,DeviceMemory<Eigen::half> * output_c_data,bool is_training,ScratchAllocator * reserve_space_allocator,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)2409 bool CudnnSupport::DoRnnForward(
2410     Stream* stream, const dnn::RnnDescriptor& rnn_desc,
2411     const dnn::RnnSequenceTensorDescriptor& input_desc,
2412     const DeviceMemory<Eigen::half>& input_data,
2413     const DeviceMemory<int>& seq_lengths_data,
2414     const dnn::RnnStateTensorDescriptor& input_h_desc,
2415     const DeviceMemory<Eigen::half>& input_h_data,
2416     const dnn::RnnStateTensorDescriptor& input_c_desc,
2417     const DeviceMemory<Eigen::half>& input_c_data,
2418     const DeviceMemory<Eigen::half>& params,
2419     const dnn::RnnSequenceTensorDescriptor& output_desc,
2420     DeviceMemory<Eigen::half>* output_data,
2421     const dnn::RnnStateTensorDescriptor& output_h_desc,
2422     DeviceMemory<Eigen::half>* output_h_data,
2423     const dnn::RnnStateTensorDescriptor& output_c_desc,
2424     DeviceMemory<Eigen::half>* output_c_data, bool is_training,
2425     ScratchAllocator* reserve_space_allocator,
2426     ScratchAllocator* workspace_allocator,
2427     dnn::ProfileResult* output_profile_result) {
2428   const CudnnRnnDescriptor& cudnn_rnn_desc =
2429       static_cast<const CudnnRnnDescriptor&>(rnn_desc);
2430   const CudnnRnnSequenceTensorDescriptor& cudnn_input_desc =
2431       static_cast<const CudnnRnnSequenceTensorDescriptor&>(input_desc);
2432   const CudnnRnnStateTensorDescriptor& cudnn_input_h_desc =
2433       static_cast<const CudnnRnnStateTensorDescriptor&>(input_h_desc);
2434   const CudnnRnnStateTensorDescriptor& cudnn_input_c_desc =
2435       static_cast<const CudnnRnnStateTensorDescriptor&>(input_c_desc);
2436   const CudnnRnnSequenceTensorDescriptor& cudnn_output_desc =
2437       static_cast<const CudnnRnnSequenceTensorDescriptor&>(output_desc);
2438   const CudnnRnnStateTensorDescriptor& cudnn_output_h_desc =
2439       static_cast<const CudnnRnnStateTensorDescriptor&>(output_h_desc);
2440   const CudnnRnnStateTensorDescriptor& cudnn_output_c_desc =
2441       static_cast<const CudnnRnnStateTensorDescriptor&>(output_c_desc);
2442   return IsStatusOk(
2443       DoRnnForwardImpl<Eigen::half>(
2444           stream, cudnn_rnn_desc, cudnn_input_desc, input_data,
2445           seq_lengths_data, cudnn_input_h_desc, input_h_data,
2446           cudnn_input_c_desc, input_c_data, params, cudnn_output_desc,
2447           output_data, cudnn_output_h_desc, output_h_data, cudnn_output_c_desc,
2448           output_c_data, is_training, reserve_space_allocator,
2449           workspace_allocator, output_profile_result),
2450       /*report_error=*/!output_profile_result);
2451 }
2452 
DoRnnForward(Stream * stream,const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<float> & input_data,const DeviceMemory<int> & seq_lengths_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<float> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<float> & input_c_data,const DeviceMemory<float> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,DeviceMemory<float> * output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,DeviceMemory<float> * output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,DeviceMemory<float> * output_c_data,bool is_training,ScratchAllocator * reserve_space_allocator,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)2453 bool CudnnSupport::DoRnnForward(
2454     Stream* stream, const dnn::RnnDescriptor& rnn_desc,
2455     const dnn::RnnSequenceTensorDescriptor& input_desc,
2456     const DeviceMemory<float>& input_data,
2457     const DeviceMemory<int>& seq_lengths_data,
2458     const dnn::RnnStateTensorDescriptor& input_h_desc,
2459     const DeviceMemory<float>& input_h_data,
2460     const dnn::RnnStateTensorDescriptor& input_c_desc,
2461     const DeviceMemory<float>& input_c_data, const DeviceMemory<float>& params,
2462     const dnn::RnnSequenceTensorDescriptor& output_desc,
2463     DeviceMemory<float>* output_data,
2464     const dnn::RnnStateTensorDescriptor& output_h_desc,
2465     DeviceMemory<float>* output_h_data,
2466     const dnn::RnnStateTensorDescriptor& output_c_desc,
2467     DeviceMemory<float>* output_c_data, bool is_training,
2468     ScratchAllocator* reserve_space_allocator,
2469     ScratchAllocator* workspace_allocator,
2470     dnn::ProfileResult* output_profile_result) {
2471   const CudnnRnnDescriptor& cudnn_rnn_desc =
2472       static_cast<const CudnnRnnDescriptor&>(rnn_desc);
2473   const CudnnRnnSequenceTensorDescriptor& cudnn_input_desc =
2474       static_cast<const CudnnRnnSequenceTensorDescriptor&>(input_desc);
2475   const CudnnRnnStateTensorDescriptor& cudnn_input_h_desc =
2476       static_cast<const CudnnRnnStateTensorDescriptor&>(input_h_desc);
2477   const CudnnRnnStateTensorDescriptor& cudnn_input_c_desc =
2478       static_cast<const CudnnRnnStateTensorDescriptor&>(input_c_desc);
2479   const CudnnRnnSequenceTensorDescriptor& cudnn_output_desc =
2480       static_cast<const CudnnRnnSequenceTensorDescriptor&>(output_desc);
2481   const CudnnRnnStateTensorDescriptor& cudnn_output_h_desc =
2482       static_cast<const CudnnRnnStateTensorDescriptor&>(output_h_desc);
2483   const CudnnRnnStateTensorDescriptor& cudnn_output_c_desc =
2484       static_cast<const CudnnRnnStateTensorDescriptor&>(output_c_desc);
2485   return IsStatusOk(
2486       DoRnnForwardImpl<float>(
2487           stream, cudnn_rnn_desc, cudnn_input_desc, input_data,
2488           seq_lengths_data, cudnn_input_h_desc, input_h_data,
2489           cudnn_input_c_desc, input_c_data, params, cudnn_output_desc,
2490           output_data, cudnn_output_h_desc, output_h_data, cudnn_output_c_desc,
2491           output_c_data, is_training, reserve_space_allocator,
2492           workspace_allocator, output_profile_result),
2493       /*report_error=*/!output_profile_result);
2494 }
2495 
DoRnnForward(Stream * stream,const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<double> & input_data,const DeviceMemory<int> & seq_lengths_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<double> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<double> & input_c_data,const DeviceMemory<double> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,DeviceMemory<double> * output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,DeviceMemory<double> * output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,DeviceMemory<double> * output_c_data,bool is_training,ScratchAllocator * reserve_space_allocator,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)2496 bool CudnnSupport::DoRnnForward(
2497     Stream* stream, const dnn::RnnDescriptor& rnn_desc,
2498     const dnn::RnnSequenceTensorDescriptor& input_desc,
2499     const DeviceMemory<double>& input_data,
2500     const DeviceMemory<int>& seq_lengths_data,
2501     const dnn::RnnStateTensorDescriptor& input_h_desc,
2502     const DeviceMemory<double>& input_h_data,
2503     const dnn::RnnStateTensorDescriptor& input_c_desc,
2504     const DeviceMemory<double>& input_c_data,
2505     const DeviceMemory<double>& params,
2506     const dnn::RnnSequenceTensorDescriptor& output_desc,
2507     DeviceMemory<double>* output_data,
2508     const dnn::RnnStateTensorDescriptor& output_h_desc,
2509     DeviceMemory<double>* output_h_data,
2510     const dnn::RnnStateTensorDescriptor& output_c_desc,
2511     DeviceMemory<double>* output_c_data, bool is_training,
2512     ScratchAllocator* reserve_space_allocator,
2513     ScratchAllocator* workspace_allocator,
2514     dnn::ProfileResult* output_profile_result) {
2515   const CudnnRnnDescriptor& cudnn_rnn_desc =
2516       static_cast<const CudnnRnnDescriptor&>(rnn_desc);
2517   const CudnnRnnSequenceTensorDescriptor& cudnn_input_desc =
2518       static_cast<const CudnnRnnSequenceTensorDescriptor&>(input_desc);
2519   const CudnnRnnStateTensorDescriptor& cudnn_input_h_desc =
2520       static_cast<const CudnnRnnStateTensorDescriptor&>(input_h_desc);
2521   const CudnnRnnStateTensorDescriptor& cudnn_input_c_desc =
2522       static_cast<const CudnnRnnStateTensorDescriptor&>(input_c_desc);
2523   const CudnnRnnSequenceTensorDescriptor& cudnn_output_desc =
2524       static_cast<const CudnnRnnSequenceTensorDescriptor&>(output_desc);
2525   const CudnnRnnStateTensorDescriptor& cudnn_output_h_desc =
2526       static_cast<const CudnnRnnStateTensorDescriptor&>(output_h_desc);
2527   const CudnnRnnStateTensorDescriptor& cudnn_output_c_desc =
2528       static_cast<const CudnnRnnStateTensorDescriptor&>(output_c_desc);
2529   return IsStatusOk(
2530       DoRnnForwardImpl<double>(
2531           stream, cudnn_rnn_desc, cudnn_input_desc, input_data,
2532           seq_lengths_data, cudnn_input_h_desc, input_h_data,
2533           cudnn_input_c_desc, input_c_data, params, cudnn_output_desc,
2534           output_data, cudnn_output_h_desc, output_h_data, cudnn_output_c_desc,
2535           output_c_data, is_training, reserve_space_allocator,
2536           workspace_allocator, output_profile_result),
2537       /*report_error=*/!output_profile_result);
2538 }
2539 
DoRnnBackward(Stream * stream,const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<Eigen::half> & input_data,const DeviceMemory<int> & seq_lengths_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<Eigen::half> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<Eigen::half> & input_c_data,const DeviceMemory<Eigen::half> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,const DeviceMemory<Eigen::half> & output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,const DeviceMemory<Eigen::half> & output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,const DeviceMemory<Eigen::half> & output_c_data,const DeviceMemory<Eigen::half> & output_backprop_data,const DeviceMemory<Eigen::half> & output_h_backprop_data,const DeviceMemory<Eigen::half> & output_c_backprop_data,DeviceMemory<Eigen::half> * input_backprop_data,DeviceMemory<Eigen::half> * input_h_backprop_data,DeviceMemory<Eigen::half> * input_c_backprop_data,DeviceMemory<Eigen::half> * params_backprop_data,DeviceMemory<uint8> * reserve_space_data,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)2540 bool CudnnSupport::DoRnnBackward(
2541     Stream* stream, const dnn::RnnDescriptor& rnn_desc,
2542     const dnn::RnnSequenceTensorDescriptor& input_desc,
2543     const DeviceMemory<Eigen::half>& input_data,
2544     const DeviceMemory<int>& seq_lengths_data,
2545     const dnn::RnnStateTensorDescriptor& input_h_desc,
2546     const DeviceMemory<Eigen::half>& input_h_data,
2547     const dnn::RnnStateTensorDescriptor& input_c_desc,
2548     const DeviceMemory<Eigen::half>& input_c_data,
2549     const DeviceMemory<Eigen::half>& params,
2550     const dnn::RnnSequenceTensorDescriptor& output_desc,
2551     const DeviceMemory<Eigen::half>& output_data,
2552     const dnn::RnnStateTensorDescriptor& output_h_desc,
2553     const DeviceMemory<Eigen::half>& output_h_data,
2554     const dnn::RnnStateTensorDescriptor& output_c_desc,
2555     const DeviceMemory<Eigen::half>& output_c_data,
2556     const DeviceMemory<Eigen::half>& output_backprop_data,
2557     const DeviceMemory<Eigen::half>& output_h_backprop_data,
2558     const DeviceMemory<Eigen::half>& output_c_backprop_data,
2559     DeviceMemory<Eigen::half>* input_backprop_data,
2560     DeviceMemory<Eigen::half>* input_h_backprop_data,
2561     DeviceMemory<Eigen::half>* input_c_backprop_data,
2562     DeviceMemory<Eigen::half>* params_backprop_data,
2563     DeviceMemory<uint8>* reserve_space_data,
2564     ScratchAllocator* workspace_allocator,
2565     dnn::ProfileResult* output_profile_result) {
2566   const CudnnRnnDescriptor& cudnn_rnn_desc =
2567       static_cast<const CudnnRnnDescriptor&>(rnn_desc);
2568   const CudnnRnnSequenceTensorDescriptor& cudnn_input_desc =
2569       static_cast<const CudnnRnnSequenceTensorDescriptor&>(input_desc);
2570   const CudnnRnnStateTensorDescriptor& cudnn_input_h_desc =
2571       static_cast<const CudnnRnnStateTensorDescriptor&>(input_h_desc);
2572   const CudnnRnnStateTensorDescriptor& cudnn_input_c_desc =
2573       static_cast<const CudnnRnnStateTensorDescriptor&>(input_c_desc);
2574   const CudnnRnnSequenceTensorDescriptor& cudnn_output_desc =
2575       static_cast<const CudnnRnnSequenceTensorDescriptor&>(output_desc);
2576   const CudnnRnnStateTensorDescriptor& cudnn_output_h_desc =
2577       static_cast<const CudnnRnnStateTensorDescriptor&>(output_h_desc);
2578   const CudnnRnnStateTensorDescriptor& cudnn_output_c_desc =
2579       static_cast<const CudnnRnnStateTensorDescriptor&>(output_c_desc);
2580   return IsStatusOk(
2581       DoRnnBackwardImpl<Eigen::half>(
2582           stream, cudnn_rnn_desc, cudnn_input_desc, input_data,
2583           seq_lengths_data, cudnn_input_h_desc, input_h_data,
2584           cudnn_input_c_desc, input_c_data, params, cudnn_output_desc,
2585           output_data, cudnn_output_h_desc, output_h_data, cudnn_output_c_desc,
2586           output_c_data, output_backprop_data, output_h_backprop_data,
2587           output_c_backprop_data, input_backprop_data, input_h_backprop_data,
2588           input_c_backprop_data, params_backprop_data, reserve_space_data,
2589           workspace_allocator, output_profile_result),
2590       /*report_error=*/!output_profile_result);
2591 }
2592 
DoRnnBackward(Stream * stream,const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<float> & input_data,const DeviceMemory<int> & seq_lengths_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<float> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<float> & input_c_data,const DeviceMemory<float> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,const DeviceMemory<float> & output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,const DeviceMemory<float> & output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,const DeviceMemory<float> & output_c_data,const DeviceMemory<float> & output_backprop_data,const DeviceMemory<float> & output_h_backprop_data,const DeviceMemory<float> & output_c_backprop_data,DeviceMemory<float> * input_backprop_data,DeviceMemory<float> * input_h_backprop_data,DeviceMemory<float> * input_c_backprop_data,DeviceMemory<float> * params_backprop_data,DeviceMemory<uint8> * reserve_space_data,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)2593 bool CudnnSupport::DoRnnBackward(
2594     Stream* stream, const dnn::RnnDescriptor& rnn_desc,
2595     const dnn::RnnSequenceTensorDescriptor& input_desc,
2596     const DeviceMemory<float>& input_data,
2597     const DeviceMemory<int>& seq_lengths_data,
2598     const dnn::RnnStateTensorDescriptor& input_h_desc,
2599     const DeviceMemory<float>& input_h_data,
2600     const dnn::RnnStateTensorDescriptor& input_c_desc,
2601     const DeviceMemory<float>& input_c_data, const DeviceMemory<float>& params,
2602     const dnn::RnnSequenceTensorDescriptor& output_desc,
2603     const DeviceMemory<float>& output_data,
2604     const dnn::RnnStateTensorDescriptor& output_h_desc,
2605     const DeviceMemory<float>& output_h_data,
2606     const dnn::RnnStateTensorDescriptor& output_c_desc,
2607     const DeviceMemory<float>& output_c_data,
2608     const DeviceMemory<float>& output_backprop_data,
2609     const DeviceMemory<float>& output_h_backprop_data,
2610     const DeviceMemory<float>& output_c_backprop_data,
2611     DeviceMemory<float>* input_backprop_data,
2612     DeviceMemory<float>* input_h_backprop_data,
2613     DeviceMemory<float>* input_c_backprop_data,
2614     DeviceMemory<float>* params_backprop_data,
2615     DeviceMemory<uint8>* reserve_space_data,
2616     ScratchAllocator* workspace_allocator,
2617     dnn::ProfileResult* output_profile_result) {
2618   const CudnnRnnDescriptor& cudnn_rnn_desc =
2619       static_cast<const CudnnRnnDescriptor&>(rnn_desc);
2620   const CudnnRnnSequenceTensorDescriptor& cudnn_input_desc =
2621       static_cast<const CudnnRnnSequenceTensorDescriptor&>(input_desc);
2622   const CudnnRnnStateTensorDescriptor& cudnn_input_h_desc =
2623       static_cast<const CudnnRnnStateTensorDescriptor&>(input_h_desc);
2624   const CudnnRnnStateTensorDescriptor& cudnn_input_c_desc =
2625       static_cast<const CudnnRnnStateTensorDescriptor&>(input_c_desc);
2626   const CudnnRnnSequenceTensorDescriptor& cudnn_output_desc =
2627       static_cast<const CudnnRnnSequenceTensorDescriptor&>(output_desc);
2628   const CudnnRnnStateTensorDescriptor& cudnn_output_h_desc =
2629       static_cast<const CudnnRnnStateTensorDescriptor&>(output_h_desc);
2630   const CudnnRnnStateTensorDescriptor& cudnn_output_c_desc =
2631       static_cast<const CudnnRnnStateTensorDescriptor&>(output_c_desc);
2632   return IsStatusOk(
2633       DoRnnBackwardImpl<float>(
2634           stream, cudnn_rnn_desc, cudnn_input_desc, input_data,
2635           seq_lengths_data, cudnn_input_h_desc, input_h_data,
2636           cudnn_input_c_desc, input_c_data, params, cudnn_output_desc,
2637           output_data, cudnn_output_h_desc, output_h_data, cudnn_output_c_desc,
2638           output_c_data, output_backprop_data, output_h_backprop_data,
2639           output_c_backprop_data, input_backprop_data, input_h_backprop_data,
2640           input_c_backprop_data, params_backprop_data, reserve_space_data,
2641           workspace_allocator, output_profile_result),
2642       /*report_error=*/!output_profile_result);
2643 }
2644 
DoRnnBackward(Stream * stream,const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<double> & input_data,const DeviceMemory<int> & seq_lengths_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<double> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<double> & input_c_data,const DeviceMemory<double> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,const DeviceMemory<double> & output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,const DeviceMemory<double> & output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,const DeviceMemory<double> & output_c_data,const DeviceMemory<double> & output_backprop_data,const DeviceMemory<double> & output_h_backprop_data,const DeviceMemory<double> & output_c_backprop_data,DeviceMemory<double> * input_backprop_data,DeviceMemory<double> * input_h_backprop_data,DeviceMemory<double> * input_c_backprop_data,DeviceMemory<double> * params_backprop_data,DeviceMemory<uint8> * reserve_space_data,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)2645 bool CudnnSupport::DoRnnBackward(
2646     Stream* stream, const dnn::RnnDescriptor& rnn_desc,
2647     const dnn::RnnSequenceTensorDescriptor& input_desc,
2648     const DeviceMemory<double>& input_data,
2649     const DeviceMemory<int>& seq_lengths_data,
2650     const dnn::RnnStateTensorDescriptor& input_h_desc,
2651     const DeviceMemory<double>& input_h_data,
2652     const dnn::RnnStateTensorDescriptor& input_c_desc,
2653     const DeviceMemory<double>& input_c_data,
2654     const DeviceMemory<double>& params,
2655     const dnn::RnnSequenceTensorDescriptor& output_desc,
2656     const DeviceMemory<double>& output_data,
2657     const dnn::RnnStateTensorDescriptor& output_h_desc,
2658     const DeviceMemory<double>& output_h_data,
2659     const dnn::RnnStateTensorDescriptor& output_c_desc,
2660     const DeviceMemory<double>& output_c_data,
2661     const DeviceMemory<double>& output_backprop_data,
2662     const DeviceMemory<double>& output_h_backprop_data,
2663     const DeviceMemory<double>& output_c_backprop_data,
2664     DeviceMemory<double>* input_backprop_data,
2665     DeviceMemory<double>* input_h_backprop_data,
2666     DeviceMemory<double>* input_c_backprop_data,
2667     DeviceMemory<double>* params_backprop_data,
2668     DeviceMemory<uint8>* reserve_space_data,
2669     ScratchAllocator* workspace_allocator,
2670     dnn::ProfileResult* output_profile_result) {
2671   const CudnnRnnDescriptor& cudnn_rnn_desc =
2672       static_cast<const CudnnRnnDescriptor&>(rnn_desc);
2673   const CudnnRnnSequenceTensorDescriptor& cudnn_input_desc =
2674       static_cast<const CudnnRnnSequenceTensorDescriptor&>(input_desc);
2675   const CudnnRnnStateTensorDescriptor& cudnn_input_h_desc =
2676       static_cast<const CudnnRnnStateTensorDescriptor&>(input_h_desc);
2677   const CudnnRnnStateTensorDescriptor& cudnn_input_c_desc =
2678       static_cast<const CudnnRnnStateTensorDescriptor&>(input_c_desc);
2679   const CudnnRnnSequenceTensorDescriptor& cudnn_output_desc =
2680       static_cast<const CudnnRnnSequenceTensorDescriptor&>(output_desc);
2681   const CudnnRnnStateTensorDescriptor& cudnn_output_h_desc =
2682       static_cast<const CudnnRnnStateTensorDescriptor&>(output_h_desc);
2683   const CudnnRnnStateTensorDescriptor& cudnn_output_c_desc =
2684       static_cast<const CudnnRnnStateTensorDescriptor&>(output_c_desc);
2685   return IsStatusOk(
2686       DoRnnBackwardImpl<double>(
2687           stream, cudnn_rnn_desc, cudnn_input_desc, input_data,
2688           seq_lengths_data, cudnn_input_h_desc, input_h_data,
2689           cudnn_input_c_desc, input_c_data, params, cudnn_output_desc,
2690           output_data, cudnn_output_h_desc, output_h_data, cudnn_output_c_desc,
2691           output_c_data, output_backprop_data, output_h_backprop_data,
2692           output_c_backprop_data, input_backprop_data, input_h_backprop_data,
2693           input_c_backprop_data, params_backprop_data, reserve_space_data,
2694           workspace_allocator, output_profile_result),
2695       /*report_error=*/!output_profile_result);
2696 }
2697 
2698 namespace {
2699 
2700 // TODO(csigg): Merge a lot of duplicate code below for forward, backward data,
2701 // and backward filter.
2702 
GetCudnnConvolutionForwardAlgo(const CudnnHandle & cudnn,const CudnnTensorDescriptor & input_nd,const CudnnFilterDescriptor & filter,const CudnnConvolutionDescriptor & conv,const CudnnTensorDescriptor & output_nd,bool specify_workspace_limit,size_t memory_limit_bytes)2703 port::StatusOr<cudnnConvolutionFwdAlgo_t> GetCudnnConvolutionForwardAlgo(
2704     const CudnnHandle& cudnn, const CudnnTensorDescriptor& input_nd,
2705     const CudnnFilterDescriptor& filter, const CudnnConvolutionDescriptor& conv,
2706     const CudnnTensorDescriptor& output_nd, bool specify_workspace_limit,
2707     size_t memory_limit_bytes) {
2708 #if CUDNN_VERSION >= 8000
2709   const int num_requested_algos = 5;
2710   int num_returned_algos = 0;
2711   cudnnConvolutionFwdAlgoPerf_t perf_results[num_requested_algos];
2712 
2713   RETURN_IF_CUDNN_ERROR(cudnnGetConvolutionForwardAlgorithm_v7(
2714       cudnn.handle(), input_nd.handle(), filter.handle(), conv.handle(),
2715       output_nd.handle(), num_requested_algos, &num_returned_algos,
2716       perf_results));
2717 
2718   size_t mem_limit = specify_workspace_limit ? memory_limit_bytes : 0ULL;
2719   for (int r = 0; r < num_returned_algos; r++) {
2720     if (perf_results[r].status == CUDNN_STATUS_SUCCESS &&
2721         perf_results[r].algo != CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED &&
2722         perf_results[r].memory <= mem_limit) {
2723       return perf_results[r].algo;
2724     }
2725   }
2726   return port::Status(port::error::INTERNAL,
2727                       "cudnnGetConvolutionForwardAlgorithm_v7 returned "
2728                       "no suitable algorithms. This could be a cudnn bug.");
2729 #else
2730   cudnnConvolutionFwdPreference_t preference =
2731       specify_workspace_limit ? CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT
2732                               : CUDNN_CONVOLUTION_FWD_NO_WORKSPACE;
2733   cudnnConvolutionFwdAlgo_t algo_to_use;
2734   RETURN_IF_CUDNN_ERROR(cudnnGetConvolutionForwardAlgorithm(
2735       cudnn.handle(), input_nd.handle(), filter.handle(), conv.handle(),
2736       output_nd.handle(), preference, memory_limit_bytes, &algo_to_use));
2737   return algo_to_use;
2738 #endif
2739 }
2740 
2741 port::StatusOr<cudnnConvolutionBwdDataAlgo_t>
GetCudnnConvolutionBackwardDataAlgo(const CudnnHandle & cudnn,const CudnnTensorDescriptor & input_nd,const CudnnFilterDescriptor & filter,const CudnnConvolutionDescriptor & conv,const CudnnTensorDescriptor & output_nd,bool specify_workspace_limit,size_t memory_limit_bytes)2742 GetCudnnConvolutionBackwardDataAlgo(const CudnnHandle& cudnn,
2743                                     const CudnnTensorDescriptor& input_nd,
2744                                     const CudnnFilterDescriptor& filter,
2745                                     const CudnnConvolutionDescriptor& conv,
2746                                     const CudnnTensorDescriptor& output_nd,
2747                                     bool specify_workspace_limit,
2748                                     size_t memory_limit_bytes) {
2749 #if CUDNN_VERSION >= 8000
2750   const int num_requested_algos = 5;
2751   int num_returned_algos = 0;
2752   cudnnConvolutionBwdDataAlgoPerf_t perf_results[num_requested_algos];
2753 
2754   RETURN_IF_CUDNN_ERROR(cudnnGetConvolutionBackwardDataAlgorithm_v7(
2755       cudnn.handle(), filter.handle(), output_nd.handle(), conv.handle(),
2756       input_nd.handle(), num_requested_algos, &num_returned_algos,
2757       perf_results));
2758 
2759   size_t mem_limit = specify_workspace_limit ? memory_limit_bytes : 0ULL;
2760   for (int r = 0; r < num_returned_algos; r++) {
2761     if (perf_results[r].status == CUDNN_STATUS_SUCCESS &&
2762         perf_results[r].algo !=
2763             CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD_NONFUSED &&
2764         perf_results[r].memory <= mem_limit) {
2765       return perf_results[r].algo;
2766     }
2767   }
2768   return port::Status(port::error::INTERNAL,
2769                       "cudnnGetConvolutionBackwardDataAlgorithm_v7 returned "
2770                       "no suitable algorithms. This could be a cudnn bug.");
2771 #else
2772   cudnnConvolutionBwdDataPreference_t preference =
2773       specify_workspace_limit
2774           ? CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT
2775           : CUDNN_CONVOLUTION_BWD_DATA_NO_WORKSPACE;
2776   cudnnConvolutionBwdDataAlgo_t algo_to_use;
2777   RETURN_IF_CUDNN_ERROR(cudnnGetConvolutionBackwardDataAlgorithm(
2778       cudnn.handle(), filter.handle(), output_nd.handle(), conv.handle(),
2779       input_nd.handle(), preference, memory_limit_bytes, &algo_to_use));
2780   return algo_to_use;
2781 #endif
2782 }
2783 
2784 port::StatusOr<cudnnConvolutionBwdFilterAlgo_t>
GetCudnnConvolutionBackwardFilterAlgo(const CudnnHandle & cudnn,const CudnnTensorDescriptor & input_nd,const CudnnFilterDescriptor & filter,const CudnnConvolutionDescriptor & conv,const CudnnTensorDescriptor & output_nd,bool specify_workspace_limit,size_t memory_limit_bytes)2785 GetCudnnConvolutionBackwardFilterAlgo(const CudnnHandle& cudnn,
2786                                       const CudnnTensorDescriptor& input_nd,
2787                                       const CudnnFilterDescriptor& filter,
2788                                       const CudnnConvolutionDescriptor& conv,
2789                                       const CudnnTensorDescriptor& output_nd,
2790                                       bool specify_workspace_limit,
2791                                       size_t memory_limit_bytes) {
2792 #if CUDNN_VERSION >= 8000
2793   const int num_requested_algos = 5;
2794   int num_returned_algos = 0;
2795   cudnnConvolutionBwdFilterAlgoPerf_t perf_results[num_requested_algos];
2796 
2797   RETURN_IF_CUDNN_ERROR(cudnnGetConvolutionBackwardFilterAlgorithm_v7(
2798       cudnn.handle(), input_nd.handle(), output_nd.handle(), conv.handle(),
2799       filter.handle(), num_requested_algos, &num_returned_algos, perf_results));
2800 
2801   size_t mem_limit = specify_workspace_limit ? memory_limit_bytes : 0ULL;
2802   for (int r = 0; r < num_returned_algos; r++) {
2803     if (perf_results[r].status == CUDNN_STATUS_SUCCESS &&
2804         perf_results[r].algo !=
2805             CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD_NONFUSED &&
2806         perf_results[r].memory <= mem_limit) {
2807       return perf_results[r].algo;
2808     }
2809   }
2810   return port::Status(port::error::INTERNAL,
2811                       "cudnnGetConvolutionBackwardFilterAlgorithm_v7 returned "
2812                       "no suitable algorithms. This could be a cudnn bug.");
2813 #else
2814   cudnnConvolutionBwdFilterPreference_t preference =
2815       specify_workspace_limit
2816           ? CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT
2817           : CUDNN_CONVOLUTION_BWD_FILTER_NO_WORKSPACE;
2818   cudnnConvolutionBwdFilterAlgo_t algo_to_use;
2819   RETURN_IF_CUDNN_ERROR(cudnnGetConvolutionBackwardFilterAlgorithm(
2820       cudnn.handle(), input_nd.handle(), output_nd.handle(), conv.handle(),
2821       filter.handle(), preference, memory_limit_bytes, &algo_to_use));
2822   return algo_to_use;
2823 #endif
2824 }
2825 
AllocateCudnnConvolutionForwardWorkspace(Stream * stream,const CudnnHandle & cudnn,const CudnnTensorDescriptor & input_nd,const CudnnFilterDescriptor & filter,const CudnnConvolutionDescriptor & conv,const CudnnTensorDescriptor & output_nd,const dnn::AlgorithmDesc & algorithm_desc,ScratchAllocator * scratch_allocator)2826 port::StatusOr<DeviceMemory<uint8>> AllocateCudnnConvolutionForwardWorkspace(
2827     Stream* stream, const CudnnHandle& cudnn,
2828     const CudnnTensorDescriptor& input_nd, const CudnnFilterDescriptor& filter,
2829     const CudnnConvolutionDescriptor& conv,
2830     const CudnnTensorDescriptor& output_nd,
2831     const dnn::AlgorithmDesc& algorithm_desc,
2832     ScratchAllocator* scratch_allocator) {
2833   if (IsTensorMathOpSet(conv) != algorithm_desc.tensor_ops_enabled()) {
2834     return port::Status(
2835         port::error::INTERNAL,
2836         "Mismatch between cudnn conv and algorithm descriptors.");
2837   }
2838 
2839   // Query the size of the workspace and allocate it.
2840   size_t size_in_bytes;
2841   RETURN_IF_CUDNN_ERROR(cudnnGetConvolutionForwardWorkspaceSize(
2842       cudnn.handle(),
2843       /*xDesc=*/input_nd.handle(),
2844       /*wDesc=*/filter.handle(), /*convDesc=*/conv.handle(),
2845       /*yDesc=*/output_nd.handle(), /*algo=*/ToConvForwardAlgo(algorithm_desc),
2846       /*sizeInBytes=*/&size_in_bytes));
2847 
2848   int64_t size_in_bytes_int64 = size_in_bytes;
2849 
2850   if (TF_PREDICT_FALSE(size_in_bytes_int64 < 0)) {
2851     return port::Status(
2852         port::error::INTERNAL,
2853         "cudnnGetConvolutionForwardWorkspaceSize() returned "
2854         "negative sizeInBytes value. This could be a cudnn bug.");
2855   }
2856 
2857   if (size_in_bytes_int64 == 0) {
2858     return DeviceMemory<uint8>();
2859   }
2860 
2861   if (TF_PREDICT_FALSE(!scratch_allocator)) {
2862     return port::Status(port::error::INVALID_ARGUMENT,
2863                         "No scratch allocator provided");
2864   }
2865 
2866   return scratch_allocator->AllocateBytes(size_in_bytes);
2867 }
2868 
2869 port::StatusOr<DeviceMemory<uint8>>
AllocateCudnnConvolutionBackwardDataWorkspace(Stream * stream,const CudnnHandle & cudnn,const CudnnTensorDescriptor & input_nd,const CudnnFilterDescriptor & filter,const CudnnConvolutionDescriptor & conv,const CudnnTensorDescriptor & output_nd,const dnn::AlgorithmDesc & algorithm_desc,ScratchAllocator * scratch_allocator)2870 AllocateCudnnConvolutionBackwardDataWorkspace(
2871     Stream* stream, const CudnnHandle& cudnn,
2872     const CudnnTensorDescriptor& input_nd, const CudnnFilterDescriptor& filter,
2873     const CudnnConvolutionDescriptor& conv,
2874     const CudnnTensorDescriptor& output_nd,
2875     const dnn::AlgorithmDesc& algorithm_desc,
2876     ScratchAllocator* scratch_allocator) {
2877   if (IsTensorMathOpSet(conv) != algorithm_desc.tensor_ops_enabled()) {
2878     return port::Status(
2879         port::error::INTERNAL,
2880         "Mismatch between cudnn conv and algorithm descriptors.");
2881   }
2882 
2883   // Query the size of the workspace and allocate it.
2884   size_t size_in_bytes;
2885   RETURN_IF_CUDNN_ERROR(cudnnGetConvolutionBackwardDataWorkspaceSize(
2886       cudnn.handle(),
2887       /*wDesc=*/filter.handle(),
2888       /*dyDesc=*/output_nd.handle(),
2889       /*convDesc=*/conv.handle(),
2890       /*dxDesc=*/input_nd.handle(),
2891       /*algo=*/ToConvBackwardDataAlgo(algorithm_desc),
2892       /*sizeInBytes=*/&size_in_bytes));
2893 
2894   int64_t size_in_bytes_int64 = size_in_bytes;
2895 
2896   if (TF_PREDICT_FALSE(size_in_bytes_int64 < 0)) {
2897     return port::Status(
2898         port::error::INTERNAL,
2899         "cudnnGetConvolutionBackwardDataWorkspaceSize() returned "
2900         "negative sizeInBytes value. This could be a cudnn bug.");
2901   }
2902 
2903   if (size_in_bytes_int64 == 0) {
2904     return DeviceMemory<uint8>();
2905   }
2906 
2907   if (TF_PREDICT_FALSE(!scratch_allocator)) {
2908     return port::Status(port::error::INVALID_ARGUMENT,
2909                         "No scratch allocator provided");
2910   }
2911 
2912   return scratch_allocator->AllocateBytes(size_in_bytes);
2913 }
2914 
2915 port::StatusOr<DeviceMemory<uint8>>
AllocateCudnnConvolutionBackwardFilterWorkspace(Stream * stream,const CudnnHandle & cudnn,const CudnnTensorDescriptor & input_nd,const CudnnFilterDescriptor & filter,const CudnnConvolutionDescriptor & conv,const CudnnTensorDescriptor & output_nd,const dnn::AlgorithmDesc & algorithm_desc,ScratchAllocator * scratch_allocator)2916 AllocateCudnnConvolutionBackwardFilterWorkspace(
2917     Stream* stream, const CudnnHandle& cudnn,
2918     const CudnnTensorDescriptor& input_nd, const CudnnFilterDescriptor& filter,
2919     const CudnnConvolutionDescriptor& conv,
2920     const CudnnTensorDescriptor& output_nd,
2921     const dnn::AlgorithmDesc& algorithm_desc,
2922     ScratchAllocator* scratch_allocator) {
2923   if (IsTensorMathOpSet(conv) != algorithm_desc.tensor_ops_enabled()) {
2924     return port::Status(
2925         port::error::INTERNAL,
2926         "Mismatch between cudnn conv and algorithm descriptors.");
2927   }
2928 
2929   // Query the size of the workspace and allocate it.
2930   size_t size_in_bytes;
2931   RETURN_IF_CUDNN_ERROR(cudnnGetConvolutionBackwardFilterWorkspaceSize(
2932       cudnn.handle(),
2933       /*xDesc=*/input_nd.handle(),
2934       /*dyDesc=*/output_nd.handle(),
2935       /*convDesc=*/conv.handle(),
2936       /*gradDesc=*/filter.handle(),
2937       /*algo=*/ToConvBackwardFilterAlgo(algorithm_desc),
2938       /*sizeInBytes=*/&size_in_bytes));
2939 
2940   int64_t size_in_bytes_int64 = size_in_bytes;
2941 
2942   if (TF_PREDICT_FALSE(size_in_bytes_int64 < 0)) {
2943     return port::Status(
2944         port::error::INTERNAL,
2945         "cudnnGetConvolutionBackwardFilterWorkspaceSize() returned "
2946         "negative sizeInBytes value. This could be a cudnn bug.");
2947   }
2948 
2949   if (size_in_bytes_int64 == 0) {
2950     return DeviceMemory<uint8>();
2951   }
2952 
2953   if (TF_PREDICT_FALSE(!scratch_allocator)) {
2954     return port::Status(port::error::INVALID_ARGUMENT,
2955                         "No scratch allocator provided");
2956   }
2957 
2958   return scratch_allocator->AllocateBytes(size_in_bytes);
2959 }
2960 
UseTensorOps(Stream * stream,dnn::DataType type,absl::optional<dnn::AlgorithmDesc> desc)2961 port::StatusOr<bool> UseTensorOps(Stream* stream, dnn::DataType type,
2962                                   absl::optional<dnn::AlgorithmDesc> desc) {
2963   bool use_tensor_ops;
2964   if (desc.has_value()) {
2965     use_tensor_ops = desc->tensor_ops_enabled();
2966     if (use_tensor_ops && !IsTensorMathEnabled(stream, type)) {
2967       return port::Status(port::error::INVALID_ARGUMENT,
2968                           "Algo requests disabled tensor op evaluation.");
2969     }
2970   } else {
2971     use_tensor_ops = IsTensorMathEnabled(stream, type);
2972   }
2973   return use_tensor_ops;
2974 }
2975 
2976 cudnnDataType_t GetRnnComputeType(dnn::DataType data_type);
2977 dnn::DataType GetConvAccumulatorType(dnn::DataType data_type);
2978 
GetCudnnConvolutionForwardAlgorithm(Stream * stream,const CudnnHandle & cudnn,const dnn::AlgorithmConfig & algorithm_config,const CudnnTensorDescriptor & input_nd,const CudnnFilterDescriptor & filter,dnn::DataType element_type,const dnn::ConvolutionDescriptor & convolution_descriptor,const CudnnTensorDescriptor & output_nd,ScratchAllocator * scratch_allocator,DeviceMemory<uint8> * scratch)2979 port::StatusOr<dnn::AlgorithmDesc> GetCudnnConvolutionForwardAlgorithm(
2980     Stream* stream, const CudnnHandle& cudnn,
2981     const dnn::AlgorithmConfig& algorithm_config,
2982     const CudnnTensorDescriptor& input_nd, const CudnnFilterDescriptor& filter,
2983     dnn::DataType element_type,
2984     const dnn::ConvolutionDescriptor& convolution_descriptor,
2985     const CudnnTensorDescriptor& output_nd, ScratchAllocator* scratch_allocator,
2986     DeviceMemory<uint8>* scratch) {
2987   absl::optional<dnn::AlgorithmDesc> algo_desc = algorithm_config.algorithm();
2988 
2989   CudnnConvolutionDescriptor conv(
2990       convolution_descriptor,
2991       ToCudnnDataType(GetConvAccumulatorType(element_type)));
2992   bool use_tensor_ops;
2993   SE_ASSIGN_OR_RETURN(use_tensor_ops,
2994                       UseTensorOps(stream, element_type, algo_desc));
2995   conv.set_use_tensor_op_math(use_tensor_ops);
2996 
2997   if (!algo_desc.has_value()) {
2998     // Pick fastest algorithm within memory limit according to cuDNN's
2999     // heuristics.
3000     bool specify_workspace_limit = scratch_allocator != nullptr;
3001     auto memory_limit_bytes =
3002         specify_workspace_limit
3003             ? std::max(scratch_allocator->GetMemoryLimitInBytes(), int64{0})
3004             : int64{0};
3005     SE_ASSIGN_OR_RETURN(cudnnConvolutionFwdAlgo_t algo,
3006                         GetCudnnConvolutionForwardAlgo(
3007                             cudnn, input_nd, filter, conv, output_nd,
3008                             specify_workspace_limit, memory_limit_bytes));
3009     algo_desc = dnn::AlgorithmDesc(algo, use_tensor_ops);
3010   }
3011 
3012   const auto scratch_or = AllocateCudnnConvolutionForwardWorkspace(
3013       stream, cudnn, input_nd, filter, conv, output_nd, *algo_desc,
3014       scratch_allocator);
3015 
3016   if (scratch_or.ok()) {
3017     *scratch = scratch_or.ValueOrDie();
3018     return *algo_desc;
3019   }
3020 
3021   algo_desc = algorithm_config.algorithm_no_scratch();
3022 
3023   // Failed to allocate workspace for the first algorithm, fall back to the
3024   // no_scratch algorithm.
3025   if (!algo_desc.has_value()) {
3026     return port::Status(
3027         scratch_or.status().code(),
3028         absl::StrCat("The primary convolution algorithm failed, ",
3029                      "while a secondary algorithm is not provided. ",
3030                      "Returned status: ", scratch_or.status().ToString()));
3031   }
3032 
3033   SE_ASSIGN_OR_RETURN(use_tensor_ops,
3034                       UseTensorOps(stream, element_type, algo_desc));
3035   conv.set_use_tensor_op_math(use_tensor_ops);
3036   SE_ASSIGN_OR_RETURN(*scratch, AllocateCudnnConvolutionForwardWorkspace(
3037                                     stream, cudnn, input_nd, filter, conv,
3038                                     output_nd, *algo_desc, scratch_allocator));
3039   return *algo_desc;
3040 }
3041 
GetCudnnConvolutionBackwardDataAlgorithm(Stream * stream,const CudnnHandle & cudnn,const dnn::AlgorithmConfig & algorithm_config,const CudnnTensorDescriptor & input_nd,const CudnnFilterDescriptor & filter,dnn::DataType element_type,const dnn::ConvolutionDescriptor & convolution_descriptor,const CudnnTensorDescriptor & output_nd,ScratchAllocator * scratch_allocator,DeviceMemory<uint8> * scratch)3042 port::StatusOr<dnn::AlgorithmDesc> GetCudnnConvolutionBackwardDataAlgorithm(
3043     Stream* stream, const CudnnHandle& cudnn,
3044     const dnn::AlgorithmConfig& algorithm_config,
3045     const CudnnTensorDescriptor& input_nd, const CudnnFilterDescriptor& filter,
3046     dnn::DataType element_type,
3047     const dnn::ConvolutionDescriptor& convolution_descriptor,
3048     const CudnnTensorDescriptor& output_nd, ScratchAllocator* scratch_allocator,
3049     DeviceMemory<uint8>* scratch) {
3050   absl::optional<dnn::AlgorithmDesc> algo_desc = algorithm_config.algorithm();
3051   CudnnConvolutionDescriptor conv(
3052       convolution_descriptor,
3053       ToCudnnDataType(GetConvAccumulatorType(element_type)));
3054   bool use_tensor_ops;
3055   SE_ASSIGN_OR_RETURN(use_tensor_ops,
3056                       UseTensorOps(stream, element_type, algo_desc));
3057   conv.set_use_tensor_op_math(use_tensor_ops);
3058 
3059   if (!algo_desc.has_value()) {
3060     // Pick fastest algorithm within memory limit according to cuDNN's
3061     // heuristics.
3062     bool specify_workspace_limit = scratch_allocator != nullptr;
3063     auto memory_limit_bytes =
3064         specify_workspace_limit
3065             ? std::max(scratch_allocator->GetMemoryLimitInBytes(), int64{0})
3066             : int64{0};
3067     SE_ASSIGN_OR_RETURN(cudnnConvolutionBwdDataAlgo_t algo,
3068                         GetCudnnConvolutionBackwardDataAlgo(
3069                             cudnn, input_nd, filter, conv, output_nd,
3070                             specify_workspace_limit, memory_limit_bytes));
3071     algo_desc = dnn::AlgorithmDesc(algo, use_tensor_ops);
3072   }
3073 
3074   const auto scratch_or = AllocateCudnnConvolutionBackwardDataWorkspace(
3075       stream, cudnn, input_nd, filter, conv, output_nd, *algo_desc,
3076       scratch_allocator);
3077 
3078   if (scratch_or.ok()) {
3079     *scratch = scratch_or.ValueOrDie();
3080     return *algo_desc;
3081   }
3082 
3083   algo_desc = algorithm_config.algorithm_no_scratch();
3084 
3085   // Failed to allocate workspace for the first algorithm, fall back to the
3086   // no_scratch algorithm.
3087   if (!algo_desc.has_value()) {
3088     return port::Status(
3089         port::error::INVALID_ARGUMENT,
3090         "The primary convolution algorithm failed memory allocation, "
3091         "while a secondary algorithm is not provided.");
3092   }
3093 
3094   SE_ASSIGN_OR_RETURN(use_tensor_ops,
3095                       UseTensorOps(stream, element_type, algo_desc));
3096   conv.set_use_tensor_op_math(use_tensor_ops);
3097   SE_ASSIGN_OR_RETURN(*scratch, AllocateCudnnConvolutionBackwardDataWorkspace(
3098                                     stream, cudnn, input_nd, filter, conv,
3099                                     output_nd, *algo_desc, scratch_allocator));
3100   return *algo_desc;
3101 }
3102 
GetCudnnConvolutionBackwardFilterAlgorithm(Stream * stream,const CudnnHandle & cudnn,const dnn::AlgorithmConfig & algorithm_config,const CudnnTensorDescriptor & input_nd,const CudnnFilterDescriptor & filter,dnn::DataType element_type,const dnn::ConvolutionDescriptor & convolution_descriptor,const CudnnTensorDescriptor & output_nd,ScratchAllocator * scratch_allocator,DeviceMemory<uint8> * scratch)3103 port::StatusOr<dnn::AlgorithmDesc> GetCudnnConvolutionBackwardFilterAlgorithm(
3104     Stream* stream, const CudnnHandle& cudnn,
3105     const dnn::AlgorithmConfig& algorithm_config,
3106     const CudnnTensorDescriptor& input_nd, const CudnnFilterDescriptor& filter,
3107     dnn::DataType element_type,
3108     const dnn::ConvolutionDescriptor& convolution_descriptor,
3109     const CudnnTensorDescriptor& output_nd, ScratchAllocator* scratch_allocator,
3110     DeviceMemory<uint8>* scratch) {
3111   absl::optional<dnn::AlgorithmDesc> algo_desc = algorithm_config.algorithm();
3112   CudnnConvolutionDescriptor conv(
3113       convolution_descriptor,
3114       ToCudnnDataType(GetConvAccumulatorType(element_type)));
3115   bool use_tensor_ops;
3116   SE_ASSIGN_OR_RETURN(use_tensor_ops,
3117                       UseTensorOps(stream, element_type, algo_desc));
3118   conv.set_use_tensor_op_math(use_tensor_ops);
3119 
3120   if (!algo_desc.has_value()) {
3121     // Pick fastest algorithm within memory limit according to cuDNN's
3122     // heuristics.
3123     bool specify_workspace_limit = scratch_allocator != nullptr;
3124     auto memory_limit_bytes =
3125         specify_workspace_limit
3126             ? std::max(scratch_allocator->GetMemoryLimitInBytes(), int64{0})
3127             : int64{0};
3128     SE_ASSIGN_OR_RETURN(cudnnConvolutionBwdFilterAlgo_t algo,
3129                         GetCudnnConvolutionBackwardFilterAlgo(
3130                             cudnn, input_nd, filter, conv, output_nd,
3131                             specify_workspace_limit, memory_limit_bytes));
3132     algo_desc = dnn::AlgorithmDesc(algo, use_tensor_ops);
3133   }
3134 
3135   port::StatusOr<DeviceMemory<uint8>> scratch_or =
3136       AllocateCudnnConvolutionBackwardFilterWorkspace(
3137           stream, cudnn, input_nd, filter, conv, output_nd, *algo_desc,
3138           scratch_allocator);
3139 
3140   if (scratch_or.ok()) {
3141     *scratch = scratch_or.ValueOrDie();
3142     return *algo_desc;
3143   }
3144 
3145   algo_desc = algorithm_config.algorithm_no_scratch();
3146 
3147   // Failed to allocate workspace for the first algorithm, fall back to the
3148   // no_scratch algorithm.
3149   if (!algo_desc.has_value()) {
3150     return port::Status(
3151         port::error::INVALID_ARGUMENT,
3152         absl::StrCat(
3153             "The primary convolution algorithm failed memory allocation, "
3154             "while a secondary algorithm is not provided. Actual error: ",
3155             scratch_or.status().ToString()));
3156   }
3157 
3158   SE_ASSIGN_OR_RETURN(use_tensor_ops,
3159                       UseTensorOps(stream, element_type, algo_desc));
3160   conv.set_use_tensor_op_math(use_tensor_ops);
3161   SE_ASSIGN_OR_RETURN(*scratch, AllocateCudnnConvolutionBackwardFilterWorkspace(
3162                                     stream, cudnn, input_nd, filter, conv,
3163                                     output_nd, *algo_desc, scratch_allocator));
3164   return *algo_desc;
3165 }
3166 
3167 // A helper class to set env-vars and choose options for cudnn-related
3168 // algorithms.
3169 template <typename EnvVar>
3170 class CudnnEnvVar {
3171  public:
IsEnabled()3172   static bool IsEnabled() {
3173     static bool is_enabled = IsEnabledImpl();
3174     return is_enabled;
3175   }
3176 
3177  private:
IsEnabledImpl()3178   static bool IsEnabledImpl() {
3179     const char* tf_env_var_val = getenv(EnvVar::kName);
3180     if (tf_env_var_val != nullptr) {
3181       absl::string_view tf_env_var_val_str(tf_env_var_val);
3182       if (tf_env_var_val_str == "0") {
3183         return false;
3184       }
3185       return true;
3186     }
3187     return EnvVar::kDefaultFlag;
3188   }
3189 };
3190 
3191 // A helper struct to decide whether to enable the FFT_TILING algorithms for
3192 // forward convolution. It is disabled for cuDNN < 7 due to memory corruption
3193 // caused by some shapes with this algorithm. Users can explicitly enable the
3194 // algorithm through an env-var "TF_ENABLE_FFT_TILING_FORWARD=1".
3195 struct FftTilingForward {
3196   static constexpr const char* kName = "TF_ENABLE_FFT_TILING_FORWARD";
3197   static constexpr bool kDefaultFlag = true;
3198 };
3199 
3200 // A helper struct to decide whether to enable the WINOGRAD_NONFUSED algorithms.
3201 // By default it is turned on, users can explicitly disable them through an
3202 // env-var "TF_ENABLE_WINOGRAD_NONFUSED=0".
3203 // https://github.com/tensorflow/tensorflow/pull/4901
3204 // For CUDNN v8.1, when this env-var is turned off, both the winograd and
3205 // winograd-non-fused engines will be ruled out.
3206 struct WinogradNonfused {
3207   static constexpr const char* kName = "TF_ENABLE_WINOGRAD_NONFUSED";
3208   // NVIDIA has fixed winograd nonfused bug for cudnn v>=7. For older versions,
3209   // we have a workaround.
3210   static constexpr bool kDefaultFlag = true;
3211 };
3212 
3213 // A helper struct to decide whether to use FP32 as the internal compute type
3214 // for convolution when the input data type is FP16. By default it is turned on,
3215 // users can explicitly disable them (choose to use FP16 as the internal compute
3216 // type) through an env-var "TF_FP16_CONV_USE_FP32_COMPUTE=0".
3217 struct ConvDoFP32ComputationFP16Input {
3218   static constexpr const char* kName = "TF_FP16_CONV_USE_FP32_COMPUTE";
3219   // Using FP16 as the internal compute type for convolution when the input data
3220   // type is FP16 is only supported on architectures with true fp16 support
3221   // (compute capability 5.3 and 6.0). Setting this to false in an unsupported
3222   // architecture will cause internal errors.
3223   static constexpr bool kDefaultFlag = true;
3224 };
3225 
3226 // A helper struct to decide whether to use FP32 as the internal compute type
3227 // for rnn when the input data type is FP16. At present it is turned off,
3228 // users can explicitly control them through an env-var
3229 // TF_FP16_RNN_USE_FP32_COMPUTE.
3230 // After the TODO below is fixed, users should almost always use fp32 compute
3231 // type for training. Using fp16 might suffer suboptimal accuracy due to loss
3232 // in precision.
3233 struct RnnDoFP32ComputationFP16Input {
3234   static constexpr const char* kName = "TF_FP16_RNN_USE_FP32_COMPUTE";
3235   // TODO(jamesqin): b/78182362 flip to true when cudnn 7.1.4 fixes the bug.
3236   // Before cudnn 7.1.4 RNN are always done in fp32, no matter what math
3237   // precision is set.
3238   // Set it temporary to false s.t. no error is raised when using fp16 inputs,
3239   // fp32 math precision.
3240   //
3241   // cuDNN == 7.5.0 is verified to have this fixed.
3242   static constexpr bool kDefaultFlag = CUDNN_VERSION >= 7500;
3243 };
3244 
3245 namespace {
3246 
3247 #if CUDNN_VERSION >= 8100 && TF_ENABLE_CUDNN_FRONTEND
3248 
GenericEngineFilter(cudnnBackendDescriptor_t engine_config,bool disable_winograd,bool disable_nondeterminism,bool disable_tensor_core)3249 bool GenericEngineFilter(cudnnBackendDescriptor_t engine_config,
3250                          bool disable_winograd, bool disable_nondeterminism,
3251                          bool disable_tensor_core) {
3252   bool ret = cudnn_frontend::hasNumericalNote<
3253       CUDNN_NUMERICAL_NOTE_DOWN_CONVERT_INPUTS>(engine_config);
3254 
3255   if (disable_winograd) {
3256     ret |= cudnn_frontend::hasNumericalNote<CUDNN_NUMERICAL_NOTE_WINOGRAD>(
3257         engine_config);
3258   }
3259 
3260   if (disable_nondeterminism) {
3261     ret |=
3262         cudnn_frontend::hasNumericalNote<CUDNN_NUMERICAL_NOTE_NONDETERMINISTIC>(
3263             engine_config);
3264   }
3265 
3266   if (disable_tensor_core) {
3267     ret |= cudnn_frontend::hasNumericalNote<CUDNN_NUMERICAL_NOTE_TENSOR_CORE>(
3268         engine_config);
3269   }
3270 
3271   return ret;
3272 }
3273 
3274 #endif  // CUDNN_VERSION >= 8100 && TF_ENABLE_CUDNN_FRONTEND
3275 
3276 }  // namespace
3277 
GetRnnComputeType(dnn::DataType data_type)3278 cudnnDataType_t GetRnnComputeType(dnn::DataType data_type) {
3279   switch (data_type) {
3280     case dnn::DataType::kFloat:
3281       return CUDNN_DATA_FLOAT;
3282     case dnn::DataType::kDouble:
3283       return CUDNN_DATA_DOUBLE;
3284     case dnn::DataType::kHalf:
3285       if (CudnnEnvVar<RnnDoFP32ComputationFP16Input>::IsEnabled()) {
3286         return CUDNN_DATA_FLOAT;
3287       } else {
3288         return CUDNN_DATA_HALF;
3289       }
3290     default:
3291       LOG(FATAL) << "Invalid RNN data type: " << static_cast<int>(data_type);
3292   }
3293 }
3294 
GetConvAccumulatorType(dnn::DataType data_type)3295 dnn::DataType GetConvAccumulatorType(dnn::DataType data_type) {
3296   switch (data_type) {
3297     case dnn::DataType::kFloat:
3298     case dnn::DataType::kDouble:
3299       return data_type;
3300     case dnn::DataType::kHalf:
3301       return CudnnEnvVar<ConvDoFP32ComputationFP16Input>::IsEnabled()
3302                  ? dnn::DataType::kFloat
3303                  : dnn::DataType::kHalf;
3304     case dnn::DataType::kInt8:
3305     case dnn::DataType::kInt32:
3306       return dnn::DataType::kInt32;
3307 #if CUDNN_VERSION >= 8200
3308     case dnn::DataType::kBF16:
3309       return CudnnEnvVar<ConvDoFP32ComputationFP16Input>::IsEnabled()
3310                  ? dnn::DataType::kFloat
3311                  : dnn::DataType::kBF16;
3312 #endif
3313     default:
3314       LOG(FATAL) << "Invalid DNN data type: " << static_cast<int>(data_type);
3315   }
3316 }
3317 
3318 #if CUDNN_VERSION >= 8100 && TF_ENABLE_CUDNN_FRONTEND
GetCudnnConvolutionType(dnn::ConvolutionKind kind)3319 cudnnBackendDescriptorType_t GetCudnnConvolutionType(
3320     dnn::ConvolutionKind kind) {
3321   cudnnBackendDescriptorType_t conv_mode;
3322   switch (kind) {
3323     case dnn::ConvolutionKind::FORWARD: {
3324       conv_mode = CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR;
3325       break;
3326     }
3327     case dnn::ConvolutionKind::BACKWARD_DATA: {
3328       conv_mode = CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR;
3329       break;
3330     }
3331     case dnn::ConvolutionKind::BACKWARD_FILTER: {
3332       conv_mode =
3333           CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR;
3334       break;
3335     }
3336     default:
3337       LOG(FATAL) << "Unexpected convolution kind " << static_cast<int>(kind);
3338       break;
3339   }
3340   return conv_mode;
3341 }
3342 
3343 // Cudnn only supports vectorization over the channel dimension (e.g., int8x4,
3344 // or int8x32).
GetTensorVectorSizeAndDim(const dnn::BatchDescriptor & tensor,dnn::DataType element_type)3345 std::tuple<int, int> GetTensorVectorSizeAndDim(
3346     const dnn::BatchDescriptor& tensor, dnn::DataType element_type) {
3347   int vector_size = 1;
3348   int vector_dim = -1;
3349   if (element_type == dnn::DataType::kInt8) {
3350     if (tensor.layout() == dnn::DataLayout::kBatchDepthYX4) {
3351       vector_size = 4;
3352       vector_dim = 1;
3353     } else if (tensor.layout() == dnn::DataLayout::kBatchDepthYX32) {
3354       vector_size = 32;
3355       vector_dim = 1;
3356     }
3357   }
3358   return std::make_tuple(vector_size, vector_dim);
3359 }
3360 
GetTensorVectorSizeAndDim(const dnn::FilterDescriptor & filter,dnn::DataType element_type)3361 std::tuple<int, int> GetTensorVectorSizeAndDim(
3362     const dnn::FilterDescriptor& filter, dnn::DataType element_type) {
3363   int vector_size = 1;
3364   int vector_dim = -1;
3365   if (element_type == dnn::DataType::kInt8) {
3366     if (filter.layout() == dnn::FilterLayout::kOutputInputYX4) {
3367       vector_size = 4;
3368       vector_dim = 1;
3369     } else if (filter.layout() == dnn::FilterLayout::kOutputInputYX32) {
3370       vector_size = 32;
3371       vector_dim = 1;
3372     }
3373   }
3374   return std::make_tuple(vector_size, vector_dim);
3375 }
3376 
3377 port::StatusOr<std::unique_ptr<cudnn_frontend::OperationGraph>>
GetCudnnOperationGraph(dnn::ConvolutionKind kind,dnn::DataType element_type,Stream * stream,const dnn::BatchDescriptor & input_descriptor,const dnn::FilterDescriptor & filter_descriptor,const dnn::BatchDescriptor & output_descriptor,const dnn::ConvolutionDescriptor & convolution_descriptor,CudnnHandle & cudnn)3378 GetCudnnOperationGraph(dnn::ConvolutionKind kind, dnn::DataType element_type,
3379                        Stream* stream,
3380                        const dnn::BatchDescriptor& input_descriptor,
3381                        const dnn::FilterDescriptor& filter_descriptor,
3382                        const dnn::BatchDescriptor& output_descriptor,
3383                        const dnn::ConvolutionDescriptor& convolution_descriptor,
3384                        CudnnHandle& cudnn) {
3385   cudnnBackendDescriptorType_t conv_mode = GetCudnnConvolutionType(kind);
3386   cudnnDataType_t cudnn_type = ToCudnnDataType(element_type);
3387 
3388   // x tensor.
3389   int vector_size, vector_dim;
3390   std::tie(vector_size, vector_dim) =
3391       GetTensorVectorSizeAndDim(input_descriptor, element_type);
3392   std::vector<int64> input_dims = input_descriptor.vectorized_dims(
3393       dnn::DataLayout::kBatchDepthYX, vector_size, vector_dim);
3394   std::vector<int64> input_strides = input_descriptor.vectorized_strides(
3395       dnn::DataLayout::kBatchDepthYX, vector_size, vector_dim);
3396 
3397   if (vector_size == 32) {
3398     return port::InternalError(
3399         "cuDNN frontend doesn't support int8x32 at the "
3400         "moment.");
3401   }
3402 
3403   auto tensor_x = cudnn_frontend::TensorBuilder()
3404                       .setDim(input_dims.size(), input_dims.data())
3405                       .setStrides(input_dims.size(), input_strides.data())
3406                       .setId('x')
3407                       .setAlignment(32)
3408                       .setDataType(cudnn_type)
3409                       .setVectorCountAndDimension(vector_size, vector_dim)
3410                       .build();
3411   RETURN_MSG_IF_CUDNN_ERROR(tensor_x);
3412 
3413   // y tensor.
3414   std::tie(vector_size, vector_dim) =
3415       GetTensorVectorSizeAndDim(output_descriptor, element_type);
3416   std::vector<int64> output_dims = output_descriptor.vectorized_dims(
3417       dnn::DataLayout::kBatchDepthYX, vector_size, vector_dim);
3418   std::vector<int64> output_strides = output_descriptor.vectorized_strides(
3419       dnn::DataLayout::kBatchDepthYX, vector_size, vector_dim);
3420 
3421   auto tensor_y = cudnn_frontend::TensorBuilder()
3422                       .setDim(output_dims.size(), output_dims.data())
3423                       .setStrides(output_dims.size(), output_strides.data())
3424                       .setId('y')
3425                       .setAlignment(32)
3426                       .setDataType(cudnn_type)
3427                       .setVectorCountAndDimension(vector_size, vector_dim)
3428                       .build();
3429   RETURN_MSG_IF_CUDNN_ERROR(tensor_y);
3430 
3431   // w tensor.
3432   std::tie(vector_size, vector_dim) =
3433       GetTensorVectorSizeAndDim(filter_descriptor, element_type);
3434   std::vector<int64> filter_dims = filter_descriptor.vectorized_dims(
3435       dnn::FilterLayout::kOutputInputYX, vector_size, vector_dim);
3436   std::vector<int64> filter_strides = filter_descriptor.vectorized_strides(
3437       dnn::FilterLayout::kOutputInputYX, vector_size, vector_dim);
3438 
3439   auto tensor_w = cudnn_frontend::TensorBuilder()
3440                       .setDim(filter_dims.size(), filter_dims.data())
3441                       .setStrides(filter_dims.size(), filter_strides.data())
3442                       .setId('w')
3443                       .setAlignment(32)
3444                       .setDataType(cudnn_type)
3445                       .setVectorCountAndDimension(vector_size, vector_dim)
3446                       .build();
3447   RETURN_MSG_IF_CUDNN_ERROR(tensor_w);
3448 
3449   // conv_desc.
3450   auto mode = convolution_descriptor.convolution_not_crosscorr()
3451                   ? CUDNN_CONVOLUTION
3452                   : CUDNN_CROSS_CORRELATION;
3453 
3454   int conv_dim = convolution_descriptor.ndims();
3455 
3456   auto accumulator_type = ToCudnnDataType(GetConvAccumulatorType(element_type));
3457   CHECK_NE(convolution_descriptor.pad_alignment(),
3458            dnn::PadAlignment::kTensorFlowPadding)
3459       << "TensorFlow padding alignment is not supported.";
3460 
3461   auto conv_desc =
3462       cudnn_frontend::ConvDescBuilder()
3463           .setComputePrecision(accumulator_type)
3464           .setMathMode(mode)
3465           .setNDims(conv_dim)
3466           .setStrides(conv_dim, convolution_descriptor.strides().data())
3467           .setPrePadding(conv_dim, convolution_descriptor.padding().data())
3468           .setPostPadding(conv_dim, convolution_descriptor.padding().data())
3469           .setDilation(conv_dim, convolution_descriptor.dilations().data())
3470           .build();
3471   RETURN_MSG_IF_CUDNN_ERROR(conv_desc);
3472 
3473   double alpha = 1.0;
3474   double beta = 0.0;
3475 
3476   // CUDNN Operation
3477   auto op = cudnn_frontend::OperationBuilder(conv_mode)
3478                 .setxDesc(tensor_x)
3479                 .setyDesc(tensor_y)
3480                 .setwDesc(tensor_w)
3481                 .setcDesc(conv_desc)
3482                 .setAlpha(alpha)
3483                 .setBeta(beta)
3484                 .build();
3485   RETURN_MSG_IF_CUDNN_ERROR(op);
3486 
3487   // CUDNN OperationGraph
3488   std::array<cudnn_frontend::Operation const*, 1> ops = {&op};
3489   auto opGraph = cudnn_frontend::OperationGraphBuilder()
3490                      .setHandle(cudnn.handle())
3491                      .setOperationGraph(ops.size(), ops.data())
3492                      .build();
3493   RETURN_MSG_IF_CUDNN_ERROR(opGraph);
3494 
3495   VLOG(4) << "\nTensor_x: " << tensor_x.describe()
3496           << "\nTensor_y: " << tensor_y.describe()
3497           << "\nTensor_w: " << tensor_w.describe()
3498           << "\nConv: " << conv_desc.describe() << "\nOp: " << op.describe()
3499           << "\nOpGraph: " << opGraph.describe();
3500 
3501   return std::unique_ptr<cudnn_frontend::OperationGraph>(
3502       new cudnn_frontend::OperationGraph(std::move(opGraph)));
3503 }
3504 
3505 port::StatusOr<std::unique_ptr<cudnn_frontend::OperationGraph>>
GetCudnnFusedOperationGraph(dnn::ConvolutionKind kind,dnn::DataType element_type,double alpha,double alpha2,Stream * stream,const dnn::BatchDescriptor & input_descriptor,const dnn::FilterDescriptor & filter_descriptor,dnn::BatchDescriptor bias_descriptor,const dnn::BatchDescriptor & output_descriptor,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::ActivationMode activation_mode,CudnnHandle & cudnn)3506 GetCudnnFusedOperationGraph(
3507     dnn::ConvolutionKind kind, dnn::DataType element_type, double alpha,
3508     double alpha2, Stream* stream, const dnn::BatchDescriptor& input_descriptor,
3509     const dnn::FilterDescriptor& filter_descriptor,
3510     dnn::BatchDescriptor bias_descriptor,
3511     const dnn::BatchDescriptor& output_descriptor,
3512     const dnn::ConvolutionDescriptor& convolution_descriptor,
3513     const dnn::ActivationMode activation_mode, CudnnHandle& cudnn) {
3514   cudnnBackendDescriptorType_t conv_mode = GetCudnnConvolutionType(kind);
3515   cudnnDataType_t cudnn_type = ToCudnnDataType(element_type);
3516 
3517   // CUDNN fused operation supports the pattern in the form of
3518   // Conv + Add + BiasAdd + Act. Therefore, we need to build a graph of the
3519   // four ops with their input/output tensor edges:
3520   // Conv   : input: tensor_x, tensor_w;    output: tensor_conv (virtual)
3521   // Add    : input: tensor_conv, tensor_z; output: tensor_add (virtual)
3522   // BiasAdd: input: tensor_add, tensor_b;  output: tensor_bias (virtual)
3523   // Act    : input: tensor_bias;           output: tensor_y
3524   int vector_size, vector_dim;
3525   std::tie(vector_size, vector_dim) =
3526       GetTensorVectorSizeAndDim(input_descriptor, element_type);
3527   std::vector<int64> input_dims = input_descriptor.vectorized_dims(
3528       dnn::DataLayout::kBatchDepthYX, vector_size, vector_dim);
3529   std::vector<int64> input_strides = input_descriptor.vectorized_strides(
3530       dnn::DataLayout::kBatchDepthYX, vector_size, vector_dim);
3531 
3532   if (vector_size == 32) {
3533     return port::InternalError(
3534         "cuDNN frontend doesn't support int8x32 at the "
3535         "moment.");
3536   }
3537 
3538   auto tensor_x = cudnn_frontend::TensorBuilder()
3539                       .setDim(input_dims.size(), input_dims.data())
3540                       .setStrides(input_dims.size(), input_strides.data())
3541                       .setId('x')
3542                       .setAlignment(32)
3543                       .setDataType(cudnn_type)
3544                       .setVectorCountAndDimension(vector_size, vector_dim)
3545                       .build();
3546   RETURN_MSG_IF_CUDNN_ERROR(tensor_x);
3547 
3548   std::tie(vector_size, vector_dim) =
3549       GetTensorVectorSizeAndDim(output_descriptor, element_type);
3550   std::vector<int64> output_dims = output_descriptor.vectorized_dims(
3551       dnn::DataLayout::kBatchDepthYX, vector_size, vector_dim);
3552   std::vector<int64> output_strides = output_descriptor.vectorized_strides(
3553       dnn::DataLayout::kBatchDepthYX, vector_size, vector_dim);
3554   auto tensor_y = cudnn_frontend::TensorBuilder()
3555                       .setDim(output_dims.size(), output_dims.data())
3556                       .setStrides(output_dims.size(), output_strides.data())
3557                       .setId('y')
3558                       .setAlignment(32)
3559                       .setDataType(cudnn_type)
3560                       .setVectorCountAndDimension(vector_size, vector_dim)
3561                       .build();
3562   RETURN_MSG_IF_CUDNN_ERROR(tensor_y);
3563 
3564   auto tensor_z = cudnn_frontend::TensorBuilder()
3565                       .setDim(output_dims.size(), &output_dims[0])
3566                       .setStrides(output_dims.size(), &output_strides[0])
3567                       .setId('z')
3568                       .setAlignment(32)
3569                       .setDataType(cudnn_type)
3570                       .setVectorCountAndDimension(vector_size, vector_dim)
3571                       .build();
3572   RETURN_MSG_IF_CUDNN_ERROR(tensor_z);
3573 
3574   std::tie(vector_size, vector_dim) =
3575       GetTensorVectorSizeAndDim(filter_descriptor, element_type);
3576   std::vector<int64> filter_dims = filter_descriptor.vectorized_dims(
3577       dnn::FilterLayout::kOutputInputYX, vector_size, vector_dim);
3578   std::vector<int64> filter_strides = filter_descriptor.vectorized_strides(
3579       dnn::FilterLayout::kOutputInputYX, vector_size, vector_dim);
3580   auto tensor_w = cudnn_frontend::TensorBuilder()
3581                       .setDim(filter_dims.size(), filter_dims.data())
3582                       .setStrides(filter_dims.size(), filter_strides.data())
3583                       .setId('w')
3584                       .setAlignment(32)
3585                       .setDataType(cudnn_type)
3586                       .setVectorCountAndDimension(vector_size, vector_dim)
3587                       .build();
3588   RETURN_MSG_IF_CUDNN_ERROR(tensor_w);
3589 
3590   // For the purposes of the cudnn graph, say that the bias tensor has the same
3591   // layout as the output tensor.  It doesn't actually matter, because bias is a
3592   // 1D array.  But we need to get the correct vectorization, otherwise the
3593   // cudnn graph API rejects this tensor.
3594   bias_descriptor.set_layout(output_descriptor.layout());
3595 
3596   std::tie(vector_size, vector_dim) =
3597       GetTensorVectorSizeAndDim(bias_descriptor, element_type);
3598   std::vector<int64> bias_dims = bias_descriptor.vectorized_dims(
3599       dnn::DataLayout::kBatchDepthYX, vector_size, vector_dim);
3600   std::vector<int64> bias_strides = bias_descriptor.vectorized_strides(
3601       dnn::DataLayout::kBatchDepthYX, vector_size, vector_dim);
3602   auto tensor_b = cudnn_frontend::TensorBuilder()
3603                       .setDim(bias_dims.size(), bias_dims.data())
3604                       .setStrides(bias_dims.size(), bias_strides.data())
3605                       .setId('b')
3606                       .setAlignment(32)
3607                       .setDataType(cudnn_type)
3608                       .setVectorCountAndDimension(vector_size, vector_dim)
3609                       .build();
3610   RETURN_MSG_IF_CUDNN_ERROR(tensor_b);
3611 
3612   std::tie(vector_size, vector_dim) =
3613       GetTensorVectorSizeAndDim(output_descriptor, element_type);
3614   auto tensor_conv = cudnn_frontend::TensorBuilder()
3615                          .setDim(output_dims.size(), &output_dims[0])
3616                          .setStrides(output_dims.size(), &output_strides[0])
3617                          .setVirtual()
3618                          .setId('C')
3619                          .setAlignment(32)
3620                          .setDataType(cudnn_type)
3621                          .setVectorCountAndDimension(vector_size, vector_dim)
3622                          .build();
3623   RETURN_MSG_IF_CUDNN_ERROR(tensor_conv);
3624 
3625   auto tensor_add = cudnn_frontend::TensorBuilder()
3626                         .setDim(output_dims.size(), &output_dims[0])
3627                         .setStrides(output_dims.size(), &output_strides[0])
3628                         .setVirtual()
3629                         .setId('A')
3630                         .setAlignment(32)
3631                         .setDataType(cudnn_type)
3632                         .setVectorCountAndDimension(vector_size, vector_dim)
3633                         .build();
3634   RETURN_MSG_IF_CUDNN_ERROR(tensor_add);
3635 
3636   auto tensor_bias = cudnn_frontend::TensorBuilder()
3637                          .setDim(output_dims.size(), &output_dims[0])
3638                          .setStrides(output_dims.size(), &output_strides[0])
3639                          .setVirtual()
3640                          .setId('B')
3641                          .setAlignment(32)
3642                          .setDataType(cudnn_type)
3643                          .setVectorCountAndDimension(vector_size, vector_dim)
3644                          .build();
3645   RETURN_MSG_IF_CUDNN_ERROR(tensor_bias);
3646 
3647   // conv_desc.
3648   auto mode = convolution_descriptor.convolution_not_crosscorr()
3649                   ? CUDNN_CONVOLUTION
3650                   : CUDNN_CROSS_CORRELATION;
3651 
3652   int conv_dim = convolution_descriptor.ndims();
3653 
3654   auto accumulator_type = ToCudnnDataType(GetConvAccumulatorType(element_type));
3655   CHECK_NE(convolution_descriptor.pad_alignment(),
3656            dnn::PadAlignment::kTensorFlowPadding)
3657       << "TensorFlow padding alignment is not supported.";
3658 
3659   auto conv_desc =
3660       cudnn_frontend::ConvDescBuilder()
3661           .setComputePrecision(accumulator_type)
3662           .setMathMode(mode)
3663           .setNDims(conv_dim)
3664           .setStrides(conv_dim, convolution_descriptor.strides().data())
3665           .setPrePadding(conv_dim, convolution_descriptor.padding().data())
3666           .setPostPadding(conv_dim, convolution_descriptor.padding().data())
3667           .setDilation(conv_dim, convolution_descriptor.dilations().data())
3668           .build();
3669   RETURN_MSG_IF_CUDNN_ERROR(conv_desc);
3670 
3671   // Beta is the scaling factor for output.
3672   double beta = 0.0;
3673 
3674   // CUDNN Operation
3675   auto conv_op = cudnn_frontend::OperationBuilder(conv_mode)
3676                      .setxDesc(tensor_x)
3677                      .setyDesc(tensor_conv)
3678                      .setwDesc(tensor_w)
3679                      .setcDesc(conv_desc)
3680                      .setAlpha(alpha)
3681                      .setBeta(beta)
3682                      .build();
3683   RETURN_MSG_IF_CUDNN_ERROR(conv_op);
3684 
3685   auto add_desc = cudnn_frontend::PointWiseDescBuilder()
3686                       .setMode(CUDNN_POINTWISE_ADD)
3687                       .setMathPrecision(cudnn_type)
3688                       .build();
3689   auto add_op = cudnn_frontend::OperationBuilder(
3690                     CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
3691                     .setxDesc(conv_op.getOutputTensor())
3692                     .setbDesc(tensor_z)
3693                     .setyDesc(tensor_add)
3694                     .setpwDesc(add_desc)
3695                     .setAlpha(alpha)
3696                     .setAlpha2(alpha2)
3697                     .build();
3698   RETURN_MSG_IF_CUDNN_ERROR(add_op);
3699 
3700   auto bias_add_desc = cudnn_frontend::PointWiseDescBuilder()
3701                            .setMode(CUDNN_POINTWISE_ADD)
3702                            .setMathPrecision(cudnn_type)
3703                            .build();
3704 
3705   // If the activation is the identity function, then the bias-add is the last
3706   // op, and it writes to the output, tensor_y.  Otherwise, it writes to the
3707   // "virtual tensor" (temp buffer) tensor_bias, to which we apply the
3708   // activation.
3709   auto& bias_out_desc =
3710       activation_mode == dnn::ActivationMode::kNone ? tensor_y : tensor_bias;
3711   auto bias_add_op = cudnn_frontend::OperationBuilder(
3712                          CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
3713                          .setxDesc(add_op.getOutputTensor())
3714                          .setbDesc(tensor_b)
3715                          .setyDesc(bias_out_desc)
3716                          .setpwDesc(bias_add_desc)
3717                          .build();
3718   RETURN_MSG_IF_CUDNN_ERROR(bias_add_op);
3719 
3720   // CUDNN OperationGraph
3721   absl::InlinedVector<cudnn_frontend::Operation const*, 4> ops = {
3722       &conv_op, &add_op, &bias_add_op};
3723 
3724   absl::optional<cudnn_frontend::PointWiseDesc_v8> act_desc;
3725   absl::optional<cudnn_frontend::Operation_v8> act_op;
3726   switch (activation_mode) {
3727     case dnn::ActivationMode::kNone:
3728       break;
3729     case dnn::ActivationMode::kRelu:
3730       act_desc.emplace(cudnn_frontend::PointWiseDescBuilder()
3731                            .setMode(CUDNN_POINTWISE_RELU_FWD)
3732                            .setMathPrecision(cudnn_type)
3733                            .build());
3734       act_op.emplace(cudnn_frontend::OperationBuilder(
3735                          CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
3736                          .setxDesc(bias_add_op.getOutputTensor())
3737                          .setyDesc(tensor_y)
3738                          .setpwDesc(*act_desc)
3739                          .build());
3740       RETURN_MSG_IF_CUDNN_ERROR(*act_op);
3741       ops.push_back(&*act_op);
3742       break;
3743     default:
3744       return port::InternalError(
3745           absl::StrCat("Unimplemented activation mode ",
3746                        dnn::ActivationModeString(activation_mode)));
3747   }
3748 
3749   auto op_graph = cudnn_frontend::OperationGraphBuilder()
3750                       .setHandle(cudnn.handle())
3751                       .setOperationGraph(ops.size(), ops.data())
3752                       .build();
3753   RETURN_MSG_IF_CUDNN_ERROR(op_graph);
3754 
3755   VLOG(4) << "\nTensor_x: " << tensor_x.describe()
3756           << "\nTensor_y: " << tensor_y.describe()
3757           << "\nTensor_z: " << tensor_z.describe()
3758           << "\nTensor_w: " << tensor_w.describe()
3759           << "\nTensor_b: " << tensor_b.describe()
3760           << "\nTensor_conv: " << tensor_conv.describe()
3761           << "\nTensor_add: " << tensor_add.describe()
3762           << "\nTensor_bias: " << tensor_bias.describe()
3763           << "\nConv: " << conv_desc.describe()
3764           << "\nAdd: " << add_desc.describe()
3765           << "\nBiasAdd: " << bias_add_desc.describe()  //
3766           << "\nAct: "
3767           << (act_desc.has_value() ? act_desc->describe() : "(identity)")
3768           << "\nConvOp: " << conv_op.describe()
3769           << "\nAddOp: " << add_op.describe()
3770           << "\nBiasAddOp: " << bias_add_op.describe()  //
3771           << "\nActOp: "
3772           << (act_op.has_value() ? act_op->describe() : "(identity)")
3773           << "\nOpGraph: " << op_graph.describe();
3774 
3775   return std::unique_ptr<cudnn_frontend::OperationGraph>(
3776       new cudnn_frontend::OperationGraph(std::move(op_graph)));
3777 }
3778 
3779 port::StatusOr<std::unique_ptr<cudnn_frontend::ExecutionPlan>>
GetFirstWorkingExecutionPlan(Stream * stream,dnn::DataType element_type,const std::unique_ptr<cudnn_frontend::OperationGraph> & op_graph,dnn::ConvolutionKind kind,CudnnHandle & cudnn,ScratchAllocator * scratch_allocator)3780 GetFirstWorkingExecutionPlan(
3781     Stream* stream, dnn::DataType element_type,
3782     const std::unique_ptr<cudnn_frontend::OperationGraph>& op_graph,
3783     dnn::ConvolutionKind kind, CudnnHandle& cudnn,
3784     ScratchAllocator* scratch_allocator) {
3785   auto heuristics = cudnn_frontend::EngineHeuristicsBuilder()
3786                         .setOperationGraph(*op_graph)
3787                         .setHeurMode(CUDNN_HEUR_MODE_INSTANT)
3788                         .build();
3789   RETURN_MSG_IF_CUDNN_ERROR(heuristics);
3790 
3791   cudnnBackendDescriptorType_t conv_mode = GetCudnnConvolutionType(kind);
3792   auto fallback = cudnn_frontend::EngineFallbackListBuilder()
3793                       .setOperationGraph(*op_graph)
3794                       .setOperation(conv_mode)
3795                       .build();
3796   RETURN_MSG_IF_CUDNN_ERROR(fallback);
3797 
3798   // cuDNN frontend sneakily puts error messages on the object and returns
3799   // partially-initialized results when there's an error; make sure to check
3800   // them.
3801   int64_t engine_count = heuristics.getEngineConfigCount();
3802   RETURN_MSG_IF_CUDNN_ERROR(heuristics);
3803   auto& engine_config = heuristics.getEngineConfig(engine_count);
3804   RETURN_MSG_IF_CUDNN_ERROR(heuristics);
3805 
3806   auto& fallback_list = fallback.getFallbackList();
3807 
3808   cudnn_frontend::EngineConfigList filtered_configs;
3809   auto generic_filter_fn = [=](cudnnBackendDescriptor_t engine_config) -> bool {
3810     return GenericEngineFilter(
3811         engine_config,
3812         /*disable_winograd*/ !CudnnEnvVar<WinogradNonfused>::IsEnabled(),
3813         /*disable_nondeterminism*/ RequireCudnnDeterminism(),
3814         /*disable_tensor_core*/ !IsTensorMathEnabled(stream, element_type));
3815   };
3816 
3817   cudnn_frontend::filter(engine_config, filtered_configs, generic_filter_fn);
3818   cudnn_frontend::filter(fallback_list, filtered_configs, generic_filter_fn);
3819 
3820   auto fn = []() { return true; };
3821   auto maybe_json_handle_static = CudnnExecutionPlanEngineFilterStatic();
3822   auto maybe_json_handle_runtime = CudnnExecutionPlanEngineFilterRuntime();
3823 
3824   for (int i = 0; i < filtered_configs.size(); i++) {
3825     auto plan = cudnn_frontend::ExecutionPlanBuilder()
3826                     .setHandle(cudnn.handle())
3827                     .setEngineConfig(filtered_configs[i], op_graph->getTag())
3828                     .build();
3829     if (plan.get_status() == CUDNN_STATUS_SUCCESS) {
3830       if (maybe_json_handle_static &&
3831           cudnn_frontend::check_errata(*maybe_json_handle_static, plan.getTag(),
3832                                        cudnn.handle(), fn)) {
3833         VLOG(4) << "Exclude engine (static): " << plan.getTag();
3834         continue;
3835       }
3836       if (maybe_json_handle_runtime &&
3837           cudnn_frontend::check_errata(*maybe_json_handle_runtime,
3838                                        plan.getTag(), cudnn.handle(), fn)) {
3839         VLOG(4) << "Exclude engine (runtime): " << plan.getTag();
3840         continue;
3841       }
3842 
3843       bool specify_workspace_limit = scratch_allocator != nullptr;
3844       auto memory_limit_bytes =
3845           specify_workspace_limit
3846               ? std::max(scratch_allocator->GetMemoryLimitInBytes(), int64{0})
3847               : int64{0};
3848       int64_t workspace_size = plan.getWorkspaceSize();
3849       if (workspace_size <= memory_limit_bytes) {
3850         return std::unique_ptr<cudnn_frontend::ExecutionPlan>(
3851             new cudnn_frontend::ExecutionPlan(std::move(plan)));
3852       }
3853     }
3854   }
3855   return port::Status(port::error::UNKNOWN,
3856                       "CUDNN failed to get a working"
3857                       " plan.");
3858 }
3859 #endif  // CUDNN_VERSION >= 8100 && TF_ENABLE_CUDNN_FRONTEND
3860 
3861 }  // namespace
3862 
DoPrepareForConvolution(dnn::ConvolutionKind kind,dnn::DataType element_type,Stream * stream,const dnn::BatchDescriptor & input_descriptor,DeviceMemoryBase input_data,const dnn::FilterDescriptor & filter_descriptor,DeviceMemoryBase filter_data,const dnn::BatchDescriptor & output_descriptor,DeviceMemoryBase output_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::AlgorithmConfig & algorithm_config,ScratchAllocator * scratch_allocator,dnn::AlgorithmDesc * algorithm_desc,DeviceMemory<uint8> * scratch_memory)3863 port::Status CudnnSupport::DoPrepareForConvolution(
3864     dnn::ConvolutionKind kind, dnn::DataType element_type, Stream* stream,
3865     const dnn::BatchDescriptor& input_descriptor, DeviceMemoryBase input_data,
3866     const dnn::FilterDescriptor& filter_descriptor,
3867     DeviceMemoryBase filter_data, const dnn::BatchDescriptor& output_descriptor,
3868     DeviceMemoryBase output_data,
3869     const dnn::ConvolutionDescriptor& convolution_descriptor,
3870     const dnn::AlgorithmConfig& algorithm_config,
3871     ScratchAllocator* scratch_allocator, dnn::AlgorithmDesc* algorithm_desc,
3872     DeviceMemory<uint8>* scratch_memory) {
3873   CudnnTensorDescriptor input_nd(
3874       input_descriptor,
3875       ToCudnnDataType(element_type, input_descriptor.layout()));
3876   CudnnFilterDescriptor filter_nd(
3877       filter_descriptor,
3878       ToCudnnDataType(element_type, filter_descriptor.layout()));
3879   CudnnTensorDescriptor output_nd(
3880       output_descriptor,
3881       ToCudnnDataType(element_type, output_descriptor.layout()));
3882 
3883   auto cudnn = cudnn_->GetHandle(parent_, stream);
3884 
3885   switch (kind) {
3886     case dnn::ConvolutionKind::FORWARD: {
3887       SE_ASSIGN_OR_RETURN(*algorithm_desc,
3888                           GetCudnnConvolutionForwardAlgorithm(
3889                               stream, cudnn, algorithm_config, input_nd,
3890                               filter_nd, element_type, convolution_descriptor,
3891                               output_nd, scratch_allocator, scratch_memory));
3892       break;
3893     }
3894     case dnn::ConvolutionKind::BACKWARD_DATA: {
3895       SE_ASSIGN_OR_RETURN(*algorithm_desc,
3896                           GetCudnnConvolutionBackwardDataAlgorithm(
3897                               stream, cudnn, algorithm_config, input_nd,
3898                               filter_nd, element_type, convolution_descriptor,
3899                               output_nd, scratch_allocator, scratch_memory));
3900       break;
3901     }
3902     case dnn::ConvolutionKind::BACKWARD_FILTER: {
3903       SE_ASSIGN_OR_RETURN(*algorithm_desc,
3904                           GetCudnnConvolutionBackwardFilterAlgorithm(
3905                               stream, cudnn, algorithm_config, input_nd,
3906                               filter_nd, element_type, convolution_descriptor,
3907                               output_nd, scratch_allocator, scratch_memory));
3908       break;
3909     }
3910     default:
3911       return port::InternalError(
3912           absl::StrCat("Unexpected convolution kind ", static_cast<int>(kind)));
3913   }
3914 
3915   return port::Status::OK();
3916 }
3917 
DoConvolve(dnn::ConvolutionKind kind,dnn::DataType element_type,dnn::DataType output_type,Stream * stream,const dnn::BatchDescriptor & input_descriptor,DeviceMemoryBase input_data,const dnn::FilterDescriptor & filter_descriptor,DeviceMemoryBase filter_data,const dnn::BatchDescriptor & output_descriptor,DeviceMemoryBase output_data,const dnn::ConvolutionDescriptor & convolution_descriptor,dnn::AlgorithmDesc algorithm_desc,DeviceMemory<uint8> scratch_memory,dnn::ProfileResult * output_profile_result)3918 port::Status CudnnSupport::DoConvolve(
3919     dnn::ConvolutionKind kind, dnn::DataType element_type,
3920     dnn::DataType output_type, Stream* stream,
3921     const dnn::BatchDescriptor& input_descriptor, DeviceMemoryBase input_data,
3922     const dnn::FilterDescriptor& filter_descriptor,
3923     DeviceMemoryBase filter_data, const dnn::BatchDescriptor& output_descriptor,
3924     DeviceMemoryBase output_data,
3925     const dnn::ConvolutionDescriptor& convolution_descriptor,
3926     dnn::AlgorithmDesc algorithm_desc, DeviceMemory<uint8> scratch_memory,
3927     dnn::ProfileResult* output_profile_result) {
3928   cudnnDataType_t cudnn_type =
3929       ToCudnnDataType(element_type, input_descriptor.layout());
3930 
3931   CudnnTensorDescriptor input_nd(input_descriptor, cudnn_type);
3932   CudnnTensorDescriptor output_nd(
3933       output_descriptor,
3934       ToCudnnDataType(output_type, output_descriptor.layout()));
3935   CudnnFilterDescriptor filter_nd(
3936       filter_descriptor,
3937       ToCudnnDataType(element_type, filter_descriptor.layout()));
3938 
3939   auto accumulator_type = GetConvAccumulatorType(element_type);
3940   CudnnConvolutionDescriptor conv(convolution_descriptor,
3941                                   ToCudnnDataType(accumulator_type));
3942   SE_ASSIGN_OR_RETURN(bool use_tensor_ops,
3943                       UseTensorOps(stream, element_type, algorithm_desc));
3944   conv.set_use_tensor_op_math(use_tensor_ops);
3945 
3946   auto cudnn = cudnn_->GetHandle(parent_, stream);
3947   // Alpha is the scaling factor for input.
3948   float falpha = 1.0;
3949   double dalpha = 1.0;
3950   void* alpha = cudnn_type == CUDNN_DATA_DOUBLE ? static_cast<void*>(&dalpha)
3951                                                 : static_cast<void*>(&falpha);
3952   // Beta is the scaling factor for output.
3953   float fbeta = 0.0;
3954   double dbeta = 0.0;
3955   void* beta = cudnn_type == CUDNN_DATA_DOUBLE ? static_cast<void*>(&dbeta)
3956                                                : static_cast<void*>(&fbeta);
3957 
3958   const bool is_profiling = output_profile_result != nullptr;
3959 
3960   std::unique_ptr<GpuTimer, GpuTimerDeleter> timer;
3961   if (is_profiling) {
3962     timer.reset(new GpuTimer(parent_));  // NOLINT
3963     // The start and stop of the timer should be as close to the Cudnn call as
3964     // possible. It is still possible for other threads to issue workload on
3965     // to this stream. So it could take multiple profiling measurements.
3966     if (!timer->Init() || !timer->Start(AsGpuStream(stream))) {
3967       return port::Status(port::error::INTERNAL, "Failed to start timer");
3968     }
3969   }
3970 
3971   const auto get_fwd_bugs = [&]() -> port::Status {
3972     if (CUDNN_VERSION < 8000) {
3973       if (algorithm_desc.algo_id() ==
3974               CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM &&
3975           ToCudnnDataType(element_type) == CUDNN_DATA_INT8 &&
3976           ToCudnnDataType(output_type) == CUDNN_DATA_FLOAT) {
3977         return port::Status(
3978             port::error::FAILED_PRECONDITION,
3979             "This configuration potentially produces incorrect results.");
3980       }
3981     }
3982     return port::Status::OK();
3983   };
3984 
3985   auto get_bwd_data_bugs = [&]() -> port::Status { return port::Status::OK(); };
3986 
3987   const auto get_bwd_filter_bugs = [&]() -> port::Status {
3988     return port::Status::OK();
3989   };
3990 
3991   switch (kind) {
3992     case dnn::ConvolutionKind::FORWARD: {
3993       SE_RETURN_IF_ERROR(get_fwd_bugs());
3994       RETURN_IF_CUDNN_ERROR(cudnnConvolutionForward(
3995           cudnn.handle(),
3996           /*alpha=*/alpha, /*srcDesc=*/input_nd.handle(),
3997           /*srcData=*/input_data.opaque(), /*filterDesc=*/filter_nd.handle(),
3998           /*filterData=*/filter_data.opaque(), /*convDesc=*/conv.handle(),
3999           /*algo=*/ToConvForwardAlgo(algorithm_desc),
4000           /*workSpace=*/scratch_memory.opaque(),
4001           /*workSpaceSizeInBytes=*/scratch_memory.size(), /*beta=*/beta,
4002           /*yDesc=*/output_nd.handle(), /*y=*/output_data.opaque()));
4003       break;
4004     }
4005     case dnn::ConvolutionKind::BACKWARD_DATA: {
4006       SE_RETURN_IF_ERROR(get_bwd_data_bugs());
4007       RETURN_IF_CUDNN_ERROR(cudnnConvolutionBackwardData(
4008           cudnn.handle(),
4009           /*alpha=*/alpha,
4010           /*wDesc=*/filter_nd.handle(),
4011           /*w=*/filter_data.opaque(),
4012           /*dyDesc=*/output_nd.handle(),
4013           /*dy=*/output_data.opaque(),
4014           /*convDesc=*/conv.handle(),
4015           /*algo=*/ToConvBackwardDataAlgo(algorithm_desc),
4016           /*workSpace=*/scratch_memory.opaque(),
4017           /*workSpaceSizeInBytes=*/scratch_memory.size(),
4018           /*beta=*/beta,
4019           /*dxDesc=*/input_nd.handle(),
4020           /*dx=*/input_data.opaque()));
4021       break;
4022     }
4023     case dnn::ConvolutionKind::BACKWARD_FILTER: {
4024       SE_RETURN_IF_ERROR(get_bwd_filter_bugs());
4025       RETURN_IF_CUDNN_ERROR(cudnnConvolutionBackwardFilter(
4026           cudnn.handle(),
4027           /*alpha=*/alpha,
4028           /*srcDesc=*/input_nd.handle(),
4029           /*srcData=*/input_data.opaque(),
4030           /*diffDesc=*/output_nd.handle(),
4031           /*diffData=*/output_data.opaque(),
4032           /*convDesc=*/conv.handle(),
4033           /*algo=*/ToConvBackwardFilterAlgo(algorithm_desc),
4034           /*workSpace=*/scratch_memory.opaque(),
4035           /*workSpaceSizeInBytes=*/scratch_memory.size(),
4036           /*beta=*/beta,
4037           /*gradDesc=*/filter_nd.handle(),
4038           /*dw=*/filter_data.opaque()));
4039       break;
4040     }
4041     default:
4042       return port::InternalError(
4043           absl::StrCat("Unexpected convolution kind ", static_cast<int>(kind)));
4044   }
4045 
4046   if (is_profiling) {
4047     if (!timer->Stop(AsGpuStream(stream))) {
4048       return port::Status(port::error::INTERNAL, "Failed to stop timer");
4049     }
4050     output_profile_result->set_algorithm(algorithm_desc);
4051     output_profile_result->set_elapsed_time_in_ms(
4052         timer->GetElapsedMilliseconds());
4053     output_profile_result->set_scratch_size(scratch_memory.size());
4054   }
4055 
4056   return port::Status::OK();
4057 }
4058 
DoConvolveWithExecutionPlan(dnn::ConvolutionKind kind,dnn::DataType element_type,dnn::DataType output_type,Stream * stream,const dnn::BatchDescriptor & input_descriptor,DeviceMemoryBase input_data,const dnn::FilterDescriptor & filter_descriptor,DeviceMemoryBase filter_data,const dnn::BatchDescriptor & output_descriptor,DeviceMemoryBase output_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::AlgorithmConfig & plan_config,ScratchAllocator * scratch_allocator,dnn::ProfileResult * output_profile_result)4059 port::Status CudnnSupport::DoConvolveWithExecutionPlan(
4060     dnn::ConvolutionKind kind, dnn::DataType element_type,
4061     dnn::DataType output_type, Stream* stream,
4062     const dnn::BatchDescriptor& input_descriptor, DeviceMemoryBase input_data,
4063     const dnn::FilterDescriptor& filter_descriptor,
4064     DeviceMemoryBase filter_data, const dnn::BatchDescriptor& output_descriptor,
4065     DeviceMemoryBase output_data,
4066     const dnn::ConvolutionDescriptor& convolution_descriptor,
4067     const dnn::AlgorithmConfig& plan_config,
4068     ScratchAllocator* scratch_allocator,
4069     dnn::ProfileResult* output_profile_result) {
4070 #if CUDNN_VERSION >= 8100 && TF_ENABLE_CUDNN_FRONTEND
4071   auto cudnn = cudnn_->GetHandle(parent_, stream);
4072 
4073   absl::optional<dnn::AlgorithmDesc> plan_or = plan_config.algorithm();
4074   absl::optional<dnn::AlgorithmDesc> plan_no_scratch_or =
4075       plan_config.algorithm_no_scratch();
4076 
4077   std::unique_ptr<cudnn_frontend::ExecutionPlan> current_plan;
4078   if (!plan_or.has_value()) {
4079     SE_ASSIGN_OR_RETURN(
4080         std::unique_ptr<cudnn_frontend::OperationGraph> op_graph,
4081         GetCudnnOperationGraph(kind, element_type, stream, input_descriptor,
4082                                filter_descriptor, output_descriptor,
4083                                convolution_descriptor, cudnn));
4084 
4085     SE_ASSIGN_OR_RETURN(current_plan, GetFirstWorkingExecutionPlan(
4086                                           stream, element_type, op_graph, kind,
4087                                           cudnn, scratch_allocator));
4088   }
4089 
4090   size_t workspace_size = 0;
4091   cudnnBackendDescriptor_t plan_desc;
4092   std::string exec_plan_id = "unknown";
4093   if (current_plan) {
4094     exec_plan_id = current_plan->getTag();
4095     workspace_size = current_plan->getWorkspaceSize();
4096     plan_desc = current_plan->get_raw_desc();
4097   } else {
4098     exec_plan_id = plan_or->exec_plan_id();
4099     auto workspace_size_or = plan_config.scratch_size();
4100     if (workspace_size_or.has_value()) {
4101       workspace_size = *workspace_size_or;
4102     }
4103     plan_desc = plan_or->exec_plan_desc();
4104   }
4105   dnn::AlgorithmDesc selected_plan_(exec_plan_id, plan_desc);
4106 
4107   DeviceMemory<uint8> scratch_memory;
4108   if (workspace_size > 0) {
4109     auto scratch_or = scratch_allocator->AllocateBytes(workspace_size);
4110     if (scratch_or.ok()) {
4111       scratch_memory = scratch_or.ValueOrDie();
4112     } else if (plan_no_scratch_or.has_value()) {
4113       selected_plan_ = {plan_no_scratch_or->exec_plan_id(),
4114                         plan_no_scratch_or->exec_plan_desc()};
4115     } else {
4116       return port::Status(port::error::UNKNOWN,
4117                           "CUDNN failed to allocate the scratch space for the "
4118                           "plan or to find a working no-scratch plan.");
4119     }
4120   }
4121 
4122   void* data_ptrs[] = {input_data.opaque(), output_data.opaque(),
4123                        filter_data.opaque()};
4124   int64_t uids[] = {'x', 'y', 'w'};
4125   auto variantPack = cudnn_frontend::VariantPackBuilder()
4126                          .setWorkspacePointer(scratch_memory.opaque())
4127                          .setDataPointers(3, data_ptrs)
4128                          .setUids(3, uids)
4129                          .build();
4130   RETURN_MSG_IF_CUDNN_ERROR(variantPack);
4131 
4132   VLOG(4) << "\nDo convolution with plan tag: " << selected_plan_.exec_plan_id()
4133           << "\nWorkspace size in bytes: " << workspace_size
4134           << "\nVariantPack: " << variantPack.describe();
4135 
4136   const bool is_profiling = output_profile_result != nullptr;
4137 
4138   std::unique_ptr<GpuTimer, GpuTimerDeleter> timer;
4139   if (is_profiling) {
4140     timer.reset(new GpuTimer(parent_));  // NOLINT
4141     // The start and stop of the timer should be as close to the Cudnn call as
4142     // possible. It is still possible for other threads to issue workload on
4143     // to this stream. So it could take multiple profiling measurements.
4144     if (!timer->Init() || !timer->Start(AsGpuStream(stream))) {
4145       return port::Status(port::error::INTERNAL, "Failed to start timer");
4146     }
4147   }
4148 
4149   cudnnStatus_t status =
4150       cudnnBackendExecute(cudnn.handle(), selected_plan_.exec_plan_desc(),
4151                           variantPack.get_raw_desc());
4152   RETURN_IF_CUDNN_ERROR(status);
4153 
4154   if (is_profiling) {
4155     if (!timer->Stop(AsGpuStream(stream))) {
4156       return port::Status(port::error::INTERNAL, "Failed to stop timer");
4157     }
4158     output_profile_result->set_algorithm(selected_plan_);
4159     output_profile_result->set_elapsed_time_in_ms(
4160         timer->GetElapsedMilliseconds());
4161     output_profile_result->set_scratch_size(scratch_memory.size());
4162   }
4163 
4164   return port::Status::OK();
4165 #else
4166   return port::InternalError(
4167       "To use CuDNN frontend APIs, CuDNN v8.1 or later "
4168       "is required.");
4169 #endif  // CUDNN_VERSION >= 8100 && TF_ENABLE_CUDNN_FRONTEND
4170 }
4171 
4172 // Utility for dealing with CUDA's type-erased scaling parameters, where some
4173 // sets of parameters expect a void* pointing at a float while others expect it
4174 // to point at a double.
4175 //
4176 // This is rather ugly, but its purpose is to quarantine the corresponding
4177 // ugliness that already exists in the CUDA API.
4178 class ScalingParam {
4179  public:
ScalingParam(double value)4180   explicit ScalingParam(double value) : as_double_(value), as_float_(value) {}
4181 
4182   // Return a pointer to the appropriate representation type for the given
4183   // element type.
4184   //
4185   // See
4186   // https://docs.nvidia.com/deeplearning/cudnn/developer-guide/index.html#scaling-parameters
4187   // for more info; the behavior for int8 result tensors is not described there,
4188   // but is maintained from the existing behavior (namely, using a float scaling
4189   // parameter).
ToVoidPointer(dnn::DataType element_type)4190   void* ToVoidPointer(dnn::DataType element_type) {
4191     if (element_type == dnn::DataType::kDouble) {
4192       return &as_double_;
4193     } else {
4194       return &as_float_;
4195     }
4196   }
4197 
4198  private:
4199   double as_double_;
4200   float as_float_;
4201 };
4202 
GetConvolveExecutionPlans(dnn::ConvolutionKind kind,dnn::DataType element_type,Stream * stream,const dnn::BatchDescriptor & input_descriptor,const dnn::FilterDescriptor & filter_descriptor,const dnn::BatchDescriptor & output_descriptor,const dnn::ConvolutionDescriptor & convolution_descriptor,std::vector<std::unique_ptr<dnn::ConvolveExecutionPlan>> * out_exec_plans)4203 bool CudnnSupport::GetConvolveExecutionPlans(
4204     dnn::ConvolutionKind kind, dnn::DataType element_type, Stream* stream,
4205     const dnn::BatchDescriptor& input_descriptor,
4206     const dnn::FilterDescriptor& filter_descriptor,
4207     const dnn::BatchDescriptor& output_descriptor,
4208     const dnn::ConvolutionDescriptor& convolution_descriptor,
4209     std::vector<std::unique_ptr<dnn::ConvolveExecutionPlan>>* out_exec_plans) {
4210 #if CUDNN_VERSION >= 8100 && TF_ENABLE_CUDNN_FRONTEND
4211   auto cudnn = cudnn_->GetHandle(parent_, stream);
4212   auto op_graph_status = GetCudnnOperationGraph(
4213       kind, element_type, stream, input_descriptor, filter_descriptor,
4214       output_descriptor, convolution_descriptor, cudnn);
4215   if (!op_graph_status.status().ok()) {
4216     return false;
4217   }
4218   auto op_graph = op_graph_status.ConsumeValueOrDie();
4219 
4220   auto heur = cudnn_frontend::EngineHeuristicsBuilder()
4221                   .setOperationGraph(*op_graph)
4222                   .setHeurMode(CUDNN_HEUR_MODE_INSTANT)
4223                   .build();
4224   RETURN_FALSE_IF_CUDNN_ERROR(heur);
4225 
4226   auto fallback = cudnn_frontend::EngineFallbackListBuilder()
4227                       .setOperationGraph(*op_graph)
4228                       .setOperation(GetCudnnConvolutionType(kind))
4229                       .build();
4230   RETURN_FALSE_IF_CUDNN_ERROR(fallback);
4231 
4232   // cuDNN frontend sneakily puts error messages on the object and returns
4233   // partially-initialized results when there's an error; make sure to check
4234   // them.
4235   int64_t engine_count = heur.getEngineConfigCount();
4236   RETURN_FALSE_IF_CUDNN_ERROR(heur);
4237   auto& heur_configs = heur.getEngineConfig(engine_count);
4238   RETURN_FALSE_IF_CUDNN_ERROR(heur);
4239 
4240   auto& fallback_configs = fallback.getFallbackList();
4241 
4242   VLOG(4) << "\nHeuristics engine configs size: " << heur_configs.size()
4243           << "\nFallback engine configs size: " << fallback_configs.size();
4244 
4245   cudnn_frontend::EngineConfigList filtered_configs;
4246   auto generic_filter_fn = [=](cudnnBackendDescriptor_t engine_config) -> bool {
4247     return GenericEngineFilter(
4248         engine_config,
4249         /*disable_winograd*/ !CudnnEnvVar<WinogradNonfused>::IsEnabled(),
4250         /*disable_nondeterminism*/ RequireCudnnDeterminism(),
4251         /*disable_tensor_core*/ !IsTensorMathEnabled(stream, element_type));
4252   };
4253 
4254   cudnn_frontend::filter(heur_configs, filtered_configs, generic_filter_fn);
4255   cudnn_frontend::filter(fallback_configs, filtered_configs, generic_filter_fn);
4256 
4257   auto fn = []() { return true; };
4258   auto maybe_json_handle_static = CudnnExecutionPlanEngineFilterStatic();
4259   auto maybe_json_handle_runtime = CudnnExecutionPlanEngineFilterRuntime();
4260 
4261   VLOG(4) << "\nFiltered engine configs size: " << filtered_configs.size();
4262 
4263   out_exec_plans->clear();
4264   for (int i = 0; i < filtered_configs.size(); i++) {
4265     auto plan = cudnn_frontend::ExecutionPlanBuilder()
4266                     .setHandle(cudnn.handle())
4267                     .setEngineConfig(filtered_configs[i], op_graph->getTag())
4268                     .build();
4269     if (plan.get_status() == CUDNN_STATUS_SUCCESS) {
4270       if (maybe_json_handle_static &&
4271           cudnn_frontend::check_errata(*maybe_json_handle_static, plan.getTag(),
4272                                        cudnn.handle(), fn)) {
4273         VLOG(4) << "Exclude engine (static): " << plan.getTag();
4274         continue;
4275       }
4276       if (maybe_json_handle_runtime &&
4277           cudnn_frontend::check_errata(*maybe_json_handle_runtime,
4278                                        plan.getTag(), cudnn.handle(), fn)) {
4279         VLOG(4) << "Exclude engine (runtime): " << plan.getTag();
4280         continue;
4281       }
4282 
4283       out_exec_plans->push_back(std::unique_ptr<dnn::ConvolveExecutionPlan>(
4284           new CudnnConvolveExecutionPlan(std::move(plan))));
4285       // We will use the first working plan when determinism is required.
4286       if (RequireCudnnDeterminism()) {
4287         break;
4288       }
4289     }
4290   }
4291 
4292   VLOG(4) << "\nReturned execution plans size: " << out_exec_plans->size();
4293 
4294   return true;
4295 #else
4296   return false;
4297 #endif  // CUDNN_VERSION >= 8100 && TF_ENABLE_CUDNN_FRONTEND
4298 }
4299 
GetFusedConvolveExecutionPlans(dnn::ConvolutionKind kind,dnn::DataType element_type,double conv_input_scale,double side_input_scale,Stream * stream,const dnn::BatchDescriptor & input_descriptor,const dnn::FilterDescriptor & filter_descriptor,const dnn::BatchDescriptor & bias_descriptor,const dnn::BatchDescriptor & output_descriptor,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::ActivationMode activation_mode,std::vector<std::unique_ptr<dnn::ConvolveExecutionPlan>> * out_exec_plans)4300 port::Status CudnnSupport::GetFusedConvolveExecutionPlans(
4301     dnn::ConvolutionKind kind, dnn::DataType element_type,
4302     double conv_input_scale, double side_input_scale, Stream* stream,
4303     const dnn::BatchDescriptor& input_descriptor,
4304     const dnn::FilterDescriptor& filter_descriptor,
4305     const dnn::BatchDescriptor& bias_descriptor,
4306     const dnn::BatchDescriptor& output_descriptor,
4307     const dnn::ConvolutionDescriptor& convolution_descriptor,
4308     const dnn::ActivationMode activation_mode,
4309     std::vector<std::unique_ptr<dnn::ConvolveExecutionPlan>>* out_exec_plans) {
4310 #if CUDNN_VERSION >= 8100 && TF_ENABLE_CUDNN_FRONTEND
4311   auto cudnn = cudnn_->GetHandle(parent_, stream);
4312   auto op_graph_status = GetCudnnFusedOperationGraph(
4313       kind, element_type, conv_input_scale, side_input_scale, stream,
4314       input_descriptor, filter_descriptor, bias_descriptor, output_descriptor,
4315       convolution_descriptor, activation_mode, cudnn);
4316   if (!op_graph_status.status().ok()) {
4317     return port::Status(port::error::INTERNAL,
4318                         absl::StrCat("Cudnn graph failed to build: ",
4319                                      op_graph_status.status().ToString()));
4320   }
4321   auto op_graph = op_graph_status.ConsumeValueOrDie();
4322 
4323   auto heur = cudnn_frontend::EngineHeuristicsBuilder()
4324                   .setOperationGraph(*op_graph)
4325                   .setHeurMode(CUDNN_HEUR_MODE_INSTANT)
4326                   .build();
4327   RETURN_MSG_IF_CUDNN_ERROR(heur);
4328 
4329   auto fallback =
4330       cudnn_frontend::EngineFallbackListBuilder()
4331           .setOperationGraph(*op_graph)
4332           .setOperation(GetCudnnConvolutionType(dnn::ConvolutionKind::FORWARD))
4333           .build();
4334   RETURN_MSG_IF_CUDNN_ERROR(fallback);
4335 
4336   // cuDNN frontend sneakily puts error messages on the object and returns
4337   // partially-initialized results when there's an error; make sure to check
4338   // them.
4339   int64_t engine_count = heur.getEngineConfigCount();
4340   RETURN_MSG_IF_CUDNN_ERROR(heur);
4341   auto& heur_configs = heur.getEngineConfig(engine_count);
4342   RETURN_MSG_IF_CUDNN_ERROR(heur);
4343 
4344   auto& fallback_configs = fallback.getFallbackList();
4345 
4346   VLOG(4) << "\nHeuristics engine configs size: " << heur_configs.size()
4347           << "\nFallback engine configs size: " << fallback_configs.size();
4348 
4349   cudnn_frontend::EngineConfigList filtered_configs;
4350   auto generic_filter_fn = [=](cudnnBackendDescriptor_t engine_config) -> bool {
4351     return GenericEngineFilter(
4352         engine_config,
4353         /*disable_winograd*/ !CudnnEnvVar<WinogradNonfused>::IsEnabled(),
4354         /*disable_nondeterminism*/ RequireCudnnDeterminism(),
4355         /*disable_tensor_core*/ !IsTensorMathEnabled(stream, element_type));
4356   };
4357 
4358   cudnn_frontend::filter(heur_configs, filtered_configs, generic_filter_fn);
4359   cudnn_frontend::filter(fallback_configs, filtered_configs, generic_filter_fn);
4360 
4361   auto fn = []() { return true; };
4362   auto maybe_json_handle_static = CudnnExecutionPlanEngineFilterStatic();
4363   auto maybe_json_handle_runtime = CudnnExecutionPlanEngineFilterRuntime();
4364 
4365   VLOG(4) << "\nFiltered engine configs size: " << filtered_configs.size();
4366 
4367   out_exec_plans->clear();
4368   for (int i = 0; i < filtered_configs.size(); i++) {
4369     auto plan = cudnn_frontend::ExecutionPlanBuilder()
4370                     .setHandle(cudnn.handle())
4371                     .setEngineConfig(filtered_configs[i], op_graph->getTag())
4372                     .build();
4373     if (plan.get_status() == CUDNN_STATUS_SUCCESS) {
4374       if (maybe_json_handle_static &&
4375           cudnn_frontend::check_errata(*maybe_json_handle_static, plan.getTag(),
4376                                        cudnn.handle(), fn)) {
4377         VLOG(4) << "Exclude engine (static): " << plan.getTag();
4378         continue;
4379       }
4380       if (maybe_json_handle_runtime &&
4381           cudnn_frontend::check_errata(*maybe_json_handle_runtime,
4382                                        plan.getTag(), cudnn.handle(), fn)) {
4383         VLOG(4) << "Exclude engine (runtime): " << plan.getTag();
4384         continue;
4385       }
4386 
4387       out_exec_plans->push_back(std::unique_ptr<dnn::ConvolveExecutionPlan>(
4388           new CudnnConvolveExecutionPlan(std::move(plan))));
4389       // We will use the first working plan when determinism is required.
4390       if (RequireCudnnDeterminism()) {
4391         break;
4392       }
4393     }
4394   }
4395 
4396   VLOG(4) << "\nReturned execution plans size: " << out_exec_plans->size();
4397 
4398   return port::Status::OK();
4399 #else
4400   return port::UnimplementedError(
4401       "Cudnn execution plans are only supported with Cudnn >= 8.1.");
4402 #endif  // CUDNN_VERSION >= 8100 && TF_ENABLE_CUDNN_FRONTEND
4403 }
4404 
GetConvolveAlgorithms(CudaComputeCapability cuda_compute_capability,std::vector<dnn::AlgorithmDesc> * out_algorithms)4405 bool CudnnSupport::GetConvolveAlgorithms(
4406     CudaComputeCapability cuda_compute_capability,
4407     std::vector<dnn::AlgorithmDesc>* out_algorithms) {
4408   // Preload sub libs for cudnn 8.0.4+
4409 #if CUDNN_MAJOR >= 8 && (CUDNN_MINOR > 0 || CUDNN_PATCHLEVEL >= 4)
4410   cudnnOpsInferVersionCheck();
4411   cudnnCnnInferVersionCheck();
4412 #endif
4413   bool tensor_op_math_available =
4414       TensorOpMathAvailable(cuda_compute_capability);
4415   out_algorithms->clear();
4416 
4417   std::vector<dnn::AlgorithmDesc::Index> algo_types;
4418   if (ConvUseDefaultAlgorithm()) {
4419     // Force a fallback algorithm.
4420     algo_types = {CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM};
4421   } else {
4422     algo_types = {CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM,
4423                   CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM,
4424                   CUDNN_CONVOLUTION_FWD_ALGO_GEMM,
4425                   CUDNN_CONVOLUTION_FWD_ALGO_DIRECT,
4426                   CUDNN_CONVOLUTION_FWD_ALGO_FFT,
4427                   CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD};
4428     if (CudnnEnvVar<FftTilingForward>::IsEnabled()) {
4429       algo_types.push_back(CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING);
4430     }
4431     if (CudnnEnvVar<WinogradNonfused>::IsEnabled()) {
4432       algo_types.push_back(CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED);
4433     }
4434   }
4435 
4436   // The algorithms are intentionally ordered for deterministic operation
4437   for (auto i : algo_types) {
4438     if (tensor_op_math_available) {
4439       out_algorithms->push_back({i, /*use_tensor_ops=*/true});
4440     }
4441     out_algorithms->push_back({i, /*use_tensor_ops=*/false});
4442   }
4443 
4444   return true;
4445 }
4446 
GetRnnAlgorithms(std::vector<dnn::AlgorithmDesc> * out_algorithms)4447 bool CudnnSupport::GetRnnAlgorithms(
4448     std::vector<dnn::AlgorithmDesc>* out_algorithms) {
4449   // Preload sub libs for cudnn 8.0.4+
4450 #if CUDNN_MAJOR >= 8 && (CUDNN_MINOR > 0 || CUDNN_PATCHLEVEL >= 4)
4451   cudnnOpsInferVersionCheck();
4452   cudnnOpsTrainVersionCheck();
4453   cudnnAdvInferVersionCheck();
4454   cudnnAdvTrainVersionCheck();
4455 #endif
4456   std::vector<dnn::AlgorithmDesc::Index> algo_types = {
4457       // clang-format off
4458     CUDNN_RNN_ALGO_STANDARD,
4459     CUDNN_RNN_ALGO_PERSIST_STATIC,
4460     CUDNN_RNN_ALGO_PERSIST_DYNAMIC,
4461       // clang-format on
4462   };
4463 
4464   out_algorithms->clear();
4465   for (auto i : algo_types) {
4466     out_algorithms->push_back({i, /*use_tensor_ops=*/false});
4467     out_algorithms->push_back({i, /*use_tensor_ops=*/true});
4468   }
4469   return true;
4470 }
4471 
GetConvolveBackwardDataAlgorithms(CudaComputeCapability cuda_compute_capability,std::vector<dnn::AlgorithmDesc> * out_algorithms)4472 bool CudnnSupport::GetConvolveBackwardDataAlgorithms(
4473     CudaComputeCapability cuda_compute_capability,
4474     std::vector<dnn::AlgorithmDesc>* out_algorithms) {
4475   // Preload sub libs for cudnn 8.0.4+
4476 #if CUDNN_MAJOR >= 8 && (CUDNN_MINOR > 0 || CUDNN_PATCHLEVEL >= 4)
4477   cudnnOpsInferVersionCheck();
4478   cudnnOpsTrainVersionCheck();
4479   cudnnCnnInferVersionCheck();
4480   cudnnCnnTrainVersionCheck();
4481 #endif
4482   bool tensor_op_math_available =
4483       TensorOpMathAvailable(cuda_compute_capability);
4484   out_algorithms->clear();
4485 
4486   std::vector<dnn::AlgorithmDesc::Index> algo_types = {
4487       // clang-format off
4488     CUDNN_CONVOLUTION_BWD_DATA_ALGO_1,
4489     CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT,
4490     CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING,
4491     CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD,
4492       // clang-format on
4493   };
4494   if (CudnnEnvVar<WinogradNonfused>::IsEnabled()) {
4495     algo_types.push_back(CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD_NONFUSED);
4496   }
4497   if (!RequireCudnnDeterminism()) {
4498     algo_types.push_back(CUDNN_CONVOLUTION_BWD_DATA_ALGO_0);
4499   }
4500 
4501   // The algorithms are intentionally ordered for deterministic operation
4502   for (auto i : algo_types) {
4503     if (tensor_op_math_available) {
4504       out_algorithms->push_back({i, /*use_tensor_ops=*/true});
4505     }
4506     out_algorithms->push_back({i, /*use_tensor_ops=*/false});
4507   }
4508 
4509   return true;
4510 }
4511 
GetConvolveBackwardFilterAlgorithms(CudaComputeCapability cuda_compute_capability,std::vector<dnn::AlgorithmDesc> * out_algorithms)4512 bool CudnnSupport::GetConvolveBackwardFilterAlgorithms(
4513     CudaComputeCapability cuda_compute_capability,
4514     std::vector<dnn::AlgorithmDesc>* out_algorithms) {
4515   // Preload sub libs for cudnn 8.0.4+
4516 #if CUDNN_MAJOR >= 8 && (CUDNN_MINOR > 0 || CUDNN_PATCHLEVEL >= 4)
4517   cudnnOpsInferVersionCheck();
4518   cudnnOpsTrainVersionCheck();
4519   cudnnCnnInferVersionCheck();
4520   cudnnCnnTrainVersionCheck();
4521 #endif
4522   bool tensor_op_math_available =
4523       TensorOpMathAvailable(cuda_compute_capability);
4524   out_algorithms->clear();
4525 
4526   std::vector<dnn::AlgorithmDesc::Index> algo_types = {
4527       // clang-format off
4528       CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1,
4529       CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT,
4530       // Based on cudnn.h, the following is not implemented.
4531       // CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD,
4532 
4533       // Produces incorrect results for some shapes. Disabled for now, see
4534       // NVIDIA bug 2072856. TODO(csigg): Only disable for subset of shapes.
4535       // CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT_TILING,
4536       // clang-format on
4537   };
4538   if (CudnnEnvVar<WinogradNonfused>::IsEnabled()) {
4539     algo_types.push_back(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD_NONFUSED);
4540   }
4541   if (!RequireCudnnDeterminism()) {
4542     algo_types.push_back(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0);
4543     algo_types.push_back(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_3);
4544   }
4545 
4546   // The algorithms are intentionally ordered for deterministic operation
4547   for (auto i : algo_types) {
4548     if (tensor_op_math_available) {
4549       out_algorithms->push_back({i, /*use_tensor_ops=*/true});
4550     }
4551     out_algorithms->push_back({i, /*use_tensor_ops=*/false});
4552   }
4553 
4554   return true;
4555 }
4556 
DoBatchNormalizationForward(Stream * stream,const DeviceMemory<float> & x,const DeviceMemory<float> & scale,const DeviceMemory<float> & offset,const DeviceMemory<float> & estimated_mean,const DeviceMemory<float> & estimated_variance,const DeviceMemory<float> & side_input,const dnn::BatchDescriptor & x_desc,const dnn::BatchDescriptor & scale_offset_desc,const double epsilon,const double exponential_average_factor,dnn::ActivationMode activation_mode,DeviceMemory<float> * y,DeviceMemory<float> * batch_mean,DeviceMemory<float> * batch_var,DeviceMemory<float> * saved_mean,DeviceMemory<float> * saved_inv_var,bool is_training,ScratchAllocator * reserve_space_allocator,ScratchAllocator * workspace_allocator)4557 bool CudnnSupport::DoBatchNormalizationForward(
4558     Stream* stream, const DeviceMemory<float>& x,
4559     const DeviceMemory<float>& scale, const DeviceMemory<float>& offset,
4560     const DeviceMemory<float>& estimated_mean,
4561     const DeviceMemory<float>& estimated_variance,
4562     const DeviceMemory<float>& side_input, const dnn::BatchDescriptor& x_desc,
4563     const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
4564     const double exponential_average_factor,
4565     dnn::ActivationMode activation_mode, DeviceMemory<float>* y,
4566     DeviceMemory<float>* batch_mean, DeviceMemory<float>* batch_var,
4567     DeviceMemory<float>* saved_mean, DeviceMemory<float>* saved_inv_var,
4568     bool is_training, ScratchAllocator* reserve_space_allocator,
4569     ScratchAllocator* workspace_allocator) {
4570   return IsStatusOk(
4571       DoBatchNormalizationForwardImpl<float, float>(
4572           stream, dnn::DataType::kFloat, dnn::DataType::kFloat, x, scale,
4573           offset, estimated_mean, estimated_variance, side_input, x_desc,
4574           scale_offset_desc, epsilon, exponential_average_factor,
4575           activation_mode, y, batch_mean, batch_var, saved_mean, saved_inv_var,
4576           is_training, reserve_space_allocator, workspace_allocator),
4577       /*report_error=*/true);
4578 }
4579 
DoBatchNormalizationForward(Stream * stream,const DeviceMemory<Eigen::half> & x,const DeviceMemory<float> & scale,const DeviceMemory<float> & offset,const DeviceMemory<float> & estimated_mean,const DeviceMemory<float> & estimated_variance,const DeviceMemory<Eigen::half> & side_input,const dnn::BatchDescriptor & x_desc,const dnn::BatchDescriptor & scale_offset_desc,const double epsilon,const double exponential_average_factor,dnn::ActivationMode activation_mode,DeviceMemory<Eigen::half> * y,DeviceMemory<float> * batch_mean,DeviceMemory<float> * batch_var,DeviceMemory<float> * saved_mean,DeviceMemory<float> * saved_inv_var,bool is_training,ScratchAllocator * reserve_space_allocator,ScratchAllocator * workspace_allocator)4580 bool CudnnSupport::DoBatchNormalizationForward(
4581     Stream* stream, const DeviceMemory<Eigen::half>& x,
4582     const DeviceMemory<float>& scale, const DeviceMemory<float>& offset,
4583     const DeviceMemory<float>& estimated_mean,
4584     const DeviceMemory<float>& estimated_variance,
4585     const DeviceMemory<Eigen::half>& side_input,
4586     const dnn::BatchDescriptor& x_desc,
4587     const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
4588     const double exponential_average_factor,
4589     dnn::ActivationMode activation_mode, DeviceMemory<Eigen::half>* y,
4590     DeviceMemory<float>* batch_mean, DeviceMemory<float>* batch_var,
4591     DeviceMemory<float>* saved_mean, DeviceMemory<float>* saved_inv_var,
4592     bool is_training, ScratchAllocator* reserve_space_allocator,
4593     ScratchAllocator* workspace_allocator) {
4594   return IsStatusOk(
4595       DoBatchNormalizationForwardImpl<Eigen::half, float>(
4596           stream, dnn::DataType::kHalf, dnn::DataType::kFloat, x, scale, offset,
4597           estimated_mean, estimated_variance, side_input, x_desc,
4598           scale_offset_desc, epsilon, exponential_average_factor,
4599           activation_mode, y, batch_mean, batch_var, saved_mean, saved_inv_var,
4600           is_training, reserve_space_allocator, workspace_allocator),
4601       /*report_error=*/true);
4602 }
4603 
4604 template <class T, class U>
DoBatchNormalizationForwardImpl(Stream * stream,dnn::DataType input_data_type,dnn::DataType scale_data_type,const DeviceMemory<T> & x,const DeviceMemory<U> & scale,const DeviceMemory<U> & offset,const DeviceMemory<U> & estimated_mean,const DeviceMemory<U> & estimated_variance,const DeviceMemory<T> & side_input,const dnn::BatchDescriptor & x_desc,const dnn::BatchDescriptor & scale_offset_desc,const double epsilon,const double exponential_average_factor,dnn::ActivationMode activation_mode,DeviceMemory<T> * y,DeviceMemory<U> * batch_mean,DeviceMemory<U> * batch_var,DeviceMemory<U> * saved_mean,DeviceMemory<U> * saved_inv_var,bool is_training,ScratchAllocator * reserve_space_allocator,ScratchAllocator * workspace_allocator)4605 port::Status CudnnSupport::DoBatchNormalizationForwardImpl(
4606     Stream* stream, dnn::DataType input_data_type,
4607     dnn::DataType scale_data_type, const DeviceMemory<T>& x,
4608     const DeviceMemory<U>& scale, const DeviceMemory<U>& offset,
4609     const DeviceMemory<U>& estimated_mean,
4610     const DeviceMemory<U>& estimated_variance,
4611     const DeviceMemory<T>& side_input, const dnn::BatchDescriptor& x_desc,
4612     const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
4613     const double exponential_average_factor,
4614     dnn::ActivationMode activation_mode, DeviceMemory<T>* y,
4615     DeviceMemory<U>* batch_mean, DeviceMemory<U>* batch_var,
4616     DeviceMemory<U>* saved_mean, DeviceMemory<U>* saved_inv_var,
4617     bool is_training, ScratchAllocator* reserve_space_allocator,
4618     ScratchAllocator* workspace_allocator) {
4619   CudnnTensorDescriptor x_descriptor(x_desc, ToCudnnDataType(input_data_type));
4620   CudnnTensorDescriptor scale_offset_descriptor(
4621       scale_offset_desc, ToCudnnDataType(scale_data_type));
4622   cudnnBatchNormMode_t mode = CUDNN_BATCHNORM_SPATIAL;
4623   if (BatchnormSpatialPersistentEnabled() && is_training) {
4624     mode = CUDNN_BATCHNORM_SPATIAL_PERSISTENT;
4625   }
4626   float one = 1.0;
4627   float zero = 0.0;
4628   auto cudnn = cudnn_->GetHandle(parent_, stream);
4629 
4630   DeviceMemory<uint8> workspace;
4631   DeviceMemory<uint8> reserve_space;
4632 
4633 #if CUDNN_VERSION >= 7402
4634   const auto get_bn_ops = [&]() -> cudnnBatchNormOps_t {
4635     if (side_input.is_null()) {
4636       return activation_mode == dnn::ActivationMode::kNone
4637                  ? CUDNN_BATCHNORM_OPS_BN
4638                  : CUDNN_BATCHNORM_OPS_BN_ACTIVATION;
4639     } else {
4640       return CUDNN_BATCHNORM_OPS_BN_ADD_ACTIVATION;
4641     }
4642   };
4643   const cudnnBatchNormOps_t bn_ops = get_bn_ops();
4644 
4645   // We use Nan propagation to be consistent with CudnnSupport::DoActivate(...).
4646   CudnnActivationDescriptor activation_desc(
4647       activation_mode, CUDNN_PROPAGATE_NAN, x_desc.value_max());
4648 
4649   if (reserve_space_allocator != nullptr && workspace_allocator != nullptr) {
4650     SE_ASSIGN_OR_RETURN(
4651         workspace,
4652         CreateBatchNormForwardWorkspace(
4653             stream, cudnn, mode, bn_ops, activation_desc.handle(), x_descriptor,
4654             scale_offset_descriptor, workspace_allocator))
4655     if (is_training) {
4656       size_t reserve_space_size_in_bytes = 0;
4657       RETURN_IF_CUDNN_ERROR(
4658           cudnnGetBatchNormalizationTrainingExReserveSpaceSize(
4659               /*handle=*/cudnn.handle(), /*mode=*/mode, /*bnOps=*/bn_ops,
4660               /*activationDesc=*/activation_desc.handle(),
4661               /*xDesc=*/x_descriptor.handle(),
4662               /*sizeInBytes=*/&reserve_space_size_in_bytes));
4663       SE_ASSIGN_OR_RETURN(reserve_space, reserve_space_allocator->AllocateBytes(
4664                                              reserve_space_size_in_bytes));
4665     }
4666   }
4667 #endif
4668 
4669   auto check_no_side_input_or_activation = [&]() -> port::Status {
4670     if (activation_mode != dnn::ActivationMode::kNone ||
4671         !side_input.is_null()) {
4672       return port::Status(
4673           port::error::INTERNAL,
4674           absl::StrCat(
4675               "Side input and activation are not supported by cuDNN version: ",
4676               CUDNN_VERSION));
4677     } else {
4678       return port::Status::OK();
4679     }
4680   };
4681 
4682   if (is_training) {
4683     CHECK_EQ(batch_mean->is_null(), batch_var->is_null())
4684         << "batch_mean and batch_var must both be null or both be non-null";
4685 
4686     void* batch_mean_opaque;
4687     void* batch_var_opaque;
4688     if (!batch_mean->is_null() && !batch_var->is_null()) {
4689       if (exponential_average_factor == 1.0) {
4690         stream->ThenMemZero(batch_mean, batch_mean->size());
4691         stream->ThenMemZero(batch_var, batch_var->size());
4692       }
4693       batch_mean_opaque = batch_mean->opaque();
4694       batch_var_opaque = batch_var->opaque();
4695     } else {
4696       batch_mean_opaque = nullptr;
4697       batch_var_opaque = nullptr;
4698     }
4699 
4700     bool called = false;
4701 #if CUDNN_VERSION >= 7402
4702     if (reserve_space_allocator != nullptr && workspace_allocator != nullptr) {
4703       called = true;
4704       RETURN_IF_CUDNN_ERROR(cudnnBatchNormalizationForwardTrainingEx(
4705           /*handle=*/cudnn.handle(),
4706           /*mode=*/mode,
4707           /*bnOps=*/bn_ops,
4708           /*alpha=*/&one,
4709           /*beta=*/&zero,
4710           /*xDesc=*/x_descriptor.handle(),
4711           /*xData=*/x.opaque(),
4712           /*zDesc=*/x_descriptor.handle(),
4713           /*zData=*/side_input.opaque(),
4714           /*yDesc=*/x_descriptor.handle(),
4715           /*yData=*/y->opaque(),
4716           /*bnScaleBiasMeanVarDesc=*/scale_offset_descriptor.handle(),
4717           /*bnScale=*/scale.opaque(),
4718           /*bnBias=*/offset.opaque(),
4719           /*exponentialAverageFactor=*/exponential_average_factor,
4720           /*resultRunningMean=*/batch_mean_opaque,
4721           /*resultRunningVariance=*/batch_var_opaque,
4722           /*epsilon=*/epsilon,
4723           /*resultSaveMean=*/saved_mean->opaque(),
4724           /*resultSaveInvVariance=*/saved_inv_var->opaque(),
4725           /*activationDesc=*/activation_desc.handle(),
4726           /*workspace=*/workspace.opaque(),
4727           /*workSpaceSizeInBytes=*/workspace.size(),
4728           /*reserveSpace=*/reserve_space.opaque(),
4729           /*reserveSpaceSizeInBytes=*/reserve_space.size()));
4730     }
4731 #endif
4732     if (!called) {
4733       SE_RETURN_IF_ERROR(check_no_side_input_or_activation());
4734       RETURN_IF_CUDNN_ERROR(cudnnBatchNormalizationForwardTraining(
4735           cudnn.handle(), mode, &one, &zero, x_descriptor.handle(), x.opaque(),
4736           x_descriptor.handle(), y->opaque(), scale_offset_descriptor.handle(),
4737           scale.opaque(), offset.opaque(), exponential_average_factor,
4738           batch_mean_opaque, batch_var_opaque, epsilon, saved_mean->opaque(),
4739           saved_inv_var->opaque()));
4740     }
4741   } else {
4742     const void* maybe_inv_var = estimated_variance.opaque();
4743     SE_RETURN_IF_ERROR(check_no_side_input_or_activation());
4744     RETURN_IF_CUDNN_ERROR(cudnnBatchNormalizationForwardInference(
4745         cudnn.handle(), mode, &one, &zero, x_descriptor.handle(), x.opaque(),
4746         x_descriptor.handle(), y->opaque(), scale_offset_descriptor.handle(),
4747         scale.opaque(), offset.opaque(), estimated_mean.opaque(), maybe_inv_var,
4748         epsilon));
4749   }
4750   return port::Status::OK();
4751 }
4752 
DoBatchNormalizationBackward(Stream * stream,const DeviceMemory<float> & y_backprop,const DeviceMemory<float> & x,const DeviceMemory<float> & scale,const DeviceMemory<float> & mean,const DeviceMemory<float> & inv_var,const dnn::BatchDescriptor & x_desc,const dnn::BatchDescriptor & scale_offset_desc,const double epsilon,DeviceMemory<float> * x_backprop,DeviceMemory<float> * scale_backprop,DeviceMemory<float> * offset_backprop,DeviceMemory<uint8> * reserve_space_data,ScratchAllocator * workspace_allocator)4753 bool CudnnSupport::DoBatchNormalizationBackward(
4754     Stream* stream, const DeviceMemory<float>& y_backprop,
4755     const DeviceMemory<float>& x, const DeviceMemory<float>& scale,
4756     const DeviceMemory<float>& mean, const DeviceMemory<float>& inv_var,
4757     const dnn::BatchDescriptor& x_desc,
4758     const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
4759     DeviceMemory<float>* x_backprop, DeviceMemory<float>* scale_backprop,
4760     DeviceMemory<float>* offset_backprop,
4761     DeviceMemory<uint8>* reserve_space_data,
4762     ScratchAllocator* workspace_allocator) {
4763   return IsStatusOk(DoBatchNormalizationBackwardImpl(
4764                         stream, CUDNN_DATA_FLOAT, CUDNN_DATA_FLOAT, y_backprop,
4765                         x, scale, mean, inv_var, x_desc, scale_offset_desc,
4766                         epsilon, x_backprop, scale_backprop, offset_backprop,
4767                         reserve_space_data, workspace_allocator),
4768                     /*report_error=*/true);
4769 }
4770 
DoBatchNormalizationBackward(Stream * stream,const DeviceMemory<Eigen::half> & y_backprop,const DeviceMemory<Eigen::half> & x,const DeviceMemory<float> & scale,const DeviceMemory<float> & mean,const DeviceMemory<float> & inv_var,const dnn::BatchDescriptor & x_desc,const dnn::BatchDescriptor & scale_offset_desc,const double epsilon,DeviceMemory<Eigen::half> * x_backprop,DeviceMemory<float> * scale_backprop,DeviceMemory<float> * offset_backprop,DeviceMemory<uint8> * reserve_space_data,ScratchAllocator * workspace_allocator)4771 bool CudnnSupport::DoBatchNormalizationBackward(
4772     Stream* stream, const DeviceMemory<Eigen::half>& y_backprop,
4773     const DeviceMemory<Eigen::half>& x, const DeviceMemory<float>& scale,
4774     const DeviceMemory<float>& mean, const DeviceMemory<float>& inv_var,
4775     const dnn::BatchDescriptor& x_desc,
4776     const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
4777     DeviceMemory<Eigen::half>* x_backprop, DeviceMemory<float>* scale_backprop,
4778     DeviceMemory<float>* offset_backprop,
4779     DeviceMemory<uint8>* reserve_space_data,
4780     ScratchAllocator* workspace_allocator) {
4781   return IsStatusOk(DoBatchNormalizationBackwardImpl(
4782                         stream, CUDNN_DATA_HALF, CUDNN_DATA_FLOAT, y_backprop,
4783                         x, scale, mean, inv_var, x_desc, scale_offset_desc,
4784                         epsilon, x_backprop, scale_backprop, offset_backprop,
4785                         reserve_space_data, workspace_allocator),
4786                     /*report_error=*/true);
4787 }
4788 
4789 template <class T, class U>
DoBatchNormalizationBackwardImpl(Stream * stream,int cudnn_input_type,int cudnn_scale_type,const DeviceMemory<T> & y_backprop,const DeviceMemory<T> & x,const DeviceMemory<U> & scale,const DeviceMemory<U> & mean,const DeviceMemory<U> & inv_var,const dnn::BatchDescriptor & x_desc,const dnn::BatchDescriptor & scale_offset_desc,const double epsilon,DeviceMemory<T> * x_backprop,DeviceMemory<U> * scale_backprop,DeviceMemory<U> * offset_backprop,DeviceMemory<uint8> * reserve_space_data,ScratchAllocator * workspace_allocator)4790 port::Status CudnnSupport::DoBatchNormalizationBackwardImpl(
4791     Stream* stream, int cudnn_input_type, int cudnn_scale_type,
4792     const DeviceMemory<T>& y_backprop, const DeviceMemory<T>& x,
4793     const DeviceMemory<U>& scale, const DeviceMemory<U>& mean,
4794     const DeviceMemory<U>& inv_var, const dnn::BatchDescriptor& x_desc,
4795     const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
4796     DeviceMemory<T>* x_backprop, DeviceMemory<U>* scale_backprop,
4797     DeviceMemory<U>* offset_backprop, DeviceMemory<uint8>* reserve_space_data,
4798     ScratchAllocator* workspace_allocator) {
4799   CudnnTensorDescriptor x_descriptor(
4800       x_desc, static_cast<cudnnDataType_t>(cudnn_input_type));
4801   CudnnTensorDescriptor scale_offset_descriptor(
4802       scale_offset_desc, static_cast<cudnnDataType_t>(cudnn_scale_type));
4803   cudnnBatchNormMode_t mode = CUDNN_BATCHNORM_SPATIAL;
4804   if (BatchnormSpatialPersistentEnabled()) {
4805     mode = CUDNN_BATCHNORM_SPATIAL_PERSISTENT;
4806   }
4807   float one = 1.0;
4808   float zero = 0.0;
4809 
4810   auto cudnn = cudnn_->GetHandle(parent_, stream);
4811 
4812   bool called = false;
4813 #if CUDNN_VERSION >= 7402
4814   if (reserve_space_data != nullptr && workspace_allocator != nullptr) {
4815     called = true;
4816     const cudnnBatchNormOps_t bn_ops = CUDNN_BATCHNORM_OPS_BN;
4817     SE_ASSIGN_OR_RETURN(DeviceMemory<uint8> workspace,
4818                         CreateBatchNormBackwardWorkspace(
4819                             stream, cudnn, mode, bn_ops, x_descriptor,
4820                             scale_offset_descriptor, workspace_allocator))
4821     RETURN_IF_CUDNN_ERROR(cudnnBatchNormalizationBackwardEx(
4822         /*handle=*/cudnn.handle(),
4823         /*mode=*/mode,
4824         /*bnOps=*/bn_ops,
4825         /*alphaDataDiff=*/&one,
4826         /*betaDataDiff=*/&zero,
4827         /*alphaParamDiff=*/&one,
4828         /*betaParamDiff=*/&zero,
4829         /*xDesc=*/x_descriptor.handle(),
4830         /*xData=*/x.opaque(),
4831         /*yDesc=*/nullptr,
4832         /*yData=*/nullptr,
4833         /*dyDesc=*/x_descriptor.handle(),
4834         /*dyData=*/y_backprop.opaque(),
4835         /*dzDesc=*/nullptr,
4836         /*dzData=*/nullptr,
4837         /*dxDesc=*/x_descriptor.handle(),
4838         /*dxData=*/x_backprop->opaque(),
4839         /*dBnScaleBiasDesc=*/scale_offset_descriptor.handle(),
4840         /*bnScaleData=*/scale.opaque(),
4841         /*bnBiasData=*/nullptr,
4842         /*dBnScaleData=*/scale_backprop->opaque(),
4843         /*dBnBiasData=*/offset_backprop->opaque(),
4844         /*epsilon=*/epsilon,
4845         /*savedMean=*/mean.opaque(),
4846         /*savedInvVariance=*/inv_var.opaque(),
4847         /*activationDesc=*/nullptr,
4848         /*workspace=*/workspace.opaque(),
4849         /*workSpaceSizeInBytes=*/workspace.size(),
4850         /*reserveSpace=*/reserve_space_data->opaque(),
4851         /*reserveSpaceSizeInBytes=*/reserve_space_data->size()));
4852   }
4853 #endif
4854   if (!called) {
4855     RETURN_IF_CUDNN_ERROR(cudnnBatchNormalizationBackward(
4856         cudnn.handle(), mode, &one, &zero, &one, &zero, x_descriptor.handle(),
4857         x.opaque(), x_descriptor.handle(), y_backprop.opaque(),
4858         x_descriptor.handle(), x_backprop->opaque(),
4859         scale_offset_descriptor.handle(), scale.opaque(),
4860         scale_backprop->opaque(), offset_backprop->opaque(), epsilon,
4861         mean.opaque(), inv_var.opaque()));
4862   }
4863 
4864   return port::Status::OK();
4865 }
4866 
DoFusedConvolve(Stream * stream,dnn::DataType input_type,dnn::DataType side_input_type,dnn::DataType bias_type,dnn::DataType output_type,const dnn::BatchDescriptor & conv_input_descriptor,DeviceMemoryBase conv_input_data,double conv_input_scale,const dnn::FilterDescriptor & filter_descriptor,DeviceMemoryBase filter_data,const dnn::ConvolutionDescriptor & convolution_descriptor,DeviceMemoryBase side_input_data,double side_input_scale,const dnn::BatchDescriptor & bias_descriptor,DeviceMemoryBase biases,dnn::ActivationMode activation_mode,const dnn::BatchDescriptor & output_descriptor,DeviceMemoryBase output_data,ScratchAllocator * scratch_allocator,const dnn::AlgorithmConfig & algorithm_config,dnn::ProfileResult * output_profile_result)4867 port::Status CudnnSupport::DoFusedConvolve(
4868     Stream* stream, dnn::DataType input_type, dnn::DataType side_input_type,
4869     dnn::DataType bias_type, dnn::DataType output_type,
4870     const dnn::BatchDescriptor& conv_input_descriptor,
4871     DeviceMemoryBase conv_input_data, double conv_input_scale,
4872     const dnn::FilterDescriptor& filter_descriptor,
4873     DeviceMemoryBase filter_data,
4874     const dnn::ConvolutionDescriptor& convolution_descriptor,
4875     DeviceMemoryBase side_input_data, double side_input_scale,
4876     const dnn::BatchDescriptor& bias_descriptor, DeviceMemoryBase biases,
4877     dnn::ActivationMode activation_mode,
4878     const dnn::BatchDescriptor& output_descriptor, DeviceMemoryBase output_data,
4879     ScratchAllocator* scratch_allocator,
4880     const dnn::AlgorithmConfig& algorithm_config,
4881     dnn::ProfileResult* output_profile_result) {
4882   if (input_type == dnn::DataType::kInt8 &&
4883       !stream->GetCudaComputeCapability().IsAtLeast(6, 1)) {
4884     return port::UnimplementedError(
4885         "cudnnConvolutionBiasActivationForward() for int8 is only supported on "
4886         "GPUs with compute capability 6.1 or later.");
4887   }
4888 
4889   if (input_type == dnn::DataType::kInt8 &&
4890       output_type == dnn::DataType::kFloat &&
4891       (CUDNN_VERSION >= 8000 && CUDNN_VERSION <= 8200)) {
4892     return port::UnimplementedError(
4893         "int8 -> float fused conv is disabled for this cuDNN version. See "
4894         "go/nvbugs/3326122");
4895   }
4896 
4897   if (activation_mode != dnn::ActivationMode::kRelu &&
4898       activation_mode != dnn::ActivationMode::kNone) {
4899     return port::Status(port::error::INVALID_ARGUMENT,
4900                         "cudnnConvolutionBiasActivationForward() only supports "
4901                         "Relu or None activation.");
4902   }
4903 
4904   CudnnTensorDescriptor conv_input_nd(
4905       conv_input_descriptor,
4906       ToCudnnDataType(input_type, conv_input_descriptor.layout()));
4907   CudnnTensorDescriptor output_nd(
4908       output_descriptor,
4909       ToCudnnDataType(output_type, conv_input_descriptor.layout()));
4910   CudnnFilterDescriptor filter(
4911       filter_descriptor,
4912       ToCudnnDataType(input_type, filter_descriptor.layout()));
4913   CudnnTensorDescriptor bias_nd(bias_descriptor, ToCudnnDataType(bias_type));
4914 
4915   auto cudnn = cudnn_->GetHandle(parent_, stream);
4916 
4917   const bool is_profiling = output_profile_result != nullptr;
4918 
4919   DeviceMemory<uint8> scratch;
4920   SE_ASSIGN_OR_RETURN(
4921       dnn::AlgorithmDesc algo_desc,
4922       GetCudnnConvolutionForwardAlgorithm(
4923           stream, cudnn, algorithm_config, conv_input_nd, filter, input_type,
4924           convolution_descriptor, output_nd, scratch_allocator, &scratch));
4925 
4926   CudnnConvolutionDescriptor conv(
4927       convolution_descriptor,
4928       ToCudnnDataType(GetConvAccumulatorType(input_type)));
4929   conv.set_use_tensor_op_math(algo_desc.tensor_ops_enabled());
4930 
4931   std::unique_ptr<GpuTimer, GpuTimerDeleter> timer;
4932   if (is_profiling) {
4933     timer.reset(new GpuTimer(parent_));  // NOLINT
4934     // The start and stop of the timer should be as close to the Cudnn call as
4935     // possible. It is still possible for other threads to issue workload on
4936     // to this stream. So it could take multiple profiling measurements.
4937     if (!timer->Init() || !timer->Start(AsGpuStream(stream))) {
4938       return port::Status(port::error::INTERNAL, "Failed to start timer");
4939     }
4940   }
4941   // CUDNN v6 only supports CUDNN_NOT_PROPAGATE_NAN as the reluNanOpt for
4942   // activation descriptor. Note that this will change the nan propagation
4943   // behavior from separate conv, bias, and relu (which by default is
4944   // CUDNN_PROPAGATE_NAN.
4945   CudnnActivationDescriptor activation_desc(
4946       activation_mode, CUDNN_NOT_PROPAGATE_NAN, output_descriptor.value_max());
4947   auto side_input_data_ptr =
4948       (side_input_scale == 0) ? output_data.opaque() : side_input_data.opaque();
4949 
4950   VLOG(2) << "\nconv_input_scale = " << conv_input_scale
4951           << "\nconv_input_nd.handle() = " << conv_input_nd.handle()
4952           << "\nconv_input_data.opaque() = " << conv_input_data.opaque()
4953           << "\nfilter.handle() = " << filter.handle()
4954           << "\nfilter_data.opaque() = " << filter_data.opaque()
4955           << "\nconv.handle() = " << conv.handle()
4956           << "\nalgo = " << algo_desc.algo_id()
4957           << ", tensor_ops_enabled=" << algo_desc.tensor_ops_enabled()
4958           << "\nscratch.opaque() = " << scratch.opaque()
4959           << "\nscratch.size() = " << scratch.size()
4960           << "\nside_input_scale = " << side_input_scale
4961           << "\noutput_nd.handle() = " << output_nd.handle()
4962           << "\nside_input_data_ptr = " << side_input_data_ptr
4963           << "\nbias_nd.handle() = " << bias_nd.handle()
4964           << "\nbiases.opaque() = " << biases.opaque()
4965           << "\nactivation_desc.handle() = " << activation_desc.handle()
4966           << "\noutput_nd.handle() = " << output_nd.handle()
4967           << "\noutput_data.opaque() = " << output_data.opaque();
4968 
4969   if (IsTensorMathOpSet(conv) != algo_desc.tensor_ops_enabled()) {
4970     return port::Status(port::error::FAILED_PRECONDITION,
4971                         "Tensor op math type in dnn::AlgorithmDesc does not "
4972                         "match that of the CudnnConvolutionDescriptor");
4973   }
4974 
4975   // N.B. the scaling parameters alpha1 and alpha2 are pointers to temporaries;
4976   // this API doesn't persist the pointers beyond its own stack frame.
4977   auto status = cudnnConvolutionBiasActivationForward(
4978       cudnn.handle(),
4979       /*alpha1=*/ScalingParam(conv_input_scale).ToVoidPointer(input_type),
4980       /*xDesc=*/conv_input_nd.handle(), /*x=*/conv_input_data.opaque(),
4981       /*wDesc=*/filter.handle(), /*w=*/filter_data.opaque(),
4982       /*convDesc=*/conv.handle(), ToConvForwardAlgo(algo_desc),
4983       /*workSpace=*/scratch.opaque(),
4984       /*workSpaceSizeInBytes=*/scratch.size(),
4985       /*alpha2=*/ScalingParam(side_input_scale).ToVoidPointer(input_type),
4986       /*zDesc=*/output_nd.handle(), /*z=*/side_input_data_ptr,
4987       /*biasDesc=*/bias_nd.handle(), /*bias=*/biases.opaque(),
4988       /*activationDesc=*/activation_desc.handle(),
4989       /*yDesc=*/output_nd.handle(), /*y=*/output_data.opaque());
4990   if (status != CUDNN_STATUS_SUCCESS || !is_profiling) {
4991     VLOG(4) << "conv with algorithm " << ToConvForwardAlgo(algo_desc)
4992             << ", workspace_size=" << scratch.size() << " -> "
4993             << ToString(status);
4994   }
4995   RETURN_IF_CUDNN_ERROR(status);
4996 
4997   if (is_profiling) {
4998     if (!timer->Stop(AsGpuStream(stream))) {
4999       return port::Status(port::error::INTERNAL, "Failed to stop timer");
5000     }
5001     output_profile_result->set_algorithm(algo_desc);
5002     output_profile_result->set_elapsed_time_in_ms(
5003         timer->GetElapsedMilliseconds());
5004     output_profile_result->set_scratch_size(scratch.size());
5005     VLOG(4) << "conv with algorithm " << ToConvForwardAlgo(algo_desc)
5006             << ", tensor_ops_enabled=" << algo_desc.tensor_ops_enabled()
5007             << ", workspace_size=" << scratch.size() << " -> "
5008             << ToString(status) << " in " << timer->GetElapsedMilliseconds()
5009             << "ms";
5010   }
5011 
5012   return port::Status::OK();
5013 }
5014 
DoFusedConvolveWithExecutionPlan(Stream * stream,dnn::DataType element_type,const dnn::BatchDescriptor & conv_input_descriptor,DeviceMemoryBase conv_input_data,double conv_input_scale,const dnn::FilterDescriptor & filter_descriptor,DeviceMemoryBase filter_data,const dnn::ConvolutionDescriptor & convolution_descriptor,DeviceMemoryBase side_input_data,double side_input_scale,const dnn::BatchDescriptor & bias_descriptor,DeviceMemoryBase biases,dnn::ActivationMode activation_mode,const dnn::BatchDescriptor & output_descriptor,DeviceMemoryBase output_data,ScratchAllocator * scratch_allocator,const dnn::AlgorithmConfig & algorithm_config,dnn::ProfileResult * output_profile_result)5015 port::Status CudnnSupport::DoFusedConvolveWithExecutionPlan(
5016     Stream* stream, dnn::DataType element_type,
5017     const dnn::BatchDescriptor& conv_input_descriptor,
5018     DeviceMemoryBase conv_input_data, double conv_input_scale,
5019     const dnn::FilterDescriptor& filter_descriptor,
5020     DeviceMemoryBase filter_data,
5021     const dnn::ConvolutionDescriptor& convolution_descriptor,
5022     DeviceMemoryBase side_input_data, double side_input_scale,
5023     const dnn::BatchDescriptor& bias_descriptor, DeviceMemoryBase biases,
5024     dnn::ActivationMode activation_mode,
5025     const dnn::BatchDescriptor& output_descriptor, DeviceMemoryBase output_data,
5026     ScratchAllocator* scratch_allocator,
5027     const dnn::AlgorithmConfig& algorithm_config,
5028     dnn::ProfileResult* output_profile_result) {
5029 #if CUDNN_VERSION >= 8100 && TF_ENABLE_CUDNN_FRONTEND
5030   auto cudnn = cudnn_->GetHandle(parent_, stream);
5031 
5032   absl::optional<dnn::AlgorithmDesc> plan_or = algorithm_config.algorithm();
5033   absl::optional<dnn::AlgorithmDesc> plan_no_scratch_or =
5034       algorithm_config.algorithm_no_scratch();
5035 
5036   std::unique_ptr<cudnn_frontend::ExecutionPlan> current_plan;
5037   if (!plan_or.has_value()) {
5038     SE_ASSIGN_OR_RETURN(
5039         std::unique_ptr<cudnn_frontend::OperationGraph> op_graph,
5040         GetCudnnFusedOperationGraph(
5041             dnn::ConvolutionKind::FORWARD, element_type, conv_input_scale,
5042             side_input_scale, stream, conv_input_descriptor, filter_descriptor,
5043             bias_descriptor, output_descriptor, convolution_descriptor,
5044             activation_mode, cudnn));
5045 
5046     SE_ASSIGN_OR_RETURN(current_plan, GetFirstWorkingExecutionPlan(
5047                                           stream, element_type, op_graph,
5048                                           dnn::ConvolutionKind::FORWARD, cudnn,
5049                                           scratch_allocator));
5050   }
5051 
5052   size_t workspace_size = 0;
5053   cudnnBackendDescriptor_t plan_desc;
5054   std::string exec_plan_id = "unknown";
5055   if (current_plan) {
5056     exec_plan_id = current_plan->getTag();
5057     workspace_size = current_plan->getWorkspaceSize();
5058     plan_desc = current_plan->get_raw_desc();
5059   } else {
5060     exec_plan_id = plan_or->exec_plan_id();
5061     auto workspace_size_or = algorithm_config.scratch_size();
5062     if (workspace_size_or.has_value()) {
5063       workspace_size = *workspace_size_or;
5064     }
5065     plan_desc = plan_or->exec_plan_desc();
5066   }
5067   dnn::AlgorithmDesc selected_plan(exec_plan_id, plan_desc);
5068 
5069   DeviceMemory<uint8> scratch_memory;
5070   if (workspace_size > 0) {
5071     auto scratch_or = scratch_allocator->AllocateBytes(workspace_size);
5072     if (scratch_or.ok()) {
5073       scratch_memory = scratch_or.ValueOrDie();
5074     } else if (plan_no_scratch_or.has_value()) {
5075       selected_plan = {plan_no_scratch_or->exec_plan_id(),
5076                        plan_no_scratch_or->exec_plan_desc()};
5077     } else {
5078       return port::Status(port::error::UNKNOWN,
5079                           "CUDNN failed to allocate the scratch space for the "
5080                           "plan or to find a working no-scratch plan.");
5081     }
5082   }
5083 
5084   void* data_ptrs[] = {
5085       conv_input_data.opaque(), output_data.opaque(), filter_data.opaque(),
5086       side_input_data.opaque(), biases.opaque(),
5087   };
5088   int64_t uids[] = {'x', 'y', 'w', 'z', 'b'};
5089   auto variantPack = cudnn_frontend::VariantPackBuilder()
5090                          .setWorkspacePointer(scratch_memory.opaque())
5091                          .setDataPointers(5, data_ptrs)
5092                          .setUids(5, uids)
5093                          .build();
5094   RETURN_MSG_IF_CUDNN_ERROR(variantPack);
5095 
5096   VLOG(4) << "\nDo fused convolution with plan tag: "
5097           << selected_plan.exec_plan_id()
5098           << "\nWorkspace size in bytes: " << workspace_size
5099           << "\nVariantPack: " << variantPack.describe();
5100 
5101   const bool is_profiling = output_profile_result != nullptr;
5102 
5103   std::unique_ptr<GpuTimer, GpuTimerDeleter> timer;
5104   if (is_profiling) {
5105     timer.reset(new GpuTimer(parent_));  // NOLINT
5106     // The start and stop of the timer should be as close to the Cudnn call as
5107     // possible. It is still possible for other threads to issue workload on
5108     // to this stream. So it could take multiple profiling measurements.
5109     if (!timer->Init() || !timer->Start(AsGpuStream(stream))) {
5110       return port::Status(port::error::INTERNAL, "Failed to start timer");
5111     }
5112   }
5113 
5114   cudnnStatus_t status =
5115       cudnnBackendExecute(cudnn.handle(), selected_plan.exec_plan_desc(),
5116                           variantPack.get_raw_desc());
5117   if (status != CUDNN_STATUS_SUCCESS || !is_profiling) {
5118     VLOG(4) << "conv with plan " << selected_plan.exec_plan_id()
5119             << ", workspace_size=" << workspace_size << " -> "
5120             << ToString(status);
5121   }
5122   RETURN_IF_CUDNN_ERROR(status);
5123 
5124   if (is_profiling) {
5125     if (!timer->Stop(AsGpuStream(stream))) {
5126       return port::Status(port::error::INTERNAL, "Failed to stop timer");
5127     }
5128     output_profile_result->set_algorithm(selected_plan);
5129     output_profile_result->set_elapsed_time_in_ms(
5130         timer->GetElapsedMilliseconds());
5131     output_profile_result->set_scratch_size(scratch_memory.size());
5132     VLOG(4) << "conv with plan " << selected_plan.exec_plan_id()
5133             << ", workspace_size=" << workspace_size << " -> "
5134             << ToString(status) << " in " << timer->GetElapsedMilliseconds()
5135             << "ms";
5136   }
5137 
5138   return port::Status::OK();
5139 #else
5140   return port::InternalError(
5141       "To use CuDNN frontend APIs, CuDNN v8.1 or later "
5142       "is required.");
5143 #endif  // CUDNN_VERSION >= 8100 && TF_ENABLE_CUDNN_FRONTEND
5144 }
5145 
DoPrepareForCtcLoss(Stream * stream,dnn::DataType element_type,const dnn::RnnStateTensorDescriptor & probs_desc,const dnn::RnnStateTensorDescriptor & grads_desc,absl::Span<const int> labels_data,absl::Span<const int> labels_lengths_data,absl::Span<const int> input_lengths_data,ScratchAllocator * scratch_allocator,DeviceMemory<uint8> * scratch_memory,int * ctc_loss_algo_id)5146 port::Status CudnnSupport::DoPrepareForCtcLoss(
5147     Stream* stream, dnn::DataType element_type,
5148     const dnn::RnnStateTensorDescriptor& probs_desc,
5149     const dnn::RnnStateTensorDescriptor& grads_desc,
5150     absl::Span<const int> labels_data,
5151     absl::Span<const int> labels_lengths_data,
5152     absl::Span<const int> input_lengths_data,
5153     ScratchAllocator* scratch_allocator, DeviceMemory<uint8>* scratch_memory,
5154     int* ctc_loss_algo_id) {
5155   auto cudnn = cudnn_->GetHandle(parent_, stream);
5156   // Query the workspace size.
5157   size_t workspace_size_in_bytes = 0;
5158 #if CUDNN_VERSION >= 7603
5159   CudnnCtcLossDescriptor cudnn_ctc_loss_desc(ToCudnnDataType(element_type));
5160   const CudnnRnnStateTensorDescriptor& cudnn_probs_desc =
5161       static_cast<const CudnnRnnStateTensorDescriptor&>(probs_desc);
5162   const CudnnRnnStateTensorDescriptor& cudnn_grads_desc =
5163       static_cast<const CudnnRnnStateTensorDescriptor&>(grads_desc);
5164 
5165   // Try running with `algo`, if successful then pick it. The non-deterministic
5166   // algorithm is first and thus preferentially picked when determinism is not
5167   // required.
5168   auto algo = RequireCudnnDeterminism() ? CUDNN_CTC_LOSS_ALGO_DETERMINISTIC
5169                                         : CUDNN_CTC_LOSS_ALGO_NON_DETERMINISTIC;
5170   cudnnStatus_t status = cudnnGetCTCLossWorkspaceSize(
5171       /*handle=*/cudnn.handle(), /*probsDesc=*/cudnn_probs_desc.handle(),
5172       /*gradientsDesc=*/cudnn_grads_desc.handle(),
5173       /*labels=*/labels_data.data(),
5174       /*labelLengths=*/labels_lengths_data.data(),
5175       /*inputLengths=*/input_lengths_data.data(),
5176       /*algo=*/algo,
5177       /*ctcLossDesc=*/cudnn_ctc_loss_desc.handle(),
5178       /*sizeInBytes=*/&workspace_size_in_bytes);
5179   if (RequireCudnnDeterminism()) {
5180     RETURN_IF_CUDNN_ERROR(status);
5181   }
5182 
5183   if (status != CUDNN_STATUS_SUCCESS) {
5184     algo = CUDNN_CTC_LOSS_ALGO_DETERMINISTIC;
5185     RETURN_IF_CUDNN_ERROR(cudnnGetCTCLossWorkspaceSize(
5186         /*handle=*/cudnn.handle(), /*probsDesc=*/cudnn_probs_desc.handle(),
5187         /*gradientsDesc=*/cudnn_grads_desc.handle(),
5188         /*labels=*/labels_data.data(),
5189         /*labelLengths=*/labels_lengths_data.data(),
5190         /*inputLengths=*/input_lengths_data.data(),
5191         /*algo=*/algo,
5192         /*ctcLossDesc=*/cudnn_ctc_loss_desc.handle(),
5193         /*sizeInBytes=*/&workspace_size_in_bytes));
5194   }
5195   *ctc_loss_algo_id = algo;
5196 #else
5197   return port::Status(port::error::INVALID_ARGUMENT,
5198                       "No supported cudnnGetCTCLossWorkspaceSize when "
5199                       "CUDNN_VERSION < 7.6.3");
5200 #endif
5201   // Allocate the workspace.
5202   if (workspace_size_in_bytes == 0) {
5203     *scratch_memory = DeviceMemory<uint8>();
5204     return port::Status::OK();
5205   }
5206   const auto scratch_or =
5207       scratch_allocator->AllocateBytes(workspace_size_in_bytes);
5208   if (scratch_or.ok()) {
5209     *scratch_memory = scratch_or.ValueOrDie();
5210     return port::Status::OK();
5211   }
5212   return port::InternalError(
5213       "Failed to allocate scratch memory for the CuDNN CTC Loss");
5214 }
5215 
DoCtcLoss(Stream * stream,dnn::DataType element_type,const dnn::RnnStateTensorDescriptor & probs_desc,const DeviceMemoryBase probs_data,absl::Span<const int> labels_data,absl::Span<const int> labels_lengths_data,absl::Span<const int> input_lengths_data,DeviceMemoryBase costs_data,const dnn::RnnStateTensorDescriptor & grads_desc,DeviceMemoryBase grads_data,DeviceMemory<uint8> scratch_memory,int ctc_loss_algo_id)5216 port::Status CudnnSupport::DoCtcLoss(
5217     Stream* stream, dnn::DataType element_type,
5218     const dnn::RnnStateTensorDescriptor& probs_desc,
5219     const DeviceMemoryBase probs_data, absl::Span<const int> labels_data,
5220     absl::Span<const int> labels_lengths_data,
5221     absl::Span<const int> input_lengths_data, DeviceMemoryBase costs_data,
5222     const dnn::RnnStateTensorDescriptor& grads_desc,
5223     DeviceMemoryBase grads_data, DeviceMemory<uint8> scratch_memory,
5224     int ctc_loss_algo_id) {
5225   // Current cuDNN CTC Loss only supports the float datatype
5226   if (CUDNN_VERSION < 7603 || element_type != dnn::DataType::kFloat) {
5227     return port::Status(port::error::INVALID_ARGUMENT,
5228                         "CudnnCtcLossDescriptor is supported only when the "
5229                         "CUDNN_VERSION >= 7.6.3 and DataType is float");
5230   }
5231   CudnnCtcLossDescriptor cudnn_ctc_loss_desc(ToCudnnDataType(element_type));
5232   const CudnnRnnStateTensorDescriptor& cudnn_probs_desc =
5233       static_cast<const CudnnRnnStateTensorDescriptor&>(probs_desc);
5234   const CudnnRnnStateTensorDescriptor& cudnn_grads_desc =
5235       static_cast<const CudnnRnnStateTensorDescriptor&>(grads_desc);
5236   return DoCtcLossImpl(stream, cudnn_probs_desc, probs_data, labels_data,
5237                        labels_lengths_data, input_lengths_data, costs_data,
5238                        cudnn_grads_desc, grads_data, cudnn_ctc_loss_desc,
5239                        scratch_memory, ctc_loss_algo_id);
5240 }
5241 
DoTransformTensor(Stream * stream,const dnn::BatchDescriptor & input_desc,dnn::DataType input_type,const DeviceMemoryBase & input_data,const dnn::BatchDescriptor & output_desc,dnn::DataType output_type,float scale,DeviceMemoryBase * output_data)5242 bool CudnnSupport::DoTransformTensor(Stream* stream,
5243                                      const dnn::BatchDescriptor& input_desc,
5244                                      dnn::DataType input_type,
5245                                      const DeviceMemoryBase& input_data,
5246                                      const dnn::BatchDescriptor& output_desc,
5247                                      dnn::DataType output_type, float scale,
5248                                      DeviceMemoryBase* output_data) {
5249   float beta = 0.0f;
5250   CudnnTensorDescriptor input_tensor_desc(
5251       input_desc, ToCudnnDataType(input_type, input_desc.layout()));
5252   CudnnTensorDescriptor output_tensor_desc(
5253       output_desc, ToCudnnDataType(output_type, output_desc.layout()));
5254   auto cudnn = cudnn_->GetHandle(parent_, stream);
5255   const auto status = [&] {
5256     RETURN_IF_CUDNN_ERROR(cudnnTransformTensor(
5257         cudnn.handle(), &scale, input_tensor_desc.handle(), input_data.opaque(),
5258         &beta, output_tensor_desc.handle(), output_data->opaque()));
5259     return port::Status::OK();
5260   }();
5261   return IsStatusOk(status, /*report_error=*/true);
5262 }
5263 
DoMatMul(Stream * stream,const DeviceMemory<float> & input_data,const DeviceMemory<float> & weights,const dnn::BatchDescriptor & input_dimensions,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<float> * output_data)5264 bool CudnnSupport::DoMatMul(Stream* stream,
5265                             const DeviceMemory<float>& input_data,
5266                             const DeviceMemory<float>& weights,
5267                             const dnn::BatchDescriptor& input_dimensions,
5268                             const dnn::BatchDescriptor& output_dimensions,
5269                             DeviceMemory<float>* output_data) {
5270   if (input_dimensions.count() != output_dimensions.count()) {
5271     LOG(ERROR) << "MatMul input and output dimensions are not compatible.";
5272     return false;
5273   }
5274 
5275   // We do not permute the input or output, instead we just
5276   // reinterpret the layout. We are working with row-major matrices
5277   // and the rows of the input and output correspond to batch, so
5278   // batch has to be outermost in both the input and output.
5279   //
5280   // By adding transposes to the BLAS gemm call we could perhaps make
5281   // the kYXDepthBatch layout work as well, but there has been no need
5282   // for that so far.
5283   if (input_dimensions.layout() != dnn::DataLayout::kBatchYXDepth &&
5284       input_dimensions.layout() != dnn::DataLayout::kBatchDepthYX) {
5285     LOG(ERROR) << "Unsupported MatMul input layout.";
5286     return false;
5287   }
5288   if (output_dimensions.layout() != dnn::DataLayout::kBatchYXDepth &&
5289       output_dimensions.layout() != dnn::DataLayout::kBatchDepthYX) {
5290     LOG(ERROR) << "Unsupported MatMul output layout.";
5291     return false;
5292   }
5293 
5294   if (output_dimensions.width() == 1 && output_dimensions.height() == 1) {
5295     // This is a fast path that also supports the kBatchYXDepth layout.
5296 
5297     // The matrices here are in row-major format while BLAS expects
5298     // column-major, i.e. our matrices are transposed as far as BLAS
5299     // is concerned. So we need to compute output^T =
5300     // input^T*weights^T. There is no parameter for transposing the
5301     // output in BLAS gemm, but instead we can transpose both sides of
5302     // the equality to see that this is equivalent to
5303     // output=weights*input. So we only need to swap the order of
5304     // weights and input in the matrix product to correct for the
5305     // row-major versus column-major difference.
5306     const int64_t m = output_dimensions.NodesAcrossFeatureMaps();
5307     const int64_t n = input_dimensions.count();
5308     const int64_t k = input_dimensions.NodesAcrossFeatureMaps();
5309     if (!stream
5310              ->ThenBlasGemm(blas::Transpose::kNoTranspose,
5311                             blas::Transpose::kNoTranspose, m, n, k, weights, m,
5312                             input_data, k, output_data, m)
5313              .ok()) {
5314       return false;
5315     }
5316   } else {
5317     // This is a slower and more complex path that supports output
5318     // width() * height() > 1, though it only supports the
5319     // kBatchYXDepth layout. Does support kBatchDepthYX if output
5320     // feature_map_count() == 1, as then there is no difference
5321     // between the two layouts.
5322     //
5323     // The operation here is the same as above, except that we have to
5324     // do the matrix multiplication for each (y,x) output coordinate
5325     // separately. We then interpret weights as containing K = width()
5326     // * height() different matrices, which we all multiply onto the
5327     // matrix from input_data, yielding K matrix products. We then
5328     // combine these together into one matrix by concatenating all the
5329     // first rows of these matrices, then all the seconds rows and so
5330     // on. We can do this with a batched matrix multiplication, where
5331     // the result is written to a different submatrix of the output
5332     // for each matrix multiplication.
5333     //
5334     // The reason that we only support the kBatchYXDepth output layout
5335     // is that we have to do something in the depth for each (y,x)
5336     // coordinate. The kBatchYXDepth layout has the depth information
5337     // for each point (y,x) in contiguous memory while the
5338     // kBatchDepthYX layout does not.
5339     //
5340     // TODO(broune): Consider a special case for when output depth ==
5341     // 1, as then possibly this could all be done as one matrix
5342     // multiplication instead of a batched one, which should be
5343     // faster. Another possibility would be to add a weights layout
5344     // parameter and then support kBatchDepthYX for a different
5345     // weights layout.
5346     if (output_dimensions.layout() != dnn::DataLayout::kBatchYXDepth &&
5347         !(output_dimensions.layout() == dnn::DataLayout::kBatchDepthYX &&
5348           output_dimensions.feature_map_count() == 1)) {
5349       LOG(ERROR) << "Unsupported MatMul output layout.";
5350       return false;
5351     }
5352 
5353     const float alpha = 1.0f;  // Take the matrix product without scaling it.
5354     const float beta = 0.0f;   // Ignore the original values in output_data.
5355     const uint64 m = output_dimensions.feature_map_count();
5356     const uint64 n = input_dimensions.count();
5357     const uint64 k = input_dimensions.NodesAcrossFeatureMaps();
5358     const int lda = m;
5359     const int ldb = k;
5360     const int ldc = output_dimensions.NodesAcrossFeatureMaps();
5361     const int batch_count = output_dimensions.NodesPerFeatureMap();
5362 
5363     std::vector<DeviceMemory<float>> a(batch_count);
5364     std::vector<DeviceMemory<float>> b(batch_count);
5365     std::vector<DeviceMemory<float>> c(batch_count);
5366     for (int i = 0; i < batch_count; ++i) {
5367       const int weights_offset = i * input_dimensions.NodesAcrossFeatureMaps() *
5368                                  output_dimensions.feature_map_count();
5369       a[i] = DeviceMemory<float>::MakeFromByteSize(
5370           const_cast<float*>(reinterpret_cast<const float*>(weights.opaque())) +
5371               weights_offset,
5372           weights.ElementCount() - weights_offset);
5373 
5374       b[i] = input_data;
5375 
5376       const int output_offset = i * output_dimensions.feature_map_count();
5377       c[i] = DeviceMemory<float>::MakeFromByteSize(
5378           const_cast<float*>(
5379               reinterpret_cast<const float*>(output_data->opaque())) +
5380               output_offset,
5381           output_data->ElementCount() - output_offset);
5382     }
5383     const auto toPtrs = [](std::vector<DeviceMemory<float>>& v) {
5384       std::vector<DeviceMemory<float>*> ptrs;
5385       ptrs.reserve(v.size());
5386       for (auto& mem : v) {
5387         ptrs.push_back(&mem);
5388       }
5389       return ptrs;
5390     };
5391 
5392     stream->ThenBlasGemmBatched(blas::Transpose::kNoTranspose,
5393                                 blas::Transpose::kNoTranspose, m, n, k, alpha,
5394                                 toPtrs(a), lda, toPtrs(b), ldb, beta, toPtrs(c),
5395                                 ldc, batch_count);
5396   }
5397 
5398   return stream->ok();
5399 }
5400 
DoBiasAdd(Stream * stream,const DeviceMemory<float> & input_data,const DeviceMemory<float> & biases,const dnn::BatchDescriptor & dimensions,DeviceMemory<float> * output_data)5401 bool CudnnSupport::DoBiasAdd(Stream* stream,
5402                              const DeviceMemory<float>& input_data,
5403                              const DeviceMemory<float>& biases,
5404                              const dnn::BatchDescriptor& dimensions,
5405                              DeviceMemory<float>* output_data) {
5406   CudnnTensorDescriptor input_descriptor(dimensions, CUDNN_DATA_FLOAT);
5407 
5408   dnn::BatchDescriptor bias_dimensions;
5409   bias_dimensions.set_count(1)
5410       .set_feature_map_count(dimensions.feature_map_count())
5411       .set_height(1)
5412       .set_width(1)
5413       .set_layout(dnn::DataLayout::kBatchYXDepth);
5414   CudnnTensorDescriptor bias_descriptor(bias_dimensions, CUDNN_DATA_FLOAT);
5415 
5416   // cudnnAddTensor after R3 is in-place, so we need to copy input_data to
5417   // output_data before doing the addition, unless the input and
5418   // output are at the same address.
5419   if (input_data.opaque() != output_data->opaque()) {
5420     stream->ThenMemcpy(output_data, input_data,
5421                        dimensions.ElementCount() * sizeof(float));
5422     if (!stream->ok()) {
5423       LOG(ERROR)
5424           << "stream " << stream
5425           << " could not enqueue a tensor copy as part of bias addition.";
5426       return false;
5427     }
5428   }
5429 
5430   const float alpha = 1.0f;
5431   const float beta = 1.0f;
5432 
5433   auto cudnn = cudnn_->GetHandle(parent_, stream);
5434 
5435   const auto status = [&] {
5436     RETURN_IF_CUDNN_ERROR(cudnnAddTensor(
5437         cudnn.handle(), &alpha, bias_descriptor.handle(), biases.opaque(),
5438         &beta, input_descriptor.handle(), output_data->opaque()));
5439     return port::Status::OK();
5440   }();
5441   return IsStatusOk(status, /*report_error=*/true);
5442 }
5443 
DoActivate(Stream * stream,dnn::ActivationMode activation_mode,const dnn::BatchDescriptor & dimensions,const DeviceMemory<float> & input_data,DeviceMemory<float> * output_data,uint64 options)5444 bool CudnnSupport::DoActivate(Stream* stream,
5445                               dnn::ActivationMode activation_mode,
5446                               const dnn::BatchDescriptor& dimensions,
5447                               const DeviceMemory<float>& input_data,
5448                               DeviceMemory<float>* output_data,
5449                               uint64 options) {
5450   CudnnActivationDescriptor activation_desc(
5451       activation_mode, CUDNN_PROPAGATE_NAN, dimensions.value_max());
5452 
5453   CudnnTensorDescriptor input_nd(dimensions, CUDNN_DATA_FLOAT);
5454   // Alpha is the input scaling factor.
5455   float alpha = 1.0;
5456   // Beta is the output scaling factor.
5457   float beta = 0.0;
5458 
5459   auto cudnn = cudnn_->GetHandle(parent_, stream);
5460   const auto status = [&] {
5461     RETURN_IF_CUDNN_ERROR(cudnnActivationForward(
5462         cudnn.handle(), activation_desc.handle(), &alpha, input_nd.handle(),
5463         input_data.opaque(), &beta, input_nd.handle(), output_data->opaque()));
5464     return port::Status::OK();
5465   }();
5466   return IsStatusOk(status, /*report_error=*/true);
5467 }
5468 
DoPoolForward(Stream * stream,const dnn::PoolingDescriptor & pooling_dimensions,const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<double> & input_data,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<double> * output_data,ScratchAllocator * workspace_allocator)5469 bool CudnnSupport::DoPoolForward(
5470     Stream* stream, const dnn::PoolingDescriptor& pooling_dimensions,
5471     const dnn::BatchDescriptor& input_dimensions,
5472     const DeviceMemory<double>& input_data,
5473     const dnn::BatchDescriptor& output_dimensions,
5474     DeviceMemory<double>* output_data, ScratchAllocator* workspace_allocator) {
5475   // Alpha is the scaling factor for input.
5476   double alpha = 1.0;
5477   // Beta is the scaling factor for output.
5478   double beta = 0.0;
5479 
5480   CudnnTensorDescriptor src_desc(input_dimensions, CUDNN_DATA_DOUBLE);
5481   CudnnTensorDescriptor dest_desc(output_dimensions, CUDNN_DATA_DOUBLE);
5482   CudnnPoolingDescriptor pooling_desc(pooling_dimensions);
5483 
5484   auto cudnn = cudnn_->GetHandle(parent_, stream);
5485   const auto status = [&] {
5486     RETURN_IF_CUDNN_ERROR(cudnnPoolingForward(
5487         cudnn.handle(), pooling_desc.handle(), &alpha, src_desc.handle(),
5488         input_data.opaque(), &beta, dest_desc.handle(), output_data->opaque()));
5489     return port::Status::OK();
5490   }();
5491   return IsStatusOk(status, /*report_error=*/true);
5492 }
5493 
DoPoolForward(Stream * stream,const dnn::PoolingDescriptor & pooling_dimensions,const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<float> & input_data,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<float> * output_data,ScratchAllocator * workspace_allocator)5494 bool CudnnSupport::DoPoolForward(
5495     Stream* stream, const dnn::PoolingDescriptor& pooling_dimensions,
5496     const dnn::BatchDescriptor& input_dimensions,
5497     const DeviceMemory<float>& input_data,
5498     const dnn::BatchDescriptor& output_dimensions,
5499     DeviceMemory<float>* output_data, ScratchAllocator* workspace_allocator) {
5500   // Alpha is the scaling factor for input.
5501   float alpha = 1.0;
5502   // Beta is the scaling factor for output.
5503   float beta = 0.0;
5504 
5505   CudnnTensorDescriptor src_desc(input_dimensions, CUDNN_DATA_FLOAT);
5506   CudnnTensorDescriptor dest_desc(output_dimensions, CUDNN_DATA_FLOAT);
5507   CudnnPoolingDescriptor pooling_desc(pooling_dimensions);
5508 
5509   auto cudnn = cudnn_->GetHandle(parent_, stream);
5510   const auto status = [&] {
5511     RETURN_IF_CUDNN_ERROR(cudnnPoolingForward(
5512         cudnn.handle(), pooling_desc.handle(), &alpha, src_desc.handle(),
5513         input_data.opaque(), &beta, dest_desc.handle(), output_data->opaque()));
5514     return port::Status::OK();
5515   }();
5516   return IsStatusOk(status, /*report_error=*/true);
5517 }
5518 
DoPoolForward(Stream * stream,const dnn::PoolingDescriptor & pooling_dimensions,const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<Eigen::half> & input_data,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<Eigen::half> * output_data,ScratchAllocator * workspace_allocator)5519 bool CudnnSupport::DoPoolForward(
5520     Stream* stream, const dnn::PoolingDescriptor& pooling_dimensions,
5521     const dnn::BatchDescriptor& input_dimensions,
5522     const DeviceMemory<Eigen::half>& input_data,
5523     const dnn::BatchDescriptor& output_dimensions,
5524     DeviceMemory<Eigen::half>* output_data,
5525     ScratchAllocator* workspace_allocator) {
5526   // Alpha is the scaling factor for input.
5527   float alpha = 1.0;
5528   // Beta is the scaling factor for output.
5529   float beta = 0.0;
5530 
5531   CudnnTensorDescriptor src_desc(input_dimensions, CUDNN_DATA_HALF);
5532   CudnnTensorDescriptor dest_desc(output_dimensions, CUDNN_DATA_HALF);
5533   CudnnPoolingDescriptor pooling_desc(pooling_dimensions);
5534   auto cudnn = cudnn_->GetHandle(parent_, stream);
5535   const auto status = [&] {
5536     RETURN_IF_CUDNN_ERROR(cudnnPoolingForward(
5537         cudnn.handle(), pooling_desc.handle(), &alpha, src_desc.handle(),
5538         input_data.opaque(), &beta, dest_desc.handle(), output_data->opaque()));
5539     return port::Status::OK();
5540   }();
5541   return IsStatusOk(status, /*report_error=*/true);
5542 }
5543 
DoPoolForward(Stream * stream,const dnn::PoolingDescriptor & pooling_dimensions,const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<int8> & input_data,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<int8> * output_data,ScratchAllocator * workspace_allocator)5544 bool CudnnSupport::DoPoolForward(
5545     Stream* stream, const dnn::PoolingDescriptor& pooling_dimensions,
5546     const dnn::BatchDescriptor& input_dimensions,
5547     const DeviceMemory<int8>& input_data,
5548     const dnn::BatchDescriptor& output_dimensions,
5549     DeviceMemory<int8>* output_data, ScratchAllocator* workspace_allocator) {
5550   // Alpha is the scaling factor for input.
5551   float alpha = 1.0;
5552   // Beta is the scaling factor for output.
5553   float beta = 0.0;
5554 
5555   CudnnTensorDescriptor src_desc(input_dimensions, CUDNN_DATA_INT8);
5556   CudnnTensorDescriptor dest_desc(output_dimensions, CUDNN_DATA_INT8);
5557   CudnnPoolingDescriptor pooling_desc(pooling_dimensions);
5558 
5559   auto cudnn = cudnn_->GetHandle(parent_, stream);
5560   const auto status = [&] {
5561     RETURN_IF_CUDNN_ERROR(cudnnPoolingForward(
5562         cudnn.handle(), pooling_desc.handle(), &alpha, src_desc.handle(),
5563         input_data.opaque(), &beta, dest_desc.handle(), output_data->opaque()));
5564     return port::Status::OK();
5565   }();
5566   return IsStatusOk(status, /*report_error=*/true);
5567 }
5568 
DoPoolBackward(Stream * stream,const dnn::PoolingDescriptor & pooling_dimensions,const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<double> & input_data,const dnn::BatchDescriptor & output_dimensions,const DeviceMemory<double> & output_data,const DeviceMemory<double> & input_diff_data,DeviceMemory<double> * output_diff_data,ScratchAllocator * workspace_allocator)5569 bool CudnnSupport::DoPoolBackward(
5570     Stream* stream, const dnn::PoolingDescriptor& pooling_dimensions,
5571     const dnn::BatchDescriptor& input_dimensions,
5572     const DeviceMemory<double>& input_data,
5573     const dnn::BatchDescriptor& output_dimensions,
5574     const DeviceMemory<double>& output_data,
5575     const DeviceMemory<double>& input_diff_data,
5576     DeviceMemory<double>* output_diff_data,
5577     ScratchAllocator* workspace_allocator) {
5578   // Alpha is the scaling factor for input.
5579   double alpha = 1.0;
5580   // Beta is the scaling factor for output.
5581   double beta = 0.0;
5582 
5583   CudnnTensorDescriptor src_desc(input_dimensions, CUDNN_DATA_DOUBLE);
5584   CudnnTensorDescriptor dest_desc(output_dimensions, CUDNN_DATA_DOUBLE);
5585   CudnnPoolingDescriptor pooling_desc(pooling_dimensions);
5586 
5587   auto cudnn = cudnn_->GetHandle(parent_, stream);
5588   const auto status = [&] {
5589     RETURN_IF_CUDNN_ERROR(cudnnPoolingBackward(
5590         cudnn.handle(), pooling_desc.handle(), &alpha, dest_desc.handle(),
5591         output_data.opaque(), dest_desc.handle(), input_diff_data.opaque(),
5592         src_desc.handle(), input_data.opaque(), &beta, src_desc.handle(),
5593         output_diff_data->opaque()));
5594     return port::Status::OK();
5595   }();
5596   return IsStatusOk(status, /*report_error=*/true);
5597 }
5598 
DoPoolBackward(Stream * stream,const dnn::PoolingDescriptor & pooling_dimensions,const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<float> & input_data,const dnn::BatchDescriptor & output_dimensions,const DeviceMemory<float> & output_data,const DeviceMemory<float> & input_diff_data,DeviceMemory<float> * output_diff_data,ScratchAllocator * workspace_allocator)5599 bool CudnnSupport::DoPoolBackward(
5600     Stream* stream, const dnn::PoolingDescriptor& pooling_dimensions,
5601     const dnn::BatchDescriptor& input_dimensions,
5602     const DeviceMemory<float>& input_data,
5603     const dnn::BatchDescriptor& output_dimensions,
5604     const DeviceMemory<float>& output_data,
5605     const DeviceMemory<float>& input_diff_data,
5606     DeviceMemory<float>* output_diff_data,
5607     ScratchAllocator* workspace_allocator) {
5608   // Alpha is the scaling factor for input.
5609   float alpha = 1.0;
5610   // Beta is the scaling factor for output.
5611   float beta = 0.0;
5612 
5613   CudnnTensorDescriptor src_desc(input_dimensions, CUDNN_DATA_FLOAT);
5614   CudnnTensorDescriptor dest_desc(output_dimensions, CUDNN_DATA_FLOAT);
5615   CudnnPoolingDescriptor pooling_desc(pooling_dimensions);
5616 
5617   auto cudnn = cudnn_->GetHandle(parent_, stream);
5618   const auto status = [&] {
5619     RETURN_IF_CUDNN_ERROR(cudnnPoolingBackward(
5620         cudnn.handle(), pooling_desc.handle(), &alpha, dest_desc.handle(),
5621         output_data.opaque(), dest_desc.handle(), input_diff_data.opaque(),
5622         src_desc.handle(), input_data.opaque(), &beta, src_desc.handle(),
5623         output_diff_data->opaque()));
5624     return port::Status::OK();
5625   }();
5626   return IsStatusOk(status, /*report_error=*/true);
5627 }
5628 
DoPoolBackward(Stream * stream,const dnn::PoolingDescriptor & pooling_dimensions,const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<Eigen::half> & input_data,const dnn::BatchDescriptor & output_dimensions,const DeviceMemory<Eigen::half> & output_data,const DeviceMemory<Eigen::half> & input_diff_data,DeviceMemory<Eigen::half> * output_diff_data,ScratchAllocator * workspace_allocator)5629 bool CudnnSupport::DoPoolBackward(
5630     Stream* stream, const dnn::PoolingDescriptor& pooling_dimensions,
5631     const dnn::BatchDescriptor& input_dimensions,
5632     const DeviceMemory<Eigen::half>& input_data,
5633     const dnn::BatchDescriptor& output_dimensions,
5634     const DeviceMemory<Eigen::half>& output_data,
5635     const DeviceMemory<Eigen::half>& input_diff_data,
5636     DeviceMemory<Eigen::half>* output_diff_data,
5637     ScratchAllocator* workspace_allocator) {
5638   // Alpha is the scaling factor for input.
5639   float alpha = 1.0;
5640   // Beta is the scaling factor for output.
5641   float beta = 0.0;
5642 
5643   CudnnTensorDescriptor src_desc(input_dimensions, CUDNN_DATA_HALF);
5644   CudnnTensorDescriptor dest_desc(output_dimensions, CUDNN_DATA_HALF);
5645   CudnnPoolingDescriptor pooling_desc(pooling_dimensions);
5646 
5647   auto cudnn = cudnn_->GetHandle(parent_, stream);
5648   const auto status = [&] {
5649     RETURN_IF_CUDNN_ERROR(cudnnPoolingBackward(
5650         cudnn.handle(), pooling_desc.handle(), &alpha, dest_desc.handle(),
5651         output_data.opaque(), dest_desc.handle(), input_diff_data.opaque(),
5652         src_desc.handle(), input_data.opaque(), &beta, src_desc.handle(),
5653         output_diff_data->opaque()));
5654     return port::Status::OK();
5655   }();
5656   return IsStatusOk(status, /*report_error=*/true);
5657 }
5658 
DoNormalizeWithDimensions(Stream * stream,const dnn::NormalizeDescriptor & normalize_descriptor,const dnn::BatchDescriptor & dimensions,const DeviceMemory<float> & input_data,DeviceMemory<float> * output_data)5659 bool CudnnSupport::DoNormalizeWithDimensions(
5660     Stream* stream, const dnn::NormalizeDescriptor& normalize_descriptor,
5661     const dnn::BatchDescriptor& dimensions,
5662     const DeviceMemory<float>& input_data, DeviceMemory<float>* output_data) {
5663   // Check for unsupported modes.
5664   if (normalize_descriptor.wrap_around()) {
5665     LOG(ERROR) << "CUDA LRN does not support cudnn-around mode";
5666     return false;
5667   }
5668   if (normalize_descriptor.segment_size()) {
5669     LOG(ERROR) << "CUDA LRN does not support segmentation";
5670     return false;
5671   }
5672 
5673   CudnnTensorDescriptor dims(dimensions, CUDNN_DATA_FLOAT);
5674   CudnnNormalizeDescriptor normalize(normalize_descriptor);
5675 
5676   // Alpha is the scaling factor for input.
5677   float alpha = 1.0f;
5678   // Beta is the scaling factor for output.
5679   float beta = 0.0f;
5680 
5681   auto cudnn = cudnn_->GetHandle(parent_, stream);
5682 
5683   // Launch the normalization.
5684   const auto status = [&] {
5685     RETURN_IF_CUDNN_ERROR(cudnnLRNCrossChannelForward(
5686         cudnn.handle(), normalize.handle(), CUDNN_LRN_CROSS_CHANNEL_DIM1,
5687         &alpha, dims.handle(), input_data.opaque(), &beta, dims.handle(),
5688         output_data->opaque()));
5689     return port::Status::OK();
5690   }();
5691   return IsStatusOk(status, /*report_error=*/true);
5692 }
5693 
DoNormalizeBackwardWithDimensions(Stream * stream,const dnn::NormalizeDescriptor & normalize_descriptor,const dnn::BatchDescriptor & dimensions,const DeviceMemory<float> & raw_data,const DeviceMemory<float> & normalized_data,const DeviceMemory<float> & normalized_variable_gradient,DeviceMemory<float> * raw_variable_gradient,ScratchAllocator * workspace_allocator)5694 bool CudnnSupport::DoNormalizeBackwardWithDimensions(
5695     Stream* stream, const dnn::NormalizeDescriptor& normalize_descriptor,
5696     const dnn::BatchDescriptor& dimensions, const DeviceMemory<float>& raw_data,
5697     const DeviceMemory<float>& normalized_data,
5698     const DeviceMemory<float>& normalized_variable_gradient,
5699     DeviceMemory<float>* raw_variable_gradient,
5700     ScratchAllocator* workspace_allocator) {
5701   // Check for unsupported modes.
5702   if (normalize_descriptor.wrap_around()) {
5703     LOG(ERROR) << "CUDA LRN does not support cudnn-around mode";
5704     return false;
5705   }
5706   if (normalize_descriptor.segment_size()) {
5707     LOG(ERROR) << "CUDA LRN does not support segmentation";
5708     return false;
5709   }
5710 
5711   CudnnTensorDescriptor dims(dimensions, CUDNN_DATA_FLOAT);
5712   CudnnNormalizeDescriptor normalize(normalize_descriptor);
5713 
5714   float alpha = 1.0f;
5715   float beta = 0.0f;
5716 
5717   auto cudnn = cudnn_->GetHandle(parent_, stream);
5718   const auto status = [&] {
5719     RETURN_IF_CUDNN_ERROR(cudnnLRNCrossChannelBackward(
5720         cudnn.handle(), normalize.handle(), CUDNN_LRN_CROSS_CHANNEL_DIM1,
5721         &alpha, dims.handle(), normalized_data.opaque(), dims.handle(),
5722         normalized_variable_gradient.opaque(), dims.handle(), raw_data.opaque(),
5723         &beta, dims.handle(), raw_variable_gradient->opaque()));
5724     return port::Status::OK();
5725   }();
5726   return IsStatusOk(status, /*report_error=*/true);
5727 }
5728 
DoDepthConcatenate(Stream * stream,port::ArraySlice<dnn::BatchDescriptor> input_dimensions,port::ArraySlice<const DeviceMemory<float> * > input_data,DeviceMemory<float> * output_data)5729 bool CudnnSupport::DoDepthConcatenate(
5730     Stream* stream, port::ArraySlice<dnn::BatchDescriptor> input_dimensions,
5731     port::ArraySlice<const DeviceMemory<float>*> input_data,
5732     DeviceMemory<float>* output_data) {
5733   CHECK_EQ(input_dimensions.size(), input_data.size());
5734 
5735   for (const auto& dimensions : input_dimensions) {
5736     if (dimensions.layout() != dnn::DataLayout::kBatchDepthYX) {
5737       LOG(ERROR) << "CudnnSupport::DoDepthConcatenate currently only "
5738                     "supports the kBatchDepthYX layout.";
5739       return false;
5740     }
5741   }
5742 
5743   if (input_dimensions.empty()) {
5744     return true;  // Nothing to do.
5745   }
5746 
5747   dnn::BatchDescriptor output_dimensions =
5748       dnn::BatchDescriptor::DepthConcatenateOutputDescriptor(input_dimensions);
5749 
5750   const int64_t area = output_dimensions.width() * output_dimensions.height();
5751   const auto index = [area](int64_t batch, int64_t depth, int64_t yx,
5752                             int64_t max_depth) {
5753     return (batch * max_depth + depth) * area + yx;
5754   };
5755 
5756   std::vector<float> output_host(output_dimensions.ElementCount());
5757   std::vector<float> tmp;
5758   int64_t depth_sum = 0;
5759   for (size_t i = 0; i < input_data.size(); ++i) {
5760     const auto& dimensions = input_dimensions[i];
5761     tmp.resize(dimensions.ElementCount());
5762     stream->ThenMemcpyD2H<float>(*input_data[i], absl::MakeSpan(tmp));
5763     port::Status block_status = stream->BlockHostUntilDone();
5764     if (!block_status.ok()) {
5765       LOG(ERROR) << "BlockHostUntilDone failed: " << block_status;
5766       return false;
5767     }
5768 
5769     for (int64_t batch = 0; batch < output_dimensions.count(); ++batch) {
5770       for (int64_t yx = 0; yx < area; ++yx) {
5771         for (int64_t depth = 0; depth < dimensions.feature_map_count();
5772              ++depth) {
5773           LOG(INFO) << output_dimensions.ElementCount() << ' ' << batch << ' '
5774                     << yx << ' ' << depth;
5775           output_host[index(batch, depth + depth_sum, yx,
5776                             output_dimensions.feature_map_count())] =
5777               tmp[index(batch, depth, yx, dimensions.feature_map_count())];
5778         }
5779       }
5780     }
5781     depth_sum += dimensions.feature_map_count();
5782   }
5783   stream->ThenMemcpyH2D<float>(output_host, output_data);
5784   return true;
5785 }
5786 
DoElementwiseOperate(Stream * stream,dnn::ElementwiseOperation operation,port::ArraySlice<dnn::BatchDescriptor> input_dimensions,port::ArraySlice<const DeviceMemory<float> * > input_data,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<float> * output_data)5787 bool CudnnSupport::DoElementwiseOperate(
5788     Stream* stream, dnn::ElementwiseOperation operation,
5789     port::ArraySlice<dnn::BatchDescriptor> input_dimensions,
5790     port::ArraySlice<const DeviceMemory<float>*> input_data,
5791     const dnn::BatchDescriptor& output_dimensions,
5792     DeviceMemory<float>* output_data) {
5793   LOG(FATAL) << "not yet implemented";  // TODO(leary)
5794   return false;
5795 }
5796 
DoXYPad(Stream * stream,const dnn::BatchDescriptor & dimensions,const DeviceMemory<float> & input_data,int64_t left_pad,int64_t right_pad,int64_t top_pad,int64_t bottom_pad,DeviceMemory<float> * output_data)5797 bool CudnnSupport::DoXYPad(Stream* stream,
5798                            const dnn::BatchDescriptor& dimensions,
5799                            const DeviceMemory<float>& input_data,
5800                            int64_t left_pad, int64_t right_pad, int64_t top_pad,
5801                            int64_t bottom_pad,
5802                            DeviceMemory<float>* output_data) {
5803   LOG(FATAL) << "not yet implemented";  // TODO(leary)
5804   return false;
5805 }
5806 
DoXYSlice(Stream * stream,const dnn::BatchDescriptor & dimensions,const DeviceMemory<float> & input_data,int64_t left_trim,int64_t right_trim,int64_t top_trim,int64_t bottom_trim,DeviceMemory<float> * output_data)5807 bool CudnnSupport::DoXYSlice(Stream* stream,
5808                              const dnn::BatchDescriptor& dimensions,
5809                              const DeviceMemory<float>& input_data,
5810                              int64_t left_trim, int64_t right_trim,
5811                              int64_t top_trim, int64_t bottom_trim,
5812                              DeviceMemory<float>* output_data) {
5813   LOG(FATAL) << "not yet implemented";  // TODO(leary)
5814   return false;
5815 }
5816 
DoMemcpyD2HQuantized(Stream * stream,const DeviceMemory<float> & gpu_unquantized_src,dnn::QuantizedActivationMode mode,void * host_dst,int64_t size)5817 bool CudnnSupport::DoMemcpyD2HQuantized(
5818     Stream* stream, const DeviceMemory<float>& gpu_unquantized_src,
5819     dnn::QuantizedActivationMode mode, void* host_dst, int64_t size) {
5820   LOG(ERROR) << "quantized memcpy not supported by cuDNN";
5821   return false;
5822 }
5823 
DoMemcpyH2DQuantized(Stream * stream,const void * host_src,int64_t size,dnn::QuantizedActivationMode mode,DeviceMemory<float> * gpu_unquantized_dst)5824 bool CudnnSupport::DoMemcpyH2DQuantized(
5825     Stream* stream, const void* host_src, int64_t size,
5826     dnn::QuantizedActivationMode mode,
5827     DeviceMemory<float>* gpu_unquantized_dst) {
5828   LOG(ERROR) << "quantized memcpy not supported by cuDNN";
5829   return false;
5830 }
5831 
DeriveOutputBatchDescriptor(const dnn::BatchDescriptor & batch_descriptor,const dnn::FilterDescriptor & filter_descriptor,const dnn::ConvolutionDescriptor & convolution_descriptor,dnn::BatchDescriptor * output_batch_descriptor)5832 bool CudnnSupport::DeriveOutputBatchDescriptor(
5833     const dnn::BatchDescriptor& batch_descriptor,
5834     const dnn::FilterDescriptor& filter_descriptor,
5835     const dnn::ConvolutionDescriptor& convolution_descriptor,
5836     dnn::BatchDescriptor* output_batch_descriptor) {
5837   CudnnTensorDescriptor input_nd(batch_descriptor, CUDNN_DATA_FLOAT);
5838   CudnnFilterDescriptor filter(filter_descriptor, CUDNN_DATA_FLOAT);
5839   CudnnConvolutionDescriptor conv(convolution_descriptor, CUDNN_DATA_FLOAT);
5840 
5841   int dn = batch_descriptor.ndims() + 2;
5842   std::vector<int> dims(dn);  // in BDYX
5843   const auto status = [&] {
5844     RETURN_IF_CUDNN_ERROR(cudnnGetConvolutionNdForwardOutputDim(
5845         conv.handle(), input_nd.handle(), filter.handle(), dn, dims.data()));
5846     output_batch_descriptor->set_count(dims[0])
5847         .set_feature_map_count(dims[1])
5848         .set_layout(batch_descriptor.layout());
5849 
5850     for (int i = 0; i < batch_descriptor.ndims(); i++) {
5851       output_batch_descriptor->set_spatial_dim(static_cast<dnn::DimIndex>(i),
5852                                                dims.rbegin()[i]);
5853     }
5854     return port::Status::OK();
5855   }();
5856   return IsStatusOk(status, /*report_error=*/true);
5857 }
5858 
5859 }  // namespace gpu
5860 
initialize_cudnn()5861 void initialize_cudnn() {
5862   port::Status status =
5863       PluginRegistry::Instance()->RegisterFactory<PluginRegistry::DnnFactory>(
5864           cuda::kCudaPlatformId, gpu::kCuDnnPlugin, "cuDNN",
5865           [](internal::StreamExecutorInterface* parent) -> dnn::DnnSupport* {
5866             gpu::GpuExecutor* cuda_executor =
5867                 dynamic_cast<gpu::GpuExecutor*>(parent);
5868             if (cuda_executor == nullptr) {
5869               LOG(ERROR) << "Attempting to initialize an instance of the cuDNN "
5870                          << "support library with a non-CUDA StreamExecutor";
5871               return nullptr;
5872             }
5873 
5874             gpu::CudnnSupport* dnn = new gpu::CudnnSupport(cuda_executor);
5875             if (!dnn->Init().ok()) {
5876               // Note: Init() will log a more specific error.
5877               delete dnn;
5878               return nullptr;
5879             }
5880             return dnn;
5881           });
5882 
5883   if (!status.ok()) {
5884     LOG(ERROR) << "Unable to register cuDNN factory: "
5885                << status.error_message();
5886   }
5887 
5888   PluginRegistry::Instance()->SetDefaultFactory(
5889       cuda::kCudaPlatformId, PluginKind::kDnn, gpu::kCuDnnPlugin);
5890 }
5891 
5892 }  // namespace stream_executor
5893 
5894 #pragma clang diagnostic pop
5895 
5896 REGISTER_MODULE_INITIALIZER(register_cudnn,
5897                             { stream_executor::initialize_cudnn(); });
5898