1 #define TORCH_ASSERT_NO_OPERATORS
2 
3 #include <ATen/native/UnaryOps.h>
4 
5 #include <limits>
6 
7 #include <ATen/AccumulateType.h>
8 #include <ATen/Dispatch.h>
9 #include <ATen/native/DispatchStub.h>
10 #include <ATen/native/Math.h>
11 #include <ATen/native/TensorIterator.h>
12 #include <ATen/native/cuda/JitLoops.cuh>
13 #include <ATen/native/cuda/Loops.cuh>
14 #include <ATen/native/cuda/Math.cuh>
15 #include <ATen/native/cuda/jit_utils.h>
16 #include <ATen/NumericUtils.h>
17 #include <c10/core/Scalar.h>
18 #include <c10/cuda/CUDAMathCompat.h>
19 #include <c10/util/complex.h>
20 
21 namespace at::native {
22 namespace {
23 CONSTEXPR_EXCEPT_WIN_CUDA char airy_ai_name[] = "airy_ai_forward";
24 
airy_ai_kernel_cuda(TensorIteratorBase & iterator)25 void airy_ai_kernel_cuda(TensorIteratorBase& iterator) {
26 #if AT_USE_JITERATOR()
27     AT_DISPATCH_FLOATING_TYPES(iterator.common_dtype(), "airy_ai_cuda", [&]() {
28         jitted_gpu_kernel<airy_ai_name, scalar_t, scalar_t, 1>(iterator, airy_ai_string);
29     });
30 #else
31     AT_DISPATCH_FLOATING_TYPES(iterator.common_dtype(), "airy_ai_cuda", [&]() {
32         gpu_kernel(iterator, []GPU_LAMBDA(scalar_t a) -> scalar_t {
33             return airy_ai_forward(a);
34         });
35     });
36 #endif // AT_USE_JITERATOR()
37 }
38 
39 } // anonymous namespace
40 
41 REGISTER_DISPATCH(special_airy_ai_stub, &airy_ai_kernel_cuda);
42 } // namespace at::native
43