1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/stream_executor/cuda/cuda_dnn.h"
17
18 #include <functional>
19 #include <memory>
20 #include <utility>
21
22 #include "absl/memory/memory.h"
23 #include "absl/strings/str_cat.h"
24 #include "absl/strings/str_format.h"
25 #include "third_party/eigen3/Eigen/Core"
26 #include "tensorflow/core/lib/core/errors.h"
27 #include "tensorflow/core/platform/tensor_float_32_utils.h"
28 #include "tensorflow/core/util/determinism.h"
29 #include "tensorflow/core/util/env_var.h"
30 #include "tensorflow/core/util/use_cudnn.h"
31 #include "tensorflow/stream_executor/cuda/cuda_activation.h"
32 #include "tensorflow/stream_executor/cuda/cuda_diagnostics.h"
33 #include "tensorflow/stream_executor/cuda/cuda_driver.h"
34 #include "tensorflow/stream_executor/cuda/cuda_gpu_executor.h"
35 #include "tensorflow/stream_executor/cuda/cuda_platform_id.h"
36 #include "tensorflow/stream_executor/cuda/cuda_stream.h"
37 #include "tensorflow/stream_executor/cuda/cuda_timer.h"
38 #include "tensorflow/stream_executor/cuda/cudnn_version.h"
39 #include "tensorflow/stream_executor/dnn.h"
40 #include "tensorflow/stream_executor/lib/env.h"
41 #include "tensorflow/stream_executor/lib/error.h"
42 #include "tensorflow/stream_executor/lib/initialize.h"
43 #include "tensorflow/stream_executor/lib/mathutil.h"
44 #include "tensorflow/stream_executor/lib/threadpool.h"
45 #include "tensorflow/stream_executor/platform/logging.h"
46 #include "tensorflow/stream_executor/plugin_registry.h"
47 #include "tensorflow/stream_executor/scratch_allocator.h"
48 #include "tensorflow/stream_executor/stream.h"
49 #include "tensorflow/stream_executor/stream_executor_pimpl.h"
50 // clang-format off
51 #include "third_party/gpus/cudnn/cudnn.h"
52 #if CUDNN_VERSION >= 8100 && TF_ENABLE_CUDNN_FRONTEND
53 #include "third_party/cudnn_frontend/include/cudnn_frontend.h"
54 #endif // CUDNN_VERSION >= 8100 && TF_ENABLE_CUDNN_FRONTEND
55 #include "absl/strings/string_view.h"
56 // clang-format on
57
58 #pragma clang diagnostic push
59
60 // Make sure that Eigen::half forward declaration in dnn.h matches the
61 // declaration in Eigen.
62 #pragma clang diagnostic warning "-Wmismatched-tags"
63
64 namespace stream_executor {
65 namespace gpu {
66
67 PLUGIN_REGISTRY_DEFINE_PLUGIN_ID(kCuDnnPlugin);
68
69 namespace {
70
71 static_assert(CUDNN_VERSION >= 7300, "cuDNN needs to be version 7.3 or higher");
72
73 // Exits the program if 'expr' doesn't return CUDNN_STATUS_SUCCESS.
74 #define CHECK_CUDNN_OK(expr) CHECK_EQ(expr, CUDNN_STATUS_SUCCESS)
75
76 // If 'expr' doesn't return CUDNN_STATUS_SUCCESS, returns from the current
77 // function with a non-successful port::Status.
78 #define RETURN_IF_CUDNN_ERROR(expr) \
79 do { \
80 cudnnStatus_t _status = (expr); \
81 if (!SE_PREDICT_TRUE(_status == CUDNN_STATUS_SUCCESS)) { \
82 std::ostringstream oss; \
83 oss << ToString(_status) << "\nin " << __FILE__ << "(" << __LINE__ \
84 << "): '" << #expr << "'"; \
85 return port::Status(port::error::UNKNOWN, oss.str()); \
86 } \
87 } while (false)
88
89 #define RETURN_MSG_IF_CUDNN_ERROR(expr) \
90 do { \
91 cudnnStatus_t _status = (expr).get_status(); \
92 if (!SE_PREDICT_TRUE(_status == CUDNN_STATUS_SUCCESS)) { \
93 std::ostringstream oss; \
94 oss << ToString(_status) << "\nin " << __FILE__ << "(" << __LINE__ \
95 << "): '" << #expr << "' " << (expr).get_error(); \
96 return port::Status(port::error::UNKNOWN, oss.str()); \
97 } \
98 } while (false)
99
100 #define RETURN_FALSE_IF_CUDNN_ERROR(expr) \
101 do { \
102 if (!SE_PREDICT_TRUE((expr).get_status() == CUDNN_STATUS_SUCCESS)) { \
103 return false; \
104 } \
105 } while (false)
106
107 // Converts (via narrowing) a type T value to a type U, and checks that the
108 // value has no value change due to the conversion.
109 template <typename WideT, typename NarrowT>
CheckedNarrowing(const WideT & wide)110 NarrowT CheckedNarrowing(const WideT& wide) {
111 NarrowT narrow = wide;
112 CHECK_EQ(narrow, wide)
113 << "checked narrowing failed; values not equal post-conversion";
114 return narrow;
115 }
116
ToString(cudnnStatus_t status)117 std::string ToString(cudnnStatus_t status) {
118 switch (status) {
119 case CUDNN_STATUS_SUCCESS:
120 return "CUDNN_STATUS_SUCCESS";
121 case CUDNN_STATUS_NOT_INITIALIZED:
122 return "CUDNN_STATUS_NOT_INITIALIZED";
123 case CUDNN_STATUS_ALLOC_FAILED:
124 return "CUDNN_STATUS_ALLOC_FAILED";
125 case CUDNN_STATUS_BAD_PARAM:
126 return "CUDNN_STATUS_BAD_PARAM";
127 case CUDNN_STATUS_INTERNAL_ERROR:
128 return "CUDNN_STATUS_INTERNAL_ERROR";
129 case CUDNN_STATUS_INVALID_VALUE:
130 return "CUDNN_STATUS_INVALID_VALUE";
131 case CUDNN_STATUS_ARCH_MISMATCH:
132 return "CUDNN_STATUS_ARCH_MISMATCH";
133 case CUDNN_STATUS_MAPPING_ERROR:
134 return "CUDNN_STATUS_MAPPING_ERROR";
135 case CUDNN_STATUS_EXECUTION_FAILED:
136 return "CUDNN_STATUS_EXECUTION_FAILED";
137 case CUDNN_STATUS_NOT_SUPPORTED:
138 return "CUDNN_STATUS_NOT_SUPPORTED";
139 case CUDNN_STATUS_LICENSE_ERROR:
140 return "CUDNN_STATUS_LICENSE_ERROR";
141 case CUDNN_STATUS_RUNTIME_PREREQUISITE_MISSING:
142 return "CUDNN_STATUS_RUNTIME_PREREQUISITE_MISSING";
143 case CUDNN_STATUS_RUNTIME_IN_PROGRESS:
144 return "CUDNN_STATUS_RUNTIME_IN_PROGRESS";
145 case CUDNN_STATUS_RUNTIME_FP_OVERFLOW:
146 return "CUDNN_STATUS_RUNTIME_FP_OVERFLOW";
147 default:
148 return absl::StrCat("<unknown cudnn status: ", static_cast<int>(status),
149 ">");
150 }
151 }
152
153 // RAII wrapper for all calls to cuDNN with a cuDNN handle argument.
154 //
155 // See CudnnAccess::GetHandle() for details.
156 class CudnnHandle {
157 public:
158 // Takes ownership of the executor context and the lock to access cuDNN
159 // using handle.
CudnnHandle(gpu::ScopedActivateExecutorContext context,std::unique_ptr<absl::MutexLock> lock,cudnnHandle_t handle)160 CudnnHandle(gpu::ScopedActivateExecutorContext context,
161 std::unique_ptr<absl::MutexLock> lock, cudnnHandle_t handle)
162 : context_(std::move(context)), lock_(std::move(lock)), handle_(handle) {}
163
164 // Returns cuDNN handle. To be passed directly to cuDNN APIs, don't keep
165 // a copy.
handle() const166 cudnnHandle_t handle() const { return handle_; }
167
168 private:
169 gpu::ScopedActivateExecutorContext context_;
170 std::unique_ptr<absl::MutexLock> lock_;
171 cudnnHandle_t handle_; // Not owned.
172 };
173
174 } // namespace
175
176 // Wraps a cuDNN handle and provides access to it through CudnnHandle
177 // instances, which also locks a mutex, acquires the CUDA context, and sets
178 // the stream that cuDNN should use to enqueue any work.
179 //
180 // Note: CudnnSupport::cudnn_ should be the only instantiation of this class.
181 class CudnnAccess {
182 public:
183 // Takes ownership of the handle.
CudnnAccess(cudnnHandle_t handle)184 explicit CudnnAccess(cudnnHandle_t handle) : handle_(handle) {}
185
~CudnnAccess()186 ~CudnnAccess() {
187 absl::MutexLock lock(&mutex_);
188 cudnnDestroy(handle_);
189 }
190
191 // Creates a CudnnHandle instance for stream.
192 //
193 // cuDNN API calls using the same handle instance need to be serialized
194 // across threads. This is guaranteed by CudnnHandle instances locking the
195 // mutex owned by this class.
196 //
197 // Most cuDNN APIs taking a handle perform work on a CUDA stream. The
198 // CudnnHandle instance acquires the executor's CUDA context and sets cuDNN
199 // to use the provided stream.
200 //
201 // The stream argument may be null, which translates to the legacy default
202 // stream. See
203 // https://docs.nvidia.com/cuda/cuda-driver-api/stream-sync-behavior.html.
204 // The legacy default stream synchronizes with all other streams and it is
205 // therefore a bad idea (performance wise) to call any cuDNN APIs that
206 // enqueue work in the stream.
GetHandle(GpuExecutor * executor,Stream * stream)207 CudnnHandle GetHandle(GpuExecutor* executor, Stream* stream) {
208 auto lock = absl::make_unique<absl::MutexLock>(&mutex_);
209 mutex_.AssertHeld();
210 gpu::ScopedActivateExecutorContext context(executor);
211 CUstream cu_stream = stream ? AsGpuStreamValue(stream) : cudaStreamLegacy;
212 const auto status = cudnnSetStream(handle_, cu_stream);
213 CHECK_EQ(status, CUDNN_STATUS_SUCCESS) << "Failed to set cuDNN stream.";
214 return CudnnHandle(std::move(context), std::move(lock), handle_);
215 }
216
217 private:
218 // Guards the enqueueing of cuDNN operations via the handle_ below.
219 absl::Mutex mutex_;
220
221 // cuDNN library handle.
222 cudnnHandle_t handle_ TF_GUARDED_BY(mutex_); // Owned.
223 };
224
225 namespace {
226
227 // A helper function to return the internal compute type for
228 // RNNs in cudnn.
229 cudnnDataType_t GetRnnComputeType(dnn::DataType data_type);
230
ToConvForwardAlgo(dnn::AlgorithmDesc algorithm)231 cudnnConvolutionFwdAlgo_t ToConvForwardAlgo(dnn::AlgorithmDesc algorithm) {
232 cudnnConvolutionFwdAlgo_t algo =
233 cudnnConvolutionFwdAlgo_t(algorithm.algo_id());
234 switch (algo) {
235 case CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM:
236 case CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM:
237 case CUDNN_CONVOLUTION_FWD_ALGO_GEMM:
238 case CUDNN_CONVOLUTION_FWD_ALGO_DIRECT:
239 case CUDNN_CONVOLUTION_FWD_ALGO_FFT:
240 case CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING:
241 case CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD:
242 case CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED:
243 return algo;
244 default:
245 LOG(FATAL) << "Unsupported Cudnn convolution forward algorithm: "
246 << algorithm.algo_id();
247 }
248 }
249
ToConvBackwardDataAlgo(dnn::AlgorithmDesc algorithm)250 cudnnConvolutionBwdDataAlgo_t ToConvBackwardDataAlgo(
251 dnn::AlgorithmDesc algorithm) {
252 cudnnConvolutionBwdDataAlgo_t algo =
253 cudnnConvolutionBwdDataAlgo_t(algorithm.algo_id());
254 switch (algo) {
255 case CUDNN_CONVOLUTION_BWD_DATA_ALGO_0:
256 case CUDNN_CONVOLUTION_BWD_DATA_ALGO_1:
257 case CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT:
258 case CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING:
259 case CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD:
260 case CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD_NONFUSED:
261 return algo;
262 default:
263 LOG(FATAL)
264 << "Unsupported Cudnn convolution backward algorithm for data: "
265 << algorithm.algo_id();
266 }
267 }
268
ToConvBackwardFilterAlgo(dnn::AlgorithmDesc algorithm)269 cudnnConvolutionBwdFilterAlgo_t ToConvBackwardFilterAlgo(
270 dnn::AlgorithmDesc algorithm) {
271 cudnnConvolutionBwdFilterAlgo_t algo =
272 cudnnConvolutionBwdFilterAlgo_t(algorithm.algo_id());
273 switch (algo) {
274 case CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0:
275 case CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1:
276 case CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT:
277 case CUDNN_CONVOLUTION_BWD_FILTER_ALGO_3:
278 // Based on cudnn.h, the following is not implemented.
279 // case CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD:
280 case CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD_NONFUSED:
281 case CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT_TILING:
282 return algo;
283 default:
284 LOG(FATAL)
285 << "Unsupported Cudnn convolution backward algorithm for filter: "
286 << algorithm.algo_id();
287 }
288 }
289
GetCudnnProperty(libraryPropertyType type)290 port::StatusOr<int> GetCudnnProperty(libraryPropertyType type) {
291 int value;
292 RETURN_IF_CUDNN_ERROR(cudnnGetProperty(type, &value));
293 return value;
294 }
295
ToCudnnRNNAlgo(absl::optional<dnn::AlgorithmDesc> algorithm)296 cudnnRNNAlgo_t ToCudnnRNNAlgo(absl::optional<dnn::AlgorithmDesc> algorithm) {
297 if (!algorithm.has_value()) {
298 return CUDNN_RNN_ALGO_STANDARD;
299 }
300 cudnnRNNAlgo_t algo = static_cast<cudnnRNNAlgo_t>(algorithm->algo_id());
301 switch (algo) {
302 case CUDNN_RNN_ALGO_STANDARD:
303 case CUDNN_RNN_ALGO_PERSIST_STATIC:
304 case CUDNN_RNN_ALGO_PERSIST_DYNAMIC:
305 return algo;
306 default:
307 LOG(FATAL) << "Unsupported Cudnn RNN algorithm: " << algorithm->algo_id();
308 }
309 }
310
GetLoadedCudnnVersion(CudnnVersion * version)311 port::Status GetLoadedCudnnVersion(CudnnVersion* version) {
312 SE_ASSIGN_OR_RETURN(version->major_version, GetCudnnProperty(MAJOR_VERSION));
313 SE_ASSIGN_OR_RETURN(version->minor_version, GetCudnnProperty(MINOR_VERSION));
314 SE_ASSIGN_OR_RETURN(version->patch_level, GetCudnnProperty(PATCH_LEVEL));
315 return port::Status::OK();
316 }
317
318 #if CUDNN_MAJOR >= 8 && (CUDNN_MINOR > 0 || CUDNN_PATCHLEVEL >= 4)
PreloadCudnnLibrary(cudnnStatus_t (* version_check_fn)(),absl::string_view sub_library)319 void PreloadCudnnLibrary(cudnnStatus_t (*version_check_fn)(),
320 absl::string_view sub_library) {
321 cudnnStatus_t status = version_check_fn();
322 if (status != CUDNN_STATUS_SUCCESS) {
323 VLOG(1) << "Could not pre-initialize cuDNN sub-library " << sub_library
324 << ". Error: " << cudnnGetErrorString(status) << ".";
325 }
326 }
327 #endif
328
329 } // namespace
330
CudnnSupport(GpuExecutor * parent)331 CudnnSupport::CudnnSupport(GpuExecutor* parent) : parent_(parent) {}
332
Init()333 port::Status CudnnSupport::Init() {
334 ScopedActivateExecutorContext context(parent_);
335
336 // Peek at the last error to give more information in cases of errors.
337 cudaError_t cerr = cudaPeekAtLastError();
338 if (cerr != cudaSuccess) {
339 LOG(WARNING) << "There was an error before creating cudnn handle: "
340 << cudaGetErrorName(cerr) << " : " << cudaGetErrorString(cerr);
341 }
342
343 cudnnHandle_t cudnn_handle = nullptr;
344 const auto status = cudnnCreate(&cudnn_handle);
345 if (status == CUDNN_STATUS_SUCCESS) {
346 CudnnVersion source_version(CUDNN_MAJOR, CUDNN_MINOR, CUDNN_PATCHLEVEL);
347
348 CudnnVersion loaded_version;
349 TF_RETURN_IF_ERROR(GetLoadedCudnnVersion(&loaded_version));
350 if (!IsSourceCompatibleWithCudnnLibrary(source_version, loaded_version)) {
351 const std::string error = absl::StrCat(
352 "Loaded runtime CuDNN library: ", loaded_version.ToString(),
353 " but source was compiled with: ", source_version.ToString(),
354 ". CuDNN library needs to have matching major version and equal or "
355 "higher minor version. If using a binary install, upgrade your CuDNN "
356 "library. If building from sources, make sure the library loaded at "
357 "runtime is compatible with the version specified during compile "
358 "configuration.");
359 LOG(ERROR) << error;
360 cudnnDestroy(cudnn_handle);
361 return port::Status(port::error::INTERNAL, error);
362 }
363
364 cudnn_.reset(new CudnnAccess(cudnn_handle));
365
366 LOG(INFO) << "Loaded cuDNN version " << cudnnGetVersion();
367 return port::Status::OK();
368 }
369
370 CHECK_EQ(cudnn_handle, nullptr);
371 LOG(ERROR) << "Could not create cudnn handle: " << ToString(status);
372 if (status == CUDNN_STATUS_NOT_INITIALIZED) {
373 auto result = gpu::Diagnostician::FindKernelDriverVersion();
374 if (!result.ok()) {
375 LOG(ERROR) << "Error retrieving driver version: "
376 << cuda::DriverVersionStatusToString(result);
377 } else {
378 const auto& version = result.ValueOrDie();
379 LOG(ERROR) << "Possibly insufficient driver version: "
380 << cuda::DriverVersionToString(version);
381 }
382 }
383
384 return port::Status(port::error::INTERNAL,
385 absl::StrCat("cudnn library could not create a handle: ",
386 ToString(status)));
387 }
388
389 port::StatusOr<perftools::gputools::dnn::VersionInfo>
GetVersion()390 CudnnSupport::GetVersion() {
391 CudnnVersion version;
392 TF_RETURN_IF_ERROR(GetLoadedCudnnVersion(&version));
393 return perftools::gputools::dnn::VersionInfo(
394 version.major_version, version.minor_version, version.patch_level);
395 }
396
397 namespace {
398
399 // Deleter functors for cuDNN types that need to be deleted.
400 struct TensorDescriptorDeleter {
operator ()stream_executor::gpu::__anone3ef6fe60311::TensorDescriptorDeleter401 void operator()(cudnnTensorDescriptor_t descriptor) const {
402 CHECK_CUDNN_OK(cudnnDestroyTensorDescriptor(descriptor));
403 }
404 };
405 struct RNNDataDescriptorDeleter {
operator ()stream_executor::gpu::__anone3ef6fe60311::RNNDataDescriptorDeleter406 void operator()(cudnnRNNDataDescriptor_t descriptor) const {
407 CHECK_CUDNN_OK(cudnnDestroyRNNDataDescriptor(descriptor));
408 }
409 };
410 struct FilterDescriptorDeleter {
operator ()stream_executor::gpu::__anone3ef6fe60311::FilterDescriptorDeleter411 void operator()(cudnnFilterDescriptor_t descriptor) const {
412 CHECK_CUDNN_OK(cudnnDestroyFilterDescriptor(descriptor));
413 }
414 };
415 struct ConvolutionDescriptorDeleter {
operator ()stream_executor::gpu::__anone3ef6fe60311::ConvolutionDescriptorDeleter416 void operator()(cudnnConvolutionDescriptor_t descriptor) const {
417 CHECK_CUDNN_OK(cudnnDestroyConvolutionDescriptor(descriptor));
418 }
419 };
420 struct PoolingDescriptorDeleter {
operator ()stream_executor::gpu::__anone3ef6fe60311::PoolingDescriptorDeleter421 void operator()(cudnnPoolingDescriptor_t descriptor) const {
422 CHECK_CUDNN_OK(cudnnDestroyPoolingDescriptor(descriptor));
423 }
424 };
425 struct LrnDescriptorDeleter {
operator ()stream_executor::gpu::__anone3ef6fe60311::LrnDescriptorDeleter426 void operator()(cudnnLRNDescriptor_t descriptor) const {
427 CHECK_CUDNN_OK(cudnnDestroyLRNDescriptor(descriptor));
428 }
429 };
430
431 struct ActivationDescriptorDeleter {
operator ()stream_executor::gpu::__anone3ef6fe60311::ActivationDescriptorDeleter432 void operator()(cudnnActivationDescriptor_t descriptor) const {
433 CHECK_CUDNN_OK(cudnnDestroyActivationDescriptor(descriptor));
434 }
435 };
436 struct DropoutDescriptorDeleter {
operator ()stream_executor::gpu::__anone3ef6fe60311::DropoutDescriptorDeleter437 void operator()(cudnnDropoutDescriptor_t descriptor) const {
438 CHECK_CUDNN_OK(cudnnDestroyDropoutDescriptor(descriptor));
439 }
440 };
441 struct RnnDescriptorDeleter {
operator ()stream_executor::gpu::__anone3ef6fe60311::RnnDescriptorDeleter442 void operator()(cudnnRNNDescriptor_t descriptor) const {
443 CHECK_CUDNN_OK(cudnnDestroyRNNDescriptor(descriptor));
444 }
445 };
446 struct PersistentRnnPlanDeleter {
operator ()stream_executor::gpu::__anone3ef6fe60311::PersistentRnnPlanDeleter447 void operator()(cudnnPersistentRNNPlan_t plan) const {
448 CHECK_CUDNN_OK(cudnnDestroyPersistentRNNPlan(plan));
449 }
450 };
451 #if CUDNN_VERSION >= 7603
452 struct CtcLossDescriptorDeleter {
operator ()stream_executor::gpu::__anone3ef6fe60311::CtcLossDescriptorDeleter453 void operator()(cudnnCTCLossDescriptor_t descriptor) const {
454 CHECK_CUDNN_OK(cudnnDestroyCTCLossDescriptor(descriptor));
455 }
456 };
457 #endif
458
459 // RAII wrappers for cuDNN types.
460 using TensorDescriptor =
461 std::unique_ptr<cudnnTensorStruct, TensorDescriptorDeleter>;
462 using RNNDataDescriptor =
463 std::unique_ptr<cudnnRNNDataStruct, RNNDataDescriptorDeleter>;
464 using FilterDescriptor =
465 std::unique_ptr<cudnnFilterStruct, FilterDescriptorDeleter>;
466 using ConvolutionDescriptor =
467 std::unique_ptr<cudnnConvolutionStruct, ConvolutionDescriptorDeleter>;
468 using PoolingDescriptor =
469 std::unique_ptr<cudnnPoolingStruct, PoolingDescriptorDeleter>;
470 using LrnDescriptor = std::unique_ptr<cudnnLRNStruct, LrnDescriptorDeleter>;
471 using ActivationDescriptor =
472 std::unique_ptr<cudnnActivationStruct, ActivationDescriptorDeleter>;
473 using DropoutDescriptor =
474 std::unique_ptr<cudnnDropoutStruct, DropoutDescriptorDeleter>;
475 using RnnDescriptor = std::unique_ptr<cudnnRNNStruct, RnnDescriptorDeleter>;
476 using PersistentRnnPlan =
477 std::unique_ptr<cudnnPersistentRNNPlan, PersistentRnnPlanDeleter>;
478 #if CUDNN_VERSION >= 7603
479 using CtcLossDescriptor =
480 std::unique_ptr<cudnnCTCLossStruct, CtcLossDescriptorDeleter>;
481 #endif
482
483 // Factory methods for cuDNN types.
CreateTensorDescriptor()484 TensorDescriptor CreateTensorDescriptor() {
485 cudnnTensorDescriptor_t result;
486 CHECK_CUDNN_OK(cudnnCreateTensorDescriptor(&result));
487 return TensorDescriptor(result);
488 }
CreateRNNDataDescriptor()489 RNNDataDescriptor CreateRNNDataDescriptor() {
490 cudnnRNNDataDescriptor_t result;
491 CHECK_CUDNN_OK(cudnnCreateRNNDataDescriptor(&result));
492 return RNNDataDescriptor(result);
493 }
CreateFilterDescriptor()494 FilterDescriptor CreateFilterDescriptor() {
495 cudnnFilterDescriptor_t result;
496 CHECK_CUDNN_OK(cudnnCreateFilterDescriptor(&result));
497 return FilterDescriptor(result);
498 }
CreateConvolutionDescriptor()499 ConvolutionDescriptor CreateConvolutionDescriptor() {
500 cudnnConvolutionDescriptor_t result;
501 CHECK_CUDNN_OK(cudnnCreateConvolutionDescriptor(&result));
502 return ConvolutionDescriptor(result);
503 }
CreatePoolingDescriptor()504 PoolingDescriptor CreatePoolingDescriptor() {
505 cudnnPoolingDescriptor_t result;
506 CHECK_CUDNN_OK(cudnnCreatePoolingDescriptor(&result));
507 return PoolingDescriptor(result);
508 }
CreateLrnDescriptor()509 LrnDescriptor CreateLrnDescriptor() {
510 cudnnLRNDescriptor_t result;
511 CHECK_CUDNN_OK(cudnnCreateLRNDescriptor(&result));
512 return LrnDescriptor(result);
513 }
CreateActivationDescriptor()514 ActivationDescriptor CreateActivationDescriptor() {
515 cudnnActivationDescriptor_t result;
516 CHECK_CUDNN_OK(cudnnCreateActivationDescriptor(&result));
517 return ActivationDescriptor(result);
518 }
CreateDropoutDescriptor()519 DropoutDescriptor CreateDropoutDescriptor() {
520 cudnnDropoutDescriptor_t result;
521 CHECK_CUDNN_OK(cudnnCreateDropoutDescriptor(&result));
522 return DropoutDescriptor(result);
523 }
CreateRnnDescriptor()524 RnnDescriptor CreateRnnDescriptor() {
525 cudnnRNNDescriptor_t result;
526 CHECK_CUDNN_OK(cudnnCreateRNNDescriptor(&result));
527 return RnnDescriptor(result);
528 }
529 #if CUDNN_VERSION >= 7603
CreateCtcLossDescriptor()530 CtcLossDescriptor CreateCtcLossDescriptor() {
531 cudnnCTCLossDescriptor_t result;
532 CHECK_CUDNN_OK(cudnnCreateCTCLossDescriptor(&result));
533 return CtcLossDescriptor(result);
534 }
535 #endif
536
CreatePersistentRnnPlan(cudnnRNNDescriptor_t rnn_desc,int batch_size,cudnnDataType_t data_type)537 port::StatusOr<PersistentRnnPlan> CreatePersistentRnnPlan(
538 cudnnRNNDescriptor_t rnn_desc, int batch_size, cudnnDataType_t data_type) {
539 cudnnPersistentRNNPlan_t result;
540 RETURN_IF_CUDNN_ERROR(
541 cudnnCreatePersistentRNNPlan(rnn_desc, batch_size, data_type, &result));
542 return port::StatusOr<PersistentRnnPlan>(PersistentRnnPlan(result));
543 }
544
545 // Turns a BatchDescriptor structure into a cudnn tensor handle within a
546 // scope.
547 class CudnnTensorDescriptor {
548 public:
CudnnTensorDescriptor(const dnn::BatchDescriptor & batch_descriptor,cudnnDataType_t elem_type)549 CudnnTensorDescriptor(const dnn::BatchDescriptor& batch_descriptor,
550 cudnnDataType_t elem_type)
551 : handle_(CreateTensorDescriptor()) {
552 switch (batch_descriptor.layout()) {
553 case dnn::DataLayout::kBatchYXDepth:
554 case dnn::DataLayout::kBatchDepthYX: {
555 const int nd = batch_descriptor.ndims() + 2;
556 // cuDNN requires the strides and dims to be ordered as BDYX.
557 std::vector<int64> strides64 =
558 batch_descriptor.full_strides(dnn::DataLayout::kBatchDepthYX);
559 std::vector<int64> dims64 =
560 batch_descriptor.full_dims(dnn::DataLayout::kBatchDepthYX);
561
562 // cuDNN requires arrays of ints.
563 std::vector<int> strides(nd);
564 std::vector<int> dims(nd);
565 std::transform(strides64.cbegin(), strides64.cend(), strides.begin(),
566 &CheckedNarrowing<int64, int>);
567 std::transform(dims64.cbegin(), dims64.cend(), dims.begin(),
568 &CheckedNarrowing<int64, int>);
569 CHECK_CUDNN_OK(cudnnSetTensorNdDescriptor(handle_.get(), elem_type, nd,
570 dims.data(), strides.data()))
571 << "batch_descriptor: " << batch_descriptor.ToString();
572 break;
573 }
574 case dnn::DataLayout::kBatchDepthYX4:
575 case dnn::DataLayout::kBatchDepthYX32: {
576 auto expected_elem_ty =
577 batch_descriptor.layout() == dnn::DataLayout::kBatchDepthYX4
578 ? CUDNN_DATA_INT8x4
579 : CUDNN_DATA_INT8x32;
580 CHECK_EQ(elem_type, expected_elem_ty);
581 CHECK_CUDNN_OK(cudnnSetTensor4dDescriptor(
582 handle_.get(), CUDNN_TENSOR_NCHW_VECT_C, elem_type,
583 batch_descriptor.count(), batch_descriptor.feature_map_count(),
584 batch_descriptor.height(), batch_descriptor.width()))
585 << "batch_descriptor: " << batch_descriptor.ToString();
586 break;
587 }
588 default:
589 LOG(FATAL) << "Unsupported tensor format "
590 << DataLayoutString(batch_descriptor.layout());
591 break;
592 }
593 }
594
handle() const595 cudnnTensorDescriptor_t handle() const { return handle_.get(); }
596
597 private:
598 TensorDescriptor handle_;
599
600 SE_DISALLOW_COPY_AND_ASSIGN(CudnnTensorDescriptor);
601 };
602
603 // Turns a FilterDescriptor structure into a cudnn filter handle within a
604 // scope.
605 class CudnnFilterDescriptor {
606 public:
CudnnFilterDescriptor(const dnn::FilterDescriptor & filter_descriptor,cudnnDataType_t elem_type)607 CudnnFilterDescriptor(const dnn::FilterDescriptor& filter_descriptor,
608 cudnnDataType_t elem_type)
609 : handle_(CreateFilterDescriptor()) {
610 // TODO(b/23032134): Even if the filter layout is not supported,
611 // cudnnSetFilter4DDescriptor_v4 will return CUDNN_STATUS_SUCCESS because
612 // it does not take layout as an input. Maybe force cuDNN by giving wrong
613 // inputs intentionally?
614 cudnnTensorFormat_t format;
615 switch (filter_descriptor.layout()) {
616 case dnn::FilterLayout::kOutputInputYX:
617 format = CUDNN_TENSOR_NCHW;
618 break;
619 case dnn::FilterLayout::kOutputYXInput:
620 format = CUDNN_TENSOR_NHWC;
621 break;
622 case dnn::FilterLayout::kOutputInputYX4:
623 case dnn::FilterLayout::kOutputInputYX32: {
624 auto expected_elem_ty =
625 filter_descriptor.layout() == dnn::FilterLayout::kOutputInputYX4
626 ? CUDNN_DATA_INT8x4
627 : CUDNN_DATA_INT8x32;
628 CHECK_EQ(elem_type, expected_elem_ty);
629 format = CUDNN_TENSOR_NCHW_VECT_C;
630 break;
631 }
632 default:
633 LOG(FATAL) << "Unsupported filter format "
634 << FilterLayoutString(filter_descriptor.layout());
635 break;
636 }
637
638 std::vector<int> dims(2 + filter_descriptor.ndims());
639 dims[0] = filter_descriptor.output_feature_map_count();
640 dims[1] = filter_descriptor.input_feature_map_count();
641 absl::Span<const int64> spatial_dims =
642 filter_descriptor.input_filter_dims();
643 std::copy(spatial_dims.begin(), spatial_dims.end(), dims.begin() + 2);
644
645 CHECK_CUDNN_OK(cudnnSetFilterNdDescriptor(handle_.get(), elem_type, format,
646 dims.size(), dims.data()));
647 }
648
handle() const649 cudnnFilterDescriptor_t handle() const { return handle_.get(); }
650
651 private:
652 FilterDescriptor handle_; // Owned.
653
654 SE_DISALLOW_COPY_AND_ASSIGN(CudnnFilterDescriptor);
655 };
656
657 #if CUDNN_VERSION >= 8100 && TF_ENABLE_CUDNN_FRONTEND
658 // The errata sheet (JSON format) for marking the cudnn engines that might be
659 // buggy. For example, we don't want the engine 999 of forward convolution:
660 // R"({ "version" : 1,
661 // "rules" : [
662 // { "rule_id" : "ConvFwd_eng999",
663 // "operation" : "ConvFwd",
664 // "engine" : 999,
665 // "knob" : [],
666 // "cudnn_version_start" : 8000,
667 // "cudnn_version_end" : -1
668 // }
669 // ]})"
670 // We skip eng0 in the static filter because they are too slow. Additionally,
671 // users can specify an additional errata JSON file via
672 // CUDNN_ERRATA_JSON_FILE at runtime.
CudnnExecutionPlanEngineFilterStatic()673 const json* CudnnExecutionPlanEngineFilterStatic() {
674 static absl::string_view filter_str = R"({
675 "version" : 1,
676 "rules" : [
677 { "rule_id" : "ConvFwd_eng0",
678 "operation" : "ConvFwd",
679 "engine" : 0,
680 "knob" : [],
681 "cudnn_version_start" : 8000,
682 "cudnn_version_end" : -1
683 },
684 { "rule_id" : "ConvBwdData_eng0",
685 "operation" : "ConvBwdData",
686 "engine" : 0,
687 "knob" : [],
688 "cudnn_version_start" : 8000,
689 "cudnn_version_end" : -1
690 },
691 { "rule_id" : "ConvBwdFilter_eng0",
692 "operation" : "ConvBwdFilter",
693 "engine" : 0,
694 "knob" : [],
695 "cudnn_version_start" : 8000,
696 "cudnn_version_end" : -1
697 }
698 ]})";
699 static const json* json_handle = new json(json::parse(filter_str));
700 return json_handle;
701 }
702
CudnnExecutionPlanEngineFilterRuntime()703 const json* CudnnExecutionPlanEngineFilterRuntime() {
704 static const json* json_handle = []() -> const json* {
705 json j;
706 if (cudnn_frontend::load_from_config(j, "")) {
707 return new json(j);
708 }
709 return nullptr;
710 }();
711 return json_handle;
712 }
713 #endif // CUDNN_VERSION >= 8100 && TF_ENABLE_CUDNN_FRONTEND
714
715 // A helper function to decide whether to use
716 // CUDNN_BATCHNORM_SPATIAL_PERSISTENT in batchnorm. This mode can be faster in
717 // some tasks because an optimized path may be selected for CUDNN_DATA_FLOAT
718 // and CUDNN_DATA_HALF data types, compute capability 6.0 or higher. The
719 // reason we set it to false by default is that this mode may use scaled
720 // atomic integer reduction that may cause a numerical overflow for certain
721 // input data range.
722 // TODO(yangzihao): Use autotune to choose between this mode and
723 // CUDNN_BATCHNORM_SPATIAL mode.
BatchnormSpatialPersistentEnabled()724 bool BatchnormSpatialPersistentEnabled() {
725 static bool is_enabled = [] {
726 bool is_enabled = false;
727 TF_CHECK_OK(tensorflow::ReadBoolFromEnvVar(
728 "TF_USE_CUDNN_BATCHNORM_SPATIAL_PERSISTENT",
729 /*default_val=*/false, &is_enabled));
730 return is_enabled;
731 }();
732 return is_enabled;
733 }
734
RequireCudnnDeterminism()735 bool RequireCudnnDeterminism() {
736 static bool require_cudnn_determinism = [] {
737 // TODO(reedwm): Remove the TF_CUDNN_DETERMINISTIC env var.
738 bool cudnn_deterministic = false;
739 TF_CHECK_OK(tensorflow::ReadBoolFromEnvVar("TF_CUDNN_DETERMINISTIC",
740 /*default_val=*/false,
741 &cudnn_deterministic));
742 return cudnn_deterministic;
743 }();
744 return tensorflow::OpDeterminismRequired() || require_cudnn_determinism;
745 }
746
747 // A helper function to decide whether to force the default conv algorithm.
ConvUseDefaultAlgorithm()748 bool ConvUseDefaultAlgorithm() {
749 static bool use_default = [] {
750 bool use_default = false;
751 TF_CHECK_OK(tensorflow::ReadBoolFromEnvVar("TF_USE_DEFAULT_CONV_ALGO",
752 /*default_val=*/false,
753 &use_default));
754 return use_default;
755 }();
756 return use_default;
757 }
758
759 // Turns a ConvolutionDescriptor structure into a cudnn convolution handle
760 // within a scope.
761 class CudnnConvolutionDescriptor {
762 public:
CudnnConvolutionDescriptor(const dnn::ConvolutionDescriptor & convolution_descriptor,cudnnDataType_t data_type)763 CudnnConvolutionDescriptor(
764 const dnn::ConvolutionDescriptor& convolution_descriptor,
765 cudnnDataType_t data_type)
766 : handle_(CreateConvolutionDescriptor()) {
767 absl::Span<const int64> strides64 = convolution_descriptor.strides();
768 absl::Span<const int64> padding64 = convolution_descriptor.padding();
769 absl::Span<const int64> dilations64 = convolution_descriptor.dilations();
770 CHECK_NE(convolution_descriptor.pad_alignment(),
771 dnn::PadAlignment::kTensorFlowPadding)
772 << "TensorFlow padding alignment is not supported.";
773
774 // cuDNN requires arrays of ints.
775 std::vector<int> strides(convolution_descriptor.ndims());
776 std::vector<int> padding(convolution_descriptor.ndims());
777 std::vector<int> dilations(convolution_descriptor.ndims());
778 std::transform(strides64.cbegin(), strides64.cend(), strides.begin(),
779 &CheckedNarrowing<int64, int>);
780 std::transform(padding64.cbegin(), padding64.cend(), padding.begin(),
781 &CheckedNarrowing<int64, int>);
782 // TODO(yangzihao): Test with negative dilation to make sure that cudnn
783 // doesn't crash.
784 std::transform(dilations64.cbegin(), dilations64.cend(), dilations.begin(),
785 &CheckedNarrowing<int64, int>);
786
787 CHECK_CUDNN_OK(cudnnSetConvolutionNdDescriptor(
788 handle_.get(), convolution_descriptor.ndims(), padding.data(),
789 strides.data(), dilations.data(),
790 convolution_descriptor.convolution_not_crosscorr()
791 ? CUDNN_CONVOLUTION
792 : CUDNN_CROSS_CORRELATION,
793 data_type));
794
795 #if CUDNN_MAJOR >= 7
796 VLOG(2) << "Requesting grouped convolution: "
797 << convolution_descriptor.group_count();
798 CHECK_CUDNN_OK(cudnnSetConvolutionGroupCount(
799 handle_.get(), convolution_descriptor.group_count()));
800 #else
801 CHECK_EQ(convolution_descriptor.group_count(), 1)
802 << "Requested grouped convolution for cuDNN version < 7";
803 #endif
804 }
805
set_use_tensor_op_math(bool use_tensor_op_math)806 void set_use_tensor_op_math(bool use_tensor_op_math) {
807 cudnnMathType_t math_type =
808 #if CUDNN_VERSION >= 8000
809 (use_tensor_op_math ? CUDNN_TENSOR_OP_MATH : CUDNN_FMA_MATH);
810 #else
811 (use_tensor_op_math ? CUDNN_TENSOR_OP_MATH : CUDNN_DEFAULT_MATH);
812 #endif
813 CHECK_CUDNN_OK(cudnnSetConvolutionMathType(handle_.get(), math_type));
814 }
815
handle() const816 cudnnConvolutionDescriptor_t handle() const { return handle_.get(); }
817
818 private:
819 ConvolutionDescriptor handle_; // Owned.
820
821 SE_DISALLOW_COPY_AND_ASSIGN(CudnnConvolutionDescriptor);
822 };
823
824 // A helper function to query if a CudnnConvolutionDescriptor has tensor_op_math
825 // set
IsTensorMathOpSet(const CudnnConvolutionDescriptor & conv)826 static bool IsTensorMathOpSet(const CudnnConvolutionDescriptor& conv) {
827 cudnnMathType_t math_type;
828 CHECK_CUDNN_OK(cudnnGetConvolutionMathType(conv.handle(), &math_type));
829 #if CUDNN_VERSION >= 8000
830 return math_type != CUDNN_FMA_MATH;
831 #else
832 return math_type == CUDNN_TENSOR_OP_MATH;
833 #endif
834 }
835
TensorOpMathAvailable(CudaComputeCapability cuda_compute_capability)836 static bool TensorOpMathAvailable(
837 CudaComputeCapability cuda_compute_capability) {
838 return cuda_compute_capability.IsAtLeast(7);
839 }
840
IsTensorMathEnabled(Stream * stream,dnn::DataType input_type)841 static bool IsTensorMathEnabled(Stream* stream, dnn::DataType input_type) {
842 if (!TensorOpMathAvailable(stream->GetCudaComputeCapability())) {
843 return false;
844 }
845 if (input_type == dnn::DataType::kFloat) {
846 #if CUDNN_VERSION < 8000
847 return false;
848 #else
849 if (!tensorflow::tensor_float_32_execution_enabled()) {
850 return false;
851 }
852 #endif
853 }
854 return true;
855 }
856
857 // Turns a PoolingDescriptor structure into a cudnn pooling descriptor handle
858 // within a scope.
859 class CudnnPoolingDescriptor {
860 public:
CudnnPoolingDescriptor(const dnn::PoolingDescriptor & pooling_descriptor)861 explicit CudnnPoolingDescriptor(
862 const dnn::PoolingDescriptor& pooling_descriptor)
863 : handle_(CreatePoolingDescriptor()) {
864 absl::Span<const int64> strides64 = pooling_descriptor.strides();
865 absl::Span<const int64> padding64 = pooling_descriptor.padding();
866 absl::Span<const int64> shape64 = pooling_descriptor.window();
867
868 const int nd = pooling_descriptor.ndims();
869 std::vector<int> shape(nd);
870 std::vector<int> padding(nd);
871 std::vector<int> strides(nd);
872 std::transform(strides64.cbegin(), strides64.cend(), strides.begin(),
873 &CheckedNarrowing<int64, int>);
874 std::transform(padding64.cbegin(), padding64.cend(), padding.begin(),
875 &CheckedNarrowing<int64, int>);
876 std::transform(shape64.cbegin(), shape64.cend(), shape.begin(),
877 &CheckedNarrowing<int64, int>);
878 bool propagate_nans = pooling_descriptor.propagate_nans();
879 const auto cudnn_max_pooling_mode = RequireCudnnDeterminism()
880 ? CUDNN_POOLING_MAX_DETERMINISTIC
881 : CUDNN_POOLING_MAX;
882 CHECK_CUDNN_OK(cudnnSetPoolingNdDescriptor(
883 handle_.get(),
884 (pooling_descriptor.mode() == dnn::PoolingMode::kMaximum
885 ? cudnn_max_pooling_mode
886 : CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING),
887 propagate_nans ? CUDNN_PROPAGATE_NAN : CUDNN_NOT_PROPAGATE_NAN, nd,
888 shape.data(), padding.data(), strides.data()));
889 }
890
handle() const891 cudnnPoolingDescriptor_t handle() const { return handle_.get(); }
892
893 private:
894 PoolingDescriptor handle_; // Owned.
895
896 SE_DISALLOW_COPY_AND_ASSIGN(CudnnPoolingDescriptor);
897 };
898
899 // Turns a NormalizeDescriptor structure into a cudnn LRN descriptor handle.
900 class CudnnNormalizeDescriptor {
901 public:
CudnnNormalizeDescriptor(const dnn::NormalizeDescriptor & normalize_descriptor)902 explicit CudnnNormalizeDescriptor(
903 const dnn::NormalizeDescriptor& normalize_descriptor)
904 : handle_(CreateLrnDescriptor()) {
905 // The range specifies that the indices in the closed range
906 // [i - range, i + range] should be included in the normalization for index
907 // i. The lrnN value is the total number of elements in the range, so
908 // lrnN = 2*range + 1.
909 unsigned lrnN = 2 * normalize_descriptor.range() + 1;
910
911 // Note that SE defines the normalization operation as
912 //
913 // U_i = V_i / ((bias + alpha * (sum_j V_j^2)) ^ beta)
914 //
915 // but cuDNN defines it as
916 //
917 // U_i = V_i / ((bias + (alpha / n) * (sum_j V_j^2)) ^ beta)
918 //
919 // i.e. there is a factor of n difference between the meaning of the alphas
920 // in the two contexts. The cuDNN alpha is n times the SE alpha.
921 double lrnAlpha = lrnN * normalize_descriptor.alpha();
922
923 double lrnBeta = normalize_descriptor.beta();
924 double lrnK = normalize_descriptor.bias();
925 CHECK_CUDNN_OK(
926 cudnnSetLRNDescriptor(handle_.get(), lrnN, lrnAlpha, lrnBeta, lrnK));
927 }
928
handle() const929 cudnnLRNDescriptor_t handle() const { return handle_.get(); }
930
931 private:
932 LrnDescriptor handle_; // Owned.
933
934 SE_DISALLOW_COPY_AND_ASSIGN(CudnnNormalizeDescriptor);
935 };
936
937 // Turns a ActivationDescriptor structure into a cudnn activation
938 // descriptor handle within a scope.
939 class CudnnActivationDescriptor {
940 public:
CudnnActivationDescriptor(dnn::ActivationMode activation_mode,cudnnNanPropagation_t nan_propagation,double value_max)941 CudnnActivationDescriptor(dnn::ActivationMode activation_mode,
942 cudnnNanPropagation_t nan_propagation,
943 double value_max)
944 : handle_(CreateActivationDescriptor()) {
945 double relu_ceiling = 0.0;
946 cudnnActivationMode_t mode;
947 switch (activation_mode) {
948 case dnn::ActivationMode::kNone:
949 mode = CUDNN_ACTIVATION_IDENTITY;
950 break;
951 case dnn::ActivationMode::kRelu6:
952 relu_ceiling = 6.0;
953 mode = CUDNN_ACTIVATION_CLIPPED_RELU;
954 break;
955 case dnn::ActivationMode::kReluX:
956 relu_ceiling = value_max;
957 mode = CUDNN_ACTIVATION_CLIPPED_RELU;
958 break;
959 case dnn::ActivationMode::kRelu:
960 mode = CUDNN_ACTIVATION_RELU;
961 break;
962 case dnn::ActivationMode::kSigmoid:
963 mode = CUDNN_ACTIVATION_SIGMOID;
964 break;
965 case dnn::ActivationMode::kTanh:
966 mode = CUDNN_ACTIVATION_TANH;
967 break;
968 default:
969 LOG(FATAL) << "unrecognized activation mode: "
970 << static_cast<int>(activation_mode);
971 }
972
973 CHECK_CUDNN_OK(cudnnSetActivationDescriptor(handle_.get(), mode,
974 nan_propagation, relu_ceiling));
975 }
976
handle() const977 cudnnActivationDescriptor_t handle() const { return handle_.get(); }
978
979 private:
980 ActivationDescriptor handle_; // Owned.
981
982 SE_DISALLOW_COPY_AND_ASSIGN(CudnnActivationDescriptor);
983 };
984
ToCudnnDataType(dnn::DataType data_type,dnn::DataLayout data_layout=dnn::DataLayout::kBatchDepthYX)985 cudnnDataType_t ToCudnnDataType(
986 dnn::DataType data_type,
987 dnn::DataLayout data_layout = dnn::DataLayout::kBatchDepthYX) {
988 switch (data_type) {
989 case dnn::DataType::kFloat:
990 return CUDNN_DATA_FLOAT;
991 case dnn::DataType::kDouble:
992 return CUDNN_DATA_DOUBLE;
993 case dnn::DataType::kHalf:
994 return CUDNN_DATA_HALF;
995 case dnn::DataType::kInt8:
996 switch (data_layout) {
997 case dnn::DataLayout::kBatchDepthYX4:
998 return CUDNN_DATA_INT8x4;
999 case dnn::DataLayout::kBatchDepthYX32:
1000 return CUDNN_DATA_INT8x32;
1001 default:
1002 return CUDNN_DATA_INT8;
1003 }
1004 case dnn::DataType::kInt32:
1005 return CUDNN_DATA_INT32;
1006 #if CUDNN_VERSION >= 8200
1007 case dnn::DataType::kBF16:
1008 return CUDNN_DATA_BFLOAT16;
1009 #endif
1010 default:
1011 LOG(FATAL) << "Invalid DNN data type: " << static_cast<int>(data_type);
1012 }
1013 }
1014
ToCudnnDataType(dnn::DataType data_type,dnn::FilterLayout filter_layout)1015 cudnnDataType_t ToCudnnDataType(dnn::DataType data_type,
1016 dnn::FilterLayout filter_layout) {
1017 if (data_type == dnn::DataType::kInt8 &&
1018 filter_layout == dnn::FilterLayout::kOutputInputYX4) {
1019 return CUDNN_DATA_INT8x4;
1020 }
1021 if (data_type == dnn::DataType::kInt8 &&
1022 filter_layout == dnn::FilterLayout::kOutputInputYX32) {
1023 return CUDNN_DATA_INT8x32;
1024 }
1025 return ToCudnnDataType(data_type);
1026 }
1027
1028 template <typename T>
GetCudnnDataType(dnn::DataLayout data_layout=dnn::DataLayout::kBatchDepthYX)1029 cudnnDataType_t GetCudnnDataType(
1030 dnn::DataLayout data_layout = dnn::DataLayout::kBatchDepthYX) {
1031 return ToCudnnDataType(dnn::ToDataType<T>::value, data_layout);
1032 }
1033
1034 template <typename T>
GetCudnnDataType(dnn::FilterLayout filter_layout)1035 cudnnDataType_t GetCudnnDataType(dnn::FilterLayout filter_layout) {
1036 return ToCudnnDataType(dnn::ToDataType<T>::value, filter_layout);
1037 }
1038
ToCudnnRnnInputMode(dnn::RnnInputMode input_mode)1039 cudnnRNNInputMode_t ToCudnnRnnInputMode(dnn::RnnInputMode input_mode) {
1040 switch (input_mode) {
1041 case dnn::RnnInputMode::kRnnLinearSkip:
1042 case dnn::RnnInputMode::kRnnSkipInput:
1043 return static_cast<cudnnRNNInputMode_t>(input_mode);
1044 default:
1045 LOG(FATAL) << "Invalid RNN input mode: " << static_cast<int>(input_mode);
1046 }
1047 }
1048
ToCudnnRnnDirectionMode(dnn::RnnDirectionMode direction_mode)1049 cudnnDirectionMode_t ToCudnnRnnDirectionMode(
1050 dnn::RnnDirectionMode direction_mode) {
1051 switch (direction_mode) {
1052 case dnn::RnnDirectionMode::kRnnUnidirectional:
1053 case dnn::RnnDirectionMode::kRnnBidirectional:
1054 return static_cast<cudnnDirectionMode_t>(direction_mode);
1055 default:
1056 LOG(FATAL) << "Invalid RNN direction mode: "
1057 << static_cast<int>(direction_mode);
1058 }
1059 }
1060
ToCudnnRnnMode(dnn::RnnMode rnn_mode)1061 cudnnRNNMode_t ToCudnnRnnMode(dnn::RnnMode rnn_mode) {
1062 switch (rnn_mode) {
1063 case dnn::RnnMode::kRnnRelu:
1064 case dnn::RnnMode::kRnnTanh:
1065 case dnn::RnnMode::kRnnLstm:
1066 case dnn::RnnMode::kRnnGru:
1067 return static_cast<cudnnRNNMode_t>(rnn_mode);
1068 default:
1069 LOG(FATAL) << "Invalid RNN Mode: " << static_cast<int>(rnn_mode);
1070 }
1071 }
1072
CudnnDataTypeToByteSize(cudnnDataType_t data_type)1073 int CudnnDataTypeToByteSize(cudnnDataType_t data_type) {
1074 switch (data_type) {
1075 case CUDNN_DATA_FLOAT:
1076 return sizeof(float);
1077 case CUDNN_DATA_DOUBLE:
1078 return sizeof(double);
1079 case CUDNN_DATA_HALF:
1080 return sizeof(Eigen::half);
1081 default:
1082 LOG(FATAL) << "Invalid DNN data type: " << static_cast<int>(data_type);
1083 }
1084 }
1085
1086 class CudnnDropoutDescriptor {
CudnnDropoutDescriptor(DropoutDescriptor handle)1087 explicit CudnnDropoutDescriptor(DropoutDescriptor handle)
1088 : handle_(std::move(handle)) {}
1089
1090 public:
1091 CudnnDropoutDescriptor(CudnnDropoutDescriptor&&) = default;
1092
Create(const CudnnHandle & cudnn,float dropout,uint64 seed,ScratchAllocator * state_allocator)1093 static port::StatusOr<CudnnDropoutDescriptor> Create(
1094 const CudnnHandle& cudnn, float dropout, uint64 seed,
1095 ScratchAllocator* state_allocator) {
1096 DropoutDescriptor handle = CreateDropoutDescriptor();
1097
1098 if (dropout == 0.0f) {
1099 // Return 'empty' dropout descriptor.
1100 return CudnnDropoutDescriptor(std::move(handle));
1101 }
1102
1103 DeviceMemory<uint8> state_memory;
1104 if (state_allocator) {
1105 size_t state_sizes_in_bytes = 0;
1106 RETURN_IF_CUDNN_ERROR(
1107 cudnnDropoutGetStatesSize(cudnn.handle(), &state_sizes_in_bytes));
1108 SE_ASSIGN_OR_RETURN(state_memory,
1109 state_allocator->AllocateBytes(state_sizes_in_bytes));
1110 }
1111 RETURN_IF_CUDNN_ERROR(cudnnSetDropoutDescriptor(
1112 handle.get(), cudnn.handle(), dropout, state_memory.opaque(),
1113 state_memory.size(), seed));
1114
1115 return CudnnDropoutDescriptor(std::move(handle));
1116 }
1117
handle() const1118 cudnnDropoutDescriptor_t handle() const { return handle_.get(); }
1119
1120 private:
1121 DropoutDescriptor handle_; // Owned.
1122 SE_DISALLOW_COPY_AND_ASSIGN(CudnnDropoutDescriptor);
1123 };
1124
1125 class CudnnRnnParamsDescriptor {
1126 typedef dnn::RnnDescriptor::ParamsRegions ParamsRegions;
1127
CudnnRnnParamsDescriptor(FilterDescriptor handle,int64_t params_size_in_bytes,ParamsRegions weights,ParamsRegions biases)1128 CudnnRnnParamsDescriptor(FilterDescriptor handle,
1129 int64_t params_size_in_bytes, ParamsRegions weights,
1130 ParamsRegions biases)
1131 : handle_(std::move(handle)),
1132 params_size_in_bytes_(params_size_in_bytes),
1133 weights_(std::move(weights)),
1134 biases_(std::move(biases)) {}
1135
1136 public:
1137 CudnnRnnParamsDescriptor(CudnnRnnParamsDescriptor&&) = default;
1138
1139 static port::StatusOr<CudnnRnnParamsDescriptor> Create(
1140 const CudnnHandle& cudnn, int input_size, cudnnDataType_t data_type,
1141 cudnnRNNDescriptor_t rnn_desc, cudnnRNNMode_t rnn_mode,
1142 cudnnDirectionMode_t direction_mode, int num_layers);
1143
handle() const1144 cudnnFilterDescriptor_t handle() const { return handle_.get(); }
params_size_in_bytes() const1145 int64 params_size_in_bytes() const { return params_size_in_bytes_; }
params_weights() const1146 ParamsRegions params_weights() const { return weights_; }
params_biases() const1147 ParamsRegions params_biases() const { return biases_; }
1148
1149 private:
1150 FilterDescriptor handle_;
1151 int64 params_size_in_bytes_;
1152 ParamsRegions weights_;
1153 ParamsRegions biases_;
1154 SE_DISALLOW_COPY_AND_ASSIGN(CudnnRnnParamsDescriptor);
1155 };
1156
1157 } // namespace
1158
1159 class CudnnRnnDescriptor : public dnn::RnnDescriptor {
CudnnRnnDescriptor(const CudnnHandle & cudnn,gpu::RnnDescriptor rnn_desc,PersistentRnnPlan rnn_plan,int num_layers,int hidden_size,int input_size,int cell_size,int batch_size,cudnnRNNInputMode_t input_mode,cudnnDirectionMode_t direction_mode,cudnnRNNMode_t rnn_mode,cudnnDataType_t data_type,cudnnDataType_t compute_type,const dnn::AlgorithmConfig & algorithm_config,CudnnDropoutDescriptor dropout_desc,CudnnRnnParamsDescriptor params_desc)1160 CudnnRnnDescriptor(const CudnnHandle& cudnn, gpu::RnnDescriptor rnn_desc,
1161 PersistentRnnPlan rnn_plan, int num_layers,
1162 int hidden_size, int input_size, int cell_size,
1163 int batch_size, cudnnRNNInputMode_t input_mode,
1164 cudnnDirectionMode_t direction_mode,
1165 cudnnRNNMode_t rnn_mode, cudnnDataType_t data_type,
1166 cudnnDataType_t compute_type,
1167 const dnn::AlgorithmConfig& algorithm_config,
1168 CudnnDropoutDescriptor dropout_desc,
1169 CudnnRnnParamsDescriptor params_desc)
1170 : rnn_desc_(std::move(rnn_desc)),
1171 rnn_plan_(std::move(rnn_plan)),
1172 num_layers_(num_layers),
1173 hidden_size_(hidden_size),
1174 input_size_(input_size),
1175 cell_size_(cell_size),
1176 batch_size_(batch_size),
1177 rnn_algo_(ToCudnnRNNAlgo(algorithm_config.algorithm())),
1178 input_mode_(input_mode),
1179 direction_mode_(direction_mode),
1180 rnn_mode_(rnn_mode),
1181 data_type_(data_type),
1182 compute_type_(compute_type),
1183 algorithm_config_(algorithm_config),
1184 dropout_desc_(std::move(dropout_desc)),
1185 params_desc_(std::move(params_desc)) {}
1186
1187 public:
1188 CudnnRnnDescriptor(CudnnRnnDescriptor&& other) = default;
1189
Create(const CudnnHandle & cudnn,int num_layers,int hidden_size,int input_size,int cell_size,int batch_size,cudnnRNNInputMode_t input_mode,cudnnDirectionMode_t direction_mode,cudnnRNNMode_t rnn_mode,cudnnDataType_t data_type,cudnnDataType_t compute_type,const dnn::AlgorithmConfig & algorithm_config,float dropout,uint64 seed,ScratchAllocator * state_allocator,bool use_padded_io)1190 static port::StatusOr<CudnnRnnDescriptor> Create(
1191 const CudnnHandle& cudnn, int num_layers, int hidden_size, int input_size,
1192 int cell_size, int batch_size, cudnnRNNInputMode_t input_mode,
1193 cudnnDirectionMode_t direction_mode, cudnnRNNMode_t rnn_mode,
1194 cudnnDataType_t data_type, cudnnDataType_t compute_type,
1195 const dnn::AlgorithmConfig& algorithm_config, float dropout, uint64 seed,
1196 ScratchAllocator* state_allocator, bool use_padded_io) {
1197 SE_ASSIGN_OR_RETURN(
1198 CudnnDropoutDescriptor dropout_desc,
1199 CudnnDropoutDescriptor::Create(cudnn, dropout, seed, state_allocator));
1200
1201 gpu::RnnDescriptor rnn_desc = CreateRnnDescriptor();
1202 cudnnRNNAlgo_t rnn_algo = ToCudnnRNNAlgo(algorithm_config.algorithm());
1203
1204 // TODO: allow the user to choose an algorithm.
1205 auto proj_size = hidden_size;
1206 hidden_size = std::max(hidden_size, cell_size);
1207
1208 // Require explicit algorithm config to enable tensor cores. Some configs
1209 // return CUDNN_NOT_SUPPORTED when tensor ops are enabled (which is against
1210 // the idiom that enabling tensor ops is only a hint: see nvbugs/2172799).
1211 // We can only reasonably expect the user to handle the subsequent failure
1212 // in profile mode, which is run with algorithms returned from
1213 // GetRnnAlgorithms() (which are non-default and explicitly set whether to
1214 // use tensor ops). CuDNN 7.2.1 fixed this issue.
1215 // TODO(csigg): Minimal support cuDNN version is 7.3, clean up.
1216 bool allow_tensor_ops = data_type == CUDNN_DATA_HALF;
1217 if (data_type == CUDNN_DATA_FLOAT)
1218 allow_tensor_ops = tensorflow::tensor_float_32_execution_enabled();
1219 bool use_tensor_ops =
1220 algorithm_config.algorithm().has_value()
1221 ? algorithm_config.algorithm()->tensor_ops_enabled()
1222 : allow_tensor_ops;
1223 if (use_tensor_ops && !allow_tensor_ops) {
1224 return port::Status(port::error::INVALID_ARGUMENT,
1225 "Algo requests disallowed tensor op evaluation.");
1226 }
1227
1228 #if CUDNN_VERSION >= 8000
1229 cudnnMathType_t math_type =
1230 use_tensor_ops ? CUDNN_TENSOR_OP_MATH : CUDNN_FMA_MATH;
1231 #else
1232 cudnnMathType_t math_type =
1233 use_tensor_ops ? CUDNN_TENSOR_OP_MATH : CUDNN_DEFAULT_MATH;
1234 #endif
1235
1236 #if CUDNN_VERSION >= 8000
1237 cudnnRNNBiasMode_t bias_mode = CUDNN_RNN_DOUBLE_BIAS;
1238 uint32_t aux_flags = 0;
1239 if (use_padded_io) aux_flags |= CUDNN_RNN_PADDED_IO_ENABLED;
1240 RETURN_IF_CUDNN_ERROR(cudnnSetRNNDescriptor_v8(
1241 /*rnnDesc=*/rnn_desc.get(), /*algo=*/rnn_algo, /*cellMode=*/rnn_mode,
1242 /*biasMode=*/bias_mode, /*dirMode=*/direction_mode,
1243 /*inputMode=*/input_mode,
1244 /*dataType=*/data_type, /*mathPrec=*/compute_type,
1245 /*mathType=*/math_type,
1246 /*inputSize=*/input_size,
1247 /*hiddenSize=*/hidden_size, /*projSize=*/proj_size,
1248 /*numLayers=*/num_layers,
1249 /*dropoutDesc=*/dropout_desc.handle(),
1250 /*auxFlags=*/aux_flags));
1251 #else
1252 RETURN_IF_CUDNN_ERROR(cudnnSetRNNDescriptor_v6(
1253 cudnn.handle(), /*rnnDesc=*/rnn_desc.get(),
1254 /*hiddenSize=*/hidden_size, /*numLayers=*/num_layers,
1255 /*dropoutDesc=*/dropout_desc.handle(), /*inputMode=*/input_mode,
1256 /*direction=*/direction_mode, /*mode=*/rnn_mode, /*algo=*/rnn_algo,
1257 /*dataType=*/compute_type));
1258 CHECK_CUDNN_OK(cudnnSetRNNMatrixMathType(rnn_desc.get(), math_type));
1259
1260 if (proj_size < hidden_size) {
1261 RETURN_IF_CUDNN_ERROR(cudnnSetRNNProjectionLayers(
1262 cudnn.handle(), /*rnnDesc=*/rnn_desc.get(),
1263 /*recProjSize=*/proj_size, /*outProjSize=*/0));
1264 }
1265
1266 // TODO: For now, we only use cudnnRNN**Ex API to process padded inputs.
1267 // But in the future if these APIs are used to process full length arrays,
1268 // we need to distinguish when to set it.
1269 if (use_padded_io) {
1270 RETURN_IF_CUDNN_ERROR(
1271 cudnnSetRNNPaddingMode(rnn_desc.get(), CUDNN_RNN_PADDED_IO_ENABLED));
1272 }
1273 #endif
1274
1275 port::StatusOr<PersistentRnnPlan> rnn_plan_wrapper;
1276 PersistentRnnPlan rnn_plan;
1277 if (rnn_algo == CUDNN_RNN_ALGO_PERSIST_DYNAMIC) {
1278 CHECK_GE(batch_size, 0);
1279 rnn_plan_wrapper =
1280 CreatePersistentRnnPlan(rnn_desc.get(), batch_size, data_type);
1281 if (!rnn_plan_wrapper.ok()) {
1282 return port::StatusOr<CudnnRnnDescriptor>(rnn_plan_wrapper.status());
1283 } else {
1284 rnn_plan = rnn_plan_wrapper.ConsumeValueOrDie();
1285 RETURN_IF_CUDNN_ERROR(
1286 cudnnSetPersistentRNNPlan(rnn_desc.get(), rnn_plan.get()));
1287 }
1288 }
1289
1290 // Create the params handle.
1291 // TODO(kaixih@nvidia.com): Should be removed when cudnnRNNForward*** and
1292 // cudnnRNNForward***Ex are removed from the codebase, since the new API
1293 // doesn't need param descriptors any more.
1294 SE_ASSIGN_OR_RETURN(auto params_desc,
1295 CudnnRnnParamsDescriptor::Create(
1296 cudnn, input_size, data_type, rnn_desc.get(),
1297 rnn_mode, direction_mode, num_layers));
1298
1299 return CudnnRnnDescriptor(cudnn, std::move(rnn_desc), std::move(rnn_plan),
1300 num_layers, hidden_size, input_size, cell_size,
1301 batch_size, input_mode, direction_mode, rnn_mode,
1302 data_type, compute_type, algorithm_config,
1303 std::move(dropout_desc), std::move(params_desc));
1304 }
1305
handle() const1306 cudnnRNNDescriptor_t handle() const { return rnn_desc_.get(); }
num_layers() const1307 int num_layers() const { return num_layers_; }
hidden_size() const1308 int hidden_size() const { return hidden_size_; }
input_size() const1309 int input_size() const { return input_size_; }
cell_size() const1310 int cell_size() const { return cell_size_; }
batch_size() const1311 int batch_size() const { return batch_size_; }
input_mode() const1312 cudnnRNNInputMode_t input_mode() const { return input_mode_; }
direction_mode() const1313 cudnnDirectionMode_t direction_mode() const { return direction_mode_; }
rnn_mode() const1314 cudnnRNNMode_t rnn_mode() const { return rnn_mode_; }
data_type() const1315 cudnnDataType_t data_type() const { return data_type_; }
compute_type() const1316 cudnnDataType_t compute_type() const { return compute_type_; }
algorithm_config() const1317 const dnn::AlgorithmConfig& algorithm_config() const {
1318 return algorithm_config_;
1319 }
ParamsSizeInBytes() const1320 int64 ParamsSizeInBytes() const override {
1321 return params_desc_.params_size_in_bytes();
1322 }
params_handle() const1323 cudnnFilterDescriptor_t params_handle() const {
1324 return params_desc_.handle();
1325 }
ParamsWeightRegions() const1326 ParamsRegions ParamsWeightRegions() const override {
1327 return params_desc_.params_weights();
1328 }
ParamsBiasRegions() const1329 ParamsRegions ParamsBiasRegions() const override {
1330 return params_desc_.params_biases();
1331 }
1332
1333 private:
1334 gpu::RnnDescriptor rnn_desc_;
1335 PersistentRnnPlan rnn_plan_;
1336 int num_layers_;
1337 int hidden_size_;
1338 int input_size_;
1339 // cell_size_ is the size of cell state, which will be different from
1340 // hidden_size_ if the projection is used.
1341 int cell_size_;
1342 // batch_size_ is set to -1 when not using CUDNN_RNN_ALGO_PERSIST_DYNAMIC
1343 // algorithm.
1344 int batch_size_;
1345 cudnnRNNAlgo_t rnn_algo_;
1346 cudnnRNNInputMode_t input_mode_;
1347 cudnnDirectionMode_t direction_mode_;
1348 cudnnRNNMode_t rnn_mode_;
1349 cudnnDataType_t data_type_;
1350 cudnnDataType_t compute_type_;
1351 dnn::AlgorithmConfig algorithm_config_;
1352 CudnnDropoutDescriptor dropout_desc_;
1353 CudnnRnnParamsDescriptor params_desc_;
1354 SE_DISALLOW_COPY_AND_ASSIGN(CudnnRnnDescriptor);
1355 };
1356
1357 #if CUDNN_VERSION >= 7603
1358 class CudnnCtcLossDescriptor {
1359 public:
CudnnCtcLossDescriptor(cudnnDataType_t data_type)1360 explicit CudnnCtcLossDescriptor(cudnnDataType_t data_type)
1361 : handle_(CreateCtcLossDescriptor()) {
1362 CHECK_CUDNN_OK(cudnnSetCTCLossDescriptorEx(
1363 /*ctcLossDesc=*/handle_.get(),
1364 /*compType=*/data_type,
1365 /*normMode=*/CUDNN_LOSS_NORMALIZATION_SOFTMAX,
1366 /*gradMode=*/CUDNN_NOT_PROPAGATE_NAN));
1367 }
1368
handle() const1369 cudnnCTCLossDescriptor_t handle() const { return handle_.get(); }
1370
1371 private:
1372 CtcLossDescriptor handle_; // Owned
1373
1374 SE_DISALLOW_COPY_AND_ASSIGN(CudnnCtcLossDescriptor);
1375 };
1376 #else
1377 // dummy class
1378 class CudnnCtcLossDescriptor {
1379 public:
CudnnCtcLossDescriptor(cudnnDataType_t data_type)1380 CudnnCtcLossDescriptor(cudnnDataType_t data_type) {}
1381 };
1382 #endif
1383
1384 namespace {
1385
1386 // Check if the LSTM projection is used. If yes, an additional weight matrix
1387 // (projection matrix) will be fetched to the 'weights'. Otherwise, nothing will
1388 // be done.
CheckAndFetchProjectionWeights(const CudnnHandle & cudnn,cudnnRNNDescriptor_t rnn_desc,const int layer,const TensorDescriptor & input_desc,const FilterDescriptor & filter_desc,const FilterDescriptor & region_desc_handle,dnn::RnnDescriptor::ParamsRegions * weights)1389 port::Status CheckAndFetchProjectionWeights(
1390 const CudnnHandle& cudnn, cudnnRNNDescriptor_t rnn_desc, const int layer,
1391 const TensorDescriptor& input_desc, const FilterDescriptor& filter_desc,
1392 const FilterDescriptor& region_desc_handle,
1393 dnn::RnnDescriptor::ParamsRegions* weights) {
1394 int hidden_size_v;
1395 int num_layers_v;
1396 cudnnDropoutDescriptor_t dropout_desc;
1397 cudnnRNNInputMode_t input_mode;
1398 cudnnDirectionMode_t direction;
1399 cudnnRNNMode_t mode;
1400 cudnnRNNAlgo_t algo;
1401 cudnnDataType_t data_type;
1402 #if CUDNN_VERSION >= 8000
1403 RETURN_IF_CUDNN_ERROR(cudnnGetRNNDescriptor_v6(
1404 /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc,
1405 /*hiddenSize=*/&hidden_size_v,
1406 /*numLayers=*/&num_layers_v,
1407 /*dropoutDesc=*/&dropout_desc,
1408 /*inputMode=*/&input_mode,
1409 /*direction=*/&direction,
1410 /*mode=*/&mode,
1411 /*algo=*/&algo,
1412 /*mathPrec=*/&data_type));
1413 #else
1414 RETURN_IF_CUDNN_ERROR(cudnnGetRNNDescriptor(
1415 /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc,
1416 /*hiddenSize=*/&hidden_size_v,
1417 /*numLayers=*/&num_layers_v,
1418 /*dropoutDesc=*/&dropout_desc,
1419 /*inputMode=*/&input_mode,
1420 /*direction=*/&direction,
1421 /*mode=*/&mode,
1422 /*algo=*/&algo,
1423 /*mathPrec=*/&data_type));
1424 #endif
1425 int rec_proj_size_v;
1426 int out_proj_size_v;
1427 RETURN_IF_CUDNN_ERROR(cudnnGetRNNProjectionLayers(
1428 /*handle=*/cudnn.handle(),
1429 /*rnnDesc=*/rnn_desc,
1430 /*recProjSize*/ &rec_proj_size_v,
1431 /*outProjSize*/ &out_proj_size_v));
1432 if (rec_proj_size_v != hidden_size_v) {
1433 void* offset = nullptr;
1434 int region_id = 8;
1435 RETURN_IF_CUDNN_ERROR(cudnnGetRNNLinLayerMatrixParams(
1436 /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc,
1437 /*layer=*/layer, /*xDesc=*/input_desc.get(),
1438 /*wDesc=*/filter_desc.get(),
1439 /*w=*/nullptr, /*linLayerID=*/region_id,
1440 /*linLayerMatDesc=*/region_desc_handle.get(),
1441 /*linLayerMat or linLayerBias=*/&offset));
1442 int dims[] = {1, 1, 1};
1443 cudnnDataType_t data_type;
1444 cudnnTensorFormat_t tensor_format;
1445 int n_dims;
1446 RETURN_IF_CUDNN_ERROR(cudnnGetFilterNdDescriptor(
1447 /*filterDesc=*/region_desc_handle.get(),
1448 /*nbDimsRequested=*/sizeof(dims) / sizeof(dims[0]),
1449 /*dataType=*/&data_type, /*format=*/&tensor_format,
1450 /*nbDims=*/&n_dims, /*filterDimA=*/dims));
1451 int64_t size =
1452 dims[0] * dims[1] * dims[2] * CudnnDataTypeToByteSize(data_type);
1453 dnn::RnnDescriptor::ParamsRegion region = {reinterpret_cast<int64>(offset),
1454 size};
1455 weights->push_back(region);
1456 }
1457 return port::Status::OK();
1458 }
1459
Create(const CudnnHandle & cudnn,int input_size,cudnnDataType_t data_type,cudnnRNNDescriptor_t rnn_desc,cudnnRNNMode_t rnn_mode,cudnnDirectionMode_t direction_mode,int num_layers)1460 port::StatusOr<CudnnRnnParamsDescriptor> CudnnRnnParamsDescriptor::Create(
1461 const CudnnHandle& cudnn, int input_size, cudnnDataType_t data_type,
1462 cudnnRNNDescriptor_t rnn_desc, cudnnRNNMode_t rnn_mode,
1463 cudnnDirectionMode_t direction_mode, int num_layers) {
1464 // Query the params size.
1465 TensorDescriptor input_desc = CreateTensorDescriptor();
1466 int tensor_dims[] = {1, input_size, 1};
1467 int strides[] = {tensor_dims[1] * tensor_dims[2], tensor_dims[2], 1};
1468 RETURN_IF_CUDNN_ERROR(cudnnSetTensorNdDescriptor(
1469 /*tensorDesc=*/input_desc.get(), /*dataType=*/data_type,
1470 /*nbDims=*/sizeof(tensor_dims) / sizeof(tensor_dims[0]),
1471 /*dimA=*/tensor_dims,
1472 /*strideA=*/strides));
1473
1474 size_t params_size = 0;
1475 RETURN_IF_CUDNN_ERROR(cudnnGetRNNParamsSize(
1476 /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc,
1477 /*xDesc=*/input_desc.get(), /*sizeInBytes=*/¶ms_size,
1478 /*dataType=*/data_type));
1479 int64_t params_size_in_bytes = static_cast<int64>(params_size);
1480
1481 FilterDescriptor filter_desc = CreateFilterDescriptor();
1482 int64_t filter_dim0 =
1483 params_size_in_bytes / CudnnDataTypeToByteSize(data_type);
1484 int filter_dims[] = {static_cast<int>(filter_dim0), 1, 1};
1485 RETURN_IF_CUDNN_ERROR(cudnnSetFilterNdDescriptor(
1486 /*filterDesc=*/filter_desc.get(), /*dataType=*/data_type,
1487 /*format=*/CUDNN_TENSOR_NCHW,
1488 /*nbDims=*/sizeof(filter_dims) / sizeof(filter_dims[0]),
1489 /*filterDimA=*/filter_dims));
1490
1491 // Create the weights and biases into the params buffer
1492 int region_count_per_layer = [&] {
1493 switch (rnn_mode) {
1494 case CUDNN_RNN_RELU:
1495 case CUDNN_RNN_TANH:
1496 return 2;
1497 case CUDNN_LSTM:
1498 return 8;
1499 case CUDNN_GRU:
1500 return 6;
1501 default:
1502 LOG(FATAL) << "Invalid RNN Mode: " << static_cast<int>(rnn_mode);
1503 return 0;
1504 }
1505 }();
1506
1507 FilterDescriptor region_desc_handle = CreateFilterDescriptor();
1508 const int layer_count =
1509 direction_mode == CUDNN_UNIDIRECTIONAL ? num_layers : 2 * num_layers;
1510
1511 ParamsRegions weights;
1512 ParamsRegions biases;
1513
1514 for (int layer = 0; layer < layer_count; layer++) {
1515 for (int region = 0; region < region_count_per_layer; region++) {
1516 for (int type = 0; type < 2; type++) {
1517 void* offset = nullptr;
1518 RETURN_IF_CUDNN_ERROR(
1519 type == 0 ? cudnnGetRNNLinLayerMatrixParams(
1520 /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc,
1521 /*layer=*/layer, /*xDesc=*/input_desc.get(),
1522 /*wDesc=*/filter_desc.get(),
1523 /*w=*/nullptr, /*linLayerID=*/region,
1524 /*linLayerMatDesc=*/region_desc_handle.get(),
1525 /*linLayerMat or linLayerBias=*/&offset)
1526 : cudnnGetRNNLinLayerBiasParams(
1527 /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc,
1528 /*layer=*/layer, /*xDesc=*/input_desc.get(),
1529 /*wDesc=*/filter_desc.get(),
1530 /*w=*/nullptr, /*linLayerID=*/region,
1531 /*linLayerMatDesc=*/region_desc_handle.get(),
1532 /*linLayerMat or linLayerBias=*/&offset));
1533 int dims[] = {1, 1, 1};
1534 cudnnDataType_t data_type;
1535 cudnnTensorFormat_t tensor_format;
1536 int n_dims;
1537 RETURN_IF_CUDNN_ERROR(cudnnGetFilterNdDescriptor(
1538 /*filterDesc=*/region_desc_handle.get(),
1539 /*nbDimsRequested=*/sizeof(dims) / sizeof(dims[0]),
1540 /*dataType=*/&data_type, /*format=*/&tensor_format,
1541 /*nbDims=*/&n_dims, /*filterDimA=*/dims));
1542 int64_t size =
1543 dims[0] * dims[1] * dims[2] * CudnnDataTypeToByteSize(data_type);
1544 dnn::RnnDescriptor::ParamsRegion region = {
1545 reinterpret_cast<int64>(offset), size};
1546 (type == 0 ? weights : biases).push_back(region);
1547 }
1548 }
1549 TF_RETURN_IF_ERROR(CheckAndFetchProjectionWeights(
1550 cudnn, rnn_desc, layer, input_desc, filter_desc, region_desc_handle,
1551 &weights));
1552 }
1553
1554 return CudnnRnnParamsDescriptor(std::move(filter_desc), params_size_in_bytes,
1555 weights, biases);
1556 }
1557
1558 } // namespace
1559
1560 class CudnnRnnSequenceTensorDescriptor
1561 : public dnn::RnnSequenceTensorDescriptor {
CudnnRnnSequenceTensorDescriptor(GpuExecutor * parent,int max_seq_length,int batch_size,int data_size,cudnnDataType_t data_type,RNNDataDescriptor data_handle,TensorDescriptor handle)1562 CudnnRnnSequenceTensorDescriptor(GpuExecutor* parent, int max_seq_length,
1563 int batch_size, int data_size,
1564 cudnnDataType_t data_type,
1565 RNNDataDescriptor data_handle,
1566 TensorDescriptor handle)
1567 : max_seq_length_(max_seq_length),
1568 batch_size_(batch_size),
1569 data_size_(data_size),
1570 data_type_(data_type),
1571 handle_(std::move(handle)),
1572 rnn_data_handle_(std::move(data_handle)),
1573 handles_(max_seq_length, handle_.get()) {}
1574
1575 public:
1576 CudnnRnnSequenceTensorDescriptor(CudnnRnnSequenceTensorDescriptor&&) =
1577 default;
1578
Create(GpuExecutor * parent,int max_seq_length,int batch_size,int data_size,cudnnDataType_t data_type)1579 static port::StatusOr<CudnnRnnSequenceTensorDescriptor> Create(
1580 GpuExecutor* parent, int max_seq_length, int batch_size, int data_size,
1581 cudnnDataType_t data_type) {
1582 if (max_seq_length <= 0) {
1583 return port::Status(port::error::INVALID_ARGUMENT, "max_seq_length <= 0");
1584 }
1585 int dims[] = {batch_size, data_size, 1};
1586 int strides[] = {dims[1] * dims[2], dims[2], 1};
1587 TensorDescriptor tensor_desc = CreateTensorDescriptor();
1588 RETURN_IF_CUDNN_ERROR(cudnnSetTensorNdDescriptor(
1589 /*tensorDesc=*/tensor_desc.get(), /*dataType=*/data_type,
1590 /*nbDims=*/sizeof(dims) / sizeof(dims[0]), /*dimA=*/dims,
1591 /*strideA=*/strides));
1592 return CudnnRnnSequenceTensorDescriptor(parent, max_seq_length, batch_size,
1593 data_size, data_type, nullptr,
1594 std::move(tensor_desc));
1595 }
1596
Create(GpuExecutor * parent,int max_seq_length,int batch_size,int data_size,const absl::Span<const int> & seq_lengths,bool time_major,cudnnDataType_t data_type)1597 static port::StatusOr<CudnnRnnSequenceTensorDescriptor> Create(
1598 GpuExecutor* parent, int max_seq_length, int batch_size, int data_size,
1599 const absl::Span<const int>& seq_lengths, bool time_major,
1600 cudnnDataType_t data_type) {
1601 if (max_seq_length <= 0) {
1602 return port::Status(port::error::INVALID_ARGUMENT, "max_seq_length <= 0");
1603 }
1604 int dims[] = {batch_size, data_size, 1};
1605 int strides[] = {dims[1] * dims[2], dims[2], 1};
1606 TensorDescriptor tensor_desc = CreateTensorDescriptor();
1607 RETURN_IF_CUDNN_ERROR(cudnnSetTensorNdDescriptor(
1608 /*tensorDesc=*/tensor_desc.get(), /*dataType=*/data_type,
1609 /*nbDims=*/sizeof(dims) / sizeof(dims[0]), /*dimA=*/dims,
1610 /*strideA=*/strides));
1611 const int* seq_lengths_array = seq_lengths.data();
1612 RNNDataDescriptor data_desc = CreateRNNDataDescriptor();
1613 float padding_fill = 0.0f;
1614 cudnnRNNDataLayout_t layout;
1615 if (time_major) {
1616 layout = CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_UNPACKED;
1617 } else {
1618 layout = CUDNN_RNN_DATA_LAYOUT_BATCH_MAJOR_UNPACKED;
1619 }
1620 RETURN_IF_CUDNN_ERROR(cudnnSetRNNDataDescriptor(
1621 /*RNNDataDesc=*/data_desc.get(), /*dataType*/ data_type,
1622 /*layout=*/layout,
1623 /*maxSeqLength=*/max_seq_length,
1624 /*batchSize=*/batch_size, /*vectorSize=*/data_size,
1625 /*seqLengthArray=*/seq_lengths_array,
1626 /*paddingFill*/ (void*)&padding_fill));
1627 return CudnnRnnSequenceTensorDescriptor(
1628 parent, max_seq_length, batch_size, data_size, data_type,
1629 std::move(data_desc), std::move(tensor_desc));
1630 }
1631
handles() const1632 const cudnnTensorDescriptor_t* handles() const { return handles_.data(); }
data_handle() const1633 const cudnnRNNDataDescriptor_t data_handle() const {
1634 return rnn_data_handle_.get();
1635 }
1636
max_seq_length() const1637 int max_seq_length() const { return max_seq_length_; }
batch_size() const1638 int batch_size() const { return batch_size_; }
data_size() const1639 int data_size() const { return data_size_; }
is_var_seq_lengths() const1640 bool is_var_seq_lengths() const { return rnn_data_handle_ != nullptr; }
1641
1642 private:
1643 int max_seq_length_;
1644 int batch_size_;
1645 int data_size_;
1646 cudnnDataType_t data_type_;
1647 TensorDescriptor handle_;
1648 RNNDataDescriptor rnn_data_handle_;
1649 std::vector<cudnnTensorDescriptor_t> handles_; // Copies of handle_.
1650 SE_DISALLOW_COPY_AND_ASSIGN(CudnnRnnSequenceTensorDescriptor);
1651 };
1652
1653 class CudnnRnnStateTensorDescriptor : public dnn::RnnStateTensorDescriptor {
1654 public:
CudnnRnnStateTensorDescriptor(GpuExecutor * parent,int num_layers,int batch_size,int data_size,cudnnDataType_t data_type)1655 CudnnRnnStateTensorDescriptor(GpuExecutor* parent, int num_layers,
1656 int batch_size, int data_size,
1657 cudnnDataType_t data_type)
1658 : handle_(CreateTensorDescriptor()),
1659 num_layers_(num_layers),
1660 batch_size_(batch_size),
1661 data_size_(data_size),
1662 data_type_(data_type) {
1663 int dims[] = {num_layers, batch_size, data_size};
1664 int strides[] = {dims[1] * dims[2], dims[2], 1};
1665 CHECK_CUDNN_OK(cudnnSetTensorNdDescriptor(
1666 /*tensorDesc=*/handle_.get(), /*dataType=*/data_type,
1667 /*nbDims=*/sizeof(dims) / sizeof(dims[0]), /*dimA=*/dims,
1668 /*strideA=*/strides));
1669 }
1670
handle() const1671 cudnnTensorDescriptor_t handle() const { return handle_.get(); }
1672
num_layers() const1673 int num_layers() const { return num_layers_; }
batch_size() const1674 int batch_size() const { return batch_size_; }
data_size() const1675 int data_size() const { return data_size_; }
1676
1677 private:
1678 TensorDescriptor handle_;
1679 int num_layers_;
1680 int batch_size_;
1681 int data_size_;
1682 cudnnDataType_t data_type_;
1683 SE_DISALLOW_COPY_AND_ASSIGN(CudnnRnnStateTensorDescriptor);
1684 };
1685
1686 #if CUDNN_VERSION >= 8100 && TF_ENABLE_CUDNN_FRONTEND
1687 class CudnnConvolveExecutionPlan : public dnn::ConvolveExecutionPlan {
1688 public:
CudnnConvolveExecutionPlan(cudnn_frontend::ExecutionPlan plan)1689 CudnnConvolveExecutionPlan(cudnn_frontend::ExecutionPlan plan)
1690 : plan_(std::move(plan)) {}
getTag()1691 std::string getTag() override { return plan_.getTag(); };
get_raw_desc()1692 void* get_raw_desc() override { return plan_.get_raw_desc(); }
getWorkspaceSize()1693 int64_t getWorkspaceSize() override { return plan_.getWorkspaceSize(); }
1694
1695 private:
1696 cudnn_frontend::ExecutionPlan plan_;
1697 SE_DISALLOW_COPY_AND_ASSIGN(CudnnConvolveExecutionPlan);
1698 };
1699 #endif // CUDNN_VERSION >= 8100 && TF_ENABLE_CUDNN_FRONTEND
1700
1701 namespace {
1702
1703 struct RnnModelDims {
1704 int num_layers = 0;
1705 int batch_size = 0;
1706 int max_seq_length = 0;
1707 int hidden_size = 0;
1708 int input_size = 0;
1709 int cell_size = 0;
1710 int dir_count = 0;
1711 };
1712
1713 template <class T>
ExtractAndCheckRnnForward(const CudnnRnnDescriptor & rnn_desc,const CudnnRnnSequenceTensorDescriptor & input_desc,const DeviceMemory<T> & input_data,const CudnnRnnStateTensorDescriptor & input_h_desc,const DeviceMemory<T> & input_h_data,const CudnnRnnStateTensorDescriptor & input_c_desc,const DeviceMemory<T> & input_c_data,const DeviceMemory<T> & params,const CudnnRnnSequenceTensorDescriptor & output_desc,const DeviceMemory<T> & output_data,const CudnnRnnStateTensorDescriptor & output_h_desc,const DeviceMemory<T> & output_h_data,const CudnnRnnStateTensorDescriptor & output_c_desc,const DeviceMemory<T> & output_c_data)1714 port::StatusOr<RnnModelDims> ExtractAndCheckRnnForward(
1715 const CudnnRnnDescriptor& rnn_desc,
1716 const CudnnRnnSequenceTensorDescriptor& input_desc,
1717 const DeviceMemory<T>& input_data,
1718 const CudnnRnnStateTensorDescriptor& input_h_desc,
1719 const DeviceMemory<T>& input_h_data,
1720 const CudnnRnnStateTensorDescriptor& input_c_desc,
1721 const DeviceMemory<T>& input_c_data, const DeviceMemory<T>& params,
1722 const CudnnRnnSequenceTensorDescriptor& output_desc,
1723 const DeviceMemory<T>& output_data,
1724 const CudnnRnnStateTensorDescriptor& output_h_desc,
1725 const DeviceMemory<T>& output_h_data,
1726 const CudnnRnnStateTensorDescriptor& output_c_desc,
1727 const DeviceMemory<T>& output_c_data) {
1728 // extract model parameters
1729 RnnModelDims model_dims;
1730 model_dims.num_layers = rnn_desc.num_layers();
1731 model_dims.batch_size = input_desc.batch_size();
1732 model_dims.max_seq_length = input_desc.max_seq_length();
1733 model_dims.hidden_size = rnn_desc.hidden_size();
1734 model_dims.input_size = input_desc.data_size();
1735 model_dims.cell_size = rnn_desc.cell_size();
1736 model_dims.dir_count =
1737 (rnn_desc.direction_mode() == CUDNN_BIDIRECTIONAL) ? 2 : 1;
1738
1739 // check parameters
1740 if (!(input_h_desc.num_layers() ==
1741 model_dims.num_layers * model_dims.dir_count &&
1742 input_h_desc.batch_size() == model_dims.batch_size &&
1743 input_h_desc.data_size() == model_dims.hidden_size)) {
1744 return port::Status(port::error::INVALID_ARGUMENT, "Invalid input_h shape");
1745 }
1746 // The LSTM projection will be used if input_h_desc.data_size() <
1747 // input_c_desc.data_size()
1748 if (!(input_h_desc.num_layers() == input_c_desc.num_layers() &&
1749 input_h_desc.batch_size() == input_c_desc.batch_size() &&
1750 input_h_desc.data_size() <= input_c_desc.data_size())) {
1751 return port::Status(port::error::INVALID_ARGUMENT, "Invalid input_c shape");
1752 }
1753 if (!(output_desc.max_seq_length() == model_dims.max_seq_length &&
1754 output_desc.batch_size() == model_dims.batch_size &&
1755 output_desc.data_size() ==
1756 model_dims.hidden_size * model_dims.dir_count)) {
1757 return port::Status(port::error::INVALID_ARGUMENT, "Invalid output shape");
1758 }
1759 if (!(input_h_desc.num_layers() == output_h_desc.num_layers() &&
1760 input_h_desc.batch_size() == output_h_desc.batch_size() &&
1761 input_h_desc.data_size() == output_h_desc.data_size())) {
1762 return port::Status(port::error::INVALID_ARGUMENT,
1763 "Invalid output_h shape");
1764 }
1765 if (!(input_h_desc.num_layers() == output_c_desc.num_layers() &&
1766 input_h_desc.batch_size() == output_c_desc.batch_size() &&
1767 input_h_desc.data_size() <= output_c_desc.data_size())) {
1768 return port::Status(port::error::INVALID_ARGUMENT,
1769 "Invalid output_c shape");
1770 }
1771
1772 return model_dims;
1773 }
1774
CheckRNNParameterSize(const CudnnHandle & cudnn,const CudnnRnnDescriptor & rnn_desc,const CudnnRnnSequenceTensorDescriptor & input_desc)1775 port::Status CheckRNNParameterSize(
1776 const CudnnHandle& cudnn, const CudnnRnnDescriptor& rnn_desc,
1777 const CudnnRnnSequenceTensorDescriptor& input_desc) {
1778 size_t params_size_in_bytes = 0;
1779 #if CUDNN_VERSION >= 8100 && TF_ENABLE_CUDNN_FRONTEND
1780 RETURN_IF_CUDNN_ERROR(cudnnGetRNNWeightSpaceSize(
1781 /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(),
1782 /*sizeInBytes=*/¶ms_size_in_bytes));
1783 #else
1784 RETURN_IF_CUDNN_ERROR(cudnnGetRNNParamsSize(
1785 /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(),
1786 /*xDesc=*/input_desc.handles()[0], /*sizeInBytes=*/¶ms_size_in_bytes,
1787 /*dataType=*/rnn_desc.data_type()));
1788 #endif
1789 if (static_cast<int64>(params_size_in_bytes) !=
1790 rnn_desc.ParamsSizeInBytes()) {
1791 return port::Status(port::error::INVALID_ARGUMENT,
1792 "Mismatching RNN parameter size");
1793 }
1794 return port::Status::OK();
1795 }
1796
CreateRnnWorkspace(Stream * stream,const CudnnHandle & cudnn,const CudnnRnnDescriptor & rnn_desc,const CudnnRnnSequenceTensorDescriptor & input_desc,ScratchAllocator * workspace_allocator)1797 port::StatusOr<DeviceMemory<uint8>> CreateRnnWorkspace(
1798 Stream* stream, const CudnnHandle& cudnn,
1799 const CudnnRnnDescriptor& rnn_desc,
1800 const CudnnRnnSequenceTensorDescriptor& input_desc,
1801 ScratchAllocator* workspace_allocator) {
1802 // Query the workspace size.
1803 size_t workspace_size_in_bytes = 0;
1804 RETURN_IF_CUDNN_ERROR(cudnnGetRNNWorkspaceSize(
1805 /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(),
1806 /*seqLength=*/input_desc.max_seq_length(), /*xDesc=*/input_desc.handles(),
1807 /*sizeInBytes=*/&workspace_size_in_bytes));
1808 // Allocate the workspace.
1809 if (workspace_size_in_bytes == 0) {
1810 return DeviceMemory<uint8>();
1811 }
1812 return workspace_allocator->AllocateBytes(workspace_size_in_bytes);
1813 }
1814
1815 #if CUDNN_VERSION >= 7402
CreateBatchNormForwardWorkspace(Stream * stream,const CudnnHandle & cudnn,const cudnnBatchNormMode_t & mode,const cudnnBatchNormOps_t & bn_ops,const cudnnActivationDescriptor_t & activation_desc,const CudnnTensorDescriptor & x_descriptor,const CudnnTensorDescriptor & scale_offset_descriptor,ScratchAllocator * workspace_allocator)1816 port::StatusOr<DeviceMemory<uint8>> CreateBatchNormForwardWorkspace(
1817 Stream* stream, const CudnnHandle& cudnn, const cudnnBatchNormMode_t& mode,
1818 const cudnnBatchNormOps_t& bn_ops,
1819 const cudnnActivationDescriptor_t& activation_desc,
1820 const CudnnTensorDescriptor& x_descriptor,
1821 const CudnnTensorDescriptor& scale_offset_descriptor,
1822 ScratchAllocator* workspace_allocator) {
1823 // Query the workspace size.
1824 size_t workspace_size_in_bytes = 0;
1825 RETURN_IF_CUDNN_ERROR(
1826 cudnnGetBatchNormalizationForwardTrainingExWorkspaceSize(
1827 /*handle=*/cudnn.handle(), /*mode=*/mode, /*bnOps=*/bn_ops,
1828 /*xDesc=*/x_descriptor.handle(), /*zDesc=*/x_descriptor.handle(),
1829 /*yDesc=*/x_descriptor.handle(),
1830 /*bnScaleBiasMeanVarDesc=*/scale_offset_descriptor.handle(),
1831 /*activationDesc=*/activation_desc,
1832 /*sizeInBytes=*/&workspace_size_in_bytes));
1833 // Allocate the workspace.
1834 if (workspace_size_in_bytes == 0) {
1835 return DeviceMemory<uint8>();
1836 }
1837 return workspace_allocator->AllocateBytes(workspace_size_in_bytes);
1838 }
1839
CreateBatchNormBackwardWorkspace(Stream * stream,const CudnnHandle & cudnn,const cudnnBatchNormMode_t & mode,const cudnnBatchNormOps_t & bn_ops,const CudnnTensorDescriptor & x_descriptor,const CudnnTensorDescriptor & scale_offset_descriptor,ScratchAllocator * workspace_allocator)1840 port::StatusOr<DeviceMemory<uint8>> CreateBatchNormBackwardWorkspace(
1841 Stream* stream, const CudnnHandle& cudnn, const cudnnBatchNormMode_t& mode,
1842 const cudnnBatchNormOps_t& bn_ops,
1843 const CudnnTensorDescriptor& x_descriptor,
1844 const CudnnTensorDescriptor& scale_offset_descriptor,
1845 ScratchAllocator* workspace_allocator) {
1846 // Query the workspace size.
1847 size_t workspace_size_in_bytes = 0;
1848 RETURN_IF_CUDNN_ERROR(cudnnGetBatchNormalizationBackwardExWorkspaceSize(
1849 /*handle=*/cudnn.handle(), /*mode=*/mode, /*bnOps=*/bn_ops,
1850 /*xDesc=*/x_descriptor.handle(),
1851 /*yDesc=*/x_descriptor.handle(),
1852 /*dyDesc=*/x_descriptor.handle(),
1853 /*dzDesc=*/nullptr,
1854 /*dxDesc=*/x_descriptor.handle(),
1855 /*dBnScaleBiasDesc=*/scale_offset_descriptor.handle(),
1856 /*activationDesc=*/nullptr, /*sizeInBytes=*/&workspace_size_in_bytes));
1857 // Allocate the workspace.
1858 if (workspace_size_in_bytes == 0) {
1859 return DeviceMemory<uint8>();
1860 }
1861 return workspace_allocator->AllocateBytes(workspace_size_in_bytes);
1862 }
1863
1864 #endif
1865
1866 } // namespace
1867
1868 template <class T>
DoRnnForwardImpl(Stream * stream,const CudnnRnnDescriptor & rnn_desc,const CudnnRnnSequenceTensorDescriptor & input_desc,const DeviceMemory<T> & input_data,const DeviceMemory<int> & seq_lengths_data,const CudnnRnnStateTensorDescriptor & input_h_desc,const DeviceMemory<T> & input_h_data,const CudnnRnnStateTensorDescriptor & input_c_desc,const DeviceMemory<T> & input_c_data,const DeviceMemory<T> & params,const CudnnRnnSequenceTensorDescriptor & output_desc,DeviceMemory<T> * output_data,const CudnnRnnStateTensorDescriptor & output_h_desc,DeviceMemory<T> * output_h_data,const CudnnRnnStateTensorDescriptor & output_c_desc,DeviceMemory<T> * output_c_data,bool is_training,ScratchAllocator * reserve_space_allocator,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)1869 port::Status CudnnSupport::DoRnnForwardImpl(
1870 Stream* stream, const CudnnRnnDescriptor& rnn_desc,
1871 const CudnnRnnSequenceTensorDescriptor& input_desc,
1872 const DeviceMemory<T>& input_data,
1873 const DeviceMemory<int>& seq_lengths_data,
1874 const CudnnRnnStateTensorDescriptor& input_h_desc,
1875 const DeviceMemory<T>& input_h_data,
1876 const CudnnRnnStateTensorDescriptor& input_c_desc,
1877 const DeviceMemory<T>& input_c_data, const DeviceMemory<T>& params,
1878 const CudnnRnnSequenceTensorDescriptor& output_desc,
1879 DeviceMemory<T>* output_data,
1880 const CudnnRnnStateTensorDescriptor& output_h_desc,
1881 DeviceMemory<T>* output_h_data,
1882 const CudnnRnnStateTensorDescriptor& output_c_desc,
1883 DeviceMemory<T>* output_c_data, bool is_training,
1884 ScratchAllocator* reserve_space_allocator,
1885 ScratchAllocator* workspace_allocator,
1886 dnn::ProfileResult* output_profile_result) {
1887 SE_ASSIGN_OR_RETURN(
1888 RnnModelDims model_dims,
1889 ExtractAndCheckRnnForward(
1890 rnn_desc, input_desc, input_data, input_h_desc, input_h_data,
1891 input_c_desc, input_c_data, params, output_desc, *output_data,
1892 output_h_desc, *output_h_data, output_c_desc, *output_c_data));
1893
1894 auto cudnn = cudnn_->GetHandle(parent_, stream);
1895
1896 SE_RETURN_IF_ERROR(CheckRNNParameterSize(cudnn, rnn_desc, input_desc));
1897
1898 // In CUDNN v8.0, the cudnnRNNForward*** and cudnnRNNForward***Ex have been
1899 // deprecated. Instead, we use the cudnnRNNForward which requires the
1900 // sequence_lengths parameter. For more info,
1901 // https://docs.nvidia.com/deeplearning/cudnn/api/index.html#release-802.
1902 #if CUDNN_VERSION >= 8100 && TF_ENABLE_CUDNN_FRONTEND
1903 if (input_desc.is_var_seq_lengths()) {
1904 DeviceMemory<uint8> workspace;
1905 DeviceMemory<uint8> reserve_space;
1906 cudnnForwardMode_t rnn_fwd_mode;
1907 if (is_training) {
1908 rnn_fwd_mode = CUDNN_FWD_MODE_TRAINING;
1909 } else {
1910 rnn_fwd_mode = CUDNN_FWD_MODE_INFERENCE;
1911 }
1912 size_t reserve_space_size_in_bytes = 0;
1913 size_t workspace_size_in_bytes = 0;
1914 RETURN_IF_CUDNN_ERROR(cudnnGetRNNTempSpaceSizes(
1915 /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(),
1916 /*fMode=*/rnn_fwd_mode, /*xDesc=*/input_desc.data_handle(),
1917 /*workSpaceSize=*/&workspace_size_in_bytes,
1918 /*reserveSpaceSize=*/&reserve_space_size_in_bytes));
1919
1920 if (workspace_size_in_bytes > 0) {
1921 SE_ASSIGN_OR_RETURN(workspace, workspace_allocator->AllocateBytes(
1922 workspace_size_in_bytes));
1923 }
1924 if (reserve_space_size_in_bytes > 0) {
1925 SE_ASSIGN_OR_RETURN(reserve_space, reserve_space_allocator->AllocateBytes(
1926 reserve_space_size_in_bytes));
1927 }
1928
1929 std::unique_ptr<GpuTimer, GpuTimerDeleter> timer;
1930 const bool is_profiling = output_profile_result != nullptr;
1931 if (is_profiling) {
1932 timer.reset(new GpuTimer(parent_));
1933 // The start and stop of the timer should be as close to the Cudnn call as
1934 // possible. It is still possible for other threads to issue workload on
1935 // to this stream. So it could take multiple profiling measurements.
1936 if (!timer->Init() || !timer->Start(AsGpuStream(stream))) {
1937 return port::Status(port::error::INTERNAL, "Failed to start timer");
1938 }
1939 }
1940
1941 RETURN_IF_CUDNN_ERROR(cudnnRNNForward(
1942 /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(),
1943 /*fwdMode=*/rnn_fwd_mode,
1944 /*devSeqLengths=*/
1945 reinterpret_cast<const int*>(seq_lengths_data.opaque()),
1946 /*xDesc=*/input_desc.data_handle(), /*x=*/input_data.opaque(),
1947 /*yDesc=*/output_desc.data_handle(), /*y=*/output_data->opaque(),
1948 /*hxDesc=*/input_h_desc.handle(), /*hx=*/input_h_data.opaque(),
1949 /*hy=*/output_h_data->opaque(),
1950 /*cxDesc=*/input_c_desc.handle(), /*cx=*/input_c_data.opaque(),
1951 /*cy=*/output_c_data->opaque(),
1952 /*weightSpaceSize=*/rnn_desc.ParamsSizeInBytes(),
1953 /*weightSpace=*/params.opaque(),
1954 /*workSpaceSize=*/workspace.size(), /*workspace=*/workspace.opaque(),
1955 /*reserveSpaceSizeInBytes=*/reserve_space.size(),
1956 /*reserveSpace=*/reserve_space.opaque()));
1957
1958 if (is_profiling) {
1959 if (!timer->Stop(AsGpuStream(stream))) {
1960 return port::Status(port::error::INTERNAL, "Failed to stop timer");
1961 }
1962 auto algo_desc = *rnn_desc.algorithm_config().algorithm();
1963 output_profile_result->set_algorithm(algo_desc);
1964 output_profile_result->set_elapsed_time_in_ms(
1965 timer->GetElapsedMilliseconds());
1966 }
1967 return port::Status::OK();
1968 }
1969 #endif
1970 SE_ASSIGN_OR_RETURN(DeviceMemory<uint8> workspace,
1971 CreateRnnWorkspace(stream, cudnn, rnn_desc, input_desc,
1972 workspace_allocator))
1973
1974 // query the reserve space size
1975 // allocate the reserve space
1976 DeviceMemory<uint8> reserve_space;
1977 if (is_training) {
1978 size_t reserve_space_size_in_bytes = 0;
1979 RETURN_IF_CUDNN_ERROR(cudnnGetRNNTrainingReserveSize(
1980 /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(),
1981 /*seqLength=*/model_dims.max_seq_length, /*xDesc=*/input_desc.handles(),
1982 /*sizeInBytes=*/&reserve_space_size_in_bytes));
1983
1984 if (reserve_space_size_in_bytes > 0) {
1985 SE_ASSIGN_OR_RETURN(reserve_space, reserve_space_allocator->AllocateBytes(
1986 reserve_space_size_in_bytes));
1987 }
1988 }
1989
1990 std::unique_ptr<GpuTimer, GpuTimerDeleter> timer;
1991 const bool is_profiling = output_profile_result != nullptr;
1992 if (is_profiling) {
1993 timer.reset(new GpuTimer(parent_));
1994 // The start and stop of the timer should be as close to the Cudnn call as
1995 // possible. It is still possible for other threads to issue workload on
1996 // to this stream. So it could take multiple profiling measurements.
1997 if (!timer->Init() || !timer->Start(AsGpuStream(stream))) {
1998 return port::Status(port::error::INTERNAL, "Failed to start timer");
1999 }
2000 }
2001
2002 if (!is_training) {
2003 if (input_desc.is_var_seq_lengths()) {
2004 RETURN_IF_CUDNN_ERROR(cudnnRNNForwardInferenceEx(
2005 /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(),
2006 /*xDesc=*/input_desc.data_handle(), /*x=*/input_data.opaque(),
2007 /*hxDesc=*/input_h_desc.handle(), /*hx=*/input_h_data.opaque(),
2008 /*cxDesc=*/input_c_desc.handle(), /*cx=*/input_c_data.opaque(),
2009 /*wDesc=*/rnn_desc.params_handle(), /*w=*/params.opaque(),
2010 /*yDesc=*/output_desc.data_handle(),
2011 /*y=*/output_data->opaque(),
2012 /*hyDesc=*/output_h_desc.handle(), /*hy=*/output_h_data->opaque(),
2013 /*cyDesc=*/output_c_desc.handle(), /*cy=*/output_c_data->opaque(),
2014 nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr,
2015 nullptr,
2016 /*workspace=*/workspace.opaque(),
2017 /*workSpaceSizeInBytes=*/workspace.size()));
2018 } else {
2019 RETURN_IF_CUDNN_ERROR(cudnnRNNForwardInference(
2020 /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(),
2021 /*seqLength=*/model_dims.max_seq_length,
2022 /*xDesc=*/input_desc.handles(),
2023 /*x=*/input_data.opaque(), /*hxDesc=*/input_h_desc.handle(),
2024 /*hx=*/input_h_data.opaque(), /*cxDesc=*/input_c_desc.handle(),
2025 /*cx=*/input_c_data.opaque(), /*wDesc=*/rnn_desc.params_handle(),
2026 /*w=*/params.opaque(), /*yDesc=*/output_desc.handles(),
2027 /*y=*/output_data->opaque(), /*hyDesc=*/output_h_desc.handle(),
2028 /*hy=*/output_h_data->opaque(), /*cyDesc=*/output_c_desc.handle(),
2029 /*cy=*/output_c_data->opaque(), /*workspace=*/workspace.opaque(),
2030 /*workSpaceSizeInBytes=*/workspace.size()));
2031 }
2032 } else {
2033 if (input_desc.is_var_seq_lengths()) {
2034 RETURN_IF_CUDNN_ERROR(cudnnRNNForwardTrainingEx(
2035 /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(),
2036 /*xDesc=*/input_desc.data_handle(), /*x=*/input_data.opaque(),
2037 /*hxDesc=*/input_h_desc.handle(), /*hx=*/input_h_data.opaque(),
2038 /*cxDesc=*/input_c_desc.handle(), /*cx=*/input_c_data.opaque(),
2039 /*wDesc=*/rnn_desc.params_handle(), /*w=*/params.opaque(),
2040 /*yDesc=*/output_desc.data_handle(),
2041 /*y=*/output_data->opaque(),
2042 /*hyDesc=*/output_h_desc.handle(), /*hy=*/output_h_data->opaque(),
2043 /*cyDesc=*/output_c_desc.handle(), /*cy=*/output_c_data->opaque(),
2044 nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr,
2045 nullptr,
2046 /*workspace=*/workspace.opaque(),
2047 /*workSpaceSizeInBytes=*/workspace.size(),
2048 /*reserveSpace=*/reserve_space.opaque(),
2049 /*reserveSpaceSizeInBytes=*/reserve_space.size()));
2050 } else {
2051 RETURN_IF_CUDNN_ERROR(cudnnRNNForwardTraining(
2052 /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(),
2053 /*seqLength=*/model_dims.max_seq_length,
2054 /*xDesc=*/input_desc.handles(),
2055 /*x=*/input_data.opaque(), /*hxDesc=*/input_h_desc.handle(),
2056 /*hx=*/input_h_data.opaque(), /*cxDesc=*/input_c_desc.handle(),
2057 /*cx=*/input_c_data.opaque(), /*wDesc=*/rnn_desc.params_handle(),
2058 /*w=*/params.opaque(), /*yDesc=*/output_desc.handles(),
2059 /*y=*/output_data->opaque(), /*hyDesc=*/output_h_desc.handle(),
2060 /*hy=*/output_h_data->opaque(), /*cyDesc=*/output_c_desc.handle(),
2061 /*cy=*/output_c_data->opaque(), /*workspace=*/workspace.opaque(),
2062 /*workSpaceSizeInBytes=*/workspace.size(),
2063 /*reserveSpace=*/reserve_space.opaque(),
2064 /*reserveSpaceSizeInBytes=*/reserve_space.size()));
2065 }
2066 }
2067
2068 if (is_profiling) {
2069 if (!timer->Stop(AsGpuStream(stream))) {
2070 return port::Status(port::error::INTERNAL, "Failed to stop timer");
2071 }
2072 auto algo_desc = *rnn_desc.algorithm_config().algorithm();
2073 output_profile_result->set_algorithm(algo_desc);
2074 output_profile_result->set_elapsed_time_in_ms(
2075 timer->GetElapsedMilliseconds());
2076 }
2077
2078 return port::Status::OK();
2079 }
2080
2081 template <class T>
DoRnnBackwardImpl(Stream * stream,const CudnnRnnDescriptor & rnn_desc,const CudnnRnnSequenceTensorDescriptor & input_desc,const DeviceMemory<T> & input_data,const DeviceMemory<int> & seq_lengths_data,const CudnnRnnStateTensorDescriptor & input_h_desc,const DeviceMemory<T> & input_h_data,const CudnnRnnStateTensorDescriptor & input_c_desc,const DeviceMemory<T> & input_c_data,const DeviceMemory<T> & params,const CudnnRnnSequenceTensorDescriptor & output_desc,const DeviceMemory<T> & output_data,const CudnnRnnStateTensorDescriptor & output_h_desc,const DeviceMemory<T> & output_h_data,const CudnnRnnStateTensorDescriptor & output_c_desc,const DeviceMemory<T> & output_c_data,const DeviceMemory<T> & output_backprop_data,const DeviceMemory<T> & output_h_backprop_data,const DeviceMemory<T> & output_c_backprop_data,DeviceMemory<T> * input_backprop_data,DeviceMemory<T> * input_h_backprop_data,DeviceMemory<T> * input_c_backprop_data,DeviceMemory<T> * params_backprop_data,DeviceMemory<uint8> * reserve_space_data,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)2082 port::Status CudnnSupport::DoRnnBackwardImpl(
2083 Stream* stream, const CudnnRnnDescriptor& rnn_desc,
2084 const CudnnRnnSequenceTensorDescriptor& input_desc,
2085 const DeviceMemory<T>& input_data,
2086 const DeviceMemory<int>& seq_lengths_data,
2087 const CudnnRnnStateTensorDescriptor& input_h_desc,
2088 const DeviceMemory<T>& input_h_data,
2089 const CudnnRnnStateTensorDescriptor& input_c_desc,
2090 const DeviceMemory<T>& input_c_data, const DeviceMemory<T>& params,
2091 const CudnnRnnSequenceTensorDescriptor& output_desc,
2092 const DeviceMemory<T>& output_data,
2093 const CudnnRnnStateTensorDescriptor& output_h_desc,
2094 const DeviceMemory<T>& output_h_data,
2095 const CudnnRnnStateTensorDescriptor& output_c_desc,
2096 const DeviceMemory<T>& output_c_data,
2097 const DeviceMemory<T>& output_backprop_data,
2098 const DeviceMemory<T>& output_h_backprop_data,
2099 const DeviceMemory<T>& output_c_backprop_data,
2100 DeviceMemory<T>* input_backprop_data,
2101 DeviceMemory<T>* input_h_backprop_data,
2102 DeviceMemory<T>* input_c_backprop_data,
2103 DeviceMemory<T>* params_backprop_data,
2104 DeviceMemory<uint8>* reserve_space_data,
2105 ScratchAllocator* workspace_allocator,
2106 dnn::ProfileResult* output_profile_result) {
2107 SE_ASSIGN_OR_RETURN(
2108 RnnModelDims model_dims,
2109 ExtractAndCheckRnnForward(rnn_desc, input_desc, input_data, input_h_desc,
2110 input_h_data, input_c_desc, input_c_data,
2111 params, output_desc, output_data, output_h_desc,
2112 output_h_data, output_c_desc, output_c_data));
2113
2114 auto cudnn = cudnn_->GetHandle(parent_, stream);
2115
2116 SE_RETURN_IF_ERROR(CheckRNNParameterSize(cudnn, rnn_desc, input_desc));
2117
2118 // In CUDNN v8.0, the cudnnRNNForward*** and cudnnRNNForward***Ex have been
2119 // deprecated. Instead, we use the cudnnRNNForward which requires the
2120 // sequence_lengths parameter. For more info,
2121 // https://docs.nvidia.com/deeplearning/cudnn/api/index.html#release-802.
2122 #if CUDNN_VERSION >= 8100 && TF_ENABLE_CUDNN_FRONTEND
2123 if (input_desc.is_var_seq_lengths()) {
2124 DeviceMemory<uint8> workspace;
2125 size_t workspace_size_in_bytes = 0;
2126 RETURN_IF_CUDNN_ERROR(cudnnGetRNNTempSpaceSizes(
2127 /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(),
2128 /*fMode=*/CUDNN_FWD_MODE_TRAINING, /*xDesc=*/input_desc.data_handle(),
2129 /*workSpaceSize=*/&workspace_size_in_bytes,
2130 /*reserveSpaceSize=*/NULL));
2131 if (workspace_size_in_bytes > 0) {
2132 SE_ASSIGN_OR_RETURN(workspace, workspace_allocator->AllocateBytes(
2133 workspace_size_in_bytes));
2134 }
2135
2136 std::unique_ptr<GpuTimer, GpuTimerDeleter> timer;
2137 const bool is_profiling = output_profile_result != nullptr;
2138 if (is_profiling) {
2139 timer.reset(new GpuTimer(parent_));
2140 // The start and stop of the timer should be as close to the Cudnn call as
2141 // possible. It is still possible for other threads to issue workload on
2142 // to this stream. So it could take multiple profiling measurements.
2143 if (!timer->Init() || !timer->Start(AsGpuStream(stream))) {
2144 return port::Status(port::error::INTERNAL, "Failed to start timer");
2145 }
2146 }
2147
2148 RETURN_IF_CUDNN_ERROR(cudnnRNNBackwardData_v8(
2149 /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(),
2150 /*devSeqLengths=*/
2151 reinterpret_cast<const int*>(seq_lengths_data.opaque()),
2152 /*yDesc=*/output_desc.data_handle(), /*y=*/output_data.opaque(),
2153 /*dy=*/output_backprop_data.opaque(),
2154 /*xDesc=*/input_desc.data_handle(),
2155 /*dx=*/input_backprop_data->opaque(),
2156 /*hxDesc=*/input_h_desc.handle(), /*hx=*/input_h_data.opaque(),
2157 /*dhy=*/output_h_backprop_data.opaque(),
2158 /*dhx=*/input_h_backprop_data->opaque(),
2159 /*cxDesc=*/input_c_desc.handle(), /*cx=*/input_c_data.opaque(),
2160 /*dcy=*/output_c_backprop_data.opaque(),
2161 /*dcx=*/input_c_backprop_data->opaque(),
2162 /*weightSpaceSize=*/rnn_desc.ParamsSizeInBytes(),
2163 /*weightSpace=*/params.opaque(),
2164 /*workSpaceSize=*/workspace.size(), /*workSpace=*/workspace.opaque(),
2165 /*reserveSpaceSize=*/reserve_space_data->size(),
2166 /*reserveSpace=*/reserve_space_data->opaque()));
2167
2168 if (params_backprop_data != nullptr) {
2169 // Clear the dw to zeros.
2170 stream->ThenMemZero(params_backprop_data, params_backprop_data->size());
2171 RETURN_IF_CUDNN_ERROR(cudnnRNNBackwardWeights_v8(
2172 /*handle=*/cudnn.handle(),
2173 /*rnnDesc=*/rnn_desc.handle(),
2174 /*addGrad=*/CUDNN_WGRAD_MODE_ADD,
2175 /*devSeqLengths=*/
2176 reinterpret_cast<const int*>(seq_lengths_data.opaque()),
2177 /*xDesc=*/input_desc.data_handle(),
2178 /*x=*/input_data.opaque(),
2179 /*hDesc=*/input_h_desc.handle(),
2180 /*hx=*/input_h_data.opaque(),
2181 /*yDesc=*/output_desc.data_handle(),
2182 /*y=*/output_data.opaque(),
2183 /*weightSpaceSize=*/rnn_desc.ParamsSizeInBytes(),
2184 /*dweightSpace=*/params_backprop_data->opaque(),
2185 /*workSpaceSize=*/workspace.size(),
2186 /*workSpace=*/workspace.opaque(),
2187 /*reserveSpaceSize=*/reserve_space_data->size(),
2188 /*reserveSpace=*/reserve_space_data->opaque()));
2189 }
2190
2191 if (is_profiling) {
2192 if (!timer->Stop(AsGpuStream(stream))) {
2193 return port::Status(port::error::INTERNAL, "Failed to stop timer");
2194 }
2195 auto algo_desc = *rnn_desc.algorithm_config().algorithm();
2196 output_profile_result->set_algorithm(algo_desc);
2197 output_profile_result->set_elapsed_time_in_ms(
2198 timer->GetElapsedMilliseconds());
2199 }
2200 return port::Status::OK();
2201 }
2202 #endif
2203 SE_ASSIGN_OR_RETURN(DeviceMemory<uint8> workspace,
2204 CreateRnnWorkspace(stream, cudnn, rnn_desc, input_desc,
2205 workspace_allocator));
2206
2207 std::unique_ptr<GpuTimer, GpuTimerDeleter> timer;
2208 const bool is_profiling = output_profile_result != nullptr;
2209 if (is_profiling) {
2210 timer.reset(new GpuTimer(parent_));
2211 // The start and stop of the timer should be as close to the Cudnn call as
2212 // possible. It is still possible for other threads to issue workload on
2213 // to this stream. So it could take multiple profiling measurements.
2214 if (!timer->Init() || !timer->Start(AsGpuStream(stream))) {
2215 return port::Status(port::error::INTERNAL, "Failed to start timer");
2216 }
2217 }
2218
2219 if (input_desc.is_var_seq_lengths()) {
2220 RETURN_IF_CUDNN_ERROR(cudnnRNNBackwardDataEx(
2221 /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(),
2222 /*yDesc=*/output_desc.data_handle(), /*y=*/output_data.opaque(),
2223 /*dyDesc=*/output_desc.data_handle(),
2224 /*dy=*/output_backprop_data.opaque(), nullptr, nullptr,
2225 /*dhyDesc=*/output_h_desc.handle(),
2226 /*dhy=*/output_h_backprop_data.opaque(),
2227 /*dcyDesc=*/output_c_desc.handle(),
2228 /*dcy=*/output_c_backprop_data.opaque(),
2229 /*wDesc=*/rnn_desc.params_handle(), /*w=*/params.opaque(),
2230 /*hxDesc=*/input_h_desc.handle(), /*hx=*/input_h_data.opaque(),
2231 /*cxDesc=*/input_c_desc.handle(), /*cx=*/input_c_data.opaque(),
2232 /*dxDesc=*/input_desc.data_handle(),
2233 /*dx=*/input_backprop_data->opaque(),
2234 /*dhxDesc=*/input_h_desc.handle(),
2235 /*dhx=*/input_h_backprop_data->opaque(),
2236 /*dcxDesc=*/input_c_desc.handle(),
2237 /*dcx=*/input_c_backprop_data->opaque(), nullptr, nullptr,
2238 /*workspace=*/workspace.opaque(),
2239 /*workSpaceSizeInBytes=*/workspace.size(),
2240 /*reserveSpace=*/reserve_space_data->opaque(),
2241 /*reserveSpaceSizeInBytes=*/reserve_space_data->size()));
2242 } else {
2243 RETURN_IF_CUDNN_ERROR(cudnnRNNBackwardData(
2244 /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(),
2245 /*seqLength=*/model_dims.max_seq_length,
2246 /*yDesc=*/output_desc.handles(),
2247 /*y=*/output_data.opaque(), /*dyDesc=*/output_desc.handles(),
2248 /*dy=*/output_backprop_data.opaque(),
2249 /*dhyDesc=*/output_h_desc.handle(),
2250 /*dhy=*/output_h_backprop_data.opaque(),
2251 /*dcyDesc=*/output_c_desc.handle(),
2252 /*dcy=*/output_c_backprop_data.opaque(),
2253 /*wDesc=*/rnn_desc.params_handle(), /*w=*/params.opaque(),
2254 /*hxDesc=*/input_h_desc.handle(), /*hx=*/input_h_data.opaque(),
2255 /*cxDesc=*/input_c_desc.handle(), /*cx=*/input_c_data.opaque(),
2256 /*dxDesc=*/input_desc.handles(), /*dx=*/input_backprop_data->opaque(),
2257 /*dhxDesc=*/input_h_desc.handle(),
2258 /*dhx=*/input_h_backprop_data->opaque(),
2259 /*dcxDesc=*/input_c_desc.handle(),
2260 /*dcx=*/input_c_backprop_data->opaque(),
2261 /*workspace=*/workspace.opaque(),
2262 /*workSpaceSizeInBytes=*/workspace.size(),
2263 /*reserveSpace=*/reserve_space_data->opaque(),
2264 /*reserveSpaceSizeInBytes=*/reserve_space_data->size()));
2265 }
2266
2267 if (params_backprop_data != nullptr) {
2268 // Clear the dw to zeros.
2269 stream->ThenMemZero(params_backprop_data, params_backprop_data->size());
2270 if (input_desc.is_var_seq_lengths()) {
2271 RETURN_IF_CUDNN_ERROR(cudnnRNNBackwardWeightsEx(
2272 /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(),
2273 /*xDesc=*/input_desc.data_handle(), /*x=*/input_data.opaque(),
2274 /*hxDesc=*/input_h_desc.handle(), /*hx=*/input_h_data.opaque(),
2275 /*yDesc=*/output_desc.data_handle(),
2276 /*y=*/output_data.opaque(),
2277 /*workspace=*/workspace.opaque(),
2278 /*workSpaceSizeInBytes=*/workspace.size(),
2279 /*dwDesc=*/rnn_desc.params_handle(),
2280 /*dw=*/params_backprop_data->opaque(),
2281 /*reserveSpace=*/reserve_space_data->opaque(),
2282 /*reserveSpaceSizeInBytes=*/reserve_space_data->size()));
2283 } else {
2284 // make the backward weight call
2285 RETURN_IF_CUDNN_ERROR(cudnnRNNBackwardWeights(
2286 /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(),
2287 /*seqLength=*/model_dims.max_seq_length,
2288 /*xDesc=*/input_desc.handles(),
2289 /*x=*/input_data.opaque(), /*hxDesc=*/input_h_desc.handle(),
2290 /*hx=*/input_h_data.opaque(), /*yDesc=*/output_desc.handles(),
2291 /*y=*/output_data.opaque(), /*workspace=*/workspace.opaque(),
2292 /*workSpaceSizeInBytes=*/workspace.size(),
2293 /*dwDesc=*/rnn_desc.params_handle(),
2294 /*dw=*/params_backprop_data->opaque(),
2295 /*reserveSpace=*/reserve_space_data->opaque(),
2296 /*reserveSpaceSizeInBytes=*/reserve_space_data->size()));
2297 }
2298 }
2299
2300 if (is_profiling) {
2301 if (!timer->Stop(AsGpuStream(stream))) {
2302 return port::Status(port::error::INTERNAL, "Failed to stop timer");
2303 }
2304 auto algo_desc = *rnn_desc.algorithm_config().algorithm();
2305 output_profile_result->set_algorithm(algo_desc);
2306 output_profile_result->set_elapsed_time_in_ms(
2307 timer->GetElapsedMilliseconds());
2308 }
2309
2310 return port::Status::OK();
2311 }
2312
DoCtcLossImpl(Stream * stream,const CudnnRnnStateTensorDescriptor & probs_desc,const DeviceMemoryBase probs_data,absl::Span<const int> labels_data,absl::Span<const int> labels_lengths_data,absl::Span<const int> input_lengths_data,DeviceMemoryBase costs_data,const CudnnRnnStateTensorDescriptor & grads_desc,DeviceMemoryBase grads_data,const CudnnCtcLossDescriptor & ctc_loss_desc,DeviceMemory<uint8> scratch_memory,int ctc_loss_algo_id)2313 port::Status CudnnSupport::DoCtcLossImpl(
2314 Stream* stream, const CudnnRnnStateTensorDescriptor& probs_desc,
2315 const DeviceMemoryBase probs_data, absl::Span<const int> labels_data,
2316 absl::Span<const int> labels_lengths_data,
2317 absl::Span<const int> input_lengths_data, DeviceMemoryBase costs_data,
2318 const CudnnRnnStateTensorDescriptor& grads_desc,
2319 DeviceMemoryBase grads_data, const CudnnCtcLossDescriptor& ctc_loss_desc,
2320 DeviceMemory<uint8> scratch_memory, int ctc_loss_algo_id) {
2321 auto cudnn = cudnn_->GetHandle(parent_, stream);
2322
2323 int kNumTimestamps = probs_desc.num_layers();
2324 int kBatchSize = probs_desc.batch_size();
2325 int kNumLabels = probs_desc.data_size();
2326 int total_size = kNumLabels * kNumTimestamps * kBatchSize;
2327 (void)total_size;
2328
2329 #if CUDNN_VERSION >= 7603
2330 cudnnCTCLossAlgo_t ctc_loss_algo =
2331 static_cast<cudnnCTCLossAlgo_t>(ctc_loss_algo_id);
2332 RETURN_IF_CUDNN_ERROR(cudnnCTCLoss(
2333 /*handle=*/cudnn.handle(), /*probsDesc=*/probs_desc.handle(),
2334 /*probs=*/probs_data.opaque(), /*labels=*/labels_data.data(),
2335 /*labelLengths=*/labels_lengths_data.data(),
2336 /*inputLengths=*/input_lengths_data.data(),
2337 /*costs=*/costs_data.opaque(), /*gradientsDesc=*/grads_desc.handle(),
2338 /*gradients=*/grads_data.opaque(),
2339 /*algo=*/ctc_loss_algo,
2340 /*ctcLossDesc=*/ctc_loss_desc.handle(),
2341 /*workspace=*/scratch_memory.opaque(),
2342 /*workSpaceSizeInBytes=*/scratch_memory.size()));
2343 #else
2344 return port::Status(port::error::INVALID_ARGUMENT,
2345 "No supported cudnnCTCLoss when "
2346 "CUDNN_VERSION < 7.6.3");
2347 #endif
2348
2349 return port::Status::OK();
2350 }
2351
2352 port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>>
createRnnDescriptor(int num_layers,int hidden_size,int input_size,int cell_size,int batch_size,dnn::RnnInputMode input_mode,dnn::RnnDirectionMode direction_mode,dnn::RnnMode rnn_mode,dnn::DataType data_type,const dnn::AlgorithmConfig & algorithm_config,float dropout,uint64 seed,ScratchAllocator * state_allocator,bool use_padded_io)2353 CudnnSupport::createRnnDescriptor(
2354 int num_layers, int hidden_size, int input_size, int cell_size,
2355 int batch_size, dnn::RnnInputMode input_mode,
2356 dnn::RnnDirectionMode direction_mode, dnn::RnnMode rnn_mode,
2357 dnn::DataType data_type, const dnn::AlgorithmConfig& algorithm_config,
2358 float dropout, uint64 seed, ScratchAllocator* state_allocator,
2359 bool use_padded_io) {
2360 // Setting up a cudnnRNNDescriptor requires a cuDNN handle, but because it's
2361 // not enqueueing anything into a stream, we pass in the null stream.
2362 auto cudnn = cudnn_->GetHandle(parent_, /*stream=*/nullptr);
2363 SE_ASSIGN_OR_RETURN(
2364 CudnnRnnDescriptor rnn_desc,
2365 CudnnRnnDescriptor::Create(
2366 cudnn, num_layers, hidden_size, input_size, cell_size, batch_size,
2367 ToCudnnRnnInputMode(input_mode),
2368 ToCudnnRnnDirectionMode(direction_mode), ToCudnnRnnMode(rnn_mode),
2369 ToCudnnDataType(data_type), GetRnnComputeType(data_type),
2370 algorithm_config, dropout, seed, state_allocator, use_padded_io));
2371 return std::unique_ptr<dnn::RnnDescriptor>(
2372 new CudnnRnnDescriptor(std::move(rnn_desc)));
2373 }
2374
2375 port::StatusOr<std::unique_ptr<dnn::RnnSequenceTensorDescriptor>>
createRnnSequenceTensorDescriptor(int max_seq_length,int batch_size,int data_size,dnn::DataType data_type)2376 CudnnSupport::createRnnSequenceTensorDescriptor(int max_seq_length,
2377 int batch_size, int data_size,
2378 dnn::DataType data_type) {
2379 SE_ASSIGN_OR_RETURN(CudnnRnnSequenceTensorDescriptor descriptor,
2380 CudnnRnnSequenceTensorDescriptor::Create(
2381 parent_, max_seq_length, batch_size, data_size,
2382 ToCudnnDataType(data_type)));
2383 return std::unique_ptr<dnn::RnnSequenceTensorDescriptor>(
2384 new CudnnRnnSequenceTensorDescriptor(std::move(descriptor)));
2385 }
2386
2387 port::StatusOr<std::unique_ptr<dnn::RnnSequenceTensorDescriptor>>
createRnnSequenceTensorDescriptor(int max_seq_length,int batch_size,int data_size,const absl::Span<const int> & seq_lengths,bool time_major,dnn::DataType data_type)2388 CudnnSupport::createRnnSequenceTensorDescriptor(
2389 int max_seq_length, int batch_size, int data_size,
2390 const absl::Span<const int>& seq_lengths, bool time_major,
2391 dnn::DataType data_type) {
2392 SE_ASSIGN_OR_RETURN(CudnnRnnSequenceTensorDescriptor descriptor,
2393 CudnnRnnSequenceTensorDescriptor::Create(
2394 parent_, max_seq_length, batch_size, data_size,
2395 seq_lengths, time_major, ToCudnnDataType(data_type)));
2396 return std::unique_ptr<dnn::RnnSequenceTensorDescriptor>(
2397 new CudnnRnnSequenceTensorDescriptor(std::move(descriptor)));
2398 }
2399
2400 port::StatusOr<std::unique_ptr<dnn::RnnStateTensorDescriptor>>
createRnnStateTensorDescriptor(int num_layer,int batch_size,int data_size,dnn::DataType data_type)2401 CudnnSupport::createRnnStateTensorDescriptor(int num_layer, int batch_size,
2402 int data_size,
2403 dnn::DataType data_type) {
2404 return std::unique_ptr<dnn::RnnStateTensorDescriptor>(
2405 new CudnnRnnStateTensorDescriptor(parent_, num_layer, batch_size,
2406 data_size, ToCudnnDataType(data_type)));
2407 }
2408
DoRnnForward(Stream * stream,const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<Eigen::half> & input_data,const DeviceMemory<int> & seq_lengths_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<Eigen::half> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<Eigen::half> & input_c_data,const DeviceMemory<Eigen::half> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,DeviceMemory<Eigen::half> * output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,DeviceMemory<Eigen::half> * output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,DeviceMemory<Eigen::half> * output_c_data,bool is_training,ScratchAllocator * reserve_space_allocator,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)2409 bool CudnnSupport::DoRnnForward(
2410 Stream* stream, const dnn::RnnDescriptor& rnn_desc,
2411 const dnn::RnnSequenceTensorDescriptor& input_desc,
2412 const DeviceMemory<Eigen::half>& input_data,
2413 const DeviceMemory<int>& seq_lengths_data,
2414 const dnn::RnnStateTensorDescriptor& input_h_desc,
2415 const DeviceMemory<Eigen::half>& input_h_data,
2416 const dnn::RnnStateTensorDescriptor& input_c_desc,
2417 const DeviceMemory<Eigen::half>& input_c_data,
2418 const DeviceMemory<Eigen::half>& params,
2419 const dnn::RnnSequenceTensorDescriptor& output_desc,
2420 DeviceMemory<Eigen::half>* output_data,
2421 const dnn::RnnStateTensorDescriptor& output_h_desc,
2422 DeviceMemory<Eigen::half>* output_h_data,
2423 const dnn::RnnStateTensorDescriptor& output_c_desc,
2424 DeviceMemory<Eigen::half>* output_c_data, bool is_training,
2425 ScratchAllocator* reserve_space_allocator,
2426 ScratchAllocator* workspace_allocator,
2427 dnn::ProfileResult* output_profile_result) {
2428 const CudnnRnnDescriptor& cudnn_rnn_desc =
2429 static_cast<const CudnnRnnDescriptor&>(rnn_desc);
2430 const CudnnRnnSequenceTensorDescriptor& cudnn_input_desc =
2431 static_cast<const CudnnRnnSequenceTensorDescriptor&>(input_desc);
2432 const CudnnRnnStateTensorDescriptor& cudnn_input_h_desc =
2433 static_cast<const CudnnRnnStateTensorDescriptor&>(input_h_desc);
2434 const CudnnRnnStateTensorDescriptor& cudnn_input_c_desc =
2435 static_cast<const CudnnRnnStateTensorDescriptor&>(input_c_desc);
2436 const CudnnRnnSequenceTensorDescriptor& cudnn_output_desc =
2437 static_cast<const CudnnRnnSequenceTensorDescriptor&>(output_desc);
2438 const CudnnRnnStateTensorDescriptor& cudnn_output_h_desc =
2439 static_cast<const CudnnRnnStateTensorDescriptor&>(output_h_desc);
2440 const CudnnRnnStateTensorDescriptor& cudnn_output_c_desc =
2441 static_cast<const CudnnRnnStateTensorDescriptor&>(output_c_desc);
2442 return IsStatusOk(
2443 DoRnnForwardImpl<Eigen::half>(
2444 stream, cudnn_rnn_desc, cudnn_input_desc, input_data,
2445 seq_lengths_data, cudnn_input_h_desc, input_h_data,
2446 cudnn_input_c_desc, input_c_data, params, cudnn_output_desc,
2447 output_data, cudnn_output_h_desc, output_h_data, cudnn_output_c_desc,
2448 output_c_data, is_training, reserve_space_allocator,
2449 workspace_allocator, output_profile_result),
2450 /*report_error=*/!output_profile_result);
2451 }
2452
DoRnnForward(Stream * stream,const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<float> & input_data,const DeviceMemory<int> & seq_lengths_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<float> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<float> & input_c_data,const DeviceMemory<float> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,DeviceMemory<float> * output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,DeviceMemory<float> * output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,DeviceMemory<float> * output_c_data,bool is_training,ScratchAllocator * reserve_space_allocator,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)2453 bool CudnnSupport::DoRnnForward(
2454 Stream* stream, const dnn::RnnDescriptor& rnn_desc,
2455 const dnn::RnnSequenceTensorDescriptor& input_desc,
2456 const DeviceMemory<float>& input_data,
2457 const DeviceMemory<int>& seq_lengths_data,
2458 const dnn::RnnStateTensorDescriptor& input_h_desc,
2459 const DeviceMemory<float>& input_h_data,
2460 const dnn::RnnStateTensorDescriptor& input_c_desc,
2461 const DeviceMemory<float>& input_c_data, const DeviceMemory<float>& params,
2462 const dnn::RnnSequenceTensorDescriptor& output_desc,
2463 DeviceMemory<float>* output_data,
2464 const dnn::RnnStateTensorDescriptor& output_h_desc,
2465 DeviceMemory<float>* output_h_data,
2466 const dnn::RnnStateTensorDescriptor& output_c_desc,
2467 DeviceMemory<float>* output_c_data, bool is_training,
2468 ScratchAllocator* reserve_space_allocator,
2469 ScratchAllocator* workspace_allocator,
2470 dnn::ProfileResult* output_profile_result) {
2471 const CudnnRnnDescriptor& cudnn_rnn_desc =
2472 static_cast<const CudnnRnnDescriptor&>(rnn_desc);
2473 const CudnnRnnSequenceTensorDescriptor& cudnn_input_desc =
2474 static_cast<const CudnnRnnSequenceTensorDescriptor&>(input_desc);
2475 const CudnnRnnStateTensorDescriptor& cudnn_input_h_desc =
2476 static_cast<const CudnnRnnStateTensorDescriptor&>(input_h_desc);
2477 const CudnnRnnStateTensorDescriptor& cudnn_input_c_desc =
2478 static_cast<const CudnnRnnStateTensorDescriptor&>(input_c_desc);
2479 const CudnnRnnSequenceTensorDescriptor& cudnn_output_desc =
2480 static_cast<const CudnnRnnSequenceTensorDescriptor&>(output_desc);
2481 const CudnnRnnStateTensorDescriptor& cudnn_output_h_desc =
2482 static_cast<const CudnnRnnStateTensorDescriptor&>(output_h_desc);
2483 const CudnnRnnStateTensorDescriptor& cudnn_output_c_desc =
2484 static_cast<const CudnnRnnStateTensorDescriptor&>(output_c_desc);
2485 return IsStatusOk(
2486 DoRnnForwardImpl<float>(
2487 stream, cudnn_rnn_desc, cudnn_input_desc, input_data,
2488 seq_lengths_data, cudnn_input_h_desc, input_h_data,
2489 cudnn_input_c_desc, input_c_data, params, cudnn_output_desc,
2490 output_data, cudnn_output_h_desc, output_h_data, cudnn_output_c_desc,
2491 output_c_data, is_training, reserve_space_allocator,
2492 workspace_allocator, output_profile_result),
2493 /*report_error=*/!output_profile_result);
2494 }
2495
DoRnnForward(Stream * stream,const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<double> & input_data,const DeviceMemory<int> & seq_lengths_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<double> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<double> & input_c_data,const DeviceMemory<double> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,DeviceMemory<double> * output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,DeviceMemory<double> * output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,DeviceMemory<double> * output_c_data,bool is_training,ScratchAllocator * reserve_space_allocator,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)2496 bool CudnnSupport::DoRnnForward(
2497 Stream* stream, const dnn::RnnDescriptor& rnn_desc,
2498 const dnn::RnnSequenceTensorDescriptor& input_desc,
2499 const DeviceMemory<double>& input_data,
2500 const DeviceMemory<int>& seq_lengths_data,
2501 const dnn::RnnStateTensorDescriptor& input_h_desc,
2502 const DeviceMemory<double>& input_h_data,
2503 const dnn::RnnStateTensorDescriptor& input_c_desc,
2504 const DeviceMemory<double>& input_c_data,
2505 const DeviceMemory<double>& params,
2506 const dnn::RnnSequenceTensorDescriptor& output_desc,
2507 DeviceMemory<double>* output_data,
2508 const dnn::RnnStateTensorDescriptor& output_h_desc,
2509 DeviceMemory<double>* output_h_data,
2510 const dnn::RnnStateTensorDescriptor& output_c_desc,
2511 DeviceMemory<double>* output_c_data, bool is_training,
2512 ScratchAllocator* reserve_space_allocator,
2513 ScratchAllocator* workspace_allocator,
2514 dnn::ProfileResult* output_profile_result) {
2515 const CudnnRnnDescriptor& cudnn_rnn_desc =
2516 static_cast<const CudnnRnnDescriptor&>(rnn_desc);
2517 const CudnnRnnSequenceTensorDescriptor& cudnn_input_desc =
2518 static_cast<const CudnnRnnSequenceTensorDescriptor&>(input_desc);
2519 const CudnnRnnStateTensorDescriptor& cudnn_input_h_desc =
2520 static_cast<const CudnnRnnStateTensorDescriptor&>(input_h_desc);
2521 const CudnnRnnStateTensorDescriptor& cudnn_input_c_desc =
2522 static_cast<const CudnnRnnStateTensorDescriptor&>(input_c_desc);
2523 const CudnnRnnSequenceTensorDescriptor& cudnn_output_desc =
2524 static_cast<const CudnnRnnSequenceTensorDescriptor&>(output_desc);
2525 const CudnnRnnStateTensorDescriptor& cudnn_output_h_desc =
2526 static_cast<const CudnnRnnStateTensorDescriptor&>(output_h_desc);
2527 const CudnnRnnStateTensorDescriptor& cudnn_output_c_desc =
2528 static_cast<const CudnnRnnStateTensorDescriptor&>(output_c_desc);
2529 return IsStatusOk(
2530 DoRnnForwardImpl<double>(
2531 stream, cudnn_rnn_desc, cudnn_input_desc, input_data,
2532 seq_lengths_data, cudnn_input_h_desc, input_h_data,
2533 cudnn_input_c_desc, input_c_data, params, cudnn_output_desc,
2534 output_data, cudnn_output_h_desc, output_h_data, cudnn_output_c_desc,
2535 output_c_data, is_training, reserve_space_allocator,
2536 workspace_allocator, output_profile_result),
2537 /*report_error=*/!output_profile_result);
2538 }
2539
DoRnnBackward(Stream * stream,const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<Eigen::half> & input_data,const DeviceMemory<int> & seq_lengths_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<Eigen::half> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<Eigen::half> & input_c_data,const DeviceMemory<Eigen::half> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,const DeviceMemory<Eigen::half> & output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,const DeviceMemory<Eigen::half> & output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,const DeviceMemory<Eigen::half> & output_c_data,const DeviceMemory<Eigen::half> & output_backprop_data,const DeviceMemory<Eigen::half> & output_h_backprop_data,const DeviceMemory<Eigen::half> & output_c_backprop_data,DeviceMemory<Eigen::half> * input_backprop_data,DeviceMemory<Eigen::half> * input_h_backprop_data,DeviceMemory<Eigen::half> * input_c_backprop_data,DeviceMemory<Eigen::half> * params_backprop_data,DeviceMemory<uint8> * reserve_space_data,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)2540 bool CudnnSupport::DoRnnBackward(
2541 Stream* stream, const dnn::RnnDescriptor& rnn_desc,
2542 const dnn::RnnSequenceTensorDescriptor& input_desc,
2543 const DeviceMemory<Eigen::half>& input_data,
2544 const DeviceMemory<int>& seq_lengths_data,
2545 const dnn::RnnStateTensorDescriptor& input_h_desc,
2546 const DeviceMemory<Eigen::half>& input_h_data,
2547 const dnn::RnnStateTensorDescriptor& input_c_desc,
2548 const DeviceMemory<Eigen::half>& input_c_data,
2549 const DeviceMemory<Eigen::half>& params,
2550 const dnn::RnnSequenceTensorDescriptor& output_desc,
2551 const DeviceMemory<Eigen::half>& output_data,
2552 const dnn::RnnStateTensorDescriptor& output_h_desc,
2553 const DeviceMemory<Eigen::half>& output_h_data,
2554 const dnn::RnnStateTensorDescriptor& output_c_desc,
2555 const DeviceMemory<Eigen::half>& output_c_data,
2556 const DeviceMemory<Eigen::half>& output_backprop_data,
2557 const DeviceMemory<Eigen::half>& output_h_backprop_data,
2558 const DeviceMemory<Eigen::half>& output_c_backprop_data,
2559 DeviceMemory<Eigen::half>* input_backprop_data,
2560 DeviceMemory<Eigen::half>* input_h_backprop_data,
2561 DeviceMemory<Eigen::half>* input_c_backprop_data,
2562 DeviceMemory<Eigen::half>* params_backprop_data,
2563 DeviceMemory<uint8>* reserve_space_data,
2564 ScratchAllocator* workspace_allocator,
2565 dnn::ProfileResult* output_profile_result) {
2566 const CudnnRnnDescriptor& cudnn_rnn_desc =
2567 static_cast<const CudnnRnnDescriptor&>(rnn_desc);
2568 const CudnnRnnSequenceTensorDescriptor& cudnn_input_desc =
2569 static_cast<const CudnnRnnSequenceTensorDescriptor&>(input_desc);
2570 const CudnnRnnStateTensorDescriptor& cudnn_input_h_desc =
2571 static_cast<const CudnnRnnStateTensorDescriptor&>(input_h_desc);
2572 const CudnnRnnStateTensorDescriptor& cudnn_input_c_desc =
2573 static_cast<const CudnnRnnStateTensorDescriptor&>(input_c_desc);
2574 const CudnnRnnSequenceTensorDescriptor& cudnn_output_desc =
2575 static_cast<const CudnnRnnSequenceTensorDescriptor&>(output_desc);
2576 const CudnnRnnStateTensorDescriptor& cudnn_output_h_desc =
2577 static_cast<const CudnnRnnStateTensorDescriptor&>(output_h_desc);
2578 const CudnnRnnStateTensorDescriptor& cudnn_output_c_desc =
2579 static_cast<const CudnnRnnStateTensorDescriptor&>(output_c_desc);
2580 return IsStatusOk(
2581 DoRnnBackwardImpl<Eigen::half>(
2582 stream, cudnn_rnn_desc, cudnn_input_desc, input_data,
2583 seq_lengths_data, cudnn_input_h_desc, input_h_data,
2584 cudnn_input_c_desc, input_c_data, params, cudnn_output_desc,
2585 output_data, cudnn_output_h_desc, output_h_data, cudnn_output_c_desc,
2586 output_c_data, output_backprop_data, output_h_backprop_data,
2587 output_c_backprop_data, input_backprop_data, input_h_backprop_data,
2588 input_c_backprop_data, params_backprop_data, reserve_space_data,
2589 workspace_allocator, output_profile_result),
2590 /*report_error=*/!output_profile_result);
2591 }
2592
DoRnnBackward(Stream * stream,const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<float> & input_data,const DeviceMemory<int> & seq_lengths_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<float> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<float> & input_c_data,const DeviceMemory<float> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,const DeviceMemory<float> & output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,const DeviceMemory<float> & output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,const DeviceMemory<float> & output_c_data,const DeviceMemory<float> & output_backprop_data,const DeviceMemory<float> & output_h_backprop_data,const DeviceMemory<float> & output_c_backprop_data,DeviceMemory<float> * input_backprop_data,DeviceMemory<float> * input_h_backprop_data,DeviceMemory<float> * input_c_backprop_data,DeviceMemory<float> * params_backprop_data,DeviceMemory<uint8> * reserve_space_data,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)2593 bool CudnnSupport::DoRnnBackward(
2594 Stream* stream, const dnn::RnnDescriptor& rnn_desc,
2595 const dnn::RnnSequenceTensorDescriptor& input_desc,
2596 const DeviceMemory<float>& input_data,
2597 const DeviceMemory<int>& seq_lengths_data,
2598 const dnn::RnnStateTensorDescriptor& input_h_desc,
2599 const DeviceMemory<float>& input_h_data,
2600 const dnn::RnnStateTensorDescriptor& input_c_desc,
2601 const DeviceMemory<float>& input_c_data, const DeviceMemory<float>& params,
2602 const dnn::RnnSequenceTensorDescriptor& output_desc,
2603 const DeviceMemory<float>& output_data,
2604 const dnn::RnnStateTensorDescriptor& output_h_desc,
2605 const DeviceMemory<float>& output_h_data,
2606 const dnn::RnnStateTensorDescriptor& output_c_desc,
2607 const DeviceMemory<float>& output_c_data,
2608 const DeviceMemory<float>& output_backprop_data,
2609 const DeviceMemory<float>& output_h_backprop_data,
2610 const DeviceMemory<float>& output_c_backprop_data,
2611 DeviceMemory<float>* input_backprop_data,
2612 DeviceMemory<float>* input_h_backprop_data,
2613 DeviceMemory<float>* input_c_backprop_data,
2614 DeviceMemory<float>* params_backprop_data,
2615 DeviceMemory<uint8>* reserve_space_data,
2616 ScratchAllocator* workspace_allocator,
2617 dnn::ProfileResult* output_profile_result) {
2618 const CudnnRnnDescriptor& cudnn_rnn_desc =
2619 static_cast<const CudnnRnnDescriptor&>(rnn_desc);
2620 const CudnnRnnSequenceTensorDescriptor& cudnn_input_desc =
2621 static_cast<const CudnnRnnSequenceTensorDescriptor&>(input_desc);
2622 const CudnnRnnStateTensorDescriptor& cudnn_input_h_desc =
2623 static_cast<const CudnnRnnStateTensorDescriptor&>(input_h_desc);
2624 const CudnnRnnStateTensorDescriptor& cudnn_input_c_desc =
2625 static_cast<const CudnnRnnStateTensorDescriptor&>(input_c_desc);
2626 const CudnnRnnSequenceTensorDescriptor& cudnn_output_desc =
2627 static_cast<const CudnnRnnSequenceTensorDescriptor&>(output_desc);
2628 const CudnnRnnStateTensorDescriptor& cudnn_output_h_desc =
2629 static_cast<const CudnnRnnStateTensorDescriptor&>(output_h_desc);
2630 const CudnnRnnStateTensorDescriptor& cudnn_output_c_desc =
2631 static_cast<const CudnnRnnStateTensorDescriptor&>(output_c_desc);
2632 return IsStatusOk(
2633 DoRnnBackwardImpl<float>(
2634 stream, cudnn_rnn_desc, cudnn_input_desc, input_data,
2635 seq_lengths_data, cudnn_input_h_desc, input_h_data,
2636 cudnn_input_c_desc, input_c_data, params, cudnn_output_desc,
2637 output_data, cudnn_output_h_desc, output_h_data, cudnn_output_c_desc,
2638 output_c_data, output_backprop_data, output_h_backprop_data,
2639 output_c_backprop_data, input_backprop_data, input_h_backprop_data,
2640 input_c_backprop_data, params_backprop_data, reserve_space_data,
2641 workspace_allocator, output_profile_result),
2642 /*report_error=*/!output_profile_result);
2643 }
2644
DoRnnBackward(Stream * stream,const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<double> & input_data,const DeviceMemory<int> & seq_lengths_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<double> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<double> & input_c_data,const DeviceMemory<double> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,const DeviceMemory<double> & output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,const DeviceMemory<double> & output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,const DeviceMemory<double> & output_c_data,const DeviceMemory<double> & output_backprop_data,const DeviceMemory<double> & output_h_backprop_data,const DeviceMemory<double> & output_c_backprop_data,DeviceMemory<double> * input_backprop_data,DeviceMemory<double> * input_h_backprop_data,DeviceMemory<double> * input_c_backprop_data,DeviceMemory<double> * params_backprop_data,DeviceMemory<uint8> * reserve_space_data,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)2645 bool CudnnSupport::DoRnnBackward(
2646 Stream* stream, const dnn::RnnDescriptor& rnn_desc,
2647 const dnn::RnnSequenceTensorDescriptor& input_desc,
2648 const DeviceMemory<double>& input_data,
2649 const DeviceMemory<int>& seq_lengths_data,
2650 const dnn::RnnStateTensorDescriptor& input_h_desc,
2651 const DeviceMemory<double>& input_h_data,
2652 const dnn::RnnStateTensorDescriptor& input_c_desc,
2653 const DeviceMemory<double>& input_c_data,
2654 const DeviceMemory<double>& params,
2655 const dnn::RnnSequenceTensorDescriptor& output_desc,
2656 const DeviceMemory<double>& output_data,
2657 const dnn::RnnStateTensorDescriptor& output_h_desc,
2658 const DeviceMemory<double>& output_h_data,
2659 const dnn::RnnStateTensorDescriptor& output_c_desc,
2660 const DeviceMemory<double>& output_c_data,
2661 const DeviceMemory<double>& output_backprop_data,
2662 const DeviceMemory<double>& output_h_backprop_data,
2663 const DeviceMemory<double>& output_c_backprop_data,
2664 DeviceMemory<double>* input_backprop_data,
2665 DeviceMemory<double>* input_h_backprop_data,
2666 DeviceMemory<double>* input_c_backprop_data,
2667 DeviceMemory<double>* params_backprop_data,
2668 DeviceMemory<uint8>* reserve_space_data,
2669 ScratchAllocator* workspace_allocator,
2670 dnn::ProfileResult* output_profile_result) {
2671 const CudnnRnnDescriptor& cudnn_rnn_desc =
2672 static_cast<const CudnnRnnDescriptor&>(rnn_desc);
2673 const CudnnRnnSequenceTensorDescriptor& cudnn_input_desc =
2674 static_cast<const CudnnRnnSequenceTensorDescriptor&>(input_desc);
2675 const CudnnRnnStateTensorDescriptor& cudnn_input_h_desc =
2676 static_cast<const CudnnRnnStateTensorDescriptor&>(input_h_desc);
2677 const CudnnRnnStateTensorDescriptor& cudnn_input_c_desc =
2678 static_cast<const CudnnRnnStateTensorDescriptor&>(input_c_desc);
2679 const CudnnRnnSequenceTensorDescriptor& cudnn_output_desc =
2680 static_cast<const CudnnRnnSequenceTensorDescriptor&>(output_desc);
2681 const CudnnRnnStateTensorDescriptor& cudnn_output_h_desc =
2682 static_cast<const CudnnRnnStateTensorDescriptor&>(output_h_desc);
2683 const CudnnRnnStateTensorDescriptor& cudnn_output_c_desc =
2684 static_cast<const CudnnRnnStateTensorDescriptor&>(output_c_desc);
2685 return IsStatusOk(
2686 DoRnnBackwardImpl<double>(
2687 stream, cudnn_rnn_desc, cudnn_input_desc, input_data,
2688 seq_lengths_data, cudnn_input_h_desc, input_h_data,
2689 cudnn_input_c_desc, input_c_data, params, cudnn_output_desc,
2690 output_data, cudnn_output_h_desc, output_h_data, cudnn_output_c_desc,
2691 output_c_data, output_backprop_data, output_h_backprop_data,
2692 output_c_backprop_data, input_backprop_data, input_h_backprop_data,
2693 input_c_backprop_data, params_backprop_data, reserve_space_data,
2694 workspace_allocator, output_profile_result),
2695 /*report_error=*/!output_profile_result);
2696 }
2697
2698 namespace {
2699
2700 // TODO(csigg): Merge a lot of duplicate code below for forward, backward data,
2701 // and backward filter.
2702
GetCudnnConvolutionForwardAlgo(const CudnnHandle & cudnn,const CudnnTensorDescriptor & input_nd,const CudnnFilterDescriptor & filter,const CudnnConvolutionDescriptor & conv,const CudnnTensorDescriptor & output_nd,bool specify_workspace_limit,size_t memory_limit_bytes)2703 port::StatusOr<cudnnConvolutionFwdAlgo_t> GetCudnnConvolutionForwardAlgo(
2704 const CudnnHandle& cudnn, const CudnnTensorDescriptor& input_nd,
2705 const CudnnFilterDescriptor& filter, const CudnnConvolutionDescriptor& conv,
2706 const CudnnTensorDescriptor& output_nd, bool specify_workspace_limit,
2707 size_t memory_limit_bytes) {
2708 #if CUDNN_VERSION >= 8000
2709 const int num_requested_algos = 5;
2710 int num_returned_algos = 0;
2711 cudnnConvolutionFwdAlgoPerf_t perf_results[num_requested_algos];
2712
2713 RETURN_IF_CUDNN_ERROR(cudnnGetConvolutionForwardAlgorithm_v7(
2714 cudnn.handle(), input_nd.handle(), filter.handle(), conv.handle(),
2715 output_nd.handle(), num_requested_algos, &num_returned_algos,
2716 perf_results));
2717
2718 size_t mem_limit = specify_workspace_limit ? memory_limit_bytes : 0ULL;
2719 for (int r = 0; r < num_returned_algos; r++) {
2720 if (perf_results[r].status == CUDNN_STATUS_SUCCESS &&
2721 perf_results[r].algo != CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED &&
2722 perf_results[r].memory <= mem_limit) {
2723 return perf_results[r].algo;
2724 }
2725 }
2726 return port::Status(port::error::INTERNAL,
2727 "cudnnGetConvolutionForwardAlgorithm_v7 returned "
2728 "no suitable algorithms. This could be a cudnn bug.");
2729 #else
2730 cudnnConvolutionFwdPreference_t preference =
2731 specify_workspace_limit ? CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT
2732 : CUDNN_CONVOLUTION_FWD_NO_WORKSPACE;
2733 cudnnConvolutionFwdAlgo_t algo_to_use;
2734 RETURN_IF_CUDNN_ERROR(cudnnGetConvolutionForwardAlgorithm(
2735 cudnn.handle(), input_nd.handle(), filter.handle(), conv.handle(),
2736 output_nd.handle(), preference, memory_limit_bytes, &algo_to_use));
2737 return algo_to_use;
2738 #endif
2739 }
2740
2741 port::StatusOr<cudnnConvolutionBwdDataAlgo_t>
GetCudnnConvolutionBackwardDataAlgo(const CudnnHandle & cudnn,const CudnnTensorDescriptor & input_nd,const CudnnFilterDescriptor & filter,const CudnnConvolutionDescriptor & conv,const CudnnTensorDescriptor & output_nd,bool specify_workspace_limit,size_t memory_limit_bytes)2742 GetCudnnConvolutionBackwardDataAlgo(const CudnnHandle& cudnn,
2743 const CudnnTensorDescriptor& input_nd,
2744 const CudnnFilterDescriptor& filter,
2745 const CudnnConvolutionDescriptor& conv,
2746 const CudnnTensorDescriptor& output_nd,
2747 bool specify_workspace_limit,
2748 size_t memory_limit_bytes) {
2749 #if CUDNN_VERSION >= 8000
2750 const int num_requested_algos = 5;
2751 int num_returned_algos = 0;
2752 cudnnConvolutionBwdDataAlgoPerf_t perf_results[num_requested_algos];
2753
2754 RETURN_IF_CUDNN_ERROR(cudnnGetConvolutionBackwardDataAlgorithm_v7(
2755 cudnn.handle(), filter.handle(), output_nd.handle(), conv.handle(),
2756 input_nd.handle(), num_requested_algos, &num_returned_algos,
2757 perf_results));
2758
2759 size_t mem_limit = specify_workspace_limit ? memory_limit_bytes : 0ULL;
2760 for (int r = 0; r < num_returned_algos; r++) {
2761 if (perf_results[r].status == CUDNN_STATUS_SUCCESS &&
2762 perf_results[r].algo !=
2763 CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD_NONFUSED &&
2764 perf_results[r].memory <= mem_limit) {
2765 return perf_results[r].algo;
2766 }
2767 }
2768 return port::Status(port::error::INTERNAL,
2769 "cudnnGetConvolutionBackwardDataAlgorithm_v7 returned "
2770 "no suitable algorithms. This could be a cudnn bug.");
2771 #else
2772 cudnnConvolutionBwdDataPreference_t preference =
2773 specify_workspace_limit
2774 ? CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT
2775 : CUDNN_CONVOLUTION_BWD_DATA_NO_WORKSPACE;
2776 cudnnConvolutionBwdDataAlgo_t algo_to_use;
2777 RETURN_IF_CUDNN_ERROR(cudnnGetConvolutionBackwardDataAlgorithm(
2778 cudnn.handle(), filter.handle(), output_nd.handle(), conv.handle(),
2779 input_nd.handle(), preference, memory_limit_bytes, &algo_to_use));
2780 return algo_to_use;
2781 #endif
2782 }
2783
2784 port::StatusOr<cudnnConvolutionBwdFilterAlgo_t>
GetCudnnConvolutionBackwardFilterAlgo(const CudnnHandle & cudnn,const CudnnTensorDescriptor & input_nd,const CudnnFilterDescriptor & filter,const CudnnConvolutionDescriptor & conv,const CudnnTensorDescriptor & output_nd,bool specify_workspace_limit,size_t memory_limit_bytes)2785 GetCudnnConvolutionBackwardFilterAlgo(const CudnnHandle& cudnn,
2786 const CudnnTensorDescriptor& input_nd,
2787 const CudnnFilterDescriptor& filter,
2788 const CudnnConvolutionDescriptor& conv,
2789 const CudnnTensorDescriptor& output_nd,
2790 bool specify_workspace_limit,
2791 size_t memory_limit_bytes) {
2792 #if CUDNN_VERSION >= 8000
2793 const int num_requested_algos = 5;
2794 int num_returned_algos = 0;
2795 cudnnConvolutionBwdFilterAlgoPerf_t perf_results[num_requested_algos];
2796
2797 RETURN_IF_CUDNN_ERROR(cudnnGetConvolutionBackwardFilterAlgorithm_v7(
2798 cudnn.handle(), input_nd.handle(), output_nd.handle(), conv.handle(),
2799 filter.handle(), num_requested_algos, &num_returned_algos, perf_results));
2800
2801 size_t mem_limit = specify_workspace_limit ? memory_limit_bytes : 0ULL;
2802 for (int r = 0; r < num_returned_algos; r++) {
2803 if (perf_results[r].status == CUDNN_STATUS_SUCCESS &&
2804 perf_results[r].algo !=
2805 CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD_NONFUSED &&
2806 perf_results[r].memory <= mem_limit) {
2807 return perf_results[r].algo;
2808 }
2809 }
2810 return port::Status(port::error::INTERNAL,
2811 "cudnnGetConvolutionBackwardFilterAlgorithm_v7 returned "
2812 "no suitable algorithms. This could be a cudnn bug.");
2813 #else
2814 cudnnConvolutionBwdFilterPreference_t preference =
2815 specify_workspace_limit
2816 ? CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT
2817 : CUDNN_CONVOLUTION_BWD_FILTER_NO_WORKSPACE;
2818 cudnnConvolutionBwdFilterAlgo_t algo_to_use;
2819 RETURN_IF_CUDNN_ERROR(cudnnGetConvolutionBackwardFilterAlgorithm(
2820 cudnn.handle(), input_nd.handle(), output_nd.handle(), conv.handle(),
2821 filter.handle(), preference, memory_limit_bytes, &algo_to_use));
2822 return algo_to_use;
2823 #endif
2824 }
2825
AllocateCudnnConvolutionForwardWorkspace(Stream * stream,const CudnnHandle & cudnn,const CudnnTensorDescriptor & input_nd,const CudnnFilterDescriptor & filter,const CudnnConvolutionDescriptor & conv,const CudnnTensorDescriptor & output_nd,const dnn::AlgorithmDesc & algorithm_desc,ScratchAllocator * scratch_allocator)2826 port::StatusOr<DeviceMemory<uint8>> AllocateCudnnConvolutionForwardWorkspace(
2827 Stream* stream, const CudnnHandle& cudnn,
2828 const CudnnTensorDescriptor& input_nd, const CudnnFilterDescriptor& filter,
2829 const CudnnConvolutionDescriptor& conv,
2830 const CudnnTensorDescriptor& output_nd,
2831 const dnn::AlgorithmDesc& algorithm_desc,
2832 ScratchAllocator* scratch_allocator) {
2833 if (IsTensorMathOpSet(conv) != algorithm_desc.tensor_ops_enabled()) {
2834 return port::Status(
2835 port::error::INTERNAL,
2836 "Mismatch between cudnn conv and algorithm descriptors.");
2837 }
2838
2839 // Query the size of the workspace and allocate it.
2840 size_t size_in_bytes;
2841 RETURN_IF_CUDNN_ERROR(cudnnGetConvolutionForwardWorkspaceSize(
2842 cudnn.handle(),
2843 /*xDesc=*/input_nd.handle(),
2844 /*wDesc=*/filter.handle(), /*convDesc=*/conv.handle(),
2845 /*yDesc=*/output_nd.handle(), /*algo=*/ToConvForwardAlgo(algorithm_desc),
2846 /*sizeInBytes=*/&size_in_bytes));
2847
2848 int64_t size_in_bytes_int64 = size_in_bytes;
2849
2850 if (TF_PREDICT_FALSE(size_in_bytes_int64 < 0)) {
2851 return port::Status(
2852 port::error::INTERNAL,
2853 "cudnnGetConvolutionForwardWorkspaceSize() returned "
2854 "negative sizeInBytes value. This could be a cudnn bug.");
2855 }
2856
2857 if (size_in_bytes_int64 == 0) {
2858 return DeviceMemory<uint8>();
2859 }
2860
2861 if (TF_PREDICT_FALSE(!scratch_allocator)) {
2862 return port::Status(port::error::INVALID_ARGUMENT,
2863 "No scratch allocator provided");
2864 }
2865
2866 return scratch_allocator->AllocateBytes(size_in_bytes);
2867 }
2868
2869 port::StatusOr<DeviceMemory<uint8>>
AllocateCudnnConvolutionBackwardDataWorkspace(Stream * stream,const CudnnHandle & cudnn,const CudnnTensorDescriptor & input_nd,const CudnnFilterDescriptor & filter,const CudnnConvolutionDescriptor & conv,const CudnnTensorDescriptor & output_nd,const dnn::AlgorithmDesc & algorithm_desc,ScratchAllocator * scratch_allocator)2870 AllocateCudnnConvolutionBackwardDataWorkspace(
2871 Stream* stream, const CudnnHandle& cudnn,
2872 const CudnnTensorDescriptor& input_nd, const CudnnFilterDescriptor& filter,
2873 const CudnnConvolutionDescriptor& conv,
2874 const CudnnTensorDescriptor& output_nd,
2875 const dnn::AlgorithmDesc& algorithm_desc,
2876 ScratchAllocator* scratch_allocator) {
2877 if (IsTensorMathOpSet(conv) != algorithm_desc.tensor_ops_enabled()) {
2878 return port::Status(
2879 port::error::INTERNAL,
2880 "Mismatch between cudnn conv and algorithm descriptors.");
2881 }
2882
2883 // Query the size of the workspace and allocate it.
2884 size_t size_in_bytes;
2885 RETURN_IF_CUDNN_ERROR(cudnnGetConvolutionBackwardDataWorkspaceSize(
2886 cudnn.handle(),
2887 /*wDesc=*/filter.handle(),
2888 /*dyDesc=*/output_nd.handle(),
2889 /*convDesc=*/conv.handle(),
2890 /*dxDesc=*/input_nd.handle(),
2891 /*algo=*/ToConvBackwardDataAlgo(algorithm_desc),
2892 /*sizeInBytes=*/&size_in_bytes));
2893
2894 int64_t size_in_bytes_int64 = size_in_bytes;
2895
2896 if (TF_PREDICT_FALSE(size_in_bytes_int64 < 0)) {
2897 return port::Status(
2898 port::error::INTERNAL,
2899 "cudnnGetConvolutionBackwardDataWorkspaceSize() returned "
2900 "negative sizeInBytes value. This could be a cudnn bug.");
2901 }
2902
2903 if (size_in_bytes_int64 == 0) {
2904 return DeviceMemory<uint8>();
2905 }
2906
2907 if (TF_PREDICT_FALSE(!scratch_allocator)) {
2908 return port::Status(port::error::INVALID_ARGUMENT,
2909 "No scratch allocator provided");
2910 }
2911
2912 return scratch_allocator->AllocateBytes(size_in_bytes);
2913 }
2914
2915 port::StatusOr<DeviceMemory<uint8>>
AllocateCudnnConvolutionBackwardFilterWorkspace(Stream * stream,const CudnnHandle & cudnn,const CudnnTensorDescriptor & input_nd,const CudnnFilterDescriptor & filter,const CudnnConvolutionDescriptor & conv,const CudnnTensorDescriptor & output_nd,const dnn::AlgorithmDesc & algorithm_desc,ScratchAllocator * scratch_allocator)2916 AllocateCudnnConvolutionBackwardFilterWorkspace(
2917 Stream* stream, const CudnnHandle& cudnn,
2918 const CudnnTensorDescriptor& input_nd, const CudnnFilterDescriptor& filter,
2919 const CudnnConvolutionDescriptor& conv,
2920 const CudnnTensorDescriptor& output_nd,
2921 const dnn::AlgorithmDesc& algorithm_desc,
2922 ScratchAllocator* scratch_allocator) {
2923 if (IsTensorMathOpSet(conv) != algorithm_desc.tensor_ops_enabled()) {
2924 return port::Status(
2925 port::error::INTERNAL,
2926 "Mismatch between cudnn conv and algorithm descriptors.");
2927 }
2928
2929 // Query the size of the workspace and allocate it.
2930 size_t size_in_bytes;
2931 RETURN_IF_CUDNN_ERROR(cudnnGetConvolutionBackwardFilterWorkspaceSize(
2932 cudnn.handle(),
2933 /*xDesc=*/input_nd.handle(),
2934 /*dyDesc=*/output_nd.handle(),
2935 /*convDesc=*/conv.handle(),
2936 /*gradDesc=*/filter.handle(),
2937 /*algo=*/ToConvBackwardFilterAlgo(algorithm_desc),
2938 /*sizeInBytes=*/&size_in_bytes));
2939
2940 int64_t size_in_bytes_int64 = size_in_bytes;
2941
2942 if (TF_PREDICT_FALSE(size_in_bytes_int64 < 0)) {
2943 return port::Status(
2944 port::error::INTERNAL,
2945 "cudnnGetConvolutionBackwardFilterWorkspaceSize() returned "
2946 "negative sizeInBytes value. This could be a cudnn bug.");
2947 }
2948
2949 if (size_in_bytes_int64 == 0) {
2950 return DeviceMemory<uint8>();
2951 }
2952
2953 if (TF_PREDICT_FALSE(!scratch_allocator)) {
2954 return port::Status(port::error::INVALID_ARGUMENT,
2955 "No scratch allocator provided");
2956 }
2957
2958 return scratch_allocator->AllocateBytes(size_in_bytes);
2959 }
2960
UseTensorOps(Stream * stream,dnn::DataType type,absl::optional<dnn::AlgorithmDesc> desc)2961 port::StatusOr<bool> UseTensorOps(Stream* stream, dnn::DataType type,
2962 absl::optional<dnn::AlgorithmDesc> desc) {
2963 bool use_tensor_ops;
2964 if (desc.has_value()) {
2965 use_tensor_ops = desc->tensor_ops_enabled();
2966 if (use_tensor_ops && !IsTensorMathEnabled(stream, type)) {
2967 return port::Status(port::error::INVALID_ARGUMENT,
2968 "Algo requests disabled tensor op evaluation.");
2969 }
2970 } else {
2971 use_tensor_ops = IsTensorMathEnabled(stream, type);
2972 }
2973 return use_tensor_ops;
2974 }
2975
2976 cudnnDataType_t GetRnnComputeType(dnn::DataType data_type);
2977 dnn::DataType GetConvAccumulatorType(dnn::DataType data_type);
2978
GetCudnnConvolutionForwardAlgorithm(Stream * stream,const CudnnHandle & cudnn,const dnn::AlgorithmConfig & algorithm_config,const CudnnTensorDescriptor & input_nd,const CudnnFilterDescriptor & filter,dnn::DataType element_type,const dnn::ConvolutionDescriptor & convolution_descriptor,const CudnnTensorDescriptor & output_nd,ScratchAllocator * scratch_allocator,DeviceMemory<uint8> * scratch)2979 port::StatusOr<dnn::AlgorithmDesc> GetCudnnConvolutionForwardAlgorithm(
2980 Stream* stream, const CudnnHandle& cudnn,
2981 const dnn::AlgorithmConfig& algorithm_config,
2982 const CudnnTensorDescriptor& input_nd, const CudnnFilterDescriptor& filter,
2983 dnn::DataType element_type,
2984 const dnn::ConvolutionDescriptor& convolution_descriptor,
2985 const CudnnTensorDescriptor& output_nd, ScratchAllocator* scratch_allocator,
2986 DeviceMemory<uint8>* scratch) {
2987 absl::optional<dnn::AlgorithmDesc> algo_desc = algorithm_config.algorithm();
2988
2989 CudnnConvolutionDescriptor conv(
2990 convolution_descriptor,
2991 ToCudnnDataType(GetConvAccumulatorType(element_type)));
2992 bool use_tensor_ops;
2993 SE_ASSIGN_OR_RETURN(use_tensor_ops,
2994 UseTensorOps(stream, element_type, algo_desc));
2995 conv.set_use_tensor_op_math(use_tensor_ops);
2996
2997 if (!algo_desc.has_value()) {
2998 // Pick fastest algorithm within memory limit according to cuDNN's
2999 // heuristics.
3000 bool specify_workspace_limit = scratch_allocator != nullptr;
3001 auto memory_limit_bytes =
3002 specify_workspace_limit
3003 ? std::max(scratch_allocator->GetMemoryLimitInBytes(), int64{0})
3004 : int64{0};
3005 SE_ASSIGN_OR_RETURN(cudnnConvolutionFwdAlgo_t algo,
3006 GetCudnnConvolutionForwardAlgo(
3007 cudnn, input_nd, filter, conv, output_nd,
3008 specify_workspace_limit, memory_limit_bytes));
3009 algo_desc = dnn::AlgorithmDesc(algo, use_tensor_ops);
3010 }
3011
3012 const auto scratch_or = AllocateCudnnConvolutionForwardWorkspace(
3013 stream, cudnn, input_nd, filter, conv, output_nd, *algo_desc,
3014 scratch_allocator);
3015
3016 if (scratch_or.ok()) {
3017 *scratch = scratch_or.ValueOrDie();
3018 return *algo_desc;
3019 }
3020
3021 algo_desc = algorithm_config.algorithm_no_scratch();
3022
3023 // Failed to allocate workspace for the first algorithm, fall back to the
3024 // no_scratch algorithm.
3025 if (!algo_desc.has_value()) {
3026 return port::Status(
3027 scratch_or.status().code(),
3028 absl::StrCat("The primary convolution algorithm failed, ",
3029 "while a secondary algorithm is not provided. ",
3030 "Returned status: ", scratch_or.status().ToString()));
3031 }
3032
3033 SE_ASSIGN_OR_RETURN(use_tensor_ops,
3034 UseTensorOps(stream, element_type, algo_desc));
3035 conv.set_use_tensor_op_math(use_tensor_ops);
3036 SE_ASSIGN_OR_RETURN(*scratch, AllocateCudnnConvolutionForwardWorkspace(
3037 stream, cudnn, input_nd, filter, conv,
3038 output_nd, *algo_desc, scratch_allocator));
3039 return *algo_desc;
3040 }
3041
GetCudnnConvolutionBackwardDataAlgorithm(Stream * stream,const CudnnHandle & cudnn,const dnn::AlgorithmConfig & algorithm_config,const CudnnTensorDescriptor & input_nd,const CudnnFilterDescriptor & filter,dnn::DataType element_type,const dnn::ConvolutionDescriptor & convolution_descriptor,const CudnnTensorDescriptor & output_nd,ScratchAllocator * scratch_allocator,DeviceMemory<uint8> * scratch)3042 port::StatusOr<dnn::AlgorithmDesc> GetCudnnConvolutionBackwardDataAlgorithm(
3043 Stream* stream, const CudnnHandle& cudnn,
3044 const dnn::AlgorithmConfig& algorithm_config,
3045 const CudnnTensorDescriptor& input_nd, const CudnnFilterDescriptor& filter,
3046 dnn::DataType element_type,
3047 const dnn::ConvolutionDescriptor& convolution_descriptor,
3048 const CudnnTensorDescriptor& output_nd, ScratchAllocator* scratch_allocator,
3049 DeviceMemory<uint8>* scratch) {
3050 absl::optional<dnn::AlgorithmDesc> algo_desc = algorithm_config.algorithm();
3051 CudnnConvolutionDescriptor conv(
3052 convolution_descriptor,
3053 ToCudnnDataType(GetConvAccumulatorType(element_type)));
3054 bool use_tensor_ops;
3055 SE_ASSIGN_OR_RETURN(use_tensor_ops,
3056 UseTensorOps(stream, element_type, algo_desc));
3057 conv.set_use_tensor_op_math(use_tensor_ops);
3058
3059 if (!algo_desc.has_value()) {
3060 // Pick fastest algorithm within memory limit according to cuDNN's
3061 // heuristics.
3062 bool specify_workspace_limit = scratch_allocator != nullptr;
3063 auto memory_limit_bytes =
3064 specify_workspace_limit
3065 ? std::max(scratch_allocator->GetMemoryLimitInBytes(), int64{0})
3066 : int64{0};
3067 SE_ASSIGN_OR_RETURN(cudnnConvolutionBwdDataAlgo_t algo,
3068 GetCudnnConvolutionBackwardDataAlgo(
3069 cudnn, input_nd, filter, conv, output_nd,
3070 specify_workspace_limit, memory_limit_bytes));
3071 algo_desc = dnn::AlgorithmDesc(algo, use_tensor_ops);
3072 }
3073
3074 const auto scratch_or = AllocateCudnnConvolutionBackwardDataWorkspace(
3075 stream, cudnn, input_nd, filter, conv, output_nd, *algo_desc,
3076 scratch_allocator);
3077
3078 if (scratch_or.ok()) {
3079 *scratch = scratch_or.ValueOrDie();
3080 return *algo_desc;
3081 }
3082
3083 algo_desc = algorithm_config.algorithm_no_scratch();
3084
3085 // Failed to allocate workspace for the first algorithm, fall back to the
3086 // no_scratch algorithm.
3087 if (!algo_desc.has_value()) {
3088 return port::Status(
3089 port::error::INVALID_ARGUMENT,
3090 "The primary convolution algorithm failed memory allocation, "
3091 "while a secondary algorithm is not provided.");
3092 }
3093
3094 SE_ASSIGN_OR_RETURN(use_tensor_ops,
3095 UseTensorOps(stream, element_type, algo_desc));
3096 conv.set_use_tensor_op_math(use_tensor_ops);
3097 SE_ASSIGN_OR_RETURN(*scratch, AllocateCudnnConvolutionBackwardDataWorkspace(
3098 stream, cudnn, input_nd, filter, conv,
3099 output_nd, *algo_desc, scratch_allocator));
3100 return *algo_desc;
3101 }
3102
GetCudnnConvolutionBackwardFilterAlgorithm(Stream * stream,const CudnnHandle & cudnn,const dnn::AlgorithmConfig & algorithm_config,const CudnnTensorDescriptor & input_nd,const CudnnFilterDescriptor & filter,dnn::DataType element_type,const dnn::ConvolutionDescriptor & convolution_descriptor,const CudnnTensorDescriptor & output_nd,ScratchAllocator * scratch_allocator,DeviceMemory<uint8> * scratch)3103 port::StatusOr<dnn::AlgorithmDesc> GetCudnnConvolutionBackwardFilterAlgorithm(
3104 Stream* stream, const CudnnHandle& cudnn,
3105 const dnn::AlgorithmConfig& algorithm_config,
3106 const CudnnTensorDescriptor& input_nd, const CudnnFilterDescriptor& filter,
3107 dnn::DataType element_type,
3108 const dnn::ConvolutionDescriptor& convolution_descriptor,
3109 const CudnnTensorDescriptor& output_nd, ScratchAllocator* scratch_allocator,
3110 DeviceMemory<uint8>* scratch) {
3111 absl::optional<dnn::AlgorithmDesc> algo_desc = algorithm_config.algorithm();
3112 CudnnConvolutionDescriptor conv(
3113 convolution_descriptor,
3114 ToCudnnDataType(GetConvAccumulatorType(element_type)));
3115 bool use_tensor_ops;
3116 SE_ASSIGN_OR_RETURN(use_tensor_ops,
3117 UseTensorOps(stream, element_type, algo_desc));
3118 conv.set_use_tensor_op_math(use_tensor_ops);
3119
3120 if (!algo_desc.has_value()) {
3121 // Pick fastest algorithm within memory limit according to cuDNN's
3122 // heuristics.
3123 bool specify_workspace_limit = scratch_allocator != nullptr;
3124 auto memory_limit_bytes =
3125 specify_workspace_limit
3126 ? std::max(scratch_allocator->GetMemoryLimitInBytes(), int64{0})
3127 : int64{0};
3128 SE_ASSIGN_OR_RETURN(cudnnConvolutionBwdFilterAlgo_t algo,
3129 GetCudnnConvolutionBackwardFilterAlgo(
3130 cudnn, input_nd, filter, conv, output_nd,
3131 specify_workspace_limit, memory_limit_bytes));
3132 algo_desc = dnn::AlgorithmDesc(algo, use_tensor_ops);
3133 }
3134
3135 port::StatusOr<DeviceMemory<uint8>> scratch_or =
3136 AllocateCudnnConvolutionBackwardFilterWorkspace(
3137 stream, cudnn, input_nd, filter, conv, output_nd, *algo_desc,
3138 scratch_allocator);
3139
3140 if (scratch_or.ok()) {
3141 *scratch = scratch_or.ValueOrDie();
3142 return *algo_desc;
3143 }
3144
3145 algo_desc = algorithm_config.algorithm_no_scratch();
3146
3147 // Failed to allocate workspace for the first algorithm, fall back to the
3148 // no_scratch algorithm.
3149 if (!algo_desc.has_value()) {
3150 return port::Status(
3151 port::error::INVALID_ARGUMENT,
3152 absl::StrCat(
3153 "The primary convolution algorithm failed memory allocation, "
3154 "while a secondary algorithm is not provided. Actual error: ",
3155 scratch_or.status().ToString()));
3156 }
3157
3158 SE_ASSIGN_OR_RETURN(use_tensor_ops,
3159 UseTensorOps(stream, element_type, algo_desc));
3160 conv.set_use_tensor_op_math(use_tensor_ops);
3161 SE_ASSIGN_OR_RETURN(*scratch, AllocateCudnnConvolutionBackwardFilterWorkspace(
3162 stream, cudnn, input_nd, filter, conv,
3163 output_nd, *algo_desc, scratch_allocator));
3164 return *algo_desc;
3165 }
3166
3167 // A helper class to set env-vars and choose options for cudnn-related
3168 // algorithms.
3169 template <typename EnvVar>
3170 class CudnnEnvVar {
3171 public:
IsEnabled()3172 static bool IsEnabled() {
3173 static bool is_enabled = IsEnabledImpl();
3174 return is_enabled;
3175 }
3176
3177 private:
IsEnabledImpl()3178 static bool IsEnabledImpl() {
3179 const char* tf_env_var_val = getenv(EnvVar::kName);
3180 if (tf_env_var_val != nullptr) {
3181 absl::string_view tf_env_var_val_str(tf_env_var_val);
3182 if (tf_env_var_val_str == "0") {
3183 return false;
3184 }
3185 return true;
3186 }
3187 return EnvVar::kDefaultFlag;
3188 }
3189 };
3190
3191 // A helper struct to decide whether to enable the FFT_TILING algorithms for
3192 // forward convolution. It is disabled for cuDNN < 7 due to memory corruption
3193 // caused by some shapes with this algorithm. Users can explicitly enable the
3194 // algorithm through an env-var "TF_ENABLE_FFT_TILING_FORWARD=1".
3195 struct FftTilingForward {
3196 static constexpr const char* kName = "TF_ENABLE_FFT_TILING_FORWARD";
3197 static constexpr bool kDefaultFlag = true;
3198 };
3199
3200 // A helper struct to decide whether to enable the WINOGRAD_NONFUSED algorithms.
3201 // By default it is turned on, users can explicitly disable them through an
3202 // env-var "TF_ENABLE_WINOGRAD_NONFUSED=0".
3203 // https://github.com/tensorflow/tensorflow/pull/4901
3204 // For CUDNN v8.1, when this env-var is turned off, both the winograd and
3205 // winograd-non-fused engines will be ruled out.
3206 struct WinogradNonfused {
3207 static constexpr const char* kName = "TF_ENABLE_WINOGRAD_NONFUSED";
3208 // NVIDIA has fixed winograd nonfused bug for cudnn v>=7. For older versions,
3209 // we have a workaround.
3210 static constexpr bool kDefaultFlag = true;
3211 };
3212
3213 // A helper struct to decide whether to use FP32 as the internal compute type
3214 // for convolution when the input data type is FP16. By default it is turned on,
3215 // users can explicitly disable them (choose to use FP16 as the internal compute
3216 // type) through an env-var "TF_FP16_CONV_USE_FP32_COMPUTE=0".
3217 struct ConvDoFP32ComputationFP16Input {
3218 static constexpr const char* kName = "TF_FP16_CONV_USE_FP32_COMPUTE";
3219 // Using FP16 as the internal compute type for convolution when the input data
3220 // type is FP16 is only supported on architectures with true fp16 support
3221 // (compute capability 5.3 and 6.0). Setting this to false in an unsupported
3222 // architecture will cause internal errors.
3223 static constexpr bool kDefaultFlag = true;
3224 };
3225
3226 // A helper struct to decide whether to use FP32 as the internal compute type
3227 // for rnn when the input data type is FP16. At present it is turned off,
3228 // users can explicitly control them through an env-var
3229 // TF_FP16_RNN_USE_FP32_COMPUTE.
3230 // After the TODO below is fixed, users should almost always use fp32 compute
3231 // type for training. Using fp16 might suffer suboptimal accuracy due to loss
3232 // in precision.
3233 struct RnnDoFP32ComputationFP16Input {
3234 static constexpr const char* kName = "TF_FP16_RNN_USE_FP32_COMPUTE";
3235 // TODO(jamesqin): b/78182362 flip to true when cudnn 7.1.4 fixes the bug.
3236 // Before cudnn 7.1.4 RNN are always done in fp32, no matter what math
3237 // precision is set.
3238 // Set it temporary to false s.t. no error is raised when using fp16 inputs,
3239 // fp32 math precision.
3240 //
3241 // cuDNN == 7.5.0 is verified to have this fixed.
3242 static constexpr bool kDefaultFlag = CUDNN_VERSION >= 7500;
3243 };
3244
3245 namespace {
3246
3247 #if CUDNN_VERSION >= 8100 && TF_ENABLE_CUDNN_FRONTEND
3248
GenericEngineFilter(cudnnBackendDescriptor_t engine_config,bool disable_winograd,bool disable_nondeterminism,bool disable_tensor_core)3249 bool GenericEngineFilter(cudnnBackendDescriptor_t engine_config,
3250 bool disable_winograd, bool disable_nondeterminism,
3251 bool disable_tensor_core) {
3252 bool ret = cudnn_frontend::hasNumericalNote<
3253 CUDNN_NUMERICAL_NOTE_DOWN_CONVERT_INPUTS>(engine_config);
3254
3255 if (disable_winograd) {
3256 ret |= cudnn_frontend::hasNumericalNote<CUDNN_NUMERICAL_NOTE_WINOGRAD>(
3257 engine_config);
3258 }
3259
3260 if (disable_nondeterminism) {
3261 ret |=
3262 cudnn_frontend::hasNumericalNote<CUDNN_NUMERICAL_NOTE_NONDETERMINISTIC>(
3263 engine_config);
3264 }
3265
3266 if (disable_tensor_core) {
3267 ret |= cudnn_frontend::hasNumericalNote<CUDNN_NUMERICAL_NOTE_TENSOR_CORE>(
3268 engine_config);
3269 }
3270
3271 return ret;
3272 }
3273
3274 #endif // CUDNN_VERSION >= 8100 && TF_ENABLE_CUDNN_FRONTEND
3275
3276 } // namespace
3277
GetRnnComputeType(dnn::DataType data_type)3278 cudnnDataType_t GetRnnComputeType(dnn::DataType data_type) {
3279 switch (data_type) {
3280 case dnn::DataType::kFloat:
3281 return CUDNN_DATA_FLOAT;
3282 case dnn::DataType::kDouble:
3283 return CUDNN_DATA_DOUBLE;
3284 case dnn::DataType::kHalf:
3285 if (CudnnEnvVar<RnnDoFP32ComputationFP16Input>::IsEnabled()) {
3286 return CUDNN_DATA_FLOAT;
3287 } else {
3288 return CUDNN_DATA_HALF;
3289 }
3290 default:
3291 LOG(FATAL) << "Invalid RNN data type: " << static_cast<int>(data_type);
3292 }
3293 }
3294
GetConvAccumulatorType(dnn::DataType data_type)3295 dnn::DataType GetConvAccumulatorType(dnn::DataType data_type) {
3296 switch (data_type) {
3297 case dnn::DataType::kFloat:
3298 case dnn::DataType::kDouble:
3299 return data_type;
3300 case dnn::DataType::kHalf:
3301 return CudnnEnvVar<ConvDoFP32ComputationFP16Input>::IsEnabled()
3302 ? dnn::DataType::kFloat
3303 : dnn::DataType::kHalf;
3304 case dnn::DataType::kInt8:
3305 case dnn::DataType::kInt32:
3306 return dnn::DataType::kInt32;
3307 #if CUDNN_VERSION >= 8200
3308 case dnn::DataType::kBF16:
3309 return CudnnEnvVar<ConvDoFP32ComputationFP16Input>::IsEnabled()
3310 ? dnn::DataType::kFloat
3311 : dnn::DataType::kBF16;
3312 #endif
3313 default:
3314 LOG(FATAL) << "Invalid DNN data type: " << static_cast<int>(data_type);
3315 }
3316 }
3317
3318 #if CUDNN_VERSION >= 8100 && TF_ENABLE_CUDNN_FRONTEND
GetCudnnConvolutionType(dnn::ConvolutionKind kind)3319 cudnnBackendDescriptorType_t GetCudnnConvolutionType(
3320 dnn::ConvolutionKind kind) {
3321 cudnnBackendDescriptorType_t conv_mode;
3322 switch (kind) {
3323 case dnn::ConvolutionKind::FORWARD: {
3324 conv_mode = CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR;
3325 break;
3326 }
3327 case dnn::ConvolutionKind::BACKWARD_DATA: {
3328 conv_mode = CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR;
3329 break;
3330 }
3331 case dnn::ConvolutionKind::BACKWARD_FILTER: {
3332 conv_mode =
3333 CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR;
3334 break;
3335 }
3336 default:
3337 LOG(FATAL) << "Unexpected convolution kind " << static_cast<int>(kind);
3338 break;
3339 }
3340 return conv_mode;
3341 }
3342
3343 // Cudnn only supports vectorization over the channel dimension (e.g., int8x4,
3344 // or int8x32).
GetTensorVectorSizeAndDim(const dnn::BatchDescriptor & tensor,dnn::DataType element_type)3345 std::tuple<int, int> GetTensorVectorSizeAndDim(
3346 const dnn::BatchDescriptor& tensor, dnn::DataType element_type) {
3347 int vector_size = 1;
3348 int vector_dim = -1;
3349 if (element_type == dnn::DataType::kInt8) {
3350 if (tensor.layout() == dnn::DataLayout::kBatchDepthYX4) {
3351 vector_size = 4;
3352 vector_dim = 1;
3353 } else if (tensor.layout() == dnn::DataLayout::kBatchDepthYX32) {
3354 vector_size = 32;
3355 vector_dim = 1;
3356 }
3357 }
3358 return std::make_tuple(vector_size, vector_dim);
3359 }
3360
GetTensorVectorSizeAndDim(const dnn::FilterDescriptor & filter,dnn::DataType element_type)3361 std::tuple<int, int> GetTensorVectorSizeAndDim(
3362 const dnn::FilterDescriptor& filter, dnn::DataType element_type) {
3363 int vector_size = 1;
3364 int vector_dim = -1;
3365 if (element_type == dnn::DataType::kInt8) {
3366 if (filter.layout() == dnn::FilterLayout::kOutputInputYX4) {
3367 vector_size = 4;
3368 vector_dim = 1;
3369 } else if (filter.layout() == dnn::FilterLayout::kOutputInputYX32) {
3370 vector_size = 32;
3371 vector_dim = 1;
3372 }
3373 }
3374 return std::make_tuple(vector_size, vector_dim);
3375 }
3376
3377 port::StatusOr<std::unique_ptr<cudnn_frontend::OperationGraph>>
GetCudnnOperationGraph(dnn::ConvolutionKind kind,dnn::DataType element_type,Stream * stream,const dnn::BatchDescriptor & input_descriptor,const dnn::FilterDescriptor & filter_descriptor,const dnn::BatchDescriptor & output_descriptor,const dnn::ConvolutionDescriptor & convolution_descriptor,CudnnHandle & cudnn)3378 GetCudnnOperationGraph(dnn::ConvolutionKind kind, dnn::DataType element_type,
3379 Stream* stream,
3380 const dnn::BatchDescriptor& input_descriptor,
3381 const dnn::FilterDescriptor& filter_descriptor,
3382 const dnn::BatchDescriptor& output_descriptor,
3383 const dnn::ConvolutionDescriptor& convolution_descriptor,
3384 CudnnHandle& cudnn) {
3385 cudnnBackendDescriptorType_t conv_mode = GetCudnnConvolutionType(kind);
3386 cudnnDataType_t cudnn_type = ToCudnnDataType(element_type);
3387
3388 // x tensor.
3389 int vector_size, vector_dim;
3390 std::tie(vector_size, vector_dim) =
3391 GetTensorVectorSizeAndDim(input_descriptor, element_type);
3392 std::vector<int64> input_dims = input_descriptor.vectorized_dims(
3393 dnn::DataLayout::kBatchDepthYX, vector_size, vector_dim);
3394 std::vector<int64> input_strides = input_descriptor.vectorized_strides(
3395 dnn::DataLayout::kBatchDepthYX, vector_size, vector_dim);
3396
3397 if (vector_size == 32) {
3398 return port::InternalError(
3399 "cuDNN frontend doesn't support int8x32 at the "
3400 "moment.");
3401 }
3402
3403 auto tensor_x = cudnn_frontend::TensorBuilder()
3404 .setDim(input_dims.size(), input_dims.data())
3405 .setStrides(input_dims.size(), input_strides.data())
3406 .setId('x')
3407 .setAlignment(32)
3408 .setDataType(cudnn_type)
3409 .setVectorCountAndDimension(vector_size, vector_dim)
3410 .build();
3411 RETURN_MSG_IF_CUDNN_ERROR(tensor_x);
3412
3413 // y tensor.
3414 std::tie(vector_size, vector_dim) =
3415 GetTensorVectorSizeAndDim(output_descriptor, element_type);
3416 std::vector<int64> output_dims = output_descriptor.vectorized_dims(
3417 dnn::DataLayout::kBatchDepthYX, vector_size, vector_dim);
3418 std::vector<int64> output_strides = output_descriptor.vectorized_strides(
3419 dnn::DataLayout::kBatchDepthYX, vector_size, vector_dim);
3420
3421 auto tensor_y = cudnn_frontend::TensorBuilder()
3422 .setDim(output_dims.size(), output_dims.data())
3423 .setStrides(output_dims.size(), output_strides.data())
3424 .setId('y')
3425 .setAlignment(32)
3426 .setDataType(cudnn_type)
3427 .setVectorCountAndDimension(vector_size, vector_dim)
3428 .build();
3429 RETURN_MSG_IF_CUDNN_ERROR(tensor_y);
3430
3431 // w tensor.
3432 std::tie(vector_size, vector_dim) =
3433 GetTensorVectorSizeAndDim(filter_descriptor, element_type);
3434 std::vector<int64> filter_dims = filter_descriptor.vectorized_dims(
3435 dnn::FilterLayout::kOutputInputYX, vector_size, vector_dim);
3436 std::vector<int64> filter_strides = filter_descriptor.vectorized_strides(
3437 dnn::FilterLayout::kOutputInputYX, vector_size, vector_dim);
3438
3439 auto tensor_w = cudnn_frontend::TensorBuilder()
3440 .setDim(filter_dims.size(), filter_dims.data())
3441 .setStrides(filter_dims.size(), filter_strides.data())
3442 .setId('w')
3443 .setAlignment(32)
3444 .setDataType(cudnn_type)
3445 .setVectorCountAndDimension(vector_size, vector_dim)
3446 .build();
3447 RETURN_MSG_IF_CUDNN_ERROR(tensor_w);
3448
3449 // conv_desc.
3450 auto mode = convolution_descriptor.convolution_not_crosscorr()
3451 ? CUDNN_CONVOLUTION
3452 : CUDNN_CROSS_CORRELATION;
3453
3454 int conv_dim = convolution_descriptor.ndims();
3455
3456 auto accumulator_type = ToCudnnDataType(GetConvAccumulatorType(element_type));
3457 CHECK_NE(convolution_descriptor.pad_alignment(),
3458 dnn::PadAlignment::kTensorFlowPadding)
3459 << "TensorFlow padding alignment is not supported.";
3460
3461 auto conv_desc =
3462 cudnn_frontend::ConvDescBuilder()
3463 .setComputePrecision(accumulator_type)
3464 .setMathMode(mode)
3465 .setNDims(conv_dim)
3466 .setStrides(conv_dim, convolution_descriptor.strides().data())
3467 .setPrePadding(conv_dim, convolution_descriptor.padding().data())
3468 .setPostPadding(conv_dim, convolution_descriptor.padding().data())
3469 .setDilation(conv_dim, convolution_descriptor.dilations().data())
3470 .build();
3471 RETURN_MSG_IF_CUDNN_ERROR(conv_desc);
3472
3473 double alpha = 1.0;
3474 double beta = 0.0;
3475
3476 // CUDNN Operation
3477 auto op = cudnn_frontend::OperationBuilder(conv_mode)
3478 .setxDesc(tensor_x)
3479 .setyDesc(tensor_y)
3480 .setwDesc(tensor_w)
3481 .setcDesc(conv_desc)
3482 .setAlpha(alpha)
3483 .setBeta(beta)
3484 .build();
3485 RETURN_MSG_IF_CUDNN_ERROR(op);
3486
3487 // CUDNN OperationGraph
3488 std::array<cudnn_frontend::Operation const*, 1> ops = {&op};
3489 auto opGraph = cudnn_frontend::OperationGraphBuilder()
3490 .setHandle(cudnn.handle())
3491 .setOperationGraph(ops.size(), ops.data())
3492 .build();
3493 RETURN_MSG_IF_CUDNN_ERROR(opGraph);
3494
3495 VLOG(4) << "\nTensor_x: " << tensor_x.describe()
3496 << "\nTensor_y: " << tensor_y.describe()
3497 << "\nTensor_w: " << tensor_w.describe()
3498 << "\nConv: " << conv_desc.describe() << "\nOp: " << op.describe()
3499 << "\nOpGraph: " << opGraph.describe();
3500
3501 return std::unique_ptr<cudnn_frontend::OperationGraph>(
3502 new cudnn_frontend::OperationGraph(std::move(opGraph)));
3503 }
3504
3505 port::StatusOr<std::unique_ptr<cudnn_frontend::OperationGraph>>
GetCudnnFusedOperationGraph(dnn::ConvolutionKind kind,dnn::DataType element_type,double alpha,double alpha2,Stream * stream,const dnn::BatchDescriptor & input_descriptor,const dnn::FilterDescriptor & filter_descriptor,dnn::BatchDescriptor bias_descriptor,const dnn::BatchDescriptor & output_descriptor,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::ActivationMode activation_mode,CudnnHandle & cudnn)3506 GetCudnnFusedOperationGraph(
3507 dnn::ConvolutionKind kind, dnn::DataType element_type, double alpha,
3508 double alpha2, Stream* stream, const dnn::BatchDescriptor& input_descriptor,
3509 const dnn::FilterDescriptor& filter_descriptor,
3510 dnn::BatchDescriptor bias_descriptor,
3511 const dnn::BatchDescriptor& output_descriptor,
3512 const dnn::ConvolutionDescriptor& convolution_descriptor,
3513 const dnn::ActivationMode activation_mode, CudnnHandle& cudnn) {
3514 cudnnBackendDescriptorType_t conv_mode = GetCudnnConvolutionType(kind);
3515 cudnnDataType_t cudnn_type = ToCudnnDataType(element_type);
3516
3517 // CUDNN fused operation supports the pattern in the form of
3518 // Conv + Add + BiasAdd + Act. Therefore, we need to build a graph of the
3519 // four ops with their input/output tensor edges:
3520 // Conv : input: tensor_x, tensor_w; output: tensor_conv (virtual)
3521 // Add : input: tensor_conv, tensor_z; output: tensor_add (virtual)
3522 // BiasAdd: input: tensor_add, tensor_b; output: tensor_bias (virtual)
3523 // Act : input: tensor_bias; output: tensor_y
3524 int vector_size, vector_dim;
3525 std::tie(vector_size, vector_dim) =
3526 GetTensorVectorSizeAndDim(input_descriptor, element_type);
3527 std::vector<int64> input_dims = input_descriptor.vectorized_dims(
3528 dnn::DataLayout::kBatchDepthYX, vector_size, vector_dim);
3529 std::vector<int64> input_strides = input_descriptor.vectorized_strides(
3530 dnn::DataLayout::kBatchDepthYX, vector_size, vector_dim);
3531
3532 if (vector_size == 32) {
3533 return port::InternalError(
3534 "cuDNN frontend doesn't support int8x32 at the "
3535 "moment.");
3536 }
3537
3538 auto tensor_x = cudnn_frontend::TensorBuilder()
3539 .setDim(input_dims.size(), input_dims.data())
3540 .setStrides(input_dims.size(), input_strides.data())
3541 .setId('x')
3542 .setAlignment(32)
3543 .setDataType(cudnn_type)
3544 .setVectorCountAndDimension(vector_size, vector_dim)
3545 .build();
3546 RETURN_MSG_IF_CUDNN_ERROR(tensor_x);
3547
3548 std::tie(vector_size, vector_dim) =
3549 GetTensorVectorSizeAndDim(output_descriptor, element_type);
3550 std::vector<int64> output_dims = output_descriptor.vectorized_dims(
3551 dnn::DataLayout::kBatchDepthYX, vector_size, vector_dim);
3552 std::vector<int64> output_strides = output_descriptor.vectorized_strides(
3553 dnn::DataLayout::kBatchDepthYX, vector_size, vector_dim);
3554 auto tensor_y = cudnn_frontend::TensorBuilder()
3555 .setDim(output_dims.size(), output_dims.data())
3556 .setStrides(output_dims.size(), output_strides.data())
3557 .setId('y')
3558 .setAlignment(32)
3559 .setDataType(cudnn_type)
3560 .setVectorCountAndDimension(vector_size, vector_dim)
3561 .build();
3562 RETURN_MSG_IF_CUDNN_ERROR(tensor_y);
3563
3564 auto tensor_z = cudnn_frontend::TensorBuilder()
3565 .setDim(output_dims.size(), &output_dims[0])
3566 .setStrides(output_dims.size(), &output_strides[0])
3567 .setId('z')
3568 .setAlignment(32)
3569 .setDataType(cudnn_type)
3570 .setVectorCountAndDimension(vector_size, vector_dim)
3571 .build();
3572 RETURN_MSG_IF_CUDNN_ERROR(tensor_z);
3573
3574 std::tie(vector_size, vector_dim) =
3575 GetTensorVectorSizeAndDim(filter_descriptor, element_type);
3576 std::vector<int64> filter_dims = filter_descriptor.vectorized_dims(
3577 dnn::FilterLayout::kOutputInputYX, vector_size, vector_dim);
3578 std::vector<int64> filter_strides = filter_descriptor.vectorized_strides(
3579 dnn::FilterLayout::kOutputInputYX, vector_size, vector_dim);
3580 auto tensor_w = cudnn_frontend::TensorBuilder()
3581 .setDim(filter_dims.size(), filter_dims.data())
3582 .setStrides(filter_dims.size(), filter_strides.data())
3583 .setId('w')
3584 .setAlignment(32)
3585 .setDataType(cudnn_type)
3586 .setVectorCountAndDimension(vector_size, vector_dim)
3587 .build();
3588 RETURN_MSG_IF_CUDNN_ERROR(tensor_w);
3589
3590 // For the purposes of the cudnn graph, say that the bias tensor has the same
3591 // layout as the output tensor. It doesn't actually matter, because bias is a
3592 // 1D array. But we need to get the correct vectorization, otherwise the
3593 // cudnn graph API rejects this tensor.
3594 bias_descriptor.set_layout(output_descriptor.layout());
3595
3596 std::tie(vector_size, vector_dim) =
3597 GetTensorVectorSizeAndDim(bias_descriptor, element_type);
3598 std::vector<int64> bias_dims = bias_descriptor.vectorized_dims(
3599 dnn::DataLayout::kBatchDepthYX, vector_size, vector_dim);
3600 std::vector<int64> bias_strides = bias_descriptor.vectorized_strides(
3601 dnn::DataLayout::kBatchDepthYX, vector_size, vector_dim);
3602 auto tensor_b = cudnn_frontend::TensorBuilder()
3603 .setDim(bias_dims.size(), bias_dims.data())
3604 .setStrides(bias_dims.size(), bias_strides.data())
3605 .setId('b')
3606 .setAlignment(32)
3607 .setDataType(cudnn_type)
3608 .setVectorCountAndDimension(vector_size, vector_dim)
3609 .build();
3610 RETURN_MSG_IF_CUDNN_ERROR(tensor_b);
3611
3612 std::tie(vector_size, vector_dim) =
3613 GetTensorVectorSizeAndDim(output_descriptor, element_type);
3614 auto tensor_conv = cudnn_frontend::TensorBuilder()
3615 .setDim(output_dims.size(), &output_dims[0])
3616 .setStrides(output_dims.size(), &output_strides[0])
3617 .setVirtual()
3618 .setId('C')
3619 .setAlignment(32)
3620 .setDataType(cudnn_type)
3621 .setVectorCountAndDimension(vector_size, vector_dim)
3622 .build();
3623 RETURN_MSG_IF_CUDNN_ERROR(tensor_conv);
3624
3625 auto tensor_add = cudnn_frontend::TensorBuilder()
3626 .setDim(output_dims.size(), &output_dims[0])
3627 .setStrides(output_dims.size(), &output_strides[0])
3628 .setVirtual()
3629 .setId('A')
3630 .setAlignment(32)
3631 .setDataType(cudnn_type)
3632 .setVectorCountAndDimension(vector_size, vector_dim)
3633 .build();
3634 RETURN_MSG_IF_CUDNN_ERROR(tensor_add);
3635
3636 auto tensor_bias = cudnn_frontend::TensorBuilder()
3637 .setDim(output_dims.size(), &output_dims[0])
3638 .setStrides(output_dims.size(), &output_strides[0])
3639 .setVirtual()
3640 .setId('B')
3641 .setAlignment(32)
3642 .setDataType(cudnn_type)
3643 .setVectorCountAndDimension(vector_size, vector_dim)
3644 .build();
3645 RETURN_MSG_IF_CUDNN_ERROR(tensor_bias);
3646
3647 // conv_desc.
3648 auto mode = convolution_descriptor.convolution_not_crosscorr()
3649 ? CUDNN_CONVOLUTION
3650 : CUDNN_CROSS_CORRELATION;
3651
3652 int conv_dim = convolution_descriptor.ndims();
3653
3654 auto accumulator_type = ToCudnnDataType(GetConvAccumulatorType(element_type));
3655 CHECK_NE(convolution_descriptor.pad_alignment(),
3656 dnn::PadAlignment::kTensorFlowPadding)
3657 << "TensorFlow padding alignment is not supported.";
3658
3659 auto conv_desc =
3660 cudnn_frontend::ConvDescBuilder()
3661 .setComputePrecision(accumulator_type)
3662 .setMathMode(mode)
3663 .setNDims(conv_dim)
3664 .setStrides(conv_dim, convolution_descriptor.strides().data())
3665 .setPrePadding(conv_dim, convolution_descriptor.padding().data())
3666 .setPostPadding(conv_dim, convolution_descriptor.padding().data())
3667 .setDilation(conv_dim, convolution_descriptor.dilations().data())
3668 .build();
3669 RETURN_MSG_IF_CUDNN_ERROR(conv_desc);
3670
3671 // Beta is the scaling factor for output.
3672 double beta = 0.0;
3673
3674 // CUDNN Operation
3675 auto conv_op = cudnn_frontend::OperationBuilder(conv_mode)
3676 .setxDesc(tensor_x)
3677 .setyDesc(tensor_conv)
3678 .setwDesc(tensor_w)
3679 .setcDesc(conv_desc)
3680 .setAlpha(alpha)
3681 .setBeta(beta)
3682 .build();
3683 RETURN_MSG_IF_CUDNN_ERROR(conv_op);
3684
3685 auto add_desc = cudnn_frontend::PointWiseDescBuilder()
3686 .setMode(CUDNN_POINTWISE_ADD)
3687 .setMathPrecision(cudnn_type)
3688 .build();
3689 auto add_op = cudnn_frontend::OperationBuilder(
3690 CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
3691 .setxDesc(conv_op.getOutputTensor())
3692 .setbDesc(tensor_z)
3693 .setyDesc(tensor_add)
3694 .setpwDesc(add_desc)
3695 .setAlpha(alpha)
3696 .setAlpha2(alpha2)
3697 .build();
3698 RETURN_MSG_IF_CUDNN_ERROR(add_op);
3699
3700 auto bias_add_desc = cudnn_frontend::PointWiseDescBuilder()
3701 .setMode(CUDNN_POINTWISE_ADD)
3702 .setMathPrecision(cudnn_type)
3703 .build();
3704
3705 // If the activation is the identity function, then the bias-add is the last
3706 // op, and it writes to the output, tensor_y. Otherwise, it writes to the
3707 // "virtual tensor" (temp buffer) tensor_bias, to which we apply the
3708 // activation.
3709 auto& bias_out_desc =
3710 activation_mode == dnn::ActivationMode::kNone ? tensor_y : tensor_bias;
3711 auto bias_add_op = cudnn_frontend::OperationBuilder(
3712 CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
3713 .setxDesc(add_op.getOutputTensor())
3714 .setbDesc(tensor_b)
3715 .setyDesc(bias_out_desc)
3716 .setpwDesc(bias_add_desc)
3717 .build();
3718 RETURN_MSG_IF_CUDNN_ERROR(bias_add_op);
3719
3720 // CUDNN OperationGraph
3721 absl::InlinedVector<cudnn_frontend::Operation const*, 4> ops = {
3722 &conv_op, &add_op, &bias_add_op};
3723
3724 absl::optional<cudnn_frontend::PointWiseDesc_v8> act_desc;
3725 absl::optional<cudnn_frontend::Operation_v8> act_op;
3726 switch (activation_mode) {
3727 case dnn::ActivationMode::kNone:
3728 break;
3729 case dnn::ActivationMode::kRelu:
3730 act_desc.emplace(cudnn_frontend::PointWiseDescBuilder()
3731 .setMode(CUDNN_POINTWISE_RELU_FWD)
3732 .setMathPrecision(cudnn_type)
3733 .build());
3734 act_op.emplace(cudnn_frontend::OperationBuilder(
3735 CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
3736 .setxDesc(bias_add_op.getOutputTensor())
3737 .setyDesc(tensor_y)
3738 .setpwDesc(*act_desc)
3739 .build());
3740 RETURN_MSG_IF_CUDNN_ERROR(*act_op);
3741 ops.push_back(&*act_op);
3742 break;
3743 default:
3744 return port::InternalError(
3745 absl::StrCat("Unimplemented activation mode ",
3746 dnn::ActivationModeString(activation_mode)));
3747 }
3748
3749 auto op_graph = cudnn_frontend::OperationGraphBuilder()
3750 .setHandle(cudnn.handle())
3751 .setOperationGraph(ops.size(), ops.data())
3752 .build();
3753 RETURN_MSG_IF_CUDNN_ERROR(op_graph);
3754
3755 VLOG(4) << "\nTensor_x: " << tensor_x.describe()
3756 << "\nTensor_y: " << tensor_y.describe()
3757 << "\nTensor_z: " << tensor_z.describe()
3758 << "\nTensor_w: " << tensor_w.describe()
3759 << "\nTensor_b: " << tensor_b.describe()
3760 << "\nTensor_conv: " << tensor_conv.describe()
3761 << "\nTensor_add: " << tensor_add.describe()
3762 << "\nTensor_bias: " << tensor_bias.describe()
3763 << "\nConv: " << conv_desc.describe()
3764 << "\nAdd: " << add_desc.describe()
3765 << "\nBiasAdd: " << bias_add_desc.describe() //
3766 << "\nAct: "
3767 << (act_desc.has_value() ? act_desc->describe() : "(identity)")
3768 << "\nConvOp: " << conv_op.describe()
3769 << "\nAddOp: " << add_op.describe()
3770 << "\nBiasAddOp: " << bias_add_op.describe() //
3771 << "\nActOp: "
3772 << (act_op.has_value() ? act_op->describe() : "(identity)")
3773 << "\nOpGraph: " << op_graph.describe();
3774
3775 return std::unique_ptr<cudnn_frontend::OperationGraph>(
3776 new cudnn_frontend::OperationGraph(std::move(op_graph)));
3777 }
3778
3779 port::StatusOr<std::unique_ptr<cudnn_frontend::ExecutionPlan>>
GetFirstWorkingExecutionPlan(Stream * stream,dnn::DataType element_type,const std::unique_ptr<cudnn_frontend::OperationGraph> & op_graph,dnn::ConvolutionKind kind,CudnnHandle & cudnn,ScratchAllocator * scratch_allocator)3780 GetFirstWorkingExecutionPlan(
3781 Stream* stream, dnn::DataType element_type,
3782 const std::unique_ptr<cudnn_frontend::OperationGraph>& op_graph,
3783 dnn::ConvolutionKind kind, CudnnHandle& cudnn,
3784 ScratchAllocator* scratch_allocator) {
3785 auto heuristics = cudnn_frontend::EngineHeuristicsBuilder()
3786 .setOperationGraph(*op_graph)
3787 .setHeurMode(CUDNN_HEUR_MODE_INSTANT)
3788 .build();
3789 RETURN_MSG_IF_CUDNN_ERROR(heuristics);
3790
3791 cudnnBackendDescriptorType_t conv_mode = GetCudnnConvolutionType(kind);
3792 auto fallback = cudnn_frontend::EngineFallbackListBuilder()
3793 .setOperationGraph(*op_graph)
3794 .setOperation(conv_mode)
3795 .build();
3796 RETURN_MSG_IF_CUDNN_ERROR(fallback);
3797
3798 // cuDNN frontend sneakily puts error messages on the object and returns
3799 // partially-initialized results when there's an error; make sure to check
3800 // them.
3801 int64_t engine_count = heuristics.getEngineConfigCount();
3802 RETURN_MSG_IF_CUDNN_ERROR(heuristics);
3803 auto& engine_config = heuristics.getEngineConfig(engine_count);
3804 RETURN_MSG_IF_CUDNN_ERROR(heuristics);
3805
3806 auto& fallback_list = fallback.getFallbackList();
3807
3808 cudnn_frontend::EngineConfigList filtered_configs;
3809 auto generic_filter_fn = [=](cudnnBackendDescriptor_t engine_config) -> bool {
3810 return GenericEngineFilter(
3811 engine_config,
3812 /*disable_winograd*/ !CudnnEnvVar<WinogradNonfused>::IsEnabled(),
3813 /*disable_nondeterminism*/ RequireCudnnDeterminism(),
3814 /*disable_tensor_core*/ !IsTensorMathEnabled(stream, element_type));
3815 };
3816
3817 cudnn_frontend::filter(engine_config, filtered_configs, generic_filter_fn);
3818 cudnn_frontend::filter(fallback_list, filtered_configs, generic_filter_fn);
3819
3820 auto fn = []() { return true; };
3821 auto maybe_json_handle_static = CudnnExecutionPlanEngineFilterStatic();
3822 auto maybe_json_handle_runtime = CudnnExecutionPlanEngineFilterRuntime();
3823
3824 for (int i = 0; i < filtered_configs.size(); i++) {
3825 auto plan = cudnn_frontend::ExecutionPlanBuilder()
3826 .setHandle(cudnn.handle())
3827 .setEngineConfig(filtered_configs[i], op_graph->getTag())
3828 .build();
3829 if (plan.get_status() == CUDNN_STATUS_SUCCESS) {
3830 if (maybe_json_handle_static &&
3831 cudnn_frontend::check_errata(*maybe_json_handle_static, plan.getTag(),
3832 cudnn.handle(), fn)) {
3833 VLOG(4) << "Exclude engine (static): " << plan.getTag();
3834 continue;
3835 }
3836 if (maybe_json_handle_runtime &&
3837 cudnn_frontend::check_errata(*maybe_json_handle_runtime,
3838 plan.getTag(), cudnn.handle(), fn)) {
3839 VLOG(4) << "Exclude engine (runtime): " << plan.getTag();
3840 continue;
3841 }
3842
3843 bool specify_workspace_limit = scratch_allocator != nullptr;
3844 auto memory_limit_bytes =
3845 specify_workspace_limit
3846 ? std::max(scratch_allocator->GetMemoryLimitInBytes(), int64{0})
3847 : int64{0};
3848 int64_t workspace_size = plan.getWorkspaceSize();
3849 if (workspace_size <= memory_limit_bytes) {
3850 return std::unique_ptr<cudnn_frontend::ExecutionPlan>(
3851 new cudnn_frontend::ExecutionPlan(std::move(plan)));
3852 }
3853 }
3854 }
3855 return port::Status(port::error::UNKNOWN,
3856 "CUDNN failed to get a working"
3857 " plan.");
3858 }
3859 #endif // CUDNN_VERSION >= 8100 && TF_ENABLE_CUDNN_FRONTEND
3860
3861 } // namespace
3862
DoPrepareForConvolution(dnn::ConvolutionKind kind,dnn::DataType element_type,Stream * stream,const dnn::BatchDescriptor & input_descriptor,DeviceMemoryBase input_data,const dnn::FilterDescriptor & filter_descriptor,DeviceMemoryBase filter_data,const dnn::BatchDescriptor & output_descriptor,DeviceMemoryBase output_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::AlgorithmConfig & algorithm_config,ScratchAllocator * scratch_allocator,dnn::AlgorithmDesc * algorithm_desc,DeviceMemory<uint8> * scratch_memory)3863 port::Status CudnnSupport::DoPrepareForConvolution(
3864 dnn::ConvolutionKind kind, dnn::DataType element_type, Stream* stream,
3865 const dnn::BatchDescriptor& input_descriptor, DeviceMemoryBase input_data,
3866 const dnn::FilterDescriptor& filter_descriptor,
3867 DeviceMemoryBase filter_data, const dnn::BatchDescriptor& output_descriptor,
3868 DeviceMemoryBase output_data,
3869 const dnn::ConvolutionDescriptor& convolution_descriptor,
3870 const dnn::AlgorithmConfig& algorithm_config,
3871 ScratchAllocator* scratch_allocator, dnn::AlgorithmDesc* algorithm_desc,
3872 DeviceMemory<uint8>* scratch_memory) {
3873 CudnnTensorDescriptor input_nd(
3874 input_descriptor,
3875 ToCudnnDataType(element_type, input_descriptor.layout()));
3876 CudnnFilterDescriptor filter_nd(
3877 filter_descriptor,
3878 ToCudnnDataType(element_type, filter_descriptor.layout()));
3879 CudnnTensorDescriptor output_nd(
3880 output_descriptor,
3881 ToCudnnDataType(element_type, output_descriptor.layout()));
3882
3883 auto cudnn = cudnn_->GetHandle(parent_, stream);
3884
3885 switch (kind) {
3886 case dnn::ConvolutionKind::FORWARD: {
3887 SE_ASSIGN_OR_RETURN(*algorithm_desc,
3888 GetCudnnConvolutionForwardAlgorithm(
3889 stream, cudnn, algorithm_config, input_nd,
3890 filter_nd, element_type, convolution_descriptor,
3891 output_nd, scratch_allocator, scratch_memory));
3892 break;
3893 }
3894 case dnn::ConvolutionKind::BACKWARD_DATA: {
3895 SE_ASSIGN_OR_RETURN(*algorithm_desc,
3896 GetCudnnConvolutionBackwardDataAlgorithm(
3897 stream, cudnn, algorithm_config, input_nd,
3898 filter_nd, element_type, convolution_descriptor,
3899 output_nd, scratch_allocator, scratch_memory));
3900 break;
3901 }
3902 case dnn::ConvolutionKind::BACKWARD_FILTER: {
3903 SE_ASSIGN_OR_RETURN(*algorithm_desc,
3904 GetCudnnConvolutionBackwardFilterAlgorithm(
3905 stream, cudnn, algorithm_config, input_nd,
3906 filter_nd, element_type, convolution_descriptor,
3907 output_nd, scratch_allocator, scratch_memory));
3908 break;
3909 }
3910 default:
3911 return port::InternalError(
3912 absl::StrCat("Unexpected convolution kind ", static_cast<int>(kind)));
3913 }
3914
3915 return port::Status::OK();
3916 }
3917
DoConvolve(dnn::ConvolutionKind kind,dnn::DataType element_type,dnn::DataType output_type,Stream * stream,const dnn::BatchDescriptor & input_descriptor,DeviceMemoryBase input_data,const dnn::FilterDescriptor & filter_descriptor,DeviceMemoryBase filter_data,const dnn::BatchDescriptor & output_descriptor,DeviceMemoryBase output_data,const dnn::ConvolutionDescriptor & convolution_descriptor,dnn::AlgorithmDesc algorithm_desc,DeviceMemory<uint8> scratch_memory,dnn::ProfileResult * output_profile_result)3918 port::Status CudnnSupport::DoConvolve(
3919 dnn::ConvolutionKind kind, dnn::DataType element_type,
3920 dnn::DataType output_type, Stream* stream,
3921 const dnn::BatchDescriptor& input_descriptor, DeviceMemoryBase input_data,
3922 const dnn::FilterDescriptor& filter_descriptor,
3923 DeviceMemoryBase filter_data, const dnn::BatchDescriptor& output_descriptor,
3924 DeviceMemoryBase output_data,
3925 const dnn::ConvolutionDescriptor& convolution_descriptor,
3926 dnn::AlgorithmDesc algorithm_desc, DeviceMemory<uint8> scratch_memory,
3927 dnn::ProfileResult* output_profile_result) {
3928 cudnnDataType_t cudnn_type =
3929 ToCudnnDataType(element_type, input_descriptor.layout());
3930
3931 CudnnTensorDescriptor input_nd(input_descriptor, cudnn_type);
3932 CudnnTensorDescriptor output_nd(
3933 output_descriptor,
3934 ToCudnnDataType(output_type, output_descriptor.layout()));
3935 CudnnFilterDescriptor filter_nd(
3936 filter_descriptor,
3937 ToCudnnDataType(element_type, filter_descriptor.layout()));
3938
3939 auto accumulator_type = GetConvAccumulatorType(element_type);
3940 CudnnConvolutionDescriptor conv(convolution_descriptor,
3941 ToCudnnDataType(accumulator_type));
3942 SE_ASSIGN_OR_RETURN(bool use_tensor_ops,
3943 UseTensorOps(stream, element_type, algorithm_desc));
3944 conv.set_use_tensor_op_math(use_tensor_ops);
3945
3946 auto cudnn = cudnn_->GetHandle(parent_, stream);
3947 // Alpha is the scaling factor for input.
3948 float falpha = 1.0;
3949 double dalpha = 1.0;
3950 void* alpha = cudnn_type == CUDNN_DATA_DOUBLE ? static_cast<void*>(&dalpha)
3951 : static_cast<void*>(&falpha);
3952 // Beta is the scaling factor for output.
3953 float fbeta = 0.0;
3954 double dbeta = 0.0;
3955 void* beta = cudnn_type == CUDNN_DATA_DOUBLE ? static_cast<void*>(&dbeta)
3956 : static_cast<void*>(&fbeta);
3957
3958 const bool is_profiling = output_profile_result != nullptr;
3959
3960 std::unique_ptr<GpuTimer, GpuTimerDeleter> timer;
3961 if (is_profiling) {
3962 timer.reset(new GpuTimer(parent_)); // NOLINT
3963 // The start and stop of the timer should be as close to the Cudnn call as
3964 // possible. It is still possible for other threads to issue workload on
3965 // to this stream. So it could take multiple profiling measurements.
3966 if (!timer->Init() || !timer->Start(AsGpuStream(stream))) {
3967 return port::Status(port::error::INTERNAL, "Failed to start timer");
3968 }
3969 }
3970
3971 const auto get_fwd_bugs = [&]() -> port::Status {
3972 if (CUDNN_VERSION < 8000) {
3973 if (algorithm_desc.algo_id() ==
3974 CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM &&
3975 ToCudnnDataType(element_type) == CUDNN_DATA_INT8 &&
3976 ToCudnnDataType(output_type) == CUDNN_DATA_FLOAT) {
3977 return port::Status(
3978 port::error::FAILED_PRECONDITION,
3979 "This configuration potentially produces incorrect results.");
3980 }
3981 }
3982 return port::Status::OK();
3983 };
3984
3985 auto get_bwd_data_bugs = [&]() -> port::Status { return port::Status::OK(); };
3986
3987 const auto get_bwd_filter_bugs = [&]() -> port::Status {
3988 return port::Status::OK();
3989 };
3990
3991 switch (kind) {
3992 case dnn::ConvolutionKind::FORWARD: {
3993 SE_RETURN_IF_ERROR(get_fwd_bugs());
3994 RETURN_IF_CUDNN_ERROR(cudnnConvolutionForward(
3995 cudnn.handle(),
3996 /*alpha=*/alpha, /*srcDesc=*/input_nd.handle(),
3997 /*srcData=*/input_data.opaque(), /*filterDesc=*/filter_nd.handle(),
3998 /*filterData=*/filter_data.opaque(), /*convDesc=*/conv.handle(),
3999 /*algo=*/ToConvForwardAlgo(algorithm_desc),
4000 /*workSpace=*/scratch_memory.opaque(),
4001 /*workSpaceSizeInBytes=*/scratch_memory.size(), /*beta=*/beta,
4002 /*yDesc=*/output_nd.handle(), /*y=*/output_data.opaque()));
4003 break;
4004 }
4005 case dnn::ConvolutionKind::BACKWARD_DATA: {
4006 SE_RETURN_IF_ERROR(get_bwd_data_bugs());
4007 RETURN_IF_CUDNN_ERROR(cudnnConvolutionBackwardData(
4008 cudnn.handle(),
4009 /*alpha=*/alpha,
4010 /*wDesc=*/filter_nd.handle(),
4011 /*w=*/filter_data.opaque(),
4012 /*dyDesc=*/output_nd.handle(),
4013 /*dy=*/output_data.opaque(),
4014 /*convDesc=*/conv.handle(),
4015 /*algo=*/ToConvBackwardDataAlgo(algorithm_desc),
4016 /*workSpace=*/scratch_memory.opaque(),
4017 /*workSpaceSizeInBytes=*/scratch_memory.size(),
4018 /*beta=*/beta,
4019 /*dxDesc=*/input_nd.handle(),
4020 /*dx=*/input_data.opaque()));
4021 break;
4022 }
4023 case dnn::ConvolutionKind::BACKWARD_FILTER: {
4024 SE_RETURN_IF_ERROR(get_bwd_filter_bugs());
4025 RETURN_IF_CUDNN_ERROR(cudnnConvolutionBackwardFilter(
4026 cudnn.handle(),
4027 /*alpha=*/alpha,
4028 /*srcDesc=*/input_nd.handle(),
4029 /*srcData=*/input_data.opaque(),
4030 /*diffDesc=*/output_nd.handle(),
4031 /*diffData=*/output_data.opaque(),
4032 /*convDesc=*/conv.handle(),
4033 /*algo=*/ToConvBackwardFilterAlgo(algorithm_desc),
4034 /*workSpace=*/scratch_memory.opaque(),
4035 /*workSpaceSizeInBytes=*/scratch_memory.size(),
4036 /*beta=*/beta,
4037 /*gradDesc=*/filter_nd.handle(),
4038 /*dw=*/filter_data.opaque()));
4039 break;
4040 }
4041 default:
4042 return port::InternalError(
4043 absl::StrCat("Unexpected convolution kind ", static_cast<int>(kind)));
4044 }
4045
4046 if (is_profiling) {
4047 if (!timer->Stop(AsGpuStream(stream))) {
4048 return port::Status(port::error::INTERNAL, "Failed to stop timer");
4049 }
4050 output_profile_result->set_algorithm(algorithm_desc);
4051 output_profile_result->set_elapsed_time_in_ms(
4052 timer->GetElapsedMilliseconds());
4053 output_profile_result->set_scratch_size(scratch_memory.size());
4054 }
4055
4056 return port::Status::OK();
4057 }
4058
DoConvolveWithExecutionPlan(dnn::ConvolutionKind kind,dnn::DataType element_type,dnn::DataType output_type,Stream * stream,const dnn::BatchDescriptor & input_descriptor,DeviceMemoryBase input_data,const dnn::FilterDescriptor & filter_descriptor,DeviceMemoryBase filter_data,const dnn::BatchDescriptor & output_descriptor,DeviceMemoryBase output_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::AlgorithmConfig & plan_config,ScratchAllocator * scratch_allocator,dnn::ProfileResult * output_profile_result)4059 port::Status CudnnSupport::DoConvolveWithExecutionPlan(
4060 dnn::ConvolutionKind kind, dnn::DataType element_type,
4061 dnn::DataType output_type, Stream* stream,
4062 const dnn::BatchDescriptor& input_descriptor, DeviceMemoryBase input_data,
4063 const dnn::FilterDescriptor& filter_descriptor,
4064 DeviceMemoryBase filter_data, const dnn::BatchDescriptor& output_descriptor,
4065 DeviceMemoryBase output_data,
4066 const dnn::ConvolutionDescriptor& convolution_descriptor,
4067 const dnn::AlgorithmConfig& plan_config,
4068 ScratchAllocator* scratch_allocator,
4069 dnn::ProfileResult* output_profile_result) {
4070 #if CUDNN_VERSION >= 8100 && TF_ENABLE_CUDNN_FRONTEND
4071 auto cudnn = cudnn_->GetHandle(parent_, stream);
4072
4073 absl::optional<dnn::AlgorithmDesc> plan_or = plan_config.algorithm();
4074 absl::optional<dnn::AlgorithmDesc> plan_no_scratch_or =
4075 plan_config.algorithm_no_scratch();
4076
4077 std::unique_ptr<cudnn_frontend::ExecutionPlan> current_plan;
4078 if (!plan_or.has_value()) {
4079 SE_ASSIGN_OR_RETURN(
4080 std::unique_ptr<cudnn_frontend::OperationGraph> op_graph,
4081 GetCudnnOperationGraph(kind, element_type, stream, input_descriptor,
4082 filter_descriptor, output_descriptor,
4083 convolution_descriptor, cudnn));
4084
4085 SE_ASSIGN_OR_RETURN(current_plan, GetFirstWorkingExecutionPlan(
4086 stream, element_type, op_graph, kind,
4087 cudnn, scratch_allocator));
4088 }
4089
4090 size_t workspace_size = 0;
4091 cudnnBackendDescriptor_t plan_desc;
4092 std::string exec_plan_id = "unknown";
4093 if (current_plan) {
4094 exec_plan_id = current_plan->getTag();
4095 workspace_size = current_plan->getWorkspaceSize();
4096 plan_desc = current_plan->get_raw_desc();
4097 } else {
4098 exec_plan_id = plan_or->exec_plan_id();
4099 auto workspace_size_or = plan_config.scratch_size();
4100 if (workspace_size_or.has_value()) {
4101 workspace_size = *workspace_size_or;
4102 }
4103 plan_desc = plan_or->exec_plan_desc();
4104 }
4105 dnn::AlgorithmDesc selected_plan_(exec_plan_id, plan_desc);
4106
4107 DeviceMemory<uint8> scratch_memory;
4108 if (workspace_size > 0) {
4109 auto scratch_or = scratch_allocator->AllocateBytes(workspace_size);
4110 if (scratch_or.ok()) {
4111 scratch_memory = scratch_or.ValueOrDie();
4112 } else if (plan_no_scratch_or.has_value()) {
4113 selected_plan_ = {plan_no_scratch_or->exec_plan_id(),
4114 plan_no_scratch_or->exec_plan_desc()};
4115 } else {
4116 return port::Status(port::error::UNKNOWN,
4117 "CUDNN failed to allocate the scratch space for the "
4118 "plan or to find a working no-scratch plan.");
4119 }
4120 }
4121
4122 void* data_ptrs[] = {input_data.opaque(), output_data.opaque(),
4123 filter_data.opaque()};
4124 int64_t uids[] = {'x', 'y', 'w'};
4125 auto variantPack = cudnn_frontend::VariantPackBuilder()
4126 .setWorkspacePointer(scratch_memory.opaque())
4127 .setDataPointers(3, data_ptrs)
4128 .setUids(3, uids)
4129 .build();
4130 RETURN_MSG_IF_CUDNN_ERROR(variantPack);
4131
4132 VLOG(4) << "\nDo convolution with plan tag: " << selected_plan_.exec_plan_id()
4133 << "\nWorkspace size in bytes: " << workspace_size
4134 << "\nVariantPack: " << variantPack.describe();
4135
4136 const bool is_profiling = output_profile_result != nullptr;
4137
4138 std::unique_ptr<GpuTimer, GpuTimerDeleter> timer;
4139 if (is_profiling) {
4140 timer.reset(new GpuTimer(parent_)); // NOLINT
4141 // The start and stop of the timer should be as close to the Cudnn call as
4142 // possible. It is still possible for other threads to issue workload on
4143 // to this stream. So it could take multiple profiling measurements.
4144 if (!timer->Init() || !timer->Start(AsGpuStream(stream))) {
4145 return port::Status(port::error::INTERNAL, "Failed to start timer");
4146 }
4147 }
4148
4149 cudnnStatus_t status =
4150 cudnnBackendExecute(cudnn.handle(), selected_plan_.exec_plan_desc(),
4151 variantPack.get_raw_desc());
4152 RETURN_IF_CUDNN_ERROR(status);
4153
4154 if (is_profiling) {
4155 if (!timer->Stop(AsGpuStream(stream))) {
4156 return port::Status(port::error::INTERNAL, "Failed to stop timer");
4157 }
4158 output_profile_result->set_algorithm(selected_plan_);
4159 output_profile_result->set_elapsed_time_in_ms(
4160 timer->GetElapsedMilliseconds());
4161 output_profile_result->set_scratch_size(scratch_memory.size());
4162 }
4163
4164 return port::Status::OK();
4165 #else
4166 return port::InternalError(
4167 "To use CuDNN frontend APIs, CuDNN v8.1 or later "
4168 "is required.");
4169 #endif // CUDNN_VERSION >= 8100 && TF_ENABLE_CUDNN_FRONTEND
4170 }
4171
4172 // Utility for dealing with CUDA's type-erased scaling parameters, where some
4173 // sets of parameters expect a void* pointing at a float while others expect it
4174 // to point at a double.
4175 //
4176 // This is rather ugly, but its purpose is to quarantine the corresponding
4177 // ugliness that already exists in the CUDA API.
4178 class ScalingParam {
4179 public:
ScalingParam(double value)4180 explicit ScalingParam(double value) : as_double_(value), as_float_(value) {}
4181
4182 // Return a pointer to the appropriate representation type for the given
4183 // element type.
4184 //
4185 // See
4186 // https://docs.nvidia.com/deeplearning/cudnn/developer-guide/index.html#scaling-parameters
4187 // for more info; the behavior for int8 result tensors is not described there,
4188 // but is maintained from the existing behavior (namely, using a float scaling
4189 // parameter).
ToVoidPointer(dnn::DataType element_type)4190 void* ToVoidPointer(dnn::DataType element_type) {
4191 if (element_type == dnn::DataType::kDouble) {
4192 return &as_double_;
4193 } else {
4194 return &as_float_;
4195 }
4196 }
4197
4198 private:
4199 double as_double_;
4200 float as_float_;
4201 };
4202
GetConvolveExecutionPlans(dnn::ConvolutionKind kind,dnn::DataType element_type,Stream * stream,const dnn::BatchDescriptor & input_descriptor,const dnn::FilterDescriptor & filter_descriptor,const dnn::BatchDescriptor & output_descriptor,const dnn::ConvolutionDescriptor & convolution_descriptor,std::vector<std::unique_ptr<dnn::ConvolveExecutionPlan>> * out_exec_plans)4203 bool CudnnSupport::GetConvolveExecutionPlans(
4204 dnn::ConvolutionKind kind, dnn::DataType element_type, Stream* stream,
4205 const dnn::BatchDescriptor& input_descriptor,
4206 const dnn::FilterDescriptor& filter_descriptor,
4207 const dnn::BatchDescriptor& output_descriptor,
4208 const dnn::ConvolutionDescriptor& convolution_descriptor,
4209 std::vector<std::unique_ptr<dnn::ConvolveExecutionPlan>>* out_exec_plans) {
4210 #if CUDNN_VERSION >= 8100 && TF_ENABLE_CUDNN_FRONTEND
4211 auto cudnn = cudnn_->GetHandle(parent_, stream);
4212 auto op_graph_status = GetCudnnOperationGraph(
4213 kind, element_type, stream, input_descriptor, filter_descriptor,
4214 output_descriptor, convolution_descriptor, cudnn);
4215 if (!op_graph_status.status().ok()) {
4216 return false;
4217 }
4218 auto op_graph = op_graph_status.ConsumeValueOrDie();
4219
4220 auto heur = cudnn_frontend::EngineHeuristicsBuilder()
4221 .setOperationGraph(*op_graph)
4222 .setHeurMode(CUDNN_HEUR_MODE_INSTANT)
4223 .build();
4224 RETURN_FALSE_IF_CUDNN_ERROR(heur);
4225
4226 auto fallback = cudnn_frontend::EngineFallbackListBuilder()
4227 .setOperationGraph(*op_graph)
4228 .setOperation(GetCudnnConvolutionType(kind))
4229 .build();
4230 RETURN_FALSE_IF_CUDNN_ERROR(fallback);
4231
4232 // cuDNN frontend sneakily puts error messages on the object and returns
4233 // partially-initialized results when there's an error; make sure to check
4234 // them.
4235 int64_t engine_count = heur.getEngineConfigCount();
4236 RETURN_FALSE_IF_CUDNN_ERROR(heur);
4237 auto& heur_configs = heur.getEngineConfig(engine_count);
4238 RETURN_FALSE_IF_CUDNN_ERROR(heur);
4239
4240 auto& fallback_configs = fallback.getFallbackList();
4241
4242 VLOG(4) << "\nHeuristics engine configs size: " << heur_configs.size()
4243 << "\nFallback engine configs size: " << fallback_configs.size();
4244
4245 cudnn_frontend::EngineConfigList filtered_configs;
4246 auto generic_filter_fn = [=](cudnnBackendDescriptor_t engine_config) -> bool {
4247 return GenericEngineFilter(
4248 engine_config,
4249 /*disable_winograd*/ !CudnnEnvVar<WinogradNonfused>::IsEnabled(),
4250 /*disable_nondeterminism*/ RequireCudnnDeterminism(),
4251 /*disable_tensor_core*/ !IsTensorMathEnabled(stream, element_type));
4252 };
4253
4254 cudnn_frontend::filter(heur_configs, filtered_configs, generic_filter_fn);
4255 cudnn_frontend::filter(fallback_configs, filtered_configs, generic_filter_fn);
4256
4257 auto fn = []() { return true; };
4258 auto maybe_json_handle_static = CudnnExecutionPlanEngineFilterStatic();
4259 auto maybe_json_handle_runtime = CudnnExecutionPlanEngineFilterRuntime();
4260
4261 VLOG(4) << "\nFiltered engine configs size: " << filtered_configs.size();
4262
4263 out_exec_plans->clear();
4264 for (int i = 0; i < filtered_configs.size(); i++) {
4265 auto plan = cudnn_frontend::ExecutionPlanBuilder()
4266 .setHandle(cudnn.handle())
4267 .setEngineConfig(filtered_configs[i], op_graph->getTag())
4268 .build();
4269 if (plan.get_status() == CUDNN_STATUS_SUCCESS) {
4270 if (maybe_json_handle_static &&
4271 cudnn_frontend::check_errata(*maybe_json_handle_static, plan.getTag(),
4272 cudnn.handle(), fn)) {
4273 VLOG(4) << "Exclude engine (static): " << plan.getTag();
4274 continue;
4275 }
4276 if (maybe_json_handle_runtime &&
4277 cudnn_frontend::check_errata(*maybe_json_handle_runtime,
4278 plan.getTag(), cudnn.handle(), fn)) {
4279 VLOG(4) << "Exclude engine (runtime): " << plan.getTag();
4280 continue;
4281 }
4282
4283 out_exec_plans->push_back(std::unique_ptr<dnn::ConvolveExecutionPlan>(
4284 new CudnnConvolveExecutionPlan(std::move(plan))));
4285 // We will use the first working plan when determinism is required.
4286 if (RequireCudnnDeterminism()) {
4287 break;
4288 }
4289 }
4290 }
4291
4292 VLOG(4) << "\nReturned execution plans size: " << out_exec_plans->size();
4293
4294 return true;
4295 #else
4296 return false;
4297 #endif // CUDNN_VERSION >= 8100 && TF_ENABLE_CUDNN_FRONTEND
4298 }
4299
GetFusedConvolveExecutionPlans(dnn::ConvolutionKind kind,dnn::DataType element_type,double conv_input_scale,double side_input_scale,Stream * stream,const dnn::BatchDescriptor & input_descriptor,const dnn::FilterDescriptor & filter_descriptor,const dnn::BatchDescriptor & bias_descriptor,const dnn::BatchDescriptor & output_descriptor,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::ActivationMode activation_mode,std::vector<std::unique_ptr<dnn::ConvolveExecutionPlan>> * out_exec_plans)4300 port::Status CudnnSupport::GetFusedConvolveExecutionPlans(
4301 dnn::ConvolutionKind kind, dnn::DataType element_type,
4302 double conv_input_scale, double side_input_scale, Stream* stream,
4303 const dnn::BatchDescriptor& input_descriptor,
4304 const dnn::FilterDescriptor& filter_descriptor,
4305 const dnn::BatchDescriptor& bias_descriptor,
4306 const dnn::BatchDescriptor& output_descriptor,
4307 const dnn::ConvolutionDescriptor& convolution_descriptor,
4308 const dnn::ActivationMode activation_mode,
4309 std::vector<std::unique_ptr<dnn::ConvolveExecutionPlan>>* out_exec_plans) {
4310 #if CUDNN_VERSION >= 8100 && TF_ENABLE_CUDNN_FRONTEND
4311 auto cudnn = cudnn_->GetHandle(parent_, stream);
4312 auto op_graph_status = GetCudnnFusedOperationGraph(
4313 kind, element_type, conv_input_scale, side_input_scale, stream,
4314 input_descriptor, filter_descriptor, bias_descriptor, output_descriptor,
4315 convolution_descriptor, activation_mode, cudnn);
4316 if (!op_graph_status.status().ok()) {
4317 return port::Status(port::error::INTERNAL,
4318 absl::StrCat("Cudnn graph failed to build: ",
4319 op_graph_status.status().ToString()));
4320 }
4321 auto op_graph = op_graph_status.ConsumeValueOrDie();
4322
4323 auto heur = cudnn_frontend::EngineHeuristicsBuilder()
4324 .setOperationGraph(*op_graph)
4325 .setHeurMode(CUDNN_HEUR_MODE_INSTANT)
4326 .build();
4327 RETURN_MSG_IF_CUDNN_ERROR(heur);
4328
4329 auto fallback =
4330 cudnn_frontend::EngineFallbackListBuilder()
4331 .setOperationGraph(*op_graph)
4332 .setOperation(GetCudnnConvolutionType(dnn::ConvolutionKind::FORWARD))
4333 .build();
4334 RETURN_MSG_IF_CUDNN_ERROR(fallback);
4335
4336 // cuDNN frontend sneakily puts error messages on the object and returns
4337 // partially-initialized results when there's an error; make sure to check
4338 // them.
4339 int64_t engine_count = heur.getEngineConfigCount();
4340 RETURN_MSG_IF_CUDNN_ERROR(heur);
4341 auto& heur_configs = heur.getEngineConfig(engine_count);
4342 RETURN_MSG_IF_CUDNN_ERROR(heur);
4343
4344 auto& fallback_configs = fallback.getFallbackList();
4345
4346 VLOG(4) << "\nHeuristics engine configs size: " << heur_configs.size()
4347 << "\nFallback engine configs size: " << fallback_configs.size();
4348
4349 cudnn_frontend::EngineConfigList filtered_configs;
4350 auto generic_filter_fn = [=](cudnnBackendDescriptor_t engine_config) -> bool {
4351 return GenericEngineFilter(
4352 engine_config,
4353 /*disable_winograd*/ !CudnnEnvVar<WinogradNonfused>::IsEnabled(),
4354 /*disable_nondeterminism*/ RequireCudnnDeterminism(),
4355 /*disable_tensor_core*/ !IsTensorMathEnabled(stream, element_type));
4356 };
4357
4358 cudnn_frontend::filter(heur_configs, filtered_configs, generic_filter_fn);
4359 cudnn_frontend::filter(fallback_configs, filtered_configs, generic_filter_fn);
4360
4361 auto fn = []() { return true; };
4362 auto maybe_json_handle_static = CudnnExecutionPlanEngineFilterStatic();
4363 auto maybe_json_handle_runtime = CudnnExecutionPlanEngineFilterRuntime();
4364
4365 VLOG(4) << "\nFiltered engine configs size: " << filtered_configs.size();
4366
4367 out_exec_plans->clear();
4368 for (int i = 0; i < filtered_configs.size(); i++) {
4369 auto plan = cudnn_frontend::ExecutionPlanBuilder()
4370 .setHandle(cudnn.handle())
4371 .setEngineConfig(filtered_configs[i], op_graph->getTag())
4372 .build();
4373 if (plan.get_status() == CUDNN_STATUS_SUCCESS) {
4374 if (maybe_json_handle_static &&
4375 cudnn_frontend::check_errata(*maybe_json_handle_static, plan.getTag(),
4376 cudnn.handle(), fn)) {
4377 VLOG(4) << "Exclude engine (static): " << plan.getTag();
4378 continue;
4379 }
4380 if (maybe_json_handle_runtime &&
4381 cudnn_frontend::check_errata(*maybe_json_handle_runtime,
4382 plan.getTag(), cudnn.handle(), fn)) {
4383 VLOG(4) << "Exclude engine (runtime): " << plan.getTag();
4384 continue;
4385 }
4386
4387 out_exec_plans->push_back(std::unique_ptr<dnn::ConvolveExecutionPlan>(
4388 new CudnnConvolveExecutionPlan(std::move(plan))));
4389 // We will use the first working plan when determinism is required.
4390 if (RequireCudnnDeterminism()) {
4391 break;
4392 }
4393 }
4394 }
4395
4396 VLOG(4) << "\nReturned execution plans size: " << out_exec_plans->size();
4397
4398 return port::Status::OK();
4399 #else
4400 return port::UnimplementedError(
4401 "Cudnn execution plans are only supported with Cudnn >= 8.1.");
4402 #endif // CUDNN_VERSION >= 8100 && TF_ENABLE_CUDNN_FRONTEND
4403 }
4404
GetConvolveAlgorithms(CudaComputeCapability cuda_compute_capability,std::vector<dnn::AlgorithmDesc> * out_algorithms)4405 bool CudnnSupport::GetConvolveAlgorithms(
4406 CudaComputeCapability cuda_compute_capability,
4407 std::vector<dnn::AlgorithmDesc>* out_algorithms) {
4408 // Preload sub libs for cudnn 8.0.4+
4409 #if CUDNN_MAJOR >= 8 && (CUDNN_MINOR > 0 || CUDNN_PATCHLEVEL >= 4)
4410 cudnnOpsInferVersionCheck();
4411 cudnnCnnInferVersionCheck();
4412 #endif
4413 bool tensor_op_math_available =
4414 TensorOpMathAvailable(cuda_compute_capability);
4415 out_algorithms->clear();
4416
4417 std::vector<dnn::AlgorithmDesc::Index> algo_types;
4418 if (ConvUseDefaultAlgorithm()) {
4419 // Force a fallback algorithm.
4420 algo_types = {CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM};
4421 } else {
4422 algo_types = {CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM,
4423 CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM,
4424 CUDNN_CONVOLUTION_FWD_ALGO_GEMM,
4425 CUDNN_CONVOLUTION_FWD_ALGO_DIRECT,
4426 CUDNN_CONVOLUTION_FWD_ALGO_FFT,
4427 CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD};
4428 if (CudnnEnvVar<FftTilingForward>::IsEnabled()) {
4429 algo_types.push_back(CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING);
4430 }
4431 if (CudnnEnvVar<WinogradNonfused>::IsEnabled()) {
4432 algo_types.push_back(CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED);
4433 }
4434 }
4435
4436 // The algorithms are intentionally ordered for deterministic operation
4437 for (auto i : algo_types) {
4438 if (tensor_op_math_available) {
4439 out_algorithms->push_back({i, /*use_tensor_ops=*/true});
4440 }
4441 out_algorithms->push_back({i, /*use_tensor_ops=*/false});
4442 }
4443
4444 return true;
4445 }
4446
GetRnnAlgorithms(std::vector<dnn::AlgorithmDesc> * out_algorithms)4447 bool CudnnSupport::GetRnnAlgorithms(
4448 std::vector<dnn::AlgorithmDesc>* out_algorithms) {
4449 // Preload sub libs for cudnn 8.0.4+
4450 #if CUDNN_MAJOR >= 8 && (CUDNN_MINOR > 0 || CUDNN_PATCHLEVEL >= 4)
4451 cudnnOpsInferVersionCheck();
4452 cudnnOpsTrainVersionCheck();
4453 cudnnAdvInferVersionCheck();
4454 cudnnAdvTrainVersionCheck();
4455 #endif
4456 std::vector<dnn::AlgorithmDesc::Index> algo_types = {
4457 // clang-format off
4458 CUDNN_RNN_ALGO_STANDARD,
4459 CUDNN_RNN_ALGO_PERSIST_STATIC,
4460 CUDNN_RNN_ALGO_PERSIST_DYNAMIC,
4461 // clang-format on
4462 };
4463
4464 out_algorithms->clear();
4465 for (auto i : algo_types) {
4466 out_algorithms->push_back({i, /*use_tensor_ops=*/false});
4467 out_algorithms->push_back({i, /*use_tensor_ops=*/true});
4468 }
4469 return true;
4470 }
4471
GetConvolveBackwardDataAlgorithms(CudaComputeCapability cuda_compute_capability,std::vector<dnn::AlgorithmDesc> * out_algorithms)4472 bool CudnnSupport::GetConvolveBackwardDataAlgorithms(
4473 CudaComputeCapability cuda_compute_capability,
4474 std::vector<dnn::AlgorithmDesc>* out_algorithms) {
4475 // Preload sub libs for cudnn 8.0.4+
4476 #if CUDNN_MAJOR >= 8 && (CUDNN_MINOR > 0 || CUDNN_PATCHLEVEL >= 4)
4477 cudnnOpsInferVersionCheck();
4478 cudnnOpsTrainVersionCheck();
4479 cudnnCnnInferVersionCheck();
4480 cudnnCnnTrainVersionCheck();
4481 #endif
4482 bool tensor_op_math_available =
4483 TensorOpMathAvailable(cuda_compute_capability);
4484 out_algorithms->clear();
4485
4486 std::vector<dnn::AlgorithmDesc::Index> algo_types = {
4487 // clang-format off
4488 CUDNN_CONVOLUTION_BWD_DATA_ALGO_1,
4489 CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT,
4490 CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING,
4491 CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD,
4492 // clang-format on
4493 };
4494 if (CudnnEnvVar<WinogradNonfused>::IsEnabled()) {
4495 algo_types.push_back(CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD_NONFUSED);
4496 }
4497 if (!RequireCudnnDeterminism()) {
4498 algo_types.push_back(CUDNN_CONVOLUTION_BWD_DATA_ALGO_0);
4499 }
4500
4501 // The algorithms are intentionally ordered for deterministic operation
4502 for (auto i : algo_types) {
4503 if (tensor_op_math_available) {
4504 out_algorithms->push_back({i, /*use_tensor_ops=*/true});
4505 }
4506 out_algorithms->push_back({i, /*use_tensor_ops=*/false});
4507 }
4508
4509 return true;
4510 }
4511
GetConvolveBackwardFilterAlgorithms(CudaComputeCapability cuda_compute_capability,std::vector<dnn::AlgorithmDesc> * out_algorithms)4512 bool CudnnSupport::GetConvolveBackwardFilterAlgorithms(
4513 CudaComputeCapability cuda_compute_capability,
4514 std::vector<dnn::AlgorithmDesc>* out_algorithms) {
4515 // Preload sub libs for cudnn 8.0.4+
4516 #if CUDNN_MAJOR >= 8 && (CUDNN_MINOR > 0 || CUDNN_PATCHLEVEL >= 4)
4517 cudnnOpsInferVersionCheck();
4518 cudnnOpsTrainVersionCheck();
4519 cudnnCnnInferVersionCheck();
4520 cudnnCnnTrainVersionCheck();
4521 #endif
4522 bool tensor_op_math_available =
4523 TensorOpMathAvailable(cuda_compute_capability);
4524 out_algorithms->clear();
4525
4526 std::vector<dnn::AlgorithmDesc::Index> algo_types = {
4527 // clang-format off
4528 CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1,
4529 CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT,
4530 // Based on cudnn.h, the following is not implemented.
4531 // CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD,
4532
4533 // Produces incorrect results for some shapes. Disabled for now, see
4534 // NVIDIA bug 2072856. TODO(csigg): Only disable for subset of shapes.
4535 // CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT_TILING,
4536 // clang-format on
4537 };
4538 if (CudnnEnvVar<WinogradNonfused>::IsEnabled()) {
4539 algo_types.push_back(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD_NONFUSED);
4540 }
4541 if (!RequireCudnnDeterminism()) {
4542 algo_types.push_back(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0);
4543 algo_types.push_back(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_3);
4544 }
4545
4546 // The algorithms are intentionally ordered for deterministic operation
4547 for (auto i : algo_types) {
4548 if (tensor_op_math_available) {
4549 out_algorithms->push_back({i, /*use_tensor_ops=*/true});
4550 }
4551 out_algorithms->push_back({i, /*use_tensor_ops=*/false});
4552 }
4553
4554 return true;
4555 }
4556
DoBatchNormalizationForward(Stream * stream,const DeviceMemory<float> & x,const DeviceMemory<float> & scale,const DeviceMemory<float> & offset,const DeviceMemory<float> & estimated_mean,const DeviceMemory<float> & estimated_variance,const DeviceMemory<float> & side_input,const dnn::BatchDescriptor & x_desc,const dnn::BatchDescriptor & scale_offset_desc,const double epsilon,const double exponential_average_factor,dnn::ActivationMode activation_mode,DeviceMemory<float> * y,DeviceMemory<float> * batch_mean,DeviceMemory<float> * batch_var,DeviceMemory<float> * saved_mean,DeviceMemory<float> * saved_inv_var,bool is_training,ScratchAllocator * reserve_space_allocator,ScratchAllocator * workspace_allocator)4557 bool CudnnSupport::DoBatchNormalizationForward(
4558 Stream* stream, const DeviceMemory<float>& x,
4559 const DeviceMemory<float>& scale, const DeviceMemory<float>& offset,
4560 const DeviceMemory<float>& estimated_mean,
4561 const DeviceMemory<float>& estimated_variance,
4562 const DeviceMemory<float>& side_input, const dnn::BatchDescriptor& x_desc,
4563 const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
4564 const double exponential_average_factor,
4565 dnn::ActivationMode activation_mode, DeviceMemory<float>* y,
4566 DeviceMemory<float>* batch_mean, DeviceMemory<float>* batch_var,
4567 DeviceMemory<float>* saved_mean, DeviceMemory<float>* saved_inv_var,
4568 bool is_training, ScratchAllocator* reserve_space_allocator,
4569 ScratchAllocator* workspace_allocator) {
4570 return IsStatusOk(
4571 DoBatchNormalizationForwardImpl<float, float>(
4572 stream, dnn::DataType::kFloat, dnn::DataType::kFloat, x, scale,
4573 offset, estimated_mean, estimated_variance, side_input, x_desc,
4574 scale_offset_desc, epsilon, exponential_average_factor,
4575 activation_mode, y, batch_mean, batch_var, saved_mean, saved_inv_var,
4576 is_training, reserve_space_allocator, workspace_allocator),
4577 /*report_error=*/true);
4578 }
4579
DoBatchNormalizationForward(Stream * stream,const DeviceMemory<Eigen::half> & x,const DeviceMemory<float> & scale,const DeviceMemory<float> & offset,const DeviceMemory<float> & estimated_mean,const DeviceMemory<float> & estimated_variance,const DeviceMemory<Eigen::half> & side_input,const dnn::BatchDescriptor & x_desc,const dnn::BatchDescriptor & scale_offset_desc,const double epsilon,const double exponential_average_factor,dnn::ActivationMode activation_mode,DeviceMemory<Eigen::half> * y,DeviceMemory<float> * batch_mean,DeviceMemory<float> * batch_var,DeviceMemory<float> * saved_mean,DeviceMemory<float> * saved_inv_var,bool is_training,ScratchAllocator * reserve_space_allocator,ScratchAllocator * workspace_allocator)4580 bool CudnnSupport::DoBatchNormalizationForward(
4581 Stream* stream, const DeviceMemory<Eigen::half>& x,
4582 const DeviceMemory<float>& scale, const DeviceMemory<float>& offset,
4583 const DeviceMemory<float>& estimated_mean,
4584 const DeviceMemory<float>& estimated_variance,
4585 const DeviceMemory<Eigen::half>& side_input,
4586 const dnn::BatchDescriptor& x_desc,
4587 const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
4588 const double exponential_average_factor,
4589 dnn::ActivationMode activation_mode, DeviceMemory<Eigen::half>* y,
4590 DeviceMemory<float>* batch_mean, DeviceMemory<float>* batch_var,
4591 DeviceMemory<float>* saved_mean, DeviceMemory<float>* saved_inv_var,
4592 bool is_training, ScratchAllocator* reserve_space_allocator,
4593 ScratchAllocator* workspace_allocator) {
4594 return IsStatusOk(
4595 DoBatchNormalizationForwardImpl<Eigen::half, float>(
4596 stream, dnn::DataType::kHalf, dnn::DataType::kFloat, x, scale, offset,
4597 estimated_mean, estimated_variance, side_input, x_desc,
4598 scale_offset_desc, epsilon, exponential_average_factor,
4599 activation_mode, y, batch_mean, batch_var, saved_mean, saved_inv_var,
4600 is_training, reserve_space_allocator, workspace_allocator),
4601 /*report_error=*/true);
4602 }
4603
4604 template <class T, class U>
DoBatchNormalizationForwardImpl(Stream * stream,dnn::DataType input_data_type,dnn::DataType scale_data_type,const DeviceMemory<T> & x,const DeviceMemory<U> & scale,const DeviceMemory<U> & offset,const DeviceMemory<U> & estimated_mean,const DeviceMemory<U> & estimated_variance,const DeviceMemory<T> & side_input,const dnn::BatchDescriptor & x_desc,const dnn::BatchDescriptor & scale_offset_desc,const double epsilon,const double exponential_average_factor,dnn::ActivationMode activation_mode,DeviceMemory<T> * y,DeviceMemory<U> * batch_mean,DeviceMemory<U> * batch_var,DeviceMemory<U> * saved_mean,DeviceMemory<U> * saved_inv_var,bool is_training,ScratchAllocator * reserve_space_allocator,ScratchAllocator * workspace_allocator)4605 port::Status CudnnSupport::DoBatchNormalizationForwardImpl(
4606 Stream* stream, dnn::DataType input_data_type,
4607 dnn::DataType scale_data_type, const DeviceMemory<T>& x,
4608 const DeviceMemory<U>& scale, const DeviceMemory<U>& offset,
4609 const DeviceMemory<U>& estimated_mean,
4610 const DeviceMemory<U>& estimated_variance,
4611 const DeviceMemory<T>& side_input, const dnn::BatchDescriptor& x_desc,
4612 const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
4613 const double exponential_average_factor,
4614 dnn::ActivationMode activation_mode, DeviceMemory<T>* y,
4615 DeviceMemory<U>* batch_mean, DeviceMemory<U>* batch_var,
4616 DeviceMemory<U>* saved_mean, DeviceMemory<U>* saved_inv_var,
4617 bool is_training, ScratchAllocator* reserve_space_allocator,
4618 ScratchAllocator* workspace_allocator) {
4619 CudnnTensorDescriptor x_descriptor(x_desc, ToCudnnDataType(input_data_type));
4620 CudnnTensorDescriptor scale_offset_descriptor(
4621 scale_offset_desc, ToCudnnDataType(scale_data_type));
4622 cudnnBatchNormMode_t mode = CUDNN_BATCHNORM_SPATIAL;
4623 if (BatchnormSpatialPersistentEnabled() && is_training) {
4624 mode = CUDNN_BATCHNORM_SPATIAL_PERSISTENT;
4625 }
4626 float one = 1.0;
4627 float zero = 0.0;
4628 auto cudnn = cudnn_->GetHandle(parent_, stream);
4629
4630 DeviceMemory<uint8> workspace;
4631 DeviceMemory<uint8> reserve_space;
4632
4633 #if CUDNN_VERSION >= 7402
4634 const auto get_bn_ops = [&]() -> cudnnBatchNormOps_t {
4635 if (side_input.is_null()) {
4636 return activation_mode == dnn::ActivationMode::kNone
4637 ? CUDNN_BATCHNORM_OPS_BN
4638 : CUDNN_BATCHNORM_OPS_BN_ACTIVATION;
4639 } else {
4640 return CUDNN_BATCHNORM_OPS_BN_ADD_ACTIVATION;
4641 }
4642 };
4643 const cudnnBatchNormOps_t bn_ops = get_bn_ops();
4644
4645 // We use Nan propagation to be consistent with CudnnSupport::DoActivate(...).
4646 CudnnActivationDescriptor activation_desc(
4647 activation_mode, CUDNN_PROPAGATE_NAN, x_desc.value_max());
4648
4649 if (reserve_space_allocator != nullptr && workspace_allocator != nullptr) {
4650 SE_ASSIGN_OR_RETURN(
4651 workspace,
4652 CreateBatchNormForwardWorkspace(
4653 stream, cudnn, mode, bn_ops, activation_desc.handle(), x_descriptor,
4654 scale_offset_descriptor, workspace_allocator))
4655 if (is_training) {
4656 size_t reserve_space_size_in_bytes = 0;
4657 RETURN_IF_CUDNN_ERROR(
4658 cudnnGetBatchNormalizationTrainingExReserveSpaceSize(
4659 /*handle=*/cudnn.handle(), /*mode=*/mode, /*bnOps=*/bn_ops,
4660 /*activationDesc=*/activation_desc.handle(),
4661 /*xDesc=*/x_descriptor.handle(),
4662 /*sizeInBytes=*/&reserve_space_size_in_bytes));
4663 SE_ASSIGN_OR_RETURN(reserve_space, reserve_space_allocator->AllocateBytes(
4664 reserve_space_size_in_bytes));
4665 }
4666 }
4667 #endif
4668
4669 auto check_no_side_input_or_activation = [&]() -> port::Status {
4670 if (activation_mode != dnn::ActivationMode::kNone ||
4671 !side_input.is_null()) {
4672 return port::Status(
4673 port::error::INTERNAL,
4674 absl::StrCat(
4675 "Side input and activation are not supported by cuDNN version: ",
4676 CUDNN_VERSION));
4677 } else {
4678 return port::Status::OK();
4679 }
4680 };
4681
4682 if (is_training) {
4683 CHECK_EQ(batch_mean->is_null(), batch_var->is_null())
4684 << "batch_mean and batch_var must both be null or both be non-null";
4685
4686 void* batch_mean_opaque;
4687 void* batch_var_opaque;
4688 if (!batch_mean->is_null() && !batch_var->is_null()) {
4689 if (exponential_average_factor == 1.0) {
4690 stream->ThenMemZero(batch_mean, batch_mean->size());
4691 stream->ThenMemZero(batch_var, batch_var->size());
4692 }
4693 batch_mean_opaque = batch_mean->opaque();
4694 batch_var_opaque = batch_var->opaque();
4695 } else {
4696 batch_mean_opaque = nullptr;
4697 batch_var_opaque = nullptr;
4698 }
4699
4700 bool called = false;
4701 #if CUDNN_VERSION >= 7402
4702 if (reserve_space_allocator != nullptr && workspace_allocator != nullptr) {
4703 called = true;
4704 RETURN_IF_CUDNN_ERROR(cudnnBatchNormalizationForwardTrainingEx(
4705 /*handle=*/cudnn.handle(),
4706 /*mode=*/mode,
4707 /*bnOps=*/bn_ops,
4708 /*alpha=*/&one,
4709 /*beta=*/&zero,
4710 /*xDesc=*/x_descriptor.handle(),
4711 /*xData=*/x.opaque(),
4712 /*zDesc=*/x_descriptor.handle(),
4713 /*zData=*/side_input.opaque(),
4714 /*yDesc=*/x_descriptor.handle(),
4715 /*yData=*/y->opaque(),
4716 /*bnScaleBiasMeanVarDesc=*/scale_offset_descriptor.handle(),
4717 /*bnScale=*/scale.opaque(),
4718 /*bnBias=*/offset.opaque(),
4719 /*exponentialAverageFactor=*/exponential_average_factor,
4720 /*resultRunningMean=*/batch_mean_opaque,
4721 /*resultRunningVariance=*/batch_var_opaque,
4722 /*epsilon=*/epsilon,
4723 /*resultSaveMean=*/saved_mean->opaque(),
4724 /*resultSaveInvVariance=*/saved_inv_var->opaque(),
4725 /*activationDesc=*/activation_desc.handle(),
4726 /*workspace=*/workspace.opaque(),
4727 /*workSpaceSizeInBytes=*/workspace.size(),
4728 /*reserveSpace=*/reserve_space.opaque(),
4729 /*reserveSpaceSizeInBytes=*/reserve_space.size()));
4730 }
4731 #endif
4732 if (!called) {
4733 SE_RETURN_IF_ERROR(check_no_side_input_or_activation());
4734 RETURN_IF_CUDNN_ERROR(cudnnBatchNormalizationForwardTraining(
4735 cudnn.handle(), mode, &one, &zero, x_descriptor.handle(), x.opaque(),
4736 x_descriptor.handle(), y->opaque(), scale_offset_descriptor.handle(),
4737 scale.opaque(), offset.opaque(), exponential_average_factor,
4738 batch_mean_opaque, batch_var_opaque, epsilon, saved_mean->opaque(),
4739 saved_inv_var->opaque()));
4740 }
4741 } else {
4742 const void* maybe_inv_var = estimated_variance.opaque();
4743 SE_RETURN_IF_ERROR(check_no_side_input_or_activation());
4744 RETURN_IF_CUDNN_ERROR(cudnnBatchNormalizationForwardInference(
4745 cudnn.handle(), mode, &one, &zero, x_descriptor.handle(), x.opaque(),
4746 x_descriptor.handle(), y->opaque(), scale_offset_descriptor.handle(),
4747 scale.opaque(), offset.opaque(), estimated_mean.opaque(), maybe_inv_var,
4748 epsilon));
4749 }
4750 return port::Status::OK();
4751 }
4752
DoBatchNormalizationBackward(Stream * stream,const DeviceMemory<float> & y_backprop,const DeviceMemory<float> & x,const DeviceMemory<float> & scale,const DeviceMemory<float> & mean,const DeviceMemory<float> & inv_var,const dnn::BatchDescriptor & x_desc,const dnn::BatchDescriptor & scale_offset_desc,const double epsilon,DeviceMemory<float> * x_backprop,DeviceMemory<float> * scale_backprop,DeviceMemory<float> * offset_backprop,DeviceMemory<uint8> * reserve_space_data,ScratchAllocator * workspace_allocator)4753 bool CudnnSupport::DoBatchNormalizationBackward(
4754 Stream* stream, const DeviceMemory<float>& y_backprop,
4755 const DeviceMemory<float>& x, const DeviceMemory<float>& scale,
4756 const DeviceMemory<float>& mean, const DeviceMemory<float>& inv_var,
4757 const dnn::BatchDescriptor& x_desc,
4758 const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
4759 DeviceMemory<float>* x_backprop, DeviceMemory<float>* scale_backprop,
4760 DeviceMemory<float>* offset_backprop,
4761 DeviceMemory<uint8>* reserve_space_data,
4762 ScratchAllocator* workspace_allocator) {
4763 return IsStatusOk(DoBatchNormalizationBackwardImpl(
4764 stream, CUDNN_DATA_FLOAT, CUDNN_DATA_FLOAT, y_backprop,
4765 x, scale, mean, inv_var, x_desc, scale_offset_desc,
4766 epsilon, x_backprop, scale_backprop, offset_backprop,
4767 reserve_space_data, workspace_allocator),
4768 /*report_error=*/true);
4769 }
4770
DoBatchNormalizationBackward(Stream * stream,const DeviceMemory<Eigen::half> & y_backprop,const DeviceMemory<Eigen::half> & x,const DeviceMemory<float> & scale,const DeviceMemory<float> & mean,const DeviceMemory<float> & inv_var,const dnn::BatchDescriptor & x_desc,const dnn::BatchDescriptor & scale_offset_desc,const double epsilon,DeviceMemory<Eigen::half> * x_backprop,DeviceMemory<float> * scale_backprop,DeviceMemory<float> * offset_backprop,DeviceMemory<uint8> * reserve_space_data,ScratchAllocator * workspace_allocator)4771 bool CudnnSupport::DoBatchNormalizationBackward(
4772 Stream* stream, const DeviceMemory<Eigen::half>& y_backprop,
4773 const DeviceMemory<Eigen::half>& x, const DeviceMemory<float>& scale,
4774 const DeviceMemory<float>& mean, const DeviceMemory<float>& inv_var,
4775 const dnn::BatchDescriptor& x_desc,
4776 const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
4777 DeviceMemory<Eigen::half>* x_backprop, DeviceMemory<float>* scale_backprop,
4778 DeviceMemory<float>* offset_backprop,
4779 DeviceMemory<uint8>* reserve_space_data,
4780 ScratchAllocator* workspace_allocator) {
4781 return IsStatusOk(DoBatchNormalizationBackwardImpl(
4782 stream, CUDNN_DATA_HALF, CUDNN_DATA_FLOAT, y_backprop,
4783 x, scale, mean, inv_var, x_desc, scale_offset_desc,
4784 epsilon, x_backprop, scale_backprop, offset_backprop,
4785 reserve_space_data, workspace_allocator),
4786 /*report_error=*/true);
4787 }
4788
4789 template <class T, class U>
DoBatchNormalizationBackwardImpl(Stream * stream,int cudnn_input_type,int cudnn_scale_type,const DeviceMemory<T> & y_backprop,const DeviceMemory<T> & x,const DeviceMemory<U> & scale,const DeviceMemory<U> & mean,const DeviceMemory<U> & inv_var,const dnn::BatchDescriptor & x_desc,const dnn::BatchDescriptor & scale_offset_desc,const double epsilon,DeviceMemory<T> * x_backprop,DeviceMemory<U> * scale_backprop,DeviceMemory<U> * offset_backprop,DeviceMemory<uint8> * reserve_space_data,ScratchAllocator * workspace_allocator)4790 port::Status CudnnSupport::DoBatchNormalizationBackwardImpl(
4791 Stream* stream, int cudnn_input_type, int cudnn_scale_type,
4792 const DeviceMemory<T>& y_backprop, const DeviceMemory<T>& x,
4793 const DeviceMemory<U>& scale, const DeviceMemory<U>& mean,
4794 const DeviceMemory<U>& inv_var, const dnn::BatchDescriptor& x_desc,
4795 const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
4796 DeviceMemory<T>* x_backprop, DeviceMemory<U>* scale_backprop,
4797 DeviceMemory<U>* offset_backprop, DeviceMemory<uint8>* reserve_space_data,
4798 ScratchAllocator* workspace_allocator) {
4799 CudnnTensorDescriptor x_descriptor(
4800 x_desc, static_cast<cudnnDataType_t>(cudnn_input_type));
4801 CudnnTensorDescriptor scale_offset_descriptor(
4802 scale_offset_desc, static_cast<cudnnDataType_t>(cudnn_scale_type));
4803 cudnnBatchNormMode_t mode = CUDNN_BATCHNORM_SPATIAL;
4804 if (BatchnormSpatialPersistentEnabled()) {
4805 mode = CUDNN_BATCHNORM_SPATIAL_PERSISTENT;
4806 }
4807 float one = 1.0;
4808 float zero = 0.0;
4809
4810 auto cudnn = cudnn_->GetHandle(parent_, stream);
4811
4812 bool called = false;
4813 #if CUDNN_VERSION >= 7402
4814 if (reserve_space_data != nullptr && workspace_allocator != nullptr) {
4815 called = true;
4816 const cudnnBatchNormOps_t bn_ops = CUDNN_BATCHNORM_OPS_BN;
4817 SE_ASSIGN_OR_RETURN(DeviceMemory<uint8> workspace,
4818 CreateBatchNormBackwardWorkspace(
4819 stream, cudnn, mode, bn_ops, x_descriptor,
4820 scale_offset_descriptor, workspace_allocator))
4821 RETURN_IF_CUDNN_ERROR(cudnnBatchNormalizationBackwardEx(
4822 /*handle=*/cudnn.handle(),
4823 /*mode=*/mode,
4824 /*bnOps=*/bn_ops,
4825 /*alphaDataDiff=*/&one,
4826 /*betaDataDiff=*/&zero,
4827 /*alphaParamDiff=*/&one,
4828 /*betaParamDiff=*/&zero,
4829 /*xDesc=*/x_descriptor.handle(),
4830 /*xData=*/x.opaque(),
4831 /*yDesc=*/nullptr,
4832 /*yData=*/nullptr,
4833 /*dyDesc=*/x_descriptor.handle(),
4834 /*dyData=*/y_backprop.opaque(),
4835 /*dzDesc=*/nullptr,
4836 /*dzData=*/nullptr,
4837 /*dxDesc=*/x_descriptor.handle(),
4838 /*dxData=*/x_backprop->opaque(),
4839 /*dBnScaleBiasDesc=*/scale_offset_descriptor.handle(),
4840 /*bnScaleData=*/scale.opaque(),
4841 /*bnBiasData=*/nullptr,
4842 /*dBnScaleData=*/scale_backprop->opaque(),
4843 /*dBnBiasData=*/offset_backprop->opaque(),
4844 /*epsilon=*/epsilon,
4845 /*savedMean=*/mean.opaque(),
4846 /*savedInvVariance=*/inv_var.opaque(),
4847 /*activationDesc=*/nullptr,
4848 /*workspace=*/workspace.opaque(),
4849 /*workSpaceSizeInBytes=*/workspace.size(),
4850 /*reserveSpace=*/reserve_space_data->opaque(),
4851 /*reserveSpaceSizeInBytes=*/reserve_space_data->size()));
4852 }
4853 #endif
4854 if (!called) {
4855 RETURN_IF_CUDNN_ERROR(cudnnBatchNormalizationBackward(
4856 cudnn.handle(), mode, &one, &zero, &one, &zero, x_descriptor.handle(),
4857 x.opaque(), x_descriptor.handle(), y_backprop.opaque(),
4858 x_descriptor.handle(), x_backprop->opaque(),
4859 scale_offset_descriptor.handle(), scale.opaque(),
4860 scale_backprop->opaque(), offset_backprop->opaque(), epsilon,
4861 mean.opaque(), inv_var.opaque()));
4862 }
4863
4864 return port::Status::OK();
4865 }
4866
DoFusedConvolve(Stream * stream,dnn::DataType input_type,dnn::DataType side_input_type,dnn::DataType bias_type,dnn::DataType output_type,const dnn::BatchDescriptor & conv_input_descriptor,DeviceMemoryBase conv_input_data,double conv_input_scale,const dnn::FilterDescriptor & filter_descriptor,DeviceMemoryBase filter_data,const dnn::ConvolutionDescriptor & convolution_descriptor,DeviceMemoryBase side_input_data,double side_input_scale,const dnn::BatchDescriptor & bias_descriptor,DeviceMemoryBase biases,dnn::ActivationMode activation_mode,const dnn::BatchDescriptor & output_descriptor,DeviceMemoryBase output_data,ScratchAllocator * scratch_allocator,const dnn::AlgorithmConfig & algorithm_config,dnn::ProfileResult * output_profile_result)4867 port::Status CudnnSupport::DoFusedConvolve(
4868 Stream* stream, dnn::DataType input_type, dnn::DataType side_input_type,
4869 dnn::DataType bias_type, dnn::DataType output_type,
4870 const dnn::BatchDescriptor& conv_input_descriptor,
4871 DeviceMemoryBase conv_input_data, double conv_input_scale,
4872 const dnn::FilterDescriptor& filter_descriptor,
4873 DeviceMemoryBase filter_data,
4874 const dnn::ConvolutionDescriptor& convolution_descriptor,
4875 DeviceMemoryBase side_input_data, double side_input_scale,
4876 const dnn::BatchDescriptor& bias_descriptor, DeviceMemoryBase biases,
4877 dnn::ActivationMode activation_mode,
4878 const dnn::BatchDescriptor& output_descriptor, DeviceMemoryBase output_data,
4879 ScratchAllocator* scratch_allocator,
4880 const dnn::AlgorithmConfig& algorithm_config,
4881 dnn::ProfileResult* output_profile_result) {
4882 if (input_type == dnn::DataType::kInt8 &&
4883 !stream->GetCudaComputeCapability().IsAtLeast(6, 1)) {
4884 return port::UnimplementedError(
4885 "cudnnConvolutionBiasActivationForward() for int8 is only supported on "
4886 "GPUs with compute capability 6.1 or later.");
4887 }
4888
4889 if (input_type == dnn::DataType::kInt8 &&
4890 output_type == dnn::DataType::kFloat &&
4891 (CUDNN_VERSION >= 8000 && CUDNN_VERSION <= 8200)) {
4892 return port::UnimplementedError(
4893 "int8 -> float fused conv is disabled for this cuDNN version. See "
4894 "go/nvbugs/3326122");
4895 }
4896
4897 if (activation_mode != dnn::ActivationMode::kRelu &&
4898 activation_mode != dnn::ActivationMode::kNone) {
4899 return port::Status(port::error::INVALID_ARGUMENT,
4900 "cudnnConvolutionBiasActivationForward() only supports "
4901 "Relu or None activation.");
4902 }
4903
4904 CudnnTensorDescriptor conv_input_nd(
4905 conv_input_descriptor,
4906 ToCudnnDataType(input_type, conv_input_descriptor.layout()));
4907 CudnnTensorDescriptor output_nd(
4908 output_descriptor,
4909 ToCudnnDataType(output_type, conv_input_descriptor.layout()));
4910 CudnnFilterDescriptor filter(
4911 filter_descriptor,
4912 ToCudnnDataType(input_type, filter_descriptor.layout()));
4913 CudnnTensorDescriptor bias_nd(bias_descriptor, ToCudnnDataType(bias_type));
4914
4915 auto cudnn = cudnn_->GetHandle(parent_, stream);
4916
4917 const bool is_profiling = output_profile_result != nullptr;
4918
4919 DeviceMemory<uint8> scratch;
4920 SE_ASSIGN_OR_RETURN(
4921 dnn::AlgorithmDesc algo_desc,
4922 GetCudnnConvolutionForwardAlgorithm(
4923 stream, cudnn, algorithm_config, conv_input_nd, filter, input_type,
4924 convolution_descriptor, output_nd, scratch_allocator, &scratch));
4925
4926 CudnnConvolutionDescriptor conv(
4927 convolution_descriptor,
4928 ToCudnnDataType(GetConvAccumulatorType(input_type)));
4929 conv.set_use_tensor_op_math(algo_desc.tensor_ops_enabled());
4930
4931 std::unique_ptr<GpuTimer, GpuTimerDeleter> timer;
4932 if (is_profiling) {
4933 timer.reset(new GpuTimer(parent_)); // NOLINT
4934 // The start and stop of the timer should be as close to the Cudnn call as
4935 // possible. It is still possible for other threads to issue workload on
4936 // to this stream. So it could take multiple profiling measurements.
4937 if (!timer->Init() || !timer->Start(AsGpuStream(stream))) {
4938 return port::Status(port::error::INTERNAL, "Failed to start timer");
4939 }
4940 }
4941 // CUDNN v6 only supports CUDNN_NOT_PROPAGATE_NAN as the reluNanOpt for
4942 // activation descriptor. Note that this will change the nan propagation
4943 // behavior from separate conv, bias, and relu (which by default is
4944 // CUDNN_PROPAGATE_NAN.
4945 CudnnActivationDescriptor activation_desc(
4946 activation_mode, CUDNN_NOT_PROPAGATE_NAN, output_descriptor.value_max());
4947 auto side_input_data_ptr =
4948 (side_input_scale == 0) ? output_data.opaque() : side_input_data.opaque();
4949
4950 VLOG(2) << "\nconv_input_scale = " << conv_input_scale
4951 << "\nconv_input_nd.handle() = " << conv_input_nd.handle()
4952 << "\nconv_input_data.opaque() = " << conv_input_data.opaque()
4953 << "\nfilter.handle() = " << filter.handle()
4954 << "\nfilter_data.opaque() = " << filter_data.opaque()
4955 << "\nconv.handle() = " << conv.handle()
4956 << "\nalgo = " << algo_desc.algo_id()
4957 << ", tensor_ops_enabled=" << algo_desc.tensor_ops_enabled()
4958 << "\nscratch.opaque() = " << scratch.opaque()
4959 << "\nscratch.size() = " << scratch.size()
4960 << "\nside_input_scale = " << side_input_scale
4961 << "\noutput_nd.handle() = " << output_nd.handle()
4962 << "\nside_input_data_ptr = " << side_input_data_ptr
4963 << "\nbias_nd.handle() = " << bias_nd.handle()
4964 << "\nbiases.opaque() = " << biases.opaque()
4965 << "\nactivation_desc.handle() = " << activation_desc.handle()
4966 << "\noutput_nd.handle() = " << output_nd.handle()
4967 << "\noutput_data.opaque() = " << output_data.opaque();
4968
4969 if (IsTensorMathOpSet(conv) != algo_desc.tensor_ops_enabled()) {
4970 return port::Status(port::error::FAILED_PRECONDITION,
4971 "Tensor op math type in dnn::AlgorithmDesc does not "
4972 "match that of the CudnnConvolutionDescriptor");
4973 }
4974
4975 // N.B. the scaling parameters alpha1 and alpha2 are pointers to temporaries;
4976 // this API doesn't persist the pointers beyond its own stack frame.
4977 auto status = cudnnConvolutionBiasActivationForward(
4978 cudnn.handle(),
4979 /*alpha1=*/ScalingParam(conv_input_scale).ToVoidPointer(input_type),
4980 /*xDesc=*/conv_input_nd.handle(), /*x=*/conv_input_data.opaque(),
4981 /*wDesc=*/filter.handle(), /*w=*/filter_data.opaque(),
4982 /*convDesc=*/conv.handle(), ToConvForwardAlgo(algo_desc),
4983 /*workSpace=*/scratch.opaque(),
4984 /*workSpaceSizeInBytes=*/scratch.size(),
4985 /*alpha2=*/ScalingParam(side_input_scale).ToVoidPointer(input_type),
4986 /*zDesc=*/output_nd.handle(), /*z=*/side_input_data_ptr,
4987 /*biasDesc=*/bias_nd.handle(), /*bias=*/biases.opaque(),
4988 /*activationDesc=*/activation_desc.handle(),
4989 /*yDesc=*/output_nd.handle(), /*y=*/output_data.opaque());
4990 if (status != CUDNN_STATUS_SUCCESS || !is_profiling) {
4991 VLOG(4) << "conv with algorithm " << ToConvForwardAlgo(algo_desc)
4992 << ", workspace_size=" << scratch.size() << " -> "
4993 << ToString(status);
4994 }
4995 RETURN_IF_CUDNN_ERROR(status);
4996
4997 if (is_profiling) {
4998 if (!timer->Stop(AsGpuStream(stream))) {
4999 return port::Status(port::error::INTERNAL, "Failed to stop timer");
5000 }
5001 output_profile_result->set_algorithm(algo_desc);
5002 output_profile_result->set_elapsed_time_in_ms(
5003 timer->GetElapsedMilliseconds());
5004 output_profile_result->set_scratch_size(scratch.size());
5005 VLOG(4) << "conv with algorithm " << ToConvForwardAlgo(algo_desc)
5006 << ", tensor_ops_enabled=" << algo_desc.tensor_ops_enabled()
5007 << ", workspace_size=" << scratch.size() << " -> "
5008 << ToString(status) << " in " << timer->GetElapsedMilliseconds()
5009 << "ms";
5010 }
5011
5012 return port::Status::OK();
5013 }
5014
DoFusedConvolveWithExecutionPlan(Stream * stream,dnn::DataType element_type,const dnn::BatchDescriptor & conv_input_descriptor,DeviceMemoryBase conv_input_data,double conv_input_scale,const dnn::FilterDescriptor & filter_descriptor,DeviceMemoryBase filter_data,const dnn::ConvolutionDescriptor & convolution_descriptor,DeviceMemoryBase side_input_data,double side_input_scale,const dnn::BatchDescriptor & bias_descriptor,DeviceMemoryBase biases,dnn::ActivationMode activation_mode,const dnn::BatchDescriptor & output_descriptor,DeviceMemoryBase output_data,ScratchAllocator * scratch_allocator,const dnn::AlgorithmConfig & algorithm_config,dnn::ProfileResult * output_profile_result)5015 port::Status CudnnSupport::DoFusedConvolveWithExecutionPlan(
5016 Stream* stream, dnn::DataType element_type,
5017 const dnn::BatchDescriptor& conv_input_descriptor,
5018 DeviceMemoryBase conv_input_data, double conv_input_scale,
5019 const dnn::FilterDescriptor& filter_descriptor,
5020 DeviceMemoryBase filter_data,
5021 const dnn::ConvolutionDescriptor& convolution_descriptor,
5022 DeviceMemoryBase side_input_data, double side_input_scale,
5023 const dnn::BatchDescriptor& bias_descriptor, DeviceMemoryBase biases,
5024 dnn::ActivationMode activation_mode,
5025 const dnn::BatchDescriptor& output_descriptor, DeviceMemoryBase output_data,
5026 ScratchAllocator* scratch_allocator,
5027 const dnn::AlgorithmConfig& algorithm_config,
5028 dnn::ProfileResult* output_profile_result) {
5029 #if CUDNN_VERSION >= 8100 && TF_ENABLE_CUDNN_FRONTEND
5030 auto cudnn = cudnn_->GetHandle(parent_, stream);
5031
5032 absl::optional<dnn::AlgorithmDesc> plan_or = algorithm_config.algorithm();
5033 absl::optional<dnn::AlgorithmDesc> plan_no_scratch_or =
5034 algorithm_config.algorithm_no_scratch();
5035
5036 std::unique_ptr<cudnn_frontend::ExecutionPlan> current_plan;
5037 if (!plan_or.has_value()) {
5038 SE_ASSIGN_OR_RETURN(
5039 std::unique_ptr<cudnn_frontend::OperationGraph> op_graph,
5040 GetCudnnFusedOperationGraph(
5041 dnn::ConvolutionKind::FORWARD, element_type, conv_input_scale,
5042 side_input_scale, stream, conv_input_descriptor, filter_descriptor,
5043 bias_descriptor, output_descriptor, convolution_descriptor,
5044 activation_mode, cudnn));
5045
5046 SE_ASSIGN_OR_RETURN(current_plan, GetFirstWorkingExecutionPlan(
5047 stream, element_type, op_graph,
5048 dnn::ConvolutionKind::FORWARD, cudnn,
5049 scratch_allocator));
5050 }
5051
5052 size_t workspace_size = 0;
5053 cudnnBackendDescriptor_t plan_desc;
5054 std::string exec_plan_id = "unknown";
5055 if (current_plan) {
5056 exec_plan_id = current_plan->getTag();
5057 workspace_size = current_plan->getWorkspaceSize();
5058 plan_desc = current_plan->get_raw_desc();
5059 } else {
5060 exec_plan_id = plan_or->exec_plan_id();
5061 auto workspace_size_or = algorithm_config.scratch_size();
5062 if (workspace_size_or.has_value()) {
5063 workspace_size = *workspace_size_or;
5064 }
5065 plan_desc = plan_or->exec_plan_desc();
5066 }
5067 dnn::AlgorithmDesc selected_plan(exec_plan_id, plan_desc);
5068
5069 DeviceMemory<uint8> scratch_memory;
5070 if (workspace_size > 0) {
5071 auto scratch_or = scratch_allocator->AllocateBytes(workspace_size);
5072 if (scratch_or.ok()) {
5073 scratch_memory = scratch_or.ValueOrDie();
5074 } else if (plan_no_scratch_or.has_value()) {
5075 selected_plan = {plan_no_scratch_or->exec_plan_id(),
5076 plan_no_scratch_or->exec_plan_desc()};
5077 } else {
5078 return port::Status(port::error::UNKNOWN,
5079 "CUDNN failed to allocate the scratch space for the "
5080 "plan or to find a working no-scratch plan.");
5081 }
5082 }
5083
5084 void* data_ptrs[] = {
5085 conv_input_data.opaque(), output_data.opaque(), filter_data.opaque(),
5086 side_input_data.opaque(), biases.opaque(),
5087 };
5088 int64_t uids[] = {'x', 'y', 'w', 'z', 'b'};
5089 auto variantPack = cudnn_frontend::VariantPackBuilder()
5090 .setWorkspacePointer(scratch_memory.opaque())
5091 .setDataPointers(5, data_ptrs)
5092 .setUids(5, uids)
5093 .build();
5094 RETURN_MSG_IF_CUDNN_ERROR(variantPack);
5095
5096 VLOG(4) << "\nDo fused convolution with plan tag: "
5097 << selected_plan.exec_plan_id()
5098 << "\nWorkspace size in bytes: " << workspace_size
5099 << "\nVariantPack: " << variantPack.describe();
5100
5101 const bool is_profiling = output_profile_result != nullptr;
5102
5103 std::unique_ptr<GpuTimer, GpuTimerDeleter> timer;
5104 if (is_profiling) {
5105 timer.reset(new GpuTimer(parent_)); // NOLINT
5106 // The start and stop of the timer should be as close to the Cudnn call as
5107 // possible. It is still possible for other threads to issue workload on
5108 // to this stream. So it could take multiple profiling measurements.
5109 if (!timer->Init() || !timer->Start(AsGpuStream(stream))) {
5110 return port::Status(port::error::INTERNAL, "Failed to start timer");
5111 }
5112 }
5113
5114 cudnnStatus_t status =
5115 cudnnBackendExecute(cudnn.handle(), selected_plan.exec_plan_desc(),
5116 variantPack.get_raw_desc());
5117 if (status != CUDNN_STATUS_SUCCESS || !is_profiling) {
5118 VLOG(4) << "conv with plan " << selected_plan.exec_plan_id()
5119 << ", workspace_size=" << workspace_size << " -> "
5120 << ToString(status);
5121 }
5122 RETURN_IF_CUDNN_ERROR(status);
5123
5124 if (is_profiling) {
5125 if (!timer->Stop(AsGpuStream(stream))) {
5126 return port::Status(port::error::INTERNAL, "Failed to stop timer");
5127 }
5128 output_profile_result->set_algorithm(selected_plan);
5129 output_profile_result->set_elapsed_time_in_ms(
5130 timer->GetElapsedMilliseconds());
5131 output_profile_result->set_scratch_size(scratch_memory.size());
5132 VLOG(4) << "conv with plan " << selected_plan.exec_plan_id()
5133 << ", workspace_size=" << workspace_size << " -> "
5134 << ToString(status) << " in " << timer->GetElapsedMilliseconds()
5135 << "ms";
5136 }
5137
5138 return port::Status::OK();
5139 #else
5140 return port::InternalError(
5141 "To use CuDNN frontend APIs, CuDNN v8.1 or later "
5142 "is required.");
5143 #endif // CUDNN_VERSION >= 8100 && TF_ENABLE_CUDNN_FRONTEND
5144 }
5145
DoPrepareForCtcLoss(Stream * stream,dnn::DataType element_type,const dnn::RnnStateTensorDescriptor & probs_desc,const dnn::RnnStateTensorDescriptor & grads_desc,absl::Span<const int> labels_data,absl::Span<const int> labels_lengths_data,absl::Span<const int> input_lengths_data,ScratchAllocator * scratch_allocator,DeviceMemory<uint8> * scratch_memory,int * ctc_loss_algo_id)5146 port::Status CudnnSupport::DoPrepareForCtcLoss(
5147 Stream* stream, dnn::DataType element_type,
5148 const dnn::RnnStateTensorDescriptor& probs_desc,
5149 const dnn::RnnStateTensorDescriptor& grads_desc,
5150 absl::Span<const int> labels_data,
5151 absl::Span<const int> labels_lengths_data,
5152 absl::Span<const int> input_lengths_data,
5153 ScratchAllocator* scratch_allocator, DeviceMemory<uint8>* scratch_memory,
5154 int* ctc_loss_algo_id) {
5155 auto cudnn = cudnn_->GetHandle(parent_, stream);
5156 // Query the workspace size.
5157 size_t workspace_size_in_bytes = 0;
5158 #if CUDNN_VERSION >= 7603
5159 CudnnCtcLossDescriptor cudnn_ctc_loss_desc(ToCudnnDataType(element_type));
5160 const CudnnRnnStateTensorDescriptor& cudnn_probs_desc =
5161 static_cast<const CudnnRnnStateTensorDescriptor&>(probs_desc);
5162 const CudnnRnnStateTensorDescriptor& cudnn_grads_desc =
5163 static_cast<const CudnnRnnStateTensorDescriptor&>(grads_desc);
5164
5165 // Try running with `algo`, if successful then pick it. The non-deterministic
5166 // algorithm is first and thus preferentially picked when determinism is not
5167 // required.
5168 auto algo = RequireCudnnDeterminism() ? CUDNN_CTC_LOSS_ALGO_DETERMINISTIC
5169 : CUDNN_CTC_LOSS_ALGO_NON_DETERMINISTIC;
5170 cudnnStatus_t status = cudnnGetCTCLossWorkspaceSize(
5171 /*handle=*/cudnn.handle(), /*probsDesc=*/cudnn_probs_desc.handle(),
5172 /*gradientsDesc=*/cudnn_grads_desc.handle(),
5173 /*labels=*/labels_data.data(),
5174 /*labelLengths=*/labels_lengths_data.data(),
5175 /*inputLengths=*/input_lengths_data.data(),
5176 /*algo=*/algo,
5177 /*ctcLossDesc=*/cudnn_ctc_loss_desc.handle(),
5178 /*sizeInBytes=*/&workspace_size_in_bytes);
5179 if (RequireCudnnDeterminism()) {
5180 RETURN_IF_CUDNN_ERROR(status);
5181 }
5182
5183 if (status != CUDNN_STATUS_SUCCESS) {
5184 algo = CUDNN_CTC_LOSS_ALGO_DETERMINISTIC;
5185 RETURN_IF_CUDNN_ERROR(cudnnGetCTCLossWorkspaceSize(
5186 /*handle=*/cudnn.handle(), /*probsDesc=*/cudnn_probs_desc.handle(),
5187 /*gradientsDesc=*/cudnn_grads_desc.handle(),
5188 /*labels=*/labels_data.data(),
5189 /*labelLengths=*/labels_lengths_data.data(),
5190 /*inputLengths=*/input_lengths_data.data(),
5191 /*algo=*/algo,
5192 /*ctcLossDesc=*/cudnn_ctc_loss_desc.handle(),
5193 /*sizeInBytes=*/&workspace_size_in_bytes));
5194 }
5195 *ctc_loss_algo_id = algo;
5196 #else
5197 return port::Status(port::error::INVALID_ARGUMENT,
5198 "No supported cudnnGetCTCLossWorkspaceSize when "
5199 "CUDNN_VERSION < 7.6.3");
5200 #endif
5201 // Allocate the workspace.
5202 if (workspace_size_in_bytes == 0) {
5203 *scratch_memory = DeviceMemory<uint8>();
5204 return port::Status::OK();
5205 }
5206 const auto scratch_or =
5207 scratch_allocator->AllocateBytes(workspace_size_in_bytes);
5208 if (scratch_or.ok()) {
5209 *scratch_memory = scratch_or.ValueOrDie();
5210 return port::Status::OK();
5211 }
5212 return port::InternalError(
5213 "Failed to allocate scratch memory for the CuDNN CTC Loss");
5214 }
5215
DoCtcLoss(Stream * stream,dnn::DataType element_type,const dnn::RnnStateTensorDescriptor & probs_desc,const DeviceMemoryBase probs_data,absl::Span<const int> labels_data,absl::Span<const int> labels_lengths_data,absl::Span<const int> input_lengths_data,DeviceMemoryBase costs_data,const dnn::RnnStateTensorDescriptor & grads_desc,DeviceMemoryBase grads_data,DeviceMemory<uint8> scratch_memory,int ctc_loss_algo_id)5216 port::Status CudnnSupport::DoCtcLoss(
5217 Stream* stream, dnn::DataType element_type,
5218 const dnn::RnnStateTensorDescriptor& probs_desc,
5219 const DeviceMemoryBase probs_data, absl::Span<const int> labels_data,
5220 absl::Span<const int> labels_lengths_data,
5221 absl::Span<const int> input_lengths_data, DeviceMemoryBase costs_data,
5222 const dnn::RnnStateTensorDescriptor& grads_desc,
5223 DeviceMemoryBase grads_data, DeviceMemory<uint8> scratch_memory,
5224 int ctc_loss_algo_id) {
5225 // Current cuDNN CTC Loss only supports the float datatype
5226 if (CUDNN_VERSION < 7603 || element_type != dnn::DataType::kFloat) {
5227 return port::Status(port::error::INVALID_ARGUMENT,
5228 "CudnnCtcLossDescriptor is supported only when the "
5229 "CUDNN_VERSION >= 7.6.3 and DataType is float");
5230 }
5231 CudnnCtcLossDescriptor cudnn_ctc_loss_desc(ToCudnnDataType(element_type));
5232 const CudnnRnnStateTensorDescriptor& cudnn_probs_desc =
5233 static_cast<const CudnnRnnStateTensorDescriptor&>(probs_desc);
5234 const CudnnRnnStateTensorDescriptor& cudnn_grads_desc =
5235 static_cast<const CudnnRnnStateTensorDescriptor&>(grads_desc);
5236 return DoCtcLossImpl(stream, cudnn_probs_desc, probs_data, labels_data,
5237 labels_lengths_data, input_lengths_data, costs_data,
5238 cudnn_grads_desc, grads_data, cudnn_ctc_loss_desc,
5239 scratch_memory, ctc_loss_algo_id);
5240 }
5241
DoTransformTensor(Stream * stream,const dnn::BatchDescriptor & input_desc,dnn::DataType input_type,const DeviceMemoryBase & input_data,const dnn::BatchDescriptor & output_desc,dnn::DataType output_type,float scale,DeviceMemoryBase * output_data)5242 bool CudnnSupport::DoTransformTensor(Stream* stream,
5243 const dnn::BatchDescriptor& input_desc,
5244 dnn::DataType input_type,
5245 const DeviceMemoryBase& input_data,
5246 const dnn::BatchDescriptor& output_desc,
5247 dnn::DataType output_type, float scale,
5248 DeviceMemoryBase* output_data) {
5249 float beta = 0.0f;
5250 CudnnTensorDescriptor input_tensor_desc(
5251 input_desc, ToCudnnDataType(input_type, input_desc.layout()));
5252 CudnnTensorDescriptor output_tensor_desc(
5253 output_desc, ToCudnnDataType(output_type, output_desc.layout()));
5254 auto cudnn = cudnn_->GetHandle(parent_, stream);
5255 const auto status = [&] {
5256 RETURN_IF_CUDNN_ERROR(cudnnTransformTensor(
5257 cudnn.handle(), &scale, input_tensor_desc.handle(), input_data.opaque(),
5258 &beta, output_tensor_desc.handle(), output_data->opaque()));
5259 return port::Status::OK();
5260 }();
5261 return IsStatusOk(status, /*report_error=*/true);
5262 }
5263
DoMatMul(Stream * stream,const DeviceMemory<float> & input_data,const DeviceMemory<float> & weights,const dnn::BatchDescriptor & input_dimensions,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<float> * output_data)5264 bool CudnnSupport::DoMatMul(Stream* stream,
5265 const DeviceMemory<float>& input_data,
5266 const DeviceMemory<float>& weights,
5267 const dnn::BatchDescriptor& input_dimensions,
5268 const dnn::BatchDescriptor& output_dimensions,
5269 DeviceMemory<float>* output_data) {
5270 if (input_dimensions.count() != output_dimensions.count()) {
5271 LOG(ERROR) << "MatMul input and output dimensions are not compatible.";
5272 return false;
5273 }
5274
5275 // We do not permute the input or output, instead we just
5276 // reinterpret the layout. We are working with row-major matrices
5277 // and the rows of the input and output correspond to batch, so
5278 // batch has to be outermost in both the input and output.
5279 //
5280 // By adding transposes to the BLAS gemm call we could perhaps make
5281 // the kYXDepthBatch layout work as well, but there has been no need
5282 // for that so far.
5283 if (input_dimensions.layout() != dnn::DataLayout::kBatchYXDepth &&
5284 input_dimensions.layout() != dnn::DataLayout::kBatchDepthYX) {
5285 LOG(ERROR) << "Unsupported MatMul input layout.";
5286 return false;
5287 }
5288 if (output_dimensions.layout() != dnn::DataLayout::kBatchYXDepth &&
5289 output_dimensions.layout() != dnn::DataLayout::kBatchDepthYX) {
5290 LOG(ERROR) << "Unsupported MatMul output layout.";
5291 return false;
5292 }
5293
5294 if (output_dimensions.width() == 1 && output_dimensions.height() == 1) {
5295 // This is a fast path that also supports the kBatchYXDepth layout.
5296
5297 // The matrices here are in row-major format while BLAS expects
5298 // column-major, i.e. our matrices are transposed as far as BLAS
5299 // is concerned. So we need to compute output^T =
5300 // input^T*weights^T. There is no parameter for transposing the
5301 // output in BLAS gemm, but instead we can transpose both sides of
5302 // the equality to see that this is equivalent to
5303 // output=weights*input. So we only need to swap the order of
5304 // weights and input in the matrix product to correct for the
5305 // row-major versus column-major difference.
5306 const int64_t m = output_dimensions.NodesAcrossFeatureMaps();
5307 const int64_t n = input_dimensions.count();
5308 const int64_t k = input_dimensions.NodesAcrossFeatureMaps();
5309 if (!stream
5310 ->ThenBlasGemm(blas::Transpose::kNoTranspose,
5311 blas::Transpose::kNoTranspose, m, n, k, weights, m,
5312 input_data, k, output_data, m)
5313 .ok()) {
5314 return false;
5315 }
5316 } else {
5317 // This is a slower and more complex path that supports output
5318 // width() * height() > 1, though it only supports the
5319 // kBatchYXDepth layout. Does support kBatchDepthYX if output
5320 // feature_map_count() == 1, as then there is no difference
5321 // between the two layouts.
5322 //
5323 // The operation here is the same as above, except that we have to
5324 // do the matrix multiplication for each (y,x) output coordinate
5325 // separately. We then interpret weights as containing K = width()
5326 // * height() different matrices, which we all multiply onto the
5327 // matrix from input_data, yielding K matrix products. We then
5328 // combine these together into one matrix by concatenating all the
5329 // first rows of these matrices, then all the seconds rows and so
5330 // on. We can do this with a batched matrix multiplication, where
5331 // the result is written to a different submatrix of the output
5332 // for each matrix multiplication.
5333 //
5334 // The reason that we only support the kBatchYXDepth output layout
5335 // is that we have to do something in the depth for each (y,x)
5336 // coordinate. The kBatchYXDepth layout has the depth information
5337 // for each point (y,x) in contiguous memory while the
5338 // kBatchDepthYX layout does not.
5339 //
5340 // TODO(broune): Consider a special case for when output depth ==
5341 // 1, as then possibly this could all be done as one matrix
5342 // multiplication instead of a batched one, which should be
5343 // faster. Another possibility would be to add a weights layout
5344 // parameter and then support kBatchDepthYX for a different
5345 // weights layout.
5346 if (output_dimensions.layout() != dnn::DataLayout::kBatchYXDepth &&
5347 !(output_dimensions.layout() == dnn::DataLayout::kBatchDepthYX &&
5348 output_dimensions.feature_map_count() == 1)) {
5349 LOG(ERROR) << "Unsupported MatMul output layout.";
5350 return false;
5351 }
5352
5353 const float alpha = 1.0f; // Take the matrix product without scaling it.
5354 const float beta = 0.0f; // Ignore the original values in output_data.
5355 const uint64 m = output_dimensions.feature_map_count();
5356 const uint64 n = input_dimensions.count();
5357 const uint64 k = input_dimensions.NodesAcrossFeatureMaps();
5358 const int lda = m;
5359 const int ldb = k;
5360 const int ldc = output_dimensions.NodesAcrossFeatureMaps();
5361 const int batch_count = output_dimensions.NodesPerFeatureMap();
5362
5363 std::vector<DeviceMemory<float>> a(batch_count);
5364 std::vector<DeviceMemory<float>> b(batch_count);
5365 std::vector<DeviceMemory<float>> c(batch_count);
5366 for (int i = 0; i < batch_count; ++i) {
5367 const int weights_offset = i * input_dimensions.NodesAcrossFeatureMaps() *
5368 output_dimensions.feature_map_count();
5369 a[i] = DeviceMemory<float>::MakeFromByteSize(
5370 const_cast<float*>(reinterpret_cast<const float*>(weights.opaque())) +
5371 weights_offset,
5372 weights.ElementCount() - weights_offset);
5373
5374 b[i] = input_data;
5375
5376 const int output_offset = i * output_dimensions.feature_map_count();
5377 c[i] = DeviceMemory<float>::MakeFromByteSize(
5378 const_cast<float*>(
5379 reinterpret_cast<const float*>(output_data->opaque())) +
5380 output_offset,
5381 output_data->ElementCount() - output_offset);
5382 }
5383 const auto toPtrs = [](std::vector<DeviceMemory<float>>& v) {
5384 std::vector<DeviceMemory<float>*> ptrs;
5385 ptrs.reserve(v.size());
5386 for (auto& mem : v) {
5387 ptrs.push_back(&mem);
5388 }
5389 return ptrs;
5390 };
5391
5392 stream->ThenBlasGemmBatched(blas::Transpose::kNoTranspose,
5393 blas::Transpose::kNoTranspose, m, n, k, alpha,
5394 toPtrs(a), lda, toPtrs(b), ldb, beta, toPtrs(c),
5395 ldc, batch_count);
5396 }
5397
5398 return stream->ok();
5399 }
5400
DoBiasAdd(Stream * stream,const DeviceMemory<float> & input_data,const DeviceMemory<float> & biases,const dnn::BatchDescriptor & dimensions,DeviceMemory<float> * output_data)5401 bool CudnnSupport::DoBiasAdd(Stream* stream,
5402 const DeviceMemory<float>& input_data,
5403 const DeviceMemory<float>& biases,
5404 const dnn::BatchDescriptor& dimensions,
5405 DeviceMemory<float>* output_data) {
5406 CudnnTensorDescriptor input_descriptor(dimensions, CUDNN_DATA_FLOAT);
5407
5408 dnn::BatchDescriptor bias_dimensions;
5409 bias_dimensions.set_count(1)
5410 .set_feature_map_count(dimensions.feature_map_count())
5411 .set_height(1)
5412 .set_width(1)
5413 .set_layout(dnn::DataLayout::kBatchYXDepth);
5414 CudnnTensorDescriptor bias_descriptor(bias_dimensions, CUDNN_DATA_FLOAT);
5415
5416 // cudnnAddTensor after R3 is in-place, so we need to copy input_data to
5417 // output_data before doing the addition, unless the input and
5418 // output are at the same address.
5419 if (input_data.opaque() != output_data->opaque()) {
5420 stream->ThenMemcpy(output_data, input_data,
5421 dimensions.ElementCount() * sizeof(float));
5422 if (!stream->ok()) {
5423 LOG(ERROR)
5424 << "stream " << stream
5425 << " could not enqueue a tensor copy as part of bias addition.";
5426 return false;
5427 }
5428 }
5429
5430 const float alpha = 1.0f;
5431 const float beta = 1.0f;
5432
5433 auto cudnn = cudnn_->GetHandle(parent_, stream);
5434
5435 const auto status = [&] {
5436 RETURN_IF_CUDNN_ERROR(cudnnAddTensor(
5437 cudnn.handle(), &alpha, bias_descriptor.handle(), biases.opaque(),
5438 &beta, input_descriptor.handle(), output_data->opaque()));
5439 return port::Status::OK();
5440 }();
5441 return IsStatusOk(status, /*report_error=*/true);
5442 }
5443
DoActivate(Stream * stream,dnn::ActivationMode activation_mode,const dnn::BatchDescriptor & dimensions,const DeviceMemory<float> & input_data,DeviceMemory<float> * output_data,uint64 options)5444 bool CudnnSupport::DoActivate(Stream* stream,
5445 dnn::ActivationMode activation_mode,
5446 const dnn::BatchDescriptor& dimensions,
5447 const DeviceMemory<float>& input_data,
5448 DeviceMemory<float>* output_data,
5449 uint64 options) {
5450 CudnnActivationDescriptor activation_desc(
5451 activation_mode, CUDNN_PROPAGATE_NAN, dimensions.value_max());
5452
5453 CudnnTensorDescriptor input_nd(dimensions, CUDNN_DATA_FLOAT);
5454 // Alpha is the input scaling factor.
5455 float alpha = 1.0;
5456 // Beta is the output scaling factor.
5457 float beta = 0.0;
5458
5459 auto cudnn = cudnn_->GetHandle(parent_, stream);
5460 const auto status = [&] {
5461 RETURN_IF_CUDNN_ERROR(cudnnActivationForward(
5462 cudnn.handle(), activation_desc.handle(), &alpha, input_nd.handle(),
5463 input_data.opaque(), &beta, input_nd.handle(), output_data->opaque()));
5464 return port::Status::OK();
5465 }();
5466 return IsStatusOk(status, /*report_error=*/true);
5467 }
5468
DoPoolForward(Stream * stream,const dnn::PoolingDescriptor & pooling_dimensions,const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<double> & input_data,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<double> * output_data,ScratchAllocator * workspace_allocator)5469 bool CudnnSupport::DoPoolForward(
5470 Stream* stream, const dnn::PoolingDescriptor& pooling_dimensions,
5471 const dnn::BatchDescriptor& input_dimensions,
5472 const DeviceMemory<double>& input_data,
5473 const dnn::BatchDescriptor& output_dimensions,
5474 DeviceMemory<double>* output_data, ScratchAllocator* workspace_allocator) {
5475 // Alpha is the scaling factor for input.
5476 double alpha = 1.0;
5477 // Beta is the scaling factor for output.
5478 double beta = 0.0;
5479
5480 CudnnTensorDescriptor src_desc(input_dimensions, CUDNN_DATA_DOUBLE);
5481 CudnnTensorDescriptor dest_desc(output_dimensions, CUDNN_DATA_DOUBLE);
5482 CudnnPoolingDescriptor pooling_desc(pooling_dimensions);
5483
5484 auto cudnn = cudnn_->GetHandle(parent_, stream);
5485 const auto status = [&] {
5486 RETURN_IF_CUDNN_ERROR(cudnnPoolingForward(
5487 cudnn.handle(), pooling_desc.handle(), &alpha, src_desc.handle(),
5488 input_data.opaque(), &beta, dest_desc.handle(), output_data->opaque()));
5489 return port::Status::OK();
5490 }();
5491 return IsStatusOk(status, /*report_error=*/true);
5492 }
5493
DoPoolForward(Stream * stream,const dnn::PoolingDescriptor & pooling_dimensions,const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<float> & input_data,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<float> * output_data,ScratchAllocator * workspace_allocator)5494 bool CudnnSupport::DoPoolForward(
5495 Stream* stream, const dnn::PoolingDescriptor& pooling_dimensions,
5496 const dnn::BatchDescriptor& input_dimensions,
5497 const DeviceMemory<float>& input_data,
5498 const dnn::BatchDescriptor& output_dimensions,
5499 DeviceMemory<float>* output_data, ScratchAllocator* workspace_allocator) {
5500 // Alpha is the scaling factor for input.
5501 float alpha = 1.0;
5502 // Beta is the scaling factor for output.
5503 float beta = 0.0;
5504
5505 CudnnTensorDescriptor src_desc(input_dimensions, CUDNN_DATA_FLOAT);
5506 CudnnTensorDescriptor dest_desc(output_dimensions, CUDNN_DATA_FLOAT);
5507 CudnnPoolingDescriptor pooling_desc(pooling_dimensions);
5508
5509 auto cudnn = cudnn_->GetHandle(parent_, stream);
5510 const auto status = [&] {
5511 RETURN_IF_CUDNN_ERROR(cudnnPoolingForward(
5512 cudnn.handle(), pooling_desc.handle(), &alpha, src_desc.handle(),
5513 input_data.opaque(), &beta, dest_desc.handle(), output_data->opaque()));
5514 return port::Status::OK();
5515 }();
5516 return IsStatusOk(status, /*report_error=*/true);
5517 }
5518
DoPoolForward(Stream * stream,const dnn::PoolingDescriptor & pooling_dimensions,const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<Eigen::half> & input_data,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<Eigen::half> * output_data,ScratchAllocator * workspace_allocator)5519 bool CudnnSupport::DoPoolForward(
5520 Stream* stream, const dnn::PoolingDescriptor& pooling_dimensions,
5521 const dnn::BatchDescriptor& input_dimensions,
5522 const DeviceMemory<Eigen::half>& input_data,
5523 const dnn::BatchDescriptor& output_dimensions,
5524 DeviceMemory<Eigen::half>* output_data,
5525 ScratchAllocator* workspace_allocator) {
5526 // Alpha is the scaling factor for input.
5527 float alpha = 1.0;
5528 // Beta is the scaling factor for output.
5529 float beta = 0.0;
5530
5531 CudnnTensorDescriptor src_desc(input_dimensions, CUDNN_DATA_HALF);
5532 CudnnTensorDescriptor dest_desc(output_dimensions, CUDNN_DATA_HALF);
5533 CudnnPoolingDescriptor pooling_desc(pooling_dimensions);
5534 auto cudnn = cudnn_->GetHandle(parent_, stream);
5535 const auto status = [&] {
5536 RETURN_IF_CUDNN_ERROR(cudnnPoolingForward(
5537 cudnn.handle(), pooling_desc.handle(), &alpha, src_desc.handle(),
5538 input_data.opaque(), &beta, dest_desc.handle(), output_data->opaque()));
5539 return port::Status::OK();
5540 }();
5541 return IsStatusOk(status, /*report_error=*/true);
5542 }
5543
DoPoolForward(Stream * stream,const dnn::PoolingDescriptor & pooling_dimensions,const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<int8> & input_data,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<int8> * output_data,ScratchAllocator * workspace_allocator)5544 bool CudnnSupport::DoPoolForward(
5545 Stream* stream, const dnn::PoolingDescriptor& pooling_dimensions,
5546 const dnn::BatchDescriptor& input_dimensions,
5547 const DeviceMemory<int8>& input_data,
5548 const dnn::BatchDescriptor& output_dimensions,
5549 DeviceMemory<int8>* output_data, ScratchAllocator* workspace_allocator) {
5550 // Alpha is the scaling factor for input.
5551 float alpha = 1.0;
5552 // Beta is the scaling factor for output.
5553 float beta = 0.0;
5554
5555 CudnnTensorDescriptor src_desc(input_dimensions, CUDNN_DATA_INT8);
5556 CudnnTensorDescriptor dest_desc(output_dimensions, CUDNN_DATA_INT8);
5557 CudnnPoolingDescriptor pooling_desc(pooling_dimensions);
5558
5559 auto cudnn = cudnn_->GetHandle(parent_, stream);
5560 const auto status = [&] {
5561 RETURN_IF_CUDNN_ERROR(cudnnPoolingForward(
5562 cudnn.handle(), pooling_desc.handle(), &alpha, src_desc.handle(),
5563 input_data.opaque(), &beta, dest_desc.handle(), output_data->opaque()));
5564 return port::Status::OK();
5565 }();
5566 return IsStatusOk(status, /*report_error=*/true);
5567 }
5568
DoPoolBackward(Stream * stream,const dnn::PoolingDescriptor & pooling_dimensions,const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<double> & input_data,const dnn::BatchDescriptor & output_dimensions,const DeviceMemory<double> & output_data,const DeviceMemory<double> & input_diff_data,DeviceMemory<double> * output_diff_data,ScratchAllocator * workspace_allocator)5569 bool CudnnSupport::DoPoolBackward(
5570 Stream* stream, const dnn::PoolingDescriptor& pooling_dimensions,
5571 const dnn::BatchDescriptor& input_dimensions,
5572 const DeviceMemory<double>& input_data,
5573 const dnn::BatchDescriptor& output_dimensions,
5574 const DeviceMemory<double>& output_data,
5575 const DeviceMemory<double>& input_diff_data,
5576 DeviceMemory<double>* output_diff_data,
5577 ScratchAllocator* workspace_allocator) {
5578 // Alpha is the scaling factor for input.
5579 double alpha = 1.0;
5580 // Beta is the scaling factor for output.
5581 double beta = 0.0;
5582
5583 CudnnTensorDescriptor src_desc(input_dimensions, CUDNN_DATA_DOUBLE);
5584 CudnnTensorDescriptor dest_desc(output_dimensions, CUDNN_DATA_DOUBLE);
5585 CudnnPoolingDescriptor pooling_desc(pooling_dimensions);
5586
5587 auto cudnn = cudnn_->GetHandle(parent_, stream);
5588 const auto status = [&] {
5589 RETURN_IF_CUDNN_ERROR(cudnnPoolingBackward(
5590 cudnn.handle(), pooling_desc.handle(), &alpha, dest_desc.handle(),
5591 output_data.opaque(), dest_desc.handle(), input_diff_data.opaque(),
5592 src_desc.handle(), input_data.opaque(), &beta, src_desc.handle(),
5593 output_diff_data->opaque()));
5594 return port::Status::OK();
5595 }();
5596 return IsStatusOk(status, /*report_error=*/true);
5597 }
5598
DoPoolBackward(Stream * stream,const dnn::PoolingDescriptor & pooling_dimensions,const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<float> & input_data,const dnn::BatchDescriptor & output_dimensions,const DeviceMemory<float> & output_data,const DeviceMemory<float> & input_diff_data,DeviceMemory<float> * output_diff_data,ScratchAllocator * workspace_allocator)5599 bool CudnnSupport::DoPoolBackward(
5600 Stream* stream, const dnn::PoolingDescriptor& pooling_dimensions,
5601 const dnn::BatchDescriptor& input_dimensions,
5602 const DeviceMemory<float>& input_data,
5603 const dnn::BatchDescriptor& output_dimensions,
5604 const DeviceMemory<float>& output_data,
5605 const DeviceMemory<float>& input_diff_data,
5606 DeviceMemory<float>* output_diff_data,
5607 ScratchAllocator* workspace_allocator) {
5608 // Alpha is the scaling factor for input.
5609 float alpha = 1.0;
5610 // Beta is the scaling factor for output.
5611 float beta = 0.0;
5612
5613 CudnnTensorDescriptor src_desc(input_dimensions, CUDNN_DATA_FLOAT);
5614 CudnnTensorDescriptor dest_desc(output_dimensions, CUDNN_DATA_FLOAT);
5615 CudnnPoolingDescriptor pooling_desc(pooling_dimensions);
5616
5617 auto cudnn = cudnn_->GetHandle(parent_, stream);
5618 const auto status = [&] {
5619 RETURN_IF_CUDNN_ERROR(cudnnPoolingBackward(
5620 cudnn.handle(), pooling_desc.handle(), &alpha, dest_desc.handle(),
5621 output_data.opaque(), dest_desc.handle(), input_diff_data.opaque(),
5622 src_desc.handle(), input_data.opaque(), &beta, src_desc.handle(),
5623 output_diff_data->opaque()));
5624 return port::Status::OK();
5625 }();
5626 return IsStatusOk(status, /*report_error=*/true);
5627 }
5628
DoPoolBackward(Stream * stream,const dnn::PoolingDescriptor & pooling_dimensions,const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<Eigen::half> & input_data,const dnn::BatchDescriptor & output_dimensions,const DeviceMemory<Eigen::half> & output_data,const DeviceMemory<Eigen::half> & input_diff_data,DeviceMemory<Eigen::half> * output_diff_data,ScratchAllocator * workspace_allocator)5629 bool CudnnSupport::DoPoolBackward(
5630 Stream* stream, const dnn::PoolingDescriptor& pooling_dimensions,
5631 const dnn::BatchDescriptor& input_dimensions,
5632 const DeviceMemory<Eigen::half>& input_data,
5633 const dnn::BatchDescriptor& output_dimensions,
5634 const DeviceMemory<Eigen::half>& output_data,
5635 const DeviceMemory<Eigen::half>& input_diff_data,
5636 DeviceMemory<Eigen::half>* output_diff_data,
5637 ScratchAllocator* workspace_allocator) {
5638 // Alpha is the scaling factor for input.
5639 float alpha = 1.0;
5640 // Beta is the scaling factor for output.
5641 float beta = 0.0;
5642
5643 CudnnTensorDescriptor src_desc(input_dimensions, CUDNN_DATA_HALF);
5644 CudnnTensorDescriptor dest_desc(output_dimensions, CUDNN_DATA_HALF);
5645 CudnnPoolingDescriptor pooling_desc(pooling_dimensions);
5646
5647 auto cudnn = cudnn_->GetHandle(parent_, stream);
5648 const auto status = [&] {
5649 RETURN_IF_CUDNN_ERROR(cudnnPoolingBackward(
5650 cudnn.handle(), pooling_desc.handle(), &alpha, dest_desc.handle(),
5651 output_data.opaque(), dest_desc.handle(), input_diff_data.opaque(),
5652 src_desc.handle(), input_data.opaque(), &beta, src_desc.handle(),
5653 output_diff_data->opaque()));
5654 return port::Status::OK();
5655 }();
5656 return IsStatusOk(status, /*report_error=*/true);
5657 }
5658
DoNormalizeWithDimensions(Stream * stream,const dnn::NormalizeDescriptor & normalize_descriptor,const dnn::BatchDescriptor & dimensions,const DeviceMemory<float> & input_data,DeviceMemory<float> * output_data)5659 bool CudnnSupport::DoNormalizeWithDimensions(
5660 Stream* stream, const dnn::NormalizeDescriptor& normalize_descriptor,
5661 const dnn::BatchDescriptor& dimensions,
5662 const DeviceMemory<float>& input_data, DeviceMemory<float>* output_data) {
5663 // Check for unsupported modes.
5664 if (normalize_descriptor.wrap_around()) {
5665 LOG(ERROR) << "CUDA LRN does not support cudnn-around mode";
5666 return false;
5667 }
5668 if (normalize_descriptor.segment_size()) {
5669 LOG(ERROR) << "CUDA LRN does not support segmentation";
5670 return false;
5671 }
5672
5673 CudnnTensorDescriptor dims(dimensions, CUDNN_DATA_FLOAT);
5674 CudnnNormalizeDescriptor normalize(normalize_descriptor);
5675
5676 // Alpha is the scaling factor for input.
5677 float alpha = 1.0f;
5678 // Beta is the scaling factor for output.
5679 float beta = 0.0f;
5680
5681 auto cudnn = cudnn_->GetHandle(parent_, stream);
5682
5683 // Launch the normalization.
5684 const auto status = [&] {
5685 RETURN_IF_CUDNN_ERROR(cudnnLRNCrossChannelForward(
5686 cudnn.handle(), normalize.handle(), CUDNN_LRN_CROSS_CHANNEL_DIM1,
5687 &alpha, dims.handle(), input_data.opaque(), &beta, dims.handle(),
5688 output_data->opaque()));
5689 return port::Status::OK();
5690 }();
5691 return IsStatusOk(status, /*report_error=*/true);
5692 }
5693
DoNormalizeBackwardWithDimensions(Stream * stream,const dnn::NormalizeDescriptor & normalize_descriptor,const dnn::BatchDescriptor & dimensions,const DeviceMemory<float> & raw_data,const DeviceMemory<float> & normalized_data,const DeviceMemory<float> & normalized_variable_gradient,DeviceMemory<float> * raw_variable_gradient,ScratchAllocator * workspace_allocator)5694 bool CudnnSupport::DoNormalizeBackwardWithDimensions(
5695 Stream* stream, const dnn::NormalizeDescriptor& normalize_descriptor,
5696 const dnn::BatchDescriptor& dimensions, const DeviceMemory<float>& raw_data,
5697 const DeviceMemory<float>& normalized_data,
5698 const DeviceMemory<float>& normalized_variable_gradient,
5699 DeviceMemory<float>* raw_variable_gradient,
5700 ScratchAllocator* workspace_allocator) {
5701 // Check for unsupported modes.
5702 if (normalize_descriptor.wrap_around()) {
5703 LOG(ERROR) << "CUDA LRN does not support cudnn-around mode";
5704 return false;
5705 }
5706 if (normalize_descriptor.segment_size()) {
5707 LOG(ERROR) << "CUDA LRN does not support segmentation";
5708 return false;
5709 }
5710
5711 CudnnTensorDescriptor dims(dimensions, CUDNN_DATA_FLOAT);
5712 CudnnNormalizeDescriptor normalize(normalize_descriptor);
5713
5714 float alpha = 1.0f;
5715 float beta = 0.0f;
5716
5717 auto cudnn = cudnn_->GetHandle(parent_, stream);
5718 const auto status = [&] {
5719 RETURN_IF_CUDNN_ERROR(cudnnLRNCrossChannelBackward(
5720 cudnn.handle(), normalize.handle(), CUDNN_LRN_CROSS_CHANNEL_DIM1,
5721 &alpha, dims.handle(), normalized_data.opaque(), dims.handle(),
5722 normalized_variable_gradient.opaque(), dims.handle(), raw_data.opaque(),
5723 &beta, dims.handle(), raw_variable_gradient->opaque()));
5724 return port::Status::OK();
5725 }();
5726 return IsStatusOk(status, /*report_error=*/true);
5727 }
5728
DoDepthConcatenate(Stream * stream,port::ArraySlice<dnn::BatchDescriptor> input_dimensions,port::ArraySlice<const DeviceMemory<float> * > input_data,DeviceMemory<float> * output_data)5729 bool CudnnSupport::DoDepthConcatenate(
5730 Stream* stream, port::ArraySlice<dnn::BatchDescriptor> input_dimensions,
5731 port::ArraySlice<const DeviceMemory<float>*> input_data,
5732 DeviceMemory<float>* output_data) {
5733 CHECK_EQ(input_dimensions.size(), input_data.size());
5734
5735 for (const auto& dimensions : input_dimensions) {
5736 if (dimensions.layout() != dnn::DataLayout::kBatchDepthYX) {
5737 LOG(ERROR) << "CudnnSupport::DoDepthConcatenate currently only "
5738 "supports the kBatchDepthYX layout.";
5739 return false;
5740 }
5741 }
5742
5743 if (input_dimensions.empty()) {
5744 return true; // Nothing to do.
5745 }
5746
5747 dnn::BatchDescriptor output_dimensions =
5748 dnn::BatchDescriptor::DepthConcatenateOutputDescriptor(input_dimensions);
5749
5750 const int64_t area = output_dimensions.width() * output_dimensions.height();
5751 const auto index = [area](int64_t batch, int64_t depth, int64_t yx,
5752 int64_t max_depth) {
5753 return (batch * max_depth + depth) * area + yx;
5754 };
5755
5756 std::vector<float> output_host(output_dimensions.ElementCount());
5757 std::vector<float> tmp;
5758 int64_t depth_sum = 0;
5759 for (size_t i = 0; i < input_data.size(); ++i) {
5760 const auto& dimensions = input_dimensions[i];
5761 tmp.resize(dimensions.ElementCount());
5762 stream->ThenMemcpyD2H<float>(*input_data[i], absl::MakeSpan(tmp));
5763 port::Status block_status = stream->BlockHostUntilDone();
5764 if (!block_status.ok()) {
5765 LOG(ERROR) << "BlockHostUntilDone failed: " << block_status;
5766 return false;
5767 }
5768
5769 for (int64_t batch = 0; batch < output_dimensions.count(); ++batch) {
5770 for (int64_t yx = 0; yx < area; ++yx) {
5771 for (int64_t depth = 0; depth < dimensions.feature_map_count();
5772 ++depth) {
5773 LOG(INFO) << output_dimensions.ElementCount() << ' ' << batch << ' '
5774 << yx << ' ' << depth;
5775 output_host[index(batch, depth + depth_sum, yx,
5776 output_dimensions.feature_map_count())] =
5777 tmp[index(batch, depth, yx, dimensions.feature_map_count())];
5778 }
5779 }
5780 }
5781 depth_sum += dimensions.feature_map_count();
5782 }
5783 stream->ThenMemcpyH2D<float>(output_host, output_data);
5784 return true;
5785 }
5786
DoElementwiseOperate(Stream * stream,dnn::ElementwiseOperation operation,port::ArraySlice<dnn::BatchDescriptor> input_dimensions,port::ArraySlice<const DeviceMemory<float> * > input_data,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<float> * output_data)5787 bool CudnnSupport::DoElementwiseOperate(
5788 Stream* stream, dnn::ElementwiseOperation operation,
5789 port::ArraySlice<dnn::BatchDescriptor> input_dimensions,
5790 port::ArraySlice<const DeviceMemory<float>*> input_data,
5791 const dnn::BatchDescriptor& output_dimensions,
5792 DeviceMemory<float>* output_data) {
5793 LOG(FATAL) << "not yet implemented"; // TODO(leary)
5794 return false;
5795 }
5796
DoXYPad(Stream * stream,const dnn::BatchDescriptor & dimensions,const DeviceMemory<float> & input_data,int64_t left_pad,int64_t right_pad,int64_t top_pad,int64_t bottom_pad,DeviceMemory<float> * output_data)5797 bool CudnnSupport::DoXYPad(Stream* stream,
5798 const dnn::BatchDescriptor& dimensions,
5799 const DeviceMemory<float>& input_data,
5800 int64_t left_pad, int64_t right_pad, int64_t top_pad,
5801 int64_t bottom_pad,
5802 DeviceMemory<float>* output_data) {
5803 LOG(FATAL) << "not yet implemented"; // TODO(leary)
5804 return false;
5805 }
5806
DoXYSlice(Stream * stream,const dnn::BatchDescriptor & dimensions,const DeviceMemory<float> & input_data,int64_t left_trim,int64_t right_trim,int64_t top_trim,int64_t bottom_trim,DeviceMemory<float> * output_data)5807 bool CudnnSupport::DoXYSlice(Stream* stream,
5808 const dnn::BatchDescriptor& dimensions,
5809 const DeviceMemory<float>& input_data,
5810 int64_t left_trim, int64_t right_trim,
5811 int64_t top_trim, int64_t bottom_trim,
5812 DeviceMemory<float>* output_data) {
5813 LOG(FATAL) << "not yet implemented"; // TODO(leary)
5814 return false;
5815 }
5816
DoMemcpyD2HQuantized(Stream * stream,const DeviceMemory<float> & gpu_unquantized_src,dnn::QuantizedActivationMode mode,void * host_dst,int64_t size)5817 bool CudnnSupport::DoMemcpyD2HQuantized(
5818 Stream* stream, const DeviceMemory<float>& gpu_unquantized_src,
5819 dnn::QuantizedActivationMode mode, void* host_dst, int64_t size) {
5820 LOG(ERROR) << "quantized memcpy not supported by cuDNN";
5821 return false;
5822 }
5823
DoMemcpyH2DQuantized(Stream * stream,const void * host_src,int64_t size,dnn::QuantizedActivationMode mode,DeviceMemory<float> * gpu_unquantized_dst)5824 bool CudnnSupport::DoMemcpyH2DQuantized(
5825 Stream* stream, const void* host_src, int64_t size,
5826 dnn::QuantizedActivationMode mode,
5827 DeviceMemory<float>* gpu_unquantized_dst) {
5828 LOG(ERROR) << "quantized memcpy not supported by cuDNN";
5829 return false;
5830 }
5831
DeriveOutputBatchDescriptor(const dnn::BatchDescriptor & batch_descriptor,const dnn::FilterDescriptor & filter_descriptor,const dnn::ConvolutionDescriptor & convolution_descriptor,dnn::BatchDescriptor * output_batch_descriptor)5832 bool CudnnSupport::DeriveOutputBatchDescriptor(
5833 const dnn::BatchDescriptor& batch_descriptor,
5834 const dnn::FilterDescriptor& filter_descriptor,
5835 const dnn::ConvolutionDescriptor& convolution_descriptor,
5836 dnn::BatchDescriptor* output_batch_descriptor) {
5837 CudnnTensorDescriptor input_nd(batch_descriptor, CUDNN_DATA_FLOAT);
5838 CudnnFilterDescriptor filter(filter_descriptor, CUDNN_DATA_FLOAT);
5839 CudnnConvolutionDescriptor conv(convolution_descriptor, CUDNN_DATA_FLOAT);
5840
5841 int dn = batch_descriptor.ndims() + 2;
5842 std::vector<int> dims(dn); // in BDYX
5843 const auto status = [&] {
5844 RETURN_IF_CUDNN_ERROR(cudnnGetConvolutionNdForwardOutputDim(
5845 conv.handle(), input_nd.handle(), filter.handle(), dn, dims.data()));
5846 output_batch_descriptor->set_count(dims[0])
5847 .set_feature_map_count(dims[1])
5848 .set_layout(batch_descriptor.layout());
5849
5850 for (int i = 0; i < batch_descriptor.ndims(); i++) {
5851 output_batch_descriptor->set_spatial_dim(static_cast<dnn::DimIndex>(i),
5852 dims.rbegin()[i]);
5853 }
5854 return port::Status::OK();
5855 }();
5856 return IsStatusOk(status, /*report_error=*/true);
5857 }
5858
5859 } // namespace gpu
5860
initialize_cudnn()5861 void initialize_cudnn() {
5862 port::Status status =
5863 PluginRegistry::Instance()->RegisterFactory<PluginRegistry::DnnFactory>(
5864 cuda::kCudaPlatformId, gpu::kCuDnnPlugin, "cuDNN",
5865 [](internal::StreamExecutorInterface* parent) -> dnn::DnnSupport* {
5866 gpu::GpuExecutor* cuda_executor =
5867 dynamic_cast<gpu::GpuExecutor*>(parent);
5868 if (cuda_executor == nullptr) {
5869 LOG(ERROR) << "Attempting to initialize an instance of the cuDNN "
5870 << "support library with a non-CUDA StreamExecutor";
5871 return nullptr;
5872 }
5873
5874 gpu::CudnnSupport* dnn = new gpu::CudnnSupport(cuda_executor);
5875 if (!dnn->Init().ok()) {
5876 // Note: Init() will log a more specific error.
5877 delete dnn;
5878 return nullptr;
5879 }
5880 return dnn;
5881 });
5882
5883 if (!status.ok()) {
5884 LOG(ERROR) << "Unable to register cuDNN factory: "
5885 << status.error_message();
5886 }
5887
5888 PluginRegistry::Instance()->SetDefaultFactory(
5889 cuda::kCudaPlatformId, PluginKind::kDnn, gpu::kCuDnnPlugin);
5890 }
5891
5892 } // namespace stream_executor
5893
5894 #pragma clang diagnostic pop
5895
5896 REGISTER_MODULE_INITIALIZER(register_cudnn,
5897 { stream_executor::initialize_cudnn(); });
5898