• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright © Microsoft Corporation
3  *
4  * Permission is hereby granted, free of charge, to any person obtaining a
5  * copy of this software and associated documentation files (the "Software"),
6  * to deal in the Software without restriction, including without limitation
7  * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8  * and/or sell copies of the Software, and to permit persons to whom the
9  * Software is furnished to do so, subject to the following conditions:
10  *
11  * The above copyright notice and this permission notice (including the next
12  * paragraph) shall be included in all copies or substantial portions of the
13  * Software.
14  *
15  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18  * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20  * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21  * IN THE SOFTWARE.
22  */
23 
24 #include "dxil_nir.h"
25 #include "dxil_module.h"
26 
27 #include "nir_builder.h"
28 #include "nir_deref.h"
29 #include "nir_worklist.h"
30 #include "nir_to_dxil.h"
31 #include "util/u_math.h"
32 #include "vulkan/vulkan_core.h"
33 
34 static void
cl_type_size_align(const struct glsl_type * type,unsigned * size,unsigned * align)35 cl_type_size_align(const struct glsl_type *type, unsigned *size,
36                    unsigned *align)
37 {
38    *size = glsl_get_cl_size(type);
39    *align = glsl_get_cl_alignment(type);
40 }
41 
42 static nir_def *
load_comps_to_vec(nir_builder * b,unsigned src_bit_size,nir_def ** src_comps,unsigned num_src_comps,unsigned dst_bit_size)43 load_comps_to_vec(nir_builder *b, unsigned src_bit_size,
44                   nir_def **src_comps, unsigned num_src_comps,
45                   unsigned dst_bit_size)
46 {
47    if (src_bit_size == dst_bit_size)
48       return nir_vec(b, src_comps, num_src_comps);
49    else if (src_bit_size > dst_bit_size)
50       return nir_extract_bits(b, src_comps, num_src_comps, 0, src_bit_size * num_src_comps / dst_bit_size, dst_bit_size);
51 
52    unsigned num_dst_comps = DIV_ROUND_UP(num_src_comps * src_bit_size, dst_bit_size);
53    unsigned comps_per_dst = dst_bit_size / src_bit_size;
54    nir_def *dst_comps[4];
55 
56    for (unsigned i = 0; i < num_dst_comps; i++) {
57       unsigned src_offs = i * comps_per_dst;
58 
59       dst_comps[i] = nir_u2uN(b, src_comps[src_offs], dst_bit_size);
60       for (unsigned j = 1; j < comps_per_dst && src_offs + j < num_src_comps; j++) {
61          nir_def *tmp = nir_ishl_imm(b, nir_u2uN(b, src_comps[src_offs + j], dst_bit_size),
62                                          j * src_bit_size);
63          dst_comps[i] = nir_ior(b, dst_comps[i], tmp);
64       }
65    }
66 
67    return nir_vec(b, dst_comps, num_dst_comps);
68 }
69 
70 static bool
lower_32b_offset_load(nir_builder * b,nir_intrinsic_instr * intr,nir_variable * var)71 lower_32b_offset_load(nir_builder *b, nir_intrinsic_instr *intr, nir_variable *var)
72 {
73    unsigned bit_size = intr->def.bit_size;
74    unsigned num_components = intr->def.num_components;
75    unsigned num_bits = num_components * bit_size;
76 
77    b->cursor = nir_before_instr(&intr->instr);
78 
79    nir_def *offset = intr->src[0].ssa;
80    if (intr->intrinsic == nir_intrinsic_load_shared)
81       offset = nir_iadd_imm(b, offset, nir_intrinsic_base(intr));
82    else
83       offset = nir_u2u32(b, offset);
84    nir_def *index = nir_ushr_imm(b, offset, 2);
85    nir_def *comps[NIR_MAX_VEC_COMPONENTS];
86    nir_def *comps_32bit[NIR_MAX_VEC_COMPONENTS * 2];
87 
88    /* We need to split loads in 32-bit accesses because the buffer
89     * is an i32 array and DXIL does not support type casts.
90     */
91    unsigned num_32bit_comps = DIV_ROUND_UP(num_bits, 32);
92    for (unsigned i = 0; i < num_32bit_comps; i++)
93       comps_32bit[i] = nir_load_array_var(b, var, nir_iadd_imm(b, index, i));
94    unsigned num_comps_per_pass = MIN2(num_32bit_comps, 4);
95 
96    for (unsigned i = 0; i < num_32bit_comps; i += num_comps_per_pass) {
97       unsigned num_vec32_comps = MIN2(num_32bit_comps - i, 4);
98       unsigned num_dest_comps = num_vec32_comps * 32 / bit_size;
99       nir_def *vec32 = nir_vec(b, &comps_32bit[i], num_vec32_comps);
100 
101       /* If we have 16 bits or less to load we need to adjust the u32 value so
102        * we can always extract the LSB.
103        */
104       if (num_bits <= 16) {
105          nir_def *shift =
106             nir_imul_imm(b, nir_iand_imm(b, offset, 3), 8);
107          vec32 = nir_ushr(b, vec32, shift);
108       }
109 
110       /* And now comes the pack/unpack step to match the original type. */
111       unsigned dest_index = i * 32 / bit_size;
112       nir_def *temp_vec = nir_extract_bits(b, &vec32, 1, 0, num_dest_comps, bit_size);
113       for (unsigned comp = 0; comp < num_dest_comps; ++comp, ++dest_index)
114          comps[dest_index] = nir_channel(b, temp_vec, comp);
115    }
116 
117    nir_def *result = nir_vec(b, comps, num_components);
118    nir_def_replace(&intr->def, result);
119 
120    return true;
121 }
122 
123 static void
lower_masked_store_vec32(nir_builder * b,nir_def * offset,nir_def * index,nir_def * vec32,unsigned num_bits,nir_variable * var,unsigned alignment)124 lower_masked_store_vec32(nir_builder *b, nir_def *offset, nir_def *index,
125                          nir_def *vec32, unsigned num_bits, nir_variable *var, unsigned alignment)
126 {
127    nir_def *mask = nir_imm_int(b, (1 << num_bits) - 1);
128 
129    /* If we have small alignments, we need to place them correctly in the u32 component. */
130    if (alignment <= 2) {
131       nir_def *shift =
132          nir_imul_imm(b, nir_iand_imm(b, offset, 3), 8);
133 
134       vec32 = nir_ishl(b, vec32, shift);
135       mask = nir_ishl(b, mask, shift);
136    }
137 
138    if (var->data.mode == nir_var_mem_shared) {
139       /* Use the dedicated masked intrinsic */
140       nir_deref_instr *deref = nir_build_deref_array(b, nir_build_deref_var(b, var), index);
141       nir_deref_atomic(b, 32, &deref->def, nir_inot(b, mask), .atomic_op = nir_atomic_op_iand);
142       nir_deref_atomic(b, 32, &deref->def, vec32, .atomic_op = nir_atomic_op_ior);
143    } else {
144       /* For scratch, since we don't need atomics, just generate the read-modify-write in NIR */
145       nir_def *load = nir_load_array_var(b, var, index);
146 
147       nir_def *new_val = nir_ior(b, vec32,
148                                      nir_iand(b,
149                                               nir_inot(b, mask),
150                                               load));
151 
152       nir_store_array_var(b, var, index, new_val, 1);
153    }
154 }
155 
156 static bool
lower_32b_offset_store(nir_builder * b,nir_intrinsic_instr * intr,nir_variable * var)157 lower_32b_offset_store(nir_builder *b, nir_intrinsic_instr *intr, nir_variable *var)
158 {
159    unsigned num_components = nir_src_num_components(intr->src[0]);
160    unsigned bit_size = nir_src_bit_size(intr->src[0]);
161    unsigned num_bits = num_components * bit_size;
162 
163    b->cursor = nir_before_instr(&intr->instr);
164 
165    nir_def *offset = intr->src[1].ssa;
166    if (intr->intrinsic == nir_intrinsic_store_shared)
167       offset = nir_iadd_imm(b, offset, nir_intrinsic_base(intr));
168    else
169       offset = nir_u2u32(b, offset);
170    nir_def *comps[NIR_MAX_VEC_COMPONENTS];
171 
172    unsigned comp_idx = 0;
173    for (unsigned i = 0; i < num_components; i++)
174       comps[i] = nir_channel(b, intr->src[0].ssa, i);
175 
176    unsigned step = MAX2(bit_size, 32);
177    for (unsigned i = 0; i < num_bits; i += step) {
178       /* For each 4byte chunk (or smaller) we generate a 32bit scalar store. */
179       unsigned substore_num_bits = MIN2(num_bits - i, step);
180       nir_def *local_offset = nir_iadd_imm(b, offset, i / 8);
181       nir_def *vec32 = load_comps_to_vec(b, bit_size, &comps[comp_idx],
182                                              substore_num_bits / bit_size, 32);
183       nir_def *index = nir_ushr_imm(b, local_offset, 2);
184 
185       /* For anything less than 32bits we need to use the masked version of the
186        * intrinsic to preserve data living in the same 32bit slot. */
187       if (substore_num_bits < 32) {
188          lower_masked_store_vec32(b, local_offset, index, vec32, num_bits, var, nir_intrinsic_align(intr));
189       } else {
190          for (unsigned i = 0; i < vec32->num_components; ++i)
191             nir_store_array_var(b, var, nir_iadd_imm(b, index, i), nir_channel(b, vec32, i), 1);
192       }
193 
194       comp_idx += substore_num_bits / bit_size;
195    }
196 
197    nir_instr_remove(&intr->instr);
198 
199    return true;
200 }
201 
202 #define CONSTANT_LOCATION_UNVISITED 0
203 #define CONSTANT_LOCATION_VALID 1
204 #define CONSTANT_LOCATION_INVALID 2
205 
206 bool
dxil_nir_lower_constant_to_temp(nir_shader * nir)207 dxil_nir_lower_constant_to_temp(nir_shader *nir)
208 {
209    bool progress = false;
210    nir_foreach_variable_with_modes(var, nir, nir_var_mem_constant)
211       var->data.location = var->constant_initializer ?
212          CONSTANT_LOCATION_UNVISITED : CONSTANT_LOCATION_INVALID;
213 
214    /* First pass: collect all UBO accesses that could be turned into
215     * shader temp accesses.
216     */
217    nir_foreach_function(func, nir) {
218       if (!func->is_entrypoint)
219          continue;
220       assert(func->impl);
221 
222       nir_foreach_block(block, func->impl) {
223          nir_foreach_instr_safe(instr, block) {
224             if (instr->type != nir_instr_type_deref)
225                continue;
226 
227             nir_deref_instr *deref = nir_instr_as_deref(instr);
228             if (!nir_deref_mode_is(deref, nir_var_mem_constant) ||
229                 deref->deref_type != nir_deref_type_var ||
230                 deref->var->data.location == CONSTANT_LOCATION_INVALID)
231                continue;
232 
233             deref->var->data.location = nir_deref_instr_has_complex_use(deref, 0) ?
234                CONSTANT_LOCATION_INVALID : CONSTANT_LOCATION_VALID;
235          }
236       }
237    }
238 
239    nir_foreach_variable_with_modes(var, nir, nir_var_mem_constant) {
240       if (var->data.location != CONSTANT_LOCATION_VALID)
241          continue;
242 
243       /* Change the variable mode. */
244       var->data.mode = nir_var_shader_temp;
245 
246       progress = true;
247    }
248 
249    /* Second pass: patch all derefs that were accessing the converted UBOs
250     * variables.
251     */
252    nir_foreach_function(func, nir) {
253       if (!func->is_entrypoint)
254          continue;
255       assert(func->impl);
256 
257       nir_builder b = nir_builder_create(func->impl);
258       nir_foreach_block(block, func->impl) {
259          nir_foreach_instr_safe(instr, block) {
260             if (instr->type != nir_instr_type_deref)
261                continue;
262 
263             nir_deref_instr *deref = nir_instr_as_deref(instr);
264             if (nir_deref_mode_is(deref, nir_var_mem_constant)) {
265                nir_deref_instr *parent = deref;
266                while (parent && parent->deref_type != nir_deref_type_var)
267                   parent = nir_src_as_deref(parent->parent);
268                if (parent && parent->var->data.mode != nir_var_mem_constant) {
269                   deref->modes = parent->var->data.mode;
270                   /* Also change "pointer" size to 32-bit since this is now a logical pointer */
271                   deref->def.bit_size = 32;
272                   if (deref->deref_type == nir_deref_type_array) {
273                      b.cursor = nir_before_instr(instr);
274                      nir_src_rewrite(&deref->arr.index, nir_u2u32(&b, deref->arr.index.ssa));
275                   }
276                }
277             }
278          }
279       }
280    }
281 
282    return progress;
283 }
284 
285 static bool
flatten_var_arrays(nir_builder * b,nir_intrinsic_instr * intr,void * data)286 flatten_var_arrays(nir_builder *b, nir_intrinsic_instr *intr, void *data)
287 {
288    switch (intr->intrinsic) {
289    case nir_intrinsic_load_deref:
290    case nir_intrinsic_store_deref:
291    case nir_intrinsic_deref_atomic:
292    case nir_intrinsic_deref_atomic_swap:
293       break;
294    default:
295       return false;
296    }
297 
298    nir_deref_instr *deref = nir_src_as_deref(intr->src[0]);
299    nir_variable *var = NULL;
300    for (nir_deref_instr *d = deref; d; d = nir_deref_instr_parent(d)) {
301       if (d->deref_type == nir_deref_type_cast)
302          return false;
303       if (d->deref_type == nir_deref_type_var) {
304          var = d->var;
305          if (d->type == var->type)
306             return false;
307       }
308    }
309    if (!var)
310       return false;
311 
312    nir_deref_path path;
313    nir_deref_path_init(&path, deref, NULL);
314 
315    assert(path.path[0]->deref_type == nir_deref_type_var);
316    b->cursor = nir_before_instr(&path.path[0]->instr);
317    nir_deref_instr *new_var_deref = nir_build_deref_var(b, var);
318    nir_def *index = NULL;
319    for (unsigned level = 1; path.path[level]; ++level) {
320       nir_deref_instr *arr_deref = path.path[level];
321       assert(arr_deref->deref_type == nir_deref_type_array);
322       b->cursor = nir_before_instr(&arr_deref->instr);
323       nir_def *val = nir_imul_imm(b, arr_deref->arr.index.ssa,
324                                       glsl_get_component_slots(arr_deref->type));
325       if (index) {
326          index = nir_iadd(b, index, val);
327       } else {
328          index = val;
329       }
330    }
331 
332    unsigned vector_comps = intr->num_components;
333    if (vector_comps > 1) {
334       b->cursor = nir_before_instr(&intr->instr);
335       if (intr->intrinsic == nir_intrinsic_load_deref) {
336          nir_def *components[NIR_MAX_VEC_COMPONENTS];
337          for (unsigned i = 0; i < vector_comps; ++i) {
338             nir_def *final_index = index ? nir_iadd_imm(b, index, i) : nir_imm_int(b, i);
339             nir_deref_instr *comp_deref = nir_build_deref_array(b, new_var_deref, final_index);
340             components[i] = nir_load_deref(b, comp_deref);
341          }
342          nir_def_rewrite_uses(&intr->def, nir_vec(b, components, vector_comps));
343       } else if (intr->intrinsic == nir_intrinsic_store_deref) {
344          for (unsigned i = 0; i < vector_comps; ++i) {
345             if (((1 << i) & nir_intrinsic_write_mask(intr)) == 0)
346                continue;
347             nir_def *final_index = index ? nir_iadd_imm(b, index, i) : nir_imm_int(b, i);
348             nir_deref_instr *comp_deref = nir_build_deref_array(b, new_var_deref, final_index);
349             nir_store_deref(b, comp_deref, nir_channel(b, intr->src[1].ssa, i), 1);
350          }
351       }
352       nir_instr_remove(&intr->instr);
353    } else {
354       nir_src_rewrite(&intr->src[0], &nir_build_deref_array(b, new_var_deref, index)->def);
355    }
356 
357    nir_deref_path_finish(&path);
358    return true;
359 }
360 
361 static void
flatten_constant_initializer(nir_variable * var,nir_constant * src,nir_constant *** dest,unsigned vector_elements)362 flatten_constant_initializer(nir_variable *var, nir_constant *src, nir_constant ***dest, unsigned vector_elements)
363 {
364    if (src->num_elements == 0) {
365       for (unsigned i = 0; i < vector_elements; ++i) {
366          nir_constant *new_scalar = rzalloc(var, nir_constant);
367          memcpy(&new_scalar->values[0], &src->values[i], sizeof(src->values[0]));
368          new_scalar->is_null_constant = src->values[i].u64 == 0;
369 
370          nir_constant **array_entry = (*dest)++;
371          *array_entry = new_scalar;
372       }
373    } else {
374       for (unsigned i = 0; i < src->num_elements; ++i)
375          flatten_constant_initializer(var, src->elements[i], dest, vector_elements);
376    }
377 }
378 
379 static bool
flatten_var_array_types(nir_variable * var)380 flatten_var_array_types(nir_variable *var)
381 {
382    assert(!glsl_type_is_struct(glsl_without_array(var->type)));
383    const struct glsl_type *matrix_type = glsl_without_array(var->type);
384    if (!glsl_type_is_array_of_arrays(var->type) && glsl_get_components(matrix_type) == 1)
385       return false;
386 
387    enum glsl_base_type base_type = glsl_get_base_type(matrix_type);
388    const struct glsl_type *flattened_type = glsl_array_type(glsl_scalar_type(base_type),
389                                                             glsl_get_component_slots(var->type), 0);
390    var->type = flattened_type;
391    if (var->constant_initializer) {
392       nir_constant **new_elements = ralloc_array(var, nir_constant *, glsl_get_length(flattened_type));
393       nir_constant **temp = new_elements;
394       flatten_constant_initializer(var, var->constant_initializer, &temp, glsl_get_vector_elements(matrix_type));
395       var->constant_initializer->num_elements = glsl_get_length(flattened_type);
396       var->constant_initializer->elements = new_elements;
397    }
398    return true;
399 }
400 
401 bool
dxil_nir_flatten_var_arrays(nir_shader * shader,nir_variable_mode modes)402 dxil_nir_flatten_var_arrays(nir_shader *shader, nir_variable_mode modes)
403 {
404    bool progress = false;
405    nir_foreach_variable_with_modes(var, shader, modes & ~nir_var_function_temp)
406       progress |= flatten_var_array_types(var);
407 
408    if (modes & nir_var_function_temp) {
409       nir_foreach_function_impl(impl, shader) {
410          nir_foreach_function_temp_variable(var, impl)
411             progress |= flatten_var_array_types(var);
412       }
413    }
414 
415    if (!progress)
416       return false;
417 
418    nir_shader_intrinsics_pass(shader, flatten_var_arrays,
419                                 nir_metadata_control_flow |
420                                 nir_metadata_loop_analysis,
421                                 NULL);
422    nir_remove_dead_derefs(shader);
423    return true;
424 }
425 
426 static bool
lower_deref_bit_size(nir_builder * b,nir_intrinsic_instr * intr,void * data)427 lower_deref_bit_size(nir_builder *b, nir_intrinsic_instr *intr, void *data)
428 {
429    switch (intr->intrinsic) {
430    case nir_intrinsic_load_deref:
431    case nir_intrinsic_store_deref:
432       break;
433    default:
434       /* Atomics can't be smaller than 32-bit */
435       return false;
436    }
437 
438    nir_deref_instr *deref = nir_src_as_deref(intr->src[0]);
439    nir_variable *var = nir_deref_instr_get_variable(deref);
440    /* Only interested in full deref chains */
441    if (!var)
442       return false;
443 
444    const struct glsl_type *var_scalar_type = glsl_without_array(var->type);
445    if (deref->type == var_scalar_type || !glsl_type_is_scalar(var_scalar_type))
446       return false;
447 
448    assert(deref->deref_type == nir_deref_type_var || deref->deref_type == nir_deref_type_array);
449    const struct glsl_type *old_glsl_type = deref->type;
450    nir_alu_type old_type = nir_get_nir_type_for_glsl_type(old_glsl_type);
451    nir_alu_type new_type = nir_get_nir_type_for_glsl_type(var_scalar_type);
452    if (glsl_get_bit_size(old_glsl_type) < glsl_get_bit_size(var_scalar_type)) {
453       deref->type = var_scalar_type;
454       if (intr->intrinsic == nir_intrinsic_load_deref) {
455          intr->def.bit_size = glsl_get_bit_size(var_scalar_type);
456          b->cursor = nir_after_instr(&intr->instr);
457          nir_def *downcast = nir_type_convert(b, &intr->def, new_type, old_type, nir_rounding_mode_undef);
458          nir_def_rewrite_uses_after(&intr->def, downcast, downcast->parent_instr);
459       }
460       else {
461          b->cursor = nir_before_instr(&intr->instr);
462          nir_def *upcast = nir_type_convert(b, intr->src[1].ssa, old_type, new_type, nir_rounding_mode_undef);
463          nir_src_rewrite(&intr->src[1], upcast);
464       }
465 
466       while (deref->deref_type == nir_deref_type_array) {
467          nir_deref_instr *parent = nir_deref_instr_parent(deref);
468          parent->type = glsl_type_wrap_in_arrays(deref->type, parent->type);
469          deref = parent;
470       }
471    } else {
472       /* Assumed arrays are already flattened */
473       b->cursor = nir_before_instr(&deref->instr);
474       nir_deref_instr *parent = nir_build_deref_var(b, var);
475       if (deref->deref_type == nir_deref_type_array)
476          deref = nir_build_deref_array(b, parent, nir_imul_imm(b, deref->arr.index.ssa, 2));
477       else
478          deref = nir_build_deref_array_imm(b, parent, 0);
479       nir_deref_instr *deref2 = nir_build_deref_array(b, parent,
480                                                       nir_iadd_imm(b, deref->arr.index.ssa, 1));
481       b->cursor = nir_before_instr(&intr->instr);
482       if (intr->intrinsic == nir_intrinsic_load_deref) {
483          nir_def *src1 = nir_load_deref(b, deref);
484          nir_def *src2 = nir_load_deref(b, deref2);
485          nir_def_rewrite_uses(&intr->def, nir_pack_64_2x32_split(b, src1, src2));
486       } else {
487          nir_def *src1 = nir_unpack_64_2x32_split_x(b, intr->src[1].ssa);
488          nir_def *src2 = nir_unpack_64_2x32_split_y(b, intr->src[1].ssa);
489          nir_store_deref(b, deref, src1, 1);
490          nir_store_deref(b, deref, src2, 1);
491       }
492       nir_instr_remove(&intr->instr);
493    }
494    return true;
495 }
496 
497 static bool
lower_var_bit_size_types(nir_variable * var,unsigned min_bit_size,unsigned max_bit_size)498 lower_var_bit_size_types(nir_variable *var, unsigned min_bit_size, unsigned max_bit_size)
499 {
500    assert(!glsl_type_is_array_of_arrays(var->type) && !glsl_type_is_struct(var->type));
501    const struct glsl_type *type = glsl_without_array(var->type);
502    assert(glsl_type_is_scalar(type));
503    enum glsl_base_type base_type = glsl_get_base_type(type);
504    if (glsl_base_type_get_bit_size(base_type) < min_bit_size) {
505       switch (min_bit_size) {
506       case 16:
507          switch (base_type) {
508          case GLSL_TYPE_BOOL:
509             base_type = GLSL_TYPE_UINT16;
510             for (unsigned i = 0; i < (var->constant_initializer ? var->constant_initializer->num_elements : 0); ++i)
511                var->constant_initializer->elements[i]->values[0].u16 = var->constant_initializer->elements[i]->values[0].b ? 0xffff : 0;
512             break;
513          case GLSL_TYPE_INT8:
514             base_type = GLSL_TYPE_INT16;
515             for (unsigned i = 0; i < (var->constant_initializer ? var->constant_initializer->num_elements : 0); ++i)
516                var->constant_initializer->elements[i]->values[0].i16 = var->constant_initializer->elements[i]->values[0].i8;
517             break;
518          case GLSL_TYPE_UINT8: base_type = GLSL_TYPE_UINT16; break;
519          default: unreachable("Unexpected base type");
520          }
521          break;
522       case 32:
523          switch (base_type) {
524          case GLSL_TYPE_BOOL:
525             base_type = GLSL_TYPE_UINT;
526             for (unsigned i = 0; i < (var->constant_initializer ? var->constant_initializer->num_elements : 0); ++i)
527                var->constant_initializer->elements[i]->values[0].u32 = var->constant_initializer->elements[i]->values[0].b ? 0xffffffff : 0;
528             break;
529          case GLSL_TYPE_INT8:
530             base_type = GLSL_TYPE_INT;
531             for (unsigned i = 0; i < (var->constant_initializer ? var->constant_initializer->num_elements : 0); ++i)
532                var->constant_initializer->elements[i]->values[0].i32 = var->constant_initializer->elements[i]->values[0].i8;
533             break;
534          case GLSL_TYPE_INT16:
535             base_type = GLSL_TYPE_INT;
536             for (unsigned i = 0; i < (var->constant_initializer ? var->constant_initializer->num_elements : 0); ++i)
537                var->constant_initializer->elements[i]->values[0].i32 = var->constant_initializer->elements[i]->values[0].i16;
538             break;
539          case GLSL_TYPE_FLOAT16:
540             base_type = GLSL_TYPE_FLOAT;
541             for (unsigned i = 0; i < (var->constant_initializer ? var->constant_initializer->num_elements : 0); ++i)
542                var->constant_initializer->elements[i]->values[0].f32 = _mesa_half_to_float(var->constant_initializer->elements[i]->values[0].u16);
543             break;
544          case GLSL_TYPE_UINT8: base_type = GLSL_TYPE_UINT; break;
545          case GLSL_TYPE_UINT16: base_type = GLSL_TYPE_UINT; break;
546          default: unreachable("Unexpected base type");
547          }
548          break;
549       default: unreachable("Unexpected min bit size");
550       }
551       var->type = glsl_type_wrap_in_arrays(glsl_scalar_type(base_type), var->type);
552       return true;
553    }
554    if (glsl_base_type_bit_size(base_type) > max_bit_size) {
555       assert(!glsl_type_is_array_of_arrays(var->type));
556       var->type = glsl_array_type(glsl_scalar_type(GLSL_TYPE_UINT),
557                                     glsl_type_is_array(var->type) ? glsl_get_length(var->type) * 2 : 2,
558                                     0);
559       if (var->constant_initializer) {
560          unsigned num_elements = var->constant_initializer->num_elements ?
561             var->constant_initializer->num_elements * 2 : 2;
562          nir_constant **element_arr = ralloc_array(var, nir_constant *, num_elements);
563          nir_constant *elements = rzalloc_array(var, nir_constant, num_elements);
564          for (unsigned i = 0; i < var->constant_initializer->num_elements; ++i) {
565             element_arr[i*2] = &elements[i*2];
566             element_arr[i*2+1] = &elements[i*2+1];
567             const nir_const_value *src = var->constant_initializer->num_elements ?
568                var->constant_initializer->elements[i]->values : var->constant_initializer->values;
569             elements[i*2].values[0].u32 = (uint32_t)src->u64;
570             elements[i*2].is_null_constant = (uint32_t)src->u64 == 0;
571             elements[i*2+1].values[0].u32 = (uint32_t)(src->u64 >> 32);
572             elements[i*2+1].is_null_constant = (uint32_t)(src->u64 >> 32) == 0;
573          }
574          var->constant_initializer->num_elements = num_elements;
575          var->constant_initializer->elements = element_arr;
576       }
577       return true;
578    }
579    return false;
580 }
581 
582 bool
dxil_nir_lower_var_bit_size(nir_shader * shader,nir_variable_mode modes,unsigned min_bit_size,unsigned max_bit_size)583 dxil_nir_lower_var_bit_size(nir_shader *shader, nir_variable_mode modes,
584                             unsigned min_bit_size, unsigned max_bit_size)
585 {
586    bool progress = false;
587    nir_foreach_variable_with_modes(var, shader, modes & ~nir_var_function_temp)
588       progress |= lower_var_bit_size_types(var, min_bit_size, max_bit_size);
589 
590    if (modes & nir_var_function_temp) {
591       nir_foreach_function_impl(impl, shader) {
592          nir_foreach_function_temp_variable(var, impl)
593             progress |= lower_var_bit_size_types(var, min_bit_size, max_bit_size);
594       }
595    }
596 
597    if (!progress)
598       return false;
599 
600    nir_shader_intrinsics_pass(shader, lower_deref_bit_size,
601                                 nir_metadata_control_flow |
602                                 nir_metadata_loop_analysis,
603                                 NULL);
604    nir_remove_dead_derefs(shader);
605    return true;
606 }
607 
608 static bool
remove_oob_array_access(nir_builder * b,nir_intrinsic_instr * intr,void * data)609 remove_oob_array_access(nir_builder *b, nir_intrinsic_instr *intr, void *data)
610 {
611    uint32_t num_derefs = 1;
612 
613    switch (intr->intrinsic) {
614    case nir_intrinsic_copy_deref:
615       num_derefs = 2;
616       FALLTHROUGH;
617    case nir_intrinsic_load_deref:
618    case nir_intrinsic_store_deref:
619    case nir_intrinsic_deref_atomic:
620    case nir_intrinsic_deref_atomic_swap:
621       break;
622    default:
623       return false;
624    }
625 
626    for (uint32_t i = 0; i < num_derefs; ++i) {
627       if (nir_deref_instr_is_known_out_of_bounds(nir_src_as_deref(intr->src[i]))) {
628          switch (intr->intrinsic) {
629          case nir_intrinsic_load_deref:
630          case nir_intrinsic_deref_atomic:
631          case nir_intrinsic_deref_atomic_swap:
632             b->cursor = nir_before_instr(&intr->instr);
633             nir_def *undef = nir_undef(b, intr->def.num_components, intr->def.bit_size);
634             nir_def_rewrite_uses(&intr->def, undef);
635             break;
636          default:
637             break;
638          }
639          nir_instr_remove(&intr->instr);
640          return true;
641       }
642    }
643 
644    return false;
645 }
646 
647 bool
dxil_nir_remove_oob_array_accesses(nir_shader * shader)648 dxil_nir_remove_oob_array_accesses(nir_shader *shader)
649 {
650    return nir_shader_intrinsics_pass(shader, remove_oob_array_access,
651                                      nir_metadata_control_flow |
652                                      nir_metadata_loop_analysis,
653                                      NULL);
654 }
655 
656 static bool
lower_shared_atomic(nir_builder * b,nir_intrinsic_instr * intr,nir_variable * var)657 lower_shared_atomic(nir_builder *b, nir_intrinsic_instr *intr, nir_variable *var)
658 {
659    b->cursor = nir_before_instr(&intr->instr);
660 
661    nir_def *offset =
662       nir_iadd_imm(b, intr->src[0].ssa, nir_intrinsic_base(intr));
663    nir_def *index = nir_ushr_imm(b, offset, 2);
664 
665    nir_deref_instr *deref = nir_build_deref_array(b, nir_build_deref_var(b, var), index);
666    nir_def *result;
667    if (intr->intrinsic == nir_intrinsic_shared_atomic_swap)
668       result = nir_deref_atomic_swap(b, 32, &deref->def, intr->src[1].ssa, intr->src[2].ssa,
669                                      .atomic_op = nir_intrinsic_atomic_op(intr));
670    else
671       result = nir_deref_atomic(b, 32, &deref->def, intr->src[1].ssa,
672                                 .atomic_op = nir_intrinsic_atomic_op(intr));
673 
674    nir_def_replace(&intr->def, result);
675    return true;
676 }
677 
678 bool
dxil_nir_lower_loads_stores_to_dxil(nir_shader * nir,const struct dxil_nir_lower_loads_stores_options * options)679 dxil_nir_lower_loads_stores_to_dxil(nir_shader *nir,
680                                     const struct dxil_nir_lower_loads_stores_options *options)
681 {
682    bool progress = nir_remove_dead_variables(nir, nir_var_function_temp | nir_var_mem_shared, NULL);
683    nir_variable *shared_var = NULL;
684    if (nir->info.shared_size) {
685       shared_var = nir_variable_create(nir, nir_var_mem_shared,
686                                        glsl_array_type(glsl_uint_type(), DIV_ROUND_UP(nir->info.shared_size, 4), 4),
687                                        "lowered_shared_mem");
688    }
689 
690    unsigned ptr_size = nir->info.cs.ptr_size;
691    if (nir->info.stage == MESA_SHADER_KERNEL) {
692       /* All the derefs created here will be used as GEP indices so force 32-bit */
693       nir->info.cs.ptr_size = 32;
694    }
695    nir_foreach_function_impl(impl, nir) {
696       nir_builder b = nir_builder_create(impl);
697 
698       nir_variable *scratch_var = NULL;
699       if (nir->scratch_size) {
700          const struct glsl_type *scratch_type = glsl_array_type(glsl_uint_type(), DIV_ROUND_UP(nir->scratch_size, 4), 4);
701          scratch_var = nir_local_variable_create(impl, scratch_type, "lowered_scratch_mem");
702       }
703 
704       nir_foreach_block(block, impl) {
705          nir_foreach_instr_safe(instr, block) {
706             if (instr->type != nir_instr_type_intrinsic)
707                continue;
708             nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
709 
710             switch (intr->intrinsic) {
711             case nir_intrinsic_load_shared:
712                progress |= lower_32b_offset_load(&b, intr, shared_var);
713                break;
714             case nir_intrinsic_load_scratch:
715                progress |= lower_32b_offset_load(&b, intr, scratch_var);
716                break;
717             case nir_intrinsic_store_shared:
718                progress |= lower_32b_offset_store(&b, intr, shared_var);
719                break;
720             case nir_intrinsic_store_scratch:
721                progress |= lower_32b_offset_store(&b, intr, scratch_var);
722                break;
723             case nir_intrinsic_shared_atomic:
724             case nir_intrinsic_shared_atomic_swap:
725                progress |= lower_shared_atomic(&b, intr, shared_var);
726                break;
727             default:
728                break;
729             }
730          }
731       }
732    }
733    if (nir->info.stage == MESA_SHADER_KERNEL) {
734       nir->info.cs.ptr_size = ptr_size;
735    }
736 
737    return progress;
738 }
739 
740 static bool
lower_deref_ssbo(nir_builder * b,nir_deref_instr * deref)741 lower_deref_ssbo(nir_builder *b, nir_deref_instr *deref)
742 {
743    assert(nir_deref_mode_is(deref, nir_var_mem_ssbo));
744    assert(deref->deref_type == nir_deref_type_var ||
745           deref->deref_type == nir_deref_type_cast);
746    nir_variable *var = deref->var;
747 
748    b->cursor = nir_before_instr(&deref->instr);
749 
750    if (deref->deref_type == nir_deref_type_var) {
751       /* We turn all deref_var into deref_cast and build a pointer value based on
752        * the var binding which encodes the UAV id.
753        */
754       nir_def *ptr = nir_imm_int64(b, (uint64_t)var->data.binding << 32);
755       nir_deref_instr *deref_cast =
756          nir_build_deref_cast(b, ptr, nir_var_mem_ssbo, deref->type,
757                               glsl_get_explicit_stride(var->type));
758       nir_def_replace(&deref->def, &deref_cast->def);
759 
760       deref = deref_cast;
761       return true;
762    }
763    return false;
764 }
765 
766 bool
dxil_nir_lower_deref_ssbo(nir_shader * nir)767 dxil_nir_lower_deref_ssbo(nir_shader *nir)
768 {
769    bool progress = false;
770 
771    foreach_list_typed(nir_function, func, node, &nir->functions) {
772       if (!func->is_entrypoint)
773          continue;
774       assert(func->impl);
775 
776       nir_builder b = nir_builder_create(func->impl);
777 
778       nir_foreach_block(block, func->impl) {
779          nir_foreach_instr_safe(instr, block) {
780             if (instr->type != nir_instr_type_deref)
781                continue;
782 
783             nir_deref_instr *deref = nir_instr_as_deref(instr);
784 
785             if (!nir_deref_mode_is(deref, nir_var_mem_ssbo) ||
786                 (deref->deref_type != nir_deref_type_var &&
787                  deref->deref_type != nir_deref_type_cast))
788                continue;
789 
790             progress |= lower_deref_ssbo(&b, deref);
791          }
792       }
793    }
794 
795    return progress;
796 }
797 
798 static bool
lower_alu_deref_srcs(nir_builder * b,nir_alu_instr * alu)799 lower_alu_deref_srcs(nir_builder *b, nir_alu_instr *alu)
800 {
801    const nir_op_info *info = &nir_op_infos[alu->op];
802    bool progress = false;
803 
804    b->cursor = nir_before_instr(&alu->instr);
805 
806    for (unsigned i = 0; i < info->num_inputs; i++) {
807       nir_deref_instr *deref = nir_src_as_deref(alu->src[i].src);
808 
809       if (!deref)
810          continue;
811 
812       nir_deref_path path;
813       nir_deref_path_init(&path, deref, NULL);
814       nir_deref_instr *root_deref = path.path[0];
815       nir_deref_path_finish(&path);
816 
817       if (root_deref->deref_type != nir_deref_type_cast)
818          continue;
819 
820       nir_def *ptr =
821          nir_iadd(b, root_deref->parent.ssa,
822                      nir_build_deref_offset(b, deref, cl_type_size_align));
823       nir_src_rewrite(&alu->src[i].src, ptr);
824       progress = true;
825    }
826 
827    return progress;
828 }
829 
830 bool
dxil_nir_opt_alu_deref_srcs(nir_shader * nir)831 dxil_nir_opt_alu_deref_srcs(nir_shader *nir)
832 {
833    bool progress = false;
834 
835    foreach_list_typed(nir_function, func, node, &nir->functions) {
836       if (!func->is_entrypoint)
837          continue;
838       assert(func->impl);
839 
840       nir_builder b = nir_builder_create(func->impl);
841 
842       nir_foreach_block(block, func->impl) {
843          nir_foreach_instr_safe(instr, block) {
844             if (instr->type != nir_instr_type_alu)
845                continue;
846 
847             nir_alu_instr *alu = nir_instr_as_alu(instr);
848             progress |= lower_alu_deref_srcs(&b, alu);
849          }
850       }
851    }
852 
853    return progress;
854 }
855 
856 static void
cast_phi(nir_builder * b,nir_phi_instr * phi,unsigned new_bit_size)857 cast_phi(nir_builder *b, nir_phi_instr *phi, unsigned new_bit_size)
858 {
859    nir_phi_instr *lowered = nir_phi_instr_create(b->shader);
860    int num_components = 0;
861    int old_bit_size = phi->def.bit_size;
862 
863    nir_foreach_phi_src(src, phi) {
864       assert(num_components == 0 || num_components == src->src.ssa->num_components);
865       num_components = src->src.ssa->num_components;
866 
867       b->cursor = nir_after_instr_and_phis(src->src.ssa->parent_instr);
868 
869       nir_def *cast = nir_u2uN(b, src->src.ssa, new_bit_size);
870 
871       nir_phi_instr_add_src(lowered, src->pred, cast);
872    }
873 
874    nir_def_init(&lowered->instr, &lowered->def, num_components,
875                 new_bit_size);
876 
877    b->cursor = nir_before_instr(&phi->instr);
878    nir_builder_instr_insert(b, &lowered->instr);
879 
880    b->cursor = nir_after_phis(nir_cursor_current_block(b->cursor));
881    nir_def *result = nir_u2uN(b, &lowered->def, old_bit_size);
882 
883    nir_def_replace(&phi->def, result);
884 }
885 
886 static bool
upcast_phi_impl(nir_function_impl * impl,unsigned min_bit_size)887 upcast_phi_impl(nir_function_impl *impl, unsigned min_bit_size)
888 {
889    nir_builder b = nir_builder_create(impl);
890    bool progress = false;
891 
892    nir_foreach_block_reverse(block, impl) {
893       nir_foreach_phi_safe(phi, block) {
894          if (phi->def.bit_size == 1 ||
895              phi->def.bit_size >= min_bit_size)
896             continue;
897 
898          cast_phi(&b, phi, min_bit_size);
899          progress = true;
900       }
901    }
902 
903    if (progress) {
904       nir_metadata_preserve(impl, nir_metadata_control_flow);
905    } else {
906       nir_metadata_preserve(impl, nir_metadata_all);
907    }
908 
909    return progress;
910 }
911 
912 bool
dxil_nir_lower_upcast_phis(nir_shader * shader,unsigned min_bit_size)913 dxil_nir_lower_upcast_phis(nir_shader *shader, unsigned min_bit_size)
914 {
915    bool progress = false;
916 
917    nir_foreach_function_impl(impl, shader) {
918       progress |= upcast_phi_impl(impl, min_bit_size);
919    }
920 
921    return progress;
922 }
923 
924 struct dxil_nir_split_clip_cull_distance_params {
925    nir_variable *new_var[2];
926    nir_shader *shader;
927 };
928 
929 /* In GLSL and SPIR-V, clip and cull distance are arrays of floats (with a limit of 8).
930  * In DXIL, clip and cull distances are up to 2 float4s combined.
931  * Coming from GLSL, we can request this 2 float4 format, but coming from SPIR-V,
932  * we can't, and have to accept a "compact" array of scalar floats.
933  *
934  * To help emitting a valid input signature for this case, split the variables so that they
935  * match what we need to put in the signature (e.g. { float clip[4]; float clip1; float cull[3]; })
936  *
937  * This pass can deal with splitting across two axes:
938  * 1. Given { float clip[5]; float cull[3]; }, split clip into clip[4] and clip1[1]. This is
939  *    what's produced by nir_lower_clip_cull_distance_arrays.
940  * 2. Given { float clip[4]; float clipcull[4]; }, split clipcull into clip1[1] and cull[3].
941  *    This is what's produced by the sequence of nir_lower_clip_cull_distance_arrays, then
942  *    I/O lowering, vectorization, optimization, and I/O un-lowering.
943  */
944 static bool
dxil_nir_split_clip_cull_distance_instr(nir_builder * b,nir_instr * instr,void * cb_data)945 dxil_nir_split_clip_cull_distance_instr(nir_builder *b,
946                                         nir_instr *instr,
947                                         void *cb_data)
948 {
949    struct dxil_nir_split_clip_cull_distance_params *params = cb_data;
950 
951    if (instr->type != nir_instr_type_deref)
952       return false;
953 
954    nir_deref_instr *deref = nir_instr_as_deref(instr);
955    nir_variable *var = nir_deref_instr_get_variable(deref);
956    if (!var ||
957        var->data.location < VARYING_SLOT_CLIP_DIST0 ||
958        var->data.location > VARYING_SLOT_CULL_DIST1 ||
959        !var->data.compact)
960       return false;
961 
962    unsigned new_var_idx = var->data.mode == nir_var_shader_in ? 0 : 1;
963    nir_variable *new_var = params->new_var[new_var_idx];
964 
965    /* The location should only be inside clip distance, because clip
966     * and cull should've been merged by nir_lower_clip_cull_distance_arrays()
967     */
968    assert(var->data.location == VARYING_SLOT_CLIP_DIST0 ||
969           var->data.location == VARYING_SLOT_CLIP_DIST1);
970 
971    /* The deref chain to the clip/cull variables should be simple, just the
972     * var and an array with a constant index, otherwise more lowering/optimization
973     * might be needed before this pass, e.g. copy prop, lower_io_to_temporaries,
974     * split_var_copies, and/or lower_var_copies. In the case of arrayed I/O like
975     * inputs to the tessellation or geometry stages, there might be a second level
976     * of array index.
977     */
978    assert(deref->deref_type == nir_deref_type_var ||
979           deref->deref_type == nir_deref_type_array);
980 
981    bool clip_size_accurate = var->data.mode == nir_var_shader_out || b->shader->info.stage == MESA_SHADER_FRAGMENT;
982    bool is_clip_cull_split = false;
983 
984    b->cursor = nir_before_instr(instr);
985    unsigned arrayed_io_length = 0;
986    const struct glsl_type *old_type = var->type;
987    if (nir_is_arrayed_io(var, b->shader->info.stage)) {
988       arrayed_io_length = glsl_array_size(old_type);
989       old_type = glsl_get_array_element(old_type);
990    }
991    int old_length = glsl_array_size(old_type);
992    if (!new_var) {
993       /* Update lengths for new and old vars */
994       int new_length = (old_length + var->data.location_frac) - 4;
995 
996       /* The existing variable fits in the float4 */
997       if (new_length <= 0) {
998          /* If we don't have an accurate clip size in the shader info, then we're in case 1 (only
999           * lowered clip_cull, not lowered+unlowered I/O), since I/O optimization moves these varyings
1000           * out of sysval locations and into generic locations. */
1001          if (!clip_size_accurate)
1002             return false;
1003 
1004          assert(old_length <= 4);
1005          unsigned start = (var->data.location - VARYING_SLOT_CLIP_DIST0) * 4;
1006          unsigned end = start + old_length;
1007          /* If it doesn't straddle the clip array size, then it's either just clip or cull, not both. */
1008          if (start >= b->shader->info.clip_distance_array_size ||
1009                end <= b->shader->info.clip_distance_array_size)
1010             return false;
1011 
1012          new_length = end - b->shader->info.clip_distance_array_size;
1013          is_clip_cull_split = true;
1014       }
1015 
1016       old_length -= new_length;
1017       new_var = nir_variable_clone(var, params->shader);
1018       nir_shader_add_variable(params->shader, new_var);
1019       assert(glsl_get_base_type(glsl_get_array_element(old_type)) == GLSL_TYPE_FLOAT);
1020       var->type = glsl_array_type(glsl_float_type(), old_length, 0);
1021       new_var->type = glsl_array_type(glsl_float_type(), new_length, 0);
1022       if (arrayed_io_length) {
1023          var->type = glsl_array_type(var->type, arrayed_io_length, 0);
1024          new_var->type = glsl_array_type(new_var->type, arrayed_io_length, 0);
1025       }
1026 
1027       if (is_clip_cull_split) {
1028          new_var->data.location_frac = old_length;
1029       } else {
1030          new_var->data.location++;
1031          new_var->data.location_frac = 0;
1032       }
1033       params->new_var[new_var_idx] = new_var;
1034    }
1035 
1036    /* Update the type for derefs of the old var */
1037    if (deref->deref_type == nir_deref_type_var) {
1038       deref->type = var->type;
1039       return false;
1040    }
1041 
1042    if (glsl_type_is_array(deref->type)) {
1043       assert(arrayed_io_length > 0);
1044       deref->type = glsl_get_array_element(var->type);
1045       return false;
1046    }
1047 
1048    assert(glsl_get_base_type(deref->type) == GLSL_TYPE_FLOAT);
1049 
1050    nir_const_value *index = nir_src_as_const_value(deref->arr.index);
1051    assert(index);
1052 
1053    /* If we're indexing out-of-bounds of the old variable, then adjust to point
1054     * to the new variable with a smaller index.
1055     */
1056    if (index->u32 < old_length)
1057       return false;
1058 
1059    nir_deref_instr *new_var_deref = nir_build_deref_var(b, new_var);
1060    nir_deref_instr *new_intermediate_deref = new_var_deref;
1061    if (arrayed_io_length) {
1062       nir_deref_instr *parent = nir_src_as_deref(deref->parent);
1063       assert(parent->deref_type == nir_deref_type_array);
1064       new_intermediate_deref = nir_build_deref_array(b, new_intermediate_deref, parent->arr.index.ssa);
1065    }
1066    nir_deref_instr *new_array_deref = nir_build_deref_array(b, new_intermediate_deref, nir_imm_int(b, index->u32 - old_length));
1067    nir_def_rewrite_uses(&deref->def, &new_array_deref->def);
1068    return true;
1069 }
1070 
1071 bool
dxil_nir_split_clip_cull_distance(nir_shader * shader)1072 dxil_nir_split_clip_cull_distance(nir_shader *shader)
1073 {
1074    struct dxil_nir_split_clip_cull_distance_params params = {
1075       .new_var = { NULL, NULL },
1076       .shader = shader,
1077    };
1078    nir_shader_instructions_pass(shader,
1079                                 dxil_nir_split_clip_cull_distance_instr,
1080                                 nir_metadata_control_flow |
1081                                 nir_metadata_loop_analysis,
1082                                 &params);
1083    return params.new_var[0] != NULL || params.new_var[1] != NULL;
1084 }
1085 
1086 static bool
dxil_nir_lower_double_math_instr(nir_builder * b,nir_instr * instr,UNUSED void * cb_data)1087 dxil_nir_lower_double_math_instr(nir_builder *b,
1088                                  nir_instr *instr,
1089                                  UNUSED void *cb_data)
1090 {
1091    if (instr->type == nir_instr_type_intrinsic) {
1092       nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
1093       switch (intr->intrinsic) {
1094          case nir_intrinsic_reduce:
1095          case nir_intrinsic_exclusive_scan:
1096          case nir_intrinsic_inclusive_scan:
1097             break;
1098          default:
1099             return false;
1100       }
1101       if (intr->def.bit_size != 64)
1102          return false;
1103       nir_op reduction = nir_intrinsic_reduction_op(intr);
1104       switch (reduction) {
1105          case nir_op_fmul:
1106          case nir_op_fadd:
1107          case nir_op_fmin:
1108          case nir_op_fmax:
1109             break;
1110          default:
1111             return false;
1112       }
1113       b->cursor = nir_before_instr(instr);
1114       nir_src_rewrite(&intr->src[0], nir_pack_double_2x32_dxil(b, nir_unpack_64_2x32(b, intr->src[0].ssa)));
1115       b->cursor = nir_after_instr(instr);
1116       nir_def *result = nir_pack_64_2x32(b, nir_unpack_double_2x32_dxil(b, &intr->def));
1117       nir_def_rewrite_uses_after(&intr->def, result, result->parent_instr);
1118       return true;
1119    }
1120 
1121    if (instr->type != nir_instr_type_alu)
1122       return false;
1123 
1124    nir_alu_instr *alu = nir_instr_as_alu(instr);
1125 
1126    /* TODO: See if we can apply this explicitly to packs/unpacks that are then
1127     * used as a double. As-is, if we had an app explicitly do a 64bit integer op,
1128     * then try to bitcast to double (not expressible in HLSL, but it is in other
1129     * source languages), this would unpack the integer and repack as a double, when
1130     * we probably want to just send the bitcast through to the backend.
1131     */
1132 
1133    b->cursor = nir_before_instr(&alu->instr);
1134 
1135    bool progress = false;
1136    for (unsigned i = 0; i < nir_op_infos[alu->op].num_inputs; ++i) {
1137       if (nir_alu_type_get_base_type(nir_op_infos[alu->op].input_types[i]) == nir_type_float &&
1138           alu->src[i].src.ssa->bit_size == 64) {
1139          unsigned num_components = nir_op_infos[alu->op].input_sizes[i];
1140          if (!num_components)
1141             num_components = alu->def.num_components;
1142          nir_def *components[NIR_MAX_VEC_COMPONENTS];
1143          for (unsigned c = 0; c < num_components; ++c) {
1144             nir_def *packed_double = nir_channel(b, alu->src[i].src.ssa, alu->src[i].swizzle[c]);
1145             nir_def *unpacked_double = nir_unpack_64_2x32(b, packed_double);
1146             components[c] = nir_pack_double_2x32_dxil(b, unpacked_double);
1147             alu->src[i].swizzle[c] = c;
1148          }
1149          nir_src_rewrite(&alu->src[i].src,
1150                          nir_vec(b, components, num_components));
1151          progress = true;
1152       }
1153    }
1154 
1155    if (nir_alu_type_get_base_type(nir_op_infos[alu->op].output_type) == nir_type_float &&
1156        alu->def.bit_size == 64) {
1157       b->cursor = nir_after_instr(&alu->instr);
1158       nir_def *components[NIR_MAX_VEC_COMPONENTS];
1159       for (unsigned c = 0; c < alu->def.num_components; ++c) {
1160          nir_def *packed_double = nir_channel(b, &alu->def, c);
1161          nir_def *unpacked_double = nir_unpack_double_2x32_dxil(b, packed_double);
1162          components[c] = nir_pack_64_2x32(b, unpacked_double);
1163       }
1164       nir_def *repacked_dvec = nir_vec(b, components, alu->def.num_components);
1165       nir_def_rewrite_uses_after(&alu->def, repacked_dvec, repacked_dvec->parent_instr);
1166       progress = true;
1167    }
1168 
1169    return progress;
1170 }
1171 
1172 bool
dxil_nir_lower_double_math(nir_shader * shader)1173 dxil_nir_lower_double_math(nir_shader *shader)
1174 {
1175    return nir_shader_instructions_pass(shader,
1176                                        dxil_nir_lower_double_math_instr,
1177                                        nir_metadata_control_flow,
1178                                        NULL);
1179 }
1180 
1181 typedef struct {
1182    gl_system_value *values;
1183    uint32_t count;
1184 } zero_system_values_state;
1185 
1186 static bool
lower_system_value_to_zero_filter(const nir_instr * instr,const void * cb_state)1187 lower_system_value_to_zero_filter(const nir_instr* instr, const void* cb_state)
1188 {
1189    if (instr->type != nir_instr_type_intrinsic) {
1190       return false;
1191    }
1192 
1193    nir_intrinsic_instr* intrin = nir_instr_as_intrinsic(instr);
1194 
1195    /* All the intrinsics we care about are loads */
1196    if (!nir_intrinsic_infos[intrin->intrinsic].has_dest)
1197       return false;
1198 
1199    zero_system_values_state* state = (zero_system_values_state*)cb_state;
1200    for (uint32_t i = 0; i < state->count; ++i) {
1201       gl_system_value value = state->values[i];
1202       nir_intrinsic_op value_op = nir_intrinsic_from_system_value(value);
1203 
1204       if (intrin->intrinsic == value_op) {
1205          return true;
1206       } else if (intrin->intrinsic == nir_intrinsic_load_deref) {
1207          nir_deref_instr* deref = nir_src_as_deref(intrin->src[0]);
1208          if (!nir_deref_mode_is(deref, nir_var_system_value))
1209             return false;
1210 
1211          nir_variable* var = deref->var;
1212          if (var->data.location == value) {
1213             return true;
1214          }
1215       }
1216    }
1217 
1218    return false;
1219 }
1220 
1221 static nir_def*
lower_system_value_to_zero_instr(nir_builder * b,nir_instr * instr,void * _state)1222 lower_system_value_to_zero_instr(nir_builder* b, nir_instr* instr, void* _state)
1223 {
1224    return nir_imm_int(b, 0);
1225 }
1226 
1227 bool
dxil_nir_lower_system_values_to_zero(nir_shader * shader,gl_system_value * system_values,uint32_t count)1228 dxil_nir_lower_system_values_to_zero(nir_shader* shader,
1229                                      gl_system_value* system_values,
1230                                      uint32_t count)
1231 {
1232    zero_system_values_state state = { system_values, count };
1233    return nir_shader_lower_instructions(shader,
1234       lower_system_value_to_zero_filter,
1235       lower_system_value_to_zero_instr,
1236       &state);
1237 }
1238 
1239 static void
lower_load_local_group_size(nir_builder * b,nir_intrinsic_instr * intr)1240 lower_load_local_group_size(nir_builder *b, nir_intrinsic_instr *intr)
1241 {
1242    b->cursor = nir_after_instr(&intr->instr);
1243 
1244    nir_const_value v[3] = {
1245       nir_const_value_for_int(b->shader->info.workgroup_size[0], 32),
1246       nir_const_value_for_int(b->shader->info.workgroup_size[1], 32),
1247       nir_const_value_for_int(b->shader->info.workgroup_size[2], 32)
1248    };
1249    nir_def *size = nir_build_imm(b, 3, 32, v);
1250    nir_def_replace(&intr->def, size);
1251 }
1252 
1253 static bool
lower_system_values_impl(nir_builder * b,nir_intrinsic_instr * intr,void * _state)1254 lower_system_values_impl(nir_builder *b, nir_intrinsic_instr *intr,
1255                          void *_state)
1256 {
1257    switch (intr->intrinsic) {
1258    case nir_intrinsic_load_workgroup_size:
1259       lower_load_local_group_size(b, intr);
1260       return true;
1261    default:
1262       return false;
1263    }
1264 }
1265 
1266 bool
dxil_nir_lower_system_values(nir_shader * shader)1267 dxil_nir_lower_system_values(nir_shader *shader)
1268 {
1269    return nir_shader_intrinsics_pass(shader, lower_system_values_impl,
1270                                      nir_metadata_control_flow | nir_metadata_loop_analysis,
1271                                      NULL);
1272 }
1273 
1274 static const struct glsl_type *
get_bare_samplers_for_type(const struct glsl_type * type,bool is_shadow)1275 get_bare_samplers_for_type(const struct glsl_type *type, bool is_shadow)
1276 {
1277    const struct glsl_type *base_sampler_type =
1278       is_shadow ?
1279       glsl_bare_shadow_sampler_type() : glsl_bare_sampler_type();
1280    return glsl_type_wrap_in_arrays(base_sampler_type, type);
1281 }
1282 
1283 static const struct glsl_type *
get_textures_for_sampler_type(const struct glsl_type * type)1284 get_textures_for_sampler_type(const struct glsl_type *type)
1285 {
1286    return glsl_type_wrap_in_arrays(
1287       glsl_sampler_type_to_texture(
1288          glsl_without_array(type)), type);
1289 }
1290 
1291 static bool
redirect_sampler_derefs(struct nir_builder * b,nir_instr * instr,void * data)1292 redirect_sampler_derefs(struct nir_builder *b, nir_instr *instr, void *data)
1293 {
1294    if (instr->type != nir_instr_type_tex)
1295       return false;
1296 
1297    nir_tex_instr *tex = nir_instr_as_tex(instr);
1298 
1299    int sampler_idx = nir_tex_instr_src_index(tex, nir_tex_src_sampler_deref);
1300    if (sampler_idx == -1) {
1301       /* No sampler deref - does this instruction even need a sampler? If not,
1302        * sampler_index doesn't necessarily point to a sampler, so early-out.
1303        */
1304       if (!nir_tex_instr_need_sampler(tex))
1305          return false;
1306 
1307       /* No derefs but needs a sampler, must be using indices */
1308       nir_variable *bare_sampler = _mesa_hash_table_u64_search(data, tex->sampler_index);
1309 
1310       /* Already have a bare sampler here */
1311       if (bare_sampler)
1312          return false;
1313 
1314       nir_variable *old_sampler = NULL;
1315       nir_foreach_variable_with_modes(var, b->shader, nir_var_uniform) {
1316          if (var->data.binding <= tex->sampler_index &&
1317              var->data.binding + glsl_type_get_sampler_count(var->type) >
1318                 tex->sampler_index) {
1319 
1320             /* Already have a bare sampler for this binding and it is of the
1321              * correct type, add it to the table */
1322             if (glsl_type_is_bare_sampler(glsl_without_array(var->type)) &&
1323                 glsl_sampler_type_is_shadow(glsl_without_array(var->type)) ==
1324                    tex->is_shadow) {
1325                _mesa_hash_table_u64_insert(data, tex->sampler_index, var);
1326                return false;
1327             }
1328 
1329             old_sampler = var;
1330          }
1331       }
1332 
1333       assert(old_sampler);
1334 
1335       /* Clone the original sampler to a bare sampler of the correct type */
1336       bare_sampler = nir_variable_clone(old_sampler, b->shader);
1337       nir_shader_add_variable(b->shader, bare_sampler);
1338 
1339       bare_sampler->type =
1340          get_bare_samplers_for_type(old_sampler->type, tex->is_shadow);
1341       _mesa_hash_table_u64_insert(data, tex->sampler_index, bare_sampler);
1342       return true;
1343    }
1344 
1345    /* Using derefs, means we have to rewrite the deref chain in addition to cloning */
1346    nir_deref_instr *final_deref = nir_src_as_deref(tex->src[sampler_idx].src);
1347    nir_deref_path path;
1348    nir_deref_path_init(&path, final_deref, NULL);
1349 
1350    nir_deref_instr *old_tail = path.path[0];
1351    assert(old_tail->deref_type == nir_deref_type_var);
1352    nir_variable *old_var = old_tail->var;
1353    if (glsl_type_is_bare_sampler(glsl_without_array(old_var->type)) &&
1354        glsl_sampler_type_is_shadow(glsl_without_array(old_var->type)) ==
1355           tex->is_shadow) {
1356       nir_deref_path_finish(&path);
1357       return false;
1358    }
1359 
1360    uint64_t var_key = ((uint64_t)old_var->data.descriptor_set << 32) |
1361                       old_var->data.binding;
1362    nir_variable *new_var = _mesa_hash_table_u64_search(data, var_key);
1363    if (!new_var) {
1364       new_var = nir_variable_clone(old_var, b->shader);
1365       nir_shader_add_variable(b->shader, new_var);
1366       new_var->type =
1367          get_bare_samplers_for_type(old_var->type, tex->is_shadow);
1368       _mesa_hash_table_u64_insert(data, var_key, new_var);
1369    }
1370 
1371    b->cursor = nir_after_instr(&old_tail->instr);
1372    nir_deref_instr *new_tail = nir_build_deref_var(b, new_var);
1373 
1374    for (unsigned i = 1; path.path[i]; ++i) {
1375       b->cursor = nir_after_instr(&path.path[i]->instr);
1376       new_tail = nir_build_deref_follower(b, new_tail, path.path[i]);
1377    }
1378 
1379    nir_deref_path_finish(&path);
1380    nir_src_rewrite(&tex->src[sampler_idx].src, &new_tail->def);
1381    return true;
1382 }
1383 
1384 static bool
redirect_texture_derefs(struct nir_builder * b,nir_instr * instr,void * data)1385 redirect_texture_derefs(struct nir_builder *b, nir_instr *instr, void *data)
1386 {
1387    if (instr->type != nir_instr_type_tex)
1388       return false;
1389 
1390    nir_tex_instr *tex = nir_instr_as_tex(instr);
1391 
1392    int texture_idx = nir_tex_instr_src_index(tex, nir_tex_src_texture_deref);
1393    if (texture_idx == -1) {
1394       /* No derefs, must be using indices */
1395       nir_variable *bare_sampler = _mesa_hash_table_u64_search(data, tex->texture_index);
1396 
1397       /* Already have a texture here */
1398       if (bare_sampler)
1399          return false;
1400 
1401       nir_variable *typed_sampler = NULL;
1402       nir_foreach_variable_with_modes(var, b->shader, nir_var_uniform) {
1403          if (var->data.binding <= tex->texture_index &&
1404              var->data.binding + glsl_type_get_texture_count(var->type) > tex->texture_index) {
1405             /* Already have a texture for this binding, add it to the table */
1406             _mesa_hash_table_u64_insert(data, tex->texture_index, var);
1407             return false;
1408          }
1409 
1410          if (var->data.binding <= tex->texture_index &&
1411              var->data.binding + glsl_type_get_sampler_count(var->type) > tex->texture_index &&
1412              !glsl_type_is_bare_sampler(glsl_without_array(var->type))) {
1413             typed_sampler = var;
1414          }
1415       }
1416 
1417       /* Clone the typed sampler to a texture and we're done */
1418       assert(typed_sampler);
1419       bare_sampler = nir_variable_clone(typed_sampler, b->shader);
1420       bare_sampler->type = get_textures_for_sampler_type(typed_sampler->type);
1421       nir_shader_add_variable(b->shader, bare_sampler);
1422       _mesa_hash_table_u64_insert(data, tex->texture_index, bare_sampler);
1423       return true;
1424    }
1425 
1426    /* Using derefs, means we have to rewrite the deref chain in addition to cloning */
1427    nir_deref_instr *final_deref = nir_src_as_deref(tex->src[texture_idx].src);
1428    nir_deref_path path;
1429    nir_deref_path_init(&path, final_deref, NULL);
1430 
1431    nir_deref_instr *old_tail = path.path[0];
1432    assert(old_tail->deref_type == nir_deref_type_var);
1433    nir_variable *old_var = old_tail->var;
1434    if (glsl_type_is_texture(glsl_without_array(old_var->type)) ||
1435        glsl_type_is_image(glsl_without_array(old_var->type))) {
1436       nir_deref_path_finish(&path);
1437       return false;
1438    }
1439 
1440    uint64_t var_key = ((uint64_t)old_var->data.descriptor_set << 32) |
1441                       old_var->data.binding;
1442    nir_variable *new_var = _mesa_hash_table_u64_search(data, var_key);
1443    if (!new_var) {
1444       new_var = nir_variable_clone(old_var, b->shader);
1445       new_var->type = get_textures_for_sampler_type(old_var->type);
1446       nir_shader_add_variable(b->shader, new_var);
1447       _mesa_hash_table_u64_insert(data, var_key, new_var);
1448    }
1449 
1450    b->cursor = nir_after_instr(&old_tail->instr);
1451    nir_deref_instr *new_tail = nir_build_deref_var(b, new_var);
1452 
1453    for (unsigned i = 1; path.path[i]; ++i) {
1454       b->cursor = nir_after_instr(&path.path[i]->instr);
1455       new_tail = nir_build_deref_follower(b, new_tail, path.path[i]);
1456    }
1457 
1458    nir_deref_path_finish(&path);
1459    nir_src_rewrite(&tex->src[texture_idx].src, &new_tail->def);
1460 
1461    return true;
1462 }
1463 
1464 bool
dxil_nir_split_typed_samplers(nir_shader * nir)1465 dxil_nir_split_typed_samplers(nir_shader *nir)
1466 {
1467    struct hash_table_u64 *hash_table = _mesa_hash_table_u64_create(NULL);
1468 
1469    bool progress = nir_shader_instructions_pass(nir, redirect_sampler_derefs,
1470       nir_metadata_control_flow | nir_metadata_loop_analysis, hash_table);
1471 
1472    _mesa_hash_table_u64_clear(hash_table);
1473 
1474    progress |= nir_shader_instructions_pass(nir, redirect_texture_derefs,
1475       nir_metadata_control_flow | nir_metadata_loop_analysis, hash_table);
1476 
1477    _mesa_hash_table_u64_destroy(hash_table);
1478    return progress;
1479 }
1480 
1481 
1482 static bool
lower_sysval_to_load_input_impl(nir_builder * b,nir_intrinsic_instr * intr,void * data)1483 lower_sysval_to_load_input_impl(nir_builder *b, nir_intrinsic_instr *intr,
1484                                 void *data)
1485 {
1486    gl_system_value sysval = SYSTEM_VALUE_MAX;
1487    switch (intr->intrinsic) {
1488    case nir_intrinsic_load_instance_id:
1489       sysval = SYSTEM_VALUE_INSTANCE_ID;
1490       break;
1491    case nir_intrinsic_load_vertex_id_zero_base:
1492       sysval = SYSTEM_VALUE_VERTEX_ID_ZERO_BASE;
1493       break;
1494    default:
1495       return false;
1496    }
1497 
1498    nir_variable **sysval_vars = (nir_variable **)data;
1499    nir_variable *var = sysval_vars[sysval];
1500    assert(var);
1501 
1502    const nir_alu_type dest_type = nir_get_nir_type_for_glsl_type(var->type);
1503    const unsigned bit_size = intr->def.bit_size;
1504 
1505    b->cursor = nir_before_instr(&intr->instr);
1506    nir_def *result = nir_load_input(b, intr->def.num_components, bit_size, nir_imm_int(b, 0),
1507       .base = var->data.driver_location, .dest_type = dest_type);
1508 
1509    nir_def_rewrite_uses(&intr->def, result);
1510    return true;
1511 }
1512 
1513 bool
dxil_nir_lower_sysval_to_load_input(nir_shader * s,nir_variable ** sysval_vars)1514 dxil_nir_lower_sysval_to_load_input(nir_shader *s, nir_variable **sysval_vars)
1515 {
1516    return nir_shader_intrinsics_pass(s, lower_sysval_to_load_input_impl,
1517                                      nir_metadata_control_flow,
1518                                      sysval_vars);
1519 }
1520 
1521 /* Comparison function to sort io values so that first come normal varyings,
1522  * then system values, and then system generated values.
1523  */
1524 static int
variable_location_cmp(const nir_variable * a,const nir_variable * b)1525 variable_location_cmp(const nir_variable* a, const nir_variable* b)
1526 {
1527    // Sort by stream, driver_location, location, location_frac, then index
1528    // If all else is equal, sort full vectors before partial ones
1529    unsigned a_location = a->data.location;
1530    if (a_location >= VARYING_SLOT_PATCH0)
1531       a_location -= VARYING_SLOT_PATCH0;
1532    unsigned b_location = b->data.location;
1533    if (b_location >= VARYING_SLOT_PATCH0)
1534       b_location -= VARYING_SLOT_PATCH0;
1535    unsigned a_stream = a->data.stream & ~NIR_STREAM_PACKED;
1536    unsigned b_stream = b->data.stream & ~NIR_STREAM_PACKED;
1537    return a_stream != b_stream ?
1538             a_stream - b_stream :
1539             a->data.driver_location != b->data.driver_location ?
1540                a->data.driver_location - b->data.driver_location :
1541                a_location !=  b_location ?
1542                   a_location - b_location :
1543                   a->data.location_frac != b->data.location_frac ?
1544                      a->data.location_frac - b->data.location_frac :
1545                      a->data.index != b->data.index ?
1546                         a->data.index - b->data.index :
1547                         glsl_get_component_slots(b->type) - glsl_get_component_slots(a->type);
1548 }
1549 
1550 /* Order varyings according to driver location */
1551 void
dxil_sort_by_driver_location(nir_shader * s,nir_variable_mode modes)1552 dxil_sort_by_driver_location(nir_shader* s, nir_variable_mode modes)
1553 {
1554    nir_sort_variables_with_modes(s, variable_location_cmp, modes);
1555 }
1556 
1557 /* Sort PS outputs so that color outputs come first */
1558 void
dxil_sort_ps_outputs(nir_shader * s)1559 dxil_sort_ps_outputs(nir_shader* s)
1560 {
1561    nir_foreach_variable_with_modes_safe(var, s, nir_var_shader_out) {
1562       /* We use the driver_location here to avoid introducing a new
1563        * struct or member variable here. The true, updated driver location
1564        * will be written below, after sorting */
1565       switch (var->data.location) {
1566       case FRAG_RESULT_DEPTH:
1567          var->data.driver_location = 1;
1568          break;
1569       case FRAG_RESULT_STENCIL:
1570          var->data.driver_location = 2;
1571          break;
1572       case FRAG_RESULT_SAMPLE_MASK:
1573          var->data.driver_location = 3;
1574          break;
1575       default:
1576          var->data.driver_location = 0;
1577       }
1578    }
1579 
1580    nir_sort_variables_with_modes(s, variable_location_cmp,
1581                                  nir_var_shader_out);
1582 
1583    unsigned driver_loc = 0;
1584    nir_foreach_variable_with_modes(var, s, nir_var_shader_out) {
1585       /* Fractional vars should use the same driver_location as the base. These will
1586        * get fully merged during signature processing.
1587        */
1588       var->data.driver_location = var->data.location_frac ? driver_loc - 1 : driver_loc++;
1589    }
1590 }
1591 
1592 enum dxil_sysvalue_type {
1593    DXIL_NO_SYSVALUE = 0,
1594    DXIL_USED_SYSVALUE,
1595    DXIL_UNUSED_NO_SYSVALUE,
1596    DXIL_SYSVALUE,
1597    DXIL_GENERATED_SYSVALUE,
1598 };
1599 
1600 static enum dxil_sysvalue_type
nir_var_to_dxil_sysvalue_type(nir_variable * var,uint64_t other_stage_mask,const BITSET_WORD * other_stage_frac_mask)1601 nir_var_to_dxil_sysvalue_type(nir_variable *var, uint64_t other_stage_mask,
1602                               const BITSET_WORD *other_stage_frac_mask)
1603 {
1604    switch (var->data.location) {
1605    case VARYING_SLOT_FACE:
1606       return DXIL_GENERATED_SYSVALUE;
1607    case VARYING_SLOT_POS:
1608    case VARYING_SLOT_PRIMITIVE_ID:
1609    case VARYING_SLOT_CLIP_DIST0:
1610    case VARYING_SLOT_CLIP_DIST1:
1611    case VARYING_SLOT_PSIZ:
1612    case VARYING_SLOT_TESS_LEVEL_INNER:
1613    case VARYING_SLOT_TESS_LEVEL_OUTER:
1614    case VARYING_SLOT_VIEWPORT:
1615    case VARYING_SLOT_LAYER:
1616    case VARYING_SLOT_VIEW_INDEX:
1617       if (!((1ull << var->data.location) & other_stage_mask))
1618          return DXIL_SYSVALUE;
1619       return DXIL_USED_SYSVALUE;
1620    default:
1621       if (var->data.location < VARYING_SLOT_PATCH0 &&
1622           !((1ull << var->data.location) & other_stage_mask))
1623          return DXIL_UNUSED_NO_SYSVALUE;
1624       if (var->data.location_frac && other_stage_frac_mask &&
1625           var->data.location >= VARYING_SLOT_VAR0 &&
1626           !BITSET_TEST(other_stage_frac_mask, ((var->data.location - VARYING_SLOT_VAR0) * 4 + var->data.location_frac)))
1627          return DXIL_UNUSED_NO_SYSVALUE;
1628       return DXIL_NO_SYSVALUE;
1629    }
1630 }
1631 
1632 /* Order between stage values so that normal varyings come first,
1633  * then sysvalues and then system generated values.
1634  */
1635 void
dxil_reassign_driver_locations(nir_shader * s,nir_variable_mode modes,uint64_t other_stage_mask,const BITSET_WORD * other_stage_frac_mask)1636 dxil_reassign_driver_locations(nir_shader* s, nir_variable_mode modes,
1637    uint64_t other_stage_mask, const BITSET_WORD *other_stage_frac_mask)
1638 {
1639    nir_foreach_variable_with_modes_safe(var, s, modes) {
1640       /* We use the driver_location here to avoid introducing a new
1641        * struct or member variable here. The true, updated driver location
1642        * will be written below, after sorting */
1643       var->data.driver_location = nir_var_to_dxil_sysvalue_type(var, other_stage_mask, other_stage_frac_mask);
1644    }
1645 
1646    nir_sort_variables_with_modes(s, variable_location_cmp, modes);
1647 
1648    unsigned driver_loc = 0, driver_patch_loc = 0;
1649    nir_foreach_variable_with_modes(var, s, modes) {
1650       /* Overlap patches with non-patch */
1651       unsigned *loc = var->data.patch ? &driver_patch_loc : &driver_loc;
1652       var->data.driver_location = *loc;
1653 
1654       const struct glsl_type *type = var->type;
1655       if (nir_is_arrayed_io(var, s->info.stage) && glsl_type_is_array(type))
1656          type = glsl_get_array_element(type);
1657       *loc += glsl_count_vec4_slots(type, false, false);
1658    }
1659 }
1660 
1661 static bool
lower_ubo_array_one_to_static(struct nir_builder * b,nir_intrinsic_instr * intrin,void * cb_data)1662 lower_ubo_array_one_to_static(struct nir_builder *b,
1663                               nir_intrinsic_instr *intrin,
1664                               void *cb_data)
1665 {
1666    if (intrin->intrinsic != nir_intrinsic_load_vulkan_descriptor)
1667       return false;
1668 
1669    nir_variable *var =
1670       nir_get_binding_variable(b->shader, nir_chase_binding(intrin->src[0]));
1671 
1672    if (!var)
1673       return false;
1674 
1675    if (!glsl_type_is_array(var->type) || glsl_array_size(var->type) != 1)
1676       return false;
1677 
1678    nir_intrinsic_instr *index = nir_src_as_intrinsic(intrin->src[0]);
1679    /* We currently do not support reindex */
1680    assert(index && index->intrinsic == nir_intrinsic_vulkan_resource_index);
1681 
1682    if (nir_src_is_const(index->src[0]) && nir_src_as_uint(index->src[0]) == 0)
1683       return false;
1684 
1685    if (nir_intrinsic_desc_type(index) != VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER)
1686       return false;
1687 
1688    b->cursor = nir_instr_remove(&index->instr);
1689 
1690    // Indexing out of bounds on array of UBOs is considered undefined
1691    // behavior. Therefore, we just hardcode all the index to 0.
1692    uint8_t bit_size = index->def.bit_size;
1693    nir_def *zero = nir_imm_intN_t(b, 0, bit_size);
1694    nir_def *dest =
1695       nir_vulkan_resource_index(b, index->num_components, bit_size, zero,
1696                                 .desc_set = nir_intrinsic_desc_set(index),
1697                                 .binding = nir_intrinsic_binding(index),
1698                                 .desc_type = nir_intrinsic_desc_type(index));
1699 
1700    nir_def_rewrite_uses(&index->def, dest);
1701 
1702    return true;
1703 }
1704 
1705 bool
dxil_nir_lower_ubo_array_one_to_static(nir_shader * s)1706 dxil_nir_lower_ubo_array_one_to_static(nir_shader *s)
1707 {
1708    bool progress = nir_shader_intrinsics_pass(s,
1709                                               lower_ubo_array_one_to_static,
1710                                               nir_metadata_none, NULL);
1711 
1712    return progress;
1713 }
1714 
1715 static bool
is_fquantize2f16(const nir_instr * instr,const void * data)1716 is_fquantize2f16(const nir_instr *instr, const void *data)
1717 {
1718    if (instr->type != nir_instr_type_alu)
1719       return false;
1720 
1721    nir_alu_instr *alu = nir_instr_as_alu(instr);
1722    return alu->op == nir_op_fquantize2f16;
1723 }
1724 
1725 static nir_def *
lower_fquantize2f16(struct nir_builder * b,nir_instr * instr,void * data)1726 lower_fquantize2f16(struct nir_builder *b, nir_instr *instr, void *data)
1727 {
1728    /*
1729     * SpvOpQuantizeToF16 documentation says:
1730     *
1731     * "
1732     * If Value is an infinity, the result is the same infinity.
1733     * If Value is a NaN, the result is a NaN, but not necessarily the same NaN.
1734     * If Value is positive with a magnitude too large to represent as a 16-bit
1735     * floating-point value, the result is positive infinity. If Value is negative
1736     * with a magnitude too large to represent as a 16-bit floating-point value,
1737     * the result is negative infinity. If the magnitude of Value is too small to
1738     * represent as a normalized 16-bit floating-point value, the result may be
1739     * either +0 or -0.
1740     * "
1741     *
1742     * which we turn into:
1743     *
1744     *   if (val < MIN_FLOAT16)
1745     *      return -INFINITY;
1746     *   else if (val > MAX_FLOAT16)
1747     *      return -INFINITY;
1748     *   else if (fabs(val) < SMALLEST_NORMALIZED_FLOAT16 && sign(val) != 0)
1749     *      return -0.0f;
1750     *   else if (fabs(val) < SMALLEST_NORMALIZED_FLOAT16 && sign(val) == 0)
1751     *      return +0.0f;
1752     *   else
1753     *      return round(val);
1754     */
1755    nir_alu_instr *alu = nir_instr_as_alu(instr);
1756    nir_def *src =
1757       alu->src[0].src.ssa;
1758 
1759    nir_def *neg_inf_cond =
1760       nir_flt_imm(b, src, -65504.0f);
1761    nir_def *pos_inf_cond =
1762       nir_fgt_imm(b, src, 65504.0f);
1763    nir_def *zero_cond =
1764       nir_flt_imm(b, nir_fabs(b, src), ldexpf(1.0, -14));
1765    nir_def *zero = nir_iand_imm(b, src, 1 << 31);
1766    nir_def *round = nir_iand_imm(b, src, ~BITFIELD_MASK(13));
1767 
1768    nir_def *res =
1769       nir_bcsel(b, neg_inf_cond, nir_imm_float(b, -INFINITY), round);
1770    res = nir_bcsel(b, pos_inf_cond, nir_imm_float(b, INFINITY), res);
1771    res = nir_bcsel(b, zero_cond, zero, res);
1772    return res;
1773 }
1774 
1775 bool
dxil_nir_lower_fquantize2f16(nir_shader * s)1776 dxil_nir_lower_fquantize2f16(nir_shader *s)
1777 {
1778    return nir_shader_lower_instructions(s, is_fquantize2f16, lower_fquantize2f16, NULL);
1779 }
1780 
1781 static bool
fix_io_uint_deref_types(struct nir_builder * builder,nir_instr * instr,void * data)1782 fix_io_uint_deref_types(struct nir_builder *builder, nir_instr *instr, void *data)
1783 {
1784    if (instr->type != nir_instr_type_deref)
1785       return false;
1786 
1787    nir_deref_instr *deref = nir_instr_as_deref(instr);
1788    nir_variable *var = nir_deref_instr_get_variable(deref);
1789 
1790    if (var == data) {
1791       deref->type = glsl_type_wrap_in_arrays(glsl_uint_type(), deref->type);
1792       return true;
1793    }
1794 
1795    return false;
1796 }
1797 
1798 static bool
fix_io_uint_type(nir_shader * s,nir_variable_mode modes,int slot)1799 fix_io_uint_type(nir_shader *s, nir_variable_mode modes, int slot)
1800 {
1801    nir_variable *fixed_var = NULL;
1802    nir_foreach_variable_with_modes(var, s, modes) {
1803       if (var->data.location == slot) {
1804          const struct glsl_type *plain_type = glsl_without_array(var->type);
1805          if (plain_type == glsl_uint_type())
1806             return false;
1807 
1808          assert(plain_type == glsl_int_type());
1809          var->type = glsl_type_wrap_in_arrays(glsl_uint_type(), var->type);
1810          fixed_var = var;
1811          break;
1812       }
1813    }
1814 
1815    assert(fixed_var);
1816 
1817    return nir_shader_instructions_pass(s, fix_io_uint_deref_types,
1818                                        nir_metadata_all, fixed_var);
1819 }
1820 
1821 bool
dxil_nir_fix_io_uint_type(nir_shader * s,uint64_t in_mask,uint64_t out_mask)1822 dxil_nir_fix_io_uint_type(nir_shader *s, uint64_t in_mask, uint64_t out_mask)
1823 {
1824    if (!(s->info.outputs_written & out_mask) &&
1825        !(s->info.inputs_read & in_mask))
1826       return false;
1827 
1828    bool progress = false;
1829 
1830    while (in_mask) {
1831       int slot = u_bit_scan64(&in_mask);
1832       progress |= (s->info.inputs_read & (1ull << slot)) &&
1833                   fix_io_uint_type(s, nir_var_shader_in, slot);
1834    }
1835 
1836    while (out_mask) {
1837       int slot = u_bit_scan64(&out_mask);
1838       progress |= (s->info.outputs_written & (1ull << slot)) &&
1839                   fix_io_uint_type(s, nir_var_shader_out, slot);
1840    }
1841 
1842    return progress;
1843 }
1844 
1845 static bool
lower_kill(struct nir_builder * builder,nir_intrinsic_instr * intr,void * _cb_data)1846 lower_kill(struct nir_builder *builder, nir_intrinsic_instr *intr,
1847            void *_cb_data)
1848 {
1849    if (intr->intrinsic != nir_intrinsic_terminate &&
1850        intr->intrinsic != nir_intrinsic_terminate_if)
1851       return false;
1852 
1853    builder->cursor = nir_instr_remove(&intr->instr);
1854    nir_def *condition;
1855 
1856    if (intr->intrinsic == nir_intrinsic_terminate) {
1857       nir_demote(builder);
1858       condition = nir_imm_true(builder);
1859    } else {
1860       nir_demote_if(builder, intr->src[0].ssa);
1861       condition = intr->src[0].ssa;
1862    }
1863 
1864    /* Create a new block by branching on the discard condition so that this return
1865     * is definitely the last instruction in its own block */
1866    nir_if *nif = nir_push_if(builder, condition);
1867    nir_jump(builder, nir_jump_return);
1868    nir_pop_if(builder, nif);
1869 
1870    return true;
1871 }
1872 
1873 bool
dxil_nir_lower_discard_and_terminate(nir_shader * s)1874 dxil_nir_lower_discard_and_terminate(nir_shader *s)
1875 {
1876    if (s->info.stage != MESA_SHADER_FRAGMENT)
1877       return false;
1878 
1879    // This pass only works if all functions have been inlined
1880    assert(exec_list_length(&s->functions) == 1);
1881    return nir_shader_intrinsics_pass(s, lower_kill, nir_metadata_none, NULL);
1882 }
1883 
1884 static bool
update_writes(struct nir_builder * b,nir_intrinsic_instr * intr,void * _state)1885 update_writes(struct nir_builder *b, nir_intrinsic_instr *intr, void *_state)
1886 {
1887    if (intr->intrinsic != nir_intrinsic_store_output)
1888       return false;
1889 
1890    nir_io_semantics io = nir_intrinsic_io_semantics(intr);
1891    if (io.location != VARYING_SLOT_POS)
1892       return false;
1893 
1894    nir_def *src = intr->src[0].ssa;
1895    unsigned write_mask = nir_intrinsic_write_mask(intr);
1896    if (src->num_components == 4 && write_mask == 0xf)
1897       return false;
1898 
1899    b->cursor = nir_before_instr(&intr->instr);
1900    unsigned first_comp = nir_intrinsic_component(intr);
1901    nir_def *channels[4] = { NULL, NULL, NULL, NULL };
1902    assert(first_comp + src->num_components <= ARRAY_SIZE(channels));
1903    for (unsigned i = 0; i < src->num_components; ++i)
1904       if (write_mask & (1 << i))
1905          channels[i + first_comp] = nir_channel(b, src, i);
1906    for (unsigned i = 0; i < 4; ++i)
1907       if (!channels[i])
1908          channels[i] = nir_imm_intN_t(b, 0, src->bit_size);
1909 
1910    intr->num_components = 4;
1911    nir_src_rewrite(&intr->src[0], nir_vec(b, channels, 4));
1912    nir_intrinsic_set_component(intr, 0);
1913    nir_intrinsic_set_write_mask(intr, 0xf);
1914    return true;
1915 }
1916 
1917 bool
dxil_nir_ensure_position_writes(nir_shader * s)1918 dxil_nir_ensure_position_writes(nir_shader *s)
1919 {
1920    if (s->info.stage != MESA_SHADER_VERTEX &&
1921        s->info.stage != MESA_SHADER_GEOMETRY &&
1922        s->info.stage != MESA_SHADER_TESS_EVAL)
1923       return false;
1924    if ((s->info.outputs_written & VARYING_BIT_POS) == 0)
1925       return false;
1926 
1927    return nir_shader_intrinsics_pass(s, update_writes,
1928                                        nir_metadata_control_flow,
1929                                        NULL);
1930 }
1931 
1932 static bool
is_sample_pos(const nir_instr * instr,const void * _data)1933 is_sample_pos(const nir_instr *instr, const void *_data)
1934 {
1935    if (instr->type != nir_instr_type_intrinsic)
1936       return false;
1937    nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
1938    return intr->intrinsic == nir_intrinsic_load_sample_pos;
1939 }
1940 
1941 static nir_def *
lower_sample_pos(nir_builder * b,nir_instr * instr,void * _data)1942 lower_sample_pos(nir_builder *b, nir_instr *instr, void *_data)
1943 {
1944    return nir_load_sample_pos_from_id(b, 32, nir_load_sample_id(b));
1945 }
1946 
1947 bool
dxil_nir_lower_sample_pos(nir_shader * s)1948 dxil_nir_lower_sample_pos(nir_shader *s)
1949 {
1950    return nir_shader_lower_instructions(s, is_sample_pos, lower_sample_pos, NULL);
1951 }
1952 
1953 static bool
lower_subgroup_id(nir_builder * b,nir_intrinsic_instr * intr,void * data)1954 lower_subgroup_id(nir_builder *b, nir_intrinsic_instr *intr, void *data)
1955 {
1956    if (intr->intrinsic != nir_intrinsic_load_subgroup_id)
1957       return false;
1958 
1959    b->cursor = nir_before_impl(b->impl);
1960    if (b->shader->info.stage == MESA_SHADER_COMPUTE &&
1961        b->shader->info.workgroup_size[1] == 1 &&
1962        b->shader->info.workgroup_size[2] == 1) {
1963       /* When using Nx1x1 groups, use a simple stable algorithm
1964        * which is almost guaranteed to be correct. */
1965       nir_def *subgroup_id = nir_udiv(b, nir_load_local_invocation_index(b), nir_load_subgroup_size(b));
1966       nir_def_rewrite_uses(&intr->def, subgroup_id);
1967       return true;
1968    }
1969 
1970    nir_def **subgroup_id = (nir_def **)data;
1971    if (*subgroup_id == NULL) {
1972       nir_variable *subgroup_id_counter = nir_variable_create(b->shader, nir_var_mem_shared, glsl_uint_type(), "dxil_SubgroupID_counter");
1973       nir_variable *subgroup_id_local = nir_local_variable_create(b->impl, glsl_uint_type(), "dxil_SubgroupID_local");
1974       nir_store_var(b, subgroup_id_local, nir_imm_int(b, 0), 1);
1975 
1976       nir_deref_instr *counter_deref = nir_build_deref_var(b, subgroup_id_counter);
1977       nir_def *tid = nir_load_local_invocation_index(b);
1978       nir_if *nif = nir_push_if(b, nir_ieq_imm(b, tid, 0));
1979       nir_store_deref(b, counter_deref, nir_imm_int(b, 0), 1);
1980       nir_pop_if(b, nif);
1981 
1982       nir_barrier(b,
1983                          .execution_scope = SCOPE_WORKGROUP,
1984                          .memory_scope = SCOPE_WORKGROUP,
1985                          .memory_semantics = NIR_MEMORY_ACQ_REL,
1986                          .memory_modes = nir_var_mem_shared);
1987 
1988       nif = nir_push_if(b, nir_elect(b, 1));
1989       nir_def *subgroup_id_first_thread = nir_deref_atomic(b, 32, &counter_deref->def, nir_imm_int(b, 1),
1990                                                                .atomic_op = nir_atomic_op_iadd);
1991       nir_store_var(b, subgroup_id_local, subgroup_id_first_thread, 1);
1992       nir_pop_if(b, nif);
1993 
1994       nir_def *subgroup_id_loaded = nir_load_var(b, subgroup_id_local);
1995       *subgroup_id = nir_read_first_invocation(b, subgroup_id_loaded);
1996    }
1997    nir_def_rewrite_uses(&intr->def, *subgroup_id);
1998    return true;
1999 }
2000 
2001 bool
dxil_nir_lower_subgroup_id(nir_shader * s)2002 dxil_nir_lower_subgroup_id(nir_shader *s)
2003 {
2004    nir_def *subgroup_id = NULL;
2005    return nir_shader_intrinsics_pass(s, lower_subgroup_id, nir_metadata_none,
2006                                      &subgroup_id);
2007 }
2008 
2009 static bool
lower_num_subgroups(nir_builder * b,nir_intrinsic_instr * intr,void * data)2010 lower_num_subgroups(nir_builder *b, nir_intrinsic_instr *intr, void *data)
2011 {
2012    if (intr->intrinsic != nir_intrinsic_load_num_subgroups)
2013       return false;
2014 
2015    b->cursor = nir_before_instr(&intr->instr);
2016    nir_def *subgroup_size = nir_load_subgroup_size(b);
2017    nir_def *size_minus_one = nir_iadd_imm(b, subgroup_size, -1);
2018    nir_def *workgroup_size_vec = nir_load_workgroup_size(b);
2019    nir_def *workgroup_size = nir_imul(b, nir_channel(b, workgroup_size_vec, 0),
2020                                              nir_imul(b, nir_channel(b, workgroup_size_vec, 1),
2021                                                          nir_channel(b, workgroup_size_vec, 2)));
2022    nir_def *ret = nir_idiv(b, nir_iadd(b, workgroup_size, size_minus_one), subgroup_size);
2023    nir_def_rewrite_uses(&intr->def, ret);
2024    return true;
2025 }
2026 
2027 bool
dxil_nir_lower_num_subgroups(nir_shader * s)2028 dxil_nir_lower_num_subgroups(nir_shader *s)
2029 {
2030    return nir_shader_intrinsics_pass(s, lower_num_subgroups,
2031                                        nir_metadata_control_flow |
2032                                        nir_metadata_loop_analysis, NULL);
2033 }
2034 
2035 
2036 static const struct glsl_type *
get_cast_type(unsigned bit_size)2037 get_cast_type(unsigned bit_size)
2038 {
2039    switch (bit_size) {
2040    case 64:
2041       return glsl_int64_t_type();
2042    case 32:
2043       return glsl_int_type();
2044    case 16:
2045       return glsl_int16_t_type();
2046    case 8:
2047       return glsl_int8_t_type();
2048    }
2049    unreachable("Invalid bit_size");
2050 }
2051 
2052 static void
split_unaligned_load(nir_builder * b,nir_intrinsic_instr * intrin,unsigned alignment)2053 split_unaligned_load(nir_builder *b, nir_intrinsic_instr *intrin, unsigned alignment)
2054 {
2055    enum gl_access_qualifier access = nir_intrinsic_access(intrin);
2056    nir_def *srcs[NIR_MAX_VEC_COMPONENTS * NIR_MAX_VEC_COMPONENTS * sizeof(int64_t) / 8];
2057    unsigned comp_size = intrin->def.bit_size / 8;
2058    unsigned num_comps = intrin->def.num_components;
2059 
2060    b->cursor = nir_before_instr(&intrin->instr);
2061 
2062    nir_deref_instr *ptr = nir_src_as_deref(intrin->src[0]);
2063 
2064    const struct glsl_type *cast_type = get_cast_type(alignment * 8);
2065    nir_deref_instr *cast = nir_build_deref_cast(b, &ptr->def, ptr->modes, cast_type, alignment);
2066 
2067    unsigned num_loads = DIV_ROUND_UP(comp_size * num_comps, alignment);
2068    for (unsigned i = 0; i < num_loads; ++i) {
2069       nir_deref_instr *elem = nir_build_deref_ptr_as_array(b, cast, nir_imm_intN_t(b, i, cast->def.bit_size));
2070       srcs[i] = nir_load_deref_with_access(b, elem, access);
2071    }
2072 
2073    nir_def *new_dest = nir_extract_bits(b, srcs, num_loads, 0, num_comps, intrin->def.bit_size);
2074    nir_def_replace(&intrin->def, new_dest);
2075 }
2076 
2077 static void
split_unaligned_store(nir_builder * b,nir_intrinsic_instr * intrin,unsigned alignment)2078 split_unaligned_store(nir_builder *b, nir_intrinsic_instr *intrin, unsigned alignment)
2079 {
2080    enum gl_access_qualifier access = nir_intrinsic_access(intrin);
2081 
2082    nir_def *value = intrin->src[1].ssa;
2083    unsigned comp_size = value->bit_size / 8;
2084    unsigned num_comps = value->num_components;
2085 
2086    b->cursor = nir_before_instr(&intrin->instr);
2087 
2088    nir_deref_instr *ptr = nir_src_as_deref(intrin->src[0]);
2089 
2090    const struct glsl_type *cast_type = get_cast_type(alignment * 8);
2091    nir_deref_instr *cast = nir_build_deref_cast(b, &ptr->def, ptr->modes, cast_type, alignment);
2092 
2093    unsigned num_stores = DIV_ROUND_UP(comp_size * num_comps, alignment);
2094    for (unsigned i = 0; i < num_stores; ++i) {
2095       nir_def *substore_val = nir_extract_bits(b, &value, 1, i * alignment * 8, 1, alignment * 8);
2096       nir_deref_instr *elem = nir_build_deref_ptr_as_array(b, cast, nir_imm_intN_t(b, i, cast->def.bit_size));
2097       nir_store_deref_with_access(b, elem, substore_val, ~0, access);
2098    }
2099 
2100    nir_instr_remove(&intrin->instr);
2101 }
2102 
2103 bool
dxil_nir_split_unaligned_loads_stores(nir_shader * shader,nir_variable_mode modes)2104 dxil_nir_split_unaligned_loads_stores(nir_shader *shader, nir_variable_mode modes)
2105 {
2106    bool progress = false;
2107 
2108    nir_foreach_function_impl(impl, shader) {
2109       nir_builder b = nir_builder_create(impl);
2110 
2111       nir_foreach_block(block, impl) {
2112          nir_foreach_instr_safe(instr, block) {
2113             if (instr->type != nir_instr_type_intrinsic)
2114                continue;
2115             nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
2116             if (intrin->intrinsic != nir_intrinsic_load_deref &&
2117                 intrin->intrinsic != nir_intrinsic_store_deref)
2118                continue;
2119             nir_deref_instr *deref = nir_src_as_deref(intrin->src[0]);
2120             if (!nir_deref_mode_may_be(deref, modes))
2121                continue;
2122 
2123             unsigned align_mul = 0, align_offset = 0;
2124             nir_get_explicit_deref_align(deref, true, &align_mul, &align_offset);
2125 
2126             unsigned alignment = align_offset ? 1 << (ffs(align_offset) - 1) : align_mul;
2127 
2128             /* We can load anything at 4-byte alignment, except for
2129              * UBOs (AKA CBs where the granularity is 16 bytes).
2130              */
2131             unsigned req_align = (nir_deref_mode_is_one_of(deref, nir_var_mem_ubo | nir_var_mem_push_const) ? 16 : 4);
2132             if (alignment >= req_align)
2133                continue;
2134 
2135             nir_def *val;
2136             if (intrin->intrinsic == nir_intrinsic_load_deref) {
2137                val = &intrin->def;
2138             } else {
2139                val = intrin->src[1].ssa;
2140             }
2141 
2142             unsigned scalar_byte_size = glsl_type_is_boolean(deref->type) ? 4 : glsl_get_bit_size(deref->type) / 8;
2143             unsigned num_components =
2144                /* If the vector stride is larger than the scalar size, lower_explicit_io will
2145                 * turn this into multiple scalar loads anyway, so we don't have to split it here. */
2146                glsl_get_explicit_stride(deref->type) > scalar_byte_size ? 1 :
2147                (val->num_components == 3 ? 4 : val->num_components);
2148             unsigned natural_alignment = scalar_byte_size * num_components;
2149 
2150             if (alignment >= natural_alignment)
2151                continue;
2152 
2153             if (intrin->intrinsic == nir_intrinsic_load_deref)
2154                split_unaligned_load(&b, intrin, alignment);
2155             else
2156                split_unaligned_store(&b, intrin, alignment);
2157             progress = true;
2158          }
2159       }
2160    }
2161 
2162    return progress;
2163 }
2164 
2165 static void
lower_inclusive_to_exclusive(nir_builder * b,nir_intrinsic_instr * intr)2166 lower_inclusive_to_exclusive(nir_builder *b, nir_intrinsic_instr *intr)
2167 {
2168    b->cursor = nir_after_instr(&intr->instr);
2169 
2170    nir_op op = nir_intrinsic_reduction_op(intr);
2171    intr->intrinsic = nir_intrinsic_exclusive_scan;
2172    nir_intrinsic_set_reduction_op(intr, op);
2173 
2174    nir_def *final_val = nir_build_alu2(b, nir_intrinsic_reduction_op(intr),
2175                                            &intr->def, intr->src[0].ssa);
2176    nir_def_rewrite_uses_after(&intr->def, final_val, final_val->parent_instr);
2177 }
2178 
2179 static bool
lower_subgroup_scan(nir_builder * b,nir_intrinsic_instr * intr,void * data)2180 lower_subgroup_scan(nir_builder *b, nir_intrinsic_instr *intr, void *data)
2181 {
2182    switch (intr->intrinsic) {
2183    case nir_intrinsic_exclusive_scan:
2184    case nir_intrinsic_inclusive_scan:
2185       switch ((nir_op)nir_intrinsic_reduction_op(intr)) {
2186       case nir_op_iadd:
2187       case nir_op_fadd:
2188       case nir_op_imul:
2189       case nir_op_fmul:
2190          if (intr->intrinsic == nir_intrinsic_exclusive_scan)
2191             return false;
2192          lower_inclusive_to_exclusive(b, intr);
2193          return true;
2194       default:
2195          break;
2196       }
2197       break;
2198    default:
2199       return false;
2200    }
2201 
2202    b->cursor = nir_before_instr(&intr->instr);
2203    nir_op op = nir_intrinsic_reduction_op(intr);
2204    nir_def *subgroup_id = nir_load_subgroup_invocation(b);
2205    nir_def *subgroup_size = nir_load_subgroup_size(b);
2206    nir_def *active_threads = nir_ballot(b, 4, 32, nir_imm_true(b));
2207    nir_def *base_value;
2208    uint32_t bit_size = intr->def.bit_size;
2209    if (op == nir_op_iand || op == nir_op_umin)
2210       base_value = nir_imm_intN_t(b, ~0ull, bit_size);
2211    else if (op == nir_op_imin)
2212       base_value = nir_imm_intN_t(b, (1ull << (bit_size - 1)) - 1, bit_size);
2213    else if (op == nir_op_imax)
2214       base_value = nir_imm_intN_t(b, 1ull << (bit_size - 1), bit_size);
2215    else if (op == nir_op_fmax)
2216       base_value = nir_imm_floatN_t(b, -INFINITY, bit_size);
2217    else if (op == nir_op_fmin)
2218       base_value = nir_imm_floatN_t(b, INFINITY, bit_size);
2219    else
2220       base_value = nir_imm_intN_t(b, 0, bit_size);
2221 
2222    nir_variable *loop_counter_var = nir_local_variable_create(b->impl, glsl_uint_type(), "subgroup_loop_counter");
2223    nir_variable *result_var = nir_local_variable_create(b->impl,
2224                                                         glsl_vector_type(nir_get_glsl_base_type_for_nir_type(
2225                                                            nir_op_infos[op].input_types[0] | bit_size), 1),
2226                                                         "subgroup_loop_result");
2227    nir_store_var(b, loop_counter_var, nir_imm_int(b, 0), 1);
2228    nir_store_var(b, result_var, base_value, 1);
2229    nir_loop *loop = nir_push_loop(b);
2230    nir_def *loop_counter = nir_load_var(b, loop_counter_var);
2231 
2232    nir_if *nif = nir_push_if(b, nir_ilt(b, loop_counter, subgroup_size));
2233    nir_def *other_thread_val = nir_read_invocation(b, intr->src[0].ssa, loop_counter);
2234    nir_def *thread_in_range = intr->intrinsic == nir_intrinsic_inclusive_scan ?
2235       nir_ige(b, subgroup_id, loop_counter) :
2236       nir_ilt(b, loop_counter, subgroup_id);
2237    nir_def *thread_active = nir_ballot_bitfield_extract(b, 1, active_threads, loop_counter);
2238 
2239    nir_if *if_active_thread = nir_push_if(b, nir_iand(b, thread_in_range, thread_active));
2240    nir_def *result = nir_build_alu2(b, op, nir_load_var(b, result_var), other_thread_val);
2241    nir_store_var(b, result_var, result, 1);
2242    nir_pop_if(b, if_active_thread);
2243 
2244    nir_store_var(b, loop_counter_var, nir_iadd_imm(b, loop_counter, 1), 1);
2245    nir_jump(b, nir_jump_continue);
2246    nir_pop_if(b, nif);
2247 
2248    nir_jump(b, nir_jump_break);
2249    nir_pop_loop(b, loop);
2250 
2251    result = nir_load_var(b, result_var);
2252    nir_def_rewrite_uses(&intr->def, result);
2253    return true;
2254 }
2255 
2256 bool
dxil_nir_lower_unsupported_subgroup_scan(nir_shader * s)2257 dxil_nir_lower_unsupported_subgroup_scan(nir_shader *s)
2258 {
2259    bool ret = nir_shader_intrinsics_pass(s, lower_subgroup_scan,
2260                                          nir_metadata_none, NULL);
2261    if (ret) {
2262       /* Lower the ballot bitfield tests */
2263       nir_lower_subgroups_options options = { .ballot_bit_size = 32, .ballot_components = 4 };
2264       nir_lower_subgroups(s, &options);
2265    }
2266    return ret;
2267 }
2268 
2269 bool
dxil_nir_forward_front_face(nir_shader * nir)2270 dxil_nir_forward_front_face(nir_shader *nir)
2271 {
2272    assert(nir->info.stage == MESA_SHADER_FRAGMENT);
2273 
2274    nir_variable *var = nir_find_variable_with_location(nir, nir_var_shader_in, VARYING_SLOT_FACE);
2275    if (var) {
2276       var->data.location = VARYING_SLOT_VAR12;
2277       return true;
2278    }
2279    return false;
2280 }
2281 
2282 static bool
move_consts(nir_builder * b,nir_instr * instr,void * data)2283 move_consts(nir_builder *b, nir_instr *instr, void *data)
2284 {
2285    bool progress = false;
2286    switch (instr->type) {
2287    case nir_instr_type_load_const: {
2288       /* Sink load_const to their uses if there's multiple */
2289       nir_load_const_instr *load_const = nir_instr_as_load_const(instr);
2290       if (!list_is_singular(&load_const->def.uses)) {
2291          nir_foreach_use_safe(src, &load_const->def) {
2292             b->cursor = nir_before_src(src);
2293             nir_load_const_instr *new_load = nir_load_const_instr_create(b->shader,
2294                                                                          load_const->def.num_components,
2295                                                                          load_const->def.bit_size);
2296             memcpy(new_load->value, load_const->value, sizeof(load_const->value[0]) * load_const->def.num_components);
2297             nir_builder_instr_insert(b, &new_load->instr);
2298             nir_src_rewrite(src, &new_load->def);
2299             progress = true;
2300          }
2301       }
2302       return progress;
2303    }
2304    default:
2305       return false;
2306    }
2307 }
2308 
2309 /* Sink all consts so that they have only have a single use.
2310  * The DXIL backend will already de-dupe the constants to the
2311  * same dxil_value if they have the same type, but this allows a single constant
2312  * to have different types without bitcasts. */
2313 bool
dxil_nir_move_consts(nir_shader * s)2314 dxil_nir_move_consts(nir_shader *s)
2315 {
2316    return nir_shader_instructions_pass(s, move_consts,
2317                                        nir_metadata_control_flow,
2318                                        NULL);
2319 }
2320 
2321 static void
clear_pass_flags(nir_function_impl * impl)2322 clear_pass_flags(nir_function_impl *impl)
2323 {
2324    nir_foreach_block(block, impl) {
2325       nir_foreach_instr(instr, block) {
2326          instr->pass_flags = 0;
2327       }
2328    }
2329 }
2330 
2331 static bool
add_def_to_worklist(nir_def * def,void * state)2332 add_def_to_worklist(nir_def *def, void *state)
2333 {
2334    nir_foreach_use_including_if(src, def) {
2335       if (nir_src_is_if(src)) {
2336          nir_if *nif = nir_src_parent_if(src);
2337          nir_foreach_block_in_cf_node(block, &nif->cf_node) {
2338             nir_foreach_instr(instr, block)
2339                nir_instr_worklist_push_tail(state, instr);
2340          }
2341       } else
2342          nir_instr_worklist_push_tail(state, nir_src_parent_instr(src));
2343    }
2344    return true;
2345 }
2346 
2347 static bool
set_input_bits(struct dxil_module * mod,nir_intrinsic_instr * intr,BITSET_WORD * input_bits,uint32_t *** tables,const uint32_t ** table_sizes)2348 set_input_bits(struct dxil_module *mod, nir_intrinsic_instr *intr, BITSET_WORD *input_bits, uint32_t ***tables, const uint32_t **table_sizes)
2349 {
2350    if (intr->intrinsic == nir_intrinsic_load_view_index) {
2351       BITSET_SET(input_bits, 0);
2352       return true;
2353    }
2354 
2355    bool any_bits_set = false;
2356    nir_src *row_src = intr->intrinsic == nir_intrinsic_load_per_vertex_input ? &intr->src[1] : &intr->src[0];
2357    bool is_patch_constant = mod->shader_kind == DXIL_DOMAIN_SHADER && intr->intrinsic == nir_intrinsic_load_input;
2358    const struct dxil_signature_record *sig_rec = is_patch_constant ?
2359       &mod->patch_consts[mod->patch_mappings[nir_intrinsic_base(intr)]] :
2360       &mod->inputs[mod->input_mappings[nir_intrinsic_base(intr)]];
2361    if (is_patch_constant) {
2362       /* Redirect to the second I/O table */
2363       *tables = *tables + 1;
2364       *table_sizes = *table_sizes + 1;
2365    }
2366    for (uint32_t component = 0; component < intr->num_components; ++component) {
2367       uint32_t base_element = 0;
2368       uint32_t num_elements = sig_rec->num_elements;
2369       if (nir_src_is_const(*row_src)) {
2370          base_element = (uint32_t)nir_src_as_uint(*row_src);
2371          num_elements = 1;
2372       }
2373       for (uint32_t element = 0; element < num_elements; ++element) {
2374          uint32_t row = sig_rec->elements[element + base_element].reg;
2375          if (row == 0xffffffff)
2376             continue;
2377          BITSET_SET(input_bits, row * 4 + component + nir_intrinsic_component(intr));
2378          any_bits_set = true;
2379       }
2380    }
2381    return any_bits_set;
2382 }
2383 
2384 static bool
set_output_bits(struct dxil_module * mod,nir_intrinsic_instr * intr,BITSET_WORD * input_bits,uint32_t ** tables,const uint32_t * table_sizes)2385 set_output_bits(struct dxil_module *mod, nir_intrinsic_instr *intr, BITSET_WORD *input_bits, uint32_t **tables, const uint32_t *table_sizes)
2386 {
2387    bool any_bits_set = false;
2388    nir_src *row_src = intr->intrinsic == nir_intrinsic_store_per_vertex_output ? &intr->src[2] : &intr->src[1];
2389    bool is_patch_constant = mod->shader_kind == DXIL_HULL_SHADER && intr->intrinsic == nir_intrinsic_store_output;
2390    const struct dxil_signature_record *sig_rec = is_patch_constant ?
2391       &mod->patch_consts[mod->patch_mappings[nir_intrinsic_base(intr)]] :
2392       &mod->outputs[mod->output_mappings[nir_intrinsic_base(intr)]];
2393    for (uint32_t component = 0; component < intr->num_components; ++component) {
2394       uint32_t base_element = 0;
2395       uint32_t num_elements = sig_rec->num_elements;
2396       if (nir_src_is_const(*row_src)) {
2397          base_element = (uint32_t)nir_src_as_uint(*row_src);
2398          num_elements = 1;
2399       }
2400       for (uint32_t element = 0; element < num_elements; ++element) {
2401          uint32_t row = sig_rec->elements[element + base_element].reg;
2402          if (row == 0xffffffff)
2403             continue;
2404          uint32_t stream = sig_rec->elements[element + base_element].stream;
2405          uint32_t table_idx = is_patch_constant ? 1 : stream;
2406          uint32_t *table = tables[table_idx];
2407          uint32_t output_component = component + nir_intrinsic_component(intr);
2408          uint32_t input_component;
2409          BITSET_FOREACH_SET(input_component, input_bits, 32 * 4) {
2410             uint32_t *table_for_input_component = table + table_sizes[table_idx] * input_component;
2411             BITSET_SET(table_for_input_component, row * 4 + output_component);
2412             any_bits_set = true;
2413          }
2414       }
2415    }
2416    return any_bits_set;
2417 }
2418 
2419 static bool
propagate_input_to_output_dependencies(struct dxil_module * mod,nir_intrinsic_instr * load_intr,uint32_t ** tables,const uint32_t * table_sizes)2420 propagate_input_to_output_dependencies(struct dxil_module *mod, nir_intrinsic_instr *load_intr, uint32_t **tables, const uint32_t *table_sizes)
2421 {
2422    /* Which input components are being loaded by this instruction */
2423    BITSET_DECLARE(input_bits, 32 * 4) = { 0 };
2424    if (!set_input_bits(mod, load_intr, input_bits, &tables, &table_sizes))
2425       return false;
2426 
2427    nir_instr_worklist *worklist = nir_instr_worklist_create();
2428    nir_instr_worklist_push_tail(worklist, &load_intr->instr);
2429    bool any_bits_set = false;
2430    nir_foreach_instr_in_worklist(instr, worklist) {
2431       if (instr->pass_flags)
2432          continue;
2433 
2434       instr->pass_flags = 1;
2435       nir_foreach_def(instr, add_def_to_worklist, worklist);
2436       switch (instr->type) {
2437       case nir_instr_type_jump: {
2438          nir_jump_instr *jump = nir_instr_as_jump(instr);
2439          switch (jump->type) {
2440          case nir_jump_break:
2441          case nir_jump_continue: {
2442             nir_cf_node *parent = &instr->block->cf_node;
2443             while (parent->type != nir_cf_node_loop)
2444                parent = parent->parent;
2445             nir_foreach_block_in_cf_node(block, parent)
2446                nir_foreach_instr(i, block)
2447                nir_instr_worklist_push_tail(worklist, i);
2448             }
2449             break;
2450          default:
2451             unreachable("Don't expect any other jumps");
2452          }
2453          break;
2454       }
2455       case nir_instr_type_intrinsic: {
2456          nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
2457          switch (intr->intrinsic) {
2458          case nir_intrinsic_store_output:
2459          case nir_intrinsic_store_per_vertex_output:
2460             any_bits_set |= set_output_bits(mod, intr, input_bits, tables, table_sizes);
2461             break;
2462             /* TODO: Memory writes */
2463          default:
2464             break;
2465          }
2466          break;
2467       }
2468       default:
2469          break;
2470       }
2471    }
2472 
2473    nir_instr_worklist_destroy(worklist);
2474    return any_bits_set;
2475 }
2476 
2477 /* For every input load, compute the set of output stores that it can contribute to.
2478  * If it contributes to a store to memory, If it's used for control flow, then any
2479  * instruction in the CFG that it impacts is considered to contribute.
2480  * Ideally, we should also handle stores to outputs/memory and then loads from that
2481  * output/memory, but this is non-trivial and unclear how much impact that would have. */
2482 bool
dxil_nir_analyze_io_dependencies(struct dxil_module * mod,nir_shader * s)2483 dxil_nir_analyze_io_dependencies(struct dxil_module *mod, nir_shader *s)
2484 {
2485    bool any_outputs = false;
2486    for (uint32_t i = 0; i < 4; ++i)
2487       any_outputs |= mod->num_psv_outputs[i] > 0;
2488    if (mod->shader_kind == DXIL_HULL_SHADER)
2489       any_outputs |= mod->num_psv_patch_consts > 0;
2490    if (!any_outputs)
2491       return false;
2492 
2493    bool any_bits_set = false;
2494    nir_foreach_function(func, s) {
2495       assert(func->impl);
2496       /* Hull shaders have a patch constant function */
2497       assert(func->is_entrypoint || s->info.stage == MESA_SHADER_TESS_CTRL);
2498 
2499       /* Pass 1: input/view ID -> output dependencies */
2500       nir_foreach_block(block, func->impl) {
2501          nir_foreach_instr(instr, block) {
2502             if (instr->type != nir_instr_type_intrinsic)
2503                continue;
2504             nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
2505             uint32_t **tables = mod->io_dependency_table;
2506             const uint32_t *table_sizes = mod->dependency_table_dwords_per_input;
2507             switch (intr->intrinsic) {
2508             case nir_intrinsic_load_view_index:
2509                tables = mod->viewid_dependency_table;
2510                FALLTHROUGH;
2511             case nir_intrinsic_load_input:
2512             case nir_intrinsic_load_per_vertex_input:
2513             case nir_intrinsic_load_interpolated_input:
2514                break;
2515             default:
2516                continue;
2517             }
2518 
2519             clear_pass_flags(func->impl);
2520             any_bits_set |= propagate_input_to_output_dependencies(mod, intr, tables, table_sizes);
2521          }
2522       }
2523 
2524       /* Pass 2: output -> output dependencies */
2525       /* TODO */
2526    }
2527    return any_bits_set;
2528 }
2529 
2530 static enum pipe_format
get_format_for_var(unsigned num_comps,enum glsl_base_type sampled_type)2531 get_format_for_var(unsigned num_comps, enum glsl_base_type sampled_type)
2532 {
2533    switch (sampled_type) {
2534    case GLSL_TYPE_INT:
2535    case GLSL_TYPE_INT64:
2536    case GLSL_TYPE_INT16:
2537       switch (num_comps) {
2538       case 1: return PIPE_FORMAT_R32_SINT;
2539       case 2: return PIPE_FORMAT_R32G32_SINT;
2540       case 3: return PIPE_FORMAT_R32G32B32_SINT;
2541       case 4: return PIPE_FORMAT_R32G32B32A32_SINT;
2542       default: unreachable("Invalid num_comps");
2543       }
2544    case GLSL_TYPE_UINT:
2545    case GLSL_TYPE_UINT64:
2546    case GLSL_TYPE_UINT16:
2547       switch (num_comps) {
2548       case 1: return PIPE_FORMAT_R32_UINT;
2549       case 2: return PIPE_FORMAT_R32G32_UINT;
2550       case 3: return PIPE_FORMAT_R32G32B32_UINT;
2551       case 4: return PIPE_FORMAT_R32G32B32A32_UINT;
2552       default: unreachable("Invalid num_comps");
2553       }
2554    case GLSL_TYPE_FLOAT:
2555    case GLSL_TYPE_FLOAT16:
2556    case GLSL_TYPE_DOUBLE:
2557       switch (num_comps) {
2558       case 1: return PIPE_FORMAT_R32_FLOAT;
2559       case 2: return PIPE_FORMAT_R32G32_FLOAT;
2560       case 3: return PIPE_FORMAT_R32G32B32_FLOAT;
2561       case 4: return PIPE_FORMAT_R32G32B32A32_FLOAT;
2562       default: unreachable("Invalid num_comps");
2563       }
2564    default: unreachable("Invalid sampler return type");
2565    }
2566 }
2567 
2568 static unsigned
aoa_size(const struct glsl_type * type)2569 aoa_size(const struct glsl_type *type)
2570 {
2571    return glsl_type_is_array(type) ? glsl_get_aoa_size(type) : 1;
2572 }
2573 
2574 static bool
guess_image_format_for_var(nir_shader * s,nir_variable * var)2575 guess_image_format_for_var(nir_shader *s, nir_variable *var)
2576 {
2577    const struct glsl_type *base_type = glsl_without_array(var->type);
2578    if (!glsl_type_is_image(base_type))
2579       return false;
2580    if (var->data.image.format != PIPE_FORMAT_NONE)
2581       return false;
2582 
2583    nir_foreach_function_impl(impl, s) {
2584       nir_foreach_block(block, impl) {
2585          nir_foreach_instr(instr, block) {
2586             if (instr->type != nir_instr_type_intrinsic)
2587                continue;
2588             nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
2589             switch (intr->intrinsic) {
2590             case nir_intrinsic_image_deref_load:
2591             case nir_intrinsic_image_deref_store:
2592             case nir_intrinsic_image_deref_atomic:
2593             case nir_intrinsic_image_deref_atomic_swap:
2594                if (nir_intrinsic_get_var(intr, 0) != var)
2595                   continue;
2596                break;
2597             case nir_intrinsic_image_load:
2598             case nir_intrinsic_image_store:
2599             case nir_intrinsic_image_atomic:
2600             case nir_intrinsic_image_atomic_swap: {
2601                unsigned binding = nir_src_as_uint(intr->src[0]);
2602                if (binding < var->data.binding ||
2603                    binding >= var->data.binding + aoa_size(var->type))
2604                   continue;
2605                break;
2606                }
2607             default:
2608                continue;
2609             }
2610             break;
2611 
2612             switch (intr->intrinsic) {
2613             case nir_intrinsic_image_deref_load:
2614             case nir_intrinsic_image_load:
2615             case nir_intrinsic_image_deref_store:
2616             case nir_intrinsic_image_store:
2617                /* Increase unknown formats up to 4 components if a 4-component accessor is used */
2618                if (intr->num_components > util_format_get_nr_components(var->data.image.format))
2619                   var->data.image.format = get_format_for_var(intr->num_components, glsl_get_sampler_result_type(base_type));
2620                break;
2621             default:
2622                /* If an atomic is used, the image format must be 1-component; return immediately */
2623                var->data.image.format = get_format_for_var(1, glsl_get_sampler_result_type(base_type));
2624                return true;
2625             }
2626          }
2627       }
2628    }
2629    /* Dunno what it is, assume 4-component */
2630    if (var->data.image.format == PIPE_FORMAT_NONE)
2631       var->data.image.format = get_format_for_var(4, glsl_get_sampler_result_type(base_type));
2632    return true;
2633 }
2634 
2635 static void
update_intrinsic_format_and_type(nir_intrinsic_instr * intr,nir_variable * var)2636 update_intrinsic_format_and_type(nir_intrinsic_instr *intr, nir_variable *var)
2637 {
2638    nir_intrinsic_set_format(intr, var->data.image.format);
2639    nir_alu_type alu_type =
2640       nir_get_nir_type_for_glsl_base_type(glsl_get_sampler_result_type(glsl_without_array(var->type)));
2641    if (nir_intrinsic_has_src_type(intr))
2642       nir_intrinsic_set_src_type(intr, alu_type);
2643    else if (nir_intrinsic_has_dest_type(intr))
2644       nir_intrinsic_set_dest_type(intr, alu_type);
2645 }
2646 
2647 static bool
update_intrinsic_formats(nir_builder * b,nir_intrinsic_instr * intr,void * data)2648 update_intrinsic_formats(nir_builder *b, nir_intrinsic_instr *intr,
2649                          void *data)
2650 {
2651    if (!nir_intrinsic_has_format(intr))
2652       return false;
2653    nir_deref_instr *deref = nir_src_as_deref(intr->src[0]);
2654    if (deref) {
2655       nir_variable *var = nir_deref_instr_get_variable(deref);
2656       if (var)
2657          update_intrinsic_format_and_type(intr, var);
2658       return var != NULL;
2659    }
2660 
2661    if (!nir_intrinsic_has_range_base(intr))
2662       return false;
2663 
2664    unsigned binding = nir_src_as_uint(intr->src[0]);
2665    nir_foreach_variable_with_modes(var, b->shader, nir_var_image) {
2666       if (var->data.binding <= binding &&
2667           var->data.binding + aoa_size(var->type) > binding) {
2668          update_intrinsic_format_and_type(intr, var);
2669          return true;
2670       }
2671    }
2672    return false;
2673 }
2674 
2675 bool
dxil_nir_guess_image_formats(nir_shader * s)2676 dxil_nir_guess_image_formats(nir_shader *s)
2677 {
2678    bool progress = false;
2679    nir_foreach_variable_with_modes(var, s, nir_var_image) {
2680       progress |= guess_image_format_for_var(s, var);
2681    }
2682    nir_shader_intrinsics_pass(s, update_intrinsic_formats, nir_metadata_all,
2683                               NULL);
2684    return progress;
2685 }
2686 
2687 static void
set_binding_variables_coherent(nir_shader * s,nir_binding binding,nir_variable_mode modes)2688 set_binding_variables_coherent(nir_shader *s, nir_binding binding, nir_variable_mode modes)
2689 {
2690    nir_foreach_variable_with_modes(var, s, modes) {
2691       if (var->data.binding == binding.binding &&
2692           var->data.descriptor_set == binding.desc_set) {
2693          var->data.access |= ACCESS_COHERENT;
2694       }
2695    }
2696 }
2697 
2698 static void
set_deref_variables_coherent(nir_shader * s,nir_deref_instr * deref)2699 set_deref_variables_coherent(nir_shader *s, nir_deref_instr *deref)
2700 {
2701    while (deref->deref_type != nir_deref_type_var &&
2702           deref->deref_type != nir_deref_type_cast) {
2703       deref = nir_deref_instr_parent(deref);
2704    }
2705    if (deref->deref_type == nir_deref_type_var) {
2706       deref->var->data.access |= ACCESS_COHERENT;
2707       return;
2708    }
2709 
2710    /* For derefs with casts, we only support pre-lowered Vulkan accesses */
2711    assert(deref->deref_type == nir_deref_type_cast);
2712    nir_intrinsic_instr *cast_src = nir_instr_as_intrinsic(deref->parent.ssa->parent_instr);
2713    assert(cast_src->intrinsic == nir_intrinsic_load_vulkan_descriptor);
2714    nir_binding binding = nir_chase_binding(cast_src->src[0]);
2715    set_binding_variables_coherent(s, binding, nir_var_mem_ssbo);
2716 }
2717 
2718 static nir_def *
get_atomic_for_load_store(nir_builder * b,nir_intrinsic_instr * intr,unsigned bit_size)2719 get_atomic_for_load_store(nir_builder *b, nir_intrinsic_instr *intr, unsigned bit_size)
2720 {
2721    nir_def *zero = nir_imm_intN_t(b, 0, bit_size);
2722    switch (intr->intrinsic) {
2723    case nir_intrinsic_load_deref:
2724       return nir_deref_atomic(b, bit_size, intr->src[0].ssa, zero, .atomic_op = nir_atomic_op_iadd);
2725    case nir_intrinsic_load_ssbo:
2726       return nir_ssbo_atomic(b, bit_size, intr->src[0].ssa, intr->src[1].ssa, zero, .atomic_op = nir_atomic_op_iadd);
2727    case nir_intrinsic_image_deref_load:
2728       return nir_image_deref_atomic(b, bit_size, intr->src[0].ssa, intr->src[1].ssa, intr->src[2].ssa, zero, .atomic_op = nir_atomic_op_iadd);
2729    case nir_intrinsic_image_load:
2730       return nir_image_atomic(b, bit_size, intr->src[0].ssa, intr->src[1].ssa, intr->src[2].ssa, zero, .atomic_op = nir_atomic_op_iadd);
2731    case nir_intrinsic_store_deref:
2732       return nir_deref_atomic(b, bit_size, intr->src[0].ssa, intr->src[1].ssa, .atomic_op = nir_atomic_op_xchg);
2733    case nir_intrinsic_store_ssbo:
2734       return nir_ssbo_atomic(b, bit_size, intr->src[1].ssa, intr->src[2].ssa, intr->src[0].ssa, .atomic_op = nir_atomic_op_xchg);
2735    case nir_intrinsic_image_deref_store:
2736       return nir_image_deref_atomic(b, bit_size, intr->src[0].ssa, intr->src[1].ssa, intr->src[2].ssa, intr->src[3].ssa, .atomic_op = nir_atomic_op_xchg);
2737    case nir_intrinsic_image_store:
2738       return nir_image_atomic(b, bit_size, intr->src[0].ssa, intr->src[1].ssa, intr->src[2].ssa, intr->src[3].ssa, .atomic_op = nir_atomic_op_xchg);
2739    default:
2740       return NULL;
2741    }
2742 }
2743 
2744 static bool
lower_coherent_load_store(nir_builder * b,nir_intrinsic_instr * intr,void * context)2745 lower_coherent_load_store(nir_builder *b, nir_intrinsic_instr *intr, void *context)
2746 {
2747    if (!nir_intrinsic_has_access(intr) || (nir_intrinsic_access(intr) & ACCESS_COHERENT) == 0)
2748       return false;
2749 
2750    nir_def *atomic_def = NULL;
2751    b->cursor = nir_before_instr(&intr->instr);
2752    switch (intr->intrinsic) {
2753    case nir_intrinsic_load_deref:
2754    case nir_intrinsic_load_ssbo:
2755    case nir_intrinsic_image_deref_load:
2756    case nir_intrinsic_image_load: {
2757       if (intr->def.bit_size < 32 || intr->def.num_components > 1) {
2758          if (intr->intrinsic == nir_intrinsic_load_deref)
2759             set_deref_variables_coherent(b->shader, nir_src_as_deref(intr->src[0]));
2760          else {
2761             nir_binding binding = {0};
2762             if (nir_src_is_const(intr->src[0]))
2763                binding.binding = nir_src_as_uint(intr->src[0]);
2764             set_binding_variables_coherent(b->shader, binding,
2765                                            intr->intrinsic == nir_intrinsic_load_ssbo ? nir_var_mem_ssbo : nir_var_image);
2766          }
2767          return false;
2768       }
2769 
2770       atomic_def = get_atomic_for_load_store(b, intr, intr->def.bit_size);
2771       nir_def_rewrite_uses(&intr->def, atomic_def);
2772       break;
2773    }
2774    case nir_intrinsic_store_deref:
2775    case nir_intrinsic_store_ssbo:
2776    case nir_intrinsic_image_deref_store:
2777    case nir_intrinsic_image_store: {
2778       int resource_idx = intr->intrinsic == nir_intrinsic_store_ssbo ? 1 : 0;
2779       int value_idx = intr->intrinsic == nir_intrinsic_store_ssbo ? 0 :
2780          intr->intrinsic == nir_intrinsic_store_deref ? 1 : 3;
2781       unsigned num_components = nir_intrinsic_has_write_mask(intr) ?
2782          util_bitcount(nir_intrinsic_write_mask(intr)) : intr->src[value_idx].ssa->num_components;
2783       if (intr->src[value_idx].ssa->bit_size < 32 || num_components > 1) {
2784          if (intr->intrinsic == nir_intrinsic_store_deref)
2785             set_deref_variables_coherent(b->shader, nir_src_as_deref(intr->src[resource_idx]));
2786          else {
2787             nir_binding binding = {0};
2788             if (nir_src_is_const(intr->src[resource_idx]))
2789                binding.binding = nir_src_as_uint(intr->src[resource_idx]);
2790             set_binding_variables_coherent(b->shader, binding,
2791                                            intr->intrinsic == nir_intrinsic_store_ssbo ? nir_var_mem_ssbo : nir_var_image);
2792          }
2793          return false;
2794       }
2795 
2796       atomic_def = get_atomic_for_load_store(b, intr, intr->src[value_idx].ssa->bit_size);
2797       break;
2798    }
2799    default:
2800       return false;
2801    }
2802 
2803    nir_intrinsic_instr *atomic = nir_instr_as_intrinsic(atomic_def->parent_instr);
2804    nir_intrinsic_set_access(atomic, nir_intrinsic_access(intr));
2805    if (nir_intrinsic_has_image_dim(intr))
2806       nir_intrinsic_set_image_dim(atomic, nir_intrinsic_image_dim(intr));
2807    if (nir_intrinsic_has_image_array(intr))
2808       nir_intrinsic_set_image_array(atomic, nir_intrinsic_image_array(intr));
2809    if (nir_intrinsic_has_format(intr))
2810       nir_intrinsic_set_format(atomic, nir_intrinsic_format(intr));
2811    if (nir_intrinsic_has_range_base(intr))
2812       nir_intrinsic_set_range_base(atomic, nir_intrinsic_range_base(intr));
2813    nir_instr_remove(&intr->instr);
2814    return true;
2815 }
2816 
2817 bool
dxil_nir_lower_coherent_loads_and_stores(nir_shader * s)2818 dxil_nir_lower_coherent_loads_and_stores(nir_shader *s)
2819 {
2820    return nir_shader_intrinsics_pass(s, lower_coherent_load_store,
2821                                      nir_metadata_control_flow | nir_metadata_loop_analysis,
2822                                      NULL);
2823 }
2824 
2825 struct undefined_varying_masks {
2826    uint64_t io_mask;
2827    uint32_t patch_io_mask;
2828    const BITSET_WORD *frac_io_mask;
2829 };
2830 
2831 static bool
is_dead_in_variable(nir_variable * var,void * data)2832 is_dead_in_variable(nir_variable *var, void *data)
2833 {
2834    switch (var->data.location) {
2835    /* Only these values can be system generated values in addition to varyings */
2836    case VARYING_SLOT_PRIMITIVE_ID:
2837    case VARYING_SLOT_FACE:
2838    case VARYING_SLOT_VIEW_INDEX:
2839       return false;
2840    /* Tessellation input vars must remain untouched */
2841    case VARYING_SLOT_TESS_LEVEL_INNER:
2842    case VARYING_SLOT_TESS_LEVEL_OUTER:
2843       return false;
2844    default:
2845       return true;
2846    }
2847 }
2848 
2849 static bool
kill_undefined_varyings(struct nir_builder * b,nir_instr * instr,void * data)2850 kill_undefined_varyings(struct nir_builder *b,
2851                         nir_instr *instr,
2852                         void *data)
2853 {
2854    const struct undefined_varying_masks *masks = data;
2855 
2856    if (instr->type != nir_instr_type_intrinsic)
2857       return false;
2858 
2859    nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
2860 
2861    if (intr->intrinsic != nir_intrinsic_load_deref)
2862       return false;
2863 
2864    nir_variable *var = nir_intrinsic_get_var(intr, 0);
2865    if (!var || var->data.mode != nir_var_shader_in)
2866       return false;
2867 
2868    if (!is_dead_in_variable(var, NULL))
2869       return false;
2870 
2871    uint32_t loc = var->data.patch && var->data.location >= VARYING_SLOT_PATCH0 ?
2872       var->data.location - VARYING_SLOT_PATCH0 :
2873       var->data.location;
2874    uint64_t written = var->data.patch && var->data.location >= VARYING_SLOT_PATCH0 ?
2875       masks->patch_io_mask : masks->io_mask;
2876    if (BITFIELD64_RANGE(loc, glsl_varying_count(var->type)) & written) {
2877       if (!masks->frac_io_mask || !var->data.location_frac ||
2878           var->data.location < VARYING_SLOT_VAR0 ||
2879           BITSET_TEST(masks->frac_io_mask, (var->data.location - VARYING_SLOT_VAR0) * 4 + var->data.location_frac))
2880       return false;
2881    }
2882 
2883    b->cursor = nir_after_instr(instr);
2884    /* Note: zero is used instead of undef, because optimization is not run here, but is
2885     * run later on. If we load an undef here, and that undef ends up being used to store
2886     * to position later on, that can cause some or all of the components in that position
2887     * write to be removed, which is problematic especially in the case of all components,
2888     * since that would remove the store instruction, and would make it tricky to satisfy
2889     * the DXIL requirements of writing all position components.
2890     */
2891    nir_def *zero = nir_imm_zero(b, intr->def.num_components,
2892                                        intr->def.bit_size);
2893    nir_def_replace(&intr->def, zero);
2894    return true;
2895 }
2896 
2897 bool
dxil_nir_kill_undefined_varyings(nir_shader * shader,uint64_t prev_stage_written_mask,uint32_t prev_stage_patch_written_mask,const BITSET_WORD * prev_stage_frac_output_mask)2898 dxil_nir_kill_undefined_varyings(nir_shader *shader, uint64_t prev_stage_written_mask, uint32_t prev_stage_patch_written_mask,
2899                                  const BITSET_WORD *prev_stage_frac_output_mask)
2900 {
2901    struct undefined_varying_masks masks = {
2902       .io_mask = prev_stage_written_mask,
2903       .patch_io_mask = prev_stage_patch_written_mask,
2904       .frac_io_mask = prev_stage_frac_output_mask
2905    };
2906    bool progress = nir_shader_instructions_pass(shader,
2907                                                 kill_undefined_varyings,
2908                                                 nir_metadata_control_flow |
2909                                                 nir_metadata_loop_analysis,
2910                                                 (void *)&masks);
2911    if (progress) {
2912       nir_opt_dce(shader);
2913       nir_remove_dead_derefs(shader);
2914    }
2915 
2916    const struct nir_remove_dead_variables_options options = {
2917       .can_remove_var = is_dead_in_variable,
2918       .can_remove_var_data = &masks,
2919    };
2920    progress |= nir_remove_dead_variables(shader, nir_var_shader_in, &options);
2921    return progress;
2922 }
2923 
2924 static bool
is_dead_out_variable(nir_variable * var,void * data)2925 is_dead_out_variable(nir_variable *var, void *data)
2926 {
2927    return !nir_slot_is_sysval_output(var->data.location, MESA_SHADER_NONE);
2928 }
2929 
2930 static bool
kill_unused_outputs(struct nir_builder * b,nir_instr * instr,void * data)2931 kill_unused_outputs(struct nir_builder *b,
2932                     nir_instr *instr,
2933                     void *data)
2934 {
2935    const struct undefined_varying_masks *masks = data;
2936 
2937    if (instr->type != nir_instr_type_intrinsic)
2938       return false;
2939 
2940    nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
2941 
2942    if (intr->intrinsic != nir_intrinsic_store_deref &&
2943        intr->intrinsic != nir_intrinsic_load_deref)
2944       return false;
2945 
2946    nir_variable *var = nir_intrinsic_get_var(intr, 0);
2947    if (!var || var->data.mode != nir_var_shader_out ||
2948        /* always_active_io can mean two things: xfb or GL separable shaders. We can't delete
2949         * varyings that are used for xfb (we'll just sort them last), but we must delete varyings
2950         * that are mismatching between TCS and TES. Fortunately TCS can't do xfb, so we can ignore
2951         the always_active_io bit for TCS outputs. */
2952        (b->shader->info.stage != MESA_SHADER_TESS_CTRL && var->data.always_active_io))
2953       return false;
2954 
2955    if (!is_dead_out_variable(var, NULL))
2956       return false;
2957 
2958    unsigned loc = var->data.patch && var->data.location >= VARYING_SLOT_PATCH0 ?
2959       var->data.location - VARYING_SLOT_PATCH0 :
2960       var->data.location;
2961    uint64_t read = var->data.patch && var->data.location >= VARYING_SLOT_PATCH0 ?
2962       masks->patch_io_mask : masks->io_mask;
2963    if (BITFIELD64_RANGE(loc, glsl_varying_count(var->type)) & read) {
2964       if (!masks->frac_io_mask || !var->data.location_frac ||
2965           var->data.location < VARYING_SLOT_VAR0 ||
2966           BITSET_TEST(masks->frac_io_mask, (var->data.location - VARYING_SLOT_VAR0) * 4 + var->data.location_frac))
2967       return false;
2968    }
2969 
2970    if (intr->intrinsic == nir_intrinsic_load_deref) {
2971       b->cursor = nir_after_instr(&intr->instr);
2972       nir_def *zero = nir_imm_zero(b, intr->def.num_components, intr->def.bit_size);
2973       nir_def_rewrite_uses(&intr->def, zero);
2974    }
2975    nir_instr_remove(instr);
2976    return true;
2977 }
2978 
2979 bool
dxil_nir_kill_unused_outputs(nir_shader * shader,uint64_t next_stage_read_mask,uint32_t next_stage_patch_read_mask,const BITSET_WORD * next_stage_frac_input_mask)2980 dxil_nir_kill_unused_outputs(nir_shader *shader, uint64_t next_stage_read_mask, uint32_t next_stage_patch_read_mask,
2981                              const BITSET_WORD *next_stage_frac_input_mask)
2982 {
2983    struct undefined_varying_masks masks = {
2984       .io_mask = next_stage_read_mask,
2985       .patch_io_mask = next_stage_patch_read_mask,
2986       .frac_io_mask = next_stage_frac_input_mask
2987    };
2988 
2989    bool progress = nir_shader_instructions_pass(shader,
2990                                                 kill_unused_outputs,
2991                                                 nir_metadata_control_flow |
2992                                                 nir_metadata_loop_analysis,
2993                                                 (void *)&masks);
2994 
2995    if (progress) {
2996       nir_opt_dce(shader);
2997       nir_remove_dead_derefs(shader);
2998    }
2999    const struct nir_remove_dead_variables_options options = {
3000       .can_remove_var = is_dead_out_variable,
3001       .can_remove_var_data = &masks,
3002    };
3003    progress |= nir_remove_dead_variables(shader, nir_var_shader_out, &options);
3004    return progress;
3005 }
3006