1 #define TORCH_ASSERT_NO_OPERATORS
2 #include <ATen/cuda/CUDAGeneratorImpl.h>
3 #include <ATen/native/UnaryOps.h>
4 #include <ATen/native/cuda/DistributionTemplates.h>
5
6 namespace at::native {
7
random_from_to_kernel(TensorIteratorBase & iter,uint64_t range,int64_t base,std::optional<Generator> gen_)8 void random_from_to_kernel(TensorIteratorBase& iter, uint64_t range, int64_t base, std::optional<Generator> gen_) {
9 auto gen = get_generator_or_default<CUDAGeneratorImpl>(gen_, cuda::detail::getDefaultCUDAGenerator());
10 at::native::templates::cuda::random_from_to_kernel(iter, range, base, gen);
11 }
12
random_full_64_bits_range_kernel(TensorIteratorBase & iter,std::optional<Generator> gen_)13 void random_full_64_bits_range_kernel(TensorIteratorBase& iter, std::optional<Generator> gen_) {
14 auto gen = get_generator_or_default<CUDAGeneratorImpl>(gen_, cuda::detail::getDefaultCUDAGenerator());
15 at::native::templates::cuda::random_full_64_bits_range_kernel(iter, gen);
16 }
17
random_kernel(TensorIteratorBase & iter,std::optional<Generator> gen_)18 void random_kernel(TensorIteratorBase& iter, std::optional<Generator> gen_) {
19 auto gen = get_generator_or_default<CUDAGeneratorImpl>(gen_, cuda::detail::getDefaultCUDAGenerator());
20 at::native::templates::cuda::random_kernel(iter, gen);
21 }
22
23 REGISTER_DISPATCH(random_from_to_stub, &random_from_to_kernel);
24 REGISTER_DISPATCH(random_stub, &random_kernel);
25 REGISTER_DISPATCH(random_full_64_bits_range_stub, &random_full_64_bits_range_kernel);
26
27 } // namespace at::native
28