• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright © 2021 Valve Corporation
3  * SPDX-License-Identifier: MIT
4  */
5 
6 #include "ir3.h"
7 #include "ir3_nir.h"
8 #include "util/ralloc.h"
9 
10 /* Lower several macro-instructions needed for shader subgroup support that
11  * must be turned into if statements. We do this after RA and post-RA
12  * scheduling to give the scheduler a chance to rearrange them, because RA
13  * may need to insert OPC_META_READ_FIRST to handle splitting live ranges, and
14  * also because some (e.g. BALLOT and READ_FIRST) must produce a shared
15  * register that cannot be spilled to a normal register until after the if,
16  * which makes implementing spilling more complicated if they are already
17  * lowered.
18  */
19 
20 static void
replace_pred(struct ir3_block * block,struct ir3_block * old_pred,struct ir3_block * new_pred)21 replace_pred(struct ir3_block *block, struct ir3_block *old_pred,
22              struct ir3_block *new_pred)
23 {
24    for (unsigned i = 0; i < block->predecessors_count; i++) {
25       if (block->predecessors[i] == old_pred) {
26          block->predecessors[i] = new_pred;
27          return;
28       }
29    }
30 }
31 
32 static void
replace_physical_pred(struct ir3_block * block,struct ir3_block * old_pred,struct ir3_block * new_pred)33 replace_physical_pred(struct ir3_block *block, struct ir3_block *old_pred,
34                       struct ir3_block *new_pred)
35 {
36    for (unsigned i = 0; i < block->physical_predecessors_count; i++) {
37       if (block->physical_predecessors[i] == old_pred) {
38          block->physical_predecessors[i] = new_pred;
39          return;
40       }
41    }
42 }
43 
44 static void
mov_immed(struct ir3_register * dst,struct ir3_block * block,unsigned immed)45 mov_immed(struct ir3_register *dst, struct ir3_block *block, unsigned immed)
46 {
47    struct ir3_instruction *mov =
48       ir3_instr_create_at(ir3_before_terminator(block), OPC_MOV, 1, 1);
49    struct ir3_register *mov_dst = ir3_dst_create(mov, dst->num, dst->flags);
50    mov_dst->wrmask = dst->wrmask;
51    struct ir3_register *src = ir3_src_create(
52       mov, INVALID_REG, (dst->flags & IR3_REG_HALF) | IR3_REG_IMMED);
53    src->uim_val = immed;
54    mov->cat1.dst_type = (dst->flags & IR3_REG_HALF) ? TYPE_U16 : TYPE_U32;
55    mov->cat1.src_type = mov->cat1.dst_type;
56    mov->repeat = util_last_bit(mov_dst->wrmask) - 1;
57 }
58 
59 static void
mov_reg(struct ir3_block * block,struct ir3_register * dst,struct ir3_register * src)60 mov_reg(struct ir3_block *block, struct ir3_register *dst,
61         struct ir3_register *src)
62 {
63    struct ir3_instruction *mov =
64       ir3_instr_create_at(ir3_before_terminator(block), OPC_MOV, 1, 1);
65 
66    struct ir3_register *mov_dst =
67       ir3_dst_create(mov, dst->num, dst->flags & (IR3_REG_HALF | IR3_REG_SHARED));
68    struct ir3_register *mov_src =
69       ir3_src_create(mov, src->num, src->flags & (IR3_REG_HALF | IR3_REG_SHARED));
70    mov_dst->wrmask = dst->wrmask;
71    mov_src->wrmask = src->wrmask;
72    mov->repeat = util_last_bit(mov_dst->wrmask) - 1;
73 
74    mov->cat1.dst_type = (dst->flags & IR3_REG_HALF) ? TYPE_U16 : TYPE_U32;
75    mov->cat1.src_type = (src->flags & IR3_REG_HALF) ? TYPE_U16 : TYPE_U32;
76 }
77 
78 static void
binop(struct ir3_block * block,opc_t opc,struct ir3_register * dst,struct ir3_register * src0,struct ir3_register * src1)79 binop(struct ir3_block *block, opc_t opc, struct ir3_register *dst,
80       struct ir3_register *src0, struct ir3_register *src1)
81 {
82    struct ir3_instruction *instr =
83       ir3_instr_create_at(ir3_before_terminator(block), opc, 1, 2);
84 
85    unsigned flags = dst->flags & IR3_REG_HALF;
86    struct ir3_register *instr_dst = ir3_dst_create(instr, dst->num, flags);
87    struct ir3_register *instr_src0 = ir3_src_create(instr, src0->num, flags);
88    struct ir3_register *instr_src1 = ir3_src_create(instr, src1->num, flags);
89 
90    instr_dst->wrmask = dst->wrmask;
91    instr_src0->wrmask = src0->wrmask;
92    instr_src1->wrmask = src1->wrmask;
93    instr->repeat = util_last_bit(instr_dst->wrmask) - 1;
94 }
95 
96 static void
triop(struct ir3_block * block,opc_t opc,struct ir3_register * dst,struct ir3_register * src0,struct ir3_register * src1,struct ir3_register * src2)97 triop(struct ir3_block *block, opc_t opc, struct ir3_register *dst,
98       struct ir3_register *src0, struct ir3_register *src1,
99       struct ir3_register *src2)
100 {
101    struct ir3_instruction *instr =
102       ir3_instr_create_at(ir3_before_terminator(block), opc, 1, 3);
103 
104    unsigned flags = dst->flags & IR3_REG_HALF;
105    struct ir3_register *instr_dst = ir3_dst_create(instr, dst->num, flags);
106    struct ir3_register *instr_src0 = ir3_src_create(instr, src0->num, flags);
107    struct ir3_register *instr_src1 = ir3_src_create(instr, src1->num, flags);
108    struct ir3_register *instr_src2 = ir3_src_create(instr, src2->num, flags);
109 
110    instr_dst->wrmask = dst->wrmask;
111    instr_src0->wrmask = src0->wrmask;
112    instr_src1->wrmask = src1->wrmask;
113    instr_src2->wrmask = src2->wrmask;
114    instr->repeat = util_last_bit(instr_dst->wrmask) - 1;
115 }
116 
117 static void
do_reduce(struct ir3_block * block,reduce_op_t opc,struct ir3_register * dst,struct ir3_register * src0,struct ir3_register * src1)118 do_reduce(struct ir3_block *block, reduce_op_t opc,
119           struct ir3_register *dst, struct ir3_register *src0,
120           struct ir3_register *src1)
121 {
122    switch (opc) {
123 #define CASE(name)                                                             \
124    case REDUCE_OP_##name:                                                      \
125       binop(block, OPC_##name, dst, src0, src1);                               \
126       break;
127 
128    CASE(ADD_U)
129    CASE(ADD_F)
130    CASE(MUL_F)
131    CASE(MIN_U)
132    CASE(MIN_S)
133    CASE(MIN_F)
134    CASE(MAX_U)
135    CASE(MAX_S)
136    CASE(MAX_F)
137    CASE(AND_B)
138    CASE(OR_B)
139    CASE(XOR_B)
140 
141 #undef CASE
142 
143    case REDUCE_OP_MUL_U:
144       if (dst->flags & IR3_REG_HALF) {
145          binop(block, OPC_MUL_S24, dst, src0, src1);
146       } else {
147          /* 32-bit multiplication macro - see ir3_nir_imul */
148          binop(block, OPC_MULL_U, dst, src0, src1);
149          triop(block, OPC_MADSH_M16, dst, src0, src1, dst);
150          triop(block, OPC_MADSH_M16, dst, src1, src0, dst);
151       }
152       break;
153    }
154 }
155 
156 static struct ir3_block *
split_block(struct ir3 * ir,struct ir3_block * before_block,struct ir3_instruction * instr)157 split_block(struct ir3 *ir, struct ir3_block *before_block,
158             struct ir3_instruction *instr)
159 {
160    struct ir3_block *after_block = ir3_block_create(ir);
161    list_add(&after_block->node, &before_block->node);
162 
163    for (unsigned i = 0; i < ARRAY_SIZE(before_block->successors); i++) {
164       after_block->successors[i] = before_block->successors[i];
165       if (after_block->successors[i])
166          replace_pred(after_block->successors[i], before_block, after_block);
167    }
168 
169    for (unsigned i = 0; i < before_block->physical_successors_count; i++) {
170       replace_physical_pred(before_block->physical_successors[i],
171                             before_block, after_block);
172    }
173 
174    ralloc_steal(after_block, before_block->physical_successors);
175    after_block->physical_successors = before_block->physical_successors;
176    after_block->physical_successors_sz = before_block->physical_successors_sz;
177    after_block->physical_successors_count =
178       before_block->physical_successors_count;
179 
180    before_block->successors[0] = before_block->successors[1] = NULL;
181    before_block->physical_successors = NULL;
182    before_block->physical_successors_count = 0;
183    before_block->physical_successors_sz = 0;
184 
185    foreach_instr_from_safe (rem_instr, &instr->node,
186                             &before_block->instr_list) {
187       list_del(&rem_instr->node);
188       list_addtail(&rem_instr->node, &after_block->instr_list);
189       rem_instr->block = after_block;
190    }
191 
192    after_block->divergent_condition = before_block->divergent_condition;
193    before_block->divergent_condition = false;
194    return after_block;
195 }
196 
197 static void
link_blocks(struct ir3_block * pred,struct ir3_block * succ,unsigned index)198 link_blocks(struct ir3_block *pred, struct ir3_block *succ, unsigned index)
199 {
200    pred->successors[index] = succ;
201    ir3_block_add_predecessor(succ, pred);
202    ir3_block_link_physical(pred, succ);
203 }
204 
205 static void
link_blocks_jump(struct ir3_block * pred,struct ir3_block * succ)206 link_blocks_jump(struct ir3_block *pred, struct ir3_block *succ)
207 {
208    struct ir3_builder build = ir3_builder_at(ir3_after_block(pred));
209    ir3_JUMP(&build);
210    link_blocks(pred, succ, 0);
211 }
212 
213 static void
link_blocks_branch(struct ir3_block * pred,struct ir3_block * target,struct ir3_block * fallthrough,unsigned opc,unsigned flags,struct ir3_instruction * condition)214 link_blocks_branch(struct ir3_block *pred, struct ir3_block *target,
215                    struct ir3_block *fallthrough, unsigned opc, unsigned flags,
216                    struct ir3_instruction *condition)
217 {
218    unsigned nsrc = condition ? 1 : 0;
219    struct ir3_instruction *branch =
220       ir3_instr_create_at(ir3_after_block(pred), opc, 0, nsrc);
221    branch->flags |= flags;
222 
223    if (condition) {
224       struct ir3_register *cond_dst = condition->dsts[0];
225       struct ir3_register *src =
226          ir3_src_create(branch, cond_dst->num, cond_dst->flags);
227       src->def = cond_dst;
228    }
229 
230    link_blocks(pred, target, 0);
231    link_blocks(pred, fallthrough, 1);
232 
233    if (opc != OPC_BALL && opc != OPC_BANY) {
234       pred->divergent_condition = true;
235    }
236 }
237 
238 static struct ir3_block *
create_if(struct ir3 * ir,struct ir3_block * before_block,struct ir3_block * after_block,unsigned opc,unsigned flags,struct ir3_instruction * condition)239 create_if(struct ir3 *ir, struct ir3_block *before_block,
240           struct ir3_block *after_block, unsigned opc, unsigned flags,
241           struct ir3_instruction *condition)
242 {
243    struct ir3_block *then_block = ir3_block_create(ir);
244    list_add(&then_block->node, &before_block->node);
245 
246    link_blocks_branch(before_block, then_block, after_block, opc, flags,
247                       condition);
248    link_blocks_jump(then_block, after_block);
249 
250    return then_block;
251 }
252 
253 static bool
lower_instr(struct ir3 * ir,struct ir3_block ** block,struct ir3_instruction * instr)254 lower_instr(struct ir3 *ir, struct ir3_block **block, struct ir3_instruction *instr)
255 {
256    switch (instr->opc) {
257    case OPC_BALLOT_MACRO:
258    case OPC_ANY_MACRO:
259    case OPC_ALL_MACRO:
260    case OPC_ELECT_MACRO:
261    case OPC_READ_COND_MACRO:
262    case OPC_READ_GETLAST_MACRO:
263    case OPC_SCAN_MACRO:
264    case OPC_SCAN_CLUSTERS_MACRO:
265       break;
266    case OPC_READ_FIRST_MACRO:
267       /* Moves to shared registers read the first active fiber, so we can just
268        * turn read_first.macro into a move. However we must still use the macro
269        * and lower it late because in ir3_cp we need to distinguish between
270        * moves where all source fibers contain the same value, which can be copy
271        * propagated, and moves generated from API-level ReadFirstInvocation
272        * which cannot.
273        */
274       assert(instr->dsts[0]->flags & IR3_REG_SHARED);
275       instr->opc = OPC_MOV;
276       instr->cat1.dst_type = TYPE_U32;
277       instr->cat1.src_type =
278          (instr->srcs[0]->flags & IR3_REG_HALF) ? TYPE_U16 : TYPE_U32;
279       return false;
280    default:
281       return false;
282    }
283 
284    struct ir3_block *before_block = *block;
285    struct ir3_block *after_block = split_block(ir, before_block, instr);
286 
287    if (instr->opc == OPC_SCAN_MACRO) {
288       /* The pseudo-code for the scan macro is:
289        *
290        * while (true) {
291        *    header:
292        *    if (elect()) {
293        *       exit:
294        *       exclusive = reduce;
295        *       inclusive = src OP exclusive;
296        *       reduce = inclusive;
297        *       break;
298        *    }
299        *    footer:
300        * }
301        *
302        * This is based on the blob's sequence, and carefully crafted to avoid
303        * using the shared register "reduce" except in move instructions, since
304        * using it in the actual OP isn't possible for half-registers.
305        */
306       struct ir3_block *header = ir3_block_create(ir);
307       list_add(&header->node, &before_block->node);
308 
309       struct ir3_block *exit = ir3_block_create(ir);
310       list_add(&exit->node, &header->node);
311 
312       struct ir3_block *footer = ir3_block_create(ir);
313       list_add(&footer->node, &exit->node);
314       footer->reconvergence_point = true;
315 
316       after_block->reconvergence_point = true;
317 
318       link_blocks_jump(before_block, header);
319 
320       link_blocks_branch(header, exit, footer, OPC_GETONE,
321                          IR3_INSTR_NEEDS_HELPERS, NULL);
322 
323       link_blocks_jump(exit, after_block);
324       ir3_block_link_physical(exit, footer);
325 
326       link_blocks_jump(footer, header);
327 
328       struct ir3_register *exclusive = instr->dsts[0];
329       struct ir3_register *inclusive = instr->dsts[1];
330       struct ir3_register *reduce = instr->dsts[2];
331       struct ir3_register *src = instr->srcs[0];
332 
333       mov_reg(exit, exclusive, reduce);
334       do_reduce(exit, instr->cat1.reduce_op, inclusive, src, exclusive);
335       mov_reg(exit, reduce, inclusive);
336    } else if (instr->opc == OPC_SCAN_CLUSTERS_MACRO) {
337       /* The pseudo-code for the scan macro is:
338        *
339        * while (true) {
340        *    body:
341        *    scratch = reduce;
342        *
343        *    inclusive = inclusive_src OP scratch;
344        *
345        *    static if (is exclusive scan)
346        *       exclusive = exclusive_src OP scratch
347        *
348        *    if (getlast()) {
349        *       store:
350        *       reduce = inclusive;
351        *       if (elect())
352        *           break;
353        *    } else {
354        *       break;
355        *    }
356        * }
357        * after_block:
358        */
359       struct ir3_block *body = ir3_block_create(ir);
360       list_add(&body->node, &before_block->node);
361 
362       struct ir3_block *store = ir3_block_create(ir);
363       list_add(&store->node, &body->node);
364 
365       after_block->reconvergence_point = true;
366 
367       link_blocks_jump(before_block, body);
368 
369       link_blocks_branch(body, store, after_block, OPC_GETLAST, 0, NULL);
370 
371       link_blocks_branch(store, after_block, body, OPC_GETONE,
372                          IR3_INSTR_NEEDS_HELPERS, NULL);
373 
374       struct ir3_register *reduce = instr->dsts[0];
375       struct ir3_register *inclusive = instr->dsts[1];
376       struct ir3_register *inclusive_src = instr->srcs[1];
377 
378       /* We need to perform the following operations:
379        *  - inclusive = inclusive_src OP reduce
380        *  - exclusive = exclusive_src OP reduce (iff exclusive scan)
381        * Since reduce is initially in a shared register, we need to copy it to a
382        * scratch register before performing the operations.
383        *
384        * The scratch register used is:
385        *  - an explicitly allocated one if op is 32b mul_u.
386        *    - necessary because we cannot do 'foo = foo mul_u bar' since mul_u
387        *      clobbers its destination.
388        *  - exclusive if this is an exclusive scan (and not 32b mul_u).
389        *    - since we calculate inclusive first.
390        *  - inclusive otherwise.
391        *
392        * In all cases, this is the last destination.
393        */
394       struct ir3_register *scratch = instr->dsts[instr->dsts_count - 1];
395 
396       mov_reg(body, scratch, reduce);
397       do_reduce(body, instr->cat1.reduce_op, inclusive, inclusive_src, scratch);
398 
399       /* exclusive scan */
400       if (instr->srcs_count == 3) {
401          struct ir3_register *exclusive_src = instr->srcs[2];
402          struct ir3_register *exclusive = instr->dsts[2];
403          do_reduce(body, instr->cat1.reduce_op, exclusive, exclusive_src,
404                    scratch);
405       }
406 
407       mov_reg(store, reduce, inclusive);
408    } else {
409       /* For ballot, the destination must be initialized to 0 before we do
410        * the movmsk because the condition may be 0 and then the movmsk will
411        * be skipped.
412        */
413       if (instr->opc == OPC_BALLOT_MACRO) {
414          mov_immed(instr->dsts[0], before_block, 0);
415       }
416 
417       struct ir3_instruction *condition = NULL;
418       unsigned branch_opc = 0;
419       unsigned branch_flags = 0;
420 
421       switch (instr->opc) {
422       case OPC_BALLOT_MACRO:
423       case OPC_READ_COND_MACRO:
424       case OPC_ANY_MACRO:
425       case OPC_ALL_MACRO:
426          condition = instr->srcs[0]->def->instr;
427          break;
428       default:
429          break;
430       }
431 
432       switch (instr->opc) {
433       case OPC_BALLOT_MACRO:
434       case OPC_READ_COND_MACRO:
435          after_block->reconvergence_point = true;
436          branch_opc = OPC_BR;
437          break;
438       case OPC_ANY_MACRO:
439          branch_opc = OPC_BANY;
440          break;
441       case OPC_ALL_MACRO:
442          branch_opc = OPC_BALL;
443          break;
444       case OPC_ELECT_MACRO:
445          after_block->reconvergence_point = true;
446          branch_opc = OPC_GETONE;
447          branch_flags = instr->flags & IR3_INSTR_NEEDS_HELPERS;
448          break;
449       case OPC_READ_GETLAST_MACRO:
450          after_block->reconvergence_point = true;
451          branch_opc = OPC_GETLAST;
452          branch_flags = instr->flags & IR3_INSTR_NEEDS_HELPERS;
453          break;
454       default:
455          unreachable("bad opcode");
456       }
457 
458       struct ir3_block *then_block =
459          create_if(ir, before_block, after_block, branch_opc, branch_flags,
460                    condition);
461 
462       switch (instr->opc) {
463       case OPC_ALL_MACRO:
464       case OPC_ANY_MACRO:
465       case OPC_ELECT_MACRO:
466          mov_immed(instr->dsts[0], then_block, 1);
467          mov_immed(instr->dsts[0], before_block, 0);
468          break;
469 
470       case OPC_BALLOT_MACRO: {
471          unsigned wrmask = instr->dsts[0]->wrmask;
472          unsigned comp_count = util_last_bit(wrmask);
473          struct ir3_instruction *movmsk = ir3_instr_create_at(
474             ir3_before_terminator(then_block), OPC_MOVMSK, 1, 0);
475          struct ir3_register *dst =
476             ir3_dst_create(movmsk, instr->dsts[0]->num, instr->dsts[0]->flags);
477          dst->wrmask = wrmask;
478          movmsk->repeat = comp_count - 1;
479          break;
480       }
481 
482       case OPC_READ_GETLAST_MACRO:
483       case OPC_READ_COND_MACRO: {
484          struct ir3_instruction *mov = ir3_instr_create_at(
485             ir3_before_terminator(then_block), OPC_MOV, 1, 1);
486          ir3_dst_create(mov, instr->dsts[0]->num, instr->dsts[0]->flags);
487          struct ir3_register *new_src = ir3_src_create(mov, 0, 0);
488          unsigned idx = instr->opc == OPC_READ_COND_MACRO ? 1 : 0;
489          *new_src = *instr->srcs[idx];
490          mov->cat1.dst_type = TYPE_U32;
491          mov->cat1.src_type =
492             (new_src->flags & IR3_REG_HALF) ? TYPE_U16 : TYPE_U32;
493          mov->flags |= IR3_INSTR_NEEDS_HELPERS;
494          break;
495       }
496 
497       default:
498          unreachable("bad opcode");
499       }
500    }
501 
502    *block = after_block;
503    list_delinit(&instr->node);
504    return true;
505 }
506 
507 static bool
lower_block(struct ir3 * ir,struct ir3_block ** block)508 lower_block(struct ir3 *ir, struct ir3_block **block)
509 {
510    bool progress = true;
511 
512    bool inner_progress;
513    do {
514       inner_progress = false;
515       foreach_instr (instr, &(*block)->instr_list) {
516          if (lower_instr(ir, block, instr)) {
517             /* restart the loop with the new block we created because the
518              * iterator has been invalidated.
519              */
520             progress = inner_progress = true;
521             break;
522          }
523       }
524    } while (inner_progress);
525 
526    return progress;
527 }
528 
529 bool
ir3_lower_subgroups(struct ir3 * ir)530 ir3_lower_subgroups(struct ir3 *ir)
531 {
532    bool progress = false;
533 
534    foreach_block (block, &ir->block_list)
535       progress |= lower_block(ir, &block);
536 
537    return progress;
538 }
539 
540 static const struct glsl_type *
glsl_type_for_def(nir_def * def)541 glsl_type_for_def(nir_def *def)
542 {
543    assert(def->num_components == 1);
544    return def->bit_size == 1 ? glsl_bool_type()
545                              : glsl_uintN_t_type(def->bit_size);
546 }
547 
548 static bool
filter_scan_reduce(const nir_instr * instr,const void * data)549 filter_scan_reduce(const nir_instr *instr, const void *data)
550 {
551    if (instr->type != nir_instr_type_intrinsic)
552       return false;
553 
554    nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
555 
556    switch (intrin->intrinsic) {
557    case nir_intrinsic_reduce:
558    case nir_intrinsic_inclusive_scan:
559    case nir_intrinsic_exclusive_scan:
560       return true;
561    default:
562       return false;
563    }
564 }
565 
566 typedef nir_def *(*reduce_cluster)(nir_builder *, nir_op, nir_def *);
567 
568 /* Execute `reduce` for each cluster in the subgroup with only the invocations
569  * in the current cluster active.
570  */
571 static nir_def *
foreach_cluster(nir_builder * b,nir_op op,nir_def * inclusive,unsigned cluster_size,reduce_cluster reduce)572 foreach_cluster(nir_builder *b, nir_op op, nir_def *inclusive,
573                 unsigned cluster_size, reduce_cluster reduce)
574 {
575    nir_def *id = nir_load_subgroup_invocation(b);
576    nir_def *cluster_size_imm = nir_imm_int(b, cluster_size);
577 
578    /* cur_cluster_end = cluster_size;
579     * while (true) {
580     *    if (gl_SubgroupInvocationID < cur_cluster_end) {
581     *       cluster_val = reduce(inclusive);
582     *       break;
583     *    }
584     *
585     *    cur_cluster_end += cluster_size;
586     * }
587     */
588    nir_variable *cur_cluster_end_var =
589       nir_local_variable_create(b->impl, glsl_uint_type(), "cur_cluster_end");
590    nir_store_var(b, cur_cluster_end_var, cluster_size_imm, 1);
591    nir_variable *cluster_val_var = nir_local_variable_create(
592       b->impl, glsl_type_for_def(inclusive), "cluster_val");
593 
594    nir_loop *loop = nir_push_loop(b);
595    {
596       nir_def *cur_cluster_end = nir_load_var(b, cur_cluster_end_var);
597       nir_def *in_cur_cluster = nir_ult(b, id, cur_cluster_end);
598 
599       nir_if *nif = nir_push_if(b, in_cur_cluster);
600       {
601          nir_def *reduced = reduce(b, op, inclusive);
602          nir_store_var(b, cluster_val_var, reduced, 1);
603          nir_jump(b, nir_jump_break);
604       }
605       nir_pop_if(b, nif);
606 
607       nir_def *next_cluster_end =
608          nir_iadd(b, cur_cluster_end, cluster_size_imm);
609       nir_store_var(b, cur_cluster_end_var, next_cluster_end, 1);
610    }
611    nir_pop_loop(b, loop);
612 
613    return nir_load_var(b, cluster_val_var);
614 }
615 
616 static nir_def *
read_last(nir_builder * b,nir_op op,nir_def * val)617 read_last(nir_builder *b, nir_op op, nir_def *val)
618 {
619    return nir_read_getlast_ir3(b, val);
620 }
621 
622 static nir_def *
reduce_clusters(nir_builder * b,nir_op op,nir_def * val)623 reduce_clusters(nir_builder *b, nir_op op, nir_def *val)
624 {
625    return nir_reduce_clusters_ir3(b, val, .reduction_op = op);
626 }
627 
628 static nir_def *
lower_scan_reduce(struct nir_builder * b,nir_instr * instr,void * data)629 lower_scan_reduce(struct nir_builder *b, nir_instr *instr, void *data)
630 {
631    struct ir3_shader_variant *v = data;
632 
633    nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
634    unsigned bit_size = intrin->def.bit_size;
635    assert(bit_size < 64);
636 
637    nir_op op = nir_intrinsic_reduction_op(intrin);
638    nir_const_value ident_val = nir_alu_binop_identity(op, bit_size);
639    nir_def *ident = nir_build_imm(b, 1, bit_size, &ident_val);
640    nir_def *inclusive = intrin->src[0].ssa;
641    nir_def *exclusive = ident;
642    unsigned cluster_size = nir_intrinsic_has_cluster_size(intrin)
643                               ? nir_intrinsic_cluster_size(intrin)
644                               : 0;
645    bool clustered = cluster_size != 0;
646    unsigned subgroup_size, max_subgroup_size;
647    ir3_shader_get_subgroup_size(v->compiler, &v->shader_options, v->type,
648                                 &subgroup_size, &max_subgroup_size);
649 
650    if (subgroup_size == 0) {
651       subgroup_size = max_subgroup_size;
652    }
653 
654    /* Should have been lowered by nir_lower_subgroups. */
655    assert(cluster_size != 1);
656 
657    /* Only clustered reduce operations are supported. */
658    assert(intrin->intrinsic == nir_intrinsic_reduce || !clustered);
659 
660    unsigned max_brcst_cluster_size = clustered ? MIN2(cluster_size, 8) : 8;
661 
662    for (unsigned brcst_cluster_size = 2;
663         brcst_cluster_size <= max_brcst_cluster_size; brcst_cluster_size *= 2) {
664       nir_def *brcst = nir_brcst_active_ir3(b, ident, inclusive,
665                                             .cluster_size = brcst_cluster_size);
666       inclusive = nir_build_alu2(b, op, inclusive, brcst);
667 
668       if (intrin->intrinsic == nir_intrinsic_exclusive_scan)
669          exclusive = nir_build_alu2(b, op, exclusive, brcst);
670    }
671 
672    switch (intrin->intrinsic) {
673    case nir_intrinsic_reduce:
674       if (!clustered || cluster_size >= subgroup_size) {
675          /* The normal (non-clustered) path does a full reduction of all brcst
676           * clusters.
677           */
678          return nir_reduce_clusters_ir3(b, inclusive, .reduction_op = op);
679       } else if (cluster_size <= 8) {
680          /* After the brcsts have been executed, each brcst cluster has its
681           * reduction in its last fiber. So if the cluster size is at most the
682           * maximum brcst cluster size (8) we can simply iterate the clusters
683           * and read the value from their last fibers.
684           */
685          return foreach_cluster(b, op, inclusive, cluster_size, read_last);
686       } else {
687          /* For larger clusters, we do a normal reduction for every cluster.
688           */
689          return foreach_cluster(b, op, inclusive, cluster_size,
690                                 reduce_clusters);
691       }
692    case nir_intrinsic_inclusive_scan:
693       return nir_inclusive_scan_clusters_ir3(b, inclusive, .reduction_op = op);
694    case nir_intrinsic_exclusive_scan:
695       return nir_exclusive_scan_clusters_ir3(b, inclusive, exclusive,
696                                              .reduction_op = op);
697    default:
698       unreachable("filtered intrinsic");
699    }
700 }
701 
702 bool
ir3_nir_opt_subgroups(nir_shader * nir,struct ir3_shader_variant * v)703 ir3_nir_opt_subgroups(nir_shader *nir, struct ir3_shader_variant *v)
704 {
705    if (!v->compiler->has_getfiberid)
706       return false;
707 
708    return nir_shader_lower_instructions(nir, filter_scan_reduce,
709                                         lower_scan_reduce, v);
710 }
711 
712 bool
ir3_nir_lower_subgroups_filter(const nir_instr * instr,const void * data)713 ir3_nir_lower_subgroups_filter(const nir_instr *instr, const void *data)
714 {
715    if (instr->type != nir_instr_type_intrinsic)
716       return false;
717 
718    nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
719 
720    const struct ir3_compiler *compiler = data;
721 
722    switch (intrin->intrinsic) {
723    case nir_intrinsic_reduce:
724       if (nir_intrinsic_cluster_size(intrin) == 1) {
725          return true;
726       }
727       if (nir_intrinsic_cluster_size(intrin) > 0 && !compiler->has_getfiberid) {
728          return true;
729       }
730       FALLTHROUGH;
731    case nir_intrinsic_inclusive_scan:
732    case nir_intrinsic_exclusive_scan:
733       switch (nir_intrinsic_reduction_op(intrin)) {
734       case nir_op_imul:
735       case nir_op_imin:
736       case nir_op_imax:
737       case nir_op_umin:
738       case nir_op_umax:
739          if (intrin->def.bit_size == 64) {
740             return true;
741          }
742          FALLTHROUGH;
743       default:
744          return intrin->def.num_components > 1;
745       }
746    default:
747       return true;
748    }
749 }
750 
751 static bool
filter_shuffle(const nir_instr * instr,const void * data)752 filter_shuffle(const nir_instr *instr, const void *data)
753 {
754    if (instr->type != nir_instr_type_intrinsic) {
755       return false;
756    }
757 
758    nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
759 
760    switch (intrin->intrinsic) {
761    case nir_intrinsic_shuffle:
762    case nir_intrinsic_shuffle_up:
763    case nir_intrinsic_shuffle_down:
764    case nir_intrinsic_shuffle_xor:
765       return true;
766    default:
767       return false;
768    }
769 }
770 
771 static nir_def *
shuffle_to_uniform(nir_builder * b,nir_intrinsic_op op,struct nir_def * val,struct nir_def * id)772 shuffle_to_uniform(nir_builder *b, nir_intrinsic_op op, struct nir_def *val,
773                    struct nir_def *id)
774 {
775    switch (op) {
776    case nir_intrinsic_shuffle:
777       return nir_rotate(b, val, id);
778    case nir_intrinsic_shuffle_up:
779       return nir_shuffle_up_uniform_ir3(b, val, id);
780    case nir_intrinsic_shuffle_down:
781       return nir_shuffle_down_uniform_ir3(b, val, id);
782    case nir_intrinsic_shuffle_xor:
783       return nir_shuffle_xor_uniform_ir3(b, val, id);
784    default:
785       unreachable("filtered intrinsic");
786    }
787 }
788 
789 /* Transforms a shuffle operation into a loop that only uses shuffles with
790  * (dynamically) uniform indices. This is based on the blob's sequence and
791  * carefully makes sure that the least amount of iterations are performed (i.e.,
792  * one iteration per distinct index) while keeping all invocations active during
793  * each shfl operation. This is necessary since shfl does not update its dst
794  * when its src is inactive.
795  *
796  * done = false;
797  * while (true) {
798  *    next_index = read_invocation_cond_ir3(index, !done);
799  *    shuffled = op_uniform(val, next_index);
800  *
801  *    if (index == next_index) {
802  *       result = shuffled;
803  *       done = true;
804  *    }
805  *
806  *    if (subgroupAll(done)) {
807  *       break;
808  *    }
809  * }
810  */
811 static nir_def *
make_shuffle_uniform(nir_builder * b,nir_def * val,nir_def * index,nir_intrinsic_op op)812 make_shuffle_uniform(nir_builder *b, nir_def *val, nir_def *index,
813                      nir_intrinsic_op op)
814 {
815    nir_variable *done =
816       nir_local_variable_create(b->impl, glsl_bool_type(), "done");
817    nir_store_var(b, done, nir_imm_false(b), 1);
818    nir_variable *result =
819       nir_local_variable_create(b->impl, glsl_type_for_def(val), "result");
820 
821    nir_loop *loop = nir_push_loop(b);
822    {
823       nir_def *next_index = nir_read_invocation_cond_ir3(
824          b, index->bit_size, index, nir_inot(b, nir_load_var(b, done)));
825       next_index->divergent = false;
826       nir_def *shuffled = shuffle_to_uniform(b, op, val, next_index);
827 
828       nir_if *nif = nir_push_if(b, nir_ieq(b, index, next_index));
829       {
830          nir_store_var(b, result, shuffled, 1);
831          nir_store_var(b, done, nir_imm_true(b), 1);
832       }
833       nir_pop_if(b, nif);
834 
835       nir_break_if(b, nir_vote_all(b, 1, nir_load_var(b, done)));
836    }
837    nir_pop_loop(b, loop);
838 
839    return nir_load_var(b, result);
840 }
841 
842 static nir_def *
lower_shuffle(nir_builder * b,nir_instr * instr,void * data)843 lower_shuffle(nir_builder *b, nir_instr *instr, void *data)
844 {
845    nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
846    nir_def *val = intrin->src[0].ssa;
847    nir_def *index = intrin->src[1].ssa;
848 
849    if (intrin->intrinsic == nir_intrinsic_shuffle) {
850       /* The hw only does relative shuffles/rotates so transform shuffle(val, x)
851        * into rotate(val, x - gl_SubgroupInvocationID) which is valid since we
852        * make sure to only use it with uniform indices.
853        */
854       index = nir_isub(b, index, nir_load_subgroup_invocation(b));
855    }
856 
857    if (!index->divergent) {
858       return shuffle_to_uniform(b, intrin->intrinsic, val, index);
859    }
860 
861    return make_shuffle_uniform(b, val, index, intrin->intrinsic);
862 }
863 
864 /* Lower (relative) shuffles to be able to use the shfl instruction. One quirk
865  * of shfl is that its index has to be dynamically uniform, so we transform the
866  * standard NIR intrinsics into ir3-specific ones which require their index to
867  * be uniform.
868  */
869 bool
ir3_nir_lower_shuffle(nir_shader * nir,struct ir3_shader * shader)870 ir3_nir_lower_shuffle(nir_shader *nir, struct ir3_shader *shader)
871 {
872    if (!shader->compiler->has_shfl) {
873       return false;
874    }
875 
876    nir_divergence_analysis(nir);
877    return nir_shader_lower_instructions(nir, filter_shuffle, lower_shuffle,
878                                         NULL);
879 }
880