#define TORCH_ASSERT_ONLY_METHOD_OPERATORS #include #include #include #include #include #include #include #include #include #include #include #include #ifndef AT_PER_OPERATOR_HEADERS #include #include #else #include #include #include #include #include #include #include #include #include #endif #if !AT_CUDNN_ENABLED() namespace at { namespace native { // See Note [ATen preprocessor philosophy] Tensor _cudnn_rnn_flatten_weight( TensorList weight_arr, int64_t weight_stride0, int64_t input_size, int64_t fn_mode, int64_t fn_hidden_size, int64_t fn_proj_size, int64_t fn_num_layers, bool batch_first, bool fn_bidirectional) { AT_ERROR("_cudnn_rnn_flatten_weight: ATen not compiled with cuDNN support"); } std::tuple _cudnn_rnn( const Tensor& input_r, TensorList weight, int64_t weight_stride0, const std::optional& weight_buf_r_opt, const Tensor& hx, const std::optional& cx_opt, int64_t fn_mode, int64_t fn_hidden_size, int64_t fn_proj_size, int64_t fn_num_layers, bool batch_first, double fn_dropout, bool fn_train, bool fn_bidirectional, IntArrayRef fn_batch_sizes, const std::optional& fn_dropout_state_opt) { AT_ERROR("_cudnn_rnn: ATen not compiled with cuDNN support"); } std::tuple> _cudnn_rnn_backward( const Tensor& input, TensorList weight, int64_t weight_stride0, const Tensor& weight_buf, const Tensor& hx, const std::optional& cx_opt, const Tensor& output, const std::optional& grad_output_r_opt, const std::optional& grad_hy_r_opt, const std::optional& grad_cy_r_opt, int64_t mode, int64_t hidden_size, int64_t proj_size, int64_t num_layers, bool batch_first, double dropout, bool train, bool bidirectional, IntArrayRef batch_sizes, const std::optional& dropout_state_opt, const Tensor& reserve, std::array output_mask) { AT_ERROR("_cudnn_rnn_backward: ATen not compiled with cuDNN support"); } Tensor _cudnn_init_dropout_state( double dropout, bool train, int64_t dropout_seed, std::optional dtype, std::optional layout, std::optional device, std::optional pin_memory) { // See [Note: hacky wrapper removal for TensorOptions] TensorOptions().dtype(dtype).layout(layout).device(device).pinned_memory( pin_memory); AT_ERROR("_cudnn_init_dropout_state: ATen not compiled with cuDNN support"); } } // namespace native } // namespace at #else // AT_CUDNN_ENABLED() #include namespace at { namespace native { namespace { // DropoutDescriptor struct DropoutDescriptorParams { bool train; double dropout; Tensor dropout_state; DropoutDescriptorParams() = default; void set(bool train_, double dropout_, Tensor dropout_state_) { train = train_; dropout = dropout_; dropout_state = dropout_state_; } DropoutDescriptor descriptor(cudnnHandle_t handle) const { auto dropout_p = train ? dropout : 0; DropoutDescriptor dropout_desc; if (dropout_p == 0) { dropout_desc.set_no_dropout(handle); } else { dropout_desc.set(handle, dropout_p, dropout_state); } return dropout_desc; } }; // RNNDescriptor struct RNNDescriptorParams { #ifdef USE_CUDNN_RNN_V8_API int64_t input_size; bool packed; #endif int64_t hidden_size; int64_t proj_size; int64_t num_layers; cudnnDirectionMode_t bidirectional; cudnnRNNMode_t mode; cudnnDataType_t datatype; cudnnDataType_t input_datatype; cudnnRNNAlgo_t algo = CUDNN_RNN_ALGO_STANDARD; cudnnRNNInputMode_t input_mode = CUDNN_LINEAR_INPUT; int64_t num_directions() const { return bidirectional ? 2 : 1; } void set_mode(int64_t fn_mode) { switch (fn_mode) { case CUDNN_RNN_RELU: mode = CUDNN_RNN_RELU; break; case CUDNN_RNN_TANH: mode = CUDNN_RNN_TANH; break; case CUDNN_LSTM: mode = CUDNN_LSTM; break; case CUDNN_GRU: mode = CUDNN_GRU; break; default: { std::ostringstream oss; oss << "unrecognized cuDNN RNN mode " << fn_mode; AT_ERROR(oss.str()); } } } void set_bidirectional(bool fn_bidirectional) { bidirectional = fn_bidirectional ? CUDNN_BIDIRECTIONAL : CUDNN_UNIDIRECTIONAL; } void set_algo(cudnnRNNAlgo_t algo) { this->algo = algo; } #ifndef USE_CUDNN_RNN_V8_API void set( int64_t mode, int64_t hidden_size, int64_t proj_size, int64_t num_layers, bool bidirectional, cudnnDataType_t datatype, cudnnDataType_t input_datatype){ #else void set( int64_t mode, int64_t input_size, bool packed, int64_t hidden_size, int64_t proj_size, int64_t num_layers, bool bidirectional, cudnnDataType_t datatype, cudnnDataType_t input_datatype) { #endif this->set_mode(mode); #ifdef USE_CUDNN_RNN_V8_API this->input_size = input_size; this->packed = packed; #endif this->hidden_size = hidden_size; this->proj_size = proj_size; this->num_layers = num_layers; this->set_bidirectional(bidirectional); this->datatype = datatype; this->input_datatype = input_datatype; } RNNDescriptor descriptor(cudnnHandle_t handle, DropoutDescriptor&& dropout_desc) const { RNNDescriptor rnn_desc; #ifndef USE_CUDNN_RNN_V8_API rnn_desc.set( handle, hidden_size, proj_size, num_layers, std::move(dropout_desc), input_mode, bidirectional, mode, datatype, input_datatype, algo, at::globalContext().allowTF32CuDNN()); #else rnn_desc.set( handle, input_size, packed, hidden_size, proj_size, num_layers, std::move(dropout_desc), input_mode, bidirectional, mode, datatype, input_datatype, algo, at::globalContext().allowTF32CuDNN()); #endif return rnn_desc; } // In some cases, a use of RNNDescriptor does not rely on the // DropoutDescriptor. In this case, we fake up a no-dropout // descriptor to make the RNN descriptor initialization go through. // This is used by _cudnn_rnn_flatten_weight, which needs an // RNNDescriptor for get_parameters(), but does not actually need // a fully initialized dropout descriptor. This lets us avoid // having to pass the dropout state to flatten, which has no business // knowing what the dropout state is. RNNDescriptor descriptor(cudnnHandle_t handle) const { DropoutDescriptor dropout_desc; dropout_desc.set_no_dropout(handle); return descriptor(handle, std::move(dropout_desc)); } }; // namespace // TensorDescriptor list #ifndef USE_CUDNN_RNN_V8_API std::vector rnn_descriptor_sequence( const Tensor& tensor, IntArrayRef batch_sizes) { std::vector descriptors(batch_sizes.size()); size_t i = 0; // To be mutated in the loop auto batch_tensor_size = tensor.sizes().vec(); for (auto batch_size : batch_sizes) { batch_tensor_size[0] = batch_size; // NB: cuDNN RNN API does not support 2d descriptors, so we // must pad it out to 3d. descriptors[i].set( getCudnnDataType(tensor), batch_tensor_size, tensor.strides(), 3); i++; } return descriptors; } std::vector rnn_descriptor(const Tensor& tensor, int64_t N) { std::vector descriptors(N); for (const auto i : c10::irange(N)) { descriptors[i].set(tensor, 5); } return descriptors; } #else auto rnn_descriptor_sequence( const Tensor& tensor, uint32_t batch_size, const IntArrayRef batch_sizes, uint32_t seq_len, uint32_t vector_size) { // packed case RNNDataDescriptor r; std::vector seqLengthArray(batch_size, 1); // cuDNN wants the sequence lengths for a packed batch as if they // were unpacked, e.g., for the // Sequence 1: ABCD // Sequence 2: EF // Sequence 3: G // case below, this would be [4, 2, 1] (has length == mini_batch) // TODO(eqy): There's probably a smarter way to do this than O(SN) for (auto it = batch_sizes.begin(); it != batch_sizes.end(); it++) { // everyone starts at sequence length 1 so we skip an iteration if (it == batch_sizes.begin()) { continue; } for (const auto idx : c10::irange(*it)) { seqLengthArray[idx]++; } } r.set( tensor, CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_PACKED, seq_len, batch_size, vector_size, seqLengthArray.data()); return r; } auto rnn_descriptor( const Tensor& tensor, uint32_t batch_size, uint32_t seq_len, uint32_t vector_size) { RNNDataDescriptor r; // NB: Looks like even if batch_first is true here we always want // SEQ_MAJOR_UNPACKED, because the input appears to be transposed if it is // barch-major const auto layout = CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_UNPACKED; std::vector seqLengthArray(batch_size, seq_len); r.set( tensor, layout, seq_len, batch_size, vector_size, seqLengthArray.data()); return r; } #endif // The best way to understand the meaning of the values stored in // this struct is to consider each of the possible ways our // input can be structured. // // Suppose you want to run RNN on the following variable // length inputs: // // Sequence 1: ABCD // Sequence 2: EF // Sequence 3: G // // (Let _ be padding when we have non-packed representations.) // // # Packed input (batch_sizes is non-empty) // // input_size // +------+ + // | A | | // | E | mini_batch = | // | G | batch_sizes[0] = 3 | // +------+ | // | B | | batch_sizes_sum = 7 // | F | batch_sizes[1] = 2 | // +------+ | // | C | batch_sizes[2] = 1 | // +------+ | // | D | batch_sizes[3] = 1 | // +------+ + // // (seq_length = 4) // // input.size() = batch_sizes_sum x input_size // // # Unpacked input (batch_first = false) // // mini_batch = 3 // +-------+ // | A E G | // | B F _ | seq_length = 4 // | C _ _ | // | D _ _ | // +-------+ // ... input_size // +-------+ // // input.size() = seq_length x mini_batch x input_size // // # Unpacked input (batch_first = true) // // seq_length = 4 // +---------+ // | A B C D | // | E F _ _ | mini_batch = 3 // | G _ _ _ | // +---------+ // ... input_size // +---------+ // // input.size() = mini_batch x seq_length x input_size // struct TensorDescriptorListParams { IntArrayRef batch_sizes; int64_t seq_length; int64_t mini_batch; // NB: this is not input.size(), which is an IntArrayRef; instead, this // size of the inner-most dimension. In NL applications, this is usually // the size of the embedding. You can also think of this as the size // of the "channel" dimension (at risk of confusing vision researchers :) int64_t input_size; // Only valid when !is_input_packed int64_t batch_sizes_sum; // == sum(batch_sizes) bool is_input_packed() const { return batch_sizes.size() != 0; } void set( IntArrayRef input_sizes, IntArrayRef batch_sizes_, bool batch_first) { batch_sizes = batch_sizes_; if (is_input_packed()) { seq_length = batch_sizes.size(); mini_batch = batch_sizes[0]; // NB: When input is packed, the mini_batch size is NOT the size // of the outer dimension batch_sizes_sum = input_sizes[0]; input_size = input_sizes[1]; } else { if (batch_first) { seq_length = input_sizes[1]; mini_batch = input_sizes[0]; } else { seq_length = input_sizes[0]; mini_batch = input_sizes[1]; } input_size = input_sizes[2]; // TODO: Actually, would this make ASAN's job harder catching // an uninitialized access? batch_sizes_sum = -1; // something bogus in case we access it } } #ifndef USE_CUDNN_RNN_V8_API // TODO: check x for consistency with input_size? std::vector descriptors(Tensor x) const { auto is_input_packed = batch_sizes.size() != 0; if (is_input_packed) { return rnn_descriptor_sequence(x, batch_sizes); } else { return rnn_descriptor(x[0], seq_length); } } #else auto descriptors(Tensor x) const { auto is_input_packed = batch_sizes.size() != 0; if (is_input_packed) { return rnn_descriptor_sequence( x, mini_batch, batch_sizes, seq_length, x.size(-1)); } else { return rnn_descriptor(x, mini_batch, seq_length, x.size(-1)); } } #endif }; // Everything together struct RNNParams { DropoutDescriptorParams dropout; RNNDescriptorParams rnn; TensorDescriptorListParams tensors; }; // NB: Doesn't include the weight descriptor struct RNNDescriptors { RNNDescriptor rnn_desc; // NB: this won't actually lay out the tensor descriptor pointers // in the right way, so you'll have to preprocess them #ifndef USE_CUDNN_RNN_V8_API std::vector x_descs; std::vector y_descs; #else RNNDataDescriptor x_descs; RNNDataDescriptor y_descs; #endif TensorDescriptor hx_desc; TensorDescriptor hy_desc; TensorDescriptor cx_desc; TensorDescriptor cy_desc; RNNDescriptors( const RNNParams& fn, cudnnHandle_t handle, Tensor x, Tensor y, Tensor hx, Tensor cx) { rnn_desc = fn.rnn.descriptor(handle, fn.dropout.descriptor(handle)); x_descs = fn.tensors.descriptors(x); y_descs = fn.tensors.descriptors(y); hx_desc.set(hx, 5); hy_desc.set(hx, 5); if (cx.defined()) { cx_desc.set(cx, 5); cy_desc.set(cx, 5); } } // TODO: This is annoying, having to put the cudnnTensorDescriptor_t // in a contiguous array... std::vector get_descs( const std::vector& descs) { std::vector r; r.reserve(descs.size()); for (auto& desc : descs) { r.emplace_back(desc.desc()); } return r; } #ifndef USE_CUDNN_RNN_V8_API std::vector get_x_descs() { return get_descs(x_descs); } std::vector get_y_descs() { return get_descs(y_descs); } #endif }; int64_t get_num_weights( cudnnHandle_t handle, const RNNDescriptor& rnn_desc, #ifndef USE_CUDNN_RNN_V8_API const TensorDescriptor& x_desc, #endif cudnnDataType_t datatype) { size_t weight_size; #ifndef USE_CUDNN_RNN_V8_API AT_CUDNN_CHECK(cudnnGetRNNParamsSize( handle, rnn_desc.desc(), x_desc.desc(), &weight_size, datatype)); #else AT_CUDNN_CHECK( cudnnGetRNNWeightSpaceSize(handle, rnn_desc.desc(), &weight_size)); #endif auto elem_size = dataSize(datatype); TORCH_INTERNAL_ASSERT( weight_size % elem_size == 0, "cudnnGetRNNParamsSize returned nonsensical weight_size"); return weight_size / elem_size; } int64_t _num_linear_layers(cudnnRNNMode_t mode) { switch (mode) { case CUDNN_LSTM: return 8; case CUDNN_GRU: return 6; case CUDNN_RNN_RELU: return 2; case CUDNN_RNN_TANH: return 2; default: AT_ERROR("unknown cuDNN RNN mode ", mode); } } void add_projection_weights( cudnnHandle_t handle, const RNNDescriptor& rnn_desc, #ifndef USE_CUDNN_RNN_V8_API const TensorDescriptor& x_desc, const FilterDescriptor& w_desc, #endif const Tensor& weight_buf, int64_t layer, std::vector& params) { void* matrix_pointer = nullptr; // assuming it's LSTM which has 8 "linear layers" (i.e. 4 weights and 4 // biases) int64_t linear_id = 8; #ifndef USE_CUDNN_RNN_V8_API FilterDescriptor lin_layer_mat_desc; AT_CUDNN_CHECK(cudnnGetRNNLinLayerMatrixParams( /*handle=*/handle, /*rnnDesc=*/rnn_desc.desc(), /*layer=*/layer, /*xDesc=*/x_desc.desc(), /*wDesc=*/w_desc.desc(), /*w=*/weight_buf.data_ptr(), /*linLayerID=*/linear_id, /*linLayerMatDesc=*/lin_layer_mat_desc.mut_desc(), /*linLayerMat=*/&matrix_pointer)); #else TensorDescriptor lin_layer_mat_desc; AT_CUDNN_CHECK(cudnnGetRNNWeightParams( /*handle=*/handle, /*rnnDesc=*/rnn_desc.desc(), /*layer=*/layer, /*wDesc=*/weight_buf.numel() * weight_buf.element_size(), /*w=*/weight_buf.data_ptr(), /*linLayerID=*/linear_id, /*linLayerMatDesc=*/lin_layer_mat_desc.mut_desc(), /*linLayerMat=*/&matrix_pointer, nullptr, nullptr)); #endif cudnnDataType_t data_type; #ifndef USE_CUDNN_RNN_V8_API cudnnTensorFormat_t format; #else int stride_dim_a[5]; #endif int nb_dims; constexpr int min_dim = 3; int filter_dim_a[min_dim]; #ifndef USE_CUDNN_RNN_V8_API AT_CUDNN_CHECK(cudnnGetFilterNdDescriptor( lin_layer_mat_desc.desc(), min_dim, &data_type, &format, &nb_dims, filter_dim_a)); #else AT_CUDNN_CHECK(cudnnGetTensorNdDescriptor( lin_layer_mat_desc.desc(), min_dim, &data_type, &nb_dims, filter_dim_a, stride_dim_a)); #endif TORCH_INTERNAL_ASSERT( nb_dims <= min_dim, "nb_dims = ", nb_dims, "; min_dim = ", min_dim); auto elem_size = dataSize(getCudnnDataType(weight_buf)); auto offset_bytes = (char*)matrix_pointer - (char*)weight_buf.data_ptr(); TORCH_INTERNAL_ASSERT( offset_bytes % elem_size == 0, "offset_bytes = ", offset_bytes, "; elem_size = ", elem_size); size_t offset = offset_bytes / elem_size; int mat_numel = c10::multiply_integers(filter_dim_a, filter_dim_a + nb_dims); // Generate a new parameter tensor which is a view into the weight_buf. std::initializer_list size = {mat_numel, 1}; Tensor param = at::empty({0}, weight_buf.options()) .set_(weight_buf.storage(), offset, size); params.emplace_back(std::move(param)); } /* Returns weight and bias tensors for each layer of the RNN. These tensors are views on the underlying weight buffer allocated by CuDNN. Note: for LSTM and GRU, which have multiple parameters of each type (4 and 3, respectively), these parameters are concatenated along the first dimension. These parameters are returned in a consistent order by CuDNN: (reset, forget, cell, output) for LSTM (reset, input, new) for GRU Args: fn: The RNN function object holding the RNN state handle: a CuDNN handle weight_buf: a 1D tensor containing the CuDNN-allocated weight (or grad_weight) buffer Returns: parameters: [(weight_ih, weight_hh, bias_ih, bias_hh)*], with length equal to the num_layers. This is represented as a pair of vector, and outer-dimension stride (NB: Can't return MatrixRef because we need to allocate the underlying tensor) */ std::pair, size_t> // stride0 get_parameters( cudnnHandle_t handle, const RNNDescriptorParams& rnn, const RNNDescriptor& rnn_desc, #ifndef USE_CUDNN_RNN_V8_API const TensorDescriptor& x_desc, const FilterDescriptor& w_desc, #endif const Tensor& weight_buf, bool include_bias = true) { #ifndef USE_CUDNN_RNN_V8_API auto cudnn_methods = { cudnnGetRNNLinLayerMatrixParams, cudnnGetRNNLinLayerBiasParams}; #else auto cudnn_methods = {true, false}; #endif std::vector params; int64_t num_linear_layers = _num_linear_layers(rnn.mode); int64_t num_layers = rnn.num_directions() * rnn.num_layers; size_t cur_offset = 0; size_t global_layer_params_count = 0; for (const auto layer : c10::irange(num_layers)) { size_t layer_params_count = 0; for (auto cudnn_method : cudnn_methods) { for (const auto linear_id : c10::irange(num_linear_layers)) { void* matrix_pointer; #ifndef USE_CUDNN_RNN_V8_API FilterDescriptor lin_layer_mat_desc; AT_CUDNN_CHECK(cudnn_method( handle, rnn_desc.desc(), layer, x_desc.desc(), w_desc.desc(), weight_buf.data_ptr(), linear_id, lin_layer_mat_desc.mut_desc(), &matrix_pointer)); #else TensorDescriptor lin_layer_mat_desc; for (int stateless = 0; stateless < 100; stateless++) { if (cudnn_method) { // matrix AT_CUDNN_CHECK(cudnnGetRNNWeightParams( handle, rnn_desc.desc(), layer, weight_buf.numel() * weight_buf.element_size(), weight_buf.data_ptr(), linear_id, lin_layer_mat_desc.mut_desc(), &matrix_pointer, nullptr, nullptr)); } else { // bias AT_CUDNN_CHECK(cudnnGetRNNWeightParams( handle, rnn_desc.desc(), layer, weight_buf.numel() * weight_buf.element_size(), weight_buf.data_ptr(), linear_id, nullptr, nullptr, lin_layer_mat_desc.mut_desc(), &matrix_pointer)); } } #endif cudnnDataType_t data_type; #ifndef USE_CUDNN_RNN_V8_API cudnnTensorFormat_t format; #else int stride_dim_a[5]; #endif int nb_dims; constexpr int min_dim = 3; int filter_dim_a[min_dim]; #ifndef USE_CUDNN_RNN_V8_API AT_CUDNN_CHECK(cudnnGetFilterNdDescriptor( lin_layer_mat_desc.desc(), min_dim, &data_type, &format, &nb_dims, filter_dim_a)); #else AT_CUDNN_CHECK(cudnnGetTensorNdDescriptor( lin_layer_mat_desc.desc(), min_dim, &data_type, &nb_dims, filter_dim_a, stride_dim_a)); #endif TORCH_INTERNAL_ASSERT( nb_dims <= min_dim, "nb_dims = ", nb_dims, "; min_dim = ", min_dim); auto elem_size = dataSize(getCudnnDataType(weight_buf)); auto offset_bytes = (char*)matrix_pointer - (char*)weight_buf.data_ptr(); TORCH_INTERNAL_ASSERT( offset_bytes % elem_size == 0, "offset_bytes = ", offset_bytes, "; elem_size = ", elem_size); size_t offset = offset_bytes / elem_size; // for all the RNN types provided by CUDNN, all the ih weights // are the same size and are allocated in a contiguous chunk // (same for the hh weights, and the ih and hh biases). // Since we're storing all the weights in a single tensor anyway, // might as well merge the CUDNN ones into a single tensor as well int mat_numel = c10::multiply_integers(filter_dim_a, filter_dim_a + nb_dims); if (linear_id == 0 || linear_id == num_linear_layers / 2) { // We could also exclude bias params by restricting cudnn_methods to // just { cudnnGetRNNLinLayerMatrixParams } at the very top. However, // to do so would throw off the cur_offset account, which is currently // a strict and informative check that all params are laid out the way // we think they are. If include_bias is false, I'd rather keep full // cur_offset checks rather than save some CPU overhead by skipping // the cudnn_method = cudnnGetRNNLinLayerBiasParams iteration. #ifndef USE_CUDNN_RNN_V8_API if (include_bias || cudnn_method != cudnnGetRNNLinLayerBiasParams) { #else if (include_bias || cudnn_method) { #endif // Generate a new parameter tensor which is a view into the // weight_buf. std::initializer_list size = { mat_numel * num_linear_layers / 2, 1}; Tensor param = at::empty({0}, weight_buf.options()) .set_(weight_buf.storage(), offset, size); params.emplace_back(std::move(param)); layer_params_count++; } } else { TORCH_INTERNAL_ASSERT( cur_offset == offset, "cur_offset = ", cur_offset, "; offset = ", offset); } cur_offset = offset + mat_numel; } } // for cudnn_method if (rnn.proj_size != 0) { #ifndef USE_CUDNN_RNN_V8_API add_projection_weights( handle, rnn_desc, x_desc, w_desc, weight_buf, layer, params); #else add_projection_weights(handle, rnn_desc, weight_buf, layer, params); #endif layer_params_count++; } if (layer == 0) { global_layer_params_count = layer_params_count; } else { TORCH_INTERNAL_ASSERT( global_layer_params_count == layer_params_count, "global_layer_params_count = ", global_layer_params_count, "; layer_params_count = ", layer_params_count); } } // for layer return std::make_pair(params, global_layer_params_count); } // This is a lightweight version of the method above used to quickly get the // expected parameter offsets. std::vector get_expected_data_ptrs( const Tensor& weight_buf, cudnnHandle_t handle, const RNNDescriptorParams& rnn, const RNNDescriptor& rnn_desc, const TensorDescriptor& x_desc, cudnnDataType_t datatype) { #ifndef USE_CUDNN_RNN_V8_API FilterDescriptor w_desc; w_desc.set(weight_buf, 3); #endif int64_t num_linear_layers = _num_linear_layers(rnn.mode); int64_t num_dir_layers = rnn.num_directions() * rnn.num_layers; #ifndef USE_CUDNN_RNN_V8_API const auto cudnn_methods = { cudnnGetRNNLinLayerMatrixParams, cudnnGetRNNLinLayerBiasParams}; #else const auto cudnn_methods = {true, false}; #endif std::vector data_ptrs; if (rnn.proj_size != 0) { data_ptrs.reserve(num_dir_layers * (2 * 2 + 1)); } else { data_ptrs.reserve(num_dir_layers * 2 * 2); } for (const auto layer : c10::irange(num_dir_layers)) { for (auto cudnn_method : cudnn_methods) { // This API returns a separate pointer for weight of every gate, // but we represent them as a single tensor, so we're only interested // in a very limited subset of possible values. const std::array linear_offsets = {0, num_linear_layers / 2}; for (int64_t linear_id : linear_offsets) { void* matrix_pointer; #ifndef USE_CUDNN_RNN_V8_API FilterDescriptor lin_layer_mat_desc; AT_CUDNN_CHECK(cudnn_method( handle, rnn_desc.desc(), layer, x_desc.desc(), w_desc.desc(), weight_buf.data_ptr(), linear_id, lin_layer_mat_desc.mut_desc(), &matrix_pointer)); #else TensorDescriptor lin_layer_mat_desc; if (cudnn_method) { // matrix AT_CUDNN_CHECK(cudnnGetRNNWeightParams( handle, rnn_desc.desc(), layer, weight_buf.numel() * weight_buf.element_size(), weight_buf.data_ptr(), linear_id, lin_layer_mat_desc.mut_desc(), &matrix_pointer, nullptr, nullptr)); } else { // bias AT_CUDNN_CHECK(cudnnGetRNNWeightParams( handle, rnn_desc.desc(), layer, weight_buf.numel() * weight_buf.element_size(), weight_buf.data_ptr(), linear_id, nullptr, nullptr, lin_layer_mat_desc.mut_desc(), &matrix_pointer)); } #endif data_ptrs.push_back(matrix_pointer); } } if (rnn.proj_size != 0) { // assuming it's LSTM which has 8 "linear layers" (i.e. 4 weights and 4 // biases) int64_t linear_id = 8; void* matrix_pointer; #ifndef USE_CUDNN_RNN_V8_API FilterDescriptor lin_layer_mat_desc; AT_CUDNN_CHECK(cudnnGetRNNLinLayerMatrixParams( handle, rnn_desc.desc(), layer, x_desc.desc(), w_desc.desc(), weight_buf.data_ptr(), linear_id, lin_layer_mat_desc.mut_desc(), &matrix_pointer)); #else TensorDescriptor lin_layer_mat_desc; AT_CUDNN_CHECK(cudnnGetRNNWeightParams( handle, rnn_desc.desc(), layer, weight_buf.numel() * weight_buf.element_size(), weight_buf.data_ptr(), linear_id, lin_layer_mat_desc.mut_desc(), &matrix_pointer, nullptr, nullptr)); #endif data_ptrs.push_back(matrix_pointer); } } return data_ptrs; } void _viewOrCopyOneParam( const Tensor& param_from, const Tensor& param_to, bool copy, bool allow_type_change = false) { // if copying, allow_type_change may be true or false. // if viewing, allow_type_change must be false. TORCH_INTERNAL_ASSERT( copy || !allow_type_change, "if viewing, type change is not allowed."); TORCH_INTERNAL_ASSERT( allow_type_change || (param_from.scalar_type() == param_to.scalar_type()), "parameter types mismatch"); if (copy) { param_to.copy_(param_from.view_as(param_to)); } else { param_from.resize_as_(param_to); } } void _viewOrCopyParams( MatrixRef params_from, MatrixRef params_to, bool copy, bool allow_type_change = false) { TORCH_INTERNAL_ASSERT( params_from.size(0) == params_to.size(0), "number of layers mismatch"); for (const auto i : c10::irange(params_from.size(0))) { auto layer_params_from = params_from[i]; auto layer_params_to = params_to[i]; // NOTE: these lists have all weights before all biases, so if the layer // doesn't use biases, iteration will terminate once layer_params_from ends // and ignore them. // NOTE: there is an exception from the above statement. If LSTMs with // projections are used, weights layout will be w_ih, w_hh, b_ih, b_hh, // w_hr. So need to handle no-bias case specially, because will need to copy // 0->0, 1->1, 2->4. This case can be uniquely identified by checking if // number of defined parameters for each layer is 3. if (layer_params_from.size() == 3 && layer_params_to.size() != 3) { _viewOrCopyOneParam( layer_params_from[0], layer_params_to[0], copy, allow_type_change); _viewOrCopyOneParam( layer_params_from[1], layer_params_to[1], copy, allow_type_change); _viewOrCopyOneParam( layer_params_from[2], layer_params_to[4], copy, allow_type_change); continue; } if (layer_params_to.size() == 3 && layer_params_from.size() != 3) { _viewOrCopyOneParam( layer_params_from[0], layer_params_to[0], copy, allow_type_change); _viewOrCopyOneParam( layer_params_from[1], layer_params_to[1], copy, allow_type_change); _viewOrCopyOneParam( layer_params_from[4], layer_params_to[2], copy, allow_type_change); continue; } for (auto a = layer_params_from.begin(), b = layer_params_to.begin(); a != layer_params_from.end() && b != layer_params_to.end(); ++a, ++b) { _viewOrCopyOneParam(*a, *b, copy, allow_type_change); } } } void _copyParams(MatrixRef params_from, MatrixRef params_to) { _viewOrCopyParams(params_from, params_to, true); } void _viewParams(MatrixRef params_from, MatrixRef params_to) { _viewOrCopyParams(params_from, params_to, false); } std::vector _input_size(const TensorDescriptorListParams& tensors) { if (tensors.is_input_packed()) { return {tensors.batch_sizes_sum, tensors.input_size}; } else { return {tensors.seq_length, tensors.mini_batch, tensors.input_size}; } } std::vector _hidden_size( const RNNDescriptorParams& rnn, const TensorDescriptorListParams& tensors) { if (rnn.proj_size != 0) { return { rnn.num_layers * rnn.num_directions(), tensors.mini_batch, rnn.proj_size}; } else { return { rnn.num_layers * rnn.num_directions(), tensors.mini_batch, rnn.hidden_size}; } } std::vector _cell_size( const RNNDescriptorParams& rnn, const TensorDescriptorListParams& tensors) { return { rnn.num_layers * rnn.num_directions(), tensors.mini_batch, rnn.hidden_size}; } std::vector _output_size( const RNNDescriptorParams& rnn, const TensorDescriptorListParams& tensors) { auto out_size = rnn.hidden_size; if (rnn.proj_size != 0) { out_size = rnn.proj_size; } if (tensors.is_input_packed()) { return {tensors.batch_sizes_sum, out_size * rnn.num_directions()}; } else { return { tensors.seq_length, tensors.mini_batch, out_size * rnn.num_directions()}; } } inline bool use_persist_common_heuristics( const RNNDescriptorParams& rnn, const TensorDescriptorListParams& tensors) { return rnn.num_layers == 1 && rnn.hidden_size <= 1024 && rnn.num_directions() == 1 && rnn.hidden_size % 128 == 0 && tensors.input_size % 128 == 0; } inline bool use_persist_device_heuristics( const RNNDescriptorParams& rnn, const TensorDescriptorListParams& tensors) { auto bsize = tensors.mini_batch; cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties(); if (prop->major == 7) { if (prop->minor == 5) { // Excludes Turing from using persistent rnn. return false; } else { // technically, batch size should be multiple of 8, but there are quite a // few multiple-of-8 batchsizes that give bad perf, weed them out return ((bsize % 16 == 0 && bsize != 80 && bsize != 112) || bsize == 8) && ((tensors.seq_length >= 40 && bsize <= 128) || (tensors.seq_length >= 20 && bsize <= 96) || (tensors.seq_length >= 10 && bsize <= 32)); } } else if (prop->major >= 8 && prop->multiProcessorCount >= 98) { // SM count check excludes A30 (similar issue to A40) if (prop->minor == 6) { // Excludes sm_86 GPU devices from using persistent rnn. // This is because there are some edge cases that will throw exceptions // with cudnn 8.0.5 on Nvidia A40 GPU. return false; } // Based on tests by Vasily Volkov and xwang233. Vasily only tried bsize <= // 128, so conservatively enable persistence for bsize <= 128 only. // TODO: Run more tests for bsize > 128. if (rnn.mode == CUDNN_GRU) { // Persistent GRU performance is flakier than other RNN types. Exclude // them for now. // TODO: Write a more refined GRU heuristic. return false; } else if (rnn.mode == CUDNN_LSTM) { // Persistent LSTMs are comparable to or better than non-persistent for // bsize <= 128. return (bsize % 8 == 0) && (bsize <= 128); } else { // Persistent RNN_RELU and TANH show poor performance when bsize >= 96 AND // hidden size >= 896. return (bsize % 8 == 0) && (bsize <= 128) && (bsize < 96 || rnn.hidden_size < 896); } } else { return false; } } inline bool use_rnn_persist_small_h( const RNNDescriptorParams& rnn, const TensorDescriptorListParams& tensors, bool forward) { cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties(); if (prop->major < 6) return false; if (forward) { if (rnn.mode == CUDNN_RNN_RELU || rnn.mode == CUDNN_RNN_TANH) { return rnn.hidden_size <= 384; } if (rnn.mode == CUDNN_LSTM || rnn.mode == CUDNN_GRU) { return rnn.hidden_size <= 192; } } else /* backward */ { if (rnn.mode == CUDNN_RNN_RELU || rnn.mode == CUDNN_RNN_TANH) { return rnn.hidden_size <= 256; } if (rnn.mode == CUDNN_LSTM || rnn.mode == CUDNN_GRU) { return rnn.hidden_size <= 128; } } return false; } cudnnRNNAlgo_t get_algo( const RNNDescriptorParams& rnn, const TensorDescriptorListParams& tensors, const Tensor input, bool forward) { // LSTM with projections only works with standard algorithm if (rnn.proj_size != 0) { return CUDNN_RNN_ALGO_STANDARD; } // Persistent algos typically don't work for packed inputs with sequence // lengths that vary across batch elements, and will return // CUDNN_STATUS_NOT_SUPPORTED if attempted. See // https://docs.nvidia.com/deeplearning/cudnn/developer-guide/index.html#features-of-rnn-functions if (!tensors.is_input_packed()) { auto cudnnDataType = getCudnnDataType(input); if (cudnnDataType != CUDNN_DATA_DOUBLE) { if (use_rnn_persist_small_h(rnn, tensors, forward)) { return CUDNN_RNN_ALGO_PERSIST_STATIC_SMALL_H; } } if (cudnnDataType == CUDNN_DATA_HALF) { if (use_persist_common_heuristics(rnn, tensors) && use_persist_device_heuristics(rnn, tensors)) { return CUDNN_RNN_ALGO_PERSIST_STATIC; } } } return CUDNN_RNN_ALGO_STANDARD; } cudnnDataType_t promote_rnn_math_type(cudnnDataType_t dtype) { if (dtype == CUDNN_DATA_HALF) { return CUDNN_DATA_FLOAT; } return dtype; } } // namespace native // Utilities exposed in RNNUtils.h namespace cudnn_rnn { TORCH_CUDA_CPP_API std::tuple> copy_weights_to_flat_buf_views( TensorList weight_arr, int64_t weight_stride0, int64_t input_size, int64_t mode, int64_t hidden_size, int64_t proj_size, int64_t num_layers, bool batch_first, bool bidirectional, const cudnnDataType_t flat_buf_datatype, const TensorOptions& flat_buf_options, bool set_orig_weights_to_flat_buf, bool allow_type_change /*=false*/, bool include_bias /*=true*/) { // flat_buf_datatype is accepted as a separate argument (rather than extracted // from flat_buf_options) because to extract flat_buf_datatype from // flat_buf_options, we'd need to say auto flat_buf_datatype = // getCudnnDataTypeFromScalarType(typeMetaToScalarType(options.dtype())); // typeMetaToScalarType is a surprisingly nontrivial function. We should // avoid it if we can. TORCH_CHECK( weight_arr.size() > 0, "copy_weights_to_flat_buf_views: cannot flatten empty weight list"); RNNDescriptorParams rnn; rnn.set( mode, #ifdef USE_CUDNN_RNN_V8_API input_size, false, // eqy: bogus as we do not know if the input is packed here // but it should not affect the weights (what are are interested // in) #endif hidden_size, proj_size, num_layers, bidirectional, promote_rnn_math_type(flat_buf_datatype), flat_buf_datatype); auto handle = getCudnnHandle(); RNNDescriptor rnn_desc = rnn.descriptor(handle); TensorGeometry x_geom({1, input_size}); TensorDescriptor x_desc; // Why do we pad to 5 dims here (and elsewhere)? // https://docs.nvidia.com/deeplearning/sdk/cudnn-api/index.html#cudnnRNNForwardTraining // expects descriptors padded to 3 dimensions. x_desc.set(flat_buf_datatype, x_geom.sizes(), x_geom.strides(), 5); auto num_weights = #ifndef USE_CUDNN_RNN_V8_API get_num_weights(handle, rnn_desc, x_desc, flat_buf_datatype); #else get_num_weights(handle, rnn_desc, flat_buf_datatype); #endif auto weight_buf = at::zeros(num_weights, flat_buf_options); #ifndef USE_CUDNN_RNN_V8_API FilterDescriptor w_desc; w_desc.set(weight_buf, 3); #endif // Slice off views into weight_buf auto [params_arr, params_stride0] = get_parameters( #ifndef USE_CUDNN_RNN_V8_API handle, rnn, rnn_desc, x_desc, w_desc, weight_buf, include_bias); #else handle, rnn, rnn_desc, weight_buf, include_bias); #endif MatrixRef weight{weight_arr, static_cast(weight_stride0)}, params{params_arr, params_stride0}; // Copy weights _viewOrCopyParams(weight, params, /*copy=*/true, allow_type_change); if (set_orig_weights_to_flat_buf) { // Update the storage for (const auto i : c10::irange(weight.size(0))) { // There is a special case for LSTM with projections and no bias, // where weight copy is done in 0->0, 1->1, 2->4 layout if (weight[i].size() == 3 && params[i].size() == 5) { weight[i][0].set_(params[i][0].view_as(weight[i][0])); weight[i][1].set_(params[i][1].view_as(weight[i][1])); weight[i][2].set_(params[i][4].view_as(weight[i][2])); } else { for (auto orig_param_it = weight[i].begin(), new_param_it = params[i].begin(); orig_param_it != weight[i].end() && new_param_it != params[i].end(); orig_param_it++, new_param_it++) { auto orig_param = *orig_param_it, new_param = *new_param_it; orig_param.set_(new_param.view_as(orig_param)); } } } } return std::make_tuple(weight_buf, params_arr); } } // namespace cudnn_rnn using namespace cudnn_rnn; // NB: does inplace update into TensorList // It would be a relatively simple matter to refactor this into multiple // functions, only one of which does an inplace update, but we leave this // for future work Tensor _cudnn_rnn_flatten_weight( TensorList weight_arr, int64_t weight_stride0, int64_t input_size, int64_t fn_mode, int64_t fn_hidden_size, int64_t fn_proj_size, int64_t fn_num_layers, bool batch_first, bool fn_bidirectional) { // returns flat weight_buf return std::get<0>(copy_weights_to_flat_buf_views( weight_arr, weight_stride0, input_size, fn_mode, fn_hidden_size, fn_proj_size, fn_num_layers, batch_first, fn_bidirectional, /*flat_buf_datatype=*/getCudnnDataType(weight_arr[0]), /*flat_buf_options=*/weight_arr[0].options(), /*set_orig_weights_to_flat_buf=*/true)); } const char* WEIGHT_FORMAT_WARN = "RNN module weights are not part of single contiguous " "chunk of memory. This means they need to be compacted " "at every call, possibly greatly increasing memory usage. " "To compact weights again call flatten_parameters()."; // NB: when fn_batch_sizes is empty, that means no batch sizes was specified std::tuple _cudnn_rnn( const Tensor& input_r, TensorList weight, int64_t weight_stride0, const std::optional& weight_buf_r_opt, const Tensor& hx, const std::optional& cx_opt, int64_t fn_mode, int64_t fn_hidden_size, int64_t fn_proj_size, int64_t fn_num_layers, bool batch_first, double fn_dropout, bool fn_train, bool fn_bidirectional, IntArrayRef fn_batch_sizes, const std::optional& fn_dropout_state_opt) { // See [Note: hacky wrapper removal for optional tensor] c10::MaybeOwned weight_buf_r_maybe_owned = at::borrow_from_optional_tensor(weight_buf_r_opt); const Tensor& weight_buf_r = *weight_buf_r_maybe_owned; const Tensor& cx = c10::value_or_else(cx_opt, [] { return Tensor(); }); const Tensor& fn_dropout_state = c10::value_or_else(fn_dropout_state_opt, [] { return Tensor(); }); check_attributes(input_r, weight, {hx, cx}, /*check_dtype=*/true); auto input = input_r; auto weight_buf = weight_buf_r; if (!weight_buf.defined()) { TORCH_WARN(WEIGHT_FORMAT_WARN); } if (fn_dropout_state.defined()) { auto input_arg = TensorArg(input, "input", 1); auto dropout_state_arg = TensorArg(fn_dropout_state, "dropout_states", 15); checkSameGPU("cudnn_rnn", input_arg, dropout_state_arg); } RNNParams fn; auto datatype = getCudnnDataType(input); #ifndef USE_CUDNN_RNN_V8_API fn.rnn.set( fn_mode, fn_hidden_size, fn_proj_size, fn_num_layers, fn_bidirectional, promote_rnn_math_type(datatype), datatype); #else auto input_size = input_r.size(-1); auto packed = fn_batch_sizes.size() != 0; fn.rnn.set( fn_mode, input_size, packed, fn_hidden_size, fn_proj_size, fn_num_layers, fn_bidirectional, promote_rnn_math_type(datatype), datatype); #endif fn.dropout.set(fn_train, fn_dropout, fn_dropout_state); fn.tensors.set(input.sizes(), fn_batch_sizes, batch_first); // TODO: Set device to input if (fn.rnn.mode != CUDNN_LSTM) { TORCH_CHECK(!cx.defined(), "rnn: illegal defined cx for non-LSTM RNN"); } // TODO: can batch_first be a wrapper around this function? auto is_input_packed = fn.tensors.batch_sizes.size() != 0; if (batch_first && !is_input_packed) { input = input.transpose(0, 1); } auto hidden_size = _hidden_size(fn.rnn, fn.tensors); auto cell_size = _cell_size(fn.rnn, fn.tensors); auto output_size = _output_size(fn.rnn, fn.tensors); TORCH_CHECK(hx.is_contiguous(), "rnn: hx is not contiguous"); TORCH_CHECK(!cx.defined() || cx.is_contiguous(), "rnn: cx is not contiguous"); auto x = input.contiguous(); auto output = at::empty(output_size, input.options()); auto hy = at::empty(hidden_size, hx.options()); Tensor cy; if (cx.defined()) { cy = at::empty(cell_size, cx.options()); } else { cy = at::empty( {0}, hx.options()); // NB: Not allowed to return undefined tensors } auto y = output; auto handle = getCudnnHandle(); cudnnRNNAlgo_t algo = get_algo(fn.rnn, fn.tensors, input, true); fn.rnn.set_algo(algo); RNNDescriptors descs(fn, handle, x, y, hx, cx); #ifndef USE_CUDNN_RNN_V8_API FilterDescriptor w_desc; #endif if (!weight_buf.defined()) { #ifndef USE_CUDNN_RNN_V8_API auto num_weights = get_num_weights(handle, descs.rnn_desc, descs.x_descs[0], datatype); #else auto num_weights = get_num_weights(handle, descs.rnn_desc, datatype); #endif weight_buf = at::empty(num_weights, x.options()); #ifndef USE_CUDNN_RNN_V8_API w_desc.set(weight_buf, 3); #endif weight_buf.zero_(); #ifndef USE_CUDNN_RNN_V8_API auto [params, params_stride0] = get_parameters( handle, fn.rnn, descs.rnn_desc, descs.x_descs[0], w_desc, weight_buf); #else auto [params, params_stride0] = get_parameters(handle, fn.rnn, descs.rnn_desc, weight_buf); #endif _copyParams( MatrixRef{weight, static_cast(weight_stride0)}, MatrixRef{params, params_stride0}); } else { #ifndef USE_CUDNN_RNN_V8_API w_desc.set(weight_buf, 3); #endif } TORCH_CHECK( !cx.defined() || cx.sizes().equals(cell_size), "Expected cell size ", IntArrayRef{cell_size}, ", got ", cx.sizes()); size_t workspace_size; #ifndef USE_CUDNN_RNN_V8_API auto x_descs_arr = descs.get_x_descs(); auto y_descs_arr = descs.get_y_descs(); #else auto& x_descs_arr = descs.x_descs; auto& y_descs_arr = descs.y_descs; #endif #ifndef USE_CUDNN_RNN_V8_API AT_CUDNN_CHECK(cudnnGetRNNWorkspaceSize( handle, descs.rnn_desc.desc(), fn.tensors.seq_length, x_descs_arr.data(), &workspace_size)); #endif Tensor workspace; Tensor reserve; // NB: Previously, the test was for fn.requires_grad, but we don't have // this information. Use 'train' as a proxy. if (fn_train) { size_t reserve_size; #ifndef USE_CUDNN_RNN_V8_API AT_CUDNN_CHECK(cudnnGetRNNTrainingReserveSize( handle, descs.rnn_desc.desc(), fn.tensors.seq_length, x_descs_arr.data(), &reserve_size)); #else AT_CUDNN_CHECK(cudnnGetRNNTempSpaceSizes( handle, descs.rnn_desc.desc(), CUDNN_FWD_MODE_TRAINING, x_descs_arr.desc(), &workspace_size, &reserve_size)); #endif workspace = at::empty(workspace_size, input.options().dtype(kByte)); reserve = at::empty(reserve_size, input.options().dtype(kByte)); #ifndef USE_CUDNN_RNN_V8_API AT_CUDNN_CHECK(cudnnRNNForwardTraining( handle, descs.rnn_desc.desc(), fn.tensors.seq_length, x_descs_arr.data(), x.data_ptr(), descs.hx_desc.desc(), hx.data_ptr(), descs.cx_desc.desc(), cx.defined() ? cx.data_ptr() : nullptr, w_desc.desc(), weight_buf.data_ptr(), y_descs_arr.data(), y.data_ptr(), descs.hy_desc.desc(), hy.data_ptr(), descs.cy_desc.desc(), cy.defined() ? cy.data_ptr() : nullptr, workspace.data_ptr(), workspace.size(0), reserve.mutable_data_ptr(), reserve.size(0))); #else AT_CUDNN_CHECK(cudnnRNNForward( handle, descs.rnn_desc.desc(), CUDNN_FWD_MODE_TRAINING, nullptr, x_descs_arr.desc(), x.data_ptr(), y_descs_arr.desc(), y.data_ptr(), descs.hx_desc.desc(), hx.data_ptr(), hy.data_ptr(), descs.cx_desc.desc(), cx.defined() ? cx.data_ptr() : nullptr, cy.defined() ? cy.data_ptr() : nullptr, weight_buf.numel() * weight_buf.element_size(), weight_buf.data_ptr(), workspace.size(0), workspace.data_ptr(), reserve.size(0), reserve.mutable_data_ptr())); #endif } else { // inference #ifdef USE_CUDNN_RNN_V8_API AT_CUDNN_CHECK(cudnnGetRNNTempSpaceSizes( handle, descs.rnn_desc.desc(), CUDNN_FWD_MODE_INFERENCE, x_descs_arr.desc(), &workspace_size, NULL)); #endif workspace = at::empty(workspace_size, input.options().dtype(kByte)); reserve = at::empty({0}, input.options().dtype(kByte)); #ifndef USE_CUDNN_RNN_V8_API AT_CUDNN_CHECK(cudnnRNNForwardInference( handle, descs.rnn_desc.desc(), fn.tensors.seq_length, x_descs_arr.data(), x.data_ptr(), descs.hx_desc.desc(), hx.data_ptr(), descs.cx_desc.desc(), cx.defined() ? cx.data_ptr() : nullptr, w_desc.desc(), weight_buf.data_ptr(), y_descs_arr.data(), y.data_ptr(), descs.hy_desc.desc(), hy.data_ptr(), descs.cy_desc.desc(), cy.defined() ? cy.data_ptr() : nullptr, workspace.data_ptr(), workspace.size(0))); #else AT_CUDNN_CHECK(cudnnRNNForward( handle, descs.rnn_desc.desc(), CUDNN_FWD_MODE_INFERENCE, nullptr, x_descs_arr.desc(), x.data_ptr(), y_descs_arr.desc(), y.data_ptr(), descs.hx_desc.desc(), hx.data_ptr(), hy.data_ptr(), descs.cx_desc.desc(), cx.defined() ? cx.data_ptr() : nullptr, cy.defined() ? cy.data_ptr() : nullptr, weight_buf.numel() * weight_buf.element_size(), weight_buf.data_ptr(), workspace.size(0), workspace.data_ptr(), reserve.size(0), reserve.mutable_data_ptr())); #endif } if (batch_first && !is_input_packed) { output.transpose_(0, 1); } return std::make_tuple(output, hy, cy, reserve, weight_buf); } std::tuple _cudnn_rnn_backward_input( const Tensor& input_r, const Tensor& weight_buf, const Tensor& hx, const Tensor& cx, const Tensor& output_r, const Tensor& grad_output_r, const Tensor& grad_hy, const Tensor& grad_cy, int64_t fn_mode, int64_t fn_hidden_size, int64_t fn_proj_size, int64_t fn_num_layers, bool batch_first, double fn_dropout, bool fn_train, bool fn_bidirectional, IntArrayRef fn_batch_sizes, const Tensor& fn_dropout_state, const Tensor& fn_reserve, std::array output_mask) { auto input = input_r; auto grad_output = grad_output_r; auto output = output_r; RNNParams fn; auto datatype = getCudnnDataType(input); #ifndef USE_CUDNN_RNN_V8_API fn.rnn.set( fn_mode, fn_hidden_size, fn_proj_size, fn_num_layers, fn_bidirectional, promote_rnn_math_type(datatype), datatype); #else auto cudnn_input_size = input_r.size(-1); auto packed = fn_batch_sizes.size() != 0; fn.rnn.set( fn_mode, cudnn_input_size, packed, fn_hidden_size, fn_proj_size, fn_num_layers, fn_bidirectional, promote_rnn_math_type(datatype), datatype); #endif fn.dropout.set(fn_train, fn_dropout, fn_dropout_state); fn.tensors.set(input.sizes(), fn_batch_sizes, batch_first); // TODO: Set device to input auto handle = getCudnnHandle(); if (fn.rnn.mode != CUDNN_LSTM) { TORCH_CHECK(!cx.defined(), "rnn: illegal defined cx for non-LSTM RNN"); } auto is_input_packed = fn_batch_sizes.size() != 0; if (batch_first && !is_input_packed) { input = input.transpose(0, 1); grad_output = grad_output.transpose(0, 1); output = output.transpose(0, 1); } auto input_size = _input_size(fn.tensors); auto hidden_size = _hidden_size(fn.rnn, fn.tensors); auto cell_size = _cell_size(fn.rnn, fn.tensors); auto output_size = _output_size(fn.rnn, fn.tensors); TORCH_CHECK(hx.is_contiguous(), "rnn: hx is not contiguous"); TORCH_CHECK(!cx.defined() || cx.is_contiguous(), "rnn: cx is not contiguous"); auto x = input.contiguous(); auto dy = grad_output.contiguous(); auto y = output; auto w = weight_buf; auto dx = at::empty( input.sizes(), input.options()); // TODO: more compact way of saying this auto dhy = grad_hy.contiguous().view(hidden_size); auto dcy = grad_cy.defined() ? grad_cy.contiguous().view(cell_size) : Tensor(); auto dhx = at::empty(hidden_size, hx.options()); TORCH_INTERNAL_ASSERT( cx.defined() || !output_mask[2], "illegally required grad of cx for non-LSTM RNN"); auto dcx = cx.defined() ? at::empty(cell_size, cx.options()) : Tensor(); TORCH_CHECK( fn_train, "cudnn RNN backward can only be called in training mode"); TORCH_CHECK( input.sizes().equals(input_size), "Expected input size ", IntArrayRef{input_size}, ", got ", input.sizes()); TORCH_CHECK( output.sizes().equals(output_size), "Expected output size ", IntArrayRef{output_size}, ", got ", output.sizes()); TORCH_CHECK( !hx.defined() || hx.sizes().equals(hidden_size), "Expected hidden size ", IntArrayRef{hidden_size}, ", got ", hx.sizes()); TORCH_CHECK( !cx.defined() || cx.sizes().equals(cell_size), "Expected cell size ", IntArrayRef{cell_size}, ", got ", cx.sizes()); TORCH_CHECK( !dhy.defined() || dhy.sizes().equals(hidden_size), "Expected d_hidden size ", IntArrayRef{hidden_size}, ", got ", dhy.sizes()); TORCH_CHECK( !dcy.defined() || dcy.sizes().equals(cell_size), "Expected d_cell size ", IntArrayRef{cell_size}, ", got ", dcy.sizes()); TORCH_CHECK( dhy.is_cuda() && dy.is_cuda() && (!dcy.defined() || dcy.is_cuda()), "Gradients aren't CUDA tensors"); cudnnRNNAlgo_t algo = get_algo(fn.rnn, fn.tensors, input, false); fn.rnn.set_algo(algo); RNNDescriptors descs(fn, handle, x, y, hx, cx); #ifndef USE_CUDNN_RNN_V8_API FilterDescriptor w_desc; w_desc.set(weight_buf, 3); #endif size_t workspace_size; #ifndef USE_CUDNN_RNN_V8_API auto x_descs_arr = descs.get_x_descs(); auto y_descs_arr = descs.get_y_descs(); AT_CUDNN_CHECK(cudnnGetRNNWorkspaceSize( handle, descs.rnn_desc.desc(), fn.tensors.seq_length, x_descs_arr.data(), &workspace_size)); #else auto& x_descs_arr = descs.x_descs; auto& y_descs_arr = descs.y_descs; AT_CUDNN_CHECK(cudnnGetRNNTempSpaceSizes( handle, descs.rnn_desc.desc(), CUDNN_FWD_MODE_TRAINING, x_descs_arr.desc(), &workspace_size, NULL)); #endif // TODO: put this in the correct device??? Tensor workspace = at::empty(workspace_size, input.options().dtype(kByte)); #ifndef USE_CUDNN_RNN_V8_API AT_CUDNN_CHECK(cudnnRNNBackwardData( handle, descs.rnn_desc.desc(), fn.tensors.seq_length, y_descs_arr.data(), y.data_ptr(), y_descs_arr.data(), dy.data_ptr(), descs.hy_desc.desc(), dhy.data_ptr(), descs.cy_desc.desc(), cx.defined() ? dcy.data_ptr() : nullptr, w_desc.desc(), w.data_ptr(), descs.hx_desc.desc(), hx.data_ptr(), descs.cx_desc.desc(), cx.defined() ? cx.data_ptr() : nullptr, x_descs_arr.data(), dx.data_ptr(), descs.hx_desc.desc(), dhx.data_ptr(), descs.cx_desc.desc(), cx.defined() ? dcx.data_ptr() : nullptr, workspace.data_ptr(), workspace.size(0), fn_reserve.data_ptr(), fn_reserve.size(0))); #else AT_CUDNN_CHECK(cudnnRNNBackwardData_v8( handle, descs.rnn_desc.desc(), nullptr, y_descs_arr.desc(), y.data_ptr(), dy.data_ptr(), x_descs_arr.desc(), dx.data_ptr(), descs.hx_desc.desc(), hx.data_ptr(), dhy.data_ptr(), dhx.data_ptr(), descs.cx_desc.desc(), cx.defined() ? cx.data_ptr() : nullptr, cx.defined() ? dcy.data_ptr() : nullptr, cx.defined() ? dcx.data_ptr() : nullptr, weight_buf.numel() * weight_buf.element_size(), weight_buf.data_ptr(), workspace.size(0), workspace.data_ptr(), fn_reserve.size(0), fn_reserve.data_ptr())); #endif if (batch_first && !is_input_packed) { dx = dx.transpose_(0, 1); } return std::make_tuple(dx, dhx, dcx); } // NB: This MUST BE CALLED AFTER _cudnn_rnn_backward_input. // We'll give a user friendly combined function... std::vector _cudnn_rnn_backward_weight( // TODO: I think tensor geometry sufficient for weight_buf/weight const Tensor& input_r, TensorList weight_arr, int64_t weight_stride0, const Tensor& weight_buf, const Tensor& hx, const Tensor& cx, const Tensor& output_r, int64_t fn_mode, int64_t fn_hidden_size, int64_t fn_proj_size, int64_t fn_num_layers, bool batch_first, double fn_dropout, bool fn_train, bool fn_bidirectional, IntArrayRef fn_batch_sizes, const Tensor& fn_dropout_state, const Tensor& fn_reserve) { MatrixRef weight{weight_arr, static_cast(weight_stride0)}; auto input = input_r; auto output = output_r; RNNParams fn; auto datatype = getCudnnDataType(input); #ifndef USE_CUDNN_RNN_V8_API fn.rnn.set( fn_mode, fn_hidden_size, fn_proj_size, fn_num_layers, fn_bidirectional, promote_rnn_math_type(datatype), datatype); #else auto cudnn_input_size = input_r.size(-1); auto packed = fn_batch_sizes.size() != 0; fn.rnn.set( fn_mode, cudnn_input_size, packed, fn_hidden_size, fn_proj_size, fn_num_layers, fn_bidirectional, promote_rnn_math_type(datatype), datatype); #endif fn.dropout.set(fn_train, fn_dropout, fn_dropout_state); fn.tensors.set(input.sizes(), fn_batch_sizes, batch_first); auto handle = getCudnnHandle(); if (fn.rnn.mode != CUDNN_LSTM) { TORCH_CHECK(!cx.defined(), "rnn: illegal defined cx for non-LSTM RNN"); } auto is_input_packed = fn_batch_sizes.size() != 0; if (batch_first && !is_input_packed) { input = input.transpose(0, 1); output = output.transpose(0, 1); } auto input_size = _input_size(fn.tensors); auto hidden_size = _hidden_size(fn.rnn, fn.tensors); TORCH_CHECK( fn_train, "cudnn RNN backward can only be called in training mode"); TORCH_CHECK( input.sizes().equals(input_size), "Expected input size ", IntArrayRef{input_size}, ", got ", input.sizes()); TORCH_CHECK( !hx.defined() || hx.sizes().equals(hidden_size), "Expected hidden size ", IntArrayRef{hidden_size}, ", got ", hx.sizes()); // TODO: the above were the only checks in rnn.py, but it doesn't seem // like these checks are enough TORCH_CHECK(hx.is_contiguous(), "rnn: hx is not contiguous"); TORCH_CHECK(!cx.defined() || cx.is_contiguous(), "rnn: cx is not contiguous"); auto x = input.contiguous(); const auto& y = output; auto dw = at::zeros(weight_buf.sizes(), weight_buf.options()); cudnnRNNAlgo_t algo = get_algo(fn.rnn, fn.tensors, input, false); fn.rnn.set_algo(algo); RNNDescriptors descs(fn, handle, x, y, hx, cx); #ifndef USE_CUDNN_RNN_V8_API FilterDescriptor w_desc; w_desc.set(weight_buf, 3); #endif size_t workspace_size; #ifndef USE_CUDNN_RNN_V8_API auto x_descs_arr = descs.get_x_descs(); auto y_descs_arr = descs.get_y_descs(); AT_CUDNN_CHECK(cudnnGetRNNWorkspaceSize( handle, descs.rnn_desc.desc(), fn.tensors.seq_length, x_descs_arr.data(), &workspace_size)); #else auto& x_descs_arr = descs.x_descs; auto& y_descs_arr = descs.y_descs; AT_CUDNN_CHECK(cudnnGetRNNTempSpaceSizes( handle, descs.rnn_desc.desc(), CUDNN_FWD_MODE_TRAINING, x_descs_arr.desc(), &workspace_size, NULL)); #endif Tensor workspace = at::empty(workspace_size, input.options().dtype(kByte)); #ifndef USE_CUDNN_RNN_V8_API AT_CUDNN_CHECK(cudnnRNNBackwardWeights( handle, descs.rnn_desc.desc(), fn.tensors.seq_length, x_descs_arr.data(), x.data_ptr(), descs.hx_desc.desc(), hx.data_ptr(), y_descs_arr.data(), y.data_ptr(), workspace.data_ptr(), workspace.size(0), w_desc.desc(), dw.data_ptr(), fn_reserve.data_ptr(), fn_reserve.size(0))); #else AT_CUDNN_CHECK(cudnnRNNBackwardWeights_v8( handle, descs.rnn_desc.desc(), CUDNN_WGRAD_MODE_ADD, nullptr, x_descs_arr.desc(), x.data_ptr(), descs.hx_desc.desc(), hx.data_ptr(), y_descs_arr.desc(), y.data_ptr(), weight_buf.numel() * weight_buf.element_size(), dw.data_ptr(), workspace.size(0), workspace.data_ptr(), fn_reserve.size(0), fn_reserve.data_ptr())); #endif #ifndef USE_CUDNN_RNN_V8_API auto [grad_params_arr, grad_params_stride0] = get_parameters( handle, fn.rnn, descs.rnn_desc, descs.x_descs[0], w_desc, dw); #else auto [grad_params_arr, grad_params_stride0] = get_parameters(handle, fn.rnn, descs.rnn_desc, dw); #endif if (grad_params_stride0 == static_cast(weight_stride0)) { _viewParams( MatrixRef{grad_params_arr, grad_params_stride0}, MatrixRef{weight_arr, static_cast(weight_stride0)}); return grad_params_arr; } else { std::vector grad_weight_arr; grad_weight_arr.reserve(weight.numel()); for (const auto& w : weight_arr) { grad_weight_arr.emplace_back(at::empty(w.sizes(), w.options())); } _copyParams( MatrixRef{grad_params_arr, grad_params_stride0}, MatrixRef{ grad_weight_arr, static_cast(weight_stride0)}); return grad_weight_arr; } } // We need this dispatcher because _cudnn_rnn_backward_weight has a stringent // ordering requirement with _cudnn_rnn_backward_input std::tuple> _cudnn_rnn_backward( const Tensor& input, TensorList weight, int64_t weight_stride0, const Tensor& weight_buf, const Tensor& hx, const std::optional& cx_opt, const Tensor& output, const std::optional& grad_output_r_opt, const std::optional& grad_hy_r_opt, const std::optional& grad_cy_r_opt, int64_t mode, int64_t hidden_size, int64_t proj_size, int64_t num_layers, bool batch_first, double dropout, bool train, bool bidirectional, IntArrayRef batch_sizes, const std::optional& dropout_state_opt, const Tensor& reserve, std::array output_mask) { // See [Note: hacky wrapper removal for optional tensor] c10::MaybeOwned cx_maybe_owned = at::borrow_from_optional_tensor(cx_opt); const Tensor& cx = *cx_maybe_owned; const Tensor& grad_output_r = c10::value_or_else(grad_output_r_opt, [] { return Tensor(); }); const Tensor& grad_hy_r = c10::value_or_else(grad_hy_r_opt, [] { return Tensor(); }); const Tensor& grad_cy_r = c10::value_or_else(grad_cy_r_opt, [] { return Tensor(); }); const Tensor& dropout_state = c10::value_or_else(dropout_state_opt, [] { return Tensor(); }); if (!grad_output_r.defined() && !grad_hy_r.defined() && !grad_cy_r.defined()) { return std::tuple>( Tensor(), Tensor(), Tensor(), std::vector(weight.size())); } auto grad_output = grad_output_r.defined() ? grad_output_r : at::zeros_like(output, LEGACY_CONTIGUOUS_MEMORY_FORMAT); auto grad_hy = grad_hy_r.defined() ? grad_hy_r : at::zeros_like(hx, LEGACY_CONTIGUOUS_MEMORY_FORMAT); auto grad_cy = cx.defined() ? (grad_cy_r.defined() ? grad_cy_r : at::zeros_like(cx, LEGACY_CONTIGUOUS_MEMORY_FORMAT)) : grad_cy_r; // NB: unconditionally compute this gradient, because it mutates reserve auto [dx, dhx, dcx] = at::native::_cudnn_rnn_backward_input( input, weight_buf, hx, cx, output, grad_output, grad_hy, grad_cy, mode, hidden_size, proj_size, num_layers, batch_first, dropout, train, bidirectional, batch_sizes, dropout_state, reserve, {output_mask[0], output_mask[1], output_mask[2]}); std::vector dw; if (output_mask[3]) { dw = at::native::_cudnn_rnn_backward_weight( input, weight, weight_stride0, weight_buf, hx, cx, output, mode, hidden_size, proj_size, num_layers, batch_first, dropout, train, bidirectional, batch_sizes, dropout_state, reserve); } return std::tuple>{ dx, dhx, dcx, dw}; } // TODO: I am not sure if we actually need the 'dropout' and 'train' parameters // to initialize just the state tensor // // NB: You can have any color you like, as long as it's a CUDA byte // tensor. Why does this function take a TensorOptions at all in that case? // This is a factory function: it produces tensors but takes no tensors // as input. The codegen currently assumes that ALL factory functions // take TensorOptions, so it's just a lot easier for this function to // be bound if it also does it. Tensor _cudnn_init_dropout_state( double dropout, bool train, int64_t dropout_seed, std::optional dtype, std::optional layout, std::optional device, std::optional pin_memory) { // See [Note: hacky wrapper removal for TensorOptions] TensorOptions options = TensorOptions().dtype(dtype).layout(layout).device(device).pinned_memory( pin_memory); auto handle = getCudnnHandle(); DropoutDescriptor dropout_desc; auto dropout_p = train ? dropout : 0; dropout_desc.initialize_rng(handle, dropout_p, dropout_seed, options); return dropout_desc.state; } //////////////////////////////////////////////////////////////////////////////// // CUDA dispatch for the generic RNN ops (at::lstm, at::gru, ...) //////////////////////////////////////////////////////////////////////////////// namespace { // Helpers for working with different hidden types. std::tuple unpack_hidden(const Tensor& hidden) { return std::make_tuple(hidden, at::Tensor{}); } std::tuple unpack_hidden( const std::tuple& hidden) { return hidden; } template hidden_type pack_hidden(const Tensor& hx, const Tensor& cx) { static_assert( false && sizeof(hidden_type), "pack_hidden not implemented for this type"); } template <> Tensor pack_hidden(const Tensor& hx, const Tensor& cx) { AT_ASSERT(cx.numel() == 0); return hx; } template <> std::tuple pack_hidden>( const Tensor& hx, const Tensor& cx) { return std::make_tuple(hx, cx); } /** * Note [DropoutState and CUDA graph capture] * ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ * (1) Telling a capturing stream to wait on an event recorded in a non-capturing stream is an error. * (2) Telling a non-capturing stream to wait on an event recorded during capture is also an error. * * So DropoutState's usage syncs could error if an RNN with dropout is called in an uncaptured region * then called in a captured region (triggering 1), or called in a captured region then called # in an uncaptured region (triggering 2). * * To prevent 1 and 2, lock() only syncs on the last usage event if it was recorded in the same * capture state as the current state (which also means the same graph, if capture is in progress). * * The solution should be safe as long as capture obeys the following restrictions: * - Only one capture may be underway at a time in a given process. * - While a capture is underway, no calls to eager ops on noncapturing streams (on any thread) * may interleave with the captured ops. * * TODO: As people experiment with capture, keep an eye out for use cases that might need to * relax those restrictions. * * See https://github.com/pytorch/pytorch/pull/56433 for more discussion. */ struct DropoutState { // Both buffer and event are lazily instantiated when a dropout state is // needed for the first time. Note that in this case needed != used, as we // don't need a buffer to e.g. run RNNs in test mode. at::Tensor buffer; std::optional event; std::mutex mutex; #if !defined(USE_ROCM) // cudaStreamGetCaptureInfo will never give back a capture id of 0, so 0 can // serve as a sentinel value that capture was not underway. cuda::CaptureId_t capture_id_last_lock = 0; cuda::CaptureId_t capture_id_last_unlock = 0; #endif // Every time we use a dropout state, we need to synchronize with its event, // to make sure all previous uses finish running before this one starts. Once // we're done, we record the event to allow others to synchronize with this // kernel. Those events are really needed only for inter-stream sync on a // single GPU. I doubt anyone will want to run cuDNN RNNs in parallel on a // single GPU, so they should end up being complete no-ops. void lock() { // NB: We can't ignore the lock even when event is undefined, because // someone could then define it before we get to unlock(). mutex.lock(); if (event) { #if !defined(USE_ROCM) // See Note [DropoutState and CUDA graph capture] cudaStreamCaptureStatus status; AT_CUDA_CHECK(cudaStreamGetCaptureInfo( cuda::getCurrentCUDAStream(), &status, &capture_id_last_lock)); if (status == cudaStreamCaptureStatus::cudaStreamCaptureStatusNone) { capture_id_last_lock = 0; } if (capture_id_last_lock == capture_id_last_unlock) { event->block(cuda::getCurrentCUDAStream()); } #else event->block(cuda::getCurrentCUDAStream()); #endif } } void unlock() { if (event) { event->record(); #if !defined(USE_ROCM) // See Note [DropoutState and CUDA graph capture] cudaStreamCaptureStatus status; AT_CUDA_CHECK(cudaStreamGetCaptureInfo( cuda::getCurrentCUDAStream(), &status, &capture_id_last_unlock)); if (status == cudaStreamCaptureStatus::cudaStreamCaptureStatusNone) { capture_id_last_unlock = 0; } TORCH_INTERNAL_ASSERT(capture_id_last_unlock == capture_id_last_lock); #endif } mutex.unlock(); } }; DropoutState& get_dropout_state( double dropout_p, bool train, TensorOptions options) { // Each state is slightly over 2MB and initialized lazily, so it's fine to // cache them. static std::vector dropout_state_cache{ static_cast(cuda::getNumGPUs())}; static std::mutex state_cache_mut; AT_ASSERT(options.device().is_cuda()); auto device = options.device().index(); std::unique_lock lock{state_cache_mut}; auto& state = dropout_state_cache.at(device); if (train && dropout_p > 0) { const auto& gen = at::detail::getCUDAHooks().getDefaultCUDAGenerator(device); auto gen_impl = gen.get(); bool reset_rnn_state = gen_impl->reset_rnn_state(); if (!state.buffer.defined() || reset_rnn_state) { std::unique_lock lock{state.mutex}; int64_t seed = at::empty({}, options.dtype(at::kLong)).random_(gen).item(); state.buffer = at::_cudnn_init_dropout_state( dropout_p, train, seed, options.dtype(at::kByte)); // NB: CUDA binds the event to a device at creation time, so we can // initialize it only now, when we know we're on the correct device. if (!state.event.has_value()) { state.event.emplace(); } } } return state; } Tensor try_get_weight_buf( const Tensor& input, TensorList parameters, bool has_biases, cudnnRNNMode_t mode, c10::SymInt hidden_size, c10::SymInt proj_size, int64_t num_layers, bool bidirectional) { // Prepare all relevant descriptors auto handle = getCudnnHandle(); auto& any_param = parameters.at(0); auto datatype = getCudnnDataType(any_param); // Something very naughty is happening here. try_get_weight_buf // is called from _cudnn_impl, which is a *composite*. In other words, // inside the composite function we need to query cudnn to figure out how big // the weight buf actually is going to be. This clearly cannot be done // symbolically. For now, we insert guards here; but once we have the black // box handling for dynamic shapes, we could also hypothetically infer out // the relationships RNNDescriptorParams rnn; #ifndef USE_CUDNN_RNN_V8_API rnn.set( mode, hidden_size.guard_int(__FILE__, __LINE__), proj_size.guard_int(__FILE__, __LINE__), num_layers, bidirectional, promote_rnn_math_type(datatype), datatype); #else auto cudnn_input_size = input.size(-1); auto packed = false; // eqy: bogus as we do not know if the input is packed // here again, it should also not affect the weights rnn.set( mode, cudnn_input_size, packed, hidden_size.guard_int(__FILE__, __LINE__), proj_size.guard_int(__FILE__, __LINE__), num_layers, bidirectional, promote_rnn_math_type(datatype), datatype); #endif RNNDescriptor rnn_desc = rnn.descriptor(handle); TensorGeometry x_geom({1, input.sym_size(-1).guard_int(__FILE__, __LINE__)}); TensorDescriptor x_desc; // datatype for x_desc comes from any_param, not input. // try_get_weight_buf's job is to check "is the weight buffer correctly laid // out for us to run it with input of the same datatype?" x_desc.set(datatype, x_geom.sizes(), x_geom.strides(), 5); #ifndef USE_CUDNN_RNN_V8_API auto num_params = get_num_weights(handle, rnn_desc, x_desc, datatype); #else auto num_params = get_num_weights(handle, rnn_desc, datatype); #endif // Try to get parameter storage auto param_storage = any_param.storage(); auto weight_buf = at::empty({0}, any_param.options()).set_(param_storage); if (weight_buf.size(0) < num_params) { return {}; } else if (weight_buf.size(0) > num_params) { weight_buf = weight_buf.narrow(0, 0, num_params); } // Get and check data pointers auto expected_data_ptrs = get_expected_data_ptrs( weight_buf, handle, rnn, rnn_desc, x_desc, datatype); int64_t num_parameters = parameters.size(); int64_t num_ptrs = expected_data_ptrs.size(); if (proj_size != 0) { AT_ASSERT(num_parameters % (has_biases ? 5 : 3) == 0); AT_ASSERT(num_ptrs % 5 == 0); if (has_biases) { AT_ASSERT(num_ptrs == num_parameters); for (const auto i : c10::irange(num_parameters)) { if (expected_data_ptrs[i] != parameters[i].data_ptr()) return {}; } } else { AT_ASSERT(num_parameters % 3 == 0); AT_ASSERT(num_ptrs == num_parameters * 5 / 3); for (int64_t param_i = 0, ptr_i = 0; ptr_i < num_ptrs; ptr_i += 5, param_i += 3) { if (expected_data_ptrs[ptr_i] != parameters[param_i].data_ptr()) return {}; if (expected_data_ptrs[ptr_i + 1] != parameters[param_i + 1].data_ptr()) return {}; if (expected_data_ptrs[ptr_i + 4] != parameters[param_i + 2].data_ptr()) return {}; } } } else { AT_ASSERT(num_ptrs == (num_parameters * (has_biases ? 1 : 2))); AT_ASSERT(num_parameters % (has_biases ? 4 : 2) == 0); for (int64_t param_i = 0, ptr_i = 0; ptr_i < num_ptrs; ptr_i += (has_biases ? 2 : 4), param_i += 2) { if (expected_data_ptrs[ptr_i] != parameters[param_i].data_ptr()) return {}; if (expected_data_ptrs[ptr_i + 1] != parameters[param_i + 1].data_ptr()) return {}; } } if (!parameters[num_parameters - 1].is_contiguous()) return {}; return weight_buf; } template std::pair _cudnn_impl( const Tensor& input, const Tensor& _batch_sizes, const hidden_type& hidden, TensorList params, bool has_biases, cudnnRNNMode_t mode, int64_t num_layers, double dropout_p, bool train, bool bidirectional) { auto [hx, cx] = unpack_hidden(hidden); auto hidden_size = hx.sym_size(2); SymInt proj_size = 0; // For LSTM models with projections hidden size could be different if (cx.defined() && cx.sym_size(2) != hx.sym_size(2)) { hidden_size = cx.sym_size(2); proj_size = hx.sym_size(2); } // TODO: try_get_weight_buf returns a Tensor, but _cudnn_rnn below takes a // std::optional in weight_buf's slot. Do we want try_get_weight_buf // to return a std::optional instead of a defined or undefined Tensor? at::cuda::OptionalCUDAGuard guard(input.get_device()); auto weight_buf = try_get_weight_buf( input, params, has_biases, mode, hidden_size, proj_size, num_layers, bidirectional); TORCH_CHECK(_batch_sizes.dim() == 1, "batch_sizes tensor should be 1D"); IntArrayRef batch_sizes{ _batch_sizes.data_ptr(), static_cast(_batch_sizes.size(0))}; auto& dropout_state = get_dropout_state(dropout_p, train, input.options()); std::unique_lock lock{dropout_state}; int64_t num_params = has_biases ? 4 : 2; if (proj_size != 0) { ++num_params; } auto sym_batch_sizes = c10::SymIntArrayRef( reinterpret_cast(batch_sizes.data()), batch_sizes.size()); // cudnn_output = std::tuple auto cudnn_output = at::_cudnn_rnn_symint( input, params, num_params, weight_buf, hx, cx, static_cast(mode), hidden_size, proj_size, num_layers, /*batch_first=*/false, dropout_p, train, bidirectional, sym_batch_sizes, dropout_state.buffer); return { std::get<0>(cudnn_output), pack_hidden( std::get<1>(cudnn_output), std::get<2>(cudnn_output))}; } template std::pair _cudnn_impl( const Tensor& input, const hidden_type& hidden, TensorList params, bool has_biases, cudnnRNNMode_t mode, int64_t num_layers, double dropout_p, bool train, bool bidirectional, bool batch_first) { auto [hx, cx] = unpack_hidden(hidden); auto hidden_size = hx.sym_size(2); c10::SymInt proj_size = 0; // For LSTM models with projections hidden size could be different if (cx.defined() && cx.sym_size(2) != hx.sym_size(2)) { hidden_size = cx.sym_size(2); proj_size = hx.sym_size(2); } at::cuda::OptionalCUDAGuard guard(input.get_device()); auto weight_buf = try_get_weight_buf( input, params, has_biases, mode, hidden_size, proj_size, num_layers, bidirectional); auto& dropout_state = get_dropout_state(dropout_p, train, input.options()); std::unique_lock lock{dropout_state}; int64_t num_params = has_biases ? 4 : 2; if (proj_size != 0) { ++num_params; } // cudnn_output = std::tuple auto cudnn_output = at::_cudnn_rnn_symint( input, params, num_params, weight_buf, hx, cx, static_cast(mode), hidden_size, proj_size, num_layers, batch_first, dropout_p, train, bidirectional, /*batch_sizes=*/{}, dropout_state.buffer); return { std::get<0>(cudnn_output), pack_hidden( std::get<1>(cudnn_output), std::get<2>(cudnn_output))}; } #define ONE_HIDDEN_RNN(NAME, MODE) \ void NAME##_cudnn( \ Tensor& output, \ Tensor& hy, \ const Tensor& input, \ const Tensor& hx, \ TensorList params, \ bool has_biases, \ int64_t num_layers, \ double dropout_p, \ bool train, \ bool bidirectional, \ bool batch_first) { \ std::tie(output, hy) = _cudnn_impl( \ input, \ hx, \ params, \ has_biases, \ MODE, \ num_layers, \ dropout_p, \ train, \ bidirectional, \ batch_first); \ } \ \ void NAME##_packed_cudnn( \ Tensor& output, \ Tensor& hy, \ const Tensor& data, \ const Tensor& batch_sizes, \ const Tensor& hx, \ TensorList params, \ bool has_biases, \ int64_t num_layers, \ double dropout_p, \ bool train, \ bool bidirectional) { \ std::tie(output, hy) = _cudnn_impl( \ data, \ batch_sizes, \ hx, \ params, \ has_biases, \ MODE, \ num_layers, \ dropout_p, \ train, \ bidirectional); \ } \ \ REGISTER_CUDA_DISPATCH(NAME##_cudnn_stub, &NAME##_cudnn); \ REGISTER_CUDA_DISPATCH(NAME##_packed_cudnn_stub, &NAME##_packed_cudnn); ONE_HIDDEN_RNN(gru, CUDNN_GRU) ONE_HIDDEN_RNN(rnn_tanh, CUDNN_RNN_TANH) ONE_HIDDEN_RNN(rnn_relu, CUDNN_RNN_RELU) void lstm_cudnn( Tensor& output, Tensor& hy, Tensor& cy, const Tensor& input, TensorList hx, TensorList params, bool has_biases, int64_t num_layers, double dropout_p, bool train, bool bidirectional, bool batch_first) { auto result = _cudnn_impl( input, std::make_tuple(hx[0], hx[1]), params, has_biases, CUDNN_LSTM, num_layers, dropout_p, train, bidirectional, batch_first); output = result.first; hy = std::get<0>(result.second); cy = std::get<1>(result.second); } void lstm_packed_cudnn( Tensor& output, Tensor& hy, Tensor& cy, const Tensor& data, const Tensor& batch_sizes, TensorList hx, TensorList params, bool has_biases, int64_t num_layers, double dropout_p, bool train, bool bidirectional) { auto result = _cudnn_impl( data, batch_sizes, std::make_tuple(hx[0], hx[1]), params, has_biases, CUDNN_LSTM, num_layers, dropout_p, train, bidirectional); output = result.first; hy = std::get<0>(result.second); cy = std::get<1>(result.second); } REGISTER_CUDA_DISPATCH(lstm_cudnn_stub, &lstm_cudnn); REGISTER_CUDA_DISPATCH(lstm_packed_cudnn_stub, &lstm_packed_cudnn); } // namespace } // namespace at } // namespace at #endif // AT_CUDNN_ENABLED()