• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright © 2020 Google LLC
3  *
4  * Permission is hereby granted, free of charge, to any person obtaining a
5  * copy of this software and associated documentation files (the "Software"),
6  * to deal in the Software without restriction, including without limitation
7  * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8  * and/or sell copies of the Software, and to permit persons to whom the
9  * Software is furnished to do so, subject to the following conditions:
10  *
11  * The above copyright notice and this permission notice (including the next
12  * paragraph) shall be included in all copies or substantial portions of the
13  * Software.
14  *
15  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18  * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20  * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21  * IN THE SOFTWARE.
22  */
23 
24 /**
25  * @file
26  *
27  * Trims off the unused trailing components of SSA defs.
28  *
29  * Due to various optimization passes (or frontend implementations,
30  * particularly prog_to_nir), we may have instructions generating vectors
31  * whose components don't get read by any instruction. As it can be tricky
32  * to eliminate unused low components or channels in the middle of a writemask
33  * (you might need to increment some offset from a load_uniform, for example),
34  * it is trivial to just drop the trailing components. For vector ALU only used
35  * by ALU, this pass eliminates arbitrary channels and reswizzles the uses.
36  *
37  * This pass is probably only of use to vector backends -- scalar backends
38  * typically get unused def channel trimming by scalarizing and dead code
39  * elimination.
40  */
41 
42 #include "nir.h"
43 #include "nir_builder.h"
44 
45 static bool
shrink_dest_to_read_mask(nir_ssa_def * def)46 shrink_dest_to_read_mask(nir_ssa_def *def)
47 {
48    /* early out if there's nothing to do. */
49    if (def->num_components == 1)
50       return false;
51 
52    /* don't remove any channels if used by an intrinsic */
53    nir_foreach_use(use_src, def) {
54       if (use_src->parent_instr->type == nir_instr_type_intrinsic)
55          return false;
56    }
57 
58    unsigned mask = nir_ssa_def_components_read(def);
59    int last_bit = util_last_bit(mask);
60 
61    /* If nothing was read, leave it up to DCE. */
62    if (!mask)
63       return false;
64 
65    if (def->num_components > last_bit) {
66       def->num_components = last_bit;
67       return true;
68    }
69 
70    return false;
71 }
72 
73 static bool
opt_shrink_vectors_alu(nir_builder * b,nir_alu_instr * instr)74 opt_shrink_vectors_alu(nir_builder *b, nir_alu_instr *instr)
75 {
76    nir_ssa_def *def = &instr->dest.dest.ssa;
77 
78    /* Nothing to shrink */
79    if (def->num_components == 1)
80       return false;
81 
82    bool is_vec = false;
83    switch (instr->op) {
84       /* don't use nir_op_is_vec() as not all vector sizes are supported. */
85       case nir_op_vec4:
86       case nir_op_vec3:
87       case nir_op_vec2:
88          is_vec = true;
89          break;
90       default:
91          if (nir_op_infos[instr->op].output_size != 0)
92             return false;
93          break;
94    }
95 
96    /* don't remove any channels if used by an intrinsic */
97    nir_foreach_use(use_src, def) {
98       if (use_src->parent_instr->type == nir_instr_type_intrinsic)
99          return false;
100    }
101 
102    unsigned mask = nir_ssa_def_components_read(def);
103    unsigned last_bit = util_last_bit(mask);
104    unsigned num_components = util_bitcount(mask);
105 
106    /* return, if there is nothing to do */
107    if (mask == 0 || num_components == def->num_components)
108       return false;
109 
110    const bool is_bitfield_mask = last_bit == num_components;
111 
112    if (is_vec) {
113       /* replace vecN with smaller version */
114       nir_ssa_def *srcs[NIR_MAX_VEC_COMPONENTS] = { 0 };
115       unsigned index = 0;
116       for (int i = 0; i < last_bit; i++) {
117          if ((mask >> i) & 0x1)
118             srcs[index++] = nir_ssa_for_alu_src(b, instr, i);
119       }
120       assert(index == num_components);
121       nir_ssa_def *new_vec = nir_vec(b, srcs, num_components);
122       nir_ssa_def_rewrite_uses(def, new_vec);
123       def = new_vec;
124    }
125 
126    if (is_bitfield_mask) {
127       /* just reduce the number of components and return */
128       def->num_components = num_components;
129       instr->dest.write_mask = mask;
130       return true;
131    }
132 
133    if (!is_vec) {
134       /* update sources */
135       for (int i = 0; i < nir_op_infos[instr->op].num_inputs; i++) {
136          unsigned index = 0;
137          for (int j = 0; j < last_bit; j++) {
138             if ((mask >> j) & 0x1)
139                instr->src[i].swizzle[index++] = instr->src[i].swizzle[j];
140          }
141          assert(index == num_components);
142       }
143 
144       /* update dest */
145       def->num_components = num_components;
146       instr->dest.write_mask = BITFIELD_MASK(num_components);
147    }
148 
149    /* compute new dest swizzles */
150    uint8_t reswizzle[NIR_MAX_VEC_COMPONENTS] = { 0 };
151    unsigned index = 0;
152    for (int i = 0; i < last_bit; i++) {
153       if ((mask >> i) & 0x1)
154          reswizzle[i] = index++;
155    }
156    assert(index == num_components);
157 
158    /* update uses */
159    nir_foreach_use(use_src, def) {
160       assert(use_src->parent_instr->type == nir_instr_type_alu);
161       nir_alu_src *alu_src = (nir_alu_src*)use_src;
162       for (unsigned i = 0; i < NIR_MAX_VEC_COMPONENTS; i++)
163          alu_src->swizzle[i] = reswizzle[alu_src->swizzle[i]];
164    }
165 
166    return true;
167 }
168 
169 static bool
opt_shrink_vectors_image_store(nir_builder * b,nir_intrinsic_instr * instr)170 opt_shrink_vectors_image_store(nir_builder *b, nir_intrinsic_instr *instr)
171 {
172    enum pipe_format format;
173    if (instr->intrinsic == nir_intrinsic_image_deref_store) {
174       nir_deref_instr *deref = nir_src_as_deref(instr->src[0]);
175       format = nir_deref_instr_get_variable(deref)->data.image.format;
176    } else {
177       format = nir_intrinsic_format(instr);
178    }
179    if (format == PIPE_FORMAT_NONE)
180       return false;
181 
182    unsigned components = util_format_get_nr_components(format);
183    if (components >= instr->num_components)
184       return false;
185 
186    nir_ssa_def *data = nir_channels(b, instr->src[3].ssa, BITSET_MASK(components));
187    nir_instr_rewrite_src(&instr->instr, &instr->src[3], nir_src_for_ssa(data));
188    instr->num_components = components;
189 
190    return true;
191 }
192 
193 static bool
opt_shrink_vectors_intrinsic(nir_builder * b,nir_intrinsic_instr * instr,bool shrink_image_store)194 opt_shrink_vectors_intrinsic(nir_builder *b, nir_intrinsic_instr *instr, bool shrink_image_store)
195 {
196    switch (instr->intrinsic) {
197    case nir_intrinsic_load_uniform:
198    case nir_intrinsic_load_ubo:
199    case nir_intrinsic_load_input:
200    case nir_intrinsic_load_input_vertex:
201    case nir_intrinsic_load_per_vertex_input:
202    case nir_intrinsic_load_interpolated_input:
203    case nir_intrinsic_load_ssbo:
204    case nir_intrinsic_load_push_constant:
205    case nir_intrinsic_load_constant:
206    case nir_intrinsic_load_shared:
207    case nir_intrinsic_load_global:
208    case nir_intrinsic_load_global_constant:
209    case nir_intrinsic_load_kernel_input:
210    case nir_intrinsic_load_scratch:
211    case nir_intrinsic_store_output:
212    case nir_intrinsic_store_per_vertex_output:
213    case nir_intrinsic_store_ssbo:
214    case nir_intrinsic_store_shared:
215    case nir_intrinsic_store_global:
216    case nir_intrinsic_store_scratch:
217       break;
218    case nir_intrinsic_bindless_image_store:
219    case nir_intrinsic_image_deref_store:
220    case nir_intrinsic_image_store:
221       return shrink_image_store && opt_shrink_vectors_image_store(b, instr);
222    default:
223       return false;
224    }
225 
226    /* Must be a vectorized intrinsic that we can resize. */
227    assert(instr->num_components != 0);
228 
229    if (nir_intrinsic_infos[instr->intrinsic].has_dest) {
230       /* loads: Trim the dest to the used channels */
231 
232       if (shrink_dest_to_read_mask(&instr->dest.ssa)) {
233          instr->num_components = instr->dest.ssa.num_components;
234          return true;
235       }
236    } else {
237       /* Stores: trim the num_components stored according to the write
238        * mask.
239        */
240       unsigned write_mask = nir_intrinsic_write_mask(instr);
241       unsigned last_bit = util_last_bit(write_mask);
242       if (last_bit < instr->num_components && instr->src[0].is_ssa) {
243          nir_ssa_def *def = nir_channels(b, instr->src[0].ssa,
244                                          BITSET_MASK(last_bit));
245          nir_instr_rewrite_src(&instr->instr,
246                                &instr->src[0],
247                                nir_src_for_ssa(def));
248          instr->num_components = last_bit;
249 
250          return true;
251       }
252    }
253 
254    return false;
255 }
256 
257 static bool
opt_shrink_vectors_load_const(nir_load_const_instr * instr)258 opt_shrink_vectors_load_const(nir_load_const_instr *instr)
259 {
260    return shrink_dest_to_read_mask(&instr->def);
261 }
262 
263 static bool
opt_shrink_vectors_ssa_undef(nir_ssa_undef_instr * instr)264 opt_shrink_vectors_ssa_undef(nir_ssa_undef_instr *instr)
265 {
266    return shrink_dest_to_read_mask(&instr->def);
267 }
268 
269 static bool
opt_shrink_vectors_instr(nir_builder * b,nir_instr * instr,bool shrink_image_store)270 opt_shrink_vectors_instr(nir_builder *b, nir_instr *instr, bool shrink_image_store)
271 {
272    b->cursor = nir_before_instr(instr);
273 
274    switch (instr->type) {
275    case nir_instr_type_alu:
276       return opt_shrink_vectors_alu(b, nir_instr_as_alu(instr));
277 
278    case nir_instr_type_intrinsic:
279       return opt_shrink_vectors_intrinsic(b, nir_instr_as_intrinsic(instr), shrink_image_store);
280 
281    case nir_instr_type_load_const:
282       return opt_shrink_vectors_load_const(nir_instr_as_load_const(instr));
283 
284    case nir_instr_type_ssa_undef:
285       return opt_shrink_vectors_ssa_undef(nir_instr_as_ssa_undef(instr));
286 
287    default:
288       return false;
289    }
290 
291    return true;
292 }
293 
294 bool
nir_opt_shrink_vectors(nir_shader * shader,bool shrink_image_store)295 nir_opt_shrink_vectors(nir_shader *shader, bool shrink_image_store)
296 {
297    bool progress = false;
298 
299    nir_foreach_function(function, shader) {
300       if (!function->impl)
301          continue;
302 
303       nir_builder b;
304       nir_builder_init(&b, function->impl);
305 
306       nir_foreach_block_reverse(block, function->impl) {
307          nir_foreach_instr_reverse(instr, block) {
308             progress |= opt_shrink_vectors_instr(&b, instr, shrink_image_store);
309          }
310       }
311 
312       if (progress) {
313          nir_metadata_preserve(function->impl,
314                                nir_metadata_block_index |
315                                nir_metadata_dominance);
316       } else {
317          nir_metadata_preserve(function->impl, nir_metadata_all);
318       }
319    }
320 
321    return progress;
322 }
323