1 /*
2 * Copyright 2023 Valve Corporation
3 * SPDX-License-Identifier: MIT
4 */
5
6 #include "nir_builder.h"
7 #include "nir_constant_expressions.h"
8 #include "radv_nir.h"
9
10 /* This pass optimizes shuffles and boolean alu where the source can be
11 * expressed as a function of tid (only subgroup_id,
12 * invocation_id or constant as inputs).
13 * Shuffles are replaced by specialized intrinsics, boolean alu by inverse_ballot.
14 * The pass first computes the function of tid (fotid) mask, and then uses constant
15 * folding to compute the source for each invocation.
16 *
17 * This pass assumes that local_invocation_index = subgroup_id * subgroup_size + subgroup_invocation_id.
18 * That is not guaranteed by the VK spec, but it's how amd hardware works, if the GFX12 INTERLEAVE_BITS_X/Y
19 * fields are not used. This is also the main reason why this pass is currently radv specific.
20 */
21
22 #define NIR_MAX_SUBGROUP_SIZE 128
23 #define FOTID_MAX_RECURSION_DEPTH 16 /* totally arbitrary */
24
25 static inline unsigned
src_get_fotid_mask(nir_src src)26 src_get_fotid_mask(nir_src src)
27 {
28 return src.ssa->parent_instr->pass_flags;
29 }
30
31 static inline unsigned
alu_src_get_fotid_mask(nir_alu_instr * instr,unsigned idx)32 alu_src_get_fotid_mask(nir_alu_instr *instr, unsigned idx)
33 {
34 unsigned unswizzled = src_get_fotid_mask(instr->src[idx].src);
35 unsigned result = 0;
36 for (unsigned i = 0; i < nir_ssa_alu_instr_src_components(instr, idx); i++) {
37 bool is_fotid = unswizzled & (1u << instr->src[idx].swizzle[i]);
38 result |= is_fotid << i;
39 }
40 return result;
41 }
42
43 static void
update_fotid_alu(nir_builder * b,nir_alu_instr * instr,const radv_nir_opt_tid_function_options * options)44 update_fotid_alu(nir_builder *b, nir_alu_instr *instr, const radv_nir_opt_tid_function_options *options)
45 {
46 const nir_op_info *info = &nir_op_infos[instr->op];
47
48 unsigned res = BITFIELD_MASK(instr->def.num_components);
49 for (unsigned i = 0; res != 0 && i < info->num_inputs; i++) {
50 unsigned src_mask = alu_src_get_fotid_mask(instr, i);
51 if (info->input_sizes[i] == 0)
52 res &= src_mask;
53 else if (src_mask != BITFIELD_MASK(info->input_sizes[i]))
54 res = 0;
55 }
56
57 instr->instr.pass_flags = (uint8_t)res;
58 }
59
60 static void
update_fotid_intrinsic(nir_builder * b,nir_intrinsic_instr * instr,const radv_nir_opt_tid_function_options * options)61 update_fotid_intrinsic(nir_builder *b, nir_intrinsic_instr *instr, const radv_nir_opt_tid_function_options *options)
62 {
63 switch (instr->intrinsic) {
64 case nir_intrinsic_load_subgroup_invocation: {
65 instr->instr.pass_flags = 1;
66 break;
67 }
68 case nir_intrinsic_load_local_invocation_id: {
69 if (b->shader->info.workgroup_size_variable)
70 break;
71 /* This assumes linear subgroup dispatch. */
72 unsigned partial_size = 1;
73 for (unsigned i = 0; i < 3; i++) {
74 partial_size *= b->shader->info.workgroup_size[i];
75 if (partial_size == options->hw_subgroup_size)
76 instr->instr.pass_flags = (uint8_t)BITFIELD_MASK(i + 1);
77 }
78 if (partial_size <= options->hw_subgroup_size)
79 instr->instr.pass_flags = 0x7;
80 break;
81 }
82 case nir_intrinsic_load_local_invocation_index: {
83 if (b->shader->info.workgroup_size_variable)
84 break;
85 unsigned workgroup_size =
86 b->shader->info.workgroup_size[0] * b->shader->info.workgroup_size[1] * b->shader->info.workgroup_size[2];
87 if (workgroup_size <= options->hw_subgroup_size)
88 instr->instr.pass_flags = 0x1;
89 break;
90 }
91 case nir_intrinsic_inverse_ballot: {
92 if (src_get_fotid_mask(instr->src[0]) == BITFIELD_MASK(instr->src[0].ssa->num_components)) {
93 instr->instr.pass_flags = 0x1;
94 }
95 break;
96 }
97 default: {
98 break;
99 }
100 }
101 }
102
103 static void
update_fotid_load_const(nir_load_const_instr * instr)104 update_fotid_load_const(nir_load_const_instr *instr)
105 {
106 instr->instr.pass_flags = (uint8_t)BITFIELD_MASK(instr->def.num_components);
107 }
108
109 static bool
update_fotid_instr(nir_builder * b,nir_instr * instr,const radv_nir_opt_tid_function_options * options)110 update_fotid_instr(nir_builder *b, nir_instr *instr, const radv_nir_opt_tid_function_options *options)
111 {
112 /* Gather a mask of components that are functions of tid. */
113 instr->pass_flags = 0;
114
115 switch (instr->type) {
116 case nir_instr_type_alu:
117 update_fotid_alu(b, nir_instr_as_alu(instr), options);
118 break;
119 case nir_instr_type_intrinsic:
120 update_fotid_intrinsic(b, nir_instr_as_intrinsic(instr), options);
121 break;
122 case nir_instr_type_load_const:
123 update_fotid_load_const(nir_instr_as_load_const(instr));
124 break;
125 default:
126 break;
127 }
128
129 return false;
130 }
131
132 static bool
constant_fold_scalar(nir_scalar s,unsigned invocation_id,nir_shader * shader,nir_const_value * dest,unsigned depth)133 constant_fold_scalar(nir_scalar s, unsigned invocation_id, nir_shader *shader, nir_const_value *dest, unsigned depth)
134 {
135 if (depth > FOTID_MAX_RECURSION_DEPTH)
136 return false;
137
138 memset(dest, 0, sizeof(*dest));
139
140 if (nir_scalar_is_alu(s)) {
141 nir_alu_instr *alu = nir_instr_as_alu(s.def->parent_instr);
142 nir_const_value sources[NIR_ALU_MAX_INPUTS][NIR_MAX_VEC_COMPONENTS];
143 const nir_op_info *op_info = &nir_op_infos[alu->op];
144
145 unsigned bit_size = 0;
146 if (!nir_alu_type_get_type_size(op_info->output_type))
147 bit_size = alu->def.bit_size;
148
149 for (unsigned i = 0; i < op_info->num_inputs; i++) {
150 if (!bit_size && !nir_alu_type_get_type_size(op_info->input_types[i]))
151 bit_size = alu->src[i].src.ssa->bit_size;
152
153 unsigned offset = 0;
154 unsigned num_comp = op_info->input_sizes[i];
155 if (num_comp == 0) {
156 num_comp = 1;
157 offset = s.comp;
158 }
159
160 for (unsigned j = 0; j < num_comp; j++) {
161 nir_scalar src_scalar = nir_get_scalar(alu->src[i].src.ssa, alu->src[i].swizzle[offset + j]);
162 if (!constant_fold_scalar(src_scalar, invocation_id, shader, &sources[i][j], depth + 1))
163 return false;
164 }
165 }
166
167 if (!bit_size)
168 bit_size = 32;
169
170 unsigned exec_mode = shader->info.float_controls_execution_mode;
171
172 nir_const_value *srcs[NIR_ALU_MAX_INPUTS];
173 for (unsigned i = 0; i < op_info->num_inputs; ++i)
174 srcs[i] = sources[i];
175 nir_const_value dests[NIR_MAX_VEC_COMPONENTS];
176 if (op_info->output_size == 0) {
177 nir_eval_const_opcode(alu->op, dests, 1, bit_size, srcs, exec_mode);
178 *dest = dests[0];
179 } else {
180 nir_eval_const_opcode(alu->op, dests, s.def->num_components, bit_size, srcs, exec_mode);
181 *dest = dests[s.comp];
182 }
183 return true;
184 } else if (nir_scalar_is_intrinsic(s)) {
185 switch (nir_scalar_intrinsic_op(s)) {
186 case nir_intrinsic_load_subgroup_invocation:
187 case nir_intrinsic_load_local_invocation_index: {
188 *dest = nir_const_value_for_uint(invocation_id, s.def->bit_size);
189 return true;
190 }
191 case nir_intrinsic_load_local_invocation_id: {
192 unsigned local_ids[3];
193 local_ids[2] = invocation_id / (shader->info.workgroup_size[0] * shader->info.workgroup_size[1]);
194 unsigned xy = invocation_id % (shader->info.workgroup_size[0] * shader->info.workgroup_size[1]);
195 local_ids[1] = xy / shader->info.workgroup_size[0];
196 local_ids[0] = xy % shader->info.workgroup_size[0];
197 *dest = nir_const_value_for_uint(local_ids[s.comp], s.def->bit_size);
198 return true;
199 }
200 case nir_intrinsic_inverse_ballot: {
201 nir_def *src = nir_instr_as_intrinsic(s.def->parent_instr)->src[0].ssa;
202 unsigned comp = invocation_id / src->bit_size;
203 unsigned bit = invocation_id % src->bit_size;
204 if (!constant_fold_scalar(nir_get_scalar(src, comp), invocation_id, shader, dest, depth + 1))
205 return false;
206 uint64_t ballot = nir_const_value_as_uint(*dest, src->bit_size);
207 *dest = nir_const_value_for_bool(ballot & (1ull << bit), 1);
208 return true;
209 }
210 default:
211 break;
212 }
213 } else if (nir_scalar_is_const(s)) {
214 *dest = nir_scalar_as_const_value(s);
215 return true;
216 }
217
218 unreachable("unhandled scalar type");
219 return false;
220 }
221
222 struct fotid_context {
223 const radv_nir_opt_tid_function_options *options;
224 uint8_t src_invoc[NIR_MAX_SUBGROUP_SIZE];
225 bool reads_zero[NIR_MAX_SUBGROUP_SIZE];
226 nir_shader *shader;
227 };
228
229 static bool
gather_read_invocation_shuffle(nir_def * src,struct fotid_context * ctx)230 gather_read_invocation_shuffle(nir_def *src, struct fotid_context *ctx)
231 {
232 nir_scalar s = {src, 0};
233
234 /* Recursive constant folding for each invocation */
235 for (unsigned i = 0; i < ctx->options->hw_subgroup_size; i++) {
236 nir_const_value value;
237 if (!constant_fold_scalar(s, i, ctx->shader, &value, 0))
238 return false;
239 ctx->src_invoc[i] = MIN2(nir_const_value_as_uint(value, src->bit_size), UINT8_MAX);
240 }
241
242 return true;
243 }
244
245 static nir_alu_instr *
get_singluar_user_bcsel(nir_def * def,unsigned * src_idx)246 get_singluar_user_bcsel(nir_def *def, unsigned *src_idx)
247 {
248 if (def->num_components != 1 || !list_is_singular(&def->uses))
249 return NULL;
250
251 nir_alu_instr *bcsel = NULL;
252 nir_foreach_use_including_if_safe (src, def) {
253 if (nir_src_is_if(src) || nir_src_parent_instr(src)->type != nir_instr_type_alu)
254 return NULL;
255 bcsel = nir_instr_as_alu(nir_src_parent_instr(src));
256 if (bcsel->op != nir_op_bcsel || bcsel->def.num_components != 1)
257 return NULL;
258 *src_idx = list_entry(src, nir_alu_src, src) - bcsel->src;
259 break;
260 }
261 assert(*src_idx < 3);
262
263 if (*src_idx == 0)
264 return NULL;
265 return bcsel;
266 }
267
268 static bool
gather_invocation_uses(nir_alu_instr * bcsel,unsigned shuffle_idx,struct fotid_context * ctx)269 gather_invocation_uses(nir_alu_instr *bcsel, unsigned shuffle_idx, struct fotid_context *ctx)
270 {
271 if (!alu_src_get_fotid_mask(bcsel, 0))
272 return false;
273
274 nir_scalar s = {bcsel->src[0].src.ssa, bcsel->src[0].swizzle[0]};
275
276 bool can_remove_bcsel =
277 nir_src_is_const(bcsel->src[3 - shuffle_idx].src) && nir_src_as_uint(bcsel->src[3 - shuffle_idx].src) == 0;
278
279 /* Recursive constant folding for each invocation */
280 for (unsigned i = 0; i < ctx->options->hw_subgroup_size; i++) {
281 nir_const_value value;
282 if (!constant_fold_scalar(s, i, ctx->shader, &value, 0)) {
283 can_remove_bcsel = false;
284 continue;
285 }
286
287 /* If this invocation selects the other source,
288 * so we can read an undefined result. */
289 if (nir_const_value_as_bool(value, 1) == (shuffle_idx != 1)) {
290 ctx->src_invoc[i] = UINT8_MAX;
291 ctx->reads_zero[i] = can_remove_bcsel;
292 }
293 }
294
295 if (can_remove_bcsel) {
296 return true;
297 } else {
298 memset(ctx->reads_zero, 0, sizeof(ctx->reads_zero));
299 return false;
300 }
301 }
302
303 static nir_def *
try_opt_bitwise_mask(nir_builder * b,nir_def * def,struct fotid_context * ctx)304 try_opt_bitwise_mask(nir_builder *b, nir_def *def, struct fotid_context *ctx)
305 {
306 unsigned one = NIR_MAX_SUBGROUP_SIZE - 1;
307 unsigned zero = NIR_MAX_SUBGROUP_SIZE - 1;
308 unsigned copy = NIR_MAX_SUBGROUP_SIZE - 1;
309 unsigned invert = NIR_MAX_SUBGROUP_SIZE - 1;
310
311 for (unsigned i = 0; i < ctx->options->hw_subgroup_size; i++) {
312 unsigned read = ctx->src_invoc[i];
313 if (read >= ctx->options->hw_subgroup_size)
314 continue; /* undefined result */
315
316 copy &= ~(read ^ i);
317 invert &= read ^ i;
318 one &= read;
319 zero &= ~read;
320 }
321
322 /* We didn't find valid masks for at least one bit. */
323 if ((copy | zero | one | invert) != NIR_MAX_SUBGROUP_SIZE - 1)
324 return NULL;
325
326 unsigned and_mask = copy | invert;
327 unsigned xor_mask = (one | invert) & ~copy;
328
329 #if 0
330 fprintf(stderr, "and %x, xor %x \n", and_mask, xor_mask);
331
332 assert(false);
333 #endif
334
335 if ((and_mask & (ctx->options->hw_subgroup_size - 1)) == 0) {
336 return nir_read_invocation(b, def, nir_imm_int(b, xor_mask));
337 } else if (and_mask == 0x7f && xor_mask == 0) {
338 return def;
339 } else if (ctx->options->use_shuffle_xor && and_mask == 0x7f) {
340 return nir_shuffle_xor(b, def, nir_imm_int(b, xor_mask));
341 } else if (ctx->options->use_masked_swizzle_amd && (and_mask & 0x60) == 0x60 && xor_mask <= 0x1f) {
342 return nir_masked_swizzle_amd(b, def, (xor_mask << 10) | (and_mask & 0x1f), .fetch_inactive = true);
343 }
344
345 return NULL;
346 }
347
348 static nir_def *
try_opt_rotate(nir_builder * b,nir_def * def,struct fotid_context * ctx)349 try_opt_rotate(nir_builder *b, nir_def *def, struct fotid_context *ctx)
350 {
351 for (unsigned csize = 4; csize <= ctx->options->hw_subgroup_size; csize *= 2) {
352 unsigned cmask = csize - 1;
353
354 unsigned delta = UINT_MAX;
355 for (unsigned i = 0; i < ctx->options->hw_subgroup_size; i++) {
356 if (ctx->src_invoc[i] >= ctx->options->hw_subgroup_size)
357 continue;
358
359 if (ctx->src_invoc[i] >= i)
360 delta = ctx->src_invoc[i] - i;
361 else
362 delta = csize - i + ctx->src_invoc[i];
363 break;
364 }
365
366 if (delta >= csize || delta == 0)
367 continue;
368
369 bool use_rotate = true;
370 for (unsigned i = 0; use_rotate && i < ctx->options->hw_subgroup_size; i++) {
371 if (ctx->src_invoc[i] >= ctx->options->hw_subgroup_size)
372 continue;
373 use_rotate &= (((i + delta) & cmask) + (i & ~cmask)) == ctx->src_invoc[i];
374 }
375
376 if (use_rotate)
377 return nir_rotate(b, def, nir_imm_int(b, delta), .cluster_size = csize);
378 }
379
380 return NULL;
381 }
382
383 static nir_def *
try_opt_dpp16_shift(nir_builder * b,nir_def * def,struct fotid_context * ctx)384 try_opt_dpp16_shift(nir_builder *b, nir_def *def, struct fotid_context *ctx)
385 {
386 int delta = INT_MAX;
387 for (unsigned i = 0; i < ctx->options->hw_subgroup_size; i++) {
388 if (ctx->src_invoc[i] >= ctx->options->hw_subgroup_size)
389 continue;
390 delta = ctx->src_invoc[i] - i;
391 break;
392 }
393
394 if (delta < -15 || delta > 15 || delta == 0)
395 return NULL;
396
397 for (unsigned i = 0; i < ctx->options->hw_subgroup_size; i++) {
398 int read = i + delta;
399 bool out_of_bounds = (read & ~0xf) != (i & ~0xf);
400 if (ctx->reads_zero[i] && !out_of_bounds)
401 return NULL;
402 if (ctx->src_invoc[i] >= ctx->options->hw_subgroup_size)
403 continue;
404 if (read != ctx->src_invoc[i] || out_of_bounds)
405 return NULL;
406 }
407
408 return nir_dpp16_shift_amd(b, def, .base = delta);
409 }
410
411 static bool
opt_fotid_shuffle(nir_builder * b,nir_intrinsic_instr * instr,const radv_nir_opt_tid_function_options * options,bool revist_bcsel)412 opt_fotid_shuffle(nir_builder *b, nir_intrinsic_instr *instr, const radv_nir_opt_tid_function_options *options,
413 bool revist_bcsel)
414 {
415 if (instr->intrinsic != nir_intrinsic_shuffle)
416 return false;
417 if (!instr->src[1].ssa->parent_instr->pass_flags)
418 return false;
419
420 unsigned src_idx = 0;
421 nir_alu_instr *bcsel = get_singluar_user_bcsel(&instr->def, &src_idx);
422 /* Skip this shuffle, it will be revisited later when
423 * the function of tid mask is set on the bcsel.
424 */
425 if (bcsel && !revist_bcsel)
426 return false;
427
428 /* We already tried (and failed) to optimize this shuffle. */
429 if (!bcsel && revist_bcsel)
430 return false;
431
432 struct fotid_context ctx = {
433 .options = options,
434 .reads_zero = {0},
435 .shader = b->shader,
436 };
437
438 memset(ctx.src_invoc, 0xff, sizeof(ctx.src_invoc));
439
440 if (!gather_read_invocation_shuffle(instr->src[1].ssa, &ctx))
441 return false;
442
443 /* Generalize src_invoc by taking into account which invocations
444 * do not use the shuffle result because of bcsel.
445 */
446 bool can_remove_bcsel = false;
447 if (bcsel)
448 can_remove_bcsel = gather_invocation_uses(bcsel, src_idx, &ctx);
449
450 #if 0
451 for (int i = 0; i < options->hw_subgroup_size; i++) {
452 fprintf(stderr, "invocation %d reads %d\n", i, ctx.src_invoc[i]);
453 }
454
455 for (int i = 0; i < options->hw_subgroup_size; i++) {
456 fprintf(stderr, "invocation %d zero %d\n", i, ctx.reads_zero[i]);
457 }
458 #endif
459
460 b->cursor = nir_after_instr(&instr->instr);
461
462 nir_def *res = NULL;
463
464 if (can_remove_bcsel && options->use_dpp16_shift_amd) {
465 res = try_opt_dpp16_shift(b, instr->src[0].ssa, &ctx);
466 if (res) {
467 nir_def_rewrite_uses(&bcsel->def, res);
468 return true;
469 }
470 }
471
472 if (!res)
473 res = try_opt_bitwise_mask(b, instr->src[0].ssa, &ctx);
474 if (!res && options->use_clustered_rotate)
475 res = try_opt_rotate(b, instr->src[0].ssa, &ctx);
476
477 if (res) {
478 nir_def_replace(&instr->def, res);
479 return true;
480 } else {
481 return false;
482 }
483 }
484
485 static bool
opt_fotid_bool(nir_builder * b,nir_alu_instr * instr,const radv_nir_opt_tid_function_options * options)486 opt_fotid_bool(nir_builder *b, nir_alu_instr *instr, const radv_nir_opt_tid_function_options *options)
487 {
488 nir_scalar s = {&instr->def, 0};
489
490 b->cursor = nir_after_instr(&instr->instr);
491
492 nir_def *ballot_comp[NIR_MAX_VEC_COMPONENTS];
493
494 for (unsigned comp = 0; comp < options->hw_ballot_num_comp; comp++) {
495 uint64_t cballot = 0;
496 for (unsigned i = 0; i < options->hw_ballot_bit_size; i++) {
497 unsigned invocation_id = comp * options->hw_ballot_bit_size + i;
498 if (invocation_id >= options->hw_subgroup_size)
499 break;
500 nir_const_value value;
501 if (!constant_fold_scalar(s, invocation_id, b->shader, &value, 0))
502 return false;
503 cballot |= nir_const_value_as_uint(value, 1) << i;
504 }
505 ballot_comp[comp] = nir_imm_intN_t(b, cballot, options->hw_ballot_bit_size);
506 }
507
508 nir_def *ballot = nir_vec(b, ballot_comp, options->hw_ballot_num_comp);
509 nir_def *res = nir_inverse_ballot(b, 1, ballot);
510 res->parent_instr->pass_flags = 1;
511
512 nir_def_replace(&instr->def, res);
513 return true;
514 }
515
516 static bool
visit_instr(nir_builder * b,nir_instr * instr,void * params)517 visit_instr(nir_builder *b, nir_instr *instr, void *params)
518 {
519 const radv_nir_opt_tid_function_options *options = params;
520 update_fotid_instr(b, instr, options);
521
522 switch (instr->type) {
523 case nir_instr_type_alu: {
524 nir_alu_instr *alu = nir_instr_as_alu(instr);
525
526 if (alu->op == nir_op_bcsel && alu->def.bit_size != 1) {
527 /* revist shuffles that we skipped previously */
528 bool progress = false;
529 for (unsigned i = 1; i < 3; i++) {
530 nir_instr *src_instr = alu->src[i].src.ssa->parent_instr;
531 if (src_instr->type == nir_instr_type_intrinsic) {
532 nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(src_instr);
533 progress |= opt_fotid_shuffle(b, intrin, options, true);
534 if (list_is_empty(&alu->def.uses))
535 break;
536 }
537 }
538 return progress;
539 }
540
541 if (!options->hw_ballot_bit_size || !options->hw_ballot_num_comp)
542 return false;
543 if (alu->def.bit_size != 1 || alu->def.num_components > 1 || !instr->pass_flags)
544 return false;
545 return opt_fotid_bool(b, alu, options);
546 }
547 case nir_instr_type_intrinsic: {
548 nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
549 return opt_fotid_shuffle(b, intrin, options, false);
550 }
551 default:
552 return false;
553 }
554 }
555
556 bool
radv_nir_opt_tid_function(nir_shader * shader,const radv_nir_opt_tid_function_options * options)557 radv_nir_opt_tid_function(nir_shader *shader, const radv_nir_opt_tid_function_options *options)
558 {
559 return nir_shader_instructions_pass(shader, visit_instr, nir_metadata_control_flow, (void *)options);
560 }
561