1 /*
2 * Copyright © 2020 Google LLC
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 * and/or sell copies of the Software, and to permit persons to whom the
9 * Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21 * IN THE SOFTWARE.
22 */
23
24 /**
25 * @file
26 *
27 * Removes unused components of SSA defs.
28 *
29 * Due to various optimization passes (or frontend implementations,
30 * particularly prog_to_nir), we may have instructions generating vectors
31 * whose components don't get read by any instruction.
32 *
33 * For memory loads, while it can be tricky to eliminate unused low components
34 * or channels in the middle of a writemask (you might need to increment some
35 * offset from a load_uniform, for example), it is trivial to just drop the
36 * trailing components.
37 * For vector ALU and load_const, only used by other ALU instructions,
38 * this pass eliminates arbitrary channels as well as duplicate channels,
39 * and reswizzles the uses.
40 *
41 * This pass is probably only of use to vector backends -- scalar backends
42 * typically get unused def channel trimming by scalarizing and dead code
43 * elimination.
44 */
45
46 #include "util/u_math.h"
47 #include "nir.h"
48 #include "nir_builder.h"
49
50 /*
51 * Round up a vector size to a vector size that's valid in NIR. At present, NIR
52 * supports only vec2-5, vec8, and vec16. Attempting to generate other sizes
53 * will fail validation.
54 */
55 static unsigned
round_up_components(unsigned n)56 round_up_components(unsigned n)
57 {
58 return (n > 5) ? util_next_power_of_two(n) : n;
59 }
60
61 static bool
shrink_dest_to_read_mask(nir_def * def)62 shrink_dest_to_read_mask(nir_def *def)
63 {
64 /* early out if there's nothing to do. */
65 if (def->num_components == 1)
66 return false;
67
68 /* don't remove any channels if used by an intrinsic */
69 nir_foreach_use(use_src, def) {
70 if (nir_src_parent_instr(use_src)->type == nir_instr_type_intrinsic)
71 return false;
72 }
73
74 unsigned mask = nir_def_components_read(def);
75 int last_bit = util_last_bit(mask);
76
77 /* If nothing was read, leave it up to DCE. */
78 if (!mask)
79 return false;
80
81 unsigned rounded = round_up_components(last_bit);
82 assert(rounded <= def->num_components);
83 last_bit = rounded;
84
85 if (def->num_components > last_bit) {
86 def->num_components = last_bit;
87 return true;
88 }
89
90 return false;
91 }
92
93 static bool
shrink_intrinsic_to_non_sparse(nir_intrinsic_instr * instr)94 shrink_intrinsic_to_non_sparse(nir_intrinsic_instr *instr)
95 {
96 unsigned mask = nir_def_components_read(&instr->def);
97 int last_bit = util_last_bit(mask);
98
99 /* If the sparse component is used, do nothing. */
100 if (last_bit == instr->def.num_components)
101 return false;
102
103 instr->def.num_components -= 1;
104 instr->num_components = instr->def.num_components;
105
106 /* Switch to the non-sparse intrinsic. */
107 switch (instr->intrinsic) {
108 case nir_intrinsic_image_sparse_load:
109 instr->intrinsic = nir_intrinsic_image_load;
110 break;
111 case nir_intrinsic_bindless_image_sparse_load:
112 instr->intrinsic = nir_intrinsic_bindless_image_load;
113 break;
114 case nir_intrinsic_image_deref_sparse_load:
115 instr->intrinsic = nir_intrinsic_image_deref_load;
116 break;
117 default:
118 break;
119 }
120
121 return true;
122 }
123
124 static void
reswizzle_alu_uses(nir_def * def,uint8_t * reswizzle)125 reswizzle_alu_uses(nir_def *def, uint8_t *reswizzle)
126 {
127 nir_foreach_use(use_src, def) {
128 /* all uses must be ALU instructions */
129 assert(nir_src_parent_instr(use_src)->type == nir_instr_type_alu);
130 nir_alu_src *alu_src = (nir_alu_src *)use_src;
131
132 /* reswizzle ALU sources */
133 for (unsigned i = 0; i < NIR_MAX_VEC_COMPONENTS; i++)
134 alu_src->swizzle[i] = reswizzle[alu_src->swizzle[i]];
135 }
136 }
137
138 static bool
is_only_used_by_alu(nir_def * def)139 is_only_used_by_alu(nir_def *def)
140 {
141 nir_foreach_use(use_src, def) {
142 if (nir_src_parent_instr(use_src)->type != nir_instr_type_alu)
143 return false;
144 }
145
146 return true;
147 }
148
149 static bool
opt_shrink_vector(nir_builder * b,nir_alu_instr * instr)150 opt_shrink_vector(nir_builder *b, nir_alu_instr *instr)
151 {
152 nir_def *def = &instr->def;
153 unsigned mask = nir_def_components_read(def);
154
155 /* If nothing was read, leave it up to DCE. */
156 if (mask == 0)
157 return false;
158
159 /* don't remove any channels if used by non-ALU */
160 if (!is_only_used_by_alu(def))
161 return false;
162
163 uint8_t reswizzle[NIR_MAX_VEC_COMPONENTS] = { 0 };
164 nir_scalar srcs[NIR_MAX_VEC_COMPONENTS] = { 0 };
165 unsigned num_components = 0;
166 for (unsigned i = 0; i < def->num_components; i++) {
167 if (!((mask >> i) & 0x1))
168 continue;
169
170 nir_scalar scalar = nir_get_scalar(instr->src[i].src.ssa, instr->src[i].swizzle[0]);
171
172 /* Try reuse a component with the same value */
173 unsigned j;
174 for (j = 0; j < num_components; j++) {
175 if (nir_scalar_equal(scalar, srcs[j])) {
176 reswizzle[i] = j;
177 break;
178 }
179 }
180
181 /* Otherwise, just append the value */
182 if (j == num_components) {
183 srcs[num_components] = scalar;
184 reswizzle[i] = num_components++;
185 }
186 }
187
188 /* return if no component was removed */
189 if (num_components == def->num_components)
190 return false;
191
192 /* create new vecN and replace uses */
193 nir_def *new_vec = nir_vec_scalars(b, srcs, num_components);
194 nir_def_rewrite_uses(def, new_vec);
195 reswizzle_alu_uses(new_vec, reswizzle);
196
197 return true;
198 }
199
200 static bool
opt_shrink_vectors_alu(nir_builder * b,nir_alu_instr * instr)201 opt_shrink_vectors_alu(nir_builder *b, nir_alu_instr *instr)
202 {
203 nir_def *def = &instr->def;
204
205 /* Nothing to shrink */
206 if (def->num_components == 1)
207 return false;
208
209 switch (instr->op) {
210 /* don't use nir_op_is_vec() as not all vector sizes are supported. */
211 case nir_op_vec4:
212 case nir_op_vec3:
213 case nir_op_vec2:
214 return opt_shrink_vector(b, instr);
215 default:
216 if (nir_op_infos[instr->op].output_size != 0)
217 return false;
218 break;
219 }
220
221 /* don't remove any channels if used by non-ALU */
222 if (!is_only_used_by_alu(def))
223 return false;
224
225 unsigned mask = nir_def_components_read(def);
226 /* return, if there is nothing to do */
227 if (mask == 0)
228 return false;
229
230 uint8_t reswizzle[NIR_MAX_VEC_COMPONENTS] = { 0 };
231 unsigned num_components = 0;
232 bool progress = false;
233 for (unsigned i = 0; i < def->num_components; i++) {
234 /* skip unused components */
235 if (!((mask >> i) & 0x1))
236 continue;
237
238 /* Try reuse a component with the same swizzles */
239 unsigned j;
240 for (j = 0; j < num_components; j++) {
241 bool duplicate_channel = true;
242 for (unsigned k = 0; k < nir_op_infos[instr->op].num_inputs; k++) {
243 if (nir_op_infos[instr->op].input_sizes[k] != 0 ||
244 instr->src[k].swizzle[i] != instr->src[k].swizzle[j]) {
245 duplicate_channel = false;
246 break;
247 }
248 }
249
250 if (duplicate_channel) {
251 reswizzle[i] = j;
252 progress = true;
253 break;
254 }
255 }
256
257 /* Otherwise, just append the value */
258 if (j == num_components) {
259 for (int k = 0; k < nir_op_infos[instr->op].num_inputs; k++) {
260 instr->src[k].swizzle[num_components] = instr->src[k].swizzle[i];
261 }
262 if (i != num_components)
263 progress = true;
264 reswizzle[i] = num_components++;
265 }
266 }
267
268 /* update uses */
269 if (progress)
270 reswizzle_alu_uses(def, reswizzle);
271
272 unsigned rounded = round_up_components(num_components);
273 assert(rounded <= def->num_components);
274 if (rounded < def->num_components)
275 progress = true;
276
277 /* update dest */
278 def->num_components = rounded;
279
280 return progress;
281 }
282
283 static bool
opt_shrink_vectors_intrinsic(nir_builder * b,nir_intrinsic_instr * instr)284 opt_shrink_vectors_intrinsic(nir_builder *b, nir_intrinsic_instr *instr)
285 {
286 switch (instr->intrinsic) {
287 case nir_intrinsic_load_uniform:
288 case nir_intrinsic_load_ubo:
289 case nir_intrinsic_load_input:
290 case nir_intrinsic_load_input_vertex:
291 case nir_intrinsic_load_per_vertex_input:
292 case nir_intrinsic_load_interpolated_input:
293 case nir_intrinsic_load_ssbo:
294 case nir_intrinsic_load_push_constant:
295 case nir_intrinsic_load_constant:
296 case nir_intrinsic_load_shared:
297 case nir_intrinsic_load_global:
298 case nir_intrinsic_load_global_constant:
299 case nir_intrinsic_load_kernel_input:
300 case nir_intrinsic_load_scratch: {
301 /* Must be a vectorized intrinsic that we can resize. */
302 assert(instr->num_components != 0);
303
304 /* Trim the dest to the used channels */
305 if (!shrink_dest_to_read_mask(&instr->def))
306 return false;
307
308 instr->num_components = instr->def.num_components;
309 return true;
310 }
311 case nir_intrinsic_image_sparse_load:
312 case nir_intrinsic_bindless_image_sparse_load:
313 case nir_intrinsic_image_deref_sparse_load:
314 return shrink_intrinsic_to_non_sparse(instr);
315 default:
316 return false;
317 }
318 }
319
320 static bool
opt_shrink_vectors_tex(nir_builder * b,nir_tex_instr * tex)321 opt_shrink_vectors_tex(nir_builder *b, nir_tex_instr *tex)
322 {
323 if (!tex->is_sparse)
324 return false;
325
326 unsigned mask = nir_def_components_read(&tex->def);
327 int last_bit = util_last_bit(mask);
328
329 /* If the sparse component is used, do nothing. */
330 if (last_bit == tex->def.num_components)
331 return false;
332
333 tex->def.num_components -= 1;
334 tex->is_sparse = false;
335
336 return true;
337 }
338
339 static bool
opt_shrink_vectors_load_const(nir_load_const_instr * instr)340 opt_shrink_vectors_load_const(nir_load_const_instr *instr)
341 {
342 nir_def *def = &instr->def;
343
344 /* early out if there's nothing to do. */
345 if (def->num_components == 1)
346 return false;
347
348 /* don't remove any channels if used by non-ALU */
349 if (!is_only_used_by_alu(def))
350 return false;
351
352 unsigned mask = nir_def_components_read(def);
353
354 /* If nothing was read, leave it up to DCE. */
355 if (!mask)
356 return false;
357
358 uint8_t reswizzle[NIR_MAX_VEC_COMPONENTS] = { 0 };
359 unsigned num_components = 0;
360 bool progress = false;
361 for (unsigned i = 0; i < def->num_components; i++) {
362 if (!((mask >> i) & 0x1))
363 continue;
364
365 /* Try reuse a component with the same constant */
366 unsigned j;
367 for (j = 0; j < num_components; j++) {
368 if (instr->value[i].u64 == instr->value[j].u64) {
369 reswizzle[i] = j;
370 progress = true;
371 break;
372 }
373 }
374
375 /* Otherwise, just append the value */
376 if (j == num_components) {
377 instr->value[num_components] = instr->value[i];
378 if (i != num_components)
379 progress = true;
380 reswizzle[i] = num_components++;
381 }
382 }
383
384 if (progress)
385 reswizzle_alu_uses(def, reswizzle);
386
387 unsigned rounded = round_up_components(num_components);
388 assert(rounded <= def->num_components);
389 if (rounded < def->num_components)
390 progress = true;
391
392 def->num_components = rounded;
393
394 return progress;
395 }
396
397 static bool
opt_shrink_vectors_ssa_undef(nir_undef_instr * instr)398 opt_shrink_vectors_ssa_undef(nir_undef_instr *instr)
399 {
400 return shrink_dest_to_read_mask(&instr->def);
401 }
402
403 static bool
opt_shrink_vectors_phi(nir_builder * b,nir_phi_instr * instr)404 opt_shrink_vectors_phi(nir_builder *b, nir_phi_instr *instr)
405 {
406 nir_def *def = &instr->def;
407
408 /* early out if there's nothing to do. */
409 if (def->num_components == 1)
410 return false;
411
412 /* Ignore large vectors for now. */
413 if (def->num_components > 4)
414 return false;
415
416 /* Check the uses. */
417 nir_component_mask_t mask = 0;
418 nir_foreach_use(src, def) {
419 if (nir_src_parent_instr(src)->type != nir_instr_type_alu)
420 return false;
421
422 nir_alu_instr *alu = nir_instr_as_alu(nir_src_parent_instr(src));
423
424 nir_alu_src *alu_src = exec_node_data(nir_alu_src, src, src);
425 int src_idx = alu_src - &alu->src[0];
426 nir_component_mask_t src_read_mask = nir_alu_instr_src_read_mask(alu, src_idx);
427
428 nir_def *alu_def = &alu->def;
429
430 /* We don't mark the channels used if the only reader is the original phi.
431 * This can happen in the case of loops.
432 */
433 nir_foreach_use(alu_use_src, alu_def) {
434 if (nir_src_parent_instr(alu_use_src) != &instr->instr) {
435 mask |= src_read_mask;
436 }
437 }
438
439 /* However, even if the instruction only points back at the phi, we still
440 * need to check that the swizzles are trivial.
441 */
442 if (nir_op_is_vec(alu->op)) {
443 if (src_idx != alu->src[src_idx].swizzle[0]) {
444 mask |= src_read_mask;
445 }
446 } else if (!nir_alu_src_is_trivial_ssa(alu, src_idx)) {
447 mask |= src_read_mask;
448 }
449 }
450
451 /* DCE will handle this. */
452 if (mask == 0)
453 return false;
454
455 /* Nothing to shrink? */
456 if (BITFIELD_MASK(def->num_components) == mask)
457 return false;
458
459 /* Set up the reswizzles. */
460 unsigned num_components = 0;
461 uint8_t reswizzle[NIR_MAX_VEC_COMPONENTS] = { 0 };
462 uint8_t src_reswizzle[NIR_MAX_VEC_COMPONENTS] = { 0 };
463 for (unsigned i = 0; i < def->num_components; i++) {
464 if (!((mask >> i) & 0x1))
465 continue;
466 src_reswizzle[num_components] = i;
467 reswizzle[i] = num_components++;
468 }
469
470 /* Shrink the phi, this part is simple. */
471 def->num_components = num_components;
472
473 /* We can't swizzle phi sources directly so just insert extra mov
474 * with the correct swizzle and let the other parts of nir_shrink_vectors
475 * do its job on the original source instruction. If the original source was
476 * used only in the phi, the movs will disappear later after copy propagate.
477 */
478 nir_foreach_phi_src(phi_src, instr) {
479 b->cursor = nir_after_instr_and_phis(phi_src->src.ssa->parent_instr);
480
481 nir_alu_src alu_src = {
482 .src = nir_src_for_ssa(phi_src->src.ssa)
483 };
484
485 for (unsigned i = 0; i < num_components; i++)
486 alu_src.swizzle[i] = src_reswizzle[i];
487 nir_def *mov = nir_mov_alu(b, alu_src, num_components);
488
489 nir_src_rewrite(&phi_src->src, mov);
490 }
491 b->cursor = nir_before_instr(&instr->instr);
492
493 /* Reswizzle readers. */
494 reswizzle_alu_uses(def, reswizzle);
495
496 return true;
497 }
498
499 static bool
opt_shrink_vectors_instr(nir_builder * b,nir_instr * instr)500 opt_shrink_vectors_instr(nir_builder *b, nir_instr *instr)
501 {
502 b->cursor = nir_before_instr(instr);
503
504 switch (instr->type) {
505 case nir_instr_type_alu:
506 return opt_shrink_vectors_alu(b, nir_instr_as_alu(instr));
507
508 case nir_instr_type_tex:
509 return opt_shrink_vectors_tex(b, nir_instr_as_tex(instr));
510
511 case nir_instr_type_intrinsic:
512 return opt_shrink_vectors_intrinsic(b, nir_instr_as_intrinsic(instr));
513
514 case nir_instr_type_load_const:
515 return opt_shrink_vectors_load_const(nir_instr_as_load_const(instr));
516
517 case nir_instr_type_undef:
518 return opt_shrink_vectors_ssa_undef(nir_instr_as_undef(instr));
519
520 case nir_instr_type_phi:
521 return opt_shrink_vectors_phi(b, nir_instr_as_phi(instr));
522
523 default:
524 return false;
525 }
526
527 return true;
528 }
529
530 bool
nir_opt_shrink_vectors(nir_shader * shader)531 nir_opt_shrink_vectors(nir_shader *shader)
532 {
533 bool progress = false;
534
535 nir_foreach_function_impl(impl, shader) {
536 nir_builder b = nir_builder_create(impl);
537
538 nir_foreach_block_reverse(block, impl) {
539 nir_foreach_instr_reverse(instr, block) {
540 progress |= opt_shrink_vectors_instr(&b, instr);
541 }
542 }
543
544 if (progress) {
545 nir_metadata_preserve(impl,
546 nir_metadata_block_index |
547 nir_metadata_dominance);
548 } else {
549 nir_metadata_preserve(impl, nir_metadata_all);
550 }
551 }
552
553 return progress;
554 }
555