• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright 2024 Advanced Micro Devices, Inc.
3  *
4  * SPDX-License-Identifier: MIT
5  */
6 
7 /**
8  * This pass:
9  * - vectorizes lowered input/output loads and stores
10  * - vectorizes low and high 16-bit loads and stores by merging them into
11  *   a single 32-bit load or store (except load_interpolated_input, which has
12  *   to keep bit_size=16)
13  * - performs DCE of output stores that overwrite the previous value by writing
14  *   into the same slot and component.
15  *
16  * Vectorization is only local within basic blocks. No vectorization occurs
17  * across basic block boundaries, barriers (only TCS outputs), emits (only
18  * GS outputs), and output load <-> output store dependencies.
19  *
20  * All loads and stores must be scalar. 64-bit loads and stores are forbidden.
21  *
22  * For each basic block, the time complexity is O(n*log(n)) where n is
23  * the number of IO instructions within that block.
24  */
25 
26 #include "nir.h"
27 #include "nir_builder.h"
28 #include "util/u_dynarray.h"
29 
30 /* Return 0 if loads/stores are vectorizable. Return 1 or -1 to define
31  * an ordering between non-vectorizable instructions. This is used by qsort,
32  * to sort all gathered instructions into groups of vectorizable instructions.
33  */
34 static int
compare_is_not_vectorizable(nir_intrinsic_instr * a,nir_intrinsic_instr * b)35 compare_is_not_vectorizable(nir_intrinsic_instr *a, nir_intrinsic_instr *b)
36 {
37    if (a->intrinsic != b->intrinsic)
38       return a->intrinsic > b->intrinsic ? 1 : -1;
39 
40    nir_src *offset0 = nir_get_io_offset_src(a);
41    nir_src *offset1 = nir_get_io_offset_src(b);
42    if (offset0 && offset0->ssa != offset1->ssa)
43       return offset0->ssa->index > offset1->ssa->index ? 1 : -1;
44 
45    nir_src *array_idx0 = nir_get_io_arrayed_index_src(a);
46    nir_src *array_idx1 = nir_get_io_arrayed_index_src(b);
47    if (array_idx0 && array_idx0->ssa != array_idx1->ssa)
48       return array_idx0->ssa->index > array_idx1->ssa->index ? 1 : -1;
49 
50    /* Compare barycentrics or vertex index. */
51    if ((a->intrinsic == nir_intrinsic_load_interpolated_input ||
52         a->intrinsic == nir_intrinsic_load_input_vertex) &&
53        a->src[0].ssa != b->src[0].ssa)
54       return a->src[0].ssa->index > b->src[0].ssa->index ? 1 : -1;
55 
56    nir_io_semantics sem0 = nir_intrinsic_io_semantics(a);
57    nir_io_semantics sem1 = nir_intrinsic_io_semantics(b);
58    if (sem0.location != sem1.location)
59       return sem0.location > sem1.location ? 1 : -1;
60 
61    /* The mediump flag isn't mergable. */
62    if (sem0.medium_precision != sem1.medium_precision)
63       return sem0.medium_precision > sem1.medium_precision ? 1 : -1;
64 
65    /* Don't merge per-view attributes with non-per-view attributes. */
66    if (sem0.per_view != sem1.per_view)
67       return sem0.per_view > sem1.per_view ? 1 : -1;
68 
69    if (sem0.interp_explicit_strict != sem1.interp_explicit_strict)
70       return sem0.interp_explicit_strict > sem1.interp_explicit_strict ? 1 : -1;
71 
72    /* Only load_interpolated_input can't merge low and high halves of 16-bit
73     * loads/stores.
74     */
75    if (a->intrinsic == nir_intrinsic_load_interpolated_input &&
76        sem0.high_16bits != sem1.high_16bits)
77       return sem0.high_16bits > sem1.high_16bits ? 1 : -1;
78 
79    nir_shader *shader =
80       nir_cf_node_get_function(&a->instr.block->cf_node)->function->shader;
81 
82    /* Compare the types. */
83    if (!(shader->options->io_options & nir_io_vectorizer_ignores_types)) {
84       unsigned type_a, type_b;
85 
86       if (nir_intrinsic_has_src_type(a)) {
87          type_a = nir_intrinsic_src_type(a);
88          type_b = nir_intrinsic_src_type(b);
89       } else {
90          type_a = nir_intrinsic_dest_type(a);
91          type_b = nir_intrinsic_dest_type(b);
92       }
93 
94       if (type_a != type_b)
95          return type_a > type_b ? 1 : -1;
96    }
97 
98    return 0;
99 }
100 
101 static int
compare_intr(const void * xa,const void * xb)102 compare_intr(const void *xa, const void *xb)
103 {
104    nir_intrinsic_instr *a = *(nir_intrinsic_instr **)xa;
105    nir_intrinsic_instr *b = *(nir_intrinsic_instr **)xb;
106 
107    int comp = compare_is_not_vectorizable(a, b);
108    if (comp)
109       return comp;
110 
111    /* qsort isn't stable. This ensures that later stores aren't moved before earlier stores. */
112    return a->instr.index > b->instr.index ? 1 : -1;
113 }
114 
115 static void
vectorize_load(nir_intrinsic_instr * chan[8],unsigned start,unsigned count,bool merge_low_high_16_to_32)116 vectorize_load(nir_intrinsic_instr *chan[8], unsigned start, unsigned count,
117                bool merge_low_high_16_to_32)
118 {
119    nir_intrinsic_instr *first = NULL;
120 
121    /* Find the first instruction where the vectorized load will be
122     * inserted.
123     */
124    for (unsigned i = start; i < start + count; i++) {
125       first = !first || chan[i]->instr.index < first->instr.index ?
126                  chan[i] : first;
127       if (merge_low_high_16_to_32) {
128          first = !first || chan[4 + i]->instr.index < first->instr.index ?
129                     chan[4 + i] : first;
130       }
131    }
132 
133    /* Insert the vectorized load. */
134    nir_builder b = nir_builder_at(nir_before_instr(&first->instr));
135    nir_intrinsic_instr *new_intr =
136       nir_intrinsic_instr_create(b.shader, first->intrinsic);
137 
138    new_intr->num_components = count;
139    nir_def_init(&new_intr->instr, &new_intr->def, count,
140                 merge_low_high_16_to_32 ? 32 : first->def.bit_size);
141    memcpy(new_intr->src, first->src,
142           nir_intrinsic_infos[first->intrinsic].num_srcs * sizeof(nir_src));
143    nir_intrinsic_copy_const_indices(new_intr, first);
144    nir_intrinsic_set_component(new_intr, start);
145 
146    if (merge_low_high_16_to_32) {
147       nir_io_semantics sem = nir_intrinsic_io_semantics(new_intr);
148       sem.high_16bits = 0;
149       nir_intrinsic_set_io_semantics(new_intr, sem);
150       nir_intrinsic_set_dest_type(new_intr,
151                                   (nir_intrinsic_dest_type(new_intr) & ~16) | 32);
152    }
153 
154    nir_builder_instr_insert(&b, &new_intr->instr);
155    nir_def *def = &new_intr->def;
156 
157    /* Replace the scalar loads. */
158    if (merge_low_high_16_to_32) {
159       for (unsigned i = start; i < start + count; i++) {
160          nir_def *comp = nir_channel(&b, def, i - start);
161 
162          nir_def_rewrite_uses(&chan[i]->def,
163                               nir_unpack_32_2x16_split_x(&b, comp));
164          nir_def_rewrite_uses(&chan[4 + i]->def,
165                               nir_unpack_32_2x16_split_y(&b, comp));
166          nir_instr_remove(&chan[i]->instr);
167          nir_instr_remove(&chan[4 + i]->instr);
168       }
169    } else {
170       for (unsigned i = start; i < start + count; i++) {
171          nir_def_replace(&chan[i]->def, nir_channel(&b, def, i - start));
172       }
173    }
174 }
175 
176 static void
vectorize_store(nir_intrinsic_instr * chan[8],unsigned start,unsigned count,bool merge_low_high_16_to_32)177 vectorize_store(nir_intrinsic_instr *chan[8], unsigned start, unsigned count,
178                 bool merge_low_high_16_to_32)
179 {
180    nir_intrinsic_instr *last = NULL;
181 
182    /* Find the last instruction where the vectorized store will be
183     * inserted.
184     */
185    for (unsigned i = start; i < start + count; i++) {
186       last = !last || chan[i]->instr.index > last->instr.index ?
187                 chan[i] : last;
188       if (merge_low_high_16_to_32) {
189          last = !last || chan[4 + i]->instr.index > last->instr.index ?
190                    chan[4 + i] : last;
191       }
192    }
193 
194    /* Change the last instruction to a vectorized store. Update xfb first
195     * because we need to read some info from "last" before overwriting it.
196     */
197    if (nir_intrinsic_has_io_xfb(last)) {
198       /* 0 = low/full XY channels
199        * 1 = low/full ZW channels
200        * 2 = high XY channels
201        * 3 = high ZW channels
202        */
203       nir_io_xfb xfb[4] = {{{{0}}}};
204 
205       for (unsigned i = start; i < start + count; i++) {
206          xfb[i / 2].out[i % 2] =
207             (i < 2 ? nir_intrinsic_io_xfb(chan[i]) :
208                      nir_intrinsic_io_xfb2(chan[i])).out[i % 2];
209 
210          /* Merging low and high 16 bits to 32 bits is not possible
211           * with xfb in some cases. (and it's not implemented for
212           * cases where it's possible)
213           */
214          assert(!xfb[i / 2].out[i % 2].num_components ||
215                 !merge_low_high_16_to_32);
216       }
217 
218       /* Now vectorize xfb info by merging the individual elements. */
219       for (unsigned i = start; i < start + count; i++) {
220          /* mediump means that xfb upconverts to 32 bits when writing to
221           * memory.
222           */
223          unsigned xfb_comp_size =
224             nir_intrinsic_io_semantics(chan[i]).medium_precision ?
225                   32 : chan[i]->src[0].ssa->bit_size;
226 
227          for (unsigned j = i + 1; j < start + count; j++) {
228             if (xfb[i / 2].out[i % 2].buffer != xfb[j / 2].out[j % 2].buffer ||
229                 xfb[i / 2].out[i % 2].offset != xfb[j / 2].out[j % 2].offset +
230                 xfb_comp_size * (j - i))
231                break;
232 
233             xfb[i / 2].out[i % 2].num_components++;
234             memset(&xfb[j / 2].out[j % 2], 0, sizeof(xfb[j / 2].out[j % 2]));
235          }
236       }
237 
238       nir_intrinsic_set_io_xfb(last, xfb[0]);
239       nir_intrinsic_set_io_xfb2(last, xfb[1]);
240    }
241 
242    /* Update gs_streams. */
243    unsigned gs_streams = 0;
244    for (unsigned i = start; i < start + count; i++) {
245       gs_streams |= (nir_intrinsic_io_semantics(chan[i]).gs_streams & 0x3) <<
246                     ((i - start) * 2);
247    }
248 
249    nir_io_semantics sem = nir_intrinsic_io_semantics(last);
250    sem.gs_streams = gs_streams;
251 
252    /* Update other flags. */
253    for (unsigned i = start; i < start + count; i++) {
254       if (!nir_intrinsic_io_semantics(chan[i]).no_sysval_output)
255          sem.no_sysval_output = 0;
256       if (!nir_intrinsic_io_semantics(chan[i]).no_varying)
257          sem.no_varying = 0;
258       if (nir_intrinsic_io_semantics(chan[i]).invariant)
259          sem.invariant = 1;
260    }
261 
262    if (merge_low_high_16_to_32) {
263       /* Update "no" flags for high bits. */
264       for (unsigned i = start; i < start + count; i++) {
265          if (!nir_intrinsic_io_semantics(chan[4 + i]).no_sysval_output)
266             sem.no_sysval_output = 0;
267          if (!nir_intrinsic_io_semantics(chan[4 + i]).no_varying)
268             sem.no_varying = 0;
269          if (nir_intrinsic_io_semantics(chan[4 + i]).invariant)
270             sem.invariant = 1;
271       }
272 
273       /* Update the type. */
274       sem.high_16bits = 0;
275       nir_intrinsic_set_src_type(last,
276                                  (nir_intrinsic_src_type(last) & ~16) | 32);
277    }
278 
279    /* TODO: Merge names? */
280 
281    /* Update the rest. */
282    nir_intrinsic_set_io_semantics(last, sem);
283    nir_intrinsic_set_component(last, start);
284    nir_intrinsic_set_write_mask(last, BITFIELD_MASK(count));
285    last->num_components = count;
286 
287    nir_builder b = nir_builder_at(nir_before_instr(&last->instr));
288 
289    /* Replace the stored scalar with the vector. */
290    if (merge_low_high_16_to_32) {
291       nir_def *value[4];
292       for (unsigned i = start; i < start + count; i++) {
293          value[i] = nir_pack_32_2x16_split(&b, chan[i]->src[0].ssa,
294                                            chan[4 + i]->src[0].ssa);
295       }
296 
297       nir_src_rewrite(&last->src[0], nir_vec(&b, &value[start], count));
298    } else {
299       nir_def *value[8];
300       for (unsigned i = start; i < start + count; i++)
301          value[i] = chan[i]->src[0].ssa;
302 
303       nir_src_rewrite(&last->src[0], nir_vec(&b, &value[start], count));
304    }
305 
306    /* Remove the scalar stores. */
307    for (unsigned i = start; i < start + count; i++) {
308       if (chan[i] != last)
309          nir_instr_remove(&chan[i]->instr);
310       if (merge_low_high_16_to_32 && chan[4 + i] != last)
311          nir_instr_remove(&chan[4 + i]->instr);
312    }
313 }
314 
315 /* Vectorize a vector of scalar instructions. chan[8] are the channels.
316  * (the last 4 are the high 16-bit channels)
317  */
318 static bool
vectorize_slot(nir_intrinsic_instr * chan[8],unsigned mask)319 vectorize_slot(nir_intrinsic_instr *chan[8], unsigned mask)
320 {
321    bool progress = false;
322 
323    /* First, merge low and high 16-bit halves into 32 bits separately when
324     * possible. Then vectorize what's left.
325     */
326    for (int merge_low_high_16_to_32 = 1; merge_low_high_16_to_32 >= 0;
327         merge_low_high_16_to_32--) {
328       unsigned scan_mask;
329 
330       if (merge_low_high_16_to_32) {
331          /* Get the subset of the mask where both low and high bits are set. */
332          scan_mask = 0;
333          for (unsigned i = 0; i < 4; i++) {
334             unsigned low_high_bits = BITFIELD_BIT(i) | BITFIELD_BIT(i + 4);
335 
336             if ((mask & low_high_bits) == low_high_bits) {
337                /* Merging low and high 16 bits to 32 bits is not possible
338                 * with xfb in some cases. (and it's not implemented for
339                 * cases where it's possible)
340                 */
341                if (nir_intrinsic_has_io_xfb(chan[i])) {
342                   unsigned hi = i + 4;
343 
344                   if ((i < 2 ? nir_intrinsic_io_xfb(chan[i])
345                              : nir_intrinsic_io_xfb2(chan[i])).out[i % 2].num_components ||
346                       (i < 2 ? nir_intrinsic_io_xfb(chan[hi])
347                              : nir_intrinsic_io_xfb2(chan[hi])).out[i % 2].num_components)
348                      continue;
349                }
350 
351                /* The GS stream must be the same for both halves. */
352                if ((nir_intrinsic_io_semantics(chan[i]).gs_streams & 0x3) !=
353                    (nir_intrinsic_io_semantics(chan[4 + i]).gs_streams & 0x3))
354                   continue;
355 
356                scan_mask |= BITFIELD_BIT(i);
357                mask &= ~low_high_bits;
358             }
359          }
360       } else {
361          scan_mask = mask;
362       }
363 
364       while (scan_mask) {
365          int start, count;
366 
367          u_bit_scan_consecutive_range(&scan_mask, &start, &count);
368 
369          if (count == 1 && !merge_low_high_16_to_32)
370             continue; /* There is nothing to vectorize. */
371 
372          bool is_load = nir_intrinsic_infos[chan[start]->intrinsic].has_dest;
373 
374          if (is_load)
375             vectorize_load(chan, start, count, merge_low_high_16_to_32);
376          else
377             vectorize_store(chan, start, count, merge_low_high_16_to_32);
378 
379          progress = true;
380       }
381    }
382 
383    return progress;
384 }
385 
386 static bool
vectorize_batch(struct util_dynarray * io_instructions)387 vectorize_batch(struct util_dynarray *io_instructions)
388 {
389    unsigned num_instr = util_dynarray_num_elements(io_instructions, void *);
390 
391    /* We need to at least 2 instructions to have something to do. */
392    if (num_instr <= 1) {
393       /* Clear the array. The next block will reuse it. */
394       util_dynarray_clear(io_instructions);
395       return false;
396    }
397 
398    /* The instructions are sorted such that groups of vectorizable
399     * instructions are next to each other. Multiple incompatible
400     * groups of vectorizable instructions can occur in this array.
401     * The reason why 2 groups would be incompatible is that they
402     * could have a different intrinsic, indirect index, array index,
403     * vertex index, barycentrics, or location. Each group is vectorized
404     * separately.
405     *
406     * This reorders instructions in the array, but not in the shader.
407     */
408    qsort(io_instructions->data, num_instr, sizeof(void*), compare_intr);
409 
410    nir_intrinsic_instr *chan[8] = {0}, *prev = NULL;
411    unsigned chan_mask = 0;
412    bool progress = false;
413 
414    /* Vectorize all groups.
415     *
416     * The channels for each group are gathered. If 2 stores overwrite
417     * the same channel, the earlier store is DCE'd here.
418     */
419    util_dynarray_foreach(io_instructions, nir_intrinsic_instr *, intr) {
420       /* If the next instruction is not vectorizable, vectorize what
421        * we have gathered so far.
422        */
423       if (prev && compare_is_not_vectorizable(prev, *intr)) {
424          /* We need at least 2 instructions to have something to do. */
425          if (util_bitcount(chan_mask) > 1)
426             progress |= vectorize_slot(chan, chan_mask);
427 
428          prev = NULL;
429          memset(chan, 0, sizeof(chan));
430          chan_mask = 0;
431       }
432 
433       /* This performs DCE of output stores because the previous value
434        * is being overwritten.
435        */
436       unsigned index = nir_intrinsic_io_semantics(*intr).high_16bits * 4 +
437                        nir_intrinsic_component(*intr);
438       bool is_store = !nir_intrinsic_infos[(*intr)->intrinsic].has_dest;
439       if (is_store && chan[index])
440          nir_instr_remove(&chan[index]->instr);
441 
442       /* Gather the channel. */
443       chan[index] = *intr;
444       prev = *intr;
445       chan_mask |= BITFIELD_BIT(index);
446    }
447 
448    /* Vectorize the last group. */
449    if (prev && util_bitcount(chan_mask) > 1)
450       progress |= vectorize_slot(chan, chan_mask);
451 
452    /* Clear the array. The next block will reuse it. */
453    util_dynarray_clear(io_instructions);
454    return progress;
455 }
456 
457 bool
nir_opt_vectorize_io(nir_shader * shader,nir_variable_mode modes)458 nir_opt_vectorize_io(nir_shader *shader, nir_variable_mode modes)
459 {
460    assert(!(modes & ~(nir_var_shader_in | nir_var_shader_out)));
461 
462    if (shader->info.stage == MESA_SHADER_FRAGMENT &&
463        shader->options->io_options & nir_io_prefer_scalar_fs_inputs)
464       modes &= ~nir_var_shader_in;
465 
466    if ((shader->info.stage == MESA_SHADER_TESS_CTRL ||
467         shader->info.stage == MESA_SHADER_GEOMETRY) &&
468        util_bitcount(modes) == 2) {
469       /* When vectorizing TCS and GS IO, inputs can ignore barriers and emits,
470        * but that is only done when outputs are ignored, so vectorize them
471        * separately.
472        */
473       bool progress_in = nir_opt_vectorize_io(shader, nir_var_shader_in);
474       bool progress_out = nir_opt_vectorize_io(shader, nir_var_shader_out);
475       return progress_in || progress_out;
476    }
477 
478    /* Initialize dynamic arrays. */
479    struct util_dynarray io_instructions;
480    util_dynarray_init(&io_instructions, NULL);
481    bool global_progress = false;
482 
483    nir_foreach_function_impl(impl, shader) {
484       bool progress = false;
485       nir_metadata_require(impl, nir_metadata_instr_index);
486 
487       nir_foreach_block(block, impl) {
488          BITSET_DECLARE(has_output_loads, NUM_TOTAL_VARYING_SLOTS * 8);
489          BITSET_DECLARE(has_output_stores, NUM_TOTAL_VARYING_SLOTS * 8);
490          BITSET_ZERO(has_output_loads);
491          BITSET_ZERO(has_output_stores);
492 
493          /* Gather load/store intrinsics within the block. */
494          nir_foreach_instr(instr, block) {
495             if (instr->type != nir_instr_type_intrinsic)
496                continue;
497 
498             nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
499             bool is_load = nir_intrinsic_infos[intr->intrinsic].has_dest;
500             bool is_output = false;
501             nir_io_semantics sem = {0};
502             unsigned index = 0;
503 
504             if (nir_intrinsic_has_io_semantics(intr)) {
505                sem = nir_intrinsic_io_semantics(intr);
506                assert(sem.location < NUM_TOTAL_VARYING_SLOTS);
507                index = sem.location * 8 + sem.high_16bits * 4 +
508                        nir_intrinsic_component(intr);
509             }
510 
511             switch (intr->intrinsic) {
512             case nir_intrinsic_load_input:
513             case nir_intrinsic_load_per_primitive_input:
514             case nir_intrinsic_load_input_vertex:
515             case nir_intrinsic_load_interpolated_input:
516             case nir_intrinsic_load_per_vertex_input:
517                if (!(modes & nir_var_shader_in))
518                   continue;
519                break;
520 
521             case nir_intrinsic_load_output:
522             case nir_intrinsic_load_per_vertex_output:
523             case nir_intrinsic_load_per_view_output:
524             case nir_intrinsic_load_per_primitive_output:
525             case nir_intrinsic_store_output:
526             case nir_intrinsic_store_per_vertex_output:
527             case nir_intrinsic_store_per_view_output:
528             case nir_intrinsic_store_per_primitive_output:
529                if (!(modes & nir_var_shader_out))
530                   continue;
531 
532                /* Break the batch if an output load is followed by an output
533                 * store to the same channel and vice versa.
534                 */
535                if (BITSET_TEST(is_load ? has_output_stores : has_output_loads,
536                                index)) {
537                   progress |= vectorize_batch(&io_instructions);
538                   BITSET_ZERO(has_output_loads);
539                   BITSET_ZERO(has_output_stores);
540                }
541                is_output = true;
542                break;
543 
544             case nir_intrinsic_barrier:
545                /* Don't vectorize across TCS barriers. */
546                if (modes & nir_var_shader_out &&
547                    nir_intrinsic_memory_modes(intr) & nir_var_shader_out) {
548                   progress |= vectorize_batch(&io_instructions);
549                   BITSET_ZERO(has_output_loads);
550                   BITSET_ZERO(has_output_stores);
551                }
552                continue;
553 
554             case nir_intrinsic_emit_vertex:
555                /* Don't vectorize across GS emits. */
556                progress |= vectorize_batch(&io_instructions);
557                BITSET_ZERO(has_output_loads);
558                BITSET_ZERO(has_output_stores);
559                continue;
560 
561             default:
562                continue;
563             }
564 
565             /* Only scalar 16 and 32-bit instructions are allowed. */
566             ASSERTED nir_def *value = is_load ? &intr->def : intr->src[0].ssa;
567             assert(value->num_components == 1);
568             assert(value->bit_size == 16 || value->bit_size == 32);
569 
570             util_dynarray_append(&io_instructions, void *, intr);
571             if (is_output)
572                BITSET_SET(is_load ? has_output_loads : has_output_stores, index);
573          }
574 
575          progress |= vectorize_batch(&io_instructions);
576       }
577 
578       nir_metadata_preserve(impl, progress ? (nir_metadata_block_index |
579                                               nir_metadata_dominance) :
580                                              nir_metadata_all);
581       global_progress |= progress;
582    }
583    util_dynarray_fini(&io_instructions);
584 
585    return global_progress;
586 }
587