• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #define EIGEN_USE_THREADS
16 
17 #include <stddef.h>
18 #include <atomic>
19 #include <cmath>
20 #include <functional>
21 #include <limits>
22 #include <string>
23 #include <unordered_set>
24 
25 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
26 #include "tensorflow/core/framework/device_base.h"
27 #include "tensorflow/core/framework/kernel_def_builder.h"
28 #include "tensorflow/core/framework/op.h"
29 #include "tensorflow/core/framework/op_def_builder.h"
30 #include "tensorflow/core/framework/op_kernel.h"
31 #include "tensorflow/core/framework/register_types.h"
32 #include "tensorflow/core/framework/tensor.h"
33 #include "tensorflow/core/framework/tensor_shape.h"
34 #include "tensorflow/core/framework/tensor_types.h"
35 #include "tensorflow/core/framework/types.h"
36 #include "tensorflow/core/kernels/gpu_utils.h"
37 #include "tensorflow/core/lib/core/errors.h"
38 #include "tensorflow/core/lib/core/status.h"
39 #include "tensorflow/core/lib/core/stringpiece.h"
40 #include "tensorflow/core/lib/gtl/inlined_vector.h"
41 #include "tensorflow/core/lib/hash/hash.h"
42 #include "tensorflow/core/lib/strings/stringprintf.h"
43 #include "tensorflow/core/platform/fingerprint.h"
44 #include "tensorflow/core/platform/mutex.h"
45 #include "tensorflow/core/platform/types.h"
46 #include "tensorflow/core/util/env_var.h"
47 #include "tensorflow/core/util/use_cudnn.h"
48 
49 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
50 #include "tensorflow/core/platform/stream_executor.h"
51 #include "tensorflow/core/util/stream_executor_util.h"
52 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
53 
54 /*
55  * This module implements ops that fuse a multi-layer multi-step RNN/LSTM model
56  * using the underlying Cudnn library.
57  *
58  * Cudnn RNN library exposes an opaque parameter buffer with unknown layout and
59  * format. And it is very likely that if saved, they cannot be used across
60  * different GPUs. So users need to first query the size of the opaque
61  * parameter buffer, and convert it to and from its canonical forms. But each
62  * actual training step is carried out with the parameter buffer.
63  *
64  * Similar to many other ops, the forward op has two flavors: training and
65  * inference. When training is specified, additional data in reserve_space will
66  * be produced for the backward pass. So there is a performance penalty.
67  *
68  * In addition to the actual data and reserve_space, Cudnn also needs more
69  * memory as temporary workspace. The memory management to and from
70  * stream-executor is done through ScratchAllocator. In general,
71  * stream-executor is responsible for creating the memory of proper size. And
72  * TensorFlow is responsible for making sure the memory is alive long enough
73  * and recycles afterwards.
74  *
75  */
76 namespace tensorflow {
77 
78 using CPUDevice = Eigen::ThreadPoolDevice;
79 
80 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
81 
82 using GPUDevice = Eigen::GpuDevice;
83 using se::Stream;
84 using se::StreamExecutor;
85 using se::dnn::RnnDescriptor;
86 
87 template <typename Device, typename T, typename Index>
88 class CudnnRNNParamsSizeOp;
89 
90 template <typename Device, typename T>
91 class CudnnRNNParamsToCanonical;
92 
93 template <typename Device, typename T>
94 class CudnnRNNCanonicalToParams;
95 
96 template <typename Device, typename T>
97 class CudnnRNNForwardOp;
98 
99 template <typename Device, typename T>
100 class CudnnRNNBackwardOp;
101 
102 template <typename Device, typename T>
103 class CudnnRNNForwardOpV2;
104 
105 template <typename Device, typename T>
106 class CudnnRNNBackwardOpV2;
107 
108 template <typename Device, typename T>
109 class CudnnRNNForwardOpV3;
110 
111 template <typename Device, typename T>
112 class CudnnRNNBackwardOpV3;
113 
114 enum class TFRNNInputMode {
115   kRNNLinearInput = 0,
116   kRNNSkipInput = 1,
117   kAutoSelect = 9999999
118 };
119 
120 namespace {
121 using se::DeviceMemory;
122 using se::DeviceMemoryBase;
123 using se::ScratchAllocator;
124 using se::dnn::AlgorithmConfig;
125 using se::dnn::AlgorithmDesc;
126 using se::dnn::ProfileResult;
127 using se::dnn::RnnDirectionMode;
128 using se::dnn::RnnInputMode;
129 using se::dnn::RnnMode;
130 using se::dnn::RnnSequenceTensorDescriptor;
131 using se::dnn::RnnStateTensorDescriptor;
132 using se::dnn::ToDataType;
133 using se::port::StatusOr;
134 
HashList(const std::vector<int> & list)135 uint64 HashList(const std::vector<int>& list) {
136   if (list.empty()) {
137     return 0;
138   }
139   uint64 hash_code = list[0];
140   for (int i = 1; i < list.size(); i++) {
141     hash_code = Hash64Combine(hash_code, list[i]);
142   }
143   return hash_code;
144 }
145 
146 // Encapsulate all the shape information that is used in both forward and
147 // backward rnn operations.
148 class CudnnRnnParameters {
149  public:
CudnnRnnParameters(int num_layers,int input_size,int num_units,int max_seq_length,int batch_size,int dir_count,bool has_dropout,bool is_training,RnnMode rnn_mode,TFRNNInputMode rnn_input_mode,DataType dtype)150   CudnnRnnParameters(int num_layers, int input_size, int num_units,
151                      int max_seq_length, int batch_size, int dir_count,
152                      bool has_dropout, bool is_training, RnnMode rnn_mode,
153                      TFRNNInputMode rnn_input_mode, DataType dtype)
154       : num_layers_(num_layers),
155         input_size_(input_size),
156         num_units_(num_units),
157         seq_length_(max_seq_length),
158         batch_size_(batch_size),
159         dir_count_(dir_count),
160         has_dropout_(has_dropout),
161         is_training_(is_training),
162         rnn_mode_(rnn_mode),
163         rnn_input_mode_(rnn_input_mode),
164         dtype_(dtype) {
165     hash_code_ =
166         HashList({num_layers, input_size, num_units, max_seq_length, batch_size,
167                   dir_count, static_cast<int>(has_dropout),
168                   static_cast<int>(is_training), static_cast<int>(rnn_mode),
169                   static_cast<int>(rnn_input_mode), dtype});
170   }
171 
operator ==(const CudnnRnnParameters & other) const172   bool operator==(const CudnnRnnParameters& other) const {
173     return this->get_data_as_tuple() == other.get_data_as_tuple();
174   }
175 
operator !=(const CudnnRnnParameters & other) const176   bool operator!=(const CudnnRnnParameters& other) const {
177     return !(*this == other);
178   }
hash() const179   uint64 hash() const { return hash_code_; }
180 
ToString() const181   string ToString() const {
182     std::vector<string> fields = {
183         std::to_string(num_layers_),
184         std::to_string(input_size_),
185         std::to_string(num_units_),
186         std::to_string(seq_length_),
187         std::to_string(batch_size_),
188         std::to_string(dir_count_),
189         std::to_string(has_dropout_),
190         std::to_string(is_training_),
191         std::to_string(static_cast<int>(rnn_mode_)),
192         std::to_string(static_cast<int>(rnn_input_mode_)),
193         std::to_string(static_cast<int>(dtype_))};
194     return absl::StrJoin(fields, ", ");
195   }
196 
197  private:
198   using ParameterDataType = std::tuple<int, int, int, int, int, int, bool, bool,
199                                        RnnMode, TFRNNInputMode, DataType>;
200 
get_data_as_tuple() const201   ParameterDataType get_data_as_tuple() const {
202     return std::make_tuple(num_layers_, input_size_, num_units_, seq_length_,
203                            batch_size_, dir_count_, has_dropout_, is_training_,
204                            rnn_mode_, rnn_input_mode_, dtype_);
205   }
206 
207   const int num_layers_;
208   const int input_size_;
209   const int num_units_;
210   const int seq_length_;
211   const int batch_size_;
212   const int dir_count_;
213   const bool has_dropout_;
214   const bool is_training_;
215   const RnnMode rnn_mode_;
216   const TFRNNInputMode rnn_input_mode_;
217   const DataType dtype_;
218   uint64 hash_code_;
219 };
220 
221 struct RnnAutoTuneGroup {
nametensorflow::__anond88b5f210111::RnnAutoTuneGroup222   static string name() { return "Rnn"; }
223 };
224 
225 using AutoTuneRnnConfigMap =
226     AutoTuneSingleton<RnnAutoTuneGroup, CudnnRnnParameters, AlgorithmConfig>;
227 
ParseRNNMode(const string & str,RnnMode * rnn_mode)228 Status ParseRNNMode(const string& str, RnnMode* rnn_mode) {
229   if (str == "rnn_relu") {
230     *rnn_mode = RnnMode::kRnnRelu;
231     return Status::OK();
232   } else if (str == "rnn_tanh") {
233     *rnn_mode = RnnMode::kRnnTanh;
234     return Status::OK();
235   } else if (str == "lstm") {
236     *rnn_mode = RnnMode::kRnnLstm;
237     return Status::OK();
238   } else if (str == "gru") {
239     *rnn_mode = RnnMode::kRnnGru;
240     return Status::OK();
241   }
242   return errors::InvalidArgument("Invalid RNN mode: ", str);
243 }
244 
ParseTFRNNInputMode(const string & str,TFRNNInputMode * rnn_input_mode)245 Status ParseTFRNNInputMode(const string& str, TFRNNInputMode* rnn_input_mode) {
246   if (str == "linear_input") {
247     *rnn_input_mode = TFRNNInputMode::kRNNLinearInput;
248     return Status::OK();
249   } else if (str == "skip_input") {
250     *rnn_input_mode = TFRNNInputMode::kRNNSkipInput;
251     return Status::OK();
252   } else if (str == "auto_select") {
253     *rnn_input_mode = TFRNNInputMode::kAutoSelect;
254     return Status::OK();
255   }
256   return errors::InvalidArgument("Invalid RNN input mode: ", str);
257 }
258 
ParseRNNDirectionMode(const string & str,RnnDirectionMode * rnn_dir_mode)259 Status ParseRNNDirectionMode(const string& str,
260                              RnnDirectionMode* rnn_dir_mode) {
261   if (str == "unidirectional") {
262     *rnn_dir_mode = RnnDirectionMode::kRnnUnidirectional;
263     return Status::OK();
264   } else if (str == "bidirectional") {
265     *rnn_dir_mode = RnnDirectionMode::kRnnBidirectional;
266     return Status::OK();
267   }
268   return errors::InvalidArgument("Invalid RNN direction mode: ", str);
269 }
270 
ToRNNInputMode(TFRNNInputMode tf_input_mode,int num_units,int input_size,RnnInputMode * input_mode)271 Status ToRNNInputMode(TFRNNInputMode tf_input_mode, int num_units,
272                       int input_size, RnnInputMode* input_mode) {
273   switch (tf_input_mode) {
274     case TFRNNInputMode::kRNNLinearInput:
275       *input_mode = RnnInputMode::kRnnLinearSkip;
276       break;
277     case TFRNNInputMode::kRNNSkipInput:
278       *input_mode = RnnInputMode::kRnnSkipInput;
279       break;
280     case TFRNNInputMode::kAutoSelect:
281       *input_mode = (input_size == num_units) ? RnnInputMode::kRnnSkipInput
282                                               : RnnInputMode::kRnnLinearSkip;
283       break;
284     default:
285       return errors::InvalidArgument("Invalid TF input mode: ",
286                                      static_cast<int>(tf_input_mode));
287   }
288   return Status::OK();
289 }
290 
291 // TODO(zhengxq): Merge those into stream_executor_util.h.
292 template <typename T>
AsDeviceMemory(const Tensor * tensor)293 const DeviceMemory<T> AsDeviceMemory(const Tensor* tensor) {
294   return DeviceMemory<T>::MakeFromByteSize(
295       const_cast<T*>(tensor->template flat<T>().data()),
296       tensor->template flat<T>().size() * sizeof(T));
297 }
298 
299 template <typename T>
AsDeviceMemory(Tensor * tensor)300 DeviceMemory<T> AsDeviceMemory(Tensor* tensor) {
301   return DeviceMemory<T>::MakeFromByteSize(
302       tensor->template flat<T>().data(),
303       tensor->template flat<T>().size() * sizeof(T));
304 }
305 
306 template <typename U, typename T>
CastDeviceMemory(Tensor * tensor)307 DeviceMemory<U> CastDeviceMemory(Tensor* tensor) {
308   return DeviceMemory<U>::MakeFromByteSize(
309       tensor->template flat<T>().data(),
310       tensor->template flat<T>().size() * sizeof(T));
311 }
312 
SliceDeviceMemory(const DeviceMemoryBase & device_memory,int64 offset,int64 size)313 DeviceMemoryBase SliceDeviceMemory(const DeviceMemoryBase& device_memory,
314                                    int64 offset, int64 size) {
315   const void* base_ptr = device_memory.opaque();
316   void* offset_ptr =
317       const_cast<char*>(reinterpret_cast<const char*>(base_ptr) + offset);
318   CHECK(offset + size <= device_memory.size())
319       << "The slice is not within the region of DeviceMemory.";
320   return DeviceMemoryBase(offset_ptr, size);
321 }
322 
FromExecutorStatus(const se::port::Status & s)323 inline Status FromExecutorStatus(const se::port::Status& s) {
324   return s.ok() ? Status::OK()
325                 : Status(static_cast<error::Code>(static_cast<int>(s.code())),
326                          s.error_message());
327 }
328 
329 template <typename T>
FromExecutorStatus(const se::port::StatusOr<T> & s)330 inline Status FromExecutorStatus(const se::port::StatusOr<T>& s) {
331   return FromExecutorStatus(s.status());
332 }
333 
ToExecutorStatus(const Status & s)334 inline se::port::Status ToExecutorStatus(const Status& s) {
335   return s.ok() ? se::port::Status::OK()
336                 : se::port::Status(static_cast<se::port::error::Code>(
337                                        static_cast<int>(s.code())),
338                                    s.error_message());
339 }
340 
341 template <typename>
342 struct ToTFDataType;
343 
344 template <>
345 struct ToTFDataType<Eigen::half> : std::integral_constant<DataType, DT_HALF> {};
346 
347 template <>
348 struct ToTFDataType<float> : std::integral_constant<DataType, DT_FLOAT> {};
349 
350 template <>
351 struct ToTFDataType<double> : std::integral_constant<DataType, DT_DOUBLE> {};
352 
353 template <>
354 struct ToTFDataType<uint8> : std::integral_constant<DataType, DT_UINT8> {};
355 
356 // A helper to allocate temporary scratch memory for Cudnn RNN models. It
357 // takes the ownership of the underlying memory. The expectation is that the
358 // memory should be alive for the span of the Cudnn RNN itself.
359 template <typename T>
360 class CudnnRnnAllocatorInTemp : public ScratchAllocator {
361  public:
362   ~CudnnRnnAllocatorInTemp() override = default;
363 
CudnnRnnAllocatorInTemp(OpKernelContext * context)364   explicit CudnnRnnAllocatorInTemp(OpKernelContext* context)
365       : context_(context) {}
GetMemoryLimitInBytes()366   int64 GetMemoryLimitInBytes() override {
367     return std::numeric_limits<int64>::max();
368   }
369 
AllocateBytes(int64 byte_size)370   StatusOr<DeviceMemory<uint8>> AllocateBytes(int64 byte_size) override {
371     Tensor temporary_memory;
372     const DataType tf_data_type = ToTFDataType<T>::value;
373     int64 allocate_count =
374         Eigen::divup(byte_size, static_cast<int64>(sizeof(T)));
375     Status allocation_status(context_->allocate_temp(
376         tf_data_type, TensorShape({allocate_count}), &temporary_memory));
377     if (!allocation_status.ok()) {
378       return ToExecutorStatus(allocation_status);
379     }
380     // Hold the reference of the allocated tensors until the end of the
381     // allocator.
382     allocated_tensors_.push_back(temporary_memory);
383     total_byte_size_ += byte_size;
384     return DeviceMemory<uint8>::MakeFromByteSize(
385         temporary_memory.template flat<T>().data(),
386         temporary_memory.template flat<T>().size() * sizeof(T));
387   }
388 
TotalByteSize() const389   int64 TotalByteSize() const { return total_byte_size_; }
390 
get_allocated_tensor(int index) const391   Tensor get_allocated_tensor(int index) const {
392     return allocated_tensors_[index];
393   }
394 
395  private:
396   int64 total_byte_size_ = 0;
397   OpKernelContext* context_;  // not owned
398   std::vector<Tensor> allocated_tensors_;
399 };
400 
401 // A helper to allocate memory for Cudnn RNN models as a kernel output. It is
402 // used by forward pass kernel to feed the output to the backward pass.
403 // The memory is expected to live long enough after the backward pass is
404 // finished.
405 template <typename T>
406 class CudnnRnnAllocatorInOutput : public ScratchAllocator {
407  public:
~CudnnRnnAllocatorInOutput()408   ~CudnnRnnAllocatorInOutput() override {}
CudnnRnnAllocatorInOutput(OpKernelContext * context,int output_index)409   CudnnRnnAllocatorInOutput(OpKernelContext* context, int output_index)
410       : context_(context), output_index_(output_index) {}
GetMemoryLimitInBytes()411   int64 GetMemoryLimitInBytes() override {
412     return std::numeric_limits<int64>::max();
413   }
AllocateBytes(int64 byte_size)414   StatusOr<DeviceMemory<uint8>> AllocateBytes(int64 byte_size) override {
415     CHECK(total_byte_size_ == 0)
416         << "Reserve space allocator can only be called once";
417     int64 allocate_count =
418         Eigen::divup(byte_size, static_cast<int64>(sizeof(T)));
419 
420     Tensor* temporary_memory = nullptr;
421     Status allocation_status(context_->allocate_output(
422         output_index_, TensorShape({allocate_count}), &temporary_memory));
423     if (!allocation_status.ok()) {
424       return ToExecutorStatus(allocation_status);
425     }
426     total_byte_size_ += byte_size;
427     auto memory_uint8 = DeviceMemory<uint8>::MakeFromByteSize(
428         temporary_memory->template flat<T>().data(),
429         temporary_memory->template flat<T>().size() * sizeof(T));
430     return StatusOr<DeviceMemory<uint8>>(memory_uint8);
431   }
TotalByteSize()432   int64 TotalByteSize() { return total_byte_size_; }
433 
434  private:
435   int64 total_byte_size_ = 0;
436   OpKernelContext* context_;  // not owned
437   int output_index_;
438 };
439 
440 // A helper to allocate persistent memory for Cudnn RNN models, which is
441 // expected to live between kernel invocations.
442 // This class is not thread-safe.
443 class CudnnRNNPersistentSpaceAllocator : public ScratchAllocator {
444  public:
CudnnRNNPersistentSpaceAllocator(OpKernelContext * context)445   explicit CudnnRNNPersistentSpaceAllocator(OpKernelContext* context)
446       : context_(context) {}
447 
~CudnnRNNPersistentSpaceAllocator()448   ~CudnnRNNPersistentSpaceAllocator() override {}
449 
GetMemoryLimitInBytes()450   int64 GetMemoryLimitInBytes() override {
451     return std::numeric_limits<int64>::max();
452   }
453 
AllocateBytes(int64 byte_size)454   StatusOr<DeviceMemory<uint8>> AllocateBytes(int64 byte_size) override {
455     if (total_byte_size_ != 0) {
456       return Status(error::FAILED_PRECONDITION,
457                     "Persistent space allocator can only be called once");
458     }
459 
460     Status allocation_status = context_->allocate_persistent(
461         DT_UINT8, TensorShape({byte_size}), &handle_, nullptr);
462     if (!allocation_status.ok()) {
463       return ToExecutorStatus(allocation_status);
464     }
465     total_byte_size_ += byte_size;
466     return AsDeviceMemory<uint8>(handle_.AccessTensor(context_));
467   }
TotalByteSize()468   int64 TotalByteSize() { return total_byte_size_; }
469 
470  private:
471   int64 total_byte_size_ = 0;
472   PersistentTensor handle_;
473   OpKernelContext* context_;  // not owned
474 };
475 
476 struct CudnnModelTypes {
477   RnnMode rnn_mode;
478   TFRNNInputMode rnn_input_mode;
479   RnnDirectionMode rnn_direction_mode;
HasInputCtensorflow::__anond88b5f210111::CudnnModelTypes480   bool HasInputC() const {
481     // For Cudnn 5.0, only LSTM has input-c. All other models use only
482     // input-h.
483     return rnn_mode == RnnMode::kRnnLstm;
484   }
485 
DebugStringtensorflow::__anond88b5f210111::CudnnModelTypes486   string DebugString() const {
487     return strings::Printf(
488         "[rnn_mode, rnn_input_mode, rnn_direction_mode]: %d, %d, %d ",
489         static_cast<int>(rnn_mode), static_cast<int>(rnn_input_mode),
490         static_cast<int>(rnn_direction_mode));
491   }
492 };
493 
494 // A helper class that collects the shapes to describe a RNN model.
495 struct CudnnRnnModelShapes {
496   int num_layers;
497   int input_size;
498   int num_units;
499   int dir_count;
500   int max_seq_length;
501   int batch_size;
502   int cell_num_units = 0;
503   // If you add new field to this structure, please take care of
504   // updating IsCompatibleWith() below as well as the hash function in
505   // CudnnRnnConfigHasher.
506   TensorShape input_shape;
507   TensorShape output_shape;
508   TensorShape hidden_state_shape;
509   TensorShape cell_state_shape;
510   // At present only fields related to cached RnnDescriptor are concerned.
IsCompatibleWithtensorflow::__anond88b5f210111::CudnnRnnModelShapes511   bool IsCompatibleWith(const CudnnRnnModelShapes& rhs) const {
512     return num_layers == rhs.num_layers && input_size == rhs.input_size &&
513            num_units == rhs.num_units && dir_count == rhs.dir_count &&
514            cell_num_units == rhs.cell_num_units &&
515            max_seq_length == rhs.max_seq_length;
516   }
DebugStringtensorflow::__anond88b5f210111::CudnnRnnModelShapes517   string DebugString() const {
518     return strings::Printf(
519         "[num_layers, input_size, num_units, dir_count, max_seq_length, "
520         "batch_size, cell_num_units]: [%d, %d, %d, %d, %d, %d, %d] ",
521         num_layers, input_size, num_units, dir_count, max_seq_length,
522         batch_size, cell_num_units);
523   }
524 };
525 
526 // Utility class for using CudnnRnnConfig and AlgorithmDesc pair a hash table
527 // key.
528 struct CudnnRnnConfigHasher {
operator ()tensorflow::__anond88b5f210111::CudnnRnnConfigHasher529   uint64 operator()(
530       const std::pair<CudnnRnnModelShapes, absl::optional<AlgorithmDesc>>&
531           to_hash) const {
532     auto& shapes = to_hash.first;
533     auto& algo_desc = to_hash.second;
534 
535     uint64 hash =
536         HashList({shapes.num_layers, shapes.input_size, shapes.num_units,
537                   shapes.dir_count, shapes.max_seq_length, shapes.batch_size});
538     if (algo_desc.has_value()) {
539       hash = Hash64Combine(hash, algo_desc->hash());
540     }
541     return hash;
542   }
543 };
544 
545 // Utility class for using CudnnRnnModelShapes and AlgorithmDesc pair as a hash
546 // table key.
547 struct CudnnRnnConfigComparator {
operator ()tensorflow::__anond88b5f210111::CudnnRnnConfigComparator548   bool operator()(
549       const std::pair<CudnnRnnModelShapes, absl::optional<AlgorithmDesc>>& lhs,
550       const std::pair<CudnnRnnModelShapes, absl::optional<AlgorithmDesc>>& rhs)
551       const {
552     return lhs.first.IsCompatibleWith(rhs.first) && lhs.second == rhs.second;
553   }
554 };
555 
556 // Pointers to RNN scratch space for a specific set of shape parameters (used as
557 // a hash table value in CudnnRNNForwardOp and CudnnRNNBackwardOp).
558 struct RnnScratchSpace {
559   std::unique_ptr<RnnDescriptor> rnn_desc;
560   std::unique_ptr<CudnnRNNPersistentSpaceAllocator> dropout_state_allocator;
561 };
562 
563 // Extract and checks the forward input tensors, parameters, and shapes from the
564 // OpKernelContext.
ExtractForwardInput(OpKernelContext * context,const CudnnModelTypes & model_types,bool time_major,const Tensor ** input,const Tensor ** input_h,const Tensor ** input_c,const Tensor ** params,const int num_proj,CudnnRnnModelShapes * model_shapes)565 Status ExtractForwardInput(OpKernelContext* context,
566                            const CudnnModelTypes& model_types, bool time_major,
567                            const Tensor** input, const Tensor** input_h,
568                            const Tensor** input_c, const Tensor** params,
569                            const int num_proj,
570                            CudnnRnnModelShapes* model_shapes) {
571   TF_RETURN_IF_ERROR(context->input("input", input));
572   TF_RETURN_IF_ERROR(context->input("input_h", input_h));
573   if (model_types.HasInputC()) {
574     TF_RETURN_IF_ERROR(context->input("input_c", input_c));
575   }
576   TF_RETURN_IF_ERROR(context->input("params", params));
577 
578   if ((*input)->dims() != 3) {
579     return errors::InvalidArgument("RNN input must be a 3-D vector.");
580   }
581   if (time_major) {
582     model_shapes->max_seq_length = (*input)->dim_size(0);
583     model_shapes->batch_size = (*input)->dim_size(1);
584   } else {
585     model_shapes->max_seq_length = (*input)->dim_size(1);
586     model_shapes->batch_size = (*input)->dim_size(0);
587   }
588   model_shapes->input_size = (*input)->dim_size(2);
589   model_shapes->input_shape = (*input)->shape();
590   model_shapes->dir_count =
591       (model_types.rnn_direction_mode == RnnDirectionMode::kRnnBidirectional)
592           ? 2
593           : 1;
594 
595   if ((*input_h)->dims() != 3) {
596     return errors::InvalidArgument("RNN input_h must be a 3-D vector.");
597   }
598   if (time_major) {
599     model_shapes->num_layers =
600         (*input_h)->dim_size(0) / model_shapes->dir_count;
601   } else {
602     model_shapes->num_layers =
603         (*input_h)->dim_size(1) / model_shapes->dir_count;
604   }
605   model_shapes->num_units = (*input_h)->dim_size(2);
606 
607   if (time_major) {
608     model_shapes->hidden_state_shape =
609         TensorShape({model_shapes->dir_count * model_shapes->num_layers,
610                      model_shapes->batch_size, model_shapes->num_units});
611   } else {
612     model_shapes->hidden_state_shape =
613         TensorShape({model_shapes->batch_size,
614                      model_shapes->dir_count * model_shapes->num_layers,
615                      model_shapes->num_units});
616   }
617   if ((*input_h)->shape() != model_shapes->hidden_state_shape) {
618     return errors::InvalidArgument(
619         "Invalid input_h shape: ", (*input_h)->shape().DebugString(), " ",
620         model_shapes->hidden_state_shape.DebugString());
621   }
622   if (model_types.HasInputC()) {
623     model_shapes->cell_num_units = (*input_c)->dim_size(2);
624     if (time_major) {
625       model_shapes->cell_state_shape =
626           TensorShape({model_shapes->dir_count * model_shapes->num_layers,
627                        model_shapes->batch_size, model_shapes->cell_num_units});
628     } else {
629       model_shapes->cell_state_shape =
630           TensorShape({model_shapes->batch_size,
631                        model_shapes->dir_count * model_shapes->num_layers,
632                        model_shapes->cell_num_units});
633     }
634     if (num_proj == 0) {
635       if ((*input_h)->shape() != (*input_c)->shape()) {
636         return errors::InvalidArgument(
637             "input_h and input_c must have the same shape w/o projection: ",
638             (*input_h)->shape().DebugString(), " ",
639             (*input_c)->shape().DebugString());
640       }
641     } else {
642       if ((*input_h)->dim_size(2) > (*input_c)->dim_size(2) ||
643           num_proj != (*input_h)->dim_size(2) ||
644           (*input_h)->dim_size(0) != (*input_c)->dim_size(0) ||
645           (*input_h)->dim_size(1) != (*input_c)->dim_size(1)) {
646         return errors::InvalidArgument(
647             "Invalid input_h and input_c w/ projection size: ", num_proj, " ",
648             (*input_h)->shape().DebugString(), " ",
649             (*input_c)->shape().DebugString());
650       }
651     }
652   } else {
653     // dummy cell_state_shape TODO(kaixih): remove the time_major branch
654     if (time_major) {
655       model_shapes->cell_state_shape =
656           TensorShape({model_shapes->dir_count * model_shapes->num_layers,
657                        model_shapes->batch_size, model_shapes->num_units});
658     } else {
659       model_shapes->cell_state_shape =
660           TensorShape({model_shapes->batch_size,
661                        model_shapes->dir_count * model_shapes->num_layers,
662                        model_shapes->num_units});
663     }
664     model_shapes->cell_num_units = 0;
665   }
666   if (time_major) {
667     model_shapes->output_shape =
668         TensorShape({model_shapes->max_seq_length, model_shapes->batch_size,
669                      model_shapes->dir_count * model_shapes->num_units});
670   } else {
671     model_shapes->output_shape =
672         TensorShape({model_shapes->batch_size, model_shapes->max_seq_length,
673                      model_shapes->dir_count * model_shapes->num_units});
674   }
675   return Status::OK();
676 }
677 
678 // Overloaded function to process the sequence_lengths
ExtractForwardInput(OpKernelContext * context,const CudnnModelTypes & model_types,bool time_major,const Tensor ** input,const Tensor ** input_h,const Tensor ** input_c,const Tensor ** params,const Tensor ** sequence_lengths,const int num_proj,CudnnRnnModelShapes * model_shapes)679 Status ExtractForwardInput(OpKernelContext* context,
680                            const CudnnModelTypes& model_types, bool time_major,
681                            const Tensor** input, const Tensor** input_h,
682                            const Tensor** input_c, const Tensor** params,
683                            const Tensor** sequence_lengths, const int num_proj,
684                            CudnnRnnModelShapes* model_shapes) {
685   TF_RETURN_IF_ERROR(context->input("sequence_lengths", sequence_lengths));
686   return ExtractForwardInput(context, model_types, time_major, input, input_h,
687                              input_c, params, num_proj, model_shapes);
688 }
689 
690 template <typename T>
CreateForwardAndBackwardIODescriptors(OpKernelContext * context,const CudnnRnnModelShapes & model_shapes,std::unique_ptr<RnnSequenceTensorDescriptor> * input_desc,std::unique_ptr<RnnStateTensorDescriptor> * h_state_desc,std::unique_ptr<RnnStateTensorDescriptor> * c_state_desc,std::unique_ptr<RnnSequenceTensorDescriptor> * output_desc,const absl::Span<const int> & seq_lengths,bool time_major)691 Status CreateForwardAndBackwardIODescriptors(
692     OpKernelContext* context, const CudnnRnnModelShapes& model_shapes,
693     std::unique_ptr<RnnSequenceTensorDescriptor>* input_desc,
694     std::unique_ptr<RnnStateTensorDescriptor>* h_state_desc,
695     std::unique_ptr<RnnStateTensorDescriptor>* c_state_desc,
696     std::unique_ptr<RnnSequenceTensorDescriptor>* output_desc,
697     const absl::Span<const int>& seq_lengths, bool time_major) {
698   StreamExecutor* executor = context->op_device_context()->stream()->parent();
699   se::dnn::DataType data_type = ToDataType<T>::value;
700 
701   const TensorShape& input_shape = model_shapes.input_shape;
702   const TensorShape& hidden_state_shape = model_shapes.hidden_state_shape;
703   const TensorShape& cell_state_shape = model_shapes.cell_state_shape;
704   const TensorShape& output_shape = model_shapes.output_shape;
705 
706   DCHECK_EQ(input_shape.dims(), 3);
707   if (seq_lengths.data() != nullptr) {
708     if (time_major) {
709       auto input_desc_s = executor->createRnnSequenceTensorDescriptor(
710           input_shape.dim_size(0), input_shape.dim_size(1),
711           input_shape.dim_size(2), seq_lengths, time_major, data_type);
712       TF_RETURN_IF_ERROR(input_desc_s.status());
713       *input_desc = input_desc_s.ConsumeValueOrDie();
714     } else {
715       auto input_desc_s = executor->createRnnSequenceTensorDescriptor(
716           input_shape.dim_size(1), input_shape.dim_size(0),
717           input_shape.dim_size(2), seq_lengths, time_major, data_type);
718       TF_RETURN_IF_ERROR(input_desc_s.status());
719       *input_desc = input_desc_s.ConsumeValueOrDie();
720     }
721   } else {
722     auto input_desc_s = executor->createRnnSequenceTensorDescriptor(
723         input_shape.dim_size(0), input_shape.dim_size(1),
724         input_shape.dim_size(2), data_type);
725     TF_RETURN_IF_ERROR(input_desc_s.status());
726     *input_desc = input_desc_s.ConsumeValueOrDie();
727   }
728 
729   DCHECK_EQ(hidden_state_shape.dims(), 3);
730   if (time_major) {
731     auto hidden_state_desc_s = executor->createRnnStateTensorDescriptor(
732         hidden_state_shape.dim_size(0), hidden_state_shape.dim_size(1),
733         hidden_state_shape.dim_size(2), data_type);
734     TF_RETURN_IF_ERROR(hidden_state_desc_s.status());
735     *h_state_desc = hidden_state_desc_s.ConsumeValueOrDie();
736   } else {
737     auto hidden_state_desc_s = executor->createRnnStateTensorDescriptor(
738         hidden_state_shape.dim_size(1), hidden_state_shape.dim_size(0),
739         hidden_state_shape.dim_size(2), data_type);
740     TF_RETURN_IF_ERROR(hidden_state_desc_s.status());
741     *h_state_desc = hidden_state_desc_s.ConsumeValueOrDie();
742   }
743 
744   DCHECK_EQ(cell_state_shape.dims(), 3);
745   if (time_major) {
746     auto cell_state_desc_s = executor->createRnnStateTensorDescriptor(
747         cell_state_shape.dim_size(0), cell_state_shape.dim_size(1),
748         cell_state_shape.dim_size(2), data_type);
749     TF_RETURN_IF_ERROR(cell_state_desc_s.status());
750     *c_state_desc = cell_state_desc_s.ConsumeValueOrDie();
751   } else {
752     auto cell_state_desc_s = executor->createRnnStateTensorDescriptor(
753         cell_state_shape.dim_size(1), cell_state_shape.dim_size(0),
754         cell_state_shape.dim_size(2), data_type);
755     TF_RETURN_IF_ERROR(cell_state_desc_s.status());
756     *c_state_desc = cell_state_desc_s.ConsumeValueOrDie();
757   }
758 
759   DCHECK_EQ(output_shape.dims(), 3);
760   if (seq_lengths.data() != nullptr) {
761     if (time_major) {
762       auto output_desc_s = executor->createRnnSequenceTensorDescriptor(
763           output_shape.dim_size(0), output_shape.dim_size(1),
764           output_shape.dim_size(2), seq_lengths, time_major, data_type);
765       TF_RETURN_IF_ERROR(output_desc_s.status());
766       *output_desc = output_desc_s.ConsumeValueOrDie();
767     } else {
768       auto output_desc_s = executor->createRnnSequenceTensorDescriptor(
769           output_shape.dim_size(1), output_shape.dim_size(0),
770           output_shape.dim_size(2), seq_lengths, time_major, data_type);
771       TF_RETURN_IF_ERROR(output_desc_s.status());
772       *output_desc = output_desc_s.ConsumeValueOrDie();
773     }
774   } else {
775     auto output_desc_s = executor->createRnnSequenceTensorDescriptor(
776         output_shape.dim_size(0), output_shape.dim_size(1),
777         output_shape.dim_size(2), data_type);
778     TF_RETURN_IF_ERROR(output_desc_s.status());
779     *output_desc = output_desc_s.ConsumeValueOrDie();
780   }
781 
782   return Status::OK();
783 }
784 
785 template <typename T>
DoForward(OpKernelContext * context,const RnnDescriptor & rnn_desc,const CudnnModelTypes & model_types,const CudnnRnnModelShapes & model_shapes,const Tensor * input,const Tensor * input_h,const Tensor * input_c,const Tensor * params,const bool is_training,Tensor * output,Tensor * output_h,Tensor * output_c,const Tensor * sequence_lengths,bool time_major,ScratchAllocator * reserve_space_allocator,ScratchAllocator * workspace_allocator,ProfileResult * output_profile_result)786 Status DoForward(OpKernelContext* context, const RnnDescriptor& rnn_desc,
787                  const CudnnModelTypes& model_types,
788                  const CudnnRnnModelShapes& model_shapes,
789                  /* forward inputs */
790                  const Tensor* input, const Tensor* input_h,
791                  const Tensor* input_c, const Tensor* params,
792                  const bool is_training,
793                  /* forward outputs, outputs of the function */
794                  Tensor* output, Tensor* output_h, Tensor* output_c,
795                  const Tensor* sequence_lengths, bool time_major,
796                  ScratchAllocator* reserve_space_allocator,
797                  ScratchAllocator* workspace_allocator,
798                  ProfileResult* output_profile_result) {
799   std::unique_ptr<RnnSequenceTensorDescriptor> input_desc;
800   std::unique_ptr<RnnStateTensorDescriptor> h_state_desc;
801   std::unique_ptr<RnnStateTensorDescriptor> c_state_desc;
802   std::unique_ptr<RnnSequenceTensorDescriptor> output_desc;
803 
804   absl::Span<const int> seq_lengths;
805   if (sequence_lengths != nullptr) {
806     seq_lengths = absl::Span<const int>(
807         sequence_lengths->template flat<int>().data(), model_shapes.batch_size);
808   }
809   TF_RETURN_IF_ERROR(CreateForwardAndBackwardIODescriptors<T>(
810       context, model_shapes, &input_desc, &h_state_desc, &c_state_desc,
811       &output_desc, seq_lengths, time_major));
812 
813   auto input_data = AsDeviceMemory<T>(input);
814   auto input_h_data = AsDeviceMemory<T>(input_h);
815   DeviceMemory<T> input_c_data;
816   if (model_types.HasInputC()) {
817     input_c_data = AsDeviceMemory<T>(input_c);
818   }
819 
820   auto params_data = AsDeviceMemory<T>(params);
821   auto output_data = AsDeviceMemory<T>(output);
822   auto output_h_data = AsDeviceMemory<T>(output_h);
823   DeviceMemory<T> output_c_data;
824   if (model_types.HasInputC()) {
825     output_c_data = AsDeviceMemory<T>(output_c);
826   }
827 
828   Stream* stream = context->op_device_context()->stream();
829   bool launch_success =
830       stream
831           ->ThenRnnForward(rnn_desc, *input_desc, input_data, *h_state_desc,
832                            input_h_data, *c_state_desc, input_c_data,
833                            params_data, *output_desc, &output_data,
834                            *h_state_desc, &output_h_data, *c_state_desc,
835                            &output_c_data, is_training, reserve_space_allocator,
836                            workspace_allocator, output_profile_result)
837           .ok();
838   return launch_success
839              ? Status::OK()
840              : errors::Internal(
841                    "Failed to call ThenRnnForward with model config: ",
842                    model_types.DebugString(), ", ", model_shapes.DebugString());
843 }
844 
845 template <typename T>
DoBackward(OpKernelContext * context,const RnnDescriptor & rnn_desc,const CudnnModelTypes & model_types,const CudnnRnnModelShapes & model_shapes,const Tensor * input,const Tensor * input_h,const Tensor * input_c,const Tensor * params,const Tensor * output,const Tensor * output_h,const Tensor * output_c,const Tensor * output_backprop,const Tensor * output_h_backprop,const Tensor * output_c_backprop,const Tensor * reserve_space,Tensor * input_backprop,Tensor * input_h_backprop,Tensor * input_c_backprop,Tensor * params_backprop,const Tensor * sequence_lengths,bool time_major,ScratchAllocator * workspace_allocator,ProfileResult * output_profile_result)846 Status DoBackward(
847     OpKernelContext* context, const RnnDescriptor& rnn_desc,
848     const CudnnModelTypes& model_types, const CudnnRnnModelShapes& model_shapes,
849     /* forward inputs */
850     const Tensor* input, const Tensor* input_h, const Tensor* input_c,
851     const Tensor* params,
852     /* forward outputs */
853     const Tensor* output, const Tensor* output_h, const Tensor* output_c,
854     /* backprop inputs */
855     const Tensor* output_backprop, const Tensor* output_h_backprop,
856     const Tensor* output_c_backprop, const Tensor* reserve_space,
857     /* backprop outputs, output of the function */
858     Tensor* input_backprop, Tensor* input_h_backprop, Tensor* input_c_backprop,
859     Tensor* params_backprop, const Tensor* sequence_lengths, bool time_major,
860     ScratchAllocator* workspace_allocator,
861     ProfileResult* output_profile_result) {
862   std::unique_ptr<RnnSequenceTensorDescriptor> input_desc;
863   std::unique_ptr<RnnStateTensorDescriptor> h_state_desc;
864   std::unique_ptr<RnnStateTensorDescriptor> c_state_desc;
865   std::unique_ptr<RnnSequenceTensorDescriptor> output_desc;
866 
867   absl::Span<const int> seq_lengths;
868   if (sequence_lengths != nullptr) {
869     seq_lengths = absl::Span<const int>(
870         sequence_lengths->template flat<int>().data(), model_shapes.batch_size);
871   }
872   TF_RETURN_IF_ERROR(CreateForwardAndBackwardIODescriptors<T>(
873       context, model_shapes, &input_desc, &h_state_desc, &c_state_desc,
874       &output_desc, seq_lengths, time_major));
875 
876   auto input_data = AsDeviceMemory<T>(input);
877   auto input_h_data = AsDeviceMemory<T>(input_h);
878   DeviceMemory<T> input_c_data;
879   if (model_types.HasInputC()) {
880     input_c_data = AsDeviceMemory<T>(input_c);
881   }
882   auto params_data = AsDeviceMemory<T>(params);
883   auto output_data = AsDeviceMemory<T>(output);
884   auto output_h_data = AsDeviceMemory<T>(output_h);
885   DeviceMemory<T> output_c_data;
886   if (model_types.HasInputC()) {
887     output_c_data = AsDeviceMemory<T>(output_c);
888   }
889   auto output_backprop_data = AsDeviceMemory<T>(output_backprop);
890   auto output_h_backprop_data = AsDeviceMemory<T>(output_h_backprop);
891   DeviceMemory<T> output_c_backprop_data;
892   if (model_types.HasInputC()) {
893     output_c_backprop_data = AsDeviceMemory<T>(output_c_backprop);
894   }
895   auto input_backprop_data = AsDeviceMemory<T>(input_backprop);
896   auto input_h_backprop_data = AsDeviceMemory<T>(input_h_backprop);
897   DeviceMemory<T> input_c_backprop_data;
898   if (model_types.HasInputC()) {
899     input_c_backprop_data = AsDeviceMemory<T>(input_c_backprop);
900   }
901   auto params_backprop_data = AsDeviceMemory<T>(params_backprop);
902   auto reserve_space_uint8 =
903       CastDeviceMemory<uint8, T>(const_cast<Tensor*>(reserve_space));
904 
905   // Creates a memory callback for the workspace. The memory lives to the end
906   // of this kernel calls.
907   Stream* stream = context->op_device_context()->stream();
908   bool launch_success =
909       stream
910           ->ThenRnnBackward(
911               rnn_desc, *input_desc, input_data, *h_state_desc, input_h_data,
912               *c_state_desc, input_c_data, params_data, *output_desc,
913               output_data, *h_state_desc, output_h_data, *c_state_desc,
914               output_c_data, output_backprop_data, output_h_backprop_data,
915               output_c_backprop_data, &input_backprop_data,
916               &input_h_backprop_data, &input_c_backprop_data,
917               &params_backprop_data, &reserve_space_uint8, workspace_allocator,
918               output_profile_result)
919           .ok();
920   return launch_success
921              ? Status::OK()
922              : errors::Internal(
923                    "Failed to call ThenRnnBackward with model config: ",
924                    model_types.DebugString(), ", ", model_shapes.DebugString());
925 }
926 
927 template <typename T>
RestoreParams(const OpInputList params_input,const std::vector<RnnDescriptor::ParamsRegion> & params,DeviceMemoryBase * data_dst,Stream * stream)928 void RestoreParams(const OpInputList params_input,
929                    const std::vector<RnnDescriptor::ParamsRegion>& params,
930                    DeviceMemoryBase* data_dst, Stream* stream) {
931   int num_params = params.size();
932   CHECK(params_input.size() == num_params)
933       << "Number of params mismatch. Expected " << params_input.size()
934       << ", got " << num_params;
935   for (int i = 0; i < params.size(); i++) {
936     int64 size_in_bytes = params[i].size;
937     int64 size = size_in_bytes / sizeof(T);
938     CHECK(size == params_input[i].NumElements())
939         << "Params size mismatch. Expected " << size << ", got "
940         << params_input[i].NumElements();
941     auto data_src_ptr = StreamExecutorUtil::AsDeviceMemory<T>(params_input[i]);
942     DeviceMemoryBase data_dst_ptr =
943         SliceDeviceMemory(*data_dst, params[i].offset, size_in_bytes);
944     stream->ThenMemcpy(&data_dst_ptr, data_src_ptr, size_in_bytes);
945   }
946 }
947 
ShouldUsePaddedIO(const Tensor * sequence_lengths,const CudnnRnnModelShapes & model_shapes,bool time_major)948 bool ShouldUsePaddedIO(const Tensor* sequence_lengths,
949                        const CudnnRnnModelShapes& model_shapes,
950                        bool time_major) {
951   auto seq_array = sequence_lengths->template flat<int>().data();
952   bool all_max_seq_length = true;
953   for (int i = 0; i < model_shapes.batch_size; i++) {
954     if (seq_array[i] != model_shapes.max_seq_length) {
955       all_max_seq_length = false;
956       break;
957     }
958   }
959   return !(time_major && all_max_seq_length);
960 }
961 
962 }  // namespace
963 
964 // Note: all following kernels depend on a RnnDescriptor instance, which
965 // according to Cudnn official doc should be kept around and reused across all
966 // Cudnn kernels in the same model.
967 // In Tensorflow, we don't pass the reference across different OpKernels,
968 // rather, recreate it separately in each OpKernel, which does no cause issue:
969 // CudnnDropoutDescriptor keeps a reference to a memory for
970 // random number generator state. During recreation, this state is lost.
971 // However, only forward-pass Cudnn APIs make use of the state.
972 
973 // A common base class for RNN kernels. It extracts common attributes and
974 // shape validations.
975 class CudnnRNNKernelCommon : public OpKernel {
976  protected:
CudnnRNNKernelCommon(OpKernelConstruction * context)977   explicit CudnnRNNKernelCommon(OpKernelConstruction* context)
978       : OpKernel(context) {
979     OP_REQUIRES_OK(context, context->GetAttr("dropout", &dropout_));
980     OP_REQUIRES_OK(context, context->GetAttr("seed", &seed_));
981     OP_REQUIRES_OK(context, context->GetAttr("seed2", &seed2_));
982     string str;
983     OP_REQUIRES_OK(context, context->GetAttr("rnn_mode", &str));
984     OP_REQUIRES_OK(context, ParseRNNMode(str, &model_types_.rnn_mode));
985     OP_REQUIRES_OK(context, context->GetAttr("input_mode", &str));
986     OP_REQUIRES_OK(context,
987                    ParseTFRNNInputMode(str, &model_types_.rnn_input_mode));
988     OP_REQUIRES_OK(context, context->GetAttr("direction", &str));
989     OP_REQUIRES_OK(
990         context, ParseRNNDirectionMode(str, &model_types_.rnn_direction_mode));
991     // Reset CudnnRnnDescriptor and related random number generate states in
992     // every Compute() call.
993     OP_REQUIRES_OK(context, ReadBoolFromEnvVar("TF_CUDNN_RESET_RND_GEN_STATE",
994                                                false, &reset_rnd_gen_state_));
995   }
996 
HasInputC() const997   bool HasInputC() const { return model_types_.HasInputC(); }
rnn_mode() const998   RnnMode rnn_mode() const { return model_types_.rnn_mode; }
rnn_input_mode() const999   TFRNNInputMode rnn_input_mode() const { return model_types_.rnn_input_mode; }
rnn_direction_mode() const1000   RnnDirectionMode rnn_direction_mode() const {
1001     return model_types_.rnn_direction_mode;
1002   }
model_types() const1003   const CudnnModelTypes& model_types() const { return model_types_; }
dropout() const1004   float dropout() const { return dropout_; }
seed()1005   uint64 seed() { return (static_cast<uint64>(seed_) << 32) | seed2_; }
ResetRndGenState()1006   bool ResetRndGenState() { return reset_rnd_gen_state_; }
1007 
1008   template <typename T>
ExtractCudnnRNNParamsInfo(OpKernelContext * context,int num_proj,std::unique_ptr<RnnDescriptor> * rnn_desc)1009   Status ExtractCudnnRNNParamsInfo(OpKernelContext* context, int num_proj,
1010                                    std::unique_ptr<RnnDescriptor>* rnn_desc) {
1011     const Tensor* num_layers_t = nullptr;
1012     TF_RETURN_IF_ERROR(context->input("num_layers", &num_layers_t));
1013     if (!TensorShapeUtils::IsScalar(num_layers_t->shape())) {
1014       return errors::InvalidArgument("num_layers is not a scalar");
1015     }
1016     int num_layers = num_layers_t->scalar<int>()();
1017     const Tensor* num_units_t = nullptr;
1018     TF_RETURN_IF_ERROR(context->input("num_units", &num_units_t));
1019     if (!TensorShapeUtils::IsScalar(num_units_t->shape())) {
1020       return errors::InvalidArgument("num_units is not a scalar");
1021     }
1022     int num_units = num_units_t->scalar<int>()();
1023     const Tensor* input_size_t = nullptr;
1024     TF_RETURN_IF_ERROR(context->input("input_size", &input_size_t));
1025     if (!TensorShapeUtils::IsScalar(input_size_t->shape())) {
1026       return errors::InvalidArgument("input_size is not a scalar");
1027     }
1028     int input_size = input_size_t->scalar<int>()();
1029 
1030     int h_num_units = (num_proj == 0 ? num_units : num_proj);
1031     int c_num_units = (num_proj == 0 ? 0 : num_units);
1032 
1033     RnnInputMode input_mode;
1034     TF_RETURN_IF_ERROR(
1035         ToRNNInputMode(rnn_input_mode(), num_units, input_size, &input_mode));
1036 
1037     Stream* stream = context->op_device_context()->stream();
1038     // ExtracCudnnRNNParamsInfo is only called by op_kernels that do not require
1039     // random number generator, therefore set state_allocator to nullptr.
1040     const AlgorithmConfig algo_config;
1041     auto rnn_desc_s = stream->parent()->createRnnDescriptor(
1042         num_layers, h_num_units, input_size, /*cell_size=*/c_num_units,
1043         /*batch_size=*/0, input_mode, rnn_direction_mode(), rnn_mode(),
1044         ToDataType<T>::value, algo_config, dropout(), seed(),
1045         /* state_allocator=*/nullptr, /*use_padded_io=*/false);
1046     if (!rnn_desc_s.ok()) {
1047       return FromExecutorStatus(rnn_desc_s);
1048     }
1049     *rnn_desc = rnn_desc_s.ConsumeValueOrDie();
1050     return Status::OK();
1051   }
1052 
1053   template <typename T>
CreateRnnDescriptor(OpKernelContext * context,const CudnnRnnModelShapes & model_shapes,const RnnInputMode & input_mode,const AlgorithmConfig & algo_config,ScratchAllocator * dropout_state_allocator,std::unique_ptr<RnnDescriptor> * rnn_desc,bool use_padded_io)1054   Status CreateRnnDescriptor(OpKernelContext* context,
1055                              const CudnnRnnModelShapes& model_shapes,
1056                              const RnnInputMode& input_mode,
1057                              const AlgorithmConfig& algo_config,
1058                              ScratchAllocator* dropout_state_allocator,
1059                              std::unique_ptr<RnnDescriptor>* rnn_desc,
1060                              bool use_padded_io) {
1061     StreamExecutor* executor = context->op_device_context()->stream()->parent();
1062     se::dnn::DataType data_type = ToDataType<T>::value;
1063     auto rnn_desc_s = executor->createRnnDescriptor(
1064         model_shapes.num_layers, model_shapes.num_units,
1065         model_shapes.input_size, model_shapes.cell_num_units,
1066         model_shapes.batch_size, input_mode, rnn_direction_mode(), rnn_mode(),
1067         data_type, algo_config, dropout(), seed(), dropout_state_allocator,
1068         use_padded_io);
1069     TF_RETURN_IF_ERROR(rnn_desc_s.status());
1070 
1071     *rnn_desc = rnn_desc_s.ConsumeValueOrDie();
1072     return Status::OK();
1073   }
1074 
1075   using RnnStateCache = gtl::FlatMap<
1076       std::pair<CudnnRnnModelShapes, absl::optional<AlgorithmDesc>>,
1077       RnnScratchSpace, CudnnRnnConfigHasher, CudnnRnnConfigComparator>;
1078   // Returns a raw rnn descriptor pointer. The cache owns the rnn descriptor and
1079   // should outlive the returned pointer.
1080   template <typename T>
GetCachedRnnDescriptor(OpKernelContext * context,const CudnnRnnModelShapes & model_shapes,const RnnInputMode & input_mode,const AlgorithmConfig & algo_config,RnnStateCache * cache,RnnDescriptor ** rnn_desc,bool use_padded_io)1081   Status GetCachedRnnDescriptor(OpKernelContext* context,
1082                                 const CudnnRnnModelShapes& model_shapes,
1083                                 const RnnInputMode& input_mode,
1084                                 const AlgorithmConfig& algo_config,
1085                                 RnnStateCache* cache, RnnDescriptor** rnn_desc,
1086                                 bool use_padded_io) {
1087     auto key = std::make_pair(model_shapes, algo_config.algorithm());
1088     RnnScratchSpace& rnn_state = (*cache)[key];
1089     if (rnn_state.rnn_desc == nullptr || ResetRndGenState()) {
1090       CudnnRNNPersistentSpaceAllocator* dropout_state_allocator =
1091           new CudnnRNNPersistentSpaceAllocator(context);
1092       rnn_state.dropout_state_allocator.reset(dropout_state_allocator);
1093       Status status = CreateRnnDescriptor<T>(
1094           context, model_shapes, input_mode, algo_config,
1095           dropout_state_allocator, &rnn_state.rnn_desc, use_padded_io);
1096       TF_RETURN_IF_ERROR(status);
1097     }
1098     *rnn_desc = rnn_state.rnn_desc.get();
1099     return Status::OK();
1100   }
1101 
1102  private:
1103   int seed_;
1104   int seed2_;
1105   float dropout_;
1106   bool reset_rnd_gen_state_;
1107 
1108   CudnnModelTypes model_types_;
1109 };
1110 
1111 // A class that returns the size of the opaque parameter buffer. The user should
1112 // use that to create the actual parameter buffer for training. However, it
1113 // should not be used for saving and restoring.
1114 template <typename T, typename Index>
1115 class CudnnRNNParamsSizeOp<GPUDevice, T, Index> : public CudnnRNNKernelCommon {
1116  public:
CudnnRNNParamsSizeOp(OpKernelConstruction * context)1117   explicit CudnnRNNParamsSizeOp(OpKernelConstruction* context)
1118       : CudnnRNNKernelCommon(context) {
1119     if (context->HasAttr("num_proj")) {
1120       OP_REQUIRES_OK(context, context->GetAttr("num_proj", &num_proj_));
1121     } else {
1122       num_proj_ = 0;
1123     }
1124   }
1125 
Compute(OpKernelContext * context)1126   void Compute(OpKernelContext* context) override {
1127     std::unique_ptr<RnnDescriptor> rnn_desc;
1128     OP_REQUIRES_OK(context,
1129                    ExtractCudnnRNNParamsInfo<T>(context, num_proj_, &rnn_desc));
1130     int64 params_size_in_bytes = rnn_desc->ParamsSizeInBytes();
1131     CHECK(params_size_in_bytes % sizeof(T) == 0)
1132         << "params_size_in_bytes must be multiple of element size";
1133     int64 params_size = params_size_in_bytes / sizeof(T);
1134 
1135     Tensor* output_t = nullptr;
1136     OP_REQUIRES_OK(context, context->allocate_output(0, {1}, &output_t));
1137     *output_t->template flat<Index>().data() = params_size;
1138   }
1139 
1140  private:
1141   int num_proj_;
1142 };
1143 
1144 #define REGISTER_GPU(T)                                    \
1145   REGISTER_KERNEL_BUILDER(Name("CudnnRNNParamsSize")       \
1146                               .Device(DEVICE_GPU)          \
1147                               .HostMemory("num_layers")    \
1148                               .HostMemory("num_units")     \
1149                               .HostMemory("input_size")    \
1150                               .HostMemory("params_size")   \
1151                               .TypeConstraint<T>("T")      \
1152                               .TypeConstraint<int32>("S"), \
1153                           CudnnRNNParamsSizeOp<GPUDevice, T, int32>);
1154 
1155 TF_CALL_half(REGISTER_GPU);
1156 TF_CALL_float(REGISTER_GPU);
1157 TF_CALL_double(REGISTER_GPU);
1158 #undef REGISTER_GPU
1159 
1160 // Convert weight and bias params from a platform-specific layout to the
1161 // canonical form.
1162 template <typename T>
1163 class CudnnRNNParamsToCanonical<GPUDevice, T> : public CudnnRNNKernelCommon {
1164  public:
CudnnRNNParamsToCanonical(OpKernelConstruction * context)1165   explicit CudnnRNNParamsToCanonical(OpKernelConstruction* context)
1166       : CudnnRNNKernelCommon(context) {
1167     if (context->HasAttr("num_params")) {
1168       OP_REQUIRES_OK(context, context->GetAttr("num_params", &num_params_));
1169     } else {
1170       num_params_ = 0;
1171     }
1172     if (context->HasAttr("num_params_weights")) {
1173       OP_REQUIRES_OK(context, context->GetAttr("num_params_weights",
1174                                                &num_params_weights_));
1175     } else {
1176       num_params_weights_ = 0;
1177     }
1178     if (context->HasAttr("num_params_biases")) {
1179       OP_REQUIRES_OK(
1180           context, context->GetAttr("num_params_biases", &num_params_biases_));
1181     } else {
1182       num_params_biases_ = 0;
1183     }
1184     if (context->HasAttr("num_proj")) {
1185       OP_REQUIRES_OK(context, context->GetAttr("num_proj", &num_proj_));
1186     } else {
1187       num_proj_ = 0;
1188     }
1189     if (num_proj_ == 0) {
1190       num_params_weights_ = num_params_;
1191       num_params_biases_ = num_params_;
1192     }
1193   }
1194 
Compute(OpKernelContext * context)1195   void Compute(OpKernelContext* context) override {
1196     const Tensor& input = context->input(3);
1197     auto input_ptr = StreamExecutorUtil::AsDeviceMemory<T>(input);
1198     Stream* stream = context->op_device_context()->stream();
1199 
1200     std::unique_ptr<RnnDescriptor> rnn_desc;
1201     OP_REQUIRES_OK(context,
1202                    ExtractCudnnRNNParamsInfo<T>(context, num_proj_, &rnn_desc));
1203     int64 params_size_in_bytes = rnn_desc->ParamsSizeInBytes();
1204     CHECK(params_size_in_bytes % sizeof(T) == 0)
1205         << "params_size_in_bytes must be multiple of element size";
1206 
1207     const Tensor* num_units_t = nullptr;
1208     OP_REQUIRES_OK(context, context->input("num_units", &num_units_t));
1209     CHECK(TensorShapeUtils::IsScalar(num_units_t->shape()))
1210         << "num_units is not a scalar";
1211     int num_units = num_units_t->scalar<int>()();
1212 
1213     const Tensor* input_size_t = nullptr;
1214     OP_REQUIRES_OK(context, context->input("input_size", &input_size_t));
1215     CHECK(TensorShapeUtils::IsScalar(input_size_t->shape()))
1216         << "input_size is not a scalar";
1217     int input_size = input_size_t->scalar<int>()();
1218 
1219     const Tensor* num_layers_t = nullptr;
1220     OP_REQUIRES_OK(context, context->input("num_layers", &num_layers_t));
1221     CHECK(TensorShapeUtils::IsScalar(num_layers_t->shape()))
1222         << "num_layers is not a scalar";
1223     int num_layers = num_layers_t->scalar<int>()();
1224     int num_dirs = 1;
1225     if (rnn_direction_mode() == RnnDirectionMode::kRnnBidirectional) {
1226       num_dirs = 2;
1227     }
1228     const int num_params_weights_per_layer =
1229         num_params_weights_ / num_layers / num_dirs;
1230     // Number of params applied on inputs. The rest are applied on recurrent
1231     // hidden states.
1232     const int num_params_input_state = num_params_weights_per_layer / 2;
1233     OP_REQUIRES(
1234         context, num_params_weights_ % (num_layers * num_dirs) == 0,
1235         errors::InvalidArgument("Number of params (weights) is not a multiple"
1236                                 "of num_layers * num_dirs."));
1237     OP_REQUIRES(
1238         context, num_params_biases_ % (num_layers * num_dirs) == 0,
1239         errors::InvalidArgument("Number of params (biases) is not a multiple"
1240                                 "of num_layers * num_dirs."));
1241     if (num_proj_ == 0) {
1242       OP_REQUIRES(
1243           context, num_params_weights_per_layer % 2 == 0,
1244           errors::InvalidArgument("Number of params (weights) per layer is not"
1245                                   "an even number with no projection."));
1246     } else {
1247       OP_REQUIRES(
1248           context, num_params_weights_per_layer % 2 != 0,
1249           errors::InvalidArgument("Number of params (weights) per layer is not"
1250                                   "an odl number with projection."));
1251     }
1252 
1253     OP_REQUIRES(
1254         context, num_params_weights_ == rnn_desc->ParamsWeightRegions().size(),
1255         errors::InvalidArgument("C Number of params mismatch. Expected ",
1256                                 num_params_weights_, ", got ",
1257                                 rnn_desc->ParamsWeightRegions().size()));
1258     int h_num_units = (num_proj_ == 0 ? num_units : num_proj_);
1259     int c_num_units = (num_proj_ == 0 ? 0 : num_units);
1260     for (int i = 0; i < rnn_desc->ParamsWeightRegions().size(); i++) {
1261       int64 size_in_bytes = rnn_desc->ParamsWeightRegions()[i].size;
1262       int64 size = size_in_bytes / sizeof(T);
1263       const int layer_idx = i / num_params_weights_per_layer;
1264       const int index_within_layer = i % num_params_weights_per_layer;
1265       int width = 0, height = (num_proj_ == 0 ? h_num_units : c_num_units);
1266       // In CuDNN layout, each layer has num_params_weights_per_layer params,
1267       // with the
1268       // first half a.k.a num_params_input_state params applied on the inputs,
1269       // and the second half on the recurrent hidden states.
1270       bool apply_on_input_state = index_within_layer < num_params_input_state;
1271       if (rnn_direction_mode() == RnnDirectionMode::kRnnUnidirectional) {
1272         if (layer_idx == 0 && apply_on_input_state) {
1273           width = input_size;
1274         } else {
1275           width = h_num_units;
1276         }
1277       } else {
1278         if (apply_on_input_state) {
1279           if (layer_idx <= 1) {
1280             // First fwd or bak layer.
1281             width = input_size;
1282           } else {
1283             // Following layers, cell inputs are concatenated outputs of
1284             // its prior layer.
1285             width = 2 * h_num_units;
1286           }
1287         } else {
1288           width = h_num_units;
1289         }
1290       }
1291       CHECK(size == width * height) << "Params size mismatch. Expected "
1292                                     << width * height << ", got " << size;
1293       Tensor* output = nullptr;
1294       int id_in_layer = i % num_params_weights_per_layer;
1295       if (num_proj_ != 0 && id_in_layer == num_params_weights_per_layer - 1) {
1296         std::swap(height, width);
1297       }
1298       OP_REQUIRES_OK(context, context->allocate_output(
1299                                   i, TensorShape({height, width}), &output));
1300       DeviceMemoryBase data_src_ptr = SliceDeviceMemory(
1301           input_ptr, rnn_desc->ParamsWeightRegions()[i].offset, size_in_bytes);
1302       auto data_dst_ptr = StreamExecutorUtil::AsDeviceMemory<T>(*output);
1303       stream->ThenMemcpy(&data_dst_ptr, data_src_ptr, size_in_bytes);
1304     }
1305 
1306     OP_REQUIRES(
1307         context, num_params_biases_ == rnn_desc->ParamsBiasRegions().size(),
1308         errors::InvalidArgument("A Number of params mismatch. Expected ",
1309                                 num_params_biases_, ", got ",
1310                                 rnn_desc->ParamsBiasRegions().size()));
1311     for (int i = 0; i < rnn_desc->ParamsBiasRegions().size(); i++) {
1312       int64 size_in_bytes = rnn_desc->ParamsBiasRegions()[i].size;
1313       int64 size = size_in_bytes / sizeof(T);
1314       OP_REQUIRES(context, size == num_units,
1315                   errors::InvalidArgument("Params size mismatch. Expected ",
1316                                           num_units, ", got ", size));
1317 
1318       Tensor* output = nullptr;
1319       OP_REQUIRES_OK(context,
1320                      context->allocate_output(num_params_weights_ + i,
1321                                               TensorShape({size}), &output));
1322       DeviceMemoryBase data_src_ptr = SliceDeviceMemory(
1323           input_ptr, rnn_desc->ParamsBiasRegions()[i].offset, size_in_bytes);
1324       auto data_dst_ptr = StreamExecutorUtil::AsDeviceMemory<T>(*output);
1325       stream->ThenMemcpy(&data_dst_ptr, data_src_ptr, size_in_bytes);
1326     }
1327   }
1328 
1329  private:
1330   int num_params_;
1331   int num_params_weights_;
1332   int num_params_biases_;
1333   int num_proj_;
1334 };
1335 
1336 #define REGISTER_GPU(T)                                     \
1337   REGISTER_KERNEL_BUILDER(Name("CudnnRNNParamsToCanonical") \
1338                               .Device(DEVICE_GPU)           \
1339                               .HostMemory("num_layers")     \
1340                               .HostMemory("num_units")      \
1341                               .HostMemory("input_size")     \
1342                               .TypeConstraint<T>("T"),      \
1343                           CudnnRNNParamsToCanonical<GPUDevice, T>);
1344 TF_CALL_half(REGISTER_GPU);
1345 TF_CALL_float(REGISTER_GPU);
1346 TF_CALL_double(REGISTER_GPU);
1347 #undef REGISTER_GPU
1348 
1349 #define REGISTER_GPU(T)                                       \
1350   REGISTER_KERNEL_BUILDER(Name("CudnnRNNParamsToCanonicalV2") \
1351                               .Device(DEVICE_GPU)             \
1352                               .HostMemory("num_layers")       \
1353                               .HostMemory("num_units")        \
1354                               .HostMemory("input_size")       \
1355                               .TypeConstraint<T>("T"),        \
1356                           CudnnRNNParamsToCanonical<GPUDevice, T>);
1357 TF_CALL_half(REGISTER_GPU);
1358 TF_CALL_float(REGISTER_GPU);
1359 TF_CALL_double(REGISTER_GPU);
1360 #undef REGISTER_GPU
1361 
1362 // Convert weight and bias params from the canonical form to a
1363 // platform-specific layout.
1364 template <typename T>
1365 class CudnnRNNCanonicalToParams<GPUDevice, T> : public CudnnRNNKernelCommon {
1366  public:
CudnnRNNCanonicalToParams(OpKernelConstruction * context)1367   explicit CudnnRNNCanonicalToParams(OpKernelConstruction* context)
1368       : CudnnRNNKernelCommon(context) {
1369     if (context->HasAttr("num_proj")) {
1370       OP_REQUIRES_OK(context, context->GetAttr("num_proj", &num_proj_));
1371     } else {
1372       num_proj_ = 0;
1373     }
1374   }
1375 
Compute(OpKernelContext * context)1376   void Compute(OpKernelContext* context) override {
1377     std::unique_ptr<RnnDescriptor> rnn_desc;
1378     OP_REQUIRES_OK(context,
1379                    ExtractCudnnRNNParamsInfo<T>(context, num_proj_, &rnn_desc));
1380     int64 params_size_in_bytes = rnn_desc->ParamsSizeInBytes();
1381     CHECK(params_size_in_bytes % sizeof(T) == 0)
1382         << "params_size_in_bytes must be multiple of element size";
1383     Tensor* output = nullptr;
1384     int params_size = params_size_in_bytes / sizeof(T);
1385     OP_REQUIRES_OK(context,
1386                    context->allocate_output(0, {params_size}, &output));
1387     auto output_ptr = StreamExecutorUtil::AsDeviceMemory<T>(*output);
1388     Stream* stream = context->op_device_context()->stream();
1389 
1390     OpInputList weights;
1391     OP_REQUIRES_OK(context, context->input_list("weights", &weights));
1392     RestoreParams<T>(weights, rnn_desc->ParamsWeightRegions(), &output_ptr,
1393                      stream);
1394 
1395     OpInputList biases;
1396     OP_REQUIRES_OK(context, context->input_list("biases", &biases));
1397     RestoreParams<T>(biases, rnn_desc->ParamsBiasRegions(), &output_ptr,
1398                      stream);
1399   }
1400 
1401  private:
1402   int num_proj_;
1403 };
1404 
1405 #define REGISTER_GPU(T)                                     \
1406   REGISTER_KERNEL_BUILDER(Name("CudnnRNNCanonicalToParams") \
1407                               .Device(DEVICE_GPU)           \
1408                               .HostMemory("num_layers")     \
1409                               .HostMemory("num_units")      \
1410                               .HostMemory("input_size")     \
1411                               .TypeConstraint<T>("T"),      \
1412                           CudnnRNNCanonicalToParams<GPUDevice, T>);
1413 TF_CALL_half(REGISTER_GPU);
1414 TF_CALL_float(REGISTER_GPU);
1415 TF_CALL_double(REGISTER_GPU);
1416 #undef REGISTER_GPU
1417 
1418 #define REGISTER_GPU(T)                                       \
1419   REGISTER_KERNEL_BUILDER(Name("CudnnRNNCanonicalToParamsV2") \
1420                               .Device(DEVICE_GPU)             \
1421                               .HostMemory("num_layers")       \
1422                               .HostMemory("num_units")        \
1423                               .HostMemory("input_size")       \
1424                               .TypeConstraint<T>("T"),        \
1425                           CudnnRNNCanonicalToParams<GPUDevice, T>);
1426 TF_CALL_half(REGISTER_GPU);
1427 TF_CALL_float(REGISTER_GPU);
1428 TF_CALL_double(REGISTER_GPU);
1429 #undef REGISTER_GPU
1430 
1431 // Run the forward operation of the RNN model.
1432 template <typename T>
1433 class CudnnRNNForwardOp<GPUDevice, T> : public CudnnRNNKernelCommon {
1434  public:
CudnnRNNForwardOp(OpKernelConstruction * context)1435   explicit CudnnRNNForwardOp(OpKernelConstruction* context)
1436       : CudnnRNNKernelCommon(context) {
1437     OP_REQUIRES_OK(context, context->GetAttr("is_training", &is_training_));
1438 
1439     // Read debug env variables.
1440     is_debug_mode_ = DebugCudnnRnn();
1441     debug_cudnn_rnn_algo_ = DebugCudnnRnnAlgo();
1442     debug_use_tensor_ops_ = DebugCudnnRnnUseTensorOps();
1443   }
1444 
Compute(OpKernelContext * context)1445   void Compute(OpKernelContext* context) override {
1446     AlgorithmConfig algo_config;
1447     ComputeAndReturnAlgorithm(context, &algo_config, /*var_seq_lengths=*/false,
1448                               /*time_major=*/true, /*num_proj=*/0);
1449   }
1450 
1451  protected:
ComputeAndReturnAlgorithm(OpKernelContext * context,AlgorithmConfig * output_algo_config,bool var_seq_lengths,bool time_major,int num_proj)1452   virtual void ComputeAndReturnAlgorithm(OpKernelContext* context,
1453                                          AlgorithmConfig* output_algo_config,
1454                                          bool var_seq_lengths, bool time_major,
1455                                          int num_proj) {
1456     CHECK_NE(output_algo_config, nullptr);
1457 
1458     const Tensor* input = nullptr;
1459     const Tensor* input_h = nullptr;
1460     const Tensor* input_c = nullptr;
1461     const Tensor* params = nullptr;
1462     const Tensor* sequence_lengths = nullptr;
1463     CudnnRnnModelShapes model_shapes;
1464     bool use_padded_io = false;
1465     if (var_seq_lengths) {
1466       OP_REQUIRES_OK(context, ExtractForwardInput(
1467                                   context, model_types(), time_major, &input,
1468                                   &input_h, &input_c, &params,
1469                                   &sequence_lengths, num_proj, &model_shapes));
1470       use_padded_io =
1471           ShouldUsePaddedIO(sequence_lengths, model_shapes, time_major);
1472     } else {
1473       OP_REQUIRES_OK(context,
1474                      ExtractForwardInput(context, model_types(), time_major,
1475                                          &input, &input_h, &input_c, &params,
1476                                          num_proj, &model_shapes));
1477     }
1478     RnnInputMode input_mode;
1479     OP_REQUIRES_OK(context,
1480                    ToRNNInputMode(rnn_input_mode(), model_shapes.num_units,
1481                                   model_shapes.input_size, &input_mode));
1482 
1483     Tensor* output = nullptr;
1484     Tensor* output_h = nullptr;
1485     Tensor* output_c = nullptr;
1486     OP_REQUIRES_OK(context, AllocateOutputs(context, model_shapes, &output,
1487                                             &output_h, &output_c));
1488 
1489     // Creates a memory callback for the reserve_space. The memory lives in the
1490     // output of this kernel. And it will be fed into the backward pass when
1491     // needed.
1492     CudnnRnnAllocatorInOutput<T> reserve_space_allocator(context, 3);
1493     // Creates a memory callback for the workspace. The memory lives to the end
1494     // of this kernel calls.
1495     CudnnRnnAllocatorInTemp<uint8> workspace_allocator(context);
1496 
1497     if (is_debug_mode_) {
1498       AlgorithmDesc algo_desc(debug_cudnn_rnn_algo_, debug_use_tensor_ops_);
1499       output_algo_config->set_algorithm(algo_desc);
1500     } else {
1501       OP_REQUIRES_OK(context,
1502                      MaybeAutoTune(context, model_shapes, input_mode, input,
1503                                    input_h, input_c, params, output, output_h,
1504                                    output_c, output_algo_config));
1505     }
1506 
1507     Status launch_status;
1508     {
1509       mutex_lock l(mu_);
1510       RnnDescriptor* rnn_desc_ptr = nullptr;
1511       OP_REQUIRES_OK(context,
1512                      GetCachedRnnDescriptor<T>(
1513                          context, model_shapes, input_mode, *output_algo_config,
1514                          &rnn_state_cache_, &rnn_desc_ptr, use_padded_io));
1515       launch_status = DoForward<T>(
1516           context, *rnn_desc_ptr, model_types(), model_shapes, input, input_h,
1517           input_c, params, is_training_, output, output_h, output_c,
1518           sequence_lengths, time_major, &reserve_space_allocator,
1519           &workspace_allocator, /*output_profile_result=*/nullptr);
1520     }
1521     OP_REQUIRES_OK(context, launch_status);
1522   }
1523 
1524  protected:
MaybeAutoTune(OpKernelContext * context,const CudnnRnnModelShapes & model_shapes,const RnnInputMode & input_mode,const Tensor * input,const Tensor * input_h,const Tensor * input_c,const Tensor * params,Tensor * output,Tensor * output_h,Tensor * output_c,AlgorithmConfig * best_algo_config)1525   virtual Status MaybeAutoTune(OpKernelContext* context,
1526                                const CudnnRnnModelShapes& model_shapes,
1527                                const RnnInputMode& input_mode,
1528                                const Tensor* input, const Tensor* input_h,
1529                                const Tensor* input_c, const Tensor* params,
1530                                Tensor* output, Tensor* output_h,
1531                                Tensor* output_c,
1532                                AlgorithmConfig* best_algo_config) {
1533     CHECK_NE(best_algo_config, nullptr);
1534     *best_algo_config = AlgorithmConfig();
1535     return Status::OK();
1536   }
1537 
is_training() const1538   bool is_training() const { return is_training_; }
1539   bool is_debug_mode_;
1540   bool debug_use_tensor_ops_;
1541   int64 debug_cudnn_rnn_algo_;
1542 
1543  private:
AllocateOutputs(OpKernelContext * context,const CudnnRnnModelShapes & model_shapes,Tensor ** output,Tensor ** output_h,Tensor ** output_c)1544   Status AllocateOutputs(OpKernelContext* context,
1545                          const CudnnRnnModelShapes& model_shapes,
1546                          Tensor** output, Tensor** output_h,
1547                          Tensor** output_c) {
1548     const TensorShape& hidden_state_shape = model_shapes.hidden_state_shape;
1549     const TensorShape& output_shape = model_shapes.output_shape;
1550     const TensorShape& cell_state_shape = model_shapes.cell_state_shape;
1551 
1552     TF_RETURN_IF_ERROR(context->allocate_output(0, output_shape, output));
1553     TF_RETURN_IF_ERROR(
1554         context->allocate_output(1, hidden_state_shape, output_h));
1555     if (HasInputC()) {
1556       TF_RETURN_IF_ERROR(
1557           context->allocate_output(2, cell_state_shape, output_c));
1558     } else {
1559       // Only LSTM uses input_c and output_c. So for all other models, we only
1560       // need to create dummy outputs.
1561       TF_RETURN_IF_ERROR(context->allocate_output(2, {}, output_c));
1562     }
1563     if (!is_training_) {
1564       Tensor* dummy_reserve_space = nullptr;
1565       TF_RETURN_IF_ERROR(context->allocate_output(3, {}, &dummy_reserve_space));
1566     }
1567     return Status::OK();
1568   }
1569 
1570   mutex mu_;
1571   bool is_training_;
1572   RnnStateCache rnn_state_cache_ TF_GUARDED_BY(mu_);
1573 };
1574 
1575 #define REGISTER_GPU(T)                                           \
1576   REGISTER_KERNEL_BUILDER(                                        \
1577       Name("CudnnRNN").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
1578       CudnnRNNForwardOp<GPUDevice, T>);
1579 
1580 TF_CALL_half(REGISTER_GPU);
1581 TF_CALL_float(REGISTER_GPU);
1582 TF_CALL_double(REGISTER_GPU);
1583 #undef REGISTER_GPU
1584 
1585 template <typename T>
1586 class CudnnRNNForwardOpV2<GPUDevice, T>
1587     : public CudnnRNNForwardOp<GPUDevice, T> {
1588  private:
1589   using CudnnRNNForwardOp<GPUDevice, T>::is_training;
1590   using CudnnRNNKernelCommon::CreateRnnDescriptor;
1591   using CudnnRNNKernelCommon::dropout;
1592   using CudnnRNNKernelCommon::HasInputC;
1593   using CudnnRNNKernelCommon::model_types;
1594 
1595  public:
CudnnRNNForwardOpV2(OpKernelConstruction * context)1596   explicit CudnnRNNForwardOpV2(OpKernelConstruction* context)
1597       : CudnnRNNForwardOp<GPUDevice, T>(context) {}
1598 
Compute(OpKernelContext * context)1599   void Compute(OpKernelContext* context) override {
1600     AlgorithmConfig best_algo_config;
1601     CudnnRNNForwardOp<GPUDevice, T>::ComputeAndReturnAlgorithm(
1602         context, &best_algo_config, /*var_seq_lengths=*/false,
1603         /*time_major=*/true, /*num_proj=*/0);
1604     if (!context->status().ok()) {
1605       return;
1606     }
1607 
1608     Tensor* output_host_reserved = nullptr;
1609     // output_host_reserved stores opaque info used for backprop when running
1610     // in training mode. At present, it includes a serialization of the best
1611     // AlgorithmDesc picked during rnn forward pass autotune.
1612     // int8 algorithm_id
1613     // int8 use_tensor_op
1614     // If autotune is not enabled, the algorithm_id is
1615     // stream_executor::dnn::kDefaultAlgorithm and use_tensor_op is false. If
1616     // running in inference mode, the output_host_reserved is currently not
1617     // populated.
1618     if (is_training()) {
1619       OP_REQUIRES_OK(context, context->allocate_output(4, TensorShape({2}),
1620                                                        &output_host_reserved));
1621       auto output_host_reserved_int8 = output_host_reserved->vec<int8>();
1622       output_host_reserved_int8(0) = best_algo_config.algorithm()->algo_id();
1623       output_host_reserved_int8(1) =
1624           best_algo_config.algorithm()->tensor_ops_enabled();
1625     } else {
1626       OP_REQUIRES_OK(context,
1627                      context->allocate_output(4, {}, &output_host_reserved));
1628     }
1629   }
1630 
1631  protected:
MaybeAutoTune(OpKernelContext * context,const CudnnRnnModelShapes & model_shapes,const RnnInputMode & input_mode,const Tensor * input,const Tensor * input_h,const Tensor * input_c,const Tensor * params,Tensor * output,Tensor * output_h,Tensor * output_c,AlgorithmConfig * algo_config)1632   Status MaybeAutoTune(OpKernelContext* context,
1633                        const CudnnRnnModelShapes& model_shapes,
1634                        const RnnInputMode& input_mode, const Tensor* input,
1635                        const Tensor* input_h, const Tensor* input_c,
1636                        const Tensor* params, Tensor* output, Tensor* output_h,
1637                        Tensor* output_c,
1638                        AlgorithmConfig* algo_config) override {
1639     CHECK_NE(algo_config, nullptr);
1640     if (!CudnnRnnUseAutotune() || this->is_debug_mode_) {
1641       *algo_config = AlgorithmConfig();
1642       return Status::OK();
1643     }
1644 
1645     std::vector<AlgorithmDesc> algorithms;
1646     auto* stream = context->op_device_context()->stream();
1647     CHECK(stream->parent()->GetRnnAlgorithms(&algorithms));
1648     if (algorithms.empty()) {
1649       LOG(WARNING) << "No Rnn algorithm found";
1650       return Status::OK();
1651     }
1652 
1653     const auto& modeltypes = model_types();
1654     CudnnRnnParameters rnn_params(
1655         model_shapes.num_layers, model_shapes.input_size,
1656         model_shapes.num_units, model_shapes.max_seq_length,
1657         model_shapes.batch_size, model_shapes.dir_count,
1658         /*has_dropout=*/std::abs(dropout()) > 1e-8, is_training(),
1659         modeltypes.rnn_mode, modeltypes.rnn_input_mode, input->dtype());
1660 
1661     if (AutoTuneRnnConfigMap::GetInstance()->Find(rnn_params, algo_config)) {
1662       VLOG(1) << "Using existing best Cudnn RNN algorithm "
1663               << "(algo, tensor_op_enabled) = ("
1664               << algo_config->algorithm()->algo_id() << ", "
1665               << algo_config->algorithm()->tensor_ops_enabled() << ").";
1666       return Status::OK();
1667     }
1668 
1669     // Create temp tensors when profiling backprop pass.
1670     auto data_type = input->dtype();
1671     Tensor output_backprop;
1672     Tensor output_h_backprop;
1673     Tensor output_c_backprop;
1674     Tensor input_backprop;
1675     Tensor input_h_backprop;
1676     Tensor input_c_backprop;
1677     Tensor params_backprop;
1678     if (is_training()) {
1679       TF_RETURN_IF_ERROR(context->allocate_temp(
1680           data_type, model_shapes.output_shape, &output_backprop));
1681       TF_RETURN_IF_ERROR(context->allocate_temp(
1682           data_type, model_shapes.hidden_state_shape, &output_h_backprop));
1683 
1684       TF_RETURN_IF_ERROR(
1685           context->allocate_temp(data_type, params->shape(), &params_backprop));
1686       TF_RETURN_IF_ERROR(context->allocate_temp(
1687           data_type, model_shapes.input_shape, &input_backprop));
1688       TF_RETURN_IF_ERROR(context->allocate_temp(
1689           data_type, model_shapes.hidden_state_shape, &input_h_backprop));
1690       if (HasInputC()) {
1691         TF_RETURN_IF_ERROR(context->allocate_temp(
1692             data_type, model_shapes.hidden_state_shape, &output_c_backprop));
1693         TF_RETURN_IF_ERROR(context->allocate_temp(
1694             data_type, model_shapes.hidden_state_shape, &input_c_backprop));
1695       }
1696     }
1697     ProfileResult best_result;
1698     for (auto& algo : algorithms) {
1699       VLOG(1) << "Profile Cudnn RNN algorithm (algo, tensor_op_enabled) =  ("
1700               << algo.algo_id() << ", " << algo.tensor_ops_enabled() << ").";
1701       Status status;
1702       ProfileResult final_profile_result;
1703 
1704       ProfileResult fwd_profile_result;
1705       ProfileResult bak_profile_result;
1706 
1707       // RnnDescriptor is algorithm-dependent, thus not reusable.
1708       std::unique_ptr<RnnDescriptor> rnn_desc;
1709       // Use a temp scratch allocator for the random num generator.
1710       CudnnRnnAllocatorInTemp<uint8> dropout_state_allocator(context);
1711       if (!this->template CreateRnnDescriptor<T>(
1712                    context, model_shapes, input_mode, AlgorithmConfig(algo),
1713                    &dropout_state_allocator, &rnn_desc,
1714                    /*use_padded_io=*/false)
1715                .ok()) {
1716         continue;
1717       }
1718 
1719       // Again use temp scratch allocator during profiling.
1720       CudnnRnnAllocatorInTemp<T> reserve_space_allocator(context);
1721       CudnnRnnAllocatorInTemp<uint8> workspace_allocator(context);
1722       status = DoForward<T>(context, *rnn_desc, model_types(), model_shapes,
1723                             input, input_h, input_c, params, is_training(),
1724                             output, output_h, output_c, nullptr, true,
1725                             &reserve_space_allocator, &workspace_allocator,
1726                             &fwd_profile_result);
1727       if (!status.ok()) {
1728         continue;
1729       }
1730 
1731       if (is_training()) {
1732         // Get reserve space from the forward pass.
1733         Tensor reserve_space = reserve_space_allocator.get_allocated_tensor(0);
1734         status = DoBackward<T>(
1735             context, *rnn_desc, model_types(), model_shapes, input, input_h,
1736             input_c, params, output, output_h, output_c, &output_backprop,
1737             &output_h_backprop, &output_c_backprop, &reserve_space,
1738             &input_backprop, &input_h_backprop, &input_c_backprop,
1739             &params_backprop, nullptr, true, &workspace_allocator,
1740             &bak_profile_result);
1741         if (!status.ok()) {
1742           continue;
1743         }
1744         final_profile_result.set_elapsed_time_in_ms(
1745             fwd_profile_result.elapsed_time_in_ms() +
1746             bak_profile_result.elapsed_time_in_ms());
1747       } else {
1748         final_profile_result = fwd_profile_result;
1749       }
1750 
1751       auto total_time = final_profile_result.elapsed_time_in_ms();
1752       VLOG(1) << "Cudnn RNN algorithm (algo, tensor_op_enabled) =  ("
1753               << algo.algo_id() << ", " << algo.tensor_ops_enabled() << ")"
1754               << " run time: " << total_time << " ms.";
1755       if (total_time < best_result.elapsed_time_in_ms()) {
1756         best_result.set_elapsed_time_in_ms(total_time);
1757         best_result.set_algorithm(algo);
1758       }
1759     }
1760 
1761     if (!best_result.is_valid()) {
1762       return Status(error::Code::INTERNAL, "No algorithm worked!");
1763     }
1764     algo_config->set_algorithm(best_result.algorithm());
1765     VLOG(1) << "Best Cudnn RNN algorithm (algo, tensor_op_enabled) =  ("
1766             << best_result.algorithm().algo_id() << ", "
1767             << best_result.algorithm().tensor_ops_enabled() << ").";
1768     AutoTuneRnnConfigMap::GetInstance()->Insert(rnn_params, *algo_config);
1769     return Status::OK();
1770   }
1771 };
1772 
1773 #define REGISTER_GPU(T)                                    \
1774   REGISTER_KERNEL_BUILDER(Name("CudnnRNNV2")               \
1775                               .Device(DEVICE_GPU)          \
1776                               .HostMemory("host_reserved") \
1777                               .TypeConstraint<T>("T"),     \
1778                           CudnnRNNForwardOpV2<GPUDevice, T>);
1779 
1780 TF_CALL_half(REGISTER_GPU);
1781 TF_CALL_float(REGISTER_GPU);
1782 TF_CALL_double(REGISTER_GPU);
1783 #undef REGISTER_GPU
1784 
1785 template <typename T>
1786 class CudnnRNNForwardOpV3<GPUDevice, T>
1787     : public CudnnRNNForwardOp<GPUDevice, T> {
1788  private:
1789   using CudnnRNNForwardOp<GPUDevice, T>::is_training;
1790   using CudnnRNNKernelCommon::CreateRnnDescriptor;
1791   using CudnnRNNKernelCommon::dropout;
1792   using CudnnRNNKernelCommon::HasInputC;
1793   using CudnnRNNKernelCommon::model_types;
1794   bool time_major_;
1795 
1796  protected:
time_major()1797   bool time_major() { return time_major_; }
1798 
1799  public:
CudnnRNNForwardOpV3(OpKernelConstruction * context)1800   explicit CudnnRNNForwardOpV3(OpKernelConstruction* context)
1801       : CudnnRNNForwardOp<GPUDevice, T>(context) {
1802     OP_REQUIRES_OK(context, context->GetAttr("time_major", &time_major_));
1803     if (context->HasAttr("num_proj")) {
1804       OP_REQUIRES_OK(context, context->GetAttr("num_proj", &num_proj_));
1805     } else {
1806       num_proj_ = 0;
1807     }
1808   }
1809 
Compute(OpKernelContext * context)1810   void Compute(OpKernelContext* context) override {
1811     AlgorithmConfig best_algo_config;
1812     CudnnRNNForwardOp<GPUDevice, T>::ComputeAndReturnAlgorithm(
1813         context, &best_algo_config, /*var_seq_lengths=*/true,
1814         /*time_major=*/time_major(), num_proj_);
1815     if (!context->status().ok()) {
1816       return;
1817     }
1818 
1819     Tensor* output_host_reserved = nullptr;
1820     // TODO: Current V3 only uses the default standard algorithm to process
1821     // batches with variable sequences and the inputs should be padded.
1822     // Autotune is not supported yet.
1823     OP_REQUIRES_OK(context,
1824                    context->allocate_output(4, {}, &output_host_reserved));
1825   }
1826 
1827  private:
1828   int num_proj_;
1829 };
1830 
1831 #define REGISTER_GPU(T)                                       \
1832   REGISTER_KERNEL_BUILDER(Name("CudnnRNNV3")                  \
1833                               .Device(DEVICE_GPU)             \
1834                               .HostMemory("sequence_lengths") \
1835                               .HostMemory("host_reserved")    \
1836                               .TypeConstraint<T>("T"),        \
1837                           CudnnRNNForwardOpV3<GPUDevice, T>);
1838 
1839 TF_CALL_half(REGISTER_GPU);
1840 TF_CALL_float(REGISTER_GPU);
1841 TF_CALL_double(REGISTER_GPU);
1842 #undef REGISTER_GPU
1843 
1844 // Run the backward operation of the RNN model.
1845 template <typename T>
1846 class CudnnRNNBackwardOp<GPUDevice, T> : public CudnnRNNKernelCommon {
1847  public:
CudnnRNNBackwardOp(OpKernelConstruction * context)1848   explicit CudnnRNNBackwardOp(OpKernelConstruction* context)
1849       : CudnnRNNKernelCommon(context) {}
1850 
Compute(OpKernelContext * context)1851   void Compute(OpKernelContext* context) override {
1852     ComputeImpl(context, false, true, 0);
1853   }
1854 
1855  protected:
ComputeImpl(OpKernelContext * context,bool var_seq_lengths,bool time_major,int num_proj)1856   virtual void ComputeImpl(OpKernelContext* context, bool var_seq_lengths,
1857                            bool time_major, int num_proj) {
1858     const Tensor* input = nullptr;
1859     const Tensor* input_h = nullptr;
1860     const Tensor* input_c = nullptr;
1861     const Tensor* params = nullptr;
1862     const Tensor* sequence_lengths = nullptr;
1863     CudnnRnnModelShapes model_shapes;
1864     bool use_padded_io = false;
1865     if (var_seq_lengths) {
1866       OP_REQUIRES_OK(context, ExtractForwardInput(
1867                                   context, model_types(), time_major, &input,
1868                                   &input_h, &input_c, &params,
1869                                   &sequence_lengths, num_proj, &model_shapes));
1870       use_padded_io =
1871           ShouldUsePaddedIO(sequence_lengths, model_shapes, time_major);
1872     } else {
1873       OP_REQUIRES_OK(context,
1874                      ExtractForwardInput(context, model_types(), time_major,
1875                                          &input, &input_h, &input_c, &params,
1876                                          num_proj, &model_shapes));
1877     }
1878     RnnInputMode input_mode;
1879     OP_REQUIRES_OK(context,
1880                    ToRNNInputMode(rnn_input_mode(), model_shapes.num_units,
1881                                   model_shapes.input_size, &input_mode));
1882 
1883     const Tensor* output = nullptr;
1884     const Tensor* output_h = nullptr;
1885     const Tensor* output_c = nullptr;
1886     const Tensor* output_backprop = nullptr;
1887     const Tensor* output_h_backprop = nullptr;
1888     const Tensor* output_c_backprop = nullptr;
1889     const Tensor* reserve_space = nullptr;
1890     OP_REQUIRES_OK(context,
1891                    ExtractBackwardInputs(context, model_shapes, model_types(),
1892                                          &output, &output_h, &output_c,
1893                                          &output_backprop, &output_h_backprop,
1894                                          &output_c_backprop, &reserve_space));
1895 
1896     Tensor* input_backprop = nullptr;
1897     Tensor* input_h_backprop = nullptr;
1898     Tensor* input_c_backprop = nullptr;
1899     Tensor* params_backprop = nullptr;
1900     OP_REQUIRES_OK(context,
1901                    AllocateOutputs(context, model_shapes, params->shape(),
1902                                    &input_backprop, &input_h_backprop,
1903                                    &input_c_backprop, &params_backprop));
1904 
1905     // Creates a memory callback for the workspace. The memory lives to the end
1906     // of this kernel calls.
1907     CudnnRnnAllocatorInTemp<uint8> workspace_allocator(context);
1908     AlgorithmConfig algo_config;
1909     OP_REQUIRES_OK(context, GetAlgorithm(context, &algo_config));
1910     Status launch_status;
1911     {
1912       mutex_lock l(mu_);
1913       RnnDescriptor* rnn_desc_ptr = nullptr;
1914       OP_REQUIRES_OK(
1915           context, GetCachedRnnDescriptor<T>(context, model_shapes, input_mode,
1916                                              algo_config, &rnn_state_cache_,
1917                                              &rnn_desc_ptr, use_padded_io));
1918       launch_status = DoBackward<T>(
1919           context, *rnn_desc_ptr, model_types(), model_shapes, input, input_h,
1920           input_c, params, output, output_h, output_c, output_backprop,
1921           output_h_backprop, output_c_backprop, reserve_space, input_backprop,
1922           input_h_backprop, input_c_backprop, params_backprop, sequence_lengths,
1923           time_major, &workspace_allocator,
1924           /*output_profile_result=*/nullptr);
1925     }
1926     OP_REQUIRES_OK(context, launch_status);
1927   }
1928 
1929  protected:
GetAlgorithm(OpKernelContext * context,AlgorithmConfig * algo_config)1930   virtual Status GetAlgorithm(OpKernelContext* context,
1931                               AlgorithmConfig* algo_config) {
1932     CHECK_NE(algo_config, nullptr);
1933     *algo_config = AlgorithmConfig();
1934     return Status::OK();
1935   }
1936 
1937  private:
1938   mutex mu_;
1939   RnnStateCache rnn_state_cache_ TF_GUARDED_BY(mu_);
1940 
ExtractBackwardInputs(OpKernelContext * context,const CudnnRnnModelShapes & model_shapes,const CudnnModelTypes & model_types,const Tensor ** output,const Tensor ** output_h,const Tensor ** output_c,const Tensor ** output_backprop,const Tensor ** output_h_backprop,const Tensor ** output_c_backprop,const Tensor ** reserve_space)1941   Status ExtractBackwardInputs(
1942       OpKernelContext* context, const CudnnRnnModelShapes& model_shapes,
1943       const CudnnModelTypes& model_types, const Tensor** output,
1944       const Tensor** output_h, const Tensor** output_c,
1945       const Tensor** output_backprop, const Tensor** output_h_backprop,
1946       const Tensor** output_c_backprop, const Tensor** reserve_space) {
1947     TF_RETURN_IF_ERROR(context->input("output", output));
1948     TF_RETURN_IF_ERROR(context->input("output_backprop", output_backprop));
1949     TF_RETURN_IF_ERROR(context->input("output_h", output_h));
1950     TF_RETURN_IF_ERROR(context->input("output_h_backprop", output_h_backprop));
1951     if (model_types.HasInputC()) {
1952       TF_RETURN_IF_ERROR(context->input("output_c", output_c));
1953       TF_RETURN_IF_ERROR(
1954           context->input("output_c_backprop", output_c_backprop));
1955     }
1956     TF_RETURN_IF_ERROR(context->input("reserve_space", reserve_space));
1957     const TensorShape& hidden_state_shape = model_shapes.hidden_state_shape;
1958     const TensorShape& output_shape = model_shapes.output_shape;
1959     const TensorShape& cell_state_shape = model_shapes.cell_state_shape;
1960 
1961     if (output_shape != (*output)->shape()) {
1962       return errors::InvalidArgument(
1963           "Invalid output shape: ", (*output)->shape().DebugString(), " ",
1964           output_shape.DebugString());
1965     }
1966     if (hidden_state_shape != (*output_h)->shape()) {
1967       return errors::InvalidArgument(
1968           "Invalid output_h shape: ", (*output_h)->shape().DebugString(), " ",
1969           hidden_state_shape.DebugString());
1970     }
1971 
1972     if (output_shape != (*output_backprop)->shape()) {
1973       return errors::InvalidArgument("Invalid output_backprop shape: ",
1974                                      (*output_backprop)->shape().DebugString(),
1975                                      " ", output_shape.DebugString());
1976     }
1977     if (hidden_state_shape != (*output_h_backprop)->shape()) {
1978       return errors::InvalidArgument(
1979           "Invalid output_h_backprop shape: ",
1980           (*output_h_backprop)->shape().DebugString(), " ",
1981           hidden_state_shape.DebugString());
1982     }
1983 
1984     if (model_types.HasInputC()) {
1985       if (cell_state_shape != (*output_c)->shape()) {
1986         return errors::InvalidArgument(
1987             "Invalid output_c shape: ", (*output_c)->shape().DebugString(), " ",
1988             cell_state_shape.DebugString());
1989       }
1990       if (cell_state_shape != (*output_c_backprop)->shape()) {
1991         return errors::InvalidArgument(
1992             "Invalid output_c_backprop shape: ",
1993             (*output_c_backprop)->shape().DebugString(), " ",
1994             cell_state_shape.DebugString());
1995       }
1996     }
1997     return Status::OK();
1998   }
1999 
AllocateOutputs(OpKernelContext * context,const CudnnRnnModelShapes & model_shapes,const TensorShape & params_shape,Tensor ** input_backprop,Tensor ** input_h_backprop,Tensor ** input_c_backprop,Tensor ** params_backprop)2000   Status AllocateOutputs(OpKernelContext* context,
2001                          const CudnnRnnModelShapes& model_shapes,
2002                          const TensorShape& params_shape,
2003                          Tensor** input_backprop, Tensor** input_h_backprop,
2004                          Tensor** input_c_backprop, Tensor** params_backprop) {
2005     const TensorShape& input_shape = model_shapes.input_shape;
2006     const TensorShape& hidden_state_shape = model_shapes.hidden_state_shape;
2007     const TensorShape& cell_state_shape = model_shapes.cell_state_shape;
2008 
2009     TF_RETURN_IF_ERROR(
2010         context->allocate_output(0, input_shape, input_backprop));
2011     TF_RETURN_IF_ERROR(
2012         context->allocate_output(1, hidden_state_shape, input_h_backprop));
2013     if (HasInputC()) {
2014       TF_RETURN_IF_ERROR(
2015           context->allocate_output(2, cell_state_shape, input_c_backprop));
2016     } else {
2017       // Only LSTM uses input_c and output_c. So for all other models, we only
2018       // need to create dummy outputs.
2019       TF_RETURN_IF_ERROR(context->allocate_output(2, {}, input_c_backprop));
2020     }
2021     TF_RETURN_IF_ERROR(
2022         context->allocate_output(3, params_shape, params_backprop));
2023     return Status::OK();
2024   }
2025 };
2026 
2027 #define REGISTER_GPU(T)                                                   \
2028   REGISTER_KERNEL_BUILDER(                                                \
2029       Name("CudnnRNNBackprop").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
2030       CudnnRNNBackwardOp<GPUDevice, T>);
2031 
2032 TF_CALL_half(REGISTER_GPU);
2033 TF_CALL_float(REGISTER_GPU);
2034 TF_CALL_double(REGISTER_GPU);
2035 #undef REGISTER_GPU
2036 
2037 template <typename T>
2038 class CudnnRNNBackwardOpV2<GPUDevice, T>
2039     : public CudnnRNNBackwardOp<GPUDevice, T> {
2040  public:
CudnnRNNBackwardOpV2(OpKernelConstruction * context)2041   explicit CudnnRNNBackwardOpV2(OpKernelConstruction* context)
2042       : CudnnRNNBackwardOp<GPUDevice, T>(context) {}
2043 
2044  protected:
GetAlgorithm(OpKernelContext * context,AlgorithmConfig * algo_config)2045   Status GetAlgorithm(OpKernelContext* context,
2046                       AlgorithmConfig* algo_config) override {
2047     CHECK_NE(algo_config, nullptr);
2048     const Tensor* host_reserved = nullptr;
2049     TF_RETURN_IF_ERROR(context->input("host_reserved", &host_reserved));
2050 
2051     auto host_reserved_int8 = host_reserved->vec<int8>();
2052     const AlgorithmDesc algo_desc(host_reserved_int8(0), host_reserved_int8(1));
2053     algo_config->set_algorithm(algo_desc);
2054     return Status::OK();
2055   }
2056 };
2057 
2058 #define REGISTER_GPU(T)                                    \
2059   REGISTER_KERNEL_BUILDER(Name("CudnnRNNBackpropV2")       \
2060                               .Device(DEVICE_GPU)          \
2061                               .HostMemory("host_reserved") \
2062                               .TypeConstraint<T>("T"),     \
2063                           CudnnRNNBackwardOpV2<GPUDevice, T>);
2064 
2065 TF_CALL_half(REGISTER_GPU);
2066 TF_CALL_float(REGISTER_GPU);
2067 TF_CALL_double(REGISTER_GPU);
2068 #undef REGISTER_GPU
2069 
2070 template <typename T>
2071 class CudnnRNNBackwardOpV3<GPUDevice, T>
2072     : public CudnnRNNBackwardOp<GPUDevice, T> {
2073  private:
2074   bool time_major_;
2075 
2076  protected:
time_major()2077   bool time_major() { return time_major_; }
2078 
2079  public:
CudnnRNNBackwardOpV3(OpKernelConstruction * context)2080   explicit CudnnRNNBackwardOpV3(OpKernelConstruction* context)
2081       : CudnnRNNBackwardOp<GPUDevice, T>(context) {
2082     OP_REQUIRES_OK(context, context->GetAttr("time_major", &time_major_));
2083     if (context->HasAttr("num_proj")) {
2084       OP_REQUIRES_OK(context, context->GetAttr("num_proj", &num_proj_));
2085     } else {
2086       num_proj_ = 0;
2087     }
2088   }
2089 
Compute(OpKernelContext * context)2090   void Compute(OpKernelContext* context) override {
2091     CudnnRNNBackwardOp<GPUDevice, T>::ComputeImpl(context, true, time_major(),
2092                                                   num_proj_);
2093   }
2094 
2095  private:
2096   int num_proj_;
2097 };
2098 
2099 #define REGISTER_GPU(T)                                       \
2100   REGISTER_KERNEL_BUILDER(Name("CudnnRNNBackpropV3")          \
2101                               .Device(DEVICE_GPU)             \
2102                               .HostMemory("sequence_lengths") \
2103                               .HostMemory("host_reserved")    \
2104                               .TypeConstraint<T>("T"),        \
2105                           CudnnRNNBackwardOpV3<GPUDevice, T>);
2106 
2107 TF_CALL_half(REGISTER_GPU);
2108 TF_CALL_float(REGISTER_GPU);
2109 TF_CALL_double(REGISTER_GPU);
2110 #undef REGISTER_GPU
2111 
2112 // TODO(zhengxq): Add the conversion of Cudnn RNN Params from and to
2113 // its canonical form.
2114 
2115 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
2116 
2117 }  // namespace tensorflow
2118