#include #include #include #include #include #include #include #include #include #include #include #ifndef AT_PER_OPERATOR_HEADERS #include #else #include #include #include #endif #if !AT_MKLDNN_ENABLED() namespace at::native { std::tuple mkldnn_rnn_layer( const Tensor& input, const Tensor& w0, const Tensor& w1, const Tensor& w2, const Tensor& w3, const Tensor& hx_, const Tensor& cx_, bool reverse, IntArrayRef batch_sizes, int64_t mode, int64_t hidden_size, int64_t num_layers, bool has_biases, bool bidirectional, bool batch_first, bool train) { AT_ERROR("mkldnn_rnn_layer: ATen not compiled with MKLDNN support"); } std::tuple mkldnn_rnn_layer_backward( const Tensor& input, const Tensor& weight0, const Tensor& weight1, const Tensor& weight2, const Tensor& weight3, const Tensor& hx_, const Tensor& cx_tmp, const Tensor& output, const Tensor& hy_, const Tensor& cy_, const std::optional& grad_output_r_opt, const std::optional& grad_hy_r_opt, const std::optional& grad_cy_r_opt, bool reverse, int64_t mode, int64_t hidden_size, int64_t num_layers, bool has_biases, bool train, bool bidirectional, at::IntArrayRef batch_sizes, bool batch_first, const at::Tensor& workspace) { AT_ERROR("mkldnn_rnn_layer_backward: ATen not compiled with MKLDNN support"); } REGISTER_NO_CPU_DISPATCH(lstm_mkldnn_stub); } // namespace at::native #else // AT_MKLDNN_ENABLED #include #include namespace at::native { struct RNNParams { ideep::rnn_kind mode; int64_t seq_length; int64_t mini_batch; int64_t input_size; int64_t hidden_size; int64_t num_directions; int64_t num_layers; bool batch_first; bool train; at::IntArrayRef batch_sizes; int64_t num_gates; int64_t num_bias_gates; RNNParams( const at::Tensor& input, at::IntArrayRef batch_sizes_, int64_t mode_, int64_t hidden_size_, int64_t num_layers_, bool bidirectional, bool batch_first_, bool train_) { mode = static_cast(mode_); batch_first = batch_first_; seq_length = input.size(0); mini_batch = input.size(1); input_size = input.size(2); hidden_size = hidden_size_; num_directions = bidirectional ? 2 : 1; num_layers = num_layers_; train = train_; batch_sizes = batch_sizes_; if (mode == ideep::rnn_kind::LSTM) { num_gates = 4; num_bias_gates = 4; } else if (mode == ideep::rnn_kind::GRU) { num_gates = 3; num_bias_gates = 4; } else { // RNN_RELU; RNN_TANH num_gates = 1; num_bias_gates = 1; } } // mkldnn memory descriptors using format = ideep::format_tag; using desc = ideep::tensor::desc; using dtype = ideep::tensor::data_type; desc src_layer_desc(int64_t _input_size, dtype dtype) const { return {{seq_length, mini_batch, _input_size}, dtype, format::tnc}; } desc src_iter_desc(dtype dtype) const { return {{1, 1, mini_batch, hidden_size}, dtype, format::ldnc}; } desc src_iter_c_desc(dtype dtype) const { return {{1, 1, mini_batch, hidden_size}, dtype, format::ldnc}; } // logical size described as ldigo desc weights_layer_desc(int64_t _input_size, dtype dtype) const { return {{1, 1, _input_size, num_gates, hidden_size}, dtype, format::ldgoi}; } desc weights_layer_ldigo_desc(int64_t _input_size, dtype dtype) const { return {{1, 1, _input_size, num_gates, hidden_size}, dtype, format::ldigo}; } desc weights_iter_desc(dtype dtype) const { return {{1, 1, hidden_size, num_gates, hidden_size}, dtype, format::ldgoi}; } desc weights_iter_ldigo_desc(dtype dtype) const { return {{1, 1, hidden_size, num_gates, hidden_size}, dtype, format::ldigo}; } desc bias_desc(dtype dtype) const { return {{1, 1, num_bias_gates, hidden_size}, dtype, format::ldgo}; } desc dst_layer_desc(dtype dtype) const { return {{seq_length, mini_batch, hidden_size}, dtype, format::tnc}; } desc dst_iter_desc(dtype dtype) const { return {{1, 1, mini_batch, hidden_size}, dtype, format::ldnc}; } desc dst_iter_c_desc(dtype dtype) const { return {{1, 1, mini_batch, hidden_size}, dtype, format::ldnc}; } }; template std::vector _output_size(const RNNParams& rnn) { auto output_channels = is_single_direction ? rnn.hidden_size : rnn.hidden_size * rnn.num_directions; return {rnn.seq_length, rnn.mini_batch, output_channels}; } // MKLDNN GRU gate order is different from PyTorch's which requires gates shuffle // (let rt,zt,nt be reset, update, new gates respectively) // // MKLDNN GRU weight_ih/weight_hh gates order: (zt, rt, nt) // PyTorch GRU weight_ih/weight_hh gates order: (rt, zt, nt) // // MKLDNN GRU bias has 4 gates instead of 3 // (PyTorch GRU bias) (MKLDNN GRU bias) // // bias_ih bias_hh bias // +-----+ +-----+ +---------+ // | rt1 | | rt2 | | zt1+zt2 | // |-----| |-----| |---------| // | zt1 | | zt2 | | rt1+rt2 | // |-----| |-----| |---------| // | nt1 | | nt2 | | nt1 | // +-----+ +-----+ |---------| // | nt2 | // +---------+ // static Tensor _shuffle_weight(const Tensor& weight, int64_t fn_mode) { auto weight_t = weight.contiguous(); if (static_cast(fn_mode) == ideep::rnn_kind::GRU) { std::vector gates = weight_t.chunk(3, /*gates*/0); return at::cat({gates[1], gates[0], gates[2]}, /*gates*/0); } return weight_t; } static Tensor _shuffle_bias(const Tensor& bias_ih, const Tensor& bias_hh, int64_t fn_mode) { if (static_cast(fn_mode) == ideep::rnn_kind::GRU) { std::vector b1 = bias_ih.chunk(3, /*output_channels*/0); std::vector b2 = bias_hh.chunk(3, /*output_channels*/0); return at::cat({b1[1] + b2[1], b1[0] + b2[0], b1[2], b2[2]}, /*output_channels*/0); } return bias_ih + bias_hh; } std::tuple mkldnn_rnn_layer(const Tensor& input, const Tensor& w0, const Tensor& w1, const Tensor& w2, const Tensor& w3, const Tensor& hx_, const Tensor& cx_, bool reverse, IntArrayRef batch_sizes, int64_t mode, int64_t hidden_size, int64_t num_layers, bool has_biases, bool bidirectional, bool batch_first, bool train) { RNNParams rnn( input, batch_sizes, mode, hidden_size, num_layers, bidirectional, batch_first, train); auto output_size = _output_size(rnn); auto output = at::empty(output_size, input.options()); auto hy_ = at::empty(hx_.sizes(), hx_.options()); auto cy_ = at::empty(cx_.sizes(), cx_.options()); auto weight_ih = _shuffle_weight(w0, rnn.mode); auto weight_hh = _shuffle_weight(w1, rnn.mode); // Packed weight will be mkldnn layout while bias won't be packed auto bias = has_biases ? _shuffle_bias(w2, w3, rnn.mode) : at::zeros({rnn.num_bias_gates * rnn.hidden_size}, weight_ih.options().layout(at::Layout::Strided)); // per layer input size int64_t input_size = input.size(2); ideep::tensor w1_, w2_; auto x = itensor_view_from_dense( input, rnn.src_layer_desc(input_size, get_mkldnn_dtype(input))); auto hx = itensor_view_from_dense( hx_, rnn.src_iter_desc(get_mkldnn_dtype(hx_))); auto cx = itensor_view_from_dense( cx_, rnn.src_iter_c_desc(get_mkldnn_dtype(cx_))); auto b = itensor_view_from_dense( bias, rnn.bias_desc(get_mkldnn_dtype(bias))); auto y = itensor_view_from_dense( output, rnn.dst_layer_desc(get_mkldnn_dtype(output))); auto hy = itensor_view_from_dense( hy_, rnn.dst_iter_desc(get_mkldnn_dtype(hy_))); auto cy = itensor_view_from_dense( cy_, rnn.dst_iter_c_desc(get_mkldnn_dtype(cy_))); w1_ = weight_ih.is_mkldnn() ? itensor_from_tensor(weight_ih) : itensor_view_from_dense(weight_ih, rnn.weights_layer_desc(input_size, get_mkldnn_dtype(weight_ih))); w2_ = weight_hh.is_mkldnn() ? itensor_from_tensor(weight_hh) : itensor_view_from_dense(weight_hh, rnn.weights_iter_desc(get_mkldnn_dtype(weight_hh))); if (at::GradMode::is_enabled()) { Tensor workspace = Tensor(); auto pd = ideep::lstm_forward_training::prepare( x, hx, cx, w1_, w2_, b, y, hy, cy, reverse); workspace = at::empty(pd.workspace_desc().get_size() / sizeof(uint8_t), input.options().dtype(at::kByte)); ideep::tensor mkldnn_workspace; mkldnn_workspace.init( pd.workspace_desc(), workspace.template data_ptr()); ideep::lstm_forward_training::compute( pd, x, hx, cx, w1_, w2_, b, mkldnn_workspace, y, hy, cy, reverse, ideep::prop_kind::forward_training); return std::make_tuple(output, hy_, cy_, workspace); } else { ideep::lstm_forward_inference::compute( x, hx, cx, w1_, w2_, b, y, hy, cy, reverse, ideep::prop_kind::forward_inference); return std::make_tuple(output, hy_, cy_, Tensor()); } } std::tuple mkldnn_rnn_layer_backward( const Tensor& input, const Tensor& weight0, const Tensor& weight1, const Tensor& weight2, const Tensor& weight3, const Tensor& hx_, const Tensor& cx_tmp, const Tensor& output, const Tensor& hy_, const Tensor& cy_, const std::optional& grad_output_r_opt, const std::optional& grad_hy_r_opt, const std::optional& grad_cy_r_opt, bool reverse, int64_t mode, int64_t hidden_size, int64_t num_layers, bool has_biases, bool train, bool bidirectional, at::IntArrayRef batch_sizes, bool batch_first, const at::Tensor& workspace) { const Tensor& grad_output_r = c10::value_or_else(grad_output_r_opt, [] {return Tensor();}); const Tensor& grad_hy_r = c10::value_or_else(grad_hy_r_opt, [] {return Tensor();}); const Tensor& grad_cy_r = c10::value_or_else(grad_cy_r_opt, [] {return Tensor();}); if (!grad_output_r.defined() && !grad_hy_r.defined() && !grad_cy_r.defined()) { return std::make_tuple(Tensor(), Tensor(), Tensor(), Tensor(), Tensor(), Tensor(), Tensor()); } auto grad_output = grad_output_r.defined() ? grad_output_r.contiguous() : at::zeros_like(output, LEGACY_CONTIGUOUS_MEMORY_FORMAT); auto grad_hy = grad_hy_r.defined() ? grad_hy_r.contiguous() : at::zeros_like(hx_, LEGACY_CONTIGUOUS_MEMORY_FORMAT); auto grad_cy = cx_tmp.defined() ? (grad_cy_r.defined() ? grad_cy_r.contiguous() : at::zeros_like(cx_tmp, LEGACY_CONTIGUOUS_MEMORY_FORMAT)) : grad_cy_r.contiguous(); RNNParams rnn( input, batch_sizes, mode, hidden_size, num_layers, bidirectional, batch_first, train); auto output_size = _output_size(rnn); auto weight_ih = _shuffle_weight(weight0, rnn.mode); auto weight_hh = _shuffle_weight(weight1, rnn.mode); auto bias = has_biases ? _shuffle_bias(weight2, weight3, rnn.mode) : at::zeros({rnn.num_bias_gates * rnn.hidden_size}, weight_ih.options()); auto cx_ = hx_.storage().unsafeGetStorageImpl() == cx_tmp.storage().unsafeGetStorageImpl() ? at::clone(cx_tmp) : cx_tmp; // per layer input size int64_t input_size = input.size(2); auto x = itensor_view_from_dense( input, rnn.src_layer_desc(input_size, get_mkldnn_dtype(input.scalar_type()))); auto hx = itensor_view_from_dense( hx_, rnn.src_iter_desc(get_mkldnn_dtype(hx_.scalar_type()))); auto cx = itensor_view_from_dense( cx_, rnn.src_iter_c_desc(get_mkldnn_dtype(cx_.scalar_type()))); auto w1 = itensor_view_from_dense( weight_ih, rnn.weights_layer_desc( input_size, get_mkldnn_dtype(weight_ih.scalar_type()))); auto w2 = itensor_view_from_dense( weight_hh, rnn.weights_iter_desc(get_mkldnn_dtype(weight_hh.scalar_type()))); auto b = itensor_view_from_dense( bias, rnn.bias_desc(get_mkldnn_dtype(bias.scalar_type()))); auto y = itensor_view_from_dense( output, rnn.dst_layer_desc(get_mkldnn_dtype(output.scalar_type()))); auto hy = itensor_view_from_dense( hy_, rnn.dst_iter_desc(get_mkldnn_dtype(hy_.scalar_type()))); auto cy = itensor_view_from_dense( cy_, rnn.dst_iter_c_desc(get_mkldnn_dtype(cy_.scalar_type()))); // Create diff_* ATen tensor and corresponding ideep tensor as fp32 auto diff_x_ = at::empty(input.sizes(), input.options().dtype(at::ScalarType::Float)); auto diff_hx_ = at::empty(hx_.sizes(), hx_.options().dtype(at::ScalarType::Float)); auto diff_cx_ = at::empty(cx_.sizes(), cx_.options().dtype(at::ScalarType::Float)); auto diff_w1_ = at::empty( weight_ih.sizes(), weight_ih.options().dtype(at::ScalarType::Float)); auto diff_w2_ = at::empty( weight_hh.sizes(), weight_hh.options().dtype(at::ScalarType::Float)); auto diff_b_ = at::empty(bias.sizes(), bias.options().dtype(at::ScalarType::Float)); auto diff_x = itensor_view_from_dense( diff_x_, rnn.src_layer_desc(input_size, ideep::tensor::data_type::f32)); auto diff_hx = itensor_view_from_dense( diff_hx_, rnn.src_iter_desc(ideep::tensor::data_type::f32)); auto diff_cx = itensor_view_from_dense( diff_cx_, rnn.src_iter_c_desc(ideep::tensor::data_type::f32)); auto diff_w1 = itensor_view_from_dense( diff_w1_, rnn.weights_layer_desc(input_size, ideep::tensor::data_type::f32)); auto diff_w2 = itensor_view_from_dense( diff_w2_, rnn.weights_iter_desc(ideep::tensor::data_type::f32)); auto diff_b = itensor_view_from_dense( diff_b_, rnn.bias_desc(ideep::tensor::data_type::f32)); // Convert grad_y, grad_hy, grad_cy to fp32 in non-fp32 backward ideep::tensor diff_y, diff_hy, diff_cy; at::Tensor grad_y_, grad_hy_, grad_cy_; if (input.scalar_type() != at::ScalarType::Float) { grad_y_ = at::empty( grad_output.sizes(), grad_output.options().dtype(at::ScalarType::Float)); grad_y_.copy_(grad_output); grad_hy_ = at::empty( grad_hy.sizes(), grad_hy.options().dtype(at::ScalarType::Float)); grad_hy_.copy_(grad_hy); grad_cy_ = at::empty( grad_cy.sizes(), grad_cy.options().dtype(at::ScalarType::Float)); grad_cy_.copy_(grad_cy); diff_y = itensor_view_from_dense( grad_y_, rnn.dst_layer_desc(get_mkldnn_dtype(grad_y_.scalar_type()))); diff_hy = itensor_view_from_dense( grad_hy_, rnn.dst_iter_desc(get_mkldnn_dtype(grad_hy_.scalar_type()))); diff_cy = itensor_view_from_dense( grad_cy_, rnn.dst_iter_desc(get_mkldnn_dtype(grad_cy_.scalar_type()))); } else { diff_y = itensor_view_from_dense( grad_output, rnn.dst_layer_desc(ideep::tensor::data_type::f32)); diff_hy = itensor_view_from_dense( grad_hy, rnn.dst_iter_desc(ideep::tensor::data_type::f32)); diff_cy = itensor_view_from_dense( grad_cy, rnn.dst_iter_desc(ideep::tensor::data_type::f32)); } auto forward_hint = ideep::lstm_forward_training::prepare(x, hx, cx, w1, w2, b, y, hy, cy, reverse); ideep::tensor mkldnn_workspace; mkldnn_workspace.init( forward_hint.workspace_desc(), workspace.template data_ptr()); ideep::lstm_backward::compute(forward_hint, x, hx, cx, w1, w2, b, y, hy, cy, diff_y, diff_hy, diff_cy, mkldnn_workspace, diff_x, diff_hx, diff_cx, diff_w1, diff_w2, diff_b, reverse); auto diff_b2_ = at::clone(diff_b_); return std::make_tuple(diff_x_, diff_w1_, diff_w2_, diff_b_, diff_b2_, diff_hx_, diff_cx_); } // MKLDNN RNN integration notes: // I. Memory Formats // a. mkldnn will use plain formats for input, hx/cx, output, hy/cy // and possibly use blocked formats for weights depending shape info. // b. All mkldnn memorys are created (in plain format) as views on ATen tensor, // the weight reorder(if any) is handed automatically inside ideep (mkldnn bridge) // // II. MKLDNN Primitive Mapping // a. mkldnn rnn primitive doesn't support training with dropout or padded input sequence. // b. here break a single RNN module into { num_layers * num_directions } mkldnn rnn primitives // for future need to cover these feature gaps. // //TODO: a. training with dropout // b. padded sequence input support // static std::tuple mkldnn_rnn( const Tensor& input_, TensorList weight, int64_t weight_stride0, const Tensor& hx_, const Tensor& cx_, int64_t mode, int64_t hidden_size, int64_t num_layers, bool has_biases, bool batch_first, double dropout_p, bool train, bool bidirectional, IntArrayRef batch_sizes) { TORCH_CHECK(batch_sizes.size() == 0, "mkldnn_rnn doesn't support packed input"); if (static_cast(mode) != ideep::rnn_kind::LSTM) { TORCH_CHECK(!cx_.defined(), "mkldnn_rnn: illegal defined cx for non-LSTM RNN"); } auto input = input_; if (batch_first) { input = input.transpose(0, 1); } input = input.contiguous(); auto hx = hx_.contiguous(); auto cx = cx_.contiguous(); MatrixRef weights{weight, static_cast(weight_stride0)}; auto num_directions = bidirectional ? 2 : 1; auto layer_input = input; std::vector layer_output(num_directions); std::vector layer_hy(num_layers * num_directions); std::vector layer_cy(num_layers * num_directions); for (const auto layer: c10::irange(num_layers)) { for (const auto direction: c10::irange(num_directions)) { const auto index = layer * num_directions + direction; auto layer_weights = weights[index]; TORCH_CHECK(layer_weights.size() == 2 || layer_weights.size() == 4); auto layer_hx = hx[index]; auto layer_cx = cx[index]; auto reverse = (direction > 0); // bias won't be packed auto outputs = at::mkldnn_rnn_layer(layer_input, layer_weights[0], layer_weights[1], has_biases ? layer_weights[2] : at::zeros(layer_weights[0].sizes(), layer_weights[0].options().layout(at::Layout::Strided)), has_biases ? layer_weights[3] : at::zeros(layer_weights[1].sizes(), layer_weights[1].options().layout(at::Layout::Strided)), layer_hx, layer_cx, reverse, batch_sizes, mode, hidden_size, num_layers, has_biases, bidirectional, batch_first, train); layer_output[direction] = std::get<0>(outputs); layer_hy[index] = std::get<1>(outputs); layer_cy[index] = std::get<2>(outputs); } layer_input = num_directions == 1 ? layer_output[0] : at::cat(layer_output, /*output_channels*/-1); if (dropout_p != 0 && train && layer < num_layers - 1) { layer_input = at::dropout(layer_input, dropout_p, /*train=*/true); } } auto output = layer_input; auto hy = at::stack(layer_hy, 0); auto cy = at::stack(layer_cy, 0); if (batch_first) { output = output.transpose(0, 1); } return std::make_tuple(output, hy, cy); } //////////////////////////////////////////////////////////////////////////////// //// MKLDNN dispatch for the generic RNN ops (at::lstm, at::gru, ...) //////////////////////////////////////////////////////////////////////////////// namespace { // Helpers for working with different hidden types. std::tuple unpack_hidden(const std::tuple& hidden) { return hidden; } template hidden_type pack_hidden(const Tensor& hx, const Tensor& cx) { static_assert(false && sizeof(hidden_type), "pack_hidden not implemented for this type"); } template<> std::tuple pack_hidden>(const Tensor& hx, const Tensor& cx) { return std::make_tuple(hx, cx); } template std::pair mkldnn_impl( const Tensor& input, const hidden_type& hidden, TensorList params, bool has_biases, ideep::rnn_kind mode, int64_t num_layers, double dropout_p, bool train, bool bidirectional, bool batch_first) { auto [hx, cx] = unpack_hidden(hidden); int64_t hidden_size = hx.size(2); auto mkldnn_output = mkldnn_rnn( input, params, has_biases ? 4 : 2, hx, cx, static_cast(mode), hidden_size, num_layers, has_biases, batch_first, dropout_p, train, bidirectional, /*batch_sizes*/{}); return {std::get<0>(mkldnn_output), pack_hidden(std::get<1>(mkldnn_output), std::get<2>(mkldnn_output))}; } void lstm_mkldnn(Tensor& output, Tensor& hy, Tensor& cy, const Tensor& input, TensorList hx, TensorList params, bool has_biases, int64_t num_layers, double dropout_p, bool train, bool bidirectional, bool batch_first) { auto result = mkldnn_impl(input, std::make_tuple(hx[0], hx[1]), params, has_biases, ideep::rnn_kind::LSTM, num_layers, dropout_p, train, bidirectional, batch_first); output = result.first; hy = std::get<0>(result.second); cy = std::get<1>(result.second); } } // anonymous namespace REGISTER_ALL_CPU_DISPATCH(lstm_mkldnn_stub, &lstm_mkldnn); } // namespace at::native #endif // AT_MKLDNN_ENABLED