1 #define TORCH_ASSERT_NO_OPERATORS
2 #include <ATen/AccumulateType.h>
3 #include <ATen/Dispatch.h>
4 #include <ATen/OpMathType.h>
5 #include <ATen/native/DispatchStub.h>
6 #include <ATen/native/TensorIterator.h>
7 #include <ATen/native/UnaryOps.h>
8 #include <ATen/native/cuda/JitLoops.cuh>
9 #include <ATen/native/cuda/Loops.cuh>
10 #include <ATen/native/cuda/Math.cuh>
11 #include <limits>
12 
13 namespace at::native {
14 
15 #if AT_USE_JITERATOR()
16 CONSTEXPR_EXCEPT_WIN_CUDA char sinh_name[] = "sinh_impl";
17 #endif
18 
sinh_kernel_cuda(TensorIteratorBase & iter)19 void sinh_kernel_cuda(TensorIteratorBase& iter) {
20   auto common_dtype = iter.common_dtype();
21   if (at::isComplexType(common_dtype)) {
22 #if AT_USE_JITERATOR()
23     static const auto sinh_string = jiterator_stringify(
24         template <typename T> T sinh_impl(T a) { return std::sinh(a); });
25     AT_DISPATCH_COMPLEX_TYPES_AND(
26         kComplexHalf, common_dtype, "sinh_name", [&]() {
27           jitted_gpu_kernel<
28               /*name=*/sinh_name,
29               /*return_dtype=*/scalar_t,
30               /*common_dtype=*/scalar_t,
31               /*arity=*/1>(iter, sinh_string);
32         });
33 #else
34     AT_DISPATCH_COMPLEX_TYPES_AND(
35         kComplexHalf, common_dtype, "sinh_name", [&]() {
36           gpu_kernel(iter, [] GPU_LAMBDA(scalar_t a) -> scalar_t {
37             using opmath_t = at::opmath_type<scalar_t>;
38             return ::sinh(static_cast<opmath_t>(a));
39           });
40         });
41 #endif
42   } else {
43     AT_DISPATCH_FLOATING_TYPES_AND2(
44         ScalarType::Half,
45         ScalarType::BFloat16,
46         common_dtype,
47         "sinh_cuda",
48         [&]() {
49           gpu_kernel(iter, [] GPU_LAMBDA(scalar_t a) -> scalar_t {
50             return ::sinh(a);
51           });
52         });
53   }
54 }
55 
56 REGISTER_DISPATCH(sinh_stub, &sinh_kernel_cuda);
57 
58 } // namespace at::native
59