1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/native/RNN.h>
3 #include <ATen/core/Tensor.h>
4 #include <ATen/Config.h>
5 #include <ATen/InitialTensorOptions.h>
6 #include <ATen/MatrixRef.h>
7 #include <ATen/TensorUtils.h>
8
9 #include <ATen/cuda/CUDAConfig.h>
10 #include <c10/util/Exception.h>
11 #include <c10/util/irange.h>
12
13 #ifndef AT_PER_OPERATOR_HEADERS
14 #include <ATen/Functions.h>
15 #include <ATen/NativeFunctions.h>
16 #else
17 #include <ATen/ops/cat.h>
18 #include <ATen/ops/empty.h>
19 #include <ATen/ops/miopen_rnn.h>
20 #include <ATen/ops/miopen_rnn_native.h>
21 #include <ATen/ops/miopen_rnn_backward_native.h>
22 #include <ATen/ops/zeros.h>
23 #include <ATen/ops/zeros_like.h>
24 #endif
25
26 #if !AT_ROCM_ENABLED()
27
28 namespace at { namespace native {
29
miopen_rnn(const Tensor & input_r,TensorList weight,int64_t weight_stride0,const Tensor & hx,const std::optional<Tensor> & cx_opt,int64_t fn_mode,int64_t fn_hidden_size,int64_t fn_num_layers,bool batch_first,double fn_dropout,bool fn_train,bool fn_bidirectional,IntArrayRef fn_batch_sizes,const std::optional<Tensor> & fn_dropout_state_opt)30 std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor> miopen_rnn(
31 const Tensor& input_r, TensorList weight, int64_t weight_stride0,
32 const Tensor& hx, const std::optional<Tensor>& cx_opt,
33 int64_t fn_mode, int64_t fn_hidden_size, int64_t fn_num_layers,
34 bool batch_first, double fn_dropout, bool fn_train, bool fn_bidirectional,
35 IntArrayRef fn_batch_sizes, const std::optional<Tensor>& fn_dropout_state_opt
36 ) {
37 AT_ERROR("miopen_rnn : ATen not compiled with MIOpen support.");
38 }
39
miopen_rnn_backward(const Tensor & input,TensorList weight,int64_t weight_stride0,const Tensor & weight_buf,const Tensor & hx,const std::optional<Tensor> & cx_opt,const Tensor & output,const std::optional<Tensor> & grad_output_r_opt,const std::optional<Tensor> & grad_hy_r_opt,const std::optional<Tensor> & grad_cy_r_opt,int64_t mode,int64_t hidden_size,int64_t num_layers,bool batch_first,double dropout,bool train,bool bidirectional,IntArrayRef batch_sizes,const std::optional<Tensor> & dropout_state_opt,const Tensor & reserve,std::array<bool,4> output_mask)40 std::tuple<Tensor, Tensor, Tensor, std::vector<Tensor>> miopen_rnn_backward(
41 const Tensor& input, TensorList weight, int64_t weight_stride0, const Tensor& weight_buf, const Tensor& hx, const std::optional<Tensor>& cx_opt,
42 const Tensor& output, const std::optional<Tensor>& grad_output_r_opt, const std::optional<Tensor>& grad_hy_r_opt, const std::optional<Tensor>& grad_cy_r_opt, int64_t mode, int64_t hidden_size, int64_t num_layers, bool batch_first,
43 double dropout, bool train, bool bidirectional, IntArrayRef batch_sizes, const std::optional<Tensor>& dropout_state_opt,
44 const Tensor& reserve, std::array<bool, 4> output_mask
45 ) {
46 AT_ERROR("miopen_rnn_backward: ATen not compiled with MIOpen support.");
47 }
48
49 }} //namespace at::native
50
51 #else // AT_ROCM_ENABLED()
52
53 #include <ATen/miopen/miopen-wrapper.h>
54 #include <ATen/miopen/Descriptors.h>
55 #include <ATen/miopen/Types.h>
56 #include <ATen/miopen/Utils.h>
57
58 #include <ATen/TensorUtils.h>
59
60 #include <functional>
61 #include <iterator>
62 #include <sstream>
63 #include <algorithm>
64 #include <memory>
65 #include <mutex>
66 #include <stdint.h>
67 #include <unordered_map>
68
69 namespace at { namespace native {
70
71 //RNNDescriptor.
72 struct RNNDescriptorParams {
73 int64_t hidden_size;
74 int64_t num_layers;
75 miopenRNNDirectionMode_t direction;
76 miopenRNNMode_t rnn_mode;
77 miopenDataType_t datatype;
78 miopenRNNAlgo_t algo = miopenRNNdefault;
79 miopenRNNInputMode_t input_mode = miopenRNNlinear;
80 miopenRNNBiasMode_t bias_mode = miopenRNNNoBias;
81
num_directionsat::native::RNNDescriptorParams82 int64_t num_directions() const {
83 return (direction == miopenRNNbidirection) ? 2 : 1;
84 }
85
set_bidirectionalat::native::RNNDescriptorParams86 void set_bidirectional(bool fn_bidirectional) {
87 direction = fn_bidirectional ? miopenRNNbidirection : miopenRNNunidirection;
88 }
89
set_algoat::native::RNNDescriptorParams90 void set_algo(miopenRNNAlgo_t algo) {
91 this->algo = algo;
92 }
93
set_modeat::native::RNNDescriptorParams94 void set_mode(int64_t fn_mode) {
95 switch (fn_mode) {
96 case 0:
97 rnn_mode = miopenRNNRELU;
98 break;
99 case 1:
100 rnn_mode = miopenRNNTANH;
101 break;
102 case 2:
103 rnn_mode = miopenLSTM;
104 break;
105 case 3:
106 rnn_mode = miopenGRU;
107 break;
108 default:
109 {
110 std::ostringstream oss;
111 oss << "unrecognized miopen RNN mode " << fn_mode;
112 AT_ERROR(oss.str());
113 }
114 }
115 }
116
setat::native::RNNDescriptorParams117 void set(int64_t mode, int64_t hidden_size, int64_t num_layers, bool bidirectional, miopenDataType_t datatype, miopenRNNBiasMode_t bias_mode) {
118 this->set_mode(mode);
119 this->hidden_size = hidden_size;
120 this->num_layers = num_layers;
121 this->set_bidirectional(bidirectional);
122 this->datatype = datatype;
123 this->bias_mode = bias_mode;
124 }
125
descriptorat::native::RNNDescriptorParams126 RNNDescriptor descriptor() const {
127 RNNDescriptor rnn_desc;
128 rnn_desc.set(hidden_size, num_layers, input_mode, direction, rnn_mode, bias_mode, algo, datatype);
129 return rnn_desc;
130 }
131 };
132
133 //TensorDescriptor list.
rnn_descriptor_sequence(const Tensor & tensor,IntArrayRef batch_sizes)134 std::vector<TensorDescriptor> rnn_descriptor_sequence(const Tensor& tensor, IntArrayRef batch_sizes) {
135 std::vector<TensorDescriptor> descriptors(batch_sizes.size());
136 size_t i =0;
137
138 auto batch_tensor_size = tensor.sizes().vec();
139 for (auto batch_size : batch_sizes) {
140 batch_tensor_size[0] = batch_size;
141
142 descriptors[i].set(getMiopenDataType(tensor), batch_tensor_size, tensor.strides(), 3);
143 i++;
144 }
145
146 return descriptors;
147 }
148
rnn_descriptor(const Tensor & tensor,int64_t N)149 std::vector<TensorDescriptor> rnn_descriptor(const Tensor& tensor, int64_t N) {
150 std::vector<TensorDescriptor> descriptors(N);
151 for (const auto i : c10::irange(N)) {
152 descriptors[i].set(tensor, 5);
153 }
154
155 return descriptors;
156 }
157
158 struct TensorDescriptorListParams {
159 IntArrayRef batch_sizes;
160 int64_t seq_length;
161 int64_t mini_batch;
162
163 int64_t input_size;
164 int64_t batch_sizes_sum;
165
is_input_packedat::native::TensorDescriptorListParams166 bool is_input_packed() const {
167 return batch_sizes.size() != 0;
168 }
169
setat::native::TensorDescriptorListParams170 void set(IntArrayRef input_sizes, IntArrayRef batch_sizes_, bool batch_first) {
171 batch_sizes = batch_sizes_;
172 if (is_input_packed()) {
173 seq_length = batch_sizes.size();
174 mini_batch = batch_sizes[0];
175 batch_sizes_sum = input_sizes[0];
176 input_size = input_sizes[1];
177 } else {
178 if (batch_first) {
179 seq_length = input_sizes[1];
180 mini_batch = input_sizes[0];
181 } else {
182 seq_length = input_sizes[0];
183 mini_batch = input_sizes[1];
184 }
185 input_size = input_sizes[2];
186 batch_sizes_sum = -1;
187 }
188 }
189
descriptorsat::native::TensorDescriptorListParams190 std::vector<TensorDescriptor> descriptors(Tensor x) const {
191 auto is_input_packed = batch_sizes.size() != 0;
192 if (is_input_packed) {
193 return rnn_descriptor_sequence(x, batch_sizes);
194 } else {
195 return rnn_descriptor(x[0], seq_length);
196 }
197 }
198 };
199
200 struct RNNParams {
201 RNNDescriptorParams rnn;
202 TensorDescriptorListParams tensors;
203 };
204
205 struct RNNDescriptors {
206 RNNDescriptor rnn_desc;
207 std::vector<TensorDescriptor> x_descs;
208 std::vector<TensorDescriptor> y_descs;
209 TensorDescriptor hx_desc;
210 TensorDescriptor hy_desc;
211 TensorDescriptor cx_desc;
212 TensorDescriptor cy_desc;
213
RNNDescriptorsat::native::RNNDescriptors214 RNNDescriptors(const RNNParams& fn, miopenHandle_t handle, Tensor x, Tensor y, Tensor hx, Tensor cx) {
215 rnn_desc = fn.rnn.descriptor();
216 x_descs = fn.tensors.descriptors(x);
217 y_descs = fn.tensors.descriptors(y);
218 hx_desc.set(hx, 5);
219 hy_desc.set(hx, 5);
220 cx_desc.set(hx, 5);
221 cy_desc.set(hx, 5);
222 }
223
get_descsat::native::RNNDescriptors224 std::vector<miopenTensorDescriptor_t> get_descs(const std::vector<TensorDescriptor>& descs) {
225 std::vector<miopenTensorDescriptor_t> r;
226 r.reserve(descs.size());
227 for (auto& desc : descs) {
228 r.emplace_back(desc.desc());
229 }
230 return r;
231 }
232
get_x_descsat::native::RNNDescriptors233 std::vector<miopenTensorDescriptor_t> get_x_descs() {
234 return get_descs(x_descs);
235 }
236
get_y_descsat::native::RNNDescriptors237 std::vector<miopenTensorDescriptor_t> get_y_descs() {
238 return get_descs(y_descs);
239 }
240 };
241
permute_wei_for_miopen(Tensor wei,int64_t mode)242 Tensor permute_wei_for_miopen(Tensor wei, int64_t mode)
243 {
244 if (mode < 2)
245 return wei;
246
247 Tensor permuted_wei;
248 if(mode == 2) { // LSTM
249 auto sliced_tensor = wei.chunk(4, 0);
250 permuted_wei = at::cat({sliced_tensor[0], sliced_tensor[1], sliced_tensor[3], sliced_tensor[2]});
251 }
252 else if(mode == 3) { // GRU
253 auto sliced_tensor = wei.chunk(3, 0);
254 permuted_wei = at::cat({sliced_tensor[1], sliced_tensor[0], sliced_tensor[2]});
255 }
256 return permuted_wei;
257 }
258
_viewOrCopyParams(MatrixRef<Tensor> params_from,MatrixRef<Tensor> params_to,bool copy)259 void _viewOrCopyParams(MatrixRef<Tensor> params_from, MatrixRef<Tensor> params_to, bool copy) {
260 TORCH_CHECK(params_from.size(0) == params_to.size(0), "number of layers mismatch");
261 for (const auto i : c10::irange(params_from.size(0))) {
262 auto layer_params_from = params_from[i];
263 auto layer_params_to = params_to[i];
264 // NOTE: these lists have all weights before all biases, so if the layer
265 // doesn't use biases, iteration will terminate once layer_params_from ends
266 // and ignore them.
267 for (auto a = layer_params_from.begin(), b = layer_params_to.begin();
268 a != layer_params_from.end() && b != layer_params_to.end();
269 ++a, ++b) {
270 auto param_from = *a, param_to = *b;
271 TORCH_CHECK(param_from.type() == param_to.type(), "parameter types mismatch");
272 if (copy) {
273 param_to.copy_(param_from.view_as(param_to));
274 } else {
275 param_from.resize_as_(param_to);
276 }
277 }
278 }
279 }
280
_copyParams_and_permute(MatrixRef<Tensor> params_from,MatrixRef<Tensor> params_to,int64_t mode)281 void _copyParams_and_permute(MatrixRef<Tensor> params_from, MatrixRef<Tensor> params_to, int64_t mode) {
282 TORCH_CHECK(params_from.size(0) == params_to.size(0), "number of layers mismatch");
283 for (const auto i : c10::irange(params_from.size(0))) {
284 auto layer_params_from = params_from[i];
285 auto layer_params_to = params_to[i];
286 for (auto a = layer_params_from.begin(), b = layer_params_to.begin();
287 a != layer_params_from.end() && b != layer_params_to.end();
288 ++a, ++b) {
289 auto param_from = *a, param_to = *b;
290 TORCH_CHECK(param_from.type() == param_to.type(), "parameter types mismatch");
291 auto tmp = permute_wei_for_miopen(param_from, mode);
292 param_to.copy_(tmp.view_as(param_to));
293 }
294 }
295 }
296
_copyParams(MatrixRef<Tensor> params_from,MatrixRef<Tensor> params_to)297 void _copyParams(MatrixRef<Tensor> params_from, MatrixRef<Tensor> params_to) {
298 _viewOrCopyParams(params_from, params_to, true);
299 }
300
_viewParams(MatrixRef<Tensor> params_from,MatrixRef<Tensor> params_to)301 void _viewParams(MatrixRef<Tensor> params_from, MatrixRef<Tensor> params_to) {
302 _viewOrCopyParams(params_from, params_to, false);
303 }
304
get_num_weights(miopenHandle_t handle,const RNNDescriptor & rnn_desc,const TensorDescriptor & x_desc,miopenDataType_t datatype)305 int64_t get_num_weights(miopenHandle_t handle, const RNNDescriptor& rnn_desc,
306 const TensorDescriptor& x_desc, miopenDataType_t datatype)
307 {
308 size_t weight_size;
309 MIOPEN_CHECK(miopenGetRNNParamsSize(handle, rnn_desc.desc(), x_desc.desc(), &weight_size, datatype));
310 auto element_size = dataSize(datatype);
311 TORCH_CHECK(weight_size % element_size == 0, "miopenGetRNNParamsSize returned nonsensical weight_size.");
312 return weight_size / element_size;
313 }
314
_num_linear_layers(miopenRNNMode_t mode)315 int64_t _num_linear_layers(miopenRNNMode_t mode) {
316 switch(mode) {
317 case miopenLSTM:
318 return 8;
319 case miopenGRU:
320 return 6;
321 case miopenRNNRELU:
322 return 2;
323 case miopenRNNTANH:
324 return 2;
325 default:
326 AT_ERROR("Unknown miopen RNN mode : ", mode);
327 }
328 }
329
get_parameters(miopenHandle_t handle,const RNNDescriptorParams & rnn,const RNNDescriptor & rnn_desc,const TensorDescriptor & x_desc,const FilterDescriptor & w_desc,const Tensor & weight_buf)330 std::pair<std::vector<Tensor>, size_t> get_parameters(miopenHandle_t handle, const RNNDescriptorParams& rnn,
331 const RNNDescriptor& rnn_desc, const TensorDescriptor& x_desc, const FilterDescriptor& w_desc,
332 const Tensor& weight_buf)
333 {
334 std::vector<Tensor> params;
335 int64_t num_linear_layers = _num_linear_layers(rnn.rnn_mode);
336 int64_t num_layers = rnn.num_directions() * rnn.num_layers;
337 size_t cur_offset = 0;
338 size_t global_layer_params_count = 0;
339 auto elem_size = dataSize(getMiopenDataType(weight_buf));
340 auto bias_mode = rnn.bias_mode;
341
342 for (const auto layer : c10::irange(num_layers)) {
343 size_t layer_params_count = 0;
344
345 // Get layer params
346 for (const auto linear_id : c10::irange(num_linear_layers)) {
347 FilterDescriptor lin_layer_mat_desc;
348 size_t offset;
349 MIOPEN_CHECK(miopenGetRNNLayerParamOffset(
350 rnn_desc.desc(),
351 layer,
352 x_desc.desc(),
353 linear_id,
354 lin_layer_mat_desc.mut_desc(),
355 &offset));
356
357 size_t param_size;
358 MIOPEN_CHECK(miopenGetRNNLayerParamSize(
359 handle,
360 rnn_desc.desc(),
361 layer,
362 x_desc.desc(),
363 linear_id,
364 ¶m_size));
365 param_size /= elem_size;
366
367 if(linear_id == 0 || linear_id == num_linear_layers / 2) {
368 std::initializer_list<int64_t> size = { static_cast<int64_t>(param_size * num_linear_layers / 2), 1L};
369 Tensor param = at::empty({0}, weight_buf.options()).set_(weight_buf.storage(), offset, size);
370 params.emplace_back(std::move(param));
371 layer_params_count++;
372 } else {
373 TORCH_INTERNAL_ASSERT(cur_offset == offset,
374 "cur_offset = ", cur_offset, " ; offset = ", offset);
375 }
376 cur_offset = offset + param_size;
377 }
378
379 // Get bias params
380 if (bias_mode == miopenRNNwithBias) {
381 for (const auto linear_id : c10::irange(num_linear_layers)) {
382 FilterDescriptor lin_layer_mat_desc;
383 size_t offset;
384 MIOPEN_CHECK(miopenGetRNNLayerBiasOffset(
385 rnn_desc.desc(),
386 layer,
387 x_desc.desc(),
388 linear_id,
389 lin_layer_mat_desc.mut_desc(),
390 &offset));
391
392 size_t bias_size;
393 MIOPEN_CHECK(miopenGetRNNLayerBiasSize(
394 handle,
395 rnn_desc.desc(),
396 layer,
397 linear_id,
398 &bias_size));
399 bias_size /= elem_size;
400
401 if(linear_id == 0 || linear_id == num_linear_layers / 2) {
402 std::initializer_list<int64_t> size = { static_cast<int64_t>(bias_size * num_linear_layers / 2), 1L};
403 Tensor param = at::empty({0}, weight_buf.options()).set_(weight_buf.storage(), offset, size);
404 params.emplace_back(std::move(param));
405 layer_params_count++;
406 } else {
407 TORCH_INTERNAL_ASSERT(cur_offset == offset,
408 "cur_offset = ", cur_offset, " ; offset = ", offset);
409 }
410 cur_offset = offset + bias_size;
411 }
412 }
413
414 if (layer == 0) {
415 global_layer_params_count = layer_params_count;
416 } else {
417 TORCH_INTERNAL_ASSERT(global_layer_params_count == layer_params_count,
418 "global_layer_params_count = ", global_layer_params_count,
419 "; layer_params_count = ", layer_params_count);
420 }
421 } // layer
422 return std::make_pair(params, global_layer_params_count);
423 }
424
_input_size(const TensorDescriptorListParams & tensors)425 std::vector<int64_t> _input_size(const TensorDescriptorListParams& tensors) {
426 if (tensors.is_input_packed()) {
427 return {tensors.batch_sizes_sum, tensors.input_size};
428 } else {
429 return {tensors.seq_length, tensors.mini_batch, tensors.input_size};
430 }
431 }
432
_hidden_size(const RNNDescriptorParams & rnn,const TensorDescriptorListParams & tensors)433 std::vector<int64_t> _hidden_size(const RNNDescriptorParams& rnn, const TensorDescriptorListParams& tensors) {
434 return {rnn.num_layers * rnn.num_directions(), tensors.mini_batch, rnn.hidden_size};
435 }
436
_output_size(const RNNDescriptorParams & rnn,const TensorDescriptorListParams & tensors)437 std::vector<int64_t> _output_size(const RNNDescriptorParams& rnn, const TensorDescriptorListParams& tensors) {
438 if (tensors.is_input_packed()) {
439 return {tensors.batch_sizes_sum, rnn.hidden_size * rnn.num_directions()};
440 } else {
441 return {tensors.seq_length, tensors.mini_batch, rnn.hidden_size * rnn.num_directions()};
442 }
443 }
444
miopen_rnn(const Tensor & input_r,TensorList weight,int64_t weight_stride0,const Tensor & hx,const std::optional<Tensor> & cx_opt,int64_t fn_mode,int64_t fn_hidden_size,int64_t fn_num_layers,bool batch_first,double fn_dropout,bool fn_train,bool fn_bidirectional,IntArrayRef fn_batch_sizes,const std::optional<Tensor> & fn_dropout_state_opt)445 std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor> miopen_rnn(
446 const Tensor& input_r, TensorList weight, int64_t weight_stride0,
447 const Tensor& hx, const std::optional<Tensor>& cx_opt,
448 int64_t fn_mode, int64_t fn_hidden_size, int64_t fn_num_layers,
449 bool batch_first, double fn_dropout, bool fn_train, bool fn_bidirectional,
450 IntArrayRef fn_batch_sizes, const std::optional<Tensor>& fn_dropout_state_opt
451 ) {
452 // See [Note: hacky wrapper removal for optional tensor]
453 c10::MaybeOwned<Tensor> cx_maybe_owned = at::borrow_from_optional_tensor(cx_opt);
454 const Tensor& cx = *cx_maybe_owned;
455 const Tensor& fn_dropout_state = c10::value_or_else(fn_dropout_state_opt, [] {return Tensor();});
456
457 check_attributes(input_r, weight, {hx, cx});
458 auto input = input_r;
459
460 RNNParams fn;
461 auto datatype = getMiopenDataType(input);
462 miopenRNNBiasMode_t bias_mode = (weight_stride0 == 4) ? miopenRNNwithBias : miopenRNNNoBias;
463 fn.rnn.set(fn_mode, fn_hidden_size, fn_num_layers, fn_bidirectional, datatype, bias_mode);
464 fn.tensors.set(input.sizes(), fn_batch_sizes, batch_first);
465
466 if (fn.rnn.rnn_mode != miopenLSTM) {
467 TORCH_CHECK(!cx.defined(), "miopen_rnn: illegal defined cx for non-LSTM RNN.");
468 }
469
470 auto is_input_packed = fn.tensors.batch_sizes.size() != 0;
471 if (batch_first && !is_input_packed) {
472 input = input.transpose(0, 1);
473 }
474
475 auto hidden_size = _hidden_size(fn.rnn, fn.tensors);
476 auto output_size = _output_size(fn.rnn, fn.tensors);
477
478 TORCH_CHECK(hx.is_contiguous(), "miopen_rnn : hx is not contiguous.");
479 TORCH_CHECK(!cx.defined() || cx.is_contiguous(), "miopen_rnn : cx is not contiguous.");
480
481 auto x = input.contiguous();
482 auto output = at::empty(output_size, input.options());
483 auto hy = at::empty(hidden_size, hx.options());
484 Tensor cy;
485 if (cx.defined()) {
486 cy = at::empty(hidden_size, cx.options());
487 } else {
488 cy = at::empty({0}, hx.options());
489 }
490
491 auto y = output;
492 auto handle = getMiopenHandle();
493 miopenRNNAlgo_t algo = miopenRNNdefault;
494 fn.rnn.set_algo(algo);
495
496 RNNDescriptors descs(fn, handle, x, y, hx, cx);
497
498 FilterDescriptor w_desc;
499 auto num_weights = get_num_weights(handle, descs.rnn_desc, descs.x_descs[0], datatype);
500 auto weight_buf = at::empty(num_weights, x.options());
501 w_desc.set(weight_buf, 3);
502 weight_buf.zero_();
503 auto [params, params_stride0] = get_parameters(handle, fn.rnn, descs.rnn_desc, descs.x_descs[0], w_desc, weight_buf);
504 if (fn_mode < 2)
505 _copyParams(MatrixRef<Tensor>{weight, static_cast<size_t>(weight_stride0)},
506 MatrixRef<Tensor>{params, params_stride0});
507 else
508 _copyParams_and_permute(MatrixRef<Tensor>{weight, static_cast<size_t>(weight_stride0)},
509 MatrixRef<Tensor>{params, params_stride0}, fn_mode);
510
511 TORCH_CHECK(!cx.defined() || cx.sizes().equals(hidden_size), "Expected cell size ", IntArrayRef{hidden_size}, ", got", cx.sizes());
512
513 size_t workspace_size;
514 auto x_descs_arr = descs.get_x_descs();
515 auto y_descs_arr = descs.get_y_descs();
516
517 //Allocate workspace size.
518 MIOPEN_CHECK(miopenGetRNNWorkspaceSize(handle, descs.rnn_desc.desc(), fn.tensors.seq_length, x_descs_arr.data(), &workspace_size));
519 auto workspace = at::empty(workspace_size, input.options().dtype(kByte));
520
521 //Train or inference.
522 Tensor reserve;
523 if (fn_train) { //Train.
524 size_t reserver_size;
525 MIOPEN_CHECK(miopenGetRNNTrainingReserveSize(handle, descs.rnn_desc.desc(), fn.tensors.seq_length, x_descs_arr.data(), &reserver_size));
526 reserve = at::empty(reserver_size, input.options().dtype(kByte));
527 MIOPEN_CHECK(miopenRNNForwardTraining(handle, descs.rnn_desc.desc(), fn.tensors.seq_length,
528 x_descs_arr.data(), x.data_ptr(),
529 descs.hx_desc.desc(), hx.data_ptr(),
530 descs.cx_desc.desc(), cx.defined() ? cx.data_ptr() : nullptr,
531 w_desc.desc(), weight_buf.data_ptr(),
532 y_descs_arr.data(), y.data_ptr(),
533 descs.hy_desc.desc(), hy.data_ptr(),
534 descs.cy_desc.desc(), cy.defined() ? cy.data_ptr() : nullptr,
535 workspace.data_ptr(), workspace_size, reserve.mutable_data_ptr(), reserver_size ));
536 } else { //Inference.
537 reserve = at::empty({0}, input.options().dtype(kByte));
538 MIOPEN_CHECK(miopenRNNForwardInference(handle, descs.rnn_desc.desc(), fn.tensors.seq_length,
539 x_descs_arr.data(), x.data_ptr(),
540 descs.hx_desc.desc(), hx.data_ptr(),
541 descs.cx_desc.desc(), cx.defined() ? cx.data_ptr() : nullptr,
542 w_desc.desc(), weight_buf.data_ptr(),
543 y_descs_arr.data(), y.data_ptr(),
544 descs.hy_desc.desc(), hy.data_ptr(),
545 descs.cy_desc.desc(), cy.defined() ? cy.data_ptr() : nullptr,
546 workspace.data_ptr(), workspace_size));
547 }
548
549 if (batch_first && !is_input_packed) {
550 output.transpose_(0, 1);
551 }
552
553 return std::make_tuple(output, hy, cy, reserve, weight_buf);
554
555 }
556
miopen_rnn_backward_input(const Tensor & input_r,const Tensor & weight_buf,const Tensor & hx,const Tensor & cx,const Tensor & output_r,const Tensor & grad_output_r,const Tensor & grad_hy,const Tensor & grad_cy,int64_t fn_mode,int64_t fn_hidden_size,int64_t fn_num_layers,bool batch_first,double fn_dropout,bool fn_train,bool fn_bidirectional,IntArrayRef fn_batch_sizes,const Tensor & fn_dropout_state,const Tensor & fn_reserve,std::array<bool,3> output_mask)557 std::tuple<Tensor, Tensor, Tensor, Tensor> miopen_rnn_backward_input(
558 const Tensor& input_r, const Tensor& weight_buf, const Tensor& hx, const Tensor& cx,
559 const Tensor& output_r, const Tensor& grad_output_r, const Tensor& grad_hy,
560 const Tensor& grad_cy,
561 int64_t fn_mode, int64_t fn_hidden_size,
562 int64_t fn_num_layers, bool batch_first, double fn_dropout,
563 bool fn_train, bool fn_bidirectional, IntArrayRef fn_batch_sizes,
564 const Tensor& fn_dropout_state, const Tensor& fn_reserve,
565 std::array<bool, 3> output_mask
566 ) {
567 auto input = input_r;
568 auto grad_output = grad_output_r;
569 auto output = output_r;
570
571 RNNParams fn;
572 auto datatype = getMiopenDataType(input);
573 fn.rnn.set(fn_mode, fn_hidden_size, fn_num_layers, fn_bidirectional, datatype, miopenRNNwithBias);
574 fn.tensors.set(input.sizes(), fn_batch_sizes, batch_first);
575
576 auto handle = getMiopenHandle();
577
578 if(fn.rnn.rnn_mode != miopenLSTM) {
579 TORCH_CHECK(!cx.defined(), "rnn: illegal defined cx for non-LSTM RNN");
580 }
581
582 auto is_input_packed = fn_batch_sizes.size() != 0;
583 if (batch_first && !is_input_packed) {
584 input = input.transpose(0, 1);
585 grad_output = grad_output.transpose(0, 1);
586 output = output.transpose(0, 1);
587 }
588
589 auto input_size = _input_size(fn.tensors);
590 auto hidden_size = _hidden_size(fn.rnn, fn.tensors);
591 auto output_size = _output_size(fn.rnn, fn.tensors);
592
593 TORCH_CHECK(hx.is_contiguous(), "rnn: hx is not contiguous");
594 TORCH_CHECK(!cx.defined() || cx.is_contiguous(), "rnn: cx is not contiguous");
595
596 auto x = input.contiguous();
597 auto dy = grad_output.contiguous();
598 auto y = output;
599 auto w = weight_buf;
600 auto dx = at::empty(input.sizes(), input.options());
601 auto dhy = grad_hy.contiguous().view(hidden_size);
602 auto dcy = grad_cy.defined() ? grad_cy.contiguous().view(hidden_size) : Tensor();
603 auto dhx = at::empty(hidden_size, hx.options());
604 TORCH_INTERNAL_ASSERT(cx.defined() || !output_mask[2],
605 "illegally required grad of cx for non-LSTM RNN");
606 auto dcx = cx.defined() ? at::empty(hidden_size, cx.options()) : Tensor();
607
608 TORCH_CHECK(fn_train, "miopen RNN backward can only be called in training mode");
609
610 TORCH_CHECK(input.sizes().equals(input_size),
611 "Expected input size ", IntArrayRef{input_size}, ", got ", input.sizes());
612 TORCH_CHECK(output.sizes().equals(output_size),
613 "Expected output size ", IntArrayRef{output_size}, ", got ", output.sizes());
614
615 TORCH_CHECK(!hx.defined() || hx.sizes().equals(hidden_size),
616 "Expected hidden size ", IntArrayRef{hidden_size}, ", got ", hx.sizes());
617 TORCH_CHECK(!cx.defined() || cx.sizes().equals(hidden_size),
618 "Expected cell size ", IntArrayRef{hidden_size}, ", got ", cx.sizes());
619 TORCH_CHECK(!dhy.defined() || dhy.sizes().equals(hidden_size),
620 "Expected d_hidden size ", IntArrayRef{hidden_size}, ", got ", dhy.sizes());
621 TORCH_CHECK(!dcy.defined() || dcy.sizes().equals(hidden_size),
622 "Expected d_cell size ", IntArrayRef{hidden_size}, ", got ", dcy.sizes());
623
624 TORCH_CHECK(dhy.is_cuda() && dy.is_cuda() && (!dcy.defined() || dcy.is_cuda()),
625 "Gradients aren't HIP tensors");
626
627 miopenRNNAlgo_t algo = miopenRNNdefault;
628 fn.rnn.set_algo(algo);
629 RNNDescriptors descs(fn, handle, x, y, hx, cx);
630
631 FilterDescriptor w_desc;
632 w_desc.set(weight_buf, 3);
633
634 size_t workspace_size;
635 auto x_descs_arr = descs.get_x_descs();
636 auto y_descs_arr = descs.get_y_descs();
637
638 MIOPEN_CHECK(miopenGetRNNWorkspaceSize(
639 handle,
640 descs.rnn_desc.desc(),
641 fn.tensors.seq_length,
642 x_descs_arr.data(),
643 &workspace_size
644 ));
645 auto workspace = at::empty(workspace_size, input.options().dtype(kByte));
646
647 MIOPEN_CHECK(miopenRNNBackwardData(
648 handle,
649 descs.rnn_desc.desc(),
650 fn.tensors.seq_length,
651 y_descs_arr.data(), y.data_ptr(),
652 y_descs_arr.data(), dy.data_ptr(),
653 descs.hy_desc.desc(), dhy.data_ptr(),
654 descs.cy_desc.desc(), cx.defined() ? dcy.data_ptr() : nullptr,
655 w_desc.desc(), w.data_ptr(),
656 descs.hx_desc.desc(), hx.data_ptr(),
657 descs.cx_desc.desc(), cx.defined() ? cx.data_ptr() : nullptr,
658 x_descs_arr.data(), dx.data_ptr(),
659 descs.hx_desc.desc(), dhx.data_ptr(),
660 descs.cx_desc.desc(), cx.defined() ? dcx.data_ptr() : nullptr,
661 workspace.data_ptr(), workspace.size(0),
662 fn_reserve.data_ptr(), fn_reserve.size(0)
663 ));
664
665 if(batch_first && !is_input_packed) {
666 dx = dx.transpose_(0, 1);
667 }
668
669 return std::make_tuple(dx, dhx, dcx, workspace);
670 }
671
miopen_rnn_backward_weight(const Tensor & input_r,TensorList weight_arr,int64_t weight_stride0,const Tensor & weight_buf,const Tensor & hx,const Tensor & cx,const Tensor & output_r,int64_t fn_mode,int64_t fn_hidden_size,int64_t fn_num_layers,bool batch_first,double fn_dropout,bool fn_train,bool fn_bidirectional,IntArrayRef fn_batch_sizes,const Tensor & fn_dropout_state,const Tensor & fn_reserve,const Tensor & fn_workspace)672 std::vector<Tensor> miopen_rnn_backward_weight(
673 const Tensor& input_r, TensorList weight_arr, int64_t weight_stride0,
674 const Tensor& weight_buf, const Tensor& hx, const Tensor& cx,
675 const Tensor& output_r,
676 int64_t fn_mode, int64_t fn_hidden_size,
677 int64_t fn_num_layers, bool batch_first, double fn_dropout,
678 bool fn_train, bool fn_bidirectional, IntArrayRef fn_batch_sizes,
679 const Tensor& fn_dropout_state, const Tensor& fn_reserve, const Tensor& fn_workspace
680 ) {
681 MatrixRef<Tensor> weight{ weight_arr, static_cast<size_t>(weight_stride0) };
682
683 auto input = input_r;
684 auto output = output_r;
685
686 RNNParams fn;
687 auto datatype = getMiopenDataType(input);
688 miopenRNNBiasMode_t bias_mode = (weight_stride0 == 4) ? miopenRNNwithBias : miopenRNNNoBias;
689 fn.rnn.set(fn_mode, fn_hidden_size, fn_num_layers, fn_bidirectional, datatype, bias_mode);
690 fn.tensors.set(input.sizes(), fn_batch_sizes, batch_first);
691
692 auto handle = getMiopenHandle();
693
694 if (fn.rnn.rnn_mode != miopenLSTM) {
695 TORCH_CHECK(!cx.defined(), "rnn: illegal defined cx for non-LSTM RNN");
696 }
697
698 auto is_input_packed = fn_batch_sizes.size() != 0;
699 if (batch_first && !is_input_packed) {
700 input = input.transpose(0, 1);
701 output = output.transpose(0, 1);
702 }
703
704 auto input_size = _input_size(fn.tensors);
705 auto hidden_size = _hidden_size(fn.rnn, fn.tensors);
706
707 TORCH_CHECK(fn_train, "miopen RNN backward can only be called in training mode");
708
709 TORCH_CHECK(input.sizes().equals(input_size),
710 "Expected input size ", IntArrayRef{input_size}, ", got ", input.sizes());
711 TORCH_CHECK(!hx.defined() || hx.sizes().equals(hidden_size),
712 "Expected hidden size ", IntArrayRef{hidden_size}, ", got ", hx.sizes());
713
714 TORCH_CHECK(hx.is_contiguous(), "rnn: hx is not contiguous");
715 TORCH_CHECK(!cx.defined() || cx.is_contiguous(), "rnn: cx is not contiguous");
716
717 auto x = input.contiguous();
718 const auto& y = output;
719 auto dw = at::zeros(weight_buf.sizes(), weight_buf.options());
720
721 miopenRNNAlgo_t algo = miopenRNNdefault;
722 fn.rnn.set_algo(algo);
723 RNNDescriptors descs(fn, handle, x, y, hx, cx);
724
725 FilterDescriptor w_desc;
726 w_desc.set(weight_buf, 3);
727
728 auto x_descs_arr = descs.get_x_descs();
729 auto y_descs_arr = descs.get_y_descs();
730
731 MIOPEN_CHECK(miopenRNNBackwardWeights(
732 handle,
733 descs.rnn_desc.desc(),
734 fn.tensors.seq_length,
735 x_descs_arr.data(), x.data_ptr(),
736 descs.hx_desc.desc(), hx.data_ptr(),
737 y_descs_arr.data(), y.data_ptr(),
738 w_desc.desc(), dw.data_ptr(),
739 fn_workspace.data_ptr(), fn_workspace.size(0),
740 fn_reserve.data_ptr(), fn_reserve.size(0)
741 ));
742
743 auto [grad_params_arr, grad_params_stride0] = get_parameters(handle, fn.rnn, descs.rnn_desc, descs.x_descs[0], w_desc, dw);
744 if (grad_params_stride0 == static_cast<size_t>(weight_stride0)) {
745 _viewParams(MatrixRef<Tensor>{grad_params_arr, grad_params_stride0},
746 MatrixRef<Tensor>{weight_arr, static_cast<size_t>(weight_stride0)});
747 return grad_params_arr;
748 } else {
749 std::vector<Tensor> grad_weight_arr;
750 grad_weight_arr.reserve( weight.numel() );
751 for (const auto& w : weight_arr) {
752 grad_weight_arr.emplace_back(at::empty(w.sizes(), w.options()));
753 }
754 _copyParams(MatrixRef<Tensor>{grad_params_arr, grad_params_stride0},
755 MatrixRef<Tensor>{grad_weight_arr, static_cast<size_t>(weight_stride0)});
756 return grad_weight_arr;
757 }
758 }
759
miopen_rnn_backward(const Tensor & input,TensorList weight,int64_t weight_stride0,const Tensor & weight_buf,const Tensor & hx,const std::optional<Tensor> & cx_opt,const Tensor & output,const std::optional<Tensor> & grad_output_r_opt,const std::optional<Tensor> & grad_hy_r_opt,const std::optional<Tensor> & grad_cy_r_opt,int64_t mode,int64_t hidden_size,int64_t num_layers,bool batch_first,double dropout,bool train,bool bidirectional,IntArrayRef batch_sizes,const std::optional<Tensor> & dropout_state_opt,const Tensor & reserve,std::array<bool,4> output_mask)760 std::tuple<Tensor, Tensor, Tensor, std::vector<Tensor>> miopen_rnn_backward(
761 const Tensor& input, TensorList weight, int64_t weight_stride0, const Tensor& weight_buf, const Tensor& hx, const std::optional<Tensor>& cx_opt,
762 const Tensor& output, const std::optional<Tensor>& grad_output_r_opt, const std::optional<Tensor>& grad_hy_r_opt, const std::optional<Tensor>& grad_cy_r_opt, int64_t mode, int64_t hidden_size, int64_t num_layers, bool batch_first,
763 double dropout, bool train, bool bidirectional, IntArrayRef batch_sizes, const std::optional<Tensor>& dropout_state_opt,
764 const Tensor& reserve, std::array<bool, 4> output_mask
765 ) {
766 // See [Note: hacky wrapper removal for optional tensor]
767 c10::MaybeOwned<Tensor> cx_maybe_owned = at::borrow_from_optional_tensor(cx_opt);
768 const Tensor& cx = *cx_maybe_owned;
769 const Tensor& grad_output_r = c10::value_or_else(grad_output_r_opt, [] {return Tensor();});
770 const Tensor& grad_hy_r = c10::value_or_else(grad_hy_r_opt, [] {return Tensor();});
771 const Tensor& grad_cy_r = c10::value_or_else(grad_cy_r_opt, [] {return Tensor();});
772 const Tensor& dropout_state = c10::value_or_else(dropout_state_opt, [] {return Tensor();});
773
774 if (!grad_output_r.defined() && !grad_hy_r.defined() && !grad_cy_r.defined()) {
775 return std::tuple<Tensor, Tensor, Tensor, std::vector<Tensor>>(Tensor(), Tensor(), Tensor(), std::vector<Tensor>(weight.size()));
776 }
777 auto grad_output = grad_output_r.defined() ? grad_output_r : at::zeros_like(output, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
778 auto grad_hy = grad_hy_r.defined() ? grad_hy_r : at::zeros_like(hx, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
779 auto grad_cy = cx.defined() ? (grad_cy_r.defined() ? grad_cy_r : at::zeros_like(cx, LEGACY_CONTIGUOUS_MEMORY_FORMAT)) : grad_cy_r;
780
781 auto [dx, dhx, dcx, ws] = at::native::miopen_rnn_backward_input(input, weight_buf, hx, cx, output, grad_output, grad_hy, grad_cy, mode, hidden_size, num_layers, batch_first, dropout, train, bidirectional, batch_sizes, dropout_state, reserve, {output_mask[0], output_mask[1], output_mask[2]});
782 std::vector<Tensor> dw;
783 if (output_mask[3]) {
784 dw = at::native::miopen_rnn_backward_weight(input, weight, weight_stride0, weight_buf, hx, cx, output, mode, hidden_size, num_layers, batch_first, dropout, train, bidirectional, batch_sizes, dropout_state, reserve, ws);
785 if (mode > 1) {
786 for (const auto i : c10::irange(dw.size())) {
787 dw[i] = permute_wei_for_miopen(dw[i], mode);
788 }
789 }
790 }
791 return std::tuple<Tensor, Tensor, Tensor, std::vector<Tensor>>{dx, dhx, dcx, dw};
792 }
793
794 namespace {
795
unpack_hidden(const Tensor & hidden)796 std::tuple<Tensor, Tensor> unpack_hidden(const Tensor& hidden) {
797 return std::make_tuple(hidden, at::Tensor{});
798 }
799
unpack_hidden(const std::tuple<Tensor,Tensor> & hidden)800 std::tuple<Tensor, Tensor> unpack_hidden(const std::tuple<Tensor, Tensor>& hidden) {
801 return hidden;
802 }
803
804 template<typename hidden_type>
pack_hidden(const Tensor & hx,const Tensor & cx)805 hidden_type pack_hidden(const Tensor& hx, const Tensor& cx) {
806 static_assert(std::is_same<hidden_type, void>::value, "pack_hidden not implemented for this type");
807 AT_ERROR("NOT IMPLEMENTED");
808 }
809
810 template<>
pack_hidden(const Tensor & hx,const Tensor & cx)811 Tensor pack_hidden<Tensor>(const Tensor& hx, const Tensor& cx) {
812 AT_ASSERT(cx.numel() == 0);
813 return hx;
814 }
815
816 template<>
pack_hidden(const Tensor & hx,const Tensor & cx)817 std::tuple<Tensor, Tensor> pack_hidden<std::tuple<Tensor, Tensor>>(const Tensor& hx, const Tensor& cx) {
818 return std::make_tuple(hx, cx);
819 }
820
821 template<typename hidden_type>
_miopen_impl(const Tensor & input,const Tensor & _batch_sizes,const hidden_type & hidden,TensorList params,bool has_biases,miopenRNNMode_t mode,int64_t num_layers,double dropout_p,bool train,bool bidirectional)822 std::pair<Tensor, hidden_type> _miopen_impl(
823 const Tensor& input, const Tensor& _batch_sizes, const hidden_type& hidden,
824 TensorList params, bool has_biases, miopenRNNMode_t mode,
825 int64_t num_layers, double dropout_p, bool train, bool bidirectional) {
826 auto [hx, cx] = unpack_hidden(hidden);
827 int64_t hidden_size = hx.size(2);
828
829 TORCH_CHECK(_batch_sizes.dim() == 1, "batch_sizes tensor should be 1D");
830 IntArrayRef batch_sizes { _batch_sizes.data_ptr<int64_t>(), static_cast<size_t>(_batch_sizes.size(0)) };
831
832 Tensor dropout_state = at::empty({0}, input.options());
833
834 auto miopen_output = at::miopen_rnn(
835 input, params, has_biases ? 4 : 2,
836 hx, cx, static_cast<int>(mode), hidden_size, num_layers, /*batch_first=*/false,
837 dropout_p, train, bidirectional, batch_sizes, dropout_state);
838
839 return {std::get<0>(miopen_output),
840 pack_hidden<hidden_type>(std::get<1>(miopen_output), std::get<2>(miopen_output))};
841 }
842
843 template<typename hidden_type>
_miopen_impl(const Tensor & input,const hidden_type & hidden,TensorList params,bool has_biases,miopenRNNMode_t mode,int64_t num_layers,double dropout_p,bool train,bool bidirectional,bool batch_first)844 std::pair<Tensor, hidden_type> _miopen_impl(
845 const Tensor& input, const hidden_type& hidden,
846 TensorList params, bool has_biases, miopenRNNMode_t mode,
847 int64_t num_layers, double dropout_p, bool train, bool bidirectional, bool batch_first) {
848 auto [hx, cx] = unpack_hidden(hidden);
849 int64_t hidden_size = hx.size(2);
850
851 Tensor dropout_state = at::empty({0}, input.options());
852
853 auto miopen_output = at::miopen_rnn(
854 input, params, has_biases ? 4 : 2,
855 hx, cx, static_cast<int>(mode), hidden_size, num_layers, batch_first, dropout_p,
856 train, bidirectional, /*batch_sizes=*/{}, dropout_state);
857
858 return {std::get<0>(miopen_output),
859 pack_hidden<hidden_type>(std::get<1>(miopen_output), std::get<2>(miopen_output))};
860 }
861
862 #define ONE_HIDDEN_RNN(NAME, MODE) \
863 void NAME##_miopen(Tensor& output, Tensor& hy, \
864 const Tensor& input, const Tensor& hx, \
865 TensorList params, bool has_biases, \
866 int64_t num_layers, double dropout_p, bool train, bool bidirectional, bool batch_first) { \
867 std::tie(output, hy) = _miopen_impl(input, hx, params, has_biases, \
868 MODE, num_layers, dropout_p, train, bidirectional, batch_first); \
869 } \
870 \
871 void NAME##_packed_miopen(Tensor& output, Tensor& hy, \
872 const Tensor& data, const Tensor& batch_sizes, const Tensor& hx, \
873 TensorList params, bool has_biases, \
874 int64_t num_layers, double dropout_p, bool train, bool bidirectional) { \
875 std::tie(output, hy) = _miopen_impl(data, batch_sizes, hx, params, \
876 has_biases, MODE, num_layers, dropout_p, train, bidirectional); \
877 } \
878 \
879 REGISTER_CUDA_DISPATCH(NAME##_miopen_stub, &NAME##_miopen); \
880 REGISTER_CUDA_DISPATCH(NAME##_packed_miopen_stub, &NAME##_packed_miopen);
881
ONE_HIDDEN_RNN(gru,miopenGRU)882 ONE_HIDDEN_RNN(gru, miopenGRU)
883 ONE_HIDDEN_RNN(rnn_tanh, miopenRNNTANH)
884 ONE_HIDDEN_RNN(rnn_relu, miopenRNNRELU)
885
886 void lstm_miopen(Tensor& output, Tensor& hy, Tensor& cy,
887 const Tensor& input, TensorList hx,
888 TensorList params, bool has_biases,
889 int64_t num_layers, double dropout_p, bool train, bool bidirectional, bool batch_first) {
890 auto result = _miopen_impl(input, std::make_tuple(hx[0], hx[1]), params, has_biases,
891 miopenLSTM, num_layers, dropout_p, train, bidirectional, batch_first);
892 output = result.first;
893 hy = std::get<0>(result.second);
894 cy = std::get<1>(result.second);
895 }
896
lstm_packed_miopen(Tensor & output,Tensor & hy,Tensor & cy,const Tensor & data,const Tensor & batch_sizes,TensorList hx,TensorList params,bool has_biases,int64_t num_layers,double dropout_p,bool train,bool bidirectional)897 void lstm_packed_miopen(Tensor& output, Tensor& hy, Tensor& cy,
898 const Tensor& data, const Tensor& batch_sizes, TensorList hx,
899 TensorList params, bool has_biases,
900 int64_t num_layers, double dropout_p, bool train, bool bidirectional) {
901 auto result = _miopen_impl(data, batch_sizes, std::make_tuple(hx[0], hx[1]),
902 params, has_biases, miopenLSTM, num_layers, dropout_p, train, bidirectional);
903 output = result.first;
904 hy = std::get<0>(result.second);
905 cy = std::get<1>(result.second);
906 }
907
908 REGISTER_CUDA_DISPATCH(lstm_miopen_stub, &lstm_miopen);
909 REGISTER_CUDA_DISPATCH(lstm_packed_miopen_stub, &lstm_packed_miopen);
910
911 } // anonymous namespace
912 }} //namespace native.
913
914 #endif
915