1 /*
2 * Copyright © 2021 Valve Corporation
3 * SPDX-License-Identifier: MIT
4 */
5
6 #include "ir3.h"
7 #include "ir3_nir.h"
8 #include "util/ralloc.h"
9
10 /* Lower several macro-instructions needed for shader subgroup support that
11 * must be turned into if statements. We do this after RA and post-RA
12 * scheduling to give the scheduler a chance to rearrange them, because RA
13 * may need to insert OPC_META_READ_FIRST to handle splitting live ranges, and
14 * also because some (e.g. BALLOT and READ_FIRST) must produce a shared
15 * register that cannot be spilled to a normal register until after the if,
16 * which makes implementing spilling more complicated if they are already
17 * lowered.
18 */
19
20 static void
replace_pred(struct ir3_block * block,struct ir3_block * old_pred,struct ir3_block * new_pred)21 replace_pred(struct ir3_block *block, struct ir3_block *old_pred,
22 struct ir3_block *new_pred)
23 {
24 for (unsigned i = 0; i < block->predecessors_count; i++) {
25 if (block->predecessors[i] == old_pred) {
26 block->predecessors[i] = new_pred;
27 return;
28 }
29 }
30 }
31
32 static void
replace_physical_pred(struct ir3_block * block,struct ir3_block * old_pred,struct ir3_block * new_pred)33 replace_physical_pred(struct ir3_block *block, struct ir3_block *old_pred,
34 struct ir3_block *new_pred)
35 {
36 for (unsigned i = 0; i < block->physical_predecessors_count; i++) {
37 if (block->physical_predecessors[i] == old_pred) {
38 block->physical_predecessors[i] = new_pred;
39 return;
40 }
41 }
42 }
43
44 static void
mov_immed(struct ir3_register * dst,struct ir3_block * block,unsigned immed)45 mov_immed(struct ir3_register *dst, struct ir3_block *block, unsigned immed)
46 {
47 struct ir3_instruction *mov =
48 ir3_instr_create_at(ir3_before_terminator(block), OPC_MOV, 1, 1);
49 struct ir3_register *mov_dst = ir3_dst_create(mov, dst->num, dst->flags);
50 mov_dst->wrmask = dst->wrmask;
51 struct ir3_register *src = ir3_src_create(
52 mov, INVALID_REG, (dst->flags & IR3_REG_HALF) | IR3_REG_IMMED);
53 src->uim_val = immed;
54 mov->cat1.dst_type = (dst->flags & IR3_REG_HALF) ? TYPE_U16 : TYPE_U32;
55 mov->cat1.src_type = mov->cat1.dst_type;
56 mov->repeat = util_last_bit(mov_dst->wrmask) - 1;
57 }
58
59 static void
mov_reg(struct ir3_block * block,struct ir3_register * dst,struct ir3_register * src)60 mov_reg(struct ir3_block *block, struct ir3_register *dst,
61 struct ir3_register *src)
62 {
63 struct ir3_instruction *mov =
64 ir3_instr_create_at(ir3_before_terminator(block), OPC_MOV, 1, 1);
65
66 struct ir3_register *mov_dst =
67 ir3_dst_create(mov, dst->num, dst->flags & (IR3_REG_HALF | IR3_REG_SHARED));
68 struct ir3_register *mov_src =
69 ir3_src_create(mov, src->num, src->flags & (IR3_REG_HALF | IR3_REG_SHARED));
70 mov_dst->wrmask = dst->wrmask;
71 mov_src->wrmask = src->wrmask;
72 mov->repeat = util_last_bit(mov_dst->wrmask) - 1;
73
74 mov->cat1.dst_type = (dst->flags & IR3_REG_HALF) ? TYPE_U16 : TYPE_U32;
75 mov->cat1.src_type = (src->flags & IR3_REG_HALF) ? TYPE_U16 : TYPE_U32;
76 }
77
78 static void
binop(struct ir3_block * block,opc_t opc,struct ir3_register * dst,struct ir3_register * src0,struct ir3_register * src1)79 binop(struct ir3_block *block, opc_t opc, struct ir3_register *dst,
80 struct ir3_register *src0, struct ir3_register *src1)
81 {
82 struct ir3_instruction *instr =
83 ir3_instr_create_at(ir3_before_terminator(block), opc, 1, 2);
84
85 unsigned flags = dst->flags & IR3_REG_HALF;
86 struct ir3_register *instr_dst = ir3_dst_create(instr, dst->num, flags);
87 struct ir3_register *instr_src0 = ir3_src_create(instr, src0->num, flags);
88 struct ir3_register *instr_src1 = ir3_src_create(instr, src1->num, flags);
89
90 instr_dst->wrmask = dst->wrmask;
91 instr_src0->wrmask = src0->wrmask;
92 instr_src1->wrmask = src1->wrmask;
93 instr->repeat = util_last_bit(instr_dst->wrmask) - 1;
94 }
95
96 static void
triop(struct ir3_block * block,opc_t opc,struct ir3_register * dst,struct ir3_register * src0,struct ir3_register * src1,struct ir3_register * src2)97 triop(struct ir3_block *block, opc_t opc, struct ir3_register *dst,
98 struct ir3_register *src0, struct ir3_register *src1,
99 struct ir3_register *src2)
100 {
101 struct ir3_instruction *instr =
102 ir3_instr_create_at(ir3_before_terminator(block), opc, 1, 3);
103
104 unsigned flags = dst->flags & IR3_REG_HALF;
105 struct ir3_register *instr_dst = ir3_dst_create(instr, dst->num, flags);
106 struct ir3_register *instr_src0 = ir3_src_create(instr, src0->num, flags);
107 struct ir3_register *instr_src1 = ir3_src_create(instr, src1->num, flags);
108 struct ir3_register *instr_src2 = ir3_src_create(instr, src2->num, flags);
109
110 instr_dst->wrmask = dst->wrmask;
111 instr_src0->wrmask = src0->wrmask;
112 instr_src1->wrmask = src1->wrmask;
113 instr_src2->wrmask = src2->wrmask;
114 instr->repeat = util_last_bit(instr_dst->wrmask) - 1;
115 }
116
117 static void
do_reduce(struct ir3_block * block,reduce_op_t opc,struct ir3_register * dst,struct ir3_register * src0,struct ir3_register * src1)118 do_reduce(struct ir3_block *block, reduce_op_t opc,
119 struct ir3_register *dst, struct ir3_register *src0,
120 struct ir3_register *src1)
121 {
122 switch (opc) {
123 #define CASE(name) \
124 case REDUCE_OP_##name: \
125 binop(block, OPC_##name, dst, src0, src1); \
126 break;
127
128 CASE(ADD_U)
129 CASE(ADD_F)
130 CASE(MUL_F)
131 CASE(MIN_U)
132 CASE(MIN_S)
133 CASE(MIN_F)
134 CASE(MAX_U)
135 CASE(MAX_S)
136 CASE(MAX_F)
137 CASE(AND_B)
138 CASE(OR_B)
139 CASE(XOR_B)
140
141 #undef CASE
142
143 case REDUCE_OP_MUL_U:
144 if (dst->flags & IR3_REG_HALF) {
145 binop(block, OPC_MUL_S24, dst, src0, src1);
146 } else {
147 /* 32-bit multiplication macro - see ir3_nir_imul */
148 binop(block, OPC_MULL_U, dst, src0, src1);
149 triop(block, OPC_MADSH_M16, dst, src0, src1, dst);
150 triop(block, OPC_MADSH_M16, dst, src1, src0, dst);
151 }
152 break;
153 }
154 }
155
156 static struct ir3_block *
split_block(struct ir3 * ir,struct ir3_block * before_block,struct ir3_instruction * instr)157 split_block(struct ir3 *ir, struct ir3_block *before_block,
158 struct ir3_instruction *instr)
159 {
160 struct ir3_block *after_block = ir3_block_create(ir);
161 list_add(&after_block->node, &before_block->node);
162
163 for (unsigned i = 0; i < ARRAY_SIZE(before_block->successors); i++) {
164 after_block->successors[i] = before_block->successors[i];
165 if (after_block->successors[i])
166 replace_pred(after_block->successors[i], before_block, after_block);
167 }
168
169 for (unsigned i = 0; i < before_block->physical_successors_count; i++) {
170 replace_physical_pred(before_block->physical_successors[i],
171 before_block, after_block);
172 }
173
174 ralloc_steal(after_block, before_block->physical_successors);
175 after_block->physical_successors = before_block->physical_successors;
176 after_block->physical_successors_sz = before_block->physical_successors_sz;
177 after_block->physical_successors_count =
178 before_block->physical_successors_count;
179
180 before_block->successors[0] = before_block->successors[1] = NULL;
181 before_block->physical_successors = NULL;
182 before_block->physical_successors_count = 0;
183 before_block->physical_successors_sz = 0;
184
185 foreach_instr_from_safe (rem_instr, &instr->node,
186 &before_block->instr_list) {
187 list_del(&rem_instr->node);
188 list_addtail(&rem_instr->node, &after_block->instr_list);
189 rem_instr->block = after_block;
190 }
191
192 after_block->divergent_condition = before_block->divergent_condition;
193 before_block->divergent_condition = false;
194 return after_block;
195 }
196
197 static void
link_blocks(struct ir3_block * pred,struct ir3_block * succ,unsigned index)198 link_blocks(struct ir3_block *pred, struct ir3_block *succ, unsigned index)
199 {
200 pred->successors[index] = succ;
201 ir3_block_add_predecessor(succ, pred);
202 ir3_block_link_physical(pred, succ);
203 }
204
205 static void
link_blocks_jump(struct ir3_block * pred,struct ir3_block * succ)206 link_blocks_jump(struct ir3_block *pred, struct ir3_block *succ)
207 {
208 struct ir3_builder build = ir3_builder_at(ir3_after_block(pred));
209 ir3_JUMP(&build);
210 link_blocks(pred, succ, 0);
211 }
212
213 static void
link_blocks_branch(struct ir3_block * pred,struct ir3_block * target,struct ir3_block * fallthrough,unsigned opc,unsigned flags,struct ir3_instruction * condition)214 link_blocks_branch(struct ir3_block *pred, struct ir3_block *target,
215 struct ir3_block *fallthrough, unsigned opc, unsigned flags,
216 struct ir3_instruction *condition)
217 {
218 unsigned nsrc = condition ? 1 : 0;
219 struct ir3_instruction *branch =
220 ir3_instr_create_at(ir3_after_block(pred), opc, 0, nsrc);
221 branch->flags |= flags;
222
223 if (condition) {
224 struct ir3_register *cond_dst = condition->dsts[0];
225 struct ir3_register *src =
226 ir3_src_create(branch, cond_dst->num, cond_dst->flags);
227 src->def = cond_dst;
228 }
229
230 link_blocks(pred, target, 0);
231 link_blocks(pred, fallthrough, 1);
232
233 if (opc != OPC_BALL && opc != OPC_BANY) {
234 pred->divergent_condition = true;
235 }
236 }
237
238 static struct ir3_block *
create_if(struct ir3 * ir,struct ir3_block * before_block,struct ir3_block * after_block,unsigned opc,unsigned flags,struct ir3_instruction * condition)239 create_if(struct ir3 *ir, struct ir3_block *before_block,
240 struct ir3_block *after_block, unsigned opc, unsigned flags,
241 struct ir3_instruction *condition)
242 {
243 struct ir3_block *then_block = ir3_block_create(ir);
244 list_add(&then_block->node, &before_block->node);
245
246 link_blocks_branch(before_block, then_block, after_block, opc, flags,
247 condition);
248 link_blocks_jump(then_block, after_block);
249
250 return then_block;
251 }
252
253 static bool
lower_instr(struct ir3 * ir,struct ir3_block ** block,struct ir3_instruction * instr)254 lower_instr(struct ir3 *ir, struct ir3_block **block, struct ir3_instruction *instr)
255 {
256 switch (instr->opc) {
257 case OPC_BALLOT_MACRO:
258 case OPC_ANY_MACRO:
259 case OPC_ALL_MACRO:
260 case OPC_ELECT_MACRO:
261 case OPC_READ_COND_MACRO:
262 case OPC_READ_GETLAST_MACRO:
263 case OPC_SCAN_MACRO:
264 case OPC_SCAN_CLUSTERS_MACRO:
265 break;
266 case OPC_READ_FIRST_MACRO:
267 /* Moves to shared registers read the first active fiber, so we can just
268 * turn read_first.macro into a move. However we must still use the macro
269 * and lower it late because in ir3_cp we need to distinguish between
270 * moves where all source fibers contain the same value, which can be copy
271 * propagated, and moves generated from API-level ReadFirstInvocation
272 * which cannot.
273 */
274 assert(instr->dsts[0]->flags & IR3_REG_SHARED);
275 instr->opc = OPC_MOV;
276 instr->cat1.dst_type = TYPE_U32;
277 instr->cat1.src_type =
278 (instr->srcs[0]->flags & IR3_REG_HALF) ? TYPE_U16 : TYPE_U32;
279 return false;
280 default:
281 return false;
282 }
283
284 struct ir3_block *before_block = *block;
285 struct ir3_block *after_block = split_block(ir, before_block, instr);
286
287 if (instr->opc == OPC_SCAN_MACRO) {
288 /* The pseudo-code for the scan macro is:
289 *
290 * while (true) {
291 * header:
292 * if (elect()) {
293 * exit:
294 * exclusive = reduce;
295 * inclusive = src OP exclusive;
296 * reduce = inclusive;
297 * break;
298 * }
299 * footer:
300 * }
301 *
302 * This is based on the blob's sequence, and carefully crafted to avoid
303 * using the shared register "reduce" except in move instructions, since
304 * using it in the actual OP isn't possible for half-registers.
305 */
306 struct ir3_block *header = ir3_block_create(ir);
307 list_add(&header->node, &before_block->node);
308
309 struct ir3_block *exit = ir3_block_create(ir);
310 list_add(&exit->node, &header->node);
311
312 struct ir3_block *footer = ir3_block_create(ir);
313 list_add(&footer->node, &exit->node);
314 footer->reconvergence_point = true;
315
316 after_block->reconvergence_point = true;
317
318 link_blocks_jump(before_block, header);
319
320 link_blocks_branch(header, exit, footer, OPC_GETONE,
321 IR3_INSTR_NEEDS_HELPERS, NULL);
322
323 link_blocks_jump(exit, after_block);
324 ir3_block_link_physical(exit, footer);
325
326 link_blocks_jump(footer, header);
327
328 struct ir3_register *exclusive = instr->dsts[0];
329 struct ir3_register *inclusive = instr->dsts[1];
330 struct ir3_register *reduce = instr->dsts[2];
331 struct ir3_register *src = instr->srcs[0];
332
333 mov_reg(exit, exclusive, reduce);
334 do_reduce(exit, instr->cat1.reduce_op, inclusive, src, exclusive);
335 mov_reg(exit, reduce, inclusive);
336 } else if (instr->opc == OPC_SCAN_CLUSTERS_MACRO) {
337 /* The pseudo-code for the scan macro is:
338 *
339 * while (true) {
340 * body:
341 * scratch = reduce;
342 *
343 * inclusive = inclusive_src OP scratch;
344 *
345 * static if (is exclusive scan)
346 * exclusive = exclusive_src OP scratch
347 *
348 * if (getlast()) {
349 * store:
350 * reduce = inclusive;
351 * if (elect())
352 * break;
353 * } else {
354 * break;
355 * }
356 * }
357 * after_block:
358 */
359 struct ir3_block *body = ir3_block_create(ir);
360 list_add(&body->node, &before_block->node);
361
362 struct ir3_block *store = ir3_block_create(ir);
363 list_add(&store->node, &body->node);
364
365 after_block->reconvergence_point = true;
366
367 link_blocks_jump(before_block, body);
368
369 link_blocks_branch(body, store, after_block, OPC_GETLAST, 0, NULL);
370
371 link_blocks_branch(store, after_block, body, OPC_GETONE,
372 IR3_INSTR_NEEDS_HELPERS, NULL);
373
374 struct ir3_register *reduce = instr->dsts[0];
375 struct ir3_register *inclusive = instr->dsts[1];
376 struct ir3_register *inclusive_src = instr->srcs[1];
377
378 /* We need to perform the following operations:
379 * - inclusive = inclusive_src OP reduce
380 * - exclusive = exclusive_src OP reduce (iff exclusive scan)
381 * Since reduce is initially in a shared register, we need to copy it to a
382 * scratch register before performing the operations.
383 *
384 * The scratch register used is:
385 * - an explicitly allocated one if op is 32b mul_u.
386 * - necessary because we cannot do 'foo = foo mul_u bar' since mul_u
387 * clobbers its destination.
388 * - exclusive if this is an exclusive scan (and not 32b mul_u).
389 * - since we calculate inclusive first.
390 * - inclusive otherwise.
391 *
392 * In all cases, this is the last destination.
393 */
394 struct ir3_register *scratch = instr->dsts[instr->dsts_count - 1];
395
396 mov_reg(body, scratch, reduce);
397 do_reduce(body, instr->cat1.reduce_op, inclusive, inclusive_src, scratch);
398
399 /* exclusive scan */
400 if (instr->srcs_count == 3) {
401 struct ir3_register *exclusive_src = instr->srcs[2];
402 struct ir3_register *exclusive = instr->dsts[2];
403 do_reduce(body, instr->cat1.reduce_op, exclusive, exclusive_src,
404 scratch);
405 }
406
407 mov_reg(store, reduce, inclusive);
408 } else {
409 /* For ballot, the destination must be initialized to 0 before we do
410 * the movmsk because the condition may be 0 and then the movmsk will
411 * be skipped.
412 */
413 if (instr->opc == OPC_BALLOT_MACRO) {
414 mov_immed(instr->dsts[0], before_block, 0);
415 }
416
417 struct ir3_instruction *condition = NULL;
418 unsigned branch_opc = 0;
419 unsigned branch_flags = 0;
420
421 switch (instr->opc) {
422 case OPC_BALLOT_MACRO:
423 case OPC_READ_COND_MACRO:
424 case OPC_ANY_MACRO:
425 case OPC_ALL_MACRO:
426 condition = instr->srcs[0]->def->instr;
427 break;
428 default:
429 break;
430 }
431
432 switch (instr->opc) {
433 case OPC_BALLOT_MACRO:
434 case OPC_READ_COND_MACRO:
435 after_block->reconvergence_point = true;
436 branch_opc = OPC_BR;
437 break;
438 case OPC_ANY_MACRO:
439 branch_opc = OPC_BANY;
440 break;
441 case OPC_ALL_MACRO:
442 branch_opc = OPC_BALL;
443 break;
444 case OPC_ELECT_MACRO:
445 after_block->reconvergence_point = true;
446 branch_opc = OPC_GETONE;
447 branch_flags = instr->flags & IR3_INSTR_NEEDS_HELPERS;
448 break;
449 case OPC_READ_GETLAST_MACRO:
450 after_block->reconvergence_point = true;
451 branch_opc = OPC_GETLAST;
452 branch_flags = instr->flags & IR3_INSTR_NEEDS_HELPERS;
453 break;
454 default:
455 unreachable("bad opcode");
456 }
457
458 struct ir3_block *then_block =
459 create_if(ir, before_block, after_block, branch_opc, branch_flags,
460 condition);
461
462 switch (instr->opc) {
463 case OPC_ALL_MACRO:
464 case OPC_ANY_MACRO:
465 case OPC_ELECT_MACRO:
466 mov_immed(instr->dsts[0], then_block, 1);
467 mov_immed(instr->dsts[0], before_block, 0);
468 break;
469
470 case OPC_BALLOT_MACRO: {
471 unsigned wrmask = instr->dsts[0]->wrmask;
472 unsigned comp_count = util_last_bit(wrmask);
473 struct ir3_instruction *movmsk = ir3_instr_create_at(
474 ir3_before_terminator(then_block), OPC_MOVMSK, 1, 0);
475 struct ir3_register *dst =
476 ir3_dst_create(movmsk, instr->dsts[0]->num, instr->dsts[0]->flags);
477 dst->wrmask = wrmask;
478 movmsk->repeat = comp_count - 1;
479 break;
480 }
481
482 case OPC_READ_GETLAST_MACRO:
483 case OPC_READ_COND_MACRO: {
484 struct ir3_instruction *mov = ir3_instr_create_at(
485 ir3_before_terminator(then_block), OPC_MOV, 1, 1);
486 ir3_dst_create(mov, instr->dsts[0]->num, instr->dsts[0]->flags);
487 struct ir3_register *new_src = ir3_src_create(mov, 0, 0);
488 unsigned idx = instr->opc == OPC_READ_COND_MACRO ? 1 : 0;
489 *new_src = *instr->srcs[idx];
490 mov->cat1.dst_type = TYPE_U32;
491 mov->cat1.src_type =
492 (new_src->flags & IR3_REG_HALF) ? TYPE_U16 : TYPE_U32;
493 mov->flags |= IR3_INSTR_NEEDS_HELPERS;
494 break;
495 }
496
497 default:
498 unreachable("bad opcode");
499 }
500 }
501
502 *block = after_block;
503 list_delinit(&instr->node);
504 return true;
505 }
506
507 static bool
lower_block(struct ir3 * ir,struct ir3_block ** block)508 lower_block(struct ir3 *ir, struct ir3_block **block)
509 {
510 bool progress = true;
511
512 bool inner_progress;
513 do {
514 inner_progress = false;
515 foreach_instr (instr, &(*block)->instr_list) {
516 if (lower_instr(ir, block, instr)) {
517 /* restart the loop with the new block we created because the
518 * iterator has been invalidated.
519 */
520 progress = inner_progress = true;
521 break;
522 }
523 }
524 } while (inner_progress);
525
526 return progress;
527 }
528
529 bool
ir3_lower_subgroups(struct ir3 * ir)530 ir3_lower_subgroups(struct ir3 *ir)
531 {
532 bool progress = false;
533
534 foreach_block (block, &ir->block_list)
535 progress |= lower_block(ir, &block);
536
537 return progress;
538 }
539
540 static const struct glsl_type *
glsl_type_for_def(nir_def * def)541 glsl_type_for_def(nir_def *def)
542 {
543 assert(def->num_components == 1);
544 return def->bit_size == 1 ? glsl_bool_type()
545 : glsl_uintN_t_type(def->bit_size);
546 }
547
548 static bool
filter_scan_reduce(const nir_instr * instr,const void * data)549 filter_scan_reduce(const nir_instr *instr, const void *data)
550 {
551 if (instr->type != nir_instr_type_intrinsic)
552 return false;
553
554 nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
555
556 switch (intrin->intrinsic) {
557 case nir_intrinsic_reduce:
558 case nir_intrinsic_inclusive_scan:
559 case nir_intrinsic_exclusive_scan:
560 return true;
561 default:
562 return false;
563 }
564 }
565
566 typedef nir_def *(*reduce_cluster)(nir_builder *, nir_op, nir_def *);
567
568 /* Execute `reduce` for each cluster in the subgroup with only the invocations
569 * in the current cluster active.
570 */
571 static nir_def *
foreach_cluster(nir_builder * b,nir_op op,nir_def * inclusive,unsigned cluster_size,reduce_cluster reduce)572 foreach_cluster(nir_builder *b, nir_op op, nir_def *inclusive,
573 unsigned cluster_size, reduce_cluster reduce)
574 {
575 nir_def *id = nir_load_subgroup_invocation(b);
576 nir_def *cluster_size_imm = nir_imm_int(b, cluster_size);
577
578 /* cur_cluster_end = cluster_size;
579 * while (true) {
580 * if (gl_SubgroupInvocationID < cur_cluster_end) {
581 * cluster_val = reduce(inclusive);
582 * break;
583 * }
584 *
585 * cur_cluster_end += cluster_size;
586 * }
587 */
588 nir_variable *cur_cluster_end_var =
589 nir_local_variable_create(b->impl, glsl_uint_type(), "cur_cluster_end");
590 nir_store_var(b, cur_cluster_end_var, cluster_size_imm, 1);
591 nir_variable *cluster_val_var = nir_local_variable_create(
592 b->impl, glsl_type_for_def(inclusive), "cluster_val");
593
594 nir_loop *loop = nir_push_loop(b);
595 {
596 nir_def *cur_cluster_end = nir_load_var(b, cur_cluster_end_var);
597 nir_def *in_cur_cluster = nir_ult(b, id, cur_cluster_end);
598
599 nir_if *nif = nir_push_if(b, in_cur_cluster);
600 {
601 nir_def *reduced = reduce(b, op, inclusive);
602 nir_store_var(b, cluster_val_var, reduced, 1);
603 nir_jump(b, nir_jump_break);
604 }
605 nir_pop_if(b, nif);
606
607 nir_def *next_cluster_end =
608 nir_iadd(b, cur_cluster_end, cluster_size_imm);
609 nir_store_var(b, cur_cluster_end_var, next_cluster_end, 1);
610 }
611 nir_pop_loop(b, loop);
612
613 return nir_load_var(b, cluster_val_var);
614 }
615
616 static nir_def *
read_last(nir_builder * b,nir_op op,nir_def * val)617 read_last(nir_builder *b, nir_op op, nir_def *val)
618 {
619 return nir_read_getlast_ir3(b, val);
620 }
621
622 static nir_def *
reduce_clusters(nir_builder * b,nir_op op,nir_def * val)623 reduce_clusters(nir_builder *b, nir_op op, nir_def *val)
624 {
625 return nir_reduce_clusters_ir3(b, val, .reduction_op = op);
626 }
627
628 static nir_def *
lower_scan_reduce(struct nir_builder * b,nir_instr * instr,void * data)629 lower_scan_reduce(struct nir_builder *b, nir_instr *instr, void *data)
630 {
631 struct ir3_shader_variant *v = data;
632
633 nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
634 unsigned bit_size = intrin->def.bit_size;
635 assert(bit_size < 64);
636
637 nir_op op = nir_intrinsic_reduction_op(intrin);
638 nir_const_value ident_val = nir_alu_binop_identity(op, bit_size);
639 nir_def *ident = nir_build_imm(b, 1, bit_size, &ident_val);
640 nir_def *inclusive = intrin->src[0].ssa;
641 nir_def *exclusive = ident;
642 unsigned cluster_size = nir_intrinsic_has_cluster_size(intrin)
643 ? nir_intrinsic_cluster_size(intrin)
644 : 0;
645 bool clustered = cluster_size != 0;
646 unsigned subgroup_size, max_subgroup_size;
647 ir3_shader_get_subgroup_size(v->compiler, &v->shader_options, v->type,
648 &subgroup_size, &max_subgroup_size);
649
650 if (subgroup_size == 0) {
651 subgroup_size = max_subgroup_size;
652 }
653
654 /* Should have been lowered by nir_lower_subgroups. */
655 assert(cluster_size != 1);
656
657 /* Only clustered reduce operations are supported. */
658 assert(intrin->intrinsic == nir_intrinsic_reduce || !clustered);
659
660 unsigned max_brcst_cluster_size = clustered ? MIN2(cluster_size, 8) : 8;
661
662 for (unsigned brcst_cluster_size = 2;
663 brcst_cluster_size <= max_brcst_cluster_size; brcst_cluster_size *= 2) {
664 nir_def *brcst = nir_brcst_active_ir3(b, ident, inclusive,
665 .cluster_size = brcst_cluster_size);
666 inclusive = nir_build_alu2(b, op, inclusive, brcst);
667
668 if (intrin->intrinsic == nir_intrinsic_exclusive_scan)
669 exclusive = nir_build_alu2(b, op, exclusive, brcst);
670 }
671
672 switch (intrin->intrinsic) {
673 case nir_intrinsic_reduce:
674 if (!clustered || cluster_size >= subgroup_size) {
675 /* The normal (non-clustered) path does a full reduction of all brcst
676 * clusters.
677 */
678 return nir_reduce_clusters_ir3(b, inclusive, .reduction_op = op);
679 } else if (cluster_size <= 8) {
680 /* After the brcsts have been executed, each brcst cluster has its
681 * reduction in its last fiber. So if the cluster size is at most the
682 * maximum brcst cluster size (8) we can simply iterate the clusters
683 * and read the value from their last fibers.
684 */
685 return foreach_cluster(b, op, inclusive, cluster_size, read_last);
686 } else {
687 /* For larger clusters, we do a normal reduction for every cluster.
688 */
689 return foreach_cluster(b, op, inclusive, cluster_size,
690 reduce_clusters);
691 }
692 case nir_intrinsic_inclusive_scan:
693 return nir_inclusive_scan_clusters_ir3(b, inclusive, .reduction_op = op);
694 case nir_intrinsic_exclusive_scan:
695 return nir_exclusive_scan_clusters_ir3(b, inclusive, exclusive,
696 .reduction_op = op);
697 default:
698 unreachable("filtered intrinsic");
699 }
700 }
701
702 bool
ir3_nir_opt_subgroups(nir_shader * nir,struct ir3_shader_variant * v)703 ir3_nir_opt_subgroups(nir_shader *nir, struct ir3_shader_variant *v)
704 {
705 if (!v->compiler->has_getfiberid)
706 return false;
707
708 return nir_shader_lower_instructions(nir, filter_scan_reduce,
709 lower_scan_reduce, v);
710 }
711
712 bool
ir3_nir_lower_subgroups_filter(const nir_instr * instr,const void * data)713 ir3_nir_lower_subgroups_filter(const nir_instr *instr, const void *data)
714 {
715 if (instr->type != nir_instr_type_intrinsic)
716 return false;
717
718 nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
719
720 const struct ir3_compiler *compiler = data;
721
722 switch (intrin->intrinsic) {
723 case nir_intrinsic_reduce:
724 if (nir_intrinsic_cluster_size(intrin) == 1) {
725 return true;
726 }
727 if (nir_intrinsic_cluster_size(intrin) > 0 && !compiler->has_getfiberid) {
728 return true;
729 }
730 FALLTHROUGH;
731 case nir_intrinsic_inclusive_scan:
732 case nir_intrinsic_exclusive_scan:
733 switch (nir_intrinsic_reduction_op(intrin)) {
734 case nir_op_imul:
735 case nir_op_imin:
736 case nir_op_imax:
737 case nir_op_umin:
738 case nir_op_umax:
739 if (intrin->def.bit_size == 64) {
740 return true;
741 }
742 FALLTHROUGH;
743 default:
744 return intrin->def.num_components > 1;
745 }
746 default:
747 return true;
748 }
749 }
750
751 static bool
filter_shuffle(const nir_instr * instr,const void * data)752 filter_shuffle(const nir_instr *instr, const void *data)
753 {
754 if (instr->type != nir_instr_type_intrinsic) {
755 return false;
756 }
757
758 nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
759
760 switch (intrin->intrinsic) {
761 case nir_intrinsic_shuffle:
762 case nir_intrinsic_shuffle_up:
763 case nir_intrinsic_shuffle_down:
764 case nir_intrinsic_shuffle_xor:
765 return true;
766 default:
767 return false;
768 }
769 }
770
771 static nir_def *
shuffle_to_uniform(nir_builder * b,nir_intrinsic_op op,struct nir_def * val,struct nir_def * id)772 shuffle_to_uniform(nir_builder *b, nir_intrinsic_op op, struct nir_def *val,
773 struct nir_def *id)
774 {
775 switch (op) {
776 case nir_intrinsic_shuffle:
777 return nir_rotate(b, val, id);
778 case nir_intrinsic_shuffle_up:
779 return nir_shuffle_up_uniform_ir3(b, val, id);
780 case nir_intrinsic_shuffle_down:
781 return nir_shuffle_down_uniform_ir3(b, val, id);
782 case nir_intrinsic_shuffle_xor:
783 return nir_shuffle_xor_uniform_ir3(b, val, id);
784 default:
785 unreachable("filtered intrinsic");
786 }
787 }
788
789 /* Transforms a shuffle operation into a loop that only uses shuffles with
790 * (dynamically) uniform indices. This is based on the blob's sequence and
791 * carefully makes sure that the least amount of iterations are performed (i.e.,
792 * one iteration per distinct index) while keeping all invocations active during
793 * each shfl operation. This is necessary since shfl does not update its dst
794 * when its src is inactive.
795 *
796 * done = false;
797 * while (true) {
798 * next_index = read_invocation_cond_ir3(index, !done);
799 * shuffled = op_uniform(val, next_index);
800 *
801 * if (index == next_index) {
802 * result = shuffled;
803 * done = true;
804 * }
805 *
806 * if (subgroupAll(done)) {
807 * break;
808 * }
809 * }
810 */
811 static nir_def *
make_shuffle_uniform(nir_builder * b,nir_def * val,nir_def * index,nir_intrinsic_op op)812 make_shuffle_uniform(nir_builder *b, nir_def *val, nir_def *index,
813 nir_intrinsic_op op)
814 {
815 nir_variable *done =
816 nir_local_variable_create(b->impl, glsl_bool_type(), "done");
817 nir_store_var(b, done, nir_imm_false(b), 1);
818 nir_variable *result =
819 nir_local_variable_create(b->impl, glsl_type_for_def(val), "result");
820
821 nir_loop *loop = nir_push_loop(b);
822 {
823 nir_def *next_index = nir_read_invocation_cond_ir3(
824 b, index->bit_size, index, nir_inot(b, nir_load_var(b, done)));
825 next_index->divergent = false;
826 nir_def *shuffled = shuffle_to_uniform(b, op, val, next_index);
827
828 nir_if *nif = nir_push_if(b, nir_ieq(b, index, next_index));
829 {
830 nir_store_var(b, result, shuffled, 1);
831 nir_store_var(b, done, nir_imm_true(b), 1);
832 }
833 nir_pop_if(b, nif);
834
835 nir_break_if(b, nir_vote_all(b, 1, nir_load_var(b, done)));
836 }
837 nir_pop_loop(b, loop);
838
839 return nir_load_var(b, result);
840 }
841
842 static nir_def *
lower_shuffle(nir_builder * b,nir_instr * instr,void * data)843 lower_shuffle(nir_builder *b, nir_instr *instr, void *data)
844 {
845 nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
846 nir_def *val = intrin->src[0].ssa;
847 nir_def *index = intrin->src[1].ssa;
848
849 if (intrin->intrinsic == nir_intrinsic_shuffle) {
850 /* The hw only does relative shuffles/rotates so transform shuffle(val, x)
851 * into rotate(val, x - gl_SubgroupInvocationID) which is valid since we
852 * make sure to only use it with uniform indices.
853 */
854 index = nir_isub(b, index, nir_load_subgroup_invocation(b));
855 }
856
857 if (!index->divergent) {
858 return shuffle_to_uniform(b, intrin->intrinsic, val, index);
859 }
860
861 return make_shuffle_uniform(b, val, index, intrin->intrinsic);
862 }
863
864 /* Lower (relative) shuffles to be able to use the shfl instruction. One quirk
865 * of shfl is that its index has to be dynamically uniform, so we transform the
866 * standard NIR intrinsics into ir3-specific ones which require their index to
867 * be uniform.
868 */
869 bool
ir3_nir_lower_shuffle(nir_shader * nir,struct ir3_shader * shader)870 ir3_nir_lower_shuffle(nir_shader *nir, struct ir3_shader *shader)
871 {
872 if (!shader->compiler->has_shfl) {
873 return false;
874 }
875
876 nir_divergence_analysis(nir);
877 return nir_shader_lower_instructions(nir, filter_shuffle, lower_shuffle,
878 NULL);
879 }
880