• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/stream_executor/cuda/cuda_dnn.h"
17 
18 #include <functional>
19 #include <memory>
20 #include <utility>
21 
22 #include "absl/strings/str_cat.h"
23 #include "third_party/eigen3/Eigen/Core"
24 #include "tensorflow/core/lib/core/errors.h"
25 #include "tensorflow/core/util/env_var.h"
26 #include "tensorflow/stream_executor/cuda/cuda_activation.h"
27 #include "tensorflow/stream_executor/cuda/cuda_diagnostics.h"
28 #include "tensorflow/stream_executor/cuda/cuda_driver.h"
29 #include "tensorflow/stream_executor/cuda/cuda_gpu_executor.h"
30 #include "tensorflow/stream_executor/cuda/cuda_platform_id.h"
31 #include "tensorflow/stream_executor/cuda/cuda_stream.h"
32 #include "tensorflow/stream_executor/cuda/cuda_timer.h"
33 #include "tensorflow/stream_executor/cuda/cudnn_version.h"
34 #include "tensorflow/stream_executor/dnn.h"
35 #include "tensorflow/stream_executor/lib/env.h"
36 #include "tensorflow/stream_executor/lib/error.h"
37 #include "tensorflow/stream_executor/lib/initialize.h"
38 #include "tensorflow/stream_executor/lib/mathutil.h"
39 #include "tensorflow/stream_executor/lib/threadpool.h"
40 #include "tensorflow/stream_executor/platform/logging.h"
41 #include "tensorflow/stream_executor/plugin_registry.h"
42 #include "tensorflow/stream_executor/scratch_allocator.h"
43 #include "tensorflow/stream_executor/stream.h"
44 #include "tensorflow/stream_executor/stream_executor_pimpl.h"
45 // clang-format off
46 #include "third_party/gpus/cudnn/cudnn.h"
47 #include "absl/strings/string_view.h"
48 // clang-format on
49 
50 #pragma clang diagnostic push
51 
52 // Make sure that Eigen::half forward declaration in dnn.h matches the
53 // declaration in Eigen.
54 #pragma clang diagnostic warning "-Wmismatched-tags"
55 
56 namespace stream_executor {
57 namespace gpu {
58 
59 PLUGIN_REGISTRY_DEFINE_PLUGIN_ID(kCuDnnPlugin);
60 
61 namespace {
62 
63 static_assert(CUDNN_VERSION >= 6000, "cuDNN needs to be version 6.0 or higher");
64 
65 // Exits the program if 'expr' doesn't return CUDNN_STATUS_SUCCESS.
66 #define CHECK_CUDNN_OK(expr) CHECK_EQ(expr, CUDNN_STATUS_SUCCESS)
67 
68 // If 'expr' doesn't return CUDNN_STATUS_SUCCESS, returns from the current
69 // function with a non-successful port::Status.
70 #define RETURN_IF_CUDNN_ERROR(expr)                                      \
71   do {                                                                   \
72     cudnnStatus_t _status = expr;                                        \
73     if (!SE_PREDICT_TRUE(_status == CUDNN_STATUS_SUCCESS)) {             \
74       std::ostringstream oss;                                            \
75       oss << ToString(_status) << "\nin " << __FILE__ << "(" << __LINE__ \
76           << "): '" << #expr << "'";                                     \
77       return port::Status(port::error::UNKNOWN, oss.str().c_str());      \
78     }                                                                    \
79   } while (false)
80 
81 // Converts (via narrowing) a type T value to a type U, and checks that the
82 // value has no value change due to the conversion.
83 template <typename WideT, typename NarrowT>
CheckedNarrowing(const WideT & wide)84 NarrowT CheckedNarrowing(const WideT& wide) {
85   NarrowT narrow = wide;
86   CHECK_EQ(narrow, wide)
87       << "checked narrowing failed; values not equal post-conversion";
88   return narrow;
89 }
90 
ToString(cudnnStatus_t status)91 string ToString(cudnnStatus_t status) {
92   switch (status) {
93     case CUDNN_STATUS_SUCCESS:
94       return "CUDNN_STATUS_SUCCESS";
95     case CUDNN_STATUS_NOT_INITIALIZED:
96       return "CUDNN_STATUS_NOT_INITIALIZED";
97     case CUDNN_STATUS_ALLOC_FAILED:
98       return "CUDNN_STATUS_ALLOC_FAILED";
99     case CUDNN_STATUS_BAD_PARAM:
100       return "CUDNN_STATUS_BAD_PARAM";
101     case CUDNN_STATUS_INTERNAL_ERROR:
102       return "CUDNN_STATUS_INTERNAL_ERROR";
103     case CUDNN_STATUS_INVALID_VALUE:
104       return "CUDNN_STATUS_INVALID_VALUE";
105     case CUDNN_STATUS_ARCH_MISMATCH:
106       return "CUDNN_STATUS_ARCH_MISMATCH";
107     case CUDNN_STATUS_MAPPING_ERROR:
108       return "CUDNN_STATUS_MAPPING_ERROR";
109     case CUDNN_STATUS_EXECUTION_FAILED:
110       return "CUDNN_STATUS_EXECUTION_FAILED";
111     case CUDNN_STATUS_NOT_SUPPORTED:
112       return "CUDNN_STATUS_NOT_SUPPORTED";
113     case CUDNN_STATUS_LICENSE_ERROR:
114       return "CUDNN_STATUS_LICENSE_ERROR";
115     case CUDNN_STATUS_RUNTIME_PREREQUISITE_MISSING:
116       return "CUDNN_STATUS_RUNTIME_PREREQUISITE_MISSING";
117 #if CUDNN_VERSION >= 7000
118     case CUDNN_STATUS_RUNTIME_IN_PROGRESS:
119       return "CUDNN_STATUS_RUNTIME_IN_PROGRESS";
120     case CUDNN_STATUS_RUNTIME_FP_OVERFLOW:
121       return "CUDNN_STATUS_RUNTIME_FP_OVERFLOW";
122 #endif
123     default:
124       return absl::StrCat("<unknown cudnn status: ", static_cast<int>(status),
125                           ">");
126   }
127 }
128 
129 // RAII wrapper for all calls to cuDNN with a cuDNN handle argument.
130 //
131 // See CudnnAccess::GetHandle() for details.
132 class CudnnHandle {
133  public:
134   // Takes ownership of the executor context and the lock to access cuDNN
135   // using handle.
CudnnHandle(gpu::ScopedActivateExecutorContext context,std::unique_ptr<absl::MutexLock> lock,cudnnHandle_t handle)136   CudnnHandle(gpu::ScopedActivateExecutorContext context,
137               std::unique_ptr<absl::MutexLock> lock, cudnnHandle_t handle)
138       : context_(std::move(context)), lock_(std::move(lock)), handle_(handle) {}
139 
140   // Returns cuDNN handle. To be passed directly to cuDNN APIs, don't keep
141   // a copy.
handle() const142   cudnnHandle_t handle() const { return handle_; }
143 
144  private:
145   gpu::ScopedActivateExecutorContext context_;
146   std::unique_ptr<absl::MutexLock> lock_;
147   cudnnHandle_t handle_;  // Not owned.
148 };
149 
150 }  // namespace
151 
152 // Wraps a cuDNN handle and provides access to it through CudnnHandle
153 // instances, which also locks a mutex, acquires the CUDA context, and sets
154 // the stream that cuDNN should use to enqueue any work.
155 //
156 // Note: CudnnSupport::cudnn_ should be the only instantiation of this class.
157 class CudnnAccess {
158  public:
159   // Takes ownership of the handle.
CudnnAccess(cudnnHandle_t handle)160   explicit CudnnAccess(cudnnHandle_t handle) : handle_(handle) {}
161 
~CudnnAccess()162   ~CudnnAccess() {
163     absl::MutexLock lock(&mutex_);
164     cudnnDestroy(handle_);
165   }
166 
167   // Creates a CudnnHandle instance for stream.
168   //
169   // cuDNN API calls using the same handle instance need to be serialized
170   // across threads. This is guaranteed by CudnnHandle instances locking the
171   // mutex owned by this class.
172   //
173   // Most cuDNN APIs taking a handle perform work on a CUDA stream. The
174   // CudnnHandle instance acquires the executor's CUDA context and sets cuDNN
175   // to use the provided stream.
176   //
177   // The stream argument may be null, which translates to the legacy default
178   // stream. See
179   // https://docs.nvidia.com/cuda/cuda-driver-api/stream-sync-behavior.html.
180   // The legacy default stream synchronizes with all other streams and it is
181   // therefore a bad idea (performance wise) to call any cuDNN APIs that
182   // enqueue work in the stream.
GetHandle(GpuExecutor * executor,Stream * stream)183   CudnnHandle GetHandle(GpuExecutor* executor, Stream* stream) {
184     auto lock = absl::make_unique<absl::MutexLock>(&mutex_);
185     mutex_.AssertHeld();
186     gpu::ScopedActivateExecutorContext context(executor);
187     CUstream cu_stream = stream ? AsGpuStreamValue(stream) : cudaStreamLegacy;
188     const auto status = cudnnSetStream(handle_, cu_stream);
189     CHECK_EQ(status, CUDNN_STATUS_SUCCESS) << "Failed to set cuDNN stream.";
190     return CudnnHandle(std::move(context), std::move(lock), handle_);
191   }
192 
193  private:
194   // Guards the enqueueing of cuDNN operations via the handle_ below.
195   absl::Mutex mutex_;
196 
197   // cuDNN library handle.
198   cudnnHandle_t handle_ GUARDED_BY(mutex_);  // Owned.
199 };
200 
201 namespace {
202 
203 // A helper function to return the internal compute type for
204 // RNNs in cudnn.
205 cudnnDataType_t GetRnnComputeType(dnn::DataType data_type);
206 
ToConvForwardAlgo(dnn::AlgorithmDesc algorithm)207 cudnnConvolutionFwdAlgo_t ToConvForwardAlgo(dnn::AlgorithmDesc algorithm) {
208   cudnnConvolutionFwdAlgo_t algo =
209       cudnnConvolutionFwdAlgo_t(algorithm.algo_id());
210   switch (algo) {
211     case CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM:
212     case CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM:
213     case CUDNN_CONVOLUTION_FWD_ALGO_GEMM:
214     case CUDNN_CONVOLUTION_FWD_ALGO_DIRECT:
215     case CUDNN_CONVOLUTION_FWD_ALGO_FFT:
216     case CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING:
217     case CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD:
218     case CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED:
219       return algo;
220     default:
221       LOG(FATAL) << "Unsupported Cudnn convolution forward algorithm: "
222                  << algorithm.algo_id();
223   }
224 }
225 
ToConvBackwardDataAlgo(dnn::AlgorithmDesc algorithm)226 cudnnConvolutionBwdDataAlgo_t ToConvBackwardDataAlgo(
227     dnn::AlgorithmDesc algorithm) {
228   cudnnConvolutionBwdDataAlgo_t algo =
229       cudnnConvolutionBwdDataAlgo_t(algorithm.algo_id());
230   switch (algo) {
231     case CUDNN_CONVOLUTION_BWD_DATA_ALGO_0:
232     case CUDNN_CONVOLUTION_BWD_DATA_ALGO_1:
233     case CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT:
234     case CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING:
235     case CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD:
236     case CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD_NONFUSED:
237       return algo;
238     default:
239       LOG(FATAL)
240           << "Unsupported Cudnn convolution backward algorithm for data: "
241           << algorithm.algo_id();
242   }
243 }
244 
ToConvBackwardFilterAlgo(dnn::AlgorithmDesc algorithm)245 cudnnConvolutionBwdFilterAlgo_t ToConvBackwardFilterAlgo(
246     dnn::AlgorithmDesc algorithm) {
247   cudnnConvolutionBwdFilterAlgo_t algo =
248       cudnnConvolutionBwdFilterAlgo_t(algorithm.algo_id());
249   switch (algo) {
250     case CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0:
251     case CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1:
252     case CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT:
253     case CUDNN_CONVOLUTION_BWD_FILTER_ALGO_3:
254     // Based on cudnn.h, the following is not implemented.
255     // case CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD:
256     case CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD_NONFUSED:
257       return algo;
258     // Produces incorrect results for some shapes. Disabled for now, see
259     // NVIDIA bug 2072856. TODO(csigg): Only disable for subset of shapes.
260     // case CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT_TILING:
261     default:
262       LOG(FATAL)
263           << "Unsupported Cudnn convolution backward algorithm for filter: "
264           << algorithm.algo_id();
265   }
266 }
267 
GetCudnnProperty(libraryPropertyType type)268 port::StatusOr<int> GetCudnnProperty(libraryPropertyType type) {
269   int value;
270   RETURN_IF_CUDNN_ERROR(cudnnGetProperty(type, &value));
271   return value;
272 }
273 
ToCudnnRNNAlgo(absl::optional<dnn::AlgorithmDesc> algorithm)274 cudnnRNNAlgo_t ToCudnnRNNAlgo(absl::optional<dnn::AlgorithmDesc> algorithm) {
275   if (!algorithm.has_value()) {
276     return CUDNN_RNN_ALGO_STANDARD;
277   }
278   cudnnRNNAlgo_t algo = static_cast<cudnnRNNAlgo_t>(algorithm->algo_id());
279   switch (algo) {
280     case CUDNN_RNN_ALGO_STANDARD:
281     case CUDNN_RNN_ALGO_PERSIST_STATIC:
282     case CUDNN_RNN_ALGO_PERSIST_DYNAMIC:
283       return algo;
284     default:
285       LOG(FATAL) << "Unsupported Cudnn RNN algorithm: " << algorithm->algo_id();
286   }
287 }
288 
GetLoadedCudnnVersion(CudnnVersion * version)289 port::Status GetLoadedCudnnVersion(CudnnVersion* version) {
290   SE_ASSIGN_OR_RETURN(version->major_version, GetCudnnProperty(MAJOR_VERSION));
291   SE_ASSIGN_OR_RETURN(version->minor_version, GetCudnnProperty(MINOR_VERSION));
292   SE_ASSIGN_OR_RETURN(version->patch_level, GetCudnnProperty(PATCH_LEVEL));
293   return port::Status::OK();
294 }
295 
296 }  // namespace
297 
CudnnSupport(GpuExecutor * parent)298 CudnnSupport::CudnnSupport(GpuExecutor* parent) : parent_(parent) {}
299 
Init()300 port::Status CudnnSupport::Init() {
301   ScopedActivateExecutorContext context(parent_);
302   cudnnHandle_t cudnn_handle = nullptr;
303   const auto status = cudnnCreate(&cudnn_handle);
304   if (status == CUDNN_STATUS_SUCCESS) {
305     CudnnVersion source_version(CUDNN_MAJOR, CUDNN_MINOR, CUDNN_PATCHLEVEL);
306 
307     CudnnVersion loaded_version;
308     TF_RETURN_IF_ERROR(GetLoadedCudnnVersion(&loaded_version));
309     if (!IsSourceCompatibleWithCudnnLibrary(source_version, loaded_version)) {
310       const string error = absl::StrCat(
311           "Loaded runtime CuDNN library: ", loaded_version.ToString(),
312           " but source was compiled with: ", source_version.ToString(),
313           ".  CuDNN library major and minor version needs to match or have "
314           "higher minor version in case of CuDNN 7.0 or later version. If "
315           "using a binary install, upgrade your CuDNN library.  If building "
316           "from sources, make sure the library loaded at runtime is "
317           "compatible "
318           "with the version specified during compile configuration.");
319       LOG(ERROR) << error;
320       cudnnDestroy(cudnn_handle);
321       return port::Status(port::error::INTERNAL, error);
322     }
323 
324     cudnn_.reset(new CudnnAccess(cudnn_handle));
325     return port::Status::OK();
326   }
327 
328   CHECK_EQ(cudnn_handle, nullptr);
329   LOG(ERROR) << "Could not create cudnn handle: " << ToString(status);
330   if (status == CUDNN_STATUS_NOT_INITIALIZED) {
331     auto result = gpu::Diagnostician::FindKernelDriverVersion();
332     if (!result.ok()) {
333       LOG(ERROR) << "Error retrieving driver version: "
334                  << cuda::DriverVersionStatusToString(result);
335     } else {
336       const auto& version = result.ValueOrDie();
337       LOG(ERROR) << "Possibly insufficient driver version: "
338                  << cuda::DriverVersionToString(version);
339     }
340   }
341 
342   return port::Status(port::error::INTERNAL,
343                       absl::StrCat("cudnn library could not create a handle: ",
344                                    ToString(status)));
345 }
346 
347 port::StatusOr<perftools::gputools::dnn::VersionInfo>
GetVersion()348 CudnnSupport::GetVersion() {
349   CudnnVersion version;
350   TF_RETURN_IF_ERROR(GetLoadedCudnnVersion(&version));
351   return perftools::gputools::dnn::VersionInfo(
352       version.major_version, version.minor_version, version.patch_level);
353 }
354 
355 namespace {
356 
357 // Deleter functors for cuDNN types that need to be deleted.
358 struct TensorDescriptorDeleter {
operator ()stream_executor::gpu::__anon74a305650311::TensorDescriptorDeleter359   void operator()(cudnnTensorDescriptor_t descriptor) const {
360     CHECK_CUDNN_OK(cudnnDestroyTensorDescriptor(descriptor));
361   }
362 };
363 #if CUDNN_VERSION >= 7201
364 struct RNNDataDescriptorDeleter {
operator ()stream_executor::gpu::__anon74a305650311::RNNDataDescriptorDeleter365   void operator()(cudnnRNNDataDescriptor_t descriptor) const {
366     CHECK_CUDNN_OK(cudnnDestroyRNNDataDescriptor(descriptor));
367   }
368 };
369 #endif
370 struct FilterDescriptorDeleter {
operator ()stream_executor::gpu::__anon74a305650311::FilterDescriptorDeleter371   void operator()(cudnnFilterDescriptor_t descriptor) const {
372     CHECK_CUDNN_OK(cudnnDestroyFilterDescriptor(descriptor));
373   }
374 };
375 struct ConvolutionDescriptorDeleter {
operator ()stream_executor::gpu::__anon74a305650311::ConvolutionDescriptorDeleter376   void operator()(cudnnConvolutionDescriptor_t descriptor) const {
377     CHECK_CUDNN_OK(cudnnDestroyConvolutionDescriptor(descriptor));
378   }
379 };
380 struct PoolingDescriptorDeleter {
operator ()stream_executor::gpu::__anon74a305650311::PoolingDescriptorDeleter381   void operator()(cudnnPoolingDescriptor_t descriptor) const {
382     CHECK_CUDNN_OK(cudnnDestroyPoolingDescriptor(descriptor));
383   }
384 };
385 struct LrnDescriptorDeleter {
operator ()stream_executor::gpu::__anon74a305650311::LrnDescriptorDeleter386   void operator()(cudnnLRNDescriptor_t descriptor) const {
387     CHECK_CUDNN_OK(cudnnDestroyLRNDescriptor(descriptor));
388   }
389 };
390 
391 struct ActivationDescriptorDeleter {
operator ()stream_executor::gpu::__anon74a305650311::ActivationDescriptorDeleter392   void operator()(cudnnActivationDescriptor_t descriptor) const {
393     CHECK_CUDNN_OK(cudnnDestroyActivationDescriptor(descriptor));
394   }
395 };
396 struct DropoutDescriptorDeleter {
operator ()stream_executor::gpu::__anon74a305650311::DropoutDescriptorDeleter397   void operator()(cudnnDropoutDescriptor_t descriptor) const {
398     CHECK_CUDNN_OK(cudnnDestroyDropoutDescriptor(descriptor));
399   }
400 };
401 struct RnnDescriptorDeleter {
operator ()stream_executor::gpu::__anon74a305650311::RnnDescriptorDeleter402   void operator()(cudnnRNNDescriptor_t descriptor) const {
403     CHECK_CUDNN_OK(cudnnDestroyRNNDescriptor(descriptor));
404   }
405 };
406 struct PersistentRnnPlanDeleter {
operator ()stream_executor::gpu::__anon74a305650311::PersistentRnnPlanDeleter407   void operator()(cudnnPersistentRNNPlan_t plan) const {
408     CHECK_CUDNN_OK(cudnnDestroyPersistentRNNPlan(plan));
409   }
410 };
411 #if CUDNN_VERSION >= 7603
412 struct CtcLossDescriptorDeleter {
operator ()stream_executor::gpu::__anon74a305650311::CtcLossDescriptorDeleter413   void operator()(cudnnCTCLossDescriptor_t descriptor) const {
414     CHECK_CUDNN_OK(cudnnDestroyCTCLossDescriptor(descriptor));
415   }
416 };
417 #endif
418 
419 // RAII wrappers for cuDNN types.
420 using TensorDescriptor =
421     std::unique_ptr<cudnnTensorStruct, TensorDescriptorDeleter>;
422 #if CUDNN_VERSION >= 7201
423 using RNNDataDescriptor =
424     std::unique_ptr<cudnnRNNDataStruct, RNNDataDescriptorDeleter>;
425 #endif
426 using FilterDescriptor =
427     std::unique_ptr<cudnnFilterStruct, FilterDescriptorDeleter>;
428 using ConvolutionDescriptor =
429     std::unique_ptr<cudnnConvolutionStruct, ConvolutionDescriptorDeleter>;
430 using PoolingDescriptor =
431     std::unique_ptr<cudnnPoolingStruct, PoolingDescriptorDeleter>;
432 using LrnDescriptor = std::unique_ptr<cudnnLRNStruct, LrnDescriptorDeleter>;
433 using ActivationDescriptor =
434     std::unique_ptr<cudnnActivationStruct, ActivationDescriptorDeleter>;
435 using DropoutDescriptor =
436     std::unique_ptr<cudnnDropoutStruct, DropoutDescriptorDeleter>;
437 using RnnDescriptor = std::unique_ptr<cudnnRNNStruct, RnnDescriptorDeleter>;
438 using PersistentRnnPlan =
439     std::unique_ptr<cudnnPersistentRNNPlan, PersistentRnnPlanDeleter>;
440 #if CUDNN_VERSION >= 7603
441 using CtcLossDescriptor =
442     std::unique_ptr<cudnnCTCLossStruct, CtcLossDescriptorDeleter>;
443 #endif
444 
445 // Factory methods for cuDNN types.
CreateTensorDescriptor()446 TensorDescriptor CreateTensorDescriptor() {
447   cudnnTensorDescriptor_t result;
448   CHECK_CUDNN_OK(cudnnCreateTensorDescriptor(&result));
449   return TensorDescriptor(result);
450 }
451 #if CUDNN_VERSION >= 7201
CreateRNNDataDescriptor()452 RNNDataDescriptor CreateRNNDataDescriptor() {
453   cudnnRNNDataDescriptor_t result;
454   CHECK_CUDNN_OK(cudnnCreateRNNDataDescriptor(&result));
455   return RNNDataDescriptor(result);
456 }
457 #endif
CreateFilterDescriptor()458 FilterDescriptor CreateFilterDescriptor() {
459   cudnnFilterDescriptor_t result;
460   CHECK_CUDNN_OK(cudnnCreateFilterDescriptor(&result));
461   return FilterDescriptor(result);
462 }
CreateConvolutionDescriptor()463 ConvolutionDescriptor CreateConvolutionDescriptor() {
464   cudnnConvolutionDescriptor_t result;
465   CHECK_CUDNN_OK(cudnnCreateConvolutionDescriptor(&result));
466   return ConvolutionDescriptor(result);
467 }
CreatePoolingDescriptor()468 PoolingDescriptor CreatePoolingDescriptor() {
469   cudnnPoolingDescriptor_t result;
470   CHECK_CUDNN_OK(cudnnCreatePoolingDescriptor(&result));
471   return PoolingDescriptor(result);
472 }
CreateLrnDescriptor()473 LrnDescriptor CreateLrnDescriptor() {
474   cudnnLRNDescriptor_t result;
475   CHECK_CUDNN_OK(cudnnCreateLRNDescriptor(&result));
476   return LrnDescriptor(result);
477 }
CreateActivationDescriptor()478 ActivationDescriptor CreateActivationDescriptor() {
479   cudnnActivationDescriptor_t result;
480   CHECK_CUDNN_OK(cudnnCreateActivationDescriptor(&result));
481   return ActivationDescriptor(result);
482 }
CreateDropoutDescriptor()483 DropoutDescriptor CreateDropoutDescriptor() {
484   cudnnDropoutDescriptor_t result;
485   CHECK_CUDNN_OK(cudnnCreateDropoutDescriptor(&result));
486   return DropoutDescriptor(result);
487 }
CreateRnnDescriptor()488 RnnDescriptor CreateRnnDescriptor() {
489   cudnnRNNDescriptor_t result;
490   CHECK_CUDNN_OK(cudnnCreateRNNDescriptor(&result));
491   return RnnDescriptor(result);
492 }
493 #if CUDNN_VERSION >= 7603
CreateCtcLossDescriptor()494 CtcLossDescriptor CreateCtcLossDescriptor() {
495   cudnnCTCLossDescriptor_t result;
496   CHECK_CUDNN_OK(cudnnCreateCTCLossDescriptor(&result));
497   return CtcLossDescriptor(result);
498 }
499 #endif
500 
CreatePersistentRnnPlan(cudnnRNNDescriptor_t rnn_desc,int batch_size,cudnnDataType_t data_type)501 port::StatusOr<PersistentRnnPlan> CreatePersistentRnnPlan(
502     cudnnRNNDescriptor_t rnn_desc, int batch_size, cudnnDataType_t data_type) {
503   cudnnPersistentRNNPlan_t result;
504   RETURN_IF_CUDNN_ERROR(
505       cudnnCreatePersistentRNNPlan(rnn_desc, batch_size, data_type, &result));
506   return port::StatusOr<PersistentRnnPlan>(PersistentRnnPlan(result));
507 }
508 
509 // Turns a BatchDescriptor structure into a cudnn tensor handle within a
510 // scope.
511 class CudnnTensorDescriptor {
512  public:
CudnnTensorDescriptor(const dnn::BatchDescriptor & batch_descriptor,cudnnDataType_t elem_type)513   CudnnTensorDescriptor(const dnn::BatchDescriptor& batch_descriptor,
514                         cudnnDataType_t elem_type)
515       : handle_(CreateTensorDescriptor()) {
516     switch (batch_descriptor.layout()) {
517       case dnn::DataLayout::kBatchYXDepth:
518       case dnn::DataLayout::kBatchDepthYX: {
519         const int nd = batch_descriptor.ndims() + 2;
520         // cuDNN requires the strides and dims to be ordered as BDYX.
521         std::vector<int64> strides64 =
522             batch_descriptor.full_strides(dnn::DataLayout::kBatchDepthYX);
523         std::vector<int64> dims64 =
524             batch_descriptor.full_dims(dnn::DataLayout::kBatchDepthYX);
525 
526         // cuDNN requires arrays of ints.
527         std::vector<int> strides(nd);
528         std::vector<int> dims(nd);
529         std::transform(strides64.cbegin(), strides64.cend(), strides.begin(),
530                        &CheckedNarrowing<int64, int>);
531         std::transform(dims64.cbegin(), dims64.cend(), dims.begin(),
532                        &CheckedNarrowing<int64, int>);
533         CHECK_CUDNN_OK(cudnnSetTensorNdDescriptor(handle_.get(), elem_type, nd,
534                                                   dims.data(), strides.data()))
535             << "batch_descriptor: " << batch_descriptor.ToString();
536       } break;
537       case dnn::DataLayout::kBatchDepthYX4: {
538         CHECK_CUDNN_OK(cudnnSetTensor4dDescriptor(
539             handle_.get(), CUDNN_TENSOR_NCHW_VECT_C, elem_type,
540             batch_descriptor.count(), batch_descriptor.feature_map_count(),
541             batch_descriptor.height(), batch_descriptor.width()))
542             << "batch_descriptor: " << batch_descriptor.ToString();
543       } break;
544       default:
545         LOG(FATAL) << "Unsupported tensor format "
546                    << DataLayoutString(batch_descriptor.layout());
547         break;
548     }
549   }
550 
handle() const551   cudnnTensorDescriptor_t handle() const { return handle_.get(); }
552 
553  private:
554   TensorDescriptor handle_;
555 
556   SE_DISALLOW_COPY_AND_ASSIGN(CudnnTensorDescriptor);
557 };
558 
559 // Turns a FilterDescriptor structure into a cudnn filter handle within a
560 // scope.
561 class CudnnFilterDescriptor {
562  public:
CudnnFilterDescriptor(const dnn::FilterDescriptor & filter_descriptor,cudnnDataType_t elem_type)563   CudnnFilterDescriptor(const dnn::FilterDescriptor& filter_descriptor,
564                         cudnnDataType_t elem_type)
565       : handle_(CreateFilterDescriptor()) {
566     // TODO(b/23032134): Even if the filter layout is not supported,
567     // cudnnSetFilter4DDescriptor_v4 will return CUDNN_STATUS_SUCCESS because
568     // it does not take layout as an input. Maybe force cuDNN by giving wrong
569     // inputs intentionally?
570     cudnnTensorFormat_t format;
571     switch (filter_descriptor.layout()) {
572       case dnn::FilterLayout::kOutputInputYX:
573         format = CUDNN_TENSOR_NCHW;
574         break;
575       case dnn::FilterLayout::kOutputYXInput:
576         format = CUDNN_TENSOR_NHWC;
577         break;
578       case dnn::FilterLayout::kOutputInputYX4:
579         format = CUDNN_TENSOR_NCHW_VECT_C;
580         break;
581       default:
582         LOG(FATAL) << "Unsupported filter format "
583                    << FilterLayoutString(filter_descriptor.layout());
584         break;
585     }
586 
587     std::vector<int> dims(2 + filter_descriptor.ndims());
588     dims[0] = filter_descriptor.output_feature_map_count();
589     dims[1] = filter_descriptor.input_feature_map_count();
590     absl::Span<const int64> spatial_dims =
591         filter_descriptor.input_filter_dims();
592     std::copy(spatial_dims.begin(), spatial_dims.end(), dims.begin() + 2);
593 
594     CHECK_CUDNN_OK(cudnnSetFilterNdDescriptor(handle_.get(), elem_type, format,
595                                               dims.size(), dims.data()));
596   }
597 
handle() const598   cudnnFilterDescriptor_t handle() const { return handle_.get(); }
599 
600  private:
601   FilterDescriptor handle_;  // Owned.
602 
603   SE_DISALLOW_COPY_AND_ASSIGN(CudnnFilterDescriptor);
604 };
605 
606 // A helper function to decide whether to enable the TENSOR_OP_MATH math type
TensorOpMathEnabled()607 bool TensorOpMathEnabled() {
608   static bool is_enabled = [] {
609     bool is_disabled = false;
610     TF_CHECK_OK(
611         tensorflow::ReadBoolFromEnvVar("TF_DISABLE_CUDNN_TENSOR_OP_MATH",
612                                        /*default_val=*/false, &is_disabled));
613     return !is_disabled;
614   }();
615   return is_enabled;
616 }
617 
618 // A helper function to decide whether to enable the TENSOR_OP_MATH math type
619 // for RNNs.
RnnTensorOpMathEnabled()620 bool RnnTensorOpMathEnabled() {
621   static bool is_enabled = [] {
622     bool is_disabled = false;
623     TF_CHECK_OK(
624         tensorflow::ReadBoolFromEnvVar("TF_DISABLE_CUDNN_RNN_TENSOR_OP_MATH",
625                                        /*default_val=*/false, &is_disabled));
626     return !is_disabled;
627   }();
628   return is_enabled;
629 }
630 
631 // A helper function to decide whether to use
632 // CUDNN_BATCHNORM_SPATIAL_PERSISTENT in batchnorm. This mode can be faster in
633 // some tasks because an optimized path may be selected for CUDNN_DATA_FLOAT
634 // and CUDNN_DATA_HALF data types, compute capability 6.0 or higher. The
635 // reason we set it to false by default is that this mode may use scaled
636 // atomic integer reduction that may cause a numerical overflow for certain
637 // input data range.
638 // TODO(yangzihao): Use autotune to choose between this mode and
639 // CUDNN_BATCHNORM_SPATIAL mode.
BatchnormSpatialPersistentEnabled()640 bool BatchnormSpatialPersistentEnabled() {
641   static bool is_enabled = [] {
642     bool is_enabled = false;
643     TF_CHECK_OK(tensorflow::ReadBoolFromEnvVar(
644         "TF_USE_CUDNN_BATCHNORM_SPATIAL_PERSISTENT",
645         /*default_val=*/false, &is_enabled));
646     return is_enabled;
647   }();
648   return is_enabled;
649 }
650 
651 // The following function allows deterministic ops to be implemented relatively
652 // quickly using environment variables. It is intended to be temporary. The
653 // longer-term intention is to enable deterministic ops via tf.config and
654 // appropriate plumbing. See the discussion on PR 34951 for more information:
655 // https://github.com/tensorflow/tensorflow/pull/34951#discussion_r355682316
656 // This function and associated comment are replicated in the following three
657 // places:
658 //   1. tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker.cc
659 //   2. tensorflow/core/kernels/gpu_utils.cc
660 //   3. tensorflow/stream_executor/cuda/cuda_dnn.cc
661 // When implementing the plumbing, you should also search for the use of
662 // TF_DETERMINISTIC_OPS on its own.
663 // TODO(duncanriach): move to an API that uses tf.config and implement the first
664 //                    phase of plumbing.
RequireCudnnDeterminism()665 bool RequireCudnnDeterminism() {
666   static bool require_cudnn_determinism = [] {
667     bool deterministic_ops = false;
668     TF_CHECK_OK(tensorflow::ReadBoolFromEnvVar("TF_DETERMINISTIC_OPS",
669                                                /*default_val=*/false,
670                                                &deterministic_ops));
671     bool cudnn_deterministic = false;
672     TF_CHECK_OK(tensorflow::ReadBoolFromEnvVar("TF_CUDNN_DETERMINISTIC",
673                                                /*default_val=*/false,
674                                                &cudnn_deterministic));
675     return deterministic_ops || cudnn_deterministic;
676   }();
677   return require_cudnn_determinism;
678 }
679 
GetCcMajorMinor(Stream * stream)680 std::tuple<int, int> GetCcMajorMinor(Stream* stream) {
681   int cc_major, cc_minor;
682   stream->parent()->GetDeviceDescription().cuda_compute_capability(&cc_major,
683                                                                    &cc_minor);
684   return std::make_tuple(cc_major, cc_minor);
685 }
686 
687 // Turns a ConvolutionDescriptor structure into a cudnn convolution handle
688 // within a scope.
689 class CudnnConvolutionDescriptor {
690  public:
CudnnConvolutionDescriptor(const dnn::ConvolutionDescriptor & convolution_descriptor,cudnnDataType_t data_type)691   CudnnConvolutionDescriptor(
692       const dnn::ConvolutionDescriptor& convolution_descriptor,
693       cudnnDataType_t data_type)
694       : handle_(CreateConvolutionDescriptor()) {
695     absl::Span<const int64> strides64 = convolution_descriptor.strides();
696     absl::Span<const int64> padding64 = convolution_descriptor.padding();
697     absl::Span<const int64> dilations64 = convolution_descriptor.dilations();
698     CHECK_NE(convolution_descriptor.pad_alignment(),
699              dnn::PadAlignment::kTensorFlowPadding)
700         << "TensorFlow padding alignment is not supported.";
701 
702     // cuDNN requires arrays of ints.
703     std::vector<int> strides(convolution_descriptor.ndims());
704     std::vector<int> padding(convolution_descriptor.ndims());
705     std::vector<int> dilations(convolution_descriptor.ndims());
706     std::transform(strides64.cbegin(), strides64.cend(), strides.begin(),
707                    &CheckedNarrowing<int64, int>);
708     std::transform(padding64.cbegin(), padding64.cend(), padding.begin(),
709                    &CheckedNarrowing<int64, int>);
710     // TODO(yangzihao): Test with negative dilation to make sure that cudnn
711     // doesn't crash.
712     std::transform(dilations64.cbegin(), dilations64.cend(), dilations.begin(),
713                    &CheckedNarrowing<int64, int>);
714 
715     CHECK_CUDNN_OK(cudnnSetConvolutionNdDescriptor(
716         handle_.get(), convolution_descriptor.ndims(), padding.data(),
717         strides.data(), dilations.data(),
718         convolution_descriptor.convolution_not_crosscorr()
719             ? CUDNN_CONVOLUTION
720             : CUDNN_CROSS_CORRELATION,
721         data_type));
722 
723     // NOTE(benbarsdell): This only applies if tensor op math is enabled
724     //                      and algo selection is set to Default.
725     this->set_use_tensor_op_math(true);
726 
727 #if CUDNN_MAJOR >= 7
728     VLOG(2) << "Requesting grouped convolution: "
729             << convolution_descriptor.group_count();
730     CHECK_CUDNN_OK(cudnnSetConvolutionGroupCount(
731         handle_.get(), convolution_descriptor.group_count()));
732 #else
733     CHECK_EQ(convolution_descriptor.group_count(), 1)
734         << "Requested grouped convolution for cuDNN version < 7";
735 #endif
736   }
737 
set_use_tensor_op_math(bool use_tensor_op_math) const738   void set_use_tensor_op_math(bool use_tensor_op_math) const {
739 #if CUDNN_VERSION >= 7000
740     cudnnMathType_t math_type =
741         (use_tensor_op_math ? CUDNN_TENSOR_OP_MATH : CUDNN_DEFAULT_MATH);
742     if (TensorOpMathEnabled()) {
743       CHECK_CUDNN_OK(cudnnSetConvolutionMathType(handle_.get(), math_type));
744     }
745 #endif
746   }
747 
handle() const748   cudnnConvolutionDescriptor_t handle() const { return handle_.get(); }
749 
750  private:
751   ConvolutionDescriptor handle_;  // Owned.
752 
753   SE_DISALLOW_COPY_AND_ASSIGN(CudnnConvolutionDescriptor);
754 };
755 
756 // Turns a PoolingDescriptor structure into a cudnn pooling descriptor handle
757 // within a scope.
758 class CudnnPoolingDescriptor {
759  public:
CudnnPoolingDescriptor(const dnn::PoolingDescriptor & pooling_descriptor)760   explicit CudnnPoolingDescriptor(
761       const dnn::PoolingDescriptor& pooling_descriptor)
762       : handle_(CreatePoolingDescriptor()) {
763     absl::Span<const int64> strides64 = pooling_descriptor.strides();
764     absl::Span<const int64> padding64 = pooling_descriptor.padding();
765     absl::Span<const int64> shape64 = pooling_descriptor.window();
766 
767     const int nd = pooling_descriptor.ndims();
768     std::vector<int> shape(nd);
769     std::vector<int> padding(nd);
770     std::vector<int> strides(nd);
771     std::transform(strides64.cbegin(), strides64.cend(), strides.begin(),
772                    &CheckedNarrowing<int64, int>);
773     std::transform(padding64.cbegin(), padding64.cend(), padding.begin(),
774                    &CheckedNarrowing<int64, int>);
775     std::transform(shape64.cbegin(), shape64.cend(), shape.begin(),
776                    &CheckedNarrowing<int64, int>);
777     bool propagate_nans = pooling_descriptor.propagate_nans();
778     const auto cudnn_max_pooling_mode = RequireCudnnDeterminism()
779                                             ? CUDNN_POOLING_MAX_DETERMINISTIC
780                                             : CUDNN_POOLING_MAX;
781     CHECK_CUDNN_OK(cudnnSetPoolingNdDescriptor(
782         handle_.get(),
783         (pooling_descriptor.mode() == dnn::PoolingMode::kMaximum
784              ? cudnn_max_pooling_mode
785              : CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING),
786         propagate_nans ? CUDNN_PROPAGATE_NAN : CUDNN_NOT_PROPAGATE_NAN, nd,
787         shape.data(), padding.data(), strides.data()));
788   }
789 
handle() const790   cudnnPoolingDescriptor_t handle() const { return handle_.get(); }
791 
792  private:
793   PoolingDescriptor handle_;  // Owned.
794 
795   SE_DISALLOW_COPY_AND_ASSIGN(CudnnPoolingDescriptor);
796 };
797 
798 // Turns a NormalizeDescriptor structure into a cudnn LRN descriptor handle.
799 class CudnnNormalizeDescriptor {
800  public:
CudnnNormalizeDescriptor(const dnn::NormalizeDescriptor & normalize_descriptor)801   explicit CudnnNormalizeDescriptor(
802       const dnn::NormalizeDescriptor& normalize_descriptor)
803       : handle_(CreateLrnDescriptor()) {
804     // The range specifies that the indices in the closed range
805     // [i - range, i + range] should be included in the normalization for index
806     // i. The lrnN value is the total number of elements in the range, so
807     // lrnN = 2*range + 1.
808     unsigned lrnN = 2 * normalize_descriptor.range() + 1;
809 
810     // Note that SE defines the normalization operation as
811     //
812     //  U_i = V_i / ((bias +  alpha      * (sum_j V_j^2)) ^ beta)
813     //
814     // but cuDNN defines it as
815     //
816     //  U_i = V_i / ((bias + (alpha / n) * (sum_j V_j^2)) ^ beta)
817     //
818     // i.e. there is a factor of n difference between the meaning of the alphas
819     // in the two contexts. The cuDNN alpha is n times the SE alpha.
820     double lrnAlpha = lrnN * normalize_descriptor.alpha();
821 
822     double lrnBeta = normalize_descriptor.beta();
823     double lrnK = normalize_descriptor.bias();
824     CHECK_CUDNN_OK(
825         cudnnSetLRNDescriptor(handle_.get(), lrnN, lrnAlpha, lrnBeta, lrnK));
826   }
827 
handle() const828   cudnnLRNDescriptor_t handle() const { return handle_.get(); }
829 
830  private:
831   LrnDescriptor handle_;  // Owned.
832 
833   SE_DISALLOW_COPY_AND_ASSIGN(CudnnNormalizeDescriptor);
834 };
835 
836 // Turns a ActivationDescriptor structure into a cudnn activation
837 // descriptor handle within a scope.
838 class CudnnActivationDescriptor {
839  public:
CudnnActivationDescriptor(dnn::ActivationMode activation_mode,cudnnNanPropagation_t nan_propagation,double value_max)840   CudnnActivationDescriptor(dnn::ActivationMode activation_mode,
841                             cudnnNanPropagation_t nan_propagation,
842                             double value_max)
843       : handle_(CreateActivationDescriptor()) {
844     double relu_ceiling = 0.0;
845     cudnnActivationMode_t mode;
846     switch (activation_mode) {
847 #if CUDNN_VERSION >= 7100
848       case dnn::ActivationMode::kNone:
849         mode = CUDNN_ACTIVATION_IDENTITY;
850         break;
851 #endif
852       case dnn::ActivationMode::kRelu6:
853         relu_ceiling = 6.0;
854         mode = CUDNN_ACTIVATION_CLIPPED_RELU;
855         break;
856       case dnn::ActivationMode::kReluX:
857         relu_ceiling = value_max;
858         mode = CUDNN_ACTIVATION_CLIPPED_RELU;
859         break;
860       case dnn::ActivationMode::kRelu:
861         mode = CUDNN_ACTIVATION_RELU;
862         break;
863       case dnn::ActivationMode::kSigmoid:
864         mode = CUDNN_ACTIVATION_SIGMOID;
865         break;
866       case dnn::ActivationMode::kTanh:
867         mode = CUDNN_ACTIVATION_TANH;
868         break;
869       default:
870         LOG(FATAL) << "unrecognized activation mode: "
871                    << static_cast<int>(activation_mode);
872     }
873 
874     CHECK_CUDNN_OK(cudnnSetActivationDescriptor(handle_.get(), mode,
875                                                 nan_propagation, relu_ceiling));
876   }
877 
handle() const878   cudnnActivationDescriptor_t handle() const { return handle_.get(); }
879 
880  private:
881   ActivationDescriptor handle_;  // Owned.
882 
883   SE_DISALLOW_COPY_AND_ASSIGN(CudnnActivationDescriptor);
884 };
885 
ToCudnnDataType(dnn::DataType data_type,dnn::DataLayout data_layout=dnn::DataLayout::kBatchDepthYX)886 cudnnDataType_t ToCudnnDataType(
887     dnn::DataType data_type,
888     dnn::DataLayout data_layout = dnn::DataLayout::kBatchDepthYX) {
889   switch (data_type) {
890     case dnn::DataType::kFloat:
891       return CUDNN_DATA_FLOAT;
892     case dnn::DataType::kDouble:
893       return CUDNN_DATA_DOUBLE;
894     case dnn::DataType::kHalf:
895       return CUDNN_DATA_HALF;
896     case dnn::DataType::kInt8:
897       return data_layout == dnn::DataLayout::kBatchDepthYX4 ? CUDNN_DATA_INT8x4
898                                                             : CUDNN_DATA_INT8;
899     case dnn::DataType::kInt32:
900       return CUDNN_DATA_INT32;
901     default:
902       LOG(FATAL) << "Invalid DNN data type: " << static_cast<int>(data_type);
903   }
904 }
905 
ToCudnnDataType(dnn::DataType data_type,dnn::FilterLayout filter_layout)906 cudnnDataType_t ToCudnnDataType(dnn::DataType data_type,
907                                 dnn::FilterLayout filter_layout) {
908   if (data_type == dnn::DataType::kInt8 &&
909       filter_layout == dnn::FilterLayout::kOutputInputYX4) {
910     return CUDNN_DATA_INT8x4;
911   }
912   return ToCudnnDataType(data_type);
913 }
914 
915 template <typename T>
GetCudnnDataType(dnn::DataLayout data_layout=dnn::DataLayout::kBatchDepthYX)916 cudnnDataType_t GetCudnnDataType(
917     dnn::DataLayout data_layout = dnn::DataLayout::kBatchDepthYX) {
918   return ToCudnnDataType(dnn::ToDataType<T>::value, data_layout);
919 }
920 
ToCudnnRnnInputMode(dnn::RnnInputMode input_mode)921 cudnnRNNInputMode_t ToCudnnRnnInputMode(dnn::RnnInputMode input_mode) {
922   switch (input_mode) {
923     case dnn::RnnInputMode::kRnnLinearSkip:
924     case dnn::RnnInputMode::kRnnSkipInput:
925       return static_cast<cudnnRNNInputMode_t>(input_mode);
926     default:
927       LOG(FATAL) << "Invalid RNN input mode: " << static_cast<int>(input_mode);
928   }
929 }
930 
ToCudnnRnnDirectionMode(dnn::RnnDirectionMode direction_mode)931 cudnnDirectionMode_t ToCudnnRnnDirectionMode(
932     dnn::RnnDirectionMode direction_mode) {
933   switch (direction_mode) {
934     case dnn::RnnDirectionMode::kRnnUnidirectional:
935     case dnn::RnnDirectionMode::kRnnBidirectional:
936       return static_cast<cudnnDirectionMode_t>(direction_mode);
937     default:
938       LOG(FATAL) << "Invalid RNN direction mode: "
939                  << static_cast<int>(direction_mode);
940   }
941 }
942 
ToCudnnRnnMode(dnn::RnnMode rnn_mode)943 cudnnRNNMode_t ToCudnnRnnMode(dnn::RnnMode rnn_mode) {
944   switch (rnn_mode) {
945     case dnn::RnnMode::kRnnRelu:
946     case dnn::RnnMode::kRnnTanh:
947     case dnn::RnnMode::kRnnLstm:
948     case dnn::RnnMode::kRnnGru:
949       return static_cast<cudnnRNNMode_t>(rnn_mode);
950     default:
951       LOG(FATAL) << "Invalid RNN Mode: " << static_cast<int>(rnn_mode);
952   }
953 }
954 
CudnnDataTypeToByteSize(cudnnDataType_t data_type)955 int CudnnDataTypeToByteSize(cudnnDataType_t data_type) {
956   switch (data_type) {
957     case CUDNN_DATA_FLOAT:
958       return sizeof(float);
959     case CUDNN_DATA_DOUBLE:
960       return sizeof(double);
961     case CUDNN_DATA_HALF:
962       return sizeof(Eigen::half);
963     default:
964       LOG(FATAL) << "Invalid DNN data type: " << static_cast<int>(data_type);
965   }
966 }
967 
968 class CudnnDropoutDescriptor {
CudnnDropoutDescriptor(DropoutDescriptor handle)969   explicit CudnnDropoutDescriptor(DropoutDescriptor handle)
970       : handle_(std::move(handle)) {}
971 
972  public:
973   CudnnDropoutDescriptor(CudnnDropoutDescriptor&&) = default;
974 
Create(const CudnnHandle & cudnn,float dropout,uint64 seed,ScratchAllocator * state_allocator)975   static port::StatusOr<CudnnDropoutDescriptor> Create(
976       const CudnnHandle& cudnn, float dropout, uint64 seed,
977       ScratchAllocator* state_allocator) {
978     DropoutDescriptor handle = CreateDropoutDescriptor();
979 
980     if (dropout == 0.0f) {
981       // Return 'empty' dropout descriptor.
982       return CudnnDropoutDescriptor(std::move(handle));
983     }
984 
985     DeviceMemory<uint8> state_memory;
986     if (state_allocator) {
987       size_t state_sizes_in_bytes = 0;
988       RETURN_IF_CUDNN_ERROR(
989           cudnnDropoutGetStatesSize(cudnn.handle(), &state_sizes_in_bytes));
990       SE_ASSIGN_OR_RETURN(state_memory,
991                           state_allocator->AllocateBytes(state_sizes_in_bytes));
992     }
993     RETURN_IF_CUDNN_ERROR(cudnnSetDropoutDescriptor(
994         handle.get(), cudnn.handle(), dropout, state_memory.opaque(),
995         state_memory.size(), seed));
996 
997     return CudnnDropoutDescriptor(std::move(handle));
998   }
999 
handle() const1000   cudnnDropoutDescriptor_t handle() const { return handle_.get(); }
1001 
1002  private:
1003   DropoutDescriptor handle_;  // Owned.
1004   SE_DISALLOW_COPY_AND_ASSIGN(CudnnDropoutDescriptor);
1005 };
1006 
1007 class CudnnRnnParamsDescriptor {
1008   typedef dnn::RnnDescriptor::ParamsRegions ParamsRegions;
1009 
CudnnRnnParamsDescriptor(FilterDescriptor handle,int64 params_size_in_bytes,ParamsRegions weights,ParamsRegions biases)1010   CudnnRnnParamsDescriptor(FilterDescriptor handle, int64 params_size_in_bytes,
1011                            ParamsRegions weights, ParamsRegions biases)
1012       : handle_(std::move(handle)),
1013         params_size_in_bytes_(params_size_in_bytes),
1014         weights_(std::move(weights)),
1015         biases_(std::move(biases)) {}
1016 
1017  public:
1018   CudnnRnnParamsDescriptor(CudnnRnnParamsDescriptor&&) = default;
1019 
1020   static port::StatusOr<CudnnRnnParamsDescriptor> Create(
1021       const CudnnHandle& cudnn, int input_size, cudnnDataType_t data_type,
1022       cudnnRNNDescriptor_t rnn_desc, cudnnRNNMode_t rnn_mode,
1023       cudnnDirectionMode_t direction_mode, int num_layers);
1024 
handle() const1025   cudnnFilterDescriptor_t handle() const { return handle_.get(); }
params_size_in_bytes() const1026   int64 params_size_in_bytes() const { return params_size_in_bytes_; }
params_weights() const1027   ParamsRegions params_weights() const { return weights_; }
params_biases() const1028   ParamsRegions params_biases() const { return biases_; }
1029 
1030  private:
1031   FilterDescriptor handle_;
1032   int64 params_size_in_bytes_;
1033   ParamsRegions weights_;
1034   ParamsRegions biases_;
1035   SE_DISALLOW_COPY_AND_ASSIGN(CudnnRnnParamsDescriptor);
1036 };
1037 
1038 }  // namespace
1039 
1040 class CudnnRnnDescriptor : public dnn::RnnDescriptor {
CudnnRnnDescriptor(const CudnnHandle & cudnn,gpu::RnnDescriptor rnn_desc,PersistentRnnPlan rnn_plan,int num_layers,int hidden_size,int input_size,int cell_size,int batch_size,cudnnRNNInputMode_t input_mode,cudnnDirectionMode_t direction_mode,cudnnRNNMode_t rnn_mode,cudnnDataType_t data_type,cudnnDataType_t compute_type,const dnn::AlgorithmConfig & algorithm_config,CudnnDropoutDescriptor dropout_desc,CudnnRnnParamsDescriptor params_desc)1041   CudnnRnnDescriptor(const CudnnHandle& cudnn, gpu::RnnDescriptor rnn_desc,
1042                      PersistentRnnPlan rnn_plan, int num_layers,
1043                      int hidden_size, int input_size, int cell_size,
1044                      int batch_size, cudnnRNNInputMode_t input_mode,
1045                      cudnnDirectionMode_t direction_mode,
1046                      cudnnRNNMode_t rnn_mode, cudnnDataType_t data_type,
1047                      cudnnDataType_t compute_type,
1048                      const dnn::AlgorithmConfig& algorithm_config,
1049                      CudnnDropoutDescriptor dropout_desc,
1050                      CudnnRnnParamsDescriptor params_desc)
1051       : rnn_desc_(std::move(rnn_desc)),
1052         rnn_plan_(std::move(rnn_plan)),
1053         num_layers_(num_layers),
1054         hidden_size_(hidden_size),
1055         input_size_(input_size),
1056         cell_size_(cell_size),
1057         batch_size_(batch_size),
1058         rnn_algo_(ToCudnnRNNAlgo(algorithm_config.algorithm())),
1059         input_mode_(input_mode),
1060         direction_mode_(direction_mode),
1061         rnn_mode_(rnn_mode),
1062         data_type_(data_type),
1063         compute_type_(compute_type),
1064         algorithm_config_(algorithm_config),
1065         dropout_desc_(std::move(dropout_desc)),
1066         params_desc_(std::move(params_desc)) {}
1067 
1068  public:
1069   CudnnRnnDescriptor(CudnnRnnDescriptor&& other) = default;
1070 
Create(const CudnnHandle & cudnn,int num_layers,int hidden_size,int input_size,int cell_size,int batch_size,cudnnRNNInputMode_t input_mode,cudnnDirectionMode_t direction_mode,cudnnRNNMode_t rnn_mode,cudnnDataType_t data_type,cudnnDataType_t compute_type,const dnn::AlgorithmConfig & algorithm_config,float dropout,uint64 seed,ScratchAllocator * state_allocator,bool use_padded_io)1071   static port::StatusOr<CudnnRnnDescriptor> Create(
1072       const CudnnHandle& cudnn, int num_layers, int hidden_size, int input_size,
1073       int cell_size, int batch_size, cudnnRNNInputMode_t input_mode,
1074       cudnnDirectionMode_t direction_mode, cudnnRNNMode_t rnn_mode,
1075       cudnnDataType_t data_type, cudnnDataType_t compute_type,
1076       const dnn::AlgorithmConfig& algorithm_config, float dropout, uint64 seed,
1077       ScratchAllocator* state_allocator, bool use_padded_io) {
1078     SE_ASSIGN_OR_RETURN(
1079         CudnnDropoutDescriptor dropout_desc,
1080         CudnnDropoutDescriptor::Create(cudnn, dropout, seed, state_allocator));
1081 
1082     gpu::RnnDescriptor rnn_desc = CreateRnnDescriptor();
1083     cudnnRNNAlgo_t rnn_algo = ToCudnnRNNAlgo(algorithm_config.algorithm());
1084 
1085     // TODO: allow the user to choose an algorithm.
1086     int unified_size = hidden_size;
1087     bool use_projection = cell_size != 0 && hidden_size < cell_size;
1088     if (use_projection) {
1089       unified_size = cell_size;
1090     }
1091     RETURN_IF_CUDNN_ERROR(cudnnSetRNNDescriptor_v6(
1092         cudnn.handle(), /*rnnDesc=*/rnn_desc.get(),
1093         /*hiddenSize=*/unified_size, /*numLayers=*/num_layers,
1094         /*dropoutDesc=*/dropout_desc.handle(), /*inputMode=*/input_mode,
1095         /*direction=*/direction_mode, /*mode=*/rnn_mode, /*algo=*/rnn_algo,
1096         /*dataType=*/compute_type));
1097     if (use_projection) {
1098 #if CUDNN_VERSION >= 7101
1099       RETURN_IF_CUDNN_ERROR(cudnnSetRNNProjectionLayers(
1100           cudnn.handle(), /*rnnDesc=*/rnn_desc.get(),
1101           /*recProjSize=*/hidden_size, /*outProjSize=*/0));
1102 #else
1103       return port::Status(port::error::INVALID_ARGUMENT,
1104                           "No supported cudnnSetRNNProjectionLayers when "
1105                           "CUDNN_VERSION < 7.1.1");
1106 #endif
1107     }
1108 
1109     // TODO: For now, we only use cudnnRNN**Ex API to process padded inputs.
1110     // But in the future if these APIs are used to process full length arrays,
1111     // we need to distinguish when to set it.
1112 #if CUDNN_VERSION >= 7201
1113     if (use_padded_io) {
1114       RETURN_IF_CUDNN_ERROR(
1115           cudnnSetRNNPaddingMode(rnn_desc.get(), CUDNN_RNN_PADDED_IO_ENABLED));
1116     }
1117 #endif
1118 
1119     port::StatusOr<PersistentRnnPlan> rnn_plan_wrapper;
1120     PersistentRnnPlan rnn_plan;
1121     if (rnn_algo == CUDNN_RNN_ALGO_PERSIST_DYNAMIC) {
1122       CHECK_GE(batch_size, 0);
1123       rnn_plan_wrapper =
1124           CreatePersistentRnnPlan(rnn_desc.get(), batch_size, data_type);
1125       if (!rnn_plan_wrapper.ok()) {
1126         return port::StatusOr<CudnnRnnDescriptor>(rnn_plan_wrapper.status());
1127       } else {
1128         rnn_plan = rnn_plan_wrapper.ConsumeValueOrDie();
1129         RETURN_IF_CUDNN_ERROR(
1130             cudnnSetPersistentRNNPlan(rnn_desc.get(), rnn_plan.get()));
1131       }
1132     }
1133 
1134     // Create the params handle.
1135     SE_ASSIGN_OR_RETURN(auto params_desc,
1136                         CudnnRnnParamsDescriptor::Create(
1137                             cudnn, input_size, data_type, rnn_desc.get(),
1138                             rnn_mode, direction_mode, num_layers));
1139 
1140 #if CUDNN_VERSION >= 7000
1141     // Require explicit algorithm config to enable tensor cores. Some configs
1142     // return CUDNN_NOT_SUPPORTED when tensor ops are enabled (which is against
1143     // the idiom that enabling tensor ops is only a hint: see nvbugs/2172799).
1144     // We can only reasonably expect the user to handle the subsequent failure
1145     // in profile mode, which is run with algorithms returned from
1146     // GetRnnAlgorithms() (which are non-default and explicitly set whether to
1147     // use tensor ops). CuDNN 7.2.1 fixed this issue
1148     if (RnnTensorOpMathEnabled()) {
1149       cudnnMathType_t math_type;
1150       if (algorithm_config.algorithm().has_value()) {
1151         math_type = algorithm_config.algorithm()->tensor_ops_enabled()
1152                         ? CUDNN_TENSOR_OP_MATH
1153                         : CUDNN_DEFAULT_MATH;
1154       } else {
1155 #if CUDNN_VERSION >= 7201
1156         math_type = CUDNN_TENSOR_OP_MATH;
1157 #else
1158         math_type = CUDNN_DEFAULT_MATH;
1159 #endif  // CUDNN_VERSION >= 7201
1160       }
1161       CHECK_CUDNN_OK(cudnnSetRNNMatrixMathType(rnn_desc.get(), math_type));
1162     }
1163 #endif  // CUDNN_VERSION >= 7000
1164 
1165     return CudnnRnnDescriptor(cudnn, std::move(rnn_desc), std::move(rnn_plan),
1166                               num_layers, hidden_size, input_size, cell_size,
1167                               batch_size, input_mode, direction_mode, rnn_mode,
1168                               data_type, compute_type, algorithm_config,
1169                               std::move(dropout_desc), std::move(params_desc));
1170   }
1171 
handle() const1172   cudnnRNNDescriptor_t handle() const { return rnn_desc_.get(); }
num_layers() const1173   int num_layers() const { return num_layers_; }
hidden_size() const1174   int hidden_size() const { return hidden_size_; }
input_size() const1175   int input_size() const { return input_size_; }
cell_size() const1176   int cell_size() const { return cell_size_; }
batch_size() const1177   int batch_size() const { return batch_size_; }
input_mode() const1178   cudnnRNNInputMode_t input_mode() const { return input_mode_; }
direction_mode() const1179   cudnnDirectionMode_t direction_mode() const { return direction_mode_; }
rnn_mode() const1180   cudnnRNNMode_t rnn_mode() const { return rnn_mode_; }
data_type() const1181   cudnnDataType_t data_type() const { return data_type_; }
compute_type() const1182   cudnnDataType_t compute_type() const { return compute_type_; }
algorithm_config() const1183   const dnn::AlgorithmConfig& algorithm_config() const {
1184     return algorithm_config_;
1185   }
ParamsSizeInBytes() const1186   int64 ParamsSizeInBytes() const override {
1187     return params_desc_.params_size_in_bytes();
1188   }
params_handle() const1189   cudnnFilterDescriptor_t params_handle() const {
1190     return params_desc_.handle();
1191   }
ParamsWeightRegions() const1192   ParamsRegions ParamsWeightRegions() const override {
1193     return params_desc_.params_weights();
1194   }
ParamsBiasRegions() const1195   ParamsRegions ParamsBiasRegions() const override {
1196     return params_desc_.params_biases();
1197   }
1198 
1199  private:
1200   gpu::RnnDescriptor rnn_desc_;
1201   PersistentRnnPlan rnn_plan_;
1202   int num_layers_;
1203   int hidden_size_;
1204   int input_size_;
1205   // cell_size_ is the size of cell state, which will be different from
1206   // hidden_size_ if the projection is used.
1207   int cell_size_;
1208   // batch_size_ is set to -1 when not using CUDNN_RNN_ALGO_PERSIST_DYNAMIC
1209   // algorithm.
1210   int batch_size_;
1211   cudnnRNNAlgo_t rnn_algo_;
1212   cudnnRNNInputMode_t input_mode_;
1213   cudnnDirectionMode_t direction_mode_;
1214   cudnnRNNMode_t rnn_mode_;
1215   cudnnDataType_t data_type_;
1216   cudnnDataType_t compute_type_;
1217   dnn::AlgorithmConfig algorithm_config_;
1218   CudnnDropoutDescriptor dropout_desc_;
1219   CudnnRnnParamsDescriptor params_desc_;
1220   SE_DISALLOW_COPY_AND_ASSIGN(CudnnRnnDescriptor);
1221 };
1222 
1223 #if CUDNN_VERSION >= 7603
1224 class CudnnCtcLossDescriptor {
1225  public:
CudnnCtcLossDescriptor(cudnnDataType_t data_type)1226   explicit CudnnCtcLossDescriptor(cudnnDataType_t data_type)
1227       : handle_(CreateCtcLossDescriptor()) {
1228     CHECK_CUDNN_OK(cudnnSetCTCLossDescriptorEx(
1229         /*ctcLossDesc=*/handle_.get(),
1230         /*compType=*/data_type,
1231         /*normMode=*/CUDNN_LOSS_NORMALIZATION_SOFTMAX,
1232         /*gradMode=*/CUDNN_NOT_PROPAGATE_NAN));
1233   }
1234 
handle() const1235   cudnnCTCLossDescriptor_t handle() const { return handle_.get(); }
1236 
1237  private:
1238   CtcLossDescriptor handle_;  // Owned
1239 
1240   SE_DISALLOW_COPY_AND_ASSIGN(CudnnCtcLossDescriptor);
1241 };
1242 #else
1243 // dummy class
1244 class CudnnCtcLossDescriptor {
1245  public:
CudnnCtcLossDescriptor(cudnnDataType_t data_type)1246   CudnnCtcLossDescriptor(cudnnDataType_t data_type) {}
1247 };
1248 #endif
1249 
1250 namespace {
1251 
1252 // Check if the LSTM projection is used. If yes, an additional weight matrix
1253 // (projection matrix) will be fetched to the 'weights'. Otherwise, nothing will
1254 // be done.
CheckAndFetchProjectionWeights(const CudnnHandle & cudnn,cudnnRNNDescriptor_t rnn_desc,const int layer,const TensorDescriptor & input_desc,const FilterDescriptor & filter_desc,const FilterDescriptor & region_desc_handle,dnn::RnnDescriptor::ParamsRegions * weights)1255 port::Status CheckAndFetchProjectionWeights(
1256     const CudnnHandle& cudnn, cudnnRNNDescriptor_t rnn_desc, const int layer,
1257     const TensorDescriptor& input_desc, const FilterDescriptor& filter_desc,
1258     const FilterDescriptor& region_desc_handle,
1259     dnn::RnnDescriptor::ParamsRegions* weights) {
1260 #if CUDNN_VERSION >= 7101
1261   int hidden_size_v;
1262   int num_layers_v;
1263   cudnnDropoutDescriptor_t dropout_desc;
1264   cudnnRNNInputMode_t input_mode;
1265   cudnnDirectionMode_t direction;
1266   cudnnRNNMode_t mode;
1267   cudnnRNNAlgo_t algo;
1268   cudnnDataType_t data_type;
1269   RETURN_IF_CUDNN_ERROR(cudnnGetRNNDescriptor(
1270       /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc,
1271       /*hiddenSize=*/&hidden_size_v,
1272       /*numLayers=*/&num_layers_v,
1273       /*dropoutDesc=*/&dropout_desc,
1274       /*inputMode=*/&input_mode,
1275       /*direction=*/&direction,
1276       /*mode=*/&mode,
1277       /*algo=*/&algo,
1278       /*dataType=*/&data_type));
1279   int rec_proj_size_v;
1280   int out_proj_size_v;
1281   RETURN_IF_CUDNN_ERROR(cudnnGetRNNProjectionLayers(
1282       /*handle=*/cudnn.handle(),
1283       /*rnnDesc=*/rnn_desc,
1284       /*recProjSize*/ &rec_proj_size_v,
1285       /*outProjSize*/ &out_proj_size_v));
1286   if (rec_proj_size_v != hidden_size_v) {
1287     void* offset = nullptr;
1288     int region_id = 8;
1289     RETURN_IF_CUDNN_ERROR(cudnnGetRNNLinLayerMatrixParams(
1290         /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc,
1291         /*layer=*/layer, /*xDesc=*/input_desc.get(),
1292         /*wDesc=*/filter_desc.get(),
1293         /*w=*/nullptr, /*linLayerID=*/region_id,
1294         /*linLayerMatDesc=*/region_desc_handle.get(),
1295         /*linLayerMat or linLayerBias=*/&offset));
1296     int dims[] = {1, 1, 1};
1297     cudnnDataType_t data_type;
1298     cudnnTensorFormat_t tensor_format;
1299     int n_dims;
1300     RETURN_IF_CUDNN_ERROR(cudnnGetFilterNdDescriptor(
1301         /*filterDesc=*/region_desc_handle.get(),
1302         /*nbDimsRequested=*/sizeof(dims) / sizeof(dims[0]),
1303         /*dataType=*/&data_type, /*format=*/&tensor_format,
1304         /*nbDims=*/&n_dims, /*filterDimA=*/dims));
1305     int64 size =
1306         dims[0] * dims[1] * dims[2] * CudnnDataTypeToByteSize(data_type);
1307     dnn::RnnDescriptor::ParamsRegion region = {reinterpret_cast<int64>(offset),
1308                                                size};
1309     weights->push_back(region);
1310   }
1311 #endif  // CUDNN_VERSION >= 7101
1312   return port::Status::OK();
1313 }
1314 
Create(const CudnnHandle & cudnn,int input_size,cudnnDataType_t data_type,cudnnRNNDescriptor_t rnn_desc,cudnnRNNMode_t rnn_mode,cudnnDirectionMode_t direction_mode,int num_layers)1315 port::StatusOr<CudnnRnnParamsDescriptor> CudnnRnnParamsDescriptor::Create(
1316     const CudnnHandle& cudnn, int input_size, cudnnDataType_t data_type,
1317     cudnnRNNDescriptor_t rnn_desc, cudnnRNNMode_t rnn_mode,
1318     cudnnDirectionMode_t direction_mode, int num_layers) {
1319   // Query the params size.
1320   TensorDescriptor input_desc = CreateTensorDescriptor();
1321   int tensor_dims[] = {1, input_size, 1};
1322   int strides[] = {tensor_dims[1] * tensor_dims[2], tensor_dims[2], 1};
1323   RETURN_IF_CUDNN_ERROR(cudnnSetTensorNdDescriptor(
1324       /*tensorDesc=*/input_desc.get(), /*dataType=*/data_type,
1325       /*nbDims=*/sizeof(tensor_dims) / sizeof(tensor_dims[0]),
1326       /*dimA=*/tensor_dims,
1327       /*strideA=*/strides));
1328 
1329   size_t params_size = 0;
1330   RETURN_IF_CUDNN_ERROR(cudnnGetRNNParamsSize(
1331       /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc,
1332       /*xDesc=*/input_desc.get(), /*sizeInBytes=*/&params_size,
1333       /*dataType=*/data_type));
1334   int64 params_size_in_bytes = static_cast<int64>(params_size);
1335 
1336   FilterDescriptor filter_desc = CreateFilterDescriptor();
1337   int filter_dims[] = {static_cast<int>(params_size_in_bytes), 1, 1};
1338   RETURN_IF_CUDNN_ERROR(cudnnSetFilterNdDescriptor(
1339       /*filterDesc=*/filter_desc.get(), /*dataType=*/data_type,
1340       /*format=*/CUDNN_TENSOR_NCHW,
1341       /*nbDims=*/sizeof(filter_dims) / sizeof(filter_dims[0]),
1342       /*filterDimA=*/filter_dims));
1343 
1344   // Create the weights and biases into the params buffer
1345   int region_count_per_layer = [&] {
1346     switch (rnn_mode) {
1347       case CUDNN_RNN_RELU:
1348       case CUDNN_RNN_TANH:
1349         return 2;
1350       case CUDNN_LSTM:
1351         return 8;
1352       case CUDNN_GRU:
1353         return 6;
1354       default:
1355         LOG(FATAL) << "Invalid RNN Mode: " << static_cast<int>(rnn_mode);
1356         return 0;
1357     }
1358   }();
1359 
1360   FilterDescriptor region_desc_handle = CreateFilterDescriptor();
1361   const int layer_count =
1362       direction_mode == CUDNN_UNIDIRECTIONAL ? num_layers : 2 * num_layers;
1363 
1364   ParamsRegions weights;
1365   ParamsRegions biases;
1366 
1367   for (int layer = 0; layer < layer_count; layer++) {
1368     for (int region = 0; region < region_count_per_layer; region++) {
1369       for (int type = 0; type < 2; type++) {
1370         void* offset = nullptr;
1371         RETURN_IF_CUDNN_ERROR(
1372             type == 0 ? cudnnGetRNNLinLayerMatrixParams(
1373                             /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc,
1374                             /*layer=*/layer, /*xDesc=*/input_desc.get(),
1375                             /*wDesc=*/filter_desc.get(),
1376                             /*w=*/nullptr, /*linLayerID=*/region,
1377                             /*linLayerMatDesc=*/region_desc_handle.get(),
1378                             /*linLayerMat or linLayerBias=*/&offset)
1379                       : cudnnGetRNNLinLayerBiasParams(
1380                             /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc,
1381                             /*layer=*/layer, /*xDesc=*/input_desc.get(),
1382                             /*wDesc=*/filter_desc.get(),
1383                             /*w=*/nullptr, /*linLayerID=*/region,
1384                             /*linLayerMatDesc=*/region_desc_handle.get(),
1385                             /*linLayerMat or linLayerBias=*/&offset));
1386         int dims[] = {1, 1, 1};
1387         cudnnDataType_t data_type;
1388         cudnnTensorFormat_t tensor_format;
1389         int n_dims;
1390         RETURN_IF_CUDNN_ERROR(cudnnGetFilterNdDescriptor(
1391             /*filterDesc=*/region_desc_handle.get(),
1392             /*nbDimsRequested=*/sizeof(dims) / sizeof(dims[0]),
1393             /*dataType=*/&data_type, /*format=*/&tensor_format,
1394             /*nbDims=*/&n_dims, /*filterDimA=*/dims));
1395         int64 size =
1396             dims[0] * dims[1] * dims[2] * CudnnDataTypeToByteSize(data_type);
1397         dnn::RnnDescriptor::ParamsRegion region = {
1398             reinterpret_cast<int64>(offset), size};
1399         (type == 0 ? weights : biases).push_back(region);
1400       }
1401     }
1402     TF_RETURN_IF_ERROR(CheckAndFetchProjectionWeights(
1403         cudnn, rnn_desc, layer, input_desc, filter_desc, region_desc_handle,
1404         &weights));
1405   }
1406 
1407   return CudnnRnnParamsDescriptor(std::move(filter_desc), params_size_in_bytes,
1408                                   weights, biases);
1409 }
1410 
1411 }  // namespace
1412 
1413 class CudnnRnnSequenceTensorDescriptor
1414     : public dnn::RnnSequenceTensorDescriptor {
CudnnRnnSequenceTensorDescriptor(GpuExecutor * parent,int max_seq_length,int batch_size,int data_size,cudnnDataType_t data_type,RNNDataDescriptor data_handle,TensorDescriptor handle)1415   CudnnRnnSequenceTensorDescriptor(GpuExecutor* parent, int max_seq_length,
1416                                    int batch_size, int data_size,
1417                                    cudnnDataType_t data_type,
1418 #if CUDNN_VERSION >= 7201
1419                                    RNNDataDescriptor data_handle,
1420 #endif
1421                                    TensorDescriptor handle)
1422       : max_seq_length_(max_seq_length),
1423         batch_size_(batch_size),
1424         data_size_(data_size),
1425         data_type_(data_type),
1426         handle_(std::move(handle)),
1427 #if CUDNN_VERSION >= 7201
1428         rnn_data_handle_(std::move(data_handle)),
1429 #endif
1430         handles_(max_seq_length, handle_.get()) {
1431   }
1432 
1433  public:
1434   CudnnRnnSequenceTensorDescriptor(CudnnRnnSequenceTensorDescriptor&&) =
1435       default;
1436 
Create(GpuExecutor * parent,int max_seq_length,int batch_size,int data_size,cudnnDataType_t data_type)1437   static port::StatusOr<CudnnRnnSequenceTensorDescriptor> Create(
1438       GpuExecutor* parent, int max_seq_length, int batch_size, int data_size,
1439       cudnnDataType_t data_type) {
1440     CHECK_GT(max_seq_length, 0);
1441     int dims[] = {batch_size, data_size, 1};
1442     int strides[] = {dims[1] * dims[2], dims[2], 1};
1443     TensorDescriptor tensor_desc = CreateTensorDescriptor();
1444     RETURN_IF_CUDNN_ERROR(cudnnSetTensorNdDescriptor(
1445         /*tensorDesc=*/tensor_desc.get(), /*dataType=*/data_type,
1446         /*nbDims=*/sizeof(dims) / sizeof(dims[0]), /*dimA=*/dims,
1447         /*strideA=*/strides));
1448     return CudnnRnnSequenceTensorDescriptor(parent, max_seq_length, batch_size,
1449                                             data_size, data_type,
1450 #if CUDNN_VERSION >= 7201
1451                                             nullptr,
1452 #endif
1453                                             std::move(tensor_desc));
1454   }
1455 
Create(GpuExecutor * parent,int max_seq_length,int batch_size,int data_size,const absl::Span<const int> & seq_lengths,bool time_major,cudnnDataType_t data_type)1456   static port::StatusOr<CudnnRnnSequenceTensorDescriptor> Create(
1457       GpuExecutor* parent, int max_seq_length, int batch_size, int data_size,
1458       const absl::Span<const int>& seq_lengths, bool time_major,
1459       cudnnDataType_t data_type) {
1460 #if CUDNN_VERSION >= 7201
1461     CHECK_GT(max_seq_length, 0);
1462     int dims[] = {batch_size, data_size, 1};
1463     int strides[] = {dims[1] * dims[2], dims[2], 1};
1464     TensorDescriptor tensor_desc = CreateTensorDescriptor();
1465     RETURN_IF_CUDNN_ERROR(cudnnSetTensorNdDescriptor(
1466         /*tensorDesc=*/tensor_desc.get(), /*dataType=*/data_type,
1467         /*nbDims=*/sizeof(dims) / sizeof(dims[0]), /*dimA=*/dims,
1468         /*strideA=*/strides));
1469     const int* seq_lengths_array = seq_lengths.data();
1470     RNNDataDescriptor data_desc = CreateRNNDataDescriptor();
1471     float padding_fill = 0.0f;
1472     cudnnRNNDataLayout_t layout;
1473     if (time_major) {
1474       layout = CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_UNPACKED;
1475     } else {
1476       layout = CUDNN_RNN_DATA_LAYOUT_BATCH_MAJOR_UNPACKED;
1477     }
1478     RETURN_IF_CUDNN_ERROR(cudnnSetRNNDataDescriptor(
1479         /*RNNDataDesc=*/data_desc.get(), /*dataType*/ data_type,
1480         /*layout=*/layout,
1481         /*maxSeqLength=*/max_seq_length,
1482         /*batchSize=*/batch_size, /*vectorSize=*/data_size,
1483         /*seqLengthArray=*/seq_lengths_array,
1484         /*paddingFill*/ (void*)&padding_fill));
1485     return CudnnRnnSequenceTensorDescriptor(
1486         parent, max_seq_length, batch_size, data_size, data_type,
1487         std::move(data_desc), std::move(tensor_desc));
1488 #else
1489     return port::Status(port::error::INVALID_ARGUMENT,
1490                         "No supported cudnnSetRNNDataDescriptor when "
1491                         "CUDNN_VERSION < 7.2.1");
1492 #endif
1493   }
1494 
handles() const1495   const cudnnTensorDescriptor_t* handles() const { return handles_.data(); }
1496 #if CUDNN_VERSION >= 7201
data_handle() const1497   const cudnnRNNDataDescriptor_t data_handle() const {
1498     return rnn_data_handle_.get();
1499   }
1500 #endif
1501 
max_seq_length() const1502   int max_seq_length() const { return max_seq_length_; }
batch_size() const1503   int batch_size() const { return batch_size_; }
data_size() const1504   int data_size() const { return data_size_; }
is_var_seq_lengths() const1505   bool is_var_seq_lengths() const {
1506 #if CUDNN_VERSION >= 7201
1507     return rnn_data_handle_ != nullptr;
1508 #else
1509     return false;
1510 #endif
1511   }
1512 
1513  private:
1514   int max_seq_length_;
1515   int batch_size_;
1516   int data_size_;
1517   cudnnDataType_t data_type_;
1518   TensorDescriptor handle_;
1519 #if CUDNN_VERSION >= 7201
1520   RNNDataDescriptor rnn_data_handle_;
1521 #endif
1522   std::vector<cudnnTensorDescriptor_t> handles_;  // Copies of handle_.
1523   SE_DISALLOW_COPY_AND_ASSIGN(CudnnRnnSequenceTensorDescriptor);
1524 };
1525 
1526 class CudnnRnnStateTensorDescriptor : public dnn::RnnStateTensorDescriptor {
1527  public:
CudnnRnnStateTensorDescriptor(GpuExecutor * parent,int num_layers,int batch_size,int data_size,cudnnDataType_t data_type)1528   CudnnRnnStateTensorDescriptor(GpuExecutor* parent, int num_layers,
1529                                 int batch_size, int data_size,
1530                                 cudnnDataType_t data_type)
1531       : handle_(CreateTensorDescriptor()),
1532         num_layers_(num_layers),
1533         batch_size_(batch_size),
1534         data_size_(data_size),
1535         data_type_(data_type) {
1536     int dims[] = {num_layers, batch_size, data_size};
1537     int strides[] = {dims[1] * dims[2], dims[2], 1};
1538     CHECK_CUDNN_OK(cudnnSetTensorNdDescriptor(
1539         /*tensorDesc=*/handle_.get(), /*dataType=*/data_type,
1540         /*nbDims=*/sizeof(dims) / sizeof(dims[0]), /*dimA=*/dims,
1541         /*strideA=*/strides));
1542   }
1543 
handle() const1544   cudnnTensorDescriptor_t handle() const { return handle_.get(); }
1545 
num_layers() const1546   int num_layers() const { return num_layers_; }
batch_size() const1547   int batch_size() const { return batch_size_; }
data_size() const1548   int data_size() const { return data_size_; }
1549 
1550  private:
1551   TensorDescriptor handle_;
1552   int num_layers_;
1553   int batch_size_;
1554   int data_size_;
1555   cudnnDataType_t data_type_;
1556   SE_DISALLOW_COPY_AND_ASSIGN(CudnnRnnStateTensorDescriptor);
1557 };
1558 
1559 namespace {
1560 
1561 struct RnnModelDims {
1562   int num_layers = 0;
1563   int batch_size = 0;
1564   int max_seq_length = 0;
1565   int hidden_size = 0;
1566   int input_size = 0;
1567   int cell_size = 0;
1568   int dir_count = 0;
1569 };
1570 
1571 template <class T>
ExtractAndCheckRnnForward(const CudnnRnnDescriptor & rnn_desc,const CudnnRnnSequenceTensorDescriptor & input_desc,const DeviceMemory<T> & input_data,const CudnnRnnStateTensorDescriptor & input_h_desc,const DeviceMemory<T> & input_h_data,const CudnnRnnStateTensorDescriptor & input_c_desc,const DeviceMemory<T> & input_c_data,const DeviceMemory<T> & params,const CudnnRnnSequenceTensorDescriptor & output_desc,const DeviceMemory<T> & output_data,const CudnnRnnStateTensorDescriptor & output_h_desc,const DeviceMemory<T> & output_h_data,const CudnnRnnStateTensorDescriptor & output_c_desc,const DeviceMemory<T> & output_c_data)1572 port::StatusOr<RnnModelDims> ExtractAndCheckRnnForward(
1573     const CudnnRnnDescriptor& rnn_desc,
1574     const CudnnRnnSequenceTensorDescriptor& input_desc,
1575     const DeviceMemory<T>& input_data,
1576     const CudnnRnnStateTensorDescriptor& input_h_desc,
1577     const DeviceMemory<T>& input_h_data,
1578     const CudnnRnnStateTensorDescriptor& input_c_desc,
1579     const DeviceMemory<T>& input_c_data, const DeviceMemory<T>& params,
1580     const CudnnRnnSequenceTensorDescriptor& output_desc,
1581     const DeviceMemory<T>& output_data,
1582     const CudnnRnnStateTensorDescriptor& output_h_desc,
1583     const DeviceMemory<T>& output_h_data,
1584     const CudnnRnnStateTensorDescriptor& output_c_desc,
1585     const DeviceMemory<T>& output_c_data) {
1586   // extract model parameters
1587   RnnModelDims model_dims;
1588   model_dims.num_layers = rnn_desc.num_layers();
1589   model_dims.batch_size = input_desc.batch_size();
1590   model_dims.max_seq_length = input_desc.max_seq_length();
1591   model_dims.hidden_size = rnn_desc.hidden_size();
1592   model_dims.input_size = input_desc.data_size();
1593   model_dims.cell_size = rnn_desc.cell_size();
1594   model_dims.dir_count =
1595       (rnn_desc.direction_mode() == CUDNN_BIDIRECTIONAL) ? 2 : 1;
1596 
1597   // check parameters
1598   if (!(input_h_desc.num_layers() ==
1599             model_dims.num_layers * model_dims.dir_count &&
1600         input_h_desc.batch_size() == model_dims.batch_size &&
1601         input_h_desc.data_size() == model_dims.hidden_size)) {
1602     return port::Status(port::error::INVALID_ARGUMENT, "Invalid input_h shape");
1603   }
1604   // The LSTM projection will be used if input_h_desc.data_size() <
1605   // input_c_desc.data_size()
1606   if (!(input_h_desc.num_layers() == input_c_desc.num_layers() &&
1607         input_h_desc.batch_size() == input_c_desc.batch_size() &&
1608         input_h_desc.data_size() <= input_c_desc.data_size())) {
1609     return port::Status(port::error::INVALID_ARGUMENT, "Invalid input_c shape");
1610   }
1611   if (!(output_desc.max_seq_length() == model_dims.max_seq_length &&
1612         output_desc.batch_size() == model_dims.batch_size &&
1613         output_desc.data_size() ==
1614             model_dims.hidden_size * model_dims.dir_count)) {
1615     return port::Status(port::error::INVALID_ARGUMENT, "Invalid output shape");
1616   }
1617   if (!(input_h_desc.num_layers() == output_h_desc.num_layers() &&
1618         input_h_desc.batch_size() == output_h_desc.batch_size() &&
1619         input_h_desc.data_size() == output_h_desc.data_size())) {
1620     return port::Status(port::error::INVALID_ARGUMENT,
1621                         "Invalid output_h shape");
1622   }
1623   if (!(input_h_desc.num_layers() == output_c_desc.num_layers() &&
1624         input_h_desc.batch_size() == output_c_desc.batch_size() &&
1625         input_h_desc.data_size() <= output_c_desc.data_size())) {
1626     return port::Status(port::error::INVALID_ARGUMENT,
1627                         "Invalid output_c shape");
1628   }
1629 
1630   return model_dims;
1631 }
1632 
CheckRNNParameterSize(const CudnnHandle & cudnn,const CudnnRnnDescriptor & rnn_desc,const CudnnRnnSequenceTensorDescriptor & input_desc)1633 port::Status CheckRNNParameterSize(
1634     const CudnnHandle& cudnn, const CudnnRnnDescriptor& rnn_desc,
1635     const CudnnRnnSequenceTensorDescriptor& input_desc) {
1636   size_t params_size_in_bytes = 0;
1637   RETURN_IF_CUDNN_ERROR(cudnnGetRNNParamsSize(
1638       /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(),
1639       /*xDesc=*/input_desc.handles()[0], /*sizeInBytes=*/&params_size_in_bytes,
1640       /*dataType=*/rnn_desc.data_type()));
1641   if (static_cast<int64>(params_size_in_bytes) !=
1642       rnn_desc.ParamsSizeInBytes()) {
1643     return port::Status(port::error::INVALID_ARGUMENT,
1644                         "Mismatching RNN parameter size");
1645   }
1646   return port::Status::OK();
1647 }
1648 
CreateRnnWorkspace(Stream * stream,const CudnnHandle & cudnn,const CudnnRnnDescriptor & rnn_desc,const CudnnRnnSequenceTensorDescriptor & input_desc,ScratchAllocator * workspace_allocator)1649 port::StatusOr<DeviceMemory<uint8>> CreateRnnWorkspace(
1650     Stream* stream, const CudnnHandle& cudnn,
1651     const CudnnRnnDescriptor& rnn_desc,
1652     const CudnnRnnSequenceTensorDescriptor& input_desc,
1653     ScratchAllocator* workspace_allocator) {
1654   // Query the workspace size.
1655   size_t workspace_size_in_bytes = 0;
1656   RETURN_IF_CUDNN_ERROR(cudnnGetRNNWorkspaceSize(
1657       /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(),
1658       /*seqLength=*/input_desc.max_seq_length(), /*xDesc=*/input_desc.handles(),
1659       /*sizeInBytes=*/&workspace_size_in_bytes));
1660   // Allocate the workspace.
1661   if (workspace_size_in_bytes == 0) {
1662     return DeviceMemory<uint8>();
1663   }
1664   return workspace_allocator->AllocateBytes(workspace_size_in_bytes);
1665 }
1666 
1667 #if CUDNN_VERSION >= 7402
CreateBatchNormForwardWorkspace(Stream * stream,const CudnnHandle & cudnn,const cudnnBatchNormMode_t & mode,const cudnnBatchNormOps_t & bn_ops,const cudnnActivationDescriptor_t & activation_desc,const CudnnTensorDescriptor & x_descriptor,const CudnnTensorDescriptor & scale_offset_descriptor,ScratchAllocator * workspace_allocator)1668 port::StatusOr<DeviceMemory<uint8>> CreateBatchNormForwardWorkspace(
1669     Stream* stream, const CudnnHandle& cudnn, const cudnnBatchNormMode_t& mode,
1670     const cudnnBatchNormOps_t& bn_ops,
1671     const cudnnActivationDescriptor_t& activation_desc,
1672     const CudnnTensorDescriptor& x_descriptor,
1673     const CudnnTensorDescriptor& scale_offset_descriptor,
1674     ScratchAllocator* workspace_allocator) {
1675   // Query the workspace size.
1676   size_t workspace_size_in_bytes = 0;
1677   RETURN_IF_CUDNN_ERROR(
1678       cudnnGetBatchNormalizationForwardTrainingExWorkspaceSize(
1679           /*handle=*/cudnn.handle(), /*mode=*/mode, /*bnOps=*/bn_ops,
1680           /*xDesc=*/x_descriptor.handle(), /*zDesc=*/x_descriptor.handle(),
1681           /*yDesc=*/x_descriptor.handle(),
1682           /*bnScaleBiasMeanVarDesc=*/scale_offset_descriptor.handle(),
1683           /*activationDesc=*/activation_desc,
1684           /*sizeInBytes=*/&workspace_size_in_bytes));
1685   // Allocate the workspace.
1686   if (workspace_size_in_bytes == 0) {
1687     return DeviceMemory<uint8>();
1688   }
1689   return workspace_allocator->AllocateBytes(workspace_size_in_bytes);
1690 }
1691 
CreateBatchNormBackwardWorkspace(Stream * stream,const CudnnHandle & cudnn,const cudnnBatchNormMode_t & mode,const cudnnBatchNormOps_t & bn_ops,const CudnnTensorDescriptor & x_descriptor,const CudnnTensorDescriptor & scale_offset_descriptor,ScratchAllocator * workspace_allocator)1692 port::StatusOr<DeviceMemory<uint8>> CreateBatchNormBackwardWorkspace(
1693     Stream* stream, const CudnnHandle& cudnn, const cudnnBatchNormMode_t& mode,
1694     const cudnnBatchNormOps_t& bn_ops,
1695     const CudnnTensorDescriptor& x_descriptor,
1696     const CudnnTensorDescriptor& scale_offset_descriptor,
1697     ScratchAllocator* workspace_allocator) {
1698   // Query the workspace size.
1699   size_t workspace_size_in_bytes = 0;
1700   RETURN_IF_CUDNN_ERROR(cudnnGetBatchNormalizationBackwardExWorkspaceSize(
1701       /*handle=*/cudnn.handle(), /*mode=*/mode, /*bnOps=*/bn_ops,
1702       /*xDesc=*/x_descriptor.handle(),
1703       /*yDesc=*/x_descriptor.handle(),
1704       /*dyDesc=*/x_descriptor.handle(),
1705       /*dzDesc=*/nullptr,
1706       /*dxDesc=*/x_descriptor.handle(),
1707       /*dBnScaleBiasDesc=*/scale_offset_descriptor.handle(),
1708       /*activationDesc=*/nullptr, /*sizeInBytes=*/&workspace_size_in_bytes));
1709   // Allocate the workspace.
1710   if (workspace_size_in_bytes == 0) {
1711     return DeviceMemory<uint8>();
1712   }
1713   return workspace_allocator->AllocateBytes(workspace_size_in_bytes);
1714 }
1715 
1716 #endif
1717 
1718 }  // namespace
1719 
1720 template <class T>
DoRnnForwardImpl(Stream * stream,const CudnnRnnDescriptor & rnn_desc,const CudnnRnnSequenceTensorDescriptor & input_desc,const DeviceMemory<T> & input_data,const CudnnRnnStateTensorDescriptor & input_h_desc,const DeviceMemory<T> & input_h_data,const CudnnRnnStateTensorDescriptor & input_c_desc,const DeviceMemory<T> & input_c_data,const DeviceMemory<T> & params,const CudnnRnnSequenceTensorDescriptor & output_desc,DeviceMemory<T> * output_data,const CudnnRnnStateTensorDescriptor & output_h_desc,DeviceMemory<T> * output_h_data,const CudnnRnnStateTensorDescriptor & output_c_desc,DeviceMemory<T> * output_c_data,bool is_training,ScratchAllocator * reserve_space_allocator,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)1721 port::Status CudnnSupport::DoRnnForwardImpl(
1722     Stream* stream, const CudnnRnnDescriptor& rnn_desc,
1723     const CudnnRnnSequenceTensorDescriptor& input_desc,
1724     const DeviceMemory<T>& input_data,
1725     const CudnnRnnStateTensorDescriptor& input_h_desc,
1726     const DeviceMemory<T>& input_h_data,
1727     const CudnnRnnStateTensorDescriptor& input_c_desc,
1728     const DeviceMemory<T>& input_c_data, const DeviceMemory<T>& params,
1729     const CudnnRnnSequenceTensorDescriptor& output_desc,
1730     DeviceMemory<T>* output_data,
1731     const CudnnRnnStateTensorDescriptor& output_h_desc,
1732     DeviceMemory<T>* output_h_data,
1733     const CudnnRnnStateTensorDescriptor& output_c_desc,
1734     DeviceMemory<T>* output_c_data, bool is_training,
1735     ScratchAllocator* reserve_space_allocator,
1736     ScratchAllocator* workspace_allocator,
1737     dnn::ProfileResult* output_profile_result) {
1738   SE_ASSIGN_OR_RETURN(
1739       RnnModelDims model_dims,
1740       ExtractAndCheckRnnForward(
1741           rnn_desc, input_desc, input_data, input_h_desc, input_h_data,
1742           input_c_desc, input_c_data, params, output_desc, *output_data,
1743           output_h_desc, *output_h_data, output_c_desc, *output_c_data));
1744 
1745   auto cudnn = cudnn_->GetHandle(parent_, stream);
1746 
1747   SE_RETURN_IF_ERROR(CheckRNNParameterSize(cudnn, rnn_desc, input_desc));
1748   SE_ASSIGN_OR_RETURN(DeviceMemory<uint8> workspace,
1749                       CreateRnnWorkspace(stream, cudnn, rnn_desc, input_desc,
1750                                          workspace_allocator))
1751 
1752   // query the reserve space size
1753   // allocate the reserve space
1754   DeviceMemory<uint8> reserve_space;
1755   if (is_training) {
1756     size_t reserve_space_size_in_bytes = 0;
1757     RETURN_IF_CUDNN_ERROR(cudnnGetRNNTrainingReserveSize(
1758         /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(),
1759         /*seqLength=*/model_dims.max_seq_length, /*xDesc=*/input_desc.handles(),
1760         /*sizeInBytes=*/&reserve_space_size_in_bytes));
1761 
1762     if (reserve_space_size_in_bytes > 0) {
1763       SE_ASSIGN_OR_RETURN(reserve_space, reserve_space_allocator->AllocateBytes(
1764                                              reserve_space_size_in_bytes));
1765     }
1766   }
1767 
1768   std::unique_ptr<GpuTimer, GpuTimerDeleter> timer;
1769   const bool is_profiling = output_profile_result != nullptr;
1770   if (is_profiling) {
1771     timer.reset(new GpuTimer(parent_));
1772     // The start and stop of the timer should be as close to the Cudnn call as
1773     // possible. It is still possible for other threads to issue workload on
1774     // to this stream. So it could take multiple profiling measurements.
1775     if (!timer->Init() || !timer->Start(AsGpuStream(stream))) {
1776       return port::Status(port::error::INTERNAL, "Failed to start timer");
1777     }
1778   }
1779 
1780   if (!is_training) {
1781     if (input_desc.is_var_seq_lengths()) {
1782 #if CUDNN_VERSION >= 7201
1783       RETURN_IF_CUDNN_ERROR(cudnnRNNForwardInferenceEx(
1784           /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(),
1785           /*xDesc=*/input_desc.data_handle(), /*x=*/input_data.opaque(),
1786           /*hxDesc=*/input_h_desc.handle(), /*hx=*/input_h_data.opaque(),
1787           /*cxDesc=*/input_c_desc.handle(), /*cx=*/input_c_data.opaque(),
1788           /*wDesc=*/rnn_desc.params_handle(), /*w=*/params.opaque(),
1789           /*yDesc=*/output_desc.data_handle(),
1790           /*y=*/output_data->opaque(),
1791           /*hyDesc=*/output_h_desc.handle(), /*hy=*/output_h_data->opaque(),
1792           /*cyDesc=*/output_c_desc.handle(), /*cy=*/output_c_data->opaque(),
1793           nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr,
1794           nullptr,
1795           /*workspace=*/workspace.opaque(),
1796           /*workSpaceSizeInBytes=*/workspace.size()));
1797 #else
1798       return port::Status(port::error::INVALID_ARGUMENT,
1799                           "No supported cudnnRNNForwardInferenceEx when "
1800                           "CUDNN_VERSION < 7.2.1");
1801 #endif
1802     } else {
1803       RETURN_IF_CUDNN_ERROR(cudnnRNNForwardInference(
1804           /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(),
1805           /*seqLength=*/model_dims.max_seq_length,
1806           /*xDesc=*/input_desc.handles(),
1807           /*x=*/input_data.opaque(), /*hxDesc=*/input_h_desc.handle(),
1808           /*hx=*/input_h_data.opaque(), /*cxDesc=*/input_c_desc.handle(),
1809           /*cx=*/input_c_data.opaque(), /*wDesc=*/rnn_desc.params_handle(),
1810           /*w=*/params.opaque(), /*yDesc=*/output_desc.handles(),
1811           /*y=*/output_data->opaque(), /*hyDesc=*/output_h_desc.handle(),
1812           /*hy=*/output_h_data->opaque(), /*cyDesc=*/output_c_desc.handle(),
1813           /*cy=*/output_c_data->opaque(), /*workspace=*/workspace.opaque(),
1814           /*workSpaceSizeInBytes=*/workspace.size()));
1815     }
1816   } else {
1817     if (input_desc.is_var_seq_lengths()) {
1818 #if CUDNN_VERSION >= 7201
1819       // cudnnSetRNNPaddingMode(rnn_desc.handle(), CUDNN_RNN_PADDED_IO_ENABLED);
1820       RETURN_IF_CUDNN_ERROR(cudnnRNNForwardTrainingEx(
1821           /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(),
1822           /*xDesc=*/input_desc.data_handle(), /*x=*/input_data.opaque(),
1823           /*hxDesc=*/input_h_desc.handle(), /*hx=*/input_h_data.opaque(),
1824           /*cxDesc=*/input_c_desc.handle(), /*cx=*/input_c_data.opaque(),
1825           /*wDesc=*/rnn_desc.params_handle(), /*w=*/params.opaque(),
1826           /*yDesc=*/output_desc.data_handle(),
1827           /*y=*/output_data->opaque(),
1828           /*hyDesc=*/output_h_desc.handle(), /*hy=*/output_h_data->opaque(),
1829           /*cyDesc=*/output_c_desc.handle(), /*cy=*/output_c_data->opaque(),
1830           nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr,
1831           nullptr,
1832           /*workspace=*/workspace.opaque(),
1833           /*workSpaceSizeInBytes=*/workspace.size(),
1834           /*reserveSpace=*/reserve_space.opaque(),
1835           /*reserveSpaceSizeInBytes=*/reserve_space.size()));
1836 #else
1837       return port::Status(port::error::INVALID_ARGUMENT,
1838                           "No supported cudnnRNNForwardTrainingEx when "
1839                           "CUDNN_VERSION < 7.2.1");
1840 #endif
1841     } else {
1842       RETURN_IF_CUDNN_ERROR(cudnnRNNForwardTraining(
1843           /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(),
1844           /*seqLength=*/model_dims.max_seq_length,
1845           /*xDesc=*/input_desc.handles(),
1846           /*x=*/input_data.opaque(), /*hxDesc=*/input_h_desc.handle(),
1847           /*hx=*/input_h_data.opaque(), /*cxDesc=*/input_c_desc.handle(),
1848           /*cx=*/input_c_data.opaque(), /*wDesc=*/rnn_desc.params_handle(),
1849           /*w=*/params.opaque(), /*yDesc=*/output_desc.handles(),
1850           /*y=*/output_data->opaque(), /*hyDesc=*/output_h_desc.handle(),
1851           /*hy=*/output_h_data->opaque(), /*cyDesc=*/output_c_desc.handle(),
1852           /*cy=*/output_c_data->opaque(), /*workspace=*/workspace.opaque(),
1853           /*workSpaceSizeInBytes=*/workspace.size(),
1854           /*reserveSpace=*/reserve_space.opaque(),
1855           /*reserveSpaceSizeInBytes=*/reserve_space.size()));
1856     }
1857   }
1858 
1859   if (is_profiling) {
1860     if (!timer->Stop(AsGpuStream(stream))) {
1861       return port::Status(port::error::INTERNAL, "Failed to stop timer");
1862     }
1863     auto algo_desc = *rnn_desc.algorithm_config().algorithm();
1864     output_profile_result->set_algorithm(algo_desc);
1865     output_profile_result->set_elapsed_time_in_ms(
1866         timer->GetElapsedMilliseconds());
1867   }
1868 
1869   return port::Status::OK();
1870 }
1871 
1872 template <class T>
DoRnnBackwardImpl(Stream * stream,const CudnnRnnDescriptor & rnn_desc,const CudnnRnnSequenceTensorDescriptor & input_desc,const DeviceMemory<T> & input_data,const CudnnRnnStateTensorDescriptor & input_h_desc,const DeviceMemory<T> & input_h_data,const CudnnRnnStateTensorDescriptor & input_c_desc,const DeviceMemory<T> & input_c_data,const DeviceMemory<T> & params,const CudnnRnnSequenceTensorDescriptor & output_desc,const DeviceMemory<T> & output_data,const CudnnRnnStateTensorDescriptor & output_h_desc,const DeviceMemory<T> & output_h_data,const CudnnRnnStateTensorDescriptor & output_c_desc,const DeviceMemory<T> & output_c_data,const DeviceMemory<T> & output_backprop_data,const DeviceMemory<T> & output_h_backprop_data,const DeviceMemory<T> & output_c_backprop_data,DeviceMemory<T> * input_backprop_data,DeviceMemory<T> * input_h_backprop_data,DeviceMemory<T> * input_c_backprop_data,DeviceMemory<T> * params_backprop_data,DeviceMemory<uint8> * reserve_space_data,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)1873 port::Status CudnnSupport::DoRnnBackwardImpl(
1874     Stream* stream, const CudnnRnnDescriptor& rnn_desc,
1875     const CudnnRnnSequenceTensorDescriptor& input_desc,
1876     const DeviceMemory<T>& input_data,
1877     const CudnnRnnStateTensorDescriptor& input_h_desc,
1878     const DeviceMemory<T>& input_h_data,
1879     const CudnnRnnStateTensorDescriptor& input_c_desc,
1880     const DeviceMemory<T>& input_c_data, const DeviceMemory<T>& params,
1881     const CudnnRnnSequenceTensorDescriptor& output_desc,
1882     const DeviceMemory<T>& output_data,
1883     const CudnnRnnStateTensorDescriptor& output_h_desc,
1884     const DeviceMemory<T>& output_h_data,
1885     const CudnnRnnStateTensorDescriptor& output_c_desc,
1886     const DeviceMemory<T>& output_c_data,
1887     const DeviceMemory<T>& output_backprop_data,
1888     const DeviceMemory<T>& output_h_backprop_data,
1889     const DeviceMemory<T>& output_c_backprop_data,
1890     DeviceMemory<T>* input_backprop_data,
1891     DeviceMemory<T>* input_h_backprop_data,
1892     DeviceMemory<T>* input_c_backprop_data,
1893     DeviceMemory<T>* params_backprop_data,
1894     DeviceMemory<uint8>* reserve_space_data,
1895     ScratchAllocator* workspace_allocator,
1896     dnn::ProfileResult* output_profile_result) {
1897   SE_ASSIGN_OR_RETURN(
1898       RnnModelDims model_dims,
1899       ExtractAndCheckRnnForward(rnn_desc, input_desc, input_data, input_h_desc,
1900                                 input_h_data, input_c_desc, input_c_data,
1901                                 params, output_desc, output_data, output_h_desc,
1902                                 output_h_data, output_c_desc, output_c_data));
1903 
1904   auto cudnn = cudnn_->GetHandle(parent_, stream);
1905 
1906   SE_RETURN_IF_ERROR(CheckRNNParameterSize(cudnn, rnn_desc, input_desc));
1907   SE_ASSIGN_OR_RETURN(DeviceMemory<uint8> workspace,
1908                       CreateRnnWorkspace(stream, cudnn, rnn_desc, input_desc,
1909                                          workspace_allocator));
1910 
1911   std::unique_ptr<GpuTimer, GpuTimerDeleter> timer;
1912   const bool is_profiling = output_profile_result != nullptr;
1913   if (is_profiling) {
1914     timer.reset(new GpuTimer(parent_));
1915     // The start and stop of the timer should be as close to the Cudnn call as
1916     // possible. It is still possible for other threads to issue workload on
1917     // to this stream. So it could take multiple profiling measurements.
1918     if (!timer->Init() || !timer->Start(AsGpuStream(stream))) {
1919       return port::Status(port::error::INTERNAL, "Failed to start timer");
1920     }
1921   }
1922 
1923   if (input_desc.is_var_seq_lengths()) {
1924 #if CUDNN_VERSION >= 7201
1925     RETURN_IF_CUDNN_ERROR(cudnnRNNBackwardDataEx(
1926         /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(),
1927         /*yDesc=*/output_desc.data_handle(), /*y=*/output_data.opaque(),
1928         /*dyDesc=*/output_desc.data_handle(),
1929         /*dy=*/output_backprop_data.opaque(), nullptr, nullptr,
1930         /*dhyDesc=*/output_h_desc.handle(),
1931         /*dhy=*/output_h_backprop_data.opaque(),
1932         /*dcyDesc=*/output_c_desc.handle(),
1933         /*dcy=*/output_c_backprop_data.opaque(),
1934         /*wDesc=*/rnn_desc.params_handle(), /*w=*/params.opaque(),
1935         /*hxDesc=*/input_h_desc.handle(), /*hx=*/input_h_data.opaque(),
1936         /*cxDesc=*/input_c_desc.handle(), /*cx=*/input_c_data.opaque(),
1937         /*dxDesc=*/input_desc.data_handle(),
1938         /*dx=*/input_backprop_data->opaque(),
1939         /*dhxDesc=*/input_h_desc.handle(),
1940         /*dhx=*/input_h_backprop_data->opaque(),
1941         /*dcxDesc=*/input_c_desc.handle(),
1942         /*dcx=*/input_c_backprop_data->opaque(), nullptr, nullptr,
1943         /*workspace=*/workspace.opaque(),
1944         /*workSpaceSizeInBytes=*/workspace.size(),
1945         /*reserveSpace=*/reserve_space_data->opaque(),
1946         /*reserveSpaceSizeInBytes=*/reserve_space_data->size()));
1947 #else
1948     return port::Status(port::error::INVALID_ARGUMENT,
1949                         "No supported cudnnRNNBackwardDataEx when "
1950                         "CUDNN_VERSION < 7.2.1");
1951 #endif
1952   } else {
1953     RETURN_IF_CUDNN_ERROR(cudnnRNNBackwardData(
1954         /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(),
1955         /*seqLength=*/model_dims.max_seq_length,
1956         /*yDesc=*/output_desc.handles(),
1957         /*y=*/output_data.opaque(), /*dyDesc=*/output_desc.handles(),
1958         /*dy=*/output_backprop_data.opaque(),
1959         /*dhyDesc=*/output_h_desc.handle(),
1960         /*dhy=*/output_h_backprop_data.opaque(),
1961         /*dcyDesc=*/output_c_desc.handle(),
1962         /*dcy=*/output_c_backprop_data.opaque(),
1963         /*wDesc=*/rnn_desc.params_handle(), /*w=*/params.opaque(),
1964         /*hxDesc=*/input_h_desc.handle(), /*hx=*/input_h_data.opaque(),
1965         /*cxDesc=*/input_c_desc.handle(), /*cx=*/input_c_data.opaque(),
1966         /*dxDesc=*/input_desc.handles(), /*dx=*/input_backprop_data->opaque(),
1967         /*dhxDesc=*/input_h_desc.handle(),
1968         /*dhx=*/input_h_backprop_data->opaque(),
1969         /*dcxDesc=*/input_c_desc.handle(),
1970         /*dcx=*/input_c_backprop_data->opaque(),
1971         /*workspace=*/workspace.opaque(),
1972         /*workSpaceSizeInBytes=*/workspace.size(),
1973         /*reserveSpace=*/reserve_space_data->opaque(),
1974         /*reserveSpaceSizeInBytes=*/reserve_space_data->size()));
1975   }
1976 
1977   if (params_backprop_data != nullptr) {
1978     // Clear the dw to zeros.
1979     stream->ThenMemZero(params_backprop_data, params_backprop_data->size());
1980     if (input_desc.is_var_seq_lengths()) {
1981 #if CUDNN_VERSION >= 7201
1982       RETURN_IF_CUDNN_ERROR(cudnnRNNBackwardWeightsEx(
1983           /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(),
1984           /*xDesc=*/input_desc.data_handle(), /*x=*/input_data.opaque(),
1985           /*hxDesc=*/input_h_desc.handle(), /*hx=*/input_h_data.opaque(),
1986           /*yDesc=*/output_desc.data_handle(),
1987           /*y=*/output_data.opaque(),
1988           /*workspace=*/workspace.opaque(),
1989           /*workSpaceSizeInBytes=*/workspace.size(),
1990           /*dwDesc=*/rnn_desc.params_handle(),
1991           /*dw=*/params_backprop_data->opaque(),
1992           /*reserveSpace=*/reserve_space_data->opaque(),
1993           /*reserveSpaceSizeInBytes=*/reserve_space_data->size()));
1994 #else
1995       return port::Status(port::error::INVALID_ARGUMENT,
1996                           "No supported cudnnRNNBackwardWeightsEx when "
1997                           "CUDNN_VERSION < 7.2.1");
1998 #endif
1999     } else {
2000       // make the backward weight call
2001       RETURN_IF_CUDNN_ERROR(cudnnRNNBackwardWeights(
2002           /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(),
2003           /*seqLength=*/model_dims.max_seq_length,
2004           /*xDesc=*/input_desc.handles(),
2005           /*x=*/input_data.opaque(), /*hxDesc=*/input_h_desc.handle(),
2006           /*hx=*/input_h_data.opaque(), /*yDesc=*/output_desc.handles(),
2007           /*y=*/output_data.opaque(), /*workspace=*/workspace.opaque(),
2008           /*workSpaceSizeInBytes=*/workspace.size(),
2009           /*dwDesc=*/rnn_desc.params_handle(),
2010           /*dw=*/params_backprop_data->opaque(),
2011           /*reserveSpace=*/reserve_space_data->opaque(),
2012           /*reserveSpaceSizeInBytes=*/reserve_space_data->size()));
2013     }
2014   }
2015 
2016   if (is_profiling) {
2017     if (!timer->Stop(AsGpuStream(stream))) {
2018       return port::Status(port::error::INTERNAL, "Failed to stop timer");
2019     }
2020     auto algo_desc = *rnn_desc.algorithm_config().algorithm();
2021     output_profile_result->set_algorithm(algo_desc);
2022     output_profile_result->set_elapsed_time_in_ms(
2023         timer->GetElapsedMilliseconds());
2024   }
2025 
2026   return port::Status::OK();
2027 }
2028 
DoCtcLossImpl(Stream * stream,const CudnnRnnStateTensorDescriptor & probs_desc,const DeviceMemoryBase probs_data,absl::Span<const int> labels_data,absl::Span<const int> labels_lengths_data,absl::Span<const int> input_lengths_data,DeviceMemoryBase costs_data,const CudnnRnnStateTensorDescriptor & grads_desc,DeviceMemoryBase grads_data,const CudnnCtcLossDescriptor & ctc_loss_desc,DeviceMemory<uint8> scratch_memory)2029 port::Status CudnnSupport::DoCtcLossImpl(
2030     Stream* stream, const CudnnRnnStateTensorDescriptor& probs_desc,
2031     const DeviceMemoryBase probs_data, absl::Span<const int> labels_data,
2032     absl::Span<const int> labels_lengths_data,
2033     absl::Span<const int> input_lengths_data, DeviceMemoryBase costs_data,
2034     const CudnnRnnStateTensorDescriptor& grads_desc,
2035     DeviceMemoryBase grads_data, const CudnnCtcLossDescriptor& ctc_loss_desc,
2036     DeviceMemory<uint8> scratch_memory) {
2037   auto cudnn = cudnn_->GetHandle(parent_, stream);
2038 
2039   int kNumTimestamps = probs_desc.num_layers();
2040   int kBatchSize = probs_desc.batch_size();
2041   int kNumLabels = probs_desc.data_size();
2042   int total_size = kNumLabels * kNumTimestamps * kBatchSize;
2043   (void)total_size;
2044 
2045 #if CUDNN_VERSION >= 7603
2046   RETURN_IF_CUDNN_ERROR(cudnnCTCLoss(
2047       /*handle=*/cudnn.handle(), /*probsDesc=*/probs_desc.handle(),
2048       /*probs=*/probs_data.opaque(), /*labels=*/labels_data.data(),
2049       /*labelLengths=*/labels_lengths_data.data(),
2050       /*inputLengths=*/input_lengths_data.data(),
2051       /*costs=*/costs_data.opaque(), /*gradientsDesc=*/grads_desc.handle(),
2052       /*gradients=*/grads_data.opaque(),
2053       /*algo=*/CUDNN_CTC_LOSS_ALGO_NON_DETERMINISTIC,
2054       /*ctcLossDesc=*/ctc_loss_desc.handle(),
2055       /*workspace=*/scratch_memory.opaque(),
2056       /*workSpaceSizeInBytes=*/scratch_memory.size()));
2057 #else
2058   return port::Status(port::error::INVALID_ARGUMENT,
2059                       "No supported cudnnCTCLoss when "
2060                       "CUDNN_VERSION < 7.6.3");
2061 #endif
2062 
2063   return port::Status::OK();
2064 }
2065 
2066 port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>>
createRnnDescriptor(int num_layers,int hidden_size,int input_size,int cell_size,int batch_size,dnn::RnnInputMode input_mode,dnn::RnnDirectionMode direction_mode,dnn::RnnMode rnn_mode,dnn::DataType data_type,const dnn::AlgorithmConfig & algorithm_config,float dropout,uint64 seed,ScratchAllocator * state_allocator,bool use_padded_io)2067 CudnnSupport::createRnnDescriptor(
2068     int num_layers, int hidden_size, int input_size, int cell_size,
2069     int batch_size, dnn::RnnInputMode input_mode,
2070     dnn::RnnDirectionMode direction_mode, dnn::RnnMode rnn_mode,
2071     dnn::DataType data_type, const dnn::AlgorithmConfig& algorithm_config,
2072     float dropout, uint64 seed, ScratchAllocator* state_allocator,
2073     bool use_padded_io) {
2074   // Setting up a cudnnRNNDescriptor requires a cuDNN handle, but because it's
2075   // not enqueueing anything into a stream, we pass in the null stream.
2076   auto cudnn = cudnn_->GetHandle(parent_, /*stream=*/nullptr);
2077   SE_ASSIGN_OR_RETURN(
2078       CudnnRnnDescriptor rnn_desc,
2079       CudnnRnnDescriptor::Create(
2080           cudnn, num_layers, hidden_size, input_size, cell_size, batch_size,
2081           ToCudnnRnnInputMode(input_mode),
2082           ToCudnnRnnDirectionMode(direction_mode), ToCudnnRnnMode(rnn_mode),
2083           ToCudnnDataType(data_type), GetRnnComputeType(data_type),
2084           algorithm_config, dropout, seed, state_allocator, use_padded_io));
2085   return std::unique_ptr<dnn::RnnDescriptor>(
2086       new CudnnRnnDescriptor(std::move(rnn_desc)));
2087 }
2088 
2089 port::StatusOr<std::unique_ptr<dnn::RnnSequenceTensorDescriptor>>
createRnnSequenceTensorDescriptor(int max_seq_length,int batch_size,int data_size,dnn::DataType data_type)2090 CudnnSupport::createRnnSequenceTensorDescriptor(int max_seq_length,
2091                                                 int batch_size, int data_size,
2092                                                 dnn::DataType data_type) {
2093   SE_ASSIGN_OR_RETURN(CudnnRnnSequenceTensorDescriptor descriptor,
2094                       CudnnRnnSequenceTensorDescriptor::Create(
2095                           parent_, max_seq_length, batch_size, data_size,
2096                           ToCudnnDataType(data_type)));
2097   return std::unique_ptr<dnn::RnnSequenceTensorDescriptor>(
2098       new CudnnRnnSequenceTensorDescriptor(std::move(descriptor)));
2099 }
2100 
2101 port::StatusOr<std::unique_ptr<dnn::RnnSequenceTensorDescriptor>>
createRnnSequenceTensorDescriptor(int max_seq_length,int batch_size,int data_size,const absl::Span<const int> & seq_lengths,bool time_major,dnn::DataType data_type)2102 CudnnSupport::createRnnSequenceTensorDescriptor(
2103     int max_seq_length, int batch_size, int data_size,
2104     const absl::Span<const int>& seq_lengths, bool time_major,
2105     dnn::DataType data_type) {
2106   SE_ASSIGN_OR_RETURN(CudnnRnnSequenceTensorDescriptor descriptor,
2107                       CudnnRnnSequenceTensorDescriptor::Create(
2108                           parent_, max_seq_length, batch_size, data_size,
2109                           seq_lengths, time_major, ToCudnnDataType(data_type)));
2110   return std::unique_ptr<dnn::RnnSequenceTensorDescriptor>(
2111       new CudnnRnnSequenceTensorDescriptor(std::move(descriptor)));
2112 }
2113 
2114 port::StatusOr<std::unique_ptr<dnn::RnnStateTensorDescriptor>>
createRnnStateTensorDescriptor(int num_layer,int batch_size,int data_size,dnn::DataType data_type)2115 CudnnSupport::createRnnStateTensorDescriptor(int num_layer, int batch_size,
2116                                              int data_size,
2117                                              dnn::DataType data_type) {
2118   return std::unique_ptr<dnn::RnnStateTensorDescriptor>(
2119       new CudnnRnnStateTensorDescriptor(parent_, num_layer, batch_size,
2120                                         data_size, ToCudnnDataType(data_type)));
2121 }
2122 
DoRnnForward(Stream * stream,const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<Eigen::half> & input_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<Eigen::half> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<Eigen::half> & input_c_data,const DeviceMemory<Eigen::half> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,DeviceMemory<Eigen::half> * output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,DeviceMemory<Eigen::half> * output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,DeviceMemory<Eigen::half> * output_c_data,bool is_training,ScratchAllocator * reserve_space_allocator,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)2123 bool CudnnSupport::DoRnnForward(
2124     Stream* stream, const dnn::RnnDescriptor& rnn_desc,
2125     const dnn::RnnSequenceTensorDescriptor& input_desc,
2126     const DeviceMemory<Eigen::half>& input_data,
2127     const dnn::RnnStateTensorDescriptor& input_h_desc,
2128     const DeviceMemory<Eigen::half>& input_h_data,
2129     const dnn::RnnStateTensorDescriptor& input_c_desc,
2130     const DeviceMemory<Eigen::half>& input_c_data,
2131     const DeviceMemory<Eigen::half>& params,
2132     const dnn::RnnSequenceTensorDescriptor& output_desc,
2133     DeviceMemory<Eigen::half>* output_data,
2134     const dnn::RnnStateTensorDescriptor& output_h_desc,
2135     DeviceMemory<Eigen::half>* output_h_data,
2136     const dnn::RnnStateTensorDescriptor& output_c_desc,
2137     DeviceMemory<Eigen::half>* output_c_data, bool is_training,
2138     ScratchAllocator* reserve_space_allocator,
2139     ScratchAllocator* workspace_allocator,
2140     dnn::ProfileResult* output_profile_result) {
2141   const CudnnRnnDescriptor& cudnn_rnn_desc =
2142       static_cast<const CudnnRnnDescriptor&>(rnn_desc);
2143   const CudnnRnnSequenceTensorDescriptor& cudnn_input_desc =
2144       static_cast<const CudnnRnnSequenceTensorDescriptor&>(input_desc);
2145   const CudnnRnnStateTensorDescriptor& cudnn_input_h_desc =
2146       static_cast<const CudnnRnnStateTensorDescriptor&>(input_h_desc);
2147   const CudnnRnnStateTensorDescriptor& cudnn_input_c_desc =
2148       static_cast<const CudnnRnnStateTensorDescriptor&>(input_c_desc);
2149   const CudnnRnnSequenceTensorDescriptor& cudnn_output_desc =
2150       static_cast<const CudnnRnnSequenceTensorDescriptor&>(output_desc);
2151   const CudnnRnnStateTensorDescriptor& cudnn_output_h_desc =
2152       static_cast<const CudnnRnnStateTensorDescriptor&>(output_h_desc);
2153   const CudnnRnnStateTensorDescriptor& cudnn_output_c_desc =
2154       static_cast<const CudnnRnnStateTensorDescriptor&>(output_c_desc);
2155   return IsStatusOk(
2156       DoRnnForwardImpl<Eigen::half>(
2157           stream, cudnn_rnn_desc, cudnn_input_desc, input_data,
2158           cudnn_input_h_desc, input_h_data, cudnn_input_c_desc, input_c_data,
2159           params, cudnn_output_desc, output_data, cudnn_output_h_desc,
2160           output_h_data, cudnn_output_c_desc, output_c_data, is_training,
2161           reserve_space_allocator, workspace_allocator, output_profile_result),
2162       /*report_error=*/!output_profile_result);
2163 }
2164 
DoRnnForward(Stream * stream,const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<float> & input_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<float> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<float> & input_c_data,const DeviceMemory<float> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,DeviceMemory<float> * output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,DeviceMemory<float> * output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,DeviceMemory<float> * output_c_data,bool is_training,ScratchAllocator * reserve_space_allocator,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)2165 bool CudnnSupport::DoRnnForward(
2166     Stream* stream, const dnn::RnnDescriptor& rnn_desc,
2167     const dnn::RnnSequenceTensorDescriptor& input_desc,
2168     const DeviceMemory<float>& input_data,
2169     const dnn::RnnStateTensorDescriptor& input_h_desc,
2170     const DeviceMemory<float>& input_h_data,
2171     const dnn::RnnStateTensorDescriptor& input_c_desc,
2172     const DeviceMemory<float>& input_c_data, const DeviceMemory<float>& params,
2173     const dnn::RnnSequenceTensorDescriptor& output_desc,
2174     DeviceMemory<float>* output_data,
2175     const dnn::RnnStateTensorDescriptor& output_h_desc,
2176     DeviceMemory<float>* output_h_data,
2177     const dnn::RnnStateTensorDescriptor& output_c_desc,
2178     DeviceMemory<float>* output_c_data, bool is_training,
2179     ScratchAllocator* reserve_space_allocator,
2180     ScratchAllocator* workspace_allocator,
2181     dnn::ProfileResult* output_profile_result) {
2182   const CudnnRnnDescriptor& cudnn_rnn_desc =
2183       static_cast<const CudnnRnnDescriptor&>(rnn_desc);
2184   const CudnnRnnSequenceTensorDescriptor& cudnn_input_desc =
2185       static_cast<const CudnnRnnSequenceTensorDescriptor&>(input_desc);
2186   const CudnnRnnStateTensorDescriptor& cudnn_input_h_desc =
2187       static_cast<const CudnnRnnStateTensorDescriptor&>(input_h_desc);
2188   const CudnnRnnStateTensorDescriptor& cudnn_input_c_desc =
2189       static_cast<const CudnnRnnStateTensorDescriptor&>(input_c_desc);
2190   const CudnnRnnSequenceTensorDescriptor& cudnn_output_desc =
2191       static_cast<const CudnnRnnSequenceTensorDescriptor&>(output_desc);
2192   const CudnnRnnStateTensorDescriptor& cudnn_output_h_desc =
2193       static_cast<const CudnnRnnStateTensorDescriptor&>(output_h_desc);
2194   const CudnnRnnStateTensorDescriptor& cudnn_output_c_desc =
2195       static_cast<const CudnnRnnStateTensorDescriptor&>(output_c_desc);
2196   return IsStatusOk(
2197       DoRnnForwardImpl<float>(
2198           stream, cudnn_rnn_desc, cudnn_input_desc, input_data,
2199           cudnn_input_h_desc, input_h_data, cudnn_input_c_desc, input_c_data,
2200           params, cudnn_output_desc, output_data, cudnn_output_h_desc,
2201           output_h_data, cudnn_output_c_desc, output_c_data, is_training,
2202           reserve_space_allocator, workspace_allocator, output_profile_result),
2203       /*report_error=*/!output_profile_result);
2204 }
2205 
DoRnnForward(Stream * stream,const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<double> & input_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<double> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<double> & input_c_data,const DeviceMemory<double> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,DeviceMemory<double> * output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,DeviceMemory<double> * output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,DeviceMemory<double> * output_c_data,bool is_training,ScratchAllocator * reserve_space_allocator,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)2206 bool CudnnSupport::DoRnnForward(
2207     Stream* stream, const dnn::RnnDescriptor& rnn_desc,
2208     const dnn::RnnSequenceTensorDescriptor& input_desc,
2209     const DeviceMemory<double>& input_data,
2210     const dnn::RnnStateTensorDescriptor& input_h_desc,
2211     const DeviceMemory<double>& input_h_data,
2212     const dnn::RnnStateTensorDescriptor& input_c_desc,
2213     const DeviceMemory<double>& input_c_data,
2214     const DeviceMemory<double>& params,
2215     const dnn::RnnSequenceTensorDescriptor& output_desc,
2216     DeviceMemory<double>* output_data,
2217     const dnn::RnnStateTensorDescriptor& output_h_desc,
2218     DeviceMemory<double>* output_h_data,
2219     const dnn::RnnStateTensorDescriptor& output_c_desc,
2220     DeviceMemory<double>* output_c_data, bool is_training,
2221     ScratchAllocator* reserve_space_allocator,
2222     ScratchAllocator* workspace_allocator,
2223     dnn::ProfileResult* output_profile_result) {
2224   const CudnnRnnDescriptor& cudnn_rnn_desc =
2225       static_cast<const CudnnRnnDescriptor&>(rnn_desc);
2226   const CudnnRnnSequenceTensorDescriptor& cudnn_input_desc =
2227       static_cast<const CudnnRnnSequenceTensorDescriptor&>(input_desc);
2228   const CudnnRnnStateTensorDescriptor& cudnn_input_h_desc =
2229       static_cast<const CudnnRnnStateTensorDescriptor&>(input_h_desc);
2230   const CudnnRnnStateTensorDescriptor& cudnn_input_c_desc =
2231       static_cast<const CudnnRnnStateTensorDescriptor&>(input_c_desc);
2232   const CudnnRnnSequenceTensorDescriptor& cudnn_output_desc =
2233       static_cast<const CudnnRnnSequenceTensorDescriptor&>(output_desc);
2234   const CudnnRnnStateTensorDescriptor& cudnn_output_h_desc =
2235       static_cast<const CudnnRnnStateTensorDescriptor&>(output_h_desc);
2236   const CudnnRnnStateTensorDescriptor& cudnn_output_c_desc =
2237       static_cast<const CudnnRnnStateTensorDescriptor&>(output_c_desc);
2238   return IsStatusOk(
2239       DoRnnForwardImpl<double>(
2240           stream, cudnn_rnn_desc, cudnn_input_desc, input_data,
2241           cudnn_input_h_desc, input_h_data, cudnn_input_c_desc, input_c_data,
2242           params, cudnn_output_desc, output_data, cudnn_output_h_desc,
2243           output_h_data, cudnn_output_c_desc, output_c_data, is_training,
2244           reserve_space_allocator, workspace_allocator, output_profile_result),
2245       /*report_error=*/!output_profile_result);
2246 }
2247 
DoRnnBackward(Stream * stream,const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<Eigen::half> & input_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<Eigen::half> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<Eigen::half> & input_c_data,const DeviceMemory<Eigen::half> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,const DeviceMemory<Eigen::half> & output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,const DeviceMemory<Eigen::half> & output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,const DeviceMemory<Eigen::half> & output_c_data,const DeviceMemory<Eigen::half> & output_backprop_data,const DeviceMemory<Eigen::half> & output_h_backprop_data,const DeviceMemory<Eigen::half> & output_c_backprop_data,DeviceMemory<Eigen::half> * input_backprop_data,DeviceMemory<Eigen::half> * input_h_backprop_data,DeviceMemory<Eigen::half> * input_c_backprop_data,DeviceMemory<Eigen::half> * params_backprop_data,DeviceMemory<uint8> * reserve_space_data,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)2248 bool CudnnSupport::DoRnnBackward(
2249     Stream* stream, const dnn::RnnDescriptor& rnn_desc,
2250     const dnn::RnnSequenceTensorDescriptor& input_desc,
2251     const DeviceMemory<Eigen::half>& input_data,
2252     const dnn::RnnStateTensorDescriptor& input_h_desc,
2253     const DeviceMemory<Eigen::half>& input_h_data,
2254     const dnn::RnnStateTensorDescriptor& input_c_desc,
2255     const DeviceMemory<Eigen::half>& input_c_data,
2256     const DeviceMemory<Eigen::half>& params,
2257     const dnn::RnnSequenceTensorDescriptor& output_desc,
2258     const DeviceMemory<Eigen::half>& output_data,
2259     const dnn::RnnStateTensorDescriptor& output_h_desc,
2260     const DeviceMemory<Eigen::half>& output_h_data,
2261     const dnn::RnnStateTensorDescriptor& output_c_desc,
2262     const DeviceMemory<Eigen::half>& output_c_data,
2263     const DeviceMemory<Eigen::half>& output_backprop_data,
2264     const DeviceMemory<Eigen::half>& output_h_backprop_data,
2265     const DeviceMemory<Eigen::half>& output_c_backprop_data,
2266     DeviceMemory<Eigen::half>* input_backprop_data,
2267     DeviceMemory<Eigen::half>* input_h_backprop_data,
2268     DeviceMemory<Eigen::half>* input_c_backprop_data,
2269     DeviceMemory<Eigen::half>* params_backprop_data,
2270     DeviceMemory<uint8>* reserve_space_data,
2271     ScratchAllocator* workspace_allocator,
2272     dnn::ProfileResult* output_profile_result) {
2273   const CudnnRnnDescriptor& cudnn_rnn_desc =
2274       static_cast<const CudnnRnnDescriptor&>(rnn_desc);
2275   const CudnnRnnSequenceTensorDescriptor& cudnn_input_desc =
2276       static_cast<const CudnnRnnSequenceTensorDescriptor&>(input_desc);
2277   const CudnnRnnStateTensorDescriptor& cudnn_input_h_desc =
2278       static_cast<const CudnnRnnStateTensorDescriptor&>(input_h_desc);
2279   const CudnnRnnStateTensorDescriptor& cudnn_input_c_desc =
2280       static_cast<const CudnnRnnStateTensorDescriptor&>(input_c_desc);
2281   const CudnnRnnSequenceTensorDescriptor& cudnn_output_desc =
2282       static_cast<const CudnnRnnSequenceTensorDescriptor&>(output_desc);
2283   const CudnnRnnStateTensorDescriptor& cudnn_output_h_desc =
2284       static_cast<const CudnnRnnStateTensorDescriptor&>(output_h_desc);
2285   const CudnnRnnStateTensorDescriptor& cudnn_output_c_desc =
2286       static_cast<const CudnnRnnStateTensorDescriptor&>(output_c_desc);
2287   return IsStatusOk(
2288       DoRnnBackwardImpl<Eigen::half>(
2289           stream, cudnn_rnn_desc, cudnn_input_desc, input_data,
2290           cudnn_input_h_desc, input_h_data, cudnn_input_c_desc, input_c_data,
2291           params, cudnn_output_desc, output_data, cudnn_output_h_desc,
2292           output_h_data, cudnn_output_c_desc, output_c_data,
2293           output_backprop_data, output_h_backprop_data, output_c_backprop_data,
2294           input_backprop_data, input_h_backprop_data, input_c_backprop_data,
2295           params_backprop_data, reserve_space_data, workspace_allocator,
2296           output_profile_result),
2297       /*report_error=*/!output_profile_result);
2298 }
2299 
DoRnnBackward(Stream * stream,const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<float> & input_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<float> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<float> & input_c_data,const DeviceMemory<float> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,const DeviceMemory<float> & output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,const DeviceMemory<float> & output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,const DeviceMemory<float> & output_c_data,const DeviceMemory<float> & output_backprop_data,const DeviceMemory<float> & output_h_backprop_data,const DeviceMemory<float> & output_c_backprop_data,DeviceMemory<float> * input_backprop_data,DeviceMemory<float> * input_h_backprop_data,DeviceMemory<float> * input_c_backprop_data,DeviceMemory<float> * params_backprop_data,DeviceMemory<uint8> * reserve_space_data,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)2300 bool CudnnSupport::DoRnnBackward(
2301     Stream* stream, const dnn::RnnDescriptor& rnn_desc,
2302     const dnn::RnnSequenceTensorDescriptor& input_desc,
2303     const DeviceMemory<float>& input_data,
2304     const dnn::RnnStateTensorDescriptor& input_h_desc,
2305     const DeviceMemory<float>& input_h_data,
2306     const dnn::RnnStateTensorDescriptor& input_c_desc,
2307     const DeviceMemory<float>& input_c_data, const DeviceMemory<float>& params,
2308     const dnn::RnnSequenceTensorDescriptor& output_desc,
2309     const DeviceMemory<float>& output_data,
2310     const dnn::RnnStateTensorDescriptor& output_h_desc,
2311     const DeviceMemory<float>& output_h_data,
2312     const dnn::RnnStateTensorDescriptor& output_c_desc,
2313     const DeviceMemory<float>& output_c_data,
2314     const DeviceMemory<float>& output_backprop_data,
2315     const DeviceMemory<float>& output_h_backprop_data,
2316     const DeviceMemory<float>& output_c_backprop_data,
2317     DeviceMemory<float>* input_backprop_data,
2318     DeviceMemory<float>* input_h_backprop_data,
2319     DeviceMemory<float>* input_c_backprop_data,
2320     DeviceMemory<float>* params_backprop_data,
2321     DeviceMemory<uint8>* reserve_space_data,
2322     ScratchAllocator* workspace_allocator,
2323     dnn::ProfileResult* output_profile_result) {
2324   const CudnnRnnDescriptor& cudnn_rnn_desc =
2325       static_cast<const CudnnRnnDescriptor&>(rnn_desc);
2326   const CudnnRnnSequenceTensorDescriptor& cudnn_input_desc =
2327       static_cast<const CudnnRnnSequenceTensorDescriptor&>(input_desc);
2328   const CudnnRnnStateTensorDescriptor& cudnn_input_h_desc =
2329       static_cast<const CudnnRnnStateTensorDescriptor&>(input_h_desc);
2330   const CudnnRnnStateTensorDescriptor& cudnn_input_c_desc =
2331       static_cast<const CudnnRnnStateTensorDescriptor&>(input_c_desc);
2332   const CudnnRnnSequenceTensorDescriptor& cudnn_output_desc =
2333       static_cast<const CudnnRnnSequenceTensorDescriptor&>(output_desc);
2334   const CudnnRnnStateTensorDescriptor& cudnn_output_h_desc =
2335       static_cast<const CudnnRnnStateTensorDescriptor&>(output_h_desc);
2336   const CudnnRnnStateTensorDescriptor& cudnn_output_c_desc =
2337       static_cast<const CudnnRnnStateTensorDescriptor&>(output_c_desc);
2338   return IsStatusOk(
2339       DoRnnBackwardImpl<float>(
2340           stream, cudnn_rnn_desc, cudnn_input_desc, input_data,
2341           cudnn_input_h_desc, input_h_data, cudnn_input_c_desc, input_c_data,
2342           params, cudnn_output_desc, output_data, cudnn_output_h_desc,
2343           output_h_data, cudnn_output_c_desc, output_c_data,
2344           output_backprop_data, output_h_backprop_data, output_c_backprop_data,
2345           input_backprop_data, input_h_backprop_data, input_c_backprop_data,
2346           params_backprop_data, reserve_space_data, workspace_allocator,
2347           output_profile_result),
2348       /*report_error=*/!output_profile_result);
2349 }
2350 
DoRnnBackward(Stream * stream,const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<double> & input_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<double> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<double> & input_c_data,const DeviceMemory<double> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,const DeviceMemory<double> & output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,const DeviceMemory<double> & output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,const DeviceMemory<double> & output_c_data,const DeviceMemory<double> & output_backprop_data,const DeviceMemory<double> & output_h_backprop_data,const DeviceMemory<double> & output_c_backprop_data,DeviceMemory<double> * input_backprop_data,DeviceMemory<double> * input_h_backprop_data,DeviceMemory<double> * input_c_backprop_data,DeviceMemory<double> * params_backprop_data,DeviceMemory<uint8> * reserve_space_data,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)2351 bool CudnnSupport::DoRnnBackward(
2352     Stream* stream, const dnn::RnnDescriptor& rnn_desc,
2353     const dnn::RnnSequenceTensorDescriptor& input_desc,
2354     const DeviceMemory<double>& input_data,
2355     const dnn::RnnStateTensorDescriptor& input_h_desc,
2356     const DeviceMemory<double>& input_h_data,
2357     const dnn::RnnStateTensorDescriptor& input_c_desc,
2358     const DeviceMemory<double>& input_c_data,
2359     const DeviceMemory<double>& params,
2360     const dnn::RnnSequenceTensorDescriptor& output_desc,
2361     const DeviceMemory<double>& output_data,
2362     const dnn::RnnStateTensorDescriptor& output_h_desc,
2363     const DeviceMemory<double>& output_h_data,
2364     const dnn::RnnStateTensorDescriptor& output_c_desc,
2365     const DeviceMemory<double>& output_c_data,
2366     const DeviceMemory<double>& output_backprop_data,
2367     const DeviceMemory<double>& output_h_backprop_data,
2368     const DeviceMemory<double>& output_c_backprop_data,
2369     DeviceMemory<double>* input_backprop_data,
2370     DeviceMemory<double>* input_h_backprop_data,
2371     DeviceMemory<double>* input_c_backprop_data,
2372     DeviceMemory<double>* params_backprop_data,
2373     DeviceMemory<uint8>* reserve_space_data,
2374     ScratchAllocator* workspace_allocator,
2375     dnn::ProfileResult* output_profile_result) {
2376   const CudnnRnnDescriptor& cudnn_rnn_desc =
2377       static_cast<const CudnnRnnDescriptor&>(rnn_desc);
2378   const CudnnRnnSequenceTensorDescriptor& cudnn_input_desc =
2379       static_cast<const CudnnRnnSequenceTensorDescriptor&>(input_desc);
2380   const CudnnRnnStateTensorDescriptor& cudnn_input_h_desc =
2381       static_cast<const CudnnRnnStateTensorDescriptor&>(input_h_desc);
2382   const CudnnRnnStateTensorDescriptor& cudnn_input_c_desc =
2383       static_cast<const CudnnRnnStateTensorDescriptor&>(input_c_desc);
2384   const CudnnRnnSequenceTensorDescriptor& cudnn_output_desc =
2385       static_cast<const CudnnRnnSequenceTensorDescriptor&>(output_desc);
2386   const CudnnRnnStateTensorDescriptor& cudnn_output_h_desc =
2387       static_cast<const CudnnRnnStateTensorDescriptor&>(output_h_desc);
2388   const CudnnRnnStateTensorDescriptor& cudnn_output_c_desc =
2389       static_cast<const CudnnRnnStateTensorDescriptor&>(output_c_desc);
2390   return IsStatusOk(
2391       DoRnnBackwardImpl<double>(
2392           stream, cudnn_rnn_desc, cudnn_input_desc, input_data,
2393           cudnn_input_h_desc, input_h_data, cudnn_input_c_desc, input_c_data,
2394           params, cudnn_output_desc, output_data, cudnn_output_h_desc,
2395           output_h_data, cudnn_output_c_desc, output_c_data,
2396           output_backprop_data, output_h_backprop_data, output_c_backprop_data,
2397           input_backprop_data, input_h_backprop_data, input_c_backprop_data,
2398           params_backprop_data, reserve_space_data, workspace_allocator,
2399           output_profile_result),
2400       /*report_error=*/!output_profile_result);
2401 }
2402 
2403 namespace {
2404 
2405 // TODO(csigg): Merge a lot of duplicate code below for forward, backward data,
2406 // and backward filter.
2407 
GetCudnnConvolutionForwardAlgo(const CudnnHandle & cudnn,const CudnnTensorDescriptor & input_nd,const CudnnFilterDescriptor & filter,const CudnnConvolutionDescriptor & conv,const CudnnTensorDescriptor & output_nd,bool specify_workspace_limit,size_t memory_limit_bytes)2408 port::StatusOr<cudnnConvolutionFwdAlgo_t> GetCudnnConvolutionForwardAlgo(
2409     const CudnnHandle& cudnn, const CudnnTensorDescriptor& input_nd,
2410     const CudnnFilterDescriptor& filter, const CudnnConvolutionDescriptor& conv,
2411     const CudnnTensorDescriptor& output_nd, bool specify_workspace_limit,
2412     size_t memory_limit_bytes) {
2413   cudnnConvolutionFwdPreference_t preference =
2414       specify_workspace_limit ? CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT
2415                               : CUDNN_CONVOLUTION_FWD_NO_WORKSPACE;
2416   cudnnConvolutionFwdAlgo_t algo_to_use;
2417   RETURN_IF_CUDNN_ERROR(cudnnGetConvolutionForwardAlgorithm(
2418       cudnn.handle(), input_nd.handle(), filter.handle(), conv.handle(),
2419       output_nd.handle(), preference, memory_limit_bytes, &algo_to_use));
2420   return algo_to_use;
2421 }
2422 
2423 port::StatusOr<cudnnConvolutionBwdDataAlgo_t>
GetCudnnConvolutionBackwardDataAlgo(const CudnnHandle & cudnn,const CudnnTensorDescriptor & input_nd,const CudnnFilterDescriptor & filter,const CudnnConvolutionDescriptor & conv,const CudnnTensorDescriptor & output_nd,bool specify_workspace_limit,size_t memory_limit_bytes)2424 GetCudnnConvolutionBackwardDataAlgo(const CudnnHandle& cudnn,
2425                                     const CudnnTensorDescriptor& input_nd,
2426                                     const CudnnFilterDescriptor& filter,
2427                                     const CudnnConvolutionDescriptor& conv,
2428                                     const CudnnTensorDescriptor& output_nd,
2429                                     bool specify_workspace_limit,
2430                                     size_t memory_limit_bytes) {
2431   cudnnConvolutionBwdDataPreference_t preference =
2432       specify_workspace_limit
2433           ? CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT
2434           : CUDNN_CONVOLUTION_BWD_DATA_NO_WORKSPACE;
2435   cudnnConvolutionBwdDataAlgo_t algo_to_use;
2436   RETURN_IF_CUDNN_ERROR(cudnnGetConvolutionBackwardDataAlgorithm(
2437       cudnn.handle(), filter.handle(), output_nd.handle(), conv.handle(),
2438       input_nd.handle(), preference, memory_limit_bytes, &algo_to_use));
2439   return algo_to_use;
2440 }
2441 
2442 port::StatusOr<cudnnConvolutionBwdFilterAlgo_t>
GetCudnnConvolutionBackwardFilterAlgo(const CudnnHandle & cudnn,const CudnnTensorDescriptor & input_nd,const CudnnFilterDescriptor & filter,const CudnnConvolutionDescriptor & conv,const CudnnTensorDescriptor & output_nd,bool specify_workspace_limit,size_t memory_limit_bytes)2443 GetCudnnConvolutionBackwardFilterAlgo(const CudnnHandle& cudnn,
2444                                       const CudnnTensorDescriptor& input_nd,
2445                                       const CudnnFilterDescriptor& filter,
2446                                       const CudnnConvolutionDescriptor& conv,
2447                                       const CudnnTensorDescriptor& output_nd,
2448                                       bool specify_workspace_limit,
2449                                       size_t memory_limit_bytes) {
2450   cudnnConvolutionBwdFilterPreference_t preference =
2451       specify_workspace_limit
2452           ? CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT
2453           : CUDNN_CONVOLUTION_BWD_FILTER_NO_WORKSPACE;
2454   cudnnConvolutionBwdFilterAlgo_t algo_to_use;
2455   RETURN_IF_CUDNN_ERROR(cudnnGetConvolutionBackwardFilterAlgorithm(
2456       cudnn.handle(), input_nd.handle(), output_nd.handle(), conv.handle(),
2457       filter.handle(), preference, memory_limit_bytes, &algo_to_use));
2458   return algo_to_use;
2459 }
2460 
AllocateCudnnConvolutionForwardWorkspace(Stream * stream,const CudnnHandle & cudnn,const CudnnTensorDescriptor & input_nd,const CudnnFilterDescriptor & filter,const CudnnConvolutionDescriptor & conv,const CudnnTensorDescriptor & output_nd,const dnn::AlgorithmDesc & algorithm_desc,ScratchAllocator * scratch_allocator)2461 port::StatusOr<DeviceMemory<uint8>> AllocateCudnnConvolutionForwardWorkspace(
2462     Stream* stream, const CudnnHandle& cudnn,
2463     const CudnnTensorDescriptor& input_nd, const CudnnFilterDescriptor& filter,
2464     const CudnnConvolutionDescriptor& conv,
2465     const CudnnTensorDescriptor& output_nd,
2466     const dnn::AlgorithmDesc& algorithm_desc,
2467     ScratchAllocator* scratch_allocator) {
2468   // TODO(csigg): This has side effects on the convolution descriptor. It is
2469   // functionally correct because the convolution is run with the algorithm of
2470   // the last call to this function, but should be fixed anyway.
2471   conv.set_use_tensor_op_math(algorithm_desc.tensor_ops_enabled());
2472 
2473   // Query the size of the workspace and allocate it.
2474   size_t size_in_bytes;
2475   RETURN_IF_CUDNN_ERROR(cudnnGetConvolutionForwardWorkspaceSize(
2476       cudnn.handle(),
2477       /*xDesc=*/input_nd.handle(),
2478       /*wDesc=*/filter.handle(), /*convDesc=*/conv.handle(),
2479       /*yDesc=*/output_nd.handle(), /*algo=*/ToConvForwardAlgo(algorithm_desc),
2480       /*sizeInBytes=*/&size_in_bytes));
2481 
2482   int64 size_in_bytes_int64 = size_in_bytes;
2483 
2484   if (TF_PREDICT_FALSE(size_in_bytes_int64 < 0)) {
2485     return port::Status(
2486         port::error::INTERNAL,
2487         "cudnnGetConvolutionForwardWorkspaceSize() returned "
2488         "negative sizeInBytes value. This could be a cudnn bug.");
2489   }
2490 
2491   if (size_in_bytes_int64 == 0) {
2492     return DeviceMemory<uint8>();
2493   }
2494 
2495   if (TF_PREDICT_FALSE(!scratch_allocator)) {
2496     return port::Status(port::error::INVALID_ARGUMENT,
2497                         "No scratch allocator provided");
2498   }
2499 
2500   return scratch_allocator->AllocateBytes(size_in_bytes);
2501 }
2502 
2503 port::StatusOr<DeviceMemory<uint8>>
AllocateCudnnConvolutionBackwardDataWorkspace(Stream * stream,const CudnnHandle & cudnn,const CudnnTensorDescriptor & input_nd,const CudnnFilterDescriptor & filter,const CudnnConvolutionDescriptor & conv,const CudnnTensorDescriptor & output_nd,const dnn::AlgorithmDesc & algorithm_desc,ScratchAllocator * scratch_allocator)2504 AllocateCudnnConvolutionBackwardDataWorkspace(
2505     Stream* stream, const CudnnHandle& cudnn,
2506     const CudnnTensorDescriptor& input_nd, const CudnnFilterDescriptor& filter,
2507     const CudnnConvolutionDescriptor& conv,
2508     const CudnnTensorDescriptor& output_nd,
2509     const dnn::AlgorithmDesc& algorithm_desc,
2510     ScratchAllocator* scratch_allocator) {
2511   // TODO(csigg): This has side effects on the convolution descriptor. It is
2512   // functionally correct because the convolution is run with the algorithm of
2513   // the last call to this function, but should be fixed anyway.
2514   conv.set_use_tensor_op_math(algorithm_desc.tensor_ops_enabled());
2515 
2516   // Query the size of the workspace and allocate it.
2517   size_t size_in_bytes;
2518   RETURN_IF_CUDNN_ERROR(cudnnGetConvolutionBackwardDataWorkspaceSize(
2519       cudnn.handle(),
2520       /*wDesc=*/filter.handle(),
2521       /*dyDesc=*/output_nd.handle(),
2522       /*convDesc=*/conv.handle(),
2523       /*dxDesc=*/input_nd.handle(),
2524       /*algo=*/ToConvBackwardDataAlgo(algorithm_desc),
2525       /*sizeInBytes=*/&size_in_bytes));
2526 
2527   int64 size_in_bytes_int64 = size_in_bytes;
2528 
2529   if (TF_PREDICT_FALSE(size_in_bytes_int64 < 0)) {
2530     return port::Status(
2531         port::error::INTERNAL,
2532         "cudnnGetConvolutionBackwardDataWorkspaceSize() returned "
2533         "negative sizeInBytes value. This could be a cudnn bug.");
2534   }
2535 
2536   if (size_in_bytes_int64 == 0) {
2537     return DeviceMemory<uint8>();
2538   }
2539 
2540   if (TF_PREDICT_FALSE(!scratch_allocator)) {
2541     return port::Status(port::error::INVALID_ARGUMENT,
2542                         "No scratch allocator provided");
2543   }
2544 
2545   return scratch_allocator->AllocateBytes(size_in_bytes);
2546 }
2547 
2548 port::StatusOr<DeviceMemory<uint8>>
AllocateCudnnConvolutionBackwardFilterWorkspace(Stream * stream,const CudnnHandle & cudnn,const CudnnTensorDescriptor & input_nd,const CudnnFilterDescriptor & filter,const CudnnConvolutionDescriptor & conv,const CudnnTensorDescriptor & output_nd,const dnn::AlgorithmDesc & algorithm_desc,ScratchAllocator * scratch_allocator)2549 AllocateCudnnConvolutionBackwardFilterWorkspace(
2550     Stream* stream, const CudnnHandle& cudnn,
2551     const CudnnTensorDescriptor& input_nd, const CudnnFilterDescriptor& filter,
2552     const CudnnConvolutionDescriptor& conv,
2553     const CudnnTensorDescriptor& output_nd,
2554     const dnn::AlgorithmDesc& algorithm_desc,
2555     ScratchAllocator* scratch_allocator) {
2556   // TODO(csigg): This has side effects on the convolution descriptor. It is
2557   // functionally correct because the convolution is run with the algorithm of
2558   // the last call to this function, but should be fixed anyway.
2559   conv.set_use_tensor_op_math(algorithm_desc.tensor_ops_enabled());
2560 
2561   // Query the size of the workspace and allocate it.
2562   size_t size_in_bytes;
2563   RETURN_IF_CUDNN_ERROR(cudnnGetConvolutionBackwardFilterWorkspaceSize(
2564       cudnn.handle(),
2565       /*xDesc=*/input_nd.handle(),
2566       /*dyDesc=*/output_nd.handle(),
2567       /*convDesc=*/conv.handle(),
2568       /*gradDesc=*/filter.handle(),
2569       /*algo=*/ToConvBackwardFilterAlgo(algorithm_desc),
2570       /*sizeInBytes=*/&size_in_bytes));
2571 
2572   int64 size_in_bytes_int64 = size_in_bytes;
2573 
2574   if (TF_PREDICT_FALSE(size_in_bytes_int64 < 0)) {
2575     return port::Status(
2576         port::error::INTERNAL,
2577         "cudnnGetConvolutionBackwardFilterWorkspaceSize() returned "
2578         "negative sizeInBytes value. This could be a cudnn bug.");
2579   }
2580 
2581   if (size_in_bytes_int64 == 0) {
2582     return DeviceMemory<uint8>();
2583   }
2584 
2585   if (TF_PREDICT_FALSE(!scratch_allocator)) {
2586     return port::Status(port::error::INVALID_ARGUMENT,
2587                         "No scratch allocator provided");
2588   }
2589 
2590   return scratch_allocator->AllocateBytes(size_in_bytes);
2591 }
2592 
TensorOpMathAvailable(int cc_major)2593 static bool TensorOpMathAvailable(int cc_major) {
2594   return cc_major >= 7 && CUDNN_VERSION >= 7000 && TensorOpMathEnabled();
2595 }
2596 
GetCudnnConvolutionForwardAlgorithm(Stream * stream,const CudnnHandle & cudnn,const dnn::AlgorithmConfig & algorithm_config,const CudnnTensorDescriptor & input_nd,const CudnnFilterDescriptor & filter,const CudnnConvolutionDescriptor & conv,const CudnnTensorDescriptor & output_nd,ScratchAllocator * scratch_allocator,DeviceMemory<uint8> * scratch)2597 port::StatusOr<dnn::AlgorithmDesc> GetCudnnConvolutionForwardAlgorithm(
2598     Stream* stream, const CudnnHandle& cudnn,
2599     const dnn::AlgorithmConfig& algorithm_config,
2600     const CudnnTensorDescriptor& input_nd, const CudnnFilterDescriptor& filter,
2601     const CudnnConvolutionDescriptor& conv,
2602     const CudnnTensorDescriptor& output_nd, ScratchAllocator* scratch_allocator,
2603     DeviceMemory<uint8>* scratch) {
2604   absl::optional<dnn::AlgorithmDesc> algo_desc = algorithm_config.algorithm();
2605   if (!algo_desc.has_value()) {
2606     // Pick fastest algorithm within memory limit according to cuDNN's
2607     // heuristics.
2608     bool specify_workspace_limit = scratch_allocator != nullptr;
2609     auto memory_limit_bytes =
2610         specify_workspace_limit
2611             ? std::max(scratch_allocator->GetMemoryLimitInBytes(), 0ll)
2612             : 0ll;
2613     SE_ASSIGN_OR_RETURN(cudnnConvolutionFwdAlgo_t algo,
2614                         GetCudnnConvolutionForwardAlgo(
2615                             cudnn, input_nd, filter, conv, output_nd,
2616                             specify_workspace_limit, memory_limit_bytes));
2617     int cc_major, cc_minor;
2618     std::tie(cc_major, cc_minor) = GetCcMajorMinor(stream);
2619     algo_desc = dnn::AlgorithmDesc(
2620         algo, /*use_tensor_ops=*/TensorOpMathAvailable(cc_major));
2621   }
2622 
2623   const auto scratch_or = AllocateCudnnConvolutionForwardWorkspace(
2624       stream, cudnn, input_nd, filter, conv, output_nd, *algo_desc,
2625       scratch_allocator);
2626 
2627   if (scratch_or.ok()) {
2628     *scratch = scratch_or.ValueOrDie();
2629     return *algo_desc;
2630   }
2631 
2632   algo_desc = algorithm_config.algorithm_no_scratch();
2633 
2634   // Failed to allocate workspace for the first algorithm, fall back to the
2635   // no_scratch algorithm.
2636   if (!algo_desc.has_value()) {
2637     return port::Status(
2638         port::error::INVALID_ARGUMENT,
2639         absl::StrCat("The primary convolution algorithm failed, ",
2640                      "while a secondary algorithm is not provided. ",
2641                      "Returned status: ", scratch_or.status().ToString()));
2642   }
2643 
2644   SE_ASSIGN_OR_RETURN(*scratch, AllocateCudnnConvolutionForwardWorkspace(
2645                                     stream, cudnn, input_nd, filter, conv,
2646                                     output_nd, *algo_desc, scratch_allocator));
2647   return *algo_desc;
2648 }
2649 
GetCudnnConvolutionBackwardDataAlgorithm(Stream * stream,const CudnnHandle & cudnn,const dnn::AlgorithmConfig & algorithm_config,const CudnnTensorDescriptor & input_nd,const CudnnFilterDescriptor & filter,const CudnnConvolutionDescriptor & conv,const CudnnTensorDescriptor & output_nd,ScratchAllocator * scratch_allocator,DeviceMemory<uint8> * scratch)2650 port::StatusOr<dnn::AlgorithmDesc> GetCudnnConvolutionBackwardDataAlgorithm(
2651     Stream* stream, const CudnnHandle& cudnn,
2652     const dnn::AlgorithmConfig& algorithm_config,
2653     const CudnnTensorDescriptor& input_nd, const CudnnFilterDescriptor& filter,
2654     const CudnnConvolutionDescriptor& conv,
2655     const CudnnTensorDescriptor& output_nd, ScratchAllocator* scratch_allocator,
2656     DeviceMemory<uint8>* scratch) {
2657   absl::optional<dnn::AlgorithmDesc> algo_desc = algorithm_config.algorithm();
2658   if (!algo_desc.has_value()) {
2659     // Pick fastest algorithm within memory limit according to cuDNN's
2660     // heuristics.
2661     bool specify_workspace_limit = scratch_allocator != nullptr;
2662     auto memory_limit_bytes =
2663         specify_workspace_limit
2664             ? std::max(scratch_allocator->GetMemoryLimitInBytes(), 0ll)
2665             : 0ll;
2666     SE_ASSIGN_OR_RETURN(cudnnConvolutionBwdDataAlgo_t algo,
2667                         GetCudnnConvolutionBackwardDataAlgo(
2668                             cudnn, input_nd, filter, conv, output_nd,
2669                             specify_workspace_limit, memory_limit_bytes));
2670     int cc_major, cc_minor;
2671     std::tie(cc_major, cc_minor) = GetCcMajorMinor(stream);
2672     algo_desc = dnn::AlgorithmDesc(
2673         algo, /*use_tensor_ops=*/TensorOpMathAvailable(cc_major));
2674   }
2675 
2676   const auto scratch_or = AllocateCudnnConvolutionBackwardDataWorkspace(
2677       stream, cudnn, input_nd, filter, conv, output_nd, *algo_desc,
2678       scratch_allocator);
2679 
2680   if (scratch_or.ok()) {
2681     *scratch = scratch_or.ValueOrDie();
2682     return *algo_desc;
2683   }
2684 
2685   algo_desc = algorithm_config.algorithm_no_scratch();
2686 
2687   // Failed to allocate workspace for the first algorithm, fall back to the
2688   // no_scratch algorithm.
2689   if (!algo_desc.has_value()) {
2690     return port::Status(
2691         port::error::INVALID_ARGUMENT,
2692         "The primary convolution algorithm failed memory allocation, "
2693         "while a secondary algorithm is not provided.");
2694   }
2695 
2696   SE_ASSIGN_OR_RETURN(*scratch, AllocateCudnnConvolutionBackwardDataWorkspace(
2697                                     stream, cudnn, input_nd, filter, conv,
2698                                     output_nd, *algo_desc, scratch_allocator));
2699   return *algo_desc;
2700 }
2701 
GetCudnnConvolutionBackwardFilterAlgorithm(Stream * stream,const CudnnHandle & cudnn,const dnn::AlgorithmConfig & algorithm_config,const CudnnTensorDescriptor & input_nd,const CudnnFilterDescriptor & filter,const CudnnConvolutionDescriptor & conv,const CudnnTensorDescriptor & output_nd,ScratchAllocator * scratch_allocator,DeviceMemory<uint8> * scratch)2702 port::StatusOr<dnn::AlgorithmDesc> GetCudnnConvolutionBackwardFilterAlgorithm(
2703     Stream* stream, const CudnnHandle& cudnn,
2704     const dnn::AlgorithmConfig& algorithm_config,
2705     const CudnnTensorDescriptor& input_nd, const CudnnFilterDescriptor& filter,
2706     const CudnnConvolutionDescriptor& conv,
2707     const CudnnTensorDescriptor& output_nd, ScratchAllocator* scratch_allocator,
2708     DeviceMemory<uint8>* scratch) {
2709   absl::optional<dnn::AlgorithmDesc> algo_desc = algorithm_config.algorithm();
2710   if (!algo_desc.has_value()) {
2711     // Pick fastest algorithm within memory limit according to cuDNN's
2712     // heuristics.
2713     bool specify_workspace_limit = scratch_allocator != nullptr;
2714     auto memory_limit_bytes =
2715         specify_workspace_limit
2716             ? std::max(scratch_allocator->GetMemoryLimitInBytes(), 0ll)
2717             : 0ll;
2718     SE_ASSIGN_OR_RETURN(cudnnConvolutionBwdFilterAlgo_t algo,
2719                         GetCudnnConvolutionBackwardFilterAlgo(
2720                             cudnn, input_nd, filter, conv, output_nd,
2721                             specify_workspace_limit, memory_limit_bytes));
2722     int cc_major, cc_minor;
2723     std::tie(cc_major, cc_minor) = GetCcMajorMinor(stream);
2724     algo_desc = dnn::AlgorithmDesc(
2725         algo, /*use_tensor_ops=*/TensorOpMathAvailable(cc_major));
2726   }
2727 
2728   auto scratch_or = AllocateCudnnConvolutionBackwardFilterWorkspace(
2729       stream, cudnn, input_nd, filter, conv, output_nd, *algo_desc,
2730       scratch_allocator);
2731 
2732   if (scratch_or.ok()) {
2733     *scratch = scratch_or.ValueOrDie();
2734     return *algo_desc;
2735   }
2736 
2737   algo_desc = algorithm_config.algorithm_no_scratch();
2738 
2739   // Failed to allocate workspace for the first algorithm, fall back to the
2740   // no_scratch algorithm.
2741   if (!algo_desc.has_value()) {
2742     return port::Status(
2743         port::error::INVALID_ARGUMENT,
2744         "The primary convolution algorithm failed memory allocation, "
2745         "while a secondary algorithm is not provided.");
2746   }
2747 
2748   SE_ASSIGN_OR_RETURN(*scratch, AllocateCudnnConvolutionBackwardFilterWorkspace(
2749                                     stream, cudnn, input_nd, filter, conv,
2750                                     output_nd, *algo_desc, scratch_allocator));
2751   return *algo_desc;
2752 }
2753 
2754 // A helper class to set env-vars and choose options for cudnn-related
2755 // algorithms.
2756 template <typename EnvVar>
2757 class CudnnEnvVar {
2758  public:
IsEnabled()2759   static bool IsEnabled() {
2760     static bool is_enabled = IsEnabledImpl();
2761     return is_enabled;
2762   }
2763 
2764  private:
IsEnabledImpl()2765   static bool IsEnabledImpl() {
2766     const char* tf_env_var_val = getenv(EnvVar::kName);
2767     if (tf_env_var_val != nullptr) {
2768       absl::string_view tf_env_var_val_str(tf_env_var_val);
2769       if (tf_env_var_val_str == "0") {
2770         return false;
2771       }
2772       return true;
2773     }
2774     return EnvVar::kDefaultFlag;
2775   }
2776 };
2777 
2778 // A helper struct to decide whether to enable the FFT_TILING algorithms for
2779 // forward convolution. It is disabled for cuDNN < 7 due to memory corruption
2780 // caused by some shapes with this algorithm. Users can explicitly enable the
2781 // algorithm through an env-var "TF_ENABLE_FFT_TILING_FORWARD=1".
2782 struct FftTilingForward {
2783   static constexpr const char* kName = "TF_ENABLE_FFT_TILING_FORWARD";
2784   static constexpr bool kDefaultFlag = CUDNN_VERSION >= 7000;
2785 };
2786 
2787 // A helper struct to decide whether to enable the WINOGRAD_NONFUSED algorithms.
2788 // By default it is turned on, users can explicitly disable them through an
2789 // env-var "TF_ENABLE_WINOGRAD_NONFUSED=0".
2790 // https://github.com/tensorflow/tensorflow/pull/4901
2791 struct WinogradNonfused {
2792   static constexpr const char* kName = "TF_ENABLE_WINOGRAD_NONFUSED";
2793   // NVIDIA has fixed winograd nonfused bug for cudnn v>=7. For older versions,
2794   // we have a workaround.
2795   static constexpr bool kDefaultFlag = true;
2796 };
2797 
2798 // A helper struct to decide whether to use FP32 as the internal compute type
2799 // for convolution when the input data type is FP16. By default it is turned on,
2800 // users can explicitly disable them (choose to use FP16 as the internal compute
2801 // type) through an env-var "TF_FP16_CONV_USE_FP32_COMPUTE=0".
2802 struct ConvDoFP32ComputationFP16Input {
2803   static constexpr const char* kName = "TF_FP16_CONV_USE_FP32_COMPUTE";
2804   // Using FP16 as the internal compute type for convolution when the input data
2805   // type is FP16 is only supported on architectures with true fp16 support
2806   // (compute capability 5.3 and 6.0). Setting this to false in an unsupported
2807   // architecture will cause internal errors.
2808   static constexpr bool kDefaultFlag = true;
2809 };
2810 
2811 // A helper struct to decide whether to use FP32 as the internal compute type
2812 // for rnn when the input data type is FP16. At present it is turned off,
2813 // users can explicitly control them through an env-var
2814 // TF_FP16_RNN_USE_FP32_COMPUTE.
2815 // After the TODO below is fixed, users should almost always use fp32 compute
2816 // type for training. Using fp16 might suffer suboptimal accuracy due to loss
2817 // in precision.
2818 struct RnnDoFP32ComputationFP16Input {
2819   static constexpr const char* kName = "TF_FP16_RNN_USE_FP32_COMPUTE";
2820   // TODO(jamesqin): b/78182362 flip to true when cudnn 7.1.4 fixes the bug.
2821   // Before cudnn 7.1.4 RNN are always done in fp32, no matter what math
2822   // precision is set.
2823   // Set it temporary to false s.t. no error is raised when using fp16 inputs,
2824   // fp32 math precision.
2825   //
2826   // cuDNN == 7.5.0 is verified to have this fixed.
2827   static constexpr bool kDefaultFlag = CUDNN_VERSION >= 7500;
2828 };
2829 
GetRnnComputeType(dnn::DataType data_type)2830 cudnnDataType_t GetRnnComputeType(dnn::DataType data_type) {
2831   switch (data_type) {
2832     case dnn::DataType::kFloat:
2833       return CUDNN_DATA_FLOAT;
2834     case dnn::DataType::kDouble:
2835       return CUDNN_DATA_DOUBLE;
2836     case dnn::DataType::kHalf:
2837       if (CudnnEnvVar<RnnDoFP32ComputationFP16Input>::IsEnabled()) {
2838         return CUDNN_DATA_FLOAT;
2839       } else {
2840         return CUDNN_DATA_HALF;
2841       }
2842     default:
2843       LOG(FATAL) << "Invalid RNN data type: " << static_cast<int>(data_type);
2844   }
2845 }
2846 
GetConvAccumulatorType(dnn::DataType data_type)2847 dnn::DataType GetConvAccumulatorType(dnn::DataType data_type) {
2848   switch (data_type) {
2849     case dnn::DataType::kFloat:
2850     case dnn::DataType::kDouble:
2851       return data_type;
2852     case dnn::DataType::kHalf:
2853       return CudnnEnvVar<ConvDoFP32ComputationFP16Input>::IsEnabled()
2854                  ? dnn::DataType::kFloat
2855                  : dnn::DataType::kHalf;
2856     case dnn::DataType::kInt8:
2857     case dnn::DataType::kInt32:
2858       return dnn::DataType::kInt32;
2859     default:
2860       LOG(FATAL) << "Invalid DNN data type: " << static_cast<int>(data_type);
2861   }
2862 }
2863 
2864 // Determines whether we can safely perform a winograd non-fused convolution for
2865 // the given input and output shapes.  This works around b/68264959, an integer
2866 // overflow in cuDNNv5 and cuDNNv6.
2867 #if CUDNN_VERSION >= 7000
ShouldIncludeWinogradNonfusedAlgo(const dnn::BatchDescriptor &,const dnn::BatchDescriptor &)2868 bool ShouldIncludeWinogradNonfusedAlgo(const dnn::BatchDescriptor&,
2869                                        const dnn::BatchDescriptor&) {
2870   return true;
2871 }
2872 #else
ShouldIncludeWinogradNonfusedAlgo(const dnn::BatchDescriptor & input_desc,const dnn::BatchDescriptor & output_desc)2873 bool ShouldIncludeWinogradNonfusedAlgo(
2874     const dnn::BatchDescriptor& input_desc,
2875     const dnn::BatchDescriptor& output_desc) {
2876   int64 batch = input_desc.count();
2877   int64 in_depths = input_desc.feature_map_count();
2878   int64 in_rows = input_desc.height();
2879   int64 in_cols = input_desc.ndims() == 1 ? 1 : input_desc.width();
2880   int64 out_depths = output_desc.feature_map_count();
2881 
2882   int64 total_size = port::MathUtil::CeilOfRatio(batch, int64{16}) *
2883                      std::max(in_depths, out_depths) * in_cols * in_rows *
2884                      sizeof(float);
2885 
2886   const int64 threshold = 1L << 31;
2887   return total_size < threshold;
2888 }
2889 #endif
2890 
2891 }  // namespace
2892 
DoPrepareForConvolution(dnn::ConvolutionKind kind,dnn::DataType element_type,Stream * stream,const dnn::BatchDescriptor & input_descriptor,DeviceMemoryBase input_data,const dnn::FilterDescriptor & filter_descriptor,DeviceMemoryBase filter_data,const dnn::BatchDescriptor & output_descriptor,DeviceMemoryBase output_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::AlgorithmConfig & algorithm_config,ScratchAllocator * scratch_allocator,dnn::AlgorithmDesc * algorithm_desc,DeviceMemory<uint8> * scratch_memory)2893 port::Status CudnnSupport::DoPrepareForConvolution(
2894     dnn::ConvolutionKind kind, dnn::DataType element_type, Stream* stream,
2895     const dnn::BatchDescriptor& input_descriptor, DeviceMemoryBase input_data,
2896     const dnn::FilterDescriptor& filter_descriptor,
2897     DeviceMemoryBase filter_data, const dnn::BatchDescriptor& output_descriptor,
2898     DeviceMemoryBase output_data,
2899     const dnn::ConvolutionDescriptor& convolution_descriptor,
2900     const dnn::AlgorithmConfig& algorithm_config,
2901     ScratchAllocator* scratch_allocator, dnn::AlgorithmDesc* algorithm_desc,
2902     DeviceMemory<uint8>* scratch_memory) {
2903   CudnnTensorDescriptor input_nd(
2904       input_descriptor,
2905       ToCudnnDataType(element_type, input_descriptor.layout()));
2906   CudnnFilterDescriptor filter_nd(
2907       filter_descriptor,
2908       ToCudnnDataType(element_type, filter_descriptor.layout()));
2909   CudnnTensorDescriptor output_nd(
2910       output_descriptor,
2911       ToCudnnDataType(element_type, output_descriptor.layout()));
2912   CudnnConvolutionDescriptor conv(
2913       convolution_descriptor,
2914       ToCudnnDataType(GetConvAccumulatorType(element_type)));
2915 
2916   auto cudnn = cudnn_->GetHandle(parent_, stream);
2917 
2918   switch (kind) {
2919     case dnn::ConvolutionKind::FORWARD: {
2920       SE_ASSIGN_OR_RETURN(
2921           *algorithm_desc,
2922           GetCudnnConvolutionForwardAlgorithm(
2923               stream, cudnn, algorithm_config, input_nd, filter_nd, conv,
2924               output_nd, scratch_allocator, scratch_memory));
2925       break;
2926     }
2927     case dnn::ConvolutionKind::BACKWARD_DATA: {
2928       SE_ASSIGN_OR_RETURN(
2929           *algorithm_desc,
2930           GetCudnnConvolutionBackwardDataAlgorithm(
2931               stream, cudnn, algorithm_config, input_nd, filter_nd, conv,
2932               output_nd, scratch_allocator, scratch_memory));
2933       break;
2934     }
2935     case dnn::ConvolutionKind::BACKWARD_FILTER: {
2936       SE_ASSIGN_OR_RETURN(
2937           *algorithm_desc,
2938           GetCudnnConvolutionBackwardFilterAlgorithm(
2939               stream, cudnn, algorithm_config, input_nd, filter_nd, conv,
2940               output_nd, scratch_allocator, scratch_memory));
2941       break;
2942     }
2943     default:
2944       return port::InternalError(
2945           absl::StrCat("Unexpected convolution kind ", static_cast<int>(kind)));
2946   }
2947 
2948   return port::Status::OK();
2949 }
2950 
DoConvolve(dnn::ConvolutionKind kind,dnn::DataType element_type,dnn::DataType output_type,Stream * stream,const dnn::BatchDescriptor & input_descriptor,DeviceMemoryBase input_data,const dnn::FilterDescriptor & filter_descriptor,DeviceMemoryBase filter_data,const dnn::BatchDescriptor & output_descriptor,DeviceMemoryBase output_data,const dnn::ConvolutionDescriptor & convolution_descriptor,dnn::AlgorithmDesc algorithm_desc,DeviceMemory<uint8> scratch_memory,dnn::ProfileResult * output_profile_result)2951 port::Status CudnnSupport::DoConvolve(
2952     dnn::ConvolutionKind kind, dnn::DataType element_type,
2953     dnn::DataType output_type, Stream* stream,
2954     const dnn::BatchDescriptor& input_descriptor, DeviceMemoryBase input_data,
2955     const dnn::FilterDescriptor& filter_descriptor,
2956     DeviceMemoryBase filter_data, const dnn::BatchDescriptor& output_descriptor,
2957     DeviceMemoryBase output_data,
2958     const dnn::ConvolutionDescriptor& convolution_descriptor,
2959     dnn::AlgorithmDesc algorithm_desc, DeviceMemory<uint8> scratch_memory,
2960     dnn::ProfileResult* output_profile_result) {
2961   cudnnDataType_t cudnn_type = ToCudnnDataType(element_type);
2962   CudnnTensorDescriptor input_nd(input_descriptor, cudnn_type);
2963   CudnnTensorDescriptor output_nd(output_descriptor,
2964                                   ToCudnnDataType(output_type));
2965   CudnnFilterDescriptor filter_nd(filter_descriptor, cudnn_type);
2966   auto accumulator_type = GetConvAccumulatorType(element_type);
2967   CudnnConvolutionDescriptor conv(convolution_descriptor,
2968                                   ToCudnnDataType(accumulator_type));
2969   // Set use_tensor_math param to correct value
2970   conv.set_use_tensor_op_math(algorithm_desc.tensor_ops_enabled());
2971 
2972   auto cudnn = cudnn_->GetHandle(parent_, stream);
2973   // Alpha is the scaling factor for input.
2974   float falpha = 1.0;
2975   double dalpha = 1.0;
2976   void* alpha = cudnn_type == CUDNN_DATA_DOUBLE ? static_cast<void*>(&dalpha)
2977                                                 : static_cast<void*>(&falpha);
2978   // Beta is the scaling factor for output.
2979   float fbeta = 0.0;
2980   double dbeta = 0.0;
2981   void* beta = cudnn_type == CUDNN_DATA_DOUBLE ? static_cast<void*>(&dbeta)
2982                                                : static_cast<void*>(&fbeta);
2983 
2984   const bool is_profiling = output_profile_result != nullptr;
2985 
2986   std::unique_ptr<GpuTimer, GpuTimerDeleter> timer;
2987   if (is_profiling) {
2988     timer.reset(new GpuTimer(parent_));  // NOLINT
2989     // The start and stop of the timer should be as close to the Cudnn call as
2990     // possible. It is still possible for other threads to issue workload on
2991     // to this stream. So it could take multiple profiling measurements.
2992     if (!timer->Init() || !timer->Start(AsGpuStream(stream))) {
2993       return port::Status(port::error::INTERNAL, "Failed to start timer");
2994     }
2995   }
2996 
2997   const auto get_fwd_bugs = [&]() -> port::Status {
2998     // Report an error if we might be hitting a cuDNN bug that accesses illegal
2999     // memory. See nvbugs/2138754, b/80018418.
3000     if (CUDNN_VERSION < 7300) {
3001       if (algorithm_desc.algo_id() != CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING) {
3002         return port::Status::OK();
3003       }
3004       if (input_descriptor.ndims() < 3) {
3005         return port::Status::OK();
3006       }
3007       // Checks that a*b is within the valid range (as provided by NVIDIA).
3008       const auto check_sizes = [](size_t a, size_t b) {
3009         if ((a * b * 4608 - 1) >> 31 == 0) {
3010           return port::Status::OK();
3011         }
3012         return port::Status(
3013             port::error::FAILED_PRECONDITION,
3014             "This configuration potentially accesses illegal memory.");
3015       };
3016       SE_RETURN_IF_ERROR(check_sizes(input_descriptor.feature_map_count(),
3017                                      output_descriptor.feature_map_count()));
3018       SE_RETURN_IF_ERROR(check_sizes(input_descriptor.count(),
3019                                      input_descriptor.feature_map_count()));
3020       SE_RETURN_IF_ERROR(check_sizes(input_descriptor.count(),
3021                                      output_descriptor.feature_map_count()));
3022       return port::Status::OK();
3023     }
3024     if (algorithm_desc.algo_id() ==
3025             CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED &&
3026         !ShouldIncludeWinogradNonfusedAlgo(input_descriptor,
3027                                            output_descriptor)) {
3028       return port::Status(
3029           port::error::FAILED_PRECONDITION,
3030           "This configuration has potential integer overflow in "
3031           "cuDNNv5 and cuDNNv6. See b/68264959.");
3032     }
3033     if (CUDNN_VERSION < 8000) {
3034       if (algorithm_desc.algo_id() ==
3035               CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM &&
3036           ToCudnnDataType(element_type) == CUDNN_DATA_INT8 &&
3037           ToCudnnDataType(output_type) == CUDNN_DATA_FLOAT) {
3038         return port::Status(
3039             port::error::FAILED_PRECONDITION,
3040             "This configuration potentially produces incorrect results.");
3041       }
3042     }
3043     return port::Status::OK();
3044   };
3045 
3046   auto get_bwd_data_bugs = [&]() -> port::Status {
3047     if (algorithm_desc.algo_id() ==
3048             CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED &&
3049         !ShouldIncludeWinogradNonfusedAlgo(input_descriptor,
3050                                            output_descriptor)) {
3051       return port::Status(
3052           port::error::FAILED_PRECONDITION,
3053           "This configuration has potential integer overflow in "
3054           "cuDNNv5 and cuDNNv6. See b/68264959.");
3055     }
3056 
3057     // Cudnn 7.1.4 has a bug if the workspace of the following convolution is
3058     // not zero-initialized, nvbugs/2254619.
3059     if (CUDNN_VERSION >= 7000 && CUDNN_VERSION < 7300 &&
3060         algorithm_desc.algo_id() == CUDNN_CONVOLUTION_BWD_DATA_ALGO_1 &&
3061         cudnn_type == CUDNN_DATA_HALF && algorithm_desc.tensor_ops_enabled() &&
3062         input_descriptor.layout() == dnn::DataLayout::kBatchYXDepth &&
3063         filter_descriptor.layout() == dnn::FilterLayout::kOutputInputYX &&
3064         output_descriptor.layout() == dnn::DataLayout::kBatchDepthYX &&
3065         (convolution_descriptor.vertical_filter_stride() > 1 ||
3066          convolution_descriptor.horizontal_filter_stride() > 1)) {
3067       stream->ThenMemZero(&scratch_memory, scratch_memory.size());
3068     }
3069     return port::Status::OK();
3070   };
3071 
3072   const auto get_bwd_filter_bugs = [&]() -> port::Status {
3073     // Report an error if we might be hitting a cuDNN bug that produces
3074     // incorrect results. See nvbugs/2072856
3075     if (CUDNN_VERSION < 7300) {
3076       SE_RETURN_IF_ERROR([&] {
3077         if (algorithm_desc.algo_id() !=
3078             CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT_TILING) {
3079           return port::Status::OK();
3080         }
3081         if (output_descriptor.height() > 1 && output_descriptor.width() > 1) {
3082           return port::Status::OK();
3083         }
3084         int convolution_size = output_descriptor.height() > 1
3085                                    ? filter_descriptor.input_filter_height()
3086                                    : filter_descriptor.input_filter_width();
3087         if (convolution_size <= 32) {
3088           return port::Status::OK();
3089         }
3090         cudnnConvolutionMode_t convolution_mode;
3091         cudnnDataType_t compute_type;
3092         RETURN_IF_CUDNN_ERROR(cudnnGetConvolutionNdDescriptor(
3093             conv.handle(), 0, nullptr, nullptr, nullptr, nullptr,
3094             &convolution_mode, &compute_type));
3095         if (convolution_mode != CUDNN_CONVOLUTION) {
3096           return port::Status::OK();
3097         }
3098         return port::Status(
3099             port::error::FAILED_PRECONDITION,
3100             "This configuration potentially produces incorrect results.");
3101       }());
3102     }
3103 
3104     if (algorithm_desc.algo_id() ==
3105             CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED &&
3106         !ShouldIncludeWinogradNonfusedAlgo(input_descriptor,
3107                                            output_descriptor)) {
3108       return port::Status(
3109           port::error::FAILED_PRECONDITION,
3110           "This configuration has potential integer overflow in "
3111           "cuDNNv5 and cuDNNv6. See b/68264959.");
3112     }
3113 
3114     // Zero out the result buffer for strided conv backward filter for NHWC
3115     // layouts. cuDNN 7.1.4 and 7.2 has non-determinisic bug if the buffer is
3116     // not zeroed.
3117     //
3118     // This wrong result caused by the bug is very flaky. It needs to be run for
3119     // up to 20 times to produce a mismatch.
3120     //
3121     // See nvbugs/2379553.
3122     if (CUDNN_VERSION >= 7100 && CUDNN_VERSION < 7300 &&
3123         algorithm_desc.algo_id() == CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1 &&
3124         cudnn_type == CUDNN_DATA_HALF &&
3125         input_descriptor.layout() == dnn::DataLayout::kBatchYXDepth &&
3126         filter_descriptor.layout() == dnn::FilterLayout::kOutputYXInput &&
3127         output_descriptor.layout() == dnn::DataLayout::kBatchYXDepth &&
3128         (convolution_descriptor.vertical_filter_stride() > 1 ||
3129          convolution_descriptor.horizontal_filter_stride() > 1)) {
3130       stream->ThenMemZero(&filter_data, filter_data.size());
3131     }
3132     return port::Status::OK();
3133   };
3134 
3135   switch (kind) {
3136     case dnn::ConvolutionKind::FORWARD: {
3137       SE_RETURN_IF_ERROR(get_fwd_bugs());
3138       RETURN_IF_CUDNN_ERROR(cudnnConvolutionForward(
3139           cudnn.handle(),
3140           /*alpha=*/alpha, /*srcDesc=*/input_nd.handle(),
3141           /*srcData=*/input_data.opaque(), /*filterDesc=*/filter_nd.handle(),
3142           /*filterData=*/filter_data.opaque(), /*convDesc=*/conv.handle(),
3143           /*algo=*/ToConvForwardAlgo(algorithm_desc),
3144           /*workSpace=*/scratch_memory.opaque(),
3145           /*workSpaceSizeInBytes=*/scratch_memory.size(), /*beta=*/beta,
3146           /*yDesc=*/output_nd.handle(), /*y=*/output_data.opaque()));
3147       break;
3148     }
3149     case dnn::ConvolutionKind::BACKWARD_DATA: {
3150       SE_RETURN_IF_ERROR(get_bwd_data_bugs());
3151       RETURN_IF_CUDNN_ERROR(cudnnConvolutionBackwardData(
3152           cudnn.handle(),
3153           /*alpha=*/alpha,
3154           /*wDesc=*/filter_nd.handle(),
3155           /*w=*/filter_data.opaque(),
3156           /*dyDesc=*/output_nd.handle(),
3157           /*dy=*/output_data.opaque(),
3158           /*convDesc=*/conv.handle(),
3159           /*algo=*/ToConvBackwardDataAlgo(algorithm_desc),
3160           /*workSpace=*/scratch_memory.opaque(),
3161           /*workSpaceSizeInBytes=*/scratch_memory.size(),
3162           /*beta=*/beta,
3163           /*dxDesc=*/input_nd.handle(),
3164           /*dx=*/input_data.opaque()));
3165       break;
3166     }
3167     case dnn::ConvolutionKind::BACKWARD_FILTER: {
3168       SE_RETURN_IF_ERROR(get_bwd_filter_bugs());
3169       RETURN_IF_CUDNN_ERROR(cudnnConvolutionBackwardFilter(
3170           cudnn.handle(),
3171           /*alpha=*/alpha,
3172           /*srcDesc=*/input_nd.handle(),
3173           /*srcData=*/input_data.opaque(),
3174           /*diffDesc=*/output_nd.handle(),
3175           /*diffData=*/output_data.opaque(),
3176           /*convDesc=*/conv.handle(),
3177           /*algo=*/ToConvBackwardFilterAlgo(algorithm_desc),
3178           /*workSpace=*/scratch_memory.opaque(),
3179           /*workSpaceSizeInBytes=*/scratch_memory.size(),
3180           /*beta=*/beta,
3181           /*gradDesc=*/filter_nd.handle(),
3182           /*dw=*/filter_data.opaque()));
3183       break;
3184     }
3185     default:
3186       return port::InternalError(
3187           absl::StrCat("Unexpected convolution kind ", static_cast<int>(kind)));
3188   }
3189 
3190   if (is_profiling) {
3191     if (!timer->Stop(AsGpuStream(stream))) {
3192       return port::Status(port::error::INTERNAL, "Failed to stop timer");
3193     }
3194     output_profile_result->set_algorithm(algorithm_desc);
3195     output_profile_result->set_elapsed_time_in_ms(
3196         timer->GetElapsedMilliseconds());
3197     output_profile_result->set_scratch_size(scratch_memory.size());
3198   }
3199 
3200   return port::Status::OK();
3201 }
3202 
3203 // A helper function to query if a CudnnConvolutionDescriptor has tensor_op_math
3204 // set
IsTensorMathOpSet(const CudnnConvolutionDescriptor & conv)3205 static bool IsTensorMathOpSet(const CudnnConvolutionDescriptor& conv) {
3206   cudnnMathType_t math_type;
3207   CHECK_CUDNN_OK(cudnnGetConvolutionMathType(conv.handle(), &math_type));
3208   return math_type == CUDNN_TENSOR_OP_MATH;
3209 }
3210 
3211 template <typename ElementType, typename BiasType, typename ScaleType,
3212           typename OutputType>
DoFusedConvolveImpl(Stream * stream,const dnn::BatchDescriptor & conv_input_descriptor,const DeviceMemory<ElementType> & conv_input_data,ScaleType conv_input_scale,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<ElementType> & filter_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const DeviceMemory<OutputType> & side_input_data,ScaleType side_input_scale,const dnn::BatchDescriptor & bias_descriptor,const DeviceMemory<BiasType> & biases,dnn::ActivationMode activation_mode,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<OutputType> * output_data,dnn::DataType accumulator_type,ScratchAllocator * scratch_allocator,const dnn::AlgorithmConfig & algorithm_config,dnn::ProfileResult * output_profile_result)3213 port::Status CudnnSupport::DoFusedConvolveImpl(
3214     Stream* stream, const dnn::BatchDescriptor& conv_input_descriptor,
3215     const DeviceMemory<ElementType>& conv_input_data,
3216     ScaleType conv_input_scale, const dnn::FilterDescriptor& filter_descriptor,
3217     const DeviceMemory<ElementType>& filter_data,
3218     const dnn::ConvolutionDescriptor& convolution_descriptor,
3219     const DeviceMemory<OutputType>& side_input_data, ScaleType side_input_scale,
3220     const dnn::BatchDescriptor& bias_descriptor,
3221     const DeviceMemory<BiasType>& biases, dnn::ActivationMode activation_mode,
3222     const dnn::BatchDescriptor& output_descriptor,
3223     DeviceMemory<OutputType>* output_data, dnn::DataType accumulator_type,
3224     ScratchAllocator* scratch_allocator,
3225     const dnn::AlgorithmConfig& algorithm_config,
3226     dnn::ProfileResult* output_profile_result) {
3227   if (activation_mode != dnn::ActivationMode::kRelu &&
3228       activation_mode != dnn::ActivationMode::kNone) {
3229     return port::Status(port::error::INVALID_ARGUMENT,
3230                         "cudnnConvolutionBiasActivationForward() only supports "
3231                         "Relu or None activation.");
3232   }
3233 
3234   CudnnTensorDescriptor conv_input_nd(
3235       conv_input_descriptor,
3236       GetCudnnDataType<ElementType>(conv_input_descriptor.layout()));
3237   CudnnTensorDescriptor output_nd(
3238       output_descriptor,
3239       GetCudnnDataType<OutputType>(conv_input_descriptor.layout()));
3240   CudnnFilterDescriptor filter(
3241       filter_descriptor,
3242       GetCudnnDataType<ElementType>(conv_input_descriptor.layout()));
3243   CudnnTensorDescriptor bias_nd(bias_descriptor, GetCudnnDataType<BiasType>());
3244   CudnnConvolutionDescriptor conv(convolution_descriptor,
3245                                   ToCudnnDataType(accumulator_type));
3246 
3247   auto cudnn = cudnn_->GetHandle(parent_, stream);
3248 
3249   const bool is_profiling = output_profile_result != nullptr;
3250 
3251   DeviceMemory<uint8> scratch;
3252   SE_ASSIGN_OR_RETURN(
3253       dnn::AlgorithmDesc algo_desc,
3254       GetCudnnConvolutionForwardAlgorithm(
3255           stream, cudnn, algorithm_config, conv_input_nd, filter, conv,
3256           output_nd, scratch_allocator, &scratch));
3257 
3258   std::unique_ptr<GpuTimer, GpuTimerDeleter> timer;
3259   if (is_profiling) {
3260     timer.reset(new GpuTimer(parent_));  // NOLINT
3261     // The start and stop of the timer should be as close to the Cudnn call as
3262     // possible. It is still possible for other threads to issue workload on
3263     // to this stream. So it could take multiple profiling measurements.
3264     if (!timer->Init() || !timer->Start(AsGpuStream(stream))) {
3265       return port::Status(port::error::INTERNAL, "Failed to start timer");
3266     }
3267   }
3268   // CUDNN v6 only supports CUDNN_NOT_PROPAGATE_NAN as the reluNanOpt for
3269   // activation descriptor. Note that this will change the nan propagation
3270   // behavior from separate conv, bias, and relu (which by default is
3271   // CUDNN_PROPAGATE_NAN.
3272   CudnnActivationDescriptor activation_desc(
3273       activation_mode, CUDNN_NOT_PROPAGATE_NAN, output_descriptor.value_max());
3274   auto side_input_data_ptr = (side_input_scale == 0) ? output_data->opaque()
3275                                                      : side_input_data.opaque();
3276 
3277   VLOG(2) << "\nconv_input_scale = " << conv_input_scale
3278           << "\nconv_input_nd.handle() = " << conv_input_nd.handle()
3279           << "\nconv_input_data.opaque() = " << conv_input_data.opaque()
3280           << "\nfilter.handle() = " << filter.handle()
3281           << "\nfilter_data.opaque() = " << filter_data.opaque()
3282           << "\nconv.handle() = " << conv.handle()
3283           << "\nalgo = " << algo_desc.algo_id()
3284           << "\nscratch.opaque() = " << scratch.opaque()
3285           << "\nscratch.size() = " << scratch.size()
3286           << "\nside_input_scale = " << side_input_scale
3287           << "\noutput_nd.handle() = " << output_nd.handle()
3288           << "\nside_input_data_ptr = " << side_input_data_ptr
3289           << "\nbias_nd.handle() = " << bias_nd.handle()
3290           << "\nbiases.opaque() = " << biases.opaque()
3291           << "\nactivation_desc.handle() = " << activation_desc.handle()
3292           << "\noutput_nd.handle() = " << output_nd.handle()
3293           << "\noutput_data->opaque() = " << output_data->opaque();
3294 
3295   if (algo_desc.algo_id() == CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED &&
3296       !ShouldIncludeWinogradNonfusedAlgo(conv_input_descriptor,
3297                                          output_descriptor)) {
3298     return port::Status(port::error::FAILED_PRECONDITION,
3299                         "This configuration has potential integer overflow in "
3300                         "cuDNNv5 and cuDNNv6. See around b/68264959.");
3301   }
3302   if (IsTensorMathOpSet(conv) != algo_desc.tensor_ops_enabled()) {
3303     return port::Status(port::error::FAILED_PRECONDITION,
3304                         "Tensor op math type in dnn::AlgorithmDesc does not "
3305                         "match that of the CudnnConvolutionDescriptor");
3306   }
3307 
3308   RETURN_IF_CUDNN_ERROR(cudnnConvolutionBiasActivationForward(
3309       cudnn.handle(),
3310       /*alpha1=*/&conv_input_scale,
3311       /*srcDesc=*/conv_input_nd.handle(), /*srcData=*/conv_input_data.opaque(),
3312       /*filterDesc=*/filter.handle(), /*filterData=*/filter_data.opaque(),
3313       /*convDesc=*/conv.handle(), ToConvForwardAlgo(algo_desc),
3314       /*workSpace=*/scratch.opaque(),
3315       /*workSpaceSizeInBytes=*/scratch.size(), /*alpha2=*/&side_input_scale,
3316       /*zDesc=*/output_nd.handle(), /*z=*/side_input_data_ptr,
3317       /*biasDesc=*/bias_nd.handle(), /*bias=*/biases.opaque(),
3318       /*activationDesc=*/activation_desc.handle(),
3319       /*yDesc=*/output_nd.handle(), /*y=*/output_data->opaque()));
3320 
3321   if (is_profiling) {
3322     if (!timer->Stop(AsGpuStream(stream))) {
3323       return port::Status(port::error::INTERNAL, "Failed to stop timer");
3324     }
3325     output_profile_result->set_algorithm(algo_desc);
3326     output_profile_result->set_elapsed_time_in_ms(
3327         timer->GetElapsedMilliseconds());
3328     output_profile_result->set_scratch_size(scratch.size());
3329   }
3330 
3331   return port::Status::OK();
3332 }
3333 
GetConvolveAlgorithms(bool with_winograd_nonfused,int cc_major,int cc_minor,std::vector<dnn::AlgorithmDesc> * out_algorithms)3334 bool CudnnSupport::GetConvolveAlgorithms(
3335     bool with_winograd_nonfused, int cc_major, int cc_minor,
3336     std::vector<dnn::AlgorithmDesc>* out_algorithms) {
3337   bool tensor_op_math_available = TensorOpMathAvailable(cc_major);
3338   out_algorithms->clear();
3339 
3340   std::vector<dnn::AlgorithmDesc::Index> algo_types = {
3341       // clang-format off
3342     CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM,
3343     CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM,
3344     CUDNN_CONVOLUTION_FWD_ALGO_GEMM,
3345     CUDNN_CONVOLUTION_FWD_ALGO_DIRECT,
3346     CUDNN_CONVOLUTION_FWD_ALGO_FFT,
3347     CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD,
3348       // clang-format on
3349   };
3350   if (CudnnEnvVar<FftTilingForward>::IsEnabled()) {
3351     algo_types.push_back(CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING);
3352   }
3353   if (CudnnEnvVar<WinogradNonfused>::IsEnabled() && with_winograd_nonfused) {
3354     algo_types.push_back(CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED);
3355   }
3356 
3357   // The algorithms are intentionally ordered for deterministic operation
3358   for (auto i : algo_types) {
3359     if (tensor_op_math_available) {
3360       out_algorithms->push_back({i, /*use_tensor_ops=*/true});
3361     }
3362     out_algorithms->push_back({i, /*use_tensor_ops=*/false});
3363   }
3364 
3365   return true;
3366 }
3367 
GetRnnAlgorithms(std::vector<dnn::AlgorithmDesc> * out_algorithms)3368 bool CudnnSupport::GetRnnAlgorithms(
3369     std::vector<dnn::AlgorithmDesc>* out_algorithms) {
3370   std::vector<dnn::AlgorithmDesc::Index> algo_types = {
3371       // clang-format off
3372     CUDNN_RNN_ALGO_STANDARD,
3373     CUDNN_RNN_ALGO_PERSIST_STATIC,
3374     CUDNN_RNN_ALGO_PERSIST_DYNAMIC,
3375       // clang-format on
3376   };
3377 
3378   out_algorithms->clear();
3379   for (auto i : algo_types) {
3380     out_algorithms->push_back({i, /*use_tensor_ops=*/false});
3381 #if CUDNN_VERSION >= 7100
3382     if (RnnTensorOpMathEnabled()) {
3383       out_algorithms->push_back({i, /*use_tensor_ops=*/true});
3384     }
3385 #endif
3386   }
3387   return true;
3388 }
3389 
GetConvolveBackwardDataAlgorithms(bool with_winograd_nonfused,int cc_major,int cc_minor,std::vector<dnn::AlgorithmDesc> * out_algorithms)3390 bool CudnnSupport::GetConvolveBackwardDataAlgorithms(
3391     bool with_winograd_nonfused, int cc_major, int cc_minor,
3392     std::vector<dnn::AlgorithmDesc>* out_algorithms) {
3393   bool tensor_op_math_available = TensorOpMathAvailable(cc_major);
3394   out_algorithms->clear();
3395 
3396   std::vector<dnn::AlgorithmDesc::Index> algo_types = {
3397       // clang-format off
3398     CUDNN_CONVOLUTION_BWD_DATA_ALGO_1,
3399     CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT,
3400     CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING,
3401     CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD,
3402       // clang-format on
3403   };
3404   if (CudnnEnvVar<WinogradNonfused>::IsEnabled() && with_winograd_nonfused) {
3405     algo_types.push_back(CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD_NONFUSED);
3406   }
3407   if (!RequireCudnnDeterminism()) {
3408     algo_types.push_back(CUDNN_CONVOLUTION_BWD_DATA_ALGO_0);
3409   }
3410 
3411   // The algorithms are intentionally ordered for deterministic operation
3412   for (auto i : algo_types) {
3413     if (tensor_op_math_available) {
3414       out_algorithms->push_back({i, /*use_tensor_ops=*/true});
3415     }
3416     out_algorithms->push_back({i, /*use_tensor_ops=*/false});
3417   }
3418 
3419   return true;
3420 }
3421 
GetConvolveBackwardFilterAlgorithms(bool with_winograd_nonfused,int cc_major,int cc_minor,std::vector<dnn::AlgorithmDesc> * out_algorithms)3422 bool CudnnSupport::GetConvolveBackwardFilterAlgorithms(
3423     bool with_winograd_nonfused, int cc_major, int cc_minor,
3424     std::vector<dnn::AlgorithmDesc>* out_algorithms) {
3425   bool tensor_op_math_available = TensorOpMathAvailable(cc_major);
3426   out_algorithms->clear();
3427 
3428   std::vector<dnn::AlgorithmDesc::Index> algo_types = {
3429       // clang-format off
3430       CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1,
3431       CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT,
3432       // Based on cudnn.h, the following is not implemented.
3433       // CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD,
3434 
3435       // Produces incorrect results for some shapes. Disabled for now, see
3436       // NVIDIA bug 2072856. TODO(csigg): Only disable for subset of shapes.
3437       // CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT_TILING,
3438       // clang-format on
3439   };
3440   if (CudnnEnvVar<WinogradNonfused>::IsEnabled() && with_winograd_nonfused) {
3441     algo_types.push_back(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD_NONFUSED);
3442   }
3443   if (!RequireCudnnDeterminism()) {
3444     algo_types.push_back(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0);
3445     algo_types.push_back(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_3);
3446   }
3447 
3448   // The algorithms are intentionally ordered for deterministic operation
3449   for (auto i : algo_types) {
3450     if (tensor_op_math_available) {
3451       out_algorithms->push_back({i, /*use_tensor_ops=*/true});
3452     }
3453     out_algorithms->push_back({i, /*use_tensor_ops=*/false});
3454   }
3455 
3456   return true;
3457 }
3458 
DoBatchNormalizationForward(Stream * stream,const DeviceMemory<float> & x,const DeviceMemory<float> & scale,const DeviceMemory<float> & offset,const DeviceMemory<float> & estimated_mean,const DeviceMemory<float> & estimated_variance,const DeviceMemory<float> & side_input,const dnn::BatchDescriptor & x_desc,const dnn::BatchDescriptor & scale_offset_desc,const double epsilon,const double exponential_average_factor,dnn::ActivationMode activation_mode,DeviceMemory<float> * y,DeviceMemory<float> * batch_mean,DeviceMemory<float> * batch_var,DeviceMemory<float> * saved_mean,DeviceMemory<float> * saved_inv_var,bool is_training,ScratchAllocator * reserve_space_allocator,ScratchAllocator * workspace_allocator,std::function<const DeviceMemory<float> & ()> var_to_inv_var,std::function<void ()> inv_var_to_var)3459 bool CudnnSupport::DoBatchNormalizationForward(
3460     Stream* stream, const DeviceMemory<float>& x,
3461     const DeviceMemory<float>& scale, const DeviceMemory<float>& offset,
3462     const DeviceMemory<float>& estimated_mean,
3463     const DeviceMemory<float>& estimated_variance,
3464     const DeviceMemory<float>& side_input, const dnn::BatchDescriptor& x_desc,
3465     const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
3466     const double exponential_average_factor,
3467     dnn::ActivationMode activation_mode, DeviceMemory<float>* y,
3468     DeviceMemory<float>* batch_mean, DeviceMemory<float>* batch_var,
3469     DeviceMemory<float>* saved_mean, DeviceMemory<float>* saved_inv_var,
3470     bool is_training, ScratchAllocator* reserve_space_allocator,
3471     ScratchAllocator* workspace_allocator,
3472     std::function<const DeviceMemory<float>&()> var_to_inv_var,
3473     std::function<void()> inv_var_to_var) {
3474   return IsStatusOk(
3475       DoBatchNormalizationForwardImpl<float, float>(
3476           stream, dnn::DataType::kFloat, dnn::DataType::kFloat, x, scale,
3477           offset, estimated_mean, estimated_variance, side_input, x_desc,
3478           scale_offset_desc, epsilon, exponential_average_factor,
3479           activation_mode, y, batch_mean, batch_var, saved_mean, saved_inv_var,
3480           is_training, reserve_space_allocator, workspace_allocator,
3481           std::move(var_to_inv_var), std::move(inv_var_to_var)),
3482       /*report_error=*/true);
3483 }
3484 
DoBatchNormalizationForward(Stream * stream,const DeviceMemory<Eigen::half> & x,const DeviceMemory<float> & scale,const DeviceMemory<float> & offset,const DeviceMemory<float> & estimated_mean,const DeviceMemory<float> & estimated_variance,const DeviceMemory<float> & side_input,const dnn::BatchDescriptor & x_desc,const dnn::BatchDescriptor & scale_offset_desc,const double epsilon,const double exponential_average_factor,dnn::ActivationMode activation_mode,DeviceMemory<Eigen::half> * y,DeviceMemory<float> * batch_mean,DeviceMemory<float> * batch_var,DeviceMemory<float> * saved_mean,DeviceMemory<float> * saved_inv_var,bool is_training,ScratchAllocator * reserve_space_allocator,ScratchAllocator * workspace_allocator,std::function<const DeviceMemory<float> & ()> var_to_inv_var,std::function<void ()> inv_var_to_var)3485 bool CudnnSupport::DoBatchNormalizationForward(
3486     Stream* stream, const DeviceMemory<Eigen::half>& x,
3487     const DeviceMemory<float>& scale, const DeviceMemory<float>& offset,
3488     const DeviceMemory<float>& estimated_mean,
3489     const DeviceMemory<float>& estimated_variance,
3490     const DeviceMemory<float>& side_input, const dnn::BatchDescriptor& x_desc,
3491     const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
3492     const double exponential_average_factor,
3493     dnn::ActivationMode activation_mode, DeviceMemory<Eigen::half>* y,
3494     DeviceMemory<float>* batch_mean, DeviceMemory<float>* batch_var,
3495     DeviceMemory<float>* saved_mean, DeviceMemory<float>* saved_inv_var,
3496     bool is_training, ScratchAllocator* reserve_space_allocator,
3497     ScratchAllocator* workspace_allocator,
3498     std::function<const DeviceMemory<float>&()> var_to_inv_var,
3499     std::function<void()> inv_var_to_var) {
3500   return IsStatusOk(
3501       DoBatchNormalizationForwardImpl<Eigen::half, float>(
3502           stream, dnn::DataType::kHalf, dnn::DataType::kFloat, x, scale, offset,
3503           estimated_mean, estimated_variance, side_input, x_desc,
3504           scale_offset_desc, epsilon, exponential_average_factor,
3505           activation_mode, y, batch_mean, batch_var, saved_mean, saved_inv_var,
3506           is_training, reserve_space_allocator, workspace_allocator,
3507           std::move(var_to_inv_var), std::move(inv_var_to_var)),
3508       /*report_error=*/true);
3509 }
3510 
3511 template <class T, class U>
DoBatchNormalizationForwardImpl(Stream * stream,dnn::DataType input_data_type,dnn::DataType scale_data_type,const DeviceMemory<T> & x,const DeviceMemory<U> & scale,const DeviceMemory<U> & offset,const DeviceMemory<U> & estimated_mean,const DeviceMemory<U> & estimated_variance,const DeviceMemory<U> & side_input,const dnn::BatchDescriptor & x_desc,const dnn::BatchDescriptor & scale_offset_desc,const double epsilon,const double exponential_average_factor,dnn::ActivationMode activation_mode,DeviceMemory<T> * y,DeviceMemory<U> * batch_mean,DeviceMemory<U> * batch_var,DeviceMemory<U> * saved_mean,DeviceMemory<U> * saved_inv_var,bool is_training,ScratchAllocator * reserve_space_allocator,ScratchAllocator * workspace_allocator,std::function<const DeviceMemory<U> & ()> var_to_inv_var,std::function<void ()> inv_var_to_var)3512 port::Status CudnnSupport::DoBatchNormalizationForwardImpl(
3513     Stream* stream, dnn::DataType input_data_type,
3514     dnn::DataType scale_data_type, const DeviceMemory<T>& x,
3515     const DeviceMemory<U>& scale, const DeviceMemory<U>& offset,
3516     const DeviceMemory<U>& estimated_mean,
3517     const DeviceMemory<U>& estimated_variance,
3518     const DeviceMemory<U>& side_input, const dnn::BatchDescriptor& x_desc,
3519     const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
3520     const double exponential_average_factor,
3521     dnn::ActivationMode activation_mode, DeviceMemory<T>* y,
3522     DeviceMemory<U>* batch_mean, DeviceMemory<U>* batch_var,
3523     DeviceMemory<U>* saved_mean, DeviceMemory<U>* saved_inv_var,
3524     bool is_training, ScratchAllocator* reserve_space_allocator,
3525     ScratchAllocator* workspace_allocator,
3526     std::function<const DeviceMemory<U>&()> var_to_inv_var,
3527     std::function<void()> inv_var_to_var) {
3528   CudnnTensorDescriptor x_descriptor(x_desc, ToCudnnDataType(input_data_type));
3529   CudnnTensorDescriptor scale_offset_descriptor(
3530       scale_offset_desc, ToCudnnDataType(scale_data_type));
3531   cudnnBatchNormMode_t mode = CUDNN_BATCHNORM_SPATIAL;
3532 #if CUDNN_VERSION >= 7000
3533   if (BatchnormSpatialPersistentEnabled() && is_training) {
3534     mode = CUDNN_BATCHNORM_SPATIAL_PERSISTENT;
3535   }
3536 #endif
3537   float one = 1.0;
3538   float zero = 0.0;
3539   auto cudnn = cudnn_->GetHandle(parent_, stream);
3540 
3541   DeviceMemory<uint8> workspace;
3542   DeviceMemory<uint8> reserve_space;
3543 
3544 #if CUDNN_VERSION >= 7402
3545   const auto get_bn_ops = [&]() -> cudnnBatchNormOps_t {
3546     if (side_input.is_null()) {
3547       return activation_mode == dnn::ActivationMode::kNone
3548                  ? CUDNN_BATCHNORM_OPS_BN
3549                  : CUDNN_BATCHNORM_OPS_BN_ACTIVATION;
3550     } else {
3551       return CUDNN_BATCHNORM_OPS_BN_ADD_ACTIVATION;
3552     }
3553   };
3554   const cudnnBatchNormOps_t bn_ops = get_bn_ops();
3555 
3556   // We use Nan propagation to be consistent with CudnnSupport::DoActivate(...).
3557   CudnnActivationDescriptor activation_desc(
3558       activation_mode, CUDNN_PROPAGATE_NAN, x_desc.value_max());
3559 
3560   if (reserve_space_allocator != nullptr && workspace_allocator != nullptr) {
3561     SE_ASSIGN_OR_RETURN(
3562         workspace,
3563         CreateBatchNormForwardWorkspace(
3564             stream, cudnn, mode, bn_ops, activation_desc.handle(), x_descriptor,
3565             scale_offset_descriptor, workspace_allocator))
3566     if (is_training) {
3567       size_t reserve_space_size_in_bytes = 0;
3568       RETURN_IF_CUDNN_ERROR(
3569           cudnnGetBatchNormalizationTrainingExReserveSpaceSize(
3570               /*handle=*/cudnn.handle(), /*mode=*/mode, /*bnOps=*/bn_ops,
3571               /*activationDesc=*/activation_desc.handle(),
3572               /*xDesc=*/x_descriptor.handle(),
3573               /*sizeInBytes=*/&reserve_space_size_in_bytes));
3574       SE_ASSIGN_OR_RETURN(reserve_space, reserve_space_allocator->AllocateBytes(
3575                                              reserve_space_size_in_bytes));
3576     }
3577   }
3578 #endif
3579 
3580   auto check_no_side_input_or_activation = [&]() -> port::Status {
3581     if (activation_mode != dnn::ActivationMode::kNone ||
3582         !side_input.is_null()) {
3583       return port::Status(
3584           port::error::INTERNAL,
3585           absl::StrCat(
3586               "Side input and activation are not supported by cuDNN version: ",
3587               CUDNN_VERSION));
3588     } else {
3589       return port::Status::OK();
3590     }
3591   };
3592 
3593   if (is_training) {
3594     CHECK_EQ(batch_mean->is_null(), batch_var->is_null())
3595         << "batch_mean and batch_var must both be null or both be non-null";
3596 
3597     void* batch_mean_opaque;
3598     void* batch_var_opaque;
3599     if (!batch_mean->is_null() && !batch_var->is_null()) {
3600       stream->ThenMemZero(batch_mean, batch_mean->size());
3601       stream->ThenMemZero(batch_var, batch_var->size());
3602       batch_mean_opaque = batch_mean->opaque();
3603       batch_var_opaque = batch_var->opaque();
3604     } else {
3605       batch_mean_opaque = nullptr;
3606       batch_var_opaque = nullptr;
3607     }
3608 
3609     bool called = false;
3610 #if CUDNN_VERSION >= 7402
3611     if (reserve_space_allocator != nullptr && workspace_allocator != nullptr) {
3612       called = true;
3613       RETURN_IF_CUDNN_ERROR(cudnnBatchNormalizationForwardTrainingEx(
3614           /*handle=*/cudnn.handle(),
3615           /*mode=*/mode,
3616           /*bnOps=*/bn_ops,
3617           /*alpha=*/&one,
3618           /*beta=*/&zero,
3619           /*xDesc=*/x_descriptor.handle(),
3620           /*xData=*/x.opaque(),
3621           /*zDesc=*/x_descriptor.handle(),
3622           /*zData=*/side_input.opaque(),
3623           /*yDesc=*/x_descriptor.handle(),
3624           /*yData=*/y->opaque(),
3625           /*bnScaleBiasMeanVarDesc=*/scale_offset_descriptor.handle(),
3626           /*bnScale=*/scale.opaque(),
3627           /*bnBias=*/offset.opaque(),
3628           /*exponentialAverageFactor=*/exponential_average_factor,
3629           /*resultRunningMean=*/batch_mean_opaque,
3630           /*resultRunningVariance=*/batch_var_opaque,
3631           /*epsilon=*/epsilon,
3632           /*resultSaveMean=*/saved_mean->opaque(),
3633           /*resultSaveInvVariance=*/saved_inv_var->opaque(),
3634           /*activationDesc=*/activation_desc.handle(),
3635           /*workspace=*/workspace.opaque(),
3636           /*workSpaceSizeInBytes=*/workspace.size(),
3637           /*reserveSpace=*/reserve_space.opaque(),
3638           /*reserveSpaceSizeInBytes=*/reserve_space.size()));
3639     }
3640 #endif
3641     if (!called) {
3642       SE_RETURN_IF_ERROR(check_no_side_input_or_activation());
3643       RETURN_IF_CUDNN_ERROR(cudnnBatchNormalizationForwardTraining(
3644           cudnn.handle(), mode, &one, &zero, x_descriptor.handle(), x.opaque(),
3645           x_descriptor.handle(), y->opaque(), scale_offset_descriptor.handle(),
3646           scale.opaque(), offset.opaque(), exponential_average_factor,
3647           batch_mean_opaque, batch_var_opaque, epsilon, saved_mean->opaque(),
3648           saved_inv_var->opaque()));
3649     }
3650   } else {
3651     const void* maybe_inv_var = estimated_variance.opaque();
3652     SE_RETURN_IF_ERROR(check_no_side_input_or_activation());
3653     RETURN_IF_CUDNN_ERROR(cudnnBatchNormalizationForwardInference(
3654         cudnn.handle(), mode, &one, &zero, x_descriptor.handle(), x.opaque(),
3655         x_descriptor.handle(), y->opaque(), scale_offset_descriptor.handle(),
3656         scale.opaque(), offset.opaque(), estimated_mean.opaque(), maybe_inv_var,
3657         epsilon));
3658   }
3659   return port::Status::OK();
3660 }
3661 
DoBatchNormalizationBackward(Stream * stream,const DeviceMemory<float> & y_backprop,const DeviceMemory<float> & x,const DeviceMemory<float> & scale,const DeviceMemory<float> & mean,const DeviceMemory<float> & inv_var,const dnn::BatchDescriptor & x_desc,const dnn::BatchDescriptor & scale_offset_desc,const double epsilon,DeviceMemory<float> * x_backprop,DeviceMemory<float> * scale_backprop,DeviceMemory<float> * offset_backprop,DeviceMemory<uint8> * reserve_space_data,ScratchAllocator * workspace_allocator)3662 bool CudnnSupport::DoBatchNormalizationBackward(
3663     Stream* stream, const DeviceMemory<float>& y_backprop,
3664     const DeviceMemory<float>& x, const DeviceMemory<float>& scale,
3665     const DeviceMemory<float>& mean, const DeviceMemory<float>& inv_var,
3666     const dnn::BatchDescriptor& x_desc,
3667     const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
3668     DeviceMemory<float>* x_backprop, DeviceMemory<float>* scale_backprop,
3669     DeviceMemory<float>* offset_backprop,
3670     DeviceMemory<uint8>* reserve_space_data,
3671     ScratchAllocator* workspace_allocator) {
3672   return IsStatusOk(DoBatchNormalizationBackwardImpl(
3673                         stream, CUDNN_DATA_FLOAT, CUDNN_DATA_FLOAT, y_backprop,
3674                         x, scale, mean, inv_var, x_desc, scale_offset_desc,
3675                         epsilon, x_backprop, scale_backprop, offset_backprop,
3676                         reserve_space_data, workspace_allocator),
3677                     /*report_error=*/true);
3678 }
3679 
DoBatchNormalizationBackward(Stream * stream,const DeviceMemory<Eigen::half> & y_backprop,const DeviceMemory<Eigen::half> & x,const DeviceMemory<float> & scale,const DeviceMemory<float> & mean,const DeviceMemory<float> & inv_var,const dnn::BatchDescriptor & x_desc,const dnn::BatchDescriptor & scale_offset_desc,const double epsilon,DeviceMemory<Eigen::half> * x_backprop,DeviceMemory<float> * scale_backprop,DeviceMemory<float> * offset_backprop,DeviceMemory<uint8> * reserve_space_data,ScratchAllocator * workspace_allocator)3680 bool CudnnSupport::DoBatchNormalizationBackward(
3681     Stream* stream, const DeviceMemory<Eigen::half>& y_backprop,
3682     const DeviceMemory<Eigen::half>& x, const DeviceMemory<float>& scale,
3683     const DeviceMemory<float>& mean, const DeviceMemory<float>& inv_var,
3684     const dnn::BatchDescriptor& x_desc,
3685     const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
3686     DeviceMemory<Eigen::half>* x_backprop, DeviceMemory<float>* scale_backprop,
3687     DeviceMemory<float>* offset_backprop,
3688     DeviceMemory<uint8>* reserve_space_data,
3689     ScratchAllocator* workspace_allocator) {
3690   return IsStatusOk(DoBatchNormalizationBackwardImpl(
3691                         stream, CUDNN_DATA_HALF, CUDNN_DATA_FLOAT, y_backprop,
3692                         x, scale, mean, inv_var, x_desc, scale_offset_desc,
3693                         epsilon, x_backprop, scale_backprop, offset_backprop,
3694                         reserve_space_data, workspace_allocator),
3695                     /*report_error=*/true);
3696 }
3697 
3698 template <class T, class U>
DoBatchNormalizationBackwardImpl(Stream * stream,int cudnn_input_type,int cudnn_scale_type,const DeviceMemory<T> & y_backprop,const DeviceMemory<T> & x,const DeviceMemory<U> & scale,const DeviceMemory<U> & mean,const DeviceMemory<U> & inv_var,const dnn::BatchDescriptor & x_desc,const dnn::BatchDescriptor & scale_offset_desc,const double epsilon,DeviceMemory<T> * x_backprop,DeviceMemory<U> * scale_backprop,DeviceMemory<U> * offset_backprop,DeviceMemory<uint8> * reserve_space_data,ScratchAllocator * workspace_allocator)3699 port::Status CudnnSupport::DoBatchNormalizationBackwardImpl(
3700     Stream* stream, int cudnn_input_type, int cudnn_scale_type,
3701     const DeviceMemory<T>& y_backprop, const DeviceMemory<T>& x,
3702     const DeviceMemory<U>& scale, const DeviceMemory<U>& mean,
3703     const DeviceMemory<U>& inv_var, const dnn::BatchDescriptor& x_desc,
3704     const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
3705     DeviceMemory<T>* x_backprop, DeviceMemory<U>* scale_backprop,
3706     DeviceMemory<U>* offset_backprop, DeviceMemory<uint8>* reserve_space_data,
3707     ScratchAllocator* workspace_allocator) {
3708   CudnnTensorDescriptor x_descriptor(
3709       x_desc, static_cast<cudnnDataType_t>(cudnn_input_type));
3710   CudnnTensorDescriptor scale_offset_descriptor(
3711       scale_offset_desc, static_cast<cudnnDataType_t>(cudnn_scale_type));
3712   cudnnBatchNormMode_t mode = CUDNN_BATCHNORM_SPATIAL;
3713 #if CUDNN_VERSION >= 7000
3714   if (BatchnormSpatialPersistentEnabled()) {
3715     mode = CUDNN_BATCHNORM_SPATIAL_PERSISTENT;
3716   }
3717 #endif
3718   float one = 1.0;
3719   float zero = 0.0;
3720 
3721   auto cudnn = cudnn_->GetHandle(parent_, stream);
3722 
3723   bool called = false;
3724 #if CUDNN_VERSION >= 7402
3725   if (reserve_space_data != nullptr && workspace_allocator != nullptr) {
3726     called = true;
3727     const cudnnBatchNormOps_t bn_ops = CUDNN_BATCHNORM_OPS_BN;
3728     SE_ASSIGN_OR_RETURN(DeviceMemory<uint8> workspace,
3729                         CreateBatchNormBackwardWorkspace(
3730                             stream, cudnn, mode, bn_ops, x_descriptor,
3731                             scale_offset_descriptor, workspace_allocator))
3732     RETURN_IF_CUDNN_ERROR(cudnnBatchNormalizationBackwardEx(
3733         /*handle=*/cudnn.handle(),
3734         /*mode=*/mode,
3735         /*bnOps=*/bn_ops,
3736         /*alphaDataDiff=*/&one,
3737         /*betaDataDiff=*/&zero,
3738         /*alphaParamDiff=*/&one,
3739         /*betaParamDiff=*/&zero,
3740         /*xDesc=*/x_descriptor.handle(),
3741         /*xData=*/x.opaque(),
3742         /*yDesc=*/nullptr,
3743         /*yData=*/nullptr,
3744         /*dyDesc=*/x_descriptor.handle(),
3745         /*dyData=*/y_backprop.opaque(),
3746         /*dzDesc=*/nullptr,
3747         /*dzData=*/nullptr,
3748         /*dxDesc=*/x_descriptor.handle(),
3749         /*dxData=*/x_backprop->opaque(),
3750         /*dBnScaleBiasDesc=*/scale_offset_descriptor.handle(),
3751         /*bnScaleData=*/scale.opaque(),
3752         /*bnBiasData=*/nullptr,
3753         /*dBnScaleData=*/scale_backprop->opaque(),
3754         /*dBnBiasData=*/offset_backprop->opaque(),
3755         /*epsilon=*/epsilon,
3756         /*savedMean=*/mean.opaque(),
3757         /*savedInvVariance=*/inv_var.opaque(),
3758         /*activationDesc=*/nullptr,
3759         /*workspace=*/workspace.opaque(),
3760         /*workSpaceSizeInBytes=*/workspace.size(),
3761         /*reserveSpace=*/reserve_space_data->opaque(),
3762         /*reserveSpaceSizeInBytes=*/reserve_space_data->size()));
3763   }
3764 #endif
3765   if (!called) {
3766     RETURN_IF_CUDNN_ERROR(cudnnBatchNormalizationBackward(
3767         cudnn.handle(), mode, &one, &zero, &one, &zero, x_descriptor.handle(),
3768         x.opaque(), x_descriptor.handle(), y_backprop.opaque(),
3769         x_descriptor.handle(), x_backprop->opaque(),
3770         scale_offset_descriptor.handle(), scale.opaque(),
3771         scale_backprop->opaque(), offset_backprop->opaque(), epsilon,
3772         mean.opaque(), inv_var.opaque()));
3773   }
3774 
3775   return port::Status::OK();
3776 }
3777 
DoFusedConvolve(Stream * stream,const dnn::BatchDescriptor & conv_input_descriptor,const DeviceMemory<double> & conv_input_data,double conv_input_scale,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<double> & filter_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const DeviceMemory<double> & side_input_data,double side_input_scale,const dnn::BatchDescriptor & bias_descriptor,const DeviceMemory<double> & biases,dnn::ActivationMode activation_mode,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<double> * output_data,ScratchAllocator * scratch_allocator,const dnn::AlgorithmConfig & algorithm_config,dnn::ProfileResult * output_profile_result)3778 bool CudnnSupport::DoFusedConvolve(
3779     Stream* stream, const dnn::BatchDescriptor& conv_input_descriptor,
3780     const DeviceMemory<double>& conv_input_data, double conv_input_scale,
3781     const dnn::FilterDescriptor& filter_descriptor,
3782     const DeviceMemory<double>& filter_data,
3783     const dnn::ConvolutionDescriptor& convolution_descriptor,
3784     const DeviceMemory<double>& side_input_data, double side_input_scale,
3785     const dnn::BatchDescriptor& bias_descriptor,
3786     const DeviceMemory<double>& biases, dnn::ActivationMode activation_mode,
3787     const dnn::BatchDescriptor& output_descriptor,
3788     DeviceMemory<double>* output_data, ScratchAllocator* scratch_allocator,
3789     const dnn::AlgorithmConfig& algorithm_config,
3790     dnn::ProfileResult* output_profile_result) {
3791   return IsStatusOk(
3792       DoFusedConvolveImpl(
3793           stream, conv_input_descriptor, conv_input_data, conv_input_scale,
3794           filter_descriptor, filter_data, convolution_descriptor,
3795           side_input_data, side_input_scale, bias_descriptor, biases,
3796           activation_mode, output_descriptor, output_data,
3797           GetConvAccumulatorType(dnn::DataType::kDouble), scratch_allocator,
3798           algorithm_config, output_profile_result),
3799       /*report_error=*/!output_profile_result);
3800 }
3801 
DoFusedConvolve(Stream * stream,const dnn::BatchDescriptor & conv_input_descriptor,const DeviceMemory<float> & conv_input_data,float conv_input_scale,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<float> & filter_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const DeviceMemory<float> & side_input_data,float side_input_scale,const dnn::BatchDescriptor & bias_descriptor,const DeviceMemory<float> & biases,dnn::ActivationMode activation_mode,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<float> * output_data,ScratchAllocator * scratch_allocator,const dnn::AlgorithmConfig & algorithm_config,dnn::ProfileResult * output_profile_result)3802 bool CudnnSupport::DoFusedConvolve(
3803     Stream* stream, const dnn::BatchDescriptor& conv_input_descriptor,
3804     const DeviceMemory<float>& conv_input_data, float conv_input_scale,
3805     const dnn::FilterDescriptor& filter_descriptor,
3806     const DeviceMemory<float>& filter_data,
3807     const dnn::ConvolutionDescriptor& convolution_descriptor,
3808     const DeviceMemory<float>& side_input_data, float side_input_scale,
3809     const dnn::BatchDescriptor& bias_descriptor,
3810     const DeviceMemory<float>& biases, dnn::ActivationMode activation_mode,
3811     const dnn::BatchDescriptor& output_descriptor,
3812     DeviceMemory<float>* output_data, ScratchAllocator* scratch_allocator,
3813     const dnn::AlgorithmConfig& algorithm_config,
3814     dnn::ProfileResult* output_profile_result) {
3815   return IsStatusOk(
3816       DoFusedConvolveImpl(
3817           stream, conv_input_descriptor, conv_input_data, conv_input_scale,
3818           filter_descriptor, filter_data, convolution_descriptor,
3819           side_input_data, side_input_scale, bias_descriptor, biases,
3820           activation_mode, output_descriptor, output_data,
3821           GetConvAccumulatorType(dnn::DataType::kFloat), scratch_allocator,
3822           algorithm_config, output_profile_result),
3823       /*report_error=*/!output_profile_result);
3824 }
3825 
DoFusedConvolve(Stream * stream,const dnn::BatchDescriptor & conv_input_descriptor,const DeviceMemory<Eigen::half> & conv_input_data,float conv_input_scale,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<Eigen::half> & filter_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const DeviceMemory<Eigen::half> & side_input_data,float side_input_scale,const dnn::BatchDescriptor & bias_descriptor,const DeviceMemory<Eigen::half> & biases,dnn::ActivationMode activation_mode,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<Eigen::half> * output_data,ScratchAllocator * scratch_allocator,const dnn::AlgorithmConfig & algorithm_config,dnn::ProfileResult * output_profile_result)3826 bool CudnnSupport::DoFusedConvolve(
3827     Stream* stream, const dnn::BatchDescriptor& conv_input_descriptor,
3828     const DeviceMemory<Eigen::half>& conv_input_data, float conv_input_scale,
3829     const dnn::FilterDescriptor& filter_descriptor,
3830     const DeviceMemory<Eigen::half>& filter_data,
3831     const dnn::ConvolutionDescriptor& convolution_descriptor,
3832     const DeviceMemory<Eigen::half>& side_input_data, float side_input_scale,
3833     const dnn::BatchDescriptor& bias_descriptor,
3834     const DeviceMemory<Eigen::half>& biases,
3835     dnn::ActivationMode activation_mode,
3836     const dnn::BatchDescriptor& output_descriptor,
3837     DeviceMemory<Eigen::half>* output_data, ScratchAllocator* scratch_allocator,
3838     const dnn::AlgorithmConfig& algorithm_config,
3839     dnn::ProfileResult* output_profile_result) {
3840   return IsStatusOk(
3841       DoFusedConvolveImpl(
3842           stream, conv_input_descriptor, conv_input_data, conv_input_scale,
3843           filter_descriptor, filter_data, convolution_descriptor,
3844           side_input_data, side_input_scale, bias_descriptor, biases,
3845           activation_mode, output_descriptor, output_data,
3846           GetConvAccumulatorType(dnn::DataType::kHalf), scratch_allocator,
3847           algorithm_config, output_profile_result),
3848       /*report_error=*/!output_profile_result);
3849 }
3850 
DoFusedConvolve(Stream * stream,const dnn::BatchDescriptor & conv_input_descriptor,const DeviceMemory<int8> & conv_input_data,float conv_input_scale,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<int8> & filter_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const DeviceMemory<int8> & side_input_data,float side_input_scale,const dnn::BatchDescriptor & bias_descriptor,const DeviceMemory<float> & biases,dnn::ActivationMode activation_mode,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<int8> * output_data,ScratchAllocator * scratch_allocator,const dnn::AlgorithmConfig & algorithm_config,dnn::ProfileResult * output_profile_result)3851 bool CudnnSupport::DoFusedConvolve(
3852     Stream* stream, const dnn::BatchDescriptor& conv_input_descriptor,
3853     const DeviceMemory<int8>& conv_input_data, float conv_input_scale,
3854     const dnn::FilterDescriptor& filter_descriptor,
3855     const DeviceMemory<int8>& filter_data,
3856     const dnn::ConvolutionDescriptor& convolution_descriptor,
3857     const DeviceMemory<int8>& side_input_data, float side_input_scale,
3858     const dnn::BatchDescriptor& bias_descriptor,
3859     const DeviceMemory<float>& biases, dnn::ActivationMode activation_mode,
3860     const dnn::BatchDescriptor& output_descriptor,
3861     DeviceMemory<int8>* output_data, ScratchAllocator* scratch_allocator,
3862     const dnn::AlgorithmConfig& algorithm_config,
3863     dnn::ProfileResult* output_profile_result) {
3864   int cc_major, cc_minor;
3865   std::tie(cc_major, cc_minor) = GetCcMajorMinor(stream);
3866 
3867   if (cc_major < 6 || (cc_major == 6 && cc_minor < 1)) {
3868     LOG(WARNING) << "cudnnConvolutionBiasActivationForward() for int8 is only "
3869                     "supported on GPUs with compute capability 6.1 or later.";
3870     return false;
3871   }
3872 
3873   return IsStatusOk(
3874       DoFusedConvolveImpl(
3875           stream, conv_input_descriptor, conv_input_data, conv_input_scale,
3876           filter_descriptor, filter_data, convolution_descriptor,
3877           side_input_data, side_input_scale, bias_descriptor, biases,
3878           activation_mode, output_descriptor, output_data,
3879           GetConvAccumulatorType(dnn::DataType::kInt8), scratch_allocator,
3880           algorithm_config, output_profile_result),
3881       /*report_error=*/!output_profile_result);
3882 }
3883 
DoFusedConvolve(Stream * stream,const dnn::BatchDescriptor & conv_input_descriptor,const DeviceMemory<int8> & conv_input_data,float conv_input_scale,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<int8> & filter_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const DeviceMemory<float> & side_input_data,float side_input_scale,const dnn::BatchDescriptor & bias_descriptor,const DeviceMemory<float> & biases,dnn::ActivationMode activation_mode,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<float> * output_data,ScratchAllocator * scratch_allocator,const dnn::AlgorithmConfig & algorithm_config,dnn::ProfileResult * output_profile_result)3884 bool CudnnSupport::DoFusedConvolve(
3885     Stream* stream, const dnn::BatchDescriptor& conv_input_descriptor,
3886     const DeviceMemory<int8>& conv_input_data, float conv_input_scale,
3887     const dnn::FilterDescriptor& filter_descriptor,
3888     const DeviceMemory<int8>& filter_data,
3889     const dnn::ConvolutionDescriptor& convolution_descriptor,
3890     const DeviceMemory<float>& side_input_data, float side_input_scale,
3891     const dnn::BatchDescriptor& bias_descriptor,
3892     const DeviceMemory<float>& biases, dnn::ActivationMode activation_mode,
3893     const dnn::BatchDescriptor& output_descriptor,
3894     DeviceMemory<float>* output_data, ScratchAllocator* scratch_allocator,
3895     const dnn::AlgorithmConfig& algorithm_config,
3896     dnn::ProfileResult* output_profile_result) {
3897   int cc_major, cc_minor;
3898   stream->parent()->GetDeviceDescription().cuda_compute_capability(&cc_major,
3899                                                                    &cc_minor);
3900   if (cc_major < 6 || (cc_major == 6 && cc_minor < 1)) {
3901     LOG(WARNING) << "cudnnConvolutionBiasActivationForward() for int8 is only "
3902                     "supported on GPUs with compute capability 6.1 or later.";
3903     return false;
3904   }
3905 
3906   return IsStatusOk(
3907       DoFusedConvolveImpl(
3908           stream, conv_input_descriptor, conv_input_data, conv_input_scale,
3909           filter_descriptor, filter_data, convolution_descriptor,
3910           side_input_data, side_input_scale, bias_descriptor, biases,
3911           activation_mode, output_descriptor, output_data,
3912           GetConvAccumulatorType(dnn::DataType::kInt8), scratch_allocator,
3913           algorithm_config, output_profile_result),
3914       /*report_error=*/!output_profile_result);
3915 }
3916 
DoPrepareForCtcLoss(Stream * stream,dnn::DataType element_type,const dnn::RnnStateTensorDescriptor & probs_desc,const dnn::RnnStateTensorDescriptor & grads_desc,absl::Span<const int> labels_data,absl::Span<const int> labels_lengths_data,absl::Span<const int> input_lengths_data,ScratchAllocator * scratch_allocator,DeviceMemory<uint8> * scratch_memory)3917 port::Status CudnnSupport::DoPrepareForCtcLoss(
3918     Stream* stream, dnn::DataType element_type,
3919     const dnn::RnnStateTensorDescriptor& probs_desc,
3920     const dnn::RnnStateTensorDescriptor& grads_desc,
3921     absl::Span<const int> labels_data,
3922     absl::Span<const int> labels_lengths_data,
3923     absl::Span<const int> input_lengths_data,
3924     ScratchAllocator* scratch_allocator, DeviceMemory<uint8>* scratch_memory) {
3925   auto cudnn = cudnn_->GetHandle(parent_, stream);
3926   // Query the workspace size.
3927   size_t workspace_size_in_bytes = 0;
3928 #if CUDNN_VERSION >= 7603
3929   CudnnCtcLossDescriptor cudnn_ctc_loss_desc(ToCudnnDataType(element_type));
3930   const CudnnRnnStateTensorDescriptor& cudnn_probs_desc =
3931       static_cast<const CudnnRnnStateTensorDescriptor&>(probs_desc);
3932   const CudnnRnnStateTensorDescriptor& cudnn_grads_desc =
3933       static_cast<const CudnnRnnStateTensorDescriptor&>(grads_desc);
3934   RETURN_IF_CUDNN_ERROR(cudnnGetCTCLossWorkspaceSize(
3935       /*handle=*/cudnn.handle(), /*probsDesc=*/cudnn_probs_desc.handle(),
3936       /*gradientsDesc=*/cudnn_grads_desc.handle(),
3937       /*labels=*/labels_data.data(),
3938       /*labelLengths=*/labels_lengths_data.data(),
3939       /*inputLengths=*/input_lengths_data.data(),
3940       /*algo=*/CUDNN_CTC_LOSS_ALGO_NON_DETERMINISTIC,
3941       /*ctcLossDesc=*/cudnn_ctc_loss_desc.handle(),
3942       /*sizeInBytes=*/&workspace_size_in_bytes));
3943 #else
3944   return port::Status(port::error::INVALID_ARGUMENT,
3945                       "No supported cudnnGetCTCLossWorkspaceSize when "
3946                       "CUDNN_VERSION < 7.6.3");
3947 #endif
3948   // Allocate the workspace.
3949   if (workspace_size_in_bytes == 0) {
3950     *scratch_memory = DeviceMemory<uint8>();
3951     return port::Status::OK();
3952   }
3953   const auto scratch_or =
3954       scratch_allocator->AllocateBytes(workspace_size_in_bytes);
3955   if (scratch_or.ok()) {
3956     *scratch_memory = scratch_or.ValueOrDie();
3957     return port::Status::OK();
3958   }
3959   return port::InternalError(
3960       "Failed to allocate scratch memory for the CuDNN CTC Loss");
3961 }
3962 
DoCtcLoss(Stream * stream,dnn::DataType element_type,const dnn::RnnStateTensorDescriptor & probs_desc,const DeviceMemoryBase probs_data,absl::Span<const int> labels_data,absl::Span<const int> labels_lengths_data,absl::Span<const int> input_lengths_data,DeviceMemoryBase costs_data,const dnn::RnnStateTensorDescriptor & grads_desc,DeviceMemoryBase grads_data,DeviceMemory<uint8> scratch_memory)3963 port::Status CudnnSupport::DoCtcLoss(
3964     Stream* stream, dnn::DataType element_type,
3965     const dnn::RnnStateTensorDescriptor& probs_desc,
3966     const DeviceMemoryBase probs_data,
3967 
3968     absl::Span<const int> labels_data,
3969     absl::Span<const int> labels_lengths_data,
3970     absl::Span<const int> input_lengths_data, DeviceMemoryBase costs_data,
3971     const dnn::RnnStateTensorDescriptor& grads_desc,
3972     DeviceMemoryBase grads_data, DeviceMemory<uint8> scratch_memory) {
3973   // Current cuDNN CTC Loss only supports the float datatype
3974   if (CUDNN_VERSION < 7603 || element_type != dnn::DataType::kFloat) {
3975     return port::Status(port::error::INVALID_ARGUMENT,
3976                         "CudnnCtcLossDescriptor is supported only when the "
3977                         "CUDNN_VERSION >= 7.6.3 and DataType is float");
3978   }
3979   CudnnCtcLossDescriptor cudnn_ctc_loss_desc(ToCudnnDataType(element_type));
3980   const CudnnRnnStateTensorDescriptor& cudnn_probs_desc =
3981       static_cast<const CudnnRnnStateTensorDescriptor&>(probs_desc);
3982   const CudnnRnnStateTensorDescriptor& cudnn_grads_desc =
3983       static_cast<const CudnnRnnStateTensorDescriptor&>(grads_desc);
3984   return DoCtcLossImpl(stream, cudnn_probs_desc, probs_data, labels_data,
3985                        labels_lengths_data, input_lengths_data, costs_data,
3986                        cudnn_grads_desc, grads_data, cudnn_ctc_loss_desc,
3987                        scratch_memory);
3988 }
3989 
DoTransformTensor(Stream * stream,const dnn::BatchDescriptor & input_desc,dnn::DataType input_type,const DeviceMemoryBase & input_data,const dnn::BatchDescriptor & output_desc,dnn::DataType output_type,float scale,DeviceMemoryBase * output_data)3990 bool CudnnSupport::DoTransformTensor(Stream* stream,
3991                                      const dnn::BatchDescriptor& input_desc,
3992                                      dnn::DataType input_type,
3993                                      const DeviceMemoryBase& input_data,
3994                                      const dnn::BatchDescriptor& output_desc,
3995                                      dnn::DataType output_type, float scale,
3996                                      DeviceMemoryBase* output_data) {
3997   float beta = 0.0f;
3998   CudnnTensorDescriptor input_tensor_desc(
3999       input_desc, ToCudnnDataType(input_type, input_desc.layout()));
4000   CudnnTensorDescriptor output_tensor_desc(
4001       output_desc, ToCudnnDataType(output_type, output_desc.layout()));
4002   auto cudnn = cudnn_->GetHandle(parent_, stream);
4003   const auto status = [&] {
4004     RETURN_IF_CUDNN_ERROR(cudnnTransformTensor(
4005         cudnn.handle(), &scale, input_tensor_desc.handle(), input_data.opaque(),
4006         &beta, output_tensor_desc.handle(), output_data->opaque()));
4007     return port::Status::OK();
4008   }();
4009   return IsStatusOk(status, /*report_error=*/true);
4010 }
4011 
4012 template <class T>
DoConvolveBackwardBiasImpl(Stream * stream,const dnn::BatchDescriptor & input_descriptor,const DeviceMemory<T> & input_data,const dnn::BatchDescriptor & bias_descriptor,DeviceMemory<T> * backward_bias_data)4013 port::Status CudnnSupport::DoConvolveBackwardBiasImpl(
4014     Stream* stream, const dnn::BatchDescriptor& input_descriptor,
4015     const DeviceMemory<T>& input_data,
4016     const dnn::BatchDescriptor& bias_descriptor,
4017     DeviceMemory<T>* backward_bias_data) {
4018   cudnnDataType_t cudnn_type = GetCudnnDataType<T>();
4019   CudnnTensorDescriptor input_nd(input_descriptor, cudnn_type);
4020   CudnnTensorDescriptor bias_nd(bias_descriptor, cudnn_type);
4021 
4022   // Alpha is the scaling factor for input.
4023   float alpha = 1.0;
4024   // Beta is the scaling factor for output.
4025   float beta = 0.0;
4026 
4027   auto cudnn = cudnn_->GetHandle(parent_, stream);
4028   RETURN_IF_CUDNN_ERROR(cudnnConvolutionBackwardBias(
4029       cudnn.handle(), &alpha, input_nd.handle(), input_data.opaque(), &beta,
4030       bias_nd.handle(), backward_bias_data->opaque()));
4031   return port::Status::OK();
4032 }
4033 
DoConvolveBackwardBias(Stream * stream,const dnn::BatchDescriptor & input_descriptor,const DeviceMemory<double> & input_data,const dnn::BatchDescriptor & bias_descriptor,DeviceMemory<double> * backward_bias_data)4034 bool CudnnSupport::DoConvolveBackwardBias(
4035     Stream* stream, const dnn::BatchDescriptor& input_descriptor,
4036     const DeviceMemory<double>& input_data,
4037     const dnn::BatchDescriptor& bias_descriptor,
4038     DeviceMemory<double>* backward_bias_data) {
4039   return IsStatusOk(
4040       DoConvolveBackwardBiasImpl(stream, input_descriptor, input_data,
4041                                  bias_descriptor, backward_bias_data),
4042       /*report_error=*/true);
4043 }
4044 
DoConvolveBackwardBias(Stream * stream,const dnn::BatchDescriptor & input_descriptor,const DeviceMemory<float> & input_data,const dnn::BatchDescriptor & bias_descriptor,DeviceMemory<float> * backward_bias_data)4045 bool CudnnSupport::DoConvolveBackwardBias(
4046     Stream* stream, const dnn::BatchDescriptor& input_descriptor,
4047     const DeviceMemory<float>& input_data,
4048     const dnn::BatchDescriptor& bias_descriptor,
4049     DeviceMemory<float>* backward_bias_data) {
4050   return IsStatusOk(
4051       DoConvolveBackwardBiasImpl(stream, input_descriptor, input_data,
4052                                  bias_descriptor, backward_bias_data),
4053       /*report_error=*/true);
4054 }
4055 
DoConvolveBackwardBias(Stream * stream,const dnn::BatchDescriptor & input_descriptor,const DeviceMemory<Eigen::half> & input_data,const dnn::BatchDescriptor & bias_descriptor,DeviceMemory<Eigen::half> * backward_bias_data)4056 bool CudnnSupport::DoConvolveBackwardBias(
4057     Stream* stream, const dnn::BatchDescriptor& input_descriptor,
4058     const DeviceMemory<Eigen::half>& input_data,
4059     const dnn::BatchDescriptor& bias_descriptor,
4060     DeviceMemory<Eigen::half>* backward_bias_data) {
4061   return IsStatusOk(
4062       DoConvolveBackwardBiasImpl(stream, input_descriptor, input_data,
4063                                  bias_descriptor, backward_bias_data),
4064       /*report_error=*/true);
4065 }
4066 
DoMatMul(Stream * stream,const DeviceMemory<float> & input_data,const DeviceMemory<float> & weights,const dnn::BatchDescriptor & input_dimensions,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<float> * output_data)4067 bool CudnnSupport::DoMatMul(Stream* stream,
4068                             const DeviceMemory<float>& input_data,
4069                             const DeviceMemory<float>& weights,
4070                             const dnn::BatchDescriptor& input_dimensions,
4071                             const dnn::BatchDescriptor& output_dimensions,
4072                             DeviceMemory<float>* output_data) {
4073   if (input_dimensions.count() != output_dimensions.count()) {
4074     LOG(ERROR) << "MatMul input and output dimensions are not compatible.";
4075     return false;
4076   }
4077 
4078   // We do not permute the input or output, instead we just
4079   // reinterpret the layout. We are working with row-major matrices
4080   // and the rows of the input and output correspond to batch, so
4081   // batch has to be outermost in both the input and output.
4082   //
4083   // By adding transposes to the BLAS gemm call we could perhaps make
4084   // the kYXDepthBatch layout work as well, but there has been no need
4085   // for that so far.
4086   if (input_dimensions.layout() != dnn::DataLayout::kBatchYXDepth &&
4087       input_dimensions.layout() != dnn::DataLayout::kBatchDepthYX) {
4088     LOG(ERROR) << "Unsupported MatMul input layout.";
4089     return false;
4090   }
4091   if (output_dimensions.layout() != dnn::DataLayout::kBatchYXDepth &&
4092       output_dimensions.layout() != dnn::DataLayout::kBatchDepthYX) {
4093     LOG(ERROR) << "Unsupported MatMul output layout.";
4094     return false;
4095   }
4096 
4097   if (output_dimensions.width() == 1 && output_dimensions.height() == 1) {
4098     // This is a fast path that also supports the kBatchYXDepth layout.
4099 
4100     // The matrices here are in row-major format while BLAS expects
4101     // column-major, i.e. our matrices are transposed as far as BLAS
4102     // is concerned. So we need to compute output^T =
4103     // input^T*weights^T. There is no parameter for transposing the
4104     // output in BLAS gemm, but instead we can transpose both sides of
4105     // the equality to see that this is equivalent to
4106     // output=weights*input. So we only need to swap the order of
4107     // weights and input in the matrix product to correct for the
4108     // row-major versus column-major difference.
4109     const float alpha = 1.0f;  // Take the matrix product without scaling it.
4110     const float beta = 0.0f;   // Ignore the original values in output_data.
4111     const int64 m = output_dimensions.NodesAcrossFeatureMaps();
4112     const int64 n = input_dimensions.count();
4113     const int64 k = input_dimensions.NodesAcrossFeatureMaps();
4114     stream->ThenBlasGemm(blas::Transpose::kNoTranspose,
4115                          blas::Transpose::kNoTranspose, m, n, k, alpha, weights,
4116                          m, input_data, k, beta, output_data, m);
4117   } else {
4118     // This is a slower and more complex path that supports output
4119     // width() * height() > 1, though it only supports the
4120     // kBatchYXDepth layout. Does support kBatchDepthYX if output
4121     // feature_map_count() == 1, as then there is no difference
4122     // between the two layouts.
4123     //
4124     // The operation here is the same as above, except that we have to
4125     // do the matrix multiplication for each (y,x) output coordinate
4126     // separately. We then interpret weights as containing K = width()
4127     // * height() different matrices, which we all multiply onto the
4128     // matrix from input_data, yielding K matrix products. We then
4129     // combine these together into one matrix by concatenating all the
4130     // first rows of these matrices, then all the seconds rows and so
4131     // on. We can do this with a batched matrix multiplication, where
4132     // the result is written to a different submatrix of the output
4133     // for each matrix multiplication.
4134     //
4135     // The reason that we only support the kBatchYXDepth output layout
4136     // is that we have to do something in the depth for each (y,x)
4137     // coordinate. The kBatchYXDepth layout has the depth information
4138     // for each point (y,x) in contiguous memory while the
4139     // kBatchDepthYX layout does not.
4140     //
4141     // TODO(broune): Consider a special case for when output depth ==
4142     // 1, as then possibly this could all be done as one matrix
4143     // multiplication instead of a batched one, which should be
4144     // faster. Another possibility would be to add a weights layout
4145     // parameter and then support kBatchDepthYX for a different
4146     // weights layout.
4147     if (output_dimensions.layout() != dnn::DataLayout::kBatchYXDepth &&
4148         !(output_dimensions.layout() == dnn::DataLayout::kBatchDepthYX &&
4149           output_dimensions.feature_map_count() == 1)) {
4150       LOG(ERROR) << "Unsupported MatMul output layout.";
4151       return false;
4152     }
4153 
4154     const float alpha = 1.0f;  // Take the matrix product without scaling it.
4155     const float beta = 0.0f;   // Ignore the original values in output_data.
4156     const uint64 m = output_dimensions.feature_map_count();
4157     const uint64 n = input_dimensions.count();
4158     const uint64 k = input_dimensions.NodesAcrossFeatureMaps();
4159     const int lda = m;
4160     const int ldb = k;
4161     const int ldc = output_dimensions.NodesAcrossFeatureMaps();
4162     const int batch_count = output_dimensions.NodesPerFeatureMap();
4163 
4164     std::vector<DeviceMemory<float>> a(batch_count);
4165     std::vector<DeviceMemory<float>> b(batch_count);
4166     std::vector<DeviceMemory<float>> c(batch_count);
4167     for (int i = 0; i < batch_count; ++i) {
4168       const int weights_offset = i * input_dimensions.NodesAcrossFeatureMaps() *
4169                                  output_dimensions.feature_map_count();
4170       a[i] = DeviceMemory<float>::MakeFromByteSize(
4171           const_cast<float*>(reinterpret_cast<const float*>(weights.opaque())) +
4172               weights_offset,
4173           weights.ElementCount() - weights_offset);
4174 
4175       b[i] = input_data;
4176 
4177       const int output_offset = i * output_dimensions.feature_map_count();
4178       c[i] = DeviceMemory<float>::MakeFromByteSize(
4179           const_cast<float*>(
4180               reinterpret_cast<const float*>(output_data->opaque())) +
4181               output_offset,
4182           output_data->ElementCount() - output_offset);
4183     }
4184     const auto toPtrs = [](std::vector<DeviceMemory<float>>& v) {
4185       std::vector<DeviceMemory<float>*> ptrs;
4186       ptrs.reserve(v.size());
4187       for (auto& mem : v) {
4188         ptrs.push_back(&mem);
4189       }
4190       return ptrs;
4191     };
4192 
4193     stream->ThenBlasGemmBatched(blas::Transpose::kNoTranspose,
4194                                 blas::Transpose::kNoTranspose, m, n, k, alpha,
4195                                 toPtrs(a), lda, toPtrs(b), ldb, beta, toPtrs(c),
4196                                 ldc, batch_count);
4197   }
4198 
4199   return stream->ok();
4200 }
4201 
DoBiasAdd(Stream * stream,const DeviceMemory<float> & input_data,const DeviceMemory<float> & biases,const dnn::BatchDescriptor & dimensions,DeviceMemory<float> * output_data)4202 bool CudnnSupport::DoBiasAdd(Stream* stream,
4203                              const DeviceMemory<float>& input_data,
4204                              const DeviceMemory<float>& biases,
4205                              const dnn::BatchDescriptor& dimensions,
4206                              DeviceMemory<float>* output_data) {
4207   CudnnTensorDescriptor input_descriptor(dimensions, CUDNN_DATA_FLOAT);
4208 
4209   dnn::BatchDescriptor bias_dimensions;
4210   bias_dimensions.set_count(1)
4211       .set_feature_map_count(dimensions.feature_map_count())
4212       .set_height(1)
4213       .set_width(1)
4214       .set_layout(dnn::DataLayout::kBatchYXDepth);
4215   CudnnTensorDescriptor bias_descriptor(bias_dimensions, CUDNN_DATA_FLOAT);
4216 
4217   // cudnnAddTensor after R3 is in-place, so we need to copy input_data to
4218   // output_data before doing the addition, unless the input and
4219   // output are at the same address.
4220   if (input_data.opaque() != output_data->opaque()) {
4221     stream->ThenMemcpy(output_data, input_data,
4222                        dimensions.ElementCount() * sizeof(float));
4223     if (!stream->ok()) {
4224       LOG(ERROR)
4225           << "stream " << stream
4226           << " could not enqueue a tensor copy as part of bias addition.";
4227       return false;
4228     }
4229   }
4230 
4231   const float alpha = 1.0f;
4232   const float beta = 1.0f;
4233 
4234   auto cudnn = cudnn_->GetHandle(parent_, stream);
4235 
4236   const auto status = [&] {
4237     RETURN_IF_CUDNN_ERROR(cudnnAddTensor(
4238         cudnn.handle(), &alpha, bias_descriptor.handle(), biases.opaque(),
4239         &beta, input_descriptor.handle(), output_data->opaque()));
4240     return port::Status::OK();
4241   }();
4242   return IsStatusOk(status, /*report_error=*/true);
4243 }
4244 
DoActivate(Stream * stream,dnn::ActivationMode activation_mode,const dnn::BatchDescriptor & dimensions,const DeviceMemory<float> & input_data,DeviceMemory<float> * output_data,uint64 options)4245 bool CudnnSupport::DoActivate(Stream* stream,
4246                               dnn::ActivationMode activation_mode,
4247                               const dnn::BatchDescriptor& dimensions,
4248                               const DeviceMemory<float>& input_data,
4249                               DeviceMemory<float>* output_data,
4250                               uint64 options) {
4251   CudnnActivationDescriptor activation_desc(
4252       activation_mode, CUDNN_PROPAGATE_NAN, dimensions.value_max());
4253 
4254   CudnnTensorDescriptor input_nd(dimensions, CUDNN_DATA_FLOAT);
4255   // Alpha is the input scaling factor.
4256   float alpha = 1.0;
4257   // Beta is the output scaling factor.
4258   float beta = 0.0;
4259 
4260   auto cudnn = cudnn_->GetHandle(parent_, stream);
4261   const auto status = [&] {
4262     RETURN_IF_CUDNN_ERROR(cudnnActivationForward(
4263         cudnn.handle(), activation_desc.handle(), &alpha, input_nd.handle(),
4264         input_data.opaque(), &beta, input_nd.handle(), output_data->opaque()));
4265     return port::Status::OK();
4266   }();
4267   return IsStatusOk(status, /*report_error=*/true);
4268 }
4269 
DoPoolForward(Stream * stream,const dnn::PoolingDescriptor & pooling_dimensions,const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<double> & input_data,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<double> * output_data,ScratchAllocator * workspace_allocator)4270 bool CudnnSupport::DoPoolForward(
4271     Stream* stream, const dnn::PoolingDescriptor& pooling_dimensions,
4272     const dnn::BatchDescriptor& input_dimensions,
4273     const DeviceMemory<double>& input_data,
4274     const dnn::BatchDescriptor& output_dimensions,
4275     DeviceMemory<double>* output_data, ScratchAllocator* workspace_allocator) {
4276   // Alpha is the scaling factor for input.
4277   double alpha = 1.0;
4278   // Beta is the scaling factor for output.
4279   double beta = 0.0;
4280 
4281   CudnnTensorDescriptor src_desc(input_dimensions, CUDNN_DATA_DOUBLE);
4282   CudnnTensorDescriptor dest_desc(output_dimensions, CUDNN_DATA_DOUBLE);
4283   CudnnPoolingDescriptor pooling_desc(pooling_dimensions);
4284 
4285   auto cudnn = cudnn_->GetHandle(parent_, stream);
4286   const auto status = [&] {
4287     RETURN_IF_CUDNN_ERROR(cudnnPoolingForward(
4288         cudnn.handle(), pooling_desc.handle(), &alpha, src_desc.handle(),
4289         input_data.opaque(), &beta, dest_desc.handle(), output_data->opaque()));
4290     return port::Status::OK();
4291   }();
4292   return IsStatusOk(status, /*report_error=*/true);
4293 }
4294 
DoPoolForward(Stream * stream,const dnn::PoolingDescriptor & pooling_dimensions,const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<float> & input_data,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<float> * output_data,ScratchAllocator * workspace_allocator)4295 bool CudnnSupport::DoPoolForward(
4296     Stream* stream, const dnn::PoolingDescriptor& pooling_dimensions,
4297     const dnn::BatchDescriptor& input_dimensions,
4298     const DeviceMemory<float>& input_data,
4299     const dnn::BatchDescriptor& output_dimensions,
4300     DeviceMemory<float>* output_data, ScratchAllocator* workspace_allocator) {
4301   // Alpha is the scaling factor for input.
4302   float alpha = 1.0;
4303   // Beta is the scaling factor for output.
4304   float beta = 0.0;
4305 
4306   CudnnTensorDescriptor src_desc(input_dimensions, CUDNN_DATA_FLOAT);
4307   CudnnTensorDescriptor dest_desc(output_dimensions, CUDNN_DATA_FLOAT);
4308   CudnnPoolingDescriptor pooling_desc(pooling_dimensions);
4309 
4310   auto cudnn = cudnn_->GetHandle(parent_, stream);
4311   const auto status = [&] {
4312     RETURN_IF_CUDNN_ERROR(cudnnPoolingForward(
4313         cudnn.handle(), pooling_desc.handle(), &alpha, src_desc.handle(),
4314         input_data.opaque(), &beta, dest_desc.handle(), output_data->opaque()));
4315     return port::Status::OK();
4316   }();
4317   return IsStatusOk(status, /*report_error=*/true);
4318 }
4319 
DoPoolForward(Stream * stream,const dnn::PoolingDescriptor & pooling_dimensions,const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<Eigen::half> & input_data,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<Eigen::half> * output_data,ScratchAllocator * workspace_allocator)4320 bool CudnnSupport::DoPoolForward(
4321     Stream* stream, const dnn::PoolingDescriptor& pooling_dimensions,
4322     const dnn::BatchDescriptor& input_dimensions,
4323     const DeviceMemory<Eigen::half>& input_data,
4324     const dnn::BatchDescriptor& output_dimensions,
4325     DeviceMemory<Eigen::half>* output_data,
4326     ScratchAllocator* workspace_allocator) {
4327   // Alpha is the scaling factor for input.
4328   float alpha = 1.0;
4329   // Beta is the scaling factor for output.
4330   float beta = 0.0;
4331 
4332   CudnnTensorDescriptor src_desc(input_dimensions, CUDNN_DATA_HALF);
4333   CudnnTensorDescriptor dest_desc(output_dimensions, CUDNN_DATA_HALF);
4334   CudnnPoolingDescriptor pooling_desc(pooling_dimensions);
4335   auto cudnn = cudnn_->GetHandle(parent_, stream);
4336   const auto status = [&] {
4337     RETURN_IF_CUDNN_ERROR(cudnnPoolingForward(
4338         cudnn.handle(), pooling_desc.handle(), &alpha, src_desc.handle(),
4339         input_data.opaque(), &beta, dest_desc.handle(), output_data->opaque()));
4340     return port::Status::OK();
4341   }();
4342   return IsStatusOk(status, /*report_error=*/true);
4343 }
4344 
DoPoolForward(Stream * stream,const dnn::PoolingDescriptor & pooling_dimensions,const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<int8> & input_data,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<int8> * output_data,ScratchAllocator * workspace_allocator)4345 bool CudnnSupport::DoPoolForward(
4346     Stream* stream, const dnn::PoolingDescriptor& pooling_dimensions,
4347     const dnn::BatchDescriptor& input_dimensions,
4348     const DeviceMemory<int8>& input_data,
4349     const dnn::BatchDescriptor& output_dimensions,
4350     DeviceMemory<int8>* output_data, ScratchAllocator* workspace_allocator) {
4351   // Alpha is the scaling factor for input.
4352   float alpha = 1.0;
4353   // Beta is the scaling factor for output.
4354   float beta = 0.0;
4355 
4356   CudnnTensorDescriptor src_desc(input_dimensions, CUDNN_DATA_INT8);
4357   CudnnTensorDescriptor dest_desc(output_dimensions, CUDNN_DATA_INT8);
4358   CudnnPoolingDescriptor pooling_desc(pooling_dimensions);
4359 
4360   auto cudnn = cudnn_->GetHandle(parent_, stream);
4361   const auto status = [&] {
4362     RETURN_IF_CUDNN_ERROR(cudnnPoolingForward(
4363         cudnn.handle(), pooling_desc.handle(), &alpha, src_desc.handle(),
4364         input_data.opaque(), &beta, dest_desc.handle(), output_data->opaque()));
4365     return port::Status::OK();
4366   }();
4367   return IsStatusOk(status, /*report_error=*/true);
4368 }
4369 
DoPoolBackward(Stream * stream,const dnn::PoolingDescriptor & pooling_dimensions,const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<double> & input_data,const dnn::BatchDescriptor & output_dimensions,const DeviceMemory<double> & output_data,const DeviceMemory<double> & input_diff_data,DeviceMemory<double> * output_diff_data,ScratchAllocator * workspace_allocator)4370 bool CudnnSupport::DoPoolBackward(
4371     Stream* stream, const dnn::PoolingDescriptor& pooling_dimensions,
4372     const dnn::BatchDescriptor& input_dimensions,
4373     const DeviceMemory<double>& input_data,
4374     const dnn::BatchDescriptor& output_dimensions,
4375     const DeviceMemory<double>& output_data,
4376     const DeviceMemory<double>& input_diff_data,
4377     DeviceMemory<double>* output_diff_data,
4378     ScratchAllocator* workspace_allocator) {
4379   // Alpha is the scaling factor for input.
4380   double alpha = 1.0;
4381   // Beta is the scaling factor for output.
4382   double beta = 0.0;
4383 
4384   CudnnTensorDescriptor src_desc(input_dimensions, CUDNN_DATA_DOUBLE);
4385   CudnnTensorDescriptor dest_desc(output_dimensions, CUDNN_DATA_DOUBLE);
4386   CudnnPoolingDescriptor pooling_desc(pooling_dimensions);
4387 
4388   auto cudnn = cudnn_->GetHandle(parent_, stream);
4389   const auto status = [&] {
4390     RETURN_IF_CUDNN_ERROR(cudnnPoolingBackward(
4391         cudnn.handle(), pooling_desc.handle(), &alpha, dest_desc.handle(),
4392         output_data.opaque(), dest_desc.handle(), input_diff_data.opaque(),
4393         src_desc.handle(), input_data.opaque(), &beta, src_desc.handle(),
4394         output_diff_data->opaque()));
4395     return port::Status::OK();
4396   }();
4397   return IsStatusOk(status, /*report_error=*/true);
4398 }
4399 
DoPoolBackward(Stream * stream,const dnn::PoolingDescriptor & pooling_dimensions,const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<float> & input_data,const dnn::BatchDescriptor & output_dimensions,const DeviceMemory<float> & output_data,const DeviceMemory<float> & input_diff_data,DeviceMemory<float> * output_diff_data,ScratchAllocator * workspace_allocator)4400 bool CudnnSupport::DoPoolBackward(
4401     Stream* stream, const dnn::PoolingDescriptor& pooling_dimensions,
4402     const dnn::BatchDescriptor& input_dimensions,
4403     const DeviceMemory<float>& input_data,
4404     const dnn::BatchDescriptor& output_dimensions,
4405     const DeviceMemory<float>& output_data,
4406     const DeviceMemory<float>& input_diff_data,
4407     DeviceMemory<float>* output_diff_data,
4408     ScratchAllocator* workspace_allocator) {
4409   // Alpha is the scaling factor for input.
4410   float alpha = 1.0;
4411   // Beta is the scaling factor for output.
4412   float beta = 0.0;
4413 
4414   CudnnTensorDescriptor src_desc(input_dimensions, CUDNN_DATA_FLOAT);
4415   CudnnTensorDescriptor dest_desc(output_dimensions, CUDNN_DATA_FLOAT);
4416   CudnnPoolingDescriptor pooling_desc(pooling_dimensions);
4417 
4418   auto cudnn = cudnn_->GetHandle(parent_, stream);
4419   const auto status = [&] {
4420     RETURN_IF_CUDNN_ERROR(cudnnPoolingBackward(
4421         cudnn.handle(), pooling_desc.handle(), &alpha, dest_desc.handle(),
4422         output_data.opaque(), dest_desc.handle(), input_diff_data.opaque(),
4423         src_desc.handle(), input_data.opaque(), &beta, src_desc.handle(),
4424         output_diff_data->opaque()));
4425     return port::Status::OK();
4426   }();
4427   return IsStatusOk(status, /*report_error=*/true);
4428 }
4429 
DoPoolBackward(Stream * stream,const dnn::PoolingDescriptor & pooling_dimensions,const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<Eigen::half> & input_data,const dnn::BatchDescriptor & output_dimensions,const DeviceMemory<Eigen::half> & output_data,const DeviceMemory<Eigen::half> & input_diff_data,DeviceMemory<Eigen::half> * output_diff_data,ScratchAllocator * workspace_allocator)4430 bool CudnnSupport::DoPoolBackward(
4431     Stream* stream, const dnn::PoolingDescriptor& pooling_dimensions,
4432     const dnn::BatchDescriptor& input_dimensions,
4433     const DeviceMemory<Eigen::half>& input_data,
4434     const dnn::BatchDescriptor& output_dimensions,
4435     const DeviceMemory<Eigen::half>& output_data,
4436     const DeviceMemory<Eigen::half>& input_diff_data,
4437     DeviceMemory<Eigen::half>* output_diff_data,
4438     ScratchAllocator* workspace_allocator) {
4439   // Alpha is the scaling factor for input.
4440   float alpha = 1.0;
4441   // Beta is the scaling factor for output.
4442   float beta = 0.0;
4443 
4444   CudnnTensorDescriptor src_desc(input_dimensions, CUDNN_DATA_HALF);
4445   CudnnTensorDescriptor dest_desc(output_dimensions, CUDNN_DATA_HALF);
4446   CudnnPoolingDescriptor pooling_desc(pooling_dimensions);
4447 
4448   auto cudnn = cudnn_->GetHandle(parent_, stream);
4449   const auto status = [&] {
4450     RETURN_IF_CUDNN_ERROR(cudnnPoolingBackward(
4451         cudnn.handle(), pooling_desc.handle(), &alpha, dest_desc.handle(),
4452         output_data.opaque(), dest_desc.handle(), input_diff_data.opaque(),
4453         src_desc.handle(), input_data.opaque(), &beta, src_desc.handle(),
4454         output_diff_data->opaque()));
4455     return port::Status::OK();
4456   }();
4457   return IsStatusOk(status, /*report_error=*/true);
4458 }
4459 
DoNormalizeWithDimensions(Stream * stream,const dnn::NormalizeDescriptor & normalize_descriptor,const dnn::BatchDescriptor & dimensions,const DeviceMemory<float> & input_data,DeviceMemory<float> * output_data)4460 bool CudnnSupport::DoNormalizeWithDimensions(
4461     Stream* stream, const dnn::NormalizeDescriptor& normalize_descriptor,
4462     const dnn::BatchDescriptor& dimensions,
4463     const DeviceMemory<float>& input_data, DeviceMemory<float>* output_data) {
4464   // Check for unsupported modes.
4465   if (normalize_descriptor.wrap_around()) {
4466     LOG(ERROR) << "CUDA LRN does not support cudnn-around mode";
4467     return false;
4468   }
4469   if (normalize_descriptor.segment_size()) {
4470     LOG(ERROR) << "CUDA LRN does not support segmentation";
4471     return false;
4472   }
4473 
4474   CudnnTensorDescriptor dims(dimensions, CUDNN_DATA_FLOAT);
4475   CudnnNormalizeDescriptor normalize(normalize_descriptor);
4476 
4477   // Alpha is the scaling factor for input.
4478   float alpha = 1.0f;
4479   // Beta is the scaling factor for output.
4480   float beta = 0.0f;
4481 
4482   auto cudnn = cudnn_->GetHandle(parent_, stream);
4483 
4484   // Launch the normalization.
4485   const auto status = [&] {
4486     RETURN_IF_CUDNN_ERROR(cudnnLRNCrossChannelForward(
4487         cudnn.handle(), normalize.handle(), CUDNN_LRN_CROSS_CHANNEL_DIM1,
4488         &alpha, dims.handle(), input_data.opaque(), &beta, dims.handle(),
4489         output_data->opaque()));
4490     return port::Status::OK();
4491   }();
4492   return IsStatusOk(status, /*report_error=*/true);
4493 }
4494 
DoNormalizeBackwardWithDimensions(Stream * stream,const dnn::NormalizeDescriptor & normalize_descriptor,const dnn::BatchDescriptor & dimensions,const DeviceMemory<float> & raw_data,const DeviceMemory<float> & normalized_data,const DeviceMemory<float> & normalized_variable_gradient,DeviceMemory<float> * raw_variable_gradient,ScratchAllocator * workspace_allocator)4495 bool CudnnSupport::DoNormalizeBackwardWithDimensions(
4496     Stream* stream, const dnn::NormalizeDescriptor& normalize_descriptor,
4497     const dnn::BatchDescriptor& dimensions, const DeviceMemory<float>& raw_data,
4498     const DeviceMemory<float>& normalized_data,
4499     const DeviceMemory<float>& normalized_variable_gradient,
4500     DeviceMemory<float>* raw_variable_gradient,
4501     ScratchAllocator* workspace_allocator) {
4502   // Check for unsupported modes.
4503   if (normalize_descriptor.wrap_around()) {
4504     LOG(ERROR) << "CUDA LRN does not support cudnn-around mode";
4505     return false;
4506   }
4507   if (normalize_descriptor.segment_size()) {
4508     LOG(ERROR) << "CUDA LRN does not support segmentation";
4509     return false;
4510   }
4511 
4512   CudnnTensorDescriptor dims(dimensions, CUDNN_DATA_FLOAT);
4513   CudnnNormalizeDescriptor normalize(normalize_descriptor);
4514 
4515   float alpha = 1.0f;
4516   float beta = 0.0f;
4517 
4518   auto cudnn = cudnn_->GetHandle(parent_, stream);
4519   const auto status = [&] {
4520     RETURN_IF_CUDNN_ERROR(cudnnLRNCrossChannelBackward(
4521         cudnn.handle(), normalize.handle(), CUDNN_LRN_CROSS_CHANNEL_DIM1,
4522         &alpha, dims.handle(), normalized_data.opaque(), dims.handle(),
4523         normalized_variable_gradient.opaque(), dims.handle(), raw_data.opaque(),
4524         &beta, dims.handle(), raw_variable_gradient->opaque()));
4525     return port::Status::OK();
4526   }();
4527   return IsStatusOk(status, /*report_error=*/true);
4528 }
4529 
DoDepthConcatenate(Stream * stream,port::ArraySlice<dnn::BatchDescriptor> input_dimensions,port::ArraySlice<const DeviceMemory<float> * > input_data,DeviceMemory<float> * output_data)4530 bool CudnnSupport::DoDepthConcatenate(
4531     Stream* stream, port::ArraySlice<dnn::BatchDescriptor> input_dimensions,
4532     port::ArraySlice<const DeviceMemory<float>*> input_data,
4533     DeviceMemory<float>* output_data) {
4534   CHECK_EQ(input_dimensions.size(), input_data.size());
4535 
4536   for (const auto& dimensions : input_dimensions) {
4537     if (dimensions.layout() != dnn::DataLayout::kBatchDepthYX) {
4538       LOG(ERROR) << "CudnnSupport::DoDepthConcatenate currently only "
4539                     "supports the kBatchDepthYX layout.";
4540       return false;
4541     }
4542   }
4543 
4544   if (input_dimensions.empty()) {
4545     return true;  // Nothing to do.
4546   }
4547 
4548   dnn::BatchDescriptor output_dimensions =
4549       dnn::BatchDescriptor::DepthConcatenateOutputDescriptor(input_dimensions);
4550 
4551   const int64 area = output_dimensions.width() * output_dimensions.height();
4552   const auto index = [area](int64 batch, int64 depth, int64 yx,
4553                             int64 max_depth) {
4554     return (batch * max_depth + depth) * area + yx;
4555   };
4556 
4557   std::vector<float> output_host(output_dimensions.ElementCount());
4558   std::vector<float> tmp;
4559   int64 depth_sum = 0;
4560   for (size_t i = 0; i < input_data.size(); ++i) {
4561     const auto& dimensions = input_dimensions[i];
4562     tmp.resize(dimensions.ElementCount());
4563     stream->ThenMemcpyD2H<float>(*input_data[i], absl::MakeSpan(tmp));
4564     port::Status block_status = stream->BlockHostUntilDone();
4565     if (!block_status.ok()) {
4566       LOG(ERROR) << "BlockHostUntilDone failed: " << block_status;
4567       return false;
4568     }
4569 
4570     for (int64 batch = 0; batch < output_dimensions.count(); ++batch) {
4571       for (int64 yx = 0; yx < area; ++yx) {
4572         for (int64 depth = 0; depth < dimensions.feature_map_count(); ++depth) {
4573           LOG(INFO) << output_dimensions.ElementCount() << ' ' << batch << ' '
4574                     << yx << ' ' << depth;
4575           output_host[index(batch, depth + depth_sum, yx,
4576                             output_dimensions.feature_map_count())] =
4577               tmp[index(batch, depth, yx, dimensions.feature_map_count())];
4578         }
4579       }
4580     }
4581     depth_sum += dimensions.feature_map_count();
4582   }
4583   stream->ThenMemcpyH2D<float>(output_host, output_data);
4584   return true;
4585 }
4586 
DoElementwiseOperate(Stream * stream,dnn::ElementwiseOperation operation,port::ArraySlice<dnn::BatchDescriptor> input_dimensions,port::ArraySlice<const DeviceMemory<float> * > input_data,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<float> * output_data)4587 bool CudnnSupport::DoElementwiseOperate(
4588     Stream* stream, dnn::ElementwiseOperation operation,
4589     port::ArraySlice<dnn::BatchDescriptor> input_dimensions,
4590     port::ArraySlice<const DeviceMemory<float>*> input_data,
4591     const dnn::BatchDescriptor& output_dimensions,
4592     DeviceMemory<float>* output_data) {
4593   LOG(FATAL) << "not yet implemented";  // TODO(leary)
4594   return false;
4595 }
4596 
DoXYPad(Stream * stream,const dnn::BatchDescriptor & dimensions,const DeviceMemory<float> & input_data,int64 left_pad,int64 right_pad,int64 top_pad,int64 bottom_pad,DeviceMemory<float> * output_data)4597 bool CudnnSupport::DoXYPad(Stream* stream,
4598                            const dnn::BatchDescriptor& dimensions,
4599                            const DeviceMemory<float>& input_data,
4600                            int64 left_pad, int64 right_pad, int64 top_pad,
4601                            int64 bottom_pad, DeviceMemory<float>* output_data) {
4602   LOG(FATAL) << "not yet implemented";  // TODO(leary)
4603   return false;
4604 }
4605 
DoXYSlice(Stream * stream,const dnn::BatchDescriptor & dimensions,const DeviceMemory<float> & input_data,int64 left_trim,int64 right_trim,int64 top_trim,int64 bottom_trim,DeviceMemory<float> * output_data)4606 bool CudnnSupport::DoXYSlice(Stream* stream,
4607                              const dnn::BatchDescriptor& dimensions,
4608                              const DeviceMemory<float>& input_data,
4609                              int64 left_trim, int64 right_trim, int64 top_trim,
4610                              int64 bottom_trim,
4611                              DeviceMemory<float>* output_data) {
4612   LOG(FATAL) << "not yet implemented";  // TODO(leary)
4613   return false;
4614 }
4615 
DoMemcpyD2HQuantized(Stream * stream,const DeviceMemory<float> & gpu_unquantized_src,dnn::QuantizedActivationMode mode,void * host_dst,int64 size)4616 bool CudnnSupport::DoMemcpyD2HQuantized(
4617     Stream* stream, const DeviceMemory<float>& gpu_unquantized_src,
4618     dnn::QuantizedActivationMode mode, void* host_dst, int64 size) {
4619   LOG(ERROR) << "quantized memcpy not supported by cuDNN";
4620   return false;
4621 }
4622 
DoMemcpyH2DQuantized(Stream * stream,const void * host_src,int64 size,dnn::QuantizedActivationMode mode,DeviceMemory<float> * gpu_unquantized_dst)4623 bool CudnnSupport::DoMemcpyH2DQuantized(
4624     Stream* stream, const void* host_src, int64 size,
4625     dnn::QuantizedActivationMode mode,
4626     DeviceMemory<float>* gpu_unquantized_dst) {
4627   LOG(ERROR) << "quantized memcpy not supported by cuDNN";
4628   return false;
4629 }
4630 
DeriveOutputBatchDescriptor(const dnn::BatchDescriptor & batch_descriptor,const dnn::FilterDescriptor & filter_descriptor,const dnn::ConvolutionDescriptor & convolution_descriptor,dnn::BatchDescriptor * output_batch_descriptor)4631 bool CudnnSupport::DeriveOutputBatchDescriptor(
4632     const dnn::BatchDescriptor& batch_descriptor,
4633     const dnn::FilterDescriptor& filter_descriptor,
4634     const dnn::ConvolutionDescriptor& convolution_descriptor,
4635     dnn::BatchDescriptor* output_batch_descriptor) {
4636   CudnnTensorDescriptor input_nd(batch_descriptor, CUDNN_DATA_FLOAT);
4637   CudnnFilterDescriptor filter(filter_descriptor, CUDNN_DATA_FLOAT);
4638   CudnnConvolutionDescriptor conv(convolution_descriptor, CUDNN_DATA_FLOAT);
4639 
4640   int dn = batch_descriptor.ndims() + 2;
4641   std::vector<int> dims(dn);  // in BDYX
4642   const auto status = [&] {
4643     RETURN_IF_CUDNN_ERROR(cudnnGetConvolutionNdForwardOutputDim(
4644         conv.handle(), input_nd.handle(), filter.handle(), dn, dims.data()));
4645     output_batch_descriptor->set_count(dims[0])
4646         .set_feature_map_count(dims[1])
4647         .set_layout(batch_descriptor.layout());
4648 
4649     for (int i = 0; i < batch_descriptor.ndims(); i++) {
4650       output_batch_descriptor->set_spatial_dim(static_cast<dnn::DimIndex>(i),
4651                                                dims.rbegin()[i]);
4652     }
4653     return port::Status::OK();
4654   }();
4655   return IsStatusOk(status, /*report_error=*/true);
4656 }
4657 
4658 }  // namespace gpu
4659 
initialize_cudnn()4660 void initialize_cudnn() {
4661   port::Status status =
4662       PluginRegistry::Instance()->RegisterFactory<PluginRegistry::DnnFactory>(
4663           cuda::kCudaPlatformId, gpu::kCuDnnPlugin, "cuDNN",
4664           [](internal::StreamExecutorInterface* parent) -> dnn::DnnSupport* {
4665             gpu::GpuExecutor* cuda_executor =
4666                 dynamic_cast<gpu::GpuExecutor*>(parent);
4667             if (cuda_executor == nullptr) {
4668               LOG(ERROR) << "Attempting to initialize an instance of the cuDNN "
4669                          << "support library with a non-CUDA StreamExecutor";
4670               return nullptr;
4671             }
4672 
4673             gpu::CudnnSupport* dnn = new gpu::CudnnSupport(cuda_executor);
4674             if (!dnn->Init().ok()) {
4675               // Note: Init() will log a more specific error.
4676               delete dnn;
4677               return nullptr;
4678             }
4679             return dnn;
4680           });
4681 
4682   if (!status.ok()) {
4683     LOG(ERROR) << "Unable to register cuDNN factory: "
4684                << status.error_message();
4685   }
4686 
4687   PluginRegistry::Instance()->SetDefaultFactory(
4688       cuda::kCudaPlatformId, PluginKind::kDnn, gpu::kCuDnnPlugin);
4689 }
4690 
4691 }  // namespace stream_executor
4692 
4693 #pragma clang diagnostic pop
4694 
4695 REGISTER_MODULE_INITIALIZER(register_cudnn,
4696                             { stream_executor::initialize_cudnn(); });
4697