• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #define EIGEN_USE_THREADS
16 
17 #include <stddef.h>
18 
19 #include <atomic>
20 #include <cmath>
21 #include <functional>
22 #include <limits>
23 #include <string>
24 #include <unordered_set>
25 
26 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
27 #include "tensorflow/core/framework/device_base.h"
28 #include "tensorflow/core/framework/kernel_def_builder.h"
29 #include "tensorflow/core/framework/op.h"
30 #include "tensorflow/core/framework/op_def_builder.h"
31 #include "tensorflow/core/framework/op_kernel.h"
32 #include "tensorflow/core/framework/register_types.h"
33 #include "tensorflow/core/framework/tensor.h"
34 #include "tensorflow/core/framework/tensor_shape.h"
35 #include "tensorflow/core/framework/tensor_types.h"
36 #include "tensorflow/core/framework/types.h"
37 #include "tensorflow/core/kernels/gpu_utils.h"
38 #include "tensorflow/core/lib/core/errors.h"
39 #include "tensorflow/core/lib/core/status.h"
40 #include "tensorflow/core/lib/core/stringpiece.h"
41 #include "tensorflow/core/lib/gtl/inlined_vector.h"
42 #include "tensorflow/core/lib/hash/hash.h"
43 #include "tensorflow/core/lib/strings/stringprintf.h"
44 #include "tensorflow/core/platform/fingerprint.h"
45 #include "tensorflow/core/platform/mutex.h"
46 #include "tensorflow/core/platform/types.h"
47 #include "tensorflow/core/profiler/lib/scoped_annotation.h"
48 #include "tensorflow/core/util/env_var.h"
49 #include "tensorflow/core/util/use_cudnn.h"
50 
51 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
52 #include "tensorflow/core/platform/stream_executor.h"
53 #include "tensorflow/core/util/stream_executor_util.h"
54 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
55 
56 /*
57  * This module implements ops that fuse a multi-layer multi-step RNN/LSTM model
58  * using the underlying Cudnn library.
59  *
60  * Cudnn RNN library exposes an opaque parameter buffer with unknown layout and
61  * format. And it is very likely that if saved, they cannot be used across
62  * different GPUs. So users need to first query the size of the opaque
63  * parameter buffer, and convert it to and from its canonical forms. But each
64  * actual training step is carried out with the parameter buffer.
65  *
66  * Similar to many other ops, the forward op has two flavors: training and
67  * inference. When training is specified, additional data in reserve_space will
68  * be produced for the backward pass. So there is a performance penalty.
69  *
70  * In addition to the actual data and reserve_space, Cudnn also needs more
71  * memory as temporary workspace. The memory management to and from
72  * stream-executor is done through ScratchAllocator. In general,
73  * stream-executor is responsible for creating the memory of proper size. And
74  * TensorFlow is responsible for making sure the memory is alive long enough
75  * and recycles afterwards.
76  *
77  */
78 namespace tensorflow {
79 
80 using CPUDevice = Eigen::ThreadPoolDevice;
81 
82 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
83 
84 using GPUDevice = Eigen::GpuDevice;
85 using se::Stream;
86 using se::StreamExecutor;
87 using se::dnn::RnnDescriptor;
88 
89 template <typename Device, typename T, typename Index>
90 class CudnnRNNParamsSizeOp;
91 
92 template <typename Device, typename T>
93 class CudnnRNNParamsToCanonical;
94 
95 template <typename Device, typename T>
96 class CudnnRNNCanonicalToParams;
97 
98 template <typename Device, typename T>
99 class CudnnRNNForwardOp;
100 
101 template <typename Device, typename T>
102 class CudnnRNNBackwardOp;
103 
104 template <typename Device, typename T>
105 class CudnnRNNForwardOpV2;
106 
107 template <typename Device, typename T>
108 class CudnnRNNBackwardOpV2;
109 
110 template <typename Device, typename T>
111 class CudnnRNNForwardOpV3;
112 
113 template <typename Device, typename T>
114 class CudnnRNNBackwardOpV3;
115 
116 enum class TFRNNInputMode {
117   kRNNLinearInput = 0,
118   kRNNSkipInput = 1,
119   kAutoSelect = 9999999
120 };
121 
122 namespace {
123 using se::DeviceMemory;
124 using se::DeviceMemoryBase;
125 using se::ScratchAllocator;
126 using se::dnn::AlgorithmConfig;
127 using se::dnn::AlgorithmDesc;
128 using se::dnn::ProfileResult;
129 using se::dnn::RnnDirectionMode;
130 using se::dnn::RnnInputMode;
131 using se::dnn::RnnMode;
132 using se::dnn::RnnSequenceTensorDescriptor;
133 using se::dnn::RnnStateTensorDescriptor;
134 using se::dnn::ToDataType;
135 using se::port::StatusOr;
136 
HashList(const std::vector<int> & list)137 uint64 HashList(const std::vector<int>& list) {
138   if (list.empty()) {
139     return 0;
140   }
141   uint64 hash_code = list[0];
142   for (int i = 1; i < list.size(); i++) {
143     hash_code = Hash64Combine(hash_code, list[i]);
144   }
145   return hash_code;
146 }
147 
148 // Encapsulate all the shape information that is used in both forward and
149 // backward rnn operations.
150 class CudnnRnnParameters {
151  public:
CudnnRnnParameters(int num_layers,int input_size,int num_units,int max_seq_length,int batch_size,int dir_count,bool has_dropout,bool is_training,RnnMode rnn_mode,TFRNNInputMode rnn_input_mode,DataType dtype)152   CudnnRnnParameters(int num_layers, int input_size, int num_units,
153                      int max_seq_length, int batch_size, int dir_count,
154                      bool has_dropout, bool is_training, RnnMode rnn_mode,
155                      TFRNNInputMode rnn_input_mode, DataType dtype)
156       : num_layers_(num_layers),
157         input_size_(input_size),
158         num_units_(num_units),
159         seq_length_(max_seq_length),
160         batch_size_(batch_size),
161         dir_count_(dir_count),
162         has_dropout_(has_dropout),
163         is_training_(is_training),
164         rnn_mode_(rnn_mode),
165         rnn_input_mode_(rnn_input_mode),
166         dtype_(dtype) {
167     hash_code_ =
168         HashList({num_layers, input_size, num_units, max_seq_length, batch_size,
169                   dir_count, static_cast<int>(has_dropout),
170                   static_cast<int>(is_training), static_cast<int>(rnn_mode),
171                   static_cast<int>(rnn_input_mode), dtype});
172   }
173 
operator ==(const CudnnRnnParameters & other) const174   bool operator==(const CudnnRnnParameters& other) const {
175     return this->get_data_as_tuple() == other.get_data_as_tuple();
176   }
177 
operator !=(const CudnnRnnParameters & other) const178   bool operator!=(const CudnnRnnParameters& other) const {
179     return !(*this == other);
180   }
hash() const181   uint64 hash() const { return hash_code_; }
182 
ToString() const183   string ToString() const {
184     std::vector<string> fields = {
185         std::to_string(num_layers_),
186         std::to_string(input_size_),
187         std::to_string(num_units_),
188         std::to_string(seq_length_),
189         std::to_string(batch_size_),
190         std::to_string(dir_count_),
191         std::to_string(has_dropout_),
192         std::to_string(is_training_),
193         std::to_string(static_cast<int>(rnn_mode_)),
194         std::to_string(static_cast<int>(rnn_input_mode_)),
195         std::to_string(static_cast<int>(dtype_))};
196     return absl::StrJoin(fields, ", ");
197   }
198 
199  private:
200   using ParameterDataType = std::tuple<int, int, int, int, int, int, bool, bool,
201                                        RnnMode, TFRNNInputMode, DataType>;
202 
get_data_as_tuple() const203   ParameterDataType get_data_as_tuple() const {
204     return std::make_tuple(num_layers_, input_size_, num_units_, seq_length_,
205                            batch_size_, dir_count_, has_dropout_, is_training_,
206                            rnn_mode_, rnn_input_mode_, dtype_);
207   }
208 
209   const int num_layers_;
210   const int input_size_;
211   const int num_units_;
212   const int seq_length_;
213   const int batch_size_;
214   const int dir_count_;
215   const bool has_dropout_;
216   const bool is_training_;
217   const RnnMode rnn_mode_;
218   const TFRNNInputMode rnn_input_mode_;
219   const DataType dtype_;
220   uint64 hash_code_;
221 };
222 
223 struct RnnAutotuneGroup {
nametensorflow::__anon787752e60111::RnnAutotuneGroup224   static string name() { return "Rnn"; }
225 };
226 
227 using AutotuneRnnConfigMap =
228     AutotuneSingleton<RnnAutotuneGroup, CudnnRnnParameters, AlgorithmConfig>;
229 
ParseRNNMode(const string & str,RnnMode * rnn_mode)230 Status ParseRNNMode(const string& str, RnnMode* rnn_mode) {
231   if (str == "rnn_relu") {
232     *rnn_mode = RnnMode::kRnnRelu;
233     return Status::OK();
234   } else if (str == "rnn_tanh") {
235     *rnn_mode = RnnMode::kRnnTanh;
236     return Status::OK();
237   } else if (str == "lstm") {
238     *rnn_mode = RnnMode::kRnnLstm;
239     return Status::OK();
240   } else if (str == "gru") {
241     *rnn_mode = RnnMode::kRnnGru;
242     return Status::OK();
243   }
244   return errors::InvalidArgument("Invalid RNN mode: ", str);
245 }
246 
ParseTFRNNInputMode(const string & str,TFRNNInputMode * rnn_input_mode)247 Status ParseTFRNNInputMode(const string& str, TFRNNInputMode* rnn_input_mode) {
248   if (str == "linear_input") {
249     *rnn_input_mode = TFRNNInputMode::kRNNLinearInput;
250     return Status::OK();
251   } else if (str == "skip_input") {
252     *rnn_input_mode = TFRNNInputMode::kRNNSkipInput;
253     return Status::OK();
254   } else if (str == "auto_select") {
255     *rnn_input_mode = TFRNNInputMode::kAutoSelect;
256     return Status::OK();
257   }
258   return errors::InvalidArgument("Invalid RNN input mode: ", str);
259 }
260 
ParseRNNDirectionMode(const string & str,RnnDirectionMode * rnn_dir_mode)261 Status ParseRNNDirectionMode(const string& str,
262                              RnnDirectionMode* rnn_dir_mode) {
263   if (str == "unidirectional") {
264     *rnn_dir_mode = RnnDirectionMode::kRnnUnidirectional;
265     return Status::OK();
266   } else if (str == "bidirectional") {
267     *rnn_dir_mode = RnnDirectionMode::kRnnBidirectional;
268     return Status::OK();
269   }
270   return errors::InvalidArgument("Invalid RNN direction mode: ", str);
271 }
272 
ToRNNInputMode(TFRNNInputMode tf_input_mode,int num_units,int input_size,RnnInputMode * input_mode)273 Status ToRNNInputMode(TFRNNInputMode tf_input_mode, int num_units,
274                       int input_size, RnnInputMode* input_mode) {
275   switch (tf_input_mode) {
276     case TFRNNInputMode::kRNNLinearInput:
277       *input_mode = RnnInputMode::kRnnLinearSkip;
278       break;
279     case TFRNNInputMode::kRNNSkipInput:
280       *input_mode = RnnInputMode::kRnnSkipInput;
281       break;
282     case TFRNNInputMode::kAutoSelect:
283       *input_mode = (input_size == num_units) ? RnnInputMode::kRnnSkipInput
284                                               : RnnInputMode::kRnnLinearSkip;
285       break;
286     default:
287       return errors::InvalidArgument("Invalid TF input mode: ",
288                                      static_cast<int>(tf_input_mode));
289   }
290   return Status::OK();
291 }
292 
293 // TODO(zhengxq): Merge those into stream_executor_util.h.
294 template <typename T>
AsDeviceMemory(const Tensor * tensor)295 const DeviceMemory<T> AsDeviceMemory(const Tensor* tensor) {
296   return DeviceMemory<T>::MakeFromByteSize(
297       const_cast<T*>(tensor->template flat<T>().data()),
298       tensor->template flat<T>().size() * sizeof(T));
299 }
300 
301 template <typename T>
AsDeviceMemory(Tensor * tensor)302 DeviceMemory<T> AsDeviceMemory(Tensor* tensor) {
303   return DeviceMemory<T>::MakeFromByteSize(
304       tensor->template flat<T>().data(),
305       tensor->template flat<T>().size() * sizeof(T));
306 }
307 
308 template <typename U, typename T>
CastDeviceMemory(Tensor * tensor)309 DeviceMemory<U> CastDeviceMemory(Tensor* tensor) {
310   return DeviceMemory<U>::MakeFromByteSize(
311       tensor->template flat<T>().data(),
312       tensor->template flat<T>().size() * sizeof(T));
313 }
314 
SliceDeviceMemory(const DeviceMemoryBase & device_memory,int64_t offset,int64_t size)315 DeviceMemoryBase SliceDeviceMemory(const DeviceMemoryBase& device_memory,
316                                    int64_t offset, int64_t size) {
317   const void* base_ptr = device_memory.opaque();
318   void* offset_ptr =
319       const_cast<char*>(reinterpret_cast<const char*>(base_ptr) + offset);
320   CHECK(offset + size <= device_memory.size())
321       << "The slice is not within the region of DeviceMemory.";
322   return DeviceMemoryBase(offset_ptr, size);
323 }
324 
FromExecutorStatus(const se::port::Status & s)325 inline Status FromExecutorStatus(const se::port::Status& s) {
326   return s.ok() ? Status::OK()
327                 : Status(static_cast<error::Code>(static_cast<int>(s.code())),
328                          s.error_message());
329 }
330 
331 template <typename T>
FromExecutorStatus(const se::port::StatusOr<T> & s)332 inline Status FromExecutorStatus(const se::port::StatusOr<T>& s) {
333   return FromExecutorStatus(s.status());
334 }
335 
ToExecutorStatus(const Status & s)336 inline se::port::Status ToExecutorStatus(const Status& s) {
337   return s.ok() ? se::port::Status::OK()
338                 : se::port::Status(static_cast<se::port::error::Code>(
339                                        static_cast<int>(s.code())),
340                                    s.error_message());
341 }
342 
343 template <typename>
344 struct ToTFDataType;
345 
346 template <>
347 struct ToTFDataType<Eigen::half> : std::integral_constant<DataType, DT_HALF> {};
348 
349 template <>
350 struct ToTFDataType<float> : std::integral_constant<DataType, DT_FLOAT> {};
351 
352 template <>
353 struct ToTFDataType<double> : std::integral_constant<DataType, DT_DOUBLE> {};
354 
355 template <>
356 struct ToTFDataType<uint8> : std::integral_constant<DataType, DT_UINT8> {};
357 
358 // A helper to allocate temporary scratch memory for Cudnn RNN models. It
359 // takes the ownership of the underlying memory. The expectation is that the
360 // memory should be alive for the span of the Cudnn RNN itself.
361 template <typename T>
362 class CudnnRnnAllocatorInTemp : public ScratchAllocator {
363  public:
364   ~CudnnRnnAllocatorInTemp() override = default;
365 
CudnnRnnAllocatorInTemp(OpKernelContext * context)366   explicit CudnnRnnAllocatorInTemp(OpKernelContext* context)
367       : context_(context) {}
GetMemoryLimitInBytes()368   int64 GetMemoryLimitInBytes() override {
369     return std::numeric_limits<int64>::max();
370   }
371 
AllocateBytes(int64_t byte_size)372   StatusOr<DeviceMemory<uint8>> AllocateBytes(int64_t byte_size) override {
373     Tensor temporary_memory;
374     const DataType tf_data_type = ToTFDataType<T>::value;
375     int64_t allocate_count =
376         Eigen::divup(byte_size, static_cast<int64>(sizeof(T)));
377     Status allocation_status(context_->allocate_temp(
378         tf_data_type, TensorShape({allocate_count}), &temporary_memory));
379     if (!allocation_status.ok()) {
380       return ToExecutorStatus(allocation_status);
381     }
382     // Hold the reference of the allocated tensors until the end of the
383     // allocator.
384     allocated_tensors_.push_back(temporary_memory);
385     total_byte_size_ += byte_size;
386     return DeviceMemory<uint8>::MakeFromByteSize(
387         temporary_memory.template flat<T>().data(),
388         temporary_memory.template flat<T>().size() * sizeof(T));
389   }
390 
TotalByteSize() const391   int64 TotalByteSize() const { return total_byte_size_; }
392 
get_allocated_tensor(int index) const393   Tensor get_allocated_tensor(int index) const {
394     return allocated_tensors_[index];
395   }
396 
397  private:
398   int64 total_byte_size_ = 0;
399   OpKernelContext* context_;  // not owned
400   std::vector<Tensor> allocated_tensors_;
401 };
402 
403 // A helper to allocate memory for Cudnn RNN models as a kernel output. It is
404 // used by forward pass kernel to feed the output to the backward pass.
405 // The memory is expected to live long enough after the backward pass is
406 // finished.
407 template <typename T>
408 class CudnnRnnAllocatorInOutput : public ScratchAllocator {
409  public:
~CudnnRnnAllocatorInOutput()410   ~CudnnRnnAllocatorInOutput() override {}
CudnnRnnAllocatorInOutput(OpKernelContext * context,int output_index)411   CudnnRnnAllocatorInOutput(OpKernelContext* context, int output_index)
412       : context_(context), output_index_(output_index) {}
GetMemoryLimitInBytes()413   int64 GetMemoryLimitInBytes() override {
414     return std::numeric_limits<int64>::max();
415   }
AllocateBytes(int64_t byte_size)416   StatusOr<DeviceMemory<uint8>> AllocateBytes(int64_t byte_size) override {
417     CHECK(total_byte_size_ == 0)
418         << "Reserve space allocator can only be called once";
419     int64_t allocate_count =
420         Eigen::divup(byte_size, static_cast<int64>(sizeof(T)));
421 
422     Tensor* temporary_memory = nullptr;
423     Status allocation_status(context_->allocate_output(
424         output_index_, TensorShape({allocate_count}), &temporary_memory));
425     if (!allocation_status.ok()) {
426       return ToExecutorStatus(allocation_status);
427     }
428     total_byte_size_ += byte_size;
429     auto memory_uint8 = DeviceMemory<uint8>::MakeFromByteSize(
430         temporary_memory->template flat<T>().data(),
431         temporary_memory->template flat<T>().size() * sizeof(T));
432     return StatusOr<DeviceMemory<uint8>>(memory_uint8);
433   }
TotalByteSize()434   int64 TotalByteSize() { return total_byte_size_; }
435 
436  private:
437   int64 total_byte_size_ = 0;
438   OpKernelContext* context_;  // not owned
439   int output_index_;
440 };
441 
442 // A helper to allocate memory for Cudnn RNN models, which is
443 // expected to live between kernel invocations.
444 // This class is not thread-safe.
445 class CudnnRNNSpaceAllocator : public ScratchAllocator {
446  public:
CudnnRNNSpaceAllocator(OpKernelContext * context)447   explicit CudnnRNNSpaceAllocator(OpKernelContext* context)
448       : context_(context) {}
449 
~CudnnRNNSpaceAllocator()450   ~CudnnRNNSpaceAllocator() override {}
451 
GetMemoryLimitInBytes()452   int64 GetMemoryLimitInBytes() override {
453     return std::numeric_limits<int64>::max();
454   }
455 
AllocateBytes(int64_t byte_size)456   StatusOr<DeviceMemory<uint8>> AllocateBytes(int64_t byte_size) override {
457     if (total_byte_size_ != 0) {
458       return Status(error::FAILED_PRECONDITION,
459                     "Space allocator can only be called once");
460     }
461 
462     Status allocation_status =
463         context_->allocate_temp(DT_UINT8, TensorShape({byte_size}), &tensor_);
464     if (!allocation_status.ok()) {
465       return ToExecutorStatus(allocation_status);
466     }
467     total_byte_size_ += byte_size;
468     return AsDeviceMemory<uint8>(&tensor_);
469   }
TotalByteSize()470   int64 TotalByteSize() { return total_byte_size_; }
471 
472  private:
473   int64 total_byte_size_ = 0;
474   Tensor tensor_;
475   OpKernelContext* context_;  // not owned
476 };
477 
478 struct CudnnModelTypes {
479   RnnMode rnn_mode;
480   TFRNNInputMode rnn_input_mode;
481   RnnDirectionMode rnn_direction_mode;
HasInputCtensorflow::__anon787752e60111::CudnnModelTypes482   bool HasInputC() const {
483     // For Cudnn 5.0, only LSTM has input-c. All other models use only
484     // input-h.
485     return rnn_mode == RnnMode::kRnnLstm;
486   }
487 
DebugStringtensorflow::__anon787752e60111::CudnnModelTypes488   string DebugString() const {
489     return strings::Printf(
490         "[rnn_mode, rnn_input_mode, rnn_direction_mode]: %d, %d, %d ",
491         static_cast<int>(rnn_mode), static_cast<int>(rnn_input_mode),
492         static_cast<int>(rnn_direction_mode));
493   }
494 };
495 
496 // A helper class that collects the shapes to describe a RNN model.
497 struct CudnnRnnModelShapes {
498   int num_layers;
499   int input_size;
500   int num_units;
501   int dir_count;
502   int max_seq_length;
503   int batch_size;
504   int cell_num_units = 0;
505   // If you add new field to this structure, please take care of
506   // updating IsCompatibleWith() below as well as the hash function in
507   // CudnnRnnConfigHasher.
508   TensorShape input_shape;
509   TensorShape output_shape;
510   TensorShape hidden_state_shape;
511   TensorShape cell_state_shape;
512   // At present only fields related to cached RnnDescriptor are concerned.
IsCompatibleWithtensorflow::__anon787752e60111::CudnnRnnModelShapes513   bool IsCompatibleWith(const CudnnRnnModelShapes& rhs) const {
514     return num_layers == rhs.num_layers && input_size == rhs.input_size &&
515            num_units == rhs.num_units && dir_count == rhs.dir_count &&
516            cell_num_units == rhs.cell_num_units &&
517            max_seq_length == rhs.max_seq_length;
518   }
DebugStringtensorflow::__anon787752e60111::CudnnRnnModelShapes519   string DebugString() const {
520     return strings::Printf(
521         "[num_layers, input_size, num_units, dir_count, max_seq_length, "
522         "batch_size, cell_num_units]: [%d, %d, %d, %d, %d, %d, %d] ",
523         num_layers, input_size, num_units, dir_count, max_seq_length,
524         batch_size, cell_num_units);
525   }
526 };
527 
528 // Utility class for using CudnnRnnConfig and AlgorithmDesc pair a hash table
529 // key.
530 struct CudnnRnnConfigHasher {
operator ()tensorflow::__anon787752e60111::CudnnRnnConfigHasher531   uint64 operator()(
532       const std::pair<CudnnRnnModelShapes, absl::optional<AlgorithmDesc>>&
533           to_hash) const {
534     auto& shapes = to_hash.first;
535     auto& algo_desc = to_hash.second;
536 
537     uint64 hash =
538         HashList({shapes.num_layers, shapes.input_size, shapes.num_units,
539                   shapes.dir_count, shapes.max_seq_length, shapes.batch_size});
540     if (algo_desc.has_value()) {
541       hash = Hash64Combine(hash, algo_desc->hash());
542     }
543     return hash;
544   }
545 };
546 
547 // Utility class for using CudnnRnnModelShapes and AlgorithmDesc pair as a hash
548 // table key.
549 struct CudnnRnnConfigComparator {
operator ()tensorflow::__anon787752e60111::CudnnRnnConfigComparator550   bool operator()(
551       const std::pair<CudnnRnnModelShapes, absl::optional<AlgorithmDesc>>& lhs,
552       const std::pair<CudnnRnnModelShapes, absl::optional<AlgorithmDesc>>& rhs)
553       const {
554     return lhs.first.IsCompatibleWith(rhs.first) && lhs.second == rhs.second;
555   }
556 };
557 
558 // Pointers to RNN scratch space for a specific set of shape parameters (used as
559 // a hash table value in CudnnRNNForwardOp and CudnnRNNBackwardOp).
560 struct RnnScratchSpace {
561   std::unique_ptr<RnnDescriptor> rnn_desc;
562   std::unique_ptr<CudnnRNNSpaceAllocator> dropout_state_allocator;
563 };
564 
565 // Extract and checks the forward input tensors, parameters, and shapes from the
566 // OpKernelContext.
ExtractForwardInput(OpKernelContext * context,const CudnnModelTypes & model_types,bool time_major,const Tensor ** input,const Tensor ** input_h,const Tensor ** input_c,const Tensor ** params,const int num_proj,CudnnRnnModelShapes * model_shapes)567 Status ExtractForwardInput(OpKernelContext* context,
568                            const CudnnModelTypes& model_types, bool time_major,
569                            const Tensor** input, const Tensor** input_h,
570                            const Tensor** input_c, const Tensor** params,
571                            const int num_proj,
572                            CudnnRnnModelShapes* model_shapes) {
573   TF_RETURN_IF_ERROR(context->input("input", input));
574   TF_RETURN_IF_ERROR(context->input("input_h", input_h));
575   if (model_types.HasInputC()) {
576     TF_RETURN_IF_ERROR(context->input("input_c", input_c));
577   }
578   TF_RETURN_IF_ERROR(context->input("params", params));
579 
580   if ((*input)->dims() != 3) {
581     return errors::InvalidArgument("RNN input must be a 3-D vector.");
582   }
583   if (time_major) {
584     model_shapes->max_seq_length = (*input)->dim_size(0);
585     model_shapes->batch_size = (*input)->dim_size(1);
586   } else {
587     model_shapes->max_seq_length = (*input)->dim_size(1);
588     model_shapes->batch_size = (*input)->dim_size(0);
589   }
590   model_shapes->input_size = (*input)->dim_size(2);
591   model_shapes->input_shape = (*input)->shape();
592   model_shapes->dir_count =
593       (model_types.rnn_direction_mode == RnnDirectionMode::kRnnBidirectional)
594           ? 2
595           : 1;
596 
597   if ((*input_h)->dims() != 3) {
598     return errors::InvalidArgument("RNN input_h must be a 3-D vector.");
599   }
600   if (time_major) {
601     model_shapes->num_layers =
602         (*input_h)->dim_size(0) / model_shapes->dir_count;
603   } else {
604     model_shapes->num_layers =
605         (*input_h)->dim_size(1) / model_shapes->dir_count;
606   }
607   model_shapes->num_units = (*input_h)->dim_size(2);
608 
609   if (time_major) {
610     model_shapes->hidden_state_shape =
611         TensorShape({model_shapes->dir_count * model_shapes->num_layers,
612                      model_shapes->batch_size, model_shapes->num_units});
613   } else {
614     model_shapes->hidden_state_shape =
615         TensorShape({model_shapes->batch_size,
616                      model_shapes->dir_count * model_shapes->num_layers,
617                      model_shapes->num_units});
618   }
619   if ((*input_h)->shape() != model_shapes->hidden_state_shape) {
620     return errors::InvalidArgument(
621         "Invalid input_h shape: ", (*input_h)->shape().DebugString(), " ",
622         model_shapes->hidden_state_shape.DebugString());
623   }
624   if (model_types.HasInputC()) {
625     model_shapes->cell_num_units = (*input_c)->dim_size(2);
626     if (time_major) {
627       model_shapes->cell_state_shape =
628           TensorShape({model_shapes->dir_count * model_shapes->num_layers,
629                        model_shapes->batch_size, model_shapes->cell_num_units});
630     } else {
631       model_shapes->cell_state_shape =
632           TensorShape({model_shapes->batch_size,
633                        model_shapes->dir_count * model_shapes->num_layers,
634                        model_shapes->cell_num_units});
635     }
636     if (num_proj == 0) {
637       if ((*input_h)->shape() != (*input_c)->shape()) {
638         return errors::InvalidArgument(
639             "input_h and input_c must have the same shape w/o projection: ",
640             (*input_h)->shape().DebugString(), " ",
641             (*input_c)->shape().DebugString());
642       }
643     } else {
644       if ((*input_h)->dim_size(2) > (*input_c)->dim_size(2) ||
645           num_proj != (*input_h)->dim_size(2) ||
646           (*input_h)->dim_size(0) != (*input_c)->dim_size(0) ||
647           (*input_h)->dim_size(1) != (*input_c)->dim_size(1)) {
648         return errors::InvalidArgument(
649             "Invalid input_h and input_c w/ projection size: ", num_proj, " ",
650             (*input_h)->shape().DebugString(), " ",
651             (*input_c)->shape().DebugString());
652       }
653     }
654   } else {
655     // dummy cell_state_shape TODO(kaixih): remove the time_major branch
656     if (time_major) {
657       model_shapes->cell_state_shape =
658           TensorShape({model_shapes->dir_count * model_shapes->num_layers,
659                        model_shapes->batch_size, model_shapes->num_units});
660     } else {
661       model_shapes->cell_state_shape =
662           TensorShape({model_shapes->batch_size,
663                        model_shapes->dir_count * model_shapes->num_layers,
664                        model_shapes->num_units});
665     }
666     model_shapes->cell_num_units = 0;
667   }
668   if (time_major) {
669     model_shapes->output_shape =
670         TensorShape({model_shapes->max_seq_length, model_shapes->batch_size,
671                      model_shapes->dir_count * model_shapes->num_units});
672   } else {
673     model_shapes->output_shape =
674         TensorShape({model_shapes->batch_size, model_shapes->max_seq_length,
675                      model_shapes->dir_count * model_shapes->num_units});
676   }
677   return Status::OK();
678 }
679 
680 // Overloaded function to process the sequence_lengths
ExtractForwardInput(OpKernelContext * context,const CudnnModelTypes & model_types,bool time_major,const Tensor ** input,const Tensor ** input_h,const Tensor ** input_c,const Tensor ** params,const Tensor ** sequence_lengths,const int num_proj,CudnnRnnModelShapes * model_shapes)681 Status ExtractForwardInput(OpKernelContext* context,
682                            const CudnnModelTypes& model_types, bool time_major,
683                            const Tensor** input, const Tensor** input_h,
684                            const Tensor** input_c, const Tensor** params,
685                            const Tensor** sequence_lengths, const int num_proj,
686                            CudnnRnnModelShapes* model_shapes) {
687   TF_RETURN_IF_ERROR(context->input("sequence_lengths", sequence_lengths));
688   return ExtractForwardInput(context, model_types, time_major, input, input_h,
689                              input_c, params, num_proj, model_shapes);
690 }
691 
692 template <typename T>
CreateForwardAndBackwardIODescriptors(OpKernelContext * context,const CudnnRnnModelShapes & model_shapes,std::unique_ptr<RnnSequenceTensorDescriptor> * input_desc,std::unique_ptr<RnnStateTensorDescriptor> * h_state_desc,std::unique_ptr<RnnStateTensorDescriptor> * c_state_desc,std::unique_ptr<RnnSequenceTensorDescriptor> * output_desc,const absl::Span<const int> seq_lengths,bool time_major)693 Status CreateForwardAndBackwardIODescriptors(
694     OpKernelContext* context, const CudnnRnnModelShapes& model_shapes,
695     std::unique_ptr<RnnSequenceTensorDescriptor>* input_desc,
696     std::unique_ptr<RnnStateTensorDescriptor>* h_state_desc,
697     std::unique_ptr<RnnStateTensorDescriptor>* c_state_desc,
698     std::unique_ptr<RnnSequenceTensorDescriptor>* output_desc,
699     const absl::Span<const int> seq_lengths, bool time_major) {
700   StreamExecutor* executor = context->op_device_context()->stream()->parent();
701   se::dnn::DataType data_type = ToDataType<T>::value;
702 
703   const TensorShape& input_shape = model_shapes.input_shape;
704   const TensorShape& hidden_state_shape = model_shapes.hidden_state_shape;
705   const TensorShape& cell_state_shape = model_shapes.cell_state_shape;
706   const TensorShape& output_shape = model_shapes.output_shape;
707 
708   DCHECK_EQ(input_shape.dims(), 3);
709   if (seq_lengths.data() != nullptr) {
710     if (time_major) {
711       auto input_desc_s = executor->createRnnSequenceTensorDescriptor(
712           input_shape.dim_size(0), input_shape.dim_size(1),
713           input_shape.dim_size(2), seq_lengths, time_major, data_type);
714       TF_RETURN_IF_ERROR(input_desc_s.status());
715       *input_desc = input_desc_s.ConsumeValueOrDie();
716     } else {
717       auto input_desc_s = executor->createRnnSequenceTensorDescriptor(
718           input_shape.dim_size(1), input_shape.dim_size(0),
719           input_shape.dim_size(2), seq_lengths, time_major, data_type);
720       TF_RETURN_IF_ERROR(input_desc_s.status());
721       *input_desc = input_desc_s.ConsumeValueOrDie();
722     }
723   } else {
724     auto input_desc_s = executor->createRnnSequenceTensorDescriptor(
725         input_shape.dim_size(0), input_shape.dim_size(1),
726         input_shape.dim_size(2), data_type);
727     TF_RETURN_IF_ERROR(input_desc_s.status());
728     *input_desc = input_desc_s.ConsumeValueOrDie();
729   }
730 
731   DCHECK_EQ(hidden_state_shape.dims(), 3);
732   if (time_major) {
733     auto hidden_state_desc_s = executor->createRnnStateTensorDescriptor(
734         hidden_state_shape.dim_size(0), hidden_state_shape.dim_size(1),
735         hidden_state_shape.dim_size(2), data_type);
736     TF_RETURN_IF_ERROR(hidden_state_desc_s.status());
737     *h_state_desc = hidden_state_desc_s.ConsumeValueOrDie();
738   } else {
739     auto hidden_state_desc_s = executor->createRnnStateTensorDescriptor(
740         hidden_state_shape.dim_size(1), hidden_state_shape.dim_size(0),
741         hidden_state_shape.dim_size(2), data_type);
742     TF_RETURN_IF_ERROR(hidden_state_desc_s.status());
743     *h_state_desc = hidden_state_desc_s.ConsumeValueOrDie();
744   }
745 
746   DCHECK_EQ(cell_state_shape.dims(), 3);
747   if (time_major) {
748     auto cell_state_desc_s = executor->createRnnStateTensorDescriptor(
749         cell_state_shape.dim_size(0), cell_state_shape.dim_size(1),
750         cell_state_shape.dim_size(2), data_type);
751     TF_RETURN_IF_ERROR(cell_state_desc_s.status());
752     *c_state_desc = cell_state_desc_s.ConsumeValueOrDie();
753   } else {
754     auto cell_state_desc_s = executor->createRnnStateTensorDescriptor(
755         cell_state_shape.dim_size(1), cell_state_shape.dim_size(0),
756         cell_state_shape.dim_size(2), data_type);
757     TF_RETURN_IF_ERROR(cell_state_desc_s.status());
758     *c_state_desc = cell_state_desc_s.ConsumeValueOrDie();
759   }
760 
761   DCHECK_EQ(output_shape.dims(), 3);
762   if (seq_lengths.data() != nullptr) {
763     if (time_major) {
764       auto output_desc_s = executor->createRnnSequenceTensorDescriptor(
765           output_shape.dim_size(0), output_shape.dim_size(1),
766           output_shape.dim_size(2), seq_lengths, time_major, data_type);
767       TF_RETURN_IF_ERROR(output_desc_s.status());
768       *output_desc = output_desc_s.ConsumeValueOrDie();
769     } else {
770       auto output_desc_s = executor->createRnnSequenceTensorDescriptor(
771           output_shape.dim_size(1), output_shape.dim_size(0),
772           output_shape.dim_size(2), seq_lengths, time_major, data_type);
773       TF_RETURN_IF_ERROR(output_desc_s.status());
774       *output_desc = output_desc_s.ConsumeValueOrDie();
775     }
776   } else {
777     auto output_desc_s = executor->createRnnSequenceTensorDescriptor(
778         output_shape.dim_size(0), output_shape.dim_size(1),
779         output_shape.dim_size(2), data_type);
780     TF_RETURN_IF_ERROR(output_desc_s.status());
781     *output_desc = output_desc_s.ConsumeValueOrDie();
782   }
783 
784   return Status::OK();
785 }
786 
787 template <typename T>
DoForward(OpKernelContext * context,const RnnDescriptor & rnn_desc,const CudnnModelTypes & model_types,const CudnnRnnModelShapes & model_shapes,const Tensor * input,const Tensor * input_h,const Tensor * input_c,const Tensor * params,const bool is_training,Tensor * output,Tensor * output_h,Tensor * output_c,const Tensor * sequence_lengths,bool time_major,ScratchAllocator * reserve_space_allocator,ScratchAllocator * workspace_allocator,ProfileResult * output_profile_result)788 Status DoForward(OpKernelContext* context, const RnnDescriptor& rnn_desc,
789                  const CudnnModelTypes& model_types,
790                  const CudnnRnnModelShapes& model_shapes,
791                  /* forward inputs */
792                  const Tensor* input, const Tensor* input_h,
793                  const Tensor* input_c, const Tensor* params,
794                  const bool is_training,
795                  /* forward outputs, outputs of the function */
796                  Tensor* output, Tensor* output_h, Tensor* output_c,
797                  const Tensor* sequence_lengths, bool time_major,
798                  ScratchAllocator* reserve_space_allocator,
799                  ScratchAllocator* workspace_allocator,
800                  ProfileResult* output_profile_result) {
801   std::unique_ptr<RnnSequenceTensorDescriptor> input_desc;
802   std::unique_ptr<RnnStateTensorDescriptor> h_state_desc;
803   std::unique_ptr<RnnStateTensorDescriptor> c_state_desc;
804   std::unique_ptr<RnnSequenceTensorDescriptor> output_desc;
805 
806   absl::Span<const int> seq_lengths;
807   if (sequence_lengths != nullptr) {
808     seq_lengths = absl::Span<const int>(
809         sequence_lengths->template flat<int>().data(), model_shapes.batch_size);
810   }
811   TF_RETURN_IF_ERROR(CreateForwardAndBackwardIODescriptors<T>(
812       context, model_shapes, &input_desc, &h_state_desc, &c_state_desc,
813       &output_desc, seq_lengths, time_major));
814 
815   auto input_data = AsDeviceMemory<T>(input);
816   auto input_h_data = AsDeviceMemory<T>(input_h);
817   DeviceMemory<T> input_c_data;
818   if (model_types.HasInputC()) {
819     input_c_data = AsDeviceMemory<T>(input_c);
820   }
821 
822   auto params_data = AsDeviceMemory<T>(params);
823   auto output_data = AsDeviceMemory<T>(output);
824   auto output_h_data = AsDeviceMemory<T>(output_h);
825   DeviceMemory<T> output_c_data;
826   if (model_types.HasInputC()) {
827     output_c_data = AsDeviceMemory<T>(output_c);
828   }
829 
830   Stream* stream = context->op_device_context()->stream();
831 
832   Tensor seq_lengths_tensor;
833   DeviceMemory<int> seq_lengths_ptr;
834   if (sequence_lengths != nullptr) {
835     TF_RETURN_IF_ERROR(context->allocate_temp(
836         DT_INT32, {static_cast<long>(seq_lengths.size())},
837         &seq_lengths_tensor));
838     seq_lengths_ptr = AsDeviceMemory<int>(&seq_lengths_tensor);
839     if (!stream
840              ->ThenMemcpy(&seq_lengths_ptr, seq_lengths.data(),
841                           seq_lengths.size() * sizeof(int))
842              .ok()) {
843       return errors::InvalidArgument(
844           "Failed to copy memory from host to "
845           "device for sequence_lengths in "
846           "CudnnRNNV3");
847     }
848   }
849 
850   bool launch_success =
851       stream
852           ->ThenRnnForward(rnn_desc, *input_desc, input_data, seq_lengths_ptr,
853                            *h_state_desc, input_h_data, *c_state_desc,
854                            input_c_data, params_data, *output_desc,
855                            &output_data, *h_state_desc, &output_h_data,
856                            *c_state_desc, &output_c_data, is_training,
857                            reserve_space_allocator, workspace_allocator,
858                            output_profile_result)
859           .ok();
860   return launch_success
861              ? Status::OK()
862              : errors::Internal(
863                    "Failed to call ThenRnnForward with model config: ",
864                    model_types.DebugString(), ", ", model_shapes.DebugString());
865 }
866 
867 template <typename T>
DoBackward(OpKernelContext * context,const RnnDescriptor & rnn_desc,const CudnnModelTypes & model_types,const CudnnRnnModelShapes & model_shapes,const Tensor * input,const Tensor * input_h,const Tensor * input_c,const Tensor * params,const Tensor * output,const Tensor * output_h,const Tensor * output_c,const Tensor * output_backprop,const Tensor * output_h_backprop,const Tensor * output_c_backprop,const Tensor * reserve_space,Tensor * input_backprop,Tensor * input_h_backprop,Tensor * input_c_backprop,Tensor * params_backprop,const Tensor * sequence_lengths,bool time_major,ScratchAllocator * workspace_allocator,ProfileResult * output_profile_result)868 Status DoBackward(
869     OpKernelContext* context, const RnnDescriptor& rnn_desc,
870     const CudnnModelTypes& model_types, const CudnnRnnModelShapes& model_shapes,
871     /* forward inputs */
872     const Tensor* input, const Tensor* input_h, const Tensor* input_c,
873     const Tensor* params,
874     /* forward outputs */
875     const Tensor* output, const Tensor* output_h, const Tensor* output_c,
876     /* backprop inputs */
877     const Tensor* output_backprop, const Tensor* output_h_backprop,
878     const Tensor* output_c_backprop, const Tensor* reserve_space,
879     /* backprop outputs, output of the function */
880     Tensor* input_backprop, Tensor* input_h_backprop, Tensor* input_c_backprop,
881     Tensor* params_backprop, const Tensor* sequence_lengths, bool time_major,
882     ScratchAllocator* workspace_allocator,
883     ProfileResult* output_profile_result) {
884   std::unique_ptr<RnnSequenceTensorDescriptor> input_desc;
885   std::unique_ptr<RnnStateTensorDescriptor> h_state_desc;
886   std::unique_ptr<RnnStateTensorDescriptor> c_state_desc;
887   std::unique_ptr<RnnSequenceTensorDescriptor> output_desc;
888 
889   absl::Span<const int> seq_lengths;
890   if (sequence_lengths != nullptr) {
891     seq_lengths = absl::Span<const int>(
892         sequence_lengths->template flat<int>().data(), model_shapes.batch_size);
893   }
894   TF_RETURN_IF_ERROR(CreateForwardAndBackwardIODescriptors<T>(
895       context, model_shapes, &input_desc, &h_state_desc, &c_state_desc,
896       &output_desc, seq_lengths, time_major));
897 
898   auto input_data = AsDeviceMemory<T>(input);
899   auto input_h_data = AsDeviceMemory<T>(input_h);
900   DeviceMemory<T> input_c_data;
901   if (model_types.HasInputC()) {
902     input_c_data = AsDeviceMemory<T>(input_c);
903   }
904   auto params_data = AsDeviceMemory<T>(params);
905   auto output_data = AsDeviceMemory<T>(output);
906   auto output_h_data = AsDeviceMemory<T>(output_h);
907   DeviceMemory<T> output_c_data;
908   if (model_types.HasInputC()) {
909     output_c_data = AsDeviceMemory<T>(output_c);
910   }
911   auto output_backprop_data = AsDeviceMemory<T>(output_backprop);
912   auto output_h_backprop_data = AsDeviceMemory<T>(output_h_backprop);
913   DeviceMemory<T> output_c_backprop_data;
914   if (model_types.HasInputC()) {
915     output_c_backprop_data = AsDeviceMemory<T>(output_c_backprop);
916   }
917   auto input_backprop_data = AsDeviceMemory<T>(input_backprop);
918   auto input_h_backprop_data = AsDeviceMemory<T>(input_h_backprop);
919   DeviceMemory<T> input_c_backprop_data;
920   if (model_types.HasInputC()) {
921     input_c_backprop_data = AsDeviceMemory<T>(input_c_backprop);
922   }
923   auto params_backprop_data = AsDeviceMemory<T>(params_backprop);
924   auto reserve_space_uint8 =
925       CastDeviceMemory<uint8, T>(const_cast<Tensor*>(reserve_space));
926 
927   // Creates a memory callback for the workspace. The memory lives to the end
928   // of this kernel calls.
929   Stream* stream = context->op_device_context()->stream();
930 
931   Tensor seq_lengths_tensor;
932   DeviceMemory<int> seq_lengths_ptr;
933   if (sequence_lengths != nullptr) {
934     TF_RETURN_IF_ERROR(context->allocate_temp(
935         DT_INT32, {static_cast<long>(seq_lengths.size())},
936         &seq_lengths_tensor));
937     seq_lengths_ptr = AsDeviceMemory<int>(&seq_lengths_tensor);
938     if (!stream
939              ->ThenMemcpy(&seq_lengths_ptr, seq_lengths.data(),
940                           seq_lengths.size() * sizeof(int))
941              .ok()) {
942       return errors::InvalidArgument(
943           "Failed to copy memory from host to "
944           "device for sequence_lengths in "
945           "CudnnRNNBackwardOpV3");
946     }
947   }
948 
949   bool launch_success =
950       stream
951           ->ThenRnnBackward(
952               rnn_desc, *input_desc, input_data, seq_lengths_ptr, *h_state_desc,
953               input_h_data, *c_state_desc, input_c_data, params_data,
954               *output_desc, output_data, *h_state_desc, output_h_data,
955               *c_state_desc, output_c_data, output_backprop_data,
956               output_h_backprop_data, output_c_backprop_data,
957               &input_backprop_data, &input_h_backprop_data,
958               &input_c_backprop_data, &params_backprop_data,
959               &reserve_space_uint8, workspace_allocator, output_profile_result)
960           .ok();
961   return launch_success
962              ? Status::OK()
963              : errors::Internal(
964                    "Failed to call ThenRnnBackward with model config: ",
965                    model_types.DebugString(), ", ", model_shapes.DebugString());
966 }
967 
968 template <typename T>
RestoreParams(const OpInputList params_input,const std::vector<RnnDescriptor::ParamsRegion> & params,DeviceMemoryBase * data_dst,Stream * stream)969 void RestoreParams(const OpInputList params_input,
970                    const std::vector<RnnDescriptor::ParamsRegion>& params,
971                    DeviceMemoryBase* data_dst, Stream* stream) {
972   int num_params = params.size();
973   CHECK(params_input.size() == num_params)
974       << "Number of params mismatch. Expected " << params_input.size()
975       << ", got " << num_params;
976   for (int i = 0; i < params.size(); i++) {
977     int64_t size_in_bytes = params[i].size;
978     int64_t size = size_in_bytes / sizeof(T);
979     CHECK(size == params_input[i].NumElements())
980         << "Params size mismatch. Expected " << size << ", got "
981         << params_input[i].NumElements();
982     auto data_src_ptr = StreamExecutorUtil::AsDeviceMemory<T>(params_input[i]);
983     DeviceMemoryBase data_dst_ptr =
984         SliceDeviceMemory(*data_dst, params[i].offset, size_in_bytes);
985     stream->ThenMemcpy(&data_dst_ptr, data_src_ptr, size_in_bytes);
986   }
987 }
988 
ShouldUsePaddedIO(const Tensor * sequence_lengths,const CudnnRnnModelShapes & model_shapes,bool time_major)989 bool ShouldUsePaddedIO(const Tensor* sequence_lengths,
990                        const CudnnRnnModelShapes& model_shapes,
991                        bool time_major) {
992   auto seq_array = sequence_lengths->template flat<int>().data();
993   bool all_max_seq_length = true;
994   for (int i = 0; i < model_shapes.batch_size; i++) {
995     if (seq_array[i] != model_shapes.max_seq_length) {
996       all_max_seq_length = false;
997       break;
998     }
999   }
1000   return !(time_major && all_max_seq_length);
1001 }
1002 
1003 }  // namespace
1004 
1005 // Note: all following kernels depend on a RnnDescriptor instance, which
1006 // according to Cudnn official doc should be kept around and reused across all
1007 // Cudnn kernels in the same model.
1008 // In Tensorflow, we don't pass the reference across different OpKernels,
1009 // rather, recreate it separately in each OpKernel, which does no cause issue:
1010 // CudnnDropoutDescriptor keeps a reference to a memory for
1011 // random number generator state. During recreation, this state is lost.
1012 // However, only forward-pass Cudnn APIs make use of the state.
1013 
1014 // A common base class for RNN kernels. It extracts common attributes and
1015 // shape validations.
1016 class CudnnRNNKernelCommon : public OpKernel {
1017  protected:
CudnnRNNKernelCommon(OpKernelConstruction * context)1018   explicit CudnnRNNKernelCommon(OpKernelConstruction* context)
1019       : OpKernel(context) {
1020     OP_REQUIRES_OK(context, context->GetAttr("dropout", &dropout_));
1021     OP_REQUIRES_OK(context, context->GetAttr("seed", &seed_));
1022     OP_REQUIRES_OK(context, context->GetAttr("seed2", &seed2_));
1023     string str;
1024     OP_REQUIRES_OK(context, context->GetAttr("rnn_mode", &str));
1025     OP_REQUIRES_OK(context, ParseRNNMode(str, &model_types_.rnn_mode));
1026     OP_REQUIRES_OK(context, context->GetAttr("input_mode", &str));
1027     OP_REQUIRES_OK(context,
1028                    ParseTFRNNInputMode(str, &model_types_.rnn_input_mode));
1029     OP_REQUIRES_OK(context, context->GetAttr("direction", &str));
1030     OP_REQUIRES_OK(
1031         context, ParseRNNDirectionMode(str, &model_types_.rnn_direction_mode));
1032     // Reset CudnnRnnDescriptor and related random number generate states in
1033     // every Compute() call.
1034     OP_REQUIRES_OK(context, ReadBoolFromEnvVar("TF_CUDNN_RESET_RND_GEN_STATE",
1035                                                false, &reset_rnd_gen_state_));
1036   }
1037 
HasInputC() const1038   bool HasInputC() const { return model_types_.HasInputC(); }
rnn_mode() const1039   RnnMode rnn_mode() const { return model_types_.rnn_mode; }
rnn_input_mode() const1040   TFRNNInputMode rnn_input_mode() const { return model_types_.rnn_input_mode; }
rnn_direction_mode() const1041   RnnDirectionMode rnn_direction_mode() const {
1042     return model_types_.rnn_direction_mode;
1043   }
model_types() const1044   const CudnnModelTypes& model_types() const { return model_types_; }
dropout() const1045   float dropout() const { return dropout_; }
seed()1046   uint64 seed() { return (static_cast<uint64>(seed_) << 32) | seed2_; }
ResetRndGenState()1047   bool ResetRndGenState() { return reset_rnd_gen_state_; }
1048 
1049   template <typename T>
ExtractCudnnRNNParamsInfo(OpKernelContext * context,int num_proj,std::unique_ptr<RnnDescriptor> * rnn_desc)1050   Status ExtractCudnnRNNParamsInfo(OpKernelContext* context, int num_proj,
1051                                    std::unique_ptr<RnnDescriptor>* rnn_desc) {
1052     const Tensor* num_layers_t = nullptr;
1053     TF_RETURN_IF_ERROR(context->input("num_layers", &num_layers_t));
1054     if (!TensorShapeUtils::IsScalar(num_layers_t->shape())) {
1055       return errors::InvalidArgument("num_layers is not a scalar");
1056     }
1057     int num_layers = num_layers_t->scalar<int>()();
1058     const Tensor* num_units_t = nullptr;
1059     TF_RETURN_IF_ERROR(context->input("num_units", &num_units_t));
1060     if (!TensorShapeUtils::IsScalar(num_units_t->shape())) {
1061       return errors::InvalidArgument("num_units is not a scalar");
1062     }
1063     int num_units = num_units_t->scalar<int>()();
1064     const Tensor* input_size_t = nullptr;
1065     TF_RETURN_IF_ERROR(context->input("input_size", &input_size_t));
1066     if (!TensorShapeUtils::IsScalar(input_size_t->shape())) {
1067       return errors::InvalidArgument("input_size is not a scalar");
1068     }
1069     int input_size = input_size_t->scalar<int>()();
1070 
1071     int h_num_units = (num_proj == 0 ? num_units : num_proj);
1072     int c_num_units = (num_proj == 0 ? 0 : num_units);
1073 
1074     RnnInputMode input_mode;
1075     TF_RETURN_IF_ERROR(
1076         ToRNNInputMode(rnn_input_mode(), num_units, input_size, &input_mode));
1077 
1078     Stream* stream = context->op_device_context()->stream();
1079     // ExtracCudnnRNNParamsInfo is only called by op_kernels that do not require
1080     // random number generator, therefore set state_allocator to nullptr.
1081     const AlgorithmConfig algo_config;
1082     auto rnn_desc_s = stream->parent()->createRnnDescriptor(
1083         num_layers, h_num_units, input_size, /*cell_size=*/c_num_units,
1084         /*batch_size=*/0, input_mode, rnn_direction_mode(), rnn_mode(),
1085         ToDataType<T>::value, algo_config, dropout(), seed(),
1086         /* state_allocator=*/nullptr, /*use_padded_io=*/false);
1087     if (!rnn_desc_s.ok()) {
1088       return FromExecutorStatus(rnn_desc_s);
1089     }
1090     *rnn_desc = rnn_desc_s.ConsumeValueOrDie();
1091     return Status::OK();
1092   }
1093 
1094   template <typename T>
CreateRnnDescriptor(OpKernelContext * context,const CudnnRnnModelShapes & model_shapes,const RnnInputMode & input_mode,const AlgorithmConfig & algo_config,ScratchAllocator * dropout_state_allocator,std::unique_ptr<RnnDescriptor> * rnn_desc,bool use_padded_io)1095   Status CreateRnnDescriptor(OpKernelContext* context,
1096                              const CudnnRnnModelShapes& model_shapes,
1097                              const RnnInputMode& input_mode,
1098                              const AlgorithmConfig& algo_config,
1099                              ScratchAllocator* dropout_state_allocator,
1100                              std::unique_ptr<RnnDescriptor>* rnn_desc,
1101                              bool use_padded_io) {
1102     StreamExecutor* executor = context->op_device_context()->stream()->parent();
1103     se::dnn::DataType data_type = ToDataType<T>::value;
1104     auto rnn_desc_s = executor->createRnnDescriptor(
1105         model_shapes.num_layers, model_shapes.num_units,
1106         model_shapes.input_size, model_shapes.cell_num_units,
1107         model_shapes.batch_size, input_mode, rnn_direction_mode(), rnn_mode(),
1108         data_type, algo_config, dropout(), seed(), dropout_state_allocator,
1109         use_padded_io);
1110     TF_RETURN_IF_ERROR(rnn_desc_s.status());
1111 
1112     *rnn_desc = rnn_desc_s.ConsumeValueOrDie();
1113     return Status::OK();
1114   }
1115 
1116   using RnnStateCache = gtl::FlatMap<
1117       std::pair<CudnnRnnModelShapes, absl::optional<AlgorithmDesc>>,
1118       RnnScratchSpace, CudnnRnnConfigHasher, CudnnRnnConfigComparator>;
1119   // Returns a raw rnn descriptor pointer. The cache owns the rnn descriptor and
1120   // should outlive the returned pointer.
1121   template <typename T>
GetCachedRnnDescriptor(OpKernelContext * context,const CudnnRnnModelShapes & model_shapes,const RnnInputMode & input_mode,const AlgorithmConfig & algo_config,RnnStateCache * cache,RnnDescriptor ** rnn_desc,bool use_padded_io)1122   Status GetCachedRnnDescriptor(OpKernelContext* context,
1123                                 const CudnnRnnModelShapes& model_shapes,
1124                                 const RnnInputMode& input_mode,
1125                                 const AlgorithmConfig& algo_config,
1126                                 RnnStateCache* cache, RnnDescriptor** rnn_desc,
1127                                 bool use_padded_io) {
1128     auto key = std::make_pair(model_shapes, algo_config.algorithm());
1129     RnnScratchSpace& rnn_state = (*cache)[key];
1130     if (rnn_state.rnn_desc == nullptr || ResetRndGenState()) {
1131       CudnnRNNSpaceAllocator* dropout_state_allocator =
1132           new CudnnRNNSpaceAllocator(context);
1133       rnn_state.dropout_state_allocator.reset(dropout_state_allocator);
1134       Status status = CreateRnnDescriptor<T>(
1135           context, model_shapes, input_mode, algo_config,
1136           dropout_state_allocator, &rnn_state.rnn_desc, use_padded_io);
1137       TF_RETURN_IF_ERROR(status);
1138     }
1139     *rnn_desc = rnn_state.rnn_desc.get();
1140     return Status::OK();
1141   }
1142 
1143  private:
1144   int seed_;
1145   int seed2_;
1146   float dropout_;
1147   bool reset_rnd_gen_state_;
1148 
1149   CudnnModelTypes model_types_;
1150 };
1151 
1152 // A class that returns the size of the opaque parameter buffer. The user should
1153 // use that to create the actual parameter buffer for training. However, it
1154 // should not be used for saving and restoring.
1155 template <typename T, typename Index>
1156 class CudnnRNNParamsSizeOp<GPUDevice, T, Index> : public CudnnRNNKernelCommon {
1157  public:
CudnnRNNParamsSizeOp(OpKernelConstruction * context)1158   explicit CudnnRNNParamsSizeOp(OpKernelConstruction* context)
1159       : CudnnRNNKernelCommon(context) {
1160     if (context->HasAttr("num_proj")) {
1161       OP_REQUIRES_OK(context, context->GetAttr("num_proj", &num_proj_));
1162     } else {
1163       num_proj_ = 0;
1164     }
1165   }
1166 
Compute(OpKernelContext * context)1167   void Compute(OpKernelContext* context) override {
1168     std::unique_ptr<RnnDescriptor> rnn_desc;
1169     OP_REQUIRES_OK(context,
1170                    ExtractCudnnRNNParamsInfo<T>(context, num_proj_, &rnn_desc));
1171     int64_t params_size_in_bytes = rnn_desc->ParamsSizeInBytes();
1172     CHECK(params_size_in_bytes % sizeof(T) == 0)
1173         << "params_size_in_bytes must be multiple of element size";
1174     int64_t params_size = params_size_in_bytes / sizeof(T);
1175 
1176     Tensor* output_t = nullptr;
1177     OP_REQUIRES_OK(context, context->allocate_output(0, {1}, &output_t));
1178     *output_t->template flat<Index>().data() = params_size;
1179   }
1180 
1181  private:
1182   int num_proj_;
1183 };
1184 
1185 #define REGISTER_GPU(T)                                    \
1186   REGISTER_KERNEL_BUILDER(Name("CudnnRNNParamsSize")       \
1187                               .Device(DEVICE_GPU)          \
1188                               .HostMemory("num_layers")    \
1189                               .HostMemory("num_units")     \
1190                               .HostMemory("input_size")    \
1191                               .HostMemory("params_size")   \
1192                               .TypeConstraint<T>("T")      \
1193                               .TypeConstraint<int32>("S"), \
1194                           CudnnRNNParamsSizeOp<GPUDevice, T, int32>);
1195 
1196 TF_CALL_half(REGISTER_GPU);
1197 TF_CALL_float(REGISTER_GPU);
1198 TF_CALL_double(REGISTER_GPU);
1199 #undef REGISTER_GPU
1200 
1201 // Convert weight and bias params from a platform-specific layout to the
1202 // canonical form.
1203 template <typename T>
1204 class CudnnRNNParamsToCanonical<GPUDevice, T> : public CudnnRNNKernelCommon {
1205  public:
CudnnRNNParamsToCanonical(OpKernelConstruction * context)1206   explicit CudnnRNNParamsToCanonical(OpKernelConstruction* context)
1207       : CudnnRNNKernelCommon(context) {
1208     if (context->HasAttr("num_params")) {
1209       OP_REQUIRES_OK(context, context->GetAttr("num_params", &num_params_));
1210     } else {
1211       num_params_ = 0;
1212     }
1213     if (context->HasAttr("num_params_weights")) {
1214       OP_REQUIRES_OK(context, context->GetAttr("num_params_weights",
1215                                                &num_params_weights_));
1216     } else {
1217       num_params_weights_ = 0;
1218     }
1219     if (context->HasAttr("num_params_biases")) {
1220       OP_REQUIRES_OK(
1221           context, context->GetAttr("num_params_biases", &num_params_biases_));
1222     } else {
1223       num_params_biases_ = 0;
1224     }
1225     if (context->HasAttr("num_proj")) {
1226       OP_REQUIRES_OK(context, context->GetAttr("num_proj", &num_proj_));
1227     } else {
1228       num_proj_ = 0;
1229     }
1230     if (num_proj_ == 0) {
1231       num_params_weights_ = num_params_;
1232       num_params_biases_ = num_params_;
1233     }
1234   }
1235 
Compute(OpKernelContext * context)1236   void Compute(OpKernelContext* context) override {
1237     const Tensor& input = context->input(3);
1238     auto input_ptr = StreamExecutorUtil::AsDeviceMemory<T>(input);
1239     Stream* stream = context->op_device_context()->stream();
1240 
1241     std::unique_ptr<RnnDescriptor> rnn_desc;
1242     OP_REQUIRES_OK(context,
1243                    ExtractCudnnRNNParamsInfo<T>(context, num_proj_, &rnn_desc));
1244     int64_t params_size_in_bytes = rnn_desc->ParamsSizeInBytes();
1245     CHECK(params_size_in_bytes % sizeof(T) == 0)
1246         << "params_size_in_bytes must be multiple of element size";
1247 
1248     const Tensor* num_units_t = nullptr;
1249     OP_REQUIRES_OK(context, context->input("num_units", &num_units_t));
1250     CHECK(TensorShapeUtils::IsScalar(num_units_t->shape()))
1251         << "num_units is not a scalar";
1252     int num_units = num_units_t->scalar<int>()();
1253 
1254     const Tensor* input_size_t = nullptr;
1255     OP_REQUIRES_OK(context, context->input("input_size", &input_size_t));
1256     CHECK(TensorShapeUtils::IsScalar(input_size_t->shape()))
1257         << "input_size is not a scalar";
1258     int input_size = input_size_t->scalar<int>()();
1259 
1260     const Tensor* num_layers_t = nullptr;
1261     OP_REQUIRES_OK(context, context->input("num_layers", &num_layers_t));
1262     CHECK(TensorShapeUtils::IsScalar(num_layers_t->shape()))
1263         << "num_layers is not a scalar";
1264     int num_layers = num_layers_t->scalar<int>()();
1265     int num_dirs = 1;
1266     if (rnn_direction_mode() == RnnDirectionMode::kRnnBidirectional) {
1267       num_dirs = 2;
1268     }
1269     const int num_params_weights_per_layer =
1270         num_params_weights_ / num_layers / num_dirs;
1271     // Number of params applied on inputs. The rest are applied on recurrent
1272     // hidden states.
1273     const int num_params_input_state = num_params_weights_per_layer / 2;
1274     OP_REQUIRES(
1275         context, num_params_weights_ % (num_layers * num_dirs) == 0,
1276         errors::InvalidArgument("Number of params (weights) is not a multiple"
1277                                 "of num_layers * num_dirs."));
1278     OP_REQUIRES(
1279         context, num_params_biases_ % (num_layers * num_dirs) == 0,
1280         errors::InvalidArgument("Number of params (biases) is not a multiple"
1281                                 "of num_layers * num_dirs."));
1282     if (num_proj_ == 0) {
1283       OP_REQUIRES(
1284           context, num_params_weights_per_layer % 2 == 0,
1285           errors::InvalidArgument("Number of params (weights) per layer is not"
1286                                   "an even number with no projection."));
1287     } else {
1288       OP_REQUIRES(
1289           context, num_params_weights_per_layer % 2 != 0,
1290           errors::InvalidArgument("Number of params (weights) per layer is not"
1291                                   "an odl number with projection."));
1292     }
1293 
1294     OP_REQUIRES(
1295         context, num_params_weights_ == rnn_desc->ParamsWeightRegions().size(),
1296         errors::InvalidArgument("C Number of params mismatch. Expected ",
1297                                 num_params_weights_, ", got ",
1298                                 rnn_desc->ParamsWeightRegions().size()));
1299     int h_num_units = (num_proj_ == 0 ? num_units : num_proj_);
1300     int c_num_units = (num_proj_ == 0 ? 0 : num_units);
1301     for (int i = 0; i < rnn_desc->ParamsWeightRegions().size(); i++) {
1302       int64_t size_in_bytes = rnn_desc->ParamsWeightRegions()[i].size;
1303       int64_t size = size_in_bytes / sizeof(T);
1304       const int layer_idx = i / num_params_weights_per_layer;
1305       const int index_within_layer = i % num_params_weights_per_layer;
1306       int width = 0, height = (num_proj_ == 0 ? h_num_units : c_num_units);
1307       // In CuDNN layout, each layer has num_params_weights_per_layer params,
1308       // with the
1309       // first half a.k.a num_params_input_state params applied on the inputs,
1310       // and the second half on the recurrent hidden states.
1311       bool apply_on_input_state = index_within_layer < num_params_input_state;
1312       if (rnn_direction_mode() == RnnDirectionMode::kRnnUnidirectional) {
1313         if (layer_idx == 0 && apply_on_input_state) {
1314           width = input_size;
1315         } else {
1316           width = h_num_units;
1317         }
1318       } else {
1319         if (apply_on_input_state) {
1320           if (layer_idx <= 1) {
1321             // First fwd or bak layer.
1322             width = input_size;
1323           } else {
1324             // Following layers, cell inputs are concatenated outputs of
1325             // its prior layer.
1326             width = 2 * h_num_units;
1327           }
1328         } else {
1329           width = h_num_units;
1330         }
1331       }
1332       CHECK(size == width * height) << "Params size mismatch. Expected "
1333                                     << width * height << ", got " << size;
1334       Tensor* output = nullptr;
1335       int id_in_layer = i % num_params_weights_per_layer;
1336       if (num_proj_ != 0 && id_in_layer == num_params_weights_per_layer - 1) {
1337         std::swap(height, width);
1338       }
1339       OP_REQUIRES_OK(context, context->allocate_output(
1340                                   i, TensorShape({height, width}), &output));
1341       DeviceMemoryBase data_src_ptr = SliceDeviceMemory(
1342           input_ptr, rnn_desc->ParamsWeightRegions()[i].offset, size_in_bytes);
1343       auto data_dst_ptr = StreamExecutorUtil::AsDeviceMemory<T>(*output);
1344       stream->ThenMemcpy(&data_dst_ptr, data_src_ptr, size_in_bytes);
1345     }
1346 
1347     OP_REQUIRES(
1348         context, num_params_biases_ == rnn_desc->ParamsBiasRegions().size(),
1349         errors::InvalidArgument("A Number of params mismatch. Expected ",
1350                                 num_params_biases_, ", got ",
1351                                 rnn_desc->ParamsBiasRegions().size()));
1352     for (int i = 0; i < rnn_desc->ParamsBiasRegions().size(); i++) {
1353       int64_t size_in_bytes = rnn_desc->ParamsBiasRegions()[i].size;
1354       int64_t size = size_in_bytes / sizeof(T);
1355       OP_REQUIRES(context, size == num_units,
1356                   errors::InvalidArgument("Params size mismatch. Expected ",
1357                                           num_units, ", got ", size));
1358 
1359       Tensor* output = nullptr;
1360       OP_REQUIRES_OK(context,
1361                      context->allocate_output(num_params_weights_ + i,
1362                                               TensorShape({size}), &output));
1363       DeviceMemoryBase data_src_ptr = SliceDeviceMemory(
1364           input_ptr, rnn_desc->ParamsBiasRegions()[i].offset, size_in_bytes);
1365       auto data_dst_ptr = StreamExecutorUtil::AsDeviceMemory<T>(*output);
1366       stream->ThenMemcpy(&data_dst_ptr, data_src_ptr, size_in_bytes);
1367     }
1368   }
1369 
1370  private:
1371   int num_params_;
1372   int num_params_weights_;
1373   int num_params_biases_;
1374   int num_proj_;
1375 };
1376 
1377 #define REGISTER_GPU(T)                                     \
1378   REGISTER_KERNEL_BUILDER(Name("CudnnRNNParamsToCanonical") \
1379                               .Device(DEVICE_GPU)           \
1380                               .HostMemory("num_layers")     \
1381                               .HostMemory("num_units")      \
1382                               .HostMemory("input_size")     \
1383                               .TypeConstraint<T>("T"),      \
1384                           CudnnRNNParamsToCanonical<GPUDevice, T>);
1385 TF_CALL_half(REGISTER_GPU);
1386 TF_CALL_float(REGISTER_GPU);
1387 TF_CALL_double(REGISTER_GPU);
1388 #undef REGISTER_GPU
1389 
1390 #define REGISTER_GPU(T)                                       \
1391   REGISTER_KERNEL_BUILDER(Name("CudnnRNNParamsToCanonicalV2") \
1392                               .Device(DEVICE_GPU)             \
1393                               .HostMemory("num_layers")       \
1394                               .HostMemory("num_units")        \
1395                               .HostMemory("input_size")       \
1396                               .TypeConstraint<T>("T"),        \
1397                           CudnnRNNParamsToCanonical<GPUDevice, T>);
1398 TF_CALL_half(REGISTER_GPU);
1399 TF_CALL_float(REGISTER_GPU);
1400 TF_CALL_double(REGISTER_GPU);
1401 #undef REGISTER_GPU
1402 
1403 // Convert weight and bias params from the canonical form to a
1404 // platform-specific layout.
1405 template <typename T>
1406 class CudnnRNNCanonicalToParams<GPUDevice, T> : public CudnnRNNKernelCommon {
1407  public:
CudnnRNNCanonicalToParams(OpKernelConstruction * context)1408   explicit CudnnRNNCanonicalToParams(OpKernelConstruction* context)
1409       : CudnnRNNKernelCommon(context) {
1410     if (context->HasAttr("num_proj")) {
1411       OP_REQUIRES_OK(context, context->GetAttr("num_proj", &num_proj_));
1412     } else {
1413       num_proj_ = 0;
1414     }
1415   }
1416 
Compute(OpKernelContext * context)1417   void Compute(OpKernelContext* context) override {
1418     std::unique_ptr<RnnDescriptor> rnn_desc;
1419     OP_REQUIRES_OK(context,
1420                    ExtractCudnnRNNParamsInfo<T>(context, num_proj_, &rnn_desc));
1421     int64_t params_size_in_bytes = rnn_desc->ParamsSizeInBytes();
1422     CHECK(params_size_in_bytes % sizeof(T) == 0)
1423         << "params_size_in_bytes must be multiple of element size";
1424     Tensor* output = nullptr;
1425     int params_size = params_size_in_bytes / sizeof(T);
1426     OP_REQUIRES_OK(context,
1427                    context->allocate_output(0, {params_size}, &output));
1428     auto output_ptr = StreamExecutorUtil::AsDeviceMemory<T>(*output);
1429     Stream* stream = context->op_device_context()->stream();
1430 
1431     OpInputList weights;
1432     OP_REQUIRES_OK(context, context->input_list("weights", &weights));
1433     RestoreParams<T>(weights, rnn_desc->ParamsWeightRegions(), &output_ptr,
1434                      stream);
1435 
1436     OpInputList biases;
1437     OP_REQUIRES_OK(context, context->input_list("biases", &biases));
1438     RestoreParams<T>(biases, rnn_desc->ParamsBiasRegions(), &output_ptr,
1439                      stream);
1440   }
1441 
1442  private:
1443   int num_proj_;
1444 };
1445 
1446 #define REGISTER_GPU(T)                                     \
1447   REGISTER_KERNEL_BUILDER(Name("CudnnRNNCanonicalToParams") \
1448                               .Device(DEVICE_GPU)           \
1449                               .HostMemory("num_layers")     \
1450                               .HostMemory("num_units")      \
1451                               .HostMemory("input_size")     \
1452                               .TypeConstraint<T>("T"),      \
1453                           CudnnRNNCanonicalToParams<GPUDevice, T>);
1454 TF_CALL_half(REGISTER_GPU);
1455 TF_CALL_float(REGISTER_GPU);
1456 TF_CALL_double(REGISTER_GPU);
1457 #undef REGISTER_GPU
1458 
1459 #define REGISTER_GPU(T)                                       \
1460   REGISTER_KERNEL_BUILDER(Name("CudnnRNNCanonicalToParamsV2") \
1461                               .Device(DEVICE_GPU)             \
1462                               .HostMemory("num_layers")       \
1463                               .HostMemory("num_units")        \
1464                               .HostMemory("input_size")       \
1465                               .TypeConstraint<T>("T"),        \
1466                           CudnnRNNCanonicalToParams<GPUDevice, T>);
1467 TF_CALL_half(REGISTER_GPU);
1468 TF_CALL_float(REGISTER_GPU);
1469 TF_CALL_double(REGISTER_GPU);
1470 #undef REGISTER_GPU
1471 
1472 // Run the forward operation of the RNN model.
1473 template <typename T>
1474 class CudnnRNNForwardOp<GPUDevice, T> : public CudnnRNNKernelCommon {
1475  public:
CudnnRNNForwardOp(OpKernelConstruction * context)1476   explicit CudnnRNNForwardOp(OpKernelConstruction* context)
1477       : CudnnRNNKernelCommon(context) {
1478     OP_REQUIRES_OK(context, context->GetAttr("is_training", &is_training_));
1479 
1480     // Read debug env variables.
1481     is_debug_mode_ = DebugCudnnRnn();
1482     debug_cudnn_rnn_algo_ = DebugCudnnRnnAlgo();
1483     debug_use_tensor_ops_ = DebugCudnnRnnUseTensorOps();
1484   }
1485 
Compute(OpKernelContext * context)1486   void Compute(OpKernelContext* context) override {
1487     AlgorithmConfig algo_config;
1488     ComputeAndReturnAlgorithm(context, &algo_config, /*var_seq_lengths=*/false,
1489                               /*time_major=*/true, /*num_proj=*/0);
1490   }
1491 
1492  protected:
ComputeAndReturnAlgorithm(OpKernelContext * context,AlgorithmConfig * output_algo_config,bool var_seq_lengths,bool time_major,int num_proj)1493   virtual void ComputeAndReturnAlgorithm(OpKernelContext* context,
1494                                          AlgorithmConfig* output_algo_config,
1495                                          bool var_seq_lengths, bool time_major,
1496                                          int num_proj) {
1497     CHECK_NE(output_algo_config, nullptr);
1498 
1499     const Tensor* input = nullptr;
1500     const Tensor* input_h = nullptr;
1501     const Tensor* input_c = nullptr;
1502     const Tensor* params = nullptr;
1503     const Tensor* sequence_lengths = nullptr;
1504     CudnnRnnModelShapes model_shapes;
1505     bool use_padded_io = false;
1506     if (var_seq_lengths) {
1507       OP_REQUIRES_OK(context, ExtractForwardInput(
1508                                   context, model_types(), time_major, &input,
1509                                   &input_h, &input_c, &params,
1510                                   &sequence_lengths, num_proj, &model_shapes));
1511       use_padded_io =
1512           ShouldUsePaddedIO(sequence_lengths, model_shapes, time_major);
1513     } else {
1514       OP_REQUIRES_OK(context,
1515                      ExtractForwardInput(context, model_types(), time_major,
1516                                          &input, &input_h, &input_c, &params,
1517                                          num_proj, &model_shapes));
1518     }
1519     RnnInputMode input_mode;
1520     OP_REQUIRES_OK(context,
1521                    ToRNNInputMode(rnn_input_mode(), model_shapes.num_units,
1522                                   model_shapes.input_size, &input_mode));
1523 
1524     Tensor* output = nullptr;
1525     Tensor* output_h = nullptr;
1526     Tensor* output_c = nullptr;
1527     OP_REQUIRES_OK(context, AllocateOutputs(context, model_shapes, &output,
1528                                             &output_h, &output_c));
1529 
1530     // Creates a memory callback for the reserve_space. The memory lives in the
1531     // output of this kernel. And it will be fed into the backward pass when
1532     // needed.
1533     CudnnRnnAllocatorInOutput<T> reserve_space_allocator(context, 3);
1534     // Creates a memory callback for the workspace. The memory lives to the end
1535     // of this kernel calls.
1536     CudnnRnnAllocatorInTemp<uint8> workspace_allocator(context);
1537 
1538     if (is_debug_mode_) {
1539       AlgorithmDesc algo_desc(debug_cudnn_rnn_algo_, debug_use_tensor_ops_);
1540       output_algo_config->set_algorithm(algo_desc);
1541     } else {
1542       OP_REQUIRES_OK(context,
1543                      MaybeAutotune(context, model_shapes, input_mode, input,
1544                                    input_h, input_c, params, output, output_h,
1545                                    output_c, output_algo_config));
1546     }
1547 
1548     Status launch_status;
1549     {
1550       mutex_lock l(mu_);
1551       RnnDescriptor* rnn_desc_ptr = nullptr;
1552       OP_REQUIRES_OK(context,
1553                      GetCachedRnnDescriptor<T>(
1554                          context, model_shapes, input_mode, *output_algo_config,
1555                          &rnn_state_cache_, &rnn_desc_ptr, use_padded_io));
1556       launch_status = DoForward<T>(
1557           context, *rnn_desc_ptr, model_types(), model_shapes, input, input_h,
1558           input_c, params, is_training_, output, output_h, output_c,
1559           sequence_lengths, time_major, &reserve_space_allocator,
1560           &workspace_allocator, /*output_profile_result=*/nullptr);
1561     }
1562     OP_REQUIRES_OK(context, launch_status);
1563   }
1564 
1565  protected:
MaybeAutotune(OpKernelContext * context,const CudnnRnnModelShapes & model_shapes,const RnnInputMode & input_mode,const Tensor * input,const Tensor * input_h,const Tensor * input_c,const Tensor * params,Tensor * output,Tensor * output_h,Tensor * output_c,AlgorithmConfig * best_algo_config)1566   virtual Status MaybeAutotune(OpKernelContext* context,
1567                                const CudnnRnnModelShapes& model_shapes,
1568                                const RnnInputMode& input_mode,
1569                                const Tensor* input, const Tensor* input_h,
1570                                const Tensor* input_c, const Tensor* params,
1571                                Tensor* output, Tensor* output_h,
1572                                Tensor* output_c,
1573                                AlgorithmConfig* best_algo_config) {
1574     CHECK_NE(best_algo_config, nullptr);
1575     *best_algo_config = AlgorithmConfig();
1576     return Status::OK();
1577   }
1578 
is_training() const1579   bool is_training() const { return is_training_; }
1580   bool is_debug_mode_;
1581   bool debug_use_tensor_ops_;
1582   int64 debug_cudnn_rnn_algo_;
1583 
1584  private:
AllocateOutputs(OpKernelContext * context,const CudnnRnnModelShapes & model_shapes,Tensor ** output,Tensor ** output_h,Tensor ** output_c)1585   Status AllocateOutputs(OpKernelContext* context,
1586                          const CudnnRnnModelShapes& model_shapes,
1587                          Tensor** output, Tensor** output_h,
1588                          Tensor** output_c) {
1589     const TensorShape& hidden_state_shape = model_shapes.hidden_state_shape;
1590     const TensorShape& output_shape = model_shapes.output_shape;
1591     const TensorShape& cell_state_shape = model_shapes.cell_state_shape;
1592 
1593     TF_RETURN_IF_ERROR(context->allocate_output(0, output_shape, output));
1594     TF_RETURN_IF_ERROR(
1595         context->allocate_output(1, hidden_state_shape, output_h));
1596     if (HasInputC()) {
1597       TF_RETURN_IF_ERROR(
1598           context->allocate_output(2, cell_state_shape, output_c));
1599     } else {
1600       // Only LSTM uses input_c and output_c. So for all other models, we only
1601       // need to create dummy outputs.
1602       TF_RETURN_IF_ERROR(context->allocate_output(2, {}, output_c));
1603     }
1604     if (!is_training_) {
1605       Tensor* dummy_reserve_space = nullptr;
1606       TF_RETURN_IF_ERROR(context->allocate_output(3, {}, &dummy_reserve_space));
1607     }
1608     return Status::OK();
1609   }
1610 
1611   mutex mu_;
1612   bool is_training_;
1613   RnnStateCache rnn_state_cache_ TF_GUARDED_BY(mu_);
1614 };
1615 
1616 #define REGISTER_GPU(T)                                           \
1617   REGISTER_KERNEL_BUILDER(                                        \
1618       Name("CudnnRNN").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
1619       CudnnRNNForwardOp<GPUDevice, T>);
1620 
1621 TF_CALL_half(REGISTER_GPU);
1622 TF_CALL_float(REGISTER_GPU);
1623 TF_CALL_double(REGISTER_GPU);
1624 #undef REGISTER_GPU
1625 
1626 template <typename T>
1627 class CudnnRNNForwardOpV2<GPUDevice, T>
1628     : public CudnnRNNForwardOp<GPUDevice, T> {
1629  private:
1630   using CudnnRNNForwardOp<GPUDevice, T>::is_training;
1631   using CudnnRNNKernelCommon::CreateRnnDescriptor;
1632   using CudnnRNNKernelCommon::dropout;
1633   using CudnnRNNKernelCommon::HasInputC;
1634   using CudnnRNNKernelCommon::model_types;
1635 
1636  public:
CudnnRNNForwardOpV2(OpKernelConstruction * context)1637   explicit CudnnRNNForwardOpV2(OpKernelConstruction* context)
1638       : CudnnRNNForwardOp<GPUDevice, T>(context) {}
1639 
Compute(OpKernelContext * context)1640   void Compute(OpKernelContext* context) override {
1641     AlgorithmConfig best_algo_config;
1642     CudnnRNNForwardOp<GPUDevice, T>::ComputeAndReturnAlgorithm(
1643         context, &best_algo_config, /*var_seq_lengths=*/false,
1644         /*time_major=*/true, /*num_proj=*/0);
1645     if (!context->status().ok()) {
1646       return;
1647     }
1648 
1649     Tensor* output_host_reserved = nullptr;
1650     // output_host_reserved stores opaque info used for backprop when running
1651     // in training mode. At present, it includes a serialization of the best
1652     // AlgorithmDesc picked during rnn forward pass autotune.
1653     // int8 algorithm_id
1654     // int8 use_tensor_op
1655     // If autotune is not enabled, the algorithm_id is
1656     // stream_executor::dnn::kDefaultAlgorithm and use_tensor_op is false. If
1657     // running in inference mode, the output_host_reserved is currently not
1658     // populated.
1659     if (is_training()) {
1660       OP_REQUIRES_OK(context, context->allocate_output(4, TensorShape({2}),
1661                                                        &output_host_reserved));
1662       auto output_host_reserved_int8 = output_host_reserved->vec<int8>();
1663       output_host_reserved_int8(0) = best_algo_config.algorithm()->algo_id();
1664       output_host_reserved_int8(1) =
1665           best_algo_config.algorithm()->tensor_ops_enabled();
1666     } else {
1667       OP_REQUIRES_OK(context,
1668                      context->allocate_output(4, {}, &output_host_reserved));
1669     }
1670   }
1671 
1672  protected:
MaybeAutotune(OpKernelContext * context,const CudnnRnnModelShapes & model_shapes,const RnnInputMode & input_mode,const Tensor * input,const Tensor * input_h,const Tensor * input_c,const Tensor * params,Tensor * output,Tensor * output_h,Tensor * output_c,AlgorithmConfig * algo_config)1673   Status MaybeAutotune(OpKernelContext* context,
1674                        const CudnnRnnModelShapes& model_shapes,
1675                        const RnnInputMode& input_mode, const Tensor* input,
1676                        const Tensor* input_h, const Tensor* input_c,
1677                        const Tensor* params, Tensor* output, Tensor* output_h,
1678                        Tensor* output_c,
1679                        AlgorithmConfig* algo_config) override {
1680     CHECK_NE(algo_config, nullptr);
1681     if (!CudnnRnnUseAutotune() || this->is_debug_mode_) {
1682       *algo_config = AlgorithmConfig();
1683       return Status::OK();
1684     }
1685 
1686     std::vector<AlgorithmDesc> algorithms;
1687     auto* stream = context->op_device_context()->stream();
1688     CHECK(stream->parent()->GetRnnAlgorithms(&algorithms));
1689     if (algorithms.empty()) {
1690       LOG(WARNING) << "No Rnn algorithm found";
1691       return Status::OK();
1692     }
1693 
1694     const auto& modeltypes = model_types();
1695     CudnnRnnParameters rnn_params(
1696         model_shapes.num_layers, model_shapes.input_size,
1697         model_shapes.num_units, model_shapes.max_seq_length,
1698         model_shapes.batch_size, model_shapes.dir_count,
1699         /*has_dropout=*/std::abs(dropout()) > 1e-8, is_training(),
1700         modeltypes.rnn_mode, modeltypes.rnn_input_mode, input->dtype());
1701 
1702     if (AutotuneRnnConfigMap::GetInstance()->Find(rnn_params, algo_config)) {
1703       VLOG(1) << "Using existing best Cudnn RNN algorithm "
1704               << "(algo, tensor_op_enabled) = ("
1705               << algo_config->algorithm()->algo_id() << ", "
1706               << algo_config->algorithm()->tensor_ops_enabled() << ").";
1707       return Status::OK();
1708     }
1709     profiler::ScopedAnnotation trace("cudnn_autotuning");
1710 
1711     // Create temp tensors when profiling backprop pass.
1712     auto data_type = input->dtype();
1713     Tensor output_backprop;
1714     Tensor output_h_backprop;
1715     Tensor output_c_backprop;
1716     Tensor input_backprop;
1717     Tensor input_h_backprop;
1718     Tensor input_c_backprop;
1719     Tensor params_backprop;
1720     if (is_training()) {
1721       TF_RETURN_IF_ERROR(context->allocate_temp(
1722           data_type, model_shapes.output_shape, &output_backprop));
1723       TF_RETURN_IF_ERROR(context->allocate_temp(
1724           data_type, model_shapes.hidden_state_shape, &output_h_backprop));
1725 
1726       TF_RETURN_IF_ERROR(
1727           context->allocate_temp(data_type, params->shape(), &params_backprop));
1728       TF_RETURN_IF_ERROR(context->allocate_temp(
1729           data_type, model_shapes.input_shape, &input_backprop));
1730       TF_RETURN_IF_ERROR(context->allocate_temp(
1731           data_type, model_shapes.hidden_state_shape, &input_h_backprop));
1732       if (HasInputC()) {
1733         TF_RETURN_IF_ERROR(context->allocate_temp(
1734             data_type, model_shapes.hidden_state_shape, &output_c_backprop));
1735         TF_RETURN_IF_ERROR(context->allocate_temp(
1736             data_type, model_shapes.hidden_state_shape, &input_c_backprop));
1737       }
1738     }
1739     ProfileResult best_result;
1740     for (auto& algo : algorithms) {
1741       VLOG(1) << "Profile Cudnn RNN algorithm (algo, tensor_op_enabled) =  ("
1742               << algo.algo_id() << ", " << algo.tensor_ops_enabled() << ").";
1743       Status status;
1744       ProfileResult final_profile_result;
1745 
1746       ProfileResult fwd_profile_result;
1747       ProfileResult bak_profile_result;
1748 
1749       // RnnDescriptor is algorithm-dependent, thus not reusable.
1750       std::unique_ptr<RnnDescriptor> rnn_desc;
1751       // Use a temp scratch allocator for the random num generator.
1752       CudnnRnnAllocatorInTemp<uint8> dropout_state_allocator(context);
1753       if (!this->template CreateRnnDescriptor<T>(
1754                    context, model_shapes, input_mode, AlgorithmConfig(algo),
1755                    &dropout_state_allocator, &rnn_desc,
1756                    /*use_padded_io=*/false)
1757                .ok()) {
1758         continue;
1759       }
1760 
1761       // Again use temp scratch allocator during profiling.
1762       CudnnRnnAllocatorInTemp<T> reserve_space_allocator(context);
1763       CudnnRnnAllocatorInTemp<uint8> workspace_allocator(context);
1764       status = DoForward<T>(context, *rnn_desc, model_types(), model_shapes,
1765                             input, input_h, input_c, params, is_training(),
1766                             output, output_h, output_c, nullptr, true,
1767                             &reserve_space_allocator, &workspace_allocator,
1768                             &fwd_profile_result);
1769       if (!status.ok()) {
1770         continue;
1771       }
1772 
1773       if (is_training()) {
1774         // Get reserve space from the forward pass.
1775         Tensor reserve_space = reserve_space_allocator.get_allocated_tensor(0);
1776         status = DoBackward<T>(
1777             context, *rnn_desc, model_types(), model_shapes, input, input_h,
1778             input_c, params, output, output_h, output_c, &output_backprop,
1779             &output_h_backprop, &output_c_backprop, &reserve_space,
1780             &input_backprop, &input_h_backprop, &input_c_backprop,
1781             &params_backprop, nullptr, true, &workspace_allocator,
1782             &bak_profile_result);
1783         if (!status.ok()) {
1784           continue;
1785         }
1786         final_profile_result.set_elapsed_time_in_ms(
1787             fwd_profile_result.elapsed_time_in_ms() +
1788             bak_profile_result.elapsed_time_in_ms());
1789       } else {
1790         final_profile_result = fwd_profile_result;
1791       }
1792 
1793       auto total_time = final_profile_result.elapsed_time_in_ms();
1794       VLOG(1) << "Cudnn RNN algorithm (algo, tensor_op_enabled) =  ("
1795               << algo.algo_id() << ", " << algo.tensor_ops_enabled() << ")"
1796               << " run time: " << total_time << " ms.";
1797       if (total_time < best_result.elapsed_time_in_ms()) {
1798         best_result.set_elapsed_time_in_ms(total_time);
1799         best_result.set_algorithm(algo);
1800       }
1801     }
1802 
1803     if (!best_result.is_valid()) {
1804       return Status(error::Code::INTERNAL, "No algorithm worked!");
1805     }
1806     algo_config->set_algorithm(best_result.algorithm());
1807     VLOG(1) << "Best Cudnn RNN algorithm (algo, tensor_op_enabled) =  ("
1808             << best_result.algorithm().algo_id() << ", "
1809             << best_result.algorithm().tensor_ops_enabled() << ").";
1810     AutotuneRnnConfigMap::GetInstance()->Insert(rnn_params, *algo_config);
1811     return Status::OK();
1812   }
1813 };
1814 
1815 #define REGISTER_GPU(T)                                    \
1816   REGISTER_KERNEL_BUILDER(Name("CudnnRNNV2")               \
1817                               .Device(DEVICE_GPU)          \
1818                               .HostMemory("host_reserved") \
1819                               .TypeConstraint<T>("T"),     \
1820                           CudnnRNNForwardOpV2<GPUDevice, T>);
1821 
1822 TF_CALL_half(REGISTER_GPU);
1823 TF_CALL_float(REGISTER_GPU);
1824 TF_CALL_double(REGISTER_GPU);
1825 #undef REGISTER_GPU
1826 
1827 template <typename T>
1828 class CudnnRNNForwardOpV3<GPUDevice, T>
1829     : public CudnnRNNForwardOp<GPUDevice, T> {
1830  private:
1831   using CudnnRNNForwardOp<GPUDevice, T>::is_training;
1832   using CudnnRNNKernelCommon::CreateRnnDescriptor;
1833   using CudnnRNNKernelCommon::dropout;
1834   using CudnnRNNKernelCommon::HasInputC;
1835   using CudnnRNNKernelCommon::model_types;
1836   bool time_major_;
1837 
1838  protected:
time_major()1839   bool time_major() { return time_major_; }
1840 
1841  public:
CudnnRNNForwardOpV3(OpKernelConstruction * context)1842   explicit CudnnRNNForwardOpV3(OpKernelConstruction* context)
1843       : CudnnRNNForwardOp<GPUDevice, T>(context) {
1844     OP_REQUIRES_OK(context, context->GetAttr("time_major", &time_major_));
1845     if (context->HasAttr("num_proj")) {
1846       OP_REQUIRES_OK(context, context->GetAttr("num_proj", &num_proj_));
1847     } else {
1848       num_proj_ = 0;
1849     }
1850   }
1851 
Compute(OpKernelContext * context)1852   void Compute(OpKernelContext* context) override {
1853     AlgorithmConfig best_algo_config;
1854     CudnnRNNForwardOp<GPUDevice, T>::ComputeAndReturnAlgorithm(
1855         context, &best_algo_config, /*var_seq_lengths=*/true,
1856         /*time_major=*/time_major(), num_proj_);
1857     if (!context->status().ok()) {
1858       return;
1859     }
1860 
1861     Tensor* output_host_reserved = nullptr;
1862     // TODO: Current V3 only uses the default standard algorithm to process
1863     // batches with variable sequences and the inputs should be padded.
1864     // Autotune is not supported yet.
1865     OP_REQUIRES_OK(context,
1866                    context->allocate_output(4, {}, &output_host_reserved));
1867   }
1868 
1869  private:
1870   int num_proj_;
1871 };
1872 
1873 #define REGISTER_GPU(T)                                       \
1874   REGISTER_KERNEL_BUILDER(Name("CudnnRNNV3")                  \
1875                               .Device(DEVICE_GPU)             \
1876                               .HostMemory("sequence_lengths") \
1877                               .HostMemory("host_reserved")    \
1878                               .TypeConstraint<T>("T"),        \
1879                           CudnnRNNForwardOpV3<GPUDevice, T>);
1880 
1881 TF_CALL_half(REGISTER_GPU);
1882 TF_CALL_float(REGISTER_GPU);
1883 TF_CALL_double(REGISTER_GPU);
1884 #undef REGISTER_GPU
1885 
1886 // Run the backward operation of the RNN model.
1887 template <typename T>
1888 class CudnnRNNBackwardOp<GPUDevice, T> : public CudnnRNNKernelCommon {
1889  public:
CudnnRNNBackwardOp(OpKernelConstruction * context)1890   explicit CudnnRNNBackwardOp(OpKernelConstruction* context)
1891       : CudnnRNNKernelCommon(context) {}
1892 
Compute(OpKernelContext * context)1893   void Compute(OpKernelContext* context) override {
1894     ComputeImpl(context, false, true, 0);
1895   }
1896 
1897  protected:
ComputeImpl(OpKernelContext * context,bool var_seq_lengths,bool time_major,int num_proj)1898   virtual void ComputeImpl(OpKernelContext* context, bool var_seq_lengths,
1899                            bool time_major, int num_proj) {
1900     const Tensor* input = nullptr;
1901     const Tensor* input_h = nullptr;
1902     const Tensor* input_c = nullptr;
1903     const Tensor* params = nullptr;
1904     const Tensor* sequence_lengths = nullptr;
1905     CudnnRnnModelShapes model_shapes;
1906     bool use_padded_io = false;
1907     if (var_seq_lengths) {
1908       OP_REQUIRES_OK(context, ExtractForwardInput(
1909                                   context, model_types(), time_major, &input,
1910                                   &input_h, &input_c, &params,
1911                                   &sequence_lengths, num_proj, &model_shapes));
1912       use_padded_io =
1913           ShouldUsePaddedIO(sequence_lengths, model_shapes, time_major);
1914     } else {
1915       OP_REQUIRES_OK(context,
1916                      ExtractForwardInput(context, model_types(), time_major,
1917                                          &input, &input_h, &input_c, &params,
1918                                          num_proj, &model_shapes));
1919     }
1920     RnnInputMode input_mode;
1921     OP_REQUIRES_OK(context,
1922                    ToRNNInputMode(rnn_input_mode(), model_shapes.num_units,
1923                                   model_shapes.input_size, &input_mode));
1924 
1925     const Tensor* output = nullptr;
1926     const Tensor* output_h = nullptr;
1927     const Tensor* output_c = nullptr;
1928     const Tensor* output_backprop = nullptr;
1929     const Tensor* output_h_backprop = nullptr;
1930     const Tensor* output_c_backprop = nullptr;
1931     const Tensor* reserve_space = nullptr;
1932     OP_REQUIRES_OK(context,
1933                    ExtractBackwardInputs(context, model_shapes, model_types(),
1934                                          &output, &output_h, &output_c,
1935                                          &output_backprop, &output_h_backprop,
1936                                          &output_c_backprop, &reserve_space));
1937 
1938     Tensor* input_backprop = nullptr;
1939     Tensor* input_h_backprop = nullptr;
1940     Tensor* input_c_backprop = nullptr;
1941     Tensor* params_backprop = nullptr;
1942     OP_REQUIRES_OK(context,
1943                    AllocateOutputs(context, model_shapes, params->shape(),
1944                                    &input_backprop, &input_h_backprop,
1945                                    &input_c_backprop, &params_backprop));
1946 
1947     // Creates a memory callback for the workspace. The memory lives to the end
1948     // of this kernel calls.
1949     CudnnRnnAllocatorInTemp<uint8> workspace_allocator(context);
1950     AlgorithmConfig algo_config;
1951     OP_REQUIRES_OK(context, GetAlgorithm(context, &algo_config));
1952     Status launch_status;
1953     {
1954       mutex_lock l(mu_);
1955       RnnDescriptor* rnn_desc_ptr = nullptr;
1956       OP_REQUIRES_OK(
1957           context, GetCachedRnnDescriptor<T>(context, model_shapes, input_mode,
1958                                              algo_config, &rnn_state_cache_,
1959                                              &rnn_desc_ptr, use_padded_io));
1960       launch_status = DoBackward<T>(
1961           context, *rnn_desc_ptr, model_types(), model_shapes, input, input_h,
1962           input_c, params, output, output_h, output_c, output_backprop,
1963           output_h_backprop, output_c_backprop, reserve_space, input_backprop,
1964           input_h_backprop, input_c_backprop, params_backprop, sequence_lengths,
1965           time_major, &workspace_allocator,
1966           /*output_profile_result=*/nullptr);
1967     }
1968     OP_REQUIRES_OK(context, launch_status);
1969   }
1970 
1971  protected:
GetAlgorithm(OpKernelContext * context,AlgorithmConfig * algo_config)1972   virtual Status GetAlgorithm(OpKernelContext* context,
1973                               AlgorithmConfig* algo_config) {
1974     CHECK_NE(algo_config, nullptr);
1975     *algo_config = AlgorithmConfig();
1976     return Status::OK();
1977   }
1978 
1979  private:
1980   mutex mu_;
1981   RnnStateCache rnn_state_cache_ TF_GUARDED_BY(mu_);
1982 
ExtractBackwardInputs(OpKernelContext * context,const CudnnRnnModelShapes & model_shapes,const CudnnModelTypes & model_types,const Tensor ** output,const Tensor ** output_h,const Tensor ** output_c,const Tensor ** output_backprop,const Tensor ** output_h_backprop,const Tensor ** output_c_backprop,const Tensor ** reserve_space)1983   Status ExtractBackwardInputs(
1984       OpKernelContext* context, const CudnnRnnModelShapes& model_shapes,
1985       const CudnnModelTypes& model_types, const Tensor** output,
1986       const Tensor** output_h, const Tensor** output_c,
1987       const Tensor** output_backprop, const Tensor** output_h_backprop,
1988       const Tensor** output_c_backprop, const Tensor** reserve_space) {
1989     TF_RETURN_IF_ERROR(context->input("output", output));
1990     TF_RETURN_IF_ERROR(context->input("output_backprop", output_backprop));
1991     TF_RETURN_IF_ERROR(context->input("output_h", output_h));
1992     TF_RETURN_IF_ERROR(context->input("output_h_backprop", output_h_backprop));
1993     if (model_types.HasInputC()) {
1994       TF_RETURN_IF_ERROR(context->input("output_c", output_c));
1995       TF_RETURN_IF_ERROR(
1996           context->input("output_c_backprop", output_c_backprop));
1997     }
1998     TF_RETURN_IF_ERROR(context->input("reserve_space", reserve_space));
1999     const TensorShape& hidden_state_shape = model_shapes.hidden_state_shape;
2000     const TensorShape& output_shape = model_shapes.output_shape;
2001     const TensorShape& cell_state_shape = model_shapes.cell_state_shape;
2002 
2003     if (output_shape != (*output)->shape()) {
2004       return errors::InvalidArgument(
2005           "Invalid output shape: ", (*output)->shape().DebugString(), " ",
2006           output_shape.DebugString());
2007     }
2008     if (hidden_state_shape != (*output_h)->shape()) {
2009       return errors::InvalidArgument(
2010           "Invalid output_h shape: ", (*output_h)->shape().DebugString(), " ",
2011           hidden_state_shape.DebugString());
2012     }
2013 
2014     if (output_shape != (*output_backprop)->shape()) {
2015       return errors::InvalidArgument("Invalid output_backprop shape: ",
2016                                      (*output_backprop)->shape().DebugString(),
2017                                      " ", output_shape.DebugString());
2018     }
2019     if (hidden_state_shape != (*output_h_backprop)->shape()) {
2020       return errors::InvalidArgument(
2021           "Invalid output_h_backprop shape: ",
2022           (*output_h_backprop)->shape().DebugString(), " ",
2023           hidden_state_shape.DebugString());
2024     }
2025 
2026     if (model_types.HasInputC()) {
2027       if (cell_state_shape != (*output_c)->shape()) {
2028         return errors::InvalidArgument(
2029             "Invalid output_c shape: ", (*output_c)->shape().DebugString(), " ",
2030             cell_state_shape.DebugString());
2031       }
2032       if (cell_state_shape != (*output_c_backprop)->shape()) {
2033         return errors::InvalidArgument(
2034             "Invalid output_c_backprop shape: ",
2035             (*output_c_backprop)->shape().DebugString(), " ",
2036             cell_state_shape.DebugString());
2037       }
2038     }
2039     return Status::OK();
2040   }
2041 
AllocateOutputs(OpKernelContext * context,const CudnnRnnModelShapes & model_shapes,const TensorShape & params_shape,Tensor ** input_backprop,Tensor ** input_h_backprop,Tensor ** input_c_backprop,Tensor ** params_backprop)2042   Status AllocateOutputs(OpKernelContext* context,
2043                          const CudnnRnnModelShapes& model_shapes,
2044                          const TensorShape& params_shape,
2045                          Tensor** input_backprop, Tensor** input_h_backprop,
2046                          Tensor** input_c_backprop, Tensor** params_backprop) {
2047     const TensorShape& input_shape = model_shapes.input_shape;
2048     const TensorShape& hidden_state_shape = model_shapes.hidden_state_shape;
2049     const TensorShape& cell_state_shape = model_shapes.cell_state_shape;
2050 
2051     TF_RETURN_IF_ERROR(
2052         context->allocate_output(0, input_shape, input_backprop));
2053     TF_RETURN_IF_ERROR(
2054         context->allocate_output(1, hidden_state_shape, input_h_backprop));
2055     if (HasInputC()) {
2056       TF_RETURN_IF_ERROR(
2057           context->allocate_output(2, cell_state_shape, input_c_backprop));
2058     } else {
2059       // Only LSTM uses input_c and output_c. So for all other models, we only
2060       // need to create dummy outputs.
2061       TF_RETURN_IF_ERROR(context->allocate_output(2, {}, input_c_backprop));
2062     }
2063     TF_RETURN_IF_ERROR(
2064         context->allocate_output(3, params_shape, params_backprop));
2065     return Status::OK();
2066   }
2067 };
2068 
2069 #define REGISTER_GPU(T)                                                   \
2070   REGISTER_KERNEL_BUILDER(                                                \
2071       Name("CudnnRNNBackprop").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
2072       CudnnRNNBackwardOp<GPUDevice, T>);
2073 
2074 TF_CALL_half(REGISTER_GPU);
2075 TF_CALL_float(REGISTER_GPU);
2076 TF_CALL_double(REGISTER_GPU);
2077 #undef REGISTER_GPU
2078 
2079 template <typename T>
2080 class CudnnRNNBackwardOpV2<GPUDevice, T>
2081     : public CudnnRNNBackwardOp<GPUDevice, T> {
2082  public:
CudnnRNNBackwardOpV2(OpKernelConstruction * context)2083   explicit CudnnRNNBackwardOpV2(OpKernelConstruction* context)
2084       : CudnnRNNBackwardOp<GPUDevice, T>(context) {}
2085 
2086  protected:
GetAlgorithm(OpKernelContext * context,AlgorithmConfig * algo_config)2087   Status GetAlgorithm(OpKernelContext* context,
2088                       AlgorithmConfig* algo_config) override {
2089     CHECK_NE(algo_config, nullptr);
2090     const Tensor* host_reserved = nullptr;
2091     TF_RETURN_IF_ERROR(context->input("host_reserved", &host_reserved));
2092 
2093     auto host_reserved_int8 = host_reserved->vec<int8>();
2094     const AlgorithmDesc algo_desc(host_reserved_int8(0), host_reserved_int8(1));
2095     algo_config->set_algorithm(algo_desc);
2096     return Status::OK();
2097   }
2098 };
2099 
2100 #define REGISTER_GPU(T)                                    \
2101   REGISTER_KERNEL_BUILDER(Name("CudnnRNNBackpropV2")       \
2102                               .Device(DEVICE_GPU)          \
2103                               .HostMemory("host_reserved") \
2104                               .TypeConstraint<T>("T"),     \
2105                           CudnnRNNBackwardOpV2<GPUDevice, T>);
2106 
2107 TF_CALL_half(REGISTER_GPU);
2108 TF_CALL_float(REGISTER_GPU);
2109 TF_CALL_double(REGISTER_GPU);
2110 #undef REGISTER_GPU
2111 
2112 template <typename T>
2113 class CudnnRNNBackwardOpV3<GPUDevice, T>
2114     : public CudnnRNNBackwardOp<GPUDevice, T> {
2115  private:
2116   bool time_major_;
2117 
2118  protected:
time_major()2119   bool time_major() { return time_major_; }
2120 
2121  public:
CudnnRNNBackwardOpV3(OpKernelConstruction * context)2122   explicit CudnnRNNBackwardOpV3(OpKernelConstruction* context)
2123       : CudnnRNNBackwardOp<GPUDevice, T>(context) {
2124     OP_REQUIRES_OK(context, context->GetAttr("time_major", &time_major_));
2125     if (context->HasAttr("num_proj")) {
2126       OP_REQUIRES_OK(context, context->GetAttr("num_proj", &num_proj_));
2127     } else {
2128       num_proj_ = 0;
2129     }
2130   }
2131 
Compute(OpKernelContext * context)2132   void Compute(OpKernelContext* context) override {
2133     CudnnRNNBackwardOp<GPUDevice, T>::ComputeImpl(context, true, time_major(),
2134                                                   num_proj_);
2135   }
2136 
2137  private:
2138   int num_proj_;
2139 };
2140 
2141 #define REGISTER_GPU(T)                                       \
2142   REGISTER_KERNEL_BUILDER(Name("CudnnRNNBackpropV3")          \
2143                               .Device(DEVICE_GPU)             \
2144                               .HostMemory("sequence_lengths") \
2145                               .HostMemory("host_reserved")    \
2146                               .TypeConstraint<T>("T"),        \
2147                           CudnnRNNBackwardOpV3<GPUDevice, T>);
2148 
2149 TF_CALL_half(REGISTER_GPU);
2150 TF_CALL_float(REGISTER_GPU);
2151 TF_CALL_double(REGISTER_GPU);
2152 #undef REGISTER_GPU
2153 
2154 // TODO(zhengxq): Add the conversion of Cudnn RNN Params from and to
2155 // its canonical form.
2156 
2157 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
2158 
2159 }  // namespace tensorflow
2160