• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 #include <ATen/cuda/cub.cuh>
2 #include <ATen/cuda/CUDAContext.h>
3 #include <gtest/gtest.h>
4 
TEST(NumBits,CubTest)5 TEST(NumBits, CubTest) {
6   using at::cuda::cub::get_num_bits;
7   ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000000000000000000000000000000000UL), 1);
8   ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000000000000000000000000000000001UL), 1);
9   ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000000000000000000000000000000010UL), 2);
10   ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000000000000000000000000000000011UL), 2);
11   ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000000000000000000000000000000100UL), 3);
12   ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000000000000000000000000000000111UL), 3);
13   ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000000000000000000000000000001000UL), 4);
14   ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000000000000000000000000000001111UL), 4);
15   ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000000000000000000000000000010000UL), 5);
16   ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000000000000000000000000000011111UL), 5);
17   ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000000000000000000000000000100000UL), 6);
18   ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000000000000000000000000000111111UL), 6);
19   ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000000000000000000000000001000000UL), 7);
20   ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000000000000000000000000001111111UL), 7);
21   ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000000000000000000000000010000000UL), 8);
22   ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000000000000000000000000011111111UL), 8);
23   ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000000000000000000000000100000000UL), 9);
24   ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000000000000000000000000111111111UL), 9);
25   ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000000000000000000000001000000000UL), 10);
26   ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000000000000000000000001111111111UL), 10);
27   ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000000000000000000000010000000000UL), 11);
28   ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000000000000000000000011111111111UL), 11);
29   ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000000000000000000000100000000000UL), 12);
30   ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000000000000000000000111111111111UL), 12);
31   ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000000000000000000001000000000000UL), 13);
32   ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000000000000000000001111111111111UL), 13);
33   ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000000000000000000010000000000000UL), 14);
34   ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000000000000000000011111111111111UL), 14);
35   ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000000000000000000100000000000000UL), 15);
36   ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000000000000000000111111111111111UL), 15);
37   ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000000000000000001000000000000000UL), 16);
38   ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000000000000000001111111111111111UL), 16);
39   ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000000000000000010000000000000000UL), 17);
40   ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000000000000000011111111111111111UL), 17);
41   ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000000000000000100000000000000000UL), 18);
42   ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000000000000000111111111111111111UL), 18);
43   ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000000000000001000000000000000000UL), 19);
44   ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000000000000001111111111111111111UL), 19);
45   ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000000000000010000000000000000000UL), 20);
46   ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000000000000011111111111111111111UL), 20);
47   ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000000000000100000000000000000000UL), 21);
48   ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000000000000111111111111111111111UL), 21);
49   ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000000000001000000000000000000000UL), 22);
50   ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000000000001111111111111111111111UL), 22);
51   ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000000000010000000000000000000000UL), 23);
52   ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000000000011111111111111111111111UL), 23);
53   ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000000000100000000000000000000000UL), 24);
54   ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000000000111111111111111111111111UL), 24);
55   ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000000001000000000000000000000000UL), 25);
56   ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000000001111111111111111111111111UL), 25);
57   ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000000010000000000000000000000000UL), 26);
58   ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000000011111111111111111111111111UL), 26);
59   ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000000100000000000000000000000000UL), 27);
60   ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000000111111111111111111111111111UL), 27);
61   ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000001000000000000000000000000000UL), 28);
62   ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000001111111111111111111111111111UL), 28);
63   ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000010000000000000000000000000000UL), 29);
64   ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000011111111111111111111111111111UL), 29);
65   ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000100000000000000000000000000000UL), 30);
66   ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000111111111111111111111111111111UL), 30);
67   ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000001000000000000000000000000000000UL), 31);
68   ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000001111111111111111111111111111111UL), 31);
69   ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000010000000000000000000000000000000UL), 32);
70   ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000011111111111111111111111111111111UL), 32);
71   ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000100000000000000000000000000000000UL), 33);
72   ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000111111111111111111111111111111111UL), 33);
73   ASSERT_EQ(get_num_bits(0b0000000000000000000000000000001000000000000000000000000000000000UL), 34);
74   ASSERT_EQ(get_num_bits(0b0000000000000000000000000000001111111111111111111111111111111111UL), 34);
75   ASSERT_EQ(get_num_bits(0b0000000000000000000000000000010000000000000000000000000000000000UL), 35);
76   ASSERT_EQ(get_num_bits(0b0000000000000000000000000000011111111111111111111111111111111111UL), 35);
77   ASSERT_EQ(get_num_bits(0b0000000000000000000000000000100000000000000000000000000000000000UL), 36);
78   ASSERT_EQ(get_num_bits(0b0000000000000000000000000000111111111111111111111111111111111111UL), 36);
79   ASSERT_EQ(get_num_bits(0b0000000000000000000000000001000000000000000000000000000000000000UL), 37);
80   ASSERT_EQ(get_num_bits(0b0000000000000000000000000001111111111111111111111111111111111111UL), 37);
81   ASSERT_EQ(get_num_bits(0b0000000000000000000000000010000000000000000000000000000000000000UL), 38);
82   ASSERT_EQ(get_num_bits(0b0000000000000000000000000011111111111111111111111111111111111111UL), 38);
83   ASSERT_EQ(get_num_bits(0b0000000000000000000000000100000000000000000000000000000000000000UL), 39);
84   ASSERT_EQ(get_num_bits(0b0000000000000000000000000111111111111111111111111111111111111111UL), 39);
85   ASSERT_EQ(get_num_bits(0b0000000000000000000000001000000000000000000000000000000000000000UL), 40);
86   ASSERT_EQ(get_num_bits(0b0000000000000000000000001111111111111111111111111111111111111111UL), 40);
87   ASSERT_EQ(get_num_bits(0b0000000000000000000000010000000000000000000000000000000000000000UL), 41);
88   ASSERT_EQ(get_num_bits(0b0000000000000000000000011111111111111111111111111111111111111111UL), 41);
89   ASSERT_EQ(get_num_bits(0b0000000000000000000000100000000000000000000000000000000000000000UL), 42);
90   ASSERT_EQ(get_num_bits(0b0000000000000000000000111111111111111111111111111111111111111111UL), 42);
91   ASSERT_EQ(get_num_bits(0b0000000000000000000001000000000000000000000000000000000000000000UL), 43);
92   ASSERT_EQ(get_num_bits(0b0000000000000000000001111111111111111111111111111111111111111111UL), 43);
93   ASSERT_EQ(get_num_bits(0b0000000000000000000010000000000000000000000000000000000000000000UL), 44);
94   ASSERT_EQ(get_num_bits(0b0000000000000000000011111111111111111111111111111111111111111111UL), 44);
95   ASSERT_EQ(get_num_bits(0b0000000000000000000100000000000000000000000000000000000000000000UL), 45);
96   ASSERT_EQ(get_num_bits(0b0000000000000000000111111111111111111111111111111111111111111111UL), 45);
97   ASSERT_EQ(get_num_bits(0b0000000000000000001000000000000000000000000000000000000000000000UL), 46);
98   ASSERT_EQ(get_num_bits(0b0000000000000000001111111111111111111111111111111111111111111111UL), 46);
99   ASSERT_EQ(get_num_bits(0b0000000000000000010000000000000000000000000000000000000000000000UL), 47);
100   ASSERT_EQ(get_num_bits(0b0000000000000000011111111111111111111111111111111111111111111111UL), 47);
101   ASSERT_EQ(get_num_bits(0b0000000000000000100000000000000000000000000000000000000000000000UL), 48);
102   ASSERT_EQ(get_num_bits(0b0000000000000000111111111111111111111111111111111111111111111111UL), 48);
103   ASSERT_EQ(get_num_bits(0b0000000000000001000000000000000000000000000000000000000000000000UL), 49);
104   ASSERT_EQ(get_num_bits(0b0000000000000001111111111111111111111111111111111111111111111111UL), 49);
105   ASSERT_EQ(get_num_bits(0b0000000000000010000000000000000000000000000000000000000000000000UL), 50);
106   ASSERT_EQ(get_num_bits(0b0000000000000011111111111111111111111111111111111111111111111111UL), 50);
107   ASSERT_EQ(get_num_bits(0b0000000000000100000000000000000000000000000000000000000000000000UL), 51);
108   ASSERT_EQ(get_num_bits(0b0000000000000111111111111111111111111111111111111111111111111111UL), 51);
109   ASSERT_EQ(get_num_bits(0b0000000000001000000000000000000000000000000000000000000000000000UL), 52);
110   ASSERT_EQ(get_num_bits(0b0000000000001111111111111111111111111111111111111111111111111111UL), 52);
111   ASSERT_EQ(get_num_bits(0b0000000000010000000000000000000000000000000000000000000000000000UL), 53);
112   ASSERT_EQ(get_num_bits(0b0000000000011111111111111111111111111111111111111111111111111111UL), 53);
113   ASSERT_EQ(get_num_bits(0b0000000000100000000000000000000000000000000000000000000000000000UL), 54);
114   ASSERT_EQ(get_num_bits(0b0000000000111111111111111111111111111111111111111111111111111111UL), 54);
115   ASSERT_EQ(get_num_bits(0b0000000001000000000000000000000000000000000000000000000000000000UL), 55);
116   ASSERT_EQ(get_num_bits(0b0000000001111111111111111111111111111111111111111111111111111111UL), 55);
117   ASSERT_EQ(get_num_bits(0b0000000010000000000000000000000000000000000000000000000000000000UL), 56);
118   ASSERT_EQ(get_num_bits(0b0000000011111111111111111111111111111111111111111111111111111111UL), 56);
119   ASSERT_EQ(get_num_bits(0b0000000100000000000000000000000000000000000000000000000000000000UL), 57);
120   ASSERT_EQ(get_num_bits(0b0000000111111111111111111111111111111111111111111111111111111111UL), 57);
121   ASSERT_EQ(get_num_bits(0b0000001000000000000000000000000000000000000000000000000000000000UL), 58);
122   ASSERT_EQ(get_num_bits(0b0000001111111111111111111111111111111111111111111111111111111111UL), 58);
123   ASSERT_EQ(get_num_bits(0b0000010000000000000000000000000000000000000000000000000000000000UL), 59);
124   ASSERT_EQ(get_num_bits(0b0000011111111111111111111111111111111111111111111111111111111111UL), 59);
125   ASSERT_EQ(get_num_bits(0b0000100000000000000000000000000000000000000000000000000000000000UL), 60);
126   ASSERT_EQ(get_num_bits(0b0000111111111111111111111111111111111111111111111111111111111111UL), 60);
127   ASSERT_EQ(get_num_bits(0b0001000000000000000000000000000000000000000000000000000000000000UL), 61);
128   ASSERT_EQ(get_num_bits(0b0001111111111111111111111111111111111111111111111111111111111111UL), 61);
129   ASSERT_EQ(get_num_bits(0b0010000000000000000000000000000000000000000000000000000000000000UL), 62);
130   ASSERT_EQ(get_num_bits(0b0011111111111111111111111111111111111111111111111111111111111111UL), 62);
131   ASSERT_EQ(get_num_bits(0b0100000000000000000000000000000000000000000000000000000000000000UL), 63);
132   ASSERT_EQ(get_num_bits(0b0111111111111111111111111111111111111111111111111111111111111111UL), 63);
133   ASSERT_EQ(get_num_bits(0b1000000000000000000000000000000000000000000000000000000000000000UL), 64);
134   ASSERT_EQ(get_num_bits(0b1111111111111111111111111111111111111111111111111111111111111111UL), 64);
135 }
136 
137 __managed__ int input[] = { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 };
138 
TEST(InclusiveScanSplit,CubTest)139 TEST(InclusiveScanSplit, CubTest) {
140   if (!at::cuda::is_available()) return;
141   at::globalContext().lazyInitCUDA();  // This is required to use PyTorch's caching allocator.
142 
143   int *output1;
144   cudaMallocManaged(&output1, sizeof(int) * 10);
145 
146   cudaDeviceSynchronize();
147   at::cuda::cub::inclusive_scan<int *, int *, ::at_cuda_detail::cub::Sum, /*max_cub_size=*/2>(
148     input, output1, ::at_cuda_detail::cub::Sum(), 10);
149   cudaDeviceSynchronize();
150 
151   ASSERT_EQ(output1[0], 1);
152   ASSERT_EQ(output1[1], 3);
153   ASSERT_EQ(output1[2], 6);
154   ASSERT_EQ(output1[3], 10);
155   ASSERT_EQ(output1[4], 15);
156   ASSERT_EQ(output1[5], 21);
157   ASSERT_EQ(output1[6], 28);
158   ASSERT_EQ(output1[7], 36);
159   ASSERT_EQ(output1[8], 45);
160   ASSERT_EQ(output1[9], 55);
161 }
162 
TEST(ExclusiveScanSplit,CubTest)163 TEST(ExclusiveScanSplit, CubTest) {
164   if (!at::cuda::is_available()) return;
165   at::globalContext().lazyInitCUDA();  // This is required to use PyTorch's caching allocator.
166 
167   int *output2;
168   cudaMallocManaged(&output2, sizeof(int) * 10);
169 
170   cudaDeviceSynchronize();
171   at::cuda::cub::exclusive_scan<int *, int *, ::at_cuda_detail::cub::Sum, int, /*max_cub_size=*/2>(
172     input, output2, ::at_cuda_detail::cub::Sum(), 0, 10);
173   cudaDeviceSynchronize();
174 
175   ASSERT_EQ(output2[0], 0);
176   ASSERT_EQ(output2[1], 1);
177   ASSERT_EQ(output2[2], 3);
178   ASSERT_EQ(output2[3], 6);
179   ASSERT_EQ(output2[4], 10);
180   ASSERT_EQ(output2[5], 15);
181   ASSERT_EQ(output2[6], 21);
182   ASSERT_EQ(output2[7], 28);
183   ASSERT_EQ(output2[8], 36);
184   ASSERT_EQ(output2[9], 45);
185 }
186