1 #define TORCH_ASSERT_NO_OPERATORS
2 #include <ATen/cuda/CUDAConfig.h>
3 #include <ATen/cuda/cub.cuh>
4
5 namespace at::cuda::cub {
6
7 template <typename key_t>
radix_sort_keys(const key_t * keys_in,key_t * keys_out,int64_t n,bool descending,int64_t begin_bit,int64_t end_bit)8 void radix_sort_keys(
9 const key_t* keys_in,
10 key_t* keys_out,
11 int64_t n,
12 bool descending,
13 int64_t begin_bit,
14 int64_t end_bit) {
15 TORCH_CHECK(
16 n <= std::numeric_limits<int>::max(),
17 "cub sort does not support sorting more than INT_MAX elements");
18 using key_t_ = typename detail::cuda_type<key_t>::type;
19
20 const key_t_* keys_in_ = reinterpret_cast<const key_t_*>(keys_in);
21 key_t_* keys_out_ = reinterpret_cast<key_t_*>(keys_out);
22
23 if (descending) {
24 CUB_WRAPPER(
25 NO_ROCM(at_cuda_detail)::cub::DeviceRadixSort::SortKeysDescending,
26 keys_in_,
27 keys_out_,
28 n,
29 begin_bit,
30 end_bit,
31 c10::cuda::getCurrentCUDAStream());
32 } else {
33 CUB_WRAPPER(
34 NO_ROCM(at_cuda_detail)::cub::DeviceRadixSort::SortKeys,
35 keys_in_,
36 keys_out_,
37 n,
38 begin_bit,
39 end_bit,
40 c10::cuda::getCurrentCUDAStream());
41 }
42 }
43
44 #define AT_INSTATIATE_CUB_TEMPLATES(scalar_t, ScalarType) \
45 template void radix_sort_keys( \
46 const scalar_t* keys_in, \
47 scalar_t* keys_out, \
48 int64_t n, \
49 bool descending, \
50 int64_t begin_bit, \
51 int64_t end_bit);
52
53 AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, AT_INSTATIATE_CUB_TEMPLATES)
54 AT_INSTATIATE_CUB_TEMPLATES(uint16_t, UInt16)
55 AT_INSTATIATE_CUB_TEMPLATES(uint32_t, UInt32)
56 AT_INSTATIATE_CUB_TEMPLATES(uint64_t, UInt64)
57
58 } // namespace at::cuda::cub
59