1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/stream_executor/cuda/cuda_dnn.h"
17
18 #include <functional>
19 #include <memory>
20 #include <utility>
21
22 #include "absl/strings/str_cat.h"
23 #include "third_party/eigen3/Eigen/Core"
24 #include "tensorflow/core/lib/core/errors.h"
25 #include "tensorflow/core/lib/strings/stringprintf.h"
26 #include "tensorflow/core/util/env_var.h"
27 #include "tensorflow/stream_executor/cuda/cuda_activation.h"
28 #include "tensorflow/stream_executor/cuda/cuda_diagnostics.h"
29 #include "tensorflow/stream_executor/cuda/cuda_driver.h"
30 #include "tensorflow/stream_executor/cuda/cuda_gpu_executor.h"
31 #include "tensorflow/stream_executor/cuda/cuda_platform_id.h"
32 #include "tensorflow/stream_executor/cuda/cuda_stream.h"
33 #include "tensorflow/stream_executor/cuda/cuda_timer.h"
34 #include "tensorflow/stream_executor/cuda/cudnn_version.h"
35 #include "tensorflow/stream_executor/dnn.h"
36 #include "tensorflow/stream_executor/lib/env.h"
37 #include "tensorflow/stream_executor/lib/error.h"
38 #include "tensorflow/stream_executor/lib/initialize.h"
39 #include "tensorflow/stream_executor/lib/mathutil.h"
40 #include "tensorflow/stream_executor/lib/threadpool.h"
41 #include "tensorflow/stream_executor/platform/logging.h"
42 #include "tensorflow/stream_executor/plugin_registry.h"
43 #include "tensorflow/stream_executor/scratch_allocator.h"
44 #include "tensorflow/stream_executor/stream.h"
45 #include "tensorflow/stream_executor/stream_executor_pimpl.h"
46 // clang-format off
47 #include "cuda/include/cudnn.h"
48 #include "absl/strings/string_view.h"
49 // clang-format on
50
51 #pragma clang diagnostic push
52
53 // Make sure that Eigen::half forward declaration in dnn.h matches the
54 // declaration in Eigen.
55 #pragma clang diagnostic warning "-Wmismatched-tags"
56
57 namespace stream_executor {
58 namespace gpu {
59
60 PLUGIN_REGISTRY_DEFINE_PLUGIN_ID(kCuDnnPlugin);
61
62 namespace {
63
64 static_assert(CUDNN_VERSION >= 6000, "cuDNN needs to be version 6.0 or higher");
65
66 // Exits the program if 'expr' doesn't return CUDNN_STATUS_SUCCESS.
67 #define CHECK_CUDNN_OK(expr) CHECK_EQ(expr, CUDNN_STATUS_SUCCESS)
68
69 // If 'expr' doesn't return CUDNN_STATUS_SUCCESS, returns from the current
70 // function with a non-successful port::Status.
71 #define RETURN_IF_CUDNN_ERROR(expr) \
72 do { \
73 cudnnStatus_t _status = expr; \
74 if (!SE_PREDICT_TRUE(_status == CUDNN_STATUS_SUCCESS)) { \
75 std::ostringstream oss; \
76 oss << ToString(_status) << "\nin " << __FILE__ << "(" << __LINE__ \
77 << "): '" << #expr << "'"; \
78 return port::Status(port::error::UNKNOWN, oss.str().c_str()); \
79 } \
80 } while (false)
81
82 // Converts (via narrowing) a type T value to a type U, and checks that the
83 // value has no value change due to the conversion.
84 template <typename WideT, typename NarrowT>
CheckedNarrowing(const WideT & wide)85 NarrowT CheckedNarrowing(const WideT& wide) {
86 NarrowT narrow = wide;
87 CHECK_EQ(narrow, wide)
88 << "checked narrowing failed; values not equal post-conversion";
89 return narrow;
90 }
91
ToString(cudnnStatus_t status)92 string ToString(cudnnStatus_t status) {
93 switch (status) {
94 case CUDNN_STATUS_SUCCESS:
95 return "CUDNN_STATUS_SUCCESS";
96 case CUDNN_STATUS_NOT_INITIALIZED:
97 return "CUDNN_STATUS_NOT_INITIALIZED";
98 case CUDNN_STATUS_ALLOC_FAILED:
99 return "CUDNN_STATUS_ALLOC_FAILED";
100 case CUDNN_STATUS_BAD_PARAM:
101 return "CUDNN_STATUS_BAD_PARAM";
102 case CUDNN_STATUS_INTERNAL_ERROR:
103 return "CUDNN_STATUS_INTERNAL_ERROR";
104 case CUDNN_STATUS_INVALID_VALUE:
105 return "CUDNN_STATUS_INVALID_VALUE";
106 case CUDNN_STATUS_ARCH_MISMATCH:
107 return "CUDNN_STATUS_ARCH_MISMATCH";
108 case CUDNN_STATUS_MAPPING_ERROR:
109 return "CUDNN_STATUS_MAPPING_ERROR";
110 case CUDNN_STATUS_EXECUTION_FAILED:
111 return "CUDNN_STATUS_EXECUTION_FAILED";
112 case CUDNN_STATUS_NOT_SUPPORTED:
113 return "CUDNN_STATUS_NOT_SUPPORTED";
114 case CUDNN_STATUS_LICENSE_ERROR:
115 return "CUDNN_STATUS_LICENSE_ERROR";
116 case CUDNN_STATUS_RUNTIME_PREREQUISITE_MISSING:
117 return "CUDNN_STATUS_RUNTIME_PREREQUISITE_MISSING";
118 #if CUDNN_VERSION >= 7000
119 case CUDNN_STATUS_RUNTIME_IN_PROGRESS:
120 return "CUDNN_STATUS_RUNTIME_IN_PROGRESS";
121 case CUDNN_STATUS_RUNTIME_FP_OVERFLOW:
122 return "CUDNN_STATUS_RUNTIME_FP_OVERFLOW";
123 #endif
124 default:
125 return absl::StrCat("<unknown cudnn status: ", static_cast<int>(status),
126 ">");
127 }
128 }
129
130 // RAII wrapper for all calls to cuDNN with a cuDNN handle argument.
131 //
132 // See CudnnAccess::GetHandle() for details.
133 class CudnnHandle {
134 public:
135 // Takes ownership of the executor context and the lock to access cuDNN
136 // using handle.
CudnnHandle(gpu::ScopedActivateExecutorContext context,mutex_lock lock,cudnnHandle_t handle)137 CudnnHandle(gpu::ScopedActivateExecutorContext context, mutex_lock lock,
138 cudnnHandle_t handle)
139 : context_(std::move(context)), lock_(std::move(lock)), handle_(handle) {}
140
141 // Returns cuDNN handle. To be passed directly to cuDNN APIs, don't keep
142 // a copy.
handle() const143 cudnnHandle_t handle() const { return handle_; }
144
145 private:
146 gpu::ScopedActivateExecutorContext context_;
147 mutex_lock lock_;
148 cudnnHandle_t handle_; // Not owned.
149 };
150
151 } // namespace
152
153 // Wraps a cuDNN handle and provides access to it through CudnnHandle
154 // instances, which also locks a mutex, acquires the CUDA context, and sets
155 // the stream that cuDNN should use to enqueue any work.
156 //
157 // Note: CudnnSupport::cudnn_ should be the only instantiation of this class.
158 class CudnnAccess {
159 public:
160 // Takes ownership of the handle.
CudnnAccess(cudnnHandle_t handle)161 explicit CudnnAccess(cudnnHandle_t handle) : handle_(handle) {}
162
~CudnnAccess()163 ~CudnnAccess() {
164 mutex_lock lock(mutex_);
165 cudnnDestroy(handle_);
166 }
167
168 // Creates a CudnnHandle instance for stream.
169 //
170 // cuDNN API calls using the same handle instance need to be serialized
171 // across threads. This is guaranteed by CudnnHandle instances locking the
172 // mutex owned by this class.
173 //
174 // Most cuDNN APIs taking a handle perform work on a CUDA stream. The
175 // CudnnHandle instance acquires the executor's CUDA context and sets cuDNN
176 // to use the provided stream.
177 //
178 // The stream argument may be null, which translates to the legacy default
179 // stream. See
180 // https://docs.nvidia.com/cuda/cuda-driver-api/stream-sync-behavior.html.
181 // The legacy default stream synchronizes with all other streams and it is
182 // therefore a bad idea (performance wise) to call any cuDNN APIs that
183 // enqueue work in the stream.
GetHandle(GpuExecutor * executor,Stream * stream)184 CudnnHandle GetHandle(GpuExecutor* executor, Stream* stream) {
185 mutex_lock lock(mutex_);
186 gpu::ScopedActivateExecutorContext context(executor);
187 CUstream cu_stream = stream ? AsGpuStreamValue(stream) : cudaStreamLegacy;
188 const auto status = cudnnSetStream(handle_, cu_stream);
189 CHECK_EQ(status, CUDNN_STATUS_SUCCESS) << "Failed to set cuDNN stream.";
190 return CudnnHandle(std::move(context), std::move(lock), handle_);
191 }
192
193 private:
194 // Guards the enqueueing of cuDNN operations via the handle_ below.
195 mutex mutex_;
196
197 // cuDNN library handle.
198 cudnnHandle_t handle_ GUARDED_BY(mutex_); // Owned.
199 };
200
201 namespace {
202
203 // A helper function to return the internal compute type for
204 // RNNs in cudnn.
205 cudnnDataType_t GetRnnComputeType(dnn::DataType data_type);
206
ToConvForwardAlgo(dnn::AlgorithmDesc algorithm)207 cudnnConvolutionFwdAlgo_t ToConvForwardAlgo(dnn::AlgorithmDesc algorithm) {
208 cudnnConvolutionFwdAlgo_t algo =
209 cudnnConvolutionFwdAlgo_t(algorithm.algo_id());
210 switch (algo) {
211 case CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM:
212 case CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM:
213 case CUDNN_CONVOLUTION_FWD_ALGO_GEMM:
214 case CUDNN_CONVOLUTION_FWD_ALGO_DIRECT:
215 case CUDNN_CONVOLUTION_FWD_ALGO_FFT:
216 case CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING:
217 case CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD:
218 case CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED:
219 return algo;
220 default:
221 LOG(FATAL) << "Unsupported Cudnn convolution forward algorithm: "
222 << algorithm.algo_id();
223 }
224 }
225
ToConvBackwardDataAlgo(dnn::AlgorithmDesc algorithm)226 cudnnConvolutionBwdDataAlgo_t ToConvBackwardDataAlgo(
227 dnn::AlgorithmDesc algorithm) {
228 cudnnConvolutionBwdDataAlgo_t algo =
229 cudnnConvolutionBwdDataAlgo_t(algorithm.algo_id());
230 switch (algo) {
231 case CUDNN_CONVOLUTION_BWD_DATA_ALGO_0:
232 case CUDNN_CONVOLUTION_BWD_DATA_ALGO_1:
233 case CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT:
234 case CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING:
235 case CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD:
236 case CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD_NONFUSED:
237 return algo;
238 default:
239 LOG(FATAL)
240 << "Unsupported Cudnn convolution backward algorithm for data: "
241 << algorithm.algo_id();
242 }
243 }
244
ToConvBackwardFilterAlgo(dnn::AlgorithmDesc algorithm)245 cudnnConvolutionBwdFilterAlgo_t ToConvBackwardFilterAlgo(
246 dnn::AlgorithmDesc algorithm) {
247 cudnnConvolutionBwdFilterAlgo_t algo =
248 cudnnConvolutionBwdFilterAlgo_t(algorithm.algo_id());
249 switch (algo) {
250 case CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0:
251 case CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1:
252 case CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT:
253 case CUDNN_CONVOLUTION_BWD_FILTER_ALGO_3:
254 // Based on cudnn.h, the following is not implemented.
255 // case CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD:
256 case CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD_NONFUSED:
257 return algo;
258 // Produces incorrect results for some shapes. Disabled for now, see
259 // NVIDIA bug 2072856. TODO(csigg): Only disable for subset of shapes.
260 // case CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT_TILING:
261 default:
262 LOG(FATAL)
263 << "Unsupported Cudnn convolution backward algorithm for filter: "
264 << algorithm.algo_id();
265 }
266 }
267
GetCudnnProperty(libraryPropertyType type)268 port::StatusOr<int> GetCudnnProperty(libraryPropertyType type) {
269 int value;
270 RETURN_IF_CUDNN_ERROR(cudnnGetProperty(type, &value));
271 return value;
272 }
273
ToCudnnRNNAlgo(absl::optional<dnn::AlgorithmDesc> algorithm)274 cudnnRNNAlgo_t ToCudnnRNNAlgo(absl::optional<dnn::AlgorithmDesc> algorithm) {
275 if (!algorithm.has_value()) {
276 return CUDNN_RNN_ALGO_STANDARD;
277 }
278 cudnnRNNAlgo_t algo = static_cast<cudnnRNNAlgo_t>(algorithm->algo_id());
279 switch (algo) {
280 case CUDNN_RNN_ALGO_STANDARD:
281 case CUDNN_RNN_ALGO_PERSIST_STATIC:
282 case CUDNN_RNN_ALGO_PERSIST_DYNAMIC:
283 return algo;
284 default:
285 LOG(FATAL) << "Unsupported Cudnn RNN algorithm: " << algorithm->algo_id();
286 }
287 }
288
GetLoadedCudnnVersion(CudnnVersion * version)289 port::Status GetLoadedCudnnVersion(CudnnVersion* version) {
290 SE_ASSIGN_OR_RETURN(version->major_version, GetCudnnProperty(MAJOR_VERSION));
291 SE_ASSIGN_OR_RETURN(version->minor_version, GetCudnnProperty(MINOR_VERSION));
292 SE_ASSIGN_OR_RETURN(version->patch_level, GetCudnnProperty(PATCH_LEVEL));
293 return port::Status::OK();
294 }
295
296 } // namespace
297
CudnnSupport(GpuExecutor * parent)298 CudnnSupport::CudnnSupport(GpuExecutor* parent) : parent_(parent) {}
299
Init()300 port::Status CudnnSupport::Init() {
301 ScopedActivateExecutorContext context(parent_);
302 cudnnHandle_t cudnn_handle = nullptr;
303 const auto status = cudnnCreate(&cudnn_handle);
304 if (status == CUDNN_STATUS_SUCCESS) {
305 CudnnVersion source_version(CUDNN_MAJOR, CUDNN_MINOR, CUDNN_PATCHLEVEL);
306
307 CudnnVersion loaded_version;
308 TF_RETURN_IF_ERROR(GetLoadedCudnnVersion(&loaded_version));
309 if (!IsSourceCompatibleWithCudnnLibrary(source_version, loaded_version)) {
310 const tensorflow::string error = absl::StrCat(
311 "Loaded runtime CuDNN library: ", loaded_version.ToString(),
312 " but source was compiled with: ", source_version.ToString(),
313 ". CuDNN library major and minor version needs to match or have "
314 "higher minor version in case of CuDNN 7.0 or later version. If "
315 "using a binary install, upgrade your CuDNN library. If building "
316 "from sources, make sure the library loaded at runtime is "
317 "compatible "
318 "with the version specified during compile configuration.");
319 LOG(ERROR) << error;
320 cudnnDestroy(cudnn_handle);
321 return port::Status(port::error::INTERNAL, error);
322 }
323
324 cudnn_.reset(new CudnnAccess(cudnn_handle));
325 return port::Status::OK();
326 }
327
328 CHECK_EQ(cudnn_handle, nullptr);
329 LOG(ERROR) << "Could not create cudnn handle: " << ToString(status);
330 if (status == CUDNN_STATUS_NOT_INITIALIZED) {
331 auto result = gpu::Diagnostician::FindKernelDriverVersion();
332 if (!result.ok()) {
333 LOG(ERROR) << "Error retrieving driver version: "
334 << cuda::DriverVersionStatusToString(result);
335 } else {
336 const auto& version = result.ValueOrDie();
337 LOG(ERROR) << "Possibly insufficient driver version: "
338 << cuda::DriverVersionToString(version);
339 }
340 }
341
342 return port::Status(port::error::INTERNAL,
343 absl::StrCat("cudnn library could not create a handle: ",
344 ToString(status)));
345 }
346
347 port::StatusOr<perftools::gputools::dnn::VersionInfo>
GetVersion()348 CudnnSupport::GetVersion() {
349 CudnnVersion version;
350 TF_RETURN_IF_ERROR(GetLoadedCudnnVersion(&version));
351 return perftools::gputools::dnn::VersionInfo(
352 version.major_version, version.minor_version, version.patch_level);
353 }
354
355 namespace {
356
357 // Deleter functors for cuDNN types that need to be deleted.
358 struct TensorDescriptorDeleter {
operator ()stream_executor::gpu::__anonf41f84630311::TensorDescriptorDeleter359 void operator()(cudnnTensorDescriptor_t descriptor) const {
360 CHECK_CUDNN_OK(cudnnDestroyTensorDescriptor(descriptor));
361 }
362 };
363 #if CUDNN_VERSION >= 7201
364 struct RNNDataDescriptorDeleter {
operator ()stream_executor::gpu::__anonf41f84630311::RNNDataDescriptorDeleter365 void operator()(cudnnRNNDataDescriptor_t descriptor) const {
366 CHECK_CUDNN_OK(cudnnDestroyRNNDataDescriptor(descriptor));
367 }
368 };
369 #endif
370 struct FilterDescriptorDeleter {
operator ()stream_executor::gpu::__anonf41f84630311::FilterDescriptorDeleter371 void operator()(cudnnFilterDescriptor_t descriptor) const {
372 CHECK_CUDNN_OK(cudnnDestroyFilterDescriptor(descriptor));
373 }
374 };
375 struct ConvolutionDescriptorDeleter {
operator ()stream_executor::gpu::__anonf41f84630311::ConvolutionDescriptorDeleter376 void operator()(cudnnConvolutionDescriptor_t descriptor) const {
377 CHECK_CUDNN_OK(cudnnDestroyConvolutionDescriptor(descriptor));
378 }
379 };
380 struct PoolingDescriptorDeleter {
operator ()stream_executor::gpu::__anonf41f84630311::PoolingDescriptorDeleter381 void operator()(cudnnPoolingDescriptor_t descriptor) const {
382 CHECK_CUDNN_OK(cudnnDestroyPoolingDescriptor(descriptor));
383 }
384 };
385 struct LrnDescriptorDeleter {
operator ()stream_executor::gpu::__anonf41f84630311::LrnDescriptorDeleter386 void operator()(cudnnLRNDescriptor_t descriptor) const {
387 CHECK_CUDNN_OK(cudnnDestroyLRNDescriptor(descriptor));
388 }
389 };
390
391 struct ActivationDescriptorDeleter {
operator ()stream_executor::gpu::__anonf41f84630311::ActivationDescriptorDeleter392 void operator()(cudnnActivationDescriptor_t descriptor) const {
393 CHECK_CUDNN_OK(cudnnDestroyActivationDescriptor(descriptor));
394 }
395 };
396 struct DropoutDescriptorDeleter {
operator ()stream_executor::gpu::__anonf41f84630311::DropoutDescriptorDeleter397 void operator()(cudnnDropoutDescriptor_t descriptor) const {
398 CHECK_CUDNN_OK(cudnnDestroyDropoutDescriptor(descriptor));
399 }
400 };
401 struct RnnDescriptorDeleter {
operator ()stream_executor::gpu::__anonf41f84630311::RnnDescriptorDeleter402 void operator()(cudnnRNNDescriptor_t descriptor) const {
403 CHECK_CUDNN_OK(cudnnDestroyRNNDescriptor(descriptor));
404 }
405 };
406 struct PersistentRnnPlanDeleter {
operator ()stream_executor::gpu::__anonf41f84630311::PersistentRnnPlanDeleter407 void operator()(cudnnPersistentRNNPlan_t plan) const {
408 CHECK_CUDNN_OK(cudnnDestroyPersistentRNNPlan(plan));
409 }
410 };
411
412 // RAII wrappers for cuDNN types.
413 using TensorDescriptor =
414 std::unique_ptr<cudnnTensorStruct, TensorDescriptorDeleter>;
415 #if CUDNN_VERSION >= 7201
416 using RNNDataDescriptor =
417 std::unique_ptr<cudnnRNNDataStruct, RNNDataDescriptorDeleter>;
418 #endif
419 using FilterDescriptor =
420 std::unique_ptr<cudnnFilterStruct, FilterDescriptorDeleter>;
421 using ConvolutionDescriptor =
422 std::unique_ptr<cudnnConvolutionStruct, ConvolutionDescriptorDeleter>;
423 using PoolingDescriptor =
424 std::unique_ptr<cudnnPoolingStruct, PoolingDescriptorDeleter>;
425 using LrnDescriptor = std::unique_ptr<cudnnLRNStruct, LrnDescriptorDeleter>;
426 using ActivationDescriptor =
427 std::unique_ptr<cudnnActivationStruct, ActivationDescriptorDeleter>;
428 using DropoutDescriptor =
429 std::unique_ptr<cudnnDropoutStruct, DropoutDescriptorDeleter>;
430 using RnnDescriptor = std::unique_ptr<cudnnRNNStruct, RnnDescriptorDeleter>;
431 using PersistentRnnPlan =
432 std::unique_ptr<cudnnPersistentRNNPlan, PersistentRnnPlanDeleter>;
433
434 // Factory methods for cuDNN types.
CreateTensorDescriptor()435 TensorDescriptor CreateTensorDescriptor() {
436 cudnnTensorDescriptor_t result;
437 CHECK_CUDNN_OK(cudnnCreateTensorDescriptor(&result));
438 return TensorDescriptor(result);
439 }
440 #if CUDNN_VERSION >= 7201
CreateRNNDataDescriptor()441 RNNDataDescriptor CreateRNNDataDescriptor() {
442 cudnnRNNDataDescriptor_t result;
443 CHECK_CUDNN_OK(cudnnCreateRNNDataDescriptor(&result));
444 return RNNDataDescriptor(result);
445 }
446 #endif
CreateFilterDescriptor()447 FilterDescriptor CreateFilterDescriptor() {
448 cudnnFilterDescriptor_t result;
449 CHECK_CUDNN_OK(cudnnCreateFilterDescriptor(&result));
450 return FilterDescriptor(result);
451 }
CreateConvolutionDescriptor()452 ConvolutionDescriptor CreateConvolutionDescriptor() {
453 cudnnConvolutionDescriptor_t result;
454 CHECK_CUDNN_OK(cudnnCreateConvolutionDescriptor(&result));
455 return ConvolutionDescriptor(result);
456 }
CreatePoolingDescriptor()457 PoolingDescriptor CreatePoolingDescriptor() {
458 cudnnPoolingDescriptor_t result;
459 CHECK_CUDNN_OK(cudnnCreatePoolingDescriptor(&result));
460 return PoolingDescriptor(result);
461 }
CreateLrnDescriptor()462 LrnDescriptor CreateLrnDescriptor() {
463 cudnnLRNDescriptor_t result;
464 CHECK_CUDNN_OK(cudnnCreateLRNDescriptor(&result));
465 return LrnDescriptor(result);
466 }
CreateActivationDescriptor()467 ActivationDescriptor CreateActivationDescriptor() {
468 cudnnActivationDescriptor_t result;
469 CHECK_CUDNN_OK(cudnnCreateActivationDescriptor(&result));
470 return ActivationDescriptor(result);
471 }
CreateDropoutDescriptor()472 DropoutDescriptor CreateDropoutDescriptor() {
473 cudnnDropoutDescriptor_t result;
474 CHECK_CUDNN_OK(cudnnCreateDropoutDescriptor(&result));
475 return DropoutDescriptor(result);
476 }
CreateRnnDescriptor()477 RnnDescriptor CreateRnnDescriptor() {
478 cudnnRNNDescriptor_t result;
479 CHECK_CUDNN_OK(cudnnCreateRNNDescriptor(&result));
480 return RnnDescriptor(result);
481 }
482
CreatePersistentRnnPlan(cudnnRNNDescriptor_t rnn_desc,int batch_size,cudnnDataType_t data_type)483 port::StatusOr<PersistentRnnPlan> CreatePersistentRnnPlan(
484 cudnnRNNDescriptor_t rnn_desc, int batch_size, cudnnDataType_t data_type) {
485 cudnnPersistentRNNPlan_t result;
486 RETURN_IF_CUDNN_ERROR(
487 cudnnCreatePersistentRNNPlan(rnn_desc, batch_size, data_type, &result));
488 return port::StatusOr<PersistentRnnPlan>(PersistentRnnPlan(result));
489 }
490
491 // Turns a BatchDescriptor structure into a cudnn tensor handle within a
492 // scope.
493 class CudnnTensorDescriptor {
494 public:
CudnnTensorDescriptor(const dnn::BatchDescriptor & batch_descriptor,cudnnDataType_t elem_type)495 CudnnTensorDescriptor(const dnn::BatchDescriptor& batch_descriptor,
496 cudnnDataType_t elem_type)
497 : handle_(CreateTensorDescriptor()) {
498 switch (batch_descriptor.layout()) {
499 case dnn::DataLayout::kBatchYXDepth:
500 case dnn::DataLayout::kBatchDepthYX: {
501 const int nd = batch_descriptor.ndims() + 2;
502 // cuDNN requires the strides and dims to be ordered as BDYX.
503 std::vector<int64> strides64 =
504 batch_descriptor.full_strides(dnn::DataLayout::kBatchDepthYX);
505 std::vector<int64> dims64 =
506 batch_descriptor.full_dims(dnn::DataLayout::kBatchDepthYX);
507
508 // cuDNN requires arrays of ints.
509 std::vector<int> strides(nd);
510 std::vector<int> dims(nd);
511 std::transform(strides64.cbegin(), strides64.cend(), strides.begin(),
512 &CheckedNarrowing<int64, int>);
513 std::transform(dims64.cbegin(), dims64.cend(), dims.begin(),
514 &CheckedNarrowing<int64, int>);
515 CHECK_CUDNN_OK(cudnnSetTensorNdDescriptor(handle_.get(), elem_type, nd,
516 dims.data(), strides.data()))
517 << "batch_descriptor: " << batch_descriptor.ToString();
518 } break;
519 case dnn::DataLayout::kBatchDepthYX4: {
520 CHECK_CUDNN_OK(cudnnSetTensor4dDescriptor(
521 handle_.get(), CUDNN_TENSOR_NCHW_VECT_C, elem_type,
522 batch_descriptor.count(), batch_descriptor.feature_map_count(),
523 batch_descriptor.height(), batch_descriptor.width()))
524 << "batch_descriptor: " << batch_descriptor.ToString();
525 } break;
526 default:
527 LOG(FATAL) << "Unsupported tensor format "
528 << DataLayoutString(batch_descriptor.layout());
529 break;
530 }
531 }
532
handle() const533 cudnnTensorDescriptor_t handle() const { return handle_.get(); }
534
535 private:
536 TensorDescriptor handle_;
537
538 SE_DISALLOW_COPY_AND_ASSIGN(CudnnTensorDescriptor);
539 };
540
541 // Turns a FilterDescriptor structure into a cudnn filter handle within a
542 // scope.
543 class CudnnFilterDescriptor {
544 public:
CudnnFilterDescriptor(const dnn::FilterDescriptor & filter_descriptor,cudnnDataType_t elem_type)545 CudnnFilterDescriptor(const dnn::FilterDescriptor& filter_descriptor,
546 cudnnDataType_t elem_type)
547 : handle_(CreateFilterDescriptor()) {
548 // TODO(b/23032134): Even if the filter layout is not supported,
549 // cudnnSetFilter4DDescriptor_v4 will return CUDNN_STATUS_SUCCESS because
550 // it does not take layout as an input. Maybe force cuDNN by giving wrong
551 // inputs intentionally?
552 cudnnTensorFormat_t format;
553 switch (filter_descriptor.layout()) {
554 case dnn::FilterLayout::kOutputInputYX:
555 format = CUDNN_TENSOR_NCHW;
556 break;
557 case dnn::FilterLayout::kOutputYXInput:
558 format = CUDNN_TENSOR_NHWC;
559 break;
560 case dnn::FilterLayout::kOutputInputYX4:
561 format = CUDNN_TENSOR_NCHW_VECT_C;
562 break;
563 default:
564 LOG(FATAL) << "Unsupported filter format "
565 << FilterLayoutString(filter_descriptor.layout());
566 break;
567 }
568
569 std::vector<int> dims(2 + filter_descriptor.ndims());
570 dims[0] = filter_descriptor.output_feature_map_count();
571 dims[1] = filter_descriptor.input_feature_map_count();
572 absl::Span<const int64> spatial_dims =
573 filter_descriptor.input_filter_dims();
574 std::copy(spatial_dims.begin(), spatial_dims.end(), dims.begin() + 2);
575
576 CHECK_CUDNN_OK(cudnnSetFilterNdDescriptor(handle_.get(), elem_type, format,
577 dims.size(), dims.data()));
578 }
579
handle() const580 cudnnFilterDescriptor_t handle() const { return handle_.get(); }
581
582 private:
583 FilterDescriptor handle_; // Owned.
584
585 SE_DISALLOW_COPY_AND_ASSIGN(CudnnFilterDescriptor);
586 };
587
588 // A helper function to decide whether to enable the TENSOR_OP_MATH math type
TensorOpMathEnabled()589 bool TensorOpMathEnabled() {
590 static bool is_enabled = [] {
591 bool is_disabled = false;
592 TF_CHECK_OK(
593 tensorflow::ReadBoolFromEnvVar("TF_DISABLE_CUDNN_TENSOR_OP_MATH",
594 /*default_val=*/false, &is_disabled));
595 return !is_disabled;
596 }();
597 return is_enabled;
598 }
599
600 // A helper function to decide whether to enable the TENSOR_OP_MATH math type
601 // for RNNs.
RnnTensorOpMathEnabled()602 bool RnnTensorOpMathEnabled() {
603 static bool is_enabled = [] {
604 bool is_disabled = false;
605 TF_CHECK_OK(
606 tensorflow::ReadBoolFromEnvVar("TF_DISABLE_CUDNN_RNN_TENSOR_OP_MATH",
607 /*default_val=*/false, &is_disabled));
608 return !is_disabled;
609 }();
610 return is_enabled;
611 }
612
613 // A helper function to decide whether to use
614 // CUDNN_BATCHNORM_SPATIAL_PERSISTENT in batchnorm. This mode can be faster in
615 // some tasks because an optimized path may be selected for CUDNN_DATA_FLOAT
616 // and CUDNN_DATA_HALF data types, compute capability 6.0 or higher. The
617 // reason we set it to false by default is that this mode may use scaled
618 // atomic integer reduction that may cause a numerical overflow for certain
619 // input data range.
620 // TODO(yangzihao): Use autotune to choose between this mode and
621 // CUDNN_BATCHNORM_SPATIAL mode.
BatchnormSpatialPersistentEnabled()622 bool BatchnormSpatialPersistentEnabled() {
623 static bool is_enabled = [] {
624 bool is_enabled = false;
625 TF_CHECK_OK(tensorflow::ReadBoolFromEnvVar(
626 "TF_USE_CUDNN_BATCHNORM_SPATIAL_PERSISTENT",
627 /*default_val=*/false, &is_enabled));
628 return is_enabled;
629 }();
630 return is_enabled;
631 }
632
633 // A helper function to decide whether to enable deterministic functionality.
RequireDeterminism()634 bool RequireDeterminism() {
635 static bool is_enabled = [] {
636 bool is_enabled = false;
637 TF_CHECK_OK(tensorflow::ReadBoolFromEnvVar("TF_CUDNN_DETERMINISTIC",
638 /*default_val=*/false,
639 &is_enabled));
640 return is_enabled;
641 }();
642 return is_enabled;
643 }
644
645 // Turns a ConvolutionDescriptor structure into a cudnn convolution handle
646 // within a scope.
647 class CudnnConvolutionDescriptor {
648 public:
CudnnConvolutionDescriptor(const dnn::ConvolutionDescriptor & convolution_descriptor,cudnnDataType_t data_type)649 CudnnConvolutionDescriptor(
650 const dnn::ConvolutionDescriptor& convolution_descriptor,
651 cudnnDataType_t data_type)
652 : handle_(CreateConvolutionDescriptor()) {
653 absl::Span<const int64> strides64 = convolution_descriptor.strides();
654 absl::Span<const int64> padding64 = convolution_descriptor.padding();
655 absl::Span<const int64> dilations64 = convolution_descriptor.dilations();
656 CHECK_NE(convolution_descriptor.pad_alignment(),
657 dnn::PadAlignment::kTensorFlowPadding)
658 << "TensorFlow padding alignment is not supported.";
659
660 // cuDNN requires arrays of ints.
661 std::vector<int> strides(convolution_descriptor.ndims());
662 std::vector<int> padding(convolution_descriptor.ndims());
663 std::vector<int> dilations(convolution_descriptor.ndims());
664 std::transform(strides64.cbegin(), strides64.cend(), strides.begin(),
665 &CheckedNarrowing<int64, int>);
666 std::transform(padding64.cbegin(), padding64.cend(), padding.begin(),
667 &CheckedNarrowing<int64, int>);
668 // TODO(yangzihao): Test with negative dilation to make sure that cudnn
669 // doesn't crash.
670 std::transform(dilations64.cbegin(), dilations64.cend(), dilations.begin(),
671 &CheckedNarrowing<int64, int>);
672
673 CHECK_CUDNN_OK(cudnnSetConvolutionNdDescriptor(
674 handle_.get(), convolution_descriptor.ndims(), padding.data(),
675 strides.data(), dilations.data(),
676 convolution_descriptor.convolution_not_crosscorr()
677 ? CUDNN_CONVOLUTION
678 : CUDNN_CROSS_CORRELATION,
679 data_type));
680
681 // NOTE(benbarsdell): This only applies if tensor op math is enabled
682 // and algo selection is set to Default.
683 this->set_use_tensor_op_math(true);
684
685 #if CUDNN_MAJOR >= 7
686 VLOG(2) << "Requesting grouped convolution: "
687 << convolution_descriptor.group_count();
688 CHECK_CUDNN_OK(cudnnSetConvolutionGroupCount(
689 handle_.get(), convolution_descriptor.group_count()));
690 #else
691 CHECK_EQ(convolution_descriptor.group_count(), 1)
692 << "Requested grouped convolution for cuDNN version < 7";
693 #endif
694 }
695
set_use_tensor_op_math(bool use_tensor_op_math) const696 void set_use_tensor_op_math(bool use_tensor_op_math) const {
697 #if CUDNN_VERSION >= 7000
698 cudnnMathType_t math_type =
699 (use_tensor_op_math ? CUDNN_TENSOR_OP_MATH : CUDNN_DEFAULT_MATH);
700 if (TensorOpMathEnabled()) {
701 CHECK_CUDNN_OK(cudnnSetConvolutionMathType(handle_.get(), math_type));
702 }
703 #endif
704 }
705
handle() const706 cudnnConvolutionDescriptor_t handle() const { return handle_.get(); }
707
708 private:
709 ConvolutionDescriptor handle_; // Owned.
710
711 SE_DISALLOW_COPY_AND_ASSIGN(CudnnConvolutionDescriptor);
712 };
713
714 // Turns a PoolingDescriptor structure into a cudnn pooling descriptor handle
715 // within a scope.
716 class CudnnPoolingDescriptor {
717 public:
CudnnPoolingDescriptor(const dnn::PoolingDescriptor & pooling_descriptor)718 explicit CudnnPoolingDescriptor(
719 const dnn::PoolingDescriptor& pooling_descriptor)
720 : handle_(CreatePoolingDescriptor()) {
721 absl::Span<const int64> strides64 = pooling_descriptor.strides();
722 absl::Span<const int64> padding64 = pooling_descriptor.padding();
723 absl::Span<const int64> shape64 = pooling_descriptor.window();
724
725 const int nd = pooling_descriptor.ndims();
726 std::vector<int> shape(nd);
727 std::vector<int> padding(nd);
728 std::vector<int> strides(nd);
729 std::transform(strides64.cbegin(), strides64.cend(), strides.begin(),
730 &CheckedNarrowing<int64, int>);
731 std::transform(padding64.cbegin(), padding64.cend(), padding.begin(),
732 &CheckedNarrowing<int64, int>);
733 std::transform(shape64.cbegin(), shape64.cend(), shape.begin(),
734 &CheckedNarrowing<int64, int>);
735 bool propagate_nans = pooling_descriptor.propagate_nans();
736 const auto cudnn_max_pooling_mode = RequireDeterminism()
737 ? CUDNN_POOLING_MAX_DETERMINISTIC
738 : CUDNN_POOLING_MAX;
739 CHECK_CUDNN_OK(cudnnSetPoolingNdDescriptor(
740 handle_.get(),
741 (pooling_descriptor.mode() == dnn::PoolingMode::kMaximum
742 ? cudnn_max_pooling_mode
743 : CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING),
744 propagate_nans ? CUDNN_PROPAGATE_NAN : CUDNN_NOT_PROPAGATE_NAN, nd,
745 shape.data(), padding.data(), strides.data()));
746 }
747
handle() const748 cudnnPoolingDescriptor_t handle() const { return handle_.get(); }
749
750 private:
751 PoolingDescriptor handle_; // Owned.
752
753 SE_DISALLOW_COPY_AND_ASSIGN(CudnnPoolingDescriptor);
754 };
755
756 // Turns a NormalizeDescriptor structure into a cudnn LRN descriptor handle.
757 class CudnnNormalizeDescriptor {
758 public:
CudnnNormalizeDescriptor(const dnn::NormalizeDescriptor & normalize_descriptor)759 explicit CudnnNormalizeDescriptor(
760 const dnn::NormalizeDescriptor& normalize_descriptor)
761 : handle_(CreateLrnDescriptor()) {
762 // The range specifies that the indices in the closed range
763 // [i - range, i + range] should be included in the normalization for index
764 // i. The lrnN value is the total number of elements in the range, so
765 // lrnN = 2*range + 1.
766 unsigned lrnN = 2 * normalize_descriptor.range() + 1;
767
768 // Note that SE defines the normalization operation as
769 //
770 // U_i = V_i / ((bias + alpha * (sum_j V_j^2)) ^ beta)
771 //
772 // but cuDNN defines it as
773 //
774 // U_i = V_i / ((bias + (alpha / n) * (sum_j V_j^2)) ^ beta)
775 //
776 // i.e. there is a factor of n difference between the meaning of the alphas
777 // in the two contexts. The cuDNN alpha is n times the SE alpha.
778 double lrnAlpha = lrnN * normalize_descriptor.alpha();
779
780 double lrnBeta = normalize_descriptor.beta();
781 double lrnK = normalize_descriptor.bias();
782 CHECK_CUDNN_OK(
783 cudnnSetLRNDescriptor(handle_.get(), lrnN, lrnAlpha, lrnBeta, lrnK));
784 }
785
handle() const786 cudnnLRNDescriptor_t handle() const { return handle_.get(); }
787
788 private:
789 LrnDescriptor handle_; // Owned.
790
791 SE_DISALLOW_COPY_AND_ASSIGN(CudnnNormalizeDescriptor);
792 };
793
794 // Turns a ActivationDescriptor structure into a cudnn activation
795 // descriptor handle within a scope.
796 class CudnnActivationDescriptor {
797 public:
CudnnActivationDescriptor(dnn::ActivationMode activation_mode,cudnnNanPropagation_t nan_propagation,double value_max)798 CudnnActivationDescriptor(dnn::ActivationMode activation_mode,
799 cudnnNanPropagation_t nan_propagation,
800 double value_max)
801 : handle_(CreateActivationDescriptor()) {
802 double relu_ceiling = 0.0;
803 cudnnActivationMode_t mode;
804 switch (activation_mode) {
805 #if CUDNN_VERSION >= 7100
806 case dnn::ActivationMode::kNone:
807 mode = CUDNN_ACTIVATION_IDENTITY;
808 break;
809 #endif
810 case dnn::ActivationMode::kRelu6:
811 relu_ceiling = 6.0;
812 mode = CUDNN_ACTIVATION_CLIPPED_RELU;
813 break;
814 case dnn::ActivationMode::kReluX:
815 relu_ceiling = value_max;
816 mode = CUDNN_ACTIVATION_CLIPPED_RELU;
817 break;
818 case dnn::ActivationMode::kRelu:
819 mode = CUDNN_ACTIVATION_RELU;
820 break;
821 case dnn::ActivationMode::kSigmoid:
822 mode = CUDNN_ACTIVATION_SIGMOID;
823 break;
824 case dnn::ActivationMode::kTanh:
825 mode = CUDNN_ACTIVATION_TANH;
826 break;
827 default:
828 LOG(FATAL) << "unrecognized activation mode: "
829 << static_cast<int>(activation_mode);
830 }
831
832 CHECK_CUDNN_OK(cudnnSetActivationDescriptor(handle_.get(), mode,
833 nan_propagation, relu_ceiling));
834 }
835
handle() const836 cudnnActivationDescriptor_t handle() const { return handle_.get(); }
837
838 private:
839 ActivationDescriptor handle_; // Owned.
840
841 SE_DISALLOW_COPY_AND_ASSIGN(CudnnActivationDescriptor);
842 };
843
ToCudnnDataType(dnn::DataType data_type,dnn::DataLayout data_layout=dnn::DataLayout::kBatchDepthYX)844 cudnnDataType_t ToCudnnDataType(
845 dnn::DataType data_type,
846 dnn::DataLayout data_layout = dnn::DataLayout::kBatchDepthYX) {
847 switch (data_type) {
848 case dnn::DataType::kFloat:
849 return CUDNN_DATA_FLOAT;
850 case dnn::DataType::kDouble:
851 return CUDNN_DATA_DOUBLE;
852 case dnn::DataType::kHalf:
853 return CUDNN_DATA_HALF;
854 case dnn::DataType::kInt8:
855 return data_layout == dnn::DataLayout::kBatchDepthYX4 ? CUDNN_DATA_INT8x4
856 : CUDNN_DATA_INT8;
857 case dnn::DataType::kInt32:
858 return CUDNN_DATA_INT32;
859 default:
860 LOG(FATAL) << "Invalid DNN data type: " << static_cast<int>(data_type);
861 }
862 }
863
ToCudnnDataType(dnn::DataType data_type,dnn::FilterLayout filter_layout)864 cudnnDataType_t ToCudnnDataType(dnn::DataType data_type,
865 dnn::FilterLayout filter_layout) {
866 if (data_type == dnn::DataType::kInt8 &&
867 filter_layout == dnn::FilterLayout::kOutputInputYX4) {
868 return CUDNN_DATA_INT8x4;
869 }
870 return ToCudnnDataType(data_type);
871 }
872
873 template <typename T>
GetCudnnDataType(dnn::DataLayout data_layout=dnn::DataLayout::kBatchDepthYX)874 cudnnDataType_t GetCudnnDataType(
875 dnn::DataLayout data_layout = dnn::DataLayout::kBatchDepthYX) {
876 return ToCudnnDataType(dnn::ToDataType<T>::value, data_layout);
877 }
878
ToCudnnRnnInputMode(dnn::RnnInputMode input_mode)879 cudnnRNNInputMode_t ToCudnnRnnInputMode(dnn::RnnInputMode input_mode) {
880 switch (input_mode) {
881 case dnn::RnnInputMode::kRnnLinearSkip:
882 case dnn::RnnInputMode::kRnnSkipInput:
883 return static_cast<cudnnRNNInputMode_t>(input_mode);
884 default:
885 LOG(FATAL) << "Invalid RNN input mode: " << static_cast<int>(input_mode);
886 }
887 }
888
ToCudnnRnnDirectionMode(dnn::RnnDirectionMode direction_mode)889 cudnnDirectionMode_t ToCudnnRnnDirectionMode(
890 dnn::RnnDirectionMode direction_mode) {
891 switch (direction_mode) {
892 case dnn::RnnDirectionMode::kRnnUnidirectional:
893 case dnn::RnnDirectionMode::kRnnBidirectional:
894 return static_cast<cudnnDirectionMode_t>(direction_mode);
895 default:
896 LOG(FATAL) << "Invalid RNN direction mode: "
897 << static_cast<int>(direction_mode);
898 }
899 }
900
ToCudnnRnnMode(dnn::RnnMode rnn_mode)901 cudnnRNNMode_t ToCudnnRnnMode(dnn::RnnMode rnn_mode) {
902 switch (rnn_mode) {
903 case dnn::RnnMode::kRnnRelu:
904 case dnn::RnnMode::kRnnTanh:
905 case dnn::RnnMode::kRnnLstm:
906 case dnn::RnnMode::kRnnGru:
907 return static_cast<cudnnRNNMode_t>(rnn_mode);
908 default:
909 LOG(FATAL) << "Invalid RNN Mode: " << static_cast<int>(rnn_mode);
910 }
911 }
912
CudnnDataTypeToByteSize(cudnnDataType_t data_type)913 int CudnnDataTypeToByteSize(cudnnDataType_t data_type) {
914 switch (data_type) {
915 case CUDNN_DATA_FLOAT:
916 return sizeof(float);
917 case CUDNN_DATA_DOUBLE:
918 return sizeof(double);
919 case CUDNN_DATA_HALF:
920 return sizeof(Eigen::half);
921 default:
922 LOG(FATAL) << "Invalid DNN data type: " << static_cast<int>(data_type);
923 }
924 }
925
926 class CudnnDropoutDescriptor {
CudnnDropoutDescriptor(DropoutDescriptor handle)927 explicit CudnnDropoutDescriptor(DropoutDescriptor handle)
928 : handle_(std::move(handle)) {}
929
930 public:
931 CudnnDropoutDescriptor(CudnnDropoutDescriptor&&) = default;
932
Create(const CudnnHandle & cudnn,float dropout,uint64 seed,ScratchAllocator * state_allocator)933 static port::StatusOr<CudnnDropoutDescriptor> Create(
934 const CudnnHandle& cudnn, float dropout, uint64 seed,
935 ScratchAllocator* state_allocator) {
936 DropoutDescriptor handle = CreateDropoutDescriptor();
937
938 if (dropout == 0.0f) {
939 // Return 'empty' dropout descriptor.
940 return CudnnDropoutDescriptor(std::move(handle));
941 }
942
943 DeviceMemory<uint8> state_memory;
944 if (state_allocator) {
945 size_t state_sizes_in_bytes = 0;
946 RETURN_IF_CUDNN_ERROR(
947 cudnnDropoutGetStatesSize(cudnn.handle(), &state_sizes_in_bytes));
948 SE_ASSIGN_OR_RETURN(state_memory, state_allocator->AllocateBytes(
949 nullptr, state_sizes_in_bytes));
950 }
951 RETURN_IF_CUDNN_ERROR(cudnnSetDropoutDescriptor(
952 handle.get(), cudnn.handle(), dropout, state_memory.opaque(),
953 state_memory.size(), seed));
954
955 return CudnnDropoutDescriptor(std::move(handle));
956 }
957
handle() const958 cudnnDropoutDescriptor_t handle() const { return handle_.get(); }
959
960 private:
961 DropoutDescriptor handle_; // Owned.
962 SE_DISALLOW_COPY_AND_ASSIGN(CudnnDropoutDescriptor);
963 };
964
965 class CudnnRnnParamsDescriptor {
966 typedef dnn::RnnDescriptor::ParamsRegions ParamsRegions;
967
CudnnRnnParamsDescriptor(FilterDescriptor handle,int64 params_size_in_bytes,ParamsRegions weights,ParamsRegions biases)968 CudnnRnnParamsDescriptor(FilterDescriptor handle, int64 params_size_in_bytes,
969 ParamsRegions weights, ParamsRegions biases)
970 : handle_(std::move(handle)),
971 params_size_in_bytes_(params_size_in_bytes),
972 weights_(std::move(weights)),
973 biases_(std::move(biases)) {}
974
975 public:
976 CudnnRnnParamsDescriptor(CudnnRnnParamsDescriptor&&) = default;
977
978 static port::StatusOr<CudnnRnnParamsDescriptor> Create(
979 const CudnnHandle& cudnn, int input_size, cudnnDataType_t data_type,
980 cudnnRNNDescriptor_t rnn_desc, cudnnRNNMode_t rnn_mode,
981 cudnnDirectionMode_t direction_mode, int num_layers);
982
handle() const983 cudnnFilterDescriptor_t handle() const { return handle_.get(); }
params_size_in_bytes() const984 int64 params_size_in_bytes() const { return params_size_in_bytes_; }
params_weights() const985 ParamsRegions params_weights() const {
986 return weights_;
987 }
params_biases() const988 ParamsRegions params_biases() const {
989 return biases_;
990 }
991
992 private:
993 FilterDescriptor handle_;
994 int64 params_size_in_bytes_;
995 ParamsRegions weights_;
996 ParamsRegions biases_;
997 SE_DISALLOW_COPY_AND_ASSIGN(CudnnRnnParamsDescriptor);
998 };
999
1000 } // namespace
1001
1002 class CudnnRnnDescriptor : public dnn::RnnDescriptor {
CudnnRnnDescriptor(const CudnnHandle & cudnn,gpu::RnnDescriptor rnn_desc,PersistentRnnPlan rnn_plan,int num_layers,int hidden_size,int input_size,int batch_size,cudnnRNNInputMode_t input_mode,cudnnDirectionMode_t direction_mode,cudnnRNNMode_t rnn_mode,cudnnDataType_t data_type,cudnnDataType_t compute_type,const dnn::AlgorithmConfig & algorithm_config,CudnnDropoutDescriptor dropout_desc,CudnnRnnParamsDescriptor params_desc)1003 CudnnRnnDescriptor(const CudnnHandle& cudnn, gpu::RnnDescriptor rnn_desc,
1004 PersistentRnnPlan rnn_plan, int num_layers,
1005 int hidden_size, int input_size, int batch_size,
1006 cudnnRNNInputMode_t input_mode,
1007 cudnnDirectionMode_t direction_mode,
1008 cudnnRNNMode_t rnn_mode, cudnnDataType_t data_type,
1009 cudnnDataType_t compute_type,
1010 const dnn::AlgorithmConfig& algorithm_config,
1011 CudnnDropoutDescriptor dropout_desc,
1012 CudnnRnnParamsDescriptor params_desc)
1013 : rnn_desc_(std::move(rnn_desc)),
1014 rnn_plan_(std::move(rnn_plan)),
1015 num_layers_(num_layers),
1016 hidden_size_(hidden_size),
1017 input_size_(input_size),
1018 batch_size_(batch_size),
1019 rnn_algo_(ToCudnnRNNAlgo(algorithm_config.algorithm())),
1020 input_mode_(input_mode),
1021 direction_mode_(direction_mode),
1022 rnn_mode_(rnn_mode),
1023 data_type_(data_type),
1024 compute_type_(compute_type),
1025 algorithm_config_(algorithm_config),
1026 dropout_desc_(std::move(dropout_desc)),
1027 params_desc_(std::move(params_desc)) {}
1028
1029 public:
1030 CudnnRnnDescriptor(CudnnRnnDescriptor&& other) = default;
1031
Create(const CudnnHandle & cudnn,int num_layers,int hidden_size,int input_size,int batch_size,cudnnRNNInputMode_t input_mode,cudnnDirectionMode_t direction_mode,cudnnRNNMode_t rnn_mode,cudnnDataType_t data_type,cudnnDataType_t compute_type,const dnn::AlgorithmConfig & algorithm_config,float dropout,uint64 seed,ScratchAllocator * state_allocator)1032 static port::StatusOr<CudnnRnnDescriptor> Create(
1033 const CudnnHandle& cudnn, int num_layers, int hidden_size, int input_size,
1034 int batch_size, cudnnRNNInputMode_t input_mode,
1035 cudnnDirectionMode_t direction_mode, cudnnRNNMode_t rnn_mode,
1036 cudnnDataType_t data_type, cudnnDataType_t compute_type,
1037 const dnn::AlgorithmConfig& algorithm_config, float dropout, uint64 seed,
1038 ScratchAllocator* state_allocator) {
1039 SE_ASSIGN_OR_RETURN(
1040 CudnnDropoutDescriptor dropout_desc,
1041 CudnnDropoutDescriptor::Create(cudnn, dropout, seed, state_allocator));
1042
1043 gpu::RnnDescriptor rnn_desc = CreateRnnDescriptor();
1044 cudnnRNNAlgo_t rnn_algo = ToCudnnRNNAlgo(algorithm_config.algorithm());
1045
1046 // TODO: allow the user to choose an algorithm.
1047 RETURN_IF_CUDNN_ERROR(cudnnSetRNNDescriptor_v6(
1048 cudnn.handle(), /*rnnDesc=*/rnn_desc.get(), /*hiddenSize=*/hidden_size,
1049 /*numLayers=*/num_layers, /*dropoutDesc=*/dropout_desc.handle(),
1050 /*inputMode=*/input_mode, /*direction=*/direction_mode,
1051 /*mode=*/rnn_mode, /*algo=*/rnn_algo,
1052 /*dataType=*/compute_type));
1053
1054 // TODO: For now, we only use cudnnRNN**Ex API to process padded inputs.
1055 // But in the future if these APIs are used to process full length arrays,
1056 // we need to distinguish when to set it.
1057 #if CUDNN_VERSION >= 7201
1058 RETURN_IF_CUDNN_ERROR(
1059 cudnnSetRNNPaddingMode(rnn_desc.get(), CUDNN_RNN_PADDED_IO_ENABLED));
1060 #endif
1061
1062 port::StatusOr<PersistentRnnPlan> rnn_plan_wrapper;
1063 PersistentRnnPlan rnn_plan;
1064 if (rnn_algo == CUDNN_RNN_ALGO_PERSIST_DYNAMIC) {
1065 CHECK_GE(batch_size, 0);
1066 rnn_plan_wrapper =
1067 CreatePersistentRnnPlan(rnn_desc.get(), batch_size, data_type);
1068 if (!rnn_plan_wrapper.ok()) {
1069 return port::StatusOr<CudnnRnnDescriptor>(rnn_plan_wrapper.status());
1070 } else {
1071 rnn_plan = rnn_plan_wrapper.ConsumeValueOrDie();
1072 RETURN_IF_CUDNN_ERROR(
1073 cudnnSetPersistentRNNPlan(rnn_desc.get(), rnn_plan.get()));
1074 }
1075 }
1076
1077 // Create the params handle.
1078 SE_ASSIGN_OR_RETURN(auto params_desc,
1079 CudnnRnnParamsDescriptor::Create(
1080 cudnn, input_size, data_type, rnn_desc.get(),
1081 rnn_mode, direction_mode, num_layers));
1082
1083 #if CUDNN_VERSION >= 7000
1084 // Require explicit algorithm config to enable tensor cores. Some configs
1085 // return CUDNN_NOT_SUPPORTED when tensor ops are enabled (which is against
1086 // the idiom that enabling tensor ops is only a hint: see nvbugs/2172799).
1087 // We can only reasonably expect the user to handle the subsequent failure
1088 // in profile mode, which is run with algorithms returned from
1089 // GetRnnAlgorithms() (which are non-default and explicitly set whether to
1090 // use tensor ops).
1091 if (RnnTensorOpMathEnabled() && algorithm_config.algorithm().has_value()) {
1092 cudnnMathType_t math_type =
1093 algorithm_config.algorithm()->tensor_ops_enabled()
1094 ? CUDNN_TENSOR_OP_MATH
1095 : CUDNN_DEFAULT_MATH;
1096 CHECK_CUDNN_OK(cudnnSetRNNMatrixMathType(rnn_desc.get(), math_type));
1097 }
1098 #endif
1099
1100 return CudnnRnnDescriptor(cudnn, std::move(rnn_desc), std::move(rnn_plan),
1101 num_layers, hidden_size, input_size, batch_size,
1102 input_mode, direction_mode, rnn_mode, data_type,
1103 compute_type, algorithm_config,
1104 std::move(dropout_desc), std::move(params_desc));
1105 }
1106
handle() const1107 cudnnRNNDescriptor_t handle() const { return rnn_desc_.get(); }
num_layers() const1108 int num_layers() const { return num_layers_; }
hidden_size() const1109 int hidden_size() const { return hidden_size_; }
input_size() const1110 int input_size() const { return input_size_; }
batch_size() const1111 int batch_size() const { return batch_size_; }
input_mode() const1112 cudnnRNNInputMode_t input_mode() const { return input_mode_; }
direction_mode() const1113 cudnnDirectionMode_t direction_mode() const { return direction_mode_; }
rnn_mode() const1114 cudnnRNNMode_t rnn_mode() const { return rnn_mode_; }
data_type() const1115 cudnnDataType_t data_type() const { return data_type_; }
compute_type() const1116 cudnnDataType_t compute_type() const { return compute_type_; }
algorithm_config() const1117 const dnn::AlgorithmConfig& algorithm_config() const {
1118 return algorithm_config_;
1119 }
ParamsSizeInBytes() const1120 int64 ParamsSizeInBytes() const override {
1121 return params_desc_.params_size_in_bytes();
1122 }
params_handle() const1123 cudnnFilterDescriptor_t params_handle() const {
1124 return params_desc_.handle();
1125 }
ParamsWeightRegions() const1126 ParamsRegions ParamsWeightRegions() const override {
1127 return params_desc_.params_weights();
1128 }
ParamsBiasRegions() const1129 ParamsRegions ParamsBiasRegions() const override {
1130 return params_desc_.params_biases();
1131 }
1132
1133 private:
1134 gpu::RnnDescriptor rnn_desc_;
1135 PersistentRnnPlan rnn_plan_;
1136 int num_layers_;
1137 int hidden_size_;
1138 int input_size_;
1139 // batch_size_ is set to -1 when not using CUDNN_RNN_ALGO_PERSIST_DYNAMIC
1140 // algorithm.
1141 int batch_size_;
1142 cudnnRNNAlgo_t rnn_algo_;
1143 cudnnRNNInputMode_t input_mode_;
1144 cudnnDirectionMode_t direction_mode_;
1145 cudnnRNNMode_t rnn_mode_;
1146 cudnnDataType_t data_type_;
1147 cudnnDataType_t compute_type_;
1148 dnn::AlgorithmConfig algorithm_config_;
1149 CudnnDropoutDescriptor dropout_desc_;
1150 CudnnRnnParamsDescriptor params_desc_;
1151 SE_DISALLOW_COPY_AND_ASSIGN(CudnnRnnDescriptor);
1152 };
1153
1154 namespace {
1155
Create(const CudnnHandle & cudnn,int input_size,cudnnDataType_t data_type,cudnnRNNDescriptor_t rnn_desc,cudnnRNNMode_t rnn_mode,cudnnDirectionMode_t direction_mode,int num_layers)1156 port::StatusOr<CudnnRnnParamsDescriptor> CudnnRnnParamsDescriptor::Create(
1157 const CudnnHandle& cudnn, int input_size, cudnnDataType_t data_type,
1158 cudnnRNNDescriptor_t rnn_desc, cudnnRNNMode_t rnn_mode,
1159 cudnnDirectionMode_t direction_mode, int num_layers) {
1160 // Query the params size.
1161 TensorDescriptor input_desc = CreateTensorDescriptor();
1162 int tensor_dims[] = {1, input_size, 1};
1163 int strides[] = {tensor_dims[1] * tensor_dims[2], tensor_dims[2], 1};
1164 RETURN_IF_CUDNN_ERROR(cudnnSetTensorNdDescriptor(
1165 /*tensorDesc=*/input_desc.get(), /*dataType=*/data_type,
1166 /*nbDims=*/sizeof(tensor_dims) / sizeof(tensor_dims[0]),
1167 /*dimA=*/tensor_dims,
1168 /*strideA=*/strides));
1169
1170 size_t params_size = 0;
1171 RETURN_IF_CUDNN_ERROR(cudnnGetRNNParamsSize(
1172 /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc,
1173 /*xDesc=*/input_desc.get(), /*sizeInBytes=*/¶ms_size,
1174 /*dataType=*/data_type));
1175 int64 params_size_in_bytes = static_cast<int64>(params_size);
1176
1177 FilterDescriptor filter_desc = CreateFilterDescriptor();
1178 int filter_dims[] = {static_cast<int>(params_size_in_bytes), 1, 1};
1179 RETURN_IF_CUDNN_ERROR(cudnnSetFilterNdDescriptor(
1180 /*filterDesc=*/filter_desc.get(), /*dataType=*/data_type,
1181 /*format=*/CUDNN_TENSOR_NCHW,
1182 /*nbDims=*/sizeof(filter_dims) / sizeof(filter_dims[0]),
1183 /*filterDimA=*/filter_dims));
1184
1185 // Create the weights and biases into the params buffer
1186 int region_count_per_layer = [&] {
1187 switch (rnn_mode) {
1188 case CUDNN_RNN_RELU:
1189 case CUDNN_RNN_TANH:
1190 return 2;
1191 case CUDNN_LSTM:
1192 return 8;
1193 case CUDNN_GRU:
1194 return 6;
1195 default:
1196 LOG(FATAL) << "Invalid RNN Mode: " << static_cast<int>(rnn_mode);
1197 return 0;
1198 }
1199 }();
1200
1201 FilterDescriptor region_desc_handle = CreateFilterDescriptor();
1202 const int layer_count =
1203 direction_mode == CUDNN_UNIDIRECTIONAL ? num_layers : 2 * num_layers;
1204
1205 ParamsRegions weights;
1206 ParamsRegions biases;
1207
1208 for (int layer = 0; layer < layer_count; layer++) {
1209 for (int region = 0; region < region_count_per_layer; region++) {
1210 for (int type = 0; type < 2; type++) {
1211 void* offset = nullptr;
1212 RETURN_IF_CUDNN_ERROR(
1213 type == 0 ? cudnnGetRNNLinLayerMatrixParams(
1214 /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc,
1215 /*layer=*/layer, /*xDesc=*/input_desc.get(),
1216 /*wDesc=*/filter_desc.get(),
1217 /*w=*/nullptr, /*linLayerID=*/region,
1218 /*linLayerMatDesc=*/region_desc_handle.get(),
1219 /*linLayerMat or linLayerBias=*/&offset)
1220 : cudnnGetRNNLinLayerBiasParams(
1221 /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc,
1222 /*layer=*/layer, /*xDesc=*/input_desc.get(),
1223 /*wDesc=*/filter_desc.get(),
1224 /*w=*/nullptr, /*linLayerID=*/region,
1225 /*linLayerMatDesc=*/region_desc_handle.get(),
1226 /*linLayerMat or linLayerBias=*/&offset));
1227 int dims[] = {1, 1, 1};
1228 cudnnDataType_t data_type;
1229 cudnnTensorFormat_t tensor_format;
1230 int n_dims;
1231 RETURN_IF_CUDNN_ERROR(cudnnGetFilterNdDescriptor(
1232 /*filterDesc=*/region_desc_handle.get(),
1233 /*nbDimsRequested=*/sizeof(dims) / sizeof(dims[0]),
1234 /*dataType=*/&data_type, /*format=*/&tensor_format,
1235 /*nbDims=*/&n_dims, /*filterDimA=*/dims));
1236 int64 size =
1237 dims[0] * dims[1] * dims[2] * CudnnDataTypeToByteSize(data_type);
1238 dnn::RnnDescriptor::ParamsRegion region = {
1239 reinterpret_cast<int64>(offset), size};
1240 (type == 0 ? weights : biases).push_back(region);
1241 }
1242 }
1243 }
1244
1245 return CudnnRnnParamsDescriptor(std::move(filter_desc), params_size_in_bytes,
1246 weights, biases);
1247 }
1248
1249 } // namespace
1250
1251 class CudnnRnnSequenceTensorDescriptor
1252 : public dnn::RnnSequenceTensorDescriptor {
CudnnRnnSequenceTensorDescriptor(GpuExecutor * parent,int max_seq_length,int batch_size,int data_size,cudnnDataType_t data_type,RNNDataDescriptor data_handle,TensorDescriptor handle)1253 CudnnRnnSequenceTensorDescriptor(GpuExecutor* parent, int max_seq_length,
1254 int batch_size, int data_size,
1255 cudnnDataType_t data_type,
1256 #if CUDNN_VERSION >= 7201
1257 RNNDataDescriptor data_handle,
1258 #endif
1259 TensorDescriptor handle)
1260 : max_seq_length_(max_seq_length),
1261 batch_size_(batch_size),
1262 data_size_(data_size),
1263 data_type_(data_type),
1264 handle_(std::move(handle)),
1265 #if CUDNN_VERSION >= 7201
1266 rnn_data_handle_(std::move(data_handle)),
1267 #endif
1268 handles_(max_seq_length, handle_.get()) {
1269 }
1270
1271 public:
1272 CudnnRnnSequenceTensorDescriptor(CudnnRnnSequenceTensorDescriptor&&) =
1273 default;
1274
Create(GpuExecutor * parent,int max_seq_length,int batch_size,int data_size,cudnnDataType_t data_type)1275 static port::StatusOr<CudnnRnnSequenceTensorDescriptor> Create(
1276 GpuExecutor* parent, int max_seq_length, int batch_size, int data_size,
1277 cudnnDataType_t data_type) {
1278 CHECK_GT(max_seq_length, 0);
1279 int dims[] = {batch_size, data_size, 1};
1280 int strides[] = {dims[1] * dims[2], dims[2], 1};
1281 TensorDescriptor tensor_desc = CreateTensorDescriptor();
1282 RETURN_IF_CUDNN_ERROR(cudnnSetTensorNdDescriptor(
1283 /*tensorDesc=*/tensor_desc.get(), /*dataType=*/data_type,
1284 /*nbDims=*/sizeof(dims) / sizeof(dims[0]), /*dimA=*/dims,
1285 /*strideA=*/strides));
1286 return CudnnRnnSequenceTensorDescriptor(parent, max_seq_length, batch_size,
1287 data_size, data_type,
1288 #if CUDNN_VERSION >= 7201
1289 nullptr,
1290 #endif
1291 std::move(tensor_desc));
1292 }
1293
Create(GpuExecutor * parent,int max_seq_length,int batch_size,int data_size,const absl::Span<const int> & seq_lengths,bool time_major,cudnnDataType_t data_type)1294 static port::StatusOr<CudnnRnnSequenceTensorDescriptor> Create(
1295 GpuExecutor* parent, int max_seq_length, int batch_size, int data_size,
1296 const absl::Span<const int>& seq_lengths, bool time_major,
1297 cudnnDataType_t data_type) {
1298 #if CUDNN_VERSION >= 7201
1299 CHECK_GT(max_seq_length, 0);
1300 int dims[] = {batch_size, data_size, 1};
1301 int strides[] = {dims[1] * dims[2], dims[2], 1};
1302 TensorDescriptor tensor_desc = CreateTensorDescriptor();
1303 RETURN_IF_CUDNN_ERROR(cudnnSetTensorNdDescriptor(
1304 /*tensorDesc=*/tensor_desc.get(), /*dataType=*/data_type,
1305 /*nbDims=*/sizeof(dims) / sizeof(dims[0]), /*dimA=*/dims,
1306 /*strideA=*/strides));
1307 const int* seq_lengths_array = seq_lengths.data();
1308 RNNDataDescriptor data_desc = CreateRNNDataDescriptor();
1309 float padding_fill = 0.0f;
1310 cudnnRNNDataLayout_t layout;
1311 if (time_major) {
1312 layout = CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_UNPACKED;
1313 } else {
1314 layout = CUDNN_RNN_DATA_LAYOUT_BATCH_MAJOR_UNPACKED;
1315 }
1316 RETURN_IF_CUDNN_ERROR(cudnnSetRNNDataDescriptor(
1317 /*RNNDataDesc=*/data_desc.get(), /*dataType*/ data_type,
1318 /*layout=*/layout,
1319 /*maxSeqLength=*/max_seq_length,
1320 /*batchSize=*/batch_size, /*vectorSize=*/data_size,
1321 /*seqLengthArray=*/seq_lengths_array,
1322 /*paddingFill*/ (void*)&padding_fill));
1323 return CudnnRnnSequenceTensorDescriptor(
1324 parent, max_seq_length, batch_size, data_size, data_type,
1325 std::move(data_desc), std::move(tensor_desc));
1326 #else
1327 return port::Status(port::error::INVALID_ARGUMENT,
1328 "No supported cudnnSetRNNDataDescriptor when "
1329 "CUDNN_VERSION < 7.2.1");
1330 #endif
1331 }
1332
handles() const1333 const cudnnTensorDescriptor_t* handles() const {
1334 return handles_.data();
1335 }
1336 #if CUDNN_VERSION >= 7201
data_handle() const1337 const cudnnRNNDataDescriptor_t data_handle() const {
1338 return rnn_data_handle_.get();
1339 }
1340 #endif
1341
max_seq_length() const1342 int max_seq_length() const { return max_seq_length_; }
batch_size() const1343 int batch_size() const { return batch_size_; }
data_size() const1344 int data_size() const { return data_size_; }
is_var_seq_lengths() const1345 bool is_var_seq_lengths() const {
1346 #if CUDNN_VERSION >= 7201
1347 return rnn_data_handle_ != nullptr;
1348 #else
1349 return false;
1350 #endif
1351 }
1352
1353 private:
1354 int max_seq_length_;
1355 int batch_size_;
1356 int data_size_;
1357 cudnnDataType_t data_type_;
1358 TensorDescriptor handle_;
1359 #if CUDNN_VERSION >= 7201
1360 RNNDataDescriptor rnn_data_handle_;
1361 #endif
1362 std::vector<cudnnTensorDescriptor_t> handles_; // Copies of handle_.
1363 SE_DISALLOW_COPY_AND_ASSIGN(CudnnRnnSequenceTensorDescriptor);
1364 };
1365
1366 class CudnnRnnStateTensorDescriptor : public dnn::RnnStateTensorDescriptor {
1367 public:
CudnnRnnStateTensorDescriptor(GpuExecutor * parent,int num_layers,int batch_size,int data_size,cudnnDataType_t data_type)1368 CudnnRnnStateTensorDescriptor(GpuExecutor* parent, int num_layers,
1369 int batch_size, int data_size,
1370 cudnnDataType_t data_type)
1371 : handle_(CreateTensorDescriptor()),
1372 num_layers_(num_layers),
1373 batch_size_(batch_size),
1374 data_size_(data_size),
1375 data_type_(data_type) {
1376 int dims[] = {num_layers, batch_size, data_size};
1377 int strides[] = {dims[1] * dims[2], dims[2], 1};
1378 CHECK_CUDNN_OK(cudnnSetTensorNdDescriptor(
1379 /*tensorDesc=*/handle_.get(), /*dataType=*/data_type,
1380 /*nbDims=*/sizeof(dims) / sizeof(dims[0]), /*dimA=*/dims,
1381 /*strideA=*/strides));
1382 }
1383
handle() const1384 cudnnTensorDescriptor_t handle() const { return handle_.get(); }
1385
num_layers() const1386 int num_layers() const { return num_layers_; }
batch_size() const1387 int batch_size() const { return batch_size_; }
data_size() const1388 int data_size() const { return data_size_; }
1389
1390 private:
1391 TensorDescriptor handle_;
1392 int num_layers_;
1393 int batch_size_;
1394 int data_size_;
1395 cudnnDataType_t data_type_;
1396 SE_DISALLOW_COPY_AND_ASSIGN(CudnnRnnStateTensorDescriptor);
1397 };
1398
1399 namespace {
1400
1401 struct RnnModelDims {
1402 int num_layers = 0;
1403 int batch_size = 0;
1404 int max_seq_length = 0;
1405 int hidden_size = 0;
1406 int input_size = 0;
1407 int dir_count = 0;
1408 };
1409
1410 template <class T>
ExtractAndCheckRnnForward(const CudnnRnnDescriptor & rnn_desc,const CudnnRnnSequenceTensorDescriptor & input_desc,const DeviceMemory<T> & input_data,const CudnnRnnStateTensorDescriptor & input_h_desc,const DeviceMemory<T> & input_h_data,const CudnnRnnStateTensorDescriptor & input_c_desc,const DeviceMemory<T> & input_c_data,const DeviceMemory<T> & params,const CudnnRnnSequenceTensorDescriptor & output_desc,const DeviceMemory<T> & output_data,const CudnnRnnStateTensorDescriptor & output_h_desc,const DeviceMemory<T> & output_h_data,const CudnnRnnStateTensorDescriptor & output_c_desc,const DeviceMemory<T> & output_c_data)1411 port::StatusOr<RnnModelDims> ExtractAndCheckRnnForward(
1412 const CudnnRnnDescriptor& rnn_desc,
1413 const CudnnRnnSequenceTensorDescriptor& input_desc,
1414 const DeviceMemory<T>& input_data,
1415 const CudnnRnnStateTensorDescriptor& input_h_desc,
1416 const DeviceMemory<T>& input_h_data,
1417 const CudnnRnnStateTensorDescriptor& input_c_desc,
1418 const DeviceMemory<T>& input_c_data, const DeviceMemory<T>& params,
1419 const CudnnRnnSequenceTensorDescriptor& output_desc,
1420 const DeviceMemory<T>& output_data,
1421 const CudnnRnnStateTensorDescriptor& output_h_desc,
1422 const DeviceMemory<T>& output_h_data,
1423 const CudnnRnnStateTensorDescriptor& output_c_desc,
1424 const DeviceMemory<T>& output_c_data) {
1425 // extract model parameters
1426 RnnModelDims model_dims;
1427 model_dims.num_layers = rnn_desc.num_layers();
1428 model_dims.batch_size = input_desc.batch_size();
1429 model_dims.max_seq_length = input_desc.max_seq_length();
1430 model_dims.hidden_size = rnn_desc.hidden_size();
1431 model_dims.input_size = input_desc.data_size();
1432 model_dims.dir_count =
1433 (rnn_desc.direction_mode() == CUDNN_BIDIRECTIONAL) ? 2 : 1;
1434
1435 // check parameters
1436 if (!(input_h_desc.num_layers() ==
1437 model_dims.num_layers * model_dims.dir_count &&
1438 input_h_desc.batch_size() == model_dims.batch_size &&
1439 input_h_desc.data_size() == model_dims.hidden_size)) {
1440 return port::Status(port::error::INVALID_ARGUMENT, "Invalid input_h shape");
1441 }
1442 if (!(input_h_desc.num_layers() == input_c_desc.num_layers() &&
1443 input_h_desc.batch_size() == input_c_desc.batch_size() &&
1444 input_h_desc.data_size() == input_c_desc.data_size())) {
1445 return port::Status(port::error::INVALID_ARGUMENT, "Invalid input_c shape");
1446 }
1447 if (!(output_desc.max_seq_length() == model_dims.max_seq_length &&
1448 output_desc.batch_size() == model_dims.batch_size &&
1449 output_desc.data_size() ==
1450 model_dims.hidden_size * model_dims.dir_count)) {
1451 return port::Status(port::error::INVALID_ARGUMENT, "Invalid output shape");
1452 }
1453 if (!(input_h_desc.num_layers() == output_h_desc.num_layers() &&
1454 input_h_desc.batch_size() == output_h_desc.batch_size() &&
1455 input_h_desc.data_size() == output_h_desc.data_size())) {
1456 return port::Status(port::error::INVALID_ARGUMENT,
1457 "Invalid output_h shape");
1458 }
1459 if (!(input_h_desc.num_layers() == output_c_desc.num_layers() &&
1460 input_h_desc.batch_size() == output_c_desc.batch_size() &&
1461 input_h_desc.data_size() == output_c_desc.data_size())) {
1462 return port::Status(port::error::INVALID_ARGUMENT,
1463 "Invalid output_c shape");
1464 }
1465
1466 return model_dims;
1467 }
1468
CheckRNNParameterSize(const CudnnHandle & cudnn,const CudnnRnnDescriptor & rnn_desc,const CudnnRnnSequenceTensorDescriptor & input_desc)1469 port::Status CheckRNNParameterSize(
1470 const CudnnHandle& cudnn, const CudnnRnnDescriptor& rnn_desc,
1471 const CudnnRnnSequenceTensorDescriptor& input_desc) {
1472 size_t params_size_in_bytes = 0;
1473 RETURN_IF_CUDNN_ERROR(cudnnGetRNNParamsSize(
1474 /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(),
1475 /*xDesc=*/input_desc.handles()[0], /*sizeInBytes=*/¶ms_size_in_bytes,
1476 /*dataType=*/rnn_desc.data_type()));
1477 if (static_cast<int64>(params_size_in_bytes) !=
1478 rnn_desc.ParamsSizeInBytes()) {
1479 return port::Status(port::error::INVALID_ARGUMENT,
1480 "Mismatching RNN parameter size");
1481 }
1482 return port::Status::OK();
1483 }
1484
CreateRnnWorkspace(Stream * stream,const CudnnHandle & cudnn,const CudnnRnnDescriptor & rnn_desc,const CudnnRnnSequenceTensorDescriptor & input_desc,ScratchAllocator * workspace_allocator)1485 port::StatusOr<DeviceMemory<uint8>> CreateRnnWorkspace(
1486 Stream* stream, const CudnnHandle& cudnn,
1487 const CudnnRnnDescriptor& rnn_desc,
1488 const CudnnRnnSequenceTensorDescriptor& input_desc,
1489 ScratchAllocator* workspace_allocator) {
1490 // Query the workspace size.
1491 size_t workspace_size_in_bytes = 0;
1492 RETURN_IF_CUDNN_ERROR(cudnnGetRNNWorkspaceSize(
1493 /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(),
1494 /*seqLength=*/input_desc.max_seq_length(), /*xDesc=*/input_desc.handles(),
1495 /*sizeInBytes=*/&workspace_size_in_bytes));
1496 // Allocate the workspace.
1497 if (workspace_size_in_bytes == 0) {
1498 return DeviceMemory<uint8>();
1499 }
1500 return workspace_allocator->AllocateBytes(stream, workspace_size_in_bytes);
1501 }
1502
1503 } // namespace
1504
1505 template <class T>
DoRnnForwardImpl(Stream * stream,const CudnnRnnDescriptor & rnn_desc,const CudnnRnnSequenceTensorDescriptor & input_desc,const DeviceMemory<T> & input_data,const CudnnRnnStateTensorDescriptor & input_h_desc,const DeviceMemory<T> & input_h_data,const CudnnRnnStateTensorDescriptor & input_c_desc,const DeviceMemory<T> & input_c_data,const DeviceMemory<T> & params,const CudnnRnnSequenceTensorDescriptor & output_desc,DeviceMemory<T> * output_data,const CudnnRnnStateTensorDescriptor & output_h_desc,DeviceMemory<T> * output_h_data,const CudnnRnnStateTensorDescriptor & output_c_desc,DeviceMemory<T> * output_c_data,bool is_training,ScratchAllocator * reserve_space_allocator,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)1506 port::Status CudnnSupport::DoRnnForwardImpl(
1507 Stream* stream, const CudnnRnnDescriptor& rnn_desc,
1508 const CudnnRnnSequenceTensorDescriptor& input_desc,
1509 const DeviceMemory<T>& input_data,
1510 const CudnnRnnStateTensorDescriptor& input_h_desc,
1511 const DeviceMemory<T>& input_h_data,
1512 const CudnnRnnStateTensorDescriptor& input_c_desc,
1513 const DeviceMemory<T>& input_c_data, const DeviceMemory<T>& params,
1514 const CudnnRnnSequenceTensorDescriptor& output_desc,
1515 DeviceMemory<T>* output_data,
1516 const CudnnRnnStateTensorDescriptor& output_h_desc,
1517 DeviceMemory<T>* output_h_data,
1518 const CudnnRnnStateTensorDescriptor& output_c_desc,
1519 DeviceMemory<T>* output_c_data, bool is_training,
1520 ScratchAllocator* reserve_space_allocator,
1521 ScratchAllocator* workspace_allocator,
1522 dnn::ProfileResult* output_profile_result) {
1523 SE_ASSIGN_OR_RETURN(
1524 RnnModelDims model_dims,
1525 ExtractAndCheckRnnForward(
1526 rnn_desc, input_desc, input_data, input_h_desc, input_h_data,
1527 input_c_desc, input_c_data, params, output_desc, *output_data,
1528 output_h_desc, *output_h_data, output_c_desc, *output_c_data));
1529
1530 auto cudnn = cudnn_->GetHandle(parent_, stream);
1531
1532 SE_RETURN_IF_ERROR(CheckRNNParameterSize(cudnn, rnn_desc, input_desc));
1533 SE_ASSIGN_OR_RETURN(DeviceMemory<uint8> workspace,
1534 CreateRnnWorkspace(stream, cudnn, rnn_desc, input_desc,
1535 workspace_allocator))
1536
1537 // query the reserve space size
1538 // allocate the reserve space
1539 DeviceMemory<uint8> reserve_space;
1540 if (is_training) {
1541 size_t reserve_space_size_in_bytes = 0;
1542 RETURN_IF_CUDNN_ERROR(cudnnGetRNNTrainingReserveSize(
1543 /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(),
1544 /*seqLength=*/model_dims.max_seq_length, /*xDesc=*/input_desc.handles(),
1545 /*sizeInBytes=*/&reserve_space_size_in_bytes));
1546
1547 if (reserve_space_size_in_bytes > 0) {
1548 SE_ASSIGN_OR_RETURN(reserve_space,
1549 reserve_space_allocator->AllocateBytes(
1550 stream, reserve_space_size_in_bytes));
1551 }
1552 }
1553
1554 std::unique_ptr<GpuTimer, GpuTimerDeleter> timer;
1555 const bool is_profiling = output_profile_result != nullptr;
1556 if (is_profiling) {
1557 timer.reset(new GpuTimer(parent_));
1558 // The start and stop of the timer should be as close to the Cudnn call as
1559 // possible. It is still possible for other threads to issue workload on
1560 // to this stream. So it could take multiple profiling measurements.
1561 if (!timer->Init() || !timer->Start(AsGpuStream(stream))) {
1562 return port::Status(port::error::INTERNAL, "Failed to start timer");
1563 }
1564 }
1565
1566 if (!is_training) {
1567 if (input_desc.is_var_seq_lengths()) {
1568 #if CUDNN_VERSION >= 7201
1569 RETURN_IF_CUDNN_ERROR(cudnnRNNForwardInferenceEx(
1570 /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(),
1571 /*xDesc=*/input_desc.data_handle(), /*x=*/input_data.opaque(),
1572 /*hxDesc=*/input_h_desc.handle(), /*hx=*/input_h_data.opaque(),
1573 /*cxDesc=*/input_c_desc.handle(), /*cx=*/input_c_data.opaque(),
1574 /*wDesc=*/rnn_desc.params_handle(), /*w=*/params.opaque(),
1575 /*yDesc=*/output_desc.data_handle(),
1576 /*y=*/output_data->opaque(),
1577 /*hyDesc=*/output_h_desc.handle(), /*hy=*/output_h_data->opaque(),
1578 /*cyDesc=*/output_c_desc.handle(), /*cy=*/output_c_data->opaque(),
1579 nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr,
1580 nullptr,
1581 /*workspace=*/workspace.opaque(),
1582 /*workSpaceSizeInBytes=*/workspace.size()));
1583 #else
1584 return port::Status(port::error::INVALID_ARGUMENT,
1585 "No supported cudnnRNNForwardInferenceEx when "
1586 "CUDNN_VERSION < 7.2.1");
1587 #endif
1588 } else {
1589 RETURN_IF_CUDNN_ERROR(cudnnRNNForwardInference(
1590 /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(),
1591 /*seqLength=*/model_dims.max_seq_length,
1592 /*xDesc=*/input_desc.handles(),
1593 /*x=*/input_data.opaque(), /*hxDesc=*/input_h_desc.handle(),
1594 /*hx=*/input_h_data.opaque(), /*cxDesc=*/input_c_desc.handle(),
1595 /*cx=*/input_c_data.opaque(), /*wDesc=*/rnn_desc.params_handle(),
1596 /*w=*/params.opaque(), /*yDesc=*/output_desc.handles(),
1597 /*y=*/output_data->opaque(), /*hyDesc=*/output_h_desc.handle(),
1598 /*hy=*/output_h_data->opaque(), /*cyDesc=*/output_c_desc.handle(),
1599 /*cy=*/output_c_data->opaque(), /*workspace=*/workspace.opaque(),
1600 /*workSpaceSizeInBytes=*/workspace.size()));
1601 }
1602 } else {
1603 if (input_desc.is_var_seq_lengths()) {
1604 #if CUDNN_VERSION >= 7201
1605 // cudnnSetRNNPaddingMode(rnn_desc.handle(), CUDNN_RNN_PADDED_IO_ENABLED);
1606 RETURN_IF_CUDNN_ERROR(cudnnRNNForwardTrainingEx(
1607 /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(),
1608 /*xDesc=*/input_desc.data_handle(), /*x=*/input_data.opaque(),
1609 /*hxDesc=*/input_h_desc.handle(), /*hx=*/input_h_data.opaque(),
1610 /*cxDesc=*/input_c_desc.handle(), /*cx=*/input_c_data.opaque(),
1611 /*wDesc=*/rnn_desc.params_handle(), /*w=*/params.opaque(),
1612 /*yDesc=*/output_desc.data_handle(),
1613 /*y=*/output_data->opaque(),
1614 /*hyDesc=*/output_h_desc.handle(), /*hy=*/output_h_data->opaque(),
1615 /*cyDesc=*/output_c_desc.handle(), /*cy=*/output_c_data->opaque(),
1616 nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr,
1617 nullptr,
1618 /*workspace=*/workspace.opaque(),
1619 /*workSpaceSizeInBytes=*/workspace.size(),
1620 /*reserveSpace=*/reserve_space.opaque(),
1621 /*reserveSpaceSizeInBytes=*/reserve_space.size()));
1622 #else
1623 return port::Status(port::error::INVALID_ARGUMENT,
1624 "No supported cudnnRNNForwardTrainingEx when "
1625 "CUDNN_VERSION < 7.2.1");
1626 #endif
1627 } else {
1628 RETURN_IF_CUDNN_ERROR(cudnnRNNForwardTraining(
1629 /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(),
1630 /*seqLength=*/model_dims.max_seq_length,
1631 /*xDesc=*/input_desc.handles(),
1632 /*x=*/input_data.opaque(), /*hxDesc=*/input_h_desc.handle(),
1633 /*hx=*/input_h_data.opaque(), /*cxDesc=*/input_c_desc.handle(),
1634 /*cx=*/input_c_data.opaque(), /*wDesc=*/rnn_desc.params_handle(),
1635 /*w=*/params.opaque(), /*yDesc=*/output_desc.handles(),
1636 /*y=*/output_data->opaque(), /*hyDesc=*/output_h_desc.handle(),
1637 /*hy=*/output_h_data->opaque(), /*cyDesc=*/output_c_desc.handle(),
1638 /*cy=*/output_c_data->opaque(), /*workspace=*/workspace.opaque(),
1639 /*workSpaceSizeInBytes=*/workspace.size(),
1640 /*reserveSpace=*/reserve_space.opaque(),
1641 /*reserveSpaceSizeInBytes=*/reserve_space.size()));
1642 }
1643 }
1644
1645 if (is_profiling) {
1646 if (!timer->Stop(AsGpuStream(stream))) {
1647 return port::Status(port::error::INTERNAL, "Failed to stop timer");
1648 }
1649 auto algo_desc = *rnn_desc.algorithm_config().algorithm();
1650 output_profile_result->set_algorithm(algo_desc);
1651 output_profile_result->set_elapsed_time_in_ms(
1652 timer->GetElapsedMilliseconds());
1653 }
1654
1655 return port::Status::OK();
1656 }
1657
1658 template <class T>
DoRnnBackwardImpl(Stream * stream,const CudnnRnnDescriptor & rnn_desc,const CudnnRnnSequenceTensorDescriptor & input_desc,const DeviceMemory<T> & input_data,const CudnnRnnStateTensorDescriptor & input_h_desc,const DeviceMemory<T> & input_h_data,const CudnnRnnStateTensorDescriptor & input_c_desc,const DeviceMemory<T> & input_c_data,const DeviceMemory<T> & params,const CudnnRnnSequenceTensorDescriptor & output_desc,const DeviceMemory<T> & output_data,const CudnnRnnStateTensorDescriptor & output_h_desc,const DeviceMemory<T> & output_h_data,const CudnnRnnStateTensorDescriptor & output_c_desc,const DeviceMemory<T> & output_c_data,const DeviceMemory<T> & output_backprop_data,const DeviceMemory<T> & output_h_backprop_data,const DeviceMemory<T> & output_c_backprop_data,DeviceMemory<T> * input_backprop_data,DeviceMemory<T> * input_h_backprop_data,DeviceMemory<T> * input_c_backprop_data,DeviceMemory<T> * params_backprop_data,DeviceMemory<uint8> * reserve_space_data,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)1659 port::Status CudnnSupport::DoRnnBackwardImpl(
1660 Stream* stream, const CudnnRnnDescriptor& rnn_desc,
1661 const CudnnRnnSequenceTensorDescriptor& input_desc,
1662 const DeviceMemory<T>& input_data,
1663 const CudnnRnnStateTensorDescriptor& input_h_desc,
1664 const DeviceMemory<T>& input_h_data,
1665 const CudnnRnnStateTensorDescriptor& input_c_desc,
1666 const DeviceMemory<T>& input_c_data, const DeviceMemory<T>& params,
1667 const CudnnRnnSequenceTensorDescriptor& output_desc,
1668 const DeviceMemory<T>& output_data,
1669 const CudnnRnnStateTensorDescriptor& output_h_desc,
1670 const DeviceMemory<T>& output_h_data,
1671 const CudnnRnnStateTensorDescriptor& output_c_desc,
1672 const DeviceMemory<T>& output_c_data,
1673 const DeviceMemory<T>& output_backprop_data,
1674 const DeviceMemory<T>& output_h_backprop_data,
1675 const DeviceMemory<T>& output_c_backprop_data,
1676 DeviceMemory<T>* input_backprop_data,
1677 DeviceMemory<T>* input_h_backprop_data,
1678 DeviceMemory<T>* input_c_backprop_data,
1679 DeviceMemory<T>* params_backprop_data,
1680 DeviceMemory<uint8>* reserve_space_data,
1681 ScratchAllocator* workspace_allocator,
1682 dnn::ProfileResult* output_profile_result) {
1683 SE_ASSIGN_OR_RETURN(
1684 RnnModelDims model_dims,
1685 ExtractAndCheckRnnForward(rnn_desc, input_desc, input_data, input_h_desc,
1686 input_h_data, input_c_desc, input_c_data,
1687 params, output_desc, output_data, output_h_desc,
1688 output_h_data, output_c_desc, output_c_data));
1689
1690 auto cudnn = cudnn_->GetHandle(parent_, stream);
1691
1692 SE_RETURN_IF_ERROR(CheckRNNParameterSize(cudnn, rnn_desc, input_desc));
1693 SE_ASSIGN_OR_RETURN(DeviceMemory<uint8> workspace,
1694 CreateRnnWorkspace(stream, cudnn, rnn_desc, input_desc,
1695 workspace_allocator));
1696
1697 std::unique_ptr<GpuTimer, GpuTimerDeleter> timer;
1698 const bool is_profiling = output_profile_result != nullptr;
1699 if (is_profiling) {
1700 timer.reset(new GpuTimer(parent_));
1701 // The start and stop of the timer should be as close to the Cudnn call as
1702 // possible. It is still possible for other threads to issue workload on
1703 // to this stream. So it could take multiple profiling measurements.
1704 if (!timer->Init() || !timer->Start(AsGpuStream(stream))) {
1705 return port::Status(port::error::INTERNAL, "Failed to start timer");
1706 }
1707 }
1708
1709 if (input_desc.is_var_seq_lengths()) {
1710 #if CUDNN_VERSION >= 7201
1711 RETURN_IF_CUDNN_ERROR(cudnnRNNBackwardDataEx(
1712 /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(),
1713 /*yDesc=*/output_desc.data_handle(), /*y=*/output_data.opaque(),
1714 /*dyDesc=*/output_desc.data_handle(),
1715 /*dy=*/output_backprop_data.opaque(), nullptr, nullptr,
1716 /*dhyDesc=*/output_h_desc.handle(),
1717 /*dhy=*/output_h_backprop_data.opaque(),
1718 /*dcyDesc=*/output_c_desc.handle(),
1719 /*dcy=*/output_c_backprop_data.opaque(),
1720 /*wDesc=*/rnn_desc.params_handle(), /*w=*/params.opaque(),
1721 /*hxDesc=*/input_h_desc.handle(), /*hx=*/input_h_data.opaque(),
1722 /*cxDesc=*/input_c_desc.handle(), /*cx=*/input_c_data.opaque(),
1723 /*dxDesc=*/input_desc.data_handle(),
1724 /*dx=*/input_backprop_data->opaque(),
1725 /*dhxDesc=*/input_h_desc.handle(),
1726 /*dhx=*/input_h_backprop_data->opaque(),
1727 /*dcxDesc=*/input_c_desc.handle(),
1728 /*dcx=*/input_c_backprop_data->opaque(), nullptr, nullptr,
1729 /*workspace=*/workspace.opaque(),
1730 /*workSpaceSizeInBytes=*/workspace.size(),
1731 /*reserveSpace=*/reserve_space_data->opaque(),
1732 /*reserveSpaceSizeInBytes=*/reserve_space_data->size()));
1733 #else
1734 return port::Status(port::error::INVALID_ARGUMENT,
1735 "No supported cudnnRNNBackwardDataEx when "
1736 "CUDNN_VERSION < 7.2.1");
1737 #endif
1738 } else {
1739 RETURN_IF_CUDNN_ERROR(cudnnRNNBackwardData(
1740 /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(),
1741 /*seqLength=*/model_dims.max_seq_length,
1742 /*yDesc=*/output_desc.handles(),
1743 /*y=*/output_data.opaque(), /*dyDesc=*/output_desc.handles(),
1744 /*dy=*/output_backprop_data.opaque(),
1745 /*dhyDesc=*/output_h_desc.handle(),
1746 /*dhy=*/output_h_backprop_data.opaque(),
1747 /*dcyDesc=*/output_c_desc.handle(),
1748 /*dcy=*/output_c_backprop_data.opaque(),
1749 /*wDesc=*/rnn_desc.params_handle(), /*w=*/params.opaque(),
1750 /*hxDesc=*/input_h_desc.handle(), /*hx=*/input_h_data.opaque(),
1751 /*cxDesc=*/input_c_desc.handle(), /*cx=*/input_c_data.opaque(),
1752 /*dxDesc=*/input_desc.handles(), /*dx=*/input_backprop_data->opaque(),
1753 /*dhxDesc=*/input_h_desc.handle(),
1754 /*dhx=*/input_h_backprop_data->opaque(),
1755 /*dcxDesc=*/input_c_desc.handle(),
1756 /*dcx=*/input_c_backprop_data->opaque(),
1757 /*workspace=*/workspace.opaque(),
1758 /*workSpaceSizeInBytes=*/workspace.size(),
1759 /*reserveSpace=*/reserve_space_data->opaque(),
1760 /*reserveSpaceSizeInBytes=*/reserve_space_data->size()));
1761 }
1762
1763 if (params_backprop_data != nullptr) {
1764 // Clear the dw to zeros.
1765 stream->ThenMemZero(params_backprop_data, params_backprop_data->size());
1766 if (input_desc.is_var_seq_lengths()) {
1767 #if CUDNN_VERSION >= 7201
1768 RETURN_IF_CUDNN_ERROR(cudnnRNNBackwardWeightsEx(
1769 /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(),
1770 /*xDesc=*/input_desc.data_handle(), /*x=*/input_data.opaque(),
1771 /*hxDesc=*/input_h_desc.handle(), /*hx=*/input_h_data.opaque(),
1772 /*yDesc=*/output_desc.data_handle(),
1773 /*y=*/output_data.opaque(),
1774 /*workspace=*/workspace.opaque(),
1775 /*workSpaceSizeInBytes=*/workspace.size(),
1776 /*dwDesc=*/rnn_desc.params_handle(),
1777 /*dw=*/params_backprop_data->opaque(),
1778 /*reserveSpace=*/reserve_space_data->opaque(),
1779 /*reserveSpaceSizeInBytes=*/reserve_space_data->size()));
1780 #else
1781 return port::Status(port::error::INVALID_ARGUMENT,
1782 "No supported cudnnRNNBackwardWeightsEx when "
1783 "CUDNN_VERSION < 7.2.1");
1784 #endif
1785 } else {
1786 // make the backward weight call
1787 RETURN_IF_CUDNN_ERROR(cudnnRNNBackwardWeights(
1788 /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(),
1789 /*seqLength=*/model_dims.max_seq_length,
1790 /*xDesc=*/input_desc.handles(),
1791 /*x=*/input_data.opaque(), /*hxDesc=*/input_h_desc.handle(),
1792 /*hx=*/input_h_data.opaque(), /*yDesc=*/output_desc.handles(),
1793 /*y=*/output_data.opaque(), /*workspace=*/workspace.opaque(),
1794 /*workSpaceSizeInBytes=*/workspace.size(),
1795 /*dwDesc=*/rnn_desc.params_handle(),
1796 /*dw=*/params_backprop_data->opaque(),
1797 /*reserveSpace=*/reserve_space_data->opaque(),
1798 /*reserveSpaceSizeInBytes=*/reserve_space_data->size()));
1799 }
1800 }
1801
1802 if (is_profiling) {
1803 if (!timer->Stop(AsGpuStream(stream))) {
1804 return port::Status(port::error::INTERNAL, "Failed to stop timer");
1805 }
1806 auto algo_desc = *rnn_desc.algorithm_config().algorithm();
1807 output_profile_result->set_algorithm(algo_desc);
1808 output_profile_result->set_elapsed_time_in_ms(
1809 timer->GetElapsedMilliseconds());
1810 }
1811
1812 return port::Status::OK();
1813 }
1814
1815 port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>>
createRnnDescriptor(int num_layers,int hidden_size,int input_size,int batch_size,dnn::RnnInputMode input_mode,dnn::RnnDirectionMode direction_mode,dnn::RnnMode rnn_mode,dnn::DataType data_type,const dnn::AlgorithmConfig & algorithm_config,float dropout,uint64 seed,ScratchAllocator * state_allocator)1816 CudnnSupport::createRnnDescriptor(
1817 int num_layers, int hidden_size, int input_size, int batch_size,
1818 dnn::RnnInputMode input_mode, dnn::RnnDirectionMode direction_mode,
1819 dnn::RnnMode rnn_mode, dnn::DataType data_type,
1820 const dnn::AlgorithmConfig& algorithm_config, float dropout, uint64 seed,
1821 ScratchAllocator* state_allocator) {
1822 // Setting up a cudnnRNNDescriptor requires a cuDNN handle, but because it's
1823 // not enqueueing anything into a stream, we pass in the null stream.
1824 auto cudnn = cudnn_->GetHandle(parent_, /*stream=*/nullptr);
1825 SE_ASSIGN_OR_RETURN(
1826 CudnnRnnDescriptor rnn_desc,
1827 CudnnRnnDescriptor::Create(
1828 cudnn, num_layers, hidden_size, input_size, batch_size,
1829 ToCudnnRnnInputMode(input_mode),
1830 ToCudnnRnnDirectionMode(direction_mode), ToCudnnRnnMode(rnn_mode),
1831 ToCudnnDataType(data_type), GetRnnComputeType(data_type),
1832 algorithm_config, dropout, seed, state_allocator));
1833 return std::unique_ptr<dnn::RnnDescriptor>(
1834 new CudnnRnnDescriptor(std::move(rnn_desc)));
1835 }
1836
1837 port::StatusOr<std::unique_ptr<dnn::RnnSequenceTensorDescriptor>>
createRnnSequenceTensorDescriptor(int max_seq_length,int batch_size,int data_size,dnn::DataType data_type)1838 CudnnSupport::createRnnSequenceTensorDescriptor(int max_seq_length,
1839 int batch_size, int data_size,
1840 dnn::DataType data_type) {
1841 SE_ASSIGN_OR_RETURN(CudnnRnnSequenceTensorDescriptor descriptor,
1842 CudnnRnnSequenceTensorDescriptor::Create(
1843 parent_, max_seq_length, batch_size, data_size,
1844 ToCudnnDataType(data_type)));
1845 return std::unique_ptr<dnn::RnnSequenceTensorDescriptor>(
1846 new CudnnRnnSequenceTensorDescriptor(std::move(descriptor)));
1847 }
1848
1849 port::StatusOr<std::unique_ptr<dnn::RnnSequenceTensorDescriptor>>
createRnnSequenceTensorDescriptor(int max_seq_length,int batch_size,int data_size,const absl::Span<const int> & seq_lengths,bool time_major,dnn::DataType data_type)1850 CudnnSupport::createRnnSequenceTensorDescriptor(
1851 int max_seq_length, int batch_size, int data_size,
1852 const absl::Span<const int>& seq_lengths, bool time_major,
1853 dnn::DataType data_type) {
1854 SE_ASSIGN_OR_RETURN(CudnnRnnSequenceTensorDescriptor descriptor,
1855 CudnnRnnSequenceTensorDescriptor::Create(
1856 parent_, max_seq_length, batch_size, data_size,
1857 seq_lengths, time_major, ToCudnnDataType(data_type)));
1858 return std::unique_ptr<dnn::RnnSequenceTensorDescriptor>(
1859 new CudnnRnnSequenceTensorDescriptor(std::move(descriptor)));
1860 }
1861
1862 port::StatusOr<std::unique_ptr<dnn::RnnStateTensorDescriptor>>
createRnnStateTensorDescriptor(int num_layer,int batch_size,int data_size,dnn::DataType data_type)1863 CudnnSupport::createRnnStateTensorDescriptor(int num_layer, int batch_size,
1864 int data_size,
1865 dnn::DataType data_type) {
1866 return std::unique_ptr<dnn::RnnStateTensorDescriptor>(
1867 new CudnnRnnStateTensorDescriptor(parent_, num_layer, batch_size,
1868 data_size, ToCudnnDataType(data_type)));
1869 }
1870
DoRnnForward(Stream * stream,const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<Eigen::half> & input_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<Eigen::half> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<Eigen::half> & input_c_data,const DeviceMemory<Eigen::half> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,DeviceMemory<Eigen::half> * output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,DeviceMemory<Eigen::half> * output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,DeviceMemory<Eigen::half> * output_c_data,bool is_training,ScratchAllocator * reserve_space_allocator,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)1871 bool CudnnSupport::DoRnnForward(
1872 Stream* stream, const dnn::RnnDescriptor& rnn_desc,
1873 const dnn::RnnSequenceTensorDescriptor& input_desc,
1874 const DeviceMemory<Eigen::half>& input_data,
1875 const dnn::RnnStateTensorDescriptor& input_h_desc,
1876 const DeviceMemory<Eigen::half>& input_h_data,
1877 const dnn::RnnStateTensorDescriptor& input_c_desc,
1878 const DeviceMemory<Eigen::half>& input_c_data,
1879 const DeviceMemory<Eigen::half>& params,
1880 const dnn::RnnSequenceTensorDescriptor& output_desc,
1881 DeviceMemory<Eigen::half>* output_data,
1882 const dnn::RnnStateTensorDescriptor& output_h_desc,
1883 DeviceMemory<Eigen::half>* output_h_data,
1884 const dnn::RnnStateTensorDescriptor& output_c_desc,
1885 DeviceMemory<Eigen::half>* output_c_data, bool is_training,
1886 ScratchAllocator* reserve_space_allocator,
1887 ScratchAllocator* workspace_allocator,
1888 dnn::ProfileResult* output_profile_result) {
1889 const CudnnRnnDescriptor& cudnn_rnn_desc =
1890 static_cast<const CudnnRnnDescriptor&>(rnn_desc);
1891 const CudnnRnnSequenceTensorDescriptor& cudnn_input_desc =
1892 static_cast<const CudnnRnnSequenceTensorDescriptor&>(input_desc);
1893 const CudnnRnnStateTensorDescriptor& cudnn_input_h_desc =
1894 static_cast<const CudnnRnnStateTensorDescriptor&>(input_h_desc);
1895 const CudnnRnnStateTensorDescriptor& cudnn_input_c_desc =
1896 static_cast<const CudnnRnnStateTensorDescriptor&>(input_c_desc);
1897 const CudnnRnnSequenceTensorDescriptor& cudnn_output_desc =
1898 static_cast<const CudnnRnnSequenceTensorDescriptor&>(output_desc);
1899 const CudnnRnnStateTensorDescriptor& cudnn_output_h_desc =
1900 static_cast<const CudnnRnnStateTensorDescriptor&>(output_h_desc);
1901 const CudnnRnnStateTensorDescriptor& cudnn_output_c_desc =
1902 static_cast<const CudnnRnnStateTensorDescriptor&>(output_c_desc);
1903 return IsStatusOk(
1904 DoRnnForwardImpl<Eigen::half>(
1905 stream, cudnn_rnn_desc, cudnn_input_desc, input_data,
1906 cudnn_input_h_desc, input_h_data, cudnn_input_c_desc, input_c_data,
1907 params, cudnn_output_desc, output_data, cudnn_output_h_desc,
1908 output_h_data, cudnn_output_c_desc, output_c_data, is_training,
1909 reserve_space_allocator, workspace_allocator, output_profile_result),
1910 /*report_error=*/!output_profile_result);
1911 }
1912
DoRnnForward(Stream * stream,const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<float> & input_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<float> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<float> & input_c_data,const DeviceMemory<float> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,DeviceMemory<float> * output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,DeviceMemory<float> * output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,DeviceMemory<float> * output_c_data,bool is_training,ScratchAllocator * reserve_space_allocator,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)1913 bool CudnnSupport::DoRnnForward(
1914 Stream* stream, const dnn::RnnDescriptor& rnn_desc,
1915 const dnn::RnnSequenceTensorDescriptor& input_desc,
1916 const DeviceMemory<float>& input_data,
1917 const dnn::RnnStateTensorDescriptor& input_h_desc,
1918 const DeviceMemory<float>& input_h_data,
1919 const dnn::RnnStateTensorDescriptor& input_c_desc,
1920 const DeviceMemory<float>& input_c_data, const DeviceMemory<float>& params,
1921 const dnn::RnnSequenceTensorDescriptor& output_desc,
1922 DeviceMemory<float>* output_data,
1923 const dnn::RnnStateTensorDescriptor& output_h_desc,
1924 DeviceMemory<float>* output_h_data,
1925 const dnn::RnnStateTensorDescriptor& output_c_desc,
1926 DeviceMemory<float>* output_c_data, bool is_training,
1927 ScratchAllocator* reserve_space_allocator,
1928 ScratchAllocator* workspace_allocator,
1929 dnn::ProfileResult* output_profile_result) {
1930 const CudnnRnnDescriptor& cudnn_rnn_desc =
1931 static_cast<const CudnnRnnDescriptor&>(rnn_desc);
1932 const CudnnRnnSequenceTensorDescriptor& cudnn_input_desc =
1933 static_cast<const CudnnRnnSequenceTensorDescriptor&>(input_desc);
1934 const CudnnRnnStateTensorDescriptor& cudnn_input_h_desc =
1935 static_cast<const CudnnRnnStateTensorDescriptor&>(input_h_desc);
1936 const CudnnRnnStateTensorDescriptor& cudnn_input_c_desc =
1937 static_cast<const CudnnRnnStateTensorDescriptor&>(input_c_desc);
1938 const CudnnRnnSequenceTensorDescriptor& cudnn_output_desc =
1939 static_cast<const CudnnRnnSequenceTensorDescriptor&>(output_desc);
1940 const CudnnRnnStateTensorDescriptor& cudnn_output_h_desc =
1941 static_cast<const CudnnRnnStateTensorDescriptor&>(output_h_desc);
1942 const CudnnRnnStateTensorDescriptor& cudnn_output_c_desc =
1943 static_cast<const CudnnRnnStateTensorDescriptor&>(output_c_desc);
1944 return IsStatusOk(
1945 DoRnnForwardImpl<float>(
1946 stream, cudnn_rnn_desc, cudnn_input_desc, input_data,
1947 cudnn_input_h_desc, input_h_data, cudnn_input_c_desc, input_c_data,
1948 params, cudnn_output_desc, output_data, cudnn_output_h_desc,
1949 output_h_data, cudnn_output_c_desc, output_c_data, is_training,
1950 reserve_space_allocator, workspace_allocator, output_profile_result),
1951 /*report_error=*/!output_profile_result);
1952 }
1953
DoRnnForward(Stream * stream,const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<double> & input_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<double> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<double> & input_c_data,const DeviceMemory<double> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,DeviceMemory<double> * output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,DeviceMemory<double> * output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,DeviceMemory<double> * output_c_data,bool is_training,ScratchAllocator * reserve_space_allocator,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)1954 bool CudnnSupport::DoRnnForward(
1955 Stream* stream, const dnn::RnnDescriptor& rnn_desc,
1956 const dnn::RnnSequenceTensorDescriptor& input_desc,
1957 const DeviceMemory<double>& input_data,
1958 const dnn::RnnStateTensorDescriptor& input_h_desc,
1959 const DeviceMemory<double>& input_h_data,
1960 const dnn::RnnStateTensorDescriptor& input_c_desc,
1961 const DeviceMemory<double>& input_c_data,
1962 const DeviceMemory<double>& params,
1963 const dnn::RnnSequenceTensorDescriptor& output_desc,
1964 DeviceMemory<double>* output_data,
1965 const dnn::RnnStateTensorDescriptor& output_h_desc,
1966 DeviceMemory<double>* output_h_data,
1967 const dnn::RnnStateTensorDescriptor& output_c_desc,
1968 DeviceMemory<double>* output_c_data, bool is_training,
1969 ScratchAllocator* reserve_space_allocator,
1970 ScratchAllocator* workspace_allocator,
1971 dnn::ProfileResult* output_profile_result) {
1972 const CudnnRnnDescriptor& cudnn_rnn_desc =
1973 static_cast<const CudnnRnnDescriptor&>(rnn_desc);
1974 const CudnnRnnSequenceTensorDescriptor& cudnn_input_desc =
1975 static_cast<const CudnnRnnSequenceTensorDescriptor&>(input_desc);
1976 const CudnnRnnStateTensorDescriptor& cudnn_input_h_desc =
1977 static_cast<const CudnnRnnStateTensorDescriptor&>(input_h_desc);
1978 const CudnnRnnStateTensorDescriptor& cudnn_input_c_desc =
1979 static_cast<const CudnnRnnStateTensorDescriptor&>(input_c_desc);
1980 const CudnnRnnSequenceTensorDescriptor& cudnn_output_desc =
1981 static_cast<const CudnnRnnSequenceTensorDescriptor&>(output_desc);
1982 const CudnnRnnStateTensorDescriptor& cudnn_output_h_desc =
1983 static_cast<const CudnnRnnStateTensorDescriptor&>(output_h_desc);
1984 const CudnnRnnStateTensorDescriptor& cudnn_output_c_desc =
1985 static_cast<const CudnnRnnStateTensorDescriptor&>(output_c_desc);
1986 return IsStatusOk(
1987 DoRnnForwardImpl<double>(
1988 stream, cudnn_rnn_desc, cudnn_input_desc, input_data,
1989 cudnn_input_h_desc, input_h_data, cudnn_input_c_desc, input_c_data,
1990 params, cudnn_output_desc, output_data, cudnn_output_h_desc,
1991 output_h_data, cudnn_output_c_desc, output_c_data, is_training,
1992 reserve_space_allocator, workspace_allocator, output_profile_result),
1993 /*report_error=*/!output_profile_result);
1994 }
1995
DoRnnBackward(Stream * stream,const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<Eigen::half> & input_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<Eigen::half> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<Eigen::half> & input_c_data,const DeviceMemory<Eigen::half> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,const DeviceMemory<Eigen::half> & output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,const DeviceMemory<Eigen::half> & output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,const DeviceMemory<Eigen::half> & output_c_data,const DeviceMemory<Eigen::half> & output_backprop_data,const DeviceMemory<Eigen::half> & output_h_backprop_data,const DeviceMemory<Eigen::half> & output_c_backprop_data,DeviceMemory<Eigen::half> * input_backprop_data,DeviceMemory<Eigen::half> * input_h_backprop_data,DeviceMemory<Eigen::half> * input_c_backprop_data,DeviceMemory<Eigen::half> * params_backprop_data,DeviceMemory<uint8> * reserve_space_data,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)1996 bool CudnnSupport::DoRnnBackward(
1997 Stream* stream, const dnn::RnnDescriptor& rnn_desc,
1998 const dnn::RnnSequenceTensorDescriptor& input_desc,
1999 const DeviceMemory<Eigen::half>& input_data,
2000 const dnn::RnnStateTensorDescriptor& input_h_desc,
2001 const DeviceMemory<Eigen::half>& input_h_data,
2002 const dnn::RnnStateTensorDescriptor& input_c_desc,
2003 const DeviceMemory<Eigen::half>& input_c_data,
2004 const DeviceMemory<Eigen::half>& params,
2005 const dnn::RnnSequenceTensorDescriptor& output_desc,
2006 const DeviceMemory<Eigen::half>& output_data,
2007 const dnn::RnnStateTensorDescriptor& output_h_desc,
2008 const DeviceMemory<Eigen::half>& output_h_data,
2009 const dnn::RnnStateTensorDescriptor& output_c_desc,
2010 const DeviceMemory<Eigen::half>& output_c_data,
2011 const DeviceMemory<Eigen::half>& output_backprop_data,
2012 const DeviceMemory<Eigen::half>& output_h_backprop_data,
2013 const DeviceMemory<Eigen::half>& output_c_backprop_data,
2014 DeviceMemory<Eigen::half>* input_backprop_data,
2015 DeviceMemory<Eigen::half>* input_h_backprop_data,
2016 DeviceMemory<Eigen::half>* input_c_backprop_data,
2017 DeviceMemory<Eigen::half>* params_backprop_data,
2018 DeviceMemory<uint8>* reserve_space_data,
2019 ScratchAllocator* workspace_allocator,
2020 dnn::ProfileResult* output_profile_result) {
2021 const CudnnRnnDescriptor& cudnn_rnn_desc =
2022 static_cast<const CudnnRnnDescriptor&>(rnn_desc);
2023 const CudnnRnnSequenceTensorDescriptor& cudnn_input_desc =
2024 static_cast<const CudnnRnnSequenceTensorDescriptor&>(input_desc);
2025 const CudnnRnnStateTensorDescriptor& cudnn_input_h_desc =
2026 static_cast<const CudnnRnnStateTensorDescriptor&>(input_h_desc);
2027 const CudnnRnnStateTensorDescriptor& cudnn_input_c_desc =
2028 static_cast<const CudnnRnnStateTensorDescriptor&>(input_c_desc);
2029 const CudnnRnnSequenceTensorDescriptor& cudnn_output_desc =
2030 static_cast<const CudnnRnnSequenceTensorDescriptor&>(output_desc);
2031 const CudnnRnnStateTensorDescriptor& cudnn_output_h_desc =
2032 static_cast<const CudnnRnnStateTensorDescriptor&>(output_h_desc);
2033 const CudnnRnnStateTensorDescriptor& cudnn_output_c_desc =
2034 static_cast<const CudnnRnnStateTensorDescriptor&>(output_c_desc);
2035 return IsStatusOk(
2036 DoRnnBackwardImpl<Eigen::half>(
2037 stream, cudnn_rnn_desc, cudnn_input_desc, input_data,
2038 cudnn_input_h_desc, input_h_data, cudnn_input_c_desc, input_c_data,
2039 params, cudnn_output_desc, output_data, cudnn_output_h_desc,
2040 output_h_data, cudnn_output_c_desc, output_c_data,
2041 output_backprop_data, output_h_backprop_data, output_c_backprop_data,
2042 input_backprop_data, input_h_backprop_data, input_c_backprop_data,
2043 params_backprop_data, reserve_space_data, workspace_allocator,
2044 output_profile_result),
2045 /*report_error=*/!output_profile_result);
2046 }
2047
DoRnnBackward(Stream * stream,const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<float> & input_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<float> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<float> & input_c_data,const DeviceMemory<float> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,const DeviceMemory<float> & output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,const DeviceMemory<float> & output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,const DeviceMemory<float> & output_c_data,const DeviceMemory<float> & output_backprop_data,const DeviceMemory<float> & output_h_backprop_data,const DeviceMemory<float> & output_c_backprop_data,DeviceMemory<float> * input_backprop_data,DeviceMemory<float> * input_h_backprop_data,DeviceMemory<float> * input_c_backprop_data,DeviceMemory<float> * params_backprop_data,DeviceMemory<uint8> * reserve_space_data,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)2048 bool CudnnSupport::DoRnnBackward(
2049 Stream* stream, const dnn::RnnDescriptor& rnn_desc,
2050 const dnn::RnnSequenceTensorDescriptor& input_desc,
2051 const DeviceMemory<float>& input_data,
2052 const dnn::RnnStateTensorDescriptor& input_h_desc,
2053 const DeviceMemory<float>& input_h_data,
2054 const dnn::RnnStateTensorDescriptor& input_c_desc,
2055 const DeviceMemory<float>& input_c_data, const DeviceMemory<float>& params,
2056 const dnn::RnnSequenceTensorDescriptor& output_desc,
2057 const DeviceMemory<float>& output_data,
2058 const dnn::RnnStateTensorDescriptor& output_h_desc,
2059 const DeviceMemory<float>& output_h_data,
2060 const dnn::RnnStateTensorDescriptor& output_c_desc,
2061 const DeviceMemory<float>& output_c_data,
2062 const DeviceMemory<float>& output_backprop_data,
2063 const DeviceMemory<float>& output_h_backprop_data,
2064 const DeviceMemory<float>& output_c_backprop_data,
2065 DeviceMemory<float>* input_backprop_data,
2066 DeviceMemory<float>* input_h_backprop_data,
2067 DeviceMemory<float>* input_c_backprop_data,
2068 DeviceMemory<float>* params_backprop_data,
2069 DeviceMemory<uint8>* reserve_space_data,
2070 ScratchAllocator* workspace_allocator,
2071 dnn::ProfileResult* output_profile_result) {
2072 const CudnnRnnDescriptor& cudnn_rnn_desc =
2073 static_cast<const CudnnRnnDescriptor&>(rnn_desc);
2074 const CudnnRnnSequenceTensorDescriptor& cudnn_input_desc =
2075 static_cast<const CudnnRnnSequenceTensorDescriptor&>(input_desc);
2076 const CudnnRnnStateTensorDescriptor& cudnn_input_h_desc =
2077 static_cast<const CudnnRnnStateTensorDescriptor&>(input_h_desc);
2078 const CudnnRnnStateTensorDescriptor& cudnn_input_c_desc =
2079 static_cast<const CudnnRnnStateTensorDescriptor&>(input_c_desc);
2080 const CudnnRnnSequenceTensorDescriptor& cudnn_output_desc =
2081 static_cast<const CudnnRnnSequenceTensorDescriptor&>(output_desc);
2082 const CudnnRnnStateTensorDescriptor& cudnn_output_h_desc =
2083 static_cast<const CudnnRnnStateTensorDescriptor&>(output_h_desc);
2084 const CudnnRnnStateTensorDescriptor& cudnn_output_c_desc =
2085 static_cast<const CudnnRnnStateTensorDescriptor&>(output_c_desc);
2086 return IsStatusOk(
2087 DoRnnBackwardImpl<float>(
2088 stream, cudnn_rnn_desc, cudnn_input_desc, input_data,
2089 cudnn_input_h_desc, input_h_data, cudnn_input_c_desc, input_c_data,
2090 params, cudnn_output_desc, output_data, cudnn_output_h_desc,
2091 output_h_data, cudnn_output_c_desc, output_c_data,
2092 output_backprop_data, output_h_backprop_data, output_c_backprop_data,
2093 input_backprop_data, input_h_backprop_data, input_c_backprop_data,
2094 params_backprop_data, reserve_space_data, workspace_allocator,
2095 output_profile_result),
2096 /*report_error=*/!output_profile_result);
2097 }
2098
DoRnnBackward(Stream * stream,const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<double> & input_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<double> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<double> & input_c_data,const DeviceMemory<double> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,const DeviceMemory<double> & output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,const DeviceMemory<double> & output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,const DeviceMemory<double> & output_c_data,const DeviceMemory<double> & output_backprop_data,const DeviceMemory<double> & output_h_backprop_data,const DeviceMemory<double> & output_c_backprop_data,DeviceMemory<double> * input_backprop_data,DeviceMemory<double> * input_h_backprop_data,DeviceMemory<double> * input_c_backprop_data,DeviceMemory<double> * params_backprop_data,DeviceMemory<uint8> * reserve_space_data,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)2099 bool CudnnSupport::DoRnnBackward(
2100 Stream* stream, const dnn::RnnDescriptor& rnn_desc,
2101 const dnn::RnnSequenceTensorDescriptor& input_desc,
2102 const DeviceMemory<double>& input_data,
2103 const dnn::RnnStateTensorDescriptor& input_h_desc,
2104 const DeviceMemory<double>& input_h_data,
2105 const dnn::RnnStateTensorDescriptor& input_c_desc,
2106 const DeviceMemory<double>& input_c_data,
2107 const DeviceMemory<double>& params,
2108 const dnn::RnnSequenceTensorDescriptor& output_desc,
2109 const DeviceMemory<double>& output_data,
2110 const dnn::RnnStateTensorDescriptor& output_h_desc,
2111 const DeviceMemory<double>& output_h_data,
2112 const dnn::RnnStateTensorDescriptor& output_c_desc,
2113 const DeviceMemory<double>& output_c_data,
2114 const DeviceMemory<double>& output_backprop_data,
2115 const DeviceMemory<double>& output_h_backprop_data,
2116 const DeviceMemory<double>& output_c_backprop_data,
2117 DeviceMemory<double>* input_backprop_data,
2118 DeviceMemory<double>* input_h_backprop_data,
2119 DeviceMemory<double>* input_c_backprop_data,
2120 DeviceMemory<double>* params_backprop_data,
2121 DeviceMemory<uint8>* reserve_space_data,
2122 ScratchAllocator* workspace_allocator,
2123 dnn::ProfileResult* output_profile_result) {
2124 const CudnnRnnDescriptor& cudnn_rnn_desc =
2125 static_cast<const CudnnRnnDescriptor&>(rnn_desc);
2126 const CudnnRnnSequenceTensorDescriptor& cudnn_input_desc =
2127 static_cast<const CudnnRnnSequenceTensorDescriptor&>(input_desc);
2128 const CudnnRnnStateTensorDescriptor& cudnn_input_h_desc =
2129 static_cast<const CudnnRnnStateTensorDescriptor&>(input_h_desc);
2130 const CudnnRnnStateTensorDescriptor& cudnn_input_c_desc =
2131 static_cast<const CudnnRnnStateTensorDescriptor&>(input_c_desc);
2132 const CudnnRnnSequenceTensorDescriptor& cudnn_output_desc =
2133 static_cast<const CudnnRnnSequenceTensorDescriptor&>(output_desc);
2134 const CudnnRnnStateTensorDescriptor& cudnn_output_h_desc =
2135 static_cast<const CudnnRnnStateTensorDescriptor&>(output_h_desc);
2136 const CudnnRnnStateTensorDescriptor& cudnn_output_c_desc =
2137 static_cast<const CudnnRnnStateTensorDescriptor&>(output_c_desc);
2138 return IsStatusOk(
2139 DoRnnBackwardImpl<double>(
2140 stream, cudnn_rnn_desc, cudnn_input_desc, input_data,
2141 cudnn_input_h_desc, input_h_data, cudnn_input_c_desc, input_c_data,
2142 params, cudnn_output_desc, output_data, cudnn_output_h_desc,
2143 output_h_data, cudnn_output_c_desc, output_c_data,
2144 output_backprop_data, output_h_backprop_data, output_c_backprop_data,
2145 input_backprop_data, input_h_backprop_data, input_c_backprop_data,
2146 params_backprop_data, reserve_space_data, workspace_allocator,
2147 output_profile_result),
2148 /*report_error=*/!output_profile_result);
2149 }
2150
2151 namespace {
2152
2153 // TODO(csigg): Merge a lot of duplicate code below for forward, backward data,
2154 // and backward filter.
2155
GetCudnnConvolutionForwardAlgo(const CudnnHandle & cudnn,const CudnnTensorDescriptor & input_nd,const CudnnFilterDescriptor & filter,const CudnnConvolutionDescriptor & conv,const CudnnTensorDescriptor & output_nd,bool specify_workspace_limit,size_t memory_limit_bytes)2156 port::StatusOr<cudnnConvolutionFwdAlgo_t> GetCudnnConvolutionForwardAlgo(
2157 const CudnnHandle& cudnn, const CudnnTensorDescriptor& input_nd,
2158 const CudnnFilterDescriptor& filter, const CudnnConvolutionDescriptor& conv,
2159 const CudnnTensorDescriptor& output_nd, bool specify_workspace_limit,
2160 size_t memory_limit_bytes) {
2161 cudnnConvolutionFwdPreference_t preference =
2162 specify_workspace_limit ? CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT
2163 : CUDNN_CONVOLUTION_FWD_NO_WORKSPACE;
2164 cudnnConvolutionFwdAlgo_t algo_to_use;
2165 RETURN_IF_CUDNN_ERROR(cudnnGetConvolutionForwardAlgorithm(
2166 cudnn.handle(), input_nd.handle(), filter.handle(), conv.handle(),
2167 output_nd.handle(), preference, memory_limit_bytes, &algo_to_use));
2168 return algo_to_use;
2169 }
2170
2171 port::StatusOr<cudnnConvolutionBwdDataAlgo_t>
GetCudnnConvolutionBackwardDataAlgo(const CudnnHandle & cudnn,const CudnnTensorDescriptor & input_nd,const CudnnFilterDescriptor & filter,const CudnnConvolutionDescriptor & conv,const CudnnTensorDescriptor & output_nd,bool specify_workspace_limit,size_t memory_limit_bytes)2172 GetCudnnConvolutionBackwardDataAlgo(const CudnnHandle& cudnn,
2173 const CudnnTensorDescriptor& input_nd,
2174 const CudnnFilterDescriptor& filter,
2175 const CudnnConvolutionDescriptor& conv,
2176 const CudnnTensorDescriptor& output_nd,
2177 bool specify_workspace_limit,
2178 size_t memory_limit_bytes) {
2179 cudnnConvolutionBwdDataPreference_t preference =
2180 specify_workspace_limit
2181 ? CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT
2182 : CUDNN_CONVOLUTION_BWD_DATA_NO_WORKSPACE;
2183 cudnnConvolutionBwdDataAlgo_t algo_to_use;
2184 RETURN_IF_CUDNN_ERROR(cudnnGetConvolutionBackwardDataAlgorithm(
2185 cudnn.handle(), filter.handle(), output_nd.handle(), conv.handle(),
2186 input_nd.handle(), preference, memory_limit_bytes, &algo_to_use));
2187 return algo_to_use;
2188 }
2189
2190 port::StatusOr<cudnnConvolutionBwdFilterAlgo_t>
GetCudnnConvolutionBackwardFilterAlgo(const CudnnHandle & cudnn,const CudnnTensorDescriptor & input_nd,const CudnnFilterDescriptor & filter,const CudnnConvolutionDescriptor & conv,const CudnnTensorDescriptor & output_nd,bool specify_workspace_limit,size_t memory_limit_bytes)2191 GetCudnnConvolutionBackwardFilterAlgo(const CudnnHandle& cudnn,
2192 const CudnnTensorDescriptor& input_nd,
2193 const CudnnFilterDescriptor& filter,
2194 const CudnnConvolutionDescriptor& conv,
2195 const CudnnTensorDescriptor& output_nd,
2196 bool specify_workspace_limit,
2197 size_t memory_limit_bytes) {
2198 cudnnConvolutionBwdFilterPreference_t preference =
2199 specify_workspace_limit
2200 ? CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT
2201 : CUDNN_CONVOLUTION_BWD_FILTER_NO_WORKSPACE;
2202 cudnnConvolutionBwdFilterAlgo_t algo_to_use;
2203 RETURN_IF_CUDNN_ERROR(cudnnGetConvolutionBackwardFilterAlgorithm(
2204 cudnn.handle(), input_nd.handle(), output_nd.handle(), conv.handle(),
2205 filter.handle(), preference, memory_limit_bytes, &algo_to_use));
2206 return algo_to_use;
2207 }
2208
AllocateCudnnConvolutionForwardWorkspace(Stream * stream,const CudnnHandle & cudnn,const CudnnTensorDescriptor & input_nd,const CudnnFilterDescriptor & filter,const CudnnConvolutionDescriptor & conv,const CudnnTensorDescriptor & output_nd,const dnn::AlgorithmDesc & algorithm_desc,ScratchAllocator * scratch_allocator)2209 port::StatusOr<DeviceMemory<uint8>> AllocateCudnnConvolutionForwardWorkspace(
2210 Stream* stream, const CudnnHandle& cudnn,
2211 const CudnnTensorDescriptor& input_nd, const CudnnFilterDescriptor& filter,
2212 const CudnnConvolutionDescriptor& conv,
2213 const CudnnTensorDescriptor& output_nd,
2214 const dnn::AlgorithmDesc& algorithm_desc,
2215 ScratchAllocator* scratch_allocator) {
2216 // TODO(csigg): This has side effects on the convolution descriptor. It is
2217 // functionally correct because the convolution is run with the algorithm of
2218 // the last call to this function, but should be fixed anyway.
2219 conv.set_use_tensor_op_math(algorithm_desc.tensor_ops_enabled());
2220
2221 // Query the size of the workspace and allocate it.
2222 size_t size_in_bytes;
2223 RETURN_IF_CUDNN_ERROR(cudnnGetConvolutionForwardWorkspaceSize(
2224 cudnn.handle(),
2225 /*xDesc=*/input_nd.handle(),
2226 /*wDesc=*/filter.handle(), /*convDesc=*/conv.handle(),
2227 /*yDesc=*/output_nd.handle(), /*algo=*/ToConvForwardAlgo(algorithm_desc),
2228 /*sizeInBytes=*/&size_in_bytes));
2229
2230 int64 size_in_bytes_int64 = size_in_bytes;
2231
2232 if (TF_PREDICT_FALSE(size_in_bytes_int64 < 0)) {
2233 return port::Status(
2234 port::error::INTERNAL,
2235 "cudnnGetConvolutionForwardWorkspaceSize() returned "
2236 "negative sizeInBytes value. This could be a cudnn bug.");
2237 }
2238
2239 if (size_in_bytes_int64 == 0) {
2240 return DeviceMemory<uint8>();
2241 }
2242
2243 if (TF_PREDICT_FALSE(!scratch_allocator)) {
2244 return port::Status(port::error::INVALID_ARGUMENT,
2245 "No scratch allocator provided");
2246 }
2247
2248 return scratch_allocator->AllocateBytes(stream, size_in_bytes);
2249 }
2250
2251 port::StatusOr<DeviceMemory<uint8>>
AllocateCudnnConvolutionBackwardDataWorkspace(Stream * stream,const CudnnHandle & cudnn,const CudnnTensorDescriptor & input_nd,const CudnnFilterDescriptor & filter,const CudnnConvolutionDescriptor & conv,const CudnnTensorDescriptor & output_nd,const dnn::AlgorithmDesc & algorithm_desc,ScratchAllocator * scratch_allocator)2252 AllocateCudnnConvolutionBackwardDataWorkspace(
2253 Stream* stream, const CudnnHandle& cudnn,
2254 const CudnnTensorDescriptor& input_nd, const CudnnFilterDescriptor& filter,
2255 const CudnnConvolutionDescriptor& conv,
2256 const CudnnTensorDescriptor& output_nd,
2257 const dnn::AlgorithmDesc& algorithm_desc,
2258 ScratchAllocator* scratch_allocator) {
2259 // TODO(csigg): This has side effects on the convolution descriptor. It is
2260 // functionally correct because the convolution is run with the algorithm of
2261 // the last call to this function, but should be fixed anyway.
2262 conv.set_use_tensor_op_math(algorithm_desc.tensor_ops_enabled());
2263
2264 // Query the size of the workspace and allocate it.
2265 size_t size_in_bytes;
2266 RETURN_IF_CUDNN_ERROR(cudnnGetConvolutionBackwardDataWorkspaceSize(
2267 cudnn.handle(),
2268 /*wDesc=*/filter.handle(),
2269 /*dyDesc=*/output_nd.handle(),
2270 /*convDesc=*/conv.handle(),
2271 /*dxDesc=*/input_nd.handle(),
2272 /*algo=*/ToConvBackwardDataAlgo(algorithm_desc),
2273 /*sizeInBytes=*/&size_in_bytes));
2274
2275 int64 size_in_bytes_int64 = size_in_bytes;
2276
2277 if (TF_PREDICT_FALSE(size_in_bytes_int64 < 0)) {
2278 return port::Status(
2279 port::error::INTERNAL,
2280 "cudnnGetConvolutionBackwardDataWorkspaceSize() returned "
2281 "negative sizeInBytes value. This could be a cudnn bug.");
2282 }
2283
2284 if (size_in_bytes_int64 == 0) {
2285 return DeviceMemory<uint8>();
2286 }
2287
2288 if (TF_PREDICT_FALSE(!scratch_allocator)) {
2289 return port::Status(port::error::INVALID_ARGUMENT,
2290 "No scratch allocator provided");
2291 }
2292
2293 return scratch_allocator->AllocateBytes(stream, size_in_bytes);
2294 }
2295
2296 port::StatusOr<DeviceMemory<uint8>>
AllocateCudnnConvolutionBackwardFilterWorkspace(Stream * stream,const CudnnHandle & cudnn,const CudnnTensorDescriptor & input_nd,const CudnnFilterDescriptor & filter,const CudnnConvolutionDescriptor & conv,const CudnnTensorDescriptor & output_nd,const dnn::AlgorithmDesc & algorithm_desc,ScratchAllocator * scratch_allocator)2297 AllocateCudnnConvolutionBackwardFilterWorkspace(
2298 Stream* stream, const CudnnHandle& cudnn,
2299 const CudnnTensorDescriptor& input_nd, const CudnnFilterDescriptor& filter,
2300 const CudnnConvolutionDescriptor& conv,
2301 const CudnnTensorDescriptor& output_nd,
2302 const dnn::AlgorithmDesc& algorithm_desc,
2303 ScratchAllocator* scratch_allocator) {
2304 // TODO(csigg): This has side effects on the convolution descriptor. It is
2305 // functionally correct because the convolution is run with the algorithm of
2306 // the last call to this function, but should be fixed anyway.
2307 conv.set_use_tensor_op_math(algorithm_desc.tensor_ops_enabled());
2308
2309 // Query the size of the workspace and allocate it.
2310 size_t size_in_bytes;
2311 RETURN_IF_CUDNN_ERROR(cudnnGetConvolutionBackwardFilterWorkspaceSize(
2312 cudnn.handle(),
2313 /*xDesc=*/input_nd.handle(),
2314 /*dyDesc=*/output_nd.handle(),
2315 /*convDesc=*/conv.handle(),
2316 /*gradDesc=*/filter.handle(),
2317 /*algo=*/ToConvBackwardFilterAlgo(algorithm_desc),
2318 /*sizeInBytes=*/&size_in_bytes));
2319
2320 int64 size_in_bytes_int64 = size_in_bytes;
2321
2322 if (TF_PREDICT_FALSE(size_in_bytes_int64 < 0)) {
2323 return port::Status(
2324 port::error::INTERNAL,
2325 "cudnnGetConvolutionBackwardFilterWorkspaceSize() returned "
2326 "negative sizeInBytes value. This could be a cudnn bug.");
2327 }
2328
2329 if (size_in_bytes_int64 == 0) {
2330 return DeviceMemory<uint8>();
2331 }
2332
2333 if (TF_PREDICT_FALSE(!scratch_allocator)) {
2334 return port::Status(port::error::INVALID_ARGUMENT,
2335 "No scratch allocator provided");
2336 }
2337
2338 return scratch_allocator->AllocateBytes(stream, size_in_bytes);
2339 }
2340
GetCudnnConvolutionForwardAlgorithm(Stream * stream,const CudnnHandle & cudnn,const dnn::AlgorithmConfig & algorithm_config,const CudnnTensorDescriptor & input_nd,const CudnnFilterDescriptor & filter,const CudnnConvolutionDescriptor & conv,const CudnnTensorDescriptor & output_nd,ScratchAllocator * scratch_allocator,DeviceMemory<uint8> * scratch)2341 port::StatusOr<dnn::AlgorithmDesc> GetCudnnConvolutionForwardAlgorithm(
2342 Stream* stream, const CudnnHandle& cudnn,
2343 const dnn::AlgorithmConfig& algorithm_config,
2344 const CudnnTensorDescriptor& input_nd, const CudnnFilterDescriptor& filter,
2345 const CudnnConvolutionDescriptor& conv,
2346 const CudnnTensorDescriptor& output_nd, ScratchAllocator* scratch_allocator,
2347 DeviceMemory<uint8>* scratch) {
2348 absl::optional<dnn::AlgorithmDesc> algo_desc = algorithm_config.algorithm();
2349 if (!algo_desc.has_value()) {
2350 // Pick fastest algorithm within memory limit according to cuDNN's
2351 // heuristics.
2352 bool specify_workspace_limit = scratch_allocator != nullptr;
2353 auto memory_limit_bytes =
2354 specify_workspace_limit
2355 ? std::max(scratch_allocator->GetMemoryLimitInBytes(stream), 0ll)
2356 : 0ll;
2357 SE_ASSIGN_OR_RETURN(cudnnConvolutionFwdAlgo_t algo,
2358 GetCudnnConvolutionForwardAlgo(
2359 cudnn, input_nd, filter, conv, output_nd,
2360 specify_workspace_limit, memory_limit_bytes));
2361 algo_desc = dnn::AlgorithmDesc(algo, /*use_tensor_ops=*/true);
2362 }
2363
2364 const auto scratch_or = AllocateCudnnConvolutionForwardWorkspace(
2365 stream, cudnn, input_nd, filter, conv, output_nd, *algo_desc,
2366 scratch_allocator);
2367
2368 if (scratch_or.ok()) {
2369 *scratch = scratch_or.ValueOrDie();
2370 return *algo_desc;
2371 }
2372
2373 algo_desc = algorithm_config.algorithm_no_scratch();
2374
2375 // Failed to allocate workspace for the first algorithm, fall back to the
2376 // no_scratch algorithm.
2377 if (!algo_desc.has_value()) {
2378 return port::Status(
2379 port::error::INVALID_ARGUMENT,
2380 "The primary convolution algorithm failed memory allocation, "
2381 "while a secondary algorithm is not provided.");
2382 }
2383
2384 SE_ASSIGN_OR_RETURN(*scratch, AllocateCudnnConvolutionForwardWorkspace(
2385 stream, cudnn, input_nd, filter, conv,
2386 output_nd, *algo_desc, scratch_allocator));
2387 return *algo_desc;
2388 }
2389
GetCudnnConvolutionBackwardDataAlgorithm(Stream * stream,const CudnnHandle & cudnn,const dnn::AlgorithmConfig & algorithm_config,const CudnnTensorDescriptor & input_nd,const CudnnFilterDescriptor & filter,const CudnnConvolutionDescriptor & conv,const CudnnTensorDescriptor & output_nd,ScratchAllocator * scratch_allocator,DeviceMemory<uint8> * scratch)2390 port::StatusOr<dnn::AlgorithmDesc> GetCudnnConvolutionBackwardDataAlgorithm(
2391 Stream* stream, const CudnnHandle& cudnn,
2392 const dnn::AlgorithmConfig& algorithm_config,
2393 const CudnnTensorDescriptor& input_nd, const CudnnFilterDescriptor& filter,
2394 const CudnnConvolutionDescriptor& conv,
2395 const CudnnTensorDescriptor& output_nd, ScratchAllocator* scratch_allocator,
2396 DeviceMemory<uint8>* scratch) {
2397 absl::optional<dnn::AlgorithmDesc> algo_desc = algorithm_config.algorithm();
2398 if (!algo_desc.has_value()) {
2399 // Pick fastest algorithm within memory limit according to cuDNN's
2400 // heuristics.
2401 bool specify_workspace_limit = scratch_allocator != nullptr;
2402 auto memory_limit_bytes =
2403 specify_workspace_limit
2404 ? std::max(scratch_allocator->GetMemoryLimitInBytes(stream), 0ll)
2405 : 0ll;
2406 SE_ASSIGN_OR_RETURN(cudnnConvolutionBwdDataAlgo_t algo,
2407 GetCudnnConvolutionBackwardDataAlgo(
2408 cudnn, input_nd, filter, conv, output_nd,
2409 specify_workspace_limit, memory_limit_bytes));
2410 algo_desc = dnn::AlgorithmDesc(algo, /*use_tensor_ops=*/true);
2411 }
2412
2413 const auto scratch_or = AllocateCudnnConvolutionBackwardDataWorkspace(
2414 stream, cudnn, input_nd, filter, conv, output_nd, *algo_desc,
2415 scratch_allocator);
2416
2417 if (scratch_or.ok()) {
2418 *scratch = scratch_or.ValueOrDie();
2419 return *algo_desc;
2420 }
2421
2422 algo_desc = algorithm_config.algorithm_no_scratch();
2423
2424 // Failed to allocate workspace for the first algorithm, fall back to the
2425 // no_scratch algorithm.
2426 if (!algo_desc.has_value()) {
2427 return port::Status(
2428 port::error::INVALID_ARGUMENT,
2429 "The primary convolution algorithm failed memory allocation, "
2430 "while a secondary algorithm is not provided.");
2431 }
2432
2433 SE_ASSIGN_OR_RETURN(*scratch, AllocateCudnnConvolutionBackwardDataWorkspace(
2434 stream, cudnn, input_nd, filter, conv,
2435 output_nd, *algo_desc, scratch_allocator));
2436 return *algo_desc;
2437 }
2438
GetCudnnConvolutionBackwardFilterAlgorithm(Stream * stream,const CudnnHandle & cudnn,const dnn::AlgorithmConfig & algorithm_config,const CudnnTensorDescriptor & input_nd,const CudnnFilterDescriptor & filter,const CudnnConvolutionDescriptor & conv,const CudnnTensorDescriptor & output_nd,ScratchAllocator * scratch_allocator,DeviceMemory<uint8> * scratch)2439 port::StatusOr<dnn::AlgorithmDesc> GetCudnnConvolutionBackwardFilterAlgorithm(
2440 Stream* stream, const CudnnHandle& cudnn,
2441 const dnn::AlgorithmConfig& algorithm_config,
2442 const CudnnTensorDescriptor& input_nd, const CudnnFilterDescriptor& filter,
2443 const CudnnConvolutionDescriptor& conv,
2444 const CudnnTensorDescriptor& output_nd, ScratchAllocator* scratch_allocator,
2445 DeviceMemory<uint8>* scratch) {
2446 absl::optional<dnn::AlgorithmDesc> algo_desc = algorithm_config.algorithm();
2447 if (!algo_desc.has_value()) {
2448 // Pick fastest algorithm within memory limit according to cuDNN's
2449 // heuristics.
2450 bool specify_workspace_limit = scratch_allocator != nullptr;
2451 auto memory_limit_bytes =
2452 specify_workspace_limit
2453 ? std::max(scratch_allocator->GetMemoryLimitInBytes(stream), 0ll)
2454 : 0ll;
2455 SE_ASSIGN_OR_RETURN(cudnnConvolutionBwdFilterAlgo_t algo,
2456 GetCudnnConvolutionBackwardFilterAlgo(
2457 cudnn, input_nd, filter, conv, output_nd,
2458 specify_workspace_limit, memory_limit_bytes));
2459 algo_desc = dnn::AlgorithmDesc(algo, /*use_tensor_ops=*/true);
2460 }
2461
2462 auto scratch_or = AllocateCudnnConvolutionBackwardFilterWorkspace(
2463 stream, cudnn, input_nd, filter, conv, output_nd, *algo_desc,
2464 scratch_allocator);
2465
2466 if (scratch_or.ok()) {
2467 *scratch = scratch_or.ValueOrDie();
2468 return *algo_desc;
2469 }
2470
2471 algo_desc = algorithm_config.algorithm_no_scratch();
2472
2473 // Failed to allocate workspace for the first algorithm, fall back to the
2474 // no_scratch algorithm.
2475 if (!algo_desc.has_value()) {
2476 return port::Status(
2477 port::error::INVALID_ARGUMENT,
2478 "The primary convolution algorithm failed memory allocation, "
2479 "while a secondary algorithm is not provided.");
2480 }
2481
2482 SE_ASSIGN_OR_RETURN(*scratch, AllocateCudnnConvolutionBackwardFilterWorkspace(
2483 stream, cudnn, input_nd, filter, conv,
2484 output_nd, *algo_desc, scratch_allocator));
2485 return *algo_desc;
2486 }
2487
2488 // A helper class to set env-vars and choose options for cudnn-related
2489 // algorithms.
2490 template <typename EnvVar>
2491 class CudnnEnvVar {
2492 public:
IsEnabled()2493 static bool IsEnabled() {
2494 static bool is_enabled = IsEnabledImpl();
2495 return is_enabled;
2496 }
2497
2498 private:
IsEnabledImpl()2499 static bool IsEnabledImpl() {
2500 const char* tf_env_var_val = getenv(EnvVar::kName);
2501 if (tf_env_var_val != nullptr) {
2502 absl::string_view tf_env_var_val_str(tf_env_var_val);
2503 if (tf_env_var_val_str == "0") {
2504 return false;
2505 }
2506 return true;
2507 }
2508 return EnvVar::kDefaultFlag;
2509 }
2510 };
2511
2512 // A helper struct to decide whether to enable the FFT_TILING algorithms for
2513 // forward convolution. It is disabled for cuDNN < 7 due to memory corruption
2514 // caused by some shapes with this algorithm. Users can explicitly enable the
2515 // algorithm through an env-var "TF_ENABLE_FFT_TILING_FORWARD=1".
2516 struct FftTilingForward {
2517 static constexpr const char* kName = "TF_ENABLE_FFT_TILING_FORWARD";
2518 static constexpr bool kDefaultFlag = CUDNN_VERSION >= 7000;
2519 };
2520
2521 // A helper struct to decide whether to enable the WINOGRAD_NONFUSED algorithms.
2522 // By default it is turned on, users can explicitly disable them through an
2523 // env-var "TF_ENABLE_WINOGRAD_NONFUSED=0".
2524 // https://github.com/tensorflow/tensorflow/pull/4901
2525 struct WinogradNonfused {
2526 static constexpr const char* kName = "TF_ENABLE_WINOGRAD_NONFUSED";
2527 // NVIDIA has fixed winograd nonfused bug for cudnn v>=7. For older versions,
2528 // we have a workaround.
2529 static constexpr bool kDefaultFlag = true;
2530 };
2531
2532 // A helper struct to decide whether to use FP32 as the internal compute type
2533 // for convolution when the input data type is FP16. By default it is turned on,
2534 // users can explicitly disable them (choose to use FP16 as the internal compute
2535 // type) through an env-var "TF_FP16_CONV_USE_FP32_COMPUTE=0".
2536 struct ConvDoFP32ComputationFP16Input {
2537 static constexpr const char* kName = "TF_FP16_CONV_USE_FP32_COMPUTE";
2538 // Using FP16 as the internal compute type for convolution when the input data
2539 // type is FP16 is only supported on architectures with true fp16 support
2540 // (compute capability 5.3 and 6.0). Setting this to false in an unsupported
2541 // architecture will cause internal errors.
2542 static constexpr bool kDefaultFlag = true;
2543 };
2544
2545 // A helper struct to decide whether to use FP32 as the internal compute type
2546 // for rnn when the input data type is FP16. At present it is turned off,
2547 // users can explicitly control them through an env-var
2548 // TF_FP16_RNN_USE_FP32_COMPUTE.
2549 // After the TODO below is fixed, users should almost always use fp32 compute
2550 // type for training. Using fp16 might suffer suboptimal accuracy due to loss
2551 // in precision.
2552 struct RnnDoFP32ComputationFP16Input {
2553 static constexpr const char* kName = "TF_FP16_RNN_USE_FP32_COMPUTE";
2554 // TODO(jamesqin): b/78182362 flip to true when cudnn 7.1.4 fixes the bug.
2555 // Before cudnn 7.1.4 RNN are always done in fp32, no matter what math
2556 // precision is set.
2557 // Set it temporary to false s.t. no error is raised when using fp16 inputs,
2558 // fp32 math precision.
2559 static constexpr bool kDefaultFlag = false;
2560 };
2561
GetRnnComputeType(dnn::DataType data_type)2562 cudnnDataType_t GetRnnComputeType(dnn::DataType data_type) {
2563 switch (data_type) {
2564 case dnn::DataType::kFloat:
2565 return CUDNN_DATA_FLOAT;
2566 case dnn::DataType::kDouble:
2567 return CUDNN_DATA_DOUBLE;
2568 case dnn::DataType::kHalf:
2569 if (CudnnEnvVar<RnnDoFP32ComputationFP16Input>::IsEnabled()) {
2570 return CUDNN_DATA_FLOAT;
2571 } else {
2572 return CUDNN_DATA_HALF;
2573 }
2574 default:
2575 LOG(FATAL) << "Invalid RNN data type: " << static_cast<int>(data_type);
2576 }
2577 }
2578
GetConvAccumulatorType(dnn::DataType data_type)2579 dnn::DataType GetConvAccumulatorType(dnn::DataType data_type) {
2580 switch (data_type) {
2581 case dnn::DataType::kFloat:
2582 case dnn::DataType::kDouble:
2583 return data_type;
2584 case dnn::DataType::kHalf:
2585 return CudnnEnvVar<ConvDoFP32ComputationFP16Input>::IsEnabled()
2586 ? dnn::DataType::kFloat
2587 : dnn::DataType::kHalf;
2588 case dnn::DataType::kInt8:
2589 case dnn::DataType::kInt32:
2590 return dnn::DataType::kInt32;
2591 default:
2592 LOG(FATAL) << "Invalid DNN data type: " << static_cast<int>(data_type);
2593 }
2594 }
2595
2596 // Determines whether we can safely perform a winograd non-fused convolution for
2597 // the given input and output shapes. This works around b/68264959, an integer
2598 // overflow in cuDNNv5 and cuDNNv6.
2599 #if CUDNN_VERSION >= 7000
ShouldIncludeWinogradNonfusedAlgo(const dnn::BatchDescriptor &,const dnn::BatchDescriptor &)2600 bool ShouldIncludeWinogradNonfusedAlgo(const dnn::BatchDescriptor&,
2601 const dnn::BatchDescriptor&) {
2602 return true;
2603 }
2604 #else
ShouldIncludeWinogradNonfusedAlgo(const dnn::BatchDescriptor & input_desc,const dnn::BatchDescriptor & output_desc)2605 bool ShouldIncludeWinogradNonfusedAlgo(
2606 const dnn::BatchDescriptor& input_desc,
2607 const dnn::BatchDescriptor& output_desc) {
2608 int64 batch = input_desc.count();
2609 int64 in_depths = input_desc.feature_map_count();
2610 int64 in_rows = input_desc.height();
2611 int64 in_cols = input_desc.ndims() == 1 ? 1 : input_desc.width();
2612 int64 out_depths = output_desc.feature_map_count();
2613
2614 int64 total_size = port::MathUtil::CeilOfRatio(batch, int64{16}) *
2615 std::max(in_depths, out_depths) * in_cols * in_rows *
2616 sizeof(float);
2617
2618 const int64 threshold = 1L << 31;
2619 return total_size < threshold;
2620 }
2621 #endif
2622
2623 } // namespace
2624
DoPrepareForConvolution(dnn::ConvolutionKind kind,dnn::DataType element_type,Stream * stream,const dnn::BatchDescriptor & input_descriptor,DeviceMemoryBase input_data,const dnn::FilterDescriptor & filter_descriptor,DeviceMemoryBase filter_data,const dnn::BatchDescriptor & output_descriptor,DeviceMemoryBase output_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::AlgorithmConfig & algorithm_config,ScratchAllocator * scratch_allocator,dnn::AlgorithmDesc * algorithm_desc,DeviceMemory<uint8> * scratch_memory)2625 port::Status CudnnSupport::DoPrepareForConvolution(
2626 dnn::ConvolutionKind kind, dnn::DataType element_type, Stream* stream,
2627 const dnn::BatchDescriptor& input_descriptor, DeviceMemoryBase input_data,
2628 const dnn::FilterDescriptor& filter_descriptor,
2629 DeviceMemoryBase filter_data, const dnn::BatchDescriptor& output_descriptor,
2630 DeviceMemoryBase output_data,
2631 const dnn::ConvolutionDescriptor& convolution_descriptor,
2632 const dnn::AlgorithmConfig& algorithm_config,
2633 ScratchAllocator* scratch_allocator, dnn::AlgorithmDesc* algorithm_desc,
2634 DeviceMemory<uint8>* scratch_memory) {
2635 CudnnTensorDescriptor input_nd(
2636 input_descriptor,
2637 ToCudnnDataType(element_type, input_descriptor.layout()));
2638 CudnnFilterDescriptor filter_nd(
2639 filter_descriptor,
2640 ToCudnnDataType(element_type, filter_descriptor.layout()));
2641 CudnnTensorDescriptor output_nd(
2642 output_descriptor,
2643 ToCudnnDataType(element_type, output_descriptor.layout()));
2644 CudnnConvolutionDescriptor conv(
2645 convolution_descriptor,
2646 ToCudnnDataType(GetConvAccumulatorType(element_type)));
2647
2648 auto cudnn = cudnn_->GetHandle(parent_, stream);
2649
2650 switch (kind) {
2651 case dnn::ConvolutionKind::FORWARD: {
2652 SE_ASSIGN_OR_RETURN(
2653 *algorithm_desc,
2654 GetCudnnConvolutionForwardAlgorithm(
2655 stream, cudnn, algorithm_config, input_nd, filter_nd, conv,
2656 output_nd, scratch_allocator, scratch_memory));
2657 break;
2658 }
2659 case dnn::ConvolutionKind::BACKWARD_DATA: {
2660 SE_ASSIGN_OR_RETURN(
2661 *algorithm_desc,
2662 GetCudnnConvolutionBackwardDataAlgorithm(
2663 stream, cudnn, algorithm_config, input_nd, filter_nd, conv,
2664 output_nd, scratch_allocator, scratch_memory));
2665 break;
2666 }
2667 case dnn::ConvolutionKind::BACKWARD_FILTER: {
2668 SE_ASSIGN_OR_RETURN(
2669 *algorithm_desc,
2670 GetCudnnConvolutionBackwardFilterAlgorithm(
2671 stream, cudnn, algorithm_config, input_nd, filter_nd, conv,
2672 output_nd, scratch_allocator, scratch_memory));
2673 break;
2674 }
2675 default:
2676 return port::InternalError(
2677 absl::StrCat("Unexpected convolution kind ", static_cast<int>(kind)));
2678 }
2679
2680 return port::Status::OK();
2681 }
2682
DoConvolve(dnn::ConvolutionKind kind,dnn::DataType element_type,Stream * stream,const dnn::BatchDescriptor & input_descriptor,DeviceMemoryBase input_data,const dnn::FilterDescriptor & filter_descriptor,DeviceMemoryBase filter_data,const dnn::BatchDescriptor & output_descriptor,DeviceMemoryBase output_data,const dnn::ConvolutionDescriptor & convolution_descriptor,dnn::AlgorithmDesc algorithm_desc,DeviceMemory<uint8> scratch_memory,dnn::ProfileResult * output_profile_result)2683 port::Status CudnnSupport::DoConvolve(
2684 dnn::ConvolutionKind kind, dnn::DataType element_type, Stream* stream,
2685 const dnn::BatchDescriptor& input_descriptor, DeviceMemoryBase input_data,
2686 const dnn::FilterDescriptor& filter_descriptor,
2687 DeviceMemoryBase filter_data, const dnn::BatchDescriptor& output_descriptor,
2688 DeviceMemoryBase output_data,
2689 const dnn::ConvolutionDescriptor& convolution_descriptor,
2690 dnn::AlgorithmDesc algorithm_desc, DeviceMemory<uint8> scratch_memory,
2691 dnn::ProfileResult* output_profile_result) {
2692 cudnnDataType_t cudnn_type = ToCudnnDataType(element_type);
2693 CudnnTensorDescriptor input_nd(input_descriptor, cudnn_type);
2694 CudnnTensorDescriptor output_nd(output_descriptor, cudnn_type);
2695 CudnnFilterDescriptor filter_nd(filter_descriptor, cudnn_type);
2696 auto accumulator_type = GetConvAccumulatorType(element_type);
2697 CudnnConvolutionDescriptor conv(convolution_descriptor,
2698 ToCudnnDataType(accumulator_type));
2699
2700 auto cudnn = cudnn_->GetHandle(parent_, stream);
2701 // Alpha is the scaling factor for input.
2702 float falpha = 1.0;
2703 double dalpha = 1.0;
2704 void* alpha = cudnn_type == CUDNN_DATA_DOUBLE ? static_cast<void*>(&dalpha)
2705 : static_cast<void*>(&falpha);
2706 // Beta is the scaling factor for output.
2707 float fbeta = 0.0;
2708 double dbeta = 0.0;
2709 void* beta = cudnn_type == CUDNN_DATA_DOUBLE ? static_cast<void*>(&dbeta)
2710 : static_cast<void*>(&fbeta);
2711
2712 const bool is_profiling = output_profile_result != nullptr;
2713
2714 std::unique_ptr<GpuTimer, GpuTimerDeleter> timer;
2715 if (is_profiling) {
2716 timer.reset(new GpuTimer(parent_)); // NOLINT
2717 // The start and stop of the timer should be as close to the Cudnn call as
2718 // possible. It is still possible for other threads to issue workload on
2719 // to this stream. So it could take multiple profiling measurements.
2720 if (!timer->Init() || !timer->Start(AsGpuStream(stream))) {
2721 return port::Status(port::error::INTERNAL, "Failed to start timer");
2722 }
2723 }
2724
2725 const auto get_fwd_bugs = [&]() -> port::Status {
2726 // Report an error if we might be hitting a cuDNN bug that accesses illegal
2727 // memory. See nvbugs/2138754, b/80018418.
2728 if (CUDNN_VERSION < 7300) {
2729 if (algorithm_desc.algo_id() != CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING) {
2730 return port::Status::OK();
2731 }
2732 if (input_descriptor.ndims() < 3) {
2733 return port::Status::OK();
2734 }
2735 // Checks that a*b is within the valid range (as provided by NVIDIA).
2736 const auto check_sizes = [](size_t a, size_t b) {
2737 if ((a * b * 4608 - 1) >> 31 == 0) {
2738 return port::Status::OK();
2739 }
2740 return port::Status(
2741 port::error::FAILED_PRECONDITION,
2742 "This configuration potentially accesses illegal memory.");
2743 };
2744 SE_RETURN_IF_ERROR(check_sizes(input_descriptor.feature_map_count(),
2745 output_descriptor.feature_map_count()));
2746 SE_RETURN_IF_ERROR(check_sizes(input_descriptor.count(),
2747 input_descriptor.feature_map_count()));
2748 SE_RETURN_IF_ERROR(check_sizes(input_descriptor.count(),
2749 output_descriptor.feature_map_count()));
2750 return port::Status::OK();
2751 }
2752 if (algorithm_desc.algo_id() ==
2753 CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED &&
2754 !ShouldIncludeWinogradNonfusedAlgo(input_descriptor,
2755 output_descriptor)) {
2756 return port::Status(
2757 port::error::FAILED_PRECONDITION,
2758 "This configuration has potential integer overflow in "
2759 "cuDNNv5 and cuDNNv6. See b/68264959.");
2760 }
2761 return port::Status::OK();
2762 };
2763
2764 auto get_bwd_data_bugs = [&]() -> port::Status {
2765 if (algorithm_desc.algo_id() ==
2766 CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED &&
2767 !ShouldIncludeWinogradNonfusedAlgo(input_descriptor,
2768 output_descriptor)) {
2769 return port::Status(
2770 port::error::FAILED_PRECONDITION,
2771 "This configuration has potential integer overflow in "
2772 "cuDNNv5 and cuDNNv6. See b/68264959.");
2773 }
2774
2775 // Cudnn 7.1.4 has a bug if the workspace of the following convolution is
2776 // not zero-initialized, nvbugs/2254619.
2777 if (CUDNN_VERSION >= 7000 && CUDNN_VERSION < 7300 &&
2778 algorithm_desc.algo_id() == CUDNN_CONVOLUTION_BWD_DATA_ALGO_1 &&
2779 cudnn_type == CUDNN_DATA_HALF && algorithm_desc.tensor_ops_enabled() &&
2780 input_descriptor.layout() == dnn::DataLayout::kBatchYXDepth &&
2781 filter_descriptor.layout() == dnn::FilterLayout::kOutputInputYX &&
2782 output_descriptor.layout() == dnn::DataLayout::kBatchDepthYX &&
2783 (convolution_descriptor.vertical_filter_stride() > 1 ||
2784 convolution_descriptor.horizontal_filter_stride() > 1)) {
2785 stream->ThenMemZero(&scratch_memory, scratch_memory.size());
2786 }
2787 return port::Status::OK();
2788 };
2789
2790 const auto get_bwd_filter_bugs = [&]() -> port::Status {
2791 // Report an error if we might be hitting a cuDNN bug that produces
2792 // incorrect results. See nvbugs/2072856
2793 if (CUDNN_VERSION < 7300) {
2794 SE_RETURN_IF_ERROR([&] {
2795 if (algorithm_desc.algo_id() !=
2796 CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT_TILING) {
2797 return port::Status::OK();
2798 }
2799 if (output_descriptor.height() > 1 && output_descriptor.width() > 1) {
2800 return port::Status::OK();
2801 }
2802 int convolution_size = output_descriptor.height() > 1
2803 ? filter_descriptor.input_filter_height()
2804 : filter_descriptor.input_filter_width();
2805 if (convolution_size <= 32) {
2806 return port::Status::OK();
2807 }
2808 cudnnConvolutionMode_t convolution_mode;
2809 cudnnDataType_t compute_type;
2810 RETURN_IF_CUDNN_ERROR(cudnnGetConvolutionNdDescriptor(
2811 conv.handle(), 0, nullptr, nullptr, nullptr, nullptr,
2812 &convolution_mode, &compute_type));
2813 if (convolution_mode != CUDNN_CONVOLUTION) {
2814 return port::Status::OK();
2815 }
2816 return port::Status(
2817 port::error::FAILED_PRECONDITION,
2818 "This configuration potentially produces incorrect results.");
2819 }());
2820 }
2821
2822 if (algorithm_desc.algo_id() ==
2823 CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED &&
2824 !ShouldIncludeWinogradNonfusedAlgo(input_descriptor,
2825 output_descriptor)) {
2826 return port::Status(
2827 port::error::FAILED_PRECONDITION,
2828 "This configuration has potential integer overflow in "
2829 "cuDNNv5 and cuDNNv6. See b/68264959.");
2830 }
2831
2832 // Zero out the result buffer for strided conv backward filter for NHWC
2833 // layouts. cuDNN 7.1.4 and 7.2 has non-determinisic bug if the buffer is
2834 // not zeroed.
2835 //
2836 // This wrong result caused by the bug is very flaky. It needs to be run for
2837 // up to 20 times to produce a mismatch.
2838 //
2839 // See nvbugs/2379553.
2840 if (CUDNN_VERSION >= 7100 && CUDNN_VERSION < 7300 &&
2841 algorithm_desc.algo_id() == CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1 &&
2842 cudnn_type == CUDNN_DATA_HALF &&
2843 input_descriptor.layout() == dnn::DataLayout::kBatchYXDepth &&
2844 filter_descriptor.layout() == dnn::FilterLayout::kOutputYXInput &&
2845 output_descriptor.layout() == dnn::DataLayout::kBatchYXDepth &&
2846 (convolution_descriptor.vertical_filter_stride() > 1 ||
2847 convolution_descriptor.horizontal_filter_stride() > 1)) {
2848 stream->ThenMemZero(&filter_data, filter_data.size());
2849 }
2850 return port::Status::OK();
2851 };
2852
2853 switch (kind) {
2854 case dnn::ConvolutionKind::FORWARD: {
2855 SE_RETURN_IF_ERROR(get_fwd_bugs());
2856 RETURN_IF_CUDNN_ERROR(cudnnConvolutionForward(
2857 cudnn.handle(),
2858 /*alpha=*/alpha, /*srcDesc=*/input_nd.handle(),
2859 /*srcData=*/input_data.opaque(), /*filterDesc=*/filter_nd.handle(),
2860 /*filterData=*/filter_data.opaque(), /*convDesc=*/conv.handle(),
2861 /*algo=*/ToConvForwardAlgo(algorithm_desc),
2862 /*workSpace=*/scratch_memory.opaque(),
2863 /*workSpaceSizeInBytes=*/scratch_memory.size(), /*beta=*/beta,
2864 /*yDesc=*/output_nd.handle(), /*y=*/output_data.opaque()));
2865 break;
2866 }
2867 case dnn::ConvolutionKind::BACKWARD_DATA: {
2868 SE_RETURN_IF_ERROR(get_bwd_data_bugs());
2869 RETURN_IF_CUDNN_ERROR(cudnnConvolutionBackwardData(
2870 cudnn.handle(),
2871 /*alpha=*/alpha,
2872 /*wDesc=*/filter_nd.handle(),
2873 /*w=*/filter_data.opaque(),
2874 /*dyDesc=*/output_nd.handle(),
2875 /*dy=*/output_data.opaque(),
2876 /*convDesc=*/conv.handle(),
2877 /*algo=*/ToConvBackwardDataAlgo(algorithm_desc),
2878 /*workSpace=*/scratch_memory.opaque(),
2879 /*workSpaceSizeInBytes=*/scratch_memory.size(),
2880 /*beta=*/beta,
2881 /*dxDesc=*/input_nd.handle(),
2882 /*dx=*/input_data.opaque()));
2883 break;
2884 }
2885 case dnn::ConvolutionKind::BACKWARD_FILTER: {
2886 SE_RETURN_IF_ERROR(get_bwd_filter_bugs());
2887 RETURN_IF_CUDNN_ERROR(cudnnConvolutionBackwardFilter(
2888 cudnn.handle(),
2889 /*alpha=*/alpha,
2890 /*srcDesc=*/input_nd.handle(),
2891 /*srcData=*/input_data.opaque(),
2892 /*diffDesc=*/output_nd.handle(),
2893 /*diffData=*/output_data.opaque(),
2894 /*convDesc=*/conv.handle(),
2895 /*algo=*/ToConvBackwardFilterAlgo(algorithm_desc),
2896 /*workSpace=*/scratch_memory.opaque(),
2897 /*workSpaceSizeInBytes=*/scratch_memory.size(),
2898 /*beta=*/beta,
2899 /*gradDesc=*/filter_nd.handle(),
2900 /*dw=*/filter_data.opaque()));
2901 break;
2902 }
2903 default:
2904 return port::InternalError(
2905 absl::StrCat("Unexpected convolution kind ", static_cast<int>(kind)));
2906 }
2907
2908 if (is_profiling) {
2909 if (!timer->Stop(AsGpuStream(stream))) {
2910 return port::Status(port::error::INTERNAL, "Failed to stop timer");
2911 }
2912 output_profile_result->set_algorithm(algorithm_desc);
2913 output_profile_result->set_elapsed_time_in_ms(
2914 timer->GetElapsedMilliseconds());
2915 output_profile_result->set_scratch_size(scratch_memory.size());
2916 }
2917
2918 return port::Status::OK();
2919 }
2920
2921 template <typename ElementType, typename BiasType, typename ScaleType>
DoFusedConvolveImpl(Stream * stream,const dnn::BatchDescriptor & conv_input_descriptor,const DeviceMemory<ElementType> & conv_input_data,ScaleType conv_input_scale,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<ElementType> & filter_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const DeviceMemory<ElementType> & side_input_data,ScaleType side_input_scale,const dnn::BatchDescriptor & bias_descriptor,const DeviceMemory<BiasType> & biases,dnn::ActivationMode activation_mode,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<ElementType> * output_data,dnn::DataType accumulator_type,ScratchAllocator * scratch_allocator,const dnn::AlgorithmConfig & algorithm_config,dnn::ProfileResult * output_profile_result)2922 port::Status CudnnSupport::DoFusedConvolveImpl(
2923 Stream* stream, const dnn::BatchDescriptor& conv_input_descriptor,
2924 const DeviceMemory<ElementType>& conv_input_data,
2925 ScaleType conv_input_scale, const dnn::FilterDescriptor& filter_descriptor,
2926 const DeviceMemory<ElementType>& filter_data,
2927 const dnn::ConvolutionDescriptor& convolution_descriptor,
2928 const DeviceMemory<ElementType>& side_input_data,
2929 ScaleType side_input_scale, const dnn::BatchDescriptor& bias_descriptor,
2930 const DeviceMemory<BiasType>& biases, dnn::ActivationMode activation_mode,
2931 const dnn::BatchDescriptor& output_descriptor,
2932 DeviceMemory<ElementType>* output_data, dnn::DataType accumulator_type,
2933 ScratchAllocator* scratch_allocator,
2934 const dnn::AlgorithmConfig& algorithm_config,
2935 dnn::ProfileResult* output_profile_result) {
2936 if (activation_mode != dnn::ActivationMode::kRelu &&
2937 activation_mode != dnn::ActivationMode::kNone) {
2938 return port::Status(port::error::INVALID_ARGUMENT,
2939 "cudnnConvolutionBiasActivationForward() only supports "
2940 "Relu or None activation.");
2941 }
2942
2943 CudnnTensorDescriptor conv_input_nd(
2944 conv_input_descriptor,
2945 GetCudnnDataType<ElementType>(conv_input_descriptor.layout()));
2946 CudnnTensorDescriptor output_nd(
2947 output_descriptor,
2948 GetCudnnDataType<ElementType>(conv_input_descriptor.layout()));
2949 CudnnFilterDescriptor filter(
2950 filter_descriptor,
2951 GetCudnnDataType<ElementType>(conv_input_descriptor.layout()));
2952 CudnnTensorDescriptor bias_nd(bias_descriptor, GetCudnnDataType<BiasType>());
2953 CudnnConvolutionDescriptor conv(convolution_descriptor,
2954 ToCudnnDataType(accumulator_type));
2955
2956 auto cudnn = cudnn_->GetHandle(parent_, stream);
2957
2958 const bool is_profiling = output_profile_result != nullptr;
2959
2960 DeviceMemory<uint8> scratch;
2961 SE_ASSIGN_OR_RETURN(
2962 dnn::AlgorithmDesc algo_desc,
2963 GetCudnnConvolutionForwardAlgorithm(
2964 stream, cudnn, algorithm_config, conv_input_nd, filter, conv,
2965 output_nd, scratch_allocator, &scratch));
2966
2967 std::unique_ptr<GpuTimer, GpuTimerDeleter> timer;
2968 if (is_profiling) {
2969 timer.reset(new GpuTimer(parent_)); // NOLINT
2970 // The start and stop of the timer should be as close to the Cudnn call as
2971 // possible. It is still possible for other threads to issue workload on
2972 // to this stream. So it could take multiple profiling measurements.
2973 if (!timer->Init() || !timer->Start(AsGpuStream(stream))) {
2974 return port::Status(port::error::INTERNAL, "Failed to start timer");
2975 }
2976 }
2977 // CUDNN v6 only supports CUDNN_NOT_PROPAGATE_NAN as the reluNanOpt for
2978 // activation descriptor. Note that this will change the nan propagation
2979 // behavior from separate conv, bias, and relu (which by default is
2980 // CUDNN_PROPAGATE_NAN.
2981 CudnnActivationDescriptor activation_desc(
2982 activation_mode, CUDNN_NOT_PROPAGATE_NAN, output_descriptor.value_max());
2983 auto side_input_data_ptr = (side_input_scale == 0) ? output_data->opaque()
2984 : side_input_data.opaque();
2985
2986 VLOG(2) << "\nconv_input_scale = " << conv_input_scale
2987 << "\nconv_input_nd.handle() = " << conv_input_nd.handle()
2988 << "\nconv_input_data.opaque() = " << conv_input_data.opaque()
2989 << "\nfilter.handle() = " << filter.handle()
2990 << "\nfilter_data.opaque() = " << filter_data.opaque()
2991 << "\nconv.handle() = " << conv.handle()
2992 << "\nalgo = " << algo_desc.algo_id()
2993 << "\nscratch.opaque() = " << scratch.opaque()
2994 << "\nscratch.size() = " << scratch.size()
2995 << "\nside_input_scale = " << side_input_scale
2996 << "\noutput_nd.handle() = " << output_nd.handle()
2997 << "\nside_input_data_ptr = " << side_input_data_ptr
2998 << "\nbias_nd.handle() = " << bias_nd.handle()
2999 << "\nbiases.opaque() = " << biases.opaque()
3000 << "\nactivation_desc.handle() = " << activation_desc.handle()
3001 << "\noutput_nd.handle() = " << output_nd.handle()
3002 << "\noutput_data->opaque() = " << output_data->opaque();
3003
3004 if (algo_desc.algo_id() == CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED &&
3005 !ShouldIncludeWinogradNonfusedAlgo(conv_input_descriptor,
3006 output_descriptor)) {
3007 return port::Status(port::error::FAILED_PRECONDITION,
3008 "This configuration has potential integer overflow in "
3009 "cuDNNv5 and cuDNNv6. See around b/68264959.");
3010 }
3011
3012 RETURN_IF_CUDNN_ERROR(cudnnConvolutionBiasActivationForward(
3013 cudnn.handle(),
3014 /*alpha1=*/&conv_input_scale,
3015 /*srcDesc=*/conv_input_nd.handle(), /*srcData=*/conv_input_data.opaque(),
3016 /*filterDesc=*/filter.handle(), /*filterData=*/filter_data.opaque(),
3017 /*convDesc=*/conv.handle(), ToConvForwardAlgo(algo_desc),
3018 /*workSpace=*/scratch.opaque(),
3019 /*workSpaceSizeInBytes=*/scratch.size(), /*alpha2=*/&side_input_scale,
3020 /*zDesc=*/output_nd.handle(), /*z=*/side_input_data_ptr,
3021 /*biasDesc=*/bias_nd.handle(), /*bias=*/biases.opaque(),
3022 /*activationDesc=*/activation_desc.handle(),
3023 /*yDesc=*/output_nd.handle(), /*y=*/output_data->opaque()));
3024
3025 if (is_profiling) {
3026 if (!timer->Stop(AsGpuStream(stream))) {
3027 return port::Status(port::error::INTERNAL, "Failed to stop timer");
3028 }
3029 output_profile_result->set_algorithm(algo_desc);
3030 output_profile_result->set_elapsed_time_in_ms(
3031 timer->GetElapsedMilliseconds());
3032 output_profile_result->set_scratch_size(scratch.size());
3033 }
3034
3035 return port::Status::OK();
3036 }
3037
TensorOpMathAvailable(int cc_major)3038 inline bool TensorOpMathAvailable(int cc_major) {
3039 return cc_major >= 7 && CUDNN_VERSION >= 7000 && TensorOpMathEnabled();
3040 }
3041
GetConvolveAlgorithms(bool with_winograd_nonfused,int cc_major,int cc_minor,std::vector<dnn::AlgorithmDesc> * out_algorithms)3042 bool CudnnSupport::GetConvolveAlgorithms(
3043 bool with_winograd_nonfused, int cc_major, int cc_minor,
3044 std::vector<dnn::AlgorithmDesc>* out_algorithms) {
3045 bool tensor_op_math_available = TensorOpMathAvailable(cc_major);
3046 out_algorithms->clear();
3047
3048 if (RequireDeterminism()) {
3049 out_algorithms->push_back({CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM,
3050 tensor_op_math_available});
3051 return true;
3052 }
3053
3054 std::vector<dnn::AlgorithmDesc::Index> algo_types = {
3055 // clang-format off
3056 CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM,
3057 CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM,
3058 CUDNN_CONVOLUTION_FWD_ALGO_GEMM,
3059 CUDNN_CONVOLUTION_FWD_ALGO_DIRECT,
3060 CUDNN_CONVOLUTION_FWD_ALGO_FFT,
3061 CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD,
3062 // clang-format on
3063 };
3064 if (CudnnEnvVar<FftTilingForward>::IsEnabled()) {
3065 algo_types.push_back(CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING);
3066 }
3067 if (CudnnEnvVar<WinogradNonfused>::IsEnabled() && with_winograd_nonfused) {
3068 algo_types.push_back(CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED);
3069 }
3070
3071 for (auto i : algo_types) {
3072 out_algorithms->push_back({i, /*use_tensor_ops=*/false});
3073 if (tensor_op_math_available) {
3074 out_algorithms->push_back({i, /*use_tensor_ops=*/true});
3075 }
3076 }
3077
3078 return true;
3079 }
3080
GetRnnAlgorithms(std::vector<dnn::AlgorithmDesc> * out_algorithms)3081 bool CudnnSupport::GetRnnAlgorithms(
3082 std::vector<dnn::AlgorithmDesc>* out_algorithms) {
3083 std::vector<dnn::AlgorithmDesc::Index> algo_types = {
3084 // clang-format off
3085 CUDNN_RNN_ALGO_STANDARD,
3086 CUDNN_RNN_ALGO_PERSIST_STATIC,
3087 CUDNN_RNN_ALGO_PERSIST_DYNAMIC,
3088 // clang-format on
3089 };
3090
3091 out_algorithms->clear();
3092 for (auto i : algo_types) {
3093 out_algorithms->push_back({i, /*use_tensor_ops=*/false});
3094 #if CUDNN_VERSION >= 7100
3095 if (RnnTensorOpMathEnabled()) {
3096 out_algorithms->push_back({i, /*use_tensor_ops=*/true});
3097 }
3098 #endif
3099 }
3100 return true;
3101 }
3102
GetConvolveBackwardDataAlgorithms(bool with_winograd_nonfused,int cc_major,int cc_minor,std::vector<dnn::AlgorithmDesc> * out_algorithms)3103 bool CudnnSupport::GetConvolveBackwardDataAlgorithms(
3104 bool with_winograd_nonfused, int cc_major, int cc_minor,
3105 std::vector<dnn::AlgorithmDesc>* out_algorithms) {
3106 bool tensor_op_math_available = TensorOpMathAvailable(cc_major);
3107 out_algorithms->clear();
3108
3109 if (RequireDeterminism()) {
3110 out_algorithms->push_back(
3111 {CUDNN_CONVOLUTION_BWD_DATA_ALGO_1, tensor_op_math_available});
3112 return true;
3113 }
3114
3115 std::vector<dnn::AlgorithmDesc::Index> algo_types = {
3116 // clang-format off
3117 CUDNN_CONVOLUTION_BWD_DATA_ALGO_0,
3118 CUDNN_CONVOLUTION_BWD_DATA_ALGO_1,
3119 CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT,
3120 CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING,
3121 CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD,
3122 // clang-format on
3123 };
3124 if (CudnnEnvVar<WinogradNonfused>::IsEnabled() && with_winograd_nonfused) {
3125 algo_types.push_back(CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD_NONFUSED);
3126 }
3127
3128 for (auto i : algo_types) {
3129 out_algorithms->push_back({i, /*use_tensor_ops=*/false});
3130 if (tensor_op_math_available) {
3131 out_algorithms->push_back({i, /*use_tensor_ops=*/true});
3132 }
3133 }
3134
3135 return true;
3136 }
3137
GetConvolveBackwardFilterAlgorithms(bool with_winograd_nonfused,int cc_major,int cc_minor,std::vector<dnn::AlgorithmDesc> * out_algorithms)3138 bool CudnnSupport::GetConvolveBackwardFilterAlgorithms(
3139 bool with_winograd_nonfused, int cc_major, int cc_minor,
3140 std::vector<dnn::AlgorithmDesc>* out_algorithms) {
3141 bool tensor_op_math_available = TensorOpMathAvailable(cc_major);
3142 out_algorithms->clear();
3143
3144 if (RequireDeterminism()) {
3145 out_algorithms->push_back(
3146 {CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1, tensor_op_math_available});
3147 return true;
3148 }
3149
3150 std::vector<dnn::AlgorithmDesc::Index> algo_types = {
3151 // clang-format off
3152 CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0,
3153 CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1,
3154 CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT,
3155 CUDNN_CONVOLUTION_BWD_FILTER_ALGO_3,
3156 // Based on cudnn.h, the following is not implemented.
3157 // CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD,
3158
3159 // Produces incorrect results for some shapes. Disabled for now, see
3160 // NVIDIA bug 2072856. TODO(csigg): Only disable for subset of shapes.
3161 // CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT_TILING,
3162 // clang-format on
3163 };
3164 if (CudnnEnvVar<WinogradNonfused>::IsEnabled() && with_winograd_nonfused) {
3165 algo_types.push_back(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD_NONFUSED);
3166 }
3167
3168 for (auto i : algo_types) {
3169 out_algorithms->push_back({i, /*use_tensor_ops=*/false});
3170 if (tensor_op_math_available) {
3171 out_algorithms->push_back({i, /*use_tensor_ops=*/true});
3172 }
3173 }
3174
3175 return true;
3176 }
3177
DoBatchNormalizationForward(Stream * stream,const DeviceMemory<float> & x,const DeviceMemory<float> & scale,const DeviceMemory<float> & offset,const DeviceMemory<float> & estimated_mean,const DeviceMemory<float> & estimated_variance,const dnn::BatchDescriptor & x_desc,const dnn::BatchDescriptor & scale_offset_desc,const double epsilon,DeviceMemory<float> * y,DeviceMemory<float> * batch_mean,DeviceMemory<float> * batch_var,DeviceMemory<float> * saved_mean,DeviceMemory<float> * saved_inv_var,bool is_training,std::function<const DeviceMemory<float> & ()> var_to_inv_var,std::function<void ()> inv_var_to_var)3178 bool CudnnSupport::DoBatchNormalizationForward(
3179 Stream* stream, const DeviceMemory<float>& x,
3180 const DeviceMemory<float>& scale, const DeviceMemory<float>& offset,
3181 const DeviceMemory<float>& estimated_mean,
3182 const DeviceMemory<float>& estimated_variance,
3183 const dnn::BatchDescriptor& x_desc,
3184 const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
3185 DeviceMemory<float>* y, DeviceMemory<float>* batch_mean,
3186 DeviceMemory<float>* batch_var, DeviceMemory<float>* saved_mean,
3187 DeviceMemory<float>* saved_inv_var, bool is_training,
3188 std::function<const DeviceMemory<float>&()> var_to_inv_var,
3189 std::function<void()> inv_var_to_var) {
3190 return IsStatusOk(
3191 DoBatchNormalizationForwardImpl<float, float>(
3192 stream, dnn::DataType::kFloat, dnn::DataType::kFloat, x, scale,
3193 offset, estimated_mean, estimated_variance, x_desc, scale_offset_desc,
3194 epsilon, y, batch_mean, batch_var, saved_mean, saved_inv_var,
3195 is_training, std::move(var_to_inv_var), std::move(inv_var_to_var)),
3196 /*report_error=*/true);
3197 }
3198
DoBatchNormalizationForward(Stream * stream,const DeviceMemory<Eigen::half> & x,const DeviceMemory<float> & scale,const DeviceMemory<float> & offset,const DeviceMemory<float> & estimated_mean,const DeviceMemory<float> & estimated_variance,const dnn::BatchDescriptor & x_desc,const dnn::BatchDescriptor & scale_offset_desc,const double epsilon,DeviceMemory<Eigen::half> * y,DeviceMemory<float> * batch_mean,DeviceMemory<float> * batch_var,DeviceMemory<float> * saved_mean,DeviceMemory<float> * saved_inv_var,bool is_training,std::function<const DeviceMemory<float> & ()> var_to_inv_var,std::function<void ()> inv_var_to_var)3199 bool CudnnSupport::DoBatchNormalizationForward(
3200 Stream* stream, const DeviceMemory<Eigen::half>& x,
3201 const DeviceMemory<float>& scale, const DeviceMemory<float>& offset,
3202 const DeviceMemory<float>& estimated_mean,
3203 const DeviceMemory<float>& estimated_variance,
3204 const dnn::BatchDescriptor& x_desc,
3205 const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
3206 DeviceMemory<Eigen::half>* y, DeviceMemory<float>* batch_mean,
3207 DeviceMemory<float>* batch_var, DeviceMemory<float>* saved_mean,
3208 DeviceMemory<float>* saved_inv_var, bool is_training,
3209 std::function<const DeviceMemory<float>&()> var_to_inv_var,
3210 std::function<void()> inv_var_to_var) {
3211 return IsStatusOk(
3212 DoBatchNormalizationForwardImpl<Eigen::half, float>(
3213 stream, dnn::DataType::kHalf, dnn::DataType::kFloat, x, scale, offset,
3214 estimated_mean, estimated_variance, x_desc, scale_offset_desc,
3215 epsilon, y, batch_mean, batch_var, saved_mean, saved_inv_var,
3216 is_training, std::move(var_to_inv_var), std::move(inv_var_to_var)),
3217 /*report_error=*/true);
3218 }
3219
3220 template <class T, class U>
DoBatchNormalizationForwardImpl(Stream * stream,dnn::DataType input_data_type,dnn::DataType scale_data_type,const DeviceMemory<T> & x,const DeviceMemory<U> & scale,const DeviceMemory<U> & offset,const DeviceMemory<U> & estimated_mean,const DeviceMemory<U> & estimated_variance,const dnn::BatchDescriptor & x_desc,const dnn::BatchDescriptor & scale_offset_desc,const double epsilon,DeviceMemory<T> * y,DeviceMemory<U> * batch_mean,DeviceMemory<U> * batch_var,DeviceMemory<U> * saved_mean,DeviceMemory<U> * saved_inv_var,bool is_training,std::function<const DeviceMemory<U> & ()> var_to_inv_var,std::function<void ()> inv_var_to_var)3221 port::Status CudnnSupport::DoBatchNormalizationForwardImpl(
3222 Stream* stream, dnn::DataType input_data_type,
3223 dnn::DataType scale_data_type, const DeviceMemory<T>& x,
3224 const DeviceMemory<U>& scale, const DeviceMemory<U>& offset,
3225 const DeviceMemory<U>& estimated_mean,
3226 const DeviceMemory<U>& estimated_variance,
3227 const dnn::BatchDescriptor& x_desc,
3228 const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
3229 DeviceMemory<T>* y, DeviceMemory<U>* batch_mean, DeviceMemory<U>* batch_var,
3230 DeviceMemory<U>* saved_mean, DeviceMemory<U>* saved_inv_var,
3231 bool is_training, std::function<const DeviceMemory<U>&()> var_to_inv_var,
3232 std::function<void()> inv_var_to_var) {
3233 CudnnTensorDescriptor x_descriptor(x_desc, ToCudnnDataType(input_data_type));
3234 CudnnTensorDescriptor scale_offset_descriptor(
3235 scale_offset_desc, ToCudnnDataType(scale_data_type));
3236 cudnnBatchNormMode_t mode = CUDNN_BATCHNORM_SPATIAL;
3237 #if CUDNN_VERSION >= 7000
3238 if (BatchnormSpatialPersistentEnabled() && is_training) {
3239 mode = CUDNN_BATCHNORM_SPATIAL_PERSISTENT;
3240 }
3241 #endif
3242 float one = 1.0;
3243 float zero = 0.0;
3244 auto cudnn = cudnn_->GetHandle(parent_, stream);
3245
3246 if (is_training) {
3247 CHECK_EQ(batch_mean->is_null(), batch_var->is_null())
3248 << "batch_mean and batch_var must both be null or both be non-null";
3249
3250 void* batch_mean_opaque;
3251 void* batch_var_opaque;
3252 if (!batch_mean->is_null() && !batch_var->is_null()) {
3253 stream->ThenMemZero(batch_mean, batch_mean->size());
3254 stream->ThenMemZero(batch_var, batch_var->size());
3255 batch_mean_opaque = batch_mean->opaque();
3256 batch_var_opaque = batch_var->opaque();
3257 } else {
3258 batch_mean_opaque = nullptr;
3259 batch_var_opaque = nullptr;
3260 }
3261
3262 RETURN_IF_CUDNN_ERROR(cudnnBatchNormalizationForwardTraining(
3263 cudnn.handle(), mode, &one, &zero, x_descriptor.handle(), x.opaque(),
3264 x_descriptor.handle(), y->opaque(), scale_offset_descriptor.handle(),
3265 scale.opaque(), offset.opaque(), 1.0, batch_mean_opaque,
3266 batch_var_opaque, epsilon, saved_mean->opaque(),
3267 saved_inv_var->opaque()));
3268 } else {
3269 const void* maybe_inv_var = estimated_variance.opaque();
3270 RETURN_IF_CUDNN_ERROR(cudnnBatchNormalizationForwardInference(
3271 cudnn.handle(), mode, &one, &zero, x_descriptor.handle(), x.opaque(),
3272 x_descriptor.handle(), y->opaque(), scale_offset_descriptor.handle(),
3273 scale.opaque(), offset.opaque(), estimated_mean.opaque(), maybe_inv_var,
3274 epsilon));
3275 }
3276 return port::Status::OK();
3277 }
3278
DoBatchNormalizationBackward(Stream * stream,const DeviceMemory<float> & y_backprop,const DeviceMemory<float> & x,const DeviceMemory<float> & scale,const DeviceMemory<float> & mean,const DeviceMemory<float> & inv_var,const dnn::BatchDescriptor & x_desc,const dnn::BatchDescriptor & scale_offset_desc,const double epsilon,DeviceMemory<float> * x_backprop,DeviceMemory<float> * scale_backprop,DeviceMemory<float> * offset_backprop)3279 bool CudnnSupport::DoBatchNormalizationBackward(
3280 Stream* stream, const DeviceMemory<float>& y_backprop,
3281 const DeviceMemory<float>& x, const DeviceMemory<float>& scale,
3282 const DeviceMemory<float>& mean, const DeviceMemory<float>& inv_var,
3283 const dnn::BatchDescriptor& x_desc,
3284 const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
3285 DeviceMemory<float>* x_backprop, DeviceMemory<float>* scale_backprop,
3286 DeviceMemory<float>* offset_backprop) {
3287 return IsStatusOk(DoBatchNormalizationBackwardImpl(
3288 stream, CUDNN_DATA_FLOAT, CUDNN_DATA_FLOAT, y_backprop,
3289 x, scale, mean, inv_var, x_desc, scale_offset_desc,
3290 epsilon, x_backprop, scale_backprop, offset_backprop),
3291 /*report_error=*/true);
3292 }
3293
DoBatchNormalizationBackward(Stream * stream,const DeviceMemory<Eigen::half> & y_backprop,const DeviceMemory<Eigen::half> & x,const DeviceMemory<float> & scale,const DeviceMemory<float> & mean,const DeviceMemory<float> & inv_var,const dnn::BatchDescriptor & x_desc,const dnn::BatchDescriptor & scale_offset_desc,const double epsilon,DeviceMemory<Eigen::half> * x_backprop,DeviceMemory<float> * scale_backprop,DeviceMemory<float> * offset_backprop)3294 bool CudnnSupport::DoBatchNormalizationBackward(
3295 Stream* stream, const DeviceMemory<Eigen::half>& y_backprop,
3296 const DeviceMemory<Eigen::half>& x, const DeviceMemory<float>& scale,
3297 const DeviceMemory<float>& mean, const DeviceMemory<float>& inv_var,
3298 const dnn::BatchDescriptor& x_desc,
3299 const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
3300 DeviceMemory<Eigen::half>* x_backprop, DeviceMemory<float>* scale_backprop,
3301 DeviceMemory<float>* offset_backprop) {
3302 return IsStatusOk(DoBatchNormalizationBackwardImpl(
3303 stream, CUDNN_DATA_HALF, CUDNN_DATA_FLOAT, y_backprop,
3304 x, scale, mean, inv_var, x_desc, scale_offset_desc,
3305 epsilon, x_backprop, scale_backprop, offset_backprop),
3306 /*report_error=*/true);
3307 }
3308
3309 template <class T, class U>
DoBatchNormalizationBackwardImpl(Stream * stream,int cudnn_input_type,int cudnn_scale_type,const DeviceMemory<T> & y_backprop,const DeviceMemory<T> & x,const DeviceMemory<U> & scale,const DeviceMemory<U> & mean,const DeviceMemory<U> & inv_var,const dnn::BatchDescriptor & x_desc,const dnn::BatchDescriptor & scale_offset_desc,const double epsilon,DeviceMemory<T> * x_backprop,DeviceMemory<U> * scale_backprop,DeviceMemory<U> * offset_backprop)3310 port::Status CudnnSupport::DoBatchNormalizationBackwardImpl(
3311 Stream* stream, int cudnn_input_type, int cudnn_scale_type,
3312 const DeviceMemory<T>& y_backprop, const DeviceMemory<T>& x,
3313 const DeviceMemory<U>& scale, const DeviceMemory<U>& mean,
3314 const DeviceMemory<U>& inv_var, const dnn::BatchDescriptor& x_desc,
3315 const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
3316 DeviceMemory<T>* x_backprop, DeviceMemory<U>* scale_backprop,
3317 DeviceMemory<U>* offset_backprop) {
3318 CudnnTensorDescriptor x_descriptor(
3319 x_desc, static_cast<cudnnDataType_t>(cudnn_input_type));
3320 CudnnTensorDescriptor scale_offset_descriptor(
3321 scale_offset_desc, static_cast<cudnnDataType_t>(cudnn_scale_type));
3322 cudnnBatchNormMode_t mode = CUDNN_BATCHNORM_SPATIAL;
3323 #if CUDNN_VERSION >= 7000
3324 if (BatchnormSpatialPersistentEnabled()) {
3325 mode = CUDNN_BATCHNORM_SPATIAL_PERSISTENT;
3326 }
3327 #endif
3328 float one = 1.0;
3329 float zero = 0.0;
3330
3331 auto cudnn = cudnn_->GetHandle(parent_, stream);
3332
3333 RETURN_IF_CUDNN_ERROR(cudnnBatchNormalizationBackward(
3334 cudnn.handle(), mode, &one, &zero, &one, &zero, x_descriptor.handle(),
3335 x.opaque(), x_descriptor.handle(), y_backprop.opaque(),
3336 x_descriptor.handle(), x_backprop->opaque(),
3337 scale_offset_descriptor.handle(), scale.opaque(),
3338 scale_backprop->opaque(), offset_backprop->opaque(), epsilon,
3339 mean.opaque(), inv_var.opaque()));
3340 return port::Status::OK();
3341 }
3342
DoFusedConvolve(Stream * stream,const dnn::BatchDescriptor & conv_input_descriptor,const DeviceMemory<double> & conv_input_data,double conv_input_scale,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<double> & filter_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const DeviceMemory<double> & side_input_data,double side_input_scale,const dnn::BatchDescriptor & bias_descriptor,const DeviceMemory<double> & biases,dnn::ActivationMode activation_mode,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<double> * output_data,ScratchAllocator * scratch_allocator,const dnn::AlgorithmConfig & algorithm_config,dnn::ProfileResult * output_profile_result)3343 bool CudnnSupport::DoFusedConvolve(
3344 Stream* stream, const dnn::BatchDescriptor& conv_input_descriptor,
3345 const DeviceMemory<double>& conv_input_data, double conv_input_scale,
3346 const dnn::FilterDescriptor& filter_descriptor,
3347 const DeviceMemory<double>& filter_data,
3348 const dnn::ConvolutionDescriptor& convolution_descriptor,
3349 const DeviceMemory<double>& side_input_data, double side_input_scale,
3350 const dnn::BatchDescriptor& bias_descriptor,
3351 const DeviceMemory<double>& biases, dnn::ActivationMode activation_mode,
3352 const dnn::BatchDescriptor& output_descriptor,
3353 DeviceMemory<double>* output_data, ScratchAllocator* scratch_allocator,
3354 const dnn::AlgorithmConfig& algorithm_config,
3355 dnn::ProfileResult* output_profile_result) {
3356 return IsStatusOk(
3357 DoFusedConvolveImpl(
3358 stream, conv_input_descriptor, conv_input_data, conv_input_scale,
3359 filter_descriptor, filter_data, convolution_descriptor,
3360 side_input_data, side_input_scale, bias_descriptor, biases,
3361 activation_mode, output_descriptor, output_data,
3362 GetConvAccumulatorType(dnn::DataType::kDouble), scratch_allocator,
3363 algorithm_config, output_profile_result),
3364 /*report_error=*/!output_profile_result);
3365 }
3366
DoFusedConvolve(Stream * stream,const dnn::BatchDescriptor & conv_input_descriptor,const DeviceMemory<float> & conv_input_data,float conv_input_scale,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<float> & filter_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const DeviceMemory<float> & side_input_data,float side_input_scale,const dnn::BatchDescriptor & bias_descriptor,const DeviceMemory<float> & biases,dnn::ActivationMode activation_mode,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<float> * output_data,ScratchAllocator * scratch_allocator,const dnn::AlgorithmConfig & algorithm_config,dnn::ProfileResult * output_profile_result)3367 bool CudnnSupport::DoFusedConvolve(
3368 Stream* stream, const dnn::BatchDescriptor& conv_input_descriptor,
3369 const DeviceMemory<float>& conv_input_data, float conv_input_scale,
3370 const dnn::FilterDescriptor& filter_descriptor,
3371 const DeviceMemory<float>& filter_data,
3372 const dnn::ConvolutionDescriptor& convolution_descriptor,
3373 const DeviceMemory<float>& side_input_data, float side_input_scale,
3374 const dnn::BatchDescriptor& bias_descriptor,
3375 const DeviceMemory<float>& biases, dnn::ActivationMode activation_mode,
3376 const dnn::BatchDescriptor& output_descriptor,
3377 DeviceMemory<float>* output_data, ScratchAllocator* scratch_allocator,
3378 const dnn::AlgorithmConfig& algorithm_config,
3379 dnn::ProfileResult* output_profile_result) {
3380 return IsStatusOk(
3381 DoFusedConvolveImpl(
3382 stream, conv_input_descriptor, conv_input_data, conv_input_scale,
3383 filter_descriptor, filter_data, convolution_descriptor,
3384 side_input_data, side_input_scale, bias_descriptor, biases,
3385 activation_mode, output_descriptor, output_data,
3386 GetConvAccumulatorType(dnn::DataType::kFloat), scratch_allocator,
3387 algorithm_config, output_profile_result),
3388 /*report_error=*/!output_profile_result);
3389 }
3390
DoFusedConvolve(Stream * stream,const dnn::BatchDescriptor & conv_input_descriptor,const DeviceMemory<Eigen::half> & conv_input_data,float conv_input_scale,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<Eigen::half> & filter_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const DeviceMemory<Eigen::half> & side_input_data,float side_input_scale,const dnn::BatchDescriptor & bias_descriptor,const DeviceMemory<Eigen::half> & biases,dnn::ActivationMode activation_mode,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<Eigen::half> * output_data,ScratchAllocator * scratch_allocator,const dnn::AlgorithmConfig & algorithm_config,dnn::ProfileResult * output_profile_result)3391 bool CudnnSupport::DoFusedConvolve(
3392 Stream* stream, const dnn::BatchDescriptor& conv_input_descriptor,
3393 const DeviceMemory<Eigen::half>& conv_input_data, float conv_input_scale,
3394 const dnn::FilterDescriptor& filter_descriptor,
3395 const DeviceMemory<Eigen::half>& filter_data,
3396 const dnn::ConvolutionDescriptor& convolution_descriptor,
3397 const DeviceMemory<Eigen::half>& side_input_data, float side_input_scale,
3398 const dnn::BatchDescriptor& bias_descriptor,
3399 const DeviceMemory<Eigen::half>& biases,
3400 dnn::ActivationMode activation_mode,
3401 const dnn::BatchDescriptor& output_descriptor,
3402 DeviceMemory<Eigen::half>* output_data, ScratchAllocator* scratch_allocator,
3403 const dnn::AlgorithmConfig& algorithm_config,
3404 dnn::ProfileResult* output_profile_result) {
3405 return IsStatusOk(
3406 DoFusedConvolveImpl(
3407 stream, conv_input_descriptor, conv_input_data, conv_input_scale,
3408 filter_descriptor, filter_data, convolution_descriptor,
3409 side_input_data, side_input_scale, bias_descriptor, biases,
3410 activation_mode, output_descriptor, output_data,
3411 GetConvAccumulatorType(dnn::DataType::kHalf), scratch_allocator,
3412 algorithm_config, output_profile_result),
3413 /*report_error=*/!output_profile_result);
3414 }
3415
DoFusedConvolve(Stream * stream,const dnn::BatchDescriptor & conv_input_descriptor,const DeviceMemory<int8> & conv_input_data,float conv_input_scale,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<int8> & filter_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const DeviceMemory<int8> & side_input_data,float side_input_scale,const dnn::BatchDescriptor & bias_descriptor,const DeviceMemory<float> & biases,dnn::ActivationMode activation_mode,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<int8> * output_data,ScratchAllocator * scratch_allocator,const dnn::AlgorithmConfig & algorithm_config,dnn::ProfileResult * output_profile_result)3416 bool CudnnSupport::DoFusedConvolve(
3417 Stream* stream, const dnn::BatchDescriptor& conv_input_descriptor,
3418 const DeviceMemory<int8>& conv_input_data, float conv_input_scale,
3419 const dnn::FilterDescriptor& filter_descriptor,
3420 const DeviceMemory<int8>& filter_data,
3421 const dnn::ConvolutionDescriptor& convolution_descriptor,
3422 const DeviceMemory<int8>& side_input_data, float side_input_scale,
3423 const dnn::BatchDescriptor& bias_descriptor,
3424 const DeviceMemory<float>& biases, dnn::ActivationMode activation_mode,
3425 const dnn::BatchDescriptor& output_descriptor,
3426 DeviceMemory<int8>* output_data, ScratchAllocator* scratch_allocator,
3427 const dnn::AlgorithmConfig& algorithm_config,
3428 dnn::ProfileResult* output_profile_result) {
3429 int cc_major, cc_minor;
3430 stream->parent()->GetDeviceDescription().cuda_compute_capability(&cc_major,
3431 &cc_minor);
3432 if (cc_major < 6 || (cc_major == 6 && cc_minor < 1)) {
3433 LOG(WARNING) << "cudnnConvolutionBiasActivationForward() for int8 is only "
3434 "supported on GPUs with compute capability 6.1 or later.";
3435 return false;
3436 }
3437 return IsStatusOk(
3438 DoFusedConvolveImpl(
3439 stream, conv_input_descriptor, conv_input_data, conv_input_scale,
3440 filter_descriptor, filter_data, convolution_descriptor,
3441 side_input_data, side_input_scale, bias_descriptor, biases,
3442 activation_mode, output_descriptor, output_data,
3443 GetConvAccumulatorType(dnn::DataType::kInt8), scratch_allocator,
3444 algorithm_config, output_profile_result),
3445 /*report_error=*/!output_profile_result);
3446 }
3447
DoTransformTensor(Stream * stream,const dnn::BatchDescriptor & input_desc,dnn::DataType input_type,const DeviceMemoryBase & input_data,const dnn::BatchDescriptor & output_desc,dnn::DataType output_type,float scale,DeviceMemoryBase * output_data)3448 bool CudnnSupport::DoTransformTensor(Stream* stream,
3449 const dnn::BatchDescriptor& input_desc,
3450 dnn::DataType input_type,
3451 const DeviceMemoryBase& input_data,
3452 const dnn::BatchDescriptor& output_desc,
3453 dnn::DataType output_type, float scale,
3454 DeviceMemoryBase* output_data) {
3455 float beta = 0.0f;
3456 CudnnTensorDescriptor input_tensor_desc(
3457 input_desc, ToCudnnDataType(input_type, input_desc.layout()));
3458 CudnnTensorDescriptor output_tensor_desc(
3459 output_desc, ToCudnnDataType(output_type, output_desc.layout()));
3460 auto cudnn = cudnn_->GetHandle(parent_, stream);
3461 const auto status = [&] {
3462 RETURN_IF_CUDNN_ERROR(cudnnTransformTensor(
3463 cudnn.handle(), &scale, input_tensor_desc.handle(), input_data.opaque(),
3464 &beta, output_tensor_desc.handle(), output_data->opaque()));
3465 return port::Status::OK();
3466 }();
3467 return IsStatusOk(status, /*report_error=*/true);
3468 }
3469
3470 template <class T>
DoConvolveBackwardBiasImpl(Stream * stream,const dnn::BatchDescriptor & input_descriptor,const DeviceMemory<T> & input_data,const dnn::BatchDescriptor & bias_descriptor,DeviceMemory<T> * backward_bias_data)3471 port::Status CudnnSupport::DoConvolveBackwardBiasImpl(
3472 Stream* stream, const dnn::BatchDescriptor& input_descriptor,
3473 const DeviceMemory<T>& input_data,
3474 const dnn::BatchDescriptor& bias_descriptor,
3475 DeviceMemory<T>* backward_bias_data) {
3476 cudnnDataType_t cudnn_type = GetCudnnDataType<T>();
3477 CudnnTensorDescriptor input_nd(input_descriptor, cudnn_type);
3478 CudnnTensorDescriptor bias_nd(bias_descriptor, cudnn_type);
3479
3480 // Alpha is the scaling factor for input.
3481 float alpha = 1.0;
3482 // Beta is the scaling factor for output.
3483 float beta = 0.0;
3484
3485 auto cudnn = cudnn_->GetHandle(parent_, stream);
3486 RETURN_IF_CUDNN_ERROR(cudnnConvolutionBackwardBias(
3487 cudnn.handle(), &alpha, input_nd.handle(), input_data.opaque(), &beta,
3488 bias_nd.handle(), backward_bias_data->opaque()));
3489 return port::Status::OK();
3490 }
3491
DoConvolveBackwardBias(Stream * stream,const dnn::BatchDescriptor & input_descriptor,const DeviceMemory<double> & input_data,const dnn::BatchDescriptor & bias_descriptor,DeviceMemory<double> * backward_bias_data)3492 bool CudnnSupport::DoConvolveBackwardBias(
3493 Stream* stream, const dnn::BatchDescriptor& input_descriptor,
3494 const DeviceMemory<double>& input_data,
3495 const dnn::BatchDescriptor& bias_descriptor,
3496 DeviceMemory<double>* backward_bias_data) {
3497 return IsStatusOk(
3498 DoConvolveBackwardBiasImpl(stream, input_descriptor, input_data,
3499 bias_descriptor, backward_bias_data),
3500 /*report_error=*/true);
3501 }
3502
DoConvolveBackwardBias(Stream * stream,const dnn::BatchDescriptor & input_descriptor,const DeviceMemory<float> & input_data,const dnn::BatchDescriptor & bias_descriptor,DeviceMemory<float> * backward_bias_data)3503 bool CudnnSupport::DoConvolveBackwardBias(
3504 Stream* stream, const dnn::BatchDescriptor& input_descriptor,
3505 const DeviceMemory<float>& input_data,
3506 const dnn::BatchDescriptor& bias_descriptor,
3507 DeviceMemory<float>* backward_bias_data) {
3508 return IsStatusOk(
3509 DoConvolveBackwardBiasImpl(stream, input_descriptor, input_data,
3510 bias_descriptor, backward_bias_data),
3511 /*report_error=*/true);
3512 }
3513
DoConvolveBackwardBias(Stream * stream,const dnn::BatchDescriptor & input_descriptor,const DeviceMemory<Eigen::half> & input_data,const dnn::BatchDescriptor & bias_descriptor,DeviceMemory<Eigen::half> * backward_bias_data)3514 bool CudnnSupport::DoConvolveBackwardBias(
3515 Stream* stream, const dnn::BatchDescriptor& input_descriptor,
3516 const DeviceMemory<Eigen::half>& input_data,
3517 const dnn::BatchDescriptor& bias_descriptor,
3518 DeviceMemory<Eigen::half>* backward_bias_data) {
3519 return IsStatusOk(
3520 DoConvolveBackwardBiasImpl(stream, input_descriptor, input_data,
3521 bias_descriptor, backward_bias_data),
3522 /*report_error=*/true);
3523 }
3524
DoMatMul(Stream * stream,const DeviceMemory<float> & input_data,const DeviceMemory<float> & weights,const dnn::BatchDescriptor & input_dimensions,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<float> * output_data)3525 bool CudnnSupport::DoMatMul(Stream* stream,
3526 const DeviceMemory<float>& input_data,
3527 const DeviceMemory<float>& weights,
3528 const dnn::BatchDescriptor& input_dimensions,
3529 const dnn::BatchDescriptor& output_dimensions,
3530 DeviceMemory<float>* output_data) {
3531 if (input_dimensions.count() != output_dimensions.count()) {
3532 LOG(ERROR) << "MatMul input and output dimensions are not compatible.";
3533 return false;
3534 }
3535
3536 // We do not permute the input or output, instead we just
3537 // reinterpret the layout. We are working with row-major matrices
3538 // and the rows of the input and output correspond to batch, so
3539 // batch has to be outermost in both the input and output.
3540 //
3541 // By adding transposes to the BLAS gemm call we could perhaps make
3542 // the kYXDepthBatch layout work as well, but there has been no need
3543 // for that so far.
3544 if (input_dimensions.layout() != dnn::DataLayout::kBatchYXDepth &&
3545 input_dimensions.layout() != dnn::DataLayout::kBatchDepthYX) {
3546 LOG(ERROR) << "Unsupported MatMul input layout.";
3547 return false;
3548 }
3549 if (output_dimensions.layout() != dnn::DataLayout::kBatchYXDepth &&
3550 output_dimensions.layout() != dnn::DataLayout::kBatchDepthYX) {
3551 LOG(ERROR) << "Unsupported MatMul output layout.";
3552 return false;
3553 }
3554
3555 if (output_dimensions.width() == 1 && output_dimensions.height() == 1) {
3556 // This is a fast path that also supports the kBatchYXDepth layout.
3557
3558 // The matrices here are in row-major format while BLAS expects
3559 // column-major, i.e. our matrices are transposed as far as BLAS
3560 // is concerned. So we need to compute output^T =
3561 // input^T*weights^T. There is no parameter for transposing the
3562 // output in BLAS gemm, but instead we can transpose both sides of
3563 // the equality to see that this is equivalent to
3564 // output=weights*input. So we only need to swap the order of
3565 // weights and input in the matrix product to correct for the
3566 // row-major versus column-major difference.
3567 const float alpha = 1.0f; // Take the matrix product without scaling it.
3568 const float beta = 0.0f; // Ignore the original values in output_data.
3569 const int64 m = output_dimensions.NodesAcrossFeatureMaps();
3570 const int64 n = input_dimensions.count();
3571 const int64 k = input_dimensions.NodesAcrossFeatureMaps();
3572 stream->ThenBlasGemm(blas::Transpose::kNoTranspose,
3573 blas::Transpose::kNoTranspose, m, n, k, alpha, weights,
3574 m, input_data, k, beta, output_data, m);
3575 } else {
3576 // This is a slower and more complex path that supports output
3577 // width() * height() > 1, though it only supports the
3578 // kBatchYXDepth layout. Does support kBatchDepthYX if output
3579 // feature_map_count() == 1, as then there is no difference
3580 // between the two layouts.
3581 //
3582 // The operation here is the same as above, except that we have to
3583 // do the matrix multiplication for each (y,x) output coordinate
3584 // separately. We then interpret weights as containing K = width()
3585 // * height() different matrices, which we all multiply onto the
3586 // matrix from input_data, yielding K matrix products. We then
3587 // combine these together into one matrix by concatenating all the
3588 // first rows of these matrices, then all the seconds rows and so
3589 // on. We can do this with a batched matrix multiplication, where
3590 // the result is written to a different submatrix of the output
3591 // for each matrix multiplication.
3592 //
3593 // The reason that we only support the kBatchYXDepth output layout
3594 // is that we have to do something in the depth for each (y,x)
3595 // coordinate. The kBatchYXDepth layout has the depth information
3596 // for each point (y,x) in contiguous memory while the
3597 // kBatchDepthYX layout does not.
3598 //
3599 // TODO(broune): Consider a special case for when output depth ==
3600 // 1, as then possibly this could all be done as one matrix
3601 // multiplication instead of a batched one, which should be
3602 // faster. Another possibility would be to add a weights layout
3603 // parameter and then support kBatchDepthYX for a different
3604 // weights layout.
3605 if (output_dimensions.layout() != dnn::DataLayout::kBatchYXDepth &&
3606 !(output_dimensions.layout() == dnn::DataLayout::kBatchDepthYX &&
3607 output_dimensions.feature_map_count() == 1)) {
3608 LOG(ERROR) << "Unsupported MatMul output layout.";
3609 return false;
3610 }
3611
3612 const float alpha = 1.0f; // Take the matrix product without scaling it.
3613 const float beta = 0.0f; // Ignore the original values in output_data.
3614 const uint64 m = output_dimensions.feature_map_count();
3615 const uint64 n = input_dimensions.count();
3616 const uint64 k = input_dimensions.NodesAcrossFeatureMaps();
3617 const int lda = m;
3618 const int ldb = k;
3619 const int ldc = output_dimensions.NodesAcrossFeatureMaps();
3620 const int batch_count = output_dimensions.NodesPerFeatureMap();
3621
3622 std::vector<DeviceMemory<float>> a(batch_count);
3623 std::vector<DeviceMemory<float>> b(batch_count);
3624 std::vector<DeviceMemory<float>> c(batch_count);
3625 for (int i = 0; i < batch_count; ++i) {
3626 const int weights_offset = i * input_dimensions.NodesAcrossFeatureMaps() *
3627 output_dimensions.feature_map_count();
3628 a[i] = DeviceMemory<float>::MakeFromByteSize(
3629 const_cast<float*>(reinterpret_cast<const float*>(weights.opaque())) +
3630 weights_offset,
3631 weights.ElementCount() - weights_offset);
3632
3633 b[i] = input_data;
3634
3635 const int output_offset = i * output_dimensions.feature_map_count();
3636 c[i] = DeviceMemory<float>::MakeFromByteSize(
3637 const_cast<float*>(
3638 reinterpret_cast<const float*>(output_data->opaque())) +
3639 output_offset,
3640 output_data->ElementCount() - output_offset);
3641 }
3642 const auto toPtrs = [](std::vector<DeviceMemory<float>>& v) {
3643 std::vector<DeviceMemory<float>*> ptrs;
3644 ptrs.reserve(v.size());
3645 for (auto& mem : v) {
3646 ptrs.push_back(&mem);
3647 }
3648 return ptrs;
3649 };
3650
3651 stream->ThenBlasGemmBatched(blas::Transpose::kNoTranspose,
3652 blas::Transpose::kNoTranspose, m, n, k, alpha,
3653 toPtrs(a), lda, toPtrs(b), ldb, beta, toPtrs(c),
3654 ldc, batch_count);
3655 }
3656
3657 return stream->ok();
3658 }
3659
DoBiasAdd(Stream * stream,const DeviceMemory<float> & input_data,const DeviceMemory<float> & biases,const dnn::BatchDescriptor & dimensions,DeviceMemory<float> * output_data)3660 bool CudnnSupport::DoBiasAdd(Stream* stream,
3661 const DeviceMemory<float>& input_data,
3662 const DeviceMemory<float>& biases,
3663 const dnn::BatchDescriptor& dimensions,
3664 DeviceMemory<float>* output_data) {
3665 CudnnTensorDescriptor input_descriptor(dimensions, CUDNN_DATA_FLOAT);
3666
3667 dnn::BatchDescriptor bias_dimensions;
3668 bias_dimensions.set_count(1)
3669 .set_feature_map_count(dimensions.feature_map_count())
3670 .set_height(1)
3671 .set_width(1)
3672 .set_layout(dnn::DataLayout::kBatchYXDepth);
3673 CudnnTensorDescriptor bias_descriptor(bias_dimensions, CUDNN_DATA_FLOAT);
3674
3675 // cudnnAddTensor after R3 is in-place, so we need to copy input_data to
3676 // output_data before doing the addition, unless the input and
3677 // output are at the same address.
3678 if (input_data.opaque() != output_data->opaque()) {
3679 stream->ThenMemcpy(output_data, input_data,
3680 dimensions.ElementCount() * sizeof(float));
3681 if (!stream->ok()) {
3682 LOG(ERROR)
3683 << "stream " << stream
3684 << " could not enqueue a tensor copy as part of bias addition.";
3685 return false;
3686 }
3687 }
3688
3689 const float alpha = 1.0f;
3690 const float beta = 1.0f;
3691
3692 auto cudnn = cudnn_->GetHandle(parent_, stream);
3693
3694 const auto status = [&] {
3695 RETURN_IF_CUDNN_ERROR(cudnnAddTensor(
3696 cudnn.handle(), &alpha, bias_descriptor.handle(), biases.opaque(),
3697 &beta, input_descriptor.handle(), output_data->opaque()));
3698 return port::Status::OK();
3699 }();
3700 return IsStatusOk(status, /*report_error=*/true);
3701 }
3702
DoActivate(Stream * stream,dnn::ActivationMode activation_mode,const dnn::BatchDescriptor & dimensions,const DeviceMemory<float> & input_data,DeviceMemory<float> * output_data,uint64 options)3703 bool CudnnSupport::DoActivate(Stream* stream,
3704 dnn::ActivationMode activation_mode,
3705 const dnn::BatchDescriptor& dimensions,
3706 const DeviceMemory<float>& input_data,
3707 DeviceMemory<float>* output_data,
3708 uint64 options) {
3709 CudnnActivationDescriptor activation_desc(
3710 activation_mode, CUDNN_PROPAGATE_NAN, dimensions.value_max());
3711
3712 CudnnTensorDescriptor input_nd(dimensions, CUDNN_DATA_FLOAT);
3713 // Alpha is the input scaling factor.
3714 float alpha = 1.0;
3715 // Beta is the output scaling factor.
3716 float beta = 0.0;
3717
3718 auto cudnn = cudnn_->GetHandle(parent_, stream);
3719 const auto status = [&] {
3720 RETURN_IF_CUDNN_ERROR(cudnnActivationForward(
3721 cudnn.handle(), activation_desc.handle(), &alpha, input_nd.handle(),
3722 input_data.opaque(), &beta, input_nd.handle(), output_data->opaque()));
3723 return port::Status::OK();
3724 }();
3725 return IsStatusOk(status, /*report_error=*/true);
3726 }
3727
DoPoolForward(Stream * stream,const dnn::PoolingDescriptor & pooling_dimensions,const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<double> & input_data,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<double> * output_data,ScratchAllocator * workspace_allocator)3728 bool CudnnSupport::DoPoolForward(
3729 Stream* stream, const dnn::PoolingDescriptor& pooling_dimensions,
3730 const dnn::BatchDescriptor& input_dimensions,
3731 const DeviceMemory<double>& input_data,
3732 const dnn::BatchDescriptor& output_dimensions,
3733 DeviceMemory<double>* output_data, ScratchAllocator* workspace_allocator) {
3734 // Alpha is the scaling factor for input.
3735 double alpha = 1.0;
3736 // Beta is the scaling factor for output.
3737 double beta = 0.0;
3738
3739 CudnnTensorDescriptor src_desc(input_dimensions, CUDNN_DATA_DOUBLE);
3740 CudnnTensorDescriptor dest_desc(output_dimensions, CUDNN_DATA_DOUBLE);
3741 CudnnPoolingDescriptor pooling_desc(pooling_dimensions);
3742
3743 auto cudnn = cudnn_->GetHandle(parent_, stream);
3744 const auto status = [&] {
3745 RETURN_IF_CUDNN_ERROR(cudnnPoolingForward(
3746 cudnn.handle(), pooling_desc.handle(), &alpha, src_desc.handle(),
3747 input_data.opaque(), &beta, dest_desc.handle(), output_data->opaque()));
3748 return port::Status::OK();
3749 }();
3750 return IsStatusOk(status, /*report_error=*/true);
3751 }
3752
DoPoolForward(Stream * stream,const dnn::PoolingDescriptor & pooling_dimensions,const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<float> & input_data,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<float> * output_data,ScratchAllocator * workspace_allocator)3753 bool CudnnSupport::DoPoolForward(
3754 Stream* stream, const dnn::PoolingDescriptor& pooling_dimensions,
3755 const dnn::BatchDescriptor& input_dimensions,
3756 const DeviceMemory<float>& input_data,
3757 const dnn::BatchDescriptor& output_dimensions,
3758 DeviceMemory<float>* output_data, ScratchAllocator* workspace_allocator) {
3759 // Alpha is the scaling factor for input.
3760 float alpha = 1.0;
3761 // Beta is the scaling factor for output.
3762 float beta = 0.0;
3763
3764 CudnnTensorDescriptor src_desc(input_dimensions, CUDNN_DATA_FLOAT);
3765 CudnnTensorDescriptor dest_desc(output_dimensions, CUDNN_DATA_FLOAT);
3766 CudnnPoolingDescriptor pooling_desc(pooling_dimensions);
3767
3768 auto cudnn = cudnn_->GetHandle(parent_, stream);
3769 const auto status = [&] {
3770 RETURN_IF_CUDNN_ERROR(cudnnPoolingForward(
3771 cudnn.handle(), pooling_desc.handle(), &alpha, src_desc.handle(),
3772 input_data.opaque(), &beta, dest_desc.handle(), output_data->opaque()));
3773 return port::Status::OK();
3774 }();
3775 return IsStatusOk(status, /*report_error=*/true);
3776 }
3777
DoPoolForward(Stream * stream,const dnn::PoolingDescriptor & pooling_dimensions,const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<Eigen::half> & input_data,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<Eigen::half> * output_data,ScratchAllocator * workspace_allocator)3778 bool CudnnSupport::DoPoolForward(
3779 Stream* stream, const dnn::PoolingDescriptor& pooling_dimensions,
3780 const dnn::BatchDescriptor& input_dimensions,
3781 const DeviceMemory<Eigen::half>& input_data,
3782 const dnn::BatchDescriptor& output_dimensions,
3783 DeviceMemory<Eigen::half>* output_data,
3784 ScratchAllocator* workspace_allocator) {
3785 // Alpha is the scaling factor for input.
3786 float alpha = 1.0;
3787 // Beta is the scaling factor for output.
3788 float beta = 0.0;
3789
3790 CudnnTensorDescriptor src_desc(input_dimensions, CUDNN_DATA_HALF);
3791 CudnnTensorDescriptor dest_desc(output_dimensions, CUDNN_DATA_HALF);
3792 CudnnPoolingDescriptor pooling_desc(pooling_dimensions);
3793 auto cudnn = cudnn_->GetHandle(parent_, stream);
3794 const auto status = [&] {
3795 RETURN_IF_CUDNN_ERROR(cudnnPoolingForward(
3796 cudnn.handle(), pooling_desc.handle(), &alpha, src_desc.handle(),
3797 input_data.opaque(), &beta, dest_desc.handle(), output_data->opaque()));
3798 return port::Status::OK();
3799 }();
3800 return IsStatusOk(status, /*report_error=*/true);
3801 }
3802
DoPoolForward(Stream * stream,const dnn::PoolingDescriptor & pooling_dimensions,const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<int8> & input_data,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<int8> * output_data,ScratchAllocator * workspace_allocator)3803 bool CudnnSupport::DoPoolForward(
3804 Stream* stream, const dnn::PoolingDescriptor& pooling_dimensions,
3805 const dnn::BatchDescriptor& input_dimensions,
3806 const DeviceMemory<int8>& input_data,
3807 const dnn::BatchDescriptor& output_dimensions,
3808 DeviceMemory<int8>* output_data, ScratchAllocator* workspace_allocator) {
3809 // Alpha is the scaling factor for input.
3810 float alpha = 1.0;
3811 // Beta is the scaling factor for output.
3812 float beta = 0.0;
3813
3814 CudnnTensorDescriptor src_desc(input_dimensions, CUDNN_DATA_INT8);
3815 CudnnTensorDescriptor dest_desc(output_dimensions, CUDNN_DATA_INT8);
3816 CudnnPoolingDescriptor pooling_desc(pooling_dimensions);
3817
3818 auto cudnn = cudnn_->GetHandle(parent_, stream);
3819 const auto status = [&] {
3820 RETURN_IF_CUDNN_ERROR(cudnnPoolingForward(
3821 cudnn.handle(), pooling_desc.handle(), &alpha, src_desc.handle(),
3822 input_data.opaque(), &beta, dest_desc.handle(), output_data->opaque()));
3823 return port::Status::OK();
3824 }();
3825 return IsStatusOk(status, /*report_error=*/true);
3826 }
3827
DoPoolBackward(Stream * stream,const dnn::PoolingDescriptor & pooling_dimensions,const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<double> & input_data,const dnn::BatchDescriptor & output_dimensions,const DeviceMemory<double> & output_data,const DeviceMemory<double> & input_diff_data,DeviceMemory<double> * output_diff_data,ScratchAllocator * workspace_allocator)3828 bool CudnnSupport::DoPoolBackward(
3829 Stream* stream, const dnn::PoolingDescriptor& pooling_dimensions,
3830 const dnn::BatchDescriptor& input_dimensions,
3831 const DeviceMemory<double>& input_data,
3832 const dnn::BatchDescriptor& output_dimensions,
3833 const DeviceMemory<double>& output_data,
3834 const DeviceMemory<double>& input_diff_data,
3835 DeviceMemory<double>* output_diff_data,
3836 ScratchAllocator* workspace_allocator) {
3837 // Alpha is the scaling factor for input.
3838 double alpha = 1.0;
3839 // Beta is the scaling factor for output.
3840 double beta = 0.0;
3841
3842 CudnnTensorDescriptor src_desc(input_dimensions, CUDNN_DATA_DOUBLE);
3843 CudnnTensorDescriptor dest_desc(output_dimensions, CUDNN_DATA_DOUBLE);
3844 CudnnPoolingDescriptor pooling_desc(pooling_dimensions);
3845
3846 auto cudnn = cudnn_->GetHandle(parent_, stream);
3847 const auto status = [&] {
3848 RETURN_IF_CUDNN_ERROR(cudnnPoolingBackward(
3849 cudnn.handle(), pooling_desc.handle(), &alpha, dest_desc.handle(),
3850 output_data.opaque(), dest_desc.handle(), input_diff_data.opaque(),
3851 src_desc.handle(), input_data.opaque(), &beta, src_desc.handle(),
3852 output_diff_data->opaque()));
3853 return port::Status::OK();
3854 }();
3855 return IsStatusOk(status, /*report_error=*/true);
3856 }
3857
DoPoolBackward(Stream * stream,const dnn::PoolingDescriptor & pooling_dimensions,const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<float> & input_data,const dnn::BatchDescriptor & output_dimensions,const DeviceMemory<float> & output_data,const DeviceMemory<float> & input_diff_data,DeviceMemory<float> * output_diff_data,ScratchAllocator * workspace_allocator)3858 bool CudnnSupport::DoPoolBackward(
3859 Stream* stream, const dnn::PoolingDescriptor& pooling_dimensions,
3860 const dnn::BatchDescriptor& input_dimensions,
3861 const DeviceMemory<float>& input_data,
3862 const dnn::BatchDescriptor& output_dimensions,
3863 const DeviceMemory<float>& output_data,
3864 const DeviceMemory<float>& input_diff_data,
3865 DeviceMemory<float>* output_diff_data,
3866 ScratchAllocator* workspace_allocator) {
3867 // Alpha is the scaling factor for input.
3868 float alpha = 1.0;
3869 // Beta is the scaling factor for output.
3870 float beta = 0.0;
3871
3872 CudnnTensorDescriptor src_desc(input_dimensions, CUDNN_DATA_FLOAT);
3873 CudnnTensorDescriptor dest_desc(output_dimensions, CUDNN_DATA_FLOAT);
3874 CudnnPoolingDescriptor pooling_desc(pooling_dimensions);
3875
3876 auto cudnn = cudnn_->GetHandle(parent_, stream);
3877 const auto status = [&] {
3878 RETURN_IF_CUDNN_ERROR(cudnnPoolingBackward(
3879 cudnn.handle(), pooling_desc.handle(), &alpha, dest_desc.handle(),
3880 output_data.opaque(), dest_desc.handle(), input_diff_data.opaque(),
3881 src_desc.handle(), input_data.opaque(), &beta, src_desc.handle(),
3882 output_diff_data->opaque()));
3883 return port::Status::OK();
3884 }();
3885 return IsStatusOk(status, /*report_error=*/true);
3886 }
3887
DoPoolBackward(Stream * stream,const dnn::PoolingDescriptor & pooling_dimensions,const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<Eigen::half> & input_data,const dnn::BatchDescriptor & output_dimensions,const DeviceMemory<Eigen::half> & output_data,const DeviceMemory<Eigen::half> & input_diff_data,DeviceMemory<Eigen::half> * output_diff_data,ScratchAllocator * workspace_allocator)3888 bool CudnnSupport::DoPoolBackward(
3889 Stream* stream, const dnn::PoolingDescriptor& pooling_dimensions,
3890 const dnn::BatchDescriptor& input_dimensions,
3891 const DeviceMemory<Eigen::half>& input_data,
3892 const dnn::BatchDescriptor& output_dimensions,
3893 const DeviceMemory<Eigen::half>& output_data,
3894 const DeviceMemory<Eigen::half>& input_diff_data,
3895 DeviceMemory<Eigen::half>* output_diff_data,
3896 ScratchAllocator* workspace_allocator) {
3897 // Alpha is the scaling factor for input.
3898 float alpha = 1.0;
3899 // Beta is the scaling factor for output.
3900 float beta = 0.0;
3901
3902 CudnnTensorDescriptor src_desc(input_dimensions, CUDNN_DATA_HALF);
3903 CudnnTensorDescriptor dest_desc(output_dimensions, CUDNN_DATA_HALF);
3904 CudnnPoolingDescriptor pooling_desc(pooling_dimensions);
3905
3906 auto cudnn = cudnn_->GetHandle(parent_, stream);
3907 const auto status = [&] {
3908 RETURN_IF_CUDNN_ERROR(cudnnPoolingBackward(
3909 cudnn.handle(), pooling_desc.handle(), &alpha, dest_desc.handle(),
3910 output_data.opaque(), dest_desc.handle(), input_diff_data.opaque(),
3911 src_desc.handle(), input_data.opaque(), &beta, src_desc.handle(),
3912 output_diff_data->opaque()));
3913 return port::Status::OK();
3914 }();
3915 return IsStatusOk(status, /*report_error=*/true);
3916 }
3917
DoNormalizeWithDimensions(Stream * stream,const dnn::NormalizeDescriptor & normalize_descriptor,const dnn::BatchDescriptor & dimensions,const DeviceMemory<float> & input_data,DeviceMemory<float> * output_data)3918 bool CudnnSupport::DoNormalizeWithDimensions(
3919 Stream* stream, const dnn::NormalizeDescriptor& normalize_descriptor,
3920 const dnn::BatchDescriptor& dimensions,
3921 const DeviceMemory<float>& input_data, DeviceMemory<float>* output_data) {
3922 // Check for unsupported modes.
3923 if (normalize_descriptor.wrap_around()) {
3924 LOG(ERROR) << "CUDA LRN does not support cudnn-around mode";
3925 return false;
3926 }
3927 if (normalize_descriptor.segment_size()) {
3928 LOG(ERROR) << "CUDA LRN does not support segmentation";
3929 return false;
3930 }
3931
3932 CudnnTensorDescriptor dims(dimensions, CUDNN_DATA_FLOAT);
3933 CudnnNormalizeDescriptor normalize(normalize_descriptor);
3934
3935 // Alpha is the scaling factor for input.
3936 float alpha = 1.0f;
3937 // Beta is the scaling factor for output.
3938 float beta = 0.0f;
3939
3940 auto cudnn = cudnn_->GetHandle(parent_, stream);
3941
3942 // Launch the normalization.
3943 const auto status = [&] {
3944 RETURN_IF_CUDNN_ERROR(cudnnLRNCrossChannelForward(
3945 cudnn.handle(), normalize.handle(), CUDNN_LRN_CROSS_CHANNEL_DIM1,
3946 &alpha, dims.handle(), input_data.opaque(), &beta, dims.handle(),
3947 output_data->opaque()));
3948 return port::Status::OK();
3949 }();
3950 return IsStatusOk(status, /*report_error=*/true);
3951 }
3952
DoNormalizeBackwardWithDimensions(Stream * stream,const dnn::NormalizeDescriptor & normalize_descriptor,const dnn::BatchDescriptor & dimensions,const DeviceMemory<float> & raw_data,const DeviceMemory<float> & normalized_data,const DeviceMemory<float> & normalized_variable_gradient,DeviceMemory<float> * raw_variable_gradient,ScratchAllocator * workspace_allocator)3953 bool CudnnSupport::DoNormalizeBackwardWithDimensions(
3954 Stream* stream, const dnn::NormalizeDescriptor& normalize_descriptor,
3955 const dnn::BatchDescriptor& dimensions, const DeviceMemory<float>& raw_data,
3956 const DeviceMemory<float>& normalized_data,
3957 const DeviceMemory<float>& normalized_variable_gradient,
3958 DeviceMemory<float>* raw_variable_gradient,
3959 ScratchAllocator* workspace_allocator) {
3960 // Check for unsupported modes.
3961 if (normalize_descriptor.wrap_around()) {
3962 LOG(ERROR) << "CUDA LRN does not support cudnn-around mode";
3963 return false;
3964 }
3965 if (normalize_descriptor.segment_size()) {
3966 LOG(ERROR) << "CUDA LRN does not support segmentation";
3967 return false;
3968 }
3969
3970 CudnnTensorDescriptor dims(dimensions, CUDNN_DATA_FLOAT);
3971 CudnnNormalizeDescriptor normalize(normalize_descriptor);
3972
3973 float alpha = 1.0f;
3974 float beta = 0.0f;
3975
3976 auto cudnn = cudnn_->GetHandle(parent_, stream);
3977 const auto status = [&] {
3978 RETURN_IF_CUDNN_ERROR(cudnnLRNCrossChannelBackward(
3979 cudnn.handle(), normalize.handle(), CUDNN_LRN_CROSS_CHANNEL_DIM1,
3980 &alpha, dims.handle(), normalized_data.opaque(), dims.handle(),
3981 normalized_variable_gradient.opaque(), dims.handle(), raw_data.opaque(),
3982 &beta, dims.handle(), raw_variable_gradient->opaque()));
3983 return port::Status::OK();
3984 }();
3985 return IsStatusOk(status, /*report_error=*/true);
3986 }
3987
DoDepthConcatenate(Stream * stream,port::ArraySlice<dnn::BatchDescriptor> input_dimensions,port::ArraySlice<const DeviceMemory<float> * > input_data,DeviceMemory<float> * output_data)3988 bool CudnnSupport::DoDepthConcatenate(
3989 Stream* stream, port::ArraySlice<dnn::BatchDescriptor> input_dimensions,
3990 port::ArraySlice<const DeviceMemory<float>*> input_data,
3991 DeviceMemory<float>* output_data) {
3992 CHECK_EQ(input_dimensions.size(), input_data.size());
3993
3994 for (const auto& dimensions : input_dimensions) {
3995 if (dimensions.layout() != dnn::DataLayout::kBatchDepthYX) {
3996 LOG(ERROR) << "CudnnSupport::DoDepthConcatenate currently only "
3997 "supports the kBatchDepthYX layout.";
3998 return false;
3999 }
4000 }
4001
4002 if (input_dimensions.empty()) {
4003 return true; // Nothing to do.
4004 }
4005
4006 dnn::BatchDescriptor output_dimensions =
4007 dnn::BatchDescriptor::DepthConcatenateOutputDescriptor(input_dimensions);
4008
4009 const int64 area = output_dimensions.width() * output_dimensions.height();
4010 const auto index = [area](int64 batch, int64 depth, int64 yx,
4011 int64 max_depth) {
4012 return (batch * max_depth + depth) * area + yx;
4013 };
4014
4015 std::vector<float> output_host(output_dimensions.ElementCount());
4016 std::vector<float> tmp;
4017 int64 depth_sum = 0;
4018 for (size_t i = 0; i < input_data.size(); ++i) {
4019 const auto& dimensions = input_dimensions[i];
4020 tmp.resize(dimensions.ElementCount());
4021 stream->ThenMemcpyD2H<float>(*input_data[i], absl::MakeSpan(tmp));
4022 port::Status block_status = stream->BlockHostUntilDone();
4023 if (!block_status.ok()) {
4024 LOG(ERROR) << "BlockHostUntilDone failed: " << block_status;
4025 return false;
4026 }
4027
4028 for (int64 batch = 0; batch < output_dimensions.count(); ++batch) {
4029 for (int64 yx = 0; yx < area; ++yx) {
4030 for (int64 depth = 0; depth < dimensions.feature_map_count(); ++depth) {
4031 LOG(INFO) << output_dimensions.ElementCount() << ' ' << batch << ' '
4032 << yx << ' ' << depth;
4033 output_host[index(batch, depth + depth_sum, yx,
4034 output_dimensions.feature_map_count())] =
4035 tmp[index(batch, depth, yx, dimensions.feature_map_count())];
4036 }
4037 }
4038 }
4039 depth_sum += dimensions.feature_map_count();
4040 }
4041 stream->ThenMemcpyH2D<float>(output_host, output_data);
4042 return true;
4043 }
4044
DoElementwiseOperate(Stream * stream,dnn::ElementwiseOperation operation,port::ArraySlice<dnn::BatchDescriptor> input_dimensions,port::ArraySlice<const DeviceMemory<float> * > input_data,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<float> * output_data)4045 bool CudnnSupport::DoElementwiseOperate(
4046 Stream* stream, dnn::ElementwiseOperation operation,
4047 port::ArraySlice<dnn::BatchDescriptor> input_dimensions,
4048 port::ArraySlice<const DeviceMemory<float>*> input_data,
4049 const dnn::BatchDescriptor& output_dimensions,
4050 DeviceMemory<float>* output_data) {
4051 LOG(FATAL) << "not yet implemented"; // TODO(leary)
4052 return false;
4053 }
4054
DoXYPad(Stream * stream,const dnn::BatchDescriptor & dimensions,const DeviceMemory<float> & input_data,int64 left_pad,int64 right_pad,int64 top_pad,int64 bottom_pad,DeviceMemory<float> * output_data)4055 bool CudnnSupport::DoXYPad(Stream* stream,
4056 const dnn::BatchDescriptor& dimensions,
4057 const DeviceMemory<float>& input_data,
4058 int64 left_pad, int64 right_pad, int64 top_pad,
4059 int64 bottom_pad, DeviceMemory<float>* output_data) {
4060 LOG(FATAL) << "not yet implemented"; // TODO(leary)
4061 return false;
4062 }
4063
DoXYSlice(Stream * stream,const dnn::BatchDescriptor & dimensions,const DeviceMemory<float> & input_data,int64 left_trim,int64 right_trim,int64 top_trim,int64 bottom_trim,DeviceMemory<float> * output_data)4064 bool CudnnSupport::DoXYSlice(Stream* stream,
4065 const dnn::BatchDescriptor& dimensions,
4066 const DeviceMemory<float>& input_data,
4067 int64 left_trim, int64 right_trim, int64 top_trim,
4068 int64 bottom_trim,
4069 DeviceMemory<float>* output_data) {
4070 LOG(FATAL) << "not yet implemented"; // TODO(leary)
4071 return false;
4072 }
4073
DoMemcpyD2HQuantized(Stream * stream,const DeviceMemory<float> & gpu_unquantized_src,dnn::QuantizedActivationMode mode,void * host_dst,int64 size)4074 bool CudnnSupport::DoMemcpyD2HQuantized(
4075 Stream* stream, const DeviceMemory<float>& gpu_unquantized_src,
4076 dnn::QuantizedActivationMode mode, void* host_dst, int64 size) {
4077 LOG(ERROR) << "quantized memcpy not supported by cuDNN";
4078 return false;
4079 }
4080
DoMemcpyH2DQuantized(Stream * stream,const void * host_src,int64 size,dnn::QuantizedActivationMode mode,DeviceMemory<float> * gpu_unquantized_dst)4081 bool CudnnSupport::DoMemcpyH2DQuantized(
4082 Stream* stream, const void* host_src, int64 size,
4083 dnn::QuantizedActivationMode mode,
4084 DeviceMemory<float>* gpu_unquantized_dst) {
4085 LOG(ERROR) << "quantized memcpy not supported by cuDNN";
4086 return false;
4087 }
4088
DeriveOutputBatchDescriptor(const dnn::BatchDescriptor & batch_descriptor,const dnn::FilterDescriptor & filter_descriptor,const dnn::ConvolutionDescriptor & convolution_descriptor,dnn::BatchDescriptor * output_batch_descriptor)4089 bool CudnnSupport::DeriveOutputBatchDescriptor(
4090 const dnn::BatchDescriptor& batch_descriptor,
4091 const dnn::FilterDescriptor& filter_descriptor,
4092 const dnn::ConvolutionDescriptor& convolution_descriptor,
4093 dnn::BatchDescriptor* output_batch_descriptor) {
4094 CudnnTensorDescriptor input_nd(batch_descriptor, CUDNN_DATA_FLOAT);
4095 CudnnFilterDescriptor filter(filter_descriptor, CUDNN_DATA_FLOAT);
4096 CudnnConvolutionDescriptor conv(convolution_descriptor, CUDNN_DATA_FLOAT);
4097
4098 int dn = batch_descriptor.ndims() + 2;
4099 std::vector<int> dims(dn); // in BDYX
4100 const auto status = [&] {
4101 RETURN_IF_CUDNN_ERROR(cudnnGetConvolutionNdForwardOutputDim(
4102 conv.handle(), input_nd.handle(), filter.handle(), dn, dims.data()));
4103 output_batch_descriptor->set_count(dims[0])
4104 .set_feature_map_count(dims[1])
4105 .set_layout(batch_descriptor.layout());
4106
4107 for (int i = 0; i < batch_descriptor.ndims(); i++) {
4108 output_batch_descriptor->set_spatial_dim(static_cast<dnn::DimIndex>(i),
4109 dims.rbegin()[i]);
4110 }
4111 return port::Status::OK();
4112 }();
4113 return IsStatusOk(status, /*report_error=*/true);
4114 }
4115
4116 } // namespace gpu
4117
initialize_cudnn()4118 void initialize_cudnn() {
4119 port::Status status =
4120 PluginRegistry::Instance()->RegisterFactory<PluginRegistry::DnnFactory>(
4121 cuda::kCudaPlatformId, gpu::kCuDnnPlugin, "cuDNN",
4122 [](internal::StreamExecutorInterface* parent) -> dnn::DnnSupport* {
4123 gpu::GpuExecutor* cuda_executor =
4124 dynamic_cast<gpu::GpuExecutor*>(parent);
4125 if (cuda_executor == nullptr) {
4126 LOG(ERROR) << "Attempting to initialize an instance of the cuDNN "
4127 << "support library with a non-CUDA StreamExecutor";
4128 return nullptr;
4129 }
4130
4131 gpu::CudnnSupport* dnn = new gpu::CudnnSupport(cuda_executor);
4132 if (!dnn->Init().ok()) {
4133 // Note: Init() will log a more specific error.
4134 delete dnn;
4135 return nullptr;
4136 }
4137 return dnn;
4138 });
4139
4140 if (!status.ok()) {
4141 LOG(ERROR) << "Unable to register cuDNN factory: "
4142 << status.error_message();
4143 }
4144
4145 PluginRegistry::Instance()->SetDefaultFactory(
4146 cuda::kCudaPlatformId, PluginKind::kDnn, gpu::kCuDnnPlugin);
4147 }
4148
4149 } // namespace stream_executor
4150
4151 #pragma clang diagnostic pop
4152
4153 REGISTER_MODULE_INITIALIZER(register_cudnn,
4154 { stream_executor::initialize_cudnn(); });
4155