1 /*
2 * Copyright © 2015 Connor Abbott
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 * and/or sell copies of the Software, and to permit persons to whom the
9 * Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21 * IN THE SOFTWARE.
22 *
23 */
24
25 /**
26 * nir_opt_vectorize() aims to vectorize ALU instructions.
27 *
28 * The default vectorization width is 4.
29 * If desired, a callback function which returns the max vectorization width
30 * per instruction can be provided.
31 *
32 * The max vectorization width must be a power of 2.
33 */
34
35 #include "nir.h"
36 #include "nir_vla.h"
37 #include "nir_builder.h"
38 #include "util/u_dynarray.h"
39
40 #define HASH(hash, data) XXH32(&data, sizeof(data), hash)
41
42 static uint32_t
hash_src(uint32_t hash,const nir_src * src)43 hash_src(uint32_t hash, const nir_src *src)
44 {
45 assert(src->is_ssa);
46 void *hash_data = nir_src_is_const(*src) ? NULL : src->ssa;
47
48 return HASH(hash, hash_data);
49 }
50
51 static uint32_t
hash_alu_src(uint32_t hash,const nir_alu_src * src,uint32_t num_components,uint32_t max_vec)52 hash_alu_src(uint32_t hash, const nir_alu_src *src,
53 uint32_t num_components, uint32_t max_vec)
54 {
55 assert(!src->abs && !src->negate);
56
57 /* hash whether a swizzle accesses elements beyond the maximum
58 * vectorization factor:
59 * For example accesses to .x and .y are considered different variables
60 * compared to accesses to .z and .w for 16-bit vec2.
61 */
62 uint32_t swizzle = (src->swizzle[0] & ~(max_vec - 1));
63 hash = HASH(hash, swizzle);
64
65 return hash_src(hash, &src->src);
66 }
67
68 static uint32_t
hash_instr(const void * data)69 hash_instr(const void *data)
70 {
71 const nir_instr *instr = (nir_instr *) data;
72 assert(instr->type == nir_instr_type_alu);
73 nir_alu_instr *alu = nir_instr_as_alu(instr);
74
75 uint32_t hash = HASH(0, alu->op);
76 hash = HASH(hash, alu->dest.dest.ssa.bit_size);
77
78 for (unsigned i = 0; i < nir_op_infos[alu->op].num_inputs; i++)
79 hash = hash_alu_src(hash, &alu->src[i],
80 alu->dest.dest.ssa.num_components,
81 instr->pass_flags);
82
83 return hash;
84 }
85
86 static bool
srcs_equal(const nir_src * src1,const nir_src * src2)87 srcs_equal(const nir_src *src1, const nir_src *src2)
88 {
89 assert(src1->is_ssa);
90 assert(src2->is_ssa);
91
92 return src1->ssa == src2->ssa ||
93 (nir_src_is_const(*src1) && nir_src_is_const(*src2));
94 }
95
96 static bool
alu_srcs_equal(const nir_alu_src * src1,const nir_alu_src * src2,uint32_t max_vec)97 alu_srcs_equal(const nir_alu_src *src1, const nir_alu_src *src2,
98 uint32_t max_vec)
99 {
100 assert(!src1->abs);
101 assert(!src1->negate);
102 assert(!src2->abs);
103 assert(!src2->negate);
104
105 uint32_t mask = ~(max_vec - 1);
106 if ((src1->swizzle[0] & mask) != (src2->swizzle[0] & mask))
107 return false;
108
109 return srcs_equal(&src1->src, &src2->src);
110 }
111
112 static bool
instrs_equal(const void * data1,const void * data2)113 instrs_equal(const void *data1, const void *data2)
114 {
115 const nir_instr *instr1 = (nir_instr *) data1;
116 const nir_instr *instr2 = (nir_instr *) data2;
117 assert(instr1->type == nir_instr_type_alu);
118 assert(instr2->type == nir_instr_type_alu);
119
120 nir_alu_instr *alu1 = nir_instr_as_alu(instr1);
121 nir_alu_instr *alu2 = nir_instr_as_alu(instr2);
122
123 if (alu1->op != alu2->op)
124 return false;
125
126 if (alu1->dest.dest.ssa.bit_size != alu2->dest.dest.ssa.bit_size)
127 return false;
128
129 for (unsigned i = 0; i < nir_op_infos[alu1->op].num_inputs; i++) {
130 if (!alu_srcs_equal(&alu1->src[i], &alu2->src[i], instr1->pass_flags))
131 return false;
132 }
133
134 return true;
135 }
136
137 static bool
instr_can_rewrite(nir_instr * instr)138 instr_can_rewrite(nir_instr *instr)
139 {
140 switch (instr->type) {
141 case nir_instr_type_alu: {
142 nir_alu_instr *alu = nir_instr_as_alu(instr);
143
144 /* Don't try and vectorize mov's. Either they'll be handled by copy
145 * prop, or they're actually necessary and trying to vectorize them
146 * would result in fighting with copy prop.
147 */
148 if (alu->op == nir_op_mov)
149 return false;
150
151 /* no need to hash instructions which are already vectorized */
152 if (alu->dest.dest.ssa.num_components >= instr->pass_flags)
153 return false;
154
155 if (nir_op_infos[alu->op].output_size != 0)
156 return false;
157
158 for (unsigned i = 0; i < nir_op_infos[alu->op].num_inputs; i++) {
159 if (nir_op_infos[alu->op].input_sizes[i] != 0)
160 return false;
161
162 /* don't hash instructions which are already swizzled
163 * outside of max_components: these should better be scalarized */
164 uint32_t mask = ~(instr->pass_flags - 1);
165 for (unsigned j = 1; j < alu->dest.dest.ssa.num_components; j++) {
166 if ((alu->src[i].swizzle[0] & mask) != (alu->src[i].swizzle[j] & mask))
167 return false;
168 }
169 }
170
171 return true;
172 }
173
174 /* TODO support phi nodes */
175 default:
176 break;
177 }
178
179 return false;
180 }
181
182 /*
183 * Tries to combine two instructions whose sources are different components of
184 * the same instructions into one vectorized instruction. Note that instr1
185 * should dominate instr2.
186 */
187 static nir_instr *
instr_try_combine(struct set * instr_set,nir_instr * instr1,nir_instr * instr2)188 instr_try_combine(struct set *instr_set, nir_instr *instr1, nir_instr *instr2)
189 {
190 assert(instr1->type == nir_instr_type_alu);
191 assert(instr2->type == nir_instr_type_alu);
192 nir_alu_instr *alu1 = nir_instr_as_alu(instr1);
193 nir_alu_instr *alu2 = nir_instr_as_alu(instr2);
194
195 assert(alu1->dest.dest.ssa.bit_size == alu2->dest.dest.ssa.bit_size);
196 unsigned alu1_components = alu1->dest.dest.ssa.num_components;
197 unsigned alu2_components = alu2->dest.dest.ssa.num_components;
198 unsigned total_components = alu1_components + alu2_components;
199
200 assert(instr1->pass_flags == instr2->pass_flags);
201 if (total_components > instr1->pass_flags)
202 return NULL;
203
204 nir_builder b;
205 nir_builder_init(&b, nir_cf_node_get_function(&instr1->block->cf_node));
206 b.cursor = nir_after_instr(instr1);
207
208 nir_alu_instr *new_alu = nir_alu_instr_create(b.shader, alu1->op);
209 nir_ssa_dest_init(&new_alu->instr, &new_alu->dest.dest,
210 total_components, alu1->dest.dest.ssa.bit_size, NULL);
211 new_alu->dest.write_mask = (1 << total_components) - 1;
212 new_alu->instr.pass_flags = alu1->instr.pass_flags;
213
214 /* If either channel is exact, we have to preserve it even if it's
215 * not optimal for other channels.
216 */
217 new_alu->exact = alu1->exact || alu2->exact;
218
219 /* If all channels don't wrap, we can say that the whole vector doesn't
220 * wrap.
221 */
222 new_alu->no_signed_wrap = alu1->no_signed_wrap && alu2->no_signed_wrap;
223 new_alu->no_unsigned_wrap = alu1->no_unsigned_wrap && alu2->no_unsigned_wrap;
224
225 for (unsigned i = 0; i < nir_op_infos[alu1->op].num_inputs; i++) {
226 /* handle constant merging case */
227 if (alu1->src[i].src.ssa != alu2->src[i].src.ssa) {
228 nir_const_value *c1 = nir_src_as_const_value(alu1->src[i].src);
229 nir_const_value *c2 = nir_src_as_const_value(alu2->src[i].src);
230 assert(c1 && c2);
231 nir_const_value value[NIR_MAX_VEC_COMPONENTS];
232 unsigned bit_size = alu1->src[i].src.ssa->bit_size;
233
234 for (unsigned j = 0; j < total_components; j++) {
235 value[j].u64 = j < alu1_components ?
236 c1[alu1->src[i].swizzle[j]].u64 :
237 c2[alu2->src[i].swizzle[j - alu1_components]].u64;
238 }
239 nir_ssa_def *def = nir_build_imm(&b, total_components, bit_size, value);
240
241 new_alu->src[i].src = nir_src_for_ssa(def);
242 for (unsigned j = 0; j < total_components; j++)
243 new_alu->src[i].swizzle[j] = j;
244 continue;
245 }
246
247 new_alu->src[i].src = alu1->src[i].src;
248
249 for (unsigned j = 0; j < alu1_components; j++)
250 new_alu->src[i].swizzle[j] = alu1->src[i].swizzle[j];
251
252 for (unsigned j = 0; j < alu2_components; j++) {
253 new_alu->src[i].swizzle[j + alu1_components] =
254 alu2->src[i].swizzle[j];
255 }
256 }
257
258 nir_builder_instr_insert(&b, &new_alu->instr);
259
260 /* update all ALU uses */
261 nir_foreach_use_safe(src, &alu1->dest.dest.ssa) {
262 nir_instr *user_instr = src->parent_instr;
263 if (user_instr->type == nir_instr_type_alu) {
264 /* Check if user is found in the hashset */
265 struct set_entry *entry = _mesa_set_search(instr_set, user_instr);
266
267 /* For ALU instructions, rewrite the source directly to avoid a
268 * round-trip through copy propagation.
269 */
270 nir_instr_rewrite_src(user_instr, src,
271 nir_src_for_ssa(&new_alu->dest.dest.ssa));
272
273 /* Rehash user if it was found in the hashset */
274 if (entry && entry->key == user_instr) {
275 _mesa_set_remove(instr_set, entry);
276 _mesa_set_add(instr_set, user_instr);
277 }
278 }
279 }
280
281 nir_foreach_use_safe(src, &alu2->dest.dest.ssa) {
282 if (src->parent_instr->type == nir_instr_type_alu) {
283 /* For ALU instructions, rewrite the source directly to avoid a
284 * round-trip through copy propagation.
285 */
286 nir_instr_rewrite_src(src->parent_instr, src,
287 nir_src_for_ssa(&new_alu->dest.dest.ssa));
288
289 nir_alu_src *alu_src = container_of(src, nir_alu_src, src);
290 nir_alu_instr *use = nir_instr_as_alu(src->parent_instr);
291 unsigned components = nir_ssa_alu_instr_src_components(use, alu_src - use->src);
292 for (unsigned i = 0; i < components; i++)
293 alu_src->swizzle[i] += alu1_components;
294 }
295 }
296
297 /* update all other uses if there are any */
298 unsigned swiz[NIR_MAX_VEC_COMPONENTS];
299
300 if (!nir_ssa_def_is_unused(&alu1->dest.dest.ssa)) {
301 for (unsigned i = 0; i < alu1_components; i++)
302 swiz[i] = i;
303 nir_ssa_def *new_alu1 = nir_swizzle(&b, &new_alu->dest.dest.ssa, swiz,
304 alu1_components);
305 nir_ssa_def_rewrite_uses(&alu1->dest.dest.ssa, new_alu1);
306 }
307
308 if (!nir_ssa_def_is_unused(&alu2->dest.dest.ssa)) {
309 for (unsigned i = 0; i < alu2_components; i++)
310 swiz[i] = i + alu1_components;
311 nir_ssa_def *new_alu2 = nir_swizzle(&b, &new_alu->dest.dest.ssa, swiz,
312 alu2_components);
313 nir_ssa_def_rewrite_uses(&alu2->dest.dest.ssa, new_alu2);
314 }
315
316 nir_instr_remove(instr1);
317 nir_instr_remove(instr2);
318
319 return &new_alu->instr;
320 }
321
322 static struct set *
vec_instr_set_create(void)323 vec_instr_set_create(void)
324 {
325 return _mesa_set_create(NULL, hash_instr, instrs_equal);
326 }
327
328 static void
vec_instr_set_destroy(struct set * instr_set)329 vec_instr_set_destroy(struct set *instr_set)
330 {
331 _mesa_set_destroy(instr_set, NULL);
332 }
333
334 static bool
vec_instr_set_add_or_rewrite(struct set * instr_set,nir_instr * instr,nir_vectorize_cb filter,void * data)335 vec_instr_set_add_or_rewrite(struct set *instr_set, nir_instr *instr,
336 nir_vectorize_cb filter, void *data)
337 {
338 /* set max vector to instr pass flags: this is used to hash swizzles */
339 instr->pass_flags = filter ? filter(instr, data) : 4;
340 assert(util_is_power_of_two_or_zero(instr->pass_flags));
341
342 if (!instr_can_rewrite(instr))
343 return false;
344
345 struct set_entry *entry = _mesa_set_search(instr_set, instr);
346 if (entry) {
347 nir_instr *old_instr = (nir_instr *) entry->key;
348 _mesa_set_remove(instr_set, entry);
349 nir_instr *new_instr = instr_try_combine(instr_set, old_instr, instr);
350 if (new_instr) {
351 if (instr_can_rewrite(new_instr))
352 _mesa_set_add(instr_set, new_instr);
353 return true;
354 }
355 }
356
357 _mesa_set_add(instr_set, instr);
358 return false;
359 }
360
361 static bool
vectorize_block(nir_block * block,struct set * instr_set,nir_vectorize_cb filter,void * data)362 vectorize_block(nir_block *block, struct set *instr_set,
363 nir_vectorize_cb filter, void *data)
364 {
365 bool progress = false;
366
367 nir_foreach_instr_safe(instr, block) {
368 if (vec_instr_set_add_or_rewrite(instr_set, instr, filter, data))
369 progress = true;
370 }
371
372 for (unsigned i = 0; i < block->num_dom_children; i++) {
373 nir_block *child = block->dom_children[i];
374 progress |= vectorize_block(child, instr_set, filter, data);
375 }
376
377 nir_foreach_instr_reverse(instr, block) {
378 if (instr_can_rewrite(instr))
379 _mesa_set_remove_key(instr_set, instr);
380 }
381
382 return progress;
383 }
384
385 static bool
nir_opt_vectorize_impl(nir_function_impl * impl,nir_vectorize_cb filter,void * data)386 nir_opt_vectorize_impl(nir_function_impl *impl,
387 nir_vectorize_cb filter, void *data)
388 {
389 struct set *instr_set = vec_instr_set_create();
390
391 nir_metadata_require(impl, nir_metadata_dominance);
392
393 bool progress = vectorize_block(nir_start_block(impl), instr_set,
394 filter, data);
395
396 if (progress) {
397 nir_metadata_preserve(impl, nir_metadata_block_index |
398 nir_metadata_dominance);
399 } else {
400 nir_metadata_preserve(impl, nir_metadata_all);
401 }
402
403 vec_instr_set_destroy(instr_set);
404 return progress;
405 }
406
407 bool
nir_opt_vectorize(nir_shader * shader,nir_vectorize_cb filter,void * data)408 nir_opt_vectorize(nir_shader *shader, nir_vectorize_cb filter,
409 void *data)
410 {
411 bool progress = false;
412
413 nir_foreach_function(function, shader) {
414 if (function->impl)
415 progress |= nir_opt_vectorize_impl(function->impl, filter, data);
416 }
417
418 return progress;
419 }
420