1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.h"
17
18 #include <functional>
19 #include <memory>
20 #include <optional>
21 #include <utility>
22
23 #include "absl/base/thread_annotations.h"
24 #include "absl/container/flat_hash_map.h"
25 #include "absl/memory/memory.h"
26 #include "absl/strings/str_cat.h"
27 #include "absl/strings/str_format.h"
28 #include "third_party/eigen3/Eigen/Core"
29 #include "tensorflow/compiler/xla/stream_executor/cuda/cuda_activation.h"
30 #include "tensorflow/compiler/xla/stream_executor/cuda/cuda_diagnostics.h"
31 #include "tensorflow/compiler/xla/stream_executor/cuda/cuda_driver.h"
32 #include "tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executor.h"
33 #include "tensorflow/compiler/xla/stream_executor/cuda/cuda_platform_id.h"
34 #include "tensorflow/compiler/xla/stream_executor/cuda/cuda_stream.h"
35 #include "tensorflow/compiler/xla/stream_executor/cuda/cuda_timer.h"
36 #include "tensorflow/compiler/xla/stream_executor/cuda/cudnn_version.h"
37 #include "tensorflow/compiler/xla/stream_executor/dnn.h"
38 #include "tensorflow/compiler/xla/stream_executor/lib/env.h"
39 #include "tensorflow/compiler/xla/stream_executor/lib/error.h"
40 #include "tensorflow/compiler/xla/stream_executor/lib/initialize.h"
41 #include "tensorflow/compiler/xla/stream_executor/lib/mathutil.h"
42 #include "tensorflow/compiler/xla/stream_executor/lib/threadpool.h"
43 #include "tensorflow/compiler/xla/stream_executor/platform/logging.h"
44 #include "tensorflow/compiler/xla/stream_executor/plugin_registry.h"
45 #include "tensorflow/compiler/xla/stream_executor/scratch_allocator.h"
46 #include "tensorflow/compiler/xla/stream_executor/stream.h"
47 #include "tensorflow/compiler/xla/stream_executor/stream_executor_internal.h"
48 #include "tensorflow/compiler/xla/stream_executor/stream_executor_pimpl.h"
49 #include "tensorflow/core/lib/core/errors.h"
50 #include "tensorflow/core/platform/tensor_float_32_utils.h"
51 #include "tensorflow/core/util/determinism.h"
52 #include "tensorflow/core/util/env_var.h"
53 #include "tensorflow/core/util/use_cudnn.h"
54 // clang-format off
55 #include "third_party/gpus/cudnn/cudnn.h"
56 #if CUDNN_VERSION >= 8100 && TF_ENABLE_CUDNN_FRONTEND
57 #include "third_party/cudnn_frontend/include/cudnn_frontend.h"
58 #endif // CUDNN_VERSION >= 8100 && TF_ENABLE_CUDNN_FRONTEND
59 #include "absl/strings/string_view.h"
60 // clang-format on
61
62 #pragma clang diagnostic push
63
64 // Make sure that Eigen::half forward declaration in dnn.h matches the
65 // declaration in Eigen.
66 #pragma clang diagnostic warning "-Wmismatched-tags"
67
68 namespace stream_executor {
69 namespace gpu {
70
71 PLUGIN_REGISTRY_DEFINE_PLUGIN_ID(kCuDnnPlugin);
72
73 namespace {
74
75 static_assert(CUDNN_VERSION >= 7300, "cuDNN needs to be version 7.3 or higher");
76
77 // Exits the program if 'expr' doesn't return CUDNN_STATUS_SUCCESS.
78 #define CHECK_CUDNN_OK(expr) CHECK_EQ(expr, CUDNN_STATUS_SUCCESS)
79
80 // If 'expr' doesn't return CUDNN_STATUS_SUCCESS, returns from the current
81 // function with a non-successful port::Status.
82 #define RETURN_IF_CUDNN_ERROR(expr) \
83 do { \
84 cudnnStatus_t _status = (expr); \
85 if (!SE_PREDICT_TRUE(_status == CUDNN_STATUS_SUCCESS)) { \
86 std::ostringstream oss; \
87 oss << CudnnStatusToString(_status) << "\nin " << __FILE__ << "(" \
88 << __LINE__ << "): '" << #expr << "'"; \
89 return port::Status(port::error::UNKNOWN, oss.str()); \
90 } \
91 } while (false)
92
93 #define RETURN_MSG_IF_CUDNN_ERROR(expr) \
94 do { \
95 cudnnStatus_t _status = (expr).get_status(); \
96 if (!SE_PREDICT_TRUE(_status == CUDNN_STATUS_SUCCESS)) { \
97 std::ostringstream oss; \
98 oss << CudnnStatusToString(_status) << "\nin " << __FILE__ << "(" \
99 << __LINE__ << "): '" << #expr << "' " << (expr).get_error(); \
100 return port::Status(port::error::UNKNOWN, oss.str()); \
101 } \
102 } while (false)
103
104 #define RETURN_FALSE_IF_CUDNN_ERROR(expr) \
105 do { \
106 if (!SE_PREDICT_TRUE((expr).get_status() == CUDNN_STATUS_SUCCESS)) { \
107 return false; \
108 } \
109 } while (false)
110
111 // Converts (via narrowing) a type T value to a type U, and checks that the
112 // value has no value change due to the conversion.
113 template <typename WideT, typename NarrowT>
CheckedNarrowing(const WideT & wide)114 NarrowT CheckedNarrowing(const WideT& wide) {
115 NarrowT narrow = wide;
116 CHECK_EQ(narrow, wide)
117 << "checked narrowing failed; values not equal post-conversion";
118 return narrow;
119 }
120
CudnnStatusToString(cudnnStatus_t status)121 std::string CudnnStatusToString(cudnnStatus_t status) {
122 switch (status) {
123 case CUDNN_STATUS_SUCCESS:
124 return "CUDNN_STATUS_SUCCESS";
125 case CUDNN_STATUS_NOT_INITIALIZED:
126 return "CUDNN_STATUS_NOT_INITIALIZED";
127 case CUDNN_STATUS_ALLOC_FAILED:
128 return "CUDNN_STATUS_ALLOC_FAILED";
129 case CUDNN_STATUS_BAD_PARAM:
130 return "CUDNN_STATUS_BAD_PARAM";
131 case CUDNN_STATUS_INTERNAL_ERROR:
132 return "CUDNN_STATUS_INTERNAL_ERROR";
133 case CUDNN_STATUS_INVALID_VALUE:
134 return "CUDNN_STATUS_INVALID_VALUE";
135 case CUDNN_STATUS_ARCH_MISMATCH:
136 return "CUDNN_STATUS_ARCH_MISMATCH";
137 case CUDNN_STATUS_MAPPING_ERROR:
138 return "CUDNN_STATUS_MAPPING_ERROR";
139 case CUDNN_STATUS_EXECUTION_FAILED:
140 return "CUDNN_STATUS_EXECUTION_FAILED";
141 case CUDNN_STATUS_NOT_SUPPORTED:
142 return "CUDNN_STATUS_NOT_SUPPORTED";
143 case CUDNN_STATUS_LICENSE_ERROR:
144 return "CUDNN_STATUS_LICENSE_ERROR";
145 case CUDNN_STATUS_RUNTIME_PREREQUISITE_MISSING:
146 return "CUDNN_STATUS_RUNTIME_PREREQUISITE_MISSING";
147 case CUDNN_STATUS_RUNTIME_IN_PROGRESS:
148 return "CUDNN_STATUS_RUNTIME_IN_PROGRESS";
149 case CUDNN_STATUS_RUNTIME_FP_OVERFLOW:
150 return "CUDNN_STATUS_RUNTIME_FP_OVERFLOW";
151 default:
152 return absl::StrCat("<unknown cudnn status: ", static_cast<int>(status),
153 ">");
154 }
155 }
156
157 // RAII wrapper for all calls to cuDNN with a cuDNN handle argument.
158 //
159 // See CudnnAccess::GetHandle() for details.
160 class CudnnHandle {
161 public:
162 // Takes ownership of the executor context and the lock to access cuDNN
163 // using handle.
CudnnHandle(gpu::ScopedActivateExecutorContext context,std::unique_ptr<absl::MutexLock> lock,cudnnHandle_t handle)164 CudnnHandle(gpu::ScopedActivateExecutorContext context,
165 std::unique_ptr<absl::MutexLock> lock, cudnnHandle_t handle)
166 : context_(std::move(context)), lock_(std::move(lock)), handle_(handle) {}
167
168 // Returns cuDNN handle. To be passed directly to cuDNN APIs, don't keep
169 // a copy.
handle() const170 cudnnHandle_t handle() const { return handle_; }
171
172 private:
173 gpu::ScopedActivateExecutorContext context_;
174 std::unique_ptr<absl::MutexLock> lock_;
175 cudnnHandle_t handle_; // Not owned.
176 };
177
178 } // namespace
179
180 // Wraps a cuDNN handle and provides access to it through CudnnHandle
181 // instances, which also locks a mutex, acquires the CUDA context, and sets
182 // the stream that cuDNN should use to enqueue any work.
183 //
184 // Note: CudnnSupport::cudnn_ should be the only instantiation of this class.
185 class CudnnAccess {
186 public:
187 // Takes ownership of the handle.
CudnnAccess(cudnnHandle_t handle)188 explicit CudnnAccess(cudnnHandle_t handle) : handle_(handle) {}
189
~CudnnAccess()190 ~CudnnAccess() {
191 absl::MutexLock lock(&mutex_);
192 cudnnDestroy(handle_);
193 }
194
195 // Creates a CudnnHandle instance for stream.
196 //
197 // cuDNN API calls using the same handle instance need to be serialized
198 // across threads. This is guaranteed by CudnnHandle instances locking the
199 // mutex owned by this class.
200 //
201 // Most cuDNN APIs taking a handle perform work on a CUDA stream. The
202 // CudnnHandle instance acquires the executor's CUDA context and sets cuDNN
203 // to use the provided stream.
204 //
205 // The stream argument may be null, which translates to the legacy default
206 // stream. See
207 // https://docs.nvidia.com/cuda/cuda-driver-api/stream-sync-behavior.html.
208 // The legacy default stream synchronizes with all other streams and it is
209 // therefore a bad idea (performance wise) to call any cuDNN APIs that
210 // enqueue work in the stream.
GetHandle(GpuExecutor * executor,Stream * stream)211 CudnnHandle GetHandle(GpuExecutor* executor, Stream* stream) {
212 auto lock = std::make_unique<absl::MutexLock>(&mutex_);
213 mutex_.AssertHeld();
214 gpu::ScopedActivateExecutorContext context(executor);
215 CUstream cu_stream = stream ? AsGpuStreamValue(stream) : cudaStreamLegacy;
216 if (!current_stream_ || cu_stream != *current_stream_) {
217 current_stream_ = cu_stream;
218 const auto status = cudnnSetStream(handle_, cu_stream);
219 CHECK_EQ(status, CUDNN_STATUS_SUCCESS) << "Failed to set cuDNN stream.";
220 }
221 return CudnnHandle(std::move(context), std::move(lock), handle_);
222 }
223
NotifyStreamDestroyed(Stream * stream)224 void NotifyStreamDestroyed(Stream* stream) {
225 CUstream cu_stream = AsGpuStreamValue(stream);
226 absl::MutexLock lock(&mutex_);
227 if (current_stream_ && cu_stream == *current_stream_) {
228 current_stream_.reset();
229 }
230 }
231
232 private:
233 // Guards current_stream_ and the enqueueing of cuDNN operations via the
234 // handle_ below.
235 absl::Mutex mutex_;
236
237 // If set, indicates the stream currently active on handle_, to avoid the
238 // overhead of re-setting the same stream unnecessarily.
239 std::optional<CUstream> current_stream_ ABSL_GUARDED_BY(mutex_);
240
241 // cuDNN library handle.
242 cudnnHandle_t handle_ ABSL_GUARDED_BY(mutex_); // Owned.
243 };
244
245 namespace {
246
247 // A helper function to return the internal compute type for
248 // RNNs in cudnn.
249 cudnnDataType_t GetRnnComputeType(dnn::DataType data_type);
250
ToConvForwardAlgo(dnn::AlgorithmDesc algorithm)251 cudnnConvolutionFwdAlgo_t ToConvForwardAlgo(dnn::AlgorithmDesc algorithm) {
252 cudnnConvolutionFwdAlgo_t algo =
253 cudnnConvolutionFwdAlgo_t(algorithm.algo_id());
254 switch (algo) {
255 case CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM:
256 case CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM:
257 case CUDNN_CONVOLUTION_FWD_ALGO_GEMM:
258 case CUDNN_CONVOLUTION_FWD_ALGO_DIRECT:
259 case CUDNN_CONVOLUTION_FWD_ALGO_FFT:
260 case CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING:
261 case CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD:
262 case CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED:
263 return algo;
264 default:
265 LOG(FATAL) << "Unsupported Cudnn convolution forward algorithm: "
266 << algorithm.algo_id();
267 }
268 }
269
ToConvBackwardDataAlgo(dnn::AlgorithmDesc algorithm)270 cudnnConvolutionBwdDataAlgo_t ToConvBackwardDataAlgo(
271 dnn::AlgorithmDesc algorithm) {
272 cudnnConvolutionBwdDataAlgo_t algo =
273 cudnnConvolutionBwdDataAlgo_t(algorithm.algo_id());
274 switch (algo) {
275 case CUDNN_CONVOLUTION_BWD_DATA_ALGO_0:
276 case CUDNN_CONVOLUTION_BWD_DATA_ALGO_1:
277 case CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT:
278 case CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING:
279 case CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD:
280 case CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD_NONFUSED:
281 return algo;
282 default:
283 LOG(FATAL)
284 << "Unsupported Cudnn convolution backward algorithm for data: "
285 << algorithm.algo_id();
286 }
287 }
288
ToConvBackwardFilterAlgo(dnn::AlgorithmDesc algorithm)289 cudnnConvolutionBwdFilterAlgo_t ToConvBackwardFilterAlgo(
290 dnn::AlgorithmDesc algorithm) {
291 cudnnConvolutionBwdFilterAlgo_t algo =
292 cudnnConvolutionBwdFilterAlgo_t(algorithm.algo_id());
293 switch (algo) {
294 case CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0:
295 case CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1:
296 case CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT:
297 case CUDNN_CONVOLUTION_BWD_FILTER_ALGO_3:
298 // Based on cudnn.h, the following is not implemented.
299 // case CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD:
300 case CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD_NONFUSED:
301 case CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT_TILING:
302 return algo;
303 default:
304 LOG(FATAL)
305 << "Unsupported Cudnn convolution backward algorithm for filter: "
306 << algorithm.algo_id();
307 }
308 }
309
GetCudnnProperty(libraryPropertyType type)310 port::StatusOr<int> GetCudnnProperty(libraryPropertyType type) {
311 int value;
312 RETURN_IF_CUDNN_ERROR(cudnnGetProperty(type, &value));
313 return value;
314 }
315
ToCudnnRNNAlgo(std::optional<dnn::AlgorithmDesc> algorithm)316 cudnnRNNAlgo_t ToCudnnRNNAlgo(std::optional<dnn::AlgorithmDesc> algorithm) {
317 if (!algorithm.has_value()) {
318 return CUDNN_RNN_ALGO_STANDARD;
319 }
320 cudnnRNNAlgo_t algo = static_cast<cudnnRNNAlgo_t>(algorithm->algo_id());
321 switch (algo) {
322 case CUDNN_RNN_ALGO_STANDARD:
323 case CUDNN_RNN_ALGO_PERSIST_STATIC:
324 case CUDNN_RNN_ALGO_PERSIST_DYNAMIC:
325 return algo;
326 default:
327 LOG(FATAL) << "Unsupported Cudnn RNN algorithm: " << algorithm->algo_id();
328 }
329 }
330
GetLoadedCudnnVersion(CudnnVersion * version)331 port::Status GetLoadedCudnnVersion(CudnnVersion* version) {
332 TF_ASSIGN_OR_RETURN(version->major_version, GetCudnnProperty(MAJOR_VERSION));
333 TF_ASSIGN_OR_RETURN(version->minor_version, GetCudnnProperty(MINOR_VERSION));
334 TF_ASSIGN_OR_RETURN(version->patch_level, GetCudnnProperty(PATCH_LEVEL));
335 return ::tensorflow::OkStatus();
336 }
337
338 enum class PreloadCudnnType { ConvFwd, ConvBwdFilter, ConvBwdData, Rnn };
339
340 // Preload sub libs for cudnn 8.0.4+ to make sure that the loading time isn't
341 // measured in the autotuning.
PreloadCudnnSubLibs(PreloadCudnnType type)342 void PreloadCudnnSubLibs(PreloadCudnnType type) {
343 #if CUDNN_VERSION >= 8004
344 switch (type) {
345 case PreloadCudnnType::ConvBwdFilter:
346 case PreloadCudnnType::ConvBwdData: {
347 cudnnOpsTrainVersionCheck();
348 cudnnCnnTrainVersionCheck();
349 [[clang::fallthrough]];
350 }
351 case PreloadCudnnType::ConvFwd: {
352 cudnnOpsInferVersionCheck();
353 cudnnCnnInferVersionCheck();
354 break;
355 }
356 case PreloadCudnnType::Rnn: {
357 cudnnOpsInferVersionCheck();
358 cudnnAdvInferVersionCheck();
359 cudnnOpsTrainVersionCheck();
360 cudnnAdvTrainVersionCheck();
361 break;
362 }
363 }
364 #endif // CUDNN_VERSION >= 8004
365 }
366
PreloadCudnnSubLibsHelper(dnn::ConvolutionKind kind)367 void PreloadCudnnSubLibsHelper(dnn::ConvolutionKind kind) {
368 switch (kind) {
369 case dnn::ConvolutionKind::FORWARD: {
370 PreloadCudnnSubLibs(PreloadCudnnType::ConvFwd);
371 break;
372 }
373 case dnn::ConvolutionKind::BACKWARD_DATA: {
374 PreloadCudnnSubLibs(PreloadCudnnType::ConvBwdData);
375 break;
376 }
377 case dnn::ConvolutionKind::BACKWARD_FILTER: {
378 PreloadCudnnSubLibs(PreloadCudnnType::ConvBwdFilter);
379 break;
380 }
381 default: {
382 LOG(WARNING) << "Unsupported dnn::ConvolutionKind: "
383 << static_cast<int>(kind) << " for cuDNN preload.";
384 break;
385 }
386 }
387 }
388
389 } // namespace
390
CudnnSupport(GpuExecutor * parent)391 CudnnSupport::CudnnSupport(GpuExecutor* parent) : parent_(parent) {}
392
Init()393 port::Status CudnnSupport::Init() {
394 ScopedActivateExecutorContext context(parent_);
395
396 // Peek at the last error to give more information in cases of errors.
397 cudaError_t cerr = cudaPeekAtLastError();
398 if (cerr != cudaSuccess) {
399 LOG(WARNING) << "There was an error before creating cudnn handle: "
400 << cudaGetErrorName(cerr) << " : " << cudaGetErrorString(cerr);
401 }
402
403 cudnnHandle_t cudnn_handle = nullptr;
404 const auto status = cudnnCreate(&cudnn_handle);
405 if (status == CUDNN_STATUS_SUCCESS) {
406 CudnnVersion source_version(CUDNN_MAJOR, CUDNN_MINOR, CUDNN_PATCHLEVEL);
407
408 CudnnVersion loaded_version;
409 TF_RETURN_IF_ERROR(GetLoadedCudnnVersion(&loaded_version));
410 if (!IsSourceCompatibleWithCudnnLibrary(source_version, loaded_version)) {
411 const std::string error = absl::StrCat(
412 "Loaded runtime CuDNN library: ", loaded_version.ToString(),
413 " but source was compiled with: ", source_version.ToString(),
414 ". CuDNN library needs to have matching major version and equal or "
415 "higher minor version. If using a binary install, upgrade your CuDNN "
416 "library. If building from sources, make sure the library loaded at "
417 "runtime is compatible with the version specified during compile "
418 "configuration.");
419 LOG(ERROR) << error;
420 cudnnDestroy(cudnn_handle);
421 return port::Status(port::error::INTERNAL, error);
422 }
423
424 cudnn_.reset(new CudnnAccess(cudnn_handle));
425
426 LOG(INFO) << "Loaded cuDNN version " << cudnnGetVersion();
427 return ::tensorflow::OkStatus();
428 }
429
430 CHECK_EQ(cudnn_handle, nullptr);
431 LOG(ERROR) << "Could not create cudnn handle: "
432 << CudnnStatusToString(status);
433 if (status == CUDNN_STATUS_NOT_INITIALIZED) {
434 auto result = gpu::Diagnostician::FindKernelDriverVersion();
435 if (!result.ok()) {
436 LOG(ERROR) << "Error retrieving driver version: "
437 << cuda::DriverVersionStatusToString(result);
438 } else {
439 const auto& version = result.ValueOrDie();
440 LOG(ERROR) << "Possibly insufficient driver version: "
441 << cuda::DriverVersionToString(version);
442 }
443 }
444
445 return port::Status(port::error::INTERNAL,
446 absl::StrCat("cudnn library could not create a handle: ",
447 CudnnStatusToString(status)));
448 }
449
NotifyStreamDestroyed(Stream * stream)450 void CudnnSupport::NotifyStreamDestroyed(Stream* stream) /* override */ {
451 cudnn_->NotifyStreamDestroyed(stream);
452 }
453
454 port::StatusOr<perftools::gputools::dnn::VersionInfo>
GetVersion()455 CudnnSupport::GetVersion() {
456 CudnnVersion version;
457 TF_RETURN_IF_ERROR(GetLoadedCudnnVersion(&version));
458 return perftools::gputools::dnn::VersionInfo(
459 version.major_version, version.minor_version, version.patch_level);
460 }
461
462 namespace {
463
464 // Deleter functors for cuDNN types that need to be deleted.
465 struct TensorDescriptorDeleter {
operator ()stream_executor::gpu::__anon467f18dd0311::TensorDescriptorDeleter466 void operator()(cudnnTensorDescriptor_t descriptor) const {
467 CHECK_CUDNN_OK(cudnnDestroyTensorDescriptor(descriptor));
468 }
469 };
470 struct RNNDataDescriptorDeleter {
operator ()stream_executor::gpu::__anon467f18dd0311::RNNDataDescriptorDeleter471 void operator()(cudnnRNNDataDescriptor_t descriptor) const {
472 CHECK_CUDNN_OK(cudnnDestroyRNNDataDescriptor(descriptor));
473 }
474 };
475 struct FilterDescriptorDeleter {
operator ()stream_executor::gpu::__anon467f18dd0311::FilterDescriptorDeleter476 void operator()(cudnnFilterDescriptor_t descriptor) const {
477 CHECK_CUDNN_OK(cudnnDestroyFilterDescriptor(descriptor));
478 }
479 };
480 struct ConvolutionDescriptorDeleter {
operator ()stream_executor::gpu::__anon467f18dd0311::ConvolutionDescriptorDeleter481 void operator()(cudnnConvolutionDescriptor_t descriptor) const {
482 CHECK_CUDNN_OK(cudnnDestroyConvolutionDescriptor(descriptor));
483 }
484 };
485 struct PoolingDescriptorDeleter {
operator ()stream_executor::gpu::__anon467f18dd0311::PoolingDescriptorDeleter486 void operator()(cudnnPoolingDescriptor_t descriptor) const {
487 CHECK_CUDNN_OK(cudnnDestroyPoolingDescriptor(descriptor));
488 }
489 };
490 struct LrnDescriptorDeleter {
operator ()stream_executor::gpu::__anon467f18dd0311::LrnDescriptorDeleter491 void operator()(cudnnLRNDescriptor_t descriptor) const {
492 CHECK_CUDNN_OK(cudnnDestroyLRNDescriptor(descriptor));
493 }
494 };
495
496 struct ActivationDescriptorDeleter {
operator ()stream_executor::gpu::__anon467f18dd0311::ActivationDescriptorDeleter497 void operator()(cudnnActivationDescriptor_t descriptor) const {
498 CHECK_CUDNN_OK(cudnnDestroyActivationDescriptor(descriptor));
499 }
500 };
501 struct DropoutDescriptorDeleter {
operator ()stream_executor::gpu::__anon467f18dd0311::DropoutDescriptorDeleter502 void operator()(cudnnDropoutDescriptor_t descriptor) const {
503 CHECK_CUDNN_OK(cudnnDestroyDropoutDescriptor(descriptor));
504 }
505 };
506 struct RnnDescriptorDeleter {
operator ()stream_executor::gpu::__anon467f18dd0311::RnnDescriptorDeleter507 void operator()(cudnnRNNDescriptor_t descriptor) const {
508 CHECK_CUDNN_OK(cudnnDestroyRNNDescriptor(descriptor));
509 }
510 };
511 struct PersistentRnnPlanDeleter {
operator ()stream_executor::gpu::__anon467f18dd0311::PersistentRnnPlanDeleter512 void operator()(cudnnPersistentRNNPlan_t plan) const {
513 CHECK_CUDNN_OK(cudnnDestroyPersistentRNNPlan(plan));
514 }
515 };
516 #if CUDNN_VERSION >= 7603
517 struct CtcLossDescriptorDeleter {
operator ()stream_executor::gpu::__anon467f18dd0311::CtcLossDescriptorDeleter518 void operator()(cudnnCTCLossDescriptor_t descriptor) const {
519 CHECK_CUDNN_OK(cudnnDestroyCTCLossDescriptor(descriptor));
520 }
521 };
522 #endif
523
524 // RAII wrappers for cuDNN types.
525 using TensorDescriptor =
526 std::unique_ptr<cudnnTensorStruct, TensorDescriptorDeleter>;
527 using RNNDataDescriptor =
528 std::unique_ptr<cudnnRNNDataStruct, RNNDataDescriptorDeleter>;
529 using FilterDescriptor =
530 std::unique_ptr<cudnnFilterStruct, FilterDescriptorDeleter>;
531 using ConvolutionDescriptor =
532 std::unique_ptr<cudnnConvolutionStruct, ConvolutionDescriptorDeleter>;
533 using PoolingDescriptor =
534 std::unique_ptr<cudnnPoolingStruct, PoolingDescriptorDeleter>;
535 using LrnDescriptor = std::unique_ptr<cudnnLRNStruct, LrnDescriptorDeleter>;
536 using ActivationDescriptor =
537 std::unique_ptr<cudnnActivationStruct, ActivationDescriptorDeleter>;
538 using DropoutDescriptor =
539 std::unique_ptr<cudnnDropoutStruct, DropoutDescriptorDeleter>;
540 using RnnDescriptor = std::unique_ptr<cudnnRNNStruct, RnnDescriptorDeleter>;
541 using PersistentRnnPlan =
542 std::unique_ptr<cudnnPersistentRNNPlan, PersistentRnnPlanDeleter>;
543 #if CUDNN_VERSION >= 7603
544 using CtcLossDescriptor =
545 std::unique_ptr<cudnnCTCLossStruct, CtcLossDescriptorDeleter>;
546 #endif
547
548 // Factory methods for cuDNN types.
CreateTensorDescriptor()549 TensorDescriptor CreateTensorDescriptor() {
550 cudnnTensorDescriptor_t result;
551 CHECK_CUDNN_OK(cudnnCreateTensorDescriptor(&result));
552 return TensorDescriptor(result);
553 }
CreateRNNDataDescriptor()554 RNNDataDescriptor CreateRNNDataDescriptor() {
555 cudnnRNNDataDescriptor_t result;
556 CHECK_CUDNN_OK(cudnnCreateRNNDataDescriptor(&result));
557 return RNNDataDescriptor(result);
558 }
CreateFilterDescriptor()559 FilterDescriptor CreateFilterDescriptor() {
560 cudnnFilterDescriptor_t result;
561 CHECK_CUDNN_OK(cudnnCreateFilterDescriptor(&result));
562 return FilterDescriptor(result);
563 }
CreateConvolutionDescriptor()564 ConvolutionDescriptor CreateConvolutionDescriptor() {
565 cudnnConvolutionDescriptor_t result;
566 CHECK_CUDNN_OK(cudnnCreateConvolutionDescriptor(&result));
567 return ConvolutionDescriptor(result);
568 }
CreatePoolingDescriptor()569 PoolingDescriptor CreatePoolingDescriptor() {
570 cudnnPoolingDescriptor_t result;
571 CHECK_CUDNN_OK(cudnnCreatePoolingDescriptor(&result));
572 return PoolingDescriptor(result);
573 }
CreateLrnDescriptor()574 LrnDescriptor CreateLrnDescriptor() {
575 cudnnLRNDescriptor_t result;
576 CHECK_CUDNN_OK(cudnnCreateLRNDescriptor(&result));
577 return LrnDescriptor(result);
578 }
CreateActivationDescriptor()579 ActivationDescriptor CreateActivationDescriptor() {
580 cudnnActivationDescriptor_t result;
581 CHECK_CUDNN_OK(cudnnCreateActivationDescriptor(&result));
582 return ActivationDescriptor(result);
583 }
CreateDropoutDescriptor()584 DropoutDescriptor CreateDropoutDescriptor() {
585 cudnnDropoutDescriptor_t result;
586 CHECK_CUDNN_OK(cudnnCreateDropoutDescriptor(&result));
587 return DropoutDescriptor(result);
588 }
CreateRnnDescriptor()589 RnnDescriptor CreateRnnDescriptor() {
590 cudnnRNNDescriptor_t result;
591 CHECK_CUDNN_OK(cudnnCreateRNNDescriptor(&result));
592 return RnnDescriptor(result);
593 }
594 #if CUDNN_VERSION >= 7603
CreateCtcLossDescriptor()595 CtcLossDescriptor CreateCtcLossDescriptor() {
596 cudnnCTCLossDescriptor_t result;
597 CHECK_CUDNN_OK(cudnnCreateCTCLossDescriptor(&result));
598 return CtcLossDescriptor(result);
599 }
600 #endif
601
CreatePersistentRnnPlan(cudnnRNNDescriptor_t rnn_desc,int batch_size,cudnnDataType_t data_type)602 port::StatusOr<PersistentRnnPlan> CreatePersistentRnnPlan(
603 cudnnRNNDescriptor_t rnn_desc, int batch_size, cudnnDataType_t data_type) {
604 cudnnPersistentRNNPlan_t result;
605 RETURN_IF_CUDNN_ERROR(
606 cudnnCreatePersistentRNNPlan(rnn_desc, batch_size, data_type, &result));
607 return port::StatusOr<PersistentRnnPlan>(PersistentRnnPlan(result));
608 }
609
610 // Turns a BatchDescriptor structure into a cudnn tensor handle within a
611 // scope.
612 class CudnnTensorDescriptor {
613 public:
CudnnTensorDescriptor(const dnn::BatchDescriptor & batch_descriptor,cudnnDataType_t elem_type)614 CudnnTensorDescriptor(const dnn::BatchDescriptor& batch_descriptor,
615 cudnnDataType_t elem_type)
616 : handle_(CreateTensorDescriptor()) {
617 switch (batch_descriptor.layout()) {
618 case dnn::DataLayout::kBatchYXDepth:
619 case dnn::DataLayout::kBatchDepthYX: {
620 const int nd = batch_descriptor.ndims() + 2;
621 // cuDNN requires the strides and dims to be ordered as BDYX.
622 std::vector<int64_t> strides64 =
623 batch_descriptor.full_strides(dnn::DataLayout::kBatchDepthYX);
624 std::vector<int64_t> dims64 =
625 batch_descriptor.full_dims(dnn::DataLayout::kBatchDepthYX);
626
627 // cuDNN requires arrays of ints.
628 std::vector<int> strides(nd);
629 std::vector<int> dims(nd);
630 std::transform(strides64.cbegin(), strides64.cend(), strides.begin(),
631 &CheckedNarrowing<int64_t, int>);
632 std::transform(dims64.cbegin(), dims64.cend(), dims.begin(),
633 &CheckedNarrowing<int64_t, int>);
634 CHECK_CUDNN_OK(cudnnSetTensorNdDescriptor(handle_.get(), elem_type, nd,
635 dims.data(), strides.data()))
636 << "batch_descriptor: " << batch_descriptor.ToString();
637 break;
638 }
639 case dnn::DataLayout::kBatchDepthYX4:
640 case dnn::DataLayout::kBatchDepthYX32: {
641 auto expected_elem_ty =
642 batch_descriptor.layout() == dnn::DataLayout::kBatchDepthYX4
643 ? CUDNN_DATA_INT8x4
644 : CUDNN_DATA_INT8x32;
645 CHECK_EQ(elem_type, expected_elem_ty);
646 CHECK_CUDNN_OK(cudnnSetTensor4dDescriptor(
647 handle_.get(), CUDNN_TENSOR_NCHW_VECT_C, elem_type,
648 batch_descriptor.count(), batch_descriptor.feature_map_count(),
649 batch_descriptor.height(), batch_descriptor.width()))
650 << "batch_descriptor: " << batch_descriptor.ToString();
651 break;
652 }
653 default:
654 LOG(FATAL) << "Unsupported tensor format "
655 << DataLayoutString(batch_descriptor.layout());
656 break;
657 }
658 }
659
handle() const660 cudnnTensorDescriptor_t handle() const { return handle_.get(); }
661
662 private:
663 TensorDescriptor handle_;
664 };
665
666 // Turns a FilterDescriptor structure into a cudnn filter handle within a
667 // scope.
668 class CudnnFilterDescriptor {
669 public:
CudnnFilterDescriptor(const dnn::FilterDescriptor & filter_descriptor,cudnnDataType_t elem_type)670 CudnnFilterDescriptor(const dnn::FilterDescriptor& filter_descriptor,
671 cudnnDataType_t elem_type)
672 : handle_(CreateFilterDescriptor()) {
673 // TODO(b/23032134): Even if the filter layout is not supported,
674 // cudnnSetFilter4DDescriptor_v4 will return CUDNN_STATUS_SUCCESS because
675 // it does not take layout as an input. Maybe force cuDNN by giving wrong
676 // inputs intentionally?
677 cudnnTensorFormat_t format;
678 switch (filter_descriptor.layout()) {
679 case dnn::FilterLayout::kOutputInputYX:
680 format = CUDNN_TENSOR_NCHW;
681 break;
682 case dnn::FilterLayout::kOutputYXInput:
683 format = CUDNN_TENSOR_NHWC;
684 break;
685 case dnn::FilterLayout::kOutputInputYX4:
686 case dnn::FilterLayout::kOutputInputYX32: {
687 auto expected_elem_ty =
688 filter_descriptor.layout() == dnn::FilterLayout::kOutputInputYX4
689 ? CUDNN_DATA_INT8x4
690 : CUDNN_DATA_INT8x32;
691 CHECK_EQ(elem_type, expected_elem_ty);
692 format = CUDNN_TENSOR_NCHW_VECT_C;
693 break;
694 }
695 default:
696 LOG(FATAL) << "Unsupported filter format "
697 << FilterLayoutString(filter_descriptor.layout());
698 break;
699 }
700
701 std::vector<int> dims(2 + filter_descriptor.ndims());
702 dims[0] = filter_descriptor.output_feature_map_count();
703 dims[1] = filter_descriptor.input_feature_map_count();
704 absl::Span<const int64_t> spatial_dims =
705 filter_descriptor.input_filter_dims();
706 std::copy(spatial_dims.begin(), spatial_dims.end(), dims.begin() + 2);
707
708 CHECK_CUDNN_OK(cudnnSetFilterNdDescriptor(handle_.get(), elem_type, format,
709 dims.size(), dims.data()));
710 }
711
handle() const712 cudnnFilterDescriptor_t handle() const { return handle_.get(); }
713
714 private:
715 FilterDescriptor handle_; // Owned.
716 };
717
718 #if CUDNN_VERSION >= 8100 && TF_ENABLE_CUDNN_FRONTEND
719 // The errata sheet (JSON format) for marking the cudnn engines that might be
720 // buggy. For example, we don't want the engine 999 of forward convolution:
721 // R"({ "version" : 1,
722 // "rules" : [
723 // { "rule_id" : "ConvFwd_eng999",
724 // "operation" : "ConvFwd",
725 // "engine" : 999,
726 // "knob" : [],
727 // "cudnn_version_start" : 8000,
728 // "cudnn_version_end" : -1
729 // }
730 // ]})"
731 // We skip a non-existing eng999 in the static filter as a placeholder.
732 // Additionally, users can specify an additional errata JSON file via
733 // CUDNN_ERRATA_JSON_FILE at runtime.
734 // We are also excluding two flavors of ConvFwd_eng42 due to b/234183340.
CudnnExecutionPlanEngineFilterStatic()735 const json* CudnnExecutionPlanEngineFilterStatic() {
736 static absl::string_view filter_str = R"({
737 "version" : 1,
738 "rules" : [
739 { "rule_id" : "ConvFwd_eng999",
740 "operation" : "ConvFwd",
741 "engine" : 999,
742 "knob" : [],
743 "cudnn_version_start" : 8000,
744 "cudnn_version_end" : -1
745 },
746 { "rule_id" : "ConvFwd_eng42_k2=2_k4=3_k5=0_k6=0_k7=0",
747 "operation" : "ConvFwd",
748 "engine" : 42,
749 "knob" :
750 {
751 "k2" : "2",
752 "k4" : "3",
753 "k5" : "0",
754 "k6" : "0",
755 "k7" : "0"
756 },
757 "cudnn_version_start" : 8000,
758 "cudnn_version_end" : -1
759 },
760 { "rule_id" : "ConvFwd_eng42_k2=1_k4=3_k5=1_k6=0_k7=0",
761 "operation" : "ConvFwd",
762 "engine" : 42,
763 "knob" :
764 {
765 "k2" : "1",
766 "k4" : "3",
767 "k5" : "1",
768 "k6" : "0",
769 "k7" : "0"
770 },
771 "cudnn_version_start" : 8000,
772 "cudnn_version_end" : -1
773 }
774 ]})";
775 static const json* json_handle = new json(json::parse(filter_str));
776 return json_handle;
777 }
778
CudnnExecutionPlanEngineFilterRuntime()779 const json* CudnnExecutionPlanEngineFilterRuntime() {
780 static const json* json_handle = []() -> const json* {
781 json j;
782 if (cudnn_frontend::load_from_config(j, "")) {
783 return new json(j);
784 }
785 return nullptr;
786 }();
787 return json_handle;
788 }
789
790 #endif // CUDNN_VERSION >= 8100 && TF_ENABLE_CUDNN_FRONTEND
791
792 // A helper function to decide whether to use
793 // CUDNN_BATCHNORM_SPATIAL_PERSISTENT in batchnorm. This mode can be faster in
794 // some tasks because an optimized path may be selected for CUDNN_DATA_FLOAT
795 // and CUDNN_DATA_HALF data types, compute capability 6.0 or higher. The
796 // reason we set it to false by default is that this mode may use scaled
797 // atomic integer reduction that may cause a numerical overflow for certain
798 // input data range.
799 // TODO(yangzihao): Use autotune to choose between this mode and
800 // CUDNN_BATCHNORM_SPATIAL mode.
BatchnormSpatialPersistentEnabled()801 bool BatchnormSpatialPersistentEnabled() {
802 static bool is_enabled = [] {
803 bool is_enabled = false;
804 TF_CHECK_OK(tensorflow::ReadBoolFromEnvVar(
805 "TF_USE_CUDNN_BATCHNORM_SPATIAL_PERSISTENT",
806 /*default_val=*/false, &is_enabled));
807 return is_enabled;
808 }();
809 return is_enabled;
810 }
811
RequireCudnnDeterminism()812 bool RequireCudnnDeterminism() {
813 static bool require_cudnn_determinism = [] {
814 // TODO(reedwm): Remove the TF_CUDNN_DETERMINISTIC env var.
815 bool cudnn_deterministic = false;
816 TF_CHECK_OK(tensorflow::ReadBoolFromEnvVar("TF_CUDNN_DETERMINISTIC",
817 /*default_val=*/false,
818 &cudnn_deterministic));
819 return cudnn_deterministic;
820 }();
821 return tensorflow::OpDeterminismRequired() || require_cudnn_determinism;
822 }
823
824 // A helper function to decide whether to force the default conv algorithm.
ConvUseDefaultAlgorithm()825 bool ConvUseDefaultAlgorithm() {
826 static bool use_default = [] {
827 bool use_default = false;
828 TF_CHECK_OK(tensorflow::ReadBoolFromEnvVar("TF_USE_DEFAULT_CONV_ALGO",
829 /*default_val=*/false,
830 &use_default));
831 return use_default;
832 }();
833 return use_default;
834 }
835
836 // Turns a ConvolutionDescriptor structure into a cudnn convolution handle
837 // within a scope.
838 class CudnnConvolutionDescriptor {
839 public:
CudnnConvolutionDescriptor(const dnn::ConvolutionDescriptor & convolution_descriptor,cudnnDataType_t data_type)840 CudnnConvolutionDescriptor(
841 const dnn::ConvolutionDescriptor& convolution_descriptor,
842 cudnnDataType_t data_type)
843 : handle_(CreateConvolutionDescriptor()) {
844 absl::Span<const int64_t> strides64 = convolution_descriptor.strides();
845 absl::Span<const int64_t> padding64 = convolution_descriptor.padding();
846 absl::Span<const int64_t> dilations64 = convolution_descriptor.dilations();
847 CHECK_NE(convolution_descriptor.pad_alignment(),
848 dnn::PadAlignment::kTensorFlowPadding)
849 << "TensorFlow padding alignment is not supported.";
850
851 // cuDNN requires arrays of ints.
852 std::vector<int> strides(convolution_descriptor.ndims());
853 std::vector<int> padding(convolution_descriptor.ndims());
854 std::vector<int> dilations(convolution_descriptor.ndims());
855 std::transform(strides64.cbegin(), strides64.cend(), strides.begin(),
856 &CheckedNarrowing<int64_t, int>);
857 std::transform(padding64.cbegin(), padding64.cend(), padding.begin(),
858 &CheckedNarrowing<int64_t, int>);
859 // TODO(yangzihao): Test with negative dilation to make sure that cudnn
860 // doesn't crash.
861 std::transform(dilations64.cbegin(), dilations64.cend(), dilations.begin(),
862 &CheckedNarrowing<int64_t, int>);
863
864 CHECK_CUDNN_OK(cudnnSetConvolutionNdDescriptor(
865 handle_.get(), convolution_descriptor.ndims(), padding.data(),
866 strides.data(), dilations.data(),
867 convolution_descriptor.convolution_not_crosscorr()
868 ? CUDNN_CONVOLUTION
869 : CUDNN_CROSS_CORRELATION,
870 data_type));
871
872 #if CUDNN_MAJOR >= 7
873 VLOG(2) << "Requesting grouped convolution: "
874 << convolution_descriptor.group_count();
875 CHECK_CUDNN_OK(cudnnSetConvolutionGroupCount(
876 handle_.get(), convolution_descriptor.group_count()));
877 #else
878 CHECK_EQ(convolution_descriptor.group_count(), 1)
879 << "Requested grouped convolution for cuDNN version < 7";
880 #endif
881 }
882
set_use_tensor_op_math(bool use_tensor_op_math)883 void set_use_tensor_op_math(bool use_tensor_op_math) {
884 cudnnMathType_t math_type =
885 #if CUDNN_VERSION >= 8000
886 (use_tensor_op_math ? CUDNN_TENSOR_OP_MATH : CUDNN_FMA_MATH);
887 #else
888 (use_tensor_op_math ? CUDNN_TENSOR_OP_MATH : CUDNN_DEFAULT_MATH);
889 #endif
890 CHECK_CUDNN_OK(cudnnSetConvolutionMathType(handle_.get(), math_type));
891 }
892
handle() const893 cudnnConvolutionDescriptor_t handle() const { return handle_.get(); }
894
895 private:
896 ConvolutionDescriptor handle_; // Owned.
897 };
898
899 // A helper function to query if a CudnnConvolutionDescriptor has tensor_op_math
900 // set
IsTensorMathOpSet(const CudnnConvolutionDescriptor & conv)901 static bool IsTensorMathOpSet(const CudnnConvolutionDescriptor& conv) {
902 cudnnMathType_t math_type;
903 CHECK_CUDNN_OK(cudnnGetConvolutionMathType(conv.handle(), &math_type));
904 #if CUDNN_VERSION >= 8000
905 return math_type != CUDNN_FMA_MATH;
906 #else
907 return math_type == CUDNN_TENSOR_OP_MATH;
908 #endif
909 }
910
TensorOpMathAvailable(CudaComputeCapability cuda_compute_capability)911 static bool TensorOpMathAvailable(
912 CudaComputeCapability cuda_compute_capability) {
913 return cuda_compute_capability.IsAtLeast(7);
914 }
915
IsTensorMathEnabled(Stream * stream,dnn::DataType input_type)916 static bool IsTensorMathEnabled(Stream* stream, dnn::DataType input_type) {
917 if (!TensorOpMathAvailable(stream->GetCudaComputeCapability())) {
918 return false;
919 }
920 if (input_type == dnn::DataType::kFloat) {
921 #if CUDNN_VERSION < 8000
922 return false;
923 #else
924 if (!tensorflow::tensor_float_32_execution_enabled()) {
925 return false;
926 }
927 #endif
928 }
929 return true;
930 }
931
932 // Turns a PoolingDescriptor structure into a cudnn pooling descriptor handle
933 // within a scope.
934 class CudnnPoolingDescriptor {
935 public:
CudnnPoolingDescriptor(const dnn::PoolingDescriptor & pooling_descriptor)936 explicit CudnnPoolingDescriptor(
937 const dnn::PoolingDescriptor& pooling_descriptor)
938 : handle_(CreatePoolingDescriptor()) {
939 absl::Span<const int64_t> strides64 = pooling_descriptor.strides();
940 absl::Span<const int64_t> padding64 = pooling_descriptor.padding();
941 absl::Span<const int64_t> shape64 = pooling_descriptor.window();
942
943 const int nd = pooling_descriptor.ndims();
944 std::vector<int> shape(nd);
945 std::vector<int> padding(nd);
946 std::vector<int> strides(nd);
947 std::transform(strides64.cbegin(), strides64.cend(), strides.begin(),
948 &CheckedNarrowing<int64_t, int>);
949 std::transform(padding64.cbegin(), padding64.cend(), padding.begin(),
950 &CheckedNarrowing<int64_t, int>);
951 std::transform(shape64.cbegin(), shape64.cend(), shape.begin(),
952 &CheckedNarrowing<int64_t, int>);
953 bool propagate_nans = pooling_descriptor.propagate_nans();
954 const auto cudnn_max_pooling_mode = RequireCudnnDeterminism()
955 ? CUDNN_POOLING_MAX_DETERMINISTIC
956 : CUDNN_POOLING_MAX;
957 CHECK_CUDNN_OK(cudnnSetPoolingNdDescriptor(
958 handle_.get(),
959 (pooling_descriptor.mode() == dnn::PoolingMode::kMaximum
960 ? cudnn_max_pooling_mode
961 : CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING),
962 propagate_nans ? CUDNN_PROPAGATE_NAN : CUDNN_NOT_PROPAGATE_NAN, nd,
963 shape.data(), padding.data(), strides.data()));
964 }
965
handle() const966 cudnnPoolingDescriptor_t handle() const { return handle_.get(); }
967
968 private:
969 PoolingDescriptor handle_; // Owned.
970
971 SE_DISALLOW_COPY_AND_ASSIGN(CudnnPoolingDescriptor);
972 };
973
974 // Turns a NormalizeDescriptor structure into a cudnn LRN descriptor handle.
975 class CudnnNormalizeDescriptor {
976 public:
CudnnNormalizeDescriptor(const dnn::NormalizeDescriptor & normalize_descriptor)977 explicit CudnnNormalizeDescriptor(
978 const dnn::NormalizeDescriptor& normalize_descriptor)
979 : handle_(CreateLrnDescriptor()) {
980 // The range specifies that the indices in the closed range
981 // [i - range, i + range] should be included in the normalization for index
982 // i. The lrnN value is the total number of elements in the range, so
983 // lrnN = 2*range + 1.
984 unsigned lrnN = 2 * normalize_descriptor.range() + 1;
985
986 // Note that SE defines the normalization operation as
987 //
988 // U_i = V_i / ((bias + alpha * (sum_j V_j^2)) ^ beta)
989 //
990 // but cuDNN defines it as
991 //
992 // U_i = V_i / ((bias + (alpha / n) * (sum_j V_j^2)) ^ beta)
993 //
994 // i.e. there is a factor of n difference between the meaning of the alphas
995 // in the two contexts. The cuDNN alpha is n times the SE alpha.
996 double lrnAlpha = lrnN * normalize_descriptor.alpha();
997
998 double lrnBeta = normalize_descriptor.beta();
999 double lrnK = normalize_descriptor.bias();
1000 CHECK_CUDNN_OK(
1001 cudnnSetLRNDescriptor(handle_.get(), lrnN, lrnAlpha, lrnBeta, lrnK));
1002 }
1003
handle() const1004 cudnnLRNDescriptor_t handle() const { return handle_.get(); }
1005
1006 private:
1007 LrnDescriptor handle_; // Owned.
1008
1009 SE_DISALLOW_COPY_AND_ASSIGN(CudnnNormalizeDescriptor);
1010 };
1011
1012 // Turns a ActivationDescriptor structure into a cudnn activation
1013 // descriptor handle within a scope.
1014 class CudnnActivationDescriptor {
1015 public:
CudnnActivationDescriptor(dnn::ActivationMode activation_mode,cudnnNanPropagation_t nan_propagation,double value_max)1016 CudnnActivationDescriptor(dnn::ActivationMode activation_mode,
1017 cudnnNanPropagation_t nan_propagation,
1018 double value_max)
1019 : handle_(CreateActivationDescriptor()) {
1020 double relu_ceiling = 0.0;
1021 cudnnActivationMode_t mode;
1022 switch (activation_mode) {
1023 case dnn::ActivationMode::kNone:
1024 mode = CUDNN_ACTIVATION_IDENTITY;
1025 break;
1026 case dnn::ActivationMode::kRelu6:
1027 relu_ceiling = 6.0;
1028 mode = CUDNN_ACTIVATION_CLIPPED_RELU;
1029 break;
1030 case dnn::ActivationMode::kReluX:
1031 relu_ceiling = value_max;
1032 mode = CUDNN_ACTIVATION_CLIPPED_RELU;
1033 break;
1034 case dnn::ActivationMode::kRelu:
1035 mode = CUDNN_ACTIVATION_RELU;
1036 break;
1037 case dnn::ActivationMode::kSigmoid:
1038 mode = CUDNN_ACTIVATION_SIGMOID;
1039 break;
1040 case dnn::ActivationMode::kTanh:
1041 mode = CUDNN_ACTIVATION_TANH;
1042 break;
1043 default:
1044 LOG(FATAL) << "unrecognized activation mode: "
1045 << static_cast<int>(activation_mode);
1046 }
1047
1048 CHECK_CUDNN_OK(cudnnSetActivationDescriptor(handle_.get(), mode,
1049 nan_propagation, relu_ceiling));
1050 }
1051
handle() const1052 cudnnActivationDescriptor_t handle() const { return handle_.get(); }
1053
1054 private:
1055 ActivationDescriptor handle_; // Owned.
1056 };
1057
ToCudnnDataType(dnn::DataType data_type,dnn::DataLayout data_layout=dnn::DataLayout::kBatchDepthYX)1058 cudnnDataType_t ToCudnnDataType(
1059 dnn::DataType data_type,
1060 dnn::DataLayout data_layout = dnn::DataLayout::kBatchDepthYX) {
1061 switch (data_type) {
1062 case dnn::DataType::kFloat:
1063 return CUDNN_DATA_FLOAT;
1064 case dnn::DataType::kDouble:
1065 return CUDNN_DATA_DOUBLE;
1066 case dnn::DataType::kHalf:
1067 return CUDNN_DATA_HALF;
1068 case dnn::DataType::kInt8:
1069 switch (data_layout) {
1070 case dnn::DataLayout::kBatchDepthYX4:
1071 return CUDNN_DATA_INT8x4;
1072 case dnn::DataLayout::kBatchDepthYX32:
1073 return CUDNN_DATA_INT8x32;
1074 default:
1075 return CUDNN_DATA_INT8;
1076 }
1077 case dnn::DataType::kInt32:
1078 return CUDNN_DATA_INT32;
1079 #if CUDNN_VERSION >= 8200
1080 case dnn::DataType::kBF16:
1081 return CUDNN_DATA_BFLOAT16;
1082 #endif
1083 default:
1084 LOG(FATAL) << "Invalid DNN data type: " << static_cast<int>(data_type);
1085 }
1086 }
1087
ToCudnnDataType(dnn::DataType data_type,dnn::FilterLayout filter_layout)1088 cudnnDataType_t ToCudnnDataType(dnn::DataType data_type,
1089 dnn::FilterLayout filter_layout) {
1090 if (data_type == dnn::DataType::kInt8 &&
1091 filter_layout == dnn::FilterLayout::kOutputInputYX4) {
1092 return CUDNN_DATA_INT8x4;
1093 }
1094 if (data_type == dnn::DataType::kInt8 &&
1095 filter_layout == dnn::FilterLayout::kOutputInputYX32) {
1096 return CUDNN_DATA_INT8x32;
1097 }
1098 return ToCudnnDataType(data_type);
1099 }
1100
1101 template <typename T>
GetCudnnDataType(dnn::DataLayout data_layout=dnn::DataLayout::kBatchDepthYX)1102 cudnnDataType_t GetCudnnDataType(
1103 dnn::DataLayout data_layout = dnn::DataLayout::kBatchDepthYX) {
1104 return ToCudnnDataType(dnn::ToDataType<T>::value, data_layout);
1105 }
1106
1107 template <typename T>
GetCudnnDataType(dnn::FilterLayout filter_layout)1108 cudnnDataType_t GetCudnnDataType(dnn::FilterLayout filter_layout) {
1109 return ToCudnnDataType(dnn::ToDataType<T>::value, filter_layout);
1110 }
1111
ToCudnnRnnInputMode(dnn::RnnInputMode input_mode)1112 cudnnRNNInputMode_t ToCudnnRnnInputMode(dnn::RnnInputMode input_mode) {
1113 switch (input_mode) {
1114 case dnn::RnnInputMode::kRnnLinearSkip:
1115 case dnn::RnnInputMode::kRnnSkipInput:
1116 return static_cast<cudnnRNNInputMode_t>(input_mode);
1117 default:
1118 LOG(FATAL) << "Invalid RNN input mode: " << static_cast<int>(input_mode);
1119 }
1120 }
1121
ToCudnnRnnDirectionMode(dnn::RnnDirectionMode direction_mode)1122 cudnnDirectionMode_t ToCudnnRnnDirectionMode(
1123 dnn::RnnDirectionMode direction_mode) {
1124 switch (direction_mode) {
1125 case dnn::RnnDirectionMode::kRnnUnidirectional:
1126 case dnn::RnnDirectionMode::kRnnBidirectional:
1127 return static_cast<cudnnDirectionMode_t>(direction_mode);
1128 default:
1129 LOG(FATAL) << "Invalid RNN direction mode: "
1130 << static_cast<int>(direction_mode);
1131 }
1132 }
1133
ToCudnnRnnMode(dnn::RnnMode rnn_mode)1134 cudnnRNNMode_t ToCudnnRnnMode(dnn::RnnMode rnn_mode) {
1135 switch (rnn_mode) {
1136 case dnn::RnnMode::kRnnRelu:
1137 case dnn::RnnMode::kRnnTanh:
1138 case dnn::RnnMode::kRnnLstm:
1139 case dnn::RnnMode::kRnnGru:
1140 return static_cast<cudnnRNNMode_t>(rnn_mode);
1141 default:
1142 LOG(FATAL) << "Invalid RNN Mode: " << static_cast<int>(rnn_mode);
1143 }
1144 }
1145
CudnnDataTypeToByteSize(cudnnDataType_t data_type)1146 int CudnnDataTypeToByteSize(cudnnDataType_t data_type) {
1147 switch (data_type) {
1148 case CUDNN_DATA_FLOAT:
1149 return sizeof(float);
1150 case CUDNN_DATA_DOUBLE:
1151 return sizeof(double);
1152 case CUDNN_DATA_HALF:
1153 return sizeof(Eigen::half);
1154 default:
1155 LOG(FATAL) << "Invalid DNN data type: " << static_cast<int>(data_type);
1156 }
1157 }
1158
1159 class CudnnDropoutDescriptor {
CudnnDropoutDescriptor(DropoutDescriptor handle)1160 explicit CudnnDropoutDescriptor(DropoutDescriptor handle)
1161 : handle_(std::move(handle)) {}
1162
1163 public:
1164 CudnnDropoutDescriptor(CudnnDropoutDescriptor&&) = default;
1165
Create(const CudnnHandle & cudnn,float dropout,uint64_t seed,ScratchAllocator * state_allocator)1166 static port::StatusOr<CudnnDropoutDescriptor> Create(
1167 const CudnnHandle& cudnn, float dropout, uint64_t seed,
1168 ScratchAllocator* state_allocator) {
1169 DropoutDescriptor handle = CreateDropoutDescriptor();
1170
1171 if (dropout == 0.0f) {
1172 // Return 'empty' dropout descriptor.
1173 return CudnnDropoutDescriptor(std::move(handle));
1174 }
1175
1176 DeviceMemory<uint8> state_memory;
1177 if (state_allocator) {
1178 size_t state_sizes_in_bytes = 0;
1179 RETURN_IF_CUDNN_ERROR(
1180 cudnnDropoutGetStatesSize(cudnn.handle(), &state_sizes_in_bytes));
1181 TF_ASSIGN_OR_RETURN(state_memory,
1182 state_allocator->AllocateBytes(state_sizes_in_bytes));
1183 }
1184 RETURN_IF_CUDNN_ERROR(cudnnSetDropoutDescriptor(
1185 handle.get(), cudnn.handle(), dropout, state_memory.opaque(),
1186 state_memory.size(), seed));
1187
1188 return CudnnDropoutDescriptor(std::move(handle));
1189 }
1190
handle() const1191 cudnnDropoutDescriptor_t handle() const { return handle_.get(); }
1192
1193 private:
1194 DropoutDescriptor handle_; // Owned.
1195 SE_DISALLOW_COPY_AND_ASSIGN(CudnnDropoutDescriptor);
1196 };
1197
1198 class CudnnRnnParamsDescriptor {
1199 typedef dnn::RnnDescriptor::ParamsRegions ParamsRegions;
1200
CudnnRnnParamsDescriptor(FilterDescriptor handle,int64_t params_size_in_bytes,ParamsRegions weights,ParamsRegions biases)1201 CudnnRnnParamsDescriptor(FilterDescriptor handle,
1202 int64_t params_size_in_bytes, ParamsRegions weights,
1203 ParamsRegions biases)
1204 : handle_(std::move(handle)),
1205 params_size_in_bytes_(params_size_in_bytes),
1206 weights_(std::move(weights)),
1207 biases_(std::move(biases)) {}
1208
1209 public:
1210 CudnnRnnParamsDescriptor(CudnnRnnParamsDescriptor&&) = default;
1211
1212 static port::StatusOr<CudnnRnnParamsDescriptor> Create(
1213 const CudnnHandle& cudnn, int input_size, cudnnDataType_t data_type,
1214 cudnnRNNDescriptor_t rnn_desc, cudnnRNNMode_t rnn_mode,
1215 cudnnDirectionMode_t direction_mode, int num_layers);
1216
handle() const1217 cudnnFilterDescriptor_t handle() const { return handle_.get(); }
params_size_in_bytes() const1218 int64_t params_size_in_bytes() const { return params_size_in_bytes_; }
params_weights() const1219 ParamsRegions params_weights() const { return weights_; }
params_biases() const1220 ParamsRegions params_biases() const { return biases_; }
1221
1222 private:
1223 FilterDescriptor handle_;
1224 int64_t params_size_in_bytes_;
1225 ParamsRegions weights_;
1226 ParamsRegions biases_;
1227 SE_DISALLOW_COPY_AND_ASSIGN(CudnnRnnParamsDescriptor);
1228 };
1229
1230 } // namespace
1231
1232 class CudnnRnnDescriptor : public dnn::RnnDescriptor {
CudnnRnnDescriptor(const CudnnHandle & cudnn,gpu::RnnDescriptor rnn_desc,PersistentRnnPlan rnn_plan,int num_layers,int hidden_size,int input_size,int cell_size,int batch_size,cudnnRNNInputMode_t input_mode,cudnnDirectionMode_t direction_mode,cudnnRNNMode_t rnn_mode,cudnnDataType_t data_type,cudnnDataType_t compute_type,const dnn::AlgorithmConfig & algorithm_config,CudnnDropoutDescriptor dropout_desc,CudnnRnnParamsDescriptor params_desc)1233 CudnnRnnDescriptor(const CudnnHandle& cudnn, gpu::RnnDescriptor rnn_desc,
1234 PersistentRnnPlan rnn_plan, int num_layers,
1235 int hidden_size, int input_size, int cell_size,
1236 int batch_size, cudnnRNNInputMode_t input_mode,
1237 cudnnDirectionMode_t direction_mode,
1238 cudnnRNNMode_t rnn_mode, cudnnDataType_t data_type,
1239 cudnnDataType_t compute_type,
1240 const dnn::AlgorithmConfig& algorithm_config,
1241 CudnnDropoutDescriptor dropout_desc,
1242 CudnnRnnParamsDescriptor params_desc)
1243 : rnn_desc_(std::move(rnn_desc)),
1244 rnn_plan_(std::move(rnn_plan)),
1245 num_layers_(num_layers),
1246 hidden_size_(hidden_size),
1247 input_size_(input_size),
1248 cell_size_(cell_size),
1249 batch_size_(batch_size),
1250 rnn_algo_(ToCudnnRNNAlgo(algorithm_config.algorithm())),
1251 input_mode_(input_mode),
1252 direction_mode_(direction_mode),
1253 rnn_mode_(rnn_mode),
1254 data_type_(data_type),
1255 compute_type_(compute_type),
1256 algorithm_config_(algorithm_config),
1257 dropout_desc_(std::move(dropout_desc)),
1258 params_desc_(std::move(params_desc)) {}
1259
1260 public:
1261 CudnnRnnDescriptor(CudnnRnnDescriptor&& other) = default;
1262
Create(const CudnnHandle & cudnn,int num_layers,int hidden_size,int input_size,int cell_size,int batch_size,cudnnRNNInputMode_t input_mode,cudnnDirectionMode_t direction_mode,cudnnRNNMode_t rnn_mode,cudnnDataType_t data_type,cudnnDataType_t compute_type,const dnn::AlgorithmConfig & algorithm_config,float dropout,uint64_t seed,ScratchAllocator * state_allocator,bool use_padded_io)1263 static port::StatusOr<CudnnRnnDescriptor> Create(
1264 const CudnnHandle& cudnn, int num_layers, int hidden_size, int input_size,
1265 int cell_size, int batch_size, cudnnRNNInputMode_t input_mode,
1266 cudnnDirectionMode_t direction_mode, cudnnRNNMode_t rnn_mode,
1267 cudnnDataType_t data_type, cudnnDataType_t compute_type,
1268 const dnn::AlgorithmConfig& algorithm_config, float dropout,
1269 uint64_t seed, ScratchAllocator* state_allocator, bool use_padded_io) {
1270 TF_ASSIGN_OR_RETURN(
1271 CudnnDropoutDescriptor dropout_desc,
1272 CudnnDropoutDescriptor::Create(cudnn, dropout, seed, state_allocator));
1273
1274 gpu::RnnDescriptor rnn_desc = CreateRnnDescriptor();
1275 cudnnRNNAlgo_t rnn_algo = ToCudnnRNNAlgo(algorithm_config.algorithm());
1276
1277 // TODO: allow the user to choose an algorithm.
1278 auto proj_size = hidden_size;
1279 hidden_size = std::max(hidden_size, cell_size);
1280
1281 // Require explicit algorithm config to enable tensor cores. Some configs
1282 // return CUDNN_NOT_SUPPORTED when tensor ops are enabled (which is against
1283 // the idiom that enabling tensor ops is only a hint: see nvbugs/2172799).
1284 // We can only reasonably expect the user to handle the subsequent failure
1285 // in profile mode, which is run with algorithms returned from
1286 // GetRnnAlgorithms() (which are non-default and explicitly set whether to
1287 // use tensor ops). CuDNN 7.2.1 fixed this issue.
1288 // TODO(csigg): Minimal support cuDNN version is 7.3, clean up.
1289 bool allow_tensor_ops = data_type == CUDNN_DATA_HALF;
1290 if (data_type == CUDNN_DATA_FLOAT)
1291 allow_tensor_ops = tensorflow::tensor_float_32_execution_enabled();
1292 bool use_tensor_ops =
1293 algorithm_config.algorithm().has_value()
1294 ? algorithm_config.algorithm()->tensor_ops_enabled()
1295 : allow_tensor_ops;
1296 if (use_tensor_ops && !allow_tensor_ops) {
1297 return port::Status(port::error::INVALID_ARGUMENT,
1298 "Algo requests disallowed tensor op evaluation.");
1299 }
1300
1301 #if CUDNN_VERSION >= 8000
1302 cudnnMathType_t math_type =
1303 use_tensor_ops ? CUDNN_TENSOR_OP_MATH : CUDNN_FMA_MATH;
1304 #else
1305 cudnnMathType_t math_type =
1306 use_tensor_ops ? CUDNN_TENSOR_OP_MATH : CUDNN_DEFAULT_MATH;
1307 #endif
1308
1309 #if CUDNN_VERSION >= 8000
1310 cudnnRNNBiasMode_t bias_mode = CUDNN_RNN_DOUBLE_BIAS;
1311 uint32_t aux_flags = 0;
1312 if (use_padded_io) aux_flags |= CUDNN_RNN_PADDED_IO_ENABLED;
1313 RETURN_IF_CUDNN_ERROR(cudnnSetRNNDescriptor_v8(
1314 /*rnnDesc=*/rnn_desc.get(), /*algo=*/rnn_algo, /*cellMode=*/rnn_mode,
1315 /*biasMode=*/bias_mode, /*dirMode=*/direction_mode,
1316 /*inputMode=*/input_mode,
1317 /*dataType=*/data_type, /*mathPrec=*/compute_type,
1318 /*mathType=*/math_type,
1319 /*inputSize=*/input_size,
1320 /*hiddenSize=*/hidden_size, /*projSize=*/proj_size,
1321 /*numLayers=*/num_layers,
1322 /*dropoutDesc=*/dropout_desc.handle(),
1323 /*auxFlags=*/aux_flags));
1324 #else
1325 RETURN_IF_CUDNN_ERROR(cudnnSetRNNDescriptor_v6(
1326 cudnn.handle(), /*rnnDesc=*/rnn_desc.get(),
1327 /*hiddenSize=*/hidden_size, /*numLayers=*/num_layers,
1328 /*dropoutDesc=*/dropout_desc.handle(), /*inputMode=*/input_mode,
1329 /*direction=*/direction_mode, /*mode=*/rnn_mode, /*algo=*/rnn_algo,
1330 /*dataType=*/compute_type));
1331 CHECK_CUDNN_OK(cudnnSetRNNMatrixMathType(rnn_desc.get(), math_type));
1332
1333 if (proj_size < hidden_size) {
1334 RETURN_IF_CUDNN_ERROR(cudnnSetRNNProjectionLayers(
1335 cudnn.handle(), /*rnnDesc=*/rnn_desc.get(),
1336 /*recProjSize=*/proj_size, /*outProjSize=*/0));
1337 }
1338
1339 // TODO: For now, we only use cudnnRNN**Ex API to process padded inputs.
1340 // But in the future if these APIs are used to process full length arrays,
1341 // we need to distinguish when to set it.
1342 if (use_padded_io) {
1343 RETURN_IF_CUDNN_ERROR(
1344 cudnnSetRNNPaddingMode(rnn_desc.get(), CUDNN_RNN_PADDED_IO_ENABLED));
1345 }
1346 #endif
1347
1348 port::StatusOr<PersistentRnnPlan> rnn_plan_wrapper;
1349 PersistentRnnPlan rnn_plan;
1350 if (rnn_algo == CUDNN_RNN_ALGO_PERSIST_DYNAMIC) {
1351 CHECK_GE(batch_size, 0);
1352 rnn_plan_wrapper =
1353 CreatePersistentRnnPlan(rnn_desc.get(), batch_size, data_type);
1354 if (!rnn_plan_wrapper.ok()) {
1355 return port::StatusOr<CudnnRnnDescriptor>(rnn_plan_wrapper.status());
1356 } else {
1357 rnn_plan = std::move(rnn_plan_wrapper).value();
1358 RETURN_IF_CUDNN_ERROR(
1359 cudnnSetPersistentRNNPlan(rnn_desc.get(), rnn_plan.get()));
1360 }
1361 }
1362
1363 // Create the params handle.
1364 // TODO(kaixih@nvidia.com): Should be removed when cudnnRNNForward*** and
1365 // cudnnRNNForward***Ex are removed from the codebase, since the new API
1366 // doesn't need param descriptors any more.
1367 TF_ASSIGN_OR_RETURN(auto params_desc,
1368 CudnnRnnParamsDescriptor::Create(
1369 cudnn, input_size, data_type, rnn_desc.get(),
1370 rnn_mode, direction_mode, num_layers));
1371
1372 return CudnnRnnDescriptor(cudnn, std::move(rnn_desc), std::move(rnn_plan),
1373 num_layers, hidden_size, input_size, cell_size,
1374 batch_size, input_mode, direction_mode, rnn_mode,
1375 data_type, compute_type, algorithm_config,
1376 std::move(dropout_desc), std::move(params_desc));
1377 }
1378
handle() const1379 cudnnRNNDescriptor_t handle() const { return rnn_desc_.get(); }
num_layers() const1380 int num_layers() const { return num_layers_; }
hidden_size() const1381 int hidden_size() const { return hidden_size_; }
input_size() const1382 int input_size() const { return input_size_; }
cell_size() const1383 int cell_size() const { return cell_size_; }
batch_size() const1384 int batch_size() const { return batch_size_; }
input_mode() const1385 cudnnRNNInputMode_t input_mode() const { return input_mode_; }
direction_mode() const1386 cudnnDirectionMode_t direction_mode() const { return direction_mode_; }
rnn_mode() const1387 cudnnRNNMode_t rnn_mode() const { return rnn_mode_; }
data_type() const1388 cudnnDataType_t data_type() const { return data_type_; }
compute_type() const1389 cudnnDataType_t compute_type() const { return compute_type_; }
algorithm_config() const1390 const dnn::AlgorithmConfig& algorithm_config() const {
1391 return algorithm_config_;
1392 }
ParamsSizeInBytes() const1393 int64_t ParamsSizeInBytes() const override {
1394 return params_desc_.params_size_in_bytes();
1395 }
params_handle() const1396 cudnnFilterDescriptor_t params_handle() const {
1397 return params_desc_.handle();
1398 }
ParamsWeightRegions() const1399 ParamsRegions ParamsWeightRegions() const override {
1400 return params_desc_.params_weights();
1401 }
ParamsBiasRegions() const1402 ParamsRegions ParamsBiasRegions() const override {
1403 return params_desc_.params_biases();
1404 }
1405
1406 private:
1407 gpu::RnnDescriptor rnn_desc_;
1408 PersistentRnnPlan rnn_plan_;
1409 int num_layers_;
1410 int hidden_size_;
1411 int input_size_;
1412 // cell_size_ is the size of cell state, which will be different from
1413 // hidden_size_ if the projection is used.
1414 int cell_size_;
1415 // batch_size_ is set to -1 when not using CUDNN_RNN_ALGO_PERSIST_DYNAMIC
1416 // algorithm.
1417 int batch_size_;
1418 cudnnRNNAlgo_t rnn_algo_;
1419 cudnnRNNInputMode_t input_mode_;
1420 cudnnDirectionMode_t direction_mode_;
1421 cudnnRNNMode_t rnn_mode_;
1422 cudnnDataType_t data_type_;
1423 cudnnDataType_t compute_type_;
1424 dnn::AlgorithmConfig algorithm_config_;
1425 CudnnDropoutDescriptor dropout_desc_;
1426 CudnnRnnParamsDescriptor params_desc_;
1427 SE_DISALLOW_COPY_AND_ASSIGN(CudnnRnnDescriptor);
1428 };
1429
1430 #if CUDNN_VERSION >= 7603
1431 class CudnnCtcLossDescriptor {
1432 public:
CudnnCtcLossDescriptor(cudnnDataType_t data_type)1433 explicit CudnnCtcLossDescriptor(cudnnDataType_t data_type)
1434 : handle_(CreateCtcLossDescriptor()) {
1435 CHECK_CUDNN_OK(cudnnSetCTCLossDescriptorEx(
1436 /*ctcLossDesc=*/handle_.get(),
1437 /*compType=*/data_type,
1438 /*normMode=*/CUDNN_LOSS_NORMALIZATION_SOFTMAX,
1439 /*gradMode=*/CUDNN_NOT_PROPAGATE_NAN));
1440 }
1441
handle() const1442 cudnnCTCLossDescriptor_t handle() const { return handle_.get(); }
1443
1444 private:
1445 CtcLossDescriptor handle_; // Owned
1446
1447 SE_DISALLOW_COPY_AND_ASSIGN(CudnnCtcLossDescriptor);
1448 };
1449 #else
1450 // dummy class
1451 class CudnnCtcLossDescriptor {
1452 public:
CudnnCtcLossDescriptor(cudnnDataType_t data_type)1453 CudnnCtcLossDescriptor(cudnnDataType_t data_type) {}
1454 };
1455 #endif
1456
1457 namespace {
1458
1459 // Check if the LSTM projection is used. If yes, an additional weight matrix
1460 // (projection matrix) will be fetched to the 'weights'. Otherwise, nothing will
1461 // be done.
CheckAndFetchProjectionWeights(const CudnnHandle & cudnn,cudnnRNNDescriptor_t rnn_desc,const int layer,const TensorDescriptor & input_desc,const FilterDescriptor & filter_desc,const FilterDescriptor & region_desc_handle,dnn::RnnDescriptor::ParamsRegions * weights)1462 port::Status CheckAndFetchProjectionWeights(
1463 const CudnnHandle& cudnn, cudnnRNNDescriptor_t rnn_desc, const int layer,
1464 const TensorDescriptor& input_desc, const FilterDescriptor& filter_desc,
1465 const FilterDescriptor& region_desc_handle,
1466 dnn::RnnDescriptor::ParamsRegions* weights) {
1467 int hidden_size_v;
1468 int num_layers_v;
1469 cudnnDropoutDescriptor_t dropout_desc;
1470 cudnnRNNInputMode_t input_mode;
1471 cudnnDirectionMode_t direction;
1472 cudnnRNNMode_t mode;
1473 cudnnRNNAlgo_t algo;
1474 cudnnDataType_t data_type;
1475 #if CUDNN_VERSION >= 8000
1476 RETURN_IF_CUDNN_ERROR(cudnnGetRNNDescriptor_v6(
1477 /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc,
1478 /*hiddenSize=*/&hidden_size_v,
1479 /*numLayers=*/&num_layers_v,
1480 /*dropoutDesc=*/&dropout_desc,
1481 /*inputMode=*/&input_mode,
1482 /*direction=*/&direction,
1483 /*mode=*/&mode,
1484 /*algo=*/&algo,
1485 /*mathPrec=*/&data_type));
1486 #else
1487 RETURN_IF_CUDNN_ERROR(cudnnGetRNNDescriptor(
1488 /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc,
1489 /*hiddenSize=*/&hidden_size_v,
1490 /*numLayers=*/&num_layers_v,
1491 /*dropoutDesc=*/&dropout_desc,
1492 /*inputMode=*/&input_mode,
1493 /*direction=*/&direction,
1494 /*mode=*/&mode,
1495 /*algo=*/&algo,
1496 /*mathPrec=*/&data_type));
1497 #endif
1498 int rec_proj_size_v;
1499 int out_proj_size_v;
1500 RETURN_IF_CUDNN_ERROR(cudnnGetRNNProjectionLayers(
1501 /*handle=*/cudnn.handle(),
1502 /*rnnDesc=*/rnn_desc,
1503 /*recProjSize*/ &rec_proj_size_v,
1504 /*outProjSize*/ &out_proj_size_v));
1505 if (rec_proj_size_v != hidden_size_v) {
1506 void* offset = nullptr;
1507 int region_id = 8;
1508 RETURN_IF_CUDNN_ERROR(cudnnGetRNNLinLayerMatrixParams(
1509 /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc,
1510 /*layer=*/layer, /*xDesc=*/input_desc.get(),
1511 /*wDesc=*/filter_desc.get(),
1512 /*w=*/nullptr, /*linLayerID=*/region_id,
1513 /*linLayerMatDesc=*/region_desc_handle.get(),
1514 /*linLayerMat or linLayerBias=*/&offset));
1515 int dims[] = {1, 1, 1};
1516 cudnnDataType_t data_type;
1517 cudnnTensorFormat_t tensor_format;
1518 int n_dims;
1519 RETURN_IF_CUDNN_ERROR(cudnnGetFilterNdDescriptor(
1520 /*filterDesc=*/region_desc_handle.get(),
1521 /*nbDimsRequested=*/sizeof(dims) / sizeof(dims[0]),
1522 /*dataType=*/&data_type, /*format=*/&tensor_format,
1523 /*nbDims=*/&n_dims, /*filterDimA=*/dims));
1524 int64_t size =
1525 dims[0] * dims[1] * dims[2] * CudnnDataTypeToByteSize(data_type);
1526 dnn::RnnDescriptor::ParamsRegion region = {
1527 reinterpret_cast<int64_t>(offset), size};
1528 weights->push_back(region);
1529 }
1530 return ::tensorflow::OkStatus();
1531 }
1532
Create(const CudnnHandle & cudnn,int input_size,cudnnDataType_t data_type,cudnnRNNDescriptor_t rnn_desc,cudnnRNNMode_t rnn_mode,cudnnDirectionMode_t direction_mode,int num_layers)1533 port::StatusOr<CudnnRnnParamsDescriptor> CudnnRnnParamsDescriptor::Create(
1534 const CudnnHandle& cudnn, int input_size, cudnnDataType_t data_type,
1535 cudnnRNNDescriptor_t rnn_desc, cudnnRNNMode_t rnn_mode,
1536 cudnnDirectionMode_t direction_mode, int num_layers) {
1537 // Query the params size.
1538 TensorDescriptor input_desc = CreateTensorDescriptor();
1539 int tensor_dims[] = {1, input_size, 1};
1540 int strides[] = {tensor_dims[1] * tensor_dims[2], tensor_dims[2], 1};
1541 RETURN_IF_CUDNN_ERROR(cudnnSetTensorNdDescriptor(
1542 /*tensorDesc=*/input_desc.get(), /*dataType=*/data_type,
1543 /*nbDims=*/sizeof(tensor_dims) / sizeof(tensor_dims[0]),
1544 /*dimA=*/tensor_dims,
1545 /*strideA=*/strides));
1546
1547 size_t params_size = 0;
1548 RETURN_IF_CUDNN_ERROR(cudnnGetRNNParamsSize(
1549 /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc,
1550 /*xDesc=*/input_desc.get(), /*sizeInBytes=*/¶ms_size,
1551 /*dataType=*/data_type));
1552 int64_t params_size_in_bytes = static_cast<int64_t>(params_size);
1553
1554 FilterDescriptor filter_desc = CreateFilterDescriptor();
1555 int64_t filter_dim0 =
1556 params_size_in_bytes / CudnnDataTypeToByteSize(data_type);
1557 int filter_dims[] = {static_cast<int>(filter_dim0), 1, 1};
1558 RETURN_IF_CUDNN_ERROR(cudnnSetFilterNdDescriptor(
1559 /*filterDesc=*/filter_desc.get(), /*dataType=*/data_type,
1560 /*format=*/CUDNN_TENSOR_NCHW,
1561 /*nbDims=*/sizeof(filter_dims) / sizeof(filter_dims[0]),
1562 /*filterDimA=*/filter_dims));
1563
1564 // Create the weights and biases into the params buffer
1565 int region_count_per_layer = [&] {
1566 switch (rnn_mode) {
1567 case CUDNN_RNN_RELU:
1568 case CUDNN_RNN_TANH:
1569 return 2;
1570 case CUDNN_LSTM:
1571 return 8;
1572 case CUDNN_GRU:
1573 return 6;
1574 default:
1575 LOG(FATAL) << "Invalid RNN Mode: " << static_cast<int>(rnn_mode);
1576 return 0;
1577 }
1578 }();
1579
1580 FilterDescriptor region_desc_handle = CreateFilterDescriptor();
1581 const int layer_count =
1582 direction_mode == CUDNN_UNIDIRECTIONAL ? num_layers : 2 * num_layers;
1583
1584 ParamsRegions weights;
1585 ParamsRegions biases;
1586
1587 for (int layer = 0; layer < layer_count; layer++) {
1588 for (int region = 0; region < region_count_per_layer; region++) {
1589 for (int type = 0; type < 2; type++) {
1590 void* offset = nullptr;
1591 RETURN_IF_CUDNN_ERROR(
1592 type == 0 ? cudnnGetRNNLinLayerMatrixParams(
1593 /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc,
1594 /*layer=*/layer, /*xDesc=*/input_desc.get(),
1595 /*wDesc=*/filter_desc.get(),
1596 /*w=*/nullptr, /*linLayerID=*/region,
1597 /*linLayerMatDesc=*/region_desc_handle.get(),
1598 /*linLayerMat or linLayerBias=*/&offset)
1599 : cudnnGetRNNLinLayerBiasParams(
1600 /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc,
1601 /*layer=*/layer, /*xDesc=*/input_desc.get(),
1602 /*wDesc=*/filter_desc.get(),
1603 /*w=*/nullptr, /*linLayerID=*/region,
1604 /*linLayerMatDesc=*/region_desc_handle.get(),
1605 /*linLayerMat or linLayerBias=*/&offset));
1606 int dims[] = {1, 1, 1};
1607 cudnnDataType_t data_type;
1608 cudnnTensorFormat_t tensor_format;
1609 int n_dims;
1610 RETURN_IF_CUDNN_ERROR(cudnnGetFilterNdDescriptor(
1611 /*filterDesc=*/region_desc_handle.get(),
1612 /*nbDimsRequested=*/sizeof(dims) / sizeof(dims[0]),
1613 /*dataType=*/&data_type, /*format=*/&tensor_format,
1614 /*nbDims=*/&n_dims, /*filterDimA=*/dims));
1615 int64_t size =
1616 dims[0] * dims[1] * dims[2] * CudnnDataTypeToByteSize(data_type);
1617 dnn::RnnDescriptor::ParamsRegion region = {
1618 reinterpret_cast<int64_t>(offset), size};
1619 (type == 0 ? weights : biases).push_back(region);
1620 }
1621 }
1622 TF_RETURN_IF_ERROR(CheckAndFetchProjectionWeights(
1623 cudnn, rnn_desc, layer, input_desc, filter_desc, region_desc_handle,
1624 &weights));
1625 }
1626
1627 return CudnnRnnParamsDescriptor(std::move(filter_desc), params_size_in_bytes,
1628 weights, biases);
1629 }
1630
1631 } // namespace
1632
1633 class CudnnRnnSequenceTensorDescriptor
1634 : public dnn::RnnSequenceTensorDescriptor {
CudnnRnnSequenceTensorDescriptor(GpuExecutor * parent,int max_seq_length,int batch_size,int data_size,cudnnDataType_t data_type,RNNDataDescriptor data_handle,TensorDescriptor handle)1635 CudnnRnnSequenceTensorDescriptor(GpuExecutor* parent, int max_seq_length,
1636 int batch_size, int data_size,
1637 cudnnDataType_t data_type,
1638 RNNDataDescriptor data_handle,
1639 TensorDescriptor handle)
1640 : max_seq_length_(max_seq_length),
1641 batch_size_(batch_size),
1642 data_size_(data_size),
1643 data_type_(data_type),
1644 handle_(std::move(handle)),
1645 rnn_data_handle_(std::move(data_handle)),
1646 handles_(max_seq_length, handle_.get()) {}
1647
1648 public:
1649 CudnnRnnSequenceTensorDescriptor(CudnnRnnSequenceTensorDescriptor&&) =
1650 default;
1651
Create(GpuExecutor * parent,int max_seq_length,int batch_size,int data_size,cudnnDataType_t data_type)1652 static port::StatusOr<CudnnRnnSequenceTensorDescriptor> Create(
1653 GpuExecutor* parent, int max_seq_length, int batch_size, int data_size,
1654 cudnnDataType_t data_type) {
1655 if (max_seq_length <= 0) {
1656 return port::Status(port::error::INVALID_ARGUMENT, "max_seq_length <= 0");
1657 }
1658 int dims[] = {batch_size, data_size, 1};
1659 int strides[] = {dims[1] * dims[2], dims[2], 1};
1660 TensorDescriptor tensor_desc = CreateTensorDescriptor();
1661 RETURN_IF_CUDNN_ERROR(cudnnSetTensorNdDescriptor(
1662 /*tensorDesc=*/tensor_desc.get(), /*dataType=*/data_type,
1663 /*nbDims=*/sizeof(dims) / sizeof(dims[0]), /*dimA=*/dims,
1664 /*strideA=*/strides));
1665 return CudnnRnnSequenceTensorDescriptor(parent, max_seq_length, batch_size,
1666 data_size, data_type, nullptr,
1667 std::move(tensor_desc));
1668 }
1669
Create(GpuExecutor * parent,int max_seq_length,int batch_size,int data_size,const absl::Span<const int> & seq_lengths,bool time_major,cudnnDataType_t data_type)1670 static port::StatusOr<CudnnRnnSequenceTensorDescriptor> Create(
1671 GpuExecutor* parent, int max_seq_length, int batch_size, int data_size,
1672 const absl::Span<const int>& seq_lengths, bool time_major,
1673 cudnnDataType_t data_type) {
1674 if (max_seq_length <= 0) {
1675 return port::Status(port::error::INVALID_ARGUMENT, "max_seq_length <= 0");
1676 }
1677 int dims[] = {batch_size, data_size, 1};
1678 int strides[] = {dims[1] * dims[2], dims[2], 1};
1679 TensorDescriptor tensor_desc = CreateTensorDescriptor();
1680 RETURN_IF_CUDNN_ERROR(cudnnSetTensorNdDescriptor(
1681 /*tensorDesc=*/tensor_desc.get(), /*dataType=*/data_type,
1682 /*nbDims=*/sizeof(dims) / sizeof(dims[0]), /*dimA=*/dims,
1683 /*strideA=*/strides));
1684 const int* seq_lengths_array = seq_lengths.data();
1685 RNNDataDescriptor data_desc = CreateRNNDataDescriptor();
1686 float padding_fill = 0.0f;
1687 cudnnRNNDataLayout_t layout;
1688 if (time_major) {
1689 layout = CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_UNPACKED;
1690 } else {
1691 layout = CUDNN_RNN_DATA_LAYOUT_BATCH_MAJOR_UNPACKED;
1692 }
1693 RETURN_IF_CUDNN_ERROR(cudnnSetRNNDataDescriptor(
1694 /*RNNDataDesc=*/data_desc.get(), /*dataType*/ data_type,
1695 /*layout=*/layout,
1696 /*maxSeqLength=*/max_seq_length,
1697 /*batchSize=*/batch_size, /*vectorSize=*/data_size,
1698 /*seqLengthArray=*/seq_lengths_array,
1699 /*paddingFill*/ (void*)&padding_fill));
1700 return CudnnRnnSequenceTensorDescriptor(
1701 parent, max_seq_length, batch_size, data_size, data_type,
1702 std::move(data_desc), std::move(tensor_desc));
1703 }
1704
handles() const1705 const cudnnTensorDescriptor_t* handles() const { return handles_.data(); }
data_handle() const1706 const cudnnRNNDataDescriptor_t data_handle() const {
1707 return rnn_data_handle_.get();
1708 }
1709
max_seq_length() const1710 int max_seq_length() const { return max_seq_length_; }
batch_size() const1711 int batch_size() const { return batch_size_; }
data_size() const1712 int data_size() const { return data_size_; }
is_var_seq_lengths() const1713 bool is_var_seq_lengths() const { return rnn_data_handle_ != nullptr; }
1714
1715 private:
1716 int max_seq_length_;
1717 int batch_size_;
1718 int data_size_;
1719 cudnnDataType_t data_type_;
1720 TensorDescriptor handle_;
1721 RNNDataDescriptor rnn_data_handle_;
1722 std::vector<cudnnTensorDescriptor_t> handles_; // Copies of handle_.
1723 SE_DISALLOW_COPY_AND_ASSIGN(CudnnRnnSequenceTensorDescriptor);
1724 };
1725
1726 class CudnnRnnStateTensorDescriptor : public dnn::RnnStateTensorDescriptor {
1727 public:
CudnnRnnStateTensorDescriptor(GpuExecutor * parent,int num_layers,int batch_size,int data_size,cudnnDataType_t data_type)1728 CudnnRnnStateTensorDescriptor(GpuExecutor* parent, int num_layers,
1729 int batch_size, int data_size,
1730 cudnnDataType_t data_type)
1731 : handle_(CreateTensorDescriptor()),
1732 num_layers_(num_layers),
1733 batch_size_(batch_size),
1734 data_size_(data_size),
1735 data_type_(data_type) {
1736 int dims[] = {num_layers, batch_size, data_size};
1737 int strides[] = {dims[1] * dims[2], dims[2], 1};
1738 CHECK_CUDNN_OK(cudnnSetTensorNdDescriptor(
1739 /*tensorDesc=*/handle_.get(), /*dataType=*/data_type,
1740 /*nbDims=*/sizeof(dims) / sizeof(dims[0]), /*dimA=*/dims,
1741 /*strideA=*/strides));
1742 }
1743
handle() const1744 cudnnTensorDescriptor_t handle() const { return handle_.get(); }
1745
num_layers() const1746 int num_layers() const { return num_layers_; }
batch_size() const1747 int batch_size() const { return batch_size_; }
data_size() const1748 int data_size() const { return data_size_; }
1749
1750 private:
1751 TensorDescriptor handle_;
1752 int num_layers_;
1753 int batch_size_;
1754 int data_size_;
1755 cudnnDataType_t data_type_;
1756 SE_DISALLOW_COPY_AND_ASSIGN(CudnnRnnStateTensorDescriptor);
1757 };
1758
1759 namespace {
1760
1761 struct RnnModelDims {
1762 int num_layers = 0;
1763 int batch_size = 0;
1764 int max_seq_length = 0;
1765 int hidden_size = 0;
1766 int input_size = 0;
1767 int cell_size = 0;
1768 int dir_count = 0;
1769 };
1770
1771 template <class T>
ExtractAndCheckRnnForward(const CudnnRnnDescriptor & rnn_desc,const CudnnRnnSequenceTensorDescriptor & input_desc,const DeviceMemory<T> & input_data,const CudnnRnnStateTensorDescriptor & input_h_desc,const DeviceMemory<T> & input_h_data,const CudnnRnnStateTensorDescriptor & input_c_desc,const DeviceMemory<T> & input_c_data,const DeviceMemory<T> & params,const CudnnRnnSequenceTensorDescriptor & output_desc,const DeviceMemory<T> & output_data,const CudnnRnnStateTensorDescriptor & output_h_desc,const DeviceMemory<T> & output_h_data,const CudnnRnnStateTensorDescriptor & output_c_desc,const DeviceMemory<T> & output_c_data)1772 port::StatusOr<RnnModelDims> ExtractAndCheckRnnForward(
1773 const CudnnRnnDescriptor& rnn_desc,
1774 const CudnnRnnSequenceTensorDescriptor& input_desc,
1775 const DeviceMemory<T>& input_data,
1776 const CudnnRnnStateTensorDescriptor& input_h_desc,
1777 const DeviceMemory<T>& input_h_data,
1778 const CudnnRnnStateTensorDescriptor& input_c_desc,
1779 const DeviceMemory<T>& input_c_data, const DeviceMemory<T>& params,
1780 const CudnnRnnSequenceTensorDescriptor& output_desc,
1781 const DeviceMemory<T>& output_data,
1782 const CudnnRnnStateTensorDescriptor& output_h_desc,
1783 const DeviceMemory<T>& output_h_data,
1784 const CudnnRnnStateTensorDescriptor& output_c_desc,
1785 const DeviceMemory<T>& output_c_data) {
1786 // extract model parameters
1787 RnnModelDims model_dims;
1788 model_dims.num_layers = rnn_desc.num_layers();
1789 model_dims.batch_size = input_desc.batch_size();
1790 model_dims.max_seq_length = input_desc.max_seq_length();
1791 model_dims.hidden_size = rnn_desc.hidden_size();
1792 model_dims.input_size = input_desc.data_size();
1793 model_dims.cell_size = rnn_desc.cell_size();
1794 model_dims.dir_count =
1795 (rnn_desc.direction_mode() == CUDNN_BIDIRECTIONAL) ? 2 : 1;
1796
1797 // check parameters
1798 if (!(input_h_desc.num_layers() ==
1799 model_dims.num_layers * model_dims.dir_count &&
1800 input_h_desc.batch_size() == model_dims.batch_size &&
1801 input_h_desc.data_size() == model_dims.hidden_size)) {
1802 return port::Status(port::error::INVALID_ARGUMENT, "Invalid input_h shape");
1803 }
1804 // The LSTM projection will be used if input_h_desc.data_size() <
1805 // input_c_desc.data_size()
1806 if (!(input_h_desc.num_layers() == input_c_desc.num_layers() &&
1807 input_h_desc.batch_size() == input_c_desc.batch_size() &&
1808 input_h_desc.data_size() <= input_c_desc.data_size())) {
1809 return port::Status(port::error::INVALID_ARGUMENT, "Invalid input_c shape");
1810 }
1811 if (!(output_desc.max_seq_length() == model_dims.max_seq_length &&
1812 output_desc.batch_size() == model_dims.batch_size &&
1813 output_desc.data_size() ==
1814 model_dims.hidden_size * model_dims.dir_count)) {
1815 return port::Status(port::error::INVALID_ARGUMENT, "Invalid output shape");
1816 }
1817 if (!(input_h_desc.num_layers() == output_h_desc.num_layers() &&
1818 input_h_desc.batch_size() == output_h_desc.batch_size() &&
1819 input_h_desc.data_size() == output_h_desc.data_size())) {
1820 return port::Status(port::error::INVALID_ARGUMENT,
1821 "Invalid output_h shape");
1822 }
1823 if (!(input_h_desc.num_layers() == output_c_desc.num_layers() &&
1824 input_h_desc.batch_size() == output_c_desc.batch_size() &&
1825 input_h_desc.data_size() <= output_c_desc.data_size())) {
1826 return port::Status(port::error::INVALID_ARGUMENT,
1827 "Invalid output_c shape");
1828 }
1829
1830 return model_dims;
1831 }
1832
CheckRNNParameterSize(const CudnnHandle & cudnn,const CudnnRnnDescriptor & rnn_desc,const CudnnRnnSequenceTensorDescriptor & input_desc)1833 port::Status CheckRNNParameterSize(
1834 const CudnnHandle& cudnn, const CudnnRnnDescriptor& rnn_desc,
1835 const CudnnRnnSequenceTensorDescriptor& input_desc) {
1836 size_t params_size_in_bytes = 0;
1837 #if CUDNN_VERSION >= 8100 && TF_ENABLE_CUDNN_FRONTEND
1838 RETURN_IF_CUDNN_ERROR(cudnnGetRNNWeightSpaceSize(
1839 /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(),
1840 /*sizeInBytes=*/¶ms_size_in_bytes));
1841 #else
1842 RETURN_IF_CUDNN_ERROR(cudnnGetRNNParamsSize(
1843 /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(),
1844 /*xDesc=*/input_desc.handles()[0], /*sizeInBytes=*/¶ms_size_in_bytes,
1845 /*dataType=*/rnn_desc.data_type()));
1846 #endif
1847 if (static_cast<int64_t>(params_size_in_bytes) !=
1848 rnn_desc.ParamsSizeInBytes()) {
1849 return port::Status(port::error::INVALID_ARGUMENT,
1850 "Mismatching RNN parameter size");
1851 }
1852 return ::tensorflow::OkStatus();
1853 }
1854
CreateRnnWorkspace(Stream * stream,const CudnnHandle & cudnn,const CudnnRnnDescriptor & rnn_desc,const CudnnRnnSequenceTensorDescriptor & input_desc,ScratchAllocator * workspace_allocator)1855 port::StatusOr<DeviceMemory<uint8>> CreateRnnWorkspace(
1856 Stream* stream, const CudnnHandle& cudnn,
1857 const CudnnRnnDescriptor& rnn_desc,
1858 const CudnnRnnSequenceTensorDescriptor& input_desc,
1859 ScratchAllocator* workspace_allocator) {
1860 // Query the workspace size.
1861 size_t workspace_size_in_bytes = 0;
1862 RETURN_IF_CUDNN_ERROR(cudnnGetRNNWorkspaceSize(
1863 /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(),
1864 /*seqLength=*/input_desc.max_seq_length(), /*xDesc=*/input_desc.handles(),
1865 /*sizeInBytes=*/&workspace_size_in_bytes));
1866 // Allocate the workspace.
1867 if (workspace_size_in_bytes == 0) {
1868 return DeviceMemory<uint8>();
1869 }
1870 return workspace_allocator->AllocateBytes(workspace_size_in_bytes);
1871 }
1872
1873 #if CUDNN_VERSION >= 7402
CreateBatchNormForwardWorkspace(Stream * stream,const CudnnHandle & cudnn,const cudnnBatchNormMode_t & mode,const cudnnBatchNormOps_t & bn_ops,const cudnnActivationDescriptor_t & activation_desc,const CudnnTensorDescriptor & x_descriptor,const CudnnTensorDescriptor & scale_offset_descriptor,ScratchAllocator * workspace_allocator)1874 port::StatusOr<DeviceMemory<uint8>> CreateBatchNormForwardWorkspace(
1875 Stream* stream, const CudnnHandle& cudnn, const cudnnBatchNormMode_t& mode,
1876 const cudnnBatchNormOps_t& bn_ops,
1877 const cudnnActivationDescriptor_t& activation_desc,
1878 const CudnnTensorDescriptor& x_descriptor,
1879 const CudnnTensorDescriptor& scale_offset_descriptor,
1880 ScratchAllocator* workspace_allocator) {
1881 // Query the workspace size.
1882 size_t workspace_size_in_bytes = 0;
1883 RETURN_IF_CUDNN_ERROR(
1884 cudnnGetBatchNormalizationForwardTrainingExWorkspaceSize(
1885 /*handle=*/cudnn.handle(), /*mode=*/mode, /*bnOps=*/bn_ops,
1886 /*xDesc=*/x_descriptor.handle(), /*zDesc=*/x_descriptor.handle(),
1887 /*yDesc=*/x_descriptor.handle(),
1888 /*bnScaleBiasMeanVarDesc=*/scale_offset_descriptor.handle(),
1889 /*activationDesc=*/activation_desc,
1890 /*sizeInBytes=*/&workspace_size_in_bytes));
1891 // Allocate the workspace.
1892 if (workspace_size_in_bytes == 0) {
1893 return DeviceMemory<uint8>();
1894 }
1895 return workspace_allocator->AllocateBytes(workspace_size_in_bytes);
1896 }
1897
CreateBatchNormBackwardWorkspace(Stream * stream,const CudnnHandle & cudnn,const cudnnBatchNormMode_t & mode,const cudnnBatchNormOps_t & bn_ops,const cudnnActivationDescriptor_t & activation_desc,const CudnnTensorDescriptor & x_descriptor,const CudnnTensorDescriptor & scale_offset_descriptor,ScratchAllocator * workspace_allocator)1898 port::StatusOr<DeviceMemory<uint8>> CreateBatchNormBackwardWorkspace(
1899 Stream* stream, const CudnnHandle& cudnn, const cudnnBatchNormMode_t& mode,
1900 const cudnnBatchNormOps_t& bn_ops,
1901 const cudnnActivationDescriptor_t& activation_desc,
1902 const CudnnTensorDescriptor& x_descriptor,
1903 const CudnnTensorDescriptor& scale_offset_descriptor,
1904 ScratchAllocator* workspace_allocator) {
1905 // Query the workspace size.
1906 size_t workspace_size_in_bytes = 0;
1907 RETURN_IF_CUDNN_ERROR(cudnnGetBatchNormalizationBackwardExWorkspaceSize(
1908 /*handle=*/cudnn.handle(), /*mode=*/mode, /*bnOps=*/bn_ops,
1909 /*xDesc=*/x_descriptor.handle(),
1910 /*yDesc=*/x_descriptor.handle(),
1911 /*dyDesc=*/x_descriptor.handle(),
1912 /*dzDesc=*/x_descriptor.handle(),
1913 /*dxDesc=*/x_descriptor.handle(),
1914 /*dBnScaleBiasDesc=*/scale_offset_descriptor.handle(),
1915 /*activationDesc=*/activation_desc,
1916 /*sizeInBytes=*/&workspace_size_in_bytes));
1917 // Allocate the workspace.
1918 if (workspace_size_in_bytes == 0) {
1919 return DeviceMemory<uint8>();
1920 }
1921 return workspace_allocator->AllocateBytes(workspace_size_in_bytes);
1922 }
1923
1924 #endif
1925
1926 } // namespace
1927
1928 template <class T>
DoRnnForwardImpl(Stream * stream,const CudnnRnnDescriptor & rnn_desc,const CudnnRnnSequenceTensorDescriptor & input_desc,const DeviceMemory<T> & input_data,const DeviceMemory<int> & seq_lengths_data,const CudnnRnnStateTensorDescriptor & input_h_desc,const DeviceMemory<T> & input_h_data,const CudnnRnnStateTensorDescriptor & input_c_desc,const DeviceMemory<T> & input_c_data,const DeviceMemory<T> & params,const CudnnRnnSequenceTensorDescriptor & output_desc,DeviceMemory<T> * output_data,const CudnnRnnStateTensorDescriptor & output_h_desc,DeviceMemory<T> * output_h_data,const CudnnRnnStateTensorDescriptor & output_c_desc,DeviceMemory<T> * output_c_data,bool is_training,ScratchAllocator * reserve_space_allocator,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)1929 port::Status CudnnSupport::DoRnnForwardImpl(
1930 Stream* stream, const CudnnRnnDescriptor& rnn_desc,
1931 const CudnnRnnSequenceTensorDescriptor& input_desc,
1932 const DeviceMemory<T>& input_data,
1933 const DeviceMemory<int>& seq_lengths_data,
1934 const CudnnRnnStateTensorDescriptor& input_h_desc,
1935 const DeviceMemory<T>& input_h_data,
1936 const CudnnRnnStateTensorDescriptor& input_c_desc,
1937 const DeviceMemory<T>& input_c_data, const DeviceMemory<T>& params,
1938 const CudnnRnnSequenceTensorDescriptor& output_desc,
1939 DeviceMemory<T>* output_data,
1940 const CudnnRnnStateTensorDescriptor& output_h_desc,
1941 DeviceMemory<T>* output_h_data,
1942 const CudnnRnnStateTensorDescriptor& output_c_desc,
1943 DeviceMemory<T>* output_c_data, bool is_training,
1944 ScratchAllocator* reserve_space_allocator,
1945 ScratchAllocator* workspace_allocator,
1946 dnn::ProfileResult* output_profile_result) {
1947 TF_ASSIGN_OR_RETURN(
1948 RnnModelDims model_dims,
1949 ExtractAndCheckRnnForward(
1950 rnn_desc, input_desc, input_data, input_h_desc, input_h_data,
1951 input_c_desc, input_c_data, params, output_desc, *output_data,
1952 output_h_desc, *output_h_data, output_c_desc, *output_c_data));
1953
1954 auto cudnn = cudnn_->GetHandle(parent_, stream);
1955
1956 TF_RETURN_IF_ERROR(CheckRNNParameterSize(cudnn, rnn_desc, input_desc));
1957
1958 // In CUDNN v8.0, the cudnnRNNForward*** and cudnnRNNForward***Ex have been
1959 // deprecated. Instead, we use the cudnnRNNForward which requires the
1960 // sequence_lengths parameter. For more info,
1961 // https://docs.nvidia.com/deeplearning/cudnn/api/index.html#release-802.
1962 #if CUDNN_VERSION >= 8100 && TF_ENABLE_CUDNN_FRONTEND
1963 if (input_desc.is_var_seq_lengths()) {
1964 DeviceMemory<uint8> workspace;
1965 DeviceMemory<uint8> reserve_space;
1966 cudnnForwardMode_t rnn_fwd_mode;
1967 if (is_training) {
1968 rnn_fwd_mode = CUDNN_FWD_MODE_TRAINING;
1969 } else {
1970 rnn_fwd_mode = CUDNN_FWD_MODE_INFERENCE;
1971 }
1972 size_t reserve_space_size_in_bytes = 0;
1973 size_t workspace_size_in_bytes = 0;
1974 RETURN_IF_CUDNN_ERROR(cudnnGetRNNTempSpaceSizes(
1975 /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(),
1976 /*fMode=*/rnn_fwd_mode, /*xDesc=*/input_desc.data_handle(),
1977 /*workSpaceSize=*/&workspace_size_in_bytes,
1978 /*reserveSpaceSize=*/&reserve_space_size_in_bytes));
1979
1980 if (workspace_size_in_bytes > 0) {
1981 TF_ASSIGN_OR_RETURN(workspace, workspace_allocator->AllocateBytes(
1982 workspace_size_in_bytes));
1983 }
1984 if (reserve_space_size_in_bytes > 0) {
1985 TF_ASSIGN_OR_RETURN(reserve_space, reserve_space_allocator->AllocateBytes(
1986 reserve_space_size_in_bytes));
1987 }
1988
1989 std::unique_ptr<GpuTimer, GpuTimerDeleter> timer;
1990 const bool is_profiling = output_profile_result != nullptr;
1991 if (is_profiling) {
1992 timer.reset(new GpuTimer(parent_));
1993 // The start and stop of the timer should be as close to the Cudnn call as
1994 // possible. It is still possible for other threads to issue workload on
1995 // to this stream. So it could take multiple profiling measurements.
1996 if (!timer->Init() || !timer->Start(AsGpuStream(stream))) {
1997 return port::Status(port::error::INTERNAL, "Failed to start timer");
1998 }
1999 }
2000
2001 RETURN_IF_CUDNN_ERROR(cudnnRNNForward(
2002 /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(),
2003 /*fwdMode=*/rnn_fwd_mode,
2004 /*devSeqLengths=*/
2005 reinterpret_cast<const int*>(seq_lengths_data.opaque()),
2006 /*xDesc=*/input_desc.data_handle(), /*x=*/input_data.opaque(),
2007 /*yDesc=*/output_desc.data_handle(), /*y=*/output_data->opaque(),
2008 /*hxDesc=*/input_h_desc.handle(), /*hx=*/input_h_data.opaque(),
2009 /*hy=*/output_h_data->opaque(),
2010 /*cxDesc=*/input_c_desc.handle(), /*cx=*/input_c_data.opaque(),
2011 /*cy=*/output_c_data->opaque(),
2012 /*weightSpaceSize=*/rnn_desc.ParamsSizeInBytes(),
2013 /*weightSpace=*/params.opaque(),
2014 /*workSpaceSize=*/workspace.size(), /*workspace=*/workspace.opaque(),
2015 /*reserveSpaceSizeInBytes=*/reserve_space.size(),
2016 /*reserveSpace=*/reserve_space.opaque()));
2017
2018 if (is_profiling) {
2019 if (!timer->Stop(AsGpuStream(stream))) {
2020 return port::Status(port::error::INTERNAL, "Failed to stop timer");
2021 }
2022 auto algo_desc = *rnn_desc.algorithm_config().algorithm();
2023 output_profile_result->set_algorithm(algo_desc);
2024 output_profile_result->set_elapsed_time_in_ms(
2025 timer->GetElapsedMilliseconds());
2026 }
2027 return port::Status::OK();
2028 }
2029 #endif
2030 TF_ASSIGN_OR_RETURN(DeviceMemory<uint8> workspace,
2031 CreateRnnWorkspace(stream, cudnn, rnn_desc, input_desc,
2032 workspace_allocator));
2033
2034 // query the reserve space size
2035 // allocate the reserve space
2036 DeviceMemory<uint8> reserve_space;
2037 if (is_training) {
2038 size_t reserve_space_size_in_bytes = 0;
2039 RETURN_IF_CUDNN_ERROR(cudnnGetRNNTrainingReserveSize(
2040 /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(),
2041 /*seqLength=*/model_dims.max_seq_length, /*xDesc=*/input_desc.handles(),
2042 /*sizeInBytes=*/&reserve_space_size_in_bytes));
2043
2044 if (reserve_space_size_in_bytes > 0) {
2045 TF_ASSIGN_OR_RETURN(reserve_space, reserve_space_allocator->AllocateBytes(
2046 reserve_space_size_in_bytes));
2047 }
2048 }
2049
2050 std::unique_ptr<GpuTimer, GpuTimerDeleter> timer;
2051 const bool is_profiling = output_profile_result != nullptr;
2052 if (is_profiling) {
2053 timer.reset(new GpuTimer(parent_));
2054 // The start and stop of the timer should be as close to the Cudnn call as
2055 // possible. It is still possible for other threads to issue workload on
2056 // to this stream. So it could take multiple profiling measurements.
2057 if (!timer->Init() || !timer->Start(AsGpuStream(stream))) {
2058 return port::Status(port::error::INTERNAL, "Failed to start timer");
2059 }
2060 }
2061
2062 if (!is_training) {
2063 if (input_desc.is_var_seq_lengths()) {
2064 RETURN_IF_CUDNN_ERROR(cudnnRNNForwardInferenceEx(
2065 /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(),
2066 /*xDesc=*/input_desc.data_handle(), /*x=*/input_data.opaque(),
2067 /*hxDesc=*/input_h_desc.handle(), /*hx=*/input_h_data.opaque(),
2068 /*cxDesc=*/input_c_desc.handle(), /*cx=*/input_c_data.opaque(),
2069 /*wDesc=*/rnn_desc.params_handle(), /*w=*/params.opaque(),
2070 /*yDesc=*/output_desc.data_handle(),
2071 /*y=*/output_data->opaque(),
2072 /*hyDesc=*/output_h_desc.handle(), /*hy=*/output_h_data->opaque(),
2073 /*cyDesc=*/output_c_desc.handle(), /*cy=*/output_c_data->opaque(),
2074 nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr,
2075 nullptr,
2076 /*workspace=*/workspace.opaque(),
2077 /*workSpaceSizeInBytes=*/workspace.size()));
2078 } else {
2079 RETURN_IF_CUDNN_ERROR(cudnnRNNForwardInference(
2080 /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(),
2081 /*seqLength=*/model_dims.max_seq_length,
2082 /*xDesc=*/input_desc.handles(),
2083 /*x=*/input_data.opaque(), /*hxDesc=*/input_h_desc.handle(),
2084 /*hx=*/input_h_data.opaque(), /*cxDesc=*/input_c_desc.handle(),
2085 /*cx=*/input_c_data.opaque(), /*wDesc=*/rnn_desc.params_handle(),
2086 /*w=*/params.opaque(), /*yDesc=*/output_desc.handles(),
2087 /*y=*/output_data->opaque(), /*hyDesc=*/output_h_desc.handle(),
2088 /*hy=*/output_h_data->opaque(), /*cyDesc=*/output_c_desc.handle(),
2089 /*cy=*/output_c_data->opaque(), /*workspace=*/workspace.opaque(),
2090 /*workSpaceSizeInBytes=*/workspace.size()));
2091 }
2092 } else {
2093 if (input_desc.is_var_seq_lengths()) {
2094 RETURN_IF_CUDNN_ERROR(cudnnRNNForwardTrainingEx(
2095 /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(),
2096 /*xDesc=*/input_desc.data_handle(), /*x=*/input_data.opaque(),
2097 /*hxDesc=*/input_h_desc.handle(), /*hx=*/input_h_data.opaque(),
2098 /*cxDesc=*/input_c_desc.handle(), /*cx=*/input_c_data.opaque(),
2099 /*wDesc=*/rnn_desc.params_handle(), /*w=*/params.opaque(),
2100 /*yDesc=*/output_desc.data_handle(),
2101 /*y=*/output_data->opaque(),
2102 /*hyDesc=*/output_h_desc.handle(), /*hy=*/output_h_data->opaque(),
2103 /*cyDesc=*/output_c_desc.handle(), /*cy=*/output_c_data->opaque(),
2104 nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr,
2105 nullptr,
2106 /*workspace=*/workspace.opaque(),
2107 /*workSpaceSizeInBytes=*/workspace.size(),
2108 /*reserveSpace=*/reserve_space.opaque(),
2109 /*reserveSpaceSizeInBytes=*/reserve_space.size()));
2110 } else {
2111 RETURN_IF_CUDNN_ERROR(cudnnRNNForwardTraining(
2112 /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(),
2113 /*seqLength=*/model_dims.max_seq_length,
2114 /*xDesc=*/input_desc.handles(),
2115 /*x=*/input_data.opaque(), /*hxDesc=*/input_h_desc.handle(),
2116 /*hx=*/input_h_data.opaque(), /*cxDesc=*/input_c_desc.handle(),
2117 /*cx=*/input_c_data.opaque(), /*wDesc=*/rnn_desc.params_handle(),
2118 /*w=*/params.opaque(), /*yDesc=*/output_desc.handles(),
2119 /*y=*/output_data->opaque(), /*hyDesc=*/output_h_desc.handle(),
2120 /*hy=*/output_h_data->opaque(), /*cyDesc=*/output_c_desc.handle(),
2121 /*cy=*/output_c_data->opaque(), /*workspace=*/workspace.opaque(),
2122 /*workSpaceSizeInBytes=*/workspace.size(),
2123 /*reserveSpace=*/reserve_space.opaque(),
2124 /*reserveSpaceSizeInBytes=*/reserve_space.size()));
2125 }
2126 }
2127
2128 if (is_profiling) {
2129 if (!timer->Stop(AsGpuStream(stream))) {
2130 return port::Status(port::error::INTERNAL, "Failed to stop timer");
2131 }
2132 auto algo_desc = *rnn_desc.algorithm_config().algorithm();
2133 output_profile_result->set_algorithm(algo_desc);
2134 output_profile_result->set_elapsed_time_in_ms(
2135 timer->GetElapsedMilliseconds());
2136 }
2137
2138 return ::tensorflow::OkStatus();
2139 }
2140
2141 template <class T>
DoRnnBackwardImpl(Stream * stream,const CudnnRnnDescriptor & rnn_desc,const CudnnRnnSequenceTensorDescriptor & input_desc,const DeviceMemory<T> & input_data,const DeviceMemory<int> & seq_lengths_data,const CudnnRnnStateTensorDescriptor & input_h_desc,const DeviceMemory<T> & input_h_data,const CudnnRnnStateTensorDescriptor & input_c_desc,const DeviceMemory<T> & input_c_data,const DeviceMemory<T> & params,const CudnnRnnSequenceTensorDescriptor & output_desc,const DeviceMemory<T> & output_data,const CudnnRnnStateTensorDescriptor & output_h_desc,const DeviceMemory<T> & output_h_data,const CudnnRnnStateTensorDescriptor & output_c_desc,const DeviceMemory<T> & output_c_data,const DeviceMemory<T> & output_backprop_data,const DeviceMemory<T> & output_h_backprop_data,const DeviceMemory<T> & output_c_backprop_data,DeviceMemory<T> * input_backprop_data,DeviceMemory<T> * input_h_backprop_data,DeviceMemory<T> * input_c_backprop_data,DeviceMemory<T> * params_backprop_data,DeviceMemory<uint8> * reserve_space_data,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)2142 port::Status CudnnSupport::DoRnnBackwardImpl(
2143 Stream* stream, const CudnnRnnDescriptor& rnn_desc,
2144 const CudnnRnnSequenceTensorDescriptor& input_desc,
2145 const DeviceMemory<T>& input_data,
2146 const DeviceMemory<int>& seq_lengths_data,
2147 const CudnnRnnStateTensorDescriptor& input_h_desc,
2148 const DeviceMemory<T>& input_h_data,
2149 const CudnnRnnStateTensorDescriptor& input_c_desc,
2150 const DeviceMemory<T>& input_c_data, const DeviceMemory<T>& params,
2151 const CudnnRnnSequenceTensorDescriptor& output_desc,
2152 const DeviceMemory<T>& output_data,
2153 const CudnnRnnStateTensorDescriptor& output_h_desc,
2154 const DeviceMemory<T>& output_h_data,
2155 const CudnnRnnStateTensorDescriptor& output_c_desc,
2156 const DeviceMemory<T>& output_c_data,
2157 const DeviceMemory<T>& output_backprop_data,
2158 const DeviceMemory<T>& output_h_backprop_data,
2159 const DeviceMemory<T>& output_c_backprop_data,
2160 DeviceMemory<T>* input_backprop_data,
2161 DeviceMemory<T>* input_h_backprop_data,
2162 DeviceMemory<T>* input_c_backprop_data,
2163 DeviceMemory<T>* params_backprop_data,
2164 DeviceMemory<uint8>* reserve_space_data,
2165 ScratchAllocator* workspace_allocator,
2166 dnn::ProfileResult* output_profile_result) {
2167 TF_ASSIGN_OR_RETURN(
2168 RnnModelDims model_dims,
2169 ExtractAndCheckRnnForward(rnn_desc, input_desc, input_data, input_h_desc,
2170 input_h_data, input_c_desc, input_c_data,
2171 params, output_desc, output_data, output_h_desc,
2172 output_h_data, output_c_desc, output_c_data));
2173
2174 auto cudnn = cudnn_->GetHandle(parent_, stream);
2175
2176 TF_RETURN_IF_ERROR(CheckRNNParameterSize(cudnn, rnn_desc, input_desc));
2177
2178 // In CUDNN v8.0, the cudnnRNNForward*** and cudnnRNNForward***Ex have been
2179 // deprecated. Instead, we use the cudnnRNNForward which requires the
2180 // sequence_lengths parameter. For more info,
2181 // https://docs.nvidia.com/deeplearning/cudnn/api/index.html#release-802.
2182 #if CUDNN_VERSION >= 8100 && TF_ENABLE_CUDNN_FRONTEND
2183 if (input_desc.is_var_seq_lengths()) {
2184 DeviceMemory<uint8> workspace;
2185 size_t workspace_size_in_bytes = 0;
2186 RETURN_IF_CUDNN_ERROR(cudnnGetRNNTempSpaceSizes(
2187 /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(),
2188 /*fMode=*/CUDNN_FWD_MODE_TRAINING, /*xDesc=*/input_desc.data_handle(),
2189 /*workSpaceSize=*/&workspace_size_in_bytes,
2190 /*reserveSpaceSize=*/NULL));
2191 if (workspace_size_in_bytes > 0) {
2192 TF_ASSIGN_OR_RETURN(workspace, workspace_allocator->AllocateBytes(
2193 workspace_size_in_bytes));
2194 }
2195
2196 std::unique_ptr<GpuTimer, GpuTimerDeleter> timer;
2197 const bool is_profiling = output_profile_result != nullptr;
2198 if (is_profiling) {
2199 timer.reset(new GpuTimer(parent_));
2200 // The start and stop of the timer should be as close to the Cudnn call as
2201 // possible. It is still possible for other threads to issue workload on
2202 // to this stream. So it could take multiple profiling measurements.
2203 if (!timer->Init() || !timer->Start(AsGpuStream(stream))) {
2204 return port::Status(port::error::INTERNAL, "Failed to start timer");
2205 }
2206 }
2207
2208 RETURN_IF_CUDNN_ERROR(cudnnRNNBackwardData_v8(
2209 /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(),
2210 /*devSeqLengths=*/
2211 reinterpret_cast<const int*>(seq_lengths_data.opaque()),
2212 /*yDesc=*/output_desc.data_handle(), /*y=*/output_data.opaque(),
2213 /*dy=*/output_backprop_data.opaque(),
2214 /*xDesc=*/input_desc.data_handle(),
2215 /*dx=*/input_backprop_data->opaque(),
2216 /*hxDesc=*/input_h_desc.handle(), /*hx=*/input_h_data.opaque(),
2217 /*dhy=*/output_h_backprop_data.opaque(),
2218 /*dhx=*/input_h_backprop_data->opaque(),
2219 /*cxDesc=*/input_c_desc.handle(), /*cx=*/input_c_data.opaque(),
2220 /*dcy=*/output_c_backprop_data.opaque(),
2221 /*dcx=*/input_c_backprop_data->opaque(),
2222 /*weightSpaceSize=*/rnn_desc.ParamsSizeInBytes(),
2223 /*weightSpace=*/params.opaque(),
2224 /*workSpaceSize=*/workspace.size(), /*workSpace=*/workspace.opaque(),
2225 /*reserveSpaceSize=*/reserve_space_data->size(),
2226 /*reserveSpace=*/reserve_space_data->opaque()));
2227
2228 if (params_backprop_data != nullptr) {
2229 // Clear the dw to zeros.
2230 stream->ThenMemZero(params_backprop_data, params_backprop_data->size());
2231 RETURN_IF_CUDNN_ERROR(cudnnRNNBackwardWeights_v8(
2232 /*handle=*/cudnn.handle(),
2233 /*rnnDesc=*/rnn_desc.handle(),
2234 /*addGrad=*/CUDNN_WGRAD_MODE_ADD,
2235 /*devSeqLengths=*/
2236 reinterpret_cast<const int*>(seq_lengths_data.opaque()),
2237 /*xDesc=*/input_desc.data_handle(),
2238 /*x=*/input_data.opaque(),
2239 /*hDesc=*/input_h_desc.handle(),
2240 /*hx=*/input_h_data.opaque(),
2241 /*yDesc=*/output_desc.data_handle(),
2242 /*y=*/output_data.opaque(),
2243 /*weightSpaceSize=*/rnn_desc.ParamsSizeInBytes(),
2244 /*dweightSpace=*/params_backprop_data->opaque(),
2245 /*workSpaceSize=*/workspace.size(),
2246 /*workSpace=*/workspace.opaque(),
2247 /*reserveSpaceSize=*/reserve_space_data->size(),
2248 /*reserveSpace=*/reserve_space_data->opaque()));
2249 }
2250
2251 if (is_profiling) {
2252 if (!timer->Stop(AsGpuStream(stream))) {
2253 return port::Status(port::error::INTERNAL, "Failed to stop timer");
2254 }
2255 auto algo_desc = *rnn_desc.algorithm_config().algorithm();
2256 output_profile_result->set_algorithm(algo_desc);
2257 output_profile_result->set_elapsed_time_in_ms(
2258 timer->GetElapsedMilliseconds());
2259 }
2260 return port::Status::OK();
2261 }
2262 #endif
2263 TF_ASSIGN_OR_RETURN(DeviceMemory<uint8> workspace,
2264 CreateRnnWorkspace(stream, cudnn, rnn_desc, input_desc,
2265 workspace_allocator));
2266
2267 std::unique_ptr<GpuTimer, GpuTimerDeleter> timer;
2268 const bool is_profiling = output_profile_result != nullptr;
2269 if (is_profiling) {
2270 timer.reset(new GpuTimer(parent_));
2271 // The start and stop of the timer should be as close to the Cudnn call as
2272 // possible. It is still possible for other threads to issue workload on
2273 // to this stream. So it could take multiple profiling measurements.
2274 if (!timer->Init() || !timer->Start(AsGpuStream(stream))) {
2275 return port::Status(port::error::INTERNAL, "Failed to start timer");
2276 }
2277 }
2278
2279 if (input_desc.is_var_seq_lengths()) {
2280 RETURN_IF_CUDNN_ERROR(cudnnRNNBackwardDataEx(
2281 /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(),
2282 /*yDesc=*/output_desc.data_handle(), /*y=*/output_data.opaque(),
2283 /*dyDesc=*/output_desc.data_handle(),
2284 /*dy=*/output_backprop_data.opaque(), nullptr, nullptr,
2285 /*dhyDesc=*/output_h_desc.handle(),
2286 /*dhy=*/output_h_backprop_data.opaque(),
2287 /*dcyDesc=*/output_c_desc.handle(),
2288 /*dcy=*/output_c_backprop_data.opaque(),
2289 /*wDesc=*/rnn_desc.params_handle(), /*w=*/params.opaque(),
2290 /*hxDesc=*/input_h_desc.handle(), /*hx=*/input_h_data.opaque(),
2291 /*cxDesc=*/input_c_desc.handle(), /*cx=*/input_c_data.opaque(),
2292 /*dxDesc=*/input_desc.data_handle(),
2293 /*dx=*/input_backprop_data->opaque(),
2294 /*dhxDesc=*/input_h_desc.handle(),
2295 /*dhx=*/input_h_backprop_data->opaque(),
2296 /*dcxDesc=*/input_c_desc.handle(),
2297 /*dcx=*/input_c_backprop_data->opaque(), nullptr, nullptr,
2298 /*workspace=*/workspace.opaque(),
2299 /*workSpaceSizeInBytes=*/workspace.size(),
2300 /*reserveSpace=*/reserve_space_data->opaque(),
2301 /*reserveSpaceSizeInBytes=*/reserve_space_data->size()));
2302 } else {
2303 RETURN_IF_CUDNN_ERROR(cudnnRNNBackwardData(
2304 /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(),
2305 /*seqLength=*/model_dims.max_seq_length,
2306 /*yDesc=*/output_desc.handles(),
2307 /*y=*/output_data.opaque(), /*dyDesc=*/output_desc.handles(),
2308 /*dy=*/output_backprop_data.opaque(),
2309 /*dhyDesc=*/output_h_desc.handle(),
2310 /*dhy=*/output_h_backprop_data.opaque(),
2311 /*dcyDesc=*/output_c_desc.handle(),
2312 /*dcy=*/output_c_backprop_data.opaque(),
2313 /*wDesc=*/rnn_desc.params_handle(), /*w=*/params.opaque(),
2314 /*hxDesc=*/input_h_desc.handle(), /*hx=*/input_h_data.opaque(),
2315 /*cxDesc=*/input_c_desc.handle(), /*cx=*/input_c_data.opaque(),
2316 /*dxDesc=*/input_desc.handles(), /*dx=*/input_backprop_data->opaque(),
2317 /*dhxDesc=*/input_h_desc.handle(),
2318 /*dhx=*/input_h_backprop_data->opaque(),
2319 /*dcxDesc=*/input_c_desc.handle(),
2320 /*dcx=*/input_c_backprop_data->opaque(),
2321 /*workspace=*/workspace.opaque(),
2322 /*workSpaceSizeInBytes=*/workspace.size(),
2323 /*reserveSpace=*/reserve_space_data->opaque(),
2324 /*reserveSpaceSizeInBytes=*/reserve_space_data->size()));
2325 }
2326
2327 if (params_backprop_data != nullptr) {
2328 // Clear the dw to zeros.
2329 stream->ThenMemZero(params_backprop_data, params_backprop_data->size());
2330 if (input_desc.is_var_seq_lengths()) {
2331 RETURN_IF_CUDNN_ERROR(cudnnRNNBackwardWeightsEx(
2332 /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(),
2333 /*xDesc=*/input_desc.data_handle(), /*x=*/input_data.opaque(),
2334 /*hxDesc=*/input_h_desc.handle(), /*hx=*/input_h_data.opaque(),
2335 /*yDesc=*/output_desc.data_handle(),
2336 /*y=*/output_data.opaque(),
2337 /*workspace=*/workspace.opaque(),
2338 /*workSpaceSizeInBytes=*/workspace.size(),
2339 /*dwDesc=*/rnn_desc.params_handle(),
2340 /*dw=*/params_backprop_data->opaque(),
2341 /*reserveSpace=*/reserve_space_data->opaque(),
2342 /*reserveSpaceSizeInBytes=*/reserve_space_data->size()));
2343 } else {
2344 // make the backward weight call
2345 RETURN_IF_CUDNN_ERROR(cudnnRNNBackwardWeights(
2346 /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(),
2347 /*seqLength=*/model_dims.max_seq_length,
2348 /*xDesc=*/input_desc.handles(),
2349 /*x=*/input_data.opaque(), /*hxDesc=*/input_h_desc.handle(),
2350 /*hx=*/input_h_data.opaque(), /*yDesc=*/output_desc.handles(),
2351 /*y=*/output_data.opaque(), /*workspace=*/workspace.opaque(),
2352 /*workSpaceSizeInBytes=*/workspace.size(),
2353 /*dwDesc=*/rnn_desc.params_handle(),
2354 /*dw=*/params_backprop_data->opaque(),
2355 /*reserveSpace=*/reserve_space_data->opaque(),
2356 /*reserveSpaceSizeInBytes=*/reserve_space_data->size()));
2357 }
2358 }
2359
2360 if (is_profiling) {
2361 if (!timer->Stop(AsGpuStream(stream))) {
2362 return port::Status(port::error::INTERNAL, "Failed to stop timer");
2363 }
2364 auto algo_desc = *rnn_desc.algorithm_config().algorithm();
2365 output_profile_result->set_algorithm(algo_desc);
2366 output_profile_result->set_elapsed_time_in_ms(
2367 timer->GetElapsedMilliseconds());
2368 }
2369
2370 return ::tensorflow::OkStatus();
2371 }
2372
DoCtcLossImpl(Stream * stream,const CudnnRnnStateTensorDescriptor & probs_desc,const DeviceMemoryBase probs_data,absl::Span<const int> labels_data,absl::Span<const int> labels_lengths_data,absl::Span<const int> input_lengths_data,DeviceMemoryBase costs_data,const CudnnRnnStateTensorDescriptor & grads_desc,DeviceMemoryBase grads_data,const CudnnCtcLossDescriptor & ctc_loss_desc,DeviceMemory<uint8> scratch_memory,int ctc_loss_algo_id)2373 port::Status CudnnSupport::DoCtcLossImpl(
2374 Stream* stream, const CudnnRnnStateTensorDescriptor& probs_desc,
2375 const DeviceMemoryBase probs_data, absl::Span<const int> labels_data,
2376 absl::Span<const int> labels_lengths_data,
2377 absl::Span<const int> input_lengths_data, DeviceMemoryBase costs_data,
2378 const CudnnRnnStateTensorDescriptor& grads_desc,
2379 DeviceMemoryBase grads_data, const CudnnCtcLossDescriptor& ctc_loss_desc,
2380 DeviceMemory<uint8> scratch_memory, int ctc_loss_algo_id) {
2381 auto cudnn = cudnn_->GetHandle(parent_, stream);
2382
2383 int kNumTimestamps = probs_desc.num_layers();
2384 int kBatchSize = probs_desc.batch_size();
2385 int kNumLabels = probs_desc.data_size();
2386 int total_size = kNumLabels * kNumTimestamps * kBatchSize;
2387 (void)total_size;
2388
2389 #if CUDNN_VERSION >= 7603
2390 cudnnCTCLossAlgo_t ctc_loss_algo =
2391 static_cast<cudnnCTCLossAlgo_t>(ctc_loss_algo_id);
2392 RETURN_IF_CUDNN_ERROR(cudnnCTCLoss(
2393 /*handle=*/cudnn.handle(), /*probsDesc=*/probs_desc.handle(),
2394 /*probs=*/probs_data.opaque(), /*labels=*/labels_data.data(),
2395 /*labelLengths=*/labels_lengths_data.data(),
2396 /*inputLengths=*/input_lengths_data.data(),
2397 /*costs=*/costs_data.opaque(), /*gradientsDesc=*/grads_desc.handle(),
2398 /*gradients=*/grads_data.opaque(),
2399 /*algo=*/ctc_loss_algo,
2400 /*ctcLossDesc=*/ctc_loss_desc.handle(),
2401 /*workspace=*/scratch_memory.opaque(),
2402 /*workSpaceSizeInBytes=*/scratch_memory.size()));
2403 #else
2404 return port::Status(port::error::INVALID_ARGUMENT,
2405 "No supported cudnnCTCLoss when "
2406 "CUDNN_VERSION < 7.6.3");
2407 #endif
2408
2409 return ::tensorflow::OkStatus();
2410 }
2411
2412 port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>>
createRnnDescriptor(int num_layers,int hidden_size,int input_size,int cell_size,int batch_size,dnn::RnnInputMode input_mode,dnn::RnnDirectionMode direction_mode,dnn::RnnMode rnn_mode,dnn::DataType data_type,const dnn::AlgorithmConfig & algorithm_config,float dropout,uint64_t seed,ScratchAllocator * state_allocator,bool use_padded_io)2413 CudnnSupport::createRnnDescriptor(
2414 int num_layers, int hidden_size, int input_size, int cell_size,
2415 int batch_size, dnn::RnnInputMode input_mode,
2416 dnn::RnnDirectionMode direction_mode, dnn::RnnMode rnn_mode,
2417 dnn::DataType data_type, const dnn::AlgorithmConfig& algorithm_config,
2418 float dropout, uint64_t seed, ScratchAllocator* state_allocator,
2419 bool use_padded_io) {
2420 // Setting up a cudnnRNNDescriptor requires a cuDNN handle, but because it's
2421 // not enqueueing anything into a stream, we pass in the null stream.
2422 auto cudnn = cudnn_->GetHandle(parent_, /*stream=*/nullptr);
2423 TF_ASSIGN_OR_RETURN(
2424 CudnnRnnDescriptor rnn_desc,
2425 CudnnRnnDescriptor::Create(
2426 cudnn, num_layers, hidden_size, input_size, cell_size, batch_size,
2427 ToCudnnRnnInputMode(input_mode),
2428 ToCudnnRnnDirectionMode(direction_mode), ToCudnnRnnMode(rnn_mode),
2429 ToCudnnDataType(data_type), GetRnnComputeType(data_type),
2430 algorithm_config, dropout, seed, state_allocator, use_padded_io));
2431 return std::unique_ptr<dnn::RnnDescriptor>(
2432 new CudnnRnnDescriptor(std::move(rnn_desc)));
2433 }
2434
2435 port::StatusOr<std::unique_ptr<dnn::RnnSequenceTensorDescriptor>>
createRnnSequenceTensorDescriptor(int max_seq_length,int batch_size,int data_size,dnn::DataType data_type)2436 CudnnSupport::createRnnSequenceTensorDescriptor(int max_seq_length,
2437 int batch_size, int data_size,
2438 dnn::DataType data_type) {
2439 TF_ASSIGN_OR_RETURN(CudnnRnnSequenceTensorDescriptor descriptor,
2440 CudnnRnnSequenceTensorDescriptor::Create(
2441 parent_, max_seq_length, batch_size, data_size,
2442 ToCudnnDataType(data_type)));
2443 return std::unique_ptr<dnn::RnnSequenceTensorDescriptor>(
2444 new CudnnRnnSequenceTensorDescriptor(std::move(descriptor)));
2445 }
2446
2447 port::StatusOr<std::unique_ptr<dnn::RnnSequenceTensorDescriptor>>
createRnnSequenceTensorDescriptor(int max_seq_length,int batch_size,int data_size,const absl::Span<const int> & seq_lengths,bool time_major,dnn::DataType data_type)2448 CudnnSupport::createRnnSequenceTensorDescriptor(
2449 int max_seq_length, int batch_size, int data_size,
2450 const absl::Span<const int>& seq_lengths, bool time_major,
2451 dnn::DataType data_type) {
2452 TF_ASSIGN_OR_RETURN(CudnnRnnSequenceTensorDescriptor descriptor,
2453 CudnnRnnSequenceTensorDescriptor::Create(
2454 parent_, max_seq_length, batch_size, data_size,
2455 seq_lengths, time_major, ToCudnnDataType(data_type)));
2456 return std::unique_ptr<dnn::RnnSequenceTensorDescriptor>(
2457 new CudnnRnnSequenceTensorDescriptor(std::move(descriptor)));
2458 }
2459
2460 port::StatusOr<std::unique_ptr<dnn::RnnStateTensorDescriptor>>
createRnnStateTensorDescriptor(int num_layer,int batch_size,int data_size,dnn::DataType data_type)2461 CudnnSupport::createRnnStateTensorDescriptor(int num_layer, int batch_size,
2462 int data_size,
2463 dnn::DataType data_type) {
2464 return std::unique_ptr<dnn::RnnStateTensorDescriptor>(
2465 new CudnnRnnStateTensorDescriptor(parent_, num_layer, batch_size,
2466 data_size, ToCudnnDataType(data_type)));
2467 }
2468
DoRnnForward(Stream * stream,const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<Eigen::half> & input_data,const DeviceMemory<int> & seq_lengths_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<Eigen::half> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<Eigen::half> & input_c_data,const DeviceMemory<Eigen::half> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,DeviceMemory<Eigen::half> * output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,DeviceMemory<Eigen::half> * output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,DeviceMemory<Eigen::half> * output_c_data,bool is_training,ScratchAllocator * reserve_space_allocator,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)2469 bool CudnnSupport::DoRnnForward(
2470 Stream* stream, const dnn::RnnDescriptor& rnn_desc,
2471 const dnn::RnnSequenceTensorDescriptor& input_desc,
2472 const DeviceMemory<Eigen::half>& input_data,
2473 const DeviceMemory<int>& seq_lengths_data,
2474 const dnn::RnnStateTensorDescriptor& input_h_desc,
2475 const DeviceMemory<Eigen::half>& input_h_data,
2476 const dnn::RnnStateTensorDescriptor& input_c_desc,
2477 const DeviceMemory<Eigen::half>& input_c_data,
2478 const DeviceMemory<Eigen::half>& params,
2479 const dnn::RnnSequenceTensorDescriptor& output_desc,
2480 DeviceMemory<Eigen::half>* output_data,
2481 const dnn::RnnStateTensorDescriptor& output_h_desc,
2482 DeviceMemory<Eigen::half>* output_h_data,
2483 const dnn::RnnStateTensorDescriptor& output_c_desc,
2484 DeviceMemory<Eigen::half>* output_c_data, bool is_training,
2485 ScratchAllocator* reserve_space_allocator,
2486 ScratchAllocator* workspace_allocator,
2487 dnn::ProfileResult* output_profile_result) {
2488 const CudnnRnnDescriptor& cudnn_rnn_desc =
2489 static_cast<const CudnnRnnDescriptor&>(rnn_desc);
2490 const CudnnRnnSequenceTensorDescriptor& cudnn_input_desc =
2491 static_cast<const CudnnRnnSequenceTensorDescriptor&>(input_desc);
2492 const CudnnRnnStateTensorDescriptor& cudnn_input_h_desc =
2493 static_cast<const CudnnRnnStateTensorDescriptor&>(input_h_desc);
2494 const CudnnRnnStateTensorDescriptor& cudnn_input_c_desc =
2495 static_cast<const CudnnRnnStateTensorDescriptor&>(input_c_desc);
2496 const CudnnRnnSequenceTensorDescriptor& cudnn_output_desc =
2497 static_cast<const CudnnRnnSequenceTensorDescriptor&>(output_desc);
2498 const CudnnRnnStateTensorDescriptor& cudnn_output_h_desc =
2499 static_cast<const CudnnRnnStateTensorDescriptor&>(output_h_desc);
2500 const CudnnRnnStateTensorDescriptor& cudnn_output_c_desc =
2501 static_cast<const CudnnRnnStateTensorDescriptor&>(output_c_desc);
2502 return IsStatusOk(
2503 DoRnnForwardImpl<Eigen::half>(
2504 stream, cudnn_rnn_desc, cudnn_input_desc, input_data,
2505 seq_lengths_data, cudnn_input_h_desc, input_h_data,
2506 cudnn_input_c_desc, input_c_data, params, cudnn_output_desc,
2507 output_data, cudnn_output_h_desc, output_h_data, cudnn_output_c_desc,
2508 output_c_data, is_training, reserve_space_allocator,
2509 workspace_allocator, output_profile_result),
2510 /*report_error=*/!output_profile_result);
2511 }
2512
DoRnnForward(Stream * stream,const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<float> & input_data,const DeviceMemory<int> & seq_lengths_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<float> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<float> & input_c_data,const DeviceMemory<float> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,DeviceMemory<float> * output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,DeviceMemory<float> * output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,DeviceMemory<float> * output_c_data,bool is_training,ScratchAllocator * reserve_space_allocator,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)2513 bool CudnnSupport::DoRnnForward(
2514 Stream* stream, const dnn::RnnDescriptor& rnn_desc,
2515 const dnn::RnnSequenceTensorDescriptor& input_desc,
2516 const DeviceMemory<float>& input_data,
2517 const DeviceMemory<int>& seq_lengths_data,
2518 const dnn::RnnStateTensorDescriptor& input_h_desc,
2519 const DeviceMemory<float>& input_h_data,
2520 const dnn::RnnStateTensorDescriptor& input_c_desc,
2521 const DeviceMemory<float>& input_c_data, const DeviceMemory<float>& params,
2522 const dnn::RnnSequenceTensorDescriptor& output_desc,
2523 DeviceMemory<float>* output_data,
2524 const dnn::RnnStateTensorDescriptor& output_h_desc,
2525 DeviceMemory<float>* output_h_data,
2526 const dnn::RnnStateTensorDescriptor& output_c_desc,
2527 DeviceMemory<float>* output_c_data, bool is_training,
2528 ScratchAllocator* reserve_space_allocator,
2529 ScratchAllocator* workspace_allocator,
2530 dnn::ProfileResult* output_profile_result) {
2531 const CudnnRnnDescriptor& cudnn_rnn_desc =
2532 static_cast<const CudnnRnnDescriptor&>(rnn_desc);
2533 const CudnnRnnSequenceTensorDescriptor& cudnn_input_desc =
2534 static_cast<const CudnnRnnSequenceTensorDescriptor&>(input_desc);
2535 const CudnnRnnStateTensorDescriptor& cudnn_input_h_desc =
2536 static_cast<const CudnnRnnStateTensorDescriptor&>(input_h_desc);
2537 const CudnnRnnStateTensorDescriptor& cudnn_input_c_desc =
2538 static_cast<const CudnnRnnStateTensorDescriptor&>(input_c_desc);
2539 const CudnnRnnSequenceTensorDescriptor& cudnn_output_desc =
2540 static_cast<const CudnnRnnSequenceTensorDescriptor&>(output_desc);
2541 const CudnnRnnStateTensorDescriptor& cudnn_output_h_desc =
2542 static_cast<const CudnnRnnStateTensorDescriptor&>(output_h_desc);
2543 const CudnnRnnStateTensorDescriptor& cudnn_output_c_desc =
2544 static_cast<const CudnnRnnStateTensorDescriptor&>(output_c_desc);
2545 return IsStatusOk(
2546 DoRnnForwardImpl<float>(
2547 stream, cudnn_rnn_desc, cudnn_input_desc, input_data,
2548 seq_lengths_data, cudnn_input_h_desc, input_h_data,
2549 cudnn_input_c_desc, input_c_data, params, cudnn_output_desc,
2550 output_data, cudnn_output_h_desc, output_h_data, cudnn_output_c_desc,
2551 output_c_data, is_training, reserve_space_allocator,
2552 workspace_allocator, output_profile_result),
2553 /*report_error=*/!output_profile_result);
2554 }
2555
DoRnnForward(Stream * stream,const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<double> & input_data,const DeviceMemory<int> & seq_lengths_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<double> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<double> & input_c_data,const DeviceMemory<double> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,DeviceMemory<double> * output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,DeviceMemory<double> * output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,DeviceMemory<double> * output_c_data,bool is_training,ScratchAllocator * reserve_space_allocator,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)2556 bool CudnnSupport::DoRnnForward(
2557 Stream* stream, const dnn::RnnDescriptor& rnn_desc,
2558 const dnn::RnnSequenceTensorDescriptor& input_desc,
2559 const DeviceMemory<double>& input_data,
2560 const DeviceMemory<int>& seq_lengths_data,
2561 const dnn::RnnStateTensorDescriptor& input_h_desc,
2562 const DeviceMemory<double>& input_h_data,
2563 const dnn::RnnStateTensorDescriptor& input_c_desc,
2564 const DeviceMemory<double>& input_c_data,
2565 const DeviceMemory<double>& params,
2566 const dnn::RnnSequenceTensorDescriptor& output_desc,
2567 DeviceMemory<double>* output_data,
2568 const dnn::RnnStateTensorDescriptor& output_h_desc,
2569 DeviceMemory<double>* output_h_data,
2570 const dnn::RnnStateTensorDescriptor& output_c_desc,
2571 DeviceMemory<double>* output_c_data, bool is_training,
2572 ScratchAllocator* reserve_space_allocator,
2573 ScratchAllocator* workspace_allocator,
2574 dnn::ProfileResult* output_profile_result) {
2575 const CudnnRnnDescriptor& cudnn_rnn_desc =
2576 static_cast<const CudnnRnnDescriptor&>(rnn_desc);
2577 const CudnnRnnSequenceTensorDescriptor& cudnn_input_desc =
2578 static_cast<const CudnnRnnSequenceTensorDescriptor&>(input_desc);
2579 const CudnnRnnStateTensorDescriptor& cudnn_input_h_desc =
2580 static_cast<const CudnnRnnStateTensorDescriptor&>(input_h_desc);
2581 const CudnnRnnStateTensorDescriptor& cudnn_input_c_desc =
2582 static_cast<const CudnnRnnStateTensorDescriptor&>(input_c_desc);
2583 const CudnnRnnSequenceTensorDescriptor& cudnn_output_desc =
2584 static_cast<const CudnnRnnSequenceTensorDescriptor&>(output_desc);
2585 const CudnnRnnStateTensorDescriptor& cudnn_output_h_desc =
2586 static_cast<const CudnnRnnStateTensorDescriptor&>(output_h_desc);
2587 const CudnnRnnStateTensorDescriptor& cudnn_output_c_desc =
2588 static_cast<const CudnnRnnStateTensorDescriptor&>(output_c_desc);
2589 return IsStatusOk(
2590 DoRnnForwardImpl<double>(
2591 stream, cudnn_rnn_desc, cudnn_input_desc, input_data,
2592 seq_lengths_data, cudnn_input_h_desc, input_h_data,
2593 cudnn_input_c_desc, input_c_data, params, cudnn_output_desc,
2594 output_data, cudnn_output_h_desc, output_h_data, cudnn_output_c_desc,
2595 output_c_data, is_training, reserve_space_allocator,
2596 workspace_allocator, output_profile_result),
2597 /*report_error=*/!output_profile_result);
2598 }
2599
DoRnnBackward(Stream * stream,const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<Eigen::half> & input_data,const DeviceMemory<int> & seq_lengths_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<Eigen::half> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<Eigen::half> & input_c_data,const DeviceMemory<Eigen::half> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,const DeviceMemory<Eigen::half> & output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,const DeviceMemory<Eigen::half> & output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,const DeviceMemory<Eigen::half> & output_c_data,const DeviceMemory<Eigen::half> & output_backprop_data,const DeviceMemory<Eigen::half> & output_h_backprop_data,const DeviceMemory<Eigen::half> & output_c_backprop_data,DeviceMemory<Eigen::half> * input_backprop_data,DeviceMemory<Eigen::half> * input_h_backprop_data,DeviceMemory<Eigen::half> * input_c_backprop_data,DeviceMemory<Eigen::half> * params_backprop_data,DeviceMemory<uint8> * reserve_space_data,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)2600 bool CudnnSupport::DoRnnBackward(
2601 Stream* stream, const dnn::RnnDescriptor& rnn_desc,
2602 const dnn::RnnSequenceTensorDescriptor& input_desc,
2603 const DeviceMemory<Eigen::half>& input_data,
2604 const DeviceMemory<int>& seq_lengths_data,
2605 const dnn::RnnStateTensorDescriptor& input_h_desc,
2606 const DeviceMemory<Eigen::half>& input_h_data,
2607 const dnn::RnnStateTensorDescriptor& input_c_desc,
2608 const DeviceMemory<Eigen::half>& input_c_data,
2609 const DeviceMemory<Eigen::half>& params,
2610 const dnn::RnnSequenceTensorDescriptor& output_desc,
2611 const DeviceMemory<Eigen::half>& output_data,
2612 const dnn::RnnStateTensorDescriptor& output_h_desc,
2613 const DeviceMemory<Eigen::half>& output_h_data,
2614 const dnn::RnnStateTensorDescriptor& output_c_desc,
2615 const DeviceMemory<Eigen::half>& output_c_data,
2616 const DeviceMemory<Eigen::half>& output_backprop_data,
2617 const DeviceMemory<Eigen::half>& output_h_backprop_data,
2618 const DeviceMemory<Eigen::half>& output_c_backprop_data,
2619 DeviceMemory<Eigen::half>* input_backprop_data,
2620 DeviceMemory<Eigen::half>* input_h_backprop_data,
2621 DeviceMemory<Eigen::half>* input_c_backprop_data,
2622 DeviceMemory<Eigen::half>* params_backprop_data,
2623 DeviceMemory<uint8>* reserve_space_data,
2624 ScratchAllocator* workspace_allocator,
2625 dnn::ProfileResult* output_profile_result) {
2626 const CudnnRnnDescriptor& cudnn_rnn_desc =
2627 static_cast<const CudnnRnnDescriptor&>(rnn_desc);
2628 const CudnnRnnSequenceTensorDescriptor& cudnn_input_desc =
2629 static_cast<const CudnnRnnSequenceTensorDescriptor&>(input_desc);
2630 const CudnnRnnStateTensorDescriptor& cudnn_input_h_desc =
2631 static_cast<const CudnnRnnStateTensorDescriptor&>(input_h_desc);
2632 const CudnnRnnStateTensorDescriptor& cudnn_input_c_desc =
2633 static_cast<const CudnnRnnStateTensorDescriptor&>(input_c_desc);
2634 const CudnnRnnSequenceTensorDescriptor& cudnn_output_desc =
2635 static_cast<const CudnnRnnSequenceTensorDescriptor&>(output_desc);
2636 const CudnnRnnStateTensorDescriptor& cudnn_output_h_desc =
2637 static_cast<const CudnnRnnStateTensorDescriptor&>(output_h_desc);
2638 const CudnnRnnStateTensorDescriptor& cudnn_output_c_desc =
2639 static_cast<const CudnnRnnStateTensorDescriptor&>(output_c_desc);
2640 return IsStatusOk(
2641 DoRnnBackwardImpl<Eigen::half>(
2642 stream, cudnn_rnn_desc, cudnn_input_desc, input_data,
2643 seq_lengths_data, cudnn_input_h_desc, input_h_data,
2644 cudnn_input_c_desc, input_c_data, params, cudnn_output_desc,
2645 output_data, cudnn_output_h_desc, output_h_data, cudnn_output_c_desc,
2646 output_c_data, output_backprop_data, output_h_backprop_data,
2647 output_c_backprop_data, input_backprop_data, input_h_backprop_data,
2648 input_c_backprop_data, params_backprop_data, reserve_space_data,
2649 workspace_allocator, output_profile_result),
2650 /*report_error=*/!output_profile_result);
2651 }
2652
DoRnnBackward(Stream * stream,const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<float> & input_data,const DeviceMemory<int> & seq_lengths_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<float> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<float> & input_c_data,const DeviceMemory<float> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,const DeviceMemory<float> & output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,const DeviceMemory<float> & output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,const DeviceMemory<float> & output_c_data,const DeviceMemory<float> & output_backprop_data,const DeviceMemory<float> & output_h_backprop_data,const DeviceMemory<float> & output_c_backprop_data,DeviceMemory<float> * input_backprop_data,DeviceMemory<float> * input_h_backprop_data,DeviceMemory<float> * input_c_backprop_data,DeviceMemory<float> * params_backprop_data,DeviceMemory<uint8> * reserve_space_data,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)2653 bool CudnnSupport::DoRnnBackward(
2654 Stream* stream, const dnn::RnnDescriptor& rnn_desc,
2655 const dnn::RnnSequenceTensorDescriptor& input_desc,
2656 const DeviceMemory<float>& input_data,
2657 const DeviceMemory<int>& seq_lengths_data,
2658 const dnn::RnnStateTensorDescriptor& input_h_desc,
2659 const DeviceMemory<float>& input_h_data,
2660 const dnn::RnnStateTensorDescriptor& input_c_desc,
2661 const DeviceMemory<float>& input_c_data, const DeviceMemory<float>& params,
2662 const dnn::RnnSequenceTensorDescriptor& output_desc,
2663 const DeviceMemory<float>& output_data,
2664 const dnn::RnnStateTensorDescriptor& output_h_desc,
2665 const DeviceMemory<float>& output_h_data,
2666 const dnn::RnnStateTensorDescriptor& output_c_desc,
2667 const DeviceMemory<float>& output_c_data,
2668 const DeviceMemory<float>& output_backprop_data,
2669 const DeviceMemory<float>& output_h_backprop_data,
2670 const DeviceMemory<float>& output_c_backprop_data,
2671 DeviceMemory<float>* input_backprop_data,
2672 DeviceMemory<float>* input_h_backprop_data,
2673 DeviceMemory<float>* input_c_backprop_data,
2674 DeviceMemory<float>* params_backprop_data,
2675 DeviceMemory<uint8>* reserve_space_data,
2676 ScratchAllocator* workspace_allocator,
2677 dnn::ProfileResult* output_profile_result) {
2678 const CudnnRnnDescriptor& cudnn_rnn_desc =
2679 static_cast<const CudnnRnnDescriptor&>(rnn_desc);
2680 const CudnnRnnSequenceTensorDescriptor& cudnn_input_desc =
2681 static_cast<const CudnnRnnSequenceTensorDescriptor&>(input_desc);
2682 const CudnnRnnStateTensorDescriptor& cudnn_input_h_desc =
2683 static_cast<const CudnnRnnStateTensorDescriptor&>(input_h_desc);
2684 const CudnnRnnStateTensorDescriptor& cudnn_input_c_desc =
2685 static_cast<const CudnnRnnStateTensorDescriptor&>(input_c_desc);
2686 const CudnnRnnSequenceTensorDescriptor& cudnn_output_desc =
2687 static_cast<const CudnnRnnSequenceTensorDescriptor&>(output_desc);
2688 const CudnnRnnStateTensorDescriptor& cudnn_output_h_desc =
2689 static_cast<const CudnnRnnStateTensorDescriptor&>(output_h_desc);
2690 const CudnnRnnStateTensorDescriptor& cudnn_output_c_desc =
2691 static_cast<const CudnnRnnStateTensorDescriptor&>(output_c_desc);
2692 return IsStatusOk(
2693 DoRnnBackwardImpl<float>(
2694 stream, cudnn_rnn_desc, cudnn_input_desc, input_data,
2695 seq_lengths_data, cudnn_input_h_desc, input_h_data,
2696 cudnn_input_c_desc, input_c_data, params, cudnn_output_desc,
2697 output_data, cudnn_output_h_desc, output_h_data, cudnn_output_c_desc,
2698 output_c_data, output_backprop_data, output_h_backprop_data,
2699 output_c_backprop_data, input_backprop_data, input_h_backprop_data,
2700 input_c_backprop_data, params_backprop_data, reserve_space_data,
2701 workspace_allocator, output_profile_result),
2702 /*report_error=*/!output_profile_result);
2703 }
2704
DoRnnBackward(Stream * stream,const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<double> & input_data,const DeviceMemory<int> & seq_lengths_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<double> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<double> & input_c_data,const DeviceMemory<double> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,const DeviceMemory<double> & output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,const DeviceMemory<double> & output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,const DeviceMemory<double> & output_c_data,const DeviceMemory<double> & output_backprop_data,const DeviceMemory<double> & output_h_backprop_data,const DeviceMemory<double> & output_c_backprop_data,DeviceMemory<double> * input_backprop_data,DeviceMemory<double> * input_h_backprop_data,DeviceMemory<double> * input_c_backprop_data,DeviceMemory<double> * params_backprop_data,DeviceMemory<uint8> * reserve_space_data,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)2705 bool CudnnSupport::DoRnnBackward(
2706 Stream* stream, const dnn::RnnDescriptor& rnn_desc,
2707 const dnn::RnnSequenceTensorDescriptor& input_desc,
2708 const DeviceMemory<double>& input_data,
2709 const DeviceMemory<int>& seq_lengths_data,
2710 const dnn::RnnStateTensorDescriptor& input_h_desc,
2711 const DeviceMemory<double>& input_h_data,
2712 const dnn::RnnStateTensorDescriptor& input_c_desc,
2713 const DeviceMemory<double>& input_c_data,
2714 const DeviceMemory<double>& params,
2715 const dnn::RnnSequenceTensorDescriptor& output_desc,
2716 const DeviceMemory<double>& output_data,
2717 const dnn::RnnStateTensorDescriptor& output_h_desc,
2718 const DeviceMemory<double>& output_h_data,
2719 const dnn::RnnStateTensorDescriptor& output_c_desc,
2720 const DeviceMemory<double>& output_c_data,
2721 const DeviceMemory<double>& output_backprop_data,
2722 const DeviceMemory<double>& output_h_backprop_data,
2723 const DeviceMemory<double>& output_c_backprop_data,
2724 DeviceMemory<double>* input_backprop_data,
2725 DeviceMemory<double>* input_h_backprop_data,
2726 DeviceMemory<double>* input_c_backprop_data,
2727 DeviceMemory<double>* params_backprop_data,
2728 DeviceMemory<uint8>* reserve_space_data,
2729 ScratchAllocator* workspace_allocator,
2730 dnn::ProfileResult* output_profile_result) {
2731 const CudnnRnnDescriptor& cudnn_rnn_desc =
2732 static_cast<const CudnnRnnDescriptor&>(rnn_desc);
2733 const CudnnRnnSequenceTensorDescriptor& cudnn_input_desc =
2734 static_cast<const CudnnRnnSequenceTensorDescriptor&>(input_desc);
2735 const CudnnRnnStateTensorDescriptor& cudnn_input_h_desc =
2736 static_cast<const CudnnRnnStateTensorDescriptor&>(input_h_desc);
2737 const CudnnRnnStateTensorDescriptor& cudnn_input_c_desc =
2738 static_cast<const CudnnRnnStateTensorDescriptor&>(input_c_desc);
2739 const CudnnRnnSequenceTensorDescriptor& cudnn_output_desc =
2740 static_cast<const CudnnRnnSequenceTensorDescriptor&>(output_desc);
2741 const CudnnRnnStateTensorDescriptor& cudnn_output_h_desc =
2742 static_cast<const CudnnRnnStateTensorDescriptor&>(output_h_desc);
2743 const CudnnRnnStateTensorDescriptor& cudnn_output_c_desc =
2744 static_cast<const CudnnRnnStateTensorDescriptor&>(output_c_desc);
2745 return IsStatusOk(
2746 DoRnnBackwardImpl<double>(
2747 stream, cudnn_rnn_desc, cudnn_input_desc, input_data,
2748 seq_lengths_data, cudnn_input_h_desc, input_h_data,
2749 cudnn_input_c_desc, input_c_data, params, cudnn_output_desc,
2750 output_data, cudnn_output_h_desc, output_h_data, cudnn_output_c_desc,
2751 output_c_data, output_backprop_data, output_h_backprop_data,
2752 output_c_backprop_data, input_backprop_data, input_h_backprop_data,
2753 input_c_backprop_data, params_backprop_data, reserve_space_data,
2754 workspace_allocator, output_profile_result),
2755 /*report_error=*/!output_profile_result);
2756 }
2757
2758 namespace {
2759
2760 // TODO(csigg): Merge a lot of duplicate code below for forward, backward data,
2761 // and backward filter.
2762
GetCudnnConvolutionForwardAlgo(const CudnnHandle & cudnn,const CudnnTensorDescriptor & input_nd,const CudnnFilterDescriptor & filter,const CudnnConvolutionDescriptor & conv,const CudnnTensorDescriptor & output_nd,bool specify_workspace_limit,size_t memory_limit_bytes)2763 port::StatusOr<cudnnConvolutionFwdAlgo_t> GetCudnnConvolutionForwardAlgo(
2764 const CudnnHandle& cudnn, const CudnnTensorDescriptor& input_nd,
2765 const CudnnFilterDescriptor& filter, const CudnnConvolutionDescriptor& conv,
2766 const CudnnTensorDescriptor& output_nd, bool specify_workspace_limit,
2767 size_t memory_limit_bytes) {
2768 #if CUDNN_VERSION >= 8000
2769 const int num_requested_algos = 5;
2770 int num_returned_algos = 0;
2771 cudnnConvolutionFwdAlgoPerf_t perf_results[num_requested_algos];
2772
2773 RETURN_IF_CUDNN_ERROR(cudnnGetConvolutionForwardAlgorithm_v7(
2774 cudnn.handle(), input_nd.handle(), filter.handle(), conv.handle(),
2775 output_nd.handle(), num_requested_algos, &num_returned_algos,
2776 perf_results));
2777
2778 size_t mem_limit = specify_workspace_limit ? memory_limit_bytes : 0ULL;
2779 for (int r = 0; r < num_returned_algos; r++) {
2780 if (perf_results[r].status == CUDNN_STATUS_SUCCESS &&
2781 perf_results[r].algo != CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED &&
2782 perf_results[r].memory <= mem_limit) {
2783 return perf_results[r].algo;
2784 }
2785 }
2786 return port::Status(port::error::INTERNAL,
2787 "cudnnGetConvolutionForwardAlgorithm_v7 returned "
2788 "no suitable algorithms. This could be a cudnn bug.");
2789 #else
2790 cudnnConvolutionFwdPreference_t preference =
2791 specify_workspace_limit ? CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT
2792 : CUDNN_CONVOLUTION_FWD_NO_WORKSPACE;
2793 cudnnConvolutionFwdAlgo_t algo_to_use;
2794 RETURN_IF_CUDNN_ERROR(cudnnGetConvolutionForwardAlgorithm(
2795 cudnn.handle(), input_nd.handle(), filter.handle(), conv.handle(),
2796 output_nd.handle(), preference, memory_limit_bytes, &algo_to_use));
2797 return algo_to_use;
2798 #endif
2799 }
2800
2801 port::StatusOr<cudnnConvolutionBwdDataAlgo_t>
GetCudnnConvolutionBackwardDataAlgo(const CudnnHandle & cudnn,const CudnnTensorDescriptor & input_nd,const CudnnFilterDescriptor & filter,const CudnnConvolutionDescriptor & conv,const CudnnTensorDescriptor & output_nd,bool specify_workspace_limit,size_t memory_limit_bytes)2802 GetCudnnConvolutionBackwardDataAlgo(const CudnnHandle& cudnn,
2803 const CudnnTensorDescriptor& input_nd,
2804 const CudnnFilterDescriptor& filter,
2805 const CudnnConvolutionDescriptor& conv,
2806 const CudnnTensorDescriptor& output_nd,
2807 bool specify_workspace_limit,
2808 size_t memory_limit_bytes) {
2809 #if CUDNN_VERSION >= 8000
2810 const int num_requested_algos = 5;
2811 int num_returned_algos = 0;
2812 cudnnConvolutionBwdDataAlgoPerf_t perf_results[num_requested_algos];
2813
2814 RETURN_IF_CUDNN_ERROR(cudnnGetConvolutionBackwardDataAlgorithm_v7(
2815 cudnn.handle(), filter.handle(), output_nd.handle(), conv.handle(),
2816 input_nd.handle(), num_requested_algos, &num_returned_algos,
2817 perf_results));
2818
2819 size_t mem_limit = specify_workspace_limit ? memory_limit_bytes : 0ULL;
2820 for (int r = 0; r < num_returned_algos; r++) {
2821 if (perf_results[r].status == CUDNN_STATUS_SUCCESS &&
2822 perf_results[r].algo !=
2823 CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD_NONFUSED &&
2824 perf_results[r].memory <= mem_limit) {
2825 return perf_results[r].algo;
2826 }
2827 }
2828 return port::Status(port::error::INTERNAL,
2829 "cudnnGetConvolutionBackwardDataAlgorithm_v7 returned "
2830 "no suitable algorithms. This could be a cudnn bug.");
2831 #else
2832 cudnnConvolutionBwdDataPreference_t preference =
2833 specify_workspace_limit
2834 ? CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT
2835 : CUDNN_CONVOLUTION_BWD_DATA_NO_WORKSPACE;
2836 cudnnConvolutionBwdDataAlgo_t algo_to_use;
2837 RETURN_IF_CUDNN_ERROR(cudnnGetConvolutionBackwardDataAlgorithm(
2838 cudnn.handle(), filter.handle(), output_nd.handle(), conv.handle(),
2839 input_nd.handle(), preference, memory_limit_bytes, &algo_to_use));
2840 return algo_to_use;
2841 #endif
2842 }
2843
2844 port::StatusOr<cudnnConvolutionBwdFilterAlgo_t>
GetCudnnConvolutionBackwardFilterAlgo(const CudnnHandle & cudnn,const CudnnTensorDescriptor & input_nd,const CudnnFilterDescriptor & filter,const CudnnConvolutionDescriptor & conv,const CudnnTensorDescriptor & output_nd,bool specify_workspace_limit,size_t memory_limit_bytes)2845 GetCudnnConvolutionBackwardFilterAlgo(const CudnnHandle& cudnn,
2846 const CudnnTensorDescriptor& input_nd,
2847 const CudnnFilterDescriptor& filter,
2848 const CudnnConvolutionDescriptor& conv,
2849 const CudnnTensorDescriptor& output_nd,
2850 bool specify_workspace_limit,
2851 size_t memory_limit_bytes) {
2852 #if CUDNN_VERSION >= 8000
2853 const int num_requested_algos = 5;
2854 int num_returned_algos = 0;
2855 cudnnConvolutionBwdFilterAlgoPerf_t perf_results[num_requested_algos];
2856
2857 RETURN_IF_CUDNN_ERROR(cudnnGetConvolutionBackwardFilterAlgorithm_v7(
2858 cudnn.handle(), input_nd.handle(), output_nd.handle(), conv.handle(),
2859 filter.handle(), num_requested_algos, &num_returned_algos, perf_results));
2860
2861 size_t mem_limit = specify_workspace_limit ? memory_limit_bytes : 0ULL;
2862 for (int r = 0; r < num_returned_algos; r++) {
2863 if (perf_results[r].status == CUDNN_STATUS_SUCCESS &&
2864 perf_results[r].algo !=
2865 CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD_NONFUSED &&
2866 perf_results[r].memory <= mem_limit) {
2867 return perf_results[r].algo;
2868 }
2869 }
2870 return port::Status(port::error::INTERNAL,
2871 "cudnnGetConvolutionBackwardFilterAlgorithm_v7 returned "
2872 "no suitable algorithms. This could be a cudnn bug.");
2873 #else
2874 cudnnConvolutionBwdFilterPreference_t preference =
2875 specify_workspace_limit
2876 ? CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT
2877 : CUDNN_CONVOLUTION_BWD_FILTER_NO_WORKSPACE;
2878 cudnnConvolutionBwdFilterAlgo_t algo_to_use;
2879 RETURN_IF_CUDNN_ERROR(cudnnGetConvolutionBackwardFilterAlgorithm(
2880 cudnn.handle(), input_nd.handle(), output_nd.handle(), conv.handle(),
2881 filter.handle(), preference, memory_limit_bytes, &algo_to_use));
2882 return algo_to_use;
2883 #endif
2884 }
2885
AllocateCudnnConvolutionForwardWorkspace(Stream * stream,const CudnnHandle & cudnn,const CudnnTensorDescriptor & input_nd,const CudnnFilterDescriptor & filter,const CudnnConvolutionDescriptor & conv,const CudnnTensorDescriptor & output_nd,const dnn::AlgorithmDesc & algorithm_desc,ScratchAllocator * scratch_allocator)2886 port::StatusOr<DeviceMemory<uint8>> AllocateCudnnConvolutionForwardWorkspace(
2887 Stream* stream, const CudnnHandle& cudnn,
2888 const CudnnTensorDescriptor& input_nd, const CudnnFilterDescriptor& filter,
2889 const CudnnConvolutionDescriptor& conv,
2890 const CudnnTensorDescriptor& output_nd,
2891 const dnn::AlgorithmDesc& algorithm_desc,
2892 ScratchAllocator* scratch_allocator) {
2893 if (IsTensorMathOpSet(conv) != algorithm_desc.tensor_ops_enabled()) {
2894 return port::Status(
2895 port::error::INTERNAL,
2896 "Mismatch between cudnn conv and algorithm descriptors.");
2897 }
2898
2899 // Query the size of the workspace and allocate it.
2900 size_t size_in_bytes;
2901 if (algorithm_desc.workspace_size()) {
2902 size_in_bytes = *algorithm_desc.workspace_size();
2903 } else {
2904 RETURN_IF_CUDNN_ERROR(cudnnGetConvolutionForwardWorkspaceSize(
2905 cudnn.handle(),
2906 /*xDesc=*/input_nd.handle(),
2907 /*wDesc=*/filter.handle(), /*convDesc=*/conv.handle(),
2908 /*yDesc=*/output_nd.handle(),
2909 /*algo=*/ToConvForwardAlgo(algorithm_desc),
2910 /*sizeInBytes=*/&size_in_bytes));
2911 }
2912
2913 int64_t size_in_bytes_int64_t = size_in_bytes;
2914
2915 if (TF_PREDICT_FALSE(size_in_bytes_int64_t < 0)) {
2916 return port::Status(
2917 port::error::INTERNAL,
2918 "cudnnGetConvolutionForwardWorkspaceSize() returned "
2919 "negative sizeInBytes value. This could be a cudnn bug.");
2920 }
2921
2922 if (size_in_bytes_int64_t == 0) {
2923 return DeviceMemory<uint8>();
2924 }
2925
2926 if (TF_PREDICT_FALSE(!scratch_allocator)) {
2927 return port::Status(port::error::INVALID_ARGUMENT,
2928 "No scratch allocator provided");
2929 }
2930
2931 return scratch_allocator->AllocateBytes(size_in_bytes);
2932 }
2933
2934 port::StatusOr<DeviceMemory<uint8>>
AllocateCudnnConvolutionBackwardDataWorkspace(Stream * stream,const CudnnHandle & cudnn,const CudnnTensorDescriptor & input_nd,const CudnnFilterDescriptor & filter,const CudnnConvolutionDescriptor & conv,const CudnnTensorDescriptor & output_nd,const dnn::AlgorithmDesc & algorithm_desc,ScratchAllocator * scratch_allocator)2935 AllocateCudnnConvolutionBackwardDataWorkspace(
2936 Stream* stream, const CudnnHandle& cudnn,
2937 const CudnnTensorDescriptor& input_nd, const CudnnFilterDescriptor& filter,
2938 const CudnnConvolutionDescriptor& conv,
2939 const CudnnTensorDescriptor& output_nd,
2940 const dnn::AlgorithmDesc& algorithm_desc,
2941 ScratchAllocator* scratch_allocator) {
2942 if (IsTensorMathOpSet(conv) != algorithm_desc.tensor_ops_enabled()) {
2943 return port::Status(
2944 port::error::INTERNAL,
2945 "Mismatch between cudnn conv and algorithm descriptors.");
2946 }
2947
2948 // Query the size of the workspace and allocate it.
2949 size_t size_in_bytes;
2950 if (algorithm_desc.workspace_size()) {
2951 size_in_bytes = *algorithm_desc.workspace_size();
2952 } else {
2953 RETURN_IF_CUDNN_ERROR(cudnnGetConvolutionBackwardDataWorkspaceSize(
2954 cudnn.handle(),
2955 /*wDesc=*/filter.handle(),
2956 /*dyDesc=*/output_nd.handle(),
2957 /*convDesc=*/conv.handle(),
2958 /*dxDesc=*/input_nd.handle(),
2959 /*algo=*/ToConvBackwardDataAlgo(algorithm_desc),
2960 /*sizeInBytes=*/&size_in_bytes));
2961 }
2962
2963 int64_t size_in_bytes_int64_t = size_in_bytes;
2964
2965 if (TF_PREDICT_FALSE(size_in_bytes_int64_t < 0)) {
2966 return port::Status(
2967 port::error::INTERNAL,
2968 "cudnnGetConvolutionBackwardDataWorkspaceSize() returned "
2969 "negative sizeInBytes value. This could be a cudnn bug.");
2970 }
2971
2972 if (size_in_bytes_int64_t == 0) {
2973 return DeviceMemory<uint8>();
2974 }
2975
2976 if (TF_PREDICT_FALSE(!scratch_allocator)) {
2977 return port::Status(port::error::INVALID_ARGUMENT,
2978 "No scratch allocator provided");
2979 }
2980
2981 return scratch_allocator->AllocateBytes(size_in_bytes);
2982 }
2983
2984 port::StatusOr<DeviceMemory<uint8>>
AllocateCudnnConvolutionBackwardFilterWorkspace(Stream * stream,const CudnnHandle & cudnn,const CudnnTensorDescriptor & input_nd,const CudnnFilterDescriptor & filter,const CudnnConvolutionDescriptor & conv,const CudnnTensorDescriptor & output_nd,const dnn::AlgorithmDesc & algorithm_desc,ScratchAllocator * scratch_allocator)2985 AllocateCudnnConvolutionBackwardFilterWorkspace(
2986 Stream* stream, const CudnnHandle& cudnn,
2987 const CudnnTensorDescriptor& input_nd, const CudnnFilterDescriptor& filter,
2988 const CudnnConvolutionDescriptor& conv,
2989 const CudnnTensorDescriptor& output_nd,
2990 const dnn::AlgorithmDesc& algorithm_desc,
2991 ScratchAllocator* scratch_allocator) {
2992 if (IsTensorMathOpSet(conv) != algorithm_desc.tensor_ops_enabled()) {
2993 return port::Status(
2994 port::error::INTERNAL,
2995 "Mismatch between cudnn conv and algorithm descriptors.");
2996 }
2997
2998 // Query the size of the workspace and allocate it.
2999 size_t size_in_bytes;
3000 if (algorithm_desc.workspace_size()) {
3001 size_in_bytes = *algorithm_desc.workspace_size();
3002 } else {
3003 RETURN_IF_CUDNN_ERROR(cudnnGetConvolutionBackwardFilterWorkspaceSize(
3004 cudnn.handle(),
3005 /*xDesc=*/input_nd.handle(),
3006 /*dyDesc=*/output_nd.handle(),
3007 /*convDesc=*/conv.handle(),
3008 /*gradDesc=*/filter.handle(),
3009 /*algo=*/ToConvBackwardFilterAlgo(algorithm_desc),
3010 /*sizeInBytes=*/&size_in_bytes));
3011 }
3012
3013 int64_t size_in_bytes_int64_t = size_in_bytes;
3014
3015 if (TF_PREDICT_FALSE(size_in_bytes_int64_t < 0)) {
3016 return port::Status(
3017 port::error::INTERNAL,
3018 "cudnnGetConvolutionBackwardFilterWorkspaceSize() returned "
3019 "negative sizeInBytes value. This could be a cudnn bug.");
3020 }
3021
3022 if (size_in_bytes_int64_t == 0) {
3023 return DeviceMemory<uint8>();
3024 }
3025
3026 if (TF_PREDICT_FALSE(!scratch_allocator)) {
3027 return port::Status(port::error::INVALID_ARGUMENT,
3028 "No scratch allocator provided");
3029 }
3030
3031 return scratch_allocator->AllocateBytes(size_in_bytes);
3032 }
3033
UseTensorOps(Stream * stream,dnn::DataType type,std::optional<dnn::AlgorithmDesc> desc)3034 port::StatusOr<bool> UseTensorOps(Stream* stream, dnn::DataType type,
3035 std::optional<dnn::AlgorithmDesc> desc) {
3036 bool use_tensor_ops;
3037 if (desc.has_value()) {
3038 use_tensor_ops = desc->tensor_ops_enabled();
3039 if (use_tensor_ops && !IsTensorMathEnabled(stream, type)) {
3040 return port::Status(port::error::INVALID_ARGUMENT,
3041 "Algo requests disabled tensor op evaluation.");
3042 }
3043 } else {
3044 use_tensor_ops = IsTensorMathEnabled(stream, type);
3045 }
3046 return use_tensor_ops;
3047 }
3048
3049 cudnnDataType_t GetRnnComputeType(dnn::DataType data_type);
3050 dnn::DataType GetConvAccumulatorType(dnn::DataType data_type);
3051
GetCudnnConvolutionForwardAlgorithm(Stream * stream,const CudnnHandle & cudnn,const dnn::AlgorithmConfig & algorithm_config,const CudnnTensorDescriptor & input_nd,const CudnnFilterDescriptor & filter,dnn::DataType element_type,const dnn::ConvolutionDescriptor & convolution_descriptor,const CudnnTensorDescriptor & output_nd,ScratchAllocator * scratch_allocator,DeviceMemory<uint8> * scratch)3052 port::StatusOr<dnn::AlgorithmDesc> GetCudnnConvolutionForwardAlgorithm(
3053 Stream* stream, const CudnnHandle& cudnn,
3054 const dnn::AlgorithmConfig& algorithm_config,
3055 const CudnnTensorDescriptor& input_nd, const CudnnFilterDescriptor& filter,
3056 dnn::DataType element_type,
3057 const dnn::ConvolutionDescriptor& convolution_descriptor,
3058 const CudnnTensorDescriptor& output_nd, ScratchAllocator* scratch_allocator,
3059 DeviceMemory<uint8>* scratch) {
3060 std::optional<dnn::AlgorithmDesc> algo_desc = algorithm_config.algorithm();
3061
3062 CudnnConvolutionDescriptor conv(
3063 convolution_descriptor,
3064 ToCudnnDataType(GetConvAccumulatorType(element_type)));
3065 bool use_tensor_ops;
3066 TF_ASSIGN_OR_RETURN(use_tensor_ops,
3067 UseTensorOps(stream, element_type, algo_desc));
3068 conv.set_use_tensor_op_math(use_tensor_ops);
3069
3070 if (!algo_desc.has_value()) {
3071 // Pick fastest algorithm within memory limit according to cuDNN's
3072 // heuristics.
3073 bool specify_workspace_limit = scratch_allocator != nullptr;
3074 auto memory_limit_bytes =
3075 specify_workspace_limit
3076 ? std::max(scratch_allocator->GetMemoryLimitInBytes(), int64_t{0})
3077 : int64_t{0};
3078 TF_ASSIGN_OR_RETURN(cudnnConvolutionFwdAlgo_t algo,
3079 GetCudnnConvolutionForwardAlgo(
3080 cudnn, input_nd, filter, conv, output_nd,
3081 specify_workspace_limit, memory_limit_bytes));
3082 algo_desc = dnn::AlgorithmDesc(algo, use_tensor_ops);
3083 }
3084
3085 const auto scratch_or = AllocateCudnnConvolutionForwardWorkspace(
3086 stream, cudnn, input_nd, filter, conv, output_nd, *algo_desc,
3087 scratch_allocator);
3088
3089 if (scratch_or.ok()) {
3090 *scratch = scratch_or.ValueOrDie();
3091 return *algo_desc;
3092 }
3093
3094 algo_desc = algorithm_config.algorithm_no_scratch();
3095
3096 // Failed to allocate workspace for the first algorithm, fall back to the
3097 // no_scratch algorithm.
3098 if (!algo_desc.has_value()) {
3099 return port::Status(
3100 scratch_or.status().code(),
3101 absl::StrCat("The primary convolution algorithm failed, ",
3102 "while a secondary algorithm is not provided. ",
3103 "Returned status: ", scratch_or.status().ToString()));
3104 }
3105
3106 TF_ASSIGN_OR_RETURN(use_tensor_ops,
3107 UseTensorOps(stream, element_type, algo_desc));
3108 conv.set_use_tensor_op_math(use_tensor_ops);
3109 TF_ASSIGN_OR_RETURN(*scratch, AllocateCudnnConvolutionForwardWorkspace(
3110 stream, cudnn, input_nd, filter, conv,
3111 output_nd, *algo_desc, scratch_allocator));
3112 return *algo_desc;
3113 }
3114
GetCudnnConvolutionBackwardDataAlgorithm(Stream * stream,const CudnnHandle & cudnn,const dnn::AlgorithmConfig & algorithm_config,const CudnnTensorDescriptor & input_nd,const CudnnFilterDescriptor & filter,dnn::DataType element_type,const dnn::ConvolutionDescriptor & convolution_descriptor,const CudnnTensorDescriptor & output_nd,ScratchAllocator * scratch_allocator,DeviceMemory<uint8> * scratch)3115 port::StatusOr<dnn::AlgorithmDesc> GetCudnnConvolutionBackwardDataAlgorithm(
3116 Stream* stream, const CudnnHandle& cudnn,
3117 const dnn::AlgorithmConfig& algorithm_config,
3118 const CudnnTensorDescriptor& input_nd, const CudnnFilterDescriptor& filter,
3119 dnn::DataType element_type,
3120 const dnn::ConvolutionDescriptor& convolution_descriptor,
3121 const CudnnTensorDescriptor& output_nd, ScratchAllocator* scratch_allocator,
3122 DeviceMemory<uint8>* scratch) {
3123 std::optional<dnn::AlgorithmDesc> algo_desc = algorithm_config.algorithm();
3124 CudnnConvolutionDescriptor conv(
3125 convolution_descriptor,
3126 ToCudnnDataType(GetConvAccumulatorType(element_type)));
3127 bool use_tensor_ops;
3128 TF_ASSIGN_OR_RETURN(use_tensor_ops,
3129 UseTensorOps(stream, element_type, algo_desc));
3130 conv.set_use_tensor_op_math(use_tensor_ops);
3131
3132 if (!algo_desc.has_value()) {
3133 // Pick fastest algorithm within memory limit according to cuDNN's
3134 // heuristics.
3135 bool specify_workspace_limit = scratch_allocator != nullptr;
3136 auto memory_limit_bytes =
3137 specify_workspace_limit
3138 ? std::max(scratch_allocator->GetMemoryLimitInBytes(), int64_t{0})
3139 : int64_t{0};
3140 TF_ASSIGN_OR_RETURN(cudnnConvolutionBwdDataAlgo_t algo,
3141 GetCudnnConvolutionBackwardDataAlgo(
3142 cudnn, input_nd, filter, conv, output_nd,
3143 specify_workspace_limit, memory_limit_bytes));
3144 algo_desc = dnn::AlgorithmDesc(algo, use_tensor_ops);
3145 }
3146
3147 const auto scratch_or = AllocateCudnnConvolutionBackwardDataWorkspace(
3148 stream, cudnn, input_nd, filter, conv, output_nd, *algo_desc,
3149 scratch_allocator);
3150
3151 if (scratch_or.ok()) {
3152 *scratch = scratch_or.ValueOrDie();
3153 return *algo_desc;
3154 }
3155
3156 algo_desc = algorithm_config.algorithm_no_scratch();
3157
3158 // Failed to allocate workspace for the first algorithm, fall back to the
3159 // no_scratch algorithm.
3160 if (!algo_desc.has_value()) {
3161 return port::Status(
3162 port::error::INVALID_ARGUMENT,
3163 "The primary convolution algorithm failed memory allocation, "
3164 "while a secondary algorithm is not provided.");
3165 }
3166
3167 TF_ASSIGN_OR_RETURN(use_tensor_ops,
3168 UseTensorOps(stream, element_type, algo_desc));
3169 conv.set_use_tensor_op_math(use_tensor_ops);
3170 TF_ASSIGN_OR_RETURN(*scratch, AllocateCudnnConvolutionBackwardDataWorkspace(
3171 stream, cudnn, input_nd, filter, conv,
3172 output_nd, *algo_desc, scratch_allocator));
3173 return *algo_desc;
3174 }
3175
GetCudnnConvolutionBackwardFilterAlgorithm(Stream * stream,const CudnnHandle & cudnn,const dnn::AlgorithmConfig & algorithm_config,const CudnnTensorDescriptor & input_nd,const CudnnFilterDescriptor & filter,dnn::DataType element_type,const dnn::ConvolutionDescriptor & convolution_descriptor,const CudnnTensorDescriptor & output_nd,ScratchAllocator * scratch_allocator,DeviceMemory<uint8> * scratch)3176 port::StatusOr<dnn::AlgorithmDesc> GetCudnnConvolutionBackwardFilterAlgorithm(
3177 Stream* stream, const CudnnHandle& cudnn,
3178 const dnn::AlgorithmConfig& algorithm_config,
3179 const CudnnTensorDescriptor& input_nd, const CudnnFilterDescriptor& filter,
3180 dnn::DataType element_type,
3181 const dnn::ConvolutionDescriptor& convolution_descriptor,
3182 const CudnnTensorDescriptor& output_nd, ScratchAllocator* scratch_allocator,
3183 DeviceMemory<uint8>* scratch) {
3184 std::optional<dnn::AlgorithmDesc> algo_desc = algorithm_config.algorithm();
3185 CudnnConvolutionDescriptor conv(
3186 convolution_descriptor,
3187 ToCudnnDataType(GetConvAccumulatorType(element_type)));
3188 bool use_tensor_ops;
3189 TF_ASSIGN_OR_RETURN(use_tensor_ops,
3190 UseTensorOps(stream, element_type, algo_desc));
3191 conv.set_use_tensor_op_math(use_tensor_ops);
3192
3193 if (!algo_desc.has_value()) {
3194 // Pick fastest algorithm within memory limit according to cuDNN's
3195 // heuristics.
3196 bool specify_workspace_limit = scratch_allocator != nullptr;
3197 auto memory_limit_bytes =
3198 specify_workspace_limit
3199 ? std::max(scratch_allocator->GetMemoryLimitInBytes(), int64_t{0})
3200 : int64_t{0};
3201 TF_ASSIGN_OR_RETURN(cudnnConvolutionBwdFilterAlgo_t algo,
3202 GetCudnnConvolutionBackwardFilterAlgo(
3203 cudnn, input_nd, filter, conv, output_nd,
3204 specify_workspace_limit, memory_limit_bytes));
3205 algo_desc = dnn::AlgorithmDesc(algo, use_tensor_ops);
3206 }
3207
3208 port::StatusOr<DeviceMemory<uint8>> scratch_or =
3209 AllocateCudnnConvolutionBackwardFilterWorkspace(
3210 stream, cudnn, input_nd, filter, conv, output_nd, *algo_desc,
3211 scratch_allocator);
3212
3213 if (scratch_or.ok()) {
3214 *scratch = scratch_or.ValueOrDie();
3215 return *algo_desc;
3216 }
3217
3218 algo_desc = algorithm_config.algorithm_no_scratch();
3219
3220 // Failed to allocate workspace for the first algorithm, fall back to the
3221 // no_scratch algorithm.
3222 if (!algo_desc.has_value()) {
3223 return port::Status(
3224 port::error::INVALID_ARGUMENT,
3225 absl::StrCat(
3226 "The primary convolution algorithm failed memory allocation, "
3227 "while a secondary algorithm is not provided. Actual error: ",
3228 scratch_or.status().ToString()));
3229 }
3230
3231 TF_ASSIGN_OR_RETURN(use_tensor_ops,
3232 UseTensorOps(stream, element_type, algo_desc));
3233 conv.set_use_tensor_op_math(use_tensor_ops);
3234 TF_ASSIGN_OR_RETURN(*scratch, AllocateCudnnConvolutionBackwardFilterWorkspace(
3235 stream, cudnn, input_nd, filter, conv,
3236 output_nd, *algo_desc, scratch_allocator));
3237 return *algo_desc;
3238 }
3239
3240 // A helper class to set env-vars and choose options for cudnn-related
3241 // algorithms.
3242 template <typename EnvVar>
3243 class CudnnEnvVar {
3244 public:
IsEnabled()3245 static bool IsEnabled() {
3246 static bool is_enabled = IsEnabledImpl();
3247 return is_enabled;
3248 }
3249
3250 private:
IsEnabledImpl()3251 static bool IsEnabledImpl() {
3252 const char* tf_env_var_val = getenv(EnvVar::kName);
3253 if (tf_env_var_val != nullptr) {
3254 absl::string_view tf_env_var_val_str(tf_env_var_val);
3255 if (tf_env_var_val_str == "0") {
3256 return false;
3257 }
3258 return true;
3259 }
3260 return EnvVar::kDefaultFlag;
3261 }
3262 };
3263
3264 // A helper struct to decide whether to enable the FFT_TILING algorithms for
3265 // forward convolution. It is disabled for cuDNN < 7 due to memory corruption
3266 // caused by some shapes with this algorithm. Users can explicitly enable the
3267 // algorithm through an env-var "TF_ENABLE_FFT_TILING_FORWARD=1".
3268 struct FftTilingForward {
3269 static constexpr const char* kName = "TF_ENABLE_FFT_TILING_FORWARD";
3270 static constexpr bool kDefaultFlag = true;
3271 };
3272
3273 // A helper struct to decide whether to enable the WINOGRAD_NONFUSED algorithms.
3274 // By default it is turned on, users can explicitly disable them through an
3275 // env-var "TF_ENABLE_WINOGRAD_NONFUSED=0".
3276 // https://github.com/tensorflow/tensorflow/pull/4901
3277 // For CUDNN v8.1, when this env-var is turned off, both the winograd and
3278 // winograd-non-fused engines will be ruled out.
3279 struct WinogradNonfused {
3280 static constexpr const char* kName = "TF_ENABLE_WINOGRAD_NONFUSED";
3281 // NVIDIA has fixed winograd nonfused bug for cudnn v>=7. For older versions,
3282 // we have a workaround.
3283 static constexpr bool kDefaultFlag = true;
3284 };
3285
3286 // A helper struct to decide whether to use FP32 as the internal compute type
3287 // for convolution when the input data type is FP16. By default it is turned on,
3288 // users can explicitly disable them (choose to use FP16 as the internal compute
3289 // type) through an env-var "TF_FP16_CONV_USE_FP32_COMPUTE=0".
3290 struct ConvDoFP32ComputationFP16Input {
3291 static constexpr const char* kName = "TF_FP16_CONV_USE_FP32_COMPUTE";
3292 // Using FP16 as the internal compute type for convolution when the input data
3293 // type is FP16 is only supported on architectures with true fp16 support
3294 // (compute capability 5.3 and 6.0). Setting this to false in an unsupported
3295 // architecture will cause internal errors.
3296 static constexpr bool kDefaultFlag = true;
3297 };
3298
3299 // A helper struct to decide whether to use FP32 as the internal compute type
3300 // for rnn when the input data type is FP16. At present it is turned off,
3301 // users can explicitly control them through an env-var
3302 // TF_FP16_RNN_USE_FP32_COMPUTE.
3303 // After the TODO below is fixed, users should almost always use fp32 compute
3304 // type for training. Using fp16 might suffer suboptimal accuracy due to loss
3305 // in precision.
3306 struct RnnDoFP32ComputationFP16Input {
3307 static constexpr const char* kName = "TF_FP16_RNN_USE_FP32_COMPUTE";
3308 // TODO(jamesqin): b/78182362 flip to true when cudnn 7.1.4 fixes the bug.
3309 // Before cudnn 7.1.4 RNN are always done in fp32, no matter what math
3310 // precision is set.
3311 // Set it temporary to false s.t. no error is raised when using fp16 inputs,
3312 // fp32 math precision.
3313 //
3314 // cuDNN == 7.5.0 is verified to have this fixed.
3315 static constexpr bool kDefaultFlag = CUDNN_VERSION >= 7500;
3316 };
3317
3318 namespace {
3319
3320 #if CUDNN_VERSION >= 8100 && TF_ENABLE_CUDNN_FRONTEND
3321
GenericEngineFilter(cudnnBackendDescriptor_t engine_config,bool disable_winograd,bool disable_nondeterminism,bool disable_tensor_core)3322 bool GenericEngineFilter(cudnnBackendDescriptor_t engine_config,
3323 bool disable_winograd, bool disable_nondeterminism,
3324 bool disable_tensor_core) {
3325 bool ret = cudnn_frontend::hasNumericalNote<
3326 CUDNN_NUMERICAL_NOTE_DOWN_CONVERT_INPUTS>(engine_config);
3327
3328 if (disable_winograd) {
3329 ret |= cudnn_frontend::hasNumericalNote<CUDNN_NUMERICAL_NOTE_WINOGRAD>(
3330 engine_config);
3331 }
3332
3333 if (disable_nondeterminism) {
3334 ret |=
3335 cudnn_frontend::hasNumericalNote<CUDNN_NUMERICAL_NOTE_NONDETERMINISTIC>(
3336 engine_config);
3337 }
3338
3339 if (disable_tensor_core) {
3340 ret |= cudnn_frontend::hasNumericalNote<CUDNN_NUMERICAL_NOTE_TENSOR_CORE>(
3341 engine_config);
3342 }
3343
3344 return ret;
3345 }
3346
3347 #endif // CUDNN_VERSION >= 8100 && TF_ENABLE_CUDNN_FRONTEND
3348
3349 } // namespace
3350
GetRnnComputeType(dnn::DataType data_type)3351 cudnnDataType_t GetRnnComputeType(dnn::DataType data_type) {
3352 switch (data_type) {
3353 case dnn::DataType::kFloat:
3354 return CUDNN_DATA_FLOAT;
3355 case dnn::DataType::kDouble:
3356 return CUDNN_DATA_DOUBLE;
3357 case dnn::DataType::kHalf:
3358 if (CudnnEnvVar<RnnDoFP32ComputationFP16Input>::IsEnabled()) {
3359 return CUDNN_DATA_FLOAT;
3360 } else {
3361 return CUDNN_DATA_HALF;
3362 }
3363 default:
3364 LOG(FATAL) << "Invalid RNN data type: " << static_cast<int>(data_type);
3365 }
3366 }
3367
GetConvActivationType(dnn::DataType data_type)3368 dnn::DataType GetConvActivationType(dnn::DataType data_type) {
3369 switch (data_type) {
3370 case dnn::DataType::kFloat:
3371 case dnn::DataType::kDouble:
3372 return data_type;
3373 // TODO(awpr): it's not clear whether float-precision activations on
3374 // half-precision convs are supported; double-check.
3375 case dnn::DataType::kHalf:
3376 return CudnnEnvVar<ConvDoFP32ComputationFP16Input>::IsEnabled()
3377 ? dnn::DataType::kFloat
3378 : dnn::DataType::kHalf;
3379 case dnn::DataType::kInt8:
3380 case dnn::DataType::kInt32: // TODO(awpr): does int32 do blending in float?
3381 return dnn::DataType::kFloat;
3382 #if CUDNN_VERSION >= 8200
3383 // TODO(awpr): as with kHalf, this is not clear.
3384 case dnn::DataType::kBF16:
3385 return CudnnEnvVar<ConvDoFP32ComputationFP16Input>::IsEnabled()
3386 ? dnn::DataType::kFloat
3387 : dnn::DataType::kBF16;
3388 #endif
3389 default:
3390 LOG(FATAL) << "Invalid DNN data type: " << static_cast<int>(data_type);
3391 }
3392 }
3393
GetConvAccumulatorType(dnn::DataType data_type)3394 dnn::DataType GetConvAccumulatorType(dnn::DataType data_type) {
3395 switch (data_type) {
3396 case dnn::DataType::kFloat:
3397 case dnn::DataType::kDouble:
3398 return data_type;
3399 case dnn::DataType::kHalf:
3400 return CudnnEnvVar<ConvDoFP32ComputationFP16Input>::IsEnabled()
3401 ? dnn::DataType::kFloat
3402 : dnn::DataType::kHalf;
3403 case dnn::DataType::kInt8:
3404 case dnn::DataType::kInt32:
3405 return dnn::DataType::kInt32;
3406 #if CUDNN_VERSION >= 8200
3407 case dnn::DataType::kBF16:
3408 return CudnnEnvVar<ConvDoFP32ComputationFP16Input>::IsEnabled()
3409 ? dnn::DataType::kFloat
3410 : dnn::DataType::kBF16;
3411 #endif
3412 default:
3413 LOG(FATAL) << "Invalid DNN data type: " << static_cast<int>(data_type);
3414 }
3415 }
3416
3417 #if CUDNN_VERSION >= 8100 && TF_ENABLE_CUDNN_FRONTEND
3418
3419 namespace {
GetCudnnFrontendHeurMode()3420 cudnnBackendHeurMode_t GetCudnnFrontendHeurMode() {
3421 #if CUDNN_VERSION >= 8300
3422 return CUDNN_HEUR_MODE_B;
3423 #else
3424 return CUDNN_HEUR_MODE_INSTANT;
3425 #endif // CUDNN_VERSION >= 8300
3426 }
3427
GetCudnnConvolutionType(dnn::ConvolutionKind kind)3428 cudnnBackendDescriptorType_t GetCudnnConvolutionType(
3429 dnn::ConvolutionKind kind) {
3430 cudnnBackendDescriptorType_t conv_mode;
3431 switch (kind) {
3432 case dnn::ConvolutionKind::FORWARD: {
3433 conv_mode = CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR;
3434 break;
3435 }
3436 case dnn::ConvolutionKind::BACKWARD_DATA: {
3437 conv_mode = CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR;
3438 break;
3439 }
3440 case dnn::ConvolutionKind::BACKWARD_FILTER: {
3441 conv_mode =
3442 CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR;
3443 break;
3444 }
3445 default:
3446 LOG(FATAL) << "Unexpected convolution kind " << static_cast<int>(kind);
3447 break;
3448 }
3449 return conv_mode;
3450 }
3451
3452 // Cudnn only supports vectorization over the channel dimension (e.g., int8x4,
3453 // or int8x32).
GetTensorVectorSizeAndDim(const dnn::BatchDescriptor & tensor,dnn::DataType element_type)3454 std::tuple<int, int> GetTensorVectorSizeAndDim(
3455 const dnn::BatchDescriptor& tensor, dnn::DataType element_type) {
3456 int vector_size = 1;
3457 int vector_dim = -1;
3458 if (element_type == dnn::DataType::kInt8) {
3459 if (tensor.layout() == dnn::DataLayout::kBatchDepthYX4) {
3460 vector_size = 4;
3461 vector_dim = 1;
3462 } else if (tensor.layout() == dnn::DataLayout::kBatchDepthYX32) {
3463 vector_size = 32;
3464 vector_dim = 1;
3465 }
3466 }
3467 return std::make_tuple(vector_size, vector_dim);
3468 }
3469
GetTensorVectorSizeAndDim(const dnn::FilterDescriptor & filter,dnn::DataType element_type)3470 std::tuple<int, int> GetTensorVectorSizeAndDim(
3471 const dnn::FilterDescriptor& filter, dnn::DataType element_type) {
3472 int vector_size = 1;
3473 int vector_dim = -1;
3474 if (element_type == dnn::DataType::kInt8) {
3475 if (filter.layout() == dnn::FilterLayout::kOutputInputYX4) {
3476 vector_size = 4;
3477 vector_dim = 1;
3478 } else if (filter.layout() == dnn::FilterLayout::kOutputInputYX32) {
3479 vector_size = 32;
3480 vector_dim = 1;
3481 }
3482 }
3483 return std::make_tuple(vector_size, vector_dim);
3484 }
3485
CreateCudnnTensor(absl::Span<const int64_t> dims,absl::Span<const int64_t> strides,int64_t uid,dnn::DataType dtype,int64_t vec_count,int64_t vec_dim,bool is_virtual=false)3486 port::StatusOr<cudnn_frontend::Tensor> CreateCudnnTensor(
3487 absl::Span<const int64_t> dims, absl::Span<const int64_t> strides,
3488 int64_t uid, dnn::DataType dtype, int64_t vec_count, int64_t vec_dim,
3489 bool is_virtual = false) {
3490 auto tensor = cudnn_frontend::TensorBuilder()
3491 .setDim(dims.size(), dims.data())
3492 .setStrides(strides.size(), strides.data())
3493 .setId(uid)
3494 .setAlignment(32)
3495 .setDataType(ToCudnnDataType(dtype))
3496 .setVectorCountAndDimension(vec_count, vec_dim)
3497 .setVirtual(is_virtual)
3498 .build();
3499 RETURN_MSG_IF_CUDNN_ERROR(tensor);
3500 return tensor;
3501 }
3502
3503 port::StatusOr<std::unique_ptr<cudnn_frontend::OperationGraph>>
GetCudnnOperationGraph(dnn::ConvolutionKind kind,dnn::DataType input_type,dnn::DataType output_type,const dnn::BatchDescriptor & input_descriptor,const dnn::FilterDescriptor & filter_descriptor,const dnn::BatchDescriptor & output_descriptor,const dnn::ConvolutionDescriptor & convolution_descriptor,CudnnHandle & cudnn)3504 GetCudnnOperationGraph(dnn::ConvolutionKind kind, dnn::DataType input_type,
3505 dnn::DataType output_type,
3506 const dnn::BatchDescriptor& input_descriptor,
3507 const dnn::FilterDescriptor& filter_descriptor,
3508 const dnn::BatchDescriptor& output_descriptor,
3509 const dnn::ConvolutionDescriptor& convolution_descriptor,
3510 CudnnHandle& cudnn) {
3511 PreloadCudnnSubLibsHelper(kind);
3512
3513 cudnnBackendDescriptorType_t conv_mode = GetCudnnConvolutionType(kind);
3514
3515 // x tensor.
3516 int vector_size, vector_dim;
3517 std::tie(vector_size, vector_dim) =
3518 GetTensorVectorSizeAndDim(input_descriptor, input_type);
3519 std::vector<int64_t> input_dims = input_descriptor.vectorized_dims(
3520 dnn::DataLayout::kBatchDepthYX, vector_size, vector_dim);
3521 std::vector<int64_t> input_strides = input_descriptor.vectorized_strides(
3522 dnn::DataLayout::kBatchDepthYX, vector_size, vector_dim);
3523
3524 if (vector_size == 32) {
3525 return port::InternalError(
3526 "cuDNN frontend doesn't support Tx32 at the moment.");
3527 }
3528
3529 TF_ASSIGN_OR_RETURN(auto tensor_x,
3530 CreateCudnnTensor(input_dims, input_strides, 'x',
3531 input_type, vector_size, vector_dim));
3532
3533 // y tensor.
3534 std::tie(vector_size, vector_dim) =
3535 GetTensorVectorSizeAndDim(output_descriptor, output_type);
3536 std::vector<int64_t> output_dims = output_descriptor.vectorized_dims(
3537 dnn::DataLayout::kBatchDepthYX, vector_size, vector_dim);
3538 std::vector<int64_t> output_strides = output_descriptor.vectorized_strides(
3539 dnn::DataLayout::kBatchDepthYX, vector_size, vector_dim);
3540
3541 TF_ASSIGN_OR_RETURN(auto tensor_y,
3542 CreateCudnnTensor(output_dims, output_strides, 'y',
3543 output_type, vector_size, vector_dim));
3544
3545 // w tensor.
3546 std::tie(vector_size, vector_dim) =
3547 GetTensorVectorSizeAndDim(filter_descriptor, input_type);
3548 std::vector<int64_t> filter_dims = filter_descriptor.vectorized_dims(
3549 dnn::FilterLayout::kOutputInputYX, vector_size, vector_dim);
3550 std::vector<int64_t> filter_strides = filter_descriptor.vectorized_strides(
3551 dnn::FilterLayout::kOutputInputYX, vector_size, vector_dim);
3552
3553 TF_ASSIGN_OR_RETURN(auto tensor_w,
3554 CreateCudnnTensor(filter_dims, filter_strides, 'w',
3555 input_type, vector_size, vector_dim));
3556
3557 // conv_desc.
3558 auto mode = convolution_descriptor.convolution_not_crosscorr()
3559 ? CUDNN_CONVOLUTION
3560 : CUDNN_CROSS_CORRELATION;
3561
3562 int conv_dim = convolution_descriptor.ndims();
3563
3564 auto accumulator_type = ToCudnnDataType(GetConvAccumulatorType(input_type));
3565 CHECK_NE(convolution_descriptor.pad_alignment(),
3566 dnn::PadAlignment::kTensorFlowPadding)
3567 << "TensorFlow padding alignment is not supported.";
3568
3569 auto conv_desc =
3570 cudnn_frontend::ConvDescBuilder()
3571 .setComputePrecision(accumulator_type)
3572 .setMathMode(mode)
3573 .setNDims(conv_dim)
3574 .setStrides(conv_dim, convolution_descriptor.strides().data())
3575 .setPrePadding(conv_dim, convolution_descriptor.padding().data())
3576 .setPostPadding(conv_dim, convolution_descriptor.padding().data())
3577 .setDilation(conv_dim, convolution_descriptor.dilations().data())
3578 .build();
3579 RETURN_MSG_IF_CUDNN_ERROR(conv_desc);
3580
3581 double alpha = 1.0;
3582 double beta = 0.0;
3583
3584 // CUDNN Operation
3585 auto op = cudnn_frontend::OperationBuilder(conv_mode)
3586 .setxDesc(tensor_x)
3587 .setyDesc(tensor_y)
3588 .setwDesc(tensor_w)
3589 .setcDesc(conv_desc)
3590 .setAlpha(alpha)
3591 .setBeta(beta)
3592 .build();
3593 RETURN_MSG_IF_CUDNN_ERROR(op);
3594
3595 // CUDNN OperationGraph
3596 std::array<cudnn_frontend::Operation const*, 1> ops = {&op};
3597 auto opGraph = cudnn_frontend::OperationGraphBuilder()
3598 .setHandle(cudnn.handle())
3599 .setOperationGraph(ops.size(), ops.data())
3600 .build();
3601 RETURN_MSG_IF_CUDNN_ERROR(opGraph);
3602
3603 VLOG(4) << "\nTensor_x: " << tensor_x.describe()
3604 << "\nTensor_y: " << tensor_y.describe()
3605 << "\nTensor_w: " << tensor_w.describe()
3606 << "\nConv: " << conv_desc.describe() << "\nOp: " << op.describe()
3607 << "\nOpGraph: " << opGraph.describe();
3608
3609 return std::unique_ptr<cudnn_frontend::OperationGraph>(
3610 new cudnn_frontend::OperationGraph(std::move(opGraph)));
3611 }
3612
SideInputNeeded(dnn::ActivationMode activation_mode,double conv_scale,double side_input_scale)3613 bool SideInputNeeded(dnn::ActivationMode activation_mode, double conv_scale,
3614 double side_input_scale) {
3615 // Cudnn uses precompiled kernels to perform the Conv-Add-BiasAdd-Act when the
3616 // activation is Relu or Identity and this requires the "side_input" for the
3617 // Add. For other activations, cudnn uses the runtime-compiled kernels.
3618 // However, for this case, we need to drop the Add node and use
3619 // Conv-BiasAdd-Act pattern to trigger the correct cudnn path.
3620 // TODO(kaixih@nvidia): We should remove this WAR when the cudnn fixes it.
3621 bool check_activation = activation_mode == dnn::ActivationMode::kNone ||
3622 activation_mode == dnn::ActivationMode::kRelu;
3623 bool check_scale = conv_scale != 1.0 || side_input_scale != 0.0;
3624 return check_activation || check_scale;
3625 }
3626
3627 port::StatusOr<std::unique_ptr<cudnn_frontend::OperationGraph>>
GetCudnnFusedOperationGraph(dnn::ConvolutionKind kind,dnn::DataType input_type,dnn::DataType bias_type,dnn::DataType output_type,double alpha,double alpha2,double leakyrelu_alpha,const dnn::BatchDescriptor & input_descriptor,const dnn::FilterDescriptor & filter_descriptor,dnn::BatchDescriptor bias_descriptor,const dnn::BatchDescriptor & output_descriptor,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::ActivationMode activation_mode,CudnnHandle & cudnn)3628 GetCudnnFusedOperationGraph(
3629 dnn::ConvolutionKind kind, dnn::DataType input_type,
3630 dnn::DataType bias_type, dnn::DataType output_type, double alpha,
3631 double alpha2, double leakyrelu_alpha,
3632 const dnn::BatchDescriptor& input_descriptor,
3633 const dnn::FilterDescriptor& filter_descriptor,
3634 dnn::BatchDescriptor bias_descriptor,
3635 const dnn::BatchDescriptor& output_descriptor,
3636 const dnn::ConvolutionDescriptor& convolution_descriptor,
3637 const dnn::ActivationMode activation_mode, CudnnHandle& cudnn) {
3638 PreloadCudnnSubLibsHelper(kind);
3639
3640 cudnnBackendDescriptorType_t conv_mode = GetCudnnConvolutionType(kind);
3641 dnn::DataType accumulator_type = GetConvAccumulatorType(input_type);
3642 dnn::DataType activation_type = GetConvActivationType(input_type);
3643
3644 // CUDNN fused operation supports the pattern in the form of
3645 // Conv + Add + BiasAdd + Act. Therefore, we need to build a graph of the
3646 // four ops with their input/output tensor edges:
3647 // Conv : input: tensor_x, tensor_w; output: tensor_conv (virtual)
3648 // Add : input: tensor_conv, tensor_z; output: tensor_add (virtual)
3649 // BiasAdd: input: tensor_add, tensor_b; output: tensor_bias (virtual)
3650 // Act : input: tensor_bias; output: tensor_y
3651 int vector_size, vector_dim;
3652 std::tie(vector_size, vector_dim) =
3653 GetTensorVectorSizeAndDim(input_descriptor, input_type);
3654 std::vector<int64_t> input_dims = input_descriptor.vectorized_dims(
3655 dnn::DataLayout::kBatchDepthYX, vector_size, vector_dim);
3656 std::vector<int64_t> input_strides = input_descriptor.vectorized_strides(
3657 dnn::DataLayout::kBatchDepthYX, vector_size, vector_dim);
3658
3659 if (vector_size == 32) {
3660 return port::InternalError(
3661 "cuDNN frontend doesn't support Tx32 at the moment.");
3662 }
3663
3664 TF_ASSIGN_OR_RETURN(auto tensor_x,
3665 CreateCudnnTensor(input_dims, input_strides, 'x',
3666 input_type, vector_size, vector_dim));
3667
3668 std::tie(vector_size, vector_dim) =
3669 GetTensorVectorSizeAndDim(output_descriptor, output_type);
3670 std::vector<int64_t> output_dims = output_descriptor.vectorized_dims(
3671 dnn::DataLayout::kBatchDepthYX, vector_size, vector_dim);
3672 std::vector<int64_t> output_strides = output_descriptor.vectorized_strides(
3673 dnn::DataLayout::kBatchDepthYX, vector_size, vector_dim);
3674 TF_ASSIGN_OR_RETURN(auto tensor_y,
3675 CreateCudnnTensor(output_dims, output_strides, 'y',
3676 output_type, vector_size, vector_dim));
3677
3678 TF_ASSIGN_OR_RETURN(auto tensor_z,
3679 CreateCudnnTensor(output_dims, output_strides, 'z',
3680 output_type, vector_size, vector_dim));
3681
3682 std::tie(vector_size, vector_dim) =
3683 GetTensorVectorSizeAndDim(filter_descriptor, input_type);
3684 std::vector<int64_t> filter_dims = filter_descriptor.vectorized_dims(
3685 dnn::FilterLayout::kOutputInputYX, vector_size, vector_dim);
3686 std::vector<int64_t> filter_strides = filter_descriptor.vectorized_strides(
3687 dnn::FilterLayout::kOutputInputYX, vector_size, vector_dim);
3688 TF_ASSIGN_OR_RETURN(auto tensor_w,
3689 CreateCudnnTensor(filter_dims, filter_strides, 'w',
3690 input_type, vector_size, vector_dim));
3691
3692 // For the purposes of the cudnn graph, say that the bias tensor has the same
3693 // layout as the output tensor. It doesn't actually matter, because bias is a
3694 // 1D array. But we need to get the correct vectorization, otherwise the
3695 // cudnn graph API rejects this tensor, even though vectorized float tensors
3696 // aren't even a thing in cuDNN.
3697 bias_descriptor.set_layout(output_descriptor.layout());
3698
3699 // Even more unnecessarily subtle: since vectorized float tensors don't exist,
3700 // `GetVectorSizeAndDim` ignores vectorized layouts for floating-point types,
3701 // so we have to ask it for vector sizes as if the type were `input_type`, as
3702 // opposed to `bias_type`. For non-int8 types, these are the same anyway, so
3703 // this only affects int8 convolutions.
3704 std::tie(vector_size, vector_dim) =
3705 GetTensorVectorSizeAndDim(bias_descriptor, input_type);
3706 std::vector<int64_t> bias_dims = bias_descriptor.vectorized_dims(
3707 dnn::DataLayout::kBatchDepthYX, vector_size, vector_dim);
3708 std::vector<int64_t> bias_strides = bias_descriptor.vectorized_strides(
3709 dnn::DataLayout::kBatchDepthYX, vector_size, vector_dim);
3710 TF_ASSIGN_OR_RETURN(auto tensor_b,
3711 CreateCudnnTensor(bias_dims, bias_strides, 'b', bias_type,
3712 vector_size, vector_dim));
3713
3714 std::tie(vector_size, vector_dim) =
3715 GetTensorVectorSizeAndDim(output_descriptor, output_type);
3716 TF_ASSIGN_OR_RETURN(
3717 auto tensor_conv,
3718 CreateCudnnTensor(output_dims, output_strides, 'C', accumulator_type,
3719 vector_size, vector_dim, /*is_virtual=*/true));
3720
3721 TF_ASSIGN_OR_RETURN(
3722 auto tensor_add,
3723 CreateCudnnTensor(output_dims, output_strides, 'A', activation_type,
3724 vector_size, vector_dim, /*is_virtual=*/true));
3725
3726 TF_ASSIGN_OR_RETURN(
3727 auto tensor_bias,
3728 CreateCudnnTensor(output_dims, output_strides, 'B', activation_type,
3729 vector_size, vector_dim, /*is_virtual=*/true));
3730
3731 // conv_desc.
3732 auto mode = convolution_descriptor.convolution_not_crosscorr()
3733 ? CUDNN_CONVOLUTION
3734 : CUDNN_CROSS_CORRELATION;
3735
3736 int conv_dim = convolution_descriptor.ndims();
3737
3738 CHECK_NE(convolution_descriptor.pad_alignment(),
3739 dnn::PadAlignment::kTensorFlowPadding)
3740 << "TensorFlow padding alignment is not supported.";
3741
3742 cudnnDataType_t cudnn_convolution_type = ToCudnnDataType(accumulator_type);
3743 cudnnDataType_t cudnn_activation_type = ToCudnnDataType(activation_type);
3744 auto conv_desc =
3745 cudnn_frontend::ConvDescBuilder()
3746 .setComputePrecision(cudnn_convolution_type)
3747 .setMathMode(mode)
3748 .setNDims(conv_dim)
3749 .setStrides(conv_dim, convolution_descriptor.strides().data())
3750 .setPrePadding(conv_dim, convolution_descriptor.padding().data())
3751 .setPostPadding(conv_dim, convolution_descriptor.padding().data())
3752 .setDilation(conv_dim, convolution_descriptor.dilations().data())
3753 .build();
3754 RETURN_MSG_IF_CUDNN_ERROR(conv_desc);
3755
3756 // CUDNN Operation
3757 auto conv_op = cudnn_frontend::OperationBuilder(conv_mode)
3758 .setxDesc(tensor_x)
3759 .setyDesc(tensor_conv)
3760 .setwDesc(tensor_w)
3761 .setcDesc(conv_desc)
3762 .setAlpha(1.0f)
3763 .setBeta(0.0f)
3764 .build();
3765 RETURN_MSG_IF_CUDNN_ERROR(conv_op);
3766
3767 // CUDNN OperationGraph
3768 absl::InlinedVector<cudnn_frontend::Operation const*, 4> ops = {&conv_op};
3769
3770 bool need_add_op = SideInputNeeded(activation_mode, alpha, alpha2);
3771
3772 std::optional<cudnn_frontend::PointWiseDesc_v8> add_desc;
3773 std::optional<cudnn_frontend::Operation_v8> add_op;
3774 if (need_add_op) {
3775 add_desc.emplace(cudnn_frontend::PointWiseDescBuilder()
3776 .setMode(CUDNN_POINTWISE_ADD)
3777 .setMathPrecision(cudnn_activation_type)
3778 .build());
3779 RETURN_MSG_IF_CUDNN_ERROR(*add_desc);
3780 add_op.emplace(cudnn_frontend::OperationBuilder(
3781 CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
3782 .setxDesc(conv_op.getOutputTensor())
3783 .setbDesc(tensor_z)
3784 .setyDesc(tensor_add)
3785 .setpwDesc(*add_desc)
3786 .setAlpha(alpha)
3787 .setAlpha2(alpha2)
3788 .build());
3789 RETURN_MSG_IF_CUDNN_ERROR(*add_op);
3790 ops.push_back(&*add_op);
3791 }
3792
3793 auto bias_add_desc = cudnn_frontend::PointWiseDescBuilder()
3794 .setMode(CUDNN_POINTWISE_ADD)
3795 .setMathPrecision(cudnn_activation_type)
3796 .build();
3797
3798 // If the activation is the identity function, then the bias-add is the last
3799 // op, and it writes to the output, tensor_y. Otherwise, it writes to the
3800 // "virtual tensor" (temp buffer) tensor_bias, to which we apply the
3801 // activation.
3802 auto& bias_out_desc =
3803 activation_mode == dnn::ActivationMode::kNone ? tensor_y : tensor_bias;
3804 auto& bias_in_desc = need_add_op ? tensor_add : tensor_conv;
3805 auto bias_add_op = cudnn_frontend::OperationBuilder(
3806 CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
3807 .setxDesc(bias_in_desc)
3808 .setbDesc(tensor_b)
3809 .setyDesc(bias_out_desc)
3810 .setpwDesc(bias_add_desc)
3811 .build();
3812 RETURN_MSG_IF_CUDNN_ERROR(bias_add_op);
3813 ops.push_back(&bias_add_op);
3814
3815 std::optional<cudnn_frontend::PointWiseDesc_v8> act_desc;
3816 switch (activation_mode) {
3817 case dnn::ActivationMode::kNone:
3818 break;
3819 case dnn::ActivationMode::kRelu:
3820 act_desc.emplace(cudnn_frontend::PointWiseDescBuilder()
3821 .setMode(CUDNN_POINTWISE_RELU_FWD)
3822 .setMathPrecision(cudnn_activation_type)
3823 .build());
3824 RETURN_MSG_IF_CUDNN_ERROR(*act_desc);
3825
3826 break;
3827 case dnn::ActivationMode::kRelu6:
3828 act_desc.emplace(cudnn_frontend::PointWiseDescBuilder()
3829 .setMode(CUDNN_POINTWISE_RELU_FWD)
3830 .setReluUpperClip(6.0)
3831 .setMathPrecision(cudnn_activation_type)
3832 .build());
3833 RETURN_MSG_IF_CUDNN_ERROR(*act_desc);
3834 break;
3835 case dnn::ActivationMode::kElu:
3836 act_desc.emplace(cudnn_frontend::PointWiseDescBuilder()
3837 .setMode(CUDNN_POINTWISE_ELU_FWD)
3838 .setMathPrecision(cudnn_activation_type)
3839 .build());
3840 RETURN_MSG_IF_CUDNN_ERROR(*act_desc);
3841 break;
3842 case dnn::ActivationMode::kLeakyRelu:
3843 act_desc.emplace(cudnn_frontend::PointWiseDescBuilder()
3844 .setMode(CUDNN_POINTWISE_RELU_FWD)
3845 .setReluLowerClipSlope(leakyrelu_alpha)
3846 .setMathPrecision(cudnn_activation_type)
3847 .build());
3848 RETURN_MSG_IF_CUDNN_ERROR(*act_desc);
3849 break;
3850 default:
3851 return port::InternalError(
3852 absl::StrCat("Unimplemented activation mode ",
3853 dnn::ActivationModeString(activation_mode)));
3854 }
3855
3856 std::optional<cudnn_frontend::Operation_v8> act_op;
3857 if (activation_mode != dnn::ActivationMode::kNone) {
3858 act_op.emplace(cudnn_frontend::OperationBuilder(
3859 CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
3860 .setxDesc(bias_add_op.getOutputTensor())
3861 .setyDesc(tensor_y)
3862 .setpwDesc(*act_desc)
3863 .build());
3864 RETURN_MSG_IF_CUDNN_ERROR(*act_op);
3865 ops.push_back(&*act_op);
3866 }
3867
3868 auto op_graph = cudnn_frontend::OperationGraphBuilder()
3869 .setHandle(cudnn.handle())
3870 .setOperationGraph(ops.size(), ops.data())
3871 .build();
3872 RETURN_MSG_IF_CUDNN_ERROR(op_graph);
3873
3874 VLOG(4) << "\nTensor_x: " << tensor_x.describe()
3875 << "\nTensor_y: " << tensor_y.describe()
3876 << "\nTensor_z: " << tensor_z.describe()
3877 << "\nTensor_w: " << tensor_w.describe()
3878 << "\nTensor_b: " << tensor_b.describe()
3879 << "\nTensor_conv: " << tensor_conv.describe()
3880 << "\nTensor_add: " << tensor_add.describe()
3881 << "\nTensor_bias: " << tensor_bias.describe()
3882 << "\nConv: " << conv_desc.describe() << "\nAdd: "
3883 << (add_desc.has_value() ? add_desc->describe() : "(skipped)")
3884 << "\nBiasAdd: " << bias_add_desc.describe() //
3885 << "\nAct: "
3886 << (act_desc.has_value() ? act_desc->describe() : "(identity)")
3887 << "\nConvOp: " << conv_op.describe() << "\nAddOp: "
3888 << (add_op.has_value() ? add_op->describe() : "(skipped)")
3889 << "\nBiasAddOp: " << bias_add_op.describe() //
3890 << "\nActOp: "
3891 << (act_op.has_value() ? act_op->describe() : "(identity)")
3892 << "\nOpGraph: " << op_graph.describe();
3893
3894 return std::unique_ptr<cudnn_frontend::OperationGraph>(
3895 new cudnn_frontend::OperationGraph(std::move(op_graph)));
3896 }
3897
3898 } // namespace
3899
RebuildExecutionPlan(const CudnnHandle & cudnn,const dnn::AlgorithmDesc & desc,const cudnn_frontend::OperationGraph & op_graph)3900 static port::StatusOr<cudnn_frontend::ExecutionPlan> RebuildExecutionPlan(
3901 const CudnnHandle& cudnn, const dnn::AlgorithmDesc& desc,
3902 const cudnn_frontend::OperationGraph& op_graph) {
3903 if (!desc.is_cudnn_frontend()) {
3904 return port::InternalError(
3905 "Got legacy cuDNN algorithm enum in RebuildExecutionPlan.");
3906 }
3907
3908 // Errors encountered when building a cuDNN operation graph are surfaced in an
3909 // unprecedented and innovative way: they're written into a field of the
3910 // contained engine object, but then clobbered by the object's move
3911 // constructor which makes more cuDNN API calls and encounters further errors.
3912 // The only way to get the actual errors is to peek at them via the returned
3913 // rvalue reference before actually moving the object to finish its
3914 // initialization.
3915 cudnn_frontend::EngineBuilder engine_builder;
3916 engine_builder.setOperationGraph(op_graph).setGlobalEngineIdx(desc.algo_id());
3917 auto&& unmoved = engine_builder.build();
3918 RETURN_MSG_IF_CUDNN_ERROR(unmoved);
3919 cudnn_frontend::Engine engine = std::move(unmoved);
3920 RETURN_MSG_IF_CUDNN_ERROR(engine);
3921
3922 // Miscellaneous compiler bugs and linker issues conspired to make it
3923 // impossible for AlgorithmDesc to just give us a map initially. Get the
3924 // vector of tuning knobs and build the map locally.
3925 auto tuning_knobs_vec = desc.TuningKnobs();
3926 absl::flat_hash_map<int64_t, int64_t> tuning_knobs;
3927 tuning_knobs.reserve(tuning_knobs_vec.size());
3928 for (const auto& pair : tuning_knobs_vec) {
3929 tuning_knobs[pair.first] = pair.second;
3930 }
3931
3932 for (auto& knob : engine.getSupportedKnobs()) {
3933 const auto it = tuning_knobs.find(static_cast<int64_t>(knob.getKnobType()));
3934 if (it != tuning_knobs.end()) {
3935 knob.setChoice(it->second);
3936 }
3937 }
3938
3939 auto engine_config =
3940 cudnn_frontend::EngineConfigBuilder().setEngine(engine).build();
3941 RETURN_MSG_IF_CUDNN_ERROR(engine_config);
3942
3943 auto plan = cudnn_frontend::ExecutionPlanBuilder()
3944 .setHandle(cudnn.handle())
3945 .setEngineConfig(engine_config)
3946 .build();
3947 RETURN_MSG_IF_CUDNN_ERROR(plan);
3948
3949 return {std::move(plan)};
3950 }
3951
3952 #endif // CUDNN_VERSION >= 8100 && TF_ENABLE_CUDNN_FRONTEND
3953
3954 } // namespace
3955
DoPrepareForConvolution(dnn::ConvolutionKind kind,dnn::DataType element_type,Stream * stream,const dnn::BatchDescriptor & input_descriptor,DeviceMemoryBase input_data,const dnn::FilterDescriptor & filter_descriptor,DeviceMemoryBase filter_data,const dnn::BatchDescriptor & output_descriptor,DeviceMemoryBase output_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::AlgorithmConfig & algorithm_config,ScratchAllocator * scratch_allocator,dnn::AlgorithmDesc * algorithm_desc,DeviceMemory<uint8> * scratch_memory)3956 port::Status CudnnSupport::DoPrepareForConvolution(
3957 dnn::ConvolutionKind kind, dnn::DataType element_type, Stream* stream,
3958 const dnn::BatchDescriptor& input_descriptor, DeviceMemoryBase input_data,
3959 const dnn::FilterDescriptor& filter_descriptor,
3960 DeviceMemoryBase filter_data, const dnn::BatchDescriptor& output_descriptor,
3961 DeviceMemoryBase output_data,
3962 const dnn::ConvolutionDescriptor& convolution_descriptor,
3963 const dnn::AlgorithmConfig& algorithm_config,
3964 ScratchAllocator* scratch_allocator, dnn::AlgorithmDesc* algorithm_desc,
3965 DeviceMemory<uint8>* scratch_memory) {
3966 CudnnTensorDescriptor input_nd(
3967 input_descriptor,
3968 ToCudnnDataType(element_type, input_descriptor.layout()));
3969 CudnnFilterDescriptor filter_nd(
3970 filter_descriptor,
3971 ToCudnnDataType(element_type, filter_descriptor.layout()));
3972 CudnnTensorDescriptor output_nd(
3973 output_descriptor,
3974 ToCudnnDataType(element_type, output_descriptor.layout()));
3975
3976 auto cudnn = cudnn_->GetHandle(parent_, stream);
3977
3978 switch (kind) {
3979 case dnn::ConvolutionKind::FORWARD: {
3980 TF_ASSIGN_OR_RETURN(*algorithm_desc,
3981 GetCudnnConvolutionForwardAlgorithm(
3982 stream, cudnn, algorithm_config, input_nd,
3983 filter_nd, element_type, convolution_descriptor,
3984 output_nd, scratch_allocator, scratch_memory));
3985 break;
3986 }
3987 case dnn::ConvolutionKind::BACKWARD_DATA: {
3988 TF_ASSIGN_OR_RETURN(*algorithm_desc,
3989 GetCudnnConvolutionBackwardDataAlgorithm(
3990 stream, cudnn, algorithm_config, input_nd,
3991 filter_nd, element_type, convolution_descriptor,
3992 output_nd, scratch_allocator, scratch_memory));
3993 break;
3994 }
3995 case dnn::ConvolutionKind::BACKWARD_FILTER: {
3996 TF_ASSIGN_OR_RETURN(*algorithm_desc,
3997 GetCudnnConvolutionBackwardFilterAlgorithm(
3998 stream, cudnn, algorithm_config, input_nd,
3999 filter_nd, element_type, convolution_descriptor,
4000 output_nd, scratch_allocator, scratch_memory));
4001 break;
4002 }
4003 default:
4004 return port::InternalError(
4005 absl::StrCat("Unexpected convolution kind ", static_cast<int>(kind)));
4006 }
4007
4008 return ::tensorflow::OkStatus();
4009 }
4010
4011 class CudnnLegacyConvRunner : public dnn::ConvRunner {
4012 public:
4013 // Queries the workspace size and constructs a 'CudnnLegacyConvRunner'.
Create(GpuExecutor * parent,Stream * stream,CudnnAccess * cudnn,const dnn::AlgorithmDesc & algo,dnn::DataType input_type,dnn::DataType output_type,dnn::ConvolutionKind kind,CudnnTensorDescriptor input_nd,CudnnTensorDescriptor output_nd,CudnnFilterDescriptor filter,CudnnConvolutionDescriptor conv)4014 static port::StatusOr<CudnnLegacyConvRunner> Create(
4015 GpuExecutor* parent, Stream* stream, CudnnAccess* cudnn,
4016 const dnn::AlgorithmDesc& algo, dnn::DataType input_type,
4017 dnn::DataType output_type, dnn::ConvolutionKind kind,
4018 CudnnTensorDescriptor input_nd, CudnnTensorDescriptor output_nd,
4019 CudnnFilterDescriptor filter, CudnnConvolutionDescriptor conv) {
4020 size_t workspace_size;
4021 if (algo.workspace_size()) {
4022 workspace_size = *algo.workspace_size();
4023 } else {
4024 // For old AlgorithmProtos loaded from serialized autotune maps and for
4025 // AlgorithmDescs constructed by manually specifying an algorithm ID, we
4026 // need to compute the workspace size here.
4027 auto handle = cudnn->GetHandle(parent, stream);
4028
4029 switch (kind) {
4030 case dnn::ConvolutionKind::FORWARD:
4031 RETURN_IF_CUDNN_ERROR(cudnnGetConvolutionForwardWorkspaceSize(
4032 handle.handle(),
4033 /*xDesc=*/input_nd.handle(),
4034 /*wDesc=*/filter.handle(), /*convDesc=*/conv.handle(),
4035 /*yDesc=*/output_nd.handle(),
4036 /*algo=*/ToConvForwardAlgo(algo),
4037 /*sizeInBytes=*/&workspace_size));
4038 break;
4039 case dnn::ConvolutionKind::BACKWARD_FILTER:
4040 RETURN_IF_CUDNN_ERROR(cudnnGetConvolutionBackwardFilterWorkspaceSize(
4041 handle.handle(),
4042 /*xDesc=*/input_nd.handle(),
4043 /*dyDesc=*/output_nd.handle(),
4044 /*convDesc=*/conv.handle(),
4045 /*gradDesc=*/filter.handle(),
4046 /*algo=*/ToConvBackwardFilterAlgo(algo),
4047 /*sizeInBytes=*/&workspace_size));
4048 break;
4049 case dnn::ConvolutionKind::BACKWARD_DATA:
4050 RETURN_IF_CUDNN_ERROR(cudnnGetConvolutionBackwardDataWorkspaceSize(
4051 handle.handle(),
4052 /*wDesc=*/filter.handle(),
4053 /*dyDesc=*/output_nd.handle(),
4054 /*convDesc=*/conv.handle(),
4055 /*dxDesc=*/input_nd.handle(),
4056 /*algo=*/ToConvBackwardDataAlgo(algo),
4057 /*sizeInBytes=*/&workspace_size));
4058 break;
4059 default:
4060 return port::InternalError(
4061 "Invalid ConvolutionKind for CudnnLegacyConvRunner.");
4062 }
4063 }
4064
4065 return {{parent, cudnn, algo.algo_id(), algo.tensor_ops_enabled(),
4066 workspace_size, input_type, output_type, kind, std::move(input_nd),
4067 std::move(output_nd), std::move(filter), std::move(conv)}};
4068 }
4069
ToString() const4070 std::string ToString() const override {
4071 return MakeAlgorithmDesc().ToString();
4072 }
4073
GetWorkspaceSize() const4074 size_t GetWorkspaceSize() const override { return workspace_size_; }
4075
ToAlgorithmDesc() const4076 port::StatusOr<dnn::AlgorithmDesc> ToAlgorithmDesc() const override {
4077 return MakeAlgorithmDesc();
4078 }
4079
operator ()(Stream * stream,dnn::ProfileResult * profile_result,DeviceMemoryBase scratch_memory,DeviceMemoryBase input_data,DeviceMemoryBase filter_data,DeviceMemoryBase output_data) const4080 port::Status operator()(Stream* stream, dnn::ProfileResult* profile_result,
4081 DeviceMemoryBase scratch_memory,
4082 DeviceMemoryBase input_data,
4083 DeviceMemoryBase filter_data,
4084 DeviceMemoryBase output_data) const override {
4085 auto algo = MakeAlgorithmDesc();
4086
4087 // Check that the current stream supports tensor ops if they're requested.
4088 TF_RETURN_IF_ERROR(UseTensorOps(stream, input_type_, algo).status());
4089
4090 if (static_cast<internal::StreamExecutorInterface*>(parent_) !=
4091 stream->parent()->implementation()) {
4092 return port::InternalError(
4093 "CudnnLegacyConvRunner cached across multiple StreamExecutors.");
4094 }
4095
4096 auto cudnn = cudnn_->GetHandle(parent_, stream);
4097 // Alpha is the scaling factor for input.
4098 float falpha = 1.0;
4099 double dalpha = 1.0;
4100 void* alpha = input_type_ == dnn::DataType::kDouble
4101 ? static_cast<void*>(&dalpha)
4102 : static_cast<void*>(&falpha);
4103 // Beta is the scaling factor for output.
4104 float fbeta = 0.0;
4105 double dbeta = 0.0;
4106 void* beta = input_type_ == dnn::DataType::kDouble
4107 ? static_cast<void*>(&dbeta)
4108 : static_cast<void*>(&fbeta);
4109
4110 const bool is_profiling = profile_result != nullptr;
4111
4112 std::unique_ptr<GpuTimer, GpuTimerDeleter> timer;
4113 if (is_profiling) {
4114 timer.reset(new GpuTimer(parent_)); // NOLINT
4115 // The start and stop of the timer should be as close to the Cudnn call as
4116 // possible. It is still possible for other threads to issue workload on
4117 // to this stream. So it could take multiple profiling measurements.
4118 if (!timer->Init() || !timer->Start(AsGpuStream(stream))) {
4119 return port::Status(port::error::INTERNAL, "Failed to start timer");
4120 }
4121 }
4122
4123 const auto get_fwd_bugs = [&]() -> port::Status {
4124 #if CUDNN_VERSION < 8000
4125 if (algo_id_ == CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM &&
4126 ToCudnnDataType(input_type_) == CUDNN_DATA_INT8 &&
4127 ToCudnnDataType(output_type_) == CUDNN_DATA_FLOAT) {
4128 return port::Status(
4129 port::error::FAILED_PRECONDITION,
4130 "This configuration potentially produces incorrect results.");
4131 }
4132 #else
4133 (void)output_type_; // To stop clang-tidy saying it's unused.
4134 #endif
4135 return ::tensorflow::OkStatus();
4136 };
4137
4138 auto get_bwd_data_bugs = [&]() -> port::Status {
4139 return ::tensorflow::OkStatus();
4140 };
4141
4142 const auto get_bwd_filter_bugs = [&]() -> port::Status {
4143 return ::tensorflow::OkStatus();
4144 };
4145
4146 switch (kind_) {
4147 case dnn::ConvolutionKind::FORWARD: {
4148 TF_RETURN_IF_ERROR(get_fwd_bugs());
4149 RETURN_IF_CUDNN_ERROR(cudnnConvolutionForward(
4150 cudnn.handle(),
4151 /*alpha=*/alpha, /*srcDesc=*/input_nd_.handle(),
4152 /*srcData=*/input_data.opaque(), /*filterDesc=*/filter_.handle(),
4153 /*filterData=*/filter_data.opaque(), /*convDesc=*/conv_.handle(),
4154 /*algo=*/ToConvForwardAlgo(algo),
4155 /*workSpace=*/scratch_memory.opaque(),
4156 /*workSpaceSizeInBytes=*/scratch_memory.size(), /*beta=*/beta,
4157 /*yDesc=*/output_nd_.handle(), /*y=*/output_data.opaque()));
4158 break;
4159 }
4160 case dnn::ConvolutionKind::BACKWARD_DATA: {
4161 TF_RETURN_IF_ERROR(get_bwd_data_bugs());
4162 RETURN_IF_CUDNN_ERROR(cudnnConvolutionBackwardData(
4163 cudnn.handle(),
4164 /*alpha=*/alpha,
4165 /*wDesc=*/filter_.handle(),
4166 /*w=*/filter_data.opaque(),
4167 /*dyDesc=*/output_nd_.handle(),
4168 /*dy=*/output_data.opaque(),
4169 /*convDesc=*/conv_.handle(),
4170 /*algo=*/ToConvBackwardDataAlgo(algo),
4171 /*workSpace=*/scratch_memory.opaque(),
4172 /*workSpaceSizeInBytes=*/scratch_memory.size(),
4173 /*beta=*/beta,
4174 /*dxDesc=*/input_nd_.handle(),
4175 /*dx=*/input_data.opaque()));
4176 break;
4177 }
4178 case dnn::ConvolutionKind::BACKWARD_FILTER: {
4179 TF_RETURN_IF_ERROR(get_bwd_filter_bugs());
4180 RETURN_IF_CUDNN_ERROR(cudnnConvolutionBackwardFilter(
4181 cudnn.handle(),
4182 /*alpha=*/alpha,
4183 /*srcDesc=*/input_nd_.handle(),
4184 /*srcData=*/input_data.opaque(),
4185 /*diffDesc=*/output_nd_.handle(),
4186 /*diffData=*/output_data.opaque(),
4187 /*convDesc=*/conv_.handle(),
4188 /*algo=*/ToConvBackwardFilterAlgo(algo),
4189 /*workSpace=*/scratch_memory.opaque(),
4190 /*workSpaceSizeInBytes=*/scratch_memory.size(),
4191 /*beta=*/beta,
4192 /*gradDesc=*/filter_.handle(),
4193 /*dw=*/filter_data.opaque()));
4194 break;
4195 }
4196 default:
4197 return port::InternalError(absl::StrCat("Unexpected convolution kind ",
4198 static_cast<int>(kind_)));
4199 }
4200
4201 if (is_profiling) {
4202 if (!timer->Stop(AsGpuStream(stream))) {
4203 return port::Status(port::error::INTERNAL, "Failed to stop timer");
4204 }
4205 profile_result->set_algorithm(algo);
4206 profile_result->set_elapsed_time_in_ms(timer->GetElapsedMilliseconds());
4207 profile_result->set_scratch_size(scratch_memory.size());
4208 }
4209
4210 return ::tensorflow::OkStatus();
4211 }
4212
4213 private:
4214 // Private to prevent passing in the wrong workspace_size.
CudnnLegacyConvRunner(GpuExecutor * parent,CudnnAccess * cudnn,int64_t algo_id,bool tensor_ops_enabled,size_t workspace_size,dnn::DataType input_type,dnn::DataType output_type,dnn::ConvolutionKind kind,CudnnTensorDescriptor input_nd,CudnnTensorDescriptor output_nd,CudnnFilterDescriptor filter,CudnnConvolutionDescriptor conv)4215 CudnnLegacyConvRunner(GpuExecutor* parent, CudnnAccess* cudnn,
4216 int64_t algo_id, bool tensor_ops_enabled,
4217 size_t workspace_size, dnn::DataType input_type,
4218 dnn::DataType output_type, dnn::ConvolutionKind kind,
4219 CudnnTensorDescriptor input_nd,
4220 CudnnTensorDescriptor output_nd,
4221 CudnnFilterDescriptor filter,
4222 CudnnConvolutionDescriptor conv)
4223 : parent_(parent),
4224 cudnn_(cudnn),
4225 algo_id_(algo_id),
4226 tensor_ops_enabled_(tensor_ops_enabled),
4227 workspace_size_(workspace_size),
4228 kind_(kind),
4229 input_type_(input_type),
4230 output_type_(output_type),
4231 input_nd_(std::move(input_nd)),
4232 output_nd_(std::move(output_nd)),
4233 filter_(std::move(filter)),
4234 conv_(std::move(conv)) {}
4235
4236 // Internal form of ToAlgorithmDesc without the StatusOr.
MakeAlgorithmDesc() const4237 dnn::AlgorithmDesc MakeAlgorithmDesc() const {
4238 return {algo_id_, tensor_ops_enabled_, workspace_size_};
4239 }
4240
4241 GpuExecutor* parent_;
4242 CudnnAccess* cudnn_;
4243 int64_t algo_id_;
4244 bool tensor_ops_enabled_;
4245 size_t workspace_size_;
4246 dnn::ConvolutionKind kind_;
4247 dnn::DataType input_type_;
4248 dnn::DataType output_type_;
4249
4250 CudnnTensorDescriptor input_nd_;
4251 CudnnTensorDescriptor output_nd_;
4252 CudnnFilterDescriptor filter_;
4253 CudnnConvolutionDescriptor conv_;
4254 };
4255
DoConvolve(dnn::ConvolutionKind kind,dnn::DataType element_type,dnn::DataType output_type,Stream * stream,const dnn::BatchDescriptor & input_descriptor,DeviceMemoryBase input_data,const dnn::FilterDescriptor & filter_descriptor,DeviceMemoryBase filter_data,const dnn::BatchDescriptor & output_descriptor,DeviceMemoryBase output_data,const dnn::ConvolutionDescriptor & convolution_descriptor,dnn::AlgorithmDesc algorithm_desc,DeviceMemory<uint8> scratch_memory,dnn::ProfileResult * profile_result)4256 port::Status CudnnSupport::DoConvolve(
4257 dnn::ConvolutionKind kind, dnn::DataType element_type,
4258 dnn::DataType output_type, Stream* stream,
4259 const dnn::BatchDescriptor& input_descriptor, DeviceMemoryBase input_data,
4260 const dnn::FilterDescriptor& filter_descriptor,
4261 DeviceMemoryBase filter_data, const dnn::BatchDescriptor& output_descriptor,
4262 DeviceMemoryBase output_data,
4263 const dnn::ConvolutionDescriptor& convolution_descriptor,
4264 dnn::AlgorithmDesc algorithm_desc, DeviceMemory<uint8> scratch_memory,
4265 dnn::ProfileResult* profile_result) {
4266 cudnnDataType_t cudnn_type =
4267 ToCudnnDataType(element_type, input_descriptor.layout());
4268 CudnnTensorDescriptor input_nd(input_descriptor, cudnn_type);
4269 CudnnTensorDescriptor output_nd(
4270 output_descriptor,
4271 ToCudnnDataType(output_type, output_descriptor.layout()));
4272 CudnnFilterDescriptor filter_nd(
4273 filter_descriptor,
4274 ToCudnnDataType(element_type, filter_descriptor.layout()));
4275
4276 auto accumulator_type = GetConvAccumulatorType(element_type);
4277 CudnnConvolutionDescriptor conv(convolution_descriptor,
4278 ToCudnnDataType(accumulator_type));
4279 TF_ASSIGN_OR_RETURN(bool use_tensor_ops,
4280 UseTensorOps(stream, element_type, algorithm_desc));
4281 conv.set_use_tensor_op_math(use_tensor_ops);
4282
4283 TF_ASSIGN_OR_RETURN(
4284 auto runner,
4285 CudnnLegacyConvRunner::Create(
4286 parent_, stream, cudnn_.get(), algorithm_desc, element_type,
4287 output_type, kind, std::move(input_nd), std::move(output_nd),
4288 std::move(filter_nd), std::move(conv)));
4289 return runner(stream, profile_result, scratch_memory, input_data, filter_data,
4290 output_data);
4291 }
4292
4293 #if CUDNN_VERSION >= 8100 && TF_ENABLE_CUDNN_FRONTEND
4294
4295 struct BackendDescriptorDeleter {
operator ()stream_executor::gpu::BackendDescriptorDeleter4296 void operator()(cudnnBackendDescriptor_t desc) {
4297 cudnnBackendDestroyDescriptor(desc);
4298 }
4299 };
4300
4301 using BackendDescriptor = std::unique_ptr<void, BackendDescriptorDeleter>;
4302
CreateBackendDesc(cudnnBackendDescriptorType_t type)4303 port::StatusOr<BackendDescriptor> CreateBackendDesc(
4304 cudnnBackendDescriptorType_t type) {
4305 void* result;
4306 RETURN_IF_CUDNN_ERROR(cudnnBackendCreateDescriptor(type, &result));
4307 return BackendDescriptor(result);
4308 }
4309
4310 // Get the values of a CUDNN_TYPE_BACKEND_DESCRIPTOR attribute as a vector.
4311 //
4312 // This is fetching the entirety of a single sequence-valued attribute, as
4313 // opposed to a sequence of multiple attributes. The distinction is a bit
4314 // meaningless, but this is the presentation the cuDNN docs use, so it may as
4315 // well be consistent.
GetDescriptorAttribute(cudnnBackendDescriptor_t desc,cudnnBackendAttributeName_t name,cudnnBackendDescriptorType_t type)4316 port::StatusOr<std::vector<BackendDescriptor>> GetDescriptorAttribute(
4317 cudnnBackendDescriptor_t desc, cudnnBackendAttributeName_t name,
4318 cudnnBackendDescriptorType_t type) {
4319 int64_t n;
4320 RETURN_IF_CUDNN_ERROR(cudnnBackendGetAttribute(
4321 desc, name, CUDNN_TYPE_BACKEND_DESCRIPTOR, 0, &n, nullptr));
4322
4323 std::vector<BackendDescriptor> result(n);
4324 for (int i = 0; i < n; ++i) {
4325 TF_ASSIGN_OR_RETURN(result[i], CreateBackendDesc(type));
4326 }
4327
4328 std::vector<cudnnBackendDescriptor_t> raw_ptrs;
4329 raw_ptrs.reserve(result.size());
4330 absl::c_transform(result, std::back_inserter(raw_ptrs),
4331 [](const BackendDescriptor& ptr) { return ptr.get(); });
4332
4333 // This API evidently does a deep copy of the descriptors into the pointers in
4334 // the output array, rather than writing pointers to the descriptors into the
4335 // output array. So, this writes the memory behind each BackendDescriptor in
4336 // result, rather than writing the contents of raw_ptrs.
4337 RETURN_IF_CUDNN_ERROR(cudnnBackendGetAttribute(
4338 desc, name, CUDNN_TYPE_BACKEND_DESCRIPTOR, n, &n, raw_ptrs.data()));
4339
4340 return result;
4341 }
4342
4343 // Extract the engine ID and tuning knobs from the ExecutionPlan, and return
4344 // them in the form of an AlgorithmDesc for use with RebuildExecutionPlan.
ExecutionPlanToAlgorithmDesc(const cudnn_frontend::ExecutionPlan & plan,size_t workspace_size)4345 port::StatusOr<dnn::AlgorithmDesc> ExecutionPlanToAlgorithmDesc(
4346 const cudnn_frontend::ExecutionPlan& plan, size_t workspace_size) {
4347 TF_ASSIGN_OR_RETURN(
4348 auto engine_cfgs,
4349 GetDescriptorAttribute(plan.get_raw_desc(),
4350 CUDNN_ATTR_EXECUTION_PLAN_ENGINE_CONFIG,
4351 CUDNN_BACKEND_ENGINECFG_DESCRIPTOR));
4352 if (engine_cfgs.size() != 1) {
4353 return port::InternalError(
4354 "CUDNN_ATTR_EXECUTION_PLAN_ENGINE_CONFIG had more than one element.");
4355 }
4356
4357 TF_ASSIGN_OR_RETURN(
4358 auto engines,
4359 GetDescriptorAttribute(engine_cfgs[0].get(), CUDNN_ATTR_ENGINECFG_ENGINE,
4360 CUDNN_BACKEND_ENGINE_DESCRIPTOR));
4361 if (engines.size() != 1) {
4362 return port::InternalError(
4363 "CUDNN_ATTR_ENGINECFG_ENGINE had more than one element.");
4364 }
4365
4366 int64_t n;
4367 int64_t engine_id;
4368 RETURN_IF_CUDNN_ERROR(
4369 cudnnBackendGetAttribute(engines[0].get(), CUDNN_ATTR_ENGINE_GLOBAL_INDEX,
4370 CUDNN_TYPE_INT64, 1, &n, &engine_id));
4371
4372 // Apparently for CUDNN_ATTR_ENGINECFG_KNOB_CHOICES only, trying to query the
4373 // number of elements in the attribute by using an output limit value of 0
4374 // just returns 0; the only way to find out how many there are is to
4375 // pre-allocate space for every existing knob type (as an upper bound on the
4376 // number of knob choices a config can have), and then look back at how many
4377 // were filled.
4378 std::vector<BackendDescriptor> knobs(CUDNN_KNOB_TYPE_COUNTS);
4379 for (int i = 0; i < knobs.size(); ++i) {
4380 TF_ASSIGN_OR_RETURN(
4381 knobs[i], CreateBackendDesc(CUDNN_BACKEND_KNOB_CHOICE_DESCRIPTOR));
4382 }
4383 std::vector<cudnnBackendDescriptor_t> raw_knob_ptrs;
4384 raw_knob_ptrs.reserve(knobs.size());
4385 absl::c_transform(knobs, std::back_inserter(raw_knob_ptrs),
4386 [](const BackendDescriptor& ptr) { return ptr.get(); });
4387 RETURN_IF_CUDNN_ERROR(cudnnBackendGetAttribute(
4388 engine_cfgs[0].get(), CUDNN_ATTR_ENGINECFG_KNOB_CHOICES,
4389 CUDNN_TYPE_BACKEND_DESCRIPTOR, raw_knob_ptrs.size(), &n,
4390 raw_knob_ptrs.data()));
4391 knobs.resize(n);
4392
4393 absl::flat_hash_map<int64_t, int64_t> tuning_knobs;
4394 for (const auto& knob : knobs) {
4395 cudnnBackendKnobType_t knob_type;
4396 int64_t knob_value;
4397
4398 RETURN_IF_CUDNN_ERROR(
4399 cudnnBackendGetAttribute(knob.get(), CUDNN_ATTR_KNOB_CHOICE_KNOB_TYPE,
4400 CUDNN_TYPE_KNOB_TYPE, 1, &n, &knob_type));
4401 if (n != 1) {
4402 return port::InternalError(
4403 absl::StrCat("Knob should have exactly one KNOB_TYPE; had ", n));
4404 }
4405
4406 RETURN_IF_CUDNN_ERROR(
4407 cudnnBackendGetAttribute(knob.get(), CUDNN_ATTR_KNOB_CHOICE_KNOB_VALUE,
4408 CUDNN_TYPE_INT64, 1, &n, &knob_value));
4409 if (n != 1) {
4410 return port::InternalError(
4411 absl::StrCat("Knob should have exactly one KNOB_VALUE; had ", n));
4412 }
4413
4414 auto emplaced = tuning_knobs.try_emplace(knob_type, knob_value).second;
4415 if (!emplaced) {
4416 return port::InternalError(absl::StrFormat(
4417 "cuDNN gave multiple knob values for the same knob type.\n"
4418 " KNOB_TYPE: %d\n"
4419 " new KNOB_VALUE: %d\n"
4420 " old KNOB_VALUE: %d",
4421 knob_type, knob_value, tuning_knobs.at(knob_type)));
4422 }
4423 }
4424
4425 std::vector<std::pair<int64_t, int64_t>> tuning_knobs_vec;
4426 tuning_knobs_vec.reserve(tuning_knobs.size());
4427 absl::c_copy(tuning_knobs, std::back_inserter(tuning_knobs_vec));
4428
4429 return dnn::AlgorithmDesc(engine_id, tuning_knobs_vec, workspace_size);
4430 }
4431
4432 template <typename Sig>
4433 class CudnnExecutionPlanRunner;
4434 // An OpRunner implemented by an ExecutionPlan.
4435 //
4436 // This is the class holding the implementation of ToString, GetWorkspaceSize,
4437 // and operator() for use by the cudnn frontend op runners.
4438 template <typename... Args>
4439 class CudnnExecutionPlanRunner<void(Args...)>
4440 : public dnn::OpRunner<void(Args...)> {
4441 public:
ToString() const4442 std::string ToString() const override { return plan_.getTag(); }
4443
GetWorkspaceSize() const4444 size_t GetWorkspaceSize() const override { return workspace_size_; }
4445
ToAlgorithmDesc() const4446 port::StatusOr<dnn::AlgorithmDesc> ToAlgorithmDesc() const override {
4447 return ExecutionPlanToAlgorithmDesc(plan_, workspace_size_);
4448 }
4449
operator ()(Stream * stream,dnn::ProfileResult * profile_result,DeviceMemoryBase scratch_memory,Args...inputs) const4450 port::Status operator()(Stream* stream, dnn::ProfileResult* profile_result,
4451 DeviceMemoryBase scratch_memory,
4452 Args... inputs) const override {
4453 if (static_cast<internal::StreamExecutorInterface*>(parent_) !=
4454 stream->parent()->implementation()) {
4455 return port::InternalError(
4456 "CudnnExecutionPlanRunner cached across multiple StreamExecutors.");
4457 }
4458
4459 auto cudnn = cudnn_->GetHandle(parent_, stream);
4460
4461 size_t workspace_size = plan_.getWorkspaceSize();
4462 RETURN_MSG_IF_CUDNN_ERROR(plan_);
4463
4464 std::array<void*, sizeof...(Args)> data_ptrs = {inputs.opaque()...};
4465
4466 absl::InlinedVector<int64_t, sizeof...(Args)> data_uids_vec(
4467 data_uids_.cbegin(), data_uids_.cend());
4468 absl::InlinedVector<void*, sizeof...(Args)> data_ptrs_vec(
4469 data_ptrs.cbegin(), data_ptrs.cend());
4470 // We use need_side_input to determine if the side input 'z' from
4471 // {'x', 'w', 'z', 'b', 'y'} is needed for the conv-<add>-bias-act patterns.
4472 if (sizeof...(Args) == 5 && !need_side_input_) {
4473 data_uids_vec.erase(data_uids_vec.begin() + 2);
4474 data_ptrs_vec.erase(data_ptrs_vec.begin() + 2);
4475 }
4476
4477 auto variantPack =
4478 cudnn_frontend::VariantPackBuilder()
4479 .setWorkspacePointer(scratch_memory.opaque())
4480 .setDataPointers(data_ptrs_vec.size(), data_ptrs_vec.data())
4481 .setUids(data_uids_vec.size(), data_uids_vec.data())
4482 .build();
4483 RETURN_MSG_IF_CUDNN_ERROR(variantPack);
4484
4485 VLOG(4) << "\nDo cudnn execution plan with plan tag: " << plan_.getTag()
4486 << "\nWorkspace size in bytes: " << workspace_size
4487 << "\nVariantPack: " << variantPack.describe();
4488
4489 const bool is_profiling = profile_result != nullptr;
4490
4491 std::unique_ptr<GpuTimer, GpuTimerDeleter> timer;
4492 if (is_profiling) {
4493 timer.reset(new GpuTimer(parent_)); // NOLINT
4494 // The start and stop of the timer should be as close to the Cudnn call as
4495 // possible. It is still possible for other threads to issue workload on
4496 // to this stream. So it could take multiple profiling measurements.
4497 if (!timer->Init() || !timer->Start(AsGpuStream(stream))) {
4498 return port::Status(port::error::INTERNAL, "Failed to start timer");
4499 }
4500 }
4501
4502 cudnnStatus_t status = cudnnBackendExecute(
4503 cudnn.handle(), plan_.get_raw_desc(), variantPack.get_raw_desc());
4504 RETURN_IF_CUDNN_ERROR(status);
4505
4506 if (is_profiling) {
4507 if (!timer->Stop(AsGpuStream(stream))) {
4508 return port::Status(port::error::INTERNAL, "Failed to stop timer");
4509 }
4510 TF_ASSIGN_OR_RETURN(auto desc, ToAlgorithmDesc());
4511 profile_result->set_algorithm(desc);
4512 profile_result->set_elapsed_time_in_ms(timer->GetElapsedMilliseconds());
4513 profile_result->set_scratch_size(scratch_memory.size());
4514
4515 VLOG(4) << "cudnn op with plan " << plan_.getTag()
4516 << ", workspace_size=" << workspace_size << " -> "
4517 << CudnnStatusToString(status) << " in "
4518 << timer->GetElapsedMilliseconds() << "ms";
4519 }
4520
4521 return port::Status::OK();
4522 }
4523
Create(GpuExecutor * parent,CudnnAccess * cudnn,cudnn_frontend::ExecutionPlan plan,absl::Span<const int64_t> uids,bool need_side_input)4524 static port::StatusOr<CudnnExecutionPlanRunner> Create(
4525 GpuExecutor* parent, CudnnAccess* cudnn,
4526 cudnn_frontend::ExecutionPlan plan, absl::Span<const int64_t> uids,
4527 bool need_side_input) {
4528 auto workspace_size = static_cast<uint64_t>(plan.getWorkspaceSize());
4529 RETURN_MSG_IF_CUDNN_ERROR(plan);
4530 return {{parent, cudnn, std::move(plan), workspace_size, uids,
4531 need_side_input}};
4532 }
4533
4534 private:
CudnnExecutionPlanRunner(GpuExecutor * parent,CudnnAccess * cudnn,cudnn_frontend::ExecutionPlan plan,size_t workspace_size,absl::Span<const int64_t> uids,bool need_side_input)4535 CudnnExecutionPlanRunner(GpuExecutor* parent, CudnnAccess* cudnn,
4536 cudnn_frontend::ExecutionPlan plan,
4537 size_t workspace_size,
4538 absl::Span<const int64_t> uids, bool need_side_input)
4539 : parent_(parent),
4540 cudnn_(cudnn),
4541 plan_(std::move(plan)),
4542 workspace_size_(workspace_size),
4543 data_uids_(uids.begin(), uids.end()),
4544 need_side_input_(need_side_input) {}
4545 GpuExecutor* parent_;
4546 CudnnAccess* cudnn_;
4547 cudnn_frontend::ExecutionPlan plan_;
4548 size_t workspace_size_;
4549 absl::InlinedVector<int64_t, sizeof...(Args)> data_uids_;
4550 bool need_side_input_;
4551 };
4552 #endif // CUDNN_VERSION >= 8100 && TF_ENABLE_CUDNN_FRONTEND
4553
4554 // Utility for dealing with CUDA's type-erased scaling parameters, where some
4555 // sets of parameters expect a void* pointing at a float while others expect it
4556 // to point at a double.
4557 //
4558 // This is rather ugly, but its purpose is to quarantine the corresponding
4559 // ugliness that already exists in the CUDA API.
4560 class ScalingParam {
4561 public:
ScalingParam(double value)4562 explicit ScalingParam(double value) : as_double_(value), as_float_(value) {}
4563
4564 // Return a pointer to the appropriate representation type for the given
4565 // element type.
4566 //
4567 // See
4568 // https://docs.nvidia.com/deeplearning/cudnn/developer-guide/index.html#scaling-parameters
4569 // for more info; the behavior for int8 result tensors is not described there,
4570 // but is maintained from the existing behavior (namely, using a float scaling
4571 // parameter).
ToVoidPointer(dnn::DataType element_type)4572 void* ToVoidPointer(dnn::DataType element_type) {
4573 if (element_type == dnn::DataType::kDouble) {
4574 return &as_double_;
4575 } else {
4576 return &as_float_;
4577 }
4578 }
4579
4580 private:
4581 double as_double_;
4582 float as_float_;
4583 };
4584
4585 #if CUDNN_VERSION >= 8100 && TF_ENABLE_CUDNN_FRONTEND
4586 namespace {
4587
4588 template <typename Sig>
CreateOpRunners(Stream * stream,CudnnHandle & cudnn,GpuExecutor * gpu_executor,CudnnAccess * cudnn_access,std::unique_ptr<cudnn_frontend::OperationGraph> op_graph,dnn::ConvolutionKind kind,dnn::DataType input_type,absl::Span<const int64_t> input_uids,bool use_fallback,std::vector<std::unique_ptr<const dnn::OpRunner<Sig>>> * out_runners,bool need_side_input)4589 port::Status CreateOpRunners(
4590 Stream* stream, CudnnHandle& cudnn, GpuExecutor* gpu_executor,
4591 CudnnAccess* cudnn_access,
4592 std::unique_ptr<cudnn_frontend::OperationGraph> op_graph,
4593 dnn::ConvolutionKind kind, dnn::DataType input_type,
4594 absl::Span<const int64_t> input_uids, bool use_fallback,
4595 std::vector<std::unique_ptr<const dnn::OpRunner<Sig>>>* out_runners,
4596 bool need_side_input) {
4597 cudnn_frontend::EngineConfigList filtered_configs;
4598 auto generic_filter_fn = [=](cudnnBackendDescriptor_t engine_config) -> bool {
4599 return GenericEngineFilter(
4600 engine_config,
4601 /*disable_winograd*/ !CudnnEnvVar<WinogradNonfused>::IsEnabled(),
4602 /*disable_nondeterminism*/ RequireCudnnDeterminism(),
4603 /*disable_tensor_core*/ !IsTensorMathEnabled(stream, input_type));
4604 };
4605
4606 if (!use_fallback) {
4607 // In theory, mode GetCudnnFrontendHeurMode() is supposed to fall back to
4608 // HEUR_MODE_INSTANT if it can't find any working engines. But there's a
4609 // known cudnn issue where it doesn't when dealing with runtime compiled
4610 // fusion engines. So we do it manually here.
4611 // TODO(kaixih@nvidia): remove this when the cudnn fixes it.
4612 cudnn_frontend::EngineHeuristics heuristics = [&] {
4613 auto heuristics = cudnn_frontend::EngineHeuristicsBuilder()
4614 .setOperationGraph(*op_graph)
4615 .setHeurMode(GetCudnnFrontendHeurMode())
4616 .build();
4617 if (heuristics.get_status() == CUDNN_STATUS_SUCCESS) return heuristics;
4618 return cudnn_frontend::EngineHeuristicsBuilder()
4619 .setOperationGraph(*op_graph)
4620 .setHeurMode(CUDNN_HEUR_MODE_INSTANT)
4621 .build();
4622 }();
4623 RETURN_MSG_IF_CUDNN_ERROR(heuristics);
4624
4625 // cuDNN frontend sneakily puts error messages on the object and returns
4626 // partially-initialized results when there's an error; make sure to check
4627 // them.
4628 int64_t engine_count = heuristics.getEngineConfigCount();
4629 RETURN_MSG_IF_CUDNN_ERROR(heuristics);
4630 auto& heuristics_configs = heuristics.getEngineConfig(engine_count);
4631 RETURN_MSG_IF_CUDNN_ERROR(heuristics);
4632 VLOG(4) << "\nHeuristics engine configs size: "
4633 << heuristics_configs.size();
4634
4635 cudnn_frontend::filter(heuristics_configs, filtered_configs,
4636 generic_filter_fn);
4637 } else {
4638 auto fallback = cudnn_frontend::EngineFallbackListBuilder()
4639 .setOperationGraph(*op_graph)
4640 .setOperation(GetCudnnConvolutionType(kind))
4641 .build();
4642 RETURN_MSG_IF_CUDNN_ERROR(fallback);
4643
4644 auto& fallback_configs = fallback.getFallbackList();
4645 VLOG(4) << "\nFallback engine configs size: " << fallback_configs.size();
4646
4647 cudnn_frontend::filter(fallback_configs, filtered_configs,
4648 generic_filter_fn);
4649 }
4650 VLOG(4) << "\nFiltered engine configs size: " << filtered_configs.size();
4651
4652 auto fn = []() { return true; };
4653 auto maybe_json_handle_static = CudnnExecutionPlanEngineFilterStatic();
4654 auto maybe_json_handle_runtime = CudnnExecutionPlanEngineFilterRuntime();
4655
4656 out_runners->clear();
4657 for (int i = 0; i < filtered_configs.size(); i++) {
4658 auto plan = cudnn_frontend::ExecutionPlanBuilder()
4659 .setHandle(cudnn.handle())
4660 .setEngineConfig(filtered_configs[i], op_graph->getTag())
4661 .build();
4662 if (plan.get_status() != CUDNN_STATUS_SUCCESS) {
4663 continue;
4664 }
4665
4666 if (maybe_json_handle_static &&
4667 cudnn_frontend::check_errata(*maybe_json_handle_static, plan.getTag(),
4668 cudnn.handle(), fn)) {
4669 VLOG(4) << "Exclude engine (static): " << plan.getTag();
4670 continue;
4671 }
4672 if (maybe_json_handle_runtime &&
4673 cudnn_frontend::check_errata(*maybe_json_handle_runtime, plan.getTag(),
4674 cudnn.handle(), fn)) {
4675 VLOG(4) << "Exclude engine (runtime): " << plan.getTag();
4676 continue;
4677 }
4678
4679 auto runner_or = CudnnExecutionPlanRunner<Sig>::Create(
4680 gpu_executor, cudnn_access, std::move(plan), input_uids,
4681 need_side_input);
4682 if (!runner_or.ok()) {
4683 // Note this can happen if cuDNN Frontend gives us partially-initialized
4684 // ExecutionPlans because its error handling is broken in non-exception
4685 // builds; those were meant to be filtered out earlier inside cuDNN
4686 // Frontend, but instead they get filtered out here.
4687 VLOG(4) << "Failed building runner from ExecutionPlan (i.e. failed "
4688 "getting its workspace size): "
4689 << runner_or.status().ToString();
4690 continue;
4691 }
4692
4693 out_runners->push_back(std::make_unique<CudnnExecutionPlanRunner<Sig>>(
4694 std::move(runner_or).value()));
4695
4696 // We will use the first working plan when determinism is required.
4697 if (RequireCudnnDeterminism()) {
4698 break;
4699 }
4700 }
4701
4702 VLOG(4) << "\nReturned execution plans size: " << out_runners->size();
4703
4704 return port::Status::OK();
4705 }
4706
4707 } // namespace
4708 #endif // CUDNN_VERSION >= 8100 && TF_ENABLE_CUDNN_FRONTEND
4709
GetConvolveRunners(bool use_cudnn_frontend,dnn::ConvolutionKind kind,dnn::DataType input_type,dnn::DataType output_type,Stream * stream,const dnn::BatchDescriptor & input_descriptor,DeviceMemoryBase,const dnn::FilterDescriptor & filter_descriptor,DeviceMemoryBase,const dnn::BatchDescriptor & output_descriptor,DeviceMemoryBase,const dnn::ConvolutionDescriptor & convolution_descriptor,bool use_fallback,ScratchAllocator *,std::vector<std::unique_ptr<const dnn::ConvRunner>> * out_exec_plans)4710 port::Status CudnnSupport::GetConvolveRunners(
4711 bool use_cudnn_frontend, dnn::ConvolutionKind kind,
4712 dnn::DataType input_type, dnn::DataType output_type, Stream* stream,
4713 const dnn::BatchDescriptor& input_descriptor,
4714 DeviceMemoryBase /*input_data*/,
4715 const dnn::FilterDescriptor& filter_descriptor,
4716 DeviceMemoryBase /*filter_data*/,
4717 const dnn::BatchDescriptor& output_descriptor,
4718 DeviceMemoryBase /*output_data*/,
4719 const dnn::ConvolutionDescriptor& convolution_descriptor, bool use_fallback,
4720 ScratchAllocator* /*scratch_allocator*/,
4721 std::vector<std::unique_ptr<const dnn::ConvRunner>>* out_exec_plans) {
4722 // All current versions of the frontend API lack support for Tx32
4723 // convolutions.
4724 const bool is_unsupported_x32 =
4725 input_descriptor.layout() == dnn::kBatchDepthYX32;
4726
4727 // cuDNN frontend support became sufficiently stable to use in 8.1.
4728 // TODO(awpr): remove this condition once support for cuDNN 8.0 is dropped.
4729 const bool is_pre_frontend_cudnn = CUDNN_VERSION < 8100;
4730
4731 const bool actually_use_cudnn_frontend =
4732 use_cudnn_frontend && !is_pre_frontend_cudnn && !is_unsupported_x32;
4733
4734 if (use_cudnn_frontend && !actually_use_cudnn_frontend) {
4735 // This will happen once per unique conv configuration/shape that gets
4736 // affected (and not, for example, on every conv launch). Confusion over
4737 // whether this has happened or not has repeatedly wasted a lot of time
4738 // debugging, so be sure it shows up in the logs.
4739 LOG(INFO) << "Disabling cuDNN frontend for the following convolution:\n"
4740 << " input: " << input_descriptor.ToString() << "\n"
4741 << " filter: " << filter_descriptor.ToString() << "\n"
4742 << " " << convolution_descriptor.ToString() << "\n"
4743 << " ... because "
4744 << (is_unsupported_x32
4745 ? "Tx32 convolutions are unsupported."
4746 : "the current cuDNN version does not support it.");
4747 }
4748
4749 if (!actually_use_cudnn_frontend) {
4750 auto cuda_compute_capability = stream->GetCudaComputeCapability();
4751 std::vector<dnn::AlgorithmDesc> algorithms;
4752 bool got_algos = false;
4753 switch (kind) {
4754 default:
4755 return port::InternalError(absl::StrFormat(
4756 "Unknown ConvolutionKind for unfused conv: %d", kind));
4757 case dnn::ConvolutionKind::FORWARD:
4758 got_algos = GetConvolveAlgorithms(cuda_compute_capability, &algorithms);
4759 break;
4760 case dnn::ConvolutionKind::BACKWARD_FILTER:
4761 got_algos = GetConvolveBackwardFilterAlgorithms(cuda_compute_capability,
4762 &algorithms);
4763 break;
4764 case dnn::ConvolutionKind::BACKWARD_DATA:
4765 got_algos = GetConvolveBackwardDataAlgorithms(cuda_compute_capability,
4766 &algorithms);
4767 break;
4768 }
4769 if (!got_algos) {
4770 return port::Status(
4771 port::error::UNKNOWN,
4772 absl::StrFormat("Listing algorithms failed for kind %d", kind));
4773 }
4774
4775 for (const auto& algo : algorithms) {
4776 auto runner_or = ConvolveRunnerFromDesc(
4777 stream, algo, kind, input_type, output_type, input_descriptor,
4778 filter_descriptor, output_descriptor, convolution_descriptor);
4779 if (!runner_or.ok()) {
4780 // Failures here can result from trying to query the workspace size for
4781 // algorithms that aren't supported for the present configuration. This
4782 // means we'll now return only supported algorithms, unlike the
4783 // predecessor 'GetConvolveAlgorithms', which returned all existing
4784 // algorithms regardless of any particular configuration.
4785 //
4786 // TODO(awpr): can we arrange for the expected errors here to have a
4787 // particular error code (e.g. UNIMPLEMENTED or INVALID_ARGUMENT) and
4788 // log errors for anything unexpected?
4789 continue;
4790 }
4791 out_exec_plans->push_back(std::move(runner_or).value());
4792 }
4793
4794 return ::tensorflow::OkStatus();
4795 }
4796
4797 #if CUDNN_VERSION >= 8100 && TF_ENABLE_CUDNN_FRONTEND
4798 auto cudnn = cudnn_->GetHandle(parent_, stream);
4799 TF_ASSIGN_OR_RETURN(
4800 auto op_graph,
4801 GetCudnnOperationGraph(kind, input_type, output_type, input_descriptor,
4802 filter_descriptor, output_descriptor,
4803 convolution_descriptor, cudnn));
4804
4805 return CreateOpRunners<dnn::ConvSignature>(
4806 stream, cudnn, parent_, cudnn_.get(), std::move(op_graph), kind,
4807 input_type, {'x', 'w', 'y'}, use_fallback, out_exec_plans,
4808 /*need_side_input=*/false);
4809 #else
4810 return port::UnimplementedError(
4811 "Cudnn execution plans are only supported with Cudnn >= 8.1.");
4812 #endif // CUDNN_VERSION >= 8100 && TF_ENABLE_CUDNN_FRONTEND
4813 }
4814
4815 port::StatusOr<std::unique_ptr<const dnn::ConvRunner>>
ConvolveRunnerFromDesc(Stream * stream,const dnn::AlgorithmDesc & algorithm_desc,dnn::ConvolutionKind kind,dnn::DataType input_type,dnn::DataType output_type,const dnn::BatchDescriptor & input_descriptor,const dnn::FilterDescriptor & filter_descriptor,const dnn::BatchDescriptor & output_descriptor,const dnn::ConvolutionDescriptor & convolution_descriptor)4816 CudnnSupport::ConvolveRunnerFromDesc(
4817 Stream* stream, const dnn::AlgorithmDesc& algorithm_desc,
4818 dnn::ConvolutionKind kind, dnn::DataType input_type,
4819 dnn::DataType output_type, const dnn::BatchDescriptor& input_descriptor,
4820 const dnn::FilterDescriptor& filter_descriptor,
4821 const dnn::BatchDescriptor& output_descriptor,
4822 const dnn::ConvolutionDescriptor& convolution_descriptor) {
4823 if (!algorithm_desc.is_cudnn_frontend()) {
4824 CudnnConvolutionDescriptor conv(
4825 convolution_descriptor,
4826 ToCudnnDataType(GetConvAccumulatorType(input_type)));
4827 conv.set_use_tensor_op_math(algorithm_desc.tensor_ops_enabled());
4828
4829 TF_ASSIGN_OR_RETURN(
4830 auto runner,
4831 CudnnLegacyConvRunner::Create(
4832 parent_, stream, cudnn_.get(), algorithm_desc, input_type,
4833 output_type, kind,
4834 /* input_nd = */
4835 CudnnTensorDescriptor(
4836 input_descriptor,
4837 ToCudnnDataType(input_type, input_descriptor.layout())),
4838 /* output_nd = */
4839 CudnnTensorDescriptor(
4840 output_descriptor,
4841 ToCudnnDataType(output_type, output_descriptor.layout())),
4842 /* filter = */
4843 CudnnFilterDescriptor(
4844 filter_descriptor,
4845 ToCudnnDataType(input_type, filter_descriptor.layout())),
4846 std::move(conv)));
4847
4848 return {std::make_unique<CudnnLegacyConvRunner>(std::move(runner))};
4849 }
4850
4851 #if CUDNN_VERSION >= 8100 && TF_ENABLE_CUDNN_FRONTEND
4852 auto cudnn = cudnn_->GetHandle(parent_, stream);
4853
4854 TF_ASSIGN_OR_RETURN(
4855 auto op_graph,
4856 GetCudnnOperationGraph(kind, input_type, output_type, input_descriptor,
4857 filter_descriptor, output_descriptor,
4858 convolution_descriptor, cudnn));
4859
4860 TF_ASSIGN_OR_RETURN(auto execution_plan,
4861 RebuildExecutionPlan(cudnn, algorithm_desc, *op_graph));
4862
4863 TF_ASSIGN_OR_RETURN(
4864 auto runner,
4865 CudnnExecutionPlanRunner<dnn::ConvSignature>::Create(
4866 parent_, cudnn_.get(), std::move(execution_plan), {'x', 'w', 'y'},
4867 /*need_side_input=*/false));
4868 return {std::make_unique<CudnnExecutionPlanRunner<dnn::ConvSignature>>(
4869 std::move(runner))};
4870 #else
4871 return port::UnimplementedError(
4872 "Cudnn execution plans are only supported with Cudnn >= 8.1.");
4873 #endif
4874 }
4875
4876 class CudnnLegacyFusedConvRunner : public dnn::FusedConvRunner {
4877 public:
4878 // Queries the workspace size and constructs a 'CudnnLegacyFusedConvRunner'.
Create(GpuExecutor * parent,Stream * stream,CudnnAccess * cudnn,const dnn::AlgorithmDesc & algo,dnn::DataType input_type,double conv_scale,double side_input_scale,CudnnTensorDescriptor input_nd,CudnnTensorDescriptor output_nd,CudnnFilterDescriptor filter,CudnnTensorDescriptor bias_nd,CudnnConvolutionDescriptor conv,CudnnActivationDescriptor activation_desc)4879 static port::StatusOr<CudnnLegacyFusedConvRunner> Create(
4880 GpuExecutor* parent, Stream* stream, CudnnAccess* cudnn,
4881 const dnn::AlgorithmDesc& algo, dnn::DataType input_type,
4882 double conv_scale, double side_input_scale,
4883 CudnnTensorDescriptor input_nd, CudnnTensorDescriptor output_nd,
4884 CudnnFilterDescriptor filter, CudnnTensorDescriptor bias_nd,
4885 CudnnConvolutionDescriptor conv,
4886 CudnnActivationDescriptor activation_desc) {
4887 size_t workspace_size;
4888 if (algo.workspace_size()) {
4889 workspace_size = *algo.workspace_size();
4890 } else {
4891 auto handle = cudnn->GetHandle(parent, stream);
4892
4893 RETURN_IF_CUDNN_ERROR(cudnnGetConvolutionForwardWorkspaceSize(
4894 handle.handle(),
4895 /*xDesc=*/input_nd.handle(),
4896 /*wDesc=*/filter.handle(), /*convDesc=*/conv.handle(),
4897 /*yDesc=*/output_nd.handle(),
4898 /*algo=*/ToConvForwardAlgo(algo),
4899 /*sizeInBytes=*/&workspace_size));
4900 }
4901
4902 return {{parent, cudnn, algo.algo_id(), algo.tensor_ops_enabled(),
4903 workspace_size, input_type, conv_scale, side_input_scale,
4904 std::move(input_nd), std::move(output_nd), std::move(filter),
4905 std::move(bias_nd), std::move(conv), std::move(activation_desc)}};
4906 }
4907
ToString() const4908 std::string ToString() const override {
4909 return MakeAlgorithmDesc().ToString();
4910 }
4911
GetWorkspaceSize() const4912 uint64_t GetWorkspaceSize() const override { return workspace_size_; }
4913
ToAlgorithmDesc() const4914 port::StatusOr<dnn::AlgorithmDesc> ToAlgorithmDesc() const override {
4915 return MakeAlgorithmDesc();
4916 }
4917
operator ()(Stream * stream,dnn::ProfileResult * profile_result,DeviceMemoryBase scratch_memory,DeviceMemoryBase input_data,DeviceMemoryBase filter_data,DeviceMemoryBase side_input_data,DeviceMemoryBase bias_data,DeviceMemoryBase output_data) const4918 port::Status operator()(Stream* stream, dnn::ProfileResult* profile_result,
4919 DeviceMemoryBase scratch_memory,
4920 DeviceMemoryBase input_data,
4921 DeviceMemoryBase filter_data,
4922 DeviceMemoryBase side_input_data,
4923 DeviceMemoryBase bias_data,
4924 DeviceMemoryBase output_data) const override {
4925 if (static_cast<internal::StreamExecutorInterface*>(parent_) !=
4926 stream->parent()->implementation()) {
4927 return port::InternalError(
4928 "CudnnLegacyFusedConvRunner cached across multiple StreamExecutors.");
4929 }
4930
4931 auto algo = MakeAlgorithmDesc();
4932
4933 std::unique_ptr<GpuTimer, GpuTimerDeleter> timer;
4934 if (profile_result) {
4935 timer.reset(new GpuTimer(parent_)); // NOLINT
4936 // The start and stop of the timer should be as close to the Cudnn call as
4937 // possible. It is still possible for other threads to issue workload on
4938 // to this stream. So it could take multiple profiling measurements.
4939 if (!timer->Init() || !timer->Start(AsGpuStream(stream))) {
4940 return port::Status(port::error::INTERNAL, "Failed to start timer");
4941 }
4942 }
4943 auto side_input_data_ptr = (side_input_scale_ == 0)
4944 ? output_data.opaque()
4945 : side_input_data.opaque();
4946
4947 auto cudnn = cudnn_->GetHandle(parent_, stream);
4948
4949 VLOG(2) << "\nconv_scale = " << conv_scale_
4950 << "\nconv_input_nd.handle() = " << input_nd_.handle()
4951 << "\nconv_input_data.opaque() = " << input_data.opaque()
4952 << "\nfilter.handle() = " << filter_.handle()
4953 << "\nfilter_data.opaque() = " << filter_data.opaque()
4954 << "\nconv.handle() = " << conv_.handle() << "\nalgo = " << algo_id_
4955 << ", tensor_ops_enabled=" << tensor_ops_enabled_
4956 << "\nscratch.opaque() = " << scratch_memory.opaque()
4957 << "\nscratch.size() = " << scratch_memory.size()
4958 << "\nside_input_scale = " << side_input_scale_
4959 << "\noutput_nd.handle() = " << output_nd_.handle()
4960 << "\nside_input_data_ptr = " << side_input_data_ptr
4961 << "\nbias_nd.handle() = " << bias_nd_.handle()
4962 << "\nbiases.opaque() = " << bias_data.opaque()
4963 << "\nactivation_desc.handle() = " << activation_desc_.handle()
4964 << "\noutput_nd.handle() = " << output_nd_.handle()
4965 << "\noutput_data.opaque() = " << output_data.opaque();
4966
4967 if (IsTensorMathOpSet(conv_) != tensor_ops_enabled_) {
4968 return port::Status(port::error::FAILED_PRECONDITION,
4969 "Tensor op math type in dnn::AlgorithmDesc does not "
4970 "match that of the CudnnConvolutionDescriptor");
4971 }
4972
4973 // N.B. the scaling parameters alpha1 and alpha2 are pointers to
4974 // temporaries; this API doesn't persist the pointers beyond its own stack
4975 // frame.
4976 auto status = cudnnConvolutionBiasActivationForward(
4977 cudnn.handle(),
4978 /*alpha1=*/ScalingParam(conv_scale_).ToVoidPointer(input_type_),
4979 /*xDesc=*/input_nd_.handle(), /*x=*/input_data.opaque(),
4980 /*wDesc=*/filter_.handle(), /*w=*/filter_data.opaque(),
4981 /*convDesc=*/conv_.handle(), ToConvForwardAlgo(algo),
4982 /*workSpace=*/scratch_memory.opaque(),
4983 /*workSpaceSizeInBytes=*/scratch_memory.size(),
4984 /*alpha2=*/ScalingParam(side_input_scale_).ToVoidPointer(input_type_),
4985 /*zDesc=*/output_nd_.handle(), /*z=*/side_input_data_ptr,
4986 /*biasDesc=*/bias_nd_.handle(), /*bias=*/bias_data.opaque(),
4987 /*activationDesc=*/activation_desc_.handle(),
4988 /*yDesc=*/output_nd_.handle(), /*y=*/output_data.opaque());
4989 if (status != CUDNN_STATUS_SUCCESS || !profile_result) {
4990 VLOG(4) << "conv with algorithm " << ToConvForwardAlgo(algo)
4991 << ", workspace_size=" << scratch_memory.size() << " -> "
4992 << CudnnStatusToString(status);
4993 }
4994 RETURN_IF_CUDNN_ERROR(status);
4995
4996 if (profile_result) {
4997 if (!timer->Stop(AsGpuStream(stream))) {
4998 return port::Status(port::error::INTERNAL, "Failed to stop timer");
4999 }
5000 profile_result->set_algorithm(algo);
5001 profile_result->set_elapsed_time_in_ms(timer->GetElapsedMilliseconds());
5002 profile_result->set_scratch_size(scratch_memory.size());
5003 VLOG(4) << "conv with algorithm " << ToConvForwardAlgo(algo)
5004 << ", tensor_ops_enabled=" << tensor_ops_enabled_
5005 << ", workspace_size=" << scratch_memory.size() << " -> "
5006 << CudnnStatusToString(status) << " in "
5007 << timer->GetElapsedMilliseconds() << "ms";
5008 }
5009
5010 return ::tensorflow::OkStatus();
5011 }
5012
5013 private:
5014 // Private to prevent passing in the wrong workspace_size.
CudnnLegacyFusedConvRunner(GpuExecutor * parent,CudnnAccess * cudnn,int64_t algo_id,bool tensor_ops_enabled,size_t workspace_size,dnn::DataType input_type,double conv_scale,double side_input_scale,CudnnTensorDescriptor input_nd,CudnnTensorDescriptor output_nd,CudnnFilterDescriptor filter,CudnnTensorDescriptor bias_nd,CudnnConvolutionDescriptor conv,CudnnActivationDescriptor activation_desc)5015 CudnnLegacyFusedConvRunner(GpuExecutor* parent, CudnnAccess* cudnn,
5016 int64_t algo_id, bool tensor_ops_enabled,
5017 size_t workspace_size, dnn::DataType input_type,
5018 double conv_scale, double side_input_scale,
5019 CudnnTensorDescriptor input_nd,
5020 CudnnTensorDescriptor output_nd,
5021 CudnnFilterDescriptor filter,
5022 CudnnTensorDescriptor bias_nd,
5023 CudnnConvolutionDescriptor conv,
5024 CudnnActivationDescriptor activation_desc)
5025 : parent_(parent),
5026 cudnn_(cudnn),
5027 algo_id_(algo_id),
5028 tensor_ops_enabled_(tensor_ops_enabled),
5029 workspace_size_(workspace_size),
5030 input_type_(input_type),
5031 conv_scale_(conv_scale),
5032 side_input_scale_(side_input_scale),
5033 input_nd_(std::move(input_nd)),
5034 output_nd_(std::move(output_nd)),
5035 filter_(std::move(filter)),
5036 bias_nd_(std::move(bias_nd)),
5037 conv_(std::move(conv)),
5038 activation_desc_(std::move(activation_desc)) {}
5039
5040 // Internal form of ToAlgorithmDesc without the StatusOr.
MakeAlgorithmDesc() const5041 dnn::AlgorithmDesc MakeAlgorithmDesc() const {
5042 return {algo_id_, tensor_ops_enabled_, workspace_size_};
5043 }
5044
5045 GpuExecutor* parent_;
5046 CudnnAccess* cudnn_;
5047 int64_t algo_id_;
5048 bool tensor_ops_enabled_;
5049 size_t workspace_size_;
5050 dnn::DataType input_type_;
5051 double conv_scale_, side_input_scale_;
5052
5053 CudnnTensorDescriptor input_nd_;
5054 CudnnTensorDescriptor output_nd_;
5055 CudnnFilterDescriptor filter_;
5056 CudnnTensorDescriptor bias_nd_;
5057 CudnnConvolutionDescriptor conv_;
5058 CudnnActivationDescriptor activation_desc_;
5059 };
5060
5061 port::StatusOr<std::unique_ptr<const dnn::FusedConvRunner>>
FusedConvolveRunnerFromDesc(Stream * stream,const dnn::AlgorithmDesc & algorithm_desc,dnn::ConvolutionKind kind,dnn::DataType input_type,dnn::DataType bias_type,dnn::DataType output_type,double conv_scale,double side_input_scale,double leakyrelu_alpha,const dnn::BatchDescriptor & input_descriptor,const dnn::FilterDescriptor & filter_descriptor,const dnn::BatchDescriptor & bias_descriptor,const dnn::BatchDescriptor & output_descriptor,const dnn::ConvolutionDescriptor & convolution_descriptor,dnn::ActivationMode activation_mode)5062 CudnnSupport::FusedConvolveRunnerFromDesc(
5063 Stream* stream, const dnn::AlgorithmDesc& algorithm_desc,
5064 dnn::ConvolutionKind kind, dnn::DataType input_type,
5065 dnn::DataType bias_type, dnn::DataType output_type, double conv_scale,
5066 double side_input_scale, double leakyrelu_alpha,
5067 const dnn::BatchDescriptor& input_descriptor,
5068 const dnn::FilterDescriptor& filter_descriptor,
5069 const dnn::BatchDescriptor& bias_descriptor,
5070 const dnn::BatchDescriptor& output_descriptor,
5071 const dnn::ConvolutionDescriptor& convolution_descriptor,
5072 dnn::ActivationMode activation_mode) {
5073 if (!algorithm_desc.is_cudnn_frontend()) {
5074 CudnnTensorDescriptor conv_input_nd(
5075 input_descriptor,
5076 ToCudnnDataType(input_type, input_descriptor.layout()));
5077 CudnnTensorDescriptor output_nd(
5078 output_descriptor,
5079 ToCudnnDataType(output_type, input_descriptor.layout()));
5080 CudnnFilterDescriptor filter(
5081 filter_descriptor,
5082 ToCudnnDataType(input_type, filter_descriptor.layout()));
5083 CudnnTensorDescriptor bias_nd(bias_descriptor, ToCudnnDataType(bias_type));
5084
5085 CudnnConvolutionDescriptor conv(
5086 convolution_descriptor,
5087 ToCudnnDataType(GetConvAccumulatorType(input_type)));
5088 conv.set_use_tensor_op_math(algorithm_desc.tensor_ops_enabled());
5089
5090 // CUDNN v6 only supports CUDNN_NOT_PROPAGATE_NAN as the reluNanOpt for
5091 // activation descriptor. Note that this will change the nan propagation
5092 // behavior from separate conv, bias, and relu (which by default is
5093 // CUDNN_PROPAGATE_NAN).
5094 //
5095 // TODO(awpr): reevaluate this for newer cuDNN versions.
5096 CudnnActivationDescriptor activation_desc(activation_mode,
5097 CUDNN_NOT_PROPAGATE_NAN,
5098 output_descriptor.value_max());
5099
5100 TF_ASSIGN_OR_RETURN(
5101 auto runner,
5102 CudnnLegacyFusedConvRunner::Create(
5103 parent_, stream, cudnn_.get(), algorithm_desc, input_type,
5104 conv_scale, side_input_scale, std::move(conv_input_nd),
5105 std::move(output_nd), std::move(filter), std::move(bias_nd),
5106 std::move(conv), std::move(activation_desc)));
5107 return {std::make_unique<CudnnLegacyFusedConvRunner>(std::move(runner))};
5108 }
5109
5110 #if CUDNN_VERSION >= 8100 && TF_ENABLE_CUDNN_FRONTEND
5111 auto cudnn = cudnn_->GetHandle(parent_, stream);
5112
5113 TF_ASSIGN_OR_RETURN(auto op_graph,
5114 GetCudnnFusedOperationGraph(
5115 kind, input_type, bias_type, output_type, conv_scale,
5116 side_input_scale, leakyrelu_alpha, input_descriptor,
5117 filter_descriptor, bias_descriptor, output_descriptor,
5118 convolution_descriptor, activation_mode, cudnn));
5119
5120 TF_ASSIGN_OR_RETURN(auto execution_plan,
5121 RebuildExecutionPlan(cudnn, algorithm_desc, *op_graph));
5122
5123 bool need_side_input =
5124 SideInputNeeded(activation_mode, conv_scale, side_input_scale);
5125 TF_ASSIGN_OR_RETURN(auto runner,
5126 CudnnExecutionPlanRunner<dnn::FusedConvSignature>::Create(
5127 parent_, cudnn_.get(), std::move(execution_plan),
5128 {'x', 'w', 'z', 'b', 'y'}, need_side_input));
5129 return {std::make_unique<CudnnExecutionPlanRunner<dnn::FusedConvSignature>>(
5130 std::move(runner))};
5131 #else
5132 return port::UnimplementedError(
5133 "Cudnn execution plans are only supported with Cudnn >= 8.1.");
5134 #endif
5135 }
5136
GetFusedConvolveRunners(bool use_cudnn_frontend,dnn::ConvolutionKind kind,dnn::DataType input_type,dnn::DataType bias_type,dnn::DataType output_type,double conv_scale,double side_input_scale,double leakyrelu_alpha,Stream * stream,const dnn::BatchDescriptor & input_descriptor,const dnn::FilterDescriptor & filter_descriptor,const dnn::BatchDescriptor & bias_descriptor,const dnn::BatchDescriptor & output_descriptor,const dnn::ConvolutionDescriptor & convolution_descriptor,bool use_fallback,const dnn::ActivationMode activation_mode,std::vector<std::unique_ptr<const dnn::FusedConvRunner>> * out_exec_plans)5137 port::Status CudnnSupport::GetFusedConvolveRunners(
5138 bool use_cudnn_frontend, dnn::ConvolutionKind kind,
5139 dnn::DataType input_type, dnn::DataType bias_type,
5140 dnn::DataType output_type, double conv_scale, double side_input_scale,
5141 double leakyrelu_alpha, Stream* stream,
5142 const dnn::BatchDescriptor& input_descriptor,
5143 const dnn::FilterDescriptor& filter_descriptor,
5144 const dnn::BatchDescriptor& bias_descriptor,
5145 const dnn::BatchDescriptor& output_descriptor,
5146 const dnn::ConvolutionDescriptor& convolution_descriptor, bool use_fallback,
5147 const dnn::ActivationMode activation_mode,
5148 std::vector<std::unique_ptr<const dnn::FusedConvRunner>>* out_exec_plans) {
5149 // Fused convolutions with identity activations are broken in that they
5150 // implicitly do ReLU on some engines, and we can't reliably detect which
5151 // ones.
5152 const bool is_broken_identity_fused_conv =
5153 #if CUDNN_VERSION < 8205
5154 activation_mode == dnn::ActivationMode::kNone;
5155 #else
5156 false;
5157 #endif
5158
5159 // All current versions of the frontend API lack support for Tx32
5160 // convolutions.
5161 const bool is_unsupported_x32 =
5162 input_descriptor.layout() == dnn::kBatchDepthYX32;
5163
5164 // cuDNN frontend support became sufficiently stable to use in 8.1.
5165 // TODO(awpr): remove this condition once support for cuDNN 8.0 is dropped.
5166 const bool is_pre_frontend_cudnn = CUDNN_VERSION < 8100;
5167
5168 const bool actually_use_cudnn_frontend =
5169 use_cudnn_frontend && !is_pre_frontend_cudnn &&
5170 !is_broken_identity_fused_conv && !is_unsupported_x32;
5171
5172 if (use_cudnn_frontend && !actually_use_cudnn_frontend) {
5173 const char* reason = "the current cuDNN version does not support it.";
5174 if (is_unsupported_x32) {
5175 reason = "Tx32 convolutions are unsupported.";
5176 } else if (is_broken_identity_fused_conv) {
5177 reason = "it uses an identity activation.";
5178 }
5179
5180 // This will happen once per unique conv configuration/shape that gets
5181 // affected (and not, for example, on every conv launch). Confusion over
5182 // whether this has happened or not has repeatedly wasted a lot of time
5183 // debugging, so be sure it shows up in the logs.
5184 LOG(INFO) << "Disabling cuDNN frontend for the following convolution:\n"
5185 << " input: " << input_descriptor.ToString() << "\n"
5186 << " filter: " << filter_descriptor.ToString() << "\n"
5187 << " " << convolution_descriptor.ToString() << "\n"
5188 << " ... because " << reason;
5189 }
5190
5191 if (input_type == dnn::DataType::kInt8 &&
5192 !stream->GetCudaComputeCapability().IsAtLeast(6, 1)) {
5193 return port::UnimplementedError(
5194 "cudnnConvolutionBiasActivationForward() for int8 is only supported "
5195 "on GPUs with compute capability 6.1 or later.");
5196 }
5197
5198 if (input_type == dnn::DataType::kInt8 &&
5199 output_type == dnn::DataType::kFloat &&
5200 (CUDNN_VERSION >= 8000 && CUDNN_VERSION <= 8200)) {
5201 return port::UnimplementedError(
5202 "int8 -> float fused conv is disabled for this cuDNN version. See "
5203 "go/nvbugs/3326122");
5204 }
5205
5206 if (activation_mode != dnn::ActivationMode::kRelu &&
5207 activation_mode != dnn::ActivationMode::kRelu6 &&
5208 activation_mode != dnn::ActivationMode::kElu &&
5209 activation_mode != dnn::ActivationMode::kLeakyRelu &&
5210 activation_mode != dnn::ActivationMode::kNone) {
5211 return port::Status(port::error::INVALID_ARGUMENT,
5212 "CuDNN fusion only supports activations of "
5213 "{Relu, Relu6, Elu, <None>}.");
5214 }
5215
5216 if (!actually_use_cudnn_frontend) {
5217 std::vector<dnn::AlgorithmDesc> algorithms;
5218
5219 auto cuda_compute_capability = stream->GetCudaComputeCapability();
5220 if (!GetConvolveAlgorithms(cuda_compute_capability, &algorithms)) {
5221 return port::Status(port::error::UNKNOWN,
5222 "Listing fused convolve algorithms failed.");
5223 }
5224
5225 for (const auto& algo : algorithms) {
5226 // Only CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM is supported
5227 // for identity activation, other algs seem to quietly do Relu. See
5228 // https://docs.nvidia.com/deeplearning/sdk/cudnn-developer-guide/index.html#cudnnConvolutionBiasActivationForward
5229 if (activation_mode == dnn::ActivationMode::kNone &&
5230 algo.algo_id() != CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM) {
5231 continue;
5232 }
5233 auto runner_or = FusedConvolveRunnerFromDesc(
5234 stream, algo, kind, input_type, bias_type, output_type, conv_scale,
5235 side_input_scale, leakyrelu_alpha, input_descriptor,
5236 filter_descriptor, bias_descriptor, output_descriptor,
5237 convolution_descriptor, activation_mode);
5238 if (!runner_or.ok()) {
5239 // See the corresponding error handling in
5240 // CudnnSupport::GetConvolveRunners: this filters out algorithms that
5241 // don't support this conv.
5242 continue;
5243 }
5244 out_exec_plans->push_back(std::move(runner_or).value());
5245 }
5246 return ::tensorflow::OkStatus();
5247 }
5248
5249 #if CUDNN_VERSION >= 8100 && TF_ENABLE_CUDNN_FRONTEND
5250 auto cudnn = cudnn_->GetHandle(parent_, stream);
5251 auto op_graph_status = GetCudnnFusedOperationGraph(
5252 kind, input_type, bias_type, output_type, conv_scale, side_input_scale,
5253 leakyrelu_alpha, input_descriptor, filter_descriptor, bias_descriptor,
5254 output_descriptor, convolution_descriptor, activation_mode, cudnn);
5255 if (!op_graph_status.status().ok()) {
5256 return port::Status(port::error::INTERNAL,
5257 absl::StrCat("Cudnn graph failed to build: ",
5258 op_graph_status.status().ToString()));
5259 }
5260 auto op_graph = std::move(op_graph_status).value();
5261
5262 bool need_side_input =
5263 SideInputNeeded(activation_mode, conv_scale, side_input_scale);
5264 return CreateOpRunners<dnn::FusedConvSignature>(
5265 stream, cudnn, parent_, cudnn_.get(), std::move(op_graph), kind,
5266 input_type, {'x', 'w', 'z', 'b', 'y'}, use_fallback, out_exec_plans,
5267 need_side_input);
5268 #else
5269 return port::UnimplementedError(
5270 "Cudnn execution plans are only supported with Cudnn >= 8.1.");
5271 #endif // CUDNN_VERSION >= 8100 && TF_ENABLE_CUDNN_FRONTEND
5272 }
5273
GetConvolveAlgorithms(CudaComputeCapability cuda_compute_capability,std::vector<dnn::AlgorithmDesc> * out_algorithms)5274 bool CudnnSupport::GetConvolveAlgorithms(
5275 CudaComputeCapability cuda_compute_capability,
5276 std::vector<dnn::AlgorithmDesc>* out_algorithms) {
5277 PreloadCudnnSubLibs(PreloadCudnnType::ConvFwd);
5278
5279 bool tensor_op_math_available =
5280 TensorOpMathAvailable(cuda_compute_capability);
5281 out_algorithms->clear();
5282
5283 std::vector<dnn::AlgorithmDesc::Index> algo_types;
5284 if (ConvUseDefaultAlgorithm()) {
5285 // Force a fallback algorithm.
5286 algo_types = {CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM};
5287 } else {
5288 algo_types = {CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM,
5289 CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM,
5290 CUDNN_CONVOLUTION_FWD_ALGO_GEMM,
5291 CUDNN_CONVOLUTION_FWD_ALGO_DIRECT,
5292 CUDNN_CONVOLUTION_FWD_ALGO_FFT,
5293 CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD};
5294 if (CudnnEnvVar<FftTilingForward>::IsEnabled()) {
5295 algo_types.push_back(CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING);
5296 }
5297 if (CudnnEnvVar<WinogradNonfused>::IsEnabled()) {
5298 algo_types.push_back(CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED);
5299 }
5300 }
5301
5302 // The algorithms are intentionally ordered for deterministic operation
5303 for (auto i : algo_types) {
5304 if (tensor_op_math_available) {
5305 out_algorithms->push_back({i, /*use_tensor_ops=*/true});
5306 }
5307 out_algorithms->push_back({i, /*use_tensor_ops=*/false});
5308 }
5309
5310 return true;
5311 }
5312
GetRnnAlgorithms(std::vector<dnn::AlgorithmDesc> * out_algorithms)5313 bool CudnnSupport::GetRnnAlgorithms(
5314 std::vector<dnn::AlgorithmDesc>* out_algorithms) {
5315 PreloadCudnnSubLibs(PreloadCudnnType::Rnn);
5316
5317 std::vector<dnn::AlgorithmDesc::Index> algo_types = {
5318 // clang-format off
5319 CUDNN_RNN_ALGO_STANDARD,
5320 CUDNN_RNN_ALGO_PERSIST_STATIC,
5321 CUDNN_RNN_ALGO_PERSIST_DYNAMIC,
5322 // clang-format on
5323 };
5324
5325 out_algorithms->clear();
5326 for (auto i : algo_types) {
5327 out_algorithms->push_back({i, /*use_tensor_ops=*/false});
5328 out_algorithms->push_back({i, /*use_tensor_ops=*/true});
5329 }
5330 return true;
5331 }
5332
GetConvolveBackwardDataAlgorithms(CudaComputeCapability cuda_compute_capability,std::vector<dnn::AlgorithmDesc> * out_algorithms)5333 bool CudnnSupport::GetConvolveBackwardDataAlgorithms(
5334 CudaComputeCapability cuda_compute_capability,
5335 std::vector<dnn::AlgorithmDesc>* out_algorithms) {
5336 PreloadCudnnSubLibs(PreloadCudnnType::ConvBwdData);
5337
5338 bool tensor_op_math_available =
5339 TensorOpMathAvailable(cuda_compute_capability);
5340 out_algorithms->clear();
5341
5342 std::vector<dnn::AlgorithmDesc::Index> algo_types = {
5343 // clang-format off
5344 CUDNN_CONVOLUTION_BWD_DATA_ALGO_1,
5345 CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT,
5346 CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING,
5347 CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD,
5348 // clang-format on
5349 };
5350 if (CudnnEnvVar<WinogradNonfused>::IsEnabled()) {
5351 algo_types.push_back(CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD_NONFUSED);
5352 }
5353 if (!RequireCudnnDeterminism()) {
5354 algo_types.push_back(CUDNN_CONVOLUTION_BWD_DATA_ALGO_0);
5355 }
5356
5357 // The algorithms are intentionally ordered for deterministic operation
5358 for (auto i : algo_types) {
5359 if (tensor_op_math_available) {
5360 out_algorithms->push_back({i, /*use_tensor_ops=*/true});
5361 }
5362 out_algorithms->push_back({i, /*use_tensor_ops=*/false});
5363 }
5364
5365 return true;
5366 }
5367
GetConvolveBackwardFilterAlgorithms(CudaComputeCapability cuda_compute_capability,std::vector<dnn::AlgorithmDesc> * out_algorithms)5368 bool CudnnSupport::GetConvolveBackwardFilterAlgorithms(
5369 CudaComputeCapability cuda_compute_capability,
5370 std::vector<dnn::AlgorithmDesc>* out_algorithms) {
5371 PreloadCudnnSubLibs(PreloadCudnnType::ConvBwdFilter);
5372
5373 bool tensor_op_math_available =
5374 TensorOpMathAvailable(cuda_compute_capability);
5375 out_algorithms->clear();
5376
5377 std::vector<dnn::AlgorithmDesc::Index> algo_types = {
5378 // clang-format off
5379 CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1,
5380 CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT,
5381 // Based on cudnn.h, the following is not implemented.
5382 // CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD,
5383
5384 // Produces incorrect results for some shapes. Disabled for now, see
5385 // NVIDIA bug 2072856. TODO(csigg): Only disable for subset of shapes.
5386 // CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT_TILING,
5387 // clang-format on
5388 };
5389 if (CudnnEnvVar<WinogradNonfused>::IsEnabled()) {
5390 algo_types.push_back(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD_NONFUSED);
5391 }
5392 if (!RequireCudnnDeterminism()) {
5393 algo_types.push_back(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0);
5394 algo_types.push_back(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_3);
5395 }
5396
5397 // The algorithms are intentionally ordered for deterministic operation
5398 for (auto i : algo_types) {
5399 if (tensor_op_math_available) {
5400 out_algorithms->push_back({i, /*use_tensor_ops=*/true});
5401 }
5402 out_algorithms->push_back({i, /*use_tensor_ops=*/false});
5403 }
5404
5405 return true;
5406 }
5407
DoBatchNormalizationForward(Stream * stream,const DeviceMemory<float> & x,const DeviceMemory<float> & scale,const DeviceMemory<float> & offset,const DeviceMemory<float> & estimated_mean,const DeviceMemory<float> & estimated_variance,const DeviceMemory<float> & side_input,const dnn::BatchDescriptor & x_desc,const dnn::BatchDescriptor & scale_offset_desc,const double epsilon,const double exponential_average_factor,dnn::ActivationMode activation_mode,DeviceMemory<float> * y,DeviceMemory<float> * batch_mean,DeviceMemory<float> * batch_var,DeviceMemory<float> * saved_mean,DeviceMemory<float> * saved_inv_var,bool is_training,ScratchAllocator * reserve_space_allocator,ScratchAllocator * workspace_allocator)5408 bool CudnnSupport::DoBatchNormalizationForward(
5409 Stream* stream, const DeviceMemory<float>& x,
5410 const DeviceMemory<float>& scale, const DeviceMemory<float>& offset,
5411 const DeviceMemory<float>& estimated_mean,
5412 const DeviceMemory<float>& estimated_variance,
5413 const DeviceMemory<float>& side_input, const dnn::BatchDescriptor& x_desc,
5414 const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
5415 const double exponential_average_factor,
5416 dnn::ActivationMode activation_mode, DeviceMemory<float>* y,
5417 DeviceMemory<float>* batch_mean, DeviceMemory<float>* batch_var,
5418 DeviceMemory<float>* saved_mean, DeviceMemory<float>* saved_inv_var,
5419 bool is_training, ScratchAllocator* reserve_space_allocator,
5420 ScratchAllocator* workspace_allocator) {
5421 return IsStatusOk(
5422 DoBatchNormalizationForwardImpl<float, float>(
5423 stream, dnn::DataType::kFloat, dnn::DataType::kFloat, x, scale,
5424 offset, estimated_mean, estimated_variance, side_input, x_desc,
5425 scale_offset_desc, epsilon, exponential_average_factor,
5426 activation_mode, y, batch_mean, batch_var, saved_mean, saved_inv_var,
5427 is_training, reserve_space_allocator, workspace_allocator),
5428 /*report_error=*/true);
5429 }
5430
DoBatchNormalizationForward(Stream * stream,const DeviceMemory<Eigen::half> & x,const DeviceMemory<float> & scale,const DeviceMemory<float> & offset,const DeviceMemory<float> & estimated_mean,const DeviceMemory<float> & estimated_variance,const DeviceMemory<Eigen::half> & side_input,const dnn::BatchDescriptor & x_desc,const dnn::BatchDescriptor & scale_offset_desc,const double epsilon,const double exponential_average_factor,dnn::ActivationMode activation_mode,DeviceMemory<Eigen::half> * y,DeviceMemory<float> * batch_mean,DeviceMemory<float> * batch_var,DeviceMemory<float> * saved_mean,DeviceMemory<float> * saved_inv_var,bool is_training,ScratchAllocator * reserve_space_allocator,ScratchAllocator * workspace_allocator)5431 bool CudnnSupport::DoBatchNormalizationForward(
5432 Stream* stream, const DeviceMemory<Eigen::half>& x,
5433 const DeviceMemory<float>& scale, const DeviceMemory<float>& offset,
5434 const DeviceMemory<float>& estimated_mean,
5435 const DeviceMemory<float>& estimated_variance,
5436 const DeviceMemory<Eigen::half>& side_input,
5437 const dnn::BatchDescriptor& x_desc,
5438 const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
5439 const double exponential_average_factor,
5440 dnn::ActivationMode activation_mode, DeviceMemory<Eigen::half>* y,
5441 DeviceMemory<float>* batch_mean, DeviceMemory<float>* batch_var,
5442 DeviceMemory<float>* saved_mean, DeviceMemory<float>* saved_inv_var,
5443 bool is_training, ScratchAllocator* reserve_space_allocator,
5444 ScratchAllocator* workspace_allocator) {
5445 return IsStatusOk(
5446 DoBatchNormalizationForwardImpl<Eigen::half, float>(
5447 stream, dnn::DataType::kHalf, dnn::DataType::kFloat, x, scale, offset,
5448 estimated_mean, estimated_variance, side_input, x_desc,
5449 scale_offset_desc, epsilon, exponential_average_factor,
5450 activation_mode, y, batch_mean, batch_var, saved_mean, saved_inv_var,
5451 is_training, reserve_space_allocator, workspace_allocator),
5452 /*report_error=*/true);
5453 }
5454
5455 template <class T, class U>
DoBatchNormalizationForwardImpl(Stream * stream,dnn::DataType input_data_type,dnn::DataType scale_data_type,const DeviceMemory<T> & x,const DeviceMemory<U> & scale,const DeviceMemory<U> & offset,const DeviceMemory<U> & estimated_mean,const DeviceMemory<U> & estimated_variance,const DeviceMemory<T> & side_input,const dnn::BatchDescriptor & x_desc,const dnn::BatchDescriptor & scale_offset_desc,const double epsilon,const double exponential_average_factor,dnn::ActivationMode activation_mode,DeviceMemory<T> * y,DeviceMemory<U> * batch_mean,DeviceMemory<U> * batch_var,DeviceMemory<U> * saved_mean,DeviceMemory<U> * saved_inv_var,bool is_training,ScratchAllocator * reserve_space_allocator,ScratchAllocator * workspace_allocator)5456 port::Status CudnnSupport::DoBatchNormalizationForwardImpl(
5457 Stream* stream, dnn::DataType input_data_type,
5458 dnn::DataType scale_data_type, const DeviceMemory<T>& x,
5459 const DeviceMemory<U>& scale, const DeviceMemory<U>& offset,
5460 const DeviceMemory<U>& estimated_mean,
5461 const DeviceMemory<U>& estimated_variance,
5462 const DeviceMemory<T>& side_input, const dnn::BatchDescriptor& x_desc,
5463 const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
5464 const double exponential_average_factor,
5465 dnn::ActivationMode activation_mode, DeviceMemory<T>* y,
5466 DeviceMemory<U>* batch_mean, DeviceMemory<U>* batch_var,
5467 DeviceMemory<U>* saved_mean, DeviceMemory<U>* saved_inv_var,
5468 bool is_training, ScratchAllocator* reserve_space_allocator,
5469 ScratchAllocator* workspace_allocator) {
5470 CudnnTensorDescriptor x_descriptor(x_desc, ToCudnnDataType(input_data_type));
5471 CudnnTensorDescriptor scale_offset_descriptor(
5472 scale_offset_desc, ToCudnnDataType(scale_data_type));
5473 cudnnBatchNormMode_t mode = CUDNN_BATCHNORM_SPATIAL;
5474 if (BatchnormSpatialPersistentEnabled() && is_training) {
5475 mode = CUDNN_BATCHNORM_SPATIAL_PERSISTENT;
5476 }
5477 float one = 1.0;
5478 float zero = 0.0;
5479 auto cudnn = cudnn_->GetHandle(parent_, stream);
5480
5481 DeviceMemory<uint8> workspace;
5482 DeviceMemory<uint8> reserve_space;
5483
5484 #if CUDNN_VERSION >= 7402
5485 const auto get_bn_ops = [&]() -> cudnnBatchNormOps_t {
5486 if (side_input.is_null()) {
5487 return activation_mode == dnn::ActivationMode::kNone
5488 ? CUDNN_BATCHNORM_OPS_BN
5489 : CUDNN_BATCHNORM_OPS_BN_ACTIVATION;
5490 } else {
5491 return CUDNN_BATCHNORM_OPS_BN_ADD_ACTIVATION;
5492 }
5493 };
5494 const cudnnBatchNormOps_t bn_ops = get_bn_ops();
5495
5496 // We use Nan propagation to be consistent with CudnnSupport::DoActivate(...).
5497 CudnnActivationDescriptor activation_desc(
5498 activation_mode, CUDNN_PROPAGATE_NAN, x_desc.value_max());
5499
5500 if (reserve_space_allocator != nullptr && workspace_allocator != nullptr) {
5501 TF_ASSIGN_OR_RETURN(
5502 workspace,
5503 CreateBatchNormForwardWorkspace(
5504 stream, cudnn, mode, bn_ops, activation_desc.handle(), x_descriptor,
5505 scale_offset_descriptor, workspace_allocator));
5506 if (is_training) {
5507 size_t reserve_space_size_in_bytes = 0;
5508 RETURN_IF_CUDNN_ERROR(
5509 cudnnGetBatchNormalizationTrainingExReserveSpaceSize(
5510 /*handle=*/cudnn.handle(), /*mode=*/mode, /*bnOps=*/bn_ops,
5511 /*activationDesc=*/activation_desc.handle(),
5512 /*xDesc=*/x_descriptor.handle(),
5513 /*sizeInBytes=*/&reserve_space_size_in_bytes));
5514 TF_ASSIGN_OR_RETURN(reserve_space, reserve_space_allocator->AllocateBytes(
5515 reserve_space_size_in_bytes));
5516 }
5517 }
5518 #endif
5519
5520 auto check_no_side_input_or_activation = [&]() -> port::Status {
5521 if (activation_mode != dnn::ActivationMode::kNone ||
5522 !side_input.is_null()) {
5523 return port::Status(
5524 port::error::INTERNAL,
5525 absl::StrCat(
5526 "Side input and activation are not supported by cuDNN version: ",
5527 CUDNN_VERSION));
5528 } else {
5529 return ::tensorflow::OkStatus();
5530 }
5531 };
5532
5533 if (is_training) {
5534 CHECK_EQ(batch_mean->is_null(), batch_var->is_null())
5535 << "batch_mean and batch_var must both be null or both be non-null";
5536
5537 void* batch_mean_opaque;
5538 void* batch_var_opaque;
5539 if (!batch_mean->is_null() && !batch_var->is_null()) {
5540 if (exponential_average_factor == 1.0) {
5541 stream->ThenMemZero(batch_mean, batch_mean->size());
5542 stream->ThenMemZero(batch_var, batch_var->size());
5543 }
5544 batch_mean_opaque = batch_mean->opaque();
5545 batch_var_opaque = batch_var->opaque();
5546 } else {
5547 batch_mean_opaque = nullptr;
5548 batch_var_opaque = nullptr;
5549 }
5550
5551 bool called = false;
5552 #if CUDNN_VERSION >= 7402
5553 if (reserve_space_allocator != nullptr && workspace_allocator != nullptr) {
5554 called = true;
5555 RETURN_IF_CUDNN_ERROR(cudnnBatchNormalizationForwardTrainingEx(
5556 /*handle=*/cudnn.handle(),
5557 /*mode=*/mode,
5558 /*bnOps=*/bn_ops,
5559 /*alpha=*/&one,
5560 /*beta=*/&zero,
5561 /*xDesc=*/x_descriptor.handle(),
5562 /*xData=*/x.opaque(),
5563 /*zDesc=*/x_descriptor.handle(),
5564 /*zData=*/side_input.opaque(),
5565 /*yDesc=*/x_descriptor.handle(),
5566 /*yData=*/y->opaque(),
5567 /*bnScaleBiasMeanVarDesc=*/scale_offset_descriptor.handle(),
5568 /*bnScale=*/scale.opaque(),
5569 /*bnBias=*/offset.opaque(),
5570 /*exponentialAverageFactor=*/exponential_average_factor,
5571 /*resultRunningMean=*/batch_mean_opaque,
5572 /*resultRunningVariance=*/batch_var_opaque,
5573 /*epsilon=*/epsilon,
5574 /*resultSaveMean=*/saved_mean->opaque(),
5575 /*resultSaveInvVariance=*/saved_inv_var->opaque(),
5576 /*activationDesc=*/activation_desc.handle(),
5577 /*workspace=*/workspace.opaque(),
5578 /*workSpaceSizeInBytes=*/workspace.size(),
5579 /*reserveSpace=*/reserve_space.opaque(),
5580 /*reserveSpaceSizeInBytes=*/reserve_space.size()));
5581 }
5582 #endif
5583 if (!called) {
5584 TF_RETURN_IF_ERROR(check_no_side_input_or_activation());
5585 RETURN_IF_CUDNN_ERROR(cudnnBatchNormalizationForwardTraining(
5586 cudnn.handle(), mode, &one, &zero, x_descriptor.handle(), x.opaque(),
5587 x_descriptor.handle(), y->opaque(), scale_offset_descriptor.handle(),
5588 scale.opaque(), offset.opaque(), exponential_average_factor,
5589 batch_mean_opaque, batch_var_opaque, epsilon, saved_mean->opaque(),
5590 saved_inv_var->opaque()));
5591 }
5592 } else {
5593 const void* maybe_inv_var = estimated_variance.opaque();
5594 TF_RETURN_IF_ERROR(check_no_side_input_or_activation());
5595 RETURN_IF_CUDNN_ERROR(cudnnBatchNormalizationForwardInference(
5596 cudnn.handle(), mode, &one, &zero, x_descriptor.handle(), x.opaque(),
5597 x_descriptor.handle(), y->opaque(), scale_offset_descriptor.handle(),
5598 scale.opaque(), offset.opaque(), estimated_mean.opaque(), maybe_inv_var,
5599 epsilon));
5600 }
5601 return ::tensorflow::OkStatus();
5602 }
5603
DoBatchNormalizationBackward(Stream * stream,const DeviceMemory<float> & y_backprop,const DeviceMemory<float> & x,const DeviceMemory<float> & scale,const DeviceMemory<float> & offset,const DeviceMemory<float> & mean,const DeviceMemory<float> & inv_var,const DeviceMemory<float> & y,const dnn::BatchDescriptor & x_desc,const dnn::BatchDescriptor & scale_offset_desc,const double epsilon,dnn::ActivationMode activation_mode,DeviceMemory<float> * x_backprop,DeviceMemory<float> * scale_backprop,DeviceMemory<float> * offset_backprop,DeviceMemory<float> * side_input_backprop,DeviceMemory<uint8> * reserve_space_data,ScratchAllocator * workspace_allocator)5604 bool CudnnSupport::DoBatchNormalizationBackward(
5605 Stream* stream, const DeviceMemory<float>& y_backprop,
5606 const DeviceMemory<float>& x, const DeviceMemory<float>& scale,
5607 const DeviceMemory<float>& offset, const DeviceMemory<float>& mean,
5608 const DeviceMemory<float>& inv_var, const DeviceMemory<float>& y,
5609 const dnn::BatchDescriptor& x_desc,
5610 const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
5611 dnn::ActivationMode activation_mode, DeviceMemory<float>* x_backprop,
5612 DeviceMemory<float>* scale_backprop, DeviceMemory<float>* offset_backprop,
5613 DeviceMemory<float>* side_input_backprop,
5614 DeviceMemory<uint8>* reserve_space_data,
5615 ScratchAllocator* workspace_allocator) {
5616 return IsStatusOk(
5617 DoBatchNormalizationBackwardImpl(
5618 stream, CUDNN_DATA_FLOAT, CUDNN_DATA_FLOAT, y_backprop, x, scale,
5619 offset, mean, inv_var, y, x_desc, scale_offset_desc, epsilon,
5620 activation_mode, x_backprop, scale_backprop, offset_backprop,
5621 side_input_backprop, reserve_space_data, workspace_allocator),
5622 /*report_error=*/true);
5623 }
5624
DoBatchNormalizationBackward(Stream * stream,const DeviceMemory<Eigen::half> & y_backprop,const DeviceMemory<Eigen::half> & x,const DeviceMemory<float> & scale,const DeviceMemory<float> & offset,const DeviceMemory<float> & mean,const DeviceMemory<float> & inv_var,const DeviceMemory<Eigen::half> & y,const dnn::BatchDescriptor & x_desc,const dnn::BatchDescriptor & scale_offset_desc,const double epsilon,dnn::ActivationMode activation_mode,DeviceMemory<Eigen::half> * x_backprop,DeviceMemory<float> * scale_backprop,DeviceMemory<float> * offset_backprop,DeviceMemory<Eigen::half> * side_input_backprop,DeviceMemory<uint8> * reserve_space_data,ScratchAllocator * workspace_allocator)5625 bool CudnnSupport::DoBatchNormalizationBackward(
5626 Stream* stream, const DeviceMemory<Eigen::half>& y_backprop,
5627 const DeviceMemory<Eigen::half>& x, const DeviceMemory<float>& scale,
5628 const DeviceMemory<float>& offset, const DeviceMemory<float>& mean,
5629 const DeviceMemory<float>& inv_var, const DeviceMemory<Eigen::half>& y,
5630 const dnn::BatchDescriptor& x_desc,
5631 const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
5632 dnn::ActivationMode activation_mode, DeviceMemory<Eigen::half>* x_backprop,
5633 DeviceMemory<float>* scale_backprop, DeviceMemory<float>* offset_backprop,
5634 DeviceMemory<Eigen::half>* side_input_backprop,
5635 DeviceMemory<uint8>* reserve_space_data,
5636 ScratchAllocator* workspace_allocator) {
5637 return IsStatusOk(
5638 DoBatchNormalizationBackwardImpl(
5639 stream, CUDNN_DATA_HALF, CUDNN_DATA_FLOAT, y_backprop, x, scale,
5640 offset, mean, inv_var, y, x_desc, scale_offset_desc, epsilon,
5641 activation_mode, x_backprop, scale_backprop, offset_backprop,
5642 side_input_backprop, reserve_space_data, workspace_allocator),
5643 /*report_error=*/true);
5644 }
5645
5646 template <class T, class U>
DoBatchNormalizationBackwardImpl(Stream * stream,int cudnn_input_type,int cudnn_scale_type,const DeviceMemory<T> & y_backprop,const DeviceMemory<T> & x,const DeviceMemory<U> & scale,const DeviceMemory<U> & offset,const DeviceMemory<U> & mean,const DeviceMemory<U> & inv_var,const DeviceMemory<T> & y,const dnn::BatchDescriptor & x_desc,const dnn::BatchDescriptor & scale_offset_desc,const double epsilon,dnn::ActivationMode activation_mode,DeviceMemory<T> * x_backprop,DeviceMemory<U> * scale_backprop,DeviceMemory<U> * offset_backprop,DeviceMemory<T> * side_input_backprop,DeviceMemory<uint8> * reserve_space_data,ScratchAllocator * workspace_allocator)5647 port::Status CudnnSupport::DoBatchNormalizationBackwardImpl(
5648 Stream* stream, int cudnn_input_type, int cudnn_scale_type,
5649 const DeviceMemory<T>& y_backprop, const DeviceMemory<T>& x,
5650 const DeviceMemory<U>& scale, const DeviceMemory<U>& offset,
5651 const DeviceMemory<U>& mean, const DeviceMemory<U>& inv_var,
5652 const DeviceMemory<T>& y, const dnn::BatchDescriptor& x_desc,
5653 const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
5654 dnn::ActivationMode activation_mode, DeviceMemory<T>* x_backprop,
5655 DeviceMemory<U>* scale_backprop, DeviceMemory<U>* offset_backprop,
5656 DeviceMemory<T>* side_input_backprop,
5657 DeviceMemory<uint8>* reserve_space_data,
5658 ScratchAllocator* workspace_allocator) {
5659 CudnnTensorDescriptor x_descriptor(
5660 x_desc, static_cast<cudnnDataType_t>(cudnn_input_type));
5661 CudnnTensorDescriptor scale_offset_descriptor(
5662 scale_offset_desc, static_cast<cudnnDataType_t>(cudnn_scale_type));
5663 cudnnBatchNormMode_t mode = CUDNN_BATCHNORM_SPATIAL;
5664 if (BatchnormSpatialPersistentEnabled()) {
5665 mode = CUDNN_BATCHNORM_SPATIAL_PERSISTENT;
5666 }
5667 float one = 1.0;
5668 float zero = 0.0;
5669
5670 auto cudnn = cudnn_->GetHandle(parent_, stream);
5671
5672 bool called = false;
5673 #if CUDNN_VERSION >= 7402
5674 if (reserve_space_data != nullptr && workspace_allocator != nullptr) {
5675 called = true;
5676 const cudnnBatchNormOps_t bn_ops = [&]() {
5677 if (side_input_backprop->is_null()) {
5678 return activation_mode == dnn::ActivationMode::kNone
5679 ? CUDNN_BATCHNORM_OPS_BN
5680 : CUDNN_BATCHNORM_OPS_BN_ACTIVATION;
5681 } else {
5682 return CUDNN_BATCHNORM_OPS_BN_ADD_ACTIVATION;
5683 }
5684 }();
5685
5686 // We use Nan propagation to be consistent with
5687 // CudnnSupport::DoActivate(...).
5688 CudnnActivationDescriptor activation_desc(
5689 activation_mode, CUDNN_PROPAGATE_NAN, x_desc.value_max());
5690
5691 TF_ASSIGN_OR_RETURN(
5692 DeviceMemory<uint8> workspace,
5693 CreateBatchNormBackwardWorkspace(
5694 stream, cudnn, mode, bn_ops, activation_desc.handle(), x_descriptor,
5695 scale_offset_descriptor, workspace_allocator));
5696 RETURN_IF_CUDNN_ERROR(cudnnBatchNormalizationBackwardEx(
5697 /*handle=*/cudnn.handle(),
5698 /*mode=*/mode,
5699 /*bnOps=*/bn_ops,
5700 /*alphaDataDiff=*/&one,
5701 /*betaDataDiff=*/&zero,
5702 /*alphaParamDiff=*/&one,
5703 /*betaParamDiff=*/&zero,
5704 /*xDesc=*/x_descriptor.handle(),
5705 /*xData=*/x.opaque(),
5706 /*yDesc=*/x_descriptor.handle(),
5707 /*yData=*/y.opaque(),
5708 /*dyDesc=*/x_descriptor.handle(),
5709 /*dyData=*/y_backprop.opaque(),
5710 /*dzDesc=*/x_descriptor.handle(),
5711 /*dzData=*/side_input_backprop->opaque(),
5712 /*dxDesc=*/x_descriptor.handle(),
5713 /*dxData=*/x_backprop->opaque(),
5714 /*dBnScaleBiasDesc=*/scale_offset_descriptor.handle(),
5715 /*bnScaleData=*/scale.opaque(),
5716 /*bnBiasData=*/offset.opaque(),
5717 /*dBnScaleData=*/scale_backprop->opaque(),
5718 /*dBnBiasData=*/offset_backprop->opaque(),
5719 /*epsilon=*/epsilon,
5720 /*savedMean=*/mean.opaque(),
5721 /*savedInvVariance=*/inv_var.opaque(),
5722 /*activationDesc=*/activation_desc.handle(),
5723 /*workspace=*/workspace.opaque(),
5724 /*workSpaceSizeInBytes=*/workspace.size(),
5725 /*reserveSpace=*/reserve_space_data->opaque(),
5726 /*reserveSpaceSizeInBytes=*/reserve_space_data->size()));
5727 }
5728 #endif
5729 auto check_no_side_input_or_activation = [&]() -> port::Status {
5730 if (activation_mode != dnn::ActivationMode::kNone ||
5731 !side_input_backprop->is_null()) {
5732 return port::InternalError(absl::StrCat(
5733 "Side input and activation are not supported by cuDNN version: ",
5734 CUDNN_VERSION));
5735 } else {
5736 return ::tensorflow::OkStatus();
5737 }
5738 };
5739
5740 if (!called && check_no_side_input_or_activation().ok()) {
5741 RETURN_IF_CUDNN_ERROR(cudnnBatchNormalizationBackward(
5742 cudnn.handle(), mode, &one, &zero, &one, &zero, x_descriptor.handle(),
5743 x.opaque(), x_descriptor.handle(), y_backprop.opaque(),
5744 x_descriptor.handle(), x_backprop->opaque(),
5745 scale_offset_descriptor.handle(), scale.opaque(),
5746 scale_backprop->opaque(), offset_backprop->opaque(), epsilon,
5747 mean.opaque(), inv_var.opaque()));
5748 }
5749
5750 return ::tensorflow::OkStatus();
5751 }
5752
DoFusedConvolve(Stream * stream,dnn::DataType input_type,dnn::DataType side_input_type,dnn::DataType bias_type,dnn::DataType output_type,const dnn::BatchDescriptor & conv_input_descriptor,DeviceMemoryBase conv_input_data,double conv_scale,const dnn::FilterDescriptor & filter_descriptor,DeviceMemoryBase filter_data,const dnn::ConvolutionDescriptor & convolution_descriptor,DeviceMemoryBase side_input_data,double side_input_scale,const dnn::BatchDescriptor & bias_descriptor,DeviceMemoryBase biases,dnn::ActivationMode activation_mode,const dnn::BatchDescriptor & output_descriptor,DeviceMemoryBase output_data,ScratchAllocator * scratch_allocator,const dnn::AlgorithmConfig & algorithm_config,dnn::ProfileResult * output_profile_result)5753 port::Status CudnnSupport::DoFusedConvolve(
5754 Stream* stream, dnn::DataType input_type, dnn::DataType side_input_type,
5755 dnn::DataType bias_type, dnn::DataType output_type,
5756 const dnn::BatchDescriptor& conv_input_descriptor,
5757 DeviceMemoryBase conv_input_data, double conv_scale,
5758 const dnn::FilterDescriptor& filter_descriptor,
5759 DeviceMemoryBase filter_data,
5760 const dnn::ConvolutionDescriptor& convolution_descriptor,
5761 DeviceMemoryBase side_input_data, double side_input_scale,
5762 const dnn::BatchDescriptor& bias_descriptor, DeviceMemoryBase biases,
5763 dnn::ActivationMode activation_mode,
5764 const dnn::BatchDescriptor& output_descriptor, DeviceMemoryBase output_data,
5765 ScratchAllocator* scratch_allocator,
5766 const dnn::AlgorithmConfig& algorithm_config,
5767 dnn::ProfileResult* output_profile_result) {
5768 if (input_type == dnn::DataType::kInt8 &&
5769 !stream->GetCudaComputeCapability().IsAtLeast(6, 1)) {
5770 return port::UnimplementedError(
5771 "cudnnConvolutionBiasActivationForward() for int8 is only supported "
5772 "on GPUs with compute capability 6.1 or later.");
5773 }
5774
5775 if (input_type == dnn::DataType::kInt8 &&
5776 output_type == dnn::DataType::kFloat &&
5777 (CUDNN_VERSION >= 8000 && CUDNN_VERSION <= 8200)) {
5778 return port::UnimplementedError(
5779 "int8 -> float fused conv is disabled for this cuDNN version. See "
5780 "go/nvbugs/3326122");
5781 }
5782
5783 if (activation_mode != dnn::ActivationMode::kRelu &&
5784 activation_mode != dnn::ActivationMode::kNone) {
5785 return port::Status(port::error::INVALID_ARGUMENT,
5786 "cudnnConvolutionBiasActivationForward() only supports "
5787 "Relu or None activation.");
5788 }
5789
5790 CudnnTensorDescriptor conv_input_nd(
5791 conv_input_descriptor,
5792 ToCudnnDataType(input_type, conv_input_descriptor.layout()));
5793 CudnnTensorDescriptor output_nd(
5794 output_descriptor,
5795 ToCudnnDataType(output_type, conv_input_descriptor.layout()));
5796 CudnnFilterDescriptor filter(
5797 filter_descriptor,
5798 ToCudnnDataType(input_type, filter_descriptor.layout()));
5799 CudnnTensorDescriptor bias_nd(bias_descriptor, ToCudnnDataType(bias_type));
5800
5801 DeviceMemory<uint8> scratch;
5802 dnn::AlgorithmDesc algo_desc;
5803 {
5804 auto cudnn = cudnn_->GetHandle(parent_, stream);
5805 TF_ASSIGN_OR_RETURN(
5806 algo_desc,
5807 GetCudnnConvolutionForwardAlgorithm(
5808 stream, cudnn, algorithm_config, conv_input_nd, filter, input_type,
5809 convolution_descriptor, output_nd, scratch_allocator, &scratch));
5810 } // Explicitly release cuDNN handle.
5811
5812 CudnnConvolutionDescriptor conv(
5813 convolution_descriptor,
5814 ToCudnnDataType(GetConvAccumulatorType(input_type)));
5815 conv.set_use_tensor_op_math(algo_desc.tensor_ops_enabled());
5816
5817 // CUDNN v6 only supports CUDNN_NOT_PROPAGATE_NAN as the reluNanOpt for
5818 // activation descriptor. Note that this will change the nan propagation
5819 // behavior from separate conv, bias, and relu (which by default is
5820 // CUDNN_PROPAGATE_NAN).
5821 CudnnActivationDescriptor activation_desc(
5822 activation_mode, CUDNN_NOT_PROPAGATE_NAN, output_descriptor.value_max());
5823
5824 TF_ASSIGN_OR_RETURN(
5825 auto runner,
5826 CudnnLegacyFusedConvRunner::Create(
5827 parent_, stream, cudnn_.get(), std::move(algo_desc), input_type,
5828 conv_scale, side_input_scale, std::move(conv_input_nd),
5829 std::move(output_nd), std::move(filter), std::move(bias_nd),
5830 std::move(conv), std::move(activation_desc)));
5831
5832 return runner(stream, output_profile_result, scratch, conv_input_data,
5833 filter_data, side_input_data, biases, output_data);
5834 }
5835
DoPrepareForCtcLoss(Stream * stream,dnn::DataType element_type,const dnn::RnnStateTensorDescriptor & probs_desc,const dnn::RnnStateTensorDescriptor & grads_desc,absl::Span<const int> labels_data,absl::Span<const int> labels_lengths_data,absl::Span<const int> input_lengths_data,ScratchAllocator * scratch_allocator,DeviceMemory<uint8> * scratch_memory,int * ctc_loss_algo_id)5836 port::Status CudnnSupport::DoPrepareForCtcLoss(
5837 Stream* stream, dnn::DataType element_type,
5838 const dnn::RnnStateTensorDescriptor& probs_desc,
5839 const dnn::RnnStateTensorDescriptor& grads_desc,
5840 absl::Span<const int> labels_data,
5841 absl::Span<const int> labels_lengths_data,
5842 absl::Span<const int> input_lengths_data,
5843 ScratchAllocator* scratch_allocator, DeviceMemory<uint8>* scratch_memory,
5844 int* ctc_loss_algo_id) {
5845 auto cudnn = cudnn_->GetHandle(parent_, stream);
5846 // Query the workspace size.
5847 size_t workspace_size_in_bytes = 0;
5848 #if CUDNN_VERSION >= 7603
5849 CudnnCtcLossDescriptor cudnn_ctc_loss_desc(ToCudnnDataType(element_type));
5850 const CudnnRnnStateTensorDescriptor& cudnn_probs_desc =
5851 static_cast<const CudnnRnnStateTensorDescriptor&>(probs_desc);
5852 const CudnnRnnStateTensorDescriptor& cudnn_grads_desc =
5853 static_cast<const CudnnRnnStateTensorDescriptor&>(grads_desc);
5854
5855 // Try running with `algo`, if successful then pick it. The
5856 // non-deterministic algorithm is first and thus preferentially picked when
5857 // determinism is not required.
5858 auto algo = RequireCudnnDeterminism() ? CUDNN_CTC_LOSS_ALGO_DETERMINISTIC
5859 : CUDNN_CTC_LOSS_ALGO_NON_DETERMINISTIC;
5860 cudnnStatus_t status = cudnnGetCTCLossWorkspaceSize(
5861 /*handle=*/cudnn.handle(), /*probsDesc=*/cudnn_probs_desc.handle(),
5862 /*gradientsDesc=*/cudnn_grads_desc.handle(),
5863 /*labels=*/labels_data.data(),
5864 /*labelLengths=*/labels_lengths_data.data(),
5865 /*inputLengths=*/input_lengths_data.data(),
5866 /*algo=*/algo,
5867 /*ctcLossDesc=*/cudnn_ctc_loss_desc.handle(),
5868 /*sizeInBytes=*/&workspace_size_in_bytes);
5869 if (RequireCudnnDeterminism()) {
5870 RETURN_IF_CUDNN_ERROR(status);
5871 }
5872
5873 if (status != CUDNN_STATUS_SUCCESS) {
5874 algo = CUDNN_CTC_LOSS_ALGO_DETERMINISTIC;
5875 RETURN_IF_CUDNN_ERROR(cudnnGetCTCLossWorkspaceSize(
5876 /*handle=*/cudnn.handle(), /*probsDesc=*/cudnn_probs_desc.handle(),
5877 /*gradientsDesc=*/cudnn_grads_desc.handle(),
5878 /*labels=*/labels_data.data(),
5879 /*labelLengths=*/labels_lengths_data.data(),
5880 /*inputLengths=*/input_lengths_data.data(),
5881 /*algo=*/algo,
5882 /*ctcLossDesc=*/cudnn_ctc_loss_desc.handle(),
5883 /*sizeInBytes=*/&workspace_size_in_bytes));
5884 }
5885 *ctc_loss_algo_id = algo;
5886 #else
5887 return port::Status(port::error::INVALID_ARGUMENT,
5888 "No supported cudnnGetCTCLossWorkspaceSize when "
5889 "CUDNN_VERSION < 7.6.3");
5890 #endif
5891 // Allocate the workspace.
5892 if (workspace_size_in_bytes == 0) {
5893 *scratch_memory = DeviceMemory<uint8>();
5894 return ::tensorflow::OkStatus();
5895 }
5896 const auto scratch_or =
5897 scratch_allocator->AllocateBytes(workspace_size_in_bytes);
5898 if (scratch_or.ok()) {
5899 *scratch_memory = scratch_or.ValueOrDie();
5900 return ::tensorflow::OkStatus();
5901 }
5902 return port::InternalError(
5903 "Failed to allocate scratch memory for the CuDNN CTC Loss");
5904 }
5905
DoCtcLoss(Stream * stream,dnn::DataType element_type,const dnn::RnnStateTensorDescriptor & probs_desc,const DeviceMemoryBase probs_data,absl::Span<const int> labels_data,absl::Span<const int> labels_lengths_data,absl::Span<const int> input_lengths_data,DeviceMemoryBase costs_data,const dnn::RnnStateTensorDescriptor & grads_desc,DeviceMemoryBase grads_data,DeviceMemory<uint8> scratch_memory,int ctc_loss_algo_id)5906 port::Status CudnnSupport::DoCtcLoss(
5907 Stream* stream, dnn::DataType element_type,
5908 const dnn::RnnStateTensorDescriptor& probs_desc,
5909 const DeviceMemoryBase probs_data, absl::Span<const int> labels_data,
5910 absl::Span<const int> labels_lengths_data,
5911 absl::Span<const int> input_lengths_data, DeviceMemoryBase costs_data,
5912 const dnn::RnnStateTensorDescriptor& grads_desc,
5913 DeviceMemoryBase grads_data, DeviceMemory<uint8> scratch_memory,
5914 int ctc_loss_algo_id) {
5915 // Current cuDNN CTC Loss only supports the float datatype
5916 if (CUDNN_VERSION < 7603 || element_type != dnn::DataType::kFloat) {
5917 return port::Status(port::error::INVALID_ARGUMENT,
5918 "CudnnCtcLossDescriptor is supported only when the "
5919 "CUDNN_VERSION >= 7.6.3 and DataType is float");
5920 }
5921 CudnnCtcLossDescriptor cudnn_ctc_loss_desc(ToCudnnDataType(element_type));
5922 const CudnnRnnStateTensorDescriptor& cudnn_probs_desc =
5923 static_cast<const CudnnRnnStateTensorDescriptor&>(probs_desc);
5924 const CudnnRnnStateTensorDescriptor& cudnn_grads_desc =
5925 static_cast<const CudnnRnnStateTensorDescriptor&>(grads_desc);
5926 return DoCtcLossImpl(stream, cudnn_probs_desc, probs_data, labels_data,
5927 labels_lengths_data, input_lengths_data, costs_data,
5928 cudnn_grads_desc, grads_data, cudnn_ctc_loss_desc,
5929 scratch_memory, ctc_loss_algo_id);
5930 }
5931
DoTransformTensor(Stream * stream,const dnn::BatchDescriptor & input_desc,dnn::DataType input_type,const DeviceMemoryBase & input_data,const dnn::BatchDescriptor & output_desc,dnn::DataType output_type,float scale,DeviceMemoryBase * output_data)5932 bool CudnnSupport::DoTransformTensor(Stream* stream,
5933 const dnn::BatchDescriptor& input_desc,
5934 dnn::DataType input_type,
5935 const DeviceMemoryBase& input_data,
5936 const dnn::BatchDescriptor& output_desc,
5937 dnn::DataType output_type, float scale,
5938 DeviceMemoryBase* output_data) {
5939 float beta = 0.0f;
5940 CudnnTensorDescriptor input_tensor_desc(
5941 input_desc, ToCudnnDataType(input_type, input_desc.layout()));
5942 CudnnTensorDescriptor output_tensor_desc(
5943 output_desc, ToCudnnDataType(output_type, output_desc.layout()));
5944 auto cudnn = cudnn_->GetHandle(parent_, stream);
5945 const auto status = [&] {
5946 RETURN_IF_CUDNN_ERROR(cudnnTransformTensor(
5947 cudnn.handle(), &scale, input_tensor_desc.handle(), input_data.opaque(),
5948 &beta, output_tensor_desc.handle(), output_data->opaque()));
5949 return ::tensorflow::OkStatus();
5950 }();
5951 return IsStatusOk(status, /*report_error=*/true);
5952 }
5953
DoMatMul(Stream * stream,const DeviceMemory<float> & input_data,const DeviceMemory<float> & weights,const dnn::BatchDescriptor & input_dimensions,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<float> * output_data)5954 bool CudnnSupport::DoMatMul(Stream* stream,
5955 const DeviceMemory<float>& input_data,
5956 const DeviceMemory<float>& weights,
5957 const dnn::BatchDescriptor& input_dimensions,
5958 const dnn::BatchDescriptor& output_dimensions,
5959 DeviceMemory<float>* output_data) {
5960 if (input_dimensions.count() != output_dimensions.count()) {
5961 LOG(ERROR) << "MatMul input and output dimensions are not compatible.";
5962 return false;
5963 }
5964
5965 // We do not permute the input or output, instead we just
5966 // reinterpret the layout. We are working with row-major matrices
5967 // and the rows of the input and output correspond to batch, so
5968 // batch has to be outermost in both the input and output.
5969 //
5970 // By adding transposes to the BLAS gemm call we could perhaps make
5971 // the kYXDepthBatch layout work as well, but there has been no need
5972 // for that so far.
5973 if (input_dimensions.layout() != dnn::DataLayout::kBatchYXDepth &&
5974 input_dimensions.layout() != dnn::DataLayout::kBatchDepthYX) {
5975 LOG(ERROR) << "Unsupported MatMul input layout.";
5976 return false;
5977 }
5978 if (output_dimensions.layout() != dnn::DataLayout::kBatchYXDepth &&
5979 output_dimensions.layout() != dnn::DataLayout::kBatchDepthYX) {
5980 LOG(ERROR) << "Unsupported MatMul output layout.";
5981 return false;
5982 }
5983
5984 if (output_dimensions.width() == 1 && output_dimensions.height() == 1) {
5985 // This is a fast path that also supports the kBatchYXDepth layout.
5986
5987 // The matrices here are in row-major format while BLAS expects
5988 // column-major, i.e. our matrices are transposed as far as BLAS
5989 // is concerned. So we need to compute output^T =
5990 // input^T*weights^T. There is no parameter for transposing the
5991 // output in BLAS gemm, but instead we can transpose both sides of
5992 // the equality to see that this is equivalent to
5993 // output=weights*input. So we only need to swap the order of
5994 // weights and input in the matrix product to correct for the
5995 // row-major versus column-major difference.
5996 const int64_t m = output_dimensions.NodesAcrossFeatureMaps();
5997 const int64_t n = input_dimensions.count();
5998 const int64_t k = input_dimensions.NodesAcrossFeatureMaps();
5999 if (!stream
6000 ->ThenBlasGemm(blas::Transpose::kNoTranspose,
6001 blas::Transpose::kNoTranspose, m, n, k, weights, m,
6002 input_data, k, output_data, m,
6003 blas::kDefaultComputePrecision)
6004 .ok()) {
6005 return false;
6006 }
6007 } else {
6008 // This is a slower and more complex path that supports output
6009 // width() * height() > 1, though it only supports the
6010 // kBatchYXDepth layout. Does support kBatchDepthYX if output
6011 // feature_map_count() == 1, as then there is no difference
6012 // between the two layouts.
6013 //
6014 // The operation here is the same as above, except that we have to
6015 // do the matrix multiplication for each (y,x) output coordinate
6016 // separately. We then interpret weights as containing K = width()
6017 // * height() different matrices, which we all multiply onto the
6018 // matrix from input_data, yielding K matrix products. We then
6019 // combine these together into one matrix by concatenating all the
6020 // first rows of these matrices, then all the seconds rows and so
6021 // on. We can do this with a batched matrix multiplication, where
6022 // the result is written to a different submatrix of the output
6023 // for each matrix multiplication.
6024 //
6025 // The reason that we only support the kBatchYXDepth output layout
6026 // is that we have to do something in the depth for each (y,x)
6027 // coordinate. The kBatchYXDepth layout has the depth information
6028 // for each point (y,x) in contiguous memory while the
6029 // kBatchDepthYX layout does not.
6030 //
6031 // TODO(broune): Consider a special case for when output depth ==
6032 // 1, as then possibly this could all be done as one matrix
6033 // multiplication instead of a batched one, which should be
6034 // faster. Another possibility would be to add a weights layout
6035 // parameter and then support kBatchDepthYX for a different
6036 // weights layout.
6037 if (output_dimensions.layout() != dnn::DataLayout::kBatchYXDepth &&
6038 !(output_dimensions.layout() == dnn::DataLayout::kBatchDepthYX &&
6039 output_dimensions.feature_map_count() == 1)) {
6040 LOG(ERROR) << "Unsupported MatMul output layout.";
6041 return false;
6042 }
6043
6044 const float alpha = 1.0f; // Take the matrix product without scaling it.
6045 const float beta = 0.0f; // Ignore the original values in output_data.
6046 const uint64_t m = output_dimensions.feature_map_count();
6047 const uint64_t n = input_dimensions.count();
6048 const uint64_t k = input_dimensions.NodesAcrossFeatureMaps();
6049 const int lda = m;
6050 const int ldb = k;
6051 const int ldc = output_dimensions.NodesAcrossFeatureMaps();
6052 const int batch_count = output_dimensions.NodesPerFeatureMap();
6053
6054 std::vector<DeviceMemory<float>> a(batch_count);
6055 std::vector<DeviceMemory<float>> b(batch_count);
6056 std::vector<DeviceMemory<float>> c(batch_count);
6057 for (int i = 0; i < batch_count; ++i) {
6058 const int weights_offset = i * input_dimensions.NodesAcrossFeatureMaps() *
6059 output_dimensions.feature_map_count();
6060 a[i] = DeviceMemory<float>::MakeFromByteSize(
6061 const_cast<float*>(reinterpret_cast<const float*>(weights.opaque())) +
6062 weights_offset,
6063 weights.ElementCount() - weights_offset);
6064
6065 b[i] = input_data;
6066
6067 const int output_offset = i * output_dimensions.feature_map_count();
6068 c[i] = DeviceMemory<float>::MakeFromByteSize(
6069 const_cast<float*>(
6070 reinterpret_cast<const float*>(output_data->opaque())) +
6071 output_offset,
6072 output_data->ElementCount() - output_offset);
6073 }
6074 const auto toPtrs = [](std::vector<DeviceMemory<float>>& v) {
6075 std::vector<DeviceMemory<float>*> ptrs;
6076 ptrs.reserve(v.size());
6077 for (auto& mem : v) {
6078 ptrs.push_back(&mem);
6079 }
6080 return ptrs;
6081 };
6082
6083 stream->ThenBlasGemmBatched(blas::Transpose::kNoTranspose,
6084 blas::Transpose::kNoTranspose, m, n, k, alpha,
6085 toPtrs(a), lda, toPtrs(b), ldb, beta, toPtrs(c),
6086 ldc, batch_count);
6087 }
6088
6089 return stream->ok();
6090 }
6091
DoBiasAdd(Stream * stream,const DeviceMemory<float> & input_data,const DeviceMemory<float> & biases,const dnn::BatchDescriptor & dimensions,DeviceMemory<float> * output_data)6092 bool CudnnSupport::DoBiasAdd(Stream* stream,
6093 const DeviceMemory<float>& input_data,
6094 const DeviceMemory<float>& biases,
6095 const dnn::BatchDescriptor& dimensions,
6096 DeviceMemory<float>* output_data) {
6097 CudnnTensorDescriptor input_descriptor(dimensions, CUDNN_DATA_FLOAT);
6098
6099 dnn::BatchDescriptor bias_dimensions;
6100 bias_dimensions.set_count(1)
6101 .set_feature_map_count(dimensions.feature_map_count())
6102 .set_height(1)
6103 .set_width(1)
6104 .set_layout(dnn::DataLayout::kBatchYXDepth);
6105 CudnnTensorDescriptor bias_descriptor(bias_dimensions, CUDNN_DATA_FLOAT);
6106
6107 // cudnnAddTensor after R3 is in-place, so we need to copy input_data to
6108 // output_data before doing the addition, unless the input and
6109 // output are at the same address.
6110 if (input_data.opaque() != output_data->opaque()) {
6111 stream->ThenMemcpy(output_data, input_data,
6112 dimensions.ElementCount() * sizeof(float));
6113 if (!stream->ok()) {
6114 LOG(ERROR)
6115 << "stream " << stream
6116 << " could not enqueue a tensor copy as part of bias addition.";
6117 return false;
6118 }
6119 }
6120
6121 const float alpha = 1.0f;
6122 const float beta = 1.0f;
6123
6124 auto cudnn = cudnn_->GetHandle(parent_, stream);
6125
6126 const auto status = [&] {
6127 RETURN_IF_CUDNN_ERROR(cudnnAddTensor(
6128 cudnn.handle(), &alpha, bias_descriptor.handle(), biases.opaque(),
6129 &beta, input_descriptor.handle(), output_data->opaque()));
6130 return ::tensorflow::OkStatus();
6131 }();
6132 return IsStatusOk(status, /*report_error=*/true);
6133 }
6134
DoActivate(Stream * stream,dnn::ActivationMode activation_mode,const dnn::BatchDescriptor & dimensions,const DeviceMemory<float> & input_data,DeviceMemory<float> * output_data,uint64_t options)6135 bool CudnnSupport::DoActivate(Stream* stream,
6136 dnn::ActivationMode activation_mode,
6137 const dnn::BatchDescriptor& dimensions,
6138 const DeviceMemory<float>& input_data,
6139 DeviceMemory<float>* output_data,
6140 uint64_t options) {
6141 CudnnActivationDescriptor activation_desc(
6142 activation_mode, CUDNN_PROPAGATE_NAN, dimensions.value_max());
6143
6144 CudnnTensorDescriptor input_nd(dimensions, CUDNN_DATA_FLOAT);
6145 // Alpha is the input scaling factor.
6146 float alpha = 1.0;
6147 // Beta is the output scaling factor.
6148 float beta = 0.0;
6149
6150 auto cudnn = cudnn_->GetHandle(parent_, stream);
6151 const auto status = [&] {
6152 RETURN_IF_CUDNN_ERROR(cudnnActivationForward(
6153 cudnn.handle(), activation_desc.handle(), &alpha, input_nd.handle(),
6154 input_data.opaque(), &beta, input_nd.handle(), output_data->opaque()));
6155 return ::tensorflow::OkStatus();
6156 }();
6157 return IsStatusOk(status, /*report_error=*/true);
6158 }
6159
6160 namespace {
6161
6162 // Cudnn legacy API only supports int32 indexing and can handle a maximum of
6163 // 2^31-1 elements. For pooling operations, we split the big tensor along the
6164 // batch axis into multiple small tensors when possible and then call cudnn API
6165 // sequentially.
6166 struct PoolingSplitsSpec {
6167 int64_t num_batches;
6168 int64_t input_offset_in_bytes;
6169 int64_t output_offset_in_bytes;
6170 };
6171
GetTensorSplits(const dnn::BatchDescriptor & input_descriptor,const dnn::BatchDescriptor & output_descriptor,dnn::DataType element_type)6172 port::StatusOr<std::vector<PoolingSplitsSpec>> GetTensorSplits(
6173 const dnn::BatchDescriptor& input_descriptor,
6174 const dnn::BatchDescriptor& output_descriptor, dnn::DataType element_type) {
6175 std::vector<PoolingSplitsSpec> out;
6176 if (element_type == dnn::DataType::kInt8) {
6177 out.push_back({input_descriptor.count(), 0, 0});
6178 return out;
6179 }
6180
6181 cudnnDataType_t cudnn_input_type =
6182 ToCudnnDataType(element_type, input_descriptor.layout());
6183 cudnnDataType_t cudnn_output_type =
6184 ToCudnnDataType(element_type, output_descriptor.layout());
6185
6186 std::vector<int64_t> dims64 =
6187 input_descriptor.full_dims(dnn::DataLayout::kBatchDepthYX);
6188
6189 int64_t num_batches = input_descriptor.count();
6190 int64_t elements_per_batch_input = input_descriptor.NodesAcrossFeatureMaps();
6191 int64_t elements_per_batch_output =
6192 output_descriptor.NodesAcrossFeatureMaps();
6193
6194 int64_t max_batches_per_split =
6195 std::numeric_limits<int>::max() / elements_per_batch_input;
6196
6197 if (max_batches_per_split == 0) {
6198 return port::Status(
6199 port::error::INTERNAL,
6200 absl::StrCat(
6201 "Tensor has too many elements for int32 indexing: batches=",
6202 num_batches, " elements_per_batch=", elements_per_batch_input,
6203 "."));
6204 }
6205
6206 int64_t processed_batches = 0;
6207 while (processed_batches < num_batches) {
6208 int64_t num_batches_per_split =
6209 std::min(max_batches_per_split, num_batches - processed_batches);
6210 int64_t offset_input = processed_batches * elements_per_batch_input *
6211 CudnnDataTypeToByteSize(cudnn_input_type);
6212 int64_t offset_output = processed_batches * elements_per_batch_output *
6213 CudnnDataTypeToByteSize(cudnn_output_type);
6214 out.push_back({num_batches_per_split, offset_input, offset_output});
6215 processed_batches += num_batches_per_split;
6216 }
6217 return out;
6218 }
6219 } // namespace
6220
DoPoolForward(dnn::DataType element_type,Stream * stream,const dnn::PoolingDescriptor & pooling_dimensions,const dnn::BatchDescriptor & input_dimensions,DeviceMemoryBase input_data,const dnn::BatchDescriptor & output_dimensions,DeviceMemoryBase output_data,ScratchAllocator * workspace_allocator)6221 port::Status CudnnSupport::DoPoolForward(
6222 dnn::DataType element_type, Stream* stream,
6223 const dnn::PoolingDescriptor& pooling_dimensions,
6224 const dnn::BatchDescriptor& input_dimensions, DeviceMemoryBase input_data,
6225 const dnn::BatchDescriptor& output_dimensions, DeviceMemoryBase output_data,
6226 ScratchAllocator* workspace_allocator) {
6227 // Alpha is the scaling factor for input.
6228 const float alpha_f = 1.0f;
6229 const double alpha_d = 1.0;
6230 const void* alpha = element_type == dnn::DataType::kDouble
6231 ? static_cast<const void*>(&alpha_d)
6232 : static_cast<const void*>(&alpha_f);
6233 // Beta is the scaling factor for output.
6234 const float beta_f = 0.0f;
6235 const double beta_d = 0.0;
6236 const void* beta = element_type == dnn::DataType::kDouble
6237 ? static_cast<const void*>(&beta_d)
6238 : static_cast<const void*>(&beta_f);
6239
6240 cudnnDataType_t cudnn_input_type =
6241 ToCudnnDataType(element_type, input_dimensions.layout());
6242 cudnnDataType_t cudnn_output_type =
6243 ToCudnnDataType(element_type, output_dimensions.layout());
6244 CudnnPoolingDescriptor pooling_desc(pooling_dimensions);
6245 auto cudnn = cudnn_->GetHandle(parent_, stream);
6246
6247 auto cudnn_launcher = [&](CudnnTensorDescriptor& src_desc,
6248 CudnnTensorDescriptor& dest_desc,
6249 const void* input_ptr, void* output_ptr) {
6250 RETURN_IF_CUDNN_ERROR(cudnnPoolingForward(
6251 cudnn.handle(), pooling_desc.handle(), alpha, src_desc.handle(),
6252 input_ptr, beta, dest_desc.handle(), output_ptr));
6253 return ::tensorflow::OkStatus();
6254 };
6255
6256 auto splits_or =
6257 GetTensorSplits(input_dimensions, output_dimensions, element_type);
6258 if (!splits_or.ok()) {
6259 return port::Status(port::error::INTERNAL, "Cudnn pooling failed to split");
6260 }
6261 auto splits = std::move(splits_or.ValueOrDie());
6262
6263 dnn::BatchDescriptor input_split = input_dimensions;
6264 dnn::BatchDescriptor output_split = output_dimensions;
6265 for (int i = 0; i < splits.size(); i++) {
6266 // It is safe to cap the batch dimension, since it is the leading
6267 // dimension and will have no effect on the computation of strides in both
6268 // kBatchYXDepth and kBatchDepthYX formats.
6269 input_split.set_count(splits[i].num_batches);
6270 output_split.set_count(splits[i].num_batches);
6271 CudnnTensorDescriptor src_desc(input_split, cudnn_input_type);
6272 CudnnTensorDescriptor dest_desc(output_split, cudnn_output_type);
6273
6274 void* input_data_ptr = static_cast<char*>(input_data.opaque()) +
6275 splits[i].input_offset_in_bytes;
6276 void* output_data_ptr = static_cast<char*>(output_data.opaque()) +
6277 splits[i].output_offset_in_bytes;
6278 const auto status =
6279 cudnn_launcher(src_desc, dest_desc, input_data_ptr, output_data_ptr);
6280 if (!IsStatusOk(status, /*report_error=*/true)) {
6281 return status;
6282 }
6283 }
6284 return ::tensorflow::OkStatus();
6285 }
6286
DoPoolBackward(dnn::DataType element_type,Stream * stream,const dnn::PoolingDescriptor & pooling_dimensions,const dnn::BatchDescriptor & input_dimensions,DeviceMemoryBase input_data,const dnn::BatchDescriptor & output_dimensions,DeviceMemoryBase output_data,DeviceMemoryBase input_diff_data,DeviceMemoryBase output_diff_data,ScratchAllocator * workspace_allocator)6287 port::Status CudnnSupport::DoPoolBackward(
6288 dnn::DataType element_type, Stream* stream,
6289 const dnn::PoolingDescriptor& pooling_dimensions,
6290 const dnn::BatchDescriptor& input_dimensions, DeviceMemoryBase input_data,
6291 const dnn::BatchDescriptor& output_dimensions, DeviceMemoryBase output_data,
6292 DeviceMemoryBase input_diff_data, DeviceMemoryBase output_diff_data,
6293 ScratchAllocator* workspace_allocator) {
6294 // Alpha is the scaling factor for input.
6295 const float alpha_f = 1.0f;
6296 const double alpha_d = 1.0;
6297 const void* alpha = element_type == dnn::DataType::kDouble
6298 ? static_cast<const void*>(&alpha_d)
6299 : static_cast<const void*>(&alpha_f);
6300 // Beta is the scaling factor for output.
6301 const float beta_f = 0.0f;
6302 const double beta_d = 0.0;
6303 const void* beta = element_type == dnn::DataType::kDouble
6304 ? static_cast<const void*>(&beta_d)
6305 : static_cast<const void*>(&beta_f);
6306
6307 cudnnDataType_t cudnn_input_type =
6308 ToCudnnDataType(element_type, input_dimensions.layout());
6309 cudnnDataType_t cudnn_output_type =
6310 ToCudnnDataType(element_type, output_dimensions.layout());
6311 CudnnPoolingDescriptor pooling_desc(pooling_dimensions);
6312 auto cudnn = cudnn_->GetHandle(parent_, stream);
6313
6314 auto cudnn_launcher = [&](CudnnTensorDescriptor& src_desc,
6315 CudnnTensorDescriptor& dest_desc,
6316 const void* output_ptr, const void* input_diff_ptr,
6317 const void* input_ptr, void* output_diff_ptr) {
6318 RETURN_IF_CUDNN_ERROR(cudnnPoolingBackward(
6319 cudnn.handle(), pooling_desc.handle(), alpha, dest_desc.handle(),
6320 output_ptr, dest_desc.handle(), input_diff_ptr, src_desc.handle(),
6321 input_ptr, beta, src_desc.handle(), output_diff_ptr));
6322 return ::tensorflow::OkStatus();
6323 };
6324
6325 auto splits_or =
6326 GetTensorSplits(input_dimensions, output_dimensions, element_type);
6327 if (!splits_or.ok()) {
6328 return port::Status(port::error::INTERNAL, "Cudnn pooling failed to split");
6329 }
6330 auto splits = std::move(splits_or.ValueOrDie());
6331
6332 dnn::BatchDescriptor input_split = input_dimensions;
6333 dnn::BatchDescriptor output_split = output_dimensions;
6334 for (int i = 0; i < splits.size(); i++) {
6335 // It is safe to cap the batch dimension, since it is the leading
6336 // dimension and will have no effect on the computation of strides in both
6337 // kBatchYXDepth and kBatchDepthYX formats.
6338 input_split.set_count(splits[i].num_batches);
6339 output_split.set_count(splits[i].num_batches);
6340 CudnnTensorDescriptor src_desc(input_split, cudnn_input_type);
6341 CudnnTensorDescriptor dest_desc(output_split, cudnn_output_type);
6342
6343 void* output_data_ptr = static_cast<char*>(output_data.opaque()) +
6344 splits[i].output_offset_in_bytes;
6345 void* input_diff_data_ptr = static_cast<char*>(input_diff_data.opaque()) +
6346 splits[i].output_offset_in_bytes;
6347 void* input_data_ptr = static_cast<char*>(input_data.opaque()) +
6348 splits[i].input_offset_in_bytes;
6349 void* output_diff_data_ptr = static_cast<char*>(output_diff_data.opaque()) +
6350 splits[i].input_offset_in_bytes;
6351 const auto status = cudnn_launcher(src_desc, dest_desc, output_data_ptr,
6352 input_diff_data_ptr, input_data_ptr,
6353 output_diff_data_ptr);
6354 if (!IsStatusOk(status, /*report_error=*/true)) {
6355 return status;
6356 }
6357 }
6358 return ::tensorflow::OkStatus();
6359 }
6360
DoNormalizeWithDimensions(Stream * stream,const dnn::NormalizeDescriptor & normalize_descriptor,const dnn::BatchDescriptor & dimensions,const DeviceMemory<float> & input_data,DeviceMemory<float> * output_data)6361 bool CudnnSupport::DoNormalizeWithDimensions(
6362 Stream* stream, const dnn::NormalizeDescriptor& normalize_descriptor,
6363 const dnn::BatchDescriptor& dimensions,
6364 const DeviceMemory<float>& input_data, DeviceMemory<float>* output_data) {
6365 // Check for unsupported modes.
6366 if (normalize_descriptor.wrap_around()) {
6367 LOG(ERROR) << "CUDA LRN does not support cudnn-around mode";
6368 return false;
6369 }
6370 if (normalize_descriptor.segment_size()) {
6371 LOG(ERROR) << "CUDA LRN does not support segmentation";
6372 return false;
6373 }
6374
6375 CudnnTensorDescriptor dims(dimensions, CUDNN_DATA_FLOAT);
6376 CudnnNormalizeDescriptor normalize(normalize_descriptor);
6377
6378 // Alpha is the scaling factor for input.
6379 float alpha = 1.0f;
6380 // Beta is the scaling factor for output.
6381 float beta = 0.0f;
6382
6383 auto cudnn = cudnn_->GetHandle(parent_, stream);
6384
6385 // Launch the normalization.
6386 const auto status = [&] {
6387 RETURN_IF_CUDNN_ERROR(cudnnLRNCrossChannelForward(
6388 cudnn.handle(), normalize.handle(), CUDNN_LRN_CROSS_CHANNEL_DIM1,
6389 &alpha, dims.handle(), input_data.opaque(), &beta, dims.handle(),
6390 output_data->opaque()));
6391 return ::tensorflow::OkStatus();
6392 }();
6393 return IsStatusOk(status, /*report_error=*/true);
6394 }
6395
DoNormalizeBackwardWithDimensions(Stream * stream,const dnn::NormalizeDescriptor & normalize_descriptor,const dnn::BatchDescriptor & dimensions,const DeviceMemory<float> & raw_data,const DeviceMemory<float> & normalized_data,const DeviceMemory<float> & normalized_variable_gradient,DeviceMemory<float> * raw_variable_gradient,ScratchAllocator * workspace_allocator)6396 bool CudnnSupport::DoNormalizeBackwardWithDimensions(
6397 Stream* stream, const dnn::NormalizeDescriptor& normalize_descriptor,
6398 const dnn::BatchDescriptor& dimensions, const DeviceMemory<float>& raw_data,
6399 const DeviceMemory<float>& normalized_data,
6400 const DeviceMemory<float>& normalized_variable_gradient,
6401 DeviceMemory<float>* raw_variable_gradient,
6402 ScratchAllocator* workspace_allocator) {
6403 // Check for unsupported modes.
6404 if (normalize_descriptor.wrap_around()) {
6405 LOG(ERROR) << "CUDA LRN does not support cudnn-around mode";
6406 return false;
6407 }
6408 if (normalize_descriptor.segment_size()) {
6409 LOG(ERROR) << "CUDA LRN does not support segmentation";
6410 return false;
6411 }
6412
6413 CudnnTensorDescriptor dims(dimensions, CUDNN_DATA_FLOAT);
6414 CudnnNormalizeDescriptor normalize(normalize_descriptor);
6415
6416 float alpha = 1.0f;
6417 float beta = 0.0f;
6418
6419 auto cudnn = cudnn_->GetHandle(parent_, stream);
6420 const auto status = [&] {
6421 RETURN_IF_CUDNN_ERROR(cudnnLRNCrossChannelBackward(
6422 cudnn.handle(), normalize.handle(), CUDNN_LRN_CROSS_CHANNEL_DIM1,
6423 &alpha, dims.handle(), normalized_data.opaque(), dims.handle(),
6424 normalized_variable_gradient.opaque(), dims.handle(), raw_data.opaque(),
6425 &beta, dims.handle(), raw_variable_gradient->opaque()));
6426 return ::tensorflow::OkStatus();
6427 }();
6428 return IsStatusOk(status, /*report_error=*/true);
6429 }
6430
DoDepthConcatenate(Stream * stream,BatchDescriptorSlice input_dimensions,DeviceMemorySlice<float> input_data,DeviceMemory<float> * output_data)6431 bool CudnnSupport::DoDepthConcatenate(Stream* stream,
6432 BatchDescriptorSlice input_dimensions,
6433 DeviceMemorySlice<float> input_data,
6434 DeviceMemory<float>* output_data) {
6435 CHECK_EQ(input_dimensions.size(), input_data.size());
6436
6437 for (const auto& dimensions : input_dimensions) {
6438 if (dimensions.layout() != dnn::DataLayout::kBatchDepthYX) {
6439 LOG(ERROR) << "CudnnSupport::DoDepthConcatenate currently only "
6440 "supports the kBatchDepthYX layout.";
6441 return false;
6442 }
6443 }
6444
6445 if (input_dimensions.empty()) {
6446 return true; // Nothing to do.
6447 }
6448
6449 dnn::BatchDescriptor output_dimensions =
6450 dnn::BatchDescriptor::DepthConcatenateOutputDescriptor(input_dimensions);
6451
6452 const int64_t area = output_dimensions.width() * output_dimensions.height();
6453 const auto index = [area](int64_t batch, int64_t depth, int64_t yx,
6454 int64_t max_depth) {
6455 return (batch * max_depth + depth) * area + yx;
6456 };
6457
6458 std::vector<float> output_host(output_dimensions.ElementCount());
6459 std::vector<float> tmp;
6460 int64_t depth_sum = 0;
6461 for (size_t i = 0; i < input_data.size(); ++i) {
6462 const auto& dimensions = input_dimensions[i];
6463 tmp.resize(dimensions.ElementCount());
6464 stream->ThenMemcpyD2H<float>(*input_data[i], absl::MakeSpan(tmp));
6465 port::Status block_status = stream->BlockHostUntilDone();
6466 if (!block_status.ok()) {
6467 LOG(ERROR) << "BlockHostUntilDone failed: " << block_status;
6468 return false;
6469 }
6470
6471 for (int64_t batch = 0; batch < output_dimensions.count(); ++batch) {
6472 for (int64_t yx = 0; yx < area; ++yx) {
6473 for (int64_t depth = 0; depth < dimensions.feature_map_count();
6474 ++depth) {
6475 LOG(INFO) << output_dimensions.ElementCount() << ' ' << batch << ' '
6476 << yx << ' ' << depth;
6477 output_host[index(batch, depth + depth_sum, yx,
6478 output_dimensions.feature_map_count())] =
6479 tmp[index(batch, depth, yx, dimensions.feature_map_count())];
6480 }
6481 }
6482 }
6483 depth_sum += dimensions.feature_map_count();
6484 }
6485 stream->ThenMemcpyH2D<float>(output_host, output_data);
6486 return true;
6487 }
6488
DoElementwiseOperate(Stream *,dnn::ElementwiseOperation,BatchDescriptorSlice,DeviceMemorySlice<float>,const dnn::BatchDescriptor &,DeviceMemory<float> *)6489 bool CudnnSupport::DoElementwiseOperate(Stream*, dnn::ElementwiseOperation,
6490 BatchDescriptorSlice,
6491 DeviceMemorySlice<float>,
6492 const dnn::BatchDescriptor&,
6493 DeviceMemory<float>*) {
6494 LOG(FATAL) << "not yet implemented"; // TODO(leary)
6495 return false;
6496 }
6497
DoXYPad(Stream * stream,const dnn::BatchDescriptor & dimensions,const DeviceMemory<float> & input_data,int64_t left_pad,int64_t right_pad,int64_t top_pad,int64_t bottom_pad,DeviceMemory<float> * output_data)6498 bool CudnnSupport::DoXYPad(Stream* stream,
6499 const dnn::BatchDescriptor& dimensions,
6500 const DeviceMemory<float>& input_data,
6501 int64_t left_pad, int64_t right_pad, int64_t top_pad,
6502 int64_t bottom_pad,
6503 DeviceMemory<float>* output_data) {
6504 LOG(FATAL) << "not yet implemented"; // TODO(leary)
6505 return false;
6506 }
6507
DoXYSlice(Stream * stream,const dnn::BatchDescriptor & dimensions,const DeviceMemory<float> & input_data,int64_t left_trim,int64_t right_trim,int64_t top_trim,int64_t bottom_trim,DeviceMemory<float> * output_data)6508 bool CudnnSupport::DoXYSlice(Stream* stream,
6509 const dnn::BatchDescriptor& dimensions,
6510 const DeviceMemory<float>& input_data,
6511 int64_t left_trim, int64_t right_trim,
6512 int64_t top_trim, int64_t bottom_trim,
6513 DeviceMemory<float>* output_data) {
6514 LOG(FATAL) << "not yet implemented"; // TODO(leary)
6515 return false;
6516 }
6517
DoMemcpyD2HQuantized(Stream * stream,const DeviceMemory<float> & gpu_unquantized_src,dnn::QuantizedActivationMode mode,void * host_dst,int64_t size)6518 bool CudnnSupport::DoMemcpyD2HQuantized(
6519 Stream* stream, const DeviceMemory<float>& gpu_unquantized_src,
6520 dnn::QuantizedActivationMode mode, void* host_dst, int64_t size) {
6521 LOG(ERROR) << "quantized memcpy not supported by cuDNN";
6522 return false;
6523 }
6524
DoMemcpyH2DQuantized(Stream * stream,const void * host_src,int64_t size,dnn::QuantizedActivationMode mode,DeviceMemory<float> * gpu_unquantized_dst)6525 bool CudnnSupport::DoMemcpyH2DQuantized(
6526 Stream* stream, const void* host_src, int64_t size,
6527 dnn::QuantizedActivationMode mode,
6528 DeviceMemory<float>* gpu_unquantized_dst) {
6529 LOG(ERROR) << "quantized memcpy not supported by cuDNN";
6530 return false;
6531 }
6532
DeriveOutputBatchDescriptor(const dnn::BatchDescriptor & batch_descriptor,const dnn::FilterDescriptor & filter_descriptor,const dnn::ConvolutionDescriptor & convolution_descriptor,dnn::BatchDescriptor * output_batch_descriptor)6533 bool CudnnSupport::DeriveOutputBatchDescriptor(
6534 const dnn::BatchDescriptor& batch_descriptor,
6535 const dnn::FilterDescriptor& filter_descriptor,
6536 const dnn::ConvolutionDescriptor& convolution_descriptor,
6537 dnn::BatchDescriptor* output_batch_descriptor) {
6538 CudnnTensorDescriptor input_nd(batch_descriptor, CUDNN_DATA_FLOAT);
6539 CudnnFilterDescriptor filter(filter_descriptor, CUDNN_DATA_FLOAT);
6540 CudnnConvolutionDescriptor conv(convolution_descriptor, CUDNN_DATA_FLOAT);
6541
6542 int dn = batch_descriptor.ndims() + 2;
6543 std::vector<int> dims(dn); // in BDYX
6544 const auto status = [&] {
6545 RETURN_IF_CUDNN_ERROR(cudnnGetConvolutionNdForwardOutputDim(
6546 conv.handle(), input_nd.handle(), filter.handle(), dn, dims.data()));
6547 output_batch_descriptor->set_count(dims[0])
6548 .set_feature_map_count(dims[1])
6549 .set_layout(batch_descriptor.layout());
6550
6551 for (int i = 0; i < batch_descriptor.ndims(); i++) {
6552 output_batch_descriptor->set_spatial_dim(static_cast<dnn::DimIndex>(i),
6553 dims.rbegin()[i]);
6554 }
6555 return ::tensorflow::OkStatus();
6556 }();
6557 return IsStatusOk(status, /*report_error=*/true);
6558 }
6559
6560 } // namespace gpu
6561
initialize_cudnn()6562 void initialize_cudnn() {
6563 port::Status status =
6564 PluginRegistry::Instance()->RegisterFactory<PluginRegistry::DnnFactory>(
6565 cuda::kCudaPlatformId, gpu::kCuDnnPlugin, "cuDNN",
6566 [](internal::StreamExecutorInterface* parent) -> dnn::DnnSupport* {
6567 gpu::GpuExecutor* cuda_executor =
6568 dynamic_cast<gpu::GpuExecutor*>(parent);
6569 if (cuda_executor == nullptr) {
6570 LOG(ERROR) << "Attempting to initialize an instance of the cuDNN "
6571 << "support library with a non-CUDA StreamExecutor";
6572 return nullptr;
6573 }
6574
6575 gpu::CudnnSupport* dnn = new gpu::CudnnSupport(cuda_executor);
6576 if (!dnn->Init().ok()) {
6577 // Note: Init() will log a more specific error.
6578 delete dnn;
6579 return nullptr;
6580 }
6581 return dnn;
6582 });
6583
6584 if (!status.ok()) {
6585 LOG(ERROR) << "Unable to register cuDNN factory: "
6586 << status.error_message();
6587 }
6588
6589 PluginRegistry::Instance()->SetDefaultFactory(
6590 cuda::kCudaPlatformId, PluginKind::kDnn, gpu::kCuDnnPlugin);
6591 }
6592
6593 } // namespace stream_executor
6594
6595 #pragma clang diagnostic pop
6596
6597 REGISTER_MODULE_INITIALIZER(register_cudnn,
6598 { stream_executor::initialize_cudnn(); });
6599