• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 #define TORCH_ASSERT_NO_OPERATORS
2 #include <ATen/cuda/CUDAConfig.h>
3 #include <ATen/cuda/cub.cuh>
4 
5 namespace at::cuda::cub {
6 
7 template <typename key_t>
radix_sort_keys(const key_t * keys_in,key_t * keys_out,int64_t n,bool descending,int64_t begin_bit,int64_t end_bit)8 void radix_sort_keys(
9     const key_t* keys_in,
10     key_t* keys_out,
11     int64_t n,
12     bool descending,
13     int64_t begin_bit,
14     int64_t end_bit) {
15   TORCH_CHECK(
16       n <= std::numeric_limits<int>::max(),
17       "cub sort does not support sorting more than INT_MAX elements");
18   using key_t_ = typename detail::cuda_type<key_t>::type;
19 
20   const key_t_* keys_in_ = reinterpret_cast<const key_t_*>(keys_in);
21   key_t_* keys_out_ = reinterpret_cast<key_t_*>(keys_out);
22 
23   if (descending) {
24     CUB_WRAPPER(
25         NO_ROCM(at_cuda_detail)::cub::DeviceRadixSort::SortKeysDescending,
26         keys_in_,
27         keys_out_,
28         n,
29         begin_bit,
30         end_bit,
31         c10::cuda::getCurrentCUDAStream());
32   } else {
33     CUB_WRAPPER(
34         NO_ROCM(at_cuda_detail)::cub::DeviceRadixSort::SortKeys,
35         keys_in_,
36         keys_out_,
37         n,
38         begin_bit,
39         end_bit,
40         c10::cuda::getCurrentCUDAStream());
41   }
42 }
43 
44 #define AT_INSTATIATE_CUB_TEMPLATES(scalar_t, ScalarType) \
45   template void radix_sort_keys(                          \
46       const scalar_t* keys_in,                            \
47       scalar_t* keys_out,                                 \
48       int64_t n,                                          \
49       bool descending,                                    \
50       int64_t begin_bit,                                  \
51       int64_t end_bit);
52 
53 AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, AT_INSTATIATE_CUB_TEMPLATES)
54 AT_INSTATIATE_CUB_TEMPLATES(uint16_t, UInt16)
55 AT_INSTATIATE_CUB_TEMPLATES(uint32_t, UInt32)
56 AT_INSTATIATE_CUB_TEMPLATES(uint64_t, UInt64)
57 
58 } // namespace at::cuda::cub
59