• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/native/RNN.h>
3 #include <ATen/core/Tensor.h>
4 #include <ATen/Config.h>
5 #include <ATen/InitialTensorOptions.h>
6 #include <ATen/MatrixRef.h>
7 #include <ATen/TensorUtils.h>
8 
9 #include <ATen/cuda/CUDAConfig.h>
10 #include <c10/util/Exception.h>
11 #include <c10/util/irange.h>
12 
13 #ifndef AT_PER_OPERATOR_HEADERS
14 #include <ATen/Functions.h>
15 #include <ATen/NativeFunctions.h>
16 #else
17 #include <ATen/ops/cat.h>
18 #include <ATen/ops/empty.h>
19 #include <ATen/ops/miopen_rnn.h>
20 #include <ATen/ops/miopen_rnn_native.h>
21 #include <ATen/ops/miopen_rnn_backward_native.h>
22 #include <ATen/ops/zeros.h>
23 #include <ATen/ops/zeros_like.h>
24 #endif
25 
26 #if !AT_ROCM_ENABLED()
27 
28 namespace at { namespace native {
29 
miopen_rnn(const Tensor & input_r,TensorList weight,int64_t weight_stride0,const Tensor & hx,const std::optional<Tensor> & cx_opt,int64_t fn_mode,int64_t fn_hidden_size,int64_t fn_num_layers,bool batch_first,double fn_dropout,bool fn_train,bool fn_bidirectional,IntArrayRef fn_batch_sizes,const std::optional<Tensor> & fn_dropout_state_opt)30     std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor> miopen_rnn(
31             const Tensor& input_r, TensorList weight, int64_t weight_stride0,
32             const Tensor& hx, const std::optional<Tensor>& cx_opt,
33             int64_t fn_mode, int64_t fn_hidden_size, int64_t fn_num_layers,
34             bool batch_first, double fn_dropout, bool fn_train, bool fn_bidirectional,
35             IntArrayRef fn_batch_sizes, const std::optional<Tensor>& fn_dropout_state_opt
36             ) {
37         AT_ERROR("miopen_rnn : ATen not compiled with MIOpen support.");
38     }
39 
miopen_rnn_backward(const Tensor & input,TensorList weight,int64_t weight_stride0,const Tensor & weight_buf,const Tensor & hx,const std::optional<Tensor> & cx_opt,const Tensor & output,const std::optional<Tensor> & grad_output_r_opt,const std::optional<Tensor> & grad_hy_r_opt,const std::optional<Tensor> & grad_cy_r_opt,int64_t mode,int64_t hidden_size,int64_t num_layers,bool batch_first,double dropout,bool train,bool bidirectional,IntArrayRef batch_sizes,const std::optional<Tensor> & dropout_state_opt,const Tensor & reserve,std::array<bool,4> output_mask)40     std::tuple<Tensor, Tensor, Tensor, std::vector<Tensor>> miopen_rnn_backward(
41             const Tensor& input, TensorList weight, int64_t weight_stride0, const Tensor& weight_buf, const Tensor& hx, const std::optional<Tensor>& cx_opt,
42             const Tensor& output, const std::optional<Tensor>& grad_output_r_opt, const std::optional<Tensor>& grad_hy_r_opt, const std::optional<Tensor>& grad_cy_r_opt, int64_t mode, int64_t hidden_size, int64_t num_layers, bool batch_first,
43             double dropout, bool train, bool bidirectional, IntArrayRef batch_sizes, const std::optional<Tensor>& dropout_state_opt,
44             const Tensor& reserve, std::array<bool, 4> output_mask
45             ) {
46         AT_ERROR("miopen_rnn_backward: ATen not compiled with MIOpen support.");
47     }
48 
49 }} //namespace at::native
50 
51 #else // AT_ROCM_ENABLED()
52 
53 #include <ATen/miopen/miopen-wrapper.h>
54 #include <ATen/miopen/Descriptors.h>
55 #include <ATen/miopen/Types.h>
56 #include <ATen/miopen/Utils.h>
57 
58 #include <ATen/TensorUtils.h>
59 
60 #include <functional>
61 #include <iterator>
62 #include <sstream>
63 #include <algorithm>
64 #include <memory>
65 #include <mutex>
66 #include <stdint.h>
67 #include <unordered_map>
68 
69 namespace at { namespace native {
70 
71 //RNNDescriptor.
72 struct RNNDescriptorParams {
73     int64_t hidden_size;
74     int64_t num_layers;
75     miopenRNNDirectionMode_t direction;
76     miopenRNNMode_t rnn_mode;
77     miopenDataType_t datatype;
78     miopenRNNAlgo_t algo = miopenRNNdefault;
79     miopenRNNInputMode_t input_mode = miopenRNNlinear;
80     miopenRNNBiasMode_t bias_mode = miopenRNNNoBias;
81 
num_directionsat::native::RNNDescriptorParams82     int64_t num_directions() const {
83         return (direction == miopenRNNbidirection) ? 2 : 1;
84     }
85 
set_bidirectionalat::native::RNNDescriptorParams86     void set_bidirectional(bool fn_bidirectional) {
87         direction = fn_bidirectional ? miopenRNNbidirection : miopenRNNunidirection;
88     }
89 
set_algoat::native::RNNDescriptorParams90     void set_algo(miopenRNNAlgo_t algo) {
91         this->algo = algo;
92     }
93 
set_modeat::native::RNNDescriptorParams94     void set_mode(int64_t fn_mode) {
95         switch (fn_mode) {
96             case 0:
97                 rnn_mode = miopenRNNRELU;
98                 break;
99             case 1:
100                 rnn_mode = miopenRNNTANH;
101                 break;
102             case 2:
103                 rnn_mode = miopenLSTM;
104                 break;
105             case 3:
106                 rnn_mode = miopenGRU;
107                 break;
108             default:
109                 {
110                     std::ostringstream oss;
111                     oss << "unrecognized miopen RNN mode " << fn_mode;
112                     AT_ERROR(oss.str());
113                 }
114         }
115     }
116 
setat::native::RNNDescriptorParams117     void set(int64_t mode, int64_t hidden_size, int64_t num_layers, bool bidirectional, miopenDataType_t datatype, miopenRNNBiasMode_t bias_mode) {
118         this->set_mode(mode);
119         this->hidden_size = hidden_size;
120         this->num_layers = num_layers;
121         this->set_bidirectional(bidirectional);
122         this->datatype = datatype;
123         this->bias_mode = bias_mode;
124     }
125 
descriptorat::native::RNNDescriptorParams126     RNNDescriptor descriptor() const {
127         RNNDescriptor rnn_desc;
128         rnn_desc.set(hidden_size, num_layers, input_mode, direction, rnn_mode, bias_mode, algo, datatype);
129         return rnn_desc;
130     }
131 };
132 
133 //TensorDescriptor list.
rnn_descriptor_sequence(const Tensor & tensor,IntArrayRef batch_sizes)134 std::vector<TensorDescriptor> rnn_descriptor_sequence(const Tensor& tensor, IntArrayRef batch_sizes) {
135     std::vector<TensorDescriptor> descriptors(batch_sizes.size());
136     size_t i =0;
137 
138     auto batch_tensor_size = tensor.sizes().vec();
139     for (auto batch_size : batch_sizes) {
140         batch_tensor_size[0] = batch_size;
141 
142         descriptors[i].set(getMiopenDataType(tensor), batch_tensor_size, tensor.strides(), 3);
143         i++;
144     }
145 
146     return descriptors;
147 }
148 
rnn_descriptor(const Tensor & tensor,int64_t N)149 std::vector<TensorDescriptor> rnn_descriptor(const Tensor& tensor, int64_t N) {
150     std::vector<TensorDescriptor> descriptors(N);
151     for (const auto i : c10::irange(N)) {
152         descriptors[i].set(tensor, 5);
153     }
154 
155     return descriptors;
156 }
157 
158 struct TensorDescriptorListParams {
159     IntArrayRef batch_sizes;
160     int64_t seq_length;
161     int64_t mini_batch;
162 
163     int64_t input_size;
164     int64_t batch_sizes_sum;
165 
is_input_packedat::native::TensorDescriptorListParams166     bool is_input_packed() const {
167         return batch_sizes.size() != 0;
168     }
169 
setat::native::TensorDescriptorListParams170     void set(IntArrayRef input_sizes, IntArrayRef batch_sizes_, bool batch_first) {
171         batch_sizes = batch_sizes_;
172         if (is_input_packed()) {
173             seq_length = batch_sizes.size();
174             mini_batch = batch_sizes[0];
175             batch_sizes_sum = input_sizes[0];
176             input_size = input_sizes[1];
177         } else {
178             if (batch_first) {
179                 seq_length = input_sizes[1];
180                 mini_batch = input_sizes[0];
181             } else {
182                 seq_length = input_sizes[0];
183                 mini_batch = input_sizes[1];
184             }
185             input_size = input_sizes[2];
186             batch_sizes_sum = -1;
187         }
188     }
189 
descriptorsat::native::TensorDescriptorListParams190     std::vector<TensorDescriptor> descriptors(Tensor x) const {
191         auto is_input_packed = batch_sizes.size() != 0;
192         if (is_input_packed) {
193             return rnn_descriptor_sequence(x, batch_sizes);
194         } else {
195             return rnn_descriptor(x[0], seq_length);
196         }
197     }
198 };
199 
200 struct RNNParams {
201     RNNDescriptorParams rnn;
202     TensorDescriptorListParams tensors;
203 };
204 
205 struct RNNDescriptors {
206     RNNDescriptor rnn_desc;
207     std::vector<TensorDescriptor> x_descs;
208     std::vector<TensorDescriptor> y_descs;
209     TensorDescriptor hx_desc;
210     TensorDescriptor hy_desc;
211     TensorDescriptor cx_desc;
212     TensorDescriptor cy_desc;
213 
RNNDescriptorsat::native::RNNDescriptors214     RNNDescriptors(const RNNParams& fn, miopenHandle_t handle, Tensor x, Tensor y, Tensor hx, Tensor cx) {
215         rnn_desc = fn.rnn.descriptor();
216         x_descs = fn.tensors.descriptors(x);
217         y_descs = fn.tensors.descriptors(y);
218         hx_desc.set(hx, 5);
219         hy_desc.set(hx, 5);
220         cx_desc.set(hx, 5);
221         cy_desc.set(hx, 5);
222     }
223 
get_descsat::native::RNNDescriptors224     std::vector<miopenTensorDescriptor_t> get_descs(const std::vector<TensorDescriptor>& descs) {
225         std::vector<miopenTensorDescriptor_t> r;
226         r.reserve(descs.size());
227         for (auto& desc : descs) {
228             r.emplace_back(desc.desc());
229         }
230         return r;
231     }
232 
get_x_descsat::native::RNNDescriptors233     std::vector<miopenTensorDescriptor_t> get_x_descs() {
234         return get_descs(x_descs);
235     }
236 
get_y_descsat::native::RNNDescriptors237     std::vector<miopenTensorDescriptor_t> get_y_descs() {
238         return get_descs(y_descs);
239     }
240 };
241 
permute_wei_for_miopen(Tensor wei,int64_t mode)242 Tensor permute_wei_for_miopen(Tensor wei, int64_t mode)
243 {
244     if (mode < 2)
245         return wei;
246 
247     Tensor permuted_wei;
248     if(mode == 2) { // LSTM
249         auto sliced_tensor = wei.chunk(4, 0);
250         permuted_wei = at::cat({sliced_tensor[0], sliced_tensor[1], sliced_tensor[3], sliced_tensor[2]});
251     }
252     else if(mode == 3) {    // GRU
253         auto sliced_tensor = wei.chunk(3, 0);
254         permuted_wei = at::cat({sliced_tensor[1], sliced_tensor[0], sliced_tensor[2]});
255     }
256     return permuted_wei;
257 }
258 
_viewOrCopyParams(MatrixRef<Tensor> params_from,MatrixRef<Tensor> params_to,bool copy)259 void _viewOrCopyParams(MatrixRef<Tensor> params_from, MatrixRef<Tensor> params_to, bool copy) {
260     TORCH_CHECK(params_from.size(0) == params_to.size(0), "number of layers mismatch");
261     for (const auto i : c10::irange(params_from.size(0))) {
262         auto layer_params_from = params_from[i];
263         auto layer_params_to = params_to[i];
264         // NOTE: these lists have all weights before all biases, so if the layer
265         // doesn't use biases, iteration will terminate once layer_params_from ends
266         // and ignore them.
267         for (auto a = layer_params_from.begin(), b = layer_params_to.begin();
268                 a != layer_params_from.end() && b != layer_params_to.end();
269                 ++a, ++b) {
270             auto param_from = *a, param_to = *b;
271             TORCH_CHECK(param_from.type() == param_to.type(), "parameter types mismatch");
272             if (copy) {
273                 param_to.copy_(param_from.view_as(param_to));
274             } else {
275                 param_from.resize_as_(param_to);
276             }
277         }
278     }
279 }
280 
_copyParams_and_permute(MatrixRef<Tensor> params_from,MatrixRef<Tensor> params_to,int64_t mode)281 void _copyParams_and_permute(MatrixRef<Tensor> params_from, MatrixRef<Tensor> params_to, int64_t mode) {
282     TORCH_CHECK(params_from.size(0) == params_to.size(0), "number of layers mismatch");
283     for (const auto i : c10::irange(params_from.size(0))) {
284         auto layer_params_from = params_from[i];
285         auto layer_params_to = params_to[i];
286         for (auto a = layer_params_from.begin(), b = layer_params_to.begin();
287                 a != layer_params_from.end() && b != layer_params_to.end();
288                 ++a, ++b) {
289             auto param_from = *a, param_to = *b;
290             TORCH_CHECK(param_from.type() == param_to.type(), "parameter types mismatch");
291             auto tmp = permute_wei_for_miopen(param_from, mode);
292             param_to.copy_(tmp.view_as(param_to));
293         }
294     }
295 }
296 
_copyParams(MatrixRef<Tensor> params_from,MatrixRef<Tensor> params_to)297 void _copyParams(MatrixRef<Tensor> params_from, MatrixRef<Tensor> params_to) {
298     _viewOrCopyParams(params_from, params_to, true);
299 }
300 
_viewParams(MatrixRef<Tensor> params_from,MatrixRef<Tensor> params_to)301 void _viewParams(MatrixRef<Tensor> params_from, MatrixRef<Tensor> params_to) {
302     _viewOrCopyParams(params_from, params_to, false);
303 }
304 
get_num_weights(miopenHandle_t handle,const RNNDescriptor & rnn_desc,const TensorDescriptor & x_desc,miopenDataType_t datatype)305 int64_t get_num_weights(miopenHandle_t handle, const RNNDescriptor& rnn_desc,
306         const TensorDescriptor& x_desc, miopenDataType_t datatype)
307 {
308     size_t weight_size;
309     MIOPEN_CHECK(miopenGetRNNParamsSize(handle, rnn_desc.desc(), x_desc.desc(), &weight_size, datatype));
310     auto element_size = dataSize(datatype);
311     TORCH_CHECK(weight_size % element_size == 0, "miopenGetRNNParamsSize returned nonsensical weight_size.");
312     return weight_size / element_size;
313 }
314 
_num_linear_layers(miopenRNNMode_t mode)315 int64_t _num_linear_layers(miopenRNNMode_t mode) {
316     switch(mode) {
317         case miopenLSTM:
318             return 8;
319         case miopenGRU:
320             return 6;
321         case miopenRNNRELU:
322             return 2;
323         case miopenRNNTANH:
324             return 2;
325         default:
326             AT_ERROR("Unknown miopen RNN mode : ", mode);
327     }
328 }
329 
get_parameters(miopenHandle_t handle,const RNNDescriptorParams & rnn,const RNNDescriptor & rnn_desc,const TensorDescriptor & x_desc,const FilterDescriptor & w_desc,const Tensor & weight_buf)330 std::pair<std::vector<Tensor>, size_t> get_parameters(miopenHandle_t handle, const RNNDescriptorParams& rnn,
331                     const RNNDescriptor& rnn_desc, const TensorDescriptor& x_desc, const FilterDescriptor& w_desc,
332                     const Tensor& weight_buf)
333 {
334     std::vector<Tensor> params;
335     int64_t num_linear_layers = _num_linear_layers(rnn.rnn_mode);
336     int64_t num_layers = rnn.num_directions() * rnn.num_layers;
337     size_t cur_offset = 0;
338     size_t global_layer_params_count = 0;
339     auto elem_size = dataSize(getMiopenDataType(weight_buf));
340     auto bias_mode = rnn.bias_mode;
341 
342     for (const auto layer : c10::irange(num_layers)) {
343         size_t layer_params_count = 0;
344 
345         // Get layer params
346         for (const auto linear_id : c10::irange(num_linear_layers)) {
347             FilterDescriptor lin_layer_mat_desc;
348             size_t offset;
349             MIOPEN_CHECK(miopenGetRNNLayerParamOffset(
350                 rnn_desc.desc(),
351                 layer,
352                 x_desc.desc(),
353                 linear_id,
354                 lin_layer_mat_desc.mut_desc(),
355                 &offset));
356 
357             size_t param_size;
358             MIOPEN_CHECK(miopenGetRNNLayerParamSize(
359                 handle,
360                 rnn_desc.desc(),
361                 layer,
362                 x_desc.desc(),
363                 linear_id,
364                 &param_size));
365             param_size /= elem_size;
366 
367             if(linear_id == 0 || linear_id == num_linear_layers / 2) {
368                 std::initializer_list<int64_t> size = { static_cast<int64_t>(param_size * num_linear_layers / 2), 1L};
369                 Tensor param = at::empty({0}, weight_buf.options()).set_(weight_buf.storage(), offset, size);
370                 params.emplace_back(std::move(param));
371                 layer_params_count++;
372             } else {
373                 TORCH_INTERNAL_ASSERT(cur_offset == offset,
374                                       "cur_offset = ", cur_offset, " ; offset = ", offset);
375             }
376             cur_offset = offset + param_size;
377         }
378 
379         // Get bias params
380         if (bias_mode == miopenRNNwithBias) {
381             for (const auto linear_id : c10::irange(num_linear_layers)) {
382                 FilterDescriptor lin_layer_mat_desc;
383                 size_t offset;
384                 MIOPEN_CHECK(miopenGetRNNLayerBiasOffset(
385                     rnn_desc.desc(),
386                     layer,
387                     x_desc.desc(),
388                     linear_id,
389                     lin_layer_mat_desc.mut_desc(),
390                     &offset));
391 
392                 size_t bias_size;
393                 MIOPEN_CHECK(miopenGetRNNLayerBiasSize(
394                     handle,
395                     rnn_desc.desc(),
396                     layer,
397                     linear_id,
398                     &bias_size));
399                 bias_size /= elem_size;
400 
401                 if(linear_id == 0 || linear_id == num_linear_layers / 2) {
402                     std::initializer_list<int64_t> size = { static_cast<int64_t>(bias_size * num_linear_layers / 2), 1L};
403                     Tensor param = at::empty({0}, weight_buf.options()).set_(weight_buf.storage(), offset, size);
404                     params.emplace_back(std::move(param));
405                     layer_params_count++;
406                 } else {
407                     TORCH_INTERNAL_ASSERT(cur_offset == offset,
408                                           "cur_offset = ", cur_offset, " ; offset = ", offset);
409                 }
410                 cur_offset = offset + bias_size;
411             }
412         }
413 
414         if (layer == 0) {
415             global_layer_params_count = layer_params_count;
416         } else {
417             TORCH_INTERNAL_ASSERT(global_layer_params_count == layer_params_count,
418                                   "global_layer_params_count = ", global_layer_params_count,
419                                   "; layer_params_count = ", layer_params_count);
420         }
421     } // layer
422     return std::make_pair(params, global_layer_params_count);
423 }
424 
_input_size(const TensorDescriptorListParams & tensors)425 std::vector<int64_t> _input_size(const TensorDescriptorListParams& tensors) {
426     if (tensors.is_input_packed()) {
427         return {tensors.batch_sizes_sum, tensors.input_size};
428     } else {
429         return {tensors.seq_length, tensors.mini_batch, tensors.input_size};
430     }
431 }
432 
_hidden_size(const RNNDescriptorParams & rnn,const TensorDescriptorListParams & tensors)433 std::vector<int64_t> _hidden_size(const RNNDescriptorParams& rnn, const TensorDescriptorListParams& tensors) {
434     return {rnn.num_layers * rnn.num_directions(), tensors.mini_batch, rnn.hidden_size};
435 }
436 
_output_size(const RNNDescriptorParams & rnn,const TensorDescriptorListParams & tensors)437 std::vector<int64_t> _output_size(const RNNDescriptorParams& rnn, const TensorDescriptorListParams& tensors) {
438     if (tensors.is_input_packed()) {
439         return {tensors.batch_sizes_sum, rnn.hidden_size * rnn.num_directions()};
440     } else {
441         return {tensors.seq_length, tensors.mini_batch, rnn.hidden_size * rnn.num_directions()};
442     }
443 }
444 
miopen_rnn(const Tensor & input_r,TensorList weight,int64_t weight_stride0,const Tensor & hx,const std::optional<Tensor> & cx_opt,int64_t fn_mode,int64_t fn_hidden_size,int64_t fn_num_layers,bool batch_first,double fn_dropout,bool fn_train,bool fn_bidirectional,IntArrayRef fn_batch_sizes,const std::optional<Tensor> & fn_dropout_state_opt)445 std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor> miopen_rnn(
446         const Tensor& input_r, TensorList weight, int64_t weight_stride0,
447         const Tensor& hx, const std::optional<Tensor>& cx_opt,
448         int64_t fn_mode, int64_t fn_hidden_size, int64_t fn_num_layers,
449         bool batch_first, double fn_dropout, bool fn_train, bool fn_bidirectional,
450         IntArrayRef fn_batch_sizes, const std::optional<Tensor>& fn_dropout_state_opt
451         ) {
452     // See [Note: hacky wrapper removal for optional tensor]
453     c10::MaybeOwned<Tensor> cx_maybe_owned = at::borrow_from_optional_tensor(cx_opt);
454     const Tensor& cx = *cx_maybe_owned;
455     const Tensor& fn_dropout_state = c10::value_or_else(fn_dropout_state_opt, [] {return Tensor();});
456 
457     check_attributes(input_r, weight, {hx, cx});
458     auto input = input_r;
459 
460     RNNParams fn;
461     auto datatype = getMiopenDataType(input);
462     miopenRNNBiasMode_t bias_mode = (weight_stride0 == 4) ? miopenRNNwithBias : miopenRNNNoBias;
463     fn.rnn.set(fn_mode, fn_hidden_size, fn_num_layers, fn_bidirectional, datatype, bias_mode);
464     fn.tensors.set(input.sizes(), fn_batch_sizes, batch_first);
465 
466     if (fn.rnn.rnn_mode != miopenLSTM) {
467         TORCH_CHECK(!cx.defined(), "miopen_rnn: illegal defined cx for non-LSTM RNN.");
468     }
469 
470     auto is_input_packed = fn.tensors.batch_sizes.size() != 0;
471     if (batch_first && !is_input_packed) {
472         input = input.transpose(0, 1);
473     }
474 
475     auto hidden_size = _hidden_size(fn.rnn, fn.tensors);
476     auto output_size = _output_size(fn.rnn, fn.tensors);
477 
478     TORCH_CHECK(hx.is_contiguous(), "miopen_rnn : hx is not contiguous.");
479     TORCH_CHECK(!cx.defined() || cx.is_contiguous(), "miopen_rnn : cx is not contiguous.");
480 
481     auto x = input.contiguous();
482     auto output = at::empty(output_size, input.options());
483     auto hy = at::empty(hidden_size, hx.options());
484     Tensor cy;
485     if (cx.defined()) {
486         cy = at::empty(hidden_size, cx.options());
487     } else {
488         cy = at::empty({0}, hx.options());
489     }
490 
491     auto y = output;
492     auto handle = getMiopenHandle();
493     miopenRNNAlgo_t algo = miopenRNNdefault;
494     fn.rnn.set_algo(algo);
495 
496     RNNDescriptors descs(fn, handle, x, y, hx, cx);
497 
498     FilterDescriptor w_desc;
499     auto num_weights = get_num_weights(handle, descs.rnn_desc, descs.x_descs[0], datatype);
500     auto weight_buf = at::empty(num_weights, x.options());
501     w_desc.set(weight_buf, 3);
502     weight_buf.zero_();
503     auto [params, params_stride0] = get_parameters(handle, fn.rnn, descs.rnn_desc, descs.x_descs[0], w_desc, weight_buf);
504     if (fn_mode < 2)
505         _copyParams(MatrixRef<Tensor>{weight, static_cast<size_t>(weight_stride0)},
506                 MatrixRef<Tensor>{params, params_stride0});
507     else
508         _copyParams_and_permute(MatrixRef<Tensor>{weight, static_cast<size_t>(weight_stride0)},
509                     MatrixRef<Tensor>{params, params_stride0}, fn_mode);
510 
511     TORCH_CHECK(!cx.defined() || cx.sizes().equals(hidden_size), "Expected cell size ", IntArrayRef{hidden_size}, ", got", cx.sizes());
512 
513     size_t workspace_size;
514     auto x_descs_arr = descs.get_x_descs();
515     auto y_descs_arr = descs.get_y_descs();
516 
517     //Allocate workspace size.
518     MIOPEN_CHECK(miopenGetRNNWorkspaceSize(handle, descs.rnn_desc.desc(), fn.tensors.seq_length, x_descs_arr.data(), &workspace_size));
519     auto workspace = at::empty(workspace_size, input.options().dtype(kByte));
520 
521     //Train or inference.
522     Tensor reserve;
523     if (fn_train) { //Train.
524         size_t reserver_size;
525         MIOPEN_CHECK(miopenGetRNNTrainingReserveSize(handle, descs.rnn_desc.desc(), fn.tensors.seq_length, x_descs_arr.data(), &reserver_size));
526         reserve = at::empty(reserver_size, input.options().dtype(kByte));
527         MIOPEN_CHECK(miopenRNNForwardTraining(handle, descs.rnn_desc.desc(), fn.tensors.seq_length,
528                 x_descs_arr.data(), x.data_ptr(),
529                 descs.hx_desc.desc(), hx.data_ptr(),
530                 descs.cx_desc.desc(), cx.defined() ? cx.data_ptr() : nullptr,
531                 w_desc.desc(), weight_buf.data_ptr(),
532                 y_descs_arr.data(), y.data_ptr(),
533                 descs.hy_desc.desc(), hy.data_ptr(),
534                 descs.cy_desc.desc(), cy.defined() ? cy.data_ptr() : nullptr,
535                 workspace.data_ptr(), workspace_size, reserve.mutable_data_ptr(), reserver_size ));
536     } else { //Inference.
537         reserve = at::empty({0}, input.options().dtype(kByte));
538         MIOPEN_CHECK(miopenRNNForwardInference(handle, descs.rnn_desc.desc(), fn.tensors.seq_length,
539                 x_descs_arr.data(), x.data_ptr(),
540                 descs.hx_desc.desc(), hx.data_ptr(),
541                 descs.cx_desc.desc(), cx.defined() ? cx.data_ptr() : nullptr,
542                 w_desc.desc(), weight_buf.data_ptr(),
543                 y_descs_arr.data(), y.data_ptr(),
544                 descs.hy_desc.desc(), hy.data_ptr(),
545                 descs.cy_desc.desc(), cy.defined() ? cy.data_ptr() : nullptr,
546                 workspace.data_ptr(), workspace_size));
547     }
548 
549     if (batch_first && !is_input_packed) {
550         output.transpose_(0, 1);
551     }
552 
553     return std::make_tuple(output, hy, cy, reserve, weight_buf);
554 
555 }
556 
miopen_rnn_backward_input(const Tensor & input_r,const Tensor & weight_buf,const Tensor & hx,const Tensor & cx,const Tensor & output_r,const Tensor & grad_output_r,const Tensor & grad_hy,const Tensor & grad_cy,int64_t fn_mode,int64_t fn_hidden_size,int64_t fn_num_layers,bool batch_first,double fn_dropout,bool fn_train,bool fn_bidirectional,IntArrayRef fn_batch_sizes,const Tensor & fn_dropout_state,const Tensor & fn_reserve,std::array<bool,3> output_mask)557 std::tuple<Tensor, Tensor, Tensor, Tensor> miopen_rnn_backward_input(
558         const Tensor& input_r, const Tensor& weight_buf, const Tensor& hx, const Tensor& cx,
559         const Tensor& output_r, const Tensor& grad_output_r, const Tensor& grad_hy,
560         const Tensor& grad_cy,
561         int64_t fn_mode, int64_t fn_hidden_size,
562         int64_t fn_num_layers, bool batch_first, double fn_dropout,
563         bool fn_train, bool fn_bidirectional, IntArrayRef fn_batch_sizes,
564         const Tensor& fn_dropout_state, const Tensor& fn_reserve,
565         std::array<bool, 3> output_mask
566         ) {
567     auto input = input_r;
568     auto grad_output = grad_output_r;
569     auto output = output_r;
570 
571     RNNParams fn;
572     auto datatype = getMiopenDataType(input);
573     fn.rnn.set(fn_mode, fn_hidden_size, fn_num_layers, fn_bidirectional, datatype, miopenRNNwithBias);
574     fn.tensors.set(input.sizes(), fn_batch_sizes, batch_first);
575 
576     auto handle = getMiopenHandle();
577 
578     if(fn.rnn.rnn_mode != miopenLSTM) {
579         TORCH_CHECK(!cx.defined(), "rnn: illegal defined cx for non-LSTM RNN");
580     }
581 
582     auto is_input_packed = fn_batch_sizes.size() != 0;
583     if (batch_first && !is_input_packed) {
584         input = input.transpose(0, 1);
585         grad_output = grad_output.transpose(0, 1);
586         output = output.transpose(0, 1);
587     }
588 
589     auto input_size = _input_size(fn.tensors);
590     auto hidden_size = _hidden_size(fn.rnn, fn.tensors);
591     auto output_size = _output_size(fn.rnn, fn.tensors);
592 
593     TORCH_CHECK(hx.is_contiguous(), "rnn: hx is not contiguous");
594     TORCH_CHECK(!cx.defined() || cx.is_contiguous(), "rnn: cx is not contiguous");
595 
596     auto x = input.contiguous();
597     auto dy = grad_output.contiguous();
598     auto y = output;
599     auto w = weight_buf;
600     auto dx = at::empty(input.sizes(), input.options());
601     auto dhy = grad_hy.contiguous().view(hidden_size);
602     auto dcy = grad_cy.defined() ? grad_cy.contiguous().view(hidden_size) : Tensor();
603     auto dhx = at::empty(hidden_size, hx.options());
604     TORCH_INTERNAL_ASSERT(cx.defined() || !output_mask[2],
605                           "illegally required grad of cx for non-LSTM RNN");
606     auto dcx = cx.defined() ? at::empty(hidden_size, cx.options()) : Tensor();
607 
608     TORCH_CHECK(fn_train, "miopen RNN backward can only be called in training mode");
609 
610     TORCH_CHECK(input.sizes().equals(input_size),
611         "Expected input size ", IntArrayRef{input_size}, ", got ", input.sizes());
612     TORCH_CHECK(output.sizes().equals(output_size),
613         "Expected output size ", IntArrayRef{output_size}, ", got ", output.sizes());
614 
615     TORCH_CHECK(!hx.defined() || hx.sizes().equals(hidden_size),
616         "Expected hidden size ", IntArrayRef{hidden_size}, ", got ", hx.sizes());
617     TORCH_CHECK(!cx.defined() || cx.sizes().equals(hidden_size),
618         "Expected cell size ", IntArrayRef{hidden_size}, ", got ", cx.sizes());
619     TORCH_CHECK(!dhy.defined() || dhy.sizes().equals(hidden_size),
620         "Expected d_hidden size ", IntArrayRef{hidden_size}, ", got ", dhy.sizes());
621     TORCH_CHECK(!dcy.defined() || dcy.sizes().equals(hidden_size),
622         "Expected d_cell size ", IntArrayRef{hidden_size}, ", got ", dcy.sizes());
623 
624     TORCH_CHECK(dhy.is_cuda() && dy.is_cuda() && (!dcy.defined() || dcy.is_cuda()),
625         "Gradients aren't HIP tensors");
626 
627     miopenRNNAlgo_t algo = miopenRNNdefault;
628     fn.rnn.set_algo(algo);
629     RNNDescriptors descs(fn, handle, x, y, hx, cx);
630 
631     FilterDescriptor w_desc;
632     w_desc.set(weight_buf, 3);
633 
634     size_t workspace_size;
635     auto x_descs_arr = descs.get_x_descs();
636     auto y_descs_arr = descs.get_y_descs();
637 
638     MIOPEN_CHECK(miopenGetRNNWorkspaceSize(
639         handle,
640         descs.rnn_desc.desc(),
641         fn.tensors.seq_length,
642         x_descs_arr.data(),
643         &workspace_size
644         ));
645     auto workspace = at::empty(workspace_size, input.options().dtype(kByte));
646 
647     MIOPEN_CHECK(miopenRNNBackwardData(
648         handle,
649         descs.rnn_desc.desc(),
650         fn.tensors.seq_length,
651         y_descs_arr.data(), y.data_ptr(),
652         y_descs_arr.data(), dy.data_ptr(),
653         descs.hy_desc.desc(), dhy.data_ptr(),
654         descs.cy_desc.desc(), cx.defined() ? dcy.data_ptr() : nullptr,
655         w_desc.desc(), w.data_ptr(),
656         descs.hx_desc.desc(), hx.data_ptr(),
657         descs.cx_desc.desc(), cx.defined() ? cx.data_ptr() : nullptr,
658         x_descs_arr.data(), dx.data_ptr(),
659         descs.hx_desc.desc(), dhx.data_ptr(),
660         descs.cx_desc.desc(), cx.defined() ? dcx.data_ptr() : nullptr,
661         workspace.data_ptr(), workspace.size(0),
662         fn_reserve.data_ptr(), fn_reserve.size(0)
663         ));
664 
665     if(batch_first && !is_input_packed) {
666         dx = dx.transpose_(0, 1);
667     }
668 
669     return std::make_tuple(dx, dhx, dcx, workspace);
670 }
671 
miopen_rnn_backward_weight(const Tensor & input_r,TensorList weight_arr,int64_t weight_stride0,const Tensor & weight_buf,const Tensor & hx,const Tensor & cx,const Tensor & output_r,int64_t fn_mode,int64_t fn_hidden_size,int64_t fn_num_layers,bool batch_first,double fn_dropout,bool fn_train,bool fn_bidirectional,IntArrayRef fn_batch_sizes,const Tensor & fn_dropout_state,const Tensor & fn_reserve,const Tensor & fn_workspace)672 std::vector<Tensor> miopen_rnn_backward_weight(
673         const Tensor& input_r, TensorList weight_arr, int64_t weight_stride0,
674         const Tensor& weight_buf, const Tensor& hx, const Tensor& cx,
675         const Tensor& output_r,
676         int64_t fn_mode, int64_t fn_hidden_size,
677         int64_t fn_num_layers, bool batch_first, double fn_dropout,
678         bool fn_train, bool fn_bidirectional, IntArrayRef fn_batch_sizes,
679         const Tensor& fn_dropout_state, const Tensor& fn_reserve, const Tensor& fn_workspace
680         ) {
681     MatrixRef<Tensor> weight{ weight_arr, static_cast<size_t>(weight_stride0) };
682 
683     auto input = input_r;
684     auto output = output_r;
685 
686     RNNParams fn;
687     auto datatype = getMiopenDataType(input);
688     miopenRNNBiasMode_t bias_mode = (weight_stride0 == 4) ? miopenRNNwithBias : miopenRNNNoBias;
689     fn.rnn.set(fn_mode, fn_hidden_size, fn_num_layers, fn_bidirectional, datatype, bias_mode);
690     fn.tensors.set(input.sizes(), fn_batch_sizes, batch_first);
691 
692     auto handle = getMiopenHandle();
693 
694     if (fn.rnn.rnn_mode != miopenLSTM) {
695         TORCH_CHECK(!cx.defined(), "rnn: illegal defined cx for non-LSTM RNN");
696     }
697 
698     auto is_input_packed = fn_batch_sizes.size() != 0;
699     if (batch_first && !is_input_packed) {
700         input = input.transpose(0, 1);
701         output = output.transpose(0, 1);
702     }
703 
704     auto input_size = _input_size(fn.tensors);
705     auto hidden_size = _hidden_size(fn.rnn, fn.tensors);
706 
707     TORCH_CHECK(fn_train, "miopen RNN backward can only be called in training mode");
708 
709     TORCH_CHECK(input.sizes().equals(input_size),
710         "Expected input size ", IntArrayRef{input_size}, ", got ", input.sizes());
711     TORCH_CHECK(!hx.defined() || hx.sizes().equals(hidden_size),
712         "Expected hidden size ", IntArrayRef{hidden_size}, ", got ", hx.sizes());
713 
714     TORCH_CHECK(hx.is_contiguous(), "rnn: hx is not contiguous");
715     TORCH_CHECK(!cx.defined() || cx.is_contiguous(), "rnn: cx is not contiguous");
716 
717     auto x = input.contiguous();
718     const auto& y = output;
719     auto dw = at::zeros(weight_buf.sizes(), weight_buf.options());
720 
721     miopenRNNAlgo_t algo = miopenRNNdefault;
722     fn.rnn.set_algo(algo);
723     RNNDescriptors descs(fn, handle, x, y, hx, cx);
724 
725     FilterDescriptor w_desc;
726     w_desc.set(weight_buf, 3);
727 
728     auto x_descs_arr = descs.get_x_descs();
729     auto y_descs_arr = descs.get_y_descs();
730 
731     MIOPEN_CHECK(miopenRNNBackwardWeights(
732         handle,
733         descs.rnn_desc.desc(),
734         fn.tensors.seq_length,
735         x_descs_arr.data(), x.data_ptr(),
736         descs.hx_desc.desc(), hx.data_ptr(),
737         y_descs_arr.data(), y.data_ptr(),
738         w_desc.desc(), dw.data_ptr(),
739         fn_workspace.data_ptr(), fn_workspace.size(0),
740         fn_reserve.data_ptr(), fn_reserve.size(0)
741         ));
742 
743     auto [grad_params_arr, grad_params_stride0] = get_parameters(handle, fn.rnn, descs.rnn_desc, descs.x_descs[0], w_desc, dw);
744     if (grad_params_stride0 == static_cast<size_t>(weight_stride0)) {
745         _viewParams(MatrixRef<Tensor>{grad_params_arr, grad_params_stride0},
746             MatrixRef<Tensor>{weight_arr, static_cast<size_t>(weight_stride0)});
747         return grad_params_arr;
748     } else {
749         std::vector<Tensor> grad_weight_arr;
750         grad_weight_arr.reserve( weight.numel() );
751         for (const auto& w : weight_arr) {
752             grad_weight_arr.emplace_back(at::empty(w.sizes(), w.options()));
753         }
754         _copyParams(MatrixRef<Tensor>{grad_params_arr, grad_params_stride0},
755             MatrixRef<Tensor>{grad_weight_arr, static_cast<size_t>(weight_stride0)});
756         return grad_weight_arr;
757     }
758 }
759 
miopen_rnn_backward(const Tensor & input,TensorList weight,int64_t weight_stride0,const Tensor & weight_buf,const Tensor & hx,const std::optional<Tensor> & cx_opt,const Tensor & output,const std::optional<Tensor> & grad_output_r_opt,const std::optional<Tensor> & grad_hy_r_opt,const std::optional<Tensor> & grad_cy_r_opt,int64_t mode,int64_t hidden_size,int64_t num_layers,bool batch_first,double dropout,bool train,bool bidirectional,IntArrayRef batch_sizes,const std::optional<Tensor> & dropout_state_opt,const Tensor & reserve,std::array<bool,4> output_mask)760 std::tuple<Tensor, Tensor, Tensor, std::vector<Tensor>> miopen_rnn_backward(
761         const Tensor& input, TensorList weight, int64_t weight_stride0, const Tensor& weight_buf, const Tensor& hx, const std::optional<Tensor>& cx_opt,
762         const Tensor& output, const std::optional<Tensor>& grad_output_r_opt, const std::optional<Tensor>& grad_hy_r_opt, const std::optional<Tensor>& grad_cy_r_opt, int64_t mode, int64_t hidden_size, int64_t num_layers, bool batch_first,
763         double dropout, bool train, bool bidirectional, IntArrayRef batch_sizes, const std::optional<Tensor>& dropout_state_opt,
764         const Tensor& reserve, std::array<bool, 4> output_mask
765         ) {
766     // See [Note: hacky wrapper removal for optional tensor]
767     c10::MaybeOwned<Tensor> cx_maybe_owned = at::borrow_from_optional_tensor(cx_opt);
768     const Tensor& cx = *cx_maybe_owned;
769     const Tensor& grad_output_r = c10::value_or_else(grad_output_r_opt, [] {return Tensor();});
770     const Tensor& grad_hy_r = c10::value_or_else(grad_hy_r_opt, [] {return Tensor();});
771     const Tensor& grad_cy_r = c10::value_or_else(grad_cy_r_opt, [] {return Tensor();});
772     const Tensor& dropout_state = c10::value_or_else(dropout_state_opt, [] {return Tensor();});
773 
774     if (!grad_output_r.defined() && !grad_hy_r.defined() && !grad_cy_r.defined()) {
775         return std::tuple<Tensor, Tensor, Tensor, std::vector<Tensor>>(Tensor(), Tensor(), Tensor(), std::vector<Tensor>(weight.size()));
776     }
777     auto grad_output = grad_output_r.defined() ? grad_output_r : at::zeros_like(output, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
778     auto grad_hy = grad_hy_r.defined() ? grad_hy_r : at::zeros_like(hx, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
779     auto grad_cy = cx.defined() ? (grad_cy_r.defined() ? grad_cy_r : at::zeros_like(cx, LEGACY_CONTIGUOUS_MEMORY_FORMAT)) : grad_cy_r;
780 
781     auto [dx, dhx, dcx, ws] = at::native::miopen_rnn_backward_input(input, weight_buf, hx, cx, output, grad_output, grad_hy, grad_cy, mode, hidden_size, num_layers, batch_first, dropout, train, bidirectional, batch_sizes, dropout_state, reserve, {output_mask[0], output_mask[1], output_mask[2]});
782     std::vector<Tensor> dw;
783     if (output_mask[3]) {
784         dw = at::native::miopen_rnn_backward_weight(input, weight, weight_stride0, weight_buf, hx, cx, output, mode, hidden_size, num_layers, batch_first, dropout, train, bidirectional, batch_sizes, dropout_state, reserve, ws);
785         if (mode > 1) {
786             for (const auto i : c10::irange(dw.size())) {
787                 dw[i] = permute_wei_for_miopen(dw[i], mode);
788             }
789         }
790     }
791     return std::tuple<Tensor, Tensor, Tensor, std::vector<Tensor>>{dx, dhx, dcx, dw};
792 }
793 
794 namespace {
795 
unpack_hidden(const Tensor & hidden)796 std::tuple<Tensor, Tensor> unpack_hidden(const Tensor& hidden) {
797     return std::make_tuple(hidden, at::Tensor{});
798 }
799 
unpack_hidden(const std::tuple<Tensor,Tensor> & hidden)800 std::tuple<Tensor, Tensor> unpack_hidden(const std::tuple<Tensor, Tensor>& hidden) {
801     return hidden;
802 }
803 
804 template<typename hidden_type>
pack_hidden(const Tensor & hx,const Tensor & cx)805 hidden_type pack_hidden(const Tensor& hx, const Tensor& cx) {
806     static_assert(std::is_same<hidden_type, void>::value, "pack_hidden not implemented for this type");
807     AT_ERROR("NOT IMPLEMENTED");
808 }
809 
810 template<>
pack_hidden(const Tensor & hx,const Tensor & cx)811 Tensor pack_hidden<Tensor>(const Tensor& hx, const Tensor& cx) {
812     AT_ASSERT(cx.numel() == 0);
813     return hx;
814 }
815 
816 template<>
pack_hidden(const Tensor & hx,const Tensor & cx)817 std::tuple<Tensor, Tensor> pack_hidden<std::tuple<Tensor, Tensor>>(const Tensor& hx, const Tensor& cx) {
818     return std::make_tuple(hx, cx);
819 }
820 
821 template<typename hidden_type>
_miopen_impl(const Tensor & input,const Tensor & _batch_sizes,const hidden_type & hidden,TensorList params,bool has_biases,miopenRNNMode_t mode,int64_t num_layers,double dropout_p,bool train,bool bidirectional)822 std::pair<Tensor, hidden_type> _miopen_impl(
823     const Tensor& input, const Tensor& _batch_sizes, const hidden_type& hidden,
824     TensorList params, bool has_biases, miopenRNNMode_t mode,
825     int64_t num_layers, double dropout_p, bool train, bool bidirectional) {
826     auto [hx, cx] = unpack_hidden(hidden);
827     int64_t hidden_size = hx.size(2);
828 
829     TORCH_CHECK(_batch_sizes.dim() == 1, "batch_sizes tensor should be 1D");
830     IntArrayRef batch_sizes { _batch_sizes.data_ptr<int64_t>(), static_cast<size_t>(_batch_sizes.size(0)) };
831 
832     Tensor dropout_state = at::empty({0}, input.options());
833 
834     auto miopen_output = at::miopen_rnn(
835         input, params, has_biases ? 4 : 2,
836         hx, cx, static_cast<int>(mode), hidden_size, num_layers, /*batch_first=*/false,
837         dropout_p, train, bidirectional, batch_sizes, dropout_state);
838 
839     return {std::get<0>(miopen_output),
840         pack_hidden<hidden_type>(std::get<1>(miopen_output), std::get<2>(miopen_output))};
841 }
842 
843 template<typename hidden_type>
_miopen_impl(const Tensor & input,const hidden_type & hidden,TensorList params,bool has_biases,miopenRNNMode_t mode,int64_t num_layers,double dropout_p,bool train,bool bidirectional,bool batch_first)844 std::pair<Tensor, hidden_type> _miopen_impl(
845     const Tensor& input, const hidden_type& hidden,
846     TensorList params, bool has_biases, miopenRNNMode_t mode,
847     int64_t num_layers, double dropout_p, bool train, bool bidirectional, bool batch_first) {
848     auto [hx, cx] = unpack_hidden(hidden);
849     int64_t hidden_size = hx.size(2);
850 
851     Tensor dropout_state = at::empty({0}, input.options());
852 
853     auto miopen_output = at::miopen_rnn(
854         input, params, has_biases ? 4 : 2,
855         hx, cx, static_cast<int>(mode), hidden_size, num_layers, batch_first, dropout_p,
856         train, bidirectional, /*batch_sizes=*/{}, dropout_state);
857 
858     return {std::get<0>(miopen_output),
859         pack_hidden<hidden_type>(std::get<1>(miopen_output), std::get<2>(miopen_output))};
860 }
861 
862 #define ONE_HIDDEN_RNN(NAME, MODE)                                             \
863 void NAME##_miopen(Tensor& output, Tensor& hy,                                 \
864       const Tensor& input, const Tensor& hx,                                   \
865       TensorList params, bool has_biases,                                      \
866       int64_t num_layers, double dropout_p, bool train, bool bidirectional, bool batch_first) { \
867   std::tie(output, hy) = _miopen_impl(input, hx, params, has_biases,           \
868       MODE, num_layers, dropout_p, train, bidirectional, batch_first);         \
869 }                                                                              \
870                                                                                \
871 void NAME##_packed_miopen(Tensor& output, Tensor& hy,                          \
872       const Tensor& data, const Tensor& batch_sizes, const Tensor& hx,         \
873       TensorList params, bool has_biases,                                      \
874       int64_t num_layers, double dropout_p, bool train, bool bidirectional) {  \
875   std::tie(output, hy) = _miopen_impl(data, batch_sizes, hx, params,           \
876       has_biases, MODE, num_layers, dropout_p, train, bidirectional);          \
877 }                                                                              \
878                                                                                \
879 REGISTER_CUDA_DISPATCH(NAME##_miopen_stub, &NAME##_miopen);                    \
880 REGISTER_CUDA_DISPATCH(NAME##_packed_miopen_stub, &NAME##_packed_miopen);
881 
ONE_HIDDEN_RNN(gru,miopenGRU)882 ONE_HIDDEN_RNN(gru, miopenGRU)
883 ONE_HIDDEN_RNN(rnn_tanh, miopenRNNTANH)
884 ONE_HIDDEN_RNN(rnn_relu, miopenRNNRELU)
885 
886 void lstm_miopen(Tensor& output, Tensor& hy, Tensor& cy,
887       const Tensor& input, TensorList hx,
888       TensorList params, bool has_biases,
889       int64_t num_layers, double dropout_p, bool train, bool bidirectional, bool batch_first) {
890     auto result = _miopen_impl(input, std::make_tuple(hx[0], hx[1]), params, has_biases,
891         miopenLSTM, num_layers, dropout_p, train, bidirectional, batch_first);
892     output = result.first;
893     hy = std::get<0>(result.second);
894     cy = std::get<1>(result.second);
895 }
896 
lstm_packed_miopen(Tensor & output,Tensor & hy,Tensor & cy,const Tensor & data,const Tensor & batch_sizes,TensorList hx,TensorList params,bool has_biases,int64_t num_layers,double dropout_p,bool train,bool bidirectional)897 void lstm_packed_miopen(Tensor& output, Tensor& hy, Tensor& cy,
898       const Tensor& data, const Tensor& batch_sizes, TensorList hx,
899       TensorList params, bool has_biases,
900       int64_t num_layers, double dropout_p, bool train, bool bidirectional) {
901     auto result = _miopen_impl(data, batch_sizes, std::make_tuple(hx[0], hx[1]),
902         params, has_biases, miopenLSTM, num_layers, dropout_p, train, bidirectional);
903     output = result.first;
904     hy = std::get<0>(result.second);
905     cy = std::get<1>(result.second);
906 }
907 
908 REGISTER_CUDA_DISPATCH(lstm_miopen_stub, &lstm_miopen);
909 REGISTER_CUDA_DISPATCH(lstm_packed_miopen_stub, &lstm_packed_miopen);
910 
911 } // anonymous namespace
912 }} //namespace native.
913 
914 #endif
915