• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/stream_executor/cuda/cuda_dnn.h"
17 
18 #include <functional>
19 #include <memory>
20 #include <utility>
21 
22 #include "absl/strings/str_cat.h"
23 #include "third_party/eigen3/Eigen/Core"
24 #include "tensorflow/core/lib/core/errors.h"
25 #include "tensorflow/core/lib/strings/stringprintf.h"
26 #include "tensorflow/core/util/env_var.h"
27 #include "tensorflow/stream_executor/cuda/cuda_activation.h"
28 #include "tensorflow/stream_executor/cuda/cuda_diagnostics.h"
29 #include "tensorflow/stream_executor/cuda/cuda_driver.h"
30 #include "tensorflow/stream_executor/cuda/cuda_gpu_executor.h"
31 #include "tensorflow/stream_executor/cuda/cuda_platform_id.h"
32 #include "tensorflow/stream_executor/cuda/cuda_stream.h"
33 #include "tensorflow/stream_executor/cuda/cuda_timer.h"
34 #include "tensorflow/stream_executor/cuda/cudnn_version.h"
35 #include "tensorflow/stream_executor/dnn.h"
36 #include "tensorflow/stream_executor/lib/env.h"
37 #include "tensorflow/stream_executor/lib/error.h"
38 #include "tensorflow/stream_executor/lib/initialize.h"
39 #include "tensorflow/stream_executor/lib/mathutil.h"
40 #include "tensorflow/stream_executor/lib/threadpool.h"
41 #include "tensorflow/stream_executor/platform/logging.h"
42 #include "tensorflow/stream_executor/plugin_registry.h"
43 #include "tensorflow/stream_executor/scratch_allocator.h"
44 #include "tensorflow/stream_executor/stream.h"
45 #include "tensorflow/stream_executor/stream_executor_pimpl.h"
46 // clang-format off
47 #include "cuda/include/cudnn.h"
48 #include "absl/strings/string_view.h"
49 // clang-format on
50 
51 #pragma clang diagnostic push
52 
53 // Make sure that Eigen::half forward declaration in dnn.h matches the
54 // declaration in Eigen.
55 #pragma clang diagnostic warning "-Wmismatched-tags"
56 
57 namespace stream_executor {
58 namespace gpu {
59 
60 PLUGIN_REGISTRY_DEFINE_PLUGIN_ID(kCuDnnPlugin);
61 
62 namespace {
63 
64 static_assert(CUDNN_VERSION >= 6000, "cuDNN needs to be version 6.0 or higher");
65 
66 // Exits the program if 'expr' doesn't return CUDNN_STATUS_SUCCESS.
67 #define CHECK_CUDNN_OK(expr) CHECK_EQ(expr, CUDNN_STATUS_SUCCESS)
68 
69 // If 'expr' doesn't return CUDNN_STATUS_SUCCESS, returns from the current
70 // function with a non-successful port::Status.
71 #define RETURN_IF_CUDNN_ERROR(expr)                                      \
72   do {                                                                   \
73     cudnnStatus_t _status = expr;                                        \
74     if (!SE_PREDICT_TRUE(_status == CUDNN_STATUS_SUCCESS)) {             \
75       std::ostringstream oss;                                            \
76       oss << ToString(_status) << "\nin " << __FILE__ << "(" << __LINE__ \
77           << "): '" << #expr << "'";                                     \
78       return port::Status(port::error::UNKNOWN, oss.str().c_str());      \
79     }                                                                    \
80   } while (false)
81 
82 // Converts (via narrowing) a type T value to a type U, and checks that the
83 // value has no value change due to the conversion.
84 template <typename WideT, typename NarrowT>
CheckedNarrowing(const WideT & wide)85 NarrowT CheckedNarrowing(const WideT& wide) {
86   NarrowT narrow = wide;
87   CHECK_EQ(narrow, wide)
88       << "checked narrowing failed; values not equal post-conversion";
89   return narrow;
90 }
91 
ToString(cudnnStatus_t status)92 string ToString(cudnnStatus_t status) {
93   switch (status) {
94     case CUDNN_STATUS_SUCCESS:
95       return "CUDNN_STATUS_SUCCESS";
96     case CUDNN_STATUS_NOT_INITIALIZED:
97       return "CUDNN_STATUS_NOT_INITIALIZED";
98     case CUDNN_STATUS_ALLOC_FAILED:
99       return "CUDNN_STATUS_ALLOC_FAILED";
100     case CUDNN_STATUS_BAD_PARAM:
101       return "CUDNN_STATUS_BAD_PARAM";
102     case CUDNN_STATUS_INTERNAL_ERROR:
103       return "CUDNN_STATUS_INTERNAL_ERROR";
104     case CUDNN_STATUS_INVALID_VALUE:
105       return "CUDNN_STATUS_INVALID_VALUE";
106     case CUDNN_STATUS_ARCH_MISMATCH:
107       return "CUDNN_STATUS_ARCH_MISMATCH";
108     case CUDNN_STATUS_MAPPING_ERROR:
109       return "CUDNN_STATUS_MAPPING_ERROR";
110     case CUDNN_STATUS_EXECUTION_FAILED:
111       return "CUDNN_STATUS_EXECUTION_FAILED";
112     case CUDNN_STATUS_NOT_SUPPORTED:
113       return "CUDNN_STATUS_NOT_SUPPORTED";
114     case CUDNN_STATUS_LICENSE_ERROR:
115       return "CUDNN_STATUS_LICENSE_ERROR";
116     case CUDNN_STATUS_RUNTIME_PREREQUISITE_MISSING:
117       return "CUDNN_STATUS_RUNTIME_PREREQUISITE_MISSING";
118 #if CUDNN_VERSION >= 7000
119     case CUDNN_STATUS_RUNTIME_IN_PROGRESS:
120       return "CUDNN_STATUS_RUNTIME_IN_PROGRESS";
121     case CUDNN_STATUS_RUNTIME_FP_OVERFLOW:
122       return "CUDNN_STATUS_RUNTIME_FP_OVERFLOW";
123 #endif
124     default:
125       return absl::StrCat("<unknown cudnn status: ", static_cast<int>(status),
126                           ">");
127   }
128 }
129 
130 // RAII wrapper for all calls to cuDNN with a cuDNN handle argument.
131 //
132 // See CudnnAccess::GetHandle() for details.
133 class CudnnHandle {
134  public:
135   // Takes ownership of the executor context and the lock to access cuDNN
136   // using handle.
CudnnHandle(gpu::ScopedActivateExecutorContext context,mutex_lock lock,cudnnHandle_t handle)137   CudnnHandle(gpu::ScopedActivateExecutorContext context, mutex_lock lock,
138               cudnnHandle_t handle)
139       : context_(std::move(context)), lock_(std::move(lock)), handle_(handle) {}
140 
141   // Returns cuDNN handle. To be passed directly to cuDNN APIs, don't keep
142   // a copy.
handle() const143   cudnnHandle_t handle() const { return handle_; }
144 
145  private:
146   gpu::ScopedActivateExecutorContext context_;
147   mutex_lock lock_;
148   cudnnHandle_t handle_;  // Not owned.
149 };
150 
151 }  // namespace
152 
153 // Wraps a cuDNN handle and provides access to it through CudnnHandle
154 // instances, which also locks a mutex, acquires the CUDA context, and sets
155 // the stream that cuDNN should use to enqueue any work.
156 //
157 // Note: CudnnSupport::cudnn_ should be the only instantiation of this class.
158 class CudnnAccess {
159  public:
160   // Takes ownership of the handle.
CudnnAccess(cudnnHandle_t handle)161   explicit CudnnAccess(cudnnHandle_t handle) : handle_(handle) {}
162 
~CudnnAccess()163   ~CudnnAccess() {
164     mutex_lock lock(mutex_);
165     cudnnDestroy(handle_);
166   }
167 
168   // Creates a CudnnHandle instance for stream.
169   //
170   // cuDNN API calls using the same handle instance need to be serialized
171   // across threads. This is guaranteed by CudnnHandle instances locking the
172   // mutex owned by this class.
173   //
174   // Most cuDNN APIs taking a handle perform work on a CUDA stream. The
175   // CudnnHandle instance acquires the executor's CUDA context and sets cuDNN
176   // to use the provided stream.
177   //
178   // The stream argument may be null, which translates to the legacy default
179   // stream. See
180   // https://docs.nvidia.com/cuda/cuda-driver-api/stream-sync-behavior.html.
181   // The legacy default stream synchronizes with all other streams and it is
182   // therefore a bad idea (performance wise) to call any cuDNN APIs that
183   // enqueue work in the stream.
GetHandle(GpuExecutor * executor,Stream * stream)184   CudnnHandle GetHandle(GpuExecutor* executor, Stream* stream) {
185     mutex_lock lock(mutex_);
186     gpu::ScopedActivateExecutorContext context(executor);
187     CUstream cu_stream = stream ? AsGpuStreamValue(stream) : cudaStreamLegacy;
188     const auto status = cudnnSetStream(handle_, cu_stream);
189     CHECK_EQ(status, CUDNN_STATUS_SUCCESS) << "Failed to set cuDNN stream.";
190     return CudnnHandle(std::move(context), std::move(lock), handle_);
191   }
192 
193  private:
194   // Guards the enqueueing of cuDNN operations via the handle_ below.
195   mutex mutex_;
196 
197   // cuDNN library handle.
198   cudnnHandle_t handle_ GUARDED_BY(mutex_);  // Owned.
199 };
200 
201 namespace {
202 
203 // A helper function to return the internal compute type for
204 // RNNs in cudnn.
205 cudnnDataType_t GetRnnComputeType(dnn::DataType data_type);
206 
ToConvForwardAlgo(dnn::AlgorithmDesc algorithm)207 cudnnConvolutionFwdAlgo_t ToConvForwardAlgo(dnn::AlgorithmDesc algorithm) {
208   cudnnConvolutionFwdAlgo_t algo =
209       cudnnConvolutionFwdAlgo_t(algorithm.algo_id());
210   switch (algo) {
211     case CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM:
212     case CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM:
213     case CUDNN_CONVOLUTION_FWD_ALGO_GEMM:
214     case CUDNN_CONVOLUTION_FWD_ALGO_DIRECT:
215     case CUDNN_CONVOLUTION_FWD_ALGO_FFT:
216     case CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING:
217     case CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD:
218     case CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED:
219       return algo;
220     default:
221       LOG(FATAL) << "Unsupported Cudnn convolution forward algorithm: "
222                  << algorithm.algo_id();
223   }
224 }
225 
ToConvBackwardDataAlgo(dnn::AlgorithmDesc algorithm)226 cudnnConvolutionBwdDataAlgo_t ToConvBackwardDataAlgo(
227     dnn::AlgorithmDesc algorithm) {
228   cudnnConvolutionBwdDataAlgo_t algo =
229       cudnnConvolutionBwdDataAlgo_t(algorithm.algo_id());
230   switch (algo) {
231     case CUDNN_CONVOLUTION_BWD_DATA_ALGO_0:
232     case CUDNN_CONVOLUTION_BWD_DATA_ALGO_1:
233     case CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT:
234     case CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING:
235     case CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD:
236     case CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD_NONFUSED:
237       return algo;
238     default:
239       LOG(FATAL)
240           << "Unsupported Cudnn convolution backward algorithm for data: "
241           << algorithm.algo_id();
242   }
243 }
244 
ToConvBackwardFilterAlgo(dnn::AlgorithmDesc algorithm)245 cudnnConvolutionBwdFilterAlgo_t ToConvBackwardFilterAlgo(
246     dnn::AlgorithmDesc algorithm) {
247   cudnnConvolutionBwdFilterAlgo_t algo =
248       cudnnConvolutionBwdFilterAlgo_t(algorithm.algo_id());
249   switch (algo) {
250     case CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0:
251     case CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1:
252     case CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT:
253     case CUDNN_CONVOLUTION_BWD_FILTER_ALGO_3:
254     // Based on cudnn.h, the following is not implemented.
255     // case CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD:
256     case CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD_NONFUSED:
257       return algo;
258     // Produces incorrect results for some shapes. Disabled for now, see
259     // NVIDIA bug 2072856. TODO(csigg): Only disable for subset of shapes.
260     // case CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT_TILING:
261     default:
262       LOG(FATAL)
263           << "Unsupported Cudnn convolution backward algorithm for filter: "
264           << algorithm.algo_id();
265   }
266 }
267 
GetCudnnProperty(libraryPropertyType type)268 port::StatusOr<int> GetCudnnProperty(libraryPropertyType type) {
269   int value;
270   RETURN_IF_CUDNN_ERROR(cudnnGetProperty(type, &value));
271   return value;
272 }
273 
ToCudnnRNNAlgo(absl::optional<dnn::AlgorithmDesc> algorithm)274 cudnnRNNAlgo_t ToCudnnRNNAlgo(absl::optional<dnn::AlgorithmDesc> algorithm) {
275   if (!algorithm.has_value()) {
276     return CUDNN_RNN_ALGO_STANDARD;
277   }
278   cudnnRNNAlgo_t algo = static_cast<cudnnRNNAlgo_t>(algorithm->algo_id());
279   switch (algo) {
280     case CUDNN_RNN_ALGO_STANDARD:
281     case CUDNN_RNN_ALGO_PERSIST_STATIC:
282     case CUDNN_RNN_ALGO_PERSIST_DYNAMIC:
283       return algo;
284     default:
285       LOG(FATAL) << "Unsupported Cudnn RNN algorithm: " << algorithm->algo_id();
286   }
287 }
288 
GetLoadedCudnnVersion(CudnnVersion * version)289 port::Status GetLoadedCudnnVersion(CudnnVersion* version) {
290   SE_ASSIGN_OR_RETURN(version->major_version, GetCudnnProperty(MAJOR_VERSION));
291   SE_ASSIGN_OR_RETURN(version->minor_version, GetCudnnProperty(MINOR_VERSION));
292   SE_ASSIGN_OR_RETURN(version->patch_level, GetCudnnProperty(PATCH_LEVEL));
293   return port::Status::OK();
294 }
295 
296 }  // namespace
297 
CudnnSupport(GpuExecutor * parent)298 CudnnSupport::CudnnSupport(GpuExecutor* parent) : parent_(parent) {}
299 
Init()300 port::Status CudnnSupport::Init() {
301   ScopedActivateExecutorContext context(parent_);
302   cudnnHandle_t cudnn_handle = nullptr;
303   const auto status = cudnnCreate(&cudnn_handle);
304   if (status == CUDNN_STATUS_SUCCESS) {
305     CudnnVersion source_version(CUDNN_MAJOR, CUDNN_MINOR, CUDNN_PATCHLEVEL);
306 
307     CudnnVersion loaded_version;
308     TF_RETURN_IF_ERROR(GetLoadedCudnnVersion(&loaded_version));
309     if (!IsSourceCompatibleWithCudnnLibrary(source_version, loaded_version)) {
310       const tensorflow::string error = absl::StrCat(
311           "Loaded runtime CuDNN library: ", loaded_version.ToString(),
312           " but source was compiled with: ", source_version.ToString(),
313           ".  CuDNN library major and minor version needs to match or have "
314           "higher minor version in case of CuDNN 7.0 or later version. If "
315           "using a binary install, upgrade your CuDNN library.  If building "
316           "from sources, make sure the library loaded at runtime is "
317           "compatible "
318           "with the version specified during compile configuration.");
319       LOG(ERROR) << error;
320       cudnnDestroy(cudnn_handle);
321       return port::Status(port::error::INTERNAL, error);
322     }
323 
324     cudnn_.reset(new CudnnAccess(cudnn_handle));
325     return port::Status::OK();
326   }
327 
328   CHECK_EQ(cudnn_handle, nullptr);
329   LOG(ERROR) << "Could not create cudnn handle: " << ToString(status);
330   if (status == CUDNN_STATUS_NOT_INITIALIZED) {
331     auto result = gpu::Diagnostician::FindKernelDriverVersion();
332     if (!result.ok()) {
333       LOG(ERROR) << "Error retrieving driver version: "
334                  << cuda::DriverVersionStatusToString(result);
335     } else {
336       const auto& version = result.ValueOrDie();
337       LOG(ERROR) << "Possibly insufficient driver version: "
338                  << cuda::DriverVersionToString(version);
339     }
340   }
341 
342   return port::Status(port::error::INTERNAL,
343                       absl::StrCat("cudnn library could not create a handle: ",
344                                    ToString(status)));
345 }
346 
347 port::StatusOr<perftools::gputools::dnn::VersionInfo>
GetVersion()348 CudnnSupport::GetVersion() {
349   CudnnVersion version;
350   TF_RETURN_IF_ERROR(GetLoadedCudnnVersion(&version));
351   return perftools::gputools::dnn::VersionInfo(
352       version.major_version, version.minor_version, version.patch_level);
353 }
354 
355 namespace {
356 
357 // Deleter functors for cuDNN types that need to be deleted.
358 struct TensorDescriptorDeleter {
operator ()stream_executor::gpu::__anonf41f84630311::TensorDescriptorDeleter359   void operator()(cudnnTensorDescriptor_t descriptor) const {
360     CHECK_CUDNN_OK(cudnnDestroyTensorDescriptor(descriptor));
361   }
362 };
363 #if CUDNN_VERSION >= 7201
364 struct RNNDataDescriptorDeleter {
operator ()stream_executor::gpu::__anonf41f84630311::RNNDataDescriptorDeleter365   void operator()(cudnnRNNDataDescriptor_t descriptor) const {
366     CHECK_CUDNN_OK(cudnnDestroyRNNDataDescriptor(descriptor));
367   }
368 };
369 #endif
370 struct FilterDescriptorDeleter {
operator ()stream_executor::gpu::__anonf41f84630311::FilterDescriptorDeleter371   void operator()(cudnnFilterDescriptor_t descriptor) const {
372     CHECK_CUDNN_OK(cudnnDestroyFilterDescriptor(descriptor));
373   }
374 };
375 struct ConvolutionDescriptorDeleter {
operator ()stream_executor::gpu::__anonf41f84630311::ConvolutionDescriptorDeleter376   void operator()(cudnnConvolutionDescriptor_t descriptor) const {
377     CHECK_CUDNN_OK(cudnnDestroyConvolutionDescriptor(descriptor));
378   }
379 };
380 struct PoolingDescriptorDeleter {
operator ()stream_executor::gpu::__anonf41f84630311::PoolingDescriptorDeleter381   void operator()(cudnnPoolingDescriptor_t descriptor) const {
382     CHECK_CUDNN_OK(cudnnDestroyPoolingDescriptor(descriptor));
383   }
384 };
385 struct LrnDescriptorDeleter {
operator ()stream_executor::gpu::__anonf41f84630311::LrnDescriptorDeleter386   void operator()(cudnnLRNDescriptor_t descriptor) const {
387     CHECK_CUDNN_OK(cudnnDestroyLRNDescriptor(descriptor));
388   }
389 };
390 
391 struct ActivationDescriptorDeleter {
operator ()stream_executor::gpu::__anonf41f84630311::ActivationDescriptorDeleter392   void operator()(cudnnActivationDescriptor_t descriptor) const {
393     CHECK_CUDNN_OK(cudnnDestroyActivationDescriptor(descriptor));
394   }
395 };
396 struct DropoutDescriptorDeleter {
operator ()stream_executor::gpu::__anonf41f84630311::DropoutDescriptorDeleter397   void operator()(cudnnDropoutDescriptor_t descriptor) const {
398     CHECK_CUDNN_OK(cudnnDestroyDropoutDescriptor(descriptor));
399   }
400 };
401 struct RnnDescriptorDeleter {
operator ()stream_executor::gpu::__anonf41f84630311::RnnDescriptorDeleter402   void operator()(cudnnRNNDescriptor_t descriptor) const {
403     CHECK_CUDNN_OK(cudnnDestroyRNNDescriptor(descriptor));
404   }
405 };
406 struct PersistentRnnPlanDeleter {
operator ()stream_executor::gpu::__anonf41f84630311::PersistentRnnPlanDeleter407   void operator()(cudnnPersistentRNNPlan_t plan) const {
408     CHECK_CUDNN_OK(cudnnDestroyPersistentRNNPlan(plan));
409   }
410 };
411 
412 // RAII wrappers for cuDNN types.
413 using TensorDescriptor =
414     std::unique_ptr<cudnnTensorStruct, TensorDescriptorDeleter>;
415 #if CUDNN_VERSION >= 7201
416 using RNNDataDescriptor =
417     std::unique_ptr<cudnnRNNDataStruct, RNNDataDescriptorDeleter>;
418 #endif
419 using FilterDescriptor =
420     std::unique_ptr<cudnnFilterStruct, FilterDescriptorDeleter>;
421 using ConvolutionDescriptor =
422     std::unique_ptr<cudnnConvolutionStruct, ConvolutionDescriptorDeleter>;
423 using PoolingDescriptor =
424     std::unique_ptr<cudnnPoolingStruct, PoolingDescriptorDeleter>;
425 using LrnDescriptor = std::unique_ptr<cudnnLRNStruct, LrnDescriptorDeleter>;
426 using ActivationDescriptor =
427     std::unique_ptr<cudnnActivationStruct, ActivationDescriptorDeleter>;
428 using DropoutDescriptor =
429     std::unique_ptr<cudnnDropoutStruct, DropoutDescriptorDeleter>;
430 using RnnDescriptor = std::unique_ptr<cudnnRNNStruct, RnnDescriptorDeleter>;
431 using PersistentRnnPlan =
432     std::unique_ptr<cudnnPersistentRNNPlan, PersistentRnnPlanDeleter>;
433 
434 // Factory methods for cuDNN types.
CreateTensorDescriptor()435 TensorDescriptor CreateTensorDescriptor() {
436   cudnnTensorDescriptor_t result;
437   CHECK_CUDNN_OK(cudnnCreateTensorDescriptor(&result));
438   return TensorDescriptor(result);
439 }
440 #if CUDNN_VERSION >= 7201
CreateRNNDataDescriptor()441 RNNDataDescriptor CreateRNNDataDescriptor() {
442   cudnnRNNDataDescriptor_t result;
443   CHECK_CUDNN_OK(cudnnCreateRNNDataDescriptor(&result));
444   return RNNDataDescriptor(result);
445 }
446 #endif
CreateFilterDescriptor()447 FilterDescriptor CreateFilterDescriptor() {
448   cudnnFilterDescriptor_t result;
449   CHECK_CUDNN_OK(cudnnCreateFilterDescriptor(&result));
450   return FilterDescriptor(result);
451 }
CreateConvolutionDescriptor()452 ConvolutionDescriptor CreateConvolutionDescriptor() {
453   cudnnConvolutionDescriptor_t result;
454   CHECK_CUDNN_OK(cudnnCreateConvolutionDescriptor(&result));
455   return ConvolutionDescriptor(result);
456 }
CreatePoolingDescriptor()457 PoolingDescriptor CreatePoolingDescriptor() {
458   cudnnPoolingDescriptor_t result;
459   CHECK_CUDNN_OK(cudnnCreatePoolingDescriptor(&result));
460   return PoolingDescriptor(result);
461 }
CreateLrnDescriptor()462 LrnDescriptor CreateLrnDescriptor() {
463   cudnnLRNDescriptor_t result;
464   CHECK_CUDNN_OK(cudnnCreateLRNDescriptor(&result));
465   return LrnDescriptor(result);
466 }
CreateActivationDescriptor()467 ActivationDescriptor CreateActivationDescriptor() {
468   cudnnActivationDescriptor_t result;
469   CHECK_CUDNN_OK(cudnnCreateActivationDescriptor(&result));
470   return ActivationDescriptor(result);
471 }
CreateDropoutDescriptor()472 DropoutDescriptor CreateDropoutDescriptor() {
473   cudnnDropoutDescriptor_t result;
474   CHECK_CUDNN_OK(cudnnCreateDropoutDescriptor(&result));
475   return DropoutDescriptor(result);
476 }
CreateRnnDescriptor()477 RnnDescriptor CreateRnnDescriptor() {
478   cudnnRNNDescriptor_t result;
479   CHECK_CUDNN_OK(cudnnCreateRNNDescriptor(&result));
480   return RnnDescriptor(result);
481 }
482 
CreatePersistentRnnPlan(cudnnRNNDescriptor_t rnn_desc,int batch_size,cudnnDataType_t data_type)483 port::StatusOr<PersistentRnnPlan> CreatePersistentRnnPlan(
484     cudnnRNNDescriptor_t rnn_desc, int batch_size, cudnnDataType_t data_type) {
485   cudnnPersistentRNNPlan_t result;
486   RETURN_IF_CUDNN_ERROR(
487       cudnnCreatePersistentRNNPlan(rnn_desc, batch_size, data_type, &result));
488   return port::StatusOr<PersistentRnnPlan>(PersistentRnnPlan(result));
489 }
490 
491 // Turns a BatchDescriptor structure into a cudnn tensor handle within a
492 // scope.
493 class CudnnTensorDescriptor {
494  public:
CudnnTensorDescriptor(const dnn::BatchDescriptor & batch_descriptor,cudnnDataType_t elem_type)495   CudnnTensorDescriptor(const dnn::BatchDescriptor& batch_descriptor,
496                         cudnnDataType_t elem_type)
497       : handle_(CreateTensorDescriptor()) {
498     switch (batch_descriptor.layout()) {
499       case dnn::DataLayout::kBatchYXDepth:
500       case dnn::DataLayout::kBatchDepthYX: {
501         const int nd = batch_descriptor.ndims() + 2;
502         // cuDNN requires the strides and dims to be ordered as BDYX.
503         std::vector<int64> strides64 =
504             batch_descriptor.full_strides(dnn::DataLayout::kBatchDepthYX);
505         std::vector<int64> dims64 =
506             batch_descriptor.full_dims(dnn::DataLayout::kBatchDepthYX);
507 
508         // cuDNN requires arrays of ints.
509         std::vector<int> strides(nd);
510         std::vector<int> dims(nd);
511         std::transform(strides64.cbegin(), strides64.cend(), strides.begin(),
512                        &CheckedNarrowing<int64, int>);
513         std::transform(dims64.cbegin(), dims64.cend(), dims.begin(),
514                        &CheckedNarrowing<int64, int>);
515         CHECK_CUDNN_OK(cudnnSetTensorNdDescriptor(handle_.get(), elem_type, nd,
516                                                   dims.data(), strides.data()))
517             << "batch_descriptor: " << batch_descriptor.ToString();
518       } break;
519       case dnn::DataLayout::kBatchDepthYX4: {
520         CHECK_CUDNN_OK(cudnnSetTensor4dDescriptor(
521             handle_.get(), CUDNN_TENSOR_NCHW_VECT_C, elem_type,
522             batch_descriptor.count(), batch_descriptor.feature_map_count(),
523             batch_descriptor.height(), batch_descriptor.width()))
524             << "batch_descriptor: " << batch_descriptor.ToString();
525       } break;
526       default:
527         LOG(FATAL) << "Unsupported tensor format "
528                    << DataLayoutString(batch_descriptor.layout());
529         break;
530     }
531   }
532 
handle() const533   cudnnTensorDescriptor_t handle() const { return handle_.get(); }
534 
535  private:
536   TensorDescriptor handle_;
537 
538   SE_DISALLOW_COPY_AND_ASSIGN(CudnnTensorDescriptor);
539 };
540 
541 // Turns a FilterDescriptor structure into a cudnn filter handle within a
542 // scope.
543 class CudnnFilterDescriptor {
544  public:
CudnnFilterDescriptor(const dnn::FilterDescriptor & filter_descriptor,cudnnDataType_t elem_type)545   CudnnFilterDescriptor(const dnn::FilterDescriptor& filter_descriptor,
546                         cudnnDataType_t elem_type)
547       : handle_(CreateFilterDescriptor()) {
548     // TODO(b/23032134): Even if the filter layout is not supported,
549     // cudnnSetFilter4DDescriptor_v4 will return CUDNN_STATUS_SUCCESS because
550     // it does not take layout as an input. Maybe force cuDNN by giving wrong
551     // inputs intentionally?
552     cudnnTensorFormat_t format;
553     switch (filter_descriptor.layout()) {
554       case dnn::FilterLayout::kOutputInputYX:
555         format = CUDNN_TENSOR_NCHW;
556         break;
557       case dnn::FilterLayout::kOutputYXInput:
558         format = CUDNN_TENSOR_NHWC;
559         break;
560       case dnn::FilterLayout::kOutputInputYX4:
561         format = CUDNN_TENSOR_NCHW_VECT_C;
562         break;
563       default:
564         LOG(FATAL) << "Unsupported filter format "
565                    << FilterLayoutString(filter_descriptor.layout());
566         break;
567     }
568 
569     std::vector<int> dims(2 + filter_descriptor.ndims());
570     dims[0] = filter_descriptor.output_feature_map_count();
571     dims[1] = filter_descriptor.input_feature_map_count();
572     absl::Span<const int64> spatial_dims =
573         filter_descriptor.input_filter_dims();
574     std::copy(spatial_dims.begin(), spatial_dims.end(), dims.begin() + 2);
575 
576     CHECK_CUDNN_OK(cudnnSetFilterNdDescriptor(handle_.get(), elem_type, format,
577                                               dims.size(), dims.data()));
578   }
579 
handle() const580   cudnnFilterDescriptor_t handle() const { return handle_.get(); }
581 
582  private:
583   FilterDescriptor handle_;  // Owned.
584 
585   SE_DISALLOW_COPY_AND_ASSIGN(CudnnFilterDescriptor);
586 };
587 
588 // A helper function to decide whether to enable the TENSOR_OP_MATH math type
TensorOpMathEnabled()589 bool TensorOpMathEnabled() {
590   static bool is_enabled = [] {
591     bool is_disabled = false;
592     TF_CHECK_OK(
593         tensorflow::ReadBoolFromEnvVar("TF_DISABLE_CUDNN_TENSOR_OP_MATH",
594                                        /*default_val=*/false, &is_disabled));
595     return !is_disabled;
596   }();
597   return is_enabled;
598 }
599 
600 // A helper function to decide whether to enable the TENSOR_OP_MATH math type
601 // for RNNs.
RnnTensorOpMathEnabled()602 bool RnnTensorOpMathEnabled() {
603   static bool is_enabled = [] {
604     bool is_disabled = false;
605     TF_CHECK_OK(
606         tensorflow::ReadBoolFromEnvVar("TF_DISABLE_CUDNN_RNN_TENSOR_OP_MATH",
607                                        /*default_val=*/false, &is_disabled));
608     return !is_disabled;
609   }();
610   return is_enabled;
611 }
612 
613 // A helper function to decide whether to use
614 // CUDNN_BATCHNORM_SPATIAL_PERSISTENT in batchnorm. This mode can be faster in
615 // some tasks because an optimized path may be selected for CUDNN_DATA_FLOAT
616 // and CUDNN_DATA_HALF data types, compute capability 6.0 or higher. The
617 // reason we set it to false by default is that this mode may use scaled
618 // atomic integer reduction that may cause a numerical overflow for certain
619 // input data range.
620 // TODO(yangzihao): Use autotune to choose between this mode and
621 // CUDNN_BATCHNORM_SPATIAL mode.
BatchnormSpatialPersistentEnabled()622 bool BatchnormSpatialPersistentEnabled() {
623   static bool is_enabled = [] {
624     bool is_enabled = false;
625     TF_CHECK_OK(tensorflow::ReadBoolFromEnvVar(
626         "TF_USE_CUDNN_BATCHNORM_SPATIAL_PERSISTENT",
627         /*default_val=*/false, &is_enabled));
628     return is_enabled;
629   }();
630   return is_enabled;
631 }
632 
633 // A helper function to decide whether to enable deterministic functionality.
RequireDeterminism()634 bool RequireDeterminism() {
635   static bool is_enabled = [] {
636     bool is_enabled = false;
637     TF_CHECK_OK(tensorflow::ReadBoolFromEnvVar("TF_CUDNN_DETERMINISTIC",
638                                                /*default_val=*/false,
639                                                &is_enabled));
640     return is_enabled;
641   }();
642   return is_enabled;
643 }
644 
645 // Turns a ConvolutionDescriptor structure into a cudnn convolution handle
646 // within a scope.
647 class CudnnConvolutionDescriptor {
648  public:
CudnnConvolutionDescriptor(const dnn::ConvolutionDescriptor & convolution_descriptor,cudnnDataType_t data_type)649   CudnnConvolutionDescriptor(
650       const dnn::ConvolutionDescriptor& convolution_descriptor,
651       cudnnDataType_t data_type)
652       : handle_(CreateConvolutionDescriptor()) {
653     absl::Span<const int64> strides64 = convolution_descriptor.strides();
654     absl::Span<const int64> padding64 = convolution_descriptor.padding();
655     absl::Span<const int64> dilations64 = convolution_descriptor.dilations();
656     CHECK_NE(convolution_descriptor.pad_alignment(),
657              dnn::PadAlignment::kTensorFlowPadding)
658         << "TensorFlow padding alignment is not supported.";
659 
660     // cuDNN requires arrays of ints.
661     std::vector<int> strides(convolution_descriptor.ndims());
662     std::vector<int> padding(convolution_descriptor.ndims());
663     std::vector<int> dilations(convolution_descriptor.ndims());
664     std::transform(strides64.cbegin(), strides64.cend(), strides.begin(),
665                    &CheckedNarrowing<int64, int>);
666     std::transform(padding64.cbegin(), padding64.cend(), padding.begin(),
667                    &CheckedNarrowing<int64, int>);
668     // TODO(yangzihao): Test with negative dilation to make sure that cudnn
669     // doesn't crash.
670     std::transform(dilations64.cbegin(), dilations64.cend(), dilations.begin(),
671                    &CheckedNarrowing<int64, int>);
672 
673     CHECK_CUDNN_OK(cudnnSetConvolutionNdDescriptor(
674         handle_.get(), convolution_descriptor.ndims(), padding.data(),
675         strides.data(), dilations.data(),
676         convolution_descriptor.convolution_not_crosscorr()
677             ? CUDNN_CONVOLUTION
678             : CUDNN_CROSS_CORRELATION,
679         data_type));
680 
681     // NOTE(benbarsdell): This only applies if tensor op math is enabled
682     //                      and algo selection is set to Default.
683     this->set_use_tensor_op_math(true);
684 
685 #if CUDNN_MAJOR >= 7
686     VLOG(2) << "Requesting grouped convolution: "
687             << convolution_descriptor.group_count();
688     CHECK_CUDNN_OK(cudnnSetConvolutionGroupCount(
689         handle_.get(), convolution_descriptor.group_count()));
690 #else
691     CHECK_EQ(convolution_descriptor.group_count(), 1)
692         << "Requested grouped convolution for cuDNN version < 7";
693 #endif
694   }
695 
set_use_tensor_op_math(bool use_tensor_op_math) const696   void set_use_tensor_op_math(bool use_tensor_op_math) const {
697 #if CUDNN_VERSION >= 7000
698     cudnnMathType_t math_type =
699         (use_tensor_op_math ? CUDNN_TENSOR_OP_MATH : CUDNN_DEFAULT_MATH);
700     if (TensorOpMathEnabled()) {
701       CHECK_CUDNN_OK(cudnnSetConvolutionMathType(handle_.get(), math_type));
702     }
703 #endif
704   }
705 
handle() const706   cudnnConvolutionDescriptor_t handle() const { return handle_.get(); }
707 
708  private:
709   ConvolutionDescriptor handle_;  // Owned.
710 
711   SE_DISALLOW_COPY_AND_ASSIGN(CudnnConvolutionDescriptor);
712 };
713 
714 // Turns a PoolingDescriptor structure into a cudnn pooling descriptor handle
715 // within a scope.
716 class CudnnPoolingDescriptor {
717  public:
CudnnPoolingDescriptor(const dnn::PoolingDescriptor & pooling_descriptor)718   explicit CudnnPoolingDescriptor(
719       const dnn::PoolingDescriptor& pooling_descriptor)
720       : handle_(CreatePoolingDescriptor()) {
721     absl::Span<const int64> strides64 = pooling_descriptor.strides();
722     absl::Span<const int64> padding64 = pooling_descriptor.padding();
723     absl::Span<const int64> shape64 = pooling_descriptor.window();
724 
725     const int nd = pooling_descriptor.ndims();
726     std::vector<int> shape(nd);
727     std::vector<int> padding(nd);
728     std::vector<int> strides(nd);
729     std::transform(strides64.cbegin(), strides64.cend(), strides.begin(),
730                    &CheckedNarrowing<int64, int>);
731     std::transform(padding64.cbegin(), padding64.cend(), padding.begin(),
732                    &CheckedNarrowing<int64, int>);
733     std::transform(shape64.cbegin(), shape64.cend(), shape.begin(),
734                    &CheckedNarrowing<int64, int>);
735     bool propagate_nans = pooling_descriptor.propagate_nans();
736     const auto cudnn_max_pooling_mode = RequireDeterminism()
737                                             ? CUDNN_POOLING_MAX_DETERMINISTIC
738                                             : CUDNN_POOLING_MAX;
739     CHECK_CUDNN_OK(cudnnSetPoolingNdDescriptor(
740         handle_.get(),
741         (pooling_descriptor.mode() == dnn::PoolingMode::kMaximum
742              ? cudnn_max_pooling_mode
743              : CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING),
744         propagate_nans ? CUDNN_PROPAGATE_NAN : CUDNN_NOT_PROPAGATE_NAN, nd,
745         shape.data(), padding.data(), strides.data()));
746   }
747 
handle() const748   cudnnPoolingDescriptor_t handle() const { return handle_.get(); }
749 
750  private:
751   PoolingDescriptor handle_;  // Owned.
752 
753   SE_DISALLOW_COPY_AND_ASSIGN(CudnnPoolingDescriptor);
754 };
755 
756 // Turns a NormalizeDescriptor structure into a cudnn LRN descriptor handle.
757 class CudnnNormalizeDescriptor {
758  public:
CudnnNormalizeDescriptor(const dnn::NormalizeDescriptor & normalize_descriptor)759   explicit CudnnNormalizeDescriptor(
760       const dnn::NormalizeDescriptor& normalize_descriptor)
761       : handle_(CreateLrnDescriptor()) {
762     // The range specifies that the indices in the closed range
763     // [i - range, i + range] should be included in the normalization for index
764     // i. The lrnN value is the total number of elements in the range, so
765     // lrnN = 2*range + 1.
766     unsigned lrnN = 2 * normalize_descriptor.range() + 1;
767 
768     // Note that SE defines the normalization operation as
769     //
770     //  U_i = V_i / ((bias +  alpha      * (sum_j V_j^2)) ^ beta)
771     //
772     // but cuDNN defines it as
773     //
774     //  U_i = V_i / ((bias + (alpha / n) * (sum_j V_j^2)) ^ beta)
775     //
776     // i.e. there is a factor of n difference between the meaning of the alphas
777     // in the two contexts. The cuDNN alpha is n times the SE alpha.
778     double lrnAlpha = lrnN * normalize_descriptor.alpha();
779 
780     double lrnBeta = normalize_descriptor.beta();
781     double lrnK = normalize_descriptor.bias();
782     CHECK_CUDNN_OK(
783         cudnnSetLRNDescriptor(handle_.get(), lrnN, lrnAlpha, lrnBeta, lrnK));
784   }
785 
handle() const786   cudnnLRNDescriptor_t handle() const { return handle_.get(); }
787 
788  private:
789   LrnDescriptor handle_;  // Owned.
790 
791   SE_DISALLOW_COPY_AND_ASSIGN(CudnnNormalizeDescriptor);
792 };
793 
794 // Turns a ActivationDescriptor structure into a cudnn activation
795 // descriptor handle within a scope.
796 class CudnnActivationDescriptor {
797  public:
CudnnActivationDescriptor(dnn::ActivationMode activation_mode,cudnnNanPropagation_t nan_propagation,double value_max)798   CudnnActivationDescriptor(dnn::ActivationMode activation_mode,
799                             cudnnNanPropagation_t nan_propagation,
800                             double value_max)
801       : handle_(CreateActivationDescriptor()) {
802     double relu_ceiling = 0.0;
803     cudnnActivationMode_t mode;
804     switch (activation_mode) {
805 #if CUDNN_VERSION >= 7100
806       case dnn::ActivationMode::kNone:
807         mode = CUDNN_ACTIVATION_IDENTITY;
808         break;
809 #endif
810       case dnn::ActivationMode::kRelu6:
811         relu_ceiling = 6.0;
812         mode = CUDNN_ACTIVATION_CLIPPED_RELU;
813         break;
814       case dnn::ActivationMode::kReluX:
815         relu_ceiling = value_max;
816         mode = CUDNN_ACTIVATION_CLIPPED_RELU;
817         break;
818       case dnn::ActivationMode::kRelu:
819         mode = CUDNN_ACTIVATION_RELU;
820         break;
821       case dnn::ActivationMode::kSigmoid:
822         mode = CUDNN_ACTIVATION_SIGMOID;
823         break;
824       case dnn::ActivationMode::kTanh:
825         mode = CUDNN_ACTIVATION_TANH;
826         break;
827       default:
828         LOG(FATAL) << "unrecognized activation mode: "
829                    << static_cast<int>(activation_mode);
830     }
831 
832     CHECK_CUDNN_OK(cudnnSetActivationDescriptor(handle_.get(), mode,
833                                                 nan_propagation, relu_ceiling));
834   }
835 
handle() const836   cudnnActivationDescriptor_t handle() const { return handle_.get(); }
837 
838  private:
839   ActivationDescriptor handle_;  // Owned.
840 
841   SE_DISALLOW_COPY_AND_ASSIGN(CudnnActivationDescriptor);
842 };
843 
ToCudnnDataType(dnn::DataType data_type,dnn::DataLayout data_layout=dnn::DataLayout::kBatchDepthYX)844 cudnnDataType_t ToCudnnDataType(
845     dnn::DataType data_type,
846     dnn::DataLayout data_layout = dnn::DataLayout::kBatchDepthYX) {
847   switch (data_type) {
848     case dnn::DataType::kFloat:
849       return CUDNN_DATA_FLOAT;
850     case dnn::DataType::kDouble:
851       return CUDNN_DATA_DOUBLE;
852     case dnn::DataType::kHalf:
853       return CUDNN_DATA_HALF;
854     case dnn::DataType::kInt8:
855       return data_layout == dnn::DataLayout::kBatchDepthYX4 ? CUDNN_DATA_INT8x4
856                                                             : CUDNN_DATA_INT8;
857     case dnn::DataType::kInt32:
858       return CUDNN_DATA_INT32;
859     default:
860       LOG(FATAL) << "Invalid DNN data type: " << static_cast<int>(data_type);
861   }
862 }
863 
ToCudnnDataType(dnn::DataType data_type,dnn::FilterLayout filter_layout)864 cudnnDataType_t ToCudnnDataType(dnn::DataType data_type,
865                                 dnn::FilterLayout filter_layout) {
866   if (data_type == dnn::DataType::kInt8 &&
867       filter_layout == dnn::FilterLayout::kOutputInputYX4) {
868     return CUDNN_DATA_INT8x4;
869   }
870   return ToCudnnDataType(data_type);
871 }
872 
873 template <typename T>
GetCudnnDataType(dnn::DataLayout data_layout=dnn::DataLayout::kBatchDepthYX)874 cudnnDataType_t GetCudnnDataType(
875     dnn::DataLayout data_layout = dnn::DataLayout::kBatchDepthYX) {
876   return ToCudnnDataType(dnn::ToDataType<T>::value, data_layout);
877 }
878 
ToCudnnRnnInputMode(dnn::RnnInputMode input_mode)879 cudnnRNNInputMode_t ToCudnnRnnInputMode(dnn::RnnInputMode input_mode) {
880   switch (input_mode) {
881     case dnn::RnnInputMode::kRnnLinearSkip:
882     case dnn::RnnInputMode::kRnnSkipInput:
883       return static_cast<cudnnRNNInputMode_t>(input_mode);
884     default:
885       LOG(FATAL) << "Invalid RNN input mode: " << static_cast<int>(input_mode);
886   }
887 }
888 
ToCudnnRnnDirectionMode(dnn::RnnDirectionMode direction_mode)889 cudnnDirectionMode_t ToCudnnRnnDirectionMode(
890     dnn::RnnDirectionMode direction_mode) {
891   switch (direction_mode) {
892     case dnn::RnnDirectionMode::kRnnUnidirectional:
893     case dnn::RnnDirectionMode::kRnnBidirectional:
894       return static_cast<cudnnDirectionMode_t>(direction_mode);
895     default:
896       LOG(FATAL) << "Invalid RNN direction mode: "
897                  << static_cast<int>(direction_mode);
898   }
899 }
900 
ToCudnnRnnMode(dnn::RnnMode rnn_mode)901 cudnnRNNMode_t ToCudnnRnnMode(dnn::RnnMode rnn_mode) {
902   switch (rnn_mode) {
903     case dnn::RnnMode::kRnnRelu:
904     case dnn::RnnMode::kRnnTanh:
905     case dnn::RnnMode::kRnnLstm:
906     case dnn::RnnMode::kRnnGru:
907       return static_cast<cudnnRNNMode_t>(rnn_mode);
908     default:
909       LOG(FATAL) << "Invalid RNN Mode: " << static_cast<int>(rnn_mode);
910   }
911 }
912 
CudnnDataTypeToByteSize(cudnnDataType_t data_type)913 int CudnnDataTypeToByteSize(cudnnDataType_t data_type) {
914   switch (data_type) {
915     case CUDNN_DATA_FLOAT:
916       return sizeof(float);
917     case CUDNN_DATA_DOUBLE:
918       return sizeof(double);
919     case CUDNN_DATA_HALF:
920       return sizeof(Eigen::half);
921     default:
922       LOG(FATAL) << "Invalid DNN data type: " << static_cast<int>(data_type);
923   }
924 }
925 
926 class CudnnDropoutDescriptor {
CudnnDropoutDescriptor(DropoutDescriptor handle)927   explicit CudnnDropoutDescriptor(DropoutDescriptor handle)
928       : handle_(std::move(handle)) {}
929 
930  public:
931   CudnnDropoutDescriptor(CudnnDropoutDescriptor&&) = default;
932 
Create(const CudnnHandle & cudnn,float dropout,uint64 seed,ScratchAllocator * state_allocator)933   static port::StatusOr<CudnnDropoutDescriptor> Create(
934       const CudnnHandle& cudnn, float dropout, uint64 seed,
935       ScratchAllocator* state_allocator) {
936     DropoutDescriptor handle = CreateDropoutDescriptor();
937 
938     if (dropout == 0.0f) {
939       // Return 'empty' dropout descriptor.
940       return CudnnDropoutDescriptor(std::move(handle));
941     }
942 
943     DeviceMemory<uint8> state_memory;
944     if (state_allocator) {
945       size_t state_sizes_in_bytes = 0;
946       RETURN_IF_CUDNN_ERROR(
947           cudnnDropoutGetStatesSize(cudnn.handle(), &state_sizes_in_bytes));
948       SE_ASSIGN_OR_RETURN(state_memory, state_allocator->AllocateBytes(
949                                             nullptr, state_sizes_in_bytes));
950     }
951     RETURN_IF_CUDNN_ERROR(cudnnSetDropoutDescriptor(
952         handle.get(), cudnn.handle(), dropout, state_memory.opaque(),
953         state_memory.size(), seed));
954 
955     return CudnnDropoutDescriptor(std::move(handle));
956   }
957 
handle() const958   cudnnDropoutDescriptor_t handle() const { return handle_.get(); }
959 
960  private:
961   DropoutDescriptor handle_;  // Owned.
962   SE_DISALLOW_COPY_AND_ASSIGN(CudnnDropoutDescriptor);
963 };
964 
965 class CudnnRnnParamsDescriptor {
966   typedef dnn::RnnDescriptor::ParamsRegions ParamsRegions;
967 
CudnnRnnParamsDescriptor(FilterDescriptor handle,int64 params_size_in_bytes,ParamsRegions weights,ParamsRegions biases)968   CudnnRnnParamsDescriptor(FilterDescriptor handle, int64 params_size_in_bytes,
969                            ParamsRegions weights, ParamsRegions biases)
970       : handle_(std::move(handle)),
971         params_size_in_bytes_(params_size_in_bytes),
972         weights_(std::move(weights)),
973         biases_(std::move(biases)) {}
974 
975  public:
976   CudnnRnnParamsDescriptor(CudnnRnnParamsDescriptor&&) = default;
977 
978   static port::StatusOr<CudnnRnnParamsDescriptor> Create(
979       const CudnnHandle& cudnn, int input_size, cudnnDataType_t data_type,
980       cudnnRNNDescriptor_t rnn_desc, cudnnRNNMode_t rnn_mode,
981       cudnnDirectionMode_t direction_mode, int num_layers);
982 
handle() const983   cudnnFilterDescriptor_t handle() const { return handle_.get(); }
params_size_in_bytes() const984   int64 params_size_in_bytes() const { return params_size_in_bytes_; }
params_weights() const985   ParamsRegions params_weights() const {
986     return weights_;
987   }
params_biases() const988   ParamsRegions params_biases() const {
989     return biases_;
990   }
991 
992  private:
993   FilterDescriptor handle_;
994   int64 params_size_in_bytes_;
995   ParamsRegions weights_;
996   ParamsRegions biases_;
997   SE_DISALLOW_COPY_AND_ASSIGN(CudnnRnnParamsDescriptor);
998 };
999 
1000 }  // namespace
1001 
1002 class CudnnRnnDescriptor : public dnn::RnnDescriptor {
CudnnRnnDescriptor(const CudnnHandle & cudnn,gpu::RnnDescriptor rnn_desc,PersistentRnnPlan rnn_plan,int num_layers,int hidden_size,int input_size,int batch_size,cudnnRNNInputMode_t input_mode,cudnnDirectionMode_t direction_mode,cudnnRNNMode_t rnn_mode,cudnnDataType_t data_type,cudnnDataType_t compute_type,const dnn::AlgorithmConfig & algorithm_config,CudnnDropoutDescriptor dropout_desc,CudnnRnnParamsDescriptor params_desc)1003   CudnnRnnDescriptor(const CudnnHandle& cudnn, gpu::RnnDescriptor rnn_desc,
1004                      PersistentRnnPlan rnn_plan, int num_layers,
1005                      int hidden_size, int input_size, int batch_size,
1006                      cudnnRNNInputMode_t input_mode,
1007                      cudnnDirectionMode_t direction_mode,
1008                      cudnnRNNMode_t rnn_mode, cudnnDataType_t data_type,
1009                      cudnnDataType_t compute_type,
1010                      const dnn::AlgorithmConfig& algorithm_config,
1011                      CudnnDropoutDescriptor dropout_desc,
1012                      CudnnRnnParamsDescriptor params_desc)
1013       : rnn_desc_(std::move(rnn_desc)),
1014         rnn_plan_(std::move(rnn_plan)),
1015         num_layers_(num_layers),
1016         hidden_size_(hidden_size),
1017         input_size_(input_size),
1018         batch_size_(batch_size),
1019         rnn_algo_(ToCudnnRNNAlgo(algorithm_config.algorithm())),
1020         input_mode_(input_mode),
1021         direction_mode_(direction_mode),
1022         rnn_mode_(rnn_mode),
1023         data_type_(data_type),
1024         compute_type_(compute_type),
1025         algorithm_config_(algorithm_config),
1026         dropout_desc_(std::move(dropout_desc)),
1027         params_desc_(std::move(params_desc)) {}
1028 
1029  public:
1030   CudnnRnnDescriptor(CudnnRnnDescriptor&& other) = default;
1031 
Create(const CudnnHandle & cudnn,int num_layers,int hidden_size,int input_size,int batch_size,cudnnRNNInputMode_t input_mode,cudnnDirectionMode_t direction_mode,cudnnRNNMode_t rnn_mode,cudnnDataType_t data_type,cudnnDataType_t compute_type,const dnn::AlgorithmConfig & algorithm_config,float dropout,uint64 seed,ScratchAllocator * state_allocator)1032   static port::StatusOr<CudnnRnnDescriptor> Create(
1033       const CudnnHandle& cudnn, int num_layers, int hidden_size, int input_size,
1034       int batch_size, cudnnRNNInputMode_t input_mode,
1035       cudnnDirectionMode_t direction_mode, cudnnRNNMode_t rnn_mode,
1036       cudnnDataType_t data_type, cudnnDataType_t compute_type,
1037       const dnn::AlgorithmConfig& algorithm_config, float dropout, uint64 seed,
1038       ScratchAllocator* state_allocator) {
1039     SE_ASSIGN_OR_RETURN(
1040         CudnnDropoutDescriptor dropout_desc,
1041         CudnnDropoutDescriptor::Create(cudnn, dropout, seed, state_allocator));
1042 
1043     gpu::RnnDescriptor rnn_desc = CreateRnnDescriptor();
1044     cudnnRNNAlgo_t rnn_algo = ToCudnnRNNAlgo(algorithm_config.algorithm());
1045 
1046     // TODO: allow the user to choose an algorithm.
1047     RETURN_IF_CUDNN_ERROR(cudnnSetRNNDescriptor_v6(
1048         cudnn.handle(), /*rnnDesc=*/rnn_desc.get(), /*hiddenSize=*/hidden_size,
1049         /*numLayers=*/num_layers, /*dropoutDesc=*/dropout_desc.handle(),
1050         /*inputMode=*/input_mode, /*direction=*/direction_mode,
1051         /*mode=*/rnn_mode, /*algo=*/rnn_algo,
1052         /*dataType=*/compute_type));
1053 
1054     // TODO: For now, we only use cudnnRNN**Ex API to process padded inputs.
1055     // But in the future if these APIs are used to process full length arrays,
1056     // we need to distinguish when to set it.
1057 #if CUDNN_VERSION >= 7201
1058     RETURN_IF_CUDNN_ERROR(
1059         cudnnSetRNNPaddingMode(rnn_desc.get(), CUDNN_RNN_PADDED_IO_ENABLED));
1060 #endif
1061 
1062     port::StatusOr<PersistentRnnPlan> rnn_plan_wrapper;
1063     PersistentRnnPlan rnn_plan;
1064     if (rnn_algo == CUDNN_RNN_ALGO_PERSIST_DYNAMIC) {
1065       CHECK_GE(batch_size, 0);
1066       rnn_plan_wrapper =
1067           CreatePersistentRnnPlan(rnn_desc.get(), batch_size, data_type);
1068       if (!rnn_plan_wrapper.ok()) {
1069         return port::StatusOr<CudnnRnnDescriptor>(rnn_plan_wrapper.status());
1070       } else {
1071         rnn_plan = rnn_plan_wrapper.ConsumeValueOrDie();
1072         RETURN_IF_CUDNN_ERROR(
1073             cudnnSetPersistentRNNPlan(rnn_desc.get(), rnn_plan.get()));
1074       }
1075     }
1076 
1077     // Create the params handle.
1078     SE_ASSIGN_OR_RETURN(auto params_desc,
1079                         CudnnRnnParamsDescriptor::Create(
1080                             cudnn, input_size, data_type, rnn_desc.get(),
1081                             rnn_mode, direction_mode, num_layers));
1082 
1083 #if CUDNN_VERSION >= 7000
1084     // Require explicit algorithm config to enable tensor cores. Some configs
1085     // return CUDNN_NOT_SUPPORTED when tensor ops are enabled (which is against
1086     // the idiom that enabling tensor ops is only a hint: see nvbugs/2172799).
1087     // We can only reasonably expect the user to handle the subsequent failure
1088     // in profile mode, which is run with algorithms returned from
1089     // GetRnnAlgorithms() (which are non-default and explicitly set whether to
1090     // use tensor ops).
1091     if (RnnTensorOpMathEnabled() && algorithm_config.algorithm().has_value()) {
1092       cudnnMathType_t math_type =
1093           algorithm_config.algorithm()->tensor_ops_enabled()
1094               ? CUDNN_TENSOR_OP_MATH
1095               : CUDNN_DEFAULT_MATH;
1096       CHECK_CUDNN_OK(cudnnSetRNNMatrixMathType(rnn_desc.get(), math_type));
1097     }
1098 #endif
1099 
1100     return CudnnRnnDescriptor(cudnn, std::move(rnn_desc), std::move(rnn_plan),
1101                               num_layers, hidden_size, input_size, batch_size,
1102                               input_mode, direction_mode, rnn_mode, data_type,
1103                               compute_type, algorithm_config,
1104                               std::move(dropout_desc), std::move(params_desc));
1105   }
1106 
handle() const1107   cudnnRNNDescriptor_t handle() const { return rnn_desc_.get(); }
num_layers() const1108   int num_layers() const { return num_layers_; }
hidden_size() const1109   int hidden_size() const { return hidden_size_; }
input_size() const1110   int input_size() const { return input_size_; }
batch_size() const1111   int batch_size() const { return batch_size_; }
input_mode() const1112   cudnnRNNInputMode_t input_mode() const { return input_mode_; }
direction_mode() const1113   cudnnDirectionMode_t direction_mode() const { return direction_mode_; }
rnn_mode() const1114   cudnnRNNMode_t rnn_mode() const { return rnn_mode_; }
data_type() const1115   cudnnDataType_t data_type() const { return data_type_; }
compute_type() const1116   cudnnDataType_t compute_type() const { return compute_type_; }
algorithm_config() const1117   const dnn::AlgorithmConfig& algorithm_config() const {
1118     return algorithm_config_;
1119   }
ParamsSizeInBytes() const1120   int64 ParamsSizeInBytes() const override {
1121     return params_desc_.params_size_in_bytes();
1122   }
params_handle() const1123   cudnnFilterDescriptor_t params_handle() const {
1124     return params_desc_.handle();
1125   }
ParamsWeightRegions() const1126   ParamsRegions ParamsWeightRegions() const override {
1127     return params_desc_.params_weights();
1128   }
ParamsBiasRegions() const1129   ParamsRegions ParamsBiasRegions() const override {
1130     return params_desc_.params_biases();
1131   }
1132 
1133  private:
1134   gpu::RnnDescriptor rnn_desc_;
1135   PersistentRnnPlan rnn_plan_;
1136   int num_layers_;
1137   int hidden_size_;
1138   int input_size_;
1139   // batch_size_ is set to -1 when not using CUDNN_RNN_ALGO_PERSIST_DYNAMIC
1140   // algorithm.
1141   int batch_size_;
1142   cudnnRNNAlgo_t rnn_algo_;
1143   cudnnRNNInputMode_t input_mode_;
1144   cudnnDirectionMode_t direction_mode_;
1145   cudnnRNNMode_t rnn_mode_;
1146   cudnnDataType_t data_type_;
1147   cudnnDataType_t compute_type_;
1148   dnn::AlgorithmConfig algorithm_config_;
1149   CudnnDropoutDescriptor dropout_desc_;
1150   CudnnRnnParamsDescriptor params_desc_;
1151   SE_DISALLOW_COPY_AND_ASSIGN(CudnnRnnDescriptor);
1152 };
1153 
1154 namespace {
1155 
Create(const CudnnHandle & cudnn,int input_size,cudnnDataType_t data_type,cudnnRNNDescriptor_t rnn_desc,cudnnRNNMode_t rnn_mode,cudnnDirectionMode_t direction_mode,int num_layers)1156 port::StatusOr<CudnnRnnParamsDescriptor> CudnnRnnParamsDescriptor::Create(
1157     const CudnnHandle& cudnn, int input_size, cudnnDataType_t data_type,
1158     cudnnRNNDescriptor_t rnn_desc, cudnnRNNMode_t rnn_mode,
1159     cudnnDirectionMode_t direction_mode, int num_layers) {
1160   // Query the params size.
1161   TensorDescriptor input_desc = CreateTensorDescriptor();
1162   int tensor_dims[] = {1, input_size, 1};
1163   int strides[] = {tensor_dims[1] * tensor_dims[2], tensor_dims[2], 1};
1164   RETURN_IF_CUDNN_ERROR(cudnnSetTensorNdDescriptor(
1165       /*tensorDesc=*/input_desc.get(), /*dataType=*/data_type,
1166       /*nbDims=*/sizeof(tensor_dims) / sizeof(tensor_dims[0]),
1167       /*dimA=*/tensor_dims,
1168       /*strideA=*/strides));
1169 
1170   size_t params_size = 0;
1171   RETURN_IF_CUDNN_ERROR(cudnnGetRNNParamsSize(
1172       /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc,
1173       /*xDesc=*/input_desc.get(), /*sizeInBytes=*/&params_size,
1174       /*dataType=*/data_type));
1175   int64 params_size_in_bytes = static_cast<int64>(params_size);
1176 
1177   FilterDescriptor filter_desc = CreateFilterDescriptor();
1178   int filter_dims[] = {static_cast<int>(params_size_in_bytes), 1, 1};
1179   RETURN_IF_CUDNN_ERROR(cudnnSetFilterNdDescriptor(
1180       /*filterDesc=*/filter_desc.get(), /*dataType=*/data_type,
1181       /*format=*/CUDNN_TENSOR_NCHW,
1182       /*nbDims=*/sizeof(filter_dims) / sizeof(filter_dims[0]),
1183       /*filterDimA=*/filter_dims));
1184 
1185   // Create the weights and biases into the params buffer
1186   int region_count_per_layer = [&] {
1187     switch (rnn_mode) {
1188       case CUDNN_RNN_RELU:
1189       case CUDNN_RNN_TANH:
1190         return 2;
1191       case CUDNN_LSTM:
1192         return 8;
1193       case CUDNN_GRU:
1194         return 6;
1195       default:
1196         LOG(FATAL) << "Invalid RNN Mode: " << static_cast<int>(rnn_mode);
1197         return 0;
1198     }
1199   }();
1200 
1201   FilterDescriptor region_desc_handle = CreateFilterDescriptor();
1202   const int layer_count =
1203       direction_mode == CUDNN_UNIDIRECTIONAL ? num_layers : 2 * num_layers;
1204 
1205   ParamsRegions weights;
1206   ParamsRegions biases;
1207 
1208   for (int layer = 0; layer < layer_count; layer++) {
1209     for (int region = 0; region < region_count_per_layer; region++) {
1210       for (int type = 0; type < 2; type++) {
1211         void* offset = nullptr;
1212         RETURN_IF_CUDNN_ERROR(
1213             type == 0 ? cudnnGetRNNLinLayerMatrixParams(
1214                             /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc,
1215                             /*layer=*/layer, /*xDesc=*/input_desc.get(),
1216                             /*wDesc=*/filter_desc.get(),
1217                             /*w=*/nullptr, /*linLayerID=*/region,
1218                             /*linLayerMatDesc=*/region_desc_handle.get(),
1219                             /*linLayerMat or linLayerBias=*/&offset)
1220                       : cudnnGetRNNLinLayerBiasParams(
1221                             /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc,
1222                             /*layer=*/layer, /*xDesc=*/input_desc.get(),
1223                             /*wDesc=*/filter_desc.get(),
1224                             /*w=*/nullptr, /*linLayerID=*/region,
1225                             /*linLayerMatDesc=*/region_desc_handle.get(),
1226                             /*linLayerMat or linLayerBias=*/&offset));
1227         int dims[] = {1, 1, 1};
1228         cudnnDataType_t data_type;
1229         cudnnTensorFormat_t tensor_format;
1230         int n_dims;
1231         RETURN_IF_CUDNN_ERROR(cudnnGetFilterNdDescriptor(
1232             /*filterDesc=*/region_desc_handle.get(),
1233             /*nbDimsRequested=*/sizeof(dims) / sizeof(dims[0]),
1234             /*dataType=*/&data_type, /*format=*/&tensor_format,
1235             /*nbDims=*/&n_dims, /*filterDimA=*/dims));
1236         int64 size =
1237             dims[0] * dims[1] * dims[2] * CudnnDataTypeToByteSize(data_type);
1238         dnn::RnnDescriptor::ParamsRegion region = {
1239             reinterpret_cast<int64>(offset), size};
1240         (type == 0 ? weights : biases).push_back(region);
1241       }
1242     }
1243   }
1244 
1245   return CudnnRnnParamsDescriptor(std::move(filter_desc), params_size_in_bytes,
1246                                   weights, biases);
1247 }
1248 
1249 }  // namespace
1250 
1251 class CudnnRnnSequenceTensorDescriptor
1252     : public dnn::RnnSequenceTensorDescriptor {
CudnnRnnSequenceTensorDescriptor(GpuExecutor * parent,int max_seq_length,int batch_size,int data_size,cudnnDataType_t data_type,RNNDataDescriptor data_handle,TensorDescriptor handle)1253   CudnnRnnSequenceTensorDescriptor(GpuExecutor* parent, int max_seq_length,
1254                                    int batch_size, int data_size,
1255                                    cudnnDataType_t data_type,
1256 #if CUDNN_VERSION >= 7201
1257                                    RNNDataDescriptor data_handle,
1258 #endif
1259                                    TensorDescriptor handle)
1260       : max_seq_length_(max_seq_length),
1261         batch_size_(batch_size),
1262         data_size_(data_size),
1263         data_type_(data_type),
1264         handle_(std::move(handle)),
1265 #if CUDNN_VERSION >= 7201
1266         rnn_data_handle_(std::move(data_handle)),
1267 #endif
1268         handles_(max_seq_length, handle_.get()) {
1269   }
1270 
1271  public:
1272   CudnnRnnSequenceTensorDescriptor(CudnnRnnSequenceTensorDescriptor&&) =
1273       default;
1274 
Create(GpuExecutor * parent,int max_seq_length,int batch_size,int data_size,cudnnDataType_t data_type)1275   static port::StatusOr<CudnnRnnSequenceTensorDescriptor> Create(
1276       GpuExecutor* parent, int max_seq_length, int batch_size, int data_size,
1277       cudnnDataType_t data_type) {
1278     CHECK_GT(max_seq_length, 0);
1279     int dims[] = {batch_size, data_size, 1};
1280     int strides[] = {dims[1] * dims[2], dims[2], 1};
1281     TensorDescriptor tensor_desc = CreateTensorDescriptor();
1282     RETURN_IF_CUDNN_ERROR(cudnnSetTensorNdDescriptor(
1283         /*tensorDesc=*/tensor_desc.get(), /*dataType=*/data_type,
1284         /*nbDims=*/sizeof(dims) / sizeof(dims[0]), /*dimA=*/dims,
1285         /*strideA=*/strides));
1286     return CudnnRnnSequenceTensorDescriptor(parent, max_seq_length, batch_size,
1287                                             data_size, data_type,
1288 #if CUDNN_VERSION >= 7201
1289                                             nullptr,
1290 #endif
1291                                             std::move(tensor_desc));
1292   }
1293 
Create(GpuExecutor * parent,int max_seq_length,int batch_size,int data_size,const absl::Span<const int> & seq_lengths,bool time_major,cudnnDataType_t data_type)1294   static port::StatusOr<CudnnRnnSequenceTensorDescriptor> Create(
1295       GpuExecutor* parent, int max_seq_length, int batch_size, int data_size,
1296       const absl::Span<const int>& seq_lengths, bool time_major,
1297       cudnnDataType_t data_type) {
1298 #if CUDNN_VERSION >= 7201
1299     CHECK_GT(max_seq_length, 0);
1300     int dims[] = {batch_size, data_size, 1};
1301     int strides[] = {dims[1] * dims[2], dims[2], 1};
1302     TensorDescriptor tensor_desc = CreateTensorDescriptor();
1303     RETURN_IF_CUDNN_ERROR(cudnnSetTensorNdDescriptor(
1304         /*tensorDesc=*/tensor_desc.get(), /*dataType=*/data_type,
1305         /*nbDims=*/sizeof(dims) / sizeof(dims[0]), /*dimA=*/dims,
1306         /*strideA=*/strides));
1307     const int* seq_lengths_array = seq_lengths.data();
1308     RNNDataDescriptor data_desc = CreateRNNDataDescriptor();
1309     float padding_fill = 0.0f;
1310     cudnnRNNDataLayout_t layout;
1311     if (time_major) {
1312       layout = CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_UNPACKED;
1313     } else {
1314       layout = CUDNN_RNN_DATA_LAYOUT_BATCH_MAJOR_UNPACKED;
1315     }
1316     RETURN_IF_CUDNN_ERROR(cudnnSetRNNDataDescriptor(
1317         /*RNNDataDesc=*/data_desc.get(), /*dataType*/ data_type,
1318         /*layout=*/layout,
1319         /*maxSeqLength=*/max_seq_length,
1320         /*batchSize=*/batch_size, /*vectorSize=*/data_size,
1321         /*seqLengthArray=*/seq_lengths_array,
1322         /*paddingFill*/ (void*)&padding_fill));
1323     return CudnnRnnSequenceTensorDescriptor(
1324         parent, max_seq_length, batch_size, data_size, data_type,
1325         std::move(data_desc), std::move(tensor_desc));
1326 #else
1327     return port::Status(port::error::INVALID_ARGUMENT,
1328                         "No supported cudnnSetRNNDataDescriptor when "
1329                         "CUDNN_VERSION < 7.2.1");
1330 #endif
1331   }
1332 
handles() const1333   const cudnnTensorDescriptor_t* handles() const {
1334     return handles_.data();
1335   }
1336 #if CUDNN_VERSION >= 7201
data_handle() const1337   const cudnnRNNDataDescriptor_t data_handle() const {
1338     return rnn_data_handle_.get();
1339   }
1340 #endif
1341 
max_seq_length() const1342   int max_seq_length() const { return max_seq_length_; }
batch_size() const1343   int batch_size() const { return batch_size_; }
data_size() const1344   int data_size() const { return data_size_; }
is_var_seq_lengths() const1345   bool is_var_seq_lengths() const {
1346 #if CUDNN_VERSION >= 7201
1347     return rnn_data_handle_ != nullptr;
1348 #else
1349     return false;
1350 #endif
1351   }
1352 
1353  private:
1354   int max_seq_length_;
1355   int batch_size_;
1356   int data_size_;
1357   cudnnDataType_t data_type_;
1358   TensorDescriptor handle_;
1359 #if CUDNN_VERSION >= 7201
1360   RNNDataDescriptor rnn_data_handle_;
1361 #endif
1362   std::vector<cudnnTensorDescriptor_t> handles_;  // Copies of handle_.
1363   SE_DISALLOW_COPY_AND_ASSIGN(CudnnRnnSequenceTensorDescriptor);
1364 };
1365 
1366 class CudnnRnnStateTensorDescriptor : public dnn::RnnStateTensorDescriptor {
1367  public:
CudnnRnnStateTensorDescriptor(GpuExecutor * parent,int num_layers,int batch_size,int data_size,cudnnDataType_t data_type)1368   CudnnRnnStateTensorDescriptor(GpuExecutor* parent, int num_layers,
1369                                 int batch_size, int data_size,
1370                                 cudnnDataType_t data_type)
1371       : handle_(CreateTensorDescriptor()),
1372         num_layers_(num_layers),
1373         batch_size_(batch_size),
1374         data_size_(data_size),
1375         data_type_(data_type) {
1376     int dims[] = {num_layers, batch_size, data_size};
1377     int strides[] = {dims[1] * dims[2], dims[2], 1};
1378     CHECK_CUDNN_OK(cudnnSetTensorNdDescriptor(
1379         /*tensorDesc=*/handle_.get(), /*dataType=*/data_type,
1380         /*nbDims=*/sizeof(dims) / sizeof(dims[0]), /*dimA=*/dims,
1381         /*strideA=*/strides));
1382   }
1383 
handle() const1384   cudnnTensorDescriptor_t handle() const { return handle_.get(); }
1385 
num_layers() const1386   int num_layers() const { return num_layers_; }
batch_size() const1387   int batch_size() const { return batch_size_; }
data_size() const1388   int data_size() const { return data_size_; }
1389 
1390  private:
1391   TensorDescriptor handle_;
1392   int num_layers_;
1393   int batch_size_;
1394   int data_size_;
1395   cudnnDataType_t data_type_;
1396   SE_DISALLOW_COPY_AND_ASSIGN(CudnnRnnStateTensorDescriptor);
1397 };
1398 
1399 namespace {
1400 
1401 struct RnnModelDims {
1402   int num_layers = 0;
1403   int batch_size = 0;
1404   int max_seq_length = 0;
1405   int hidden_size = 0;
1406   int input_size = 0;
1407   int dir_count = 0;
1408 };
1409 
1410 template <class T>
ExtractAndCheckRnnForward(const CudnnRnnDescriptor & rnn_desc,const CudnnRnnSequenceTensorDescriptor & input_desc,const DeviceMemory<T> & input_data,const CudnnRnnStateTensorDescriptor & input_h_desc,const DeviceMemory<T> & input_h_data,const CudnnRnnStateTensorDescriptor & input_c_desc,const DeviceMemory<T> & input_c_data,const DeviceMemory<T> & params,const CudnnRnnSequenceTensorDescriptor & output_desc,const DeviceMemory<T> & output_data,const CudnnRnnStateTensorDescriptor & output_h_desc,const DeviceMemory<T> & output_h_data,const CudnnRnnStateTensorDescriptor & output_c_desc,const DeviceMemory<T> & output_c_data)1411 port::StatusOr<RnnModelDims> ExtractAndCheckRnnForward(
1412     const CudnnRnnDescriptor& rnn_desc,
1413     const CudnnRnnSequenceTensorDescriptor& input_desc,
1414     const DeviceMemory<T>& input_data,
1415     const CudnnRnnStateTensorDescriptor& input_h_desc,
1416     const DeviceMemory<T>& input_h_data,
1417     const CudnnRnnStateTensorDescriptor& input_c_desc,
1418     const DeviceMemory<T>& input_c_data, const DeviceMemory<T>& params,
1419     const CudnnRnnSequenceTensorDescriptor& output_desc,
1420     const DeviceMemory<T>& output_data,
1421     const CudnnRnnStateTensorDescriptor& output_h_desc,
1422     const DeviceMemory<T>& output_h_data,
1423     const CudnnRnnStateTensorDescriptor& output_c_desc,
1424     const DeviceMemory<T>& output_c_data) {
1425   // extract model parameters
1426   RnnModelDims model_dims;
1427   model_dims.num_layers = rnn_desc.num_layers();
1428   model_dims.batch_size = input_desc.batch_size();
1429   model_dims.max_seq_length = input_desc.max_seq_length();
1430   model_dims.hidden_size = rnn_desc.hidden_size();
1431   model_dims.input_size = input_desc.data_size();
1432   model_dims.dir_count =
1433       (rnn_desc.direction_mode() == CUDNN_BIDIRECTIONAL) ? 2 : 1;
1434 
1435   // check parameters
1436   if (!(input_h_desc.num_layers() ==
1437             model_dims.num_layers * model_dims.dir_count &&
1438         input_h_desc.batch_size() == model_dims.batch_size &&
1439         input_h_desc.data_size() == model_dims.hidden_size)) {
1440     return port::Status(port::error::INVALID_ARGUMENT, "Invalid input_h shape");
1441   }
1442   if (!(input_h_desc.num_layers() == input_c_desc.num_layers() &&
1443         input_h_desc.batch_size() == input_c_desc.batch_size() &&
1444         input_h_desc.data_size() == input_c_desc.data_size())) {
1445     return port::Status(port::error::INVALID_ARGUMENT, "Invalid input_c shape");
1446   }
1447   if (!(output_desc.max_seq_length() == model_dims.max_seq_length &&
1448         output_desc.batch_size() == model_dims.batch_size &&
1449         output_desc.data_size() ==
1450             model_dims.hidden_size * model_dims.dir_count)) {
1451     return port::Status(port::error::INVALID_ARGUMENT, "Invalid output shape");
1452   }
1453   if (!(input_h_desc.num_layers() == output_h_desc.num_layers() &&
1454         input_h_desc.batch_size() == output_h_desc.batch_size() &&
1455         input_h_desc.data_size() == output_h_desc.data_size())) {
1456     return port::Status(port::error::INVALID_ARGUMENT,
1457                         "Invalid output_h shape");
1458   }
1459   if (!(input_h_desc.num_layers() == output_c_desc.num_layers() &&
1460         input_h_desc.batch_size() == output_c_desc.batch_size() &&
1461         input_h_desc.data_size() == output_c_desc.data_size())) {
1462     return port::Status(port::error::INVALID_ARGUMENT,
1463                         "Invalid output_c shape");
1464   }
1465 
1466   return model_dims;
1467 }
1468 
CheckRNNParameterSize(const CudnnHandle & cudnn,const CudnnRnnDescriptor & rnn_desc,const CudnnRnnSequenceTensorDescriptor & input_desc)1469 port::Status CheckRNNParameterSize(
1470     const CudnnHandle& cudnn, const CudnnRnnDescriptor& rnn_desc,
1471     const CudnnRnnSequenceTensorDescriptor& input_desc) {
1472   size_t params_size_in_bytes = 0;
1473   RETURN_IF_CUDNN_ERROR(cudnnGetRNNParamsSize(
1474       /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(),
1475       /*xDesc=*/input_desc.handles()[0], /*sizeInBytes=*/&params_size_in_bytes,
1476       /*dataType=*/rnn_desc.data_type()));
1477   if (static_cast<int64>(params_size_in_bytes) !=
1478       rnn_desc.ParamsSizeInBytes()) {
1479     return port::Status(port::error::INVALID_ARGUMENT,
1480                         "Mismatching RNN parameter size");
1481   }
1482   return port::Status::OK();
1483 }
1484 
CreateRnnWorkspace(Stream * stream,const CudnnHandle & cudnn,const CudnnRnnDescriptor & rnn_desc,const CudnnRnnSequenceTensorDescriptor & input_desc,ScratchAllocator * workspace_allocator)1485 port::StatusOr<DeviceMemory<uint8>> CreateRnnWorkspace(
1486     Stream* stream, const CudnnHandle& cudnn,
1487     const CudnnRnnDescriptor& rnn_desc,
1488     const CudnnRnnSequenceTensorDescriptor& input_desc,
1489     ScratchAllocator* workspace_allocator) {
1490   // Query the workspace size.
1491   size_t workspace_size_in_bytes = 0;
1492   RETURN_IF_CUDNN_ERROR(cudnnGetRNNWorkspaceSize(
1493       /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(),
1494       /*seqLength=*/input_desc.max_seq_length(), /*xDesc=*/input_desc.handles(),
1495       /*sizeInBytes=*/&workspace_size_in_bytes));
1496   // Allocate the workspace.
1497   if (workspace_size_in_bytes == 0) {
1498     return DeviceMemory<uint8>();
1499   }
1500   return workspace_allocator->AllocateBytes(stream, workspace_size_in_bytes);
1501 }
1502 
1503 }  // namespace
1504 
1505 template <class T>
DoRnnForwardImpl(Stream * stream,const CudnnRnnDescriptor & rnn_desc,const CudnnRnnSequenceTensorDescriptor & input_desc,const DeviceMemory<T> & input_data,const CudnnRnnStateTensorDescriptor & input_h_desc,const DeviceMemory<T> & input_h_data,const CudnnRnnStateTensorDescriptor & input_c_desc,const DeviceMemory<T> & input_c_data,const DeviceMemory<T> & params,const CudnnRnnSequenceTensorDescriptor & output_desc,DeviceMemory<T> * output_data,const CudnnRnnStateTensorDescriptor & output_h_desc,DeviceMemory<T> * output_h_data,const CudnnRnnStateTensorDescriptor & output_c_desc,DeviceMemory<T> * output_c_data,bool is_training,ScratchAllocator * reserve_space_allocator,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)1506 port::Status CudnnSupport::DoRnnForwardImpl(
1507     Stream* stream, const CudnnRnnDescriptor& rnn_desc,
1508     const CudnnRnnSequenceTensorDescriptor& input_desc,
1509     const DeviceMemory<T>& input_data,
1510     const CudnnRnnStateTensorDescriptor& input_h_desc,
1511     const DeviceMemory<T>& input_h_data,
1512     const CudnnRnnStateTensorDescriptor& input_c_desc,
1513     const DeviceMemory<T>& input_c_data, const DeviceMemory<T>& params,
1514     const CudnnRnnSequenceTensorDescriptor& output_desc,
1515     DeviceMemory<T>* output_data,
1516     const CudnnRnnStateTensorDescriptor& output_h_desc,
1517     DeviceMemory<T>* output_h_data,
1518     const CudnnRnnStateTensorDescriptor& output_c_desc,
1519     DeviceMemory<T>* output_c_data, bool is_training,
1520     ScratchAllocator* reserve_space_allocator,
1521     ScratchAllocator* workspace_allocator,
1522     dnn::ProfileResult* output_profile_result) {
1523   SE_ASSIGN_OR_RETURN(
1524       RnnModelDims model_dims,
1525       ExtractAndCheckRnnForward(
1526           rnn_desc, input_desc, input_data, input_h_desc, input_h_data,
1527           input_c_desc, input_c_data, params, output_desc, *output_data,
1528           output_h_desc, *output_h_data, output_c_desc, *output_c_data));
1529 
1530   auto cudnn = cudnn_->GetHandle(parent_, stream);
1531 
1532   SE_RETURN_IF_ERROR(CheckRNNParameterSize(cudnn, rnn_desc, input_desc));
1533   SE_ASSIGN_OR_RETURN(DeviceMemory<uint8> workspace,
1534                       CreateRnnWorkspace(stream, cudnn, rnn_desc, input_desc,
1535                                          workspace_allocator))
1536 
1537   // query the reserve space size
1538   // allocate the reserve space
1539   DeviceMemory<uint8> reserve_space;
1540   if (is_training) {
1541     size_t reserve_space_size_in_bytes = 0;
1542     RETURN_IF_CUDNN_ERROR(cudnnGetRNNTrainingReserveSize(
1543         /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(),
1544         /*seqLength=*/model_dims.max_seq_length, /*xDesc=*/input_desc.handles(),
1545         /*sizeInBytes=*/&reserve_space_size_in_bytes));
1546 
1547     if (reserve_space_size_in_bytes > 0) {
1548       SE_ASSIGN_OR_RETURN(reserve_space,
1549                           reserve_space_allocator->AllocateBytes(
1550                               stream, reserve_space_size_in_bytes));
1551     }
1552   }
1553 
1554   std::unique_ptr<GpuTimer, GpuTimerDeleter> timer;
1555   const bool is_profiling = output_profile_result != nullptr;
1556   if (is_profiling) {
1557     timer.reset(new GpuTimer(parent_));
1558     // The start and stop of the timer should be as close to the Cudnn call as
1559     // possible. It is still possible for other threads to issue workload on
1560     // to this stream. So it could take multiple profiling measurements.
1561     if (!timer->Init() || !timer->Start(AsGpuStream(stream))) {
1562       return port::Status(port::error::INTERNAL, "Failed to start timer");
1563     }
1564   }
1565 
1566   if (!is_training) {
1567     if (input_desc.is_var_seq_lengths()) {
1568 #if CUDNN_VERSION >= 7201
1569       RETURN_IF_CUDNN_ERROR(cudnnRNNForwardInferenceEx(
1570           /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(),
1571           /*xDesc=*/input_desc.data_handle(), /*x=*/input_data.opaque(),
1572           /*hxDesc=*/input_h_desc.handle(), /*hx=*/input_h_data.opaque(),
1573           /*cxDesc=*/input_c_desc.handle(), /*cx=*/input_c_data.opaque(),
1574           /*wDesc=*/rnn_desc.params_handle(), /*w=*/params.opaque(),
1575           /*yDesc=*/output_desc.data_handle(),
1576           /*y=*/output_data->opaque(),
1577           /*hyDesc=*/output_h_desc.handle(), /*hy=*/output_h_data->opaque(),
1578           /*cyDesc=*/output_c_desc.handle(), /*cy=*/output_c_data->opaque(),
1579           nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr,
1580           nullptr,
1581           /*workspace=*/workspace.opaque(),
1582           /*workSpaceSizeInBytes=*/workspace.size()));
1583 #else
1584       return port::Status(port::error::INVALID_ARGUMENT,
1585                           "No supported cudnnRNNForwardInferenceEx when "
1586                           "CUDNN_VERSION < 7.2.1");
1587 #endif
1588     } else {
1589       RETURN_IF_CUDNN_ERROR(cudnnRNNForwardInference(
1590           /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(),
1591           /*seqLength=*/model_dims.max_seq_length,
1592           /*xDesc=*/input_desc.handles(),
1593           /*x=*/input_data.opaque(), /*hxDesc=*/input_h_desc.handle(),
1594           /*hx=*/input_h_data.opaque(), /*cxDesc=*/input_c_desc.handle(),
1595           /*cx=*/input_c_data.opaque(), /*wDesc=*/rnn_desc.params_handle(),
1596           /*w=*/params.opaque(), /*yDesc=*/output_desc.handles(),
1597           /*y=*/output_data->opaque(), /*hyDesc=*/output_h_desc.handle(),
1598           /*hy=*/output_h_data->opaque(), /*cyDesc=*/output_c_desc.handle(),
1599           /*cy=*/output_c_data->opaque(), /*workspace=*/workspace.opaque(),
1600           /*workSpaceSizeInBytes=*/workspace.size()));
1601     }
1602   } else {
1603     if (input_desc.is_var_seq_lengths()) {
1604 #if CUDNN_VERSION >= 7201
1605       // cudnnSetRNNPaddingMode(rnn_desc.handle(), CUDNN_RNN_PADDED_IO_ENABLED);
1606       RETURN_IF_CUDNN_ERROR(cudnnRNNForwardTrainingEx(
1607           /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(),
1608           /*xDesc=*/input_desc.data_handle(), /*x=*/input_data.opaque(),
1609           /*hxDesc=*/input_h_desc.handle(), /*hx=*/input_h_data.opaque(),
1610           /*cxDesc=*/input_c_desc.handle(), /*cx=*/input_c_data.opaque(),
1611           /*wDesc=*/rnn_desc.params_handle(), /*w=*/params.opaque(),
1612           /*yDesc=*/output_desc.data_handle(),
1613           /*y=*/output_data->opaque(),
1614           /*hyDesc=*/output_h_desc.handle(), /*hy=*/output_h_data->opaque(),
1615           /*cyDesc=*/output_c_desc.handle(), /*cy=*/output_c_data->opaque(),
1616           nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr,
1617           nullptr,
1618           /*workspace=*/workspace.opaque(),
1619           /*workSpaceSizeInBytes=*/workspace.size(),
1620           /*reserveSpace=*/reserve_space.opaque(),
1621           /*reserveSpaceSizeInBytes=*/reserve_space.size()));
1622 #else
1623       return port::Status(port::error::INVALID_ARGUMENT,
1624                           "No supported cudnnRNNForwardTrainingEx when "
1625                           "CUDNN_VERSION < 7.2.1");
1626 #endif
1627     } else {
1628       RETURN_IF_CUDNN_ERROR(cudnnRNNForwardTraining(
1629           /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(),
1630           /*seqLength=*/model_dims.max_seq_length,
1631           /*xDesc=*/input_desc.handles(),
1632           /*x=*/input_data.opaque(), /*hxDesc=*/input_h_desc.handle(),
1633           /*hx=*/input_h_data.opaque(), /*cxDesc=*/input_c_desc.handle(),
1634           /*cx=*/input_c_data.opaque(), /*wDesc=*/rnn_desc.params_handle(),
1635           /*w=*/params.opaque(), /*yDesc=*/output_desc.handles(),
1636           /*y=*/output_data->opaque(), /*hyDesc=*/output_h_desc.handle(),
1637           /*hy=*/output_h_data->opaque(), /*cyDesc=*/output_c_desc.handle(),
1638           /*cy=*/output_c_data->opaque(), /*workspace=*/workspace.opaque(),
1639           /*workSpaceSizeInBytes=*/workspace.size(),
1640           /*reserveSpace=*/reserve_space.opaque(),
1641           /*reserveSpaceSizeInBytes=*/reserve_space.size()));
1642     }
1643   }
1644 
1645   if (is_profiling) {
1646     if (!timer->Stop(AsGpuStream(stream))) {
1647       return port::Status(port::error::INTERNAL, "Failed to stop timer");
1648     }
1649     auto algo_desc = *rnn_desc.algorithm_config().algorithm();
1650     output_profile_result->set_algorithm(algo_desc);
1651     output_profile_result->set_elapsed_time_in_ms(
1652         timer->GetElapsedMilliseconds());
1653   }
1654 
1655   return port::Status::OK();
1656 }
1657 
1658 template <class T>
DoRnnBackwardImpl(Stream * stream,const CudnnRnnDescriptor & rnn_desc,const CudnnRnnSequenceTensorDescriptor & input_desc,const DeviceMemory<T> & input_data,const CudnnRnnStateTensorDescriptor & input_h_desc,const DeviceMemory<T> & input_h_data,const CudnnRnnStateTensorDescriptor & input_c_desc,const DeviceMemory<T> & input_c_data,const DeviceMemory<T> & params,const CudnnRnnSequenceTensorDescriptor & output_desc,const DeviceMemory<T> & output_data,const CudnnRnnStateTensorDescriptor & output_h_desc,const DeviceMemory<T> & output_h_data,const CudnnRnnStateTensorDescriptor & output_c_desc,const DeviceMemory<T> & output_c_data,const DeviceMemory<T> & output_backprop_data,const DeviceMemory<T> & output_h_backprop_data,const DeviceMemory<T> & output_c_backprop_data,DeviceMemory<T> * input_backprop_data,DeviceMemory<T> * input_h_backprop_data,DeviceMemory<T> * input_c_backprop_data,DeviceMemory<T> * params_backprop_data,DeviceMemory<uint8> * reserve_space_data,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)1659 port::Status CudnnSupport::DoRnnBackwardImpl(
1660     Stream* stream, const CudnnRnnDescriptor& rnn_desc,
1661     const CudnnRnnSequenceTensorDescriptor& input_desc,
1662     const DeviceMemory<T>& input_data,
1663     const CudnnRnnStateTensorDescriptor& input_h_desc,
1664     const DeviceMemory<T>& input_h_data,
1665     const CudnnRnnStateTensorDescriptor& input_c_desc,
1666     const DeviceMemory<T>& input_c_data, const DeviceMemory<T>& params,
1667     const CudnnRnnSequenceTensorDescriptor& output_desc,
1668     const DeviceMemory<T>& output_data,
1669     const CudnnRnnStateTensorDescriptor& output_h_desc,
1670     const DeviceMemory<T>& output_h_data,
1671     const CudnnRnnStateTensorDescriptor& output_c_desc,
1672     const DeviceMemory<T>& output_c_data,
1673     const DeviceMemory<T>& output_backprop_data,
1674     const DeviceMemory<T>& output_h_backprop_data,
1675     const DeviceMemory<T>& output_c_backprop_data,
1676     DeviceMemory<T>* input_backprop_data,
1677     DeviceMemory<T>* input_h_backprop_data,
1678     DeviceMemory<T>* input_c_backprop_data,
1679     DeviceMemory<T>* params_backprop_data,
1680     DeviceMemory<uint8>* reserve_space_data,
1681     ScratchAllocator* workspace_allocator,
1682     dnn::ProfileResult* output_profile_result) {
1683   SE_ASSIGN_OR_RETURN(
1684       RnnModelDims model_dims,
1685       ExtractAndCheckRnnForward(rnn_desc, input_desc, input_data, input_h_desc,
1686                                 input_h_data, input_c_desc, input_c_data,
1687                                 params, output_desc, output_data, output_h_desc,
1688                                 output_h_data, output_c_desc, output_c_data));
1689 
1690   auto cudnn = cudnn_->GetHandle(parent_, stream);
1691 
1692   SE_RETURN_IF_ERROR(CheckRNNParameterSize(cudnn, rnn_desc, input_desc));
1693   SE_ASSIGN_OR_RETURN(DeviceMemory<uint8> workspace,
1694                       CreateRnnWorkspace(stream, cudnn, rnn_desc, input_desc,
1695                                          workspace_allocator));
1696 
1697   std::unique_ptr<GpuTimer, GpuTimerDeleter> timer;
1698   const bool is_profiling = output_profile_result != nullptr;
1699   if (is_profiling) {
1700     timer.reset(new GpuTimer(parent_));
1701     // The start and stop of the timer should be as close to the Cudnn call as
1702     // possible. It is still possible for other threads to issue workload on
1703     // to this stream. So it could take multiple profiling measurements.
1704     if (!timer->Init() || !timer->Start(AsGpuStream(stream))) {
1705       return port::Status(port::error::INTERNAL, "Failed to start timer");
1706     }
1707   }
1708 
1709   if (input_desc.is_var_seq_lengths()) {
1710 #if CUDNN_VERSION >= 7201
1711     RETURN_IF_CUDNN_ERROR(cudnnRNNBackwardDataEx(
1712         /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(),
1713         /*yDesc=*/output_desc.data_handle(), /*y=*/output_data.opaque(),
1714         /*dyDesc=*/output_desc.data_handle(),
1715         /*dy=*/output_backprop_data.opaque(), nullptr, nullptr,
1716         /*dhyDesc=*/output_h_desc.handle(),
1717         /*dhy=*/output_h_backprop_data.opaque(),
1718         /*dcyDesc=*/output_c_desc.handle(),
1719         /*dcy=*/output_c_backprop_data.opaque(),
1720         /*wDesc=*/rnn_desc.params_handle(), /*w=*/params.opaque(),
1721         /*hxDesc=*/input_h_desc.handle(), /*hx=*/input_h_data.opaque(),
1722         /*cxDesc=*/input_c_desc.handle(), /*cx=*/input_c_data.opaque(),
1723         /*dxDesc=*/input_desc.data_handle(),
1724         /*dx=*/input_backprop_data->opaque(),
1725         /*dhxDesc=*/input_h_desc.handle(),
1726         /*dhx=*/input_h_backprop_data->opaque(),
1727         /*dcxDesc=*/input_c_desc.handle(),
1728         /*dcx=*/input_c_backprop_data->opaque(), nullptr, nullptr,
1729         /*workspace=*/workspace.opaque(),
1730         /*workSpaceSizeInBytes=*/workspace.size(),
1731         /*reserveSpace=*/reserve_space_data->opaque(),
1732         /*reserveSpaceSizeInBytes=*/reserve_space_data->size()));
1733 #else
1734     return port::Status(port::error::INVALID_ARGUMENT,
1735                         "No supported cudnnRNNBackwardDataEx when "
1736                         "CUDNN_VERSION < 7.2.1");
1737 #endif
1738   } else {
1739     RETURN_IF_CUDNN_ERROR(cudnnRNNBackwardData(
1740         /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(),
1741         /*seqLength=*/model_dims.max_seq_length,
1742         /*yDesc=*/output_desc.handles(),
1743         /*y=*/output_data.opaque(), /*dyDesc=*/output_desc.handles(),
1744         /*dy=*/output_backprop_data.opaque(),
1745         /*dhyDesc=*/output_h_desc.handle(),
1746         /*dhy=*/output_h_backprop_data.opaque(),
1747         /*dcyDesc=*/output_c_desc.handle(),
1748         /*dcy=*/output_c_backprop_data.opaque(),
1749         /*wDesc=*/rnn_desc.params_handle(), /*w=*/params.opaque(),
1750         /*hxDesc=*/input_h_desc.handle(), /*hx=*/input_h_data.opaque(),
1751         /*cxDesc=*/input_c_desc.handle(), /*cx=*/input_c_data.opaque(),
1752         /*dxDesc=*/input_desc.handles(), /*dx=*/input_backprop_data->opaque(),
1753         /*dhxDesc=*/input_h_desc.handle(),
1754         /*dhx=*/input_h_backprop_data->opaque(),
1755         /*dcxDesc=*/input_c_desc.handle(),
1756         /*dcx=*/input_c_backprop_data->opaque(),
1757         /*workspace=*/workspace.opaque(),
1758         /*workSpaceSizeInBytes=*/workspace.size(),
1759         /*reserveSpace=*/reserve_space_data->opaque(),
1760         /*reserveSpaceSizeInBytes=*/reserve_space_data->size()));
1761   }
1762 
1763   if (params_backprop_data != nullptr) {
1764     // Clear the dw to zeros.
1765     stream->ThenMemZero(params_backprop_data, params_backprop_data->size());
1766     if (input_desc.is_var_seq_lengths()) {
1767 #if CUDNN_VERSION >= 7201
1768       RETURN_IF_CUDNN_ERROR(cudnnRNNBackwardWeightsEx(
1769           /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(),
1770           /*xDesc=*/input_desc.data_handle(), /*x=*/input_data.opaque(),
1771           /*hxDesc=*/input_h_desc.handle(), /*hx=*/input_h_data.opaque(),
1772           /*yDesc=*/output_desc.data_handle(),
1773           /*y=*/output_data.opaque(),
1774           /*workspace=*/workspace.opaque(),
1775           /*workSpaceSizeInBytes=*/workspace.size(),
1776           /*dwDesc=*/rnn_desc.params_handle(),
1777           /*dw=*/params_backprop_data->opaque(),
1778           /*reserveSpace=*/reserve_space_data->opaque(),
1779           /*reserveSpaceSizeInBytes=*/reserve_space_data->size()));
1780 #else
1781       return port::Status(port::error::INVALID_ARGUMENT,
1782                           "No supported cudnnRNNBackwardWeightsEx when "
1783                           "CUDNN_VERSION < 7.2.1");
1784 #endif
1785     } else {
1786       // make the backward weight call
1787       RETURN_IF_CUDNN_ERROR(cudnnRNNBackwardWeights(
1788           /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(),
1789           /*seqLength=*/model_dims.max_seq_length,
1790           /*xDesc=*/input_desc.handles(),
1791           /*x=*/input_data.opaque(), /*hxDesc=*/input_h_desc.handle(),
1792           /*hx=*/input_h_data.opaque(), /*yDesc=*/output_desc.handles(),
1793           /*y=*/output_data.opaque(), /*workspace=*/workspace.opaque(),
1794           /*workSpaceSizeInBytes=*/workspace.size(),
1795           /*dwDesc=*/rnn_desc.params_handle(),
1796           /*dw=*/params_backprop_data->opaque(),
1797           /*reserveSpace=*/reserve_space_data->opaque(),
1798           /*reserveSpaceSizeInBytes=*/reserve_space_data->size()));
1799     }
1800   }
1801 
1802   if (is_profiling) {
1803     if (!timer->Stop(AsGpuStream(stream))) {
1804       return port::Status(port::error::INTERNAL, "Failed to stop timer");
1805     }
1806     auto algo_desc = *rnn_desc.algorithm_config().algorithm();
1807     output_profile_result->set_algorithm(algo_desc);
1808     output_profile_result->set_elapsed_time_in_ms(
1809         timer->GetElapsedMilliseconds());
1810   }
1811 
1812   return port::Status::OK();
1813 }
1814 
1815 port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>>
createRnnDescriptor(int num_layers,int hidden_size,int input_size,int batch_size,dnn::RnnInputMode input_mode,dnn::RnnDirectionMode direction_mode,dnn::RnnMode rnn_mode,dnn::DataType data_type,const dnn::AlgorithmConfig & algorithm_config,float dropout,uint64 seed,ScratchAllocator * state_allocator)1816 CudnnSupport::createRnnDescriptor(
1817     int num_layers, int hidden_size, int input_size, int batch_size,
1818     dnn::RnnInputMode input_mode, dnn::RnnDirectionMode direction_mode,
1819     dnn::RnnMode rnn_mode, dnn::DataType data_type,
1820     const dnn::AlgorithmConfig& algorithm_config, float dropout, uint64 seed,
1821     ScratchAllocator* state_allocator) {
1822   // Setting up a cudnnRNNDescriptor requires a cuDNN handle, but because it's
1823   // not enqueueing anything into a stream, we pass in the null stream.
1824   auto cudnn = cudnn_->GetHandle(parent_, /*stream=*/nullptr);
1825   SE_ASSIGN_OR_RETURN(
1826       CudnnRnnDescriptor rnn_desc,
1827       CudnnRnnDescriptor::Create(
1828           cudnn, num_layers, hidden_size, input_size, batch_size,
1829           ToCudnnRnnInputMode(input_mode),
1830           ToCudnnRnnDirectionMode(direction_mode), ToCudnnRnnMode(rnn_mode),
1831           ToCudnnDataType(data_type), GetRnnComputeType(data_type),
1832           algorithm_config, dropout, seed, state_allocator));
1833   return std::unique_ptr<dnn::RnnDescriptor>(
1834       new CudnnRnnDescriptor(std::move(rnn_desc)));
1835 }
1836 
1837 port::StatusOr<std::unique_ptr<dnn::RnnSequenceTensorDescriptor>>
createRnnSequenceTensorDescriptor(int max_seq_length,int batch_size,int data_size,dnn::DataType data_type)1838 CudnnSupport::createRnnSequenceTensorDescriptor(int max_seq_length,
1839                                                 int batch_size, int data_size,
1840                                                 dnn::DataType data_type) {
1841   SE_ASSIGN_OR_RETURN(CudnnRnnSequenceTensorDescriptor descriptor,
1842                       CudnnRnnSequenceTensorDescriptor::Create(
1843                           parent_, max_seq_length, batch_size, data_size,
1844                           ToCudnnDataType(data_type)));
1845   return std::unique_ptr<dnn::RnnSequenceTensorDescriptor>(
1846       new CudnnRnnSequenceTensorDescriptor(std::move(descriptor)));
1847 }
1848 
1849 port::StatusOr<std::unique_ptr<dnn::RnnSequenceTensorDescriptor>>
createRnnSequenceTensorDescriptor(int max_seq_length,int batch_size,int data_size,const absl::Span<const int> & seq_lengths,bool time_major,dnn::DataType data_type)1850 CudnnSupport::createRnnSequenceTensorDescriptor(
1851     int max_seq_length, int batch_size, int data_size,
1852     const absl::Span<const int>& seq_lengths, bool time_major,
1853     dnn::DataType data_type) {
1854   SE_ASSIGN_OR_RETURN(CudnnRnnSequenceTensorDescriptor descriptor,
1855                       CudnnRnnSequenceTensorDescriptor::Create(
1856                           parent_, max_seq_length, batch_size, data_size,
1857                           seq_lengths, time_major, ToCudnnDataType(data_type)));
1858   return std::unique_ptr<dnn::RnnSequenceTensorDescriptor>(
1859       new CudnnRnnSequenceTensorDescriptor(std::move(descriptor)));
1860 }
1861 
1862 port::StatusOr<std::unique_ptr<dnn::RnnStateTensorDescriptor>>
createRnnStateTensorDescriptor(int num_layer,int batch_size,int data_size,dnn::DataType data_type)1863 CudnnSupport::createRnnStateTensorDescriptor(int num_layer, int batch_size,
1864                                              int data_size,
1865                                              dnn::DataType data_type) {
1866   return std::unique_ptr<dnn::RnnStateTensorDescriptor>(
1867       new CudnnRnnStateTensorDescriptor(parent_, num_layer, batch_size,
1868                                         data_size, ToCudnnDataType(data_type)));
1869 }
1870 
DoRnnForward(Stream * stream,const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<Eigen::half> & input_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<Eigen::half> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<Eigen::half> & input_c_data,const DeviceMemory<Eigen::half> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,DeviceMemory<Eigen::half> * output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,DeviceMemory<Eigen::half> * output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,DeviceMemory<Eigen::half> * output_c_data,bool is_training,ScratchAllocator * reserve_space_allocator,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)1871 bool CudnnSupport::DoRnnForward(
1872     Stream* stream, const dnn::RnnDescriptor& rnn_desc,
1873     const dnn::RnnSequenceTensorDescriptor& input_desc,
1874     const DeviceMemory<Eigen::half>& input_data,
1875     const dnn::RnnStateTensorDescriptor& input_h_desc,
1876     const DeviceMemory<Eigen::half>& input_h_data,
1877     const dnn::RnnStateTensorDescriptor& input_c_desc,
1878     const DeviceMemory<Eigen::half>& input_c_data,
1879     const DeviceMemory<Eigen::half>& params,
1880     const dnn::RnnSequenceTensorDescriptor& output_desc,
1881     DeviceMemory<Eigen::half>* output_data,
1882     const dnn::RnnStateTensorDescriptor& output_h_desc,
1883     DeviceMemory<Eigen::half>* output_h_data,
1884     const dnn::RnnStateTensorDescriptor& output_c_desc,
1885     DeviceMemory<Eigen::half>* output_c_data, bool is_training,
1886     ScratchAllocator* reserve_space_allocator,
1887     ScratchAllocator* workspace_allocator,
1888     dnn::ProfileResult* output_profile_result) {
1889   const CudnnRnnDescriptor& cudnn_rnn_desc =
1890       static_cast<const CudnnRnnDescriptor&>(rnn_desc);
1891   const CudnnRnnSequenceTensorDescriptor& cudnn_input_desc =
1892       static_cast<const CudnnRnnSequenceTensorDescriptor&>(input_desc);
1893   const CudnnRnnStateTensorDescriptor& cudnn_input_h_desc =
1894       static_cast<const CudnnRnnStateTensorDescriptor&>(input_h_desc);
1895   const CudnnRnnStateTensorDescriptor& cudnn_input_c_desc =
1896       static_cast<const CudnnRnnStateTensorDescriptor&>(input_c_desc);
1897   const CudnnRnnSequenceTensorDescriptor& cudnn_output_desc =
1898       static_cast<const CudnnRnnSequenceTensorDescriptor&>(output_desc);
1899   const CudnnRnnStateTensorDescriptor& cudnn_output_h_desc =
1900       static_cast<const CudnnRnnStateTensorDescriptor&>(output_h_desc);
1901   const CudnnRnnStateTensorDescriptor& cudnn_output_c_desc =
1902       static_cast<const CudnnRnnStateTensorDescriptor&>(output_c_desc);
1903   return IsStatusOk(
1904       DoRnnForwardImpl<Eigen::half>(
1905           stream, cudnn_rnn_desc, cudnn_input_desc, input_data,
1906           cudnn_input_h_desc, input_h_data, cudnn_input_c_desc, input_c_data,
1907           params, cudnn_output_desc, output_data, cudnn_output_h_desc,
1908           output_h_data, cudnn_output_c_desc, output_c_data, is_training,
1909           reserve_space_allocator, workspace_allocator, output_profile_result),
1910       /*report_error=*/!output_profile_result);
1911 }
1912 
DoRnnForward(Stream * stream,const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<float> & input_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<float> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<float> & input_c_data,const DeviceMemory<float> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,DeviceMemory<float> * output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,DeviceMemory<float> * output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,DeviceMemory<float> * output_c_data,bool is_training,ScratchAllocator * reserve_space_allocator,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)1913 bool CudnnSupport::DoRnnForward(
1914     Stream* stream, const dnn::RnnDescriptor& rnn_desc,
1915     const dnn::RnnSequenceTensorDescriptor& input_desc,
1916     const DeviceMemory<float>& input_data,
1917     const dnn::RnnStateTensorDescriptor& input_h_desc,
1918     const DeviceMemory<float>& input_h_data,
1919     const dnn::RnnStateTensorDescriptor& input_c_desc,
1920     const DeviceMemory<float>& input_c_data, const DeviceMemory<float>& params,
1921     const dnn::RnnSequenceTensorDescriptor& output_desc,
1922     DeviceMemory<float>* output_data,
1923     const dnn::RnnStateTensorDescriptor& output_h_desc,
1924     DeviceMemory<float>* output_h_data,
1925     const dnn::RnnStateTensorDescriptor& output_c_desc,
1926     DeviceMemory<float>* output_c_data, bool is_training,
1927     ScratchAllocator* reserve_space_allocator,
1928     ScratchAllocator* workspace_allocator,
1929     dnn::ProfileResult* output_profile_result) {
1930   const CudnnRnnDescriptor& cudnn_rnn_desc =
1931       static_cast<const CudnnRnnDescriptor&>(rnn_desc);
1932   const CudnnRnnSequenceTensorDescriptor& cudnn_input_desc =
1933       static_cast<const CudnnRnnSequenceTensorDescriptor&>(input_desc);
1934   const CudnnRnnStateTensorDescriptor& cudnn_input_h_desc =
1935       static_cast<const CudnnRnnStateTensorDescriptor&>(input_h_desc);
1936   const CudnnRnnStateTensorDescriptor& cudnn_input_c_desc =
1937       static_cast<const CudnnRnnStateTensorDescriptor&>(input_c_desc);
1938   const CudnnRnnSequenceTensorDescriptor& cudnn_output_desc =
1939       static_cast<const CudnnRnnSequenceTensorDescriptor&>(output_desc);
1940   const CudnnRnnStateTensorDescriptor& cudnn_output_h_desc =
1941       static_cast<const CudnnRnnStateTensorDescriptor&>(output_h_desc);
1942   const CudnnRnnStateTensorDescriptor& cudnn_output_c_desc =
1943       static_cast<const CudnnRnnStateTensorDescriptor&>(output_c_desc);
1944   return IsStatusOk(
1945       DoRnnForwardImpl<float>(
1946           stream, cudnn_rnn_desc, cudnn_input_desc, input_data,
1947           cudnn_input_h_desc, input_h_data, cudnn_input_c_desc, input_c_data,
1948           params, cudnn_output_desc, output_data, cudnn_output_h_desc,
1949           output_h_data, cudnn_output_c_desc, output_c_data, is_training,
1950           reserve_space_allocator, workspace_allocator, output_profile_result),
1951       /*report_error=*/!output_profile_result);
1952 }
1953 
DoRnnForward(Stream * stream,const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<double> & input_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<double> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<double> & input_c_data,const DeviceMemory<double> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,DeviceMemory<double> * output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,DeviceMemory<double> * output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,DeviceMemory<double> * output_c_data,bool is_training,ScratchAllocator * reserve_space_allocator,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)1954 bool CudnnSupport::DoRnnForward(
1955     Stream* stream, const dnn::RnnDescriptor& rnn_desc,
1956     const dnn::RnnSequenceTensorDescriptor& input_desc,
1957     const DeviceMemory<double>& input_data,
1958     const dnn::RnnStateTensorDescriptor& input_h_desc,
1959     const DeviceMemory<double>& input_h_data,
1960     const dnn::RnnStateTensorDescriptor& input_c_desc,
1961     const DeviceMemory<double>& input_c_data,
1962     const DeviceMemory<double>& params,
1963     const dnn::RnnSequenceTensorDescriptor& output_desc,
1964     DeviceMemory<double>* output_data,
1965     const dnn::RnnStateTensorDescriptor& output_h_desc,
1966     DeviceMemory<double>* output_h_data,
1967     const dnn::RnnStateTensorDescriptor& output_c_desc,
1968     DeviceMemory<double>* output_c_data, bool is_training,
1969     ScratchAllocator* reserve_space_allocator,
1970     ScratchAllocator* workspace_allocator,
1971     dnn::ProfileResult* output_profile_result) {
1972   const CudnnRnnDescriptor& cudnn_rnn_desc =
1973       static_cast<const CudnnRnnDescriptor&>(rnn_desc);
1974   const CudnnRnnSequenceTensorDescriptor& cudnn_input_desc =
1975       static_cast<const CudnnRnnSequenceTensorDescriptor&>(input_desc);
1976   const CudnnRnnStateTensorDescriptor& cudnn_input_h_desc =
1977       static_cast<const CudnnRnnStateTensorDescriptor&>(input_h_desc);
1978   const CudnnRnnStateTensorDescriptor& cudnn_input_c_desc =
1979       static_cast<const CudnnRnnStateTensorDescriptor&>(input_c_desc);
1980   const CudnnRnnSequenceTensorDescriptor& cudnn_output_desc =
1981       static_cast<const CudnnRnnSequenceTensorDescriptor&>(output_desc);
1982   const CudnnRnnStateTensorDescriptor& cudnn_output_h_desc =
1983       static_cast<const CudnnRnnStateTensorDescriptor&>(output_h_desc);
1984   const CudnnRnnStateTensorDescriptor& cudnn_output_c_desc =
1985       static_cast<const CudnnRnnStateTensorDescriptor&>(output_c_desc);
1986   return IsStatusOk(
1987       DoRnnForwardImpl<double>(
1988           stream, cudnn_rnn_desc, cudnn_input_desc, input_data,
1989           cudnn_input_h_desc, input_h_data, cudnn_input_c_desc, input_c_data,
1990           params, cudnn_output_desc, output_data, cudnn_output_h_desc,
1991           output_h_data, cudnn_output_c_desc, output_c_data, is_training,
1992           reserve_space_allocator, workspace_allocator, output_profile_result),
1993       /*report_error=*/!output_profile_result);
1994 }
1995 
DoRnnBackward(Stream * stream,const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<Eigen::half> & input_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<Eigen::half> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<Eigen::half> & input_c_data,const DeviceMemory<Eigen::half> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,const DeviceMemory<Eigen::half> & output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,const DeviceMemory<Eigen::half> & output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,const DeviceMemory<Eigen::half> & output_c_data,const DeviceMemory<Eigen::half> & output_backprop_data,const DeviceMemory<Eigen::half> & output_h_backprop_data,const DeviceMemory<Eigen::half> & output_c_backprop_data,DeviceMemory<Eigen::half> * input_backprop_data,DeviceMemory<Eigen::half> * input_h_backprop_data,DeviceMemory<Eigen::half> * input_c_backprop_data,DeviceMemory<Eigen::half> * params_backprop_data,DeviceMemory<uint8> * reserve_space_data,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)1996 bool CudnnSupport::DoRnnBackward(
1997     Stream* stream, const dnn::RnnDescriptor& rnn_desc,
1998     const dnn::RnnSequenceTensorDescriptor& input_desc,
1999     const DeviceMemory<Eigen::half>& input_data,
2000     const dnn::RnnStateTensorDescriptor& input_h_desc,
2001     const DeviceMemory<Eigen::half>& input_h_data,
2002     const dnn::RnnStateTensorDescriptor& input_c_desc,
2003     const DeviceMemory<Eigen::half>& input_c_data,
2004     const DeviceMemory<Eigen::half>& params,
2005     const dnn::RnnSequenceTensorDescriptor& output_desc,
2006     const DeviceMemory<Eigen::half>& output_data,
2007     const dnn::RnnStateTensorDescriptor& output_h_desc,
2008     const DeviceMemory<Eigen::half>& output_h_data,
2009     const dnn::RnnStateTensorDescriptor& output_c_desc,
2010     const DeviceMemory<Eigen::half>& output_c_data,
2011     const DeviceMemory<Eigen::half>& output_backprop_data,
2012     const DeviceMemory<Eigen::half>& output_h_backprop_data,
2013     const DeviceMemory<Eigen::half>& output_c_backprop_data,
2014     DeviceMemory<Eigen::half>* input_backprop_data,
2015     DeviceMemory<Eigen::half>* input_h_backprop_data,
2016     DeviceMemory<Eigen::half>* input_c_backprop_data,
2017     DeviceMemory<Eigen::half>* params_backprop_data,
2018     DeviceMemory<uint8>* reserve_space_data,
2019     ScratchAllocator* workspace_allocator,
2020     dnn::ProfileResult* output_profile_result) {
2021   const CudnnRnnDescriptor& cudnn_rnn_desc =
2022       static_cast<const CudnnRnnDescriptor&>(rnn_desc);
2023   const CudnnRnnSequenceTensorDescriptor& cudnn_input_desc =
2024       static_cast<const CudnnRnnSequenceTensorDescriptor&>(input_desc);
2025   const CudnnRnnStateTensorDescriptor& cudnn_input_h_desc =
2026       static_cast<const CudnnRnnStateTensorDescriptor&>(input_h_desc);
2027   const CudnnRnnStateTensorDescriptor& cudnn_input_c_desc =
2028       static_cast<const CudnnRnnStateTensorDescriptor&>(input_c_desc);
2029   const CudnnRnnSequenceTensorDescriptor& cudnn_output_desc =
2030       static_cast<const CudnnRnnSequenceTensorDescriptor&>(output_desc);
2031   const CudnnRnnStateTensorDescriptor& cudnn_output_h_desc =
2032       static_cast<const CudnnRnnStateTensorDescriptor&>(output_h_desc);
2033   const CudnnRnnStateTensorDescriptor& cudnn_output_c_desc =
2034       static_cast<const CudnnRnnStateTensorDescriptor&>(output_c_desc);
2035   return IsStatusOk(
2036       DoRnnBackwardImpl<Eigen::half>(
2037           stream, cudnn_rnn_desc, cudnn_input_desc, input_data,
2038           cudnn_input_h_desc, input_h_data, cudnn_input_c_desc, input_c_data,
2039           params, cudnn_output_desc, output_data, cudnn_output_h_desc,
2040           output_h_data, cudnn_output_c_desc, output_c_data,
2041           output_backprop_data, output_h_backprop_data, output_c_backprop_data,
2042           input_backprop_data, input_h_backprop_data, input_c_backprop_data,
2043           params_backprop_data, reserve_space_data, workspace_allocator,
2044           output_profile_result),
2045       /*report_error=*/!output_profile_result);
2046 }
2047 
DoRnnBackward(Stream * stream,const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<float> & input_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<float> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<float> & input_c_data,const DeviceMemory<float> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,const DeviceMemory<float> & output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,const DeviceMemory<float> & output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,const DeviceMemory<float> & output_c_data,const DeviceMemory<float> & output_backprop_data,const DeviceMemory<float> & output_h_backprop_data,const DeviceMemory<float> & output_c_backprop_data,DeviceMemory<float> * input_backprop_data,DeviceMemory<float> * input_h_backprop_data,DeviceMemory<float> * input_c_backprop_data,DeviceMemory<float> * params_backprop_data,DeviceMemory<uint8> * reserve_space_data,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)2048 bool CudnnSupport::DoRnnBackward(
2049     Stream* stream, const dnn::RnnDescriptor& rnn_desc,
2050     const dnn::RnnSequenceTensorDescriptor& input_desc,
2051     const DeviceMemory<float>& input_data,
2052     const dnn::RnnStateTensorDescriptor& input_h_desc,
2053     const DeviceMemory<float>& input_h_data,
2054     const dnn::RnnStateTensorDescriptor& input_c_desc,
2055     const DeviceMemory<float>& input_c_data, const DeviceMemory<float>& params,
2056     const dnn::RnnSequenceTensorDescriptor& output_desc,
2057     const DeviceMemory<float>& output_data,
2058     const dnn::RnnStateTensorDescriptor& output_h_desc,
2059     const DeviceMemory<float>& output_h_data,
2060     const dnn::RnnStateTensorDescriptor& output_c_desc,
2061     const DeviceMemory<float>& output_c_data,
2062     const DeviceMemory<float>& output_backprop_data,
2063     const DeviceMemory<float>& output_h_backprop_data,
2064     const DeviceMemory<float>& output_c_backprop_data,
2065     DeviceMemory<float>* input_backprop_data,
2066     DeviceMemory<float>* input_h_backprop_data,
2067     DeviceMemory<float>* input_c_backprop_data,
2068     DeviceMemory<float>* params_backprop_data,
2069     DeviceMemory<uint8>* reserve_space_data,
2070     ScratchAllocator* workspace_allocator,
2071     dnn::ProfileResult* output_profile_result) {
2072   const CudnnRnnDescriptor& cudnn_rnn_desc =
2073       static_cast<const CudnnRnnDescriptor&>(rnn_desc);
2074   const CudnnRnnSequenceTensorDescriptor& cudnn_input_desc =
2075       static_cast<const CudnnRnnSequenceTensorDescriptor&>(input_desc);
2076   const CudnnRnnStateTensorDescriptor& cudnn_input_h_desc =
2077       static_cast<const CudnnRnnStateTensorDescriptor&>(input_h_desc);
2078   const CudnnRnnStateTensorDescriptor& cudnn_input_c_desc =
2079       static_cast<const CudnnRnnStateTensorDescriptor&>(input_c_desc);
2080   const CudnnRnnSequenceTensorDescriptor& cudnn_output_desc =
2081       static_cast<const CudnnRnnSequenceTensorDescriptor&>(output_desc);
2082   const CudnnRnnStateTensorDescriptor& cudnn_output_h_desc =
2083       static_cast<const CudnnRnnStateTensorDescriptor&>(output_h_desc);
2084   const CudnnRnnStateTensorDescriptor& cudnn_output_c_desc =
2085       static_cast<const CudnnRnnStateTensorDescriptor&>(output_c_desc);
2086   return IsStatusOk(
2087       DoRnnBackwardImpl<float>(
2088           stream, cudnn_rnn_desc, cudnn_input_desc, input_data,
2089           cudnn_input_h_desc, input_h_data, cudnn_input_c_desc, input_c_data,
2090           params, cudnn_output_desc, output_data, cudnn_output_h_desc,
2091           output_h_data, cudnn_output_c_desc, output_c_data,
2092           output_backprop_data, output_h_backprop_data, output_c_backprop_data,
2093           input_backprop_data, input_h_backprop_data, input_c_backprop_data,
2094           params_backprop_data, reserve_space_data, workspace_allocator,
2095           output_profile_result),
2096       /*report_error=*/!output_profile_result);
2097 }
2098 
DoRnnBackward(Stream * stream,const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<double> & input_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<double> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<double> & input_c_data,const DeviceMemory<double> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,const DeviceMemory<double> & output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,const DeviceMemory<double> & output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,const DeviceMemory<double> & output_c_data,const DeviceMemory<double> & output_backprop_data,const DeviceMemory<double> & output_h_backprop_data,const DeviceMemory<double> & output_c_backprop_data,DeviceMemory<double> * input_backprop_data,DeviceMemory<double> * input_h_backprop_data,DeviceMemory<double> * input_c_backprop_data,DeviceMemory<double> * params_backprop_data,DeviceMemory<uint8> * reserve_space_data,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)2099 bool CudnnSupport::DoRnnBackward(
2100     Stream* stream, const dnn::RnnDescriptor& rnn_desc,
2101     const dnn::RnnSequenceTensorDescriptor& input_desc,
2102     const DeviceMemory<double>& input_data,
2103     const dnn::RnnStateTensorDescriptor& input_h_desc,
2104     const DeviceMemory<double>& input_h_data,
2105     const dnn::RnnStateTensorDescriptor& input_c_desc,
2106     const DeviceMemory<double>& input_c_data,
2107     const DeviceMemory<double>& params,
2108     const dnn::RnnSequenceTensorDescriptor& output_desc,
2109     const DeviceMemory<double>& output_data,
2110     const dnn::RnnStateTensorDescriptor& output_h_desc,
2111     const DeviceMemory<double>& output_h_data,
2112     const dnn::RnnStateTensorDescriptor& output_c_desc,
2113     const DeviceMemory<double>& output_c_data,
2114     const DeviceMemory<double>& output_backprop_data,
2115     const DeviceMemory<double>& output_h_backprop_data,
2116     const DeviceMemory<double>& output_c_backprop_data,
2117     DeviceMemory<double>* input_backprop_data,
2118     DeviceMemory<double>* input_h_backprop_data,
2119     DeviceMemory<double>* input_c_backprop_data,
2120     DeviceMemory<double>* params_backprop_data,
2121     DeviceMemory<uint8>* reserve_space_data,
2122     ScratchAllocator* workspace_allocator,
2123     dnn::ProfileResult* output_profile_result) {
2124   const CudnnRnnDescriptor& cudnn_rnn_desc =
2125       static_cast<const CudnnRnnDescriptor&>(rnn_desc);
2126   const CudnnRnnSequenceTensorDescriptor& cudnn_input_desc =
2127       static_cast<const CudnnRnnSequenceTensorDescriptor&>(input_desc);
2128   const CudnnRnnStateTensorDescriptor& cudnn_input_h_desc =
2129       static_cast<const CudnnRnnStateTensorDescriptor&>(input_h_desc);
2130   const CudnnRnnStateTensorDescriptor& cudnn_input_c_desc =
2131       static_cast<const CudnnRnnStateTensorDescriptor&>(input_c_desc);
2132   const CudnnRnnSequenceTensorDescriptor& cudnn_output_desc =
2133       static_cast<const CudnnRnnSequenceTensorDescriptor&>(output_desc);
2134   const CudnnRnnStateTensorDescriptor& cudnn_output_h_desc =
2135       static_cast<const CudnnRnnStateTensorDescriptor&>(output_h_desc);
2136   const CudnnRnnStateTensorDescriptor& cudnn_output_c_desc =
2137       static_cast<const CudnnRnnStateTensorDescriptor&>(output_c_desc);
2138   return IsStatusOk(
2139       DoRnnBackwardImpl<double>(
2140           stream, cudnn_rnn_desc, cudnn_input_desc, input_data,
2141           cudnn_input_h_desc, input_h_data, cudnn_input_c_desc, input_c_data,
2142           params, cudnn_output_desc, output_data, cudnn_output_h_desc,
2143           output_h_data, cudnn_output_c_desc, output_c_data,
2144           output_backprop_data, output_h_backprop_data, output_c_backprop_data,
2145           input_backprop_data, input_h_backprop_data, input_c_backprop_data,
2146           params_backprop_data, reserve_space_data, workspace_allocator,
2147           output_profile_result),
2148       /*report_error=*/!output_profile_result);
2149 }
2150 
2151 namespace {
2152 
2153 // TODO(csigg): Merge a lot of duplicate code below for forward, backward data,
2154 // and backward filter.
2155 
GetCudnnConvolutionForwardAlgo(const CudnnHandle & cudnn,const CudnnTensorDescriptor & input_nd,const CudnnFilterDescriptor & filter,const CudnnConvolutionDescriptor & conv,const CudnnTensorDescriptor & output_nd,bool specify_workspace_limit,size_t memory_limit_bytes)2156 port::StatusOr<cudnnConvolutionFwdAlgo_t> GetCudnnConvolutionForwardAlgo(
2157     const CudnnHandle& cudnn, const CudnnTensorDescriptor& input_nd,
2158     const CudnnFilterDescriptor& filter, const CudnnConvolutionDescriptor& conv,
2159     const CudnnTensorDescriptor& output_nd, bool specify_workspace_limit,
2160     size_t memory_limit_bytes) {
2161   cudnnConvolutionFwdPreference_t preference =
2162       specify_workspace_limit ? CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT
2163                               : CUDNN_CONVOLUTION_FWD_NO_WORKSPACE;
2164   cudnnConvolutionFwdAlgo_t algo_to_use;
2165   RETURN_IF_CUDNN_ERROR(cudnnGetConvolutionForwardAlgorithm(
2166       cudnn.handle(), input_nd.handle(), filter.handle(), conv.handle(),
2167       output_nd.handle(), preference, memory_limit_bytes, &algo_to_use));
2168   return algo_to_use;
2169 }
2170 
2171 port::StatusOr<cudnnConvolutionBwdDataAlgo_t>
GetCudnnConvolutionBackwardDataAlgo(const CudnnHandle & cudnn,const CudnnTensorDescriptor & input_nd,const CudnnFilterDescriptor & filter,const CudnnConvolutionDescriptor & conv,const CudnnTensorDescriptor & output_nd,bool specify_workspace_limit,size_t memory_limit_bytes)2172 GetCudnnConvolutionBackwardDataAlgo(const CudnnHandle& cudnn,
2173                                     const CudnnTensorDescriptor& input_nd,
2174                                     const CudnnFilterDescriptor& filter,
2175                                     const CudnnConvolutionDescriptor& conv,
2176                                     const CudnnTensorDescriptor& output_nd,
2177                                     bool specify_workspace_limit,
2178                                     size_t memory_limit_bytes) {
2179   cudnnConvolutionBwdDataPreference_t preference =
2180       specify_workspace_limit
2181           ? CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT
2182           : CUDNN_CONVOLUTION_BWD_DATA_NO_WORKSPACE;
2183   cudnnConvolutionBwdDataAlgo_t algo_to_use;
2184   RETURN_IF_CUDNN_ERROR(cudnnGetConvolutionBackwardDataAlgorithm(
2185       cudnn.handle(), filter.handle(), output_nd.handle(), conv.handle(),
2186       input_nd.handle(), preference, memory_limit_bytes, &algo_to_use));
2187   return algo_to_use;
2188 }
2189 
2190 port::StatusOr<cudnnConvolutionBwdFilterAlgo_t>
GetCudnnConvolutionBackwardFilterAlgo(const CudnnHandle & cudnn,const CudnnTensorDescriptor & input_nd,const CudnnFilterDescriptor & filter,const CudnnConvolutionDescriptor & conv,const CudnnTensorDescriptor & output_nd,bool specify_workspace_limit,size_t memory_limit_bytes)2191 GetCudnnConvolutionBackwardFilterAlgo(const CudnnHandle& cudnn,
2192                                       const CudnnTensorDescriptor& input_nd,
2193                                       const CudnnFilterDescriptor& filter,
2194                                       const CudnnConvolutionDescriptor& conv,
2195                                       const CudnnTensorDescriptor& output_nd,
2196                                       bool specify_workspace_limit,
2197                                       size_t memory_limit_bytes) {
2198   cudnnConvolutionBwdFilterPreference_t preference =
2199       specify_workspace_limit
2200           ? CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT
2201           : CUDNN_CONVOLUTION_BWD_FILTER_NO_WORKSPACE;
2202   cudnnConvolutionBwdFilterAlgo_t algo_to_use;
2203   RETURN_IF_CUDNN_ERROR(cudnnGetConvolutionBackwardFilterAlgorithm(
2204       cudnn.handle(), input_nd.handle(), output_nd.handle(), conv.handle(),
2205       filter.handle(), preference, memory_limit_bytes, &algo_to_use));
2206   return algo_to_use;
2207 }
2208 
AllocateCudnnConvolutionForwardWorkspace(Stream * stream,const CudnnHandle & cudnn,const CudnnTensorDescriptor & input_nd,const CudnnFilterDescriptor & filter,const CudnnConvolutionDescriptor & conv,const CudnnTensorDescriptor & output_nd,const dnn::AlgorithmDesc & algorithm_desc,ScratchAllocator * scratch_allocator)2209 port::StatusOr<DeviceMemory<uint8>> AllocateCudnnConvolutionForwardWorkspace(
2210     Stream* stream, const CudnnHandle& cudnn,
2211     const CudnnTensorDescriptor& input_nd, const CudnnFilterDescriptor& filter,
2212     const CudnnConvolutionDescriptor& conv,
2213     const CudnnTensorDescriptor& output_nd,
2214     const dnn::AlgorithmDesc& algorithm_desc,
2215     ScratchAllocator* scratch_allocator) {
2216   // TODO(csigg): This has side effects on the convolution descriptor. It is
2217   // functionally correct because the convolution is run with the algorithm of
2218   // the last call to this function, but should be fixed anyway.
2219   conv.set_use_tensor_op_math(algorithm_desc.tensor_ops_enabled());
2220 
2221   // Query the size of the workspace and allocate it.
2222   size_t size_in_bytes;
2223   RETURN_IF_CUDNN_ERROR(cudnnGetConvolutionForwardWorkspaceSize(
2224       cudnn.handle(),
2225       /*xDesc=*/input_nd.handle(),
2226       /*wDesc=*/filter.handle(), /*convDesc=*/conv.handle(),
2227       /*yDesc=*/output_nd.handle(), /*algo=*/ToConvForwardAlgo(algorithm_desc),
2228       /*sizeInBytes=*/&size_in_bytes));
2229 
2230   int64 size_in_bytes_int64 = size_in_bytes;
2231 
2232   if (TF_PREDICT_FALSE(size_in_bytes_int64 < 0)) {
2233     return port::Status(
2234         port::error::INTERNAL,
2235         "cudnnGetConvolutionForwardWorkspaceSize() returned "
2236         "negative sizeInBytes value. This could be a cudnn bug.");
2237   }
2238 
2239   if (size_in_bytes_int64 == 0) {
2240     return DeviceMemory<uint8>();
2241   }
2242 
2243   if (TF_PREDICT_FALSE(!scratch_allocator)) {
2244     return port::Status(port::error::INVALID_ARGUMENT,
2245                         "No scratch allocator provided");
2246   }
2247 
2248   return scratch_allocator->AllocateBytes(stream, size_in_bytes);
2249 }
2250 
2251 port::StatusOr<DeviceMemory<uint8>>
AllocateCudnnConvolutionBackwardDataWorkspace(Stream * stream,const CudnnHandle & cudnn,const CudnnTensorDescriptor & input_nd,const CudnnFilterDescriptor & filter,const CudnnConvolutionDescriptor & conv,const CudnnTensorDescriptor & output_nd,const dnn::AlgorithmDesc & algorithm_desc,ScratchAllocator * scratch_allocator)2252 AllocateCudnnConvolutionBackwardDataWorkspace(
2253     Stream* stream, const CudnnHandle& cudnn,
2254     const CudnnTensorDescriptor& input_nd, const CudnnFilterDescriptor& filter,
2255     const CudnnConvolutionDescriptor& conv,
2256     const CudnnTensorDescriptor& output_nd,
2257     const dnn::AlgorithmDesc& algorithm_desc,
2258     ScratchAllocator* scratch_allocator) {
2259   // TODO(csigg): This has side effects on the convolution descriptor. It is
2260   // functionally correct because the convolution is run with the algorithm of
2261   // the last call to this function, but should be fixed anyway.
2262   conv.set_use_tensor_op_math(algorithm_desc.tensor_ops_enabled());
2263 
2264   // Query the size of the workspace and allocate it.
2265   size_t size_in_bytes;
2266   RETURN_IF_CUDNN_ERROR(cudnnGetConvolutionBackwardDataWorkspaceSize(
2267       cudnn.handle(),
2268       /*wDesc=*/filter.handle(),
2269       /*dyDesc=*/output_nd.handle(),
2270       /*convDesc=*/conv.handle(),
2271       /*dxDesc=*/input_nd.handle(),
2272       /*algo=*/ToConvBackwardDataAlgo(algorithm_desc),
2273       /*sizeInBytes=*/&size_in_bytes));
2274 
2275   int64 size_in_bytes_int64 = size_in_bytes;
2276 
2277   if (TF_PREDICT_FALSE(size_in_bytes_int64 < 0)) {
2278     return port::Status(
2279         port::error::INTERNAL,
2280         "cudnnGetConvolutionBackwardDataWorkspaceSize() returned "
2281         "negative sizeInBytes value. This could be a cudnn bug.");
2282   }
2283 
2284   if (size_in_bytes_int64 == 0) {
2285     return DeviceMemory<uint8>();
2286   }
2287 
2288   if (TF_PREDICT_FALSE(!scratch_allocator)) {
2289     return port::Status(port::error::INVALID_ARGUMENT,
2290                         "No scratch allocator provided");
2291   }
2292 
2293   return scratch_allocator->AllocateBytes(stream, size_in_bytes);
2294 }
2295 
2296 port::StatusOr<DeviceMemory<uint8>>
AllocateCudnnConvolutionBackwardFilterWorkspace(Stream * stream,const CudnnHandle & cudnn,const CudnnTensorDescriptor & input_nd,const CudnnFilterDescriptor & filter,const CudnnConvolutionDescriptor & conv,const CudnnTensorDescriptor & output_nd,const dnn::AlgorithmDesc & algorithm_desc,ScratchAllocator * scratch_allocator)2297 AllocateCudnnConvolutionBackwardFilterWorkspace(
2298     Stream* stream, const CudnnHandle& cudnn,
2299     const CudnnTensorDescriptor& input_nd, const CudnnFilterDescriptor& filter,
2300     const CudnnConvolutionDescriptor& conv,
2301     const CudnnTensorDescriptor& output_nd,
2302     const dnn::AlgorithmDesc& algorithm_desc,
2303     ScratchAllocator* scratch_allocator) {
2304   // TODO(csigg): This has side effects on the convolution descriptor. It is
2305   // functionally correct because the convolution is run with the algorithm of
2306   // the last call to this function, but should be fixed anyway.
2307   conv.set_use_tensor_op_math(algorithm_desc.tensor_ops_enabled());
2308 
2309   // Query the size of the workspace and allocate it.
2310   size_t size_in_bytes;
2311   RETURN_IF_CUDNN_ERROR(cudnnGetConvolutionBackwardFilterWorkspaceSize(
2312       cudnn.handle(),
2313       /*xDesc=*/input_nd.handle(),
2314       /*dyDesc=*/output_nd.handle(),
2315       /*convDesc=*/conv.handle(),
2316       /*gradDesc=*/filter.handle(),
2317       /*algo=*/ToConvBackwardFilterAlgo(algorithm_desc),
2318       /*sizeInBytes=*/&size_in_bytes));
2319 
2320   int64 size_in_bytes_int64 = size_in_bytes;
2321 
2322   if (TF_PREDICT_FALSE(size_in_bytes_int64 < 0)) {
2323     return port::Status(
2324         port::error::INTERNAL,
2325         "cudnnGetConvolutionBackwardFilterWorkspaceSize() returned "
2326         "negative sizeInBytes value. This could be a cudnn bug.");
2327   }
2328 
2329   if (size_in_bytes_int64 == 0) {
2330     return DeviceMemory<uint8>();
2331   }
2332 
2333   if (TF_PREDICT_FALSE(!scratch_allocator)) {
2334     return port::Status(port::error::INVALID_ARGUMENT,
2335                         "No scratch allocator provided");
2336   }
2337 
2338   return scratch_allocator->AllocateBytes(stream, size_in_bytes);
2339 }
2340 
GetCudnnConvolutionForwardAlgorithm(Stream * stream,const CudnnHandle & cudnn,const dnn::AlgorithmConfig & algorithm_config,const CudnnTensorDescriptor & input_nd,const CudnnFilterDescriptor & filter,const CudnnConvolutionDescriptor & conv,const CudnnTensorDescriptor & output_nd,ScratchAllocator * scratch_allocator,DeviceMemory<uint8> * scratch)2341 port::StatusOr<dnn::AlgorithmDesc> GetCudnnConvolutionForwardAlgorithm(
2342     Stream* stream, const CudnnHandle& cudnn,
2343     const dnn::AlgorithmConfig& algorithm_config,
2344     const CudnnTensorDescriptor& input_nd, const CudnnFilterDescriptor& filter,
2345     const CudnnConvolutionDescriptor& conv,
2346     const CudnnTensorDescriptor& output_nd, ScratchAllocator* scratch_allocator,
2347     DeviceMemory<uint8>* scratch) {
2348   absl::optional<dnn::AlgorithmDesc> algo_desc = algorithm_config.algorithm();
2349   if (!algo_desc.has_value()) {
2350     // Pick fastest algorithm within memory limit according to cuDNN's
2351     // heuristics.
2352     bool specify_workspace_limit = scratch_allocator != nullptr;
2353     auto memory_limit_bytes =
2354         specify_workspace_limit
2355             ? std::max(scratch_allocator->GetMemoryLimitInBytes(stream), 0ll)
2356             : 0ll;
2357     SE_ASSIGN_OR_RETURN(cudnnConvolutionFwdAlgo_t algo,
2358                         GetCudnnConvolutionForwardAlgo(
2359                             cudnn, input_nd, filter, conv, output_nd,
2360                             specify_workspace_limit, memory_limit_bytes));
2361     algo_desc = dnn::AlgorithmDesc(algo, /*use_tensor_ops=*/true);
2362   }
2363 
2364   const auto scratch_or = AllocateCudnnConvolutionForwardWorkspace(
2365       stream, cudnn, input_nd, filter, conv, output_nd, *algo_desc,
2366       scratch_allocator);
2367 
2368   if (scratch_or.ok()) {
2369     *scratch = scratch_or.ValueOrDie();
2370     return *algo_desc;
2371   }
2372 
2373   algo_desc = algorithm_config.algorithm_no_scratch();
2374 
2375   // Failed to allocate workspace for the first algorithm, fall back to the
2376   // no_scratch algorithm.
2377   if (!algo_desc.has_value()) {
2378     return port::Status(
2379         port::error::INVALID_ARGUMENT,
2380         "The primary convolution algorithm failed memory allocation, "
2381         "while a secondary algorithm is not provided.");
2382   }
2383 
2384   SE_ASSIGN_OR_RETURN(*scratch, AllocateCudnnConvolutionForwardWorkspace(
2385                                     stream, cudnn, input_nd, filter, conv,
2386                                     output_nd, *algo_desc, scratch_allocator));
2387   return *algo_desc;
2388 }
2389 
GetCudnnConvolutionBackwardDataAlgorithm(Stream * stream,const CudnnHandle & cudnn,const dnn::AlgorithmConfig & algorithm_config,const CudnnTensorDescriptor & input_nd,const CudnnFilterDescriptor & filter,const CudnnConvolutionDescriptor & conv,const CudnnTensorDescriptor & output_nd,ScratchAllocator * scratch_allocator,DeviceMemory<uint8> * scratch)2390 port::StatusOr<dnn::AlgorithmDesc> GetCudnnConvolutionBackwardDataAlgorithm(
2391     Stream* stream, const CudnnHandle& cudnn,
2392     const dnn::AlgorithmConfig& algorithm_config,
2393     const CudnnTensorDescriptor& input_nd, const CudnnFilterDescriptor& filter,
2394     const CudnnConvolutionDescriptor& conv,
2395     const CudnnTensorDescriptor& output_nd, ScratchAllocator* scratch_allocator,
2396     DeviceMemory<uint8>* scratch) {
2397   absl::optional<dnn::AlgorithmDesc> algo_desc = algorithm_config.algorithm();
2398   if (!algo_desc.has_value()) {
2399     // Pick fastest algorithm within memory limit according to cuDNN's
2400     // heuristics.
2401     bool specify_workspace_limit = scratch_allocator != nullptr;
2402     auto memory_limit_bytes =
2403         specify_workspace_limit
2404             ? std::max(scratch_allocator->GetMemoryLimitInBytes(stream), 0ll)
2405             : 0ll;
2406     SE_ASSIGN_OR_RETURN(cudnnConvolutionBwdDataAlgo_t algo,
2407                         GetCudnnConvolutionBackwardDataAlgo(
2408                             cudnn, input_nd, filter, conv, output_nd,
2409                             specify_workspace_limit, memory_limit_bytes));
2410     algo_desc = dnn::AlgorithmDesc(algo, /*use_tensor_ops=*/true);
2411   }
2412 
2413   const auto scratch_or = AllocateCudnnConvolutionBackwardDataWorkspace(
2414       stream, cudnn, input_nd, filter, conv, output_nd, *algo_desc,
2415       scratch_allocator);
2416 
2417   if (scratch_or.ok()) {
2418     *scratch = scratch_or.ValueOrDie();
2419     return *algo_desc;
2420   }
2421 
2422   algo_desc = algorithm_config.algorithm_no_scratch();
2423 
2424   // Failed to allocate workspace for the first algorithm, fall back to the
2425   // no_scratch algorithm.
2426   if (!algo_desc.has_value()) {
2427     return port::Status(
2428         port::error::INVALID_ARGUMENT,
2429         "The primary convolution algorithm failed memory allocation, "
2430         "while a secondary algorithm is not provided.");
2431   }
2432 
2433   SE_ASSIGN_OR_RETURN(*scratch, AllocateCudnnConvolutionBackwardDataWorkspace(
2434                                     stream, cudnn, input_nd, filter, conv,
2435                                     output_nd, *algo_desc, scratch_allocator));
2436   return *algo_desc;
2437 }
2438 
GetCudnnConvolutionBackwardFilterAlgorithm(Stream * stream,const CudnnHandle & cudnn,const dnn::AlgorithmConfig & algorithm_config,const CudnnTensorDescriptor & input_nd,const CudnnFilterDescriptor & filter,const CudnnConvolutionDescriptor & conv,const CudnnTensorDescriptor & output_nd,ScratchAllocator * scratch_allocator,DeviceMemory<uint8> * scratch)2439 port::StatusOr<dnn::AlgorithmDesc> GetCudnnConvolutionBackwardFilterAlgorithm(
2440     Stream* stream, const CudnnHandle& cudnn,
2441     const dnn::AlgorithmConfig& algorithm_config,
2442     const CudnnTensorDescriptor& input_nd, const CudnnFilterDescriptor& filter,
2443     const CudnnConvolutionDescriptor& conv,
2444     const CudnnTensorDescriptor& output_nd, ScratchAllocator* scratch_allocator,
2445     DeviceMemory<uint8>* scratch) {
2446   absl::optional<dnn::AlgorithmDesc> algo_desc = algorithm_config.algorithm();
2447   if (!algo_desc.has_value()) {
2448     // Pick fastest algorithm within memory limit according to cuDNN's
2449     // heuristics.
2450     bool specify_workspace_limit = scratch_allocator != nullptr;
2451     auto memory_limit_bytes =
2452         specify_workspace_limit
2453             ? std::max(scratch_allocator->GetMemoryLimitInBytes(stream), 0ll)
2454             : 0ll;
2455     SE_ASSIGN_OR_RETURN(cudnnConvolutionBwdFilterAlgo_t algo,
2456                         GetCudnnConvolutionBackwardFilterAlgo(
2457                             cudnn, input_nd, filter, conv, output_nd,
2458                             specify_workspace_limit, memory_limit_bytes));
2459     algo_desc = dnn::AlgorithmDesc(algo, /*use_tensor_ops=*/true);
2460   }
2461 
2462   auto scratch_or = AllocateCudnnConvolutionBackwardFilterWorkspace(
2463       stream, cudnn, input_nd, filter, conv, output_nd, *algo_desc,
2464       scratch_allocator);
2465 
2466   if (scratch_or.ok()) {
2467     *scratch = scratch_or.ValueOrDie();
2468     return *algo_desc;
2469   }
2470 
2471   algo_desc = algorithm_config.algorithm_no_scratch();
2472 
2473   // Failed to allocate workspace for the first algorithm, fall back to the
2474   // no_scratch algorithm.
2475   if (!algo_desc.has_value()) {
2476     return port::Status(
2477         port::error::INVALID_ARGUMENT,
2478         "The primary convolution algorithm failed memory allocation, "
2479         "while a secondary algorithm is not provided.");
2480   }
2481 
2482   SE_ASSIGN_OR_RETURN(*scratch, AllocateCudnnConvolutionBackwardFilterWorkspace(
2483                                     stream, cudnn, input_nd, filter, conv,
2484                                     output_nd, *algo_desc, scratch_allocator));
2485   return *algo_desc;
2486 }
2487 
2488 // A helper class to set env-vars and choose options for cudnn-related
2489 // algorithms.
2490 template <typename EnvVar>
2491 class CudnnEnvVar {
2492  public:
IsEnabled()2493   static bool IsEnabled() {
2494     static bool is_enabled = IsEnabledImpl();
2495     return is_enabled;
2496   }
2497 
2498  private:
IsEnabledImpl()2499   static bool IsEnabledImpl() {
2500     const char* tf_env_var_val = getenv(EnvVar::kName);
2501     if (tf_env_var_val != nullptr) {
2502       absl::string_view tf_env_var_val_str(tf_env_var_val);
2503       if (tf_env_var_val_str == "0") {
2504         return false;
2505       }
2506       return true;
2507     }
2508     return EnvVar::kDefaultFlag;
2509   }
2510 };
2511 
2512 // A helper struct to decide whether to enable the FFT_TILING algorithms for
2513 // forward convolution. It is disabled for cuDNN < 7 due to memory corruption
2514 // caused by some shapes with this algorithm. Users can explicitly enable the
2515 // algorithm through an env-var "TF_ENABLE_FFT_TILING_FORWARD=1".
2516 struct FftTilingForward {
2517   static constexpr const char* kName = "TF_ENABLE_FFT_TILING_FORWARD";
2518   static constexpr bool kDefaultFlag = CUDNN_VERSION >= 7000;
2519 };
2520 
2521 // A helper struct to decide whether to enable the WINOGRAD_NONFUSED algorithms.
2522 // By default it is turned on, users can explicitly disable them through an
2523 // env-var "TF_ENABLE_WINOGRAD_NONFUSED=0".
2524 // https://github.com/tensorflow/tensorflow/pull/4901
2525 struct WinogradNonfused {
2526   static constexpr const char* kName = "TF_ENABLE_WINOGRAD_NONFUSED";
2527   // NVIDIA has fixed winograd nonfused bug for cudnn v>=7. For older versions,
2528   // we have a workaround.
2529   static constexpr bool kDefaultFlag = true;
2530 };
2531 
2532 // A helper struct to decide whether to use FP32 as the internal compute type
2533 // for convolution when the input data type is FP16. By default it is turned on,
2534 // users can explicitly disable them (choose to use FP16 as the internal compute
2535 // type) through an env-var "TF_FP16_CONV_USE_FP32_COMPUTE=0".
2536 struct ConvDoFP32ComputationFP16Input {
2537   static constexpr const char* kName = "TF_FP16_CONV_USE_FP32_COMPUTE";
2538   // Using FP16 as the internal compute type for convolution when the input data
2539   // type is FP16 is only supported on architectures with true fp16 support
2540   // (compute capability 5.3 and 6.0). Setting this to false in an unsupported
2541   // architecture will cause internal errors.
2542   static constexpr bool kDefaultFlag = true;
2543 };
2544 
2545 // A helper struct to decide whether to use FP32 as the internal compute type
2546 // for rnn when the input data type is FP16. At present it is turned off,
2547 // users can explicitly control them through an env-var
2548 // TF_FP16_RNN_USE_FP32_COMPUTE.
2549 // After the TODO below is fixed, users should almost always use fp32 compute
2550 // type for training. Using fp16 might suffer suboptimal accuracy due to loss
2551 // in precision.
2552 struct RnnDoFP32ComputationFP16Input {
2553   static constexpr const char* kName = "TF_FP16_RNN_USE_FP32_COMPUTE";
2554   // TODO(jamesqin): b/78182362 flip to true when cudnn 7.1.4 fixes the bug.
2555   // Before cudnn 7.1.4 RNN are always done in fp32, no matter what math
2556   // precision is set.
2557   // Set it temporary to false s.t. no error is raised when using fp16 inputs,
2558   // fp32 math precision.
2559   static constexpr bool kDefaultFlag = false;
2560 };
2561 
GetRnnComputeType(dnn::DataType data_type)2562 cudnnDataType_t GetRnnComputeType(dnn::DataType data_type) {
2563   switch (data_type) {
2564     case dnn::DataType::kFloat:
2565       return CUDNN_DATA_FLOAT;
2566     case dnn::DataType::kDouble:
2567       return CUDNN_DATA_DOUBLE;
2568     case dnn::DataType::kHalf:
2569       if (CudnnEnvVar<RnnDoFP32ComputationFP16Input>::IsEnabled()) {
2570         return CUDNN_DATA_FLOAT;
2571       } else {
2572         return CUDNN_DATA_HALF;
2573       }
2574     default:
2575       LOG(FATAL) << "Invalid RNN data type: " << static_cast<int>(data_type);
2576   }
2577 }
2578 
GetConvAccumulatorType(dnn::DataType data_type)2579 dnn::DataType GetConvAccumulatorType(dnn::DataType data_type) {
2580   switch (data_type) {
2581     case dnn::DataType::kFloat:
2582     case dnn::DataType::kDouble:
2583       return data_type;
2584     case dnn::DataType::kHalf:
2585       return CudnnEnvVar<ConvDoFP32ComputationFP16Input>::IsEnabled()
2586                  ? dnn::DataType::kFloat
2587                  : dnn::DataType::kHalf;
2588     case dnn::DataType::kInt8:
2589     case dnn::DataType::kInt32:
2590       return dnn::DataType::kInt32;
2591     default:
2592       LOG(FATAL) << "Invalid DNN data type: " << static_cast<int>(data_type);
2593   }
2594 }
2595 
2596 // Determines whether we can safely perform a winograd non-fused convolution for
2597 // the given input and output shapes.  This works around b/68264959, an integer
2598 // overflow in cuDNNv5 and cuDNNv6.
2599 #if CUDNN_VERSION >= 7000
ShouldIncludeWinogradNonfusedAlgo(const dnn::BatchDescriptor &,const dnn::BatchDescriptor &)2600 bool ShouldIncludeWinogradNonfusedAlgo(const dnn::BatchDescriptor&,
2601                                        const dnn::BatchDescriptor&) {
2602   return true;
2603 }
2604 #else
ShouldIncludeWinogradNonfusedAlgo(const dnn::BatchDescriptor & input_desc,const dnn::BatchDescriptor & output_desc)2605 bool ShouldIncludeWinogradNonfusedAlgo(
2606     const dnn::BatchDescriptor& input_desc,
2607     const dnn::BatchDescriptor& output_desc) {
2608   int64 batch = input_desc.count();
2609   int64 in_depths = input_desc.feature_map_count();
2610   int64 in_rows = input_desc.height();
2611   int64 in_cols = input_desc.ndims() == 1 ? 1 : input_desc.width();
2612   int64 out_depths = output_desc.feature_map_count();
2613 
2614   int64 total_size = port::MathUtil::CeilOfRatio(batch, int64{16}) *
2615                      std::max(in_depths, out_depths) * in_cols * in_rows *
2616                      sizeof(float);
2617 
2618   const int64 threshold = 1L << 31;
2619   return total_size < threshold;
2620 }
2621 #endif
2622 
2623 }  // namespace
2624 
DoPrepareForConvolution(dnn::ConvolutionKind kind,dnn::DataType element_type,Stream * stream,const dnn::BatchDescriptor & input_descriptor,DeviceMemoryBase input_data,const dnn::FilterDescriptor & filter_descriptor,DeviceMemoryBase filter_data,const dnn::BatchDescriptor & output_descriptor,DeviceMemoryBase output_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::AlgorithmConfig & algorithm_config,ScratchAllocator * scratch_allocator,dnn::AlgorithmDesc * algorithm_desc,DeviceMemory<uint8> * scratch_memory)2625 port::Status CudnnSupport::DoPrepareForConvolution(
2626     dnn::ConvolutionKind kind, dnn::DataType element_type, Stream* stream,
2627     const dnn::BatchDescriptor& input_descriptor, DeviceMemoryBase input_data,
2628     const dnn::FilterDescriptor& filter_descriptor,
2629     DeviceMemoryBase filter_data, const dnn::BatchDescriptor& output_descriptor,
2630     DeviceMemoryBase output_data,
2631     const dnn::ConvolutionDescriptor& convolution_descriptor,
2632     const dnn::AlgorithmConfig& algorithm_config,
2633     ScratchAllocator* scratch_allocator, dnn::AlgorithmDesc* algorithm_desc,
2634     DeviceMemory<uint8>* scratch_memory) {
2635   CudnnTensorDescriptor input_nd(
2636       input_descriptor,
2637       ToCudnnDataType(element_type, input_descriptor.layout()));
2638   CudnnFilterDescriptor filter_nd(
2639       filter_descriptor,
2640       ToCudnnDataType(element_type, filter_descriptor.layout()));
2641   CudnnTensorDescriptor output_nd(
2642       output_descriptor,
2643       ToCudnnDataType(element_type, output_descriptor.layout()));
2644   CudnnConvolutionDescriptor conv(
2645       convolution_descriptor,
2646       ToCudnnDataType(GetConvAccumulatorType(element_type)));
2647 
2648   auto cudnn = cudnn_->GetHandle(parent_, stream);
2649 
2650   switch (kind) {
2651     case dnn::ConvolutionKind::FORWARD: {
2652       SE_ASSIGN_OR_RETURN(
2653           *algorithm_desc,
2654           GetCudnnConvolutionForwardAlgorithm(
2655               stream, cudnn, algorithm_config, input_nd, filter_nd, conv,
2656               output_nd, scratch_allocator, scratch_memory));
2657       break;
2658     }
2659     case dnn::ConvolutionKind::BACKWARD_DATA: {
2660       SE_ASSIGN_OR_RETURN(
2661           *algorithm_desc,
2662           GetCudnnConvolutionBackwardDataAlgorithm(
2663               stream, cudnn, algorithm_config, input_nd, filter_nd, conv,
2664               output_nd, scratch_allocator, scratch_memory));
2665       break;
2666     }
2667     case dnn::ConvolutionKind::BACKWARD_FILTER: {
2668       SE_ASSIGN_OR_RETURN(
2669           *algorithm_desc,
2670           GetCudnnConvolutionBackwardFilterAlgorithm(
2671               stream, cudnn, algorithm_config, input_nd, filter_nd, conv,
2672               output_nd, scratch_allocator, scratch_memory));
2673       break;
2674     }
2675     default:
2676       return port::InternalError(
2677           absl::StrCat("Unexpected convolution kind ", static_cast<int>(kind)));
2678   }
2679 
2680   return port::Status::OK();
2681 }
2682 
DoConvolve(dnn::ConvolutionKind kind,dnn::DataType element_type,Stream * stream,const dnn::BatchDescriptor & input_descriptor,DeviceMemoryBase input_data,const dnn::FilterDescriptor & filter_descriptor,DeviceMemoryBase filter_data,const dnn::BatchDescriptor & output_descriptor,DeviceMemoryBase output_data,const dnn::ConvolutionDescriptor & convolution_descriptor,dnn::AlgorithmDesc algorithm_desc,DeviceMemory<uint8> scratch_memory,dnn::ProfileResult * output_profile_result)2683 port::Status CudnnSupport::DoConvolve(
2684     dnn::ConvolutionKind kind, dnn::DataType element_type, Stream* stream,
2685     const dnn::BatchDescriptor& input_descriptor, DeviceMemoryBase input_data,
2686     const dnn::FilterDescriptor& filter_descriptor,
2687     DeviceMemoryBase filter_data, const dnn::BatchDescriptor& output_descriptor,
2688     DeviceMemoryBase output_data,
2689     const dnn::ConvolutionDescriptor& convolution_descriptor,
2690     dnn::AlgorithmDesc algorithm_desc, DeviceMemory<uint8> scratch_memory,
2691     dnn::ProfileResult* output_profile_result) {
2692   cudnnDataType_t cudnn_type = ToCudnnDataType(element_type);
2693   CudnnTensorDescriptor input_nd(input_descriptor, cudnn_type);
2694   CudnnTensorDescriptor output_nd(output_descriptor, cudnn_type);
2695   CudnnFilterDescriptor filter_nd(filter_descriptor, cudnn_type);
2696   auto accumulator_type = GetConvAccumulatorType(element_type);
2697   CudnnConvolutionDescriptor conv(convolution_descriptor,
2698                                   ToCudnnDataType(accumulator_type));
2699 
2700   auto cudnn = cudnn_->GetHandle(parent_, stream);
2701   // Alpha is the scaling factor for input.
2702   float falpha = 1.0;
2703   double dalpha = 1.0;
2704   void* alpha = cudnn_type == CUDNN_DATA_DOUBLE ? static_cast<void*>(&dalpha)
2705                                                 : static_cast<void*>(&falpha);
2706   // Beta is the scaling factor for output.
2707   float fbeta = 0.0;
2708   double dbeta = 0.0;
2709   void* beta = cudnn_type == CUDNN_DATA_DOUBLE ? static_cast<void*>(&dbeta)
2710                                                : static_cast<void*>(&fbeta);
2711 
2712   const bool is_profiling = output_profile_result != nullptr;
2713 
2714   std::unique_ptr<GpuTimer, GpuTimerDeleter> timer;
2715   if (is_profiling) {
2716     timer.reset(new GpuTimer(parent_));  // NOLINT
2717     // The start and stop of the timer should be as close to the Cudnn call as
2718     // possible. It is still possible for other threads to issue workload on
2719     // to this stream. So it could take multiple profiling measurements.
2720     if (!timer->Init() || !timer->Start(AsGpuStream(stream))) {
2721       return port::Status(port::error::INTERNAL, "Failed to start timer");
2722     }
2723   }
2724 
2725   const auto get_fwd_bugs = [&]() -> port::Status {
2726     // Report an error if we might be hitting a cuDNN bug that accesses illegal
2727     // memory. See nvbugs/2138754, b/80018418.
2728     if (CUDNN_VERSION < 7300) {
2729       if (algorithm_desc.algo_id() != CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING) {
2730         return port::Status::OK();
2731       }
2732       if (input_descriptor.ndims() < 3) {
2733         return port::Status::OK();
2734       }
2735       // Checks that a*b is within the valid range (as provided by NVIDIA).
2736       const auto check_sizes = [](size_t a, size_t b) {
2737         if ((a * b * 4608 - 1) >> 31 == 0) {
2738           return port::Status::OK();
2739         }
2740         return port::Status(
2741             port::error::FAILED_PRECONDITION,
2742             "This configuration potentially accesses illegal memory.");
2743       };
2744       SE_RETURN_IF_ERROR(check_sizes(input_descriptor.feature_map_count(),
2745                                      output_descriptor.feature_map_count()));
2746       SE_RETURN_IF_ERROR(check_sizes(input_descriptor.count(),
2747                                      input_descriptor.feature_map_count()));
2748       SE_RETURN_IF_ERROR(check_sizes(input_descriptor.count(),
2749                                      output_descriptor.feature_map_count()));
2750       return port::Status::OK();
2751     }
2752     if (algorithm_desc.algo_id() ==
2753             CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED &&
2754         !ShouldIncludeWinogradNonfusedAlgo(input_descriptor,
2755                                            output_descriptor)) {
2756       return port::Status(
2757           port::error::FAILED_PRECONDITION,
2758           "This configuration has potential integer overflow in "
2759           "cuDNNv5 and cuDNNv6. See b/68264959.");
2760     }
2761     return port::Status::OK();
2762   };
2763 
2764   auto get_bwd_data_bugs = [&]() -> port::Status {
2765     if (algorithm_desc.algo_id() ==
2766             CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED &&
2767         !ShouldIncludeWinogradNonfusedAlgo(input_descriptor,
2768                                            output_descriptor)) {
2769       return port::Status(
2770           port::error::FAILED_PRECONDITION,
2771           "This configuration has potential integer overflow in "
2772           "cuDNNv5 and cuDNNv6. See b/68264959.");
2773     }
2774 
2775     // Cudnn 7.1.4 has a bug if the workspace of the following convolution is
2776     // not zero-initialized, nvbugs/2254619.
2777     if (CUDNN_VERSION >= 7000 && CUDNN_VERSION < 7300 &&
2778         algorithm_desc.algo_id() == CUDNN_CONVOLUTION_BWD_DATA_ALGO_1 &&
2779         cudnn_type == CUDNN_DATA_HALF && algorithm_desc.tensor_ops_enabled() &&
2780         input_descriptor.layout() == dnn::DataLayout::kBatchYXDepth &&
2781         filter_descriptor.layout() == dnn::FilterLayout::kOutputInputYX &&
2782         output_descriptor.layout() == dnn::DataLayout::kBatchDepthYX &&
2783         (convolution_descriptor.vertical_filter_stride() > 1 ||
2784          convolution_descriptor.horizontal_filter_stride() > 1)) {
2785       stream->ThenMemZero(&scratch_memory, scratch_memory.size());
2786     }
2787     return port::Status::OK();
2788   };
2789 
2790   const auto get_bwd_filter_bugs = [&]() -> port::Status {
2791     // Report an error if we might be hitting a cuDNN bug that produces
2792     // incorrect results. See nvbugs/2072856
2793     if (CUDNN_VERSION < 7300) {
2794       SE_RETURN_IF_ERROR([&] {
2795         if (algorithm_desc.algo_id() !=
2796             CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT_TILING) {
2797           return port::Status::OK();
2798         }
2799         if (output_descriptor.height() > 1 && output_descriptor.width() > 1) {
2800           return port::Status::OK();
2801         }
2802         int convolution_size = output_descriptor.height() > 1
2803                                    ? filter_descriptor.input_filter_height()
2804                                    : filter_descriptor.input_filter_width();
2805         if (convolution_size <= 32) {
2806           return port::Status::OK();
2807         }
2808         cudnnConvolutionMode_t convolution_mode;
2809         cudnnDataType_t compute_type;
2810         RETURN_IF_CUDNN_ERROR(cudnnGetConvolutionNdDescriptor(
2811             conv.handle(), 0, nullptr, nullptr, nullptr, nullptr,
2812             &convolution_mode, &compute_type));
2813         if (convolution_mode != CUDNN_CONVOLUTION) {
2814           return port::Status::OK();
2815         }
2816         return port::Status(
2817             port::error::FAILED_PRECONDITION,
2818             "This configuration potentially produces incorrect results.");
2819       }());
2820     }
2821 
2822     if (algorithm_desc.algo_id() ==
2823             CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED &&
2824         !ShouldIncludeWinogradNonfusedAlgo(input_descriptor,
2825                                            output_descriptor)) {
2826       return port::Status(
2827           port::error::FAILED_PRECONDITION,
2828           "This configuration has potential integer overflow in "
2829           "cuDNNv5 and cuDNNv6. See b/68264959.");
2830     }
2831 
2832     // Zero out the result buffer for strided conv backward filter for NHWC
2833     // layouts. cuDNN 7.1.4 and 7.2 has non-determinisic bug if the buffer is
2834     // not zeroed.
2835     //
2836     // This wrong result caused by the bug is very flaky. It needs to be run for
2837     // up to 20 times to produce a mismatch.
2838     //
2839     // See nvbugs/2379553.
2840     if (CUDNN_VERSION >= 7100 && CUDNN_VERSION < 7300 &&
2841         algorithm_desc.algo_id() == CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1 &&
2842         cudnn_type == CUDNN_DATA_HALF &&
2843         input_descriptor.layout() == dnn::DataLayout::kBatchYXDepth &&
2844         filter_descriptor.layout() == dnn::FilterLayout::kOutputYXInput &&
2845         output_descriptor.layout() == dnn::DataLayout::kBatchYXDepth &&
2846         (convolution_descriptor.vertical_filter_stride() > 1 ||
2847          convolution_descriptor.horizontal_filter_stride() > 1)) {
2848       stream->ThenMemZero(&filter_data, filter_data.size());
2849     }
2850     return port::Status::OK();
2851   };
2852 
2853   switch (kind) {
2854     case dnn::ConvolutionKind::FORWARD: {
2855       SE_RETURN_IF_ERROR(get_fwd_bugs());
2856       RETURN_IF_CUDNN_ERROR(cudnnConvolutionForward(
2857           cudnn.handle(),
2858           /*alpha=*/alpha, /*srcDesc=*/input_nd.handle(),
2859           /*srcData=*/input_data.opaque(), /*filterDesc=*/filter_nd.handle(),
2860           /*filterData=*/filter_data.opaque(), /*convDesc=*/conv.handle(),
2861           /*algo=*/ToConvForwardAlgo(algorithm_desc),
2862           /*workSpace=*/scratch_memory.opaque(),
2863           /*workSpaceSizeInBytes=*/scratch_memory.size(), /*beta=*/beta,
2864           /*yDesc=*/output_nd.handle(), /*y=*/output_data.opaque()));
2865       break;
2866     }
2867     case dnn::ConvolutionKind::BACKWARD_DATA: {
2868       SE_RETURN_IF_ERROR(get_bwd_data_bugs());
2869       RETURN_IF_CUDNN_ERROR(cudnnConvolutionBackwardData(
2870           cudnn.handle(),
2871           /*alpha=*/alpha,
2872           /*wDesc=*/filter_nd.handle(),
2873           /*w=*/filter_data.opaque(),
2874           /*dyDesc=*/output_nd.handle(),
2875           /*dy=*/output_data.opaque(),
2876           /*convDesc=*/conv.handle(),
2877           /*algo=*/ToConvBackwardDataAlgo(algorithm_desc),
2878           /*workSpace=*/scratch_memory.opaque(),
2879           /*workSpaceSizeInBytes=*/scratch_memory.size(),
2880           /*beta=*/beta,
2881           /*dxDesc=*/input_nd.handle(),
2882           /*dx=*/input_data.opaque()));
2883       break;
2884     }
2885     case dnn::ConvolutionKind::BACKWARD_FILTER: {
2886       SE_RETURN_IF_ERROR(get_bwd_filter_bugs());
2887       RETURN_IF_CUDNN_ERROR(cudnnConvolutionBackwardFilter(
2888           cudnn.handle(),
2889           /*alpha=*/alpha,
2890           /*srcDesc=*/input_nd.handle(),
2891           /*srcData=*/input_data.opaque(),
2892           /*diffDesc=*/output_nd.handle(),
2893           /*diffData=*/output_data.opaque(),
2894           /*convDesc=*/conv.handle(),
2895           /*algo=*/ToConvBackwardFilterAlgo(algorithm_desc),
2896           /*workSpace=*/scratch_memory.opaque(),
2897           /*workSpaceSizeInBytes=*/scratch_memory.size(),
2898           /*beta=*/beta,
2899           /*gradDesc=*/filter_nd.handle(),
2900           /*dw=*/filter_data.opaque()));
2901       break;
2902     }
2903     default:
2904       return port::InternalError(
2905           absl::StrCat("Unexpected convolution kind ", static_cast<int>(kind)));
2906   }
2907 
2908   if (is_profiling) {
2909     if (!timer->Stop(AsGpuStream(stream))) {
2910       return port::Status(port::error::INTERNAL, "Failed to stop timer");
2911     }
2912     output_profile_result->set_algorithm(algorithm_desc);
2913     output_profile_result->set_elapsed_time_in_ms(
2914         timer->GetElapsedMilliseconds());
2915     output_profile_result->set_scratch_size(scratch_memory.size());
2916   }
2917 
2918   return port::Status::OK();
2919 }
2920 
2921 template <typename ElementType, typename BiasType, typename ScaleType>
DoFusedConvolveImpl(Stream * stream,const dnn::BatchDescriptor & conv_input_descriptor,const DeviceMemory<ElementType> & conv_input_data,ScaleType conv_input_scale,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<ElementType> & filter_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const DeviceMemory<ElementType> & side_input_data,ScaleType side_input_scale,const dnn::BatchDescriptor & bias_descriptor,const DeviceMemory<BiasType> & biases,dnn::ActivationMode activation_mode,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<ElementType> * output_data,dnn::DataType accumulator_type,ScratchAllocator * scratch_allocator,const dnn::AlgorithmConfig & algorithm_config,dnn::ProfileResult * output_profile_result)2922 port::Status CudnnSupport::DoFusedConvolveImpl(
2923     Stream* stream, const dnn::BatchDescriptor& conv_input_descriptor,
2924     const DeviceMemory<ElementType>& conv_input_data,
2925     ScaleType conv_input_scale, const dnn::FilterDescriptor& filter_descriptor,
2926     const DeviceMemory<ElementType>& filter_data,
2927     const dnn::ConvolutionDescriptor& convolution_descriptor,
2928     const DeviceMemory<ElementType>& side_input_data,
2929     ScaleType side_input_scale, const dnn::BatchDescriptor& bias_descriptor,
2930     const DeviceMemory<BiasType>& biases, dnn::ActivationMode activation_mode,
2931     const dnn::BatchDescriptor& output_descriptor,
2932     DeviceMemory<ElementType>* output_data, dnn::DataType accumulator_type,
2933     ScratchAllocator* scratch_allocator,
2934     const dnn::AlgorithmConfig& algorithm_config,
2935     dnn::ProfileResult* output_profile_result) {
2936   if (activation_mode != dnn::ActivationMode::kRelu &&
2937       activation_mode != dnn::ActivationMode::kNone) {
2938     return port::Status(port::error::INVALID_ARGUMENT,
2939                         "cudnnConvolutionBiasActivationForward() only supports "
2940                         "Relu or None activation.");
2941   }
2942 
2943   CudnnTensorDescriptor conv_input_nd(
2944       conv_input_descriptor,
2945       GetCudnnDataType<ElementType>(conv_input_descriptor.layout()));
2946   CudnnTensorDescriptor output_nd(
2947       output_descriptor,
2948       GetCudnnDataType<ElementType>(conv_input_descriptor.layout()));
2949   CudnnFilterDescriptor filter(
2950       filter_descriptor,
2951       GetCudnnDataType<ElementType>(conv_input_descriptor.layout()));
2952   CudnnTensorDescriptor bias_nd(bias_descriptor, GetCudnnDataType<BiasType>());
2953   CudnnConvolutionDescriptor conv(convolution_descriptor,
2954                                   ToCudnnDataType(accumulator_type));
2955 
2956   auto cudnn = cudnn_->GetHandle(parent_, stream);
2957 
2958   const bool is_profiling = output_profile_result != nullptr;
2959 
2960   DeviceMemory<uint8> scratch;
2961   SE_ASSIGN_OR_RETURN(
2962       dnn::AlgorithmDesc algo_desc,
2963       GetCudnnConvolutionForwardAlgorithm(
2964           stream, cudnn, algorithm_config, conv_input_nd, filter, conv,
2965           output_nd, scratch_allocator, &scratch));
2966 
2967   std::unique_ptr<GpuTimer, GpuTimerDeleter> timer;
2968   if (is_profiling) {
2969     timer.reset(new GpuTimer(parent_));  // NOLINT
2970     // The start and stop of the timer should be as close to the Cudnn call as
2971     // possible. It is still possible for other threads to issue workload on
2972     // to this stream. So it could take multiple profiling measurements.
2973     if (!timer->Init() || !timer->Start(AsGpuStream(stream))) {
2974       return port::Status(port::error::INTERNAL, "Failed to start timer");
2975     }
2976   }
2977   // CUDNN v6 only supports CUDNN_NOT_PROPAGATE_NAN as the reluNanOpt for
2978   // activation descriptor. Note that this will change the nan propagation
2979   // behavior from separate conv, bias, and relu (which by default is
2980   // CUDNN_PROPAGATE_NAN.
2981   CudnnActivationDescriptor activation_desc(
2982       activation_mode, CUDNN_NOT_PROPAGATE_NAN, output_descriptor.value_max());
2983   auto side_input_data_ptr = (side_input_scale == 0) ? output_data->opaque()
2984                                                      : side_input_data.opaque();
2985 
2986   VLOG(2) << "\nconv_input_scale = " << conv_input_scale
2987           << "\nconv_input_nd.handle() = " << conv_input_nd.handle()
2988           << "\nconv_input_data.opaque() = " << conv_input_data.opaque()
2989           << "\nfilter.handle() = " << filter.handle()
2990           << "\nfilter_data.opaque() = " << filter_data.opaque()
2991           << "\nconv.handle() = " << conv.handle()
2992           << "\nalgo = " << algo_desc.algo_id()
2993           << "\nscratch.opaque() = " << scratch.opaque()
2994           << "\nscratch.size() = " << scratch.size()
2995           << "\nside_input_scale = " << side_input_scale
2996           << "\noutput_nd.handle() = " << output_nd.handle()
2997           << "\nside_input_data_ptr = " << side_input_data_ptr
2998           << "\nbias_nd.handle() = " << bias_nd.handle()
2999           << "\nbiases.opaque() = " << biases.opaque()
3000           << "\nactivation_desc.handle() = " << activation_desc.handle()
3001           << "\noutput_nd.handle() = " << output_nd.handle()
3002           << "\noutput_data->opaque() = " << output_data->opaque();
3003 
3004   if (algo_desc.algo_id() == CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED &&
3005       !ShouldIncludeWinogradNonfusedAlgo(conv_input_descriptor,
3006                                          output_descriptor)) {
3007     return port::Status(port::error::FAILED_PRECONDITION,
3008                         "This configuration has potential integer overflow in "
3009                         "cuDNNv5 and cuDNNv6. See around b/68264959.");
3010   }
3011 
3012   RETURN_IF_CUDNN_ERROR(cudnnConvolutionBiasActivationForward(
3013       cudnn.handle(),
3014       /*alpha1=*/&conv_input_scale,
3015       /*srcDesc=*/conv_input_nd.handle(), /*srcData=*/conv_input_data.opaque(),
3016       /*filterDesc=*/filter.handle(), /*filterData=*/filter_data.opaque(),
3017       /*convDesc=*/conv.handle(), ToConvForwardAlgo(algo_desc),
3018       /*workSpace=*/scratch.opaque(),
3019       /*workSpaceSizeInBytes=*/scratch.size(), /*alpha2=*/&side_input_scale,
3020       /*zDesc=*/output_nd.handle(), /*z=*/side_input_data_ptr,
3021       /*biasDesc=*/bias_nd.handle(), /*bias=*/biases.opaque(),
3022       /*activationDesc=*/activation_desc.handle(),
3023       /*yDesc=*/output_nd.handle(), /*y=*/output_data->opaque()));
3024 
3025   if (is_profiling) {
3026     if (!timer->Stop(AsGpuStream(stream))) {
3027       return port::Status(port::error::INTERNAL, "Failed to stop timer");
3028     }
3029     output_profile_result->set_algorithm(algo_desc);
3030     output_profile_result->set_elapsed_time_in_ms(
3031         timer->GetElapsedMilliseconds());
3032     output_profile_result->set_scratch_size(scratch.size());
3033   }
3034 
3035   return port::Status::OK();
3036 }
3037 
TensorOpMathAvailable(int cc_major)3038 inline bool TensorOpMathAvailable(int cc_major) {
3039   return cc_major >= 7 && CUDNN_VERSION >= 7000 && TensorOpMathEnabled();
3040 }
3041 
GetConvolveAlgorithms(bool with_winograd_nonfused,int cc_major,int cc_minor,std::vector<dnn::AlgorithmDesc> * out_algorithms)3042 bool CudnnSupport::GetConvolveAlgorithms(
3043     bool with_winograd_nonfused, int cc_major, int cc_minor,
3044     std::vector<dnn::AlgorithmDesc>* out_algorithms) {
3045   bool tensor_op_math_available = TensorOpMathAvailable(cc_major);
3046   out_algorithms->clear();
3047 
3048   if (RequireDeterminism()) {
3049     out_algorithms->push_back({CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM,
3050                                tensor_op_math_available});
3051     return true;
3052   }
3053 
3054   std::vector<dnn::AlgorithmDesc::Index> algo_types = {
3055     // clang-format off
3056     CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM,
3057     CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM,
3058     CUDNN_CONVOLUTION_FWD_ALGO_GEMM,
3059     CUDNN_CONVOLUTION_FWD_ALGO_DIRECT,
3060     CUDNN_CONVOLUTION_FWD_ALGO_FFT,
3061     CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD,
3062     // clang-format on
3063   };
3064   if (CudnnEnvVar<FftTilingForward>::IsEnabled()) {
3065     algo_types.push_back(CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING);
3066   }
3067   if (CudnnEnvVar<WinogradNonfused>::IsEnabled() && with_winograd_nonfused) {
3068     algo_types.push_back(CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED);
3069   }
3070 
3071   for (auto i : algo_types) {
3072     out_algorithms->push_back({i, /*use_tensor_ops=*/false});
3073     if (tensor_op_math_available) {
3074       out_algorithms->push_back({i, /*use_tensor_ops=*/true});
3075     }
3076   }
3077 
3078   return true;
3079 }
3080 
GetRnnAlgorithms(std::vector<dnn::AlgorithmDesc> * out_algorithms)3081 bool CudnnSupport::GetRnnAlgorithms(
3082     std::vector<dnn::AlgorithmDesc>* out_algorithms) {
3083   std::vector<dnn::AlgorithmDesc::Index> algo_types = {
3084       // clang-format off
3085     CUDNN_RNN_ALGO_STANDARD,
3086     CUDNN_RNN_ALGO_PERSIST_STATIC,
3087     CUDNN_RNN_ALGO_PERSIST_DYNAMIC,
3088       // clang-format on
3089   };
3090 
3091   out_algorithms->clear();
3092   for (auto i : algo_types) {
3093     out_algorithms->push_back({i, /*use_tensor_ops=*/false});
3094 #if CUDNN_VERSION >= 7100
3095     if (RnnTensorOpMathEnabled()) {
3096       out_algorithms->push_back({i, /*use_tensor_ops=*/true});
3097     }
3098 #endif
3099   }
3100   return true;
3101 }
3102 
GetConvolveBackwardDataAlgorithms(bool with_winograd_nonfused,int cc_major,int cc_minor,std::vector<dnn::AlgorithmDesc> * out_algorithms)3103 bool CudnnSupport::GetConvolveBackwardDataAlgorithms(
3104     bool with_winograd_nonfused, int cc_major, int cc_minor,
3105     std::vector<dnn::AlgorithmDesc>* out_algorithms) {
3106   bool tensor_op_math_available = TensorOpMathAvailable(cc_major);
3107   out_algorithms->clear();
3108 
3109   if (RequireDeterminism()) {
3110     out_algorithms->push_back(
3111         {CUDNN_CONVOLUTION_BWD_DATA_ALGO_1, tensor_op_math_available});
3112     return true;
3113   }
3114 
3115   std::vector<dnn::AlgorithmDesc::Index> algo_types = {
3116       // clang-format off
3117     CUDNN_CONVOLUTION_BWD_DATA_ALGO_0,
3118     CUDNN_CONVOLUTION_BWD_DATA_ALGO_1,
3119     CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT,
3120     CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING,
3121     CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD,
3122       // clang-format on
3123   };
3124   if (CudnnEnvVar<WinogradNonfused>::IsEnabled() && with_winograd_nonfused) {
3125     algo_types.push_back(CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD_NONFUSED);
3126   }
3127 
3128   for (auto i : algo_types) {
3129     out_algorithms->push_back({i, /*use_tensor_ops=*/false});
3130     if (tensor_op_math_available) {
3131       out_algorithms->push_back({i, /*use_tensor_ops=*/true});
3132     }
3133   }
3134 
3135   return true;
3136 }
3137 
GetConvolveBackwardFilterAlgorithms(bool with_winograd_nonfused,int cc_major,int cc_minor,std::vector<dnn::AlgorithmDesc> * out_algorithms)3138 bool CudnnSupport::GetConvolveBackwardFilterAlgorithms(
3139     bool with_winograd_nonfused, int cc_major, int cc_minor,
3140     std::vector<dnn::AlgorithmDesc>* out_algorithms) {
3141   bool tensor_op_math_available = TensorOpMathAvailable(cc_major);
3142   out_algorithms->clear();
3143 
3144   if (RequireDeterminism()) {
3145     out_algorithms->push_back(
3146         {CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1, tensor_op_math_available});
3147     return true;
3148   }
3149 
3150   std::vector<dnn::AlgorithmDesc::Index> algo_types = {
3151       // clang-format off
3152       CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0,
3153       CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1,
3154       CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT,
3155       CUDNN_CONVOLUTION_BWD_FILTER_ALGO_3,
3156       // Based on cudnn.h, the following is not implemented.
3157       // CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD,
3158 
3159       // Produces incorrect results for some shapes. Disabled for now, see
3160       // NVIDIA bug 2072856. TODO(csigg): Only disable for subset of shapes.
3161       // CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT_TILING,
3162       // clang-format on
3163   };
3164   if (CudnnEnvVar<WinogradNonfused>::IsEnabled() && with_winograd_nonfused) {
3165     algo_types.push_back(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD_NONFUSED);
3166   }
3167 
3168   for (auto i : algo_types) {
3169     out_algorithms->push_back({i, /*use_tensor_ops=*/false});
3170     if (tensor_op_math_available) {
3171       out_algorithms->push_back({i, /*use_tensor_ops=*/true});
3172     }
3173   }
3174 
3175   return true;
3176 }
3177 
DoBatchNormalizationForward(Stream * stream,const DeviceMemory<float> & x,const DeviceMemory<float> & scale,const DeviceMemory<float> & offset,const DeviceMemory<float> & estimated_mean,const DeviceMemory<float> & estimated_variance,const dnn::BatchDescriptor & x_desc,const dnn::BatchDescriptor & scale_offset_desc,const double epsilon,DeviceMemory<float> * y,DeviceMemory<float> * batch_mean,DeviceMemory<float> * batch_var,DeviceMemory<float> * saved_mean,DeviceMemory<float> * saved_inv_var,bool is_training,std::function<const DeviceMemory<float> & ()> var_to_inv_var,std::function<void ()> inv_var_to_var)3178 bool CudnnSupport::DoBatchNormalizationForward(
3179     Stream* stream, const DeviceMemory<float>& x,
3180     const DeviceMemory<float>& scale, const DeviceMemory<float>& offset,
3181     const DeviceMemory<float>& estimated_mean,
3182     const DeviceMemory<float>& estimated_variance,
3183     const dnn::BatchDescriptor& x_desc,
3184     const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
3185     DeviceMemory<float>* y, DeviceMemory<float>* batch_mean,
3186     DeviceMemory<float>* batch_var, DeviceMemory<float>* saved_mean,
3187     DeviceMemory<float>* saved_inv_var, bool is_training,
3188     std::function<const DeviceMemory<float>&()> var_to_inv_var,
3189     std::function<void()> inv_var_to_var) {
3190   return IsStatusOk(
3191       DoBatchNormalizationForwardImpl<float, float>(
3192           stream, dnn::DataType::kFloat, dnn::DataType::kFloat, x, scale,
3193           offset, estimated_mean, estimated_variance, x_desc, scale_offset_desc,
3194           epsilon, y, batch_mean, batch_var, saved_mean, saved_inv_var,
3195           is_training, std::move(var_to_inv_var), std::move(inv_var_to_var)),
3196       /*report_error=*/true);
3197 }
3198 
DoBatchNormalizationForward(Stream * stream,const DeviceMemory<Eigen::half> & x,const DeviceMemory<float> & scale,const DeviceMemory<float> & offset,const DeviceMemory<float> & estimated_mean,const DeviceMemory<float> & estimated_variance,const dnn::BatchDescriptor & x_desc,const dnn::BatchDescriptor & scale_offset_desc,const double epsilon,DeviceMemory<Eigen::half> * y,DeviceMemory<float> * batch_mean,DeviceMemory<float> * batch_var,DeviceMemory<float> * saved_mean,DeviceMemory<float> * saved_inv_var,bool is_training,std::function<const DeviceMemory<float> & ()> var_to_inv_var,std::function<void ()> inv_var_to_var)3199 bool CudnnSupport::DoBatchNormalizationForward(
3200     Stream* stream, const DeviceMemory<Eigen::half>& x,
3201     const DeviceMemory<float>& scale, const DeviceMemory<float>& offset,
3202     const DeviceMemory<float>& estimated_mean,
3203     const DeviceMemory<float>& estimated_variance,
3204     const dnn::BatchDescriptor& x_desc,
3205     const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
3206     DeviceMemory<Eigen::half>* y, DeviceMemory<float>* batch_mean,
3207     DeviceMemory<float>* batch_var, DeviceMemory<float>* saved_mean,
3208     DeviceMemory<float>* saved_inv_var, bool is_training,
3209     std::function<const DeviceMemory<float>&()> var_to_inv_var,
3210     std::function<void()> inv_var_to_var) {
3211   return IsStatusOk(
3212       DoBatchNormalizationForwardImpl<Eigen::half, float>(
3213           stream, dnn::DataType::kHalf, dnn::DataType::kFloat, x, scale, offset,
3214           estimated_mean, estimated_variance, x_desc, scale_offset_desc,
3215           epsilon, y, batch_mean, batch_var, saved_mean, saved_inv_var,
3216           is_training, std::move(var_to_inv_var), std::move(inv_var_to_var)),
3217       /*report_error=*/true);
3218 }
3219 
3220 template <class T, class U>
DoBatchNormalizationForwardImpl(Stream * stream,dnn::DataType input_data_type,dnn::DataType scale_data_type,const DeviceMemory<T> & x,const DeviceMemory<U> & scale,const DeviceMemory<U> & offset,const DeviceMemory<U> & estimated_mean,const DeviceMemory<U> & estimated_variance,const dnn::BatchDescriptor & x_desc,const dnn::BatchDescriptor & scale_offset_desc,const double epsilon,DeviceMemory<T> * y,DeviceMemory<U> * batch_mean,DeviceMemory<U> * batch_var,DeviceMemory<U> * saved_mean,DeviceMemory<U> * saved_inv_var,bool is_training,std::function<const DeviceMemory<U> & ()> var_to_inv_var,std::function<void ()> inv_var_to_var)3221 port::Status CudnnSupport::DoBatchNormalizationForwardImpl(
3222     Stream* stream, dnn::DataType input_data_type,
3223     dnn::DataType scale_data_type, const DeviceMemory<T>& x,
3224     const DeviceMemory<U>& scale, const DeviceMemory<U>& offset,
3225     const DeviceMemory<U>& estimated_mean,
3226     const DeviceMemory<U>& estimated_variance,
3227     const dnn::BatchDescriptor& x_desc,
3228     const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
3229     DeviceMemory<T>* y, DeviceMemory<U>* batch_mean, DeviceMemory<U>* batch_var,
3230     DeviceMemory<U>* saved_mean, DeviceMemory<U>* saved_inv_var,
3231     bool is_training, std::function<const DeviceMemory<U>&()> var_to_inv_var,
3232     std::function<void()> inv_var_to_var) {
3233   CudnnTensorDescriptor x_descriptor(x_desc, ToCudnnDataType(input_data_type));
3234   CudnnTensorDescriptor scale_offset_descriptor(
3235       scale_offset_desc, ToCudnnDataType(scale_data_type));
3236   cudnnBatchNormMode_t mode = CUDNN_BATCHNORM_SPATIAL;
3237 #if CUDNN_VERSION >= 7000
3238   if (BatchnormSpatialPersistentEnabled() && is_training) {
3239     mode = CUDNN_BATCHNORM_SPATIAL_PERSISTENT;
3240   }
3241 #endif
3242   float one = 1.0;
3243   float zero = 0.0;
3244   auto cudnn = cudnn_->GetHandle(parent_, stream);
3245 
3246   if (is_training) {
3247     CHECK_EQ(batch_mean->is_null(), batch_var->is_null())
3248         << "batch_mean and batch_var must both be null or both be non-null";
3249 
3250     void* batch_mean_opaque;
3251     void* batch_var_opaque;
3252     if (!batch_mean->is_null() && !batch_var->is_null()) {
3253       stream->ThenMemZero(batch_mean, batch_mean->size());
3254       stream->ThenMemZero(batch_var, batch_var->size());
3255       batch_mean_opaque = batch_mean->opaque();
3256       batch_var_opaque = batch_var->opaque();
3257     } else {
3258       batch_mean_opaque = nullptr;
3259       batch_var_opaque = nullptr;
3260     }
3261 
3262     RETURN_IF_CUDNN_ERROR(cudnnBatchNormalizationForwardTraining(
3263         cudnn.handle(), mode, &one, &zero, x_descriptor.handle(), x.opaque(),
3264         x_descriptor.handle(), y->opaque(), scale_offset_descriptor.handle(),
3265         scale.opaque(), offset.opaque(), 1.0, batch_mean_opaque,
3266         batch_var_opaque, epsilon, saved_mean->opaque(),
3267         saved_inv_var->opaque()));
3268   } else {
3269     const void* maybe_inv_var = estimated_variance.opaque();
3270     RETURN_IF_CUDNN_ERROR(cudnnBatchNormalizationForwardInference(
3271         cudnn.handle(), mode, &one, &zero, x_descriptor.handle(), x.opaque(),
3272         x_descriptor.handle(), y->opaque(), scale_offset_descriptor.handle(),
3273         scale.opaque(), offset.opaque(), estimated_mean.opaque(), maybe_inv_var,
3274         epsilon));
3275   }
3276   return port::Status::OK();
3277 }
3278 
DoBatchNormalizationBackward(Stream * stream,const DeviceMemory<float> & y_backprop,const DeviceMemory<float> & x,const DeviceMemory<float> & scale,const DeviceMemory<float> & mean,const DeviceMemory<float> & inv_var,const dnn::BatchDescriptor & x_desc,const dnn::BatchDescriptor & scale_offset_desc,const double epsilon,DeviceMemory<float> * x_backprop,DeviceMemory<float> * scale_backprop,DeviceMemory<float> * offset_backprop)3279 bool CudnnSupport::DoBatchNormalizationBackward(
3280     Stream* stream, const DeviceMemory<float>& y_backprop,
3281     const DeviceMemory<float>& x, const DeviceMemory<float>& scale,
3282     const DeviceMemory<float>& mean, const DeviceMemory<float>& inv_var,
3283     const dnn::BatchDescriptor& x_desc,
3284     const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
3285     DeviceMemory<float>* x_backprop, DeviceMemory<float>* scale_backprop,
3286     DeviceMemory<float>* offset_backprop) {
3287   return IsStatusOk(DoBatchNormalizationBackwardImpl(
3288                         stream, CUDNN_DATA_FLOAT, CUDNN_DATA_FLOAT, y_backprop,
3289                         x, scale, mean, inv_var, x_desc, scale_offset_desc,
3290                         epsilon, x_backprop, scale_backprop, offset_backprop),
3291                     /*report_error=*/true);
3292 }
3293 
DoBatchNormalizationBackward(Stream * stream,const DeviceMemory<Eigen::half> & y_backprop,const DeviceMemory<Eigen::half> & x,const DeviceMemory<float> & scale,const DeviceMemory<float> & mean,const DeviceMemory<float> & inv_var,const dnn::BatchDescriptor & x_desc,const dnn::BatchDescriptor & scale_offset_desc,const double epsilon,DeviceMemory<Eigen::half> * x_backprop,DeviceMemory<float> * scale_backprop,DeviceMemory<float> * offset_backprop)3294 bool CudnnSupport::DoBatchNormalizationBackward(
3295     Stream* stream, const DeviceMemory<Eigen::half>& y_backprop,
3296     const DeviceMemory<Eigen::half>& x, const DeviceMemory<float>& scale,
3297     const DeviceMemory<float>& mean, const DeviceMemory<float>& inv_var,
3298     const dnn::BatchDescriptor& x_desc,
3299     const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
3300     DeviceMemory<Eigen::half>* x_backprop, DeviceMemory<float>* scale_backprop,
3301     DeviceMemory<float>* offset_backprop) {
3302   return IsStatusOk(DoBatchNormalizationBackwardImpl(
3303                         stream, CUDNN_DATA_HALF, CUDNN_DATA_FLOAT, y_backprop,
3304                         x, scale, mean, inv_var, x_desc, scale_offset_desc,
3305                         epsilon, x_backprop, scale_backprop, offset_backprop),
3306                     /*report_error=*/true);
3307 }
3308 
3309 template <class T, class U>
DoBatchNormalizationBackwardImpl(Stream * stream,int cudnn_input_type,int cudnn_scale_type,const DeviceMemory<T> & y_backprop,const DeviceMemory<T> & x,const DeviceMemory<U> & scale,const DeviceMemory<U> & mean,const DeviceMemory<U> & inv_var,const dnn::BatchDescriptor & x_desc,const dnn::BatchDescriptor & scale_offset_desc,const double epsilon,DeviceMemory<T> * x_backprop,DeviceMemory<U> * scale_backprop,DeviceMemory<U> * offset_backprop)3310 port::Status CudnnSupport::DoBatchNormalizationBackwardImpl(
3311     Stream* stream, int cudnn_input_type, int cudnn_scale_type,
3312     const DeviceMemory<T>& y_backprop, const DeviceMemory<T>& x,
3313     const DeviceMemory<U>& scale, const DeviceMemory<U>& mean,
3314     const DeviceMemory<U>& inv_var, const dnn::BatchDescriptor& x_desc,
3315     const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
3316     DeviceMemory<T>* x_backprop, DeviceMemory<U>* scale_backprop,
3317     DeviceMemory<U>* offset_backprop) {
3318   CudnnTensorDescriptor x_descriptor(
3319       x_desc, static_cast<cudnnDataType_t>(cudnn_input_type));
3320   CudnnTensorDescriptor scale_offset_descriptor(
3321       scale_offset_desc, static_cast<cudnnDataType_t>(cudnn_scale_type));
3322   cudnnBatchNormMode_t mode = CUDNN_BATCHNORM_SPATIAL;
3323 #if CUDNN_VERSION >= 7000
3324   if (BatchnormSpatialPersistentEnabled()) {
3325     mode = CUDNN_BATCHNORM_SPATIAL_PERSISTENT;
3326   }
3327 #endif
3328   float one = 1.0;
3329   float zero = 0.0;
3330 
3331   auto cudnn = cudnn_->GetHandle(parent_, stream);
3332 
3333   RETURN_IF_CUDNN_ERROR(cudnnBatchNormalizationBackward(
3334       cudnn.handle(), mode, &one, &zero, &one, &zero, x_descriptor.handle(),
3335       x.opaque(), x_descriptor.handle(), y_backprop.opaque(),
3336       x_descriptor.handle(), x_backprop->opaque(),
3337       scale_offset_descriptor.handle(), scale.opaque(),
3338       scale_backprop->opaque(), offset_backprop->opaque(), epsilon,
3339       mean.opaque(), inv_var.opaque()));
3340   return port::Status::OK();
3341 }
3342 
DoFusedConvolve(Stream * stream,const dnn::BatchDescriptor & conv_input_descriptor,const DeviceMemory<double> & conv_input_data,double conv_input_scale,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<double> & filter_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const DeviceMemory<double> & side_input_data,double side_input_scale,const dnn::BatchDescriptor & bias_descriptor,const DeviceMemory<double> & biases,dnn::ActivationMode activation_mode,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<double> * output_data,ScratchAllocator * scratch_allocator,const dnn::AlgorithmConfig & algorithm_config,dnn::ProfileResult * output_profile_result)3343 bool CudnnSupport::DoFusedConvolve(
3344     Stream* stream, const dnn::BatchDescriptor& conv_input_descriptor,
3345     const DeviceMemory<double>& conv_input_data, double conv_input_scale,
3346     const dnn::FilterDescriptor& filter_descriptor,
3347     const DeviceMemory<double>& filter_data,
3348     const dnn::ConvolutionDescriptor& convolution_descriptor,
3349     const DeviceMemory<double>& side_input_data, double side_input_scale,
3350     const dnn::BatchDescriptor& bias_descriptor,
3351     const DeviceMemory<double>& biases, dnn::ActivationMode activation_mode,
3352     const dnn::BatchDescriptor& output_descriptor,
3353     DeviceMemory<double>* output_data, ScratchAllocator* scratch_allocator,
3354     const dnn::AlgorithmConfig& algorithm_config,
3355     dnn::ProfileResult* output_profile_result) {
3356   return IsStatusOk(
3357       DoFusedConvolveImpl(
3358           stream, conv_input_descriptor, conv_input_data, conv_input_scale,
3359           filter_descriptor, filter_data, convolution_descriptor,
3360           side_input_data, side_input_scale, bias_descriptor, biases,
3361           activation_mode, output_descriptor, output_data,
3362           GetConvAccumulatorType(dnn::DataType::kDouble), scratch_allocator,
3363           algorithm_config, output_profile_result),
3364       /*report_error=*/!output_profile_result);
3365 }
3366 
DoFusedConvolve(Stream * stream,const dnn::BatchDescriptor & conv_input_descriptor,const DeviceMemory<float> & conv_input_data,float conv_input_scale,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<float> & filter_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const DeviceMemory<float> & side_input_data,float side_input_scale,const dnn::BatchDescriptor & bias_descriptor,const DeviceMemory<float> & biases,dnn::ActivationMode activation_mode,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<float> * output_data,ScratchAllocator * scratch_allocator,const dnn::AlgorithmConfig & algorithm_config,dnn::ProfileResult * output_profile_result)3367 bool CudnnSupport::DoFusedConvolve(
3368     Stream* stream, const dnn::BatchDescriptor& conv_input_descriptor,
3369     const DeviceMemory<float>& conv_input_data, float conv_input_scale,
3370     const dnn::FilterDescriptor& filter_descriptor,
3371     const DeviceMemory<float>& filter_data,
3372     const dnn::ConvolutionDescriptor& convolution_descriptor,
3373     const DeviceMemory<float>& side_input_data, float side_input_scale,
3374     const dnn::BatchDescriptor& bias_descriptor,
3375     const DeviceMemory<float>& biases, dnn::ActivationMode activation_mode,
3376     const dnn::BatchDescriptor& output_descriptor,
3377     DeviceMemory<float>* output_data, ScratchAllocator* scratch_allocator,
3378     const dnn::AlgorithmConfig& algorithm_config,
3379     dnn::ProfileResult* output_profile_result) {
3380   return IsStatusOk(
3381       DoFusedConvolveImpl(
3382           stream, conv_input_descriptor, conv_input_data, conv_input_scale,
3383           filter_descriptor, filter_data, convolution_descriptor,
3384           side_input_data, side_input_scale, bias_descriptor, biases,
3385           activation_mode, output_descriptor, output_data,
3386           GetConvAccumulatorType(dnn::DataType::kFloat), scratch_allocator,
3387           algorithm_config, output_profile_result),
3388       /*report_error=*/!output_profile_result);
3389 }
3390 
DoFusedConvolve(Stream * stream,const dnn::BatchDescriptor & conv_input_descriptor,const DeviceMemory<Eigen::half> & conv_input_data,float conv_input_scale,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<Eigen::half> & filter_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const DeviceMemory<Eigen::half> & side_input_data,float side_input_scale,const dnn::BatchDescriptor & bias_descriptor,const DeviceMemory<Eigen::half> & biases,dnn::ActivationMode activation_mode,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<Eigen::half> * output_data,ScratchAllocator * scratch_allocator,const dnn::AlgorithmConfig & algorithm_config,dnn::ProfileResult * output_profile_result)3391 bool CudnnSupport::DoFusedConvolve(
3392     Stream* stream, const dnn::BatchDescriptor& conv_input_descriptor,
3393     const DeviceMemory<Eigen::half>& conv_input_data, float conv_input_scale,
3394     const dnn::FilterDescriptor& filter_descriptor,
3395     const DeviceMemory<Eigen::half>& filter_data,
3396     const dnn::ConvolutionDescriptor& convolution_descriptor,
3397     const DeviceMemory<Eigen::half>& side_input_data, float side_input_scale,
3398     const dnn::BatchDescriptor& bias_descriptor,
3399     const DeviceMemory<Eigen::half>& biases,
3400     dnn::ActivationMode activation_mode,
3401     const dnn::BatchDescriptor& output_descriptor,
3402     DeviceMemory<Eigen::half>* output_data, ScratchAllocator* scratch_allocator,
3403     const dnn::AlgorithmConfig& algorithm_config,
3404     dnn::ProfileResult* output_profile_result) {
3405   return IsStatusOk(
3406       DoFusedConvolveImpl(
3407           stream, conv_input_descriptor, conv_input_data, conv_input_scale,
3408           filter_descriptor, filter_data, convolution_descriptor,
3409           side_input_data, side_input_scale, bias_descriptor, biases,
3410           activation_mode, output_descriptor, output_data,
3411           GetConvAccumulatorType(dnn::DataType::kHalf), scratch_allocator,
3412           algorithm_config, output_profile_result),
3413       /*report_error=*/!output_profile_result);
3414 }
3415 
DoFusedConvolve(Stream * stream,const dnn::BatchDescriptor & conv_input_descriptor,const DeviceMemory<int8> & conv_input_data,float conv_input_scale,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<int8> & filter_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const DeviceMemory<int8> & side_input_data,float side_input_scale,const dnn::BatchDescriptor & bias_descriptor,const DeviceMemory<float> & biases,dnn::ActivationMode activation_mode,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<int8> * output_data,ScratchAllocator * scratch_allocator,const dnn::AlgorithmConfig & algorithm_config,dnn::ProfileResult * output_profile_result)3416 bool CudnnSupport::DoFusedConvolve(
3417     Stream* stream, const dnn::BatchDescriptor& conv_input_descriptor,
3418     const DeviceMemory<int8>& conv_input_data, float conv_input_scale,
3419     const dnn::FilterDescriptor& filter_descriptor,
3420     const DeviceMemory<int8>& filter_data,
3421     const dnn::ConvolutionDescriptor& convolution_descriptor,
3422     const DeviceMemory<int8>& side_input_data, float side_input_scale,
3423     const dnn::BatchDescriptor& bias_descriptor,
3424     const DeviceMemory<float>& biases, dnn::ActivationMode activation_mode,
3425     const dnn::BatchDescriptor& output_descriptor,
3426     DeviceMemory<int8>* output_data, ScratchAllocator* scratch_allocator,
3427     const dnn::AlgorithmConfig& algorithm_config,
3428     dnn::ProfileResult* output_profile_result) {
3429   int cc_major, cc_minor;
3430   stream->parent()->GetDeviceDescription().cuda_compute_capability(&cc_major,
3431                                                                    &cc_minor);
3432   if (cc_major < 6 || (cc_major == 6 && cc_minor < 1)) {
3433     LOG(WARNING) << "cudnnConvolutionBiasActivationForward() for int8 is only "
3434                     "supported on GPUs with compute capability 6.1 or later.";
3435     return false;
3436   }
3437   return IsStatusOk(
3438       DoFusedConvolveImpl(
3439           stream, conv_input_descriptor, conv_input_data, conv_input_scale,
3440           filter_descriptor, filter_data, convolution_descriptor,
3441           side_input_data, side_input_scale, bias_descriptor, biases,
3442           activation_mode, output_descriptor, output_data,
3443           GetConvAccumulatorType(dnn::DataType::kInt8), scratch_allocator,
3444           algorithm_config, output_profile_result),
3445       /*report_error=*/!output_profile_result);
3446 }
3447 
DoTransformTensor(Stream * stream,const dnn::BatchDescriptor & input_desc,dnn::DataType input_type,const DeviceMemoryBase & input_data,const dnn::BatchDescriptor & output_desc,dnn::DataType output_type,float scale,DeviceMemoryBase * output_data)3448 bool CudnnSupport::DoTransformTensor(Stream* stream,
3449                                      const dnn::BatchDescriptor& input_desc,
3450                                      dnn::DataType input_type,
3451                                      const DeviceMemoryBase& input_data,
3452                                      const dnn::BatchDescriptor& output_desc,
3453                                      dnn::DataType output_type, float scale,
3454                                      DeviceMemoryBase* output_data) {
3455   float beta = 0.0f;
3456   CudnnTensorDescriptor input_tensor_desc(
3457       input_desc, ToCudnnDataType(input_type, input_desc.layout()));
3458   CudnnTensorDescriptor output_tensor_desc(
3459       output_desc, ToCudnnDataType(output_type, output_desc.layout()));
3460   auto cudnn = cudnn_->GetHandle(parent_, stream);
3461   const auto status = [&] {
3462     RETURN_IF_CUDNN_ERROR(cudnnTransformTensor(
3463         cudnn.handle(), &scale, input_tensor_desc.handle(), input_data.opaque(),
3464         &beta, output_tensor_desc.handle(), output_data->opaque()));
3465     return port::Status::OK();
3466   }();
3467   return IsStatusOk(status, /*report_error=*/true);
3468 }
3469 
3470 template <class T>
DoConvolveBackwardBiasImpl(Stream * stream,const dnn::BatchDescriptor & input_descriptor,const DeviceMemory<T> & input_data,const dnn::BatchDescriptor & bias_descriptor,DeviceMemory<T> * backward_bias_data)3471 port::Status CudnnSupport::DoConvolveBackwardBiasImpl(
3472     Stream* stream, const dnn::BatchDescriptor& input_descriptor,
3473     const DeviceMemory<T>& input_data,
3474     const dnn::BatchDescriptor& bias_descriptor,
3475     DeviceMemory<T>* backward_bias_data) {
3476   cudnnDataType_t cudnn_type = GetCudnnDataType<T>();
3477   CudnnTensorDescriptor input_nd(input_descriptor, cudnn_type);
3478   CudnnTensorDescriptor bias_nd(bias_descriptor, cudnn_type);
3479 
3480   // Alpha is the scaling factor for input.
3481   float alpha = 1.0;
3482   // Beta is the scaling factor for output.
3483   float beta = 0.0;
3484 
3485   auto cudnn = cudnn_->GetHandle(parent_, stream);
3486   RETURN_IF_CUDNN_ERROR(cudnnConvolutionBackwardBias(
3487       cudnn.handle(), &alpha, input_nd.handle(), input_data.opaque(), &beta,
3488       bias_nd.handle(), backward_bias_data->opaque()));
3489   return port::Status::OK();
3490 }
3491 
DoConvolveBackwardBias(Stream * stream,const dnn::BatchDescriptor & input_descriptor,const DeviceMemory<double> & input_data,const dnn::BatchDescriptor & bias_descriptor,DeviceMemory<double> * backward_bias_data)3492 bool CudnnSupport::DoConvolveBackwardBias(
3493     Stream* stream, const dnn::BatchDescriptor& input_descriptor,
3494     const DeviceMemory<double>& input_data,
3495     const dnn::BatchDescriptor& bias_descriptor,
3496     DeviceMemory<double>* backward_bias_data) {
3497   return IsStatusOk(
3498       DoConvolveBackwardBiasImpl(stream, input_descriptor, input_data,
3499                                  bias_descriptor, backward_bias_data),
3500       /*report_error=*/true);
3501 }
3502 
DoConvolveBackwardBias(Stream * stream,const dnn::BatchDescriptor & input_descriptor,const DeviceMemory<float> & input_data,const dnn::BatchDescriptor & bias_descriptor,DeviceMemory<float> * backward_bias_data)3503 bool CudnnSupport::DoConvolveBackwardBias(
3504     Stream* stream, const dnn::BatchDescriptor& input_descriptor,
3505     const DeviceMemory<float>& input_data,
3506     const dnn::BatchDescriptor& bias_descriptor,
3507     DeviceMemory<float>* backward_bias_data) {
3508   return IsStatusOk(
3509       DoConvolveBackwardBiasImpl(stream, input_descriptor, input_data,
3510                                  bias_descriptor, backward_bias_data),
3511       /*report_error=*/true);
3512 }
3513 
DoConvolveBackwardBias(Stream * stream,const dnn::BatchDescriptor & input_descriptor,const DeviceMemory<Eigen::half> & input_data,const dnn::BatchDescriptor & bias_descriptor,DeviceMemory<Eigen::half> * backward_bias_data)3514 bool CudnnSupport::DoConvolveBackwardBias(
3515     Stream* stream, const dnn::BatchDescriptor& input_descriptor,
3516     const DeviceMemory<Eigen::half>& input_data,
3517     const dnn::BatchDescriptor& bias_descriptor,
3518     DeviceMemory<Eigen::half>* backward_bias_data) {
3519   return IsStatusOk(
3520       DoConvolveBackwardBiasImpl(stream, input_descriptor, input_data,
3521                                  bias_descriptor, backward_bias_data),
3522       /*report_error=*/true);
3523 }
3524 
DoMatMul(Stream * stream,const DeviceMemory<float> & input_data,const DeviceMemory<float> & weights,const dnn::BatchDescriptor & input_dimensions,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<float> * output_data)3525 bool CudnnSupport::DoMatMul(Stream* stream,
3526                             const DeviceMemory<float>& input_data,
3527                             const DeviceMemory<float>& weights,
3528                             const dnn::BatchDescriptor& input_dimensions,
3529                             const dnn::BatchDescriptor& output_dimensions,
3530                             DeviceMemory<float>* output_data) {
3531   if (input_dimensions.count() != output_dimensions.count()) {
3532     LOG(ERROR) << "MatMul input and output dimensions are not compatible.";
3533     return false;
3534   }
3535 
3536   // We do not permute the input or output, instead we just
3537   // reinterpret the layout. We are working with row-major matrices
3538   // and the rows of the input and output correspond to batch, so
3539   // batch has to be outermost in both the input and output.
3540   //
3541   // By adding transposes to the BLAS gemm call we could perhaps make
3542   // the kYXDepthBatch layout work as well, but there has been no need
3543   // for that so far.
3544   if (input_dimensions.layout() != dnn::DataLayout::kBatchYXDepth &&
3545       input_dimensions.layout() != dnn::DataLayout::kBatchDepthYX) {
3546     LOG(ERROR) << "Unsupported MatMul input layout.";
3547     return false;
3548   }
3549   if (output_dimensions.layout() != dnn::DataLayout::kBatchYXDepth &&
3550       output_dimensions.layout() != dnn::DataLayout::kBatchDepthYX) {
3551     LOG(ERROR) << "Unsupported MatMul output layout.";
3552     return false;
3553   }
3554 
3555   if (output_dimensions.width() == 1 && output_dimensions.height() == 1) {
3556     // This is a fast path that also supports the kBatchYXDepth layout.
3557 
3558     // The matrices here are in row-major format while BLAS expects
3559     // column-major, i.e. our matrices are transposed as far as BLAS
3560     // is concerned. So we need to compute output^T =
3561     // input^T*weights^T. There is no parameter for transposing the
3562     // output in BLAS gemm, but instead we can transpose both sides of
3563     // the equality to see that this is equivalent to
3564     // output=weights*input. So we only need to swap the order of
3565     // weights and input in the matrix product to correct for the
3566     // row-major versus column-major difference.
3567     const float alpha = 1.0f;  // Take the matrix product without scaling it.
3568     const float beta = 0.0f;   // Ignore the original values in output_data.
3569     const int64 m = output_dimensions.NodesAcrossFeatureMaps();
3570     const int64 n = input_dimensions.count();
3571     const int64 k = input_dimensions.NodesAcrossFeatureMaps();
3572     stream->ThenBlasGemm(blas::Transpose::kNoTranspose,
3573                          blas::Transpose::kNoTranspose, m, n, k, alpha, weights,
3574                          m, input_data, k, beta, output_data, m);
3575   } else {
3576     // This is a slower and more complex path that supports output
3577     // width() * height() > 1, though it only supports the
3578     // kBatchYXDepth layout. Does support kBatchDepthYX if output
3579     // feature_map_count() == 1, as then there is no difference
3580     // between the two layouts.
3581     //
3582     // The operation here is the same as above, except that we have to
3583     // do the matrix multiplication for each (y,x) output coordinate
3584     // separately. We then interpret weights as containing K = width()
3585     // * height() different matrices, which we all multiply onto the
3586     // matrix from input_data, yielding K matrix products. We then
3587     // combine these together into one matrix by concatenating all the
3588     // first rows of these matrices, then all the seconds rows and so
3589     // on. We can do this with a batched matrix multiplication, where
3590     // the result is written to a different submatrix of the output
3591     // for each matrix multiplication.
3592     //
3593     // The reason that we only support the kBatchYXDepth output layout
3594     // is that we have to do something in the depth for each (y,x)
3595     // coordinate. The kBatchYXDepth layout has the depth information
3596     // for each point (y,x) in contiguous memory while the
3597     // kBatchDepthYX layout does not.
3598     //
3599     // TODO(broune): Consider a special case for when output depth ==
3600     // 1, as then possibly this could all be done as one matrix
3601     // multiplication instead of a batched one, which should be
3602     // faster. Another possibility would be to add a weights layout
3603     // parameter and then support kBatchDepthYX for a different
3604     // weights layout.
3605     if (output_dimensions.layout() != dnn::DataLayout::kBatchYXDepth &&
3606         !(output_dimensions.layout() == dnn::DataLayout::kBatchDepthYX &&
3607           output_dimensions.feature_map_count() == 1)) {
3608       LOG(ERROR) << "Unsupported MatMul output layout.";
3609       return false;
3610     }
3611 
3612     const float alpha = 1.0f;  // Take the matrix product without scaling it.
3613     const float beta = 0.0f;   // Ignore the original values in output_data.
3614     const uint64 m = output_dimensions.feature_map_count();
3615     const uint64 n = input_dimensions.count();
3616     const uint64 k = input_dimensions.NodesAcrossFeatureMaps();
3617     const int lda = m;
3618     const int ldb = k;
3619     const int ldc = output_dimensions.NodesAcrossFeatureMaps();
3620     const int batch_count = output_dimensions.NodesPerFeatureMap();
3621 
3622     std::vector<DeviceMemory<float>> a(batch_count);
3623     std::vector<DeviceMemory<float>> b(batch_count);
3624     std::vector<DeviceMemory<float>> c(batch_count);
3625     for (int i = 0; i < batch_count; ++i) {
3626       const int weights_offset = i * input_dimensions.NodesAcrossFeatureMaps() *
3627                                  output_dimensions.feature_map_count();
3628       a[i] = DeviceMemory<float>::MakeFromByteSize(
3629           const_cast<float*>(reinterpret_cast<const float*>(weights.opaque())) +
3630               weights_offset,
3631           weights.ElementCount() - weights_offset);
3632 
3633       b[i] = input_data;
3634 
3635       const int output_offset = i * output_dimensions.feature_map_count();
3636       c[i] = DeviceMemory<float>::MakeFromByteSize(
3637           const_cast<float*>(
3638               reinterpret_cast<const float*>(output_data->opaque())) +
3639               output_offset,
3640           output_data->ElementCount() - output_offset);
3641     }
3642     const auto toPtrs = [](std::vector<DeviceMemory<float>>& v) {
3643       std::vector<DeviceMemory<float>*> ptrs;
3644       ptrs.reserve(v.size());
3645       for (auto& mem : v) {
3646         ptrs.push_back(&mem);
3647       }
3648       return ptrs;
3649     };
3650 
3651     stream->ThenBlasGemmBatched(blas::Transpose::kNoTranspose,
3652                                 blas::Transpose::kNoTranspose, m, n, k, alpha,
3653                                 toPtrs(a), lda, toPtrs(b), ldb, beta, toPtrs(c),
3654                                 ldc, batch_count);
3655   }
3656 
3657   return stream->ok();
3658 }
3659 
DoBiasAdd(Stream * stream,const DeviceMemory<float> & input_data,const DeviceMemory<float> & biases,const dnn::BatchDescriptor & dimensions,DeviceMemory<float> * output_data)3660 bool CudnnSupport::DoBiasAdd(Stream* stream,
3661                              const DeviceMemory<float>& input_data,
3662                              const DeviceMemory<float>& biases,
3663                              const dnn::BatchDescriptor& dimensions,
3664                              DeviceMemory<float>* output_data) {
3665   CudnnTensorDescriptor input_descriptor(dimensions, CUDNN_DATA_FLOAT);
3666 
3667   dnn::BatchDescriptor bias_dimensions;
3668   bias_dimensions.set_count(1)
3669       .set_feature_map_count(dimensions.feature_map_count())
3670       .set_height(1)
3671       .set_width(1)
3672       .set_layout(dnn::DataLayout::kBatchYXDepth);
3673   CudnnTensorDescriptor bias_descriptor(bias_dimensions, CUDNN_DATA_FLOAT);
3674 
3675   // cudnnAddTensor after R3 is in-place, so we need to copy input_data to
3676   // output_data before doing the addition, unless the input and
3677   // output are at the same address.
3678   if (input_data.opaque() != output_data->opaque()) {
3679     stream->ThenMemcpy(output_data, input_data,
3680                        dimensions.ElementCount() * sizeof(float));
3681     if (!stream->ok()) {
3682       LOG(ERROR)
3683           << "stream " << stream
3684           << " could not enqueue a tensor copy as part of bias addition.";
3685       return false;
3686     }
3687   }
3688 
3689   const float alpha = 1.0f;
3690   const float beta = 1.0f;
3691 
3692   auto cudnn = cudnn_->GetHandle(parent_, stream);
3693 
3694   const auto status = [&] {
3695     RETURN_IF_CUDNN_ERROR(cudnnAddTensor(
3696         cudnn.handle(), &alpha, bias_descriptor.handle(), biases.opaque(),
3697         &beta, input_descriptor.handle(), output_data->opaque()));
3698     return port::Status::OK();
3699   }();
3700   return IsStatusOk(status, /*report_error=*/true);
3701 }
3702 
DoActivate(Stream * stream,dnn::ActivationMode activation_mode,const dnn::BatchDescriptor & dimensions,const DeviceMemory<float> & input_data,DeviceMemory<float> * output_data,uint64 options)3703 bool CudnnSupport::DoActivate(Stream* stream,
3704                               dnn::ActivationMode activation_mode,
3705                               const dnn::BatchDescriptor& dimensions,
3706                               const DeviceMemory<float>& input_data,
3707                               DeviceMemory<float>* output_data,
3708                               uint64 options) {
3709   CudnnActivationDescriptor activation_desc(
3710       activation_mode, CUDNN_PROPAGATE_NAN, dimensions.value_max());
3711 
3712   CudnnTensorDescriptor input_nd(dimensions, CUDNN_DATA_FLOAT);
3713   // Alpha is the input scaling factor.
3714   float alpha = 1.0;
3715   // Beta is the output scaling factor.
3716   float beta = 0.0;
3717 
3718   auto cudnn = cudnn_->GetHandle(parent_, stream);
3719   const auto status = [&] {
3720     RETURN_IF_CUDNN_ERROR(cudnnActivationForward(
3721         cudnn.handle(), activation_desc.handle(), &alpha, input_nd.handle(),
3722         input_data.opaque(), &beta, input_nd.handle(), output_data->opaque()));
3723     return port::Status::OK();
3724   }();
3725   return IsStatusOk(status, /*report_error=*/true);
3726 }
3727 
DoPoolForward(Stream * stream,const dnn::PoolingDescriptor & pooling_dimensions,const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<double> & input_data,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<double> * output_data,ScratchAllocator * workspace_allocator)3728 bool CudnnSupport::DoPoolForward(
3729     Stream* stream, const dnn::PoolingDescriptor& pooling_dimensions,
3730     const dnn::BatchDescriptor& input_dimensions,
3731     const DeviceMemory<double>& input_data,
3732     const dnn::BatchDescriptor& output_dimensions,
3733     DeviceMemory<double>* output_data, ScratchAllocator* workspace_allocator) {
3734   // Alpha is the scaling factor for input.
3735   double alpha = 1.0;
3736   // Beta is the scaling factor for output.
3737   double beta = 0.0;
3738 
3739   CudnnTensorDescriptor src_desc(input_dimensions, CUDNN_DATA_DOUBLE);
3740   CudnnTensorDescriptor dest_desc(output_dimensions, CUDNN_DATA_DOUBLE);
3741   CudnnPoolingDescriptor pooling_desc(pooling_dimensions);
3742 
3743   auto cudnn = cudnn_->GetHandle(parent_, stream);
3744   const auto status = [&] {
3745     RETURN_IF_CUDNN_ERROR(cudnnPoolingForward(
3746         cudnn.handle(), pooling_desc.handle(), &alpha, src_desc.handle(),
3747         input_data.opaque(), &beta, dest_desc.handle(), output_data->opaque()));
3748     return port::Status::OK();
3749   }();
3750   return IsStatusOk(status, /*report_error=*/true);
3751 }
3752 
DoPoolForward(Stream * stream,const dnn::PoolingDescriptor & pooling_dimensions,const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<float> & input_data,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<float> * output_data,ScratchAllocator * workspace_allocator)3753 bool CudnnSupport::DoPoolForward(
3754     Stream* stream, const dnn::PoolingDescriptor& pooling_dimensions,
3755     const dnn::BatchDescriptor& input_dimensions,
3756     const DeviceMemory<float>& input_data,
3757     const dnn::BatchDescriptor& output_dimensions,
3758     DeviceMemory<float>* output_data, ScratchAllocator* workspace_allocator) {
3759   // Alpha is the scaling factor for input.
3760   float alpha = 1.0;
3761   // Beta is the scaling factor for output.
3762   float beta = 0.0;
3763 
3764   CudnnTensorDescriptor src_desc(input_dimensions, CUDNN_DATA_FLOAT);
3765   CudnnTensorDescriptor dest_desc(output_dimensions, CUDNN_DATA_FLOAT);
3766   CudnnPoolingDescriptor pooling_desc(pooling_dimensions);
3767 
3768   auto cudnn = cudnn_->GetHandle(parent_, stream);
3769   const auto status = [&] {
3770     RETURN_IF_CUDNN_ERROR(cudnnPoolingForward(
3771         cudnn.handle(), pooling_desc.handle(), &alpha, src_desc.handle(),
3772         input_data.opaque(), &beta, dest_desc.handle(), output_data->opaque()));
3773     return port::Status::OK();
3774   }();
3775   return IsStatusOk(status, /*report_error=*/true);
3776 }
3777 
DoPoolForward(Stream * stream,const dnn::PoolingDescriptor & pooling_dimensions,const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<Eigen::half> & input_data,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<Eigen::half> * output_data,ScratchAllocator * workspace_allocator)3778 bool CudnnSupport::DoPoolForward(
3779     Stream* stream, const dnn::PoolingDescriptor& pooling_dimensions,
3780     const dnn::BatchDescriptor& input_dimensions,
3781     const DeviceMemory<Eigen::half>& input_data,
3782     const dnn::BatchDescriptor& output_dimensions,
3783     DeviceMemory<Eigen::half>* output_data,
3784     ScratchAllocator* workspace_allocator) {
3785   // Alpha is the scaling factor for input.
3786   float alpha = 1.0;
3787   // Beta is the scaling factor for output.
3788   float beta = 0.0;
3789 
3790   CudnnTensorDescriptor src_desc(input_dimensions, CUDNN_DATA_HALF);
3791   CudnnTensorDescriptor dest_desc(output_dimensions, CUDNN_DATA_HALF);
3792   CudnnPoolingDescriptor pooling_desc(pooling_dimensions);
3793   auto cudnn = cudnn_->GetHandle(parent_, stream);
3794   const auto status = [&] {
3795     RETURN_IF_CUDNN_ERROR(cudnnPoolingForward(
3796         cudnn.handle(), pooling_desc.handle(), &alpha, src_desc.handle(),
3797         input_data.opaque(), &beta, dest_desc.handle(), output_data->opaque()));
3798     return port::Status::OK();
3799   }();
3800   return IsStatusOk(status, /*report_error=*/true);
3801 }
3802 
DoPoolForward(Stream * stream,const dnn::PoolingDescriptor & pooling_dimensions,const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<int8> & input_data,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<int8> * output_data,ScratchAllocator * workspace_allocator)3803 bool CudnnSupport::DoPoolForward(
3804     Stream* stream, const dnn::PoolingDescriptor& pooling_dimensions,
3805     const dnn::BatchDescriptor& input_dimensions,
3806     const DeviceMemory<int8>& input_data,
3807     const dnn::BatchDescriptor& output_dimensions,
3808     DeviceMemory<int8>* output_data, ScratchAllocator* workspace_allocator) {
3809   // Alpha is the scaling factor for input.
3810   float alpha = 1.0;
3811   // Beta is the scaling factor for output.
3812   float beta = 0.0;
3813 
3814   CudnnTensorDescriptor src_desc(input_dimensions, CUDNN_DATA_INT8);
3815   CudnnTensorDescriptor dest_desc(output_dimensions, CUDNN_DATA_INT8);
3816   CudnnPoolingDescriptor pooling_desc(pooling_dimensions);
3817 
3818   auto cudnn = cudnn_->GetHandle(parent_, stream);
3819   const auto status = [&] {
3820     RETURN_IF_CUDNN_ERROR(cudnnPoolingForward(
3821         cudnn.handle(), pooling_desc.handle(), &alpha, src_desc.handle(),
3822         input_data.opaque(), &beta, dest_desc.handle(), output_data->opaque()));
3823     return port::Status::OK();
3824   }();
3825   return IsStatusOk(status, /*report_error=*/true);
3826 }
3827 
DoPoolBackward(Stream * stream,const dnn::PoolingDescriptor & pooling_dimensions,const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<double> & input_data,const dnn::BatchDescriptor & output_dimensions,const DeviceMemory<double> & output_data,const DeviceMemory<double> & input_diff_data,DeviceMemory<double> * output_diff_data,ScratchAllocator * workspace_allocator)3828 bool CudnnSupport::DoPoolBackward(
3829     Stream* stream, const dnn::PoolingDescriptor& pooling_dimensions,
3830     const dnn::BatchDescriptor& input_dimensions,
3831     const DeviceMemory<double>& input_data,
3832     const dnn::BatchDescriptor& output_dimensions,
3833     const DeviceMemory<double>& output_data,
3834     const DeviceMemory<double>& input_diff_data,
3835     DeviceMemory<double>* output_diff_data,
3836     ScratchAllocator* workspace_allocator) {
3837   // Alpha is the scaling factor for input.
3838   double alpha = 1.0;
3839   // Beta is the scaling factor for output.
3840   double beta = 0.0;
3841 
3842   CudnnTensorDescriptor src_desc(input_dimensions, CUDNN_DATA_DOUBLE);
3843   CudnnTensorDescriptor dest_desc(output_dimensions, CUDNN_DATA_DOUBLE);
3844   CudnnPoolingDescriptor pooling_desc(pooling_dimensions);
3845 
3846   auto cudnn = cudnn_->GetHandle(parent_, stream);
3847   const auto status = [&] {
3848     RETURN_IF_CUDNN_ERROR(cudnnPoolingBackward(
3849         cudnn.handle(), pooling_desc.handle(), &alpha, dest_desc.handle(),
3850         output_data.opaque(), dest_desc.handle(), input_diff_data.opaque(),
3851         src_desc.handle(), input_data.opaque(), &beta, src_desc.handle(),
3852         output_diff_data->opaque()));
3853     return port::Status::OK();
3854   }();
3855   return IsStatusOk(status, /*report_error=*/true);
3856 }
3857 
DoPoolBackward(Stream * stream,const dnn::PoolingDescriptor & pooling_dimensions,const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<float> & input_data,const dnn::BatchDescriptor & output_dimensions,const DeviceMemory<float> & output_data,const DeviceMemory<float> & input_diff_data,DeviceMemory<float> * output_diff_data,ScratchAllocator * workspace_allocator)3858 bool CudnnSupport::DoPoolBackward(
3859     Stream* stream, const dnn::PoolingDescriptor& pooling_dimensions,
3860     const dnn::BatchDescriptor& input_dimensions,
3861     const DeviceMemory<float>& input_data,
3862     const dnn::BatchDescriptor& output_dimensions,
3863     const DeviceMemory<float>& output_data,
3864     const DeviceMemory<float>& input_diff_data,
3865     DeviceMemory<float>* output_diff_data,
3866     ScratchAllocator* workspace_allocator) {
3867   // Alpha is the scaling factor for input.
3868   float alpha = 1.0;
3869   // Beta is the scaling factor for output.
3870   float beta = 0.0;
3871 
3872   CudnnTensorDescriptor src_desc(input_dimensions, CUDNN_DATA_FLOAT);
3873   CudnnTensorDescriptor dest_desc(output_dimensions, CUDNN_DATA_FLOAT);
3874   CudnnPoolingDescriptor pooling_desc(pooling_dimensions);
3875 
3876   auto cudnn = cudnn_->GetHandle(parent_, stream);
3877   const auto status = [&] {
3878     RETURN_IF_CUDNN_ERROR(cudnnPoolingBackward(
3879         cudnn.handle(), pooling_desc.handle(), &alpha, dest_desc.handle(),
3880         output_data.opaque(), dest_desc.handle(), input_diff_data.opaque(),
3881         src_desc.handle(), input_data.opaque(), &beta, src_desc.handle(),
3882         output_diff_data->opaque()));
3883     return port::Status::OK();
3884   }();
3885   return IsStatusOk(status, /*report_error=*/true);
3886 }
3887 
DoPoolBackward(Stream * stream,const dnn::PoolingDescriptor & pooling_dimensions,const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<Eigen::half> & input_data,const dnn::BatchDescriptor & output_dimensions,const DeviceMemory<Eigen::half> & output_data,const DeviceMemory<Eigen::half> & input_diff_data,DeviceMemory<Eigen::half> * output_diff_data,ScratchAllocator * workspace_allocator)3888 bool CudnnSupport::DoPoolBackward(
3889     Stream* stream, const dnn::PoolingDescriptor& pooling_dimensions,
3890     const dnn::BatchDescriptor& input_dimensions,
3891     const DeviceMemory<Eigen::half>& input_data,
3892     const dnn::BatchDescriptor& output_dimensions,
3893     const DeviceMemory<Eigen::half>& output_data,
3894     const DeviceMemory<Eigen::half>& input_diff_data,
3895     DeviceMemory<Eigen::half>* output_diff_data,
3896     ScratchAllocator* workspace_allocator) {
3897   // Alpha is the scaling factor for input.
3898   float alpha = 1.0;
3899   // Beta is the scaling factor for output.
3900   float beta = 0.0;
3901 
3902   CudnnTensorDescriptor src_desc(input_dimensions, CUDNN_DATA_HALF);
3903   CudnnTensorDescriptor dest_desc(output_dimensions, CUDNN_DATA_HALF);
3904   CudnnPoolingDescriptor pooling_desc(pooling_dimensions);
3905 
3906   auto cudnn = cudnn_->GetHandle(parent_, stream);
3907   const auto status = [&] {
3908     RETURN_IF_CUDNN_ERROR(cudnnPoolingBackward(
3909         cudnn.handle(), pooling_desc.handle(), &alpha, dest_desc.handle(),
3910         output_data.opaque(), dest_desc.handle(), input_diff_data.opaque(),
3911         src_desc.handle(), input_data.opaque(), &beta, src_desc.handle(),
3912         output_diff_data->opaque()));
3913     return port::Status::OK();
3914   }();
3915   return IsStatusOk(status, /*report_error=*/true);
3916 }
3917 
DoNormalizeWithDimensions(Stream * stream,const dnn::NormalizeDescriptor & normalize_descriptor,const dnn::BatchDescriptor & dimensions,const DeviceMemory<float> & input_data,DeviceMemory<float> * output_data)3918 bool CudnnSupport::DoNormalizeWithDimensions(
3919     Stream* stream, const dnn::NormalizeDescriptor& normalize_descriptor,
3920     const dnn::BatchDescriptor& dimensions,
3921     const DeviceMemory<float>& input_data, DeviceMemory<float>* output_data) {
3922   // Check for unsupported modes.
3923   if (normalize_descriptor.wrap_around()) {
3924     LOG(ERROR) << "CUDA LRN does not support cudnn-around mode";
3925     return false;
3926   }
3927   if (normalize_descriptor.segment_size()) {
3928     LOG(ERROR) << "CUDA LRN does not support segmentation";
3929     return false;
3930   }
3931 
3932   CudnnTensorDescriptor dims(dimensions, CUDNN_DATA_FLOAT);
3933   CudnnNormalizeDescriptor normalize(normalize_descriptor);
3934 
3935   // Alpha is the scaling factor for input.
3936   float alpha = 1.0f;
3937   // Beta is the scaling factor for output.
3938   float beta = 0.0f;
3939 
3940   auto cudnn = cudnn_->GetHandle(parent_, stream);
3941 
3942   // Launch the normalization.
3943   const auto status = [&] {
3944     RETURN_IF_CUDNN_ERROR(cudnnLRNCrossChannelForward(
3945         cudnn.handle(), normalize.handle(), CUDNN_LRN_CROSS_CHANNEL_DIM1,
3946         &alpha, dims.handle(), input_data.opaque(), &beta, dims.handle(),
3947         output_data->opaque()));
3948     return port::Status::OK();
3949   }();
3950   return IsStatusOk(status, /*report_error=*/true);
3951 }
3952 
DoNormalizeBackwardWithDimensions(Stream * stream,const dnn::NormalizeDescriptor & normalize_descriptor,const dnn::BatchDescriptor & dimensions,const DeviceMemory<float> & raw_data,const DeviceMemory<float> & normalized_data,const DeviceMemory<float> & normalized_variable_gradient,DeviceMemory<float> * raw_variable_gradient,ScratchAllocator * workspace_allocator)3953 bool CudnnSupport::DoNormalizeBackwardWithDimensions(
3954     Stream* stream, const dnn::NormalizeDescriptor& normalize_descriptor,
3955     const dnn::BatchDescriptor& dimensions, const DeviceMemory<float>& raw_data,
3956     const DeviceMemory<float>& normalized_data,
3957     const DeviceMemory<float>& normalized_variable_gradient,
3958     DeviceMemory<float>* raw_variable_gradient,
3959     ScratchAllocator* workspace_allocator) {
3960   // Check for unsupported modes.
3961   if (normalize_descriptor.wrap_around()) {
3962     LOG(ERROR) << "CUDA LRN does not support cudnn-around mode";
3963     return false;
3964   }
3965   if (normalize_descriptor.segment_size()) {
3966     LOG(ERROR) << "CUDA LRN does not support segmentation";
3967     return false;
3968   }
3969 
3970   CudnnTensorDescriptor dims(dimensions, CUDNN_DATA_FLOAT);
3971   CudnnNormalizeDescriptor normalize(normalize_descriptor);
3972 
3973   float alpha = 1.0f;
3974   float beta = 0.0f;
3975 
3976   auto cudnn = cudnn_->GetHandle(parent_, stream);
3977   const auto status = [&] {
3978     RETURN_IF_CUDNN_ERROR(cudnnLRNCrossChannelBackward(
3979         cudnn.handle(), normalize.handle(), CUDNN_LRN_CROSS_CHANNEL_DIM1,
3980         &alpha, dims.handle(), normalized_data.opaque(), dims.handle(),
3981         normalized_variable_gradient.opaque(), dims.handle(), raw_data.opaque(),
3982         &beta, dims.handle(), raw_variable_gradient->opaque()));
3983     return port::Status::OK();
3984   }();
3985   return IsStatusOk(status, /*report_error=*/true);
3986 }
3987 
DoDepthConcatenate(Stream * stream,port::ArraySlice<dnn::BatchDescriptor> input_dimensions,port::ArraySlice<const DeviceMemory<float> * > input_data,DeviceMemory<float> * output_data)3988 bool CudnnSupport::DoDepthConcatenate(
3989     Stream* stream, port::ArraySlice<dnn::BatchDescriptor> input_dimensions,
3990     port::ArraySlice<const DeviceMemory<float>*> input_data,
3991     DeviceMemory<float>* output_data) {
3992   CHECK_EQ(input_dimensions.size(), input_data.size());
3993 
3994   for (const auto& dimensions : input_dimensions) {
3995     if (dimensions.layout() != dnn::DataLayout::kBatchDepthYX) {
3996       LOG(ERROR) << "CudnnSupport::DoDepthConcatenate currently only "
3997                     "supports the kBatchDepthYX layout.";
3998       return false;
3999     }
4000   }
4001 
4002   if (input_dimensions.empty()) {
4003     return true;  // Nothing to do.
4004   }
4005 
4006   dnn::BatchDescriptor output_dimensions =
4007       dnn::BatchDescriptor::DepthConcatenateOutputDescriptor(input_dimensions);
4008 
4009   const int64 area = output_dimensions.width() * output_dimensions.height();
4010   const auto index = [area](int64 batch, int64 depth, int64 yx,
4011                             int64 max_depth) {
4012     return (batch * max_depth + depth) * area + yx;
4013   };
4014 
4015   std::vector<float> output_host(output_dimensions.ElementCount());
4016   std::vector<float> tmp;
4017   int64 depth_sum = 0;
4018   for (size_t i = 0; i < input_data.size(); ++i) {
4019     const auto& dimensions = input_dimensions[i];
4020     tmp.resize(dimensions.ElementCount());
4021     stream->ThenMemcpyD2H<float>(*input_data[i], absl::MakeSpan(tmp));
4022     port::Status block_status = stream->BlockHostUntilDone();
4023     if (!block_status.ok()) {
4024       LOG(ERROR) << "BlockHostUntilDone failed: " << block_status;
4025       return false;
4026     }
4027 
4028     for (int64 batch = 0; batch < output_dimensions.count(); ++batch) {
4029       for (int64 yx = 0; yx < area; ++yx) {
4030         for (int64 depth = 0; depth < dimensions.feature_map_count(); ++depth) {
4031           LOG(INFO) << output_dimensions.ElementCount() << ' ' << batch << ' '
4032                     << yx << ' ' << depth;
4033           output_host[index(batch, depth + depth_sum, yx,
4034                             output_dimensions.feature_map_count())] =
4035               tmp[index(batch, depth, yx, dimensions.feature_map_count())];
4036         }
4037       }
4038     }
4039     depth_sum += dimensions.feature_map_count();
4040   }
4041   stream->ThenMemcpyH2D<float>(output_host, output_data);
4042   return true;
4043 }
4044 
DoElementwiseOperate(Stream * stream,dnn::ElementwiseOperation operation,port::ArraySlice<dnn::BatchDescriptor> input_dimensions,port::ArraySlice<const DeviceMemory<float> * > input_data,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<float> * output_data)4045 bool CudnnSupport::DoElementwiseOperate(
4046     Stream* stream, dnn::ElementwiseOperation operation,
4047     port::ArraySlice<dnn::BatchDescriptor> input_dimensions,
4048     port::ArraySlice<const DeviceMemory<float>*> input_data,
4049     const dnn::BatchDescriptor& output_dimensions,
4050     DeviceMemory<float>* output_data) {
4051   LOG(FATAL) << "not yet implemented";  // TODO(leary)
4052   return false;
4053 }
4054 
DoXYPad(Stream * stream,const dnn::BatchDescriptor & dimensions,const DeviceMemory<float> & input_data,int64 left_pad,int64 right_pad,int64 top_pad,int64 bottom_pad,DeviceMemory<float> * output_data)4055 bool CudnnSupport::DoXYPad(Stream* stream,
4056                            const dnn::BatchDescriptor& dimensions,
4057                            const DeviceMemory<float>& input_data,
4058                            int64 left_pad, int64 right_pad, int64 top_pad,
4059                            int64 bottom_pad, DeviceMemory<float>* output_data) {
4060   LOG(FATAL) << "not yet implemented";  // TODO(leary)
4061   return false;
4062 }
4063 
DoXYSlice(Stream * stream,const dnn::BatchDescriptor & dimensions,const DeviceMemory<float> & input_data,int64 left_trim,int64 right_trim,int64 top_trim,int64 bottom_trim,DeviceMemory<float> * output_data)4064 bool CudnnSupport::DoXYSlice(Stream* stream,
4065                              const dnn::BatchDescriptor& dimensions,
4066                              const DeviceMemory<float>& input_data,
4067                              int64 left_trim, int64 right_trim, int64 top_trim,
4068                              int64 bottom_trim,
4069                              DeviceMemory<float>* output_data) {
4070   LOG(FATAL) << "not yet implemented";  // TODO(leary)
4071   return false;
4072 }
4073 
DoMemcpyD2HQuantized(Stream * stream,const DeviceMemory<float> & gpu_unquantized_src,dnn::QuantizedActivationMode mode,void * host_dst,int64 size)4074 bool CudnnSupport::DoMemcpyD2HQuantized(
4075     Stream* stream, const DeviceMemory<float>& gpu_unquantized_src,
4076     dnn::QuantizedActivationMode mode, void* host_dst, int64 size) {
4077   LOG(ERROR) << "quantized memcpy not supported by cuDNN";
4078   return false;
4079 }
4080 
DoMemcpyH2DQuantized(Stream * stream,const void * host_src,int64 size,dnn::QuantizedActivationMode mode,DeviceMemory<float> * gpu_unquantized_dst)4081 bool CudnnSupport::DoMemcpyH2DQuantized(
4082     Stream* stream, const void* host_src, int64 size,
4083     dnn::QuantizedActivationMode mode,
4084     DeviceMemory<float>* gpu_unquantized_dst) {
4085   LOG(ERROR) << "quantized memcpy not supported by cuDNN";
4086   return false;
4087 }
4088 
DeriveOutputBatchDescriptor(const dnn::BatchDescriptor & batch_descriptor,const dnn::FilterDescriptor & filter_descriptor,const dnn::ConvolutionDescriptor & convolution_descriptor,dnn::BatchDescriptor * output_batch_descriptor)4089 bool CudnnSupport::DeriveOutputBatchDescriptor(
4090     const dnn::BatchDescriptor& batch_descriptor,
4091     const dnn::FilterDescriptor& filter_descriptor,
4092     const dnn::ConvolutionDescriptor& convolution_descriptor,
4093     dnn::BatchDescriptor* output_batch_descriptor) {
4094   CudnnTensorDescriptor input_nd(batch_descriptor, CUDNN_DATA_FLOAT);
4095   CudnnFilterDescriptor filter(filter_descriptor, CUDNN_DATA_FLOAT);
4096   CudnnConvolutionDescriptor conv(convolution_descriptor, CUDNN_DATA_FLOAT);
4097 
4098   int dn = batch_descriptor.ndims() + 2;
4099   std::vector<int> dims(dn);  // in BDYX
4100   const auto status = [&] {
4101     RETURN_IF_CUDNN_ERROR(cudnnGetConvolutionNdForwardOutputDim(
4102         conv.handle(), input_nd.handle(), filter.handle(), dn, dims.data()));
4103     output_batch_descriptor->set_count(dims[0])
4104         .set_feature_map_count(dims[1])
4105         .set_layout(batch_descriptor.layout());
4106 
4107     for (int i = 0; i < batch_descriptor.ndims(); i++) {
4108       output_batch_descriptor->set_spatial_dim(static_cast<dnn::DimIndex>(i),
4109                                                dims.rbegin()[i]);
4110     }
4111     return port::Status::OK();
4112   }();
4113   return IsStatusOk(status, /*report_error=*/true);
4114 }
4115 
4116 }  // namespace gpu
4117 
initialize_cudnn()4118 void initialize_cudnn() {
4119   port::Status status =
4120       PluginRegistry::Instance()->RegisterFactory<PluginRegistry::DnnFactory>(
4121           cuda::kCudaPlatformId, gpu::kCuDnnPlugin, "cuDNN",
4122           [](internal::StreamExecutorInterface* parent) -> dnn::DnnSupport* {
4123             gpu::GpuExecutor* cuda_executor =
4124                 dynamic_cast<gpu::GpuExecutor*>(parent);
4125             if (cuda_executor == nullptr) {
4126               LOG(ERROR) << "Attempting to initialize an instance of the cuDNN "
4127                          << "support library with a non-CUDA StreamExecutor";
4128               return nullptr;
4129             }
4130 
4131             gpu::CudnnSupport* dnn = new gpu::CudnnSupport(cuda_executor);
4132             if (!dnn->Init().ok()) {
4133               // Note: Init() will log a more specific error.
4134               delete dnn;
4135               return nullptr;
4136             }
4137             return dnn;
4138           });
4139 
4140   if (!status.ok()) {
4141     LOG(ERROR) << "Unable to register cuDNN factory: "
4142                << status.error_message();
4143   }
4144 
4145   PluginRegistry::Instance()->SetDefaultFactory(
4146       cuda::kCudaPlatformId, PluginKind::kDnn, gpu::kCuDnnPlugin);
4147 }
4148 
4149 }  // namespace stream_executor
4150 
4151 #pragma clang diagnostic pop
4152 
4153 REGISTER_MODULE_INITIALIZER(register_cudnn,
4154                             { stream_executor::initialize_cudnn(); });
4155