• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright © 2019 Valve Corporation
3  *
4  * Permission is hereby granted, free of charge, to any person obtaining a
5  * copy of this software and associated documentation files (the "Software"),
6  * to deal in the Software without restriction, including without limitation
7  * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8  * and/or sell copies of the Software, and to permit persons to whom the
9  * Software is furnished to do so, subject to the following conditions:
10  *
11  * The above copyright notice and this permission notice (including the next
12  * paragraph) shall be included in all copies or substantial portions of the
13  * Software.
14  *
15  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18  * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20  * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21  * IN THE SOFTWARE.
22  */
23 
24 /**
25  * Although it's called a load/store "vectorization" pass, this also combines
26  * intersecting and identical loads/stores. It currently supports derefs, ubo,
27  * ssbo and push constant loads/stores.
28  *
29  * This doesn't handle copy_deref intrinsics and assumes that
30  * nir_lower_alu_to_scalar() has been called and that the IR is free from ALU
31  * modifiers. It also assumes that derefs have explicitly laid out types.
32  *
33  * After vectorization, the backend may want to call nir_lower_alu_to_scalar()
34  * and nir_lower_pack(). Also this creates cast instructions taking derefs as a
35  * source and some parts of NIR may not be able to handle that well.
36  *
37  * There are a few situations where this doesn't vectorize as well as it could:
38  * - It won't turn four consecutive vec3 loads into 3 vec4 loads.
39  * - It doesn't do global vectorization.
40  * Handling these cases probably wouldn't provide much benefit though.
41  *
42  * This probably doesn't handle big-endian GPUs correctly.
43  */
44 
45 #include "util/u_dynarray.h"
46 #include "nir.h"
47 #include "nir_builder.h"
48 #include "nir_deref.h"
49 #include "nir_worklist.h"
50 
51 #include <stdlib.h>
52 
53 struct intrinsic_info {
54    nir_variable_mode mode; /* 0 if the mode is obtained from the deref. */
55    nir_intrinsic_op op;
56    bool is_atomic;
57    /* Indices into nir_intrinsic::src[] or -1 if not applicable. */
58    int resource_src; /* resource (e.g. from vulkan_resource_index) */
59    int base_src;     /* offset which it loads/stores from */
60    int deref_src;    /* deref which is loads/stores from */
61    int value_src;    /* the data it is storing */
62 };
63 
64 static const struct intrinsic_info *
get_info(nir_intrinsic_op op)65 get_info(nir_intrinsic_op op)
66 {
67    switch (op) {
68 #define INFO(mode, op, atomic, res, base, deref, val)                                                             \
69    case nir_intrinsic_##op: {                                                                                     \
70       static const struct intrinsic_info op##_info = { mode, nir_intrinsic_##op, atomic, res, base, deref, val }; \
71       return &op##_info;                                                                                          \
72    }
73 #define LOAD(mode, op, res, base, deref)       INFO(mode, load_##op, false, res, base, deref, -1)
74 #define STORE(mode, op, res, base, deref, val) INFO(mode, store_##op, false, res, base, deref, val)
75 #define ATOMIC(mode, type, res, base, deref, val)         \
76    INFO(mode, type##_atomic, true, res, base, deref, val) \
77    INFO(mode, type##_atomic_swap, true, res, base, deref, val)
78 
79       LOAD(nir_var_mem_push_const, push_constant, -1, 0, -1)
80       LOAD(nir_var_mem_ubo, ubo, 0, 1, -1)
81       LOAD(nir_var_mem_ssbo, ssbo, 0, 1, -1)
82       STORE(nir_var_mem_ssbo, ssbo, 1, 2, -1, 0)
83       LOAD(0, deref, -1, -1, 0)
84       STORE(0, deref, -1, -1, 0, 1)
85       LOAD(nir_var_mem_shared, shared, -1, 0, -1)
86       STORE(nir_var_mem_shared, shared, -1, 1, -1, 0)
87       LOAD(nir_var_mem_global, global, -1, 0, -1)
88       STORE(nir_var_mem_global, global, -1, 1, -1, 0)
89       LOAD(nir_var_mem_global, global_constant, -1, 0, -1)
90       LOAD(nir_var_mem_task_payload, task_payload, -1, 0, -1)
91       STORE(nir_var_mem_task_payload, task_payload, -1, 1, -1, 0)
92       ATOMIC(nir_var_mem_ssbo, ssbo, 0, 1, -1, 2)
93       ATOMIC(0, deref, -1, -1, 0, 1)
94       ATOMIC(nir_var_mem_shared, shared, -1, 0, -1, 1)
95       ATOMIC(nir_var_mem_global, global, -1, 0, -1, 1)
96       ATOMIC(nir_var_mem_task_payload, task_payload, -1, 0, -1, 1)
97       LOAD(nir_var_shader_temp, stack, -1, -1, -1)
98       STORE(nir_var_shader_temp, stack, -1, -1, -1, 0)
99       LOAD(nir_var_shader_temp, scratch, -1, 0, -1)
100       STORE(nir_var_shader_temp, scratch, -1, 1, -1, 0)
101       LOAD(nir_var_mem_ubo, ubo_uniform_block_intel, 0, 1, -1)
102       LOAD(nir_var_mem_ssbo, ssbo_uniform_block_intel, 0, 1, -1)
103       LOAD(nir_var_mem_shared, shared_uniform_block_intel, -1, 0, -1)
104       LOAD(nir_var_mem_global, global_constant_uniform_block_intel, -1, 0, -1)
105    default:
106       break;
107 #undef ATOMIC
108 #undef STORE
109 #undef LOAD
110 #undef INFO
111    }
112    return NULL;
113 }
114 
115 /*
116  * Information used to compare memory operations.
117  * It canonically represents an offset as:
118  * `offset_defs[0]*offset_defs_mul[0] + offset_defs[1]*offset_defs_mul[1] + ...`
119  * "offset_defs" is sorted in ascenting order by the ssa definition's index.
120  * "resource" or "var" may be NULL.
121  */
122 struct entry_key {
123    nir_def *resource;
124    nir_variable *var;
125    unsigned offset_def_count;
126    nir_scalar *offset_defs;
127    uint64_t *offset_defs_mul;
128 };
129 
130 /* Information on a single memory operation. */
131 struct entry {
132    struct list_head head;
133    unsigned index;
134 
135    struct entry_key *key;
136    union {
137       uint64_t offset; /* sign-extended */
138       int64_t offset_signed;
139    };
140    uint32_t align_mul;
141    uint32_t align_offset;
142 
143    nir_instr *instr;
144    nir_intrinsic_instr *intrin;
145    const struct intrinsic_info *info;
146    enum gl_access_qualifier access;
147    bool is_store;
148 
149    nir_deref_instr *deref;
150 };
151 
152 struct vectorize_ctx {
153    nir_shader *shader;
154    const nir_load_store_vectorize_options *options;
155    struct list_head entries[nir_num_variable_modes];
156    struct hash_table *loads[nir_num_variable_modes];
157    struct hash_table *stores[nir_num_variable_modes];
158 };
159 
160 static uint32_t
hash_entry_key(const void * key_)161 hash_entry_key(const void *key_)
162 {
163    /* this is careful to not include pointers in the hash calculation so that
164     * the order of the hash table walk is deterministic */
165    struct entry_key *key = (struct entry_key *)key_;
166 
167    uint32_t hash = 0;
168    if (key->resource)
169       hash = XXH32(&key->resource->index, sizeof(key->resource->index), hash);
170    if (key->var) {
171       hash = XXH32(&key->var->index, sizeof(key->var->index), hash);
172       unsigned mode = key->var->data.mode;
173       hash = XXH32(&mode, sizeof(mode), hash);
174    }
175 
176    for (unsigned i = 0; i < key->offset_def_count; i++) {
177       hash = XXH32(&key->offset_defs[i].def->index, sizeof(key->offset_defs[i].def->index), hash);
178       hash = XXH32(&key->offset_defs[i].comp, sizeof(key->offset_defs[i].comp), hash);
179    }
180 
181    hash = XXH32(key->offset_defs_mul, key->offset_def_count * sizeof(uint64_t), hash);
182 
183    return hash;
184 }
185 
186 static bool
entry_key_equals(const void * a_,const void * b_)187 entry_key_equals(const void *a_, const void *b_)
188 {
189    struct entry_key *a = (struct entry_key *)a_;
190    struct entry_key *b = (struct entry_key *)b_;
191 
192    if (a->var != b->var || a->resource != b->resource)
193       return false;
194 
195    if (a->offset_def_count != b->offset_def_count)
196       return false;
197 
198    for (unsigned i = 0; i < a->offset_def_count; i++) {
199       if (!nir_scalar_equal(a->offset_defs[i], b->offset_defs[i]))
200          return false;
201    }
202 
203    size_t offset_def_mul_size = a->offset_def_count * sizeof(uint64_t);
204    if (a->offset_def_count &&
205        memcmp(a->offset_defs_mul, b->offset_defs_mul, offset_def_mul_size))
206       return false;
207 
208    return true;
209 }
210 
211 static void
delete_entry_dynarray(struct hash_entry * entry)212 delete_entry_dynarray(struct hash_entry *entry)
213 {
214    struct util_dynarray *arr = (struct util_dynarray *)entry->data;
215    ralloc_free(arr);
216 }
217 
218 static int
sort_entries(const void * a_,const void * b_)219 sort_entries(const void *a_, const void *b_)
220 {
221    struct entry *a = *(struct entry *const *)a_;
222    struct entry *b = *(struct entry *const *)b_;
223 
224    if (a->offset_signed > b->offset_signed)
225       return 1;
226    else if (a->offset_signed < b->offset_signed)
227       return -1;
228    else
229       return 0;
230 }
231 
232 static unsigned
get_bit_size(struct entry * entry)233 get_bit_size(struct entry *entry)
234 {
235    unsigned size = entry->is_store ? entry->intrin->src[entry->info->value_src].ssa->bit_size : entry->intrin->def.bit_size;
236    return size == 1 ? 32u : size;
237 }
238 
239 /* If "def" is from an alu instruction with the opcode "op" and one of it's
240  * sources is a constant, update "def" to be the non-constant source, fill "c"
241  * with the constant and return true. */
242 static bool
parse_alu(nir_scalar * def,nir_op op,uint64_t * c)243 parse_alu(nir_scalar *def, nir_op op, uint64_t *c)
244 {
245    if (!nir_scalar_is_alu(*def) || nir_scalar_alu_op(*def) != op)
246       return false;
247 
248    nir_scalar src0 = nir_scalar_chase_alu_src(*def, 0);
249    nir_scalar src1 = nir_scalar_chase_alu_src(*def, 1);
250    if (op != nir_op_ishl && nir_scalar_is_const(src0)) {
251       *c = nir_scalar_as_uint(src0);
252       *def = src1;
253    } else if (nir_scalar_is_const(src1)) {
254       *c = nir_scalar_as_uint(src1);
255       *def = src0;
256    } else {
257       return false;
258    }
259    return true;
260 }
261 
262 /* Parses an offset expression such as "a * 16 + 4" and "(a * 16 + 4) * 64 + 32". */
263 static void
parse_offset(nir_scalar * base,uint64_t * base_mul,uint64_t * offset)264 parse_offset(nir_scalar *base, uint64_t *base_mul, uint64_t *offset)
265 {
266    if (nir_scalar_is_const(*base)) {
267       *offset = nir_scalar_as_uint(*base);
268       base->def = NULL;
269       return;
270    }
271 
272    uint64_t mul = 1;
273    uint64_t add = 0;
274    bool progress = false;
275    do {
276       uint64_t mul2 = 1, add2 = 0;
277 
278       progress = parse_alu(base, nir_op_imul, &mul2);
279       mul *= mul2;
280 
281       mul2 = 0;
282       progress |= parse_alu(base, nir_op_ishl, &mul2);
283       mul <<= mul2;
284 
285       progress |= parse_alu(base, nir_op_iadd, &add2);
286       add += add2 * mul;
287 
288       if (nir_scalar_is_alu(*base) && nir_scalar_alu_op(*base) == nir_op_mov) {
289          *base = nir_scalar_chase_alu_src(*base, 0);
290          progress = true;
291       }
292    } while (progress);
293 
294    if (base->def->parent_instr->type == nir_instr_type_intrinsic) {
295       nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(base->def->parent_instr);
296       if (intrin->intrinsic == nir_intrinsic_load_vulkan_descriptor)
297          base->def = NULL;
298    }
299 
300    *base_mul = mul;
301    *offset = add;
302 }
303 
304 static unsigned
type_scalar_size_bytes(const struct glsl_type * type)305 type_scalar_size_bytes(const struct glsl_type *type)
306 {
307    assert(glsl_type_is_vector_or_scalar(type) ||
308           glsl_type_is_matrix(type));
309    return glsl_type_is_boolean(type) ? 4u : glsl_get_bit_size(type) / 8u;
310 }
311 
312 static unsigned
add_to_entry_key(nir_scalar * offset_defs,uint64_t * offset_defs_mul,unsigned offset_def_count,nir_scalar def,uint64_t mul)313 add_to_entry_key(nir_scalar *offset_defs, uint64_t *offset_defs_mul,
314                  unsigned offset_def_count, nir_scalar def, uint64_t mul)
315 {
316    mul = util_mask_sign_extend(mul, def.def->bit_size);
317 
318    for (unsigned i = 0; i <= offset_def_count; i++) {
319       if (i == offset_def_count || def.def->index > offset_defs[i].def->index) {
320          /* insert before i */
321          memmove(offset_defs + i + 1, offset_defs + i,
322                  (offset_def_count - i) * sizeof(nir_scalar));
323          memmove(offset_defs_mul + i + 1, offset_defs_mul + i,
324                  (offset_def_count - i) * sizeof(uint64_t));
325          offset_defs[i] = def;
326          offset_defs_mul[i] = mul;
327          return 1;
328       } else if (nir_scalar_equal(def, offset_defs[i])) {
329          /* merge with offset_def at i */
330          offset_defs_mul[i] += mul;
331          return 0;
332       }
333    }
334    unreachable("Unreachable.");
335    return 0;
336 }
337 
338 static struct entry_key *
create_entry_key_from_deref(void * mem_ctx,struct vectorize_ctx * ctx,nir_deref_path * path,uint64_t * offset_base)339 create_entry_key_from_deref(void *mem_ctx,
340                             struct vectorize_ctx *ctx,
341                             nir_deref_path *path,
342                             uint64_t *offset_base)
343 {
344    unsigned path_len = 0;
345    while (path->path[path_len])
346       path_len++;
347 
348    nir_scalar offset_defs_stack[32];
349    uint64_t offset_defs_mul_stack[32];
350    nir_scalar *offset_defs = offset_defs_stack;
351    uint64_t *offset_defs_mul = offset_defs_mul_stack;
352    if (path_len > 32) {
353       offset_defs = malloc(path_len * sizeof(nir_scalar));
354       offset_defs_mul = malloc(path_len * sizeof(uint64_t));
355    }
356    unsigned offset_def_count = 0;
357 
358    struct entry_key *key = ralloc(mem_ctx, struct entry_key);
359    key->resource = NULL;
360    key->var = NULL;
361    *offset_base = 0;
362 
363    for (unsigned i = 0; i < path_len; i++) {
364       nir_deref_instr *parent = i ? path->path[i - 1] : NULL;
365       nir_deref_instr *deref = path->path[i];
366 
367       switch (deref->deref_type) {
368       case nir_deref_type_var: {
369          assert(!parent);
370          key->var = deref->var;
371          break;
372       }
373       case nir_deref_type_array:
374       case nir_deref_type_ptr_as_array: {
375          assert(parent);
376          nir_def *index = deref->arr.index.ssa;
377          uint32_t stride = nir_deref_instr_array_stride(deref);
378 
379          nir_scalar base = { .def = index, .comp = 0 };
380          uint64_t offset = 0, base_mul = 1;
381          parse_offset(&base, &base_mul, &offset);
382          offset = util_mask_sign_extend(offset, index->bit_size);
383 
384          *offset_base += offset * stride;
385          if (base.def) {
386             offset_def_count += add_to_entry_key(offset_defs, offset_defs_mul,
387                                                  offset_def_count,
388                                                  base, base_mul * stride);
389          }
390          break;
391       }
392       case nir_deref_type_struct: {
393          assert(parent);
394          int offset = glsl_get_struct_field_offset(parent->type, deref->strct.index);
395          *offset_base += offset;
396          break;
397       }
398       case nir_deref_type_cast: {
399          if (!parent)
400             key->resource = deref->parent.ssa;
401          break;
402       }
403       default:
404          unreachable("Unhandled deref type");
405       }
406    }
407 
408    key->offset_def_count = offset_def_count;
409    key->offset_defs = ralloc_array(mem_ctx, nir_scalar, offset_def_count);
410    key->offset_defs_mul = ralloc_array(mem_ctx, uint64_t, offset_def_count);
411    memcpy(key->offset_defs, offset_defs, offset_def_count * sizeof(nir_scalar));
412    memcpy(key->offset_defs_mul, offset_defs_mul, offset_def_count * sizeof(uint64_t));
413 
414    if (offset_defs != offset_defs_stack)
415       free(offset_defs);
416    if (offset_defs_mul != offset_defs_mul_stack)
417       free(offset_defs_mul);
418 
419    return key;
420 }
421 
422 static unsigned
parse_entry_key_from_offset(struct entry_key * key,unsigned size,unsigned left,nir_scalar base,uint64_t base_mul,uint64_t * offset)423 parse_entry_key_from_offset(struct entry_key *key, unsigned size, unsigned left,
424                             nir_scalar base, uint64_t base_mul, uint64_t *offset)
425 {
426    uint64_t new_mul;
427    uint64_t new_offset;
428    parse_offset(&base, &new_mul, &new_offset);
429    *offset += new_offset * base_mul;
430 
431    if (!base.def)
432       return 0;
433 
434    base_mul *= new_mul;
435 
436    assert(left >= 1);
437 
438    if (left >= 2) {
439       if (nir_scalar_is_alu(base) && nir_scalar_alu_op(base) == nir_op_iadd) {
440          nir_scalar src0 = nir_scalar_chase_alu_src(base, 0);
441          nir_scalar src1 = nir_scalar_chase_alu_src(base, 1);
442          unsigned amount = parse_entry_key_from_offset(key, size, left - 1, src0, base_mul, offset);
443          amount += parse_entry_key_from_offset(key, size + amount, left - amount, src1, base_mul, offset);
444          return amount;
445       }
446    }
447 
448    return add_to_entry_key(key->offset_defs, key->offset_defs_mul, size, base, base_mul);
449 }
450 
451 static struct entry_key *
create_entry_key_from_offset(void * mem_ctx,nir_def * base,uint64_t base_mul,uint64_t * offset)452 create_entry_key_from_offset(void *mem_ctx, nir_def *base, uint64_t base_mul, uint64_t *offset)
453 {
454    struct entry_key *key = ralloc(mem_ctx, struct entry_key);
455    key->resource = NULL;
456    key->var = NULL;
457    if (base) {
458       nir_scalar offset_defs[32];
459       uint64_t offset_defs_mul[32];
460       key->offset_defs = offset_defs;
461       key->offset_defs_mul = offset_defs_mul;
462 
463       nir_scalar scalar = { .def = base, .comp = 0 };
464       key->offset_def_count = parse_entry_key_from_offset(key, 0, 32, scalar, base_mul, offset);
465 
466       key->offset_defs = ralloc_array(mem_ctx, nir_scalar, key->offset_def_count);
467       key->offset_defs_mul = ralloc_array(mem_ctx, uint64_t, key->offset_def_count);
468       memcpy(key->offset_defs, offset_defs, key->offset_def_count * sizeof(nir_scalar));
469       memcpy(key->offset_defs_mul, offset_defs_mul, key->offset_def_count * sizeof(uint64_t));
470    } else {
471       key->offset_def_count = 0;
472       key->offset_defs = NULL;
473       key->offset_defs_mul = NULL;
474    }
475    return key;
476 }
477 
478 static nir_variable_mode
get_variable_mode(struct entry * entry)479 get_variable_mode(struct entry *entry)
480 {
481    if (entry->info->mode)
482       return entry->info->mode;
483    assert(entry->deref && util_bitcount(entry->deref->modes) == 1);
484    return entry->deref->modes;
485 }
486 
487 static unsigned
mode_to_index(nir_variable_mode mode)488 mode_to_index(nir_variable_mode mode)
489 {
490    assert(util_bitcount(mode) == 1);
491 
492    /* Globals and SSBOs should be tracked together */
493    if (mode == nir_var_mem_global)
494       mode = nir_var_mem_ssbo;
495 
496    return ffs(mode) - 1;
497 }
498 
499 static nir_variable_mode
aliasing_modes(nir_variable_mode modes)500 aliasing_modes(nir_variable_mode modes)
501 {
502    /* Global and SSBO can alias */
503    if (modes & (nir_var_mem_ssbo | nir_var_mem_global))
504       modes |= nir_var_mem_ssbo | nir_var_mem_global;
505    return modes;
506 }
507 
508 static void
calc_alignment(struct entry * entry)509 calc_alignment(struct entry *entry)
510 {
511    uint32_t align_mul = 31;
512    for (unsigned i = 0; i < entry->key->offset_def_count; i++) {
513       if (entry->key->offset_defs_mul[i])
514          align_mul = MIN2(align_mul, ffsll(entry->key->offset_defs_mul[i]));
515    }
516 
517    entry->align_mul = 1u << (align_mul - 1);
518    bool has_align = nir_intrinsic_infos[entry->intrin->intrinsic].index_map[NIR_INTRINSIC_ALIGN_MUL];
519    if (!has_align || entry->align_mul >= nir_intrinsic_align_mul(entry->intrin)) {
520       entry->align_offset = entry->offset % entry->align_mul;
521    } else {
522       entry->align_mul = nir_intrinsic_align_mul(entry->intrin);
523       entry->align_offset = nir_intrinsic_align_offset(entry->intrin);
524    }
525 }
526 
527 static struct entry *
create_entry(struct vectorize_ctx * ctx,const struct intrinsic_info * info,nir_intrinsic_instr * intrin)528 create_entry(struct vectorize_ctx *ctx,
529              const struct intrinsic_info *info,
530              nir_intrinsic_instr *intrin)
531 {
532    struct entry *entry = rzalloc(ctx, struct entry);
533    entry->intrin = intrin;
534    entry->instr = &intrin->instr;
535    entry->info = info;
536    entry->is_store = entry->info->value_src >= 0;
537 
538    if (entry->info->deref_src >= 0) {
539       entry->deref = nir_src_as_deref(intrin->src[entry->info->deref_src]);
540       nir_deref_path path;
541       nir_deref_path_init(&path, entry->deref, NULL);
542       entry->key = create_entry_key_from_deref(entry, ctx, &path, &entry->offset);
543       nir_deref_path_finish(&path);
544    } else {
545       nir_def *base = entry->info->base_src >= 0 ? intrin->src[entry->info->base_src].ssa : NULL;
546       uint64_t offset = 0;
547       if (nir_intrinsic_has_base(intrin))
548          offset += nir_intrinsic_base(intrin);
549       entry->key = create_entry_key_from_offset(entry, base, 1, &offset);
550       entry->offset = offset;
551 
552       if (base)
553          entry->offset = util_mask_sign_extend(entry->offset, base->bit_size);
554    }
555 
556    if (entry->info->resource_src >= 0)
557       entry->key->resource = intrin->src[entry->info->resource_src].ssa;
558 
559    if (nir_intrinsic_has_access(intrin))
560       entry->access = nir_intrinsic_access(intrin);
561    else if (entry->key->var)
562       entry->access = entry->key->var->data.access;
563 
564    if (nir_intrinsic_can_reorder(intrin))
565       entry->access |= ACCESS_CAN_REORDER;
566 
567    uint32_t restrict_modes = nir_var_shader_in | nir_var_shader_out;
568    restrict_modes |= nir_var_shader_temp | nir_var_function_temp;
569    restrict_modes |= nir_var_uniform | nir_var_mem_push_const;
570    restrict_modes |= nir_var_system_value | nir_var_mem_shared;
571    restrict_modes |= nir_var_mem_task_payload;
572    if (get_variable_mode(entry) & restrict_modes)
573       entry->access |= ACCESS_RESTRICT;
574 
575    calc_alignment(entry);
576 
577    return entry;
578 }
579 
580 static nir_deref_instr *
cast_deref(nir_builder * b,unsigned num_components,unsigned bit_size,nir_deref_instr * deref)581 cast_deref(nir_builder *b, unsigned num_components, unsigned bit_size, nir_deref_instr *deref)
582 {
583    if (glsl_get_components(deref->type) == num_components &&
584        type_scalar_size_bytes(deref->type) * 8u == bit_size)
585       return deref;
586 
587    enum glsl_base_type types[] = {
588       GLSL_TYPE_UINT8, GLSL_TYPE_UINT16, GLSL_TYPE_UINT, GLSL_TYPE_UINT64
589    };
590    enum glsl_base_type base = types[ffs(bit_size / 8u) - 1u];
591    const struct glsl_type *type = glsl_vector_type(base, num_components);
592 
593    if (deref->type == type)
594       return deref;
595 
596    return nir_build_deref_cast(b, &deref->def, deref->modes, type, 0);
597 }
598 
599 /* Return true if "new_bit_size" is a usable bit size for a vectorized load/store
600  * of "low" and "high". */
601 static bool
new_bitsize_acceptable(struct vectorize_ctx * ctx,unsigned new_bit_size,struct entry * low,struct entry * high,unsigned size)602 new_bitsize_acceptable(struct vectorize_ctx *ctx, unsigned new_bit_size,
603                        struct entry *low, struct entry *high, unsigned size)
604 {
605    if (size % new_bit_size != 0)
606       return false;
607 
608    unsigned new_num_components = size / new_bit_size;
609    if (!nir_num_components_valid(new_num_components))
610       return false;
611 
612    unsigned high_offset = high->offset_signed - low->offset_signed;
613 
614    /* check nir_extract_bits limitations */
615    unsigned common_bit_size = MIN2(get_bit_size(low), get_bit_size(high));
616    common_bit_size = MIN2(common_bit_size, new_bit_size);
617    if (high_offset > 0)
618       common_bit_size = MIN2(common_bit_size, (1u << (ffs(high_offset * 8) - 1)));
619    if (new_bit_size / common_bit_size > NIR_MAX_VEC_COMPONENTS)
620       return false;
621 
622    if (!ctx->options->callback(low->align_mul,
623                                low->align_offset,
624                                new_bit_size, new_num_components,
625                                low->intrin, high->intrin,
626                                ctx->options->cb_data))
627       return false;
628 
629    if (low->is_store) {
630       unsigned low_size = low->intrin->num_components * get_bit_size(low);
631       unsigned high_size = high->intrin->num_components * get_bit_size(high);
632 
633       if (low_size % new_bit_size != 0)
634          return false;
635       if (high_size % new_bit_size != 0)
636          return false;
637 
638       unsigned write_mask = nir_intrinsic_write_mask(low->intrin);
639       if (!nir_component_mask_can_reinterpret(write_mask, get_bit_size(low), new_bit_size))
640          return false;
641 
642       write_mask = nir_intrinsic_write_mask(high->intrin);
643       if (!nir_component_mask_can_reinterpret(write_mask, get_bit_size(high), new_bit_size))
644          return false;
645    }
646 
647    return true;
648 }
649 
650 static nir_deref_instr *
subtract_deref(nir_builder * b,nir_deref_instr * deref,int64_t offset)651 subtract_deref(nir_builder *b, nir_deref_instr *deref, int64_t offset)
652 {
653    /* avoid adding another deref to the path */
654    if (deref->deref_type == nir_deref_type_ptr_as_array &&
655        nir_src_is_const(deref->arr.index) &&
656        offset % nir_deref_instr_array_stride(deref) == 0) {
657       unsigned stride = nir_deref_instr_array_stride(deref);
658       nir_def *index = nir_imm_intN_t(b, nir_src_as_int(deref->arr.index) - offset / stride,
659                                       deref->def.bit_size);
660       return nir_build_deref_ptr_as_array(b, nir_deref_instr_parent(deref), index);
661    }
662 
663    if (deref->deref_type == nir_deref_type_array &&
664        nir_src_is_const(deref->arr.index)) {
665       nir_deref_instr *parent = nir_deref_instr_parent(deref);
666       unsigned stride = glsl_get_explicit_stride(parent->type);
667       if (offset % stride == 0)
668          return nir_build_deref_array_imm(
669             b, parent, nir_src_as_int(deref->arr.index) - offset / stride);
670    }
671 
672    deref = nir_build_deref_cast(b, &deref->def, deref->modes,
673                                 glsl_scalar_type(GLSL_TYPE_UINT8), 1);
674    return nir_build_deref_ptr_as_array(
675       b, deref, nir_imm_intN_t(b, -offset, deref->def.bit_size));
676 }
677 
678 static void
vectorize_loads(nir_builder * b,struct vectorize_ctx * ctx,struct entry * low,struct entry * high,struct entry * first,struct entry * second,unsigned new_bit_size,unsigned new_num_components,unsigned high_start)679 vectorize_loads(nir_builder *b, struct vectorize_ctx *ctx,
680                 struct entry *low, struct entry *high,
681                 struct entry *first, struct entry *second,
682                 unsigned new_bit_size, unsigned new_num_components,
683                 unsigned high_start)
684 {
685    unsigned low_bit_size = get_bit_size(low);
686    unsigned high_bit_size = get_bit_size(high);
687    bool low_bool = low->intrin->def.bit_size == 1;
688    bool high_bool = high->intrin->def.bit_size == 1;
689    nir_def *data = &first->intrin->def;
690 
691    b->cursor = nir_after_instr(first->instr);
692 
693    /* update the load's destination size and extract data for each of the original loads */
694    data->num_components = new_num_components;
695    data->bit_size = new_bit_size;
696 
697    nir_def *low_def = nir_extract_bits(
698       b, &data, 1, 0, low->intrin->num_components, low_bit_size);
699    nir_def *high_def = nir_extract_bits(
700       b, &data, 1, high_start, high->intrin->num_components, high_bit_size);
701 
702    /* convert booleans */
703    low_def = low_bool ? nir_i2b(b, low_def) : nir_mov(b, low_def);
704    high_def = high_bool ? nir_i2b(b, high_def) : nir_mov(b, high_def);
705 
706    /* update uses */
707    if (first == low) {
708       nir_def_rewrite_uses_after(&low->intrin->def, low_def,
709                                  high_def->parent_instr);
710       nir_def_rewrite_uses(&high->intrin->def, high_def);
711    } else {
712       nir_def_rewrite_uses(&low->intrin->def, low_def);
713       nir_def_rewrite_uses_after(&high->intrin->def, high_def,
714                                  high_def->parent_instr);
715    }
716 
717    /* update the intrinsic */
718    first->intrin->num_components = new_num_components;
719 
720    const struct intrinsic_info *info = first->info;
721 
722    /* update the offset */
723    if (first != low && info->base_src >= 0) {
724       /* let nir_opt_algebraic() remove this addition. this doesn't have much
725        * issues with subtracting 16 from expressions like "(i + 1) * 16" because
726        * nir_opt_algebraic() turns them into "i * 16 + 16" */
727       b->cursor = nir_before_instr(first->instr);
728 
729       nir_def *new_base = first->intrin->src[info->base_src].ssa;
730       new_base = nir_iadd_imm(b, new_base, -(int)(high_start / 8u));
731 
732       nir_src_rewrite(&first->intrin->src[info->base_src], new_base);
733    }
734 
735    /* update the deref */
736    if (info->deref_src >= 0) {
737       b->cursor = nir_before_instr(first->instr);
738 
739       nir_deref_instr *deref = nir_src_as_deref(first->intrin->src[info->deref_src]);
740       if (first != low && high_start != 0)
741          deref = subtract_deref(b, deref, high_start / 8u);
742       first->deref = cast_deref(b, new_num_components, new_bit_size, deref);
743 
744       nir_src_rewrite(&first->intrin->src[info->deref_src],
745                       &first->deref->def);
746    }
747 
748    /* update align */
749    if (nir_intrinsic_has_range_base(first->intrin)) {
750       uint32_t low_base = nir_intrinsic_range_base(low->intrin);
751       uint32_t high_base = nir_intrinsic_range_base(high->intrin);
752       uint32_t low_end = low_base + nir_intrinsic_range(low->intrin);
753       uint32_t high_end = high_base + nir_intrinsic_range(high->intrin);
754 
755       nir_intrinsic_set_range_base(first->intrin, low_base);
756       nir_intrinsic_set_range(first->intrin, MAX2(low_end, high_end) - low_base);
757    } else if (nir_intrinsic_has_base(first->intrin) && info->base_src == -1 && info->deref_src == -1) {
758       nir_intrinsic_set_base(first->intrin, nir_intrinsic_base(low->intrin));
759    }
760 
761    first->key = low->key;
762    first->offset = low->offset;
763 
764    first->align_mul = low->align_mul;
765    first->align_offset = low->align_offset;
766 
767    nir_instr_remove(second->instr);
768 }
769 
770 static void
vectorize_stores(nir_builder * b,struct vectorize_ctx * ctx,struct entry * low,struct entry * high,struct entry * first,struct entry * second,unsigned new_bit_size,unsigned new_num_components,unsigned high_start)771 vectorize_stores(nir_builder *b, struct vectorize_ctx *ctx,
772                  struct entry *low, struct entry *high,
773                  struct entry *first, struct entry *second,
774                  unsigned new_bit_size, unsigned new_num_components,
775                  unsigned high_start)
776 {
777    ASSERTED unsigned low_size = low->intrin->num_components * get_bit_size(low);
778    assert(low_size % new_bit_size == 0);
779 
780    b->cursor = nir_before_instr(second->instr);
781 
782    /* get new writemasks */
783    uint32_t low_write_mask = nir_intrinsic_write_mask(low->intrin);
784    uint32_t high_write_mask = nir_intrinsic_write_mask(high->intrin);
785    low_write_mask = nir_component_mask_reinterpret(low_write_mask,
786                                                    get_bit_size(low),
787                                                    new_bit_size);
788    high_write_mask = nir_component_mask_reinterpret(high_write_mask,
789                                                     get_bit_size(high),
790                                                     new_bit_size);
791    high_write_mask <<= high_start / new_bit_size;
792 
793    uint32_t write_mask = low_write_mask | high_write_mask;
794 
795    /* convert booleans */
796    nir_def *low_val = low->intrin->src[low->info->value_src].ssa;
797    nir_def *high_val = high->intrin->src[high->info->value_src].ssa;
798    low_val = low_val->bit_size == 1 ? nir_b2iN(b, low_val, 32) : low_val;
799    high_val = high_val->bit_size == 1 ? nir_b2iN(b, high_val, 32) : high_val;
800 
801    /* combine the data */
802    nir_def *data_channels[NIR_MAX_VEC_COMPONENTS];
803    for (unsigned i = 0; i < new_num_components; i++) {
804       bool set_low = low_write_mask & (1 << i);
805       bool set_high = high_write_mask & (1 << i);
806 
807       if (set_low && (!set_high || low == second)) {
808          unsigned offset = i * new_bit_size;
809          data_channels[i] = nir_extract_bits(b, &low_val, 1, offset, 1, new_bit_size);
810       } else if (set_high) {
811          assert(!set_low || high == second);
812          unsigned offset = i * new_bit_size - high_start;
813          data_channels[i] = nir_extract_bits(b, &high_val, 1, offset, 1, new_bit_size);
814       } else {
815          data_channels[i] = nir_undef(b, 1, new_bit_size);
816       }
817    }
818    nir_def *data = nir_vec(b, data_channels, new_num_components);
819 
820    /* update the intrinsic */
821    nir_intrinsic_set_write_mask(second->intrin, write_mask);
822    second->intrin->num_components = data->num_components;
823 
824    const struct intrinsic_info *info = second->info;
825    assert(info->value_src >= 0);
826    nir_src_rewrite(&second->intrin->src[info->value_src], data);
827 
828    /* update the offset */
829    if (second != low && info->base_src >= 0)
830       nir_src_rewrite(&second->intrin->src[info->base_src],
831                       low->intrin->src[info->base_src].ssa);
832 
833    /* update the deref */
834    if (info->deref_src >= 0) {
835       b->cursor = nir_before_instr(second->instr);
836       second->deref = cast_deref(b, new_num_components, new_bit_size,
837                                  nir_src_as_deref(low->intrin->src[info->deref_src]));
838       nir_src_rewrite(&second->intrin->src[info->deref_src],
839                       &second->deref->def);
840    }
841 
842    /* update base/align */
843    if (second != low && nir_intrinsic_has_base(second->intrin))
844       nir_intrinsic_set_base(second->intrin, nir_intrinsic_base(low->intrin));
845 
846    second->key = low->key;
847    second->offset = low->offset;
848 
849    second->align_mul = low->align_mul;
850    second->align_offset = low->align_offset;
851 
852    list_del(&first->head);
853    nir_instr_remove(first->instr);
854 }
855 
856 /* Returns true if it can prove that "a" and "b" point to different bindings
857  * and either one uses ACCESS_RESTRICT. */
858 static bool
bindings_different_restrict(nir_shader * shader,struct entry * a,struct entry * b)859 bindings_different_restrict(nir_shader *shader, struct entry *a, struct entry *b)
860 {
861    bool different_bindings = false;
862    nir_variable *a_var = NULL, *b_var = NULL;
863    if (a->key->resource && b->key->resource) {
864       nir_binding a_res = nir_chase_binding(nir_src_for_ssa(a->key->resource));
865       nir_binding b_res = nir_chase_binding(nir_src_for_ssa(b->key->resource));
866       if (!a_res.success || !b_res.success)
867          return false;
868 
869       if (a_res.num_indices != b_res.num_indices ||
870           a_res.desc_set != b_res.desc_set ||
871           a_res.binding != b_res.binding)
872          different_bindings = true;
873 
874       for (unsigned i = 0; i < a_res.num_indices; i++) {
875          if (nir_src_is_const(a_res.indices[i]) && nir_src_is_const(b_res.indices[i]) &&
876              nir_src_as_uint(a_res.indices[i]) != nir_src_as_uint(b_res.indices[i]))
877             different_bindings = true;
878       }
879 
880       if (different_bindings) {
881          a_var = nir_get_binding_variable(shader, a_res);
882          b_var = nir_get_binding_variable(shader, b_res);
883       }
884    } else if (a->key->var && b->key->var) {
885       a_var = a->key->var;
886       b_var = b->key->var;
887       different_bindings = a_var != b_var;
888    } else if (!!a->key->resource != !!b->key->resource) {
889       /* comparing global and ssbo access */
890       different_bindings = true;
891 
892       if (a->key->resource) {
893          nir_binding a_res = nir_chase_binding(nir_src_for_ssa(a->key->resource));
894          a_var = nir_get_binding_variable(shader, a_res);
895       }
896 
897       if (b->key->resource) {
898          nir_binding b_res = nir_chase_binding(nir_src_for_ssa(b->key->resource));
899          b_var = nir_get_binding_variable(shader, b_res);
900       }
901    } else {
902       return false;
903    }
904 
905    unsigned a_access = a->access | (a_var ? a_var->data.access : 0);
906    unsigned b_access = b->access | (b_var ? b_var->data.access : 0);
907 
908    return different_bindings &&
909           ((a_access | b_access) & ACCESS_RESTRICT);
910 }
911 
912 static int64_t
compare_entries(struct entry * a,struct entry * b)913 compare_entries(struct entry *a, struct entry *b)
914 {
915    if (!entry_key_equals(a->key, b->key))
916       return INT64_MAX;
917    return b->offset_signed - a->offset_signed;
918 }
919 
920 static bool
may_alias(nir_shader * shader,struct entry * a,struct entry * b)921 may_alias(nir_shader *shader, struct entry *a, struct entry *b)
922 {
923    assert(mode_to_index(get_variable_mode(a)) ==
924           mode_to_index(get_variable_mode(b)));
925 
926    if ((a->access | b->access) & ACCESS_CAN_REORDER)
927       return false;
928 
929    /* if the resources/variables are definitively different and both have
930     * ACCESS_RESTRICT, we can assume they do not alias. */
931    if (bindings_different_restrict(shader, a, b))
932       return false;
933 
934    /* we can't compare offsets if the resources/variables might be different */
935    if (a->key->var != b->key->var || a->key->resource != b->key->resource)
936       return true;
937 
938    /* use adjacency information */
939    /* TODO: we can look closer at the entry keys */
940    int64_t diff = compare_entries(a, b);
941    if (diff != INT64_MAX) {
942       /* with atomics, intrin->num_components can be 0 */
943       if (diff < 0)
944          return llabs(diff) < MAX2(b->intrin->num_components, 1u) * (get_bit_size(b) / 8u);
945       else
946          return diff < MAX2(a->intrin->num_components, 1u) * (get_bit_size(a) / 8u);
947    }
948 
949    /* TODO: we can use deref information */
950 
951    return true;
952 }
953 
954 static bool
check_for_aliasing(struct vectorize_ctx * ctx,struct entry * first,struct entry * second)955 check_for_aliasing(struct vectorize_ctx *ctx, struct entry *first, struct entry *second)
956 {
957    nir_variable_mode mode = get_variable_mode(first);
958    if (mode & (nir_var_uniform | nir_var_system_value |
959                nir_var_mem_push_const | nir_var_mem_ubo))
960       return false;
961 
962    unsigned mode_index = mode_to_index(mode);
963    if (first->is_store) {
964       /* find first entry that aliases "first" */
965       list_for_each_entry_from(struct entry, next, first, &ctx->entries[mode_index], head) {
966          if (next == first)
967             continue;
968          if (next == second)
969             return false;
970          if (may_alias(ctx->shader, first, next))
971             return true;
972       }
973    } else {
974       /* find previous store that aliases this load */
975       list_for_each_entry_from_rev(struct entry, prev, second, &ctx->entries[mode_index], head) {
976          if (prev == second)
977             continue;
978          if (prev == first)
979             return false;
980          if (prev->is_store && may_alias(ctx->shader, second, prev))
981             return true;
982       }
983    }
984 
985    return false;
986 }
987 
988 static uint64_t
calc_gcd(uint64_t a,uint64_t b)989 calc_gcd(uint64_t a, uint64_t b)
990 {
991    while (b != 0) {
992       int tmp_a = a;
993       a = b;
994       b = tmp_a % b;
995    }
996    return a;
997 }
998 
999 static uint64_t
round_down(uint64_t a,uint64_t b)1000 round_down(uint64_t a, uint64_t b)
1001 {
1002    return a / b * b;
1003 }
1004 
1005 static bool
addition_wraps(uint64_t a,uint64_t b,unsigned bits)1006 addition_wraps(uint64_t a, uint64_t b, unsigned bits)
1007 {
1008    uint64_t mask = BITFIELD64_MASK(bits);
1009    return ((a + b) & mask) < (a & mask);
1010 }
1011 
1012 /* Return true if the addition of "low"'s offset and "high_offset" could wrap
1013  * around.
1014  *
1015  * This is to prevent a situation where the hardware considers the high load
1016  * out-of-bounds after vectorization if the low load is out-of-bounds, even if
1017  * the wrap-around from the addition could make the high load in-bounds.
1018  */
1019 static bool
check_for_robustness(struct vectorize_ctx * ctx,struct entry * low,uint64_t high_offset)1020 check_for_robustness(struct vectorize_ctx *ctx, struct entry *low, uint64_t high_offset)
1021 {
1022    nir_variable_mode mode = get_variable_mode(low);
1023    if (!(mode & ctx->options->robust_modes))
1024       return false;
1025 
1026    /* First, try to use alignment information in case the application provided some. If the addition
1027     * of the maximum offset of the low load and "high_offset" wraps around, we can't combine the low
1028     * and high loads.
1029     */
1030    uint64_t max_low = round_down(UINT64_MAX, low->align_mul) + low->align_offset;
1031    if (!addition_wraps(max_low, high_offset, 64))
1032       return false;
1033 
1034    /* We can't obtain addition_bits */
1035    if (low->info->base_src < 0)
1036       return true;
1037 
1038    /* Second, use information about the factors from address calculation (offset_defs_mul). These
1039     * are not guaranteed to be power-of-2.
1040     */
1041    uint64_t stride = 0;
1042    for (unsigned i = 0; i < low->key->offset_def_count; i++)
1043       stride = calc_gcd(low->key->offset_defs_mul[i], stride);
1044 
1045    unsigned addition_bits = low->intrin->src[low->info->base_src].ssa->bit_size;
1046    /* low's offset must be a multiple of "stride" plus "low->offset". */
1047    max_low = low->offset;
1048    if (stride)
1049       max_low = round_down(BITFIELD64_MASK(addition_bits), stride) + (low->offset % stride);
1050    return addition_wraps(max_low, high_offset, addition_bits);
1051 }
1052 
1053 static bool
is_strided_vector(const struct glsl_type * type)1054 is_strided_vector(const struct glsl_type *type)
1055 {
1056    if (glsl_type_is_vector(type)) {
1057       unsigned explicit_stride = glsl_get_explicit_stride(type);
1058       return explicit_stride != 0 && explicit_stride !=
1059                                         type_scalar_size_bytes(glsl_get_array_element(type));
1060    } else {
1061       return false;
1062    }
1063 }
1064 
1065 static bool
can_vectorize(struct vectorize_ctx * ctx,struct entry * first,struct entry * second)1066 can_vectorize(struct vectorize_ctx *ctx, struct entry *first, struct entry *second)
1067 {
1068    if (!(get_variable_mode(first) & ctx->options->modes) ||
1069        !(get_variable_mode(second) & ctx->options->modes))
1070       return false;
1071 
1072    if (check_for_aliasing(ctx, first, second))
1073       return false;
1074 
1075    /* we can only vectorize non-volatile loads/stores of the same type and with
1076     * the same access */
1077    if (first->info != second->info || first->access != second->access ||
1078        (first->access & ACCESS_VOLATILE) || first->info->is_atomic)
1079       return false;
1080 
1081    return true;
1082 }
1083 
1084 static bool
try_vectorize(nir_function_impl * impl,struct vectorize_ctx * ctx,struct entry * low,struct entry * high,struct entry * first,struct entry * second)1085 try_vectorize(nir_function_impl *impl, struct vectorize_ctx *ctx,
1086               struct entry *low, struct entry *high,
1087               struct entry *first, struct entry *second)
1088 {
1089    if (!can_vectorize(ctx, first, second))
1090       return false;
1091 
1092    uint64_t diff = high->offset_signed - low->offset_signed;
1093    if (check_for_robustness(ctx, low, diff))
1094       return false;
1095 
1096    /* don't attempt to vectorize accesses of row-major matrix columns */
1097    if (first->deref) {
1098       const struct glsl_type *first_type = first->deref->type;
1099       const struct glsl_type *second_type = second->deref->type;
1100       if (is_strided_vector(first_type) || is_strided_vector(second_type))
1101          return false;
1102    }
1103 
1104    /* gather information */
1105    unsigned low_bit_size = get_bit_size(low);
1106    unsigned high_bit_size = get_bit_size(high);
1107    unsigned low_size = low->intrin->num_components * low_bit_size;
1108    unsigned high_size = high->intrin->num_components * high_bit_size;
1109    unsigned new_size = MAX2(diff * 8u + high_size, low_size);
1110 
1111    /* find a good bit size for the new load/store */
1112    unsigned new_bit_size = 0;
1113    if (new_bitsize_acceptable(ctx, low_bit_size, low, high, new_size)) {
1114       new_bit_size = low_bit_size;
1115    } else if (low_bit_size != high_bit_size &&
1116               new_bitsize_acceptable(ctx, high_bit_size, low, high, new_size)) {
1117       new_bit_size = high_bit_size;
1118    } else {
1119       new_bit_size = 64;
1120       for (; new_bit_size >= 8; new_bit_size /= 2) {
1121          /* don't repeat trying out bitsizes */
1122          if (new_bit_size == low_bit_size || new_bit_size == high_bit_size)
1123             continue;
1124          if (new_bitsize_acceptable(ctx, new_bit_size, low, high, new_size))
1125             break;
1126       }
1127       if (new_bit_size < 8)
1128          return false;
1129    }
1130    unsigned new_num_components = new_size / new_bit_size;
1131 
1132    /* vectorize the loads/stores */
1133    nir_builder b = nir_builder_create(impl);
1134 
1135    if (first->is_store)
1136       vectorize_stores(&b, ctx, low, high, first, second,
1137                        new_bit_size, new_num_components, diff * 8u);
1138    else
1139       vectorize_loads(&b, ctx, low, high, first, second,
1140                       new_bit_size, new_num_components, diff * 8u);
1141 
1142    return true;
1143 }
1144 
1145 static bool
try_vectorize_shared2(struct vectorize_ctx * ctx,struct entry * low,struct entry * high,struct entry * first,struct entry * second)1146 try_vectorize_shared2(struct vectorize_ctx *ctx,
1147                       struct entry *low, struct entry *high,
1148                       struct entry *first, struct entry *second)
1149 {
1150    if (!can_vectorize(ctx, first, second) || first->deref)
1151       return false;
1152 
1153    unsigned low_bit_size = get_bit_size(low);
1154    unsigned high_bit_size = get_bit_size(high);
1155    unsigned low_size = low->intrin->num_components * low_bit_size / 8;
1156    unsigned high_size = high->intrin->num_components * high_bit_size / 8;
1157    if ((low_size != 4 && low_size != 8) || (high_size != 4 && high_size != 8))
1158       return false;
1159    if (low_size != high_size)
1160       return false;
1161    if (low->align_mul % low_size || low->align_offset % low_size)
1162       return false;
1163    if (high->align_mul % low_size || high->align_offset % low_size)
1164       return false;
1165 
1166    uint64_t diff = high->offset_signed - low->offset_signed;
1167    bool st64 = diff % (64 * low_size) == 0;
1168    unsigned stride = st64 ? 64 * low_size : low_size;
1169    if (diff % stride || diff > 255 * stride)
1170       return false;
1171 
1172    /* try to avoid creating accesses we can't combine additions/offsets into */
1173    if (high->offset > 255 * stride || (st64 && high->offset % stride))
1174       return false;
1175 
1176    if (first->is_store) {
1177       if (nir_intrinsic_write_mask(low->intrin) != BITFIELD_MASK(low->intrin->num_components))
1178          return false;
1179       if (nir_intrinsic_write_mask(high->intrin) != BITFIELD_MASK(high->intrin->num_components))
1180          return false;
1181    }
1182 
1183    /* vectorize the accesses */
1184    nir_builder b = nir_builder_at(nir_after_instr(first->is_store ? second->instr : first->instr));
1185 
1186    nir_def *offset = first->intrin->src[first->is_store].ssa;
1187    offset = nir_iadd_imm(&b, offset, nir_intrinsic_base(first->intrin));
1188    if (first != low)
1189       offset = nir_iadd_imm(&b, offset, -(int)diff);
1190 
1191    if (first->is_store) {
1192       nir_def *low_val = low->intrin->src[low->info->value_src].ssa;
1193       nir_def *high_val = high->intrin->src[high->info->value_src].ssa;
1194       nir_def *val = nir_vec2(&b, nir_bitcast_vector(&b, low_val, low_size * 8u),
1195                               nir_bitcast_vector(&b, high_val, low_size * 8u));
1196       nir_store_shared2_amd(&b, val, offset, .offset1 = diff / stride, .st64 = st64);
1197    } else {
1198       nir_def *new_def = nir_load_shared2_amd(&b, low_size * 8u, offset, .offset1 = diff / stride,
1199                                               .st64 = st64);
1200       nir_def_rewrite_uses(&low->intrin->def,
1201                            nir_bitcast_vector(&b, nir_channel(&b, new_def, 0), low_bit_size));
1202       nir_def_rewrite_uses(&high->intrin->def,
1203                            nir_bitcast_vector(&b, nir_channel(&b, new_def, 1), high_bit_size));
1204    }
1205 
1206    nir_instr_remove(first->instr);
1207    nir_instr_remove(second->instr);
1208 
1209    return true;
1210 }
1211 
1212 static bool
update_align(struct entry * entry)1213 update_align(struct entry *entry)
1214 {
1215    if (nir_intrinsic_has_align_mul(entry->intrin) &&
1216        (entry->align_mul != nir_intrinsic_align_mul(entry->intrin) ||
1217         entry->align_offset != nir_intrinsic_align_offset(entry->intrin))) {
1218       nir_intrinsic_set_align(entry->intrin, entry->align_mul, entry->align_offset);
1219       return true;
1220    }
1221    return false;
1222 }
1223 
1224 static bool
vectorize_sorted_entries(struct vectorize_ctx * ctx,nir_function_impl * impl,struct util_dynarray * arr)1225 vectorize_sorted_entries(struct vectorize_ctx *ctx, nir_function_impl *impl,
1226                          struct util_dynarray *arr)
1227 {
1228    unsigned num_entries = util_dynarray_num_elements(arr, struct entry *);
1229 
1230    bool progress = false;
1231    for (unsigned first_idx = 0; first_idx < num_entries; first_idx++) {
1232       struct entry *low = *util_dynarray_element(arr, struct entry *, first_idx);
1233       if (!low)
1234          continue;
1235 
1236       for (unsigned second_idx = first_idx + 1; second_idx < num_entries; second_idx++) {
1237          struct entry *high = *util_dynarray_element(arr, struct entry *, second_idx);
1238          if (!high)
1239             continue;
1240 
1241          struct entry *first = low->index < high->index ? low : high;
1242          struct entry *second = low->index < high->index ? high : low;
1243 
1244          uint64_t diff = high->offset_signed - low->offset_signed;
1245          bool separate = diff > get_bit_size(low) / 8u * low->intrin->num_components;
1246          if (separate) {
1247             if (!ctx->options->has_shared2_amd ||
1248                 get_variable_mode(first) != nir_var_mem_shared)
1249                break;
1250 
1251             if (try_vectorize_shared2(ctx, low, high, first, second)) {
1252                low = NULL;
1253                *util_dynarray_element(arr, struct entry *, second_idx) = NULL;
1254                progress = true;
1255                break;
1256             }
1257          } else {
1258             if (try_vectorize(impl, ctx, low, high, first, second)) {
1259                low = low->is_store ? second : first;
1260                *util_dynarray_element(arr, struct entry *, second_idx) = NULL;
1261                progress = true;
1262             }
1263          }
1264       }
1265 
1266       *util_dynarray_element(arr, struct entry *, first_idx) = low;
1267    }
1268 
1269    return progress;
1270 }
1271 
1272 static bool
vectorize_entries(struct vectorize_ctx * ctx,nir_function_impl * impl,struct hash_table * ht)1273 vectorize_entries(struct vectorize_ctx *ctx, nir_function_impl *impl, struct hash_table *ht)
1274 {
1275    if (!ht)
1276       return false;
1277 
1278    bool progress = false;
1279    hash_table_foreach(ht, entry) {
1280       struct util_dynarray *arr = entry->data;
1281       if (!arr->size)
1282          continue;
1283 
1284       qsort(util_dynarray_begin(arr),
1285             util_dynarray_num_elements(arr, struct entry *),
1286             sizeof(struct entry *), &sort_entries);
1287 
1288       while (vectorize_sorted_entries(ctx, impl, arr))
1289          progress = true;
1290 
1291       util_dynarray_foreach(arr, struct entry *, elem) {
1292          if (*elem)
1293             progress |= update_align(*elem);
1294       }
1295    }
1296 
1297    _mesa_hash_table_clear(ht, delete_entry_dynarray);
1298 
1299    return progress;
1300 }
1301 
1302 static bool
handle_barrier(struct vectorize_ctx * ctx,bool * progress,nir_function_impl * impl,nir_instr * instr)1303 handle_barrier(struct vectorize_ctx *ctx, bool *progress, nir_function_impl *impl, nir_instr *instr)
1304 {
1305    unsigned modes = 0;
1306    bool acquire = true;
1307    bool release = true;
1308    if (instr->type == nir_instr_type_intrinsic) {
1309       nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
1310       switch (intrin->intrinsic) {
1311       /* prevent speculative loads/stores */
1312       case nir_intrinsic_discard_if:
1313       case nir_intrinsic_discard:
1314       case nir_intrinsic_terminate_if:
1315       case nir_intrinsic_terminate:
1316       case nir_intrinsic_launch_mesh_workgroups:
1317          modes = nir_var_all;
1318          break;
1319       case nir_intrinsic_demote_if:
1320       case nir_intrinsic_demote:
1321          acquire = false;
1322          modes = nir_var_all;
1323          break;
1324       case nir_intrinsic_barrier:
1325          if (nir_intrinsic_memory_scope(intrin) == SCOPE_NONE)
1326             break;
1327 
1328          modes = nir_intrinsic_memory_modes(intrin) & (nir_var_mem_ssbo |
1329                                                        nir_var_mem_shared |
1330                                                        nir_var_mem_global |
1331                                                        nir_var_mem_task_payload);
1332          acquire = nir_intrinsic_memory_semantics(intrin) & NIR_MEMORY_ACQUIRE;
1333          release = nir_intrinsic_memory_semantics(intrin) & NIR_MEMORY_RELEASE;
1334          switch (nir_intrinsic_memory_scope(intrin)) {
1335          case SCOPE_INVOCATION:
1336             /* a barier should never be required for correctness with these scopes */
1337             modes = 0;
1338             break;
1339          default:
1340             break;
1341          }
1342          break;
1343       default:
1344          return false;
1345       }
1346    } else if (instr->type == nir_instr_type_call) {
1347       modes = nir_var_all;
1348    } else {
1349       return false;
1350    }
1351 
1352    while (modes) {
1353       unsigned mode_index = u_bit_scan(&modes);
1354       if ((1 << mode_index) == nir_var_mem_global) {
1355          /* Global should be rolled in with SSBO */
1356          assert(list_is_empty(&ctx->entries[mode_index]));
1357          assert(ctx->loads[mode_index] == NULL);
1358          assert(ctx->stores[mode_index] == NULL);
1359          continue;
1360       }
1361 
1362       if (acquire)
1363          *progress |= vectorize_entries(ctx, impl, ctx->loads[mode_index]);
1364       if (release)
1365          *progress |= vectorize_entries(ctx, impl, ctx->stores[mode_index]);
1366    }
1367 
1368    return true;
1369 }
1370 
1371 static bool
process_block(nir_function_impl * impl,struct vectorize_ctx * ctx,nir_block * block)1372 process_block(nir_function_impl *impl, struct vectorize_ctx *ctx, nir_block *block)
1373 {
1374    bool progress = false;
1375 
1376    for (unsigned i = 0; i < nir_num_variable_modes; i++) {
1377       list_inithead(&ctx->entries[i]);
1378       if (ctx->loads[i])
1379          _mesa_hash_table_clear(ctx->loads[i], delete_entry_dynarray);
1380       if (ctx->stores[i])
1381          _mesa_hash_table_clear(ctx->stores[i], delete_entry_dynarray);
1382    }
1383 
1384    /* create entries */
1385    unsigned next_index = 0;
1386 
1387    nir_foreach_instr_safe(instr, block) {
1388       if (handle_barrier(ctx, &progress, impl, instr))
1389          continue;
1390 
1391       /* gather information */
1392       if (instr->type != nir_instr_type_intrinsic)
1393          continue;
1394       nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
1395 
1396       const struct intrinsic_info *info = get_info(intrin->intrinsic);
1397       if (!info)
1398          continue;
1399 
1400       nir_variable_mode mode = info->mode;
1401       if (!mode)
1402          mode = nir_src_as_deref(intrin->src[info->deref_src])->modes;
1403       if (!(mode & aliasing_modes(ctx->options->modes)))
1404          continue;
1405       unsigned mode_index = mode_to_index(mode);
1406 
1407       /* create entry */
1408       struct entry *entry = create_entry(ctx, info, intrin);
1409       entry->index = next_index++;
1410 
1411       list_addtail(&entry->head, &ctx->entries[mode_index]);
1412 
1413       /* add the entry to a hash table */
1414 
1415       struct hash_table *adj_ht = NULL;
1416       if (entry->is_store) {
1417          if (!ctx->stores[mode_index])
1418             ctx->stores[mode_index] = _mesa_hash_table_create(ctx, &hash_entry_key, &entry_key_equals);
1419          adj_ht = ctx->stores[mode_index];
1420       } else {
1421          if (!ctx->loads[mode_index])
1422             ctx->loads[mode_index] = _mesa_hash_table_create(ctx, &hash_entry_key, &entry_key_equals);
1423          adj_ht = ctx->loads[mode_index];
1424       }
1425 
1426       uint32_t key_hash = hash_entry_key(entry->key);
1427       struct hash_entry *adj_entry = _mesa_hash_table_search_pre_hashed(adj_ht, key_hash, entry->key);
1428       struct util_dynarray *arr;
1429       if (adj_entry && adj_entry->data) {
1430          arr = (struct util_dynarray *)adj_entry->data;
1431       } else {
1432          arr = ralloc(ctx, struct util_dynarray);
1433          util_dynarray_init(arr, arr);
1434          _mesa_hash_table_insert_pre_hashed(adj_ht, key_hash, entry->key, arr);
1435       }
1436       util_dynarray_append(arr, struct entry *, entry);
1437    }
1438 
1439    /* sort and combine entries */
1440    for (unsigned i = 0; i < nir_num_variable_modes; i++) {
1441       progress |= vectorize_entries(ctx, impl, ctx->loads[i]);
1442       progress |= vectorize_entries(ctx, impl, ctx->stores[i]);
1443    }
1444 
1445    return progress;
1446 }
1447 
1448 bool
nir_opt_load_store_vectorize(nir_shader * shader,const nir_load_store_vectorize_options * options)1449 nir_opt_load_store_vectorize(nir_shader *shader, const nir_load_store_vectorize_options *options)
1450 {
1451    bool progress = false;
1452 
1453    struct vectorize_ctx *ctx = rzalloc(NULL, struct vectorize_ctx);
1454    ctx->shader = shader;
1455    ctx->options = options;
1456 
1457    nir_shader_index_vars(shader, options->modes);
1458 
1459    nir_foreach_function_impl(impl, shader) {
1460       if (options->modes & nir_var_function_temp)
1461          nir_function_impl_index_vars(impl);
1462 
1463       nir_foreach_block(block, impl)
1464          progress |= process_block(impl, ctx, block);
1465 
1466       nir_metadata_preserve(impl,
1467                             nir_metadata_block_index |
1468                                nir_metadata_dominance |
1469                                nir_metadata_live_defs);
1470    }
1471 
1472    ralloc_free(ctx);
1473    return progress;
1474 }
1475