1 /*
2 * Copyright © 2017 Intel Corporation
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 * and/or sell copies of the Software, and to permit persons to whom the
9 * Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21 * IN THE SOFTWARE.
22 */
23
24 #include "util/u_math.h"
25 #include "nir.h"
26 #include "nir_builder.h"
27
28 /**
29 * \file nir_opt_intrinsics.c
30 */
31
32 static nir_intrinsic_instr *
lower_subgroups_64bit_split_intrinsic(nir_builder * b,nir_intrinsic_instr * intrin,unsigned int component)33 lower_subgroups_64bit_split_intrinsic(nir_builder *b, nir_intrinsic_instr *intrin,
34 unsigned int component)
35 {
36 nir_def *comp;
37 if (component == 0)
38 comp = nir_unpack_64_2x32_split_x(b, intrin->src[0].ssa);
39 else
40 comp = nir_unpack_64_2x32_split_y(b, intrin->src[0].ssa);
41
42 nir_intrinsic_instr *intr = nir_intrinsic_instr_create(b->shader, intrin->intrinsic);
43 nir_def_init(&intr->instr, &intr->def, 1, 32);
44 intr->const_index[0] = intrin->const_index[0];
45 intr->const_index[1] = intrin->const_index[1];
46 intr->src[0] = nir_src_for_ssa(comp);
47 if (nir_intrinsic_infos[intrin->intrinsic].num_srcs == 2)
48 intr->src[1] = nir_src_for_ssa(intrin->src[1].ssa);
49
50 intr->num_components = 1;
51 nir_builder_instr_insert(b, &intr->instr);
52 return intr;
53 }
54
55 static nir_def *
lower_subgroup_op_to_32bit(nir_builder * b,nir_intrinsic_instr * intrin)56 lower_subgroup_op_to_32bit(nir_builder *b, nir_intrinsic_instr *intrin)
57 {
58 assert(intrin->src[0].ssa->bit_size == 64);
59 nir_intrinsic_instr *intr_x = lower_subgroups_64bit_split_intrinsic(b, intrin, 0);
60 nir_intrinsic_instr *intr_y = lower_subgroups_64bit_split_intrinsic(b, intrin, 1);
61 return nir_pack_64_2x32_split(b, &intr_x->def, &intr_y->def);
62 }
63
64 static nir_def *
ballot_type_to_uint(nir_builder * b,nir_def * value,const nir_lower_subgroups_options * options)65 ballot_type_to_uint(nir_builder *b, nir_def *value,
66 const nir_lower_subgroups_options *options)
67 {
68 /* Only the new-style SPIR-V subgroup instructions take a ballot result as
69 * an argument, so we only use this on uvec4 types.
70 */
71 assert(value->num_components == 4 && value->bit_size == 32);
72
73 return nir_extract_bits(b, &value, 1, 0, options->ballot_components,
74 options->ballot_bit_size);
75 }
76
77 static nir_def *
uint_to_ballot_type(nir_builder * b,nir_def * value,unsigned num_components,unsigned bit_size)78 uint_to_ballot_type(nir_builder *b, nir_def *value,
79 unsigned num_components, unsigned bit_size)
80 {
81 assert(util_is_power_of_two_nonzero(num_components));
82 assert(util_is_power_of_two_nonzero(value->num_components));
83
84 unsigned total_bits = bit_size * num_components;
85
86 /* If the source doesn't have enough bits, zero-pad */
87 if (total_bits > value->bit_size * value->num_components)
88 value = nir_pad_vector_imm_int(b, value, 0, total_bits / value->bit_size);
89
90 value = nir_bitcast_vector(b, value, bit_size);
91
92 /* If the source has too many components, truncate. This can happen if,
93 * for instance, we're implementing GL_ARB_shader_ballot or
94 * VK_EXT_shader_subgroup_ballot which have 64-bit ballot values on an
95 * architecture with a native 128-bit uvec4 ballot. This comes up in Zink
96 * for OpenGL on Vulkan. It's the job of the driver calling this lowering
97 * pass to ensure that it's restricted subgroup sizes sufficiently that we
98 * have enough ballot bits.
99 */
100 if (value->num_components > num_components)
101 value = nir_trim_vector(b, value, num_components);
102
103 return value;
104 }
105
106 static nir_def *
lower_subgroup_op_to_scalar(nir_builder * b,nir_intrinsic_instr * intrin)107 lower_subgroup_op_to_scalar(nir_builder *b, nir_intrinsic_instr *intrin)
108 {
109 /* This is safe to call on scalar things but it would be silly */
110 assert(intrin->def.num_components > 1);
111
112 nir_def *value = intrin->src[0].ssa;
113 nir_def *reads[NIR_MAX_VEC_COMPONENTS];
114
115 for (unsigned i = 0; i < intrin->num_components; i++) {
116 nir_intrinsic_instr *chan_intrin =
117 nir_intrinsic_instr_create(b->shader, intrin->intrinsic);
118 nir_def_init(&chan_intrin->instr, &chan_intrin->def, 1,
119 intrin->def.bit_size);
120 chan_intrin->num_components = 1;
121
122 /* value */
123 chan_intrin->src[0] = nir_src_for_ssa(nir_channel(b, value, i));
124 /* invocation */
125 if (nir_intrinsic_infos[intrin->intrinsic].num_srcs > 1) {
126 assert(nir_intrinsic_infos[intrin->intrinsic].num_srcs == 2);
127 chan_intrin->src[1] = nir_src_for_ssa(intrin->src[1].ssa);
128 }
129
130 chan_intrin->const_index[0] = intrin->const_index[0];
131 chan_intrin->const_index[1] = intrin->const_index[1];
132
133 nir_builder_instr_insert(b, &chan_intrin->instr);
134 reads[i] = &chan_intrin->def;
135 }
136
137 return nir_vec(b, reads, intrin->num_components);
138 }
139
140 static nir_def *
lower_vote_eq_to_scalar(nir_builder * b,nir_intrinsic_instr * intrin)141 lower_vote_eq_to_scalar(nir_builder *b, nir_intrinsic_instr *intrin)
142 {
143 nir_def *value = intrin->src[0].ssa;
144
145 nir_def *result = NULL;
146 for (unsigned i = 0; i < intrin->num_components; i++) {
147 nir_def* chan = nir_channel(b, value, i);
148
149 if (intrin->intrinsic == nir_intrinsic_vote_feq) {
150 chan = nir_vote_feq(b, intrin->def.bit_size, chan);
151 } else {
152 chan = nir_vote_ieq(b, intrin->def.bit_size, chan);
153 }
154
155 if (result) {
156 result = nir_iand(b, result, chan);
157 } else {
158 result = chan;
159 }
160 }
161
162 return result;
163 }
164
165 static nir_def *
lower_vote_eq(nir_builder * b,nir_intrinsic_instr * intrin)166 lower_vote_eq(nir_builder *b, nir_intrinsic_instr *intrin)
167 {
168 nir_def *value = intrin->src[0].ssa;
169
170 /* We have to implicitly lower to scalar */
171 nir_def *all_eq = NULL;
172 for (unsigned i = 0; i < intrin->num_components; i++) {
173 nir_def *rfi = nir_read_first_invocation(b, nir_channel(b, value, i));
174
175 nir_def *is_eq;
176 if (intrin->intrinsic == nir_intrinsic_vote_feq) {
177 is_eq = nir_feq(b, rfi, nir_channel(b, value, i));
178 } else {
179 is_eq = nir_ieq(b, rfi, nir_channel(b, value, i));
180 }
181
182 if (all_eq == NULL) {
183 all_eq = is_eq;
184 } else {
185 all_eq = nir_iand(b, all_eq, is_eq);
186 }
187 }
188
189 return nir_vote_all(b, 1, all_eq);
190 }
191
192 static nir_def *
lower_shuffle_to_swizzle(nir_builder * b,nir_intrinsic_instr * intrin)193 lower_shuffle_to_swizzle(nir_builder *b, nir_intrinsic_instr *intrin)
194 {
195 unsigned mask = nir_src_as_uint(intrin->src[1]);
196
197 if (mask >= 32)
198 return NULL;
199
200 return nir_masked_swizzle_amd(b, intrin->src[0].ssa,
201 .swizzle_mask = (mask << 10) | 0x1f,
202 .fetch_inactive = true);
203 }
204
205 /* Lowers "specialized" shuffles to a generic nir_intrinsic_shuffle. */
206
207 static nir_def *
lower_to_shuffle(nir_builder * b,nir_intrinsic_instr * intrin,const nir_lower_subgroups_options * options)208 lower_to_shuffle(nir_builder *b, nir_intrinsic_instr *intrin,
209 const nir_lower_subgroups_options *options)
210 {
211 if (intrin->intrinsic == nir_intrinsic_shuffle_xor &&
212 options->lower_shuffle_to_swizzle_amd &&
213 nir_src_is_const(intrin->src[1])) {
214
215 nir_def *result = lower_shuffle_to_swizzle(b, intrin);
216 if (result)
217 return result;
218 }
219
220 nir_def *index = nir_load_subgroup_invocation(b);
221 switch (intrin->intrinsic) {
222 case nir_intrinsic_shuffle_xor:
223 index = nir_ixor(b, index, intrin->src[1].ssa);
224 break;
225 case nir_intrinsic_shuffle_up:
226 index = nir_isub(b, index, intrin->src[1].ssa);
227 break;
228 case nir_intrinsic_shuffle_down:
229 index = nir_iadd(b, index, intrin->src[1].ssa);
230 break;
231 case nir_intrinsic_quad_broadcast:
232 index = nir_ior(b, nir_iand_imm(b, index, ~0x3),
233 intrin->src[1].ssa);
234 break;
235 case nir_intrinsic_quad_swap_horizontal:
236 /* For Quad operations, subgroups are divided into quads where
237 * (invocation % 4) is the index to a square arranged as follows:
238 *
239 * +---+---+
240 * | 0 | 1 |
241 * +---+---+
242 * | 2 | 3 |
243 * +---+---+
244 */
245 index = nir_ixor(b, index, nir_imm_int(b, 0x1));
246 break;
247 case nir_intrinsic_quad_swap_vertical:
248 index = nir_ixor(b, index, nir_imm_int(b, 0x2));
249 break;
250 case nir_intrinsic_quad_swap_diagonal:
251 index = nir_ixor(b, index, nir_imm_int(b, 0x3));
252 break;
253 case nir_intrinsic_rotate: {
254 nir_def *delta = intrin->src[1].ssa;
255 nir_def *local_id = nir_load_subgroup_invocation(b);
256 const unsigned cluster_size = nir_intrinsic_cluster_size(intrin);
257
258 nir_def *rotation_group_mask =
259 cluster_size > 0 ? nir_imm_int(b, (int)(cluster_size - 1)) : nir_iadd_imm(b, nir_load_subgroup_size(b), -1);
260
261 index = nir_iand(b, nir_iadd(b, local_id, delta),
262 rotation_group_mask);
263 if (cluster_size > 0) {
264 index = nir_iadd(b, index,
265 nir_iand(b, local_id, nir_inot(b, rotation_group_mask)));
266 }
267 break;
268 }
269 default:
270 unreachable("Invalid intrinsic");
271 }
272
273 return nir_shuffle(b, intrin->src[0].ssa, index);
274 }
275
276 static const struct glsl_type *
glsl_type_for_ssa(nir_def * def)277 glsl_type_for_ssa(nir_def *def)
278 {
279 const struct glsl_type *comp_type = def->bit_size == 1 ? glsl_bool_type() : glsl_uintN_t_type(def->bit_size);
280 return glsl_replace_vector_type(comp_type, def->num_components);
281 }
282
283 /* Lower nir_intrinsic_shuffle to a waterfall loop + nir_read_invocation.
284 */
285 static nir_def *
lower_shuffle(nir_builder * b,nir_intrinsic_instr * intrin)286 lower_shuffle(nir_builder *b, nir_intrinsic_instr *intrin)
287 {
288 nir_def *val = intrin->src[0].ssa;
289 nir_def *id = intrin->src[1].ssa;
290
291 /* The loop is something like:
292 *
293 * while (true) {
294 * first_id = readFirstInvocation(gl_SubgroupInvocationID);
295 * first_val = readFirstInvocation(val);
296 * first_result = readInvocation(val, readFirstInvocation(id));
297 * if (id == first_id)
298 * result = first_val;
299 * if (elect()) {
300 * if (id > gl_SubgroupInvocationID) {
301 * result = first_result;
302 * }
303 * break;
304 * }
305 * }
306 *
307 * The idea is to guarantee, on each iteration of the loop, that anything
308 * reading from first_id gets the correct value, so that we can then kill
309 * it off by breaking out of the loop. Before doing that we also have to
310 * ensure that first_id invocation gets the correct value. It only won't be
311 * assigned the correct value already if the invocation it's reading from
312 * isn't already killed off, that is, if it's later than its own ID.
313 * Invocations where id <= gl_SubgroupInvocationID will be assigned their
314 * result in the first if, and invocations where id >
315 * gl_SubgroupInvocationID will be assigned their result in the second if.
316 *
317 * We do this more complicated loop rather than looping over all id's
318 * explicitly because at this point we don't know the "actual" subgroup
319 * size and at the moment there's no way to get at it, which means we may
320 * loop over always-inactive invocations.
321 */
322
323 nir_def *subgroup_id = nir_load_subgroup_invocation(b);
324
325 nir_variable *result =
326 nir_local_variable_create(b->impl, glsl_type_for_ssa(val), "result");
327
328 nir_loop *loop = nir_push_loop(b);
329 {
330 nir_def *first_id = nir_read_first_invocation(b, subgroup_id);
331 nir_def *first_val = nir_read_first_invocation(b, val);
332 nir_def *first_result =
333 nir_read_invocation(b, val, nir_read_first_invocation(b, id));
334
335 nir_if *nif = nir_push_if(b, nir_ieq(b, id, first_id));
336 {
337 nir_store_var(b, result, first_val, BITFIELD_MASK(val->num_components));
338 }
339 nir_pop_if(b, nif);
340
341 nir_if *nif2 = nir_push_if(b, nir_elect(b, 1));
342 {
343 nir_if *nif3 = nir_push_if(b, nir_ult(b, subgroup_id, id));
344 {
345 nir_store_var(b, result, first_result, BITFIELD_MASK(val->num_components));
346 }
347 nir_pop_if(b, nif3);
348
349 nir_jump(b, nir_jump_break);
350 }
351 nir_pop_if(b, nif2);
352 }
353 nir_pop_loop(b, loop);
354
355 return nir_load_var(b, result);
356 }
357
358 static nir_def *
lower_boolean_shuffle(nir_builder * b,nir_intrinsic_instr * intrin,const nir_lower_subgroups_options * options)359 lower_boolean_shuffle(nir_builder *b, nir_intrinsic_instr *intrin,
360 const nir_lower_subgroups_options *options)
361 {
362 assert(options->ballot_components == 1 && options->subgroup_size);
363 nir_def *ballot = nir_ballot_relaxed(b, 1, options->ballot_bit_size, intrin->src[0].ssa);
364
365 nir_def *index = NULL;
366
367 /* If the shuffle amount isn't constant, it might be divergent but
368 * inverse_ballot requires a uniform source, so take a different path.
369 * rotate allows us to assume the delta is uniform unlike shuffle_up/down.
370 */
371 switch (intrin->intrinsic) {
372 case nir_intrinsic_shuffle_up:
373 if (nir_src_is_const(intrin->src[1]))
374 ballot = nir_ishl(b, ballot, intrin->src[1].ssa);
375 else
376 index = nir_isub(b, nir_load_subgroup_invocation(b), intrin->src[1].ssa);
377 break;
378 case nir_intrinsic_shuffle_down:
379 if (nir_src_is_const(intrin->src[1]))
380 ballot = nir_ushr(b, ballot, intrin->src[1].ssa);
381 else
382 index = nir_iadd(b, nir_load_subgroup_invocation(b), intrin->src[1].ssa);
383 break;
384 case nir_intrinsic_shuffle_xor:
385 index = nir_ixor(b, nir_load_subgroup_invocation(b), intrin->src[1].ssa);
386 break;
387 case nir_intrinsic_rotate: {
388 nir_def *delta = nir_as_uniform(b, intrin->src[1].ssa);
389 uint32_t cluster_size = nir_intrinsic_cluster_size(intrin);
390 cluster_size = cluster_size ? cluster_size : options->subgroup_size;
391 cluster_size = MIN2(cluster_size, options->subgroup_size);
392 if (cluster_size == 1) {
393 return intrin->src[0].ssa;
394 } else if (cluster_size == 2) {
395 delta = nir_iand_imm(b, delta, cluster_size - 1);
396 nir_def *lo = nir_iand_imm(b, nir_ushr_imm(b, ballot, 1), 0x5555555555555555ull);
397 nir_def *hi = nir_iand_imm(b, nir_ishl_imm(b, ballot, 1), 0xaaaaaaaaaaaaaaaaull);
398 ballot = nir_bcsel(b, nir_ine_imm(b, delta, 0), nir_ior(b, hi, lo), ballot);
399 } else if (cluster_size == ballot->bit_size) {
400 ballot = nir_uror(b, ballot, delta);
401 } else if (cluster_size == 32) {
402 nir_def *unpacked = nir_unpack_64_2x32(b, ballot);
403 unpacked = nir_uror(b, unpacked, delta);
404 ballot = nir_pack_64_2x32(b, unpacked);
405 } else {
406 delta = nir_iand_imm(b, delta, cluster_size - 1);
407 nir_def *delta_rev = nir_isub_imm(b, cluster_size, delta);
408 nir_def *mask = nir_mask(b, delta_rev, ballot->bit_size);
409 for (uint32_t i = cluster_size; i < ballot->bit_size; i *= 2) {
410 mask = nir_ior(b, nir_ishl_imm(b, mask, i), mask);
411 }
412 nir_def *lo = nir_iand(b, nir_ushr(b, ballot, delta), mask);
413 nir_def *hi = nir_iand(b, nir_ishl(b, ballot, delta_rev), nir_inot(b, mask));
414 ballot = nir_ior(b, lo, hi);
415 }
416 break;
417 }
418 case nir_intrinsic_shuffle:
419 index = intrin->src[1].ssa;
420 break;
421 case nir_intrinsic_read_invocation:
422 index = nir_as_uniform(b, intrin->src[1].ssa);
423 break;
424 default:
425 unreachable("not a boolean shuffle");
426 }
427
428 if (index) {
429 nir_def *mask = nir_ishl(b, nir_imm_intN_t(b, 1, ballot->bit_size), index);
430 return nir_ine_imm(b, nir_iand(b, ballot, mask), 0);
431 } else {
432 return nir_inverse_ballot(b, 1, ballot);
433 }
434 }
435
436 static nir_def *
vec_bit_count(nir_builder * b,nir_def * value)437 vec_bit_count(nir_builder *b, nir_def *value)
438 {
439 nir_def *vec_result = nir_bit_count(b, value);
440 nir_def *result = nir_channel(b, vec_result, 0);
441 for (unsigned i = 1; i < value->num_components; i++)
442 result = nir_iadd(b, result, nir_channel(b, vec_result, i));
443 return result;
444 }
445
446 /* produce a bitmask of 111...000...111... alternating between "size"
447 * 1's and "size" 0's (the LSB is 1).
448 */
449 static uint64_t
reduce_mask(unsigned size,unsigned ballot_bit_size)450 reduce_mask(unsigned size, unsigned ballot_bit_size)
451 {
452 uint64_t mask = 0;
453 for (unsigned i = 0; i < ballot_bit_size; i += 2 * size) {
454 mask |= ((1ull << size) - 1) << i;
455 }
456
457 return mask;
458 }
459
460 /* operate on a uniform per-thread bitmask provided by ballot() to perform the
461 * desired Boolean reduction. Assumes that the identity of the operation is
462 * false (so, no iand).
463 */
464 static nir_def *
lower_boolean_reduce_internal(nir_builder * b,nir_def * src,unsigned cluster_size,nir_op op,const nir_lower_subgroups_options * options)465 lower_boolean_reduce_internal(nir_builder *b, nir_def *src,
466 unsigned cluster_size, nir_op op,
467 const nir_lower_subgroups_options *options)
468 {
469 for (unsigned size = 1; size < cluster_size; size *= 2) {
470 nir_def *shifted = nir_ushr_imm(b, src, size);
471 src = nir_build_alu2(b, op, shifted, src);
472 uint64_t mask = reduce_mask(size, options->ballot_bit_size);
473 src = nir_iand_imm(b, src, mask);
474 shifted = nir_ishl_imm(b, src, size);
475 src = nir_ior(b, src, shifted);
476 }
477
478 return src;
479 }
480
481 /* operate on a uniform per-thread bitmask provided by ballot() to perform the
482 * desired Boolean inclusive scan. Assumes that the identity of the operation is
483 * false (so, no iand).
484 */
485 static nir_def *
lower_boolean_scan_internal(nir_builder * b,nir_def * src,nir_op op,const nir_lower_subgroups_options * options)486 lower_boolean_scan_internal(nir_builder *b, nir_def *src,
487 nir_op op,
488 const nir_lower_subgroups_options *options)
489 {
490 if (op == nir_op_ior) {
491 /* We want to return a bitmask with all 1's starting at the first 1 in
492 * src. -src is equivalent to ~src + 1. While src | ~src returns all
493 * 1's, src | (~src + 1) returns all 1's except for the bits changed by
494 * the increment. Any 1's before the least significant 0 of ~src are
495 * turned into 0 (zeroing those bits after or'ing) and the least
496 * signficant 0 of ~src is turned into 1 (not doing anything). So the
497 * final output is what we want.
498 */
499 return nir_ior(b, src, nir_ineg(b, src));
500 } else {
501 assert(op == nir_op_ixor);
502 for (unsigned shift = 1; shift < options->ballot_bit_size; shift *= 2) {
503 src = nir_ixor(b, src, nir_ishl_imm(b, src, shift));
504 }
505 return src;
506 }
507 }
508
509 static nir_def *
lower_boolean_reduce(nir_builder * b,nir_intrinsic_instr * intrin,const nir_lower_subgroups_options * options)510 lower_boolean_reduce(nir_builder *b, nir_intrinsic_instr *intrin,
511 const nir_lower_subgroups_options *options)
512 {
513 assert(intrin->num_components == 1);
514 assert(options->ballot_components == 1);
515
516 unsigned cluster_size =
517 intrin->intrinsic == nir_intrinsic_reduce ? nir_intrinsic_cluster_size(intrin) : 0;
518 nir_op op = nir_intrinsic_reduction_op(intrin);
519
520 /* For certain cluster sizes, reductions of iand and ior can be implemented
521 * more efficiently.
522 */
523 if (intrin->intrinsic == nir_intrinsic_reduce) {
524 if (cluster_size == 0) {
525 if (op == nir_op_iand)
526 return nir_vote_all(b, 1, intrin->src[0].ssa);
527 else if (op == nir_op_ior)
528 return nir_vote_any(b, 1, intrin->src[0].ssa);
529 else if (op == nir_op_ixor)
530 return nir_i2b(b, nir_iand_imm(b, vec_bit_count(b, nir_ballot(b,
531 options->ballot_components,
532 options->ballot_bit_size,
533 intrin->src[0].ssa)),
534 1));
535 else
536 unreachable("bad boolean reduction op");
537 }
538
539 if (cluster_size == 4) {
540 if (op == nir_op_iand)
541 return nir_quad_vote_all(b, 1, intrin->src[0].ssa);
542 else if (op == nir_op_ior)
543 return nir_quad_vote_any(b, 1, intrin->src[0].ssa);
544 }
545 }
546
547 nir_def *src = intrin->src[0].ssa;
548
549 /* Apply DeMorgan's law to implement "and" reductions, since all the
550 * lower_boolean_*_internal() functions assume an identity of 0 to make the
551 * generated code shorter.
552 */
553 nir_op new_op = (op == nir_op_iand) ? nir_op_ior : op;
554 if (op == nir_op_iand) {
555 src = nir_inot(b, src);
556 }
557
558 nir_def *val = nir_ballot(b, options->ballot_components, options->ballot_bit_size, src);
559
560 switch (intrin->intrinsic) {
561 case nir_intrinsic_reduce:
562 val = lower_boolean_reduce_internal(b, val, cluster_size, new_op, options);
563 break;
564 case nir_intrinsic_inclusive_scan:
565 val = lower_boolean_scan_internal(b, val, new_op, options);
566 break;
567 case nir_intrinsic_exclusive_scan:
568 val = lower_boolean_scan_internal(b, val, new_op, options);
569 val = nir_ishl_imm(b, val, 1);
570 break;
571 default:
572 unreachable("bad intrinsic");
573 }
574
575 if (op == nir_op_iand) {
576 val = nir_inot(b, val);
577 }
578
579 return nir_inverse_ballot(b, 1, val);
580 }
581
582 static bool
lower_subgroups_filter(const nir_instr * instr,const void * _options)583 lower_subgroups_filter(const nir_instr *instr, const void *_options)
584 {
585 return instr->type == nir_instr_type_intrinsic;
586 }
587
588 /* Return a ballot-mask-sized value which represents "val" sign-extended and
589 * then shifted left by "shift". Only particular values for "val" are
590 * supported, see below.
591 */
592 static nir_def *
build_ballot_imm_ishl(nir_builder * b,int64_t val,nir_def * shift,const nir_lower_subgroups_options * options)593 build_ballot_imm_ishl(nir_builder *b, int64_t val, nir_def *shift,
594 const nir_lower_subgroups_options *options)
595 {
596 /* This only works if all the high bits are the same as bit 1. */
597 assert((val >> 2) == (val & 0x2 ? -1 : 0));
598
599 /* First compute the result assuming one ballot component. */
600 nir_def *result =
601 nir_ishl(b, nir_imm_intN_t(b, val, options->ballot_bit_size), shift);
602
603 if (options->ballot_components == 1)
604 return result;
605
606 /* Fix up the result when there is > 1 component. The idea is that nir_ishl
607 * masks out the high bits of the shift value already, so in case there's
608 * more than one component the component which 1 would be shifted into
609 * already has the right value and all we have to do is fixup the other
610 * components. Components below it should always be 0, and components above
611 * it must be either 0 or ~0 because of the assert above. For example, if
612 * the target ballot size is 2 x uint32, and we're shifting 1 by 33, then
613 * we'll feed 33 into ishl, which will mask it off to get 1, so we'll
614 * compute a single-component result of 2, which is correct for the second
615 * component, but the first component needs to be 0, which we get by
616 * comparing the high bits of the shift with 0 and selecting the original
617 * answer or 0 for the first component (and something similar with the
618 * second component). This idea is generalized here for any component count
619 */
620 nir_const_value min_shift[4];
621 for (unsigned i = 0; i < options->ballot_components; i++)
622 min_shift[i] = nir_const_value_for_int(i * options->ballot_bit_size, 32);
623 nir_def *min_shift_val = nir_build_imm(b, options->ballot_components, 32, min_shift);
624
625 nir_const_value max_shift[4];
626 for (unsigned i = 0; i < options->ballot_components; i++)
627 max_shift[i] = nir_const_value_for_int((i + 1) * options->ballot_bit_size, 32);
628 nir_def *max_shift_val = nir_build_imm(b, options->ballot_components, 32, max_shift);
629
630 return nir_bcsel(b, nir_ult(b, shift, max_shift_val),
631 nir_bcsel(b, nir_ult(b, shift, min_shift_val),
632 nir_imm_intN_t(b, val >> 63, result->bit_size),
633 result),
634 nir_imm_intN_t(b, 0, result->bit_size));
635 }
636
637 static nir_def *
build_subgroup_eq_mask(nir_builder * b,const nir_lower_subgroups_options * options)638 build_subgroup_eq_mask(nir_builder *b,
639 const nir_lower_subgroups_options *options)
640 {
641 nir_def *subgroup_idx = nir_load_subgroup_invocation(b);
642
643 return build_ballot_imm_ishl(b, 1, subgroup_idx, options);
644 }
645
646 static nir_def *
build_subgroup_ge_mask(nir_builder * b,const nir_lower_subgroups_options * options)647 build_subgroup_ge_mask(nir_builder *b,
648 const nir_lower_subgroups_options *options)
649 {
650 nir_def *subgroup_idx = nir_load_subgroup_invocation(b);
651
652 return build_ballot_imm_ishl(b, ~0ull, subgroup_idx, options);
653 }
654
655 static nir_def *
build_subgroup_gt_mask(nir_builder * b,const nir_lower_subgroups_options * options)656 build_subgroup_gt_mask(nir_builder *b,
657 const nir_lower_subgroups_options *options)
658 {
659 nir_def *subgroup_idx = nir_load_subgroup_invocation(b);
660
661 return build_ballot_imm_ishl(b, ~1ull, subgroup_idx, options);
662 }
663
664 /* Return a mask which is 1 for threads up to the run-time subgroup size, i.e.
665 * 1 for the entire subgroup. SPIR-V requires us to return 0 for indices at or
666 * above the subgroup size for the masks, but gt_mask and ge_mask make them 1
667 * so we have to "and" with this mask.
668 */
669 static nir_def *
build_subgroup_mask(nir_builder * b,const nir_lower_subgroups_options * options)670 build_subgroup_mask(nir_builder *b,
671 const nir_lower_subgroups_options *options)
672 {
673 nir_def *subgroup_size = nir_load_subgroup_size(b);
674
675 /* First compute the result assuming one ballot component. */
676 nir_def *result =
677 nir_ushr(b, nir_imm_intN_t(b, ~0ull, options->ballot_bit_size),
678 nir_isub_imm(b, options->ballot_bit_size,
679 subgroup_size));
680
681 /* Since the subgroup size and ballot bitsize are both powers of two, there
682 * are two possible cases to consider:
683 *
684 * (1) The subgroup size is less than the ballot bitsize. We need to return
685 * "result" in the first component and 0 in every other component.
686 * (2) The subgroup size is a multiple of the ballot bitsize. We need to
687 * return ~0 if the subgroup size divided by the ballot bitsize is less
688 * than or equal to the index in the vector and 0 otherwise. For example,
689 * with a target ballot type of 4 x uint32 and subgroup_size = 64 we'd need
690 * to return { ~0, ~0, 0, 0 }.
691 *
692 * In case (2) it turns out that "result" will be ~0, because
693 * "ballot_bit_size - subgroup_size" is also a multiple of
694 * "ballot_bit_size" and since nir_ushr masks the shift value it will
695 * shifted by 0. This means that the first component can just be "result"
696 * in all cases. The other components will also get the correct value in
697 * case (1) if we just use the rule in case (2), so we'll get the correct
698 * result if we just follow (2) and then replace the first component with
699 * "result".
700 */
701 nir_const_value min_idx[4];
702 for (unsigned i = 0; i < options->ballot_components; i++)
703 min_idx[i] = nir_const_value_for_int(i * options->ballot_bit_size, 32);
704 nir_def *min_idx_val = nir_build_imm(b, options->ballot_components, 32, min_idx);
705
706 nir_def *result_extended =
707 nir_pad_vector_imm_int(b, result, ~0ull, options->ballot_components);
708
709 return nir_bcsel(b, nir_ult(b, min_idx_val, subgroup_size),
710 result_extended, nir_imm_intN_t(b, 0, options->ballot_bit_size));
711 }
712
713 static nir_def *
vec_find_lsb(nir_builder * b,nir_def * value)714 vec_find_lsb(nir_builder *b, nir_def *value)
715 {
716 nir_def *vec_result = nir_find_lsb(b, value);
717 nir_def *result = nir_imm_int(b, -1);
718 for (int i = value->num_components - 1; i >= 0; i--) {
719 nir_def *channel = nir_channel(b, vec_result, i);
720 /* result = channel >= 0 ? (i * bitsize + channel) : result */
721 result = nir_bcsel(b, nir_ige_imm(b, channel, 0),
722 nir_iadd_imm(b, channel, i * value->bit_size),
723 result);
724 }
725 return result;
726 }
727
728 static nir_def *
vec_find_msb(nir_builder * b,nir_def * value)729 vec_find_msb(nir_builder *b, nir_def *value)
730 {
731 nir_def *vec_result = nir_ufind_msb(b, value);
732 nir_def *result = nir_imm_int(b, -1);
733 for (unsigned i = 0; i < value->num_components; i++) {
734 nir_def *channel = nir_channel(b, vec_result, i);
735 /* result = channel >= 0 ? (i * bitsize + channel) : result */
736 result = nir_bcsel(b, nir_ige_imm(b, channel, 0),
737 nir_iadd_imm(b, channel, i * value->bit_size),
738 result);
739 }
740 return result;
741 }
742
743 static nir_def *
lower_dynamic_quad_broadcast(nir_builder * b,nir_intrinsic_instr * intrin,const nir_lower_subgroups_options * options)744 lower_dynamic_quad_broadcast(nir_builder *b, nir_intrinsic_instr *intrin,
745 const nir_lower_subgroups_options *options)
746 {
747 if (!options->lower_quad_broadcast_dynamic_to_const)
748 return lower_to_shuffle(b, intrin, options);
749
750 nir_def *dst = NULL;
751
752 for (unsigned i = 0; i < 4; ++i) {
753 nir_def *qbcst = nir_quad_broadcast(b, intrin->src[0].ssa,
754 nir_imm_int(b, i));
755
756 if (i)
757 dst = nir_bcsel(b, nir_ieq_imm(b, intrin->src[1].ssa, i),
758 qbcst, dst);
759 else
760 dst = qbcst;
761 }
762
763 return dst;
764 }
765
766 static nir_def *
lower_first_invocation_to_ballot(nir_builder * b,nir_intrinsic_instr * intrin,const nir_lower_subgroups_options * options)767 lower_first_invocation_to_ballot(nir_builder *b, nir_intrinsic_instr *intrin,
768 const nir_lower_subgroups_options *options)
769 {
770 return nir_ballot_find_lsb(b, 32, nir_ballot(b, 4, 32, nir_imm_true(b)));
771 }
772
773 static nir_def *
lower_read_first_invocation(nir_builder * b,nir_intrinsic_instr * intrin)774 lower_read_first_invocation(nir_builder *b, nir_intrinsic_instr *intrin)
775 {
776 return nir_read_invocation(b, intrin->src[0].ssa, nir_first_invocation(b));
777 }
778
779 static nir_def *
lower_read_invocation_to_cond(nir_builder * b,nir_intrinsic_instr * intrin)780 lower_read_invocation_to_cond(nir_builder *b, nir_intrinsic_instr *intrin)
781 {
782 return nir_read_invocation_cond_ir3(b, intrin->def.bit_size,
783 intrin->src[0].ssa,
784 nir_ieq(b, intrin->src[1].ssa,
785 nir_load_subgroup_invocation(b)));
786 }
787
788 static nir_def *
lower_subgroups_instr(nir_builder * b,nir_instr * instr,void * _options)789 lower_subgroups_instr(nir_builder *b, nir_instr *instr, void *_options)
790 {
791 const nir_lower_subgroups_options *options = _options;
792
793 nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
794 switch (intrin->intrinsic) {
795 case nir_intrinsic_vote_any:
796 case nir_intrinsic_vote_all:
797 if (options->lower_vote_trivial)
798 return intrin->src[0].ssa;
799 break;
800
801 case nir_intrinsic_vote_feq:
802 case nir_intrinsic_vote_ieq:
803 if (options->lower_vote_trivial)
804 return nir_imm_true(b);
805
806 if (nir_src_bit_size(intrin->src[0]) == 1) {
807 if (options->lower_vote_bool_eq)
808 return lower_vote_eq(b, intrin);
809 } else {
810 if (options->lower_vote_eq)
811 return lower_vote_eq(b, intrin);
812 }
813
814 if (options->lower_to_scalar && intrin->num_components > 1)
815 return lower_vote_eq_to_scalar(b, intrin);
816 break;
817
818 case nir_intrinsic_load_subgroup_size:
819 if (options->subgroup_size)
820 return nir_imm_int(b, options->subgroup_size);
821 break;
822
823 case nir_intrinsic_first_invocation:
824 if (options->subgroup_size == 1)
825 return nir_imm_int(b, 0);
826
827 if (options->lower_first_invocation_to_ballot)
828 return lower_first_invocation_to_ballot(b, intrin, options);
829
830 break;
831
832 case nir_intrinsic_read_invocation:
833 if (options->lower_to_scalar && intrin->num_components > 1)
834 return lower_subgroup_op_to_scalar(b, intrin);
835
836 if (options->lower_boolean_shuffle && intrin->src[0].ssa->bit_size == 1)
837 return lower_boolean_shuffle(b, intrin, options);
838
839 if (options->lower_read_invocation_to_cond)
840 return lower_read_invocation_to_cond(b, intrin);
841
842 break;
843
844 case nir_intrinsic_read_first_invocation:
845 if (options->lower_to_scalar && intrin->num_components > 1)
846 return lower_subgroup_op_to_scalar(b, intrin);
847
848 if (options->lower_read_first_invocation)
849 return lower_read_first_invocation(b, intrin);
850 break;
851
852 case nir_intrinsic_load_subgroup_eq_mask:
853 case nir_intrinsic_load_subgroup_ge_mask:
854 case nir_intrinsic_load_subgroup_gt_mask:
855 case nir_intrinsic_load_subgroup_le_mask:
856 case nir_intrinsic_load_subgroup_lt_mask: {
857 if (!options->lower_subgroup_masks)
858 return NULL;
859
860 nir_def *val;
861 switch (intrin->intrinsic) {
862 case nir_intrinsic_load_subgroup_eq_mask:
863 val = build_subgroup_eq_mask(b, options);
864 break;
865 case nir_intrinsic_load_subgroup_ge_mask:
866 val = nir_iand(b, build_subgroup_ge_mask(b, options),
867 build_subgroup_mask(b, options));
868 break;
869 case nir_intrinsic_load_subgroup_gt_mask:
870 val = nir_iand(b, build_subgroup_gt_mask(b, options),
871 build_subgroup_mask(b, options));
872 break;
873 case nir_intrinsic_load_subgroup_le_mask:
874 val = nir_inot(b, build_subgroup_gt_mask(b, options));
875 break;
876 case nir_intrinsic_load_subgroup_lt_mask:
877 val = nir_inot(b, build_subgroup_ge_mask(b, options));
878 break;
879 default:
880 unreachable("you seriously can't tell this is unreachable?");
881 }
882
883 return uint_to_ballot_type(b, val,
884 intrin->def.num_components,
885 intrin->def.bit_size);
886 }
887
888 case nir_intrinsic_ballot: {
889 if (intrin->def.num_components == options->ballot_components &&
890 intrin->def.bit_size == options->ballot_bit_size)
891 return NULL;
892
893 nir_def *ballot =
894 nir_ballot(b, options->ballot_components, options->ballot_bit_size,
895 intrin->src[0].ssa);
896
897 return uint_to_ballot_type(b, ballot,
898 intrin->def.num_components,
899 intrin->def.bit_size);
900 }
901
902 case nir_intrinsic_inverse_ballot:
903 if (options->lower_inverse_ballot) {
904 return nir_ballot_bitfield_extract(b, 1, intrin->src[0].ssa,
905 nir_load_subgroup_invocation(b));
906 } else if (intrin->src[0].ssa->num_components != options->ballot_components ||
907 intrin->src[0].ssa->bit_size != options->ballot_bit_size) {
908 return nir_inverse_ballot(b, 1, ballot_type_to_uint(b, intrin->src[0].ssa, options));
909 }
910 break;
911
912 case nir_intrinsic_ballot_bitfield_extract:
913 case nir_intrinsic_ballot_bit_count_reduce:
914 case nir_intrinsic_ballot_find_lsb:
915 case nir_intrinsic_ballot_find_msb: {
916 nir_def *int_val = ballot_type_to_uint(b, intrin->src[0].ssa,
917 options);
918
919 if (intrin->intrinsic != nir_intrinsic_ballot_bitfield_extract &&
920 intrin->intrinsic != nir_intrinsic_ballot_find_lsb) {
921 /* For OpGroupNonUniformBallotFindMSB, the SPIR-V Spec says:
922 *
923 * "Find the most significant bit set to 1 in Value, considering
924 * only the bits in Value required to represent all bits of the
925 * group’s invocations. If none of the considered bits is set to
926 * 1, the result is undefined."
927 *
928 * It has similar text for the other three. This means that, in case
929 * the subgroup size is less than 32, we have to mask off the unused
930 * bits. If the subgroup size is fixed and greater than or equal to
931 * 32, the mask will be 0xffffffff and nir_opt_algebraic will delete
932 * the iand.
933 *
934 * We only have to worry about this for BitCount and FindMSB because
935 * FindLSB counts from the bottom and BitfieldExtract selects
936 * individual bits. In either case, if run outside the range of
937 * valid bits, we hit the undefined results case and we can return
938 * anything we want.
939 */
940 int_val = nir_iand(b, int_val, build_subgroup_mask(b, options));
941 }
942
943 switch (intrin->intrinsic) {
944 case nir_intrinsic_ballot_bitfield_extract: {
945 nir_def *idx = intrin->src[1].ssa;
946 if (int_val->num_components > 1) {
947 /* idx will be truncated by nir_ushr, so we just need to select
948 * the right component using the bits of idx that are truncated in
949 * the shift.
950 */
951 int_val =
952 nir_vector_extract(b, int_val,
953 nir_udiv_imm(b, idx, int_val->bit_size));
954 }
955
956 return nir_test_mask(b, nir_ushr(b, int_val, idx), 1);
957 }
958 case nir_intrinsic_ballot_bit_count_reduce:
959 return vec_bit_count(b, int_val);
960 case nir_intrinsic_ballot_find_lsb:
961 return vec_find_lsb(b, int_val);
962 case nir_intrinsic_ballot_find_msb:
963 return vec_find_msb(b, int_val);
964 default:
965 unreachable("you seriously can't tell this is unreachable?");
966 }
967 }
968
969 case nir_intrinsic_ballot_bit_count_exclusive:
970 case nir_intrinsic_ballot_bit_count_inclusive: {
971 nir_def *int_val = ballot_type_to_uint(b, intrin->src[0].ssa,
972 options);
973 if (options->lower_ballot_bit_count_to_mbcnt_amd) {
974 nir_def *acc;
975 if (intrin->intrinsic == nir_intrinsic_ballot_bit_count_exclusive) {
976 acc = nir_imm_int(b, 0);
977 } else {
978 acc = nir_iand_imm(b, nir_u2u32(b, int_val), 0x1);
979 int_val = nir_ushr_imm(b, int_val, 1);
980 }
981 return nir_mbcnt_amd(b, int_val, acc);
982 }
983
984 nir_def *mask;
985 if (intrin->intrinsic == nir_intrinsic_ballot_bit_count_inclusive) {
986 mask = nir_inot(b, build_subgroup_gt_mask(b, options));
987 } else {
988 mask = nir_inot(b, build_subgroup_ge_mask(b, options));
989 }
990
991 return vec_bit_count(b, nir_iand(b, int_val, mask));
992 }
993
994 case nir_intrinsic_elect: {
995 if (!options->lower_elect)
996 return NULL;
997
998 return nir_ieq(b, nir_load_subgroup_invocation(b), nir_first_invocation(b));
999 }
1000
1001 case nir_intrinsic_shuffle:
1002 if (options->lower_shuffle &&
1003 (!options->lower_boolean_shuffle || intrin->src[0].ssa->bit_size != 1))
1004 return lower_shuffle(b, intrin);
1005 else if (options->lower_to_scalar && intrin->num_components > 1)
1006 return lower_subgroup_op_to_scalar(b, intrin);
1007 else if (options->lower_boolean_shuffle && intrin->src[0].ssa->bit_size == 1)
1008 return lower_boolean_shuffle(b, intrin, options);
1009 else if (options->lower_shuffle_to_32bit && intrin->src[0].ssa->bit_size == 64)
1010 return lower_subgroup_op_to_32bit(b, intrin);
1011 break;
1012 case nir_intrinsic_shuffle_xor:
1013 case nir_intrinsic_shuffle_up:
1014 case nir_intrinsic_shuffle_down:
1015 if (options->lower_relative_shuffle &&
1016 (!options->lower_boolean_shuffle || intrin->src[0].ssa->bit_size != 1))
1017 return lower_to_shuffle(b, intrin, options);
1018 else if (options->lower_to_scalar && intrin->num_components > 1)
1019 return lower_subgroup_op_to_scalar(b, intrin);
1020 else if (options->lower_boolean_shuffle && intrin->src[0].ssa->bit_size == 1)
1021 return lower_boolean_shuffle(b, intrin, options);
1022 else if (options->lower_shuffle_to_32bit && intrin->src[0].ssa->bit_size == 64)
1023 return lower_subgroup_op_to_32bit(b, intrin);
1024 break;
1025
1026 case nir_intrinsic_quad_broadcast:
1027 case nir_intrinsic_quad_swap_horizontal:
1028 case nir_intrinsic_quad_swap_vertical:
1029 case nir_intrinsic_quad_swap_diagonal:
1030 if (options->lower_quad ||
1031 (options->lower_quad_broadcast_dynamic &&
1032 intrin->intrinsic == nir_intrinsic_quad_broadcast &&
1033 !nir_src_is_const(intrin->src[1])))
1034 return lower_dynamic_quad_broadcast(b, intrin, options);
1035 else if (options->lower_to_scalar && intrin->num_components > 1)
1036 return lower_subgroup_op_to_scalar(b, intrin);
1037 break;
1038
1039 case nir_intrinsic_reduce: {
1040 nir_def *ret = NULL;
1041 /* A cluster size greater than the subgroup size is implemention defined */
1042 if (options->subgroup_size &&
1043 nir_intrinsic_cluster_size(intrin) >= options->subgroup_size) {
1044 nir_intrinsic_set_cluster_size(intrin, 0);
1045 ret = NIR_LOWER_INSTR_PROGRESS;
1046 }
1047 if (nir_intrinsic_cluster_size(intrin) == 1)
1048 return intrin->src[0].ssa;
1049 if (options->lower_to_scalar && intrin->num_components > 1)
1050 return lower_subgroup_op_to_scalar(b, intrin);
1051 if (options->lower_boolean_reduce && intrin->def.bit_size == 1)
1052 return lower_boolean_reduce(b, intrin, options);
1053 return ret;
1054 }
1055 case nir_intrinsic_inclusive_scan:
1056 case nir_intrinsic_exclusive_scan:
1057 if (options->lower_to_scalar && intrin->num_components > 1)
1058 return lower_subgroup_op_to_scalar(b, intrin);
1059 if (options->lower_boolean_reduce && intrin->def.bit_size == 1)
1060 return lower_boolean_reduce(b, intrin, options);
1061 break;
1062
1063 case nir_intrinsic_rotate:
1064 if (nir_intrinsic_execution_scope(intrin) == SCOPE_SUBGROUP) {
1065 if (options->lower_rotate_to_shuffle &&
1066 (!options->lower_boolean_shuffle || intrin->src[0].ssa->bit_size != 1))
1067 return lower_to_shuffle(b, intrin, options);
1068 else if (options->lower_to_scalar && intrin->num_components > 1)
1069 return lower_subgroup_op_to_scalar(b, intrin);
1070 else if (options->lower_boolean_shuffle && intrin->src[0].ssa->bit_size == 1)
1071 return lower_boolean_shuffle(b, intrin, options);
1072 else if (options->lower_shuffle_to_32bit && intrin->src[0].ssa->bit_size == 64)
1073 return lower_subgroup_op_to_32bit(b, intrin);
1074 }
1075 break;
1076 case nir_intrinsic_masked_swizzle_amd:
1077 if (options->lower_to_scalar && intrin->num_components > 1) {
1078 return lower_subgroup_op_to_scalar(b, intrin);
1079 } else if (options->lower_shuffle_to_32bit && intrin->src[0].ssa->bit_size == 64) {
1080 return lower_subgroup_op_to_32bit(b, intrin);
1081 }
1082 break;
1083
1084 default:
1085 break;
1086 }
1087
1088 return NULL;
1089 }
1090
1091 bool
nir_lower_subgroups(nir_shader * shader,const nir_lower_subgroups_options * options)1092 nir_lower_subgroups(nir_shader *shader,
1093 const nir_lower_subgroups_options *options)
1094 {
1095 return nir_shader_lower_instructions(shader,
1096 lower_subgroups_filter,
1097 lower_subgroups_instr,
1098 (void *)options);
1099 }
1100