• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright © 2023 Collabora, Ltd.
3  * Copyright © 2017 Intel Corporation
4  *
5  * Permission is hereby granted, free of charge, to any person obtaining a
6  * copy of this software and associated documentation files (the "Software"),
7  * to deal in the Software without restriction, including without limitation
8  * the rights to use, copy, modify, merge, publish, distribute, sublicense,
9  * and/or sell copies of the Software, and to permit persons to whom the
10  * Software is furnished to do so, subject to the following conditions:
11  *
12  * The above copyright notice and this permission notice (including the next
13  * paragraph) shall be included in all copies or substantial portions of the
14  * Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
19  * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
21  * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
22  * IN THE SOFTWARE.
23  */
24 
25 #include "util/u_math.h"
26 #include "nir.h"
27 #include "nir_builder.h"
28 
29 /**
30  * \file nir_opt_intrinsics.c
31  */
32 
33 static unsigned
get_max_subgroup_size(const nir_lower_subgroups_options * options)34 get_max_subgroup_size(const nir_lower_subgroups_options *options)
35 {
36    return options->subgroup_size
37              ? options->subgroup_size
38              : options->ballot_components * options->ballot_bit_size;
39 }
40 
41 static nir_intrinsic_instr *
lower_subgroups_64bit_split_intrinsic(nir_builder * b,nir_intrinsic_instr * intrin,unsigned int component)42 lower_subgroups_64bit_split_intrinsic(nir_builder *b, nir_intrinsic_instr *intrin,
43                                       unsigned int component)
44 {
45    nir_def *comp;
46    if (component == 0)
47       comp = nir_unpack_64_2x32_split_x(b, intrin->src[0].ssa);
48    else
49       comp = nir_unpack_64_2x32_split_y(b, intrin->src[0].ssa);
50 
51    nir_intrinsic_instr *intr = nir_intrinsic_instr_create(b->shader, intrin->intrinsic);
52    nir_def_init(&intr->instr, &intr->def, 1, 32);
53    intr->const_index[0] = intrin->const_index[0];
54    intr->const_index[1] = intrin->const_index[1];
55    intr->src[0] = nir_src_for_ssa(comp);
56    if (nir_intrinsic_infos[intrin->intrinsic].num_srcs == 2)
57       intr->src[1] = nir_src_for_ssa(intrin->src[1].ssa);
58 
59    intr->num_components = 1;
60    nir_builder_instr_insert(b, &intr->instr);
61    return intr;
62 }
63 
64 static nir_def *
lower_subgroup_op_to_32bit(nir_builder * b,nir_intrinsic_instr * intrin)65 lower_subgroup_op_to_32bit(nir_builder *b, nir_intrinsic_instr *intrin)
66 {
67    assert(intrin->src[0].ssa->bit_size == 64);
68    nir_intrinsic_instr *intr_x = lower_subgroups_64bit_split_intrinsic(b, intrin, 0);
69    nir_intrinsic_instr *intr_y = lower_subgroups_64bit_split_intrinsic(b, intrin, 1);
70    return nir_pack_64_2x32_split(b, &intr_x->def, &intr_y->def);
71 }
72 
73 /* Return a mask which is 1 for threads up to the run-time subgroup size, i.e.
74  * 1 for the entire subgroup. SPIR-V requires us to return 0 for indices at or
75  * above the subgroup size for the masks, but gt_mask and ge_mask make them 1
76  * so we have to "and" with this mask.
77  */
78 static nir_def *
build_subgroup_mask(nir_builder * b,const nir_lower_subgroups_options * options)79 build_subgroup_mask(nir_builder *b,
80                     const nir_lower_subgroups_options *options)
81 {
82    nir_def *subgroup_size = nir_load_subgroup_size(b);
83 
84    /* First compute the result assuming one ballot component. */
85    nir_def *result =
86       nir_ushr(b, nir_imm_intN_t(b, ~0ull, options->ballot_bit_size),
87                nir_isub_imm(b, options->ballot_bit_size,
88                             subgroup_size));
89 
90    /* Since the subgroup size and ballot bitsize are both powers of two, there
91     * are two possible cases to consider:
92     *
93     * (1) The subgroup size is less than the ballot bitsize. We need to return
94     * "result" in the first component and 0 in every other component.
95     * (2) The subgroup size is a multiple of the ballot bitsize. We need to
96     * return ~0 if the subgroup size divided by the ballot bitsize is less
97     * than or equal to the index in the vector and 0 otherwise. For example,
98     * with a target ballot type of 4 x uint32 and subgroup_size = 64 we'd need
99     * to return { ~0, ~0, 0, 0 }.
100     *
101     * In case (2) it turns out that "result" will be ~0, because
102     * "ballot_bit_size - subgroup_size" is also a multiple of
103     * "ballot_bit_size" and since nir_ushr masks the shift value it will
104     * shifted by 0. This means that the first component can just be "result"
105     * in all cases.  The other components will also get the correct value in
106     * case (1) if we just use the rule in case (2), so we'll get the correct
107     * result if we just follow (2) and then replace the first component with
108     * "result".
109     */
110    nir_const_value min_idx[4];
111    for (unsigned i = 0; i < options->ballot_components; i++)
112       min_idx[i] = nir_const_value_for_int(i * options->ballot_bit_size, 32);
113    nir_def *min_idx_val = nir_build_imm(b, options->ballot_components, 32, min_idx);
114 
115    nir_def *result_extended =
116       nir_pad_vector_imm_int(b, result, ~0ull, options->ballot_components);
117 
118    return nir_bcsel(b, nir_ult(b, min_idx_val, subgroup_size),
119                     result_extended, nir_imm_intN_t(b, 0, options->ballot_bit_size));
120 }
121 
122 /* Return a ballot-mask-sized value which represents "val" sign-extended and
123  * then shifted left by "shift". Only particular values for "val" are
124  * supported, see below.
125  *
126  * This function assumes that `val << shift` will never span a ballot_bit_size
127  * word and that the high bit of val can be extended across the entire result.
128  * This is trivially satisfied for 0, 1, ~0, and ~1.  However, it may also be
129  * fine for other values if the shift is guaranteed to be sufficiently
130  * aligned.  One example is 0xf when the shift is known to be a multiple of 4.
131  */
132 static nir_def *
build_ballot_imm_ishl(nir_builder * b,int64_t val,nir_def * shift,const nir_lower_subgroups_options * options)133 build_ballot_imm_ishl(nir_builder *b, int64_t val, nir_def *shift,
134                       const nir_lower_subgroups_options *options)
135 {
136    /* First compute the result assuming one ballot component. */
137    nir_def *result =
138       nir_ishl(b, nir_imm_intN_t(b, val, options->ballot_bit_size), shift);
139 
140    if (options->ballot_components == 1)
141       return result;
142 
143    /* Fix up the result when there is > 1 component. The idea is that nir_ishl
144     * masks out the high bits of the shift value already, so in case there's
145     * more than one component the component which 1 would be shifted into
146     * already has the right value and all we have to do is fixup the other
147     * components. Components below it should always be 0, and components above
148     * it must be either 0 or ~0 because of the assert above. For example, if
149     * the target ballot size is 2 x uint32, and we're shifting 1 by 33, then
150     * we'll feed 33 into ishl, which will mask it off to get 1, so we'll
151     * compute a single-component result of 2, which is correct for the second
152     * component, but the first component needs to be 0, which we get by
153     * comparing the high bits of the shift with 0 and selecting the original
154     * answer or 0 for the first component (and something similar with the
155     * second component). This idea is generalized here for any component count
156     */
157    nir_const_value min_shift[4];
158    for (unsigned i = 0; i < options->ballot_components; i++)
159       min_shift[i] = nir_const_value_for_int(i * options->ballot_bit_size, 32);
160    nir_def *min_shift_val = nir_build_imm(b, options->ballot_components, 32, min_shift);
161 
162    nir_const_value max_shift[4];
163    for (unsigned i = 0; i < options->ballot_components; i++)
164       max_shift[i] = nir_const_value_for_int((i + 1) * options->ballot_bit_size, 32);
165    nir_def *max_shift_val = nir_build_imm(b, options->ballot_components, 32, max_shift);
166 
167    return nir_bcsel(b, nir_ult(b, shift, max_shift_val),
168                     nir_bcsel(b, nir_ult(b, shift, min_shift_val),
169                               nir_imm_intN_t(b, val >> 63, result->bit_size),
170                               result),
171                     nir_imm_intN_t(b, 0, result->bit_size));
172 }
173 
174 static nir_def *
ballot_type_to_uint(nir_builder * b,nir_def * value,const nir_lower_subgroups_options * options)175 ballot_type_to_uint(nir_builder *b, nir_def *value,
176                     const nir_lower_subgroups_options *options)
177 {
178    /* Allow internal generated ballots to pass through */
179    if (value->num_components == options->ballot_components &&
180        value->bit_size == options->ballot_bit_size)
181       return value;
182 
183    /* Only the new-style SPIR-V subgroup instructions take a ballot result as
184     * an argument, so we only use this on uvec4 types.
185     */
186    assert(value->num_components == 4 && value->bit_size == 32);
187 
188    return nir_extract_bits(b, &value, 1, 0, options->ballot_components,
189                            options->ballot_bit_size);
190 }
191 
192 static nir_def *
uint_to_ballot_type(nir_builder * b,nir_def * value,unsigned num_components,unsigned bit_size)193 uint_to_ballot_type(nir_builder *b, nir_def *value,
194                     unsigned num_components, unsigned bit_size)
195 {
196    assert(util_is_power_of_two_nonzero(num_components));
197    assert(util_is_power_of_two_nonzero(value->num_components));
198 
199    unsigned total_bits = bit_size * num_components;
200 
201    /* If the source doesn't have enough bits, zero-pad */
202    if (total_bits > value->bit_size * value->num_components)
203       value = nir_pad_vector_imm_int(b, value, 0, total_bits / value->bit_size);
204 
205    value = nir_bitcast_vector(b, value, bit_size);
206 
207    /* If the source has too many components, truncate.  This can happen if,
208     * for instance, we're implementing GL_ARB_shader_ballot or
209     * VK_EXT_shader_subgroup_ballot which have 64-bit ballot values on an
210     * architecture with a native 128-bit uvec4 ballot.  This comes up in Zink
211     * for OpenGL on Vulkan.  It's the job of the driver calling this lowering
212     * pass to ensure that it's restricted subgroup sizes sufficiently that we
213     * have enough ballot bits.
214     */
215    if (value->num_components > num_components)
216       value = nir_trim_vector(b, value, num_components);
217 
218    return value;
219 }
220 
221 static nir_def *
lower_subgroup_op_to_scalar(nir_builder * b,nir_intrinsic_instr * intrin)222 lower_subgroup_op_to_scalar(nir_builder *b, nir_intrinsic_instr *intrin)
223 {
224    /* This is safe to call on scalar things but it would be silly */
225    assert(intrin->def.num_components > 1);
226 
227    nir_def *value = intrin->src[0].ssa;
228    nir_def *reads[NIR_MAX_VEC_COMPONENTS];
229 
230    for (unsigned i = 0; i < intrin->num_components; i++) {
231       nir_intrinsic_instr *chan_intrin =
232          nir_intrinsic_instr_create(b->shader, intrin->intrinsic);
233       nir_def_init(&chan_intrin->instr, &chan_intrin->def, 1,
234                    intrin->def.bit_size);
235       chan_intrin->num_components = 1;
236 
237       /* value */
238       chan_intrin->src[0] = nir_src_for_ssa(nir_channel(b, value, i));
239       /* invocation */
240       if (nir_intrinsic_infos[intrin->intrinsic].num_srcs > 1) {
241          assert(nir_intrinsic_infos[intrin->intrinsic].num_srcs == 2);
242          chan_intrin->src[1] = nir_src_for_ssa(intrin->src[1].ssa);
243       }
244 
245       chan_intrin->const_index[0] = intrin->const_index[0];
246       chan_intrin->const_index[1] = intrin->const_index[1];
247 
248       nir_builder_instr_insert(b, &chan_intrin->instr);
249       reads[i] = &chan_intrin->def;
250    }
251 
252    return nir_vec(b, reads, intrin->num_components);
253 }
254 
255 static nir_def *
lower_vote_eq_to_scalar(nir_builder * b,nir_intrinsic_instr * intrin)256 lower_vote_eq_to_scalar(nir_builder *b, nir_intrinsic_instr *intrin)
257 {
258    nir_def *value = intrin->src[0].ssa;
259 
260    nir_def *result = NULL;
261    for (unsigned i = 0; i < intrin->num_components; i++) {
262       nir_def* chan = nir_channel(b, value, i);
263 
264       if (intrin->intrinsic == nir_intrinsic_vote_feq) {
265          chan = nir_vote_feq(b, intrin->def.bit_size, chan);
266       } else {
267          chan = nir_vote_ieq(b, intrin->def.bit_size, chan);
268       }
269 
270       if (result) {
271          result = nir_iand(b, result, chan);
272       } else {
273          result = chan;
274       }
275    }
276 
277    return result;
278 }
279 
280 static nir_def *
lower_vote_eq(nir_builder * b,nir_intrinsic_instr * intrin)281 lower_vote_eq(nir_builder *b, nir_intrinsic_instr *intrin)
282 {
283    nir_def *value = intrin->src[0].ssa;
284 
285    /* We have to implicitly lower to scalar */
286    nir_def *all_eq = NULL;
287    for (unsigned i = 0; i < intrin->num_components; i++) {
288       nir_def *rfi = nir_read_first_invocation(b, nir_channel(b, value, i));
289 
290       nir_def *is_eq;
291       if (intrin->intrinsic == nir_intrinsic_vote_feq) {
292          is_eq = nir_feq(b, rfi, nir_channel(b, value, i));
293       } else {
294          is_eq = nir_ieq(b, rfi, nir_channel(b, value, i));
295       }
296 
297       if (all_eq == NULL) {
298          all_eq = is_eq;
299       } else {
300          all_eq = nir_iand(b, all_eq, is_eq);
301       }
302    }
303 
304    return nir_vote_all(b, 1, all_eq);
305 }
306 
307 static nir_def *
lower_shuffle_to_swizzle(nir_builder * b,nir_intrinsic_instr * intrin)308 lower_shuffle_to_swizzle(nir_builder *b, nir_intrinsic_instr *intrin)
309 {
310    unsigned mask = nir_src_as_uint(intrin->src[1]);
311 
312    if (mask >= 32)
313       return NULL;
314 
315    return nir_masked_swizzle_amd(b, intrin->src[0].ssa,
316                                  .swizzle_mask = (mask << 10) | 0x1f,
317                                  .fetch_inactive = true);
318 }
319 
320 /* Lowers "specialized" shuffles to a generic nir_intrinsic_shuffle. */
321 
322 static nir_def *
lower_to_shuffle(nir_builder * b,nir_intrinsic_instr * intrin,const nir_lower_subgroups_options * options)323 lower_to_shuffle(nir_builder *b, nir_intrinsic_instr *intrin,
324                  const nir_lower_subgroups_options *options)
325 {
326    if (intrin->intrinsic == nir_intrinsic_shuffle_xor &&
327        options->lower_shuffle_to_swizzle_amd &&
328        nir_src_is_const(intrin->src[1])) {
329 
330       nir_def *result = lower_shuffle_to_swizzle(b, intrin);
331       if (result)
332          return result;
333    }
334 
335    nir_def *index = nir_load_subgroup_invocation(b);
336    switch (intrin->intrinsic) {
337    case nir_intrinsic_shuffle_xor:
338       index = nir_ixor(b, index, intrin->src[1].ssa);
339       break;
340    case nir_intrinsic_shuffle_up:
341       index = nir_isub(b, index, intrin->src[1].ssa);
342       break;
343    case nir_intrinsic_shuffle_down:
344       index = nir_iadd(b, index, intrin->src[1].ssa);
345       break;
346    case nir_intrinsic_quad_broadcast:
347       index = nir_ior(b, nir_iand_imm(b, index, ~0x3),
348                       intrin->src[1].ssa);
349       break;
350    case nir_intrinsic_quad_swap_horizontal:
351       /* For Quad operations, subgroups are divided into quads where
352        * (invocation % 4) is the index to a square arranged as follows:
353        *
354        *    +---+---+
355        *    | 0 | 1 |
356        *    +---+---+
357        *    | 2 | 3 |
358        *    +---+---+
359        */
360       index = nir_ixor(b, index, nir_imm_int(b, 0x1));
361       break;
362    case nir_intrinsic_quad_swap_vertical:
363       index = nir_ixor(b, index, nir_imm_int(b, 0x2));
364       break;
365    case nir_intrinsic_quad_swap_diagonal:
366       index = nir_ixor(b, index, nir_imm_int(b, 0x3));
367       break;
368    case nir_intrinsic_rotate: {
369       nir_def *delta = intrin->src[1].ssa;
370       nir_def *local_id = nir_load_subgroup_invocation(b);
371       const unsigned cluster_size = nir_intrinsic_cluster_size(intrin);
372 
373       nir_def *rotation_group_mask =
374          cluster_size > 0 ? nir_imm_int(b, (int)(cluster_size - 1)) : nir_iadd_imm(b, nir_load_subgroup_size(b), -1);
375 
376       index = nir_iand(b, nir_iadd(b, local_id, delta),
377                        rotation_group_mask);
378       if (cluster_size > 0) {
379          index = nir_iadd(b, index,
380                           nir_iand(b, local_id, nir_inot(b, rotation_group_mask)));
381       }
382       break;
383    }
384    default:
385       unreachable("Invalid intrinsic");
386    }
387 
388    return nir_shuffle(b, intrin->src[0].ssa, index);
389 }
390 
391 static const struct glsl_type *
glsl_type_for_ssa(nir_def * def)392 glsl_type_for_ssa(nir_def *def)
393 {
394    const struct glsl_type *comp_type = def->bit_size == 1 ? glsl_bool_type() : glsl_uintN_t_type(def->bit_size);
395    return glsl_replace_vector_type(comp_type, def->num_components);
396 }
397 
398 /* Lower nir_intrinsic_shuffle to a waterfall loop + nir_read_invocation.
399  */
400 static nir_def *
lower_shuffle(nir_builder * b,nir_intrinsic_instr * intrin)401 lower_shuffle(nir_builder *b, nir_intrinsic_instr *intrin)
402 {
403    nir_def *val = intrin->src[0].ssa;
404    nir_def *id = intrin->src[1].ssa;
405 
406    /* The loop is something like:
407     *
408     * while (true) {
409     *    first_id = readFirstInvocation(gl_SubgroupInvocationID);
410     *    first_val = readFirstInvocation(val);
411     *    first_result = readInvocation(val, readFirstInvocation(id));
412     *    if (id == first_id)
413     *       result = first_val;
414     *    if (elect()) {
415     *       if (id > gl_SubgroupInvocationID) {
416     *          result = first_result;
417     *       }
418     *       break;
419     *    }
420     * }
421     *
422     * The idea is to guarantee, on each iteration of the loop, that anything
423     * reading from first_id gets the correct value, so that we can then kill
424     * it off by breaking out of the loop. Before doing that we also have to
425     * ensure that first_id invocation gets the correct value. It only won't be
426     * assigned the correct value already if the invocation it's reading from
427     * isn't already killed off, that is, if it's later than its own ID.
428     * Invocations where id <= gl_SubgroupInvocationID will be assigned their
429     * result in the first if, and invocations where id >
430     * gl_SubgroupInvocationID will be assigned their result in the second if.
431     *
432     * We do this more complicated loop rather than looping over all id's
433     * explicitly because at this point we don't know the "actual" subgroup
434     * size and at the moment there's no way to get at it, which means we may
435     * loop over always-inactive invocations.
436     */
437 
438    nir_def *subgroup_id = nir_load_subgroup_invocation(b);
439 
440    nir_variable *result =
441       nir_local_variable_create(b->impl, glsl_type_for_ssa(val), "result");
442 
443    nir_loop *loop = nir_push_loop(b);
444    {
445       nir_def *first_id = nir_read_first_invocation(b, subgroup_id);
446       nir_def *first_val = nir_read_first_invocation(b, val);
447       nir_def *first_result =
448          nir_read_invocation(b, val, nir_read_first_invocation(b, id));
449 
450       nir_if *nif = nir_push_if(b, nir_ieq(b, id, first_id));
451       {
452          nir_store_var(b, result, first_val, BITFIELD_MASK(val->num_components));
453       }
454       nir_pop_if(b, nif);
455 
456       nir_if *nif2 = nir_push_if(b, nir_elect(b, 1));
457       {
458          nir_if *nif3 = nir_push_if(b, nir_ult(b, subgroup_id, id));
459          {
460             nir_store_var(b, result, first_result, BITFIELD_MASK(val->num_components));
461          }
462          nir_pop_if(b, nif3);
463 
464          nir_jump(b, nir_jump_break);
465       }
466       nir_pop_if(b, nif2);
467    }
468    nir_pop_loop(b, loop);
469 
470    return nir_load_var(b, result);
471 }
472 
473 static nir_def *
lower_boolean_shuffle(nir_builder * b,nir_intrinsic_instr * intrin,const nir_lower_subgroups_options * options)474 lower_boolean_shuffle(nir_builder *b, nir_intrinsic_instr *intrin,
475                       const nir_lower_subgroups_options *options)
476 {
477    assert(options->ballot_components == 1 && options->subgroup_size);
478    nir_def *ballot = nir_ballot_relaxed(b, 1, options->ballot_bit_size, intrin->src[0].ssa);
479 
480    nir_def *index = NULL;
481 
482    /* If the shuffle amount isn't constant, it might be divergent but
483     * inverse_ballot requires a uniform source, so take a different path.
484     * rotate allows us to assume the delta is uniform unlike shuffle_up/down.
485     */
486    switch (intrin->intrinsic) {
487    case nir_intrinsic_shuffle_up:
488       if (nir_src_is_const(intrin->src[1]))
489          ballot = nir_ishl(b, ballot, intrin->src[1].ssa);
490       else
491          index = nir_isub(b, nir_load_subgroup_invocation(b), intrin->src[1].ssa);
492       break;
493    case nir_intrinsic_shuffle_down:
494       if (nir_src_is_const(intrin->src[1]))
495          ballot = nir_ushr(b, ballot, intrin->src[1].ssa);
496       else
497          index = nir_iadd(b, nir_load_subgroup_invocation(b), intrin->src[1].ssa);
498       break;
499    case nir_intrinsic_shuffle_xor:
500       index = nir_ixor(b, nir_load_subgroup_invocation(b), intrin->src[1].ssa);
501       break;
502    case nir_intrinsic_rotate: {
503       nir_def *delta = nir_as_uniform(b, intrin->src[1].ssa);
504       uint32_t cluster_size = nir_intrinsic_cluster_size(intrin);
505       unsigned subgroup_size = get_max_subgroup_size(options);
506       cluster_size = cluster_size ? cluster_size : subgroup_size;
507       cluster_size = MIN2(cluster_size, subgroup_size);
508       if (cluster_size == 1) {
509          return intrin->src[0].ssa;
510       } else if (cluster_size == 2) {
511          delta = nir_iand_imm(b, delta, cluster_size - 1);
512          nir_def *lo = nir_iand_imm(b, nir_ushr_imm(b, ballot, 1), 0x5555555555555555ull);
513          nir_def *hi = nir_iand_imm(b, nir_ishl_imm(b, ballot, 1), 0xaaaaaaaaaaaaaaaaull);
514          ballot = nir_bcsel(b, nir_ine_imm(b, delta, 0), nir_ior(b, hi, lo), ballot);
515       } else if (cluster_size == ballot->bit_size) {
516          ballot = nir_uror(b, ballot, delta);
517       } else if (cluster_size == 32) {
518          nir_def *unpacked = nir_unpack_64_2x32(b, ballot);
519          unpacked = nir_uror(b, unpacked, delta);
520          ballot = nir_pack_64_2x32(b, unpacked);
521       } else {
522          delta = nir_iand_imm(b, delta, cluster_size - 1);
523          nir_def *delta_rev = nir_isub_imm(b, cluster_size, delta);
524          nir_def *mask = nir_mask(b, delta_rev, ballot->bit_size);
525          for (uint32_t i = cluster_size; i < ballot->bit_size; i *= 2) {
526             mask = nir_ior(b, nir_ishl_imm(b, mask, i), mask);
527          }
528          nir_def *lo = nir_iand(b, nir_ushr(b, ballot, delta), mask);
529          nir_def *hi = nir_iand(b, nir_ishl(b, ballot, delta_rev), nir_inot(b, mask));
530          ballot = nir_ior(b, lo, hi);
531       }
532       break;
533    }
534    case nir_intrinsic_shuffle:
535       index = intrin->src[1].ssa;
536       break;
537    case nir_intrinsic_read_invocation:
538       index = nir_as_uniform(b, intrin->src[1].ssa);
539       break;
540    default:
541       unreachable("not a boolean shuffle");
542    }
543 
544    if (index) {
545       nir_def *mask = nir_ishl(b, nir_imm_intN_t(b, 1, ballot->bit_size), index);
546       return nir_ine_imm(b, nir_iand(b, ballot, mask), 0);
547    } else {
548       return nir_inverse_ballot(b, 1, ballot);
549    }
550 }
551 
552 static nir_def *
vec_bit_count(nir_builder * b,nir_def * value)553 vec_bit_count(nir_builder *b, nir_def *value)
554 {
555    nir_def *vec_result = nir_bit_count(b, value);
556    nir_def *result = nir_channel(b, vec_result, 0);
557    for (unsigned i = 1; i < value->num_components; i++)
558       result = nir_iadd(b, result, nir_channel(b, vec_result, i));
559    return result;
560 }
561 
562 /* produce a bitmask of 111...000...111... alternating between "size"
563  * 1's and "size" 0's (the LSB is 1).
564  */
565 static uint64_t
reduce_mask(unsigned size,unsigned ballot_bit_size)566 reduce_mask(unsigned size, unsigned ballot_bit_size)
567 {
568    uint64_t mask = 0;
569    for (unsigned i = 0; i < ballot_bit_size; i += 2 * size) {
570       mask |= ((1ull << size) - 1) << i;
571    }
572 
573    return mask;
574 }
575 
576 /* operate on a uniform per-thread bitmask provided by ballot() to perform the
577  * desired Boolean reduction. Assumes that the identity of the operation is
578  * false (so, no iand).
579  */
580 static nir_def *
lower_boolean_reduce_internal(nir_builder * b,nir_def * src,unsigned cluster_size,nir_op op,const nir_lower_subgroups_options * options)581 lower_boolean_reduce_internal(nir_builder *b, nir_def *src,
582                               unsigned cluster_size, nir_op op,
583                               const nir_lower_subgroups_options *options)
584 {
585    for (unsigned size = 1; size < cluster_size; size *= 2) {
586       nir_def *shifted = nir_ushr_imm(b, src, size);
587       src = nir_build_alu2(b, op, shifted, src);
588       uint64_t mask = reduce_mask(size, options->ballot_bit_size);
589       src = nir_iand_imm(b, src, mask);
590       shifted = nir_ishl_imm(b, src, size);
591       src = nir_ior(b, src, shifted);
592    }
593 
594    return src;
595 }
596 
597 /* operate on a uniform per-thread bitmask provided by ballot() to perform the
598  * desired Boolean inclusive scan. Assumes that the identity of the operation is
599  * false (so, no iand).
600  */
601 static nir_def *
lower_boolean_scan_internal(nir_builder * b,nir_def * src,nir_op op,const nir_lower_subgroups_options * options)602 lower_boolean_scan_internal(nir_builder *b, nir_def *src,
603                             nir_op op,
604                             const nir_lower_subgroups_options *options)
605 {
606    if (op == nir_op_ior) {
607       /* We want to return a bitmask with all 1's starting at the first 1 in
608        * src. -src is equivalent to ~src + 1. While src | ~src returns all
609        * 1's, src | (~src + 1) returns all 1's except for the bits changed by
610        * the increment. Any 1's before the least significant 0 of ~src are
611        * turned into 0 (zeroing those bits after or'ing) and the least
612        * signficant 0 of ~src is turned into 1 (not doing anything). So the
613        * final output is what we want.
614        */
615       return nir_ior(b, src, nir_ineg(b, src));
616    } else {
617       assert(op == nir_op_ixor);
618       for (unsigned shift = 1; shift < options->ballot_bit_size; shift *= 2) {
619          src = nir_ixor(b, src, nir_ishl_imm(b, src, shift));
620       }
621       return src;
622    }
623 }
624 
625 static nir_def *
lower_boolean_reduce(nir_builder * b,nir_intrinsic_instr * intrin,const nir_lower_subgroups_options * options)626 lower_boolean_reduce(nir_builder *b, nir_intrinsic_instr *intrin,
627                      const nir_lower_subgroups_options *options)
628 {
629    assert(intrin->num_components == 1);
630    assert(options->ballot_components == 1);
631 
632    unsigned cluster_size =
633       intrin->intrinsic == nir_intrinsic_reduce ? nir_intrinsic_cluster_size(intrin) : 0;
634    nir_op op = nir_intrinsic_reduction_op(intrin);
635 
636    /* For certain cluster sizes, reductions of iand and ior can be implemented
637     * more efficiently.
638     */
639    if (intrin->intrinsic == nir_intrinsic_reduce) {
640       if (cluster_size == 0) {
641          if (op == nir_op_iand)
642             return nir_vote_all(b, 1, intrin->src[0].ssa);
643          else if (op == nir_op_ior)
644             return nir_vote_any(b, 1, intrin->src[0].ssa);
645          else if (op == nir_op_ixor)
646             return nir_i2b(b, nir_iand_imm(b, vec_bit_count(b, nir_ballot(b,
647                                                                           options->ballot_components,
648                                                                           options->ballot_bit_size,
649                                                                           intrin->src[0].ssa)),
650                                            1));
651          else
652             unreachable("bad boolean reduction op");
653       }
654 
655       if (cluster_size == 4) {
656          if (op == nir_op_iand)
657             return nir_quad_vote_all(b, 1, intrin->src[0].ssa);
658          else if (op == nir_op_ior)
659             return nir_quad_vote_any(b, 1, intrin->src[0].ssa);
660       }
661    }
662 
663    nir_def *src = intrin->src[0].ssa;
664 
665    /* Apply DeMorgan's law to implement "and" reductions, since all the
666     * lower_boolean_*_internal() functions assume an identity of 0 to make the
667     * generated code shorter.
668     */
669    nir_op new_op = (op == nir_op_iand) ? nir_op_ior : op;
670    if (op == nir_op_iand) {
671       src = nir_inot(b, src);
672    }
673 
674    nir_def *val = nir_ballot(b, options->ballot_components, options->ballot_bit_size, src);
675 
676    switch (intrin->intrinsic) {
677    case nir_intrinsic_reduce:
678       val = lower_boolean_reduce_internal(b, val, cluster_size, new_op, options);
679       break;
680    case nir_intrinsic_inclusive_scan:
681       val = lower_boolean_scan_internal(b, val, new_op, options);
682       break;
683    case nir_intrinsic_exclusive_scan:
684       val = lower_boolean_scan_internal(b, val, new_op, options);
685       val = nir_ishl_imm(b, val, 1);
686       break;
687    default:
688       unreachable("bad intrinsic");
689    }
690 
691    if (op == nir_op_iand) {
692       val = nir_inot(b, val);
693    }
694 
695    return nir_inverse_ballot(b, 1, val);
696 }
697 
698 static nir_def *
build_identity(nir_builder * b,unsigned bit_size,nir_op op)699 build_identity(nir_builder *b, unsigned bit_size, nir_op op)
700 {
701    nir_const_value ident_const = nir_alu_binop_identity(op, bit_size);
702    return nir_build_imm(b, 1, bit_size, &ident_const);
703 }
704 
705 /* Implementation of scan/reduce that assumes a full subgroup */
706 static nir_def *
build_scan_full(nir_builder * b,nir_intrinsic_op op,nir_op red_op,nir_def * data,unsigned cluster_size)707 build_scan_full(nir_builder *b, nir_intrinsic_op op, nir_op red_op,
708                 nir_def *data, unsigned cluster_size)
709 {
710    switch (op) {
711    case nir_intrinsic_exclusive_scan:
712    case nir_intrinsic_inclusive_scan: {
713       for (unsigned i = 1; i < cluster_size; i *= 2) {
714          nir_def *idx = nir_load_subgroup_invocation(b);
715          nir_def *has_buddy = nir_ige_imm(b, idx, i);
716 
717          nir_def *buddy_data = nir_shuffle_up(b, data, nir_imm_int(b, i));
718          nir_def *accum = nir_build_alu2(b, red_op, data, buddy_data);
719          data = nir_bcsel(b, has_buddy, accum, data);
720       }
721 
722       if (op == nir_intrinsic_exclusive_scan) {
723          /* For exclusive scans, we need to shift one more time and fill in the
724           * bottom channel with identity.
725           */
726          nir_def *idx = nir_load_subgroup_invocation(b);
727          nir_def *has_buddy = nir_ige_imm(b, idx, 1);
728 
729          nir_def *buddy_data = nir_shuffle_up(b, data, nir_imm_int(b, 1));
730          nir_def *identity = build_identity(b, data->bit_size, red_op);
731          data = nir_bcsel(b, has_buddy, buddy_data, identity);
732       }
733 
734       return data;
735    }
736 
737    case nir_intrinsic_reduce: {
738       for (unsigned i = 1; i < cluster_size; i *= 2) {
739          nir_def *buddy_data = nir_shuffle_xor(b, data, nir_imm_int(b, i));
740          data = nir_build_alu2(b, red_op, data, buddy_data);
741       }
742       return data;
743    }
744 
745    default:
746       unreachable("Unsupported scan/reduce op");
747    }
748 }
749 
750 /* Fully generic implementation of scan/reduce that takes a mask */
751 static nir_def *
build_scan_reduce(nir_builder * b,nir_intrinsic_op op,nir_op red_op,nir_def * data,nir_def * mask,unsigned max_mask_bits,const nir_lower_subgroups_options * options)752 build_scan_reduce(nir_builder *b, nir_intrinsic_op op, nir_op red_op,
753                   nir_def *data, nir_def *mask, unsigned max_mask_bits,
754                   const nir_lower_subgroups_options *options)
755 {
756    nir_def *lt_mask = nir_load_subgroup_lt_mask(b, options->ballot_components,
757                                                 options->ballot_bit_size);
758 
759    /* Mask of all channels whose values we need to accumulate.  Our own value
760     * is already in accum, if inclusive, thanks to the initialization above.
761     * We only need to consider lower indexed invocations.
762     */
763    nir_def *remaining = nir_iand(b, mask, lt_mask);
764 
765    for (unsigned i = 1; i < max_mask_bits; i *= 2) {
766       /* At each step, our buddy channel is the first channel we have yet to
767        * take into account in the accumulator.
768        */
769       nir_def *has_buddy = nir_bany_inequal(b, remaining, nir_imm_int(b, 0));
770       nir_def *buddy = nir_ballot_find_msb(b, 32, remaining);
771 
772       /* Accumulate with our buddy channel, if any */
773       nir_def *buddy_data = nir_shuffle(b, data, buddy);
774       nir_def *accum = nir_build_alu2(b, red_op, data, buddy_data);
775       data = nir_bcsel(b, has_buddy, accum, data);
776 
777       /* We just took into account everything in our buddy's accumulator from
778        * the previous step.  The only things remaining are whatever channels
779        * were remaining for our buddy.
780        */
781       nir_def *buddy_remaining = nir_shuffle(b, remaining, buddy);
782       remaining = nir_bcsel(b, has_buddy, buddy_remaining, nir_imm_int(b, 0));
783    }
784 
785    switch (op) {
786    case nir_intrinsic_exclusive_scan: {
787       /* For exclusive scans, we need to shift one more time and fill in the
788        * bottom channel with identity.
789        *
790        * Some of this will get CSE'd with the first step but that's okay. The
791        * code is cleaner this way.
792        */
793       nir_def *lower = nir_iand(b, mask, lt_mask);
794       nir_def *has_buddy = nir_bany_inequal(b, lower, nir_imm_int(b, 0));
795       nir_def *buddy = nir_ballot_find_msb(b, 32, lower);
796 
797       nir_def *buddy_data = nir_shuffle(b, data, buddy);
798       nir_def *identity = build_identity(b, data->bit_size, red_op);
799       return nir_bcsel(b, has_buddy, buddy_data, identity);
800    }
801 
802    case nir_intrinsic_inclusive_scan:
803       return data;
804 
805    case nir_intrinsic_reduce: {
806       /* For reductions, we need to take the top value of the scan */
807       nir_def *idx = nir_ballot_find_msb(b, 32, mask);
808       return nir_shuffle(b, data, idx);
809    }
810 
811    default:
812       unreachable("Unsupported scan/reduce op");
813    }
814 }
815 
816 static nir_def *
build_cluster_mask(nir_builder * b,unsigned cluster_size,const nir_lower_subgroups_options * options)817 build_cluster_mask(nir_builder *b, unsigned cluster_size,
818                    const nir_lower_subgroups_options *options)
819 {
820    nir_def *idx = nir_load_subgroup_invocation(b);
821    nir_def *cluster = nir_iand_imm(b, idx, ~(uint64_t)(cluster_size - 1));
822 
823    if (cluster_size <= options->ballot_bit_size) {
824       return build_ballot_imm_ishl(b, BITFIELD_MASK(cluster_size), cluster,
825                                    options);
826    }
827 
828    /* Since the cluster size and the ballot bit size are both powers of 2,
829     * cluster size will be a multiple of the ballot bit size. Therefore, each
830     * ballot component will be either all ones or all zeros. Build a vec for
831     * which each component holds the value of `cluster` for which the mask
832     * should be all ones.
833     */
834    nir_const_value cluster_sel_const[4];
835    assert(ARRAY_SIZE(cluster_sel_const) >= options->ballot_components);
836 
837    for (unsigned i = 0; i < options->ballot_components; i++) {
838       unsigned cluster_val =
839          ROUND_DOWN_TO(i * options->ballot_bit_size, cluster_size);
840       cluster_sel_const[i] =
841          nir_const_value_for_uint(cluster_val, options->ballot_bit_size);
842    }
843 
844    nir_def *cluster_sel =
845       nir_build_imm(b, options->ballot_components, options->ballot_bit_size,
846                     cluster_sel_const);
847    nir_def *ones = nir_imm_intN_t(b, -1, options->ballot_bit_size);
848    nir_def *zeros = nir_imm_intN_t(b, 0, options->ballot_bit_size);
849    return nir_bcsel(b, nir_ieq(b, cluster, cluster_sel), ones, zeros);
850 }
851 
852 static nir_def *
lower_scan_reduce(nir_builder * b,nir_intrinsic_instr * intrin,const nir_lower_subgroups_options * options)853 lower_scan_reduce(nir_builder *b, nir_intrinsic_instr *intrin,
854                   const nir_lower_subgroups_options *options)
855 {
856    const nir_op red_op = nir_intrinsic_reduction_op(intrin);
857    unsigned subgroup_size = get_max_subgroup_size(options);
858 
859    /* Grab the cluster size */
860    unsigned cluster_size = subgroup_size;
861    if (nir_intrinsic_has_cluster_size(intrin)) {
862       cluster_size = nir_intrinsic_cluster_size(intrin);
863       if (cluster_size == 0 || cluster_size > subgroup_size)
864          cluster_size = subgroup_size;
865    }
866 
867    /* Check if all invocations are active. If so, we use the fast path. */
868    nir_def *mask = nir_ballot(b, options->ballot_components,
869                               options->ballot_bit_size, nir_imm_true(b));
870 
871    nir_def *full, *partial;
872    nir_push_if(b, nir_ball_iequal(b, mask, build_subgroup_mask(b, options)));
873    {
874       full = build_scan_full(b, intrin->intrinsic, red_op,
875                              intrin->src[0].ssa, cluster_size);
876    }
877    nir_push_else(b, NULL);
878    {
879       /* Mask according to the cluster size */
880       if (cluster_size < subgroup_size) {
881          nir_def *cluster_mask = build_cluster_mask(b, cluster_size, options);
882          mask = nir_iand(b, mask, cluster_mask);
883       }
884 
885       partial = build_scan_reduce(b, intrin->intrinsic, red_op,
886                                   intrin->src[0].ssa, mask, cluster_size,
887                                   options);
888    }
889    nir_pop_if(b, NULL);
890    return nir_if_phi(b, full, partial);
891 }
892 
893 static bool
lower_subgroups_filter(const nir_instr * instr,const void * _options)894 lower_subgroups_filter(const nir_instr *instr, const void *_options)
895 {
896    const nir_lower_subgroups_options *options = _options;
897 
898    if (options->filter) {
899       return options->filter(instr, options->filter_data);
900    }
901 
902    return instr->type == nir_instr_type_intrinsic;
903 }
904 
905 static nir_def *
build_subgroup_eq_mask(nir_builder * b,const nir_lower_subgroups_options * options)906 build_subgroup_eq_mask(nir_builder *b,
907                        const nir_lower_subgroups_options *options)
908 {
909    nir_def *subgroup_idx = nir_load_subgroup_invocation(b);
910 
911    return build_ballot_imm_ishl(b, 1, subgroup_idx, options);
912 }
913 
914 static nir_def *
build_subgroup_ge_mask(nir_builder * b,const nir_lower_subgroups_options * options)915 build_subgroup_ge_mask(nir_builder *b,
916                        const nir_lower_subgroups_options *options)
917 {
918    nir_def *subgroup_idx = nir_load_subgroup_invocation(b);
919 
920    return build_ballot_imm_ishl(b, ~0ull, subgroup_idx, options);
921 }
922 
923 static nir_def *
build_subgroup_gt_mask(nir_builder * b,const nir_lower_subgroups_options * options)924 build_subgroup_gt_mask(nir_builder *b,
925                        const nir_lower_subgroups_options *options)
926 {
927    nir_def *subgroup_idx = nir_load_subgroup_invocation(b);
928 
929    return build_ballot_imm_ishl(b, ~1ull, subgroup_idx, options);
930 }
931 
932 static nir_def *
build_subgroup_quad_mask(nir_builder * b,const nir_lower_subgroups_options * options)933 build_subgroup_quad_mask(nir_builder *b,
934                          const nir_lower_subgroups_options *options)
935 {
936    nir_def *subgroup_idx = nir_load_subgroup_invocation(b);
937    nir_def *quad_first_idx = nir_iand_imm(b, subgroup_idx, ~0x3);
938 
939    return build_ballot_imm_ishl(b, 0xf, quad_first_idx, options);
940 }
941 
942 static nir_def *
build_quad_vote_any(nir_builder * b,nir_def * src,const nir_lower_subgroups_options * options)943 build_quad_vote_any(nir_builder *b, nir_def *src,
944                     const nir_lower_subgroups_options *options)
945 {
946    nir_def *ballot = nir_ballot(b, options->ballot_components,
947                                    options->ballot_bit_size,
948                                    src);
949    nir_def *mask = build_subgroup_quad_mask(b, options);
950 
951    return nir_ine_imm(b, nir_iand(b, ballot, mask), 0);
952 }
953 
954 static nir_def *
vec_find_lsb(nir_builder * b,nir_def * value)955 vec_find_lsb(nir_builder *b, nir_def *value)
956 {
957    nir_def *vec_result = nir_find_lsb(b, value);
958    nir_def *result = nir_imm_int(b, -1);
959    for (int i = value->num_components - 1; i >= 0; i--) {
960       nir_def *channel = nir_channel(b, vec_result, i);
961       /* result = channel >= 0 ? (i * bitsize + channel) : result */
962       result = nir_bcsel(b, nir_ige_imm(b, channel, 0),
963                          nir_iadd_imm(b, channel, i * value->bit_size),
964                          result);
965    }
966    return result;
967 }
968 
969 static nir_def *
vec_find_msb(nir_builder * b,nir_def * value)970 vec_find_msb(nir_builder *b, nir_def *value)
971 {
972    nir_def *vec_result = nir_ufind_msb(b, value);
973    nir_def *result = nir_imm_int(b, -1);
974    for (unsigned i = 0; i < value->num_components; i++) {
975       nir_def *channel = nir_channel(b, vec_result, i);
976       /* result = channel >= 0 ? (i * bitsize + channel) : result */
977       result = nir_bcsel(b, nir_ige_imm(b, channel, 0),
978                          nir_iadd_imm(b, channel, i * value->bit_size),
979                          result);
980    }
981    return result;
982 }
983 
984 static nir_def *
lower_dynamic_quad_broadcast(nir_builder * b,nir_intrinsic_instr * intrin,const nir_lower_subgroups_options * options)985 lower_dynamic_quad_broadcast(nir_builder *b, nir_intrinsic_instr *intrin,
986                              const nir_lower_subgroups_options *options)
987 {
988    if (!options->lower_quad_broadcast_dynamic_to_const)
989       return lower_to_shuffle(b, intrin, options);
990 
991    nir_def *dst = NULL;
992 
993    for (unsigned i = 0; i < 4; ++i) {
994       nir_def *qbcst = nir_quad_broadcast(b, intrin->src[0].ssa,
995                                               nir_imm_int(b, i));
996 
997       if (i)
998          dst = nir_bcsel(b, nir_ieq_imm(b, intrin->src[1].ssa, i),
999                          qbcst, dst);
1000       else
1001          dst = qbcst;
1002    }
1003 
1004    return dst;
1005 }
1006 
1007 static nir_def *
lower_first_invocation_to_ballot(nir_builder * b,nir_intrinsic_instr * intrin,const nir_lower_subgroups_options * options)1008 lower_first_invocation_to_ballot(nir_builder *b, nir_intrinsic_instr *intrin,
1009                                  const nir_lower_subgroups_options *options)
1010 {
1011    return nir_ballot_find_lsb(b, 32, nir_ballot(b, 4, 32, nir_imm_true(b)));
1012 }
1013 
1014 static nir_def *
lower_read_first_invocation(nir_builder * b,nir_intrinsic_instr * intrin)1015 lower_read_first_invocation(nir_builder *b, nir_intrinsic_instr *intrin)
1016 {
1017    return nir_read_invocation(b, intrin->src[0].ssa, nir_first_invocation(b));
1018 }
1019 
1020 static nir_def *
lower_read_invocation_to_cond(nir_builder * b,nir_intrinsic_instr * intrin)1021 lower_read_invocation_to_cond(nir_builder *b, nir_intrinsic_instr *intrin)
1022 {
1023    return nir_read_invocation_cond_ir3(b, intrin->def.bit_size,
1024                                        intrin->src[0].ssa,
1025                                        nir_ieq(b, intrin->src[1].ssa,
1026                                                nir_load_subgroup_invocation(b)));
1027 }
1028 
1029 static nir_def *
lower_subgroups_instr(nir_builder * b,nir_instr * instr,void * _options)1030 lower_subgroups_instr(nir_builder *b, nir_instr *instr, void *_options)
1031 {
1032    const nir_lower_subgroups_options *options = _options;
1033 
1034    nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
1035    switch (intrin->intrinsic) {
1036    case nir_intrinsic_vote_any:
1037    case nir_intrinsic_vote_all:
1038       if (options->lower_vote_trivial)
1039          return intrin->src[0].ssa;
1040       break;
1041 
1042    case nir_intrinsic_vote_feq:
1043    case nir_intrinsic_vote_ieq:
1044       if (options->lower_vote_trivial)
1045          return nir_imm_true(b);
1046 
1047       if (nir_src_bit_size(intrin->src[0]) == 1) {
1048          if (options->lower_vote_bool_eq)
1049             return lower_vote_eq(b, intrin);
1050       } else {
1051          if (options->lower_vote_eq)
1052             return lower_vote_eq(b, intrin);
1053       }
1054 
1055       if (options->lower_to_scalar && intrin->num_components > 1)
1056          return lower_vote_eq_to_scalar(b, intrin);
1057       break;
1058 
1059    case nir_intrinsic_load_subgroup_size:
1060       if (options->subgroup_size)
1061          return nir_imm_int(b, options->subgroup_size);
1062       break;
1063 
1064    case nir_intrinsic_first_invocation:
1065       if (options->subgroup_size == 1)
1066          return nir_imm_int(b, 0);
1067 
1068       if (options->lower_first_invocation_to_ballot)
1069          return lower_first_invocation_to_ballot(b, intrin, options);
1070 
1071       break;
1072 
1073    case nir_intrinsic_read_invocation:
1074       if (options->lower_to_scalar && intrin->num_components > 1)
1075          return lower_subgroup_op_to_scalar(b, intrin);
1076 
1077       if (options->lower_boolean_shuffle && intrin->src[0].ssa->bit_size == 1)
1078          return lower_boolean_shuffle(b, intrin, options);
1079 
1080       if (options->lower_read_invocation_to_cond)
1081          return lower_read_invocation_to_cond(b, intrin);
1082 
1083       break;
1084 
1085    case nir_intrinsic_read_first_invocation:
1086       if (options->lower_to_scalar && intrin->num_components > 1)
1087          return lower_subgroup_op_to_scalar(b, intrin);
1088 
1089       if (options->lower_read_first_invocation)
1090          return lower_read_first_invocation(b, intrin);
1091       break;
1092 
1093    case nir_intrinsic_load_subgroup_eq_mask:
1094    case nir_intrinsic_load_subgroup_ge_mask:
1095    case nir_intrinsic_load_subgroup_gt_mask:
1096    case nir_intrinsic_load_subgroup_le_mask:
1097    case nir_intrinsic_load_subgroup_lt_mask: {
1098       if (!options->lower_subgroup_masks)
1099          return NULL;
1100 
1101       nir_def *val;
1102       switch (intrin->intrinsic) {
1103       case nir_intrinsic_load_subgroup_eq_mask:
1104          val = build_subgroup_eq_mask(b, options);
1105          break;
1106       case nir_intrinsic_load_subgroup_ge_mask:
1107          val = nir_iand(b, build_subgroup_ge_mask(b, options),
1108                         build_subgroup_mask(b, options));
1109          break;
1110       case nir_intrinsic_load_subgroup_gt_mask:
1111          val = nir_iand(b, build_subgroup_gt_mask(b, options),
1112                         build_subgroup_mask(b, options));
1113          break;
1114       case nir_intrinsic_load_subgroup_le_mask:
1115          val = nir_inot(b, build_subgroup_gt_mask(b, options));
1116          break;
1117       case nir_intrinsic_load_subgroup_lt_mask:
1118          val = nir_inot(b, build_subgroup_ge_mask(b, options));
1119          break;
1120       default:
1121          unreachable("you seriously can't tell this is unreachable?");
1122       }
1123 
1124       return uint_to_ballot_type(b, val,
1125                                  intrin->def.num_components,
1126                                  intrin->def.bit_size);
1127    }
1128 
1129    case nir_intrinsic_ballot: {
1130       if (intrin->def.num_components == options->ballot_components &&
1131           intrin->def.bit_size == options->ballot_bit_size)
1132          return NULL;
1133 
1134       nir_def *ballot =
1135          nir_ballot(b, options->ballot_components, options->ballot_bit_size,
1136                     intrin->src[0].ssa);
1137 
1138       return uint_to_ballot_type(b, ballot,
1139                                  intrin->def.num_components,
1140                                  intrin->def.bit_size);
1141    }
1142 
1143    case nir_intrinsic_inverse_ballot:
1144       if (options->lower_inverse_ballot) {
1145          return nir_ballot_bitfield_extract(b, 1, intrin->src[0].ssa,
1146                                             nir_load_subgroup_invocation(b));
1147       } else if (intrin->src[0].ssa->num_components != options->ballot_components ||
1148                  intrin->src[0].ssa->bit_size != options->ballot_bit_size) {
1149          return nir_inverse_ballot(b, 1, ballot_type_to_uint(b, intrin->src[0].ssa, options));
1150       }
1151       break;
1152 
1153    case nir_intrinsic_ballot_bitfield_extract:
1154    case nir_intrinsic_ballot_bit_count_reduce:
1155    case nir_intrinsic_ballot_find_lsb:
1156    case nir_intrinsic_ballot_find_msb: {
1157       nir_def *int_val = ballot_type_to_uint(b, intrin->src[0].ssa,
1158                                              options);
1159 
1160       if (intrin->intrinsic != nir_intrinsic_ballot_bitfield_extract &&
1161           intrin->intrinsic != nir_intrinsic_ballot_find_lsb) {
1162          /* For OpGroupNonUniformBallotFindMSB, the SPIR-V Spec says:
1163           *
1164           *    "Find the most significant bit set to 1 in Value, considering
1165           *    only the bits in Value required to represent all bits of the
1166           *    group’s invocations.  If none of the considered bits is set to
1167           *    1, the result is undefined."
1168           *
1169           * It has similar text for the other three.  This means that, in case
1170           * the subgroup size is less than 32, we have to mask off the unused
1171           * bits.  If the subgroup size is fixed and greater than or equal to
1172           * 32, the mask will be 0xffffffff and nir_opt_algebraic will delete
1173           * the iand.
1174           *
1175           * We only have to worry about this for BitCount and FindMSB because
1176           * FindLSB counts from the bottom and BitfieldExtract selects
1177           * individual bits.  In either case, if run outside the range of
1178           * valid bits, we hit the undefined results case and we can return
1179           * anything we want.
1180           */
1181          int_val = nir_iand(b, int_val, build_subgroup_mask(b, options));
1182       }
1183 
1184       switch (intrin->intrinsic) {
1185       case nir_intrinsic_ballot_bitfield_extract: {
1186          nir_def *idx = intrin->src[1].ssa;
1187          if (int_val->num_components > 1) {
1188             /* idx will be truncated by nir_ushr, so we just need to select
1189              * the right component using the bits of idx that are truncated in
1190              * the shift.
1191              */
1192             int_val =
1193                nir_vector_extract(b, int_val,
1194                                   nir_udiv_imm(b, idx, int_val->bit_size));
1195          }
1196 
1197          return nir_test_mask(b, nir_ushr(b, int_val, idx), 1);
1198       }
1199       case nir_intrinsic_ballot_bit_count_reduce:
1200          return vec_bit_count(b, int_val);
1201       case nir_intrinsic_ballot_find_lsb:
1202          return vec_find_lsb(b, int_val);
1203       case nir_intrinsic_ballot_find_msb:
1204          return vec_find_msb(b, int_val);
1205       default:
1206          unreachable("you seriously can't tell this is unreachable?");
1207       }
1208    }
1209 
1210    case nir_intrinsic_ballot_bit_count_exclusive:
1211    case nir_intrinsic_ballot_bit_count_inclusive: {
1212       nir_def *int_val = ballot_type_to_uint(b, intrin->src[0].ssa,
1213                                              options);
1214       if (options->lower_ballot_bit_count_to_mbcnt_amd) {
1215          nir_def *acc;
1216          if (intrin->intrinsic == nir_intrinsic_ballot_bit_count_exclusive) {
1217             acc = nir_imm_int(b, 0);
1218          } else {
1219             acc = nir_iand_imm(b, nir_u2u32(b, int_val), 0x1);
1220             int_val = nir_ushr_imm(b, int_val, 1);
1221          }
1222          return nir_mbcnt_amd(b, int_val, acc);
1223       }
1224 
1225       nir_def *mask;
1226       if (intrin->intrinsic == nir_intrinsic_ballot_bit_count_inclusive) {
1227          mask = nir_inot(b, build_subgroup_gt_mask(b, options));
1228       } else {
1229          mask = nir_inot(b, build_subgroup_ge_mask(b, options));
1230       }
1231 
1232       return vec_bit_count(b, nir_iand(b, int_val, mask));
1233    }
1234 
1235    case nir_intrinsic_elect: {
1236       if (!options->lower_elect)
1237          return NULL;
1238 
1239       return nir_ieq(b, nir_load_subgroup_invocation(b), nir_first_invocation(b));
1240    }
1241 
1242    case nir_intrinsic_shuffle:
1243       if (options->lower_shuffle &&
1244           (!options->lower_boolean_shuffle || intrin->src[0].ssa->bit_size != 1))
1245          return lower_shuffle(b, intrin);
1246       else if (options->lower_to_scalar && intrin->num_components > 1)
1247          return lower_subgroup_op_to_scalar(b, intrin);
1248       else if (options->lower_boolean_shuffle && intrin->src[0].ssa->bit_size == 1)
1249          return lower_boolean_shuffle(b, intrin, options);
1250       else if (options->lower_shuffle_to_32bit && intrin->src[0].ssa->bit_size == 64)
1251          return lower_subgroup_op_to_32bit(b, intrin);
1252       break;
1253    case nir_intrinsic_shuffle_xor:
1254    case nir_intrinsic_shuffle_up:
1255    case nir_intrinsic_shuffle_down:
1256       if (options->lower_relative_shuffle &&
1257           (!options->lower_boolean_shuffle || intrin->src[0].ssa->bit_size != 1))
1258          return lower_to_shuffle(b, intrin, options);
1259       else if (options->lower_to_scalar && intrin->num_components > 1)
1260          return lower_subgroup_op_to_scalar(b, intrin);
1261       else if (options->lower_boolean_shuffle && intrin->src[0].ssa->bit_size == 1)
1262          return lower_boolean_shuffle(b, intrin, options);
1263       else if (options->lower_shuffle_to_32bit && intrin->src[0].ssa->bit_size == 64)
1264          return lower_subgroup_op_to_32bit(b, intrin);
1265       break;
1266 
1267    case nir_intrinsic_quad_broadcast:
1268    case nir_intrinsic_quad_swap_horizontal:
1269    case nir_intrinsic_quad_swap_vertical:
1270    case nir_intrinsic_quad_swap_diagonal:
1271       if (options->lower_quad ||
1272           (options->lower_quad_broadcast_dynamic &&
1273            intrin->intrinsic == nir_intrinsic_quad_broadcast &&
1274            !nir_src_is_const(intrin->src[1])))
1275          return lower_dynamic_quad_broadcast(b, intrin, options);
1276       else if (options->lower_to_scalar && intrin->num_components > 1)
1277          return lower_subgroup_op_to_scalar(b, intrin);
1278       break;
1279 
1280    case nir_intrinsic_quad_vote_any:
1281       if (options->lower_quad_vote)
1282          return build_quad_vote_any(b, intrin->src[0].ssa, options);
1283       break;
1284    case nir_intrinsic_quad_vote_all:
1285       if (options->lower_quad_vote) {
1286          nir_def *not_src = nir_inot(b, intrin->src[0].ssa);
1287          nir_def *any_not = build_quad_vote_any(b, not_src, options);
1288          return nir_inot(b, any_not);
1289       }
1290       break;
1291 
1292    case nir_intrinsic_reduce: {
1293       nir_def *ret = NULL;
1294       /* A cluster size greater than the subgroup size is implemention defined */
1295       if (options->subgroup_size &&
1296           nir_intrinsic_cluster_size(intrin) >= options->subgroup_size) {
1297          nir_intrinsic_set_cluster_size(intrin, 0);
1298          ret = NIR_LOWER_INSTR_PROGRESS;
1299       }
1300       if (nir_intrinsic_cluster_size(intrin) == 1)
1301          return intrin->src[0].ssa;
1302       if (options->lower_to_scalar && intrin->num_components > 1)
1303          return lower_subgroup_op_to_scalar(b, intrin);
1304       if (intrin->def.bit_size == 1 && options->ballot_components == 1 &&
1305           (options->lower_boolean_reduce || options->lower_reduce))
1306          return lower_boolean_reduce(b, intrin, options);
1307       if (options->lower_reduce)
1308          return lower_scan_reduce(b, intrin, options);
1309       return ret;
1310    }
1311    case nir_intrinsic_inclusive_scan:
1312    case nir_intrinsic_exclusive_scan:
1313       if (options->lower_to_scalar && intrin->num_components > 1)
1314          return lower_subgroup_op_to_scalar(b, intrin);
1315       if (intrin->def.bit_size == 1 && options->ballot_components == 1 &&
1316           (options->lower_boolean_reduce || options->lower_reduce))
1317          return lower_boolean_reduce(b, intrin, options);
1318       if (options->lower_reduce)
1319          return lower_scan_reduce(b, intrin, options);
1320       break;
1321 
1322    case nir_intrinsic_rotate:
1323       if (options->lower_rotate_to_shuffle &&
1324           (!options->lower_boolean_shuffle || intrin->src[0].ssa->bit_size != 1))
1325          return lower_to_shuffle(b, intrin, options);
1326       else if (options->lower_rotate_clustered_to_shuffle &&
1327                nir_intrinsic_cluster_size(intrin) > 0 &&
1328                (!options->lower_boolean_shuffle || intrin->src[0].ssa->bit_size != 1))
1329          return lower_to_shuffle(b, intrin, options);
1330       else if (options->lower_to_scalar && intrin->num_components > 1)
1331          return lower_subgroup_op_to_scalar(b, intrin);
1332       else if (options->lower_boolean_shuffle && intrin->src[0].ssa->bit_size == 1)
1333          return lower_boolean_shuffle(b, intrin, options);
1334       else if (options->lower_shuffle_to_32bit && intrin->src[0].ssa->bit_size == 64)
1335          return lower_subgroup_op_to_32bit(b, intrin);
1336       break;
1337    case nir_intrinsic_masked_swizzle_amd:
1338       if (options->lower_to_scalar && intrin->num_components > 1) {
1339          return lower_subgroup_op_to_scalar(b, intrin);
1340       } else if (options->lower_shuffle_to_32bit && intrin->src[0].ssa->bit_size == 64) {
1341          return lower_subgroup_op_to_32bit(b, intrin);
1342       }
1343       break;
1344 
1345    default:
1346       break;
1347    }
1348 
1349    return NULL;
1350 }
1351 
1352 bool
nir_lower_subgroups(nir_shader * shader,const nir_lower_subgroups_options * options)1353 nir_lower_subgroups(nir_shader *shader,
1354                     const nir_lower_subgroups_options *options)
1355 {
1356    return nir_shader_lower_instructions(shader, lower_subgroups_filter,
1357                                         lower_subgroups_instr,
1358                                         (void *)options);
1359 }
1360