• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 #pragma once
2 
3 #include <ATen/core/Tensor.h>
4 #include <ATen/TensorUtils.h>
5 #include <tuple>
6 
7 namespace at::native {
8 
9 TORCH_API std::tuple<Tensor, Tensor, int64_t> softmax_sparse_input_preprocessing(
10     const Tensor& input_,
11     const int64_t dim_,
12     const bool half_to_float,
13     CheckedFrom function_name);
14 
15 TORCH_API std::tuple<Tensor, Tensor, Tensor, int64_t> softmax_backward_sparse_input_preprocessing(
16     const Tensor& grad_,
17     const Tensor& output_,
18     int64_t dim_,
19     const Tensor& input_,
20     CheckedFrom function_name);
21 
22 } // namespace at::native
23