• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.h"
17 
18 #include <functional>
19 #include <memory>
20 #include <optional>
21 #include <utility>
22 
23 #include "absl/base/thread_annotations.h"
24 #include "absl/container/flat_hash_map.h"
25 #include "absl/memory/memory.h"
26 #include "absl/strings/str_cat.h"
27 #include "absl/strings/str_format.h"
28 #include "third_party/eigen3/Eigen/Core"
29 #include "tensorflow/compiler/xla/stream_executor/cuda/cuda_activation.h"
30 #include "tensorflow/compiler/xla/stream_executor/cuda/cuda_diagnostics.h"
31 #include "tensorflow/compiler/xla/stream_executor/cuda/cuda_driver.h"
32 #include "tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executor.h"
33 #include "tensorflow/compiler/xla/stream_executor/cuda/cuda_platform_id.h"
34 #include "tensorflow/compiler/xla/stream_executor/cuda/cuda_stream.h"
35 #include "tensorflow/compiler/xla/stream_executor/cuda/cuda_timer.h"
36 #include "tensorflow/compiler/xla/stream_executor/cuda/cudnn_version.h"
37 #include "tensorflow/compiler/xla/stream_executor/dnn.h"
38 #include "tensorflow/compiler/xla/stream_executor/lib/env.h"
39 #include "tensorflow/compiler/xla/stream_executor/lib/error.h"
40 #include "tensorflow/compiler/xla/stream_executor/lib/initialize.h"
41 #include "tensorflow/compiler/xla/stream_executor/lib/mathutil.h"
42 #include "tensorflow/compiler/xla/stream_executor/lib/threadpool.h"
43 #include "tensorflow/compiler/xla/stream_executor/platform/logging.h"
44 #include "tensorflow/compiler/xla/stream_executor/plugin_registry.h"
45 #include "tensorflow/compiler/xla/stream_executor/scratch_allocator.h"
46 #include "tensorflow/compiler/xla/stream_executor/stream.h"
47 #include "tensorflow/compiler/xla/stream_executor/stream_executor_internal.h"
48 #include "tensorflow/compiler/xla/stream_executor/stream_executor_pimpl.h"
49 #include "tensorflow/core/lib/core/errors.h"
50 #include "tensorflow/core/platform/tensor_float_32_utils.h"
51 #include "tensorflow/core/util/determinism.h"
52 #include "tensorflow/core/util/env_var.h"
53 #include "tensorflow/core/util/use_cudnn.h"
54 // clang-format off
55 #include "third_party/gpus/cudnn/cudnn.h"
56 #if CUDNN_VERSION >= 8100 && TF_ENABLE_CUDNN_FRONTEND
57 #include "third_party/cudnn_frontend/include/cudnn_frontend.h"
58 #endif  // CUDNN_VERSION >= 8100 && TF_ENABLE_CUDNN_FRONTEND
59 #include "absl/strings/string_view.h"
60 // clang-format on
61 
62 #pragma clang diagnostic push
63 
64 // Make sure that Eigen::half forward declaration in dnn.h matches the
65 // declaration in Eigen.
66 #pragma clang diagnostic warning "-Wmismatched-tags"
67 
68 namespace stream_executor {
69 namespace gpu {
70 
71 PLUGIN_REGISTRY_DEFINE_PLUGIN_ID(kCuDnnPlugin);
72 
73 namespace {
74 
75 static_assert(CUDNN_VERSION >= 7300, "cuDNN needs to be version 7.3 or higher");
76 
77 // Exits the program if 'expr' doesn't return CUDNN_STATUS_SUCCESS.
78 #define CHECK_CUDNN_OK(expr) CHECK_EQ(expr, CUDNN_STATUS_SUCCESS)
79 
80 // If 'expr' doesn't return CUDNN_STATUS_SUCCESS, returns from the current
81 // function with a non-successful port::Status.
82 #define RETURN_IF_CUDNN_ERROR(expr)                                     \
83   do {                                                                  \
84     cudnnStatus_t _status = (expr);                                     \
85     if (!SE_PREDICT_TRUE(_status == CUDNN_STATUS_SUCCESS)) {            \
86       std::ostringstream oss;                                           \
87       oss << CudnnStatusToString(_status) << "\nin " << __FILE__ << "(" \
88           << __LINE__ << "): '" << #expr << "'";                        \
89       return port::Status(port::error::UNKNOWN, oss.str());             \
90     }                                                                   \
91   } while (false)
92 
93 #define RETURN_MSG_IF_CUDNN_ERROR(expr)                                 \
94   do {                                                                  \
95     cudnnStatus_t _status = (expr).get_status();                        \
96     if (!SE_PREDICT_TRUE(_status == CUDNN_STATUS_SUCCESS)) {            \
97       std::ostringstream oss;                                           \
98       oss << CudnnStatusToString(_status) << "\nin " << __FILE__ << "(" \
99           << __LINE__ << "): '" << #expr << "' " << (expr).get_error(); \
100       return port::Status(port::error::UNKNOWN, oss.str());             \
101     }                                                                   \
102   } while (false)
103 
104 #define RETURN_FALSE_IF_CUDNN_ERROR(expr)                                \
105   do {                                                                   \
106     if (!SE_PREDICT_TRUE((expr).get_status() == CUDNN_STATUS_SUCCESS)) { \
107       return false;                                                      \
108     }                                                                    \
109   } while (false)
110 
111 // Converts (via narrowing) a type T value to a type U, and checks that the
112 // value has no value change due to the conversion.
113 template <typename WideT, typename NarrowT>
CheckedNarrowing(const WideT & wide)114 NarrowT CheckedNarrowing(const WideT& wide) {
115   NarrowT narrow = wide;
116   CHECK_EQ(narrow, wide)
117       << "checked narrowing failed; values not equal post-conversion";
118   return narrow;
119 }
120 
CudnnStatusToString(cudnnStatus_t status)121 std::string CudnnStatusToString(cudnnStatus_t status) {
122   switch (status) {
123     case CUDNN_STATUS_SUCCESS:
124       return "CUDNN_STATUS_SUCCESS";
125     case CUDNN_STATUS_NOT_INITIALIZED:
126       return "CUDNN_STATUS_NOT_INITIALIZED";
127     case CUDNN_STATUS_ALLOC_FAILED:
128       return "CUDNN_STATUS_ALLOC_FAILED";
129     case CUDNN_STATUS_BAD_PARAM:
130       return "CUDNN_STATUS_BAD_PARAM";
131     case CUDNN_STATUS_INTERNAL_ERROR:
132       return "CUDNN_STATUS_INTERNAL_ERROR";
133     case CUDNN_STATUS_INVALID_VALUE:
134       return "CUDNN_STATUS_INVALID_VALUE";
135     case CUDNN_STATUS_ARCH_MISMATCH:
136       return "CUDNN_STATUS_ARCH_MISMATCH";
137     case CUDNN_STATUS_MAPPING_ERROR:
138       return "CUDNN_STATUS_MAPPING_ERROR";
139     case CUDNN_STATUS_EXECUTION_FAILED:
140       return "CUDNN_STATUS_EXECUTION_FAILED";
141     case CUDNN_STATUS_NOT_SUPPORTED:
142       return "CUDNN_STATUS_NOT_SUPPORTED";
143     case CUDNN_STATUS_LICENSE_ERROR:
144       return "CUDNN_STATUS_LICENSE_ERROR";
145     case CUDNN_STATUS_RUNTIME_PREREQUISITE_MISSING:
146       return "CUDNN_STATUS_RUNTIME_PREREQUISITE_MISSING";
147     case CUDNN_STATUS_RUNTIME_IN_PROGRESS:
148       return "CUDNN_STATUS_RUNTIME_IN_PROGRESS";
149     case CUDNN_STATUS_RUNTIME_FP_OVERFLOW:
150       return "CUDNN_STATUS_RUNTIME_FP_OVERFLOW";
151     default:
152       return absl::StrCat("<unknown cudnn status: ", static_cast<int>(status),
153                           ">");
154   }
155 }
156 
157 // RAII wrapper for all calls to cuDNN with a cuDNN handle argument.
158 //
159 // See CudnnAccess::GetHandle() for details.
160 class CudnnHandle {
161  public:
162   // Takes ownership of the executor context and the lock to access cuDNN
163   // using handle.
CudnnHandle(gpu::ScopedActivateExecutorContext context,std::unique_ptr<absl::MutexLock> lock,cudnnHandle_t handle)164   CudnnHandle(gpu::ScopedActivateExecutorContext context,
165               std::unique_ptr<absl::MutexLock> lock, cudnnHandle_t handle)
166       : context_(std::move(context)), lock_(std::move(lock)), handle_(handle) {}
167 
168   // Returns cuDNN handle. To be passed directly to cuDNN APIs, don't keep
169   // a copy.
handle() const170   cudnnHandle_t handle() const { return handle_; }
171 
172  private:
173   gpu::ScopedActivateExecutorContext context_;
174   std::unique_ptr<absl::MutexLock> lock_;
175   cudnnHandle_t handle_;  // Not owned.
176 };
177 
178 }  // namespace
179 
180 // Wraps a cuDNN handle and provides access to it through CudnnHandle
181 // instances, which also locks a mutex, acquires the CUDA context, and sets
182 // the stream that cuDNN should use to enqueue any work.
183 //
184 // Note: CudnnSupport::cudnn_ should be the only instantiation of this class.
185 class CudnnAccess {
186  public:
187   // Takes ownership of the handle.
CudnnAccess(cudnnHandle_t handle)188   explicit CudnnAccess(cudnnHandle_t handle) : handle_(handle) {}
189 
~CudnnAccess()190   ~CudnnAccess() {
191     absl::MutexLock lock(&mutex_);
192     cudnnDestroy(handle_);
193   }
194 
195   // Creates a CudnnHandle instance for stream.
196   //
197   // cuDNN API calls using the same handle instance need to be serialized
198   // across threads. This is guaranteed by CudnnHandle instances locking the
199   // mutex owned by this class.
200   //
201   // Most cuDNN APIs taking a handle perform work on a CUDA stream. The
202   // CudnnHandle instance acquires the executor's CUDA context and sets cuDNN
203   // to use the provided stream.
204   //
205   // The stream argument may be null, which translates to the legacy default
206   // stream. See
207   // https://docs.nvidia.com/cuda/cuda-driver-api/stream-sync-behavior.html.
208   // The legacy default stream synchronizes with all other streams and it is
209   // therefore a bad idea (performance wise) to call any cuDNN APIs that
210   // enqueue work in the stream.
GetHandle(GpuExecutor * executor,Stream * stream)211   CudnnHandle GetHandle(GpuExecutor* executor, Stream* stream) {
212     auto lock = std::make_unique<absl::MutexLock>(&mutex_);
213     mutex_.AssertHeld();
214     gpu::ScopedActivateExecutorContext context(executor);
215     CUstream cu_stream = stream ? AsGpuStreamValue(stream) : cudaStreamLegacy;
216     if (!current_stream_ || cu_stream != *current_stream_) {
217       current_stream_ = cu_stream;
218       const auto status = cudnnSetStream(handle_, cu_stream);
219       CHECK_EQ(status, CUDNN_STATUS_SUCCESS) << "Failed to set cuDNN stream.";
220     }
221     return CudnnHandle(std::move(context), std::move(lock), handle_);
222   }
223 
NotifyStreamDestroyed(Stream * stream)224   void NotifyStreamDestroyed(Stream* stream) {
225     CUstream cu_stream = AsGpuStreamValue(stream);
226     absl::MutexLock lock(&mutex_);
227     if (current_stream_ && cu_stream == *current_stream_) {
228       current_stream_.reset();
229     }
230   }
231 
232  private:
233   // Guards current_stream_ and the enqueueing of cuDNN operations via the
234   // handle_ below.
235   absl::Mutex mutex_;
236 
237   // If set, indicates the stream currently active on handle_, to avoid the
238   // overhead of re-setting the same stream unnecessarily.
239   std::optional<CUstream> current_stream_ ABSL_GUARDED_BY(mutex_);
240 
241   // cuDNN library handle.
242   cudnnHandle_t handle_ ABSL_GUARDED_BY(mutex_);  // Owned.
243 };
244 
245 namespace {
246 
247 // A helper function to return the internal compute type for
248 // RNNs in cudnn.
249 cudnnDataType_t GetRnnComputeType(dnn::DataType data_type);
250 
ToConvForwardAlgo(dnn::AlgorithmDesc algorithm)251 cudnnConvolutionFwdAlgo_t ToConvForwardAlgo(dnn::AlgorithmDesc algorithm) {
252   cudnnConvolutionFwdAlgo_t algo =
253       cudnnConvolutionFwdAlgo_t(algorithm.algo_id());
254   switch (algo) {
255     case CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM:
256     case CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM:
257     case CUDNN_CONVOLUTION_FWD_ALGO_GEMM:
258     case CUDNN_CONVOLUTION_FWD_ALGO_DIRECT:
259     case CUDNN_CONVOLUTION_FWD_ALGO_FFT:
260     case CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING:
261     case CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD:
262     case CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED:
263       return algo;
264     default:
265       LOG(FATAL) << "Unsupported Cudnn convolution forward algorithm: "
266                  << algorithm.algo_id();
267   }
268 }
269 
ToConvBackwardDataAlgo(dnn::AlgorithmDesc algorithm)270 cudnnConvolutionBwdDataAlgo_t ToConvBackwardDataAlgo(
271     dnn::AlgorithmDesc algorithm) {
272   cudnnConvolutionBwdDataAlgo_t algo =
273       cudnnConvolutionBwdDataAlgo_t(algorithm.algo_id());
274   switch (algo) {
275     case CUDNN_CONVOLUTION_BWD_DATA_ALGO_0:
276     case CUDNN_CONVOLUTION_BWD_DATA_ALGO_1:
277     case CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT:
278     case CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING:
279     case CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD:
280     case CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD_NONFUSED:
281       return algo;
282     default:
283       LOG(FATAL)
284           << "Unsupported Cudnn convolution backward algorithm for data: "
285           << algorithm.algo_id();
286   }
287 }
288 
ToConvBackwardFilterAlgo(dnn::AlgorithmDesc algorithm)289 cudnnConvolutionBwdFilterAlgo_t ToConvBackwardFilterAlgo(
290     dnn::AlgorithmDesc algorithm) {
291   cudnnConvolutionBwdFilterAlgo_t algo =
292       cudnnConvolutionBwdFilterAlgo_t(algorithm.algo_id());
293   switch (algo) {
294     case CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0:
295     case CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1:
296     case CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT:
297     case CUDNN_CONVOLUTION_BWD_FILTER_ALGO_3:
298     // Based on cudnn.h, the following is not implemented.
299     // case CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD:
300     case CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD_NONFUSED:
301     case CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT_TILING:
302       return algo;
303     default:
304       LOG(FATAL)
305           << "Unsupported Cudnn convolution backward algorithm for filter: "
306           << algorithm.algo_id();
307   }
308 }
309 
GetCudnnProperty(libraryPropertyType type)310 port::StatusOr<int> GetCudnnProperty(libraryPropertyType type) {
311   int value;
312   RETURN_IF_CUDNN_ERROR(cudnnGetProperty(type, &value));
313   return value;
314 }
315 
ToCudnnRNNAlgo(std::optional<dnn::AlgorithmDesc> algorithm)316 cudnnRNNAlgo_t ToCudnnRNNAlgo(std::optional<dnn::AlgorithmDesc> algorithm) {
317   if (!algorithm.has_value()) {
318     return CUDNN_RNN_ALGO_STANDARD;
319   }
320   cudnnRNNAlgo_t algo = static_cast<cudnnRNNAlgo_t>(algorithm->algo_id());
321   switch (algo) {
322     case CUDNN_RNN_ALGO_STANDARD:
323     case CUDNN_RNN_ALGO_PERSIST_STATIC:
324     case CUDNN_RNN_ALGO_PERSIST_DYNAMIC:
325       return algo;
326     default:
327       LOG(FATAL) << "Unsupported Cudnn RNN algorithm: " << algorithm->algo_id();
328   }
329 }
330 
GetLoadedCudnnVersion(CudnnVersion * version)331 port::Status GetLoadedCudnnVersion(CudnnVersion* version) {
332   TF_ASSIGN_OR_RETURN(version->major_version, GetCudnnProperty(MAJOR_VERSION));
333   TF_ASSIGN_OR_RETURN(version->minor_version, GetCudnnProperty(MINOR_VERSION));
334   TF_ASSIGN_OR_RETURN(version->patch_level, GetCudnnProperty(PATCH_LEVEL));
335   return ::tensorflow::OkStatus();
336 }
337 
338 enum class PreloadCudnnType { ConvFwd, ConvBwdFilter, ConvBwdData, Rnn };
339 
340 // Preload sub libs for cudnn 8.0.4+ to make sure that the loading time isn't
341 // measured in the autotuning.
PreloadCudnnSubLibs(PreloadCudnnType type)342 void PreloadCudnnSubLibs(PreloadCudnnType type) {
343 #if CUDNN_VERSION >= 8004
344   switch (type) {
345     case PreloadCudnnType::ConvBwdFilter:
346     case PreloadCudnnType::ConvBwdData: {
347       cudnnOpsTrainVersionCheck();
348       cudnnCnnTrainVersionCheck();
349       [[clang::fallthrough]];
350     }
351     case PreloadCudnnType::ConvFwd: {
352       cudnnOpsInferVersionCheck();
353       cudnnCnnInferVersionCheck();
354       break;
355     }
356     case PreloadCudnnType::Rnn: {
357       cudnnOpsInferVersionCheck();
358       cudnnAdvInferVersionCheck();
359       cudnnOpsTrainVersionCheck();
360       cudnnAdvTrainVersionCheck();
361       break;
362     }
363   }
364 #endif  // CUDNN_VERSION >= 8004
365 }
366 
PreloadCudnnSubLibsHelper(dnn::ConvolutionKind kind)367 void PreloadCudnnSubLibsHelper(dnn::ConvolutionKind kind) {
368   switch (kind) {
369     case dnn::ConvolutionKind::FORWARD: {
370       PreloadCudnnSubLibs(PreloadCudnnType::ConvFwd);
371       break;
372     }
373     case dnn::ConvolutionKind::BACKWARD_DATA: {
374       PreloadCudnnSubLibs(PreloadCudnnType::ConvBwdData);
375       break;
376     }
377     case dnn::ConvolutionKind::BACKWARD_FILTER: {
378       PreloadCudnnSubLibs(PreloadCudnnType::ConvBwdFilter);
379       break;
380     }
381     default: {
382       LOG(WARNING) << "Unsupported dnn::ConvolutionKind: "
383                    << static_cast<int>(kind) << " for cuDNN preload.";
384       break;
385     }
386   }
387 }
388 
389 }  // namespace
390 
CudnnSupport(GpuExecutor * parent)391 CudnnSupport::CudnnSupport(GpuExecutor* parent) : parent_(parent) {}
392 
Init()393 port::Status CudnnSupport::Init() {
394   ScopedActivateExecutorContext context(parent_);
395 
396   // Peek at the last error to give more information in cases of errors.
397   cudaError_t cerr = cudaPeekAtLastError();
398   if (cerr != cudaSuccess) {
399     LOG(WARNING) << "There was an error before creating cudnn handle: "
400                  << cudaGetErrorName(cerr) << " : " << cudaGetErrorString(cerr);
401   }
402 
403   cudnnHandle_t cudnn_handle = nullptr;
404   const auto status = cudnnCreate(&cudnn_handle);
405   if (status == CUDNN_STATUS_SUCCESS) {
406     CudnnVersion source_version(CUDNN_MAJOR, CUDNN_MINOR, CUDNN_PATCHLEVEL);
407 
408     CudnnVersion loaded_version;
409     TF_RETURN_IF_ERROR(GetLoadedCudnnVersion(&loaded_version));
410     if (!IsSourceCompatibleWithCudnnLibrary(source_version, loaded_version)) {
411       const std::string error = absl::StrCat(
412           "Loaded runtime CuDNN library: ", loaded_version.ToString(),
413           " but source was compiled with: ", source_version.ToString(),
414           ".  CuDNN library needs to have matching major version and equal or "
415           "higher minor version. If using a binary install, upgrade your CuDNN "
416           "library.  If building from sources, make sure the library loaded at "
417           "runtime is compatible with the version specified during compile "
418           "configuration.");
419       LOG(ERROR) << error;
420       cudnnDestroy(cudnn_handle);
421       return port::Status(port::error::INTERNAL, error);
422     }
423 
424     cudnn_.reset(new CudnnAccess(cudnn_handle));
425 
426     LOG(INFO) << "Loaded cuDNN version " << cudnnGetVersion();
427     return ::tensorflow::OkStatus();
428   }
429 
430   CHECK_EQ(cudnn_handle, nullptr);
431   LOG(ERROR) << "Could not create cudnn handle: "
432              << CudnnStatusToString(status);
433   if (status == CUDNN_STATUS_NOT_INITIALIZED) {
434     auto result = gpu::Diagnostician::FindKernelDriverVersion();
435     if (!result.ok()) {
436       LOG(ERROR) << "Error retrieving driver version: "
437                  << cuda::DriverVersionStatusToString(result);
438     } else {
439       const auto& version = result.ValueOrDie();
440       LOG(ERROR) << "Possibly insufficient driver version: "
441                  << cuda::DriverVersionToString(version);
442     }
443   }
444 
445   return port::Status(port::error::INTERNAL,
446                       absl::StrCat("cudnn library could not create a handle: ",
447                                    CudnnStatusToString(status)));
448 }
449 
NotifyStreamDestroyed(Stream * stream)450 void CudnnSupport::NotifyStreamDestroyed(Stream* stream) /* override */ {
451   cudnn_->NotifyStreamDestroyed(stream);
452 }
453 
454 port::StatusOr<perftools::gputools::dnn::VersionInfo>
GetVersion()455 CudnnSupport::GetVersion() {
456   CudnnVersion version;
457   TF_RETURN_IF_ERROR(GetLoadedCudnnVersion(&version));
458   return perftools::gputools::dnn::VersionInfo(
459       version.major_version, version.minor_version, version.patch_level);
460 }
461 
462 namespace {
463 
464 // Deleter functors for cuDNN types that need to be deleted.
465 struct TensorDescriptorDeleter {
operator ()stream_executor::gpu::__anon467f18dd0311::TensorDescriptorDeleter466   void operator()(cudnnTensorDescriptor_t descriptor) const {
467     CHECK_CUDNN_OK(cudnnDestroyTensorDescriptor(descriptor));
468   }
469 };
470 struct RNNDataDescriptorDeleter {
operator ()stream_executor::gpu::__anon467f18dd0311::RNNDataDescriptorDeleter471   void operator()(cudnnRNNDataDescriptor_t descriptor) const {
472     CHECK_CUDNN_OK(cudnnDestroyRNNDataDescriptor(descriptor));
473   }
474 };
475 struct FilterDescriptorDeleter {
operator ()stream_executor::gpu::__anon467f18dd0311::FilterDescriptorDeleter476   void operator()(cudnnFilterDescriptor_t descriptor) const {
477     CHECK_CUDNN_OK(cudnnDestroyFilterDescriptor(descriptor));
478   }
479 };
480 struct ConvolutionDescriptorDeleter {
operator ()stream_executor::gpu::__anon467f18dd0311::ConvolutionDescriptorDeleter481   void operator()(cudnnConvolutionDescriptor_t descriptor) const {
482     CHECK_CUDNN_OK(cudnnDestroyConvolutionDescriptor(descriptor));
483   }
484 };
485 struct PoolingDescriptorDeleter {
operator ()stream_executor::gpu::__anon467f18dd0311::PoolingDescriptorDeleter486   void operator()(cudnnPoolingDescriptor_t descriptor) const {
487     CHECK_CUDNN_OK(cudnnDestroyPoolingDescriptor(descriptor));
488   }
489 };
490 struct LrnDescriptorDeleter {
operator ()stream_executor::gpu::__anon467f18dd0311::LrnDescriptorDeleter491   void operator()(cudnnLRNDescriptor_t descriptor) const {
492     CHECK_CUDNN_OK(cudnnDestroyLRNDescriptor(descriptor));
493   }
494 };
495 
496 struct ActivationDescriptorDeleter {
operator ()stream_executor::gpu::__anon467f18dd0311::ActivationDescriptorDeleter497   void operator()(cudnnActivationDescriptor_t descriptor) const {
498     CHECK_CUDNN_OK(cudnnDestroyActivationDescriptor(descriptor));
499   }
500 };
501 struct DropoutDescriptorDeleter {
operator ()stream_executor::gpu::__anon467f18dd0311::DropoutDescriptorDeleter502   void operator()(cudnnDropoutDescriptor_t descriptor) const {
503     CHECK_CUDNN_OK(cudnnDestroyDropoutDescriptor(descriptor));
504   }
505 };
506 struct RnnDescriptorDeleter {
operator ()stream_executor::gpu::__anon467f18dd0311::RnnDescriptorDeleter507   void operator()(cudnnRNNDescriptor_t descriptor) const {
508     CHECK_CUDNN_OK(cudnnDestroyRNNDescriptor(descriptor));
509   }
510 };
511 struct PersistentRnnPlanDeleter {
operator ()stream_executor::gpu::__anon467f18dd0311::PersistentRnnPlanDeleter512   void operator()(cudnnPersistentRNNPlan_t plan) const {
513     CHECK_CUDNN_OK(cudnnDestroyPersistentRNNPlan(plan));
514   }
515 };
516 #if CUDNN_VERSION >= 7603
517 struct CtcLossDescriptorDeleter {
operator ()stream_executor::gpu::__anon467f18dd0311::CtcLossDescriptorDeleter518   void operator()(cudnnCTCLossDescriptor_t descriptor) const {
519     CHECK_CUDNN_OK(cudnnDestroyCTCLossDescriptor(descriptor));
520   }
521 };
522 #endif
523 
524 // RAII wrappers for cuDNN types.
525 using TensorDescriptor =
526     std::unique_ptr<cudnnTensorStruct, TensorDescriptorDeleter>;
527 using RNNDataDescriptor =
528     std::unique_ptr<cudnnRNNDataStruct, RNNDataDescriptorDeleter>;
529 using FilterDescriptor =
530     std::unique_ptr<cudnnFilterStruct, FilterDescriptorDeleter>;
531 using ConvolutionDescriptor =
532     std::unique_ptr<cudnnConvolutionStruct, ConvolutionDescriptorDeleter>;
533 using PoolingDescriptor =
534     std::unique_ptr<cudnnPoolingStruct, PoolingDescriptorDeleter>;
535 using LrnDescriptor = std::unique_ptr<cudnnLRNStruct, LrnDescriptorDeleter>;
536 using ActivationDescriptor =
537     std::unique_ptr<cudnnActivationStruct, ActivationDescriptorDeleter>;
538 using DropoutDescriptor =
539     std::unique_ptr<cudnnDropoutStruct, DropoutDescriptorDeleter>;
540 using RnnDescriptor = std::unique_ptr<cudnnRNNStruct, RnnDescriptorDeleter>;
541 using PersistentRnnPlan =
542     std::unique_ptr<cudnnPersistentRNNPlan, PersistentRnnPlanDeleter>;
543 #if CUDNN_VERSION >= 7603
544 using CtcLossDescriptor =
545     std::unique_ptr<cudnnCTCLossStruct, CtcLossDescriptorDeleter>;
546 #endif
547 
548 // Factory methods for cuDNN types.
CreateTensorDescriptor()549 TensorDescriptor CreateTensorDescriptor() {
550   cudnnTensorDescriptor_t result;
551   CHECK_CUDNN_OK(cudnnCreateTensorDescriptor(&result));
552   return TensorDescriptor(result);
553 }
CreateRNNDataDescriptor()554 RNNDataDescriptor CreateRNNDataDescriptor() {
555   cudnnRNNDataDescriptor_t result;
556   CHECK_CUDNN_OK(cudnnCreateRNNDataDescriptor(&result));
557   return RNNDataDescriptor(result);
558 }
CreateFilterDescriptor()559 FilterDescriptor CreateFilterDescriptor() {
560   cudnnFilterDescriptor_t result;
561   CHECK_CUDNN_OK(cudnnCreateFilterDescriptor(&result));
562   return FilterDescriptor(result);
563 }
CreateConvolutionDescriptor()564 ConvolutionDescriptor CreateConvolutionDescriptor() {
565   cudnnConvolutionDescriptor_t result;
566   CHECK_CUDNN_OK(cudnnCreateConvolutionDescriptor(&result));
567   return ConvolutionDescriptor(result);
568 }
CreatePoolingDescriptor()569 PoolingDescriptor CreatePoolingDescriptor() {
570   cudnnPoolingDescriptor_t result;
571   CHECK_CUDNN_OK(cudnnCreatePoolingDescriptor(&result));
572   return PoolingDescriptor(result);
573 }
CreateLrnDescriptor()574 LrnDescriptor CreateLrnDescriptor() {
575   cudnnLRNDescriptor_t result;
576   CHECK_CUDNN_OK(cudnnCreateLRNDescriptor(&result));
577   return LrnDescriptor(result);
578 }
CreateActivationDescriptor()579 ActivationDescriptor CreateActivationDescriptor() {
580   cudnnActivationDescriptor_t result;
581   CHECK_CUDNN_OK(cudnnCreateActivationDescriptor(&result));
582   return ActivationDescriptor(result);
583 }
CreateDropoutDescriptor()584 DropoutDescriptor CreateDropoutDescriptor() {
585   cudnnDropoutDescriptor_t result;
586   CHECK_CUDNN_OK(cudnnCreateDropoutDescriptor(&result));
587   return DropoutDescriptor(result);
588 }
CreateRnnDescriptor()589 RnnDescriptor CreateRnnDescriptor() {
590   cudnnRNNDescriptor_t result;
591   CHECK_CUDNN_OK(cudnnCreateRNNDescriptor(&result));
592   return RnnDescriptor(result);
593 }
594 #if CUDNN_VERSION >= 7603
CreateCtcLossDescriptor()595 CtcLossDescriptor CreateCtcLossDescriptor() {
596   cudnnCTCLossDescriptor_t result;
597   CHECK_CUDNN_OK(cudnnCreateCTCLossDescriptor(&result));
598   return CtcLossDescriptor(result);
599 }
600 #endif
601 
CreatePersistentRnnPlan(cudnnRNNDescriptor_t rnn_desc,int batch_size,cudnnDataType_t data_type)602 port::StatusOr<PersistentRnnPlan> CreatePersistentRnnPlan(
603     cudnnRNNDescriptor_t rnn_desc, int batch_size, cudnnDataType_t data_type) {
604   cudnnPersistentRNNPlan_t result;
605   RETURN_IF_CUDNN_ERROR(
606       cudnnCreatePersistentRNNPlan(rnn_desc, batch_size, data_type, &result));
607   return port::StatusOr<PersistentRnnPlan>(PersistentRnnPlan(result));
608 }
609 
610 // Turns a BatchDescriptor structure into a cudnn tensor handle within a
611 // scope.
612 class CudnnTensorDescriptor {
613  public:
CudnnTensorDescriptor(const dnn::BatchDescriptor & batch_descriptor,cudnnDataType_t elem_type)614   CudnnTensorDescriptor(const dnn::BatchDescriptor& batch_descriptor,
615                         cudnnDataType_t elem_type)
616       : handle_(CreateTensorDescriptor()) {
617     switch (batch_descriptor.layout()) {
618       case dnn::DataLayout::kBatchYXDepth:
619       case dnn::DataLayout::kBatchDepthYX: {
620         const int nd = batch_descriptor.ndims() + 2;
621         // cuDNN requires the strides and dims to be ordered as BDYX.
622         std::vector<int64_t> strides64 =
623             batch_descriptor.full_strides(dnn::DataLayout::kBatchDepthYX);
624         std::vector<int64_t> dims64 =
625             batch_descriptor.full_dims(dnn::DataLayout::kBatchDepthYX);
626 
627         // cuDNN requires arrays of ints.
628         std::vector<int> strides(nd);
629         std::vector<int> dims(nd);
630         std::transform(strides64.cbegin(), strides64.cend(), strides.begin(),
631                        &CheckedNarrowing<int64_t, int>);
632         std::transform(dims64.cbegin(), dims64.cend(), dims.begin(),
633                        &CheckedNarrowing<int64_t, int>);
634         CHECK_CUDNN_OK(cudnnSetTensorNdDescriptor(handle_.get(), elem_type, nd,
635                                                   dims.data(), strides.data()))
636             << "batch_descriptor: " << batch_descriptor.ToString();
637         break;
638       }
639       case dnn::DataLayout::kBatchDepthYX4:
640       case dnn::DataLayout::kBatchDepthYX32: {
641         auto expected_elem_ty =
642             batch_descriptor.layout() == dnn::DataLayout::kBatchDepthYX4
643                 ? CUDNN_DATA_INT8x4
644                 : CUDNN_DATA_INT8x32;
645         CHECK_EQ(elem_type, expected_elem_ty);
646         CHECK_CUDNN_OK(cudnnSetTensor4dDescriptor(
647             handle_.get(), CUDNN_TENSOR_NCHW_VECT_C, elem_type,
648             batch_descriptor.count(), batch_descriptor.feature_map_count(),
649             batch_descriptor.height(), batch_descriptor.width()))
650             << "batch_descriptor: " << batch_descriptor.ToString();
651         break;
652       }
653       default:
654         LOG(FATAL) << "Unsupported tensor format "
655                    << DataLayoutString(batch_descriptor.layout());
656         break;
657     }
658   }
659 
handle() const660   cudnnTensorDescriptor_t handle() const { return handle_.get(); }
661 
662  private:
663   TensorDescriptor handle_;
664 };
665 
666 // Turns a FilterDescriptor structure into a cudnn filter handle within a
667 // scope.
668 class CudnnFilterDescriptor {
669  public:
CudnnFilterDescriptor(const dnn::FilterDescriptor & filter_descriptor,cudnnDataType_t elem_type)670   CudnnFilterDescriptor(const dnn::FilterDescriptor& filter_descriptor,
671                         cudnnDataType_t elem_type)
672       : handle_(CreateFilterDescriptor()) {
673     // TODO(b/23032134): Even if the filter layout is not supported,
674     // cudnnSetFilter4DDescriptor_v4 will return CUDNN_STATUS_SUCCESS because
675     // it does not take layout as an input. Maybe force cuDNN by giving wrong
676     // inputs intentionally?
677     cudnnTensorFormat_t format;
678     switch (filter_descriptor.layout()) {
679       case dnn::FilterLayout::kOutputInputYX:
680         format = CUDNN_TENSOR_NCHW;
681         break;
682       case dnn::FilterLayout::kOutputYXInput:
683         format = CUDNN_TENSOR_NHWC;
684         break;
685       case dnn::FilterLayout::kOutputInputYX4:
686       case dnn::FilterLayout::kOutputInputYX32: {
687         auto expected_elem_ty =
688             filter_descriptor.layout() == dnn::FilterLayout::kOutputInputYX4
689                 ? CUDNN_DATA_INT8x4
690                 : CUDNN_DATA_INT8x32;
691         CHECK_EQ(elem_type, expected_elem_ty);
692         format = CUDNN_TENSOR_NCHW_VECT_C;
693         break;
694       }
695       default:
696         LOG(FATAL) << "Unsupported filter format "
697                    << FilterLayoutString(filter_descriptor.layout());
698         break;
699     }
700 
701     std::vector<int> dims(2 + filter_descriptor.ndims());
702     dims[0] = filter_descriptor.output_feature_map_count();
703     dims[1] = filter_descriptor.input_feature_map_count();
704     absl::Span<const int64_t> spatial_dims =
705         filter_descriptor.input_filter_dims();
706     std::copy(spatial_dims.begin(), spatial_dims.end(), dims.begin() + 2);
707 
708     CHECK_CUDNN_OK(cudnnSetFilterNdDescriptor(handle_.get(), elem_type, format,
709                                               dims.size(), dims.data()));
710   }
711 
handle() const712   cudnnFilterDescriptor_t handle() const { return handle_.get(); }
713 
714  private:
715   FilterDescriptor handle_;  // Owned.
716 };
717 
718 #if CUDNN_VERSION >= 8100 && TF_ENABLE_CUDNN_FRONTEND
719 // The errata sheet (JSON format) for marking the cudnn engines that might be
720 // buggy. For example, we don't want the engine 999 of forward convolution:
721 // R"({ "version" : 1,
722 //      "rules"   : [
723 //        { "rule_id"             : "ConvFwd_eng999",
724 //          "operation"           : "ConvFwd",
725 //          "engine"              : 999,
726 //          "knob"                : [],
727 //          "cudnn_version_start" : 8000,
728 //          "cudnn_version_end"   : -1
729 //        }
730 // ]})"
731 // We skip a non-existing eng999 in the static filter as a placeholder.
732 // Additionally, users can specify an additional errata JSON file via
733 // CUDNN_ERRATA_JSON_FILE at runtime.
734 // We are also excluding two flavors of ConvFwd_eng42 due to b/234183340.
CudnnExecutionPlanEngineFilterStatic()735 const json* CudnnExecutionPlanEngineFilterStatic() {
736   static absl::string_view filter_str = R"({
737       "version" : 1,
738         "rules"   : [
739           { "rule_id"             : "ConvFwd_eng999",
740             "operation"           : "ConvFwd",
741             "engine"              : 999,
742             "knob"                : [],
743             "cudnn_version_start" : 8000,
744             "cudnn_version_end"   : -1
745           },
746           { "rule_id"             : "ConvFwd_eng42_k2=2_k4=3_k5=0_k6=0_k7=0",
747             "operation"           : "ConvFwd",
748             "engine"              : 42,
749             "knob"                :
750             {
751                                     "k2" : "2",
752                                     "k4" : "3",
753                                     "k5" : "0",
754                                     "k6" : "0",
755                                     "k7" : "0"
756             },
757             "cudnn_version_start" : 8000,
758             "cudnn_version_end"   : -1
759           },
760           { "rule_id"             : "ConvFwd_eng42_k2=1_k4=3_k5=1_k6=0_k7=0",
761             "operation"           : "ConvFwd",
762             "engine"              : 42,
763             "knob"                :
764             {
765                                     "k2" : "1",
766                                     "k4" : "3",
767                                     "k5" : "1",
768                                     "k6" : "0",
769                                     "k7" : "0"
770             },
771             "cudnn_version_start" : 8000,
772             "cudnn_version_end"   : -1
773           }
774       ]})";
775   static const json* json_handle = new json(json::parse(filter_str));
776   return json_handle;
777 }
778 
CudnnExecutionPlanEngineFilterRuntime()779 const json* CudnnExecutionPlanEngineFilterRuntime() {
780   static const json* json_handle = []() -> const json* {
781     json j;
782     if (cudnn_frontend::load_from_config(j, "")) {
783       return new json(j);
784     }
785     return nullptr;
786   }();
787   return json_handle;
788 }
789 
790 #endif  // CUDNN_VERSION >= 8100 && TF_ENABLE_CUDNN_FRONTEND
791 
792 // A helper function to decide whether to use
793 // CUDNN_BATCHNORM_SPATIAL_PERSISTENT in batchnorm. This mode can be faster in
794 // some tasks because an optimized path may be selected for CUDNN_DATA_FLOAT
795 // and CUDNN_DATA_HALF data types, compute capability 6.0 or higher. The
796 // reason we set it to false by default is that this mode may use scaled
797 // atomic integer reduction that may cause a numerical overflow for certain
798 // input data range.
799 // TODO(yangzihao): Use autotune to choose between this mode and
800 // CUDNN_BATCHNORM_SPATIAL mode.
BatchnormSpatialPersistentEnabled()801 bool BatchnormSpatialPersistentEnabled() {
802   static bool is_enabled = [] {
803     bool is_enabled = false;
804     TF_CHECK_OK(tensorflow::ReadBoolFromEnvVar(
805         "TF_USE_CUDNN_BATCHNORM_SPATIAL_PERSISTENT",
806         /*default_val=*/false, &is_enabled));
807     return is_enabled;
808   }();
809   return is_enabled;
810 }
811 
RequireCudnnDeterminism()812 bool RequireCudnnDeterminism() {
813   static bool require_cudnn_determinism = [] {
814     // TODO(reedwm): Remove the TF_CUDNN_DETERMINISTIC env var.
815     bool cudnn_deterministic = false;
816     TF_CHECK_OK(tensorflow::ReadBoolFromEnvVar("TF_CUDNN_DETERMINISTIC",
817                                                /*default_val=*/false,
818                                                &cudnn_deterministic));
819     return cudnn_deterministic;
820   }();
821   return tensorflow::OpDeterminismRequired() || require_cudnn_determinism;
822 }
823 
824 // A helper function to decide whether to force the default conv algorithm.
ConvUseDefaultAlgorithm()825 bool ConvUseDefaultAlgorithm() {
826   static bool use_default = [] {
827     bool use_default = false;
828     TF_CHECK_OK(tensorflow::ReadBoolFromEnvVar("TF_USE_DEFAULT_CONV_ALGO",
829                                                /*default_val=*/false,
830                                                &use_default));
831     return use_default;
832   }();
833   return use_default;
834 }
835 
836 // Turns a ConvolutionDescriptor structure into a cudnn convolution handle
837 // within a scope.
838 class CudnnConvolutionDescriptor {
839  public:
CudnnConvolutionDescriptor(const dnn::ConvolutionDescriptor & convolution_descriptor,cudnnDataType_t data_type)840   CudnnConvolutionDescriptor(
841       const dnn::ConvolutionDescriptor& convolution_descriptor,
842       cudnnDataType_t data_type)
843       : handle_(CreateConvolutionDescriptor()) {
844     absl::Span<const int64_t> strides64 = convolution_descriptor.strides();
845     absl::Span<const int64_t> padding64 = convolution_descriptor.padding();
846     absl::Span<const int64_t> dilations64 = convolution_descriptor.dilations();
847     CHECK_NE(convolution_descriptor.pad_alignment(),
848              dnn::PadAlignment::kTensorFlowPadding)
849         << "TensorFlow padding alignment is not supported.";
850 
851     // cuDNN requires arrays of ints.
852     std::vector<int> strides(convolution_descriptor.ndims());
853     std::vector<int> padding(convolution_descriptor.ndims());
854     std::vector<int> dilations(convolution_descriptor.ndims());
855     std::transform(strides64.cbegin(), strides64.cend(), strides.begin(),
856                    &CheckedNarrowing<int64_t, int>);
857     std::transform(padding64.cbegin(), padding64.cend(), padding.begin(),
858                    &CheckedNarrowing<int64_t, int>);
859     // TODO(yangzihao): Test with negative dilation to make sure that cudnn
860     // doesn't crash.
861     std::transform(dilations64.cbegin(), dilations64.cend(), dilations.begin(),
862                    &CheckedNarrowing<int64_t, int>);
863 
864     CHECK_CUDNN_OK(cudnnSetConvolutionNdDescriptor(
865         handle_.get(), convolution_descriptor.ndims(), padding.data(),
866         strides.data(), dilations.data(),
867         convolution_descriptor.convolution_not_crosscorr()
868             ? CUDNN_CONVOLUTION
869             : CUDNN_CROSS_CORRELATION,
870         data_type));
871 
872 #if CUDNN_MAJOR >= 7
873     VLOG(2) << "Requesting grouped convolution: "
874             << convolution_descriptor.group_count();
875     CHECK_CUDNN_OK(cudnnSetConvolutionGroupCount(
876         handle_.get(), convolution_descriptor.group_count()));
877 #else
878     CHECK_EQ(convolution_descriptor.group_count(), 1)
879         << "Requested grouped convolution for cuDNN version < 7";
880 #endif
881   }
882 
set_use_tensor_op_math(bool use_tensor_op_math)883   void set_use_tensor_op_math(bool use_tensor_op_math) {
884     cudnnMathType_t math_type =
885 #if CUDNN_VERSION >= 8000
886         (use_tensor_op_math ? CUDNN_TENSOR_OP_MATH : CUDNN_FMA_MATH);
887 #else
888         (use_tensor_op_math ? CUDNN_TENSOR_OP_MATH : CUDNN_DEFAULT_MATH);
889 #endif
890     CHECK_CUDNN_OK(cudnnSetConvolutionMathType(handle_.get(), math_type));
891   }
892 
handle() const893   cudnnConvolutionDescriptor_t handle() const { return handle_.get(); }
894 
895  private:
896   ConvolutionDescriptor handle_;  // Owned.
897 };
898 
899 // A helper function to query if a CudnnConvolutionDescriptor has tensor_op_math
900 // set
IsTensorMathOpSet(const CudnnConvolutionDescriptor & conv)901 static bool IsTensorMathOpSet(const CudnnConvolutionDescriptor& conv) {
902   cudnnMathType_t math_type;
903   CHECK_CUDNN_OK(cudnnGetConvolutionMathType(conv.handle(), &math_type));
904 #if CUDNN_VERSION >= 8000
905   return math_type != CUDNN_FMA_MATH;
906 #else
907   return math_type == CUDNN_TENSOR_OP_MATH;
908 #endif
909 }
910 
TensorOpMathAvailable(CudaComputeCapability cuda_compute_capability)911 static bool TensorOpMathAvailable(
912     CudaComputeCapability cuda_compute_capability) {
913   return cuda_compute_capability.IsAtLeast(7);
914 }
915 
IsTensorMathEnabled(Stream * stream,dnn::DataType input_type)916 static bool IsTensorMathEnabled(Stream* stream, dnn::DataType input_type) {
917   if (!TensorOpMathAvailable(stream->GetCudaComputeCapability())) {
918     return false;
919   }
920   if (input_type == dnn::DataType::kFloat) {
921 #if CUDNN_VERSION < 8000
922     return false;
923 #else
924     if (!tensorflow::tensor_float_32_execution_enabled()) {
925       return false;
926     }
927 #endif
928   }
929   return true;
930 }
931 
932 // Turns a PoolingDescriptor structure into a cudnn pooling descriptor handle
933 // within a scope.
934 class CudnnPoolingDescriptor {
935  public:
CudnnPoolingDescriptor(const dnn::PoolingDescriptor & pooling_descriptor)936   explicit CudnnPoolingDescriptor(
937       const dnn::PoolingDescriptor& pooling_descriptor)
938       : handle_(CreatePoolingDescriptor()) {
939     absl::Span<const int64_t> strides64 = pooling_descriptor.strides();
940     absl::Span<const int64_t> padding64 = pooling_descriptor.padding();
941     absl::Span<const int64_t> shape64 = pooling_descriptor.window();
942 
943     const int nd = pooling_descriptor.ndims();
944     std::vector<int> shape(nd);
945     std::vector<int> padding(nd);
946     std::vector<int> strides(nd);
947     std::transform(strides64.cbegin(), strides64.cend(), strides.begin(),
948                    &CheckedNarrowing<int64_t, int>);
949     std::transform(padding64.cbegin(), padding64.cend(), padding.begin(),
950                    &CheckedNarrowing<int64_t, int>);
951     std::transform(shape64.cbegin(), shape64.cend(), shape.begin(),
952                    &CheckedNarrowing<int64_t, int>);
953     bool propagate_nans = pooling_descriptor.propagate_nans();
954     const auto cudnn_max_pooling_mode = RequireCudnnDeterminism()
955                                             ? CUDNN_POOLING_MAX_DETERMINISTIC
956                                             : CUDNN_POOLING_MAX;
957     CHECK_CUDNN_OK(cudnnSetPoolingNdDescriptor(
958         handle_.get(),
959         (pooling_descriptor.mode() == dnn::PoolingMode::kMaximum
960              ? cudnn_max_pooling_mode
961              : CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING),
962         propagate_nans ? CUDNN_PROPAGATE_NAN : CUDNN_NOT_PROPAGATE_NAN, nd,
963         shape.data(), padding.data(), strides.data()));
964   }
965 
handle() const966   cudnnPoolingDescriptor_t handle() const { return handle_.get(); }
967 
968  private:
969   PoolingDescriptor handle_;  // Owned.
970 
971   SE_DISALLOW_COPY_AND_ASSIGN(CudnnPoolingDescriptor);
972 };
973 
974 // Turns a NormalizeDescriptor structure into a cudnn LRN descriptor handle.
975 class CudnnNormalizeDescriptor {
976  public:
CudnnNormalizeDescriptor(const dnn::NormalizeDescriptor & normalize_descriptor)977   explicit CudnnNormalizeDescriptor(
978       const dnn::NormalizeDescriptor& normalize_descriptor)
979       : handle_(CreateLrnDescriptor()) {
980     // The range specifies that the indices in the closed range
981     // [i - range, i + range] should be included in the normalization for index
982     // i. The lrnN value is the total number of elements in the range, so
983     // lrnN = 2*range + 1.
984     unsigned lrnN = 2 * normalize_descriptor.range() + 1;
985 
986     // Note that SE defines the normalization operation as
987     //
988     //  U_i = V_i / ((bias +  alpha      * (sum_j V_j^2)) ^ beta)
989     //
990     // but cuDNN defines it as
991     //
992     //  U_i = V_i / ((bias + (alpha / n) * (sum_j V_j^2)) ^ beta)
993     //
994     // i.e. there is a factor of n difference between the meaning of the alphas
995     // in the two contexts. The cuDNN alpha is n times the SE alpha.
996     double lrnAlpha = lrnN * normalize_descriptor.alpha();
997 
998     double lrnBeta = normalize_descriptor.beta();
999     double lrnK = normalize_descriptor.bias();
1000     CHECK_CUDNN_OK(
1001         cudnnSetLRNDescriptor(handle_.get(), lrnN, lrnAlpha, lrnBeta, lrnK));
1002   }
1003 
handle() const1004   cudnnLRNDescriptor_t handle() const { return handle_.get(); }
1005 
1006  private:
1007   LrnDescriptor handle_;  // Owned.
1008 
1009   SE_DISALLOW_COPY_AND_ASSIGN(CudnnNormalizeDescriptor);
1010 };
1011 
1012 // Turns a ActivationDescriptor structure into a cudnn activation
1013 // descriptor handle within a scope.
1014 class CudnnActivationDescriptor {
1015  public:
CudnnActivationDescriptor(dnn::ActivationMode activation_mode,cudnnNanPropagation_t nan_propagation,double value_max)1016   CudnnActivationDescriptor(dnn::ActivationMode activation_mode,
1017                             cudnnNanPropagation_t nan_propagation,
1018                             double value_max)
1019       : handle_(CreateActivationDescriptor()) {
1020     double relu_ceiling = 0.0;
1021     cudnnActivationMode_t mode;
1022     switch (activation_mode) {
1023       case dnn::ActivationMode::kNone:
1024         mode = CUDNN_ACTIVATION_IDENTITY;
1025         break;
1026       case dnn::ActivationMode::kRelu6:
1027         relu_ceiling = 6.0;
1028         mode = CUDNN_ACTIVATION_CLIPPED_RELU;
1029         break;
1030       case dnn::ActivationMode::kReluX:
1031         relu_ceiling = value_max;
1032         mode = CUDNN_ACTIVATION_CLIPPED_RELU;
1033         break;
1034       case dnn::ActivationMode::kRelu:
1035         mode = CUDNN_ACTIVATION_RELU;
1036         break;
1037       case dnn::ActivationMode::kSigmoid:
1038         mode = CUDNN_ACTIVATION_SIGMOID;
1039         break;
1040       case dnn::ActivationMode::kTanh:
1041         mode = CUDNN_ACTIVATION_TANH;
1042         break;
1043       default:
1044         LOG(FATAL) << "unrecognized activation mode: "
1045                    << static_cast<int>(activation_mode);
1046     }
1047 
1048     CHECK_CUDNN_OK(cudnnSetActivationDescriptor(handle_.get(), mode,
1049                                                 nan_propagation, relu_ceiling));
1050   }
1051 
handle() const1052   cudnnActivationDescriptor_t handle() const { return handle_.get(); }
1053 
1054  private:
1055   ActivationDescriptor handle_;  // Owned.
1056 };
1057 
ToCudnnDataType(dnn::DataType data_type,dnn::DataLayout data_layout=dnn::DataLayout::kBatchDepthYX)1058 cudnnDataType_t ToCudnnDataType(
1059     dnn::DataType data_type,
1060     dnn::DataLayout data_layout = dnn::DataLayout::kBatchDepthYX) {
1061   switch (data_type) {
1062     case dnn::DataType::kFloat:
1063       return CUDNN_DATA_FLOAT;
1064     case dnn::DataType::kDouble:
1065       return CUDNN_DATA_DOUBLE;
1066     case dnn::DataType::kHalf:
1067       return CUDNN_DATA_HALF;
1068     case dnn::DataType::kInt8:
1069       switch (data_layout) {
1070         case dnn::DataLayout::kBatchDepthYX4:
1071           return CUDNN_DATA_INT8x4;
1072         case dnn::DataLayout::kBatchDepthYX32:
1073           return CUDNN_DATA_INT8x32;
1074         default:
1075           return CUDNN_DATA_INT8;
1076       }
1077     case dnn::DataType::kInt32:
1078       return CUDNN_DATA_INT32;
1079 #if CUDNN_VERSION >= 8200
1080     case dnn::DataType::kBF16:
1081       return CUDNN_DATA_BFLOAT16;
1082 #endif
1083     default:
1084       LOG(FATAL) << "Invalid DNN data type: " << static_cast<int>(data_type);
1085   }
1086 }
1087 
ToCudnnDataType(dnn::DataType data_type,dnn::FilterLayout filter_layout)1088 cudnnDataType_t ToCudnnDataType(dnn::DataType data_type,
1089                                 dnn::FilterLayout filter_layout) {
1090   if (data_type == dnn::DataType::kInt8 &&
1091       filter_layout == dnn::FilterLayout::kOutputInputYX4) {
1092     return CUDNN_DATA_INT8x4;
1093   }
1094   if (data_type == dnn::DataType::kInt8 &&
1095       filter_layout == dnn::FilterLayout::kOutputInputYX32) {
1096     return CUDNN_DATA_INT8x32;
1097   }
1098   return ToCudnnDataType(data_type);
1099 }
1100 
1101 template <typename T>
GetCudnnDataType(dnn::DataLayout data_layout=dnn::DataLayout::kBatchDepthYX)1102 cudnnDataType_t GetCudnnDataType(
1103     dnn::DataLayout data_layout = dnn::DataLayout::kBatchDepthYX) {
1104   return ToCudnnDataType(dnn::ToDataType<T>::value, data_layout);
1105 }
1106 
1107 template <typename T>
GetCudnnDataType(dnn::FilterLayout filter_layout)1108 cudnnDataType_t GetCudnnDataType(dnn::FilterLayout filter_layout) {
1109   return ToCudnnDataType(dnn::ToDataType<T>::value, filter_layout);
1110 }
1111 
ToCudnnRnnInputMode(dnn::RnnInputMode input_mode)1112 cudnnRNNInputMode_t ToCudnnRnnInputMode(dnn::RnnInputMode input_mode) {
1113   switch (input_mode) {
1114     case dnn::RnnInputMode::kRnnLinearSkip:
1115     case dnn::RnnInputMode::kRnnSkipInput:
1116       return static_cast<cudnnRNNInputMode_t>(input_mode);
1117     default:
1118       LOG(FATAL) << "Invalid RNN input mode: " << static_cast<int>(input_mode);
1119   }
1120 }
1121 
ToCudnnRnnDirectionMode(dnn::RnnDirectionMode direction_mode)1122 cudnnDirectionMode_t ToCudnnRnnDirectionMode(
1123     dnn::RnnDirectionMode direction_mode) {
1124   switch (direction_mode) {
1125     case dnn::RnnDirectionMode::kRnnUnidirectional:
1126     case dnn::RnnDirectionMode::kRnnBidirectional:
1127       return static_cast<cudnnDirectionMode_t>(direction_mode);
1128     default:
1129       LOG(FATAL) << "Invalid RNN direction mode: "
1130                  << static_cast<int>(direction_mode);
1131   }
1132 }
1133 
ToCudnnRnnMode(dnn::RnnMode rnn_mode)1134 cudnnRNNMode_t ToCudnnRnnMode(dnn::RnnMode rnn_mode) {
1135   switch (rnn_mode) {
1136     case dnn::RnnMode::kRnnRelu:
1137     case dnn::RnnMode::kRnnTanh:
1138     case dnn::RnnMode::kRnnLstm:
1139     case dnn::RnnMode::kRnnGru:
1140       return static_cast<cudnnRNNMode_t>(rnn_mode);
1141     default:
1142       LOG(FATAL) << "Invalid RNN Mode: " << static_cast<int>(rnn_mode);
1143   }
1144 }
1145 
CudnnDataTypeToByteSize(cudnnDataType_t data_type)1146 int CudnnDataTypeToByteSize(cudnnDataType_t data_type) {
1147   switch (data_type) {
1148     case CUDNN_DATA_FLOAT:
1149       return sizeof(float);
1150     case CUDNN_DATA_DOUBLE:
1151       return sizeof(double);
1152     case CUDNN_DATA_HALF:
1153       return sizeof(Eigen::half);
1154     default:
1155       LOG(FATAL) << "Invalid DNN data type: " << static_cast<int>(data_type);
1156   }
1157 }
1158 
1159 class CudnnDropoutDescriptor {
CudnnDropoutDescriptor(DropoutDescriptor handle)1160   explicit CudnnDropoutDescriptor(DropoutDescriptor handle)
1161       : handle_(std::move(handle)) {}
1162 
1163  public:
1164   CudnnDropoutDescriptor(CudnnDropoutDescriptor&&) = default;
1165 
Create(const CudnnHandle & cudnn,float dropout,uint64_t seed,ScratchAllocator * state_allocator)1166   static port::StatusOr<CudnnDropoutDescriptor> Create(
1167       const CudnnHandle& cudnn, float dropout, uint64_t seed,
1168       ScratchAllocator* state_allocator) {
1169     DropoutDescriptor handle = CreateDropoutDescriptor();
1170 
1171     if (dropout == 0.0f) {
1172       // Return 'empty' dropout descriptor.
1173       return CudnnDropoutDescriptor(std::move(handle));
1174     }
1175 
1176     DeviceMemory<uint8> state_memory;
1177     if (state_allocator) {
1178       size_t state_sizes_in_bytes = 0;
1179       RETURN_IF_CUDNN_ERROR(
1180           cudnnDropoutGetStatesSize(cudnn.handle(), &state_sizes_in_bytes));
1181       TF_ASSIGN_OR_RETURN(state_memory,
1182                           state_allocator->AllocateBytes(state_sizes_in_bytes));
1183     }
1184     RETURN_IF_CUDNN_ERROR(cudnnSetDropoutDescriptor(
1185         handle.get(), cudnn.handle(), dropout, state_memory.opaque(),
1186         state_memory.size(), seed));
1187 
1188     return CudnnDropoutDescriptor(std::move(handle));
1189   }
1190 
handle() const1191   cudnnDropoutDescriptor_t handle() const { return handle_.get(); }
1192 
1193  private:
1194   DropoutDescriptor handle_;  // Owned.
1195   SE_DISALLOW_COPY_AND_ASSIGN(CudnnDropoutDescriptor);
1196 };
1197 
1198 class CudnnRnnParamsDescriptor {
1199   typedef dnn::RnnDescriptor::ParamsRegions ParamsRegions;
1200 
CudnnRnnParamsDescriptor(FilterDescriptor handle,int64_t params_size_in_bytes,ParamsRegions weights,ParamsRegions biases)1201   CudnnRnnParamsDescriptor(FilterDescriptor handle,
1202                            int64_t params_size_in_bytes, ParamsRegions weights,
1203                            ParamsRegions biases)
1204       : handle_(std::move(handle)),
1205         params_size_in_bytes_(params_size_in_bytes),
1206         weights_(std::move(weights)),
1207         biases_(std::move(biases)) {}
1208 
1209  public:
1210   CudnnRnnParamsDescriptor(CudnnRnnParamsDescriptor&&) = default;
1211 
1212   static port::StatusOr<CudnnRnnParamsDescriptor> Create(
1213       const CudnnHandle& cudnn, int input_size, cudnnDataType_t data_type,
1214       cudnnRNNDescriptor_t rnn_desc, cudnnRNNMode_t rnn_mode,
1215       cudnnDirectionMode_t direction_mode, int num_layers);
1216 
handle() const1217   cudnnFilterDescriptor_t handle() const { return handle_.get(); }
params_size_in_bytes() const1218   int64_t params_size_in_bytes() const { return params_size_in_bytes_; }
params_weights() const1219   ParamsRegions params_weights() const { return weights_; }
params_biases() const1220   ParamsRegions params_biases() const { return biases_; }
1221 
1222  private:
1223   FilterDescriptor handle_;
1224   int64_t params_size_in_bytes_;
1225   ParamsRegions weights_;
1226   ParamsRegions biases_;
1227   SE_DISALLOW_COPY_AND_ASSIGN(CudnnRnnParamsDescriptor);
1228 };
1229 
1230 }  // namespace
1231 
1232 class CudnnRnnDescriptor : public dnn::RnnDescriptor {
CudnnRnnDescriptor(const CudnnHandle & cudnn,gpu::RnnDescriptor rnn_desc,PersistentRnnPlan rnn_plan,int num_layers,int hidden_size,int input_size,int cell_size,int batch_size,cudnnRNNInputMode_t input_mode,cudnnDirectionMode_t direction_mode,cudnnRNNMode_t rnn_mode,cudnnDataType_t data_type,cudnnDataType_t compute_type,const dnn::AlgorithmConfig & algorithm_config,CudnnDropoutDescriptor dropout_desc,CudnnRnnParamsDescriptor params_desc)1233   CudnnRnnDescriptor(const CudnnHandle& cudnn, gpu::RnnDescriptor rnn_desc,
1234                      PersistentRnnPlan rnn_plan, int num_layers,
1235                      int hidden_size, int input_size, int cell_size,
1236                      int batch_size, cudnnRNNInputMode_t input_mode,
1237                      cudnnDirectionMode_t direction_mode,
1238                      cudnnRNNMode_t rnn_mode, cudnnDataType_t data_type,
1239                      cudnnDataType_t compute_type,
1240                      const dnn::AlgorithmConfig& algorithm_config,
1241                      CudnnDropoutDescriptor dropout_desc,
1242                      CudnnRnnParamsDescriptor params_desc)
1243       : rnn_desc_(std::move(rnn_desc)),
1244         rnn_plan_(std::move(rnn_plan)),
1245         num_layers_(num_layers),
1246         hidden_size_(hidden_size),
1247         input_size_(input_size),
1248         cell_size_(cell_size),
1249         batch_size_(batch_size),
1250         rnn_algo_(ToCudnnRNNAlgo(algorithm_config.algorithm())),
1251         input_mode_(input_mode),
1252         direction_mode_(direction_mode),
1253         rnn_mode_(rnn_mode),
1254         data_type_(data_type),
1255         compute_type_(compute_type),
1256         algorithm_config_(algorithm_config),
1257         dropout_desc_(std::move(dropout_desc)),
1258         params_desc_(std::move(params_desc)) {}
1259 
1260  public:
1261   CudnnRnnDescriptor(CudnnRnnDescriptor&& other) = default;
1262 
Create(const CudnnHandle & cudnn,int num_layers,int hidden_size,int input_size,int cell_size,int batch_size,cudnnRNNInputMode_t input_mode,cudnnDirectionMode_t direction_mode,cudnnRNNMode_t rnn_mode,cudnnDataType_t data_type,cudnnDataType_t compute_type,const dnn::AlgorithmConfig & algorithm_config,float dropout,uint64_t seed,ScratchAllocator * state_allocator,bool use_padded_io)1263   static port::StatusOr<CudnnRnnDescriptor> Create(
1264       const CudnnHandle& cudnn, int num_layers, int hidden_size, int input_size,
1265       int cell_size, int batch_size, cudnnRNNInputMode_t input_mode,
1266       cudnnDirectionMode_t direction_mode, cudnnRNNMode_t rnn_mode,
1267       cudnnDataType_t data_type, cudnnDataType_t compute_type,
1268       const dnn::AlgorithmConfig& algorithm_config, float dropout,
1269       uint64_t seed, ScratchAllocator* state_allocator, bool use_padded_io) {
1270     TF_ASSIGN_OR_RETURN(
1271         CudnnDropoutDescriptor dropout_desc,
1272         CudnnDropoutDescriptor::Create(cudnn, dropout, seed, state_allocator));
1273 
1274     gpu::RnnDescriptor rnn_desc = CreateRnnDescriptor();
1275     cudnnRNNAlgo_t rnn_algo = ToCudnnRNNAlgo(algorithm_config.algorithm());
1276 
1277     // TODO: allow the user to choose an algorithm.
1278     auto proj_size = hidden_size;
1279     hidden_size = std::max(hidden_size, cell_size);
1280 
1281     // Require explicit algorithm config to enable tensor cores. Some configs
1282     // return CUDNN_NOT_SUPPORTED when tensor ops are enabled (which is against
1283     // the idiom that enabling tensor ops is only a hint: see nvbugs/2172799).
1284     // We can only reasonably expect the user to handle the subsequent failure
1285     // in profile mode, which is run with algorithms returned from
1286     // GetRnnAlgorithms() (which are non-default and explicitly set whether to
1287     // use tensor ops). CuDNN 7.2.1 fixed this issue.
1288     // TODO(csigg): Minimal support cuDNN version is 7.3, clean up.
1289     bool allow_tensor_ops = data_type == CUDNN_DATA_HALF;
1290     if (data_type == CUDNN_DATA_FLOAT)
1291       allow_tensor_ops = tensorflow::tensor_float_32_execution_enabled();
1292     bool use_tensor_ops =
1293         algorithm_config.algorithm().has_value()
1294             ? algorithm_config.algorithm()->tensor_ops_enabled()
1295             : allow_tensor_ops;
1296     if (use_tensor_ops && !allow_tensor_ops) {
1297       return port::Status(port::error::INVALID_ARGUMENT,
1298                           "Algo requests disallowed tensor op evaluation.");
1299     }
1300 
1301 #if CUDNN_VERSION >= 8000
1302     cudnnMathType_t math_type =
1303         use_tensor_ops ? CUDNN_TENSOR_OP_MATH : CUDNN_FMA_MATH;
1304 #else
1305     cudnnMathType_t math_type =
1306         use_tensor_ops ? CUDNN_TENSOR_OP_MATH : CUDNN_DEFAULT_MATH;
1307 #endif
1308 
1309 #if CUDNN_VERSION >= 8000
1310     cudnnRNNBiasMode_t bias_mode = CUDNN_RNN_DOUBLE_BIAS;
1311     uint32_t aux_flags = 0;
1312     if (use_padded_io) aux_flags |= CUDNN_RNN_PADDED_IO_ENABLED;
1313     RETURN_IF_CUDNN_ERROR(cudnnSetRNNDescriptor_v8(
1314         /*rnnDesc=*/rnn_desc.get(), /*algo=*/rnn_algo, /*cellMode=*/rnn_mode,
1315         /*biasMode=*/bias_mode, /*dirMode=*/direction_mode,
1316         /*inputMode=*/input_mode,
1317         /*dataType=*/data_type, /*mathPrec=*/compute_type,
1318         /*mathType=*/math_type,
1319         /*inputSize=*/input_size,
1320         /*hiddenSize=*/hidden_size, /*projSize=*/proj_size,
1321         /*numLayers=*/num_layers,
1322         /*dropoutDesc=*/dropout_desc.handle(),
1323         /*auxFlags=*/aux_flags));
1324 #else
1325     RETURN_IF_CUDNN_ERROR(cudnnSetRNNDescriptor_v6(
1326         cudnn.handle(), /*rnnDesc=*/rnn_desc.get(),
1327         /*hiddenSize=*/hidden_size, /*numLayers=*/num_layers,
1328         /*dropoutDesc=*/dropout_desc.handle(), /*inputMode=*/input_mode,
1329         /*direction=*/direction_mode, /*mode=*/rnn_mode, /*algo=*/rnn_algo,
1330         /*dataType=*/compute_type));
1331     CHECK_CUDNN_OK(cudnnSetRNNMatrixMathType(rnn_desc.get(), math_type));
1332 
1333     if (proj_size < hidden_size) {
1334       RETURN_IF_CUDNN_ERROR(cudnnSetRNNProjectionLayers(
1335           cudnn.handle(), /*rnnDesc=*/rnn_desc.get(),
1336           /*recProjSize=*/proj_size, /*outProjSize=*/0));
1337     }
1338 
1339     // TODO: For now, we only use cudnnRNN**Ex API to process padded inputs.
1340     // But in the future if these APIs are used to process full length arrays,
1341     // we need to distinguish when to set it.
1342     if (use_padded_io) {
1343       RETURN_IF_CUDNN_ERROR(
1344           cudnnSetRNNPaddingMode(rnn_desc.get(), CUDNN_RNN_PADDED_IO_ENABLED));
1345     }
1346 #endif
1347 
1348     port::StatusOr<PersistentRnnPlan> rnn_plan_wrapper;
1349     PersistentRnnPlan rnn_plan;
1350     if (rnn_algo == CUDNN_RNN_ALGO_PERSIST_DYNAMIC) {
1351       CHECK_GE(batch_size, 0);
1352       rnn_plan_wrapper =
1353           CreatePersistentRnnPlan(rnn_desc.get(), batch_size, data_type);
1354       if (!rnn_plan_wrapper.ok()) {
1355         return port::StatusOr<CudnnRnnDescriptor>(rnn_plan_wrapper.status());
1356       } else {
1357         rnn_plan = std::move(rnn_plan_wrapper).value();
1358         RETURN_IF_CUDNN_ERROR(
1359             cudnnSetPersistentRNNPlan(rnn_desc.get(), rnn_plan.get()));
1360       }
1361     }
1362 
1363     // Create the params handle.
1364     // TODO(kaixih@nvidia.com): Should be removed when cudnnRNNForward*** and
1365     // cudnnRNNForward***Ex are removed from the codebase, since the new API
1366     // doesn't need param descriptors any more.
1367     TF_ASSIGN_OR_RETURN(auto params_desc,
1368                         CudnnRnnParamsDescriptor::Create(
1369                             cudnn, input_size, data_type, rnn_desc.get(),
1370                             rnn_mode, direction_mode, num_layers));
1371 
1372     return CudnnRnnDescriptor(cudnn, std::move(rnn_desc), std::move(rnn_plan),
1373                               num_layers, hidden_size, input_size, cell_size,
1374                               batch_size, input_mode, direction_mode, rnn_mode,
1375                               data_type, compute_type, algorithm_config,
1376                               std::move(dropout_desc), std::move(params_desc));
1377   }
1378 
handle() const1379   cudnnRNNDescriptor_t handle() const { return rnn_desc_.get(); }
num_layers() const1380   int num_layers() const { return num_layers_; }
hidden_size() const1381   int hidden_size() const { return hidden_size_; }
input_size() const1382   int input_size() const { return input_size_; }
cell_size() const1383   int cell_size() const { return cell_size_; }
batch_size() const1384   int batch_size() const { return batch_size_; }
input_mode() const1385   cudnnRNNInputMode_t input_mode() const { return input_mode_; }
direction_mode() const1386   cudnnDirectionMode_t direction_mode() const { return direction_mode_; }
rnn_mode() const1387   cudnnRNNMode_t rnn_mode() const { return rnn_mode_; }
data_type() const1388   cudnnDataType_t data_type() const { return data_type_; }
compute_type() const1389   cudnnDataType_t compute_type() const { return compute_type_; }
algorithm_config() const1390   const dnn::AlgorithmConfig& algorithm_config() const {
1391     return algorithm_config_;
1392   }
ParamsSizeInBytes() const1393   int64_t ParamsSizeInBytes() const override {
1394     return params_desc_.params_size_in_bytes();
1395   }
params_handle() const1396   cudnnFilterDescriptor_t params_handle() const {
1397     return params_desc_.handle();
1398   }
ParamsWeightRegions() const1399   ParamsRegions ParamsWeightRegions() const override {
1400     return params_desc_.params_weights();
1401   }
ParamsBiasRegions() const1402   ParamsRegions ParamsBiasRegions() const override {
1403     return params_desc_.params_biases();
1404   }
1405 
1406  private:
1407   gpu::RnnDescriptor rnn_desc_;
1408   PersistentRnnPlan rnn_plan_;
1409   int num_layers_;
1410   int hidden_size_;
1411   int input_size_;
1412   // cell_size_ is the size of cell state, which will be different from
1413   // hidden_size_ if the projection is used.
1414   int cell_size_;
1415   // batch_size_ is set to -1 when not using CUDNN_RNN_ALGO_PERSIST_DYNAMIC
1416   // algorithm.
1417   int batch_size_;
1418   cudnnRNNAlgo_t rnn_algo_;
1419   cudnnRNNInputMode_t input_mode_;
1420   cudnnDirectionMode_t direction_mode_;
1421   cudnnRNNMode_t rnn_mode_;
1422   cudnnDataType_t data_type_;
1423   cudnnDataType_t compute_type_;
1424   dnn::AlgorithmConfig algorithm_config_;
1425   CudnnDropoutDescriptor dropout_desc_;
1426   CudnnRnnParamsDescriptor params_desc_;
1427   SE_DISALLOW_COPY_AND_ASSIGN(CudnnRnnDescriptor);
1428 };
1429 
1430 #if CUDNN_VERSION >= 7603
1431 class CudnnCtcLossDescriptor {
1432  public:
CudnnCtcLossDescriptor(cudnnDataType_t data_type)1433   explicit CudnnCtcLossDescriptor(cudnnDataType_t data_type)
1434       : handle_(CreateCtcLossDescriptor()) {
1435     CHECK_CUDNN_OK(cudnnSetCTCLossDescriptorEx(
1436         /*ctcLossDesc=*/handle_.get(),
1437         /*compType=*/data_type,
1438         /*normMode=*/CUDNN_LOSS_NORMALIZATION_SOFTMAX,
1439         /*gradMode=*/CUDNN_NOT_PROPAGATE_NAN));
1440   }
1441 
handle() const1442   cudnnCTCLossDescriptor_t handle() const { return handle_.get(); }
1443 
1444  private:
1445   CtcLossDescriptor handle_;  // Owned
1446 
1447   SE_DISALLOW_COPY_AND_ASSIGN(CudnnCtcLossDescriptor);
1448 };
1449 #else
1450 // dummy class
1451 class CudnnCtcLossDescriptor {
1452  public:
CudnnCtcLossDescriptor(cudnnDataType_t data_type)1453   CudnnCtcLossDescriptor(cudnnDataType_t data_type) {}
1454 };
1455 #endif
1456 
1457 namespace {
1458 
1459 // Check if the LSTM projection is used. If yes, an additional weight matrix
1460 // (projection matrix) will be fetched to the 'weights'. Otherwise, nothing will
1461 // be done.
CheckAndFetchProjectionWeights(const CudnnHandle & cudnn,cudnnRNNDescriptor_t rnn_desc,const int layer,const TensorDescriptor & input_desc,const FilterDescriptor & filter_desc,const FilterDescriptor & region_desc_handle,dnn::RnnDescriptor::ParamsRegions * weights)1462 port::Status CheckAndFetchProjectionWeights(
1463     const CudnnHandle& cudnn, cudnnRNNDescriptor_t rnn_desc, const int layer,
1464     const TensorDescriptor& input_desc, const FilterDescriptor& filter_desc,
1465     const FilterDescriptor& region_desc_handle,
1466     dnn::RnnDescriptor::ParamsRegions* weights) {
1467   int hidden_size_v;
1468   int num_layers_v;
1469   cudnnDropoutDescriptor_t dropout_desc;
1470   cudnnRNNInputMode_t input_mode;
1471   cudnnDirectionMode_t direction;
1472   cudnnRNNMode_t mode;
1473   cudnnRNNAlgo_t algo;
1474   cudnnDataType_t data_type;
1475 #if CUDNN_VERSION >= 8000
1476   RETURN_IF_CUDNN_ERROR(cudnnGetRNNDescriptor_v6(
1477       /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc,
1478       /*hiddenSize=*/&hidden_size_v,
1479       /*numLayers=*/&num_layers_v,
1480       /*dropoutDesc=*/&dropout_desc,
1481       /*inputMode=*/&input_mode,
1482       /*direction=*/&direction,
1483       /*mode=*/&mode,
1484       /*algo=*/&algo,
1485       /*mathPrec=*/&data_type));
1486 #else
1487   RETURN_IF_CUDNN_ERROR(cudnnGetRNNDescriptor(
1488       /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc,
1489       /*hiddenSize=*/&hidden_size_v,
1490       /*numLayers=*/&num_layers_v,
1491       /*dropoutDesc=*/&dropout_desc,
1492       /*inputMode=*/&input_mode,
1493       /*direction=*/&direction,
1494       /*mode=*/&mode,
1495       /*algo=*/&algo,
1496       /*mathPrec=*/&data_type));
1497 #endif
1498   int rec_proj_size_v;
1499   int out_proj_size_v;
1500   RETURN_IF_CUDNN_ERROR(cudnnGetRNNProjectionLayers(
1501       /*handle=*/cudnn.handle(),
1502       /*rnnDesc=*/rnn_desc,
1503       /*recProjSize*/ &rec_proj_size_v,
1504       /*outProjSize*/ &out_proj_size_v));
1505   if (rec_proj_size_v != hidden_size_v) {
1506     void* offset = nullptr;
1507     int region_id = 8;
1508     RETURN_IF_CUDNN_ERROR(cudnnGetRNNLinLayerMatrixParams(
1509         /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc,
1510         /*layer=*/layer, /*xDesc=*/input_desc.get(),
1511         /*wDesc=*/filter_desc.get(),
1512         /*w=*/nullptr, /*linLayerID=*/region_id,
1513         /*linLayerMatDesc=*/region_desc_handle.get(),
1514         /*linLayerMat or linLayerBias=*/&offset));
1515     int dims[] = {1, 1, 1};
1516     cudnnDataType_t data_type;
1517     cudnnTensorFormat_t tensor_format;
1518     int n_dims;
1519     RETURN_IF_CUDNN_ERROR(cudnnGetFilterNdDescriptor(
1520         /*filterDesc=*/region_desc_handle.get(),
1521         /*nbDimsRequested=*/sizeof(dims) / sizeof(dims[0]),
1522         /*dataType=*/&data_type, /*format=*/&tensor_format,
1523         /*nbDims=*/&n_dims, /*filterDimA=*/dims));
1524     int64_t size =
1525         dims[0] * dims[1] * dims[2] * CudnnDataTypeToByteSize(data_type);
1526     dnn::RnnDescriptor::ParamsRegion region = {
1527         reinterpret_cast<int64_t>(offset), size};
1528     weights->push_back(region);
1529   }
1530   return ::tensorflow::OkStatus();
1531 }
1532 
Create(const CudnnHandle & cudnn,int input_size,cudnnDataType_t data_type,cudnnRNNDescriptor_t rnn_desc,cudnnRNNMode_t rnn_mode,cudnnDirectionMode_t direction_mode,int num_layers)1533 port::StatusOr<CudnnRnnParamsDescriptor> CudnnRnnParamsDescriptor::Create(
1534     const CudnnHandle& cudnn, int input_size, cudnnDataType_t data_type,
1535     cudnnRNNDescriptor_t rnn_desc, cudnnRNNMode_t rnn_mode,
1536     cudnnDirectionMode_t direction_mode, int num_layers) {
1537   // Query the params size.
1538   TensorDescriptor input_desc = CreateTensorDescriptor();
1539   int tensor_dims[] = {1, input_size, 1};
1540   int strides[] = {tensor_dims[1] * tensor_dims[2], tensor_dims[2], 1};
1541   RETURN_IF_CUDNN_ERROR(cudnnSetTensorNdDescriptor(
1542       /*tensorDesc=*/input_desc.get(), /*dataType=*/data_type,
1543       /*nbDims=*/sizeof(tensor_dims) / sizeof(tensor_dims[0]),
1544       /*dimA=*/tensor_dims,
1545       /*strideA=*/strides));
1546 
1547   size_t params_size = 0;
1548   RETURN_IF_CUDNN_ERROR(cudnnGetRNNParamsSize(
1549       /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc,
1550       /*xDesc=*/input_desc.get(), /*sizeInBytes=*/&params_size,
1551       /*dataType=*/data_type));
1552   int64_t params_size_in_bytes = static_cast<int64_t>(params_size);
1553 
1554   FilterDescriptor filter_desc = CreateFilterDescriptor();
1555   int64_t filter_dim0 =
1556       params_size_in_bytes / CudnnDataTypeToByteSize(data_type);
1557   int filter_dims[] = {static_cast<int>(filter_dim0), 1, 1};
1558   RETURN_IF_CUDNN_ERROR(cudnnSetFilterNdDescriptor(
1559       /*filterDesc=*/filter_desc.get(), /*dataType=*/data_type,
1560       /*format=*/CUDNN_TENSOR_NCHW,
1561       /*nbDims=*/sizeof(filter_dims) / sizeof(filter_dims[0]),
1562       /*filterDimA=*/filter_dims));
1563 
1564   // Create the weights and biases into the params buffer
1565   int region_count_per_layer = [&] {
1566     switch (rnn_mode) {
1567       case CUDNN_RNN_RELU:
1568       case CUDNN_RNN_TANH:
1569         return 2;
1570       case CUDNN_LSTM:
1571         return 8;
1572       case CUDNN_GRU:
1573         return 6;
1574       default:
1575         LOG(FATAL) << "Invalid RNN Mode: " << static_cast<int>(rnn_mode);
1576         return 0;
1577     }
1578   }();
1579 
1580   FilterDescriptor region_desc_handle = CreateFilterDescriptor();
1581   const int layer_count =
1582       direction_mode == CUDNN_UNIDIRECTIONAL ? num_layers : 2 * num_layers;
1583 
1584   ParamsRegions weights;
1585   ParamsRegions biases;
1586 
1587   for (int layer = 0; layer < layer_count; layer++) {
1588     for (int region = 0; region < region_count_per_layer; region++) {
1589       for (int type = 0; type < 2; type++) {
1590         void* offset = nullptr;
1591         RETURN_IF_CUDNN_ERROR(
1592             type == 0 ? cudnnGetRNNLinLayerMatrixParams(
1593                             /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc,
1594                             /*layer=*/layer, /*xDesc=*/input_desc.get(),
1595                             /*wDesc=*/filter_desc.get(),
1596                             /*w=*/nullptr, /*linLayerID=*/region,
1597                             /*linLayerMatDesc=*/region_desc_handle.get(),
1598                             /*linLayerMat or linLayerBias=*/&offset)
1599                       : cudnnGetRNNLinLayerBiasParams(
1600                             /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc,
1601                             /*layer=*/layer, /*xDesc=*/input_desc.get(),
1602                             /*wDesc=*/filter_desc.get(),
1603                             /*w=*/nullptr, /*linLayerID=*/region,
1604                             /*linLayerMatDesc=*/region_desc_handle.get(),
1605                             /*linLayerMat or linLayerBias=*/&offset));
1606         int dims[] = {1, 1, 1};
1607         cudnnDataType_t data_type;
1608         cudnnTensorFormat_t tensor_format;
1609         int n_dims;
1610         RETURN_IF_CUDNN_ERROR(cudnnGetFilterNdDescriptor(
1611             /*filterDesc=*/region_desc_handle.get(),
1612             /*nbDimsRequested=*/sizeof(dims) / sizeof(dims[0]),
1613             /*dataType=*/&data_type, /*format=*/&tensor_format,
1614             /*nbDims=*/&n_dims, /*filterDimA=*/dims));
1615         int64_t size =
1616             dims[0] * dims[1] * dims[2] * CudnnDataTypeToByteSize(data_type);
1617         dnn::RnnDescriptor::ParamsRegion region = {
1618             reinterpret_cast<int64_t>(offset), size};
1619         (type == 0 ? weights : biases).push_back(region);
1620       }
1621     }
1622     TF_RETURN_IF_ERROR(CheckAndFetchProjectionWeights(
1623         cudnn, rnn_desc, layer, input_desc, filter_desc, region_desc_handle,
1624         &weights));
1625   }
1626 
1627   return CudnnRnnParamsDescriptor(std::move(filter_desc), params_size_in_bytes,
1628                                   weights, biases);
1629 }
1630 
1631 }  // namespace
1632 
1633 class CudnnRnnSequenceTensorDescriptor
1634     : public dnn::RnnSequenceTensorDescriptor {
CudnnRnnSequenceTensorDescriptor(GpuExecutor * parent,int max_seq_length,int batch_size,int data_size,cudnnDataType_t data_type,RNNDataDescriptor data_handle,TensorDescriptor handle)1635   CudnnRnnSequenceTensorDescriptor(GpuExecutor* parent, int max_seq_length,
1636                                    int batch_size, int data_size,
1637                                    cudnnDataType_t data_type,
1638                                    RNNDataDescriptor data_handle,
1639                                    TensorDescriptor handle)
1640       : max_seq_length_(max_seq_length),
1641         batch_size_(batch_size),
1642         data_size_(data_size),
1643         data_type_(data_type),
1644         handle_(std::move(handle)),
1645         rnn_data_handle_(std::move(data_handle)),
1646         handles_(max_seq_length, handle_.get()) {}
1647 
1648  public:
1649   CudnnRnnSequenceTensorDescriptor(CudnnRnnSequenceTensorDescriptor&&) =
1650       default;
1651 
Create(GpuExecutor * parent,int max_seq_length,int batch_size,int data_size,cudnnDataType_t data_type)1652   static port::StatusOr<CudnnRnnSequenceTensorDescriptor> Create(
1653       GpuExecutor* parent, int max_seq_length, int batch_size, int data_size,
1654       cudnnDataType_t data_type) {
1655     if (max_seq_length <= 0) {
1656       return port::Status(port::error::INVALID_ARGUMENT, "max_seq_length <= 0");
1657     }
1658     int dims[] = {batch_size, data_size, 1};
1659     int strides[] = {dims[1] * dims[2], dims[2], 1};
1660     TensorDescriptor tensor_desc = CreateTensorDescriptor();
1661     RETURN_IF_CUDNN_ERROR(cudnnSetTensorNdDescriptor(
1662         /*tensorDesc=*/tensor_desc.get(), /*dataType=*/data_type,
1663         /*nbDims=*/sizeof(dims) / sizeof(dims[0]), /*dimA=*/dims,
1664         /*strideA=*/strides));
1665     return CudnnRnnSequenceTensorDescriptor(parent, max_seq_length, batch_size,
1666                                             data_size, data_type, nullptr,
1667                                             std::move(tensor_desc));
1668   }
1669 
Create(GpuExecutor * parent,int max_seq_length,int batch_size,int data_size,const absl::Span<const int> & seq_lengths,bool time_major,cudnnDataType_t data_type)1670   static port::StatusOr<CudnnRnnSequenceTensorDescriptor> Create(
1671       GpuExecutor* parent, int max_seq_length, int batch_size, int data_size,
1672       const absl::Span<const int>& seq_lengths, bool time_major,
1673       cudnnDataType_t data_type) {
1674     if (max_seq_length <= 0) {
1675       return port::Status(port::error::INVALID_ARGUMENT, "max_seq_length <= 0");
1676     }
1677     int dims[] = {batch_size, data_size, 1};
1678     int strides[] = {dims[1] * dims[2], dims[2], 1};
1679     TensorDescriptor tensor_desc = CreateTensorDescriptor();
1680     RETURN_IF_CUDNN_ERROR(cudnnSetTensorNdDescriptor(
1681         /*tensorDesc=*/tensor_desc.get(), /*dataType=*/data_type,
1682         /*nbDims=*/sizeof(dims) / sizeof(dims[0]), /*dimA=*/dims,
1683         /*strideA=*/strides));
1684     const int* seq_lengths_array = seq_lengths.data();
1685     RNNDataDescriptor data_desc = CreateRNNDataDescriptor();
1686     float padding_fill = 0.0f;
1687     cudnnRNNDataLayout_t layout;
1688     if (time_major) {
1689       layout = CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_UNPACKED;
1690     } else {
1691       layout = CUDNN_RNN_DATA_LAYOUT_BATCH_MAJOR_UNPACKED;
1692     }
1693     RETURN_IF_CUDNN_ERROR(cudnnSetRNNDataDescriptor(
1694         /*RNNDataDesc=*/data_desc.get(), /*dataType*/ data_type,
1695         /*layout=*/layout,
1696         /*maxSeqLength=*/max_seq_length,
1697         /*batchSize=*/batch_size, /*vectorSize=*/data_size,
1698         /*seqLengthArray=*/seq_lengths_array,
1699         /*paddingFill*/ (void*)&padding_fill));
1700     return CudnnRnnSequenceTensorDescriptor(
1701         parent, max_seq_length, batch_size, data_size, data_type,
1702         std::move(data_desc), std::move(tensor_desc));
1703   }
1704 
handles() const1705   const cudnnTensorDescriptor_t* handles() const { return handles_.data(); }
data_handle() const1706   const cudnnRNNDataDescriptor_t data_handle() const {
1707     return rnn_data_handle_.get();
1708   }
1709 
max_seq_length() const1710   int max_seq_length() const { return max_seq_length_; }
batch_size() const1711   int batch_size() const { return batch_size_; }
data_size() const1712   int data_size() const { return data_size_; }
is_var_seq_lengths() const1713   bool is_var_seq_lengths() const { return rnn_data_handle_ != nullptr; }
1714 
1715  private:
1716   int max_seq_length_;
1717   int batch_size_;
1718   int data_size_;
1719   cudnnDataType_t data_type_;
1720   TensorDescriptor handle_;
1721   RNNDataDescriptor rnn_data_handle_;
1722   std::vector<cudnnTensorDescriptor_t> handles_;  // Copies of handle_.
1723   SE_DISALLOW_COPY_AND_ASSIGN(CudnnRnnSequenceTensorDescriptor);
1724 };
1725 
1726 class CudnnRnnStateTensorDescriptor : public dnn::RnnStateTensorDescriptor {
1727  public:
CudnnRnnStateTensorDescriptor(GpuExecutor * parent,int num_layers,int batch_size,int data_size,cudnnDataType_t data_type)1728   CudnnRnnStateTensorDescriptor(GpuExecutor* parent, int num_layers,
1729                                 int batch_size, int data_size,
1730                                 cudnnDataType_t data_type)
1731       : handle_(CreateTensorDescriptor()),
1732         num_layers_(num_layers),
1733         batch_size_(batch_size),
1734         data_size_(data_size),
1735         data_type_(data_type) {
1736     int dims[] = {num_layers, batch_size, data_size};
1737     int strides[] = {dims[1] * dims[2], dims[2], 1};
1738     CHECK_CUDNN_OK(cudnnSetTensorNdDescriptor(
1739         /*tensorDesc=*/handle_.get(), /*dataType=*/data_type,
1740         /*nbDims=*/sizeof(dims) / sizeof(dims[0]), /*dimA=*/dims,
1741         /*strideA=*/strides));
1742   }
1743 
handle() const1744   cudnnTensorDescriptor_t handle() const { return handle_.get(); }
1745 
num_layers() const1746   int num_layers() const { return num_layers_; }
batch_size() const1747   int batch_size() const { return batch_size_; }
data_size() const1748   int data_size() const { return data_size_; }
1749 
1750  private:
1751   TensorDescriptor handle_;
1752   int num_layers_;
1753   int batch_size_;
1754   int data_size_;
1755   cudnnDataType_t data_type_;
1756   SE_DISALLOW_COPY_AND_ASSIGN(CudnnRnnStateTensorDescriptor);
1757 };
1758 
1759 namespace {
1760 
1761 struct RnnModelDims {
1762   int num_layers = 0;
1763   int batch_size = 0;
1764   int max_seq_length = 0;
1765   int hidden_size = 0;
1766   int input_size = 0;
1767   int cell_size = 0;
1768   int dir_count = 0;
1769 };
1770 
1771 template <class T>
ExtractAndCheckRnnForward(const CudnnRnnDescriptor & rnn_desc,const CudnnRnnSequenceTensorDescriptor & input_desc,const DeviceMemory<T> & input_data,const CudnnRnnStateTensorDescriptor & input_h_desc,const DeviceMemory<T> & input_h_data,const CudnnRnnStateTensorDescriptor & input_c_desc,const DeviceMemory<T> & input_c_data,const DeviceMemory<T> & params,const CudnnRnnSequenceTensorDescriptor & output_desc,const DeviceMemory<T> & output_data,const CudnnRnnStateTensorDescriptor & output_h_desc,const DeviceMemory<T> & output_h_data,const CudnnRnnStateTensorDescriptor & output_c_desc,const DeviceMemory<T> & output_c_data)1772 port::StatusOr<RnnModelDims> ExtractAndCheckRnnForward(
1773     const CudnnRnnDescriptor& rnn_desc,
1774     const CudnnRnnSequenceTensorDescriptor& input_desc,
1775     const DeviceMemory<T>& input_data,
1776     const CudnnRnnStateTensorDescriptor& input_h_desc,
1777     const DeviceMemory<T>& input_h_data,
1778     const CudnnRnnStateTensorDescriptor& input_c_desc,
1779     const DeviceMemory<T>& input_c_data, const DeviceMemory<T>& params,
1780     const CudnnRnnSequenceTensorDescriptor& output_desc,
1781     const DeviceMemory<T>& output_data,
1782     const CudnnRnnStateTensorDescriptor& output_h_desc,
1783     const DeviceMemory<T>& output_h_data,
1784     const CudnnRnnStateTensorDescriptor& output_c_desc,
1785     const DeviceMemory<T>& output_c_data) {
1786   // extract model parameters
1787   RnnModelDims model_dims;
1788   model_dims.num_layers = rnn_desc.num_layers();
1789   model_dims.batch_size = input_desc.batch_size();
1790   model_dims.max_seq_length = input_desc.max_seq_length();
1791   model_dims.hidden_size = rnn_desc.hidden_size();
1792   model_dims.input_size = input_desc.data_size();
1793   model_dims.cell_size = rnn_desc.cell_size();
1794   model_dims.dir_count =
1795       (rnn_desc.direction_mode() == CUDNN_BIDIRECTIONAL) ? 2 : 1;
1796 
1797   // check parameters
1798   if (!(input_h_desc.num_layers() ==
1799             model_dims.num_layers * model_dims.dir_count &&
1800         input_h_desc.batch_size() == model_dims.batch_size &&
1801         input_h_desc.data_size() == model_dims.hidden_size)) {
1802     return port::Status(port::error::INVALID_ARGUMENT, "Invalid input_h shape");
1803   }
1804   // The LSTM projection will be used if input_h_desc.data_size() <
1805   // input_c_desc.data_size()
1806   if (!(input_h_desc.num_layers() == input_c_desc.num_layers() &&
1807         input_h_desc.batch_size() == input_c_desc.batch_size() &&
1808         input_h_desc.data_size() <= input_c_desc.data_size())) {
1809     return port::Status(port::error::INVALID_ARGUMENT, "Invalid input_c shape");
1810   }
1811   if (!(output_desc.max_seq_length() == model_dims.max_seq_length &&
1812         output_desc.batch_size() == model_dims.batch_size &&
1813         output_desc.data_size() ==
1814             model_dims.hidden_size * model_dims.dir_count)) {
1815     return port::Status(port::error::INVALID_ARGUMENT, "Invalid output shape");
1816   }
1817   if (!(input_h_desc.num_layers() == output_h_desc.num_layers() &&
1818         input_h_desc.batch_size() == output_h_desc.batch_size() &&
1819         input_h_desc.data_size() == output_h_desc.data_size())) {
1820     return port::Status(port::error::INVALID_ARGUMENT,
1821                         "Invalid output_h shape");
1822   }
1823   if (!(input_h_desc.num_layers() == output_c_desc.num_layers() &&
1824         input_h_desc.batch_size() == output_c_desc.batch_size() &&
1825         input_h_desc.data_size() <= output_c_desc.data_size())) {
1826     return port::Status(port::error::INVALID_ARGUMENT,
1827                         "Invalid output_c shape");
1828   }
1829 
1830   return model_dims;
1831 }
1832 
CheckRNNParameterSize(const CudnnHandle & cudnn,const CudnnRnnDescriptor & rnn_desc,const CudnnRnnSequenceTensorDescriptor & input_desc)1833 port::Status CheckRNNParameterSize(
1834     const CudnnHandle& cudnn, const CudnnRnnDescriptor& rnn_desc,
1835     const CudnnRnnSequenceTensorDescriptor& input_desc) {
1836   size_t params_size_in_bytes = 0;
1837 #if CUDNN_VERSION >= 8100 && TF_ENABLE_CUDNN_FRONTEND
1838   RETURN_IF_CUDNN_ERROR(cudnnGetRNNWeightSpaceSize(
1839       /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(),
1840       /*sizeInBytes=*/&params_size_in_bytes));
1841 #else
1842   RETURN_IF_CUDNN_ERROR(cudnnGetRNNParamsSize(
1843       /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(),
1844       /*xDesc=*/input_desc.handles()[0], /*sizeInBytes=*/&params_size_in_bytes,
1845       /*dataType=*/rnn_desc.data_type()));
1846 #endif
1847   if (static_cast<int64_t>(params_size_in_bytes) !=
1848       rnn_desc.ParamsSizeInBytes()) {
1849     return port::Status(port::error::INVALID_ARGUMENT,
1850                         "Mismatching RNN parameter size");
1851   }
1852   return ::tensorflow::OkStatus();
1853 }
1854 
CreateRnnWorkspace(Stream * stream,const CudnnHandle & cudnn,const CudnnRnnDescriptor & rnn_desc,const CudnnRnnSequenceTensorDescriptor & input_desc,ScratchAllocator * workspace_allocator)1855 port::StatusOr<DeviceMemory<uint8>> CreateRnnWorkspace(
1856     Stream* stream, const CudnnHandle& cudnn,
1857     const CudnnRnnDescriptor& rnn_desc,
1858     const CudnnRnnSequenceTensorDescriptor& input_desc,
1859     ScratchAllocator* workspace_allocator) {
1860   // Query the workspace size.
1861   size_t workspace_size_in_bytes = 0;
1862   RETURN_IF_CUDNN_ERROR(cudnnGetRNNWorkspaceSize(
1863       /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(),
1864       /*seqLength=*/input_desc.max_seq_length(), /*xDesc=*/input_desc.handles(),
1865       /*sizeInBytes=*/&workspace_size_in_bytes));
1866   // Allocate the workspace.
1867   if (workspace_size_in_bytes == 0) {
1868     return DeviceMemory<uint8>();
1869   }
1870   return workspace_allocator->AllocateBytes(workspace_size_in_bytes);
1871 }
1872 
1873 #if CUDNN_VERSION >= 7402
CreateBatchNormForwardWorkspace(Stream * stream,const CudnnHandle & cudnn,const cudnnBatchNormMode_t & mode,const cudnnBatchNormOps_t & bn_ops,const cudnnActivationDescriptor_t & activation_desc,const CudnnTensorDescriptor & x_descriptor,const CudnnTensorDescriptor & scale_offset_descriptor,ScratchAllocator * workspace_allocator)1874 port::StatusOr<DeviceMemory<uint8>> CreateBatchNormForwardWorkspace(
1875     Stream* stream, const CudnnHandle& cudnn, const cudnnBatchNormMode_t& mode,
1876     const cudnnBatchNormOps_t& bn_ops,
1877     const cudnnActivationDescriptor_t& activation_desc,
1878     const CudnnTensorDescriptor& x_descriptor,
1879     const CudnnTensorDescriptor& scale_offset_descriptor,
1880     ScratchAllocator* workspace_allocator) {
1881   // Query the workspace size.
1882   size_t workspace_size_in_bytes = 0;
1883   RETURN_IF_CUDNN_ERROR(
1884       cudnnGetBatchNormalizationForwardTrainingExWorkspaceSize(
1885           /*handle=*/cudnn.handle(), /*mode=*/mode, /*bnOps=*/bn_ops,
1886           /*xDesc=*/x_descriptor.handle(), /*zDesc=*/x_descriptor.handle(),
1887           /*yDesc=*/x_descriptor.handle(),
1888           /*bnScaleBiasMeanVarDesc=*/scale_offset_descriptor.handle(),
1889           /*activationDesc=*/activation_desc,
1890           /*sizeInBytes=*/&workspace_size_in_bytes));
1891   // Allocate the workspace.
1892   if (workspace_size_in_bytes == 0) {
1893     return DeviceMemory<uint8>();
1894   }
1895   return workspace_allocator->AllocateBytes(workspace_size_in_bytes);
1896 }
1897 
CreateBatchNormBackwardWorkspace(Stream * stream,const CudnnHandle & cudnn,const cudnnBatchNormMode_t & mode,const cudnnBatchNormOps_t & bn_ops,const cudnnActivationDescriptor_t & activation_desc,const CudnnTensorDescriptor & x_descriptor,const CudnnTensorDescriptor & scale_offset_descriptor,ScratchAllocator * workspace_allocator)1898 port::StatusOr<DeviceMemory<uint8>> CreateBatchNormBackwardWorkspace(
1899     Stream* stream, const CudnnHandle& cudnn, const cudnnBatchNormMode_t& mode,
1900     const cudnnBatchNormOps_t& bn_ops,
1901     const cudnnActivationDescriptor_t& activation_desc,
1902     const CudnnTensorDescriptor& x_descriptor,
1903     const CudnnTensorDescriptor& scale_offset_descriptor,
1904     ScratchAllocator* workspace_allocator) {
1905   // Query the workspace size.
1906   size_t workspace_size_in_bytes = 0;
1907   RETURN_IF_CUDNN_ERROR(cudnnGetBatchNormalizationBackwardExWorkspaceSize(
1908       /*handle=*/cudnn.handle(), /*mode=*/mode, /*bnOps=*/bn_ops,
1909       /*xDesc=*/x_descriptor.handle(),
1910       /*yDesc=*/x_descriptor.handle(),
1911       /*dyDesc=*/x_descriptor.handle(),
1912       /*dzDesc=*/x_descriptor.handle(),
1913       /*dxDesc=*/x_descriptor.handle(),
1914       /*dBnScaleBiasDesc=*/scale_offset_descriptor.handle(),
1915       /*activationDesc=*/activation_desc,
1916       /*sizeInBytes=*/&workspace_size_in_bytes));
1917   // Allocate the workspace.
1918   if (workspace_size_in_bytes == 0) {
1919     return DeviceMemory<uint8>();
1920   }
1921   return workspace_allocator->AllocateBytes(workspace_size_in_bytes);
1922 }
1923 
1924 #endif
1925 
1926 }  // namespace
1927 
1928 template <class T>
DoRnnForwardImpl(Stream * stream,const CudnnRnnDescriptor & rnn_desc,const CudnnRnnSequenceTensorDescriptor & input_desc,const DeviceMemory<T> & input_data,const DeviceMemory<int> & seq_lengths_data,const CudnnRnnStateTensorDescriptor & input_h_desc,const DeviceMemory<T> & input_h_data,const CudnnRnnStateTensorDescriptor & input_c_desc,const DeviceMemory<T> & input_c_data,const DeviceMemory<T> & params,const CudnnRnnSequenceTensorDescriptor & output_desc,DeviceMemory<T> * output_data,const CudnnRnnStateTensorDescriptor & output_h_desc,DeviceMemory<T> * output_h_data,const CudnnRnnStateTensorDescriptor & output_c_desc,DeviceMemory<T> * output_c_data,bool is_training,ScratchAllocator * reserve_space_allocator,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)1929 port::Status CudnnSupport::DoRnnForwardImpl(
1930     Stream* stream, const CudnnRnnDescriptor& rnn_desc,
1931     const CudnnRnnSequenceTensorDescriptor& input_desc,
1932     const DeviceMemory<T>& input_data,
1933     const DeviceMemory<int>& seq_lengths_data,
1934     const CudnnRnnStateTensorDescriptor& input_h_desc,
1935     const DeviceMemory<T>& input_h_data,
1936     const CudnnRnnStateTensorDescriptor& input_c_desc,
1937     const DeviceMemory<T>& input_c_data, const DeviceMemory<T>& params,
1938     const CudnnRnnSequenceTensorDescriptor& output_desc,
1939     DeviceMemory<T>* output_data,
1940     const CudnnRnnStateTensorDescriptor& output_h_desc,
1941     DeviceMemory<T>* output_h_data,
1942     const CudnnRnnStateTensorDescriptor& output_c_desc,
1943     DeviceMemory<T>* output_c_data, bool is_training,
1944     ScratchAllocator* reserve_space_allocator,
1945     ScratchAllocator* workspace_allocator,
1946     dnn::ProfileResult* output_profile_result) {
1947   TF_ASSIGN_OR_RETURN(
1948       RnnModelDims model_dims,
1949       ExtractAndCheckRnnForward(
1950           rnn_desc, input_desc, input_data, input_h_desc, input_h_data,
1951           input_c_desc, input_c_data, params, output_desc, *output_data,
1952           output_h_desc, *output_h_data, output_c_desc, *output_c_data));
1953 
1954   auto cudnn = cudnn_->GetHandle(parent_, stream);
1955 
1956   TF_RETURN_IF_ERROR(CheckRNNParameterSize(cudnn, rnn_desc, input_desc));
1957 
1958   // In CUDNN v8.0, the cudnnRNNForward*** and cudnnRNNForward***Ex have been
1959   // deprecated. Instead, we use the cudnnRNNForward which requires the
1960   // sequence_lengths parameter. For more info,
1961   // https://docs.nvidia.com/deeplearning/cudnn/api/index.html#release-802.
1962 #if CUDNN_VERSION >= 8100 && TF_ENABLE_CUDNN_FRONTEND
1963   if (input_desc.is_var_seq_lengths()) {
1964     DeviceMemory<uint8> workspace;
1965     DeviceMemory<uint8> reserve_space;
1966     cudnnForwardMode_t rnn_fwd_mode;
1967     if (is_training) {
1968       rnn_fwd_mode = CUDNN_FWD_MODE_TRAINING;
1969     } else {
1970       rnn_fwd_mode = CUDNN_FWD_MODE_INFERENCE;
1971     }
1972     size_t reserve_space_size_in_bytes = 0;
1973     size_t workspace_size_in_bytes = 0;
1974     RETURN_IF_CUDNN_ERROR(cudnnGetRNNTempSpaceSizes(
1975         /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(),
1976         /*fMode=*/rnn_fwd_mode, /*xDesc=*/input_desc.data_handle(),
1977         /*workSpaceSize=*/&workspace_size_in_bytes,
1978         /*reserveSpaceSize=*/&reserve_space_size_in_bytes));
1979 
1980     if (workspace_size_in_bytes > 0) {
1981       TF_ASSIGN_OR_RETURN(workspace, workspace_allocator->AllocateBytes(
1982                                          workspace_size_in_bytes));
1983     }
1984     if (reserve_space_size_in_bytes > 0) {
1985       TF_ASSIGN_OR_RETURN(reserve_space, reserve_space_allocator->AllocateBytes(
1986                                              reserve_space_size_in_bytes));
1987     }
1988 
1989     std::unique_ptr<GpuTimer, GpuTimerDeleter> timer;
1990     const bool is_profiling = output_profile_result != nullptr;
1991     if (is_profiling) {
1992       timer.reset(new GpuTimer(parent_));
1993       // The start and stop of the timer should be as close to the Cudnn call as
1994       // possible. It is still possible for other threads to issue workload on
1995       // to this stream. So it could take multiple profiling measurements.
1996       if (!timer->Init() || !timer->Start(AsGpuStream(stream))) {
1997         return port::Status(port::error::INTERNAL, "Failed to start timer");
1998       }
1999     }
2000 
2001     RETURN_IF_CUDNN_ERROR(cudnnRNNForward(
2002         /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(),
2003         /*fwdMode=*/rnn_fwd_mode,
2004         /*devSeqLengths=*/
2005         reinterpret_cast<const int*>(seq_lengths_data.opaque()),
2006         /*xDesc=*/input_desc.data_handle(), /*x=*/input_data.opaque(),
2007         /*yDesc=*/output_desc.data_handle(), /*y=*/output_data->opaque(),
2008         /*hxDesc=*/input_h_desc.handle(), /*hx=*/input_h_data.opaque(),
2009         /*hy=*/output_h_data->opaque(),
2010         /*cxDesc=*/input_c_desc.handle(), /*cx=*/input_c_data.opaque(),
2011         /*cy=*/output_c_data->opaque(),
2012         /*weightSpaceSize=*/rnn_desc.ParamsSizeInBytes(),
2013         /*weightSpace=*/params.opaque(),
2014         /*workSpaceSize=*/workspace.size(), /*workspace=*/workspace.opaque(),
2015         /*reserveSpaceSizeInBytes=*/reserve_space.size(),
2016         /*reserveSpace=*/reserve_space.opaque()));
2017 
2018     if (is_profiling) {
2019       if (!timer->Stop(AsGpuStream(stream))) {
2020         return port::Status(port::error::INTERNAL, "Failed to stop timer");
2021       }
2022       auto algo_desc = *rnn_desc.algorithm_config().algorithm();
2023       output_profile_result->set_algorithm(algo_desc);
2024       output_profile_result->set_elapsed_time_in_ms(
2025           timer->GetElapsedMilliseconds());
2026     }
2027     return port::Status::OK();
2028   }
2029 #endif
2030   TF_ASSIGN_OR_RETURN(DeviceMemory<uint8> workspace,
2031                       CreateRnnWorkspace(stream, cudnn, rnn_desc, input_desc,
2032                                          workspace_allocator));
2033 
2034   // query the reserve space size
2035   // allocate the reserve space
2036   DeviceMemory<uint8> reserve_space;
2037   if (is_training) {
2038     size_t reserve_space_size_in_bytes = 0;
2039     RETURN_IF_CUDNN_ERROR(cudnnGetRNNTrainingReserveSize(
2040         /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(),
2041         /*seqLength=*/model_dims.max_seq_length, /*xDesc=*/input_desc.handles(),
2042         /*sizeInBytes=*/&reserve_space_size_in_bytes));
2043 
2044     if (reserve_space_size_in_bytes > 0) {
2045       TF_ASSIGN_OR_RETURN(reserve_space, reserve_space_allocator->AllocateBytes(
2046                                              reserve_space_size_in_bytes));
2047     }
2048   }
2049 
2050   std::unique_ptr<GpuTimer, GpuTimerDeleter> timer;
2051   const bool is_profiling = output_profile_result != nullptr;
2052   if (is_profiling) {
2053     timer.reset(new GpuTimer(parent_));
2054     // The start and stop of the timer should be as close to the Cudnn call as
2055     // possible. It is still possible for other threads to issue workload on
2056     // to this stream. So it could take multiple profiling measurements.
2057     if (!timer->Init() || !timer->Start(AsGpuStream(stream))) {
2058       return port::Status(port::error::INTERNAL, "Failed to start timer");
2059     }
2060   }
2061 
2062   if (!is_training) {
2063     if (input_desc.is_var_seq_lengths()) {
2064       RETURN_IF_CUDNN_ERROR(cudnnRNNForwardInferenceEx(
2065           /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(),
2066           /*xDesc=*/input_desc.data_handle(), /*x=*/input_data.opaque(),
2067           /*hxDesc=*/input_h_desc.handle(), /*hx=*/input_h_data.opaque(),
2068           /*cxDesc=*/input_c_desc.handle(), /*cx=*/input_c_data.opaque(),
2069           /*wDesc=*/rnn_desc.params_handle(), /*w=*/params.opaque(),
2070           /*yDesc=*/output_desc.data_handle(),
2071           /*y=*/output_data->opaque(),
2072           /*hyDesc=*/output_h_desc.handle(), /*hy=*/output_h_data->opaque(),
2073           /*cyDesc=*/output_c_desc.handle(), /*cy=*/output_c_data->opaque(),
2074           nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr,
2075           nullptr,
2076           /*workspace=*/workspace.opaque(),
2077           /*workSpaceSizeInBytes=*/workspace.size()));
2078     } else {
2079       RETURN_IF_CUDNN_ERROR(cudnnRNNForwardInference(
2080           /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(),
2081           /*seqLength=*/model_dims.max_seq_length,
2082           /*xDesc=*/input_desc.handles(),
2083           /*x=*/input_data.opaque(), /*hxDesc=*/input_h_desc.handle(),
2084           /*hx=*/input_h_data.opaque(), /*cxDesc=*/input_c_desc.handle(),
2085           /*cx=*/input_c_data.opaque(), /*wDesc=*/rnn_desc.params_handle(),
2086           /*w=*/params.opaque(), /*yDesc=*/output_desc.handles(),
2087           /*y=*/output_data->opaque(), /*hyDesc=*/output_h_desc.handle(),
2088           /*hy=*/output_h_data->opaque(), /*cyDesc=*/output_c_desc.handle(),
2089           /*cy=*/output_c_data->opaque(), /*workspace=*/workspace.opaque(),
2090           /*workSpaceSizeInBytes=*/workspace.size()));
2091     }
2092   } else {
2093     if (input_desc.is_var_seq_lengths()) {
2094       RETURN_IF_CUDNN_ERROR(cudnnRNNForwardTrainingEx(
2095           /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(),
2096           /*xDesc=*/input_desc.data_handle(), /*x=*/input_data.opaque(),
2097           /*hxDesc=*/input_h_desc.handle(), /*hx=*/input_h_data.opaque(),
2098           /*cxDesc=*/input_c_desc.handle(), /*cx=*/input_c_data.opaque(),
2099           /*wDesc=*/rnn_desc.params_handle(), /*w=*/params.opaque(),
2100           /*yDesc=*/output_desc.data_handle(),
2101           /*y=*/output_data->opaque(),
2102           /*hyDesc=*/output_h_desc.handle(), /*hy=*/output_h_data->opaque(),
2103           /*cyDesc=*/output_c_desc.handle(), /*cy=*/output_c_data->opaque(),
2104           nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr,
2105           nullptr,
2106           /*workspace=*/workspace.opaque(),
2107           /*workSpaceSizeInBytes=*/workspace.size(),
2108           /*reserveSpace=*/reserve_space.opaque(),
2109           /*reserveSpaceSizeInBytes=*/reserve_space.size()));
2110     } else {
2111       RETURN_IF_CUDNN_ERROR(cudnnRNNForwardTraining(
2112           /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(),
2113           /*seqLength=*/model_dims.max_seq_length,
2114           /*xDesc=*/input_desc.handles(),
2115           /*x=*/input_data.opaque(), /*hxDesc=*/input_h_desc.handle(),
2116           /*hx=*/input_h_data.opaque(), /*cxDesc=*/input_c_desc.handle(),
2117           /*cx=*/input_c_data.opaque(), /*wDesc=*/rnn_desc.params_handle(),
2118           /*w=*/params.opaque(), /*yDesc=*/output_desc.handles(),
2119           /*y=*/output_data->opaque(), /*hyDesc=*/output_h_desc.handle(),
2120           /*hy=*/output_h_data->opaque(), /*cyDesc=*/output_c_desc.handle(),
2121           /*cy=*/output_c_data->opaque(), /*workspace=*/workspace.opaque(),
2122           /*workSpaceSizeInBytes=*/workspace.size(),
2123           /*reserveSpace=*/reserve_space.opaque(),
2124           /*reserveSpaceSizeInBytes=*/reserve_space.size()));
2125     }
2126   }
2127 
2128   if (is_profiling) {
2129     if (!timer->Stop(AsGpuStream(stream))) {
2130       return port::Status(port::error::INTERNAL, "Failed to stop timer");
2131     }
2132     auto algo_desc = *rnn_desc.algorithm_config().algorithm();
2133     output_profile_result->set_algorithm(algo_desc);
2134     output_profile_result->set_elapsed_time_in_ms(
2135         timer->GetElapsedMilliseconds());
2136   }
2137 
2138   return ::tensorflow::OkStatus();
2139 }
2140 
2141 template <class T>
DoRnnBackwardImpl(Stream * stream,const CudnnRnnDescriptor & rnn_desc,const CudnnRnnSequenceTensorDescriptor & input_desc,const DeviceMemory<T> & input_data,const DeviceMemory<int> & seq_lengths_data,const CudnnRnnStateTensorDescriptor & input_h_desc,const DeviceMemory<T> & input_h_data,const CudnnRnnStateTensorDescriptor & input_c_desc,const DeviceMemory<T> & input_c_data,const DeviceMemory<T> & params,const CudnnRnnSequenceTensorDescriptor & output_desc,const DeviceMemory<T> & output_data,const CudnnRnnStateTensorDescriptor & output_h_desc,const DeviceMemory<T> & output_h_data,const CudnnRnnStateTensorDescriptor & output_c_desc,const DeviceMemory<T> & output_c_data,const DeviceMemory<T> & output_backprop_data,const DeviceMemory<T> & output_h_backprop_data,const DeviceMemory<T> & output_c_backprop_data,DeviceMemory<T> * input_backprop_data,DeviceMemory<T> * input_h_backprop_data,DeviceMemory<T> * input_c_backprop_data,DeviceMemory<T> * params_backprop_data,DeviceMemory<uint8> * reserve_space_data,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)2142 port::Status CudnnSupport::DoRnnBackwardImpl(
2143     Stream* stream, const CudnnRnnDescriptor& rnn_desc,
2144     const CudnnRnnSequenceTensorDescriptor& input_desc,
2145     const DeviceMemory<T>& input_data,
2146     const DeviceMemory<int>& seq_lengths_data,
2147     const CudnnRnnStateTensorDescriptor& input_h_desc,
2148     const DeviceMemory<T>& input_h_data,
2149     const CudnnRnnStateTensorDescriptor& input_c_desc,
2150     const DeviceMemory<T>& input_c_data, const DeviceMemory<T>& params,
2151     const CudnnRnnSequenceTensorDescriptor& output_desc,
2152     const DeviceMemory<T>& output_data,
2153     const CudnnRnnStateTensorDescriptor& output_h_desc,
2154     const DeviceMemory<T>& output_h_data,
2155     const CudnnRnnStateTensorDescriptor& output_c_desc,
2156     const DeviceMemory<T>& output_c_data,
2157     const DeviceMemory<T>& output_backprop_data,
2158     const DeviceMemory<T>& output_h_backprop_data,
2159     const DeviceMemory<T>& output_c_backprop_data,
2160     DeviceMemory<T>* input_backprop_data,
2161     DeviceMemory<T>* input_h_backprop_data,
2162     DeviceMemory<T>* input_c_backprop_data,
2163     DeviceMemory<T>* params_backprop_data,
2164     DeviceMemory<uint8>* reserve_space_data,
2165     ScratchAllocator* workspace_allocator,
2166     dnn::ProfileResult* output_profile_result) {
2167   TF_ASSIGN_OR_RETURN(
2168       RnnModelDims model_dims,
2169       ExtractAndCheckRnnForward(rnn_desc, input_desc, input_data, input_h_desc,
2170                                 input_h_data, input_c_desc, input_c_data,
2171                                 params, output_desc, output_data, output_h_desc,
2172                                 output_h_data, output_c_desc, output_c_data));
2173 
2174   auto cudnn = cudnn_->GetHandle(parent_, stream);
2175 
2176   TF_RETURN_IF_ERROR(CheckRNNParameterSize(cudnn, rnn_desc, input_desc));
2177 
2178   // In CUDNN v8.0, the cudnnRNNForward*** and cudnnRNNForward***Ex have been
2179   // deprecated. Instead, we use the cudnnRNNForward which requires the
2180   // sequence_lengths parameter. For more info,
2181   // https://docs.nvidia.com/deeplearning/cudnn/api/index.html#release-802.
2182 #if CUDNN_VERSION >= 8100 && TF_ENABLE_CUDNN_FRONTEND
2183   if (input_desc.is_var_seq_lengths()) {
2184     DeviceMemory<uint8> workspace;
2185     size_t workspace_size_in_bytes = 0;
2186     RETURN_IF_CUDNN_ERROR(cudnnGetRNNTempSpaceSizes(
2187         /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(),
2188         /*fMode=*/CUDNN_FWD_MODE_TRAINING, /*xDesc=*/input_desc.data_handle(),
2189         /*workSpaceSize=*/&workspace_size_in_bytes,
2190         /*reserveSpaceSize=*/NULL));
2191     if (workspace_size_in_bytes > 0) {
2192       TF_ASSIGN_OR_RETURN(workspace, workspace_allocator->AllocateBytes(
2193                                          workspace_size_in_bytes));
2194     }
2195 
2196     std::unique_ptr<GpuTimer, GpuTimerDeleter> timer;
2197     const bool is_profiling = output_profile_result != nullptr;
2198     if (is_profiling) {
2199       timer.reset(new GpuTimer(parent_));
2200       // The start and stop of the timer should be as close to the Cudnn call as
2201       // possible. It is still possible for other threads to issue workload on
2202       // to this stream. So it could take multiple profiling measurements.
2203       if (!timer->Init() || !timer->Start(AsGpuStream(stream))) {
2204         return port::Status(port::error::INTERNAL, "Failed to start timer");
2205       }
2206     }
2207 
2208     RETURN_IF_CUDNN_ERROR(cudnnRNNBackwardData_v8(
2209         /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(),
2210         /*devSeqLengths=*/
2211         reinterpret_cast<const int*>(seq_lengths_data.opaque()),
2212         /*yDesc=*/output_desc.data_handle(), /*y=*/output_data.opaque(),
2213         /*dy=*/output_backprop_data.opaque(),
2214         /*xDesc=*/input_desc.data_handle(),
2215         /*dx=*/input_backprop_data->opaque(),
2216         /*hxDesc=*/input_h_desc.handle(), /*hx=*/input_h_data.opaque(),
2217         /*dhy=*/output_h_backprop_data.opaque(),
2218         /*dhx=*/input_h_backprop_data->opaque(),
2219         /*cxDesc=*/input_c_desc.handle(), /*cx=*/input_c_data.opaque(),
2220         /*dcy=*/output_c_backprop_data.opaque(),
2221         /*dcx=*/input_c_backprop_data->opaque(),
2222         /*weightSpaceSize=*/rnn_desc.ParamsSizeInBytes(),
2223         /*weightSpace=*/params.opaque(),
2224         /*workSpaceSize=*/workspace.size(), /*workSpace=*/workspace.opaque(),
2225         /*reserveSpaceSize=*/reserve_space_data->size(),
2226         /*reserveSpace=*/reserve_space_data->opaque()));
2227 
2228     if (params_backprop_data != nullptr) {
2229       // Clear the dw to zeros.
2230       stream->ThenMemZero(params_backprop_data, params_backprop_data->size());
2231       RETURN_IF_CUDNN_ERROR(cudnnRNNBackwardWeights_v8(
2232           /*handle=*/cudnn.handle(),
2233           /*rnnDesc=*/rnn_desc.handle(),
2234           /*addGrad=*/CUDNN_WGRAD_MODE_ADD,
2235           /*devSeqLengths=*/
2236           reinterpret_cast<const int*>(seq_lengths_data.opaque()),
2237           /*xDesc=*/input_desc.data_handle(),
2238           /*x=*/input_data.opaque(),
2239           /*hDesc=*/input_h_desc.handle(),
2240           /*hx=*/input_h_data.opaque(),
2241           /*yDesc=*/output_desc.data_handle(),
2242           /*y=*/output_data.opaque(),
2243           /*weightSpaceSize=*/rnn_desc.ParamsSizeInBytes(),
2244           /*dweightSpace=*/params_backprop_data->opaque(),
2245           /*workSpaceSize=*/workspace.size(),
2246           /*workSpace=*/workspace.opaque(),
2247           /*reserveSpaceSize=*/reserve_space_data->size(),
2248           /*reserveSpace=*/reserve_space_data->opaque()));
2249     }
2250 
2251     if (is_profiling) {
2252       if (!timer->Stop(AsGpuStream(stream))) {
2253         return port::Status(port::error::INTERNAL, "Failed to stop timer");
2254       }
2255       auto algo_desc = *rnn_desc.algorithm_config().algorithm();
2256       output_profile_result->set_algorithm(algo_desc);
2257       output_profile_result->set_elapsed_time_in_ms(
2258           timer->GetElapsedMilliseconds());
2259     }
2260     return port::Status::OK();
2261   }
2262 #endif
2263   TF_ASSIGN_OR_RETURN(DeviceMemory<uint8> workspace,
2264                       CreateRnnWorkspace(stream, cudnn, rnn_desc, input_desc,
2265                                          workspace_allocator));
2266 
2267   std::unique_ptr<GpuTimer, GpuTimerDeleter> timer;
2268   const bool is_profiling = output_profile_result != nullptr;
2269   if (is_profiling) {
2270     timer.reset(new GpuTimer(parent_));
2271     // The start and stop of the timer should be as close to the Cudnn call as
2272     // possible. It is still possible for other threads to issue workload on
2273     // to this stream. So it could take multiple profiling measurements.
2274     if (!timer->Init() || !timer->Start(AsGpuStream(stream))) {
2275       return port::Status(port::error::INTERNAL, "Failed to start timer");
2276     }
2277   }
2278 
2279   if (input_desc.is_var_seq_lengths()) {
2280     RETURN_IF_CUDNN_ERROR(cudnnRNNBackwardDataEx(
2281         /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(),
2282         /*yDesc=*/output_desc.data_handle(), /*y=*/output_data.opaque(),
2283         /*dyDesc=*/output_desc.data_handle(),
2284         /*dy=*/output_backprop_data.opaque(), nullptr, nullptr,
2285         /*dhyDesc=*/output_h_desc.handle(),
2286         /*dhy=*/output_h_backprop_data.opaque(),
2287         /*dcyDesc=*/output_c_desc.handle(),
2288         /*dcy=*/output_c_backprop_data.opaque(),
2289         /*wDesc=*/rnn_desc.params_handle(), /*w=*/params.opaque(),
2290         /*hxDesc=*/input_h_desc.handle(), /*hx=*/input_h_data.opaque(),
2291         /*cxDesc=*/input_c_desc.handle(), /*cx=*/input_c_data.opaque(),
2292         /*dxDesc=*/input_desc.data_handle(),
2293         /*dx=*/input_backprop_data->opaque(),
2294         /*dhxDesc=*/input_h_desc.handle(),
2295         /*dhx=*/input_h_backprop_data->opaque(),
2296         /*dcxDesc=*/input_c_desc.handle(),
2297         /*dcx=*/input_c_backprop_data->opaque(), nullptr, nullptr,
2298         /*workspace=*/workspace.opaque(),
2299         /*workSpaceSizeInBytes=*/workspace.size(),
2300         /*reserveSpace=*/reserve_space_data->opaque(),
2301         /*reserveSpaceSizeInBytes=*/reserve_space_data->size()));
2302   } else {
2303     RETURN_IF_CUDNN_ERROR(cudnnRNNBackwardData(
2304         /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(),
2305         /*seqLength=*/model_dims.max_seq_length,
2306         /*yDesc=*/output_desc.handles(),
2307         /*y=*/output_data.opaque(), /*dyDesc=*/output_desc.handles(),
2308         /*dy=*/output_backprop_data.opaque(),
2309         /*dhyDesc=*/output_h_desc.handle(),
2310         /*dhy=*/output_h_backprop_data.opaque(),
2311         /*dcyDesc=*/output_c_desc.handle(),
2312         /*dcy=*/output_c_backprop_data.opaque(),
2313         /*wDesc=*/rnn_desc.params_handle(), /*w=*/params.opaque(),
2314         /*hxDesc=*/input_h_desc.handle(), /*hx=*/input_h_data.opaque(),
2315         /*cxDesc=*/input_c_desc.handle(), /*cx=*/input_c_data.opaque(),
2316         /*dxDesc=*/input_desc.handles(), /*dx=*/input_backprop_data->opaque(),
2317         /*dhxDesc=*/input_h_desc.handle(),
2318         /*dhx=*/input_h_backprop_data->opaque(),
2319         /*dcxDesc=*/input_c_desc.handle(),
2320         /*dcx=*/input_c_backprop_data->opaque(),
2321         /*workspace=*/workspace.opaque(),
2322         /*workSpaceSizeInBytes=*/workspace.size(),
2323         /*reserveSpace=*/reserve_space_data->opaque(),
2324         /*reserveSpaceSizeInBytes=*/reserve_space_data->size()));
2325   }
2326 
2327   if (params_backprop_data != nullptr) {
2328     // Clear the dw to zeros.
2329     stream->ThenMemZero(params_backprop_data, params_backprop_data->size());
2330     if (input_desc.is_var_seq_lengths()) {
2331       RETURN_IF_CUDNN_ERROR(cudnnRNNBackwardWeightsEx(
2332           /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(),
2333           /*xDesc=*/input_desc.data_handle(), /*x=*/input_data.opaque(),
2334           /*hxDesc=*/input_h_desc.handle(), /*hx=*/input_h_data.opaque(),
2335           /*yDesc=*/output_desc.data_handle(),
2336           /*y=*/output_data.opaque(),
2337           /*workspace=*/workspace.opaque(),
2338           /*workSpaceSizeInBytes=*/workspace.size(),
2339           /*dwDesc=*/rnn_desc.params_handle(),
2340           /*dw=*/params_backprop_data->opaque(),
2341           /*reserveSpace=*/reserve_space_data->opaque(),
2342           /*reserveSpaceSizeInBytes=*/reserve_space_data->size()));
2343     } else {
2344       // make the backward weight call
2345       RETURN_IF_CUDNN_ERROR(cudnnRNNBackwardWeights(
2346           /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(),
2347           /*seqLength=*/model_dims.max_seq_length,
2348           /*xDesc=*/input_desc.handles(),
2349           /*x=*/input_data.opaque(), /*hxDesc=*/input_h_desc.handle(),
2350           /*hx=*/input_h_data.opaque(), /*yDesc=*/output_desc.handles(),
2351           /*y=*/output_data.opaque(), /*workspace=*/workspace.opaque(),
2352           /*workSpaceSizeInBytes=*/workspace.size(),
2353           /*dwDesc=*/rnn_desc.params_handle(),
2354           /*dw=*/params_backprop_data->opaque(),
2355           /*reserveSpace=*/reserve_space_data->opaque(),
2356           /*reserveSpaceSizeInBytes=*/reserve_space_data->size()));
2357     }
2358   }
2359 
2360   if (is_profiling) {
2361     if (!timer->Stop(AsGpuStream(stream))) {
2362       return port::Status(port::error::INTERNAL, "Failed to stop timer");
2363     }
2364     auto algo_desc = *rnn_desc.algorithm_config().algorithm();
2365     output_profile_result->set_algorithm(algo_desc);
2366     output_profile_result->set_elapsed_time_in_ms(
2367         timer->GetElapsedMilliseconds());
2368   }
2369 
2370   return ::tensorflow::OkStatus();
2371 }
2372 
DoCtcLossImpl(Stream * stream,const CudnnRnnStateTensorDescriptor & probs_desc,const DeviceMemoryBase probs_data,absl::Span<const int> labels_data,absl::Span<const int> labels_lengths_data,absl::Span<const int> input_lengths_data,DeviceMemoryBase costs_data,const CudnnRnnStateTensorDescriptor & grads_desc,DeviceMemoryBase grads_data,const CudnnCtcLossDescriptor & ctc_loss_desc,DeviceMemory<uint8> scratch_memory,int ctc_loss_algo_id)2373 port::Status CudnnSupport::DoCtcLossImpl(
2374     Stream* stream, const CudnnRnnStateTensorDescriptor& probs_desc,
2375     const DeviceMemoryBase probs_data, absl::Span<const int> labels_data,
2376     absl::Span<const int> labels_lengths_data,
2377     absl::Span<const int> input_lengths_data, DeviceMemoryBase costs_data,
2378     const CudnnRnnStateTensorDescriptor& grads_desc,
2379     DeviceMemoryBase grads_data, const CudnnCtcLossDescriptor& ctc_loss_desc,
2380     DeviceMemory<uint8> scratch_memory, int ctc_loss_algo_id) {
2381   auto cudnn = cudnn_->GetHandle(parent_, stream);
2382 
2383   int kNumTimestamps = probs_desc.num_layers();
2384   int kBatchSize = probs_desc.batch_size();
2385   int kNumLabels = probs_desc.data_size();
2386   int total_size = kNumLabels * kNumTimestamps * kBatchSize;
2387   (void)total_size;
2388 
2389 #if CUDNN_VERSION >= 7603
2390   cudnnCTCLossAlgo_t ctc_loss_algo =
2391       static_cast<cudnnCTCLossAlgo_t>(ctc_loss_algo_id);
2392   RETURN_IF_CUDNN_ERROR(cudnnCTCLoss(
2393       /*handle=*/cudnn.handle(), /*probsDesc=*/probs_desc.handle(),
2394       /*probs=*/probs_data.opaque(), /*labels=*/labels_data.data(),
2395       /*labelLengths=*/labels_lengths_data.data(),
2396       /*inputLengths=*/input_lengths_data.data(),
2397       /*costs=*/costs_data.opaque(), /*gradientsDesc=*/grads_desc.handle(),
2398       /*gradients=*/grads_data.opaque(),
2399       /*algo=*/ctc_loss_algo,
2400       /*ctcLossDesc=*/ctc_loss_desc.handle(),
2401       /*workspace=*/scratch_memory.opaque(),
2402       /*workSpaceSizeInBytes=*/scratch_memory.size()));
2403 #else
2404   return port::Status(port::error::INVALID_ARGUMENT,
2405                       "No supported cudnnCTCLoss when "
2406                       "CUDNN_VERSION < 7.6.3");
2407 #endif
2408 
2409   return ::tensorflow::OkStatus();
2410 }
2411 
2412 port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>>
createRnnDescriptor(int num_layers,int hidden_size,int input_size,int cell_size,int batch_size,dnn::RnnInputMode input_mode,dnn::RnnDirectionMode direction_mode,dnn::RnnMode rnn_mode,dnn::DataType data_type,const dnn::AlgorithmConfig & algorithm_config,float dropout,uint64_t seed,ScratchAllocator * state_allocator,bool use_padded_io)2413 CudnnSupport::createRnnDescriptor(
2414     int num_layers, int hidden_size, int input_size, int cell_size,
2415     int batch_size, dnn::RnnInputMode input_mode,
2416     dnn::RnnDirectionMode direction_mode, dnn::RnnMode rnn_mode,
2417     dnn::DataType data_type, const dnn::AlgorithmConfig& algorithm_config,
2418     float dropout, uint64_t seed, ScratchAllocator* state_allocator,
2419     bool use_padded_io) {
2420   // Setting up a cudnnRNNDescriptor requires a cuDNN handle, but because it's
2421   // not enqueueing anything into a stream, we pass in the null stream.
2422   auto cudnn = cudnn_->GetHandle(parent_, /*stream=*/nullptr);
2423   TF_ASSIGN_OR_RETURN(
2424       CudnnRnnDescriptor rnn_desc,
2425       CudnnRnnDescriptor::Create(
2426           cudnn, num_layers, hidden_size, input_size, cell_size, batch_size,
2427           ToCudnnRnnInputMode(input_mode),
2428           ToCudnnRnnDirectionMode(direction_mode), ToCudnnRnnMode(rnn_mode),
2429           ToCudnnDataType(data_type), GetRnnComputeType(data_type),
2430           algorithm_config, dropout, seed, state_allocator, use_padded_io));
2431   return std::unique_ptr<dnn::RnnDescriptor>(
2432       new CudnnRnnDescriptor(std::move(rnn_desc)));
2433 }
2434 
2435 port::StatusOr<std::unique_ptr<dnn::RnnSequenceTensorDescriptor>>
createRnnSequenceTensorDescriptor(int max_seq_length,int batch_size,int data_size,dnn::DataType data_type)2436 CudnnSupport::createRnnSequenceTensorDescriptor(int max_seq_length,
2437                                                 int batch_size, int data_size,
2438                                                 dnn::DataType data_type) {
2439   TF_ASSIGN_OR_RETURN(CudnnRnnSequenceTensorDescriptor descriptor,
2440                       CudnnRnnSequenceTensorDescriptor::Create(
2441                           parent_, max_seq_length, batch_size, data_size,
2442                           ToCudnnDataType(data_type)));
2443   return std::unique_ptr<dnn::RnnSequenceTensorDescriptor>(
2444       new CudnnRnnSequenceTensorDescriptor(std::move(descriptor)));
2445 }
2446 
2447 port::StatusOr<std::unique_ptr<dnn::RnnSequenceTensorDescriptor>>
createRnnSequenceTensorDescriptor(int max_seq_length,int batch_size,int data_size,const absl::Span<const int> & seq_lengths,bool time_major,dnn::DataType data_type)2448 CudnnSupport::createRnnSequenceTensorDescriptor(
2449     int max_seq_length, int batch_size, int data_size,
2450     const absl::Span<const int>& seq_lengths, bool time_major,
2451     dnn::DataType data_type) {
2452   TF_ASSIGN_OR_RETURN(CudnnRnnSequenceTensorDescriptor descriptor,
2453                       CudnnRnnSequenceTensorDescriptor::Create(
2454                           parent_, max_seq_length, batch_size, data_size,
2455                           seq_lengths, time_major, ToCudnnDataType(data_type)));
2456   return std::unique_ptr<dnn::RnnSequenceTensorDescriptor>(
2457       new CudnnRnnSequenceTensorDescriptor(std::move(descriptor)));
2458 }
2459 
2460 port::StatusOr<std::unique_ptr<dnn::RnnStateTensorDescriptor>>
createRnnStateTensorDescriptor(int num_layer,int batch_size,int data_size,dnn::DataType data_type)2461 CudnnSupport::createRnnStateTensorDescriptor(int num_layer, int batch_size,
2462                                              int data_size,
2463                                              dnn::DataType data_type) {
2464   return std::unique_ptr<dnn::RnnStateTensorDescriptor>(
2465       new CudnnRnnStateTensorDescriptor(parent_, num_layer, batch_size,
2466                                         data_size, ToCudnnDataType(data_type)));
2467 }
2468 
DoRnnForward(Stream * stream,const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<Eigen::half> & input_data,const DeviceMemory<int> & seq_lengths_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<Eigen::half> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<Eigen::half> & input_c_data,const DeviceMemory<Eigen::half> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,DeviceMemory<Eigen::half> * output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,DeviceMemory<Eigen::half> * output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,DeviceMemory<Eigen::half> * output_c_data,bool is_training,ScratchAllocator * reserve_space_allocator,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)2469 bool CudnnSupport::DoRnnForward(
2470     Stream* stream, const dnn::RnnDescriptor& rnn_desc,
2471     const dnn::RnnSequenceTensorDescriptor& input_desc,
2472     const DeviceMemory<Eigen::half>& input_data,
2473     const DeviceMemory<int>& seq_lengths_data,
2474     const dnn::RnnStateTensorDescriptor& input_h_desc,
2475     const DeviceMemory<Eigen::half>& input_h_data,
2476     const dnn::RnnStateTensorDescriptor& input_c_desc,
2477     const DeviceMemory<Eigen::half>& input_c_data,
2478     const DeviceMemory<Eigen::half>& params,
2479     const dnn::RnnSequenceTensorDescriptor& output_desc,
2480     DeviceMemory<Eigen::half>* output_data,
2481     const dnn::RnnStateTensorDescriptor& output_h_desc,
2482     DeviceMemory<Eigen::half>* output_h_data,
2483     const dnn::RnnStateTensorDescriptor& output_c_desc,
2484     DeviceMemory<Eigen::half>* output_c_data, bool is_training,
2485     ScratchAllocator* reserve_space_allocator,
2486     ScratchAllocator* workspace_allocator,
2487     dnn::ProfileResult* output_profile_result) {
2488   const CudnnRnnDescriptor& cudnn_rnn_desc =
2489       static_cast<const CudnnRnnDescriptor&>(rnn_desc);
2490   const CudnnRnnSequenceTensorDescriptor& cudnn_input_desc =
2491       static_cast<const CudnnRnnSequenceTensorDescriptor&>(input_desc);
2492   const CudnnRnnStateTensorDescriptor& cudnn_input_h_desc =
2493       static_cast<const CudnnRnnStateTensorDescriptor&>(input_h_desc);
2494   const CudnnRnnStateTensorDescriptor& cudnn_input_c_desc =
2495       static_cast<const CudnnRnnStateTensorDescriptor&>(input_c_desc);
2496   const CudnnRnnSequenceTensorDescriptor& cudnn_output_desc =
2497       static_cast<const CudnnRnnSequenceTensorDescriptor&>(output_desc);
2498   const CudnnRnnStateTensorDescriptor& cudnn_output_h_desc =
2499       static_cast<const CudnnRnnStateTensorDescriptor&>(output_h_desc);
2500   const CudnnRnnStateTensorDescriptor& cudnn_output_c_desc =
2501       static_cast<const CudnnRnnStateTensorDescriptor&>(output_c_desc);
2502   return IsStatusOk(
2503       DoRnnForwardImpl<Eigen::half>(
2504           stream, cudnn_rnn_desc, cudnn_input_desc, input_data,
2505           seq_lengths_data, cudnn_input_h_desc, input_h_data,
2506           cudnn_input_c_desc, input_c_data, params, cudnn_output_desc,
2507           output_data, cudnn_output_h_desc, output_h_data, cudnn_output_c_desc,
2508           output_c_data, is_training, reserve_space_allocator,
2509           workspace_allocator, output_profile_result),
2510       /*report_error=*/!output_profile_result);
2511 }
2512 
DoRnnForward(Stream * stream,const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<float> & input_data,const DeviceMemory<int> & seq_lengths_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<float> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<float> & input_c_data,const DeviceMemory<float> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,DeviceMemory<float> * output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,DeviceMemory<float> * output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,DeviceMemory<float> * output_c_data,bool is_training,ScratchAllocator * reserve_space_allocator,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)2513 bool CudnnSupport::DoRnnForward(
2514     Stream* stream, const dnn::RnnDescriptor& rnn_desc,
2515     const dnn::RnnSequenceTensorDescriptor& input_desc,
2516     const DeviceMemory<float>& input_data,
2517     const DeviceMemory<int>& seq_lengths_data,
2518     const dnn::RnnStateTensorDescriptor& input_h_desc,
2519     const DeviceMemory<float>& input_h_data,
2520     const dnn::RnnStateTensorDescriptor& input_c_desc,
2521     const DeviceMemory<float>& input_c_data, const DeviceMemory<float>& params,
2522     const dnn::RnnSequenceTensorDescriptor& output_desc,
2523     DeviceMemory<float>* output_data,
2524     const dnn::RnnStateTensorDescriptor& output_h_desc,
2525     DeviceMemory<float>* output_h_data,
2526     const dnn::RnnStateTensorDescriptor& output_c_desc,
2527     DeviceMemory<float>* output_c_data, bool is_training,
2528     ScratchAllocator* reserve_space_allocator,
2529     ScratchAllocator* workspace_allocator,
2530     dnn::ProfileResult* output_profile_result) {
2531   const CudnnRnnDescriptor& cudnn_rnn_desc =
2532       static_cast<const CudnnRnnDescriptor&>(rnn_desc);
2533   const CudnnRnnSequenceTensorDescriptor& cudnn_input_desc =
2534       static_cast<const CudnnRnnSequenceTensorDescriptor&>(input_desc);
2535   const CudnnRnnStateTensorDescriptor& cudnn_input_h_desc =
2536       static_cast<const CudnnRnnStateTensorDescriptor&>(input_h_desc);
2537   const CudnnRnnStateTensorDescriptor& cudnn_input_c_desc =
2538       static_cast<const CudnnRnnStateTensorDescriptor&>(input_c_desc);
2539   const CudnnRnnSequenceTensorDescriptor& cudnn_output_desc =
2540       static_cast<const CudnnRnnSequenceTensorDescriptor&>(output_desc);
2541   const CudnnRnnStateTensorDescriptor& cudnn_output_h_desc =
2542       static_cast<const CudnnRnnStateTensorDescriptor&>(output_h_desc);
2543   const CudnnRnnStateTensorDescriptor& cudnn_output_c_desc =
2544       static_cast<const CudnnRnnStateTensorDescriptor&>(output_c_desc);
2545   return IsStatusOk(
2546       DoRnnForwardImpl<float>(
2547           stream, cudnn_rnn_desc, cudnn_input_desc, input_data,
2548           seq_lengths_data, cudnn_input_h_desc, input_h_data,
2549           cudnn_input_c_desc, input_c_data, params, cudnn_output_desc,
2550           output_data, cudnn_output_h_desc, output_h_data, cudnn_output_c_desc,
2551           output_c_data, is_training, reserve_space_allocator,
2552           workspace_allocator, output_profile_result),
2553       /*report_error=*/!output_profile_result);
2554 }
2555 
DoRnnForward(Stream * stream,const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<double> & input_data,const DeviceMemory<int> & seq_lengths_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<double> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<double> & input_c_data,const DeviceMemory<double> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,DeviceMemory<double> * output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,DeviceMemory<double> * output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,DeviceMemory<double> * output_c_data,bool is_training,ScratchAllocator * reserve_space_allocator,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)2556 bool CudnnSupport::DoRnnForward(
2557     Stream* stream, const dnn::RnnDescriptor& rnn_desc,
2558     const dnn::RnnSequenceTensorDescriptor& input_desc,
2559     const DeviceMemory<double>& input_data,
2560     const DeviceMemory<int>& seq_lengths_data,
2561     const dnn::RnnStateTensorDescriptor& input_h_desc,
2562     const DeviceMemory<double>& input_h_data,
2563     const dnn::RnnStateTensorDescriptor& input_c_desc,
2564     const DeviceMemory<double>& input_c_data,
2565     const DeviceMemory<double>& params,
2566     const dnn::RnnSequenceTensorDescriptor& output_desc,
2567     DeviceMemory<double>* output_data,
2568     const dnn::RnnStateTensorDescriptor& output_h_desc,
2569     DeviceMemory<double>* output_h_data,
2570     const dnn::RnnStateTensorDescriptor& output_c_desc,
2571     DeviceMemory<double>* output_c_data, bool is_training,
2572     ScratchAllocator* reserve_space_allocator,
2573     ScratchAllocator* workspace_allocator,
2574     dnn::ProfileResult* output_profile_result) {
2575   const CudnnRnnDescriptor& cudnn_rnn_desc =
2576       static_cast<const CudnnRnnDescriptor&>(rnn_desc);
2577   const CudnnRnnSequenceTensorDescriptor& cudnn_input_desc =
2578       static_cast<const CudnnRnnSequenceTensorDescriptor&>(input_desc);
2579   const CudnnRnnStateTensorDescriptor& cudnn_input_h_desc =
2580       static_cast<const CudnnRnnStateTensorDescriptor&>(input_h_desc);
2581   const CudnnRnnStateTensorDescriptor& cudnn_input_c_desc =
2582       static_cast<const CudnnRnnStateTensorDescriptor&>(input_c_desc);
2583   const CudnnRnnSequenceTensorDescriptor& cudnn_output_desc =
2584       static_cast<const CudnnRnnSequenceTensorDescriptor&>(output_desc);
2585   const CudnnRnnStateTensorDescriptor& cudnn_output_h_desc =
2586       static_cast<const CudnnRnnStateTensorDescriptor&>(output_h_desc);
2587   const CudnnRnnStateTensorDescriptor& cudnn_output_c_desc =
2588       static_cast<const CudnnRnnStateTensorDescriptor&>(output_c_desc);
2589   return IsStatusOk(
2590       DoRnnForwardImpl<double>(
2591           stream, cudnn_rnn_desc, cudnn_input_desc, input_data,
2592           seq_lengths_data, cudnn_input_h_desc, input_h_data,
2593           cudnn_input_c_desc, input_c_data, params, cudnn_output_desc,
2594           output_data, cudnn_output_h_desc, output_h_data, cudnn_output_c_desc,
2595           output_c_data, is_training, reserve_space_allocator,
2596           workspace_allocator, output_profile_result),
2597       /*report_error=*/!output_profile_result);
2598 }
2599 
DoRnnBackward(Stream * stream,const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<Eigen::half> & input_data,const DeviceMemory<int> & seq_lengths_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<Eigen::half> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<Eigen::half> & input_c_data,const DeviceMemory<Eigen::half> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,const DeviceMemory<Eigen::half> & output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,const DeviceMemory<Eigen::half> & output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,const DeviceMemory<Eigen::half> & output_c_data,const DeviceMemory<Eigen::half> & output_backprop_data,const DeviceMemory<Eigen::half> & output_h_backprop_data,const DeviceMemory<Eigen::half> & output_c_backprop_data,DeviceMemory<Eigen::half> * input_backprop_data,DeviceMemory<Eigen::half> * input_h_backprop_data,DeviceMemory<Eigen::half> * input_c_backprop_data,DeviceMemory<Eigen::half> * params_backprop_data,DeviceMemory<uint8> * reserve_space_data,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)2600 bool CudnnSupport::DoRnnBackward(
2601     Stream* stream, const dnn::RnnDescriptor& rnn_desc,
2602     const dnn::RnnSequenceTensorDescriptor& input_desc,
2603     const DeviceMemory<Eigen::half>& input_data,
2604     const DeviceMemory<int>& seq_lengths_data,
2605     const dnn::RnnStateTensorDescriptor& input_h_desc,
2606     const DeviceMemory<Eigen::half>& input_h_data,
2607     const dnn::RnnStateTensorDescriptor& input_c_desc,
2608     const DeviceMemory<Eigen::half>& input_c_data,
2609     const DeviceMemory<Eigen::half>& params,
2610     const dnn::RnnSequenceTensorDescriptor& output_desc,
2611     const DeviceMemory<Eigen::half>& output_data,
2612     const dnn::RnnStateTensorDescriptor& output_h_desc,
2613     const DeviceMemory<Eigen::half>& output_h_data,
2614     const dnn::RnnStateTensorDescriptor& output_c_desc,
2615     const DeviceMemory<Eigen::half>& output_c_data,
2616     const DeviceMemory<Eigen::half>& output_backprop_data,
2617     const DeviceMemory<Eigen::half>& output_h_backprop_data,
2618     const DeviceMemory<Eigen::half>& output_c_backprop_data,
2619     DeviceMemory<Eigen::half>* input_backprop_data,
2620     DeviceMemory<Eigen::half>* input_h_backprop_data,
2621     DeviceMemory<Eigen::half>* input_c_backprop_data,
2622     DeviceMemory<Eigen::half>* params_backprop_data,
2623     DeviceMemory<uint8>* reserve_space_data,
2624     ScratchAllocator* workspace_allocator,
2625     dnn::ProfileResult* output_profile_result) {
2626   const CudnnRnnDescriptor& cudnn_rnn_desc =
2627       static_cast<const CudnnRnnDescriptor&>(rnn_desc);
2628   const CudnnRnnSequenceTensorDescriptor& cudnn_input_desc =
2629       static_cast<const CudnnRnnSequenceTensorDescriptor&>(input_desc);
2630   const CudnnRnnStateTensorDescriptor& cudnn_input_h_desc =
2631       static_cast<const CudnnRnnStateTensorDescriptor&>(input_h_desc);
2632   const CudnnRnnStateTensorDescriptor& cudnn_input_c_desc =
2633       static_cast<const CudnnRnnStateTensorDescriptor&>(input_c_desc);
2634   const CudnnRnnSequenceTensorDescriptor& cudnn_output_desc =
2635       static_cast<const CudnnRnnSequenceTensorDescriptor&>(output_desc);
2636   const CudnnRnnStateTensorDescriptor& cudnn_output_h_desc =
2637       static_cast<const CudnnRnnStateTensorDescriptor&>(output_h_desc);
2638   const CudnnRnnStateTensorDescriptor& cudnn_output_c_desc =
2639       static_cast<const CudnnRnnStateTensorDescriptor&>(output_c_desc);
2640   return IsStatusOk(
2641       DoRnnBackwardImpl<Eigen::half>(
2642           stream, cudnn_rnn_desc, cudnn_input_desc, input_data,
2643           seq_lengths_data, cudnn_input_h_desc, input_h_data,
2644           cudnn_input_c_desc, input_c_data, params, cudnn_output_desc,
2645           output_data, cudnn_output_h_desc, output_h_data, cudnn_output_c_desc,
2646           output_c_data, output_backprop_data, output_h_backprop_data,
2647           output_c_backprop_data, input_backprop_data, input_h_backprop_data,
2648           input_c_backprop_data, params_backprop_data, reserve_space_data,
2649           workspace_allocator, output_profile_result),
2650       /*report_error=*/!output_profile_result);
2651 }
2652 
DoRnnBackward(Stream * stream,const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<float> & input_data,const DeviceMemory<int> & seq_lengths_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<float> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<float> & input_c_data,const DeviceMemory<float> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,const DeviceMemory<float> & output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,const DeviceMemory<float> & output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,const DeviceMemory<float> & output_c_data,const DeviceMemory<float> & output_backprop_data,const DeviceMemory<float> & output_h_backprop_data,const DeviceMemory<float> & output_c_backprop_data,DeviceMemory<float> * input_backprop_data,DeviceMemory<float> * input_h_backprop_data,DeviceMemory<float> * input_c_backprop_data,DeviceMemory<float> * params_backprop_data,DeviceMemory<uint8> * reserve_space_data,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)2653 bool CudnnSupport::DoRnnBackward(
2654     Stream* stream, const dnn::RnnDescriptor& rnn_desc,
2655     const dnn::RnnSequenceTensorDescriptor& input_desc,
2656     const DeviceMemory<float>& input_data,
2657     const DeviceMemory<int>& seq_lengths_data,
2658     const dnn::RnnStateTensorDescriptor& input_h_desc,
2659     const DeviceMemory<float>& input_h_data,
2660     const dnn::RnnStateTensorDescriptor& input_c_desc,
2661     const DeviceMemory<float>& input_c_data, const DeviceMemory<float>& params,
2662     const dnn::RnnSequenceTensorDescriptor& output_desc,
2663     const DeviceMemory<float>& output_data,
2664     const dnn::RnnStateTensorDescriptor& output_h_desc,
2665     const DeviceMemory<float>& output_h_data,
2666     const dnn::RnnStateTensorDescriptor& output_c_desc,
2667     const DeviceMemory<float>& output_c_data,
2668     const DeviceMemory<float>& output_backprop_data,
2669     const DeviceMemory<float>& output_h_backprop_data,
2670     const DeviceMemory<float>& output_c_backprop_data,
2671     DeviceMemory<float>* input_backprop_data,
2672     DeviceMemory<float>* input_h_backprop_data,
2673     DeviceMemory<float>* input_c_backprop_data,
2674     DeviceMemory<float>* params_backprop_data,
2675     DeviceMemory<uint8>* reserve_space_data,
2676     ScratchAllocator* workspace_allocator,
2677     dnn::ProfileResult* output_profile_result) {
2678   const CudnnRnnDescriptor& cudnn_rnn_desc =
2679       static_cast<const CudnnRnnDescriptor&>(rnn_desc);
2680   const CudnnRnnSequenceTensorDescriptor& cudnn_input_desc =
2681       static_cast<const CudnnRnnSequenceTensorDescriptor&>(input_desc);
2682   const CudnnRnnStateTensorDescriptor& cudnn_input_h_desc =
2683       static_cast<const CudnnRnnStateTensorDescriptor&>(input_h_desc);
2684   const CudnnRnnStateTensorDescriptor& cudnn_input_c_desc =
2685       static_cast<const CudnnRnnStateTensorDescriptor&>(input_c_desc);
2686   const CudnnRnnSequenceTensorDescriptor& cudnn_output_desc =
2687       static_cast<const CudnnRnnSequenceTensorDescriptor&>(output_desc);
2688   const CudnnRnnStateTensorDescriptor& cudnn_output_h_desc =
2689       static_cast<const CudnnRnnStateTensorDescriptor&>(output_h_desc);
2690   const CudnnRnnStateTensorDescriptor& cudnn_output_c_desc =
2691       static_cast<const CudnnRnnStateTensorDescriptor&>(output_c_desc);
2692   return IsStatusOk(
2693       DoRnnBackwardImpl<float>(
2694           stream, cudnn_rnn_desc, cudnn_input_desc, input_data,
2695           seq_lengths_data, cudnn_input_h_desc, input_h_data,
2696           cudnn_input_c_desc, input_c_data, params, cudnn_output_desc,
2697           output_data, cudnn_output_h_desc, output_h_data, cudnn_output_c_desc,
2698           output_c_data, output_backprop_data, output_h_backprop_data,
2699           output_c_backprop_data, input_backprop_data, input_h_backprop_data,
2700           input_c_backprop_data, params_backprop_data, reserve_space_data,
2701           workspace_allocator, output_profile_result),
2702       /*report_error=*/!output_profile_result);
2703 }
2704 
DoRnnBackward(Stream * stream,const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<double> & input_data,const DeviceMemory<int> & seq_lengths_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<double> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<double> & input_c_data,const DeviceMemory<double> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,const DeviceMemory<double> & output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,const DeviceMemory<double> & output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,const DeviceMemory<double> & output_c_data,const DeviceMemory<double> & output_backprop_data,const DeviceMemory<double> & output_h_backprop_data,const DeviceMemory<double> & output_c_backprop_data,DeviceMemory<double> * input_backprop_data,DeviceMemory<double> * input_h_backprop_data,DeviceMemory<double> * input_c_backprop_data,DeviceMemory<double> * params_backprop_data,DeviceMemory<uint8> * reserve_space_data,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)2705 bool CudnnSupport::DoRnnBackward(
2706     Stream* stream, const dnn::RnnDescriptor& rnn_desc,
2707     const dnn::RnnSequenceTensorDescriptor& input_desc,
2708     const DeviceMemory<double>& input_data,
2709     const DeviceMemory<int>& seq_lengths_data,
2710     const dnn::RnnStateTensorDescriptor& input_h_desc,
2711     const DeviceMemory<double>& input_h_data,
2712     const dnn::RnnStateTensorDescriptor& input_c_desc,
2713     const DeviceMemory<double>& input_c_data,
2714     const DeviceMemory<double>& params,
2715     const dnn::RnnSequenceTensorDescriptor& output_desc,
2716     const DeviceMemory<double>& output_data,
2717     const dnn::RnnStateTensorDescriptor& output_h_desc,
2718     const DeviceMemory<double>& output_h_data,
2719     const dnn::RnnStateTensorDescriptor& output_c_desc,
2720     const DeviceMemory<double>& output_c_data,
2721     const DeviceMemory<double>& output_backprop_data,
2722     const DeviceMemory<double>& output_h_backprop_data,
2723     const DeviceMemory<double>& output_c_backprop_data,
2724     DeviceMemory<double>* input_backprop_data,
2725     DeviceMemory<double>* input_h_backprop_data,
2726     DeviceMemory<double>* input_c_backprop_data,
2727     DeviceMemory<double>* params_backprop_data,
2728     DeviceMemory<uint8>* reserve_space_data,
2729     ScratchAllocator* workspace_allocator,
2730     dnn::ProfileResult* output_profile_result) {
2731   const CudnnRnnDescriptor& cudnn_rnn_desc =
2732       static_cast<const CudnnRnnDescriptor&>(rnn_desc);
2733   const CudnnRnnSequenceTensorDescriptor& cudnn_input_desc =
2734       static_cast<const CudnnRnnSequenceTensorDescriptor&>(input_desc);
2735   const CudnnRnnStateTensorDescriptor& cudnn_input_h_desc =
2736       static_cast<const CudnnRnnStateTensorDescriptor&>(input_h_desc);
2737   const CudnnRnnStateTensorDescriptor& cudnn_input_c_desc =
2738       static_cast<const CudnnRnnStateTensorDescriptor&>(input_c_desc);
2739   const CudnnRnnSequenceTensorDescriptor& cudnn_output_desc =
2740       static_cast<const CudnnRnnSequenceTensorDescriptor&>(output_desc);
2741   const CudnnRnnStateTensorDescriptor& cudnn_output_h_desc =
2742       static_cast<const CudnnRnnStateTensorDescriptor&>(output_h_desc);
2743   const CudnnRnnStateTensorDescriptor& cudnn_output_c_desc =
2744       static_cast<const CudnnRnnStateTensorDescriptor&>(output_c_desc);
2745   return IsStatusOk(
2746       DoRnnBackwardImpl<double>(
2747           stream, cudnn_rnn_desc, cudnn_input_desc, input_data,
2748           seq_lengths_data, cudnn_input_h_desc, input_h_data,
2749           cudnn_input_c_desc, input_c_data, params, cudnn_output_desc,
2750           output_data, cudnn_output_h_desc, output_h_data, cudnn_output_c_desc,
2751           output_c_data, output_backprop_data, output_h_backprop_data,
2752           output_c_backprop_data, input_backprop_data, input_h_backprop_data,
2753           input_c_backprop_data, params_backprop_data, reserve_space_data,
2754           workspace_allocator, output_profile_result),
2755       /*report_error=*/!output_profile_result);
2756 }
2757 
2758 namespace {
2759 
2760 // TODO(csigg): Merge a lot of duplicate code below for forward, backward data,
2761 // and backward filter.
2762 
GetCudnnConvolutionForwardAlgo(const CudnnHandle & cudnn,const CudnnTensorDescriptor & input_nd,const CudnnFilterDescriptor & filter,const CudnnConvolutionDescriptor & conv,const CudnnTensorDescriptor & output_nd,bool specify_workspace_limit,size_t memory_limit_bytes)2763 port::StatusOr<cudnnConvolutionFwdAlgo_t> GetCudnnConvolutionForwardAlgo(
2764     const CudnnHandle& cudnn, const CudnnTensorDescriptor& input_nd,
2765     const CudnnFilterDescriptor& filter, const CudnnConvolutionDescriptor& conv,
2766     const CudnnTensorDescriptor& output_nd, bool specify_workspace_limit,
2767     size_t memory_limit_bytes) {
2768 #if CUDNN_VERSION >= 8000
2769   const int num_requested_algos = 5;
2770   int num_returned_algos = 0;
2771   cudnnConvolutionFwdAlgoPerf_t perf_results[num_requested_algos];
2772 
2773   RETURN_IF_CUDNN_ERROR(cudnnGetConvolutionForwardAlgorithm_v7(
2774       cudnn.handle(), input_nd.handle(), filter.handle(), conv.handle(),
2775       output_nd.handle(), num_requested_algos, &num_returned_algos,
2776       perf_results));
2777 
2778   size_t mem_limit = specify_workspace_limit ? memory_limit_bytes : 0ULL;
2779   for (int r = 0; r < num_returned_algos; r++) {
2780     if (perf_results[r].status == CUDNN_STATUS_SUCCESS &&
2781         perf_results[r].algo != CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED &&
2782         perf_results[r].memory <= mem_limit) {
2783       return perf_results[r].algo;
2784     }
2785   }
2786   return port::Status(port::error::INTERNAL,
2787                       "cudnnGetConvolutionForwardAlgorithm_v7 returned "
2788                       "no suitable algorithms. This could be a cudnn bug.");
2789 #else
2790   cudnnConvolutionFwdPreference_t preference =
2791       specify_workspace_limit ? CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT
2792                               : CUDNN_CONVOLUTION_FWD_NO_WORKSPACE;
2793   cudnnConvolutionFwdAlgo_t algo_to_use;
2794   RETURN_IF_CUDNN_ERROR(cudnnGetConvolutionForwardAlgorithm(
2795       cudnn.handle(), input_nd.handle(), filter.handle(), conv.handle(),
2796       output_nd.handle(), preference, memory_limit_bytes, &algo_to_use));
2797   return algo_to_use;
2798 #endif
2799 }
2800 
2801 port::StatusOr<cudnnConvolutionBwdDataAlgo_t>
GetCudnnConvolutionBackwardDataAlgo(const CudnnHandle & cudnn,const CudnnTensorDescriptor & input_nd,const CudnnFilterDescriptor & filter,const CudnnConvolutionDescriptor & conv,const CudnnTensorDescriptor & output_nd,bool specify_workspace_limit,size_t memory_limit_bytes)2802 GetCudnnConvolutionBackwardDataAlgo(const CudnnHandle& cudnn,
2803                                     const CudnnTensorDescriptor& input_nd,
2804                                     const CudnnFilterDescriptor& filter,
2805                                     const CudnnConvolutionDescriptor& conv,
2806                                     const CudnnTensorDescriptor& output_nd,
2807                                     bool specify_workspace_limit,
2808                                     size_t memory_limit_bytes) {
2809 #if CUDNN_VERSION >= 8000
2810   const int num_requested_algos = 5;
2811   int num_returned_algos = 0;
2812   cudnnConvolutionBwdDataAlgoPerf_t perf_results[num_requested_algos];
2813 
2814   RETURN_IF_CUDNN_ERROR(cudnnGetConvolutionBackwardDataAlgorithm_v7(
2815       cudnn.handle(), filter.handle(), output_nd.handle(), conv.handle(),
2816       input_nd.handle(), num_requested_algos, &num_returned_algos,
2817       perf_results));
2818 
2819   size_t mem_limit = specify_workspace_limit ? memory_limit_bytes : 0ULL;
2820   for (int r = 0; r < num_returned_algos; r++) {
2821     if (perf_results[r].status == CUDNN_STATUS_SUCCESS &&
2822         perf_results[r].algo !=
2823             CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD_NONFUSED &&
2824         perf_results[r].memory <= mem_limit) {
2825       return perf_results[r].algo;
2826     }
2827   }
2828   return port::Status(port::error::INTERNAL,
2829                       "cudnnGetConvolutionBackwardDataAlgorithm_v7 returned "
2830                       "no suitable algorithms. This could be a cudnn bug.");
2831 #else
2832   cudnnConvolutionBwdDataPreference_t preference =
2833       specify_workspace_limit
2834           ? CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT
2835           : CUDNN_CONVOLUTION_BWD_DATA_NO_WORKSPACE;
2836   cudnnConvolutionBwdDataAlgo_t algo_to_use;
2837   RETURN_IF_CUDNN_ERROR(cudnnGetConvolutionBackwardDataAlgorithm(
2838       cudnn.handle(), filter.handle(), output_nd.handle(), conv.handle(),
2839       input_nd.handle(), preference, memory_limit_bytes, &algo_to_use));
2840   return algo_to_use;
2841 #endif
2842 }
2843 
2844 port::StatusOr<cudnnConvolutionBwdFilterAlgo_t>
GetCudnnConvolutionBackwardFilterAlgo(const CudnnHandle & cudnn,const CudnnTensorDescriptor & input_nd,const CudnnFilterDescriptor & filter,const CudnnConvolutionDescriptor & conv,const CudnnTensorDescriptor & output_nd,bool specify_workspace_limit,size_t memory_limit_bytes)2845 GetCudnnConvolutionBackwardFilterAlgo(const CudnnHandle& cudnn,
2846                                       const CudnnTensorDescriptor& input_nd,
2847                                       const CudnnFilterDescriptor& filter,
2848                                       const CudnnConvolutionDescriptor& conv,
2849                                       const CudnnTensorDescriptor& output_nd,
2850                                       bool specify_workspace_limit,
2851                                       size_t memory_limit_bytes) {
2852 #if CUDNN_VERSION >= 8000
2853   const int num_requested_algos = 5;
2854   int num_returned_algos = 0;
2855   cudnnConvolutionBwdFilterAlgoPerf_t perf_results[num_requested_algos];
2856 
2857   RETURN_IF_CUDNN_ERROR(cudnnGetConvolutionBackwardFilterAlgorithm_v7(
2858       cudnn.handle(), input_nd.handle(), output_nd.handle(), conv.handle(),
2859       filter.handle(), num_requested_algos, &num_returned_algos, perf_results));
2860 
2861   size_t mem_limit = specify_workspace_limit ? memory_limit_bytes : 0ULL;
2862   for (int r = 0; r < num_returned_algos; r++) {
2863     if (perf_results[r].status == CUDNN_STATUS_SUCCESS &&
2864         perf_results[r].algo !=
2865             CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD_NONFUSED &&
2866         perf_results[r].memory <= mem_limit) {
2867       return perf_results[r].algo;
2868     }
2869   }
2870   return port::Status(port::error::INTERNAL,
2871                       "cudnnGetConvolutionBackwardFilterAlgorithm_v7 returned "
2872                       "no suitable algorithms. This could be a cudnn bug.");
2873 #else
2874   cudnnConvolutionBwdFilterPreference_t preference =
2875       specify_workspace_limit
2876           ? CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT
2877           : CUDNN_CONVOLUTION_BWD_FILTER_NO_WORKSPACE;
2878   cudnnConvolutionBwdFilterAlgo_t algo_to_use;
2879   RETURN_IF_CUDNN_ERROR(cudnnGetConvolutionBackwardFilterAlgorithm(
2880       cudnn.handle(), input_nd.handle(), output_nd.handle(), conv.handle(),
2881       filter.handle(), preference, memory_limit_bytes, &algo_to_use));
2882   return algo_to_use;
2883 #endif
2884 }
2885 
AllocateCudnnConvolutionForwardWorkspace(Stream * stream,const CudnnHandle & cudnn,const CudnnTensorDescriptor & input_nd,const CudnnFilterDescriptor & filter,const CudnnConvolutionDescriptor & conv,const CudnnTensorDescriptor & output_nd,const dnn::AlgorithmDesc & algorithm_desc,ScratchAllocator * scratch_allocator)2886 port::StatusOr<DeviceMemory<uint8>> AllocateCudnnConvolutionForwardWorkspace(
2887     Stream* stream, const CudnnHandle& cudnn,
2888     const CudnnTensorDescriptor& input_nd, const CudnnFilterDescriptor& filter,
2889     const CudnnConvolutionDescriptor& conv,
2890     const CudnnTensorDescriptor& output_nd,
2891     const dnn::AlgorithmDesc& algorithm_desc,
2892     ScratchAllocator* scratch_allocator) {
2893   if (IsTensorMathOpSet(conv) != algorithm_desc.tensor_ops_enabled()) {
2894     return port::Status(
2895         port::error::INTERNAL,
2896         "Mismatch between cudnn conv and algorithm descriptors.");
2897   }
2898 
2899   // Query the size of the workspace and allocate it.
2900   size_t size_in_bytes;
2901   if (algorithm_desc.workspace_size()) {
2902     size_in_bytes = *algorithm_desc.workspace_size();
2903   } else {
2904     RETURN_IF_CUDNN_ERROR(cudnnGetConvolutionForwardWorkspaceSize(
2905         cudnn.handle(),
2906         /*xDesc=*/input_nd.handle(),
2907         /*wDesc=*/filter.handle(), /*convDesc=*/conv.handle(),
2908         /*yDesc=*/output_nd.handle(),
2909         /*algo=*/ToConvForwardAlgo(algorithm_desc),
2910         /*sizeInBytes=*/&size_in_bytes));
2911   }
2912 
2913   int64_t size_in_bytes_int64_t = size_in_bytes;
2914 
2915   if (TF_PREDICT_FALSE(size_in_bytes_int64_t < 0)) {
2916     return port::Status(
2917         port::error::INTERNAL,
2918         "cudnnGetConvolutionForwardWorkspaceSize() returned "
2919         "negative sizeInBytes value. This could be a cudnn bug.");
2920   }
2921 
2922   if (size_in_bytes_int64_t == 0) {
2923     return DeviceMemory<uint8>();
2924   }
2925 
2926   if (TF_PREDICT_FALSE(!scratch_allocator)) {
2927     return port::Status(port::error::INVALID_ARGUMENT,
2928                         "No scratch allocator provided");
2929   }
2930 
2931   return scratch_allocator->AllocateBytes(size_in_bytes);
2932 }
2933 
2934 port::StatusOr<DeviceMemory<uint8>>
AllocateCudnnConvolutionBackwardDataWorkspace(Stream * stream,const CudnnHandle & cudnn,const CudnnTensorDescriptor & input_nd,const CudnnFilterDescriptor & filter,const CudnnConvolutionDescriptor & conv,const CudnnTensorDescriptor & output_nd,const dnn::AlgorithmDesc & algorithm_desc,ScratchAllocator * scratch_allocator)2935 AllocateCudnnConvolutionBackwardDataWorkspace(
2936     Stream* stream, const CudnnHandle& cudnn,
2937     const CudnnTensorDescriptor& input_nd, const CudnnFilterDescriptor& filter,
2938     const CudnnConvolutionDescriptor& conv,
2939     const CudnnTensorDescriptor& output_nd,
2940     const dnn::AlgorithmDesc& algorithm_desc,
2941     ScratchAllocator* scratch_allocator) {
2942   if (IsTensorMathOpSet(conv) != algorithm_desc.tensor_ops_enabled()) {
2943     return port::Status(
2944         port::error::INTERNAL,
2945         "Mismatch between cudnn conv and algorithm descriptors.");
2946   }
2947 
2948   // Query the size of the workspace and allocate it.
2949   size_t size_in_bytes;
2950   if (algorithm_desc.workspace_size()) {
2951     size_in_bytes = *algorithm_desc.workspace_size();
2952   } else {
2953     RETURN_IF_CUDNN_ERROR(cudnnGetConvolutionBackwardDataWorkspaceSize(
2954         cudnn.handle(),
2955         /*wDesc=*/filter.handle(),
2956         /*dyDesc=*/output_nd.handle(),
2957         /*convDesc=*/conv.handle(),
2958         /*dxDesc=*/input_nd.handle(),
2959         /*algo=*/ToConvBackwardDataAlgo(algorithm_desc),
2960         /*sizeInBytes=*/&size_in_bytes));
2961   }
2962 
2963   int64_t size_in_bytes_int64_t = size_in_bytes;
2964 
2965   if (TF_PREDICT_FALSE(size_in_bytes_int64_t < 0)) {
2966     return port::Status(
2967         port::error::INTERNAL,
2968         "cudnnGetConvolutionBackwardDataWorkspaceSize() returned "
2969         "negative sizeInBytes value. This could be a cudnn bug.");
2970   }
2971 
2972   if (size_in_bytes_int64_t == 0) {
2973     return DeviceMemory<uint8>();
2974   }
2975 
2976   if (TF_PREDICT_FALSE(!scratch_allocator)) {
2977     return port::Status(port::error::INVALID_ARGUMENT,
2978                         "No scratch allocator provided");
2979   }
2980 
2981   return scratch_allocator->AllocateBytes(size_in_bytes);
2982 }
2983 
2984 port::StatusOr<DeviceMemory<uint8>>
AllocateCudnnConvolutionBackwardFilterWorkspace(Stream * stream,const CudnnHandle & cudnn,const CudnnTensorDescriptor & input_nd,const CudnnFilterDescriptor & filter,const CudnnConvolutionDescriptor & conv,const CudnnTensorDescriptor & output_nd,const dnn::AlgorithmDesc & algorithm_desc,ScratchAllocator * scratch_allocator)2985 AllocateCudnnConvolutionBackwardFilterWorkspace(
2986     Stream* stream, const CudnnHandle& cudnn,
2987     const CudnnTensorDescriptor& input_nd, const CudnnFilterDescriptor& filter,
2988     const CudnnConvolutionDescriptor& conv,
2989     const CudnnTensorDescriptor& output_nd,
2990     const dnn::AlgorithmDesc& algorithm_desc,
2991     ScratchAllocator* scratch_allocator) {
2992   if (IsTensorMathOpSet(conv) != algorithm_desc.tensor_ops_enabled()) {
2993     return port::Status(
2994         port::error::INTERNAL,
2995         "Mismatch between cudnn conv and algorithm descriptors.");
2996   }
2997 
2998   // Query the size of the workspace and allocate it.
2999   size_t size_in_bytes;
3000   if (algorithm_desc.workspace_size()) {
3001     size_in_bytes = *algorithm_desc.workspace_size();
3002   } else {
3003     RETURN_IF_CUDNN_ERROR(cudnnGetConvolutionBackwardFilterWorkspaceSize(
3004         cudnn.handle(),
3005         /*xDesc=*/input_nd.handle(),
3006         /*dyDesc=*/output_nd.handle(),
3007         /*convDesc=*/conv.handle(),
3008         /*gradDesc=*/filter.handle(),
3009         /*algo=*/ToConvBackwardFilterAlgo(algorithm_desc),
3010         /*sizeInBytes=*/&size_in_bytes));
3011   }
3012 
3013   int64_t size_in_bytes_int64_t = size_in_bytes;
3014 
3015   if (TF_PREDICT_FALSE(size_in_bytes_int64_t < 0)) {
3016     return port::Status(
3017         port::error::INTERNAL,
3018         "cudnnGetConvolutionBackwardFilterWorkspaceSize() returned "
3019         "negative sizeInBytes value. This could be a cudnn bug.");
3020   }
3021 
3022   if (size_in_bytes_int64_t == 0) {
3023     return DeviceMemory<uint8>();
3024   }
3025 
3026   if (TF_PREDICT_FALSE(!scratch_allocator)) {
3027     return port::Status(port::error::INVALID_ARGUMENT,
3028                         "No scratch allocator provided");
3029   }
3030 
3031   return scratch_allocator->AllocateBytes(size_in_bytes);
3032 }
3033 
UseTensorOps(Stream * stream,dnn::DataType type,std::optional<dnn::AlgorithmDesc> desc)3034 port::StatusOr<bool> UseTensorOps(Stream* stream, dnn::DataType type,
3035                                   std::optional<dnn::AlgorithmDesc> desc) {
3036   bool use_tensor_ops;
3037   if (desc.has_value()) {
3038     use_tensor_ops = desc->tensor_ops_enabled();
3039     if (use_tensor_ops && !IsTensorMathEnabled(stream, type)) {
3040       return port::Status(port::error::INVALID_ARGUMENT,
3041                           "Algo requests disabled tensor op evaluation.");
3042     }
3043   } else {
3044     use_tensor_ops = IsTensorMathEnabled(stream, type);
3045   }
3046   return use_tensor_ops;
3047 }
3048 
3049 cudnnDataType_t GetRnnComputeType(dnn::DataType data_type);
3050 dnn::DataType GetConvAccumulatorType(dnn::DataType data_type);
3051 
GetCudnnConvolutionForwardAlgorithm(Stream * stream,const CudnnHandle & cudnn,const dnn::AlgorithmConfig & algorithm_config,const CudnnTensorDescriptor & input_nd,const CudnnFilterDescriptor & filter,dnn::DataType element_type,const dnn::ConvolutionDescriptor & convolution_descriptor,const CudnnTensorDescriptor & output_nd,ScratchAllocator * scratch_allocator,DeviceMemory<uint8> * scratch)3052 port::StatusOr<dnn::AlgorithmDesc> GetCudnnConvolutionForwardAlgorithm(
3053     Stream* stream, const CudnnHandle& cudnn,
3054     const dnn::AlgorithmConfig& algorithm_config,
3055     const CudnnTensorDescriptor& input_nd, const CudnnFilterDescriptor& filter,
3056     dnn::DataType element_type,
3057     const dnn::ConvolutionDescriptor& convolution_descriptor,
3058     const CudnnTensorDescriptor& output_nd, ScratchAllocator* scratch_allocator,
3059     DeviceMemory<uint8>* scratch) {
3060   std::optional<dnn::AlgorithmDesc> algo_desc = algorithm_config.algorithm();
3061 
3062   CudnnConvolutionDescriptor conv(
3063       convolution_descriptor,
3064       ToCudnnDataType(GetConvAccumulatorType(element_type)));
3065   bool use_tensor_ops;
3066   TF_ASSIGN_OR_RETURN(use_tensor_ops,
3067                       UseTensorOps(stream, element_type, algo_desc));
3068   conv.set_use_tensor_op_math(use_tensor_ops);
3069 
3070   if (!algo_desc.has_value()) {
3071     // Pick fastest algorithm within memory limit according to cuDNN's
3072     // heuristics.
3073     bool specify_workspace_limit = scratch_allocator != nullptr;
3074     auto memory_limit_bytes =
3075         specify_workspace_limit
3076             ? std::max(scratch_allocator->GetMemoryLimitInBytes(), int64_t{0})
3077             : int64_t{0};
3078     TF_ASSIGN_OR_RETURN(cudnnConvolutionFwdAlgo_t algo,
3079                         GetCudnnConvolutionForwardAlgo(
3080                             cudnn, input_nd, filter, conv, output_nd,
3081                             specify_workspace_limit, memory_limit_bytes));
3082     algo_desc = dnn::AlgorithmDesc(algo, use_tensor_ops);
3083   }
3084 
3085   const auto scratch_or = AllocateCudnnConvolutionForwardWorkspace(
3086       stream, cudnn, input_nd, filter, conv, output_nd, *algo_desc,
3087       scratch_allocator);
3088 
3089   if (scratch_or.ok()) {
3090     *scratch = scratch_or.ValueOrDie();
3091     return *algo_desc;
3092   }
3093 
3094   algo_desc = algorithm_config.algorithm_no_scratch();
3095 
3096   // Failed to allocate workspace for the first algorithm, fall back to the
3097   // no_scratch algorithm.
3098   if (!algo_desc.has_value()) {
3099     return port::Status(
3100         scratch_or.status().code(),
3101         absl::StrCat("The primary convolution algorithm failed, ",
3102                      "while a secondary algorithm is not provided. ",
3103                      "Returned status: ", scratch_or.status().ToString()));
3104   }
3105 
3106   TF_ASSIGN_OR_RETURN(use_tensor_ops,
3107                       UseTensorOps(stream, element_type, algo_desc));
3108   conv.set_use_tensor_op_math(use_tensor_ops);
3109   TF_ASSIGN_OR_RETURN(*scratch, AllocateCudnnConvolutionForwardWorkspace(
3110                                     stream, cudnn, input_nd, filter, conv,
3111                                     output_nd, *algo_desc, scratch_allocator));
3112   return *algo_desc;
3113 }
3114 
GetCudnnConvolutionBackwardDataAlgorithm(Stream * stream,const CudnnHandle & cudnn,const dnn::AlgorithmConfig & algorithm_config,const CudnnTensorDescriptor & input_nd,const CudnnFilterDescriptor & filter,dnn::DataType element_type,const dnn::ConvolutionDescriptor & convolution_descriptor,const CudnnTensorDescriptor & output_nd,ScratchAllocator * scratch_allocator,DeviceMemory<uint8> * scratch)3115 port::StatusOr<dnn::AlgorithmDesc> GetCudnnConvolutionBackwardDataAlgorithm(
3116     Stream* stream, const CudnnHandle& cudnn,
3117     const dnn::AlgorithmConfig& algorithm_config,
3118     const CudnnTensorDescriptor& input_nd, const CudnnFilterDescriptor& filter,
3119     dnn::DataType element_type,
3120     const dnn::ConvolutionDescriptor& convolution_descriptor,
3121     const CudnnTensorDescriptor& output_nd, ScratchAllocator* scratch_allocator,
3122     DeviceMemory<uint8>* scratch) {
3123   std::optional<dnn::AlgorithmDesc> algo_desc = algorithm_config.algorithm();
3124   CudnnConvolutionDescriptor conv(
3125       convolution_descriptor,
3126       ToCudnnDataType(GetConvAccumulatorType(element_type)));
3127   bool use_tensor_ops;
3128   TF_ASSIGN_OR_RETURN(use_tensor_ops,
3129                       UseTensorOps(stream, element_type, algo_desc));
3130   conv.set_use_tensor_op_math(use_tensor_ops);
3131 
3132   if (!algo_desc.has_value()) {
3133     // Pick fastest algorithm within memory limit according to cuDNN's
3134     // heuristics.
3135     bool specify_workspace_limit = scratch_allocator != nullptr;
3136     auto memory_limit_bytes =
3137         specify_workspace_limit
3138             ? std::max(scratch_allocator->GetMemoryLimitInBytes(), int64_t{0})
3139             : int64_t{0};
3140     TF_ASSIGN_OR_RETURN(cudnnConvolutionBwdDataAlgo_t algo,
3141                         GetCudnnConvolutionBackwardDataAlgo(
3142                             cudnn, input_nd, filter, conv, output_nd,
3143                             specify_workspace_limit, memory_limit_bytes));
3144     algo_desc = dnn::AlgorithmDesc(algo, use_tensor_ops);
3145   }
3146 
3147   const auto scratch_or = AllocateCudnnConvolutionBackwardDataWorkspace(
3148       stream, cudnn, input_nd, filter, conv, output_nd, *algo_desc,
3149       scratch_allocator);
3150 
3151   if (scratch_or.ok()) {
3152     *scratch = scratch_or.ValueOrDie();
3153     return *algo_desc;
3154   }
3155 
3156   algo_desc = algorithm_config.algorithm_no_scratch();
3157 
3158   // Failed to allocate workspace for the first algorithm, fall back to the
3159   // no_scratch algorithm.
3160   if (!algo_desc.has_value()) {
3161     return port::Status(
3162         port::error::INVALID_ARGUMENT,
3163         "The primary convolution algorithm failed memory allocation, "
3164         "while a secondary algorithm is not provided.");
3165   }
3166 
3167   TF_ASSIGN_OR_RETURN(use_tensor_ops,
3168                       UseTensorOps(stream, element_type, algo_desc));
3169   conv.set_use_tensor_op_math(use_tensor_ops);
3170   TF_ASSIGN_OR_RETURN(*scratch, AllocateCudnnConvolutionBackwardDataWorkspace(
3171                                     stream, cudnn, input_nd, filter, conv,
3172                                     output_nd, *algo_desc, scratch_allocator));
3173   return *algo_desc;
3174 }
3175 
GetCudnnConvolutionBackwardFilterAlgorithm(Stream * stream,const CudnnHandle & cudnn,const dnn::AlgorithmConfig & algorithm_config,const CudnnTensorDescriptor & input_nd,const CudnnFilterDescriptor & filter,dnn::DataType element_type,const dnn::ConvolutionDescriptor & convolution_descriptor,const CudnnTensorDescriptor & output_nd,ScratchAllocator * scratch_allocator,DeviceMemory<uint8> * scratch)3176 port::StatusOr<dnn::AlgorithmDesc> GetCudnnConvolutionBackwardFilterAlgorithm(
3177     Stream* stream, const CudnnHandle& cudnn,
3178     const dnn::AlgorithmConfig& algorithm_config,
3179     const CudnnTensorDescriptor& input_nd, const CudnnFilterDescriptor& filter,
3180     dnn::DataType element_type,
3181     const dnn::ConvolutionDescriptor& convolution_descriptor,
3182     const CudnnTensorDescriptor& output_nd, ScratchAllocator* scratch_allocator,
3183     DeviceMemory<uint8>* scratch) {
3184   std::optional<dnn::AlgorithmDesc> algo_desc = algorithm_config.algorithm();
3185   CudnnConvolutionDescriptor conv(
3186       convolution_descriptor,
3187       ToCudnnDataType(GetConvAccumulatorType(element_type)));
3188   bool use_tensor_ops;
3189   TF_ASSIGN_OR_RETURN(use_tensor_ops,
3190                       UseTensorOps(stream, element_type, algo_desc));
3191   conv.set_use_tensor_op_math(use_tensor_ops);
3192 
3193   if (!algo_desc.has_value()) {
3194     // Pick fastest algorithm within memory limit according to cuDNN's
3195     // heuristics.
3196     bool specify_workspace_limit = scratch_allocator != nullptr;
3197     auto memory_limit_bytes =
3198         specify_workspace_limit
3199             ? std::max(scratch_allocator->GetMemoryLimitInBytes(), int64_t{0})
3200             : int64_t{0};
3201     TF_ASSIGN_OR_RETURN(cudnnConvolutionBwdFilterAlgo_t algo,
3202                         GetCudnnConvolutionBackwardFilterAlgo(
3203                             cudnn, input_nd, filter, conv, output_nd,
3204                             specify_workspace_limit, memory_limit_bytes));
3205     algo_desc = dnn::AlgorithmDesc(algo, use_tensor_ops);
3206   }
3207 
3208   port::StatusOr<DeviceMemory<uint8>> scratch_or =
3209       AllocateCudnnConvolutionBackwardFilterWorkspace(
3210           stream, cudnn, input_nd, filter, conv, output_nd, *algo_desc,
3211           scratch_allocator);
3212 
3213   if (scratch_or.ok()) {
3214     *scratch = scratch_or.ValueOrDie();
3215     return *algo_desc;
3216   }
3217 
3218   algo_desc = algorithm_config.algorithm_no_scratch();
3219 
3220   // Failed to allocate workspace for the first algorithm, fall back to the
3221   // no_scratch algorithm.
3222   if (!algo_desc.has_value()) {
3223     return port::Status(
3224         port::error::INVALID_ARGUMENT,
3225         absl::StrCat(
3226             "The primary convolution algorithm failed memory allocation, "
3227             "while a secondary algorithm is not provided. Actual error: ",
3228             scratch_or.status().ToString()));
3229   }
3230 
3231   TF_ASSIGN_OR_RETURN(use_tensor_ops,
3232                       UseTensorOps(stream, element_type, algo_desc));
3233   conv.set_use_tensor_op_math(use_tensor_ops);
3234   TF_ASSIGN_OR_RETURN(*scratch, AllocateCudnnConvolutionBackwardFilterWorkspace(
3235                                     stream, cudnn, input_nd, filter, conv,
3236                                     output_nd, *algo_desc, scratch_allocator));
3237   return *algo_desc;
3238 }
3239 
3240 // A helper class to set env-vars and choose options for cudnn-related
3241 // algorithms.
3242 template <typename EnvVar>
3243 class CudnnEnvVar {
3244  public:
IsEnabled()3245   static bool IsEnabled() {
3246     static bool is_enabled = IsEnabledImpl();
3247     return is_enabled;
3248   }
3249 
3250  private:
IsEnabledImpl()3251   static bool IsEnabledImpl() {
3252     const char* tf_env_var_val = getenv(EnvVar::kName);
3253     if (tf_env_var_val != nullptr) {
3254       absl::string_view tf_env_var_val_str(tf_env_var_val);
3255       if (tf_env_var_val_str == "0") {
3256         return false;
3257       }
3258       return true;
3259     }
3260     return EnvVar::kDefaultFlag;
3261   }
3262 };
3263 
3264 // A helper struct to decide whether to enable the FFT_TILING algorithms for
3265 // forward convolution. It is disabled for cuDNN < 7 due to memory corruption
3266 // caused by some shapes with this algorithm. Users can explicitly enable the
3267 // algorithm through an env-var "TF_ENABLE_FFT_TILING_FORWARD=1".
3268 struct FftTilingForward {
3269   static constexpr const char* kName = "TF_ENABLE_FFT_TILING_FORWARD";
3270   static constexpr bool kDefaultFlag = true;
3271 };
3272 
3273 // A helper struct to decide whether to enable the WINOGRAD_NONFUSED algorithms.
3274 // By default it is turned on, users can explicitly disable them through an
3275 // env-var "TF_ENABLE_WINOGRAD_NONFUSED=0".
3276 // https://github.com/tensorflow/tensorflow/pull/4901
3277 // For CUDNN v8.1, when this env-var is turned off, both the winograd and
3278 // winograd-non-fused engines will be ruled out.
3279 struct WinogradNonfused {
3280   static constexpr const char* kName = "TF_ENABLE_WINOGRAD_NONFUSED";
3281   // NVIDIA has fixed winograd nonfused bug for cudnn v>=7. For older versions,
3282   // we have a workaround.
3283   static constexpr bool kDefaultFlag = true;
3284 };
3285 
3286 // A helper struct to decide whether to use FP32 as the internal compute type
3287 // for convolution when the input data type is FP16. By default it is turned on,
3288 // users can explicitly disable them (choose to use FP16 as the internal compute
3289 // type) through an env-var "TF_FP16_CONV_USE_FP32_COMPUTE=0".
3290 struct ConvDoFP32ComputationFP16Input {
3291   static constexpr const char* kName = "TF_FP16_CONV_USE_FP32_COMPUTE";
3292   // Using FP16 as the internal compute type for convolution when the input data
3293   // type is FP16 is only supported on architectures with true fp16 support
3294   // (compute capability 5.3 and 6.0). Setting this to false in an unsupported
3295   // architecture will cause internal errors.
3296   static constexpr bool kDefaultFlag = true;
3297 };
3298 
3299 // A helper struct to decide whether to use FP32 as the internal compute type
3300 // for rnn when the input data type is FP16. At present it is turned off,
3301 // users can explicitly control them through an env-var
3302 // TF_FP16_RNN_USE_FP32_COMPUTE.
3303 // After the TODO below is fixed, users should almost always use fp32 compute
3304 // type for training. Using fp16 might suffer suboptimal accuracy due to loss
3305 // in precision.
3306 struct RnnDoFP32ComputationFP16Input {
3307   static constexpr const char* kName = "TF_FP16_RNN_USE_FP32_COMPUTE";
3308   // TODO(jamesqin): b/78182362 flip to true when cudnn 7.1.4 fixes the bug.
3309   // Before cudnn 7.1.4 RNN are always done in fp32, no matter what math
3310   // precision is set.
3311   // Set it temporary to false s.t. no error is raised when using fp16 inputs,
3312   // fp32 math precision.
3313   //
3314   // cuDNN == 7.5.0 is verified to have this fixed.
3315   static constexpr bool kDefaultFlag = CUDNN_VERSION >= 7500;
3316 };
3317 
3318 namespace {
3319 
3320 #if CUDNN_VERSION >= 8100 && TF_ENABLE_CUDNN_FRONTEND
3321 
GenericEngineFilter(cudnnBackendDescriptor_t engine_config,bool disable_winograd,bool disable_nondeterminism,bool disable_tensor_core)3322 bool GenericEngineFilter(cudnnBackendDescriptor_t engine_config,
3323                          bool disable_winograd, bool disable_nondeterminism,
3324                          bool disable_tensor_core) {
3325   bool ret = cudnn_frontend::hasNumericalNote<
3326       CUDNN_NUMERICAL_NOTE_DOWN_CONVERT_INPUTS>(engine_config);
3327 
3328   if (disable_winograd) {
3329     ret |= cudnn_frontend::hasNumericalNote<CUDNN_NUMERICAL_NOTE_WINOGRAD>(
3330         engine_config);
3331   }
3332 
3333   if (disable_nondeterminism) {
3334     ret |=
3335         cudnn_frontend::hasNumericalNote<CUDNN_NUMERICAL_NOTE_NONDETERMINISTIC>(
3336             engine_config);
3337   }
3338 
3339   if (disable_tensor_core) {
3340     ret |= cudnn_frontend::hasNumericalNote<CUDNN_NUMERICAL_NOTE_TENSOR_CORE>(
3341         engine_config);
3342   }
3343 
3344   return ret;
3345 }
3346 
3347 #endif  // CUDNN_VERSION >= 8100 && TF_ENABLE_CUDNN_FRONTEND
3348 
3349 }  // namespace
3350 
GetRnnComputeType(dnn::DataType data_type)3351 cudnnDataType_t GetRnnComputeType(dnn::DataType data_type) {
3352   switch (data_type) {
3353     case dnn::DataType::kFloat:
3354       return CUDNN_DATA_FLOAT;
3355     case dnn::DataType::kDouble:
3356       return CUDNN_DATA_DOUBLE;
3357     case dnn::DataType::kHalf:
3358       if (CudnnEnvVar<RnnDoFP32ComputationFP16Input>::IsEnabled()) {
3359         return CUDNN_DATA_FLOAT;
3360       } else {
3361         return CUDNN_DATA_HALF;
3362       }
3363     default:
3364       LOG(FATAL) << "Invalid RNN data type: " << static_cast<int>(data_type);
3365   }
3366 }
3367 
GetConvActivationType(dnn::DataType data_type)3368 dnn::DataType GetConvActivationType(dnn::DataType data_type) {
3369   switch (data_type) {
3370     case dnn::DataType::kFloat:
3371     case dnn::DataType::kDouble:
3372       return data_type;
3373     // TODO(awpr): it's not clear whether float-precision activations on
3374     // half-precision convs are supported; double-check.
3375     case dnn::DataType::kHalf:
3376       return CudnnEnvVar<ConvDoFP32ComputationFP16Input>::IsEnabled()
3377                  ? dnn::DataType::kFloat
3378                  : dnn::DataType::kHalf;
3379     case dnn::DataType::kInt8:
3380     case dnn::DataType::kInt32:  // TODO(awpr): does int32 do blending in float?
3381       return dnn::DataType::kFloat;
3382 #if CUDNN_VERSION >= 8200
3383     // TODO(awpr): as with kHalf, this is not clear.
3384     case dnn::DataType::kBF16:
3385       return CudnnEnvVar<ConvDoFP32ComputationFP16Input>::IsEnabled()
3386                  ? dnn::DataType::kFloat
3387                  : dnn::DataType::kBF16;
3388 #endif
3389     default:
3390       LOG(FATAL) << "Invalid DNN data type: " << static_cast<int>(data_type);
3391   }
3392 }
3393 
GetConvAccumulatorType(dnn::DataType data_type)3394 dnn::DataType GetConvAccumulatorType(dnn::DataType data_type) {
3395   switch (data_type) {
3396     case dnn::DataType::kFloat:
3397     case dnn::DataType::kDouble:
3398       return data_type;
3399     case dnn::DataType::kHalf:
3400       return CudnnEnvVar<ConvDoFP32ComputationFP16Input>::IsEnabled()
3401                  ? dnn::DataType::kFloat
3402                  : dnn::DataType::kHalf;
3403     case dnn::DataType::kInt8:
3404     case dnn::DataType::kInt32:
3405       return dnn::DataType::kInt32;
3406 #if CUDNN_VERSION >= 8200
3407     case dnn::DataType::kBF16:
3408       return CudnnEnvVar<ConvDoFP32ComputationFP16Input>::IsEnabled()
3409                  ? dnn::DataType::kFloat
3410                  : dnn::DataType::kBF16;
3411 #endif
3412     default:
3413       LOG(FATAL) << "Invalid DNN data type: " << static_cast<int>(data_type);
3414   }
3415 }
3416 
3417 #if CUDNN_VERSION >= 8100 && TF_ENABLE_CUDNN_FRONTEND
3418 
3419 namespace {
GetCudnnFrontendHeurMode()3420 cudnnBackendHeurMode_t GetCudnnFrontendHeurMode() {
3421 #if CUDNN_VERSION >= 8300
3422   return CUDNN_HEUR_MODE_B;
3423 #else
3424   return CUDNN_HEUR_MODE_INSTANT;
3425 #endif  // CUDNN_VERSION >= 8300
3426 }
3427 
GetCudnnConvolutionType(dnn::ConvolutionKind kind)3428 cudnnBackendDescriptorType_t GetCudnnConvolutionType(
3429     dnn::ConvolutionKind kind) {
3430   cudnnBackendDescriptorType_t conv_mode;
3431   switch (kind) {
3432     case dnn::ConvolutionKind::FORWARD: {
3433       conv_mode = CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR;
3434       break;
3435     }
3436     case dnn::ConvolutionKind::BACKWARD_DATA: {
3437       conv_mode = CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR;
3438       break;
3439     }
3440     case dnn::ConvolutionKind::BACKWARD_FILTER: {
3441       conv_mode =
3442           CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR;
3443       break;
3444     }
3445     default:
3446       LOG(FATAL) << "Unexpected convolution kind " << static_cast<int>(kind);
3447       break;
3448   }
3449   return conv_mode;
3450 }
3451 
3452 // Cudnn only supports vectorization over the channel dimension (e.g., int8x4,
3453 // or int8x32).
GetTensorVectorSizeAndDim(const dnn::BatchDescriptor & tensor,dnn::DataType element_type)3454 std::tuple<int, int> GetTensorVectorSizeAndDim(
3455     const dnn::BatchDescriptor& tensor, dnn::DataType element_type) {
3456   int vector_size = 1;
3457   int vector_dim = -1;
3458   if (element_type == dnn::DataType::kInt8) {
3459     if (tensor.layout() == dnn::DataLayout::kBatchDepthYX4) {
3460       vector_size = 4;
3461       vector_dim = 1;
3462     } else if (tensor.layout() == dnn::DataLayout::kBatchDepthYX32) {
3463       vector_size = 32;
3464       vector_dim = 1;
3465     }
3466   }
3467   return std::make_tuple(vector_size, vector_dim);
3468 }
3469 
GetTensorVectorSizeAndDim(const dnn::FilterDescriptor & filter,dnn::DataType element_type)3470 std::tuple<int, int> GetTensorVectorSizeAndDim(
3471     const dnn::FilterDescriptor& filter, dnn::DataType element_type) {
3472   int vector_size = 1;
3473   int vector_dim = -1;
3474   if (element_type == dnn::DataType::kInt8) {
3475     if (filter.layout() == dnn::FilterLayout::kOutputInputYX4) {
3476       vector_size = 4;
3477       vector_dim = 1;
3478     } else if (filter.layout() == dnn::FilterLayout::kOutputInputYX32) {
3479       vector_size = 32;
3480       vector_dim = 1;
3481     }
3482   }
3483   return std::make_tuple(vector_size, vector_dim);
3484 }
3485 
CreateCudnnTensor(absl::Span<const int64_t> dims,absl::Span<const int64_t> strides,int64_t uid,dnn::DataType dtype,int64_t vec_count,int64_t vec_dim,bool is_virtual=false)3486 port::StatusOr<cudnn_frontend::Tensor> CreateCudnnTensor(
3487     absl::Span<const int64_t> dims, absl::Span<const int64_t> strides,
3488     int64_t uid, dnn::DataType dtype, int64_t vec_count, int64_t vec_dim,
3489     bool is_virtual = false) {
3490   auto tensor = cudnn_frontend::TensorBuilder()
3491                     .setDim(dims.size(), dims.data())
3492                     .setStrides(strides.size(), strides.data())
3493                     .setId(uid)
3494                     .setAlignment(32)
3495                     .setDataType(ToCudnnDataType(dtype))
3496                     .setVectorCountAndDimension(vec_count, vec_dim)
3497                     .setVirtual(is_virtual)
3498                     .build();
3499   RETURN_MSG_IF_CUDNN_ERROR(tensor);
3500   return tensor;
3501 }
3502 
3503 port::StatusOr<std::unique_ptr<cudnn_frontend::OperationGraph>>
GetCudnnOperationGraph(dnn::ConvolutionKind kind,dnn::DataType input_type,dnn::DataType output_type,const dnn::BatchDescriptor & input_descriptor,const dnn::FilterDescriptor & filter_descriptor,const dnn::BatchDescriptor & output_descriptor,const dnn::ConvolutionDescriptor & convolution_descriptor,CudnnHandle & cudnn)3504 GetCudnnOperationGraph(dnn::ConvolutionKind kind, dnn::DataType input_type,
3505                        dnn::DataType output_type,
3506                        const dnn::BatchDescriptor& input_descriptor,
3507                        const dnn::FilterDescriptor& filter_descriptor,
3508                        const dnn::BatchDescriptor& output_descriptor,
3509                        const dnn::ConvolutionDescriptor& convolution_descriptor,
3510                        CudnnHandle& cudnn) {
3511   PreloadCudnnSubLibsHelper(kind);
3512 
3513   cudnnBackendDescriptorType_t conv_mode = GetCudnnConvolutionType(kind);
3514 
3515   // x tensor.
3516   int vector_size, vector_dim;
3517   std::tie(vector_size, vector_dim) =
3518       GetTensorVectorSizeAndDim(input_descriptor, input_type);
3519   std::vector<int64_t> input_dims = input_descriptor.vectorized_dims(
3520       dnn::DataLayout::kBatchDepthYX, vector_size, vector_dim);
3521   std::vector<int64_t> input_strides = input_descriptor.vectorized_strides(
3522       dnn::DataLayout::kBatchDepthYX, vector_size, vector_dim);
3523 
3524   if (vector_size == 32) {
3525     return port::InternalError(
3526         "cuDNN frontend doesn't support Tx32 at the moment.");
3527   }
3528 
3529   TF_ASSIGN_OR_RETURN(auto tensor_x,
3530                       CreateCudnnTensor(input_dims, input_strides, 'x',
3531                                         input_type, vector_size, vector_dim));
3532 
3533   // y tensor.
3534   std::tie(vector_size, vector_dim) =
3535       GetTensorVectorSizeAndDim(output_descriptor, output_type);
3536   std::vector<int64_t> output_dims = output_descriptor.vectorized_dims(
3537       dnn::DataLayout::kBatchDepthYX, vector_size, vector_dim);
3538   std::vector<int64_t> output_strides = output_descriptor.vectorized_strides(
3539       dnn::DataLayout::kBatchDepthYX, vector_size, vector_dim);
3540 
3541   TF_ASSIGN_OR_RETURN(auto tensor_y,
3542                       CreateCudnnTensor(output_dims, output_strides, 'y',
3543                                         output_type, vector_size, vector_dim));
3544 
3545   // w tensor.
3546   std::tie(vector_size, vector_dim) =
3547       GetTensorVectorSizeAndDim(filter_descriptor, input_type);
3548   std::vector<int64_t> filter_dims = filter_descriptor.vectorized_dims(
3549       dnn::FilterLayout::kOutputInputYX, vector_size, vector_dim);
3550   std::vector<int64_t> filter_strides = filter_descriptor.vectorized_strides(
3551       dnn::FilterLayout::kOutputInputYX, vector_size, vector_dim);
3552 
3553   TF_ASSIGN_OR_RETURN(auto tensor_w,
3554                       CreateCudnnTensor(filter_dims, filter_strides, 'w',
3555                                         input_type, vector_size, vector_dim));
3556 
3557   // conv_desc.
3558   auto mode = convolution_descriptor.convolution_not_crosscorr()
3559                   ? CUDNN_CONVOLUTION
3560                   : CUDNN_CROSS_CORRELATION;
3561 
3562   int conv_dim = convolution_descriptor.ndims();
3563 
3564   auto accumulator_type = ToCudnnDataType(GetConvAccumulatorType(input_type));
3565   CHECK_NE(convolution_descriptor.pad_alignment(),
3566            dnn::PadAlignment::kTensorFlowPadding)
3567       << "TensorFlow padding alignment is not supported.";
3568 
3569   auto conv_desc =
3570       cudnn_frontend::ConvDescBuilder()
3571           .setComputePrecision(accumulator_type)
3572           .setMathMode(mode)
3573           .setNDims(conv_dim)
3574           .setStrides(conv_dim, convolution_descriptor.strides().data())
3575           .setPrePadding(conv_dim, convolution_descriptor.padding().data())
3576           .setPostPadding(conv_dim, convolution_descriptor.padding().data())
3577           .setDilation(conv_dim, convolution_descriptor.dilations().data())
3578           .build();
3579   RETURN_MSG_IF_CUDNN_ERROR(conv_desc);
3580 
3581   double alpha = 1.0;
3582   double beta = 0.0;
3583 
3584   // CUDNN Operation
3585   auto op = cudnn_frontend::OperationBuilder(conv_mode)
3586                 .setxDesc(tensor_x)
3587                 .setyDesc(tensor_y)
3588                 .setwDesc(tensor_w)
3589                 .setcDesc(conv_desc)
3590                 .setAlpha(alpha)
3591                 .setBeta(beta)
3592                 .build();
3593   RETURN_MSG_IF_CUDNN_ERROR(op);
3594 
3595   // CUDNN OperationGraph
3596   std::array<cudnn_frontend::Operation const*, 1> ops = {&op};
3597   auto opGraph = cudnn_frontend::OperationGraphBuilder()
3598                      .setHandle(cudnn.handle())
3599                      .setOperationGraph(ops.size(), ops.data())
3600                      .build();
3601   RETURN_MSG_IF_CUDNN_ERROR(opGraph);
3602 
3603   VLOG(4) << "\nTensor_x: " << tensor_x.describe()
3604           << "\nTensor_y: " << tensor_y.describe()
3605           << "\nTensor_w: " << tensor_w.describe()
3606           << "\nConv: " << conv_desc.describe() << "\nOp: " << op.describe()
3607           << "\nOpGraph: " << opGraph.describe();
3608 
3609   return std::unique_ptr<cudnn_frontend::OperationGraph>(
3610       new cudnn_frontend::OperationGraph(std::move(opGraph)));
3611 }
3612 
SideInputNeeded(dnn::ActivationMode activation_mode,double conv_scale,double side_input_scale)3613 bool SideInputNeeded(dnn::ActivationMode activation_mode, double conv_scale,
3614                      double side_input_scale) {
3615   // Cudnn uses precompiled kernels to perform the Conv-Add-BiasAdd-Act when the
3616   // activation is Relu or Identity and this requires the "side_input" for the
3617   // Add. For other activations, cudnn uses the runtime-compiled kernels.
3618   // However, for this case, we need to drop the Add node and use
3619   // Conv-BiasAdd-Act pattern to trigger the correct cudnn path.
3620   // TODO(kaixih@nvidia): We should remove this WAR when the cudnn fixes it.
3621   bool check_activation = activation_mode == dnn::ActivationMode::kNone ||
3622                           activation_mode == dnn::ActivationMode::kRelu;
3623   bool check_scale = conv_scale != 1.0 || side_input_scale != 0.0;
3624   return check_activation || check_scale;
3625 }
3626 
3627 port::StatusOr<std::unique_ptr<cudnn_frontend::OperationGraph>>
GetCudnnFusedOperationGraph(dnn::ConvolutionKind kind,dnn::DataType input_type,dnn::DataType bias_type,dnn::DataType output_type,double alpha,double alpha2,double leakyrelu_alpha,const dnn::BatchDescriptor & input_descriptor,const dnn::FilterDescriptor & filter_descriptor,dnn::BatchDescriptor bias_descriptor,const dnn::BatchDescriptor & output_descriptor,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::ActivationMode activation_mode,CudnnHandle & cudnn)3628 GetCudnnFusedOperationGraph(
3629     dnn::ConvolutionKind kind, dnn::DataType input_type,
3630     dnn::DataType bias_type, dnn::DataType output_type, double alpha,
3631     double alpha2, double leakyrelu_alpha,
3632     const dnn::BatchDescriptor& input_descriptor,
3633     const dnn::FilterDescriptor& filter_descriptor,
3634     dnn::BatchDescriptor bias_descriptor,
3635     const dnn::BatchDescriptor& output_descriptor,
3636     const dnn::ConvolutionDescriptor& convolution_descriptor,
3637     const dnn::ActivationMode activation_mode, CudnnHandle& cudnn) {
3638   PreloadCudnnSubLibsHelper(kind);
3639 
3640   cudnnBackendDescriptorType_t conv_mode = GetCudnnConvolutionType(kind);
3641   dnn::DataType accumulator_type = GetConvAccumulatorType(input_type);
3642   dnn::DataType activation_type = GetConvActivationType(input_type);
3643 
3644   // CUDNN fused operation supports the pattern in the form of
3645   // Conv + Add + BiasAdd + Act. Therefore, we need to build a graph of the
3646   // four ops with their input/output tensor edges:
3647   // Conv   : input: tensor_x, tensor_w;    output: tensor_conv (virtual)
3648   // Add    : input: tensor_conv, tensor_z; output: tensor_add (virtual)
3649   // BiasAdd: input: tensor_add, tensor_b;  output: tensor_bias (virtual)
3650   // Act    : input: tensor_bias;           output: tensor_y
3651   int vector_size, vector_dim;
3652   std::tie(vector_size, vector_dim) =
3653       GetTensorVectorSizeAndDim(input_descriptor, input_type);
3654   std::vector<int64_t> input_dims = input_descriptor.vectorized_dims(
3655       dnn::DataLayout::kBatchDepthYX, vector_size, vector_dim);
3656   std::vector<int64_t> input_strides = input_descriptor.vectorized_strides(
3657       dnn::DataLayout::kBatchDepthYX, vector_size, vector_dim);
3658 
3659   if (vector_size == 32) {
3660     return port::InternalError(
3661         "cuDNN frontend doesn't support Tx32 at the moment.");
3662   }
3663 
3664   TF_ASSIGN_OR_RETURN(auto tensor_x,
3665                       CreateCudnnTensor(input_dims, input_strides, 'x',
3666                                         input_type, vector_size, vector_dim));
3667 
3668   std::tie(vector_size, vector_dim) =
3669       GetTensorVectorSizeAndDim(output_descriptor, output_type);
3670   std::vector<int64_t> output_dims = output_descriptor.vectorized_dims(
3671       dnn::DataLayout::kBatchDepthYX, vector_size, vector_dim);
3672   std::vector<int64_t> output_strides = output_descriptor.vectorized_strides(
3673       dnn::DataLayout::kBatchDepthYX, vector_size, vector_dim);
3674   TF_ASSIGN_OR_RETURN(auto tensor_y,
3675                       CreateCudnnTensor(output_dims, output_strides, 'y',
3676                                         output_type, vector_size, vector_dim));
3677 
3678   TF_ASSIGN_OR_RETURN(auto tensor_z,
3679                       CreateCudnnTensor(output_dims, output_strides, 'z',
3680                                         output_type, vector_size, vector_dim));
3681 
3682   std::tie(vector_size, vector_dim) =
3683       GetTensorVectorSizeAndDim(filter_descriptor, input_type);
3684   std::vector<int64_t> filter_dims = filter_descriptor.vectorized_dims(
3685       dnn::FilterLayout::kOutputInputYX, vector_size, vector_dim);
3686   std::vector<int64_t> filter_strides = filter_descriptor.vectorized_strides(
3687       dnn::FilterLayout::kOutputInputYX, vector_size, vector_dim);
3688   TF_ASSIGN_OR_RETURN(auto tensor_w,
3689                       CreateCudnnTensor(filter_dims, filter_strides, 'w',
3690                                         input_type, vector_size, vector_dim));
3691 
3692   // For the purposes of the cudnn graph, say that the bias tensor has the same
3693   // layout as the output tensor.  It doesn't actually matter, because bias is a
3694   // 1D array.  But we need to get the correct vectorization, otherwise the
3695   // cudnn graph API rejects this tensor, even though vectorized float tensors
3696   // aren't even a thing in cuDNN.
3697   bias_descriptor.set_layout(output_descriptor.layout());
3698 
3699   // Even more unnecessarily subtle: since vectorized float tensors don't exist,
3700   // `GetVectorSizeAndDim` ignores vectorized layouts for floating-point types,
3701   // so we have to ask it for vector sizes as if the type were `input_type`, as
3702   // opposed to `bias_type`.  For non-int8 types, these are the same anyway, so
3703   // this only affects int8 convolutions.
3704   std::tie(vector_size, vector_dim) =
3705       GetTensorVectorSizeAndDim(bias_descriptor, input_type);
3706   std::vector<int64_t> bias_dims = bias_descriptor.vectorized_dims(
3707       dnn::DataLayout::kBatchDepthYX, vector_size, vector_dim);
3708   std::vector<int64_t> bias_strides = bias_descriptor.vectorized_strides(
3709       dnn::DataLayout::kBatchDepthYX, vector_size, vector_dim);
3710   TF_ASSIGN_OR_RETURN(auto tensor_b,
3711                       CreateCudnnTensor(bias_dims, bias_strides, 'b', bias_type,
3712                                         vector_size, vector_dim));
3713 
3714   std::tie(vector_size, vector_dim) =
3715       GetTensorVectorSizeAndDim(output_descriptor, output_type);
3716   TF_ASSIGN_OR_RETURN(
3717       auto tensor_conv,
3718       CreateCudnnTensor(output_dims, output_strides, 'C', accumulator_type,
3719                         vector_size, vector_dim, /*is_virtual=*/true));
3720 
3721   TF_ASSIGN_OR_RETURN(
3722       auto tensor_add,
3723       CreateCudnnTensor(output_dims, output_strides, 'A', activation_type,
3724                         vector_size, vector_dim, /*is_virtual=*/true));
3725 
3726   TF_ASSIGN_OR_RETURN(
3727       auto tensor_bias,
3728       CreateCudnnTensor(output_dims, output_strides, 'B', activation_type,
3729                         vector_size, vector_dim, /*is_virtual=*/true));
3730 
3731   // conv_desc.
3732   auto mode = convolution_descriptor.convolution_not_crosscorr()
3733                   ? CUDNN_CONVOLUTION
3734                   : CUDNN_CROSS_CORRELATION;
3735 
3736   int conv_dim = convolution_descriptor.ndims();
3737 
3738   CHECK_NE(convolution_descriptor.pad_alignment(),
3739            dnn::PadAlignment::kTensorFlowPadding)
3740       << "TensorFlow padding alignment is not supported.";
3741 
3742   cudnnDataType_t cudnn_convolution_type = ToCudnnDataType(accumulator_type);
3743   cudnnDataType_t cudnn_activation_type = ToCudnnDataType(activation_type);
3744   auto conv_desc =
3745       cudnn_frontend::ConvDescBuilder()
3746           .setComputePrecision(cudnn_convolution_type)
3747           .setMathMode(mode)
3748           .setNDims(conv_dim)
3749           .setStrides(conv_dim, convolution_descriptor.strides().data())
3750           .setPrePadding(conv_dim, convolution_descriptor.padding().data())
3751           .setPostPadding(conv_dim, convolution_descriptor.padding().data())
3752           .setDilation(conv_dim, convolution_descriptor.dilations().data())
3753           .build();
3754   RETURN_MSG_IF_CUDNN_ERROR(conv_desc);
3755 
3756   // CUDNN Operation
3757   auto conv_op = cudnn_frontend::OperationBuilder(conv_mode)
3758                      .setxDesc(tensor_x)
3759                      .setyDesc(tensor_conv)
3760                      .setwDesc(tensor_w)
3761                      .setcDesc(conv_desc)
3762                      .setAlpha(1.0f)
3763                      .setBeta(0.0f)
3764                      .build();
3765   RETURN_MSG_IF_CUDNN_ERROR(conv_op);
3766 
3767   // CUDNN OperationGraph
3768   absl::InlinedVector<cudnn_frontend::Operation const*, 4> ops = {&conv_op};
3769 
3770   bool need_add_op = SideInputNeeded(activation_mode, alpha, alpha2);
3771 
3772   std::optional<cudnn_frontend::PointWiseDesc_v8> add_desc;
3773   std::optional<cudnn_frontend::Operation_v8> add_op;
3774   if (need_add_op) {
3775     add_desc.emplace(cudnn_frontend::PointWiseDescBuilder()
3776                          .setMode(CUDNN_POINTWISE_ADD)
3777                          .setMathPrecision(cudnn_activation_type)
3778                          .build());
3779     RETURN_MSG_IF_CUDNN_ERROR(*add_desc);
3780     add_op.emplace(cudnn_frontend::OperationBuilder(
3781                        CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
3782                        .setxDesc(conv_op.getOutputTensor())
3783                        .setbDesc(tensor_z)
3784                        .setyDesc(tensor_add)
3785                        .setpwDesc(*add_desc)
3786                        .setAlpha(alpha)
3787                        .setAlpha2(alpha2)
3788                        .build());
3789     RETURN_MSG_IF_CUDNN_ERROR(*add_op);
3790     ops.push_back(&*add_op);
3791   }
3792 
3793   auto bias_add_desc = cudnn_frontend::PointWiseDescBuilder()
3794                            .setMode(CUDNN_POINTWISE_ADD)
3795                            .setMathPrecision(cudnn_activation_type)
3796                            .build();
3797 
3798   // If the activation is the identity function, then the bias-add is the last
3799   // op, and it writes to the output, tensor_y.  Otherwise, it writes to the
3800   // "virtual tensor" (temp buffer) tensor_bias, to which we apply the
3801   // activation.
3802   auto& bias_out_desc =
3803       activation_mode == dnn::ActivationMode::kNone ? tensor_y : tensor_bias;
3804   auto& bias_in_desc = need_add_op ? tensor_add : tensor_conv;
3805   auto bias_add_op = cudnn_frontend::OperationBuilder(
3806                          CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
3807                          .setxDesc(bias_in_desc)
3808                          .setbDesc(tensor_b)
3809                          .setyDesc(bias_out_desc)
3810                          .setpwDesc(bias_add_desc)
3811                          .build();
3812   RETURN_MSG_IF_CUDNN_ERROR(bias_add_op);
3813   ops.push_back(&bias_add_op);
3814 
3815   std::optional<cudnn_frontend::PointWiseDesc_v8> act_desc;
3816   switch (activation_mode) {
3817     case dnn::ActivationMode::kNone:
3818       break;
3819     case dnn::ActivationMode::kRelu:
3820       act_desc.emplace(cudnn_frontend::PointWiseDescBuilder()
3821                            .setMode(CUDNN_POINTWISE_RELU_FWD)
3822                            .setMathPrecision(cudnn_activation_type)
3823                            .build());
3824       RETURN_MSG_IF_CUDNN_ERROR(*act_desc);
3825 
3826       break;
3827     case dnn::ActivationMode::kRelu6:
3828       act_desc.emplace(cudnn_frontend::PointWiseDescBuilder()
3829                            .setMode(CUDNN_POINTWISE_RELU_FWD)
3830                            .setReluUpperClip(6.0)
3831                            .setMathPrecision(cudnn_activation_type)
3832                            .build());
3833       RETURN_MSG_IF_CUDNN_ERROR(*act_desc);
3834       break;
3835     case dnn::ActivationMode::kElu:
3836       act_desc.emplace(cudnn_frontend::PointWiseDescBuilder()
3837                            .setMode(CUDNN_POINTWISE_ELU_FWD)
3838                            .setMathPrecision(cudnn_activation_type)
3839                            .build());
3840       RETURN_MSG_IF_CUDNN_ERROR(*act_desc);
3841       break;
3842     case dnn::ActivationMode::kLeakyRelu:
3843       act_desc.emplace(cudnn_frontend::PointWiseDescBuilder()
3844                            .setMode(CUDNN_POINTWISE_RELU_FWD)
3845                            .setReluLowerClipSlope(leakyrelu_alpha)
3846                            .setMathPrecision(cudnn_activation_type)
3847                            .build());
3848       RETURN_MSG_IF_CUDNN_ERROR(*act_desc);
3849       break;
3850     default:
3851       return port::InternalError(
3852           absl::StrCat("Unimplemented activation mode ",
3853                        dnn::ActivationModeString(activation_mode)));
3854   }
3855 
3856   std::optional<cudnn_frontend::Operation_v8> act_op;
3857   if (activation_mode != dnn::ActivationMode::kNone) {
3858     act_op.emplace(cudnn_frontend::OperationBuilder(
3859                        CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
3860                        .setxDesc(bias_add_op.getOutputTensor())
3861                        .setyDesc(tensor_y)
3862                        .setpwDesc(*act_desc)
3863                        .build());
3864     RETURN_MSG_IF_CUDNN_ERROR(*act_op);
3865     ops.push_back(&*act_op);
3866   }
3867 
3868   auto op_graph = cudnn_frontend::OperationGraphBuilder()
3869                       .setHandle(cudnn.handle())
3870                       .setOperationGraph(ops.size(), ops.data())
3871                       .build();
3872   RETURN_MSG_IF_CUDNN_ERROR(op_graph);
3873 
3874   VLOG(4) << "\nTensor_x: " << tensor_x.describe()
3875           << "\nTensor_y: " << tensor_y.describe()
3876           << "\nTensor_z: " << tensor_z.describe()
3877           << "\nTensor_w: " << tensor_w.describe()
3878           << "\nTensor_b: " << tensor_b.describe()
3879           << "\nTensor_conv: " << tensor_conv.describe()
3880           << "\nTensor_add: " << tensor_add.describe()
3881           << "\nTensor_bias: " << tensor_bias.describe()
3882           << "\nConv: " << conv_desc.describe() << "\nAdd: "
3883           << (add_desc.has_value() ? add_desc->describe() : "(skipped)")
3884           << "\nBiasAdd: " << bias_add_desc.describe()  //
3885           << "\nAct: "
3886           << (act_desc.has_value() ? act_desc->describe() : "(identity)")
3887           << "\nConvOp: " << conv_op.describe() << "\nAddOp: "
3888           << (add_op.has_value() ? add_op->describe() : "(skipped)")
3889           << "\nBiasAddOp: " << bias_add_op.describe()  //
3890           << "\nActOp: "
3891           << (act_op.has_value() ? act_op->describe() : "(identity)")
3892           << "\nOpGraph: " << op_graph.describe();
3893 
3894   return std::unique_ptr<cudnn_frontend::OperationGraph>(
3895       new cudnn_frontend::OperationGraph(std::move(op_graph)));
3896 }
3897 
3898 }  // namespace
3899 
RebuildExecutionPlan(const CudnnHandle & cudnn,const dnn::AlgorithmDesc & desc,const cudnn_frontend::OperationGraph & op_graph)3900 static port::StatusOr<cudnn_frontend::ExecutionPlan> RebuildExecutionPlan(
3901     const CudnnHandle& cudnn, const dnn::AlgorithmDesc& desc,
3902     const cudnn_frontend::OperationGraph& op_graph) {
3903   if (!desc.is_cudnn_frontend()) {
3904     return port::InternalError(
3905         "Got legacy cuDNN algorithm enum in RebuildExecutionPlan.");
3906   }
3907 
3908   // Errors encountered when building a cuDNN operation graph are surfaced in an
3909   // unprecedented and innovative way: they're written into a field of the
3910   // contained engine object, but then clobbered by the object's move
3911   // constructor which makes more cuDNN API calls and encounters further errors.
3912   // The only way to get the actual errors is to peek at them via the returned
3913   // rvalue reference before actually moving the object to finish its
3914   // initialization.
3915   cudnn_frontend::EngineBuilder engine_builder;
3916   engine_builder.setOperationGraph(op_graph).setGlobalEngineIdx(desc.algo_id());
3917   auto&& unmoved = engine_builder.build();
3918   RETURN_MSG_IF_CUDNN_ERROR(unmoved);
3919   cudnn_frontend::Engine engine = std::move(unmoved);
3920   RETURN_MSG_IF_CUDNN_ERROR(engine);
3921 
3922   // Miscellaneous compiler bugs and linker issues conspired to make it
3923   // impossible for AlgorithmDesc to just give us a map initially.  Get the
3924   // vector of tuning knobs and build the map locally.
3925   auto tuning_knobs_vec = desc.TuningKnobs();
3926   absl::flat_hash_map<int64_t, int64_t> tuning_knobs;
3927   tuning_knobs.reserve(tuning_knobs_vec.size());
3928   for (const auto& pair : tuning_knobs_vec) {
3929     tuning_knobs[pair.first] = pair.second;
3930   }
3931 
3932   for (auto& knob : engine.getSupportedKnobs()) {
3933     const auto it = tuning_knobs.find(static_cast<int64_t>(knob.getKnobType()));
3934     if (it != tuning_knobs.end()) {
3935       knob.setChoice(it->second);
3936     }
3937   }
3938 
3939   auto engine_config =
3940       cudnn_frontend::EngineConfigBuilder().setEngine(engine).build();
3941   RETURN_MSG_IF_CUDNN_ERROR(engine_config);
3942 
3943   auto plan = cudnn_frontend::ExecutionPlanBuilder()
3944                   .setHandle(cudnn.handle())
3945                   .setEngineConfig(engine_config)
3946                   .build();
3947   RETURN_MSG_IF_CUDNN_ERROR(plan);
3948 
3949   return {std::move(plan)};
3950 }
3951 
3952 #endif  // CUDNN_VERSION >= 8100 && TF_ENABLE_CUDNN_FRONTEND
3953 
3954 }  // namespace
3955 
DoPrepareForConvolution(dnn::ConvolutionKind kind,dnn::DataType element_type,Stream * stream,const dnn::BatchDescriptor & input_descriptor,DeviceMemoryBase input_data,const dnn::FilterDescriptor & filter_descriptor,DeviceMemoryBase filter_data,const dnn::BatchDescriptor & output_descriptor,DeviceMemoryBase output_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::AlgorithmConfig & algorithm_config,ScratchAllocator * scratch_allocator,dnn::AlgorithmDesc * algorithm_desc,DeviceMemory<uint8> * scratch_memory)3956 port::Status CudnnSupport::DoPrepareForConvolution(
3957     dnn::ConvolutionKind kind, dnn::DataType element_type, Stream* stream,
3958     const dnn::BatchDescriptor& input_descriptor, DeviceMemoryBase input_data,
3959     const dnn::FilterDescriptor& filter_descriptor,
3960     DeviceMemoryBase filter_data, const dnn::BatchDescriptor& output_descriptor,
3961     DeviceMemoryBase output_data,
3962     const dnn::ConvolutionDescriptor& convolution_descriptor,
3963     const dnn::AlgorithmConfig& algorithm_config,
3964     ScratchAllocator* scratch_allocator, dnn::AlgorithmDesc* algorithm_desc,
3965     DeviceMemory<uint8>* scratch_memory) {
3966   CudnnTensorDescriptor input_nd(
3967       input_descriptor,
3968       ToCudnnDataType(element_type, input_descriptor.layout()));
3969   CudnnFilterDescriptor filter_nd(
3970       filter_descriptor,
3971       ToCudnnDataType(element_type, filter_descriptor.layout()));
3972   CudnnTensorDescriptor output_nd(
3973       output_descriptor,
3974       ToCudnnDataType(element_type, output_descriptor.layout()));
3975 
3976   auto cudnn = cudnn_->GetHandle(parent_, stream);
3977 
3978   switch (kind) {
3979     case dnn::ConvolutionKind::FORWARD: {
3980       TF_ASSIGN_OR_RETURN(*algorithm_desc,
3981                           GetCudnnConvolutionForwardAlgorithm(
3982                               stream, cudnn, algorithm_config, input_nd,
3983                               filter_nd, element_type, convolution_descriptor,
3984                               output_nd, scratch_allocator, scratch_memory));
3985       break;
3986     }
3987     case dnn::ConvolutionKind::BACKWARD_DATA: {
3988       TF_ASSIGN_OR_RETURN(*algorithm_desc,
3989                           GetCudnnConvolutionBackwardDataAlgorithm(
3990                               stream, cudnn, algorithm_config, input_nd,
3991                               filter_nd, element_type, convolution_descriptor,
3992                               output_nd, scratch_allocator, scratch_memory));
3993       break;
3994     }
3995     case dnn::ConvolutionKind::BACKWARD_FILTER: {
3996       TF_ASSIGN_OR_RETURN(*algorithm_desc,
3997                           GetCudnnConvolutionBackwardFilterAlgorithm(
3998                               stream, cudnn, algorithm_config, input_nd,
3999                               filter_nd, element_type, convolution_descriptor,
4000                               output_nd, scratch_allocator, scratch_memory));
4001       break;
4002     }
4003     default:
4004       return port::InternalError(
4005           absl::StrCat("Unexpected convolution kind ", static_cast<int>(kind)));
4006   }
4007 
4008   return ::tensorflow::OkStatus();
4009 }
4010 
4011 class CudnnLegacyConvRunner : public dnn::ConvRunner {
4012  public:
4013   // Queries the workspace size and constructs a 'CudnnLegacyConvRunner'.
Create(GpuExecutor * parent,Stream * stream,CudnnAccess * cudnn,const dnn::AlgorithmDesc & algo,dnn::DataType input_type,dnn::DataType output_type,dnn::ConvolutionKind kind,CudnnTensorDescriptor input_nd,CudnnTensorDescriptor output_nd,CudnnFilterDescriptor filter,CudnnConvolutionDescriptor conv)4014   static port::StatusOr<CudnnLegacyConvRunner> Create(
4015       GpuExecutor* parent, Stream* stream, CudnnAccess* cudnn,
4016       const dnn::AlgorithmDesc& algo, dnn::DataType input_type,
4017       dnn::DataType output_type, dnn::ConvolutionKind kind,
4018       CudnnTensorDescriptor input_nd, CudnnTensorDescriptor output_nd,
4019       CudnnFilterDescriptor filter, CudnnConvolutionDescriptor conv) {
4020     size_t workspace_size;
4021     if (algo.workspace_size()) {
4022       workspace_size = *algo.workspace_size();
4023     } else {
4024       // For old AlgorithmProtos loaded from serialized autotune maps and for
4025       // AlgorithmDescs constructed by manually specifying an algorithm ID, we
4026       // need to compute the workspace size here.
4027       auto handle = cudnn->GetHandle(parent, stream);
4028 
4029       switch (kind) {
4030         case dnn::ConvolutionKind::FORWARD:
4031           RETURN_IF_CUDNN_ERROR(cudnnGetConvolutionForwardWorkspaceSize(
4032               handle.handle(),
4033               /*xDesc=*/input_nd.handle(),
4034               /*wDesc=*/filter.handle(), /*convDesc=*/conv.handle(),
4035               /*yDesc=*/output_nd.handle(),
4036               /*algo=*/ToConvForwardAlgo(algo),
4037               /*sizeInBytes=*/&workspace_size));
4038           break;
4039         case dnn::ConvolutionKind::BACKWARD_FILTER:
4040           RETURN_IF_CUDNN_ERROR(cudnnGetConvolutionBackwardFilterWorkspaceSize(
4041               handle.handle(),
4042               /*xDesc=*/input_nd.handle(),
4043               /*dyDesc=*/output_nd.handle(),
4044               /*convDesc=*/conv.handle(),
4045               /*gradDesc=*/filter.handle(),
4046               /*algo=*/ToConvBackwardFilterAlgo(algo),
4047               /*sizeInBytes=*/&workspace_size));
4048           break;
4049         case dnn::ConvolutionKind::BACKWARD_DATA:
4050           RETURN_IF_CUDNN_ERROR(cudnnGetConvolutionBackwardDataWorkspaceSize(
4051               handle.handle(),
4052               /*wDesc=*/filter.handle(),
4053               /*dyDesc=*/output_nd.handle(),
4054               /*convDesc=*/conv.handle(),
4055               /*dxDesc=*/input_nd.handle(),
4056               /*algo=*/ToConvBackwardDataAlgo(algo),
4057               /*sizeInBytes=*/&workspace_size));
4058           break;
4059         default:
4060           return port::InternalError(
4061               "Invalid ConvolutionKind for CudnnLegacyConvRunner.");
4062       }
4063     }
4064 
4065     return {{parent, cudnn, algo.algo_id(), algo.tensor_ops_enabled(),
4066              workspace_size, input_type, output_type, kind, std::move(input_nd),
4067              std::move(output_nd), std::move(filter), std::move(conv)}};
4068   }
4069 
ToString() const4070   std::string ToString() const override {
4071     return MakeAlgorithmDesc().ToString();
4072   }
4073 
GetWorkspaceSize() const4074   size_t GetWorkspaceSize() const override { return workspace_size_; }
4075 
ToAlgorithmDesc() const4076   port::StatusOr<dnn::AlgorithmDesc> ToAlgorithmDesc() const override {
4077     return MakeAlgorithmDesc();
4078   }
4079 
operator ()(Stream * stream,dnn::ProfileResult * profile_result,DeviceMemoryBase scratch_memory,DeviceMemoryBase input_data,DeviceMemoryBase filter_data,DeviceMemoryBase output_data) const4080   port::Status operator()(Stream* stream, dnn::ProfileResult* profile_result,
4081                           DeviceMemoryBase scratch_memory,
4082                           DeviceMemoryBase input_data,
4083                           DeviceMemoryBase filter_data,
4084                           DeviceMemoryBase output_data) const override {
4085     auto algo = MakeAlgorithmDesc();
4086 
4087     // Check that the current stream supports tensor ops if they're requested.
4088     TF_RETURN_IF_ERROR(UseTensorOps(stream, input_type_, algo).status());
4089 
4090     if (static_cast<internal::StreamExecutorInterface*>(parent_) !=
4091         stream->parent()->implementation()) {
4092       return port::InternalError(
4093           "CudnnLegacyConvRunner cached across multiple StreamExecutors.");
4094     }
4095 
4096     auto cudnn = cudnn_->GetHandle(parent_, stream);
4097     // Alpha is the scaling factor for input.
4098     float falpha = 1.0;
4099     double dalpha = 1.0;
4100     void* alpha = input_type_ == dnn::DataType::kDouble
4101                       ? static_cast<void*>(&dalpha)
4102                       : static_cast<void*>(&falpha);
4103     // Beta is the scaling factor for output.
4104     float fbeta = 0.0;
4105     double dbeta = 0.0;
4106     void* beta = input_type_ == dnn::DataType::kDouble
4107                      ? static_cast<void*>(&dbeta)
4108                      : static_cast<void*>(&fbeta);
4109 
4110     const bool is_profiling = profile_result != nullptr;
4111 
4112     std::unique_ptr<GpuTimer, GpuTimerDeleter> timer;
4113     if (is_profiling) {
4114       timer.reset(new GpuTimer(parent_));  // NOLINT
4115       // The start and stop of the timer should be as close to the Cudnn call as
4116       // possible. It is still possible for other threads to issue workload on
4117       // to this stream. So it could take multiple profiling measurements.
4118       if (!timer->Init() || !timer->Start(AsGpuStream(stream))) {
4119         return port::Status(port::error::INTERNAL, "Failed to start timer");
4120       }
4121     }
4122 
4123     const auto get_fwd_bugs = [&]() -> port::Status {
4124 #if CUDNN_VERSION < 8000
4125       if (algo_id_ == CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM &&
4126           ToCudnnDataType(input_type_) == CUDNN_DATA_INT8 &&
4127           ToCudnnDataType(output_type_) == CUDNN_DATA_FLOAT) {
4128         return port::Status(
4129             port::error::FAILED_PRECONDITION,
4130             "This configuration potentially produces incorrect results.");
4131       }
4132 #else
4133       (void)output_type_;  // To stop clang-tidy saying it's unused.
4134 #endif
4135       return ::tensorflow::OkStatus();
4136     };
4137 
4138     auto get_bwd_data_bugs = [&]() -> port::Status {
4139       return ::tensorflow::OkStatus();
4140     };
4141 
4142     const auto get_bwd_filter_bugs = [&]() -> port::Status {
4143       return ::tensorflow::OkStatus();
4144     };
4145 
4146     switch (kind_) {
4147       case dnn::ConvolutionKind::FORWARD: {
4148         TF_RETURN_IF_ERROR(get_fwd_bugs());
4149         RETURN_IF_CUDNN_ERROR(cudnnConvolutionForward(
4150             cudnn.handle(),
4151             /*alpha=*/alpha, /*srcDesc=*/input_nd_.handle(),
4152             /*srcData=*/input_data.opaque(), /*filterDesc=*/filter_.handle(),
4153             /*filterData=*/filter_data.opaque(), /*convDesc=*/conv_.handle(),
4154             /*algo=*/ToConvForwardAlgo(algo),
4155             /*workSpace=*/scratch_memory.opaque(),
4156             /*workSpaceSizeInBytes=*/scratch_memory.size(), /*beta=*/beta,
4157             /*yDesc=*/output_nd_.handle(), /*y=*/output_data.opaque()));
4158         break;
4159       }
4160       case dnn::ConvolutionKind::BACKWARD_DATA: {
4161         TF_RETURN_IF_ERROR(get_bwd_data_bugs());
4162         RETURN_IF_CUDNN_ERROR(cudnnConvolutionBackwardData(
4163             cudnn.handle(),
4164             /*alpha=*/alpha,
4165             /*wDesc=*/filter_.handle(),
4166             /*w=*/filter_data.opaque(),
4167             /*dyDesc=*/output_nd_.handle(),
4168             /*dy=*/output_data.opaque(),
4169             /*convDesc=*/conv_.handle(),
4170             /*algo=*/ToConvBackwardDataAlgo(algo),
4171             /*workSpace=*/scratch_memory.opaque(),
4172             /*workSpaceSizeInBytes=*/scratch_memory.size(),
4173             /*beta=*/beta,
4174             /*dxDesc=*/input_nd_.handle(),
4175             /*dx=*/input_data.opaque()));
4176         break;
4177       }
4178       case dnn::ConvolutionKind::BACKWARD_FILTER: {
4179         TF_RETURN_IF_ERROR(get_bwd_filter_bugs());
4180         RETURN_IF_CUDNN_ERROR(cudnnConvolutionBackwardFilter(
4181             cudnn.handle(),
4182             /*alpha=*/alpha,
4183             /*srcDesc=*/input_nd_.handle(),
4184             /*srcData=*/input_data.opaque(),
4185             /*diffDesc=*/output_nd_.handle(),
4186             /*diffData=*/output_data.opaque(),
4187             /*convDesc=*/conv_.handle(),
4188             /*algo=*/ToConvBackwardFilterAlgo(algo),
4189             /*workSpace=*/scratch_memory.opaque(),
4190             /*workSpaceSizeInBytes=*/scratch_memory.size(),
4191             /*beta=*/beta,
4192             /*gradDesc=*/filter_.handle(),
4193             /*dw=*/filter_data.opaque()));
4194         break;
4195       }
4196       default:
4197         return port::InternalError(absl::StrCat("Unexpected convolution kind ",
4198                                                 static_cast<int>(kind_)));
4199     }
4200 
4201     if (is_profiling) {
4202       if (!timer->Stop(AsGpuStream(stream))) {
4203         return port::Status(port::error::INTERNAL, "Failed to stop timer");
4204       }
4205       profile_result->set_algorithm(algo);
4206       profile_result->set_elapsed_time_in_ms(timer->GetElapsedMilliseconds());
4207       profile_result->set_scratch_size(scratch_memory.size());
4208     }
4209 
4210     return ::tensorflow::OkStatus();
4211   }
4212 
4213  private:
4214   // Private to prevent passing in the wrong workspace_size.
CudnnLegacyConvRunner(GpuExecutor * parent,CudnnAccess * cudnn,int64_t algo_id,bool tensor_ops_enabled,size_t workspace_size,dnn::DataType input_type,dnn::DataType output_type,dnn::ConvolutionKind kind,CudnnTensorDescriptor input_nd,CudnnTensorDescriptor output_nd,CudnnFilterDescriptor filter,CudnnConvolutionDescriptor conv)4215   CudnnLegacyConvRunner(GpuExecutor* parent, CudnnAccess* cudnn,
4216                         int64_t algo_id, bool tensor_ops_enabled,
4217                         size_t workspace_size, dnn::DataType input_type,
4218                         dnn::DataType output_type, dnn::ConvolutionKind kind,
4219                         CudnnTensorDescriptor input_nd,
4220                         CudnnTensorDescriptor output_nd,
4221                         CudnnFilterDescriptor filter,
4222                         CudnnConvolutionDescriptor conv)
4223       : parent_(parent),
4224         cudnn_(cudnn),
4225         algo_id_(algo_id),
4226         tensor_ops_enabled_(tensor_ops_enabled),
4227         workspace_size_(workspace_size),
4228         kind_(kind),
4229         input_type_(input_type),
4230         output_type_(output_type),
4231         input_nd_(std::move(input_nd)),
4232         output_nd_(std::move(output_nd)),
4233         filter_(std::move(filter)),
4234         conv_(std::move(conv)) {}
4235 
4236   // Internal form of ToAlgorithmDesc without the StatusOr.
MakeAlgorithmDesc() const4237   dnn::AlgorithmDesc MakeAlgorithmDesc() const {
4238     return {algo_id_, tensor_ops_enabled_, workspace_size_};
4239   }
4240 
4241   GpuExecutor* parent_;
4242   CudnnAccess* cudnn_;
4243   int64_t algo_id_;
4244   bool tensor_ops_enabled_;
4245   size_t workspace_size_;
4246   dnn::ConvolutionKind kind_;
4247   dnn::DataType input_type_;
4248   dnn::DataType output_type_;
4249 
4250   CudnnTensorDescriptor input_nd_;
4251   CudnnTensorDescriptor output_nd_;
4252   CudnnFilterDescriptor filter_;
4253   CudnnConvolutionDescriptor conv_;
4254 };
4255 
DoConvolve(dnn::ConvolutionKind kind,dnn::DataType element_type,dnn::DataType output_type,Stream * stream,const dnn::BatchDescriptor & input_descriptor,DeviceMemoryBase input_data,const dnn::FilterDescriptor & filter_descriptor,DeviceMemoryBase filter_data,const dnn::BatchDescriptor & output_descriptor,DeviceMemoryBase output_data,const dnn::ConvolutionDescriptor & convolution_descriptor,dnn::AlgorithmDesc algorithm_desc,DeviceMemory<uint8> scratch_memory,dnn::ProfileResult * profile_result)4256 port::Status CudnnSupport::DoConvolve(
4257     dnn::ConvolutionKind kind, dnn::DataType element_type,
4258     dnn::DataType output_type, Stream* stream,
4259     const dnn::BatchDescriptor& input_descriptor, DeviceMemoryBase input_data,
4260     const dnn::FilterDescriptor& filter_descriptor,
4261     DeviceMemoryBase filter_data, const dnn::BatchDescriptor& output_descriptor,
4262     DeviceMemoryBase output_data,
4263     const dnn::ConvolutionDescriptor& convolution_descriptor,
4264     dnn::AlgorithmDesc algorithm_desc, DeviceMemory<uint8> scratch_memory,
4265     dnn::ProfileResult* profile_result) {
4266   cudnnDataType_t cudnn_type =
4267       ToCudnnDataType(element_type, input_descriptor.layout());
4268   CudnnTensorDescriptor input_nd(input_descriptor, cudnn_type);
4269   CudnnTensorDescriptor output_nd(
4270       output_descriptor,
4271       ToCudnnDataType(output_type, output_descriptor.layout()));
4272   CudnnFilterDescriptor filter_nd(
4273       filter_descriptor,
4274       ToCudnnDataType(element_type, filter_descriptor.layout()));
4275 
4276   auto accumulator_type = GetConvAccumulatorType(element_type);
4277   CudnnConvolutionDescriptor conv(convolution_descriptor,
4278                                   ToCudnnDataType(accumulator_type));
4279   TF_ASSIGN_OR_RETURN(bool use_tensor_ops,
4280                       UseTensorOps(stream, element_type, algorithm_desc));
4281   conv.set_use_tensor_op_math(use_tensor_ops);
4282 
4283   TF_ASSIGN_OR_RETURN(
4284       auto runner,
4285       CudnnLegacyConvRunner::Create(
4286           parent_, stream, cudnn_.get(), algorithm_desc, element_type,
4287           output_type, kind, std::move(input_nd), std::move(output_nd),
4288           std::move(filter_nd), std::move(conv)));
4289   return runner(stream, profile_result, scratch_memory, input_data, filter_data,
4290                 output_data);
4291 }
4292 
4293 #if CUDNN_VERSION >= 8100 && TF_ENABLE_CUDNN_FRONTEND
4294 
4295 struct BackendDescriptorDeleter {
operator ()stream_executor::gpu::BackendDescriptorDeleter4296   void operator()(cudnnBackendDescriptor_t desc) {
4297     cudnnBackendDestroyDescriptor(desc);
4298   }
4299 };
4300 
4301 using BackendDescriptor = std::unique_ptr<void, BackendDescriptorDeleter>;
4302 
CreateBackendDesc(cudnnBackendDescriptorType_t type)4303 port::StatusOr<BackendDescriptor> CreateBackendDesc(
4304     cudnnBackendDescriptorType_t type) {
4305   void* result;
4306   RETURN_IF_CUDNN_ERROR(cudnnBackendCreateDescriptor(type, &result));
4307   return BackendDescriptor(result);
4308 }
4309 
4310 // Get the values of a CUDNN_TYPE_BACKEND_DESCRIPTOR attribute as a vector.
4311 //
4312 // This is fetching the entirety of a single sequence-valued attribute, as
4313 // opposed to a sequence of multiple attributes.  The distinction is a bit
4314 // meaningless, but this is the presentation the cuDNN docs use, so it may as
4315 // well be consistent.
GetDescriptorAttribute(cudnnBackendDescriptor_t desc,cudnnBackendAttributeName_t name,cudnnBackendDescriptorType_t type)4316 port::StatusOr<std::vector<BackendDescriptor>> GetDescriptorAttribute(
4317     cudnnBackendDescriptor_t desc, cudnnBackendAttributeName_t name,
4318     cudnnBackendDescriptorType_t type) {
4319   int64_t n;
4320   RETURN_IF_CUDNN_ERROR(cudnnBackendGetAttribute(
4321       desc, name, CUDNN_TYPE_BACKEND_DESCRIPTOR, 0, &n, nullptr));
4322 
4323   std::vector<BackendDescriptor> result(n);
4324   for (int i = 0; i < n; ++i) {
4325     TF_ASSIGN_OR_RETURN(result[i], CreateBackendDesc(type));
4326   }
4327 
4328   std::vector<cudnnBackendDescriptor_t> raw_ptrs;
4329   raw_ptrs.reserve(result.size());
4330   absl::c_transform(result, std::back_inserter(raw_ptrs),
4331                     [](const BackendDescriptor& ptr) { return ptr.get(); });
4332 
4333   // This API evidently does a deep copy of the descriptors into the pointers in
4334   // the output array, rather than writing pointers to the descriptors into the
4335   // output array.  So, this writes the memory behind each BackendDescriptor in
4336   // result, rather than writing the contents of raw_ptrs.
4337   RETURN_IF_CUDNN_ERROR(cudnnBackendGetAttribute(
4338       desc, name, CUDNN_TYPE_BACKEND_DESCRIPTOR, n, &n, raw_ptrs.data()));
4339 
4340   return result;
4341 }
4342 
4343 // Extract the engine ID and tuning knobs from the ExecutionPlan, and return
4344 // them in the form of an AlgorithmDesc for use with RebuildExecutionPlan.
ExecutionPlanToAlgorithmDesc(const cudnn_frontend::ExecutionPlan & plan,size_t workspace_size)4345 port::StatusOr<dnn::AlgorithmDesc> ExecutionPlanToAlgorithmDesc(
4346     const cudnn_frontend::ExecutionPlan& plan, size_t workspace_size) {
4347   TF_ASSIGN_OR_RETURN(
4348       auto engine_cfgs,
4349       GetDescriptorAttribute(plan.get_raw_desc(),
4350                              CUDNN_ATTR_EXECUTION_PLAN_ENGINE_CONFIG,
4351                              CUDNN_BACKEND_ENGINECFG_DESCRIPTOR));
4352   if (engine_cfgs.size() != 1) {
4353     return port::InternalError(
4354         "CUDNN_ATTR_EXECUTION_PLAN_ENGINE_CONFIG had more than one element.");
4355   }
4356 
4357   TF_ASSIGN_OR_RETURN(
4358       auto engines,
4359       GetDescriptorAttribute(engine_cfgs[0].get(), CUDNN_ATTR_ENGINECFG_ENGINE,
4360                              CUDNN_BACKEND_ENGINE_DESCRIPTOR));
4361   if (engines.size() != 1) {
4362     return port::InternalError(
4363         "CUDNN_ATTR_ENGINECFG_ENGINE had more than one element.");
4364   }
4365 
4366   int64_t n;
4367   int64_t engine_id;
4368   RETURN_IF_CUDNN_ERROR(
4369       cudnnBackendGetAttribute(engines[0].get(), CUDNN_ATTR_ENGINE_GLOBAL_INDEX,
4370                                CUDNN_TYPE_INT64, 1, &n, &engine_id));
4371 
4372   // Apparently for CUDNN_ATTR_ENGINECFG_KNOB_CHOICES only, trying to query the
4373   // number of elements in the attribute by using an output limit value of 0
4374   // just returns 0; the only way to find out how many there are is to
4375   // pre-allocate space for every existing knob type (as an upper bound on the
4376   // number of knob choices a config can have), and then look back at how many
4377   // were filled.
4378   std::vector<BackendDescriptor> knobs(CUDNN_KNOB_TYPE_COUNTS);
4379   for (int i = 0; i < knobs.size(); ++i) {
4380     TF_ASSIGN_OR_RETURN(
4381         knobs[i], CreateBackendDesc(CUDNN_BACKEND_KNOB_CHOICE_DESCRIPTOR));
4382   }
4383   std::vector<cudnnBackendDescriptor_t> raw_knob_ptrs;
4384   raw_knob_ptrs.reserve(knobs.size());
4385   absl::c_transform(knobs, std::back_inserter(raw_knob_ptrs),
4386                     [](const BackendDescriptor& ptr) { return ptr.get(); });
4387   RETURN_IF_CUDNN_ERROR(cudnnBackendGetAttribute(
4388       engine_cfgs[0].get(), CUDNN_ATTR_ENGINECFG_KNOB_CHOICES,
4389       CUDNN_TYPE_BACKEND_DESCRIPTOR, raw_knob_ptrs.size(), &n,
4390       raw_knob_ptrs.data()));
4391   knobs.resize(n);
4392 
4393   absl::flat_hash_map<int64_t, int64_t> tuning_knobs;
4394   for (const auto& knob : knobs) {
4395     cudnnBackendKnobType_t knob_type;
4396     int64_t knob_value;
4397 
4398     RETURN_IF_CUDNN_ERROR(
4399         cudnnBackendGetAttribute(knob.get(), CUDNN_ATTR_KNOB_CHOICE_KNOB_TYPE,
4400                                  CUDNN_TYPE_KNOB_TYPE, 1, &n, &knob_type));
4401     if (n != 1) {
4402       return port::InternalError(
4403           absl::StrCat("Knob should have exactly one KNOB_TYPE; had ", n));
4404     }
4405 
4406     RETURN_IF_CUDNN_ERROR(
4407         cudnnBackendGetAttribute(knob.get(), CUDNN_ATTR_KNOB_CHOICE_KNOB_VALUE,
4408                                  CUDNN_TYPE_INT64, 1, &n, &knob_value));
4409     if (n != 1) {
4410       return port::InternalError(
4411           absl::StrCat("Knob should have exactly one KNOB_VALUE; had ", n));
4412     }
4413 
4414     auto emplaced = tuning_knobs.try_emplace(knob_type, knob_value).second;
4415     if (!emplaced) {
4416       return port::InternalError(absl::StrFormat(
4417           "cuDNN gave multiple knob values for the same knob type.\n"
4418           "  KNOB_TYPE: %d\n"
4419           "  new KNOB_VALUE: %d\n"
4420           "  old KNOB_VALUE: %d",
4421           knob_type, knob_value, tuning_knobs.at(knob_type)));
4422     }
4423   }
4424 
4425   std::vector<std::pair<int64_t, int64_t>> tuning_knobs_vec;
4426   tuning_knobs_vec.reserve(tuning_knobs.size());
4427   absl::c_copy(tuning_knobs, std::back_inserter(tuning_knobs_vec));
4428 
4429   return dnn::AlgorithmDesc(engine_id, tuning_knobs_vec, workspace_size);
4430 }
4431 
4432 template <typename Sig>
4433 class CudnnExecutionPlanRunner;
4434 // An OpRunner implemented by an ExecutionPlan.
4435 //
4436 // This is the class holding the implementation of ToString, GetWorkspaceSize,
4437 // and operator() for use by the cudnn frontend op runners.
4438 template <typename... Args>
4439 class CudnnExecutionPlanRunner<void(Args...)>
4440     : public dnn::OpRunner<void(Args...)> {
4441  public:
ToString() const4442   std::string ToString() const override { return plan_.getTag(); }
4443 
GetWorkspaceSize() const4444   size_t GetWorkspaceSize() const override { return workspace_size_; }
4445 
ToAlgorithmDesc() const4446   port::StatusOr<dnn::AlgorithmDesc> ToAlgorithmDesc() const override {
4447     return ExecutionPlanToAlgorithmDesc(plan_, workspace_size_);
4448   }
4449 
operator ()(Stream * stream,dnn::ProfileResult * profile_result,DeviceMemoryBase scratch_memory,Args...inputs) const4450   port::Status operator()(Stream* stream, dnn::ProfileResult* profile_result,
4451                           DeviceMemoryBase scratch_memory,
4452                           Args... inputs) const override {
4453     if (static_cast<internal::StreamExecutorInterface*>(parent_) !=
4454         stream->parent()->implementation()) {
4455       return port::InternalError(
4456           "CudnnExecutionPlanRunner cached across multiple StreamExecutors.");
4457     }
4458 
4459     auto cudnn = cudnn_->GetHandle(parent_, stream);
4460 
4461     size_t workspace_size = plan_.getWorkspaceSize();
4462     RETURN_MSG_IF_CUDNN_ERROR(plan_);
4463 
4464     std::array<void*, sizeof...(Args)> data_ptrs = {inputs.opaque()...};
4465 
4466     absl::InlinedVector<int64_t, sizeof...(Args)> data_uids_vec(
4467         data_uids_.cbegin(), data_uids_.cend());
4468     absl::InlinedVector<void*, sizeof...(Args)> data_ptrs_vec(
4469         data_ptrs.cbegin(), data_ptrs.cend());
4470     // We use need_side_input to determine if the side input 'z' from
4471     // {'x', 'w', 'z', 'b', 'y'} is needed for the conv-<add>-bias-act patterns.
4472     if (sizeof...(Args) == 5 && !need_side_input_) {
4473       data_uids_vec.erase(data_uids_vec.begin() + 2);
4474       data_ptrs_vec.erase(data_ptrs_vec.begin() + 2);
4475     }
4476 
4477     auto variantPack =
4478         cudnn_frontend::VariantPackBuilder()
4479             .setWorkspacePointer(scratch_memory.opaque())
4480             .setDataPointers(data_ptrs_vec.size(), data_ptrs_vec.data())
4481             .setUids(data_uids_vec.size(), data_uids_vec.data())
4482             .build();
4483     RETURN_MSG_IF_CUDNN_ERROR(variantPack);
4484 
4485     VLOG(4) << "\nDo cudnn execution plan with plan tag: " << plan_.getTag()
4486             << "\nWorkspace size in bytes: " << workspace_size
4487             << "\nVariantPack: " << variantPack.describe();
4488 
4489     const bool is_profiling = profile_result != nullptr;
4490 
4491     std::unique_ptr<GpuTimer, GpuTimerDeleter> timer;
4492     if (is_profiling) {
4493       timer.reset(new GpuTimer(parent_));  // NOLINT
4494       // The start and stop of the timer should be as close to the Cudnn call as
4495       // possible. It is still possible for other threads to issue workload on
4496       // to this stream. So it could take multiple profiling measurements.
4497       if (!timer->Init() || !timer->Start(AsGpuStream(stream))) {
4498         return port::Status(port::error::INTERNAL, "Failed to start timer");
4499       }
4500     }
4501 
4502     cudnnStatus_t status = cudnnBackendExecute(
4503         cudnn.handle(), plan_.get_raw_desc(), variantPack.get_raw_desc());
4504     RETURN_IF_CUDNN_ERROR(status);
4505 
4506     if (is_profiling) {
4507       if (!timer->Stop(AsGpuStream(stream))) {
4508         return port::Status(port::error::INTERNAL, "Failed to stop timer");
4509       }
4510       TF_ASSIGN_OR_RETURN(auto desc, ToAlgorithmDesc());
4511       profile_result->set_algorithm(desc);
4512       profile_result->set_elapsed_time_in_ms(timer->GetElapsedMilliseconds());
4513       profile_result->set_scratch_size(scratch_memory.size());
4514 
4515       VLOG(4) << "cudnn op with plan " << plan_.getTag()
4516               << ", workspace_size=" << workspace_size << " -> "
4517               << CudnnStatusToString(status) << " in "
4518               << timer->GetElapsedMilliseconds() << "ms";
4519     }
4520 
4521     return port::Status::OK();
4522   }
4523 
Create(GpuExecutor * parent,CudnnAccess * cudnn,cudnn_frontend::ExecutionPlan plan,absl::Span<const int64_t> uids,bool need_side_input)4524   static port::StatusOr<CudnnExecutionPlanRunner> Create(
4525       GpuExecutor* parent, CudnnAccess* cudnn,
4526       cudnn_frontend::ExecutionPlan plan, absl::Span<const int64_t> uids,
4527       bool need_side_input) {
4528     auto workspace_size = static_cast<uint64_t>(plan.getWorkspaceSize());
4529     RETURN_MSG_IF_CUDNN_ERROR(plan);
4530     return {{parent, cudnn, std::move(plan), workspace_size, uids,
4531              need_side_input}};
4532   }
4533 
4534  private:
CudnnExecutionPlanRunner(GpuExecutor * parent,CudnnAccess * cudnn,cudnn_frontend::ExecutionPlan plan,size_t workspace_size,absl::Span<const int64_t> uids,bool need_side_input)4535   CudnnExecutionPlanRunner(GpuExecutor* parent, CudnnAccess* cudnn,
4536                            cudnn_frontend::ExecutionPlan plan,
4537                            size_t workspace_size,
4538                            absl::Span<const int64_t> uids, bool need_side_input)
4539       : parent_(parent),
4540         cudnn_(cudnn),
4541         plan_(std::move(plan)),
4542         workspace_size_(workspace_size),
4543         data_uids_(uids.begin(), uids.end()),
4544         need_side_input_(need_side_input) {}
4545   GpuExecutor* parent_;
4546   CudnnAccess* cudnn_;
4547   cudnn_frontend::ExecutionPlan plan_;
4548   size_t workspace_size_;
4549   absl::InlinedVector<int64_t, sizeof...(Args)> data_uids_;
4550   bool need_side_input_;
4551 };
4552 #endif  // CUDNN_VERSION >= 8100 && TF_ENABLE_CUDNN_FRONTEND
4553 
4554 // Utility for dealing with CUDA's type-erased scaling parameters, where some
4555 // sets of parameters expect a void* pointing at a float while others expect it
4556 // to point at a double.
4557 //
4558 // This is rather ugly, but its purpose is to quarantine the corresponding
4559 // ugliness that already exists in the CUDA API.
4560 class ScalingParam {
4561  public:
ScalingParam(double value)4562   explicit ScalingParam(double value) : as_double_(value), as_float_(value) {}
4563 
4564   // Return a pointer to the appropriate representation type for the given
4565   // element type.
4566   //
4567   // See
4568   // https://docs.nvidia.com/deeplearning/cudnn/developer-guide/index.html#scaling-parameters
4569   // for more info; the behavior for int8 result tensors is not described there,
4570   // but is maintained from the existing behavior (namely, using a float scaling
4571   // parameter).
ToVoidPointer(dnn::DataType element_type)4572   void* ToVoidPointer(dnn::DataType element_type) {
4573     if (element_type == dnn::DataType::kDouble) {
4574       return &as_double_;
4575     } else {
4576       return &as_float_;
4577     }
4578   }
4579 
4580  private:
4581   double as_double_;
4582   float as_float_;
4583 };
4584 
4585 #if CUDNN_VERSION >= 8100 && TF_ENABLE_CUDNN_FRONTEND
4586 namespace {
4587 
4588 template <typename Sig>
CreateOpRunners(Stream * stream,CudnnHandle & cudnn,GpuExecutor * gpu_executor,CudnnAccess * cudnn_access,std::unique_ptr<cudnn_frontend::OperationGraph> op_graph,dnn::ConvolutionKind kind,dnn::DataType input_type,absl::Span<const int64_t> input_uids,bool use_fallback,std::vector<std::unique_ptr<const dnn::OpRunner<Sig>>> * out_runners,bool need_side_input)4589 port::Status CreateOpRunners(
4590     Stream* stream, CudnnHandle& cudnn, GpuExecutor* gpu_executor,
4591     CudnnAccess* cudnn_access,
4592     std::unique_ptr<cudnn_frontend::OperationGraph> op_graph,
4593     dnn::ConvolutionKind kind, dnn::DataType input_type,
4594     absl::Span<const int64_t> input_uids, bool use_fallback,
4595     std::vector<std::unique_ptr<const dnn::OpRunner<Sig>>>* out_runners,
4596     bool need_side_input) {
4597   cudnn_frontend::EngineConfigList filtered_configs;
4598   auto generic_filter_fn = [=](cudnnBackendDescriptor_t engine_config) -> bool {
4599     return GenericEngineFilter(
4600         engine_config,
4601         /*disable_winograd*/ !CudnnEnvVar<WinogradNonfused>::IsEnabled(),
4602         /*disable_nondeterminism*/ RequireCudnnDeterminism(),
4603         /*disable_tensor_core*/ !IsTensorMathEnabled(stream, input_type));
4604   };
4605 
4606   if (!use_fallback) {
4607     // In theory, mode GetCudnnFrontendHeurMode() is supposed to fall back to
4608     // HEUR_MODE_INSTANT if it can't find any working engines. But there's a
4609     // known cudnn issue where it doesn't when dealing with runtime compiled
4610     // fusion engines. So we do it manually here.
4611     // TODO(kaixih@nvidia): remove this when the cudnn fixes it.
4612     cudnn_frontend::EngineHeuristics heuristics = [&] {
4613       auto heuristics = cudnn_frontend::EngineHeuristicsBuilder()
4614                             .setOperationGraph(*op_graph)
4615                             .setHeurMode(GetCudnnFrontendHeurMode())
4616                             .build();
4617       if (heuristics.get_status() == CUDNN_STATUS_SUCCESS) return heuristics;
4618       return cudnn_frontend::EngineHeuristicsBuilder()
4619           .setOperationGraph(*op_graph)
4620           .setHeurMode(CUDNN_HEUR_MODE_INSTANT)
4621           .build();
4622     }();
4623     RETURN_MSG_IF_CUDNN_ERROR(heuristics);
4624 
4625     // cuDNN frontend sneakily puts error messages on the object and returns
4626     // partially-initialized results when there's an error; make sure to check
4627     // them.
4628     int64_t engine_count = heuristics.getEngineConfigCount();
4629     RETURN_MSG_IF_CUDNN_ERROR(heuristics);
4630     auto& heuristics_configs = heuristics.getEngineConfig(engine_count);
4631     RETURN_MSG_IF_CUDNN_ERROR(heuristics);
4632     VLOG(4) << "\nHeuristics engine configs size: "
4633             << heuristics_configs.size();
4634 
4635     cudnn_frontend::filter(heuristics_configs, filtered_configs,
4636                            generic_filter_fn);
4637   } else {
4638     auto fallback = cudnn_frontend::EngineFallbackListBuilder()
4639                         .setOperationGraph(*op_graph)
4640                         .setOperation(GetCudnnConvolutionType(kind))
4641                         .build();
4642     RETURN_MSG_IF_CUDNN_ERROR(fallback);
4643 
4644     auto& fallback_configs = fallback.getFallbackList();
4645     VLOG(4) << "\nFallback engine configs size: " << fallback_configs.size();
4646 
4647     cudnn_frontend::filter(fallback_configs, filtered_configs,
4648                            generic_filter_fn);
4649   }
4650   VLOG(4) << "\nFiltered engine configs size: " << filtered_configs.size();
4651 
4652   auto fn = []() { return true; };
4653   auto maybe_json_handle_static = CudnnExecutionPlanEngineFilterStatic();
4654   auto maybe_json_handle_runtime = CudnnExecutionPlanEngineFilterRuntime();
4655 
4656   out_runners->clear();
4657   for (int i = 0; i < filtered_configs.size(); i++) {
4658     auto plan = cudnn_frontend::ExecutionPlanBuilder()
4659                     .setHandle(cudnn.handle())
4660                     .setEngineConfig(filtered_configs[i], op_graph->getTag())
4661                     .build();
4662     if (plan.get_status() != CUDNN_STATUS_SUCCESS) {
4663       continue;
4664     }
4665 
4666     if (maybe_json_handle_static &&
4667         cudnn_frontend::check_errata(*maybe_json_handle_static, plan.getTag(),
4668                                      cudnn.handle(), fn)) {
4669       VLOG(4) << "Exclude engine (static): " << plan.getTag();
4670       continue;
4671     }
4672     if (maybe_json_handle_runtime &&
4673         cudnn_frontend::check_errata(*maybe_json_handle_runtime, plan.getTag(),
4674                                      cudnn.handle(), fn)) {
4675       VLOG(4) << "Exclude engine (runtime): " << plan.getTag();
4676       continue;
4677     }
4678 
4679     auto runner_or = CudnnExecutionPlanRunner<Sig>::Create(
4680         gpu_executor, cudnn_access, std::move(plan), input_uids,
4681         need_side_input);
4682     if (!runner_or.ok()) {
4683       // Note this can happen if cuDNN Frontend gives us partially-initialized
4684       // ExecutionPlans because its error handling is broken in non-exception
4685       // builds; those were meant to be filtered out earlier inside cuDNN
4686       // Frontend, but instead they get filtered out here.
4687       VLOG(4) << "Failed building runner from ExecutionPlan (i.e. failed "
4688                  "getting its workspace size): "
4689               << runner_or.status().ToString();
4690       continue;
4691     }
4692 
4693     out_runners->push_back(std::make_unique<CudnnExecutionPlanRunner<Sig>>(
4694         std::move(runner_or).value()));
4695 
4696     // We will use the first working plan when determinism is required.
4697     if (RequireCudnnDeterminism()) {
4698       break;
4699     }
4700   }
4701 
4702   VLOG(4) << "\nReturned execution plans size: " << out_runners->size();
4703 
4704   return port::Status::OK();
4705 }
4706 
4707 }  // namespace
4708 #endif  // CUDNN_VERSION >= 8100 && TF_ENABLE_CUDNN_FRONTEND
4709 
GetConvolveRunners(bool use_cudnn_frontend,dnn::ConvolutionKind kind,dnn::DataType input_type,dnn::DataType output_type,Stream * stream,const dnn::BatchDescriptor & input_descriptor,DeviceMemoryBase,const dnn::FilterDescriptor & filter_descriptor,DeviceMemoryBase,const dnn::BatchDescriptor & output_descriptor,DeviceMemoryBase,const dnn::ConvolutionDescriptor & convolution_descriptor,bool use_fallback,ScratchAllocator *,std::vector<std::unique_ptr<const dnn::ConvRunner>> * out_exec_plans)4710 port::Status CudnnSupport::GetConvolveRunners(
4711     bool use_cudnn_frontend, dnn::ConvolutionKind kind,
4712     dnn::DataType input_type, dnn::DataType output_type, Stream* stream,
4713     const dnn::BatchDescriptor& input_descriptor,
4714     DeviceMemoryBase /*input_data*/,
4715     const dnn::FilterDescriptor& filter_descriptor,
4716     DeviceMemoryBase /*filter_data*/,
4717     const dnn::BatchDescriptor& output_descriptor,
4718     DeviceMemoryBase /*output_data*/,
4719     const dnn::ConvolutionDescriptor& convolution_descriptor, bool use_fallback,
4720     ScratchAllocator* /*scratch_allocator*/,
4721     std::vector<std::unique_ptr<const dnn::ConvRunner>>* out_exec_plans) {
4722   // All current versions of the frontend API lack support for Tx32
4723   // convolutions.
4724   const bool is_unsupported_x32 =
4725       input_descriptor.layout() == dnn::kBatchDepthYX32;
4726 
4727   // cuDNN frontend support became sufficiently stable to use in 8.1.
4728   // TODO(awpr): remove this condition once support for cuDNN 8.0 is dropped.
4729   const bool is_pre_frontend_cudnn = CUDNN_VERSION < 8100;
4730 
4731   const bool actually_use_cudnn_frontend =
4732       use_cudnn_frontend && !is_pre_frontend_cudnn && !is_unsupported_x32;
4733 
4734   if (use_cudnn_frontend && !actually_use_cudnn_frontend) {
4735     // This will happen once per unique conv configuration/shape that gets
4736     // affected (and not, for example, on every conv launch).  Confusion over
4737     // whether this has happened or not has repeatedly wasted a lot of time
4738     // debugging, so be sure it shows up in the logs.
4739     LOG(INFO) << "Disabling cuDNN frontend for the following convolution:\n"
4740               << "  input: " << input_descriptor.ToString() << "\n"
4741               << "  filter: " << filter_descriptor.ToString() << "\n"
4742               << "  " << convolution_descriptor.ToString() << "\n"
4743               << "  ... because "
4744               << (is_unsupported_x32
4745                       ? "Tx32 convolutions are unsupported."
4746                       : "the current cuDNN version does not support it.");
4747   }
4748 
4749   if (!actually_use_cudnn_frontend) {
4750     auto cuda_compute_capability = stream->GetCudaComputeCapability();
4751     std::vector<dnn::AlgorithmDesc> algorithms;
4752     bool got_algos = false;
4753     switch (kind) {
4754       default:
4755         return port::InternalError(absl::StrFormat(
4756             "Unknown ConvolutionKind for unfused conv: %d", kind));
4757       case dnn::ConvolutionKind::FORWARD:
4758         got_algos = GetConvolveAlgorithms(cuda_compute_capability, &algorithms);
4759         break;
4760       case dnn::ConvolutionKind::BACKWARD_FILTER:
4761         got_algos = GetConvolveBackwardFilterAlgorithms(cuda_compute_capability,
4762                                                         &algorithms);
4763         break;
4764       case dnn::ConvolutionKind::BACKWARD_DATA:
4765         got_algos = GetConvolveBackwardDataAlgorithms(cuda_compute_capability,
4766                                                       &algorithms);
4767         break;
4768     }
4769     if (!got_algos) {
4770       return port::Status(
4771           port::error::UNKNOWN,
4772           absl::StrFormat("Listing algorithms failed for kind %d", kind));
4773     }
4774 
4775     for (const auto& algo : algorithms) {
4776       auto runner_or = ConvolveRunnerFromDesc(
4777           stream, algo, kind, input_type, output_type, input_descriptor,
4778           filter_descriptor, output_descriptor, convolution_descriptor);
4779       if (!runner_or.ok()) {
4780         // Failures here can result from trying to query the workspace size for
4781         // algorithms that aren't supported for the present configuration.  This
4782         // means we'll now return only supported algorithms, unlike the
4783         // predecessor 'GetConvolveAlgorithms', which returned all existing
4784         // algorithms regardless of any particular configuration.
4785         //
4786         // TODO(awpr): can we arrange for the expected errors here to have a
4787         // particular error code (e.g. UNIMPLEMENTED or INVALID_ARGUMENT) and
4788         // log errors for anything unexpected?
4789         continue;
4790       }
4791       out_exec_plans->push_back(std::move(runner_or).value());
4792     }
4793 
4794     return ::tensorflow::OkStatus();
4795   }
4796 
4797 #if CUDNN_VERSION >= 8100 && TF_ENABLE_CUDNN_FRONTEND
4798   auto cudnn = cudnn_->GetHandle(parent_, stream);
4799   TF_ASSIGN_OR_RETURN(
4800       auto op_graph,
4801       GetCudnnOperationGraph(kind, input_type, output_type, input_descriptor,
4802                              filter_descriptor, output_descriptor,
4803                              convolution_descriptor, cudnn));
4804 
4805   return CreateOpRunners<dnn::ConvSignature>(
4806       stream, cudnn, parent_, cudnn_.get(), std::move(op_graph), kind,
4807       input_type, {'x', 'w', 'y'}, use_fallback, out_exec_plans,
4808       /*need_side_input=*/false);
4809 #else
4810   return port::UnimplementedError(
4811       "Cudnn execution plans are only supported with Cudnn >= 8.1.");
4812 #endif  // CUDNN_VERSION >= 8100 && TF_ENABLE_CUDNN_FRONTEND
4813 }
4814 
4815 port::StatusOr<std::unique_ptr<const dnn::ConvRunner>>
ConvolveRunnerFromDesc(Stream * stream,const dnn::AlgorithmDesc & algorithm_desc,dnn::ConvolutionKind kind,dnn::DataType input_type,dnn::DataType output_type,const dnn::BatchDescriptor & input_descriptor,const dnn::FilterDescriptor & filter_descriptor,const dnn::BatchDescriptor & output_descriptor,const dnn::ConvolutionDescriptor & convolution_descriptor)4816 CudnnSupport::ConvolveRunnerFromDesc(
4817     Stream* stream, const dnn::AlgorithmDesc& algorithm_desc,
4818     dnn::ConvolutionKind kind, dnn::DataType input_type,
4819     dnn::DataType output_type, const dnn::BatchDescriptor& input_descriptor,
4820     const dnn::FilterDescriptor& filter_descriptor,
4821     const dnn::BatchDescriptor& output_descriptor,
4822     const dnn::ConvolutionDescriptor& convolution_descriptor) {
4823   if (!algorithm_desc.is_cudnn_frontend()) {
4824     CudnnConvolutionDescriptor conv(
4825         convolution_descriptor,
4826         ToCudnnDataType(GetConvAccumulatorType(input_type)));
4827     conv.set_use_tensor_op_math(algorithm_desc.tensor_ops_enabled());
4828 
4829     TF_ASSIGN_OR_RETURN(
4830         auto runner,
4831         CudnnLegacyConvRunner::Create(
4832             parent_, stream, cudnn_.get(), algorithm_desc, input_type,
4833             output_type, kind,
4834             /* input_nd = */
4835             CudnnTensorDescriptor(
4836                 input_descriptor,
4837                 ToCudnnDataType(input_type, input_descriptor.layout())),
4838             /* output_nd = */
4839             CudnnTensorDescriptor(
4840                 output_descriptor,
4841                 ToCudnnDataType(output_type, output_descriptor.layout())),
4842             /* filter = */
4843             CudnnFilterDescriptor(
4844                 filter_descriptor,
4845                 ToCudnnDataType(input_type, filter_descriptor.layout())),
4846             std::move(conv)));
4847 
4848     return {std::make_unique<CudnnLegacyConvRunner>(std::move(runner))};
4849   }
4850 
4851 #if CUDNN_VERSION >= 8100 && TF_ENABLE_CUDNN_FRONTEND
4852   auto cudnn = cudnn_->GetHandle(parent_, stream);
4853 
4854   TF_ASSIGN_OR_RETURN(
4855       auto op_graph,
4856       GetCudnnOperationGraph(kind, input_type, output_type, input_descriptor,
4857                              filter_descriptor, output_descriptor,
4858                              convolution_descriptor, cudnn));
4859 
4860   TF_ASSIGN_OR_RETURN(auto execution_plan,
4861                       RebuildExecutionPlan(cudnn, algorithm_desc, *op_graph));
4862 
4863   TF_ASSIGN_OR_RETURN(
4864       auto runner,
4865       CudnnExecutionPlanRunner<dnn::ConvSignature>::Create(
4866           parent_, cudnn_.get(), std::move(execution_plan), {'x', 'w', 'y'},
4867           /*need_side_input=*/false));
4868   return {std::make_unique<CudnnExecutionPlanRunner<dnn::ConvSignature>>(
4869       std::move(runner))};
4870 #else
4871   return port::UnimplementedError(
4872       "Cudnn execution plans are only supported with Cudnn >= 8.1.");
4873 #endif
4874 }
4875 
4876 class CudnnLegacyFusedConvRunner : public dnn::FusedConvRunner {
4877  public:
4878   // Queries the workspace size and constructs a 'CudnnLegacyFusedConvRunner'.
Create(GpuExecutor * parent,Stream * stream,CudnnAccess * cudnn,const dnn::AlgorithmDesc & algo,dnn::DataType input_type,double conv_scale,double side_input_scale,CudnnTensorDescriptor input_nd,CudnnTensorDescriptor output_nd,CudnnFilterDescriptor filter,CudnnTensorDescriptor bias_nd,CudnnConvolutionDescriptor conv,CudnnActivationDescriptor activation_desc)4879   static port::StatusOr<CudnnLegacyFusedConvRunner> Create(
4880       GpuExecutor* parent, Stream* stream, CudnnAccess* cudnn,
4881       const dnn::AlgorithmDesc& algo, dnn::DataType input_type,
4882       double conv_scale, double side_input_scale,
4883       CudnnTensorDescriptor input_nd, CudnnTensorDescriptor output_nd,
4884       CudnnFilterDescriptor filter, CudnnTensorDescriptor bias_nd,
4885       CudnnConvolutionDescriptor conv,
4886       CudnnActivationDescriptor activation_desc) {
4887     size_t workspace_size;
4888     if (algo.workspace_size()) {
4889       workspace_size = *algo.workspace_size();
4890     } else {
4891       auto handle = cudnn->GetHandle(parent, stream);
4892 
4893       RETURN_IF_CUDNN_ERROR(cudnnGetConvolutionForwardWorkspaceSize(
4894           handle.handle(),
4895           /*xDesc=*/input_nd.handle(),
4896           /*wDesc=*/filter.handle(), /*convDesc=*/conv.handle(),
4897           /*yDesc=*/output_nd.handle(),
4898           /*algo=*/ToConvForwardAlgo(algo),
4899           /*sizeInBytes=*/&workspace_size));
4900     }
4901 
4902     return {{parent, cudnn, algo.algo_id(), algo.tensor_ops_enabled(),
4903              workspace_size, input_type, conv_scale, side_input_scale,
4904              std::move(input_nd), std::move(output_nd), std::move(filter),
4905              std::move(bias_nd), std::move(conv), std::move(activation_desc)}};
4906   }
4907 
ToString() const4908   std::string ToString() const override {
4909     return MakeAlgorithmDesc().ToString();
4910   }
4911 
GetWorkspaceSize() const4912   uint64_t GetWorkspaceSize() const override { return workspace_size_; }
4913 
ToAlgorithmDesc() const4914   port::StatusOr<dnn::AlgorithmDesc> ToAlgorithmDesc() const override {
4915     return MakeAlgorithmDesc();
4916   }
4917 
operator ()(Stream * stream,dnn::ProfileResult * profile_result,DeviceMemoryBase scratch_memory,DeviceMemoryBase input_data,DeviceMemoryBase filter_data,DeviceMemoryBase side_input_data,DeviceMemoryBase bias_data,DeviceMemoryBase output_data) const4918   port::Status operator()(Stream* stream, dnn::ProfileResult* profile_result,
4919                           DeviceMemoryBase scratch_memory,
4920                           DeviceMemoryBase input_data,
4921                           DeviceMemoryBase filter_data,
4922                           DeviceMemoryBase side_input_data,
4923                           DeviceMemoryBase bias_data,
4924                           DeviceMemoryBase output_data) const override {
4925     if (static_cast<internal::StreamExecutorInterface*>(parent_) !=
4926         stream->parent()->implementation()) {
4927       return port::InternalError(
4928           "CudnnLegacyFusedConvRunner cached across multiple StreamExecutors.");
4929     }
4930 
4931     auto algo = MakeAlgorithmDesc();
4932 
4933     std::unique_ptr<GpuTimer, GpuTimerDeleter> timer;
4934     if (profile_result) {
4935       timer.reset(new GpuTimer(parent_));  // NOLINT
4936       // The start and stop of the timer should be as close to the Cudnn call as
4937       // possible. It is still possible for other threads to issue workload on
4938       // to this stream. So it could take multiple profiling measurements.
4939       if (!timer->Init() || !timer->Start(AsGpuStream(stream))) {
4940         return port::Status(port::error::INTERNAL, "Failed to start timer");
4941       }
4942     }
4943     auto side_input_data_ptr = (side_input_scale_ == 0)
4944                                    ? output_data.opaque()
4945                                    : side_input_data.opaque();
4946 
4947     auto cudnn = cudnn_->GetHandle(parent_, stream);
4948 
4949     VLOG(2) << "\nconv_scale = " << conv_scale_
4950             << "\nconv_input_nd.handle() = " << input_nd_.handle()
4951             << "\nconv_input_data.opaque() = " << input_data.opaque()
4952             << "\nfilter.handle() = " << filter_.handle()
4953             << "\nfilter_data.opaque() = " << filter_data.opaque()
4954             << "\nconv.handle() = " << conv_.handle() << "\nalgo = " << algo_id_
4955             << ", tensor_ops_enabled=" << tensor_ops_enabled_
4956             << "\nscratch.opaque() = " << scratch_memory.opaque()
4957             << "\nscratch.size() = " << scratch_memory.size()
4958             << "\nside_input_scale = " << side_input_scale_
4959             << "\noutput_nd.handle() = " << output_nd_.handle()
4960             << "\nside_input_data_ptr = " << side_input_data_ptr
4961             << "\nbias_nd.handle() = " << bias_nd_.handle()
4962             << "\nbiases.opaque() = " << bias_data.opaque()
4963             << "\nactivation_desc.handle() = " << activation_desc_.handle()
4964             << "\noutput_nd.handle() = " << output_nd_.handle()
4965             << "\noutput_data.opaque() = " << output_data.opaque();
4966 
4967     if (IsTensorMathOpSet(conv_) != tensor_ops_enabled_) {
4968       return port::Status(port::error::FAILED_PRECONDITION,
4969                           "Tensor op math type in dnn::AlgorithmDesc does not "
4970                           "match that of the CudnnConvolutionDescriptor");
4971     }
4972 
4973     // N.B. the scaling parameters alpha1 and alpha2 are pointers to
4974     // temporaries; this API doesn't persist the pointers beyond its own stack
4975     // frame.
4976     auto status = cudnnConvolutionBiasActivationForward(
4977         cudnn.handle(),
4978         /*alpha1=*/ScalingParam(conv_scale_).ToVoidPointer(input_type_),
4979         /*xDesc=*/input_nd_.handle(), /*x=*/input_data.opaque(),
4980         /*wDesc=*/filter_.handle(), /*w=*/filter_data.opaque(),
4981         /*convDesc=*/conv_.handle(), ToConvForwardAlgo(algo),
4982         /*workSpace=*/scratch_memory.opaque(),
4983         /*workSpaceSizeInBytes=*/scratch_memory.size(),
4984         /*alpha2=*/ScalingParam(side_input_scale_).ToVoidPointer(input_type_),
4985         /*zDesc=*/output_nd_.handle(), /*z=*/side_input_data_ptr,
4986         /*biasDesc=*/bias_nd_.handle(), /*bias=*/bias_data.opaque(),
4987         /*activationDesc=*/activation_desc_.handle(),
4988         /*yDesc=*/output_nd_.handle(), /*y=*/output_data.opaque());
4989     if (status != CUDNN_STATUS_SUCCESS || !profile_result) {
4990       VLOG(4) << "conv with algorithm " << ToConvForwardAlgo(algo)
4991               << ", workspace_size=" << scratch_memory.size() << " -> "
4992               << CudnnStatusToString(status);
4993     }
4994     RETURN_IF_CUDNN_ERROR(status);
4995 
4996     if (profile_result) {
4997       if (!timer->Stop(AsGpuStream(stream))) {
4998         return port::Status(port::error::INTERNAL, "Failed to stop timer");
4999       }
5000       profile_result->set_algorithm(algo);
5001       profile_result->set_elapsed_time_in_ms(timer->GetElapsedMilliseconds());
5002       profile_result->set_scratch_size(scratch_memory.size());
5003       VLOG(4) << "conv with algorithm " << ToConvForwardAlgo(algo)
5004               << ", tensor_ops_enabled=" << tensor_ops_enabled_
5005               << ", workspace_size=" << scratch_memory.size() << " -> "
5006               << CudnnStatusToString(status) << " in "
5007               << timer->GetElapsedMilliseconds() << "ms";
5008     }
5009 
5010     return ::tensorflow::OkStatus();
5011   }
5012 
5013  private:
5014   // Private to prevent passing in the wrong workspace_size.
CudnnLegacyFusedConvRunner(GpuExecutor * parent,CudnnAccess * cudnn,int64_t algo_id,bool tensor_ops_enabled,size_t workspace_size,dnn::DataType input_type,double conv_scale,double side_input_scale,CudnnTensorDescriptor input_nd,CudnnTensorDescriptor output_nd,CudnnFilterDescriptor filter,CudnnTensorDescriptor bias_nd,CudnnConvolutionDescriptor conv,CudnnActivationDescriptor activation_desc)5015   CudnnLegacyFusedConvRunner(GpuExecutor* parent, CudnnAccess* cudnn,
5016                              int64_t algo_id, bool tensor_ops_enabled,
5017                              size_t workspace_size, dnn::DataType input_type,
5018                              double conv_scale, double side_input_scale,
5019                              CudnnTensorDescriptor input_nd,
5020                              CudnnTensorDescriptor output_nd,
5021                              CudnnFilterDescriptor filter,
5022                              CudnnTensorDescriptor bias_nd,
5023                              CudnnConvolutionDescriptor conv,
5024                              CudnnActivationDescriptor activation_desc)
5025       : parent_(parent),
5026         cudnn_(cudnn),
5027         algo_id_(algo_id),
5028         tensor_ops_enabled_(tensor_ops_enabled),
5029         workspace_size_(workspace_size),
5030         input_type_(input_type),
5031         conv_scale_(conv_scale),
5032         side_input_scale_(side_input_scale),
5033         input_nd_(std::move(input_nd)),
5034         output_nd_(std::move(output_nd)),
5035         filter_(std::move(filter)),
5036         bias_nd_(std::move(bias_nd)),
5037         conv_(std::move(conv)),
5038         activation_desc_(std::move(activation_desc)) {}
5039 
5040   // Internal form of ToAlgorithmDesc without the StatusOr.
MakeAlgorithmDesc() const5041   dnn::AlgorithmDesc MakeAlgorithmDesc() const {
5042     return {algo_id_, tensor_ops_enabled_, workspace_size_};
5043   }
5044 
5045   GpuExecutor* parent_;
5046   CudnnAccess* cudnn_;
5047   int64_t algo_id_;
5048   bool tensor_ops_enabled_;
5049   size_t workspace_size_;
5050   dnn::DataType input_type_;
5051   double conv_scale_, side_input_scale_;
5052 
5053   CudnnTensorDescriptor input_nd_;
5054   CudnnTensorDescriptor output_nd_;
5055   CudnnFilterDescriptor filter_;
5056   CudnnTensorDescriptor bias_nd_;
5057   CudnnConvolutionDescriptor conv_;
5058   CudnnActivationDescriptor activation_desc_;
5059 };
5060 
5061 port::StatusOr<std::unique_ptr<const dnn::FusedConvRunner>>
FusedConvolveRunnerFromDesc(Stream * stream,const dnn::AlgorithmDesc & algorithm_desc,dnn::ConvolutionKind kind,dnn::DataType input_type,dnn::DataType bias_type,dnn::DataType output_type,double conv_scale,double side_input_scale,double leakyrelu_alpha,const dnn::BatchDescriptor & input_descriptor,const dnn::FilterDescriptor & filter_descriptor,const dnn::BatchDescriptor & bias_descriptor,const dnn::BatchDescriptor & output_descriptor,const dnn::ConvolutionDescriptor & convolution_descriptor,dnn::ActivationMode activation_mode)5062 CudnnSupport::FusedConvolveRunnerFromDesc(
5063     Stream* stream, const dnn::AlgorithmDesc& algorithm_desc,
5064     dnn::ConvolutionKind kind, dnn::DataType input_type,
5065     dnn::DataType bias_type, dnn::DataType output_type, double conv_scale,
5066     double side_input_scale, double leakyrelu_alpha,
5067     const dnn::BatchDescriptor& input_descriptor,
5068     const dnn::FilterDescriptor& filter_descriptor,
5069     const dnn::BatchDescriptor& bias_descriptor,
5070     const dnn::BatchDescriptor& output_descriptor,
5071     const dnn::ConvolutionDescriptor& convolution_descriptor,
5072     dnn::ActivationMode activation_mode) {
5073   if (!algorithm_desc.is_cudnn_frontend()) {
5074     CudnnTensorDescriptor conv_input_nd(
5075         input_descriptor,
5076         ToCudnnDataType(input_type, input_descriptor.layout()));
5077     CudnnTensorDescriptor output_nd(
5078         output_descriptor,
5079         ToCudnnDataType(output_type, input_descriptor.layout()));
5080     CudnnFilterDescriptor filter(
5081         filter_descriptor,
5082         ToCudnnDataType(input_type, filter_descriptor.layout()));
5083     CudnnTensorDescriptor bias_nd(bias_descriptor, ToCudnnDataType(bias_type));
5084 
5085     CudnnConvolutionDescriptor conv(
5086         convolution_descriptor,
5087         ToCudnnDataType(GetConvAccumulatorType(input_type)));
5088     conv.set_use_tensor_op_math(algorithm_desc.tensor_ops_enabled());
5089 
5090     // CUDNN v6 only supports CUDNN_NOT_PROPAGATE_NAN as the reluNanOpt for
5091     // activation descriptor. Note that this will change the nan propagation
5092     // behavior from separate conv, bias, and relu (which by default is
5093     // CUDNN_PROPAGATE_NAN).
5094     //
5095     // TODO(awpr): reevaluate this for newer cuDNN versions.
5096     CudnnActivationDescriptor activation_desc(activation_mode,
5097                                               CUDNN_NOT_PROPAGATE_NAN,
5098                                               output_descriptor.value_max());
5099 
5100     TF_ASSIGN_OR_RETURN(
5101         auto runner,
5102         CudnnLegacyFusedConvRunner::Create(
5103             parent_, stream, cudnn_.get(), algorithm_desc, input_type,
5104             conv_scale, side_input_scale, std::move(conv_input_nd),
5105             std::move(output_nd), std::move(filter), std::move(bias_nd),
5106             std::move(conv), std::move(activation_desc)));
5107     return {std::make_unique<CudnnLegacyFusedConvRunner>(std::move(runner))};
5108   }
5109 
5110 #if CUDNN_VERSION >= 8100 && TF_ENABLE_CUDNN_FRONTEND
5111   auto cudnn = cudnn_->GetHandle(parent_, stream);
5112 
5113   TF_ASSIGN_OR_RETURN(auto op_graph,
5114                       GetCudnnFusedOperationGraph(
5115                           kind, input_type, bias_type, output_type, conv_scale,
5116                           side_input_scale, leakyrelu_alpha, input_descriptor,
5117                           filter_descriptor, bias_descriptor, output_descriptor,
5118                           convolution_descriptor, activation_mode, cudnn));
5119 
5120   TF_ASSIGN_OR_RETURN(auto execution_plan,
5121                       RebuildExecutionPlan(cudnn, algorithm_desc, *op_graph));
5122 
5123   bool need_side_input =
5124       SideInputNeeded(activation_mode, conv_scale, side_input_scale);
5125   TF_ASSIGN_OR_RETURN(auto runner,
5126                       CudnnExecutionPlanRunner<dnn::FusedConvSignature>::Create(
5127                           parent_, cudnn_.get(), std::move(execution_plan),
5128                           {'x', 'w', 'z', 'b', 'y'}, need_side_input));
5129   return {std::make_unique<CudnnExecutionPlanRunner<dnn::FusedConvSignature>>(
5130       std::move(runner))};
5131 #else
5132   return port::UnimplementedError(
5133       "Cudnn execution plans are only supported with Cudnn >= 8.1.");
5134 #endif
5135 }
5136 
GetFusedConvolveRunners(bool use_cudnn_frontend,dnn::ConvolutionKind kind,dnn::DataType input_type,dnn::DataType bias_type,dnn::DataType output_type,double conv_scale,double side_input_scale,double leakyrelu_alpha,Stream * stream,const dnn::BatchDescriptor & input_descriptor,const dnn::FilterDescriptor & filter_descriptor,const dnn::BatchDescriptor & bias_descriptor,const dnn::BatchDescriptor & output_descriptor,const dnn::ConvolutionDescriptor & convolution_descriptor,bool use_fallback,const dnn::ActivationMode activation_mode,std::vector<std::unique_ptr<const dnn::FusedConvRunner>> * out_exec_plans)5137 port::Status CudnnSupport::GetFusedConvolveRunners(
5138     bool use_cudnn_frontend, dnn::ConvolutionKind kind,
5139     dnn::DataType input_type, dnn::DataType bias_type,
5140     dnn::DataType output_type, double conv_scale, double side_input_scale,
5141     double leakyrelu_alpha, Stream* stream,
5142     const dnn::BatchDescriptor& input_descriptor,
5143     const dnn::FilterDescriptor& filter_descriptor,
5144     const dnn::BatchDescriptor& bias_descriptor,
5145     const dnn::BatchDescriptor& output_descriptor,
5146     const dnn::ConvolutionDescriptor& convolution_descriptor, bool use_fallback,
5147     const dnn::ActivationMode activation_mode,
5148     std::vector<std::unique_ptr<const dnn::FusedConvRunner>>* out_exec_plans) {
5149   // Fused convolutions with identity activations are broken in that they
5150   // implicitly do ReLU on some engines, and we can't reliably detect which
5151   // ones.
5152   const bool is_broken_identity_fused_conv =
5153 #if CUDNN_VERSION < 8205
5154       activation_mode == dnn::ActivationMode::kNone;
5155 #else
5156       false;
5157 #endif
5158 
5159   // All current versions of the frontend API lack support for Tx32
5160   // convolutions.
5161   const bool is_unsupported_x32 =
5162       input_descriptor.layout() == dnn::kBatchDepthYX32;
5163 
5164   // cuDNN frontend support became sufficiently stable to use in 8.1.
5165   // TODO(awpr): remove this condition once support for cuDNN 8.0 is dropped.
5166   const bool is_pre_frontend_cudnn = CUDNN_VERSION < 8100;
5167 
5168   const bool actually_use_cudnn_frontend =
5169       use_cudnn_frontend && !is_pre_frontend_cudnn &&
5170       !is_broken_identity_fused_conv && !is_unsupported_x32;
5171 
5172   if (use_cudnn_frontend && !actually_use_cudnn_frontend) {
5173     const char* reason = "the current cuDNN version does not support it.";
5174     if (is_unsupported_x32) {
5175       reason = "Tx32 convolutions are unsupported.";
5176     } else if (is_broken_identity_fused_conv) {
5177       reason = "it uses an identity activation.";
5178     }
5179 
5180     // This will happen once per unique conv configuration/shape that gets
5181     // affected (and not, for example, on every conv launch).  Confusion over
5182     // whether this has happened or not has repeatedly wasted a lot of time
5183     // debugging, so be sure it shows up in the logs.
5184     LOG(INFO) << "Disabling cuDNN frontend for the following convolution:\n"
5185               << "  input: " << input_descriptor.ToString() << "\n"
5186               << "  filter: " << filter_descriptor.ToString() << "\n"
5187               << "  " << convolution_descriptor.ToString() << "\n"
5188               << "  ... because " << reason;
5189   }
5190 
5191   if (input_type == dnn::DataType::kInt8 &&
5192       !stream->GetCudaComputeCapability().IsAtLeast(6, 1)) {
5193     return port::UnimplementedError(
5194         "cudnnConvolutionBiasActivationForward() for int8 is only supported "
5195         "on GPUs with compute capability 6.1 or later.");
5196   }
5197 
5198   if (input_type == dnn::DataType::kInt8 &&
5199       output_type == dnn::DataType::kFloat &&
5200       (CUDNN_VERSION >= 8000 && CUDNN_VERSION <= 8200)) {
5201     return port::UnimplementedError(
5202         "int8 -> float fused conv is disabled for this cuDNN version. See "
5203         "go/nvbugs/3326122");
5204   }
5205 
5206   if (activation_mode != dnn::ActivationMode::kRelu &&
5207       activation_mode != dnn::ActivationMode::kRelu6 &&
5208       activation_mode != dnn::ActivationMode::kElu &&
5209       activation_mode != dnn::ActivationMode::kLeakyRelu &&
5210       activation_mode != dnn::ActivationMode::kNone) {
5211     return port::Status(port::error::INVALID_ARGUMENT,
5212                         "CuDNN fusion only supports activations of "
5213                         "{Relu, Relu6, Elu, <None>}.");
5214   }
5215 
5216   if (!actually_use_cudnn_frontend) {
5217     std::vector<dnn::AlgorithmDesc> algorithms;
5218 
5219     auto cuda_compute_capability = stream->GetCudaComputeCapability();
5220     if (!GetConvolveAlgorithms(cuda_compute_capability, &algorithms)) {
5221       return port::Status(port::error::UNKNOWN,
5222                           "Listing fused convolve algorithms failed.");
5223     }
5224 
5225     for (const auto& algo : algorithms) {
5226       // Only CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM is supported
5227       // for identity activation, other algs seem to quietly do Relu. See
5228       // https://docs.nvidia.com/deeplearning/sdk/cudnn-developer-guide/index.html#cudnnConvolutionBiasActivationForward
5229       if (activation_mode == dnn::ActivationMode::kNone &&
5230           algo.algo_id() != CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM) {
5231         continue;
5232       }
5233       auto runner_or = FusedConvolveRunnerFromDesc(
5234           stream, algo, kind, input_type, bias_type, output_type, conv_scale,
5235           side_input_scale, leakyrelu_alpha, input_descriptor,
5236           filter_descriptor, bias_descriptor, output_descriptor,
5237           convolution_descriptor, activation_mode);
5238       if (!runner_or.ok()) {
5239         // See the corresponding error handling in
5240         // CudnnSupport::GetConvolveRunners: this filters out algorithms that
5241         // don't support this conv.
5242         continue;
5243       }
5244       out_exec_plans->push_back(std::move(runner_or).value());
5245     }
5246     return ::tensorflow::OkStatus();
5247   }
5248 
5249 #if CUDNN_VERSION >= 8100 && TF_ENABLE_CUDNN_FRONTEND
5250   auto cudnn = cudnn_->GetHandle(parent_, stream);
5251   auto op_graph_status = GetCudnnFusedOperationGraph(
5252       kind, input_type, bias_type, output_type, conv_scale, side_input_scale,
5253       leakyrelu_alpha, input_descriptor, filter_descriptor, bias_descriptor,
5254       output_descriptor, convolution_descriptor, activation_mode, cudnn);
5255   if (!op_graph_status.status().ok()) {
5256     return port::Status(port::error::INTERNAL,
5257                         absl::StrCat("Cudnn graph failed to build: ",
5258                                      op_graph_status.status().ToString()));
5259   }
5260   auto op_graph = std::move(op_graph_status).value();
5261 
5262   bool need_side_input =
5263       SideInputNeeded(activation_mode, conv_scale, side_input_scale);
5264   return CreateOpRunners<dnn::FusedConvSignature>(
5265       stream, cudnn, parent_, cudnn_.get(), std::move(op_graph), kind,
5266       input_type, {'x', 'w', 'z', 'b', 'y'}, use_fallback, out_exec_plans,
5267       need_side_input);
5268 #else
5269   return port::UnimplementedError(
5270       "Cudnn execution plans are only supported with Cudnn >= 8.1.");
5271 #endif  // CUDNN_VERSION >= 8100 && TF_ENABLE_CUDNN_FRONTEND
5272 }
5273 
GetConvolveAlgorithms(CudaComputeCapability cuda_compute_capability,std::vector<dnn::AlgorithmDesc> * out_algorithms)5274 bool CudnnSupport::GetConvolveAlgorithms(
5275     CudaComputeCapability cuda_compute_capability,
5276     std::vector<dnn::AlgorithmDesc>* out_algorithms) {
5277   PreloadCudnnSubLibs(PreloadCudnnType::ConvFwd);
5278 
5279   bool tensor_op_math_available =
5280       TensorOpMathAvailable(cuda_compute_capability);
5281   out_algorithms->clear();
5282 
5283   std::vector<dnn::AlgorithmDesc::Index> algo_types;
5284   if (ConvUseDefaultAlgorithm()) {
5285     // Force a fallback algorithm.
5286     algo_types = {CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM};
5287   } else {
5288     algo_types = {CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM,
5289                   CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM,
5290                   CUDNN_CONVOLUTION_FWD_ALGO_GEMM,
5291                   CUDNN_CONVOLUTION_FWD_ALGO_DIRECT,
5292                   CUDNN_CONVOLUTION_FWD_ALGO_FFT,
5293                   CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD};
5294     if (CudnnEnvVar<FftTilingForward>::IsEnabled()) {
5295       algo_types.push_back(CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING);
5296     }
5297     if (CudnnEnvVar<WinogradNonfused>::IsEnabled()) {
5298       algo_types.push_back(CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED);
5299     }
5300   }
5301 
5302   // The algorithms are intentionally ordered for deterministic operation
5303   for (auto i : algo_types) {
5304     if (tensor_op_math_available) {
5305       out_algorithms->push_back({i, /*use_tensor_ops=*/true});
5306     }
5307     out_algorithms->push_back({i, /*use_tensor_ops=*/false});
5308   }
5309 
5310   return true;
5311 }
5312 
GetRnnAlgorithms(std::vector<dnn::AlgorithmDesc> * out_algorithms)5313 bool CudnnSupport::GetRnnAlgorithms(
5314     std::vector<dnn::AlgorithmDesc>* out_algorithms) {
5315   PreloadCudnnSubLibs(PreloadCudnnType::Rnn);
5316 
5317   std::vector<dnn::AlgorithmDesc::Index> algo_types = {
5318       // clang-format off
5319     CUDNN_RNN_ALGO_STANDARD,
5320     CUDNN_RNN_ALGO_PERSIST_STATIC,
5321     CUDNN_RNN_ALGO_PERSIST_DYNAMIC,
5322       // clang-format on
5323   };
5324 
5325   out_algorithms->clear();
5326   for (auto i : algo_types) {
5327     out_algorithms->push_back({i, /*use_tensor_ops=*/false});
5328     out_algorithms->push_back({i, /*use_tensor_ops=*/true});
5329   }
5330   return true;
5331 }
5332 
GetConvolveBackwardDataAlgorithms(CudaComputeCapability cuda_compute_capability,std::vector<dnn::AlgorithmDesc> * out_algorithms)5333 bool CudnnSupport::GetConvolveBackwardDataAlgorithms(
5334     CudaComputeCapability cuda_compute_capability,
5335     std::vector<dnn::AlgorithmDesc>* out_algorithms) {
5336   PreloadCudnnSubLibs(PreloadCudnnType::ConvBwdData);
5337 
5338   bool tensor_op_math_available =
5339       TensorOpMathAvailable(cuda_compute_capability);
5340   out_algorithms->clear();
5341 
5342   std::vector<dnn::AlgorithmDesc::Index> algo_types = {
5343       // clang-format off
5344     CUDNN_CONVOLUTION_BWD_DATA_ALGO_1,
5345     CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT,
5346     CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING,
5347     CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD,
5348       // clang-format on
5349   };
5350   if (CudnnEnvVar<WinogradNonfused>::IsEnabled()) {
5351     algo_types.push_back(CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD_NONFUSED);
5352   }
5353   if (!RequireCudnnDeterminism()) {
5354     algo_types.push_back(CUDNN_CONVOLUTION_BWD_DATA_ALGO_0);
5355   }
5356 
5357   // The algorithms are intentionally ordered for deterministic operation
5358   for (auto i : algo_types) {
5359     if (tensor_op_math_available) {
5360       out_algorithms->push_back({i, /*use_tensor_ops=*/true});
5361     }
5362     out_algorithms->push_back({i, /*use_tensor_ops=*/false});
5363   }
5364 
5365   return true;
5366 }
5367 
GetConvolveBackwardFilterAlgorithms(CudaComputeCapability cuda_compute_capability,std::vector<dnn::AlgorithmDesc> * out_algorithms)5368 bool CudnnSupport::GetConvolveBackwardFilterAlgorithms(
5369     CudaComputeCapability cuda_compute_capability,
5370     std::vector<dnn::AlgorithmDesc>* out_algorithms) {
5371   PreloadCudnnSubLibs(PreloadCudnnType::ConvBwdFilter);
5372 
5373   bool tensor_op_math_available =
5374       TensorOpMathAvailable(cuda_compute_capability);
5375   out_algorithms->clear();
5376 
5377   std::vector<dnn::AlgorithmDesc::Index> algo_types = {
5378       // clang-format off
5379       CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1,
5380       CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT,
5381       // Based on cudnn.h, the following is not implemented.
5382       // CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD,
5383 
5384       // Produces incorrect results for some shapes. Disabled for now, see
5385       // NVIDIA bug 2072856. TODO(csigg): Only disable for subset of shapes.
5386       // CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT_TILING,
5387       // clang-format on
5388   };
5389   if (CudnnEnvVar<WinogradNonfused>::IsEnabled()) {
5390     algo_types.push_back(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD_NONFUSED);
5391   }
5392   if (!RequireCudnnDeterminism()) {
5393     algo_types.push_back(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0);
5394     algo_types.push_back(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_3);
5395   }
5396 
5397   // The algorithms are intentionally ordered for deterministic operation
5398   for (auto i : algo_types) {
5399     if (tensor_op_math_available) {
5400       out_algorithms->push_back({i, /*use_tensor_ops=*/true});
5401     }
5402     out_algorithms->push_back({i, /*use_tensor_ops=*/false});
5403   }
5404 
5405   return true;
5406 }
5407 
DoBatchNormalizationForward(Stream * stream,const DeviceMemory<float> & x,const DeviceMemory<float> & scale,const DeviceMemory<float> & offset,const DeviceMemory<float> & estimated_mean,const DeviceMemory<float> & estimated_variance,const DeviceMemory<float> & side_input,const dnn::BatchDescriptor & x_desc,const dnn::BatchDescriptor & scale_offset_desc,const double epsilon,const double exponential_average_factor,dnn::ActivationMode activation_mode,DeviceMemory<float> * y,DeviceMemory<float> * batch_mean,DeviceMemory<float> * batch_var,DeviceMemory<float> * saved_mean,DeviceMemory<float> * saved_inv_var,bool is_training,ScratchAllocator * reserve_space_allocator,ScratchAllocator * workspace_allocator)5408 bool CudnnSupport::DoBatchNormalizationForward(
5409     Stream* stream, const DeviceMemory<float>& x,
5410     const DeviceMemory<float>& scale, const DeviceMemory<float>& offset,
5411     const DeviceMemory<float>& estimated_mean,
5412     const DeviceMemory<float>& estimated_variance,
5413     const DeviceMemory<float>& side_input, const dnn::BatchDescriptor& x_desc,
5414     const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
5415     const double exponential_average_factor,
5416     dnn::ActivationMode activation_mode, DeviceMemory<float>* y,
5417     DeviceMemory<float>* batch_mean, DeviceMemory<float>* batch_var,
5418     DeviceMemory<float>* saved_mean, DeviceMemory<float>* saved_inv_var,
5419     bool is_training, ScratchAllocator* reserve_space_allocator,
5420     ScratchAllocator* workspace_allocator) {
5421   return IsStatusOk(
5422       DoBatchNormalizationForwardImpl<float, float>(
5423           stream, dnn::DataType::kFloat, dnn::DataType::kFloat, x, scale,
5424           offset, estimated_mean, estimated_variance, side_input, x_desc,
5425           scale_offset_desc, epsilon, exponential_average_factor,
5426           activation_mode, y, batch_mean, batch_var, saved_mean, saved_inv_var,
5427           is_training, reserve_space_allocator, workspace_allocator),
5428       /*report_error=*/true);
5429 }
5430 
DoBatchNormalizationForward(Stream * stream,const DeviceMemory<Eigen::half> & x,const DeviceMemory<float> & scale,const DeviceMemory<float> & offset,const DeviceMemory<float> & estimated_mean,const DeviceMemory<float> & estimated_variance,const DeviceMemory<Eigen::half> & side_input,const dnn::BatchDescriptor & x_desc,const dnn::BatchDescriptor & scale_offset_desc,const double epsilon,const double exponential_average_factor,dnn::ActivationMode activation_mode,DeviceMemory<Eigen::half> * y,DeviceMemory<float> * batch_mean,DeviceMemory<float> * batch_var,DeviceMemory<float> * saved_mean,DeviceMemory<float> * saved_inv_var,bool is_training,ScratchAllocator * reserve_space_allocator,ScratchAllocator * workspace_allocator)5431 bool CudnnSupport::DoBatchNormalizationForward(
5432     Stream* stream, const DeviceMemory<Eigen::half>& x,
5433     const DeviceMemory<float>& scale, const DeviceMemory<float>& offset,
5434     const DeviceMemory<float>& estimated_mean,
5435     const DeviceMemory<float>& estimated_variance,
5436     const DeviceMemory<Eigen::half>& side_input,
5437     const dnn::BatchDescriptor& x_desc,
5438     const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
5439     const double exponential_average_factor,
5440     dnn::ActivationMode activation_mode, DeviceMemory<Eigen::half>* y,
5441     DeviceMemory<float>* batch_mean, DeviceMemory<float>* batch_var,
5442     DeviceMemory<float>* saved_mean, DeviceMemory<float>* saved_inv_var,
5443     bool is_training, ScratchAllocator* reserve_space_allocator,
5444     ScratchAllocator* workspace_allocator) {
5445   return IsStatusOk(
5446       DoBatchNormalizationForwardImpl<Eigen::half, float>(
5447           stream, dnn::DataType::kHalf, dnn::DataType::kFloat, x, scale, offset,
5448           estimated_mean, estimated_variance, side_input, x_desc,
5449           scale_offset_desc, epsilon, exponential_average_factor,
5450           activation_mode, y, batch_mean, batch_var, saved_mean, saved_inv_var,
5451           is_training, reserve_space_allocator, workspace_allocator),
5452       /*report_error=*/true);
5453 }
5454 
5455 template <class T, class U>
DoBatchNormalizationForwardImpl(Stream * stream,dnn::DataType input_data_type,dnn::DataType scale_data_type,const DeviceMemory<T> & x,const DeviceMemory<U> & scale,const DeviceMemory<U> & offset,const DeviceMemory<U> & estimated_mean,const DeviceMemory<U> & estimated_variance,const DeviceMemory<T> & side_input,const dnn::BatchDescriptor & x_desc,const dnn::BatchDescriptor & scale_offset_desc,const double epsilon,const double exponential_average_factor,dnn::ActivationMode activation_mode,DeviceMemory<T> * y,DeviceMemory<U> * batch_mean,DeviceMemory<U> * batch_var,DeviceMemory<U> * saved_mean,DeviceMemory<U> * saved_inv_var,bool is_training,ScratchAllocator * reserve_space_allocator,ScratchAllocator * workspace_allocator)5456 port::Status CudnnSupport::DoBatchNormalizationForwardImpl(
5457     Stream* stream, dnn::DataType input_data_type,
5458     dnn::DataType scale_data_type, const DeviceMemory<T>& x,
5459     const DeviceMemory<U>& scale, const DeviceMemory<U>& offset,
5460     const DeviceMemory<U>& estimated_mean,
5461     const DeviceMemory<U>& estimated_variance,
5462     const DeviceMemory<T>& side_input, const dnn::BatchDescriptor& x_desc,
5463     const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
5464     const double exponential_average_factor,
5465     dnn::ActivationMode activation_mode, DeviceMemory<T>* y,
5466     DeviceMemory<U>* batch_mean, DeviceMemory<U>* batch_var,
5467     DeviceMemory<U>* saved_mean, DeviceMemory<U>* saved_inv_var,
5468     bool is_training, ScratchAllocator* reserve_space_allocator,
5469     ScratchAllocator* workspace_allocator) {
5470   CudnnTensorDescriptor x_descriptor(x_desc, ToCudnnDataType(input_data_type));
5471   CudnnTensorDescriptor scale_offset_descriptor(
5472       scale_offset_desc, ToCudnnDataType(scale_data_type));
5473   cudnnBatchNormMode_t mode = CUDNN_BATCHNORM_SPATIAL;
5474   if (BatchnormSpatialPersistentEnabled() && is_training) {
5475     mode = CUDNN_BATCHNORM_SPATIAL_PERSISTENT;
5476   }
5477   float one = 1.0;
5478   float zero = 0.0;
5479   auto cudnn = cudnn_->GetHandle(parent_, stream);
5480 
5481   DeviceMemory<uint8> workspace;
5482   DeviceMemory<uint8> reserve_space;
5483 
5484 #if CUDNN_VERSION >= 7402
5485   const auto get_bn_ops = [&]() -> cudnnBatchNormOps_t {
5486     if (side_input.is_null()) {
5487       return activation_mode == dnn::ActivationMode::kNone
5488                  ? CUDNN_BATCHNORM_OPS_BN
5489                  : CUDNN_BATCHNORM_OPS_BN_ACTIVATION;
5490     } else {
5491       return CUDNN_BATCHNORM_OPS_BN_ADD_ACTIVATION;
5492     }
5493   };
5494   const cudnnBatchNormOps_t bn_ops = get_bn_ops();
5495 
5496   // We use Nan propagation to be consistent with CudnnSupport::DoActivate(...).
5497   CudnnActivationDescriptor activation_desc(
5498       activation_mode, CUDNN_PROPAGATE_NAN, x_desc.value_max());
5499 
5500   if (reserve_space_allocator != nullptr && workspace_allocator != nullptr) {
5501     TF_ASSIGN_OR_RETURN(
5502         workspace,
5503         CreateBatchNormForwardWorkspace(
5504             stream, cudnn, mode, bn_ops, activation_desc.handle(), x_descriptor,
5505             scale_offset_descriptor, workspace_allocator));
5506     if (is_training) {
5507       size_t reserve_space_size_in_bytes = 0;
5508       RETURN_IF_CUDNN_ERROR(
5509           cudnnGetBatchNormalizationTrainingExReserveSpaceSize(
5510               /*handle=*/cudnn.handle(), /*mode=*/mode, /*bnOps=*/bn_ops,
5511               /*activationDesc=*/activation_desc.handle(),
5512               /*xDesc=*/x_descriptor.handle(),
5513               /*sizeInBytes=*/&reserve_space_size_in_bytes));
5514       TF_ASSIGN_OR_RETURN(reserve_space, reserve_space_allocator->AllocateBytes(
5515                                              reserve_space_size_in_bytes));
5516     }
5517   }
5518 #endif
5519 
5520   auto check_no_side_input_or_activation = [&]() -> port::Status {
5521     if (activation_mode != dnn::ActivationMode::kNone ||
5522         !side_input.is_null()) {
5523       return port::Status(
5524           port::error::INTERNAL,
5525           absl::StrCat(
5526               "Side input and activation are not supported by cuDNN version: ",
5527               CUDNN_VERSION));
5528     } else {
5529       return ::tensorflow::OkStatus();
5530     }
5531   };
5532 
5533   if (is_training) {
5534     CHECK_EQ(batch_mean->is_null(), batch_var->is_null())
5535         << "batch_mean and batch_var must both be null or both be non-null";
5536 
5537     void* batch_mean_opaque;
5538     void* batch_var_opaque;
5539     if (!batch_mean->is_null() && !batch_var->is_null()) {
5540       if (exponential_average_factor == 1.0) {
5541         stream->ThenMemZero(batch_mean, batch_mean->size());
5542         stream->ThenMemZero(batch_var, batch_var->size());
5543       }
5544       batch_mean_opaque = batch_mean->opaque();
5545       batch_var_opaque = batch_var->opaque();
5546     } else {
5547       batch_mean_opaque = nullptr;
5548       batch_var_opaque = nullptr;
5549     }
5550 
5551     bool called = false;
5552 #if CUDNN_VERSION >= 7402
5553     if (reserve_space_allocator != nullptr && workspace_allocator != nullptr) {
5554       called = true;
5555       RETURN_IF_CUDNN_ERROR(cudnnBatchNormalizationForwardTrainingEx(
5556           /*handle=*/cudnn.handle(),
5557           /*mode=*/mode,
5558           /*bnOps=*/bn_ops,
5559           /*alpha=*/&one,
5560           /*beta=*/&zero,
5561           /*xDesc=*/x_descriptor.handle(),
5562           /*xData=*/x.opaque(),
5563           /*zDesc=*/x_descriptor.handle(),
5564           /*zData=*/side_input.opaque(),
5565           /*yDesc=*/x_descriptor.handle(),
5566           /*yData=*/y->opaque(),
5567           /*bnScaleBiasMeanVarDesc=*/scale_offset_descriptor.handle(),
5568           /*bnScale=*/scale.opaque(),
5569           /*bnBias=*/offset.opaque(),
5570           /*exponentialAverageFactor=*/exponential_average_factor,
5571           /*resultRunningMean=*/batch_mean_opaque,
5572           /*resultRunningVariance=*/batch_var_opaque,
5573           /*epsilon=*/epsilon,
5574           /*resultSaveMean=*/saved_mean->opaque(),
5575           /*resultSaveInvVariance=*/saved_inv_var->opaque(),
5576           /*activationDesc=*/activation_desc.handle(),
5577           /*workspace=*/workspace.opaque(),
5578           /*workSpaceSizeInBytes=*/workspace.size(),
5579           /*reserveSpace=*/reserve_space.opaque(),
5580           /*reserveSpaceSizeInBytes=*/reserve_space.size()));
5581     }
5582 #endif
5583     if (!called) {
5584       TF_RETURN_IF_ERROR(check_no_side_input_or_activation());
5585       RETURN_IF_CUDNN_ERROR(cudnnBatchNormalizationForwardTraining(
5586           cudnn.handle(), mode, &one, &zero, x_descriptor.handle(), x.opaque(),
5587           x_descriptor.handle(), y->opaque(), scale_offset_descriptor.handle(),
5588           scale.opaque(), offset.opaque(), exponential_average_factor,
5589           batch_mean_opaque, batch_var_opaque, epsilon, saved_mean->opaque(),
5590           saved_inv_var->opaque()));
5591     }
5592   } else {
5593     const void* maybe_inv_var = estimated_variance.opaque();
5594     TF_RETURN_IF_ERROR(check_no_side_input_or_activation());
5595     RETURN_IF_CUDNN_ERROR(cudnnBatchNormalizationForwardInference(
5596         cudnn.handle(), mode, &one, &zero, x_descriptor.handle(), x.opaque(),
5597         x_descriptor.handle(), y->opaque(), scale_offset_descriptor.handle(),
5598         scale.opaque(), offset.opaque(), estimated_mean.opaque(), maybe_inv_var,
5599         epsilon));
5600   }
5601   return ::tensorflow::OkStatus();
5602 }
5603 
DoBatchNormalizationBackward(Stream * stream,const DeviceMemory<float> & y_backprop,const DeviceMemory<float> & x,const DeviceMemory<float> & scale,const DeviceMemory<float> & offset,const DeviceMemory<float> & mean,const DeviceMemory<float> & inv_var,const DeviceMemory<float> & y,const dnn::BatchDescriptor & x_desc,const dnn::BatchDescriptor & scale_offset_desc,const double epsilon,dnn::ActivationMode activation_mode,DeviceMemory<float> * x_backprop,DeviceMemory<float> * scale_backprop,DeviceMemory<float> * offset_backprop,DeviceMemory<float> * side_input_backprop,DeviceMemory<uint8> * reserve_space_data,ScratchAllocator * workspace_allocator)5604 bool CudnnSupport::DoBatchNormalizationBackward(
5605     Stream* stream, const DeviceMemory<float>& y_backprop,
5606     const DeviceMemory<float>& x, const DeviceMemory<float>& scale,
5607     const DeviceMemory<float>& offset, const DeviceMemory<float>& mean,
5608     const DeviceMemory<float>& inv_var, const DeviceMemory<float>& y,
5609     const dnn::BatchDescriptor& x_desc,
5610     const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
5611     dnn::ActivationMode activation_mode, DeviceMemory<float>* x_backprop,
5612     DeviceMemory<float>* scale_backprop, DeviceMemory<float>* offset_backprop,
5613     DeviceMemory<float>* side_input_backprop,
5614     DeviceMemory<uint8>* reserve_space_data,
5615     ScratchAllocator* workspace_allocator) {
5616   return IsStatusOk(
5617       DoBatchNormalizationBackwardImpl(
5618           stream, CUDNN_DATA_FLOAT, CUDNN_DATA_FLOAT, y_backprop, x, scale,
5619           offset, mean, inv_var, y, x_desc, scale_offset_desc, epsilon,
5620           activation_mode, x_backprop, scale_backprop, offset_backprop,
5621           side_input_backprop, reserve_space_data, workspace_allocator),
5622       /*report_error=*/true);
5623 }
5624 
DoBatchNormalizationBackward(Stream * stream,const DeviceMemory<Eigen::half> & y_backprop,const DeviceMemory<Eigen::half> & x,const DeviceMemory<float> & scale,const DeviceMemory<float> & offset,const DeviceMemory<float> & mean,const DeviceMemory<float> & inv_var,const DeviceMemory<Eigen::half> & y,const dnn::BatchDescriptor & x_desc,const dnn::BatchDescriptor & scale_offset_desc,const double epsilon,dnn::ActivationMode activation_mode,DeviceMemory<Eigen::half> * x_backprop,DeviceMemory<float> * scale_backprop,DeviceMemory<float> * offset_backprop,DeviceMemory<Eigen::half> * side_input_backprop,DeviceMemory<uint8> * reserve_space_data,ScratchAllocator * workspace_allocator)5625 bool CudnnSupport::DoBatchNormalizationBackward(
5626     Stream* stream, const DeviceMemory<Eigen::half>& y_backprop,
5627     const DeviceMemory<Eigen::half>& x, const DeviceMemory<float>& scale,
5628     const DeviceMemory<float>& offset, const DeviceMemory<float>& mean,
5629     const DeviceMemory<float>& inv_var, const DeviceMemory<Eigen::half>& y,
5630     const dnn::BatchDescriptor& x_desc,
5631     const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
5632     dnn::ActivationMode activation_mode, DeviceMemory<Eigen::half>* x_backprop,
5633     DeviceMemory<float>* scale_backprop, DeviceMemory<float>* offset_backprop,
5634     DeviceMemory<Eigen::half>* side_input_backprop,
5635     DeviceMemory<uint8>* reserve_space_data,
5636     ScratchAllocator* workspace_allocator) {
5637   return IsStatusOk(
5638       DoBatchNormalizationBackwardImpl(
5639           stream, CUDNN_DATA_HALF, CUDNN_DATA_FLOAT, y_backprop, x, scale,
5640           offset, mean, inv_var, y, x_desc, scale_offset_desc, epsilon,
5641           activation_mode, x_backprop, scale_backprop, offset_backprop,
5642           side_input_backprop, reserve_space_data, workspace_allocator),
5643       /*report_error=*/true);
5644 }
5645 
5646 template <class T, class U>
DoBatchNormalizationBackwardImpl(Stream * stream,int cudnn_input_type,int cudnn_scale_type,const DeviceMemory<T> & y_backprop,const DeviceMemory<T> & x,const DeviceMemory<U> & scale,const DeviceMemory<U> & offset,const DeviceMemory<U> & mean,const DeviceMemory<U> & inv_var,const DeviceMemory<T> & y,const dnn::BatchDescriptor & x_desc,const dnn::BatchDescriptor & scale_offset_desc,const double epsilon,dnn::ActivationMode activation_mode,DeviceMemory<T> * x_backprop,DeviceMemory<U> * scale_backprop,DeviceMemory<U> * offset_backprop,DeviceMemory<T> * side_input_backprop,DeviceMemory<uint8> * reserve_space_data,ScratchAllocator * workspace_allocator)5647 port::Status CudnnSupport::DoBatchNormalizationBackwardImpl(
5648     Stream* stream, int cudnn_input_type, int cudnn_scale_type,
5649     const DeviceMemory<T>& y_backprop, const DeviceMemory<T>& x,
5650     const DeviceMemory<U>& scale, const DeviceMemory<U>& offset,
5651     const DeviceMemory<U>& mean, const DeviceMemory<U>& inv_var,
5652     const DeviceMemory<T>& y, const dnn::BatchDescriptor& x_desc,
5653     const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
5654     dnn::ActivationMode activation_mode, DeviceMemory<T>* x_backprop,
5655     DeviceMemory<U>* scale_backprop, DeviceMemory<U>* offset_backprop,
5656     DeviceMemory<T>* side_input_backprop,
5657     DeviceMemory<uint8>* reserve_space_data,
5658     ScratchAllocator* workspace_allocator) {
5659   CudnnTensorDescriptor x_descriptor(
5660       x_desc, static_cast<cudnnDataType_t>(cudnn_input_type));
5661   CudnnTensorDescriptor scale_offset_descriptor(
5662       scale_offset_desc, static_cast<cudnnDataType_t>(cudnn_scale_type));
5663   cudnnBatchNormMode_t mode = CUDNN_BATCHNORM_SPATIAL;
5664   if (BatchnormSpatialPersistentEnabled()) {
5665     mode = CUDNN_BATCHNORM_SPATIAL_PERSISTENT;
5666   }
5667   float one = 1.0;
5668   float zero = 0.0;
5669 
5670   auto cudnn = cudnn_->GetHandle(parent_, stream);
5671 
5672   bool called = false;
5673 #if CUDNN_VERSION >= 7402
5674   if (reserve_space_data != nullptr && workspace_allocator != nullptr) {
5675     called = true;
5676     const cudnnBatchNormOps_t bn_ops = [&]() {
5677       if (side_input_backprop->is_null()) {
5678         return activation_mode == dnn::ActivationMode::kNone
5679                    ? CUDNN_BATCHNORM_OPS_BN
5680                    : CUDNN_BATCHNORM_OPS_BN_ACTIVATION;
5681       } else {
5682         return CUDNN_BATCHNORM_OPS_BN_ADD_ACTIVATION;
5683       }
5684     }();
5685 
5686     // We use Nan propagation to be consistent with
5687     // CudnnSupport::DoActivate(...).
5688     CudnnActivationDescriptor activation_desc(
5689         activation_mode, CUDNN_PROPAGATE_NAN, x_desc.value_max());
5690 
5691     TF_ASSIGN_OR_RETURN(
5692         DeviceMemory<uint8> workspace,
5693         CreateBatchNormBackwardWorkspace(
5694             stream, cudnn, mode, bn_ops, activation_desc.handle(), x_descriptor,
5695             scale_offset_descriptor, workspace_allocator));
5696     RETURN_IF_CUDNN_ERROR(cudnnBatchNormalizationBackwardEx(
5697         /*handle=*/cudnn.handle(),
5698         /*mode=*/mode,
5699         /*bnOps=*/bn_ops,
5700         /*alphaDataDiff=*/&one,
5701         /*betaDataDiff=*/&zero,
5702         /*alphaParamDiff=*/&one,
5703         /*betaParamDiff=*/&zero,
5704         /*xDesc=*/x_descriptor.handle(),
5705         /*xData=*/x.opaque(),
5706         /*yDesc=*/x_descriptor.handle(),
5707         /*yData=*/y.opaque(),
5708         /*dyDesc=*/x_descriptor.handle(),
5709         /*dyData=*/y_backprop.opaque(),
5710         /*dzDesc=*/x_descriptor.handle(),
5711         /*dzData=*/side_input_backprop->opaque(),
5712         /*dxDesc=*/x_descriptor.handle(),
5713         /*dxData=*/x_backprop->opaque(),
5714         /*dBnScaleBiasDesc=*/scale_offset_descriptor.handle(),
5715         /*bnScaleData=*/scale.opaque(),
5716         /*bnBiasData=*/offset.opaque(),
5717         /*dBnScaleData=*/scale_backprop->opaque(),
5718         /*dBnBiasData=*/offset_backprop->opaque(),
5719         /*epsilon=*/epsilon,
5720         /*savedMean=*/mean.opaque(),
5721         /*savedInvVariance=*/inv_var.opaque(),
5722         /*activationDesc=*/activation_desc.handle(),
5723         /*workspace=*/workspace.opaque(),
5724         /*workSpaceSizeInBytes=*/workspace.size(),
5725         /*reserveSpace=*/reserve_space_data->opaque(),
5726         /*reserveSpaceSizeInBytes=*/reserve_space_data->size()));
5727   }
5728 #endif
5729   auto check_no_side_input_or_activation = [&]() -> port::Status {
5730     if (activation_mode != dnn::ActivationMode::kNone ||
5731         !side_input_backprop->is_null()) {
5732       return port::InternalError(absl::StrCat(
5733           "Side input and activation are not supported by cuDNN version: ",
5734           CUDNN_VERSION));
5735     } else {
5736       return ::tensorflow::OkStatus();
5737     }
5738   };
5739 
5740   if (!called && check_no_side_input_or_activation().ok()) {
5741     RETURN_IF_CUDNN_ERROR(cudnnBatchNormalizationBackward(
5742         cudnn.handle(), mode, &one, &zero, &one, &zero, x_descriptor.handle(),
5743         x.opaque(), x_descriptor.handle(), y_backprop.opaque(),
5744         x_descriptor.handle(), x_backprop->opaque(),
5745         scale_offset_descriptor.handle(), scale.opaque(),
5746         scale_backprop->opaque(), offset_backprop->opaque(), epsilon,
5747         mean.opaque(), inv_var.opaque()));
5748   }
5749 
5750   return ::tensorflow::OkStatus();
5751 }
5752 
DoFusedConvolve(Stream * stream,dnn::DataType input_type,dnn::DataType side_input_type,dnn::DataType bias_type,dnn::DataType output_type,const dnn::BatchDescriptor & conv_input_descriptor,DeviceMemoryBase conv_input_data,double conv_scale,const dnn::FilterDescriptor & filter_descriptor,DeviceMemoryBase filter_data,const dnn::ConvolutionDescriptor & convolution_descriptor,DeviceMemoryBase side_input_data,double side_input_scale,const dnn::BatchDescriptor & bias_descriptor,DeviceMemoryBase biases,dnn::ActivationMode activation_mode,const dnn::BatchDescriptor & output_descriptor,DeviceMemoryBase output_data,ScratchAllocator * scratch_allocator,const dnn::AlgorithmConfig & algorithm_config,dnn::ProfileResult * output_profile_result)5753 port::Status CudnnSupport::DoFusedConvolve(
5754     Stream* stream, dnn::DataType input_type, dnn::DataType side_input_type,
5755     dnn::DataType bias_type, dnn::DataType output_type,
5756     const dnn::BatchDescriptor& conv_input_descriptor,
5757     DeviceMemoryBase conv_input_data, double conv_scale,
5758     const dnn::FilterDescriptor& filter_descriptor,
5759     DeviceMemoryBase filter_data,
5760     const dnn::ConvolutionDescriptor& convolution_descriptor,
5761     DeviceMemoryBase side_input_data, double side_input_scale,
5762     const dnn::BatchDescriptor& bias_descriptor, DeviceMemoryBase biases,
5763     dnn::ActivationMode activation_mode,
5764     const dnn::BatchDescriptor& output_descriptor, DeviceMemoryBase output_data,
5765     ScratchAllocator* scratch_allocator,
5766     const dnn::AlgorithmConfig& algorithm_config,
5767     dnn::ProfileResult* output_profile_result) {
5768   if (input_type == dnn::DataType::kInt8 &&
5769       !stream->GetCudaComputeCapability().IsAtLeast(6, 1)) {
5770     return port::UnimplementedError(
5771         "cudnnConvolutionBiasActivationForward() for int8 is only supported "
5772         "on GPUs with compute capability 6.1 or later.");
5773   }
5774 
5775   if (input_type == dnn::DataType::kInt8 &&
5776       output_type == dnn::DataType::kFloat &&
5777       (CUDNN_VERSION >= 8000 && CUDNN_VERSION <= 8200)) {
5778     return port::UnimplementedError(
5779         "int8 -> float fused conv is disabled for this cuDNN version. See "
5780         "go/nvbugs/3326122");
5781   }
5782 
5783   if (activation_mode != dnn::ActivationMode::kRelu &&
5784       activation_mode != dnn::ActivationMode::kNone) {
5785     return port::Status(port::error::INVALID_ARGUMENT,
5786                         "cudnnConvolutionBiasActivationForward() only supports "
5787                         "Relu or None activation.");
5788   }
5789 
5790   CudnnTensorDescriptor conv_input_nd(
5791       conv_input_descriptor,
5792       ToCudnnDataType(input_type, conv_input_descriptor.layout()));
5793   CudnnTensorDescriptor output_nd(
5794       output_descriptor,
5795       ToCudnnDataType(output_type, conv_input_descriptor.layout()));
5796   CudnnFilterDescriptor filter(
5797       filter_descriptor,
5798       ToCudnnDataType(input_type, filter_descriptor.layout()));
5799   CudnnTensorDescriptor bias_nd(bias_descriptor, ToCudnnDataType(bias_type));
5800 
5801   DeviceMemory<uint8> scratch;
5802   dnn::AlgorithmDesc algo_desc;
5803   {
5804     auto cudnn = cudnn_->GetHandle(parent_, stream);
5805     TF_ASSIGN_OR_RETURN(
5806         algo_desc,
5807         GetCudnnConvolutionForwardAlgorithm(
5808             stream, cudnn, algorithm_config, conv_input_nd, filter, input_type,
5809             convolution_descriptor, output_nd, scratch_allocator, &scratch));
5810   }  // Explicitly release cuDNN handle.
5811 
5812   CudnnConvolutionDescriptor conv(
5813       convolution_descriptor,
5814       ToCudnnDataType(GetConvAccumulatorType(input_type)));
5815   conv.set_use_tensor_op_math(algo_desc.tensor_ops_enabled());
5816 
5817   // CUDNN v6 only supports CUDNN_NOT_PROPAGATE_NAN as the reluNanOpt for
5818   // activation descriptor. Note that this will change the nan propagation
5819   // behavior from separate conv, bias, and relu (which by default is
5820   // CUDNN_PROPAGATE_NAN).
5821   CudnnActivationDescriptor activation_desc(
5822       activation_mode, CUDNN_NOT_PROPAGATE_NAN, output_descriptor.value_max());
5823 
5824   TF_ASSIGN_OR_RETURN(
5825       auto runner,
5826       CudnnLegacyFusedConvRunner::Create(
5827           parent_, stream, cudnn_.get(), std::move(algo_desc), input_type,
5828           conv_scale, side_input_scale, std::move(conv_input_nd),
5829           std::move(output_nd), std::move(filter), std::move(bias_nd),
5830           std::move(conv), std::move(activation_desc)));
5831 
5832   return runner(stream, output_profile_result, scratch, conv_input_data,
5833                 filter_data, side_input_data, biases, output_data);
5834 }
5835 
DoPrepareForCtcLoss(Stream * stream,dnn::DataType element_type,const dnn::RnnStateTensorDescriptor & probs_desc,const dnn::RnnStateTensorDescriptor & grads_desc,absl::Span<const int> labels_data,absl::Span<const int> labels_lengths_data,absl::Span<const int> input_lengths_data,ScratchAllocator * scratch_allocator,DeviceMemory<uint8> * scratch_memory,int * ctc_loss_algo_id)5836 port::Status CudnnSupport::DoPrepareForCtcLoss(
5837     Stream* stream, dnn::DataType element_type,
5838     const dnn::RnnStateTensorDescriptor& probs_desc,
5839     const dnn::RnnStateTensorDescriptor& grads_desc,
5840     absl::Span<const int> labels_data,
5841     absl::Span<const int> labels_lengths_data,
5842     absl::Span<const int> input_lengths_data,
5843     ScratchAllocator* scratch_allocator, DeviceMemory<uint8>* scratch_memory,
5844     int* ctc_loss_algo_id) {
5845   auto cudnn = cudnn_->GetHandle(parent_, stream);
5846   // Query the workspace size.
5847   size_t workspace_size_in_bytes = 0;
5848 #if CUDNN_VERSION >= 7603
5849   CudnnCtcLossDescriptor cudnn_ctc_loss_desc(ToCudnnDataType(element_type));
5850   const CudnnRnnStateTensorDescriptor& cudnn_probs_desc =
5851       static_cast<const CudnnRnnStateTensorDescriptor&>(probs_desc);
5852   const CudnnRnnStateTensorDescriptor& cudnn_grads_desc =
5853       static_cast<const CudnnRnnStateTensorDescriptor&>(grads_desc);
5854 
5855   // Try running with `algo`, if successful then pick it. The
5856   // non-deterministic algorithm is first and thus preferentially picked when
5857   // determinism is not required.
5858   auto algo = RequireCudnnDeterminism() ? CUDNN_CTC_LOSS_ALGO_DETERMINISTIC
5859                                         : CUDNN_CTC_LOSS_ALGO_NON_DETERMINISTIC;
5860   cudnnStatus_t status = cudnnGetCTCLossWorkspaceSize(
5861       /*handle=*/cudnn.handle(), /*probsDesc=*/cudnn_probs_desc.handle(),
5862       /*gradientsDesc=*/cudnn_grads_desc.handle(),
5863       /*labels=*/labels_data.data(),
5864       /*labelLengths=*/labels_lengths_data.data(),
5865       /*inputLengths=*/input_lengths_data.data(),
5866       /*algo=*/algo,
5867       /*ctcLossDesc=*/cudnn_ctc_loss_desc.handle(),
5868       /*sizeInBytes=*/&workspace_size_in_bytes);
5869   if (RequireCudnnDeterminism()) {
5870     RETURN_IF_CUDNN_ERROR(status);
5871   }
5872 
5873   if (status != CUDNN_STATUS_SUCCESS) {
5874     algo = CUDNN_CTC_LOSS_ALGO_DETERMINISTIC;
5875     RETURN_IF_CUDNN_ERROR(cudnnGetCTCLossWorkspaceSize(
5876         /*handle=*/cudnn.handle(), /*probsDesc=*/cudnn_probs_desc.handle(),
5877         /*gradientsDesc=*/cudnn_grads_desc.handle(),
5878         /*labels=*/labels_data.data(),
5879         /*labelLengths=*/labels_lengths_data.data(),
5880         /*inputLengths=*/input_lengths_data.data(),
5881         /*algo=*/algo,
5882         /*ctcLossDesc=*/cudnn_ctc_loss_desc.handle(),
5883         /*sizeInBytes=*/&workspace_size_in_bytes));
5884   }
5885   *ctc_loss_algo_id = algo;
5886 #else
5887   return port::Status(port::error::INVALID_ARGUMENT,
5888                       "No supported cudnnGetCTCLossWorkspaceSize when "
5889                       "CUDNN_VERSION < 7.6.3");
5890 #endif
5891   // Allocate the workspace.
5892   if (workspace_size_in_bytes == 0) {
5893     *scratch_memory = DeviceMemory<uint8>();
5894     return ::tensorflow::OkStatus();
5895   }
5896   const auto scratch_or =
5897       scratch_allocator->AllocateBytes(workspace_size_in_bytes);
5898   if (scratch_or.ok()) {
5899     *scratch_memory = scratch_or.ValueOrDie();
5900     return ::tensorflow::OkStatus();
5901   }
5902   return port::InternalError(
5903       "Failed to allocate scratch memory for the CuDNN CTC Loss");
5904 }
5905 
DoCtcLoss(Stream * stream,dnn::DataType element_type,const dnn::RnnStateTensorDescriptor & probs_desc,const DeviceMemoryBase probs_data,absl::Span<const int> labels_data,absl::Span<const int> labels_lengths_data,absl::Span<const int> input_lengths_data,DeviceMemoryBase costs_data,const dnn::RnnStateTensorDescriptor & grads_desc,DeviceMemoryBase grads_data,DeviceMemory<uint8> scratch_memory,int ctc_loss_algo_id)5906 port::Status CudnnSupport::DoCtcLoss(
5907     Stream* stream, dnn::DataType element_type,
5908     const dnn::RnnStateTensorDescriptor& probs_desc,
5909     const DeviceMemoryBase probs_data, absl::Span<const int> labels_data,
5910     absl::Span<const int> labels_lengths_data,
5911     absl::Span<const int> input_lengths_data, DeviceMemoryBase costs_data,
5912     const dnn::RnnStateTensorDescriptor& grads_desc,
5913     DeviceMemoryBase grads_data, DeviceMemory<uint8> scratch_memory,
5914     int ctc_loss_algo_id) {
5915   // Current cuDNN CTC Loss only supports the float datatype
5916   if (CUDNN_VERSION < 7603 || element_type != dnn::DataType::kFloat) {
5917     return port::Status(port::error::INVALID_ARGUMENT,
5918                         "CudnnCtcLossDescriptor is supported only when the "
5919                         "CUDNN_VERSION >= 7.6.3 and DataType is float");
5920   }
5921   CudnnCtcLossDescriptor cudnn_ctc_loss_desc(ToCudnnDataType(element_type));
5922   const CudnnRnnStateTensorDescriptor& cudnn_probs_desc =
5923       static_cast<const CudnnRnnStateTensorDescriptor&>(probs_desc);
5924   const CudnnRnnStateTensorDescriptor& cudnn_grads_desc =
5925       static_cast<const CudnnRnnStateTensorDescriptor&>(grads_desc);
5926   return DoCtcLossImpl(stream, cudnn_probs_desc, probs_data, labels_data,
5927                        labels_lengths_data, input_lengths_data, costs_data,
5928                        cudnn_grads_desc, grads_data, cudnn_ctc_loss_desc,
5929                        scratch_memory, ctc_loss_algo_id);
5930 }
5931 
DoTransformTensor(Stream * stream,const dnn::BatchDescriptor & input_desc,dnn::DataType input_type,const DeviceMemoryBase & input_data,const dnn::BatchDescriptor & output_desc,dnn::DataType output_type,float scale,DeviceMemoryBase * output_data)5932 bool CudnnSupport::DoTransformTensor(Stream* stream,
5933                                      const dnn::BatchDescriptor& input_desc,
5934                                      dnn::DataType input_type,
5935                                      const DeviceMemoryBase& input_data,
5936                                      const dnn::BatchDescriptor& output_desc,
5937                                      dnn::DataType output_type, float scale,
5938                                      DeviceMemoryBase* output_data) {
5939   float beta = 0.0f;
5940   CudnnTensorDescriptor input_tensor_desc(
5941       input_desc, ToCudnnDataType(input_type, input_desc.layout()));
5942   CudnnTensorDescriptor output_tensor_desc(
5943       output_desc, ToCudnnDataType(output_type, output_desc.layout()));
5944   auto cudnn = cudnn_->GetHandle(parent_, stream);
5945   const auto status = [&] {
5946     RETURN_IF_CUDNN_ERROR(cudnnTransformTensor(
5947         cudnn.handle(), &scale, input_tensor_desc.handle(), input_data.opaque(),
5948         &beta, output_tensor_desc.handle(), output_data->opaque()));
5949     return ::tensorflow::OkStatus();
5950   }();
5951   return IsStatusOk(status, /*report_error=*/true);
5952 }
5953 
DoMatMul(Stream * stream,const DeviceMemory<float> & input_data,const DeviceMemory<float> & weights,const dnn::BatchDescriptor & input_dimensions,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<float> * output_data)5954 bool CudnnSupport::DoMatMul(Stream* stream,
5955                             const DeviceMemory<float>& input_data,
5956                             const DeviceMemory<float>& weights,
5957                             const dnn::BatchDescriptor& input_dimensions,
5958                             const dnn::BatchDescriptor& output_dimensions,
5959                             DeviceMemory<float>* output_data) {
5960   if (input_dimensions.count() != output_dimensions.count()) {
5961     LOG(ERROR) << "MatMul input and output dimensions are not compatible.";
5962     return false;
5963   }
5964 
5965   // We do not permute the input or output, instead we just
5966   // reinterpret the layout. We are working with row-major matrices
5967   // and the rows of the input and output correspond to batch, so
5968   // batch has to be outermost in both the input and output.
5969   //
5970   // By adding transposes to the BLAS gemm call we could perhaps make
5971   // the kYXDepthBatch layout work as well, but there has been no need
5972   // for that so far.
5973   if (input_dimensions.layout() != dnn::DataLayout::kBatchYXDepth &&
5974       input_dimensions.layout() != dnn::DataLayout::kBatchDepthYX) {
5975     LOG(ERROR) << "Unsupported MatMul input layout.";
5976     return false;
5977   }
5978   if (output_dimensions.layout() != dnn::DataLayout::kBatchYXDepth &&
5979       output_dimensions.layout() != dnn::DataLayout::kBatchDepthYX) {
5980     LOG(ERROR) << "Unsupported MatMul output layout.";
5981     return false;
5982   }
5983 
5984   if (output_dimensions.width() == 1 && output_dimensions.height() == 1) {
5985     // This is a fast path that also supports the kBatchYXDepth layout.
5986 
5987     // The matrices here are in row-major format while BLAS expects
5988     // column-major, i.e. our matrices are transposed as far as BLAS
5989     // is concerned. So we need to compute output^T =
5990     // input^T*weights^T. There is no parameter for transposing the
5991     // output in BLAS gemm, but instead we can transpose both sides of
5992     // the equality to see that this is equivalent to
5993     // output=weights*input. So we only need to swap the order of
5994     // weights and input in the matrix product to correct for the
5995     // row-major versus column-major difference.
5996     const int64_t m = output_dimensions.NodesAcrossFeatureMaps();
5997     const int64_t n = input_dimensions.count();
5998     const int64_t k = input_dimensions.NodesAcrossFeatureMaps();
5999     if (!stream
6000              ->ThenBlasGemm(blas::Transpose::kNoTranspose,
6001                             blas::Transpose::kNoTranspose, m, n, k, weights, m,
6002                             input_data, k, output_data, m,
6003                             blas::kDefaultComputePrecision)
6004              .ok()) {
6005       return false;
6006     }
6007   } else {
6008     // This is a slower and more complex path that supports output
6009     // width() * height() > 1, though it only supports the
6010     // kBatchYXDepth layout. Does support kBatchDepthYX if output
6011     // feature_map_count() == 1, as then there is no difference
6012     // between the two layouts.
6013     //
6014     // The operation here is the same as above, except that we have to
6015     // do the matrix multiplication for each (y,x) output coordinate
6016     // separately. We then interpret weights as containing K = width()
6017     // * height() different matrices, which we all multiply onto the
6018     // matrix from input_data, yielding K matrix products. We then
6019     // combine these together into one matrix by concatenating all the
6020     // first rows of these matrices, then all the seconds rows and so
6021     // on. We can do this with a batched matrix multiplication, where
6022     // the result is written to a different submatrix of the output
6023     // for each matrix multiplication.
6024     //
6025     // The reason that we only support the kBatchYXDepth output layout
6026     // is that we have to do something in the depth for each (y,x)
6027     // coordinate. The kBatchYXDepth layout has the depth information
6028     // for each point (y,x) in contiguous memory while the
6029     // kBatchDepthYX layout does not.
6030     //
6031     // TODO(broune): Consider a special case for when output depth ==
6032     // 1, as then possibly this could all be done as one matrix
6033     // multiplication instead of a batched one, which should be
6034     // faster. Another possibility would be to add a weights layout
6035     // parameter and then support kBatchDepthYX for a different
6036     // weights layout.
6037     if (output_dimensions.layout() != dnn::DataLayout::kBatchYXDepth &&
6038         !(output_dimensions.layout() == dnn::DataLayout::kBatchDepthYX &&
6039           output_dimensions.feature_map_count() == 1)) {
6040       LOG(ERROR) << "Unsupported MatMul output layout.";
6041       return false;
6042     }
6043 
6044     const float alpha = 1.0f;  // Take the matrix product without scaling it.
6045     const float beta = 0.0f;   // Ignore the original values in output_data.
6046     const uint64_t m = output_dimensions.feature_map_count();
6047     const uint64_t n = input_dimensions.count();
6048     const uint64_t k = input_dimensions.NodesAcrossFeatureMaps();
6049     const int lda = m;
6050     const int ldb = k;
6051     const int ldc = output_dimensions.NodesAcrossFeatureMaps();
6052     const int batch_count = output_dimensions.NodesPerFeatureMap();
6053 
6054     std::vector<DeviceMemory<float>> a(batch_count);
6055     std::vector<DeviceMemory<float>> b(batch_count);
6056     std::vector<DeviceMemory<float>> c(batch_count);
6057     for (int i = 0; i < batch_count; ++i) {
6058       const int weights_offset = i * input_dimensions.NodesAcrossFeatureMaps() *
6059                                  output_dimensions.feature_map_count();
6060       a[i] = DeviceMemory<float>::MakeFromByteSize(
6061           const_cast<float*>(reinterpret_cast<const float*>(weights.opaque())) +
6062               weights_offset,
6063           weights.ElementCount() - weights_offset);
6064 
6065       b[i] = input_data;
6066 
6067       const int output_offset = i * output_dimensions.feature_map_count();
6068       c[i] = DeviceMemory<float>::MakeFromByteSize(
6069           const_cast<float*>(
6070               reinterpret_cast<const float*>(output_data->opaque())) +
6071               output_offset,
6072           output_data->ElementCount() - output_offset);
6073     }
6074     const auto toPtrs = [](std::vector<DeviceMemory<float>>& v) {
6075       std::vector<DeviceMemory<float>*> ptrs;
6076       ptrs.reserve(v.size());
6077       for (auto& mem : v) {
6078         ptrs.push_back(&mem);
6079       }
6080       return ptrs;
6081     };
6082 
6083     stream->ThenBlasGemmBatched(blas::Transpose::kNoTranspose,
6084                                 blas::Transpose::kNoTranspose, m, n, k, alpha,
6085                                 toPtrs(a), lda, toPtrs(b), ldb, beta, toPtrs(c),
6086                                 ldc, batch_count);
6087   }
6088 
6089   return stream->ok();
6090 }
6091 
DoBiasAdd(Stream * stream,const DeviceMemory<float> & input_data,const DeviceMemory<float> & biases,const dnn::BatchDescriptor & dimensions,DeviceMemory<float> * output_data)6092 bool CudnnSupport::DoBiasAdd(Stream* stream,
6093                              const DeviceMemory<float>& input_data,
6094                              const DeviceMemory<float>& biases,
6095                              const dnn::BatchDescriptor& dimensions,
6096                              DeviceMemory<float>* output_data) {
6097   CudnnTensorDescriptor input_descriptor(dimensions, CUDNN_DATA_FLOAT);
6098 
6099   dnn::BatchDescriptor bias_dimensions;
6100   bias_dimensions.set_count(1)
6101       .set_feature_map_count(dimensions.feature_map_count())
6102       .set_height(1)
6103       .set_width(1)
6104       .set_layout(dnn::DataLayout::kBatchYXDepth);
6105   CudnnTensorDescriptor bias_descriptor(bias_dimensions, CUDNN_DATA_FLOAT);
6106 
6107   // cudnnAddTensor after R3 is in-place, so we need to copy input_data to
6108   // output_data before doing the addition, unless the input and
6109   // output are at the same address.
6110   if (input_data.opaque() != output_data->opaque()) {
6111     stream->ThenMemcpy(output_data, input_data,
6112                        dimensions.ElementCount() * sizeof(float));
6113     if (!stream->ok()) {
6114       LOG(ERROR)
6115           << "stream " << stream
6116           << " could not enqueue a tensor copy as part of bias addition.";
6117       return false;
6118     }
6119   }
6120 
6121   const float alpha = 1.0f;
6122   const float beta = 1.0f;
6123 
6124   auto cudnn = cudnn_->GetHandle(parent_, stream);
6125 
6126   const auto status = [&] {
6127     RETURN_IF_CUDNN_ERROR(cudnnAddTensor(
6128         cudnn.handle(), &alpha, bias_descriptor.handle(), biases.opaque(),
6129         &beta, input_descriptor.handle(), output_data->opaque()));
6130     return ::tensorflow::OkStatus();
6131   }();
6132   return IsStatusOk(status, /*report_error=*/true);
6133 }
6134 
DoActivate(Stream * stream,dnn::ActivationMode activation_mode,const dnn::BatchDescriptor & dimensions,const DeviceMemory<float> & input_data,DeviceMemory<float> * output_data,uint64_t options)6135 bool CudnnSupport::DoActivate(Stream* stream,
6136                               dnn::ActivationMode activation_mode,
6137                               const dnn::BatchDescriptor& dimensions,
6138                               const DeviceMemory<float>& input_data,
6139                               DeviceMemory<float>* output_data,
6140                               uint64_t options) {
6141   CudnnActivationDescriptor activation_desc(
6142       activation_mode, CUDNN_PROPAGATE_NAN, dimensions.value_max());
6143 
6144   CudnnTensorDescriptor input_nd(dimensions, CUDNN_DATA_FLOAT);
6145   // Alpha is the input scaling factor.
6146   float alpha = 1.0;
6147   // Beta is the output scaling factor.
6148   float beta = 0.0;
6149 
6150   auto cudnn = cudnn_->GetHandle(parent_, stream);
6151   const auto status = [&] {
6152     RETURN_IF_CUDNN_ERROR(cudnnActivationForward(
6153         cudnn.handle(), activation_desc.handle(), &alpha, input_nd.handle(),
6154         input_data.opaque(), &beta, input_nd.handle(), output_data->opaque()));
6155     return ::tensorflow::OkStatus();
6156   }();
6157   return IsStatusOk(status, /*report_error=*/true);
6158 }
6159 
6160 namespace {
6161 
6162 // Cudnn legacy API only supports int32 indexing and can handle a maximum of
6163 // 2^31-1 elements. For pooling operations, we split the big tensor along the
6164 // batch axis into multiple small tensors when possible and then call cudnn API
6165 // sequentially.
6166 struct PoolingSplitsSpec {
6167   int64_t num_batches;
6168   int64_t input_offset_in_bytes;
6169   int64_t output_offset_in_bytes;
6170 };
6171 
GetTensorSplits(const dnn::BatchDescriptor & input_descriptor,const dnn::BatchDescriptor & output_descriptor,dnn::DataType element_type)6172 port::StatusOr<std::vector<PoolingSplitsSpec>> GetTensorSplits(
6173     const dnn::BatchDescriptor& input_descriptor,
6174     const dnn::BatchDescriptor& output_descriptor, dnn::DataType element_type) {
6175   std::vector<PoolingSplitsSpec> out;
6176   if (element_type == dnn::DataType::kInt8) {
6177     out.push_back({input_descriptor.count(), 0, 0});
6178     return out;
6179   }
6180 
6181   cudnnDataType_t cudnn_input_type =
6182       ToCudnnDataType(element_type, input_descriptor.layout());
6183   cudnnDataType_t cudnn_output_type =
6184       ToCudnnDataType(element_type, output_descriptor.layout());
6185 
6186   std::vector<int64_t> dims64 =
6187       input_descriptor.full_dims(dnn::DataLayout::kBatchDepthYX);
6188 
6189   int64_t num_batches = input_descriptor.count();
6190   int64_t elements_per_batch_input = input_descriptor.NodesAcrossFeatureMaps();
6191   int64_t elements_per_batch_output =
6192       output_descriptor.NodesAcrossFeatureMaps();
6193 
6194   int64_t max_batches_per_split =
6195       std::numeric_limits<int>::max() / elements_per_batch_input;
6196 
6197   if (max_batches_per_split == 0) {
6198     return port::Status(
6199         port::error::INTERNAL,
6200         absl::StrCat(
6201             "Tensor has too many elements for int32 indexing: batches=",
6202             num_batches, " elements_per_batch=", elements_per_batch_input,
6203             "."));
6204   }
6205 
6206   int64_t processed_batches = 0;
6207   while (processed_batches < num_batches) {
6208     int64_t num_batches_per_split =
6209         std::min(max_batches_per_split, num_batches - processed_batches);
6210     int64_t offset_input = processed_batches * elements_per_batch_input *
6211                            CudnnDataTypeToByteSize(cudnn_input_type);
6212     int64_t offset_output = processed_batches * elements_per_batch_output *
6213                             CudnnDataTypeToByteSize(cudnn_output_type);
6214     out.push_back({num_batches_per_split, offset_input, offset_output});
6215     processed_batches += num_batches_per_split;
6216   }
6217   return out;
6218 }
6219 }  // namespace
6220 
DoPoolForward(dnn::DataType element_type,Stream * stream,const dnn::PoolingDescriptor & pooling_dimensions,const dnn::BatchDescriptor & input_dimensions,DeviceMemoryBase input_data,const dnn::BatchDescriptor & output_dimensions,DeviceMemoryBase output_data,ScratchAllocator * workspace_allocator)6221 port::Status CudnnSupport::DoPoolForward(
6222     dnn::DataType element_type, Stream* stream,
6223     const dnn::PoolingDescriptor& pooling_dimensions,
6224     const dnn::BatchDescriptor& input_dimensions, DeviceMemoryBase input_data,
6225     const dnn::BatchDescriptor& output_dimensions, DeviceMemoryBase output_data,
6226     ScratchAllocator* workspace_allocator) {
6227   // Alpha is the scaling factor for input.
6228   const float alpha_f = 1.0f;
6229   const double alpha_d = 1.0;
6230   const void* alpha = element_type == dnn::DataType::kDouble
6231                           ? static_cast<const void*>(&alpha_d)
6232                           : static_cast<const void*>(&alpha_f);
6233   // Beta is the scaling factor for output.
6234   const float beta_f = 0.0f;
6235   const double beta_d = 0.0;
6236   const void* beta = element_type == dnn::DataType::kDouble
6237                          ? static_cast<const void*>(&beta_d)
6238                          : static_cast<const void*>(&beta_f);
6239 
6240   cudnnDataType_t cudnn_input_type =
6241       ToCudnnDataType(element_type, input_dimensions.layout());
6242   cudnnDataType_t cudnn_output_type =
6243       ToCudnnDataType(element_type, output_dimensions.layout());
6244   CudnnPoolingDescriptor pooling_desc(pooling_dimensions);
6245   auto cudnn = cudnn_->GetHandle(parent_, stream);
6246 
6247   auto cudnn_launcher = [&](CudnnTensorDescriptor& src_desc,
6248                             CudnnTensorDescriptor& dest_desc,
6249                             const void* input_ptr, void* output_ptr) {
6250     RETURN_IF_CUDNN_ERROR(cudnnPoolingForward(
6251         cudnn.handle(), pooling_desc.handle(), alpha, src_desc.handle(),
6252         input_ptr, beta, dest_desc.handle(), output_ptr));
6253     return ::tensorflow::OkStatus();
6254   };
6255 
6256   auto splits_or =
6257       GetTensorSplits(input_dimensions, output_dimensions, element_type);
6258   if (!splits_or.ok()) {
6259     return port::Status(port::error::INTERNAL, "Cudnn pooling failed to split");
6260   }
6261   auto splits = std::move(splits_or.ValueOrDie());
6262 
6263   dnn::BatchDescriptor input_split = input_dimensions;
6264   dnn::BatchDescriptor output_split = output_dimensions;
6265   for (int i = 0; i < splits.size(); i++) {
6266     // It is safe to cap the batch dimension, since it is the leading
6267     // dimension and will have no effect on the computation of strides in both
6268     // kBatchYXDepth and kBatchDepthYX formats.
6269     input_split.set_count(splits[i].num_batches);
6270     output_split.set_count(splits[i].num_batches);
6271     CudnnTensorDescriptor src_desc(input_split, cudnn_input_type);
6272     CudnnTensorDescriptor dest_desc(output_split, cudnn_output_type);
6273 
6274     void* input_data_ptr = static_cast<char*>(input_data.opaque()) +
6275                            splits[i].input_offset_in_bytes;
6276     void* output_data_ptr = static_cast<char*>(output_data.opaque()) +
6277                             splits[i].output_offset_in_bytes;
6278     const auto status =
6279         cudnn_launcher(src_desc, dest_desc, input_data_ptr, output_data_ptr);
6280     if (!IsStatusOk(status, /*report_error=*/true)) {
6281       return status;
6282     }
6283   }
6284   return ::tensorflow::OkStatus();
6285 }
6286 
DoPoolBackward(dnn::DataType element_type,Stream * stream,const dnn::PoolingDescriptor & pooling_dimensions,const dnn::BatchDescriptor & input_dimensions,DeviceMemoryBase input_data,const dnn::BatchDescriptor & output_dimensions,DeviceMemoryBase output_data,DeviceMemoryBase input_diff_data,DeviceMemoryBase output_diff_data,ScratchAllocator * workspace_allocator)6287 port::Status CudnnSupport::DoPoolBackward(
6288     dnn::DataType element_type, Stream* stream,
6289     const dnn::PoolingDescriptor& pooling_dimensions,
6290     const dnn::BatchDescriptor& input_dimensions, DeviceMemoryBase input_data,
6291     const dnn::BatchDescriptor& output_dimensions, DeviceMemoryBase output_data,
6292     DeviceMemoryBase input_diff_data, DeviceMemoryBase output_diff_data,
6293     ScratchAllocator* workspace_allocator) {
6294   // Alpha is the scaling factor for input.
6295   const float alpha_f = 1.0f;
6296   const double alpha_d = 1.0;
6297   const void* alpha = element_type == dnn::DataType::kDouble
6298                           ? static_cast<const void*>(&alpha_d)
6299                           : static_cast<const void*>(&alpha_f);
6300   // Beta is the scaling factor for output.
6301   const float beta_f = 0.0f;
6302   const double beta_d = 0.0;
6303   const void* beta = element_type == dnn::DataType::kDouble
6304                          ? static_cast<const void*>(&beta_d)
6305                          : static_cast<const void*>(&beta_f);
6306 
6307   cudnnDataType_t cudnn_input_type =
6308       ToCudnnDataType(element_type, input_dimensions.layout());
6309   cudnnDataType_t cudnn_output_type =
6310       ToCudnnDataType(element_type, output_dimensions.layout());
6311   CudnnPoolingDescriptor pooling_desc(pooling_dimensions);
6312   auto cudnn = cudnn_->GetHandle(parent_, stream);
6313 
6314   auto cudnn_launcher = [&](CudnnTensorDescriptor& src_desc,
6315                             CudnnTensorDescriptor& dest_desc,
6316                             const void* output_ptr, const void* input_diff_ptr,
6317                             const void* input_ptr, void* output_diff_ptr) {
6318     RETURN_IF_CUDNN_ERROR(cudnnPoolingBackward(
6319         cudnn.handle(), pooling_desc.handle(), alpha, dest_desc.handle(),
6320         output_ptr, dest_desc.handle(), input_diff_ptr, src_desc.handle(),
6321         input_ptr, beta, src_desc.handle(), output_diff_ptr));
6322     return ::tensorflow::OkStatus();
6323   };
6324 
6325   auto splits_or =
6326       GetTensorSplits(input_dimensions, output_dimensions, element_type);
6327   if (!splits_or.ok()) {
6328     return port::Status(port::error::INTERNAL, "Cudnn pooling failed to split");
6329   }
6330   auto splits = std::move(splits_or.ValueOrDie());
6331 
6332   dnn::BatchDescriptor input_split = input_dimensions;
6333   dnn::BatchDescriptor output_split = output_dimensions;
6334   for (int i = 0; i < splits.size(); i++) {
6335     // It is safe to cap the batch dimension, since it is the leading
6336     // dimension and will have no effect on the computation of strides in both
6337     // kBatchYXDepth and kBatchDepthYX formats.
6338     input_split.set_count(splits[i].num_batches);
6339     output_split.set_count(splits[i].num_batches);
6340     CudnnTensorDescriptor src_desc(input_split, cudnn_input_type);
6341     CudnnTensorDescriptor dest_desc(output_split, cudnn_output_type);
6342 
6343     void* output_data_ptr = static_cast<char*>(output_data.opaque()) +
6344                             splits[i].output_offset_in_bytes;
6345     void* input_diff_data_ptr = static_cast<char*>(input_diff_data.opaque()) +
6346                                 splits[i].output_offset_in_bytes;
6347     void* input_data_ptr = static_cast<char*>(input_data.opaque()) +
6348                            splits[i].input_offset_in_bytes;
6349     void* output_diff_data_ptr = static_cast<char*>(output_diff_data.opaque()) +
6350                                  splits[i].input_offset_in_bytes;
6351     const auto status = cudnn_launcher(src_desc, dest_desc, output_data_ptr,
6352                                        input_diff_data_ptr, input_data_ptr,
6353                                        output_diff_data_ptr);
6354     if (!IsStatusOk(status, /*report_error=*/true)) {
6355       return status;
6356     }
6357   }
6358   return ::tensorflow::OkStatus();
6359 }
6360 
DoNormalizeWithDimensions(Stream * stream,const dnn::NormalizeDescriptor & normalize_descriptor,const dnn::BatchDescriptor & dimensions,const DeviceMemory<float> & input_data,DeviceMemory<float> * output_data)6361 bool CudnnSupport::DoNormalizeWithDimensions(
6362     Stream* stream, const dnn::NormalizeDescriptor& normalize_descriptor,
6363     const dnn::BatchDescriptor& dimensions,
6364     const DeviceMemory<float>& input_data, DeviceMemory<float>* output_data) {
6365   // Check for unsupported modes.
6366   if (normalize_descriptor.wrap_around()) {
6367     LOG(ERROR) << "CUDA LRN does not support cudnn-around mode";
6368     return false;
6369   }
6370   if (normalize_descriptor.segment_size()) {
6371     LOG(ERROR) << "CUDA LRN does not support segmentation";
6372     return false;
6373   }
6374 
6375   CudnnTensorDescriptor dims(dimensions, CUDNN_DATA_FLOAT);
6376   CudnnNormalizeDescriptor normalize(normalize_descriptor);
6377 
6378   // Alpha is the scaling factor for input.
6379   float alpha = 1.0f;
6380   // Beta is the scaling factor for output.
6381   float beta = 0.0f;
6382 
6383   auto cudnn = cudnn_->GetHandle(parent_, stream);
6384 
6385   // Launch the normalization.
6386   const auto status = [&] {
6387     RETURN_IF_CUDNN_ERROR(cudnnLRNCrossChannelForward(
6388         cudnn.handle(), normalize.handle(), CUDNN_LRN_CROSS_CHANNEL_DIM1,
6389         &alpha, dims.handle(), input_data.opaque(), &beta, dims.handle(),
6390         output_data->opaque()));
6391     return ::tensorflow::OkStatus();
6392   }();
6393   return IsStatusOk(status, /*report_error=*/true);
6394 }
6395 
DoNormalizeBackwardWithDimensions(Stream * stream,const dnn::NormalizeDescriptor & normalize_descriptor,const dnn::BatchDescriptor & dimensions,const DeviceMemory<float> & raw_data,const DeviceMemory<float> & normalized_data,const DeviceMemory<float> & normalized_variable_gradient,DeviceMemory<float> * raw_variable_gradient,ScratchAllocator * workspace_allocator)6396 bool CudnnSupport::DoNormalizeBackwardWithDimensions(
6397     Stream* stream, const dnn::NormalizeDescriptor& normalize_descriptor,
6398     const dnn::BatchDescriptor& dimensions, const DeviceMemory<float>& raw_data,
6399     const DeviceMemory<float>& normalized_data,
6400     const DeviceMemory<float>& normalized_variable_gradient,
6401     DeviceMemory<float>* raw_variable_gradient,
6402     ScratchAllocator* workspace_allocator) {
6403   // Check for unsupported modes.
6404   if (normalize_descriptor.wrap_around()) {
6405     LOG(ERROR) << "CUDA LRN does not support cudnn-around mode";
6406     return false;
6407   }
6408   if (normalize_descriptor.segment_size()) {
6409     LOG(ERROR) << "CUDA LRN does not support segmentation";
6410     return false;
6411   }
6412 
6413   CudnnTensorDescriptor dims(dimensions, CUDNN_DATA_FLOAT);
6414   CudnnNormalizeDescriptor normalize(normalize_descriptor);
6415 
6416   float alpha = 1.0f;
6417   float beta = 0.0f;
6418 
6419   auto cudnn = cudnn_->GetHandle(parent_, stream);
6420   const auto status = [&] {
6421     RETURN_IF_CUDNN_ERROR(cudnnLRNCrossChannelBackward(
6422         cudnn.handle(), normalize.handle(), CUDNN_LRN_CROSS_CHANNEL_DIM1,
6423         &alpha, dims.handle(), normalized_data.opaque(), dims.handle(),
6424         normalized_variable_gradient.opaque(), dims.handle(), raw_data.opaque(),
6425         &beta, dims.handle(), raw_variable_gradient->opaque()));
6426     return ::tensorflow::OkStatus();
6427   }();
6428   return IsStatusOk(status, /*report_error=*/true);
6429 }
6430 
DoDepthConcatenate(Stream * stream,BatchDescriptorSlice input_dimensions,DeviceMemorySlice<float> input_data,DeviceMemory<float> * output_data)6431 bool CudnnSupport::DoDepthConcatenate(Stream* stream,
6432                                       BatchDescriptorSlice input_dimensions,
6433                                       DeviceMemorySlice<float> input_data,
6434                                       DeviceMemory<float>* output_data) {
6435   CHECK_EQ(input_dimensions.size(), input_data.size());
6436 
6437   for (const auto& dimensions : input_dimensions) {
6438     if (dimensions.layout() != dnn::DataLayout::kBatchDepthYX) {
6439       LOG(ERROR) << "CudnnSupport::DoDepthConcatenate currently only "
6440                     "supports the kBatchDepthYX layout.";
6441       return false;
6442     }
6443   }
6444 
6445   if (input_dimensions.empty()) {
6446     return true;  // Nothing to do.
6447   }
6448 
6449   dnn::BatchDescriptor output_dimensions =
6450       dnn::BatchDescriptor::DepthConcatenateOutputDescriptor(input_dimensions);
6451 
6452   const int64_t area = output_dimensions.width() * output_dimensions.height();
6453   const auto index = [area](int64_t batch, int64_t depth, int64_t yx,
6454                             int64_t max_depth) {
6455     return (batch * max_depth + depth) * area + yx;
6456   };
6457 
6458   std::vector<float> output_host(output_dimensions.ElementCount());
6459   std::vector<float> tmp;
6460   int64_t depth_sum = 0;
6461   for (size_t i = 0; i < input_data.size(); ++i) {
6462     const auto& dimensions = input_dimensions[i];
6463     tmp.resize(dimensions.ElementCount());
6464     stream->ThenMemcpyD2H<float>(*input_data[i], absl::MakeSpan(tmp));
6465     port::Status block_status = stream->BlockHostUntilDone();
6466     if (!block_status.ok()) {
6467       LOG(ERROR) << "BlockHostUntilDone failed: " << block_status;
6468       return false;
6469     }
6470 
6471     for (int64_t batch = 0; batch < output_dimensions.count(); ++batch) {
6472       for (int64_t yx = 0; yx < area; ++yx) {
6473         for (int64_t depth = 0; depth < dimensions.feature_map_count();
6474              ++depth) {
6475           LOG(INFO) << output_dimensions.ElementCount() << ' ' << batch << ' '
6476                     << yx << ' ' << depth;
6477           output_host[index(batch, depth + depth_sum, yx,
6478                             output_dimensions.feature_map_count())] =
6479               tmp[index(batch, depth, yx, dimensions.feature_map_count())];
6480         }
6481       }
6482     }
6483     depth_sum += dimensions.feature_map_count();
6484   }
6485   stream->ThenMemcpyH2D<float>(output_host, output_data);
6486   return true;
6487 }
6488 
DoElementwiseOperate(Stream *,dnn::ElementwiseOperation,BatchDescriptorSlice,DeviceMemorySlice<float>,const dnn::BatchDescriptor &,DeviceMemory<float> *)6489 bool CudnnSupport::DoElementwiseOperate(Stream*, dnn::ElementwiseOperation,
6490                                         BatchDescriptorSlice,
6491                                         DeviceMemorySlice<float>,
6492                                         const dnn::BatchDescriptor&,
6493                                         DeviceMemory<float>*) {
6494   LOG(FATAL) << "not yet implemented";  // TODO(leary)
6495   return false;
6496 }
6497 
DoXYPad(Stream * stream,const dnn::BatchDescriptor & dimensions,const DeviceMemory<float> & input_data,int64_t left_pad,int64_t right_pad,int64_t top_pad,int64_t bottom_pad,DeviceMemory<float> * output_data)6498 bool CudnnSupport::DoXYPad(Stream* stream,
6499                            const dnn::BatchDescriptor& dimensions,
6500                            const DeviceMemory<float>& input_data,
6501                            int64_t left_pad, int64_t right_pad, int64_t top_pad,
6502                            int64_t bottom_pad,
6503                            DeviceMemory<float>* output_data) {
6504   LOG(FATAL) << "not yet implemented";  // TODO(leary)
6505   return false;
6506 }
6507 
DoXYSlice(Stream * stream,const dnn::BatchDescriptor & dimensions,const DeviceMemory<float> & input_data,int64_t left_trim,int64_t right_trim,int64_t top_trim,int64_t bottom_trim,DeviceMemory<float> * output_data)6508 bool CudnnSupport::DoXYSlice(Stream* stream,
6509                              const dnn::BatchDescriptor& dimensions,
6510                              const DeviceMemory<float>& input_data,
6511                              int64_t left_trim, int64_t right_trim,
6512                              int64_t top_trim, int64_t bottom_trim,
6513                              DeviceMemory<float>* output_data) {
6514   LOG(FATAL) << "not yet implemented";  // TODO(leary)
6515   return false;
6516 }
6517 
DoMemcpyD2HQuantized(Stream * stream,const DeviceMemory<float> & gpu_unquantized_src,dnn::QuantizedActivationMode mode,void * host_dst,int64_t size)6518 bool CudnnSupport::DoMemcpyD2HQuantized(
6519     Stream* stream, const DeviceMemory<float>& gpu_unquantized_src,
6520     dnn::QuantizedActivationMode mode, void* host_dst, int64_t size) {
6521   LOG(ERROR) << "quantized memcpy not supported by cuDNN";
6522   return false;
6523 }
6524 
DoMemcpyH2DQuantized(Stream * stream,const void * host_src,int64_t size,dnn::QuantizedActivationMode mode,DeviceMemory<float> * gpu_unquantized_dst)6525 bool CudnnSupport::DoMemcpyH2DQuantized(
6526     Stream* stream, const void* host_src, int64_t size,
6527     dnn::QuantizedActivationMode mode,
6528     DeviceMemory<float>* gpu_unquantized_dst) {
6529   LOG(ERROR) << "quantized memcpy not supported by cuDNN";
6530   return false;
6531 }
6532 
DeriveOutputBatchDescriptor(const dnn::BatchDescriptor & batch_descriptor,const dnn::FilterDescriptor & filter_descriptor,const dnn::ConvolutionDescriptor & convolution_descriptor,dnn::BatchDescriptor * output_batch_descriptor)6533 bool CudnnSupport::DeriveOutputBatchDescriptor(
6534     const dnn::BatchDescriptor& batch_descriptor,
6535     const dnn::FilterDescriptor& filter_descriptor,
6536     const dnn::ConvolutionDescriptor& convolution_descriptor,
6537     dnn::BatchDescriptor* output_batch_descriptor) {
6538   CudnnTensorDescriptor input_nd(batch_descriptor, CUDNN_DATA_FLOAT);
6539   CudnnFilterDescriptor filter(filter_descriptor, CUDNN_DATA_FLOAT);
6540   CudnnConvolutionDescriptor conv(convolution_descriptor, CUDNN_DATA_FLOAT);
6541 
6542   int dn = batch_descriptor.ndims() + 2;
6543   std::vector<int> dims(dn);  // in BDYX
6544   const auto status = [&] {
6545     RETURN_IF_CUDNN_ERROR(cudnnGetConvolutionNdForwardOutputDim(
6546         conv.handle(), input_nd.handle(), filter.handle(), dn, dims.data()));
6547     output_batch_descriptor->set_count(dims[0])
6548         .set_feature_map_count(dims[1])
6549         .set_layout(batch_descriptor.layout());
6550 
6551     for (int i = 0; i < batch_descriptor.ndims(); i++) {
6552       output_batch_descriptor->set_spatial_dim(static_cast<dnn::DimIndex>(i),
6553                                                dims.rbegin()[i]);
6554     }
6555     return ::tensorflow::OkStatus();
6556   }();
6557   return IsStatusOk(status, /*report_error=*/true);
6558 }
6559 
6560 }  // namespace gpu
6561 
initialize_cudnn()6562 void initialize_cudnn() {
6563   port::Status status =
6564       PluginRegistry::Instance()->RegisterFactory<PluginRegistry::DnnFactory>(
6565           cuda::kCudaPlatformId, gpu::kCuDnnPlugin, "cuDNN",
6566           [](internal::StreamExecutorInterface* parent) -> dnn::DnnSupport* {
6567             gpu::GpuExecutor* cuda_executor =
6568                 dynamic_cast<gpu::GpuExecutor*>(parent);
6569             if (cuda_executor == nullptr) {
6570               LOG(ERROR) << "Attempting to initialize an instance of the cuDNN "
6571                          << "support library with a non-CUDA StreamExecutor";
6572               return nullptr;
6573             }
6574 
6575             gpu::CudnnSupport* dnn = new gpu::CudnnSupport(cuda_executor);
6576             if (!dnn->Init().ok()) {
6577               // Note: Init() will log a more specific error.
6578               delete dnn;
6579               return nullptr;
6580             }
6581             return dnn;
6582           });
6583 
6584   if (!status.ok()) {
6585     LOG(ERROR) << "Unable to register cuDNN factory: "
6586                << status.error_message();
6587   }
6588 
6589   PluginRegistry::Instance()->SetDefaultFactory(
6590       cuda::kCudaPlatformId, PluginKind::kDnn, gpu::kCuDnnPlugin);
6591 }
6592 
6593 }  // namespace stream_executor
6594 
6595 #pragma clang diagnostic pop
6596 
6597 REGISTER_MODULE_INITIALIZER(register_cudnn,
6598                             { stream_executor::initialize_cudnn(); });
6599