1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/stream_executor/cuda/cuda_dnn.h"
17
18 #include <functional>
19 #include <memory>
20 #include <utility>
21
22 #include "absl/strings/str_cat.h"
23 #include "third_party/eigen3/Eigen/Core"
24 #include "tensorflow/core/lib/core/errors.h"
25 #include "tensorflow/core/util/env_var.h"
26 #include "tensorflow/stream_executor/cuda/cuda_activation.h"
27 #include "tensorflow/stream_executor/cuda/cuda_diagnostics.h"
28 #include "tensorflow/stream_executor/cuda/cuda_driver.h"
29 #include "tensorflow/stream_executor/cuda/cuda_gpu_executor.h"
30 #include "tensorflow/stream_executor/cuda/cuda_platform_id.h"
31 #include "tensorflow/stream_executor/cuda/cuda_stream.h"
32 #include "tensorflow/stream_executor/cuda/cuda_timer.h"
33 #include "tensorflow/stream_executor/cuda/cudnn_version.h"
34 #include "tensorflow/stream_executor/dnn.h"
35 #include "tensorflow/stream_executor/lib/env.h"
36 #include "tensorflow/stream_executor/lib/error.h"
37 #include "tensorflow/stream_executor/lib/initialize.h"
38 #include "tensorflow/stream_executor/lib/mathutil.h"
39 #include "tensorflow/stream_executor/lib/threadpool.h"
40 #include "tensorflow/stream_executor/platform/logging.h"
41 #include "tensorflow/stream_executor/plugin_registry.h"
42 #include "tensorflow/stream_executor/scratch_allocator.h"
43 #include "tensorflow/stream_executor/stream.h"
44 #include "tensorflow/stream_executor/stream_executor_pimpl.h"
45 // clang-format off
46 #include "third_party/gpus/cudnn/cudnn.h"
47 #include "absl/strings/string_view.h"
48 // clang-format on
49
50 #pragma clang diagnostic push
51
52 // Make sure that Eigen::half forward declaration in dnn.h matches the
53 // declaration in Eigen.
54 #pragma clang diagnostic warning "-Wmismatched-tags"
55
56 namespace stream_executor {
57 namespace gpu {
58
59 PLUGIN_REGISTRY_DEFINE_PLUGIN_ID(kCuDnnPlugin);
60
61 namespace {
62
63 static_assert(CUDNN_VERSION >= 6000, "cuDNN needs to be version 6.0 or higher");
64
65 // Exits the program if 'expr' doesn't return CUDNN_STATUS_SUCCESS.
66 #define CHECK_CUDNN_OK(expr) CHECK_EQ(expr, CUDNN_STATUS_SUCCESS)
67
68 // If 'expr' doesn't return CUDNN_STATUS_SUCCESS, returns from the current
69 // function with a non-successful port::Status.
70 #define RETURN_IF_CUDNN_ERROR(expr) \
71 do { \
72 cudnnStatus_t _status = expr; \
73 if (!SE_PREDICT_TRUE(_status == CUDNN_STATUS_SUCCESS)) { \
74 std::ostringstream oss; \
75 oss << ToString(_status) << "\nin " << __FILE__ << "(" << __LINE__ \
76 << "): '" << #expr << "'"; \
77 return port::Status(port::error::UNKNOWN, oss.str().c_str()); \
78 } \
79 } while (false)
80
81 // Converts (via narrowing) a type T value to a type U, and checks that the
82 // value has no value change due to the conversion.
83 template <typename WideT, typename NarrowT>
CheckedNarrowing(const WideT & wide)84 NarrowT CheckedNarrowing(const WideT& wide) {
85 NarrowT narrow = wide;
86 CHECK_EQ(narrow, wide)
87 << "checked narrowing failed; values not equal post-conversion";
88 return narrow;
89 }
90
ToString(cudnnStatus_t status)91 string ToString(cudnnStatus_t status) {
92 switch (status) {
93 case CUDNN_STATUS_SUCCESS:
94 return "CUDNN_STATUS_SUCCESS";
95 case CUDNN_STATUS_NOT_INITIALIZED:
96 return "CUDNN_STATUS_NOT_INITIALIZED";
97 case CUDNN_STATUS_ALLOC_FAILED:
98 return "CUDNN_STATUS_ALLOC_FAILED";
99 case CUDNN_STATUS_BAD_PARAM:
100 return "CUDNN_STATUS_BAD_PARAM";
101 case CUDNN_STATUS_INTERNAL_ERROR:
102 return "CUDNN_STATUS_INTERNAL_ERROR";
103 case CUDNN_STATUS_INVALID_VALUE:
104 return "CUDNN_STATUS_INVALID_VALUE";
105 case CUDNN_STATUS_ARCH_MISMATCH:
106 return "CUDNN_STATUS_ARCH_MISMATCH";
107 case CUDNN_STATUS_MAPPING_ERROR:
108 return "CUDNN_STATUS_MAPPING_ERROR";
109 case CUDNN_STATUS_EXECUTION_FAILED:
110 return "CUDNN_STATUS_EXECUTION_FAILED";
111 case CUDNN_STATUS_NOT_SUPPORTED:
112 return "CUDNN_STATUS_NOT_SUPPORTED";
113 case CUDNN_STATUS_LICENSE_ERROR:
114 return "CUDNN_STATUS_LICENSE_ERROR";
115 case CUDNN_STATUS_RUNTIME_PREREQUISITE_MISSING:
116 return "CUDNN_STATUS_RUNTIME_PREREQUISITE_MISSING";
117 #if CUDNN_VERSION >= 7000
118 case CUDNN_STATUS_RUNTIME_IN_PROGRESS:
119 return "CUDNN_STATUS_RUNTIME_IN_PROGRESS";
120 case CUDNN_STATUS_RUNTIME_FP_OVERFLOW:
121 return "CUDNN_STATUS_RUNTIME_FP_OVERFLOW";
122 #endif
123 default:
124 return absl::StrCat("<unknown cudnn status: ", static_cast<int>(status),
125 ">");
126 }
127 }
128
129 // RAII wrapper for all calls to cuDNN with a cuDNN handle argument.
130 //
131 // See CudnnAccess::GetHandle() for details.
132 class CudnnHandle {
133 public:
134 // Takes ownership of the executor context and the lock to access cuDNN
135 // using handle.
CudnnHandle(gpu::ScopedActivateExecutorContext context,std::unique_ptr<absl::MutexLock> lock,cudnnHandle_t handle)136 CudnnHandle(gpu::ScopedActivateExecutorContext context,
137 std::unique_ptr<absl::MutexLock> lock, cudnnHandle_t handle)
138 : context_(std::move(context)), lock_(std::move(lock)), handle_(handle) {}
139
140 // Returns cuDNN handle. To be passed directly to cuDNN APIs, don't keep
141 // a copy.
handle() const142 cudnnHandle_t handle() const { return handle_; }
143
144 private:
145 gpu::ScopedActivateExecutorContext context_;
146 std::unique_ptr<absl::MutexLock> lock_;
147 cudnnHandle_t handle_; // Not owned.
148 };
149
150 } // namespace
151
152 // Wraps a cuDNN handle and provides access to it through CudnnHandle
153 // instances, which also locks a mutex, acquires the CUDA context, and sets
154 // the stream that cuDNN should use to enqueue any work.
155 //
156 // Note: CudnnSupport::cudnn_ should be the only instantiation of this class.
157 class CudnnAccess {
158 public:
159 // Takes ownership of the handle.
CudnnAccess(cudnnHandle_t handle)160 explicit CudnnAccess(cudnnHandle_t handle) : handle_(handle) {}
161
~CudnnAccess()162 ~CudnnAccess() {
163 absl::MutexLock lock(&mutex_);
164 cudnnDestroy(handle_);
165 }
166
167 // Creates a CudnnHandle instance for stream.
168 //
169 // cuDNN API calls using the same handle instance need to be serialized
170 // across threads. This is guaranteed by CudnnHandle instances locking the
171 // mutex owned by this class.
172 //
173 // Most cuDNN APIs taking a handle perform work on a CUDA stream. The
174 // CudnnHandle instance acquires the executor's CUDA context and sets cuDNN
175 // to use the provided stream.
176 //
177 // The stream argument may be null, which translates to the legacy default
178 // stream. See
179 // https://docs.nvidia.com/cuda/cuda-driver-api/stream-sync-behavior.html.
180 // The legacy default stream synchronizes with all other streams and it is
181 // therefore a bad idea (performance wise) to call any cuDNN APIs that
182 // enqueue work in the stream.
GetHandle(GpuExecutor * executor,Stream * stream)183 CudnnHandle GetHandle(GpuExecutor* executor, Stream* stream) {
184 auto lock = absl::make_unique<absl::MutexLock>(&mutex_);
185 mutex_.AssertHeld();
186 gpu::ScopedActivateExecutorContext context(executor);
187 CUstream cu_stream = stream ? AsGpuStreamValue(stream) : cudaStreamLegacy;
188 const auto status = cudnnSetStream(handle_, cu_stream);
189 CHECK_EQ(status, CUDNN_STATUS_SUCCESS) << "Failed to set cuDNN stream.";
190 return CudnnHandle(std::move(context), std::move(lock), handle_);
191 }
192
193 private:
194 // Guards the enqueueing of cuDNN operations via the handle_ below.
195 absl::Mutex mutex_;
196
197 // cuDNN library handle.
198 cudnnHandle_t handle_ GUARDED_BY(mutex_); // Owned.
199 };
200
201 namespace {
202
203 // A helper function to return the internal compute type for
204 // RNNs in cudnn.
205 cudnnDataType_t GetRnnComputeType(dnn::DataType data_type);
206
ToConvForwardAlgo(dnn::AlgorithmDesc algorithm)207 cudnnConvolutionFwdAlgo_t ToConvForwardAlgo(dnn::AlgorithmDesc algorithm) {
208 cudnnConvolutionFwdAlgo_t algo =
209 cudnnConvolutionFwdAlgo_t(algorithm.algo_id());
210 switch (algo) {
211 case CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM:
212 case CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM:
213 case CUDNN_CONVOLUTION_FWD_ALGO_GEMM:
214 case CUDNN_CONVOLUTION_FWD_ALGO_DIRECT:
215 case CUDNN_CONVOLUTION_FWD_ALGO_FFT:
216 case CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING:
217 case CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD:
218 case CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED:
219 return algo;
220 default:
221 LOG(FATAL) << "Unsupported Cudnn convolution forward algorithm: "
222 << algorithm.algo_id();
223 }
224 }
225
ToConvBackwardDataAlgo(dnn::AlgorithmDesc algorithm)226 cudnnConvolutionBwdDataAlgo_t ToConvBackwardDataAlgo(
227 dnn::AlgorithmDesc algorithm) {
228 cudnnConvolutionBwdDataAlgo_t algo =
229 cudnnConvolutionBwdDataAlgo_t(algorithm.algo_id());
230 switch (algo) {
231 case CUDNN_CONVOLUTION_BWD_DATA_ALGO_0:
232 case CUDNN_CONVOLUTION_BWD_DATA_ALGO_1:
233 case CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT:
234 case CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING:
235 case CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD:
236 case CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD_NONFUSED:
237 return algo;
238 default:
239 LOG(FATAL)
240 << "Unsupported Cudnn convolution backward algorithm for data: "
241 << algorithm.algo_id();
242 }
243 }
244
ToConvBackwardFilterAlgo(dnn::AlgorithmDesc algorithm)245 cudnnConvolutionBwdFilterAlgo_t ToConvBackwardFilterAlgo(
246 dnn::AlgorithmDesc algorithm) {
247 cudnnConvolutionBwdFilterAlgo_t algo =
248 cudnnConvolutionBwdFilterAlgo_t(algorithm.algo_id());
249 switch (algo) {
250 case CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0:
251 case CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1:
252 case CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT:
253 case CUDNN_CONVOLUTION_BWD_FILTER_ALGO_3:
254 // Based on cudnn.h, the following is not implemented.
255 // case CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD:
256 case CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD_NONFUSED:
257 return algo;
258 // Produces incorrect results for some shapes. Disabled for now, see
259 // NVIDIA bug 2072856. TODO(csigg): Only disable for subset of shapes.
260 // case CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT_TILING:
261 default:
262 LOG(FATAL)
263 << "Unsupported Cudnn convolution backward algorithm for filter: "
264 << algorithm.algo_id();
265 }
266 }
267
GetCudnnProperty(libraryPropertyType type)268 port::StatusOr<int> GetCudnnProperty(libraryPropertyType type) {
269 int value;
270 RETURN_IF_CUDNN_ERROR(cudnnGetProperty(type, &value));
271 return value;
272 }
273
ToCudnnRNNAlgo(absl::optional<dnn::AlgorithmDesc> algorithm)274 cudnnRNNAlgo_t ToCudnnRNNAlgo(absl::optional<dnn::AlgorithmDesc> algorithm) {
275 if (!algorithm.has_value()) {
276 return CUDNN_RNN_ALGO_STANDARD;
277 }
278 cudnnRNNAlgo_t algo = static_cast<cudnnRNNAlgo_t>(algorithm->algo_id());
279 switch (algo) {
280 case CUDNN_RNN_ALGO_STANDARD:
281 case CUDNN_RNN_ALGO_PERSIST_STATIC:
282 case CUDNN_RNN_ALGO_PERSIST_DYNAMIC:
283 return algo;
284 default:
285 LOG(FATAL) << "Unsupported Cudnn RNN algorithm: " << algorithm->algo_id();
286 }
287 }
288
GetLoadedCudnnVersion(CudnnVersion * version)289 port::Status GetLoadedCudnnVersion(CudnnVersion* version) {
290 SE_ASSIGN_OR_RETURN(version->major_version, GetCudnnProperty(MAJOR_VERSION));
291 SE_ASSIGN_OR_RETURN(version->minor_version, GetCudnnProperty(MINOR_VERSION));
292 SE_ASSIGN_OR_RETURN(version->patch_level, GetCudnnProperty(PATCH_LEVEL));
293 return port::Status::OK();
294 }
295
296 } // namespace
297
CudnnSupport(GpuExecutor * parent)298 CudnnSupport::CudnnSupport(GpuExecutor* parent) : parent_(parent) {}
299
Init()300 port::Status CudnnSupport::Init() {
301 ScopedActivateExecutorContext context(parent_);
302 cudnnHandle_t cudnn_handle = nullptr;
303 const auto status = cudnnCreate(&cudnn_handle);
304 if (status == CUDNN_STATUS_SUCCESS) {
305 CudnnVersion source_version(CUDNN_MAJOR, CUDNN_MINOR, CUDNN_PATCHLEVEL);
306
307 CudnnVersion loaded_version;
308 TF_RETURN_IF_ERROR(GetLoadedCudnnVersion(&loaded_version));
309 if (!IsSourceCompatibleWithCudnnLibrary(source_version, loaded_version)) {
310 const string error = absl::StrCat(
311 "Loaded runtime CuDNN library: ", loaded_version.ToString(),
312 " but source was compiled with: ", source_version.ToString(),
313 ". CuDNN library major and minor version needs to match or have "
314 "higher minor version in case of CuDNN 7.0 or later version. If "
315 "using a binary install, upgrade your CuDNN library. If building "
316 "from sources, make sure the library loaded at runtime is "
317 "compatible "
318 "with the version specified during compile configuration.");
319 LOG(ERROR) << error;
320 cudnnDestroy(cudnn_handle);
321 return port::Status(port::error::INTERNAL, error);
322 }
323
324 cudnn_.reset(new CudnnAccess(cudnn_handle));
325 return port::Status::OK();
326 }
327
328 CHECK_EQ(cudnn_handle, nullptr);
329 LOG(ERROR) << "Could not create cudnn handle: " << ToString(status);
330 if (status == CUDNN_STATUS_NOT_INITIALIZED) {
331 auto result = gpu::Diagnostician::FindKernelDriverVersion();
332 if (!result.ok()) {
333 LOG(ERROR) << "Error retrieving driver version: "
334 << cuda::DriverVersionStatusToString(result);
335 } else {
336 const auto& version = result.ValueOrDie();
337 LOG(ERROR) << "Possibly insufficient driver version: "
338 << cuda::DriverVersionToString(version);
339 }
340 }
341
342 return port::Status(port::error::INTERNAL,
343 absl::StrCat("cudnn library could not create a handle: ",
344 ToString(status)));
345 }
346
347 port::StatusOr<perftools::gputools::dnn::VersionInfo>
GetVersion()348 CudnnSupport::GetVersion() {
349 CudnnVersion version;
350 TF_RETURN_IF_ERROR(GetLoadedCudnnVersion(&version));
351 return perftools::gputools::dnn::VersionInfo(
352 version.major_version, version.minor_version, version.patch_level);
353 }
354
355 namespace {
356
357 // Deleter functors for cuDNN types that need to be deleted.
358 struct TensorDescriptorDeleter {
operator ()stream_executor::gpu::__anon74a305650311::TensorDescriptorDeleter359 void operator()(cudnnTensorDescriptor_t descriptor) const {
360 CHECK_CUDNN_OK(cudnnDestroyTensorDescriptor(descriptor));
361 }
362 };
363 #if CUDNN_VERSION >= 7201
364 struct RNNDataDescriptorDeleter {
operator ()stream_executor::gpu::__anon74a305650311::RNNDataDescriptorDeleter365 void operator()(cudnnRNNDataDescriptor_t descriptor) const {
366 CHECK_CUDNN_OK(cudnnDestroyRNNDataDescriptor(descriptor));
367 }
368 };
369 #endif
370 struct FilterDescriptorDeleter {
operator ()stream_executor::gpu::__anon74a305650311::FilterDescriptorDeleter371 void operator()(cudnnFilterDescriptor_t descriptor) const {
372 CHECK_CUDNN_OK(cudnnDestroyFilterDescriptor(descriptor));
373 }
374 };
375 struct ConvolutionDescriptorDeleter {
operator ()stream_executor::gpu::__anon74a305650311::ConvolutionDescriptorDeleter376 void operator()(cudnnConvolutionDescriptor_t descriptor) const {
377 CHECK_CUDNN_OK(cudnnDestroyConvolutionDescriptor(descriptor));
378 }
379 };
380 struct PoolingDescriptorDeleter {
operator ()stream_executor::gpu::__anon74a305650311::PoolingDescriptorDeleter381 void operator()(cudnnPoolingDescriptor_t descriptor) const {
382 CHECK_CUDNN_OK(cudnnDestroyPoolingDescriptor(descriptor));
383 }
384 };
385 struct LrnDescriptorDeleter {
operator ()stream_executor::gpu::__anon74a305650311::LrnDescriptorDeleter386 void operator()(cudnnLRNDescriptor_t descriptor) const {
387 CHECK_CUDNN_OK(cudnnDestroyLRNDescriptor(descriptor));
388 }
389 };
390
391 struct ActivationDescriptorDeleter {
operator ()stream_executor::gpu::__anon74a305650311::ActivationDescriptorDeleter392 void operator()(cudnnActivationDescriptor_t descriptor) const {
393 CHECK_CUDNN_OK(cudnnDestroyActivationDescriptor(descriptor));
394 }
395 };
396 struct DropoutDescriptorDeleter {
operator ()stream_executor::gpu::__anon74a305650311::DropoutDescriptorDeleter397 void operator()(cudnnDropoutDescriptor_t descriptor) const {
398 CHECK_CUDNN_OK(cudnnDestroyDropoutDescriptor(descriptor));
399 }
400 };
401 struct RnnDescriptorDeleter {
operator ()stream_executor::gpu::__anon74a305650311::RnnDescriptorDeleter402 void operator()(cudnnRNNDescriptor_t descriptor) const {
403 CHECK_CUDNN_OK(cudnnDestroyRNNDescriptor(descriptor));
404 }
405 };
406 struct PersistentRnnPlanDeleter {
operator ()stream_executor::gpu::__anon74a305650311::PersistentRnnPlanDeleter407 void operator()(cudnnPersistentRNNPlan_t plan) const {
408 CHECK_CUDNN_OK(cudnnDestroyPersistentRNNPlan(plan));
409 }
410 };
411 #if CUDNN_VERSION >= 7603
412 struct CtcLossDescriptorDeleter {
operator ()stream_executor::gpu::__anon74a305650311::CtcLossDescriptorDeleter413 void operator()(cudnnCTCLossDescriptor_t descriptor) const {
414 CHECK_CUDNN_OK(cudnnDestroyCTCLossDescriptor(descriptor));
415 }
416 };
417 #endif
418
419 // RAII wrappers for cuDNN types.
420 using TensorDescriptor =
421 std::unique_ptr<cudnnTensorStruct, TensorDescriptorDeleter>;
422 #if CUDNN_VERSION >= 7201
423 using RNNDataDescriptor =
424 std::unique_ptr<cudnnRNNDataStruct, RNNDataDescriptorDeleter>;
425 #endif
426 using FilterDescriptor =
427 std::unique_ptr<cudnnFilterStruct, FilterDescriptorDeleter>;
428 using ConvolutionDescriptor =
429 std::unique_ptr<cudnnConvolutionStruct, ConvolutionDescriptorDeleter>;
430 using PoolingDescriptor =
431 std::unique_ptr<cudnnPoolingStruct, PoolingDescriptorDeleter>;
432 using LrnDescriptor = std::unique_ptr<cudnnLRNStruct, LrnDescriptorDeleter>;
433 using ActivationDescriptor =
434 std::unique_ptr<cudnnActivationStruct, ActivationDescriptorDeleter>;
435 using DropoutDescriptor =
436 std::unique_ptr<cudnnDropoutStruct, DropoutDescriptorDeleter>;
437 using RnnDescriptor = std::unique_ptr<cudnnRNNStruct, RnnDescriptorDeleter>;
438 using PersistentRnnPlan =
439 std::unique_ptr<cudnnPersistentRNNPlan, PersistentRnnPlanDeleter>;
440 #if CUDNN_VERSION >= 7603
441 using CtcLossDescriptor =
442 std::unique_ptr<cudnnCTCLossStruct, CtcLossDescriptorDeleter>;
443 #endif
444
445 // Factory methods for cuDNN types.
CreateTensorDescriptor()446 TensorDescriptor CreateTensorDescriptor() {
447 cudnnTensorDescriptor_t result;
448 CHECK_CUDNN_OK(cudnnCreateTensorDescriptor(&result));
449 return TensorDescriptor(result);
450 }
451 #if CUDNN_VERSION >= 7201
CreateRNNDataDescriptor()452 RNNDataDescriptor CreateRNNDataDescriptor() {
453 cudnnRNNDataDescriptor_t result;
454 CHECK_CUDNN_OK(cudnnCreateRNNDataDescriptor(&result));
455 return RNNDataDescriptor(result);
456 }
457 #endif
CreateFilterDescriptor()458 FilterDescriptor CreateFilterDescriptor() {
459 cudnnFilterDescriptor_t result;
460 CHECK_CUDNN_OK(cudnnCreateFilterDescriptor(&result));
461 return FilterDescriptor(result);
462 }
CreateConvolutionDescriptor()463 ConvolutionDescriptor CreateConvolutionDescriptor() {
464 cudnnConvolutionDescriptor_t result;
465 CHECK_CUDNN_OK(cudnnCreateConvolutionDescriptor(&result));
466 return ConvolutionDescriptor(result);
467 }
CreatePoolingDescriptor()468 PoolingDescriptor CreatePoolingDescriptor() {
469 cudnnPoolingDescriptor_t result;
470 CHECK_CUDNN_OK(cudnnCreatePoolingDescriptor(&result));
471 return PoolingDescriptor(result);
472 }
CreateLrnDescriptor()473 LrnDescriptor CreateLrnDescriptor() {
474 cudnnLRNDescriptor_t result;
475 CHECK_CUDNN_OK(cudnnCreateLRNDescriptor(&result));
476 return LrnDescriptor(result);
477 }
CreateActivationDescriptor()478 ActivationDescriptor CreateActivationDescriptor() {
479 cudnnActivationDescriptor_t result;
480 CHECK_CUDNN_OK(cudnnCreateActivationDescriptor(&result));
481 return ActivationDescriptor(result);
482 }
CreateDropoutDescriptor()483 DropoutDescriptor CreateDropoutDescriptor() {
484 cudnnDropoutDescriptor_t result;
485 CHECK_CUDNN_OK(cudnnCreateDropoutDescriptor(&result));
486 return DropoutDescriptor(result);
487 }
CreateRnnDescriptor()488 RnnDescriptor CreateRnnDescriptor() {
489 cudnnRNNDescriptor_t result;
490 CHECK_CUDNN_OK(cudnnCreateRNNDescriptor(&result));
491 return RnnDescriptor(result);
492 }
493 #if CUDNN_VERSION >= 7603
CreateCtcLossDescriptor()494 CtcLossDescriptor CreateCtcLossDescriptor() {
495 cudnnCTCLossDescriptor_t result;
496 CHECK_CUDNN_OK(cudnnCreateCTCLossDescriptor(&result));
497 return CtcLossDescriptor(result);
498 }
499 #endif
500
CreatePersistentRnnPlan(cudnnRNNDescriptor_t rnn_desc,int batch_size,cudnnDataType_t data_type)501 port::StatusOr<PersistentRnnPlan> CreatePersistentRnnPlan(
502 cudnnRNNDescriptor_t rnn_desc, int batch_size, cudnnDataType_t data_type) {
503 cudnnPersistentRNNPlan_t result;
504 RETURN_IF_CUDNN_ERROR(
505 cudnnCreatePersistentRNNPlan(rnn_desc, batch_size, data_type, &result));
506 return port::StatusOr<PersistentRnnPlan>(PersistentRnnPlan(result));
507 }
508
509 // Turns a BatchDescriptor structure into a cudnn tensor handle within a
510 // scope.
511 class CudnnTensorDescriptor {
512 public:
CudnnTensorDescriptor(const dnn::BatchDescriptor & batch_descriptor,cudnnDataType_t elem_type)513 CudnnTensorDescriptor(const dnn::BatchDescriptor& batch_descriptor,
514 cudnnDataType_t elem_type)
515 : handle_(CreateTensorDescriptor()) {
516 switch (batch_descriptor.layout()) {
517 case dnn::DataLayout::kBatchYXDepth:
518 case dnn::DataLayout::kBatchDepthYX: {
519 const int nd = batch_descriptor.ndims() + 2;
520 // cuDNN requires the strides and dims to be ordered as BDYX.
521 std::vector<int64> strides64 =
522 batch_descriptor.full_strides(dnn::DataLayout::kBatchDepthYX);
523 std::vector<int64> dims64 =
524 batch_descriptor.full_dims(dnn::DataLayout::kBatchDepthYX);
525
526 // cuDNN requires arrays of ints.
527 std::vector<int> strides(nd);
528 std::vector<int> dims(nd);
529 std::transform(strides64.cbegin(), strides64.cend(), strides.begin(),
530 &CheckedNarrowing<int64, int>);
531 std::transform(dims64.cbegin(), dims64.cend(), dims.begin(),
532 &CheckedNarrowing<int64, int>);
533 CHECK_CUDNN_OK(cudnnSetTensorNdDescriptor(handle_.get(), elem_type, nd,
534 dims.data(), strides.data()))
535 << "batch_descriptor: " << batch_descriptor.ToString();
536 } break;
537 case dnn::DataLayout::kBatchDepthYX4: {
538 CHECK_CUDNN_OK(cudnnSetTensor4dDescriptor(
539 handle_.get(), CUDNN_TENSOR_NCHW_VECT_C, elem_type,
540 batch_descriptor.count(), batch_descriptor.feature_map_count(),
541 batch_descriptor.height(), batch_descriptor.width()))
542 << "batch_descriptor: " << batch_descriptor.ToString();
543 } break;
544 default:
545 LOG(FATAL) << "Unsupported tensor format "
546 << DataLayoutString(batch_descriptor.layout());
547 break;
548 }
549 }
550
handle() const551 cudnnTensorDescriptor_t handle() const { return handle_.get(); }
552
553 private:
554 TensorDescriptor handle_;
555
556 SE_DISALLOW_COPY_AND_ASSIGN(CudnnTensorDescriptor);
557 };
558
559 // Turns a FilterDescriptor structure into a cudnn filter handle within a
560 // scope.
561 class CudnnFilterDescriptor {
562 public:
CudnnFilterDescriptor(const dnn::FilterDescriptor & filter_descriptor,cudnnDataType_t elem_type)563 CudnnFilterDescriptor(const dnn::FilterDescriptor& filter_descriptor,
564 cudnnDataType_t elem_type)
565 : handle_(CreateFilterDescriptor()) {
566 // TODO(b/23032134): Even if the filter layout is not supported,
567 // cudnnSetFilter4DDescriptor_v4 will return CUDNN_STATUS_SUCCESS because
568 // it does not take layout as an input. Maybe force cuDNN by giving wrong
569 // inputs intentionally?
570 cudnnTensorFormat_t format;
571 switch (filter_descriptor.layout()) {
572 case dnn::FilterLayout::kOutputInputYX:
573 format = CUDNN_TENSOR_NCHW;
574 break;
575 case dnn::FilterLayout::kOutputYXInput:
576 format = CUDNN_TENSOR_NHWC;
577 break;
578 case dnn::FilterLayout::kOutputInputYX4:
579 format = CUDNN_TENSOR_NCHW_VECT_C;
580 break;
581 default:
582 LOG(FATAL) << "Unsupported filter format "
583 << FilterLayoutString(filter_descriptor.layout());
584 break;
585 }
586
587 std::vector<int> dims(2 + filter_descriptor.ndims());
588 dims[0] = filter_descriptor.output_feature_map_count();
589 dims[1] = filter_descriptor.input_feature_map_count();
590 absl::Span<const int64> spatial_dims =
591 filter_descriptor.input_filter_dims();
592 std::copy(spatial_dims.begin(), spatial_dims.end(), dims.begin() + 2);
593
594 CHECK_CUDNN_OK(cudnnSetFilterNdDescriptor(handle_.get(), elem_type, format,
595 dims.size(), dims.data()));
596 }
597
handle() const598 cudnnFilterDescriptor_t handle() const { return handle_.get(); }
599
600 private:
601 FilterDescriptor handle_; // Owned.
602
603 SE_DISALLOW_COPY_AND_ASSIGN(CudnnFilterDescriptor);
604 };
605
606 // A helper function to decide whether to enable the TENSOR_OP_MATH math type
TensorOpMathEnabled()607 bool TensorOpMathEnabled() {
608 static bool is_enabled = [] {
609 bool is_disabled = false;
610 TF_CHECK_OK(
611 tensorflow::ReadBoolFromEnvVar("TF_DISABLE_CUDNN_TENSOR_OP_MATH",
612 /*default_val=*/false, &is_disabled));
613 return !is_disabled;
614 }();
615 return is_enabled;
616 }
617
618 // A helper function to decide whether to enable the TENSOR_OP_MATH math type
619 // for RNNs.
RnnTensorOpMathEnabled()620 bool RnnTensorOpMathEnabled() {
621 static bool is_enabled = [] {
622 bool is_disabled = false;
623 TF_CHECK_OK(
624 tensorflow::ReadBoolFromEnvVar("TF_DISABLE_CUDNN_RNN_TENSOR_OP_MATH",
625 /*default_val=*/false, &is_disabled));
626 return !is_disabled;
627 }();
628 return is_enabled;
629 }
630
631 // A helper function to decide whether to use
632 // CUDNN_BATCHNORM_SPATIAL_PERSISTENT in batchnorm. This mode can be faster in
633 // some tasks because an optimized path may be selected for CUDNN_DATA_FLOAT
634 // and CUDNN_DATA_HALF data types, compute capability 6.0 or higher. The
635 // reason we set it to false by default is that this mode may use scaled
636 // atomic integer reduction that may cause a numerical overflow for certain
637 // input data range.
638 // TODO(yangzihao): Use autotune to choose between this mode and
639 // CUDNN_BATCHNORM_SPATIAL mode.
BatchnormSpatialPersistentEnabled()640 bool BatchnormSpatialPersistentEnabled() {
641 static bool is_enabled = [] {
642 bool is_enabled = false;
643 TF_CHECK_OK(tensorflow::ReadBoolFromEnvVar(
644 "TF_USE_CUDNN_BATCHNORM_SPATIAL_PERSISTENT",
645 /*default_val=*/false, &is_enabled));
646 return is_enabled;
647 }();
648 return is_enabled;
649 }
650
651 // The following function allows deterministic ops to be implemented relatively
652 // quickly using environment variables. It is intended to be temporary. The
653 // longer-term intention is to enable deterministic ops via tf.config and
654 // appropriate plumbing. See the discussion on PR 34951 for more information:
655 // https://github.com/tensorflow/tensorflow/pull/34951#discussion_r355682316
656 // This function and associated comment are replicated in the following three
657 // places:
658 // 1. tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker.cc
659 // 2. tensorflow/core/kernels/gpu_utils.cc
660 // 3. tensorflow/stream_executor/cuda/cuda_dnn.cc
661 // When implementing the plumbing, you should also search for the use of
662 // TF_DETERMINISTIC_OPS on its own.
663 // TODO(duncanriach): move to an API that uses tf.config and implement the first
664 // phase of plumbing.
RequireCudnnDeterminism()665 bool RequireCudnnDeterminism() {
666 static bool require_cudnn_determinism = [] {
667 bool deterministic_ops = false;
668 TF_CHECK_OK(tensorflow::ReadBoolFromEnvVar("TF_DETERMINISTIC_OPS",
669 /*default_val=*/false,
670 &deterministic_ops));
671 bool cudnn_deterministic = false;
672 TF_CHECK_OK(tensorflow::ReadBoolFromEnvVar("TF_CUDNN_DETERMINISTIC",
673 /*default_val=*/false,
674 &cudnn_deterministic));
675 return deterministic_ops || cudnn_deterministic;
676 }();
677 return require_cudnn_determinism;
678 }
679
GetCcMajorMinor(Stream * stream)680 std::tuple<int, int> GetCcMajorMinor(Stream* stream) {
681 int cc_major, cc_minor;
682 stream->parent()->GetDeviceDescription().cuda_compute_capability(&cc_major,
683 &cc_minor);
684 return std::make_tuple(cc_major, cc_minor);
685 }
686
687 // Turns a ConvolutionDescriptor structure into a cudnn convolution handle
688 // within a scope.
689 class CudnnConvolutionDescriptor {
690 public:
CudnnConvolutionDescriptor(const dnn::ConvolutionDescriptor & convolution_descriptor,cudnnDataType_t data_type)691 CudnnConvolutionDescriptor(
692 const dnn::ConvolutionDescriptor& convolution_descriptor,
693 cudnnDataType_t data_type)
694 : handle_(CreateConvolutionDescriptor()) {
695 absl::Span<const int64> strides64 = convolution_descriptor.strides();
696 absl::Span<const int64> padding64 = convolution_descriptor.padding();
697 absl::Span<const int64> dilations64 = convolution_descriptor.dilations();
698 CHECK_NE(convolution_descriptor.pad_alignment(),
699 dnn::PadAlignment::kTensorFlowPadding)
700 << "TensorFlow padding alignment is not supported.";
701
702 // cuDNN requires arrays of ints.
703 std::vector<int> strides(convolution_descriptor.ndims());
704 std::vector<int> padding(convolution_descriptor.ndims());
705 std::vector<int> dilations(convolution_descriptor.ndims());
706 std::transform(strides64.cbegin(), strides64.cend(), strides.begin(),
707 &CheckedNarrowing<int64, int>);
708 std::transform(padding64.cbegin(), padding64.cend(), padding.begin(),
709 &CheckedNarrowing<int64, int>);
710 // TODO(yangzihao): Test with negative dilation to make sure that cudnn
711 // doesn't crash.
712 std::transform(dilations64.cbegin(), dilations64.cend(), dilations.begin(),
713 &CheckedNarrowing<int64, int>);
714
715 CHECK_CUDNN_OK(cudnnSetConvolutionNdDescriptor(
716 handle_.get(), convolution_descriptor.ndims(), padding.data(),
717 strides.data(), dilations.data(),
718 convolution_descriptor.convolution_not_crosscorr()
719 ? CUDNN_CONVOLUTION
720 : CUDNN_CROSS_CORRELATION,
721 data_type));
722
723 // NOTE(benbarsdell): This only applies if tensor op math is enabled
724 // and algo selection is set to Default.
725 this->set_use_tensor_op_math(true);
726
727 #if CUDNN_MAJOR >= 7
728 VLOG(2) << "Requesting grouped convolution: "
729 << convolution_descriptor.group_count();
730 CHECK_CUDNN_OK(cudnnSetConvolutionGroupCount(
731 handle_.get(), convolution_descriptor.group_count()));
732 #else
733 CHECK_EQ(convolution_descriptor.group_count(), 1)
734 << "Requested grouped convolution for cuDNN version < 7";
735 #endif
736 }
737
set_use_tensor_op_math(bool use_tensor_op_math) const738 void set_use_tensor_op_math(bool use_tensor_op_math) const {
739 #if CUDNN_VERSION >= 7000
740 cudnnMathType_t math_type =
741 (use_tensor_op_math ? CUDNN_TENSOR_OP_MATH : CUDNN_DEFAULT_MATH);
742 if (TensorOpMathEnabled()) {
743 CHECK_CUDNN_OK(cudnnSetConvolutionMathType(handle_.get(), math_type));
744 }
745 #endif
746 }
747
handle() const748 cudnnConvolutionDescriptor_t handle() const { return handle_.get(); }
749
750 private:
751 ConvolutionDescriptor handle_; // Owned.
752
753 SE_DISALLOW_COPY_AND_ASSIGN(CudnnConvolutionDescriptor);
754 };
755
756 // Turns a PoolingDescriptor structure into a cudnn pooling descriptor handle
757 // within a scope.
758 class CudnnPoolingDescriptor {
759 public:
CudnnPoolingDescriptor(const dnn::PoolingDescriptor & pooling_descriptor)760 explicit CudnnPoolingDescriptor(
761 const dnn::PoolingDescriptor& pooling_descriptor)
762 : handle_(CreatePoolingDescriptor()) {
763 absl::Span<const int64> strides64 = pooling_descriptor.strides();
764 absl::Span<const int64> padding64 = pooling_descriptor.padding();
765 absl::Span<const int64> shape64 = pooling_descriptor.window();
766
767 const int nd = pooling_descriptor.ndims();
768 std::vector<int> shape(nd);
769 std::vector<int> padding(nd);
770 std::vector<int> strides(nd);
771 std::transform(strides64.cbegin(), strides64.cend(), strides.begin(),
772 &CheckedNarrowing<int64, int>);
773 std::transform(padding64.cbegin(), padding64.cend(), padding.begin(),
774 &CheckedNarrowing<int64, int>);
775 std::transform(shape64.cbegin(), shape64.cend(), shape.begin(),
776 &CheckedNarrowing<int64, int>);
777 bool propagate_nans = pooling_descriptor.propagate_nans();
778 const auto cudnn_max_pooling_mode = RequireCudnnDeterminism()
779 ? CUDNN_POOLING_MAX_DETERMINISTIC
780 : CUDNN_POOLING_MAX;
781 CHECK_CUDNN_OK(cudnnSetPoolingNdDescriptor(
782 handle_.get(),
783 (pooling_descriptor.mode() == dnn::PoolingMode::kMaximum
784 ? cudnn_max_pooling_mode
785 : CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING),
786 propagate_nans ? CUDNN_PROPAGATE_NAN : CUDNN_NOT_PROPAGATE_NAN, nd,
787 shape.data(), padding.data(), strides.data()));
788 }
789
handle() const790 cudnnPoolingDescriptor_t handle() const { return handle_.get(); }
791
792 private:
793 PoolingDescriptor handle_; // Owned.
794
795 SE_DISALLOW_COPY_AND_ASSIGN(CudnnPoolingDescriptor);
796 };
797
798 // Turns a NormalizeDescriptor structure into a cudnn LRN descriptor handle.
799 class CudnnNormalizeDescriptor {
800 public:
CudnnNormalizeDescriptor(const dnn::NormalizeDescriptor & normalize_descriptor)801 explicit CudnnNormalizeDescriptor(
802 const dnn::NormalizeDescriptor& normalize_descriptor)
803 : handle_(CreateLrnDescriptor()) {
804 // The range specifies that the indices in the closed range
805 // [i - range, i + range] should be included in the normalization for index
806 // i. The lrnN value is the total number of elements in the range, so
807 // lrnN = 2*range + 1.
808 unsigned lrnN = 2 * normalize_descriptor.range() + 1;
809
810 // Note that SE defines the normalization operation as
811 //
812 // U_i = V_i / ((bias + alpha * (sum_j V_j^2)) ^ beta)
813 //
814 // but cuDNN defines it as
815 //
816 // U_i = V_i / ((bias + (alpha / n) * (sum_j V_j^2)) ^ beta)
817 //
818 // i.e. there is a factor of n difference between the meaning of the alphas
819 // in the two contexts. The cuDNN alpha is n times the SE alpha.
820 double lrnAlpha = lrnN * normalize_descriptor.alpha();
821
822 double lrnBeta = normalize_descriptor.beta();
823 double lrnK = normalize_descriptor.bias();
824 CHECK_CUDNN_OK(
825 cudnnSetLRNDescriptor(handle_.get(), lrnN, lrnAlpha, lrnBeta, lrnK));
826 }
827
handle() const828 cudnnLRNDescriptor_t handle() const { return handle_.get(); }
829
830 private:
831 LrnDescriptor handle_; // Owned.
832
833 SE_DISALLOW_COPY_AND_ASSIGN(CudnnNormalizeDescriptor);
834 };
835
836 // Turns a ActivationDescriptor structure into a cudnn activation
837 // descriptor handle within a scope.
838 class CudnnActivationDescriptor {
839 public:
CudnnActivationDescriptor(dnn::ActivationMode activation_mode,cudnnNanPropagation_t nan_propagation,double value_max)840 CudnnActivationDescriptor(dnn::ActivationMode activation_mode,
841 cudnnNanPropagation_t nan_propagation,
842 double value_max)
843 : handle_(CreateActivationDescriptor()) {
844 double relu_ceiling = 0.0;
845 cudnnActivationMode_t mode;
846 switch (activation_mode) {
847 #if CUDNN_VERSION >= 7100
848 case dnn::ActivationMode::kNone:
849 mode = CUDNN_ACTIVATION_IDENTITY;
850 break;
851 #endif
852 case dnn::ActivationMode::kRelu6:
853 relu_ceiling = 6.0;
854 mode = CUDNN_ACTIVATION_CLIPPED_RELU;
855 break;
856 case dnn::ActivationMode::kReluX:
857 relu_ceiling = value_max;
858 mode = CUDNN_ACTIVATION_CLIPPED_RELU;
859 break;
860 case dnn::ActivationMode::kRelu:
861 mode = CUDNN_ACTIVATION_RELU;
862 break;
863 case dnn::ActivationMode::kSigmoid:
864 mode = CUDNN_ACTIVATION_SIGMOID;
865 break;
866 case dnn::ActivationMode::kTanh:
867 mode = CUDNN_ACTIVATION_TANH;
868 break;
869 default:
870 LOG(FATAL) << "unrecognized activation mode: "
871 << static_cast<int>(activation_mode);
872 }
873
874 CHECK_CUDNN_OK(cudnnSetActivationDescriptor(handle_.get(), mode,
875 nan_propagation, relu_ceiling));
876 }
877
handle() const878 cudnnActivationDescriptor_t handle() const { return handle_.get(); }
879
880 private:
881 ActivationDescriptor handle_; // Owned.
882
883 SE_DISALLOW_COPY_AND_ASSIGN(CudnnActivationDescriptor);
884 };
885
ToCudnnDataType(dnn::DataType data_type,dnn::DataLayout data_layout=dnn::DataLayout::kBatchDepthYX)886 cudnnDataType_t ToCudnnDataType(
887 dnn::DataType data_type,
888 dnn::DataLayout data_layout = dnn::DataLayout::kBatchDepthYX) {
889 switch (data_type) {
890 case dnn::DataType::kFloat:
891 return CUDNN_DATA_FLOAT;
892 case dnn::DataType::kDouble:
893 return CUDNN_DATA_DOUBLE;
894 case dnn::DataType::kHalf:
895 return CUDNN_DATA_HALF;
896 case dnn::DataType::kInt8:
897 return data_layout == dnn::DataLayout::kBatchDepthYX4 ? CUDNN_DATA_INT8x4
898 : CUDNN_DATA_INT8;
899 case dnn::DataType::kInt32:
900 return CUDNN_DATA_INT32;
901 default:
902 LOG(FATAL) << "Invalid DNN data type: " << static_cast<int>(data_type);
903 }
904 }
905
ToCudnnDataType(dnn::DataType data_type,dnn::FilterLayout filter_layout)906 cudnnDataType_t ToCudnnDataType(dnn::DataType data_type,
907 dnn::FilterLayout filter_layout) {
908 if (data_type == dnn::DataType::kInt8 &&
909 filter_layout == dnn::FilterLayout::kOutputInputYX4) {
910 return CUDNN_DATA_INT8x4;
911 }
912 return ToCudnnDataType(data_type);
913 }
914
915 template <typename T>
GetCudnnDataType(dnn::DataLayout data_layout=dnn::DataLayout::kBatchDepthYX)916 cudnnDataType_t GetCudnnDataType(
917 dnn::DataLayout data_layout = dnn::DataLayout::kBatchDepthYX) {
918 return ToCudnnDataType(dnn::ToDataType<T>::value, data_layout);
919 }
920
ToCudnnRnnInputMode(dnn::RnnInputMode input_mode)921 cudnnRNNInputMode_t ToCudnnRnnInputMode(dnn::RnnInputMode input_mode) {
922 switch (input_mode) {
923 case dnn::RnnInputMode::kRnnLinearSkip:
924 case dnn::RnnInputMode::kRnnSkipInput:
925 return static_cast<cudnnRNNInputMode_t>(input_mode);
926 default:
927 LOG(FATAL) << "Invalid RNN input mode: " << static_cast<int>(input_mode);
928 }
929 }
930
ToCudnnRnnDirectionMode(dnn::RnnDirectionMode direction_mode)931 cudnnDirectionMode_t ToCudnnRnnDirectionMode(
932 dnn::RnnDirectionMode direction_mode) {
933 switch (direction_mode) {
934 case dnn::RnnDirectionMode::kRnnUnidirectional:
935 case dnn::RnnDirectionMode::kRnnBidirectional:
936 return static_cast<cudnnDirectionMode_t>(direction_mode);
937 default:
938 LOG(FATAL) << "Invalid RNN direction mode: "
939 << static_cast<int>(direction_mode);
940 }
941 }
942
ToCudnnRnnMode(dnn::RnnMode rnn_mode)943 cudnnRNNMode_t ToCudnnRnnMode(dnn::RnnMode rnn_mode) {
944 switch (rnn_mode) {
945 case dnn::RnnMode::kRnnRelu:
946 case dnn::RnnMode::kRnnTanh:
947 case dnn::RnnMode::kRnnLstm:
948 case dnn::RnnMode::kRnnGru:
949 return static_cast<cudnnRNNMode_t>(rnn_mode);
950 default:
951 LOG(FATAL) << "Invalid RNN Mode: " << static_cast<int>(rnn_mode);
952 }
953 }
954
CudnnDataTypeToByteSize(cudnnDataType_t data_type)955 int CudnnDataTypeToByteSize(cudnnDataType_t data_type) {
956 switch (data_type) {
957 case CUDNN_DATA_FLOAT:
958 return sizeof(float);
959 case CUDNN_DATA_DOUBLE:
960 return sizeof(double);
961 case CUDNN_DATA_HALF:
962 return sizeof(Eigen::half);
963 default:
964 LOG(FATAL) << "Invalid DNN data type: " << static_cast<int>(data_type);
965 }
966 }
967
968 class CudnnDropoutDescriptor {
CudnnDropoutDescriptor(DropoutDescriptor handle)969 explicit CudnnDropoutDescriptor(DropoutDescriptor handle)
970 : handle_(std::move(handle)) {}
971
972 public:
973 CudnnDropoutDescriptor(CudnnDropoutDescriptor&&) = default;
974
Create(const CudnnHandle & cudnn,float dropout,uint64 seed,ScratchAllocator * state_allocator)975 static port::StatusOr<CudnnDropoutDescriptor> Create(
976 const CudnnHandle& cudnn, float dropout, uint64 seed,
977 ScratchAllocator* state_allocator) {
978 DropoutDescriptor handle = CreateDropoutDescriptor();
979
980 if (dropout == 0.0f) {
981 // Return 'empty' dropout descriptor.
982 return CudnnDropoutDescriptor(std::move(handle));
983 }
984
985 DeviceMemory<uint8> state_memory;
986 if (state_allocator) {
987 size_t state_sizes_in_bytes = 0;
988 RETURN_IF_CUDNN_ERROR(
989 cudnnDropoutGetStatesSize(cudnn.handle(), &state_sizes_in_bytes));
990 SE_ASSIGN_OR_RETURN(state_memory,
991 state_allocator->AllocateBytes(state_sizes_in_bytes));
992 }
993 RETURN_IF_CUDNN_ERROR(cudnnSetDropoutDescriptor(
994 handle.get(), cudnn.handle(), dropout, state_memory.opaque(),
995 state_memory.size(), seed));
996
997 return CudnnDropoutDescriptor(std::move(handle));
998 }
999
handle() const1000 cudnnDropoutDescriptor_t handle() const { return handle_.get(); }
1001
1002 private:
1003 DropoutDescriptor handle_; // Owned.
1004 SE_DISALLOW_COPY_AND_ASSIGN(CudnnDropoutDescriptor);
1005 };
1006
1007 class CudnnRnnParamsDescriptor {
1008 typedef dnn::RnnDescriptor::ParamsRegions ParamsRegions;
1009
CudnnRnnParamsDescriptor(FilterDescriptor handle,int64 params_size_in_bytes,ParamsRegions weights,ParamsRegions biases)1010 CudnnRnnParamsDescriptor(FilterDescriptor handle, int64 params_size_in_bytes,
1011 ParamsRegions weights, ParamsRegions biases)
1012 : handle_(std::move(handle)),
1013 params_size_in_bytes_(params_size_in_bytes),
1014 weights_(std::move(weights)),
1015 biases_(std::move(biases)) {}
1016
1017 public:
1018 CudnnRnnParamsDescriptor(CudnnRnnParamsDescriptor&&) = default;
1019
1020 static port::StatusOr<CudnnRnnParamsDescriptor> Create(
1021 const CudnnHandle& cudnn, int input_size, cudnnDataType_t data_type,
1022 cudnnRNNDescriptor_t rnn_desc, cudnnRNNMode_t rnn_mode,
1023 cudnnDirectionMode_t direction_mode, int num_layers);
1024
handle() const1025 cudnnFilterDescriptor_t handle() const { return handle_.get(); }
params_size_in_bytes() const1026 int64 params_size_in_bytes() const { return params_size_in_bytes_; }
params_weights() const1027 ParamsRegions params_weights() const { return weights_; }
params_biases() const1028 ParamsRegions params_biases() const { return biases_; }
1029
1030 private:
1031 FilterDescriptor handle_;
1032 int64 params_size_in_bytes_;
1033 ParamsRegions weights_;
1034 ParamsRegions biases_;
1035 SE_DISALLOW_COPY_AND_ASSIGN(CudnnRnnParamsDescriptor);
1036 };
1037
1038 } // namespace
1039
1040 class CudnnRnnDescriptor : public dnn::RnnDescriptor {
CudnnRnnDescriptor(const CudnnHandle & cudnn,gpu::RnnDescriptor rnn_desc,PersistentRnnPlan rnn_plan,int num_layers,int hidden_size,int input_size,int cell_size,int batch_size,cudnnRNNInputMode_t input_mode,cudnnDirectionMode_t direction_mode,cudnnRNNMode_t rnn_mode,cudnnDataType_t data_type,cudnnDataType_t compute_type,const dnn::AlgorithmConfig & algorithm_config,CudnnDropoutDescriptor dropout_desc,CudnnRnnParamsDescriptor params_desc)1041 CudnnRnnDescriptor(const CudnnHandle& cudnn, gpu::RnnDescriptor rnn_desc,
1042 PersistentRnnPlan rnn_plan, int num_layers,
1043 int hidden_size, int input_size, int cell_size,
1044 int batch_size, cudnnRNNInputMode_t input_mode,
1045 cudnnDirectionMode_t direction_mode,
1046 cudnnRNNMode_t rnn_mode, cudnnDataType_t data_type,
1047 cudnnDataType_t compute_type,
1048 const dnn::AlgorithmConfig& algorithm_config,
1049 CudnnDropoutDescriptor dropout_desc,
1050 CudnnRnnParamsDescriptor params_desc)
1051 : rnn_desc_(std::move(rnn_desc)),
1052 rnn_plan_(std::move(rnn_plan)),
1053 num_layers_(num_layers),
1054 hidden_size_(hidden_size),
1055 input_size_(input_size),
1056 cell_size_(cell_size),
1057 batch_size_(batch_size),
1058 rnn_algo_(ToCudnnRNNAlgo(algorithm_config.algorithm())),
1059 input_mode_(input_mode),
1060 direction_mode_(direction_mode),
1061 rnn_mode_(rnn_mode),
1062 data_type_(data_type),
1063 compute_type_(compute_type),
1064 algorithm_config_(algorithm_config),
1065 dropout_desc_(std::move(dropout_desc)),
1066 params_desc_(std::move(params_desc)) {}
1067
1068 public:
1069 CudnnRnnDescriptor(CudnnRnnDescriptor&& other) = default;
1070
Create(const CudnnHandle & cudnn,int num_layers,int hidden_size,int input_size,int cell_size,int batch_size,cudnnRNNInputMode_t input_mode,cudnnDirectionMode_t direction_mode,cudnnRNNMode_t rnn_mode,cudnnDataType_t data_type,cudnnDataType_t compute_type,const dnn::AlgorithmConfig & algorithm_config,float dropout,uint64 seed,ScratchAllocator * state_allocator,bool use_padded_io)1071 static port::StatusOr<CudnnRnnDescriptor> Create(
1072 const CudnnHandle& cudnn, int num_layers, int hidden_size, int input_size,
1073 int cell_size, int batch_size, cudnnRNNInputMode_t input_mode,
1074 cudnnDirectionMode_t direction_mode, cudnnRNNMode_t rnn_mode,
1075 cudnnDataType_t data_type, cudnnDataType_t compute_type,
1076 const dnn::AlgorithmConfig& algorithm_config, float dropout, uint64 seed,
1077 ScratchAllocator* state_allocator, bool use_padded_io) {
1078 SE_ASSIGN_OR_RETURN(
1079 CudnnDropoutDescriptor dropout_desc,
1080 CudnnDropoutDescriptor::Create(cudnn, dropout, seed, state_allocator));
1081
1082 gpu::RnnDescriptor rnn_desc = CreateRnnDescriptor();
1083 cudnnRNNAlgo_t rnn_algo = ToCudnnRNNAlgo(algorithm_config.algorithm());
1084
1085 // TODO: allow the user to choose an algorithm.
1086 int unified_size = hidden_size;
1087 bool use_projection = cell_size != 0 && hidden_size < cell_size;
1088 if (use_projection) {
1089 unified_size = cell_size;
1090 }
1091 RETURN_IF_CUDNN_ERROR(cudnnSetRNNDescriptor_v6(
1092 cudnn.handle(), /*rnnDesc=*/rnn_desc.get(),
1093 /*hiddenSize=*/unified_size, /*numLayers=*/num_layers,
1094 /*dropoutDesc=*/dropout_desc.handle(), /*inputMode=*/input_mode,
1095 /*direction=*/direction_mode, /*mode=*/rnn_mode, /*algo=*/rnn_algo,
1096 /*dataType=*/compute_type));
1097 if (use_projection) {
1098 #if CUDNN_VERSION >= 7101
1099 RETURN_IF_CUDNN_ERROR(cudnnSetRNNProjectionLayers(
1100 cudnn.handle(), /*rnnDesc=*/rnn_desc.get(),
1101 /*recProjSize=*/hidden_size, /*outProjSize=*/0));
1102 #else
1103 return port::Status(port::error::INVALID_ARGUMENT,
1104 "No supported cudnnSetRNNProjectionLayers when "
1105 "CUDNN_VERSION < 7.1.1");
1106 #endif
1107 }
1108
1109 // TODO: For now, we only use cudnnRNN**Ex API to process padded inputs.
1110 // But in the future if these APIs are used to process full length arrays,
1111 // we need to distinguish when to set it.
1112 #if CUDNN_VERSION >= 7201
1113 if (use_padded_io) {
1114 RETURN_IF_CUDNN_ERROR(
1115 cudnnSetRNNPaddingMode(rnn_desc.get(), CUDNN_RNN_PADDED_IO_ENABLED));
1116 }
1117 #endif
1118
1119 port::StatusOr<PersistentRnnPlan> rnn_plan_wrapper;
1120 PersistentRnnPlan rnn_plan;
1121 if (rnn_algo == CUDNN_RNN_ALGO_PERSIST_DYNAMIC) {
1122 CHECK_GE(batch_size, 0);
1123 rnn_plan_wrapper =
1124 CreatePersistentRnnPlan(rnn_desc.get(), batch_size, data_type);
1125 if (!rnn_plan_wrapper.ok()) {
1126 return port::StatusOr<CudnnRnnDescriptor>(rnn_plan_wrapper.status());
1127 } else {
1128 rnn_plan = rnn_plan_wrapper.ConsumeValueOrDie();
1129 RETURN_IF_CUDNN_ERROR(
1130 cudnnSetPersistentRNNPlan(rnn_desc.get(), rnn_plan.get()));
1131 }
1132 }
1133
1134 // Create the params handle.
1135 SE_ASSIGN_OR_RETURN(auto params_desc,
1136 CudnnRnnParamsDescriptor::Create(
1137 cudnn, input_size, data_type, rnn_desc.get(),
1138 rnn_mode, direction_mode, num_layers));
1139
1140 #if CUDNN_VERSION >= 7000
1141 // Require explicit algorithm config to enable tensor cores. Some configs
1142 // return CUDNN_NOT_SUPPORTED when tensor ops are enabled (which is against
1143 // the idiom that enabling tensor ops is only a hint: see nvbugs/2172799).
1144 // We can only reasonably expect the user to handle the subsequent failure
1145 // in profile mode, which is run with algorithms returned from
1146 // GetRnnAlgorithms() (which are non-default and explicitly set whether to
1147 // use tensor ops). CuDNN 7.2.1 fixed this issue
1148 if (RnnTensorOpMathEnabled()) {
1149 cudnnMathType_t math_type;
1150 if (algorithm_config.algorithm().has_value()) {
1151 math_type = algorithm_config.algorithm()->tensor_ops_enabled()
1152 ? CUDNN_TENSOR_OP_MATH
1153 : CUDNN_DEFAULT_MATH;
1154 } else {
1155 #if CUDNN_VERSION >= 7201
1156 math_type = CUDNN_TENSOR_OP_MATH;
1157 #else
1158 math_type = CUDNN_DEFAULT_MATH;
1159 #endif // CUDNN_VERSION >= 7201
1160 }
1161 CHECK_CUDNN_OK(cudnnSetRNNMatrixMathType(rnn_desc.get(), math_type));
1162 }
1163 #endif // CUDNN_VERSION >= 7000
1164
1165 return CudnnRnnDescriptor(cudnn, std::move(rnn_desc), std::move(rnn_plan),
1166 num_layers, hidden_size, input_size, cell_size,
1167 batch_size, input_mode, direction_mode, rnn_mode,
1168 data_type, compute_type, algorithm_config,
1169 std::move(dropout_desc), std::move(params_desc));
1170 }
1171
handle() const1172 cudnnRNNDescriptor_t handle() const { return rnn_desc_.get(); }
num_layers() const1173 int num_layers() const { return num_layers_; }
hidden_size() const1174 int hidden_size() const { return hidden_size_; }
input_size() const1175 int input_size() const { return input_size_; }
cell_size() const1176 int cell_size() const { return cell_size_; }
batch_size() const1177 int batch_size() const { return batch_size_; }
input_mode() const1178 cudnnRNNInputMode_t input_mode() const { return input_mode_; }
direction_mode() const1179 cudnnDirectionMode_t direction_mode() const { return direction_mode_; }
rnn_mode() const1180 cudnnRNNMode_t rnn_mode() const { return rnn_mode_; }
data_type() const1181 cudnnDataType_t data_type() const { return data_type_; }
compute_type() const1182 cudnnDataType_t compute_type() const { return compute_type_; }
algorithm_config() const1183 const dnn::AlgorithmConfig& algorithm_config() const {
1184 return algorithm_config_;
1185 }
ParamsSizeInBytes() const1186 int64 ParamsSizeInBytes() const override {
1187 return params_desc_.params_size_in_bytes();
1188 }
params_handle() const1189 cudnnFilterDescriptor_t params_handle() const {
1190 return params_desc_.handle();
1191 }
ParamsWeightRegions() const1192 ParamsRegions ParamsWeightRegions() const override {
1193 return params_desc_.params_weights();
1194 }
ParamsBiasRegions() const1195 ParamsRegions ParamsBiasRegions() const override {
1196 return params_desc_.params_biases();
1197 }
1198
1199 private:
1200 gpu::RnnDescriptor rnn_desc_;
1201 PersistentRnnPlan rnn_plan_;
1202 int num_layers_;
1203 int hidden_size_;
1204 int input_size_;
1205 // cell_size_ is the size of cell state, which will be different from
1206 // hidden_size_ if the projection is used.
1207 int cell_size_;
1208 // batch_size_ is set to -1 when not using CUDNN_RNN_ALGO_PERSIST_DYNAMIC
1209 // algorithm.
1210 int batch_size_;
1211 cudnnRNNAlgo_t rnn_algo_;
1212 cudnnRNNInputMode_t input_mode_;
1213 cudnnDirectionMode_t direction_mode_;
1214 cudnnRNNMode_t rnn_mode_;
1215 cudnnDataType_t data_type_;
1216 cudnnDataType_t compute_type_;
1217 dnn::AlgorithmConfig algorithm_config_;
1218 CudnnDropoutDescriptor dropout_desc_;
1219 CudnnRnnParamsDescriptor params_desc_;
1220 SE_DISALLOW_COPY_AND_ASSIGN(CudnnRnnDescriptor);
1221 };
1222
1223 #if CUDNN_VERSION >= 7603
1224 class CudnnCtcLossDescriptor {
1225 public:
CudnnCtcLossDescriptor(cudnnDataType_t data_type)1226 explicit CudnnCtcLossDescriptor(cudnnDataType_t data_type)
1227 : handle_(CreateCtcLossDescriptor()) {
1228 CHECK_CUDNN_OK(cudnnSetCTCLossDescriptorEx(
1229 /*ctcLossDesc=*/handle_.get(),
1230 /*compType=*/data_type,
1231 /*normMode=*/CUDNN_LOSS_NORMALIZATION_SOFTMAX,
1232 /*gradMode=*/CUDNN_NOT_PROPAGATE_NAN));
1233 }
1234
handle() const1235 cudnnCTCLossDescriptor_t handle() const { return handle_.get(); }
1236
1237 private:
1238 CtcLossDescriptor handle_; // Owned
1239
1240 SE_DISALLOW_COPY_AND_ASSIGN(CudnnCtcLossDescriptor);
1241 };
1242 #else
1243 // dummy class
1244 class CudnnCtcLossDescriptor {
1245 public:
CudnnCtcLossDescriptor(cudnnDataType_t data_type)1246 CudnnCtcLossDescriptor(cudnnDataType_t data_type) {}
1247 };
1248 #endif
1249
1250 namespace {
1251
1252 // Check if the LSTM projection is used. If yes, an additional weight matrix
1253 // (projection matrix) will be fetched to the 'weights'. Otherwise, nothing will
1254 // be done.
CheckAndFetchProjectionWeights(const CudnnHandle & cudnn,cudnnRNNDescriptor_t rnn_desc,const int layer,const TensorDescriptor & input_desc,const FilterDescriptor & filter_desc,const FilterDescriptor & region_desc_handle,dnn::RnnDescriptor::ParamsRegions * weights)1255 port::Status CheckAndFetchProjectionWeights(
1256 const CudnnHandle& cudnn, cudnnRNNDescriptor_t rnn_desc, const int layer,
1257 const TensorDescriptor& input_desc, const FilterDescriptor& filter_desc,
1258 const FilterDescriptor& region_desc_handle,
1259 dnn::RnnDescriptor::ParamsRegions* weights) {
1260 #if CUDNN_VERSION >= 7101
1261 int hidden_size_v;
1262 int num_layers_v;
1263 cudnnDropoutDescriptor_t dropout_desc;
1264 cudnnRNNInputMode_t input_mode;
1265 cudnnDirectionMode_t direction;
1266 cudnnRNNMode_t mode;
1267 cudnnRNNAlgo_t algo;
1268 cudnnDataType_t data_type;
1269 RETURN_IF_CUDNN_ERROR(cudnnGetRNNDescriptor(
1270 /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc,
1271 /*hiddenSize=*/&hidden_size_v,
1272 /*numLayers=*/&num_layers_v,
1273 /*dropoutDesc=*/&dropout_desc,
1274 /*inputMode=*/&input_mode,
1275 /*direction=*/&direction,
1276 /*mode=*/&mode,
1277 /*algo=*/&algo,
1278 /*dataType=*/&data_type));
1279 int rec_proj_size_v;
1280 int out_proj_size_v;
1281 RETURN_IF_CUDNN_ERROR(cudnnGetRNNProjectionLayers(
1282 /*handle=*/cudnn.handle(),
1283 /*rnnDesc=*/rnn_desc,
1284 /*recProjSize*/ &rec_proj_size_v,
1285 /*outProjSize*/ &out_proj_size_v));
1286 if (rec_proj_size_v != hidden_size_v) {
1287 void* offset = nullptr;
1288 int region_id = 8;
1289 RETURN_IF_CUDNN_ERROR(cudnnGetRNNLinLayerMatrixParams(
1290 /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc,
1291 /*layer=*/layer, /*xDesc=*/input_desc.get(),
1292 /*wDesc=*/filter_desc.get(),
1293 /*w=*/nullptr, /*linLayerID=*/region_id,
1294 /*linLayerMatDesc=*/region_desc_handle.get(),
1295 /*linLayerMat or linLayerBias=*/&offset));
1296 int dims[] = {1, 1, 1};
1297 cudnnDataType_t data_type;
1298 cudnnTensorFormat_t tensor_format;
1299 int n_dims;
1300 RETURN_IF_CUDNN_ERROR(cudnnGetFilterNdDescriptor(
1301 /*filterDesc=*/region_desc_handle.get(),
1302 /*nbDimsRequested=*/sizeof(dims) / sizeof(dims[0]),
1303 /*dataType=*/&data_type, /*format=*/&tensor_format,
1304 /*nbDims=*/&n_dims, /*filterDimA=*/dims));
1305 int64 size =
1306 dims[0] * dims[1] * dims[2] * CudnnDataTypeToByteSize(data_type);
1307 dnn::RnnDescriptor::ParamsRegion region = {reinterpret_cast<int64>(offset),
1308 size};
1309 weights->push_back(region);
1310 }
1311 #endif // CUDNN_VERSION >= 7101
1312 return port::Status::OK();
1313 }
1314
Create(const CudnnHandle & cudnn,int input_size,cudnnDataType_t data_type,cudnnRNNDescriptor_t rnn_desc,cudnnRNNMode_t rnn_mode,cudnnDirectionMode_t direction_mode,int num_layers)1315 port::StatusOr<CudnnRnnParamsDescriptor> CudnnRnnParamsDescriptor::Create(
1316 const CudnnHandle& cudnn, int input_size, cudnnDataType_t data_type,
1317 cudnnRNNDescriptor_t rnn_desc, cudnnRNNMode_t rnn_mode,
1318 cudnnDirectionMode_t direction_mode, int num_layers) {
1319 // Query the params size.
1320 TensorDescriptor input_desc = CreateTensorDescriptor();
1321 int tensor_dims[] = {1, input_size, 1};
1322 int strides[] = {tensor_dims[1] * tensor_dims[2], tensor_dims[2], 1};
1323 RETURN_IF_CUDNN_ERROR(cudnnSetTensorNdDescriptor(
1324 /*tensorDesc=*/input_desc.get(), /*dataType=*/data_type,
1325 /*nbDims=*/sizeof(tensor_dims) / sizeof(tensor_dims[0]),
1326 /*dimA=*/tensor_dims,
1327 /*strideA=*/strides));
1328
1329 size_t params_size = 0;
1330 RETURN_IF_CUDNN_ERROR(cudnnGetRNNParamsSize(
1331 /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc,
1332 /*xDesc=*/input_desc.get(), /*sizeInBytes=*/¶ms_size,
1333 /*dataType=*/data_type));
1334 int64 params_size_in_bytes = static_cast<int64>(params_size);
1335
1336 FilterDescriptor filter_desc = CreateFilterDescriptor();
1337 int filter_dims[] = {static_cast<int>(params_size_in_bytes), 1, 1};
1338 RETURN_IF_CUDNN_ERROR(cudnnSetFilterNdDescriptor(
1339 /*filterDesc=*/filter_desc.get(), /*dataType=*/data_type,
1340 /*format=*/CUDNN_TENSOR_NCHW,
1341 /*nbDims=*/sizeof(filter_dims) / sizeof(filter_dims[0]),
1342 /*filterDimA=*/filter_dims));
1343
1344 // Create the weights and biases into the params buffer
1345 int region_count_per_layer = [&] {
1346 switch (rnn_mode) {
1347 case CUDNN_RNN_RELU:
1348 case CUDNN_RNN_TANH:
1349 return 2;
1350 case CUDNN_LSTM:
1351 return 8;
1352 case CUDNN_GRU:
1353 return 6;
1354 default:
1355 LOG(FATAL) << "Invalid RNN Mode: " << static_cast<int>(rnn_mode);
1356 return 0;
1357 }
1358 }();
1359
1360 FilterDescriptor region_desc_handle = CreateFilterDescriptor();
1361 const int layer_count =
1362 direction_mode == CUDNN_UNIDIRECTIONAL ? num_layers : 2 * num_layers;
1363
1364 ParamsRegions weights;
1365 ParamsRegions biases;
1366
1367 for (int layer = 0; layer < layer_count; layer++) {
1368 for (int region = 0; region < region_count_per_layer; region++) {
1369 for (int type = 0; type < 2; type++) {
1370 void* offset = nullptr;
1371 RETURN_IF_CUDNN_ERROR(
1372 type == 0 ? cudnnGetRNNLinLayerMatrixParams(
1373 /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc,
1374 /*layer=*/layer, /*xDesc=*/input_desc.get(),
1375 /*wDesc=*/filter_desc.get(),
1376 /*w=*/nullptr, /*linLayerID=*/region,
1377 /*linLayerMatDesc=*/region_desc_handle.get(),
1378 /*linLayerMat or linLayerBias=*/&offset)
1379 : cudnnGetRNNLinLayerBiasParams(
1380 /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc,
1381 /*layer=*/layer, /*xDesc=*/input_desc.get(),
1382 /*wDesc=*/filter_desc.get(),
1383 /*w=*/nullptr, /*linLayerID=*/region,
1384 /*linLayerMatDesc=*/region_desc_handle.get(),
1385 /*linLayerMat or linLayerBias=*/&offset));
1386 int dims[] = {1, 1, 1};
1387 cudnnDataType_t data_type;
1388 cudnnTensorFormat_t tensor_format;
1389 int n_dims;
1390 RETURN_IF_CUDNN_ERROR(cudnnGetFilterNdDescriptor(
1391 /*filterDesc=*/region_desc_handle.get(),
1392 /*nbDimsRequested=*/sizeof(dims) / sizeof(dims[0]),
1393 /*dataType=*/&data_type, /*format=*/&tensor_format,
1394 /*nbDims=*/&n_dims, /*filterDimA=*/dims));
1395 int64 size =
1396 dims[0] * dims[1] * dims[2] * CudnnDataTypeToByteSize(data_type);
1397 dnn::RnnDescriptor::ParamsRegion region = {
1398 reinterpret_cast<int64>(offset), size};
1399 (type == 0 ? weights : biases).push_back(region);
1400 }
1401 }
1402 TF_RETURN_IF_ERROR(CheckAndFetchProjectionWeights(
1403 cudnn, rnn_desc, layer, input_desc, filter_desc, region_desc_handle,
1404 &weights));
1405 }
1406
1407 return CudnnRnnParamsDescriptor(std::move(filter_desc), params_size_in_bytes,
1408 weights, biases);
1409 }
1410
1411 } // namespace
1412
1413 class CudnnRnnSequenceTensorDescriptor
1414 : public dnn::RnnSequenceTensorDescriptor {
CudnnRnnSequenceTensorDescriptor(GpuExecutor * parent,int max_seq_length,int batch_size,int data_size,cudnnDataType_t data_type,RNNDataDescriptor data_handle,TensorDescriptor handle)1415 CudnnRnnSequenceTensorDescriptor(GpuExecutor* parent, int max_seq_length,
1416 int batch_size, int data_size,
1417 cudnnDataType_t data_type,
1418 #if CUDNN_VERSION >= 7201
1419 RNNDataDescriptor data_handle,
1420 #endif
1421 TensorDescriptor handle)
1422 : max_seq_length_(max_seq_length),
1423 batch_size_(batch_size),
1424 data_size_(data_size),
1425 data_type_(data_type),
1426 handle_(std::move(handle)),
1427 #if CUDNN_VERSION >= 7201
1428 rnn_data_handle_(std::move(data_handle)),
1429 #endif
1430 handles_(max_seq_length, handle_.get()) {
1431 }
1432
1433 public:
1434 CudnnRnnSequenceTensorDescriptor(CudnnRnnSequenceTensorDescriptor&&) =
1435 default;
1436
Create(GpuExecutor * parent,int max_seq_length,int batch_size,int data_size,cudnnDataType_t data_type)1437 static port::StatusOr<CudnnRnnSequenceTensorDescriptor> Create(
1438 GpuExecutor* parent, int max_seq_length, int batch_size, int data_size,
1439 cudnnDataType_t data_type) {
1440 CHECK_GT(max_seq_length, 0);
1441 int dims[] = {batch_size, data_size, 1};
1442 int strides[] = {dims[1] * dims[2], dims[2], 1};
1443 TensorDescriptor tensor_desc = CreateTensorDescriptor();
1444 RETURN_IF_CUDNN_ERROR(cudnnSetTensorNdDescriptor(
1445 /*tensorDesc=*/tensor_desc.get(), /*dataType=*/data_type,
1446 /*nbDims=*/sizeof(dims) / sizeof(dims[0]), /*dimA=*/dims,
1447 /*strideA=*/strides));
1448 return CudnnRnnSequenceTensorDescriptor(parent, max_seq_length, batch_size,
1449 data_size, data_type,
1450 #if CUDNN_VERSION >= 7201
1451 nullptr,
1452 #endif
1453 std::move(tensor_desc));
1454 }
1455
Create(GpuExecutor * parent,int max_seq_length,int batch_size,int data_size,const absl::Span<const int> & seq_lengths,bool time_major,cudnnDataType_t data_type)1456 static port::StatusOr<CudnnRnnSequenceTensorDescriptor> Create(
1457 GpuExecutor* parent, int max_seq_length, int batch_size, int data_size,
1458 const absl::Span<const int>& seq_lengths, bool time_major,
1459 cudnnDataType_t data_type) {
1460 #if CUDNN_VERSION >= 7201
1461 CHECK_GT(max_seq_length, 0);
1462 int dims[] = {batch_size, data_size, 1};
1463 int strides[] = {dims[1] * dims[2], dims[2], 1};
1464 TensorDescriptor tensor_desc = CreateTensorDescriptor();
1465 RETURN_IF_CUDNN_ERROR(cudnnSetTensorNdDescriptor(
1466 /*tensorDesc=*/tensor_desc.get(), /*dataType=*/data_type,
1467 /*nbDims=*/sizeof(dims) / sizeof(dims[0]), /*dimA=*/dims,
1468 /*strideA=*/strides));
1469 const int* seq_lengths_array = seq_lengths.data();
1470 RNNDataDescriptor data_desc = CreateRNNDataDescriptor();
1471 float padding_fill = 0.0f;
1472 cudnnRNNDataLayout_t layout;
1473 if (time_major) {
1474 layout = CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_UNPACKED;
1475 } else {
1476 layout = CUDNN_RNN_DATA_LAYOUT_BATCH_MAJOR_UNPACKED;
1477 }
1478 RETURN_IF_CUDNN_ERROR(cudnnSetRNNDataDescriptor(
1479 /*RNNDataDesc=*/data_desc.get(), /*dataType*/ data_type,
1480 /*layout=*/layout,
1481 /*maxSeqLength=*/max_seq_length,
1482 /*batchSize=*/batch_size, /*vectorSize=*/data_size,
1483 /*seqLengthArray=*/seq_lengths_array,
1484 /*paddingFill*/ (void*)&padding_fill));
1485 return CudnnRnnSequenceTensorDescriptor(
1486 parent, max_seq_length, batch_size, data_size, data_type,
1487 std::move(data_desc), std::move(tensor_desc));
1488 #else
1489 return port::Status(port::error::INVALID_ARGUMENT,
1490 "No supported cudnnSetRNNDataDescriptor when "
1491 "CUDNN_VERSION < 7.2.1");
1492 #endif
1493 }
1494
handles() const1495 const cudnnTensorDescriptor_t* handles() const { return handles_.data(); }
1496 #if CUDNN_VERSION >= 7201
data_handle() const1497 const cudnnRNNDataDescriptor_t data_handle() const {
1498 return rnn_data_handle_.get();
1499 }
1500 #endif
1501
max_seq_length() const1502 int max_seq_length() const { return max_seq_length_; }
batch_size() const1503 int batch_size() const { return batch_size_; }
data_size() const1504 int data_size() const { return data_size_; }
is_var_seq_lengths() const1505 bool is_var_seq_lengths() const {
1506 #if CUDNN_VERSION >= 7201
1507 return rnn_data_handle_ != nullptr;
1508 #else
1509 return false;
1510 #endif
1511 }
1512
1513 private:
1514 int max_seq_length_;
1515 int batch_size_;
1516 int data_size_;
1517 cudnnDataType_t data_type_;
1518 TensorDescriptor handle_;
1519 #if CUDNN_VERSION >= 7201
1520 RNNDataDescriptor rnn_data_handle_;
1521 #endif
1522 std::vector<cudnnTensorDescriptor_t> handles_; // Copies of handle_.
1523 SE_DISALLOW_COPY_AND_ASSIGN(CudnnRnnSequenceTensorDescriptor);
1524 };
1525
1526 class CudnnRnnStateTensorDescriptor : public dnn::RnnStateTensorDescriptor {
1527 public:
CudnnRnnStateTensorDescriptor(GpuExecutor * parent,int num_layers,int batch_size,int data_size,cudnnDataType_t data_type)1528 CudnnRnnStateTensorDescriptor(GpuExecutor* parent, int num_layers,
1529 int batch_size, int data_size,
1530 cudnnDataType_t data_type)
1531 : handle_(CreateTensorDescriptor()),
1532 num_layers_(num_layers),
1533 batch_size_(batch_size),
1534 data_size_(data_size),
1535 data_type_(data_type) {
1536 int dims[] = {num_layers, batch_size, data_size};
1537 int strides[] = {dims[1] * dims[2], dims[2], 1};
1538 CHECK_CUDNN_OK(cudnnSetTensorNdDescriptor(
1539 /*tensorDesc=*/handle_.get(), /*dataType=*/data_type,
1540 /*nbDims=*/sizeof(dims) / sizeof(dims[0]), /*dimA=*/dims,
1541 /*strideA=*/strides));
1542 }
1543
handle() const1544 cudnnTensorDescriptor_t handle() const { return handle_.get(); }
1545
num_layers() const1546 int num_layers() const { return num_layers_; }
batch_size() const1547 int batch_size() const { return batch_size_; }
data_size() const1548 int data_size() const { return data_size_; }
1549
1550 private:
1551 TensorDescriptor handle_;
1552 int num_layers_;
1553 int batch_size_;
1554 int data_size_;
1555 cudnnDataType_t data_type_;
1556 SE_DISALLOW_COPY_AND_ASSIGN(CudnnRnnStateTensorDescriptor);
1557 };
1558
1559 namespace {
1560
1561 struct RnnModelDims {
1562 int num_layers = 0;
1563 int batch_size = 0;
1564 int max_seq_length = 0;
1565 int hidden_size = 0;
1566 int input_size = 0;
1567 int cell_size = 0;
1568 int dir_count = 0;
1569 };
1570
1571 template <class T>
ExtractAndCheckRnnForward(const CudnnRnnDescriptor & rnn_desc,const CudnnRnnSequenceTensorDescriptor & input_desc,const DeviceMemory<T> & input_data,const CudnnRnnStateTensorDescriptor & input_h_desc,const DeviceMemory<T> & input_h_data,const CudnnRnnStateTensorDescriptor & input_c_desc,const DeviceMemory<T> & input_c_data,const DeviceMemory<T> & params,const CudnnRnnSequenceTensorDescriptor & output_desc,const DeviceMemory<T> & output_data,const CudnnRnnStateTensorDescriptor & output_h_desc,const DeviceMemory<T> & output_h_data,const CudnnRnnStateTensorDescriptor & output_c_desc,const DeviceMemory<T> & output_c_data)1572 port::StatusOr<RnnModelDims> ExtractAndCheckRnnForward(
1573 const CudnnRnnDescriptor& rnn_desc,
1574 const CudnnRnnSequenceTensorDescriptor& input_desc,
1575 const DeviceMemory<T>& input_data,
1576 const CudnnRnnStateTensorDescriptor& input_h_desc,
1577 const DeviceMemory<T>& input_h_data,
1578 const CudnnRnnStateTensorDescriptor& input_c_desc,
1579 const DeviceMemory<T>& input_c_data, const DeviceMemory<T>& params,
1580 const CudnnRnnSequenceTensorDescriptor& output_desc,
1581 const DeviceMemory<T>& output_data,
1582 const CudnnRnnStateTensorDescriptor& output_h_desc,
1583 const DeviceMemory<T>& output_h_data,
1584 const CudnnRnnStateTensorDescriptor& output_c_desc,
1585 const DeviceMemory<T>& output_c_data) {
1586 // extract model parameters
1587 RnnModelDims model_dims;
1588 model_dims.num_layers = rnn_desc.num_layers();
1589 model_dims.batch_size = input_desc.batch_size();
1590 model_dims.max_seq_length = input_desc.max_seq_length();
1591 model_dims.hidden_size = rnn_desc.hidden_size();
1592 model_dims.input_size = input_desc.data_size();
1593 model_dims.cell_size = rnn_desc.cell_size();
1594 model_dims.dir_count =
1595 (rnn_desc.direction_mode() == CUDNN_BIDIRECTIONAL) ? 2 : 1;
1596
1597 // check parameters
1598 if (!(input_h_desc.num_layers() ==
1599 model_dims.num_layers * model_dims.dir_count &&
1600 input_h_desc.batch_size() == model_dims.batch_size &&
1601 input_h_desc.data_size() == model_dims.hidden_size)) {
1602 return port::Status(port::error::INVALID_ARGUMENT, "Invalid input_h shape");
1603 }
1604 // The LSTM projection will be used if input_h_desc.data_size() <
1605 // input_c_desc.data_size()
1606 if (!(input_h_desc.num_layers() == input_c_desc.num_layers() &&
1607 input_h_desc.batch_size() == input_c_desc.batch_size() &&
1608 input_h_desc.data_size() <= input_c_desc.data_size())) {
1609 return port::Status(port::error::INVALID_ARGUMENT, "Invalid input_c shape");
1610 }
1611 if (!(output_desc.max_seq_length() == model_dims.max_seq_length &&
1612 output_desc.batch_size() == model_dims.batch_size &&
1613 output_desc.data_size() ==
1614 model_dims.hidden_size * model_dims.dir_count)) {
1615 return port::Status(port::error::INVALID_ARGUMENT, "Invalid output shape");
1616 }
1617 if (!(input_h_desc.num_layers() == output_h_desc.num_layers() &&
1618 input_h_desc.batch_size() == output_h_desc.batch_size() &&
1619 input_h_desc.data_size() == output_h_desc.data_size())) {
1620 return port::Status(port::error::INVALID_ARGUMENT,
1621 "Invalid output_h shape");
1622 }
1623 if (!(input_h_desc.num_layers() == output_c_desc.num_layers() &&
1624 input_h_desc.batch_size() == output_c_desc.batch_size() &&
1625 input_h_desc.data_size() <= output_c_desc.data_size())) {
1626 return port::Status(port::error::INVALID_ARGUMENT,
1627 "Invalid output_c shape");
1628 }
1629
1630 return model_dims;
1631 }
1632
CheckRNNParameterSize(const CudnnHandle & cudnn,const CudnnRnnDescriptor & rnn_desc,const CudnnRnnSequenceTensorDescriptor & input_desc)1633 port::Status CheckRNNParameterSize(
1634 const CudnnHandle& cudnn, const CudnnRnnDescriptor& rnn_desc,
1635 const CudnnRnnSequenceTensorDescriptor& input_desc) {
1636 size_t params_size_in_bytes = 0;
1637 RETURN_IF_CUDNN_ERROR(cudnnGetRNNParamsSize(
1638 /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(),
1639 /*xDesc=*/input_desc.handles()[0], /*sizeInBytes=*/¶ms_size_in_bytes,
1640 /*dataType=*/rnn_desc.data_type()));
1641 if (static_cast<int64>(params_size_in_bytes) !=
1642 rnn_desc.ParamsSizeInBytes()) {
1643 return port::Status(port::error::INVALID_ARGUMENT,
1644 "Mismatching RNN parameter size");
1645 }
1646 return port::Status::OK();
1647 }
1648
CreateRnnWorkspace(Stream * stream,const CudnnHandle & cudnn,const CudnnRnnDescriptor & rnn_desc,const CudnnRnnSequenceTensorDescriptor & input_desc,ScratchAllocator * workspace_allocator)1649 port::StatusOr<DeviceMemory<uint8>> CreateRnnWorkspace(
1650 Stream* stream, const CudnnHandle& cudnn,
1651 const CudnnRnnDescriptor& rnn_desc,
1652 const CudnnRnnSequenceTensorDescriptor& input_desc,
1653 ScratchAllocator* workspace_allocator) {
1654 // Query the workspace size.
1655 size_t workspace_size_in_bytes = 0;
1656 RETURN_IF_CUDNN_ERROR(cudnnGetRNNWorkspaceSize(
1657 /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(),
1658 /*seqLength=*/input_desc.max_seq_length(), /*xDesc=*/input_desc.handles(),
1659 /*sizeInBytes=*/&workspace_size_in_bytes));
1660 // Allocate the workspace.
1661 if (workspace_size_in_bytes == 0) {
1662 return DeviceMemory<uint8>();
1663 }
1664 return workspace_allocator->AllocateBytes(workspace_size_in_bytes);
1665 }
1666
1667 #if CUDNN_VERSION >= 7402
CreateBatchNormForwardWorkspace(Stream * stream,const CudnnHandle & cudnn,const cudnnBatchNormMode_t & mode,const cudnnBatchNormOps_t & bn_ops,const cudnnActivationDescriptor_t & activation_desc,const CudnnTensorDescriptor & x_descriptor,const CudnnTensorDescriptor & scale_offset_descriptor,ScratchAllocator * workspace_allocator)1668 port::StatusOr<DeviceMemory<uint8>> CreateBatchNormForwardWorkspace(
1669 Stream* stream, const CudnnHandle& cudnn, const cudnnBatchNormMode_t& mode,
1670 const cudnnBatchNormOps_t& bn_ops,
1671 const cudnnActivationDescriptor_t& activation_desc,
1672 const CudnnTensorDescriptor& x_descriptor,
1673 const CudnnTensorDescriptor& scale_offset_descriptor,
1674 ScratchAllocator* workspace_allocator) {
1675 // Query the workspace size.
1676 size_t workspace_size_in_bytes = 0;
1677 RETURN_IF_CUDNN_ERROR(
1678 cudnnGetBatchNormalizationForwardTrainingExWorkspaceSize(
1679 /*handle=*/cudnn.handle(), /*mode=*/mode, /*bnOps=*/bn_ops,
1680 /*xDesc=*/x_descriptor.handle(), /*zDesc=*/x_descriptor.handle(),
1681 /*yDesc=*/x_descriptor.handle(),
1682 /*bnScaleBiasMeanVarDesc=*/scale_offset_descriptor.handle(),
1683 /*activationDesc=*/activation_desc,
1684 /*sizeInBytes=*/&workspace_size_in_bytes));
1685 // Allocate the workspace.
1686 if (workspace_size_in_bytes == 0) {
1687 return DeviceMemory<uint8>();
1688 }
1689 return workspace_allocator->AllocateBytes(workspace_size_in_bytes);
1690 }
1691
CreateBatchNormBackwardWorkspace(Stream * stream,const CudnnHandle & cudnn,const cudnnBatchNormMode_t & mode,const cudnnBatchNormOps_t & bn_ops,const CudnnTensorDescriptor & x_descriptor,const CudnnTensorDescriptor & scale_offset_descriptor,ScratchAllocator * workspace_allocator)1692 port::StatusOr<DeviceMemory<uint8>> CreateBatchNormBackwardWorkspace(
1693 Stream* stream, const CudnnHandle& cudnn, const cudnnBatchNormMode_t& mode,
1694 const cudnnBatchNormOps_t& bn_ops,
1695 const CudnnTensorDescriptor& x_descriptor,
1696 const CudnnTensorDescriptor& scale_offset_descriptor,
1697 ScratchAllocator* workspace_allocator) {
1698 // Query the workspace size.
1699 size_t workspace_size_in_bytes = 0;
1700 RETURN_IF_CUDNN_ERROR(cudnnGetBatchNormalizationBackwardExWorkspaceSize(
1701 /*handle=*/cudnn.handle(), /*mode=*/mode, /*bnOps=*/bn_ops,
1702 /*xDesc=*/x_descriptor.handle(),
1703 /*yDesc=*/x_descriptor.handle(),
1704 /*dyDesc=*/x_descriptor.handle(),
1705 /*dzDesc=*/nullptr,
1706 /*dxDesc=*/x_descriptor.handle(),
1707 /*dBnScaleBiasDesc=*/scale_offset_descriptor.handle(),
1708 /*activationDesc=*/nullptr, /*sizeInBytes=*/&workspace_size_in_bytes));
1709 // Allocate the workspace.
1710 if (workspace_size_in_bytes == 0) {
1711 return DeviceMemory<uint8>();
1712 }
1713 return workspace_allocator->AllocateBytes(workspace_size_in_bytes);
1714 }
1715
1716 #endif
1717
1718 } // namespace
1719
1720 template <class T>
DoRnnForwardImpl(Stream * stream,const CudnnRnnDescriptor & rnn_desc,const CudnnRnnSequenceTensorDescriptor & input_desc,const DeviceMemory<T> & input_data,const CudnnRnnStateTensorDescriptor & input_h_desc,const DeviceMemory<T> & input_h_data,const CudnnRnnStateTensorDescriptor & input_c_desc,const DeviceMemory<T> & input_c_data,const DeviceMemory<T> & params,const CudnnRnnSequenceTensorDescriptor & output_desc,DeviceMemory<T> * output_data,const CudnnRnnStateTensorDescriptor & output_h_desc,DeviceMemory<T> * output_h_data,const CudnnRnnStateTensorDescriptor & output_c_desc,DeviceMemory<T> * output_c_data,bool is_training,ScratchAllocator * reserve_space_allocator,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)1721 port::Status CudnnSupport::DoRnnForwardImpl(
1722 Stream* stream, const CudnnRnnDescriptor& rnn_desc,
1723 const CudnnRnnSequenceTensorDescriptor& input_desc,
1724 const DeviceMemory<T>& input_data,
1725 const CudnnRnnStateTensorDescriptor& input_h_desc,
1726 const DeviceMemory<T>& input_h_data,
1727 const CudnnRnnStateTensorDescriptor& input_c_desc,
1728 const DeviceMemory<T>& input_c_data, const DeviceMemory<T>& params,
1729 const CudnnRnnSequenceTensorDescriptor& output_desc,
1730 DeviceMemory<T>* output_data,
1731 const CudnnRnnStateTensorDescriptor& output_h_desc,
1732 DeviceMemory<T>* output_h_data,
1733 const CudnnRnnStateTensorDescriptor& output_c_desc,
1734 DeviceMemory<T>* output_c_data, bool is_training,
1735 ScratchAllocator* reserve_space_allocator,
1736 ScratchAllocator* workspace_allocator,
1737 dnn::ProfileResult* output_profile_result) {
1738 SE_ASSIGN_OR_RETURN(
1739 RnnModelDims model_dims,
1740 ExtractAndCheckRnnForward(
1741 rnn_desc, input_desc, input_data, input_h_desc, input_h_data,
1742 input_c_desc, input_c_data, params, output_desc, *output_data,
1743 output_h_desc, *output_h_data, output_c_desc, *output_c_data));
1744
1745 auto cudnn = cudnn_->GetHandle(parent_, stream);
1746
1747 SE_RETURN_IF_ERROR(CheckRNNParameterSize(cudnn, rnn_desc, input_desc));
1748 SE_ASSIGN_OR_RETURN(DeviceMemory<uint8> workspace,
1749 CreateRnnWorkspace(stream, cudnn, rnn_desc, input_desc,
1750 workspace_allocator))
1751
1752 // query the reserve space size
1753 // allocate the reserve space
1754 DeviceMemory<uint8> reserve_space;
1755 if (is_training) {
1756 size_t reserve_space_size_in_bytes = 0;
1757 RETURN_IF_CUDNN_ERROR(cudnnGetRNNTrainingReserveSize(
1758 /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(),
1759 /*seqLength=*/model_dims.max_seq_length, /*xDesc=*/input_desc.handles(),
1760 /*sizeInBytes=*/&reserve_space_size_in_bytes));
1761
1762 if (reserve_space_size_in_bytes > 0) {
1763 SE_ASSIGN_OR_RETURN(reserve_space, reserve_space_allocator->AllocateBytes(
1764 reserve_space_size_in_bytes));
1765 }
1766 }
1767
1768 std::unique_ptr<GpuTimer, GpuTimerDeleter> timer;
1769 const bool is_profiling = output_profile_result != nullptr;
1770 if (is_profiling) {
1771 timer.reset(new GpuTimer(parent_));
1772 // The start and stop of the timer should be as close to the Cudnn call as
1773 // possible. It is still possible for other threads to issue workload on
1774 // to this stream. So it could take multiple profiling measurements.
1775 if (!timer->Init() || !timer->Start(AsGpuStream(stream))) {
1776 return port::Status(port::error::INTERNAL, "Failed to start timer");
1777 }
1778 }
1779
1780 if (!is_training) {
1781 if (input_desc.is_var_seq_lengths()) {
1782 #if CUDNN_VERSION >= 7201
1783 RETURN_IF_CUDNN_ERROR(cudnnRNNForwardInferenceEx(
1784 /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(),
1785 /*xDesc=*/input_desc.data_handle(), /*x=*/input_data.opaque(),
1786 /*hxDesc=*/input_h_desc.handle(), /*hx=*/input_h_data.opaque(),
1787 /*cxDesc=*/input_c_desc.handle(), /*cx=*/input_c_data.opaque(),
1788 /*wDesc=*/rnn_desc.params_handle(), /*w=*/params.opaque(),
1789 /*yDesc=*/output_desc.data_handle(),
1790 /*y=*/output_data->opaque(),
1791 /*hyDesc=*/output_h_desc.handle(), /*hy=*/output_h_data->opaque(),
1792 /*cyDesc=*/output_c_desc.handle(), /*cy=*/output_c_data->opaque(),
1793 nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr,
1794 nullptr,
1795 /*workspace=*/workspace.opaque(),
1796 /*workSpaceSizeInBytes=*/workspace.size()));
1797 #else
1798 return port::Status(port::error::INVALID_ARGUMENT,
1799 "No supported cudnnRNNForwardInferenceEx when "
1800 "CUDNN_VERSION < 7.2.1");
1801 #endif
1802 } else {
1803 RETURN_IF_CUDNN_ERROR(cudnnRNNForwardInference(
1804 /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(),
1805 /*seqLength=*/model_dims.max_seq_length,
1806 /*xDesc=*/input_desc.handles(),
1807 /*x=*/input_data.opaque(), /*hxDesc=*/input_h_desc.handle(),
1808 /*hx=*/input_h_data.opaque(), /*cxDesc=*/input_c_desc.handle(),
1809 /*cx=*/input_c_data.opaque(), /*wDesc=*/rnn_desc.params_handle(),
1810 /*w=*/params.opaque(), /*yDesc=*/output_desc.handles(),
1811 /*y=*/output_data->opaque(), /*hyDesc=*/output_h_desc.handle(),
1812 /*hy=*/output_h_data->opaque(), /*cyDesc=*/output_c_desc.handle(),
1813 /*cy=*/output_c_data->opaque(), /*workspace=*/workspace.opaque(),
1814 /*workSpaceSizeInBytes=*/workspace.size()));
1815 }
1816 } else {
1817 if (input_desc.is_var_seq_lengths()) {
1818 #if CUDNN_VERSION >= 7201
1819 // cudnnSetRNNPaddingMode(rnn_desc.handle(), CUDNN_RNN_PADDED_IO_ENABLED);
1820 RETURN_IF_CUDNN_ERROR(cudnnRNNForwardTrainingEx(
1821 /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(),
1822 /*xDesc=*/input_desc.data_handle(), /*x=*/input_data.opaque(),
1823 /*hxDesc=*/input_h_desc.handle(), /*hx=*/input_h_data.opaque(),
1824 /*cxDesc=*/input_c_desc.handle(), /*cx=*/input_c_data.opaque(),
1825 /*wDesc=*/rnn_desc.params_handle(), /*w=*/params.opaque(),
1826 /*yDesc=*/output_desc.data_handle(),
1827 /*y=*/output_data->opaque(),
1828 /*hyDesc=*/output_h_desc.handle(), /*hy=*/output_h_data->opaque(),
1829 /*cyDesc=*/output_c_desc.handle(), /*cy=*/output_c_data->opaque(),
1830 nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr,
1831 nullptr,
1832 /*workspace=*/workspace.opaque(),
1833 /*workSpaceSizeInBytes=*/workspace.size(),
1834 /*reserveSpace=*/reserve_space.opaque(),
1835 /*reserveSpaceSizeInBytes=*/reserve_space.size()));
1836 #else
1837 return port::Status(port::error::INVALID_ARGUMENT,
1838 "No supported cudnnRNNForwardTrainingEx when "
1839 "CUDNN_VERSION < 7.2.1");
1840 #endif
1841 } else {
1842 RETURN_IF_CUDNN_ERROR(cudnnRNNForwardTraining(
1843 /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(),
1844 /*seqLength=*/model_dims.max_seq_length,
1845 /*xDesc=*/input_desc.handles(),
1846 /*x=*/input_data.opaque(), /*hxDesc=*/input_h_desc.handle(),
1847 /*hx=*/input_h_data.opaque(), /*cxDesc=*/input_c_desc.handle(),
1848 /*cx=*/input_c_data.opaque(), /*wDesc=*/rnn_desc.params_handle(),
1849 /*w=*/params.opaque(), /*yDesc=*/output_desc.handles(),
1850 /*y=*/output_data->opaque(), /*hyDesc=*/output_h_desc.handle(),
1851 /*hy=*/output_h_data->opaque(), /*cyDesc=*/output_c_desc.handle(),
1852 /*cy=*/output_c_data->opaque(), /*workspace=*/workspace.opaque(),
1853 /*workSpaceSizeInBytes=*/workspace.size(),
1854 /*reserveSpace=*/reserve_space.opaque(),
1855 /*reserveSpaceSizeInBytes=*/reserve_space.size()));
1856 }
1857 }
1858
1859 if (is_profiling) {
1860 if (!timer->Stop(AsGpuStream(stream))) {
1861 return port::Status(port::error::INTERNAL, "Failed to stop timer");
1862 }
1863 auto algo_desc = *rnn_desc.algorithm_config().algorithm();
1864 output_profile_result->set_algorithm(algo_desc);
1865 output_profile_result->set_elapsed_time_in_ms(
1866 timer->GetElapsedMilliseconds());
1867 }
1868
1869 return port::Status::OK();
1870 }
1871
1872 template <class T>
DoRnnBackwardImpl(Stream * stream,const CudnnRnnDescriptor & rnn_desc,const CudnnRnnSequenceTensorDescriptor & input_desc,const DeviceMemory<T> & input_data,const CudnnRnnStateTensorDescriptor & input_h_desc,const DeviceMemory<T> & input_h_data,const CudnnRnnStateTensorDescriptor & input_c_desc,const DeviceMemory<T> & input_c_data,const DeviceMemory<T> & params,const CudnnRnnSequenceTensorDescriptor & output_desc,const DeviceMemory<T> & output_data,const CudnnRnnStateTensorDescriptor & output_h_desc,const DeviceMemory<T> & output_h_data,const CudnnRnnStateTensorDescriptor & output_c_desc,const DeviceMemory<T> & output_c_data,const DeviceMemory<T> & output_backprop_data,const DeviceMemory<T> & output_h_backprop_data,const DeviceMemory<T> & output_c_backprop_data,DeviceMemory<T> * input_backprop_data,DeviceMemory<T> * input_h_backprop_data,DeviceMemory<T> * input_c_backprop_data,DeviceMemory<T> * params_backprop_data,DeviceMemory<uint8> * reserve_space_data,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)1873 port::Status CudnnSupport::DoRnnBackwardImpl(
1874 Stream* stream, const CudnnRnnDescriptor& rnn_desc,
1875 const CudnnRnnSequenceTensorDescriptor& input_desc,
1876 const DeviceMemory<T>& input_data,
1877 const CudnnRnnStateTensorDescriptor& input_h_desc,
1878 const DeviceMemory<T>& input_h_data,
1879 const CudnnRnnStateTensorDescriptor& input_c_desc,
1880 const DeviceMemory<T>& input_c_data, const DeviceMemory<T>& params,
1881 const CudnnRnnSequenceTensorDescriptor& output_desc,
1882 const DeviceMemory<T>& output_data,
1883 const CudnnRnnStateTensorDescriptor& output_h_desc,
1884 const DeviceMemory<T>& output_h_data,
1885 const CudnnRnnStateTensorDescriptor& output_c_desc,
1886 const DeviceMemory<T>& output_c_data,
1887 const DeviceMemory<T>& output_backprop_data,
1888 const DeviceMemory<T>& output_h_backprop_data,
1889 const DeviceMemory<T>& output_c_backprop_data,
1890 DeviceMemory<T>* input_backprop_data,
1891 DeviceMemory<T>* input_h_backprop_data,
1892 DeviceMemory<T>* input_c_backprop_data,
1893 DeviceMemory<T>* params_backprop_data,
1894 DeviceMemory<uint8>* reserve_space_data,
1895 ScratchAllocator* workspace_allocator,
1896 dnn::ProfileResult* output_profile_result) {
1897 SE_ASSIGN_OR_RETURN(
1898 RnnModelDims model_dims,
1899 ExtractAndCheckRnnForward(rnn_desc, input_desc, input_data, input_h_desc,
1900 input_h_data, input_c_desc, input_c_data,
1901 params, output_desc, output_data, output_h_desc,
1902 output_h_data, output_c_desc, output_c_data));
1903
1904 auto cudnn = cudnn_->GetHandle(parent_, stream);
1905
1906 SE_RETURN_IF_ERROR(CheckRNNParameterSize(cudnn, rnn_desc, input_desc));
1907 SE_ASSIGN_OR_RETURN(DeviceMemory<uint8> workspace,
1908 CreateRnnWorkspace(stream, cudnn, rnn_desc, input_desc,
1909 workspace_allocator));
1910
1911 std::unique_ptr<GpuTimer, GpuTimerDeleter> timer;
1912 const bool is_profiling = output_profile_result != nullptr;
1913 if (is_profiling) {
1914 timer.reset(new GpuTimer(parent_));
1915 // The start and stop of the timer should be as close to the Cudnn call as
1916 // possible. It is still possible for other threads to issue workload on
1917 // to this stream. So it could take multiple profiling measurements.
1918 if (!timer->Init() || !timer->Start(AsGpuStream(stream))) {
1919 return port::Status(port::error::INTERNAL, "Failed to start timer");
1920 }
1921 }
1922
1923 if (input_desc.is_var_seq_lengths()) {
1924 #if CUDNN_VERSION >= 7201
1925 RETURN_IF_CUDNN_ERROR(cudnnRNNBackwardDataEx(
1926 /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(),
1927 /*yDesc=*/output_desc.data_handle(), /*y=*/output_data.opaque(),
1928 /*dyDesc=*/output_desc.data_handle(),
1929 /*dy=*/output_backprop_data.opaque(), nullptr, nullptr,
1930 /*dhyDesc=*/output_h_desc.handle(),
1931 /*dhy=*/output_h_backprop_data.opaque(),
1932 /*dcyDesc=*/output_c_desc.handle(),
1933 /*dcy=*/output_c_backprop_data.opaque(),
1934 /*wDesc=*/rnn_desc.params_handle(), /*w=*/params.opaque(),
1935 /*hxDesc=*/input_h_desc.handle(), /*hx=*/input_h_data.opaque(),
1936 /*cxDesc=*/input_c_desc.handle(), /*cx=*/input_c_data.opaque(),
1937 /*dxDesc=*/input_desc.data_handle(),
1938 /*dx=*/input_backprop_data->opaque(),
1939 /*dhxDesc=*/input_h_desc.handle(),
1940 /*dhx=*/input_h_backprop_data->opaque(),
1941 /*dcxDesc=*/input_c_desc.handle(),
1942 /*dcx=*/input_c_backprop_data->opaque(), nullptr, nullptr,
1943 /*workspace=*/workspace.opaque(),
1944 /*workSpaceSizeInBytes=*/workspace.size(),
1945 /*reserveSpace=*/reserve_space_data->opaque(),
1946 /*reserveSpaceSizeInBytes=*/reserve_space_data->size()));
1947 #else
1948 return port::Status(port::error::INVALID_ARGUMENT,
1949 "No supported cudnnRNNBackwardDataEx when "
1950 "CUDNN_VERSION < 7.2.1");
1951 #endif
1952 } else {
1953 RETURN_IF_CUDNN_ERROR(cudnnRNNBackwardData(
1954 /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(),
1955 /*seqLength=*/model_dims.max_seq_length,
1956 /*yDesc=*/output_desc.handles(),
1957 /*y=*/output_data.opaque(), /*dyDesc=*/output_desc.handles(),
1958 /*dy=*/output_backprop_data.opaque(),
1959 /*dhyDesc=*/output_h_desc.handle(),
1960 /*dhy=*/output_h_backprop_data.opaque(),
1961 /*dcyDesc=*/output_c_desc.handle(),
1962 /*dcy=*/output_c_backprop_data.opaque(),
1963 /*wDesc=*/rnn_desc.params_handle(), /*w=*/params.opaque(),
1964 /*hxDesc=*/input_h_desc.handle(), /*hx=*/input_h_data.opaque(),
1965 /*cxDesc=*/input_c_desc.handle(), /*cx=*/input_c_data.opaque(),
1966 /*dxDesc=*/input_desc.handles(), /*dx=*/input_backprop_data->opaque(),
1967 /*dhxDesc=*/input_h_desc.handle(),
1968 /*dhx=*/input_h_backprop_data->opaque(),
1969 /*dcxDesc=*/input_c_desc.handle(),
1970 /*dcx=*/input_c_backprop_data->opaque(),
1971 /*workspace=*/workspace.opaque(),
1972 /*workSpaceSizeInBytes=*/workspace.size(),
1973 /*reserveSpace=*/reserve_space_data->opaque(),
1974 /*reserveSpaceSizeInBytes=*/reserve_space_data->size()));
1975 }
1976
1977 if (params_backprop_data != nullptr) {
1978 // Clear the dw to zeros.
1979 stream->ThenMemZero(params_backprop_data, params_backprop_data->size());
1980 if (input_desc.is_var_seq_lengths()) {
1981 #if CUDNN_VERSION >= 7201
1982 RETURN_IF_CUDNN_ERROR(cudnnRNNBackwardWeightsEx(
1983 /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(),
1984 /*xDesc=*/input_desc.data_handle(), /*x=*/input_data.opaque(),
1985 /*hxDesc=*/input_h_desc.handle(), /*hx=*/input_h_data.opaque(),
1986 /*yDesc=*/output_desc.data_handle(),
1987 /*y=*/output_data.opaque(),
1988 /*workspace=*/workspace.opaque(),
1989 /*workSpaceSizeInBytes=*/workspace.size(),
1990 /*dwDesc=*/rnn_desc.params_handle(),
1991 /*dw=*/params_backprop_data->opaque(),
1992 /*reserveSpace=*/reserve_space_data->opaque(),
1993 /*reserveSpaceSizeInBytes=*/reserve_space_data->size()));
1994 #else
1995 return port::Status(port::error::INVALID_ARGUMENT,
1996 "No supported cudnnRNNBackwardWeightsEx when "
1997 "CUDNN_VERSION < 7.2.1");
1998 #endif
1999 } else {
2000 // make the backward weight call
2001 RETURN_IF_CUDNN_ERROR(cudnnRNNBackwardWeights(
2002 /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(),
2003 /*seqLength=*/model_dims.max_seq_length,
2004 /*xDesc=*/input_desc.handles(),
2005 /*x=*/input_data.opaque(), /*hxDesc=*/input_h_desc.handle(),
2006 /*hx=*/input_h_data.opaque(), /*yDesc=*/output_desc.handles(),
2007 /*y=*/output_data.opaque(), /*workspace=*/workspace.opaque(),
2008 /*workSpaceSizeInBytes=*/workspace.size(),
2009 /*dwDesc=*/rnn_desc.params_handle(),
2010 /*dw=*/params_backprop_data->opaque(),
2011 /*reserveSpace=*/reserve_space_data->opaque(),
2012 /*reserveSpaceSizeInBytes=*/reserve_space_data->size()));
2013 }
2014 }
2015
2016 if (is_profiling) {
2017 if (!timer->Stop(AsGpuStream(stream))) {
2018 return port::Status(port::error::INTERNAL, "Failed to stop timer");
2019 }
2020 auto algo_desc = *rnn_desc.algorithm_config().algorithm();
2021 output_profile_result->set_algorithm(algo_desc);
2022 output_profile_result->set_elapsed_time_in_ms(
2023 timer->GetElapsedMilliseconds());
2024 }
2025
2026 return port::Status::OK();
2027 }
2028
DoCtcLossImpl(Stream * stream,const CudnnRnnStateTensorDescriptor & probs_desc,const DeviceMemoryBase probs_data,absl::Span<const int> labels_data,absl::Span<const int> labels_lengths_data,absl::Span<const int> input_lengths_data,DeviceMemoryBase costs_data,const CudnnRnnStateTensorDescriptor & grads_desc,DeviceMemoryBase grads_data,const CudnnCtcLossDescriptor & ctc_loss_desc,DeviceMemory<uint8> scratch_memory)2029 port::Status CudnnSupport::DoCtcLossImpl(
2030 Stream* stream, const CudnnRnnStateTensorDescriptor& probs_desc,
2031 const DeviceMemoryBase probs_data, absl::Span<const int> labels_data,
2032 absl::Span<const int> labels_lengths_data,
2033 absl::Span<const int> input_lengths_data, DeviceMemoryBase costs_data,
2034 const CudnnRnnStateTensorDescriptor& grads_desc,
2035 DeviceMemoryBase grads_data, const CudnnCtcLossDescriptor& ctc_loss_desc,
2036 DeviceMemory<uint8> scratch_memory) {
2037 auto cudnn = cudnn_->GetHandle(parent_, stream);
2038
2039 int kNumTimestamps = probs_desc.num_layers();
2040 int kBatchSize = probs_desc.batch_size();
2041 int kNumLabels = probs_desc.data_size();
2042 int total_size = kNumLabels * kNumTimestamps * kBatchSize;
2043 (void)total_size;
2044
2045 #if CUDNN_VERSION >= 7603
2046 RETURN_IF_CUDNN_ERROR(cudnnCTCLoss(
2047 /*handle=*/cudnn.handle(), /*probsDesc=*/probs_desc.handle(),
2048 /*probs=*/probs_data.opaque(), /*labels=*/labels_data.data(),
2049 /*labelLengths=*/labels_lengths_data.data(),
2050 /*inputLengths=*/input_lengths_data.data(),
2051 /*costs=*/costs_data.opaque(), /*gradientsDesc=*/grads_desc.handle(),
2052 /*gradients=*/grads_data.opaque(),
2053 /*algo=*/CUDNN_CTC_LOSS_ALGO_NON_DETERMINISTIC,
2054 /*ctcLossDesc=*/ctc_loss_desc.handle(),
2055 /*workspace=*/scratch_memory.opaque(),
2056 /*workSpaceSizeInBytes=*/scratch_memory.size()));
2057 #else
2058 return port::Status(port::error::INVALID_ARGUMENT,
2059 "No supported cudnnCTCLoss when "
2060 "CUDNN_VERSION < 7.6.3");
2061 #endif
2062
2063 return port::Status::OK();
2064 }
2065
2066 port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>>
createRnnDescriptor(int num_layers,int hidden_size,int input_size,int cell_size,int batch_size,dnn::RnnInputMode input_mode,dnn::RnnDirectionMode direction_mode,dnn::RnnMode rnn_mode,dnn::DataType data_type,const dnn::AlgorithmConfig & algorithm_config,float dropout,uint64 seed,ScratchAllocator * state_allocator,bool use_padded_io)2067 CudnnSupport::createRnnDescriptor(
2068 int num_layers, int hidden_size, int input_size, int cell_size,
2069 int batch_size, dnn::RnnInputMode input_mode,
2070 dnn::RnnDirectionMode direction_mode, dnn::RnnMode rnn_mode,
2071 dnn::DataType data_type, const dnn::AlgorithmConfig& algorithm_config,
2072 float dropout, uint64 seed, ScratchAllocator* state_allocator,
2073 bool use_padded_io) {
2074 // Setting up a cudnnRNNDescriptor requires a cuDNN handle, but because it's
2075 // not enqueueing anything into a stream, we pass in the null stream.
2076 auto cudnn = cudnn_->GetHandle(parent_, /*stream=*/nullptr);
2077 SE_ASSIGN_OR_RETURN(
2078 CudnnRnnDescriptor rnn_desc,
2079 CudnnRnnDescriptor::Create(
2080 cudnn, num_layers, hidden_size, input_size, cell_size, batch_size,
2081 ToCudnnRnnInputMode(input_mode),
2082 ToCudnnRnnDirectionMode(direction_mode), ToCudnnRnnMode(rnn_mode),
2083 ToCudnnDataType(data_type), GetRnnComputeType(data_type),
2084 algorithm_config, dropout, seed, state_allocator, use_padded_io));
2085 return std::unique_ptr<dnn::RnnDescriptor>(
2086 new CudnnRnnDescriptor(std::move(rnn_desc)));
2087 }
2088
2089 port::StatusOr<std::unique_ptr<dnn::RnnSequenceTensorDescriptor>>
createRnnSequenceTensorDescriptor(int max_seq_length,int batch_size,int data_size,dnn::DataType data_type)2090 CudnnSupport::createRnnSequenceTensorDescriptor(int max_seq_length,
2091 int batch_size, int data_size,
2092 dnn::DataType data_type) {
2093 SE_ASSIGN_OR_RETURN(CudnnRnnSequenceTensorDescriptor descriptor,
2094 CudnnRnnSequenceTensorDescriptor::Create(
2095 parent_, max_seq_length, batch_size, data_size,
2096 ToCudnnDataType(data_type)));
2097 return std::unique_ptr<dnn::RnnSequenceTensorDescriptor>(
2098 new CudnnRnnSequenceTensorDescriptor(std::move(descriptor)));
2099 }
2100
2101 port::StatusOr<std::unique_ptr<dnn::RnnSequenceTensorDescriptor>>
createRnnSequenceTensorDescriptor(int max_seq_length,int batch_size,int data_size,const absl::Span<const int> & seq_lengths,bool time_major,dnn::DataType data_type)2102 CudnnSupport::createRnnSequenceTensorDescriptor(
2103 int max_seq_length, int batch_size, int data_size,
2104 const absl::Span<const int>& seq_lengths, bool time_major,
2105 dnn::DataType data_type) {
2106 SE_ASSIGN_OR_RETURN(CudnnRnnSequenceTensorDescriptor descriptor,
2107 CudnnRnnSequenceTensorDescriptor::Create(
2108 parent_, max_seq_length, batch_size, data_size,
2109 seq_lengths, time_major, ToCudnnDataType(data_type)));
2110 return std::unique_ptr<dnn::RnnSequenceTensorDescriptor>(
2111 new CudnnRnnSequenceTensorDescriptor(std::move(descriptor)));
2112 }
2113
2114 port::StatusOr<std::unique_ptr<dnn::RnnStateTensorDescriptor>>
createRnnStateTensorDescriptor(int num_layer,int batch_size,int data_size,dnn::DataType data_type)2115 CudnnSupport::createRnnStateTensorDescriptor(int num_layer, int batch_size,
2116 int data_size,
2117 dnn::DataType data_type) {
2118 return std::unique_ptr<dnn::RnnStateTensorDescriptor>(
2119 new CudnnRnnStateTensorDescriptor(parent_, num_layer, batch_size,
2120 data_size, ToCudnnDataType(data_type)));
2121 }
2122
DoRnnForward(Stream * stream,const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<Eigen::half> & input_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<Eigen::half> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<Eigen::half> & input_c_data,const DeviceMemory<Eigen::half> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,DeviceMemory<Eigen::half> * output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,DeviceMemory<Eigen::half> * output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,DeviceMemory<Eigen::half> * output_c_data,bool is_training,ScratchAllocator * reserve_space_allocator,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)2123 bool CudnnSupport::DoRnnForward(
2124 Stream* stream, const dnn::RnnDescriptor& rnn_desc,
2125 const dnn::RnnSequenceTensorDescriptor& input_desc,
2126 const DeviceMemory<Eigen::half>& input_data,
2127 const dnn::RnnStateTensorDescriptor& input_h_desc,
2128 const DeviceMemory<Eigen::half>& input_h_data,
2129 const dnn::RnnStateTensorDescriptor& input_c_desc,
2130 const DeviceMemory<Eigen::half>& input_c_data,
2131 const DeviceMemory<Eigen::half>& params,
2132 const dnn::RnnSequenceTensorDescriptor& output_desc,
2133 DeviceMemory<Eigen::half>* output_data,
2134 const dnn::RnnStateTensorDescriptor& output_h_desc,
2135 DeviceMemory<Eigen::half>* output_h_data,
2136 const dnn::RnnStateTensorDescriptor& output_c_desc,
2137 DeviceMemory<Eigen::half>* output_c_data, bool is_training,
2138 ScratchAllocator* reserve_space_allocator,
2139 ScratchAllocator* workspace_allocator,
2140 dnn::ProfileResult* output_profile_result) {
2141 const CudnnRnnDescriptor& cudnn_rnn_desc =
2142 static_cast<const CudnnRnnDescriptor&>(rnn_desc);
2143 const CudnnRnnSequenceTensorDescriptor& cudnn_input_desc =
2144 static_cast<const CudnnRnnSequenceTensorDescriptor&>(input_desc);
2145 const CudnnRnnStateTensorDescriptor& cudnn_input_h_desc =
2146 static_cast<const CudnnRnnStateTensorDescriptor&>(input_h_desc);
2147 const CudnnRnnStateTensorDescriptor& cudnn_input_c_desc =
2148 static_cast<const CudnnRnnStateTensorDescriptor&>(input_c_desc);
2149 const CudnnRnnSequenceTensorDescriptor& cudnn_output_desc =
2150 static_cast<const CudnnRnnSequenceTensorDescriptor&>(output_desc);
2151 const CudnnRnnStateTensorDescriptor& cudnn_output_h_desc =
2152 static_cast<const CudnnRnnStateTensorDescriptor&>(output_h_desc);
2153 const CudnnRnnStateTensorDescriptor& cudnn_output_c_desc =
2154 static_cast<const CudnnRnnStateTensorDescriptor&>(output_c_desc);
2155 return IsStatusOk(
2156 DoRnnForwardImpl<Eigen::half>(
2157 stream, cudnn_rnn_desc, cudnn_input_desc, input_data,
2158 cudnn_input_h_desc, input_h_data, cudnn_input_c_desc, input_c_data,
2159 params, cudnn_output_desc, output_data, cudnn_output_h_desc,
2160 output_h_data, cudnn_output_c_desc, output_c_data, is_training,
2161 reserve_space_allocator, workspace_allocator, output_profile_result),
2162 /*report_error=*/!output_profile_result);
2163 }
2164
DoRnnForward(Stream * stream,const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<float> & input_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<float> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<float> & input_c_data,const DeviceMemory<float> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,DeviceMemory<float> * output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,DeviceMemory<float> * output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,DeviceMemory<float> * output_c_data,bool is_training,ScratchAllocator * reserve_space_allocator,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)2165 bool CudnnSupport::DoRnnForward(
2166 Stream* stream, const dnn::RnnDescriptor& rnn_desc,
2167 const dnn::RnnSequenceTensorDescriptor& input_desc,
2168 const DeviceMemory<float>& input_data,
2169 const dnn::RnnStateTensorDescriptor& input_h_desc,
2170 const DeviceMemory<float>& input_h_data,
2171 const dnn::RnnStateTensorDescriptor& input_c_desc,
2172 const DeviceMemory<float>& input_c_data, const DeviceMemory<float>& params,
2173 const dnn::RnnSequenceTensorDescriptor& output_desc,
2174 DeviceMemory<float>* output_data,
2175 const dnn::RnnStateTensorDescriptor& output_h_desc,
2176 DeviceMemory<float>* output_h_data,
2177 const dnn::RnnStateTensorDescriptor& output_c_desc,
2178 DeviceMemory<float>* output_c_data, bool is_training,
2179 ScratchAllocator* reserve_space_allocator,
2180 ScratchAllocator* workspace_allocator,
2181 dnn::ProfileResult* output_profile_result) {
2182 const CudnnRnnDescriptor& cudnn_rnn_desc =
2183 static_cast<const CudnnRnnDescriptor&>(rnn_desc);
2184 const CudnnRnnSequenceTensorDescriptor& cudnn_input_desc =
2185 static_cast<const CudnnRnnSequenceTensorDescriptor&>(input_desc);
2186 const CudnnRnnStateTensorDescriptor& cudnn_input_h_desc =
2187 static_cast<const CudnnRnnStateTensorDescriptor&>(input_h_desc);
2188 const CudnnRnnStateTensorDescriptor& cudnn_input_c_desc =
2189 static_cast<const CudnnRnnStateTensorDescriptor&>(input_c_desc);
2190 const CudnnRnnSequenceTensorDescriptor& cudnn_output_desc =
2191 static_cast<const CudnnRnnSequenceTensorDescriptor&>(output_desc);
2192 const CudnnRnnStateTensorDescriptor& cudnn_output_h_desc =
2193 static_cast<const CudnnRnnStateTensorDescriptor&>(output_h_desc);
2194 const CudnnRnnStateTensorDescriptor& cudnn_output_c_desc =
2195 static_cast<const CudnnRnnStateTensorDescriptor&>(output_c_desc);
2196 return IsStatusOk(
2197 DoRnnForwardImpl<float>(
2198 stream, cudnn_rnn_desc, cudnn_input_desc, input_data,
2199 cudnn_input_h_desc, input_h_data, cudnn_input_c_desc, input_c_data,
2200 params, cudnn_output_desc, output_data, cudnn_output_h_desc,
2201 output_h_data, cudnn_output_c_desc, output_c_data, is_training,
2202 reserve_space_allocator, workspace_allocator, output_profile_result),
2203 /*report_error=*/!output_profile_result);
2204 }
2205
DoRnnForward(Stream * stream,const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<double> & input_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<double> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<double> & input_c_data,const DeviceMemory<double> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,DeviceMemory<double> * output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,DeviceMemory<double> * output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,DeviceMemory<double> * output_c_data,bool is_training,ScratchAllocator * reserve_space_allocator,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)2206 bool CudnnSupport::DoRnnForward(
2207 Stream* stream, const dnn::RnnDescriptor& rnn_desc,
2208 const dnn::RnnSequenceTensorDescriptor& input_desc,
2209 const DeviceMemory<double>& input_data,
2210 const dnn::RnnStateTensorDescriptor& input_h_desc,
2211 const DeviceMemory<double>& input_h_data,
2212 const dnn::RnnStateTensorDescriptor& input_c_desc,
2213 const DeviceMemory<double>& input_c_data,
2214 const DeviceMemory<double>& params,
2215 const dnn::RnnSequenceTensorDescriptor& output_desc,
2216 DeviceMemory<double>* output_data,
2217 const dnn::RnnStateTensorDescriptor& output_h_desc,
2218 DeviceMemory<double>* output_h_data,
2219 const dnn::RnnStateTensorDescriptor& output_c_desc,
2220 DeviceMemory<double>* output_c_data, bool is_training,
2221 ScratchAllocator* reserve_space_allocator,
2222 ScratchAllocator* workspace_allocator,
2223 dnn::ProfileResult* output_profile_result) {
2224 const CudnnRnnDescriptor& cudnn_rnn_desc =
2225 static_cast<const CudnnRnnDescriptor&>(rnn_desc);
2226 const CudnnRnnSequenceTensorDescriptor& cudnn_input_desc =
2227 static_cast<const CudnnRnnSequenceTensorDescriptor&>(input_desc);
2228 const CudnnRnnStateTensorDescriptor& cudnn_input_h_desc =
2229 static_cast<const CudnnRnnStateTensorDescriptor&>(input_h_desc);
2230 const CudnnRnnStateTensorDescriptor& cudnn_input_c_desc =
2231 static_cast<const CudnnRnnStateTensorDescriptor&>(input_c_desc);
2232 const CudnnRnnSequenceTensorDescriptor& cudnn_output_desc =
2233 static_cast<const CudnnRnnSequenceTensorDescriptor&>(output_desc);
2234 const CudnnRnnStateTensorDescriptor& cudnn_output_h_desc =
2235 static_cast<const CudnnRnnStateTensorDescriptor&>(output_h_desc);
2236 const CudnnRnnStateTensorDescriptor& cudnn_output_c_desc =
2237 static_cast<const CudnnRnnStateTensorDescriptor&>(output_c_desc);
2238 return IsStatusOk(
2239 DoRnnForwardImpl<double>(
2240 stream, cudnn_rnn_desc, cudnn_input_desc, input_data,
2241 cudnn_input_h_desc, input_h_data, cudnn_input_c_desc, input_c_data,
2242 params, cudnn_output_desc, output_data, cudnn_output_h_desc,
2243 output_h_data, cudnn_output_c_desc, output_c_data, is_training,
2244 reserve_space_allocator, workspace_allocator, output_profile_result),
2245 /*report_error=*/!output_profile_result);
2246 }
2247
DoRnnBackward(Stream * stream,const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<Eigen::half> & input_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<Eigen::half> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<Eigen::half> & input_c_data,const DeviceMemory<Eigen::half> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,const DeviceMemory<Eigen::half> & output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,const DeviceMemory<Eigen::half> & output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,const DeviceMemory<Eigen::half> & output_c_data,const DeviceMemory<Eigen::half> & output_backprop_data,const DeviceMemory<Eigen::half> & output_h_backprop_data,const DeviceMemory<Eigen::half> & output_c_backprop_data,DeviceMemory<Eigen::half> * input_backprop_data,DeviceMemory<Eigen::half> * input_h_backprop_data,DeviceMemory<Eigen::half> * input_c_backprop_data,DeviceMemory<Eigen::half> * params_backprop_data,DeviceMemory<uint8> * reserve_space_data,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)2248 bool CudnnSupport::DoRnnBackward(
2249 Stream* stream, const dnn::RnnDescriptor& rnn_desc,
2250 const dnn::RnnSequenceTensorDescriptor& input_desc,
2251 const DeviceMemory<Eigen::half>& input_data,
2252 const dnn::RnnStateTensorDescriptor& input_h_desc,
2253 const DeviceMemory<Eigen::half>& input_h_data,
2254 const dnn::RnnStateTensorDescriptor& input_c_desc,
2255 const DeviceMemory<Eigen::half>& input_c_data,
2256 const DeviceMemory<Eigen::half>& params,
2257 const dnn::RnnSequenceTensorDescriptor& output_desc,
2258 const DeviceMemory<Eigen::half>& output_data,
2259 const dnn::RnnStateTensorDescriptor& output_h_desc,
2260 const DeviceMemory<Eigen::half>& output_h_data,
2261 const dnn::RnnStateTensorDescriptor& output_c_desc,
2262 const DeviceMemory<Eigen::half>& output_c_data,
2263 const DeviceMemory<Eigen::half>& output_backprop_data,
2264 const DeviceMemory<Eigen::half>& output_h_backprop_data,
2265 const DeviceMemory<Eigen::half>& output_c_backprop_data,
2266 DeviceMemory<Eigen::half>* input_backprop_data,
2267 DeviceMemory<Eigen::half>* input_h_backprop_data,
2268 DeviceMemory<Eigen::half>* input_c_backprop_data,
2269 DeviceMemory<Eigen::half>* params_backprop_data,
2270 DeviceMemory<uint8>* reserve_space_data,
2271 ScratchAllocator* workspace_allocator,
2272 dnn::ProfileResult* output_profile_result) {
2273 const CudnnRnnDescriptor& cudnn_rnn_desc =
2274 static_cast<const CudnnRnnDescriptor&>(rnn_desc);
2275 const CudnnRnnSequenceTensorDescriptor& cudnn_input_desc =
2276 static_cast<const CudnnRnnSequenceTensorDescriptor&>(input_desc);
2277 const CudnnRnnStateTensorDescriptor& cudnn_input_h_desc =
2278 static_cast<const CudnnRnnStateTensorDescriptor&>(input_h_desc);
2279 const CudnnRnnStateTensorDescriptor& cudnn_input_c_desc =
2280 static_cast<const CudnnRnnStateTensorDescriptor&>(input_c_desc);
2281 const CudnnRnnSequenceTensorDescriptor& cudnn_output_desc =
2282 static_cast<const CudnnRnnSequenceTensorDescriptor&>(output_desc);
2283 const CudnnRnnStateTensorDescriptor& cudnn_output_h_desc =
2284 static_cast<const CudnnRnnStateTensorDescriptor&>(output_h_desc);
2285 const CudnnRnnStateTensorDescriptor& cudnn_output_c_desc =
2286 static_cast<const CudnnRnnStateTensorDescriptor&>(output_c_desc);
2287 return IsStatusOk(
2288 DoRnnBackwardImpl<Eigen::half>(
2289 stream, cudnn_rnn_desc, cudnn_input_desc, input_data,
2290 cudnn_input_h_desc, input_h_data, cudnn_input_c_desc, input_c_data,
2291 params, cudnn_output_desc, output_data, cudnn_output_h_desc,
2292 output_h_data, cudnn_output_c_desc, output_c_data,
2293 output_backprop_data, output_h_backprop_data, output_c_backprop_data,
2294 input_backprop_data, input_h_backprop_data, input_c_backprop_data,
2295 params_backprop_data, reserve_space_data, workspace_allocator,
2296 output_profile_result),
2297 /*report_error=*/!output_profile_result);
2298 }
2299
DoRnnBackward(Stream * stream,const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<float> & input_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<float> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<float> & input_c_data,const DeviceMemory<float> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,const DeviceMemory<float> & output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,const DeviceMemory<float> & output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,const DeviceMemory<float> & output_c_data,const DeviceMemory<float> & output_backprop_data,const DeviceMemory<float> & output_h_backprop_data,const DeviceMemory<float> & output_c_backprop_data,DeviceMemory<float> * input_backprop_data,DeviceMemory<float> * input_h_backprop_data,DeviceMemory<float> * input_c_backprop_data,DeviceMemory<float> * params_backprop_data,DeviceMemory<uint8> * reserve_space_data,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)2300 bool CudnnSupport::DoRnnBackward(
2301 Stream* stream, const dnn::RnnDescriptor& rnn_desc,
2302 const dnn::RnnSequenceTensorDescriptor& input_desc,
2303 const DeviceMemory<float>& input_data,
2304 const dnn::RnnStateTensorDescriptor& input_h_desc,
2305 const DeviceMemory<float>& input_h_data,
2306 const dnn::RnnStateTensorDescriptor& input_c_desc,
2307 const DeviceMemory<float>& input_c_data, const DeviceMemory<float>& params,
2308 const dnn::RnnSequenceTensorDescriptor& output_desc,
2309 const DeviceMemory<float>& output_data,
2310 const dnn::RnnStateTensorDescriptor& output_h_desc,
2311 const DeviceMemory<float>& output_h_data,
2312 const dnn::RnnStateTensorDescriptor& output_c_desc,
2313 const DeviceMemory<float>& output_c_data,
2314 const DeviceMemory<float>& output_backprop_data,
2315 const DeviceMemory<float>& output_h_backprop_data,
2316 const DeviceMemory<float>& output_c_backprop_data,
2317 DeviceMemory<float>* input_backprop_data,
2318 DeviceMemory<float>* input_h_backprop_data,
2319 DeviceMemory<float>* input_c_backprop_data,
2320 DeviceMemory<float>* params_backprop_data,
2321 DeviceMemory<uint8>* reserve_space_data,
2322 ScratchAllocator* workspace_allocator,
2323 dnn::ProfileResult* output_profile_result) {
2324 const CudnnRnnDescriptor& cudnn_rnn_desc =
2325 static_cast<const CudnnRnnDescriptor&>(rnn_desc);
2326 const CudnnRnnSequenceTensorDescriptor& cudnn_input_desc =
2327 static_cast<const CudnnRnnSequenceTensorDescriptor&>(input_desc);
2328 const CudnnRnnStateTensorDescriptor& cudnn_input_h_desc =
2329 static_cast<const CudnnRnnStateTensorDescriptor&>(input_h_desc);
2330 const CudnnRnnStateTensorDescriptor& cudnn_input_c_desc =
2331 static_cast<const CudnnRnnStateTensorDescriptor&>(input_c_desc);
2332 const CudnnRnnSequenceTensorDescriptor& cudnn_output_desc =
2333 static_cast<const CudnnRnnSequenceTensorDescriptor&>(output_desc);
2334 const CudnnRnnStateTensorDescriptor& cudnn_output_h_desc =
2335 static_cast<const CudnnRnnStateTensorDescriptor&>(output_h_desc);
2336 const CudnnRnnStateTensorDescriptor& cudnn_output_c_desc =
2337 static_cast<const CudnnRnnStateTensorDescriptor&>(output_c_desc);
2338 return IsStatusOk(
2339 DoRnnBackwardImpl<float>(
2340 stream, cudnn_rnn_desc, cudnn_input_desc, input_data,
2341 cudnn_input_h_desc, input_h_data, cudnn_input_c_desc, input_c_data,
2342 params, cudnn_output_desc, output_data, cudnn_output_h_desc,
2343 output_h_data, cudnn_output_c_desc, output_c_data,
2344 output_backprop_data, output_h_backprop_data, output_c_backprop_data,
2345 input_backprop_data, input_h_backprop_data, input_c_backprop_data,
2346 params_backprop_data, reserve_space_data, workspace_allocator,
2347 output_profile_result),
2348 /*report_error=*/!output_profile_result);
2349 }
2350
DoRnnBackward(Stream * stream,const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<double> & input_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<double> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<double> & input_c_data,const DeviceMemory<double> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,const DeviceMemory<double> & output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,const DeviceMemory<double> & output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,const DeviceMemory<double> & output_c_data,const DeviceMemory<double> & output_backprop_data,const DeviceMemory<double> & output_h_backprop_data,const DeviceMemory<double> & output_c_backprop_data,DeviceMemory<double> * input_backprop_data,DeviceMemory<double> * input_h_backprop_data,DeviceMemory<double> * input_c_backprop_data,DeviceMemory<double> * params_backprop_data,DeviceMemory<uint8> * reserve_space_data,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)2351 bool CudnnSupport::DoRnnBackward(
2352 Stream* stream, const dnn::RnnDescriptor& rnn_desc,
2353 const dnn::RnnSequenceTensorDescriptor& input_desc,
2354 const DeviceMemory<double>& input_data,
2355 const dnn::RnnStateTensorDescriptor& input_h_desc,
2356 const DeviceMemory<double>& input_h_data,
2357 const dnn::RnnStateTensorDescriptor& input_c_desc,
2358 const DeviceMemory<double>& input_c_data,
2359 const DeviceMemory<double>& params,
2360 const dnn::RnnSequenceTensorDescriptor& output_desc,
2361 const DeviceMemory<double>& output_data,
2362 const dnn::RnnStateTensorDescriptor& output_h_desc,
2363 const DeviceMemory<double>& output_h_data,
2364 const dnn::RnnStateTensorDescriptor& output_c_desc,
2365 const DeviceMemory<double>& output_c_data,
2366 const DeviceMemory<double>& output_backprop_data,
2367 const DeviceMemory<double>& output_h_backprop_data,
2368 const DeviceMemory<double>& output_c_backprop_data,
2369 DeviceMemory<double>* input_backprop_data,
2370 DeviceMemory<double>* input_h_backprop_data,
2371 DeviceMemory<double>* input_c_backprop_data,
2372 DeviceMemory<double>* params_backprop_data,
2373 DeviceMemory<uint8>* reserve_space_data,
2374 ScratchAllocator* workspace_allocator,
2375 dnn::ProfileResult* output_profile_result) {
2376 const CudnnRnnDescriptor& cudnn_rnn_desc =
2377 static_cast<const CudnnRnnDescriptor&>(rnn_desc);
2378 const CudnnRnnSequenceTensorDescriptor& cudnn_input_desc =
2379 static_cast<const CudnnRnnSequenceTensorDescriptor&>(input_desc);
2380 const CudnnRnnStateTensorDescriptor& cudnn_input_h_desc =
2381 static_cast<const CudnnRnnStateTensorDescriptor&>(input_h_desc);
2382 const CudnnRnnStateTensorDescriptor& cudnn_input_c_desc =
2383 static_cast<const CudnnRnnStateTensorDescriptor&>(input_c_desc);
2384 const CudnnRnnSequenceTensorDescriptor& cudnn_output_desc =
2385 static_cast<const CudnnRnnSequenceTensorDescriptor&>(output_desc);
2386 const CudnnRnnStateTensorDescriptor& cudnn_output_h_desc =
2387 static_cast<const CudnnRnnStateTensorDescriptor&>(output_h_desc);
2388 const CudnnRnnStateTensorDescriptor& cudnn_output_c_desc =
2389 static_cast<const CudnnRnnStateTensorDescriptor&>(output_c_desc);
2390 return IsStatusOk(
2391 DoRnnBackwardImpl<double>(
2392 stream, cudnn_rnn_desc, cudnn_input_desc, input_data,
2393 cudnn_input_h_desc, input_h_data, cudnn_input_c_desc, input_c_data,
2394 params, cudnn_output_desc, output_data, cudnn_output_h_desc,
2395 output_h_data, cudnn_output_c_desc, output_c_data,
2396 output_backprop_data, output_h_backprop_data, output_c_backprop_data,
2397 input_backprop_data, input_h_backprop_data, input_c_backprop_data,
2398 params_backprop_data, reserve_space_data, workspace_allocator,
2399 output_profile_result),
2400 /*report_error=*/!output_profile_result);
2401 }
2402
2403 namespace {
2404
2405 // TODO(csigg): Merge a lot of duplicate code below for forward, backward data,
2406 // and backward filter.
2407
GetCudnnConvolutionForwardAlgo(const CudnnHandle & cudnn,const CudnnTensorDescriptor & input_nd,const CudnnFilterDescriptor & filter,const CudnnConvolutionDescriptor & conv,const CudnnTensorDescriptor & output_nd,bool specify_workspace_limit,size_t memory_limit_bytes)2408 port::StatusOr<cudnnConvolutionFwdAlgo_t> GetCudnnConvolutionForwardAlgo(
2409 const CudnnHandle& cudnn, const CudnnTensorDescriptor& input_nd,
2410 const CudnnFilterDescriptor& filter, const CudnnConvolutionDescriptor& conv,
2411 const CudnnTensorDescriptor& output_nd, bool specify_workspace_limit,
2412 size_t memory_limit_bytes) {
2413 cudnnConvolutionFwdPreference_t preference =
2414 specify_workspace_limit ? CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT
2415 : CUDNN_CONVOLUTION_FWD_NO_WORKSPACE;
2416 cudnnConvolutionFwdAlgo_t algo_to_use;
2417 RETURN_IF_CUDNN_ERROR(cudnnGetConvolutionForwardAlgorithm(
2418 cudnn.handle(), input_nd.handle(), filter.handle(), conv.handle(),
2419 output_nd.handle(), preference, memory_limit_bytes, &algo_to_use));
2420 return algo_to_use;
2421 }
2422
2423 port::StatusOr<cudnnConvolutionBwdDataAlgo_t>
GetCudnnConvolutionBackwardDataAlgo(const CudnnHandle & cudnn,const CudnnTensorDescriptor & input_nd,const CudnnFilterDescriptor & filter,const CudnnConvolutionDescriptor & conv,const CudnnTensorDescriptor & output_nd,bool specify_workspace_limit,size_t memory_limit_bytes)2424 GetCudnnConvolutionBackwardDataAlgo(const CudnnHandle& cudnn,
2425 const CudnnTensorDescriptor& input_nd,
2426 const CudnnFilterDescriptor& filter,
2427 const CudnnConvolutionDescriptor& conv,
2428 const CudnnTensorDescriptor& output_nd,
2429 bool specify_workspace_limit,
2430 size_t memory_limit_bytes) {
2431 cudnnConvolutionBwdDataPreference_t preference =
2432 specify_workspace_limit
2433 ? CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT
2434 : CUDNN_CONVOLUTION_BWD_DATA_NO_WORKSPACE;
2435 cudnnConvolutionBwdDataAlgo_t algo_to_use;
2436 RETURN_IF_CUDNN_ERROR(cudnnGetConvolutionBackwardDataAlgorithm(
2437 cudnn.handle(), filter.handle(), output_nd.handle(), conv.handle(),
2438 input_nd.handle(), preference, memory_limit_bytes, &algo_to_use));
2439 return algo_to_use;
2440 }
2441
2442 port::StatusOr<cudnnConvolutionBwdFilterAlgo_t>
GetCudnnConvolutionBackwardFilterAlgo(const CudnnHandle & cudnn,const CudnnTensorDescriptor & input_nd,const CudnnFilterDescriptor & filter,const CudnnConvolutionDescriptor & conv,const CudnnTensorDescriptor & output_nd,bool specify_workspace_limit,size_t memory_limit_bytes)2443 GetCudnnConvolutionBackwardFilterAlgo(const CudnnHandle& cudnn,
2444 const CudnnTensorDescriptor& input_nd,
2445 const CudnnFilterDescriptor& filter,
2446 const CudnnConvolutionDescriptor& conv,
2447 const CudnnTensorDescriptor& output_nd,
2448 bool specify_workspace_limit,
2449 size_t memory_limit_bytes) {
2450 cudnnConvolutionBwdFilterPreference_t preference =
2451 specify_workspace_limit
2452 ? CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT
2453 : CUDNN_CONVOLUTION_BWD_FILTER_NO_WORKSPACE;
2454 cudnnConvolutionBwdFilterAlgo_t algo_to_use;
2455 RETURN_IF_CUDNN_ERROR(cudnnGetConvolutionBackwardFilterAlgorithm(
2456 cudnn.handle(), input_nd.handle(), output_nd.handle(), conv.handle(),
2457 filter.handle(), preference, memory_limit_bytes, &algo_to_use));
2458 return algo_to_use;
2459 }
2460
AllocateCudnnConvolutionForwardWorkspace(Stream * stream,const CudnnHandle & cudnn,const CudnnTensorDescriptor & input_nd,const CudnnFilterDescriptor & filter,const CudnnConvolutionDescriptor & conv,const CudnnTensorDescriptor & output_nd,const dnn::AlgorithmDesc & algorithm_desc,ScratchAllocator * scratch_allocator)2461 port::StatusOr<DeviceMemory<uint8>> AllocateCudnnConvolutionForwardWorkspace(
2462 Stream* stream, const CudnnHandle& cudnn,
2463 const CudnnTensorDescriptor& input_nd, const CudnnFilterDescriptor& filter,
2464 const CudnnConvolutionDescriptor& conv,
2465 const CudnnTensorDescriptor& output_nd,
2466 const dnn::AlgorithmDesc& algorithm_desc,
2467 ScratchAllocator* scratch_allocator) {
2468 // TODO(csigg): This has side effects on the convolution descriptor. It is
2469 // functionally correct because the convolution is run with the algorithm of
2470 // the last call to this function, but should be fixed anyway.
2471 conv.set_use_tensor_op_math(algorithm_desc.tensor_ops_enabled());
2472
2473 // Query the size of the workspace and allocate it.
2474 size_t size_in_bytes;
2475 RETURN_IF_CUDNN_ERROR(cudnnGetConvolutionForwardWorkspaceSize(
2476 cudnn.handle(),
2477 /*xDesc=*/input_nd.handle(),
2478 /*wDesc=*/filter.handle(), /*convDesc=*/conv.handle(),
2479 /*yDesc=*/output_nd.handle(), /*algo=*/ToConvForwardAlgo(algorithm_desc),
2480 /*sizeInBytes=*/&size_in_bytes));
2481
2482 int64 size_in_bytes_int64 = size_in_bytes;
2483
2484 if (TF_PREDICT_FALSE(size_in_bytes_int64 < 0)) {
2485 return port::Status(
2486 port::error::INTERNAL,
2487 "cudnnGetConvolutionForwardWorkspaceSize() returned "
2488 "negative sizeInBytes value. This could be a cudnn bug.");
2489 }
2490
2491 if (size_in_bytes_int64 == 0) {
2492 return DeviceMemory<uint8>();
2493 }
2494
2495 if (TF_PREDICT_FALSE(!scratch_allocator)) {
2496 return port::Status(port::error::INVALID_ARGUMENT,
2497 "No scratch allocator provided");
2498 }
2499
2500 return scratch_allocator->AllocateBytes(size_in_bytes);
2501 }
2502
2503 port::StatusOr<DeviceMemory<uint8>>
AllocateCudnnConvolutionBackwardDataWorkspace(Stream * stream,const CudnnHandle & cudnn,const CudnnTensorDescriptor & input_nd,const CudnnFilterDescriptor & filter,const CudnnConvolutionDescriptor & conv,const CudnnTensorDescriptor & output_nd,const dnn::AlgorithmDesc & algorithm_desc,ScratchAllocator * scratch_allocator)2504 AllocateCudnnConvolutionBackwardDataWorkspace(
2505 Stream* stream, const CudnnHandle& cudnn,
2506 const CudnnTensorDescriptor& input_nd, const CudnnFilterDescriptor& filter,
2507 const CudnnConvolutionDescriptor& conv,
2508 const CudnnTensorDescriptor& output_nd,
2509 const dnn::AlgorithmDesc& algorithm_desc,
2510 ScratchAllocator* scratch_allocator) {
2511 // TODO(csigg): This has side effects on the convolution descriptor. It is
2512 // functionally correct because the convolution is run with the algorithm of
2513 // the last call to this function, but should be fixed anyway.
2514 conv.set_use_tensor_op_math(algorithm_desc.tensor_ops_enabled());
2515
2516 // Query the size of the workspace and allocate it.
2517 size_t size_in_bytes;
2518 RETURN_IF_CUDNN_ERROR(cudnnGetConvolutionBackwardDataWorkspaceSize(
2519 cudnn.handle(),
2520 /*wDesc=*/filter.handle(),
2521 /*dyDesc=*/output_nd.handle(),
2522 /*convDesc=*/conv.handle(),
2523 /*dxDesc=*/input_nd.handle(),
2524 /*algo=*/ToConvBackwardDataAlgo(algorithm_desc),
2525 /*sizeInBytes=*/&size_in_bytes));
2526
2527 int64 size_in_bytes_int64 = size_in_bytes;
2528
2529 if (TF_PREDICT_FALSE(size_in_bytes_int64 < 0)) {
2530 return port::Status(
2531 port::error::INTERNAL,
2532 "cudnnGetConvolutionBackwardDataWorkspaceSize() returned "
2533 "negative sizeInBytes value. This could be a cudnn bug.");
2534 }
2535
2536 if (size_in_bytes_int64 == 0) {
2537 return DeviceMemory<uint8>();
2538 }
2539
2540 if (TF_PREDICT_FALSE(!scratch_allocator)) {
2541 return port::Status(port::error::INVALID_ARGUMENT,
2542 "No scratch allocator provided");
2543 }
2544
2545 return scratch_allocator->AllocateBytes(size_in_bytes);
2546 }
2547
2548 port::StatusOr<DeviceMemory<uint8>>
AllocateCudnnConvolutionBackwardFilterWorkspace(Stream * stream,const CudnnHandle & cudnn,const CudnnTensorDescriptor & input_nd,const CudnnFilterDescriptor & filter,const CudnnConvolutionDescriptor & conv,const CudnnTensorDescriptor & output_nd,const dnn::AlgorithmDesc & algorithm_desc,ScratchAllocator * scratch_allocator)2549 AllocateCudnnConvolutionBackwardFilterWorkspace(
2550 Stream* stream, const CudnnHandle& cudnn,
2551 const CudnnTensorDescriptor& input_nd, const CudnnFilterDescriptor& filter,
2552 const CudnnConvolutionDescriptor& conv,
2553 const CudnnTensorDescriptor& output_nd,
2554 const dnn::AlgorithmDesc& algorithm_desc,
2555 ScratchAllocator* scratch_allocator) {
2556 // TODO(csigg): This has side effects on the convolution descriptor. It is
2557 // functionally correct because the convolution is run with the algorithm of
2558 // the last call to this function, but should be fixed anyway.
2559 conv.set_use_tensor_op_math(algorithm_desc.tensor_ops_enabled());
2560
2561 // Query the size of the workspace and allocate it.
2562 size_t size_in_bytes;
2563 RETURN_IF_CUDNN_ERROR(cudnnGetConvolutionBackwardFilterWorkspaceSize(
2564 cudnn.handle(),
2565 /*xDesc=*/input_nd.handle(),
2566 /*dyDesc=*/output_nd.handle(),
2567 /*convDesc=*/conv.handle(),
2568 /*gradDesc=*/filter.handle(),
2569 /*algo=*/ToConvBackwardFilterAlgo(algorithm_desc),
2570 /*sizeInBytes=*/&size_in_bytes));
2571
2572 int64 size_in_bytes_int64 = size_in_bytes;
2573
2574 if (TF_PREDICT_FALSE(size_in_bytes_int64 < 0)) {
2575 return port::Status(
2576 port::error::INTERNAL,
2577 "cudnnGetConvolutionBackwardFilterWorkspaceSize() returned "
2578 "negative sizeInBytes value. This could be a cudnn bug.");
2579 }
2580
2581 if (size_in_bytes_int64 == 0) {
2582 return DeviceMemory<uint8>();
2583 }
2584
2585 if (TF_PREDICT_FALSE(!scratch_allocator)) {
2586 return port::Status(port::error::INVALID_ARGUMENT,
2587 "No scratch allocator provided");
2588 }
2589
2590 return scratch_allocator->AllocateBytes(size_in_bytes);
2591 }
2592
TensorOpMathAvailable(int cc_major)2593 static bool TensorOpMathAvailable(int cc_major) {
2594 return cc_major >= 7 && CUDNN_VERSION >= 7000 && TensorOpMathEnabled();
2595 }
2596
GetCudnnConvolutionForwardAlgorithm(Stream * stream,const CudnnHandle & cudnn,const dnn::AlgorithmConfig & algorithm_config,const CudnnTensorDescriptor & input_nd,const CudnnFilterDescriptor & filter,const CudnnConvolutionDescriptor & conv,const CudnnTensorDescriptor & output_nd,ScratchAllocator * scratch_allocator,DeviceMemory<uint8> * scratch)2597 port::StatusOr<dnn::AlgorithmDesc> GetCudnnConvolutionForwardAlgorithm(
2598 Stream* stream, const CudnnHandle& cudnn,
2599 const dnn::AlgorithmConfig& algorithm_config,
2600 const CudnnTensorDescriptor& input_nd, const CudnnFilterDescriptor& filter,
2601 const CudnnConvolutionDescriptor& conv,
2602 const CudnnTensorDescriptor& output_nd, ScratchAllocator* scratch_allocator,
2603 DeviceMemory<uint8>* scratch) {
2604 absl::optional<dnn::AlgorithmDesc> algo_desc = algorithm_config.algorithm();
2605 if (!algo_desc.has_value()) {
2606 // Pick fastest algorithm within memory limit according to cuDNN's
2607 // heuristics.
2608 bool specify_workspace_limit = scratch_allocator != nullptr;
2609 auto memory_limit_bytes =
2610 specify_workspace_limit
2611 ? std::max(scratch_allocator->GetMemoryLimitInBytes(), 0ll)
2612 : 0ll;
2613 SE_ASSIGN_OR_RETURN(cudnnConvolutionFwdAlgo_t algo,
2614 GetCudnnConvolutionForwardAlgo(
2615 cudnn, input_nd, filter, conv, output_nd,
2616 specify_workspace_limit, memory_limit_bytes));
2617 int cc_major, cc_minor;
2618 std::tie(cc_major, cc_minor) = GetCcMajorMinor(stream);
2619 algo_desc = dnn::AlgorithmDesc(
2620 algo, /*use_tensor_ops=*/TensorOpMathAvailable(cc_major));
2621 }
2622
2623 const auto scratch_or = AllocateCudnnConvolutionForwardWorkspace(
2624 stream, cudnn, input_nd, filter, conv, output_nd, *algo_desc,
2625 scratch_allocator);
2626
2627 if (scratch_or.ok()) {
2628 *scratch = scratch_or.ValueOrDie();
2629 return *algo_desc;
2630 }
2631
2632 algo_desc = algorithm_config.algorithm_no_scratch();
2633
2634 // Failed to allocate workspace for the first algorithm, fall back to the
2635 // no_scratch algorithm.
2636 if (!algo_desc.has_value()) {
2637 return port::Status(
2638 port::error::INVALID_ARGUMENT,
2639 absl::StrCat("The primary convolution algorithm failed, ",
2640 "while a secondary algorithm is not provided. ",
2641 "Returned status: ", scratch_or.status().ToString()));
2642 }
2643
2644 SE_ASSIGN_OR_RETURN(*scratch, AllocateCudnnConvolutionForwardWorkspace(
2645 stream, cudnn, input_nd, filter, conv,
2646 output_nd, *algo_desc, scratch_allocator));
2647 return *algo_desc;
2648 }
2649
GetCudnnConvolutionBackwardDataAlgorithm(Stream * stream,const CudnnHandle & cudnn,const dnn::AlgorithmConfig & algorithm_config,const CudnnTensorDescriptor & input_nd,const CudnnFilterDescriptor & filter,const CudnnConvolutionDescriptor & conv,const CudnnTensorDescriptor & output_nd,ScratchAllocator * scratch_allocator,DeviceMemory<uint8> * scratch)2650 port::StatusOr<dnn::AlgorithmDesc> GetCudnnConvolutionBackwardDataAlgorithm(
2651 Stream* stream, const CudnnHandle& cudnn,
2652 const dnn::AlgorithmConfig& algorithm_config,
2653 const CudnnTensorDescriptor& input_nd, const CudnnFilterDescriptor& filter,
2654 const CudnnConvolutionDescriptor& conv,
2655 const CudnnTensorDescriptor& output_nd, ScratchAllocator* scratch_allocator,
2656 DeviceMemory<uint8>* scratch) {
2657 absl::optional<dnn::AlgorithmDesc> algo_desc = algorithm_config.algorithm();
2658 if (!algo_desc.has_value()) {
2659 // Pick fastest algorithm within memory limit according to cuDNN's
2660 // heuristics.
2661 bool specify_workspace_limit = scratch_allocator != nullptr;
2662 auto memory_limit_bytes =
2663 specify_workspace_limit
2664 ? std::max(scratch_allocator->GetMemoryLimitInBytes(), 0ll)
2665 : 0ll;
2666 SE_ASSIGN_OR_RETURN(cudnnConvolutionBwdDataAlgo_t algo,
2667 GetCudnnConvolutionBackwardDataAlgo(
2668 cudnn, input_nd, filter, conv, output_nd,
2669 specify_workspace_limit, memory_limit_bytes));
2670 int cc_major, cc_minor;
2671 std::tie(cc_major, cc_minor) = GetCcMajorMinor(stream);
2672 algo_desc = dnn::AlgorithmDesc(
2673 algo, /*use_tensor_ops=*/TensorOpMathAvailable(cc_major));
2674 }
2675
2676 const auto scratch_or = AllocateCudnnConvolutionBackwardDataWorkspace(
2677 stream, cudnn, input_nd, filter, conv, output_nd, *algo_desc,
2678 scratch_allocator);
2679
2680 if (scratch_or.ok()) {
2681 *scratch = scratch_or.ValueOrDie();
2682 return *algo_desc;
2683 }
2684
2685 algo_desc = algorithm_config.algorithm_no_scratch();
2686
2687 // Failed to allocate workspace for the first algorithm, fall back to the
2688 // no_scratch algorithm.
2689 if (!algo_desc.has_value()) {
2690 return port::Status(
2691 port::error::INVALID_ARGUMENT,
2692 "The primary convolution algorithm failed memory allocation, "
2693 "while a secondary algorithm is not provided.");
2694 }
2695
2696 SE_ASSIGN_OR_RETURN(*scratch, AllocateCudnnConvolutionBackwardDataWorkspace(
2697 stream, cudnn, input_nd, filter, conv,
2698 output_nd, *algo_desc, scratch_allocator));
2699 return *algo_desc;
2700 }
2701
GetCudnnConvolutionBackwardFilterAlgorithm(Stream * stream,const CudnnHandle & cudnn,const dnn::AlgorithmConfig & algorithm_config,const CudnnTensorDescriptor & input_nd,const CudnnFilterDescriptor & filter,const CudnnConvolutionDescriptor & conv,const CudnnTensorDescriptor & output_nd,ScratchAllocator * scratch_allocator,DeviceMemory<uint8> * scratch)2702 port::StatusOr<dnn::AlgorithmDesc> GetCudnnConvolutionBackwardFilterAlgorithm(
2703 Stream* stream, const CudnnHandle& cudnn,
2704 const dnn::AlgorithmConfig& algorithm_config,
2705 const CudnnTensorDescriptor& input_nd, const CudnnFilterDescriptor& filter,
2706 const CudnnConvolutionDescriptor& conv,
2707 const CudnnTensorDescriptor& output_nd, ScratchAllocator* scratch_allocator,
2708 DeviceMemory<uint8>* scratch) {
2709 absl::optional<dnn::AlgorithmDesc> algo_desc = algorithm_config.algorithm();
2710 if (!algo_desc.has_value()) {
2711 // Pick fastest algorithm within memory limit according to cuDNN's
2712 // heuristics.
2713 bool specify_workspace_limit = scratch_allocator != nullptr;
2714 auto memory_limit_bytes =
2715 specify_workspace_limit
2716 ? std::max(scratch_allocator->GetMemoryLimitInBytes(), 0ll)
2717 : 0ll;
2718 SE_ASSIGN_OR_RETURN(cudnnConvolutionBwdFilterAlgo_t algo,
2719 GetCudnnConvolutionBackwardFilterAlgo(
2720 cudnn, input_nd, filter, conv, output_nd,
2721 specify_workspace_limit, memory_limit_bytes));
2722 int cc_major, cc_minor;
2723 std::tie(cc_major, cc_minor) = GetCcMajorMinor(stream);
2724 algo_desc = dnn::AlgorithmDesc(
2725 algo, /*use_tensor_ops=*/TensorOpMathAvailable(cc_major));
2726 }
2727
2728 auto scratch_or = AllocateCudnnConvolutionBackwardFilterWorkspace(
2729 stream, cudnn, input_nd, filter, conv, output_nd, *algo_desc,
2730 scratch_allocator);
2731
2732 if (scratch_or.ok()) {
2733 *scratch = scratch_or.ValueOrDie();
2734 return *algo_desc;
2735 }
2736
2737 algo_desc = algorithm_config.algorithm_no_scratch();
2738
2739 // Failed to allocate workspace for the first algorithm, fall back to the
2740 // no_scratch algorithm.
2741 if (!algo_desc.has_value()) {
2742 return port::Status(
2743 port::error::INVALID_ARGUMENT,
2744 "The primary convolution algorithm failed memory allocation, "
2745 "while a secondary algorithm is not provided.");
2746 }
2747
2748 SE_ASSIGN_OR_RETURN(*scratch, AllocateCudnnConvolutionBackwardFilterWorkspace(
2749 stream, cudnn, input_nd, filter, conv,
2750 output_nd, *algo_desc, scratch_allocator));
2751 return *algo_desc;
2752 }
2753
2754 // A helper class to set env-vars and choose options for cudnn-related
2755 // algorithms.
2756 template <typename EnvVar>
2757 class CudnnEnvVar {
2758 public:
IsEnabled()2759 static bool IsEnabled() {
2760 static bool is_enabled = IsEnabledImpl();
2761 return is_enabled;
2762 }
2763
2764 private:
IsEnabledImpl()2765 static bool IsEnabledImpl() {
2766 const char* tf_env_var_val = getenv(EnvVar::kName);
2767 if (tf_env_var_val != nullptr) {
2768 absl::string_view tf_env_var_val_str(tf_env_var_val);
2769 if (tf_env_var_val_str == "0") {
2770 return false;
2771 }
2772 return true;
2773 }
2774 return EnvVar::kDefaultFlag;
2775 }
2776 };
2777
2778 // A helper struct to decide whether to enable the FFT_TILING algorithms for
2779 // forward convolution. It is disabled for cuDNN < 7 due to memory corruption
2780 // caused by some shapes with this algorithm. Users can explicitly enable the
2781 // algorithm through an env-var "TF_ENABLE_FFT_TILING_FORWARD=1".
2782 struct FftTilingForward {
2783 static constexpr const char* kName = "TF_ENABLE_FFT_TILING_FORWARD";
2784 static constexpr bool kDefaultFlag = CUDNN_VERSION >= 7000;
2785 };
2786
2787 // A helper struct to decide whether to enable the WINOGRAD_NONFUSED algorithms.
2788 // By default it is turned on, users can explicitly disable them through an
2789 // env-var "TF_ENABLE_WINOGRAD_NONFUSED=0".
2790 // https://github.com/tensorflow/tensorflow/pull/4901
2791 struct WinogradNonfused {
2792 static constexpr const char* kName = "TF_ENABLE_WINOGRAD_NONFUSED";
2793 // NVIDIA has fixed winograd nonfused bug for cudnn v>=7. For older versions,
2794 // we have a workaround.
2795 static constexpr bool kDefaultFlag = true;
2796 };
2797
2798 // A helper struct to decide whether to use FP32 as the internal compute type
2799 // for convolution when the input data type is FP16. By default it is turned on,
2800 // users can explicitly disable them (choose to use FP16 as the internal compute
2801 // type) through an env-var "TF_FP16_CONV_USE_FP32_COMPUTE=0".
2802 struct ConvDoFP32ComputationFP16Input {
2803 static constexpr const char* kName = "TF_FP16_CONV_USE_FP32_COMPUTE";
2804 // Using FP16 as the internal compute type for convolution when the input data
2805 // type is FP16 is only supported on architectures with true fp16 support
2806 // (compute capability 5.3 and 6.0). Setting this to false in an unsupported
2807 // architecture will cause internal errors.
2808 static constexpr bool kDefaultFlag = true;
2809 };
2810
2811 // A helper struct to decide whether to use FP32 as the internal compute type
2812 // for rnn when the input data type is FP16. At present it is turned off,
2813 // users can explicitly control them through an env-var
2814 // TF_FP16_RNN_USE_FP32_COMPUTE.
2815 // After the TODO below is fixed, users should almost always use fp32 compute
2816 // type for training. Using fp16 might suffer suboptimal accuracy due to loss
2817 // in precision.
2818 struct RnnDoFP32ComputationFP16Input {
2819 static constexpr const char* kName = "TF_FP16_RNN_USE_FP32_COMPUTE";
2820 // TODO(jamesqin): b/78182362 flip to true when cudnn 7.1.4 fixes the bug.
2821 // Before cudnn 7.1.4 RNN are always done in fp32, no matter what math
2822 // precision is set.
2823 // Set it temporary to false s.t. no error is raised when using fp16 inputs,
2824 // fp32 math precision.
2825 //
2826 // cuDNN == 7.5.0 is verified to have this fixed.
2827 static constexpr bool kDefaultFlag = CUDNN_VERSION >= 7500;
2828 };
2829
GetRnnComputeType(dnn::DataType data_type)2830 cudnnDataType_t GetRnnComputeType(dnn::DataType data_type) {
2831 switch (data_type) {
2832 case dnn::DataType::kFloat:
2833 return CUDNN_DATA_FLOAT;
2834 case dnn::DataType::kDouble:
2835 return CUDNN_DATA_DOUBLE;
2836 case dnn::DataType::kHalf:
2837 if (CudnnEnvVar<RnnDoFP32ComputationFP16Input>::IsEnabled()) {
2838 return CUDNN_DATA_FLOAT;
2839 } else {
2840 return CUDNN_DATA_HALF;
2841 }
2842 default:
2843 LOG(FATAL) << "Invalid RNN data type: " << static_cast<int>(data_type);
2844 }
2845 }
2846
GetConvAccumulatorType(dnn::DataType data_type)2847 dnn::DataType GetConvAccumulatorType(dnn::DataType data_type) {
2848 switch (data_type) {
2849 case dnn::DataType::kFloat:
2850 case dnn::DataType::kDouble:
2851 return data_type;
2852 case dnn::DataType::kHalf:
2853 return CudnnEnvVar<ConvDoFP32ComputationFP16Input>::IsEnabled()
2854 ? dnn::DataType::kFloat
2855 : dnn::DataType::kHalf;
2856 case dnn::DataType::kInt8:
2857 case dnn::DataType::kInt32:
2858 return dnn::DataType::kInt32;
2859 default:
2860 LOG(FATAL) << "Invalid DNN data type: " << static_cast<int>(data_type);
2861 }
2862 }
2863
2864 // Determines whether we can safely perform a winograd non-fused convolution for
2865 // the given input and output shapes. This works around b/68264959, an integer
2866 // overflow in cuDNNv5 and cuDNNv6.
2867 #if CUDNN_VERSION >= 7000
ShouldIncludeWinogradNonfusedAlgo(const dnn::BatchDescriptor &,const dnn::BatchDescriptor &)2868 bool ShouldIncludeWinogradNonfusedAlgo(const dnn::BatchDescriptor&,
2869 const dnn::BatchDescriptor&) {
2870 return true;
2871 }
2872 #else
ShouldIncludeWinogradNonfusedAlgo(const dnn::BatchDescriptor & input_desc,const dnn::BatchDescriptor & output_desc)2873 bool ShouldIncludeWinogradNonfusedAlgo(
2874 const dnn::BatchDescriptor& input_desc,
2875 const dnn::BatchDescriptor& output_desc) {
2876 int64 batch = input_desc.count();
2877 int64 in_depths = input_desc.feature_map_count();
2878 int64 in_rows = input_desc.height();
2879 int64 in_cols = input_desc.ndims() == 1 ? 1 : input_desc.width();
2880 int64 out_depths = output_desc.feature_map_count();
2881
2882 int64 total_size = port::MathUtil::CeilOfRatio(batch, int64{16}) *
2883 std::max(in_depths, out_depths) * in_cols * in_rows *
2884 sizeof(float);
2885
2886 const int64 threshold = 1L << 31;
2887 return total_size < threshold;
2888 }
2889 #endif
2890
2891 } // namespace
2892
DoPrepareForConvolution(dnn::ConvolutionKind kind,dnn::DataType element_type,Stream * stream,const dnn::BatchDescriptor & input_descriptor,DeviceMemoryBase input_data,const dnn::FilterDescriptor & filter_descriptor,DeviceMemoryBase filter_data,const dnn::BatchDescriptor & output_descriptor,DeviceMemoryBase output_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::AlgorithmConfig & algorithm_config,ScratchAllocator * scratch_allocator,dnn::AlgorithmDesc * algorithm_desc,DeviceMemory<uint8> * scratch_memory)2893 port::Status CudnnSupport::DoPrepareForConvolution(
2894 dnn::ConvolutionKind kind, dnn::DataType element_type, Stream* stream,
2895 const dnn::BatchDescriptor& input_descriptor, DeviceMemoryBase input_data,
2896 const dnn::FilterDescriptor& filter_descriptor,
2897 DeviceMemoryBase filter_data, const dnn::BatchDescriptor& output_descriptor,
2898 DeviceMemoryBase output_data,
2899 const dnn::ConvolutionDescriptor& convolution_descriptor,
2900 const dnn::AlgorithmConfig& algorithm_config,
2901 ScratchAllocator* scratch_allocator, dnn::AlgorithmDesc* algorithm_desc,
2902 DeviceMemory<uint8>* scratch_memory) {
2903 CudnnTensorDescriptor input_nd(
2904 input_descriptor,
2905 ToCudnnDataType(element_type, input_descriptor.layout()));
2906 CudnnFilterDescriptor filter_nd(
2907 filter_descriptor,
2908 ToCudnnDataType(element_type, filter_descriptor.layout()));
2909 CudnnTensorDescriptor output_nd(
2910 output_descriptor,
2911 ToCudnnDataType(element_type, output_descriptor.layout()));
2912 CudnnConvolutionDescriptor conv(
2913 convolution_descriptor,
2914 ToCudnnDataType(GetConvAccumulatorType(element_type)));
2915
2916 auto cudnn = cudnn_->GetHandle(parent_, stream);
2917
2918 switch (kind) {
2919 case dnn::ConvolutionKind::FORWARD: {
2920 SE_ASSIGN_OR_RETURN(
2921 *algorithm_desc,
2922 GetCudnnConvolutionForwardAlgorithm(
2923 stream, cudnn, algorithm_config, input_nd, filter_nd, conv,
2924 output_nd, scratch_allocator, scratch_memory));
2925 break;
2926 }
2927 case dnn::ConvolutionKind::BACKWARD_DATA: {
2928 SE_ASSIGN_OR_RETURN(
2929 *algorithm_desc,
2930 GetCudnnConvolutionBackwardDataAlgorithm(
2931 stream, cudnn, algorithm_config, input_nd, filter_nd, conv,
2932 output_nd, scratch_allocator, scratch_memory));
2933 break;
2934 }
2935 case dnn::ConvolutionKind::BACKWARD_FILTER: {
2936 SE_ASSIGN_OR_RETURN(
2937 *algorithm_desc,
2938 GetCudnnConvolutionBackwardFilterAlgorithm(
2939 stream, cudnn, algorithm_config, input_nd, filter_nd, conv,
2940 output_nd, scratch_allocator, scratch_memory));
2941 break;
2942 }
2943 default:
2944 return port::InternalError(
2945 absl::StrCat("Unexpected convolution kind ", static_cast<int>(kind)));
2946 }
2947
2948 return port::Status::OK();
2949 }
2950
DoConvolve(dnn::ConvolutionKind kind,dnn::DataType element_type,dnn::DataType output_type,Stream * stream,const dnn::BatchDescriptor & input_descriptor,DeviceMemoryBase input_data,const dnn::FilterDescriptor & filter_descriptor,DeviceMemoryBase filter_data,const dnn::BatchDescriptor & output_descriptor,DeviceMemoryBase output_data,const dnn::ConvolutionDescriptor & convolution_descriptor,dnn::AlgorithmDesc algorithm_desc,DeviceMemory<uint8> scratch_memory,dnn::ProfileResult * output_profile_result)2951 port::Status CudnnSupport::DoConvolve(
2952 dnn::ConvolutionKind kind, dnn::DataType element_type,
2953 dnn::DataType output_type, Stream* stream,
2954 const dnn::BatchDescriptor& input_descriptor, DeviceMemoryBase input_data,
2955 const dnn::FilterDescriptor& filter_descriptor,
2956 DeviceMemoryBase filter_data, const dnn::BatchDescriptor& output_descriptor,
2957 DeviceMemoryBase output_data,
2958 const dnn::ConvolutionDescriptor& convolution_descriptor,
2959 dnn::AlgorithmDesc algorithm_desc, DeviceMemory<uint8> scratch_memory,
2960 dnn::ProfileResult* output_profile_result) {
2961 cudnnDataType_t cudnn_type = ToCudnnDataType(element_type);
2962 CudnnTensorDescriptor input_nd(input_descriptor, cudnn_type);
2963 CudnnTensorDescriptor output_nd(output_descriptor,
2964 ToCudnnDataType(output_type));
2965 CudnnFilterDescriptor filter_nd(filter_descriptor, cudnn_type);
2966 auto accumulator_type = GetConvAccumulatorType(element_type);
2967 CudnnConvolutionDescriptor conv(convolution_descriptor,
2968 ToCudnnDataType(accumulator_type));
2969 // Set use_tensor_math param to correct value
2970 conv.set_use_tensor_op_math(algorithm_desc.tensor_ops_enabled());
2971
2972 auto cudnn = cudnn_->GetHandle(parent_, stream);
2973 // Alpha is the scaling factor for input.
2974 float falpha = 1.0;
2975 double dalpha = 1.0;
2976 void* alpha = cudnn_type == CUDNN_DATA_DOUBLE ? static_cast<void*>(&dalpha)
2977 : static_cast<void*>(&falpha);
2978 // Beta is the scaling factor for output.
2979 float fbeta = 0.0;
2980 double dbeta = 0.0;
2981 void* beta = cudnn_type == CUDNN_DATA_DOUBLE ? static_cast<void*>(&dbeta)
2982 : static_cast<void*>(&fbeta);
2983
2984 const bool is_profiling = output_profile_result != nullptr;
2985
2986 std::unique_ptr<GpuTimer, GpuTimerDeleter> timer;
2987 if (is_profiling) {
2988 timer.reset(new GpuTimer(parent_)); // NOLINT
2989 // The start and stop of the timer should be as close to the Cudnn call as
2990 // possible. It is still possible for other threads to issue workload on
2991 // to this stream. So it could take multiple profiling measurements.
2992 if (!timer->Init() || !timer->Start(AsGpuStream(stream))) {
2993 return port::Status(port::error::INTERNAL, "Failed to start timer");
2994 }
2995 }
2996
2997 const auto get_fwd_bugs = [&]() -> port::Status {
2998 // Report an error if we might be hitting a cuDNN bug that accesses illegal
2999 // memory. See nvbugs/2138754, b/80018418.
3000 if (CUDNN_VERSION < 7300) {
3001 if (algorithm_desc.algo_id() != CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING) {
3002 return port::Status::OK();
3003 }
3004 if (input_descriptor.ndims() < 3) {
3005 return port::Status::OK();
3006 }
3007 // Checks that a*b is within the valid range (as provided by NVIDIA).
3008 const auto check_sizes = [](size_t a, size_t b) {
3009 if ((a * b * 4608 - 1) >> 31 == 0) {
3010 return port::Status::OK();
3011 }
3012 return port::Status(
3013 port::error::FAILED_PRECONDITION,
3014 "This configuration potentially accesses illegal memory.");
3015 };
3016 SE_RETURN_IF_ERROR(check_sizes(input_descriptor.feature_map_count(),
3017 output_descriptor.feature_map_count()));
3018 SE_RETURN_IF_ERROR(check_sizes(input_descriptor.count(),
3019 input_descriptor.feature_map_count()));
3020 SE_RETURN_IF_ERROR(check_sizes(input_descriptor.count(),
3021 output_descriptor.feature_map_count()));
3022 return port::Status::OK();
3023 }
3024 if (algorithm_desc.algo_id() ==
3025 CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED &&
3026 !ShouldIncludeWinogradNonfusedAlgo(input_descriptor,
3027 output_descriptor)) {
3028 return port::Status(
3029 port::error::FAILED_PRECONDITION,
3030 "This configuration has potential integer overflow in "
3031 "cuDNNv5 and cuDNNv6. See b/68264959.");
3032 }
3033 if (CUDNN_VERSION < 8000) {
3034 if (algorithm_desc.algo_id() ==
3035 CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM &&
3036 ToCudnnDataType(element_type) == CUDNN_DATA_INT8 &&
3037 ToCudnnDataType(output_type) == CUDNN_DATA_FLOAT) {
3038 return port::Status(
3039 port::error::FAILED_PRECONDITION,
3040 "This configuration potentially produces incorrect results.");
3041 }
3042 }
3043 return port::Status::OK();
3044 };
3045
3046 auto get_bwd_data_bugs = [&]() -> port::Status {
3047 if (algorithm_desc.algo_id() ==
3048 CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED &&
3049 !ShouldIncludeWinogradNonfusedAlgo(input_descriptor,
3050 output_descriptor)) {
3051 return port::Status(
3052 port::error::FAILED_PRECONDITION,
3053 "This configuration has potential integer overflow in "
3054 "cuDNNv5 and cuDNNv6. See b/68264959.");
3055 }
3056
3057 // Cudnn 7.1.4 has a bug if the workspace of the following convolution is
3058 // not zero-initialized, nvbugs/2254619.
3059 if (CUDNN_VERSION >= 7000 && CUDNN_VERSION < 7300 &&
3060 algorithm_desc.algo_id() == CUDNN_CONVOLUTION_BWD_DATA_ALGO_1 &&
3061 cudnn_type == CUDNN_DATA_HALF && algorithm_desc.tensor_ops_enabled() &&
3062 input_descriptor.layout() == dnn::DataLayout::kBatchYXDepth &&
3063 filter_descriptor.layout() == dnn::FilterLayout::kOutputInputYX &&
3064 output_descriptor.layout() == dnn::DataLayout::kBatchDepthYX &&
3065 (convolution_descriptor.vertical_filter_stride() > 1 ||
3066 convolution_descriptor.horizontal_filter_stride() > 1)) {
3067 stream->ThenMemZero(&scratch_memory, scratch_memory.size());
3068 }
3069 return port::Status::OK();
3070 };
3071
3072 const auto get_bwd_filter_bugs = [&]() -> port::Status {
3073 // Report an error if we might be hitting a cuDNN bug that produces
3074 // incorrect results. See nvbugs/2072856
3075 if (CUDNN_VERSION < 7300) {
3076 SE_RETURN_IF_ERROR([&] {
3077 if (algorithm_desc.algo_id() !=
3078 CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT_TILING) {
3079 return port::Status::OK();
3080 }
3081 if (output_descriptor.height() > 1 && output_descriptor.width() > 1) {
3082 return port::Status::OK();
3083 }
3084 int convolution_size = output_descriptor.height() > 1
3085 ? filter_descriptor.input_filter_height()
3086 : filter_descriptor.input_filter_width();
3087 if (convolution_size <= 32) {
3088 return port::Status::OK();
3089 }
3090 cudnnConvolutionMode_t convolution_mode;
3091 cudnnDataType_t compute_type;
3092 RETURN_IF_CUDNN_ERROR(cudnnGetConvolutionNdDescriptor(
3093 conv.handle(), 0, nullptr, nullptr, nullptr, nullptr,
3094 &convolution_mode, &compute_type));
3095 if (convolution_mode != CUDNN_CONVOLUTION) {
3096 return port::Status::OK();
3097 }
3098 return port::Status(
3099 port::error::FAILED_PRECONDITION,
3100 "This configuration potentially produces incorrect results.");
3101 }());
3102 }
3103
3104 if (algorithm_desc.algo_id() ==
3105 CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED &&
3106 !ShouldIncludeWinogradNonfusedAlgo(input_descriptor,
3107 output_descriptor)) {
3108 return port::Status(
3109 port::error::FAILED_PRECONDITION,
3110 "This configuration has potential integer overflow in "
3111 "cuDNNv5 and cuDNNv6. See b/68264959.");
3112 }
3113
3114 // Zero out the result buffer for strided conv backward filter for NHWC
3115 // layouts. cuDNN 7.1.4 and 7.2 has non-determinisic bug if the buffer is
3116 // not zeroed.
3117 //
3118 // This wrong result caused by the bug is very flaky. It needs to be run for
3119 // up to 20 times to produce a mismatch.
3120 //
3121 // See nvbugs/2379553.
3122 if (CUDNN_VERSION >= 7100 && CUDNN_VERSION < 7300 &&
3123 algorithm_desc.algo_id() == CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1 &&
3124 cudnn_type == CUDNN_DATA_HALF &&
3125 input_descriptor.layout() == dnn::DataLayout::kBatchYXDepth &&
3126 filter_descriptor.layout() == dnn::FilterLayout::kOutputYXInput &&
3127 output_descriptor.layout() == dnn::DataLayout::kBatchYXDepth &&
3128 (convolution_descriptor.vertical_filter_stride() > 1 ||
3129 convolution_descriptor.horizontal_filter_stride() > 1)) {
3130 stream->ThenMemZero(&filter_data, filter_data.size());
3131 }
3132 return port::Status::OK();
3133 };
3134
3135 switch (kind) {
3136 case dnn::ConvolutionKind::FORWARD: {
3137 SE_RETURN_IF_ERROR(get_fwd_bugs());
3138 RETURN_IF_CUDNN_ERROR(cudnnConvolutionForward(
3139 cudnn.handle(),
3140 /*alpha=*/alpha, /*srcDesc=*/input_nd.handle(),
3141 /*srcData=*/input_data.opaque(), /*filterDesc=*/filter_nd.handle(),
3142 /*filterData=*/filter_data.opaque(), /*convDesc=*/conv.handle(),
3143 /*algo=*/ToConvForwardAlgo(algorithm_desc),
3144 /*workSpace=*/scratch_memory.opaque(),
3145 /*workSpaceSizeInBytes=*/scratch_memory.size(), /*beta=*/beta,
3146 /*yDesc=*/output_nd.handle(), /*y=*/output_data.opaque()));
3147 break;
3148 }
3149 case dnn::ConvolutionKind::BACKWARD_DATA: {
3150 SE_RETURN_IF_ERROR(get_bwd_data_bugs());
3151 RETURN_IF_CUDNN_ERROR(cudnnConvolutionBackwardData(
3152 cudnn.handle(),
3153 /*alpha=*/alpha,
3154 /*wDesc=*/filter_nd.handle(),
3155 /*w=*/filter_data.opaque(),
3156 /*dyDesc=*/output_nd.handle(),
3157 /*dy=*/output_data.opaque(),
3158 /*convDesc=*/conv.handle(),
3159 /*algo=*/ToConvBackwardDataAlgo(algorithm_desc),
3160 /*workSpace=*/scratch_memory.opaque(),
3161 /*workSpaceSizeInBytes=*/scratch_memory.size(),
3162 /*beta=*/beta,
3163 /*dxDesc=*/input_nd.handle(),
3164 /*dx=*/input_data.opaque()));
3165 break;
3166 }
3167 case dnn::ConvolutionKind::BACKWARD_FILTER: {
3168 SE_RETURN_IF_ERROR(get_bwd_filter_bugs());
3169 RETURN_IF_CUDNN_ERROR(cudnnConvolutionBackwardFilter(
3170 cudnn.handle(),
3171 /*alpha=*/alpha,
3172 /*srcDesc=*/input_nd.handle(),
3173 /*srcData=*/input_data.opaque(),
3174 /*diffDesc=*/output_nd.handle(),
3175 /*diffData=*/output_data.opaque(),
3176 /*convDesc=*/conv.handle(),
3177 /*algo=*/ToConvBackwardFilterAlgo(algorithm_desc),
3178 /*workSpace=*/scratch_memory.opaque(),
3179 /*workSpaceSizeInBytes=*/scratch_memory.size(),
3180 /*beta=*/beta,
3181 /*gradDesc=*/filter_nd.handle(),
3182 /*dw=*/filter_data.opaque()));
3183 break;
3184 }
3185 default:
3186 return port::InternalError(
3187 absl::StrCat("Unexpected convolution kind ", static_cast<int>(kind)));
3188 }
3189
3190 if (is_profiling) {
3191 if (!timer->Stop(AsGpuStream(stream))) {
3192 return port::Status(port::error::INTERNAL, "Failed to stop timer");
3193 }
3194 output_profile_result->set_algorithm(algorithm_desc);
3195 output_profile_result->set_elapsed_time_in_ms(
3196 timer->GetElapsedMilliseconds());
3197 output_profile_result->set_scratch_size(scratch_memory.size());
3198 }
3199
3200 return port::Status::OK();
3201 }
3202
3203 // A helper function to query if a CudnnConvolutionDescriptor has tensor_op_math
3204 // set
IsTensorMathOpSet(const CudnnConvolutionDescriptor & conv)3205 static bool IsTensorMathOpSet(const CudnnConvolutionDescriptor& conv) {
3206 cudnnMathType_t math_type;
3207 CHECK_CUDNN_OK(cudnnGetConvolutionMathType(conv.handle(), &math_type));
3208 return math_type == CUDNN_TENSOR_OP_MATH;
3209 }
3210
3211 template <typename ElementType, typename BiasType, typename ScaleType,
3212 typename OutputType>
DoFusedConvolveImpl(Stream * stream,const dnn::BatchDescriptor & conv_input_descriptor,const DeviceMemory<ElementType> & conv_input_data,ScaleType conv_input_scale,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<ElementType> & filter_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const DeviceMemory<OutputType> & side_input_data,ScaleType side_input_scale,const dnn::BatchDescriptor & bias_descriptor,const DeviceMemory<BiasType> & biases,dnn::ActivationMode activation_mode,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<OutputType> * output_data,dnn::DataType accumulator_type,ScratchAllocator * scratch_allocator,const dnn::AlgorithmConfig & algorithm_config,dnn::ProfileResult * output_profile_result)3213 port::Status CudnnSupport::DoFusedConvolveImpl(
3214 Stream* stream, const dnn::BatchDescriptor& conv_input_descriptor,
3215 const DeviceMemory<ElementType>& conv_input_data,
3216 ScaleType conv_input_scale, const dnn::FilterDescriptor& filter_descriptor,
3217 const DeviceMemory<ElementType>& filter_data,
3218 const dnn::ConvolutionDescriptor& convolution_descriptor,
3219 const DeviceMemory<OutputType>& side_input_data, ScaleType side_input_scale,
3220 const dnn::BatchDescriptor& bias_descriptor,
3221 const DeviceMemory<BiasType>& biases, dnn::ActivationMode activation_mode,
3222 const dnn::BatchDescriptor& output_descriptor,
3223 DeviceMemory<OutputType>* output_data, dnn::DataType accumulator_type,
3224 ScratchAllocator* scratch_allocator,
3225 const dnn::AlgorithmConfig& algorithm_config,
3226 dnn::ProfileResult* output_profile_result) {
3227 if (activation_mode != dnn::ActivationMode::kRelu &&
3228 activation_mode != dnn::ActivationMode::kNone) {
3229 return port::Status(port::error::INVALID_ARGUMENT,
3230 "cudnnConvolutionBiasActivationForward() only supports "
3231 "Relu or None activation.");
3232 }
3233
3234 CudnnTensorDescriptor conv_input_nd(
3235 conv_input_descriptor,
3236 GetCudnnDataType<ElementType>(conv_input_descriptor.layout()));
3237 CudnnTensorDescriptor output_nd(
3238 output_descriptor,
3239 GetCudnnDataType<OutputType>(conv_input_descriptor.layout()));
3240 CudnnFilterDescriptor filter(
3241 filter_descriptor,
3242 GetCudnnDataType<ElementType>(conv_input_descriptor.layout()));
3243 CudnnTensorDescriptor bias_nd(bias_descriptor, GetCudnnDataType<BiasType>());
3244 CudnnConvolutionDescriptor conv(convolution_descriptor,
3245 ToCudnnDataType(accumulator_type));
3246
3247 auto cudnn = cudnn_->GetHandle(parent_, stream);
3248
3249 const bool is_profiling = output_profile_result != nullptr;
3250
3251 DeviceMemory<uint8> scratch;
3252 SE_ASSIGN_OR_RETURN(
3253 dnn::AlgorithmDesc algo_desc,
3254 GetCudnnConvolutionForwardAlgorithm(
3255 stream, cudnn, algorithm_config, conv_input_nd, filter, conv,
3256 output_nd, scratch_allocator, &scratch));
3257
3258 std::unique_ptr<GpuTimer, GpuTimerDeleter> timer;
3259 if (is_profiling) {
3260 timer.reset(new GpuTimer(parent_)); // NOLINT
3261 // The start and stop of the timer should be as close to the Cudnn call as
3262 // possible. It is still possible for other threads to issue workload on
3263 // to this stream. So it could take multiple profiling measurements.
3264 if (!timer->Init() || !timer->Start(AsGpuStream(stream))) {
3265 return port::Status(port::error::INTERNAL, "Failed to start timer");
3266 }
3267 }
3268 // CUDNN v6 only supports CUDNN_NOT_PROPAGATE_NAN as the reluNanOpt for
3269 // activation descriptor. Note that this will change the nan propagation
3270 // behavior from separate conv, bias, and relu (which by default is
3271 // CUDNN_PROPAGATE_NAN.
3272 CudnnActivationDescriptor activation_desc(
3273 activation_mode, CUDNN_NOT_PROPAGATE_NAN, output_descriptor.value_max());
3274 auto side_input_data_ptr = (side_input_scale == 0) ? output_data->opaque()
3275 : side_input_data.opaque();
3276
3277 VLOG(2) << "\nconv_input_scale = " << conv_input_scale
3278 << "\nconv_input_nd.handle() = " << conv_input_nd.handle()
3279 << "\nconv_input_data.opaque() = " << conv_input_data.opaque()
3280 << "\nfilter.handle() = " << filter.handle()
3281 << "\nfilter_data.opaque() = " << filter_data.opaque()
3282 << "\nconv.handle() = " << conv.handle()
3283 << "\nalgo = " << algo_desc.algo_id()
3284 << "\nscratch.opaque() = " << scratch.opaque()
3285 << "\nscratch.size() = " << scratch.size()
3286 << "\nside_input_scale = " << side_input_scale
3287 << "\noutput_nd.handle() = " << output_nd.handle()
3288 << "\nside_input_data_ptr = " << side_input_data_ptr
3289 << "\nbias_nd.handle() = " << bias_nd.handle()
3290 << "\nbiases.opaque() = " << biases.opaque()
3291 << "\nactivation_desc.handle() = " << activation_desc.handle()
3292 << "\noutput_nd.handle() = " << output_nd.handle()
3293 << "\noutput_data->opaque() = " << output_data->opaque();
3294
3295 if (algo_desc.algo_id() == CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED &&
3296 !ShouldIncludeWinogradNonfusedAlgo(conv_input_descriptor,
3297 output_descriptor)) {
3298 return port::Status(port::error::FAILED_PRECONDITION,
3299 "This configuration has potential integer overflow in "
3300 "cuDNNv5 and cuDNNv6. See around b/68264959.");
3301 }
3302 if (IsTensorMathOpSet(conv) != algo_desc.tensor_ops_enabled()) {
3303 return port::Status(port::error::FAILED_PRECONDITION,
3304 "Tensor op math type in dnn::AlgorithmDesc does not "
3305 "match that of the CudnnConvolutionDescriptor");
3306 }
3307
3308 RETURN_IF_CUDNN_ERROR(cudnnConvolutionBiasActivationForward(
3309 cudnn.handle(),
3310 /*alpha1=*/&conv_input_scale,
3311 /*srcDesc=*/conv_input_nd.handle(), /*srcData=*/conv_input_data.opaque(),
3312 /*filterDesc=*/filter.handle(), /*filterData=*/filter_data.opaque(),
3313 /*convDesc=*/conv.handle(), ToConvForwardAlgo(algo_desc),
3314 /*workSpace=*/scratch.opaque(),
3315 /*workSpaceSizeInBytes=*/scratch.size(), /*alpha2=*/&side_input_scale,
3316 /*zDesc=*/output_nd.handle(), /*z=*/side_input_data_ptr,
3317 /*biasDesc=*/bias_nd.handle(), /*bias=*/biases.opaque(),
3318 /*activationDesc=*/activation_desc.handle(),
3319 /*yDesc=*/output_nd.handle(), /*y=*/output_data->opaque()));
3320
3321 if (is_profiling) {
3322 if (!timer->Stop(AsGpuStream(stream))) {
3323 return port::Status(port::error::INTERNAL, "Failed to stop timer");
3324 }
3325 output_profile_result->set_algorithm(algo_desc);
3326 output_profile_result->set_elapsed_time_in_ms(
3327 timer->GetElapsedMilliseconds());
3328 output_profile_result->set_scratch_size(scratch.size());
3329 }
3330
3331 return port::Status::OK();
3332 }
3333
GetConvolveAlgorithms(bool with_winograd_nonfused,int cc_major,int cc_minor,std::vector<dnn::AlgorithmDesc> * out_algorithms)3334 bool CudnnSupport::GetConvolveAlgorithms(
3335 bool with_winograd_nonfused, int cc_major, int cc_minor,
3336 std::vector<dnn::AlgorithmDesc>* out_algorithms) {
3337 bool tensor_op_math_available = TensorOpMathAvailable(cc_major);
3338 out_algorithms->clear();
3339
3340 std::vector<dnn::AlgorithmDesc::Index> algo_types = {
3341 // clang-format off
3342 CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM,
3343 CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM,
3344 CUDNN_CONVOLUTION_FWD_ALGO_GEMM,
3345 CUDNN_CONVOLUTION_FWD_ALGO_DIRECT,
3346 CUDNN_CONVOLUTION_FWD_ALGO_FFT,
3347 CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD,
3348 // clang-format on
3349 };
3350 if (CudnnEnvVar<FftTilingForward>::IsEnabled()) {
3351 algo_types.push_back(CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING);
3352 }
3353 if (CudnnEnvVar<WinogradNonfused>::IsEnabled() && with_winograd_nonfused) {
3354 algo_types.push_back(CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED);
3355 }
3356
3357 // The algorithms are intentionally ordered for deterministic operation
3358 for (auto i : algo_types) {
3359 if (tensor_op_math_available) {
3360 out_algorithms->push_back({i, /*use_tensor_ops=*/true});
3361 }
3362 out_algorithms->push_back({i, /*use_tensor_ops=*/false});
3363 }
3364
3365 return true;
3366 }
3367
GetRnnAlgorithms(std::vector<dnn::AlgorithmDesc> * out_algorithms)3368 bool CudnnSupport::GetRnnAlgorithms(
3369 std::vector<dnn::AlgorithmDesc>* out_algorithms) {
3370 std::vector<dnn::AlgorithmDesc::Index> algo_types = {
3371 // clang-format off
3372 CUDNN_RNN_ALGO_STANDARD,
3373 CUDNN_RNN_ALGO_PERSIST_STATIC,
3374 CUDNN_RNN_ALGO_PERSIST_DYNAMIC,
3375 // clang-format on
3376 };
3377
3378 out_algorithms->clear();
3379 for (auto i : algo_types) {
3380 out_algorithms->push_back({i, /*use_tensor_ops=*/false});
3381 #if CUDNN_VERSION >= 7100
3382 if (RnnTensorOpMathEnabled()) {
3383 out_algorithms->push_back({i, /*use_tensor_ops=*/true});
3384 }
3385 #endif
3386 }
3387 return true;
3388 }
3389
GetConvolveBackwardDataAlgorithms(bool with_winograd_nonfused,int cc_major,int cc_minor,std::vector<dnn::AlgorithmDesc> * out_algorithms)3390 bool CudnnSupport::GetConvolveBackwardDataAlgorithms(
3391 bool with_winograd_nonfused, int cc_major, int cc_minor,
3392 std::vector<dnn::AlgorithmDesc>* out_algorithms) {
3393 bool tensor_op_math_available = TensorOpMathAvailable(cc_major);
3394 out_algorithms->clear();
3395
3396 std::vector<dnn::AlgorithmDesc::Index> algo_types = {
3397 // clang-format off
3398 CUDNN_CONVOLUTION_BWD_DATA_ALGO_1,
3399 CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT,
3400 CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING,
3401 CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD,
3402 // clang-format on
3403 };
3404 if (CudnnEnvVar<WinogradNonfused>::IsEnabled() && with_winograd_nonfused) {
3405 algo_types.push_back(CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD_NONFUSED);
3406 }
3407 if (!RequireCudnnDeterminism()) {
3408 algo_types.push_back(CUDNN_CONVOLUTION_BWD_DATA_ALGO_0);
3409 }
3410
3411 // The algorithms are intentionally ordered for deterministic operation
3412 for (auto i : algo_types) {
3413 if (tensor_op_math_available) {
3414 out_algorithms->push_back({i, /*use_tensor_ops=*/true});
3415 }
3416 out_algorithms->push_back({i, /*use_tensor_ops=*/false});
3417 }
3418
3419 return true;
3420 }
3421
GetConvolveBackwardFilterAlgorithms(bool with_winograd_nonfused,int cc_major,int cc_minor,std::vector<dnn::AlgorithmDesc> * out_algorithms)3422 bool CudnnSupport::GetConvolveBackwardFilterAlgorithms(
3423 bool with_winograd_nonfused, int cc_major, int cc_minor,
3424 std::vector<dnn::AlgorithmDesc>* out_algorithms) {
3425 bool tensor_op_math_available = TensorOpMathAvailable(cc_major);
3426 out_algorithms->clear();
3427
3428 std::vector<dnn::AlgorithmDesc::Index> algo_types = {
3429 // clang-format off
3430 CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1,
3431 CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT,
3432 // Based on cudnn.h, the following is not implemented.
3433 // CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD,
3434
3435 // Produces incorrect results for some shapes. Disabled for now, see
3436 // NVIDIA bug 2072856. TODO(csigg): Only disable for subset of shapes.
3437 // CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT_TILING,
3438 // clang-format on
3439 };
3440 if (CudnnEnvVar<WinogradNonfused>::IsEnabled() && with_winograd_nonfused) {
3441 algo_types.push_back(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD_NONFUSED);
3442 }
3443 if (!RequireCudnnDeterminism()) {
3444 algo_types.push_back(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0);
3445 algo_types.push_back(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_3);
3446 }
3447
3448 // The algorithms are intentionally ordered for deterministic operation
3449 for (auto i : algo_types) {
3450 if (tensor_op_math_available) {
3451 out_algorithms->push_back({i, /*use_tensor_ops=*/true});
3452 }
3453 out_algorithms->push_back({i, /*use_tensor_ops=*/false});
3454 }
3455
3456 return true;
3457 }
3458
DoBatchNormalizationForward(Stream * stream,const DeviceMemory<float> & x,const DeviceMemory<float> & scale,const DeviceMemory<float> & offset,const DeviceMemory<float> & estimated_mean,const DeviceMemory<float> & estimated_variance,const DeviceMemory<float> & side_input,const dnn::BatchDescriptor & x_desc,const dnn::BatchDescriptor & scale_offset_desc,const double epsilon,const double exponential_average_factor,dnn::ActivationMode activation_mode,DeviceMemory<float> * y,DeviceMemory<float> * batch_mean,DeviceMemory<float> * batch_var,DeviceMemory<float> * saved_mean,DeviceMemory<float> * saved_inv_var,bool is_training,ScratchAllocator * reserve_space_allocator,ScratchAllocator * workspace_allocator,std::function<const DeviceMemory<float> & ()> var_to_inv_var,std::function<void ()> inv_var_to_var)3459 bool CudnnSupport::DoBatchNormalizationForward(
3460 Stream* stream, const DeviceMemory<float>& x,
3461 const DeviceMemory<float>& scale, const DeviceMemory<float>& offset,
3462 const DeviceMemory<float>& estimated_mean,
3463 const DeviceMemory<float>& estimated_variance,
3464 const DeviceMemory<float>& side_input, const dnn::BatchDescriptor& x_desc,
3465 const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
3466 const double exponential_average_factor,
3467 dnn::ActivationMode activation_mode, DeviceMemory<float>* y,
3468 DeviceMemory<float>* batch_mean, DeviceMemory<float>* batch_var,
3469 DeviceMemory<float>* saved_mean, DeviceMemory<float>* saved_inv_var,
3470 bool is_training, ScratchAllocator* reserve_space_allocator,
3471 ScratchAllocator* workspace_allocator,
3472 std::function<const DeviceMemory<float>&()> var_to_inv_var,
3473 std::function<void()> inv_var_to_var) {
3474 return IsStatusOk(
3475 DoBatchNormalizationForwardImpl<float, float>(
3476 stream, dnn::DataType::kFloat, dnn::DataType::kFloat, x, scale,
3477 offset, estimated_mean, estimated_variance, side_input, x_desc,
3478 scale_offset_desc, epsilon, exponential_average_factor,
3479 activation_mode, y, batch_mean, batch_var, saved_mean, saved_inv_var,
3480 is_training, reserve_space_allocator, workspace_allocator,
3481 std::move(var_to_inv_var), std::move(inv_var_to_var)),
3482 /*report_error=*/true);
3483 }
3484
DoBatchNormalizationForward(Stream * stream,const DeviceMemory<Eigen::half> & x,const DeviceMemory<float> & scale,const DeviceMemory<float> & offset,const DeviceMemory<float> & estimated_mean,const DeviceMemory<float> & estimated_variance,const DeviceMemory<float> & side_input,const dnn::BatchDescriptor & x_desc,const dnn::BatchDescriptor & scale_offset_desc,const double epsilon,const double exponential_average_factor,dnn::ActivationMode activation_mode,DeviceMemory<Eigen::half> * y,DeviceMemory<float> * batch_mean,DeviceMemory<float> * batch_var,DeviceMemory<float> * saved_mean,DeviceMemory<float> * saved_inv_var,bool is_training,ScratchAllocator * reserve_space_allocator,ScratchAllocator * workspace_allocator,std::function<const DeviceMemory<float> & ()> var_to_inv_var,std::function<void ()> inv_var_to_var)3485 bool CudnnSupport::DoBatchNormalizationForward(
3486 Stream* stream, const DeviceMemory<Eigen::half>& x,
3487 const DeviceMemory<float>& scale, const DeviceMemory<float>& offset,
3488 const DeviceMemory<float>& estimated_mean,
3489 const DeviceMemory<float>& estimated_variance,
3490 const DeviceMemory<float>& side_input, const dnn::BatchDescriptor& x_desc,
3491 const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
3492 const double exponential_average_factor,
3493 dnn::ActivationMode activation_mode, DeviceMemory<Eigen::half>* y,
3494 DeviceMemory<float>* batch_mean, DeviceMemory<float>* batch_var,
3495 DeviceMemory<float>* saved_mean, DeviceMemory<float>* saved_inv_var,
3496 bool is_training, ScratchAllocator* reserve_space_allocator,
3497 ScratchAllocator* workspace_allocator,
3498 std::function<const DeviceMemory<float>&()> var_to_inv_var,
3499 std::function<void()> inv_var_to_var) {
3500 return IsStatusOk(
3501 DoBatchNormalizationForwardImpl<Eigen::half, float>(
3502 stream, dnn::DataType::kHalf, dnn::DataType::kFloat, x, scale, offset,
3503 estimated_mean, estimated_variance, side_input, x_desc,
3504 scale_offset_desc, epsilon, exponential_average_factor,
3505 activation_mode, y, batch_mean, batch_var, saved_mean, saved_inv_var,
3506 is_training, reserve_space_allocator, workspace_allocator,
3507 std::move(var_to_inv_var), std::move(inv_var_to_var)),
3508 /*report_error=*/true);
3509 }
3510
3511 template <class T, class U>
DoBatchNormalizationForwardImpl(Stream * stream,dnn::DataType input_data_type,dnn::DataType scale_data_type,const DeviceMemory<T> & x,const DeviceMemory<U> & scale,const DeviceMemory<U> & offset,const DeviceMemory<U> & estimated_mean,const DeviceMemory<U> & estimated_variance,const DeviceMemory<U> & side_input,const dnn::BatchDescriptor & x_desc,const dnn::BatchDescriptor & scale_offset_desc,const double epsilon,const double exponential_average_factor,dnn::ActivationMode activation_mode,DeviceMemory<T> * y,DeviceMemory<U> * batch_mean,DeviceMemory<U> * batch_var,DeviceMemory<U> * saved_mean,DeviceMemory<U> * saved_inv_var,bool is_training,ScratchAllocator * reserve_space_allocator,ScratchAllocator * workspace_allocator,std::function<const DeviceMemory<U> & ()> var_to_inv_var,std::function<void ()> inv_var_to_var)3512 port::Status CudnnSupport::DoBatchNormalizationForwardImpl(
3513 Stream* stream, dnn::DataType input_data_type,
3514 dnn::DataType scale_data_type, const DeviceMemory<T>& x,
3515 const DeviceMemory<U>& scale, const DeviceMemory<U>& offset,
3516 const DeviceMemory<U>& estimated_mean,
3517 const DeviceMemory<U>& estimated_variance,
3518 const DeviceMemory<U>& side_input, const dnn::BatchDescriptor& x_desc,
3519 const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
3520 const double exponential_average_factor,
3521 dnn::ActivationMode activation_mode, DeviceMemory<T>* y,
3522 DeviceMemory<U>* batch_mean, DeviceMemory<U>* batch_var,
3523 DeviceMemory<U>* saved_mean, DeviceMemory<U>* saved_inv_var,
3524 bool is_training, ScratchAllocator* reserve_space_allocator,
3525 ScratchAllocator* workspace_allocator,
3526 std::function<const DeviceMemory<U>&()> var_to_inv_var,
3527 std::function<void()> inv_var_to_var) {
3528 CudnnTensorDescriptor x_descriptor(x_desc, ToCudnnDataType(input_data_type));
3529 CudnnTensorDescriptor scale_offset_descriptor(
3530 scale_offset_desc, ToCudnnDataType(scale_data_type));
3531 cudnnBatchNormMode_t mode = CUDNN_BATCHNORM_SPATIAL;
3532 #if CUDNN_VERSION >= 7000
3533 if (BatchnormSpatialPersistentEnabled() && is_training) {
3534 mode = CUDNN_BATCHNORM_SPATIAL_PERSISTENT;
3535 }
3536 #endif
3537 float one = 1.0;
3538 float zero = 0.0;
3539 auto cudnn = cudnn_->GetHandle(parent_, stream);
3540
3541 DeviceMemory<uint8> workspace;
3542 DeviceMemory<uint8> reserve_space;
3543
3544 #if CUDNN_VERSION >= 7402
3545 const auto get_bn_ops = [&]() -> cudnnBatchNormOps_t {
3546 if (side_input.is_null()) {
3547 return activation_mode == dnn::ActivationMode::kNone
3548 ? CUDNN_BATCHNORM_OPS_BN
3549 : CUDNN_BATCHNORM_OPS_BN_ACTIVATION;
3550 } else {
3551 return CUDNN_BATCHNORM_OPS_BN_ADD_ACTIVATION;
3552 }
3553 };
3554 const cudnnBatchNormOps_t bn_ops = get_bn_ops();
3555
3556 // We use Nan propagation to be consistent with CudnnSupport::DoActivate(...).
3557 CudnnActivationDescriptor activation_desc(
3558 activation_mode, CUDNN_PROPAGATE_NAN, x_desc.value_max());
3559
3560 if (reserve_space_allocator != nullptr && workspace_allocator != nullptr) {
3561 SE_ASSIGN_OR_RETURN(
3562 workspace,
3563 CreateBatchNormForwardWorkspace(
3564 stream, cudnn, mode, bn_ops, activation_desc.handle(), x_descriptor,
3565 scale_offset_descriptor, workspace_allocator))
3566 if (is_training) {
3567 size_t reserve_space_size_in_bytes = 0;
3568 RETURN_IF_CUDNN_ERROR(
3569 cudnnGetBatchNormalizationTrainingExReserveSpaceSize(
3570 /*handle=*/cudnn.handle(), /*mode=*/mode, /*bnOps=*/bn_ops,
3571 /*activationDesc=*/activation_desc.handle(),
3572 /*xDesc=*/x_descriptor.handle(),
3573 /*sizeInBytes=*/&reserve_space_size_in_bytes));
3574 SE_ASSIGN_OR_RETURN(reserve_space, reserve_space_allocator->AllocateBytes(
3575 reserve_space_size_in_bytes));
3576 }
3577 }
3578 #endif
3579
3580 auto check_no_side_input_or_activation = [&]() -> port::Status {
3581 if (activation_mode != dnn::ActivationMode::kNone ||
3582 !side_input.is_null()) {
3583 return port::Status(
3584 port::error::INTERNAL,
3585 absl::StrCat(
3586 "Side input and activation are not supported by cuDNN version: ",
3587 CUDNN_VERSION));
3588 } else {
3589 return port::Status::OK();
3590 }
3591 };
3592
3593 if (is_training) {
3594 CHECK_EQ(batch_mean->is_null(), batch_var->is_null())
3595 << "batch_mean and batch_var must both be null or both be non-null";
3596
3597 void* batch_mean_opaque;
3598 void* batch_var_opaque;
3599 if (!batch_mean->is_null() && !batch_var->is_null()) {
3600 stream->ThenMemZero(batch_mean, batch_mean->size());
3601 stream->ThenMemZero(batch_var, batch_var->size());
3602 batch_mean_opaque = batch_mean->opaque();
3603 batch_var_opaque = batch_var->opaque();
3604 } else {
3605 batch_mean_opaque = nullptr;
3606 batch_var_opaque = nullptr;
3607 }
3608
3609 bool called = false;
3610 #if CUDNN_VERSION >= 7402
3611 if (reserve_space_allocator != nullptr && workspace_allocator != nullptr) {
3612 called = true;
3613 RETURN_IF_CUDNN_ERROR(cudnnBatchNormalizationForwardTrainingEx(
3614 /*handle=*/cudnn.handle(),
3615 /*mode=*/mode,
3616 /*bnOps=*/bn_ops,
3617 /*alpha=*/&one,
3618 /*beta=*/&zero,
3619 /*xDesc=*/x_descriptor.handle(),
3620 /*xData=*/x.opaque(),
3621 /*zDesc=*/x_descriptor.handle(),
3622 /*zData=*/side_input.opaque(),
3623 /*yDesc=*/x_descriptor.handle(),
3624 /*yData=*/y->opaque(),
3625 /*bnScaleBiasMeanVarDesc=*/scale_offset_descriptor.handle(),
3626 /*bnScale=*/scale.opaque(),
3627 /*bnBias=*/offset.opaque(),
3628 /*exponentialAverageFactor=*/exponential_average_factor,
3629 /*resultRunningMean=*/batch_mean_opaque,
3630 /*resultRunningVariance=*/batch_var_opaque,
3631 /*epsilon=*/epsilon,
3632 /*resultSaveMean=*/saved_mean->opaque(),
3633 /*resultSaveInvVariance=*/saved_inv_var->opaque(),
3634 /*activationDesc=*/activation_desc.handle(),
3635 /*workspace=*/workspace.opaque(),
3636 /*workSpaceSizeInBytes=*/workspace.size(),
3637 /*reserveSpace=*/reserve_space.opaque(),
3638 /*reserveSpaceSizeInBytes=*/reserve_space.size()));
3639 }
3640 #endif
3641 if (!called) {
3642 SE_RETURN_IF_ERROR(check_no_side_input_or_activation());
3643 RETURN_IF_CUDNN_ERROR(cudnnBatchNormalizationForwardTraining(
3644 cudnn.handle(), mode, &one, &zero, x_descriptor.handle(), x.opaque(),
3645 x_descriptor.handle(), y->opaque(), scale_offset_descriptor.handle(),
3646 scale.opaque(), offset.opaque(), exponential_average_factor,
3647 batch_mean_opaque, batch_var_opaque, epsilon, saved_mean->opaque(),
3648 saved_inv_var->opaque()));
3649 }
3650 } else {
3651 const void* maybe_inv_var = estimated_variance.opaque();
3652 SE_RETURN_IF_ERROR(check_no_side_input_or_activation());
3653 RETURN_IF_CUDNN_ERROR(cudnnBatchNormalizationForwardInference(
3654 cudnn.handle(), mode, &one, &zero, x_descriptor.handle(), x.opaque(),
3655 x_descriptor.handle(), y->opaque(), scale_offset_descriptor.handle(),
3656 scale.opaque(), offset.opaque(), estimated_mean.opaque(), maybe_inv_var,
3657 epsilon));
3658 }
3659 return port::Status::OK();
3660 }
3661
DoBatchNormalizationBackward(Stream * stream,const DeviceMemory<float> & y_backprop,const DeviceMemory<float> & x,const DeviceMemory<float> & scale,const DeviceMemory<float> & mean,const DeviceMemory<float> & inv_var,const dnn::BatchDescriptor & x_desc,const dnn::BatchDescriptor & scale_offset_desc,const double epsilon,DeviceMemory<float> * x_backprop,DeviceMemory<float> * scale_backprop,DeviceMemory<float> * offset_backprop,DeviceMemory<uint8> * reserve_space_data,ScratchAllocator * workspace_allocator)3662 bool CudnnSupport::DoBatchNormalizationBackward(
3663 Stream* stream, const DeviceMemory<float>& y_backprop,
3664 const DeviceMemory<float>& x, const DeviceMemory<float>& scale,
3665 const DeviceMemory<float>& mean, const DeviceMemory<float>& inv_var,
3666 const dnn::BatchDescriptor& x_desc,
3667 const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
3668 DeviceMemory<float>* x_backprop, DeviceMemory<float>* scale_backprop,
3669 DeviceMemory<float>* offset_backprop,
3670 DeviceMemory<uint8>* reserve_space_data,
3671 ScratchAllocator* workspace_allocator) {
3672 return IsStatusOk(DoBatchNormalizationBackwardImpl(
3673 stream, CUDNN_DATA_FLOAT, CUDNN_DATA_FLOAT, y_backprop,
3674 x, scale, mean, inv_var, x_desc, scale_offset_desc,
3675 epsilon, x_backprop, scale_backprop, offset_backprop,
3676 reserve_space_data, workspace_allocator),
3677 /*report_error=*/true);
3678 }
3679
DoBatchNormalizationBackward(Stream * stream,const DeviceMemory<Eigen::half> & y_backprop,const DeviceMemory<Eigen::half> & x,const DeviceMemory<float> & scale,const DeviceMemory<float> & mean,const DeviceMemory<float> & inv_var,const dnn::BatchDescriptor & x_desc,const dnn::BatchDescriptor & scale_offset_desc,const double epsilon,DeviceMemory<Eigen::half> * x_backprop,DeviceMemory<float> * scale_backprop,DeviceMemory<float> * offset_backprop,DeviceMemory<uint8> * reserve_space_data,ScratchAllocator * workspace_allocator)3680 bool CudnnSupport::DoBatchNormalizationBackward(
3681 Stream* stream, const DeviceMemory<Eigen::half>& y_backprop,
3682 const DeviceMemory<Eigen::half>& x, const DeviceMemory<float>& scale,
3683 const DeviceMemory<float>& mean, const DeviceMemory<float>& inv_var,
3684 const dnn::BatchDescriptor& x_desc,
3685 const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
3686 DeviceMemory<Eigen::half>* x_backprop, DeviceMemory<float>* scale_backprop,
3687 DeviceMemory<float>* offset_backprop,
3688 DeviceMemory<uint8>* reserve_space_data,
3689 ScratchAllocator* workspace_allocator) {
3690 return IsStatusOk(DoBatchNormalizationBackwardImpl(
3691 stream, CUDNN_DATA_HALF, CUDNN_DATA_FLOAT, y_backprop,
3692 x, scale, mean, inv_var, x_desc, scale_offset_desc,
3693 epsilon, x_backprop, scale_backprop, offset_backprop,
3694 reserve_space_data, workspace_allocator),
3695 /*report_error=*/true);
3696 }
3697
3698 template <class T, class U>
DoBatchNormalizationBackwardImpl(Stream * stream,int cudnn_input_type,int cudnn_scale_type,const DeviceMemory<T> & y_backprop,const DeviceMemory<T> & x,const DeviceMemory<U> & scale,const DeviceMemory<U> & mean,const DeviceMemory<U> & inv_var,const dnn::BatchDescriptor & x_desc,const dnn::BatchDescriptor & scale_offset_desc,const double epsilon,DeviceMemory<T> * x_backprop,DeviceMemory<U> * scale_backprop,DeviceMemory<U> * offset_backprop,DeviceMemory<uint8> * reserve_space_data,ScratchAllocator * workspace_allocator)3699 port::Status CudnnSupport::DoBatchNormalizationBackwardImpl(
3700 Stream* stream, int cudnn_input_type, int cudnn_scale_type,
3701 const DeviceMemory<T>& y_backprop, const DeviceMemory<T>& x,
3702 const DeviceMemory<U>& scale, const DeviceMemory<U>& mean,
3703 const DeviceMemory<U>& inv_var, const dnn::BatchDescriptor& x_desc,
3704 const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
3705 DeviceMemory<T>* x_backprop, DeviceMemory<U>* scale_backprop,
3706 DeviceMemory<U>* offset_backprop, DeviceMemory<uint8>* reserve_space_data,
3707 ScratchAllocator* workspace_allocator) {
3708 CudnnTensorDescriptor x_descriptor(
3709 x_desc, static_cast<cudnnDataType_t>(cudnn_input_type));
3710 CudnnTensorDescriptor scale_offset_descriptor(
3711 scale_offset_desc, static_cast<cudnnDataType_t>(cudnn_scale_type));
3712 cudnnBatchNormMode_t mode = CUDNN_BATCHNORM_SPATIAL;
3713 #if CUDNN_VERSION >= 7000
3714 if (BatchnormSpatialPersistentEnabled()) {
3715 mode = CUDNN_BATCHNORM_SPATIAL_PERSISTENT;
3716 }
3717 #endif
3718 float one = 1.0;
3719 float zero = 0.0;
3720
3721 auto cudnn = cudnn_->GetHandle(parent_, stream);
3722
3723 bool called = false;
3724 #if CUDNN_VERSION >= 7402
3725 if (reserve_space_data != nullptr && workspace_allocator != nullptr) {
3726 called = true;
3727 const cudnnBatchNormOps_t bn_ops = CUDNN_BATCHNORM_OPS_BN;
3728 SE_ASSIGN_OR_RETURN(DeviceMemory<uint8> workspace,
3729 CreateBatchNormBackwardWorkspace(
3730 stream, cudnn, mode, bn_ops, x_descriptor,
3731 scale_offset_descriptor, workspace_allocator))
3732 RETURN_IF_CUDNN_ERROR(cudnnBatchNormalizationBackwardEx(
3733 /*handle=*/cudnn.handle(),
3734 /*mode=*/mode,
3735 /*bnOps=*/bn_ops,
3736 /*alphaDataDiff=*/&one,
3737 /*betaDataDiff=*/&zero,
3738 /*alphaParamDiff=*/&one,
3739 /*betaParamDiff=*/&zero,
3740 /*xDesc=*/x_descriptor.handle(),
3741 /*xData=*/x.opaque(),
3742 /*yDesc=*/nullptr,
3743 /*yData=*/nullptr,
3744 /*dyDesc=*/x_descriptor.handle(),
3745 /*dyData=*/y_backprop.opaque(),
3746 /*dzDesc=*/nullptr,
3747 /*dzData=*/nullptr,
3748 /*dxDesc=*/x_descriptor.handle(),
3749 /*dxData=*/x_backprop->opaque(),
3750 /*dBnScaleBiasDesc=*/scale_offset_descriptor.handle(),
3751 /*bnScaleData=*/scale.opaque(),
3752 /*bnBiasData=*/nullptr,
3753 /*dBnScaleData=*/scale_backprop->opaque(),
3754 /*dBnBiasData=*/offset_backprop->opaque(),
3755 /*epsilon=*/epsilon,
3756 /*savedMean=*/mean.opaque(),
3757 /*savedInvVariance=*/inv_var.opaque(),
3758 /*activationDesc=*/nullptr,
3759 /*workspace=*/workspace.opaque(),
3760 /*workSpaceSizeInBytes=*/workspace.size(),
3761 /*reserveSpace=*/reserve_space_data->opaque(),
3762 /*reserveSpaceSizeInBytes=*/reserve_space_data->size()));
3763 }
3764 #endif
3765 if (!called) {
3766 RETURN_IF_CUDNN_ERROR(cudnnBatchNormalizationBackward(
3767 cudnn.handle(), mode, &one, &zero, &one, &zero, x_descriptor.handle(),
3768 x.opaque(), x_descriptor.handle(), y_backprop.opaque(),
3769 x_descriptor.handle(), x_backprop->opaque(),
3770 scale_offset_descriptor.handle(), scale.opaque(),
3771 scale_backprop->opaque(), offset_backprop->opaque(), epsilon,
3772 mean.opaque(), inv_var.opaque()));
3773 }
3774
3775 return port::Status::OK();
3776 }
3777
DoFusedConvolve(Stream * stream,const dnn::BatchDescriptor & conv_input_descriptor,const DeviceMemory<double> & conv_input_data,double conv_input_scale,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<double> & filter_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const DeviceMemory<double> & side_input_data,double side_input_scale,const dnn::BatchDescriptor & bias_descriptor,const DeviceMemory<double> & biases,dnn::ActivationMode activation_mode,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<double> * output_data,ScratchAllocator * scratch_allocator,const dnn::AlgorithmConfig & algorithm_config,dnn::ProfileResult * output_profile_result)3778 bool CudnnSupport::DoFusedConvolve(
3779 Stream* stream, const dnn::BatchDescriptor& conv_input_descriptor,
3780 const DeviceMemory<double>& conv_input_data, double conv_input_scale,
3781 const dnn::FilterDescriptor& filter_descriptor,
3782 const DeviceMemory<double>& filter_data,
3783 const dnn::ConvolutionDescriptor& convolution_descriptor,
3784 const DeviceMemory<double>& side_input_data, double side_input_scale,
3785 const dnn::BatchDescriptor& bias_descriptor,
3786 const DeviceMemory<double>& biases, dnn::ActivationMode activation_mode,
3787 const dnn::BatchDescriptor& output_descriptor,
3788 DeviceMemory<double>* output_data, ScratchAllocator* scratch_allocator,
3789 const dnn::AlgorithmConfig& algorithm_config,
3790 dnn::ProfileResult* output_profile_result) {
3791 return IsStatusOk(
3792 DoFusedConvolveImpl(
3793 stream, conv_input_descriptor, conv_input_data, conv_input_scale,
3794 filter_descriptor, filter_data, convolution_descriptor,
3795 side_input_data, side_input_scale, bias_descriptor, biases,
3796 activation_mode, output_descriptor, output_data,
3797 GetConvAccumulatorType(dnn::DataType::kDouble), scratch_allocator,
3798 algorithm_config, output_profile_result),
3799 /*report_error=*/!output_profile_result);
3800 }
3801
DoFusedConvolve(Stream * stream,const dnn::BatchDescriptor & conv_input_descriptor,const DeviceMemory<float> & conv_input_data,float conv_input_scale,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<float> & filter_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const DeviceMemory<float> & side_input_data,float side_input_scale,const dnn::BatchDescriptor & bias_descriptor,const DeviceMemory<float> & biases,dnn::ActivationMode activation_mode,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<float> * output_data,ScratchAllocator * scratch_allocator,const dnn::AlgorithmConfig & algorithm_config,dnn::ProfileResult * output_profile_result)3802 bool CudnnSupport::DoFusedConvolve(
3803 Stream* stream, const dnn::BatchDescriptor& conv_input_descriptor,
3804 const DeviceMemory<float>& conv_input_data, float conv_input_scale,
3805 const dnn::FilterDescriptor& filter_descriptor,
3806 const DeviceMemory<float>& filter_data,
3807 const dnn::ConvolutionDescriptor& convolution_descriptor,
3808 const DeviceMemory<float>& side_input_data, float side_input_scale,
3809 const dnn::BatchDescriptor& bias_descriptor,
3810 const DeviceMemory<float>& biases, dnn::ActivationMode activation_mode,
3811 const dnn::BatchDescriptor& output_descriptor,
3812 DeviceMemory<float>* output_data, ScratchAllocator* scratch_allocator,
3813 const dnn::AlgorithmConfig& algorithm_config,
3814 dnn::ProfileResult* output_profile_result) {
3815 return IsStatusOk(
3816 DoFusedConvolveImpl(
3817 stream, conv_input_descriptor, conv_input_data, conv_input_scale,
3818 filter_descriptor, filter_data, convolution_descriptor,
3819 side_input_data, side_input_scale, bias_descriptor, biases,
3820 activation_mode, output_descriptor, output_data,
3821 GetConvAccumulatorType(dnn::DataType::kFloat), scratch_allocator,
3822 algorithm_config, output_profile_result),
3823 /*report_error=*/!output_profile_result);
3824 }
3825
DoFusedConvolve(Stream * stream,const dnn::BatchDescriptor & conv_input_descriptor,const DeviceMemory<Eigen::half> & conv_input_data,float conv_input_scale,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<Eigen::half> & filter_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const DeviceMemory<Eigen::half> & side_input_data,float side_input_scale,const dnn::BatchDescriptor & bias_descriptor,const DeviceMemory<Eigen::half> & biases,dnn::ActivationMode activation_mode,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<Eigen::half> * output_data,ScratchAllocator * scratch_allocator,const dnn::AlgorithmConfig & algorithm_config,dnn::ProfileResult * output_profile_result)3826 bool CudnnSupport::DoFusedConvolve(
3827 Stream* stream, const dnn::BatchDescriptor& conv_input_descriptor,
3828 const DeviceMemory<Eigen::half>& conv_input_data, float conv_input_scale,
3829 const dnn::FilterDescriptor& filter_descriptor,
3830 const DeviceMemory<Eigen::half>& filter_data,
3831 const dnn::ConvolutionDescriptor& convolution_descriptor,
3832 const DeviceMemory<Eigen::half>& side_input_data, float side_input_scale,
3833 const dnn::BatchDescriptor& bias_descriptor,
3834 const DeviceMemory<Eigen::half>& biases,
3835 dnn::ActivationMode activation_mode,
3836 const dnn::BatchDescriptor& output_descriptor,
3837 DeviceMemory<Eigen::half>* output_data, ScratchAllocator* scratch_allocator,
3838 const dnn::AlgorithmConfig& algorithm_config,
3839 dnn::ProfileResult* output_profile_result) {
3840 return IsStatusOk(
3841 DoFusedConvolveImpl(
3842 stream, conv_input_descriptor, conv_input_data, conv_input_scale,
3843 filter_descriptor, filter_data, convolution_descriptor,
3844 side_input_data, side_input_scale, bias_descriptor, biases,
3845 activation_mode, output_descriptor, output_data,
3846 GetConvAccumulatorType(dnn::DataType::kHalf), scratch_allocator,
3847 algorithm_config, output_profile_result),
3848 /*report_error=*/!output_profile_result);
3849 }
3850
DoFusedConvolve(Stream * stream,const dnn::BatchDescriptor & conv_input_descriptor,const DeviceMemory<int8> & conv_input_data,float conv_input_scale,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<int8> & filter_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const DeviceMemory<int8> & side_input_data,float side_input_scale,const dnn::BatchDescriptor & bias_descriptor,const DeviceMemory<float> & biases,dnn::ActivationMode activation_mode,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<int8> * output_data,ScratchAllocator * scratch_allocator,const dnn::AlgorithmConfig & algorithm_config,dnn::ProfileResult * output_profile_result)3851 bool CudnnSupport::DoFusedConvolve(
3852 Stream* stream, const dnn::BatchDescriptor& conv_input_descriptor,
3853 const DeviceMemory<int8>& conv_input_data, float conv_input_scale,
3854 const dnn::FilterDescriptor& filter_descriptor,
3855 const DeviceMemory<int8>& filter_data,
3856 const dnn::ConvolutionDescriptor& convolution_descriptor,
3857 const DeviceMemory<int8>& side_input_data, float side_input_scale,
3858 const dnn::BatchDescriptor& bias_descriptor,
3859 const DeviceMemory<float>& biases, dnn::ActivationMode activation_mode,
3860 const dnn::BatchDescriptor& output_descriptor,
3861 DeviceMemory<int8>* output_data, ScratchAllocator* scratch_allocator,
3862 const dnn::AlgorithmConfig& algorithm_config,
3863 dnn::ProfileResult* output_profile_result) {
3864 int cc_major, cc_minor;
3865 std::tie(cc_major, cc_minor) = GetCcMajorMinor(stream);
3866
3867 if (cc_major < 6 || (cc_major == 6 && cc_minor < 1)) {
3868 LOG(WARNING) << "cudnnConvolutionBiasActivationForward() for int8 is only "
3869 "supported on GPUs with compute capability 6.1 or later.";
3870 return false;
3871 }
3872
3873 return IsStatusOk(
3874 DoFusedConvolveImpl(
3875 stream, conv_input_descriptor, conv_input_data, conv_input_scale,
3876 filter_descriptor, filter_data, convolution_descriptor,
3877 side_input_data, side_input_scale, bias_descriptor, biases,
3878 activation_mode, output_descriptor, output_data,
3879 GetConvAccumulatorType(dnn::DataType::kInt8), scratch_allocator,
3880 algorithm_config, output_profile_result),
3881 /*report_error=*/!output_profile_result);
3882 }
3883
DoFusedConvolve(Stream * stream,const dnn::BatchDescriptor & conv_input_descriptor,const DeviceMemory<int8> & conv_input_data,float conv_input_scale,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<int8> & filter_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const DeviceMemory<float> & side_input_data,float side_input_scale,const dnn::BatchDescriptor & bias_descriptor,const DeviceMemory<float> & biases,dnn::ActivationMode activation_mode,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<float> * output_data,ScratchAllocator * scratch_allocator,const dnn::AlgorithmConfig & algorithm_config,dnn::ProfileResult * output_profile_result)3884 bool CudnnSupport::DoFusedConvolve(
3885 Stream* stream, const dnn::BatchDescriptor& conv_input_descriptor,
3886 const DeviceMemory<int8>& conv_input_data, float conv_input_scale,
3887 const dnn::FilterDescriptor& filter_descriptor,
3888 const DeviceMemory<int8>& filter_data,
3889 const dnn::ConvolutionDescriptor& convolution_descriptor,
3890 const DeviceMemory<float>& side_input_data, float side_input_scale,
3891 const dnn::BatchDescriptor& bias_descriptor,
3892 const DeviceMemory<float>& biases, dnn::ActivationMode activation_mode,
3893 const dnn::BatchDescriptor& output_descriptor,
3894 DeviceMemory<float>* output_data, ScratchAllocator* scratch_allocator,
3895 const dnn::AlgorithmConfig& algorithm_config,
3896 dnn::ProfileResult* output_profile_result) {
3897 int cc_major, cc_minor;
3898 stream->parent()->GetDeviceDescription().cuda_compute_capability(&cc_major,
3899 &cc_minor);
3900 if (cc_major < 6 || (cc_major == 6 && cc_minor < 1)) {
3901 LOG(WARNING) << "cudnnConvolutionBiasActivationForward() for int8 is only "
3902 "supported on GPUs with compute capability 6.1 or later.";
3903 return false;
3904 }
3905
3906 return IsStatusOk(
3907 DoFusedConvolveImpl(
3908 stream, conv_input_descriptor, conv_input_data, conv_input_scale,
3909 filter_descriptor, filter_data, convolution_descriptor,
3910 side_input_data, side_input_scale, bias_descriptor, biases,
3911 activation_mode, output_descriptor, output_data,
3912 GetConvAccumulatorType(dnn::DataType::kInt8), scratch_allocator,
3913 algorithm_config, output_profile_result),
3914 /*report_error=*/!output_profile_result);
3915 }
3916
DoPrepareForCtcLoss(Stream * stream,dnn::DataType element_type,const dnn::RnnStateTensorDescriptor & probs_desc,const dnn::RnnStateTensorDescriptor & grads_desc,absl::Span<const int> labels_data,absl::Span<const int> labels_lengths_data,absl::Span<const int> input_lengths_data,ScratchAllocator * scratch_allocator,DeviceMemory<uint8> * scratch_memory)3917 port::Status CudnnSupport::DoPrepareForCtcLoss(
3918 Stream* stream, dnn::DataType element_type,
3919 const dnn::RnnStateTensorDescriptor& probs_desc,
3920 const dnn::RnnStateTensorDescriptor& grads_desc,
3921 absl::Span<const int> labels_data,
3922 absl::Span<const int> labels_lengths_data,
3923 absl::Span<const int> input_lengths_data,
3924 ScratchAllocator* scratch_allocator, DeviceMemory<uint8>* scratch_memory) {
3925 auto cudnn = cudnn_->GetHandle(parent_, stream);
3926 // Query the workspace size.
3927 size_t workspace_size_in_bytes = 0;
3928 #if CUDNN_VERSION >= 7603
3929 CudnnCtcLossDescriptor cudnn_ctc_loss_desc(ToCudnnDataType(element_type));
3930 const CudnnRnnStateTensorDescriptor& cudnn_probs_desc =
3931 static_cast<const CudnnRnnStateTensorDescriptor&>(probs_desc);
3932 const CudnnRnnStateTensorDescriptor& cudnn_grads_desc =
3933 static_cast<const CudnnRnnStateTensorDescriptor&>(grads_desc);
3934 RETURN_IF_CUDNN_ERROR(cudnnGetCTCLossWorkspaceSize(
3935 /*handle=*/cudnn.handle(), /*probsDesc=*/cudnn_probs_desc.handle(),
3936 /*gradientsDesc=*/cudnn_grads_desc.handle(),
3937 /*labels=*/labels_data.data(),
3938 /*labelLengths=*/labels_lengths_data.data(),
3939 /*inputLengths=*/input_lengths_data.data(),
3940 /*algo=*/CUDNN_CTC_LOSS_ALGO_NON_DETERMINISTIC,
3941 /*ctcLossDesc=*/cudnn_ctc_loss_desc.handle(),
3942 /*sizeInBytes=*/&workspace_size_in_bytes));
3943 #else
3944 return port::Status(port::error::INVALID_ARGUMENT,
3945 "No supported cudnnGetCTCLossWorkspaceSize when "
3946 "CUDNN_VERSION < 7.6.3");
3947 #endif
3948 // Allocate the workspace.
3949 if (workspace_size_in_bytes == 0) {
3950 *scratch_memory = DeviceMemory<uint8>();
3951 return port::Status::OK();
3952 }
3953 const auto scratch_or =
3954 scratch_allocator->AllocateBytes(workspace_size_in_bytes);
3955 if (scratch_or.ok()) {
3956 *scratch_memory = scratch_or.ValueOrDie();
3957 return port::Status::OK();
3958 }
3959 return port::InternalError(
3960 "Failed to allocate scratch memory for the CuDNN CTC Loss");
3961 }
3962
DoCtcLoss(Stream * stream,dnn::DataType element_type,const dnn::RnnStateTensorDescriptor & probs_desc,const DeviceMemoryBase probs_data,absl::Span<const int> labels_data,absl::Span<const int> labels_lengths_data,absl::Span<const int> input_lengths_data,DeviceMemoryBase costs_data,const dnn::RnnStateTensorDescriptor & grads_desc,DeviceMemoryBase grads_data,DeviceMemory<uint8> scratch_memory)3963 port::Status CudnnSupport::DoCtcLoss(
3964 Stream* stream, dnn::DataType element_type,
3965 const dnn::RnnStateTensorDescriptor& probs_desc,
3966 const DeviceMemoryBase probs_data,
3967
3968 absl::Span<const int> labels_data,
3969 absl::Span<const int> labels_lengths_data,
3970 absl::Span<const int> input_lengths_data, DeviceMemoryBase costs_data,
3971 const dnn::RnnStateTensorDescriptor& grads_desc,
3972 DeviceMemoryBase grads_data, DeviceMemory<uint8> scratch_memory) {
3973 // Current cuDNN CTC Loss only supports the float datatype
3974 if (CUDNN_VERSION < 7603 || element_type != dnn::DataType::kFloat) {
3975 return port::Status(port::error::INVALID_ARGUMENT,
3976 "CudnnCtcLossDescriptor is supported only when the "
3977 "CUDNN_VERSION >= 7.6.3 and DataType is float");
3978 }
3979 CudnnCtcLossDescriptor cudnn_ctc_loss_desc(ToCudnnDataType(element_type));
3980 const CudnnRnnStateTensorDescriptor& cudnn_probs_desc =
3981 static_cast<const CudnnRnnStateTensorDescriptor&>(probs_desc);
3982 const CudnnRnnStateTensorDescriptor& cudnn_grads_desc =
3983 static_cast<const CudnnRnnStateTensorDescriptor&>(grads_desc);
3984 return DoCtcLossImpl(stream, cudnn_probs_desc, probs_data, labels_data,
3985 labels_lengths_data, input_lengths_data, costs_data,
3986 cudnn_grads_desc, grads_data, cudnn_ctc_loss_desc,
3987 scratch_memory);
3988 }
3989
DoTransformTensor(Stream * stream,const dnn::BatchDescriptor & input_desc,dnn::DataType input_type,const DeviceMemoryBase & input_data,const dnn::BatchDescriptor & output_desc,dnn::DataType output_type,float scale,DeviceMemoryBase * output_data)3990 bool CudnnSupport::DoTransformTensor(Stream* stream,
3991 const dnn::BatchDescriptor& input_desc,
3992 dnn::DataType input_type,
3993 const DeviceMemoryBase& input_data,
3994 const dnn::BatchDescriptor& output_desc,
3995 dnn::DataType output_type, float scale,
3996 DeviceMemoryBase* output_data) {
3997 float beta = 0.0f;
3998 CudnnTensorDescriptor input_tensor_desc(
3999 input_desc, ToCudnnDataType(input_type, input_desc.layout()));
4000 CudnnTensorDescriptor output_tensor_desc(
4001 output_desc, ToCudnnDataType(output_type, output_desc.layout()));
4002 auto cudnn = cudnn_->GetHandle(parent_, stream);
4003 const auto status = [&] {
4004 RETURN_IF_CUDNN_ERROR(cudnnTransformTensor(
4005 cudnn.handle(), &scale, input_tensor_desc.handle(), input_data.opaque(),
4006 &beta, output_tensor_desc.handle(), output_data->opaque()));
4007 return port::Status::OK();
4008 }();
4009 return IsStatusOk(status, /*report_error=*/true);
4010 }
4011
4012 template <class T>
DoConvolveBackwardBiasImpl(Stream * stream,const dnn::BatchDescriptor & input_descriptor,const DeviceMemory<T> & input_data,const dnn::BatchDescriptor & bias_descriptor,DeviceMemory<T> * backward_bias_data)4013 port::Status CudnnSupport::DoConvolveBackwardBiasImpl(
4014 Stream* stream, const dnn::BatchDescriptor& input_descriptor,
4015 const DeviceMemory<T>& input_data,
4016 const dnn::BatchDescriptor& bias_descriptor,
4017 DeviceMemory<T>* backward_bias_data) {
4018 cudnnDataType_t cudnn_type = GetCudnnDataType<T>();
4019 CudnnTensorDescriptor input_nd(input_descriptor, cudnn_type);
4020 CudnnTensorDescriptor bias_nd(bias_descriptor, cudnn_type);
4021
4022 // Alpha is the scaling factor for input.
4023 float alpha = 1.0;
4024 // Beta is the scaling factor for output.
4025 float beta = 0.0;
4026
4027 auto cudnn = cudnn_->GetHandle(parent_, stream);
4028 RETURN_IF_CUDNN_ERROR(cudnnConvolutionBackwardBias(
4029 cudnn.handle(), &alpha, input_nd.handle(), input_data.opaque(), &beta,
4030 bias_nd.handle(), backward_bias_data->opaque()));
4031 return port::Status::OK();
4032 }
4033
DoConvolveBackwardBias(Stream * stream,const dnn::BatchDescriptor & input_descriptor,const DeviceMemory<double> & input_data,const dnn::BatchDescriptor & bias_descriptor,DeviceMemory<double> * backward_bias_data)4034 bool CudnnSupport::DoConvolveBackwardBias(
4035 Stream* stream, const dnn::BatchDescriptor& input_descriptor,
4036 const DeviceMemory<double>& input_data,
4037 const dnn::BatchDescriptor& bias_descriptor,
4038 DeviceMemory<double>* backward_bias_data) {
4039 return IsStatusOk(
4040 DoConvolveBackwardBiasImpl(stream, input_descriptor, input_data,
4041 bias_descriptor, backward_bias_data),
4042 /*report_error=*/true);
4043 }
4044
DoConvolveBackwardBias(Stream * stream,const dnn::BatchDescriptor & input_descriptor,const DeviceMemory<float> & input_data,const dnn::BatchDescriptor & bias_descriptor,DeviceMemory<float> * backward_bias_data)4045 bool CudnnSupport::DoConvolveBackwardBias(
4046 Stream* stream, const dnn::BatchDescriptor& input_descriptor,
4047 const DeviceMemory<float>& input_data,
4048 const dnn::BatchDescriptor& bias_descriptor,
4049 DeviceMemory<float>* backward_bias_data) {
4050 return IsStatusOk(
4051 DoConvolveBackwardBiasImpl(stream, input_descriptor, input_data,
4052 bias_descriptor, backward_bias_data),
4053 /*report_error=*/true);
4054 }
4055
DoConvolveBackwardBias(Stream * stream,const dnn::BatchDescriptor & input_descriptor,const DeviceMemory<Eigen::half> & input_data,const dnn::BatchDescriptor & bias_descriptor,DeviceMemory<Eigen::half> * backward_bias_data)4056 bool CudnnSupport::DoConvolveBackwardBias(
4057 Stream* stream, const dnn::BatchDescriptor& input_descriptor,
4058 const DeviceMemory<Eigen::half>& input_data,
4059 const dnn::BatchDescriptor& bias_descriptor,
4060 DeviceMemory<Eigen::half>* backward_bias_data) {
4061 return IsStatusOk(
4062 DoConvolveBackwardBiasImpl(stream, input_descriptor, input_data,
4063 bias_descriptor, backward_bias_data),
4064 /*report_error=*/true);
4065 }
4066
DoMatMul(Stream * stream,const DeviceMemory<float> & input_data,const DeviceMemory<float> & weights,const dnn::BatchDescriptor & input_dimensions,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<float> * output_data)4067 bool CudnnSupport::DoMatMul(Stream* stream,
4068 const DeviceMemory<float>& input_data,
4069 const DeviceMemory<float>& weights,
4070 const dnn::BatchDescriptor& input_dimensions,
4071 const dnn::BatchDescriptor& output_dimensions,
4072 DeviceMemory<float>* output_data) {
4073 if (input_dimensions.count() != output_dimensions.count()) {
4074 LOG(ERROR) << "MatMul input and output dimensions are not compatible.";
4075 return false;
4076 }
4077
4078 // We do not permute the input or output, instead we just
4079 // reinterpret the layout. We are working with row-major matrices
4080 // and the rows of the input and output correspond to batch, so
4081 // batch has to be outermost in both the input and output.
4082 //
4083 // By adding transposes to the BLAS gemm call we could perhaps make
4084 // the kYXDepthBatch layout work as well, but there has been no need
4085 // for that so far.
4086 if (input_dimensions.layout() != dnn::DataLayout::kBatchYXDepth &&
4087 input_dimensions.layout() != dnn::DataLayout::kBatchDepthYX) {
4088 LOG(ERROR) << "Unsupported MatMul input layout.";
4089 return false;
4090 }
4091 if (output_dimensions.layout() != dnn::DataLayout::kBatchYXDepth &&
4092 output_dimensions.layout() != dnn::DataLayout::kBatchDepthYX) {
4093 LOG(ERROR) << "Unsupported MatMul output layout.";
4094 return false;
4095 }
4096
4097 if (output_dimensions.width() == 1 && output_dimensions.height() == 1) {
4098 // This is a fast path that also supports the kBatchYXDepth layout.
4099
4100 // The matrices here are in row-major format while BLAS expects
4101 // column-major, i.e. our matrices are transposed as far as BLAS
4102 // is concerned. So we need to compute output^T =
4103 // input^T*weights^T. There is no parameter for transposing the
4104 // output in BLAS gemm, but instead we can transpose both sides of
4105 // the equality to see that this is equivalent to
4106 // output=weights*input. So we only need to swap the order of
4107 // weights and input in the matrix product to correct for the
4108 // row-major versus column-major difference.
4109 const float alpha = 1.0f; // Take the matrix product without scaling it.
4110 const float beta = 0.0f; // Ignore the original values in output_data.
4111 const int64 m = output_dimensions.NodesAcrossFeatureMaps();
4112 const int64 n = input_dimensions.count();
4113 const int64 k = input_dimensions.NodesAcrossFeatureMaps();
4114 stream->ThenBlasGemm(blas::Transpose::kNoTranspose,
4115 blas::Transpose::kNoTranspose, m, n, k, alpha, weights,
4116 m, input_data, k, beta, output_data, m);
4117 } else {
4118 // This is a slower and more complex path that supports output
4119 // width() * height() > 1, though it only supports the
4120 // kBatchYXDepth layout. Does support kBatchDepthYX if output
4121 // feature_map_count() == 1, as then there is no difference
4122 // between the two layouts.
4123 //
4124 // The operation here is the same as above, except that we have to
4125 // do the matrix multiplication for each (y,x) output coordinate
4126 // separately. We then interpret weights as containing K = width()
4127 // * height() different matrices, which we all multiply onto the
4128 // matrix from input_data, yielding K matrix products. We then
4129 // combine these together into one matrix by concatenating all the
4130 // first rows of these matrices, then all the seconds rows and so
4131 // on. We can do this with a batched matrix multiplication, where
4132 // the result is written to a different submatrix of the output
4133 // for each matrix multiplication.
4134 //
4135 // The reason that we only support the kBatchYXDepth output layout
4136 // is that we have to do something in the depth for each (y,x)
4137 // coordinate. The kBatchYXDepth layout has the depth information
4138 // for each point (y,x) in contiguous memory while the
4139 // kBatchDepthYX layout does not.
4140 //
4141 // TODO(broune): Consider a special case for when output depth ==
4142 // 1, as then possibly this could all be done as one matrix
4143 // multiplication instead of a batched one, which should be
4144 // faster. Another possibility would be to add a weights layout
4145 // parameter and then support kBatchDepthYX for a different
4146 // weights layout.
4147 if (output_dimensions.layout() != dnn::DataLayout::kBatchYXDepth &&
4148 !(output_dimensions.layout() == dnn::DataLayout::kBatchDepthYX &&
4149 output_dimensions.feature_map_count() == 1)) {
4150 LOG(ERROR) << "Unsupported MatMul output layout.";
4151 return false;
4152 }
4153
4154 const float alpha = 1.0f; // Take the matrix product without scaling it.
4155 const float beta = 0.0f; // Ignore the original values in output_data.
4156 const uint64 m = output_dimensions.feature_map_count();
4157 const uint64 n = input_dimensions.count();
4158 const uint64 k = input_dimensions.NodesAcrossFeatureMaps();
4159 const int lda = m;
4160 const int ldb = k;
4161 const int ldc = output_dimensions.NodesAcrossFeatureMaps();
4162 const int batch_count = output_dimensions.NodesPerFeatureMap();
4163
4164 std::vector<DeviceMemory<float>> a(batch_count);
4165 std::vector<DeviceMemory<float>> b(batch_count);
4166 std::vector<DeviceMemory<float>> c(batch_count);
4167 for (int i = 0; i < batch_count; ++i) {
4168 const int weights_offset = i * input_dimensions.NodesAcrossFeatureMaps() *
4169 output_dimensions.feature_map_count();
4170 a[i] = DeviceMemory<float>::MakeFromByteSize(
4171 const_cast<float*>(reinterpret_cast<const float*>(weights.opaque())) +
4172 weights_offset,
4173 weights.ElementCount() - weights_offset);
4174
4175 b[i] = input_data;
4176
4177 const int output_offset = i * output_dimensions.feature_map_count();
4178 c[i] = DeviceMemory<float>::MakeFromByteSize(
4179 const_cast<float*>(
4180 reinterpret_cast<const float*>(output_data->opaque())) +
4181 output_offset,
4182 output_data->ElementCount() - output_offset);
4183 }
4184 const auto toPtrs = [](std::vector<DeviceMemory<float>>& v) {
4185 std::vector<DeviceMemory<float>*> ptrs;
4186 ptrs.reserve(v.size());
4187 for (auto& mem : v) {
4188 ptrs.push_back(&mem);
4189 }
4190 return ptrs;
4191 };
4192
4193 stream->ThenBlasGemmBatched(blas::Transpose::kNoTranspose,
4194 blas::Transpose::kNoTranspose, m, n, k, alpha,
4195 toPtrs(a), lda, toPtrs(b), ldb, beta, toPtrs(c),
4196 ldc, batch_count);
4197 }
4198
4199 return stream->ok();
4200 }
4201
DoBiasAdd(Stream * stream,const DeviceMemory<float> & input_data,const DeviceMemory<float> & biases,const dnn::BatchDescriptor & dimensions,DeviceMemory<float> * output_data)4202 bool CudnnSupport::DoBiasAdd(Stream* stream,
4203 const DeviceMemory<float>& input_data,
4204 const DeviceMemory<float>& biases,
4205 const dnn::BatchDescriptor& dimensions,
4206 DeviceMemory<float>* output_data) {
4207 CudnnTensorDescriptor input_descriptor(dimensions, CUDNN_DATA_FLOAT);
4208
4209 dnn::BatchDescriptor bias_dimensions;
4210 bias_dimensions.set_count(1)
4211 .set_feature_map_count(dimensions.feature_map_count())
4212 .set_height(1)
4213 .set_width(1)
4214 .set_layout(dnn::DataLayout::kBatchYXDepth);
4215 CudnnTensorDescriptor bias_descriptor(bias_dimensions, CUDNN_DATA_FLOAT);
4216
4217 // cudnnAddTensor after R3 is in-place, so we need to copy input_data to
4218 // output_data before doing the addition, unless the input and
4219 // output are at the same address.
4220 if (input_data.opaque() != output_data->opaque()) {
4221 stream->ThenMemcpy(output_data, input_data,
4222 dimensions.ElementCount() * sizeof(float));
4223 if (!stream->ok()) {
4224 LOG(ERROR)
4225 << "stream " << stream
4226 << " could not enqueue a tensor copy as part of bias addition.";
4227 return false;
4228 }
4229 }
4230
4231 const float alpha = 1.0f;
4232 const float beta = 1.0f;
4233
4234 auto cudnn = cudnn_->GetHandle(parent_, stream);
4235
4236 const auto status = [&] {
4237 RETURN_IF_CUDNN_ERROR(cudnnAddTensor(
4238 cudnn.handle(), &alpha, bias_descriptor.handle(), biases.opaque(),
4239 &beta, input_descriptor.handle(), output_data->opaque()));
4240 return port::Status::OK();
4241 }();
4242 return IsStatusOk(status, /*report_error=*/true);
4243 }
4244
DoActivate(Stream * stream,dnn::ActivationMode activation_mode,const dnn::BatchDescriptor & dimensions,const DeviceMemory<float> & input_data,DeviceMemory<float> * output_data,uint64 options)4245 bool CudnnSupport::DoActivate(Stream* stream,
4246 dnn::ActivationMode activation_mode,
4247 const dnn::BatchDescriptor& dimensions,
4248 const DeviceMemory<float>& input_data,
4249 DeviceMemory<float>* output_data,
4250 uint64 options) {
4251 CudnnActivationDescriptor activation_desc(
4252 activation_mode, CUDNN_PROPAGATE_NAN, dimensions.value_max());
4253
4254 CudnnTensorDescriptor input_nd(dimensions, CUDNN_DATA_FLOAT);
4255 // Alpha is the input scaling factor.
4256 float alpha = 1.0;
4257 // Beta is the output scaling factor.
4258 float beta = 0.0;
4259
4260 auto cudnn = cudnn_->GetHandle(parent_, stream);
4261 const auto status = [&] {
4262 RETURN_IF_CUDNN_ERROR(cudnnActivationForward(
4263 cudnn.handle(), activation_desc.handle(), &alpha, input_nd.handle(),
4264 input_data.opaque(), &beta, input_nd.handle(), output_data->opaque()));
4265 return port::Status::OK();
4266 }();
4267 return IsStatusOk(status, /*report_error=*/true);
4268 }
4269
DoPoolForward(Stream * stream,const dnn::PoolingDescriptor & pooling_dimensions,const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<double> & input_data,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<double> * output_data,ScratchAllocator * workspace_allocator)4270 bool CudnnSupport::DoPoolForward(
4271 Stream* stream, const dnn::PoolingDescriptor& pooling_dimensions,
4272 const dnn::BatchDescriptor& input_dimensions,
4273 const DeviceMemory<double>& input_data,
4274 const dnn::BatchDescriptor& output_dimensions,
4275 DeviceMemory<double>* output_data, ScratchAllocator* workspace_allocator) {
4276 // Alpha is the scaling factor for input.
4277 double alpha = 1.0;
4278 // Beta is the scaling factor for output.
4279 double beta = 0.0;
4280
4281 CudnnTensorDescriptor src_desc(input_dimensions, CUDNN_DATA_DOUBLE);
4282 CudnnTensorDescriptor dest_desc(output_dimensions, CUDNN_DATA_DOUBLE);
4283 CudnnPoolingDescriptor pooling_desc(pooling_dimensions);
4284
4285 auto cudnn = cudnn_->GetHandle(parent_, stream);
4286 const auto status = [&] {
4287 RETURN_IF_CUDNN_ERROR(cudnnPoolingForward(
4288 cudnn.handle(), pooling_desc.handle(), &alpha, src_desc.handle(),
4289 input_data.opaque(), &beta, dest_desc.handle(), output_data->opaque()));
4290 return port::Status::OK();
4291 }();
4292 return IsStatusOk(status, /*report_error=*/true);
4293 }
4294
DoPoolForward(Stream * stream,const dnn::PoolingDescriptor & pooling_dimensions,const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<float> & input_data,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<float> * output_data,ScratchAllocator * workspace_allocator)4295 bool CudnnSupport::DoPoolForward(
4296 Stream* stream, const dnn::PoolingDescriptor& pooling_dimensions,
4297 const dnn::BatchDescriptor& input_dimensions,
4298 const DeviceMemory<float>& input_data,
4299 const dnn::BatchDescriptor& output_dimensions,
4300 DeviceMemory<float>* output_data, ScratchAllocator* workspace_allocator) {
4301 // Alpha is the scaling factor for input.
4302 float alpha = 1.0;
4303 // Beta is the scaling factor for output.
4304 float beta = 0.0;
4305
4306 CudnnTensorDescriptor src_desc(input_dimensions, CUDNN_DATA_FLOAT);
4307 CudnnTensorDescriptor dest_desc(output_dimensions, CUDNN_DATA_FLOAT);
4308 CudnnPoolingDescriptor pooling_desc(pooling_dimensions);
4309
4310 auto cudnn = cudnn_->GetHandle(parent_, stream);
4311 const auto status = [&] {
4312 RETURN_IF_CUDNN_ERROR(cudnnPoolingForward(
4313 cudnn.handle(), pooling_desc.handle(), &alpha, src_desc.handle(),
4314 input_data.opaque(), &beta, dest_desc.handle(), output_data->opaque()));
4315 return port::Status::OK();
4316 }();
4317 return IsStatusOk(status, /*report_error=*/true);
4318 }
4319
DoPoolForward(Stream * stream,const dnn::PoolingDescriptor & pooling_dimensions,const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<Eigen::half> & input_data,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<Eigen::half> * output_data,ScratchAllocator * workspace_allocator)4320 bool CudnnSupport::DoPoolForward(
4321 Stream* stream, const dnn::PoolingDescriptor& pooling_dimensions,
4322 const dnn::BatchDescriptor& input_dimensions,
4323 const DeviceMemory<Eigen::half>& input_data,
4324 const dnn::BatchDescriptor& output_dimensions,
4325 DeviceMemory<Eigen::half>* output_data,
4326 ScratchAllocator* workspace_allocator) {
4327 // Alpha is the scaling factor for input.
4328 float alpha = 1.0;
4329 // Beta is the scaling factor for output.
4330 float beta = 0.0;
4331
4332 CudnnTensorDescriptor src_desc(input_dimensions, CUDNN_DATA_HALF);
4333 CudnnTensorDescriptor dest_desc(output_dimensions, CUDNN_DATA_HALF);
4334 CudnnPoolingDescriptor pooling_desc(pooling_dimensions);
4335 auto cudnn = cudnn_->GetHandle(parent_, stream);
4336 const auto status = [&] {
4337 RETURN_IF_CUDNN_ERROR(cudnnPoolingForward(
4338 cudnn.handle(), pooling_desc.handle(), &alpha, src_desc.handle(),
4339 input_data.opaque(), &beta, dest_desc.handle(), output_data->opaque()));
4340 return port::Status::OK();
4341 }();
4342 return IsStatusOk(status, /*report_error=*/true);
4343 }
4344
DoPoolForward(Stream * stream,const dnn::PoolingDescriptor & pooling_dimensions,const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<int8> & input_data,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<int8> * output_data,ScratchAllocator * workspace_allocator)4345 bool CudnnSupport::DoPoolForward(
4346 Stream* stream, const dnn::PoolingDescriptor& pooling_dimensions,
4347 const dnn::BatchDescriptor& input_dimensions,
4348 const DeviceMemory<int8>& input_data,
4349 const dnn::BatchDescriptor& output_dimensions,
4350 DeviceMemory<int8>* output_data, ScratchAllocator* workspace_allocator) {
4351 // Alpha is the scaling factor for input.
4352 float alpha = 1.0;
4353 // Beta is the scaling factor for output.
4354 float beta = 0.0;
4355
4356 CudnnTensorDescriptor src_desc(input_dimensions, CUDNN_DATA_INT8);
4357 CudnnTensorDescriptor dest_desc(output_dimensions, CUDNN_DATA_INT8);
4358 CudnnPoolingDescriptor pooling_desc(pooling_dimensions);
4359
4360 auto cudnn = cudnn_->GetHandle(parent_, stream);
4361 const auto status = [&] {
4362 RETURN_IF_CUDNN_ERROR(cudnnPoolingForward(
4363 cudnn.handle(), pooling_desc.handle(), &alpha, src_desc.handle(),
4364 input_data.opaque(), &beta, dest_desc.handle(), output_data->opaque()));
4365 return port::Status::OK();
4366 }();
4367 return IsStatusOk(status, /*report_error=*/true);
4368 }
4369
DoPoolBackward(Stream * stream,const dnn::PoolingDescriptor & pooling_dimensions,const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<double> & input_data,const dnn::BatchDescriptor & output_dimensions,const DeviceMemory<double> & output_data,const DeviceMemory<double> & input_diff_data,DeviceMemory<double> * output_diff_data,ScratchAllocator * workspace_allocator)4370 bool CudnnSupport::DoPoolBackward(
4371 Stream* stream, const dnn::PoolingDescriptor& pooling_dimensions,
4372 const dnn::BatchDescriptor& input_dimensions,
4373 const DeviceMemory<double>& input_data,
4374 const dnn::BatchDescriptor& output_dimensions,
4375 const DeviceMemory<double>& output_data,
4376 const DeviceMemory<double>& input_diff_data,
4377 DeviceMemory<double>* output_diff_data,
4378 ScratchAllocator* workspace_allocator) {
4379 // Alpha is the scaling factor for input.
4380 double alpha = 1.0;
4381 // Beta is the scaling factor for output.
4382 double beta = 0.0;
4383
4384 CudnnTensorDescriptor src_desc(input_dimensions, CUDNN_DATA_DOUBLE);
4385 CudnnTensorDescriptor dest_desc(output_dimensions, CUDNN_DATA_DOUBLE);
4386 CudnnPoolingDescriptor pooling_desc(pooling_dimensions);
4387
4388 auto cudnn = cudnn_->GetHandle(parent_, stream);
4389 const auto status = [&] {
4390 RETURN_IF_CUDNN_ERROR(cudnnPoolingBackward(
4391 cudnn.handle(), pooling_desc.handle(), &alpha, dest_desc.handle(),
4392 output_data.opaque(), dest_desc.handle(), input_diff_data.opaque(),
4393 src_desc.handle(), input_data.opaque(), &beta, src_desc.handle(),
4394 output_diff_data->opaque()));
4395 return port::Status::OK();
4396 }();
4397 return IsStatusOk(status, /*report_error=*/true);
4398 }
4399
DoPoolBackward(Stream * stream,const dnn::PoolingDescriptor & pooling_dimensions,const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<float> & input_data,const dnn::BatchDescriptor & output_dimensions,const DeviceMemory<float> & output_data,const DeviceMemory<float> & input_diff_data,DeviceMemory<float> * output_diff_data,ScratchAllocator * workspace_allocator)4400 bool CudnnSupport::DoPoolBackward(
4401 Stream* stream, const dnn::PoolingDescriptor& pooling_dimensions,
4402 const dnn::BatchDescriptor& input_dimensions,
4403 const DeviceMemory<float>& input_data,
4404 const dnn::BatchDescriptor& output_dimensions,
4405 const DeviceMemory<float>& output_data,
4406 const DeviceMemory<float>& input_diff_data,
4407 DeviceMemory<float>* output_diff_data,
4408 ScratchAllocator* workspace_allocator) {
4409 // Alpha is the scaling factor for input.
4410 float alpha = 1.0;
4411 // Beta is the scaling factor for output.
4412 float beta = 0.0;
4413
4414 CudnnTensorDescriptor src_desc(input_dimensions, CUDNN_DATA_FLOAT);
4415 CudnnTensorDescriptor dest_desc(output_dimensions, CUDNN_DATA_FLOAT);
4416 CudnnPoolingDescriptor pooling_desc(pooling_dimensions);
4417
4418 auto cudnn = cudnn_->GetHandle(parent_, stream);
4419 const auto status = [&] {
4420 RETURN_IF_CUDNN_ERROR(cudnnPoolingBackward(
4421 cudnn.handle(), pooling_desc.handle(), &alpha, dest_desc.handle(),
4422 output_data.opaque(), dest_desc.handle(), input_diff_data.opaque(),
4423 src_desc.handle(), input_data.opaque(), &beta, src_desc.handle(),
4424 output_diff_data->opaque()));
4425 return port::Status::OK();
4426 }();
4427 return IsStatusOk(status, /*report_error=*/true);
4428 }
4429
DoPoolBackward(Stream * stream,const dnn::PoolingDescriptor & pooling_dimensions,const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<Eigen::half> & input_data,const dnn::BatchDescriptor & output_dimensions,const DeviceMemory<Eigen::half> & output_data,const DeviceMemory<Eigen::half> & input_diff_data,DeviceMemory<Eigen::half> * output_diff_data,ScratchAllocator * workspace_allocator)4430 bool CudnnSupport::DoPoolBackward(
4431 Stream* stream, const dnn::PoolingDescriptor& pooling_dimensions,
4432 const dnn::BatchDescriptor& input_dimensions,
4433 const DeviceMemory<Eigen::half>& input_data,
4434 const dnn::BatchDescriptor& output_dimensions,
4435 const DeviceMemory<Eigen::half>& output_data,
4436 const DeviceMemory<Eigen::half>& input_diff_data,
4437 DeviceMemory<Eigen::half>* output_diff_data,
4438 ScratchAllocator* workspace_allocator) {
4439 // Alpha is the scaling factor for input.
4440 float alpha = 1.0;
4441 // Beta is the scaling factor for output.
4442 float beta = 0.0;
4443
4444 CudnnTensorDescriptor src_desc(input_dimensions, CUDNN_DATA_HALF);
4445 CudnnTensorDescriptor dest_desc(output_dimensions, CUDNN_DATA_HALF);
4446 CudnnPoolingDescriptor pooling_desc(pooling_dimensions);
4447
4448 auto cudnn = cudnn_->GetHandle(parent_, stream);
4449 const auto status = [&] {
4450 RETURN_IF_CUDNN_ERROR(cudnnPoolingBackward(
4451 cudnn.handle(), pooling_desc.handle(), &alpha, dest_desc.handle(),
4452 output_data.opaque(), dest_desc.handle(), input_diff_data.opaque(),
4453 src_desc.handle(), input_data.opaque(), &beta, src_desc.handle(),
4454 output_diff_data->opaque()));
4455 return port::Status::OK();
4456 }();
4457 return IsStatusOk(status, /*report_error=*/true);
4458 }
4459
DoNormalizeWithDimensions(Stream * stream,const dnn::NormalizeDescriptor & normalize_descriptor,const dnn::BatchDescriptor & dimensions,const DeviceMemory<float> & input_data,DeviceMemory<float> * output_data)4460 bool CudnnSupport::DoNormalizeWithDimensions(
4461 Stream* stream, const dnn::NormalizeDescriptor& normalize_descriptor,
4462 const dnn::BatchDescriptor& dimensions,
4463 const DeviceMemory<float>& input_data, DeviceMemory<float>* output_data) {
4464 // Check for unsupported modes.
4465 if (normalize_descriptor.wrap_around()) {
4466 LOG(ERROR) << "CUDA LRN does not support cudnn-around mode";
4467 return false;
4468 }
4469 if (normalize_descriptor.segment_size()) {
4470 LOG(ERROR) << "CUDA LRN does not support segmentation";
4471 return false;
4472 }
4473
4474 CudnnTensorDescriptor dims(dimensions, CUDNN_DATA_FLOAT);
4475 CudnnNormalizeDescriptor normalize(normalize_descriptor);
4476
4477 // Alpha is the scaling factor for input.
4478 float alpha = 1.0f;
4479 // Beta is the scaling factor for output.
4480 float beta = 0.0f;
4481
4482 auto cudnn = cudnn_->GetHandle(parent_, stream);
4483
4484 // Launch the normalization.
4485 const auto status = [&] {
4486 RETURN_IF_CUDNN_ERROR(cudnnLRNCrossChannelForward(
4487 cudnn.handle(), normalize.handle(), CUDNN_LRN_CROSS_CHANNEL_DIM1,
4488 &alpha, dims.handle(), input_data.opaque(), &beta, dims.handle(),
4489 output_data->opaque()));
4490 return port::Status::OK();
4491 }();
4492 return IsStatusOk(status, /*report_error=*/true);
4493 }
4494
DoNormalizeBackwardWithDimensions(Stream * stream,const dnn::NormalizeDescriptor & normalize_descriptor,const dnn::BatchDescriptor & dimensions,const DeviceMemory<float> & raw_data,const DeviceMemory<float> & normalized_data,const DeviceMemory<float> & normalized_variable_gradient,DeviceMemory<float> * raw_variable_gradient,ScratchAllocator * workspace_allocator)4495 bool CudnnSupport::DoNormalizeBackwardWithDimensions(
4496 Stream* stream, const dnn::NormalizeDescriptor& normalize_descriptor,
4497 const dnn::BatchDescriptor& dimensions, const DeviceMemory<float>& raw_data,
4498 const DeviceMemory<float>& normalized_data,
4499 const DeviceMemory<float>& normalized_variable_gradient,
4500 DeviceMemory<float>* raw_variable_gradient,
4501 ScratchAllocator* workspace_allocator) {
4502 // Check for unsupported modes.
4503 if (normalize_descriptor.wrap_around()) {
4504 LOG(ERROR) << "CUDA LRN does not support cudnn-around mode";
4505 return false;
4506 }
4507 if (normalize_descriptor.segment_size()) {
4508 LOG(ERROR) << "CUDA LRN does not support segmentation";
4509 return false;
4510 }
4511
4512 CudnnTensorDescriptor dims(dimensions, CUDNN_DATA_FLOAT);
4513 CudnnNormalizeDescriptor normalize(normalize_descriptor);
4514
4515 float alpha = 1.0f;
4516 float beta = 0.0f;
4517
4518 auto cudnn = cudnn_->GetHandle(parent_, stream);
4519 const auto status = [&] {
4520 RETURN_IF_CUDNN_ERROR(cudnnLRNCrossChannelBackward(
4521 cudnn.handle(), normalize.handle(), CUDNN_LRN_CROSS_CHANNEL_DIM1,
4522 &alpha, dims.handle(), normalized_data.opaque(), dims.handle(),
4523 normalized_variable_gradient.opaque(), dims.handle(), raw_data.opaque(),
4524 &beta, dims.handle(), raw_variable_gradient->opaque()));
4525 return port::Status::OK();
4526 }();
4527 return IsStatusOk(status, /*report_error=*/true);
4528 }
4529
DoDepthConcatenate(Stream * stream,port::ArraySlice<dnn::BatchDescriptor> input_dimensions,port::ArraySlice<const DeviceMemory<float> * > input_data,DeviceMemory<float> * output_data)4530 bool CudnnSupport::DoDepthConcatenate(
4531 Stream* stream, port::ArraySlice<dnn::BatchDescriptor> input_dimensions,
4532 port::ArraySlice<const DeviceMemory<float>*> input_data,
4533 DeviceMemory<float>* output_data) {
4534 CHECK_EQ(input_dimensions.size(), input_data.size());
4535
4536 for (const auto& dimensions : input_dimensions) {
4537 if (dimensions.layout() != dnn::DataLayout::kBatchDepthYX) {
4538 LOG(ERROR) << "CudnnSupport::DoDepthConcatenate currently only "
4539 "supports the kBatchDepthYX layout.";
4540 return false;
4541 }
4542 }
4543
4544 if (input_dimensions.empty()) {
4545 return true; // Nothing to do.
4546 }
4547
4548 dnn::BatchDescriptor output_dimensions =
4549 dnn::BatchDescriptor::DepthConcatenateOutputDescriptor(input_dimensions);
4550
4551 const int64 area = output_dimensions.width() * output_dimensions.height();
4552 const auto index = [area](int64 batch, int64 depth, int64 yx,
4553 int64 max_depth) {
4554 return (batch * max_depth + depth) * area + yx;
4555 };
4556
4557 std::vector<float> output_host(output_dimensions.ElementCount());
4558 std::vector<float> tmp;
4559 int64 depth_sum = 0;
4560 for (size_t i = 0; i < input_data.size(); ++i) {
4561 const auto& dimensions = input_dimensions[i];
4562 tmp.resize(dimensions.ElementCount());
4563 stream->ThenMemcpyD2H<float>(*input_data[i], absl::MakeSpan(tmp));
4564 port::Status block_status = stream->BlockHostUntilDone();
4565 if (!block_status.ok()) {
4566 LOG(ERROR) << "BlockHostUntilDone failed: " << block_status;
4567 return false;
4568 }
4569
4570 for (int64 batch = 0; batch < output_dimensions.count(); ++batch) {
4571 for (int64 yx = 0; yx < area; ++yx) {
4572 for (int64 depth = 0; depth < dimensions.feature_map_count(); ++depth) {
4573 LOG(INFO) << output_dimensions.ElementCount() << ' ' << batch << ' '
4574 << yx << ' ' << depth;
4575 output_host[index(batch, depth + depth_sum, yx,
4576 output_dimensions.feature_map_count())] =
4577 tmp[index(batch, depth, yx, dimensions.feature_map_count())];
4578 }
4579 }
4580 }
4581 depth_sum += dimensions.feature_map_count();
4582 }
4583 stream->ThenMemcpyH2D<float>(output_host, output_data);
4584 return true;
4585 }
4586
DoElementwiseOperate(Stream * stream,dnn::ElementwiseOperation operation,port::ArraySlice<dnn::BatchDescriptor> input_dimensions,port::ArraySlice<const DeviceMemory<float> * > input_data,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<float> * output_data)4587 bool CudnnSupport::DoElementwiseOperate(
4588 Stream* stream, dnn::ElementwiseOperation operation,
4589 port::ArraySlice<dnn::BatchDescriptor> input_dimensions,
4590 port::ArraySlice<const DeviceMemory<float>*> input_data,
4591 const dnn::BatchDescriptor& output_dimensions,
4592 DeviceMemory<float>* output_data) {
4593 LOG(FATAL) << "not yet implemented"; // TODO(leary)
4594 return false;
4595 }
4596
DoXYPad(Stream * stream,const dnn::BatchDescriptor & dimensions,const DeviceMemory<float> & input_data,int64 left_pad,int64 right_pad,int64 top_pad,int64 bottom_pad,DeviceMemory<float> * output_data)4597 bool CudnnSupport::DoXYPad(Stream* stream,
4598 const dnn::BatchDescriptor& dimensions,
4599 const DeviceMemory<float>& input_data,
4600 int64 left_pad, int64 right_pad, int64 top_pad,
4601 int64 bottom_pad, DeviceMemory<float>* output_data) {
4602 LOG(FATAL) << "not yet implemented"; // TODO(leary)
4603 return false;
4604 }
4605
DoXYSlice(Stream * stream,const dnn::BatchDescriptor & dimensions,const DeviceMemory<float> & input_data,int64 left_trim,int64 right_trim,int64 top_trim,int64 bottom_trim,DeviceMemory<float> * output_data)4606 bool CudnnSupport::DoXYSlice(Stream* stream,
4607 const dnn::BatchDescriptor& dimensions,
4608 const DeviceMemory<float>& input_data,
4609 int64 left_trim, int64 right_trim, int64 top_trim,
4610 int64 bottom_trim,
4611 DeviceMemory<float>* output_data) {
4612 LOG(FATAL) << "not yet implemented"; // TODO(leary)
4613 return false;
4614 }
4615
DoMemcpyD2HQuantized(Stream * stream,const DeviceMemory<float> & gpu_unquantized_src,dnn::QuantizedActivationMode mode,void * host_dst,int64 size)4616 bool CudnnSupport::DoMemcpyD2HQuantized(
4617 Stream* stream, const DeviceMemory<float>& gpu_unquantized_src,
4618 dnn::QuantizedActivationMode mode, void* host_dst, int64 size) {
4619 LOG(ERROR) << "quantized memcpy not supported by cuDNN";
4620 return false;
4621 }
4622
DoMemcpyH2DQuantized(Stream * stream,const void * host_src,int64 size,dnn::QuantizedActivationMode mode,DeviceMemory<float> * gpu_unquantized_dst)4623 bool CudnnSupport::DoMemcpyH2DQuantized(
4624 Stream* stream, const void* host_src, int64 size,
4625 dnn::QuantizedActivationMode mode,
4626 DeviceMemory<float>* gpu_unquantized_dst) {
4627 LOG(ERROR) << "quantized memcpy not supported by cuDNN";
4628 return false;
4629 }
4630
DeriveOutputBatchDescriptor(const dnn::BatchDescriptor & batch_descriptor,const dnn::FilterDescriptor & filter_descriptor,const dnn::ConvolutionDescriptor & convolution_descriptor,dnn::BatchDescriptor * output_batch_descriptor)4631 bool CudnnSupport::DeriveOutputBatchDescriptor(
4632 const dnn::BatchDescriptor& batch_descriptor,
4633 const dnn::FilterDescriptor& filter_descriptor,
4634 const dnn::ConvolutionDescriptor& convolution_descriptor,
4635 dnn::BatchDescriptor* output_batch_descriptor) {
4636 CudnnTensorDescriptor input_nd(batch_descriptor, CUDNN_DATA_FLOAT);
4637 CudnnFilterDescriptor filter(filter_descriptor, CUDNN_DATA_FLOAT);
4638 CudnnConvolutionDescriptor conv(convolution_descriptor, CUDNN_DATA_FLOAT);
4639
4640 int dn = batch_descriptor.ndims() + 2;
4641 std::vector<int> dims(dn); // in BDYX
4642 const auto status = [&] {
4643 RETURN_IF_CUDNN_ERROR(cudnnGetConvolutionNdForwardOutputDim(
4644 conv.handle(), input_nd.handle(), filter.handle(), dn, dims.data()));
4645 output_batch_descriptor->set_count(dims[0])
4646 .set_feature_map_count(dims[1])
4647 .set_layout(batch_descriptor.layout());
4648
4649 for (int i = 0; i < batch_descriptor.ndims(); i++) {
4650 output_batch_descriptor->set_spatial_dim(static_cast<dnn::DimIndex>(i),
4651 dims.rbegin()[i]);
4652 }
4653 return port::Status::OK();
4654 }();
4655 return IsStatusOk(status, /*report_error=*/true);
4656 }
4657
4658 } // namespace gpu
4659
initialize_cudnn()4660 void initialize_cudnn() {
4661 port::Status status =
4662 PluginRegistry::Instance()->RegisterFactory<PluginRegistry::DnnFactory>(
4663 cuda::kCudaPlatformId, gpu::kCuDnnPlugin, "cuDNN",
4664 [](internal::StreamExecutorInterface* parent) -> dnn::DnnSupport* {
4665 gpu::GpuExecutor* cuda_executor =
4666 dynamic_cast<gpu::GpuExecutor*>(parent);
4667 if (cuda_executor == nullptr) {
4668 LOG(ERROR) << "Attempting to initialize an instance of the cuDNN "
4669 << "support library with a non-CUDA StreamExecutor";
4670 return nullptr;
4671 }
4672
4673 gpu::CudnnSupport* dnn = new gpu::CudnnSupport(cuda_executor);
4674 if (!dnn->Init().ok()) {
4675 // Note: Init() will log a more specific error.
4676 delete dnn;
4677 return nullptr;
4678 }
4679 return dnn;
4680 });
4681
4682 if (!status.ok()) {
4683 LOG(ERROR) << "Unable to register cuDNN factory: "
4684 << status.error_message();
4685 }
4686
4687 PluginRegistry::Instance()->SetDefaultFactory(
4688 cuda::kCudaPlatformId, PluginKind::kDnn, gpu::kCuDnnPlugin);
4689 }
4690
4691 } // namespace stream_executor
4692
4693 #pragma clang diagnostic pop
4694
4695 REGISTER_MODULE_INITIALIZER(register_cudnn,
4696 { stream_executor::initialize_cudnn(); });
4697