1 /*
2 * Copyright © 2015 Connor Abbott
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 * and/or sell copies of the Software, and to permit persons to whom the
9 * Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21 * IN THE SOFTWARE.
22 *
23 */
24
25 #include "nir.h"
26 #include "nir_vla.h"
27 #include "nir_builder.h"
28 #include "util/u_dynarray.h"
29
30 #define HASH(hash, data) XXH32(&data, sizeof(data), hash)
31
32 static uint32_t
hash_src(uint32_t hash,const nir_src * src)33 hash_src(uint32_t hash, const nir_src *src)
34 {
35 assert(src->is_ssa);
36 void *hash_data = nir_src_is_const(*src) ? NULL : src->ssa;
37
38 return HASH(hash, hash_data);
39 }
40
41 static uint32_t
hash_alu_src(uint32_t hash,const nir_alu_src * src,uint32_t num_components,uint32_t max_vec)42 hash_alu_src(uint32_t hash, const nir_alu_src *src,
43 uint32_t num_components, uint32_t max_vec)
44 {
45 assert(!src->abs && !src->negate);
46
47 /* hash whether a swizzle accesses elements beyond the maximum
48 * vectorization factor:
49 * For example accesses to .x and .y are considered different variables
50 * compared to accesses to .z and .w for 16-bit vec2.
51 */
52 uint32_t swizzle = (src->swizzle[0] & ~(max_vec - 1));
53 hash = HASH(hash, swizzle);
54
55 return hash_src(hash, &src->src);
56 }
57
58 static uint32_t
hash_instr(const void * data)59 hash_instr(const void *data)
60 {
61 const nir_instr *instr = (nir_instr *) data;
62 assert(instr->type == nir_instr_type_alu);
63 nir_alu_instr *alu = nir_instr_as_alu(instr);
64
65 uint32_t hash = HASH(0, alu->op);
66 hash = HASH(hash, alu->dest.dest.ssa.bit_size);
67
68 for (unsigned i = 0; i < nir_op_infos[alu->op].num_inputs; i++)
69 hash = hash_alu_src(hash, &alu->src[i],
70 alu->dest.dest.ssa.num_components,
71 instr->pass_flags);
72
73 return hash;
74 }
75
76 static bool
srcs_equal(const nir_src * src1,const nir_src * src2)77 srcs_equal(const nir_src *src1, const nir_src *src2)
78 {
79 assert(src1->is_ssa);
80 assert(src2->is_ssa);
81
82 return src1->ssa == src2->ssa ||
83 (nir_src_is_const(*src1) && nir_src_is_const(*src2));
84 }
85
86 static bool
alu_srcs_equal(const nir_alu_src * src1,const nir_alu_src * src2,uint32_t max_vec)87 alu_srcs_equal(const nir_alu_src *src1, const nir_alu_src *src2,
88 uint32_t max_vec)
89 {
90 assert(!src1->abs);
91 assert(!src1->negate);
92 assert(!src2->abs);
93 assert(!src2->negate);
94
95 uint32_t mask = ~(max_vec - 1);
96 if ((src1->swizzle[0] & mask) != (src2->swizzle[0] & mask))
97 return false;
98
99 return srcs_equal(&src1->src, &src2->src);
100 }
101
102 static bool
instrs_equal(const void * data1,const void * data2)103 instrs_equal(const void *data1, const void *data2)
104 {
105 const nir_instr *instr1 = (nir_instr *) data1;
106 const nir_instr *instr2 = (nir_instr *) data2;
107 assert(instr1->type == nir_instr_type_alu);
108 assert(instr2->type == nir_instr_type_alu);
109
110 nir_alu_instr *alu1 = nir_instr_as_alu(instr1);
111 nir_alu_instr *alu2 = nir_instr_as_alu(instr2);
112
113 if (alu1->op != alu2->op)
114 return false;
115
116 if (alu1->dest.dest.ssa.bit_size != alu2->dest.dest.ssa.bit_size)
117 return false;
118
119 for (unsigned i = 0; i < nir_op_infos[alu1->op].num_inputs; i++) {
120 if (!alu_srcs_equal(&alu1->src[i], &alu2->src[i], instr1->pass_flags))
121 return false;
122 }
123
124 return true;
125 }
126
127 static bool
instr_can_rewrite(nir_instr * instr,bool vectorize_16bit)128 instr_can_rewrite(nir_instr *instr, bool vectorize_16bit)
129 {
130 switch (instr->type) {
131 case nir_instr_type_alu: {
132 nir_alu_instr *alu = nir_instr_as_alu(instr);
133
134 /* Don't try and vectorize mov's. Either they'll be handled by copy
135 * prop, or they're actually necessary and trying to vectorize them
136 * would result in fighting with copy prop.
137 */
138 if (alu->op == nir_op_mov)
139 return false;
140
141 /* no need to hash instructions which are already vectorized */
142 if (alu->dest.dest.ssa.num_components >= 4)
143 return false;
144
145 if (vectorize_16bit &&
146 (alu->dest.dest.ssa.num_components >= 2 ||
147 alu->dest.dest.ssa.bit_size != 16))
148 return false;
149
150 if (nir_op_infos[alu->op].output_size != 0)
151 return false;
152
153 for (unsigned i = 0; i < nir_op_infos[alu->op].num_inputs; i++) {
154 if (nir_op_infos[alu->op].input_sizes[i] != 0)
155 return false;
156
157 /* don't hash instructions which are already swizzled
158 * outside of max_components: these should better be scalarized */
159 uint32_t mask = vectorize_16bit ? ~1 : ~3;
160 for (unsigned j = 0; j < alu->dest.dest.ssa.num_components; j++) {
161 if ((alu->src[i].swizzle[0] & mask) != (alu->src[i].swizzle[j] & mask))
162 return false;
163 }
164 }
165
166 return true;
167 }
168
169 /* TODO support phi nodes */
170 default:
171 break;
172 }
173
174 return false;
175 }
176
177 /*
178 * Tries to combine two instructions whose sources are different components of
179 * the same instructions into one vectorized instruction. Note that instr1
180 * should dominate instr2.
181 */
182
183 static nir_instr *
instr_try_combine(struct nir_shader * nir,struct set * instr_set,nir_instr * instr1,nir_instr * instr2)184 instr_try_combine(struct nir_shader *nir, struct set *instr_set,
185 nir_instr *instr1, nir_instr *instr2)
186 {
187 assert(instr1->type == nir_instr_type_alu);
188 assert(instr2->type == nir_instr_type_alu);
189 nir_alu_instr *alu1 = nir_instr_as_alu(instr1);
190 nir_alu_instr *alu2 = nir_instr_as_alu(instr2);
191
192 assert(alu1->dest.dest.ssa.bit_size == alu2->dest.dest.ssa.bit_size);
193 unsigned alu1_components = alu1->dest.dest.ssa.num_components;
194 unsigned alu2_components = alu2->dest.dest.ssa.num_components;
195 unsigned total_components = alu1_components + alu2_components;
196
197 if (total_components > 4)
198 return NULL;
199
200 if (nir->options->vectorize_vec2_16bit) {
201 assert(total_components == 2);
202 assert(alu1->dest.dest.ssa.bit_size == 16);
203 }
204
205 nir_builder b;
206 nir_builder_init(&b, nir_cf_node_get_function(&instr1->block->cf_node));
207 b.cursor = nir_after_instr(instr1);
208
209 nir_alu_instr *new_alu = nir_alu_instr_create(b.shader, alu1->op);
210 nir_ssa_dest_init(&new_alu->instr, &new_alu->dest.dest,
211 total_components, alu1->dest.dest.ssa.bit_size, NULL);
212 new_alu->dest.write_mask = (1 << total_components) - 1;
213 new_alu->instr.pass_flags = alu1->instr.pass_flags;
214
215 /* If either channel is exact, we have to preserve it even if it's
216 * not optimal for other channels.
217 */
218 new_alu->exact = alu1->exact || alu2->exact;
219
220 /* If all channels don't wrap, we can say that the whole vector doesn't
221 * wrap.
222 */
223 new_alu->no_signed_wrap = alu1->no_signed_wrap && alu2->no_signed_wrap;
224 new_alu->no_unsigned_wrap = alu1->no_unsigned_wrap && alu2->no_unsigned_wrap;
225
226 for (unsigned i = 0; i < nir_op_infos[alu1->op].num_inputs; i++) {
227 /* handle constant merging case */
228 if (alu1->src[i].src.ssa != alu2->src[i].src.ssa) {
229 nir_const_value *c1 = nir_src_as_const_value(alu1->src[i].src);
230 nir_const_value *c2 = nir_src_as_const_value(alu2->src[i].src);
231 assert(c1 && c2);
232 nir_const_value value[NIR_MAX_VEC_COMPONENTS];
233 unsigned bit_size = alu1->src[i].src.ssa->bit_size;
234
235 for (unsigned j = 0; j < total_components; j++) {
236 value[j].u64 = j < alu1_components ?
237 c1[alu1->src[i].swizzle[j]].u64 :
238 c2[alu2->src[i].swizzle[j - alu1_components]].u64;
239 }
240 nir_ssa_def *def = nir_build_imm(&b, total_components, bit_size, value);
241
242 new_alu->src[i].src = nir_src_for_ssa(def);
243 for (unsigned j = 0; j < total_components; j++)
244 new_alu->src[i].swizzle[j] = j;
245 continue;
246 }
247
248 new_alu->src[i].src = alu1->src[i].src;
249
250 for (unsigned j = 0; j < alu1_components; j++)
251 new_alu->src[i].swizzle[j] = alu1->src[i].swizzle[j];
252
253 for (unsigned j = 0; j < alu2_components; j++) {
254 new_alu->src[i].swizzle[j + alu1_components] =
255 alu2->src[i].swizzle[j];
256 }
257 }
258
259 nir_builder_instr_insert(&b, &new_alu->instr);
260
261 unsigned swiz[NIR_MAX_VEC_COMPONENTS];
262 for (unsigned i = 0; i < NIR_MAX_VEC_COMPONENTS; i++)
263 swiz[i] = i;
264 nir_ssa_def *new_alu1 = nir_swizzle(&b, &new_alu->dest.dest.ssa, swiz,
265 alu1_components);
266
267 for (unsigned i = 0; i < alu2_components; i++)
268 swiz[i] += alu1_components;
269 nir_ssa_def *new_alu2 = nir_swizzle(&b, &new_alu->dest.dest.ssa, swiz,
270 alu2_components);
271
272 nir_foreach_use_safe(src, &alu1->dest.dest.ssa) {
273 nir_instr *user_instr = src->parent_instr;
274 if (user_instr->type == nir_instr_type_alu) {
275 /* Check if user is found in the hashset */
276 struct set_entry *entry = _mesa_set_search(instr_set, user_instr);
277
278 /* For ALU instructions, rewrite the source directly to avoid a
279 * round-trip through copy propagation.
280 */
281 nir_instr_rewrite_src(user_instr, src,
282 nir_src_for_ssa(&new_alu->dest.dest.ssa));
283
284 /* Rehash user if it was found in the hashset */
285 if (entry && entry->key == user_instr) {
286 _mesa_set_remove(instr_set, entry);
287 _mesa_set_add(instr_set, src->parent_instr);
288 }
289 } else {
290 nir_instr_rewrite_src(user_instr, src, nir_src_for_ssa(new_alu1));
291 }
292 }
293
294 nir_foreach_if_use_safe(src, &alu1->dest.dest.ssa) {
295 nir_if_rewrite_condition(src->parent_if, nir_src_for_ssa(new_alu1));
296 }
297
298 assert(nir_ssa_def_is_unused(&alu1->dest.dest.ssa));
299
300 nir_foreach_use_safe(src, &alu2->dest.dest.ssa) {
301 if (src->parent_instr->type == nir_instr_type_alu) {
302 /* For ALU instructions, rewrite the source directly to avoid a
303 * round-trip through copy propagation.
304 */
305
306 nir_alu_instr *use = nir_instr_as_alu(src->parent_instr);
307
308 unsigned src_index = 5;
309 for (unsigned i = 0; i < nir_op_infos[use->op].num_inputs; i++) {
310 if (&use->src[i].src == src) {
311 src_index = i;
312 break;
313 }
314 }
315 assert(src_index != 5);
316
317 nir_instr_rewrite_src(src->parent_instr, src,
318 nir_src_for_ssa(&new_alu->dest.dest.ssa));
319
320 for (unsigned i = 0;
321 i < nir_ssa_alu_instr_src_components(use, src_index); i++) {
322 use->src[src_index].swizzle[i] += alu1_components;
323 }
324 } else {
325 nir_instr_rewrite_src(src->parent_instr, src,
326 nir_src_for_ssa(new_alu2));
327 }
328 }
329
330 nir_foreach_if_use_safe(src, &alu2->dest.dest.ssa) {
331 nir_if_rewrite_condition(src->parent_if, nir_src_for_ssa(new_alu2));
332 }
333
334 assert(nir_ssa_def_is_unused(&alu2->dest.dest.ssa));
335
336 nir_instr_remove(instr1);
337 nir_instr_remove(instr2);
338
339 return &new_alu->instr;
340 }
341
342 static struct set *
vec_instr_set_create(void)343 vec_instr_set_create(void)
344 {
345 return _mesa_set_create(NULL, hash_instr, instrs_equal);
346 }
347
348 static void
vec_instr_set_destroy(struct set * instr_set)349 vec_instr_set_destroy(struct set *instr_set)
350 {
351 _mesa_set_destroy(instr_set, NULL);
352 }
353
354 static bool
vec_instr_set_add_or_rewrite(struct nir_shader * nir,struct set * instr_set,nir_instr * instr,nir_opt_vectorize_cb filter,void * data)355 vec_instr_set_add_or_rewrite(struct nir_shader *nir, struct set *instr_set,
356 nir_instr *instr,
357 nir_opt_vectorize_cb filter, void *data)
358 {
359 if (!instr_can_rewrite(instr, nir->options->vectorize_vec2_16bit))
360 return false;
361
362 if (filter && !filter(instr, data))
363 return false;
364
365 /* set max vector to instr pass flags: this is used to hash swizzles */
366 instr->pass_flags = nir->options->vectorize_vec2_16bit ? 2 : 4;
367
368 struct set_entry *entry = _mesa_set_search(instr_set, instr);
369 if (entry) {
370 nir_instr *old_instr = (nir_instr *) entry->key;
371 _mesa_set_remove(instr_set, entry);
372 nir_instr *new_instr = instr_try_combine(nir, instr_set,
373 old_instr, instr);
374 if (new_instr) {
375 if (instr_can_rewrite(new_instr, nir->options->vectorize_vec2_16bit) &&
376 (!filter || filter(new_instr, data)))
377 _mesa_set_add(instr_set, new_instr);
378 return true;
379 }
380 }
381
382 _mesa_set_add(instr_set, instr);
383 return false;
384 }
385
386 static bool
vectorize_block(struct nir_shader * nir,nir_block * block,struct set * instr_set,nir_opt_vectorize_cb filter,void * data)387 vectorize_block(struct nir_shader *nir, nir_block *block,
388 struct set *instr_set,
389 nir_opt_vectorize_cb filter, void *data)
390 {
391 bool progress = false;
392
393 nir_foreach_instr_safe(instr, block) {
394 if (vec_instr_set_add_or_rewrite(nir, instr_set, instr, filter, data))
395 progress = true;
396 }
397
398 for (unsigned i = 0; i < block->num_dom_children; i++) {
399 nir_block *child = block->dom_children[i];
400 progress |= vectorize_block(nir, child, instr_set, filter, data);
401 }
402
403 nir_foreach_instr_reverse(instr, block) {
404 if (instr_can_rewrite(instr, nir->options->vectorize_vec2_16bit) &&
405 (!filter || filter(instr, data)))
406 _mesa_set_remove_key(instr_set, instr);
407 }
408
409 return progress;
410 }
411
412 static bool
nir_opt_vectorize_impl(struct nir_shader * nir,nir_function_impl * impl,nir_opt_vectorize_cb filter,void * data)413 nir_opt_vectorize_impl(struct nir_shader *nir, nir_function_impl *impl,
414 nir_opt_vectorize_cb filter, void *data)
415 {
416 struct set *instr_set = vec_instr_set_create();
417
418 nir_metadata_require(impl, nir_metadata_dominance);
419
420 bool progress = vectorize_block(nir, nir_start_block(impl), instr_set,
421 filter, data);
422
423 if (progress) {
424 nir_metadata_preserve(impl, nir_metadata_block_index |
425 nir_metadata_dominance);
426 } else {
427 nir_metadata_preserve(impl, nir_metadata_all);
428 }
429
430 vec_instr_set_destroy(instr_set);
431 return progress;
432 }
433
434 bool
nir_opt_vectorize(nir_shader * shader,nir_opt_vectorize_cb filter,void * data)435 nir_opt_vectorize(nir_shader *shader, nir_opt_vectorize_cb filter,
436 void *data)
437 {
438 bool progress = false;
439
440 nir_foreach_function(function, shader) {
441 if (function->impl)
442 progress |= nir_opt_vectorize_impl(shader, function->impl, filter, data);
443 }
444
445 return progress;
446 }
447