1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #define EIGEN_USE_THREADS
16
17 #include <stddef.h>
18 #include <atomic>
19 #include <cmath>
20 #include <functional>
21 #include <limits>
22 #include <string>
23 #include <unordered_set>
24
25 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
26 #include "tensorflow/core/framework/device_base.h"
27 #include "tensorflow/core/framework/kernel_def_builder.h"
28 #include "tensorflow/core/framework/op.h"
29 #include "tensorflow/core/framework/op_def_builder.h"
30 #include "tensorflow/core/framework/op_kernel.h"
31 #include "tensorflow/core/framework/register_types.h"
32 #include "tensorflow/core/framework/tensor.h"
33 #include "tensorflow/core/framework/tensor_shape.h"
34 #include "tensorflow/core/framework/tensor_types.h"
35 #include "tensorflow/core/framework/types.h"
36 #include "tensorflow/core/kernels/gpu_utils.h"
37 #include "tensorflow/core/lib/core/errors.h"
38 #include "tensorflow/core/lib/core/status.h"
39 #include "tensorflow/core/lib/core/stringpiece.h"
40 #include "tensorflow/core/lib/gtl/inlined_vector.h"
41 #include "tensorflow/core/lib/hash/hash.h"
42 #include "tensorflow/core/lib/strings/stringprintf.h"
43 #include "tensorflow/core/platform/fingerprint.h"
44 #include "tensorflow/core/platform/mutex.h"
45 #include "tensorflow/core/platform/types.h"
46 #include "tensorflow/core/util/env_var.h"
47 #include "tensorflow/core/util/use_cudnn.h"
48
49 #if GOOGLE_CUDA
50 #include "tensorflow/core/platform/stream_executor.h"
51 #include "tensorflow/core/util/stream_executor_util.h"
52 #endif // GOOGLE_CUDA
53
54 /*
55 * This module implements ops that fuse a multi-layer multi-step RNN/LSTM model
56 * using the underlying Cudnn library.
57 *
58 * Cudnn RNN library exposes an opaque parameter buffer with unknown layout and
59 * format. And it is very likely that if saved, they cannot be used across
60 * different GPUs. So users need to first query the size of the opaque
61 * parameter buffer, and convert it to and from its canonical forms. But each
62 * actual training step is carried out with the parameter buffer.
63 *
64 * Similar to many other ops, the forward op has two flavors: training and
65 * inference. When training is specified, additional data in reserve_space will
66 * be produced for the backward pass. So there is a performance penalty.
67 *
68 * In addition to the actual data and reserve_space, Cudnn also needs more
69 * memory as temporary workspace. The memory management to and from
70 * stream-executor is done through ScratchAllocator. In general,
71 * stream-executor is responsible for creating the memory of proper size. And
72 * TensorFlow is responsible for making sure the memory is alive long enough
73 * and recycles afterwards.
74 *
75 */
76 namespace tensorflow {
77
78 using CPUDevice = Eigen::ThreadPoolDevice;
79
80 #if GOOGLE_CUDA
81
82 using GPUDevice = Eigen::GpuDevice;
83 using se::Stream;
84 using se::StreamExecutor;
85 using se::dnn::RnnDescriptor;
86
87 template <typename Device, typename T, typename Index>
88 class CudnnRNNParamsSizeOp;
89
90 template <typename Device, typename T>
91 class CudnnRNNParamsToCanonical;
92
93 template <typename Device, typename T>
94 class CudnnRNNCanonicalToParams;
95
96 template <typename Device, typename T>
97 class CudnnRNNForwardOp;
98
99 template <typename Device, typename T>
100 class CudnnRNNBackwardOp;
101
102 template <typename Device, typename T>
103 class CudnnRNNForwardOpV2;
104
105 template <typename Device, typename T>
106 class CudnnRNNBackwardOpV2;
107
108 template <typename Device, typename T>
109 class CudnnRNNForwardOpV3;
110
111 template <typename Device, typename T>
112 class CudnnRNNBackwardOpV3;
113
114 enum class TFRNNInputMode {
115 kRNNLinearInput = 0,
116 kRNNSkipInput = 1,
117 kAutoSelect = 9999999
118 };
119
120 namespace {
121 using se::DeviceMemory;
122 using se::DeviceMemoryBase;
123 using se::ScratchAllocator;
124 using se::dnn::AlgorithmConfig;
125 using se::dnn::AlgorithmDesc;
126 using se::dnn::ProfileResult;
127 using se::dnn::RnnDirectionMode;
128 using se::dnn::RnnInputMode;
129 using se::dnn::RnnMode;
130 using se::dnn::RnnSequenceTensorDescriptor;
131 using se::dnn::RnnStateTensorDescriptor;
132 using se::dnn::ToDataType;
133 using se::port::StatusOr;
134
HashList(const std::vector<int> & list)135 uint64 HashList(const std::vector<int>& list) {
136 if (list.empty()) {
137 return 0;
138 }
139 uint64 hash_code = list[0];
140 for (int i = 1; i < list.size(); i++) {
141 hash_code = Hash64Combine(hash_code, list[i]);
142 }
143 return hash_code;
144 }
145
146 // Encapsulate all the shape information that is used in both forward and
147 // backward rnn operations.
148 class CudnnRnnParameters {
149 public:
CudnnRnnParameters(int num_layers,int input_size,int num_units,int max_seq_length,int batch_size,int dir_count,bool has_dropout,bool is_training,RnnMode rnn_mode,TFRNNInputMode rnn_input_mode,DataType dtype)150 CudnnRnnParameters(int num_layers, int input_size, int num_units,
151 int max_seq_length, int batch_size, int dir_count,
152 bool has_dropout, bool is_training, RnnMode rnn_mode,
153 TFRNNInputMode rnn_input_mode, DataType dtype)
154 : num_layers_(num_layers),
155 input_size_(input_size),
156 num_units_(num_units),
157 seq_length_(max_seq_length),
158 batch_size_(batch_size),
159 dir_count_(dir_count),
160 has_dropout_(has_dropout),
161 is_training_(is_training),
162 rnn_mode_(rnn_mode),
163 rnn_input_mode_(rnn_input_mode),
164 dtype_(dtype) {
165 hash_code_ =
166 HashList({num_layers, input_size, num_units, max_seq_length, batch_size,
167 dir_count, static_cast<int>(has_dropout),
168 static_cast<int>(is_training), static_cast<int>(rnn_mode),
169 static_cast<int>(rnn_input_mode), dtype});
170 }
171
operator ==(const CudnnRnnParameters & other) const172 bool operator==(const CudnnRnnParameters& other) const {
173 return this->get_data_as_tuple() == other.get_data_as_tuple();
174 }
175
operator !=(const CudnnRnnParameters & other) const176 bool operator!=(const CudnnRnnParameters& other) const {
177 return !(*this == other);
178 }
hash() const179 uint64 hash() const { return hash_code_; }
180
ToString() const181 string ToString() const {
182 std::vector<string> fields = {
183 std::to_string(num_layers_),
184 std::to_string(input_size_),
185 std::to_string(num_units_),
186 std::to_string(seq_length_),
187 std::to_string(batch_size_),
188 std::to_string(dir_count_),
189 std::to_string(has_dropout_),
190 std::to_string(is_training_),
191 std::to_string(static_cast<int>(rnn_mode_)),
192 std::to_string(static_cast<int>(rnn_input_mode_)),
193 std::to_string(static_cast<int>(dtype_))};
194 return str_util::Join(fields, ", ");
195 }
196
197 private:
198 using ParameterDataType = std::tuple<int, int, int, int, int, int, bool, bool,
199 RnnMode, TFRNNInputMode, DataType>;
200
get_data_as_tuple() const201 ParameterDataType get_data_as_tuple() const {
202 return std::make_tuple(num_layers_, input_size_, num_units_, seq_length_,
203 batch_size_, dir_count_, has_dropout_, is_training_,
204 rnn_mode_, rnn_input_mode_, dtype_);
205 }
206
207 const int num_layers_;
208 const int input_size_;
209 const int num_units_;
210 const int seq_length_;
211 const int batch_size_;
212 const int dir_count_;
213 const bool has_dropout_;
214 const bool is_training_;
215 const RnnMode rnn_mode_;
216 const TFRNNInputMode rnn_input_mode_;
217 const DataType dtype_;
218 uint64 hash_code_;
219 };
220
221 struct RnnAutoTuneGroup {
nametensorflow::__anon454a70830111::RnnAutoTuneGroup222 static string name() { return "Rnn"; }
223 };
224
225 using AutoTuneRnnConfigMap =
226 AutoTuneSingleton<RnnAutoTuneGroup, CudnnRnnParameters, AlgorithmConfig>;
227
ParseRNNMode(const string & str,RnnMode * rnn_mode)228 Status ParseRNNMode(const string& str, RnnMode* rnn_mode) {
229 if (str == "rnn_relu") {
230 *rnn_mode = RnnMode::kRnnRelu;
231 return Status::OK();
232 } else if (str == "rnn_tanh") {
233 *rnn_mode = RnnMode::kRnnTanh;
234 return Status::OK();
235 } else if (str == "lstm") {
236 *rnn_mode = RnnMode::kRnnLstm;
237 return Status::OK();
238 } else if (str == "gru") {
239 *rnn_mode = RnnMode::kRnnGru;
240 return Status::OK();
241 }
242 return errors::InvalidArgument("Invalid RNN mode: ", str);
243 }
244
ParseTFRNNInputMode(const string & str,TFRNNInputMode * rnn_input_mode)245 Status ParseTFRNNInputMode(const string& str, TFRNNInputMode* rnn_input_mode) {
246 if (str == "linear_input") {
247 *rnn_input_mode = TFRNNInputMode::kRNNLinearInput;
248 return Status::OK();
249 } else if (str == "skip_input") {
250 *rnn_input_mode = TFRNNInputMode::kRNNSkipInput;
251 return Status::OK();
252 } else if (str == "auto_select") {
253 *rnn_input_mode = TFRNNInputMode::kAutoSelect;
254 return Status::OK();
255 }
256 return errors::InvalidArgument("Invalid RNN input mode: ", str);
257 }
258
ParseRNNDirectionMode(const string & str,RnnDirectionMode * rnn_dir_mode)259 Status ParseRNNDirectionMode(const string& str,
260 RnnDirectionMode* rnn_dir_mode) {
261 if (str == "unidirectional") {
262 *rnn_dir_mode = RnnDirectionMode::kRnnUnidirectional;
263 return Status::OK();
264 } else if (str == "bidirectional") {
265 *rnn_dir_mode = RnnDirectionMode::kRnnBidirectional;
266 return Status::OK();
267 }
268 return errors::InvalidArgument("Invalid RNN direction mode: ", str);
269 }
270
ToRNNInputMode(TFRNNInputMode tf_input_mode,int num_units,int input_size,RnnInputMode * input_mode)271 Status ToRNNInputMode(TFRNNInputMode tf_input_mode, int num_units,
272 int input_size, RnnInputMode* input_mode) {
273 switch (tf_input_mode) {
274 case TFRNNInputMode::kRNNLinearInput:
275 *input_mode = RnnInputMode::kRnnLinearSkip;
276 break;
277 case TFRNNInputMode::kRNNSkipInput:
278 *input_mode = RnnInputMode::kRnnSkipInput;
279 break;
280 case TFRNNInputMode::kAutoSelect:
281 *input_mode = (input_size == num_units) ? RnnInputMode::kRnnSkipInput
282 : RnnInputMode::kRnnLinearSkip;
283 break;
284 default:
285 return errors::InvalidArgument("Invalid TF input mode: ",
286 static_cast<int>(tf_input_mode));
287 }
288 return Status::OK();
289 }
290
291 // TODO(zhengxq): Merge those into stream_executor_util.h.
292 template <typename T>
AsDeviceMemory(const Tensor * tensor)293 const DeviceMemory<T> AsDeviceMemory(const Tensor* tensor) {
294 return DeviceMemory<T>::MakeFromByteSize(
295 const_cast<T*>(tensor->template flat<T>().data()),
296 tensor->template flat<T>().size() * sizeof(T));
297 }
298
299 template <typename T>
AsDeviceMemory(Tensor * tensor)300 DeviceMemory<T> AsDeviceMemory(Tensor* tensor) {
301 return DeviceMemory<T>::MakeFromByteSize(
302 tensor->template flat<T>().data(),
303 tensor->template flat<T>().size() * sizeof(T));
304 }
305
306 template <typename U, typename T>
CastDeviceMemory(Tensor * tensor)307 DeviceMemory<U> CastDeviceMemory(Tensor* tensor) {
308 return DeviceMemory<U>::MakeFromByteSize(
309 tensor->template flat<T>().data(),
310 tensor->template flat<T>().size() * sizeof(T));
311 }
312
SliceDeviceMemory(const DeviceMemoryBase & device_memory,int64 offset,int64 size)313 DeviceMemoryBase SliceDeviceMemory(const DeviceMemoryBase& device_memory,
314 int64 offset, int64 size) {
315 const void* base_ptr = device_memory.opaque();
316 void* offset_ptr =
317 const_cast<char*>(reinterpret_cast<const char*>(base_ptr) + offset);
318 CHECK(offset + size <= device_memory.size())
319 << "The slice is not within the region of DeviceMemory.";
320 return DeviceMemoryBase(offset_ptr, size);
321 }
322
FromExecutorStatus(const se::port::Status & s)323 inline Status FromExecutorStatus(const se::port::Status& s) {
324 return s.ok() ? Status::OK()
325 : Status(static_cast<error::Code>(static_cast<int>(s.code())),
326 s.error_message());
327 }
328
329 template <typename T>
FromExecutorStatus(const se::port::StatusOr<T> & s)330 inline Status FromExecutorStatus(const se::port::StatusOr<T>& s) {
331 return FromExecutorStatus(s.status());
332 }
333
ToExecutorStatus(const Status & s)334 inline se::port::Status ToExecutorStatus(const Status& s) {
335 return s.ok() ? se::port::Status::OK()
336 : se::port::Status(static_cast<se::port::error::Code>(
337 static_cast<int>(s.code())),
338 s.error_message());
339 }
340
341 template <typename>
342 struct ToTFDataType;
343
344 template <>
345 struct ToTFDataType<Eigen::half> : std::integral_constant<DataType, DT_HALF> {};
346
347 template <>
348 struct ToTFDataType<float> : std::integral_constant<DataType, DT_FLOAT> {};
349
350 template <>
351 struct ToTFDataType<double> : std::integral_constant<DataType, DT_DOUBLE> {};
352
353 template <>
354 struct ToTFDataType<uint8> : std::integral_constant<DataType, DT_UINT8> {};
355
356 // A helper to allocate temporary scratch memory for Cudnn RNN models. It
357 // takes the ownership of the underlying memory. The expectation is that the
358 // memory should be alive for the span of the Cudnn RNN itself.
359 template <typename T>
360 class CudnnRnnAllocatorInTemp : public ScratchAllocator {
361 public:
362 ~CudnnRnnAllocatorInTemp() override = default;
363
CudnnRnnAllocatorInTemp(OpKernelContext * context)364 explicit CudnnRnnAllocatorInTemp(OpKernelContext* context)
365 : context_(context) {}
GetMemoryLimitInBytes(Stream * stream)366 int64 GetMemoryLimitInBytes(Stream* stream) override {
367 return std::numeric_limits<int64>::max();
368 }
369
AllocateBytes(Stream * stream,int64 byte_size)370 StatusOr<DeviceMemory<uint8>> AllocateBytes(Stream* stream,
371 int64 byte_size) override {
372 Tensor temporary_memory;
373 const DataType tf_data_type = ToTFDataType<T>::value;
374 int64 allocate_count =
375 Eigen::divup(byte_size, static_cast<int64>(sizeof(T)));
376 Status allocation_status(context_->allocate_temp(
377 tf_data_type, TensorShape({allocate_count}), &temporary_memory));
378 if (!allocation_status.ok()) {
379 return ToExecutorStatus(allocation_status);
380 }
381 // Hold the reference of the allocated tensors until the end of the
382 // allocator.
383 allocated_tensors_.push_back(temporary_memory);
384 total_byte_size_ += byte_size;
385 return DeviceMemory<uint8>::MakeFromByteSize(
386 temporary_memory.template flat<T>().data(),
387 temporary_memory.template flat<T>().size() * sizeof(T));
388 }
389
TotalByteSize() const390 int64 TotalByteSize() const { return total_byte_size_; }
391
get_allocated_tensor(int index) const392 Tensor get_allocated_tensor(int index) const {
393 return allocated_tensors_[index];
394 }
395
396 private:
397 int64 total_byte_size_ = 0;
398 OpKernelContext* context_; // not owned
399 std::vector<Tensor> allocated_tensors_;
400 };
401
402 // A helper to allocate memory for Cudnn RNN models as a kernel output. It is
403 // used by forward pass kernel to feed the output to the backward pass.
404 // The memory is expected to live long enough after the backward pass is
405 // finished.
406 template <typename T>
407 class CudnnRnnAllocatorInOutput : public ScratchAllocator {
408 public:
~CudnnRnnAllocatorInOutput()409 ~CudnnRnnAllocatorInOutput() override {}
CudnnRnnAllocatorInOutput(OpKernelContext * context,int output_index)410 CudnnRnnAllocatorInOutput(OpKernelContext* context, int output_index)
411 : context_(context), output_index_(output_index) {}
GetMemoryLimitInBytes(Stream * stream)412 int64 GetMemoryLimitInBytes(Stream* stream) override {
413 return std::numeric_limits<int64>::max();
414 }
AllocateBytes(Stream * stream,int64 byte_size)415 StatusOr<DeviceMemory<uint8>> AllocateBytes(Stream* stream,
416 int64 byte_size) override {
417 CHECK(total_byte_size_ == 0)
418 << "Reserve space allocator can only be called once";
419 int64 allocate_count =
420 Eigen::divup(byte_size, static_cast<int64>(sizeof(T)));
421
422 Tensor* temporary_memory = nullptr;
423 Status allocation_status(context_->allocate_output(
424 output_index_, TensorShape({allocate_count}), &temporary_memory));
425 if (!allocation_status.ok()) {
426 return ToExecutorStatus(allocation_status);
427 }
428 total_byte_size_ += byte_size;
429 auto memory_uint8 = DeviceMemory<uint8>::MakeFromByteSize(
430 temporary_memory->template flat<T>().data(),
431 temporary_memory->template flat<T>().size() * sizeof(T));
432 return StatusOr<DeviceMemory<uint8>>(memory_uint8);
433 }
TotalByteSize()434 int64 TotalByteSize() { return total_byte_size_; }
435
436 private:
437 int64 total_byte_size_ = 0;
438 OpKernelContext* context_; // not owned
439 int output_index_;
440 };
441
442 // A helper to allocate persistent memory for Cudnn RNN models, which is
443 // expected to live between kernel invocations.
444 // This class is not thread-safe.
445 class CudnnRNNPersistentSpaceAllocator : public ScratchAllocator {
446 public:
CudnnRNNPersistentSpaceAllocator(OpKernelContext * context)447 explicit CudnnRNNPersistentSpaceAllocator(OpKernelContext* context)
448 : context_(context) {}
449
~CudnnRNNPersistentSpaceAllocator()450 ~CudnnRNNPersistentSpaceAllocator() override {}
451
GetMemoryLimitInBytes(Stream * stream)452 int64 GetMemoryLimitInBytes(Stream* stream) override {
453 return std::numeric_limits<int64>::max();
454 }
455
AllocateBytes(Stream * stream,int64 byte_size)456 StatusOr<DeviceMemory<uint8>> AllocateBytes(Stream* stream,
457 int64 byte_size) override {
458 if (total_byte_size_ != 0) {
459 return Status(error::FAILED_PRECONDITION,
460 "Persistent space allocator can only be called once");
461 }
462
463 Status allocation_status = context_->allocate_persistent(
464 DT_UINT8, TensorShape({byte_size}), &handle_, nullptr);
465 if (!allocation_status.ok()) {
466 return ToExecutorStatus(allocation_status);
467 }
468 total_byte_size_ += byte_size;
469 return AsDeviceMemory<uint8>(handle_.AccessTensor(context_));
470 }
TotalByteSize()471 int64 TotalByteSize() { return total_byte_size_; }
472
473 private:
474 int64 total_byte_size_ = 0;
475 PersistentTensor handle_;
476 OpKernelContext* context_; // not owned
477 };
478
479 struct CudnnModelTypes {
480 RnnMode rnn_mode;
481 TFRNNInputMode rnn_input_mode;
482 RnnDirectionMode rnn_direction_mode;
HasInputCtensorflow::__anon454a70830111::CudnnModelTypes483 bool HasInputC() const {
484 // For Cudnn 5.0, only LSTM has input-c. All other models use only
485 // input-h.
486 return rnn_mode == RnnMode::kRnnLstm;
487 }
488
DebugStringtensorflow::__anon454a70830111::CudnnModelTypes489 string DebugString() const {
490 return strings::Printf(
491 "[rnn_mode, rnn_input_mode, rnn_direction_mode]: %d, %d, %d ",
492 static_cast<int>(rnn_mode), static_cast<int>(rnn_input_mode),
493 static_cast<int>(rnn_direction_mode));
494 }
495 };
496
497 // A helper class that collects the shapes to describe a RNN model.
498 struct CudnnRnnModelShapes {
499 int num_layers;
500 int input_size;
501 int num_units;
502 int dir_count;
503 int max_seq_length;
504 int batch_size;
505 TensorShape input_shape;
506 TensorShape output_shape;
507 TensorShape hidden_state_shape;
508 // At present only fields related to cached RnnDescriptor are concerned.
IsCompatibleWithtensorflow::__anon454a70830111::CudnnRnnModelShapes509 bool IsCompatibleWith(const CudnnRnnModelShapes& rhs) const {
510 return num_layers == rhs.num_layers && input_size == rhs.input_size &&
511 num_units == rhs.num_units && dir_count == rhs.dir_count;
512 }
DebugStringtensorflow::__anon454a70830111::CudnnRnnModelShapes513 string DebugString() const {
514 return strings::Printf(
515 "[num_layers, input_size, num_units, dir_count, max_seq_length, "
516 "batch_size]: [%d, %d, %d, %d, %d, %d] ",
517 num_layers, input_size, num_units, dir_count, max_seq_length,
518 batch_size);
519 }
520 };
521
522 // Utility class for using CudnnRnnConfig and AlgorithmDesc pair a hash table
523 // key.
524 struct CudnnRnnConfigHasher {
operator ()tensorflow::__anon454a70830111::CudnnRnnConfigHasher525 uint64 operator()(
526 const std::pair<CudnnRnnModelShapes, absl::optional<AlgorithmDesc>>&
527 to_hash) const {
528 auto& shapes = to_hash.first;
529 auto& algo_desc = to_hash.second;
530
531 uint64 hash =
532 HashList({shapes.num_layers, shapes.input_size, shapes.num_units,
533 shapes.dir_count, shapes.batch_size});
534 if (algo_desc.has_value()) {
535 hash = Hash64Combine(hash, algo_desc->hash());
536 }
537 return hash;
538 }
539 };
540
541 // Utility class for using CudnnRnnModelShapes and AlgorithmDesc pair as a hash
542 // table key.
543 struct CudnnRnnConfigComparator {
operator ()tensorflow::__anon454a70830111::CudnnRnnConfigComparator544 bool operator()(
545 const std::pair<CudnnRnnModelShapes, absl::optional<AlgorithmDesc>>& lhs,
546 const std::pair<CudnnRnnModelShapes, absl::optional<AlgorithmDesc>>& rhs)
547 const {
548 return lhs.first.IsCompatibleWith(rhs.first) && lhs.second == rhs.second;
549 }
550 };
551
552 // Pointers to RNN scratch space for a specific set of shape parameters (used as
553 // a hash table value in CudnnRNNForwardOp and CudnnRNNBackwardOp).
554 struct RnnScratchSpace {
555 std::unique_ptr<RnnDescriptor> rnn_desc;
556 std::unique_ptr<CudnnRNNPersistentSpaceAllocator> dropout_state_allocator;
557 };
558
559 // Extract and checks the forward input tensors, parameters, and shapes from the
560 // OpKernelContext.
ExtractForwardInput(OpKernelContext * context,const CudnnModelTypes & model_types,bool time_major,const Tensor ** input,const Tensor ** input_h,const Tensor ** input_c,const Tensor ** params,CudnnRnnModelShapes * model_shapes)561 Status ExtractForwardInput(OpKernelContext* context,
562 const CudnnModelTypes& model_types, bool time_major,
563 const Tensor** input, const Tensor** input_h,
564 const Tensor** input_c, const Tensor** params,
565 CudnnRnnModelShapes* model_shapes) {
566 TF_RETURN_IF_ERROR(context->input("input", input));
567 TF_RETURN_IF_ERROR(context->input("input_h", input_h));
568 if (model_types.HasInputC()) {
569 TF_RETURN_IF_ERROR(context->input("input_c", input_c));
570 }
571 TF_RETURN_IF_ERROR(context->input("params", params));
572
573 if ((*input)->dims() != 3) {
574 return errors::InvalidArgument("RNN input must be a 3-D vector.");
575 }
576 if (time_major) {
577 model_shapes->max_seq_length = (*input)->dim_size(0);
578 model_shapes->batch_size = (*input)->dim_size(1);
579 } else {
580 model_shapes->max_seq_length = (*input)->dim_size(1);
581 model_shapes->batch_size = (*input)->dim_size(0);
582 }
583 model_shapes->input_size = (*input)->dim_size(2);
584 model_shapes->input_shape = (*input)->shape();
585 model_shapes->dir_count =
586 (model_types.rnn_direction_mode == RnnDirectionMode::kRnnBidirectional)
587 ? 2
588 : 1;
589
590 if ((*input_h)->dims() != 3) {
591 return errors::InvalidArgument("RNN input_h must be a 3-D vector.");
592 }
593 if (time_major) {
594 model_shapes->num_layers =
595 (*input_h)->dim_size(0) / model_shapes->dir_count;
596 } else {
597 model_shapes->num_layers =
598 (*input_h)->dim_size(1) / model_shapes->dir_count;
599 }
600 model_shapes->num_units = (*input_h)->dim_size(2);
601
602 if (time_major) {
603 model_shapes->hidden_state_shape =
604 TensorShape({model_shapes->dir_count * model_shapes->num_layers,
605 model_shapes->batch_size, model_shapes->num_units});
606 } else {
607 model_shapes->hidden_state_shape =
608 TensorShape({model_shapes->batch_size,
609 model_shapes->dir_count * model_shapes->num_layers,
610 model_shapes->num_units});
611 }
612 if ((*input_h)->shape() != model_shapes->hidden_state_shape) {
613 return errors::InvalidArgument(
614 "Invalid input_h shape: ", (*input_h)->shape().DebugString(), " ",
615 model_shapes->hidden_state_shape.DebugString());
616 }
617 if (model_types.HasInputC()) {
618 if ((*input_h)->shape() != (*input_c)->shape()) {
619 return errors::InvalidArgument(
620 "input_h and input_c must have the same shape: ",
621 (*input_h)->shape().DebugString(), " ",
622 (*input_c)->shape().DebugString());
623 }
624 }
625 if (time_major) {
626 model_shapes->output_shape =
627 TensorShape({model_shapes->max_seq_length, model_shapes->batch_size,
628 model_shapes->dir_count * model_shapes->num_units});
629 } else {
630 model_shapes->output_shape =
631 TensorShape({model_shapes->batch_size, model_shapes->max_seq_length,
632 model_shapes->dir_count * model_shapes->num_units});
633 }
634 return Status::OK();
635 }
636
637 // Overloaded function to process the sequence_lengths
ExtractForwardInput(OpKernelContext * context,const CudnnModelTypes & model_types,bool time_major,const Tensor ** input,const Tensor ** input_h,const Tensor ** input_c,const Tensor ** params,const Tensor ** sequence_lengths,CudnnRnnModelShapes * model_shapes)638 Status ExtractForwardInput(OpKernelContext* context,
639 const CudnnModelTypes& model_types, bool time_major,
640 const Tensor** input, const Tensor** input_h,
641 const Tensor** input_c, const Tensor** params,
642 const Tensor** sequence_lengths,
643 CudnnRnnModelShapes* model_shapes) {
644 TF_RETURN_IF_ERROR(context->input("sequence_lengths", sequence_lengths));
645 return ExtractForwardInput(context, model_types, time_major, input, input_h,
646 input_c, params, model_shapes);
647 }
648
649 template <typename T>
CreateForwardAndBackwardIODescriptors(OpKernelContext * context,const CudnnRnnModelShapes & model_shapes,std::unique_ptr<RnnSequenceTensorDescriptor> * input_desc,std::unique_ptr<RnnStateTensorDescriptor> * state_desc,std::unique_ptr<RnnSequenceTensorDescriptor> * output_desc,const absl::Span<const int> & seq_lengths,bool time_major)650 Status CreateForwardAndBackwardIODescriptors(
651 OpKernelContext* context, const CudnnRnnModelShapes& model_shapes,
652 std::unique_ptr<RnnSequenceTensorDescriptor>* input_desc,
653 std::unique_ptr<RnnStateTensorDescriptor>* state_desc,
654 std::unique_ptr<RnnSequenceTensorDescriptor>* output_desc,
655 const absl::Span<const int>& seq_lengths, bool time_major) {
656 StreamExecutor* executor = context->op_device_context()->stream()->parent();
657 se::dnn::DataType data_type = ToDataType<T>::value;
658
659 const TensorShape& input_shape = model_shapes.input_shape;
660 const TensorShape& hidden_state_shape = model_shapes.hidden_state_shape;
661 const TensorShape& output_shape = model_shapes.output_shape;
662
663 DCHECK_EQ(input_shape.dims(), 3);
664 if (seq_lengths.data() != nullptr) {
665 if (time_major) {
666 auto input_desc_s = executor->createRnnSequenceTensorDescriptor(
667 input_shape.dim_size(0), input_shape.dim_size(1),
668 input_shape.dim_size(2), seq_lengths, time_major, data_type);
669 TF_RETURN_IF_ERROR(input_desc_s.status());
670 *input_desc = input_desc_s.ConsumeValueOrDie();
671 } else {
672 auto input_desc_s = executor->createRnnSequenceTensorDescriptor(
673 input_shape.dim_size(1), input_shape.dim_size(0),
674 input_shape.dim_size(2), seq_lengths, time_major, data_type);
675 TF_RETURN_IF_ERROR(input_desc_s.status());
676 *input_desc = input_desc_s.ConsumeValueOrDie();
677 }
678 } else {
679 auto input_desc_s = executor->createRnnSequenceTensorDescriptor(
680 input_shape.dim_size(0), input_shape.dim_size(1),
681 input_shape.dim_size(2), data_type);
682 TF_RETURN_IF_ERROR(input_desc_s.status());
683 *input_desc = input_desc_s.ConsumeValueOrDie();
684 }
685
686 DCHECK_EQ(hidden_state_shape.dims(), 3);
687 if (time_major) {
688 auto hidden_state_desc_s = executor->createRnnStateTensorDescriptor(
689 hidden_state_shape.dim_size(0), hidden_state_shape.dim_size(1),
690 hidden_state_shape.dim_size(2), data_type);
691 TF_RETURN_IF_ERROR(hidden_state_desc_s.status());
692 *state_desc = hidden_state_desc_s.ConsumeValueOrDie();
693 } else {
694 auto hidden_state_desc_s = executor->createRnnStateTensorDescriptor(
695 hidden_state_shape.dim_size(1), hidden_state_shape.dim_size(0),
696 hidden_state_shape.dim_size(2), data_type);
697 TF_RETURN_IF_ERROR(hidden_state_desc_s.status());
698 *state_desc = hidden_state_desc_s.ConsumeValueOrDie();
699 }
700
701 DCHECK_EQ(output_shape.dims(), 3);
702 if (seq_lengths.data() != nullptr) {
703 if (time_major) {
704 auto output_desc_s = executor->createRnnSequenceTensorDescriptor(
705 output_shape.dim_size(0), output_shape.dim_size(1),
706 output_shape.dim_size(2), seq_lengths, time_major, data_type);
707 TF_RETURN_IF_ERROR(output_desc_s.status());
708 *output_desc = output_desc_s.ConsumeValueOrDie();
709 } else {
710 auto output_desc_s = executor->createRnnSequenceTensorDescriptor(
711 output_shape.dim_size(1), output_shape.dim_size(0),
712 output_shape.dim_size(2), seq_lengths, time_major, data_type);
713 TF_RETURN_IF_ERROR(output_desc_s.status());
714 *output_desc = output_desc_s.ConsumeValueOrDie();
715 }
716 } else {
717 auto output_desc_s = executor->createRnnSequenceTensorDescriptor(
718 output_shape.dim_size(0), output_shape.dim_size(1),
719 output_shape.dim_size(2), data_type);
720 TF_RETURN_IF_ERROR(output_desc_s.status());
721 *output_desc = output_desc_s.ConsumeValueOrDie();
722 }
723
724 return Status::OK();
725 }
726
727 template <typename T>
DoForward(OpKernelContext * context,const RnnDescriptor & rnn_desc,const CudnnModelTypes & model_types,const CudnnRnnModelShapes & model_shapes,const Tensor * input,const Tensor * input_h,const Tensor * input_c,const Tensor * params,const bool is_training,Tensor * output,Tensor * output_h,Tensor * output_c,const Tensor * sequence_lengths,bool time_major,ScratchAllocator * reserve_space_allocator,ScratchAllocator * workspace_allocator,ProfileResult * output_profile_result)728 Status DoForward(OpKernelContext* context, const RnnDescriptor& rnn_desc,
729 const CudnnModelTypes& model_types,
730 const CudnnRnnModelShapes& model_shapes,
731 /* forward inputs */
732 const Tensor* input, const Tensor* input_h,
733 const Tensor* input_c, const Tensor* params,
734 const bool is_training,
735 /* forward outputs, outputs of the function */
736 Tensor* output, Tensor* output_h, Tensor* output_c,
737 const Tensor* sequence_lengths, bool time_major,
738 ScratchAllocator* reserve_space_allocator,
739 ScratchAllocator* workspace_allocator,
740 ProfileResult* output_profile_result) {
741 std::unique_ptr<RnnSequenceTensorDescriptor> input_desc;
742 std::unique_ptr<RnnStateTensorDescriptor> state_desc;
743 std::unique_ptr<RnnSequenceTensorDescriptor> output_desc;
744
745 absl::Span<const int> seq_lengths;
746 if (sequence_lengths != nullptr) {
747 seq_lengths = absl::Span<const int>(
748 sequence_lengths->template flat<int>().data(), model_shapes.batch_size);
749 }
750 TF_RETURN_IF_ERROR(CreateForwardAndBackwardIODescriptors<T>(
751 context, model_shapes, &input_desc, &state_desc, &output_desc,
752 seq_lengths, time_major));
753
754 auto input_data = AsDeviceMemory<T>(input);
755 auto input_h_data = AsDeviceMemory<T>(input_h);
756 DeviceMemory<T> input_c_data;
757 if (model_types.HasInputC()) {
758 input_c_data = AsDeviceMemory<T>(input_c);
759 }
760
761 auto params_data = AsDeviceMemory<T>(params);
762 auto output_data = AsDeviceMemory<T>(output);
763 auto output_h_data = AsDeviceMemory<T>(output_h);
764 DeviceMemory<T> output_c_data;
765 if (model_types.HasInputC()) {
766 output_c_data = AsDeviceMemory<T>(output_c);
767 }
768
769 Stream* stream = context->op_device_context()->stream();
770 bool launch_success =
771 stream
772 ->ThenRnnForward(rnn_desc, *input_desc, input_data, *state_desc,
773 input_h_data, *state_desc, input_c_data, params_data,
774 *output_desc, &output_data, *state_desc,
775 &output_h_data, *state_desc, &output_c_data,
776 is_training, reserve_space_allocator,
777 workspace_allocator, output_profile_result)
778 .ok();
779 return launch_success
780 ? Status::OK()
781 : errors::Internal(
782 "Failed to call ThenRnnForward with model config: ",
783 model_types.DebugString(), ", ", model_shapes.DebugString());
784 }
785
786 template <typename T>
DoBackward(OpKernelContext * context,const RnnDescriptor & rnn_desc,const CudnnModelTypes & model_types,const CudnnRnnModelShapes & model_shapes,const Tensor * input,const Tensor * input_h,const Tensor * input_c,const Tensor * params,const Tensor * output,const Tensor * output_h,const Tensor * output_c,const Tensor * output_backprop,const Tensor * output_h_backprop,const Tensor * output_c_backprop,const Tensor * reserve_space,Tensor * input_backprop,Tensor * input_h_backprop,Tensor * input_c_backprop,Tensor * params_backprop,const Tensor * sequence_lengths,bool time_major,ScratchAllocator * workspace_allocator,ProfileResult * output_profile_result)787 Status DoBackward(
788 OpKernelContext* context, const RnnDescriptor& rnn_desc,
789 const CudnnModelTypes& model_types, const CudnnRnnModelShapes& model_shapes,
790 /* forward inputs */
791 const Tensor* input, const Tensor* input_h, const Tensor* input_c,
792 const Tensor* params,
793 /* forward outputs */
794 const Tensor* output, const Tensor* output_h, const Tensor* output_c,
795 /* backprop inputs */
796 const Tensor* output_backprop, const Tensor* output_h_backprop,
797 const Tensor* output_c_backprop, const Tensor* reserve_space,
798 /* backprop outputs, output of the function */
799 Tensor* input_backprop, Tensor* input_h_backprop, Tensor* input_c_backprop,
800 Tensor* params_backprop, const Tensor* sequence_lengths, bool time_major,
801 ScratchAllocator* workspace_allocator,
802 ProfileResult* output_profile_result) {
803 std::unique_ptr<RnnSequenceTensorDescriptor> input_desc;
804 std::unique_ptr<RnnStateTensorDescriptor> state_desc;
805 std::unique_ptr<RnnSequenceTensorDescriptor> output_desc;
806
807 absl::Span<const int> seq_lengths;
808 if (sequence_lengths != nullptr) {
809 seq_lengths = absl::Span<const int>(
810 sequence_lengths->template flat<int>().data(), model_shapes.batch_size);
811 }
812 TF_RETURN_IF_ERROR(CreateForwardAndBackwardIODescriptors<T>(
813 context, model_shapes, &input_desc, &state_desc, &output_desc,
814 seq_lengths, time_major));
815
816 auto input_data = AsDeviceMemory<T>(input);
817 auto input_h_data = AsDeviceMemory<T>(input_h);
818 DeviceMemory<T> input_c_data;
819 if (model_types.HasInputC()) {
820 input_c_data = AsDeviceMemory<T>(input_c);
821 }
822 auto params_data = AsDeviceMemory<T>(params);
823 auto output_data = AsDeviceMemory<T>(output);
824 auto output_h_data = AsDeviceMemory<T>(output_h);
825 DeviceMemory<T> output_c_data;
826 if (model_types.HasInputC()) {
827 output_c_data = AsDeviceMemory<T>(output_c);
828 }
829 auto output_backprop_data = AsDeviceMemory<T>(output_backprop);
830 auto output_h_backprop_data = AsDeviceMemory<T>(output_h_backprop);
831 DeviceMemory<T> output_c_backprop_data;
832 if (model_types.HasInputC()) {
833 output_c_backprop_data = AsDeviceMemory<T>(output_c_backprop);
834 }
835 auto input_backprop_data = AsDeviceMemory<T>(input_backprop);
836 auto input_h_backprop_data = AsDeviceMemory<T>(input_h_backprop);
837 DeviceMemory<T> input_c_backprop_data;
838 if (model_types.HasInputC()) {
839 input_c_backprop_data = AsDeviceMemory<T>(input_c_backprop);
840 }
841 auto params_backprop_data = AsDeviceMemory<T>(params_backprop);
842 auto reserve_space_uint8 =
843 CastDeviceMemory<uint8, T>(const_cast<Tensor*>(reserve_space));
844
845 // Creates a memory callback for the workspace. The memory lives to the end
846 // of this kernel calls.
847 Stream* stream = context->op_device_context()->stream();
848 bool launch_success =
849 stream
850 ->ThenRnnBackward(rnn_desc, *input_desc, input_data, *state_desc,
851 input_h_data, *state_desc, input_c_data,
852 params_data, *output_desc, output_data, *state_desc,
853 output_h_data, *state_desc, output_c_data,
854 output_backprop_data, output_h_backprop_data,
855 output_c_backprop_data, &input_backprop_data,
856 &input_h_backprop_data, &input_c_backprop_data,
857 ¶ms_backprop_data, &reserve_space_uint8,
858 workspace_allocator, output_profile_result)
859 .ok();
860 return launch_success
861 ? Status::OK()
862 : errors::Internal(
863 "Failed to call ThenRnnBackward with model config: ",
864 model_types.DebugString(), ", ", model_shapes.DebugString());
865 }
866
867 template <typename T>
RestoreParams(const OpInputList params_input,const std::vector<RnnDescriptor::ParamsRegion> & params,DeviceMemoryBase * data_dst,Stream * stream)868 void RestoreParams(const OpInputList params_input,
869 const std::vector<RnnDescriptor::ParamsRegion>& params,
870 DeviceMemoryBase* data_dst, Stream* stream) {
871 int num_params = params.size();
872 CHECK(params_input.size() == num_params)
873 << "Number of params mismatch. Expected " << params_input.size()
874 << ", got " << num_params;
875 for (int i = 0; i < params.size(); i++) {
876 int64 size_in_bytes = params[i].size;
877 int64 size = size_in_bytes / sizeof(T);
878 CHECK(size == params_input[i].NumElements())
879 << "Params size mismatch. Expected " << size << ", got "
880 << params_input[i].NumElements();
881 auto data_src_ptr = StreamExecutorUtil::AsDeviceMemory<T>(params_input[i]);
882 DeviceMemoryBase data_dst_ptr =
883 SliceDeviceMemory(*data_dst, params[i].offset, size_in_bytes);
884 stream->ThenMemcpy(&data_dst_ptr, data_src_ptr, size_in_bytes);
885 }
886 }
887
888 } // namespace
889
890 // Note: all following kernels depend on a RnnDescriptor instance, which
891 // according to Cudnn official doc should be kept around and reused across all
892 // Cudnn kernels in the same model.
893 // In Tensorflow, we don't pass the reference across different OpKernels,
894 // rather, recreate it separately in each OpKernel, which does no cause issue:
895 // CudnnDropoutDescriptor keeps a reference to a memory for
896 // random number generator state. During recreation, this state is lost.
897 // However, only forward-pass Cudnn APIs make use of the state.
898
899 // A common base class for RNN kernels. It extracts common attributes and
900 // shape validations.
901 class CudnnRNNKernelCommon : public OpKernel {
902 protected:
CudnnRNNKernelCommon(OpKernelConstruction * context)903 explicit CudnnRNNKernelCommon(OpKernelConstruction* context)
904 : OpKernel(context) {
905 OP_REQUIRES_OK(context, context->GetAttr("dropout", &dropout_));
906 OP_REQUIRES_OK(context, context->GetAttr("seed", &seed_));
907 OP_REQUIRES_OK(context, context->GetAttr("seed2", &seed2_));
908 string str;
909 OP_REQUIRES_OK(context, context->GetAttr("rnn_mode", &str));
910 OP_REQUIRES_OK(context, ParseRNNMode(str, &model_types_.rnn_mode));
911 OP_REQUIRES_OK(context, context->GetAttr("input_mode", &str));
912 OP_REQUIRES_OK(context,
913 ParseTFRNNInputMode(str, &model_types_.rnn_input_mode));
914 OP_REQUIRES_OK(context, context->GetAttr("direction", &str));
915 OP_REQUIRES_OK(
916 context, ParseRNNDirectionMode(str, &model_types_.rnn_direction_mode));
917 // Reset CudnnRnnDescriptor and related random number generate states in
918 // every Compute() call.
919 OP_REQUIRES_OK(context, ReadBoolFromEnvVar("TF_CUDNN_RESET_RND_GEN_STATE",
920 false, &reset_rnd_gen_state_));
921 }
922
HasInputC() const923 bool HasInputC() const { return model_types_.HasInputC(); }
rnn_mode() const924 RnnMode rnn_mode() const { return model_types_.rnn_mode; }
rnn_input_mode() const925 TFRNNInputMode rnn_input_mode() const { return model_types_.rnn_input_mode; }
rnn_direction_mode() const926 RnnDirectionMode rnn_direction_mode() const {
927 return model_types_.rnn_direction_mode;
928 }
model_types() const929 const CudnnModelTypes& model_types() const { return model_types_; }
dropout() const930 float dropout() const { return dropout_; }
seed()931 uint64 seed() { return (static_cast<uint64>(seed_) << 32) | seed2_; }
ResetRndGenState()932 bool ResetRndGenState() { return reset_rnd_gen_state_; }
933
934 template <typename T>
ExtractCudnnRNNParamsInfo(OpKernelContext * context,std::unique_ptr<RnnDescriptor> * rnn_desc)935 Status ExtractCudnnRNNParamsInfo(OpKernelContext* context,
936 std::unique_ptr<RnnDescriptor>* rnn_desc) {
937 const Tensor* num_layers_t = nullptr;
938 TF_RETURN_IF_ERROR(context->input("num_layers", &num_layers_t));
939 if (!TensorShapeUtils::IsScalar(num_layers_t->shape())) {
940 return errors::InvalidArgument("num_layers is not a scalar");
941 }
942 int num_layers = num_layers_t->scalar<int>()();
943 const Tensor* num_units_t = nullptr;
944 TF_RETURN_IF_ERROR(context->input("num_units", &num_units_t));
945 if (!TensorShapeUtils::IsScalar(num_units_t->shape())) {
946 return errors::InvalidArgument("num_units is not a scalar");
947 }
948 int num_units = num_units_t->scalar<int>()();
949 const Tensor* input_size_t = nullptr;
950 TF_RETURN_IF_ERROR(context->input("input_size", &input_size_t));
951 if (!TensorShapeUtils::IsScalar(input_size_t->shape())) {
952 return errors::InvalidArgument("input_size is not a scalar");
953 }
954 int input_size = input_size_t->scalar<int>()();
955
956 RnnInputMode input_mode;
957 TF_RETURN_IF_ERROR(
958 ToRNNInputMode(rnn_input_mode(), num_units, input_size, &input_mode));
959
960 Stream* stream = context->op_device_context()->stream();
961 // ExtracCudnnRNNParamsInfo is only called by op_kernels that do not require
962 // random number generator, therefore set state_allocator to nullptr.
963 const AlgorithmConfig algo_config;
964 auto rnn_desc_s = stream->parent()->createRnnDescriptor(
965 num_layers, num_units, input_size, /*batch_size=*/0, input_mode,
966 rnn_direction_mode(), rnn_mode(), ToDataType<T>::value, algo_config,
967 dropout(), seed(), /* state_allocator=*/nullptr);
968 if (!rnn_desc_s.ok()) {
969 return FromExecutorStatus(rnn_desc_s);
970 }
971 *rnn_desc = rnn_desc_s.ConsumeValueOrDie();
972 return Status::OK();
973 }
974
975 template <typename T>
CreateRnnDescriptor(OpKernelContext * context,const CudnnRnnModelShapes & model_shapes,const RnnInputMode & input_mode,const AlgorithmConfig & algo_config,ScratchAllocator * dropout_state_allocator,std::unique_ptr<RnnDescriptor> * rnn_desc)976 Status CreateRnnDescriptor(OpKernelContext* context,
977 const CudnnRnnModelShapes& model_shapes,
978 const RnnInputMode& input_mode,
979 const AlgorithmConfig& algo_config,
980 ScratchAllocator* dropout_state_allocator,
981 std::unique_ptr<RnnDescriptor>* rnn_desc) {
982 StreamExecutor* executor = context->op_device_context()->stream()->parent();
983 se::dnn::DataType data_type = ToDataType<T>::value;
984 auto rnn_desc_s = executor->createRnnDescriptor(
985 model_shapes.num_layers, model_shapes.num_units,
986 model_shapes.input_size, model_shapes.batch_size, input_mode,
987 rnn_direction_mode(), rnn_mode(), data_type, algo_config, dropout(),
988 seed(), dropout_state_allocator);
989 TF_RETURN_IF_ERROR(rnn_desc_s.status());
990
991 *rnn_desc = rnn_desc_s.ConsumeValueOrDie();
992 return Status::OK();
993 }
994
995 using RnnStateCache = gtl::FlatMap<
996 std::pair<CudnnRnnModelShapes, absl::optional<AlgorithmDesc>>,
997 RnnScratchSpace, CudnnRnnConfigHasher, CudnnRnnConfigComparator>;
998 // Returns a raw rnn descriptor pointer. The cache owns the rnn descriptor and
999 // should outlive the returned pointer.
1000 template <typename T>
GetCachedRnnDescriptor(OpKernelContext * context,const CudnnRnnModelShapes & model_shapes,const RnnInputMode & input_mode,const AlgorithmConfig & algo_config,RnnStateCache * cache,RnnDescriptor ** rnn_desc)1001 Status GetCachedRnnDescriptor(OpKernelContext* context,
1002 const CudnnRnnModelShapes& model_shapes,
1003 const RnnInputMode& input_mode,
1004 const AlgorithmConfig& algo_config,
1005 RnnStateCache* cache,
1006 RnnDescriptor** rnn_desc) {
1007 auto key = std::make_pair(model_shapes, algo_config.algorithm());
1008 RnnScratchSpace& rnn_state = (*cache)[key];
1009 if (rnn_state.rnn_desc == nullptr || ResetRndGenState()) {
1010 CudnnRNNPersistentSpaceAllocator* dropout_state_allocator =
1011 new CudnnRNNPersistentSpaceAllocator(context);
1012 rnn_state.dropout_state_allocator.reset(dropout_state_allocator);
1013 Status status =
1014 CreateRnnDescriptor<T>(context, model_shapes, input_mode, algo_config,
1015 dropout_state_allocator, &rnn_state.rnn_desc);
1016 TF_RETURN_IF_ERROR(status);
1017 }
1018 *rnn_desc = rnn_state.rnn_desc.get();
1019 return Status::OK();
1020 }
1021
1022 private:
1023 int seed_;
1024 int seed2_;
1025 float dropout_;
1026 bool reset_rnd_gen_state_;
1027
1028 CudnnModelTypes model_types_;
1029 };
1030
1031 // A class that returns the size of the opaque parameter buffer. The user should
1032 // use that to create the actual parameter buffer for training. However, it
1033 // should not be used for saving and restoring.
1034 template <typename T, typename Index>
1035 class CudnnRNNParamsSizeOp<GPUDevice, T, Index> : public CudnnRNNKernelCommon {
1036 public:
CudnnRNNParamsSizeOp(OpKernelConstruction * context)1037 explicit CudnnRNNParamsSizeOp(OpKernelConstruction* context)
1038 : CudnnRNNKernelCommon(context) {}
1039
Compute(OpKernelContext * context)1040 void Compute(OpKernelContext* context) override {
1041 std::unique_ptr<RnnDescriptor> rnn_desc;
1042 OP_REQUIRES_OK(context, ExtractCudnnRNNParamsInfo<T>(context, &rnn_desc));
1043 int64 params_size_in_bytes = rnn_desc->ParamsSizeInBytes();
1044 CHECK(params_size_in_bytes % sizeof(T) == 0)
1045 << "params_size_in_bytes must be multiple of element size";
1046 int64 params_size = params_size_in_bytes / sizeof(T);
1047
1048 Tensor* output_t = nullptr;
1049 OP_REQUIRES_OK(context, context->allocate_output(0, {1}, &output_t));
1050 *output_t->template flat<Index>().data() = params_size;
1051 }
1052 };
1053
1054 #define REGISTER_GPU(T) \
1055 REGISTER_KERNEL_BUILDER(Name("CudnnRNNParamsSize") \
1056 .Device(DEVICE_GPU) \
1057 .HostMemory("num_layers") \
1058 .HostMemory("num_units") \
1059 .HostMemory("input_size") \
1060 .HostMemory("params_size") \
1061 .TypeConstraint<T>("T") \
1062 .TypeConstraint<int32>("S"), \
1063 CudnnRNNParamsSizeOp<GPUDevice, T, int32>);
1064
1065 TF_CALL_half(REGISTER_GPU);
1066 TF_CALL_float(REGISTER_GPU);
1067 TF_CALL_double(REGISTER_GPU);
1068 #undef REGISTER_GPU
1069
1070 // Convert weight and bias params from a platform-specific layout to the
1071 // canonical form.
1072 template <typename T>
1073 class CudnnRNNParamsToCanonical<GPUDevice, T> : public CudnnRNNKernelCommon {
1074 public:
CudnnRNNParamsToCanonical(OpKernelConstruction * context)1075 explicit CudnnRNNParamsToCanonical(OpKernelConstruction* context)
1076 : CudnnRNNKernelCommon(context) {
1077 OP_REQUIRES_OK(context, context->GetAttr("num_params", &num_params_));
1078 }
1079
Compute(OpKernelContext * context)1080 void Compute(OpKernelContext* context) override {
1081 const Tensor& input = context->input(3);
1082 auto input_ptr = StreamExecutorUtil::AsDeviceMemory<T>(input);
1083 Stream* stream = context->op_device_context()->stream();
1084
1085 std::unique_ptr<RnnDescriptor> rnn_desc;
1086 OP_REQUIRES_OK(context, ExtractCudnnRNNParamsInfo<T>(context, &rnn_desc));
1087 int64 params_size_in_bytes = rnn_desc->ParamsSizeInBytes();
1088 CHECK(params_size_in_bytes % sizeof(T) == 0)
1089 << "params_size_in_bytes must be multiple of element size";
1090
1091 const Tensor* num_units_t = nullptr;
1092 OP_REQUIRES_OK(context, context->input("num_units", &num_units_t));
1093 CHECK(TensorShapeUtils::IsScalar(num_units_t->shape()))
1094 << "num_units is not a scalar";
1095 int num_units = num_units_t->scalar<int>()();
1096
1097 const Tensor* input_size_t = nullptr;
1098 OP_REQUIRES_OK(context, context->input("input_size", &input_size_t));
1099 CHECK(TensorShapeUtils::IsScalar(input_size_t->shape()))
1100 << "input_size is not a scalar";
1101 int input_size = input_size_t->scalar<int>()();
1102
1103 const Tensor* num_layers_t = nullptr;
1104 OP_REQUIRES_OK(context, context->input("num_layers", &num_layers_t));
1105 CHECK(TensorShapeUtils::IsScalar(num_layers_t->shape()))
1106 << "num_layers is not a scalar";
1107 int num_layers = num_layers_t->scalar<int>()();
1108 int num_dirs = 1;
1109 if (rnn_direction_mode() == RnnDirectionMode::kRnnBidirectional) {
1110 num_dirs = 2;
1111 }
1112 const int num_params_per_layer = num_params_ / num_layers / num_dirs;
1113 // Number of params applied on inputs. The rest are applied on recurrent
1114 // hidden states.
1115 const int num_params_input_state = num_params_per_layer / 2;
1116 CHECK(num_params_ % (num_layers * num_dirs) == 0)
1117 << "Number of params is not a multiple of num_layers * num_dirs.";
1118 CHECK(num_params_per_layer % 2 == 0)
1119 << "Number of params per layer is not a even number.";
1120
1121 CHECK(num_params_ == rnn_desc->ParamsWeightRegions().size())
1122 << "Number of params mismatch. Expected " << num_params_ << ", got "
1123 << rnn_desc->ParamsWeightRegions().size();
1124 for (int i = 0; i < rnn_desc->ParamsWeightRegions().size(); i++) {
1125 int64 size_in_bytes = rnn_desc->ParamsWeightRegions()[i].size;
1126 int64 size = size_in_bytes / sizeof(T);
1127 const int layer_idx = i / num_params_per_layer;
1128 const int index_within_layer = i % num_params_per_layer;
1129 int width = 0, height = num_units;
1130 // In CuDNN layout, each layer has num_params_per_layer params, with the
1131 // first half a.k.a num_params_input_state params applied on the inputs,
1132 // and the second half on the recurrent hidden states.
1133 bool apply_on_input_state = index_within_layer < num_params_input_state;
1134 if (rnn_direction_mode() == RnnDirectionMode::kRnnUnidirectional) {
1135 if (layer_idx == 0 && apply_on_input_state) {
1136 width = input_size;
1137 } else {
1138 width = num_units;
1139 }
1140 } else {
1141 if (apply_on_input_state) {
1142 if (layer_idx <= 1) {
1143 // First fwd or bak layer.
1144 width = input_size;
1145 } else {
1146 // Following layers, cell inputs are concatenated outputs of
1147 // its prior layer.
1148 width = 2 * num_units;
1149 }
1150 } else {
1151 width = num_units;
1152 }
1153 }
1154 CHECK(size == width * height) << "Params size mismatch. Expected "
1155 << width * height << ", got " << size;
1156 Tensor* output = nullptr;
1157 OP_REQUIRES_OK(context, context->allocate_output(
1158 i, TensorShape({height, width}), &output));
1159 DeviceMemoryBase data_src_ptr = SliceDeviceMemory(
1160 input_ptr, rnn_desc->ParamsWeightRegions()[i].offset, size_in_bytes);
1161 auto data_dst_ptr = StreamExecutorUtil::AsDeviceMemory<T>(*output);
1162 stream->ThenMemcpy(&data_dst_ptr, data_src_ptr, size_in_bytes);
1163 }
1164
1165 OP_REQUIRES(context, num_params_ == rnn_desc->ParamsBiasRegions().size(),
1166 errors::InvalidArgument("Number of params mismatch. Expected ",
1167 num_params_, ", got ",
1168 rnn_desc->ParamsBiasRegions().size()));
1169 for (int i = 0; i < rnn_desc->ParamsBiasRegions().size(); i++) {
1170 int64 size_in_bytes = rnn_desc->ParamsBiasRegions()[i].size;
1171 int64 size = size_in_bytes / sizeof(T);
1172 OP_REQUIRES(context, size == num_units,
1173 errors::InvalidArgument("Params size mismatch. Expected ",
1174 num_units, ", got ", size));
1175
1176 Tensor* output = nullptr;
1177 OP_REQUIRES_OK(context,
1178 context->allocate_output(num_params_ + i,
1179 TensorShape({size}), &output));
1180 DeviceMemoryBase data_src_ptr = SliceDeviceMemory(
1181 input_ptr, rnn_desc->ParamsBiasRegions()[i].offset, size_in_bytes);
1182 auto data_dst_ptr = StreamExecutorUtil::AsDeviceMemory<T>(*output);
1183 stream->ThenMemcpy(&data_dst_ptr, data_src_ptr, size_in_bytes);
1184 }
1185 }
1186
1187 private:
1188 int num_params_;
1189 };
1190
1191 #define REGISTER_GPU(T) \
1192 REGISTER_KERNEL_BUILDER(Name("CudnnRNNParamsToCanonical") \
1193 .Device(DEVICE_GPU) \
1194 .HostMemory("num_layers") \
1195 .HostMemory("num_units") \
1196 .HostMemory("input_size") \
1197 .TypeConstraint<T>("T"), \
1198 CudnnRNNParamsToCanonical<GPUDevice, T>);
1199 TF_CALL_half(REGISTER_GPU);
1200 TF_CALL_float(REGISTER_GPU);
1201 TF_CALL_double(REGISTER_GPU);
1202 #undef REGISTER_GPU
1203
1204 // Convert weight and bias params from the canonical form to a
1205 // platform-specific layout.
1206 template <typename T>
1207 class CudnnRNNCanonicalToParams<GPUDevice, T> : public CudnnRNNKernelCommon {
1208 public:
CudnnRNNCanonicalToParams(OpKernelConstruction * context)1209 explicit CudnnRNNCanonicalToParams(OpKernelConstruction* context)
1210 : CudnnRNNKernelCommon(context) {}
1211
Compute(OpKernelContext * context)1212 void Compute(OpKernelContext* context) override {
1213 std::unique_ptr<RnnDescriptor> rnn_desc;
1214 OP_REQUIRES_OK(context, ExtractCudnnRNNParamsInfo<T>(context, &rnn_desc));
1215 int64 params_size_in_bytes = rnn_desc->ParamsSizeInBytes();
1216 CHECK(params_size_in_bytes % sizeof(T) == 0)
1217 << "params_size_in_bytes must be multiple of element size";
1218 Tensor* output = nullptr;
1219 int params_size = params_size_in_bytes / sizeof(T);
1220 OP_REQUIRES_OK(context,
1221 context->allocate_output(0, {params_size}, &output));
1222 auto output_ptr = StreamExecutorUtil::AsDeviceMemory<T>(*output);
1223 Stream* stream = context->op_device_context()->stream();
1224
1225 OpInputList weights;
1226 OP_REQUIRES_OK(context, context->input_list("weights", &weights));
1227 RestoreParams<T>(weights, rnn_desc->ParamsWeightRegions(), &output_ptr,
1228 stream);
1229
1230 OpInputList biases;
1231 OP_REQUIRES_OK(context, context->input_list("biases", &biases));
1232 RestoreParams<T>(biases, rnn_desc->ParamsBiasRegions(), &output_ptr,
1233 stream);
1234 }
1235 };
1236
1237 #define REGISTER_GPU(T) \
1238 REGISTER_KERNEL_BUILDER(Name("CudnnRNNCanonicalToParams") \
1239 .Device(DEVICE_GPU) \
1240 .HostMemory("num_layers") \
1241 .HostMemory("num_units") \
1242 .HostMemory("input_size") \
1243 .TypeConstraint<T>("T"), \
1244 CudnnRNNCanonicalToParams<GPUDevice, T>);
1245 TF_CALL_half(REGISTER_GPU);
1246 TF_CALL_float(REGISTER_GPU);
1247 TF_CALL_double(REGISTER_GPU);
1248 #undef REGISTER_GPU
1249
1250 // Run the forward operation of the RNN model.
1251 template <typename T>
1252 class CudnnRNNForwardOp<GPUDevice, T> : public CudnnRNNKernelCommon {
1253 public:
CudnnRNNForwardOp(OpKernelConstruction * context)1254 explicit CudnnRNNForwardOp(OpKernelConstruction* context)
1255 : CudnnRNNKernelCommon(context) {
1256 OP_REQUIRES_OK(context, context->GetAttr("is_training", &is_training_));
1257
1258 // Read debug env variables.
1259 is_debug_mode_ = DebugCudnnRnn();
1260 debug_cudnn_rnn_algo_ = DebugCudnnRnnAlgo();
1261 debug_use_tensor_ops_ = DebugCudnnRnnUseTensorOps();
1262 }
1263
Compute(OpKernelContext * context)1264 void Compute(OpKernelContext* context) override {
1265 AlgorithmConfig algo_config;
1266 ComputeAndReturnAlgorithm(context, &algo_config, /*var_seq_lengths=*/false,
1267 /*time_major=*/true);
1268 }
1269
1270 protected:
ComputeAndReturnAlgorithm(OpKernelContext * context,AlgorithmConfig * output_algo_config,bool var_seq_lengths,bool time_major)1271 virtual void ComputeAndReturnAlgorithm(OpKernelContext* context,
1272 AlgorithmConfig* output_algo_config,
1273 bool var_seq_lengths,
1274 bool time_major) {
1275 CHECK_NE(output_algo_config, nullptr);
1276
1277 const Tensor* input = nullptr;
1278 const Tensor* input_h = nullptr;
1279 const Tensor* input_c = nullptr;
1280 const Tensor* params = nullptr;
1281 const Tensor* sequence_lengths = nullptr;
1282 CudnnRnnModelShapes model_shapes;
1283 if (var_seq_lengths) {
1284 OP_REQUIRES_OK(context,
1285 ExtractForwardInput(context, model_types(), time_major,
1286 &input, &input_h, &input_c, ¶ms,
1287 &sequence_lengths, &model_shapes));
1288 } else {
1289 OP_REQUIRES_OK(context, ExtractForwardInput(
1290 context, model_types(), time_major, &input,
1291 &input_h, &input_c, ¶ms, &model_shapes));
1292 }
1293 RnnInputMode input_mode;
1294 OP_REQUIRES_OK(context,
1295 ToRNNInputMode(rnn_input_mode(), model_shapes.num_units,
1296 model_shapes.input_size, &input_mode));
1297
1298 Tensor* output = nullptr;
1299 Tensor* output_h = nullptr;
1300 Tensor* output_c = nullptr;
1301 OP_REQUIRES_OK(context, AllocateOutputs(context, model_shapes, &output,
1302 &output_h, &output_c));
1303
1304 // Creates a memory callback for the reserve_space. The memory lives in the
1305 // output of this kernel. And it will be fed into the backward pass when
1306 // needed.
1307 CudnnRnnAllocatorInOutput<T> reserve_space_allocator(context, 3);
1308 // Creates a memory callback for the workspace. The memory lives to the end
1309 // of this kernel calls.
1310 CudnnRnnAllocatorInTemp<uint8> workspace_allocator(context);
1311
1312 if (is_debug_mode_) {
1313 AlgorithmDesc algo_desc(debug_cudnn_rnn_algo_, debug_use_tensor_ops_);
1314 output_algo_config->set_algorithm(algo_desc);
1315 } else {
1316 OP_REQUIRES_OK(context,
1317 MaybeAutoTune(context, model_shapes, input_mode, input,
1318 input_h, input_c, params, output, output_h,
1319 output_c, output_algo_config));
1320 }
1321
1322 Status launch_status;
1323 {
1324 mutex_lock l(mu_);
1325 RnnDescriptor* rnn_desc_ptr = nullptr;
1326 OP_REQUIRES_OK(
1327 context, GetCachedRnnDescriptor<T>(context, model_shapes, input_mode,
1328 *output_algo_config,
1329 &rnn_state_cache_, &rnn_desc_ptr));
1330 launch_status = DoForward<T>(
1331 context, *rnn_desc_ptr, model_types(), model_shapes, input, input_h,
1332 input_c, params, is_training_, output, output_h, output_c,
1333 sequence_lengths, time_major, &reserve_space_allocator,
1334 &workspace_allocator, /*output_profile_result=*/nullptr);
1335 }
1336 OP_REQUIRES_OK(context, launch_status);
1337 }
1338
1339 protected:
MaybeAutoTune(OpKernelContext * context,const CudnnRnnModelShapes & model_shapes,const RnnInputMode & input_mode,const Tensor * input,const Tensor * input_h,const Tensor * input_c,const Tensor * params,Tensor * output,Tensor * output_h,Tensor * output_c,AlgorithmConfig * best_algo_config)1340 virtual Status MaybeAutoTune(OpKernelContext* context,
1341 const CudnnRnnModelShapes& model_shapes,
1342 const RnnInputMode& input_mode,
1343 const Tensor* input, const Tensor* input_h,
1344 const Tensor* input_c, const Tensor* params,
1345 Tensor* output, Tensor* output_h,
1346 Tensor* output_c,
1347 AlgorithmConfig* best_algo_config) {
1348 CHECK_NE(best_algo_config, nullptr);
1349 *best_algo_config = AlgorithmConfig();
1350 return Status::OK();
1351 }
1352
is_training() const1353 bool is_training() const { return is_training_; }
1354 bool is_debug_mode_;
1355 bool debug_use_tensor_ops_;
1356 int64 debug_cudnn_rnn_algo_;
1357
1358 private:
AllocateOutputs(OpKernelContext * context,const CudnnRnnModelShapes & model_shapes,Tensor ** output,Tensor ** output_h,Tensor ** output_c)1359 Status AllocateOutputs(OpKernelContext* context,
1360 const CudnnRnnModelShapes& model_shapes,
1361 Tensor** output, Tensor** output_h,
1362 Tensor** output_c) {
1363 const TensorShape& hidden_state_shape = model_shapes.hidden_state_shape;
1364 const TensorShape& output_shape = model_shapes.output_shape;
1365
1366 TF_RETURN_IF_ERROR(context->allocate_output(0, output_shape, output));
1367 TF_RETURN_IF_ERROR(
1368 context->allocate_output(1, hidden_state_shape, output_h));
1369 if (HasInputC()) {
1370 TF_RETURN_IF_ERROR(
1371 context->allocate_output(2, hidden_state_shape, output_c));
1372 } else {
1373 // Only LSTM uses input_c and output_c. So for all other models, we only
1374 // need to create dummy outputs.
1375 TF_RETURN_IF_ERROR(context->allocate_output(2, {}, output_c));
1376 }
1377 if (!is_training_) {
1378 Tensor* dummy_reserve_space = nullptr;
1379 TF_RETURN_IF_ERROR(context->allocate_output(3, {}, &dummy_reserve_space));
1380 }
1381 return Status::OK();
1382 }
1383
1384 mutex mu_;
1385 bool is_training_;
1386 RnnStateCache rnn_state_cache_ GUARDED_BY(mu_);
1387 };
1388
1389 #define REGISTER_GPU(T) \
1390 REGISTER_KERNEL_BUILDER( \
1391 Name("CudnnRNN").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
1392 CudnnRNNForwardOp<GPUDevice, T>);
1393
1394 TF_CALL_half(REGISTER_GPU);
1395 TF_CALL_float(REGISTER_GPU);
1396 TF_CALL_double(REGISTER_GPU);
1397 #undef REGISTER_GPU
1398
1399 template <typename T>
1400 class CudnnRNNForwardOpV2<GPUDevice, T>
1401 : public CudnnRNNForwardOp<GPUDevice, T> {
1402 private:
1403 using CudnnRNNForwardOp<GPUDevice, T>::is_training;
1404 using CudnnRNNKernelCommon::CreateRnnDescriptor;
1405 using CudnnRNNKernelCommon::dropout;
1406 using CudnnRNNKernelCommon::HasInputC;
1407 using CudnnRNNKernelCommon::model_types;
1408
1409 public:
CudnnRNNForwardOpV2(OpKernelConstruction * context)1410 explicit CudnnRNNForwardOpV2(OpKernelConstruction* context)
1411 : CudnnRNNForwardOp<GPUDevice, T>(context) {}
1412
Compute(OpKernelContext * context)1413 void Compute(OpKernelContext* context) override {
1414 AlgorithmConfig best_algo_config;
1415 CudnnRNNForwardOp<GPUDevice, T>::ComputeAndReturnAlgorithm(
1416 context, &best_algo_config, /*var_seq_lengths=*/false,
1417 /*time_major=*/true);
1418 if (!context->status().ok()) {
1419 return;
1420 }
1421
1422 Tensor* output_host_reserved = nullptr;
1423 // output_host_reserved stores opaque info used for backprop when running
1424 // in training mode. At present, it includes a serialization of the best
1425 // AlgorithmDesc picked during rnn forward pass autotune.
1426 // int8 algorithm_id
1427 // int8 use_tensor_op
1428 // If autotune is not enabled, the algorithm_id is
1429 // stream_executor::dnn::kDefaultAlgorithm and use_tensor_op is false. If
1430 // running in inference mode, the output_host_reserved is currently not
1431 // populated.
1432 if (is_training()) {
1433 OP_REQUIRES_OK(context, context->allocate_output(4, TensorShape({2}),
1434 &output_host_reserved));
1435 auto output_host_reserved_int8 = output_host_reserved->vec<int8>();
1436 output_host_reserved_int8(0) = best_algo_config.algorithm()->algo_id();
1437 output_host_reserved_int8(1) =
1438 best_algo_config.algorithm()->tensor_ops_enabled();
1439 } else {
1440 OP_REQUIRES_OK(context,
1441 context->allocate_output(4, {}, &output_host_reserved));
1442 }
1443 }
1444
1445 protected:
MaybeAutoTune(OpKernelContext * context,const CudnnRnnModelShapes & model_shapes,const RnnInputMode & input_mode,const Tensor * input,const Tensor * input_h,const Tensor * input_c,const Tensor * params,Tensor * output,Tensor * output_h,Tensor * output_c,AlgorithmConfig * algo_config)1446 Status MaybeAutoTune(OpKernelContext* context,
1447 const CudnnRnnModelShapes& model_shapes,
1448 const RnnInputMode& input_mode, const Tensor* input,
1449 const Tensor* input_h, const Tensor* input_c,
1450 const Tensor* params, Tensor* output, Tensor* output_h,
1451 Tensor* output_c,
1452 AlgorithmConfig* algo_config) override {
1453 CHECK_NE(algo_config, nullptr);
1454 if (!CudnnRnnUseAutotune() || this->is_debug_mode_) {
1455 *algo_config = AlgorithmConfig();
1456 return Status::OK();
1457 }
1458
1459 std::vector<AlgorithmDesc> algorithms;
1460 auto* stream = context->op_device_context()->stream();
1461 CHECK(stream->parent()->GetRnnAlgorithms(&algorithms));
1462 if (algorithms.empty()) {
1463 LOG(WARNING) << "No Rnn algorithm found";
1464 return Status::OK();
1465 }
1466
1467 const auto& modeltypes = model_types();
1468 CudnnRnnParameters rnn_params(
1469 model_shapes.num_layers, model_shapes.input_size,
1470 model_shapes.num_units, model_shapes.max_seq_length,
1471 model_shapes.batch_size, model_shapes.dir_count,
1472 /*has_dropout=*/std::abs(dropout()) > 1e-8, is_training(),
1473 modeltypes.rnn_mode, modeltypes.rnn_input_mode, input->dtype());
1474
1475 if (AutoTuneRnnConfigMap::GetInstance()->Find(rnn_params, algo_config)) {
1476 VLOG(1) << "Using existing best Cudnn RNN algorithm "
1477 << "(algo, tensor_op_enabled) = ("
1478 << algo_config->algorithm()->algo_id() << ", "
1479 << algo_config->algorithm()->tensor_ops_enabled() << ").";
1480 return Status::OK();
1481 }
1482
1483 // Create temp tensors when profiling backprop pass.
1484 auto data_type = input->dtype();
1485 Tensor output_backprop;
1486 Tensor output_h_backprop;
1487 Tensor output_c_backprop;
1488 Tensor input_backprop;
1489 Tensor input_h_backprop;
1490 Tensor input_c_backprop;
1491 Tensor params_backprop;
1492 if (is_training()) {
1493 TF_RETURN_IF_ERROR(context->allocate_temp(
1494 data_type, model_shapes.output_shape, &output_backprop));
1495 TF_RETURN_IF_ERROR(context->allocate_temp(
1496 data_type, model_shapes.hidden_state_shape, &output_h_backprop));
1497
1498 TF_RETURN_IF_ERROR(
1499 context->allocate_temp(data_type, params->shape(), ¶ms_backprop));
1500 TF_RETURN_IF_ERROR(context->allocate_temp(
1501 data_type, model_shapes.input_shape, &input_backprop));
1502 TF_RETURN_IF_ERROR(context->allocate_temp(
1503 data_type, model_shapes.hidden_state_shape, &input_h_backprop));
1504 if (HasInputC()) {
1505 TF_RETURN_IF_ERROR(context->allocate_temp(
1506 data_type, model_shapes.hidden_state_shape, &output_c_backprop));
1507 TF_RETURN_IF_ERROR(context->allocate_temp(
1508 data_type, model_shapes.hidden_state_shape, &input_c_backprop));
1509 }
1510 }
1511 ProfileResult best_result;
1512 for (auto& algo : algorithms) {
1513 VLOG(1) << "Profile Cudnn RNN algorithm (algo, tensor_op_enabled) = ("
1514 << algo.algo_id() << ", " << algo.tensor_ops_enabled() << ").";
1515 Status status;
1516 ProfileResult final_profile_result;
1517
1518 ProfileResult fwd_profile_result;
1519 ProfileResult bak_profile_result;
1520
1521 // RnnDescriptor is algorithm-dependent, thus not reusable.
1522 std::unique_ptr<RnnDescriptor> rnn_desc;
1523 // Use a temp scratch allocator for the random num generator.
1524 CudnnRnnAllocatorInTemp<uint8> dropout_state_allocator(context);
1525 if (!this->template CreateRnnDescriptor<T>(
1526 context, model_shapes, input_mode, AlgorithmConfig(algo),
1527 &dropout_state_allocator, &rnn_desc)
1528 .ok()) {
1529 continue;
1530 }
1531
1532 // Again use temp scratch allocator during profiling.
1533 CudnnRnnAllocatorInTemp<T> reserve_space_allocator(context);
1534 CudnnRnnAllocatorInTemp<uint8> workspace_allocator(context);
1535 status = DoForward<T>(context, *rnn_desc, model_types(), model_shapes,
1536 input, input_h, input_c, params, is_training(),
1537 output, output_h, output_c, nullptr, true,
1538 &reserve_space_allocator, &workspace_allocator,
1539 &fwd_profile_result);
1540 if (!status.ok()) {
1541 continue;
1542 }
1543
1544 if (is_training()) {
1545 // Get reserve space from the forward pass.
1546 Tensor reserve_space = reserve_space_allocator.get_allocated_tensor(0);
1547 status = DoBackward<T>(
1548 context, *rnn_desc, model_types(), model_shapes, input, input_h,
1549 input_c, params, output, output_h, output_c, &output_backprop,
1550 &output_h_backprop, &output_c_backprop, &reserve_space,
1551 &input_backprop, &input_h_backprop, &input_c_backprop,
1552 ¶ms_backprop, nullptr, true, &workspace_allocator,
1553 &bak_profile_result);
1554 if (!status.ok()) {
1555 continue;
1556 }
1557 final_profile_result.set_elapsed_time_in_ms(
1558 fwd_profile_result.elapsed_time_in_ms() +
1559 bak_profile_result.elapsed_time_in_ms());
1560 } else {
1561 final_profile_result = fwd_profile_result;
1562 }
1563
1564 auto total_time = final_profile_result.elapsed_time_in_ms();
1565 VLOG(1) << "Cudnn RNN algorithm (algo, tensor_op_enabled) = ("
1566 << algo.algo_id() << ", " << algo.tensor_ops_enabled() << ")"
1567 << " run time: " << total_time << " ms.";
1568 if (total_time < best_result.elapsed_time_in_ms()) {
1569 best_result.set_elapsed_time_in_ms(total_time);
1570 best_result.set_algorithm(algo);
1571 }
1572 }
1573
1574 if (!best_result.is_valid()) {
1575 return Status(error::Code::INTERNAL, "No algorithm worked!");
1576 }
1577 algo_config->set_algorithm(best_result.algorithm());
1578 VLOG(1) << "Best Cudnn RNN algorithm (algo, tensor_op_enabled) = ("
1579 << best_result.algorithm().algo_id() << ", "
1580 << best_result.algorithm().tensor_ops_enabled() << ").";
1581 AutoTuneRnnConfigMap::GetInstance()->Insert(rnn_params, *algo_config);
1582 return Status::OK();
1583 }
1584 };
1585
1586 #define REGISTER_GPU(T) \
1587 REGISTER_KERNEL_BUILDER(Name("CudnnRNNV2") \
1588 .Device(DEVICE_GPU) \
1589 .HostMemory("host_reserved") \
1590 .TypeConstraint<T>("T"), \
1591 CudnnRNNForwardOpV2<GPUDevice, T>);
1592
1593 TF_CALL_half(REGISTER_GPU);
1594 TF_CALL_float(REGISTER_GPU);
1595 TF_CALL_double(REGISTER_GPU);
1596 #undef REGISTER_GPU
1597
1598 template <typename T>
1599 class CudnnRNNForwardOpV3<GPUDevice, T>
1600 : public CudnnRNNForwardOp<GPUDevice, T> {
1601 private:
1602 using CudnnRNNForwardOp<GPUDevice, T>::is_training;
1603 using CudnnRNNKernelCommon::CreateRnnDescriptor;
1604 using CudnnRNNKernelCommon::dropout;
1605 using CudnnRNNKernelCommon::HasInputC;
1606 using CudnnRNNKernelCommon::model_types;
1607 bool time_major_;
1608
1609 protected:
time_major()1610 bool time_major() { return time_major_; }
1611
1612 public:
CudnnRNNForwardOpV3(OpKernelConstruction * context)1613 explicit CudnnRNNForwardOpV3(OpKernelConstruction* context)
1614 : CudnnRNNForwardOp<GPUDevice, T>(context) {
1615 OP_REQUIRES_OK(context, context->GetAttr("time_major", &time_major_));
1616 }
1617
Compute(OpKernelContext * context)1618 void Compute(OpKernelContext* context) override {
1619 AlgorithmConfig best_algo_config;
1620 CudnnRNNForwardOp<GPUDevice, T>::ComputeAndReturnAlgorithm(
1621 context, &best_algo_config, /*var_seq_lengths=*/true,
1622 /*time_major=*/time_major());
1623 if (!context->status().ok()) {
1624 return;
1625 }
1626
1627 Tensor* output_host_reserved = nullptr;
1628 // TODO: Current V3 only uses the default standard algorithm to process
1629 // batches with variable sequences and the inputs should be padded.
1630 // Autotune is not supported yet.
1631 OP_REQUIRES_OK(context,
1632 context->allocate_output(4, {}, &output_host_reserved));
1633 }
1634 };
1635
1636 #define REGISTER_GPU(T) \
1637 REGISTER_KERNEL_BUILDER(Name("CudnnRNNV3") \
1638 .Device(DEVICE_GPU) \
1639 .HostMemory("sequence_lengths") \
1640 .HostMemory("host_reserved") \
1641 .TypeConstraint<T>("T"), \
1642 CudnnRNNForwardOpV3<GPUDevice, T>);
1643
1644 TF_CALL_half(REGISTER_GPU);
1645 TF_CALL_float(REGISTER_GPU);
1646 TF_CALL_double(REGISTER_GPU);
1647 #undef REGISTER_GPU
1648
1649 // Run the backward operation of the RNN model.
1650 template <typename T>
1651 class CudnnRNNBackwardOp<GPUDevice, T> : public CudnnRNNKernelCommon {
1652 public:
CudnnRNNBackwardOp(OpKernelConstruction * context)1653 explicit CudnnRNNBackwardOp(OpKernelConstruction* context)
1654 : CudnnRNNKernelCommon(context) {}
1655
Compute(OpKernelContext * context)1656 void Compute(OpKernelContext* context) override {
1657 ComputeImpl(context, false, true);
1658 }
1659
1660 protected:
ComputeImpl(OpKernelContext * context,bool var_seq_lengths,bool time_major)1661 virtual void ComputeImpl(OpKernelContext* context, bool var_seq_lengths,
1662 bool time_major) {
1663 const Tensor* input = nullptr;
1664 const Tensor* input_h = nullptr;
1665 const Tensor* input_c = nullptr;
1666 const Tensor* params = nullptr;
1667 const Tensor* sequence_lengths = nullptr;
1668 CudnnRnnModelShapes model_shapes;
1669 if (var_seq_lengths) {
1670 OP_REQUIRES_OK(context,
1671 ExtractForwardInput(context, model_types(), time_major,
1672 &input, &input_h, &input_c, ¶ms,
1673 &sequence_lengths, &model_shapes));
1674 } else {
1675 OP_REQUIRES_OK(context, ExtractForwardInput(
1676 context, model_types(), time_major, &input,
1677 &input_h, &input_c, ¶ms, &model_shapes));
1678 }
1679 RnnInputMode input_mode;
1680 OP_REQUIRES_OK(context,
1681 ToRNNInputMode(rnn_input_mode(), model_shapes.num_units,
1682 model_shapes.input_size, &input_mode));
1683
1684 const Tensor* output = nullptr;
1685 const Tensor* output_h = nullptr;
1686 const Tensor* output_c = nullptr;
1687 const Tensor* output_backprop = nullptr;
1688 const Tensor* output_h_backprop = nullptr;
1689 const Tensor* output_c_backprop = nullptr;
1690 const Tensor* reserve_space = nullptr;
1691 OP_REQUIRES_OK(context,
1692 ExtractBackwardInputs(context, model_shapes, model_types(),
1693 &output, &output_h, &output_c,
1694 &output_backprop, &output_h_backprop,
1695 &output_c_backprop, &reserve_space));
1696
1697 Tensor* input_backprop = nullptr;
1698 Tensor* input_h_backprop = nullptr;
1699 Tensor* input_c_backprop = nullptr;
1700 Tensor* params_backprop = nullptr;
1701 OP_REQUIRES_OK(context,
1702 AllocateOutputs(context, model_shapes, params->shape(),
1703 &input_backprop, &input_h_backprop,
1704 &input_c_backprop, ¶ms_backprop));
1705
1706 // Creates a memory callback for the workspace. The memory lives to the end
1707 // of this kernel calls.
1708 CudnnRnnAllocatorInTemp<uint8> workspace_allocator(context);
1709 AlgorithmConfig algo_config;
1710 OP_REQUIRES_OK(context, GetAlgorithm(context, &algo_config));
1711 Status launch_status;
1712 {
1713 mutex_lock l(mu_);
1714 RnnDescriptor* rnn_desc_ptr = nullptr;
1715 OP_REQUIRES_OK(
1716 context, GetCachedRnnDescriptor<T>(context, model_shapes, input_mode,
1717 algo_config, &rnn_state_cache_,
1718 &rnn_desc_ptr));
1719 launch_status = DoBackward<T>(
1720 context, *rnn_desc_ptr, model_types(), model_shapes, input, input_h,
1721 input_c, params, output, output_h, output_c, output_backprop,
1722 output_h_backprop, output_c_backprop, reserve_space, input_backprop,
1723 input_h_backprop, input_c_backprop, params_backprop, sequence_lengths,
1724 time_major, &workspace_allocator,
1725 /*output_profile_result=*/nullptr);
1726 }
1727 OP_REQUIRES_OK(context, launch_status);
1728 }
1729
1730 protected:
GetAlgorithm(OpKernelContext * context,AlgorithmConfig * algo_config)1731 virtual Status GetAlgorithm(OpKernelContext* context,
1732 AlgorithmConfig* algo_config) {
1733 CHECK_NE(algo_config, nullptr);
1734 *algo_config = AlgorithmConfig();
1735 return Status::OK();
1736 }
1737
1738 private:
1739 mutex mu_;
1740 RnnStateCache rnn_state_cache_ GUARDED_BY(mu_);
1741
ExtractBackwardInputs(OpKernelContext * context,const CudnnRnnModelShapes & model_shapes,const CudnnModelTypes & model_types,const Tensor ** output,const Tensor ** output_h,const Tensor ** output_c,const Tensor ** output_backprop,const Tensor ** output_h_backprop,const Tensor ** output_c_backprop,const Tensor ** reserve_space)1742 Status ExtractBackwardInputs(
1743 OpKernelContext* context, const CudnnRnnModelShapes& model_shapes,
1744 const CudnnModelTypes& model_types, const Tensor** output,
1745 const Tensor** output_h, const Tensor** output_c,
1746 const Tensor** output_backprop, const Tensor** output_h_backprop,
1747 const Tensor** output_c_backprop, const Tensor** reserve_space) {
1748 TF_RETURN_IF_ERROR(context->input("output", output));
1749 TF_RETURN_IF_ERROR(context->input("output_backprop", output_backprop));
1750 TF_RETURN_IF_ERROR(context->input("output_h", output_h));
1751 TF_RETURN_IF_ERROR(context->input("output_h_backprop", output_h_backprop));
1752 if (model_types.HasInputC()) {
1753 TF_RETURN_IF_ERROR(context->input("output_c", output_c));
1754 TF_RETURN_IF_ERROR(
1755 context->input("output_c_backprop", output_c_backprop));
1756 }
1757 TF_RETURN_IF_ERROR(context->input("reserve_space", reserve_space));
1758 const TensorShape& hidden_state_shape = model_shapes.hidden_state_shape;
1759 const TensorShape& output_shape = model_shapes.output_shape;
1760
1761 if (output_shape != (*output)->shape()) {
1762 return errors::InvalidArgument(
1763 "Invalid output shape: ", (*output)->shape().DebugString(), " ",
1764 output_shape.DebugString());
1765 }
1766 if (hidden_state_shape != (*output_h)->shape()) {
1767 return errors::InvalidArgument(
1768 "Invalid output_h shape: ", (*output_h)->shape().DebugString(), " ",
1769 hidden_state_shape.DebugString());
1770 }
1771
1772 if (output_shape != (*output_backprop)->shape()) {
1773 return errors::InvalidArgument("Invalid output_backprop shape: ",
1774 (*output_backprop)->shape().DebugString(),
1775 " ", output_shape.DebugString());
1776 }
1777 if (hidden_state_shape != (*output_h_backprop)->shape()) {
1778 return errors::InvalidArgument(
1779 "Invalid output_h_backprop shape: ",
1780 (*output_h_backprop)->shape().DebugString(), " ",
1781 hidden_state_shape.DebugString());
1782 }
1783
1784 if (model_types.HasInputC()) {
1785 if (hidden_state_shape != (*output_c)->shape()) {
1786 return errors::InvalidArgument(
1787 "Invalid output_c shape: ", (*output_c)->shape().DebugString(), " ",
1788 hidden_state_shape.DebugString());
1789 }
1790 if (hidden_state_shape != (*output_c_backprop)->shape()) {
1791 return errors::InvalidArgument(
1792 "Invalid output_c_backprop shape: ",
1793 (*output_c_backprop)->shape().DebugString(), " ",
1794 hidden_state_shape.DebugString());
1795 }
1796 }
1797 return Status::OK();
1798 }
1799
AllocateOutputs(OpKernelContext * context,const CudnnRnnModelShapes & model_shapes,const TensorShape & params_shape,Tensor ** input_backprop,Tensor ** input_h_backprop,Tensor ** input_c_backprop,Tensor ** params_backprop)1800 Status AllocateOutputs(OpKernelContext* context,
1801 const CudnnRnnModelShapes& model_shapes,
1802 const TensorShape& params_shape,
1803 Tensor** input_backprop, Tensor** input_h_backprop,
1804 Tensor** input_c_backprop, Tensor** params_backprop) {
1805 const TensorShape& input_shape = model_shapes.input_shape;
1806 const TensorShape& hidden_state_shape = model_shapes.hidden_state_shape;
1807
1808 TF_RETURN_IF_ERROR(
1809 context->allocate_output(0, input_shape, input_backprop));
1810 TF_RETURN_IF_ERROR(
1811 context->allocate_output(1, hidden_state_shape, input_h_backprop));
1812 if (HasInputC()) {
1813 TF_RETURN_IF_ERROR(
1814 context->allocate_output(2, hidden_state_shape, input_c_backprop));
1815 } else {
1816 // Only LSTM uses input_c and output_c. So for all other models, we only
1817 // need to create dummy outputs.
1818 TF_RETURN_IF_ERROR(context->allocate_output(2, {}, input_c_backprop));
1819 }
1820 TF_RETURN_IF_ERROR(
1821 context->allocate_output(3, params_shape, params_backprop));
1822 return Status::OK();
1823 }
1824 };
1825
1826 #define REGISTER_GPU(T) \
1827 REGISTER_KERNEL_BUILDER( \
1828 Name("CudnnRNNBackprop").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
1829 CudnnRNNBackwardOp<GPUDevice, T>);
1830
1831 TF_CALL_half(REGISTER_GPU);
1832 TF_CALL_float(REGISTER_GPU);
1833 TF_CALL_double(REGISTER_GPU);
1834 #undef REGISTER_GPU
1835
1836 template <typename T>
1837 class CudnnRNNBackwardOpV2<GPUDevice, T>
1838 : public CudnnRNNBackwardOp<GPUDevice, T> {
1839 public:
CudnnRNNBackwardOpV2(OpKernelConstruction * context)1840 explicit CudnnRNNBackwardOpV2(OpKernelConstruction* context)
1841 : CudnnRNNBackwardOp<GPUDevice, T>(context) {}
1842
1843 protected:
GetAlgorithm(OpKernelContext * context,AlgorithmConfig * algo_config)1844 Status GetAlgorithm(OpKernelContext* context,
1845 AlgorithmConfig* algo_config) override {
1846 CHECK_NE(algo_config, nullptr);
1847 const Tensor* host_reserved = nullptr;
1848 TF_RETURN_IF_ERROR(context->input("host_reserved", &host_reserved));
1849
1850 auto host_reserved_int8 = host_reserved->vec<int8>();
1851 const AlgorithmDesc algo_desc(host_reserved_int8(0), host_reserved_int8(1));
1852 algo_config->set_algorithm(algo_desc);
1853 return Status::OK();
1854 }
1855 };
1856
1857 #define REGISTER_GPU(T) \
1858 REGISTER_KERNEL_BUILDER(Name("CudnnRNNBackpropV2") \
1859 .Device(DEVICE_GPU) \
1860 .HostMemory("host_reserved") \
1861 .TypeConstraint<T>("T"), \
1862 CudnnRNNBackwardOpV2<GPUDevice, T>);
1863
1864 TF_CALL_half(REGISTER_GPU);
1865 TF_CALL_float(REGISTER_GPU);
1866 TF_CALL_double(REGISTER_GPU);
1867 #undef REGISTER_GPU
1868
1869 template <typename T>
1870 class CudnnRNNBackwardOpV3<GPUDevice, T>
1871 : public CudnnRNNBackwardOp<GPUDevice, T> {
1872 private:
1873 bool time_major_;
1874
1875 protected:
time_major()1876 bool time_major() { return time_major_; }
1877
1878 public:
CudnnRNNBackwardOpV3(OpKernelConstruction * context)1879 explicit CudnnRNNBackwardOpV3(OpKernelConstruction* context)
1880 : CudnnRNNBackwardOp<GPUDevice, T>(context) {
1881 OP_REQUIRES_OK(context, context->GetAttr("time_major", &time_major_));
1882 }
1883
Compute(OpKernelContext * context)1884 void Compute(OpKernelContext* context) override {
1885 CudnnRNNBackwardOp<GPUDevice, T>::ComputeImpl(context, true, time_major());
1886 }
1887 };
1888
1889 #define REGISTER_GPU(T) \
1890 REGISTER_KERNEL_BUILDER(Name("CudnnRNNBackpropV3") \
1891 .Device(DEVICE_GPU) \
1892 .HostMemory("sequence_lengths") \
1893 .HostMemory("host_reserved") \
1894 .TypeConstraint<T>("T"), \
1895 CudnnRNNBackwardOpV3<GPUDevice, T>);
1896
1897 TF_CALL_half(REGISTER_GPU);
1898 TF_CALL_float(REGISTER_GPU);
1899 TF_CALL_double(REGISTER_GPU);
1900 #undef REGISTER_GPU
1901
1902 // TODO(zhengxq): Add the conversion of Cudnn RNN Params from and to
1903 // its canonical form.
1904
1905 #endif // GOOGLE_CUDA
1906
1907 } // namespace tensorflow
1908