• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright © 2019 Valve Corporation
3  *
4  * Permission is hereby granted, free of charge, to any person obtaining a
5  * copy of this software and associated documentation files (the "Software"),
6  * to deal in the Software without restriction, including without limitation
7  * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8  * and/or sell copies of the Software, and to permit persons to whom the
9  * Software is furnished to do so, subject to the following conditions:
10  *
11  * The above copyright notice and this permission notice (including the next
12  * paragraph) shall be included in all copies or substantial portions of the
13  * Software.
14  *
15  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18  * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20  * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21  * IN THE SOFTWARE.
22  */
23 
24 /**
25  * Although it's called a load/store "vectorization" pass, this also combines
26  * intersecting and identical loads/stores. It currently supports derefs, ubo,
27  * ssbo and push constant loads/stores.
28  *
29  * This doesn't handle copy_deref intrinsics and assumes that
30  * nir_lower_alu_to_scalar() has been called and that the IR is free from ALU
31  * modifiers. It also assumes that derefs have explicitly laid out types.
32  *
33  * After vectorization, the backend may want to call nir_lower_alu_to_scalar()
34  * and nir_lower_pack(). Also this creates cast instructions taking derefs as a
35  * source and some parts of NIR may not be able to handle that well.
36  *
37  * There are a few situations where this doesn't vectorize as well as it could:
38  * - It won't turn four consecutive vec3 loads into 3 vec4 loads.
39  * - It doesn't do global vectorization.
40  * Handling these cases probably wouldn't provide much benefit though.
41  *
42  * This probably doesn't handle big-endian GPUs correctly.
43 */
44 
45 #include "nir.h"
46 #include "nir_deref.h"
47 #include "nir_builder.h"
48 #include "nir_worklist.h"
49 #include "util/u_dynarray.h"
50 
51 #include <stdlib.h>
52 
53 struct intrinsic_info {
54    nir_variable_mode mode; /* 0 if the mode is obtained from the deref. */
55    nir_intrinsic_op op;
56    bool is_atomic;
57    /* Indices into nir_intrinsic::src[] or -1 if not applicable. */
58    int resource_src; /* resource (e.g. from vulkan_resource_index) */
59    int base_src; /* offset which it loads/stores from */
60    int deref_src; /* deref which is loads/stores from */
61    int value_src; /* the data it is storing */
62 };
63 
64 static const struct intrinsic_info *
get_info(nir_intrinsic_op op)65 get_info(nir_intrinsic_op op) {
66    switch (op) {
67 #define INFO(mode, op, atomic, res, base, deref, val) \
68 case nir_intrinsic_##op: {\
69    static const struct intrinsic_info op##_info = {mode, nir_intrinsic_##op, atomic, res, base, deref, val};\
70    return &op##_info;\
71 }
72 #define LOAD(mode, op, res, base, deref) INFO(mode, load_##op, false, res, base, deref, -1)
73 #define STORE(mode, op, res, base, deref, val) INFO(mode, store_##op, false, res, base, deref, val)
74 #define ATOMIC(mode, type, op, res, base, deref, val) INFO(mode, type##_atomic_##op, true, res, base, deref, val)
75    LOAD(nir_var_mem_push_const, push_constant, -1, 0, -1)
76    LOAD(nir_var_mem_ubo, ubo, 0, 1, -1)
77    LOAD(nir_var_mem_ssbo, ssbo, 0, 1, -1)
78    STORE(nir_var_mem_ssbo, ssbo, 1, 2, -1, 0)
79    LOAD(0, deref, -1, -1, 0)
80    STORE(0, deref, -1, -1, 0, 1)
81    LOAD(nir_var_mem_shared, shared, -1, 0, -1)
82    STORE(nir_var_mem_shared, shared, -1, 1, -1, 0)
83    LOAD(nir_var_mem_global, global, -1, 0, -1)
84    STORE(nir_var_mem_global, global, -1, 1, -1, 0)
85    LOAD(nir_var_mem_task_payload, task_payload, -1, 0, -1)
86    STORE(nir_var_mem_task_payload, task_payload, -1, 1, -1, 0)
87    ATOMIC(nir_var_mem_ssbo, ssbo, add, 0, 1, -1, 2)
88    ATOMIC(nir_var_mem_ssbo, ssbo, imin, 0, 1, -1, 2)
89    ATOMIC(nir_var_mem_ssbo, ssbo, umin, 0, 1, -1, 2)
90    ATOMIC(nir_var_mem_ssbo, ssbo, imax, 0, 1, -1, 2)
91    ATOMIC(nir_var_mem_ssbo, ssbo, umax, 0, 1, -1, 2)
92    ATOMIC(nir_var_mem_ssbo, ssbo, and, 0, 1, -1, 2)
93    ATOMIC(nir_var_mem_ssbo, ssbo, or, 0, 1, -1, 2)
94    ATOMIC(nir_var_mem_ssbo, ssbo, xor, 0, 1, -1, 2)
95    ATOMIC(nir_var_mem_ssbo, ssbo, exchange, 0, 1, -1, 2)
96    ATOMIC(nir_var_mem_ssbo, ssbo, comp_swap, 0, 1, -1, 2)
97    ATOMIC(nir_var_mem_ssbo, ssbo, fadd, 0, 1, -1, 2)
98    ATOMIC(nir_var_mem_ssbo, ssbo, fmin, 0, 1, -1, 2)
99    ATOMIC(nir_var_mem_ssbo, ssbo, fmax, 0, 1, -1, 2)
100    ATOMIC(nir_var_mem_ssbo, ssbo, fcomp_swap, 0, 1, -1, 2)
101    ATOMIC(0, deref, add, -1, -1, 0, 1)
102    ATOMIC(0, deref, imin, -1, -1, 0, 1)
103    ATOMIC(0, deref, umin, -1, -1, 0, 1)
104    ATOMIC(0, deref, imax, -1, -1, 0, 1)
105    ATOMIC(0, deref, umax, -1, -1, 0, 1)
106    ATOMIC(0, deref, and, -1, -1, 0, 1)
107    ATOMIC(0, deref, or, -1, -1, 0, 1)
108    ATOMIC(0, deref, xor, -1, -1, 0, 1)
109    ATOMIC(0, deref, exchange, -1, -1, 0, 1)
110    ATOMIC(0, deref, comp_swap, -1, -1, 0, 1)
111    ATOMIC(0, deref, fadd, -1, -1, 0, 1)
112    ATOMIC(0, deref, fmin, -1, -1, 0, 1)
113    ATOMIC(0, deref, fmax, -1, -1, 0, 1)
114    ATOMIC(0, deref, fcomp_swap, -1, -1, 0, 1)
115    ATOMIC(nir_var_mem_shared, shared, add, -1, 0, -1, 1)
116    ATOMIC(nir_var_mem_shared, shared, imin, -1, 0, -1, 1)
117    ATOMIC(nir_var_mem_shared, shared, umin, -1, 0, -1, 1)
118    ATOMIC(nir_var_mem_shared, shared, imax, -1, 0, -1, 1)
119    ATOMIC(nir_var_mem_shared, shared, umax, -1, 0, -1, 1)
120    ATOMIC(nir_var_mem_shared, shared, and, -1, 0, -1, 1)
121    ATOMIC(nir_var_mem_shared, shared, or, -1, 0, -1, 1)
122    ATOMIC(nir_var_mem_shared, shared, xor, -1, 0, -1, 1)
123    ATOMIC(nir_var_mem_shared, shared, exchange, -1, 0, -1, 1)
124    ATOMIC(nir_var_mem_shared, shared, comp_swap, -1, 0, -1, 1)
125    ATOMIC(nir_var_mem_shared, shared, fadd, -1, 0, -1, 1)
126    ATOMIC(nir_var_mem_shared, shared, fmin, -1, 0, -1, 1)
127    ATOMIC(nir_var_mem_shared, shared, fmax, -1, 0, -1, 1)
128    ATOMIC(nir_var_mem_shared, shared, fcomp_swap, -1, 0, -1, 1)
129    ATOMIC(nir_var_mem_global, global, add, -1, 0, -1, 1)
130    ATOMIC(nir_var_mem_global, global, imin, -1, 0, -1, 1)
131    ATOMIC(nir_var_mem_global, global, umin, -1, 0, -1, 1)
132    ATOMIC(nir_var_mem_global, global, imax, -1, 0, -1, 1)
133    ATOMIC(nir_var_mem_global, global, umax, -1, 0, -1, 1)
134    ATOMIC(nir_var_mem_global, global, and, -1, 0, -1, 1)
135    ATOMIC(nir_var_mem_global, global, or, -1, 0, -1, 1)
136    ATOMIC(nir_var_mem_global, global, xor, -1, 0, -1, 1)
137    ATOMIC(nir_var_mem_global, global, exchange, -1, 0, -1, 1)
138    ATOMIC(nir_var_mem_global, global, comp_swap, -1, 0, -1, 1)
139    ATOMIC(nir_var_mem_global, global, fadd, -1, 0, -1, 1)
140    ATOMIC(nir_var_mem_global, global, fmin, -1, 0, -1, 1)
141    ATOMIC(nir_var_mem_global, global, fmax, -1, 0, -1, 1)
142    ATOMIC(nir_var_mem_global, global, fcomp_swap, -1, 0, -1, 1)
143    ATOMIC(nir_var_mem_task_payload, task_payload, add, -1, 0, -1, 1)
144    ATOMIC(nir_var_mem_task_payload, task_payload, imin, -1, 0, -1, 1)
145    ATOMIC(nir_var_mem_task_payload, task_payload, umin, -1, 0, -1, 1)
146    ATOMIC(nir_var_mem_task_payload, task_payload, imax, -1, 0, -1, 1)
147    ATOMIC(nir_var_mem_task_payload, task_payload, umax, -1, 0, -1, 1)
148    ATOMIC(nir_var_mem_task_payload, task_payload, and, -1, 0, -1, 1)
149    ATOMIC(nir_var_mem_task_payload, task_payload, or, -1, 0, -1, 1)
150    ATOMIC(nir_var_mem_task_payload, task_payload, xor, -1, 0, -1, 1)
151    ATOMIC(nir_var_mem_task_payload, task_payload, exchange, -1, 0, -1, 1)
152    ATOMIC(nir_var_mem_task_payload, task_payload, comp_swap, -1, 0, -1, 1)
153    ATOMIC(nir_var_mem_task_payload, task_payload, fadd, -1, 0, -1, 1)
154    ATOMIC(nir_var_mem_task_payload, task_payload, fmin, -1, 0, -1, 1)
155    ATOMIC(nir_var_mem_task_payload, task_payload, fmax, -1, 0, -1, 1)
156    ATOMIC(nir_var_mem_task_payload, task_payload, fcomp_swap, -1, 0, -1, 1)
157    default:
158       break;
159 #undef ATOMIC
160 #undef STORE
161 #undef LOAD
162 #undef INFO
163    }
164    return NULL;
165 }
166 
167 /*
168  * Information used to compare memory operations.
169  * It canonically represents an offset as:
170  * `offset_defs[0]*offset_defs_mul[0] + offset_defs[1]*offset_defs_mul[1] + ...`
171  * "offset_defs" is sorted in ascenting order by the ssa definition's index.
172  * "resource" or "var" may be NULL.
173  */
174 struct entry_key {
175    nir_ssa_def *resource;
176    nir_variable *var;
177    unsigned offset_def_count;
178    nir_ssa_scalar *offset_defs;
179    uint64_t *offset_defs_mul;
180 };
181 
182 /* Information on a single memory operation. */
183 struct entry {
184    struct list_head head;
185    unsigned index;
186 
187    struct entry_key *key;
188    union {
189       uint64_t offset; /* sign-extended */
190       int64_t offset_signed;
191    };
192    uint32_t align_mul;
193    uint32_t align_offset;
194 
195    nir_instr *instr;
196    nir_intrinsic_instr *intrin;
197    const struct intrinsic_info *info;
198    enum gl_access_qualifier access;
199    bool is_store;
200 
201    nir_deref_instr *deref;
202 };
203 
204 struct vectorize_ctx {
205    nir_shader *shader;
206    const nir_load_store_vectorize_options *options;
207    struct list_head entries[nir_num_variable_modes];
208    struct hash_table *loads[nir_num_variable_modes];
209    struct hash_table *stores[nir_num_variable_modes];
210 };
211 
hash_entry_key(const void * key_)212 static uint32_t hash_entry_key(const void *key_)
213 {
214    /* this is careful to not include pointers in the hash calculation so that
215     * the order of the hash table walk is deterministic */
216    struct entry_key *key = (struct entry_key*)key_;
217 
218    uint32_t hash = 0;
219    if (key->resource)
220       hash = XXH32(&key->resource->index, sizeof(key->resource->index), hash);
221    if (key->var) {
222       hash = XXH32(&key->var->index, sizeof(key->var->index), hash);
223       unsigned mode = key->var->data.mode;
224       hash = XXH32(&mode, sizeof(mode), hash);
225    }
226 
227    for (unsigned i = 0; i < key->offset_def_count; i++) {
228       hash = XXH32(&key->offset_defs[i].def->index, sizeof(key->offset_defs[i].def->index), hash);
229       hash = XXH32(&key->offset_defs[i].comp, sizeof(key->offset_defs[i].comp), hash);
230    }
231 
232    hash = XXH32(key->offset_defs_mul, key->offset_def_count * sizeof(uint64_t), hash);
233 
234    return hash;
235 }
236 
entry_key_equals(const void * a_,const void * b_)237 static bool entry_key_equals(const void *a_, const void *b_)
238 {
239    struct entry_key *a = (struct entry_key*)a_;
240    struct entry_key *b = (struct entry_key*)b_;
241 
242    if (a->var != b->var || a->resource != b->resource)
243       return false;
244 
245    if (a->offset_def_count != b->offset_def_count)
246       return false;
247 
248    for (unsigned i = 0; i < a->offset_def_count; i++) {
249       if (a->offset_defs[i].def != b->offset_defs[i].def ||
250           a->offset_defs[i].comp != b->offset_defs[i].comp)
251          return false;
252    }
253 
254    size_t offset_def_mul_size = a->offset_def_count * sizeof(uint64_t);
255    if (a->offset_def_count &&
256        memcmp(a->offset_defs_mul, b->offset_defs_mul, offset_def_mul_size))
257       return false;
258 
259    return true;
260 }
261 
delete_entry_dynarray(struct hash_entry * entry)262 static void delete_entry_dynarray(struct hash_entry *entry)
263 {
264    struct util_dynarray *arr = (struct util_dynarray *)entry->data;
265    ralloc_free(arr);
266 }
267 
sort_entries(const void * a_,const void * b_)268 static int sort_entries(const void *a_, const void *b_)
269 {
270    struct entry *a = *(struct entry*const*)a_;
271    struct entry *b = *(struct entry*const*)b_;
272 
273    if (a->offset_signed > b->offset_signed)
274       return 1;
275    else if (a->offset_signed < b->offset_signed)
276       return -1;
277    else
278       return 0;
279 }
280 
281 static unsigned
get_bit_size(struct entry * entry)282 get_bit_size(struct entry *entry)
283 {
284    unsigned size = entry->is_store ?
285                    entry->intrin->src[entry->info->value_src].ssa->bit_size :
286                    entry->intrin->dest.ssa.bit_size;
287    return size == 1 ? 32u : size;
288 }
289 
290 /* If "def" is from an alu instruction with the opcode "op" and one of it's
291  * sources is a constant, update "def" to be the non-constant source, fill "c"
292  * with the constant and return true. */
293 static bool
parse_alu(nir_ssa_scalar * def,nir_op op,uint64_t * c)294 parse_alu(nir_ssa_scalar *def, nir_op op, uint64_t *c)
295 {
296    if (!nir_ssa_scalar_is_alu(*def) || nir_ssa_scalar_alu_op(*def) != op)
297       return false;
298 
299    nir_ssa_scalar src0 = nir_ssa_scalar_chase_alu_src(*def, 0);
300    nir_ssa_scalar src1 = nir_ssa_scalar_chase_alu_src(*def, 1);
301    if (op != nir_op_ishl && nir_ssa_scalar_is_const(src0)) {
302       *c = nir_ssa_scalar_as_uint(src0);
303       *def = src1;
304    } else if (nir_ssa_scalar_is_const(src1)) {
305       *c = nir_ssa_scalar_as_uint(src1);
306       *def = src0;
307    } else {
308       return false;
309    }
310    return true;
311 }
312 
313 /* Parses an offset expression such as "a * 16 + 4" and "(a * 16 + 4) * 64 + 32". */
314 static void
parse_offset(nir_ssa_scalar * base,uint64_t * base_mul,uint64_t * offset)315 parse_offset(nir_ssa_scalar *base, uint64_t *base_mul, uint64_t *offset)
316 {
317    if (nir_ssa_scalar_is_const(*base)) {
318       *offset = nir_ssa_scalar_as_uint(*base);
319       base->def = NULL;
320       return;
321    }
322 
323    uint64_t mul = 1;
324    uint64_t add = 0;
325    bool progress = false;
326    do {
327       uint64_t mul2 = 1, add2 = 0;
328 
329       progress = parse_alu(base, nir_op_imul, &mul2);
330       mul *= mul2;
331 
332       mul2 = 0;
333       progress |= parse_alu(base, nir_op_ishl, &mul2);
334       mul <<= mul2;
335 
336       progress |= parse_alu(base, nir_op_iadd, &add2);
337       add += add2 * mul;
338 
339       if (nir_ssa_scalar_is_alu(*base) && nir_ssa_scalar_alu_op(*base) == nir_op_mov) {
340          *base = nir_ssa_scalar_chase_alu_src(*base, 0);
341          progress = true;
342       }
343    } while (progress);
344 
345    if (base->def->parent_instr->type == nir_instr_type_intrinsic) {
346       nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(base->def->parent_instr);
347       if (intrin->intrinsic == nir_intrinsic_load_vulkan_descriptor)
348          base->def = NULL;
349    }
350 
351    *base_mul = mul;
352    *offset = add;
353 }
354 
355 static unsigned
type_scalar_size_bytes(const struct glsl_type * type)356 type_scalar_size_bytes(const struct glsl_type *type)
357 {
358    assert(glsl_type_is_vector_or_scalar(type) ||
359           glsl_type_is_matrix(type));
360    return glsl_type_is_boolean(type) ? 4u : glsl_get_bit_size(type) / 8u;
361 }
362 
363 static unsigned
add_to_entry_key(nir_ssa_scalar * offset_defs,uint64_t * offset_defs_mul,unsigned offset_def_count,nir_ssa_scalar def,uint64_t mul)364 add_to_entry_key(nir_ssa_scalar *offset_defs, uint64_t *offset_defs_mul,
365                  unsigned offset_def_count, nir_ssa_scalar def, uint64_t mul)
366 {
367    mul = util_mask_sign_extend(mul, def.def->bit_size);
368 
369    for (unsigned i = 0; i <= offset_def_count; i++) {
370       if (i == offset_def_count || def.def->index > offset_defs[i].def->index) {
371          /* insert before i */
372          memmove(offset_defs + i + 1, offset_defs + i,
373                  (offset_def_count - i) * sizeof(nir_ssa_scalar));
374          memmove(offset_defs_mul + i + 1, offset_defs_mul + i,
375                  (offset_def_count - i) * sizeof(uint64_t));
376          offset_defs[i] = def;
377          offset_defs_mul[i] = mul;
378          return 1;
379       } else if (def.def == offset_defs[i].def &&
380                  def.comp == offset_defs[i].comp) {
381          /* merge with offset_def at i */
382          offset_defs_mul[i] += mul;
383          return 0;
384       }
385    }
386    unreachable("Unreachable.");
387    return 0;
388 }
389 
390 static struct entry_key *
create_entry_key_from_deref(void * mem_ctx,struct vectorize_ctx * ctx,nir_deref_path * path,uint64_t * offset_base)391 create_entry_key_from_deref(void *mem_ctx,
392                             struct vectorize_ctx *ctx,
393                             nir_deref_path *path,
394                             uint64_t *offset_base)
395 {
396    unsigned path_len = 0;
397    while (path->path[path_len])
398       path_len++;
399 
400    nir_ssa_scalar offset_defs_stack[32];
401    uint64_t offset_defs_mul_stack[32];
402    nir_ssa_scalar *offset_defs = offset_defs_stack;
403    uint64_t *offset_defs_mul = offset_defs_mul_stack;
404    if (path_len > 32) {
405       offset_defs = malloc(path_len * sizeof(nir_ssa_scalar));
406       offset_defs_mul = malloc(path_len * sizeof(uint64_t));
407    }
408    unsigned offset_def_count = 0;
409 
410    struct entry_key* key = ralloc(mem_ctx, struct entry_key);
411    key->resource = NULL;
412    key->var = NULL;
413    *offset_base = 0;
414 
415    for (unsigned i = 0; i < path_len; i++) {
416       nir_deref_instr *parent = i ? path->path[i - 1] : NULL;
417       nir_deref_instr *deref = path->path[i];
418 
419       switch (deref->deref_type) {
420       case nir_deref_type_var: {
421          assert(!parent);
422          key->var = deref->var;
423          break;
424       }
425       case nir_deref_type_array:
426       case nir_deref_type_ptr_as_array: {
427          assert(parent);
428          nir_ssa_def *index = deref->arr.index.ssa;
429          uint32_t stride = nir_deref_instr_array_stride(deref);
430 
431          nir_ssa_scalar base = {.def=index, .comp=0};
432          uint64_t offset = 0, base_mul = 1;
433          parse_offset(&base, &base_mul, &offset);
434          offset = util_mask_sign_extend(offset, index->bit_size);
435 
436          *offset_base += offset * stride;
437          if (base.def) {
438             offset_def_count += add_to_entry_key(offset_defs, offset_defs_mul,
439                                                  offset_def_count,
440                                                  base, base_mul * stride);
441          }
442          break;
443       }
444       case nir_deref_type_struct: {
445          assert(parent);
446          int offset = glsl_get_struct_field_offset(parent->type, deref->strct.index);
447          *offset_base += offset;
448          break;
449       }
450       case nir_deref_type_cast: {
451          if (!parent)
452             key->resource = deref->parent.ssa;
453          break;
454       }
455       default:
456          unreachable("Unhandled deref type");
457       }
458    }
459 
460    key->offset_def_count = offset_def_count;
461    key->offset_defs = ralloc_array(mem_ctx, nir_ssa_scalar, offset_def_count);
462    key->offset_defs_mul = ralloc_array(mem_ctx, uint64_t, offset_def_count);
463    memcpy(key->offset_defs, offset_defs, offset_def_count * sizeof(nir_ssa_scalar));
464    memcpy(key->offset_defs_mul, offset_defs_mul, offset_def_count * sizeof(uint64_t));
465 
466    if (offset_defs != offset_defs_stack)
467       free(offset_defs);
468    if (offset_defs_mul != offset_defs_mul_stack)
469       free(offset_defs_mul);
470 
471    return key;
472 }
473 
474 static unsigned
parse_entry_key_from_offset(struct entry_key * key,unsigned size,unsigned left,nir_ssa_scalar base,uint64_t base_mul,uint64_t * offset)475 parse_entry_key_from_offset(struct entry_key *key, unsigned size, unsigned left,
476                             nir_ssa_scalar base, uint64_t base_mul, uint64_t *offset)
477 {
478    uint64_t new_mul;
479    uint64_t new_offset;
480    parse_offset(&base, &new_mul, &new_offset);
481    *offset += new_offset * base_mul;
482 
483    if (!base.def)
484       return 0;
485 
486    base_mul *= new_mul;
487 
488    assert(left >= 1);
489 
490    if (left >= 2) {
491       if (nir_ssa_scalar_is_alu(base) && nir_ssa_scalar_alu_op(base) == nir_op_iadd) {
492          nir_ssa_scalar src0 = nir_ssa_scalar_chase_alu_src(base, 0);
493          nir_ssa_scalar src1 = nir_ssa_scalar_chase_alu_src(base, 1);
494          unsigned amount = parse_entry_key_from_offset(key, size, left - 1, src0, base_mul, offset);
495          amount += parse_entry_key_from_offset(key, size + amount, left - amount, src1, base_mul, offset);
496             return amount;
497          }
498       }
499 
500    return add_to_entry_key(key->offset_defs, key->offset_defs_mul, size, base, base_mul);
501 }
502 
503 static struct entry_key *
create_entry_key_from_offset(void * mem_ctx,nir_ssa_def * base,uint64_t base_mul,uint64_t * offset)504 create_entry_key_from_offset(void *mem_ctx, nir_ssa_def *base, uint64_t base_mul, uint64_t *offset)
505 {
506    struct entry_key *key = ralloc(mem_ctx, struct entry_key);
507    key->resource = NULL;
508    key->var = NULL;
509    if (base) {
510       nir_ssa_scalar offset_defs[32];
511       uint64_t offset_defs_mul[32];
512       key->offset_defs = offset_defs;
513       key->offset_defs_mul = offset_defs_mul;
514 
515       nir_ssa_scalar scalar = {.def=base, .comp=0};
516       key->offset_def_count = parse_entry_key_from_offset(key, 0, 32, scalar, base_mul, offset);
517 
518       key->offset_defs = ralloc_array(mem_ctx, nir_ssa_scalar, key->offset_def_count);
519       key->offset_defs_mul = ralloc_array(mem_ctx, uint64_t, key->offset_def_count);
520       memcpy(key->offset_defs, offset_defs, key->offset_def_count * sizeof(nir_ssa_scalar));
521       memcpy(key->offset_defs_mul, offset_defs_mul, key->offset_def_count * sizeof(uint64_t));
522    } else {
523       key->offset_def_count = 0;
524       key->offset_defs = NULL;
525       key->offset_defs_mul = NULL;
526    }
527    return key;
528 }
529 
530 static nir_variable_mode
get_variable_mode(struct entry * entry)531 get_variable_mode(struct entry *entry)
532 {
533    if (entry->info->mode)
534       return entry->info->mode;
535    assert(entry->deref && util_bitcount(entry->deref->modes) == 1);
536    return entry->deref->modes;
537 }
538 
539 static unsigned
mode_to_index(nir_variable_mode mode)540 mode_to_index(nir_variable_mode mode)
541 {
542    assert(util_bitcount(mode) == 1);
543 
544    /* Globals and SSBOs should be tracked together */
545    if (mode == nir_var_mem_global)
546       mode = nir_var_mem_ssbo;
547 
548    return ffs(mode) - 1;
549 }
550 
551 static nir_variable_mode
aliasing_modes(nir_variable_mode modes)552 aliasing_modes(nir_variable_mode modes)
553 {
554    /* Global and SSBO can alias */
555    if (modes & (nir_var_mem_ssbo | nir_var_mem_global))
556       modes |= nir_var_mem_ssbo | nir_var_mem_global;
557    return modes;
558 }
559 
560 static void
calc_alignment(struct entry * entry)561 calc_alignment(struct entry *entry)
562 {
563    uint32_t align_mul = 31;
564    for (unsigned i = 0; i < entry->key->offset_def_count; i++) {
565       if (entry->key->offset_defs_mul[i])
566          align_mul = MIN2(align_mul, ffsll(entry->key->offset_defs_mul[i]));
567    }
568 
569    entry->align_mul = 1u << (align_mul - 1);
570    bool has_align = nir_intrinsic_infos[entry->intrin->intrinsic].index_map[NIR_INTRINSIC_ALIGN_MUL];
571    if (!has_align || entry->align_mul >= nir_intrinsic_align_mul(entry->intrin)) {
572       entry->align_offset = entry->offset % entry->align_mul;
573    } else {
574       entry->align_mul = nir_intrinsic_align_mul(entry->intrin);
575       entry->align_offset = nir_intrinsic_align_offset(entry->intrin);
576    }
577 }
578 
579 static struct entry *
create_entry(struct vectorize_ctx * ctx,const struct intrinsic_info * info,nir_intrinsic_instr * intrin)580 create_entry(struct vectorize_ctx *ctx,
581              const struct intrinsic_info *info,
582              nir_intrinsic_instr *intrin)
583 {
584    struct entry *entry = rzalloc(ctx, struct entry);
585    entry->intrin = intrin;
586    entry->instr = &intrin->instr;
587    entry->info = info;
588    entry->is_store = entry->info->value_src >= 0;
589 
590    if (entry->info->deref_src >= 0) {
591       entry->deref = nir_src_as_deref(intrin->src[entry->info->deref_src]);
592       nir_deref_path path;
593       nir_deref_path_init(&path, entry->deref, NULL);
594       entry->key = create_entry_key_from_deref(entry, ctx, &path, &entry->offset);
595       nir_deref_path_finish(&path);
596    } else {
597       nir_ssa_def *base = entry->info->base_src >= 0 ?
598                           intrin->src[entry->info->base_src].ssa : NULL;
599       uint64_t offset = 0;
600       if (nir_intrinsic_has_base(intrin))
601          offset += nir_intrinsic_base(intrin);
602       entry->key = create_entry_key_from_offset(entry, base, 1, &offset);
603       entry->offset = offset;
604 
605       if (base)
606          entry->offset = util_mask_sign_extend(entry->offset, base->bit_size);
607    }
608 
609    if (entry->info->resource_src >= 0)
610       entry->key->resource = intrin->src[entry->info->resource_src].ssa;
611 
612    if (nir_intrinsic_has_access(intrin))
613       entry->access = nir_intrinsic_access(intrin);
614    else if (entry->key->var)
615       entry->access = entry->key->var->data.access;
616 
617    if (nir_intrinsic_can_reorder(intrin))
618       entry->access |= ACCESS_CAN_REORDER;
619 
620    uint32_t restrict_modes = nir_var_shader_in | nir_var_shader_out;
621    restrict_modes |= nir_var_shader_temp | nir_var_function_temp;
622    restrict_modes |= nir_var_uniform | nir_var_mem_push_const;
623    restrict_modes |= nir_var_system_value | nir_var_mem_shared;
624    restrict_modes |= nir_var_mem_task_payload;
625    if (get_variable_mode(entry) & restrict_modes)
626       entry->access |= ACCESS_RESTRICT;
627 
628    calc_alignment(entry);
629 
630    return entry;
631 }
632 
633 static nir_deref_instr *
cast_deref(nir_builder * b,unsigned num_components,unsigned bit_size,nir_deref_instr * deref)634 cast_deref(nir_builder *b, unsigned num_components, unsigned bit_size, nir_deref_instr *deref)
635 {
636    if (glsl_get_components(deref->type) == num_components &&
637        type_scalar_size_bytes(deref->type)*8u == bit_size)
638       return deref;
639 
640    enum glsl_base_type types[] = {
641       GLSL_TYPE_UINT8, GLSL_TYPE_UINT16, GLSL_TYPE_UINT, GLSL_TYPE_UINT64};
642    enum glsl_base_type base = types[ffs(bit_size / 8u) - 1u];
643    const struct glsl_type *type = glsl_vector_type(base, num_components);
644 
645    if (deref->type == type)
646       return deref;
647 
648    return nir_build_deref_cast(b, &deref->dest.ssa, deref->modes, type, 0);
649 }
650 
651 /* Return true if "new_bit_size" is a usable bit size for a vectorized load/store
652  * of "low" and "high". */
653 static bool
new_bitsize_acceptable(struct vectorize_ctx * ctx,unsigned new_bit_size,struct entry * low,struct entry * high,unsigned size)654 new_bitsize_acceptable(struct vectorize_ctx *ctx, unsigned new_bit_size,
655                        struct entry *low, struct entry *high, unsigned size)
656 {
657    if (size % new_bit_size != 0)
658       return false;
659 
660    unsigned new_num_components = size / new_bit_size;
661    if (!nir_num_components_valid(new_num_components))
662       return false;
663 
664    unsigned high_offset = high->offset_signed - low->offset_signed;
665 
666    /* check nir_extract_bits limitations */
667    unsigned common_bit_size = MIN2(get_bit_size(low), get_bit_size(high));
668    common_bit_size = MIN2(common_bit_size, new_bit_size);
669    if (high_offset > 0)
670       common_bit_size = MIN2(common_bit_size, (1u << (ffs(high_offset * 8) - 1)));
671    if (new_bit_size / common_bit_size > NIR_MAX_VEC_COMPONENTS)
672       return false;
673 
674    if (!ctx->options->callback(low->align_mul,
675                                low->align_offset,
676                                new_bit_size, new_num_components,
677                                low->intrin, high->intrin,
678                                ctx->options->cb_data))
679       return false;
680 
681    if (low->is_store) {
682       unsigned low_size = low->intrin->num_components * get_bit_size(low);
683       unsigned high_size = high->intrin->num_components * get_bit_size(high);
684 
685       if (low_size % new_bit_size != 0)
686          return false;
687       if (high_size % new_bit_size != 0)
688          return false;
689 
690       unsigned write_mask = nir_intrinsic_write_mask(low->intrin);
691       if (!nir_component_mask_can_reinterpret(write_mask, get_bit_size(low), new_bit_size))
692          return false;
693 
694       write_mask = nir_intrinsic_write_mask(high->intrin);
695       if (!nir_component_mask_can_reinterpret(write_mask, get_bit_size(high), new_bit_size))
696          return false;
697    }
698 
699    return true;
700 }
701 
subtract_deref(nir_builder * b,nir_deref_instr * deref,int64_t offset)702 static nir_deref_instr *subtract_deref(nir_builder *b, nir_deref_instr *deref, int64_t offset)
703 {
704    /* avoid adding another deref to the path */
705    if (deref->deref_type == nir_deref_type_ptr_as_array &&
706        nir_src_is_const(deref->arr.index) &&
707        offset % nir_deref_instr_array_stride(deref) == 0) {
708       unsigned stride = nir_deref_instr_array_stride(deref);
709       nir_ssa_def *index = nir_imm_intN_t(b, nir_src_as_int(deref->arr.index) - offset / stride,
710                                           deref->dest.ssa.bit_size);
711       return nir_build_deref_ptr_as_array(b, nir_deref_instr_parent(deref), index);
712    }
713 
714    if (deref->deref_type == nir_deref_type_array &&
715        nir_src_is_const(deref->arr.index)) {
716       nir_deref_instr *parent = nir_deref_instr_parent(deref);
717       unsigned stride = glsl_get_explicit_stride(parent->type);
718       if (offset % stride == 0)
719          return nir_build_deref_array_imm(
720             b, parent, nir_src_as_int(deref->arr.index) - offset / stride);
721    }
722 
723 
724    deref = nir_build_deref_cast(b, &deref->dest.ssa, deref->modes,
725                                 glsl_scalar_type(GLSL_TYPE_UINT8), 1);
726    return nir_build_deref_ptr_as_array(
727       b, deref, nir_imm_intN_t(b, -offset, deref->dest.ssa.bit_size));
728 }
729 
730 static void
vectorize_loads(nir_builder * b,struct vectorize_ctx * ctx,struct entry * low,struct entry * high,struct entry * first,struct entry * second,unsigned new_bit_size,unsigned new_num_components,unsigned high_start)731 vectorize_loads(nir_builder *b, struct vectorize_ctx *ctx,
732                 struct entry *low, struct entry *high,
733                 struct entry *first, struct entry *second,
734                 unsigned new_bit_size, unsigned new_num_components,
735                 unsigned high_start)
736 {
737    unsigned low_bit_size = get_bit_size(low);
738    unsigned high_bit_size = get_bit_size(high);
739    bool low_bool = low->intrin->dest.ssa.bit_size == 1;
740    bool high_bool = high->intrin->dest.ssa.bit_size == 1;
741    nir_ssa_def *data = &first->intrin->dest.ssa;
742 
743    b->cursor = nir_after_instr(first->instr);
744 
745    /* update the load's destination size and extract data for each of the original loads */
746    data->num_components = new_num_components;
747    data->bit_size = new_bit_size;
748 
749    nir_ssa_def *low_def = nir_extract_bits(
750       b, &data, 1, 0, low->intrin->num_components, low_bit_size);
751    nir_ssa_def *high_def = nir_extract_bits(
752       b, &data, 1, high_start, high->intrin->num_components, high_bit_size);
753 
754    /* convert booleans */
755    low_def = low_bool ? nir_i2b(b, low_def) : nir_mov(b, low_def);
756    high_def = high_bool ? nir_i2b(b, high_def) : nir_mov(b, high_def);
757 
758    /* update uses */
759    if (first == low) {
760       nir_ssa_def_rewrite_uses_after(&low->intrin->dest.ssa, low_def,
761                                      high_def->parent_instr);
762       nir_ssa_def_rewrite_uses(&high->intrin->dest.ssa, high_def);
763    } else {
764       nir_ssa_def_rewrite_uses(&low->intrin->dest.ssa, low_def);
765       nir_ssa_def_rewrite_uses_after(&high->intrin->dest.ssa, high_def,
766                                      high_def->parent_instr);
767    }
768 
769    /* update the intrinsic */
770    first->intrin->num_components = new_num_components;
771 
772    const struct intrinsic_info *info = first->info;
773 
774    /* update the offset */
775    if (first != low && info->base_src >= 0) {
776       /* let nir_opt_algebraic() remove this addition. this doesn't have much
777        * issues with subtracting 16 from expressions like "(i + 1) * 16" because
778        * nir_opt_algebraic() turns them into "i * 16 + 16" */
779       b->cursor = nir_before_instr(first->instr);
780 
781       nir_ssa_def *new_base = first->intrin->src[info->base_src].ssa;
782       new_base = nir_iadd_imm(b, new_base, -(int)(high_start / 8u));
783 
784       nir_instr_rewrite_src(first->instr, &first->intrin->src[info->base_src],
785                             nir_src_for_ssa(new_base));
786    }
787 
788    /* update the deref */
789    if (info->deref_src >= 0) {
790       b->cursor = nir_before_instr(first->instr);
791 
792       nir_deref_instr *deref = nir_src_as_deref(first->intrin->src[info->deref_src]);
793       if (first != low && high_start != 0)
794          deref = subtract_deref(b, deref, high_start / 8u);
795       first->deref = cast_deref(b, new_num_components, new_bit_size, deref);
796 
797       nir_instr_rewrite_src(first->instr, &first->intrin->src[info->deref_src],
798                             nir_src_for_ssa(&first->deref->dest.ssa));
799    }
800 
801    /* update align */
802    if (nir_intrinsic_has_range_base(first->intrin)) {
803       uint32_t low_base = nir_intrinsic_range_base(low->intrin);
804       uint32_t high_base = nir_intrinsic_range_base(high->intrin);
805       uint32_t low_end = low_base + nir_intrinsic_range(low->intrin);
806       uint32_t high_end = high_base + nir_intrinsic_range(high->intrin);
807 
808       nir_intrinsic_set_range_base(first->intrin, low_base);
809       nir_intrinsic_set_range(first->intrin, MAX2(low_end, high_end) - low_base);
810    }
811 
812    first->key = low->key;
813    first->offset = low->offset;
814 
815    first->align_mul = low->align_mul;
816    first->align_offset = low->align_offset;
817 
818    nir_instr_remove(second->instr);
819 }
820 
821 static void
vectorize_stores(nir_builder * b,struct vectorize_ctx * ctx,struct entry * low,struct entry * high,struct entry * first,struct entry * second,unsigned new_bit_size,unsigned new_num_components,unsigned high_start)822 vectorize_stores(nir_builder *b, struct vectorize_ctx *ctx,
823                  struct entry *low, struct entry *high,
824                  struct entry *first, struct entry *second,
825                  unsigned new_bit_size, unsigned new_num_components,
826                  unsigned high_start)
827 {
828    ASSERTED unsigned low_size = low->intrin->num_components * get_bit_size(low);
829    assert(low_size % new_bit_size == 0);
830 
831    b->cursor = nir_before_instr(second->instr);
832 
833    /* get new writemasks */
834    uint32_t low_write_mask = nir_intrinsic_write_mask(low->intrin);
835    uint32_t high_write_mask = nir_intrinsic_write_mask(high->intrin);
836    low_write_mask = nir_component_mask_reinterpret(low_write_mask,
837                                                    get_bit_size(low),
838                                                    new_bit_size);
839    high_write_mask = nir_component_mask_reinterpret(high_write_mask,
840                                                     get_bit_size(high),
841                                                     new_bit_size);
842    high_write_mask <<= high_start / new_bit_size;
843 
844    uint32_t write_mask = low_write_mask | high_write_mask;
845 
846    /* convert booleans */
847    nir_ssa_def *low_val = low->intrin->src[low->info->value_src].ssa;
848    nir_ssa_def *high_val = high->intrin->src[high->info->value_src].ssa;
849    low_val = low_val->bit_size == 1 ? nir_b2i(b, low_val, 32) : low_val;
850    high_val = high_val->bit_size == 1 ? nir_b2i(b, high_val, 32) : high_val;
851 
852    /* combine the data */
853    nir_ssa_def *data_channels[NIR_MAX_VEC_COMPONENTS];
854    for (unsigned i = 0; i < new_num_components; i++) {
855       bool set_low = low_write_mask & (1 << i);
856       bool set_high = high_write_mask & (1 << i);
857 
858       if (set_low && (!set_high || low == second)) {
859          unsigned offset = i * new_bit_size;
860          data_channels[i] = nir_extract_bits(b, &low_val, 1, offset, 1, new_bit_size);
861       } else if (set_high) {
862          assert(!set_low || high == second);
863          unsigned offset = i * new_bit_size - high_start;
864          data_channels[i] = nir_extract_bits(b, &high_val, 1, offset, 1, new_bit_size);
865       } else {
866          data_channels[i] = nir_ssa_undef(b, 1, new_bit_size);
867       }
868    }
869    nir_ssa_def *data = nir_vec(b, data_channels, new_num_components);
870 
871    /* update the intrinsic */
872    nir_intrinsic_set_write_mask(second->intrin, write_mask);
873    second->intrin->num_components = data->num_components;
874 
875    const struct intrinsic_info *info = second->info;
876    assert(info->value_src >= 0);
877    nir_instr_rewrite_src(second->instr, &second->intrin->src[info->value_src],
878                          nir_src_for_ssa(data));
879 
880    /* update the offset */
881    if (second != low && info->base_src >= 0)
882       nir_instr_rewrite_src(second->instr, &second->intrin->src[info->base_src],
883                             low->intrin->src[info->base_src]);
884 
885    /* update the deref */
886    if (info->deref_src >= 0) {
887       b->cursor = nir_before_instr(second->instr);
888       second->deref = cast_deref(b, new_num_components, new_bit_size,
889                                  nir_src_as_deref(low->intrin->src[info->deref_src]));
890       nir_instr_rewrite_src(second->instr, &second->intrin->src[info->deref_src],
891                             nir_src_for_ssa(&second->deref->dest.ssa));
892    }
893 
894    /* update base/align */
895    if (second != low && nir_intrinsic_has_base(second->intrin))
896       nir_intrinsic_set_base(second->intrin, nir_intrinsic_base(low->intrin));
897 
898    second->key = low->key;
899    second->offset = low->offset;
900 
901    second->align_mul = low->align_mul;
902    second->align_offset = low->align_offset;
903 
904    list_del(&first->head);
905    nir_instr_remove(first->instr);
906 }
907 
908 /* Returns true if it can prove that "a" and "b" point to different bindings
909  * and either one uses ACCESS_RESTRICT. */
910 static bool
bindings_different_restrict(nir_shader * shader,struct entry * a,struct entry * b)911 bindings_different_restrict(nir_shader *shader, struct entry *a, struct entry *b)
912 {
913    bool different_bindings = false;
914    nir_variable *a_var = NULL, *b_var = NULL;
915    if (a->key->resource && b->key->resource) {
916       nir_binding a_res = nir_chase_binding(nir_src_for_ssa(a->key->resource));
917       nir_binding b_res = nir_chase_binding(nir_src_for_ssa(b->key->resource));
918       if (!a_res.success || !b_res.success)
919          return false;
920 
921       if (a_res.num_indices != b_res.num_indices ||
922           a_res.desc_set != b_res.desc_set ||
923           a_res.binding != b_res.binding)
924             different_bindings = true;
925 
926       for (unsigned i = 0; i < a_res.num_indices; i++) {
927          if (nir_src_is_const(a_res.indices[i]) && nir_src_is_const(b_res.indices[i]) &&
928              nir_src_as_uint(a_res.indices[i]) != nir_src_as_uint(b_res.indices[i]))
929                different_bindings = true;
930       }
931 
932       if (different_bindings) {
933          a_var = nir_get_binding_variable(shader, a_res);
934          b_var = nir_get_binding_variable(shader, b_res);
935       }
936    } else if (a->key->var && b->key->var) {
937       a_var = a->key->var;
938       b_var = b->key->var;
939       different_bindings = a_var != b_var;
940    } else if (!!a->key->resource != !!b->key->resource) {
941       /* comparing global and ssbo access */
942       different_bindings = true;
943 
944       if (a->key->resource) {
945          nir_binding a_res = nir_chase_binding(nir_src_for_ssa(a->key->resource));
946          a_var = nir_get_binding_variable(shader, a_res);
947       }
948 
949       if (b->key->resource) {
950          nir_binding b_res = nir_chase_binding(nir_src_for_ssa(b->key->resource));
951          b_var = nir_get_binding_variable(shader, b_res);
952       }
953    } else {
954       return false;
955    }
956 
957    unsigned a_access = a->access | (a_var ? a_var->data.access : 0);
958    unsigned b_access = b->access | (b_var ? b_var->data.access : 0);
959 
960    return different_bindings &&
961           ((a_access | b_access) & ACCESS_RESTRICT);
962 }
963 
964 static int64_t
compare_entries(struct entry * a,struct entry * b)965 compare_entries(struct entry *a, struct entry *b)
966 {
967    if (!entry_key_equals(a->key, b->key))
968       return INT64_MAX;
969    return b->offset_signed - a->offset_signed;
970 }
971 
972 static bool
may_alias(nir_shader * shader,struct entry * a,struct entry * b)973 may_alias(nir_shader *shader, struct entry *a, struct entry *b)
974 {
975    assert(mode_to_index(get_variable_mode(a)) ==
976           mode_to_index(get_variable_mode(b)));
977 
978    if ((a->access | b->access) & ACCESS_CAN_REORDER)
979       return false;
980 
981    /* if the resources/variables are definitively different and both have
982     * ACCESS_RESTRICT, we can assume they do not alias. */
983    if (bindings_different_restrict(shader, a, b))
984       return false;
985 
986    /* we can't compare offsets if the resources/variables might be different */
987    if (a->key->var != b->key->var || a->key->resource != b->key->resource)
988       return true;
989 
990    /* use adjacency information */
991    /* TODO: we can look closer at the entry keys */
992    int64_t diff = compare_entries(a, b);
993    if (diff != INT64_MAX) {
994       /* with atomics, intrin->num_components can be 0 */
995       if (diff < 0)
996          return llabs(diff) < MAX2(b->intrin->num_components, 1u) * (get_bit_size(b) / 8u);
997       else
998          return diff < MAX2(a->intrin->num_components, 1u) * (get_bit_size(a) / 8u);
999    }
1000 
1001    /* TODO: we can use deref information */
1002 
1003    return true;
1004 }
1005 
1006 static bool
check_for_aliasing(struct vectorize_ctx * ctx,struct entry * first,struct entry * second)1007 check_for_aliasing(struct vectorize_ctx *ctx, struct entry *first, struct entry *second)
1008 {
1009    nir_variable_mode mode = get_variable_mode(first);
1010    if (mode & (nir_var_uniform | nir_var_system_value |
1011                nir_var_mem_push_const | nir_var_mem_ubo))
1012       return false;
1013 
1014    unsigned mode_index = mode_to_index(mode);
1015    if (first->is_store) {
1016       /* find first entry that aliases "first" */
1017       list_for_each_entry_from(struct entry, next, first, &ctx->entries[mode_index], head) {
1018          if (next == first)
1019             continue;
1020          if (next == second)
1021             return false;
1022          if (may_alias(ctx->shader, first, next))
1023             return true;
1024       }
1025    } else {
1026       /* find previous store that aliases this load */
1027       list_for_each_entry_from_rev(struct entry, prev, second, &ctx->entries[mode_index], head) {
1028          if (prev == second)
1029             continue;
1030          if (prev == first)
1031             return false;
1032          if (prev->is_store && may_alias(ctx->shader, second, prev))
1033             return true;
1034       }
1035    }
1036 
1037    return false;
1038 }
1039 
1040 static uint64_t
calc_gcd(uint64_t a,uint64_t b)1041 calc_gcd(uint64_t a, uint64_t b)
1042 {
1043    while (b != 0) {
1044       int tmp_a = a;
1045       a = b;
1046       b = tmp_a % b;
1047    }
1048    return a;
1049 }
1050 
1051 static uint64_t
round_down(uint64_t a,uint64_t b)1052 round_down(uint64_t a, uint64_t b)
1053 {
1054    return a / b * b;
1055 }
1056 
1057 static bool
addition_wraps(uint64_t a,uint64_t b,unsigned bits)1058 addition_wraps(uint64_t a, uint64_t b, unsigned bits)
1059 {
1060    uint64_t mask = BITFIELD64_MASK(bits);
1061    return ((a + b) & mask) < (a & mask);
1062 }
1063 
1064 /* Return true if the addition of "low"'s offset and "high_offset" could wrap
1065  * around.
1066  *
1067  * This is to prevent a situation where the hardware considers the high load
1068  * out-of-bounds after vectorization if the low load is out-of-bounds, even if
1069  * the wrap-around from the addition could make the high load in-bounds.
1070  */
1071 static bool
check_for_robustness(struct vectorize_ctx * ctx,struct entry * low,uint64_t high_offset)1072 check_for_robustness(struct vectorize_ctx *ctx, struct entry *low, uint64_t high_offset)
1073 {
1074    nir_variable_mode mode = get_variable_mode(low);
1075    if (!(mode & ctx->options->robust_modes))
1076       return false;
1077 
1078    /* First, try to use alignment information in case the application provided some. If the addition
1079     * of the maximum offset of the low load and "high_offset" wraps around, we can't combine the low
1080     * and high loads.
1081     */
1082    uint64_t max_low = round_down(UINT64_MAX, low->align_mul) + low->align_offset;
1083    if (!addition_wraps(max_low, high_offset, 64))
1084       return false;
1085 
1086    /* We can't obtain addition_bits */
1087    if (low->info->base_src < 0)
1088       return true;
1089 
1090    /* Second, use information about the factors from address calculation (offset_defs_mul). These
1091     * are not guaranteed to be power-of-2.
1092     */
1093    uint64_t stride = 0;
1094    for (unsigned i = 0; i < low->key->offset_def_count; i++)
1095       stride = calc_gcd(low->key->offset_defs_mul[i], stride);
1096 
1097    unsigned addition_bits = low->intrin->src[low->info->base_src].ssa->bit_size;
1098    /* low's offset must be a multiple of "stride" plus "low->offset". */
1099    max_low = low->offset;
1100    if (stride)
1101       max_low = round_down(BITFIELD64_MASK(addition_bits), stride) + (low->offset % stride);
1102    return addition_wraps(max_low, high_offset, addition_bits);
1103 }
1104 
1105 static bool
is_strided_vector(const struct glsl_type * type)1106 is_strided_vector(const struct glsl_type *type)
1107 {
1108    if (glsl_type_is_vector(type)) {
1109       unsigned explicit_stride = glsl_get_explicit_stride(type);
1110       return explicit_stride != 0 && explicit_stride !=
1111              type_scalar_size_bytes(glsl_get_array_element(type));
1112    } else {
1113       return false;
1114    }
1115 }
1116 
1117 static bool
can_vectorize(struct vectorize_ctx * ctx,struct entry * first,struct entry * second)1118 can_vectorize(struct vectorize_ctx *ctx, struct entry *first, struct entry *second)
1119 {
1120    if (!(get_variable_mode(first) & ctx->options->modes) ||
1121        !(get_variable_mode(second) & ctx->options->modes))
1122       return false;
1123 
1124    if (check_for_aliasing(ctx, first, second))
1125       return false;
1126 
1127    /* we can only vectorize non-volatile loads/stores of the same type and with
1128     * the same access */
1129    if (first->info != second->info || first->access != second->access ||
1130        (first->access & ACCESS_VOLATILE) || first->info->is_atomic)
1131       return false;
1132 
1133    return true;
1134 }
1135 
1136 static bool
try_vectorize(nir_function_impl * impl,struct vectorize_ctx * ctx,struct entry * low,struct entry * high,struct entry * first,struct entry * second)1137 try_vectorize(nir_function_impl *impl, struct vectorize_ctx *ctx,
1138               struct entry *low, struct entry *high,
1139               struct entry *first, struct entry *second)
1140 {
1141    if (!can_vectorize(ctx, first, second))
1142       return false;
1143 
1144    uint64_t diff = high->offset_signed - low->offset_signed;
1145    if (check_for_robustness(ctx, low, diff))
1146       return false;
1147 
1148    /* don't attempt to vectorize accesses of row-major matrix columns */
1149    if (first->deref) {
1150       const struct glsl_type *first_type = first->deref->type;
1151       const struct glsl_type *second_type = second->deref->type;
1152       if (is_strided_vector(first_type) || is_strided_vector(second_type))
1153          return false;
1154    }
1155 
1156    /* gather information */
1157    unsigned low_bit_size = get_bit_size(low);
1158    unsigned high_bit_size = get_bit_size(high);
1159    unsigned low_size = low->intrin->num_components * low_bit_size;
1160    unsigned high_size = high->intrin->num_components * high_bit_size;
1161    unsigned new_size = MAX2(diff * 8u + high_size, low_size);
1162 
1163    /* find a good bit size for the new load/store */
1164    unsigned new_bit_size = 0;
1165    if (new_bitsize_acceptable(ctx, low_bit_size, low, high, new_size)) {
1166       new_bit_size = low_bit_size;
1167    } else if (low_bit_size != high_bit_size &&
1168               new_bitsize_acceptable(ctx, high_bit_size, low, high, new_size)) {
1169       new_bit_size = high_bit_size;
1170    } else {
1171       new_bit_size = 64;
1172       for (; new_bit_size >= 8; new_bit_size /= 2) {
1173          /* don't repeat trying out bitsizes */
1174          if (new_bit_size == low_bit_size || new_bit_size == high_bit_size)
1175             continue;
1176          if (new_bitsize_acceptable(ctx, new_bit_size, low, high, new_size))
1177             break;
1178       }
1179       if (new_bit_size < 8)
1180          return false;
1181    }
1182    unsigned new_num_components = new_size / new_bit_size;
1183 
1184    /* vectorize the loads/stores */
1185    nir_builder b;
1186    nir_builder_init(&b, impl);
1187 
1188    if (first->is_store)
1189       vectorize_stores(&b, ctx, low, high, first, second,
1190                        new_bit_size, new_num_components, diff * 8u);
1191    else
1192       vectorize_loads(&b, ctx, low, high, first, second,
1193                       new_bit_size, new_num_components, diff * 8u);
1194 
1195    return true;
1196 }
1197 
1198 static bool
try_vectorize_shared2(nir_function_impl * impl,struct vectorize_ctx * ctx,struct entry * low,struct entry * high,struct entry * first,struct entry * second)1199 try_vectorize_shared2(nir_function_impl *impl, struct vectorize_ctx *ctx,
1200                       struct entry *low, struct entry *high,
1201                       struct entry *first, struct entry *second)
1202 {
1203    if (!can_vectorize(ctx, first, second) || first->deref)
1204       return false;
1205 
1206    unsigned low_bit_size = get_bit_size(low);
1207    unsigned high_bit_size = get_bit_size(high);
1208    unsigned low_size = low->intrin->num_components * low_bit_size / 8;
1209    unsigned high_size = high->intrin->num_components * high_bit_size / 8;
1210    if ((low_size != 4 && low_size != 8) || (high_size != 4 && high_size != 8))
1211       return false;
1212    if (low_size != high_size)
1213       return false;
1214    if (low->align_mul % low_size || low->align_offset % low_size)
1215       return false;
1216    if (high->align_mul % low_size || high->align_offset % low_size)
1217       return false;
1218 
1219    uint64_t diff = high->offset_signed - low->offset_signed;
1220    bool st64 = diff % (64 * low_size) == 0;
1221    unsigned stride = st64 ? 64 * low_size : low_size;
1222    if (diff % stride || diff > 255 * stride)
1223       return false;
1224 
1225    /* try to avoid creating accesses we can't combine additions/offsets into */
1226    if (high->offset > 255 * stride || (st64 && high->offset % stride))
1227       return false;
1228 
1229    if (first->is_store) {
1230       if (nir_intrinsic_write_mask(low->intrin) != BITFIELD_MASK(low->intrin->num_components))
1231          return false;
1232       if (nir_intrinsic_write_mask(high->intrin) != BITFIELD_MASK(high->intrin->num_components))
1233          return false;
1234    }
1235 
1236    /* vectorize the accesses */
1237    nir_builder b;
1238    nir_builder_init(&b, impl);
1239 
1240    b.cursor = nir_after_instr(first->is_store ? second->instr : first->instr);
1241 
1242    nir_ssa_def *offset = first->intrin->src[first->is_store].ssa;
1243    offset = nir_iadd_imm(&b, offset, nir_intrinsic_base(first->intrin));
1244    if (first != low)
1245       offset = nir_iadd_imm(&b, offset, -(int)diff);
1246 
1247    if (first->is_store) {
1248       nir_ssa_def *low_val = low->intrin->src[low->info->value_src].ssa;
1249       nir_ssa_def *high_val = high->intrin->src[high->info->value_src].ssa;
1250       nir_ssa_def *val = nir_vec2(&b, nir_bitcast_vector(&b, low_val, low_size * 8u),
1251                                       nir_bitcast_vector(&b, high_val, low_size * 8u));
1252       nir_store_shared2_amd(&b, val, offset, .offset1=diff/stride, .st64=st64);
1253    } else {
1254       nir_ssa_def *new_def = nir_load_shared2_amd(&b, low_size * 8u, offset, .offset1=diff/stride,
1255                                                   .st64=st64);
1256       nir_ssa_def_rewrite_uses(&low->intrin->dest.ssa,
1257                                nir_bitcast_vector(&b, nir_channel(&b, new_def, 0), low_bit_size));
1258       nir_ssa_def_rewrite_uses(&high->intrin->dest.ssa,
1259                                nir_bitcast_vector(&b, nir_channel(&b, new_def, 1), high_bit_size));
1260    }
1261 
1262    nir_instr_remove(first->instr);
1263    nir_instr_remove(second->instr);
1264 
1265    return true;
1266 }
1267 
1268 static bool
update_align(struct entry * entry)1269 update_align(struct entry *entry)
1270 {
1271    if (nir_intrinsic_has_align_mul(entry->intrin) &&
1272        (entry->align_mul != nir_intrinsic_align_mul(entry->intrin) ||
1273         entry->align_offset != nir_intrinsic_align_offset(entry->intrin))) {
1274       nir_intrinsic_set_align(entry->intrin, entry->align_mul, entry->align_offset);
1275       return true;
1276    }
1277    return false;
1278 }
1279 
1280 static bool
vectorize_sorted_entries(struct vectorize_ctx * ctx,nir_function_impl * impl,struct util_dynarray * arr)1281 vectorize_sorted_entries(struct vectorize_ctx *ctx, nir_function_impl *impl,
1282                          struct util_dynarray *arr)
1283 {
1284    unsigned num_entries = util_dynarray_num_elements(arr, struct entry *);
1285 
1286    bool progress = false;
1287    for (unsigned first_idx = 0; first_idx < num_entries; first_idx++) {
1288       struct entry *low = *util_dynarray_element(arr, struct entry *, first_idx);
1289       if (!low)
1290          continue;
1291 
1292       for (unsigned second_idx = first_idx + 1; second_idx < num_entries; second_idx++) {
1293          struct entry *high = *util_dynarray_element(arr, struct entry *, second_idx);
1294          if (!high)
1295             continue;
1296 
1297          struct entry *first = low->index < high->index ? low : high;
1298          struct entry *second = low->index < high->index ? high : low;
1299 
1300          uint64_t diff = high->offset_signed - low->offset_signed;
1301          bool separate = diff > get_bit_size(low) / 8u * low->intrin->num_components;
1302          if (separate) {
1303             if (!ctx->options->has_shared2_amd ||
1304                 get_variable_mode(first) != nir_var_mem_shared)
1305                break;
1306 
1307             if (try_vectorize_shared2(impl, ctx, low, high, first, second)) {
1308                low = NULL;
1309                *util_dynarray_element(arr, struct entry *, second_idx) = NULL;
1310                progress = true;
1311                break;
1312             }
1313          } else {
1314             if (try_vectorize(impl, ctx, low, high, first, second)) {
1315                low = low->is_store ? second : first;
1316                *util_dynarray_element(arr, struct entry *, second_idx) = NULL;
1317                progress = true;
1318             }
1319          }
1320       }
1321 
1322       *util_dynarray_element(arr, struct entry *, first_idx) = low;
1323    }
1324 
1325    return progress;
1326 }
1327 
1328 static bool
vectorize_entries(struct vectorize_ctx * ctx,nir_function_impl * impl,struct hash_table * ht)1329 vectorize_entries(struct vectorize_ctx *ctx, nir_function_impl *impl, struct hash_table *ht)
1330 {
1331    if (!ht)
1332       return false;
1333 
1334    bool progress = false;
1335    hash_table_foreach(ht, entry) {
1336       struct util_dynarray *arr = entry->data;
1337       if (!arr->size)
1338          continue;
1339 
1340       qsort(util_dynarray_begin(arr),
1341             util_dynarray_num_elements(arr, struct entry *),
1342             sizeof(struct entry *), &sort_entries);
1343 
1344       while (vectorize_sorted_entries(ctx, impl, arr))
1345          progress = true;
1346 
1347       util_dynarray_foreach(arr, struct entry *, elem) {
1348          if (*elem)
1349             progress |= update_align(*elem);
1350       }
1351    }
1352 
1353    _mesa_hash_table_clear(ht, delete_entry_dynarray);
1354 
1355    return progress;
1356 }
1357 
1358 static bool
handle_barrier(struct vectorize_ctx * ctx,bool * progress,nir_function_impl * impl,nir_instr * instr)1359 handle_barrier(struct vectorize_ctx *ctx, bool *progress, nir_function_impl *impl, nir_instr *instr)
1360 {
1361    unsigned modes = 0;
1362    bool acquire = true;
1363    bool release = true;
1364    if (instr->type == nir_instr_type_intrinsic) {
1365       nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
1366       switch (intrin->intrinsic) {
1367       case nir_intrinsic_group_memory_barrier:
1368       case nir_intrinsic_memory_barrier:
1369          modes = nir_var_mem_ssbo | nir_var_mem_shared | nir_var_mem_global |
1370                  nir_var_mem_task_payload;
1371          break;
1372       /* prevent speculative loads/stores */
1373       case nir_intrinsic_discard_if:
1374       case nir_intrinsic_discard:
1375       case nir_intrinsic_terminate_if:
1376       case nir_intrinsic_terminate:
1377       case nir_intrinsic_launch_mesh_workgroups:
1378          modes = nir_var_all;
1379          break;
1380       case nir_intrinsic_demote_if:
1381       case nir_intrinsic_demote:
1382          acquire = false;
1383          modes = nir_var_all;
1384          break;
1385       case nir_intrinsic_memory_barrier_buffer:
1386          modes = nir_var_mem_ssbo | nir_var_mem_global;
1387          break;
1388       case nir_intrinsic_memory_barrier_shared:
1389          modes = nir_var_mem_shared | nir_var_mem_task_payload;
1390          break;
1391       case nir_intrinsic_scoped_barrier:
1392          if (nir_intrinsic_memory_scope(intrin) == NIR_SCOPE_NONE)
1393             break;
1394 
1395          modes = nir_intrinsic_memory_modes(intrin) & (nir_var_mem_ssbo |
1396                                                        nir_var_mem_shared |
1397                                                        nir_var_mem_global |
1398                                                        nir_var_mem_task_payload);
1399          acquire = nir_intrinsic_memory_semantics(intrin) & NIR_MEMORY_ACQUIRE;
1400          release = nir_intrinsic_memory_semantics(intrin) & NIR_MEMORY_RELEASE;
1401          switch (nir_intrinsic_memory_scope(intrin)) {
1402          case NIR_SCOPE_INVOCATION:
1403             /* a barier should never be required for correctness with these scopes */
1404             modes = 0;
1405             break;
1406          default:
1407             break;
1408          }
1409          break;
1410       default:
1411          return false;
1412       }
1413    } else if (instr->type == nir_instr_type_call) {
1414       modes = nir_var_all;
1415    } else {
1416       return false;
1417    }
1418 
1419    while (modes) {
1420       unsigned mode_index = u_bit_scan(&modes);
1421       if ((1 << mode_index) == nir_var_mem_global) {
1422          /* Global should be rolled in with SSBO */
1423          assert(list_is_empty(&ctx->entries[mode_index]));
1424          assert(ctx->loads[mode_index] == NULL);
1425          assert(ctx->stores[mode_index] == NULL);
1426          continue;
1427       }
1428 
1429       if (acquire)
1430          *progress |= vectorize_entries(ctx, impl, ctx->loads[mode_index]);
1431       if (release)
1432          *progress |= vectorize_entries(ctx, impl, ctx->stores[mode_index]);
1433    }
1434 
1435    return true;
1436 }
1437 
1438 static bool
process_block(nir_function_impl * impl,struct vectorize_ctx * ctx,nir_block * block)1439 process_block(nir_function_impl *impl, struct vectorize_ctx *ctx, nir_block *block)
1440 {
1441    bool progress = false;
1442 
1443    for (unsigned i = 0; i < nir_num_variable_modes; i++) {
1444       list_inithead(&ctx->entries[i]);
1445       if (ctx->loads[i])
1446          _mesa_hash_table_clear(ctx->loads[i], delete_entry_dynarray);
1447       if (ctx->stores[i])
1448          _mesa_hash_table_clear(ctx->stores[i], delete_entry_dynarray);
1449    }
1450 
1451    /* create entries */
1452    unsigned next_index = 0;
1453 
1454    nir_foreach_instr_safe(instr, block) {
1455       if (handle_barrier(ctx, &progress, impl, instr))
1456          continue;
1457 
1458       /* gather information */
1459       if (instr->type != nir_instr_type_intrinsic)
1460          continue;
1461       nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
1462 
1463       const struct intrinsic_info *info = get_info(intrin->intrinsic);
1464       if (!info)
1465          continue;
1466 
1467       nir_variable_mode mode = info->mode;
1468       if (!mode)
1469          mode = nir_src_as_deref(intrin->src[info->deref_src])->modes;
1470       if (!(mode & aliasing_modes(ctx->options->modes)))
1471          continue;
1472       unsigned mode_index = mode_to_index(mode);
1473 
1474       /* create entry */
1475       struct entry *entry = create_entry(ctx, info, intrin);
1476       entry->index = next_index++;
1477 
1478       list_addtail(&entry->head, &ctx->entries[mode_index]);
1479 
1480       /* add the entry to a hash table */
1481 
1482       struct hash_table *adj_ht = NULL;
1483       if (entry->is_store) {
1484          if (!ctx->stores[mode_index])
1485             ctx->stores[mode_index] = _mesa_hash_table_create(ctx, &hash_entry_key, &entry_key_equals);
1486          adj_ht = ctx->stores[mode_index];
1487       } else {
1488          if (!ctx->loads[mode_index])
1489             ctx->loads[mode_index] = _mesa_hash_table_create(ctx, &hash_entry_key, &entry_key_equals);
1490          adj_ht = ctx->loads[mode_index];
1491       }
1492 
1493       uint32_t key_hash = hash_entry_key(entry->key);
1494       struct hash_entry *adj_entry = _mesa_hash_table_search_pre_hashed(adj_ht, key_hash, entry->key);
1495       struct util_dynarray *arr;
1496       if (adj_entry && adj_entry->data) {
1497          arr = (struct util_dynarray *)adj_entry->data;
1498       } else {
1499          arr = ralloc(ctx, struct util_dynarray);
1500          util_dynarray_init(arr, arr);
1501          _mesa_hash_table_insert_pre_hashed(adj_ht, key_hash, entry->key, arr);
1502       }
1503       util_dynarray_append(arr, struct entry *, entry);
1504    }
1505 
1506    /* sort and combine entries */
1507    for (unsigned i = 0; i < nir_num_variable_modes; i++) {
1508       progress |= vectorize_entries(ctx, impl, ctx->loads[i]);
1509       progress |= vectorize_entries(ctx, impl, ctx->stores[i]);
1510    }
1511 
1512    return progress;
1513 }
1514 
1515 bool
nir_opt_load_store_vectorize(nir_shader * shader,const nir_load_store_vectorize_options * options)1516 nir_opt_load_store_vectorize(nir_shader *shader, const nir_load_store_vectorize_options *options)
1517 {
1518    bool progress = false;
1519 
1520    struct vectorize_ctx *ctx = rzalloc(NULL, struct vectorize_ctx);
1521    ctx->shader = shader;
1522    ctx->options = options;
1523 
1524    nir_shader_index_vars(shader, options->modes);
1525 
1526    nir_foreach_function(function, shader) {
1527       if (function->impl) {
1528          if (options->modes & nir_var_function_temp)
1529             nir_function_impl_index_vars(function->impl);
1530 
1531          nir_foreach_block(block, function->impl)
1532             progress |= process_block(function->impl, ctx, block);
1533 
1534          nir_metadata_preserve(function->impl,
1535                                nir_metadata_block_index |
1536                                nir_metadata_dominance |
1537                                nir_metadata_live_ssa_defs);
1538       }
1539    }
1540 
1541    ralloc_free(ctx);
1542    return progress;
1543 }
1544