1 #define TORCH_ASSERT_NO_OPERATORS 2 3 #include <ATen/native/UnaryOps.h> 4 5 #include <limits> 6 7 #include <ATen/AccumulateType.h> 8 #include <ATen/Dispatch.h> 9 #include <ATen/native/DispatchStub.h> 10 #include <ATen/native/Math.h> 11 #include <ATen/native/TensorIterator.h> 12 #include <ATen/native/cuda/JitLoops.cuh> 13 #include <ATen/native/cuda/Loops.cuh> 14 #include <ATen/native/cuda/Math.cuh> 15 #include <ATen/native/cuda/jit_utils.h> 16 #include <ATen/NumericUtils.h> 17 #include <c10/core/Scalar.h> 18 #include <c10/cuda/CUDAMathCompat.h> 19 #include <c10/util/complex.h> 20 21 namespace at::native { 22 namespace { 23 CONSTEXPR_EXCEPT_WIN_CUDA char modified_bessel_k0_name[] = "modified_bessel_k0_forward"; 24 modified_bessel_k0_kernel_cuda(TensorIteratorBase & iterator)25 void modified_bessel_k0_kernel_cuda(TensorIteratorBase& iterator) { 26 #if AT_USE_JITERATOR() 27 AT_DISPATCH_FLOATING_TYPES(iterator.common_dtype(), "modified_bessel_k0_cuda", [&]() { 28 jitted_gpu_kernel<modified_bessel_k0_name, scalar_t, scalar_t, 1>(iterator, modified_bessel_k0_string); 29 }); 30 #else 31 AT_DISPATCH_FLOATING_TYPES(iterator.common_dtype(), "modified_bessel_k0_cuda", [&]() { 32 gpu_kernel(iterator, []GPU_LAMBDA(scalar_t a) -> scalar_t { 33 return modified_bessel_k0_forward(a); 34 }); 35 }); 36 #endif // AT_USE_JITERATOR() 37 } 38 } 39 40 REGISTER_DISPATCH(special_modified_bessel_k0_stub, &modified_bessel_k0_kernel_cuda); 41 } // namespace at::native 42