1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #define EIGEN_USE_THREADS
16
17 #include <stddef.h>
18 #include <atomic>
19 #include <cmath>
20 #include <functional>
21 #include <limits>
22 #include <string>
23 #include <unordered_set>
24
25 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
26 #include "tensorflow/core/framework/device_base.h"
27 #include "tensorflow/core/framework/kernel_def_builder.h"
28 #include "tensorflow/core/framework/op.h"
29 #include "tensorflow/core/framework/op_def_builder.h"
30 #include "tensorflow/core/framework/op_kernel.h"
31 #include "tensorflow/core/framework/register_types.h"
32 #include "tensorflow/core/framework/tensor.h"
33 #include "tensorflow/core/framework/tensor_shape.h"
34 #include "tensorflow/core/framework/tensor_types.h"
35 #include "tensorflow/core/framework/types.h"
36 #include "tensorflow/core/kernels/gpu_utils.h"
37 #include "tensorflow/core/lib/core/errors.h"
38 #include "tensorflow/core/lib/core/status.h"
39 #include "tensorflow/core/lib/core/stringpiece.h"
40 #include "tensorflow/core/lib/gtl/inlined_vector.h"
41 #include "tensorflow/core/lib/hash/hash.h"
42 #include "tensorflow/core/lib/strings/stringprintf.h"
43 #include "tensorflow/core/platform/fingerprint.h"
44 #include "tensorflow/core/platform/mutex.h"
45 #include "tensorflow/core/platform/types.h"
46 #include "tensorflow/core/util/env_var.h"
47 #include "tensorflow/core/util/use_cudnn.h"
48
49 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
50 #include "tensorflow/core/platform/stream_executor.h"
51 #include "tensorflow/core/util/stream_executor_util.h"
52 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
53
54 /*
55 * This module implements ops that fuse a multi-layer multi-step RNN/LSTM model
56 * using the underlying Cudnn library.
57 *
58 * Cudnn RNN library exposes an opaque parameter buffer with unknown layout and
59 * format. And it is very likely that if saved, they cannot be used across
60 * different GPUs. So users need to first query the size of the opaque
61 * parameter buffer, and convert it to and from its canonical forms. But each
62 * actual training step is carried out with the parameter buffer.
63 *
64 * Similar to many other ops, the forward op has two flavors: training and
65 * inference. When training is specified, additional data in reserve_space will
66 * be produced for the backward pass. So there is a performance penalty.
67 *
68 * In addition to the actual data and reserve_space, Cudnn also needs more
69 * memory as temporary workspace. The memory management to and from
70 * stream-executor is done through ScratchAllocator. In general,
71 * stream-executor is responsible for creating the memory of proper size. And
72 * TensorFlow is responsible for making sure the memory is alive long enough
73 * and recycles afterwards.
74 *
75 */
76 namespace tensorflow {
77
78 using CPUDevice = Eigen::ThreadPoolDevice;
79
80 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
81
82 using GPUDevice = Eigen::GpuDevice;
83 using se::Stream;
84 using se::StreamExecutor;
85 using se::dnn::RnnDescriptor;
86
87 template <typename Device, typename T, typename Index>
88 class CudnnRNNParamsSizeOp;
89
90 template <typename Device, typename T>
91 class CudnnRNNParamsToCanonical;
92
93 template <typename Device, typename T>
94 class CudnnRNNCanonicalToParams;
95
96 template <typename Device, typename T>
97 class CudnnRNNForwardOp;
98
99 template <typename Device, typename T>
100 class CudnnRNNBackwardOp;
101
102 template <typename Device, typename T>
103 class CudnnRNNForwardOpV2;
104
105 template <typename Device, typename T>
106 class CudnnRNNBackwardOpV2;
107
108 template <typename Device, typename T>
109 class CudnnRNNForwardOpV3;
110
111 template <typename Device, typename T>
112 class CudnnRNNBackwardOpV3;
113
114 enum class TFRNNInputMode {
115 kRNNLinearInput = 0,
116 kRNNSkipInput = 1,
117 kAutoSelect = 9999999
118 };
119
120 namespace {
121 using se::DeviceMemory;
122 using se::DeviceMemoryBase;
123 using se::ScratchAllocator;
124 using se::dnn::AlgorithmConfig;
125 using se::dnn::AlgorithmDesc;
126 using se::dnn::ProfileResult;
127 using se::dnn::RnnDirectionMode;
128 using se::dnn::RnnInputMode;
129 using se::dnn::RnnMode;
130 using se::dnn::RnnSequenceTensorDescriptor;
131 using se::dnn::RnnStateTensorDescriptor;
132 using se::dnn::ToDataType;
133 using se::port::StatusOr;
134
HashList(const std::vector<int> & list)135 uint64 HashList(const std::vector<int>& list) {
136 if (list.empty()) {
137 return 0;
138 }
139 uint64 hash_code = list[0];
140 for (int i = 1; i < list.size(); i++) {
141 hash_code = Hash64Combine(hash_code, list[i]);
142 }
143 return hash_code;
144 }
145
146 // Encapsulate all the shape information that is used in both forward and
147 // backward rnn operations.
148 class CudnnRnnParameters {
149 public:
CudnnRnnParameters(int num_layers,int input_size,int num_units,int max_seq_length,int batch_size,int dir_count,bool has_dropout,bool is_training,RnnMode rnn_mode,TFRNNInputMode rnn_input_mode,DataType dtype)150 CudnnRnnParameters(int num_layers, int input_size, int num_units,
151 int max_seq_length, int batch_size, int dir_count,
152 bool has_dropout, bool is_training, RnnMode rnn_mode,
153 TFRNNInputMode rnn_input_mode, DataType dtype)
154 : num_layers_(num_layers),
155 input_size_(input_size),
156 num_units_(num_units),
157 seq_length_(max_seq_length),
158 batch_size_(batch_size),
159 dir_count_(dir_count),
160 has_dropout_(has_dropout),
161 is_training_(is_training),
162 rnn_mode_(rnn_mode),
163 rnn_input_mode_(rnn_input_mode),
164 dtype_(dtype) {
165 hash_code_ =
166 HashList({num_layers, input_size, num_units, max_seq_length, batch_size,
167 dir_count, static_cast<int>(has_dropout),
168 static_cast<int>(is_training), static_cast<int>(rnn_mode),
169 static_cast<int>(rnn_input_mode), dtype});
170 }
171
operator ==(const CudnnRnnParameters & other) const172 bool operator==(const CudnnRnnParameters& other) const {
173 return this->get_data_as_tuple() == other.get_data_as_tuple();
174 }
175
operator !=(const CudnnRnnParameters & other) const176 bool operator!=(const CudnnRnnParameters& other) const {
177 return !(*this == other);
178 }
hash() const179 uint64 hash() const { return hash_code_; }
180
ToString() const181 string ToString() const {
182 std::vector<string> fields = {
183 std::to_string(num_layers_),
184 std::to_string(input_size_),
185 std::to_string(num_units_),
186 std::to_string(seq_length_),
187 std::to_string(batch_size_),
188 std::to_string(dir_count_),
189 std::to_string(has_dropout_),
190 std::to_string(is_training_),
191 std::to_string(static_cast<int>(rnn_mode_)),
192 std::to_string(static_cast<int>(rnn_input_mode_)),
193 std::to_string(static_cast<int>(dtype_))};
194 return absl::StrJoin(fields, ", ");
195 }
196
197 private:
198 using ParameterDataType = std::tuple<int, int, int, int, int, int, bool, bool,
199 RnnMode, TFRNNInputMode, DataType>;
200
get_data_as_tuple() const201 ParameterDataType get_data_as_tuple() const {
202 return std::make_tuple(num_layers_, input_size_, num_units_, seq_length_,
203 batch_size_, dir_count_, has_dropout_, is_training_,
204 rnn_mode_, rnn_input_mode_, dtype_);
205 }
206
207 const int num_layers_;
208 const int input_size_;
209 const int num_units_;
210 const int seq_length_;
211 const int batch_size_;
212 const int dir_count_;
213 const bool has_dropout_;
214 const bool is_training_;
215 const RnnMode rnn_mode_;
216 const TFRNNInputMode rnn_input_mode_;
217 const DataType dtype_;
218 uint64 hash_code_;
219 };
220
221 struct RnnAutoTuneGroup {
nametensorflow::__anond88b5f210111::RnnAutoTuneGroup222 static string name() { return "Rnn"; }
223 };
224
225 using AutoTuneRnnConfigMap =
226 AutoTuneSingleton<RnnAutoTuneGroup, CudnnRnnParameters, AlgorithmConfig>;
227
ParseRNNMode(const string & str,RnnMode * rnn_mode)228 Status ParseRNNMode(const string& str, RnnMode* rnn_mode) {
229 if (str == "rnn_relu") {
230 *rnn_mode = RnnMode::kRnnRelu;
231 return Status::OK();
232 } else if (str == "rnn_tanh") {
233 *rnn_mode = RnnMode::kRnnTanh;
234 return Status::OK();
235 } else if (str == "lstm") {
236 *rnn_mode = RnnMode::kRnnLstm;
237 return Status::OK();
238 } else if (str == "gru") {
239 *rnn_mode = RnnMode::kRnnGru;
240 return Status::OK();
241 }
242 return errors::InvalidArgument("Invalid RNN mode: ", str);
243 }
244
ParseTFRNNInputMode(const string & str,TFRNNInputMode * rnn_input_mode)245 Status ParseTFRNNInputMode(const string& str, TFRNNInputMode* rnn_input_mode) {
246 if (str == "linear_input") {
247 *rnn_input_mode = TFRNNInputMode::kRNNLinearInput;
248 return Status::OK();
249 } else if (str == "skip_input") {
250 *rnn_input_mode = TFRNNInputMode::kRNNSkipInput;
251 return Status::OK();
252 } else if (str == "auto_select") {
253 *rnn_input_mode = TFRNNInputMode::kAutoSelect;
254 return Status::OK();
255 }
256 return errors::InvalidArgument("Invalid RNN input mode: ", str);
257 }
258
ParseRNNDirectionMode(const string & str,RnnDirectionMode * rnn_dir_mode)259 Status ParseRNNDirectionMode(const string& str,
260 RnnDirectionMode* rnn_dir_mode) {
261 if (str == "unidirectional") {
262 *rnn_dir_mode = RnnDirectionMode::kRnnUnidirectional;
263 return Status::OK();
264 } else if (str == "bidirectional") {
265 *rnn_dir_mode = RnnDirectionMode::kRnnBidirectional;
266 return Status::OK();
267 }
268 return errors::InvalidArgument("Invalid RNN direction mode: ", str);
269 }
270
ToRNNInputMode(TFRNNInputMode tf_input_mode,int num_units,int input_size,RnnInputMode * input_mode)271 Status ToRNNInputMode(TFRNNInputMode tf_input_mode, int num_units,
272 int input_size, RnnInputMode* input_mode) {
273 switch (tf_input_mode) {
274 case TFRNNInputMode::kRNNLinearInput:
275 *input_mode = RnnInputMode::kRnnLinearSkip;
276 break;
277 case TFRNNInputMode::kRNNSkipInput:
278 *input_mode = RnnInputMode::kRnnSkipInput;
279 break;
280 case TFRNNInputMode::kAutoSelect:
281 *input_mode = (input_size == num_units) ? RnnInputMode::kRnnSkipInput
282 : RnnInputMode::kRnnLinearSkip;
283 break;
284 default:
285 return errors::InvalidArgument("Invalid TF input mode: ",
286 static_cast<int>(tf_input_mode));
287 }
288 return Status::OK();
289 }
290
291 // TODO(zhengxq): Merge those into stream_executor_util.h.
292 template <typename T>
AsDeviceMemory(const Tensor * tensor)293 const DeviceMemory<T> AsDeviceMemory(const Tensor* tensor) {
294 return DeviceMemory<T>::MakeFromByteSize(
295 const_cast<T*>(tensor->template flat<T>().data()),
296 tensor->template flat<T>().size() * sizeof(T));
297 }
298
299 template <typename T>
AsDeviceMemory(Tensor * tensor)300 DeviceMemory<T> AsDeviceMemory(Tensor* tensor) {
301 return DeviceMemory<T>::MakeFromByteSize(
302 tensor->template flat<T>().data(),
303 tensor->template flat<T>().size() * sizeof(T));
304 }
305
306 template <typename U, typename T>
CastDeviceMemory(Tensor * tensor)307 DeviceMemory<U> CastDeviceMemory(Tensor* tensor) {
308 return DeviceMemory<U>::MakeFromByteSize(
309 tensor->template flat<T>().data(),
310 tensor->template flat<T>().size() * sizeof(T));
311 }
312
SliceDeviceMemory(const DeviceMemoryBase & device_memory,int64 offset,int64 size)313 DeviceMemoryBase SliceDeviceMemory(const DeviceMemoryBase& device_memory,
314 int64 offset, int64 size) {
315 const void* base_ptr = device_memory.opaque();
316 void* offset_ptr =
317 const_cast<char*>(reinterpret_cast<const char*>(base_ptr) + offset);
318 CHECK(offset + size <= device_memory.size())
319 << "The slice is not within the region of DeviceMemory.";
320 return DeviceMemoryBase(offset_ptr, size);
321 }
322
FromExecutorStatus(const se::port::Status & s)323 inline Status FromExecutorStatus(const se::port::Status& s) {
324 return s.ok() ? Status::OK()
325 : Status(static_cast<error::Code>(static_cast<int>(s.code())),
326 s.error_message());
327 }
328
329 template <typename T>
FromExecutorStatus(const se::port::StatusOr<T> & s)330 inline Status FromExecutorStatus(const se::port::StatusOr<T>& s) {
331 return FromExecutorStatus(s.status());
332 }
333
ToExecutorStatus(const Status & s)334 inline se::port::Status ToExecutorStatus(const Status& s) {
335 return s.ok() ? se::port::Status::OK()
336 : se::port::Status(static_cast<se::port::error::Code>(
337 static_cast<int>(s.code())),
338 s.error_message());
339 }
340
341 template <typename>
342 struct ToTFDataType;
343
344 template <>
345 struct ToTFDataType<Eigen::half> : std::integral_constant<DataType, DT_HALF> {};
346
347 template <>
348 struct ToTFDataType<float> : std::integral_constant<DataType, DT_FLOAT> {};
349
350 template <>
351 struct ToTFDataType<double> : std::integral_constant<DataType, DT_DOUBLE> {};
352
353 template <>
354 struct ToTFDataType<uint8> : std::integral_constant<DataType, DT_UINT8> {};
355
356 // A helper to allocate temporary scratch memory for Cudnn RNN models. It
357 // takes the ownership of the underlying memory. The expectation is that the
358 // memory should be alive for the span of the Cudnn RNN itself.
359 template <typename T>
360 class CudnnRnnAllocatorInTemp : public ScratchAllocator {
361 public:
362 ~CudnnRnnAllocatorInTemp() override = default;
363
CudnnRnnAllocatorInTemp(OpKernelContext * context)364 explicit CudnnRnnAllocatorInTemp(OpKernelContext* context)
365 : context_(context) {}
GetMemoryLimitInBytes()366 int64 GetMemoryLimitInBytes() override {
367 return std::numeric_limits<int64>::max();
368 }
369
AllocateBytes(int64 byte_size)370 StatusOr<DeviceMemory<uint8>> AllocateBytes(int64 byte_size) override {
371 Tensor temporary_memory;
372 const DataType tf_data_type = ToTFDataType<T>::value;
373 int64 allocate_count =
374 Eigen::divup(byte_size, static_cast<int64>(sizeof(T)));
375 Status allocation_status(context_->allocate_temp(
376 tf_data_type, TensorShape({allocate_count}), &temporary_memory));
377 if (!allocation_status.ok()) {
378 return ToExecutorStatus(allocation_status);
379 }
380 // Hold the reference of the allocated tensors until the end of the
381 // allocator.
382 allocated_tensors_.push_back(temporary_memory);
383 total_byte_size_ += byte_size;
384 return DeviceMemory<uint8>::MakeFromByteSize(
385 temporary_memory.template flat<T>().data(),
386 temporary_memory.template flat<T>().size() * sizeof(T));
387 }
388
TotalByteSize() const389 int64 TotalByteSize() const { return total_byte_size_; }
390
get_allocated_tensor(int index) const391 Tensor get_allocated_tensor(int index) const {
392 return allocated_tensors_[index];
393 }
394
395 private:
396 int64 total_byte_size_ = 0;
397 OpKernelContext* context_; // not owned
398 std::vector<Tensor> allocated_tensors_;
399 };
400
401 // A helper to allocate memory for Cudnn RNN models as a kernel output. It is
402 // used by forward pass kernel to feed the output to the backward pass.
403 // The memory is expected to live long enough after the backward pass is
404 // finished.
405 template <typename T>
406 class CudnnRnnAllocatorInOutput : public ScratchAllocator {
407 public:
~CudnnRnnAllocatorInOutput()408 ~CudnnRnnAllocatorInOutput() override {}
CudnnRnnAllocatorInOutput(OpKernelContext * context,int output_index)409 CudnnRnnAllocatorInOutput(OpKernelContext* context, int output_index)
410 : context_(context), output_index_(output_index) {}
GetMemoryLimitInBytes()411 int64 GetMemoryLimitInBytes() override {
412 return std::numeric_limits<int64>::max();
413 }
AllocateBytes(int64 byte_size)414 StatusOr<DeviceMemory<uint8>> AllocateBytes(int64 byte_size) override {
415 CHECK(total_byte_size_ == 0)
416 << "Reserve space allocator can only be called once";
417 int64 allocate_count =
418 Eigen::divup(byte_size, static_cast<int64>(sizeof(T)));
419
420 Tensor* temporary_memory = nullptr;
421 Status allocation_status(context_->allocate_output(
422 output_index_, TensorShape({allocate_count}), &temporary_memory));
423 if (!allocation_status.ok()) {
424 return ToExecutorStatus(allocation_status);
425 }
426 total_byte_size_ += byte_size;
427 auto memory_uint8 = DeviceMemory<uint8>::MakeFromByteSize(
428 temporary_memory->template flat<T>().data(),
429 temporary_memory->template flat<T>().size() * sizeof(T));
430 return StatusOr<DeviceMemory<uint8>>(memory_uint8);
431 }
TotalByteSize()432 int64 TotalByteSize() { return total_byte_size_; }
433
434 private:
435 int64 total_byte_size_ = 0;
436 OpKernelContext* context_; // not owned
437 int output_index_;
438 };
439
440 // A helper to allocate persistent memory for Cudnn RNN models, which is
441 // expected to live between kernel invocations.
442 // This class is not thread-safe.
443 class CudnnRNNPersistentSpaceAllocator : public ScratchAllocator {
444 public:
CudnnRNNPersistentSpaceAllocator(OpKernelContext * context)445 explicit CudnnRNNPersistentSpaceAllocator(OpKernelContext* context)
446 : context_(context) {}
447
~CudnnRNNPersistentSpaceAllocator()448 ~CudnnRNNPersistentSpaceAllocator() override {}
449
GetMemoryLimitInBytes()450 int64 GetMemoryLimitInBytes() override {
451 return std::numeric_limits<int64>::max();
452 }
453
AllocateBytes(int64 byte_size)454 StatusOr<DeviceMemory<uint8>> AllocateBytes(int64 byte_size) override {
455 if (total_byte_size_ != 0) {
456 return Status(error::FAILED_PRECONDITION,
457 "Persistent space allocator can only be called once");
458 }
459
460 Status allocation_status = context_->allocate_persistent(
461 DT_UINT8, TensorShape({byte_size}), &handle_, nullptr);
462 if (!allocation_status.ok()) {
463 return ToExecutorStatus(allocation_status);
464 }
465 total_byte_size_ += byte_size;
466 return AsDeviceMemory<uint8>(handle_.AccessTensor(context_));
467 }
TotalByteSize()468 int64 TotalByteSize() { return total_byte_size_; }
469
470 private:
471 int64 total_byte_size_ = 0;
472 PersistentTensor handle_;
473 OpKernelContext* context_; // not owned
474 };
475
476 struct CudnnModelTypes {
477 RnnMode rnn_mode;
478 TFRNNInputMode rnn_input_mode;
479 RnnDirectionMode rnn_direction_mode;
HasInputCtensorflow::__anond88b5f210111::CudnnModelTypes480 bool HasInputC() const {
481 // For Cudnn 5.0, only LSTM has input-c. All other models use only
482 // input-h.
483 return rnn_mode == RnnMode::kRnnLstm;
484 }
485
DebugStringtensorflow::__anond88b5f210111::CudnnModelTypes486 string DebugString() const {
487 return strings::Printf(
488 "[rnn_mode, rnn_input_mode, rnn_direction_mode]: %d, %d, %d ",
489 static_cast<int>(rnn_mode), static_cast<int>(rnn_input_mode),
490 static_cast<int>(rnn_direction_mode));
491 }
492 };
493
494 // A helper class that collects the shapes to describe a RNN model.
495 struct CudnnRnnModelShapes {
496 int num_layers;
497 int input_size;
498 int num_units;
499 int dir_count;
500 int max_seq_length;
501 int batch_size;
502 int cell_num_units = 0;
503 // If you add new field to this structure, please take care of
504 // updating IsCompatibleWith() below as well as the hash function in
505 // CudnnRnnConfigHasher.
506 TensorShape input_shape;
507 TensorShape output_shape;
508 TensorShape hidden_state_shape;
509 TensorShape cell_state_shape;
510 // At present only fields related to cached RnnDescriptor are concerned.
IsCompatibleWithtensorflow::__anond88b5f210111::CudnnRnnModelShapes511 bool IsCompatibleWith(const CudnnRnnModelShapes& rhs) const {
512 return num_layers == rhs.num_layers && input_size == rhs.input_size &&
513 num_units == rhs.num_units && dir_count == rhs.dir_count &&
514 cell_num_units == rhs.cell_num_units &&
515 max_seq_length == rhs.max_seq_length;
516 }
DebugStringtensorflow::__anond88b5f210111::CudnnRnnModelShapes517 string DebugString() const {
518 return strings::Printf(
519 "[num_layers, input_size, num_units, dir_count, max_seq_length, "
520 "batch_size, cell_num_units]: [%d, %d, %d, %d, %d, %d, %d] ",
521 num_layers, input_size, num_units, dir_count, max_seq_length,
522 batch_size, cell_num_units);
523 }
524 };
525
526 // Utility class for using CudnnRnnConfig and AlgorithmDesc pair a hash table
527 // key.
528 struct CudnnRnnConfigHasher {
operator ()tensorflow::__anond88b5f210111::CudnnRnnConfigHasher529 uint64 operator()(
530 const std::pair<CudnnRnnModelShapes, absl::optional<AlgorithmDesc>>&
531 to_hash) const {
532 auto& shapes = to_hash.first;
533 auto& algo_desc = to_hash.second;
534
535 uint64 hash =
536 HashList({shapes.num_layers, shapes.input_size, shapes.num_units,
537 shapes.dir_count, shapes.max_seq_length, shapes.batch_size});
538 if (algo_desc.has_value()) {
539 hash = Hash64Combine(hash, algo_desc->hash());
540 }
541 return hash;
542 }
543 };
544
545 // Utility class for using CudnnRnnModelShapes and AlgorithmDesc pair as a hash
546 // table key.
547 struct CudnnRnnConfigComparator {
operator ()tensorflow::__anond88b5f210111::CudnnRnnConfigComparator548 bool operator()(
549 const std::pair<CudnnRnnModelShapes, absl::optional<AlgorithmDesc>>& lhs,
550 const std::pair<CudnnRnnModelShapes, absl::optional<AlgorithmDesc>>& rhs)
551 const {
552 return lhs.first.IsCompatibleWith(rhs.first) && lhs.second == rhs.second;
553 }
554 };
555
556 // Pointers to RNN scratch space for a specific set of shape parameters (used as
557 // a hash table value in CudnnRNNForwardOp and CudnnRNNBackwardOp).
558 struct RnnScratchSpace {
559 std::unique_ptr<RnnDescriptor> rnn_desc;
560 std::unique_ptr<CudnnRNNPersistentSpaceAllocator> dropout_state_allocator;
561 };
562
563 // Extract and checks the forward input tensors, parameters, and shapes from the
564 // OpKernelContext.
ExtractForwardInput(OpKernelContext * context,const CudnnModelTypes & model_types,bool time_major,const Tensor ** input,const Tensor ** input_h,const Tensor ** input_c,const Tensor ** params,const int num_proj,CudnnRnnModelShapes * model_shapes)565 Status ExtractForwardInput(OpKernelContext* context,
566 const CudnnModelTypes& model_types, bool time_major,
567 const Tensor** input, const Tensor** input_h,
568 const Tensor** input_c, const Tensor** params,
569 const int num_proj,
570 CudnnRnnModelShapes* model_shapes) {
571 TF_RETURN_IF_ERROR(context->input("input", input));
572 TF_RETURN_IF_ERROR(context->input("input_h", input_h));
573 if (model_types.HasInputC()) {
574 TF_RETURN_IF_ERROR(context->input("input_c", input_c));
575 }
576 TF_RETURN_IF_ERROR(context->input("params", params));
577
578 if ((*input)->dims() != 3) {
579 return errors::InvalidArgument("RNN input must be a 3-D vector.");
580 }
581 if (time_major) {
582 model_shapes->max_seq_length = (*input)->dim_size(0);
583 model_shapes->batch_size = (*input)->dim_size(1);
584 } else {
585 model_shapes->max_seq_length = (*input)->dim_size(1);
586 model_shapes->batch_size = (*input)->dim_size(0);
587 }
588 model_shapes->input_size = (*input)->dim_size(2);
589 model_shapes->input_shape = (*input)->shape();
590 model_shapes->dir_count =
591 (model_types.rnn_direction_mode == RnnDirectionMode::kRnnBidirectional)
592 ? 2
593 : 1;
594
595 if ((*input_h)->dims() != 3) {
596 return errors::InvalidArgument("RNN input_h must be a 3-D vector.");
597 }
598 if (time_major) {
599 model_shapes->num_layers =
600 (*input_h)->dim_size(0) / model_shapes->dir_count;
601 } else {
602 model_shapes->num_layers =
603 (*input_h)->dim_size(1) / model_shapes->dir_count;
604 }
605 model_shapes->num_units = (*input_h)->dim_size(2);
606
607 if (time_major) {
608 model_shapes->hidden_state_shape =
609 TensorShape({model_shapes->dir_count * model_shapes->num_layers,
610 model_shapes->batch_size, model_shapes->num_units});
611 } else {
612 model_shapes->hidden_state_shape =
613 TensorShape({model_shapes->batch_size,
614 model_shapes->dir_count * model_shapes->num_layers,
615 model_shapes->num_units});
616 }
617 if ((*input_h)->shape() != model_shapes->hidden_state_shape) {
618 return errors::InvalidArgument(
619 "Invalid input_h shape: ", (*input_h)->shape().DebugString(), " ",
620 model_shapes->hidden_state_shape.DebugString());
621 }
622 if (model_types.HasInputC()) {
623 model_shapes->cell_num_units = (*input_c)->dim_size(2);
624 if (time_major) {
625 model_shapes->cell_state_shape =
626 TensorShape({model_shapes->dir_count * model_shapes->num_layers,
627 model_shapes->batch_size, model_shapes->cell_num_units});
628 } else {
629 model_shapes->cell_state_shape =
630 TensorShape({model_shapes->batch_size,
631 model_shapes->dir_count * model_shapes->num_layers,
632 model_shapes->cell_num_units});
633 }
634 if (num_proj == 0) {
635 if ((*input_h)->shape() != (*input_c)->shape()) {
636 return errors::InvalidArgument(
637 "input_h and input_c must have the same shape w/o projection: ",
638 (*input_h)->shape().DebugString(), " ",
639 (*input_c)->shape().DebugString());
640 }
641 } else {
642 if ((*input_h)->dim_size(2) > (*input_c)->dim_size(2) ||
643 num_proj != (*input_h)->dim_size(2) ||
644 (*input_h)->dim_size(0) != (*input_c)->dim_size(0) ||
645 (*input_h)->dim_size(1) != (*input_c)->dim_size(1)) {
646 return errors::InvalidArgument(
647 "Invalid input_h and input_c w/ projection size: ", num_proj, " ",
648 (*input_h)->shape().DebugString(), " ",
649 (*input_c)->shape().DebugString());
650 }
651 }
652 } else {
653 // dummy cell_state_shape TODO(kaixih): remove the time_major branch
654 if (time_major) {
655 model_shapes->cell_state_shape =
656 TensorShape({model_shapes->dir_count * model_shapes->num_layers,
657 model_shapes->batch_size, model_shapes->num_units});
658 } else {
659 model_shapes->cell_state_shape =
660 TensorShape({model_shapes->batch_size,
661 model_shapes->dir_count * model_shapes->num_layers,
662 model_shapes->num_units});
663 }
664 model_shapes->cell_num_units = 0;
665 }
666 if (time_major) {
667 model_shapes->output_shape =
668 TensorShape({model_shapes->max_seq_length, model_shapes->batch_size,
669 model_shapes->dir_count * model_shapes->num_units});
670 } else {
671 model_shapes->output_shape =
672 TensorShape({model_shapes->batch_size, model_shapes->max_seq_length,
673 model_shapes->dir_count * model_shapes->num_units});
674 }
675 return Status::OK();
676 }
677
678 // Overloaded function to process the sequence_lengths
ExtractForwardInput(OpKernelContext * context,const CudnnModelTypes & model_types,bool time_major,const Tensor ** input,const Tensor ** input_h,const Tensor ** input_c,const Tensor ** params,const Tensor ** sequence_lengths,const int num_proj,CudnnRnnModelShapes * model_shapes)679 Status ExtractForwardInput(OpKernelContext* context,
680 const CudnnModelTypes& model_types, bool time_major,
681 const Tensor** input, const Tensor** input_h,
682 const Tensor** input_c, const Tensor** params,
683 const Tensor** sequence_lengths, const int num_proj,
684 CudnnRnnModelShapes* model_shapes) {
685 TF_RETURN_IF_ERROR(context->input("sequence_lengths", sequence_lengths));
686 return ExtractForwardInput(context, model_types, time_major, input, input_h,
687 input_c, params, num_proj, model_shapes);
688 }
689
690 template <typename T>
CreateForwardAndBackwardIODescriptors(OpKernelContext * context,const CudnnRnnModelShapes & model_shapes,std::unique_ptr<RnnSequenceTensorDescriptor> * input_desc,std::unique_ptr<RnnStateTensorDescriptor> * h_state_desc,std::unique_ptr<RnnStateTensorDescriptor> * c_state_desc,std::unique_ptr<RnnSequenceTensorDescriptor> * output_desc,const absl::Span<const int> & seq_lengths,bool time_major)691 Status CreateForwardAndBackwardIODescriptors(
692 OpKernelContext* context, const CudnnRnnModelShapes& model_shapes,
693 std::unique_ptr<RnnSequenceTensorDescriptor>* input_desc,
694 std::unique_ptr<RnnStateTensorDescriptor>* h_state_desc,
695 std::unique_ptr<RnnStateTensorDescriptor>* c_state_desc,
696 std::unique_ptr<RnnSequenceTensorDescriptor>* output_desc,
697 const absl::Span<const int>& seq_lengths, bool time_major) {
698 StreamExecutor* executor = context->op_device_context()->stream()->parent();
699 se::dnn::DataType data_type = ToDataType<T>::value;
700
701 const TensorShape& input_shape = model_shapes.input_shape;
702 const TensorShape& hidden_state_shape = model_shapes.hidden_state_shape;
703 const TensorShape& cell_state_shape = model_shapes.cell_state_shape;
704 const TensorShape& output_shape = model_shapes.output_shape;
705
706 DCHECK_EQ(input_shape.dims(), 3);
707 if (seq_lengths.data() != nullptr) {
708 if (time_major) {
709 auto input_desc_s = executor->createRnnSequenceTensorDescriptor(
710 input_shape.dim_size(0), input_shape.dim_size(1),
711 input_shape.dim_size(2), seq_lengths, time_major, data_type);
712 TF_RETURN_IF_ERROR(input_desc_s.status());
713 *input_desc = input_desc_s.ConsumeValueOrDie();
714 } else {
715 auto input_desc_s = executor->createRnnSequenceTensorDescriptor(
716 input_shape.dim_size(1), input_shape.dim_size(0),
717 input_shape.dim_size(2), seq_lengths, time_major, data_type);
718 TF_RETURN_IF_ERROR(input_desc_s.status());
719 *input_desc = input_desc_s.ConsumeValueOrDie();
720 }
721 } else {
722 auto input_desc_s = executor->createRnnSequenceTensorDescriptor(
723 input_shape.dim_size(0), input_shape.dim_size(1),
724 input_shape.dim_size(2), data_type);
725 TF_RETURN_IF_ERROR(input_desc_s.status());
726 *input_desc = input_desc_s.ConsumeValueOrDie();
727 }
728
729 DCHECK_EQ(hidden_state_shape.dims(), 3);
730 if (time_major) {
731 auto hidden_state_desc_s = executor->createRnnStateTensorDescriptor(
732 hidden_state_shape.dim_size(0), hidden_state_shape.dim_size(1),
733 hidden_state_shape.dim_size(2), data_type);
734 TF_RETURN_IF_ERROR(hidden_state_desc_s.status());
735 *h_state_desc = hidden_state_desc_s.ConsumeValueOrDie();
736 } else {
737 auto hidden_state_desc_s = executor->createRnnStateTensorDescriptor(
738 hidden_state_shape.dim_size(1), hidden_state_shape.dim_size(0),
739 hidden_state_shape.dim_size(2), data_type);
740 TF_RETURN_IF_ERROR(hidden_state_desc_s.status());
741 *h_state_desc = hidden_state_desc_s.ConsumeValueOrDie();
742 }
743
744 DCHECK_EQ(cell_state_shape.dims(), 3);
745 if (time_major) {
746 auto cell_state_desc_s = executor->createRnnStateTensorDescriptor(
747 cell_state_shape.dim_size(0), cell_state_shape.dim_size(1),
748 cell_state_shape.dim_size(2), data_type);
749 TF_RETURN_IF_ERROR(cell_state_desc_s.status());
750 *c_state_desc = cell_state_desc_s.ConsumeValueOrDie();
751 } else {
752 auto cell_state_desc_s = executor->createRnnStateTensorDescriptor(
753 cell_state_shape.dim_size(1), cell_state_shape.dim_size(0),
754 cell_state_shape.dim_size(2), data_type);
755 TF_RETURN_IF_ERROR(cell_state_desc_s.status());
756 *c_state_desc = cell_state_desc_s.ConsumeValueOrDie();
757 }
758
759 DCHECK_EQ(output_shape.dims(), 3);
760 if (seq_lengths.data() != nullptr) {
761 if (time_major) {
762 auto output_desc_s = executor->createRnnSequenceTensorDescriptor(
763 output_shape.dim_size(0), output_shape.dim_size(1),
764 output_shape.dim_size(2), seq_lengths, time_major, data_type);
765 TF_RETURN_IF_ERROR(output_desc_s.status());
766 *output_desc = output_desc_s.ConsumeValueOrDie();
767 } else {
768 auto output_desc_s = executor->createRnnSequenceTensorDescriptor(
769 output_shape.dim_size(1), output_shape.dim_size(0),
770 output_shape.dim_size(2), seq_lengths, time_major, data_type);
771 TF_RETURN_IF_ERROR(output_desc_s.status());
772 *output_desc = output_desc_s.ConsumeValueOrDie();
773 }
774 } else {
775 auto output_desc_s = executor->createRnnSequenceTensorDescriptor(
776 output_shape.dim_size(0), output_shape.dim_size(1),
777 output_shape.dim_size(2), data_type);
778 TF_RETURN_IF_ERROR(output_desc_s.status());
779 *output_desc = output_desc_s.ConsumeValueOrDie();
780 }
781
782 return Status::OK();
783 }
784
785 template <typename T>
DoForward(OpKernelContext * context,const RnnDescriptor & rnn_desc,const CudnnModelTypes & model_types,const CudnnRnnModelShapes & model_shapes,const Tensor * input,const Tensor * input_h,const Tensor * input_c,const Tensor * params,const bool is_training,Tensor * output,Tensor * output_h,Tensor * output_c,const Tensor * sequence_lengths,bool time_major,ScratchAllocator * reserve_space_allocator,ScratchAllocator * workspace_allocator,ProfileResult * output_profile_result)786 Status DoForward(OpKernelContext* context, const RnnDescriptor& rnn_desc,
787 const CudnnModelTypes& model_types,
788 const CudnnRnnModelShapes& model_shapes,
789 /* forward inputs */
790 const Tensor* input, const Tensor* input_h,
791 const Tensor* input_c, const Tensor* params,
792 const bool is_training,
793 /* forward outputs, outputs of the function */
794 Tensor* output, Tensor* output_h, Tensor* output_c,
795 const Tensor* sequence_lengths, bool time_major,
796 ScratchAllocator* reserve_space_allocator,
797 ScratchAllocator* workspace_allocator,
798 ProfileResult* output_profile_result) {
799 std::unique_ptr<RnnSequenceTensorDescriptor> input_desc;
800 std::unique_ptr<RnnStateTensorDescriptor> h_state_desc;
801 std::unique_ptr<RnnStateTensorDescriptor> c_state_desc;
802 std::unique_ptr<RnnSequenceTensorDescriptor> output_desc;
803
804 absl::Span<const int> seq_lengths;
805 if (sequence_lengths != nullptr) {
806 seq_lengths = absl::Span<const int>(
807 sequence_lengths->template flat<int>().data(), model_shapes.batch_size);
808 }
809 TF_RETURN_IF_ERROR(CreateForwardAndBackwardIODescriptors<T>(
810 context, model_shapes, &input_desc, &h_state_desc, &c_state_desc,
811 &output_desc, seq_lengths, time_major));
812
813 auto input_data = AsDeviceMemory<T>(input);
814 auto input_h_data = AsDeviceMemory<T>(input_h);
815 DeviceMemory<T> input_c_data;
816 if (model_types.HasInputC()) {
817 input_c_data = AsDeviceMemory<T>(input_c);
818 }
819
820 auto params_data = AsDeviceMemory<T>(params);
821 auto output_data = AsDeviceMemory<T>(output);
822 auto output_h_data = AsDeviceMemory<T>(output_h);
823 DeviceMemory<T> output_c_data;
824 if (model_types.HasInputC()) {
825 output_c_data = AsDeviceMemory<T>(output_c);
826 }
827
828 Stream* stream = context->op_device_context()->stream();
829 bool launch_success =
830 stream
831 ->ThenRnnForward(rnn_desc, *input_desc, input_data, *h_state_desc,
832 input_h_data, *c_state_desc, input_c_data,
833 params_data, *output_desc, &output_data,
834 *h_state_desc, &output_h_data, *c_state_desc,
835 &output_c_data, is_training, reserve_space_allocator,
836 workspace_allocator, output_profile_result)
837 .ok();
838 return launch_success
839 ? Status::OK()
840 : errors::Internal(
841 "Failed to call ThenRnnForward with model config: ",
842 model_types.DebugString(), ", ", model_shapes.DebugString());
843 }
844
845 template <typename T>
DoBackward(OpKernelContext * context,const RnnDescriptor & rnn_desc,const CudnnModelTypes & model_types,const CudnnRnnModelShapes & model_shapes,const Tensor * input,const Tensor * input_h,const Tensor * input_c,const Tensor * params,const Tensor * output,const Tensor * output_h,const Tensor * output_c,const Tensor * output_backprop,const Tensor * output_h_backprop,const Tensor * output_c_backprop,const Tensor * reserve_space,Tensor * input_backprop,Tensor * input_h_backprop,Tensor * input_c_backprop,Tensor * params_backprop,const Tensor * sequence_lengths,bool time_major,ScratchAllocator * workspace_allocator,ProfileResult * output_profile_result)846 Status DoBackward(
847 OpKernelContext* context, const RnnDescriptor& rnn_desc,
848 const CudnnModelTypes& model_types, const CudnnRnnModelShapes& model_shapes,
849 /* forward inputs */
850 const Tensor* input, const Tensor* input_h, const Tensor* input_c,
851 const Tensor* params,
852 /* forward outputs */
853 const Tensor* output, const Tensor* output_h, const Tensor* output_c,
854 /* backprop inputs */
855 const Tensor* output_backprop, const Tensor* output_h_backprop,
856 const Tensor* output_c_backprop, const Tensor* reserve_space,
857 /* backprop outputs, output of the function */
858 Tensor* input_backprop, Tensor* input_h_backprop, Tensor* input_c_backprop,
859 Tensor* params_backprop, const Tensor* sequence_lengths, bool time_major,
860 ScratchAllocator* workspace_allocator,
861 ProfileResult* output_profile_result) {
862 std::unique_ptr<RnnSequenceTensorDescriptor> input_desc;
863 std::unique_ptr<RnnStateTensorDescriptor> h_state_desc;
864 std::unique_ptr<RnnStateTensorDescriptor> c_state_desc;
865 std::unique_ptr<RnnSequenceTensorDescriptor> output_desc;
866
867 absl::Span<const int> seq_lengths;
868 if (sequence_lengths != nullptr) {
869 seq_lengths = absl::Span<const int>(
870 sequence_lengths->template flat<int>().data(), model_shapes.batch_size);
871 }
872 TF_RETURN_IF_ERROR(CreateForwardAndBackwardIODescriptors<T>(
873 context, model_shapes, &input_desc, &h_state_desc, &c_state_desc,
874 &output_desc, seq_lengths, time_major));
875
876 auto input_data = AsDeviceMemory<T>(input);
877 auto input_h_data = AsDeviceMemory<T>(input_h);
878 DeviceMemory<T> input_c_data;
879 if (model_types.HasInputC()) {
880 input_c_data = AsDeviceMemory<T>(input_c);
881 }
882 auto params_data = AsDeviceMemory<T>(params);
883 auto output_data = AsDeviceMemory<T>(output);
884 auto output_h_data = AsDeviceMemory<T>(output_h);
885 DeviceMemory<T> output_c_data;
886 if (model_types.HasInputC()) {
887 output_c_data = AsDeviceMemory<T>(output_c);
888 }
889 auto output_backprop_data = AsDeviceMemory<T>(output_backprop);
890 auto output_h_backprop_data = AsDeviceMemory<T>(output_h_backprop);
891 DeviceMemory<T> output_c_backprop_data;
892 if (model_types.HasInputC()) {
893 output_c_backprop_data = AsDeviceMemory<T>(output_c_backprop);
894 }
895 auto input_backprop_data = AsDeviceMemory<T>(input_backprop);
896 auto input_h_backprop_data = AsDeviceMemory<T>(input_h_backprop);
897 DeviceMemory<T> input_c_backprop_data;
898 if (model_types.HasInputC()) {
899 input_c_backprop_data = AsDeviceMemory<T>(input_c_backprop);
900 }
901 auto params_backprop_data = AsDeviceMemory<T>(params_backprop);
902 auto reserve_space_uint8 =
903 CastDeviceMemory<uint8, T>(const_cast<Tensor*>(reserve_space));
904
905 // Creates a memory callback for the workspace. The memory lives to the end
906 // of this kernel calls.
907 Stream* stream = context->op_device_context()->stream();
908 bool launch_success =
909 stream
910 ->ThenRnnBackward(
911 rnn_desc, *input_desc, input_data, *h_state_desc, input_h_data,
912 *c_state_desc, input_c_data, params_data, *output_desc,
913 output_data, *h_state_desc, output_h_data, *c_state_desc,
914 output_c_data, output_backprop_data, output_h_backprop_data,
915 output_c_backprop_data, &input_backprop_data,
916 &input_h_backprop_data, &input_c_backprop_data,
917 ¶ms_backprop_data, &reserve_space_uint8, workspace_allocator,
918 output_profile_result)
919 .ok();
920 return launch_success
921 ? Status::OK()
922 : errors::Internal(
923 "Failed to call ThenRnnBackward with model config: ",
924 model_types.DebugString(), ", ", model_shapes.DebugString());
925 }
926
927 template <typename T>
RestoreParams(const OpInputList params_input,const std::vector<RnnDescriptor::ParamsRegion> & params,DeviceMemoryBase * data_dst,Stream * stream)928 void RestoreParams(const OpInputList params_input,
929 const std::vector<RnnDescriptor::ParamsRegion>& params,
930 DeviceMemoryBase* data_dst, Stream* stream) {
931 int num_params = params.size();
932 CHECK(params_input.size() == num_params)
933 << "Number of params mismatch. Expected " << params_input.size()
934 << ", got " << num_params;
935 for (int i = 0; i < params.size(); i++) {
936 int64 size_in_bytes = params[i].size;
937 int64 size = size_in_bytes / sizeof(T);
938 CHECK(size == params_input[i].NumElements())
939 << "Params size mismatch. Expected " << size << ", got "
940 << params_input[i].NumElements();
941 auto data_src_ptr = StreamExecutorUtil::AsDeviceMemory<T>(params_input[i]);
942 DeviceMemoryBase data_dst_ptr =
943 SliceDeviceMemory(*data_dst, params[i].offset, size_in_bytes);
944 stream->ThenMemcpy(&data_dst_ptr, data_src_ptr, size_in_bytes);
945 }
946 }
947
ShouldUsePaddedIO(const Tensor * sequence_lengths,const CudnnRnnModelShapes & model_shapes,bool time_major)948 bool ShouldUsePaddedIO(const Tensor* sequence_lengths,
949 const CudnnRnnModelShapes& model_shapes,
950 bool time_major) {
951 auto seq_array = sequence_lengths->template flat<int>().data();
952 bool all_max_seq_length = true;
953 for (int i = 0; i < model_shapes.batch_size; i++) {
954 if (seq_array[i] != model_shapes.max_seq_length) {
955 all_max_seq_length = false;
956 break;
957 }
958 }
959 return !(time_major && all_max_seq_length);
960 }
961
962 } // namespace
963
964 // Note: all following kernels depend on a RnnDescriptor instance, which
965 // according to Cudnn official doc should be kept around and reused across all
966 // Cudnn kernels in the same model.
967 // In Tensorflow, we don't pass the reference across different OpKernels,
968 // rather, recreate it separately in each OpKernel, which does no cause issue:
969 // CudnnDropoutDescriptor keeps a reference to a memory for
970 // random number generator state. During recreation, this state is lost.
971 // However, only forward-pass Cudnn APIs make use of the state.
972
973 // A common base class for RNN kernels. It extracts common attributes and
974 // shape validations.
975 class CudnnRNNKernelCommon : public OpKernel {
976 protected:
CudnnRNNKernelCommon(OpKernelConstruction * context)977 explicit CudnnRNNKernelCommon(OpKernelConstruction* context)
978 : OpKernel(context) {
979 OP_REQUIRES_OK(context, context->GetAttr("dropout", &dropout_));
980 OP_REQUIRES_OK(context, context->GetAttr("seed", &seed_));
981 OP_REQUIRES_OK(context, context->GetAttr("seed2", &seed2_));
982 string str;
983 OP_REQUIRES_OK(context, context->GetAttr("rnn_mode", &str));
984 OP_REQUIRES_OK(context, ParseRNNMode(str, &model_types_.rnn_mode));
985 OP_REQUIRES_OK(context, context->GetAttr("input_mode", &str));
986 OP_REQUIRES_OK(context,
987 ParseTFRNNInputMode(str, &model_types_.rnn_input_mode));
988 OP_REQUIRES_OK(context, context->GetAttr("direction", &str));
989 OP_REQUIRES_OK(
990 context, ParseRNNDirectionMode(str, &model_types_.rnn_direction_mode));
991 // Reset CudnnRnnDescriptor and related random number generate states in
992 // every Compute() call.
993 OP_REQUIRES_OK(context, ReadBoolFromEnvVar("TF_CUDNN_RESET_RND_GEN_STATE",
994 false, &reset_rnd_gen_state_));
995 }
996
HasInputC() const997 bool HasInputC() const { return model_types_.HasInputC(); }
rnn_mode() const998 RnnMode rnn_mode() const { return model_types_.rnn_mode; }
rnn_input_mode() const999 TFRNNInputMode rnn_input_mode() const { return model_types_.rnn_input_mode; }
rnn_direction_mode() const1000 RnnDirectionMode rnn_direction_mode() const {
1001 return model_types_.rnn_direction_mode;
1002 }
model_types() const1003 const CudnnModelTypes& model_types() const { return model_types_; }
dropout() const1004 float dropout() const { return dropout_; }
seed()1005 uint64 seed() { return (static_cast<uint64>(seed_) << 32) | seed2_; }
ResetRndGenState()1006 bool ResetRndGenState() { return reset_rnd_gen_state_; }
1007
1008 template <typename T>
ExtractCudnnRNNParamsInfo(OpKernelContext * context,int num_proj,std::unique_ptr<RnnDescriptor> * rnn_desc)1009 Status ExtractCudnnRNNParamsInfo(OpKernelContext* context, int num_proj,
1010 std::unique_ptr<RnnDescriptor>* rnn_desc) {
1011 const Tensor* num_layers_t = nullptr;
1012 TF_RETURN_IF_ERROR(context->input("num_layers", &num_layers_t));
1013 if (!TensorShapeUtils::IsScalar(num_layers_t->shape())) {
1014 return errors::InvalidArgument("num_layers is not a scalar");
1015 }
1016 int num_layers = num_layers_t->scalar<int>()();
1017 const Tensor* num_units_t = nullptr;
1018 TF_RETURN_IF_ERROR(context->input("num_units", &num_units_t));
1019 if (!TensorShapeUtils::IsScalar(num_units_t->shape())) {
1020 return errors::InvalidArgument("num_units is not a scalar");
1021 }
1022 int num_units = num_units_t->scalar<int>()();
1023 const Tensor* input_size_t = nullptr;
1024 TF_RETURN_IF_ERROR(context->input("input_size", &input_size_t));
1025 if (!TensorShapeUtils::IsScalar(input_size_t->shape())) {
1026 return errors::InvalidArgument("input_size is not a scalar");
1027 }
1028 int input_size = input_size_t->scalar<int>()();
1029
1030 int h_num_units = (num_proj == 0 ? num_units : num_proj);
1031 int c_num_units = (num_proj == 0 ? 0 : num_units);
1032
1033 RnnInputMode input_mode;
1034 TF_RETURN_IF_ERROR(
1035 ToRNNInputMode(rnn_input_mode(), num_units, input_size, &input_mode));
1036
1037 Stream* stream = context->op_device_context()->stream();
1038 // ExtracCudnnRNNParamsInfo is only called by op_kernels that do not require
1039 // random number generator, therefore set state_allocator to nullptr.
1040 const AlgorithmConfig algo_config;
1041 auto rnn_desc_s = stream->parent()->createRnnDescriptor(
1042 num_layers, h_num_units, input_size, /*cell_size=*/c_num_units,
1043 /*batch_size=*/0, input_mode, rnn_direction_mode(), rnn_mode(),
1044 ToDataType<T>::value, algo_config, dropout(), seed(),
1045 /* state_allocator=*/nullptr, /*use_padded_io=*/false);
1046 if (!rnn_desc_s.ok()) {
1047 return FromExecutorStatus(rnn_desc_s);
1048 }
1049 *rnn_desc = rnn_desc_s.ConsumeValueOrDie();
1050 return Status::OK();
1051 }
1052
1053 template <typename T>
CreateRnnDescriptor(OpKernelContext * context,const CudnnRnnModelShapes & model_shapes,const RnnInputMode & input_mode,const AlgorithmConfig & algo_config,ScratchAllocator * dropout_state_allocator,std::unique_ptr<RnnDescriptor> * rnn_desc,bool use_padded_io)1054 Status CreateRnnDescriptor(OpKernelContext* context,
1055 const CudnnRnnModelShapes& model_shapes,
1056 const RnnInputMode& input_mode,
1057 const AlgorithmConfig& algo_config,
1058 ScratchAllocator* dropout_state_allocator,
1059 std::unique_ptr<RnnDescriptor>* rnn_desc,
1060 bool use_padded_io) {
1061 StreamExecutor* executor = context->op_device_context()->stream()->parent();
1062 se::dnn::DataType data_type = ToDataType<T>::value;
1063 auto rnn_desc_s = executor->createRnnDescriptor(
1064 model_shapes.num_layers, model_shapes.num_units,
1065 model_shapes.input_size, model_shapes.cell_num_units,
1066 model_shapes.batch_size, input_mode, rnn_direction_mode(), rnn_mode(),
1067 data_type, algo_config, dropout(), seed(), dropout_state_allocator,
1068 use_padded_io);
1069 TF_RETURN_IF_ERROR(rnn_desc_s.status());
1070
1071 *rnn_desc = rnn_desc_s.ConsumeValueOrDie();
1072 return Status::OK();
1073 }
1074
1075 using RnnStateCache = gtl::FlatMap<
1076 std::pair<CudnnRnnModelShapes, absl::optional<AlgorithmDesc>>,
1077 RnnScratchSpace, CudnnRnnConfigHasher, CudnnRnnConfigComparator>;
1078 // Returns a raw rnn descriptor pointer. The cache owns the rnn descriptor and
1079 // should outlive the returned pointer.
1080 template <typename T>
GetCachedRnnDescriptor(OpKernelContext * context,const CudnnRnnModelShapes & model_shapes,const RnnInputMode & input_mode,const AlgorithmConfig & algo_config,RnnStateCache * cache,RnnDescriptor ** rnn_desc,bool use_padded_io)1081 Status GetCachedRnnDescriptor(OpKernelContext* context,
1082 const CudnnRnnModelShapes& model_shapes,
1083 const RnnInputMode& input_mode,
1084 const AlgorithmConfig& algo_config,
1085 RnnStateCache* cache, RnnDescriptor** rnn_desc,
1086 bool use_padded_io) {
1087 auto key = std::make_pair(model_shapes, algo_config.algorithm());
1088 RnnScratchSpace& rnn_state = (*cache)[key];
1089 if (rnn_state.rnn_desc == nullptr || ResetRndGenState()) {
1090 CudnnRNNPersistentSpaceAllocator* dropout_state_allocator =
1091 new CudnnRNNPersistentSpaceAllocator(context);
1092 rnn_state.dropout_state_allocator.reset(dropout_state_allocator);
1093 Status status = CreateRnnDescriptor<T>(
1094 context, model_shapes, input_mode, algo_config,
1095 dropout_state_allocator, &rnn_state.rnn_desc, use_padded_io);
1096 TF_RETURN_IF_ERROR(status);
1097 }
1098 *rnn_desc = rnn_state.rnn_desc.get();
1099 return Status::OK();
1100 }
1101
1102 private:
1103 int seed_;
1104 int seed2_;
1105 float dropout_;
1106 bool reset_rnd_gen_state_;
1107
1108 CudnnModelTypes model_types_;
1109 };
1110
1111 // A class that returns the size of the opaque parameter buffer. The user should
1112 // use that to create the actual parameter buffer for training. However, it
1113 // should not be used for saving and restoring.
1114 template <typename T, typename Index>
1115 class CudnnRNNParamsSizeOp<GPUDevice, T, Index> : public CudnnRNNKernelCommon {
1116 public:
CudnnRNNParamsSizeOp(OpKernelConstruction * context)1117 explicit CudnnRNNParamsSizeOp(OpKernelConstruction* context)
1118 : CudnnRNNKernelCommon(context) {
1119 if (context->HasAttr("num_proj")) {
1120 OP_REQUIRES_OK(context, context->GetAttr("num_proj", &num_proj_));
1121 } else {
1122 num_proj_ = 0;
1123 }
1124 }
1125
Compute(OpKernelContext * context)1126 void Compute(OpKernelContext* context) override {
1127 std::unique_ptr<RnnDescriptor> rnn_desc;
1128 OP_REQUIRES_OK(context,
1129 ExtractCudnnRNNParamsInfo<T>(context, num_proj_, &rnn_desc));
1130 int64 params_size_in_bytes = rnn_desc->ParamsSizeInBytes();
1131 CHECK(params_size_in_bytes % sizeof(T) == 0)
1132 << "params_size_in_bytes must be multiple of element size";
1133 int64 params_size = params_size_in_bytes / sizeof(T);
1134
1135 Tensor* output_t = nullptr;
1136 OP_REQUIRES_OK(context, context->allocate_output(0, {1}, &output_t));
1137 *output_t->template flat<Index>().data() = params_size;
1138 }
1139
1140 private:
1141 int num_proj_;
1142 };
1143
1144 #define REGISTER_GPU(T) \
1145 REGISTER_KERNEL_BUILDER(Name("CudnnRNNParamsSize") \
1146 .Device(DEVICE_GPU) \
1147 .HostMemory("num_layers") \
1148 .HostMemory("num_units") \
1149 .HostMemory("input_size") \
1150 .HostMemory("params_size") \
1151 .TypeConstraint<T>("T") \
1152 .TypeConstraint<int32>("S"), \
1153 CudnnRNNParamsSizeOp<GPUDevice, T, int32>);
1154
1155 TF_CALL_half(REGISTER_GPU);
1156 TF_CALL_float(REGISTER_GPU);
1157 TF_CALL_double(REGISTER_GPU);
1158 #undef REGISTER_GPU
1159
1160 // Convert weight and bias params from a platform-specific layout to the
1161 // canonical form.
1162 template <typename T>
1163 class CudnnRNNParamsToCanonical<GPUDevice, T> : public CudnnRNNKernelCommon {
1164 public:
CudnnRNNParamsToCanonical(OpKernelConstruction * context)1165 explicit CudnnRNNParamsToCanonical(OpKernelConstruction* context)
1166 : CudnnRNNKernelCommon(context) {
1167 if (context->HasAttr("num_params")) {
1168 OP_REQUIRES_OK(context, context->GetAttr("num_params", &num_params_));
1169 } else {
1170 num_params_ = 0;
1171 }
1172 if (context->HasAttr("num_params_weights")) {
1173 OP_REQUIRES_OK(context, context->GetAttr("num_params_weights",
1174 &num_params_weights_));
1175 } else {
1176 num_params_weights_ = 0;
1177 }
1178 if (context->HasAttr("num_params_biases")) {
1179 OP_REQUIRES_OK(
1180 context, context->GetAttr("num_params_biases", &num_params_biases_));
1181 } else {
1182 num_params_biases_ = 0;
1183 }
1184 if (context->HasAttr("num_proj")) {
1185 OP_REQUIRES_OK(context, context->GetAttr("num_proj", &num_proj_));
1186 } else {
1187 num_proj_ = 0;
1188 }
1189 if (num_proj_ == 0) {
1190 num_params_weights_ = num_params_;
1191 num_params_biases_ = num_params_;
1192 }
1193 }
1194
Compute(OpKernelContext * context)1195 void Compute(OpKernelContext* context) override {
1196 const Tensor& input = context->input(3);
1197 auto input_ptr = StreamExecutorUtil::AsDeviceMemory<T>(input);
1198 Stream* stream = context->op_device_context()->stream();
1199
1200 std::unique_ptr<RnnDescriptor> rnn_desc;
1201 OP_REQUIRES_OK(context,
1202 ExtractCudnnRNNParamsInfo<T>(context, num_proj_, &rnn_desc));
1203 int64 params_size_in_bytes = rnn_desc->ParamsSizeInBytes();
1204 CHECK(params_size_in_bytes % sizeof(T) == 0)
1205 << "params_size_in_bytes must be multiple of element size";
1206
1207 const Tensor* num_units_t = nullptr;
1208 OP_REQUIRES_OK(context, context->input("num_units", &num_units_t));
1209 CHECK(TensorShapeUtils::IsScalar(num_units_t->shape()))
1210 << "num_units is not a scalar";
1211 int num_units = num_units_t->scalar<int>()();
1212
1213 const Tensor* input_size_t = nullptr;
1214 OP_REQUIRES_OK(context, context->input("input_size", &input_size_t));
1215 CHECK(TensorShapeUtils::IsScalar(input_size_t->shape()))
1216 << "input_size is not a scalar";
1217 int input_size = input_size_t->scalar<int>()();
1218
1219 const Tensor* num_layers_t = nullptr;
1220 OP_REQUIRES_OK(context, context->input("num_layers", &num_layers_t));
1221 CHECK(TensorShapeUtils::IsScalar(num_layers_t->shape()))
1222 << "num_layers is not a scalar";
1223 int num_layers = num_layers_t->scalar<int>()();
1224 int num_dirs = 1;
1225 if (rnn_direction_mode() == RnnDirectionMode::kRnnBidirectional) {
1226 num_dirs = 2;
1227 }
1228 const int num_params_weights_per_layer =
1229 num_params_weights_ / num_layers / num_dirs;
1230 // Number of params applied on inputs. The rest are applied on recurrent
1231 // hidden states.
1232 const int num_params_input_state = num_params_weights_per_layer / 2;
1233 OP_REQUIRES(
1234 context, num_params_weights_ % (num_layers * num_dirs) == 0,
1235 errors::InvalidArgument("Number of params (weights) is not a multiple"
1236 "of num_layers * num_dirs."));
1237 OP_REQUIRES(
1238 context, num_params_biases_ % (num_layers * num_dirs) == 0,
1239 errors::InvalidArgument("Number of params (biases) is not a multiple"
1240 "of num_layers * num_dirs."));
1241 if (num_proj_ == 0) {
1242 OP_REQUIRES(
1243 context, num_params_weights_per_layer % 2 == 0,
1244 errors::InvalidArgument("Number of params (weights) per layer is not"
1245 "an even number with no projection."));
1246 } else {
1247 OP_REQUIRES(
1248 context, num_params_weights_per_layer % 2 != 0,
1249 errors::InvalidArgument("Number of params (weights) per layer is not"
1250 "an odl number with projection."));
1251 }
1252
1253 OP_REQUIRES(
1254 context, num_params_weights_ == rnn_desc->ParamsWeightRegions().size(),
1255 errors::InvalidArgument("C Number of params mismatch. Expected ",
1256 num_params_weights_, ", got ",
1257 rnn_desc->ParamsWeightRegions().size()));
1258 int h_num_units = (num_proj_ == 0 ? num_units : num_proj_);
1259 int c_num_units = (num_proj_ == 0 ? 0 : num_units);
1260 for (int i = 0; i < rnn_desc->ParamsWeightRegions().size(); i++) {
1261 int64 size_in_bytes = rnn_desc->ParamsWeightRegions()[i].size;
1262 int64 size = size_in_bytes / sizeof(T);
1263 const int layer_idx = i / num_params_weights_per_layer;
1264 const int index_within_layer = i % num_params_weights_per_layer;
1265 int width = 0, height = (num_proj_ == 0 ? h_num_units : c_num_units);
1266 // In CuDNN layout, each layer has num_params_weights_per_layer params,
1267 // with the
1268 // first half a.k.a num_params_input_state params applied on the inputs,
1269 // and the second half on the recurrent hidden states.
1270 bool apply_on_input_state = index_within_layer < num_params_input_state;
1271 if (rnn_direction_mode() == RnnDirectionMode::kRnnUnidirectional) {
1272 if (layer_idx == 0 && apply_on_input_state) {
1273 width = input_size;
1274 } else {
1275 width = h_num_units;
1276 }
1277 } else {
1278 if (apply_on_input_state) {
1279 if (layer_idx <= 1) {
1280 // First fwd or bak layer.
1281 width = input_size;
1282 } else {
1283 // Following layers, cell inputs are concatenated outputs of
1284 // its prior layer.
1285 width = 2 * h_num_units;
1286 }
1287 } else {
1288 width = h_num_units;
1289 }
1290 }
1291 CHECK(size == width * height) << "Params size mismatch. Expected "
1292 << width * height << ", got " << size;
1293 Tensor* output = nullptr;
1294 int id_in_layer = i % num_params_weights_per_layer;
1295 if (num_proj_ != 0 && id_in_layer == num_params_weights_per_layer - 1) {
1296 std::swap(height, width);
1297 }
1298 OP_REQUIRES_OK(context, context->allocate_output(
1299 i, TensorShape({height, width}), &output));
1300 DeviceMemoryBase data_src_ptr = SliceDeviceMemory(
1301 input_ptr, rnn_desc->ParamsWeightRegions()[i].offset, size_in_bytes);
1302 auto data_dst_ptr = StreamExecutorUtil::AsDeviceMemory<T>(*output);
1303 stream->ThenMemcpy(&data_dst_ptr, data_src_ptr, size_in_bytes);
1304 }
1305
1306 OP_REQUIRES(
1307 context, num_params_biases_ == rnn_desc->ParamsBiasRegions().size(),
1308 errors::InvalidArgument("A Number of params mismatch. Expected ",
1309 num_params_biases_, ", got ",
1310 rnn_desc->ParamsBiasRegions().size()));
1311 for (int i = 0; i < rnn_desc->ParamsBiasRegions().size(); i++) {
1312 int64 size_in_bytes = rnn_desc->ParamsBiasRegions()[i].size;
1313 int64 size = size_in_bytes / sizeof(T);
1314 OP_REQUIRES(context, size == num_units,
1315 errors::InvalidArgument("Params size mismatch. Expected ",
1316 num_units, ", got ", size));
1317
1318 Tensor* output = nullptr;
1319 OP_REQUIRES_OK(context,
1320 context->allocate_output(num_params_weights_ + i,
1321 TensorShape({size}), &output));
1322 DeviceMemoryBase data_src_ptr = SliceDeviceMemory(
1323 input_ptr, rnn_desc->ParamsBiasRegions()[i].offset, size_in_bytes);
1324 auto data_dst_ptr = StreamExecutorUtil::AsDeviceMemory<T>(*output);
1325 stream->ThenMemcpy(&data_dst_ptr, data_src_ptr, size_in_bytes);
1326 }
1327 }
1328
1329 private:
1330 int num_params_;
1331 int num_params_weights_;
1332 int num_params_biases_;
1333 int num_proj_;
1334 };
1335
1336 #define REGISTER_GPU(T) \
1337 REGISTER_KERNEL_BUILDER(Name("CudnnRNNParamsToCanonical") \
1338 .Device(DEVICE_GPU) \
1339 .HostMemory("num_layers") \
1340 .HostMemory("num_units") \
1341 .HostMemory("input_size") \
1342 .TypeConstraint<T>("T"), \
1343 CudnnRNNParamsToCanonical<GPUDevice, T>);
1344 TF_CALL_half(REGISTER_GPU);
1345 TF_CALL_float(REGISTER_GPU);
1346 TF_CALL_double(REGISTER_GPU);
1347 #undef REGISTER_GPU
1348
1349 #define REGISTER_GPU(T) \
1350 REGISTER_KERNEL_BUILDER(Name("CudnnRNNParamsToCanonicalV2") \
1351 .Device(DEVICE_GPU) \
1352 .HostMemory("num_layers") \
1353 .HostMemory("num_units") \
1354 .HostMemory("input_size") \
1355 .TypeConstraint<T>("T"), \
1356 CudnnRNNParamsToCanonical<GPUDevice, T>);
1357 TF_CALL_half(REGISTER_GPU);
1358 TF_CALL_float(REGISTER_GPU);
1359 TF_CALL_double(REGISTER_GPU);
1360 #undef REGISTER_GPU
1361
1362 // Convert weight and bias params from the canonical form to a
1363 // platform-specific layout.
1364 template <typename T>
1365 class CudnnRNNCanonicalToParams<GPUDevice, T> : public CudnnRNNKernelCommon {
1366 public:
CudnnRNNCanonicalToParams(OpKernelConstruction * context)1367 explicit CudnnRNNCanonicalToParams(OpKernelConstruction* context)
1368 : CudnnRNNKernelCommon(context) {
1369 if (context->HasAttr("num_proj")) {
1370 OP_REQUIRES_OK(context, context->GetAttr("num_proj", &num_proj_));
1371 } else {
1372 num_proj_ = 0;
1373 }
1374 }
1375
Compute(OpKernelContext * context)1376 void Compute(OpKernelContext* context) override {
1377 std::unique_ptr<RnnDescriptor> rnn_desc;
1378 OP_REQUIRES_OK(context,
1379 ExtractCudnnRNNParamsInfo<T>(context, num_proj_, &rnn_desc));
1380 int64 params_size_in_bytes = rnn_desc->ParamsSizeInBytes();
1381 CHECK(params_size_in_bytes % sizeof(T) == 0)
1382 << "params_size_in_bytes must be multiple of element size";
1383 Tensor* output = nullptr;
1384 int params_size = params_size_in_bytes / sizeof(T);
1385 OP_REQUIRES_OK(context,
1386 context->allocate_output(0, {params_size}, &output));
1387 auto output_ptr = StreamExecutorUtil::AsDeviceMemory<T>(*output);
1388 Stream* stream = context->op_device_context()->stream();
1389
1390 OpInputList weights;
1391 OP_REQUIRES_OK(context, context->input_list("weights", &weights));
1392 RestoreParams<T>(weights, rnn_desc->ParamsWeightRegions(), &output_ptr,
1393 stream);
1394
1395 OpInputList biases;
1396 OP_REQUIRES_OK(context, context->input_list("biases", &biases));
1397 RestoreParams<T>(biases, rnn_desc->ParamsBiasRegions(), &output_ptr,
1398 stream);
1399 }
1400
1401 private:
1402 int num_proj_;
1403 };
1404
1405 #define REGISTER_GPU(T) \
1406 REGISTER_KERNEL_BUILDER(Name("CudnnRNNCanonicalToParams") \
1407 .Device(DEVICE_GPU) \
1408 .HostMemory("num_layers") \
1409 .HostMemory("num_units") \
1410 .HostMemory("input_size") \
1411 .TypeConstraint<T>("T"), \
1412 CudnnRNNCanonicalToParams<GPUDevice, T>);
1413 TF_CALL_half(REGISTER_GPU);
1414 TF_CALL_float(REGISTER_GPU);
1415 TF_CALL_double(REGISTER_GPU);
1416 #undef REGISTER_GPU
1417
1418 #define REGISTER_GPU(T) \
1419 REGISTER_KERNEL_BUILDER(Name("CudnnRNNCanonicalToParamsV2") \
1420 .Device(DEVICE_GPU) \
1421 .HostMemory("num_layers") \
1422 .HostMemory("num_units") \
1423 .HostMemory("input_size") \
1424 .TypeConstraint<T>("T"), \
1425 CudnnRNNCanonicalToParams<GPUDevice, T>);
1426 TF_CALL_half(REGISTER_GPU);
1427 TF_CALL_float(REGISTER_GPU);
1428 TF_CALL_double(REGISTER_GPU);
1429 #undef REGISTER_GPU
1430
1431 // Run the forward operation of the RNN model.
1432 template <typename T>
1433 class CudnnRNNForwardOp<GPUDevice, T> : public CudnnRNNKernelCommon {
1434 public:
CudnnRNNForwardOp(OpKernelConstruction * context)1435 explicit CudnnRNNForwardOp(OpKernelConstruction* context)
1436 : CudnnRNNKernelCommon(context) {
1437 OP_REQUIRES_OK(context, context->GetAttr("is_training", &is_training_));
1438
1439 // Read debug env variables.
1440 is_debug_mode_ = DebugCudnnRnn();
1441 debug_cudnn_rnn_algo_ = DebugCudnnRnnAlgo();
1442 debug_use_tensor_ops_ = DebugCudnnRnnUseTensorOps();
1443 }
1444
Compute(OpKernelContext * context)1445 void Compute(OpKernelContext* context) override {
1446 AlgorithmConfig algo_config;
1447 ComputeAndReturnAlgorithm(context, &algo_config, /*var_seq_lengths=*/false,
1448 /*time_major=*/true, /*num_proj=*/0);
1449 }
1450
1451 protected:
ComputeAndReturnAlgorithm(OpKernelContext * context,AlgorithmConfig * output_algo_config,bool var_seq_lengths,bool time_major,int num_proj)1452 virtual void ComputeAndReturnAlgorithm(OpKernelContext* context,
1453 AlgorithmConfig* output_algo_config,
1454 bool var_seq_lengths, bool time_major,
1455 int num_proj) {
1456 CHECK_NE(output_algo_config, nullptr);
1457
1458 const Tensor* input = nullptr;
1459 const Tensor* input_h = nullptr;
1460 const Tensor* input_c = nullptr;
1461 const Tensor* params = nullptr;
1462 const Tensor* sequence_lengths = nullptr;
1463 CudnnRnnModelShapes model_shapes;
1464 bool use_padded_io = false;
1465 if (var_seq_lengths) {
1466 OP_REQUIRES_OK(context, ExtractForwardInput(
1467 context, model_types(), time_major, &input,
1468 &input_h, &input_c, ¶ms,
1469 &sequence_lengths, num_proj, &model_shapes));
1470 use_padded_io =
1471 ShouldUsePaddedIO(sequence_lengths, model_shapes, time_major);
1472 } else {
1473 OP_REQUIRES_OK(context,
1474 ExtractForwardInput(context, model_types(), time_major,
1475 &input, &input_h, &input_c, ¶ms,
1476 num_proj, &model_shapes));
1477 }
1478 RnnInputMode input_mode;
1479 OP_REQUIRES_OK(context,
1480 ToRNNInputMode(rnn_input_mode(), model_shapes.num_units,
1481 model_shapes.input_size, &input_mode));
1482
1483 Tensor* output = nullptr;
1484 Tensor* output_h = nullptr;
1485 Tensor* output_c = nullptr;
1486 OP_REQUIRES_OK(context, AllocateOutputs(context, model_shapes, &output,
1487 &output_h, &output_c));
1488
1489 // Creates a memory callback for the reserve_space. The memory lives in the
1490 // output of this kernel. And it will be fed into the backward pass when
1491 // needed.
1492 CudnnRnnAllocatorInOutput<T> reserve_space_allocator(context, 3);
1493 // Creates a memory callback for the workspace. The memory lives to the end
1494 // of this kernel calls.
1495 CudnnRnnAllocatorInTemp<uint8> workspace_allocator(context);
1496
1497 if (is_debug_mode_) {
1498 AlgorithmDesc algo_desc(debug_cudnn_rnn_algo_, debug_use_tensor_ops_);
1499 output_algo_config->set_algorithm(algo_desc);
1500 } else {
1501 OP_REQUIRES_OK(context,
1502 MaybeAutoTune(context, model_shapes, input_mode, input,
1503 input_h, input_c, params, output, output_h,
1504 output_c, output_algo_config));
1505 }
1506
1507 Status launch_status;
1508 {
1509 mutex_lock l(mu_);
1510 RnnDescriptor* rnn_desc_ptr = nullptr;
1511 OP_REQUIRES_OK(context,
1512 GetCachedRnnDescriptor<T>(
1513 context, model_shapes, input_mode, *output_algo_config,
1514 &rnn_state_cache_, &rnn_desc_ptr, use_padded_io));
1515 launch_status = DoForward<T>(
1516 context, *rnn_desc_ptr, model_types(), model_shapes, input, input_h,
1517 input_c, params, is_training_, output, output_h, output_c,
1518 sequence_lengths, time_major, &reserve_space_allocator,
1519 &workspace_allocator, /*output_profile_result=*/nullptr);
1520 }
1521 OP_REQUIRES_OK(context, launch_status);
1522 }
1523
1524 protected:
MaybeAutoTune(OpKernelContext * context,const CudnnRnnModelShapes & model_shapes,const RnnInputMode & input_mode,const Tensor * input,const Tensor * input_h,const Tensor * input_c,const Tensor * params,Tensor * output,Tensor * output_h,Tensor * output_c,AlgorithmConfig * best_algo_config)1525 virtual Status MaybeAutoTune(OpKernelContext* context,
1526 const CudnnRnnModelShapes& model_shapes,
1527 const RnnInputMode& input_mode,
1528 const Tensor* input, const Tensor* input_h,
1529 const Tensor* input_c, const Tensor* params,
1530 Tensor* output, Tensor* output_h,
1531 Tensor* output_c,
1532 AlgorithmConfig* best_algo_config) {
1533 CHECK_NE(best_algo_config, nullptr);
1534 *best_algo_config = AlgorithmConfig();
1535 return Status::OK();
1536 }
1537
is_training() const1538 bool is_training() const { return is_training_; }
1539 bool is_debug_mode_;
1540 bool debug_use_tensor_ops_;
1541 int64 debug_cudnn_rnn_algo_;
1542
1543 private:
AllocateOutputs(OpKernelContext * context,const CudnnRnnModelShapes & model_shapes,Tensor ** output,Tensor ** output_h,Tensor ** output_c)1544 Status AllocateOutputs(OpKernelContext* context,
1545 const CudnnRnnModelShapes& model_shapes,
1546 Tensor** output, Tensor** output_h,
1547 Tensor** output_c) {
1548 const TensorShape& hidden_state_shape = model_shapes.hidden_state_shape;
1549 const TensorShape& output_shape = model_shapes.output_shape;
1550 const TensorShape& cell_state_shape = model_shapes.cell_state_shape;
1551
1552 TF_RETURN_IF_ERROR(context->allocate_output(0, output_shape, output));
1553 TF_RETURN_IF_ERROR(
1554 context->allocate_output(1, hidden_state_shape, output_h));
1555 if (HasInputC()) {
1556 TF_RETURN_IF_ERROR(
1557 context->allocate_output(2, cell_state_shape, output_c));
1558 } else {
1559 // Only LSTM uses input_c and output_c. So for all other models, we only
1560 // need to create dummy outputs.
1561 TF_RETURN_IF_ERROR(context->allocate_output(2, {}, output_c));
1562 }
1563 if (!is_training_) {
1564 Tensor* dummy_reserve_space = nullptr;
1565 TF_RETURN_IF_ERROR(context->allocate_output(3, {}, &dummy_reserve_space));
1566 }
1567 return Status::OK();
1568 }
1569
1570 mutex mu_;
1571 bool is_training_;
1572 RnnStateCache rnn_state_cache_ TF_GUARDED_BY(mu_);
1573 };
1574
1575 #define REGISTER_GPU(T) \
1576 REGISTER_KERNEL_BUILDER( \
1577 Name("CudnnRNN").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
1578 CudnnRNNForwardOp<GPUDevice, T>);
1579
1580 TF_CALL_half(REGISTER_GPU);
1581 TF_CALL_float(REGISTER_GPU);
1582 TF_CALL_double(REGISTER_GPU);
1583 #undef REGISTER_GPU
1584
1585 template <typename T>
1586 class CudnnRNNForwardOpV2<GPUDevice, T>
1587 : public CudnnRNNForwardOp<GPUDevice, T> {
1588 private:
1589 using CudnnRNNForwardOp<GPUDevice, T>::is_training;
1590 using CudnnRNNKernelCommon::CreateRnnDescriptor;
1591 using CudnnRNNKernelCommon::dropout;
1592 using CudnnRNNKernelCommon::HasInputC;
1593 using CudnnRNNKernelCommon::model_types;
1594
1595 public:
CudnnRNNForwardOpV2(OpKernelConstruction * context)1596 explicit CudnnRNNForwardOpV2(OpKernelConstruction* context)
1597 : CudnnRNNForwardOp<GPUDevice, T>(context) {}
1598
Compute(OpKernelContext * context)1599 void Compute(OpKernelContext* context) override {
1600 AlgorithmConfig best_algo_config;
1601 CudnnRNNForwardOp<GPUDevice, T>::ComputeAndReturnAlgorithm(
1602 context, &best_algo_config, /*var_seq_lengths=*/false,
1603 /*time_major=*/true, /*num_proj=*/0);
1604 if (!context->status().ok()) {
1605 return;
1606 }
1607
1608 Tensor* output_host_reserved = nullptr;
1609 // output_host_reserved stores opaque info used for backprop when running
1610 // in training mode. At present, it includes a serialization of the best
1611 // AlgorithmDesc picked during rnn forward pass autotune.
1612 // int8 algorithm_id
1613 // int8 use_tensor_op
1614 // If autotune is not enabled, the algorithm_id is
1615 // stream_executor::dnn::kDefaultAlgorithm and use_tensor_op is false. If
1616 // running in inference mode, the output_host_reserved is currently not
1617 // populated.
1618 if (is_training()) {
1619 OP_REQUIRES_OK(context, context->allocate_output(4, TensorShape({2}),
1620 &output_host_reserved));
1621 auto output_host_reserved_int8 = output_host_reserved->vec<int8>();
1622 output_host_reserved_int8(0) = best_algo_config.algorithm()->algo_id();
1623 output_host_reserved_int8(1) =
1624 best_algo_config.algorithm()->tensor_ops_enabled();
1625 } else {
1626 OP_REQUIRES_OK(context,
1627 context->allocate_output(4, {}, &output_host_reserved));
1628 }
1629 }
1630
1631 protected:
MaybeAutoTune(OpKernelContext * context,const CudnnRnnModelShapes & model_shapes,const RnnInputMode & input_mode,const Tensor * input,const Tensor * input_h,const Tensor * input_c,const Tensor * params,Tensor * output,Tensor * output_h,Tensor * output_c,AlgorithmConfig * algo_config)1632 Status MaybeAutoTune(OpKernelContext* context,
1633 const CudnnRnnModelShapes& model_shapes,
1634 const RnnInputMode& input_mode, const Tensor* input,
1635 const Tensor* input_h, const Tensor* input_c,
1636 const Tensor* params, Tensor* output, Tensor* output_h,
1637 Tensor* output_c,
1638 AlgorithmConfig* algo_config) override {
1639 CHECK_NE(algo_config, nullptr);
1640 if (!CudnnRnnUseAutotune() || this->is_debug_mode_) {
1641 *algo_config = AlgorithmConfig();
1642 return Status::OK();
1643 }
1644
1645 std::vector<AlgorithmDesc> algorithms;
1646 auto* stream = context->op_device_context()->stream();
1647 CHECK(stream->parent()->GetRnnAlgorithms(&algorithms));
1648 if (algorithms.empty()) {
1649 LOG(WARNING) << "No Rnn algorithm found";
1650 return Status::OK();
1651 }
1652
1653 const auto& modeltypes = model_types();
1654 CudnnRnnParameters rnn_params(
1655 model_shapes.num_layers, model_shapes.input_size,
1656 model_shapes.num_units, model_shapes.max_seq_length,
1657 model_shapes.batch_size, model_shapes.dir_count,
1658 /*has_dropout=*/std::abs(dropout()) > 1e-8, is_training(),
1659 modeltypes.rnn_mode, modeltypes.rnn_input_mode, input->dtype());
1660
1661 if (AutoTuneRnnConfigMap::GetInstance()->Find(rnn_params, algo_config)) {
1662 VLOG(1) << "Using existing best Cudnn RNN algorithm "
1663 << "(algo, tensor_op_enabled) = ("
1664 << algo_config->algorithm()->algo_id() << ", "
1665 << algo_config->algorithm()->tensor_ops_enabled() << ").";
1666 return Status::OK();
1667 }
1668
1669 // Create temp tensors when profiling backprop pass.
1670 auto data_type = input->dtype();
1671 Tensor output_backprop;
1672 Tensor output_h_backprop;
1673 Tensor output_c_backprop;
1674 Tensor input_backprop;
1675 Tensor input_h_backprop;
1676 Tensor input_c_backprop;
1677 Tensor params_backprop;
1678 if (is_training()) {
1679 TF_RETURN_IF_ERROR(context->allocate_temp(
1680 data_type, model_shapes.output_shape, &output_backprop));
1681 TF_RETURN_IF_ERROR(context->allocate_temp(
1682 data_type, model_shapes.hidden_state_shape, &output_h_backprop));
1683
1684 TF_RETURN_IF_ERROR(
1685 context->allocate_temp(data_type, params->shape(), ¶ms_backprop));
1686 TF_RETURN_IF_ERROR(context->allocate_temp(
1687 data_type, model_shapes.input_shape, &input_backprop));
1688 TF_RETURN_IF_ERROR(context->allocate_temp(
1689 data_type, model_shapes.hidden_state_shape, &input_h_backprop));
1690 if (HasInputC()) {
1691 TF_RETURN_IF_ERROR(context->allocate_temp(
1692 data_type, model_shapes.hidden_state_shape, &output_c_backprop));
1693 TF_RETURN_IF_ERROR(context->allocate_temp(
1694 data_type, model_shapes.hidden_state_shape, &input_c_backprop));
1695 }
1696 }
1697 ProfileResult best_result;
1698 for (auto& algo : algorithms) {
1699 VLOG(1) << "Profile Cudnn RNN algorithm (algo, tensor_op_enabled) = ("
1700 << algo.algo_id() << ", " << algo.tensor_ops_enabled() << ").";
1701 Status status;
1702 ProfileResult final_profile_result;
1703
1704 ProfileResult fwd_profile_result;
1705 ProfileResult bak_profile_result;
1706
1707 // RnnDescriptor is algorithm-dependent, thus not reusable.
1708 std::unique_ptr<RnnDescriptor> rnn_desc;
1709 // Use a temp scratch allocator for the random num generator.
1710 CudnnRnnAllocatorInTemp<uint8> dropout_state_allocator(context);
1711 if (!this->template CreateRnnDescriptor<T>(
1712 context, model_shapes, input_mode, AlgorithmConfig(algo),
1713 &dropout_state_allocator, &rnn_desc,
1714 /*use_padded_io=*/false)
1715 .ok()) {
1716 continue;
1717 }
1718
1719 // Again use temp scratch allocator during profiling.
1720 CudnnRnnAllocatorInTemp<T> reserve_space_allocator(context);
1721 CudnnRnnAllocatorInTemp<uint8> workspace_allocator(context);
1722 status = DoForward<T>(context, *rnn_desc, model_types(), model_shapes,
1723 input, input_h, input_c, params, is_training(),
1724 output, output_h, output_c, nullptr, true,
1725 &reserve_space_allocator, &workspace_allocator,
1726 &fwd_profile_result);
1727 if (!status.ok()) {
1728 continue;
1729 }
1730
1731 if (is_training()) {
1732 // Get reserve space from the forward pass.
1733 Tensor reserve_space = reserve_space_allocator.get_allocated_tensor(0);
1734 status = DoBackward<T>(
1735 context, *rnn_desc, model_types(), model_shapes, input, input_h,
1736 input_c, params, output, output_h, output_c, &output_backprop,
1737 &output_h_backprop, &output_c_backprop, &reserve_space,
1738 &input_backprop, &input_h_backprop, &input_c_backprop,
1739 ¶ms_backprop, nullptr, true, &workspace_allocator,
1740 &bak_profile_result);
1741 if (!status.ok()) {
1742 continue;
1743 }
1744 final_profile_result.set_elapsed_time_in_ms(
1745 fwd_profile_result.elapsed_time_in_ms() +
1746 bak_profile_result.elapsed_time_in_ms());
1747 } else {
1748 final_profile_result = fwd_profile_result;
1749 }
1750
1751 auto total_time = final_profile_result.elapsed_time_in_ms();
1752 VLOG(1) << "Cudnn RNN algorithm (algo, tensor_op_enabled) = ("
1753 << algo.algo_id() << ", " << algo.tensor_ops_enabled() << ")"
1754 << " run time: " << total_time << " ms.";
1755 if (total_time < best_result.elapsed_time_in_ms()) {
1756 best_result.set_elapsed_time_in_ms(total_time);
1757 best_result.set_algorithm(algo);
1758 }
1759 }
1760
1761 if (!best_result.is_valid()) {
1762 return Status(error::Code::INTERNAL, "No algorithm worked!");
1763 }
1764 algo_config->set_algorithm(best_result.algorithm());
1765 VLOG(1) << "Best Cudnn RNN algorithm (algo, tensor_op_enabled) = ("
1766 << best_result.algorithm().algo_id() << ", "
1767 << best_result.algorithm().tensor_ops_enabled() << ").";
1768 AutoTuneRnnConfigMap::GetInstance()->Insert(rnn_params, *algo_config);
1769 return Status::OK();
1770 }
1771 };
1772
1773 #define REGISTER_GPU(T) \
1774 REGISTER_KERNEL_BUILDER(Name("CudnnRNNV2") \
1775 .Device(DEVICE_GPU) \
1776 .HostMemory("host_reserved") \
1777 .TypeConstraint<T>("T"), \
1778 CudnnRNNForwardOpV2<GPUDevice, T>);
1779
1780 TF_CALL_half(REGISTER_GPU);
1781 TF_CALL_float(REGISTER_GPU);
1782 TF_CALL_double(REGISTER_GPU);
1783 #undef REGISTER_GPU
1784
1785 template <typename T>
1786 class CudnnRNNForwardOpV3<GPUDevice, T>
1787 : public CudnnRNNForwardOp<GPUDevice, T> {
1788 private:
1789 using CudnnRNNForwardOp<GPUDevice, T>::is_training;
1790 using CudnnRNNKernelCommon::CreateRnnDescriptor;
1791 using CudnnRNNKernelCommon::dropout;
1792 using CudnnRNNKernelCommon::HasInputC;
1793 using CudnnRNNKernelCommon::model_types;
1794 bool time_major_;
1795
1796 protected:
time_major()1797 bool time_major() { return time_major_; }
1798
1799 public:
CudnnRNNForwardOpV3(OpKernelConstruction * context)1800 explicit CudnnRNNForwardOpV3(OpKernelConstruction* context)
1801 : CudnnRNNForwardOp<GPUDevice, T>(context) {
1802 OP_REQUIRES_OK(context, context->GetAttr("time_major", &time_major_));
1803 if (context->HasAttr("num_proj")) {
1804 OP_REQUIRES_OK(context, context->GetAttr("num_proj", &num_proj_));
1805 } else {
1806 num_proj_ = 0;
1807 }
1808 }
1809
Compute(OpKernelContext * context)1810 void Compute(OpKernelContext* context) override {
1811 AlgorithmConfig best_algo_config;
1812 CudnnRNNForwardOp<GPUDevice, T>::ComputeAndReturnAlgorithm(
1813 context, &best_algo_config, /*var_seq_lengths=*/true,
1814 /*time_major=*/time_major(), num_proj_);
1815 if (!context->status().ok()) {
1816 return;
1817 }
1818
1819 Tensor* output_host_reserved = nullptr;
1820 // TODO: Current V3 only uses the default standard algorithm to process
1821 // batches with variable sequences and the inputs should be padded.
1822 // Autotune is not supported yet.
1823 OP_REQUIRES_OK(context,
1824 context->allocate_output(4, {}, &output_host_reserved));
1825 }
1826
1827 private:
1828 int num_proj_;
1829 };
1830
1831 #define REGISTER_GPU(T) \
1832 REGISTER_KERNEL_BUILDER(Name("CudnnRNNV3") \
1833 .Device(DEVICE_GPU) \
1834 .HostMemory("sequence_lengths") \
1835 .HostMemory("host_reserved") \
1836 .TypeConstraint<T>("T"), \
1837 CudnnRNNForwardOpV3<GPUDevice, T>);
1838
1839 TF_CALL_half(REGISTER_GPU);
1840 TF_CALL_float(REGISTER_GPU);
1841 TF_CALL_double(REGISTER_GPU);
1842 #undef REGISTER_GPU
1843
1844 // Run the backward operation of the RNN model.
1845 template <typename T>
1846 class CudnnRNNBackwardOp<GPUDevice, T> : public CudnnRNNKernelCommon {
1847 public:
CudnnRNNBackwardOp(OpKernelConstruction * context)1848 explicit CudnnRNNBackwardOp(OpKernelConstruction* context)
1849 : CudnnRNNKernelCommon(context) {}
1850
Compute(OpKernelContext * context)1851 void Compute(OpKernelContext* context) override {
1852 ComputeImpl(context, false, true, 0);
1853 }
1854
1855 protected:
ComputeImpl(OpKernelContext * context,bool var_seq_lengths,bool time_major,int num_proj)1856 virtual void ComputeImpl(OpKernelContext* context, bool var_seq_lengths,
1857 bool time_major, int num_proj) {
1858 const Tensor* input = nullptr;
1859 const Tensor* input_h = nullptr;
1860 const Tensor* input_c = nullptr;
1861 const Tensor* params = nullptr;
1862 const Tensor* sequence_lengths = nullptr;
1863 CudnnRnnModelShapes model_shapes;
1864 bool use_padded_io = false;
1865 if (var_seq_lengths) {
1866 OP_REQUIRES_OK(context, ExtractForwardInput(
1867 context, model_types(), time_major, &input,
1868 &input_h, &input_c, ¶ms,
1869 &sequence_lengths, num_proj, &model_shapes));
1870 use_padded_io =
1871 ShouldUsePaddedIO(sequence_lengths, model_shapes, time_major);
1872 } else {
1873 OP_REQUIRES_OK(context,
1874 ExtractForwardInput(context, model_types(), time_major,
1875 &input, &input_h, &input_c, ¶ms,
1876 num_proj, &model_shapes));
1877 }
1878 RnnInputMode input_mode;
1879 OP_REQUIRES_OK(context,
1880 ToRNNInputMode(rnn_input_mode(), model_shapes.num_units,
1881 model_shapes.input_size, &input_mode));
1882
1883 const Tensor* output = nullptr;
1884 const Tensor* output_h = nullptr;
1885 const Tensor* output_c = nullptr;
1886 const Tensor* output_backprop = nullptr;
1887 const Tensor* output_h_backprop = nullptr;
1888 const Tensor* output_c_backprop = nullptr;
1889 const Tensor* reserve_space = nullptr;
1890 OP_REQUIRES_OK(context,
1891 ExtractBackwardInputs(context, model_shapes, model_types(),
1892 &output, &output_h, &output_c,
1893 &output_backprop, &output_h_backprop,
1894 &output_c_backprop, &reserve_space));
1895
1896 Tensor* input_backprop = nullptr;
1897 Tensor* input_h_backprop = nullptr;
1898 Tensor* input_c_backprop = nullptr;
1899 Tensor* params_backprop = nullptr;
1900 OP_REQUIRES_OK(context,
1901 AllocateOutputs(context, model_shapes, params->shape(),
1902 &input_backprop, &input_h_backprop,
1903 &input_c_backprop, ¶ms_backprop));
1904
1905 // Creates a memory callback for the workspace. The memory lives to the end
1906 // of this kernel calls.
1907 CudnnRnnAllocatorInTemp<uint8> workspace_allocator(context);
1908 AlgorithmConfig algo_config;
1909 OP_REQUIRES_OK(context, GetAlgorithm(context, &algo_config));
1910 Status launch_status;
1911 {
1912 mutex_lock l(mu_);
1913 RnnDescriptor* rnn_desc_ptr = nullptr;
1914 OP_REQUIRES_OK(
1915 context, GetCachedRnnDescriptor<T>(context, model_shapes, input_mode,
1916 algo_config, &rnn_state_cache_,
1917 &rnn_desc_ptr, use_padded_io));
1918 launch_status = DoBackward<T>(
1919 context, *rnn_desc_ptr, model_types(), model_shapes, input, input_h,
1920 input_c, params, output, output_h, output_c, output_backprop,
1921 output_h_backprop, output_c_backprop, reserve_space, input_backprop,
1922 input_h_backprop, input_c_backprop, params_backprop, sequence_lengths,
1923 time_major, &workspace_allocator,
1924 /*output_profile_result=*/nullptr);
1925 }
1926 OP_REQUIRES_OK(context, launch_status);
1927 }
1928
1929 protected:
GetAlgorithm(OpKernelContext * context,AlgorithmConfig * algo_config)1930 virtual Status GetAlgorithm(OpKernelContext* context,
1931 AlgorithmConfig* algo_config) {
1932 CHECK_NE(algo_config, nullptr);
1933 *algo_config = AlgorithmConfig();
1934 return Status::OK();
1935 }
1936
1937 private:
1938 mutex mu_;
1939 RnnStateCache rnn_state_cache_ TF_GUARDED_BY(mu_);
1940
ExtractBackwardInputs(OpKernelContext * context,const CudnnRnnModelShapes & model_shapes,const CudnnModelTypes & model_types,const Tensor ** output,const Tensor ** output_h,const Tensor ** output_c,const Tensor ** output_backprop,const Tensor ** output_h_backprop,const Tensor ** output_c_backprop,const Tensor ** reserve_space)1941 Status ExtractBackwardInputs(
1942 OpKernelContext* context, const CudnnRnnModelShapes& model_shapes,
1943 const CudnnModelTypes& model_types, const Tensor** output,
1944 const Tensor** output_h, const Tensor** output_c,
1945 const Tensor** output_backprop, const Tensor** output_h_backprop,
1946 const Tensor** output_c_backprop, const Tensor** reserve_space) {
1947 TF_RETURN_IF_ERROR(context->input("output", output));
1948 TF_RETURN_IF_ERROR(context->input("output_backprop", output_backprop));
1949 TF_RETURN_IF_ERROR(context->input("output_h", output_h));
1950 TF_RETURN_IF_ERROR(context->input("output_h_backprop", output_h_backprop));
1951 if (model_types.HasInputC()) {
1952 TF_RETURN_IF_ERROR(context->input("output_c", output_c));
1953 TF_RETURN_IF_ERROR(
1954 context->input("output_c_backprop", output_c_backprop));
1955 }
1956 TF_RETURN_IF_ERROR(context->input("reserve_space", reserve_space));
1957 const TensorShape& hidden_state_shape = model_shapes.hidden_state_shape;
1958 const TensorShape& output_shape = model_shapes.output_shape;
1959 const TensorShape& cell_state_shape = model_shapes.cell_state_shape;
1960
1961 if (output_shape != (*output)->shape()) {
1962 return errors::InvalidArgument(
1963 "Invalid output shape: ", (*output)->shape().DebugString(), " ",
1964 output_shape.DebugString());
1965 }
1966 if (hidden_state_shape != (*output_h)->shape()) {
1967 return errors::InvalidArgument(
1968 "Invalid output_h shape: ", (*output_h)->shape().DebugString(), " ",
1969 hidden_state_shape.DebugString());
1970 }
1971
1972 if (output_shape != (*output_backprop)->shape()) {
1973 return errors::InvalidArgument("Invalid output_backprop shape: ",
1974 (*output_backprop)->shape().DebugString(),
1975 " ", output_shape.DebugString());
1976 }
1977 if (hidden_state_shape != (*output_h_backprop)->shape()) {
1978 return errors::InvalidArgument(
1979 "Invalid output_h_backprop shape: ",
1980 (*output_h_backprop)->shape().DebugString(), " ",
1981 hidden_state_shape.DebugString());
1982 }
1983
1984 if (model_types.HasInputC()) {
1985 if (cell_state_shape != (*output_c)->shape()) {
1986 return errors::InvalidArgument(
1987 "Invalid output_c shape: ", (*output_c)->shape().DebugString(), " ",
1988 cell_state_shape.DebugString());
1989 }
1990 if (cell_state_shape != (*output_c_backprop)->shape()) {
1991 return errors::InvalidArgument(
1992 "Invalid output_c_backprop shape: ",
1993 (*output_c_backprop)->shape().DebugString(), " ",
1994 cell_state_shape.DebugString());
1995 }
1996 }
1997 return Status::OK();
1998 }
1999
AllocateOutputs(OpKernelContext * context,const CudnnRnnModelShapes & model_shapes,const TensorShape & params_shape,Tensor ** input_backprop,Tensor ** input_h_backprop,Tensor ** input_c_backprop,Tensor ** params_backprop)2000 Status AllocateOutputs(OpKernelContext* context,
2001 const CudnnRnnModelShapes& model_shapes,
2002 const TensorShape& params_shape,
2003 Tensor** input_backprop, Tensor** input_h_backprop,
2004 Tensor** input_c_backprop, Tensor** params_backprop) {
2005 const TensorShape& input_shape = model_shapes.input_shape;
2006 const TensorShape& hidden_state_shape = model_shapes.hidden_state_shape;
2007 const TensorShape& cell_state_shape = model_shapes.cell_state_shape;
2008
2009 TF_RETURN_IF_ERROR(
2010 context->allocate_output(0, input_shape, input_backprop));
2011 TF_RETURN_IF_ERROR(
2012 context->allocate_output(1, hidden_state_shape, input_h_backprop));
2013 if (HasInputC()) {
2014 TF_RETURN_IF_ERROR(
2015 context->allocate_output(2, cell_state_shape, input_c_backprop));
2016 } else {
2017 // Only LSTM uses input_c and output_c. So for all other models, we only
2018 // need to create dummy outputs.
2019 TF_RETURN_IF_ERROR(context->allocate_output(2, {}, input_c_backprop));
2020 }
2021 TF_RETURN_IF_ERROR(
2022 context->allocate_output(3, params_shape, params_backprop));
2023 return Status::OK();
2024 }
2025 };
2026
2027 #define REGISTER_GPU(T) \
2028 REGISTER_KERNEL_BUILDER( \
2029 Name("CudnnRNNBackprop").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
2030 CudnnRNNBackwardOp<GPUDevice, T>);
2031
2032 TF_CALL_half(REGISTER_GPU);
2033 TF_CALL_float(REGISTER_GPU);
2034 TF_CALL_double(REGISTER_GPU);
2035 #undef REGISTER_GPU
2036
2037 template <typename T>
2038 class CudnnRNNBackwardOpV2<GPUDevice, T>
2039 : public CudnnRNNBackwardOp<GPUDevice, T> {
2040 public:
CudnnRNNBackwardOpV2(OpKernelConstruction * context)2041 explicit CudnnRNNBackwardOpV2(OpKernelConstruction* context)
2042 : CudnnRNNBackwardOp<GPUDevice, T>(context) {}
2043
2044 protected:
GetAlgorithm(OpKernelContext * context,AlgorithmConfig * algo_config)2045 Status GetAlgorithm(OpKernelContext* context,
2046 AlgorithmConfig* algo_config) override {
2047 CHECK_NE(algo_config, nullptr);
2048 const Tensor* host_reserved = nullptr;
2049 TF_RETURN_IF_ERROR(context->input("host_reserved", &host_reserved));
2050
2051 auto host_reserved_int8 = host_reserved->vec<int8>();
2052 const AlgorithmDesc algo_desc(host_reserved_int8(0), host_reserved_int8(1));
2053 algo_config->set_algorithm(algo_desc);
2054 return Status::OK();
2055 }
2056 };
2057
2058 #define REGISTER_GPU(T) \
2059 REGISTER_KERNEL_BUILDER(Name("CudnnRNNBackpropV2") \
2060 .Device(DEVICE_GPU) \
2061 .HostMemory("host_reserved") \
2062 .TypeConstraint<T>("T"), \
2063 CudnnRNNBackwardOpV2<GPUDevice, T>);
2064
2065 TF_CALL_half(REGISTER_GPU);
2066 TF_CALL_float(REGISTER_GPU);
2067 TF_CALL_double(REGISTER_GPU);
2068 #undef REGISTER_GPU
2069
2070 template <typename T>
2071 class CudnnRNNBackwardOpV3<GPUDevice, T>
2072 : public CudnnRNNBackwardOp<GPUDevice, T> {
2073 private:
2074 bool time_major_;
2075
2076 protected:
time_major()2077 bool time_major() { return time_major_; }
2078
2079 public:
CudnnRNNBackwardOpV3(OpKernelConstruction * context)2080 explicit CudnnRNNBackwardOpV3(OpKernelConstruction* context)
2081 : CudnnRNNBackwardOp<GPUDevice, T>(context) {
2082 OP_REQUIRES_OK(context, context->GetAttr("time_major", &time_major_));
2083 if (context->HasAttr("num_proj")) {
2084 OP_REQUIRES_OK(context, context->GetAttr("num_proj", &num_proj_));
2085 } else {
2086 num_proj_ = 0;
2087 }
2088 }
2089
Compute(OpKernelContext * context)2090 void Compute(OpKernelContext* context) override {
2091 CudnnRNNBackwardOp<GPUDevice, T>::ComputeImpl(context, true, time_major(),
2092 num_proj_);
2093 }
2094
2095 private:
2096 int num_proj_;
2097 };
2098
2099 #define REGISTER_GPU(T) \
2100 REGISTER_KERNEL_BUILDER(Name("CudnnRNNBackpropV3") \
2101 .Device(DEVICE_GPU) \
2102 .HostMemory("sequence_lengths") \
2103 .HostMemory("host_reserved") \
2104 .TypeConstraint<T>("T"), \
2105 CudnnRNNBackwardOpV3<GPUDevice, T>);
2106
2107 TF_CALL_half(REGISTER_GPU);
2108 TF_CALL_float(REGISTER_GPU);
2109 TF_CALL_double(REGISTER_GPU);
2110 #undef REGISTER_GPU
2111
2112 // TODO(zhengxq): Add the conversion of Cudnn RNN Params from and to
2113 // its canonical form.
2114
2115 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
2116
2117 } // namespace tensorflow
2118