1 /*
2 * Copyright © 2023 Collabora, Ltd.
3 * Copyright © 2017 Intel Corporation
4 *
5 * Permission is hereby granted, free of charge, to any person obtaining a
6 * copy of this software and associated documentation files (the "Software"),
7 * to deal in the Software without restriction, including without limitation
8 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
9 * and/or sell copies of the Software, and to permit persons to whom the
10 * Software is furnished to do so, subject to the following conditions:
11 *
12 * The above copyright notice and this permission notice (including the next
13 * paragraph) shall be included in all copies or substantial portions of the
14 * Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
19 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
21 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
22 * IN THE SOFTWARE.
23 */
24
25 #include "util/u_math.h"
26 #include "nir.h"
27 #include "nir_builder.h"
28
29 /**
30 * \file nir_opt_intrinsics.c
31 */
32
33 static unsigned
get_max_subgroup_size(const nir_lower_subgroups_options * options)34 get_max_subgroup_size(const nir_lower_subgroups_options *options)
35 {
36 return options->subgroup_size
37 ? options->subgroup_size
38 : options->ballot_components * options->ballot_bit_size;
39 }
40
41 static nir_intrinsic_instr *
lower_subgroups_64bit_split_intrinsic(nir_builder * b,nir_intrinsic_instr * intrin,unsigned int component)42 lower_subgroups_64bit_split_intrinsic(nir_builder *b, nir_intrinsic_instr *intrin,
43 unsigned int component)
44 {
45 nir_def *comp;
46 if (component == 0)
47 comp = nir_unpack_64_2x32_split_x(b, intrin->src[0].ssa);
48 else
49 comp = nir_unpack_64_2x32_split_y(b, intrin->src[0].ssa);
50
51 nir_intrinsic_instr *intr = nir_intrinsic_instr_create(b->shader, intrin->intrinsic);
52 nir_def_init(&intr->instr, &intr->def, 1, 32);
53 intr->const_index[0] = intrin->const_index[0];
54 intr->const_index[1] = intrin->const_index[1];
55 intr->src[0] = nir_src_for_ssa(comp);
56 if (nir_intrinsic_infos[intrin->intrinsic].num_srcs == 2)
57 intr->src[1] = nir_src_for_ssa(intrin->src[1].ssa);
58
59 intr->num_components = 1;
60 nir_builder_instr_insert(b, &intr->instr);
61 return intr;
62 }
63
64 static nir_def *
lower_subgroup_op_to_32bit(nir_builder * b,nir_intrinsic_instr * intrin)65 lower_subgroup_op_to_32bit(nir_builder *b, nir_intrinsic_instr *intrin)
66 {
67 assert(intrin->src[0].ssa->bit_size == 64);
68 nir_intrinsic_instr *intr_x = lower_subgroups_64bit_split_intrinsic(b, intrin, 0);
69 nir_intrinsic_instr *intr_y = lower_subgroups_64bit_split_intrinsic(b, intrin, 1);
70 return nir_pack_64_2x32_split(b, &intr_x->def, &intr_y->def);
71 }
72
73 /* Return a mask which is 1 for threads up to the run-time subgroup size, i.e.
74 * 1 for the entire subgroup. SPIR-V requires us to return 0 for indices at or
75 * above the subgroup size for the masks, but gt_mask and ge_mask make them 1
76 * so we have to "and" with this mask.
77 */
78 static nir_def *
build_subgroup_mask(nir_builder * b,const nir_lower_subgroups_options * options)79 build_subgroup_mask(nir_builder *b,
80 const nir_lower_subgroups_options *options)
81 {
82 nir_def *subgroup_size = nir_load_subgroup_size(b);
83
84 /* First compute the result assuming one ballot component. */
85 nir_def *result =
86 nir_ushr(b, nir_imm_intN_t(b, ~0ull, options->ballot_bit_size),
87 nir_isub_imm(b, options->ballot_bit_size,
88 subgroup_size));
89
90 /* Since the subgroup size and ballot bitsize are both powers of two, there
91 * are two possible cases to consider:
92 *
93 * (1) The subgroup size is less than the ballot bitsize. We need to return
94 * "result" in the first component and 0 in every other component.
95 * (2) The subgroup size is a multiple of the ballot bitsize. We need to
96 * return ~0 if the subgroup size divided by the ballot bitsize is less
97 * than or equal to the index in the vector and 0 otherwise. For example,
98 * with a target ballot type of 4 x uint32 and subgroup_size = 64 we'd need
99 * to return { ~0, ~0, 0, 0 }.
100 *
101 * In case (2) it turns out that "result" will be ~0, because
102 * "ballot_bit_size - subgroup_size" is also a multiple of
103 * "ballot_bit_size" and since nir_ushr masks the shift value it will
104 * shifted by 0. This means that the first component can just be "result"
105 * in all cases. The other components will also get the correct value in
106 * case (1) if we just use the rule in case (2), so we'll get the correct
107 * result if we just follow (2) and then replace the first component with
108 * "result".
109 */
110 nir_const_value min_idx[4];
111 for (unsigned i = 0; i < options->ballot_components; i++)
112 min_idx[i] = nir_const_value_for_int(i * options->ballot_bit_size, 32);
113 nir_def *min_idx_val = nir_build_imm(b, options->ballot_components, 32, min_idx);
114
115 nir_def *result_extended =
116 nir_pad_vector_imm_int(b, result, ~0ull, options->ballot_components);
117
118 return nir_bcsel(b, nir_ult(b, min_idx_val, subgroup_size),
119 result_extended, nir_imm_intN_t(b, 0, options->ballot_bit_size));
120 }
121
122 /* Return a ballot-mask-sized value which represents "val" sign-extended and
123 * then shifted left by "shift". Only particular values for "val" are
124 * supported, see below.
125 *
126 * This function assumes that `val << shift` will never span a ballot_bit_size
127 * word and that the high bit of val can be extended across the entire result.
128 * This is trivially satisfied for 0, 1, ~0, and ~1. However, it may also be
129 * fine for other values if the shift is guaranteed to be sufficiently
130 * aligned. One example is 0xf when the shift is known to be a multiple of 4.
131 */
132 static nir_def *
build_ballot_imm_ishl(nir_builder * b,int64_t val,nir_def * shift,const nir_lower_subgroups_options * options)133 build_ballot_imm_ishl(nir_builder *b, int64_t val, nir_def *shift,
134 const nir_lower_subgroups_options *options)
135 {
136 /* First compute the result assuming one ballot component. */
137 nir_def *result =
138 nir_ishl(b, nir_imm_intN_t(b, val, options->ballot_bit_size), shift);
139
140 if (options->ballot_components == 1)
141 return result;
142
143 /* Fix up the result when there is > 1 component. The idea is that nir_ishl
144 * masks out the high bits of the shift value already, so in case there's
145 * more than one component the component which 1 would be shifted into
146 * already has the right value and all we have to do is fixup the other
147 * components. Components below it should always be 0, and components above
148 * it must be either 0 or ~0 because of the assert above. For example, if
149 * the target ballot size is 2 x uint32, and we're shifting 1 by 33, then
150 * we'll feed 33 into ishl, which will mask it off to get 1, so we'll
151 * compute a single-component result of 2, which is correct for the second
152 * component, but the first component needs to be 0, which we get by
153 * comparing the high bits of the shift with 0 and selecting the original
154 * answer or 0 for the first component (and something similar with the
155 * second component). This idea is generalized here for any component count
156 */
157 nir_const_value min_shift[4];
158 for (unsigned i = 0; i < options->ballot_components; i++)
159 min_shift[i] = nir_const_value_for_int(i * options->ballot_bit_size, 32);
160 nir_def *min_shift_val = nir_build_imm(b, options->ballot_components, 32, min_shift);
161
162 nir_const_value max_shift[4];
163 for (unsigned i = 0; i < options->ballot_components; i++)
164 max_shift[i] = nir_const_value_for_int((i + 1) * options->ballot_bit_size, 32);
165 nir_def *max_shift_val = nir_build_imm(b, options->ballot_components, 32, max_shift);
166
167 return nir_bcsel(b, nir_ult(b, shift, max_shift_val),
168 nir_bcsel(b, nir_ult(b, shift, min_shift_val),
169 nir_imm_intN_t(b, val >> 63, result->bit_size),
170 result),
171 nir_imm_intN_t(b, 0, result->bit_size));
172 }
173
174 static nir_def *
ballot_type_to_uint(nir_builder * b,nir_def * value,const nir_lower_subgroups_options * options)175 ballot_type_to_uint(nir_builder *b, nir_def *value,
176 const nir_lower_subgroups_options *options)
177 {
178 /* Allow internal generated ballots to pass through */
179 if (value->num_components == options->ballot_components &&
180 value->bit_size == options->ballot_bit_size)
181 return value;
182
183 /* Only the new-style SPIR-V subgroup instructions take a ballot result as
184 * an argument, so we only use this on uvec4 types.
185 */
186 assert(value->num_components == 4 && value->bit_size == 32);
187
188 return nir_extract_bits(b, &value, 1, 0, options->ballot_components,
189 options->ballot_bit_size);
190 }
191
192 static nir_def *
uint_to_ballot_type(nir_builder * b,nir_def * value,unsigned num_components,unsigned bit_size)193 uint_to_ballot_type(nir_builder *b, nir_def *value,
194 unsigned num_components, unsigned bit_size)
195 {
196 assert(util_is_power_of_two_nonzero(num_components));
197 assert(util_is_power_of_two_nonzero(value->num_components));
198
199 unsigned total_bits = bit_size * num_components;
200
201 /* If the source doesn't have enough bits, zero-pad */
202 if (total_bits > value->bit_size * value->num_components)
203 value = nir_pad_vector_imm_int(b, value, 0, total_bits / value->bit_size);
204
205 value = nir_bitcast_vector(b, value, bit_size);
206
207 /* If the source has too many components, truncate. This can happen if,
208 * for instance, we're implementing GL_ARB_shader_ballot or
209 * VK_EXT_shader_subgroup_ballot which have 64-bit ballot values on an
210 * architecture with a native 128-bit uvec4 ballot. This comes up in Zink
211 * for OpenGL on Vulkan. It's the job of the driver calling this lowering
212 * pass to ensure that it's restricted subgroup sizes sufficiently that we
213 * have enough ballot bits.
214 */
215 if (value->num_components > num_components)
216 value = nir_trim_vector(b, value, num_components);
217
218 return value;
219 }
220
221 static nir_def *
lower_subgroup_op_to_scalar(nir_builder * b,nir_intrinsic_instr * intrin)222 lower_subgroup_op_to_scalar(nir_builder *b, nir_intrinsic_instr *intrin)
223 {
224 /* This is safe to call on scalar things but it would be silly */
225 assert(intrin->def.num_components > 1);
226
227 nir_def *value = intrin->src[0].ssa;
228 nir_def *reads[NIR_MAX_VEC_COMPONENTS];
229
230 for (unsigned i = 0; i < intrin->num_components; i++) {
231 nir_intrinsic_instr *chan_intrin =
232 nir_intrinsic_instr_create(b->shader, intrin->intrinsic);
233 nir_def_init(&chan_intrin->instr, &chan_intrin->def, 1,
234 intrin->def.bit_size);
235 chan_intrin->num_components = 1;
236
237 /* value */
238 chan_intrin->src[0] = nir_src_for_ssa(nir_channel(b, value, i));
239 /* invocation */
240 if (nir_intrinsic_infos[intrin->intrinsic].num_srcs > 1) {
241 assert(nir_intrinsic_infos[intrin->intrinsic].num_srcs == 2);
242 chan_intrin->src[1] = nir_src_for_ssa(intrin->src[1].ssa);
243 }
244
245 chan_intrin->const_index[0] = intrin->const_index[0];
246 chan_intrin->const_index[1] = intrin->const_index[1];
247
248 nir_builder_instr_insert(b, &chan_intrin->instr);
249 reads[i] = &chan_intrin->def;
250 }
251
252 return nir_vec(b, reads, intrin->num_components);
253 }
254
255 static nir_def *
lower_vote_eq_to_scalar(nir_builder * b,nir_intrinsic_instr * intrin)256 lower_vote_eq_to_scalar(nir_builder *b, nir_intrinsic_instr *intrin)
257 {
258 nir_def *value = intrin->src[0].ssa;
259
260 nir_def *result = NULL;
261 for (unsigned i = 0; i < intrin->num_components; i++) {
262 nir_def* chan = nir_channel(b, value, i);
263
264 if (intrin->intrinsic == nir_intrinsic_vote_feq) {
265 chan = nir_vote_feq(b, intrin->def.bit_size, chan);
266 } else {
267 chan = nir_vote_ieq(b, intrin->def.bit_size, chan);
268 }
269
270 if (result) {
271 result = nir_iand(b, result, chan);
272 } else {
273 result = chan;
274 }
275 }
276
277 return result;
278 }
279
280 static nir_def *
lower_vote_eq(nir_builder * b,nir_intrinsic_instr * intrin)281 lower_vote_eq(nir_builder *b, nir_intrinsic_instr *intrin)
282 {
283 nir_def *value = intrin->src[0].ssa;
284
285 /* We have to implicitly lower to scalar */
286 nir_def *all_eq = NULL;
287 for (unsigned i = 0; i < intrin->num_components; i++) {
288 nir_def *rfi = nir_read_first_invocation(b, nir_channel(b, value, i));
289
290 nir_def *is_eq;
291 if (intrin->intrinsic == nir_intrinsic_vote_feq) {
292 is_eq = nir_feq(b, rfi, nir_channel(b, value, i));
293 } else {
294 is_eq = nir_ieq(b, rfi, nir_channel(b, value, i));
295 }
296
297 if (all_eq == NULL) {
298 all_eq = is_eq;
299 } else {
300 all_eq = nir_iand(b, all_eq, is_eq);
301 }
302 }
303
304 return nir_vote_all(b, 1, all_eq);
305 }
306
307 static nir_def *
lower_shuffle_to_swizzle(nir_builder * b,nir_intrinsic_instr * intrin)308 lower_shuffle_to_swizzle(nir_builder *b, nir_intrinsic_instr *intrin)
309 {
310 unsigned mask = nir_src_as_uint(intrin->src[1]);
311
312 if (mask >= 32)
313 return NULL;
314
315 return nir_masked_swizzle_amd(b, intrin->src[0].ssa,
316 .swizzle_mask = (mask << 10) | 0x1f,
317 .fetch_inactive = true);
318 }
319
320 /* Lowers "specialized" shuffles to a generic nir_intrinsic_shuffle. */
321
322 static nir_def *
lower_to_shuffle(nir_builder * b,nir_intrinsic_instr * intrin,const nir_lower_subgroups_options * options)323 lower_to_shuffle(nir_builder *b, nir_intrinsic_instr *intrin,
324 const nir_lower_subgroups_options *options)
325 {
326 if (intrin->intrinsic == nir_intrinsic_shuffle_xor &&
327 options->lower_shuffle_to_swizzle_amd &&
328 nir_src_is_const(intrin->src[1])) {
329
330 nir_def *result = lower_shuffle_to_swizzle(b, intrin);
331 if (result)
332 return result;
333 }
334
335 nir_def *index = nir_load_subgroup_invocation(b);
336 switch (intrin->intrinsic) {
337 case nir_intrinsic_shuffle_xor:
338 index = nir_ixor(b, index, intrin->src[1].ssa);
339 break;
340 case nir_intrinsic_shuffle_up:
341 index = nir_isub(b, index, intrin->src[1].ssa);
342 break;
343 case nir_intrinsic_shuffle_down:
344 index = nir_iadd(b, index, intrin->src[1].ssa);
345 break;
346 case nir_intrinsic_quad_broadcast:
347 index = nir_ior(b, nir_iand_imm(b, index, ~0x3),
348 intrin->src[1].ssa);
349 break;
350 case nir_intrinsic_quad_swap_horizontal:
351 /* For Quad operations, subgroups are divided into quads where
352 * (invocation % 4) is the index to a square arranged as follows:
353 *
354 * +---+---+
355 * | 0 | 1 |
356 * +---+---+
357 * | 2 | 3 |
358 * +---+---+
359 */
360 index = nir_ixor(b, index, nir_imm_int(b, 0x1));
361 break;
362 case nir_intrinsic_quad_swap_vertical:
363 index = nir_ixor(b, index, nir_imm_int(b, 0x2));
364 break;
365 case nir_intrinsic_quad_swap_diagonal:
366 index = nir_ixor(b, index, nir_imm_int(b, 0x3));
367 break;
368 case nir_intrinsic_rotate: {
369 nir_def *delta = intrin->src[1].ssa;
370 nir_def *local_id = nir_load_subgroup_invocation(b);
371 const unsigned cluster_size = nir_intrinsic_cluster_size(intrin);
372
373 nir_def *rotation_group_mask =
374 cluster_size > 0 ? nir_imm_int(b, (int)(cluster_size - 1)) : nir_iadd_imm(b, nir_load_subgroup_size(b), -1);
375
376 index = nir_iand(b, nir_iadd(b, local_id, delta),
377 rotation_group_mask);
378 if (cluster_size > 0) {
379 index = nir_iadd(b, index,
380 nir_iand(b, local_id, nir_inot(b, rotation_group_mask)));
381 }
382 break;
383 }
384 default:
385 unreachable("Invalid intrinsic");
386 }
387
388 return nir_shuffle(b, intrin->src[0].ssa, index);
389 }
390
391 static const struct glsl_type *
glsl_type_for_ssa(nir_def * def)392 glsl_type_for_ssa(nir_def *def)
393 {
394 const struct glsl_type *comp_type = def->bit_size == 1 ? glsl_bool_type() : glsl_uintN_t_type(def->bit_size);
395 return glsl_replace_vector_type(comp_type, def->num_components);
396 }
397
398 /* Lower nir_intrinsic_shuffle to a waterfall loop + nir_read_invocation.
399 */
400 static nir_def *
lower_shuffle(nir_builder * b,nir_intrinsic_instr * intrin)401 lower_shuffle(nir_builder *b, nir_intrinsic_instr *intrin)
402 {
403 nir_def *val = intrin->src[0].ssa;
404 nir_def *id = intrin->src[1].ssa;
405
406 /* The loop is something like:
407 *
408 * while (true) {
409 * first_id = readFirstInvocation(gl_SubgroupInvocationID);
410 * first_val = readFirstInvocation(val);
411 * first_result = readInvocation(val, readFirstInvocation(id));
412 * if (id == first_id)
413 * result = first_val;
414 * if (elect()) {
415 * if (id > gl_SubgroupInvocationID) {
416 * result = first_result;
417 * }
418 * break;
419 * }
420 * }
421 *
422 * The idea is to guarantee, on each iteration of the loop, that anything
423 * reading from first_id gets the correct value, so that we can then kill
424 * it off by breaking out of the loop. Before doing that we also have to
425 * ensure that first_id invocation gets the correct value. It only won't be
426 * assigned the correct value already if the invocation it's reading from
427 * isn't already killed off, that is, if it's later than its own ID.
428 * Invocations where id <= gl_SubgroupInvocationID will be assigned their
429 * result in the first if, and invocations where id >
430 * gl_SubgroupInvocationID will be assigned their result in the second if.
431 *
432 * We do this more complicated loop rather than looping over all id's
433 * explicitly because at this point we don't know the "actual" subgroup
434 * size and at the moment there's no way to get at it, which means we may
435 * loop over always-inactive invocations.
436 */
437
438 nir_def *subgroup_id = nir_load_subgroup_invocation(b);
439
440 nir_variable *result =
441 nir_local_variable_create(b->impl, glsl_type_for_ssa(val), "result");
442
443 nir_loop *loop = nir_push_loop(b);
444 {
445 nir_def *first_id = nir_read_first_invocation(b, subgroup_id);
446 nir_def *first_val = nir_read_first_invocation(b, val);
447 nir_def *first_result =
448 nir_read_invocation(b, val, nir_read_first_invocation(b, id));
449
450 nir_if *nif = nir_push_if(b, nir_ieq(b, id, first_id));
451 {
452 nir_store_var(b, result, first_val, BITFIELD_MASK(val->num_components));
453 }
454 nir_pop_if(b, nif);
455
456 nir_if *nif2 = nir_push_if(b, nir_elect(b, 1));
457 {
458 nir_if *nif3 = nir_push_if(b, nir_ult(b, subgroup_id, id));
459 {
460 nir_store_var(b, result, first_result, BITFIELD_MASK(val->num_components));
461 }
462 nir_pop_if(b, nif3);
463
464 nir_jump(b, nir_jump_break);
465 }
466 nir_pop_if(b, nif2);
467 }
468 nir_pop_loop(b, loop);
469
470 return nir_load_var(b, result);
471 }
472
473 static nir_def *
lower_boolean_shuffle(nir_builder * b,nir_intrinsic_instr * intrin,const nir_lower_subgroups_options * options)474 lower_boolean_shuffle(nir_builder *b, nir_intrinsic_instr *intrin,
475 const nir_lower_subgroups_options *options)
476 {
477 assert(options->ballot_components == 1 && options->subgroup_size);
478 nir_def *ballot = nir_ballot_relaxed(b, 1, options->ballot_bit_size, intrin->src[0].ssa);
479
480 nir_def *index = NULL;
481
482 /* If the shuffle amount isn't constant, it might be divergent but
483 * inverse_ballot requires a uniform source, so take a different path.
484 * rotate allows us to assume the delta is uniform unlike shuffle_up/down.
485 */
486 switch (intrin->intrinsic) {
487 case nir_intrinsic_shuffle_up:
488 if (nir_src_is_const(intrin->src[1]))
489 ballot = nir_ishl(b, ballot, intrin->src[1].ssa);
490 else
491 index = nir_isub(b, nir_load_subgroup_invocation(b), intrin->src[1].ssa);
492 break;
493 case nir_intrinsic_shuffle_down:
494 if (nir_src_is_const(intrin->src[1]))
495 ballot = nir_ushr(b, ballot, intrin->src[1].ssa);
496 else
497 index = nir_iadd(b, nir_load_subgroup_invocation(b), intrin->src[1].ssa);
498 break;
499 case nir_intrinsic_shuffle_xor:
500 index = nir_ixor(b, nir_load_subgroup_invocation(b), intrin->src[1].ssa);
501 break;
502 case nir_intrinsic_rotate: {
503 nir_def *delta = nir_as_uniform(b, intrin->src[1].ssa);
504 uint32_t cluster_size = nir_intrinsic_cluster_size(intrin);
505 unsigned subgroup_size = get_max_subgroup_size(options);
506 cluster_size = cluster_size ? cluster_size : subgroup_size;
507 cluster_size = MIN2(cluster_size, subgroup_size);
508 if (cluster_size == 1) {
509 return intrin->src[0].ssa;
510 } else if (cluster_size == 2) {
511 delta = nir_iand_imm(b, delta, cluster_size - 1);
512 nir_def *lo = nir_iand_imm(b, nir_ushr_imm(b, ballot, 1), 0x5555555555555555ull);
513 nir_def *hi = nir_iand_imm(b, nir_ishl_imm(b, ballot, 1), 0xaaaaaaaaaaaaaaaaull);
514 ballot = nir_bcsel(b, nir_ine_imm(b, delta, 0), nir_ior(b, hi, lo), ballot);
515 } else if (cluster_size == ballot->bit_size) {
516 ballot = nir_uror(b, ballot, delta);
517 } else if (cluster_size == 32) {
518 nir_def *unpacked = nir_unpack_64_2x32(b, ballot);
519 unpacked = nir_uror(b, unpacked, delta);
520 ballot = nir_pack_64_2x32(b, unpacked);
521 } else {
522 delta = nir_iand_imm(b, delta, cluster_size - 1);
523 nir_def *delta_rev = nir_isub_imm(b, cluster_size, delta);
524 nir_def *mask = nir_mask(b, delta_rev, ballot->bit_size);
525 for (uint32_t i = cluster_size; i < ballot->bit_size; i *= 2) {
526 mask = nir_ior(b, nir_ishl_imm(b, mask, i), mask);
527 }
528 nir_def *lo = nir_iand(b, nir_ushr(b, ballot, delta), mask);
529 nir_def *hi = nir_iand(b, nir_ishl(b, ballot, delta_rev), nir_inot(b, mask));
530 ballot = nir_ior(b, lo, hi);
531 }
532 break;
533 }
534 case nir_intrinsic_shuffle:
535 index = intrin->src[1].ssa;
536 break;
537 case nir_intrinsic_read_invocation:
538 index = nir_as_uniform(b, intrin->src[1].ssa);
539 break;
540 default:
541 unreachable("not a boolean shuffle");
542 }
543
544 if (index) {
545 nir_def *mask = nir_ishl(b, nir_imm_intN_t(b, 1, ballot->bit_size), index);
546 return nir_ine_imm(b, nir_iand(b, ballot, mask), 0);
547 } else {
548 return nir_inverse_ballot(b, 1, ballot);
549 }
550 }
551
552 static nir_def *
vec_bit_count(nir_builder * b,nir_def * value)553 vec_bit_count(nir_builder *b, nir_def *value)
554 {
555 nir_def *vec_result = nir_bit_count(b, value);
556 nir_def *result = nir_channel(b, vec_result, 0);
557 for (unsigned i = 1; i < value->num_components; i++)
558 result = nir_iadd(b, result, nir_channel(b, vec_result, i));
559 return result;
560 }
561
562 /* produce a bitmask of 111...000...111... alternating between "size"
563 * 1's and "size" 0's (the LSB is 1).
564 */
565 static uint64_t
reduce_mask(unsigned size,unsigned ballot_bit_size)566 reduce_mask(unsigned size, unsigned ballot_bit_size)
567 {
568 uint64_t mask = 0;
569 for (unsigned i = 0; i < ballot_bit_size; i += 2 * size) {
570 mask |= ((1ull << size) - 1) << i;
571 }
572
573 return mask;
574 }
575
576 /* operate on a uniform per-thread bitmask provided by ballot() to perform the
577 * desired Boolean reduction. Assumes that the identity of the operation is
578 * false (so, no iand).
579 */
580 static nir_def *
lower_boolean_reduce_internal(nir_builder * b,nir_def * src,unsigned cluster_size,nir_op op,const nir_lower_subgroups_options * options)581 lower_boolean_reduce_internal(nir_builder *b, nir_def *src,
582 unsigned cluster_size, nir_op op,
583 const nir_lower_subgroups_options *options)
584 {
585 for (unsigned size = 1; size < cluster_size; size *= 2) {
586 nir_def *shifted = nir_ushr_imm(b, src, size);
587 src = nir_build_alu2(b, op, shifted, src);
588 uint64_t mask = reduce_mask(size, options->ballot_bit_size);
589 src = nir_iand_imm(b, src, mask);
590 shifted = nir_ishl_imm(b, src, size);
591 src = nir_ior(b, src, shifted);
592 }
593
594 return src;
595 }
596
597 /* operate on a uniform per-thread bitmask provided by ballot() to perform the
598 * desired Boolean inclusive scan. Assumes that the identity of the operation is
599 * false (so, no iand).
600 */
601 static nir_def *
lower_boolean_scan_internal(nir_builder * b,nir_def * src,nir_op op,const nir_lower_subgroups_options * options)602 lower_boolean_scan_internal(nir_builder *b, nir_def *src,
603 nir_op op,
604 const nir_lower_subgroups_options *options)
605 {
606 if (op == nir_op_ior) {
607 /* We want to return a bitmask with all 1's starting at the first 1 in
608 * src. -src is equivalent to ~src + 1. While src | ~src returns all
609 * 1's, src | (~src + 1) returns all 1's except for the bits changed by
610 * the increment. Any 1's before the least significant 0 of ~src are
611 * turned into 0 (zeroing those bits after or'ing) and the least
612 * signficant 0 of ~src is turned into 1 (not doing anything). So the
613 * final output is what we want.
614 */
615 return nir_ior(b, src, nir_ineg(b, src));
616 } else {
617 assert(op == nir_op_ixor);
618 for (unsigned shift = 1; shift < options->ballot_bit_size; shift *= 2) {
619 src = nir_ixor(b, src, nir_ishl_imm(b, src, shift));
620 }
621 return src;
622 }
623 }
624
625 static nir_def *
lower_boolean_reduce(nir_builder * b,nir_intrinsic_instr * intrin,const nir_lower_subgroups_options * options)626 lower_boolean_reduce(nir_builder *b, nir_intrinsic_instr *intrin,
627 const nir_lower_subgroups_options *options)
628 {
629 assert(intrin->num_components == 1);
630 assert(options->ballot_components == 1);
631
632 unsigned cluster_size =
633 intrin->intrinsic == nir_intrinsic_reduce ? nir_intrinsic_cluster_size(intrin) : 0;
634 nir_op op = nir_intrinsic_reduction_op(intrin);
635
636 /* For certain cluster sizes, reductions of iand and ior can be implemented
637 * more efficiently.
638 */
639 if (intrin->intrinsic == nir_intrinsic_reduce) {
640 if (cluster_size == 0) {
641 if (op == nir_op_iand)
642 return nir_vote_all(b, 1, intrin->src[0].ssa);
643 else if (op == nir_op_ior)
644 return nir_vote_any(b, 1, intrin->src[0].ssa);
645 else if (op == nir_op_ixor)
646 return nir_i2b(b, nir_iand_imm(b, vec_bit_count(b, nir_ballot(b,
647 options->ballot_components,
648 options->ballot_bit_size,
649 intrin->src[0].ssa)),
650 1));
651 else
652 unreachable("bad boolean reduction op");
653 }
654
655 if (cluster_size == 4) {
656 if (op == nir_op_iand)
657 return nir_quad_vote_all(b, 1, intrin->src[0].ssa);
658 else if (op == nir_op_ior)
659 return nir_quad_vote_any(b, 1, intrin->src[0].ssa);
660 }
661 }
662
663 nir_def *src = intrin->src[0].ssa;
664
665 /* Apply DeMorgan's law to implement "and" reductions, since all the
666 * lower_boolean_*_internal() functions assume an identity of 0 to make the
667 * generated code shorter.
668 */
669 nir_op new_op = (op == nir_op_iand) ? nir_op_ior : op;
670 if (op == nir_op_iand) {
671 src = nir_inot(b, src);
672 }
673
674 nir_def *val = nir_ballot(b, options->ballot_components, options->ballot_bit_size, src);
675
676 switch (intrin->intrinsic) {
677 case nir_intrinsic_reduce:
678 val = lower_boolean_reduce_internal(b, val, cluster_size, new_op, options);
679 break;
680 case nir_intrinsic_inclusive_scan:
681 val = lower_boolean_scan_internal(b, val, new_op, options);
682 break;
683 case nir_intrinsic_exclusive_scan:
684 val = lower_boolean_scan_internal(b, val, new_op, options);
685 val = nir_ishl_imm(b, val, 1);
686 break;
687 default:
688 unreachable("bad intrinsic");
689 }
690
691 if (op == nir_op_iand) {
692 val = nir_inot(b, val);
693 }
694
695 return nir_inverse_ballot(b, 1, val);
696 }
697
698 static nir_def *
build_identity(nir_builder * b,unsigned bit_size,nir_op op)699 build_identity(nir_builder *b, unsigned bit_size, nir_op op)
700 {
701 nir_const_value ident_const = nir_alu_binop_identity(op, bit_size);
702 return nir_build_imm(b, 1, bit_size, &ident_const);
703 }
704
705 /* Implementation of scan/reduce that assumes a full subgroup */
706 static nir_def *
build_scan_full(nir_builder * b,nir_intrinsic_op op,nir_op red_op,nir_def * data,unsigned cluster_size)707 build_scan_full(nir_builder *b, nir_intrinsic_op op, nir_op red_op,
708 nir_def *data, unsigned cluster_size)
709 {
710 switch (op) {
711 case nir_intrinsic_exclusive_scan:
712 case nir_intrinsic_inclusive_scan: {
713 for (unsigned i = 1; i < cluster_size; i *= 2) {
714 nir_def *idx = nir_load_subgroup_invocation(b);
715 nir_def *has_buddy = nir_ige_imm(b, idx, i);
716
717 nir_def *buddy_data = nir_shuffle_up(b, data, nir_imm_int(b, i));
718 nir_def *accum = nir_build_alu2(b, red_op, data, buddy_data);
719 data = nir_bcsel(b, has_buddy, accum, data);
720 }
721
722 if (op == nir_intrinsic_exclusive_scan) {
723 /* For exclusive scans, we need to shift one more time and fill in the
724 * bottom channel with identity.
725 */
726 nir_def *idx = nir_load_subgroup_invocation(b);
727 nir_def *has_buddy = nir_ige_imm(b, idx, 1);
728
729 nir_def *buddy_data = nir_shuffle_up(b, data, nir_imm_int(b, 1));
730 nir_def *identity = build_identity(b, data->bit_size, red_op);
731 data = nir_bcsel(b, has_buddy, buddy_data, identity);
732 }
733
734 return data;
735 }
736
737 case nir_intrinsic_reduce: {
738 for (unsigned i = 1; i < cluster_size; i *= 2) {
739 nir_def *buddy_data = nir_shuffle_xor(b, data, nir_imm_int(b, i));
740 data = nir_build_alu2(b, red_op, data, buddy_data);
741 }
742 return data;
743 }
744
745 default:
746 unreachable("Unsupported scan/reduce op");
747 }
748 }
749
750 /* Fully generic implementation of scan/reduce that takes a mask */
751 static nir_def *
build_scan_reduce(nir_builder * b,nir_intrinsic_op op,nir_op red_op,nir_def * data,nir_def * mask,unsigned max_mask_bits,const nir_lower_subgroups_options * options)752 build_scan_reduce(nir_builder *b, nir_intrinsic_op op, nir_op red_op,
753 nir_def *data, nir_def *mask, unsigned max_mask_bits,
754 const nir_lower_subgroups_options *options)
755 {
756 nir_def *lt_mask = nir_load_subgroup_lt_mask(b, options->ballot_components,
757 options->ballot_bit_size);
758
759 /* Mask of all channels whose values we need to accumulate. Our own value
760 * is already in accum, if inclusive, thanks to the initialization above.
761 * We only need to consider lower indexed invocations.
762 */
763 nir_def *remaining = nir_iand(b, mask, lt_mask);
764
765 for (unsigned i = 1; i < max_mask_bits; i *= 2) {
766 /* At each step, our buddy channel is the first channel we have yet to
767 * take into account in the accumulator.
768 */
769 nir_def *has_buddy = nir_bany_inequal(b, remaining, nir_imm_int(b, 0));
770 nir_def *buddy = nir_ballot_find_msb(b, 32, remaining);
771
772 /* Accumulate with our buddy channel, if any */
773 nir_def *buddy_data = nir_shuffle(b, data, buddy);
774 nir_def *accum = nir_build_alu2(b, red_op, data, buddy_data);
775 data = nir_bcsel(b, has_buddy, accum, data);
776
777 /* We just took into account everything in our buddy's accumulator from
778 * the previous step. The only things remaining are whatever channels
779 * were remaining for our buddy.
780 */
781 nir_def *buddy_remaining = nir_shuffle(b, remaining, buddy);
782 remaining = nir_bcsel(b, has_buddy, buddy_remaining, nir_imm_int(b, 0));
783 }
784
785 switch (op) {
786 case nir_intrinsic_exclusive_scan: {
787 /* For exclusive scans, we need to shift one more time and fill in the
788 * bottom channel with identity.
789 *
790 * Some of this will get CSE'd with the first step but that's okay. The
791 * code is cleaner this way.
792 */
793 nir_def *lower = nir_iand(b, mask, lt_mask);
794 nir_def *has_buddy = nir_bany_inequal(b, lower, nir_imm_int(b, 0));
795 nir_def *buddy = nir_ballot_find_msb(b, 32, lower);
796
797 nir_def *buddy_data = nir_shuffle(b, data, buddy);
798 nir_def *identity = build_identity(b, data->bit_size, red_op);
799 return nir_bcsel(b, has_buddy, buddy_data, identity);
800 }
801
802 case nir_intrinsic_inclusive_scan:
803 return data;
804
805 case nir_intrinsic_reduce: {
806 /* For reductions, we need to take the top value of the scan */
807 nir_def *idx = nir_ballot_find_msb(b, 32, mask);
808 return nir_shuffle(b, data, idx);
809 }
810
811 default:
812 unreachable("Unsupported scan/reduce op");
813 }
814 }
815
816 static nir_def *
build_cluster_mask(nir_builder * b,unsigned cluster_size,const nir_lower_subgroups_options * options)817 build_cluster_mask(nir_builder *b, unsigned cluster_size,
818 const nir_lower_subgroups_options *options)
819 {
820 nir_def *idx = nir_load_subgroup_invocation(b);
821 nir_def *cluster = nir_iand_imm(b, idx, ~(uint64_t)(cluster_size - 1));
822
823 if (cluster_size <= options->ballot_bit_size) {
824 return build_ballot_imm_ishl(b, BITFIELD_MASK(cluster_size), cluster,
825 options);
826 }
827
828 /* Since the cluster size and the ballot bit size are both powers of 2,
829 * cluster size will be a multiple of the ballot bit size. Therefore, each
830 * ballot component will be either all ones or all zeros. Build a vec for
831 * which each component holds the value of `cluster` for which the mask
832 * should be all ones.
833 */
834 nir_const_value cluster_sel_const[4];
835 assert(ARRAY_SIZE(cluster_sel_const) >= options->ballot_components);
836
837 for (unsigned i = 0; i < options->ballot_components; i++) {
838 unsigned cluster_val =
839 ROUND_DOWN_TO(i * options->ballot_bit_size, cluster_size);
840 cluster_sel_const[i] =
841 nir_const_value_for_uint(cluster_val, options->ballot_bit_size);
842 }
843
844 nir_def *cluster_sel =
845 nir_build_imm(b, options->ballot_components, options->ballot_bit_size,
846 cluster_sel_const);
847 nir_def *ones = nir_imm_intN_t(b, -1, options->ballot_bit_size);
848 nir_def *zeros = nir_imm_intN_t(b, 0, options->ballot_bit_size);
849 return nir_bcsel(b, nir_ieq(b, cluster, cluster_sel), ones, zeros);
850 }
851
852 static nir_def *
lower_scan_reduce(nir_builder * b,nir_intrinsic_instr * intrin,const nir_lower_subgroups_options * options)853 lower_scan_reduce(nir_builder *b, nir_intrinsic_instr *intrin,
854 const nir_lower_subgroups_options *options)
855 {
856 const nir_op red_op = nir_intrinsic_reduction_op(intrin);
857 unsigned subgroup_size = get_max_subgroup_size(options);
858
859 /* Grab the cluster size */
860 unsigned cluster_size = subgroup_size;
861 if (nir_intrinsic_has_cluster_size(intrin)) {
862 cluster_size = nir_intrinsic_cluster_size(intrin);
863 if (cluster_size == 0 || cluster_size > subgroup_size)
864 cluster_size = subgroup_size;
865 }
866
867 /* Check if all invocations are active. If so, we use the fast path. */
868 nir_def *mask = nir_ballot(b, options->ballot_components,
869 options->ballot_bit_size, nir_imm_true(b));
870
871 nir_def *full, *partial;
872 nir_push_if(b, nir_ball_iequal(b, mask, build_subgroup_mask(b, options)));
873 {
874 full = build_scan_full(b, intrin->intrinsic, red_op,
875 intrin->src[0].ssa, cluster_size);
876 }
877 nir_push_else(b, NULL);
878 {
879 /* Mask according to the cluster size */
880 if (cluster_size < subgroup_size) {
881 nir_def *cluster_mask = build_cluster_mask(b, cluster_size, options);
882 mask = nir_iand(b, mask, cluster_mask);
883 }
884
885 partial = build_scan_reduce(b, intrin->intrinsic, red_op,
886 intrin->src[0].ssa, mask, cluster_size,
887 options);
888 }
889 nir_pop_if(b, NULL);
890 return nir_if_phi(b, full, partial);
891 }
892
893 static bool
lower_subgroups_filter(const nir_instr * instr,const void * _options)894 lower_subgroups_filter(const nir_instr *instr, const void *_options)
895 {
896 const nir_lower_subgroups_options *options = _options;
897
898 if (options->filter) {
899 return options->filter(instr, options->filter_data);
900 }
901
902 return instr->type == nir_instr_type_intrinsic;
903 }
904
905 static nir_def *
build_subgroup_eq_mask(nir_builder * b,const nir_lower_subgroups_options * options)906 build_subgroup_eq_mask(nir_builder *b,
907 const nir_lower_subgroups_options *options)
908 {
909 nir_def *subgroup_idx = nir_load_subgroup_invocation(b);
910
911 return build_ballot_imm_ishl(b, 1, subgroup_idx, options);
912 }
913
914 static nir_def *
build_subgroup_ge_mask(nir_builder * b,const nir_lower_subgroups_options * options)915 build_subgroup_ge_mask(nir_builder *b,
916 const nir_lower_subgroups_options *options)
917 {
918 nir_def *subgroup_idx = nir_load_subgroup_invocation(b);
919
920 return build_ballot_imm_ishl(b, ~0ull, subgroup_idx, options);
921 }
922
923 static nir_def *
build_subgroup_gt_mask(nir_builder * b,const nir_lower_subgroups_options * options)924 build_subgroup_gt_mask(nir_builder *b,
925 const nir_lower_subgroups_options *options)
926 {
927 nir_def *subgroup_idx = nir_load_subgroup_invocation(b);
928
929 return build_ballot_imm_ishl(b, ~1ull, subgroup_idx, options);
930 }
931
932 static nir_def *
build_subgroup_quad_mask(nir_builder * b,const nir_lower_subgroups_options * options)933 build_subgroup_quad_mask(nir_builder *b,
934 const nir_lower_subgroups_options *options)
935 {
936 nir_def *subgroup_idx = nir_load_subgroup_invocation(b);
937 nir_def *quad_first_idx = nir_iand_imm(b, subgroup_idx, ~0x3);
938
939 return build_ballot_imm_ishl(b, 0xf, quad_first_idx, options);
940 }
941
942 static nir_def *
build_quad_vote_any(nir_builder * b,nir_def * src,const nir_lower_subgroups_options * options)943 build_quad_vote_any(nir_builder *b, nir_def *src,
944 const nir_lower_subgroups_options *options)
945 {
946 nir_def *ballot = nir_ballot(b, options->ballot_components,
947 options->ballot_bit_size,
948 src);
949 nir_def *mask = build_subgroup_quad_mask(b, options);
950
951 return nir_ine_imm(b, nir_iand(b, ballot, mask), 0);
952 }
953
954 static nir_def *
vec_find_lsb(nir_builder * b,nir_def * value)955 vec_find_lsb(nir_builder *b, nir_def *value)
956 {
957 nir_def *vec_result = nir_find_lsb(b, value);
958 nir_def *result = nir_imm_int(b, -1);
959 for (int i = value->num_components - 1; i >= 0; i--) {
960 nir_def *channel = nir_channel(b, vec_result, i);
961 /* result = channel >= 0 ? (i * bitsize + channel) : result */
962 result = nir_bcsel(b, nir_ige_imm(b, channel, 0),
963 nir_iadd_imm(b, channel, i * value->bit_size),
964 result);
965 }
966 return result;
967 }
968
969 static nir_def *
vec_find_msb(nir_builder * b,nir_def * value)970 vec_find_msb(nir_builder *b, nir_def *value)
971 {
972 nir_def *vec_result = nir_ufind_msb(b, value);
973 nir_def *result = nir_imm_int(b, -1);
974 for (unsigned i = 0; i < value->num_components; i++) {
975 nir_def *channel = nir_channel(b, vec_result, i);
976 /* result = channel >= 0 ? (i * bitsize + channel) : result */
977 result = nir_bcsel(b, nir_ige_imm(b, channel, 0),
978 nir_iadd_imm(b, channel, i * value->bit_size),
979 result);
980 }
981 return result;
982 }
983
984 static nir_def *
lower_dynamic_quad_broadcast(nir_builder * b,nir_intrinsic_instr * intrin,const nir_lower_subgroups_options * options)985 lower_dynamic_quad_broadcast(nir_builder *b, nir_intrinsic_instr *intrin,
986 const nir_lower_subgroups_options *options)
987 {
988 if (!options->lower_quad_broadcast_dynamic_to_const)
989 return lower_to_shuffle(b, intrin, options);
990
991 nir_def *dst = NULL;
992
993 for (unsigned i = 0; i < 4; ++i) {
994 nir_def *qbcst = nir_quad_broadcast(b, intrin->src[0].ssa,
995 nir_imm_int(b, i));
996
997 if (i)
998 dst = nir_bcsel(b, nir_ieq_imm(b, intrin->src[1].ssa, i),
999 qbcst, dst);
1000 else
1001 dst = qbcst;
1002 }
1003
1004 return dst;
1005 }
1006
1007 static nir_def *
lower_first_invocation_to_ballot(nir_builder * b,nir_intrinsic_instr * intrin,const nir_lower_subgroups_options * options)1008 lower_first_invocation_to_ballot(nir_builder *b, nir_intrinsic_instr *intrin,
1009 const nir_lower_subgroups_options *options)
1010 {
1011 return nir_ballot_find_lsb(b, 32, nir_ballot(b, 4, 32, nir_imm_true(b)));
1012 }
1013
1014 static nir_def *
lower_read_first_invocation(nir_builder * b,nir_intrinsic_instr * intrin)1015 lower_read_first_invocation(nir_builder *b, nir_intrinsic_instr *intrin)
1016 {
1017 return nir_read_invocation(b, intrin->src[0].ssa, nir_first_invocation(b));
1018 }
1019
1020 static nir_def *
lower_read_invocation_to_cond(nir_builder * b,nir_intrinsic_instr * intrin)1021 lower_read_invocation_to_cond(nir_builder *b, nir_intrinsic_instr *intrin)
1022 {
1023 return nir_read_invocation_cond_ir3(b, intrin->def.bit_size,
1024 intrin->src[0].ssa,
1025 nir_ieq(b, intrin->src[1].ssa,
1026 nir_load_subgroup_invocation(b)));
1027 }
1028
1029 static nir_def *
lower_subgroups_instr(nir_builder * b,nir_instr * instr,void * _options)1030 lower_subgroups_instr(nir_builder *b, nir_instr *instr, void *_options)
1031 {
1032 const nir_lower_subgroups_options *options = _options;
1033
1034 nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
1035 switch (intrin->intrinsic) {
1036 case nir_intrinsic_vote_any:
1037 case nir_intrinsic_vote_all:
1038 if (options->lower_vote_trivial)
1039 return intrin->src[0].ssa;
1040 break;
1041
1042 case nir_intrinsic_vote_feq:
1043 case nir_intrinsic_vote_ieq:
1044 if (options->lower_vote_trivial)
1045 return nir_imm_true(b);
1046
1047 if (nir_src_bit_size(intrin->src[0]) == 1) {
1048 if (options->lower_vote_bool_eq)
1049 return lower_vote_eq(b, intrin);
1050 } else {
1051 if (options->lower_vote_eq)
1052 return lower_vote_eq(b, intrin);
1053 }
1054
1055 if (options->lower_to_scalar && intrin->num_components > 1)
1056 return lower_vote_eq_to_scalar(b, intrin);
1057 break;
1058
1059 case nir_intrinsic_load_subgroup_size:
1060 if (options->subgroup_size)
1061 return nir_imm_int(b, options->subgroup_size);
1062 break;
1063
1064 case nir_intrinsic_first_invocation:
1065 if (options->subgroup_size == 1)
1066 return nir_imm_int(b, 0);
1067
1068 if (options->lower_first_invocation_to_ballot)
1069 return lower_first_invocation_to_ballot(b, intrin, options);
1070
1071 break;
1072
1073 case nir_intrinsic_read_invocation:
1074 if (options->lower_to_scalar && intrin->num_components > 1)
1075 return lower_subgroup_op_to_scalar(b, intrin);
1076
1077 if (options->lower_boolean_shuffle && intrin->src[0].ssa->bit_size == 1)
1078 return lower_boolean_shuffle(b, intrin, options);
1079
1080 if (options->lower_read_invocation_to_cond)
1081 return lower_read_invocation_to_cond(b, intrin);
1082
1083 break;
1084
1085 case nir_intrinsic_read_first_invocation:
1086 if (options->lower_to_scalar && intrin->num_components > 1)
1087 return lower_subgroup_op_to_scalar(b, intrin);
1088
1089 if (options->lower_read_first_invocation)
1090 return lower_read_first_invocation(b, intrin);
1091 break;
1092
1093 case nir_intrinsic_load_subgroup_eq_mask:
1094 case nir_intrinsic_load_subgroup_ge_mask:
1095 case nir_intrinsic_load_subgroup_gt_mask:
1096 case nir_intrinsic_load_subgroup_le_mask:
1097 case nir_intrinsic_load_subgroup_lt_mask: {
1098 if (!options->lower_subgroup_masks)
1099 return NULL;
1100
1101 nir_def *val;
1102 switch (intrin->intrinsic) {
1103 case nir_intrinsic_load_subgroup_eq_mask:
1104 val = build_subgroup_eq_mask(b, options);
1105 break;
1106 case nir_intrinsic_load_subgroup_ge_mask:
1107 val = nir_iand(b, build_subgroup_ge_mask(b, options),
1108 build_subgroup_mask(b, options));
1109 break;
1110 case nir_intrinsic_load_subgroup_gt_mask:
1111 val = nir_iand(b, build_subgroup_gt_mask(b, options),
1112 build_subgroup_mask(b, options));
1113 break;
1114 case nir_intrinsic_load_subgroup_le_mask:
1115 val = nir_inot(b, build_subgroup_gt_mask(b, options));
1116 break;
1117 case nir_intrinsic_load_subgroup_lt_mask:
1118 val = nir_inot(b, build_subgroup_ge_mask(b, options));
1119 break;
1120 default:
1121 unreachable("you seriously can't tell this is unreachable?");
1122 }
1123
1124 return uint_to_ballot_type(b, val,
1125 intrin->def.num_components,
1126 intrin->def.bit_size);
1127 }
1128
1129 case nir_intrinsic_ballot: {
1130 if (intrin->def.num_components == options->ballot_components &&
1131 intrin->def.bit_size == options->ballot_bit_size)
1132 return NULL;
1133
1134 nir_def *ballot =
1135 nir_ballot(b, options->ballot_components, options->ballot_bit_size,
1136 intrin->src[0].ssa);
1137
1138 return uint_to_ballot_type(b, ballot,
1139 intrin->def.num_components,
1140 intrin->def.bit_size);
1141 }
1142
1143 case nir_intrinsic_inverse_ballot:
1144 if (options->lower_inverse_ballot) {
1145 return nir_ballot_bitfield_extract(b, 1, intrin->src[0].ssa,
1146 nir_load_subgroup_invocation(b));
1147 } else if (intrin->src[0].ssa->num_components != options->ballot_components ||
1148 intrin->src[0].ssa->bit_size != options->ballot_bit_size) {
1149 return nir_inverse_ballot(b, 1, ballot_type_to_uint(b, intrin->src[0].ssa, options));
1150 }
1151 break;
1152
1153 case nir_intrinsic_ballot_bitfield_extract:
1154 case nir_intrinsic_ballot_bit_count_reduce:
1155 case nir_intrinsic_ballot_find_lsb:
1156 case nir_intrinsic_ballot_find_msb: {
1157 nir_def *int_val = ballot_type_to_uint(b, intrin->src[0].ssa,
1158 options);
1159
1160 if (intrin->intrinsic != nir_intrinsic_ballot_bitfield_extract &&
1161 intrin->intrinsic != nir_intrinsic_ballot_find_lsb) {
1162 /* For OpGroupNonUniformBallotFindMSB, the SPIR-V Spec says:
1163 *
1164 * "Find the most significant bit set to 1 in Value, considering
1165 * only the bits in Value required to represent all bits of the
1166 * group’s invocations. If none of the considered bits is set to
1167 * 1, the result is undefined."
1168 *
1169 * It has similar text for the other three. This means that, in case
1170 * the subgroup size is less than 32, we have to mask off the unused
1171 * bits. If the subgroup size is fixed and greater than or equal to
1172 * 32, the mask will be 0xffffffff and nir_opt_algebraic will delete
1173 * the iand.
1174 *
1175 * We only have to worry about this for BitCount and FindMSB because
1176 * FindLSB counts from the bottom and BitfieldExtract selects
1177 * individual bits. In either case, if run outside the range of
1178 * valid bits, we hit the undefined results case and we can return
1179 * anything we want.
1180 */
1181 int_val = nir_iand(b, int_val, build_subgroup_mask(b, options));
1182 }
1183
1184 switch (intrin->intrinsic) {
1185 case nir_intrinsic_ballot_bitfield_extract: {
1186 nir_def *idx = intrin->src[1].ssa;
1187 if (int_val->num_components > 1) {
1188 /* idx will be truncated by nir_ushr, so we just need to select
1189 * the right component using the bits of idx that are truncated in
1190 * the shift.
1191 */
1192 int_val =
1193 nir_vector_extract(b, int_val,
1194 nir_udiv_imm(b, idx, int_val->bit_size));
1195 }
1196
1197 return nir_test_mask(b, nir_ushr(b, int_val, idx), 1);
1198 }
1199 case nir_intrinsic_ballot_bit_count_reduce:
1200 return vec_bit_count(b, int_val);
1201 case nir_intrinsic_ballot_find_lsb:
1202 return vec_find_lsb(b, int_val);
1203 case nir_intrinsic_ballot_find_msb:
1204 return vec_find_msb(b, int_val);
1205 default:
1206 unreachable("you seriously can't tell this is unreachable?");
1207 }
1208 }
1209
1210 case nir_intrinsic_ballot_bit_count_exclusive:
1211 case nir_intrinsic_ballot_bit_count_inclusive: {
1212 nir_def *int_val = ballot_type_to_uint(b, intrin->src[0].ssa,
1213 options);
1214 if (options->lower_ballot_bit_count_to_mbcnt_amd) {
1215 nir_def *acc;
1216 if (intrin->intrinsic == nir_intrinsic_ballot_bit_count_exclusive) {
1217 acc = nir_imm_int(b, 0);
1218 } else {
1219 acc = nir_iand_imm(b, nir_u2u32(b, int_val), 0x1);
1220 int_val = nir_ushr_imm(b, int_val, 1);
1221 }
1222 return nir_mbcnt_amd(b, int_val, acc);
1223 }
1224
1225 nir_def *mask;
1226 if (intrin->intrinsic == nir_intrinsic_ballot_bit_count_inclusive) {
1227 mask = nir_inot(b, build_subgroup_gt_mask(b, options));
1228 } else {
1229 mask = nir_inot(b, build_subgroup_ge_mask(b, options));
1230 }
1231
1232 return vec_bit_count(b, nir_iand(b, int_val, mask));
1233 }
1234
1235 case nir_intrinsic_elect: {
1236 if (!options->lower_elect)
1237 return NULL;
1238
1239 return nir_ieq(b, nir_load_subgroup_invocation(b), nir_first_invocation(b));
1240 }
1241
1242 case nir_intrinsic_shuffle:
1243 if (options->lower_shuffle &&
1244 (!options->lower_boolean_shuffle || intrin->src[0].ssa->bit_size != 1))
1245 return lower_shuffle(b, intrin);
1246 else if (options->lower_to_scalar && intrin->num_components > 1)
1247 return lower_subgroup_op_to_scalar(b, intrin);
1248 else if (options->lower_boolean_shuffle && intrin->src[0].ssa->bit_size == 1)
1249 return lower_boolean_shuffle(b, intrin, options);
1250 else if (options->lower_shuffle_to_32bit && intrin->src[0].ssa->bit_size == 64)
1251 return lower_subgroup_op_to_32bit(b, intrin);
1252 break;
1253 case nir_intrinsic_shuffle_xor:
1254 case nir_intrinsic_shuffle_up:
1255 case nir_intrinsic_shuffle_down:
1256 if (options->lower_relative_shuffle &&
1257 (!options->lower_boolean_shuffle || intrin->src[0].ssa->bit_size != 1))
1258 return lower_to_shuffle(b, intrin, options);
1259 else if (options->lower_to_scalar && intrin->num_components > 1)
1260 return lower_subgroup_op_to_scalar(b, intrin);
1261 else if (options->lower_boolean_shuffle && intrin->src[0].ssa->bit_size == 1)
1262 return lower_boolean_shuffle(b, intrin, options);
1263 else if (options->lower_shuffle_to_32bit && intrin->src[0].ssa->bit_size == 64)
1264 return lower_subgroup_op_to_32bit(b, intrin);
1265 break;
1266
1267 case nir_intrinsic_quad_broadcast:
1268 case nir_intrinsic_quad_swap_horizontal:
1269 case nir_intrinsic_quad_swap_vertical:
1270 case nir_intrinsic_quad_swap_diagonal:
1271 if (options->lower_quad ||
1272 (options->lower_quad_broadcast_dynamic &&
1273 intrin->intrinsic == nir_intrinsic_quad_broadcast &&
1274 !nir_src_is_const(intrin->src[1])))
1275 return lower_dynamic_quad_broadcast(b, intrin, options);
1276 else if (options->lower_to_scalar && intrin->num_components > 1)
1277 return lower_subgroup_op_to_scalar(b, intrin);
1278 break;
1279
1280 case nir_intrinsic_quad_vote_any:
1281 if (options->lower_quad_vote)
1282 return build_quad_vote_any(b, intrin->src[0].ssa, options);
1283 break;
1284 case nir_intrinsic_quad_vote_all:
1285 if (options->lower_quad_vote) {
1286 nir_def *not_src = nir_inot(b, intrin->src[0].ssa);
1287 nir_def *any_not = build_quad_vote_any(b, not_src, options);
1288 return nir_inot(b, any_not);
1289 }
1290 break;
1291
1292 case nir_intrinsic_reduce: {
1293 nir_def *ret = NULL;
1294 /* A cluster size greater than the subgroup size is implemention defined */
1295 if (options->subgroup_size &&
1296 nir_intrinsic_cluster_size(intrin) >= options->subgroup_size) {
1297 nir_intrinsic_set_cluster_size(intrin, 0);
1298 ret = NIR_LOWER_INSTR_PROGRESS;
1299 }
1300 if (nir_intrinsic_cluster_size(intrin) == 1)
1301 return intrin->src[0].ssa;
1302 if (options->lower_to_scalar && intrin->num_components > 1)
1303 return lower_subgroup_op_to_scalar(b, intrin);
1304 if (intrin->def.bit_size == 1 && options->ballot_components == 1 &&
1305 (options->lower_boolean_reduce || options->lower_reduce))
1306 return lower_boolean_reduce(b, intrin, options);
1307 if (options->lower_reduce)
1308 return lower_scan_reduce(b, intrin, options);
1309 return ret;
1310 }
1311 case nir_intrinsic_inclusive_scan:
1312 case nir_intrinsic_exclusive_scan:
1313 if (options->lower_to_scalar && intrin->num_components > 1)
1314 return lower_subgroup_op_to_scalar(b, intrin);
1315 if (intrin->def.bit_size == 1 && options->ballot_components == 1 &&
1316 (options->lower_boolean_reduce || options->lower_reduce))
1317 return lower_boolean_reduce(b, intrin, options);
1318 if (options->lower_reduce)
1319 return lower_scan_reduce(b, intrin, options);
1320 break;
1321
1322 case nir_intrinsic_rotate:
1323 if (options->lower_rotate_to_shuffle &&
1324 (!options->lower_boolean_shuffle || intrin->src[0].ssa->bit_size != 1))
1325 return lower_to_shuffle(b, intrin, options);
1326 else if (options->lower_rotate_clustered_to_shuffle &&
1327 nir_intrinsic_cluster_size(intrin) > 0 &&
1328 (!options->lower_boolean_shuffle || intrin->src[0].ssa->bit_size != 1))
1329 return lower_to_shuffle(b, intrin, options);
1330 else if (options->lower_to_scalar && intrin->num_components > 1)
1331 return lower_subgroup_op_to_scalar(b, intrin);
1332 else if (options->lower_boolean_shuffle && intrin->src[0].ssa->bit_size == 1)
1333 return lower_boolean_shuffle(b, intrin, options);
1334 else if (options->lower_shuffle_to_32bit && intrin->src[0].ssa->bit_size == 64)
1335 return lower_subgroup_op_to_32bit(b, intrin);
1336 break;
1337 case nir_intrinsic_masked_swizzle_amd:
1338 if (options->lower_to_scalar && intrin->num_components > 1) {
1339 return lower_subgroup_op_to_scalar(b, intrin);
1340 } else if (options->lower_shuffle_to_32bit && intrin->src[0].ssa->bit_size == 64) {
1341 return lower_subgroup_op_to_32bit(b, intrin);
1342 }
1343 break;
1344
1345 default:
1346 break;
1347 }
1348
1349 return NULL;
1350 }
1351
1352 bool
nir_lower_subgroups(nir_shader * shader,const nir_lower_subgroups_options * options)1353 nir_lower_subgroups(nir_shader *shader,
1354 const nir_lower_subgroups_options *options)
1355 {
1356 return nir_shader_lower_instructions(shader, lower_subgroups_filter,
1357 lower_subgroups_instr,
1358 (void *)options);
1359 }
1360