1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #define EIGEN_USE_THREADS
16
17 #include <stddef.h>
18
19 #include <atomic>
20 #include <cmath>
21 #include <functional>
22 #include <limits>
23 #include <string>
24 #include <unordered_set>
25
26 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
27 #include "tensorflow/core/framework/device_base.h"
28 #include "tensorflow/core/framework/kernel_def_builder.h"
29 #include "tensorflow/core/framework/op.h"
30 #include "tensorflow/core/framework/op_def_builder.h"
31 #include "tensorflow/core/framework/op_kernel.h"
32 #include "tensorflow/core/framework/register_types.h"
33 #include "tensorflow/core/framework/tensor.h"
34 #include "tensorflow/core/framework/tensor_shape.h"
35 #include "tensorflow/core/framework/tensor_types.h"
36 #include "tensorflow/core/framework/types.h"
37 #include "tensorflow/core/kernels/gpu_utils.h"
38 #include "tensorflow/core/lib/core/errors.h"
39 #include "tensorflow/core/lib/core/status.h"
40 #include "tensorflow/core/lib/core/stringpiece.h"
41 #include "tensorflow/core/lib/gtl/inlined_vector.h"
42 #include "tensorflow/core/lib/hash/hash.h"
43 #include "tensorflow/core/lib/strings/stringprintf.h"
44 #include "tensorflow/core/platform/fingerprint.h"
45 #include "tensorflow/core/platform/mutex.h"
46 #include "tensorflow/core/platform/types.h"
47 #include "tensorflow/core/profiler/lib/scoped_annotation.h"
48 #include "tensorflow/core/util/env_var.h"
49 #include "tensorflow/core/util/use_cudnn.h"
50
51 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
52 #include "tensorflow/core/platform/stream_executor.h"
53 #include "tensorflow/core/util/stream_executor_util.h"
54 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
55
56 /*
57 * This module implements ops that fuse a multi-layer multi-step RNN/LSTM model
58 * using the underlying Cudnn library.
59 *
60 * Cudnn RNN library exposes an opaque parameter buffer with unknown layout and
61 * format. And it is very likely that if saved, they cannot be used across
62 * different GPUs. So users need to first query the size of the opaque
63 * parameter buffer, and convert it to and from its canonical forms. But each
64 * actual training step is carried out with the parameter buffer.
65 *
66 * Similar to many other ops, the forward op has two flavors: training and
67 * inference. When training is specified, additional data in reserve_space will
68 * be produced for the backward pass. So there is a performance penalty.
69 *
70 * In addition to the actual data and reserve_space, Cudnn also needs more
71 * memory as temporary workspace. The memory management to and from
72 * stream-executor is done through ScratchAllocator. In general,
73 * stream-executor is responsible for creating the memory of proper size. And
74 * TensorFlow is responsible for making sure the memory is alive long enough
75 * and recycles afterwards.
76 *
77 */
78 namespace tensorflow {
79
80 using CPUDevice = Eigen::ThreadPoolDevice;
81
82 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
83
84 using GPUDevice = Eigen::GpuDevice;
85 using se::Stream;
86 using se::StreamExecutor;
87 using se::dnn::RnnDescriptor;
88
89 template <typename Device, typename T, typename Index>
90 class CudnnRNNParamsSizeOp;
91
92 template <typename Device, typename T>
93 class CudnnRNNParamsToCanonical;
94
95 template <typename Device, typename T>
96 class CudnnRNNCanonicalToParams;
97
98 template <typename Device, typename T>
99 class CudnnRNNForwardOp;
100
101 template <typename Device, typename T>
102 class CudnnRNNBackwardOp;
103
104 template <typename Device, typename T>
105 class CudnnRNNForwardOpV2;
106
107 template <typename Device, typename T>
108 class CudnnRNNBackwardOpV2;
109
110 template <typename Device, typename T>
111 class CudnnRNNForwardOpV3;
112
113 template <typename Device, typename T>
114 class CudnnRNNBackwardOpV3;
115
116 enum class TFRNNInputMode {
117 kRNNLinearInput = 0,
118 kRNNSkipInput = 1,
119 kAutoSelect = 9999999
120 };
121
122 namespace {
123 using se::DeviceMemory;
124 using se::DeviceMemoryBase;
125 using se::ScratchAllocator;
126 using se::dnn::AlgorithmConfig;
127 using se::dnn::AlgorithmDesc;
128 using se::dnn::ProfileResult;
129 using se::dnn::RnnDirectionMode;
130 using se::dnn::RnnInputMode;
131 using se::dnn::RnnMode;
132 using se::dnn::RnnSequenceTensorDescriptor;
133 using se::dnn::RnnStateTensorDescriptor;
134 using se::dnn::ToDataType;
135 using se::port::StatusOr;
136
HashList(const std::vector<int> & list)137 uint64 HashList(const std::vector<int>& list) {
138 if (list.empty()) {
139 return 0;
140 }
141 uint64 hash_code = list[0];
142 for (int i = 1; i < list.size(); i++) {
143 hash_code = Hash64Combine(hash_code, list[i]);
144 }
145 return hash_code;
146 }
147
148 // Encapsulate all the shape information that is used in both forward and
149 // backward rnn operations.
150 class CudnnRnnParameters {
151 public:
CudnnRnnParameters(int num_layers,int input_size,int num_units,int max_seq_length,int batch_size,int dir_count,bool has_dropout,bool is_training,RnnMode rnn_mode,TFRNNInputMode rnn_input_mode,DataType dtype)152 CudnnRnnParameters(int num_layers, int input_size, int num_units,
153 int max_seq_length, int batch_size, int dir_count,
154 bool has_dropout, bool is_training, RnnMode rnn_mode,
155 TFRNNInputMode rnn_input_mode, DataType dtype)
156 : num_layers_(num_layers),
157 input_size_(input_size),
158 num_units_(num_units),
159 seq_length_(max_seq_length),
160 batch_size_(batch_size),
161 dir_count_(dir_count),
162 has_dropout_(has_dropout),
163 is_training_(is_training),
164 rnn_mode_(rnn_mode),
165 rnn_input_mode_(rnn_input_mode),
166 dtype_(dtype) {
167 hash_code_ =
168 HashList({num_layers, input_size, num_units, max_seq_length, batch_size,
169 dir_count, static_cast<int>(has_dropout),
170 static_cast<int>(is_training), static_cast<int>(rnn_mode),
171 static_cast<int>(rnn_input_mode), dtype});
172 }
173
operator ==(const CudnnRnnParameters & other) const174 bool operator==(const CudnnRnnParameters& other) const {
175 return this->get_data_as_tuple() == other.get_data_as_tuple();
176 }
177
operator !=(const CudnnRnnParameters & other) const178 bool operator!=(const CudnnRnnParameters& other) const {
179 return !(*this == other);
180 }
hash() const181 uint64 hash() const { return hash_code_; }
182
ToString() const183 string ToString() const {
184 std::vector<string> fields = {
185 std::to_string(num_layers_),
186 std::to_string(input_size_),
187 std::to_string(num_units_),
188 std::to_string(seq_length_),
189 std::to_string(batch_size_),
190 std::to_string(dir_count_),
191 std::to_string(has_dropout_),
192 std::to_string(is_training_),
193 std::to_string(static_cast<int>(rnn_mode_)),
194 std::to_string(static_cast<int>(rnn_input_mode_)),
195 std::to_string(static_cast<int>(dtype_))};
196 return absl::StrJoin(fields, ", ");
197 }
198
199 private:
200 using ParameterDataType = std::tuple<int, int, int, int, int, int, bool, bool,
201 RnnMode, TFRNNInputMode, DataType>;
202
get_data_as_tuple() const203 ParameterDataType get_data_as_tuple() const {
204 return std::make_tuple(num_layers_, input_size_, num_units_, seq_length_,
205 batch_size_, dir_count_, has_dropout_, is_training_,
206 rnn_mode_, rnn_input_mode_, dtype_);
207 }
208
209 const int num_layers_;
210 const int input_size_;
211 const int num_units_;
212 const int seq_length_;
213 const int batch_size_;
214 const int dir_count_;
215 const bool has_dropout_;
216 const bool is_training_;
217 const RnnMode rnn_mode_;
218 const TFRNNInputMode rnn_input_mode_;
219 const DataType dtype_;
220 uint64 hash_code_;
221 };
222
223 struct RnnAutotuneGroup {
nametensorflow::__anon787752e60111::RnnAutotuneGroup224 static string name() { return "Rnn"; }
225 };
226
227 using AutotuneRnnConfigMap =
228 AutotuneSingleton<RnnAutotuneGroup, CudnnRnnParameters, AlgorithmConfig>;
229
ParseRNNMode(const string & str,RnnMode * rnn_mode)230 Status ParseRNNMode(const string& str, RnnMode* rnn_mode) {
231 if (str == "rnn_relu") {
232 *rnn_mode = RnnMode::kRnnRelu;
233 return Status::OK();
234 } else if (str == "rnn_tanh") {
235 *rnn_mode = RnnMode::kRnnTanh;
236 return Status::OK();
237 } else if (str == "lstm") {
238 *rnn_mode = RnnMode::kRnnLstm;
239 return Status::OK();
240 } else if (str == "gru") {
241 *rnn_mode = RnnMode::kRnnGru;
242 return Status::OK();
243 }
244 return errors::InvalidArgument("Invalid RNN mode: ", str);
245 }
246
ParseTFRNNInputMode(const string & str,TFRNNInputMode * rnn_input_mode)247 Status ParseTFRNNInputMode(const string& str, TFRNNInputMode* rnn_input_mode) {
248 if (str == "linear_input") {
249 *rnn_input_mode = TFRNNInputMode::kRNNLinearInput;
250 return Status::OK();
251 } else if (str == "skip_input") {
252 *rnn_input_mode = TFRNNInputMode::kRNNSkipInput;
253 return Status::OK();
254 } else if (str == "auto_select") {
255 *rnn_input_mode = TFRNNInputMode::kAutoSelect;
256 return Status::OK();
257 }
258 return errors::InvalidArgument("Invalid RNN input mode: ", str);
259 }
260
ParseRNNDirectionMode(const string & str,RnnDirectionMode * rnn_dir_mode)261 Status ParseRNNDirectionMode(const string& str,
262 RnnDirectionMode* rnn_dir_mode) {
263 if (str == "unidirectional") {
264 *rnn_dir_mode = RnnDirectionMode::kRnnUnidirectional;
265 return Status::OK();
266 } else if (str == "bidirectional") {
267 *rnn_dir_mode = RnnDirectionMode::kRnnBidirectional;
268 return Status::OK();
269 }
270 return errors::InvalidArgument("Invalid RNN direction mode: ", str);
271 }
272
ToRNNInputMode(TFRNNInputMode tf_input_mode,int num_units,int input_size,RnnInputMode * input_mode)273 Status ToRNNInputMode(TFRNNInputMode tf_input_mode, int num_units,
274 int input_size, RnnInputMode* input_mode) {
275 switch (tf_input_mode) {
276 case TFRNNInputMode::kRNNLinearInput:
277 *input_mode = RnnInputMode::kRnnLinearSkip;
278 break;
279 case TFRNNInputMode::kRNNSkipInput:
280 *input_mode = RnnInputMode::kRnnSkipInput;
281 break;
282 case TFRNNInputMode::kAutoSelect:
283 *input_mode = (input_size == num_units) ? RnnInputMode::kRnnSkipInput
284 : RnnInputMode::kRnnLinearSkip;
285 break;
286 default:
287 return errors::InvalidArgument("Invalid TF input mode: ",
288 static_cast<int>(tf_input_mode));
289 }
290 return Status::OK();
291 }
292
293 // TODO(zhengxq): Merge those into stream_executor_util.h.
294 template <typename T>
AsDeviceMemory(const Tensor * tensor)295 const DeviceMemory<T> AsDeviceMemory(const Tensor* tensor) {
296 return DeviceMemory<T>::MakeFromByteSize(
297 const_cast<T*>(tensor->template flat<T>().data()),
298 tensor->template flat<T>().size() * sizeof(T));
299 }
300
301 template <typename T>
AsDeviceMemory(Tensor * tensor)302 DeviceMemory<T> AsDeviceMemory(Tensor* tensor) {
303 return DeviceMemory<T>::MakeFromByteSize(
304 tensor->template flat<T>().data(),
305 tensor->template flat<T>().size() * sizeof(T));
306 }
307
308 template <typename U, typename T>
CastDeviceMemory(Tensor * tensor)309 DeviceMemory<U> CastDeviceMemory(Tensor* tensor) {
310 return DeviceMemory<U>::MakeFromByteSize(
311 tensor->template flat<T>().data(),
312 tensor->template flat<T>().size() * sizeof(T));
313 }
314
SliceDeviceMemory(const DeviceMemoryBase & device_memory,int64_t offset,int64_t size)315 DeviceMemoryBase SliceDeviceMemory(const DeviceMemoryBase& device_memory,
316 int64_t offset, int64_t size) {
317 const void* base_ptr = device_memory.opaque();
318 void* offset_ptr =
319 const_cast<char*>(reinterpret_cast<const char*>(base_ptr) + offset);
320 CHECK(offset + size <= device_memory.size())
321 << "The slice is not within the region of DeviceMemory.";
322 return DeviceMemoryBase(offset_ptr, size);
323 }
324
FromExecutorStatus(const se::port::Status & s)325 inline Status FromExecutorStatus(const se::port::Status& s) {
326 return s.ok() ? Status::OK()
327 : Status(static_cast<error::Code>(static_cast<int>(s.code())),
328 s.error_message());
329 }
330
331 template <typename T>
FromExecutorStatus(const se::port::StatusOr<T> & s)332 inline Status FromExecutorStatus(const se::port::StatusOr<T>& s) {
333 return FromExecutorStatus(s.status());
334 }
335
ToExecutorStatus(const Status & s)336 inline se::port::Status ToExecutorStatus(const Status& s) {
337 return s.ok() ? se::port::Status::OK()
338 : se::port::Status(static_cast<se::port::error::Code>(
339 static_cast<int>(s.code())),
340 s.error_message());
341 }
342
343 template <typename>
344 struct ToTFDataType;
345
346 template <>
347 struct ToTFDataType<Eigen::half> : std::integral_constant<DataType, DT_HALF> {};
348
349 template <>
350 struct ToTFDataType<float> : std::integral_constant<DataType, DT_FLOAT> {};
351
352 template <>
353 struct ToTFDataType<double> : std::integral_constant<DataType, DT_DOUBLE> {};
354
355 template <>
356 struct ToTFDataType<uint8> : std::integral_constant<DataType, DT_UINT8> {};
357
358 // A helper to allocate temporary scratch memory for Cudnn RNN models. It
359 // takes the ownership of the underlying memory. The expectation is that the
360 // memory should be alive for the span of the Cudnn RNN itself.
361 template <typename T>
362 class CudnnRnnAllocatorInTemp : public ScratchAllocator {
363 public:
364 ~CudnnRnnAllocatorInTemp() override = default;
365
CudnnRnnAllocatorInTemp(OpKernelContext * context)366 explicit CudnnRnnAllocatorInTemp(OpKernelContext* context)
367 : context_(context) {}
GetMemoryLimitInBytes()368 int64 GetMemoryLimitInBytes() override {
369 return std::numeric_limits<int64>::max();
370 }
371
AllocateBytes(int64_t byte_size)372 StatusOr<DeviceMemory<uint8>> AllocateBytes(int64_t byte_size) override {
373 Tensor temporary_memory;
374 const DataType tf_data_type = ToTFDataType<T>::value;
375 int64_t allocate_count =
376 Eigen::divup(byte_size, static_cast<int64>(sizeof(T)));
377 Status allocation_status(context_->allocate_temp(
378 tf_data_type, TensorShape({allocate_count}), &temporary_memory));
379 if (!allocation_status.ok()) {
380 return ToExecutorStatus(allocation_status);
381 }
382 // Hold the reference of the allocated tensors until the end of the
383 // allocator.
384 allocated_tensors_.push_back(temporary_memory);
385 total_byte_size_ += byte_size;
386 return DeviceMemory<uint8>::MakeFromByteSize(
387 temporary_memory.template flat<T>().data(),
388 temporary_memory.template flat<T>().size() * sizeof(T));
389 }
390
TotalByteSize() const391 int64 TotalByteSize() const { return total_byte_size_; }
392
get_allocated_tensor(int index) const393 Tensor get_allocated_tensor(int index) const {
394 return allocated_tensors_[index];
395 }
396
397 private:
398 int64 total_byte_size_ = 0;
399 OpKernelContext* context_; // not owned
400 std::vector<Tensor> allocated_tensors_;
401 };
402
403 // A helper to allocate memory for Cudnn RNN models as a kernel output. It is
404 // used by forward pass kernel to feed the output to the backward pass.
405 // The memory is expected to live long enough after the backward pass is
406 // finished.
407 template <typename T>
408 class CudnnRnnAllocatorInOutput : public ScratchAllocator {
409 public:
~CudnnRnnAllocatorInOutput()410 ~CudnnRnnAllocatorInOutput() override {}
CudnnRnnAllocatorInOutput(OpKernelContext * context,int output_index)411 CudnnRnnAllocatorInOutput(OpKernelContext* context, int output_index)
412 : context_(context), output_index_(output_index) {}
GetMemoryLimitInBytes()413 int64 GetMemoryLimitInBytes() override {
414 return std::numeric_limits<int64>::max();
415 }
AllocateBytes(int64_t byte_size)416 StatusOr<DeviceMemory<uint8>> AllocateBytes(int64_t byte_size) override {
417 CHECK(total_byte_size_ == 0)
418 << "Reserve space allocator can only be called once";
419 int64_t allocate_count =
420 Eigen::divup(byte_size, static_cast<int64>(sizeof(T)));
421
422 Tensor* temporary_memory = nullptr;
423 Status allocation_status(context_->allocate_output(
424 output_index_, TensorShape({allocate_count}), &temporary_memory));
425 if (!allocation_status.ok()) {
426 return ToExecutorStatus(allocation_status);
427 }
428 total_byte_size_ += byte_size;
429 auto memory_uint8 = DeviceMemory<uint8>::MakeFromByteSize(
430 temporary_memory->template flat<T>().data(),
431 temporary_memory->template flat<T>().size() * sizeof(T));
432 return StatusOr<DeviceMemory<uint8>>(memory_uint8);
433 }
TotalByteSize()434 int64 TotalByteSize() { return total_byte_size_; }
435
436 private:
437 int64 total_byte_size_ = 0;
438 OpKernelContext* context_; // not owned
439 int output_index_;
440 };
441
442 // A helper to allocate memory for Cudnn RNN models, which is
443 // expected to live between kernel invocations.
444 // This class is not thread-safe.
445 class CudnnRNNSpaceAllocator : public ScratchAllocator {
446 public:
CudnnRNNSpaceAllocator(OpKernelContext * context)447 explicit CudnnRNNSpaceAllocator(OpKernelContext* context)
448 : context_(context) {}
449
~CudnnRNNSpaceAllocator()450 ~CudnnRNNSpaceAllocator() override {}
451
GetMemoryLimitInBytes()452 int64 GetMemoryLimitInBytes() override {
453 return std::numeric_limits<int64>::max();
454 }
455
AllocateBytes(int64_t byte_size)456 StatusOr<DeviceMemory<uint8>> AllocateBytes(int64_t byte_size) override {
457 if (total_byte_size_ != 0) {
458 return Status(error::FAILED_PRECONDITION,
459 "Space allocator can only be called once");
460 }
461
462 Status allocation_status =
463 context_->allocate_temp(DT_UINT8, TensorShape({byte_size}), &tensor_);
464 if (!allocation_status.ok()) {
465 return ToExecutorStatus(allocation_status);
466 }
467 total_byte_size_ += byte_size;
468 return AsDeviceMemory<uint8>(&tensor_);
469 }
TotalByteSize()470 int64 TotalByteSize() { return total_byte_size_; }
471
472 private:
473 int64 total_byte_size_ = 0;
474 Tensor tensor_;
475 OpKernelContext* context_; // not owned
476 };
477
478 struct CudnnModelTypes {
479 RnnMode rnn_mode;
480 TFRNNInputMode rnn_input_mode;
481 RnnDirectionMode rnn_direction_mode;
HasInputCtensorflow::__anon787752e60111::CudnnModelTypes482 bool HasInputC() const {
483 // For Cudnn 5.0, only LSTM has input-c. All other models use only
484 // input-h.
485 return rnn_mode == RnnMode::kRnnLstm;
486 }
487
DebugStringtensorflow::__anon787752e60111::CudnnModelTypes488 string DebugString() const {
489 return strings::Printf(
490 "[rnn_mode, rnn_input_mode, rnn_direction_mode]: %d, %d, %d ",
491 static_cast<int>(rnn_mode), static_cast<int>(rnn_input_mode),
492 static_cast<int>(rnn_direction_mode));
493 }
494 };
495
496 // A helper class that collects the shapes to describe a RNN model.
497 struct CudnnRnnModelShapes {
498 int num_layers;
499 int input_size;
500 int num_units;
501 int dir_count;
502 int max_seq_length;
503 int batch_size;
504 int cell_num_units = 0;
505 // If you add new field to this structure, please take care of
506 // updating IsCompatibleWith() below as well as the hash function in
507 // CudnnRnnConfigHasher.
508 TensorShape input_shape;
509 TensorShape output_shape;
510 TensorShape hidden_state_shape;
511 TensorShape cell_state_shape;
512 // At present only fields related to cached RnnDescriptor are concerned.
IsCompatibleWithtensorflow::__anon787752e60111::CudnnRnnModelShapes513 bool IsCompatibleWith(const CudnnRnnModelShapes& rhs) const {
514 return num_layers == rhs.num_layers && input_size == rhs.input_size &&
515 num_units == rhs.num_units && dir_count == rhs.dir_count &&
516 cell_num_units == rhs.cell_num_units &&
517 max_seq_length == rhs.max_seq_length;
518 }
DebugStringtensorflow::__anon787752e60111::CudnnRnnModelShapes519 string DebugString() const {
520 return strings::Printf(
521 "[num_layers, input_size, num_units, dir_count, max_seq_length, "
522 "batch_size, cell_num_units]: [%d, %d, %d, %d, %d, %d, %d] ",
523 num_layers, input_size, num_units, dir_count, max_seq_length,
524 batch_size, cell_num_units);
525 }
526 };
527
528 // Utility class for using CudnnRnnConfig and AlgorithmDesc pair a hash table
529 // key.
530 struct CudnnRnnConfigHasher {
operator ()tensorflow::__anon787752e60111::CudnnRnnConfigHasher531 uint64 operator()(
532 const std::pair<CudnnRnnModelShapes, absl::optional<AlgorithmDesc>>&
533 to_hash) const {
534 auto& shapes = to_hash.first;
535 auto& algo_desc = to_hash.second;
536
537 uint64 hash =
538 HashList({shapes.num_layers, shapes.input_size, shapes.num_units,
539 shapes.dir_count, shapes.max_seq_length, shapes.batch_size});
540 if (algo_desc.has_value()) {
541 hash = Hash64Combine(hash, algo_desc->hash());
542 }
543 return hash;
544 }
545 };
546
547 // Utility class for using CudnnRnnModelShapes and AlgorithmDesc pair as a hash
548 // table key.
549 struct CudnnRnnConfigComparator {
operator ()tensorflow::__anon787752e60111::CudnnRnnConfigComparator550 bool operator()(
551 const std::pair<CudnnRnnModelShapes, absl::optional<AlgorithmDesc>>& lhs,
552 const std::pair<CudnnRnnModelShapes, absl::optional<AlgorithmDesc>>& rhs)
553 const {
554 return lhs.first.IsCompatibleWith(rhs.first) && lhs.second == rhs.second;
555 }
556 };
557
558 // Pointers to RNN scratch space for a specific set of shape parameters (used as
559 // a hash table value in CudnnRNNForwardOp and CudnnRNNBackwardOp).
560 struct RnnScratchSpace {
561 std::unique_ptr<RnnDescriptor> rnn_desc;
562 std::unique_ptr<CudnnRNNSpaceAllocator> dropout_state_allocator;
563 };
564
565 // Extract and checks the forward input tensors, parameters, and shapes from the
566 // OpKernelContext.
ExtractForwardInput(OpKernelContext * context,const CudnnModelTypes & model_types,bool time_major,const Tensor ** input,const Tensor ** input_h,const Tensor ** input_c,const Tensor ** params,const int num_proj,CudnnRnnModelShapes * model_shapes)567 Status ExtractForwardInput(OpKernelContext* context,
568 const CudnnModelTypes& model_types, bool time_major,
569 const Tensor** input, const Tensor** input_h,
570 const Tensor** input_c, const Tensor** params,
571 const int num_proj,
572 CudnnRnnModelShapes* model_shapes) {
573 TF_RETURN_IF_ERROR(context->input("input", input));
574 TF_RETURN_IF_ERROR(context->input("input_h", input_h));
575 if (model_types.HasInputC()) {
576 TF_RETURN_IF_ERROR(context->input("input_c", input_c));
577 }
578 TF_RETURN_IF_ERROR(context->input("params", params));
579
580 if ((*input)->dims() != 3) {
581 return errors::InvalidArgument("RNN input must be a 3-D vector.");
582 }
583 if (time_major) {
584 model_shapes->max_seq_length = (*input)->dim_size(0);
585 model_shapes->batch_size = (*input)->dim_size(1);
586 } else {
587 model_shapes->max_seq_length = (*input)->dim_size(1);
588 model_shapes->batch_size = (*input)->dim_size(0);
589 }
590 model_shapes->input_size = (*input)->dim_size(2);
591 model_shapes->input_shape = (*input)->shape();
592 model_shapes->dir_count =
593 (model_types.rnn_direction_mode == RnnDirectionMode::kRnnBidirectional)
594 ? 2
595 : 1;
596
597 if ((*input_h)->dims() != 3) {
598 return errors::InvalidArgument("RNN input_h must be a 3-D vector.");
599 }
600 if (time_major) {
601 model_shapes->num_layers =
602 (*input_h)->dim_size(0) / model_shapes->dir_count;
603 } else {
604 model_shapes->num_layers =
605 (*input_h)->dim_size(1) / model_shapes->dir_count;
606 }
607 model_shapes->num_units = (*input_h)->dim_size(2);
608
609 if (time_major) {
610 model_shapes->hidden_state_shape =
611 TensorShape({model_shapes->dir_count * model_shapes->num_layers,
612 model_shapes->batch_size, model_shapes->num_units});
613 } else {
614 model_shapes->hidden_state_shape =
615 TensorShape({model_shapes->batch_size,
616 model_shapes->dir_count * model_shapes->num_layers,
617 model_shapes->num_units});
618 }
619 if ((*input_h)->shape() != model_shapes->hidden_state_shape) {
620 return errors::InvalidArgument(
621 "Invalid input_h shape: ", (*input_h)->shape().DebugString(), " ",
622 model_shapes->hidden_state_shape.DebugString());
623 }
624 if (model_types.HasInputC()) {
625 model_shapes->cell_num_units = (*input_c)->dim_size(2);
626 if (time_major) {
627 model_shapes->cell_state_shape =
628 TensorShape({model_shapes->dir_count * model_shapes->num_layers,
629 model_shapes->batch_size, model_shapes->cell_num_units});
630 } else {
631 model_shapes->cell_state_shape =
632 TensorShape({model_shapes->batch_size,
633 model_shapes->dir_count * model_shapes->num_layers,
634 model_shapes->cell_num_units});
635 }
636 if (num_proj == 0) {
637 if ((*input_h)->shape() != (*input_c)->shape()) {
638 return errors::InvalidArgument(
639 "input_h and input_c must have the same shape w/o projection: ",
640 (*input_h)->shape().DebugString(), " ",
641 (*input_c)->shape().DebugString());
642 }
643 } else {
644 if ((*input_h)->dim_size(2) > (*input_c)->dim_size(2) ||
645 num_proj != (*input_h)->dim_size(2) ||
646 (*input_h)->dim_size(0) != (*input_c)->dim_size(0) ||
647 (*input_h)->dim_size(1) != (*input_c)->dim_size(1)) {
648 return errors::InvalidArgument(
649 "Invalid input_h and input_c w/ projection size: ", num_proj, " ",
650 (*input_h)->shape().DebugString(), " ",
651 (*input_c)->shape().DebugString());
652 }
653 }
654 } else {
655 // dummy cell_state_shape TODO(kaixih): remove the time_major branch
656 if (time_major) {
657 model_shapes->cell_state_shape =
658 TensorShape({model_shapes->dir_count * model_shapes->num_layers,
659 model_shapes->batch_size, model_shapes->num_units});
660 } else {
661 model_shapes->cell_state_shape =
662 TensorShape({model_shapes->batch_size,
663 model_shapes->dir_count * model_shapes->num_layers,
664 model_shapes->num_units});
665 }
666 model_shapes->cell_num_units = 0;
667 }
668 if (time_major) {
669 model_shapes->output_shape =
670 TensorShape({model_shapes->max_seq_length, model_shapes->batch_size,
671 model_shapes->dir_count * model_shapes->num_units});
672 } else {
673 model_shapes->output_shape =
674 TensorShape({model_shapes->batch_size, model_shapes->max_seq_length,
675 model_shapes->dir_count * model_shapes->num_units});
676 }
677 return Status::OK();
678 }
679
680 // Overloaded function to process the sequence_lengths
ExtractForwardInput(OpKernelContext * context,const CudnnModelTypes & model_types,bool time_major,const Tensor ** input,const Tensor ** input_h,const Tensor ** input_c,const Tensor ** params,const Tensor ** sequence_lengths,const int num_proj,CudnnRnnModelShapes * model_shapes)681 Status ExtractForwardInput(OpKernelContext* context,
682 const CudnnModelTypes& model_types, bool time_major,
683 const Tensor** input, const Tensor** input_h,
684 const Tensor** input_c, const Tensor** params,
685 const Tensor** sequence_lengths, const int num_proj,
686 CudnnRnnModelShapes* model_shapes) {
687 TF_RETURN_IF_ERROR(context->input("sequence_lengths", sequence_lengths));
688 return ExtractForwardInput(context, model_types, time_major, input, input_h,
689 input_c, params, num_proj, model_shapes);
690 }
691
692 template <typename T>
CreateForwardAndBackwardIODescriptors(OpKernelContext * context,const CudnnRnnModelShapes & model_shapes,std::unique_ptr<RnnSequenceTensorDescriptor> * input_desc,std::unique_ptr<RnnStateTensorDescriptor> * h_state_desc,std::unique_ptr<RnnStateTensorDescriptor> * c_state_desc,std::unique_ptr<RnnSequenceTensorDescriptor> * output_desc,const absl::Span<const int> seq_lengths,bool time_major)693 Status CreateForwardAndBackwardIODescriptors(
694 OpKernelContext* context, const CudnnRnnModelShapes& model_shapes,
695 std::unique_ptr<RnnSequenceTensorDescriptor>* input_desc,
696 std::unique_ptr<RnnStateTensorDescriptor>* h_state_desc,
697 std::unique_ptr<RnnStateTensorDescriptor>* c_state_desc,
698 std::unique_ptr<RnnSequenceTensorDescriptor>* output_desc,
699 const absl::Span<const int> seq_lengths, bool time_major) {
700 StreamExecutor* executor = context->op_device_context()->stream()->parent();
701 se::dnn::DataType data_type = ToDataType<T>::value;
702
703 const TensorShape& input_shape = model_shapes.input_shape;
704 const TensorShape& hidden_state_shape = model_shapes.hidden_state_shape;
705 const TensorShape& cell_state_shape = model_shapes.cell_state_shape;
706 const TensorShape& output_shape = model_shapes.output_shape;
707
708 DCHECK_EQ(input_shape.dims(), 3);
709 if (seq_lengths.data() != nullptr) {
710 if (time_major) {
711 auto input_desc_s = executor->createRnnSequenceTensorDescriptor(
712 input_shape.dim_size(0), input_shape.dim_size(1),
713 input_shape.dim_size(2), seq_lengths, time_major, data_type);
714 TF_RETURN_IF_ERROR(input_desc_s.status());
715 *input_desc = input_desc_s.ConsumeValueOrDie();
716 } else {
717 auto input_desc_s = executor->createRnnSequenceTensorDescriptor(
718 input_shape.dim_size(1), input_shape.dim_size(0),
719 input_shape.dim_size(2), seq_lengths, time_major, data_type);
720 TF_RETURN_IF_ERROR(input_desc_s.status());
721 *input_desc = input_desc_s.ConsumeValueOrDie();
722 }
723 } else {
724 auto input_desc_s = executor->createRnnSequenceTensorDescriptor(
725 input_shape.dim_size(0), input_shape.dim_size(1),
726 input_shape.dim_size(2), data_type);
727 TF_RETURN_IF_ERROR(input_desc_s.status());
728 *input_desc = input_desc_s.ConsumeValueOrDie();
729 }
730
731 DCHECK_EQ(hidden_state_shape.dims(), 3);
732 if (time_major) {
733 auto hidden_state_desc_s = executor->createRnnStateTensorDescriptor(
734 hidden_state_shape.dim_size(0), hidden_state_shape.dim_size(1),
735 hidden_state_shape.dim_size(2), data_type);
736 TF_RETURN_IF_ERROR(hidden_state_desc_s.status());
737 *h_state_desc = hidden_state_desc_s.ConsumeValueOrDie();
738 } else {
739 auto hidden_state_desc_s = executor->createRnnStateTensorDescriptor(
740 hidden_state_shape.dim_size(1), hidden_state_shape.dim_size(0),
741 hidden_state_shape.dim_size(2), data_type);
742 TF_RETURN_IF_ERROR(hidden_state_desc_s.status());
743 *h_state_desc = hidden_state_desc_s.ConsumeValueOrDie();
744 }
745
746 DCHECK_EQ(cell_state_shape.dims(), 3);
747 if (time_major) {
748 auto cell_state_desc_s = executor->createRnnStateTensorDescriptor(
749 cell_state_shape.dim_size(0), cell_state_shape.dim_size(1),
750 cell_state_shape.dim_size(2), data_type);
751 TF_RETURN_IF_ERROR(cell_state_desc_s.status());
752 *c_state_desc = cell_state_desc_s.ConsumeValueOrDie();
753 } else {
754 auto cell_state_desc_s = executor->createRnnStateTensorDescriptor(
755 cell_state_shape.dim_size(1), cell_state_shape.dim_size(0),
756 cell_state_shape.dim_size(2), data_type);
757 TF_RETURN_IF_ERROR(cell_state_desc_s.status());
758 *c_state_desc = cell_state_desc_s.ConsumeValueOrDie();
759 }
760
761 DCHECK_EQ(output_shape.dims(), 3);
762 if (seq_lengths.data() != nullptr) {
763 if (time_major) {
764 auto output_desc_s = executor->createRnnSequenceTensorDescriptor(
765 output_shape.dim_size(0), output_shape.dim_size(1),
766 output_shape.dim_size(2), seq_lengths, time_major, data_type);
767 TF_RETURN_IF_ERROR(output_desc_s.status());
768 *output_desc = output_desc_s.ConsumeValueOrDie();
769 } else {
770 auto output_desc_s = executor->createRnnSequenceTensorDescriptor(
771 output_shape.dim_size(1), output_shape.dim_size(0),
772 output_shape.dim_size(2), seq_lengths, time_major, data_type);
773 TF_RETURN_IF_ERROR(output_desc_s.status());
774 *output_desc = output_desc_s.ConsumeValueOrDie();
775 }
776 } else {
777 auto output_desc_s = executor->createRnnSequenceTensorDescriptor(
778 output_shape.dim_size(0), output_shape.dim_size(1),
779 output_shape.dim_size(2), data_type);
780 TF_RETURN_IF_ERROR(output_desc_s.status());
781 *output_desc = output_desc_s.ConsumeValueOrDie();
782 }
783
784 return Status::OK();
785 }
786
787 template <typename T>
DoForward(OpKernelContext * context,const RnnDescriptor & rnn_desc,const CudnnModelTypes & model_types,const CudnnRnnModelShapes & model_shapes,const Tensor * input,const Tensor * input_h,const Tensor * input_c,const Tensor * params,const bool is_training,Tensor * output,Tensor * output_h,Tensor * output_c,const Tensor * sequence_lengths,bool time_major,ScratchAllocator * reserve_space_allocator,ScratchAllocator * workspace_allocator,ProfileResult * output_profile_result)788 Status DoForward(OpKernelContext* context, const RnnDescriptor& rnn_desc,
789 const CudnnModelTypes& model_types,
790 const CudnnRnnModelShapes& model_shapes,
791 /* forward inputs */
792 const Tensor* input, const Tensor* input_h,
793 const Tensor* input_c, const Tensor* params,
794 const bool is_training,
795 /* forward outputs, outputs of the function */
796 Tensor* output, Tensor* output_h, Tensor* output_c,
797 const Tensor* sequence_lengths, bool time_major,
798 ScratchAllocator* reserve_space_allocator,
799 ScratchAllocator* workspace_allocator,
800 ProfileResult* output_profile_result) {
801 std::unique_ptr<RnnSequenceTensorDescriptor> input_desc;
802 std::unique_ptr<RnnStateTensorDescriptor> h_state_desc;
803 std::unique_ptr<RnnStateTensorDescriptor> c_state_desc;
804 std::unique_ptr<RnnSequenceTensorDescriptor> output_desc;
805
806 absl::Span<const int> seq_lengths;
807 if (sequence_lengths != nullptr) {
808 seq_lengths = absl::Span<const int>(
809 sequence_lengths->template flat<int>().data(), model_shapes.batch_size);
810 }
811 TF_RETURN_IF_ERROR(CreateForwardAndBackwardIODescriptors<T>(
812 context, model_shapes, &input_desc, &h_state_desc, &c_state_desc,
813 &output_desc, seq_lengths, time_major));
814
815 auto input_data = AsDeviceMemory<T>(input);
816 auto input_h_data = AsDeviceMemory<T>(input_h);
817 DeviceMemory<T> input_c_data;
818 if (model_types.HasInputC()) {
819 input_c_data = AsDeviceMemory<T>(input_c);
820 }
821
822 auto params_data = AsDeviceMemory<T>(params);
823 auto output_data = AsDeviceMemory<T>(output);
824 auto output_h_data = AsDeviceMemory<T>(output_h);
825 DeviceMemory<T> output_c_data;
826 if (model_types.HasInputC()) {
827 output_c_data = AsDeviceMemory<T>(output_c);
828 }
829
830 Stream* stream = context->op_device_context()->stream();
831
832 Tensor seq_lengths_tensor;
833 DeviceMemory<int> seq_lengths_ptr;
834 if (sequence_lengths != nullptr) {
835 TF_RETURN_IF_ERROR(context->allocate_temp(
836 DT_INT32, {static_cast<long>(seq_lengths.size())},
837 &seq_lengths_tensor));
838 seq_lengths_ptr = AsDeviceMemory<int>(&seq_lengths_tensor);
839 if (!stream
840 ->ThenMemcpy(&seq_lengths_ptr, seq_lengths.data(),
841 seq_lengths.size() * sizeof(int))
842 .ok()) {
843 return errors::InvalidArgument(
844 "Failed to copy memory from host to "
845 "device for sequence_lengths in "
846 "CudnnRNNV3");
847 }
848 }
849
850 bool launch_success =
851 stream
852 ->ThenRnnForward(rnn_desc, *input_desc, input_data, seq_lengths_ptr,
853 *h_state_desc, input_h_data, *c_state_desc,
854 input_c_data, params_data, *output_desc,
855 &output_data, *h_state_desc, &output_h_data,
856 *c_state_desc, &output_c_data, is_training,
857 reserve_space_allocator, workspace_allocator,
858 output_profile_result)
859 .ok();
860 return launch_success
861 ? Status::OK()
862 : errors::Internal(
863 "Failed to call ThenRnnForward with model config: ",
864 model_types.DebugString(), ", ", model_shapes.DebugString());
865 }
866
867 template <typename T>
DoBackward(OpKernelContext * context,const RnnDescriptor & rnn_desc,const CudnnModelTypes & model_types,const CudnnRnnModelShapes & model_shapes,const Tensor * input,const Tensor * input_h,const Tensor * input_c,const Tensor * params,const Tensor * output,const Tensor * output_h,const Tensor * output_c,const Tensor * output_backprop,const Tensor * output_h_backprop,const Tensor * output_c_backprop,const Tensor * reserve_space,Tensor * input_backprop,Tensor * input_h_backprop,Tensor * input_c_backprop,Tensor * params_backprop,const Tensor * sequence_lengths,bool time_major,ScratchAllocator * workspace_allocator,ProfileResult * output_profile_result)868 Status DoBackward(
869 OpKernelContext* context, const RnnDescriptor& rnn_desc,
870 const CudnnModelTypes& model_types, const CudnnRnnModelShapes& model_shapes,
871 /* forward inputs */
872 const Tensor* input, const Tensor* input_h, const Tensor* input_c,
873 const Tensor* params,
874 /* forward outputs */
875 const Tensor* output, const Tensor* output_h, const Tensor* output_c,
876 /* backprop inputs */
877 const Tensor* output_backprop, const Tensor* output_h_backprop,
878 const Tensor* output_c_backprop, const Tensor* reserve_space,
879 /* backprop outputs, output of the function */
880 Tensor* input_backprop, Tensor* input_h_backprop, Tensor* input_c_backprop,
881 Tensor* params_backprop, const Tensor* sequence_lengths, bool time_major,
882 ScratchAllocator* workspace_allocator,
883 ProfileResult* output_profile_result) {
884 std::unique_ptr<RnnSequenceTensorDescriptor> input_desc;
885 std::unique_ptr<RnnStateTensorDescriptor> h_state_desc;
886 std::unique_ptr<RnnStateTensorDescriptor> c_state_desc;
887 std::unique_ptr<RnnSequenceTensorDescriptor> output_desc;
888
889 absl::Span<const int> seq_lengths;
890 if (sequence_lengths != nullptr) {
891 seq_lengths = absl::Span<const int>(
892 sequence_lengths->template flat<int>().data(), model_shapes.batch_size);
893 }
894 TF_RETURN_IF_ERROR(CreateForwardAndBackwardIODescriptors<T>(
895 context, model_shapes, &input_desc, &h_state_desc, &c_state_desc,
896 &output_desc, seq_lengths, time_major));
897
898 auto input_data = AsDeviceMemory<T>(input);
899 auto input_h_data = AsDeviceMemory<T>(input_h);
900 DeviceMemory<T> input_c_data;
901 if (model_types.HasInputC()) {
902 input_c_data = AsDeviceMemory<T>(input_c);
903 }
904 auto params_data = AsDeviceMemory<T>(params);
905 auto output_data = AsDeviceMemory<T>(output);
906 auto output_h_data = AsDeviceMemory<T>(output_h);
907 DeviceMemory<T> output_c_data;
908 if (model_types.HasInputC()) {
909 output_c_data = AsDeviceMemory<T>(output_c);
910 }
911 auto output_backprop_data = AsDeviceMemory<T>(output_backprop);
912 auto output_h_backprop_data = AsDeviceMemory<T>(output_h_backprop);
913 DeviceMemory<T> output_c_backprop_data;
914 if (model_types.HasInputC()) {
915 output_c_backprop_data = AsDeviceMemory<T>(output_c_backprop);
916 }
917 auto input_backprop_data = AsDeviceMemory<T>(input_backprop);
918 auto input_h_backprop_data = AsDeviceMemory<T>(input_h_backprop);
919 DeviceMemory<T> input_c_backprop_data;
920 if (model_types.HasInputC()) {
921 input_c_backprop_data = AsDeviceMemory<T>(input_c_backprop);
922 }
923 auto params_backprop_data = AsDeviceMemory<T>(params_backprop);
924 auto reserve_space_uint8 =
925 CastDeviceMemory<uint8, T>(const_cast<Tensor*>(reserve_space));
926
927 // Creates a memory callback for the workspace. The memory lives to the end
928 // of this kernel calls.
929 Stream* stream = context->op_device_context()->stream();
930
931 Tensor seq_lengths_tensor;
932 DeviceMemory<int> seq_lengths_ptr;
933 if (sequence_lengths != nullptr) {
934 TF_RETURN_IF_ERROR(context->allocate_temp(
935 DT_INT32, {static_cast<long>(seq_lengths.size())},
936 &seq_lengths_tensor));
937 seq_lengths_ptr = AsDeviceMemory<int>(&seq_lengths_tensor);
938 if (!stream
939 ->ThenMemcpy(&seq_lengths_ptr, seq_lengths.data(),
940 seq_lengths.size() * sizeof(int))
941 .ok()) {
942 return errors::InvalidArgument(
943 "Failed to copy memory from host to "
944 "device for sequence_lengths in "
945 "CudnnRNNBackwardOpV3");
946 }
947 }
948
949 bool launch_success =
950 stream
951 ->ThenRnnBackward(
952 rnn_desc, *input_desc, input_data, seq_lengths_ptr, *h_state_desc,
953 input_h_data, *c_state_desc, input_c_data, params_data,
954 *output_desc, output_data, *h_state_desc, output_h_data,
955 *c_state_desc, output_c_data, output_backprop_data,
956 output_h_backprop_data, output_c_backprop_data,
957 &input_backprop_data, &input_h_backprop_data,
958 &input_c_backprop_data, ¶ms_backprop_data,
959 &reserve_space_uint8, workspace_allocator, output_profile_result)
960 .ok();
961 return launch_success
962 ? Status::OK()
963 : errors::Internal(
964 "Failed to call ThenRnnBackward with model config: ",
965 model_types.DebugString(), ", ", model_shapes.DebugString());
966 }
967
968 template <typename T>
RestoreParams(const OpInputList params_input,const std::vector<RnnDescriptor::ParamsRegion> & params,DeviceMemoryBase * data_dst,Stream * stream)969 void RestoreParams(const OpInputList params_input,
970 const std::vector<RnnDescriptor::ParamsRegion>& params,
971 DeviceMemoryBase* data_dst, Stream* stream) {
972 int num_params = params.size();
973 CHECK(params_input.size() == num_params)
974 << "Number of params mismatch. Expected " << params_input.size()
975 << ", got " << num_params;
976 for (int i = 0; i < params.size(); i++) {
977 int64_t size_in_bytes = params[i].size;
978 int64_t size = size_in_bytes / sizeof(T);
979 CHECK(size == params_input[i].NumElements())
980 << "Params size mismatch. Expected " << size << ", got "
981 << params_input[i].NumElements();
982 auto data_src_ptr = StreamExecutorUtil::AsDeviceMemory<T>(params_input[i]);
983 DeviceMemoryBase data_dst_ptr =
984 SliceDeviceMemory(*data_dst, params[i].offset, size_in_bytes);
985 stream->ThenMemcpy(&data_dst_ptr, data_src_ptr, size_in_bytes);
986 }
987 }
988
ShouldUsePaddedIO(const Tensor * sequence_lengths,const CudnnRnnModelShapes & model_shapes,bool time_major)989 bool ShouldUsePaddedIO(const Tensor* sequence_lengths,
990 const CudnnRnnModelShapes& model_shapes,
991 bool time_major) {
992 auto seq_array = sequence_lengths->template flat<int>().data();
993 bool all_max_seq_length = true;
994 for (int i = 0; i < model_shapes.batch_size; i++) {
995 if (seq_array[i] != model_shapes.max_seq_length) {
996 all_max_seq_length = false;
997 break;
998 }
999 }
1000 return !(time_major && all_max_seq_length);
1001 }
1002
1003 } // namespace
1004
1005 // Note: all following kernels depend on a RnnDescriptor instance, which
1006 // according to Cudnn official doc should be kept around and reused across all
1007 // Cudnn kernels in the same model.
1008 // In Tensorflow, we don't pass the reference across different OpKernels,
1009 // rather, recreate it separately in each OpKernel, which does no cause issue:
1010 // CudnnDropoutDescriptor keeps a reference to a memory for
1011 // random number generator state. During recreation, this state is lost.
1012 // However, only forward-pass Cudnn APIs make use of the state.
1013
1014 // A common base class for RNN kernels. It extracts common attributes and
1015 // shape validations.
1016 class CudnnRNNKernelCommon : public OpKernel {
1017 protected:
CudnnRNNKernelCommon(OpKernelConstruction * context)1018 explicit CudnnRNNKernelCommon(OpKernelConstruction* context)
1019 : OpKernel(context) {
1020 OP_REQUIRES_OK(context, context->GetAttr("dropout", &dropout_));
1021 OP_REQUIRES_OK(context, context->GetAttr("seed", &seed_));
1022 OP_REQUIRES_OK(context, context->GetAttr("seed2", &seed2_));
1023 string str;
1024 OP_REQUIRES_OK(context, context->GetAttr("rnn_mode", &str));
1025 OP_REQUIRES_OK(context, ParseRNNMode(str, &model_types_.rnn_mode));
1026 OP_REQUIRES_OK(context, context->GetAttr("input_mode", &str));
1027 OP_REQUIRES_OK(context,
1028 ParseTFRNNInputMode(str, &model_types_.rnn_input_mode));
1029 OP_REQUIRES_OK(context, context->GetAttr("direction", &str));
1030 OP_REQUIRES_OK(
1031 context, ParseRNNDirectionMode(str, &model_types_.rnn_direction_mode));
1032 // Reset CudnnRnnDescriptor and related random number generate states in
1033 // every Compute() call.
1034 OP_REQUIRES_OK(context, ReadBoolFromEnvVar("TF_CUDNN_RESET_RND_GEN_STATE",
1035 false, &reset_rnd_gen_state_));
1036 }
1037
HasInputC() const1038 bool HasInputC() const { return model_types_.HasInputC(); }
rnn_mode() const1039 RnnMode rnn_mode() const { return model_types_.rnn_mode; }
rnn_input_mode() const1040 TFRNNInputMode rnn_input_mode() const { return model_types_.rnn_input_mode; }
rnn_direction_mode() const1041 RnnDirectionMode rnn_direction_mode() const {
1042 return model_types_.rnn_direction_mode;
1043 }
model_types() const1044 const CudnnModelTypes& model_types() const { return model_types_; }
dropout() const1045 float dropout() const { return dropout_; }
seed()1046 uint64 seed() { return (static_cast<uint64>(seed_) << 32) | seed2_; }
ResetRndGenState()1047 bool ResetRndGenState() { return reset_rnd_gen_state_; }
1048
1049 template <typename T>
ExtractCudnnRNNParamsInfo(OpKernelContext * context,int num_proj,std::unique_ptr<RnnDescriptor> * rnn_desc)1050 Status ExtractCudnnRNNParamsInfo(OpKernelContext* context, int num_proj,
1051 std::unique_ptr<RnnDescriptor>* rnn_desc) {
1052 const Tensor* num_layers_t = nullptr;
1053 TF_RETURN_IF_ERROR(context->input("num_layers", &num_layers_t));
1054 if (!TensorShapeUtils::IsScalar(num_layers_t->shape())) {
1055 return errors::InvalidArgument("num_layers is not a scalar");
1056 }
1057 int num_layers = num_layers_t->scalar<int>()();
1058 const Tensor* num_units_t = nullptr;
1059 TF_RETURN_IF_ERROR(context->input("num_units", &num_units_t));
1060 if (!TensorShapeUtils::IsScalar(num_units_t->shape())) {
1061 return errors::InvalidArgument("num_units is not a scalar");
1062 }
1063 int num_units = num_units_t->scalar<int>()();
1064 const Tensor* input_size_t = nullptr;
1065 TF_RETURN_IF_ERROR(context->input("input_size", &input_size_t));
1066 if (!TensorShapeUtils::IsScalar(input_size_t->shape())) {
1067 return errors::InvalidArgument("input_size is not a scalar");
1068 }
1069 int input_size = input_size_t->scalar<int>()();
1070
1071 int h_num_units = (num_proj == 0 ? num_units : num_proj);
1072 int c_num_units = (num_proj == 0 ? 0 : num_units);
1073
1074 RnnInputMode input_mode;
1075 TF_RETURN_IF_ERROR(
1076 ToRNNInputMode(rnn_input_mode(), num_units, input_size, &input_mode));
1077
1078 Stream* stream = context->op_device_context()->stream();
1079 // ExtracCudnnRNNParamsInfo is only called by op_kernels that do not require
1080 // random number generator, therefore set state_allocator to nullptr.
1081 const AlgorithmConfig algo_config;
1082 auto rnn_desc_s = stream->parent()->createRnnDescriptor(
1083 num_layers, h_num_units, input_size, /*cell_size=*/c_num_units,
1084 /*batch_size=*/0, input_mode, rnn_direction_mode(), rnn_mode(),
1085 ToDataType<T>::value, algo_config, dropout(), seed(),
1086 /* state_allocator=*/nullptr, /*use_padded_io=*/false);
1087 if (!rnn_desc_s.ok()) {
1088 return FromExecutorStatus(rnn_desc_s);
1089 }
1090 *rnn_desc = rnn_desc_s.ConsumeValueOrDie();
1091 return Status::OK();
1092 }
1093
1094 template <typename T>
CreateRnnDescriptor(OpKernelContext * context,const CudnnRnnModelShapes & model_shapes,const RnnInputMode & input_mode,const AlgorithmConfig & algo_config,ScratchAllocator * dropout_state_allocator,std::unique_ptr<RnnDescriptor> * rnn_desc,bool use_padded_io)1095 Status CreateRnnDescriptor(OpKernelContext* context,
1096 const CudnnRnnModelShapes& model_shapes,
1097 const RnnInputMode& input_mode,
1098 const AlgorithmConfig& algo_config,
1099 ScratchAllocator* dropout_state_allocator,
1100 std::unique_ptr<RnnDescriptor>* rnn_desc,
1101 bool use_padded_io) {
1102 StreamExecutor* executor = context->op_device_context()->stream()->parent();
1103 se::dnn::DataType data_type = ToDataType<T>::value;
1104 auto rnn_desc_s = executor->createRnnDescriptor(
1105 model_shapes.num_layers, model_shapes.num_units,
1106 model_shapes.input_size, model_shapes.cell_num_units,
1107 model_shapes.batch_size, input_mode, rnn_direction_mode(), rnn_mode(),
1108 data_type, algo_config, dropout(), seed(), dropout_state_allocator,
1109 use_padded_io);
1110 TF_RETURN_IF_ERROR(rnn_desc_s.status());
1111
1112 *rnn_desc = rnn_desc_s.ConsumeValueOrDie();
1113 return Status::OK();
1114 }
1115
1116 using RnnStateCache = gtl::FlatMap<
1117 std::pair<CudnnRnnModelShapes, absl::optional<AlgorithmDesc>>,
1118 RnnScratchSpace, CudnnRnnConfigHasher, CudnnRnnConfigComparator>;
1119 // Returns a raw rnn descriptor pointer. The cache owns the rnn descriptor and
1120 // should outlive the returned pointer.
1121 template <typename T>
GetCachedRnnDescriptor(OpKernelContext * context,const CudnnRnnModelShapes & model_shapes,const RnnInputMode & input_mode,const AlgorithmConfig & algo_config,RnnStateCache * cache,RnnDescriptor ** rnn_desc,bool use_padded_io)1122 Status GetCachedRnnDescriptor(OpKernelContext* context,
1123 const CudnnRnnModelShapes& model_shapes,
1124 const RnnInputMode& input_mode,
1125 const AlgorithmConfig& algo_config,
1126 RnnStateCache* cache, RnnDescriptor** rnn_desc,
1127 bool use_padded_io) {
1128 auto key = std::make_pair(model_shapes, algo_config.algorithm());
1129 RnnScratchSpace& rnn_state = (*cache)[key];
1130 if (rnn_state.rnn_desc == nullptr || ResetRndGenState()) {
1131 CudnnRNNSpaceAllocator* dropout_state_allocator =
1132 new CudnnRNNSpaceAllocator(context);
1133 rnn_state.dropout_state_allocator.reset(dropout_state_allocator);
1134 Status status = CreateRnnDescriptor<T>(
1135 context, model_shapes, input_mode, algo_config,
1136 dropout_state_allocator, &rnn_state.rnn_desc, use_padded_io);
1137 TF_RETURN_IF_ERROR(status);
1138 }
1139 *rnn_desc = rnn_state.rnn_desc.get();
1140 return Status::OK();
1141 }
1142
1143 private:
1144 int seed_;
1145 int seed2_;
1146 float dropout_;
1147 bool reset_rnd_gen_state_;
1148
1149 CudnnModelTypes model_types_;
1150 };
1151
1152 // A class that returns the size of the opaque parameter buffer. The user should
1153 // use that to create the actual parameter buffer for training. However, it
1154 // should not be used for saving and restoring.
1155 template <typename T, typename Index>
1156 class CudnnRNNParamsSizeOp<GPUDevice, T, Index> : public CudnnRNNKernelCommon {
1157 public:
CudnnRNNParamsSizeOp(OpKernelConstruction * context)1158 explicit CudnnRNNParamsSizeOp(OpKernelConstruction* context)
1159 : CudnnRNNKernelCommon(context) {
1160 if (context->HasAttr("num_proj")) {
1161 OP_REQUIRES_OK(context, context->GetAttr("num_proj", &num_proj_));
1162 } else {
1163 num_proj_ = 0;
1164 }
1165 }
1166
Compute(OpKernelContext * context)1167 void Compute(OpKernelContext* context) override {
1168 std::unique_ptr<RnnDescriptor> rnn_desc;
1169 OP_REQUIRES_OK(context,
1170 ExtractCudnnRNNParamsInfo<T>(context, num_proj_, &rnn_desc));
1171 int64_t params_size_in_bytes = rnn_desc->ParamsSizeInBytes();
1172 CHECK(params_size_in_bytes % sizeof(T) == 0)
1173 << "params_size_in_bytes must be multiple of element size";
1174 int64_t params_size = params_size_in_bytes / sizeof(T);
1175
1176 Tensor* output_t = nullptr;
1177 OP_REQUIRES_OK(context, context->allocate_output(0, {1}, &output_t));
1178 *output_t->template flat<Index>().data() = params_size;
1179 }
1180
1181 private:
1182 int num_proj_;
1183 };
1184
1185 #define REGISTER_GPU(T) \
1186 REGISTER_KERNEL_BUILDER(Name("CudnnRNNParamsSize") \
1187 .Device(DEVICE_GPU) \
1188 .HostMemory("num_layers") \
1189 .HostMemory("num_units") \
1190 .HostMemory("input_size") \
1191 .HostMemory("params_size") \
1192 .TypeConstraint<T>("T") \
1193 .TypeConstraint<int32>("S"), \
1194 CudnnRNNParamsSizeOp<GPUDevice, T, int32>);
1195
1196 TF_CALL_half(REGISTER_GPU);
1197 TF_CALL_float(REGISTER_GPU);
1198 TF_CALL_double(REGISTER_GPU);
1199 #undef REGISTER_GPU
1200
1201 // Convert weight and bias params from a platform-specific layout to the
1202 // canonical form.
1203 template <typename T>
1204 class CudnnRNNParamsToCanonical<GPUDevice, T> : public CudnnRNNKernelCommon {
1205 public:
CudnnRNNParamsToCanonical(OpKernelConstruction * context)1206 explicit CudnnRNNParamsToCanonical(OpKernelConstruction* context)
1207 : CudnnRNNKernelCommon(context) {
1208 if (context->HasAttr("num_params")) {
1209 OP_REQUIRES_OK(context, context->GetAttr("num_params", &num_params_));
1210 } else {
1211 num_params_ = 0;
1212 }
1213 if (context->HasAttr("num_params_weights")) {
1214 OP_REQUIRES_OK(context, context->GetAttr("num_params_weights",
1215 &num_params_weights_));
1216 } else {
1217 num_params_weights_ = 0;
1218 }
1219 if (context->HasAttr("num_params_biases")) {
1220 OP_REQUIRES_OK(
1221 context, context->GetAttr("num_params_biases", &num_params_biases_));
1222 } else {
1223 num_params_biases_ = 0;
1224 }
1225 if (context->HasAttr("num_proj")) {
1226 OP_REQUIRES_OK(context, context->GetAttr("num_proj", &num_proj_));
1227 } else {
1228 num_proj_ = 0;
1229 }
1230 if (num_proj_ == 0) {
1231 num_params_weights_ = num_params_;
1232 num_params_biases_ = num_params_;
1233 }
1234 }
1235
Compute(OpKernelContext * context)1236 void Compute(OpKernelContext* context) override {
1237 const Tensor& input = context->input(3);
1238 auto input_ptr = StreamExecutorUtil::AsDeviceMemory<T>(input);
1239 Stream* stream = context->op_device_context()->stream();
1240
1241 std::unique_ptr<RnnDescriptor> rnn_desc;
1242 OP_REQUIRES_OK(context,
1243 ExtractCudnnRNNParamsInfo<T>(context, num_proj_, &rnn_desc));
1244 int64_t params_size_in_bytes = rnn_desc->ParamsSizeInBytes();
1245 CHECK(params_size_in_bytes % sizeof(T) == 0)
1246 << "params_size_in_bytes must be multiple of element size";
1247
1248 const Tensor* num_units_t = nullptr;
1249 OP_REQUIRES_OK(context, context->input("num_units", &num_units_t));
1250 CHECK(TensorShapeUtils::IsScalar(num_units_t->shape()))
1251 << "num_units is not a scalar";
1252 int num_units = num_units_t->scalar<int>()();
1253
1254 const Tensor* input_size_t = nullptr;
1255 OP_REQUIRES_OK(context, context->input("input_size", &input_size_t));
1256 CHECK(TensorShapeUtils::IsScalar(input_size_t->shape()))
1257 << "input_size is not a scalar";
1258 int input_size = input_size_t->scalar<int>()();
1259
1260 const Tensor* num_layers_t = nullptr;
1261 OP_REQUIRES_OK(context, context->input("num_layers", &num_layers_t));
1262 CHECK(TensorShapeUtils::IsScalar(num_layers_t->shape()))
1263 << "num_layers is not a scalar";
1264 int num_layers = num_layers_t->scalar<int>()();
1265 int num_dirs = 1;
1266 if (rnn_direction_mode() == RnnDirectionMode::kRnnBidirectional) {
1267 num_dirs = 2;
1268 }
1269 const int num_params_weights_per_layer =
1270 num_params_weights_ / num_layers / num_dirs;
1271 // Number of params applied on inputs. The rest are applied on recurrent
1272 // hidden states.
1273 const int num_params_input_state = num_params_weights_per_layer / 2;
1274 OP_REQUIRES(
1275 context, num_params_weights_ % (num_layers * num_dirs) == 0,
1276 errors::InvalidArgument("Number of params (weights) is not a multiple"
1277 "of num_layers * num_dirs."));
1278 OP_REQUIRES(
1279 context, num_params_biases_ % (num_layers * num_dirs) == 0,
1280 errors::InvalidArgument("Number of params (biases) is not a multiple"
1281 "of num_layers * num_dirs."));
1282 if (num_proj_ == 0) {
1283 OP_REQUIRES(
1284 context, num_params_weights_per_layer % 2 == 0,
1285 errors::InvalidArgument("Number of params (weights) per layer is not"
1286 "an even number with no projection."));
1287 } else {
1288 OP_REQUIRES(
1289 context, num_params_weights_per_layer % 2 != 0,
1290 errors::InvalidArgument("Number of params (weights) per layer is not"
1291 "an odl number with projection."));
1292 }
1293
1294 OP_REQUIRES(
1295 context, num_params_weights_ == rnn_desc->ParamsWeightRegions().size(),
1296 errors::InvalidArgument("C Number of params mismatch. Expected ",
1297 num_params_weights_, ", got ",
1298 rnn_desc->ParamsWeightRegions().size()));
1299 int h_num_units = (num_proj_ == 0 ? num_units : num_proj_);
1300 int c_num_units = (num_proj_ == 0 ? 0 : num_units);
1301 for (int i = 0; i < rnn_desc->ParamsWeightRegions().size(); i++) {
1302 int64_t size_in_bytes = rnn_desc->ParamsWeightRegions()[i].size;
1303 int64_t size = size_in_bytes / sizeof(T);
1304 const int layer_idx = i / num_params_weights_per_layer;
1305 const int index_within_layer = i % num_params_weights_per_layer;
1306 int width = 0, height = (num_proj_ == 0 ? h_num_units : c_num_units);
1307 // In CuDNN layout, each layer has num_params_weights_per_layer params,
1308 // with the
1309 // first half a.k.a num_params_input_state params applied on the inputs,
1310 // and the second half on the recurrent hidden states.
1311 bool apply_on_input_state = index_within_layer < num_params_input_state;
1312 if (rnn_direction_mode() == RnnDirectionMode::kRnnUnidirectional) {
1313 if (layer_idx == 0 && apply_on_input_state) {
1314 width = input_size;
1315 } else {
1316 width = h_num_units;
1317 }
1318 } else {
1319 if (apply_on_input_state) {
1320 if (layer_idx <= 1) {
1321 // First fwd or bak layer.
1322 width = input_size;
1323 } else {
1324 // Following layers, cell inputs are concatenated outputs of
1325 // its prior layer.
1326 width = 2 * h_num_units;
1327 }
1328 } else {
1329 width = h_num_units;
1330 }
1331 }
1332 CHECK(size == width * height) << "Params size mismatch. Expected "
1333 << width * height << ", got " << size;
1334 Tensor* output = nullptr;
1335 int id_in_layer = i % num_params_weights_per_layer;
1336 if (num_proj_ != 0 && id_in_layer == num_params_weights_per_layer - 1) {
1337 std::swap(height, width);
1338 }
1339 OP_REQUIRES_OK(context, context->allocate_output(
1340 i, TensorShape({height, width}), &output));
1341 DeviceMemoryBase data_src_ptr = SliceDeviceMemory(
1342 input_ptr, rnn_desc->ParamsWeightRegions()[i].offset, size_in_bytes);
1343 auto data_dst_ptr = StreamExecutorUtil::AsDeviceMemory<T>(*output);
1344 stream->ThenMemcpy(&data_dst_ptr, data_src_ptr, size_in_bytes);
1345 }
1346
1347 OP_REQUIRES(
1348 context, num_params_biases_ == rnn_desc->ParamsBiasRegions().size(),
1349 errors::InvalidArgument("A Number of params mismatch. Expected ",
1350 num_params_biases_, ", got ",
1351 rnn_desc->ParamsBiasRegions().size()));
1352 for (int i = 0; i < rnn_desc->ParamsBiasRegions().size(); i++) {
1353 int64_t size_in_bytes = rnn_desc->ParamsBiasRegions()[i].size;
1354 int64_t size = size_in_bytes / sizeof(T);
1355 OP_REQUIRES(context, size == num_units,
1356 errors::InvalidArgument("Params size mismatch. Expected ",
1357 num_units, ", got ", size));
1358
1359 Tensor* output = nullptr;
1360 OP_REQUIRES_OK(context,
1361 context->allocate_output(num_params_weights_ + i,
1362 TensorShape({size}), &output));
1363 DeviceMemoryBase data_src_ptr = SliceDeviceMemory(
1364 input_ptr, rnn_desc->ParamsBiasRegions()[i].offset, size_in_bytes);
1365 auto data_dst_ptr = StreamExecutorUtil::AsDeviceMemory<T>(*output);
1366 stream->ThenMemcpy(&data_dst_ptr, data_src_ptr, size_in_bytes);
1367 }
1368 }
1369
1370 private:
1371 int num_params_;
1372 int num_params_weights_;
1373 int num_params_biases_;
1374 int num_proj_;
1375 };
1376
1377 #define REGISTER_GPU(T) \
1378 REGISTER_KERNEL_BUILDER(Name("CudnnRNNParamsToCanonical") \
1379 .Device(DEVICE_GPU) \
1380 .HostMemory("num_layers") \
1381 .HostMemory("num_units") \
1382 .HostMemory("input_size") \
1383 .TypeConstraint<T>("T"), \
1384 CudnnRNNParamsToCanonical<GPUDevice, T>);
1385 TF_CALL_half(REGISTER_GPU);
1386 TF_CALL_float(REGISTER_GPU);
1387 TF_CALL_double(REGISTER_GPU);
1388 #undef REGISTER_GPU
1389
1390 #define REGISTER_GPU(T) \
1391 REGISTER_KERNEL_BUILDER(Name("CudnnRNNParamsToCanonicalV2") \
1392 .Device(DEVICE_GPU) \
1393 .HostMemory("num_layers") \
1394 .HostMemory("num_units") \
1395 .HostMemory("input_size") \
1396 .TypeConstraint<T>("T"), \
1397 CudnnRNNParamsToCanonical<GPUDevice, T>);
1398 TF_CALL_half(REGISTER_GPU);
1399 TF_CALL_float(REGISTER_GPU);
1400 TF_CALL_double(REGISTER_GPU);
1401 #undef REGISTER_GPU
1402
1403 // Convert weight and bias params from the canonical form to a
1404 // platform-specific layout.
1405 template <typename T>
1406 class CudnnRNNCanonicalToParams<GPUDevice, T> : public CudnnRNNKernelCommon {
1407 public:
CudnnRNNCanonicalToParams(OpKernelConstruction * context)1408 explicit CudnnRNNCanonicalToParams(OpKernelConstruction* context)
1409 : CudnnRNNKernelCommon(context) {
1410 if (context->HasAttr("num_proj")) {
1411 OP_REQUIRES_OK(context, context->GetAttr("num_proj", &num_proj_));
1412 } else {
1413 num_proj_ = 0;
1414 }
1415 }
1416
Compute(OpKernelContext * context)1417 void Compute(OpKernelContext* context) override {
1418 std::unique_ptr<RnnDescriptor> rnn_desc;
1419 OP_REQUIRES_OK(context,
1420 ExtractCudnnRNNParamsInfo<T>(context, num_proj_, &rnn_desc));
1421 int64_t params_size_in_bytes = rnn_desc->ParamsSizeInBytes();
1422 CHECK(params_size_in_bytes % sizeof(T) == 0)
1423 << "params_size_in_bytes must be multiple of element size";
1424 Tensor* output = nullptr;
1425 int params_size = params_size_in_bytes / sizeof(T);
1426 OP_REQUIRES_OK(context,
1427 context->allocate_output(0, {params_size}, &output));
1428 auto output_ptr = StreamExecutorUtil::AsDeviceMemory<T>(*output);
1429 Stream* stream = context->op_device_context()->stream();
1430
1431 OpInputList weights;
1432 OP_REQUIRES_OK(context, context->input_list("weights", &weights));
1433 RestoreParams<T>(weights, rnn_desc->ParamsWeightRegions(), &output_ptr,
1434 stream);
1435
1436 OpInputList biases;
1437 OP_REQUIRES_OK(context, context->input_list("biases", &biases));
1438 RestoreParams<T>(biases, rnn_desc->ParamsBiasRegions(), &output_ptr,
1439 stream);
1440 }
1441
1442 private:
1443 int num_proj_;
1444 };
1445
1446 #define REGISTER_GPU(T) \
1447 REGISTER_KERNEL_BUILDER(Name("CudnnRNNCanonicalToParams") \
1448 .Device(DEVICE_GPU) \
1449 .HostMemory("num_layers") \
1450 .HostMemory("num_units") \
1451 .HostMemory("input_size") \
1452 .TypeConstraint<T>("T"), \
1453 CudnnRNNCanonicalToParams<GPUDevice, T>);
1454 TF_CALL_half(REGISTER_GPU);
1455 TF_CALL_float(REGISTER_GPU);
1456 TF_CALL_double(REGISTER_GPU);
1457 #undef REGISTER_GPU
1458
1459 #define REGISTER_GPU(T) \
1460 REGISTER_KERNEL_BUILDER(Name("CudnnRNNCanonicalToParamsV2") \
1461 .Device(DEVICE_GPU) \
1462 .HostMemory("num_layers") \
1463 .HostMemory("num_units") \
1464 .HostMemory("input_size") \
1465 .TypeConstraint<T>("T"), \
1466 CudnnRNNCanonicalToParams<GPUDevice, T>);
1467 TF_CALL_half(REGISTER_GPU);
1468 TF_CALL_float(REGISTER_GPU);
1469 TF_CALL_double(REGISTER_GPU);
1470 #undef REGISTER_GPU
1471
1472 // Run the forward operation of the RNN model.
1473 template <typename T>
1474 class CudnnRNNForwardOp<GPUDevice, T> : public CudnnRNNKernelCommon {
1475 public:
CudnnRNNForwardOp(OpKernelConstruction * context)1476 explicit CudnnRNNForwardOp(OpKernelConstruction* context)
1477 : CudnnRNNKernelCommon(context) {
1478 OP_REQUIRES_OK(context, context->GetAttr("is_training", &is_training_));
1479
1480 // Read debug env variables.
1481 is_debug_mode_ = DebugCudnnRnn();
1482 debug_cudnn_rnn_algo_ = DebugCudnnRnnAlgo();
1483 debug_use_tensor_ops_ = DebugCudnnRnnUseTensorOps();
1484 }
1485
Compute(OpKernelContext * context)1486 void Compute(OpKernelContext* context) override {
1487 AlgorithmConfig algo_config;
1488 ComputeAndReturnAlgorithm(context, &algo_config, /*var_seq_lengths=*/false,
1489 /*time_major=*/true, /*num_proj=*/0);
1490 }
1491
1492 protected:
ComputeAndReturnAlgorithm(OpKernelContext * context,AlgorithmConfig * output_algo_config,bool var_seq_lengths,bool time_major,int num_proj)1493 virtual void ComputeAndReturnAlgorithm(OpKernelContext* context,
1494 AlgorithmConfig* output_algo_config,
1495 bool var_seq_lengths, bool time_major,
1496 int num_proj) {
1497 CHECK_NE(output_algo_config, nullptr);
1498
1499 const Tensor* input = nullptr;
1500 const Tensor* input_h = nullptr;
1501 const Tensor* input_c = nullptr;
1502 const Tensor* params = nullptr;
1503 const Tensor* sequence_lengths = nullptr;
1504 CudnnRnnModelShapes model_shapes;
1505 bool use_padded_io = false;
1506 if (var_seq_lengths) {
1507 OP_REQUIRES_OK(context, ExtractForwardInput(
1508 context, model_types(), time_major, &input,
1509 &input_h, &input_c, ¶ms,
1510 &sequence_lengths, num_proj, &model_shapes));
1511 use_padded_io =
1512 ShouldUsePaddedIO(sequence_lengths, model_shapes, time_major);
1513 } else {
1514 OP_REQUIRES_OK(context,
1515 ExtractForwardInput(context, model_types(), time_major,
1516 &input, &input_h, &input_c, ¶ms,
1517 num_proj, &model_shapes));
1518 }
1519 RnnInputMode input_mode;
1520 OP_REQUIRES_OK(context,
1521 ToRNNInputMode(rnn_input_mode(), model_shapes.num_units,
1522 model_shapes.input_size, &input_mode));
1523
1524 Tensor* output = nullptr;
1525 Tensor* output_h = nullptr;
1526 Tensor* output_c = nullptr;
1527 OP_REQUIRES_OK(context, AllocateOutputs(context, model_shapes, &output,
1528 &output_h, &output_c));
1529
1530 // Creates a memory callback for the reserve_space. The memory lives in the
1531 // output of this kernel. And it will be fed into the backward pass when
1532 // needed.
1533 CudnnRnnAllocatorInOutput<T> reserve_space_allocator(context, 3);
1534 // Creates a memory callback for the workspace. The memory lives to the end
1535 // of this kernel calls.
1536 CudnnRnnAllocatorInTemp<uint8> workspace_allocator(context);
1537
1538 if (is_debug_mode_) {
1539 AlgorithmDesc algo_desc(debug_cudnn_rnn_algo_, debug_use_tensor_ops_);
1540 output_algo_config->set_algorithm(algo_desc);
1541 } else {
1542 OP_REQUIRES_OK(context,
1543 MaybeAutotune(context, model_shapes, input_mode, input,
1544 input_h, input_c, params, output, output_h,
1545 output_c, output_algo_config));
1546 }
1547
1548 Status launch_status;
1549 {
1550 mutex_lock l(mu_);
1551 RnnDescriptor* rnn_desc_ptr = nullptr;
1552 OP_REQUIRES_OK(context,
1553 GetCachedRnnDescriptor<T>(
1554 context, model_shapes, input_mode, *output_algo_config,
1555 &rnn_state_cache_, &rnn_desc_ptr, use_padded_io));
1556 launch_status = DoForward<T>(
1557 context, *rnn_desc_ptr, model_types(), model_shapes, input, input_h,
1558 input_c, params, is_training_, output, output_h, output_c,
1559 sequence_lengths, time_major, &reserve_space_allocator,
1560 &workspace_allocator, /*output_profile_result=*/nullptr);
1561 }
1562 OP_REQUIRES_OK(context, launch_status);
1563 }
1564
1565 protected:
MaybeAutotune(OpKernelContext * context,const CudnnRnnModelShapes & model_shapes,const RnnInputMode & input_mode,const Tensor * input,const Tensor * input_h,const Tensor * input_c,const Tensor * params,Tensor * output,Tensor * output_h,Tensor * output_c,AlgorithmConfig * best_algo_config)1566 virtual Status MaybeAutotune(OpKernelContext* context,
1567 const CudnnRnnModelShapes& model_shapes,
1568 const RnnInputMode& input_mode,
1569 const Tensor* input, const Tensor* input_h,
1570 const Tensor* input_c, const Tensor* params,
1571 Tensor* output, Tensor* output_h,
1572 Tensor* output_c,
1573 AlgorithmConfig* best_algo_config) {
1574 CHECK_NE(best_algo_config, nullptr);
1575 *best_algo_config = AlgorithmConfig();
1576 return Status::OK();
1577 }
1578
is_training() const1579 bool is_training() const { return is_training_; }
1580 bool is_debug_mode_;
1581 bool debug_use_tensor_ops_;
1582 int64 debug_cudnn_rnn_algo_;
1583
1584 private:
AllocateOutputs(OpKernelContext * context,const CudnnRnnModelShapes & model_shapes,Tensor ** output,Tensor ** output_h,Tensor ** output_c)1585 Status AllocateOutputs(OpKernelContext* context,
1586 const CudnnRnnModelShapes& model_shapes,
1587 Tensor** output, Tensor** output_h,
1588 Tensor** output_c) {
1589 const TensorShape& hidden_state_shape = model_shapes.hidden_state_shape;
1590 const TensorShape& output_shape = model_shapes.output_shape;
1591 const TensorShape& cell_state_shape = model_shapes.cell_state_shape;
1592
1593 TF_RETURN_IF_ERROR(context->allocate_output(0, output_shape, output));
1594 TF_RETURN_IF_ERROR(
1595 context->allocate_output(1, hidden_state_shape, output_h));
1596 if (HasInputC()) {
1597 TF_RETURN_IF_ERROR(
1598 context->allocate_output(2, cell_state_shape, output_c));
1599 } else {
1600 // Only LSTM uses input_c and output_c. So for all other models, we only
1601 // need to create dummy outputs.
1602 TF_RETURN_IF_ERROR(context->allocate_output(2, {}, output_c));
1603 }
1604 if (!is_training_) {
1605 Tensor* dummy_reserve_space = nullptr;
1606 TF_RETURN_IF_ERROR(context->allocate_output(3, {}, &dummy_reserve_space));
1607 }
1608 return Status::OK();
1609 }
1610
1611 mutex mu_;
1612 bool is_training_;
1613 RnnStateCache rnn_state_cache_ TF_GUARDED_BY(mu_);
1614 };
1615
1616 #define REGISTER_GPU(T) \
1617 REGISTER_KERNEL_BUILDER( \
1618 Name("CudnnRNN").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
1619 CudnnRNNForwardOp<GPUDevice, T>);
1620
1621 TF_CALL_half(REGISTER_GPU);
1622 TF_CALL_float(REGISTER_GPU);
1623 TF_CALL_double(REGISTER_GPU);
1624 #undef REGISTER_GPU
1625
1626 template <typename T>
1627 class CudnnRNNForwardOpV2<GPUDevice, T>
1628 : public CudnnRNNForwardOp<GPUDevice, T> {
1629 private:
1630 using CudnnRNNForwardOp<GPUDevice, T>::is_training;
1631 using CudnnRNNKernelCommon::CreateRnnDescriptor;
1632 using CudnnRNNKernelCommon::dropout;
1633 using CudnnRNNKernelCommon::HasInputC;
1634 using CudnnRNNKernelCommon::model_types;
1635
1636 public:
CudnnRNNForwardOpV2(OpKernelConstruction * context)1637 explicit CudnnRNNForwardOpV2(OpKernelConstruction* context)
1638 : CudnnRNNForwardOp<GPUDevice, T>(context) {}
1639
Compute(OpKernelContext * context)1640 void Compute(OpKernelContext* context) override {
1641 AlgorithmConfig best_algo_config;
1642 CudnnRNNForwardOp<GPUDevice, T>::ComputeAndReturnAlgorithm(
1643 context, &best_algo_config, /*var_seq_lengths=*/false,
1644 /*time_major=*/true, /*num_proj=*/0);
1645 if (!context->status().ok()) {
1646 return;
1647 }
1648
1649 Tensor* output_host_reserved = nullptr;
1650 // output_host_reserved stores opaque info used for backprop when running
1651 // in training mode. At present, it includes a serialization of the best
1652 // AlgorithmDesc picked during rnn forward pass autotune.
1653 // int8 algorithm_id
1654 // int8 use_tensor_op
1655 // If autotune is not enabled, the algorithm_id is
1656 // stream_executor::dnn::kDefaultAlgorithm and use_tensor_op is false. If
1657 // running in inference mode, the output_host_reserved is currently not
1658 // populated.
1659 if (is_training()) {
1660 OP_REQUIRES_OK(context, context->allocate_output(4, TensorShape({2}),
1661 &output_host_reserved));
1662 auto output_host_reserved_int8 = output_host_reserved->vec<int8>();
1663 output_host_reserved_int8(0) = best_algo_config.algorithm()->algo_id();
1664 output_host_reserved_int8(1) =
1665 best_algo_config.algorithm()->tensor_ops_enabled();
1666 } else {
1667 OP_REQUIRES_OK(context,
1668 context->allocate_output(4, {}, &output_host_reserved));
1669 }
1670 }
1671
1672 protected:
MaybeAutotune(OpKernelContext * context,const CudnnRnnModelShapes & model_shapes,const RnnInputMode & input_mode,const Tensor * input,const Tensor * input_h,const Tensor * input_c,const Tensor * params,Tensor * output,Tensor * output_h,Tensor * output_c,AlgorithmConfig * algo_config)1673 Status MaybeAutotune(OpKernelContext* context,
1674 const CudnnRnnModelShapes& model_shapes,
1675 const RnnInputMode& input_mode, const Tensor* input,
1676 const Tensor* input_h, const Tensor* input_c,
1677 const Tensor* params, Tensor* output, Tensor* output_h,
1678 Tensor* output_c,
1679 AlgorithmConfig* algo_config) override {
1680 CHECK_NE(algo_config, nullptr);
1681 if (!CudnnRnnUseAutotune() || this->is_debug_mode_) {
1682 *algo_config = AlgorithmConfig();
1683 return Status::OK();
1684 }
1685
1686 std::vector<AlgorithmDesc> algorithms;
1687 auto* stream = context->op_device_context()->stream();
1688 CHECK(stream->parent()->GetRnnAlgorithms(&algorithms));
1689 if (algorithms.empty()) {
1690 LOG(WARNING) << "No Rnn algorithm found";
1691 return Status::OK();
1692 }
1693
1694 const auto& modeltypes = model_types();
1695 CudnnRnnParameters rnn_params(
1696 model_shapes.num_layers, model_shapes.input_size,
1697 model_shapes.num_units, model_shapes.max_seq_length,
1698 model_shapes.batch_size, model_shapes.dir_count,
1699 /*has_dropout=*/std::abs(dropout()) > 1e-8, is_training(),
1700 modeltypes.rnn_mode, modeltypes.rnn_input_mode, input->dtype());
1701
1702 if (AutotuneRnnConfigMap::GetInstance()->Find(rnn_params, algo_config)) {
1703 VLOG(1) << "Using existing best Cudnn RNN algorithm "
1704 << "(algo, tensor_op_enabled) = ("
1705 << algo_config->algorithm()->algo_id() << ", "
1706 << algo_config->algorithm()->tensor_ops_enabled() << ").";
1707 return Status::OK();
1708 }
1709 profiler::ScopedAnnotation trace("cudnn_autotuning");
1710
1711 // Create temp tensors when profiling backprop pass.
1712 auto data_type = input->dtype();
1713 Tensor output_backprop;
1714 Tensor output_h_backprop;
1715 Tensor output_c_backprop;
1716 Tensor input_backprop;
1717 Tensor input_h_backprop;
1718 Tensor input_c_backprop;
1719 Tensor params_backprop;
1720 if (is_training()) {
1721 TF_RETURN_IF_ERROR(context->allocate_temp(
1722 data_type, model_shapes.output_shape, &output_backprop));
1723 TF_RETURN_IF_ERROR(context->allocate_temp(
1724 data_type, model_shapes.hidden_state_shape, &output_h_backprop));
1725
1726 TF_RETURN_IF_ERROR(
1727 context->allocate_temp(data_type, params->shape(), ¶ms_backprop));
1728 TF_RETURN_IF_ERROR(context->allocate_temp(
1729 data_type, model_shapes.input_shape, &input_backprop));
1730 TF_RETURN_IF_ERROR(context->allocate_temp(
1731 data_type, model_shapes.hidden_state_shape, &input_h_backprop));
1732 if (HasInputC()) {
1733 TF_RETURN_IF_ERROR(context->allocate_temp(
1734 data_type, model_shapes.hidden_state_shape, &output_c_backprop));
1735 TF_RETURN_IF_ERROR(context->allocate_temp(
1736 data_type, model_shapes.hidden_state_shape, &input_c_backprop));
1737 }
1738 }
1739 ProfileResult best_result;
1740 for (auto& algo : algorithms) {
1741 VLOG(1) << "Profile Cudnn RNN algorithm (algo, tensor_op_enabled) = ("
1742 << algo.algo_id() << ", " << algo.tensor_ops_enabled() << ").";
1743 Status status;
1744 ProfileResult final_profile_result;
1745
1746 ProfileResult fwd_profile_result;
1747 ProfileResult bak_profile_result;
1748
1749 // RnnDescriptor is algorithm-dependent, thus not reusable.
1750 std::unique_ptr<RnnDescriptor> rnn_desc;
1751 // Use a temp scratch allocator for the random num generator.
1752 CudnnRnnAllocatorInTemp<uint8> dropout_state_allocator(context);
1753 if (!this->template CreateRnnDescriptor<T>(
1754 context, model_shapes, input_mode, AlgorithmConfig(algo),
1755 &dropout_state_allocator, &rnn_desc,
1756 /*use_padded_io=*/false)
1757 .ok()) {
1758 continue;
1759 }
1760
1761 // Again use temp scratch allocator during profiling.
1762 CudnnRnnAllocatorInTemp<T> reserve_space_allocator(context);
1763 CudnnRnnAllocatorInTemp<uint8> workspace_allocator(context);
1764 status = DoForward<T>(context, *rnn_desc, model_types(), model_shapes,
1765 input, input_h, input_c, params, is_training(),
1766 output, output_h, output_c, nullptr, true,
1767 &reserve_space_allocator, &workspace_allocator,
1768 &fwd_profile_result);
1769 if (!status.ok()) {
1770 continue;
1771 }
1772
1773 if (is_training()) {
1774 // Get reserve space from the forward pass.
1775 Tensor reserve_space = reserve_space_allocator.get_allocated_tensor(0);
1776 status = DoBackward<T>(
1777 context, *rnn_desc, model_types(), model_shapes, input, input_h,
1778 input_c, params, output, output_h, output_c, &output_backprop,
1779 &output_h_backprop, &output_c_backprop, &reserve_space,
1780 &input_backprop, &input_h_backprop, &input_c_backprop,
1781 ¶ms_backprop, nullptr, true, &workspace_allocator,
1782 &bak_profile_result);
1783 if (!status.ok()) {
1784 continue;
1785 }
1786 final_profile_result.set_elapsed_time_in_ms(
1787 fwd_profile_result.elapsed_time_in_ms() +
1788 bak_profile_result.elapsed_time_in_ms());
1789 } else {
1790 final_profile_result = fwd_profile_result;
1791 }
1792
1793 auto total_time = final_profile_result.elapsed_time_in_ms();
1794 VLOG(1) << "Cudnn RNN algorithm (algo, tensor_op_enabled) = ("
1795 << algo.algo_id() << ", " << algo.tensor_ops_enabled() << ")"
1796 << " run time: " << total_time << " ms.";
1797 if (total_time < best_result.elapsed_time_in_ms()) {
1798 best_result.set_elapsed_time_in_ms(total_time);
1799 best_result.set_algorithm(algo);
1800 }
1801 }
1802
1803 if (!best_result.is_valid()) {
1804 return Status(error::Code::INTERNAL, "No algorithm worked!");
1805 }
1806 algo_config->set_algorithm(best_result.algorithm());
1807 VLOG(1) << "Best Cudnn RNN algorithm (algo, tensor_op_enabled) = ("
1808 << best_result.algorithm().algo_id() << ", "
1809 << best_result.algorithm().tensor_ops_enabled() << ").";
1810 AutotuneRnnConfigMap::GetInstance()->Insert(rnn_params, *algo_config);
1811 return Status::OK();
1812 }
1813 };
1814
1815 #define REGISTER_GPU(T) \
1816 REGISTER_KERNEL_BUILDER(Name("CudnnRNNV2") \
1817 .Device(DEVICE_GPU) \
1818 .HostMemory("host_reserved") \
1819 .TypeConstraint<T>("T"), \
1820 CudnnRNNForwardOpV2<GPUDevice, T>);
1821
1822 TF_CALL_half(REGISTER_GPU);
1823 TF_CALL_float(REGISTER_GPU);
1824 TF_CALL_double(REGISTER_GPU);
1825 #undef REGISTER_GPU
1826
1827 template <typename T>
1828 class CudnnRNNForwardOpV3<GPUDevice, T>
1829 : public CudnnRNNForwardOp<GPUDevice, T> {
1830 private:
1831 using CudnnRNNForwardOp<GPUDevice, T>::is_training;
1832 using CudnnRNNKernelCommon::CreateRnnDescriptor;
1833 using CudnnRNNKernelCommon::dropout;
1834 using CudnnRNNKernelCommon::HasInputC;
1835 using CudnnRNNKernelCommon::model_types;
1836 bool time_major_;
1837
1838 protected:
time_major()1839 bool time_major() { return time_major_; }
1840
1841 public:
CudnnRNNForwardOpV3(OpKernelConstruction * context)1842 explicit CudnnRNNForwardOpV3(OpKernelConstruction* context)
1843 : CudnnRNNForwardOp<GPUDevice, T>(context) {
1844 OP_REQUIRES_OK(context, context->GetAttr("time_major", &time_major_));
1845 if (context->HasAttr("num_proj")) {
1846 OP_REQUIRES_OK(context, context->GetAttr("num_proj", &num_proj_));
1847 } else {
1848 num_proj_ = 0;
1849 }
1850 }
1851
Compute(OpKernelContext * context)1852 void Compute(OpKernelContext* context) override {
1853 AlgorithmConfig best_algo_config;
1854 CudnnRNNForwardOp<GPUDevice, T>::ComputeAndReturnAlgorithm(
1855 context, &best_algo_config, /*var_seq_lengths=*/true,
1856 /*time_major=*/time_major(), num_proj_);
1857 if (!context->status().ok()) {
1858 return;
1859 }
1860
1861 Tensor* output_host_reserved = nullptr;
1862 // TODO: Current V3 only uses the default standard algorithm to process
1863 // batches with variable sequences and the inputs should be padded.
1864 // Autotune is not supported yet.
1865 OP_REQUIRES_OK(context,
1866 context->allocate_output(4, {}, &output_host_reserved));
1867 }
1868
1869 private:
1870 int num_proj_;
1871 };
1872
1873 #define REGISTER_GPU(T) \
1874 REGISTER_KERNEL_BUILDER(Name("CudnnRNNV3") \
1875 .Device(DEVICE_GPU) \
1876 .HostMemory("sequence_lengths") \
1877 .HostMemory("host_reserved") \
1878 .TypeConstraint<T>("T"), \
1879 CudnnRNNForwardOpV3<GPUDevice, T>);
1880
1881 TF_CALL_half(REGISTER_GPU);
1882 TF_CALL_float(REGISTER_GPU);
1883 TF_CALL_double(REGISTER_GPU);
1884 #undef REGISTER_GPU
1885
1886 // Run the backward operation of the RNN model.
1887 template <typename T>
1888 class CudnnRNNBackwardOp<GPUDevice, T> : public CudnnRNNKernelCommon {
1889 public:
CudnnRNNBackwardOp(OpKernelConstruction * context)1890 explicit CudnnRNNBackwardOp(OpKernelConstruction* context)
1891 : CudnnRNNKernelCommon(context) {}
1892
Compute(OpKernelContext * context)1893 void Compute(OpKernelContext* context) override {
1894 ComputeImpl(context, false, true, 0);
1895 }
1896
1897 protected:
ComputeImpl(OpKernelContext * context,bool var_seq_lengths,bool time_major,int num_proj)1898 virtual void ComputeImpl(OpKernelContext* context, bool var_seq_lengths,
1899 bool time_major, int num_proj) {
1900 const Tensor* input = nullptr;
1901 const Tensor* input_h = nullptr;
1902 const Tensor* input_c = nullptr;
1903 const Tensor* params = nullptr;
1904 const Tensor* sequence_lengths = nullptr;
1905 CudnnRnnModelShapes model_shapes;
1906 bool use_padded_io = false;
1907 if (var_seq_lengths) {
1908 OP_REQUIRES_OK(context, ExtractForwardInput(
1909 context, model_types(), time_major, &input,
1910 &input_h, &input_c, ¶ms,
1911 &sequence_lengths, num_proj, &model_shapes));
1912 use_padded_io =
1913 ShouldUsePaddedIO(sequence_lengths, model_shapes, time_major);
1914 } else {
1915 OP_REQUIRES_OK(context,
1916 ExtractForwardInput(context, model_types(), time_major,
1917 &input, &input_h, &input_c, ¶ms,
1918 num_proj, &model_shapes));
1919 }
1920 RnnInputMode input_mode;
1921 OP_REQUIRES_OK(context,
1922 ToRNNInputMode(rnn_input_mode(), model_shapes.num_units,
1923 model_shapes.input_size, &input_mode));
1924
1925 const Tensor* output = nullptr;
1926 const Tensor* output_h = nullptr;
1927 const Tensor* output_c = nullptr;
1928 const Tensor* output_backprop = nullptr;
1929 const Tensor* output_h_backprop = nullptr;
1930 const Tensor* output_c_backprop = nullptr;
1931 const Tensor* reserve_space = nullptr;
1932 OP_REQUIRES_OK(context,
1933 ExtractBackwardInputs(context, model_shapes, model_types(),
1934 &output, &output_h, &output_c,
1935 &output_backprop, &output_h_backprop,
1936 &output_c_backprop, &reserve_space));
1937
1938 Tensor* input_backprop = nullptr;
1939 Tensor* input_h_backprop = nullptr;
1940 Tensor* input_c_backprop = nullptr;
1941 Tensor* params_backprop = nullptr;
1942 OP_REQUIRES_OK(context,
1943 AllocateOutputs(context, model_shapes, params->shape(),
1944 &input_backprop, &input_h_backprop,
1945 &input_c_backprop, ¶ms_backprop));
1946
1947 // Creates a memory callback for the workspace. The memory lives to the end
1948 // of this kernel calls.
1949 CudnnRnnAllocatorInTemp<uint8> workspace_allocator(context);
1950 AlgorithmConfig algo_config;
1951 OP_REQUIRES_OK(context, GetAlgorithm(context, &algo_config));
1952 Status launch_status;
1953 {
1954 mutex_lock l(mu_);
1955 RnnDescriptor* rnn_desc_ptr = nullptr;
1956 OP_REQUIRES_OK(
1957 context, GetCachedRnnDescriptor<T>(context, model_shapes, input_mode,
1958 algo_config, &rnn_state_cache_,
1959 &rnn_desc_ptr, use_padded_io));
1960 launch_status = DoBackward<T>(
1961 context, *rnn_desc_ptr, model_types(), model_shapes, input, input_h,
1962 input_c, params, output, output_h, output_c, output_backprop,
1963 output_h_backprop, output_c_backprop, reserve_space, input_backprop,
1964 input_h_backprop, input_c_backprop, params_backprop, sequence_lengths,
1965 time_major, &workspace_allocator,
1966 /*output_profile_result=*/nullptr);
1967 }
1968 OP_REQUIRES_OK(context, launch_status);
1969 }
1970
1971 protected:
GetAlgorithm(OpKernelContext * context,AlgorithmConfig * algo_config)1972 virtual Status GetAlgorithm(OpKernelContext* context,
1973 AlgorithmConfig* algo_config) {
1974 CHECK_NE(algo_config, nullptr);
1975 *algo_config = AlgorithmConfig();
1976 return Status::OK();
1977 }
1978
1979 private:
1980 mutex mu_;
1981 RnnStateCache rnn_state_cache_ TF_GUARDED_BY(mu_);
1982
ExtractBackwardInputs(OpKernelContext * context,const CudnnRnnModelShapes & model_shapes,const CudnnModelTypes & model_types,const Tensor ** output,const Tensor ** output_h,const Tensor ** output_c,const Tensor ** output_backprop,const Tensor ** output_h_backprop,const Tensor ** output_c_backprop,const Tensor ** reserve_space)1983 Status ExtractBackwardInputs(
1984 OpKernelContext* context, const CudnnRnnModelShapes& model_shapes,
1985 const CudnnModelTypes& model_types, const Tensor** output,
1986 const Tensor** output_h, const Tensor** output_c,
1987 const Tensor** output_backprop, const Tensor** output_h_backprop,
1988 const Tensor** output_c_backprop, const Tensor** reserve_space) {
1989 TF_RETURN_IF_ERROR(context->input("output", output));
1990 TF_RETURN_IF_ERROR(context->input("output_backprop", output_backprop));
1991 TF_RETURN_IF_ERROR(context->input("output_h", output_h));
1992 TF_RETURN_IF_ERROR(context->input("output_h_backprop", output_h_backprop));
1993 if (model_types.HasInputC()) {
1994 TF_RETURN_IF_ERROR(context->input("output_c", output_c));
1995 TF_RETURN_IF_ERROR(
1996 context->input("output_c_backprop", output_c_backprop));
1997 }
1998 TF_RETURN_IF_ERROR(context->input("reserve_space", reserve_space));
1999 const TensorShape& hidden_state_shape = model_shapes.hidden_state_shape;
2000 const TensorShape& output_shape = model_shapes.output_shape;
2001 const TensorShape& cell_state_shape = model_shapes.cell_state_shape;
2002
2003 if (output_shape != (*output)->shape()) {
2004 return errors::InvalidArgument(
2005 "Invalid output shape: ", (*output)->shape().DebugString(), " ",
2006 output_shape.DebugString());
2007 }
2008 if (hidden_state_shape != (*output_h)->shape()) {
2009 return errors::InvalidArgument(
2010 "Invalid output_h shape: ", (*output_h)->shape().DebugString(), " ",
2011 hidden_state_shape.DebugString());
2012 }
2013
2014 if (output_shape != (*output_backprop)->shape()) {
2015 return errors::InvalidArgument("Invalid output_backprop shape: ",
2016 (*output_backprop)->shape().DebugString(),
2017 " ", output_shape.DebugString());
2018 }
2019 if (hidden_state_shape != (*output_h_backprop)->shape()) {
2020 return errors::InvalidArgument(
2021 "Invalid output_h_backprop shape: ",
2022 (*output_h_backprop)->shape().DebugString(), " ",
2023 hidden_state_shape.DebugString());
2024 }
2025
2026 if (model_types.HasInputC()) {
2027 if (cell_state_shape != (*output_c)->shape()) {
2028 return errors::InvalidArgument(
2029 "Invalid output_c shape: ", (*output_c)->shape().DebugString(), " ",
2030 cell_state_shape.DebugString());
2031 }
2032 if (cell_state_shape != (*output_c_backprop)->shape()) {
2033 return errors::InvalidArgument(
2034 "Invalid output_c_backprop shape: ",
2035 (*output_c_backprop)->shape().DebugString(), " ",
2036 cell_state_shape.DebugString());
2037 }
2038 }
2039 return Status::OK();
2040 }
2041
AllocateOutputs(OpKernelContext * context,const CudnnRnnModelShapes & model_shapes,const TensorShape & params_shape,Tensor ** input_backprop,Tensor ** input_h_backprop,Tensor ** input_c_backprop,Tensor ** params_backprop)2042 Status AllocateOutputs(OpKernelContext* context,
2043 const CudnnRnnModelShapes& model_shapes,
2044 const TensorShape& params_shape,
2045 Tensor** input_backprop, Tensor** input_h_backprop,
2046 Tensor** input_c_backprop, Tensor** params_backprop) {
2047 const TensorShape& input_shape = model_shapes.input_shape;
2048 const TensorShape& hidden_state_shape = model_shapes.hidden_state_shape;
2049 const TensorShape& cell_state_shape = model_shapes.cell_state_shape;
2050
2051 TF_RETURN_IF_ERROR(
2052 context->allocate_output(0, input_shape, input_backprop));
2053 TF_RETURN_IF_ERROR(
2054 context->allocate_output(1, hidden_state_shape, input_h_backprop));
2055 if (HasInputC()) {
2056 TF_RETURN_IF_ERROR(
2057 context->allocate_output(2, cell_state_shape, input_c_backprop));
2058 } else {
2059 // Only LSTM uses input_c and output_c. So for all other models, we only
2060 // need to create dummy outputs.
2061 TF_RETURN_IF_ERROR(context->allocate_output(2, {}, input_c_backprop));
2062 }
2063 TF_RETURN_IF_ERROR(
2064 context->allocate_output(3, params_shape, params_backprop));
2065 return Status::OK();
2066 }
2067 };
2068
2069 #define REGISTER_GPU(T) \
2070 REGISTER_KERNEL_BUILDER( \
2071 Name("CudnnRNNBackprop").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
2072 CudnnRNNBackwardOp<GPUDevice, T>);
2073
2074 TF_CALL_half(REGISTER_GPU);
2075 TF_CALL_float(REGISTER_GPU);
2076 TF_CALL_double(REGISTER_GPU);
2077 #undef REGISTER_GPU
2078
2079 template <typename T>
2080 class CudnnRNNBackwardOpV2<GPUDevice, T>
2081 : public CudnnRNNBackwardOp<GPUDevice, T> {
2082 public:
CudnnRNNBackwardOpV2(OpKernelConstruction * context)2083 explicit CudnnRNNBackwardOpV2(OpKernelConstruction* context)
2084 : CudnnRNNBackwardOp<GPUDevice, T>(context) {}
2085
2086 protected:
GetAlgorithm(OpKernelContext * context,AlgorithmConfig * algo_config)2087 Status GetAlgorithm(OpKernelContext* context,
2088 AlgorithmConfig* algo_config) override {
2089 CHECK_NE(algo_config, nullptr);
2090 const Tensor* host_reserved = nullptr;
2091 TF_RETURN_IF_ERROR(context->input("host_reserved", &host_reserved));
2092
2093 auto host_reserved_int8 = host_reserved->vec<int8>();
2094 const AlgorithmDesc algo_desc(host_reserved_int8(0), host_reserved_int8(1));
2095 algo_config->set_algorithm(algo_desc);
2096 return Status::OK();
2097 }
2098 };
2099
2100 #define REGISTER_GPU(T) \
2101 REGISTER_KERNEL_BUILDER(Name("CudnnRNNBackpropV2") \
2102 .Device(DEVICE_GPU) \
2103 .HostMemory("host_reserved") \
2104 .TypeConstraint<T>("T"), \
2105 CudnnRNNBackwardOpV2<GPUDevice, T>);
2106
2107 TF_CALL_half(REGISTER_GPU);
2108 TF_CALL_float(REGISTER_GPU);
2109 TF_CALL_double(REGISTER_GPU);
2110 #undef REGISTER_GPU
2111
2112 template <typename T>
2113 class CudnnRNNBackwardOpV3<GPUDevice, T>
2114 : public CudnnRNNBackwardOp<GPUDevice, T> {
2115 private:
2116 bool time_major_;
2117
2118 protected:
time_major()2119 bool time_major() { return time_major_; }
2120
2121 public:
CudnnRNNBackwardOpV3(OpKernelConstruction * context)2122 explicit CudnnRNNBackwardOpV3(OpKernelConstruction* context)
2123 : CudnnRNNBackwardOp<GPUDevice, T>(context) {
2124 OP_REQUIRES_OK(context, context->GetAttr("time_major", &time_major_));
2125 if (context->HasAttr("num_proj")) {
2126 OP_REQUIRES_OK(context, context->GetAttr("num_proj", &num_proj_));
2127 } else {
2128 num_proj_ = 0;
2129 }
2130 }
2131
Compute(OpKernelContext * context)2132 void Compute(OpKernelContext* context) override {
2133 CudnnRNNBackwardOp<GPUDevice, T>::ComputeImpl(context, true, time_major(),
2134 num_proj_);
2135 }
2136
2137 private:
2138 int num_proj_;
2139 };
2140
2141 #define REGISTER_GPU(T) \
2142 REGISTER_KERNEL_BUILDER(Name("CudnnRNNBackpropV3") \
2143 .Device(DEVICE_GPU) \
2144 .HostMemory("sequence_lengths") \
2145 .HostMemory("host_reserved") \
2146 .TypeConstraint<T>("T"), \
2147 CudnnRNNBackwardOpV3<GPUDevice, T>);
2148
2149 TF_CALL_half(REGISTER_GPU);
2150 TF_CALL_float(REGISTER_GPU);
2151 TF_CALL_double(REGISTER_GPU);
2152 #undef REGISTER_GPU
2153
2154 // TODO(zhengxq): Add the conversion of Cudnn RNN Params from and to
2155 // its canonical form.
2156
2157 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
2158
2159 } // namespace tensorflow
2160