1 #include <ATen/cuda/cub.cuh>
2 #include <ATen/cuda/CUDAContext.h>
3 #include <gtest/gtest.h>
4
TEST(NumBits,CubTest)5 TEST(NumBits, CubTest) {
6 using at::cuda::cub::get_num_bits;
7 ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000000000000000000000000000000000UL), 1);
8 ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000000000000000000000000000000001UL), 1);
9 ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000000000000000000000000000000010UL), 2);
10 ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000000000000000000000000000000011UL), 2);
11 ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000000000000000000000000000000100UL), 3);
12 ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000000000000000000000000000000111UL), 3);
13 ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000000000000000000000000000001000UL), 4);
14 ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000000000000000000000000000001111UL), 4);
15 ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000000000000000000000000000010000UL), 5);
16 ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000000000000000000000000000011111UL), 5);
17 ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000000000000000000000000000100000UL), 6);
18 ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000000000000000000000000000111111UL), 6);
19 ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000000000000000000000000001000000UL), 7);
20 ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000000000000000000000000001111111UL), 7);
21 ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000000000000000000000000010000000UL), 8);
22 ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000000000000000000000000011111111UL), 8);
23 ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000000000000000000000000100000000UL), 9);
24 ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000000000000000000000000111111111UL), 9);
25 ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000000000000000000000001000000000UL), 10);
26 ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000000000000000000000001111111111UL), 10);
27 ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000000000000000000000010000000000UL), 11);
28 ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000000000000000000000011111111111UL), 11);
29 ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000000000000000000000100000000000UL), 12);
30 ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000000000000000000000111111111111UL), 12);
31 ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000000000000000000001000000000000UL), 13);
32 ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000000000000000000001111111111111UL), 13);
33 ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000000000000000000010000000000000UL), 14);
34 ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000000000000000000011111111111111UL), 14);
35 ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000000000000000000100000000000000UL), 15);
36 ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000000000000000000111111111111111UL), 15);
37 ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000000000000000001000000000000000UL), 16);
38 ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000000000000000001111111111111111UL), 16);
39 ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000000000000000010000000000000000UL), 17);
40 ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000000000000000011111111111111111UL), 17);
41 ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000000000000000100000000000000000UL), 18);
42 ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000000000000000111111111111111111UL), 18);
43 ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000000000000001000000000000000000UL), 19);
44 ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000000000000001111111111111111111UL), 19);
45 ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000000000000010000000000000000000UL), 20);
46 ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000000000000011111111111111111111UL), 20);
47 ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000000000000100000000000000000000UL), 21);
48 ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000000000000111111111111111111111UL), 21);
49 ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000000000001000000000000000000000UL), 22);
50 ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000000000001111111111111111111111UL), 22);
51 ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000000000010000000000000000000000UL), 23);
52 ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000000000011111111111111111111111UL), 23);
53 ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000000000100000000000000000000000UL), 24);
54 ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000000000111111111111111111111111UL), 24);
55 ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000000001000000000000000000000000UL), 25);
56 ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000000001111111111111111111111111UL), 25);
57 ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000000010000000000000000000000000UL), 26);
58 ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000000011111111111111111111111111UL), 26);
59 ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000000100000000000000000000000000UL), 27);
60 ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000000111111111111111111111111111UL), 27);
61 ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000001000000000000000000000000000UL), 28);
62 ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000001111111111111111111111111111UL), 28);
63 ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000010000000000000000000000000000UL), 29);
64 ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000011111111111111111111111111111UL), 29);
65 ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000100000000000000000000000000000UL), 30);
66 ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000111111111111111111111111111111UL), 30);
67 ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000001000000000000000000000000000000UL), 31);
68 ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000001111111111111111111111111111111UL), 31);
69 ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000010000000000000000000000000000000UL), 32);
70 ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000011111111111111111111111111111111UL), 32);
71 ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000100000000000000000000000000000000UL), 33);
72 ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000111111111111111111111111111111111UL), 33);
73 ASSERT_EQ(get_num_bits(0b0000000000000000000000000000001000000000000000000000000000000000UL), 34);
74 ASSERT_EQ(get_num_bits(0b0000000000000000000000000000001111111111111111111111111111111111UL), 34);
75 ASSERT_EQ(get_num_bits(0b0000000000000000000000000000010000000000000000000000000000000000UL), 35);
76 ASSERT_EQ(get_num_bits(0b0000000000000000000000000000011111111111111111111111111111111111UL), 35);
77 ASSERT_EQ(get_num_bits(0b0000000000000000000000000000100000000000000000000000000000000000UL), 36);
78 ASSERT_EQ(get_num_bits(0b0000000000000000000000000000111111111111111111111111111111111111UL), 36);
79 ASSERT_EQ(get_num_bits(0b0000000000000000000000000001000000000000000000000000000000000000UL), 37);
80 ASSERT_EQ(get_num_bits(0b0000000000000000000000000001111111111111111111111111111111111111UL), 37);
81 ASSERT_EQ(get_num_bits(0b0000000000000000000000000010000000000000000000000000000000000000UL), 38);
82 ASSERT_EQ(get_num_bits(0b0000000000000000000000000011111111111111111111111111111111111111UL), 38);
83 ASSERT_EQ(get_num_bits(0b0000000000000000000000000100000000000000000000000000000000000000UL), 39);
84 ASSERT_EQ(get_num_bits(0b0000000000000000000000000111111111111111111111111111111111111111UL), 39);
85 ASSERT_EQ(get_num_bits(0b0000000000000000000000001000000000000000000000000000000000000000UL), 40);
86 ASSERT_EQ(get_num_bits(0b0000000000000000000000001111111111111111111111111111111111111111UL), 40);
87 ASSERT_EQ(get_num_bits(0b0000000000000000000000010000000000000000000000000000000000000000UL), 41);
88 ASSERT_EQ(get_num_bits(0b0000000000000000000000011111111111111111111111111111111111111111UL), 41);
89 ASSERT_EQ(get_num_bits(0b0000000000000000000000100000000000000000000000000000000000000000UL), 42);
90 ASSERT_EQ(get_num_bits(0b0000000000000000000000111111111111111111111111111111111111111111UL), 42);
91 ASSERT_EQ(get_num_bits(0b0000000000000000000001000000000000000000000000000000000000000000UL), 43);
92 ASSERT_EQ(get_num_bits(0b0000000000000000000001111111111111111111111111111111111111111111UL), 43);
93 ASSERT_EQ(get_num_bits(0b0000000000000000000010000000000000000000000000000000000000000000UL), 44);
94 ASSERT_EQ(get_num_bits(0b0000000000000000000011111111111111111111111111111111111111111111UL), 44);
95 ASSERT_EQ(get_num_bits(0b0000000000000000000100000000000000000000000000000000000000000000UL), 45);
96 ASSERT_EQ(get_num_bits(0b0000000000000000000111111111111111111111111111111111111111111111UL), 45);
97 ASSERT_EQ(get_num_bits(0b0000000000000000001000000000000000000000000000000000000000000000UL), 46);
98 ASSERT_EQ(get_num_bits(0b0000000000000000001111111111111111111111111111111111111111111111UL), 46);
99 ASSERT_EQ(get_num_bits(0b0000000000000000010000000000000000000000000000000000000000000000UL), 47);
100 ASSERT_EQ(get_num_bits(0b0000000000000000011111111111111111111111111111111111111111111111UL), 47);
101 ASSERT_EQ(get_num_bits(0b0000000000000000100000000000000000000000000000000000000000000000UL), 48);
102 ASSERT_EQ(get_num_bits(0b0000000000000000111111111111111111111111111111111111111111111111UL), 48);
103 ASSERT_EQ(get_num_bits(0b0000000000000001000000000000000000000000000000000000000000000000UL), 49);
104 ASSERT_EQ(get_num_bits(0b0000000000000001111111111111111111111111111111111111111111111111UL), 49);
105 ASSERT_EQ(get_num_bits(0b0000000000000010000000000000000000000000000000000000000000000000UL), 50);
106 ASSERT_EQ(get_num_bits(0b0000000000000011111111111111111111111111111111111111111111111111UL), 50);
107 ASSERT_EQ(get_num_bits(0b0000000000000100000000000000000000000000000000000000000000000000UL), 51);
108 ASSERT_EQ(get_num_bits(0b0000000000000111111111111111111111111111111111111111111111111111UL), 51);
109 ASSERT_EQ(get_num_bits(0b0000000000001000000000000000000000000000000000000000000000000000UL), 52);
110 ASSERT_EQ(get_num_bits(0b0000000000001111111111111111111111111111111111111111111111111111UL), 52);
111 ASSERT_EQ(get_num_bits(0b0000000000010000000000000000000000000000000000000000000000000000UL), 53);
112 ASSERT_EQ(get_num_bits(0b0000000000011111111111111111111111111111111111111111111111111111UL), 53);
113 ASSERT_EQ(get_num_bits(0b0000000000100000000000000000000000000000000000000000000000000000UL), 54);
114 ASSERT_EQ(get_num_bits(0b0000000000111111111111111111111111111111111111111111111111111111UL), 54);
115 ASSERT_EQ(get_num_bits(0b0000000001000000000000000000000000000000000000000000000000000000UL), 55);
116 ASSERT_EQ(get_num_bits(0b0000000001111111111111111111111111111111111111111111111111111111UL), 55);
117 ASSERT_EQ(get_num_bits(0b0000000010000000000000000000000000000000000000000000000000000000UL), 56);
118 ASSERT_EQ(get_num_bits(0b0000000011111111111111111111111111111111111111111111111111111111UL), 56);
119 ASSERT_EQ(get_num_bits(0b0000000100000000000000000000000000000000000000000000000000000000UL), 57);
120 ASSERT_EQ(get_num_bits(0b0000000111111111111111111111111111111111111111111111111111111111UL), 57);
121 ASSERT_EQ(get_num_bits(0b0000001000000000000000000000000000000000000000000000000000000000UL), 58);
122 ASSERT_EQ(get_num_bits(0b0000001111111111111111111111111111111111111111111111111111111111UL), 58);
123 ASSERT_EQ(get_num_bits(0b0000010000000000000000000000000000000000000000000000000000000000UL), 59);
124 ASSERT_EQ(get_num_bits(0b0000011111111111111111111111111111111111111111111111111111111111UL), 59);
125 ASSERT_EQ(get_num_bits(0b0000100000000000000000000000000000000000000000000000000000000000UL), 60);
126 ASSERT_EQ(get_num_bits(0b0000111111111111111111111111111111111111111111111111111111111111UL), 60);
127 ASSERT_EQ(get_num_bits(0b0001000000000000000000000000000000000000000000000000000000000000UL), 61);
128 ASSERT_EQ(get_num_bits(0b0001111111111111111111111111111111111111111111111111111111111111UL), 61);
129 ASSERT_EQ(get_num_bits(0b0010000000000000000000000000000000000000000000000000000000000000UL), 62);
130 ASSERT_EQ(get_num_bits(0b0011111111111111111111111111111111111111111111111111111111111111UL), 62);
131 ASSERT_EQ(get_num_bits(0b0100000000000000000000000000000000000000000000000000000000000000UL), 63);
132 ASSERT_EQ(get_num_bits(0b0111111111111111111111111111111111111111111111111111111111111111UL), 63);
133 ASSERT_EQ(get_num_bits(0b1000000000000000000000000000000000000000000000000000000000000000UL), 64);
134 ASSERT_EQ(get_num_bits(0b1111111111111111111111111111111111111111111111111111111111111111UL), 64);
135 }
136
137 __managed__ int input[] = { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 };
138
TEST(InclusiveScanSplit,CubTest)139 TEST(InclusiveScanSplit, CubTest) {
140 if (!at::cuda::is_available()) return;
141 at::globalContext().lazyInitCUDA(); // This is required to use PyTorch's caching allocator.
142
143 int *output1;
144 cudaMallocManaged(&output1, sizeof(int) * 10);
145
146 cudaDeviceSynchronize();
147 at::cuda::cub::inclusive_scan<int *, int *, ::at_cuda_detail::cub::Sum, /*max_cub_size=*/2>(
148 input, output1, ::at_cuda_detail::cub::Sum(), 10);
149 cudaDeviceSynchronize();
150
151 ASSERT_EQ(output1[0], 1);
152 ASSERT_EQ(output1[1], 3);
153 ASSERT_EQ(output1[2], 6);
154 ASSERT_EQ(output1[3], 10);
155 ASSERT_EQ(output1[4], 15);
156 ASSERT_EQ(output1[5], 21);
157 ASSERT_EQ(output1[6], 28);
158 ASSERT_EQ(output1[7], 36);
159 ASSERT_EQ(output1[8], 45);
160 ASSERT_EQ(output1[9], 55);
161 }
162
TEST(ExclusiveScanSplit,CubTest)163 TEST(ExclusiveScanSplit, CubTest) {
164 if (!at::cuda::is_available()) return;
165 at::globalContext().lazyInitCUDA(); // This is required to use PyTorch's caching allocator.
166
167 int *output2;
168 cudaMallocManaged(&output2, sizeof(int) * 10);
169
170 cudaDeviceSynchronize();
171 at::cuda::cub::exclusive_scan<int *, int *, ::at_cuda_detail::cub::Sum, int, /*max_cub_size=*/2>(
172 input, output2, ::at_cuda_detail::cub::Sum(), 0, 10);
173 cudaDeviceSynchronize();
174
175 ASSERT_EQ(output2[0], 0);
176 ASSERT_EQ(output2[1], 1);
177 ASSERT_EQ(output2[2], 3);
178 ASSERT_EQ(output2[3], 6);
179 ASSERT_EQ(output2[4], 10);
180 ASSERT_EQ(output2[5], 15);
181 ASSERT_EQ(output2[6], 21);
182 ASSERT_EQ(output2[7], 28);
183 ASSERT_EQ(output2[8], 36);
184 ASSERT_EQ(output2[9], 45);
185 }
186