• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 #define TORCH_ASSERT_NO_OPERATORS
2 #include <ATen/cuda/CUDAGeneratorImpl.h>
3 #include <ATen/native/UnaryOps.h>
4 #include <ATen/native/cuda/DistributionTemplates.h>
5 
6 namespace at::native {
7 
random_from_to_kernel(TensorIteratorBase & iter,uint64_t range,int64_t base,std::optional<Generator> gen_)8 void random_from_to_kernel(TensorIteratorBase& iter, uint64_t range, int64_t base, std::optional<Generator> gen_) {
9   auto gen = get_generator_or_default<CUDAGeneratorImpl>(gen_, cuda::detail::getDefaultCUDAGenerator());
10   at::native::templates::cuda::random_from_to_kernel(iter, range, base, gen);
11 }
12 
random_full_64_bits_range_kernel(TensorIteratorBase & iter,std::optional<Generator> gen_)13 void random_full_64_bits_range_kernel(TensorIteratorBase& iter, std::optional<Generator> gen_) {
14   auto gen = get_generator_or_default<CUDAGeneratorImpl>(gen_, cuda::detail::getDefaultCUDAGenerator());
15   at::native::templates::cuda::random_full_64_bits_range_kernel(iter, gen);
16 }
17 
random_kernel(TensorIteratorBase & iter,std::optional<Generator> gen_)18 void random_kernel(TensorIteratorBase& iter, std::optional<Generator> gen_) {
19   auto gen = get_generator_or_default<CUDAGeneratorImpl>(gen_, cuda::detail::getDefaultCUDAGenerator());
20   at::native::templates::cuda::random_kernel(iter, gen);
21 }
22 
23 REGISTER_DISPATCH(random_from_to_stub, &random_from_to_kernel);
24 REGISTER_DISPATCH(random_stub, &random_kernel);
25 REGISTER_DISPATCH(random_full_64_bits_range_stub, &random_full_64_bits_range_kernel);
26 
27 } // namespace at::native
28