1 // Copyright 2021 The Fuchsia Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #ifndef SRC_GRAPHICS_LIB_COMPUTE_RADIX_SORT_PLATFORMS_VK_SHADERS_PREFIX_H_
6 #define SRC_GRAPHICS_LIB_COMPUTE_RADIX_SORT_PLATFORMS_VK_SHADERS_PREFIX_H_
7
8 //
9 // Requires several defines
10 //
11 #ifndef RS_PREFIX_LIMITS
12 #error "Error: \"prefix_limits.h\" not loaded"
13 #endif
14
15 #ifndef RS_PREFIX_ARGS
16 #error "Error: RS_PREFIX_ARGS undefined"
17 #endif
18
19 #ifndef RS_PREFIX_LOAD
20 #error "Error: RS_PREFIX_LOAD undefined"
21 #endif
22
23 #ifndef RS_PREFIX_STORE
24 #error "Error: RS_PREFIX_STORE undefined"
25 #endif
26
27 //
28 // Optional switches:
29 //
30 // * Disable holding original inclusively scanned histogram values in registers.
31 //
32 // #define RS_PREFIX_DISABLE_COMPONENTS_IN_REGISTERS
33 //
34
35 //
36 // Compute exclusive prefix of uint32_t[256]
37 //
38 void
rs_prefix(RS_PREFIX_ARGS)39 rs_prefix(RS_PREFIX_ARGS)
40 {
41 if (RS_WORKGROUP_SUBGROUPS == 1)
42 {
43 //
44 // Workgroup is a single subgroup so no shared memory is required.
45 //
46
47 //
48 // Exclusive scan-add the histogram
49 //
50 const uint32_t h0 = RS_PREFIX_LOAD(0);
51 const uint32_t h0_inc = subgroupInclusiveAdd(h0);
52 RS_SUBGROUP_UNIFORM uint32_t h_last = subgroupBroadcast(h0_inc, RS_SUBGROUP_SIZE - 1);
53
54 RS_PREFIX_STORE(0) = h0_inc - h0; // exclusive
55
56 //
57 // Each iteration is dependent on the previous so no unrolling. The
58 // compiler is free to hoist the loads upward though.
59 //
60 for (RS_SUBGROUP_UNIFORM uint32_t ii = RS_SUBGROUP_SIZE; //
61 ii < RS_RADIX_SIZE;
62 ii += RS_SUBGROUP_SIZE)
63 {
64 const uint32_t h = RS_PREFIX_LOAD(ii);
65 const uint32_t h_inc = subgroupInclusiveAdd(h) + h_last;
66 h_last = subgroupBroadcast(h_inc, RS_SUBGROUP_SIZE - 1);
67
68 RS_PREFIX_STORE(ii) = h_inc - h; // exclusive
69 }
70 }
71 else
72 {
73 //
74 // Workgroup is multiple subgroups and uses shared memory to store
75 // the scan's intermediate results.
76 //
77 // Assumes a power-of-two subgroup, workgroup and radix size.
78 //
79 // Downsweep: Repeatedly scan reductions until they fit in a single
80 // subgroup.
81 //
82 // Upsweep: Then uniformly apply reductions to each subgroup.
83 //
84 //
85 // Subgroup Size | 4 | 8 | 16 | 32 | 64 | 128 |
86 // --------------+----+----+----+----+----+-----+
87 // Sweep 0 | 64 | 32 | 16 | 8 | 4 | 2 | sweep_0[]
88 // Sweep 1 | 16 | 4 | - | - | - | - | sweep_1[]
89 // Sweep 2 | 4 | - | - | - | - | - | sweep_2[]
90 // --------------+----+----+----+----+----+-----+
91 // Total dwords | 84 | 36 | 16 | 8 | 4 | 2 |
92 // --------------+----+----+----+----+----+-----+
93 //
94 #ifndef RS_PREFIX_DISABLE_COMPONENTS_IN_REGISTERS
95 uint32_t h_exc[RS_H_COMPONENTS];
96 #endif
97
98 //
99 // Downsweep 0
100 //
101 [[unroll]] for (uint32_t ii = 0; ii < RS_H_COMPONENTS; ii++)
102 {
103 const uint32_t h = RS_PREFIX_LOAD(ii * RS_WORKGROUP_SIZE);
104
105 const uint32_t h_inc = subgroupInclusiveAdd(h);
106
107 const uint32_t smem_idx = (ii * RS_WORKGROUP_SUBGROUPS) + gl_SubgroupID;
108
109 RS_PREFIX_SWEEP0(smem_idx) = subgroupBroadcast(h_inc, RS_SUBGROUP_SIZE - 1);
110
111 //
112 #ifndef RS_PREFIX_DISABLE_COMPONENTS_IN_REGISTERS
113 h_exc[ii] = h_inc - h;
114 #else
115 RS_PREFIX_STORE(ii * RS_WORKGROUP_SIZE) = h_inc - h;
116 #endif
117 }
118
119 barrier();
120
121 //
122 // Skip generalizing these sweeps for all possible subgroups -- just
123 // write them directly.
124 //
125 if (RS_SUBGROUP_SIZE == 128)
126 {
127 // There are only two elements in SWEEP0 per subgroup. The scan is
128 // trivial so we fold it into the upsweep.
129 }
130 else if (RS_SUBGROUP_SIZE >= 16)
131 {
132 //////////////////////////////////////////////////////////////////////
133 //
134 // Scan 0
135 //
136 if (RS_SWEEP_0_SIZE != RS_WORKGROUP_SIZE && // workgroup has inactive components
137 gl_LocalInvocationID.x < RS_SWEEP_0_SIZE)
138 {
139 const uint32_t h0_red = RS_PREFIX_SWEEP0(gl_LocalInvocationID.x);
140 const uint32_t h0_inc = subgroupInclusiveAdd(h0_red);
141
142 RS_PREFIX_SWEEP0(gl_LocalInvocationID.x) = h0_inc - h0_red;
143 }
144
145 barrier();
146 }
147 else if (RS_SUBGROUP_SIZE == 8)
148 {
149 if (RS_SWEEP_0_SIZE < RS_WORKGROUP_SIZE)
150 {
151 //////////////////////////////////////////////////////////////////////
152 //
153 // Scan 0 and Downsweep 1
154 //
155 if (gl_LocalInvocationID.x < RS_SWEEP_0_SIZE) // 32 invocations
156 {
157 const uint32_t h0_red = RS_PREFIX_SWEEP0(gl_LocalInvocationID.x);
158 const uint32_t h0_inc = subgroupInclusiveAdd(h0_red);
159
160 RS_PREFIX_SWEEP0(gl_LocalInvocationID.x) = h0_inc - h0_red;
161 RS_PREFIX_SWEEP1(gl_SubgroupID) = subgroupBroadcast(h0_inc, RS_SUBGROUP_SIZE - 1);
162 }
163 }
164 else
165 {
166 //////////////////////////////////////////////////////////////////////
167 //
168 // Scan 0 and Downsweep 1
169 //
170 [[unroll]] for (uint32_t ii = 0; ii < RS_S0_PASSES; ii++) // 32 invocations
171 {
172 const uint32_t idx0 = (ii * RS_WORKGROUP_SIZE) + gl_LocalInvocationID.x;
173 const uint32_t idx1 = (ii * RS_WORKGROUP_SUBGROUPS) + gl_SubgroupID;
174
175 const uint32_t h0_red = RS_PREFIX_SWEEP0(idx0);
176 const uint32_t h0_inc = subgroupInclusiveAdd(h0_red);
177
178 RS_PREFIX_SWEEP0(idx0) = h0_inc - h0_red;
179 RS_PREFIX_SWEEP1(idx1) = subgroupBroadcast(h0_inc, RS_SUBGROUP_SIZE - 1);
180 }
181 }
182
183 barrier();
184
185 //
186 // Scan 1
187 //
188 if (gl_LocalInvocationID.x < RS_SWEEP_1_SIZE) // 4 invocations
189 {
190 const uint32_t h1_red = RS_PREFIX_SWEEP1(gl_LocalInvocationID.x);
191 const uint32_t h1_inc = subgroupInclusiveAdd(h1_red);
192
193 RS_PREFIX_SWEEP1(gl_LocalInvocationID.x) = h1_inc - h1_red;
194 }
195
196 barrier();
197 }
198 else if (RS_SUBGROUP_SIZE == 4)
199 {
200 //////////////////////////////////////////////////////////////////////
201 //
202 // Scan 0 and Downsweep 1
203 //
204 if (RS_SWEEP_0_SIZE < RS_WORKGROUP_SIZE)
205 {
206 if (gl_LocalInvocationID.x < RS_SWEEP_0_SIZE) // 64 invocations
207 {
208 const uint32_t h0_red = RS_PREFIX_SWEEP0(gl_LocalInvocationID.x);
209 const uint32_t h0_inc = subgroupInclusiveAdd(h0_red);
210
211 RS_PREFIX_SWEEP0(gl_LocalInvocationID.x) = h0_inc - h0_red;
212 RS_PREFIX_SWEEP1(gl_SubgroupID) = subgroupBroadcast(h0_inc, RS_SUBGROUP_SIZE - 1);
213 }
214 }
215 else
216 {
217 [[unroll]] for (uint32_t ii = 0; ii < RS_S0_PASSES; ii++) // 64 invocations
218 {
219 const uint32_t idx0 = (ii * RS_WORKGROUP_SIZE) + gl_LocalInvocationID.x;
220 const uint32_t idx1 = (ii * RS_WORKGROUP_SUBGROUPS) + gl_SubgroupID;
221
222 const uint32_t h0_red = RS_PREFIX_SWEEP0(idx0);
223 const uint32_t h0_inc = subgroupInclusiveAdd(h0_red);
224
225 RS_PREFIX_SWEEP0(idx0) = h0_inc - h0_red;
226 RS_PREFIX_SWEEP1(idx1) = subgroupBroadcast(h0_inc, RS_SUBGROUP_SIZE - 1);
227 }
228 }
229
230 barrier();
231
232 //
233 // Scan 1 and Downsweep 2
234 //
235 if (RS_SWEEP_1_SIZE < RS_WORKGROUP_SIZE)
236 {
237 if (gl_LocalInvocationID.x < RS_SWEEP_1_SIZE) // 16 invocations
238 {
239 const uint32_t h1_red = RS_PREFIX_SWEEP1(gl_LocalInvocationID.x);
240 const uint32_t h1_inc = subgroupInclusiveAdd(h1_red);
241
242 RS_PREFIX_SWEEP1(gl_LocalInvocationID.x) = h1_inc - h1_red;
243 RS_PREFIX_SWEEP2(gl_SubgroupID) = subgroupBroadcast(h1_inc, RS_SUBGROUP_SIZE - 1);
244 }
245 }
246 else
247 {
248 [[unroll]] for (uint32_t ii = 0; ii < RS_S1_PASSES; ii++) // 16 invocations
249 {
250 const uint32_t idx1 = (ii * RS_WORKGROUP_SIZE) + gl_LocalInvocationID.x;
251 const uint32_t idx2 = (ii * RS_WORKGROUP_SUBGROUPS) + gl_SubgroupID;
252
253 const uint32_t h1_red = RS_PREFIX_SWEEP1(idx1);
254 const uint32_t h1_inc = subgroupInclusiveAdd(h1_red);
255
256 RS_PREFIX_SWEEP1(idx1) = h1_inc - h1_red;
257 RS_PREFIX_SWEEP2(idx2) = subgroupBroadcast(h1_inc, RS_SUBGROUP_SIZE - 1);
258 }
259 }
260
261 barrier();
262
263 //
264 // Scan 2
265 //
266 // 4 invocations
267 //
268 if (gl_LocalInvocationID.x < RS_SWEEP_2_SIZE)
269 {
270 const uint32_t h2_red = RS_PREFIX_SWEEP2(gl_LocalInvocationID.x);
271 const uint32_t h2_inc = subgroupInclusiveAdd(h2_red);
272
273 RS_PREFIX_SWEEP2(gl_LocalInvocationID.x) = h2_inc - h2_red;
274 }
275
276 barrier();
277 }
278
279 //////////////////////////////////////////////////////////////////////
280 //
281 // Final upsweep 0
282 //
283 if (RS_SUBGROUP_SIZE == 128)
284 {
285 // There must be more than one subgroup per workgroup, but the maximum
286 // workgroup size is 256 so there must be exactly two subgroups per
287 // workgroup and RS_H_COMPONENTS must be 1.
288 #ifndef RS_PREFIX_DISABLE_COMPONENTS_IN_REGISTERS
289 RS_PREFIX_STORE(0) = h_exc[0] + (gl_SubgroupID > 0 ? RS_PREFIX_SWEEP0(0) : 0);
290 #else
291 const uint32_t h_exc = RS_PREFIX_LOAD(0);
292
293 RS_PREFIX_STORE(0) = h_exc + (gl_SubgroupID > 0 ? RS_PREFIX_SWEEP0(0) : 0);
294 #endif
295 }
296 else if (RS_SUBGROUP_SIZE >= 16)
297 {
298 [[unroll]] for (uint32_t ii = 0; ii < RS_H_COMPONENTS; ii++)
299 {
300 const uint32_t idx0 = (ii * RS_WORKGROUP_SUBGROUPS) + gl_SubgroupID;
301
302 // clang format issue
303 #ifndef RS_PREFIX_DISABLE_COMPONENTS_IN_REGISTERS
304 RS_PREFIX_STORE(ii * RS_WORKGROUP_SIZE) = h_exc[ii] + RS_PREFIX_SWEEP0(idx0);
305 #else
306 const uint32_t h_exc = RS_PREFIX_LOAD(ii * RS_WORKGROUP_SIZE);
307
308 RS_PREFIX_STORE(ii * RS_WORKGROUP_SIZE) = h_exc + RS_PREFIX_SWEEP0(idx0);
309 #endif
310 }
311 }
312 else if (RS_SUBGROUP_SIZE == 8)
313 {
314 [[unroll]] for (uint32_t ii = 0; ii < RS_H_COMPONENTS; ii++)
315 {
316 const uint32_t idx0 = (ii * RS_WORKGROUP_SUBGROUPS) + gl_SubgroupID;
317 const uint32_t idx1 = idx0 / RS_SUBGROUP_SIZE;
318
319 #ifndef RS_PREFIX_DISABLE_COMPONENTS_IN_REGISTERS
320 RS_PREFIX_STORE(ii * RS_WORKGROUP_SIZE) =
321 h_exc[ii] + RS_PREFIX_SWEEP0(idx0) + RS_PREFIX_SWEEP1(idx1);
322 #else
323 const uint32_t h_exc = RS_PREFIX_LOAD(ii * RS_WORKGROUP_SIZE);
324
325 RS_PREFIX_STORE(ii * RS_WORKGROUP_SIZE) =
326 h_exc + RS_PREFIX_SWEEP0(idx0) + RS_PREFIX_SWEEP1(idx1);
327 #endif
328 }
329 }
330 else if (RS_SUBGROUP_SIZE == 4)
331 {
332 [[unroll]] for (uint32_t ii = 0; ii < RS_H_COMPONENTS; ii++)
333 {
334 const uint32_t idx0 = (ii * RS_WORKGROUP_SUBGROUPS) + gl_SubgroupID;
335 const uint32_t idx1 = idx0 / RS_SUBGROUP_SIZE;
336 const uint32_t idx2 = idx1 / RS_SUBGROUP_SIZE;
337
338 #ifndef RS_PREFIX_DISABLE_COMPONENTS_IN_REGISTERS
339 RS_PREFIX_STORE(ii * RS_WORKGROUP_SIZE) =
340 h_exc[ii] + (RS_PREFIX_SWEEP0(idx0) + RS_PREFIX_SWEEP1(idx1) + RS_PREFIX_SWEEP2(idx2));
341 #else
342 const uint32_t h_exc = RS_PREFIX_LOAD(ii * RS_WORKGROUP_SIZE);
343
344 RS_PREFIX_STORE(ii * RS_WORKGROUP_SIZE) =
345 h_exc + (RS_PREFIX_SWEEP0(idx0) + RS_PREFIX_SWEEP1(idx1) + RS_PREFIX_SWEEP2(idx2));
346 #endif
347 }
348 }
349 }
350 }
351
352 //
353 //
354 //
355
356 #endif // SRC_GRAPHICS_LIB_COMPUTE_RADIX_SORT_PLATFORMS_VK_SHADERS_PREFIX_H_
357