• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright © Microsoft Corporation
3  *
4  * Permission is hereby granted, free of charge, to any person obtaining a
5  * copy of this software and associated documentation files (the "Software"),
6  * to deal in the Software without restriction, including without limitation
7  * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8  * and/or sell copies of the Software, and to permit persons to whom the
9  * Software is furnished to do so, subject to the following conditions:
10  *
11  * The above copyright notice and this permission notice (including the next
12  * paragraph) shall be included in all copies or substantial portions of the
13  * Software.
14  *
15  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18  * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20  * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21  * IN THE SOFTWARE.
22  */
23 
24 #include "dxil_nir.h"
25 #include "dxil_module.h"
26 
27 #include "nir_builder.h"
28 #include "nir_deref.h"
29 #include "nir_worklist.h"
30 #include "nir_to_dxil.h"
31 #include "util/u_math.h"
32 #include "vulkan/vulkan_core.h"
33 
34 static void
cl_type_size_align(const struct glsl_type * type,unsigned * size,unsigned * align)35 cl_type_size_align(const struct glsl_type *type, unsigned *size,
36                    unsigned *align)
37 {
38    *size = glsl_get_cl_size(type);
39    *align = glsl_get_cl_alignment(type);
40 }
41 
42 static nir_def *
load_comps_to_vec(nir_builder * b,unsigned src_bit_size,nir_def ** src_comps,unsigned num_src_comps,unsigned dst_bit_size)43 load_comps_to_vec(nir_builder *b, unsigned src_bit_size,
44                   nir_def **src_comps, unsigned num_src_comps,
45                   unsigned dst_bit_size)
46 {
47    if (src_bit_size == dst_bit_size)
48       return nir_vec(b, src_comps, num_src_comps);
49    else if (src_bit_size > dst_bit_size)
50       return nir_extract_bits(b, src_comps, num_src_comps, 0, src_bit_size * num_src_comps / dst_bit_size, dst_bit_size);
51 
52    unsigned num_dst_comps = DIV_ROUND_UP(num_src_comps * src_bit_size, dst_bit_size);
53    unsigned comps_per_dst = dst_bit_size / src_bit_size;
54    nir_def *dst_comps[4];
55 
56    for (unsigned i = 0; i < num_dst_comps; i++) {
57       unsigned src_offs = i * comps_per_dst;
58 
59       dst_comps[i] = nir_u2uN(b, src_comps[src_offs], dst_bit_size);
60       for (unsigned j = 1; j < comps_per_dst && src_offs + j < num_src_comps; j++) {
61          nir_def *tmp = nir_ishl_imm(b, nir_u2uN(b, src_comps[src_offs + j], dst_bit_size),
62                                          j * src_bit_size);
63          dst_comps[i] = nir_ior(b, dst_comps[i], tmp);
64       }
65    }
66 
67    return nir_vec(b, dst_comps, num_dst_comps);
68 }
69 
70 static bool
lower_32b_offset_load(nir_builder * b,nir_intrinsic_instr * intr,nir_variable * var)71 lower_32b_offset_load(nir_builder *b, nir_intrinsic_instr *intr, nir_variable *var)
72 {
73    unsigned bit_size = intr->def.bit_size;
74    unsigned num_components = intr->def.num_components;
75    unsigned num_bits = num_components * bit_size;
76 
77    b->cursor = nir_before_instr(&intr->instr);
78 
79    nir_def *offset = intr->src[0].ssa;
80    if (intr->intrinsic == nir_intrinsic_load_shared)
81       offset = nir_iadd_imm(b, offset, nir_intrinsic_base(intr));
82    else
83       offset = nir_u2u32(b, offset);
84    nir_def *index = nir_ushr_imm(b, offset, 2);
85    nir_def *comps[NIR_MAX_VEC_COMPONENTS];
86    nir_def *comps_32bit[NIR_MAX_VEC_COMPONENTS * 2];
87 
88    /* We need to split loads in 32-bit accesses because the buffer
89     * is an i32 array and DXIL does not support type casts.
90     */
91    unsigned num_32bit_comps = DIV_ROUND_UP(num_bits, 32);
92    for (unsigned i = 0; i < num_32bit_comps; i++)
93       comps_32bit[i] = nir_load_array_var(b, var, nir_iadd_imm(b, index, i));
94    unsigned num_comps_per_pass = MIN2(num_32bit_comps, 4);
95 
96    for (unsigned i = 0; i < num_32bit_comps; i += num_comps_per_pass) {
97       unsigned num_vec32_comps = MIN2(num_32bit_comps - i, 4);
98       unsigned num_dest_comps = num_vec32_comps * 32 / bit_size;
99       nir_def *vec32 = nir_vec(b, &comps_32bit[i], num_vec32_comps);
100 
101       /* If we have 16 bits or less to load we need to adjust the u32 value so
102        * we can always extract the LSB.
103        */
104       if (num_bits <= 16) {
105          nir_def *shift =
106             nir_imul_imm(b, nir_iand_imm(b, offset, 3), 8);
107          vec32 = nir_ushr(b, vec32, shift);
108       }
109 
110       /* And now comes the pack/unpack step to match the original type. */
111       unsigned dest_index = i * 32 / bit_size;
112       nir_def *temp_vec = nir_extract_bits(b, &vec32, 1, 0, num_dest_comps, bit_size);
113       for (unsigned comp = 0; comp < num_dest_comps; ++comp, ++dest_index)
114          comps[dest_index] = nir_channel(b, temp_vec, comp);
115    }
116 
117    nir_def *result = nir_vec(b, comps, num_components);
118    nir_def_rewrite_uses(&intr->def, result);
119    nir_instr_remove(&intr->instr);
120 
121    return true;
122 }
123 
124 static void
lower_masked_store_vec32(nir_builder * b,nir_def * offset,nir_def * index,nir_def * vec32,unsigned num_bits,nir_variable * var,unsigned alignment)125 lower_masked_store_vec32(nir_builder *b, nir_def *offset, nir_def *index,
126                          nir_def *vec32, unsigned num_bits, nir_variable *var, unsigned alignment)
127 {
128    nir_def *mask = nir_imm_int(b, (1 << num_bits) - 1);
129 
130    /* If we have small alignments, we need to place them correctly in the u32 component. */
131    if (alignment <= 2) {
132       nir_def *shift =
133          nir_imul_imm(b, nir_iand_imm(b, offset, 3), 8);
134 
135       vec32 = nir_ishl(b, vec32, shift);
136       mask = nir_ishl(b, mask, shift);
137    }
138 
139    if (var->data.mode == nir_var_mem_shared) {
140       /* Use the dedicated masked intrinsic */
141       nir_deref_instr *deref = nir_build_deref_array(b, nir_build_deref_var(b, var), index);
142       nir_deref_atomic(b, 32, &deref->def, nir_inot(b, mask), .atomic_op = nir_atomic_op_iand);
143       nir_deref_atomic(b, 32, &deref->def, vec32, .atomic_op = nir_atomic_op_ior);
144    } else {
145       /* For scratch, since we don't need atomics, just generate the read-modify-write in NIR */
146       nir_def *load = nir_load_array_var(b, var, index);
147 
148       nir_def *new_val = nir_ior(b, vec32,
149                                      nir_iand(b,
150                                               nir_inot(b, mask),
151                                               load));
152 
153       nir_store_array_var(b, var, index, new_val, 1);
154    }
155 }
156 
157 static bool
lower_32b_offset_store(nir_builder * b,nir_intrinsic_instr * intr,nir_variable * var)158 lower_32b_offset_store(nir_builder *b, nir_intrinsic_instr *intr, nir_variable *var)
159 {
160    unsigned num_components = nir_src_num_components(intr->src[0]);
161    unsigned bit_size = nir_src_bit_size(intr->src[0]);
162    unsigned num_bits = num_components * bit_size;
163 
164    b->cursor = nir_before_instr(&intr->instr);
165 
166    nir_def *offset = intr->src[1].ssa;
167    if (intr->intrinsic == nir_intrinsic_store_shared)
168       offset = nir_iadd_imm(b, offset, nir_intrinsic_base(intr));
169    else
170       offset = nir_u2u32(b, offset);
171    nir_def *comps[NIR_MAX_VEC_COMPONENTS];
172 
173    unsigned comp_idx = 0;
174    for (unsigned i = 0; i < num_components; i++)
175       comps[i] = nir_channel(b, intr->src[0].ssa, i);
176 
177    unsigned step = MAX2(bit_size, 32);
178    for (unsigned i = 0; i < num_bits; i += step) {
179       /* For each 4byte chunk (or smaller) we generate a 32bit scalar store. */
180       unsigned substore_num_bits = MIN2(num_bits - i, step);
181       nir_def *local_offset = nir_iadd_imm(b, offset, i / 8);
182       nir_def *vec32 = load_comps_to_vec(b, bit_size, &comps[comp_idx],
183                                              substore_num_bits / bit_size, 32);
184       nir_def *index = nir_ushr_imm(b, local_offset, 2);
185 
186       /* For anything less than 32bits we need to use the masked version of the
187        * intrinsic to preserve data living in the same 32bit slot. */
188       if (substore_num_bits < 32) {
189          lower_masked_store_vec32(b, local_offset, index, vec32, num_bits, var, nir_intrinsic_align(intr));
190       } else {
191          for (unsigned i = 0; i < vec32->num_components; ++i)
192             nir_store_array_var(b, var, nir_iadd_imm(b, index, i), nir_channel(b, vec32, i), 1);
193       }
194 
195       comp_idx += substore_num_bits / bit_size;
196    }
197 
198    nir_instr_remove(&intr->instr);
199 
200    return true;
201 }
202 
203 #define CONSTANT_LOCATION_UNVISITED 0
204 #define CONSTANT_LOCATION_VALID 1
205 #define CONSTANT_LOCATION_INVALID 2
206 
207 bool
dxil_nir_lower_constant_to_temp(nir_shader * nir)208 dxil_nir_lower_constant_to_temp(nir_shader *nir)
209 {
210    bool progress = false;
211    nir_foreach_variable_with_modes(var, nir, nir_var_mem_constant)
212       var->data.location = var->constant_initializer ?
213          CONSTANT_LOCATION_UNVISITED : CONSTANT_LOCATION_INVALID;
214 
215    /* First pass: collect all UBO accesses that could be turned into
216     * shader temp accesses.
217     */
218    nir_foreach_function(func, nir) {
219       if (!func->is_entrypoint)
220          continue;
221       assert(func->impl);
222 
223       nir_foreach_block(block, func->impl) {
224          nir_foreach_instr_safe(instr, block) {
225             if (instr->type != nir_instr_type_deref)
226                continue;
227 
228             nir_deref_instr *deref = nir_instr_as_deref(instr);
229             if (!nir_deref_mode_is(deref, nir_var_mem_constant) ||
230                 deref->deref_type != nir_deref_type_var ||
231                 deref->var->data.location == CONSTANT_LOCATION_INVALID)
232                continue;
233 
234             deref->var->data.location = nir_deref_instr_has_complex_use(deref, 0) ?
235                CONSTANT_LOCATION_INVALID : CONSTANT_LOCATION_VALID;
236          }
237       }
238    }
239 
240    nir_foreach_variable_with_modes(var, nir, nir_var_mem_constant) {
241       if (var->data.location != CONSTANT_LOCATION_VALID)
242          continue;
243 
244       /* Change the variable mode. */
245       var->data.mode = nir_var_shader_temp;
246 
247       progress = true;
248    }
249 
250    /* Second pass: patch all derefs that were accessing the converted UBOs
251     * variables.
252     */
253    nir_foreach_function(func, nir) {
254       if (!func->is_entrypoint)
255          continue;
256       assert(func->impl);
257 
258       nir_builder b = nir_builder_create(func->impl);
259       nir_foreach_block(block, func->impl) {
260          nir_foreach_instr_safe(instr, block) {
261             if (instr->type != nir_instr_type_deref)
262                continue;
263 
264             nir_deref_instr *deref = nir_instr_as_deref(instr);
265             if (nir_deref_mode_is(deref, nir_var_mem_constant)) {
266                nir_deref_instr *parent = deref;
267                while (parent && parent->deref_type != nir_deref_type_var)
268                   parent = nir_src_as_deref(parent->parent);
269                if (parent && parent->var->data.mode != nir_var_mem_constant) {
270                   deref->modes = parent->var->data.mode;
271                   /* Also change "pointer" size to 32-bit since this is now a logical pointer */
272                   deref->def.bit_size = 32;
273                   if (deref->deref_type == nir_deref_type_array) {
274                      b.cursor = nir_before_instr(instr);
275                      nir_src_rewrite(&deref->arr.index, nir_u2u32(&b, deref->arr.index.ssa));
276                   }
277                }
278             }
279          }
280       }
281    }
282 
283    return progress;
284 }
285 
286 static bool
flatten_var_arrays(nir_builder * b,nir_intrinsic_instr * intr,void * data)287 flatten_var_arrays(nir_builder *b, nir_intrinsic_instr *intr, void *data)
288 {
289    switch (intr->intrinsic) {
290    case nir_intrinsic_load_deref:
291    case nir_intrinsic_store_deref:
292    case nir_intrinsic_deref_atomic:
293    case nir_intrinsic_deref_atomic_swap:
294       break;
295    default:
296       return false;
297    }
298 
299    nir_deref_instr *deref = nir_src_as_deref(intr->src[0]);
300    nir_variable *var = NULL;
301    for (nir_deref_instr *d = deref; d; d = nir_deref_instr_parent(d)) {
302       if (d->deref_type == nir_deref_type_cast)
303          return false;
304       if (d->deref_type == nir_deref_type_var) {
305          var = d->var;
306          if (d->type == var->type)
307             return false;
308       }
309    }
310    if (!var)
311       return false;
312 
313    nir_deref_path path;
314    nir_deref_path_init(&path, deref, NULL);
315 
316    assert(path.path[0]->deref_type == nir_deref_type_var);
317    b->cursor = nir_before_instr(&path.path[0]->instr);
318    nir_deref_instr *new_var_deref = nir_build_deref_var(b, var);
319    nir_def *index = NULL;
320    for (unsigned level = 1; path.path[level]; ++level) {
321       nir_deref_instr *arr_deref = path.path[level];
322       assert(arr_deref->deref_type == nir_deref_type_array);
323       b->cursor = nir_before_instr(&arr_deref->instr);
324       nir_def *val = nir_imul_imm(b, arr_deref->arr.index.ssa,
325                                       glsl_get_component_slots(arr_deref->type));
326       if (index) {
327          index = nir_iadd(b, index, val);
328       } else {
329          index = val;
330       }
331    }
332 
333    unsigned vector_comps = intr->num_components;
334    if (vector_comps > 1) {
335       b->cursor = nir_before_instr(&intr->instr);
336       if (intr->intrinsic == nir_intrinsic_load_deref) {
337          nir_def *components[NIR_MAX_VEC_COMPONENTS];
338          for (unsigned i = 0; i < vector_comps; ++i) {
339             nir_def *final_index = index ? nir_iadd_imm(b, index, i) : nir_imm_int(b, i);
340             nir_deref_instr *comp_deref = nir_build_deref_array(b, new_var_deref, final_index);
341             components[i] = nir_load_deref(b, comp_deref);
342          }
343          nir_def_rewrite_uses(&intr->def, nir_vec(b, components, vector_comps));
344       } else if (intr->intrinsic == nir_intrinsic_store_deref) {
345          for (unsigned i = 0; i < vector_comps; ++i) {
346             if (((1 << i) & nir_intrinsic_write_mask(intr)) == 0)
347                continue;
348             nir_def *final_index = index ? nir_iadd_imm(b, index, i) : nir_imm_int(b, i);
349             nir_deref_instr *comp_deref = nir_build_deref_array(b, new_var_deref, final_index);
350             nir_store_deref(b, comp_deref, nir_channel(b, intr->src[1].ssa, i), 1);
351          }
352       }
353       nir_instr_remove(&intr->instr);
354    } else {
355       nir_src_rewrite(&intr->src[0], &nir_build_deref_array(b, new_var_deref, index)->def);
356    }
357 
358    nir_deref_path_finish(&path);
359    return true;
360 }
361 
362 static void
flatten_constant_initializer(nir_variable * var,nir_constant * src,nir_constant *** dest,unsigned vector_elements)363 flatten_constant_initializer(nir_variable *var, nir_constant *src, nir_constant ***dest, unsigned vector_elements)
364 {
365    if (src->num_elements == 0) {
366       for (unsigned i = 0; i < vector_elements; ++i) {
367          nir_constant *new_scalar = rzalloc(var, nir_constant);
368          memcpy(&new_scalar->values[0], &src->values[i], sizeof(src->values[0]));
369          new_scalar->is_null_constant = src->values[i].u64 == 0;
370 
371          nir_constant **array_entry = (*dest)++;
372          *array_entry = new_scalar;
373       }
374    } else {
375       for (unsigned i = 0; i < src->num_elements; ++i)
376          flatten_constant_initializer(var, src->elements[i], dest, vector_elements);
377    }
378 }
379 
380 static bool
flatten_var_array_types(nir_variable * var)381 flatten_var_array_types(nir_variable *var)
382 {
383    assert(!glsl_type_is_struct(glsl_without_array(var->type)));
384    const struct glsl_type *matrix_type = glsl_without_array(var->type);
385    if (!glsl_type_is_array_of_arrays(var->type) && glsl_get_components(matrix_type) == 1)
386       return false;
387 
388    enum glsl_base_type base_type = glsl_get_base_type(matrix_type);
389    const struct glsl_type *flattened_type = glsl_array_type(glsl_scalar_type(base_type),
390                                                             glsl_get_component_slots(var->type), 0);
391    var->type = flattened_type;
392    if (var->constant_initializer) {
393       nir_constant **new_elements = ralloc_array(var, nir_constant *, glsl_get_length(flattened_type));
394       nir_constant **temp = new_elements;
395       flatten_constant_initializer(var, var->constant_initializer, &temp, glsl_get_vector_elements(matrix_type));
396       var->constant_initializer->num_elements = glsl_get_length(flattened_type);
397       var->constant_initializer->elements = new_elements;
398    }
399    return true;
400 }
401 
402 bool
dxil_nir_flatten_var_arrays(nir_shader * shader,nir_variable_mode modes)403 dxil_nir_flatten_var_arrays(nir_shader *shader, nir_variable_mode modes)
404 {
405    bool progress = false;
406    nir_foreach_variable_with_modes(var, shader, modes & ~nir_var_function_temp)
407       progress |= flatten_var_array_types(var);
408 
409    if (modes & nir_var_function_temp) {
410       nir_foreach_function_impl(impl, shader) {
411          nir_foreach_function_temp_variable(var, impl)
412             progress |= flatten_var_array_types(var);
413       }
414    }
415 
416    if (!progress)
417       return false;
418 
419    nir_shader_intrinsics_pass(shader, flatten_var_arrays,
420                                 nir_metadata_block_index |
421                                    nir_metadata_dominance |
422                                    nir_metadata_loop_analysis,
423                                 NULL);
424    nir_remove_dead_derefs(shader);
425    return true;
426 }
427 
428 static bool
lower_deref_bit_size(nir_builder * b,nir_intrinsic_instr * intr,void * data)429 lower_deref_bit_size(nir_builder *b, nir_intrinsic_instr *intr, void *data)
430 {
431    switch (intr->intrinsic) {
432    case nir_intrinsic_load_deref:
433    case nir_intrinsic_store_deref:
434       break;
435    default:
436       /* Atomics can't be smaller than 32-bit */
437       return false;
438    }
439 
440    nir_deref_instr *deref = nir_src_as_deref(intr->src[0]);
441    nir_variable *var = nir_deref_instr_get_variable(deref);
442    /* Only interested in full deref chains */
443    if (!var)
444       return false;
445 
446    const struct glsl_type *var_scalar_type = glsl_without_array(var->type);
447    if (deref->type == var_scalar_type || !glsl_type_is_scalar(var_scalar_type))
448       return false;
449 
450    assert(deref->deref_type == nir_deref_type_var || deref->deref_type == nir_deref_type_array);
451    const struct glsl_type *old_glsl_type = deref->type;
452    nir_alu_type old_type = nir_get_nir_type_for_glsl_type(old_glsl_type);
453    nir_alu_type new_type = nir_get_nir_type_for_glsl_type(var_scalar_type);
454    if (glsl_get_bit_size(old_glsl_type) < glsl_get_bit_size(var_scalar_type)) {
455       deref->type = var_scalar_type;
456       if (intr->intrinsic == nir_intrinsic_load_deref) {
457          intr->def.bit_size = glsl_get_bit_size(var_scalar_type);
458          b->cursor = nir_after_instr(&intr->instr);
459          nir_def *downcast = nir_type_convert(b, &intr->def, new_type, old_type, nir_rounding_mode_undef);
460          nir_def_rewrite_uses_after(&intr->def, downcast, downcast->parent_instr);
461       }
462       else {
463          b->cursor = nir_before_instr(&intr->instr);
464          nir_def *upcast = nir_type_convert(b, intr->src[1].ssa, old_type, new_type, nir_rounding_mode_undef);
465          nir_src_rewrite(&intr->src[1], upcast);
466       }
467 
468       while (deref->deref_type == nir_deref_type_array) {
469          nir_deref_instr *parent = nir_deref_instr_parent(deref);
470          parent->type = glsl_type_wrap_in_arrays(deref->type, parent->type);
471          deref = parent;
472       }
473    } else {
474       /* Assumed arrays are already flattened */
475       b->cursor = nir_before_instr(&deref->instr);
476       nir_deref_instr *parent = nir_build_deref_var(b, var);
477       if (deref->deref_type == nir_deref_type_array)
478          deref = nir_build_deref_array(b, parent, nir_imul_imm(b, deref->arr.index.ssa, 2));
479       else
480          deref = nir_build_deref_array_imm(b, parent, 0);
481       nir_deref_instr *deref2 = nir_build_deref_array(b, parent,
482                                                       nir_iadd_imm(b, deref->arr.index.ssa, 1));
483       b->cursor = nir_before_instr(&intr->instr);
484       if (intr->intrinsic == nir_intrinsic_load_deref) {
485          nir_def *src1 = nir_load_deref(b, deref);
486          nir_def *src2 = nir_load_deref(b, deref2);
487          nir_def_rewrite_uses(&intr->def, nir_pack_64_2x32_split(b, src1, src2));
488       } else {
489          nir_def *src1 = nir_unpack_64_2x32_split_x(b, intr->src[1].ssa);
490          nir_def *src2 = nir_unpack_64_2x32_split_y(b, intr->src[1].ssa);
491          nir_store_deref(b, deref, src1, 1);
492          nir_store_deref(b, deref, src2, 1);
493       }
494       nir_instr_remove(&intr->instr);
495    }
496    return true;
497 }
498 
499 static bool
lower_var_bit_size_types(nir_variable * var,unsigned min_bit_size,unsigned max_bit_size)500 lower_var_bit_size_types(nir_variable *var, unsigned min_bit_size, unsigned max_bit_size)
501 {
502    assert(!glsl_type_is_array_of_arrays(var->type) && !glsl_type_is_struct(var->type));
503    const struct glsl_type *type = glsl_without_array(var->type);
504    assert(glsl_type_is_scalar(type));
505    enum glsl_base_type base_type = glsl_get_base_type(type);
506    if (glsl_base_type_get_bit_size(base_type) < min_bit_size) {
507       switch (min_bit_size) {
508       case 16:
509          switch (base_type) {
510          case GLSL_TYPE_BOOL:
511             base_type = GLSL_TYPE_UINT16;
512             for (unsigned i = 0; i < (var->constant_initializer ? var->constant_initializer->num_elements : 0); ++i)
513                var->constant_initializer->elements[i]->values[0].u16 = var->constant_initializer->elements[i]->values[0].b ? 0xffff : 0;
514             break;
515          case GLSL_TYPE_INT8:
516             base_type = GLSL_TYPE_INT16;
517             for (unsigned i = 0; i < (var->constant_initializer ? var->constant_initializer->num_elements : 0); ++i)
518                var->constant_initializer->elements[i]->values[0].i16 = var->constant_initializer->elements[i]->values[0].i8;
519             break;
520          case GLSL_TYPE_UINT8: base_type = GLSL_TYPE_UINT16; break;
521          default: unreachable("Unexpected base type");
522          }
523          break;
524       case 32:
525          switch (base_type) {
526          case GLSL_TYPE_BOOL:
527             base_type = GLSL_TYPE_UINT;
528             for (unsigned i = 0; i < (var->constant_initializer ? var->constant_initializer->num_elements : 0); ++i)
529                var->constant_initializer->elements[i]->values[0].u32 = var->constant_initializer->elements[i]->values[0].b ? 0xffffffff : 0;
530             break;
531          case GLSL_TYPE_INT8:
532             base_type = GLSL_TYPE_INT;
533             for (unsigned i = 0; i < (var->constant_initializer ? var->constant_initializer->num_elements : 0); ++i)
534                var->constant_initializer->elements[i]->values[0].i32 = var->constant_initializer->elements[i]->values[0].i8;
535             break;
536          case GLSL_TYPE_INT16:
537             base_type = GLSL_TYPE_INT;
538             for (unsigned i = 0; i < (var->constant_initializer ? var->constant_initializer->num_elements : 0); ++i)
539                var->constant_initializer->elements[i]->values[0].i32 = var->constant_initializer->elements[i]->values[0].i16;
540             break;
541          case GLSL_TYPE_FLOAT16:
542             base_type = GLSL_TYPE_FLOAT;
543             for (unsigned i = 0; i < (var->constant_initializer ? var->constant_initializer->num_elements : 0); ++i)
544                var->constant_initializer->elements[i]->values[0].f32 = _mesa_half_to_float(var->constant_initializer->elements[i]->values[0].u16);
545             break;
546          case GLSL_TYPE_UINT8: base_type = GLSL_TYPE_UINT; break;
547          case GLSL_TYPE_UINT16: base_type = GLSL_TYPE_UINT; break;
548          default: unreachable("Unexpected base type");
549          }
550          break;
551       default: unreachable("Unexpected min bit size");
552       }
553       var->type = glsl_type_wrap_in_arrays(glsl_scalar_type(base_type), var->type);
554       return true;
555    }
556    if (glsl_base_type_bit_size(base_type) > max_bit_size) {
557       assert(!glsl_type_is_array_of_arrays(var->type));
558       var->type = glsl_array_type(glsl_scalar_type(GLSL_TYPE_UINT),
559                                     glsl_type_is_array(var->type) ? glsl_get_length(var->type) * 2 : 2,
560                                     0);
561       if (var->constant_initializer) {
562          unsigned num_elements = var->constant_initializer->num_elements ?
563             var->constant_initializer->num_elements * 2 : 2;
564          nir_constant **element_arr = ralloc_array(var, nir_constant *, num_elements);
565          nir_constant *elements = rzalloc_array(var, nir_constant, num_elements);
566          for (unsigned i = 0; i < var->constant_initializer->num_elements; ++i) {
567             element_arr[i*2] = &elements[i*2];
568             element_arr[i*2+1] = &elements[i*2+1];
569             const nir_const_value *src = var->constant_initializer->num_elements ?
570                var->constant_initializer->elements[i]->values : var->constant_initializer->values;
571             elements[i*2].values[0].u32 = (uint32_t)src->u64;
572             elements[i*2].is_null_constant = (uint32_t)src->u64 == 0;
573             elements[i*2+1].values[0].u32 = (uint32_t)(src->u64 >> 32);
574             elements[i*2+1].is_null_constant = (uint32_t)(src->u64 >> 32) == 0;
575          }
576          var->constant_initializer->num_elements = num_elements;
577          var->constant_initializer->elements = element_arr;
578       }
579       return true;
580    }
581    return false;
582 }
583 
584 bool
dxil_nir_lower_var_bit_size(nir_shader * shader,nir_variable_mode modes,unsigned min_bit_size,unsigned max_bit_size)585 dxil_nir_lower_var_bit_size(nir_shader *shader, nir_variable_mode modes,
586                             unsigned min_bit_size, unsigned max_bit_size)
587 {
588    bool progress = false;
589    nir_foreach_variable_with_modes(var, shader, modes & ~nir_var_function_temp)
590       progress |= lower_var_bit_size_types(var, min_bit_size, max_bit_size);
591 
592    if (modes & nir_var_function_temp) {
593       nir_foreach_function_impl(impl, shader) {
594          nir_foreach_function_temp_variable(var, impl)
595             progress |= lower_var_bit_size_types(var, min_bit_size, max_bit_size);
596       }
597    }
598 
599    if (!progress)
600       return false;
601 
602    nir_shader_intrinsics_pass(shader, lower_deref_bit_size,
603                                 nir_metadata_block_index |
604                                    nir_metadata_dominance |
605                                    nir_metadata_loop_analysis,
606                                 NULL);
607    nir_remove_dead_derefs(shader);
608    return true;
609 }
610 
611 static bool
lower_shared_atomic(nir_builder * b,nir_intrinsic_instr * intr,nir_variable * var)612 lower_shared_atomic(nir_builder *b, nir_intrinsic_instr *intr, nir_variable *var)
613 {
614    b->cursor = nir_before_instr(&intr->instr);
615 
616    nir_def *offset =
617       nir_iadd_imm(b, intr->src[0].ssa, nir_intrinsic_base(intr));
618    nir_def *index = nir_ushr_imm(b, offset, 2);
619 
620    nir_deref_instr *deref = nir_build_deref_array(b, nir_build_deref_var(b, var), index);
621    nir_def *result;
622    if (intr->intrinsic == nir_intrinsic_shared_atomic_swap)
623       result = nir_deref_atomic_swap(b, 32, &deref->def, intr->src[1].ssa, intr->src[2].ssa,
624                                      .atomic_op = nir_intrinsic_atomic_op(intr));
625    else
626       result = nir_deref_atomic(b, 32, &deref->def, intr->src[1].ssa,
627                                 .atomic_op = nir_intrinsic_atomic_op(intr));
628 
629    nir_def_rewrite_uses(&intr->def, result);
630    nir_instr_remove(&intr->instr);
631    return true;
632 }
633 
634 bool
dxil_nir_lower_loads_stores_to_dxil(nir_shader * nir,const struct dxil_nir_lower_loads_stores_options * options)635 dxil_nir_lower_loads_stores_to_dxil(nir_shader *nir,
636                                     const struct dxil_nir_lower_loads_stores_options *options)
637 {
638    bool progress = nir_remove_dead_variables(nir, nir_var_function_temp | nir_var_mem_shared, NULL);
639    nir_variable *shared_var = NULL;
640    if (nir->info.shared_size) {
641       shared_var = nir_variable_create(nir, nir_var_mem_shared,
642                                        glsl_array_type(glsl_uint_type(), DIV_ROUND_UP(nir->info.shared_size, 4), 4),
643                                        "lowered_shared_mem");
644    }
645 
646    unsigned ptr_size = nir->info.cs.ptr_size;
647    if (nir->info.stage == MESA_SHADER_KERNEL) {
648       /* All the derefs created here will be used as GEP indices so force 32-bit */
649       nir->info.cs.ptr_size = 32;
650    }
651    nir_foreach_function_impl(impl, nir) {
652       nir_builder b = nir_builder_create(impl);
653 
654       nir_variable *scratch_var = NULL;
655       if (nir->scratch_size) {
656          const struct glsl_type *scratch_type = glsl_array_type(glsl_uint_type(), DIV_ROUND_UP(nir->scratch_size, 4), 4);
657          scratch_var = nir_local_variable_create(impl, scratch_type, "lowered_scratch_mem");
658       }
659 
660       nir_foreach_block(block, impl) {
661          nir_foreach_instr_safe(instr, block) {
662             if (instr->type != nir_instr_type_intrinsic)
663                continue;
664             nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
665 
666             switch (intr->intrinsic) {
667             case nir_intrinsic_load_shared:
668                progress |= lower_32b_offset_load(&b, intr, shared_var);
669                break;
670             case nir_intrinsic_load_scratch:
671                progress |= lower_32b_offset_load(&b, intr, scratch_var);
672                break;
673             case nir_intrinsic_store_shared:
674                progress |= lower_32b_offset_store(&b, intr, shared_var);
675                break;
676             case nir_intrinsic_store_scratch:
677                progress |= lower_32b_offset_store(&b, intr, scratch_var);
678                break;
679             case nir_intrinsic_shared_atomic:
680             case nir_intrinsic_shared_atomic_swap:
681                progress |= lower_shared_atomic(&b, intr, shared_var);
682                break;
683             default:
684                break;
685             }
686          }
687       }
688    }
689    if (nir->info.stage == MESA_SHADER_KERNEL) {
690       nir->info.cs.ptr_size = ptr_size;
691    }
692 
693    return progress;
694 }
695 
696 static bool
lower_deref_ssbo(nir_builder * b,nir_deref_instr * deref)697 lower_deref_ssbo(nir_builder *b, nir_deref_instr *deref)
698 {
699    assert(nir_deref_mode_is(deref, nir_var_mem_ssbo));
700    assert(deref->deref_type == nir_deref_type_var ||
701           deref->deref_type == nir_deref_type_cast);
702    nir_variable *var = deref->var;
703 
704    b->cursor = nir_before_instr(&deref->instr);
705 
706    if (deref->deref_type == nir_deref_type_var) {
707       /* We turn all deref_var into deref_cast and build a pointer value based on
708        * the var binding which encodes the UAV id.
709        */
710       nir_def *ptr = nir_imm_int64(b, (uint64_t)var->data.binding << 32);
711       nir_deref_instr *deref_cast =
712          nir_build_deref_cast(b, ptr, nir_var_mem_ssbo, deref->type,
713                               glsl_get_explicit_stride(var->type));
714       nir_def_rewrite_uses(&deref->def,
715                                &deref_cast->def);
716       nir_instr_remove(&deref->instr);
717 
718       deref = deref_cast;
719       return true;
720    }
721    return false;
722 }
723 
724 bool
dxil_nir_lower_deref_ssbo(nir_shader * nir)725 dxil_nir_lower_deref_ssbo(nir_shader *nir)
726 {
727    bool progress = false;
728 
729    foreach_list_typed(nir_function, func, node, &nir->functions) {
730       if (!func->is_entrypoint)
731          continue;
732       assert(func->impl);
733 
734       nir_builder b = nir_builder_create(func->impl);
735 
736       nir_foreach_block(block, func->impl) {
737          nir_foreach_instr_safe(instr, block) {
738             if (instr->type != nir_instr_type_deref)
739                continue;
740 
741             nir_deref_instr *deref = nir_instr_as_deref(instr);
742 
743             if (!nir_deref_mode_is(deref, nir_var_mem_ssbo) ||
744                 (deref->deref_type != nir_deref_type_var &&
745                  deref->deref_type != nir_deref_type_cast))
746                continue;
747 
748             progress |= lower_deref_ssbo(&b, deref);
749          }
750       }
751    }
752 
753    return progress;
754 }
755 
756 static bool
lower_alu_deref_srcs(nir_builder * b,nir_alu_instr * alu)757 lower_alu_deref_srcs(nir_builder *b, nir_alu_instr *alu)
758 {
759    const nir_op_info *info = &nir_op_infos[alu->op];
760    bool progress = false;
761 
762    b->cursor = nir_before_instr(&alu->instr);
763 
764    for (unsigned i = 0; i < info->num_inputs; i++) {
765       nir_deref_instr *deref = nir_src_as_deref(alu->src[i].src);
766 
767       if (!deref)
768          continue;
769 
770       nir_deref_path path;
771       nir_deref_path_init(&path, deref, NULL);
772       nir_deref_instr *root_deref = path.path[0];
773       nir_deref_path_finish(&path);
774 
775       if (root_deref->deref_type != nir_deref_type_cast)
776          continue;
777 
778       nir_def *ptr =
779          nir_iadd(b, root_deref->parent.ssa,
780                      nir_build_deref_offset(b, deref, cl_type_size_align));
781       nir_src_rewrite(&alu->src[i].src, ptr);
782       progress = true;
783    }
784 
785    return progress;
786 }
787 
788 bool
dxil_nir_opt_alu_deref_srcs(nir_shader * nir)789 dxil_nir_opt_alu_deref_srcs(nir_shader *nir)
790 {
791    bool progress = false;
792 
793    foreach_list_typed(nir_function, func, node, &nir->functions) {
794       if (!func->is_entrypoint)
795          continue;
796       assert(func->impl);
797 
798       nir_builder b = nir_builder_create(func->impl);
799 
800       nir_foreach_block(block, func->impl) {
801          nir_foreach_instr_safe(instr, block) {
802             if (instr->type != nir_instr_type_alu)
803                continue;
804 
805             nir_alu_instr *alu = nir_instr_as_alu(instr);
806             progress |= lower_alu_deref_srcs(&b, alu);
807          }
808       }
809    }
810 
811    return progress;
812 }
813 
814 static void
cast_phi(nir_builder * b,nir_phi_instr * phi,unsigned new_bit_size)815 cast_phi(nir_builder *b, nir_phi_instr *phi, unsigned new_bit_size)
816 {
817    nir_phi_instr *lowered = nir_phi_instr_create(b->shader);
818    int num_components = 0;
819    int old_bit_size = phi->def.bit_size;
820 
821    nir_foreach_phi_src(src, phi) {
822       assert(num_components == 0 || num_components == src->src.ssa->num_components);
823       num_components = src->src.ssa->num_components;
824 
825       b->cursor = nir_after_instr_and_phis(src->src.ssa->parent_instr);
826 
827       nir_def *cast = nir_u2uN(b, src->src.ssa, new_bit_size);
828 
829       nir_phi_instr_add_src(lowered, src->pred, cast);
830    }
831 
832    nir_def_init(&lowered->instr, &lowered->def, num_components,
833                 new_bit_size);
834 
835    b->cursor = nir_before_instr(&phi->instr);
836    nir_builder_instr_insert(b, &lowered->instr);
837 
838    b->cursor = nir_after_phis(nir_cursor_current_block(b->cursor));
839    nir_def *result = nir_u2uN(b, &lowered->def, old_bit_size);
840 
841    nir_def_rewrite_uses(&phi->def, result);
842    nir_instr_remove(&phi->instr);
843 }
844 
845 static bool
upcast_phi_impl(nir_function_impl * impl,unsigned min_bit_size)846 upcast_phi_impl(nir_function_impl *impl, unsigned min_bit_size)
847 {
848    nir_builder b = nir_builder_create(impl);
849    bool progress = false;
850 
851    nir_foreach_block_reverse(block, impl) {
852       nir_foreach_phi_safe(phi, block) {
853          if (phi->def.bit_size == 1 ||
854              phi->def.bit_size >= min_bit_size)
855             continue;
856 
857          cast_phi(&b, phi, min_bit_size);
858          progress = true;
859       }
860    }
861 
862    if (progress) {
863       nir_metadata_preserve(impl, nir_metadata_block_index |
864                                   nir_metadata_dominance);
865    } else {
866       nir_metadata_preserve(impl, nir_metadata_all);
867    }
868 
869    return progress;
870 }
871 
872 bool
dxil_nir_lower_upcast_phis(nir_shader * shader,unsigned min_bit_size)873 dxil_nir_lower_upcast_phis(nir_shader *shader, unsigned min_bit_size)
874 {
875    bool progress = false;
876 
877    nir_foreach_function_impl(impl, shader) {
878       progress |= upcast_phi_impl(impl, min_bit_size);
879    }
880 
881    return progress;
882 }
883 
884 struct dxil_nir_split_clip_cull_distance_params {
885    nir_variable *new_var[2];
886    nir_shader *shader;
887 };
888 
889 /* In GLSL and SPIR-V, clip and cull distance are arrays of floats (with a limit of 8).
890  * In DXIL, clip and cull distances are up to 2 float4s combined.
891  * Coming from GLSL, we can request this 2 float4 format, but coming from SPIR-V,
892  * we can't, and have to accept a "compact" array of scalar floats.
893  *
894  * To help emitting a valid input signature for this case, split the variables so that they
895  * match what we need to put in the signature (e.g. { float clip[4]; float clip1; float cull[3]; })
896  */
897 static bool
dxil_nir_split_clip_cull_distance_instr(nir_builder * b,nir_instr * instr,void * cb_data)898 dxil_nir_split_clip_cull_distance_instr(nir_builder *b,
899                                         nir_instr *instr,
900                                         void *cb_data)
901 {
902    struct dxil_nir_split_clip_cull_distance_params *params = cb_data;
903 
904    if (instr->type != nir_instr_type_deref)
905       return false;
906 
907    nir_deref_instr *deref = nir_instr_as_deref(instr);
908    nir_variable *var = nir_deref_instr_get_variable(deref);
909    if (!var ||
910        var->data.location < VARYING_SLOT_CLIP_DIST0 ||
911        var->data.location > VARYING_SLOT_CULL_DIST1 ||
912        !var->data.compact)
913       return false;
914 
915    unsigned new_var_idx = var->data.mode == nir_var_shader_in ? 0 : 1;
916    nir_variable *new_var = params->new_var[new_var_idx];
917 
918    /* The location should only be inside clip distance, because clip
919     * and cull should've been merged by nir_lower_clip_cull_distance_arrays()
920     */
921    assert(var->data.location == VARYING_SLOT_CLIP_DIST0 ||
922           var->data.location == VARYING_SLOT_CLIP_DIST1);
923 
924    /* The deref chain to the clip/cull variables should be simple, just the
925     * var and an array with a constant index, otherwise more lowering/optimization
926     * might be needed before this pass, e.g. copy prop, lower_io_to_temporaries,
927     * split_var_copies, and/or lower_var_copies. In the case of arrayed I/O like
928     * inputs to the tessellation or geometry stages, there might be a second level
929     * of array index.
930     */
931    assert(deref->deref_type == nir_deref_type_var ||
932           deref->deref_type == nir_deref_type_array);
933 
934    b->cursor = nir_before_instr(instr);
935    unsigned arrayed_io_length = 0;
936    const struct glsl_type *old_type = var->type;
937    if (nir_is_arrayed_io(var, b->shader->info.stage)) {
938       arrayed_io_length = glsl_array_size(old_type);
939       old_type = glsl_get_array_element(old_type);
940    }
941    if (!new_var) {
942       /* Update lengths for new and old vars */
943       int old_length = glsl_array_size(old_type);
944       int new_length = (old_length + var->data.location_frac) - 4;
945       old_length -= new_length;
946 
947       /* The existing variable fits in the float4 */
948       if (new_length <= 0)
949          return false;
950 
951       new_var = nir_variable_clone(var, params->shader);
952       nir_shader_add_variable(params->shader, new_var);
953       assert(glsl_get_base_type(glsl_get_array_element(old_type)) == GLSL_TYPE_FLOAT);
954       var->type = glsl_array_type(glsl_float_type(), old_length, 0);
955       new_var->type = glsl_array_type(glsl_float_type(), new_length, 0);
956       if (arrayed_io_length) {
957          var->type = glsl_array_type(var->type, arrayed_io_length, 0);
958          new_var->type = glsl_array_type(new_var->type, arrayed_io_length, 0);
959       }
960       new_var->data.location++;
961       new_var->data.location_frac = 0;
962       params->new_var[new_var_idx] = new_var;
963    }
964 
965    /* Update the type for derefs of the old var */
966    if (deref->deref_type == nir_deref_type_var) {
967       deref->type = var->type;
968       return false;
969    }
970 
971    if (glsl_type_is_array(deref->type)) {
972       assert(arrayed_io_length > 0);
973       deref->type = glsl_get_array_element(var->type);
974       return false;
975    }
976 
977    assert(glsl_get_base_type(deref->type) == GLSL_TYPE_FLOAT);
978 
979    nir_const_value *index = nir_src_as_const_value(deref->arr.index);
980    assert(index);
981 
982    /* Treat this array as a vector starting at the component index in location_frac,
983     * so if location_frac is 1 and index is 0, then it's accessing the 'y' component
984     * of the vector. If index + location_frac is >= 4, there's no component there,
985     * so we need to add a new variable and adjust the index.
986     */
987    unsigned total_index = index->u32 + var->data.location_frac;
988    if (total_index < 4)
989       return false;
990 
991    nir_deref_instr *new_var_deref = nir_build_deref_var(b, new_var);
992    nir_deref_instr *new_intermediate_deref = new_var_deref;
993    if (arrayed_io_length) {
994       nir_deref_instr *parent = nir_src_as_deref(deref->parent);
995       assert(parent->deref_type == nir_deref_type_array);
996       new_intermediate_deref = nir_build_deref_array(b, new_intermediate_deref, parent->arr.index.ssa);
997    }
998    nir_deref_instr *new_array_deref = nir_build_deref_array(b, new_intermediate_deref, nir_imm_int(b, total_index % 4));
999    nir_def_rewrite_uses(&deref->def, &new_array_deref->def);
1000    return true;
1001 }
1002 
1003 bool
dxil_nir_split_clip_cull_distance(nir_shader * shader)1004 dxil_nir_split_clip_cull_distance(nir_shader *shader)
1005 {
1006    struct dxil_nir_split_clip_cull_distance_params params = {
1007       .new_var = { NULL, NULL },
1008       .shader = shader,
1009    };
1010    nir_shader_instructions_pass(shader,
1011                                 dxil_nir_split_clip_cull_distance_instr,
1012                                 nir_metadata_block_index |
1013                                 nir_metadata_dominance |
1014                                 nir_metadata_loop_analysis,
1015                                 &params);
1016    return params.new_var[0] != NULL || params.new_var[1] != NULL;
1017 }
1018 
1019 static bool
dxil_nir_lower_double_math_instr(nir_builder * b,nir_instr * instr,UNUSED void * cb_data)1020 dxil_nir_lower_double_math_instr(nir_builder *b,
1021                                  nir_instr *instr,
1022                                  UNUSED void *cb_data)
1023 {
1024    if (instr->type == nir_instr_type_intrinsic) {
1025       nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
1026       switch (intr->intrinsic) {
1027          case nir_intrinsic_reduce:
1028          case nir_intrinsic_exclusive_scan:
1029          case nir_intrinsic_inclusive_scan:
1030             break;
1031          default:
1032             return false;
1033       }
1034       if (intr->def.bit_size != 64)
1035          return false;
1036       nir_op reduction = nir_intrinsic_reduction_op(intr);
1037       switch (reduction) {
1038          case nir_op_fmul:
1039          case nir_op_fadd:
1040          case nir_op_fmin:
1041          case nir_op_fmax:
1042             break;
1043          default:
1044             return false;
1045       }
1046       b->cursor = nir_before_instr(instr);
1047       nir_src_rewrite(&intr->src[0], nir_pack_double_2x32_dxil(b, nir_unpack_64_2x32(b, intr->src[0].ssa)));
1048       b->cursor = nir_after_instr(instr);
1049       nir_def *result = nir_pack_64_2x32(b, nir_unpack_double_2x32_dxil(b, &intr->def));
1050       nir_def_rewrite_uses_after(&intr->def, result, result->parent_instr);
1051       return true;
1052    }
1053 
1054    if (instr->type != nir_instr_type_alu)
1055       return false;
1056 
1057    nir_alu_instr *alu = nir_instr_as_alu(instr);
1058 
1059    /* TODO: See if we can apply this explicitly to packs/unpacks that are then
1060     * used as a double. As-is, if we had an app explicitly do a 64bit integer op,
1061     * then try to bitcast to double (not expressible in HLSL, but it is in other
1062     * source languages), this would unpack the integer and repack as a double, when
1063     * we probably want to just send the bitcast through to the backend.
1064     */
1065 
1066    b->cursor = nir_before_instr(&alu->instr);
1067 
1068    bool progress = false;
1069    for (unsigned i = 0; i < nir_op_infos[alu->op].num_inputs; ++i) {
1070       if (nir_alu_type_get_base_type(nir_op_infos[alu->op].input_types[i]) == nir_type_float &&
1071           alu->src[i].src.ssa->bit_size == 64) {
1072          unsigned num_components = nir_op_infos[alu->op].input_sizes[i];
1073          if (!num_components)
1074             num_components = alu->def.num_components;
1075          nir_def *components[NIR_MAX_VEC_COMPONENTS];
1076          for (unsigned c = 0; c < num_components; ++c) {
1077             nir_def *packed_double = nir_channel(b, alu->src[i].src.ssa, alu->src[i].swizzle[c]);
1078             nir_def *unpacked_double = nir_unpack_64_2x32(b, packed_double);
1079             components[c] = nir_pack_double_2x32_dxil(b, unpacked_double);
1080             alu->src[i].swizzle[c] = c;
1081          }
1082          nir_src_rewrite(&alu->src[i].src,
1083                          nir_vec(b, components, num_components));
1084          progress = true;
1085       }
1086    }
1087 
1088    if (nir_alu_type_get_base_type(nir_op_infos[alu->op].output_type) == nir_type_float &&
1089        alu->def.bit_size == 64) {
1090       b->cursor = nir_after_instr(&alu->instr);
1091       nir_def *components[NIR_MAX_VEC_COMPONENTS];
1092       for (unsigned c = 0; c < alu->def.num_components; ++c) {
1093          nir_def *packed_double = nir_channel(b, &alu->def, c);
1094          nir_def *unpacked_double = nir_unpack_double_2x32_dxil(b, packed_double);
1095          components[c] = nir_pack_64_2x32(b, unpacked_double);
1096       }
1097       nir_def *repacked_dvec = nir_vec(b, components, alu->def.num_components);
1098       nir_def_rewrite_uses_after(&alu->def, repacked_dvec, repacked_dvec->parent_instr);
1099       progress = true;
1100    }
1101 
1102    return progress;
1103 }
1104 
1105 bool
dxil_nir_lower_double_math(nir_shader * shader)1106 dxil_nir_lower_double_math(nir_shader *shader)
1107 {
1108    return nir_shader_instructions_pass(shader,
1109                                        dxil_nir_lower_double_math_instr,
1110                                        nir_metadata_block_index |
1111                                        nir_metadata_dominance |
1112                                        nir_metadata_loop_analysis,
1113                                        NULL);
1114 }
1115 
1116 typedef struct {
1117    gl_system_value *values;
1118    uint32_t count;
1119 } zero_system_values_state;
1120 
1121 static bool
lower_system_value_to_zero_filter(const nir_instr * instr,const void * cb_state)1122 lower_system_value_to_zero_filter(const nir_instr* instr, const void* cb_state)
1123 {
1124    if (instr->type != nir_instr_type_intrinsic) {
1125       return false;
1126    }
1127 
1128    nir_intrinsic_instr* intrin = nir_instr_as_intrinsic(instr);
1129 
1130    /* All the intrinsics we care about are loads */
1131    if (!nir_intrinsic_infos[intrin->intrinsic].has_dest)
1132       return false;
1133 
1134    zero_system_values_state* state = (zero_system_values_state*)cb_state;
1135    for (uint32_t i = 0; i < state->count; ++i) {
1136       gl_system_value value = state->values[i];
1137       nir_intrinsic_op value_op = nir_intrinsic_from_system_value(value);
1138 
1139       if (intrin->intrinsic == value_op) {
1140          return true;
1141       } else if (intrin->intrinsic == nir_intrinsic_load_deref) {
1142          nir_deref_instr* deref = nir_src_as_deref(intrin->src[0]);
1143          if (!nir_deref_mode_is(deref, nir_var_system_value))
1144             return false;
1145 
1146          nir_variable* var = deref->var;
1147          if (var->data.location == value) {
1148             return true;
1149          }
1150       }
1151    }
1152 
1153    return false;
1154 }
1155 
1156 static nir_def*
lower_system_value_to_zero_instr(nir_builder * b,nir_instr * instr,void * _state)1157 lower_system_value_to_zero_instr(nir_builder* b, nir_instr* instr, void* _state)
1158 {
1159    return nir_imm_int(b, 0);
1160 }
1161 
1162 bool
dxil_nir_lower_system_values_to_zero(nir_shader * shader,gl_system_value * system_values,uint32_t count)1163 dxil_nir_lower_system_values_to_zero(nir_shader* shader,
1164                                      gl_system_value* system_values,
1165                                      uint32_t count)
1166 {
1167    zero_system_values_state state = { system_values, count };
1168    return nir_shader_lower_instructions(shader,
1169       lower_system_value_to_zero_filter,
1170       lower_system_value_to_zero_instr,
1171       &state);
1172 }
1173 
1174 static void
lower_load_local_group_size(nir_builder * b,nir_intrinsic_instr * intr)1175 lower_load_local_group_size(nir_builder *b, nir_intrinsic_instr *intr)
1176 {
1177    b->cursor = nir_after_instr(&intr->instr);
1178 
1179    nir_const_value v[3] = {
1180       nir_const_value_for_int(b->shader->info.workgroup_size[0], 32),
1181       nir_const_value_for_int(b->shader->info.workgroup_size[1], 32),
1182       nir_const_value_for_int(b->shader->info.workgroup_size[2], 32)
1183    };
1184    nir_def *size = nir_build_imm(b, 3, 32, v);
1185    nir_def_rewrite_uses(&intr->def, size);
1186    nir_instr_remove(&intr->instr);
1187 }
1188 
1189 static bool
lower_system_values_impl(nir_builder * b,nir_intrinsic_instr * intr,void * _state)1190 lower_system_values_impl(nir_builder *b, nir_intrinsic_instr *intr,
1191                          void *_state)
1192 {
1193    switch (intr->intrinsic) {
1194    case nir_intrinsic_load_workgroup_size:
1195       lower_load_local_group_size(b, intr);
1196       return true;
1197    default:
1198       return false;
1199    }
1200 }
1201 
1202 bool
dxil_nir_lower_system_values(nir_shader * shader)1203 dxil_nir_lower_system_values(nir_shader *shader)
1204 {
1205    return nir_shader_intrinsics_pass(shader, lower_system_values_impl,
1206                                      nir_metadata_block_index | nir_metadata_dominance | nir_metadata_loop_analysis,
1207                                      NULL);
1208 }
1209 
1210 static const struct glsl_type *
get_bare_samplers_for_type(const struct glsl_type * type,bool is_shadow)1211 get_bare_samplers_for_type(const struct glsl_type *type, bool is_shadow)
1212 {
1213    const struct glsl_type *base_sampler_type =
1214       is_shadow ?
1215       glsl_bare_shadow_sampler_type() : glsl_bare_sampler_type();
1216    return glsl_type_wrap_in_arrays(base_sampler_type, type);
1217 }
1218 
1219 static const struct glsl_type *
get_textures_for_sampler_type(const struct glsl_type * type)1220 get_textures_for_sampler_type(const struct glsl_type *type)
1221 {
1222    return glsl_type_wrap_in_arrays(
1223       glsl_sampler_type_to_texture(
1224          glsl_without_array(type)), type);
1225 }
1226 
1227 static bool
redirect_sampler_derefs(struct nir_builder * b,nir_instr * instr,void * data)1228 redirect_sampler_derefs(struct nir_builder *b, nir_instr *instr, void *data)
1229 {
1230    if (instr->type != nir_instr_type_tex)
1231       return false;
1232 
1233    nir_tex_instr *tex = nir_instr_as_tex(instr);
1234 
1235    int sampler_idx = nir_tex_instr_src_index(tex, nir_tex_src_sampler_deref);
1236    if (sampler_idx == -1) {
1237       /* No sampler deref - does this instruction even need a sampler? If not,
1238        * sampler_index doesn't necessarily point to a sampler, so early-out.
1239        */
1240       if (!nir_tex_instr_need_sampler(tex))
1241          return false;
1242 
1243       /* No derefs but needs a sampler, must be using indices */
1244       nir_variable *bare_sampler = _mesa_hash_table_u64_search(data, tex->sampler_index);
1245 
1246       /* Already have a bare sampler here */
1247       if (bare_sampler)
1248          return false;
1249 
1250       nir_variable *old_sampler = NULL;
1251       nir_foreach_variable_with_modes(var, b->shader, nir_var_uniform) {
1252          if (var->data.binding <= tex->sampler_index &&
1253              var->data.binding + glsl_type_get_sampler_count(var->type) >
1254                 tex->sampler_index) {
1255 
1256             /* Already have a bare sampler for this binding and it is of the
1257              * correct type, add it to the table */
1258             if (glsl_type_is_bare_sampler(glsl_without_array(var->type)) &&
1259                 glsl_sampler_type_is_shadow(glsl_without_array(var->type)) ==
1260                    tex->is_shadow) {
1261                _mesa_hash_table_u64_insert(data, tex->sampler_index, var);
1262                return false;
1263             }
1264 
1265             old_sampler = var;
1266          }
1267       }
1268 
1269       assert(old_sampler);
1270 
1271       /* Clone the original sampler to a bare sampler of the correct type */
1272       bare_sampler = nir_variable_clone(old_sampler, b->shader);
1273       nir_shader_add_variable(b->shader, bare_sampler);
1274 
1275       bare_sampler->type =
1276          get_bare_samplers_for_type(old_sampler->type, tex->is_shadow);
1277       _mesa_hash_table_u64_insert(data, tex->sampler_index, bare_sampler);
1278       return true;
1279    }
1280 
1281    /* Using derefs, means we have to rewrite the deref chain in addition to cloning */
1282    nir_deref_instr *final_deref = nir_src_as_deref(tex->src[sampler_idx].src);
1283    nir_deref_path path;
1284    nir_deref_path_init(&path, final_deref, NULL);
1285 
1286    nir_deref_instr *old_tail = path.path[0];
1287    assert(old_tail->deref_type == nir_deref_type_var);
1288    nir_variable *old_var = old_tail->var;
1289    if (glsl_type_is_bare_sampler(glsl_without_array(old_var->type)) &&
1290        glsl_sampler_type_is_shadow(glsl_without_array(old_var->type)) ==
1291           tex->is_shadow) {
1292       nir_deref_path_finish(&path);
1293       return false;
1294    }
1295 
1296    uint64_t var_key = ((uint64_t)old_var->data.descriptor_set << 32) |
1297                       old_var->data.binding;
1298    nir_variable *new_var = _mesa_hash_table_u64_search(data, var_key);
1299    if (!new_var) {
1300       new_var = nir_variable_clone(old_var, b->shader);
1301       nir_shader_add_variable(b->shader, new_var);
1302       new_var->type =
1303          get_bare_samplers_for_type(old_var->type, tex->is_shadow);
1304       _mesa_hash_table_u64_insert(data, var_key, new_var);
1305    }
1306 
1307    b->cursor = nir_after_instr(&old_tail->instr);
1308    nir_deref_instr *new_tail = nir_build_deref_var(b, new_var);
1309 
1310    for (unsigned i = 1; path.path[i]; ++i) {
1311       b->cursor = nir_after_instr(&path.path[i]->instr);
1312       new_tail = nir_build_deref_follower(b, new_tail, path.path[i]);
1313    }
1314 
1315    nir_deref_path_finish(&path);
1316    nir_src_rewrite(&tex->src[sampler_idx].src, &new_tail->def);
1317    return true;
1318 }
1319 
1320 static bool
redirect_texture_derefs(struct nir_builder * b,nir_instr * instr,void * data)1321 redirect_texture_derefs(struct nir_builder *b, nir_instr *instr, void *data)
1322 {
1323    if (instr->type != nir_instr_type_tex)
1324       return false;
1325 
1326    nir_tex_instr *tex = nir_instr_as_tex(instr);
1327 
1328    int texture_idx = nir_tex_instr_src_index(tex, nir_tex_src_texture_deref);
1329    if (texture_idx == -1) {
1330       /* No derefs, must be using indices */
1331       nir_variable *bare_sampler = _mesa_hash_table_u64_search(data, tex->texture_index);
1332 
1333       /* Already have a texture here */
1334       if (bare_sampler)
1335          return false;
1336 
1337       nir_variable *typed_sampler = NULL;
1338       nir_foreach_variable_with_modes(var, b->shader, nir_var_uniform) {
1339          if (var->data.binding <= tex->texture_index &&
1340              var->data.binding + glsl_type_get_texture_count(var->type) > tex->texture_index) {
1341             /* Already have a texture for this binding, add it to the table */
1342             _mesa_hash_table_u64_insert(data, tex->texture_index, var);
1343             return false;
1344          }
1345 
1346          if (var->data.binding <= tex->texture_index &&
1347              var->data.binding + glsl_type_get_sampler_count(var->type) > tex->texture_index &&
1348              !glsl_type_is_bare_sampler(glsl_without_array(var->type))) {
1349             typed_sampler = var;
1350          }
1351       }
1352 
1353       /* Clone the typed sampler to a texture and we're done */
1354       assert(typed_sampler);
1355       bare_sampler = nir_variable_clone(typed_sampler, b->shader);
1356       bare_sampler->type = get_textures_for_sampler_type(typed_sampler->type);
1357       nir_shader_add_variable(b->shader, bare_sampler);
1358       _mesa_hash_table_u64_insert(data, tex->texture_index, bare_sampler);
1359       return true;
1360    }
1361 
1362    /* Using derefs, means we have to rewrite the deref chain in addition to cloning */
1363    nir_deref_instr *final_deref = nir_src_as_deref(tex->src[texture_idx].src);
1364    nir_deref_path path;
1365    nir_deref_path_init(&path, final_deref, NULL);
1366 
1367    nir_deref_instr *old_tail = path.path[0];
1368    assert(old_tail->deref_type == nir_deref_type_var);
1369    nir_variable *old_var = old_tail->var;
1370    if (glsl_type_is_texture(glsl_without_array(old_var->type)) ||
1371        glsl_type_is_image(glsl_without_array(old_var->type))) {
1372       nir_deref_path_finish(&path);
1373       return false;
1374    }
1375 
1376    uint64_t var_key = ((uint64_t)old_var->data.descriptor_set << 32) |
1377                       old_var->data.binding;
1378    nir_variable *new_var = _mesa_hash_table_u64_search(data, var_key);
1379    if (!new_var) {
1380       new_var = nir_variable_clone(old_var, b->shader);
1381       new_var->type = get_textures_for_sampler_type(old_var->type);
1382       nir_shader_add_variable(b->shader, new_var);
1383       _mesa_hash_table_u64_insert(data, var_key, new_var);
1384    }
1385 
1386    b->cursor = nir_after_instr(&old_tail->instr);
1387    nir_deref_instr *new_tail = nir_build_deref_var(b, new_var);
1388 
1389    for (unsigned i = 1; path.path[i]; ++i) {
1390       b->cursor = nir_after_instr(&path.path[i]->instr);
1391       new_tail = nir_build_deref_follower(b, new_tail, path.path[i]);
1392    }
1393 
1394    nir_deref_path_finish(&path);
1395    nir_src_rewrite(&tex->src[texture_idx].src, &new_tail->def);
1396 
1397    return true;
1398 }
1399 
1400 bool
dxil_nir_split_typed_samplers(nir_shader * nir)1401 dxil_nir_split_typed_samplers(nir_shader *nir)
1402 {
1403    struct hash_table_u64 *hash_table = _mesa_hash_table_u64_create(NULL);
1404 
1405    bool progress = nir_shader_instructions_pass(nir, redirect_sampler_derefs,
1406       nir_metadata_block_index | nir_metadata_dominance | nir_metadata_loop_analysis, hash_table);
1407 
1408    _mesa_hash_table_u64_clear(hash_table);
1409 
1410    progress |= nir_shader_instructions_pass(nir, redirect_texture_derefs,
1411       nir_metadata_block_index | nir_metadata_dominance | nir_metadata_loop_analysis, hash_table);
1412 
1413    _mesa_hash_table_u64_destroy(hash_table);
1414    return progress;
1415 }
1416 
1417 
1418 static bool
lower_sysval_to_load_input_impl(nir_builder * b,nir_intrinsic_instr * intr,void * data)1419 lower_sysval_to_load_input_impl(nir_builder *b, nir_intrinsic_instr *intr,
1420                                 void *data)
1421 {
1422    gl_system_value sysval = SYSTEM_VALUE_MAX;
1423    switch (intr->intrinsic) {
1424    case nir_intrinsic_load_front_face:
1425       sysval = SYSTEM_VALUE_FRONT_FACE;
1426       break;
1427    case nir_intrinsic_load_instance_id:
1428       sysval = SYSTEM_VALUE_INSTANCE_ID;
1429       break;
1430    case nir_intrinsic_load_vertex_id_zero_base:
1431       sysval = SYSTEM_VALUE_VERTEX_ID_ZERO_BASE;
1432       break;
1433    default:
1434       return false;
1435    }
1436 
1437    nir_variable **sysval_vars = (nir_variable **)data;
1438    nir_variable *var = sysval_vars[sysval];
1439    assert(var);
1440 
1441    const nir_alu_type dest_type = (sysval == SYSTEM_VALUE_FRONT_FACE)
1442       ? nir_type_uint32 : nir_get_nir_type_for_glsl_type(var->type);
1443    const unsigned bit_size = (sysval == SYSTEM_VALUE_FRONT_FACE)
1444       ? 32 : intr->def.bit_size;
1445 
1446    b->cursor = nir_before_instr(&intr->instr);
1447    nir_def *result = nir_load_input(b, intr->def.num_components, bit_size, nir_imm_int(b, 0),
1448       .base = var->data.driver_location, .dest_type = dest_type);
1449 
1450    /* The nir_type_uint32 is really a nir_type_bool32, but that type is very
1451     * inconvenient at this point during compilation.  Convert to
1452     * nir_type_bool1 by comparing with zero.
1453     */
1454    if (sysval == SYSTEM_VALUE_FRONT_FACE)
1455       result = nir_ine_imm(b, result, 0);
1456 
1457    nir_def_rewrite_uses(&intr->def, result);
1458    return true;
1459 }
1460 
1461 bool
dxil_nir_lower_sysval_to_load_input(nir_shader * s,nir_variable ** sysval_vars)1462 dxil_nir_lower_sysval_to_load_input(nir_shader *s, nir_variable **sysval_vars)
1463 {
1464    return nir_shader_intrinsics_pass(s, lower_sysval_to_load_input_impl,
1465                                      nir_metadata_block_index | nir_metadata_dominance,
1466                                      sysval_vars);
1467 }
1468 
1469 /* Comparison function to sort io values so that first come normal varyings,
1470  * then system values, and then system generated values.
1471  */
1472 static int
variable_location_cmp(const nir_variable * a,const nir_variable * b)1473 variable_location_cmp(const nir_variable* a, const nir_variable* b)
1474 {
1475    // Sort by stream, driver_location, location, location_frac, then index
1476    // If all else is equal, sort full vectors before partial ones
1477    unsigned a_location = a->data.location;
1478    if (a_location >= VARYING_SLOT_PATCH0)
1479       a_location -= VARYING_SLOT_PATCH0;
1480    unsigned b_location = b->data.location;
1481    if (b_location >= VARYING_SLOT_PATCH0)
1482       b_location -= VARYING_SLOT_PATCH0;
1483    unsigned a_stream = a->data.stream & ~NIR_STREAM_PACKED;
1484    unsigned b_stream = b->data.stream & ~NIR_STREAM_PACKED;
1485    return a_stream != b_stream ?
1486             a_stream - b_stream :
1487             a->data.driver_location != b->data.driver_location ?
1488                a->data.driver_location - b->data.driver_location :
1489                a_location !=  b_location ?
1490                   a_location - b_location :
1491                   a->data.location_frac != b->data.location_frac ?
1492                      a->data.location_frac - b->data.location_frac :
1493                      a->data.index != b->data.index ?
1494                         a->data.index - b->data.index :
1495                         glsl_get_component_slots(b->type) - glsl_get_component_slots(a->type);
1496 }
1497 
1498 /* Order varyings according to driver location */
1499 uint64_t
dxil_sort_by_driver_location(nir_shader * s,nir_variable_mode modes)1500 dxil_sort_by_driver_location(nir_shader* s, nir_variable_mode modes)
1501 {
1502    nir_sort_variables_with_modes(s, variable_location_cmp, modes);
1503 
1504    uint64_t result = 0;
1505    nir_foreach_variable_with_modes(var, s, modes) {
1506       result |= 1ull << var->data.location;
1507    }
1508    return result;
1509 }
1510 
1511 /* Sort PS outputs so that color outputs come first */
1512 void
dxil_sort_ps_outputs(nir_shader * s)1513 dxil_sort_ps_outputs(nir_shader* s)
1514 {
1515    nir_foreach_variable_with_modes_safe(var, s, nir_var_shader_out) {
1516       /* We use the driver_location here to avoid introducing a new
1517        * struct or member variable here. The true, updated driver location
1518        * will be written below, after sorting */
1519       switch (var->data.location) {
1520       case FRAG_RESULT_DEPTH:
1521          var->data.driver_location = 1;
1522          break;
1523       case FRAG_RESULT_STENCIL:
1524          var->data.driver_location = 2;
1525          break;
1526       case FRAG_RESULT_SAMPLE_MASK:
1527          var->data.driver_location = 3;
1528          break;
1529       default:
1530          var->data.driver_location = 0;
1531       }
1532    }
1533 
1534    nir_sort_variables_with_modes(s, variable_location_cmp,
1535                                  nir_var_shader_out);
1536 
1537    unsigned driver_loc = 0;
1538    nir_foreach_variable_with_modes(var, s, nir_var_shader_out) {
1539       /* Fractional vars should use the same driver_location as the base. These will
1540        * get fully merged during signature processing.
1541        */
1542       var->data.driver_location = var->data.location_frac ? driver_loc - 1 : driver_loc++;
1543    }
1544 }
1545 
1546 enum dxil_sysvalue_type {
1547    DXIL_NO_SYSVALUE = 0,
1548    DXIL_USED_SYSVALUE,
1549    DXIL_SYSVALUE,
1550    DXIL_GENERATED_SYSVALUE
1551 };
1552 
1553 static enum dxil_sysvalue_type
nir_var_to_dxil_sysvalue_type(nir_variable * var,uint64_t other_stage_mask)1554 nir_var_to_dxil_sysvalue_type(nir_variable *var, uint64_t other_stage_mask)
1555 {
1556    switch (var->data.location) {
1557    case VARYING_SLOT_FACE:
1558       return DXIL_GENERATED_SYSVALUE;
1559    case VARYING_SLOT_POS:
1560    case VARYING_SLOT_PRIMITIVE_ID:
1561    case VARYING_SLOT_CLIP_DIST0:
1562    case VARYING_SLOT_CLIP_DIST1:
1563    case VARYING_SLOT_PSIZ:
1564    case VARYING_SLOT_TESS_LEVEL_INNER:
1565    case VARYING_SLOT_TESS_LEVEL_OUTER:
1566    case VARYING_SLOT_VIEWPORT:
1567    case VARYING_SLOT_LAYER:
1568    case VARYING_SLOT_VIEW_INDEX:
1569       if (!((1ull << var->data.location) & other_stage_mask))
1570          return DXIL_SYSVALUE;
1571       return DXIL_USED_SYSVALUE;
1572    default:
1573       return DXIL_NO_SYSVALUE;
1574    }
1575 }
1576 
1577 /* Order between stage values so that normal varyings come first,
1578  * then sysvalues and then system generated values.
1579  */
1580 uint64_t
dxil_reassign_driver_locations(nir_shader * s,nir_variable_mode modes,uint64_t other_stage_mask)1581 dxil_reassign_driver_locations(nir_shader* s, nir_variable_mode modes,
1582    uint64_t other_stage_mask)
1583 {
1584    nir_foreach_variable_with_modes_safe(var, s, modes) {
1585       /* We use the driver_location here to avoid introducing a new
1586        * struct or member variable here. The true, updated driver location
1587        * will be written below, after sorting */
1588       var->data.driver_location = nir_var_to_dxil_sysvalue_type(var, other_stage_mask);
1589    }
1590 
1591    nir_sort_variables_with_modes(s, variable_location_cmp, modes);
1592 
1593    uint64_t result = 0;
1594    unsigned driver_loc = 0, driver_patch_loc = 0;
1595    nir_foreach_variable_with_modes(var, s, modes) {
1596       if (var->data.location < 64)
1597          result |= 1ull << var->data.location;
1598       /* Overlap patches with non-patch */
1599       var->data.driver_location = var->data.patch ?
1600          driver_patch_loc++ : driver_loc++;
1601    }
1602    return result;
1603 }
1604 
1605 static bool
lower_ubo_array_one_to_static(struct nir_builder * b,nir_intrinsic_instr * intrin,void * cb_data)1606 lower_ubo_array_one_to_static(struct nir_builder *b,
1607                               nir_intrinsic_instr *intrin,
1608                               void *cb_data)
1609 {
1610    if (intrin->intrinsic != nir_intrinsic_load_vulkan_descriptor)
1611       return false;
1612 
1613    nir_variable *var =
1614       nir_get_binding_variable(b->shader, nir_chase_binding(intrin->src[0]));
1615 
1616    if (!var)
1617       return false;
1618 
1619    if (!glsl_type_is_array(var->type) || glsl_array_size(var->type) != 1)
1620       return false;
1621 
1622    nir_intrinsic_instr *index = nir_src_as_intrinsic(intrin->src[0]);
1623    /* We currently do not support reindex */
1624    assert(index && index->intrinsic == nir_intrinsic_vulkan_resource_index);
1625 
1626    if (nir_src_is_const(index->src[0]) && nir_src_as_uint(index->src[0]) == 0)
1627       return false;
1628 
1629    if (nir_intrinsic_desc_type(index) != VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER)
1630       return false;
1631 
1632    b->cursor = nir_instr_remove(&index->instr);
1633 
1634    // Indexing out of bounds on array of UBOs is considered undefined
1635    // behavior. Therefore, we just hardcode all the index to 0.
1636    uint8_t bit_size = index->def.bit_size;
1637    nir_def *zero = nir_imm_intN_t(b, 0, bit_size);
1638    nir_def *dest =
1639       nir_vulkan_resource_index(b, index->num_components, bit_size, zero,
1640                                 .desc_set = nir_intrinsic_desc_set(index),
1641                                 .binding = nir_intrinsic_binding(index),
1642                                 .desc_type = nir_intrinsic_desc_type(index));
1643 
1644    nir_def_rewrite_uses(&index->def, dest);
1645 
1646    return true;
1647 }
1648 
1649 bool
dxil_nir_lower_ubo_array_one_to_static(nir_shader * s)1650 dxil_nir_lower_ubo_array_one_to_static(nir_shader *s)
1651 {
1652    bool progress = nir_shader_intrinsics_pass(s,
1653                                               lower_ubo_array_one_to_static,
1654                                               nir_metadata_none, NULL);
1655 
1656    return progress;
1657 }
1658 
1659 static bool
is_fquantize2f16(const nir_instr * instr,const void * data)1660 is_fquantize2f16(const nir_instr *instr, const void *data)
1661 {
1662    if (instr->type != nir_instr_type_alu)
1663       return false;
1664 
1665    nir_alu_instr *alu = nir_instr_as_alu(instr);
1666    return alu->op == nir_op_fquantize2f16;
1667 }
1668 
1669 static nir_def *
lower_fquantize2f16(struct nir_builder * b,nir_instr * instr,void * data)1670 lower_fquantize2f16(struct nir_builder *b, nir_instr *instr, void *data)
1671 {
1672    /*
1673     * SpvOpQuantizeToF16 documentation says:
1674     *
1675     * "
1676     * If Value is an infinity, the result is the same infinity.
1677     * If Value is a NaN, the result is a NaN, but not necessarily the same NaN.
1678     * If Value is positive with a magnitude too large to represent as a 16-bit
1679     * floating-point value, the result is positive infinity. If Value is negative
1680     * with a magnitude too large to represent as a 16-bit floating-point value,
1681     * the result is negative infinity. If the magnitude of Value is too small to
1682     * represent as a normalized 16-bit floating-point value, the result may be
1683     * either +0 or -0.
1684     * "
1685     *
1686     * which we turn into:
1687     *
1688     *   if (val < MIN_FLOAT16)
1689     *      return -INFINITY;
1690     *   else if (val > MAX_FLOAT16)
1691     *      return -INFINITY;
1692     *   else if (fabs(val) < SMALLEST_NORMALIZED_FLOAT16 && sign(val) != 0)
1693     *      return -0.0f;
1694     *   else if (fabs(val) < SMALLEST_NORMALIZED_FLOAT16 && sign(val) == 0)
1695     *      return +0.0f;
1696     *   else
1697     *      return round(val);
1698     */
1699    nir_alu_instr *alu = nir_instr_as_alu(instr);
1700    nir_def *src =
1701       alu->src[0].src.ssa;
1702 
1703    nir_def *neg_inf_cond =
1704       nir_flt_imm(b, src, -65504.0f);
1705    nir_def *pos_inf_cond =
1706       nir_fgt_imm(b, src, 65504.0f);
1707    nir_def *zero_cond =
1708       nir_flt_imm(b, nir_fabs(b, src), ldexpf(1.0, -14));
1709    nir_def *zero = nir_iand_imm(b, src, 1 << 31);
1710    nir_def *round = nir_iand_imm(b, src, ~BITFIELD_MASK(13));
1711 
1712    nir_def *res =
1713       nir_bcsel(b, neg_inf_cond, nir_imm_float(b, -INFINITY), round);
1714    res = nir_bcsel(b, pos_inf_cond, nir_imm_float(b, INFINITY), res);
1715    res = nir_bcsel(b, zero_cond, zero, res);
1716    return res;
1717 }
1718 
1719 bool
dxil_nir_lower_fquantize2f16(nir_shader * s)1720 dxil_nir_lower_fquantize2f16(nir_shader *s)
1721 {
1722    return nir_shader_lower_instructions(s, is_fquantize2f16, lower_fquantize2f16, NULL);
1723 }
1724 
1725 static bool
fix_io_uint_deref_types(struct nir_builder * builder,nir_instr * instr,void * data)1726 fix_io_uint_deref_types(struct nir_builder *builder, nir_instr *instr, void *data)
1727 {
1728    if (instr->type != nir_instr_type_deref)
1729       return false;
1730 
1731    nir_deref_instr *deref = nir_instr_as_deref(instr);
1732    nir_variable *var = nir_deref_instr_get_variable(deref);
1733 
1734    if (var == data) {
1735       deref->type = glsl_type_wrap_in_arrays(glsl_uint_type(), deref->type);
1736       return true;
1737    }
1738 
1739    return false;
1740 }
1741 
1742 static bool
fix_io_uint_type(nir_shader * s,nir_variable_mode modes,int slot)1743 fix_io_uint_type(nir_shader *s, nir_variable_mode modes, int slot)
1744 {
1745    nir_variable *fixed_var = NULL;
1746    nir_foreach_variable_with_modes(var, s, modes) {
1747       if (var->data.location == slot) {
1748          const struct glsl_type *plain_type = glsl_without_array(var->type);
1749          if (plain_type == glsl_uint_type())
1750             return false;
1751 
1752          assert(plain_type == glsl_int_type());
1753          var->type = glsl_type_wrap_in_arrays(glsl_uint_type(), var->type);
1754          fixed_var = var;
1755          break;
1756       }
1757    }
1758 
1759    assert(fixed_var);
1760 
1761    return nir_shader_instructions_pass(s, fix_io_uint_deref_types,
1762                                        nir_metadata_all, fixed_var);
1763 }
1764 
1765 bool
dxil_nir_fix_io_uint_type(nir_shader * s,uint64_t in_mask,uint64_t out_mask)1766 dxil_nir_fix_io_uint_type(nir_shader *s, uint64_t in_mask, uint64_t out_mask)
1767 {
1768    if (!(s->info.outputs_written & out_mask) &&
1769        !(s->info.inputs_read & in_mask))
1770       return false;
1771 
1772    bool progress = false;
1773 
1774    while (in_mask) {
1775       int slot = u_bit_scan64(&in_mask);
1776       progress |= (s->info.inputs_read & (1ull << slot)) &&
1777                   fix_io_uint_type(s, nir_var_shader_in, slot);
1778    }
1779 
1780    while (out_mask) {
1781       int slot = u_bit_scan64(&out_mask);
1782       progress |= (s->info.outputs_written & (1ull << slot)) &&
1783                   fix_io_uint_type(s, nir_var_shader_out, slot);
1784    }
1785 
1786    return progress;
1787 }
1788 
1789 struct remove_after_discard_state {
1790    struct nir_block *active_block;
1791 };
1792 
1793 static bool
remove_after_discard(struct nir_builder * builder,nir_instr * instr,void * cb_data)1794 remove_after_discard(struct nir_builder *builder, nir_instr *instr,
1795                       void *cb_data)
1796 {
1797    struct remove_after_discard_state *state = cb_data;
1798    if (instr->block == state->active_block) {
1799       nir_instr_remove_v(instr);
1800       return true;
1801    }
1802 
1803    if (instr->type != nir_instr_type_intrinsic)
1804       return false;
1805 
1806    nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
1807 
1808    if (intr->intrinsic != nir_intrinsic_discard &&
1809        intr->intrinsic != nir_intrinsic_terminate &&
1810        intr->intrinsic != nir_intrinsic_discard_if &&
1811        intr->intrinsic != nir_intrinsic_terminate_if)
1812       return false;
1813 
1814    state->active_block = instr->block;
1815 
1816    return false;
1817 }
1818 
1819 static bool
lower_kill(struct nir_builder * builder,nir_intrinsic_instr * intr,void * _cb_data)1820 lower_kill(struct nir_builder *builder, nir_intrinsic_instr *intr,
1821            void *_cb_data)
1822 {
1823    if (intr->intrinsic != nir_intrinsic_discard &&
1824        intr->intrinsic != nir_intrinsic_terminate &&
1825        intr->intrinsic != nir_intrinsic_discard_if &&
1826        intr->intrinsic != nir_intrinsic_terminate_if)
1827       return false;
1828 
1829    builder->cursor = nir_instr_remove(&intr->instr);
1830    if (intr->intrinsic == nir_intrinsic_discard ||
1831        intr->intrinsic == nir_intrinsic_terminate) {
1832       nir_demote(builder);
1833    } else {
1834       nir_demote_if(builder, intr->src[0].ssa);
1835    }
1836 
1837    nir_jump(builder, nir_jump_return);
1838 
1839    return true;
1840 }
1841 
1842 bool
dxil_nir_lower_discard_and_terminate(nir_shader * s)1843 dxil_nir_lower_discard_and_terminate(nir_shader *s)
1844 {
1845    if (s->info.stage != MESA_SHADER_FRAGMENT)
1846       return false;
1847 
1848    // This pass only works if all functions have been inlined
1849    assert(exec_list_length(&s->functions) == 1);
1850    struct remove_after_discard_state state;
1851    state.active_block = NULL;
1852    nir_shader_instructions_pass(s, remove_after_discard, nir_metadata_none,
1853                                 &state);
1854    return nir_shader_intrinsics_pass(s, lower_kill, nir_metadata_none,
1855                                        NULL);
1856 }
1857 
1858 static bool
update_writes(struct nir_builder * b,nir_intrinsic_instr * intr,void * _state)1859 update_writes(struct nir_builder *b, nir_intrinsic_instr *intr, void *_state)
1860 {
1861    if (intr->intrinsic != nir_intrinsic_store_output)
1862       return false;
1863 
1864    nir_io_semantics io = nir_intrinsic_io_semantics(intr);
1865    if (io.location != VARYING_SLOT_POS)
1866       return false;
1867 
1868    nir_def *src = intr->src[0].ssa;
1869    unsigned write_mask = nir_intrinsic_write_mask(intr);
1870    if (src->num_components == 4 && write_mask == 0xf)
1871       return false;
1872 
1873    b->cursor = nir_before_instr(&intr->instr);
1874    unsigned first_comp = nir_intrinsic_component(intr);
1875    nir_def *channels[4] = { NULL, NULL, NULL, NULL };
1876    assert(first_comp + src->num_components <= ARRAY_SIZE(channels));
1877    for (unsigned i = 0; i < src->num_components; ++i)
1878       if (write_mask & (1 << i))
1879          channels[i + first_comp] = nir_channel(b, src, i);
1880    for (unsigned i = 0; i < 4; ++i)
1881       if (!channels[i])
1882          channels[i] = nir_imm_intN_t(b, 0, src->bit_size);
1883 
1884    intr->num_components = 4;
1885    nir_src_rewrite(&intr->src[0], nir_vec(b, channels, 4));
1886    nir_intrinsic_set_component(intr, 0);
1887    nir_intrinsic_set_write_mask(intr, 0xf);
1888    return true;
1889 }
1890 
1891 bool
dxil_nir_ensure_position_writes(nir_shader * s)1892 dxil_nir_ensure_position_writes(nir_shader *s)
1893 {
1894    if (s->info.stage != MESA_SHADER_VERTEX &&
1895        s->info.stage != MESA_SHADER_GEOMETRY &&
1896        s->info.stage != MESA_SHADER_TESS_EVAL)
1897       return false;
1898    if ((s->info.outputs_written & VARYING_BIT_POS) == 0)
1899       return false;
1900 
1901    return nir_shader_intrinsics_pass(s, update_writes,
1902                                        nir_metadata_block_index | nir_metadata_dominance,
1903                                        NULL);
1904 }
1905 
1906 static bool
is_sample_pos(const nir_instr * instr,const void * _data)1907 is_sample_pos(const nir_instr *instr, const void *_data)
1908 {
1909    if (instr->type != nir_instr_type_intrinsic)
1910       return false;
1911    nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
1912    return intr->intrinsic == nir_intrinsic_load_sample_pos;
1913 }
1914 
1915 static nir_def *
lower_sample_pos(nir_builder * b,nir_instr * instr,void * _data)1916 lower_sample_pos(nir_builder *b, nir_instr *instr, void *_data)
1917 {
1918    return nir_load_sample_pos_from_id(b, 32, nir_load_sample_id(b));
1919 }
1920 
1921 bool
dxil_nir_lower_sample_pos(nir_shader * s)1922 dxil_nir_lower_sample_pos(nir_shader *s)
1923 {
1924    return nir_shader_lower_instructions(s, is_sample_pos, lower_sample_pos, NULL);
1925 }
1926 
1927 static bool
lower_subgroup_id(nir_builder * b,nir_intrinsic_instr * intr,void * data)1928 lower_subgroup_id(nir_builder *b, nir_intrinsic_instr *intr, void *data)
1929 {
1930    if (intr->intrinsic != nir_intrinsic_load_subgroup_id)
1931       return false;
1932 
1933    b->cursor = nir_before_impl(b->impl);
1934    if (b->shader->info.stage == MESA_SHADER_COMPUTE &&
1935        b->shader->info.workgroup_size[1] == 1 &&
1936        b->shader->info.workgroup_size[2] == 1) {
1937       /* When using Nx1x1 groups, use a simple stable algorithm
1938        * which is almost guaranteed to be correct. */
1939       nir_def *subgroup_id = nir_udiv(b, nir_load_local_invocation_index(b), nir_load_subgroup_size(b));
1940       nir_def_rewrite_uses(&intr->def, subgroup_id);
1941       return true;
1942    }
1943 
1944    nir_def **subgroup_id = (nir_def **)data;
1945    if (*subgroup_id == NULL) {
1946       nir_variable *subgroup_id_counter = nir_variable_create(b->shader, nir_var_mem_shared, glsl_uint_type(), "dxil_SubgroupID_counter");
1947       nir_variable *subgroup_id_local = nir_local_variable_create(b->impl, glsl_uint_type(), "dxil_SubgroupID_local");
1948       nir_store_var(b, subgroup_id_local, nir_imm_int(b, 0), 1);
1949 
1950       nir_deref_instr *counter_deref = nir_build_deref_var(b, subgroup_id_counter);
1951       nir_def *tid = nir_load_local_invocation_index(b);
1952       nir_if *nif = nir_push_if(b, nir_ieq_imm(b, tid, 0));
1953       nir_store_deref(b, counter_deref, nir_imm_int(b, 0), 1);
1954       nir_pop_if(b, nif);
1955 
1956       nir_barrier(b,
1957                          .execution_scope = SCOPE_WORKGROUP,
1958                          .memory_scope = SCOPE_WORKGROUP,
1959                          .memory_semantics = NIR_MEMORY_ACQ_REL,
1960                          .memory_modes = nir_var_mem_shared);
1961 
1962       nif = nir_push_if(b, nir_elect(b, 1));
1963       nir_def *subgroup_id_first_thread = nir_deref_atomic(b, 32, &counter_deref->def, nir_imm_int(b, 1),
1964                                                                .atomic_op = nir_atomic_op_iadd);
1965       nir_store_var(b, subgroup_id_local, subgroup_id_first_thread, 1);
1966       nir_pop_if(b, nif);
1967 
1968       nir_def *subgroup_id_loaded = nir_load_var(b, subgroup_id_local);
1969       *subgroup_id = nir_read_first_invocation(b, subgroup_id_loaded);
1970    }
1971    nir_def_rewrite_uses(&intr->def, *subgroup_id);
1972    return true;
1973 }
1974 
1975 bool
dxil_nir_lower_subgroup_id(nir_shader * s)1976 dxil_nir_lower_subgroup_id(nir_shader *s)
1977 {
1978    nir_def *subgroup_id = NULL;
1979    return nir_shader_intrinsics_pass(s, lower_subgroup_id, nir_metadata_none,
1980                                      &subgroup_id);
1981 }
1982 
1983 static bool
lower_num_subgroups(nir_builder * b,nir_intrinsic_instr * intr,void * data)1984 lower_num_subgroups(nir_builder *b, nir_intrinsic_instr *intr, void *data)
1985 {
1986    if (intr->intrinsic != nir_intrinsic_load_num_subgroups)
1987       return false;
1988 
1989    b->cursor = nir_before_instr(&intr->instr);
1990    nir_def *subgroup_size = nir_load_subgroup_size(b);
1991    nir_def *size_minus_one = nir_iadd_imm(b, subgroup_size, -1);
1992    nir_def *workgroup_size_vec = nir_load_workgroup_size(b);
1993    nir_def *workgroup_size = nir_imul(b, nir_channel(b, workgroup_size_vec, 0),
1994                                              nir_imul(b, nir_channel(b, workgroup_size_vec, 1),
1995                                                          nir_channel(b, workgroup_size_vec, 2)));
1996    nir_def *ret = nir_idiv(b, nir_iadd(b, workgroup_size, size_minus_one), subgroup_size);
1997    nir_def_rewrite_uses(&intr->def, ret);
1998    return true;
1999 }
2000 
2001 bool
dxil_nir_lower_num_subgroups(nir_shader * s)2002 dxil_nir_lower_num_subgroups(nir_shader *s)
2003 {
2004    return nir_shader_intrinsics_pass(s, lower_num_subgroups,
2005                                        nir_metadata_block_index |
2006                                        nir_metadata_dominance |
2007                                        nir_metadata_loop_analysis, NULL);
2008 }
2009 
2010 
2011 static const struct glsl_type *
get_cast_type(unsigned bit_size)2012 get_cast_type(unsigned bit_size)
2013 {
2014    switch (bit_size) {
2015    case 64:
2016       return glsl_int64_t_type();
2017    case 32:
2018       return glsl_int_type();
2019    case 16:
2020       return glsl_int16_t_type();
2021    case 8:
2022       return glsl_int8_t_type();
2023    }
2024    unreachable("Invalid bit_size");
2025 }
2026 
2027 static void
split_unaligned_load(nir_builder * b,nir_intrinsic_instr * intrin,unsigned alignment)2028 split_unaligned_load(nir_builder *b, nir_intrinsic_instr *intrin, unsigned alignment)
2029 {
2030    enum gl_access_qualifier access = nir_intrinsic_access(intrin);
2031    nir_def *srcs[NIR_MAX_VEC_COMPONENTS * NIR_MAX_VEC_COMPONENTS * sizeof(int64_t) / 8];
2032    unsigned comp_size = intrin->def.bit_size / 8;
2033    unsigned num_comps = intrin->def.num_components;
2034 
2035    b->cursor = nir_before_instr(&intrin->instr);
2036 
2037    nir_deref_instr *ptr = nir_src_as_deref(intrin->src[0]);
2038 
2039    const struct glsl_type *cast_type = get_cast_type(alignment * 8);
2040    nir_deref_instr *cast = nir_build_deref_cast(b, &ptr->def, ptr->modes, cast_type, alignment);
2041 
2042    unsigned num_loads = DIV_ROUND_UP(comp_size * num_comps, alignment);
2043    for (unsigned i = 0; i < num_loads; ++i) {
2044       nir_deref_instr *elem = nir_build_deref_ptr_as_array(b, cast, nir_imm_intN_t(b, i, cast->def.bit_size));
2045       srcs[i] = nir_load_deref_with_access(b, elem, access);
2046    }
2047 
2048    nir_def *new_dest = nir_extract_bits(b, srcs, num_loads, 0, num_comps, intrin->def.bit_size);
2049    nir_def_rewrite_uses(&intrin->def, new_dest);
2050    nir_instr_remove(&intrin->instr);
2051 }
2052 
2053 static void
split_unaligned_store(nir_builder * b,nir_intrinsic_instr * intrin,unsigned alignment)2054 split_unaligned_store(nir_builder *b, nir_intrinsic_instr *intrin, unsigned alignment)
2055 {
2056    enum gl_access_qualifier access = nir_intrinsic_access(intrin);
2057 
2058    nir_def *value = intrin->src[1].ssa;
2059    unsigned comp_size = value->bit_size / 8;
2060    unsigned num_comps = value->num_components;
2061 
2062    b->cursor = nir_before_instr(&intrin->instr);
2063 
2064    nir_deref_instr *ptr = nir_src_as_deref(intrin->src[0]);
2065 
2066    const struct glsl_type *cast_type = get_cast_type(alignment * 8);
2067    nir_deref_instr *cast = nir_build_deref_cast(b, &ptr->def, ptr->modes, cast_type, alignment);
2068 
2069    unsigned num_stores = DIV_ROUND_UP(comp_size * num_comps, alignment);
2070    for (unsigned i = 0; i < num_stores; ++i) {
2071       nir_def *substore_val = nir_extract_bits(b, &value, 1, i * alignment * 8, 1, alignment * 8);
2072       nir_deref_instr *elem = nir_build_deref_ptr_as_array(b, cast, nir_imm_intN_t(b, i, cast->def.bit_size));
2073       nir_store_deref_with_access(b, elem, substore_val, ~0, access);
2074    }
2075 
2076    nir_instr_remove(&intrin->instr);
2077 }
2078 
2079 bool
dxil_nir_split_unaligned_loads_stores(nir_shader * shader,nir_variable_mode modes)2080 dxil_nir_split_unaligned_loads_stores(nir_shader *shader, nir_variable_mode modes)
2081 {
2082    bool progress = false;
2083 
2084    nir_foreach_function_impl(impl, shader) {
2085       nir_builder b = nir_builder_create(impl);
2086 
2087       nir_foreach_block(block, impl) {
2088          nir_foreach_instr_safe(instr, block) {
2089             if (instr->type != nir_instr_type_intrinsic)
2090                continue;
2091             nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
2092             if (intrin->intrinsic != nir_intrinsic_load_deref &&
2093                 intrin->intrinsic != nir_intrinsic_store_deref)
2094                continue;
2095             nir_deref_instr *deref = nir_src_as_deref(intrin->src[0]);
2096             if (!nir_deref_mode_may_be(deref, modes))
2097                continue;
2098 
2099             unsigned align_mul = 0, align_offset = 0;
2100             nir_get_explicit_deref_align(deref, true, &align_mul, &align_offset);
2101 
2102             unsigned alignment = align_offset ? 1 << (ffs(align_offset) - 1) : align_mul;
2103 
2104             /* We can load anything at 4-byte alignment, except for
2105              * UBOs (AKA CBs where the granularity is 16 bytes).
2106              */
2107             unsigned req_align = (nir_deref_mode_is_one_of(deref, nir_var_mem_ubo | nir_var_mem_push_const) ? 16 : 4);
2108             if (alignment >= req_align)
2109                continue;
2110 
2111             nir_def *val;
2112             if (intrin->intrinsic == nir_intrinsic_load_deref) {
2113                val = &intrin->def;
2114             } else {
2115                val = intrin->src[1].ssa;
2116             }
2117 
2118             unsigned scalar_byte_size = glsl_type_is_boolean(deref->type) ? 4 : glsl_get_bit_size(deref->type) / 8;
2119             unsigned num_components =
2120                /* If the vector stride is larger than the scalar size, lower_explicit_io will
2121                 * turn this into multiple scalar loads anyway, so we don't have to split it here. */
2122                glsl_get_explicit_stride(deref->type) > scalar_byte_size ? 1 :
2123                (val->num_components == 3 ? 4 : val->num_components);
2124             unsigned natural_alignment = scalar_byte_size * num_components;
2125 
2126             if (alignment >= natural_alignment)
2127                continue;
2128 
2129             if (intrin->intrinsic == nir_intrinsic_load_deref)
2130                split_unaligned_load(&b, intrin, alignment);
2131             else
2132                split_unaligned_store(&b, intrin, alignment);
2133             progress = true;
2134          }
2135       }
2136    }
2137 
2138    return progress;
2139 }
2140 
2141 static void
lower_inclusive_to_exclusive(nir_builder * b,nir_intrinsic_instr * intr)2142 lower_inclusive_to_exclusive(nir_builder *b, nir_intrinsic_instr *intr)
2143 {
2144    b->cursor = nir_after_instr(&intr->instr);
2145 
2146    nir_op op = nir_intrinsic_reduction_op(intr);
2147    intr->intrinsic = nir_intrinsic_exclusive_scan;
2148    nir_intrinsic_set_reduction_op(intr, op);
2149 
2150    nir_def *final_val = nir_build_alu2(b, nir_intrinsic_reduction_op(intr),
2151                                            &intr->def, intr->src[0].ssa);
2152    nir_def_rewrite_uses_after(&intr->def, final_val, final_val->parent_instr);
2153 }
2154 
2155 static bool
lower_subgroup_scan(nir_builder * b,nir_intrinsic_instr * intr,void * data)2156 lower_subgroup_scan(nir_builder *b, nir_intrinsic_instr *intr, void *data)
2157 {
2158    switch (intr->intrinsic) {
2159    case nir_intrinsic_exclusive_scan:
2160    case nir_intrinsic_inclusive_scan:
2161       switch ((nir_op)nir_intrinsic_reduction_op(intr)) {
2162       case nir_op_iadd:
2163       case nir_op_fadd:
2164       case nir_op_imul:
2165       case nir_op_fmul:
2166          if (intr->intrinsic == nir_intrinsic_exclusive_scan)
2167             return false;
2168          lower_inclusive_to_exclusive(b, intr);
2169          return true;
2170       default:
2171          break;
2172       }
2173       break;
2174    default:
2175       return false;
2176    }
2177 
2178    b->cursor = nir_before_instr(&intr->instr);
2179    nir_op op = nir_intrinsic_reduction_op(intr);
2180    nir_def *subgroup_id = nir_load_subgroup_invocation(b);
2181    nir_def *subgroup_size = nir_load_subgroup_size(b);
2182    nir_def *active_threads = nir_ballot(b, 4, 32, nir_imm_true(b));
2183    nir_def *base_value;
2184    uint32_t bit_size = intr->def.bit_size;
2185    if (op == nir_op_iand || op == nir_op_umin)
2186       base_value = nir_imm_intN_t(b, ~0ull, bit_size);
2187    else if (op == nir_op_imin)
2188       base_value = nir_imm_intN_t(b, (1ull << (bit_size - 1)) - 1, bit_size);
2189    else if (op == nir_op_imax)
2190       base_value = nir_imm_intN_t(b, 1ull << (bit_size - 1), bit_size);
2191    else if (op == nir_op_fmax)
2192       base_value = nir_imm_floatN_t(b, -INFINITY, bit_size);
2193    else if (op == nir_op_fmin)
2194       base_value = nir_imm_floatN_t(b, INFINITY, bit_size);
2195    else
2196       base_value = nir_imm_intN_t(b, 0, bit_size);
2197 
2198    nir_variable *loop_counter_var = nir_local_variable_create(b->impl, glsl_uint_type(), "subgroup_loop_counter");
2199    nir_variable *result_var = nir_local_variable_create(b->impl,
2200                                                         glsl_vector_type(nir_get_glsl_base_type_for_nir_type(
2201                                                            nir_op_infos[op].input_types[0] | bit_size), 1),
2202                                                         "subgroup_loop_result");
2203    nir_store_var(b, loop_counter_var, nir_imm_int(b, 0), 1);
2204    nir_store_var(b, result_var, base_value, 1);
2205    nir_loop *loop = nir_push_loop(b);
2206    nir_def *loop_counter = nir_load_var(b, loop_counter_var);
2207 
2208    nir_if *nif = nir_push_if(b, nir_ilt(b, loop_counter, subgroup_size));
2209    nir_def *other_thread_val = nir_read_invocation(b, intr->src[0].ssa, loop_counter);
2210    nir_def *thread_in_range = intr->intrinsic == nir_intrinsic_inclusive_scan ?
2211       nir_ige(b, subgroup_id, loop_counter) :
2212       nir_ilt(b, loop_counter, subgroup_id);
2213    nir_def *thread_active = nir_ballot_bitfield_extract(b, 1, active_threads, loop_counter);
2214 
2215    nir_if *if_active_thread = nir_push_if(b, nir_iand(b, thread_in_range, thread_active));
2216    nir_def *result = nir_build_alu2(b, op, nir_load_var(b, result_var), other_thread_val);
2217    nir_store_var(b, result_var, result, 1);
2218    nir_pop_if(b, if_active_thread);
2219 
2220    nir_store_var(b, loop_counter_var, nir_iadd_imm(b, loop_counter, 1), 1);
2221    nir_jump(b, nir_jump_continue);
2222    nir_pop_if(b, nif);
2223 
2224    nir_jump(b, nir_jump_break);
2225    nir_pop_loop(b, loop);
2226 
2227    result = nir_load_var(b, result_var);
2228    nir_def_rewrite_uses(&intr->def, result);
2229    return true;
2230 }
2231 
2232 bool
dxil_nir_lower_unsupported_subgroup_scan(nir_shader * s)2233 dxil_nir_lower_unsupported_subgroup_scan(nir_shader *s)
2234 {
2235    bool ret = nir_shader_intrinsics_pass(s, lower_subgroup_scan,
2236                                          nir_metadata_none, NULL);
2237    if (ret) {
2238       /* Lower the ballot bitfield tests */
2239       nir_lower_subgroups_options options = { .ballot_bit_size = 32, .ballot_components = 4 };
2240       nir_lower_subgroups(s, &options);
2241    }
2242    return ret;
2243 }
2244 
2245 static bool
lower_load_face(nir_builder * b,nir_intrinsic_instr * intr,void * data)2246 lower_load_face(nir_builder *b, nir_intrinsic_instr *intr, void *data)
2247 {
2248    if (intr->intrinsic != nir_intrinsic_load_front_face)
2249       return false;
2250 
2251    b->cursor = nir_before_instr(&intr->instr);
2252 
2253    nir_variable *var = data;
2254    nir_def *load = nir_ine_imm(b, nir_load_var(b, var), 0);
2255 
2256    nir_def_rewrite_uses(&intr->def, load);
2257    nir_instr_remove(&intr->instr);
2258    return true;
2259 }
2260 
2261 bool
dxil_nir_forward_front_face(nir_shader * nir)2262 dxil_nir_forward_front_face(nir_shader *nir)
2263 {
2264    assert(nir->info.stage == MESA_SHADER_FRAGMENT);
2265 
2266    nir_variable *var = nir_variable_create(nir, nir_var_shader_in,
2267                                            glsl_uint_type(),
2268                                            "gl_FrontFacing");
2269    var->data.location = VARYING_SLOT_VAR12;
2270    var->data.interpolation = INTERP_MODE_FLAT;
2271 
2272    return nir_shader_intrinsics_pass(nir, lower_load_face,
2273                                        nir_metadata_block_index | nir_metadata_dominance,
2274                                        var);
2275 }
2276 
2277 static bool
move_consts(nir_builder * b,nir_instr * instr,void * data)2278 move_consts(nir_builder *b, nir_instr *instr, void *data)
2279 {
2280    bool progress = false;
2281    switch (instr->type) {
2282    case nir_instr_type_load_const: {
2283       /* Sink load_const to their uses if there's multiple */
2284       nir_load_const_instr *load_const = nir_instr_as_load_const(instr);
2285       if (!list_is_singular(&load_const->def.uses)) {
2286          nir_foreach_use_safe(src, &load_const->def) {
2287             b->cursor = nir_before_src(src);
2288             nir_load_const_instr *new_load = nir_load_const_instr_create(b->shader,
2289                                                                          load_const->def.num_components,
2290                                                                          load_const->def.bit_size);
2291             memcpy(new_load->value, load_const->value, sizeof(load_const->value[0]) * load_const->def.num_components);
2292             nir_builder_instr_insert(b, &new_load->instr);
2293             nir_src_rewrite(src, &new_load->def);
2294             progress = true;
2295          }
2296       }
2297       return progress;
2298    }
2299    default:
2300       return false;
2301    }
2302 }
2303 
2304 /* Sink all consts so that they have only have a single use.
2305  * The DXIL backend will already de-dupe the constants to the
2306  * same dxil_value if they have the same type, but this allows a single constant
2307  * to have different types without bitcasts. */
2308 bool
dxil_nir_move_consts(nir_shader * s)2309 dxil_nir_move_consts(nir_shader *s)
2310 {
2311    return nir_shader_instructions_pass(s, move_consts,
2312                                        nir_metadata_block_index | nir_metadata_dominance,
2313                                        NULL);
2314 }
2315 
2316 static void
clear_pass_flags(nir_function_impl * impl)2317 clear_pass_flags(nir_function_impl *impl)
2318 {
2319    nir_foreach_block(block, impl) {
2320       nir_foreach_instr(instr, block) {
2321          instr->pass_flags = 0;
2322       }
2323    }
2324 }
2325 
2326 static bool
add_def_to_worklist(nir_def * def,void * state)2327 add_def_to_worklist(nir_def *def, void *state)
2328 {
2329    nir_foreach_use_including_if(src, def) {
2330       if (nir_src_is_if(src)) {
2331          nir_if *nif = nir_src_parent_if(src);
2332          nir_foreach_block_in_cf_node(block, &nif->cf_node) {
2333             nir_foreach_instr(instr, block)
2334                nir_instr_worklist_push_tail(state, instr);
2335          }
2336       } else
2337          nir_instr_worklist_push_tail(state, nir_src_parent_instr(src));
2338    }
2339    return true;
2340 }
2341 
2342 static bool
set_input_bits(struct dxil_module * mod,nir_intrinsic_instr * intr,BITSET_WORD * input_bits,uint32_t *** tables,const uint32_t ** table_sizes)2343 set_input_bits(struct dxil_module *mod, nir_intrinsic_instr *intr, BITSET_WORD *input_bits, uint32_t ***tables, const uint32_t **table_sizes)
2344 {
2345    if (intr->intrinsic == nir_intrinsic_load_view_index) {
2346       BITSET_SET(input_bits, 0);
2347       return true;
2348    }
2349 
2350    bool any_bits_set = false;
2351    nir_src *row_src = intr->intrinsic == nir_intrinsic_load_per_vertex_input ? &intr->src[1] : &intr->src[0];
2352    bool is_patch_constant = mod->shader_kind == DXIL_DOMAIN_SHADER && intr->intrinsic == nir_intrinsic_load_input;
2353    const struct dxil_signature_record *sig_rec = is_patch_constant ?
2354       &mod->patch_consts[nir_intrinsic_base(intr)] :
2355       &mod->inputs[mod->input_mappings[nir_intrinsic_base(intr)]];
2356    if (is_patch_constant) {
2357       /* Redirect to the second I/O table */
2358       *tables = *tables + 1;
2359       *table_sizes = *table_sizes + 1;
2360    }
2361    for (uint32_t component = 0; component < intr->num_components; ++component) {
2362       uint32_t base_element = 0;
2363       uint32_t num_elements = sig_rec->num_elements;
2364       if (nir_src_is_const(*row_src)) {
2365          base_element = (uint32_t)nir_src_as_uint(*row_src);
2366          num_elements = 1;
2367       }
2368       for (uint32_t element = 0; element < num_elements; ++element) {
2369          uint32_t row = sig_rec->elements[element + base_element].reg;
2370          if (row == 0xffffffff)
2371             continue;
2372          BITSET_SET(input_bits, row * 4 + component + nir_intrinsic_component(intr));
2373          any_bits_set = true;
2374       }
2375    }
2376    return any_bits_set;
2377 }
2378 
2379 static bool
set_output_bits(struct dxil_module * mod,nir_intrinsic_instr * intr,BITSET_WORD * input_bits,uint32_t ** tables,const uint32_t * table_sizes)2380 set_output_bits(struct dxil_module *mod, nir_intrinsic_instr *intr, BITSET_WORD *input_bits, uint32_t **tables, const uint32_t *table_sizes)
2381 {
2382    bool any_bits_set = false;
2383    nir_src *row_src = intr->intrinsic == nir_intrinsic_store_per_vertex_output ? &intr->src[2] : &intr->src[1];
2384    bool is_patch_constant = mod->shader_kind == DXIL_HULL_SHADER && intr->intrinsic == nir_intrinsic_store_output;
2385    const struct dxil_signature_record *sig_rec = is_patch_constant ?
2386       &mod->patch_consts[nir_intrinsic_base(intr)] :
2387       &mod->outputs[nir_intrinsic_base(intr)];
2388    for (uint32_t component = 0; component < intr->num_components; ++component) {
2389       uint32_t base_element = 0;
2390       uint32_t num_elements = sig_rec->num_elements;
2391       if (nir_src_is_const(*row_src)) {
2392          base_element = (uint32_t)nir_src_as_uint(*row_src);
2393          num_elements = 1;
2394       }
2395       for (uint32_t element = 0; element < num_elements; ++element) {
2396          uint32_t row = sig_rec->elements[element + base_element].reg;
2397          if (row == 0xffffffff)
2398             continue;
2399          uint32_t stream = sig_rec->elements[element + base_element].stream;
2400          uint32_t table_idx = is_patch_constant ? 1 : stream;
2401          uint32_t *table = tables[table_idx];
2402          uint32_t output_component = component + nir_intrinsic_component(intr);
2403          uint32_t input_component;
2404          BITSET_FOREACH_SET(input_component, input_bits, 32 * 4) {
2405             uint32_t *table_for_input_component = table + table_sizes[table_idx] * input_component;
2406             BITSET_SET(table_for_input_component, row * 4 + output_component);
2407             any_bits_set = true;
2408          }
2409       }
2410    }
2411    return any_bits_set;
2412 }
2413 
2414 static bool
propagate_input_to_output_dependencies(struct dxil_module * mod,nir_intrinsic_instr * load_intr,uint32_t ** tables,const uint32_t * table_sizes)2415 propagate_input_to_output_dependencies(struct dxil_module *mod, nir_intrinsic_instr *load_intr, uint32_t **tables, const uint32_t *table_sizes)
2416 {
2417    /* Which input components are being loaded by this instruction */
2418    BITSET_DECLARE(input_bits, 32 * 4) = { 0 };
2419    if (!set_input_bits(mod, load_intr, input_bits, &tables, &table_sizes))
2420       return false;
2421 
2422    nir_instr_worklist *worklist = nir_instr_worklist_create();
2423    nir_instr_worklist_push_tail(worklist, &load_intr->instr);
2424    bool any_bits_set = false;
2425    nir_foreach_instr_in_worklist(instr, worklist) {
2426       if (instr->pass_flags)
2427          continue;
2428 
2429       instr->pass_flags = 1;
2430       nir_foreach_def(instr, add_def_to_worklist, worklist);
2431       switch (instr->type) {
2432       case nir_instr_type_jump: {
2433          nir_jump_instr *jump = nir_instr_as_jump(instr);
2434          switch (jump->type) {
2435          case nir_jump_break:
2436          case nir_jump_continue: {
2437             nir_cf_node *parent = &instr->block->cf_node;
2438             while (parent->type != nir_cf_node_loop)
2439                parent = parent->parent;
2440             nir_foreach_block_in_cf_node(block, parent)
2441                nir_foreach_instr(i, block)
2442                nir_instr_worklist_push_tail(worklist, i);
2443             }
2444             break;
2445          default:
2446             unreachable("Don't expect any other jumps");
2447          }
2448          break;
2449       }
2450       case nir_instr_type_intrinsic: {
2451          nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
2452          switch (intr->intrinsic) {
2453          case nir_intrinsic_store_output:
2454          case nir_intrinsic_store_per_vertex_output:
2455             any_bits_set |= set_output_bits(mod, intr, input_bits, tables, table_sizes);
2456             break;
2457             /* TODO: Memory writes */
2458          default:
2459             break;
2460          }
2461          break;
2462       }
2463       default:
2464          break;
2465       }
2466    }
2467 
2468    nir_instr_worklist_destroy(worklist);
2469    return any_bits_set;
2470 }
2471 
2472 /* For every input load, compute the set of output stores that it can contribute to.
2473  * If it contributes to a store to memory, If it's used for control flow, then any
2474  * instruction in the CFG that it impacts is considered to contribute.
2475  * Ideally, we should also handle stores to outputs/memory and then loads from that
2476  * output/memory, but this is non-trivial and unclear how much impact that would have. */
2477 bool
dxil_nir_analyze_io_dependencies(struct dxil_module * mod,nir_shader * s)2478 dxil_nir_analyze_io_dependencies(struct dxil_module *mod, nir_shader *s)
2479 {
2480    bool any_outputs = false;
2481    for (uint32_t i = 0; i < 4; ++i)
2482       any_outputs |= mod->num_psv_outputs[i] > 0;
2483    if (mod->shader_kind == DXIL_HULL_SHADER)
2484       any_outputs |= mod->num_psv_patch_consts > 0;
2485    if (!any_outputs)
2486       return false;
2487 
2488    bool any_bits_set = false;
2489    nir_foreach_function(func, s) {
2490       assert(func->impl);
2491       /* Hull shaders have a patch constant function */
2492       assert(func->is_entrypoint || s->info.stage == MESA_SHADER_TESS_CTRL);
2493 
2494       /* Pass 1: input/view ID -> output dependencies */
2495       nir_foreach_block(block, func->impl) {
2496          nir_foreach_instr(instr, block) {
2497             if (instr->type != nir_instr_type_intrinsic)
2498                continue;
2499             nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
2500             uint32_t **tables = mod->io_dependency_table;
2501             const uint32_t *table_sizes = mod->dependency_table_dwords_per_input;
2502             switch (intr->intrinsic) {
2503             case nir_intrinsic_load_view_index:
2504                tables = mod->viewid_dependency_table;
2505                FALLTHROUGH;
2506             case nir_intrinsic_load_input:
2507             case nir_intrinsic_load_per_vertex_input:
2508             case nir_intrinsic_load_interpolated_input:
2509                break;
2510             default:
2511                continue;
2512             }
2513 
2514             clear_pass_flags(func->impl);
2515             any_bits_set |= propagate_input_to_output_dependencies(mod, intr, tables, table_sizes);
2516          }
2517       }
2518 
2519       /* Pass 2: output -> output dependencies */
2520       /* TODO */
2521    }
2522    return any_bits_set;
2523 }
2524 
2525 static enum pipe_format
get_format_for_var(unsigned num_comps,enum glsl_base_type sampled_type)2526 get_format_for_var(unsigned num_comps, enum glsl_base_type sampled_type)
2527 {
2528    switch (sampled_type) {
2529    case GLSL_TYPE_INT:
2530    case GLSL_TYPE_INT64:
2531    case GLSL_TYPE_INT16:
2532       switch (num_comps) {
2533       case 1: return PIPE_FORMAT_R32_SINT;
2534       case 2: return PIPE_FORMAT_R32G32_SINT;
2535       case 3: return PIPE_FORMAT_R32G32B32_SINT;
2536       case 4: return PIPE_FORMAT_R32G32B32A32_SINT;
2537       default: unreachable("Invalid num_comps");
2538       }
2539    case GLSL_TYPE_UINT:
2540    case GLSL_TYPE_UINT64:
2541    case GLSL_TYPE_UINT16:
2542       switch (num_comps) {
2543       case 1: return PIPE_FORMAT_R32_UINT;
2544       case 2: return PIPE_FORMAT_R32G32_UINT;
2545       case 3: return PIPE_FORMAT_R32G32B32_UINT;
2546       case 4: return PIPE_FORMAT_R32G32B32A32_UINT;
2547       default: unreachable("Invalid num_comps");
2548       }
2549    case GLSL_TYPE_FLOAT:
2550    case GLSL_TYPE_FLOAT16:
2551    case GLSL_TYPE_DOUBLE:
2552       switch (num_comps) {
2553       case 1: return PIPE_FORMAT_R32_FLOAT;
2554       case 2: return PIPE_FORMAT_R32G32_FLOAT;
2555       case 3: return PIPE_FORMAT_R32G32B32_FLOAT;
2556       case 4: return PIPE_FORMAT_R32G32B32A32_FLOAT;
2557       default: unreachable("Invalid num_comps");
2558       }
2559    default: unreachable("Invalid sampler return type");
2560    }
2561 }
2562 
2563 static unsigned
aoa_size(const struct glsl_type * type)2564 aoa_size(const struct glsl_type *type)
2565 {
2566    return glsl_type_is_array(type) ? glsl_get_aoa_size(type) : 1;
2567 }
2568 
2569 static bool
guess_image_format_for_var(nir_shader * s,nir_variable * var)2570 guess_image_format_for_var(nir_shader *s, nir_variable *var)
2571 {
2572    const struct glsl_type *base_type = glsl_without_array(var->type);
2573    if (!glsl_type_is_image(base_type))
2574       return false;
2575    if (var->data.image.format != PIPE_FORMAT_NONE)
2576       return false;
2577 
2578    nir_foreach_function_impl(impl, s) {
2579       nir_foreach_block(block, impl) {
2580          nir_foreach_instr(instr, block) {
2581             if (instr->type != nir_instr_type_intrinsic)
2582                continue;
2583             nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
2584             switch (intr->intrinsic) {
2585             case nir_intrinsic_image_deref_load:
2586             case nir_intrinsic_image_deref_store:
2587             case nir_intrinsic_image_deref_atomic:
2588             case nir_intrinsic_image_deref_atomic_swap:
2589                if (nir_intrinsic_get_var(intr, 0) != var)
2590                   continue;
2591                break;
2592             case nir_intrinsic_image_load:
2593             case nir_intrinsic_image_store:
2594             case nir_intrinsic_image_atomic:
2595             case nir_intrinsic_image_atomic_swap: {
2596                unsigned binding = nir_src_as_uint(intr->src[0]);
2597                if (binding < var->data.binding ||
2598                    binding >= var->data.binding + aoa_size(var->type))
2599                   continue;
2600                break;
2601                }
2602             default:
2603                continue;
2604             }
2605             break;
2606 
2607             switch (intr->intrinsic) {
2608             case nir_intrinsic_image_deref_load:
2609             case nir_intrinsic_image_load:
2610             case nir_intrinsic_image_deref_store:
2611             case nir_intrinsic_image_store:
2612                /* Increase unknown formats up to 4 components if a 4-component accessor is used */
2613                if (intr->num_components > util_format_get_nr_components(var->data.image.format))
2614                   var->data.image.format = get_format_for_var(intr->num_components, glsl_get_sampler_result_type(base_type));
2615                break;
2616             default:
2617                /* If an atomic is used, the image format must be 1-component; return immediately */
2618                var->data.image.format = get_format_for_var(1, glsl_get_sampler_result_type(base_type));
2619                return true;
2620             }
2621          }
2622       }
2623    }
2624    /* Dunno what it is, assume 4-component */
2625    if (var->data.image.format == PIPE_FORMAT_NONE)
2626       var->data.image.format = get_format_for_var(4, glsl_get_sampler_result_type(base_type));
2627    return true;
2628 }
2629 
2630 static void
update_intrinsic_format_and_type(nir_intrinsic_instr * intr,nir_variable * var)2631 update_intrinsic_format_and_type(nir_intrinsic_instr *intr, nir_variable *var)
2632 {
2633    nir_intrinsic_set_format(intr, var->data.image.format);
2634    nir_alu_type alu_type =
2635       nir_get_nir_type_for_glsl_base_type(glsl_get_sampler_result_type(glsl_without_array(var->type)));
2636    if (nir_intrinsic_has_src_type(intr))
2637       nir_intrinsic_set_src_type(intr, alu_type);
2638    else if (nir_intrinsic_has_dest_type(intr))
2639       nir_intrinsic_set_dest_type(intr, alu_type);
2640 }
2641 
2642 static bool
update_intrinsic_formats(nir_builder * b,nir_intrinsic_instr * intr,void * data)2643 update_intrinsic_formats(nir_builder *b, nir_intrinsic_instr *intr,
2644                          void *data)
2645 {
2646    if (!nir_intrinsic_has_format(intr))
2647       return false;
2648    nir_deref_instr *deref = nir_src_as_deref(intr->src[0]);
2649    if (deref) {
2650       nir_variable *var = nir_deref_instr_get_variable(deref);
2651       if (var)
2652          update_intrinsic_format_and_type(intr, var);
2653       return var != NULL;
2654    }
2655 
2656    if (!nir_intrinsic_has_range_base(intr))
2657       return false;
2658 
2659    unsigned binding = nir_src_as_uint(intr->src[0]);
2660    nir_foreach_variable_with_modes(var, b->shader, nir_var_image) {
2661       if (var->data.binding <= binding &&
2662           var->data.binding + aoa_size(var->type) > binding) {
2663          update_intrinsic_format_and_type(intr, var);
2664          return true;
2665       }
2666    }
2667    return false;
2668 }
2669 
2670 bool
dxil_nir_guess_image_formats(nir_shader * s)2671 dxil_nir_guess_image_formats(nir_shader *s)
2672 {
2673    bool progress = false;
2674    nir_foreach_variable_with_modes(var, s, nir_var_image) {
2675       progress |= guess_image_format_for_var(s, var);
2676    }
2677    nir_shader_intrinsics_pass(s, update_intrinsic_formats, nir_metadata_all,
2678                               NULL);
2679    return progress;
2680 }
2681 
2682 static void
set_binding_variables_coherent(nir_shader * s,nir_binding binding,nir_variable_mode modes)2683 set_binding_variables_coherent(nir_shader *s, nir_binding binding, nir_variable_mode modes)
2684 {
2685    nir_foreach_variable_with_modes(var, s, modes) {
2686       if (var->data.binding == binding.binding &&
2687           var->data.descriptor_set == binding.desc_set) {
2688          var->data.access |= ACCESS_COHERENT;
2689       }
2690    }
2691 }
2692 
2693 static void
set_deref_variables_coherent(nir_shader * s,nir_deref_instr * deref)2694 set_deref_variables_coherent(nir_shader *s, nir_deref_instr *deref)
2695 {
2696    while (deref->deref_type != nir_deref_type_var &&
2697           deref->deref_type != nir_deref_type_cast) {
2698       deref = nir_deref_instr_parent(deref);
2699    }
2700    if (deref->deref_type == nir_deref_type_var) {
2701       deref->var->data.access |= ACCESS_COHERENT;
2702       return;
2703    }
2704 
2705    /* For derefs with casts, we only support pre-lowered Vulkan accesses */
2706    assert(deref->deref_type == nir_deref_type_cast);
2707    nir_intrinsic_instr *cast_src = nir_instr_as_intrinsic(deref->parent.ssa->parent_instr);
2708    assert(cast_src->intrinsic == nir_intrinsic_load_vulkan_descriptor);
2709    nir_binding binding = nir_chase_binding(cast_src->src[0]);
2710    set_binding_variables_coherent(s, binding, nir_var_mem_ssbo);
2711 }
2712 
2713 static nir_def *
get_atomic_for_load_store(nir_builder * b,nir_intrinsic_instr * intr,unsigned bit_size)2714 get_atomic_for_load_store(nir_builder *b, nir_intrinsic_instr *intr, unsigned bit_size)
2715 {
2716    nir_def *zero = nir_imm_intN_t(b, 0, bit_size);
2717    switch (intr->intrinsic) {
2718    case nir_intrinsic_load_deref:
2719       return nir_deref_atomic(b, bit_size, intr->src[0].ssa, zero, .atomic_op = nir_atomic_op_iadd);
2720    case nir_intrinsic_load_ssbo:
2721       return nir_ssbo_atomic(b, bit_size, intr->src[0].ssa, intr->src[1].ssa, zero, .atomic_op = nir_atomic_op_iadd);
2722    case nir_intrinsic_image_deref_load:
2723       return nir_image_deref_atomic(b, bit_size, intr->src[0].ssa, intr->src[1].ssa, intr->src[2].ssa, zero, .atomic_op = nir_atomic_op_iadd);
2724    case nir_intrinsic_image_load:
2725       return nir_image_atomic(b, bit_size, intr->src[0].ssa, intr->src[1].ssa, intr->src[2].ssa, zero, .atomic_op = nir_atomic_op_iadd);
2726    case nir_intrinsic_store_deref:
2727       return nir_deref_atomic(b, bit_size, intr->src[0].ssa, intr->src[1].ssa, .atomic_op = nir_atomic_op_xchg);
2728    case nir_intrinsic_store_ssbo:
2729       return nir_ssbo_atomic(b, bit_size, intr->src[1].ssa, intr->src[2].ssa, intr->src[0].ssa, .atomic_op = nir_atomic_op_xchg);
2730    case nir_intrinsic_image_deref_store:
2731       return nir_image_deref_atomic(b, bit_size, intr->src[0].ssa, intr->src[1].ssa, intr->src[2].ssa, intr->src[3].ssa, .atomic_op = nir_atomic_op_xchg);
2732    case nir_intrinsic_image_store:
2733       return nir_image_atomic(b, bit_size, intr->src[0].ssa, intr->src[1].ssa, intr->src[2].ssa, intr->src[3].ssa, .atomic_op = nir_atomic_op_xchg);
2734    default:
2735       return NULL;
2736    }
2737 }
2738 
2739 static bool
lower_coherent_load_store(nir_builder * b,nir_intrinsic_instr * intr,void * context)2740 lower_coherent_load_store(nir_builder *b, nir_intrinsic_instr *intr, void *context)
2741 {
2742    if (!nir_intrinsic_has_access(intr) || (nir_intrinsic_access(intr) & ACCESS_COHERENT) == 0)
2743       return false;
2744 
2745    nir_def *atomic_def = NULL;
2746    b->cursor = nir_before_instr(&intr->instr);
2747    switch (intr->intrinsic) {
2748    case nir_intrinsic_load_deref:
2749    case nir_intrinsic_load_ssbo:
2750    case nir_intrinsic_image_deref_load:
2751    case nir_intrinsic_image_load: {
2752       if (intr->def.bit_size < 32 || intr->def.num_components > 1) {
2753          if (intr->intrinsic == nir_intrinsic_load_deref)
2754             set_deref_variables_coherent(b->shader, nir_src_as_deref(intr->src[0]));
2755          else {
2756             nir_binding binding = {0};
2757             if (nir_src_is_const(intr->src[0]))
2758                binding.binding = nir_src_as_uint(intr->src[0]);
2759             set_binding_variables_coherent(b->shader, binding,
2760                                            intr->intrinsic == nir_intrinsic_load_ssbo ? nir_var_mem_ssbo : nir_var_image);
2761          }
2762          return false;
2763       }
2764 
2765       atomic_def = get_atomic_for_load_store(b, intr, intr->def.bit_size);
2766       nir_def_rewrite_uses(&intr->def, atomic_def);
2767       break;
2768    }
2769    case nir_intrinsic_store_deref:
2770    case nir_intrinsic_store_ssbo:
2771    case nir_intrinsic_image_deref_store:
2772    case nir_intrinsic_image_store: {
2773       int resource_idx = intr->intrinsic == nir_intrinsic_store_ssbo ? 1 : 0;
2774       int value_idx = intr->intrinsic == nir_intrinsic_store_ssbo ? 0 :
2775          intr->intrinsic == nir_intrinsic_store_deref ? 1 : 3;
2776       unsigned num_components = nir_intrinsic_has_write_mask(intr) ?
2777          util_bitcount(nir_intrinsic_write_mask(intr)) : intr->src[value_idx].ssa->num_components;
2778       if (intr->src[value_idx].ssa->bit_size < 32 || num_components > 1) {
2779          if (intr->intrinsic == nir_intrinsic_store_deref)
2780             set_deref_variables_coherent(b->shader, nir_src_as_deref(intr->src[resource_idx]));
2781          else {
2782             nir_binding binding = {0};
2783             if (nir_src_is_const(intr->src[resource_idx]))
2784                binding.binding = nir_src_as_uint(intr->src[resource_idx]);
2785             set_binding_variables_coherent(b->shader, binding,
2786                                            intr->intrinsic == nir_intrinsic_store_ssbo ? nir_var_mem_ssbo : nir_var_image);
2787          }
2788          return false;
2789       }
2790 
2791       atomic_def = get_atomic_for_load_store(b, intr, intr->src[value_idx].ssa->bit_size);
2792       break;
2793    }
2794    default:
2795       return false;
2796    }
2797 
2798    nir_intrinsic_instr *atomic = nir_instr_as_intrinsic(atomic_def->parent_instr);
2799    nir_intrinsic_set_access(atomic, nir_intrinsic_access(intr));
2800    if (nir_intrinsic_has_image_dim(intr))
2801       nir_intrinsic_set_image_dim(atomic, nir_intrinsic_image_dim(intr));
2802    if (nir_intrinsic_has_image_array(intr))
2803       nir_intrinsic_set_image_array(atomic, nir_intrinsic_image_array(intr));
2804    if (nir_intrinsic_has_format(intr))
2805       nir_intrinsic_set_format(atomic, nir_intrinsic_format(intr));
2806    if (nir_intrinsic_has_range_base(intr))
2807       nir_intrinsic_set_range_base(atomic, nir_intrinsic_range_base(intr));
2808    nir_instr_remove(&intr->instr);
2809    return true;
2810 }
2811 
2812 bool
dxil_nir_lower_coherent_loads_and_stores(nir_shader * s)2813 dxil_nir_lower_coherent_loads_and_stores(nir_shader *s)
2814 {
2815    return nir_shader_intrinsics_pass(s, lower_coherent_load_store,
2816                                      nir_metadata_block_index | nir_metadata_dominance | nir_metadata_loop_analysis,
2817                                      NULL);
2818 }
2819