1 #define TORCH_ASSERT_NO_OPERATORS 2 3 #include <ATen/Dispatch.h> 4 #include <ATen/native/cuda/JitLoops.cuh> 5 #include <ATen/native/cuda/Loops.cuh> 6 #include <ATen/native/BinaryOps.h> 7 #include <ATen/native/Math.h> 8 #include <ATen/native/cuda/Math.cuh> 9 #include <ATen/native/cuda/jit_utils.h> 10 11 namespace at::native { 12 namespace { 13 CONSTEXPR_EXCEPT_WIN_CUDA char hermite_polynomial_h_name[] = "hermite_polynomial_h_forward"; 14 hermite_polynomial_h_kernel_cuda(TensorIteratorBase & iterator)15 void hermite_polynomial_h_kernel_cuda(TensorIteratorBase& iterator) { 16 #if AT_USE_JITERATOR() 17 AT_DISPATCH_FLOATING_TYPES(iterator.common_dtype(), "hermite_polynomial_h_cuda", [&]() { 18 opmath_jitted_gpu_kernel_with_scalars<hermite_polynomial_h_name, scalar_t, scalar_t>(iterator, hermite_polynomial_h_string); 19 }); 20 #else 21 AT_DISPATCH_FLOATING_TYPES(iterator.common_dtype(), "hermite_polynomial_h_cuda", [&]() { 22 gpu_kernel_with_scalars(iterator, []GPU_LAMBDA(scalar_t x, scalar_t n) -> scalar_t { 23 return hermite_polynomial_h_forward<scalar_t, true>(x, n); 24 }); 25 }); 26 #endif 27 } // hermite_polynomial_h_kernel_cuda 28 } // namespace (anonymous) 29 30 REGISTER_DISPATCH(hermite_polynomial_h_stub, &hermite_polynomial_h_kernel_cuda); 31 } // namespace at::native 32