• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2021 The Fuchsia Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #ifndef SRC_GRAPHICS_LIB_COMPUTE_RADIX_SORT_PLATFORMS_VK_SHADERS_PREFIX_H_
6 #define SRC_GRAPHICS_LIB_COMPUTE_RADIX_SORT_PLATFORMS_VK_SHADERS_PREFIX_H_
7 
8 //
9 // Requires several defines
10 //
11 #ifndef RS_PREFIX_LIMITS
12 #error "Error: \"prefix_limits.h\" not loaded"
13 #endif
14 
15 #ifndef RS_PREFIX_ARGS
16 #error "Error: RS_PREFIX_ARGS undefined"
17 #endif
18 
19 #ifndef RS_PREFIX_LOAD
20 #error "Error: RS_PREFIX_LOAD undefined"
21 #endif
22 
23 #ifndef RS_PREFIX_STORE
24 #error "Error: RS_PREFIX_STORE undefined"
25 #endif
26 
27 //
28 // Optional switches:
29 //
30 //   * Disable holding original inclusively scanned histogram values in registers.
31 //
32 //     #define RS_PREFIX_DISABLE_COMPONENTS_IN_REGISTERS
33 //
34 
35 //
36 // Compute exclusive prefix of uint32_t[256]
37 //
38 void
rs_prefix(RS_PREFIX_ARGS)39 rs_prefix(RS_PREFIX_ARGS)
40 {
41   if (RS_WORKGROUP_SUBGROUPS == 1)
42   {
43     //
44     // Workgroup is a single subgroup so no shared memory is required.
45     //
46 
47     //
48     // Exclusive scan-add the histogram
49     //
50     const uint32_t               h0     = RS_PREFIX_LOAD(0);
51     const uint32_t               h0_inc = subgroupInclusiveAdd(h0);
52     RS_SUBGROUP_UNIFORM uint32_t h_last = subgroupBroadcast(h0_inc, RS_SUBGROUP_SIZE - 1);
53 
54     RS_PREFIX_STORE(0) = h0_inc - h0;  // exclusive
55 
56     //
57     // Each iteration is dependent on the previous so no unrolling.  The
58     // compiler is free to hoist the loads upward though.
59     //
60     for (RS_SUBGROUP_UNIFORM uint32_t ii = RS_SUBGROUP_SIZE;  //
61          ii < RS_RADIX_SIZE;
62          ii += RS_SUBGROUP_SIZE)
63       {
64         const uint32_t h     = RS_PREFIX_LOAD(ii);
65         const uint32_t h_inc = subgroupInclusiveAdd(h) + h_last;
66         h_last               = subgroupBroadcast(h_inc, RS_SUBGROUP_SIZE - 1);
67 
68         RS_PREFIX_STORE(ii) = h_inc - h;  // exclusive
69       }
70   }
71   else
72   {
73     //
74     // Workgroup is multiple subgroups and uses shared memory to store
75     // the scan's intermediate results.
76     //
77     // Assumes a power-of-two subgroup, workgroup and radix size.
78     //
79     // Downsweep: Repeatedly scan reductions until they fit in a single
80     //            subgroup.
81     //
82     // Upsweep:   Then uniformly apply reductions to each subgroup.
83     //
84     //
85     //   Subgroup Size |  4 |  8 | 16 | 32 | 64 | 128 |
86     //   --------------+----+----+----+----+----+-----+
87     //   Sweep 0       | 64 | 32 | 16 |  8 |  4 |   2 | sweep_0[]
88     //   Sweep 1       | 16 |  4 |  - |  - |  - |   - | sweep_1[]
89     //   Sweep 2       |  4 |  - |  - |  - |  - |   - | sweep_2[]
90     //   --------------+----+----+----+----+----+-----+
91     //   Total dwords  | 84 | 36 | 16 |  8 |  4 |   2 |
92     //   --------------+----+----+----+----+----+-----+
93     //
94 #ifndef RS_PREFIX_DISABLE_COMPONENTS_IN_REGISTERS
95     uint32_t h_exc[RS_H_COMPONENTS];
96 #endif
97 
98     //
99     // Downsweep 0
100     //
101     [[unroll]] for (uint32_t ii = 0; ii < RS_H_COMPONENTS; ii++)
102     {
103       const uint32_t h = RS_PREFIX_LOAD(ii * RS_WORKGROUP_SIZE);
104 
105       const uint32_t h_inc = subgroupInclusiveAdd(h);
106 
107       const uint32_t smem_idx = (ii * RS_WORKGROUP_SUBGROUPS) + gl_SubgroupID;
108 
109       RS_PREFIX_SWEEP0(smem_idx) = subgroupBroadcast(h_inc, RS_SUBGROUP_SIZE - 1);
110 
111       //
112 #ifndef RS_PREFIX_DISABLE_COMPONENTS_IN_REGISTERS
113       h_exc[ii] = h_inc - h;
114 #else
115       RS_PREFIX_STORE(ii * RS_WORKGROUP_SIZE) = h_inc - h;
116 #endif
117     }
118 
119     barrier();
120 
121     //
122     // Skip generalizing these sweeps for all possible subgroups -- just
123     // write them directly.
124     //
125     if (RS_SUBGROUP_SIZE == 128)
126     {
127       // There are only two elements in SWEEP0 per subgroup. The scan is
128       // trivial so we fold it into the upsweep.
129     }
130     else if (RS_SUBGROUP_SIZE >= 16)
131     {
132       //////////////////////////////////////////////////////////////////////
133       //
134       // Scan 0
135       //
136       if (RS_SWEEP_0_SIZE != RS_WORKGROUP_SIZE && // workgroup has inactive components
137           gl_LocalInvocationID.x < RS_SWEEP_0_SIZE)
138         {
139           const uint32_t h0_red = RS_PREFIX_SWEEP0(gl_LocalInvocationID.x);
140           const uint32_t h0_inc = subgroupInclusiveAdd(h0_red);
141 
142           RS_PREFIX_SWEEP0(gl_LocalInvocationID.x) = h0_inc - h0_red;
143         }
144 
145       barrier();
146     }
147     else if (RS_SUBGROUP_SIZE == 8)
148     {
149       if (RS_SWEEP_0_SIZE < RS_WORKGROUP_SIZE)
150       {
151         //////////////////////////////////////////////////////////////////////
152         //
153         // Scan 0 and Downsweep 1
154         //
155         if (gl_LocalInvocationID.x < RS_SWEEP_0_SIZE)  // 32 invocations
156           {
157             const uint32_t h0_red = RS_PREFIX_SWEEP0(gl_LocalInvocationID.x);
158             const uint32_t h0_inc = subgroupInclusiveAdd(h0_red);
159 
160             RS_PREFIX_SWEEP0(gl_LocalInvocationID.x) = h0_inc - h0_red;
161             RS_PREFIX_SWEEP1(gl_SubgroupID) = subgroupBroadcast(h0_inc, RS_SUBGROUP_SIZE - 1);
162           }
163       }
164       else
165       {
166         //////////////////////////////////////////////////////////////////////
167         //
168         // Scan 0 and Downsweep 1
169         //
170         [[unroll]] for (uint32_t ii = 0; ii < RS_S0_PASSES; ii++)  // 32 invocations
171         {
172           const uint32_t idx0 = (ii * RS_WORKGROUP_SIZE) + gl_LocalInvocationID.x;
173           const uint32_t idx1 = (ii * RS_WORKGROUP_SUBGROUPS) + gl_SubgroupID;
174 
175           const uint32_t h0_red = RS_PREFIX_SWEEP0(idx0);
176           const uint32_t h0_inc = subgroupInclusiveAdd(h0_red);
177 
178           RS_PREFIX_SWEEP0(idx0) = h0_inc - h0_red;
179           RS_PREFIX_SWEEP1(idx1) = subgroupBroadcast(h0_inc, RS_SUBGROUP_SIZE - 1);
180         }
181       }
182 
183       barrier();
184 
185       //
186       // Scan 1
187       //
188       if (gl_LocalInvocationID.x < RS_SWEEP_1_SIZE)  // 4 invocations
189         {
190           const uint32_t h1_red = RS_PREFIX_SWEEP1(gl_LocalInvocationID.x);
191           const uint32_t h1_inc = subgroupInclusiveAdd(h1_red);
192 
193           RS_PREFIX_SWEEP1(gl_LocalInvocationID.x) = h1_inc - h1_red;
194         }
195 
196       barrier();
197     }
198     else if (RS_SUBGROUP_SIZE == 4)
199     {
200       //////////////////////////////////////////////////////////////////////
201       //
202       // Scan 0 and Downsweep 1
203       //
204       if (RS_SWEEP_0_SIZE < RS_WORKGROUP_SIZE)
205       {
206         if (gl_LocalInvocationID.x < RS_SWEEP_0_SIZE)  // 64 invocations
207           {
208             const uint32_t h0_red = RS_PREFIX_SWEEP0(gl_LocalInvocationID.x);
209             const uint32_t h0_inc = subgroupInclusiveAdd(h0_red);
210 
211             RS_PREFIX_SWEEP0(gl_LocalInvocationID.x) = h0_inc - h0_red;
212             RS_PREFIX_SWEEP1(gl_SubgroupID)          = subgroupBroadcast(h0_inc, RS_SUBGROUP_SIZE - 1);
213           }
214       }
215       else
216       {
217         [[unroll]] for (uint32_t ii = 0; ii < RS_S0_PASSES; ii++)  // 64 invocations
218         {
219           const uint32_t idx0 = (ii * RS_WORKGROUP_SIZE) + gl_LocalInvocationID.x;
220           const uint32_t idx1 = (ii * RS_WORKGROUP_SUBGROUPS) + gl_SubgroupID;
221 
222           const uint32_t h0_red = RS_PREFIX_SWEEP0(idx0);
223           const uint32_t h0_inc = subgroupInclusiveAdd(h0_red);
224 
225           RS_PREFIX_SWEEP0(idx0) = h0_inc - h0_red;
226           RS_PREFIX_SWEEP1(idx1) = subgroupBroadcast(h0_inc, RS_SUBGROUP_SIZE - 1);
227         }
228       }
229 
230       barrier();
231 
232       //
233       // Scan 1 and Downsweep 2
234       //
235       if (RS_SWEEP_1_SIZE < RS_WORKGROUP_SIZE)
236       {
237         if (gl_LocalInvocationID.x < RS_SWEEP_1_SIZE)  // 16 invocations
238           {
239             const uint32_t h1_red = RS_PREFIX_SWEEP1(gl_LocalInvocationID.x);
240             const uint32_t h1_inc = subgroupInclusiveAdd(h1_red);
241 
242             RS_PREFIX_SWEEP1(gl_LocalInvocationID.x) = h1_inc - h1_red;
243             RS_PREFIX_SWEEP2(gl_SubgroupID)          = subgroupBroadcast(h1_inc, RS_SUBGROUP_SIZE - 1);
244           }
245       }
246       else
247       {
248         [[unroll]] for (uint32_t ii = 0; ii < RS_S1_PASSES; ii++)  // 16 invocations
249         {
250           const uint32_t idx1 = (ii * RS_WORKGROUP_SIZE) + gl_LocalInvocationID.x;
251           const uint32_t idx2 = (ii * RS_WORKGROUP_SUBGROUPS) + gl_SubgroupID;
252 
253           const uint32_t h1_red = RS_PREFIX_SWEEP1(idx1);
254           const uint32_t h1_inc = subgroupInclusiveAdd(h1_red);
255 
256           RS_PREFIX_SWEEP1(idx1) = h1_inc - h1_red;
257           RS_PREFIX_SWEEP2(idx2) = subgroupBroadcast(h1_inc, RS_SUBGROUP_SIZE - 1);
258         }
259       }
260 
261       barrier();
262 
263       //
264       // Scan 2
265       //
266       // 4 invocations
267       //
268       if (gl_LocalInvocationID.x < RS_SWEEP_2_SIZE)
269         {
270           const uint32_t h2_red = RS_PREFIX_SWEEP2(gl_LocalInvocationID.x);
271           const uint32_t h2_inc = subgroupInclusiveAdd(h2_red);
272 
273           RS_PREFIX_SWEEP2(gl_LocalInvocationID.x) = h2_inc - h2_red;
274         }
275 
276       barrier();
277     }
278 
279     //////////////////////////////////////////////////////////////////////
280     //
281     // Final upsweep 0
282     //
283     if (RS_SUBGROUP_SIZE == 128)
284     {
285       // There must be more than one subgroup per workgroup, but the maximum
286       // workgroup size is 256 so there must be exactly two subgroups per
287       // workgroup and RS_H_COMPONENTS must be 1.
288 #ifndef RS_PREFIX_DISABLE_COMPONENTS_IN_REGISTERS
289       RS_PREFIX_STORE(0) = h_exc[0] + (gl_SubgroupID > 0 ? RS_PREFIX_SWEEP0(0) : 0);
290 #else
291       const uint32_t h_exc = RS_PREFIX_LOAD(0);
292 
293       RS_PREFIX_STORE(0) = h_exc + (gl_SubgroupID > 0 ? RS_PREFIX_SWEEP0(0) : 0);
294 #endif
295     }
296     else if (RS_SUBGROUP_SIZE >= 16)
297     {
298       [[unroll]] for (uint32_t ii = 0; ii < RS_H_COMPONENTS; ii++)
299       {
300         const uint32_t idx0 = (ii * RS_WORKGROUP_SUBGROUPS) + gl_SubgroupID;
301 
302         // clang format issue
303 #ifndef RS_PREFIX_DISABLE_COMPONENTS_IN_REGISTERS
304         RS_PREFIX_STORE(ii * RS_WORKGROUP_SIZE) = h_exc[ii] + RS_PREFIX_SWEEP0(idx0);
305 #else
306         const uint32_t h_exc = RS_PREFIX_LOAD(ii * RS_WORKGROUP_SIZE);
307 
308         RS_PREFIX_STORE(ii * RS_WORKGROUP_SIZE) = h_exc + RS_PREFIX_SWEEP0(idx0);
309 #endif
310       }
311     }
312     else if (RS_SUBGROUP_SIZE == 8)
313     {
314       [[unroll]] for (uint32_t ii = 0; ii < RS_H_COMPONENTS; ii++)
315       {
316         const uint32_t idx0 = (ii * RS_WORKGROUP_SUBGROUPS) + gl_SubgroupID;
317         const uint32_t idx1 = idx0 / RS_SUBGROUP_SIZE;
318 
319 #ifndef RS_PREFIX_DISABLE_COMPONENTS_IN_REGISTERS
320         RS_PREFIX_STORE(ii * RS_WORKGROUP_SIZE) =
321           h_exc[ii] + RS_PREFIX_SWEEP0(idx0) + RS_PREFIX_SWEEP1(idx1);
322 #else
323         const uint32_t h_exc = RS_PREFIX_LOAD(ii * RS_WORKGROUP_SIZE);
324 
325         RS_PREFIX_STORE(ii * RS_WORKGROUP_SIZE) =
326           h_exc + RS_PREFIX_SWEEP0(idx0) + RS_PREFIX_SWEEP1(idx1);
327 #endif
328       }
329     }
330     else if (RS_SUBGROUP_SIZE == 4)
331     {
332       [[unroll]] for (uint32_t ii = 0; ii < RS_H_COMPONENTS; ii++)
333       {
334         const uint32_t idx0 = (ii * RS_WORKGROUP_SUBGROUPS) + gl_SubgroupID;
335         const uint32_t idx1 = idx0 / RS_SUBGROUP_SIZE;
336         const uint32_t idx2 = idx1 / RS_SUBGROUP_SIZE;
337 
338 #ifndef RS_PREFIX_DISABLE_COMPONENTS_IN_REGISTERS
339         RS_PREFIX_STORE(ii * RS_WORKGROUP_SIZE) =
340           h_exc[ii] + (RS_PREFIX_SWEEP0(idx0) + RS_PREFIX_SWEEP1(idx1) + RS_PREFIX_SWEEP2(idx2));
341 #else
342         const uint32_t h_exc = RS_PREFIX_LOAD(ii * RS_WORKGROUP_SIZE);
343 
344         RS_PREFIX_STORE(ii * RS_WORKGROUP_SIZE) =
345           h_exc + (RS_PREFIX_SWEEP0(idx0) + RS_PREFIX_SWEEP1(idx1) + RS_PREFIX_SWEEP2(idx2));
346 #endif
347       }
348     }
349   }
350 }
351 
352 //
353 //
354 //
355 
356 #endif  // SRC_GRAPHICS_LIB_COMPUTE_RADIX_SORT_PLATFORMS_VK_SHADERS_PREFIX_H_
357