1 /*
2 * Copyright (C) 2021 Valve Corporation
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 * and/or sell copies of the Software, and to permit persons to whom the
9 * Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 * SOFTWARE.
22 */
23
24 #include "ir3.h"
25 #include "ir3_nir.h"
26 #include "util/ralloc.h"
27
28 /* Lower several macro-instructions needed for shader subgroup support that
29 * must be turned into if statements. We do this after RA and post-RA
30 * scheduling to give the scheduler a chance to rearrange them, because RA
31 * may need to insert OPC_META_READ_FIRST to handle splitting live ranges, and
32 * also because some (e.g. BALLOT and READ_FIRST) must produce a shared
33 * register that cannot be spilled to a normal register until after the if,
34 * which makes implementing spilling more complicated if they are already
35 * lowered.
36 */
37
38 static void
replace_pred(struct ir3_block * block,struct ir3_block * old_pred,struct ir3_block * new_pred)39 replace_pred(struct ir3_block *block, struct ir3_block *old_pred,
40 struct ir3_block *new_pred)
41 {
42 for (unsigned i = 0; i < block->predecessors_count; i++) {
43 if (block->predecessors[i] == old_pred) {
44 block->predecessors[i] = new_pred;
45 return;
46 }
47 }
48 }
49
50 static void
replace_physical_pred(struct ir3_block * block,struct ir3_block * old_pred,struct ir3_block * new_pred)51 replace_physical_pred(struct ir3_block *block, struct ir3_block *old_pred,
52 struct ir3_block *new_pred)
53 {
54 for (unsigned i = 0; i < block->physical_predecessors_count; i++) {
55 if (block->physical_predecessors[i] == old_pred) {
56 block->physical_predecessors[i] = new_pred;
57 return;
58 }
59 }
60 }
61
62 static void
mov_immed(struct ir3_register * dst,struct ir3_block * block,unsigned immed)63 mov_immed(struct ir3_register *dst, struct ir3_block *block, unsigned immed)
64 {
65 struct ir3_instruction *mov = ir3_instr_create(block, OPC_MOV, 1, 1);
66 struct ir3_register *mov_dst = ir3_dst_create(mov, dst->num, dst->flags);
67 mov_dst->wrmask = dst->wrmask;
68 struct ir3_register *src = ir3_src_create(
69 mov, INVALID_REG, (dst->flags & IR3_REG_HALF) | IR3_REG_IMMED);
70 src->uim_val = immed;
71 mov->cat1.dst_type = (dst->flags & IR3_REG_HALF) ? TYPE_U16 : TYPE_U32;
72 mov->cat1.src_type = mov->cat1.dst_type;
73 mov->repeat = util_last_bit(mov_dst->wrmask) - 1;
74 }
75
76 static void
mov_reg(struct ir3_block * block,struct ir3_register * dst,struct ir3_register * src)77 mov_reg(struct ir3_block *block, struct ir3_register *dst,
78 struct ir3_register *src)
79 {
80 struct ir3_instruction *mov = ir3_instr_create(block, OPC_MOV, 1, 1);
81
82 struct ir3_register *mov_dst =
83 ir3_dst_create(mov, dst->num, dst->flags & (IR3_REG_HALF | IR3_REG_SHARED));
84 struct ir3_register *mov_src =
85 ir3_src_create(mov, src->num, src->flags & (IR3_REG_HALF | IR3_REG_SHARED));
86 mov_dst->wrmask = dst->wrmask;
87 mov_src->wrmask = src->wrmask;
88 mov->repeat = util_last_bit(mov_dst->wrmask) - 1;
89
90 mov->cat1.dst_type = (dst->flags & IR3_REG_HALF) ? TYPE_U16 : TYPE_U32;
91 mov->cat1.src_type = (src->flags & IR3_REG_HALF) ? TYPE_U16 : TYPE_U32;
92 }
93
94 static void
binop(struct ir3_block * block,opc_t opc,struct ir3_register * dst,struct ir3_register * src0,struct ir3_register * src1)95 binop(struct ir3_block *block, opc_t opc, struct ir3_register *dst,
96 struct ir3_register *src0, struct ir3_register *src1)
97 {
98 struct ir3_instruction *instr = ir3_instr_create(block, opc, 1, 2);
99
100 unsigned flags = dst->flags & IR3_REG_HALF;
101 struct ir3_register *instr_dst = ir3_dst_create(instr, dst->num, flags);
102 struct ir3_register *instr_src0 = ir3_src_create(instr, src0->num, flags);
103 struct ir3_register *instr_src1 = ir3_src_create(instr, src1->num, flags);
104
105 instr_dst->wrmask = dst->wrmask;
106 instr_src0->wrmask = src0->wrmask;
107 instr_src1->wrmask = src1->wrmask;
108 instr->repeat = util_last_bit(instr_dst->wrmask) - 1;
109 }
110
111 static void
triop(struct ir3_block * block,opc_t opc,struct ir3_register * dst,struct ir3_register * src0,struct ir3_register * src1,struct ir3_register * src2)112 triop(struct ir3_block *block, opc_t opc, struct ir3_register *dst,
113 struct ir3_register *src0, struct ir3_register *src1,
114 struct ir3_register *src2)
115 {
116 struct ir3_instruction *instr = ir3_instr_create(block, opc, 1, 3);
117
118 unsigned flags = dst->flags & IR3_REG_HALF;
119 struct ir3_register *instr_dst = ir3_dst_create(instr, dst->num, flags);
120 struct ir3_register *instr_src0 = ir3_src_create(instr, src0->num, flags);
121 struct ir3_register *instr_src1 = ir3_src_create(instr, src1->num, flags);
122 struct ir3_register *instr_src2 = ir3_src_create(instr, src2->num, flags);
123
124 instr_dst->wrmask = dst->wrmask;
125 instr_src0->wrmask = src0->wrmask;
126 instr_src1->wrmask = src1->wrmask;
127 instr_src2->wrmask = src2->wrmask;
128 instr->repeat = util_last_bit(instr_dst->wrmask) - 1;
129 }
130
131 static void
do_reduce(struct ir3_block * block,reduce_op_t opc,struct ir3_register * dst,struct ir3_register * src0,struct ir3_register * src1)132 do_reduce(struct ir3_block *block, reduce_op_t opc,
133 struct ir3_register *dst, struct ir3_register *src0,
134 struct ir3_register *src1)
135 {
136 switch (opc) {
137 #define CASE(name) \
138 case REDUCE_OP_##name: \
139 binop(block, OPC_##name, dst, src0, src1); \
140 break;
141
142 CASE(ADD_U)
143 CASE(ADD_F)
144 CASE(MUL_F)
145 CASE(MIN_U)
146 CASE(MIN_S)
147 CASE(MIN_F)
148 CASE(MAX_U)
149 CASE(MAX_S)
150 CASE(MAX_F)
151 CASE(AND_B)
152 CASE(OR_B)
153 CASE(XOR_B)
154
155 #undef CASE
156
157 case REDUCE_OP_MUL_U:
158 if (dst->flags & IR3_REG_HALF) {
159 binop(block, OPC_MUL_S24, dst, src0, src1);
160 } else {
161 /* 32-bit multiplication macro - see ir3_nir_imul */
162 binop(block, OPC_MULL_U, dst, src0, src1);
163 triop(block, OPC_MADSH_M16, dst, src0, src1, dst);
164 triop(block, OPC_MADSH_M16, dst, src1, src0, dst);
165 }
166 break;
167 }
168 }
169
170 static struct ir3_block *
split_block(struct ir3 * ir,struct ir3_block * before_block,struct ir3_instruction * instr)171 split_block(struct ir3 *ir, struct ir3_block *before_block,
172 struct ir3_instruction *instr)
173 {
174 struct ir3_block *after_block = ir3_block_create(ir);
175 list_add(&after_block->node, &before_block->node);
176
177 for (unsigned i = 0; i < ARRAY_SIZE(before_block->successors); i++) {
178 after_block->successors[i] = before_block->successors[i];
179 if (after_block->successors[i])
180 replace_pred(after_block->successors[i], before_block, after_block);
181 }
182
183 for (unsigned i = 0; i < before_block->physical_successors_count; i++) {
184 replace_physical_pred(before_block->physical_successors[i],
185 before_block, after_block);
186 }
187
188 ralloc_steal(after_block, before_block->physical_successors);
189 after_block->physical_successors = before_block->physical_successors;
190 after_block->physical_successors_sz = before_block->physical_successors_sz;
191 after_block->physical_successors_count =
192 before_block->physical_successors_count;
193
194 before_block->successors[0] = before_block->successors[1] = NULL;
195 before_block->physical_successors = NULL;
196 before_block->physical_successors_count = 0;
197 before_block->physical_successors_sz = 0;
198
199 foreach_instr_from_safe (rem_instr, &instr->node,
200 &before_block->instr_list) {
201 list_del(&rem_instr->node);
202 list_addtail(&rem_instr->node, &after_block->instr_list);
203 rem_instr->block = after_block;
204 }
205
206 after_block->brtype = before_block->brtype;
207 after_block->condition = before_block->condition;
208
209 return after_block;
210 }
211
212 static void
link_blocks(struct ir3_block * pred,struct ir3_block * succ,unsigned index)213 link_blocks(struct ir3_block *pred, struct ir3_block *succ, unsigned index)
214 {
215 pred->successors[index] = succ;
216 ir3_block_add_predecessor(succ, pred);
217 ir3_block_link_physical(pred, succ);
218 }
219
220 static struct ir3_block *
create_if(struct ir3 * ir,struct ir3_block * before_block,struct ir3_block * after_block)221 create_if(struct ir3 *ir, struct ir3_block *before_block,
222 struct ir3_block *after_block)
223 {
224 struct ir3_block *then_block = ir3_block_create(ir);
225 list_add(&then_block->node, &before_block->node);
226
227 link_blocks(before_block, then_block, 0);
228 link_blocks(before_block, after_block, 1);
229 link_blocks(then_block, after_block, 0);
230
231 return then_block;
232 }
233
234 static bool
lower_instr(struct ir3 * ir,struct ir3_block ** block,struct ir3_instruction * instr)235 lower_instr(struct ir3 *ir, struct ir3_block **block, struct ir3_instruction *instr)
236 {
237 switch (instr->opc) {
238 case OPC_BALLOT_MACRO:
239 case OPC_ANY_MACRO:
240 case OPC_ALL_MACRO:
241 case OPC_ELECT_MACRO:
242 case OPC_READ_COND_MACRO:
243 case OPC_SWZ_SHARED_MACRO:
244 case OPC_SCAN_MACRO:
245 case OPC_SCAN_CLUSTERS_MACRO:
246 break;
247 case OPC_READ_FIRST_MACRO:
248 /* Moves to shared registers read the first active fiber, so we can just
249 * turn read_first.macro into a move. However we must still use the macro
250 * and lower it late because in ir3_cp we need to distinguish between
251 * moves where all source fibers contain the same value, which can be copy
252 * propagated, and moves generated from API-level ReadFirstInvocation
253 * which cannot.
254 */
255 assert(instr->dsts[0]->flags & IR3_REG_SHARED);
256 instr->opc = OPC_MOV;
257 instr->cat1.dst_type = TYPE_U32;
258 instr->cat1.src_type =
259 (instr->srcs[0]->flags & IR3_REG_HALF) ? TYPE_U16 : TYPE_U32;
260 return false;
261 default:
262 return false;
263 }
264
265 struct ir3_block *before_block = *block;
266 struct ir3_block *after_block = split_block(ir, before_block, instr);
267
268 if (instr->opc == OPC_SCAN_MACRO) {
269 /* The pseudo-code for the scan macro is:
270 *
271 * while (true) {
272 * header:
273 * if (elect()) {
274 * exit:
275 * exclusive = reduce;
276 * inclusive = src OP exclusive;
277 * reduce = inclusive;
278 * break;
279 * }
280 * footer:
281 * }
282 *
283 * This is based on the blob's sequence, and carefully crafted to avoid
284 * using the shared register "reduce" except in move instructions, since
285 * using it in the actual OP isn't possible for half-registers.
286 */
287 struct ir3_block *header = ir3_block_create(ir);
288 list_add(&header->node, &before_block->node);
289
290 struct ir3_block *exit = ir3_block_create(ir);
291 list_add(&exit->node, &header->node);
292
293 struct ir3_block *footer = ir3_block_create(ir);
294 list_add(&footer->node, &exit->node);
295 footer->reconvergence_point = true;
296
297 after_block->reconvergence_point = true;
298
299 link_blocks(before_block, header, 0);
300
301 link_blocks(header, exit, 0);
302 link_blocks(header, footer, 1);
303 header->brtype = IR3_BRANCH_GETONE;
304
305 link_blocks(exit, after_block, 0);
306 ir3_block_link_physical(exit, footer);
307
308 link_blocks(footer, header, 0);
309
310 struct ir3_register *exclusive = instr->dsts[0];
311 struct ir3_register *inclusive = instr->dsts[1];
312 struct ir3_register *reduce = instr->dsts[2];
313 struct ir3_register *src = instr->srcs[0];
314
315 mov_reg(exit, exclusive, reduce);
316 do_reduce(exit, instr->cat1.reduce_op, inclusive, src, exclusive);
317 mov_reg(exit, reduce, inclusive);
318 } else if (instr->opc == OPC_SCAN_CLUSTERS_MACRO) {
319 /* The pseudo-code for the scan macro is:
320 *
321 * while (true) {
322 * body:
323 * scratch = reduce;
324 *
325 * inclusive = inclusive_src OP scratch;
326 *
327 * static if (is exclusive scan)
328 * exclusive = exclusive_src OP scratch
329 *
330 * if (getlast()) {
331 * store:
332 * reduce = inclusive;
333 * if (elect())
334 * break;
335 * } else {
336 * break;
337 * }
338 * }
339 * after_block:
340 */
341 struct ir3_block *body = ir3_block_create(ir);
342 list_add(&body->node, &before_block->node);
343
344 struct ir3_block *store = ir3_block_create(ir);
345 list_add(&store->node, &body->node);
346
347 body->reconvergence_point = true;
348 after_block->reconvergence_point = true;
349
350 link_blocks(before_block, body, 0);
351
352 link_blocks(body, store, 0);
353 link_blocks(body, after_block, 1);
354 body->brtype = IR3_BRANCH_GETLAST;
355
356 link_blocks(store, after_block, 0);
357 link_blocks(store, body, 1);
358 store->brtype = IR3_BRANCH_GETONE;
359
360 struct ir3_register *reduce = instr->dsts[0];
361 struct ir3_register *inclusive = instr->dsts[1];
362 struct ir3_register *inclusive_src = instr->srcs[1];
363
364 /* We need to perform the following operations:
365 * - inclusive = inclusive_src OP reduce
366 * - exclusive = exclusive_src OP reduce (iff exclusive scan)
367 * Since reduce is initially in a shared register, we need to copy it to a
368 * scratch register before performing the operations.
369 *
370 * The scratch register used is:
371 * - an explicitly allocated one if op is 32b mul_u.
372 * - necessary because we cannot do 'foo = foo mul_u bar' since mul_u
373 * clobbers its destination.
374 * - exclusive if this is an exclusive scan (and not 32b mul_u).
375 * - since we calculate inclusive first.
376 * - inclusive otherwise.
377 *
378 * In all cases, this is the last destination.
379 */
380 struct ir3_register *scratch = instr->dsts[instr->dsts_count - 1];
381
382 mov_reg(body, scratch, reduce);
383 do_reduce(body, instr->cat1.reduce_op, inclusive, inclusive_src, scratch);
384
385 /* exclusive scan */
386 if (instr->srcs_count == 3) {
387 struct ir3_register *exclusive_src = instr->srcs[2];
388 struct ir3_register *exclusive = instr->dsts[2];
389 do_reduce(body, instr->cat1.reduce_op, exclusive, exclusive_src,
390 scratch);
391 }
392
393 mov_reg(store, reduce, inclusive);
394 } else {
395 struct ir3_block *then_block = create_if(ir, before_block, after_block);
396
397 /* For ballot, the destination must be initialized to 0 before we do
398 * the movmsk because the condition may be 0 and then the movmsk will
399 * be skipped. Because it's a shared register we have to wrap the
400 * initialization in a getone block.
401 */
402 if (instr->opc == OPC_BALLOT_MACRO) {
403 before_block->brtype = IR3_BRANCH_GETONE;
404 before_block->condition = NULL;
405 mov_immed(instr->dsts[0], then_block, 0);
406 after_block->reconvergence_point = true;
407 before_block = after_block;
408 after_block = split_block(ir, before_block, instr);
409 then_block = create_if(ir, before_block, after_block);
410 }
411
412 switch (instr->opc) {
413 case OPC_BALLOT_MACRO:
414 case OPC_READ_COND_MACRO:
415 case OPC_ANY_MACRO:
416 case OPC_ALL_MACRO:
417 before_block->condition = instr->srcs[0]->def->instr;
418 break;
419 default:
420 before_block->condition = NULL;
421 break;
422 }
423
424 switch (instr->opc) {
425 case OPC_BALLOT_MACRO:
426 case OPC_READ_COND_MACRO:
427 before_block->brtype = IR3_BRANCH_COND;
428 after_block->reconvergence_point = true;
429 break;
430 case OPC_ANY_MACRO:
431 before_block->brtype = IR3_BRANCH_ANY;
432 break;
433 case OPC_ALL_MACRO:
434 before_block->brtype = IR3_BRANCH_ALL;
435 break;
436 case OPC_ELECT_MACRO:
437 case OPC_SWZ_SHARED_MACRO:
438 before_block->brtype = IR3_BRANCH_GETONE;
439 after_block->reconvergence_point = true;
440 break;
441 default:
442 unreachable("bad opcode");
443 }
444
445 switch (instr->opc) {
446 case OPC_ALL_MACRO:
447 case OPC_ANY_MACRO:
448 case OPC_ELECT_MACRO:
449 mov_immed(instr->dsts[0], then_block, 1);
450 mov_immed(instr->dsts[0], before_block, 0);
451 break;
452
453 case OPC_BALLOT_MACRO: {
454 unsigned comp_count = util_last_bit(instr->dsts[0]->wrmask);
455 struct ir3_instruction *movmsk =
456 ir3_instr_create(then_block, OPC_MOVMSK, 1, 0);
457 ir3_dst_create(movmsk, instr->dsts[0]->num, instr->dsts[0]->flags);
458 movmsk->repeat = comp_count - 1;
459 break;
460 }
461
462 case OPC_READ_COND_MACRO: {
463 struct ir3_instruction *mov =
464 ir3_instr_create(then_block, OPC_MOV, 1, 1);
465 ir3_dst_create(mov, instr->dsts[0]->num, instr->dsts[0]->flags);
466 struct ir3_register *new_src = ir3_src_create(mov, 0, 0);
467 *new_src = *instr->srcs[1];
468 mov->cat1.dst_type = TYPE_U32;
469 mov->cat1.src_type =
470 (new_src->flags & IR3_REG_HALF) ? TYPE_U16 : TYPE_U32;
471 break;
472 }
473
474 case OPC_SWZ_SHARED_MACRO: {
475 struct ir3_instruction *swz =
476 ir3_instr_create(then_block, OPC_SWZ, 2, 2);
477 ir3_dst_create(swz, instr->dsts[0]->num, instr->dsts[0]->flags);
478 ir3_dst_create(swz, instr->dsts[1]->num, instr->dsts[1]->flags);
479 ir3_src_create(swz, instr->srcs[0]->num, instr->srcs[0]->flags);
480 ir3_src_create(swz, instr->srcs[1]->num, instr->srcs[1]->flags);
481 swz->cat1.dst_type = swz->cat1.src_type = TYPE_U32;
482 swz->repeat = 1;
483 break;
484 }
485
486 default:
487 unreachable("bad opcode");
488 }
489 }
490
491 *block = after_block;
492 list_delinit(&instr->node);
493 return true;
494 }
495
496 static bool
lower_block(struct ir3 * ir,struct ir3_block ** block)497 lower_block(struct ir3 *ir, struct ir3_block **block)
498 {
499 bool progress = true;
500
501 bool inner_progress;
502 do {
503 inner_progress = false;
504 foreach_instr (instr, &(*block)->instr_list) {
505 if (lower_instr(ir, block, instr)) {
506 /* restart the loop with the new block we created because the
507 * iterator has been invalidated.
508 */
509 progress = inner_progress = true;
510 break;
511 }
512 }
513 } while (inner_progress);
514
515 return progress;
516 }
517
518 bool
ir3_lower_subgroups(struct ir3 * ir)519 ir3_lower_subgroups(struct ir3 *ir)
520 {
521 bool progress = false;
522
523 foreach_block (block, &ir->block_list)
524 progress |= lower_block(ir, &block);
525
526 return progress;
527 }
528
529 static bool
filter_scan_reduce(const nir_instr * instr,const void * data)530 filter_scan_reduce(const nir_instr *instr, const void *data)
531 {
532 if (instr->type != nir_instr_type_intrinsic)
533 return false;
534
535 nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
536
537 switch (intrin->intrinsic) {
538 case nir_intrinsic_reduce:
539 case nir_intrinsic_inclusive_scan:
540 case nir_intrinsic_exclusive_scan:
541 return true;
542 default:
543 return false;
544 }
545 }
546
547 static nir_def *
lower_scan_reduce(struct nir_builder * b,nir_instr * instr,void * data)548 lower_scan_reduce(struct nir_builder *b, nir_instr *instr, void *data)
549 {
550 nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
551 unsigned bit_size = intrin->def.bit_size;
552
553 nir_op op = nir_intrinsic_reduction_op(intrin);
554 nir_const_value ident_val = nir_alu_binop_identity(op, bit_size);
555 nir_def *ident = nir_build_imm(b, 1, bit_size, &ident_val);
556 nir_def *inclusive = intrin->src[0].ssa;
557 nir_def *exclusive = ident;
558
559 for (unsigned cluster_size = 2; cluster_size <= 8; cluster_size *= 2) {
560 nir_def *brcst = nir_brcst_active_ir3(b, ident, inclusive,
561 .cluster_size = cluster_size);
562 inclusive = nir_build_alu2(b, op, inclusive, brcst);
563
564 if (intrin->intrinsic == nir_intrinsic_exclusive_scan)
565 exclusive = nir_build_alu2(b, op, exclusive, brcst);
566 }
567
568 switch (intrin->intrinsic) {
569 case nir_intrinsic_reduce:
570 return nir_reduce_clusters_ir3(b, inclusive, .reduction_op = op);
571 case nir_intrinsic_inclusive_scan:
572 return nir_inclusive_scan_clusters_ir3(b, inclusive, .reduction_op = op);
573 case nir_intrinsic_exclusive_scan:
574 return nir_exclusive_scan_clusters_ir3(b, inclusive, exclusive,
575 .reduction_op = op);
576 default:
577 unreachable("filtered intrinsic");
578 }
579 }
580
581 bool
ir3_nir_opt_subgroups(nir_shader * nir,struct ir3_shader_variant * v)582 ir3_nir_opt_subgroups(nir_shader *nir, struct ir3_shader_variant *v)
583 {
584 if (!v->compiler->has_getfiberid)
585 return false;
586
587 return nir_shader_lower_instructions(nir, filter_scan_reduce,
588 lower_scan_reduce, NULL);
589 }
590