#define TORCH_ASSERT_ONLY_METHOD_OPERATORS #include #include #include #include #include #include #include #include #include #ifndef AT_PER_OPERATOR_HEADERS #include #include #else #include #include #include #include #include #include #include #endif #if !AT_ROCM_ENABLED() namespace at { namespace native { std::tuple miopen_rnn( const Tensor& input_r, TensorList weight, int64_t weight_stride0, const Tensor& hx, const std::optional& cx_opt, int64_t fn_mode, int64_t fn_hidden_size, int64_t fn_num_layers, bool batch_first, double fn_dropout, bool fn_train, bool fn_bidirectional, IntArrayRef fn_batch_sizes, const std::optional& fn_dropout_state_opt ) { AT_ERROR("miopen_rnn : ATen not compiled with MIOpen support."); } std::tuple> miopen_rnn_backward( const Tensor& input, TensorList weight, int64_t weight_stride0, const Tensor& weight_buf, const Tensor& hx, const std::optional& cx_opt, const Tensor& output, const std::optional& grad_output_r_opt, const std::optional& grad_hy_r_opt, const std::optional& grad_cy_r_opt, int64_t mode, int64_t hidden_size, int64_t num_layers, bool batch_first, double dropout, bool train, bool bidirectional, IntArrayRef batch_sizes, const std::optional& dropout_state_opt, const Tensor& reserve, std::array output_mask ) { AT_ERROR("miopen_rnn_backward: ATen not compiled with MIOpen support."); } }} //namespace at::native #else // AT_ROCM_ENABLED() #include #include #include #include #include #include #include #include #include #include #include #include #include namespace at { namespace native { //RNNDescriptor. struct RNNDescriptorParams { int64_t hidden_size; int64_t num_layers; miopenRNNDirectionMode_t direction; miopenRNNMode_t rnn_mode; miopenDataType_t datatype; miopenRNNAlgo_t algo = miopenRNNdefault; miopenRNNInputMode_t input_mode = miopenRNNlinear; miopenRNNBiasMode_t bias_mode = miopenRNNNoBias; int64_t num_directions() const { return (direction == miopenRNNbidirection) ? 2 : 1; } void set_bidirectional(bool fn_bidirectional) { direction = fn_bidirectional ? miopenRNNbidirection : miopenRNNunidirection; } void set_algo(miopenRNNAlgo_t algo) { this->algo = algo; } void set_mode(int64_t fn_mode) { switch (fn_mode) { case 0: rnn_mode = miopenRNNRELU; break; case 1: rnn_mode = miopenRNNTANH; break; case 2: rnn_mode = miopenLSTM; break; case 3: rnn_mode = miopenGRU; break; default: { std::ostringstream oss; oss << "unrecognized miopen RNN mode " << fn_mode; AT_ERROR(oss.str()); } } } void set(int64_t mode, int64_t hidden_size, int64_t num_layers, bool bidirectional, miopenDataType_t datatype, miopenRNNBiasMode_t bias_mode) { this->set_mode(mode); this->hidden_size = hidden_size; this->num_layers = num_layers; this->set_bidirectional(bidirectional); this->datatype = datatype; this->bias_mode = bias_mode; } RNNDescriptor descriptor() const { RNNDescriptor rnn_desc; rnn_desc.set(hidden_size, num_layers, input_mode, direction, rnn_mode, bias_mode, algo, datatype); return rnn_desc; } }; //TensorDescriptor list. std::vector rnn_descriptor_sequence(const Tensor& tensor, IntArrayRef batch_sizes) { std::vector descriptors(batch_sizes.size()); size_t i =0; auto batch_tensor_size = tensor.sizes().vec(); for (auto batch_size : batch_sizes) { batch_tensor_size[0] = batch_size; descriptors[i].set(getMiopenDataType(tensor), batch_tensor_size, tensor.strides(), 3); i++; } return descriptors; } std::vector rnn_descriptor(const Tensor& tensor, int64_t N) { std::vector descriptors(N); for (const auto i : c10::irange(N)) { descriptors[i].set(tensor, 5); } return descriptors; } struct TensorDescriptorListParams { IntArrayRef batch_sizes; int64_t seq_length; int64_t mini_batch; int64_t input_size; int64_t batch_sizes_sum; bool is_input_packed() const { return batch_sizes.size() != 0; } void set(IntArrayRef input_sizes, IntArrayRef batch_sizes_, bool batch_first) { batch_sizes = batch_sizes_; if (is_input_packed()) { seq_length = batch_sizes.size(); mini_batch = batch_sizes[0]; batch_sizes_sum = input_sizes[0]; input_size = input_sizes[1]; } else { if (batch_first) { seq_length = input_sizes[1]; mini_batch = input_sizes[0]; } else { seq_length = input_sizes[0]; mini_batch = input_sizes[1]; } input_size = input_sizes[2]; batch_sizes_sum = -1; } } std::vector descriptors(Tensor x) const { auto is_input_packed = batch_sizes.size() != 0; if (is_input_packed) { return rnn_descriptor_sequence(x, batch_sizes); } else { return rnn_descriptor(x[0], seq_length); } } }; struct RNNParams { RNNDescriptorParams rnn; TensorDescriptorListParams tensors; }; struct RNNDescriptors { RNNDescriptor rnn_desc; std::vector x_descs; std::vector y_descs; TensorDescriptor hx_desc; TensorDescriptor hy_desc; TensorDescriptor cx_desc; TensorDescriptor cy_desc; RNNDescriptors(const RNNParams& fn, miopenHandle_t handle, Tensor x, Tensor y, Tensor hx, Tensor cx) { rnn_desc = fn.rnn.descriptor(); x_descs = fn.tensors.descriptors(x); y_descs = fn.tensors.descriptors(y); hx_desc.set(hx, 5); hy_desc.set(hx, 5); cx_desc.set(hx, 5); cy_desc.set(hx, 5); } std::vector get_descs(const std::vector& descs) { std::vector r; r.reserve(descs.size()); for (auto& desc : descs) { r.emplace_back(desc.desc()); } return r; } std::vector get_x_descs() { return get_descs(x_descs); } std::vector get_y_descs() { return get_descs(y_descs); } }; Tensor permute_wei_for_miopen(Tensor wei, int64_t mode) { if (mode < 2) return wei; Tensor permuted_wei; if(mode == 2) { // LSTM auto sliced_tensor = wei.chunk(4, 0); permuted_wei = at::cat({sliced_tensor[0], sliced_tensor[1], sliced_tensor[3], sliced_tensor[2]}); } else if(mode == 3) { // GRU auto sliced_tensor = wei.chunk(3, 0); permuted_wei = at::cat({sliced_tensor[1], sliced_tensor[0], sliced_tensor[2]}); } return permuted_wei; } void _viewOrCopyParams(MatrixRef params_from, MatrixRef params_to, bool copy) { TORCH_CHECK(params_from.size(0) == params_to.size(0), "number of layers mismatch"); for (const auto i : c10::irange(params_from.size(0))) { auto layer_params_from = params_from[i]; auto layer_params_to = params_to[i]; // NOTE: these lists have all weights before all biases, so if the layer // doesn't use biases, iteration will terminate once layer_params_from ends // and ignore them. for (auto a = layer_params_from.begin(), b = layer_params_to.begin(); a != layer_params_from.end() && b != layer_params_to.end(); ++a, ++b) { auto param_from = *a, param_to = *b; TORCH_CHECK(param_from.type() == param_to.type(), "parameter types mismatch"); if (copy) { param_to.copy_(param_from.view_as(param_to)); } else { param_from.resize_as_(param_to); } } } } void _copyParams_and_permute(MatrixRef params_from, MatrixRef params_to, int64_t mode) { TORCH_CHECK(params_from.size(0) == params_to.size(0), "number of layers mismatch"); for (const auto i : c10::irange(params_from.size(0))) { auto layer_params_from = params_from[i]; auto layer_params_to = params_to[i]; for (auto a = layer_params_from.begin(), b = layer_params_to.begin(); a != layer_params_from.end() && b != layer_params_to.end(); ++a, ++b) { auto param_from = *a, param_to = *b; TORCH_CHECK(param_from.type() == param_to.type(), "parameter types mismatch"); auto tmp = permute_wei_for_miopen(param_from, mode); param_to.copy_(tmp.view_as(param_to)); } } } void _copyParams(MatrixRef params_from, MatrixRef params_to) { _viewOrCopyParams(params_from, params_to, true); } void _viewParams(MatrixRef params_from, MatrixRef params_to) { _viewOrCopyParams(params_from, params_to, false); } int64_t get_num_weights(miopenHandle_t handle, const RNNDescriptor& rnn_desc, const TensorDescriptor& x_desc, miopenDataType_t datatype) { size_t weight_size; MIOPEN_CHECK(miopenGetRNNParamsSize(handle, rnn_desc.desc(), x_desc.desc(), &weight_size, datatype)); auto element_size = dataSize(datatype); TORCH_CHECK(weight_size % element_size == 0, "miopenGetRNNParamsSize returned nonsensical weight_size."); return weight_size / element_size; } int64_t _num_linear_layers(miopenRNNMode_t mode) { switch(mode) { case miopenLSTM: return 8; case miopenGRU: return 6; case miopenRNNRELU: return 2; case miopenRNNTANH: return 2; default: AT_ERROR("Unknown miopen RNN mode : ", mode); } } std::pair, size_t> get_parameters(miopenHandle_t handle, const RNNDescriptorParams& rnn, const RNNDescriptor& rnn_desc, const TensorDescriptor& x_desc, const FilterDescriptor& w_desc, const Tensor& weight_buf) { std::vector params; int64_t num_linear_layers = _num_linear_layers(rnn.rnn_mode); int64_t num_layers = rnn.num_directions() * rnn.num_layers; size_t cur_offset = 0; size_t global_layer_params_count = 0; auto elem_size = dataSize(getMiopenDataType(weight_buf)); auto bias_mode = rnn.bias_mode; for (const auto layer : c10::irange(num_layers)) { size_t layer_params_count = 0; // Get layer params for (const auto linear_id : c10::irange(num_linear_layers)) { FilterDescriptor lin_layer_mat_desc; size_t offset; MIOPEN_CHECK(miopenGetRNNLayerParamOffset( rnn_desc.desc(), layer, x_desc.desc(), linear_id, lin_layer_mat_desc.mut_desc(), &offset)); size_t param_size; MIOPEN_CHECK(miopenGetRNNLayerParamSize( handle, rnn_desc.desc(), layer, x_desc.desc(), linear_id, ¶m_size)); param_size /= elem_size; if(linear_id == 0 || linear_id == num_linear_layers / 2) { std::initializer_list size = { static_cast(param_size * num_linear_layers / 2), 1L}; Tensor param = at::empty({0}, weight_buf.options()).set_(weight_buf.storage(), offset, size); params.emplace_back(std::move(param)); layer_params_count++; } else { TORCH_INTERNAL_ASSERT(cur_offset == offset, "cur_offset = ", cur_offset, " ; offset = ", offset); } cur_offset = offset + param_size; } // Get bias params if (bias_mode == miopenRNNwithBias) { for (const auto linear_id : c10::irange(num_linear_layers)) { FilterDescriptor lin_layer_mat_desc; size_t offset; MIOPEN_CHECK(miopenGetRNNLayerBiasOffset( rnn_desc.desc(), layer, x_desc.desc(), linear_id, lin_layer_mat_desc.mut_desc(), &offset)); size_t bias_size; MIOPEN_CHECK(miopenGetRNNLayerBiasSize( handle, rnn_desc.desc(), layer, linear_id, &bias_size)); bias_size /= elem_size; if(linear_id == 0 || linear_id == num_linear_layers / 2) { std::initializer_list size = { static_cast(bias_size * num_linear_layers / 2), 1L}; Tensor param = at::empty({0}, weight_buf.options()).set_(weight_buf.storage(), offset, size); params.emplace_back(std::move(param)); layer_params_count++; } else { TORCH_INTERNAL_ASSERT(cur_offset == offset, "cur_offset = ", cur_offset, " ; offset = ", offset); } cur_offset = offset + bias_size; } } if (layer == 0) { global_layer_params_count = layer_params_count; } else { TORCH_INTERNAL_ASSERT(global_layer_params_count == layer_params_count, "global_layer_params_count = ", global_layer_params_count, "; layer_params_count = ", layer_params_count); } } // layer return std::make_pair(params, global_layer_params_count); } std::vector _input_size(const TensorDescriptorListParams& tensors) { if (tensors.is_input_packed()) { return {tensors.batch_sizes_sum, tensors.input_size}; } else { return {tensors.seq_length, tensors.mini_batch, tensors.input_size}; } } std::vector _hidden_size(const RNNDescriptorParams& rnn, const TensorDescriptorListParams& tensors) { return {rnn.num_layers * rnn.num_directions(), tensors.mini_batch, rnn.hidden_size}; } std::vector _output_size(const RNNDescriptorParams& rnn, const TensorDescriptorListParams& tensors) { if (tensors.is_input_packed()) { return {tensors.batch_sizes_sum, rnn.hidden_size * rnn.num_directions()}; } else { return {tensors.seq_length, tensors.mini_batch, rnn.hidden_size * rnn.num_directions()}; } } std::tuple miopen_rnn( const Tensor& input_r, TensorList weight, int64_t weight_stride0, const Tensor& hx, const std::optional& cx_opt, int64_t fn_mode, int64_t fn_hidden_size, int64_t fn_num_layers, bool batch_first, double fn_dropout, bool fn_train, bool fn_bidirectional, IntArrayRef fn_batch_sizes, const std::optional& fn_dropout_state_opt ) { // See [Note: hacky wrapper removal for optional tensor] c10::MaybeOwned cx_maybe_owned = at::borrow_from_optional_tensor(cx_opt); const Tensor& cx = *cx_maybe_owned; const Tensor& fn_dropout_state = c10::value_or_else(fn_dropout_state_opt, [] {return Tensor();}); check_attributes(input_r, weight, {hx, cx}); auto input = input_r; RNNParams fn; auto datatype = getMiopenDataType(input); miopenRNNBiasMode_t bias_mode = (weight_stride0 == 4) ? miopenRNNwithBias : miopenRNNNoBias; fn.rnn.set(fn_mode, fn_hidden_size, fn_num_layers, fn_bidirectional, datatype, bias_mode); fn.tensors.set(input.sizes(), fn_batch_sizes, batch_first); if (fn.rnn.rnn_mode != miopenLSTM) { TORCH_CHECK(!cx.defined(), "miopen_rnn: illegal defined cx for non-LSTM RNN."); } auto is_input_packed = fn.tensors.batch_sizes.size() != 0; if (batch_first && !is_input_packed) { input = input.transpose(0, 1); } auto hidden_size = _hidden_size(fn.rnn, fn.tensors); auto output_size = _output_size(fn.rnn, fn.tensors); TORCH_CHECK(hx.is_contiguous(), "miopen_rnn : hx is not contiguous."); TORCH_CHECK(!cx.defined() || cx.is_contiguous(), "miopen_rnn : cx is not contiguous."); auto x = input.contiguous(); auto output = at::empty(output_size, input.options()); auto hy = at::empty(hidden_size, hx.options()); Tensor cy; if (cx.defined()) { cy = at::empty(hidden_size, cx.options()); } else { cy = at::empty({0}, hx.options()); } auto y = output; auto handle = getMiopenHandle(); miopenRNNAlgo_t algo = miopenRNNdefault; fn.rnn.set_algo(algo); RNNDescriptors descs(fn, handle, x, y, hx, cx); FilterDescriptor w_desc; auto num_weights = get_num_weights(handle, descs.rnn_desc, descs.x_descs[0], datatype); auto weight_buf = at::empty(num_weights, x.options()); w_desc.set(weight_buf, 3); weight_buf.zero_(); auto [params, params_stride0] = get_parameters(handle, fn.rnn, descs.rnn_desc, descs.x_descs[0], w_desc, weight_buf); if (fn_mode < 2) _copyParams(MatrixRef{weight, static_cast(weight_stride0)}, MatrixRef{params, params_stride0}); else _copyParams_and_permute(MatrixRef{weight, static_cast(weight_stride0)}, MatrixRef{params, params_stride0}, fn_mode); TORCH_CHECK(!cx.defined() || cx.sizes().equals(hidden_size), "Expected cell size ", IntArrayRef{hidden_size}, ", got", cx.sizes()); size_t workspace_size; auto x_descs_arr = descs.get_x_descs(); auto y_descs_arr = descs.get_y_descs(); //Allocate workspace size. MIOPEN_CHECK(miopenGetRNNWorkspaceSize(handle, descs.rnn_desc.desc(), fn.tensors.seq_length, x_descs_arr.data(), &workspace_size)); auto workspace = at::empty(workspace_size, input.options().dtype(kByte)); //Train or inference. Tensor reserve; if (fn_train) { //Train. size_t reserver_size; MIOPEN_CHECK(miopenGetRNNTrainingReserveSize(handle, descs.rnn_desc.desc(), fn.tensors.seq_length, x_descs_arr.data(), &reserver_size)); reserve = at::empty(reserver_size, input.options().dtype(kByte)); MIOPEN_CHECK(miopenRNNForwardTraining(handle, descs.rnn_desc.desc(), fn.tensors.seq_length, x_descs_arr.data(), x.data_ptr(), descs.hx_desc.desc(), hx.data_ptr(), descs.cx_desc.desc(), cx.defined() ? cx.data_ptr() : nullptr, w_desc.desc(), weight_buf.data_ptr(), y_descs_arr.data(), y.data_ptr(), descs.hy_desc.desc(), hy.data_ptr(), descs.cy_desc.desc(), cy.defined() ? cy.data_ptr() : nullptr, workspace.data_ptr(), workspace_size, reserve.mutable_data_ptr(), reserver_size )); } else { //Inference. reserve = at::empty({0}, input.options().dtype(kByte)); MIOPEN_CHECK(miopenRNNForwardInference(handle, descs.rnn_desc.desc(), fn.tensors.seq_length, x_descs_arr.data(), x.data_ptr(), descs.hx_desc.desc(), hx.data_ptr(), descs.cx_desc.desc(), cx.defined() ? cx.data_ptr() : nullptr, w_desc.desc(), weight_buf.data_ptr(), y_descs_arr.data(), y.data_ptr(), descs.hy_desc.desc(), hy.data_ptr(), descs.cy_desc.desc(), cy.defined() ? cy.data_ptr() : nullptr, workspace.data_ptr(), workspace_size)); } if (batch_first && !is_input_packed) { output.transpose_(0, 1); } return std::make_tuple(output, hy, cy, reserve, weight_buf); } std::tuple miopen_rnn_backward_input( const Tensor& input_r, const Tensor& weight_buf, const Tensor& hx, const Tensor& cx, const Tensor& output_r, const Tensor& grad_output_r, const Tensor& grad_hy, const Tensor& grad_cy, int64_t fn_mode, int64_t fn_hidden_size, int64_t fn_num_layers, bool batch_first, double fn_dropout, bool fn_train, bool fn_bidirectional, IntArrayRef fn_batch_sizes, const Tensor& fn_dropout_state, const Tensor& fn_reserve, std::array output_mask ) { auto input = input_r; auto grad_output = grad_output_r; auto output = output_r; RNNParams fn; auto datatype = getMiopenDataType(input); fn.rnn.set(fn_mode, fn_hidden_size, fn_num_layers, fn_bidirectional, datatype, miopenRNNwithBias); fn.tensors.set(input.sizes(), fn_batch_sizes, batch_first); auto handle = getMiopenHandle(); if(fn.rnn.rnn_mode != miopenLSTM) { TORCH_CHECK(!cx.defined(), "rnn: illegal defined cx for non-LSTM RNN"); } auto is_input_packed = fn_batch_sizes.size() != 0; if (batch_first && !is_input_packed) { input = input.transpose(0, 1); grad_output = grad_output.transpose(0, 1); output = output.transpose(0, 1); } auto input_size = _input_size(fn.tensors); auto hidden_size = _hidden_size(fn.rnn, fn.tensors); auto output_size = _output_size(fn.rnn, fn.tensors); TORCH_CHECK(hx.is_contiguous(), "rnn: hx is not contiguous"); TORCH_CHECK(!cx.defined() || cx.is_contiguous(), "rnn: cx is not contiguous"); auto x = input.contiguous(); auto dy = grad_output.contiguous(); auto y = output; auto w = weight_buf; auto dx = at::empty(input.sizes(), input.options()); auto dhy = grad_hy.contiguous().view(hidden_size); auto dcy = grad_cy.defined() ? grad_cy.contiguous().view(hidden_size) : Tensor(); auto dhx = at::empty(hidden_size, hx.options()); TORCH_INTERNAL_ASSERT(cx.defined() || !output_mask[2], "illegally required grad of cx for non-LSTM RNN"); auto dcx = cx.defined() ? at::empty(hidden_size, cx.options()) : Tensor(); TORCH_CHECK(fn_train, "miopen RNN backward can only be called in training mode"); TORCH_CHECK(input.sizes().equals(input_size), "Expected input size ", IntArrayRef{input_size}, ", got ", input.sizes()); TORCH_CHECK(output.sizes().equals(output_size), "Expected output size ", IntArrayRef{output_size}, ", got ", output.sizes()); TORCH_CHECK(!hx.defined() || hx.sizes().equals(hidden_size), "Expected hidden size ", IntArrayRef{hidden_size}, ", got ", hx.sizes()); TORCH_CHECK(!cx.defined() || cx.sizes().equals(hidden_size), "Expected cell size ", IntArrayRef{hidden_size}, ", got ", cx.sizes()); TORCH_CHECK(!dhy.defined() || dhy.sizes().equals(hidden_size), "Expected d_hidden size ", IntArrayRef{hidden_size}, ", got ", dhy.sizes()); TORCH_CHECK(!dcy.defined() || dcy.sizes().equals(hidden_size), "Expected d_cell size ", IntArrayRef{hidden_size}, ", got ", dcy.sizes()); TORCH_CHECK(dhy.is_cuda() && dy.is_cuda() && (!dcy.defined() || dcy.is_cuda()), "Gradients aren't HIP tensors"); miopenRNNAlgo_t algo = miopenRNNdefault; fn.rnn.set_algo(algo); RNNDescriptors descs(fn, handle, x, y, hx, cx); FilterDescriptor w_desc; w_desc.set(weight_buf, 3); size_t workspace_size; auto x_descs_arr = descs.get_x_descs(); auto y_descs_arr = descs.get_y_descs(); MIOPEN_CHECK(miopenGetRNNWorkspaceSize( handle, descs.rnn_desc.desc(), fn.tensors.seq_length, x_descs_arr.data(), &workspace_size )); auto workspace = at::empty(workspace_size, input.options().dtype(kByte)); MIOPEN_CHECK(miopenRNNBackwardData( handle, descs.rnn_desc.desc(), fn.tensors.seq_length, y_descs_arr.data(), y.data_ptr(), y_descs_arr.data(), dy.data_ptr(), descs.hy_desc.desc(), dhy.data_ptr(), descs.cy_desc.desc(), cx.defined() ? dcy.data_ptr() : nullptr, w_desc.desc(), w.data_ptr(), descs.hx_desc.desc(), hx.data_ptr(), descs.cx_desc.desc(), cx.defined() ? cx.data_ptr() : nullptr, x_descs_arr.data(), dx.data_ptr(), descs.hx_desc.desc(), dhx.data_ptr(), descs.cx_desc.desc(), cx.defined() ? dcx.data_ptr() : nullptr, workspace.data_ptr(), workspace.size(0), fn_reserve.data_ptr(), fn_reserve.size(0) )); if(batch_first && !is_input_packed) { dx = dx.transpose_(0, 1); } return std::make_tuple(dx, dhx, dcx, workspace); } std::vector miopen_rnn_backward_weight( const Tensor& input_r, TensorList weight_arr, int64_t weight_stride0, const Tensor& weight_buf, const Tensor& hx, const Tensor& cx, const Tensor& output_r, int64_t fn_mode, int64_t fn_hidden_size, int64_t fn_num_layers, bool batch_first, double fn_dropout, bool fn_train, bool fn_bidirectional, IntArrayRef fn_batch_sizes, const Tensor& fn_dropout_state, const Tensor& fn_reserve, const Tensor& fn_workspace ) { MatrixRef weight{ weight_arr, static_cast(weight_stride0) }; auto input = input_r; auto output = output_r; RNNParams fn; auto datatype = getMiopenDataType(input); miopenRNNBiasMode_t bias_mode = (weight_stride0 == 4) ? miopenRNNwithBias : miopenRNNNoBias; fn.rnn.set(fn_mode, fn_hidden_size, fn_num_layers, fn_bidirectional, datatype, bias_mode); fn.tensors.set(input.sizes(), fn_batch_sizes, batch_first); auto handle = getMiopenHandle(); if (fn.rnn.rnn_mode != miopenLSTM) { TORCH_CHECK(!cx.defined(), "rnn: illegal defined cx for non-LSTM RNN"); } auto is_input_packed = fn_batch_sizes.size() != 0; if (batch_first && !is_input_packed) { input = input.transpose(0, 1); output = output.transpose(0, 1); } auto input_size = _input_size(fn.tensors); auto hidden_size = _hidden_size(fn.rnn, fn.tensors); TORCH_CHECK(fn_train, "miopen RNN backward can only be called in training mode"); TORCH_CHECK(input.sizes().equals(input_size), "Expected input size ", IntArrayRef{input_size}, ", got ", input.sizes()); TORCH_CHECK(!hx.defined() || hx.sizes().equals(hidden_size), "Expected hidden size ", IntArrayRef{hidden_size}, ", got ", hx.sizes()); TORCH_CHECK(hx.is_contiguous(), "rnn: hx is not contiguous"); TORCH_CHECK(!cx.defined() || cx.is_contiguous(), "rnn: cx is not contiguous"); auto x = input.contiguous(); const auto& y = output; auto dw = at::zeros(weight_buf.sizes(), weight_buf.options()); miopenRNNAlgo_t algo = miopenRNNdefault; fn.rnn.set_algo(algo); RNNDescriptors descs(fn, handle, x, y, hx, cx); FilterDescriptor w_desc; w_desc.set(weight_buf, 3); auto x_descs_arr = descs.get_x_descs(); auto y_descs_arr = descs.get_y_descs(); MIOPEN_CHECK(miopenRNNBackwardWeights( handle, descs.rnn_desc.desc(), fn.tensors.seq_length, x_descs_arr.data(), x.data_ptr(), descs.hx_desc.desc(), hx.data_ptr(), y_descs_arr.data(), y.data_ptr(), w_desc.desc(), dw.data_ptr(), fn_workspace.data_ptr(), fn_workspace.size(0), fn_reserve.data_ptr(), fn_reserve.size(0) )); auto [grad_params_arr, grad_params_stride0] = get_parameters(handle, fn.rnn, descs.rnn_desc, descs.x_descs[0], w_desc, dw); if (grad_params_stride0 == static_cast(weight_stride0)) { _viewParams(MatrixRef{grad_params_arr, grad_params_stride0}, MatrixRef{weight_arr, static_cast(weight_stride0)}); return grad_params_arr; } else { std::vector grad_weight_arr; grad_weight_arr.reserve( weight.numel() ); for (const auto& w : weight_arr) { grad_weight_arr.emplace_back(at::empty(w.sizes(), w.options())); } _copyParams(MatrixRef{grad_params_arr, grad_params_stride0}, MatrixRef{grad_weight_arr, static_cast(weight_stride0)}); return grad_weight_arr; } } std::tuple> miopen_rnn_backward( const Tensor& input, TensorList weight, int64_t weight_stride0, const Tensor& weight_buf, const Tensor& hx, const std::optional& cx_opt, const Tensor& output, const std::optional& grad_output_r_opt, const std::optional& grad_hy_r_opt, const std::optional& grad_cy_r_opt, int64_t mode, int64_t hidden_size, int64_t num_layers, bool batch_first, double dropout, bool train, bool bidirectional, IntArrayRef batch_sizes, const std::optional& dropout_state_opt, const Tensor& reserve, std::array output_mask ) { // See [Note: hacky wrapper removal for optional tensor] c10::MaybeOwned cx_maybe_owned = at::borrow_from_optional_tensor(cx_opt); const Tensor& cx = *cx_maybe_owned; const Tensor& grad_output_r = c10::value_or_else(grad_output_r_opt, [] {return Tensor();}); const Tensor& grad_hy_r = c10::value_or_else(grad_hy_r_opt, [] {return Tensor();}); const Tensor& grad_cy_r = c10::value_or_else(grad_cy_r_opt, [] {return Tensor();}); const Tensor& dropout_state = c10::value_or_else(dropout_state_opt, [] {return Tensor();}); if (!grad_output_r.defined() && !grad_hy_r.defined() && !grad_cy_r.defined()) { return std::tuple>(Tensor(), Tensor(), Tensor(), std::vector(weight.size())); } auto grad_output = grad_output_r.defined() ? grad_output_r : at::zeros_like(output, LEGACY_CONTIGUOUS_MEMORY_FORMAT); auto grad_hy = grad_hy_r.defined() ? grad_hy_r : at::zeros_like(hx, LEGACY_CONTIGUOUS_MEMORY_FORMAT); auto grad_cy = cx.defined() ? (grad_cy_r.defined() ? grad_cy_r : at::zeros_like(cx, LEGACY_CONTIGUOUS_MEMORY_FORMAT)) : grad_cy_r; auto [dx, dhx, dcx, ws] = at::native::miopen_rnn_backward_input(input, weight_buf, hx, cx, output, grad_output, grad_hy, grad_cy, mode, hidden_size, num_layers, batch_first, dropout, train, bidirectional, batch_sizes, dropout_state, reserve, {output_mask[0], output_mask[1], output_mask[2]}); std::vector dw; if (output_mask[3]) { dw = at::native::miopen_rnn_backward_weight(input, weight, weight_stride0, weight_buf, hx, cx, output, mode, hidden_size, num_layers, batch_first, dropout, train, bidirectional, batch_sizes, dropout_state, reserve, ws); if (mode > 1) { for (const auto i : c10::irange(dw.size())) { dw[i] = permute_wei_for_miopen(dw[i], mode); } } } return std::tuple>{dx, dhx, dcx, dw}; } namespace { std::tuple unpack_hidden(const Tensor& hidden) { return std::make_tuple(hidden, at::Tensor{}); } std::tuple unpack_hidden(const std::tuple& hidden) { return hidden; } template hidden_type pack_hidden(const Tensor& hx, const Tensor& cx) { static_assert(std::is_same::value, "pack_hidden not implemented for this type"); AT_ERROR("NOT IMPLEMENTED"); } template<> Tensor pack_hidden(const Tensor& hx, const Tensor& cx) { AT_ASSERT(cx.numel() == 0); return hx; } template<> std::tuple pack_hidden>(const Tensor& hx, const Tensor& cx) { return std::make_tuple(hx, cx); } template std::pair _miopen_impl( const Tensor& input, const Tensor& _batch_sizes, const hidden_type& hidden, TensorList params, bool has_biases, miopenRNNMode_t mode, int64_t num_layers, double dropout_p, bool train, bool bidirectional) { auto [hx, cx] = unpack_hidden(hidden); int64_t hidden_size = hx.size(2); TORCH_CHECK(_batch_sizes.dim() == 1, "batch_sizes tensor should be 1D"); IntArrayRef batch_sizes { _batch_sizes.data_ptr(), static_cast(_batch_sizes.size(0)) }; Tensor dropout_state = at::empty({0}, input.options()); auto miopen_output = at::miopen_rnn( input, params, has_biases ? 4 : 2, hx, cx, static_cast(mode), hidden_size, num_layers, /*batch_first=*/false, dropout_p, train, bidirectional, batch_sizes, dropout_state); return {std::get<0>(miopen_output), pack_hidden(std::get<1>(miopen_output), std::get<2>(miopen_output))}; } template std::pair _miopen_impl( const Tensor& input, const hidden_type& hidden, TensorList params, bool has_biases, miopenRNNMode_t mode, int64_t num_layers, double dropout_p, bool train, bool bidirectional, bool batch_first) { auto [hx, cx] = unpack_hidden(hidden); int64_t hidden_size = hx.size(2); Tensor dropout_state = at::empty({0}, input.options()); auto miopen_output = at::miopen_rnn( input, params, has_biases ? 4 : 2, hx, cx, static_cast(mode), hidden_size, num_layers, batch_first, dropout_p, train, bidirectional, /*batch_sizes=*/{}, dropout_state); return {std::get<0>(miopen_output), pack_hidden(std::get<1>(miopen_output), std::get<2>(miopen_output))}; } #define ONE_HIDDEN_RNN(NAME, MODE) \ void NAME##_miopen(Tensor& output, Tensor& hy, \ const Tensor& input, const Tensor& hx, \ TensorList params, bool has_biases, \ int64_t num_layers, double dropout_p, bool train, bool bidirectional, bool batch_first) { \ std::tie(output, hy) = _miopen_impl(input, hx, params, has_biases, \ MODE, num_layers, dropout_p, train, bidirectional, batch_first); \ } \ \ void NAME##_packed_miopen(Tensor& output, Tensor& hy, \ const Tensor& data, const Tensor& batch_sizes, const Tensor& hx, \ TensorList params, bool has_biases, \ int64_t num_layers, double dropout_p, bool train, bool bidirectional) { \ std::tie(output, hy) = _miopen_impl(data, batch_sizes, hx, params, \ has_biases, MODE, num_layers, dropout_p, train, bidirectional); \ } \ \ REGISTER_CUDA_DISPATCH(NAME##_miopen_stub, &NAME##_miopen); \ REGISTER_CUDA_DISPATCH(NAME##_packed_miopen_stub, &NAME##_packed_miopen); ONE_HIDDEN_RNN(gru, miopenGRU) ONE_HIDDEN_RNN(rnn_tanh, miopenRNNTANH) ONE_HIDDEN_RNN(rnn_relu, miopenRNNRELU) void lstm_miopen(Tensor& output, Tensor& hy, Tensor& cy, const Tensor& input, TensorList hx, TensorList params, bool has_biases, int64_t num_layers, double dropout_p, bool train, bool bidirectional, bool batch_first) { auto result = _miopen_impl(input, std::make_tuple(hx[0], hx[1]), params, has_biases, miopenLSTM, num_layers, dropout_p, train, bidirectional, batch_first); output = result.first; hy = std::get<0>(result.second); cy = std::get<1>(result.second); } void lstm_packed_miopen(Tensor& output, Tensor& hy, Tensor& cy, const Tensor& data, const Tensor& batch_sizes, TensorList hx, TensorList params, bool has_biases, int64_t num_layers, double dropout_p, bool train, bool bidirectional) { auto result = _miopen_impl(data, batch_sizes, std::make_tuple(hx[0], hx[1]), params, has_biases, miopenLSTM, num_layers, dropout_p, train, bidirectional); output = result.first; hy = std::get<0>(result.second); cy = std::get<1>(result.second); } REGISTER_CUDA_DISPATCH(lstm_miopen_stub, &lstm_miopen); REGISTER_CUDA_DISPATCH(lstm_packed_miopen_stub, &lstm_packed_miopen); } // anonymous namespace }} //namespace native. #endif