• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright © 2019 Valve Corporation
3  *
4  * Permission is hereby granted, free of charge, to any person obtaining a
5  * copy of this software and associated documentation files (the "Software"),
6  * to deal in the Software without restriction, including without limitation
7  * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8  * and/or sell copies of the Software, and to permit persons to whom the
9  * Software is furnished to do so, subject to the following conditions:
10  *
11  * The above copyright notice and this permission notice (including the next
12  * paragraph) shall be included in all copies or substantial portions of the
13  * Software.
14  *
15  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18  * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20  * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21  * IN THE SOFTWARE.
22  */
23 
24 /**
25  * Although it's called a load/store "vectorization" pass, this also combines
26  * intersecting and identical loads/stores. It currently supports derefs, ubo,
27  * ssbo and push constant loads/stores.
28  *
29  * This doesn't handle copy_deref intrinsics and assumes that
30  * nir_lower_alu_to_scalar() has been called and that the IR is free from ALU
31  * modifiers. It also assumes that derefs have explicitly laid out types.
32  *
33  * After vectorization, the backend may want to call nir_lower_alu_to_scalar()
34  * and nir_lower_pack(). Also this creates cast instructions taking derefs as a
35  * source and some parts of NIR may not be able to handle that well.
36  *
37  * There are a few situations where this doesn't vectorize as well as it could:
38  * - It won't turn four consecutive vec3 loads into 3 vec4 loads.
39  * - It doesn't do global vectorization.
40  * Handling these cases probably wouldn't provide much benefit though.
41  *
42  * This probably doesn't handle big-endian GPUs correctly.
43  */
44 
45 #include "util/u_dynarray.h"
46 #include "nir.h"
47 #include "nir_builder.h"
48 #include "nir_deref.h"
49 #include "nir_worklist.h"
50 
51 #include <stdlib.h>
52 
53 struct intrinsic_info {
54    nir_variable_mode mode; /* 0 if the mode is obtained from the deref. */
55    nir_intrinsic_op op;
56    bool is_atomic;
57    /* Indices into nir_intrinsic::src[] or -1 if not applicable. */
58    int resource_src; /* resource (e.g. from vulkan_resource_index) */
59    int base_src;     /* offset which it loads/stores from */
60    int deref_src;    /* deref which is loads/stores from */
61    int value_src;    /* the data it is storing */
62 
63    /* Number of bytes for an offset delta of 1. */
64    unsigned offset_scale;
65 };
66 
67 static const struct intrinsic_info *
get_info(nir_intrinsic_op op)68 get_info(nir_intrinsic_op op)
69 {
70    switch (op) {
71 #define INFO(mode, op, atomic, res, base, deref, val, scale)                                                             \
72    case nir_intrinsic_##op: {                                                                                            \
73       static const struct intrinsic_info op##_info = { mode, nir_intrinsic_##op, atomic, res, base, deref, val, scale }; \
74       return &op##_info;                                                                                                 \
75    }
76 #define LOAD(mode, op, res, base, deref, scale)       INFO(mode, load_##op, false, res, base, deref, -1, scale)
77 #define STORE(mode, op, res, base, deref, val, scale) INFO(mode, store_##op, false, res, base, deref, val, scale)
78 #define ATOMIC(mode, type, res, base, deref, val, scale)         \
79    INFO(mode, type##_atomic, true, res, base, deref, val, scale) \
80    INFO(mode, type##_atomic_swap, true, res, base, deref, val, scale)
81 
82       LOAD(nir_var_mem_push_const, push_constant, -1, 0, -1, 1)
83       LOAD(nir_var_mem_ubo, ubo, 0, 1, -1, 1)
84       LOAD(nir_var_mem_ssbo, ssbo, 0, 1, -1, 1)
85       STORE(nir_var_mem_ssbo, ssbo, 1, 2, -1, 0, 1)
86       LOAD(0, deref, -1, -1, 0, 1)
87       STORE(0, deref, -1, -1, 0, 1, 1)
88       LOAD(nir_var_mem_shared, shared, -1, 0, -1, 1)
89       STORE(nir_var_mem_shared, shared, -1, 1, -1, 0, 1)
90       LOAD(nir_var_mem_global, global, -1, 0, -1, 1)
91       STORE(nir_var_mem_global, global, -1, 1, -1, 0, 1)
92       LOAD(nir_var_mem_global, global_constant, -1, 0, -1, 1)
93       LOAD(nir_var_mem_task_payload, task_payload, -1, 0, -1, 1)
94       STORE(nir_var_mem_task_payload, task_payload, -1, 1, -1, 0, 1)
95       ATOMIC(nir_var_mem_ssbo, ssbo, 0, 1, -1, 2, 1)
96       ATOMIC(0, deref, -1, -1, 0, 1, 1)
97       ATOMIC(nir_var_mem_shared, shared, -1, 0, -1, 1, 1)
98       ATOMIC(nir_var_mem_global, global, -1, 0, -1, 1, 1)
99       ATOMIC(nir_var_mem_task_payload, task_payload, -1, 0, -1, 1, 1)
100       LOAD(nir_var_shader_temp, stack, -1, -1, -1, 1)
101       STORE(nir_var_shader_temp, stack, -1, -1, -1, 0, 1)
102       LOAD(nir_var_shader_temp, scratch, -1, 0, -1, 1)
103       STORE(nir_var_shader_temp, scratch, -1, 1, -1, 0, 1)
104       LOAD(nir_var_mem_ubo, ubo_uniform_block_intel, 0, 1, -1, 1)
105       LOAD(nir_var_mem_ssbo, ssbo_uniform_block_intel, 0, 1, -1, 1)
106       LOAD(nir_var_mem_shared, shared_uniform_block_intel, -1, 0, -1, 1)
107       LOAD(nir_var_mem_global, global_constant_uniform_block_intel, -1, 0, -1, 1)
108       INFO(nir_var_mem_ubo, ldc_nv, false, 0, 1, -1, -1, 1)
109       INFO(nir_var_mem_ubo, ldcx_nv, false, 0, 1, -1, -1, 1)
110       LOAD(nir_var_uniform, const_ir3, -1, 0, -1, 4)
111       STORE(nir_var_uniform, const_ir3, -1, -1, -1, 0, 4)
112       INFO(nir_var_mem_shared, shared_append_amd, true, -1, -1, -1, -1, 1)
113       INFO(nir_var_mem_shared, shared_consume_amd, true, -1, -1, -1, -1, 1)
114       LOAD(nir_var_mem_global, smem_amd, 0, 1, -1, 1)
115    default:
116       break;
117 #undef ATOMIC
118 #undef STORE
119 #undef LOAD
120 #undef INFO
121    }
122    return NULL;
123 }
124 
125 /*
126  * Information used to compare memory operations.
127  * It canonically represents an offset as:
128  * `offset_defs[0]*offset_defs_mul[0] + offset_defs[1]*offset_defs_mul[1] + ...`
129  * "offset_defs" is sorted in ascenting order by the ssa definition's index.
130  * "resource" or "var" may be NULL.
131  */
132 struct entry_key {
133    nir_def *resource;
134    nir_variable *var;
135    unsigned offset_def_count;
136    nir_scalar *offset_defs;
137    uint64_t *offset_defs_mul;
138 };
139 
140 /* Information on a single memory operation. */
141 struct entry {
142    struct list_head head;
143    unsigned index;
144 
145    struct entry_key *key;
146    union {
147       uint64_t offset; /* sign-extended */
148       int64_t offset_signed;
149    };
150    uint32_t align_mul;
151    uint32_t align_offset;
152 
153    nir_instr *instr;
154    nir_intrinsic_instr *intrin;
155    unsigned num_components;
156    const struct intrinsic_info *info;
157    enum gl_access_qualifier access;
158    bool is_store;
159 
160    nir_deref_instr *deref;
161 };
162 
163 struct vectorize_ctx {
164    nir_shader *shader;
165    const nir_load_store_vectorize_options *options;
166    struct list_head entries[nir_num_variable_modes];
167    struct hash_table *loads[nir_num_variable_modes];
168    struct hash_table *stores[nir_num_variable_modes];
169 };
170 
171 static uint32_t
hash_entry_key(const void * key_)172 hash_entry_key(const void *key_)
173 {
174    /* this is careful to not include pointers in the hash calculation so that
175     * the order of the hash table walk is deterministic */
176    struct entry_key *key = (struct entry_key *)key_;
177 
178    uint32_t hash = 0;
179    if (key->resource)
180       hash = XXH32(&key->resource->index, sizeof(key->resource->index), hash);
181    if (key->var) {
182       hash = XXH32(&key->var->index, sizeof(key->var->index), hash);
183       unsigned mode = key->var->data.mode;
184       hash = XXH32(&mode, sizeof(mode), hash);
185    }
186 
187    for (unsigned i = 0; i < key->offset_def_count; i++) {
188       hash = XXH32(&key->offset_defs[i].def->index, sizeof(key->offset_defs[i].def->index), hash);
189       hash = XXH32(&key->offset_defs[i].comp, sizeof(key->offset_defs[i].comp), hash);
190    }
191 
192    hash = XXH32(key->offset_defs_mul, key->offset_def_count * sizeof(uint64_t), hash);
193 
194    return hash;
195 }
196 
197 static bool
entry_key_equals(const void * a_,const void * b_)198 entry_key_equals(const void *a_, const void *b_)
199 {
200    struct entry_key *a = (struct entry_key *)a_;
201    struct entry_key *b = (struct entry_key *)b_;
202 
203    if (a->var != b->var || a->resource != b->resource)
204       return false;
205 
206    if (a->offset_def_count != b->offset_def_count)
207       return false;
208 
209    for (unsigned i = 0; i < a->offset_def_count; i++) {
210       if (!nir_scalar_equal(a->offset_defs[i], b->offset_defs[i]))
211          return false;
212    }
213 
214    size_t offset_def_mul_size = a->offset_def_count * sizeof(uint64_t);
215    if (a->offset_def_count &&
216        memcmp(a->offset_defs_mul, b->offset_defs_mul, offset_def_mul_size))
217       return false;
218 
219    return true;
220 }
221 
222 static void
delete_entry_dynarray(struct hash_entry * entry)223 delete_entry_dynarray(struct hash_entry *entry)
224 {
225    struct util_dynarray *arr = (struct util_dynarray *)entry->data;
226    ralloc_free(arr);
227 }
228 
229 static int
sort_entries(const void * a_,const void * b_)230 sort_entries(const void *a_, const void *b_)
231 {
232    struct entry *a = *(struct entry *const *)a_;
233    struct entry *b = *(struct entry *const *)b_;
234 
235    if (a->offset_signed > b->offset_signed)
236       return 1;
237    else if (a->offset_signed < b->offset_signed)
238       return -1;
239    else
240       return 0;
241 }
242 
243 static unsigned
get_bit_size(struct entry * entry)244 get_bit_size(struct entry *entry)
245 {
246    unsigned size = entry->is_store ? entry->intrin->src[entry->info->value_src].ssa->bit_size : entry->intrin->def.bit_size;
247    return size == 1 ? 32u : size;
248 }
249 
250 static unsigned
get_write_mask(const nir_intrinsic_instr * intrin)251 get_write_mask(const nir_intrinsic_instr *intrin)
252 {
253    if (nir_intrinsic_has_write_mask(intrin))
254       return nir_intrinsic_write_mask(intrin);
255 
256    const struct intrinsic_info *info = get_info(intrin->intrinsic);
257    assert(info->value_src >= 0);
258    return nir_component_mask(intrin->src[info->value_src].ssa->num_components);
259 }
260 
261 static nir_op
get_effective_alu_op(nir_scalar scalar)262 get_effective_alu_op(nir_scalar scalar)
263 {
264    nir_op op = nir_scalar_alu_op(scalar);
265 
266    /* amul can always be replaced by imul and we pattern match on the more
267     * general opcode, so return imul for amul.
268     */
269    if (op == nir_op_amul)
270       return nir_op_imul;
271    else
272       return op;
273 }
274 
275 /* If "def" is from an alu instruction with the opcode "op" and one of it's
276  * sources is a constant, update "def" to be the non-constant source, fill "c"
277  * with the constant and return true. */
278 static bool
parse_alu(nir_scalar * def,nir_op op,uint64_t * c)279 parse_alu(nir_scalar *def, nir_op op, uint64_t *c)
280 {
281    if (!nir_scalar_is_alu(*def) || get_effective_alu_op(*def) != op)
282       return false;
283 
284    nir_scalar src0 = nir_scalar_chase_alu_src(*def, 0);
285    nir_scalar src1 = nir_scalar_chase_alu_src(*def, 1);
286    if (op != nir_op_ishl && nir_scalar_is_const(src0)) {
287       *c = nir_scalar_as_uint(src0);
288       *def = src1;
289    } else if (nir_scalar_is_const(src1)) {
290       *c = nir_scalar_as_uint(src1);
291       *def = src0;
292    } else {
293       return false;
294    }
295    return true;
296 }
297 
298 /* Parses an offset expression such as "a * 16 + 4" and "(a * 16 + 4) * 64 + 32". */
299 static void
parse_offset(nir_scalar * base,uint64_t * base_mul,uint64_t * offset)300 parse_offset(nir_scalar *base, uint64_t *base_mul, uint64_t *offset)
301 {
302    if (nir_scalar_is_const(*base)) {
303       *offset = nir_scalar_as_uint(*base);
304       base->def = NULL;
305       return;
306    }
307 
308    uint64_t mul = 1;
309    uint64_t add = 0;
310    bool progress = false;
311    do {
312       uint64_t mul2 = 1, add2 = 0;
313 
314       progress = parse_alu(base, nir_op_imul, &mul2);
315       mul *= mul2;
316 
317       mul2 = 0;
318       progress |= parse_alu(base, nir_op_ishl, &mul2);
319       mul <<= mul2;
320 
321       progress |= parse_alu(base, nir_op_iadd, &add2);
322       add += add2 * mul;
323 
324       if (nir_scalar_is_alu(*base) && nir_scalar_alu_op(*base) == nir_op_mov) {
325          *base = nir_scalar_chase_alu_src(*base, 0);
326          progress = true;
327       }
328    } while (progress);
329 
330    if (base->def->parent_instr->type == nir_instr_type_intrinsic) {
331       nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(base->def->parent_instr);
332       if (intrin->intrinsic == nir_intrinsic_load_vulkan_descriptor)
333          base->def = NULL;
334    }
335 
336    *base_mul = mul;
337    *offset = add;
338 }
339 
340 static unsigned
type_scalar_size_bytes(const struct glsl_type * type)341 type_scalar_size_bytes(const struct glsl_type *type)
342 {
343    assert(glsl_type_is_vector_or_scalar(type) ||
344           glsl_type_is_matrix(type));
345    return glsl_type_is_boolean(type) ? 4u : glsl_get_bit_size(type) / 8u;
346 }
347 
348 static unsigned
add_to_entry_key(nir_scalar * offset_defs,uint64_t * offset_defs_mul,unsigned offset_def_count,nir_scalar def,uint64_t mul)349 add_to_entry_key(nir_scalar *offset_defs, uint64_t *offset_defs_mul,
350                  unsigned offset_def_count, nir_scalar def, uint64_t mul)
351 {
352    mul = util_mask_sign_extend(mul, def.def->bit_size);
353 
354    for (unsigned i = 0; i <= offset_def_count; i++) {
355       if (i == offset_def_count || def.def->index > offset_defs[i].def->index) {
356          /* insert before i */
357          memmove(offset_defs + i + 1, offset_defs + i,
358                  (offset_def_count - i) * sizeof(nir_scalar));
359          memmove(offset_defs_mul + i + 1, offset_defs_mul + i,
360                  (offset_def_count - i) * sizeof(uint64_t));
361          offset_defs[i] = def;
362          offset_defs_mul[i] = mul;
363          return 1;
364       } else if (nir_scalar_equal(def, offset_defs[i])) {
365          /* merge with offset_def at i */
366          offset_defs_mul[i] += mul;
367          return 0;
368       }
369    }
370    unreachable("Unreachable.");
371    return 0;
372 }
373 
374 static struct entry_key *
create_entry_key_from_deref(void * mem_ctx,nir_deref_path * path,uint64_t * offset_base)375 create_entry_key_from_deref(void *mem_ctx,
376                             nir_deref_path *path,
377                             uint64_t *offset_base)
378 {
379    unsigned path_len = 0;
380    while (path->path[path_len])
381       path_len++;
382 
383    nir_scalar offset_defs_stack[32];
384    uint64_t offset_defs_mul_stack[32];
385    nir_scalar *offset_defs = offset_defs_stack;
386    uint64_t *offset_defs_mul = offset_defs_mul_stack;
387    if (path_len > 32) {
388       offset_defs = malloc(path_len * sizeof(nir_scalar));
389       offset_defs_mul = malloc(path_len * sizeof(uint64_t));
390    }
391    unsigned offset_def_count = 0;
392 
393    struct entry_key *key = ralloc(mem_ctx, struct entry_key);
394    key->resource = NULL;
395    key->var = NULL;
396    *offset_base = 0;
397 
398    for (unsigned i = 0; i < path_len; i++) {
399       nir_deref_instr *parent = i ? path->path[i - 1] : NULL;
400       nir_deref_instr *deref = path->path[i];
401 
402       switch (deref->deref_type) {
403       case nir_deref_type_var: {
404          assert(!parent);
405          key->var = deref->var;
406          break;
407       }
408       case nir_deref_type_array:
409       case nir_deref_type_ptr_as_array: {
410          assert(parent);
411          nir_def *index = deref->arr.index.ssa;
412          uint32_t stride = nir_deref_instr_array_stride(deref);
413 
414          nir_scalar base = { .def = index, .comp = 0 };
415          uint64_t offset = 0, base_mul = 1;
416          parse_offset(&base, &base_mul, &offset);
417          offset = util_mask_sign_extend(offset, index->bit_size);
418 
419          *offset_base += offset * stride;
420          if (base.def) {
421             offset_def_count += add_to_entry_key(offset_defs, offset_defs_mul,
422                                                  offset_def_count,
423                                                  base, base_mul * stride);
424          }
425          break;
426       }
427       case nir_deref_type_struct: {
428          assert(parent);
429          int offset = glsl_get_struct_field_offset(parent->type, deref->strct.index);
430          *offset_base += offset;
431          break;
432       }
433       case nir_deref_type_cast: {
434          if (!parent)
435             key->resource = deref->parent.ssa;
436          break;
437       }
438       default:
439          unreachable("Unhandled deref type");
440       }
441    }
442 
443    key->offset_def_count = offset_def_count;
444    key->offset_defs = ralloc_array(mem_ctx, nir_scalar, offset_def_count);
445    key->offset_defs_mul = ralloc_array(mem_ctx, uint64_t, offset_def_count);
446    memcpy(key->offset_defs, offset_defs, offset_def_count * sizeof(nir_scalar));
447    memcpy(key->offset_defs_mul, offset_defs_mul, offset_def_count * sizeof(uint64_t));
448 
449    if (offset_defs != offset_defs_stack)
450       free(offset_defs);
451    if (offset_defs_mul != offset_defs_mul_stack)
452       free(offset_defs_mul);
453 
454    return key;
455 }
456 
457 static unsigned
parse_entry_key_from_offset(struct entry_key * key,unsigned size,unsigned left,nir_scalar base,uint64_t base_mul,uint64_t * offset)458 parse_entry_key_from_offset(struct entry_key *key, unsigned size, unsigned left,
459                             nir_scalar base, uint64_t base_mul, uint64_t *offset)
460 {
461    uint64_t new_mul;
462    uint64_t new_offset;
463    parse_offset(&base, &new_mul, &new_offset);
464    *offset += new_offset * base_mul;
465 
466    if (!base.def)
467       return 0;
468 
469    base_mul *= new_mul;
470 
471    assert(left >= 1);
472 
473    if (left >= 2) {
474       if (nir_scalar_is_alu(base) && nir_scalar_alu_op(base) == nir_op_iadd) {
475          nir_scalar src0 = nir_scalar_chase_alu_src(base, 0);
476          nir_scalar src1 = nir_scalar_chase_alu_src(base, 1);
477          unsigned amount = parse_entry_key_from_offset(key, size, left - 1, src0, base_mul, offset);
478          amount += parse_entry_key_from_offset(key, size + amount, left - amount, src1, base_mul, offset);
479          return amount;
480       }
481    }
482 
483    return add_to_entry_key(key->offset_defs, key->offset_defs_mul, size, base, base_mul);
484 }
485 
486 static struct entry_key *
create_entry_key_from_offset(void * mem_ctx,nir_def * base,uint64_t base_mul,uint64_t * offset)487 create_entry_key_from_offset(void *mem_ctx, nir_def *base, uint64_t base_mul, uint64_t *offset)
488 {
489    struct entry_key *key = ralloc(mem_ctx, struct entry_key);
490    key->resource = NULL;
491    key->var = NULL;
492    if (base) {
493       nir_scalar offset_defs[32];
494       uint64_t offset_defs_mul[32];
495       key->offset_defs = offset_defs;
496       key->offset_defs_mul = offset_defs_mul;
497 
498       nir_scalar scalar = { .def = base, .comp = 0 };
499       key->offset_def_count = parse_entry_key_from_offset(key, 0, 32, scalar, base_mul, offset);
500 
501       key->offset_defs = ralloc_array(mem_ctx, nir_scalar, key->offset_def_count);
502       key->offset_defs_mul = ralloc_array(mem_ctx, uint64_t, key->offset_def_count);
503       memcpy(key->offset_defs, offset_defs, key->offset_def_count * sizeof(nir_scalar));
504       memcpy(key->offset_defs_mul, offset_defs_mul, key->offset_def_count * sizeof(uint64_t));
505    } else {
506       key->offset_def_count = 0;
507       key->offset_defs = NULL;
508       key->offset_defs_mul = NULL;
509    }
510    return key;
511 }
512 
513 static nir_variable_mode
get_variable_mode(struct entry * entry)514 get_variable_mode(struct entry *entry)
515 {
516    if (entry->info->mode)
517       return entry->info->mode;
518    assert(entry->deref && util_bitcount(entry->deref->modes) == 1);
519    return entry->deref->modes;
520 }
521 
522 static unsigned
mode_to_index(nir_variable_mode mode)523 mode_to_index(nir_variable_mode mode)
524 {
525    assert(util_bitcount(mode) == 1);
526 
527    /* Globals and SSBOs should be tracked together */
528    if (mode == nir_var_mem_global)
529       mode = nir_var_mem_ssbo;
530 
531    return ffs(mode) - 1;
532 }
533 
534 static nir_variable_mode
aliasing_modes(nir_variable_mode modes)535 aliasing_modes(nir_variable_mode modes)
536 {
537    /* Global and SSBO can alias */
538    if (modes & (nir_var_mem_ssbo | nir_var_mem_global))
539       modes |= nir_var_mem_ssbo | nir_var_mem_global;
540    return modes;
541 }
542 
543 static void
calc_alignment(struct entry * entry)544 calc_alignment(struct entry *entry)
545 {
546    uint32_t align_mul = 31;
547    for (unsigned i = 0; i < entry->key->offset_def_count; i++) {
548       if (entry->key->offset_defs_mul[i])
549          align_mul = MIN2(align_mul, ffsll(entry->key->offset_defs_mul[i]));
550    }
551 
552    entry->align_mul = 1u << (align_mul - 1);
553    bool has_align = nir_intrinsic_infos[entry->intrin->intrinsic].index_map[NIR_INTRINSIC_ALIGN_MUL];
554    if (!has_align || entry->align_mul >= nir_intrinsic_align_mul(entry->intrin)) {
555       entry->align_offset = entry->offset % entry->align_mul;
556    } else {
557       entry->align_mul = nir_intrinsic_align_mul(entry->intrin);
558       entry->align_offset = nir_intrinsic_align_offset(entry->intrin);
559    }
560 }
561 
562 static struct entry *
create_entry(void * mem_ctx,const struct intrinsic_info * info,nir_intrinsic_instr * intrin)563 create_entry(void *mem_ctx,
564              const struct intrinsic_info *info,
565              nir_intrinsic_instr *intrin)
566 {
567    struct entry *entry = rzalloc(mem_ctx, struct entry);
568    entry->intrin = intrin;
569    entry->instr = &intrin->instr;
570    entry->info = info;
571    entry->is_store = entry->info->value_src >= 0;
572    entry->num_components =
573       entry->is_store ? intrin->num_components :
574                         nir_def_last_component_read(&intrin->def) + 1;
575 
576    if (entry->info->deref_src >= 0) {
577       entry->deref = nir_src_as_deref(intrin->src[entry->info->deref_src]);
578       nir_deref_path path;
579       nir_deref_path_init(&path, entry->deref, NULL);
580       entry->key = create_entry_key_from_deref(entry, &path, &entry->offset);
581       nir_deref_path_finish(&path);
582    } else {
583       nir_def *base = entry->info->base_src >= 0 ? intrin->src[entry->info->base_src].ssa : NULL;
584       uint64_t offset = 0;
585       if (nir_intrinsic_has_base(intrin))
586          offset += nir_intrinsic_base(intrin) * info->offset_scale;
587       entry->key = create_entry_key_from_offset(entry, base, info->offset_scale, &offset);
588       entry->offset = offset;
589 
590       if (base)
591          entry->offset = util_mask_sign_extend(entry->offset, base->bit_size);
592    }
593 
594    if (entry->info->resource_src >= 0)
595       entry->key->resource = intrin->src[entry->info->resource_src].ssa;
596 
597    if (nir_intrinsic_has_access(intrin))
598       entry->access = nir_intrinsic_access(intrin);
599    else if (entry->key->var)
600       entry->access = entry->key->var->data.access;
601 
602    if (nir_intrinsic_can_reorder(intrin))
603       entry->access |= ACCESS_CAN_REORDER;
604 
605    uint32_t restrict_modes = nir_var_shader_in | nir_var_shader_out;
606    restrict_modes |= nir_var_shader_temp | nir_var_function_temp;
607    restrict_modes |= nir_var_uniform | nir_var_mem_push_const;
608    restrict_modes |= nir_var_system_value | nir_var_mem_shared;
609    restrict_modes |= nir_var_mem_task_payload;
610    if (get_variable_mode(entry) & restrict_modes)
611       entry->access |= ACCESS_RESTRICT;
612 
613    calc_alignment(entry);
614 
615    return entry;
616 }
617 
618 static nir_deref_instr *
cast_deref(nir_builder * b,unsigned num_components,unsigned bit_size,nir_deref_instr * deref)619 cast_deref(nir_builder *b, unsigned num_components, unsigned bit_size, nir_deref_instr *deref)
620 {
621    if (glsl_get_components(deref->type) == num_components &&
622        type_scalar_size_bytes(deref->type) * 8u == bit_size)
623       return deref;
624 
625    enum glsl_base_type types[] = {
626       GLSL_TYPE_UINT8, GLSL_TYPE_UINT16, GLSL_TYPE_UINT, GLSL_TYPE_UINT64
627    };
628    enum glsl_base_type base = types[ffs(bit_size / 8u) - 1u];
629    const struct glsl_type *type = glsl_vector_type(base, num_components);
630 
631    if (deref->type == type)
632       return deref;
633 
634    return nir_build_deref_cast(b, &deref->def, deref->modes, type, 0);
635 }
636 
637 /* Return true if "new_bit_size" is a usable bit size for a vectorized load/store
638  * of "low" and "high". */
639 static bool
new_bitsize_acceptable(struct vectorize_ctx * ctx,unsigned new_bit_size,struct entry * low,struct entry * high,unsigned size)640 new_bitsize_acceptable(struct vectorize_ctx *ctx, unsigned new_bit_size,
641                        struct entry *low, struct entry *high, unsigned size)
642 {
643    if (size % new_bit_size != 0)
644       return false;
645 
646    unsigned new_num_components = size / new_bit_size;
647 
648    if (low->is_store) {
649       if (!nir_num_components_valid(new_num_components))
650          return false;
651    } else {
652       /* Invalid component counts must be rejected by the callback, otherwise
653        * the load will overfetch by aligning the number to the next valid
654        * component count.
655        */
656       if (new_num_components > NIR_MAX_VEC_COMPONENTS)
657          return false;
658    }
659 
660    unsigned high_offset = high->offset_signed - low->offset_signed;
661 
662    /* check nir_extract_bits limitations */
663    unsigned common_bit_size = MIN2(get_bit_size(low), get_bit_size(high));
664    common_bit_size = MIN2(common_bit_size, new_bit_size);
665    if (high_offset > 0)
666       common_bit_size = MIN2(common_bit_size, (1u << (ffs(high_offset * 8) - 1)));
667    if (new_bit_size / common_bit_size > NIR_MAX_VEC_COMPONENTS)
668       return false;
669 
670    unsigned low_size = low->intrin->num_components * get_bit_size(low) / 8;
671    /* The hole size can be less than 0 if low and high instructions overlap. */
672    int64_t hole_size = high->offset_signed - (low->offset_signed + low_size);
673 
674    if (!ctx->options->callback(low->align_mul,
675                                low->align_offset,
676                                new_bit_size, new_num_components, hole_size,
677                                low->intrin, high->intrin,
678                                ctx->options->cb_data))
679       return false;
680 
681    if (low->is_store) {
682       unsigned low_size = low->num_components * get_bit_size(low);
683       unsigned high_size = high->num_components * get_bit_size(high);
684 
685       if (low_size % new_bit_size != 0)
686          return false;
687       if (high_size % new_bit_size != 0)
688          return false;
689 
690       unsigned write_mask = get_write_mask(low->intrin);
691       if (!nir_component_mask_can_reinterpret(write_mask, get_bit_size(low), new_bit_size))
692          return false;
693 
694       write_mask = get_write_mask(high->intrin);
695       if (!nir_component_mask_can_reinterpret(write_mask, get_bit_size(high), new_bit_size))
696          return false;
697    }
698 
699    return true;
700 }
701 
702 static nir_deref_instr *
subtract_deref(nir_builder * b,nir_deref_instr * deref,int64_t offset)703 subtract_deref(nir_builder *b, nir_deref_instr *deref, int64_t offset)
704 {
705    /* avoid adding another deref to the path */
706    if (deref->deref_type == nir_deref_type_ptr_as_array &&
707        nir_src_is_const(deref->arr.index) &&
708        offset % nir_deref_instr_array_stride(deref) == 0) {
709       unsigned stride = nir_deref_instr_array_stride(deref);
710       assert(stride != 0);
711       nir_def *index = nir_imm_intN_t(b, nir_src_as_int(deref->arr.index) - offset / stride,
712                                       deref->def.bit_size);
713       return nir_build_deref_ptr_as_array(b, nir_deref_instr_parent(deref), index);
714    }
715 
716    if (deref->deref_type == nir_deref_type_array &&
717        nir_src_is_const(deref->arr.index)) {
718       nir_deref_instr *parent = nir_deref_instr_parent(deref);
719       unsigned stride = nir_deref_instr_array_stride(deref);
720       assert(stride != 0);
721       if (offset % stride == 0)
722          return nir_build_deref_array_imm(
723             b, parent, nir_src_as_int(deref->arr.index) - offset / stride);
724    }
725 
726    deref = nir_build_deref_cast(b, &deref->def, deref->modes,
727                                 glsl_scalar_type(GLSL_TYPE_UINT8), 1);
728    return nir_build_deref_ptr_as_array(
729       b, deref, nir_imm_intN_t(b, -offset, deref->def.bit_size));
730 }
731 
732 static void
vectorize_loads(nir_builder * b,struct vectorize_ctx * ctx,struct entry * low,struct entry * high,struct entry * first,struct entry * second,unsigned new_bit_size,unsigned new_num_components,unsigned high_start)733 vectorize_loads(nir_builder *b, struct vectorize_ctx *ctx,
734                 struct entry *low, struct entry *high,
735                 struct entry *first, struct entry *second,
736                 unsigned new_bit_size, unsigned new_num_components,
737                 unsigned high_start)
738 {
739    unsigned old_low_bit_size = get_bit_size(low);
740    unsigned old_high_bit_size = get_bit_size(high);
741    unsigned old_low_num_components = low->intrin->num_components;
742    unsigned old_high_num_components = high->intrin->num_components;
743    bool low_bool = low->intrin->def.bit_size == 1;
744    bool high_bool = high->intrin->def.bit_size == 1;
745    nir_def *data = &first->intrin->def;
746 
747    b->cursor = nir_after_instr(first->instr);
748 
749    /* Align num_components to a supported vector size, effectively
750     * overfetching. Drivers can reject this in the callback by returning
751     * false for invalid num_components.
752     */
753    new_num_components = nir_round_up_components(new_num_components);
754    new_num_components = MAX2(new_num_components, 1);
755 
756    /* update the load's destination size and extract data for each of the original loads */
757    data->num_components = new_num_components;
758    data->bit_size = new_bit_size;
759 
760    /* Since this pass is shrinking merged loads if they have unused components
761     * due to using nir_def_last_component_read, nir_extract_bits might not have
762     * enough data to extract. Pad the extraction with these undefs to compensate.
763     * It will be eliminated by DCE.
764     */
765    nir_def *low_undef = nir_undef(b, old_low_num_components, old_low_bit_size);
766    nir_def *high_undef = nir_undef(b, old_high_num_components, old_high_bit_size);
767 
768    nir_def *low_def = nir_extract_bits(
769       b, (nir_def*[]){data, low_undef}, 2, 0, old_low_num_components,
770       old_low_bit_size);
771 
772    nir_def *high_def = nir_extract_bits(
773       b, (nir_def*[]){data, high_undef}, 2, high_start,
774       old_high_num_components, old_high_bit_size);
775 
776    /* convert booleans */
777    low_def = low_bool ? nir_i2b(b, low_def) : nir_mov(b, low_def);
778    high_def = high_bool ? nir_i2b(b, high_def) : nir_mov(b, high_def);
779 
780    /* update uses */
781    if (first == low) {
782       nir_def_rewrite_uses_after(&low->intrin->def, low_def,
783                                  high_def->parent_instr);
784       nir_def_rewrite_uses(&high->intrin->def, high_def);
785    } else {
786       nir_def_rewrite_uses(&low->intrin->def, low_def);
787       nir_def_rewrite_uses_after(&high->intrin->def, high_def,
788                                  high_def->parent_instr);
789    }
790 
791    /* update the intrinsic */
792    first->intrin->num_components = new_num_components;
793    first->num_components = nir_def_last_component_read(data) + 1;
794 
795    const struct intrinsic_info *info = first->info;
796 
797    /* update the offset */
798    if (first != low && info->base_src >= 0) {
799       /* let nir_opt_algebraic() remove this addition. this doesn't have much
800        * issues with subtracting 16 from expressions like "(i + 1) * 16" because
801        * nir_opt_algebraic() turns them into "i * 16 + 16" */
802       b->cursor = nir_before_instr(first->instr);
803 
804       nir_def *new_base = first->intrin->src[info->base_src].ssa;
805       new_base = nir_iadd_imm(b, new_base, -(int)(high_start / 8u / first->info->offset_scale));
806 
807       nir_src_rewrite(&first->intrin->src[info->base_src], new_base);
808    }
809 
810    /* update the deref */
811    if (info->deref_src >= 0) {
812       b->cursor = nir_before_instr(first->instr);
813 
814       nir_deref_instr *deref = nir_src_as_deref(first->intrin->src[info->deref_src]);
815       if (first != low && high_start != 0)
816          deref = subtract_deref(b, deref, high_start / 8u / first->info->offset_scale);
817       first->deref = cast_deref(b, new_num_components, new_bit_size, deref);
818 
819       nir_src_rewrite(&first->intrin->src[info->deref_src],
820                       &first->deref->def);
821    }
822 
823    /* update align */
824    if (nir_intrinsic_has_range_base(first->intrin)) {
825       uint32_t old_low_range_base = nir_intrinsic_range_base(low->intrin);
826       uint32_t old_high_range_base = nir_intrinsic_range_base(high->intrin);
827       uint32_t old_low_range_end = old_low_range_base + nir_intrinsic_range(low->intrin);
828       uint32_t old_high_range_end = old_high_range_base + nir_intrinsic_range(high->intrin);
829 
830       uint32_t old_low_size = old_low_num_components * old_low_bit_size / 8;
831       uint32_t old_high_size = old_high_num_components * old_high_bit_size / 8;
832       uint32_t old_combined_size_up_to_high = high_start / 8u + old_high_size;
833       uint32_t old_combined_size = MAX2(old_low_size, old_combined_size_up_to_high);
834       uint32_t old_combined_range = MAX2(old_low_range_end, old_high_range_end) -
835                                     old_low_range_base;
836       uint32_t new_size = new_num_components * new_bit_size / 8;
837 
838       /* If we are trimming (e.g. merging vec1 + vec7as8 removes 1 component)
839        * or overfetching, we need to adjust the range accordingly.
840        */
841       int size_change = new_size - old_combined_size;
842       uint32_t range = old_combined_range + size_change;
843       assert(range);
844 
845       nir_intrinsic_set_range_base(first->intrin, old_low_range_base);
846       nir_intrinsic_set_range(first->intrin, range);
847    } else if (nir_intrinsic_has_base(first->intrin) && info->base_src == -1 && info->deref_src == -1) {
848       nir_intrinsic_set_base(first->intrin, nir_intrinsic_base(low->intrin));
849    }
850 
851    first->key = low->key;
852    first->offset = low->offset;
853 
854    first->align_mul = low->align_mul;
855    first->align_offset = low->align_offset;
856 
857    nir_instr_remove(second->instr);
858 }
859 
860 static void
vectorize_stores(nir_builder * b,struct vectorize_ctx * ctx,struct entry * low,struct entry * high,struct entry * first,struct entry * second,unsigned new_bit_size,unsigned new_num_components,unsigned high_start)861 vectorize_stores(nir_builder *b, struct vectorize_ctx *ctx,
862                  struct entry *low, struct entry *high,
863                  struct entry *first, struct entry *second,
864                  unsigned new_bit_size, unsigned new_num_components,
865                  unsigned high_start)
866 {
867    ASSERTED unsigned low_size = low->num_components * get_bit_size(low);
868    assert(low_size % new_bit_size == 0);
869 
870    b->cursor = nir_before_instr(second->instr);
871 
872    /* get new writemasks */
873    uint32_t low_write_mask = get_write_mask(low->intrin);
874    uint32_t high_write_mask = get_write_mask(high->intrin);
875    low_write_mask = nir_component_mask_reinterpret(low_write_mask,
876                                                    get_bit_size(low),
877                                                    new_bit_size);
878    high_write_mask = nir_component_mask_reinterpret(high_write_mask,
879                                                     get_bit_size(high),
880                                                     new_bit_size);
881    high_write_mask <<= high_start / new_bit_size;
882 
883    uint32_t write_mask = low_write_mask | high_write_mask;
884 
885    /* convert booleans */
886    nir_def *low_val = low->intrin->src[low->info->value_src].ssa;
887    nir_def *high_val = high->intrin->src[high->info->value_src].ssa;
888    low_val = low_val->bit_size == 1 ? nir_b2iN(b, low_val, 32) : low_val;
889    high_val = high_val->bit_size == 1 ? nir_b2iN(b, high_val, 32) : high_val;
890 
891    /* combine the data */
892    nir_def *data_channels[NIR_MAX_VEC_COMPONENTS];
893    for (unsigned i = 0; i < new_num_components; i++) {
894       bool set_low = low_write_mask & (1 << i);
895       bool set_high = high_write_mask & (1 << i);
896 
897       if (set_low && (!set_high || low == second)) {
898          unsigned offset = i * new_bit_size;
899          data_channels[i] = nir_extract_bits(b, &low_val, 1, offset, 1, new_bit_size);
900       } else if (set_high) {
901          assert(!set_low || high == second);
902          unsigned offset = i * new_bit_size - high_start;
903          data_channels[i] = nir_extract_bits(b, &high_val, 1, offset, 1, new_bit_size);
904       } else {
905          data_channels[i] = nir_undef(b, 1, new_bit_size);
906       }
907    }
908    nir_def *data = nir_vec(b, data_channels, new_num_components);
909 
910    /* update the intrinsic */
911    if (nir_intrinsic_has_write_mask(second->intrin))
912       nir_intrinsic_set_write_mask(second->intrin, write_mask);
913    second->intrin->num_components = data->num_components;
914    second->num_components = data->num_components;
915 
916    const struct intrinsic_info *info = second->info;
917    assert(info->value_src >= 0);
918    nir_src_rewrite(&second->intrin->src[info->value_src], data);
919 
920    /* update the offset */
921    if (second != low && info->base_src >= 0)
922       nir_src_rewrite(&second->intrin->src[info->base_src],
923                       low->intrin->src[info->base_src].ssa);
924 
925    /* update the deref */
926    if (info->deref_src >= 0) {
927       b->cursor = nir_before_instr(second->instr);
928       second->deref = cast_deref(b, new_num_components, new_bit_size,
929                                  nir_src_as_deref(low->intrin->src[info->deref_src]));
930       nir_src_rewrite(&second->intrin->src[info->deref_src],
931                       &second->deref->def);
932    }
933 
934    /* update base/align */
935    if (second != low && nir_intrinsic_has_base(second->intrin))
936       nir_intrinsic_set_base(second->intrin, nir_intrinsic_base(low->intrin));
937 
938    second->key = low->key;
939    second->offset = low->offset;
940 
941    second->align_mul = low->align_mul;
942    second->align_offset = low->align_offset;
943 
944    list_del(&first->head);
945    nir_instr_remove(first->instr);
946 }
947 
948 /* Returns true if it can prove that "a" and "b" point to different bindings
949  * and either one uses ACCESS_RESTRICT. */
950 static bool
bindings_different_restrict(nir_shader * shader,struct entry * a,struct entry * b)951 bindings_different_restrict(nir_shader *shader, struct entry *a, struct entry *b)
952 {
953    bool different_bindings = false;
954    nir_variable *a_var = NULL, *b_var = NULL;
955    if (a->key->resource && b->key->resource) {
956       nir_binding a_res = nir_chase_binding(nir_src_for_ssa(a->key->resource));
957       nir_binding b_res = nir_chase_binding(nir_src_for_ssa(b->key->resource));
958       if (!a_res.success || !b_res.success)
959          return false;
960 
961       if (a_res.num_indices != b_res.num_indices ||
962           a_res.desc_set != b_res.desc_set ||
963           a_res.binding != b_res.binding)
964          different_bindings = true;
965 
966       for (unsigned i = 0; i < a_res.num_indices; i++) {
967          if (nir_src_is_const(a_res.indices[i]) && nir_src_is_const(b_res.indices[i]) &&
968              nir_src_as_uint(a_res.indices[i]) != nir_src_as_uint(b_res.indices[i]))
969             different_bindings = true;
970       }
971 
972       if (different_bindings) {
973          a_var = nir_get_binding_variable(shader, a_res);
974          b_var = nir_get_binding_variable(shader, b_res);
975       }
976    } else if (a->key->var && b->key->var) {
977       a_var = a->key->var;
978       b_var = b->key->var;
979       different_bindings = a_var != b_var;
980    } else if (!!a->key->resource != !!b->key->resource) {
981       /* comparing global and ssbo access */
982       different_bindings = true;
983 
984       if (a->key->resource) {
985          nir_binding a_res = nir_chase_binding(nir_src_for_ssa(a->key->resource));
986          a_var = nir_get_binding_variable(shader, a_res);
987       }
988 
989       if (b->key->resource) {
990          nir_binding b_res = nir_chase_binding(nir_src_for_ssa(b->key->resource));
991          b_var = nir_get_binding_variable(shader, b_res);
992       }
993    } else {
994       return false;
995    }
996 
997    unsigned a_access = a->access | (a_var ? a_var->data.access : 0);
998    unsigned b_access = b->access | (b_var ? b_var->data.access : 0);
999 
1000    return different_bindings &&
1001           ((a_access | b_access) & ACCESS_RESTRICT);
1002 }
1003 
1004 static int64_t
compare_entries(struct entry * a,struct entry * b)1005 compare_entries(struct entry *a, struct entry *b)
1006 {
1007    if (!entry_key_equals(a->key, b->key))
1008       return INT64_MAX;
1009    return b->offset_signed - a->offset_signed;
1010 }
1011 
1012 static bool
may_alias(nir_shader * shader,struct entry * a,struct entry * b)1013 may_alias(nir_shader *shader, struct entry *a, struct entry *b)
1014 {
1015    assert(mode_to_index(get_variable_mode(a)) ==
1016           mode_to_index(get_variable_mode(b)));
1017 
1018    if ((a->access | b->access) & ACCESS_CAN_REORDER)
1019       return false;
1020 
1021    /* if the resources/variables are definitively different and both have
1022     * ACCESS_RESTRICT, we can assume they do not alias. */
1023    if (bindings_different_restrict(shader, a, b))
1024       return false;
1025 
1026    /* we can't compare offsets if the resources/variables might be different */
1027    if (a->key->var != b->key->var || a->key->resource != b->key->resource)
1028       return true;
1029 
1030    /* use adjacency information */
1031    /* TODO: we can look closer at the entry keys */
1032    int64_t diff = compare_entries(a, b);
1033    if (diff != INT64_MAX) {
1034       /* with atomics, nir_intrinsic_instr::num_components can be 0 */
1035       if (diff < 0)
1036          return llabs(diff) < MAX2(b->num_components, 1u) * (get_bit_size(b) / 8u);
1037       else
1038          return diff < MAX2(a->num_components, 1u) * (get_bit_size(a) / 8u);
1039    }
1040 
1041    /* TODO: we can use deref information */
1042 
1043    return true;
1044 }
1045 
1046 static bool
check_for_aliasing(struct vectorize_ctx * ctx,struct entry * first,struct entry * second)1047 check_for_aliasing(struct vectorize_ctx *ctx, struct entry *first, struct entry *second)
1048 {
1049    nir_variable_mode mode = get_variable_mode(first);
1050    if (mode & (nir_var_uniform | nir_var_system_value |
1051                nir_var_mem_push_const | nir_var_mem_ubo))
1052       return false;
1053 
1054    unsigned mode_index = mode_to_index(mode);
1055    if (first->is_store) {
1056       /* find first entry that aliases "first" */
1057       list_for_each_entry_from(struct entry, next, first, &ctx->entries[mode_index], head) {
1058          if (next == first)
1059             continue;
1060          if (next == second)
1061             return false;
1062          if (may_alias(ctx->shader, first, next))
1063             return true;
1064       }
1065    } else {
1066       /* find previous store that aliases this load */
1067       list_for_each_entry_from_rev(struct entry, prev, second, &ctx->entries[mode_index], head) {
1068          if (prev == second)
1069             continue;
1070          if (prev == first)
1071             return false;
1072          if (prev->is_store && may_alias(ctx->shader, second, prev))
1073             return true;
1074       }
1075    }
1076 
1077    return false;
1078 }
1079 
1080 static uint64_t
calc_gcd(uint64_t a,uint64_t b)1081 calc_gcd(uint64_t a, uint64_t b)
1082 {
1083    while (b != 0) {
1084       int tmp_a = a;
1085       a = b;
1086       b = tmp_a % b;
1087    }
1088    return a;
1089 }
1090 
1091 static uint64_t
round_down(uint64_t a,uint64_t b)1092 round_down(uint64_t a, uint64_t b)
1093 {
1094    return a / b * b;
1095 }
1096 
1097 static bool
addition_wraps(uint64_t a,uint64_t b,unsigned bits)1098 addition_wraps(uint64_t a, uint64_t b, unsigned bits)
1099 {
1100    uint64_t mask = BITFIELD64_MASK(bits);
1101    return ((a + b) & mask) < (a & mask);
1102 }
1103 
1104 /* Return true if the addition of "low"'s offset and "high_offset" could wrap
1105  * around.
1106  *
1107  * This is to prevent a situation where the hardware considers the high load
1108  * out-of-bounds after vectorization if the low load is out-of-bounds, even if
1109  * the wrap-around from the addition could make the high load in-bounds.
1110  */
1111 static bool
check_for_robustness(struct vectorize_ctx * ctx,struct entry * low,uint64_t high_offset)1112 check_for_robustness(struct vectorize_ctx *ctx, struct entry *low, uint64_t high_offset)
1113 {
1114    nir_variable_mode mode = get_variable_mode(low);
1115    if (!(mode & ctx->options->robust_modes))
1116       return false;
1117 
1118    unsigned scale = low->info->offset_scale;
1119 
1120    /* First, try to use alignment information in case the application provided some. If the addition
1121     * of the maximum offset of the low load and "high_offset" wraps around, we can't combine the low
1122     * and high loads.
1123     */
1124    uint64_t max_low = round_down(UINT64_MAX, low->align_mul) + low->align_offset;
1125    if (!addition_wraps(max_low / scale, high_offset / scale, 64))
1126       return false;
1127 
1128    /* We can't obtain addition_bits */
1129    if (low->info->base_src < 0)
1130       return true;
1131 
1132    /* Second, use information about the factors from address calculation (offset_defs_mul). These
1133     * are not guaranteed to be power-of-2.
1134     */
1135    uint64_t stride = 0;
1136    for (unsigned i = 0; i < low->key->offset_def_count; i++)
1137       stride = calc_gcd(low->key->offset_defs_mul[i], stride);
1138 
1139    unsigned addition_bits = low->intrin->src[low->info->base_src].ssa->bit_size;
1140    /* low's offset must be a multiple of "stride" plus "low->offset". */
1141    max_low = low->offset;
1142    if (stride)
1143       max_low = round_down(BITFIELD64_MASK(addition_bits), stride) + (low->offset % stride);
1144    return addition_wraps(max_low / scale, high_offset / scale, addition_bits);
1145 }
1146 
1147 static bool
is_strided_vector(const struct glsl_type * type)1148 is_strided_vector(const struct glsl_type *type)
1149 {
1150    if (glsl_type_is_vector(type)) {
1151       unsigned explicit_stride = glsl_get_explicit_stride(type);
1152       return explicit_stride != 0 && explicit_stride !=
1153                                         type_scalar_size_bytes(glsl_get_array_element(type));
1154    } else {
1155       return false;
1156    }
1157 }
1158 
1159 static bool
can_vectorize(struct vectorize_ctx * ctx,struct entry * first,struct entry * second)1160 can_vectorize(struct vectorize_ctx *ctx, struct entry *first, struct entry *second)
1161 {
1162    if ((first->access | second->access) & ACCESS_KEEP_SCALAR)
1163       return false;
1164 
1165    if (!(get_variable_mode(first) & ctx->options->modes) ||
1166        !(get_variable_mode(second) & ctx->options->modes))
1167       return false;
1168 
1169    if (check_for_aliasing(ctx, first, second))
1170       return false;
1171 
1172    /* we can only vectorize non-volatile loads/stores of the same type and with
1173     * the same access */
1174    if (first->info != second->info || first->access != second->access ||
1175        (first->access & ACCESS_VOLATILE) || first->info->is_atomic)
1176       return false;
1177 
1178    return true;
1179 }
1180 
1181 static bool
try_vectorize(nir_function_impl * impl,struct vectorize_ctx * ctx,struct entry * low,struct entry * high,struct entry * first,struct entry * second)1182 try_vectorize(nir_function_impl *impl, struct vectorize_ctx *ctx,
1183               struct entry *low, struct entry *high,
1184               struct entry *first, struct entry *second)
1185 {
1186    if (!can_vectorize(ctx, first, second))
1187       return false;
1188 
1189    uint64_t diff = high->offset_signed - low->offset_signed;
1190    if (check_for_robustness(ctx, low, diff))
1191       return false;
1192 
1193    /* don't attempt to vectorize accesses of row-major matrix columns */
1194    if (first->deref) {
1195       const struct glsl_type *first_type = first->deref->type;
1196       const struct glsl_type *second_type = second->deref->type;
1197       if (is_strided_vector(first_type) || is_strided_vector(second_type))
1198          return false;
1199    }
1200 
1201    /* gather information */
1202    unsigned low_bit_size = get_bit_size(low);
1203    unsigned high_bit_size = get_bit_size(high);
1204    unsigned low_size = low->num_components * low_bit_size;
1205    unsigned high_size = high->num_components * high_bit_size;
1206    unsigned new_size = MAX2(diff * 8u + high_size, low_size);
1207 
1208    /* find a good bit size for the new load/store */
1209    unsigned new_bit_size = 0;
1210    if (new_bitsize_acceptable(ctx, low_bit_size, low, high, new_size)) {
1211       new_bit_size = low_bit_size;
1212    } else if (low_bit_size != high_bit_size &&
1213               new_bitsize_acceptable(ctx, high_bit_size, low, high, new_size)) {
1214       new_bit_size = high_bit_size;
1215    } else {
1216       new_bit_size = 64;
1217       for (; new_bit_size >= 8; new_bit_size /= 2) {
1218          /* don't repeat trying out bitsizes */
1219          if (new_bit_size == low_bit_size || new_bit_size == high_bit_size)
1220             continue;
1221          if (new_bitsize_acceptable(ctx, new_bit_size, low, high, new_size))
1222             break;
1223       }
1224       if (new_bit_size < 8)
1225          return false;
1226    }
1227    unsigned new_num_components = new_size / new_bit_size;
1228 
1229    /* vectorize the loads/stores */
1230    nir_builder b = nir_builder_create(impl);
1231 
1232    if (first->is_store)
1233       vectorize_stores(&b, ctx, low, high, first, second,
1234                        new_bit_size, new_num_components, diff * 8u);
1235    else
1236       vectorize_loads(&b, ctx, low, high, first, second,
1237                       new_bit_size, new_num_components, diff * 8u);
1238 
1239    return true;
1240 }
1241 
1242 static bool
try_vectorize_shared2(struct vectorize_ctx * ctx,struct entry * low,struct entry * high,struct entry * first,struct entry * second)1243 try_vectorize_shared2(struct vectorize_ctx *ctx,
1244                       struct entry *low, struct entry *high,
1245                       struct entry *first, struct entry *second)
1246 {
1247    if (!can_vectorize(ctx, first, second) || first->deref)
1248       return false;
1249 
1250    unsigned low_bit_size = get_bit_size(low);
1251    unsigned high_bit_size = get_bit_size(high);
1252    unsigned low_size = low->num_components * low_bit_size / 8;
1253    unsigned high_size = high->num_components * high_bit_size / 8;
1254    if ((low_size != 4 && low_size != 8) || (high_size != 4 && high_size != 8))
1255       return false;
1256    if (low_size != high_size)
1257       return false;
1258    if (low->align_mul % low_size || low->align_offset % low_size)
1259       return false;
1260    if (high->align_mul % low_size || high->align_offset % low_size)
1261       return false;
1262 
1263    uint64_t diff = high->offset_signed - low->offset_signed;
1264    bool st64 = diff % (64 * low_size) == 0;
1265    unsigned stride = st64 ? 64 * low_size : low_size;
1266    if (diff % stride || diff > 255 * stride)
1267       return false;
1268 
1269    /* try to avoid creating accesses we can't combine additions/offsets into */
1270    if (high->offset > 255 * stride || (st64 && high->offset % stride))
1271       return false;
1272 
1273    if (first->is_store) {
1274       if (get_write_mask(low->intrin) != BITFIELD_MASK(low->num_components))
1275          return false;
1276       if (get_write_mask(high->intrin) != BITFIELD_MASK(high->num_components))
1277          return false;
1278    }
1279 
1280    /* vectorize the accesses */
1281    nir_builder b = nir_builder_at(nir_after_instr(first->is_store ? second->instr : first->instr));
1282 
1283    nir_def *offset = first->intrin->src[first->is_store].ssa;
1284    offset = nir_iadd_imm(&b, offset, nir_intrinsic_base(first->intrin));
1285    if (first != low)
1286       offset = nir_iadd_imm(&b, offset, -(int)diff);
1287 
1288    if (first->is_store) {
1289       nir_def *low_val = low->intrin->src[low->info->value_src].ssa;
1290       nir_def *high_val = high->intrin->src[high->info->value_src].ssa;
1291       nir_def *val = nir_vec2(&b, nir_bitcast_vector(&b, low_val, low_size * 8u),
1292                               nir_bitcast_vector(&b, high_val, low_size * 8u));
1293       nir_store_shared2_amd(&b, val, offset, .offset1 = diff / stride, .st64 = st64);
1294    } else {
1295       nir_def *new_def = nir_load_shared2_amd(&b, low_size * 8u, offset, .offset1 = diff / stride,
1296                                               .st64 = st64);
1297       nir_def_rewrite_uses(&low->intrin->def,
1298                            nir_bitcast_vector(&b, nir_channel(&b, new_def, 0), low_bit_size));
1299       nir_def_rewrite_uses(&high->intrin->def,
1300                            nir_bitcast_vector(&b, nir_channel(&b, new_def, 1), high_bit_size));
1301    }
1302 
1303    nir_instr_remove(first->instr);
1304    nir_instr_remove(second->instr);
1305 
1306    return true;
1307 }
1308 
1309 static bool
update_align(struct entry * entry)1310 update_align(struct entry *entry)
1311 {
1312    if (nir_intrinsic_has_align_mul(entry->intrin) &&
1313        (entry->align_mul != nir_intrinsic_align_mul(entry->intrin) ||
1314         entry->align_offset != nir_intrinsic_align_offset(entry->intrin))) {
1315       nir_intrinsic_set_align(entry->intrin, entry->align_mul, entry->align_offset);
1316       return true;
1317    }
1318    return false;
1319 }
1320 
1321 static bool
vectorize_sorted_entries(struct vectorize_ctx * ctx,nir_function_impl * impl,struct util_dynarray * arr)1322 vectorize_sorted_entries(struct vectorize_ctx *ctx, nir_function_impl *impl,
1323                          struct util_dynarray *arr)
1324 {
1325    unsigned num_entries = util_dynarray_num_elements(arr, struct entry *);
1326 
1327    bool progress = false;
1328    for (unsigned first_idx = 0; first_idx < num_entries; first_idx++) {
1329       struct entry *low = *util_dynarray_element(arr, struct entry *, first_idx);
1330       if (!low)
1331          continue;
1332 
1333       for (unsigned second_idx = first_idx + 1; second_idx < num_entries; second_idx++) {
1334          struct entry *high = *util_dynarray_element(arr, struct entry *, second_idx);
1335          if (!high)
1336             continue;
1337 
1338          struct entry *first = low->index < high->index ? low : high;
1339          struct entry *second = low->index < high->index ? high : low;
1340 
1341          uint64_t diff = high->offset_signed - low->offset_signed;
1342          /* Allow overfetching by 28 bytes, which can be rejected by the
1343           * callback if needed.  Driver callbacks will likely want to
1344           * restrict this to a smaller value, say 4 bytes (or none).
1345           */
1346          unsigned max_hole =
1347             first->is_store ||
1348             (ctx->options->has_shared2_amd &&
1349              get_variable_mode(first) == nir_var_mem_shared) ? 0 : 28;
1350          unsigned low_size = get_bit_size(low) / 8u * low->num_components;
1351          bool separate = diff > max_hole + low_size;
1352 
1353          if (separate) {
1354             if (!ctx->options->has_shared2_amd ||
1355                 get_variable_mode(first) != nir_var_mem_shared)
1356                break;
1357 
1358             if (try_vectorize_shared2(ctx, low, high, first, second)) {
1359                low = NULL;
1360                *util_dynarray_element(arr, struct entry *, second_idx) = NULL;
1361                progress = true;
1362                break;
1363             }
1364          } else {
1365             if (try_vectorize(impl, ctx, low, high, first, second)) {
1366                low = low->is_store ? second : first;
1367                *util_dynarray_element(arr, struct entry *, second_idx) = NULL;
1368                progress = true;
1369             }
1370          }
1371       }
1372 
1373       *util_dynarray_element(arr, struct entry *, first_idx) = low;
1374    }
1375 
1376    return progress;
1377 }
1378 
1379 static bool
vectorize_entries(struct vectorize_ctx * ctx,nir_function_impl * impl,struct hash_table * ht)1380 vectorize_entries(struct vectorize_ctx *ctx, nir_function_impl *impl, struct hash_table *ht)
1381 {
1382    if (!ht)
1383       return false;
1384 
1385    bool progress = false;
1386    hash_table_foreach(ht, entry) {
1387       struct util_dynarray *arr = entry->data;
1388       if (!arr->size)
1389          continue;
1390 
1391       qsort(util_dynarray_begin(arr),
1392             util_dynarray_num_elements(arr, struct entry *),
1393             sizeof(struct entry *), &sort_entries);
1394 
1395       while (vectorize_sorted_entries(ctx, impl, arr))
1396          progress = true;
1397 
1398       util_dynarray_foreach(arr, struct entry *, elem) {
1399          if (*elem)
1400             progress |= update_align(*elem);
1401       }
1402    }
1403 
1404    _mesa_hash_table_clear(ht, delete_entry_dynarray);
1405 
1406    return progress;
1407 }
1408 
1409 static bool
handle_barrier(struct vectorize_ctx * ctx,bool * progress,nir_function_impl * impl,nir_instr * instr)1410 handle_barrier(struct vectorize_ctx *ctx, bool *progress, nir_function_impl *impl, nir_instr *instr)
1411 {
1412    unsigned modes = 0;
1413    bool acquire = true;
1414    bool release = true;
1415    if (instr->type == nir_instr_type_intrinsic) {
1416       nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
1417       switch (intrin->intrinsic) {
1418       /* prevent speculative loads/stores */
1419       case nir_intrinsic_terminate_if:
1420       case nir_intrinsic_terminate:
1421       case nir_intrinsic_launch_mesh_workgroups:
1422          modes = nir_var_all;
1423          break;
1424       case nir_intrinsic_demote_if:
1425       case nir_intrinsic_demote:
1426          acquire = false;
1427          modes = nir_var_all;
1428          break;
1429       case nir_intrinsic_barrier:
1430          if (nir_intrinsic_memory_scope(intrin) == SCOPE_NONE)
1431             break;
1432 
1433          modes = nir_intrinsic_memory_modes(intrin) & (nir_var_mem_ssbo |
1434                                                        nir_var_mem_shared |
1435                                                        nir_var_mem_global |
1436                                                        nir_var_mem_task_payload);
1437          acquire = nir_intrinsic_memory_semantics(intrin) & NIR_MEMORY_ACQUIRE;
1438          release = nir_intrinsic_memory_semantics(intrin) & NIR_MEMORY_RELEASE;
1439          switch (nir_intrinsic_memory_scope(intrin)) {
1440          case SCOPE_INVOCATION:
1441             /* a barier should never be required for correctness with these scopes */
1442             modes = 0;
1443             break;
1444          default:
1445             break;
1446          }
1447          break;
1448       default:
1449          return false;
1450       }
1451    } else if (instr->type == nir_instr_type_call) {
1452       modes = nir_var_all;
1453    } else {
1454       return false;
1455    }
1456 
1457    while (modes) {
1458       unsigned mode_index = u_bit_scan(&modes);
1459       if ((1 << mode_index) == nir_var_mem_global) {
1460          /* Global should be rolled in with SSBO */
1461          assert(list_is_empty(&ctx->entries[mode_index]));
1462          assert(ctx->loads[mode_index] == NULL);
1463          assert(ctx->stores[mode_index] == NULL);
1464          continue;
1465       }
1466 
1467       if (acquire)
1468          *progress |= vectorize_entries(ctx, impl, ctx->loads[mode_index]);
1469       if (release)
1470          *progress |= vectorize_entries(ctx, impl, ctx->stores[mode_index]);
1471    }
1472 
1473    return true;
1474 }
1475 
1476 static bool
process_block(nir_function_impl * impl,struct vectorize_ctx * ctx,nir_block * block)1477 process_block(nir_function_impl *impl, struct vectorize_ctx *ctx, nir_block *block)
1478 {
1479    bool progress = false;
1480 
1481    for (unsigned i = 0; i < nir_num_variable_modes; i++) {
1482       list_inithead(&ctx->entries[i]);
1483       if (ctx->loads[i])
1484          _mesa_hash_table_clear(ctx->loads[i], delete_entry_dynarray);
1485       if (ctx->stores[i])
1486          _mesa_hash_table_clear(ctx->stores[i], delete_entry_dynarray);
1487    }
1488 
1489    /* create entries */
1490    unsigned next_index = 0;
1491 
1492    nir_foreach_instr_safe(instr, block) {
1493       if (handle_barrier(ctx, &progress, impl, instr))
1494          continue;
1495 
1496       /* gather information */
1497       if (instr->type != nir_instr_type_intrinsic)
1498          continue;
1499       nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
1500 
1501       const struct intrinsic_info *info = get_info(intrin->intrinsic);
1502       if (!info)
1503          continue;
1504 
1505       nir_variable_mode mode = info->mode;
1506       if (!mode)
1507          mode = nir_src_as_deref(intrin->src[info->deref_src])->modes;
1508       if (!(mode & aliasing_modes(ctx->options->modes)))
1509          continue;
1510       unsigned mode_index = mode_to_index(mode);
1511 
1512       /* create entry */
1513       struct entry *entry = create_entry(ctx, info, intrin);
1514       entry->index = next_index++;
1515 
1516       list_addtail(&entry->head, &ctx->entries[mode_index]);
1517 
1518       /* add the entry to a hash table */
1519 
1520       struct hash_table *adj_ht = NULL;
1521       if (entry->is_store) {
1522          if (!ctx->stores[mode_index])
1523             ctx->stores[mode_index] = _mesa_hash_table_create(ctx, &hash_entry_key, &entry_key_equals);
1524          adj_ht = ctx->stores[mode_index];
1525       } else {
1526          if (!ctx->loads[mode_index])
1527             ctx->loads[mode_index] = _mesa_hash_table_create(ctx, &hash_entry_key, &entry_key_equals);
1528          adj_ht = ctx->loads[mode_index];
1529       }
1530 
1531       uint32_t key_hash = hash_entry_key(entry->key);
1532       struct hash_entry *adj_entry = _mesa_hash_table_search_pre_hashed(adj_ht, key_hash, entry->key);
1533       struct util_dynarray *arr;
1534       if (adj_entry && adj_entry->data) {
1535          arr = (struct util_dynarray *)adj_entry->data;
1536       } else {
1537          arr = ralloc(ctx, struct util_dynarray);
1538          util_dynarray_init(arr, arr);
1539          _mesa_hash_table_insert_pre_hashed(adj_ht, key_hash, entry->key, arr);
1540       }
1541       util_dynarray_append(arr, struct entry *, entry);
1542    }
1543 
1544    /* sort and combine entries */
1545    for (unsigned i = 0; i < nir_num_variable_modes; i++) {
1546       progress |= vectorize_entries(ctx, impl, ctx->loads[i]);
1547       progress |= vectorize_entries(ctx, impl, ctx->stores[i]);
1548    }
1549 
1550    return progress;
1551 }
1552 
1553 bool
nir_opt_load_store_vectorize(nir_shader * shader,const nir_load_store_vectorize_options * options)1554 nir_opt_load_store_vectorize(nir_shader *shader, const nir_load_store_vectorize_options *options)
1555 {
1556    bool progress = false;
1557 
1558    struct vectorize_ctx *ctx = rzalloc(NULL, struct vectorize_ctx);
1559    ctx->shader = shader;
1560    ctx->options = options;
1561 
1562    nir_shader_index_vars(shader, options->modes);
1563 
1564    nir_foreach_function_impl(impl, shader) {
1565       if (options->modes & nir_var_function_temp)
1566          nir_function_impl_index_vars(impl);
1567 
1568       nir_foreach_block(block, impl)
1569          progress |= process_block(impl, ctx, block);
1570 
1571       nir_metadata_preserve(impl,
1572                             nir_metadata_control_flow |
1573                             nir_metadata_live_defs);
1574    }
1575 
1576    ralloc_free(ctx);
1577    return progress;
1578 }
1579 
1580 static bool
opt_load_store_update_alignments_callback(struct nir_builder * b,nir_intrinsic_instr * intrin,UNUSED void * s)1581 opt_load_store_update_alignments_callback(struct nir_builder *b,
1582                                           nir_intrinsic_instr *intrin,
1583                                           UNUSED void *s)
1584 {
1585    if (!nir_intrinsic_has_align_mul(intrin))
1586       return false;
1587 
1588    const struct intrinsic_info *info = get_info(intrin->intrinsic);
1589    if (!info)
1590       return false;
1591 
1592    struct entry *entry = create_entry(NULL, info, intrin);
1593    const bool progress = update_align(entry);
1594    ralloc_free(entry);
1595 
1596    return progress;
1597 }
1598 
1599 bool
nir_opt_load_store_update_alignments(nir_shader * shader)1600 nir_opt_load_store_update_alignments(nir_shader *shader)
1601 {
1602    return nir_shader_intrinsics_pass(shader,
1603                                      opt_load_store_update_alignments_callback,
1604                                      nir_metadata_control_flow |
1605                                      nir_metadata_live_defs |
1606                                      nir_metadata_instr_index, NULL);
1607 }
1608