1 /*
2 * Copyright © 2019 Valve Corporation
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 * and/or sell copies of the Software, and to permit persons to whom the
9 * Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21 * IN THE SOFTWARE.
22 */
23
24 /**
25 * Although it's called a load/store "vectorization" pass, this also combines
26 * intersecting and identical loads/stores. It currently supports derefs, ubo,
27 * ssbo and push constant loads/stores.
28 *
29 * This doesn't handle copy_deref intrinsics and assumes that
30 * nir_lower_alu_to_scalar() has been called and that the IR is free from ALU
31 * modifiers. It also assumes that derefs have explicitly laid out types.
32 *
33 * After vectorization, the backend may want to call nir_lower_alu_to_scalar()
34 * and nir_lower_pack(). Also this creates cast instructions taking derefs as a
35 * source and some parts of NIR may not be able to handle that well.
36 *
37 * There are a few situations where this doesn't vectorize as well as it could:
38 * - It won't turn four consecutive vec3 loads into 3 vec4 loads.
39 * - It doesn't do global vectorization.
40 * Handling these cases probably wouldn't provide much benefit though.
41 *
42 * This probably doesn't handle big-endian GPUs correctly.
43 */
44
45 #include "nir.h"
46 #include "nir_deref.h"
47 #include "nir_builder.h"
48 #include "nir_worklist.h"
49 #include "util/u_dynarray.h"
50
51 #include <stdlib.h>
52
53 struct intrinsic_info {
54 nir_variable_mode mode; /* 0 if the mode is obtained from the deref. */
55 nir_intrinsic_op op;
56 bool is_atomic;
57 /* Indices into nir_intrinsic::src[] or -1 if not applicable. */
58 int resource_src; /* resource (e.g. from vulkan_resource_index) */
59 int base_src; /* offset which it loads/stores from */
60 int deref_src; /* deref which is loads/stores from */
61 int value_src; /* the data it is storing */
62 };
63
64 static const struct intrinsic_info *
get_info(nir_intrinsic_op op)65 get_info(nir_intrinsic_op op) {
66 switch (op) {
67 #define INFO(mode, op, atomic, res, base, deref, val) \
68 case nir_intrinsic_##op: {\
69 static const struct intrinsic_info op##_info = {mode, nir_intrinsic_##op, atomic, res, base, deref, val};\
70 return &op##_info;\
71 }
72 #define LOAD(mode, op, res, base, deref) INFO(mode, load_##op, false, res, base, deref, -1)
73 #define STORE(mode, op, res, base, deref, val) INFO(mode, store_##op, false, res, base, deref, val)
74 #define ATOMIC(mode, type, op, res, base, deref, val) INFO(mode, type##_atomic_##op, true, res, base, deref, val)
75 LOAD(nir_var_mem_push_const, push_constant, -1, 0, -1)
76 LOAD(nir_var_mem_ubo, ubo, 0, 1, -1)
77 LOAD(nir_var_mem_ssbo, ssbo, 0, 1, -1)
78 STORE(nir_var_mem_ssbo, ssbo, 1, 2, -1, 0)
79 LOAD(0, deref, -1, -1, 0)
80 STORE(0, deref, -1, -1, 0, 1)
81 LOAD(nir_var_mem_shared, shared, -1, 0, -1)
82 STORE(nir_var_mem_shared, shared, -1, 1, -1, 0)
83 LOAD(nir_var_mem_global, global, -1, 0, -1)
84 STORE(nir_var_mem_global, global, -1, 1, -1, 0)
85 LOAD(nir_var_mem_task_payload, task_payload, -1, 0, -1)
86 STORE(nir_var_mem_task_payload, task_payload, -1, 1, -1, 0)
87 ATOMIC(nir_var_mem_ssbo, ssbo, add, 0, 1, -1, 2)
88 ATOMIC(nir_var_mem_ssbo, ssbo, imin, 0, 1, -1, 2)
89 ATOMIC(nir_var_mem_ssbo, ssbo, umin, 0, 1, -1, 2)
90 ATOMIC(nir_var_mem_ssbo, ssbo, imax, 0, 1, -1, 2)
91 ATOMIC(nir_var_mem_ssbo, ssbo, umax, 0, 1, -1, 2)
92 ATOMIC(nir_var_mem_ssbo, ssbo, and, 0, 1, -1, 2)
93 ATOMIC(nir_var_mem_ssbo, ssbo, or, 0, 1, -1, 2)
94 ATOMIC(nir_var_mem_ssbo, ssbo, xor, 0, 1, -1, 2)
95 ATOMIC(nir_var_mem_ssbo, ssbo, exchange, 0, 1, -1, 2)
96 ATOMIC(nir_var_mem_ssbo, ssbo, comp_swap, 0, 1, -1, 2)
97 ATOMIC(nir_var_mem_ssbo, ssbo, fadd, 0, 1, -1, 2)
98 ATOMIC(nir_var_mem_ssbo, ssbo, fmin, 0, 1, -1, 2)
99 ATOMIC(nir_var_mem_ssbo, ssbo, fmax, 0, 1, -1, 2)
100 ATOMIC(nir_var_mem_ssbo, ssbo, fcomp_swap, 0, 1, -1, 2)
101 ATOMIC(0, deref, add, -1, -1, 0, 1)
102 ATOMIC(0, deref, imin, -1, -1, 0, 1)
103 ATOMIC(0, deref, umin, -1, -1, 0, 1)
104 ATOMIC(0, deref, imax, -1, -1, 0, 1)
105 ATOMIC(0, deref, umax, -1, -1, 0, 1)
106 ATOMIC(0, deref, and, -1, -1, 0, 1)
107 ATOMIC(0, deref, or, -1, -1, 0, 1)
108 ATOMIC(0, deref, xor, -1, -1, 0, 1)
109 ATOMIC(0, deref, exchange, -1, -1, 0, 1)
110 ATOMIC(0, deref, comp_swap, -1, -1, 0, 1)
111 ATOMIC(0, deref, fadd, -1, -1, 0, 1)
112 ATOMIC(0, deref, fmin, -1, -1, 0, 1)
113 ATOMIC(0, deref, fmax, -1, -1, 0, 1)
114 ATOMIC(0, deref, fcomp_swap, -1, -1, 0, 1)
115 ATOMIC(nir_var_mem_shared, shared, add, -1, 0, -1, 1)
116 ATOMIC(nir_var_mem_shared, shared, imin, -1, 0, -1, 1)
117 ATOMIC(nir_var_mem_shared, shared, umin, -1, 0, -1, 1)
118 ATOMIC(nir_var_mem_shared, shared, imax, -1, 0, -1, 1)
119 ATOMIC(nir_var_mem_shared, shared, umax, -1, 0, -1, 1)
120 ATOMIC(nir_var_mem_shared, shared, and, -1, 0, -1, 1)
121 ATOMIC(nir_var_mem_shared, shared, or, -1, 0, -1, 1)
122 ATOMIC(nir_var_mem_shared, shared, xor, -1, 0, -1, 1)
123 ATOMIC(nir_var_mem_shared, shared, exchange, -1, 0, -1, 1)
124 ATOMIC(nir_var_mem_shared, shared, comp_swap, -1, 0, -1, 1)
125 ATOMIC(nir_var_mem_shared, shared, fadd, -1, 0, -1, 1)
126 ATOMIC(nir_var_mem_shared, shared, fmin, -1, 0, -1, 1)
127 ATOMIC(nir_var_mem_shared, shared, fmax, -1, 0, -1, 1)
128 ATOMIC(nir_var_mem_shared, shared, fcomp_swap, -1, 0, -1, 1)
129 ATOMIC(nir_var_mem_global, global, add, -1, 0, -1, 1)
130 ATOMIC(nir_var_mem_global, global, imin, -1, 0, -1, 1)
131 ATOMIC(nir_var_mem_global, global, umin, -1, 0, -1, 1)
132 ATOMIC(nir_var_mem_global, global, imax, -1, 0, -1, 1)
133 ATOMIC(nir_var_mem_global, global, umax, -1, 0, -1, 1)
134 ATOMIC(nir_var_mem_global, global, and, -1, 0, -1, 1)
135 ATOMIC(nir_var_mem_global, global, or, -1, 0, -1, 1)
136 ATOMIC(nir_var_mem_global, global, xor, -1, 0, -1, 1)
137 ATOMIC(nir_var_mem_global, global, exchange, -1, 0, -1, 1)
138 ATOMIC(nir_var_mem_global, global, comp_swap, -1, 0, -1, 1)
139 ATOMIC(nir_var_mem_global, global, fadd, -1, 0, -1, 1)
140 ATOMIC(nir_var_mem_global, global, fmin, -1, 0, -1, 1)
141 ATOMIC(nir_var_mem_global, global, fmax, -1, 0, -1, 1)
142 ATOMIC(nir_var_mem_global, global, fcomp_swap, -1, 0, -1, 1)
143 ATOMIC(nir_var_mem_task_payload, task_payload, add, -1, 0, -1, 1)
144 ATOMIC(nir_var_mem_task_payload, task_payload, imin, -1, 0, -1, 1)
145 ATOMIC(nir_var_mem_task_payload, task_payload, umin, -1, 0, -1, 1)
146 ATOMIC(nir_var_mem_task_payload, task_payload, imax, -1, 0, -1, 1)
147 ATOMIC(nir_var_mem_task_payload, task_payload, umax, -1, 0, -1, 1)
148 ATOMIC(nir_var_mem_task_payload, task_payload, and, -1, 0, -1, 1)
149 ATOMIC(nir_var_mem_task_payload, task_payload, or, -1, 0, -1, 1)
150 ATOMIC(nir_var_mem_task_payload, task_payload, xor, -1, 0, -1, 1)
151 ATOMIC(nir_var_mem_task_payload, task_payload, exchange, -1, 0, -1, 1)
152 ATOMIC(nir_var_mem_task_payload, task_payload, comp_swap, -1, 0, -1, 1)
153 ATOMIC(nir_var_mem_task_payload, task_payload, fadd, -1, 0, -1, 1)
154 ATOMIC(nir_var_mem_task_payload, task_payload, fmin, -1, 0, -1, 1)
155 ATOMIC(nir_var_mem_task_payload, task_payload, fmax, -1, 0, -1, 1)
156 ATOMIC(nir_var_mem_task_payload, task_payload, fcomp_swap, -1, 0, -1, 1)
157 default:
158 break;
159 #undef ATOMIC
160 #undef STORE
161 #undef LOAD
162 #undef INFO
163 }
164 return NULL;
165 }
166
167 /*
168 * Information used to compare memory operations.
169 * It canonically represents an offset as:
170 * `offset_defs[0]*offset_defs_mul[0] + offset_defs[1]*offset_defs_mul[1] + ...`
171 * "offset_defs" is sorted in ascenting order by the ssa definition's index.
172 * "resource" or "var" may be NULL.
173 */
174 struct entry_key {
175 nir_ssa_def *resource;
176 nir_variable *var;
177 unsigned offset_def_count;
178 nir_ssa_scalar *offset_defs;
179 uint64_t *offset_defs_mul;
180 };
181
182 /* Information on a single memory operation. */
183 struct entry {
184 struct list_head head;
185 unsigned index;
186
187 struct entry_key *key;
188 union {
189 uint64_t offset; /* sign-extended */
190 int64_t offset_signed;
191 };
192 uint32_t align_mul;
193 uint32_t align_offset;
194
195 nir_instr *instr;
196 nir_intrinsic_instr *intrin;
197 const struct intrinsic_info *info;
198 enum gl_access_qualifier access;
199 bool is_store;
200
201 nir_deref_instr *deref;
202 };
203
204 struct vectorize_ctx {
205 nir_shader *shader;
206 const nir_load_store_vectorize_options *options;
207 struct list_head entries[nir_num_variable_modes];
208 struct hash_table *loads[nir_num_variable_modes];
209 struct hash_table *stores[nir_num_variable_modes];
210 };
211
hash_entry_key(const void * key_)212 static uint32_t hash_entry_key(const void *key_)
213 {
214 /* this is careful to not include pointers in the hash calculation so that
215 * the order of the hash table walk is deterministic */
216 struct entry_key *key = (struct entry_key*)key_;
217
218 uint32_t hash = 0;
219 if (key->resource)
220 hash = XXH32(&key->resource->index, sizeof(key->resource->index), hash);
221 if (key->var) {
222 hash = XXH32(&key->var->index, sizeof(key->var->index), hash);
223 unsigned mode = key->var->data.mode;
224 hash = XXH32(&mode, sizeof(mode), hash);
225 }
226
227 for (unsigned i = 0; i < key->offset_def_count; i++) {
228 hash = XXH32(&key->offset_defs[i].def->index, sizeof(key->offset_defs[i].def->index), hash);
229 hash = XXH32(&key->offset_defs[i].comp, sizeof(key->offset_defs[i].comp), hash);
230 }
231
232 hash = XXH32(key->offset_defs_mul, key->offset_def_count * sizeof(uint64_t), hash);
233
234 return hash;
235 }
236
entry_key_equals(const void * a_,const void * b_)237 static bool entry_key_equals(const void *a_, const void *b_)
238 {
239 struct entry_key *a = (struct entry_key*)a_;
240 struct entry_key *b = (struct entry_key*)b_;
241
242 if (a->var != b->var || a->resource != b->resource)
243 return false;
244
245 if (a->offset_def_count != b->offset_def_count)
246 return false;
247
248 for (unsigned i = 0; i < a->offset_def_count; i++) {
249 if (a->offset_defs[i].def != b->offset_defs[i].def ||
250 a->offset_defs[i].comp != b->offset_defs[i].comp)
251 return false;
252 }
253
254 size_t offset_def_mul_size = a->offset_def_count * sizeof(uint64_t);
255 if (a->offset_def_count &&
256 memcmp(a->offset_defs_mul, b->offset_defs_mul, offset_def_mul_size))
257 return false;
258
259 return true;
260 }
261
delete_entry_dynarray(struct hash_entry * entry)262 static void delete_entry_dynarray(struct hash_entry *entry)
263 {
264 struct util_dynarray *arr = (struct util_dynarray *)entry->data;
265 ralloc_free(arr);
266 }
267
sort_entries(const void * a_,const void * b_)268 static int sort_entries(const void *a_, const void *b_)
269 {
270 struct entry *a = *(struct entry*const*)a_;
271 struct entry *b = *(struct entry*const*)b_;
272
273 if (a->offset_signed > b->offset_signed)
274 return 1;
275 else if (a->offset_signed < b->offset_signed)
276 return -1;
277 else
278 return 0;
279 }
280
281 static unsigned
get_bit_size(struct entry * entry)282 get_bit_size(struct entry *entry)
283 {
284 unsigned size = entry->is_store ?
285 entry->intrin->src[entry->info->value_src].ssa->bit_size :
286 entry->intrin->dest.ssa.bit_size;
287 return size == 1 ? 32u : size;
288 }
289
290 /* If "def" is from an alu instruction with the opcode "op" and one of it's
291 * sources is a constant, update "def" to be the non-constant source, fill "c"
292 * with the constant and return true. */
293 static bool
parse_alu(nir_ssa_scalar * def,nir_op op,uint64_t * c)294 parse_alu(nir_ssa_scalar *def, nir_op op, uint64_t *c)
295 {
296 if (!nir_ssa_scalar_is_alu(*def) || nir_ssa_scalar_alu_op(*def) != op)
297 return false;
298
299 nir_ssa_scalar src0 = nir_ssa_scalar_chase_alu_src(*def, 0);
300 nir_ssa_scalar src1 = nir_ssa_scalar_chase_alu_src(*def, 1);
301 if (op != nir_op_ishl && nir_ssa_scalar_is_const(src0)) {
302 *c = nir_ssa_scalar_as_uint(src0);
303 *def = src1;
304 } else if (nir_ssa_scalar_is_const(src1)) {
305 *c = nir_ssa_scalar_as_uint(src1);
306 *def = src0;
307 } else {
308 return false;
309 }
310 return true;
311 }
312
313 /* Parses an offset expression such as "a * 16 + 4" and "(a * 16 + 4) * 64 + 32". */
314 static void
parse_offset(nir_ssa_scalar * base,uint64_t * base_mul,uint64_t * offset)315 parse_offset(nir_ssa_scalar *base, uint64_t *base_mul, uint64_t *offset)
316 {
317 if (nir_ssa_scalar_is_const(*base)) {
318 *offset = nir_ssa_scalar_as_uint(*base);
319 base->def = NULL;
320 return;
321 }
322
323 uint64_t mul = 1;
324 uint64_t add = 0;
325 bool progress = false;
326 do {
327 uint64_t mul2 = 1, add2 = 0;
328
329 progress = parse_alu(base, nir_op_imul, &mul2);
330 mul *= mul2;
331
332 mul2 = 0;
333 progress |= parse_alu(base, nir_op_ishl, &mul2);
334 mul <<= mul2;
335
336 progress |= parse_alu(base, nir_op_iadd, &add2);
337 add += add2 * mul;
338
339 if (nir_ssa_scalar_is_alu(*base) && nir_ssa_scalar_alu_op(*base) == nir_op_mov) {
340 *base = nir_ssa_scalar_chase_alu_src(*base, 0);
341 progress = true;
342 }
343 } while (progress);
344
345 if (base->def->parent_instr->type == nir_instr_type_intrinsic) {
346 nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(base->def->parent_instr);
347 if (intrin->intrinsic == nir_intrinsic_load_vulkan_descriptor)
348 base->def = NULL;
349 }
350
351 *base_mul = mul;
352 *offset = add;
353 }
354
355 static unsigned
type_scalar_size_bytes(const struct glsl_type * type)356 type_scalar_size_bytes(const struct glsl_type *type)
357 {
358 assert(glsl_type_is_vector_or_scalar(type) ||
359 glsl_type_is_matrix(type));
360 return glsl_type_is_boolean(type) ? 4u : glsl_get_bit_size(type) / 8u;
361 }
362
363 static unsigned
add_to_entry_key(nir_ssa_scalar * offset_defs,uint64_t * offset_defs_mul,unsigned offset_def_count,nir_ssa_scalar def,uint64_t mul)364 add_to_entry_key(nir_ssa_scalar *offset_defs, uint64_t *offset_defs_mul,
365 unsigned offset_def_count, nir_ssa_scalar def, uint64_t mul)
366 {
367 mul = util_mask_sign_extend(mul, def.def->bit_size);
368
369 for (unsigned i = 0; i <= offset_def_count; i++) {
370 if (i == offset_def_count || def.def->index > offset_defs[i].def->index) {
371 /* insert before i */
372 memmove(offset_defs + i + 1, offset_defs + i,
373 (offset_def_count - i) * sizeof(nir_ssa_scalar));
374 memmove(offset_defs_mul + i + 1, offset_defs_mul + i,
375 (offset_def_count - i) * sizeof(uint64_t));
376 offset_defs[i] = def;
377 offset_defs_mul[i] = mul;
378 return 1;
379 } else if (def.def == offset_defs[i].def &&
380 def.comp == offset_defs[i].comp) {
381 /* merge with offset_def at i */
382 offset_defs_mul[i] += mul;
383 return 0;
384 }
385 }
386 unreachable("Unreachable.");
387 return 0;
388 }
389
390 static struct entry_key *
create_entry_key_from_deref(void * mem_ctx,struct vectorize_ctx * ctx,nir_deref_path * path,uint64_t * offset_base)391 create_entry_key_from_deref(void *mem_ctx,
392 struct vectorize_ctx *ctx,
393 nir_deref_path *path,
394 uint64_t *offset_base)
395 {
396 unsigned path_len = 0;
397 while (path->path[path_len])
398 path_len++;
399
400 nir_ssa_scalar offset_defs_stack[32];
401 uint64_t offset_defs_mul_stack[32];
402 nir_ssa_scalar *offset_defs = offset_defs_stack;
403 uint64_t *offset_defs_mul = offset_defs_mul_stack;
404 if (path_len > 32) {
405 offset_defs = malloc(path_len * sizeof(nir_ssa_scalar));
406 offset_defs_mul = malloc(path_len * sizeof(uint64_t));
407 }
408 unsigned offset_def_count = 0;
409
410 struct entry_key* key = ralloc(mem_ctx, struct entry_key);
411 key->resource = NULL;
412 key->var = NULL;
413 *offset_base = 0;
414
415 for (unsigned i = 0; i < path_len; i++) {
416 nir_deref_instr *parent = i ? path->path[i - 1] : NULL;
417 nir_deref_instr *deref = path->path[i];
418
419 switch (deref->deref_type) {
420 case nir_deref_type_var: {
421 assert(!parent);
422 key->var = deref->var;
423 break;
424 }
425 case nir_deref_type_array:
426 case nir_deref_type_ptr_as_array: {
427 assert(parent);
428 nir_ssa_def *index = deref->arr.index.ssa;
429 uint32_t stride = nir_deref_instr_array_stride(deref);
430
431 nir_ssa_scalar base = {.def=index, .comp=0};
432 uint64_t offset = 0, base_mul = 1;
433 parse_offset(&base, &base_mul, &offset);
434 offset = util_mask_sign_extend(offset, index->bit_size);
435
436 *offset_base += offset * stride;
437 if (base.def) {
438 offset_def_count += add_to_entry_key(offset_defs, offset_defs_mul,
439 offset_def_count,
440 base, base_mul * stride);
441 }
442 break;
443 }
444 case nir_deref_type_struct: {
445 assert(parent);
446 int offset = glsl_get_struct_field_offset(parent->type, deref->strct.index);
447 *offset_base += offset;
448 break;
449 }
450 case nir_deref_type_cast: {
451 if (!parent)
452 key->resource = deref->parent.ssa;
453 break;
454 }
455 default:
456 unreachable("Unhandled deref type");
457 }
458 }
459
460 key->offset_def_count = offset_def_count;
461 key->offset_defs = ralloc_array(mem_ctx, nir_ssa_scalar, offset_def_count);
462 key->offset_defs_mul = ralloc_array(mem_ctx, uint64_t, offset_def_count);
463 memcpy(key->offset_defs, offset_defs, offset_def_count * sizeof(nir_ssa_scalar));
464 memcpy(key->offset_defs_mul, offset_defs_mul, offset_def_count * sizeof(uint64_t));
465
466 if (offset_defs != offset_defs_stack)
467 free(offset_defs);
468 if (offset_defs_mul != offset_defs_mul_stack)
469 free(offset_defs_mul);
470
471 return key;
472 }
473
474 static unsigned
parse_entry_key_from_offset(struct entry_key * key,unsigned size,unsigned left,nir_ssa_scalar base,uint64_t base_mul,uint64_t * offset)475 parse_entry_key_from_offset(struct entry_key *key, unsigned size, unsigned left,
476 nir_ssa_scalar base, uint64_t base_mul, uint64_t *offset)
477 {
478 uint64_t new_mul;
479 uint64_t new_offset;
480 parse_offset(&base, &new_mul, &new_offset);
481 *offset += new_offset * base_mul;
482
483 if (!base.def)
484 return 0;
485
486 base_mul *= new_mul;
487
488 assert(left >= 1);
489
490 if (left >= 2) {
491 if (nir_ssa_scalar_is_alu(base) && nir_ssa_scalar_alu_op(base) == nir_op_iadd) {
492 nir_ssa_scalar src0 = nir_ssa_scalar_chase_alu_src(base, 0);
493 nir_ssa_scalar src1 = nir_ssa_scalar_chase_alu_src(base, 1);
494 unsigned amount = parse_entry_key_from_offset(key, size, left - 1, src0, base_mul, offset);
495 amount += parse_entry_key_from_offset(key, size + amount, left - amount, src1, base_mul, offset);
496 return amount;
497 }
498 }
499
500 return add_to_entry_key(key->offset_defs, key->offset_defs_mul, size, base, base_mul);
501 }
502
503 static struct entry_key *
create_entry_key_from_offset(void * mem_ctx,nir_ssa_def * base,uint64_t base_mul,uint64_t * offset)504 create_entry_key_from_offset(void *mem_ctx, nir_ssa_def *base, uint64_t base_mul, uint64_t *offset)
505 {
506 struct entry_key *key = ralloc(mem_ctx, struct entry_key);
507 key->resource = NULL;
508 key->var = NULL;
509 if (base) {
510 nir_ssa_scalar offset_defs[32];
511 uint64_t offset_defs_mul[32];
512 key->offset_defs = offset_defs;
513 key->offset_defs_mul = offset_defs_mul;
514
515 nir_ssa_scalar scalar = {.def=base, .comp=0};
516 key->offset_def_count = parse_entry_key_from_offset(key, 0, 32, scalar, base_mul, offset);
517
518 key->offset_defs = ralloc_array(mem_ctx, nir_ssa_scalar, key->offset_def_count);
519 key->offset_defs_mul = ralloc_array(mem_ctx, uint64_t, key->offset_def_count);
520 memcpy(key->offset_defs, offset_defs, key->offset_def_count * sizeof(nir_ssa_scalar));
521 memcpy(key->offset_defs_mul, offset_defs_mul, key->offset_def_count * sizeof(uint64_t));
522 } else {
523 key->offset_def_count = 0;
524 key->offset_defs = NULL;
525 key->offset_defs_mul = NULL;
526 }
527 return key;
528 }
529
530 static nir_variable_mode
get_variable_mode(struct entry * entry)531 get_variable_mode(struct entry *entry)
532 {
533 if (entry->info->mode)
534 return entry->info->mode;
535 assert(entry->deref && util_bitcount(entry->deref->modes) == 1);
536 return entry->deref->modes;
537 }
538
539 static unsigned
mode_to_index(nir_variable_mode mode)540 mode_to_index(nir_variable_mode mode)
541 {
542 assert(util_bitcount(mode) == 1);
543
544 /* Globals and SSBOs should be tracked together */
545 if (mode == nir_var_mem_global)
546 mode = nir_var_mem_ssbo;
547
548 return ffs(mode) - 1;
549 }
550
551 static nir_variable_mode
aliasing_modes(nir_variable_mode modes)552 aliasing_modes(nir_variable_mode modes)
553 {
554 /* Global and SSBO can alias */
555 if (modes & (nir_var_mem_ssbo | nir_var_mem_global))
556 modes |= nir_var_mem_ssbo | nir_var_mem_global;
557 return modes;
558 }
559
560 static void
calc_alignment(struct entry * entry)561 calc_alignment(struct entry *entry)
562 {
563 uint32_t align_mul = 31;
564 for (unsigned i = 0; i < entry->key->offset_def_count; i++) {
565 if (entry->key->offset_defs_mul[i])
566 align_mul = MIN2(align_mul, ffsll(entry->key->offset_defs_mul[i]));
567 }
568
569 entry->align_mul = 1u << (align_mul - 1);
570 bool has_align = nir_intrinsic_infos[entry->intrin->intrinsic].index_map[NIR_INTRINSIC_ALIGN_MUL];
571 if (!has_align || entry->align_mul >= nir_intrinsic_align_mul(entry->intrin)) {
572 entry->align_offset = entry->offset % entry->align_mul;
573 } else {
574 entry->align_mul = nir_intrinsic_align_mul(entry->intrin);
575 entry->align_offset = nir_intrinsic_align_offset(entry->intrin);
576 }
577 }
578
579 static struct entry *
create_entry(struct vectorize_ctx * ctx,const struct intrinsic_info * info,nir_intrinsic_instr * intrin)580 create_entry(struct vectorize_ctx *ctx,
581 const struct intrinsic_info *info,
582 nir_intrinsic_instr *intrin)
583 {
584 struct entry *entry = rzalloc(ctx, struct entry);
585 entry->intrin = intrin;
586 entry->instr = &intrin->instr;
587 entry->info = info;
588 entry->is_store = entry->info->value_src >= 0;
589
590 if (entry->info->deref_src >= 0) {
591 entry->deref = nir_src_as_deref(intrin->src[entry->info->deref_src]);
592 nir_deref_path path;
593 nir_deref_path_init(&path, entry->deref, NULL);
594 entry->key = create_entry_key_from_deref(entry, ctx, &path, &entry->offset);
595 nir_deref_path_finish(&path);
596 } else {
597 nir_ssa_def *base = entry->info->base_src >= 0 ?
598 intrin->src[entry->info->base_src].ssa : NULL;
599 uint64_t offset = 0;
600 if (nir_intrinsic_has_base(intrin))
601 offset += nir_intrinsic_base(intrin);
602 entry->key = create_entry_key_from_offset(entry, base, 1, &offset);
603 entry->offset = offset;
604
605 if (base)
606 entry->offset = util_mask_sign_extend(entry->offset, base->bit_size);
607 }
608
609 if (entry->info->resource_src >= 0)
610 entry->key->resource = intrin->src[entry->info->resource_src].ssa;
611
612 if (nir_intrinsic_has_access(intrin))
613 entry->access = nir_intrinsic_access(intrin);
614 else if (entry->key->var)
615 entry->access = entry->key->var->data.access;
616
617 if (nir_intrinsic_can_reorder(intrin))
618 entry->access |= ACCESS_CAN_REORDER;
619
620 uint32_t restrict_modes = nir_var_shader_in | nir_var_shader_out;
621 restrict_modes |= nir_var_shader_temp | nir_var_function_temp;
622 restrict_modes |= nir_var_uniform | nir_var_mem_push_const;
623 restrict_modes |= nir_var_system_value | nir_var_mem_shared;
624 restrict_modes |= nir_var_mem_task_payload;
625 if (get_variable_mode(entry) & restrict_modes)
626 entry->access |= ACCESS_RESTRICT;
627
628 calc_alignment(entry);
629
630 return entry;
631 }
632
633 static nir_deref_instr *
cast_deref(nir_builder * b,unsigned num_components,unsigned bit_size,nir_deref_instr * deref)634 cast_deref(nir_builder *b, unsigned num_components, unsigned bit_size, nir_deref_instr *deref)
635 {
636 if (glsl_get_components(deref->type) == num_components &&
637 type_scalar_size_bytes(deref->type)*8u == bit_size)
638 return deref;
639
640 enum glsl_base_type types[] = {
641 GLSL_TYPE_UINT8, GLSL_TYPE_UINT16, GLSL_TYPE_UINT, GLSL_TYPE_UINT64};
642 enum glsl_base_type base = types[ffs(bit_size / 8u) - 1u];
643 const struct glsl_type *type = glsl_vector_type(base, num_components);
644
645 if (deref->type == type)
646 return deref;
647
648 return nir_build_deref_cast(b, &deref->dest.ssa, deref->modes, type, 0);
649 }
650
651 /* Return true if "new_bit_size" is a usable bit size for a vectorized load/store
652 * of "low" and "high". */
653 static bool
new_bitsize_acceptable(struct vectorize_ctx * ctx,unsigned new_bit_size,struct entry * low,struct entry * high,unsigned size)654 new_bitsize_acceptable(struct vectorize_ctx *ctx, unsigned new_bit_size,
655 struct entry *low, struct entry *high, unsigned size)
656 {
657 if (size % new_bit_size != 0)
658 return false;
659
660 unsigned new_num_components = size / new_bit_size;
661 if (!nir_num_components_valid(new_num_components))
662 return false;
663
664 unsigned high_offset = high->offset_signed - low->offset_signed;
665
666 /* check nir_extract_bits limitations */
667 unsigned common_bit_size = MIN2(get_bit_size(low), get_bit_size(high));
668 common_bit_size = MIN2(common_bit_size, new_bit_size);
669 if (high_offset > 0)
670 common_bit_size = MIN2(common_bit_size, (1u << (ffs(high_offset * 8) - 1)));
671 if (new_bit_size / common_bit_size > NIR_MAX_VEC_COMPONENTS)
672 return false;
673
674 if (!ctx->options->callback(low->align_mul,
675 low->align_offset,
676 new_bit_size, new_num_components,
677 low->intrin, high->intrin,
678 ctx->options->cb_data))
679 return false;
680
681 if (low->is_store) {
682 unsigned low_size = low->intrin->num_components * get_bit_size(low);
683 unsigned high_size = high->intrin->num_components * get_bit_size(high);
684
685 if (low_size % new_bit_size != 0)
686 return false;
687 if (high_size % new_bit_size != 0)
688 return false;
689
690 unsigned write_mask = nir_intrinsic_write_mask(low->intrin);
691 if (!nir_component_mask_can_reinterpret(write_mask, get_bit_size(low), new_bit_size))
692 return false;
693
694 write_mask = nir_intrinsic_write_mask(high->intrin);
695 if (!nir_component_mask_can_reinterpret(write_mask, get_bit_size(high), new_bit_size))
696 return false;
697 }
698
699 return true;
700 }
701
subtract_deref(nir_builder * b,nir_deref_instr * deref,int64_t offset)702 static nir_deref_instr *subtract_deref(nir_builder *b, nir_deref_instr *deref, int64_t offset)
703 {
704 /* avoid adding another deref to the path */
705 if (deref->deref_type == nir_deref_type_ptr_as_array &&
706 nir_src_is_const(deref->arr.index) &&
707 offset % nir_deref_instr_array_stride(deref) == 0) {
708 unsigned stride = nir_deref_instr_array_stride(deref);
709 nir_ssa_def *index = nir_imm_intN_t(b, nir_src_as_int(deref->arr.index) - offset / stride,
710 deref->dest.ssa.bit_size);
711 return nir_build_deref_ptr_as_array(b, nir_deref_instr_parent(deref), index);
712 }
713
714 if (deref->deref_type == nir_deref_type_array &&
715 nir_src_is_const(deref->arr.index)) {
716 nir_deref_instr *parent = nir_deref_instr_parent(deref);
717 unsigned stride = glsl_get_explicit_stride(parent->type);
718 if (offset % stride == 0)
719 return nir_build_deref_array_imm(
720 b, parent, nir_src_as_int(deref->arr.index) - offset / stride);
721 }
722
723
724 deref = nir_build_deref_cast(b, &deref->dest.ssa, deref->modes,
725 glsl_scalar_type(GLSL_TYPE_UINT8), 1);
726 return nir_build_deref_ptr_as_array(
727 b, deref, nir_imm_intN_t(b, -offset, deref->dest.ssa.bit_size));
728 }
729
730 static void
vectorize_loads(nir_builder * b,struct vectorize_ctx * ctx,struct entry * low,struct entry * high,struct entry * first,struct entry * second,unsigned new_bit_size,unsigned new_num_components,unsigned high_start)731 vectorize_loads(nir_builder *b, struct vectorize_ctx *ctx,
732 struct entry *low, struct entry *high,
733 struct entry *first, struct entry *second,
734 unsigned new_bit_size, unsigned new_num_components,
735 unsigned high_start)
736 {
737 unsigned low_bit_size = get_bit_size(low);
738 unsigned high_bit_size = get_bit_size(high);
739 bool low_bool = low->intrin->dest.ssa.bit_size == 1;
740 bool high_bool = high->intrin->dest.ssa.bit_size == 1;
741 nir_ssa_def *data = &first->intrin->dest.ssa;
742
743 b->cursor = nir_after_instr(first->instr);
744
745 /* update the load's destination size and extract data for each of the original loads */
746 data->num_components = new_num_components;
747 data->bit_size = new_bit_size;
748
749 nir_ssa_def *low_def = nir_extract_bits(
750 b, &data, 1, 0, low->intrin->num_components, low_bit_size);
751 nir_ssa_def *high_def = nir_extract_bits(
752 b, &data, 1, high_start, high->intrin->num_components, high_bit_size);
753
754 /* convert booleans */
755 low_def = low_bool ? nir_i2b(b, low_def) : nir_mov(b, low_def);
756 high_def = high_bool ? nir_i2b(b, high_def) : nir_mov(b, high_def);
757
758 /* update uses */
759 if (first == low) {
760 nir_ssa_def_rewrite_uses_after(&low->intrin->dest.ssa, low_def,
761 high_def->parent_instr);
762 nir_ssa_def_rewrite_uses(&high->intrin->dest.ssa, high_def);
763 } else {
764 nir_ssa_def_rewrite_uses(&low->intrin->dest.ssa, low_def);
765 nir_ssa_def_rewrite_uses_after(&high->intrin->dest.ssa, high_def,
766 high_def->parent_instr);
767 }
768
769 /* update the intrinsic */
770 first->intrin->num_components = new_num_components;
771
772 const struct intrinsic_info *info = first->info;
773
774 /* update the offset */
775 if (first != low && info->base_src >= 0) {
776 /* let nir_opt_algebraic() remove this addition. this doesn't have much
777 * issues with subtracting 16 from expressions like "(i + 1) * 16" because
778 * nir_opt_algebraic() turns them into "i * 16 + 16" */
779 b->cursor = nir_before_instr(first->instr);
780
781 nir_ssa_def *new_base = first->intrin->src[info->base_src].ssa;
782 new_base = nir_iadd_imm(b, new_base, -(int)(high_start / 8u));
783
784 nir_instr_rewrite_src(first->instr, &first->intrin->src[info->base_src],
785 nir_src_for_ssa(new_base));
786 }
787
788 /* update the deref */
789 if (info->deref_src >= 0) {
790 b->cursor = nir_before_instr(first->instr);
791
792 nir_deref_instr *deref = nir_src_as_deref(first->intrin->src[info->deref_src]);
793 if (first != low && high_start != 0)
794 deref = subtract_deref(b, deref, high_start / 8u);
795 first->deref = cast_deref(b, new_num_components, new_bit_size, deref);
796
797 nir_instr_rewrite_src(first->instr, &first->intrin->src[info->deref_src],
798 nir_src_for_ssa(&first->deref->dest.ssa));
799 }
800
801 /* update align */
802 if (nir_intrinsic_has_range_base(first->intrin)) {
803 uint32_t low_base = nir_intrinsic_range_base(low->intrin);
804 uint32_t high_base = nir_intrinsic_range_base(high->intrin);
805 uint32_t low_end = low_base + nir_intrinsic_range(low->intrin);
806 uint32_t high_end = high_base + nir_intrinsic_range(high->intrin);
807
808 nir_intrinsic_set_range_base(first->intrin, low_base);
809 nir_intrinsic_set_range(first->intrin, MAX2(low_end, high_end) - low_base);
810 }
811
812 first->key = low->key;
813 first->offset = low->offset;
814
815 first->align_mul = low->align_mul;
816 first->align_offset = low->align_offset;
817
818 nir_instr_remove(second->instr);
819 }
820
821 static void
vectorize_stores(nir_builder * b,struct vectorize_ctx * ctx,struct entry * low,struct entry * high,struct entry * first,struct entry * second,unsigned new_bit_size,unsigned new_num_components,unsigned high_start)822 vectorize_stores(nir_builder *b, struct vectorize_ctx *ctx,
823 struct entry *low, struct entry *high,
824 struct entry *first, struct entry *second,
825 unsigned new_bit_size, unsigned new_num_components,
826 unsigned high_start)
827 {
828 ASSERTED unsigned low_size = low->intrin->num_components * get_bit_size(low);
829 assert(low_size % new_bit_size == 0);
830
831 b->cursor = nir_before_instr(second->instr);
832
833 /* get new writemasks */
834 uint32_t low_write_mask = nir_intrinsic_write_mask(low->intrin);
835 uint32_t high_write_mask = nir_intrinsic_write_mask(high->intrin);
836 low_write_mask = nir_component_mask_reinterpret(low_write_mask,
837 get_bit_size(low),
838 new_bit_size);
839 high_write_mask = nir_component_mask_reinterpret(high_write_mask,
840 get_bit_size(high),
841 new_bit_size);
842 high_write_mask <<= high_start / new_bit_size;
843
844 uint32_t write_mask = low_write_mask | high_write_mask;
845
846 /* convert booleans */
847 nir_ssa_def *low_val = low->intrin->src[low->info->value_src].ssa;
848 nir_ssa_def *high_val = high->intrin->src[high->info->value_src].ssa;
849 low_val = low_val->bit_size == 1 ? nir_b2i(b, low_val, 32) : low_val;
850 high_val = high_val->bit_size == 1 ? nir_b2i(b, high_val, 32) : high_val;
851
852 /* combine the data */
853 nir_ssa_def *data_channels[NIR_MAX_VEC_COMPONENTS];
854 for (unsigned i = 0; i < new_num_components; i++) {
855 bool set_low = low_write_mask & (1 << i);
856 bool set_high = high_write_mask & (1 << i);
857
858 if (set_low && (!set_high || low == second)) {
859 unsigned offset = i * new_bit_size;
860 data_channels[i] = nir_extract_bits(b, &low_val, 1, offset, 1, new_bit_size);
861 } else if (set_high) {
862 assert(!set_low || high == second);
863 unsigned offset = i * new_bit_size - high_start;
864 data_channels[i] = nir_extract_bits(b, &high_val, 1, offset, 1, new_bit_size);
865 } else {
866 data_channels[i] = nir_ssa_undef(b, 1, new_bit_size);
867 }
868 }
869 nir_ssa_def *data = nir_vec(b, data_channels, new_num_components);
870
871 /* update the intrinsic */
872 nir_intrinsic_set_write_mask(second->intrin, write_mask);
873 second->intrin->num_components = data->num_components;
874
875 const struct intrinsic_info *info = second->info;
876 assert(info->value_src >= 0);
877 nir_instr_rewrite_src(second->instr, &second->intrin->src[info->value_src],
878 nir_src_for_ssa(data));
879
880 /* update the offset */
881 if (second != low && info->base_src >= 0)
882 nir_instr_rewrite_src(second->instr, &second->intrin->src[info->base_src],
883 low->intrin->src[info->base_src]);
884
885 /* update the deref */
886 if (info->deref_src >= 0) {
887 b->cursor = nir_before_instr(second->instr);
888 second->deref = cast_deref(b, new_num_components, new_bit_size,
889 nir_src_as_deref(low->intrin->src[info->deref_src]));
890 nir_instr_rewrite_src(second->instr, &second->intrin->src[info->deref_src],
891 nir_src_for_ssa(&second->deref->dest.ssa));
892 }
893
894 /* update base/align */
895 if (second != low && nir_intrinsic_has_base(second->intrin))
896 nir_intrinsic_set_base(second->intrin, nir_intrinsic_base(low->intrin));
897
898 second->key = low->key;
899 second->offset = low->offset;
900
901 second->align_mul = low->align_mul;
902 second->align_offset = low->align_offset;
903
904 list_del(&first->head);
905 nir_instr_remove(first->instr);
906 }
907
908 /* Returns true if it can prove that "a" and "b" point to different bindings
909 * and either one uses ACCESS_RESTRICT. */
910 static bool
bindings_different_restrict(nir_shader * shader,struct entry * a,struct entry * b)911 bindings_different_restrict(nir_shader *shader, struct entry *a, struct entry *b)
912 {
913 bool different_bindings = false;
914 nir_variable *a_var = NULL, *b_var = NULL;
915 if (a->key->resource && b->key->resource) {
916 nir_binding a_res = nir_chase_binding(nir_src_for_ssa(a->key->resource));
917 nir_binding b_res = nir_chase_binding(nir_src_for_ssa(b->key->resource));
918 if (!a_res.success || !b_res.success)
919 return false;
920
921 if (a_res.num_indices != b_res.num_indices ||
922 a_res.desc_set != b_res.desc_set ||
923 a_res.binding != b_res.binding)
924 different_bindings = true;
925
926 for (unsigned i = 0; i < a_res.num_indices; i++) {
927 if (nir_src_is_const(a_res.indices[i]) && nir_src_is_const(b_res.indices[i]) &&
928 nir_src_as_uint(a_res.indices[i]) != nir_src_as_uint(b_res.indices[i]))
929 different_bindings = true;
930 }
931
932 if (different_bindings) {
933 a_var = nir_get_binding_variable(shader, a_res);
934 b_var = nir_get_binding_variable(shader, b_res);
935 }
936 } else if (a->key->var && b->key->var) {
937 a_var = a->key->var;
938 b_var = b->key->var;
939 different_bindings = a_var != b_var;
940 } else if (!!a->key->resource != !!b->key->resource) {
941 /* comparing global and ssbo access */
942 different_bindings = true;
943
944 if (a->key->resource) {
945 nir_binding a_res = nir_chase_binding(nir_src_for_ssa(a->key->resource));
946 a_var = nir_get_binding_variable(shader, a_res);
947 }
948
949 if (b->key->resource) {
950 nir_binding b_res = nir_chase_binding(nir_src_for_ssa(b->key->resource));
951 b_var = nir_get_binding_variable(shader, b_res);
952 }
953 } else {
954 return false;
955 }
956
957 unsigned a_access = a->access | (a_var ? a_var->data.access : 0);
958 unsigned b_access = b->access | (b_var ? b_var->data.access : 0);
959
960 return different_bindings &&
961 ((a_access | b_access) & ACCESS_RESTRICT);
962 }
963
964 static int64_t
compare_entries(struct entry * a,struct entry * b)965 compare_entries(struct entry *a, struct entry *b)
966 {
967 if (!entry_key_equals(a->key, b->key))
968 return INT64_MAX;
969 return b->offset_signed - a->offset_signed;
970 }
971
972 static bool
may_alias(nir_shader * shader,struct entry * a,struct entry * b)973 may_alias(nir_shader *shader, struct entry *a, struct entry *b)
974 {
975 assert(mode_to_index(get_variable_mode(a)) ==
976 mode_to_index(get_variable_mode(b)));
977
978 if ((a->access | b->access) & ACCESS_CAN_REORDER)
979 return false;
980
981 /* if the resources/variables are definitively different and both have
982 * ACCESS_RESTRICT, we can assume they do not alias. */
983 if (bindings_different_restrict(shader, a, b))
984 return false;
985
986 /* we can't compare offsets if the resources/variables might be different */
987 if (a->key->var != b->key->var || a->key->resource != b->key->resource)
988 return true;
989
990 /* use adjacency information */
991 /* TODO: we can look closer at the entry keys */
992 int64_t diff = compare_entries(a, b);
993 if (diff != INT64_MAX) {
994 /* with atomics, intrin->num_components can be 0 */
995 if (diff < 0)
996 return llabs(diff) < MAX2(b->intrin->num_components, 1u) * (get_bit_size(b) / 8u);
997 else
998 return diff < MAX2(a->intrin->num_components, 1u) * (get_bit_size(a) / 8u);
999 }
1000
1001 /* TODO: we can use deref information */
1002
1003 return true;
1004 }
1005
1006 static bool
check_for_aliasing(struct vectorize_ctx * ctx,struct entry * first,struct entry * second)1007 check_for_aliasing(struct vectorize_ctx *ctx, struct entry *first, struct entry *second)
1008 {
1009 nir_variable_mode mode = get_variable_mode(first);
1010 if (mode & (nir_var_uniform | nir_var_system_value |
1011 nir_var_mem_push_const | nir_var_mem_ubo))
1012 return false;
1013
1014 unsigned mode_index = mode_to_index(mode);
1015 if (first->is_store) {
1016 /* find first entry that aliases "first" */
1017 list_for_each_entry_from(struct entry, next, first, &ctx->entries[mode_index], head) {
1018 if (next == first)
1019 continue;
1020 if (next == second)
1021 return false;
1022 if (may_alias(ctx->shader, first, next))
1023 return true;
1024 }
1025 } else {
1026 /* find previous store that aliases this load */
1027 list_for_each_entry_from_rev(struct entry, prev, second, &ctx->entries[mode_index], head) {
1028 if (prev == second)
1029 continue;
1030 if (prev == first)
1031 return false;
1032 if (prev->is_store && may_alias(ctx->shader, second, prev))
1033 return true;
1034 }
1035 }
1036
1037 return false;
1038 }
1039
1040 static uint64_t
calc_gcd(uint64_t a,uint64_t b)1041 calc_gcd(uint64_t a, uint64_t b)
1042 {
1043 while (b != 0) {
1044 int tmp_a = a;
1045 a = b;
1046 b = tmp_a % b;
1047 }
1048 return a;
1049 }
1050
1051 static uint64_t
round_down(uint64_t a,uint64_t b)1052 round_down(uint64_t a, uint64_t b)
1053 {
1054 return a / b * b;
1055 }
1056
1057 static bool
addition_wraps(uint64_t a,uint64_t b,unsigned bits)1058 addition_wraps(uint64_t a, uint64_t b, unsigned bits)
1059 {
1060 uint64_t mask = BITFIELD64_MASK(bits);
1061 return ((a + b) & mask) < (a & mask);
1062 }
1063
1064 /* Return true if the addition of "low"'s offset and "high_offset" could wrap
1065 * around.
1066 *
1067 * This is to prevent a situation where the hardware considers the high load
1068 * out-of-bounds after vectorization if the low load is out-of-bounds, even if
1069 * the wrap-around from the addition could make the high load in-bounds.
1070 */
1071 static bool
check_for_robustness(struct vectorize_ctx * ctx,struct entry * low,uint64_t high_offset)1072 check_for_robustness(struct vectorize_ctx *ctx, struct entry *low, uint64_t high_offset)
1073 {
1074 nir_variable_mode mode = get_variable_mode(low);
1075 if (!(mode & ctx->options->robust_modes))
1076 return false;
1077
1078 /* First, try to use alignment information in case the application provided some. If the addition
1079 * of the maximum offset of the low load and "high_offset" wraps around, we can't combine the low
1080 * and high loads.
1081 */
1082 uint64_t max_low = round_down(UINT64_MAX, low->align_mul) + low->align_offset;
1083 if (!addition_wraps(max_low, high_offset, 64))
1084 return false;
1085
1086 /* We can't obtain addition_bits */
1087 if (low->info->base_src < 0)
1088 return true;
1089
1090 /* Second, use information about the factors from address calculation (offset_defs_mul). These
1091 * are not guaranteed to be power-of-2.
1092 */
1093 uint64_t stride = 0;
1094 for (unsigned i = 0; i < low->key->offset_def_count; i++)
1095 stride = calc_gcd(low->key->offset_defs_mul[i], stride);
1096
1097 unsigned addition_bits = low->intrin->src[low->info->base_src].ssa->bit_size;
1098 /* low's offset must be a multiple of "stride" plus "low->offset". */
1099 max_low = low->offset;
1100 if (stride)
1101 max_low = round_down(BITFIELD64_MASK(addition_bits), stride) + (low->offset % stride);
1102 return addition_wraps(max_low, high_offset, addition_bits);
1103 }
1104
1105 static bool
is_strided_vector(const struct glsl_type * type)1106 is_strided_vector(const struct glsl_type *type)
1107 {
1108 if (glsl_type_is_vector(type)) {
1109 unsigned explicit_stride = glsl_get_explicit_stride(type);
1110 return explicit_stride != 0 && explicit_stride !=
1111 type_scalar_size_bytes(glsl_get_array_element(type));
1112 } else {
1113 return false;
1114 }
1115 }
1116
1117 static bool
can_vectorize(struct vectorize_ctx * ctx,struct entry * first,struct entry * second)1118 can_vectorize(struct vectorize_ctx *ctx, struct entry *first, struct entry *second)
1119 {
1120 if (!(get_variable_mode(first) & ctx->options->modes) ||
1121 !(get_variable_mode(second) & ctx->options->modes))
1122 return false;
1123
1124 if (check_for_aliasing(ctx, first, second))
1125 return false;
1126
1127 /* we can only vectorize non-volatile loads/stores of the same type and with
1128 * the same access */
1129 if (first->info != second->info || first->access != second->access ||
1130 (first->access & ACCESS_VOLATILE) || first->info->is_atomic)
1131 return false;
1132
1133 return true;
1134 }
1135
1136 static bool
try_vectorize(nir_function_impl * impl,struct vectorize_ctx * ctx,struct entry * low,struct entry * high,struct entry * first,struct entry * second)1137 try_vectorize(nir_function_impl *impl, struct vectorize_ctx *ctx,
1138 struct entry *low, struct entry *high,
1139 struct entry *first, struct entry *second)
1140 {
1141 if (!can_vectorize(ctx, first, second))
1142 return false;
1143
1144 uint64_t diff = high->offset_signed - low->offset_signed;
1145 if (check_for_robustness(ctx, low, diff))
1146 return false;
1147
1148 /* don't attempt to vectorize accesses of row-major matrix columns */
1149 if (first->deref) {
1150 const struct glsl_type *first_type = first->deref->type;
1151 const struct glsl_type *second_type = second->deref->type;
1152 if (is_strided_vector(first_type) || is_strided_vector(second_type))
1153 return false;
1154 }
1155
1156 /* gather information */
1157 unsigned low_bit_size = get_bit_size(low);
1158 unsigned high_bit_size = get_bit_size(high);
1159 unsigned low_size = low->intrin->num_components * low_bit_size;
1160 unsigned high_size = high->intrin->num_components * high_bit_size;
1161 unsigned new_size = MAX2(diff * 8u + high_size, low_size);
1162
1163 /* find a good bit size for the new load/store */
1164 unsigned new_bit_size = 0;
1165 if (new_bitsize_acceptable(ctx, low_bit_size, low, high, new_size)) {
1166 new_bit_size = low_bit_size;
1167 } else if (low_bit_size != high_bit_size &&
1168 new_bitsize_acceptable(ctx, high_bit_size, low, high, new_size)) {
1169 new_bit_size = high_bit_size;
1170 } else {
1171 new_bit_size = 64;
1172 for (; new_bit_size >= 8; new_bit_size /= 2) {
1173 /* don't repeat trying out bitsizes */
1174 if (new_bit_size == low_bit_size || new_bit_size == high_bit_size)
1175 continue;
1176 if (new_bitsize_acceptable(ctx, new_bit_size, low, high, new_size))
1177 break;
1178 }
1179 if (new_bit_size < 8)
1180 return false;
1181 }
1182 unsigned new_num_components = new_size / new_bit_size;
1183
1184 /* vectorize the loads/stores */
1185 nir_builder b;
1186 nir_builder_init(&b, impl);
1187
1188 if (first->is_store)
1189 vectorize_stores(&b, ctx, low, high, first, second,
1190 new_bit_size, new_num_components, diff * 8u);
1191 else
1192 vectorize_loads(&b, ctx, low, high, first, second,
1193 new_bit_size, new_num_components, diff * 8u);
1194
1195 return true;
1196 }
1197
1198 static bool
try_vectorize_shared2(nir_function_impl * impl,struct vectorize_ctx * ctx,struct entry * low,struct entry * high,struct entry * first,struct entry * second)1199 try_vectorize_shared2(nir_function_impl *impl, struct vectorize_ctx *ctx,
1200 struct entry *low, struct entry *high,
1201 struct entry *first, struct entry *second)
1202 {
1203 if (!can_vectorize(ctx, first, second) || first->deref)
1204 return false;
1205
1206 unsigned low_bit_size = get_bit_size(low);
1207 unsigned high_bit_size = get_bit_size(high);
1208 unsigned low_size = low->intrin->num_components * low_bit_size / 8;
1209 unsigned high_size = high->intrin->num_components * high_bit_size / 8;
1210 if ((low_size != 4 && low_size != 8) || (high_size != 4 && high_size != 8))
1211 return false;
1212 if (low_size != high_size)
1213 return false;
1214 if (low->align_mul % low_size || low->align_offset % low_size)
1215 return false;
1216 if (high->align_mul % low_size || high->align_offset % low_size)
1217 return false;
1218
1219 uint64_t diff = high->offset_signed - low->offset_signed;
1220 bool st64 = diff % (64 * low_size) == 0;
1221 unsigned stride = st64 ? 64 * low_size : low_size;
1222 if (diff % stride || diff > 255 * stride)
1223 return false;
1224
1225 /* try to avoid creating accesses we can't combine additions/offsets into */
1226 if (high->offset > 255 * stride || (st64 && high->offset % stride))
1227 return false;
1228
1229 if (first->is_store) {
1230 if (nir_intrinsic_write_mask(low->intrin) != BITFIELD_MASK(low->intrin->num_components))
1231 return false;
1232 if (nir_intrinsic_write_mask(high->intrin) != BITFIELD_MASK(high->intrin->num_components))
1233 return false;
1234 }
1235
1236 /* vectorize the accesses */
1237 nir_builder b;
1238 nir_builder_init(&b, impl);
1239
1240 b.cursor = nir_after_instr(first->is_store ? second->instr : first->instr);
1241
1242 nir_ssa_def *offset = first->intrin->src[first->is_store].ssa;
1243 offset = nir_iadd_imm(&b, offset, nir_intrinsic_base(first->intrin));
1244 if (first != low)
1245 offset = nir_iadd_imm(&b, offset, -(int)diff);
1246
1247 if (first->is_store) {
1248 nir_ssa_def *low_val = low->intrin->src[low->info->value_src].ssa;
1249 nir_ssa_def *high_val = high->intrin->src[high->info->value_src].ssa;
1250 nir_ssa_def *val = nir_vec2(&b, nir_bitcast_vector(&b, low_val, low_size * 8u),
1251 nir_bitcast_vector(&b, high_val, low_size * 8u));
1252 nir_store_shared2_amd(&b, val, offset, .offset1=diff/stride, .st64=st64);
1253 } else {
1254 nir_ssa_def *new_def = nir_load_shared2_amd(&b, low_size * 8u, offset, .offset1=diff/stride,
1255 .st64=st64);
1256 nir_ssa_def_rewrite_uses(&low->intrin->dest.ssa,
1257 nir_bitcast_vector(&b, nir_channel(&b, new_def, 0), low_bit_size));
1258 nir_ssa_def_rewrite_uses(&high->intrin->dest.ssa,
1259 nir_bitcast_vector(&b, nir_channel(&b, new_def, 1), high_bit_size));
1260 }
1261
1262 nir_instr_remove(first->instr);
1263 nir_instr_remove(second->instr);
1264
1265 return true;
1266 }
1267
1268 static bool
update_align(struct entry * entry)1269 update_align(struct entry *entry)
1270 {
1271 if (nir_intrinsic_has_align_mul(entry->intrin) &&
1272 (entry->align_mul != nir_intrinsic_align_mul(entry->intrin) ||
1273 entry->align_offset != nir_intrinsic_align_offset(entry->intrin))) {
1274 nir_intrinsic_set_align(entry->intrin, entry->align_mul, entry->align_offset);
1275 return true;
1276 }
1277 return false;
1278 }
1279
1280 static bool
vectorize_sorted_entries(struct vectorize_ctx * ctx,nir_function_impl * impl,struct util_dynarray * arr)1281 vectorize_sorted_entries(struct vectorize_ctx *ctx, nir_function_impl *impl,
1282 struct util_dynarray *arr)
1283 {
1284 unsigned num_entries = util_dynarray_num_elements(arr, struct entry *);
1285
1286 bool progress = false;
1287 for (unsigned first_idx = 0; first_idx < num_entries; first_idx++) {
1288 struct entry *low = *util_dynarray_element(arr, struct entry *, first_idx);
1289 if (!low)
1290 continue;
1291
1292 for (unsigned second_idx = first_idx + 1; second_idx < num_entries; second_idx++) {
1293 struct entry *high = *util_dynarray_element(arr, struct entry *, second_idx);
1294 if (!high)
1295 continue;
1296
1297 struct entry *first = low->index < high->index ? low : high;
1298 struct entry *second = low->index < high->index ? high : low;
1299
1300 uint64_t diff = high->offset_signed - low->offset_signed;
1301 bool separate = diff > get_bit_size(low) / 8u * low->intrin->num_components;
1302 if (separate) {
1303 if (!ctx->options->has_shared2_amd ||
1304 get_variable_mode(first) != nir_var_mem_shared)
1305 break;
1306
1307 if (try_vectorize_shared2(impl, ctx, low, high, first, second)) {
1308 low = NULL;
1309 *util_dynarray_element(arr, struct entry *, second_idx) = NULL;
1310 progress = true;
1311 break;
1312 }
1313 } else {
1314 if (try_vectorize(impl, ctx, low, high, first, second)) {
1315 low = low->is_store ? second : first;
1316 *util_dynarray_element(arr, struct entry *, second_idx) = NULL;
1317 progress = true;
1318 }
1319 }
1320 }
1321
1322 *util_dynarray_element(arr, struct entry *, first_idx) = low;
1323 }
1324
1325 return progress;
1326 }
1327
1328 static bool
vectorize_entries(struct vectorize_ctx * ctx,nir_function_impl * impl,struct hash_table * ht)1329 vectorize_entries(struct vectorize_ctx *ctx, nir_function_impl *impl, struct hash_table *ht)
1330 {
1331 if (!ht)
1332 return false;
1333
1334 bool progress = false;
1335 hash_table_foreach(ht, entry) {
1336 struct util_dynarray *arr = entry->data;
1337 if (!arr->size)
1338 continue;
1339
1340 qsort(util_dynarray_begin(arr),
1341 util_dynarray_num_elements(arr, struct entry *),
1342 sizeof(struct entry *), &sort_entries);
1343
1344 while (vectorize_sorted_entries(ctx, impl, arr))
1345 progress = true;
1346
1347 util_dynarray_foreach(arr, struct entry *, elem) {
1348 if (*elem)
1349 progress |= update_align(*elem);
1350 }
1351 }
1352
1353 _mesa_hash_table_clear(ht, delete_entry_dynarray);
1354
1355 return progress;
1356 }
1357
1358 static bool
handle_barrier(struct vectorize_ctx * ctx,bool * progress,nir_function_impl * impl,nir_instr * instr)1359 handle_barrier(struct vectorize_ctx *ctx, bool *progress, nir_function_impl *impl, nir_instr *instr)
1360 {
1361 unsigned modes = 0;
1362 bool acquire = true;
1363 bool release = true;
1364 if (instr->type == nir_instr_type_intrinsic) {
1365 nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
1366 switch (intrin->intrinsic) {
1367 case nir_intrinsic_group_memory_barrier:
1368 case nir_intrinsic_memory_barrier:
1369 modes = nir_var_mem_ssbo | nir_var_mem_shared | nir_var_mem_global |
1370 nir_var_mem_task_payload;
1371 break;
1372 /* prevent speculative loads/stores */
1373 case nir_intrinsic_discard_if:
1374 case nir_intrinsic_discard:
1375 case nir_intrinsic_terminate_if:
1376 case nir_intrinsic_terminate:
1377 case nir_intrinsic_launch_mesh_workgroups:
1378 modes = nir_var_all;
1379 break;
1380 case nir_intrinsic_demote_if:
1381 case nir_intrinsic_demote:
1382 acquire = false;
1383 modes = nir_var_all;
1384 break;
1385 case nir_intrinsic_memory_barrier_buffer:
1386 modes = nir_var_mem_ssbo | nir_var_mem_global;
1387 break;
1388 case nir_intrinsic_memory_barrier_shared:
1389 modes = nir_var_mem_shared | nir_var_mem_task_payload;
1390 break;
1391 case nir_intrinsic_scoped_barrier:
1392 if (nir_intrinsic_memory_scope(intrin) == NIR_SCOPE_NONE)
1393 break;
1394
1395 modes = nir_intrinsic_memory_modes(intrin) & (nir_var_mem_ssbo |
1396 nir_var_mem_shared |
1397 nir_var_mem_global |
1398 nir_var_mem_task_payload);
1399 acquire = nir_intrinsic_memory_semantics(intrin) & NIR_MEMORY_ACQUIRE;
1400 release = nir_intrinsic_memory_semantics(intrin) & NIR_MEMORY_RELEASE;
1401 switch (nir_intrinsic_memory_scope(intrin)) {
1402 case NIR_SCOPE_INVOCATION:
1403 /* a barier should never be required for correctness with these scopes */
1404 modes = 0;
1405 break;
1406 default:
1407 break;
1408 }
1409 break;
1410 default:
1411 return false;
1412 }
1413 } else if (instr->type == nir_instr_type_call) {
1414 modes = nir_var_all;
1415 } else {
1416 return false;
1417 }
1418
1419 while (modes) {
1420 unsigned mode_index = u_bit_scan(&modes);
1421 if ((1 << mode_index) == nir_var_mem_global) {
1422 /* Global should be rolled in with SSBO */
1423 assert(list_is_empty(&ctx->entries[mode_index]));
1424 assert(ctx->loads[mode_index] == NULL);
1425 assert(ctx->stores[mode_index] == NULL);
1426 continue;
1427 }
1428
1429 if (acquire)
1430 *progress |= vectorize_entries(ctx, impl, ctx->loads[mode_index]);
1431 if (release)
1432 *progress |= vectorize_entries(ctx, impl, ctx->stores[mode_index]);
1433 }
1434
1435 return true;
1436 }
1437
1438 static bool
process_block(nir_function_impl * impl,struct vectorize_ctx * ctx,nir_block * block)1439 process_block(nir_function_impl *impl, struct vectorize_ctx *ctx, nir_block *block)
1440 {
1441 bool progress = false;
1442
1443 for (unsigned i = 0; i < nir_num_variable_modes; i++) {
1444 list_inithead(&ctx->entries[i]);
1445 if (ctx->loads[i])
1446 _mesa_hash_table_clear(ctx->loads[i], delete_entry_dynarray);
1447 if (ctx->stores[i])
1448 _mesa_hash_table_clear(ctx->stores[i], delete_entry_dynarray);
1449 }
1450
1451 /* create entries */
1452 unsigned next_index = 0;
1453
1454 nir_foreach_instr_safe(instr, block) {
1455 if (handle_barrier(ctx, &progress, impl, instr))
1456 continue;
1457
1458 /* gather information */
1459 if (instr->type != nir_instr_type_intrinsic)
1460 continue;
1461 nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
1462
1463 const struct intrinsic_info *info = get_info(intrin->intrinsic);
1464 if (!info)
1465 continue;
1466
1467 nir_variable_mode mode = info->mode;
1468 if (!mode)
1469 mode = nir_src_as_deref(intrin->src[info->deref_src])->modes;
1470 if (!(mode & aliasing_modes(ctx->options->modes)))
1471 continue;
1472 unsigned mode_index = mode_to_index(mode);
1473
1474 /* create entry */
1475 struct entry *entry = create_entry(ctx, info, intrin);
1476 entry->index = next_index++;
1477
1478 list_addtail(&entry->head, &ctx->entries[mode_index]);
1479
1480 /* add the entry to a hash table */
1481
1482 struct hash_table *adj_ht = NULL;
1483 if (entry->is_store) {
1484 if (!ctx->stores[mode_index])
1485 ctx->stores[mode_index] = _mesa_hash_table_create(ctx, &hash_entry_key, &entry_key_equals);
1486 adj_ht = ctx->stores[mode_index];
1487 } else {
1488 if (!ctx->loads[mode_index])
1489 ctx->loads[mode_index] = _mesa_hash_table_create(ctx, &hash_entry_key, &entry_key_equals);
1490 adj_ht = ctx->loads[mode_index];
1491 }
1492
1493 uint32_t key_hash = hash_entry_key(entry->key);
1494 struct hash_entry *adj_entry = _mesa_hash_table_search_pre_hashed(adj_ht, key_hash, entry->key);
1495 struct util_dynarray *arr;
1496 if (adj_entry && adj_entry->data) {
1497 arr = (struct util_dynarray *)adj_entry->data;
1498 } else {
1499 arr = ralloc(ctx, struct util_dynarray);
1500 util_dynarray_init(arr, arr);
1501 _mesa_hash_table_insert_pre_hashed(adj_ht, key_hash, entry->key, arr);
1502 }
1503 util_dynarray_append(arr, struct entry *, entry);
1504 }
1505
1506 /* sort and combine entries */
1507 for (unsigned i = 0; i < nir_num_variable_modes; i++) {
1508 progress |= vectorize_entries(ctx, impl, ctx->loads[i]);
1509 progress |= vectorize_entries(ctx, impl, ctx->stores[i]);
1510 }
1511
1512 return progress;
1513 }
1514
1515 bool
nir_opt_load_store_vectorize(nir_shader * shader,const nir_load_store_vectorize_options * options)1516 nir_opt_load_store_vectorize(nir_shader *shader, const nir_load_store_vectorize_options *options)
1517 {
1518 bool progress = false;
1519
1520 struct vectorize_ctx *ctx = rzalloc(NULL, struct vectorize_ctx);
1521 ctx->shader = shader;
1522 ctx->options = options;
1523
1524 nir_shader_index_vars(shader, options->modes);
1525
1526 nir_foreach_function(function, shader) {
1527 if (function->impl) {
1528 if (options->modes & nir_var_function_temp)
1529 nir_function_impl_index_vars(function->impl);
1530
1531 nir_foreach_block(block, function->impl)
1532 progress |= process_block(function->impl, ctx, block);
1533
1534 nir_metadata_preserve(function->impl,
1535 nir_metadata_block_index |
1536 nir_metadata_dominance |
1537 nir_metadata_live_ssa_defs);
1538 }
1539 }
1540
1541 ralloc_free(ctx);
1542 return progress;
1543 }
1544