1 /*
2 * Copyright © 2020 Google LLC
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 * and/or sell copies of the Software, and to permit persons to whom the
9 * Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21 * IN THE SOFTWARE.
22 */
23
24 /**
25 * @file
26 *
27 * Trims off the unused trailing components of SSA defs.
28 *
29 * Due to various optimization passes (or frontend implementations,
30 * particularly prog_to_nir), we may have instructions generating vectors
31 * whose components don't get read by any instruction. As it can be tricky
32 * to eliminate unused low components or channels in the middle of a writemask
33 * (you might need to increment some offset from a load_uniform, for example),
34 * it is trivial to just drop the trailing components. For vector ALU only used
35 * by ALU, this pass eliminates arbitrary channels and reswizzles the uses.
36 *
37 * This pass is probably only of use to vector backends -- scalar backends
38 * typically get unused def channel trimming by scalarizing and dead code
39 * elimination.
40 */
41
42 #include "nir.h"
43 #include "nir_builder.h"
44
45 static bool
shrink_dest_to_read_mask(nir_ssa_def * def)46 shrink_dest_to_read_mask(nir_ssa_def *def)
47 {
48 /* early out if there's nothing to do. */
49 if (def->num_components == 1)
50 return false;
51
52 /* don't remove any channels if used by an intrinsic */
53 nir_foreach_use(use_src, def) {
54 if (use_src->parent_instr->type == nir_instr_type_intrinsic)
55 return false;
56 }
57
58 unsigned mask = nir_ssa_def_components_read(def);
59 int last_bit = util_last_bit(mask);
60
61 /* If nothing was read, leave it up to DCE. */
62 if (!mask)
63 return false;
64
65 if (def->num_components > last_bit) {
66 def->num_components = last_bit;
67 return true;
68 }
69
70 return false;
71 }
72
73 static bool
opt_shrink_vectors_alu(nir_builder * b,nir_alu_instr * instr)74 opt_shrink_vectors_alu(nir_builder *b, nir_alu_instr *instr)
75 {
76 nir_ssa_def *def = &instr->dest.dest.ssa;
77
78 /* Nothing to shrink */
79 if (def->num_components == 1)
80 return false;
81
82 bool is_vec = false;
83 switch (instr->op) {
84 /* don't use nir_op_is_vec() as not all vector sizes are supported. */
85 case nir_op_vec4:
86 case nir_op_vec3:
87 case nir_op_vec2:
88 is_vec = true;
89 break;
90 default:
91 if (nir_op_infos[instr->op].output_size != 0)
92 return false;
93 break;
94 }
95
96 /* don't remove any channels if used by an intrinsic */
97 nir_foreach_use(use_src, def) {
98 if (use_src->parent_instr->type == nir_instr_type_intrinsic)
99 return false;
100 }
101
102 unsigned mask = nir_ssa_def_components_read(def);
103 unsigned last_bit = util_last_bit(mask);
104 unsigned num_components = util_bitcount(mask);
105
106 /* return, if there is nothing to do */
107 if (mask == 0 || num_components == def->num_components)
108 return false;
109
110 const bool is_bitfield_mask = last_bit == num_components;
111
112 if (is_vec) {
113 /* replace vecN with smaller version */
114 nir_ssa_def *srcs[NIR_MAX_VEC_COMPONENTS] = { 0 };
115 unsigned index = 0;
116 for (int i = 0; i < last_bit; i++) {
117 if ((mask >> i) & 0x1)
118 srcs[index++] = nir_ssa_for_alu_src(b, instr, i);
119 }
120 assert(index == num_components);
121 nir_ssa_def *new_vec = nir_vec(b, srcs, num_components);
122 nir_ssa_def_rewrite_uses(def, new_vec);
123 def = new_vec;
124 }
125
126 if (is_bitfield_mask) {
127 /* just reduce the number of components and return */
128 def->num_components = num_components;
129 instr->dest.write_mask = mask;
130 return true;
131 }
132
133 if (!is_vec) {
134 /* update sources */
135 for (int i = 0; i < nir_op_infos[instr->op].num_inputs; i++) {
136 unsigned index = 0;
137 for (int j = 0; j < last_bit; j++) {
138 if ((mask >> j) & 0x1)
139 instr->src[i].swizzle[index++] = instr->src[i].swizzle[j];
140 }
141 assert(index == num_components);
142 }
143
144 /* update dest */
145 def->num_components = num_components;
146 instr->dest.write_mask = BITFIELD_MASK(num_components);
147 }
148
149 /* compute new dest swizzles */
150 uint8_t reswizzle[NIR_MAX_VEC_COMPONENTS] = { 0 };
151 unsigned index = 0;
152 for (int i = 0; i < last_bit; i++) {
153 if ((mask >> i) & 0x1)
154 reswizzle[i] = index++;
155 }
156 assert(index == num_components);
157
158 /* update uses */
159 nir_foreach_use(use_src, def) {
160 assert(use_src->parent_instr->type == nir_instr_type_alu);
161 nir_alu_src *alu_src = (nir_alu_src*)use_src;
162 for (unsigned i = 0; i < NIR_MAX_VEC_COMPONENTS; i++)
163 alu_src->swizzle[i] = reswizzle[alu_src->swizzle[i]];
164 }
165
166 return true;
167 }
168
169 static bool
opt_shrink_vectors_image_store(nir_builder * b,nir_intrinsic_instr * instr)170 opt_shrink_vectors_image_store(nir_builder *b, nir_intrinsic_instr *instr)
171 {
172 enum pipe_format format;
173 if (instr->intrinsic == nir_intrinsic_image_deref_store) {
174 nir_deref_instr *deref = nir_src_as_deref(instr->src[0]);
175 format = nir_deref_instr_get_variable(deref)->data.image.format;
176 } else {
177 format = nir_intrinsic_format(instr);
178 }
179 if (format == PIPE_FORMAT_NONE)
180 return false;
181
182 unsigned components = util_format_get_nr_components(format);
183 if (components >= instr->num_components)
184 return false;
185
186 nir_ssa_def *data = nir_channels(b, instr->src[3].ssa, BITSET_MASK(components));
187 nir_instr_rewrite_src(&instr->instr, &instr->src[3], nir_src_for_ssa(data));
188 instr->num_components = components;
189
190 return true;
191 }
192
193 static bool
opt_shrink_vectors_intrinsic(nir_builder * b,nir_intrinsic_instr * instr,bool shrink_image_store)194 opt_shrink_vectors_intrinsic(nir_builder *b, nir_intrinsic_instr *instr, bool shrink_image_store)
195 {
196 switch (instr->intrinsic) {
197 case nir_intrinsic_load_uniform:
198 case nir_intrinsic_load_ubo:
199 case nir_intrinsic_load_input:
200 case nir_intrinsic_load_input_vertex:
201 case nir_intrinsic_load_per_vertex_input:
202 case nir_intrinsic_load_interpolated_input:
203 case nir_intrinsic_load_ssbo:
204 case nir_intrinsic_load_push_constant:
205 case nir_intrinsic_load_constant:
206 case nir_intrinsic_load_shared:
207 case nir_intrinsic_load_global:
208 case nir_intrinsic_load_global_constant:
209 case nir_intrinsic_load_kernel_input:
210 case nir_intrinsic_load_scratch:
211 case nir_intrinsic_store_output:
212 case nir_intrinsic_store_per_vertex_output:
213 case nir_intrinsic_store_ssbo:
214 case nir_intrinsic_store_shared:
215 case nir_intrinsic_store_global:
216 case nir_intrinsic_store_scratch:
217 break;
218 case nir_intrinsic_bindless_image_store:
219 case nir_intrinsic_image_deref_store:
220 case nir_intrinsic_image_store:
221 return shrink_image_store && opt_shrink_vectors_image_store(b, instr);
222 default:
223 return false;
224 }
225
226 /* Must be a vectorized intrinsic that we can resize. */
227 assert(instr->num_components != 0);
228
229 if (nir_intrinsic_infos[instr->intrinsic].has_dest) {
230 /* loads: Trim the dest to the used channels */
231
232 if (shrink_dest_to_read_mask(&instr->dest.ssa)) {
233 instr->num_components = instr->dest.ssa.num_components;
234 return true;
235 }
236 } else {
237 /* Stores: trim the num_components stored according to the write
238 * mask.
239 */
240 unsigned write_mask = nir_intrinsic_write_mask(instr);
241 unsigned last_bit = util_last_bit(write_mask);
242 if (last_bit < instr->num_components && instr->src[0].is_ssa) {
243 nir_ssa_def *def = nir_channels(b, instr->src[0].ssa,
244 BITSET_MASK(last_bit));
245 nir_instr_rewrite_src(&instr->instr,
246 &instr->src[0],
247 nir_src_for_ssa(def));
248 instr->num_components = last_bit;
249
250 return true;
251 }
252 }
253
254 return false;
255 }
256
257 static bool
opt_shrink_vectors_load_const(nir_load_const_instr * instr)258 opt_shrink_vectors_load_const(nir_load_const_instr *instr)
259 {
260 return shrink_dest_to_read_mask(&instr->def);
261 }
262
263 static bool
opt_shrink_vectors_ssa_undef(nir_ssa_undef_instr * instr)264 opt_shrink_vectors_ssa_undef(nir_ssa_undef_instr *instr)
265 {
266 return shrink_dest_to_read_mask(&instr->def);
267 }
268
269 static bool
opt_shrink_vectors_instr(nir_builder * b,nir_instr * instr,bool shrink_image_store)270 opt_shrink_vectors_instr(nir_builder *b, nir_instr *instr, bool shrink_image_store)
271 {
272 b->cursor = nir_before_instr(instr);
273
274 switch (instr->type) {
275 case nir_instr_type_alu:
276 return opt_shrink_vectors_alu(b, nir_instr_as_alu(instr));
277
278 case nir_instr_type_intrinsic:
279 return opt_shrink_vectors_intrinsic(b, nir_instr_as_intrinsic(instr), shrink_image_store);
280
281 case nir_instr_type_load_const:
282 return opt_shrink_vectors_load_const(nir_instr_as_load_const(instr));
283
284 case nir_instr_type_ssa_undef:
285 return opt_shrink_vectors_ssa_undef(nir_instr_as_ssa_undef(instr));
286
287 default:
288 return false;
289 }
290
291 return true;
292 }
293
294 bool
nir_opt_shrink_vectors(nir_shader * shader,bool shrink_image_store)295 nir_opt_shrink_vectors(nir_shader *shader, bool shrink_image_store)
296 {
297 bool progress = false;
298
299 nir_foreach_function(function, shader) {
300 if (!function->impl)
301 continue;
302
303 nir_builder b;
304 nir_builder_init(&b, function->impl);
305
306 nir_foreach_block_reverse(block, function->impl) {
307 nir_foreach_instr_reverse(instr, block) {
308 progress |= opt_shrink_vectors_instr(&b, instr, shrink_image_store);
309 }
310 }
311
312 if (progress) {
313 nir_metadata_preserve(function->impl,
314 nir_metadata_block_index |
315 nir_metadata_dominance);
316 } else {
317 nir_metadata_preserve(function->impl, nir_metadata_all);
318 }
319 }
320
321 return progress;
322 }
323