• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 #define TORCH_ASSERT_NO_OPERATORS
2 
3 #include <ATen/native/UnaryOps.h>
4 
5 #include <limits>
6 
7 #include <ATen/AccumulateType.h>
8 #include <ATen/Dispatch.h>
9 #include <ATen/native/DispatchStub.h>
10 #include <ATen/native/Math.h>
11 #include <ATen/native/TensorIterator.h>
12 #include <ATen/native/cuda/JitLoops.cuh>
13 #include <ATen/native/cuda/Loops.cuh>
14 #include <ATen/native/cuda/Math.cuh>
15 #include <ATen/native/cuda/jit_utils.h>
16 #include <ATen/NumericUtils.h>
17 #include <c10/core/Scalar.h>
18 #include <c10/cuda/CUDAMathCompat.h>
19 #include <c10/util/complex.h>
20 
21 namespace at::native {
22         namespace {
23             CONSTEXPR_EXCEPT_WIN_CUDA char modified_bessel_i1_name[] = "modified_bessel_i1_forward";
24 
modified_bessel_i1_kernel_cuda(TensorIteratorBase & iterator)25             void modified_bessel_i1_kernel_cuda(TensorIteratorBase& iterator) {
26 #if AT_USE_JITERATOR()
27                 AT_DISPATCH_FLOATING_TYPES(iterator.common_dtype(), "modified_bessel_i1_cuda", [&]() {
28                     jitted_gpu_kernel<modified_bessel_i1_name, scalar_t, scalar_t, 1>(iterator, modified_bessel_i1_string);
29                 });
30 #else
31                 AT_DISPATCH_FLOATING_TYPES(iterator.common_dtype(), "modified_bessel_i1_cuda", [&]() {
32                     gpu_kernel(iterator, []GPU_LAMBDA(scalar_t a) -> scalar_t {
33                         return modified_bessel_i1_forward(a);
34                     });
35                 });
36 #endif // AT_USE_JITERATOR()
37             }
38         }
39 
40         REGISTER_DISPATCH(special_modified_bessel_i1_stub, &modified_bessel_i1_kernel_cuda);
41 } // namespace at::native
42