• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright © Microsoft Corporation
3  *
4  * Permission is hereby granted, free of charge, to any person obtaining a
5  * copy of this software and associated documentation files (the "Software"),
6  * to deal in the Software without restriction, including without limitation
7  * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8  * and/or sell copies of the Software, and to permit persons to whom the
9  * Software is furnished to do so, subject to the following conditions:
10  *
11  * The above copyright notice and this permission notice (including the next
12  * paragraph) shall be included in all copies or substantial portions of the
13  * Software.
14  *
15  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18  * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20  * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21  * IN THE SOFTWARE.
22  */
23 
24 #include "dxil_nir.h"
25 #include "dxil_module.h"
26 
27 #include "nir_builder.h"
28 #include "nir_deref.h"
29 #include "nir_worklist.h"
30 #include "nir_to_dxil.h"
31 #include "util/u_math.h"
32 #include "vulkan/vulkan_core.h"
33 
34 static void
cl_type_size_align(const struct glsl_type * type,unsigned * size,unsigned * align)35 cl_type_size_align(const struct glsl_type *type, unsigned *size,
36                    unsigned *align)
37 {
38    *size = glsl_get_cl_size(type);
39    *align = glsl_get_cl_alignment(type);
40 }
41 
42 static nir_def *
load_comps_to_vec(nir_builder * b,unsigned src_bit_size,nir_def ** src_comps,unsigned num_src_comps,unsigned dst_bit_size)43 load_comps_to_vec(nir_builder *b, unsigned src_bit_size,
44                   nir_def **src_comps, unsigned num_src_comps,
45                   unsigned dst_bit_size)
46 {
47    if (src_bit_size == dst_bit_size)
48       return nir_vec(b, src_comps, num_src_comps);
49    else if (src_bit_size > dst_bit_size)
50       return nir_extract_bits(b, src_comps, num_src_comps, 0, src_bit_size * num_src_comps / dst_bit_size, dst_bit_size);
51 
52    unsigned num_dst_comps = DIV_ROUND_UP(num_src_comps * src_bit_size, dst_bit_size);
53    unsigned comps_per_dst = dst_bit_size / src_bit_size;
54    nir_def *dst_comps[4];
55 
56    for (unsigned i = 0; i < num_dst_comps; i++) {
57       unsigned src_offs = i * comps_per_dst;
58 
59       dst_comps[i] = nir_u2uN(b, src_comps[src_offs], dst_bit_size);
60       for (unsigned j = 1; j < comps_per_dst && src_offs + j < num_src_comps; j++) {
61          nir_def *tmp = nir_ishl_imm(b, nir_u2uN(b, src_comps[src_offs + j], dst_bit_size),
62                                          j * src_bit_size);
63          dst_comps[i] = nir_ior(b, dst_comps[i], tmp);
64       }
65    }
66 
67    return nir_vec(b, dst_comps, num_dst_comps);
68 }
69 
70 static bool
lower_32b_offset_load(nir_builder * b,nir_intrinsic_instr * intr,nir_variable * var)71 lower_32b_offset_load(nir_builder *b, nir_intrinsic_instr *intr, nir_variable *var)
72 {
73    unsigned bit_size = intr->def.bit_size;
74    unsigned num_components = intr->def.num_components;
75    unsigned num_bits = num_components * bit_size;
76 
77    b->cursor = nir_before_instr(&intr->instr);
78 
79    nir_def *offset = intr->src[0].ssa;
80    if (intr->intrinsic == nir_intrinsic_load_shared)
81       offset = nir_iadd_imm(b, offset, nir_intrinsic_base(intr));
82    else
83       offset = nir_u2u32(b, offset);
84    nir_def *index = nir_ushr_imm(b, offset, 2);
85    nir_def *comps[NIR_MAX_VEC_COMPONENTS];
86    nir_def *comps_32bit[NIR_MAX_VEC_COMPONENTS * 2];
87 
88    /* We need to split loads in 32-bit accesses because the buffer
89     * is an i32 array and DXIL does not support type casts.
90     */
91    unsigned num_32bit_comps = DIV_ROUND_UP(num_bits, 32);
92    for (unsigned i = 0; i < num_32bit_comps; i++)
93       comps_32bit[i] = nir_load_array_var(b, var, nir_iadd_imm(b, index, i));
94    unsigned num_comps_per_pass = MIN2(num_32bit_comps, 4);
95 
96    for (unsigned i = 0; i < num_32bit_comps; i += num_comps_per_pass) {
97       unsigned num_vec32_comps = MIN2(num_32bit_comps - i, 4);
98       unsigned num_dest_comps = num_vec32_comps * 32 / bit_size;
99       nir_def *vec32 = nir_vec(b, &comps_32bit[i], num_vec32_comps);
100 
101       /* If we have 16 bits or less to load we need to adjust the u32 value so
102        * we can always extract the LSB.
103        */
104       if (num_bits <= 16) {
105          nir_def *shift =
106             nir_imul_imm(b, nir_iand_imm(b, offset, 3), 8);
107          vec32 = nir_ushr(b, vec32, shift);
108       }
109 
110       /* And now comes the pack/unpack step to match the original type. */
111       unsigned dest_index = i * 32 / bit_size;
112       nir_def *temp_vec = nir_extract_bits(b, &vec32, 1, 0, num_dest_comps, bit_size);
113       for (unsigned comp = 0; comp < num_dest_comps; ++comp, ++dest_index)
114          comps[dest_index] = nir_channel(b, temp_vec, comp);
115    }
116 
117    nir_def *result = nir_vec(b, comps, num_components);
118    nir_def_replace(&intr->def, result);
119 
120    return true;
121 }
122 
123 static void
lower_masked_store_vec32(nir_builder * b,nir_def * offset,nir_def * index,nir_def * vec32,unsigned num_bits,nir_variable * var,unsigned alignment)124 lower_masked_store_vec32(nir_builder *b, nir_def *offset, nir_def *index,
125                          nir_def *vec32, unsigned num_bits, nir_variable *var, unsigned alignment)
126 {
127    nir_def *mask = nir_imm_int(b, (1 << num_bits) - 1);
128 
129    /* If we have small alignments, we need to place them correctly in the u32 component. */
130    if (alignment <= 2) {
131       nir_def *shift =
132          nir_imul_imm(b, nir_iand_imm(b, offset, 3), 8);
133 
134       vec32 = nir_ishl(b, vec32, shift);
135       mask = nir_ishl(b, mask, shift);
136    }
137 
138    if (var->data.mode == nir_var_mem_shared) {
139       /* Use the dedicated masked intrinsic */
140       nir_deref_instr *deref = nir_build_deref_array(b, nir_build_deref_var(b, var), index);
141       nir_deref_atomic(b, 32, &deref->def, nir_inot(b, mask), .atomic_op = nir_atomic_op_iand);
142       nir_deref_atomic(b, 32, &deref->def, vec32, .atomic_op = nir_atomic_op_ior);
143    } else {
144       /* For scratch, since we don't need atomics, just generate the read-modify-write in NIR */
145       nir_def *load = nir_load_array_var(b, var, index);
146 
147       nir_def *new_val = nir_ior(b, vec32,
148                                      nir_iand(b,
149                                               nir_inot(b, mask),
150                                               load));
151 
152       nir_store_array_var(b, var, index, new_val, 1);
153    }
154 }
155 
156 static bool
lower_32b_offset_store(nir_builder * b,nir_intrinsic_instr * intr,nir_variable * var)157 lower_32b_offset_store(nir_builder *b, nir_intrinsic_instr *intr, nir_variable *var)
158 {
159    unsigned num_components = nir_src_num_components(intr->src[0]);
160    unsigned bit_size = nir_src_bit_size(intr->src[0]);
161    unsigned num_bits = num_components * bit_size;
162 
163    b->cursor = nir_before_instr(&intr->instr);
164 
165    nir_def *offset = intr->src[1].ssa;
166    if (intr->intrinsic == nir_intrinsic_store_shared)
167       offset = nir_iadd_imm(b, offset, nir_intrinsic_base(intr));
168    else
169       offset = nir_u2u32(b, offset);
170    nir_def *comps[NIR_MAX_VEC_COMPONENTS];
171 
172    unsigned comp_idx = 0;
173    for (unsigned i = 0; i < num_components; i++)
174       comps[i] = nir_channel(b, intr->src[0].ssa, i);
175 
176    unsigned step = MAX2(bit_size, 32);
177    for (unsigned i = 0; i < num_bits; i += step) {
178       /* For each 4byte chunk (or smaller) we generate a 32bit scalar store. */
179       unsigned substore_num_bits = MIN2(num_bits - i, step);
180       nir_def *local_offset = nir_iadd_imm(b, offset, i / 8);
181       nir_def *vec32 = load_comps_to_vec(b, bit_size, &comps[comp_idx],
182                                              substore_num_bits / bit_size, 32);
183       nir_def *index = nir_ushr_imm(b, local_offset, 2);
184 
185       /* For anything less than 32bits we need to use the masked version of the
186        * intrinsic to preserve data living in the same 32bit slot. */
187       if (substore_num_bits < 32) {
188          lower_masked_store_vec32(b, local_offset, index, vec32, num_bits, var, nir_intrinsic_align(intr));
189       } else {
190          for (unsigned i = 0; i < vec32->num_components; ++i)
191             nir_store_array_var(b, var, nir_iadd_imm(b, index, i), nir_channel(b, vec32, i), 1);
192       }
193 
194       comp_idx += substore_num_bits / bit_size;
195    }
196 
197    nir_instr_remove(&intr->instr);
198 
199    return true;
200 }
201 
202 #define CONSTANT_LOCATION_UNVISITED 0
203 #define CONSTANT_LOCATION_VALID 1
204 #define CONSTANT_LOCATION_INVALID 2
205 
206 bool
dxil_nir_lower_constant_to_temp(nir_shader * nir)207 dxil_nir_lower_constant_to_temp(nir_shader *nir)
208 {
209    bool progress = false;
210    nir_foreach_variable_with_modes(var, nir, nir_var_mem_constant)
211       var->data.location = var->constant_initializer ?
212          CONSTANT_LOCATION_UNVISITED : CONSTANT_LOCATION_INVALID;
213 
214    /* First pass: collect all UBO accesses that could be turned into
215     * shader temp accesses.
216     */
217    nir_foreach_function(func, nir) {
218       if (!func->is_entrypoint)
219          continue;
220       assert(func->impl);
221 
222       nir_foreach_block(block, func->impl) {
223          nir_foreach_instr_safe(instr, block) {
224             if (instr->type != nir_instr_type_deref)
225                continue;
226 
227             nir_deref_instr *deref = nir_instr_as_deref(instr);
228             if (!nir_deref_mode_is(deref, nir_var_mem_constant) ||
229                 deref->deref_type != nir_deref_type_var ||
230                 deref->var->data.location == CONSTANT_LOCATION_INVALID)
231                continue;
232 
233             deref->var->data.location = nir_deref_instr_has_complex_use(deref, 0) ?
234                CONSTANT_LOCATION_INVALID : CONSTANT_LOCATION_VALID;
235          }
236       }
237    }
238 
239    nir_foreach_variable_with_modes(var, nir, nir_var_mem_constant) {
240       if (var->data.location != CONSTANT_LOCATION_VALID)
241          continue;
242 
243       /* Change the variable mode. */
244       var->data.mode = nir_var_shader_temp;
245 
246       progress = true;
247    }
248 
249    /* Second pass: patch all derefs that were accessing the converted UBOs
250     * variables.
251     */
252    nir_foreach_function(func, nir) {
253       if (!func->is_entrypoint)
254          continue;
255       assert(func->impl);
256 
257       nir_builder b = nir_builder_create(func->impl);
258       nir_foreach_block(block, func->impl) {
259          nir_foreach_instr_safe(instr, block) {
260             if (instr->type != nir_instr_type_deref)
261                continue;
262 
263             nir_deref_instr *deref = nir_instr_as_deref(instr);
264             if (nir_deref_mode_is(deref, nir_var_mem_constant)) {
265                nir_deref_instr *parent = deref;
266                while (parent && parent->deref_type != nir_deref_type_var)
267                   parent = nir_src_as_deref(parent->parent);
268                if (parent && parent->var->data.mode != nir_var_mem_constant) {
269                   deref->modes = parent->var->data.mode;
270                   /* Also change "pointer" size to 32-bit since this is now a logical pointer */
271                   deref->def.bit_size = 32;
272                   if (deref->deref_type == nir_deref_type_array) {
273                      b.cursor = nir_before_instr(instr);
274                      nir_src_rewrite(&deref->arr.index, nir_u2u32(&b, deref->arr.index.ssa));
275                   }
276                }
277             }
278          }
279       }
280    }
281 
282    return progress;
283 }
284 
285 static bool
flatten_var_arrays(nir_builder * b,nir_intrinsic_instr * intr,void * data)286 flatten_var_arrays(nir_builder *b, nir_intrinsic_instr *intr, void *data)
287 {
288    switch (intr->intrinsic) {
289    case nir_intrinsic_load_deref:
290    case nir_intrinsic_store_deref:
291    case nir_intrinsic_deref_atomic:
292    case nir_intrinsic_deref_atomic_swap:
293       break;
294    default:
295       return false;
296    }
297 
298    nir_deref_instr *deref = nir_src_as_deref(intr->src[0]);
299    nir_variable *var = NULL;
300    for (nir_deref_instr *d = deref; d; d = nir_deref_instr_parent(d)) {
301       if (d->deref_type == nir_deref_type_cast)
302          return false;
303       if (d->deref_type == nir_deref_type_var) {
304          var = d->var;
305          if (d->type == var->type)
306             return false;
307       }
308    }
309    if (!var)
310       return false;
311 
312    nir_deref_path path;
313    nir_deref_path_init(&path, deref, NULL);
314 
315    assert(path.path[0]->deref_type == nir_deref_type_var);
316    b->cursor = nir_before_instr(&path.path[0]->instr);
317    nir_deref_instr *new_var_deref = nir_build_deref_var(b, var);
318    nir_def *index = NULL;
319    for (unsigned level = 1; path.path[level]; ++level) {
320       nir_deref_instr *arr_deref = path.path[level];
321       assert(arr_deref->deref_type == nir_deref_type_array);
322       b->cursor = nir_before_instr(&arr_deref->instr);
323       nir_def *val = nir_imul_imm(b, arr_deref->arr.index.ssa,
324                                       glsl_get_component_slots(arr_deref->type));
325       if (index) {
326          index = nir_iadd(b, index, val);
327       } else {
328          index = val;
329       }
330    }
331 
332    unsigned vector_comps = intr->num_components;
333    if (vector_comps > 1) {
334       b->cursor = nir_before_instr(&intr->instr);
335       if (intr->intrinsic == nir_intrinsic_load_deref) {
336          nir_def *components[NIR_MAX_VEC_COMPONENTS];
337          for (unsigned i = 0; i < vector_comps; ++i) {
338             nir_def *final_index = index ? nir_iadd_imm(b, index, i) : nir_imm_int(b, i);
339             nir_deref_instr *comp_deref = nir_build_deref_array(b, new_var_deref, final_index);
340             components[i] = nir_load_deref(b, comp_deref);
341          }
342          nir_def_rewrite_uses(&intr->def, nir_vec(b, components, vector_comps));
343       } else if (intr->intrinsic == nir_intrinsic_store_deref) {
344          for (unsigned i = 0; i < vector_comps; ++i) {
345             if (((1 << i) & nir_intrinsic_write_mask(intr)) == 0)
346                continue;
347             nir_def *final_index = index ? nir_iadd_imm(b, index, i) : nir_imm_int(b, i);
348             nir_deref_instr *comp_deref = nir_build_deref_array(b, new_var_deref, final_index);
349             nir_store_deref(b, comp_deref, nir_channel(b, intr->src[1].ssa, i), 1);
350          }
351       }
352       nir_instr_remove(&intr->instr);
353    } else {
354       nir_src_rewrite(&intr->src[0], &nir_build_deref_array(b, new_var_deref, index)->def);
355    }
356 
357    nir_deref_path_finish(&path);
358    return true;
359 }
360 
361 static void
flatten_constant_initializer(nir_variable * var,nir_constant * src,nir_constant *** dest,unsigned vector_elements)362 flatten_constant_initializer(nir_variable *var, nir_constant *src, nir_constant ***dest, unsigned vector_elements)
363 {
364    if (src->num_elements == 0) {
365       for (unsigned i = 0; i < vector_elements; ++i) {
366          nir_constant *new_scalar = rzalloc(var, nir_constant);
367          memcpy(&new_scalar->values[0], &src->values[i], sizeof(src->values[0]));
368          new_scalar->is_null_constant = src->values[i].u64 == 0;
369 
370          nir_constant **array_entry = (*dest)++;
371          *array_entry = new_scalar;
372       }
373    } else {
374       for (unsigned i = 0; i < src->num_elements; ++i)
375          flatten_constant_initializer(var, src->elements[i], dest, vector_elements);
376    }
377 }
378 
379 static bool
flatten_var_array_types(nir_variable * var)380 flatten_var_array_types(nir_variable *var)
381 {
382    assert(!glsl_type_is_struct(glsl_without_array(var->type)));
383    const struct glsl_type *matrix_type = glsl_without_array(var->type);
384    if (!glsl_type_is_array_of_arrays(var->type) && glsl_get_components(matrix_type) == 1)
385       return false;
386 
387    enum glsl_base_type base_type = glsl_get_base_type(matrix_type);
388    const struct glsl_type *flattened_type = glsl_array_type(glsl_scalar_type(base_type),
389                                                             glsl_get_component_slots(var->type), 0);
390    var->type = flattened_type;
391    if (var->constant_initializer) {
392       nir_constant **new_elements = ralloc_array(var, nir_constant *, glsl_get_length(flattened_type));
393       nir_constant **temp = new_elements;
394       flatten_constant_initializer(var, var->constant_initializer, &temp, glsl_get_vector_elements(matrix_type));
395       var->constant_initializer->num_elements = glsl_get_length(flattened_type);
396       var->constant_initializer->elements = new_elements;
397    }
398    return true;
399 }
400 
401 bool
dxil_nir_flatten_var_arrays(nir_shader * shader,nir_variable_mode modes)402 dxil_nir_flatten_var_arrays(nir_shader *shader, nir_variable_mode modes)
403 {
404    bool progress = false;
405    nir_foreach_variable_with_modes(var, shader, modes & ~nir_var_function_temp)
406       progress |= flatten_var_array_types(var);
407 
408    if (modes & nir_var_function_temp) {
409       nir_foreach_function_impl(impl, shader) {
410          nir_foreach_function_temp_variable(var, impl)
411             progress |= flatten_var_array_types(var);
412       }
413    }
414 
415    if (!progress)
416       return false;
417 
418    nir_shader_intrinsics_pass(shader, flatten_var_arrays,
419                                 nir_metadata_control_flow |
420                                 nir_metadata_loop_analysis,
421                                 NULL);
422    nir_remove_dead_derefs(shader);
423    return true;
424 }
425 
426 static bool
lower_deref_bit_size(nir_builder * b,nir_intrinsic_instr * intr,void * data)427 lower_deref_bit_size(nir_builder *b, nir_intrinsic_instr *intr, void *data)
428 {
429    switch (intr->intrinsic) {
430    case nir_intrinsic_load_deref:
431    case nir_intrinsic_store_deref:
432       break;
433    default:
434       /* Atomics can't be smaller than 32-bit */
435       return false;
436    }
437 
438    nir_deref_instr *deref = nir_src_as_deref(intr->src[0]);
439    nir_variable *var = nir_deref_instr_get_variable(deref);
440    /* Only interested in full deref chains */
441    if (!var)
442       return false;
443 
444    const struct glsl_type *var_scalar_type = glsl_without_array(var->type);
445    if (deref->type == var_scalar_type || !glsl_type_is_scalar(var_scalar_type))
446       return false;
447 
448    assert(deref->deref_type == nir_deref_type_var || deref->deref_type == nir_deref_type_array);
449    const struct glsl_type *old_glsl_type = deref->type;
450    nir_alu_type old_type = nir_get_nir_type_for_glsl_type(old_glsl_type);
451    nir_alu_type new_type = nir_get_nir_type_for_glsl_type(var_scalar_type);
452    if (glsl_get_bit_size(old_glsl_type) < glsl_get_bit_size(var_scalar_type)) {
453       deref->type = var_scalar_type;
454       if (intr->intrinsic == nir_intrinsic_load_deref) {
455          intr->def.bit_size = glsl_get_bit_size(var_scalar_type);
456          b->cursor = nir_after_instr(&intr->instr);
457          nir_def *downcast = nir_type_convert(b, &intr->def, new_type, old_type, nir_rounding_mode_undef);
458          nir_def_rewrite_uses_after(&intr->def, downcast, downcast->parent_instr);
459       }
460       else {
461          b->cursor = nir_before_instr(&intr->instr);
462          nir_def *upcast = nir_type_convert(b, intr->src[1].ssa, old_type, new_type, nir_rounding_mode_undef);
463          nir_src_rewrite(&intr->src[1], upcast);
464       }
465 
466       while (deref->deref_type == nir_deref_type_array) {
467          nir_deref_instr *parent = nir_deref_instr_parent(deref);
468          parent->type = glsl_type_wrap_in_arrays(deref->type, parent->type);
469          deref = parent;
470       }
471    } else {
472       /* Assumed arrays are already flattened */
473       b->cursor = nir_before_instr(&deref->instr);
474       nir_deref_instr *parent = nir_build_deref_var(b, var);
475       if (deref->deref_type == nir_deref_type_array)
476          deref = nir_build_deref_array(b, parent, nir_imul_imm(b, deref->arr.index.ssa, 2));
477       else
478          deref = nir_build_deref_array_imm(b, parent, 0);
479       nir_deref_instr *deref2 = nir_build_deref_array(b, parent,
480                                                       nir_iadd_imm(b, deref->arr.index.ssa, 1));
481       b->cursor = nir_before_instr(&intr->instr);
482       if (intr->intrinsic == nir_intrinsic_load_deref) {
483          nir_def *src1 = nir_load_deref(b, deref);
484          nir_def *src2 = nir_load_deref(b, deref2);
485          nir_def_rewrite_uses(&intr->def, nir_pack_64_2x32_split(b, src1, src2));
486       } else {
487          nir_def *src1 = nir_unpack_64_2x32_split_x(b, intr->src[1].ssa);
488          nir_def *src2 = nir_unpack_64_2x32_split_y(b, intr->src[1].ssa);
489          nir_store_deref(b, deref, src1, 1);
490          nir_store_deref(b, deref, src2, 1);
491       }
492       nir_instr_remove(&intr->instr);
493    }
494    return true;
495 }
496 
497 static bool
lower_var_bit_size_types(nir_variable * var,unsigned min_bit_size,unsigned max_bit_size)498 lower_var_bit_size_types(nir_variable *var, unsigned min_bit_size, unsigned max_bit_size)
499 {
500    assert(!glsl_type_is_array_of_arrays(var->type) && !glsl_type_is_struct(var->type));
501    const struct glsl_type *type = glsl_without_array(var->type);
502    assert(glsl_type_is_scalar(type));
503    enum glsl_base_type base_type = glsl_get_base_type(type);
504    if (glsl_base_type_get_bit_size(base_type) < min_bit_size) {
505       switch (min_bit_size) {
506       case 16:
507          switch (base_type) {
508          case GLSL_TYPE_BOOL:
509             base_type = GLSL_TYPE_UINT16;
510             for (unsigned i = 0; i < (var->constant_initializer ? var->constant_initializer->num_elements : 0); ++i)
511                var->constant_initializer->elements[i]->values[0].u16 = var->constant_initializer->elements[i]->values[0].b ? 0xffff : 0;
512             break;
513          case GLSL_TYPE_INT8:
514             base_type = GLSL_TYPE_INT16;
515             for (unsigned i = 0; i < (var->constant_initializer ? var->constant_initializer->num_elements : 0); ++i)
516                var->constant_initializer->elements[i]->values[0].i16 = var->constant_initializer->elements[i]->values[0].i8;
517             break;
518          case GLSL_TYPE_UINT8: base_type = GLSL_TYPE_UINT16; break;
519          default: unreachable("Unexpected base type");
520          }
521          break;
522       case 32:
523          switch (base_type) {
524          case GLSL_TYPE_BOOL:
525             base_type = GLSL_TYPE_UINT;
526             for (unsigned i = 0; i < (var->constant_initializer ? var->constant_initializer->num_elements : 0); ++i)
527                var->constant_initializer->elements[i]->values[0].u32 = var->constant_initializer->elements[i]->values[0].b ? 0xffffffff : 0;
528             break;
529          case GLSL_TYPE_INT8:
530             base_type = GLSL_TYPE_INT;
531             for (unsigned i = 0; i < (var->constant_initializer ? var->constant_initializer->num_elements : 0); ++i)
532                var->constant_initializer->elements[i]->values[0].i32 = var->constant_initializer->elements[i]->values[0].i8;
533             break;
534          case GLSL_TYPE_INT16:
535             base_type = GLSL_TYPE_INT;
536             for (unsigned i = 0; i < (var->constant_initializer ? var->constant_initializer->num_elements : 0); ++i)
537                var->constant_initializer->elements[i]->values[0].i32 = var->constant_initializer->elements[i]->values[0].i16;
538             break;
539          case GLSL_TYPE_FLOAT16:
540             base_type = GLSL_TYPE_FLOAT;
541             for (unsigned i = 0; i < (var->constant_initializer ? var->constant_initializer->num_elements : 0); ++i)
542                var->constant_initializer->elements[i]->values[0].f32 = _mesa_half_to_float(var->constant_initializer->elements[i]->values[0].u16);
543             break;
544          case GLSL_TYPE_UINT8: base_type = GLSL_TYPE_UINT; break;
545          case GLSL_TYPE_UINT16: base_type = GLSL_TYPE_UINT; break;
546          default: unreachable("Unexpected base type");
547          }
548          break;
549       default: unreachable("Unexpected min bit size");
550       }
551       var->type = glsl_type_wrap_in_arrays(glsl_scalar_type(base_type), var->type);
552       return true;
553    }
554    if (glsl_base_type_bit_size(base_type) > max_bit_size) {
555       assert(!glsl_type_is_array_of_arrays(var->type));
556       var->type = glsl_array_type(glsl_scalar_type(GLSL_TYPE_UINT),
557                                     glsl_type_is_array(var->type) ? glsl_get_length(var->type) * 2 : 2,
558                                     0);
559       if (var->constant_initializer) {
560          unsigned num_elements = var->constant_initializer->num_elements ?
561             var->constant_initializer->num_elements * 2 : 2;
562          nir_constant **element_arr = ralloc_array(var, nir_constant *, num_elements);
563          nir_constant *elements = rzalloc_array(var, nir_constant, num_elements);
564          for (unsigned i = 0; i < var->constant_initializer->num_elements; ++i) {
565             element_arr[i*2] = &elements[i*2];
566             element_arr[i*2+1] = &elements[i*2+1];
567             const nir_const_value *src = var->constant_initializer->num_elements ?
568                var->constant_initializer->elements[i]->values : var->constant_initializer->values;
569             elements[i*2].values[0].u32 = (uint32_t)src->u64;
570             elements[i*2].is_null_constant = (uint32_t)src->u64 == 0;
571             elements[i*2+1].values[0].u32 = (uint32_t)(src->u64 >> 32);
572             elements[i*2+1].is_null_constant = (uint32_t)(src->u64 >> 32) == 0;
573          }
574          var->constant_initializer->num_elements = num_elements;
575          var->constant_initializer->elements = element_arr;
576       }
577       return true;
578    }
579    return false;
580 }
581 
582 bool
dxil_nir_lower_var_bit_size(nir_shader * shader,nir_variable_mode modes,unsigned min_bit_size,unsigned max_bit_size)583 dxil_nir_lower_var_bit_size(nir_shader *shader, nir_variable_mode modes,
584                             unsigned min_bit_size, unsigned max_bit_size)
585 {
586    bool progress = false;
587    nir_foreach_variable_with_modes(var, shader, modes & ~nir_var_function_temp)
588       progress |= lower_var_bit_size_types(var, min_bit_size, max_bit_size);
589 
590    if (modes & nir_var_function_temp) {
591       nir_foreach_function_impl(impl, shader) {
592          nir_foreach_function_temp_variable(var, impl)
593             progress |= lower_var_bit_size_types(var, min_bit_size, max_bit_size);
594       }
595    }
596 
597    if (!progress)
598       return false;
599 
600    nir_shader_intrinsics_pass(shader, lower_deref_bit_size,
601                                 nir_metadata_control_flow |
602                                 nir_metadata_loop_analysis,
603                                 NULL);
604    nir_remove_dead_derefs(shader);
605    return true;
606 }
607 
608 static bool
remove_oob_array_access(nir_builder * b,nir_intrinsic_instr * intr,void * data)609 remove_oob_array_access(nir_builder *b, nir_intrinsic_instr *intr, void *data)
610 {
611    uint32_t num_derefs = 1;
612 
613    switch (intr->intrinsic) {
614    case nir_intrinsic_copy_deref:
615       num_derefs = 2;
616       FALLTHROUGH;
617    case nir_intrinsic_load_deref:
618    case nir_intrinsic_store_deref:
619    case nir_intrinsic_deref_atomic:
620    case nir_intrinsic_deref_atomic_swap:
621       break;
622    default:
623       return false;
624    }
625 
626    for (uint32_t i = 0; i < num_derefs; ++i) {
627       if (nir_deref_instr_is_known_out_of_bounds(nir_src_as_deref(intr->src[i]))) {
628          switch (intr->intrinsic) {
629          case nir_intrinsic_load_deref:
630          case nir_intrinsic_deref_atomic:
631          case nir_intrinsic_deref_atomic_swap:
632             b->cursor = nir_before_instr(&intr->instr);
633             nir_def *undef = nir_undef(b, intr->def.num_components, intr->def.bit_size);
634             nir_def_rewrite_uses(&intr->def, undef);
635             break;
636          default:
637             break;
638          }
639          nir_instr_remove(&intr->instr);
640          return true;
641       }
642    }
643 
644    return false;
645 }
646 
647 bool
dxil_nir_remove_oob_array_accesses(nir_shader * shader)648 dxil_nir_remove_oob_array_accesses(nir_shader *shader)
649 {
650    return nir_shader_intrinsics_pass(shader, remove_oob_array_access,
651                                      nir_metadata_control_flow |
652                                      nir_metadata_loop_analysis,
653                                      NULL);
654 }
655 
656 static bool
lower_shared_atomic(nir_builder * b,nir_intrinsic_instr * intr,nir_variable * var)657 lower_shared_atomic(nir_builder *b, nir_intrinsic_instr *intr, nir_variable *var)
658 {
659    b->cursor = nir_before_instr(&intr->instr);
660 
661    nir_def *offset =
662       nir_iadd_imm(b, intr->src[0].ssa, nir_intrinsic_base(intr));
663    nir_def *index = nir_ushr_imm(b, offset, 2);
664 
665    nir_deref_instr *deref = nir_build_deref_array(b, nir_build_deref_var(b, var), index);
666    nir_def *result;
667    if (intr->intrinsic == nir_intrinsic_shared_atomic_swap)
668       result = nir_deref_atomic_swap(b, 32, &deref->def, intr->src[1].ssa, intr->src[2].ssa,
669                                      .atomic_op = nir_intrinsic_atomic_op(intr));
670    else
671       result = nir_deref_atomic(b, 32, &deref->def, intr->src[1].ssa,
672                                 .atomic_op = nir_intrinsic_atomic_op(intr));
673 
674    nir_def_replace(&intr->def, result);
675    return true;
676 }
677 
678 bool
dxil_nir_lower_loads_stores_to_dxil(nir_shader * nir,const struct dxil_nir_lower_loads_stores_options * options)679 dxil_nir_lower_loads_stores_to_dxil(nir_shader *nir,
680                                     const struct dxil_nir_lower_loads_stores_options *options)
681 {
682    bool progress = nir_remove_dead_variables(nir, nir_var_function_temp | nir_var_mem_shared, NULL);
683    nir_variable *shared_var = NULL;
684    if (nir->info.shared_size) {
685       shared_var = nir_variable_create(nir, nir_var_mem_shared,
686                                        glsl_array_type(glsl_uint_type(), DIV_ROUND_UP(nir->info.shared_size, 4), 4),
687                                        "lowered_shared_mem");
688    }
689 
690    unsigned ptr_size = nir->info.cs.ptr_size;
691    if (nir->info.stage == MESA_SHADER_KERNEL) {
692       /* All the derefs created here will be used as GEP indices so force 32-bit */
693       nir->info.cs.ptr_size = 32;
694    }
695    nir_foreach_function_impl(impl, nir) {
696       nir_builder b = nir_builder_create(impl);
697 
698       nir_variable *scratch_var = NULL;
699       if (nir->scratch_size) {
700          const struct glsl_type *scratch_type = glsl_array_type(glsl_uint_type(), DIV_ROUND_UP(nir->scratch_size, 4), 4);
701          scratch_var = nir_local_variable_create(impl, scratch_type, "lowered_scratch_mem");
702       }
703 
704       nir_foreach_block(block, impl) {
705          nir_foreach_instr_safe(instr, block) {
706             if (instr->type != nir_instr_type_intrinsic)
707                continue;
708             nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
709 
710             switch (intr->intrinsic) {
711             case nir_intrinsic_load_shared:
712                progress |= lower_32b_offset_load(&b, intr, shared_var);
713                break;
714             case nir_intrinsic_load_scratch:
715                progress |= lower_32b_offset_load(&b, intr, scratch_var);
716                break;
717             case nir_intrinsic_store_shared:
718                progress |= lower_32b_offset_store(&b, intr, shared_var);
719                break;
720             case nir_intrinsic_store_scratch:
721                progress |= lower_32b_offset_store(&b, intr, scratch_var);
722                break;
723             case nir_intrinsic_shared_atomic:
724             case nir_intrinsic_shared_atomic_swap:
725                progress |= lower_shared_atomic(&b, intr, shared_var);
726                break;
727             default:
728                break;
729             }
730          }
731       }
732    }
733    if (nir->info.stage == MESA_SHADER_KERNEL) {
734       nir->info.cs.ptr_size = ptr_size;
735    }
736 
737    return progress;
738 }
739 
740 static bool
lower_deref_ssbo(nir_builder * b,nir_deref_instr * deref)741 lower_deref_ssbo(nir_builder *b, nir_deref_instr *deref)
742 {
743    assert(nir_deref_mode_is(deref, nir_var_mem_ssbo));
744    assert(deref->deref_type == nir_deref_type_var ||
745           deref->deref_type == nir_deref_type_cast);
746    nir_variable *var = deref->var;
747 
748    b->cursor = nir_before_instr(&deref->instr);
749 
750    if (deref->deref_type == nir_deref_type_var) {
751       /* We turn all deref_var into deref_cast and build a pointer value based on
752        * the var binding which encodes the UAV id.
753        */
754       nir_def *ptr = nir_imm_int64(b, (uint64_t)var->data.binding << 32);
755       nir_deref_instr *deref_cast =
756          nir_build_deref_cast(b, ptr, nir_var_mem_ssbo, deref->type,
757                               glsl_get_explicit_stride(var->type));
758       nir_def_replace(&deref->def, &deref_cast->def);
759 
760       deref = deref_cast;
761       return true;
762    }
763    return false;
764 }
765 
766 bool
dxil_nir_lower_deref_ssbo(nir_shader * nir)767 dxil_nir_lower_deref_ssbo(nir_shader *nir)
768 {
769    bool progress = false;
770 
771    foreach_list_typed(nir_function, func, node, &nir->functions) {
772       if (!func->is_entrypoint)
773          continue;
774       assert(func->impl);
775 
776       nir_builder b = nir_builder_create(func->impl);
777 
778       nir_foreach_block(block, func->impl) {
779          nir_foreach_instr_safe(instr, block) {
780             if (instr->type != nir_instr_type_deref)
781                continue;
782 
783             nir_deref_instr *deref = nir_instr_as_deref(instr);
784 
785             if (!nir_deref_mode_is(deref, nir_var_mem_ssbo) ||
786                 (deref->deref_type != nir_deref_type_var &&
787                  deref->deref_type != nir_deref_type_cast))
788                continue;
789 
790             progress |= lower_deref_ssbo(&b, deref);
791          }
792       }
793    }
794 
795    return progress;
796 }
797 
798 static bool
lower_alu_deref_srcs(nir_builder * b,nir_alu_instr * alu)799 lower_alu_deref_srcs(nir_builder *b, nir_alu_instr *alu)
800 {
801    const nir_op_info *info = &nir_op_infos[alu->op];
802    bool progress = false;
803 
804    b->cursor = nir_before_instr(&alu->instr);
805 
806    for (unsigned i = 0; i < info->num_inputs; i++) {
807       nir_deref_instr *deref = nir_src_as_deref(alu->src[i].src);
808 
809       if (!deref)
810          continue;
811 
812       nir_deref_path path;
813       nir_deref_path_init(&path, deref, NULL);
814       nir_deref_instr *root_deref = path.path[0];
815       nir_deref_path_finish(&path);
816 
817       if (root_deref->deref_type != nir_deref_type_cast)
818          continue;
819 
820       nir_def *ptr =
821          nir_iadd(b, root_deref->parent.ssa,
822                      nir_build_deref_offset(b, deref, cl_type_size_align));
823       nir_src_rewrite(&alu->src[i].src, ptr);
824       progress = true;
825    }
826 
827    return progress;
828 }
829 
830 bool
dxil_nir_opt_alu_deref_srcs(nir_shader * nir)831 dxil_nir_opt_alu_deref_srcs(nir_shader *nir)
832 {
833    bool progress = false;
834 
835    foreach_list_typed(nir_function, func, node, &nir->functions) {
836       if (!func->is_entrypoint)
837          continue;
838       assert(func->impl);
839 
840       nir_builder b = nir_builder_create(func->impl);
841 
842       nir_foreach_block(block, func->impl) {
843          nir_foreach_instr_safe(instr, block) {
844             if (instr->type != nir_instr_type_alu)
845                continue;
846 
847             nir_alu_instr *alu = nir_instr_as_alu(instr);
848             progress |= lower_alu_deref_srcs(&b, alu);
849          }
850       }
851    }
852 
853    return progress;
854 }
855 
856 static void
cast_phi(nir_builder * b,nir_phi_instr * phi,unsigned new_bit_size)857 cast_phi(nir_builder *b, nir_phi_instr *phi, unsigned new_bit_size)
858 {
859    nir_phi_instr *lowered = nir_phi_instr_create(b->shader);
860    int num_components = 0;
861    int old_bit_size = phi->def.bit_size;
862 
863    nir_foreach_phi_src(src, phi) {
864       assert(num_components == 0 || num_components == src->src.ssa->num_components);
865       num_components = src->src.ssa->num_components;
866 
867       b->cursor = nir_after_instr_and_phis(src->src.ssa->parent_instr);
868 
869       nir_def *cast = nir_u2uN(b, src->src.ssa, new_bit_size);
870 
871       nir_phi_instr_add_src(lowered, src->pred, cast);
872    }
873 
874    nir_def_init(&lowered->instr, &lowered->def, num_components,
875                 new_bit_size);
876 
877    b->cursor = nir_before_instr(&phi->instr);
878    nir_builder_instr_insert(b, &lowered->instr);
879 
880    b->cursor = nir_after_phis(nir_cursor_current_block(b->cursor));
881    nir_def *result = nir_u2uN(b, &lowered->def, old_bit_size);
882 
883    nir_def_replace(&phi->def, result);
884 }
885 
886 static bool
upcast_phi_impl(nir_function_impl * impl,unsigned min_bit_size)887 upcast_phi_impl(nir_function_impl *impl, unsigned min_bit_size)
888 {
889    nir_builder b = nir_builder_create(impl);
890    bool progress = false;
891 
892    nir_foreach_block_reverse(block, impl) {
893       nir_foreach_phi_safe(phi, block) {
894          if (phi->def.bit_size == 1 ||
895              phi->def.bit_size >= min_bit_size)
896             continue;
897 
898          cast_phi(&b, phi, min_bit_size);
899          progress = true;
900       }
901    }
902 
903    if (progress) {
904       nir_metadata_preserve(impl, nir_metadata_control_flow);
905    } else {
906       nir_metadata_preserve(impl, nir_metadata_all);
907    }
908 
909    return progress;
910 }
911 
912 bool
dxil_nir_lower_upcast_phis(nir_shader * shader,unsigned min_bit_size)913 dxil_nir_lower_upcast_phis(nir_shader *shader, unsigned min_bit_size)
914 {
915    bool progress = false;
916 
917    nir_foreach_function_impl(impl, shader) {
918       progress |= upcast_phi_impl(impl, min_bit_size);
919    }
920 
921    return progress;
922 }
923 
924 struct dxil_nir_split_clip_cull_distance_params {
925    nir_variable *new_var[2];
926    nir_shader *shader;
927 };
928 
929 /* In GLSL and SPIR-V, clip and cull distance are arrays of floats (with a limit of 8).
930  * In DXIL, clip and cull distances are up to 2 float4s combined.
931  * Coming from GLSL, we can request this 2 float4 format, but coming from SPIR-V,
932  * we can't, and have to accept a "compact" array of scalar floats.
933  *
934  * To help emitting a valid input signature for this case, split the variables so that they
935  * match what we need to put in the signature (e.g. { float clip[4]; float clip1; float cull[3]; })
936  *
937  * This pass can deal with splitting across two axes:
938  * 1. Given { float clip[5]; float cull[3]; }, split clip into clip[4] and clip1[1]. This is
939  *    what's produced by nir_lower_clip_cull_distance_arrays.
940  * 2. Given { float clip[4]; float clipcull[4]; }, split clipcull into clip1[1] and cull[3].
941  *    This is what's produced by the sequence of nir_lower_clip_cull_distance_arrays, then
942  *    I/O lowering, vectorization, optimization, and I/O un-lowering.
943  */
944 static bool
dxil_nir_split_clip_cull_distance_instr(nir_builder * b,nir_instr * instr,void * cb_data)945 dxil_nir_split_clip_cull_distance_instr(nir_builder *b,
946                                         nir_instr *instr,
947                                         void *cb_data)
948 {
949    struct dxil_nir_split_clip_cull_distance_params *params = cb_data;
950 
951    if (instr->type != nir_instr_type_deref)
952       return false;
953 
954    nir_deref_instr *deref = nir_instr_as_deref(instr);
955    nir_variable *var = nir_deref_instr_get_variable(deref);
956    if (!var ||
957        var->data.location < VARYING_SLOT_CLIP_DIST0 ||
958        var->data.location > VARYING_SLOT_CULL_DIST1 ||
959        !var->data.compact)
960       return false;
961 
962    unsigned new_var_idx = var->data.mode == nir_var_shader_in ? 0 : 1;
963    nir_variable *new_var = params->new_var[new_var_idx];
964 
965    /* The location should only be inside clip distance, because clip
966     * and cull should've been merged by nir_lower_clip_cull_distance_arrays()
967     */
968    assert(var->data.location == VARYING_SLOT_CLIP_DIST0 ||
969           var->data.location == VARYING_SLOT_CLIP_DIST1);
970 
971    /* The deref chain to the clip/cull variables should be simple, just the
972     * var and an array with a constant index, otherwise more lowering/optimization
973     * might be needed before this pass, e.g. copy prop, lower_io_to_temporaries,
974     * split_var_copies, and/or lower_var_copies. In the case of arrayed I/O like
975     * inputs to the tessellation or geometry stages, there might be a second level
976     * of array index.
977     */
978    assert(deref->deref_type == nir_deref_type_var ||
979           deref->deref_type == nir_deref_type_array);
980 
981    bool clip_size_accurate = var->data.mode == nir_var_shader_out || b->shader->info.stage == MESA_SHADER_FRAGMENT;
982    bool is_clip_cull_split = false;
983 
984    b->cursor = nir_before_instr(instr);
985    unsigned arrayed_io_length = 0;
986    const struct glsl_type *old_type = var->type;
987    if (nir_is_arrayed_io(var, b->shader->info.stage)) {
988       arrayed_io_length = glsl_array_size(old_type);
989       old_type = glsl_get_array_element(old_type);
990    }
991    int old_length = glsl_array_size(old_type);
992    if (!new_var) {
993       /* Update lengths for new and old vars */
994       int new_length = (old_length + var->data.location_frac) - 4;
995 
996       /* The existing variable fits in the float4 */
997       if (new_length <= 0) {
998          /* If we don't have an accurate clip size in the shader info, then we're in case 1 (only
999           * lowered clip_cull, not lowered+unlowered I/O), since I/O optimization moves these varyings
1000           * out of sysval locations and into generic locations. */
1001          if (!clip_size_accurate)
1002             return false;
1003 
1004          assert(old_length <= 4);
1005          unsigned start = (var->data.location - VARYING_SLOT_CLIP_DIST0) * 4;
1006          unsigned end = start + old_length;
1007          /* If it doesn't straddle the clip array size, then it's either just clip or cull, not both. */
1008          if (start >= b->shader->info.clip_distance_array_size ||
1009                end <= b->shader->info.clip_distance_array_size)
1010             return false;
1011 
1012          new_length = end - b->shader->info.clip_distance_array_size;
1013          is_clip_cull_split = true;
1014       }
1015 
1016       old_length -= new_length;
1017       new_var = nir_variable_clone(var, params->shader);
1018       nir_shader_add_variable(params->shader, new_var);
1019       assert(glsl_get_base_type(glsl_get_array_element(old_type)) == GLSL_TYPE_FLOAT);
1020       var->type = glsl_array_type(glsl_float_type(), old_length, 0);
1021       new_var->type = glsl_array_type(glsl_float_type(), new_length, 0);
1022       if (arrayed_io_length) {
1023          var->type = glsl_array_type(var->type, arrayed_io_length, 0);
1024          new_var->type = glsl_array_type(new_var->type, arrayed_io_length, 0);
1025       }
1026 
1027       if (is_clip_cull_split) {
1028          new_var->data.location_frac = old_length;
1029       } else {
1030          new_var->data.location++;
1031          new_var->data.location_frac = 0;
1032       }
1033       params->new_var[new_var_idx] = new_var;
1034    }
1035 
1036    /* Update the type for derefs of the old var */
1037    if (deref->deref_type == nir_deref_type_var) {
1038       deref->type = var->type;
1039       return false;
1040    }
1041 
1042    if (glsl_type_is_array(deref->type)) {
1043       assert(arrayed_io_length > 0);
1044       deref->type = glsl_get_array_element(var->type);
1045       return false;
1046    }
1047 
1048    assert(glsl_get_base_type(deref->type) == GLSL_TYPE_FLOAT);
1049 
1050    nir_const_value *index = nir_src_as_const_value(deref->arr.index);
1051    assert(index);
1052 
1053    /* If we're indexing out-of-bounds of the old variable, then adjust to point
1054     * to the new variable with a smaller index.
1055     */
1056    if (index->u32 < old_length)
1057       return false;
1058 
1059    nir_deref_instr *new_var_deref = nir_build_deref_var(b, new_var);
1060    nir_deref_instr *new_intermediate_deref = new_var_deref;
1061    if (arrayed_io_length) {
1062       nir_deref_instr *parent = nir_src_as_deref(deref->parent);
1063       assert(parent->deref_type == nir_deref_type_array);
1064       new_intermediate_deref = nir_build_deref_array(b, new_intermediate_deref, parent->arr.index.ssa);
1065    }
1066    nir_deref_instr *new_array_deref = nir_build_deref_array(b, new_intermediate_deref, nir_imm_int(b, index->u32 - old_length));
1067    nir_def_rewrite_uses(&deref->def, &new_array_deref->def);
1068    return true;
1069 }
1070 
1071 bool
dxil_nir_split_clip_cull_distance(nir_shader * shader)1072 dxil_nir_split_clip_cull_distance(nir_shader *shader)
1073 {
1074    struct dxil_nir_split_clip_cull_distance_params params = {
1075       .new_var = { NULL, NULL },
1076       .shader = shader,
1077    };
1078    nir_shader_instructions_pass(shader,
1079                                 dxil_nir_split_clip_cull_distance_instr,
1080                                 nir_metadata_control_flow |
1081                                 nir_metadata_loop_analysis,
1082                                 &params);
1083    return params.new_var[0] != NULL || params.new_var[1] != NULL;
1084 }
1085 
1086 static bool
dxil_nir_lower_double_math_instr(nir_builder * b,nir_instr * instr,UNUSED void * cb_data)1087 dxil_nir_lower_double_math_instr(nir_builder *b,
1088                                  nir_instr *instr,
1089                                  UNUSED void *cb_data)
1090 {
1091    if (instr->type == nir_instr_type_intrinsic) {
1092       nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
1093       switch (intr->intrinsic) {
1094          case nir_intrinsic_reduce:
1095          case nir_intrinsic_exclusive_scan:
1096          case nir_intrinsic_inclusive_scan:
1097             break;
1098          default:
1099             return false;
1100       }
1101       if (intr->def.bit_size != 64)
1102          return false;
1103       nir_op reduction = nir_intrinsic_reduction_op(intr);
1104       switch (reduction) {
1105          case nir_op_fmul:
1106          case nir_op_fadd:
1107          case nir_op_fmin:
1108          case nir_op_fmax:
1109             break;
1110          default:
1111             return false;
1112       }
1113       b->cursor = nir_before_instr(instr);
1114       nir_src_rewrite(&intr->src[0], nir_pack_double_2x32_dxil(b, nir_unpack_64_2x32(b, intr->src[0].ssa)));
1115       b->cursor = nir_after_instr(instr);
1116       nir_def *result = nir_pack_64_2x32(b, nir_unpack_double_2x32_dxil(b, &intr->def));
1117       nir_def_rewrite_uses_after(&intr->def, result, result->parent_instr);
1118       return true;
1119    }
1120 
1121    if (instr->type != nir_instr_type_alu)
1122       return false;
1123 
1124    nir_alu_instr *alu = nir_instr_as_alu(instr);
1125 
1126    /* TODO: See if we can apply this explicitly to packs/unpacks that are then
1127     * used as a double. As-is, if we had an app explicitly do a 64bit integer op,
1128     * then try to bitcast to double (not expressible in HLSL, but it is in other
1129     * source languages), this would unpack the integer and repack as a double, when
1130     * we probably want to just send the bitcast through to the backend.
1131     */
1132 
1133    b->cursor = nir_before_instr(&alu->instr);
1134 
1135    bool progress = false;
1136    for (unsigned i = 0; i < nir_op_infos[alu->op].num_inputs; ++i) {
1137       if (nir_alu_type_get_base_type(nir_op_infos[alu->op].input_types[i]) == nir_type_float &&
1138           alu->src[i].src.ssa->bit_size == 64) {
1139          unsigned num_components = nir_op_infos[alu->op].input_sizes[i];
1140          if (!num_components)
1141             num_components = alu->def.num_components;
1142          nir_def *components[NIR_MAX_VEC_COMPONENTS];
1143          for (unsigned c = 0; c < num_components; ++c) {
1144             nir_def *packed_double = nir_channel(b, alu->src[i].src.ssa, alu->src[i].swizzle[c]);
1145             nir_def *unpacked_double = nir_unpack_64_2x32(b, packed_double);
1146             components[c] = nir_pack_double_2x32_dxil(b, unpacked_double);
1147             alu->src[i].swizzle[c] = c;
1148          }
1149          nir_src_rewrite(&alu->src[i].src,
1150                          nir_vec(b, components, num_components));
1151          progress = true;
1152       }
1153    }
1154 
1155    if (nir_alu_type_get_base_type(nir_op_infos[alu->op].output_type) == nir_type_float &&
1156        alu->def.bit_size == 64) {
1157       b->cursor = nir_after_instr(&alu->instr);
1158       nir_def *components[NIR_MAX_VEC_COMPONENTS];
1159       for (unsigned c = 0; c < alu->def.num_components; ++c) {
1160          nir_def *packed_double = nir_channel(b, &alu->def, c);
1161          nir_def *unpacked_double = nir_unpack_double_2x32_dxil(b, packed_double);
1162          components[c] = nir_pack_64_2x32(b, unpacked_double);
1163       }
1164       nir_def *repacked_dvec = nir_vec(b, components, alu->def.num_components);
1165       nir_def_rewrite_uses_after(&alu->def, repacked_dvec, repacked_dvec->parent_instr);
1166       progress = true;
1167    }
1168 
1169    return progress;
1170 }
1171 
1172 bool
dxil_nir_lower_double_math(nir_shader * shader)1173 dxil_nir_lower_double_math(nir_shader *shader)
1174 {
1175    return nir_shader_instructions_pass(shader,
1176                                        dxil_nir_lower_double_math_instr,
1177                                        nir_metadata_control_flow |
1178                                        nir_metadata_loop_analysis,
1179                                        NULL);
1180 }
1181 
1182 typedef struct {
1183    gl_system_value *values;
1184    uint32_t count;
1185 } zero_system_values_state;
1186 
1187 static bool
lower_system_value_to_zero_filter(const nir_instr * instr,const void * cb_state)1188 lower_system_value_to_zero_filter(const nir_instr* instr, const void* cb_state)
1189 {
1190    if (instr->type != nir_instr_type_intrinsic) {
1191       return false;
1192    }
1193 
1194    nir_intrinsic_instr* intrin = nir_instr_as_intrinsic(instr);
1195 
1196    /* All the intrinsics we care about are loads */
1197    if (!nir_intrinsic_infos[intrin->intrinsic].has_dest)
1198       return false;
1199 
1200    zero_system_values_state* state = (zero_system_values_state*)cb_state;
1201    for (uint32_t i = 0; i < state->count; ++i) {
1202       gl_system_value value = state->values[i];
1203       nir_intrinsic_op value_op = nir_intrinsic_from_system_value(value);
1204 
1205       if (intrin->intrinsic == value_op) {
1206          return true;
1207       } else if (intrin->intrinsic == nir_intrinsic_load_deref) {
1208          nir_deref_instr* deref = nir_src_as_deref(intrin->src[0]);
1209          if (!nir_deref_mode_is(deref, nir_var_system_value))
1210             return false;
1211 
1212          nir_variable* var = deref->var;
1213          if (var->data.location == value) {
1214             return true;
1215          }
1216       }
1217    }
1218 
1219    return false;
1220 }
1221 
1222 static nir_def*
lower_system_value_to_zero_instr(nir_builder * b,nir_instr * instr,void * _state)1223 lower_system_value_to_zero_instr(nir_builder* b, nir_instr* instr, void* _state)
1224 {
1225    return nir_imm_int(b, 0);
1226 }
1227 
1228 bool
dxil_nir_lower_system_values_to_zero(nir_shader * shader,gl_system_value * system_values,uint32_t count)1229 dxil_nir_lower_system_values_to_zero(nir_shader* shader,
1230                                      gl_system_value* system_values,
1231                                      uint32_t count)
1232 {
1233    zero_system_values_state state = { system_values, count };
1234    return nir_shader_lower_instructions(shader,
1235       lower_system_value_to_zero_filter,
1236       lower_system_value_to_zero_instr,
1237       &state);
1238 }
1239 
1240 static void
lower_load_local_group_size(nir_builder * b,nir_intrinsic_instr * intr)1241 lower_load_local_group_size(nir_builder *b, nir_intrinsic_instr *intr)
1242 {
1243    b->cursor = nir_after_instr(&intr->instr);
1244 
1245    nir_const_value v[3] = {
1246       nir_const_value_for_int(b->shader->info.workgroup_size[0], 32),
1247       nir_const_value_for_int(b->shader->info.workgroup_size[1], 32),
1248       nir_const_value_for_int(b->shader->info.workgroup_size[2], 32)
1249    };
1250    nir_def *size = nir_build_imm(b, 3, 32, v);
1251    nir_def_replace(&intr->def, size);
1252 }
1253 
1254 static bool
lower_system_values_impl(nir_builder * b,nir_intrinsic_instr * intr,void * _state)1255 lower_system_values_impl(nir_builder *b, nir_intrinsic_instr *intr,
1256                          void *_state)
1257 {
1258    switch (intr->intrinsic) {
1259    case nir_intrinsic_load_workgroup_size:
1260       lower_load_local_group_size(b, intr);
1261       return true;
1262    default:
1263       return false;
1264    }
1265 }
1266 
1267 bool
dxil_nir_lower_system_values(nir_shader * shader)1268 dxil_nir_lower_system_values(nir_shader *shader)
1269 {
1270    return nir_shader_intrinsics_pass(shader, lower_system_values_impl,
1271                                      nir_metadata_control_flow | nir_metadata_loop_analysis,
1272                                      NULL);
1273 }
1274 
1275 static const struct glsl_type *
get_bare_samplers_for_type(const struct glsl_type * type,bool is_shadow)1276 get_bare_samplers_for_type(const struct glsl_type *type, bool is_shadow)
1277 {
1278    const struct glsl_type *base_sampler_type =
1279       is_shadow ?
1280       glsl_bare_shadow_sampler_type() : glsl_bare_sampler_type();
1281    return glsl_type_wrap_in_arrays(base_sampler_type, type);
1282 }
1283 
1284 static const struct glsl_type *
get_textures_for_sampler_type(const struct glsl_type * type)1285 get_textures_for_sampler_type(const struct glsl_type *type)
1286 {
1287    return glsl_type_wrap_in_arrays(
1288       glsl_sampler_type_to_texture(
1289          glsl_without_array(type)), type);
1290 }
1291 
1292 static bool
redirect_sampler_derefs(struct nir_builder * b,nir_instr * instr,void * data)1293 redirect_sampler_derefs(struct nir_builder *b, nir_instr *instr, void *data)
1294 {
1295    if (instr->type != nir_instr_type_tex)
1296       return false;
1297 
1298    nir_tex_instr *tex = nir_instr_as_tex(instr);
1299 
1300    int sampler_idx = nir_tex_instr_src_index(tex, nir_tex_src_sampler_deref);
1301    if (sampler_idx == -1) {
1302       /* No sampler deref - does this instruction even need a sampler? If not,
1303        * sampler_index doesn't necessarily point to a sampler, so early-out.
1304        */
1305       if (!nir_tex_instr_need_sampler(tex))
1306          return false;
1307 
1308       /* No derefs but needs a sampler, must be using indices */
1309       nir_variable *bare_sampler = _mesa_hash_table_u64_search(data, tex->sampler_index);
1310 
1311       /* Already have a bare sampler here */
1312       if (bare_sampler)
1313          return false;
1314 
1315       nir_variable *old_sampler = NULL;
1316       nir_foreach_variable_with_modes(var, b->shader, nir_var_uniform) {
1317          if (var->data.binding <= tex->sampler_index &&
1318              var->data.binding + glsl_type_get_sampler_count(var->type) >
1319                 tex->sampler_index) {
1320 
1321             /* Already have a bare sampler for this binding and it is of the
1322              * correct type, add it to the table */
1323             if (glsl_type_is_bare_sampler(glsl_without_array(var->type)) &&
1324                 glsl_sampler_type_is_shadow(glsl_without_array(var->type)) ==
1325                    tex->is_shadow) {
1326                _mesa_hash_table_u64_insert(data, tex->sampler_index, var);
1327                return false;
1328             }
1329 
1330             old_sampler = var;
1331          }
1332       }
1333 
1334       assert(old_sampler);
1335 
1336       /* Clone the original sampler to a bare sampler of the correct type */
1337       bare_sampler = nir_variable_clone(old_sampler, b->shader);
1338       nir_shader_add_variable(b->shader, bare_sampler);
1339 
1340       bare_sampler->type =
1341          get_bare_samplers_for_type(old_sampler->type, tex->is_shadow);
1342       _mesa_hash_table_u64_insert(data, tex->sampler_index, bare_sampler);
1343       return true;
1344    }
1345 
1346    /* Using derefs, means we have to rewrite the deref chain in addition to cloning */
1347    nir_deref_instr *final_deref = nir_src_as_deref(tex->src[sampler_idx].src);
1348    nir_deref_path path;
1349    nir_deref_path_init(&path, final_deref, NULL);
1350 
1351    nir_deref_instr *old_tail = path.path[0];
1352    assert(old_tail->deref_type == nir_deref_type_var);
1353    nir_variable *old_var = old_tail->var;
1354    if (glsl_type_is_bare_sampler(glsl_without_array(old_var->type)) &&
1355        glsl_sampler_type_is_shadow(glsl_without_array(old_var->type)) ==
1356           tex->is_shadow) {
1357       nir_deref_path_finish(&path);
1358       return false;
1359    }
1360 
1361    uint64_t var_key = ((uint64_t)old_var->data.descriptor_set << 32) |
1362                       old_var->data.binding;
1363    nir_variable *new_var = _mesa_hash_table_u64_search(data, var_key);
1364    if (!new_var) {
1365       new_var = nir_variable_clone(old_var, b->shader);
1366       nir_shader_add_variable(b->shader, new_var);
1367       new_var->type =
1368          get_bare_samplers_for_type(old_var->type, tex->is_shadow);
1369       _mesa_hash_table_u64_insert(data, var_key, new_var);
1370    }
1371 
1372    b->cursor = nir_after_instr(&old_tail->instr);
1373    nir_deref_instr *new_tail = nir_build_deref_var(b, new_var);
1374 
1375    for (unsigned i = 1; path.path[i]; ++i) {
1376       b->cursor = nir_after_instr(&path.path[i]->instr);
1377       new_tail = nir_build_deref_follower(b, new_tail, path.path[i]);
1378    }
1379 
1380    nir_deref_path_finish(&path);
1381    nir_src_rewrite(&tex->src[sampler_idx].src, &new_tail->def);
1382    return true;
1383 }
1384 
1385 static bool
redirect_texture_derefs(struct nir_builder * b,nir_instr * instr,void * data)1386 redirect_texture_derefs(struct nir_builder *b, nir_instr *instr, void *data)
1387 {
1388    if (instr->type != nir_instr_type_tex)
1389       return false;
1390 
1391    nir_tex_instr *tex = nir_instr_as_tex(instr);
1392 
1393    int texture_idx = nir_tex_instr_src_index(tex, nir_tex_src_texture_deref);
1394    if (texture_idx == -1) {
1395       /* No derefs, must be using indices */
1396       nir_variable *bare_sampler = _mesa_hash_table_u64_search(data, tex->texture_index);
1397 
1398       /* Already have a texture here */
1399       if (bare_sampler)
1400          return false;
1401 
1402       nir_variable *typed_sampler = NULL;
1403       nir_foreach_variable_with_modes(var, b->shader, nir_var_uniform) {
1404          if (var->data.binding <= tex->texture_index &&
1405              var->data.binding + glsl_type_get_texture_count(var->type) > tex->texture_index) {
1406             /* Already have a texture for this binding, add it to the table */
1407             _mesa_hash_table_u64_insert(data, tex->texture_index, var);
1408             return false;
1409          }
1410 
1411          if (var->data.binding <= tex->texture_index &&
1412              var->data.binding + glsl_type_get_sampler_count(var->type) > tex->texture_index &&
1413              !glsl_type_is_bare_sampler(glsl_without_array(var->type))) {
1414             typed_sampler = var;
1415          }
1416       }
1417 
1418       /* Clone the typed sampler to a texture and we're done */
1419       assert(typed_sampler);
1420       bare_sampler = nir_variable_clone(typed_sampler, b->shader);
1421       bare_sampler->type = get_textures_for_sampler_type(typed_sampler->type);
1422       nir_shader_add_variable(b->shader, bare_sampler);
1423       _mesa_hash_table_u64_insert(data, tex->texture_index, bare_sampler);
1424       return true;
1425    }
1426 
1427    /* Using derefs, means we have to rewrite the deref chain in addition to cloning */
1428    nir_deref_instr *final_deref = nir_src_as_deref(tex->src[texture_idx].src);
1429    nir_deref_path path;
1430    nir_deref_path_init(&path, final_deref, NULL);
1431 
1432    nir_deref_instr *old_tail = path.path[0];
1433    assert(old_tail->deref_type == nir_deref_type_var);
1434    nir_variable *old_var = old_tail->var;
1435    if (glsl_type_is_texture(glsl_without_array(old_var->type)) ||
1436        glsl_type_is_image(glsl_without_array(old_var->type))) {
1437       nir_deref_path_finish(&path);
1438       return false;
1439    }
1440 
1441    uint64_t var_key = ((uint64_t)old_var->data.descriptor_set << 32) |
1442                       old_var->data.binding;
1443    nir_variable *new_var = _mesa_hash_table_u64_search(data, var_key);
1444    if (!new_var) {
1445       new_var = nir_variable_clone(old_var, b->shader);
1446       new_var->type = get_textures_for_sampler_type(old_var->type);
1447       nir_shader_add_variable(b->shader, new_var);
1448       _mesa_hash_table_u64_insert(data, var_key, new_var);
1449    }
1450 
1451    b->cursor = nir_after_instr(&old_tail->instr);
1452    nir_deref_instr *new_tail = nir_build_deref_var(b, new_var);
1453 
1454    for (unsigned i = 1; path.path[i]; ++i) {
1455       b->cursor = nir_after_instr(&path.path[i]->instr);
1456       new_tail = nir_build_deref_follower(b, new_tail, path.path[i]);
1457    }
1458 
1459    nir_deref_path_finish(&path);
1460    nir_src_rewrite(&tex->src[texture_idx].src, &new_tail->def);
1461 
1462    return true;
1463 }
1464 
1465 bool
dxil_nir_split_typed_samplers(nir_shader * nir)1466 dxil_nir_split_typed_samplers(nir_shader *nir)
1467 {
1468    struct hash_table_u64 *hash_table = _mesa_hash_table_u64_create(NULL);
1469 
1470    bool progress = nir_shader_instructions_pass(nir, redirect_sampler_derefs,
1471       nir_metadata_control_flow | nir_metadata_loop_analysis, hash_table);
1472 
1473    _mesa_hash_table_u64_clear(hash_table);
1474 
1475    progress |= nir_shader_instructions_pass(nir, redirect_texture_derefs,
1476       nir_metadata_control_flow | nir_metadata_loop_analysis, hash_table);
1477 
1478    _mesa_hash_table_u64_destroy(hash_table);
1479    return progress;
1480 }
1481 
1482 
1483 static bool
lower_sysval_to_load_input_impl(nir_builder * b,nir_intrinsic_instr * intr,void * data)1484 lower_sysval_to_load_input_impl(nir_builder *b, nir_intrinsic_instr *intr,
1485                                 void *data)
1486 {
1487    gl_system_value sysval = SYSTEM_VALUE_MAX;
1488    switch (intr->intrinsic) {
1489    case nir_intrinsic_load_instance_id:
1490       sysval = SYSTEM_VALUE_INSTANCE_ID;
1491       break;
1492    case nir_intrinsic_load_vertex_id_zero_base:
1493       sysval = SYSTEM_VALUE_VERTEX_ID_ZERO_BASE;
1494       break;
1495    default:
1496       return false;
1497    }
1498 
1499    nir_variable **sysval_vars = (nir_variable **)data;
1500    nir_variable *var = sysval_vars[sysval];
1501    assert(var);
1502 
1503    const nir_alu_type dest_type = nir_get_nir_type_for_glsl_type(var->type);
1504    const unsigned bit_size = intr->def.bit_size;
1505 
1506    b->cursor = nir_before_instr(&intr->instr);
1507    nir_def *result = nir_load_input(b, intr->def.num_components, bit_size, nir_imm_int(b, 0),
1508       .base = var->data.driver_location, .dest_type = dest_type);
1509 
1510    nir_def_rewrite_uses(&intr->def, result);
1511    return true;
1512 }
1513 
1514 bool
dxil_nir_lower_sysval_to_load_input(nir_shader * s,nir_variable ** sysval_vars)1515 dxil_nir_lower_sysval_to_load_input(nir_shader *s, nir_variable **sysval_vars)
1516 {
1517    return nir_shader_intrinsics_pass(s, lower_sysval_to_load_input_impl,
1518                                      nir_metadata_control_flow,
1519                                      sysval_vars);
1520 }
1521 
1522 /* Comparison function to sort io values so that first come normal varyings,
1523  * then system values, and then system generated values.
1524  */
1525 static int
variable_location_cmp(const nir_variable * a,const nir_variable * b)1526 variable_location_cmp(const nir_variable* a, const nir_variable* b)
1527 {
1528    // Sort by stream, driver_location, location, location_frac, then index
1529    // If all else is equal, sort full vectors before partial ones
1530    unsigned a_location = a->data.location;
1531    if (a_location >= VARYING_SLOT_PATCH0)
1532       a_location -= VARYING_SLOT_PATCH0;
1533    unsigned b_location = b->data.location;
1534    if (b_location >= VARYING_SLOT_PATCH0)
1535       b_location -= VARYING_SLOT_PATCH0;
1536    unsigned a_stream = a->data.stream & ~NIR_STREAM_PACKED;
1537    unsigned b_stream = b->data.stream & ~NIR_STREAM_PACKED;
1538    return a_stream != b_stream ?
1539             a_stream - b_stream :
1540             a->data.driver_location != b->data.driver_location ?
1541                a->data.driver_location - b->data.driver_location :
1542                a_location !=  b_location ?
1543                   a_location - b_location :
1544                   a->data.location_frac != b->data.location_frac ?
1545                      a->data.location_frac - b->data.location_frac :
1546                      a->data.index != b->data.index ?
1547                         a->data.index - b->data.index :
1548                         glsl_get_component_slots(b->type) - glsl_get_component_slots(a->type);
1549 }
1550 
1551 /* Order varyings according to driver location */
1552 void
dxil_sort_by_driver_location(nir_shader * s,nir_variable_mode modes)1553 dxil_sort_by_driver_location(nir_shader* s, nir_variable_mode modes)
1554 {
1555    nir_sort_variables_with_modes(s, variable_location_cmp, modes);
1556 }
1557 
1558 /* Sort PS outputs so that color outputs come first */
1559 void
dxil_sort_ps_outputs(nir_shader * s)1560 dxil_sort_ps_outputs(nir_shader* s)
1561 {
1562    nir_foreach_variable_with_modes_safe(var, s, nir_var_shader_out) {
1563       /* We use the driver_location here to avoid introducing a new
1564        * struct or member variable here. The true, updated driver location
1565        * will be written below, after sorting */
1566       switch (var->data.location) {
1567       case FRAG_RESULT_DEPTH:
1568          var->data.driver_location = 1;
1569          break;
1570       case FRAG_RESULT_STENCIL:
1571          var->data.driver_location = 2;
1572          break;
1573       case FRAG_RESULT_SAMPLE_MASK:
1574          var->data.driver_location = 3;
1575          break;
1576       default:
1577          var->data.driver_location = 0;
1578       }
1579    }
1580 
1581    nir_sort_variables_with_modes(s, variable_location_cmp,
1582                                  nir_var_shader_out);
1583 
1584    unsigned driver_loc = 0;
1585    nir_foreach_variable_with_modes(var, s, nir_var_shader_out) {
1586       /* Fractional vars should use the same driver_location as the base. These will
1587        * get fully merged during signature processing.
1588        */
1589       var->data.driver_location = var->data.location_frac ? driver_loc - 1 : driver_loc++;
1590    }
1591 }
1592 
1593 enum dxil_sysvalue_type {
1594    DXIL_NO_SYSVALUE = 0,
1595    DXIL_USED_SYSVALUE,
1596    DXIL_UNUSED_NO_SYSVALUE,
1597    DXIL_SYSVALUE,
1598    DXIL_GENERATED_SYSVALUE,
1599 };
1600 
1601 static enum dxil_sysvalue_type
nir_var_to_dxil_sysvalue_type(nir_variable * var,uint64_t other_stage_mask,const BITSET_WORD * other_stage_frac_mask)1602 nir_var_to_dxil_sysvalue_type(nir_variable *var, uint64_t other_stage_mask,
1603                               const BITSET_WORD *other_stage_frac_mask)
1604 {
1605    switch (var->data.location) {
1606    case VARYING_SLOT_FACE:
1607       return DXIL_GENERATED_SYSVALUE;
1608    case VARYING_SLOT_POS:
1609    case VARYING_SLOT_PRIMITIVE_ID:
1610    case VARYING_SLOT_CLIP_DIST0:
1611    case VARYING_SLOT_CLIP_DIST1:
1612    case VARYING_SLOT_PSIZ:
1613    case VARYING_SLOT_TESS_LEVEL_INNER:
1614    case VARYING_SLOT_TESS_LEVEL_OUTER:
1615    case VARYING_SLOT_VIEWPORT:
1616    case VARYING_SLOT_LAYER:
1617    case VARYING_SLOT_VIEW_INDEX:
1618       if (!((1ull << var->data.location) & other_stage_mask))
1619          return DXIL_SYSVALUE;
1620       return DXIL_USED_SYSVALUE;
1621    default:
1622       if (var->data.location < VARYING_SLOT_PATCH0 &&
1623           !((1ull << var->data.location) & other_stage_mask))
1624          return DXIL_UNUSED_NO_SYSVALUE;
1625       if (var->data.location_frac && other_stage_frac_mask &&
1626           var->data.location >= VARYING_SLOT_VAR0 &&
1627           !BITSET_TEST(other_stage_frac_mask, ((var->data.location - VARYING_SLOT_VAR0) * 4 + var->data.location_frac)))
1628          return DXIL_UNUSED_NO_SYSVALUE;
1629       return DXIL_NO_SYSVALUE;
1630    }
1631 }
1632 
1633 /* Order between stage values so that normal varyings come first,
1634  * then sysvalues and then system generated values.
1635  */
1636 void
dxil_reassign_driver_locations(nir_shader * s,nir_variable_mode modes,uint64_t other_stage_mask,const BITSET_WORD * other_stage_frac_mask)1637 dxil_reassign_driver_locations(nir_shader* s, nir_variable_mode modes,
1638    uint64_t other_stage_mask, const BITSET_WORD *other_stage_frac_mask)
1639 {
1640    nir_foreach_variable_with_modes_safe(var, s, modes) {
1641       /* We use the driver_location here to avoid introducing a new
1642        * struct or member variable here. The true, updated driver location
1643        * will be written below, after sorting */
1644       var->data.driver_location = nir_var_to_dxil_sysvalue_type(var, other_stage_mask, other_stage_frac_mask);
1645    }
1646 
1647    nir_sort_variables_with_modes(s, variable_location_cmp, modes);
1648 
1649    unsigned driver_loc = 0, driver_patch_loc = 0;
1650    nir_foreach_variable_with_modes(var, s, modes) {
1651       /* Overlap patches with non-patch */
1652       unsigned *loc = var->data.patch ? &driver_patch_loc : &driver_loc;
1653       var->data.driver_location = *loc;
1654 
1655       const struct glsl_type *type = var->type;
1656       if (nir_is_arrayed_io(var, s->info.stage) && glsl_type_is_array(type))
1657          type = glsl_get_array_element(type);
1658       *loc += glsl_count_vec4_slots(type, false, false);
1659    }
1660 }
1661 
1662 static bool
lower_ubo_array_one_to_static(struct nir_builder * b,nir_intrinsic_instr * intrin,void * cb_data)1663 lower_ubo_array_one_to_static(struct nir_builder *b,
1664                               nir_intrinsic_instr *intrin,
1665                               void *cb_data)
1666 {
1667    if (intrin->intrinsic != nir_intrinsic_load_vulkan_descriptor)
1668       return false;
1669 
1670    nir_variable *var =
1671       nir_get_binding_variable(b->shader, nir_chase_binding(intrin->src[0]));
1672 
1673    if (!var)
1674       return false;
1675 
1676    if (!glsl_type_is_array(var->type) || glsl_array_size(var->type) != 1)
1677       return false;
1678 
1679    nir_intrinsic_instr *index = nir_src_as_intrinsic(intrin->src[0]);
1680    /* We currently do not support reindex */
1681    assert(index && index->intrinsic == nir_intrinsic_vulkan_resource_index);
1682 
1683    if (nir_src_is_const(index->src[0]) && nir_src_as_uint(index->src[0]) == 0)
1684       return false;
1685 
1686    if (nir_intrinsic_desc_type(index) != VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER)
1687       return false;
1688 
1689    b->cursor = nir_instr_remove(&index->instr);
1690 
1691    // Indexing out of bounds on array of UBOs is considered undefined
1692    // behavior. Therefore, we just hardcode all the index to 0.
1693    uint8_t bit_size = index->def.bit_size;
1694    nir_def *zero = nir_imm_intN_t(b, 0, bit_size);
1695    nir_def *dest =
1696       nir_vulkan_resource_index(b, index->num_components, bit_size, zero,
1697                                 .desc_set = nir_intrinsic_desc_set(index),
1698                                 .binding = nir_intrinsic_binding(index),
1699                                 .desc_type = nir_intrinsic_desc_type(index));
1700 
1701    nir_def_rewrite_uses(&index->def, dest);
1702 
1703    return true;
1704 }
1705 
1706 bool
dxil_nir_lower_ubo_array_one_to_static(nir_shader * s)1707 dxil_nir_lower_ubo_array_one_to_static(nir_shader *s)
1708 {
1709    bool progress = nir_shader_intrinsics_pass(s,
1710                                               lower_ubo_array_one_to_static,
1711                                               nir_metadata_none, NULL);
1712 
1713    return progress;
1714 }
1715 
1716 static bool
is_fquantize2f16(const nir_instr * instr,const void * data)1717 is_fquantize2f16(const nir_instr *instr, const void *data)
1718 {
1719    if (instr->type != nir_instr_type_alu)
1720       return false;
1721 
1722    nir_alu_instr *alu = nir_instr_as_alu(instr);
1723    return alu->op == nir_op_fquantize2f16;
1724 }
1725 
1726 static nir_def *
lower_fquantize2f16(struct nir_builder * b,nir_instr * instr,void * data)1727 lower_fquantize2f16(struct nir_builder *b, nir_instr *instr, void *data)
1728 {
1729    /*
1730     * SpvOpQuantizeToF16 documentation says:
1731     *
1732     * "
1733     * If Value is an infinity, the result is the same infinity.
1734     * If Value is a NaN, the result is a NaN, but not necessarily the same NaN.
1735     * If Value is positive with a magnitude too large to represent as a 16-bit
1736     * floating-point value, the result is positive infinity. If Value is negative
1737     * with a magnitude too large to represent as a 16-bit floating-point value,
1738     * the result is negative infinity. If the magnitude of Value is too small to
1739     * represent as a normalized 16-bit floating-point value, the result may be
1740     * either +0 or -0.
1741     * "
1742     *
1743     * which we turn into:
1744     *
1745     *   if (val < MIN_FLOAT16)
1746     *      return -INFINITY;
1747     *   else if (val > MAX_FLOAT16)
1748     *      return -INFINITY;
1749     *   else if (fabs(val) < SMALLEST_NORMALIZED_FLOAT16 && sign(val) != 0)
1750     *      return -0.0f;
1751     *   else if (fabs(val) < SMALLEST_NORMALIZED_FLOAT16 && sign(val) == 0)
1752     *      return +0.0f;
1753     *   else
1754     *      return round(val);
1755     */
1756    nir_alu_instr *alu = nir_instr_as_alu(instr);
1757    nir_def *src =
1758       alu->src[0].src.ssa;
1759 
1760    nir_def *neg_inf_cond =
1761       nir_flt_imm(b, src, -65504.0f);
1762    nir_def *pos_inf_cond =
1763       nir_fgt_imm(b, src, 65504.0f);
1764    nir_def *zero_cond =
1765       nir_flt_imm(b, nir_fabs(b, src), ldexpf(1.0, -14));
1766    nir_def *zero = nir_iand_imm(b, src, 1 << 31);
1767    nir_def *round = nir_iand_imm(b, src, ~BITFIELD_MASK(13));
1768 
1769    nir_def *res =
1770       nir_bcsel(b, neg_inf_cond, nir_imm_float(b, -INFINITY), round);
1771    res = nir_bcsel(b, pos_inf_cond, nir_imm_float(b, INFINITY), res);
1772    res = nir_bcsel(b, zero_cond, zero, res);
1773    return res;
1774 }
1775 
1776 bool
dxil_nir_lower_fquantize2f16(nir_shader * s)1777 dxil_nir_lower_fquantize2f16(nir_shader *s)
1778 {
1779    return nir_shader_lower_instructions(s, is_fquantize2f16, lower_fquantize2f16, NULL);
1780 }
1781 
1782 static bool
fix_io_uint_deref_types(struct nir_builder * builder,nir_instr * instr,void * data)1783 fix_io_uint_deref_types(struct nir_builder *builder, nir_instr *instr, void *data)
1784 {
1785    if (instr->type != nir_instr_type_deref)
1786       return false;
1787 
1788    nir_deref_instr *deref = nir_instr_as_deref(instr);
1789    nir_variable *var = nir_deref_instr_get_variable(deref);
1790 
1791    if (var == data) {
1792       deref->type = glsl_type_wrap_in_arrays(glsl_uint_type(), deref->type);
1793       return true;
1794    }
1795 
1796    return false;
1797 }
1798 
1799 static bool
fix_io_uint_type(nir_shader * s,nir_variable_mode modes,int slot)1800 fix_io_uint_type(nir_shader *s, nir_variable_mode modes, int slot)
1801 {
1802    nir_variable *fixed_var = NULL;
1803    nir_foreach_variable_with_modes(var, s, modes) {
1804       if (var->data.location == slot) {
1805          const struct glsl_type *plain_type = glsl_without_array(var->type);
1806          if (plain_type == glsl_uint_type())
1807             return false;
1808 
1809          assert(plain_type == glsl_int_type());
1810          var->type = glsl_type_wrap_in_arrays(glsl_uint_type(), var->type);
1811          fixed_var = var;
1812          break;
1813       }
1814    }
1815 
1816    assert(fixed_var);
1817 
1818    return nir_shader_instructions_pass(s, fix_io_uint_deref_types,
1819                                        nir_metadata_all, fixed_var);
1820 }
1821 
1822 bool
dxil_nir_fix_io_uint_type(nir_shader * s,uint64_t in_mask,uint64_t out_mask)1823 dxil_nir_fix_io_uint_type(nir_shader *s, uint64_t in_mask, uint64_t out_mask)
1824 {
1825    if (!(s->info.outputs_written & out_mask) &&
1826        !(s->info.inputs_read & in_mask))
1827       return false;
1828 
1829    bool progress = false;
1830 
1831    while (in_mask) {
1832       int slot = u_bit_scan64(&in_mask);
1833       progress |= (s->info.inputs_read & (1ull << slot)) &&
1834                   fix_io_uint_type(s, nir_var_shader_in, slot);
1835    }
1836 
1837    while (out_mask) {
1838       int slot = u_bit_scan64(&out_mask);
1839       progress |= (s->info.outputs_written & (1ull << slot)) &&
1840                   fix_io_uint_type(s, nir_var_shader_out, slot);
1841    }
1842 
1843    return progress;
1844 }
1845 
1846 static bool
lower_kill(struct nir_builder * builder,nir_intrinsic_instr * intr,void * _cb_data)1847 lower_kill(struct nir_builder *builder, nir_intrinsic_instr *intr,
1848            void *_cb_data)
1849 {
1850    if (intr->intrinsic != nir_intrinsic_terminate &&
1851        intr->intrinsic != nir_intrinsic_terminate_if)
1852       return false;
1853 
1854    builder->cursor = nir_instr_remove(&intr->instr);
1855    nir_def *condition;
1856 
1857    if (intr->intrinsic == nir_intrinsic_terminate) {
1858       nir_demote(builder);
1859       condition = nir_imm_true(builder);
1860    } else {
1861       nir_demote_if(builder, intr->src[0].ssa);
1862       condition = intr->src[0].ssa;
1863    }
1864 
1865    /* Create a new block by branching on the discard condition so that this return
1866     * is definitely the last instruction in its own block */
1867    nir_if *nif = nir_push_if(builder, condition);
1868    nir_jump(builder, nir_jump_return);
1869    nir_pop_if(builder, nif);
1870 
1871    return true;
1872 }
1873 
1874 bool
dxil_nir_lower_discard_and_terminate(nir_shader * s)1875 dxil_nir_lower_discard_and_terminate(nir_shader *s)
1876 {
1877    if (s->info.stage != MESA_SHADER_FRAGMENT)
1878       return false;
1879 
1880    // This pass only works if all functions have been inlined
1881    assert(exec_list_length(&s->functions) == 1);
1882    return nir_shader_intrinsics_pass(s, lower_kill, nir_metadata_none, NULL);
1883 }
1884 
1885 static bool
update_writes(struct nir_builder * b,nir_intrinsic_instr * intr,void * _state)1886 update_writes(struct nir_builder *b, nir_intrinsic_instr *intr, void *_state)
1887 {
1888    if (intr->intrinsic != nir_intrinsic_store_output)
1889       return false;
1890 
1891    nir_io_semantics io = nir_intrinsic_io_semantics(intr);
1892    if (io.location != VARYING_SLOT_POS)
1893       return false;
1894 
1895    nir_def *src = intr->src[0].ssa;
1896    unsigned write_mask = nir_intrinsic_write_mask(intr);
1897    if (src->num_components == 4 && write_mask == 0xf)
1898       return false;
1899 
1900    b->cursor = nir_before_instr(&intr->instr);
1901    unsigned first_comp = nir_intrinsic_component(intr);
1902    nir_def *channels[4] = { NULL, NULL, NULL, NULL };
1903    assert(first_comp + src->num_components <= ARRAY_SIZE(channels));
1904    for (unsigned i = 0; i < src->num_components; ++i)
1905       if (write_mask & (1 << i))
1906          channels[i + first_comp] = nir_channel(b, src, i);
1907    for (unsigned i = 0; i < 4; ++i)
1908       if (!channels[i])
1909          channels[i] = nir_imm_intN_t(b, 0, src->bit_size);
1910 
1911    intr->num_components = 4;
1912    nir_src_rewrite(&intr->src[0], nir_vec(b, channels, 4));
1913    nir_intrinsic_set_component(intr, 0);
1914    nir_intrinsic_set_write_mask(intr, 0xf);
1915    return true;
1916 }
1917 
1918 bool
dxil_nir_ensure_position_writes(nir_shader * s)1919 dxil_nir_ensure_position_writes(nir_shader *s)
1920 {
1921    if (s->info.stage != MESA_SHADER_VERTEX &&
1922        s->info.stage != MESA_SHADER_GEOMETRY &&
1923        s->info.stage != MESA_SHADER_TESS_EVAL)
1924       return false;
1925    if ((s->info.outputs_written & VARYING_BIT_POS) == 0)
1926       return false;
1927 
1928    return nir_shader_intrinsics_pass(s, update_writes,
1929                                        nir_metadata_control_flow,
1930                                        NULL);
1931 }
1932 
1933 static bool
is_sample_pos(const nir_instr * instr,const void * _data)1934 is_sample_pos(const nir_instr *instr, const void *_data)
1935 {
1936    if (instr->type != nir_instr_type_intrinsic)
1937       return false;
1938    nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
1939    return intr->intrinsic == nir_intrinsic_load_sample_pos;
1940 }
1941 
1942 static nir_def *
lower_sample_pos(nir_builder * b,nir_instr * instr,void * _data)1943 lower_sample_pos(nir_builder *b, nir_instr *instr, void *_data)
1944 {
1945    return nir_load_sample_pos_from_id(b, 32, nir_load_sample_id(b));
1946 }
1947 
1948 bool
dxil_nir_lower_sample_pos(nir_shader * s)1949 dxil_nir_lower_sample_pos(nir_shader *s)
1950 {
1951    return nir_shader_lower_instructions(s, is_sample_pos, lower_sample_pos, NULL);
1952 }
1953 
1954 static bool
lower_subgroup_id(nir_builder * b,nir_intrinsic_instr * intr,void * data)1955 lower_subgroup_id(nir_builder *b, nir_intrinsic_instr *intr, void *data)
1956 {
1957    if (intr->intrinsic != nir_intrinsic_load_subgroup_id)
1958       return false;
1959 
1960    b->cursor = nir_before_impl(b->impl);
1961    if (b->shader->info.stage == MESA_SHADER_COMPUTE &&
1962        b->shader->info.workgroup_size[1] == 1 &&
1963        b->shader->info.workgroup_size[2] == 1) {
1964       /* When using Nx1x1 groups, use a simple stable algorithm
1965        * which is almost guaranteed to be correct. */
1966       nir_def *subgroup_id = nir_udiv(b, nir_load_local_invocation_index(b), nir_load_subgroup_size(b));
1967       nir_def_rewrite_uses(&intr->def, subgroup_id);
1968       return true;
1969    }
1970 
1971    nir_def **subgroup_id = (nir_def **)data;
1972    if (*subgroup_id == NULL) {
1973       nir_variable *subgroup_id_counter = nir_variable_create(b->shader, nir_var_mem_shared, glsl_uint_type(), "dxil_SubgroupID_counter");
1974       nir_variable *subgroup_id_local = nir_local_variable_create(b->impl, glsl_uint_type(), "dxil_SubgroupID_local");
1975       nir_store_var(b, subgroup_id_local, nir_imm_int(b, 0), 1);
1976 
1977       nir_deref_instr *counter_deref = nir_build_deref_var(b, subgroup_id_counter);
1978       nir_def *tid = nir_load_local_invocation_index(b);
1979       nir_if *nif = nir_push_if(b, nir_ieq_imm(b, tid, 0));
1980       nir_store_deref(b, counter_deref, nir_imm_int(b, 0), 1);
1981       nir_pop_if(b, nif);
1982 
1983       nir_barrier(b,
1984                          .execution_scope = SCOPE_WORKGROUP,
1985                          .memory_scope = SCOPE_WORKGROUP,
1986                          .memory_semantics = NIR_MEMORY_ACQ_REL,
1987                          .memory_modes = nir_var_mem_shared);
1988 
1989       nif = nir_push_if(b, nir_elect(b, 1));
1990       nir_def *subgroup_id_first_thread = nir_deref_atomic(b, 32, &counter_deref->def, nir_imm_int(b, 1),
1991                                                                .atomic_op = nir_atomic_op_iadd);
1992       nir_store_var(b, subgroup_id_local, subgroup_id_first_thread, 1);
1993       nir_pop_if(b, nif);
1994 
1995       nir_def *subgroup_id_loaded = nir_load_var(b, subgroup_id_local);
1996       *subgroup_id = nir_read_first_invocation(b, subgroup_id_loaded);
1997    }
1998    nir_def_rewrite_uses(&intr->def, *subgroup_id);
1999    return true;
2000 }
2001 
2002 bool
dxil_nir_lower_subgroup_id(nir_shader * s)2003 dxil_nir_lower_subgroup_id(nir_shader *s)
2004 {
2005    nir_def *subgroup_id = NULL;
2006    return nir_shader_intrinsics_pass(s, lower_subgroup_id, nir_metadata_none,
2007                                      &subgroup_id);
2008 }
2009 
2010 static bool
lower_num_subgroups(nir_builder * b,nir_intrinsic_instr * intr,void * data)2011 lower_num_subgroups(nir_builder *b, nir_intrinsic_instr *intr, void *data)
2012 {
2013    if (intr->intrinsic != nir_intrinsic_load_num_subgroups)
2014       return false;
2015 
2016    b->cursor = nir_before_instr(&intr->instr);
2017    nir_def *subgroup_size = nir_load_subgroup_size(b);
2018    nir_def *size_minus_one = nir_iadd_imm(b, subgroup_size, -1);
2019    nir_def *workgroup_size_vec = nir_load_workgroup_size(b);
2020    nir_def *workgroup_size = nir_imul(b, nir_channel(b, workgroup_size_vec, 0),
2021                                              nir_imul(b, nir_channel(b, workgroup_size_vec, 1),
2022                                                          nir_channel(b, workgroup_size_vec, 2)));
2023    nir_def *ret = nir_idiv(b, nir_iadd(b, workgroup_size, size_minus_one), subgroup_size);
2024    nir_def_rewrite_uses(&intr->def, ret);
2025    return true;
2026 }
2027 
2028 bool
dxil_nir_lower_num_subgroups(nir_shader * s)2029 dxil_nir_lower_num_subgroups(nir_shader *s)
2030 {
2031    return nir_shader_intrinsics_pass(s, lower_num_subgroups,
2032                                        nir_metadata_control_flow |
2033                                        nir_metadata_loop_analysis, NULL);
2034 }
2035 
2036 
2037 static const struct glsl_type *
get_cast_type(unsigned bit_size)2038 get_cast_type(unsigned bit_size)
2039 {
2040    switch (bit_size) {
2041    case 64:
2042       return glsl_int64_t_type();
2043    case 32:
2044       return glsl_int_type();
2045    case 16:
2046       return glsl_int16_t_type();
2047    case 8:
2048       return glsl_int8_t_type();
2049    }
2050    unreachable("Invalid bit_size");
2051 }
2052 
2053 static void
split_unaligned_load(nir_builder * b,nir_intrinsic_instr * intrin,unsigned alignment)2054 split_unaligned_load(nir_builder *b, nir_intrinsic_instr *intrin, unsigned alignment)
2055 {
2056    enum gl_access_qualifier access = nir_intrinsic_access(intrin);
2057    nir_def *srcs[NIR_MAX_VEC_COMPONENTS * NIR_MAX_VEC_COMPONENTS * sizeof(int64_t) / 8];
2058    unsigned comp_size = intrin->def.bit_size / 8;
2059    unsigned num_comps = intrin->def.num_components;
2060 
2061    b->cursor = nir_before_instr(&intrin->instr);
2062 
2063    nir_deref_instr *ptr = nir_src_as_deref(intrin->src[0]);
2064 
2065    const struct glsl_type *cast_type = get_cast_type(alignment * 8);
2066    nir_deref_instr *cast = nir_build_deref_cast(b, &ptr->def, ptr->modes, cast_type, alignment);
2067 
2068    unsigned num_loads = DIV_ROUND_UP(comp_size * num_comps, alignment);
2069    for (unsigned i = 0; i < num_loads; ++i) {
2070       nir_deref_instr *elem = nir_build_deref_ptr_as_array(b, cast, nir_imm_intN_t(b, i, cast->def.bit_size));
2071       srcs[i] = nir_load_deref_with_access(b, elem, access);
2072    }
2073 
2074    nir_def *new_dest = nir_extract_bits(b, srcs, num_loads, 0, num_comps, intrin->def.bit_size);
2075    nir_def_replace(&intrin->def, new_dest);
2076 }
2077 
2078 static void
split_unaligned_store(nir_builder * b,nir_intrinsic_instr * intrin,unsigned alignment)2079 split_unaligned_store(nir_builder *b, nir_intrinsic_instr *intrin, unsigned alignment)
2080 {
2081    enum gl_access_qualifier access = nir_intrinsic_access(intrin);
2082 
2083    nir_def *value = intrin->src[1].ssa;
2084    unsigned comp_size = value->bit_size / 8;
2085    unsigned num_comps = value->num_components;
2086 
2087    b->cursor = nir_before_instr(&intrin->instr);
2088 
2089    nir_deref_instr *ptr = nir_src_as_deref(intrin->src[0]);
2090 
2091    const struct glsl_type *cast_type = get_cast_type(alignment * 8);
2092    nir_deref_instr *cast = nir_build_deref_cast(b, &ptr->def, ptr->modes, cast_type, alignment);
2093 
2094    unsigned num_stores = DIV_ROUND_UP(comp_size * num_comps, alignment);
2095    for (unsigned i = 0; i < num_stores; ++i) {
2096       nir_def *substore_val = nir_extract_bits(b, &value, 1, i * alignment * 8, 1, alignment * 8);
2097       nir_deref_instr *elem = nir_build_deref_ptr_as_array(b, cast, nir_imm_intN_t(b, i, cast->def.bit_size));
2098       nir_store_deref_with_access(b, elem, substore_val, ~0, access);
2099    }
2100 
2101    nir_instr_remove(&intrin->instr);
2102 }
2103 
2104 bool
dxil_nir_split_unaligned_loads_stores(nir_shader * shader,nir_variable_mode modes)2105 dxil_nir_split_unaligned_loads_stores(nir_shader *shader, nir_variable_mode modes)
2106 {
2107    bool progress = false;
2108 
2109    nir_foreach_function_impl(impl, shader) {
2110       nir_builder b = nir_builder_create(impl);
2111 
2112       nir_foreach_block(block, impl) {
2113          nir_foreach_instr_safe(instr, block) {
2114             if (instr->type != nir_instr_type_intrinsic)
2115                continue;
2116             nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
2117             if (intrin->intrinsic != nir_intrinsic_load_deref &&
2118                 intrin->intrinsic != nir_intrinsic_store_deref)
2119                continue;
2120             nir_deref_instr *deref = nir_src_as_deref(intrin->src[0]);
2121             if (!nir_deref_mode_may_be(deref, modes))
2122                continue;
2123 
2124             unsigned align_mul = 0, align_offset = 0;
2125             nir_get_explicit_deref_align(deref, true, &align_mul, &align_offset);
2126 
2127             unsigned alignment = align_offset ? 1 << (ffs(align_offset) - 1) : align_mul;
2128 
2129             /* We can load anything at 4-byte alignment, except for
2130              * UBOs (AKA CBs where the granularity is 16 bytes).
2131              */
2132             unsigned req_align = (nir_deref_mode_is_one_of(deref, nir_var_mem_ubo | nir_var_mem_push_const) ? 16 : 4);
2133             if (alignment >= req_align)
2134                continue;
2135 
2136             nir_def *val;
2137             if (intrin->intrinsic == nir_intrinsic_load_deref) {
2138                val = &intrin->def;
2139             } else {
2140                val = intrin->src[1].ssa;
2141             }
2142 
2143             unsigned scalar_byte_size = glsl_type_is_boolean(deref->type) ? 4 : glsl_get_bit_size(deref->type) / 8;
2144             unsigned num_components =
2145                /* If the vector stride is larger than the scalar size, lower_explicit_io will
2146                 * turn this into multiple scalar loads anyway, so we don't have to split it here. */
2147                glsl_get_explicit_stride(deref->type) > scalar_byte_size ? 1 :
2148                (val->num_components == 3 ? 4 : val->num_components);
2149             unsigned natural_alignment = scalar_byte_size * num_components;
2150 
2151             if (alignment >= natural_alignment)
2152                continue;
2153 
2154             if (intrin->intrinsic == nir_intrinsic_load_deref)
2155                split_unaligned_load(&b, intrin, alignment);
2156             else
2157                split_unaligned_store(&b, intrin, alignment);
2158             progress = true;
2159          }
2160       }
2161    }
2162 
2163    return progress;
2164 }
2165 
2166 static void
lower_inclusive_to_exclusive(nir_builder * b,nir_intrinsic_instr * intr)2167 lower_inclusive_to_exclusive(nir_builder *b, nir_intrinsic_instr *intr)
2168 {
2169    b->cursor = nir_after_instr(&intr->instr);
2170 
2171    nir_op op = nir_intrinsic_reduction_op(intr);
2172    intr->intrinsic = nir_intrinsic_exclusive_scan;
2173    nir_intrinsic_set_reduction_op(intr, op);
2174 
2175    nir_def *final_val = nir_build_alu2(b, nir_intrinsic_reduction_op(intr),
2176                                            &intr->def, intr->src[0].ssa);
2177    nir_def_rewrite_uses_after(&intr->def, final_val, final_val->parent_instr);
2178 }
2179 
2180 static bool
lower_subgroup_scan(nir_builder * b,nir_intrinsic_instr * intr,void * data)2181 lower_subgroup_scan(nir_builder *b, nir_intrinsic_instr *intr, void *data)
2182 {
2183    switch (intr->intrinsic) {
2184    case nir_intrinsic_exclusive_scan:
2185    case nir_intrinsic_inclusive_scan:
2186       switch ((nir_op)nir_intrinsic_reduction_op(intr)) {
2187       case nir_op_iadd:
2188       case nir_op_fadd:
2189       case nir_op_imul:
2190       case nir_op_fmul:
2191          if (intr->intrinsic == nir_intrinsic_exclusive_scan)
2192             return false;
2193          lower_inclusive_to_exclusive(b, intr);
2194          return true;
2195       default:
2196          break;
2197       }
2198       break;
2199    default:
2200       return false;
2201    }
2202 
2203    b->cursor = nir_before_instr(&intr->instr);
2204    nir_op op = nir_intrinsic_reduction_op(intr);
2205    nir_def *subgroup_id = nir_load_subgroup_invocation(b);
2206    nir_def *subgroup_size = nir_load_subgroup_size(b);
2207    nir_def *active_threads = nir_ballot(b, 4, 32, nir_imm_true(b));
2208    nir_def *base_value;
2209    uint32_t bit_size = intr->def.bit_size;
2210    if (op == nir_op_iand || op == nir_op_umin)
2211       base_value = nir_imm_intN_t(b, ~0ull, bit_size);
2212    else if (op == nir_op_imin)
2213       base_value = nir_imm_intN_t(b, (1ull << (bit_size - 1)) - 1, bit_size);
2214    else if (op == nir_op_imax)
2215       base_value = nir_imm_intN_t(b, 1ull << (bit_size - 1), bit_size);
2216    else if (op == nir_op_fmax)
2217       base_value = nir_imm_floatN_t(b, -INFINITY, bit_size);
2218    else if (op == nir_op_fmin)
2219       base_value = nir_imm_floatN_t(b, INFINITY, bit_size);
2220    else
2221       base_value = nir_imm_intN_t(b, 0, bit_size);
2222 
2223    nir_variable *loop_counter_var = nir_local_variable_create(b->impl, glsl_uint_type(), "subgroup_loop_counter");
2224    nir_variable *result_var = nir_local_variable_create(b->impl,
2225                                                         glsl_vector_type(nir_get_glsl_base_type_for_nir_type(
2226                                                            nir_op_infos[op].input_types[0] | bit_size), 1),
2227                                                         "subgroup_loop_result");
2228    nir_store_var(b, loop_counter_var, nir_imm_int(b, 0), 1);
2229    nir_store_var(b, result_var, base_value, 1);
2230    nir_loop *loop = nir_push_loop(b);
2231    nir_def *loop_counter = nir_load_var(b, loop_counter_var);
2232 
2233    nir_if *nif = nir_push_if(b, nir_ilt(b, loop_counter, subgroup_size));
2234    nir_def *other_thread_val = nir_read_invocation(b, intr->src[0].ssa, loop_counter);
2235    nir_def *thread_in_range = intr->intrinsic == nir_intrinsic_inclusive_scan ?
2236       nir_ige(b, subgroup_id, loop_counter) :
2237       nir_ilt(b, loop_counter, subgroup_id);
2238    nir_def *thread_active = nir_ballot_bitfield_extract(b, 1, active_threads, loop_counter);
2239 
2240    nir_if *if_active_thread = nir_push_if(b, nir_iand(b, thread_in_range, thread_active));
2241    nir_def *result = nir_build_alu2(b, op, nir_load_var(b, result_var), other_thread_val);
2242    nir_store_var(b, result_var, result, 1);
2243    nir_pop_if(b, if_active_thread);
2244 
2245    nir_store_var(b, loop_counter_var, nir_iadd_imm(b, loop_counter, 1), 1);
2246    nir_jump(b, nir_jump_continue);
2247    nir_pop_if(b, nif);
2248 
2249    nir_jump(b, nir_jump_break);
2250    nir_pop_loop(b, loop);
2251 
2252    result = nir_load_var(b, result_var);
2253    nir_def_rewrite_uses(&intr->def, result);
2254    return true;
2255 }
2256 
2257 bool
dxil_nir_lower_unsupported_subgroup_scan(nir_shader * s)2258 dxil_nir_lower_unsupported_subgroup_scan(nir_shader *s)
2259 {
2260    bool ret = nir_shader_intrinsics_pass(s, lower_subgroup_scan,
2261                                          nir_metadata_none, NULL);
2262    if (ret) {
2263       /* Lower the ballot bitfield tests */
2264       nir_lower_subgroups_options options = { .ballot_bit_size = 32, .ballot_components = 4 };
2265       nir_lower_subgroups(s, &options);
2266    }
2267    return ret;
2268 }
2269 
2270 bool
dxil_nir_forward_front_face(nir_shader * nir)2271 dxil_nir_forward_front_face(nir_shader *nir)
2272 {
2273    assert(nir->info.stage == MESA_SHADER_FRAGMENT);
2274 
2275    nir_variable *var = nir_find_variable_with_location(nir, nir_var_shader_in, VARYING_SLOT_FACE);
2276    if (var) {
2277       var->data.location = VARYING_SLOT_VAR12;
2278       return true;
2279    }
2280    return false;
2281 }
2282 
2283 static bool
move_consts(nir_builder * b,nir_instr * instr,void * data)2284 move_consts(nir_builder *b, nir_instr *instr, void *data)
2285 {
2286    bool progress = false;
2287    switch (instr->type) {
2288    case nir_instr_type_load_const: {
2289       /* Sink load_const to their uses if there's multiple */
2290       nir_load_const_instr *load_const = nir_instr_as_load_const(instr);
2291       if (!list_is_singular(&load_const->def.uses)) {
2292          nir_foreach_use_safe(src, &load_const->def) {
2293             b->cursor = nir_before_src(src);
2294             nir_load_const_instr *new_load = nir_load_const_instr_create(b->shader,
2295                                                                          load_const->def.num_components,
2296                                                                          load_const->def.bit_size);
2297             memcpy(new_load->value, load_const->value, sizeof(load_const->value[0]) * load_const->def.num_components);
2298             nir_builder_instr_insert(b, &new_load->instr);
2299             nir_src_rewrite(src, &new_load->def);
2300             progress = true;
2301          }
2302       }
2303       return progress;
2304    }
2305    default:
2306       return false;
2307    }
2308 }
2309 
2310 /* Sink all consts so that they have only have a single use.
2311  * The DXIL backend will already de-dupe the constants to the
2312  * same dxil_value if they have the same type, but this allows a single constant
2313  * to have different types without bitcasts. */
2314 bool
dxil_nir_move_consts(nir_shader * s)2315 dxil_nir_move_consts(nir_shader *s)
2316 {
2317    return nir_shader_instructions_pass(s, move_consts,
2318                                        nir_metadata_control_flow,
2319                                        NULL);
2320 }
2321 
2322 static void
clear_pass_flags(nir_function_impl * impl)2323 clear_pass_flags(nir_function_impl *impl)
2324 {
2325    nir_foreach_block(block, impl) {
2326       nir_foreach_instr(instr, block) {
2327          instr->pass_flags = 0;
2328       }
2329    }
2330 }
2331 
2332 static bool
add_def_to_worklist(nir_def * def,void * state)2333 add_def_to_worklist(nir_def *def, void *state)
2334 {
2335    nir_foreach_use_including_if(src, def) {
2336       if (nir_src_is_if(src)) {
2337          nir_if *nif = nir_src_parent_if(src);
2338          nir_foreach_block_in_cf_node(block, &nif->cf_node) {
2339             nir_foreach_instr(instr, block)
2340                nir_instr_worklist_push_tail(state, instr);
2341          }
2342       } else
2343          nir_instr_worklist_push_tail(state, nir_src_parent_instr(src));
2344    }
2345    return true;
2346 }
2347 
2348 static bool
set_input_bits(struct dxil_module * mod,nir_intrinsic_instr * intr,BITSET_WORD * input_bits,uint32_t *** tables,const uint32_t ** table_sizes)2349 set_input_bits(struct dxil_module *mod, nir_intrinsic_instr *intr, BITSET_WORD *input_bits, uint32_t ***tables, const uint32_t **table_sizes)
2350 {
2351    if (intr->intrinsic == nir_intrinsic_load_view_index) {
2352       BITSET_SET(input_bits, 0);
2353       return true;
2354    }
2355 
2356    bool any_bits_set = false;
2357    nir_src *row_src = intr->intrinsic == nir_intrinsic_load_per_vertex_input ? &intr->src[1] : &intr->src[0];
2358    bool is_patch_constant = mod->shader_kind == DXIL_DOMAIN_SHADER && intr->intrinsic == nir_intrinsic_load_input;
2359    const struct dxil_signature_record *sig_rec = is_patch_constant ?
2360       &mod->patch_consts[mod->patch_mappings[nir_intrinsic_base(intr)]] :
2361       &mod->inputs[mod->input_mappings[nir_intrinsic_base(intr)]];
2362    if (is_patch_constant) {
2363       /* Redirect to the second I/O table */
2364       *tables = *tables + 1;
2365       *table_sizes = *table_sizes + 1;
2366    }
2367    for (uint32_t component = 0; component < intr->num_components; ++component) {
2368       uint32_t base_element = 0;
2369       uint32_t num_elements = sig_rec->num_elements;
2370       if (nir_src_is_const(*row_src)) {
2371          base_element = (uint32_t)nir_src_as_uint(*row_src);
2372          num_elements = 1;
2373       }
2374       for (uint32_t element = 0; element < num_elements; ++element) {
2375          uint32_t row = sig_rec->elements[element + base_element].reg;
2376          if (row == 0xffffffff)
2377             continue;
2378          BITSET_SET(input_bits, row * 4 + component + nir_intrinsic_component(intr));
2379          any_bits_set = true;
2380       }
2381    }
2382    return any_bits_set;
2383 }
2384 
2385 static bool
set_output_bits(struct dxil_module * mod,nir_intrinsic_instr * intr,BITSET_WORD * input_bits,uint32_t ** tables,const uint32_t * table_sizes)2386 set_output_bits(struct dxil_module *mod, nir_intrinsic_instr *intr, BITSET_WORD *input_bits, uint32_t **tables, const uint32_t *table_sizes)
2387 {
2388    bool any_bits_set = false;
2389    nir_src *row_src = intr->intrinsic == nir_intrinsic_store_per_vertex_output ? &intr->src[2] : &intr->src[1];
2390    bool is_patch_constant = mod->shader_kind == DXIL_HULL_SHADER && intr->intrinsic == nir_intrinsic_store_output;
2391    const struct dxil_signature_record *sig_rec = is_patch_constant ?
2392       &mod->patch_consts[mod->patch_mappings[nir_intrinsic_base(intr)]] :
2393       &mod->outputs[mod->output_mappings[nir_intrinsic_base(intr)]];
2394    for (uint32_t component = 0; component < intr->num_components; ++component) {
2395       uint32_t base_element = 0;
2396       uint32_t num_elements = sig_rec->num_elements;
2397       if (nir_src_is_const(*row_src)) {
2398          base_element = (uint32_t)nir_src_as_uint(*row_src);
2399          num_elements = 1;
2400       }
2401       for (uint32_t element = 0; element < num_elements; ++element) {
2402          uint32_t row = sig_rec->elements[element + base_element].reg;
2403          if (row == 0xffffffff)
2404             continue;
2405          uint32_t stream = sig_rec->elements[element + base_element].stream;
2406          uint32_t table_idx = is_patch_constant ? 1 : stream;
2407          uint32_t *table = tables[table_idx];
2408          uint32_t output_component = component + nir_intrinsic_component(intr);
2409          uint32_t input_component;
2410          BITSET_FOREACH_SET(input_component, input_bits, 32 * 4) {
2411             uint32_t *table_for_input_component = table + table_sizes[table_idx] * input_component;
2412             BITSET_SET(table_for_input_component, row * 4 + output_component);
2413             any_bits_set = true;
2414          }
2415       }
2416    }
2417    return any_bits_set;
2418 }
2419 
2420 static bool
propagate_input_to_output_dependencies(struct dxil_module * mod,nir_intrinsic_instr * load_intr,uint32_t ** tables,const uint32_t * table_sizes)2421 propagate_input_to_output_dependencies(struct dxil_module *mod, nir_intrinsic_instr *load_intr, uint32_t **tables, const uint32_t *table_sizes)
2422 {
2423    /* Which input components are being loaded by this instruction */
2424    BITSET_DECLARE(input_bits, 32 * 4) = { 0 };
2425    if (!set_input_bits(mod, load_intr, input_bits, &tables, &table_sizes))
2426       return false;
2427 
2428    nir_instr_worklist *worklist = nir_instr_worklist_create();
2429    nir_instr_worklist_push_tail(worklist, &load_intr->instr);
2430    bool any_bits_set = false;
2431    nir_foreach_instr_in_worklist(instr, worklist) {
2432       if (instr->pass_flags)
2433          continue;
2434 
2435       instr->pass_flags = 1;
2436       nir_foreach_def(instr, add_def_to_worklist, worklist);
2437       switch (instr->type) {
2438       case nir_instr_type_jump: {
2439          nir_jump_instr *jump = nir_instr_as_jump(instr);
2440          switch (jump->type) {
2441          case nir_jump_break:
2442          case nir_jump_continue: {
2443             nir_cf_node *parent = &instr->block->cf_node;
2444             while (parent->type != nir_cf_node_loop)
2445                parent = parent->parent;
2446             nir_foreach_block_in_cf_node(block, parent)
2447                nir_foreach_instr(i, block)
2448                nir_instr_worklist_push_tail(worklist, i);
2449             }
2450             break;
2451          default:
2452             unreachable("Don't expect any other jumps");
2453          }
2454          break;
2455       }
2456       case nir_instr_type_intrinsic: {
2457          nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
2458          switch (intr->intrinsic) {
2459          case nir_intrinsic_store_output:
2460          case nir_intrinsic_store_per_vertex_output:
2461             any_bits_set |= set_output_bits(mod, intr, input_bits, tables, table_sizes);
2462             break;
2463             /* TODO: Memory writes */
2464          default:
2465             break;
2466          }
2467          break;
2468       }
2469       default:
2470          break;
2471       }
2472    }
2473 
2474    nir_instr_worklist_destroy(worklist);
2475    return any_bits_set;
2476 }
2477 
2478 /* For every input load, compute the set of output stores that it can contribute to.
2479  * If it contributes to a store to memory, If it's used for control flow, then any
2480  * instruction in the CFG that it impacts is considered to contribute.
2481  * Ideally, we should also handle stores to outputs/memory and then loads from that
2482  * output/memory, but this is non-trivial and unclear how much impact that would have. */
2483 bool
dxil_nir_analyze_io_dependencies(struct dxil_module * mod,nir_shader * s)2484 dxil_nir_analyze_io_dependencies(struct dxil_module *mod, nir_shader *s)
2485 {
2486    bool any_outputs = false;
2487    for (uint32_t i = 0; i < 4; ++i)
2488       any_outputs |= mod->num_psv_outputs[i] > 0;
2489    if (mod->shader_kind == DXIL_HULL_SHADER)
2490       any_outputs |= mod->num_psv_patch_consts > 0;
2491    if (!any_outputs)
2492       return false;
2493 
2494    bool any_bits_set = false;
2495    nir_foreach_function(func, s) {
2496       assert(func->impl);
2497       /* Hull shaders have a patch constant function */
2498       assert(func->is_entrypoint || s->info.stage == MESA_SHADER_TESS_CTRL);
2499 
2500       /* Pass 1: input/view ID -> output dependencies */
2501       nir_foreach_block(block, func->impl) {
2502          nir_foreach_instr(instr, block) {
2503             if (instr->type != nir_instr_type_intrinsic)
2504                continue;
2505             nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
2506             uint32_t **tables = mod->io_dependency_table;
2507             const uint32_t *table_sizes = mod->dependency_table_dwords_per_input;
2508             switch (intr->intrinsic) {
2509             case nir_intrinsic_load_view_index:
2510                tables = mod->viewid_dependency_table;
2511                FALLTHROUGH;
2512             case nir_intrinsic_load_input:
2513             case nir_intrinsic_load_per_vertex_input:
2514             case nir_intrinsic_load_interpolated_input:
2515                break;
2516             default:
2517                continue;
2518             }
2519 
2520             clear_pass_flags(func->impl);
2521             any_bits_set |= propagate_input_to_output_dependencies(mod, intr, tables, table_sizes);
2522          }
2523       }
2524 
2525       /* Pass 2: output -> output dependencies */
2526       /* TODO */
2527    }
2528    return any_bits_set;
2529 }
2530 
2531 static enum pipe_format
get_format_for_var(unsigned num_comps,enum glsl_base_type sampled_type)2532 get_format_for_var(unsigned num_comps, enum glsl_base_type sampled_type)
2533 {
2534    switch (sampled_type) {
2535    case GLSL_TYPE_INT:
2536    case GLSL_TYPE_INT64:
2537    case GLSL_TYPE_INT16:
2538       switch (num_comps) {
2539       case 1: return PIPE_FORMAT_R32_SINT;
2540       case 2: return PIPE_FORMAT_R32G32_SINT;
2541       case 3: return PIPE_FORMAT_R32G32B32_SINT;
2542       case 4: return PIPE_FORMAT_R32G32B32A32_SINT;
2543       default: unreachable("Invalid num_comps");
2544       }
2545    case GLSL_TYPE_UINT:
2546    case GLSL_TYPE_UINT64:
2547    case GLSL_TYPE_UINT16:
2548       switch (num_comps) {
2549       case 1: return PIPE_FORMAT_R32_UINT;
2550       case 2: return PIPE_FORMAT_R32G32_UINT;
2551       case 3: return PIPE_FORMAT_R32G32B32_UINT;
2552       case 4: return PIPE_FORMAT_R32G32B32A32_UINT;
2553       default: unreachable("Invalid num_comps");
2554       }
2555    case GLSL_TYPE_FLOAT:
2556    case GLSL_TYPE_FLOAT16:
2557    case GLSL_TYPE_DOUBLE:
2558       switch (num_comps) {
2559       case 1: return PIPE_FORMAT_R32_FLOAT;
2560       case 2: return PIPE_FORMAT_R32G32_FLOAT;
2561       case 3: return PIPE_FORMAT_R32G32B32_FLOAT;
2562       case 4: return PIPE_FORMAT_R32G32B32A32_FLOAT;
2563       default: unreachable("Invalid num_comps");
2564       }
2565    default: unreachable("Invalid sampler return type");
2566    }
2567 }
2568 
2569 static unsigned
aoa_size(const struct glsl_type * type)2570 aoa_size(const struct glsl_type *type)
2571 {
2572    return glsl_type_is_array(type) ? glsl_get_aoa_size(type) : 1;
2573 }
2574 
2575 static bool
guess_image_format_for_var(nir_shader * s,nir_variable * var)2576 guess_image_format_for_var(nir_shader *s, nir_variable *var)
2577 {
2578    const struct glsl_type *base_type = glsl_without_array(var->type);
2579    if (!glsl_type_is_image(base_type))
2580       return false;
2581    if (var->data.image.format != PIPE_FORMAT_NONE)
2582       return false;
2583 
2584    nir_foreach_function_impl(impl, s) {
2585       nir_foreach_block(block, impl) {
2586          nir_foreach_instr(instr, block) {
2587             if (instr->type != nir_instr_type_intrinsic)
2588                continue;
2589             nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
2590             switch (intr->intrinsic) {
2591             case nir_intrinsic_image_deref_load:
2592             case nir_intrinsic_image_deref_store:
2593             case nir_intrinsic_image_deref_atomic:
2594             case nir_intrinsic_image_deref_atomic_swap:
2595                if (nir_intrinsic_get_var(intr, 0) != var)
2596                   continue;
2597                break;
2598             case nir_intrinsic_image_load:
2599             case nir_intrinsic_image_store:
2600             case nir_intrinsic_image_atomic:
2601             case nir_intrinsic_image_atomic_swap: {
2602                unsigned binding = nir_src_as_uint(intr->src[0]);
2603                if (binding < var->data.binding ||
2604                    binding >= var->data.binding + aoa_size(var->type))
2605                   continue;
2606                break;
2607                }
2608             default:
2609                continue;
2610             }
2611             break;
2612 
2613             switch (intr->intrinsic) {
2614             case nir_intrinsic_image_deref_load:
2615             case nir_intrinsic_image_load:
2616             case nir_intrinsic_image_deref_store:
2617             case nir_intrinsic_image_store:
2618                /* Increase unknown formats up to 4 components if a 4-component accessor is used */
2619                if (intr->num_components > util_format_get_nr_components(var->data.image.format))
2620                   var->data.image.format = get_format_for_var(intr->num_components, glsl_get_sampler_result_type(base_type));
2621                break;
2622             default:
2623                /* If an atomic is used, the image format must be 1-component; return immediately */
2624                var->data.image.format = get_format_for_var(1, glsl_get_sampler_result_type(base_type));
2625                return true;
2626             }
2627          }
2628       }
2629    }
2630    /* Dunno what it is, assume 4-component */
2631    if (var->data.image.format == PIPE_FORMAT_NONE)
2632       var->data.image.format = get_format_for_var(4, glsl_get_sampler_result_type(base_type));
2633    return true;
2634 }
2635 
2636 static void
update_intrinsic_format_and_type(nir_intrinsic_instr * intr,nir_variable * var)2637 update_intrinsic_format_and_type(nir_intrinsic_instr *intr, nir_variable *var)
2638 {
2639    nir_intrinsic_set_format(intr, var->data.image.format);
2640    nir_alu_type alu_type =
2641       nir_get_nir_type_for_glsl_base_type(glsl_get_sampler_result_type(glsl_without_array(var->type)));
2642    if (nir_intrinsic_has_src_type(intr))
2643       nir_intrinsic_set_src_type(intr, alu_type);
2644    else if (nir_intrinsic_has_dest_type(intr))
2645       nir_intrinsic_set_dest_type(intr, alu_type);
2646 }
2647 
2648 static bool
update_intrinsic_formats(nir_builder * b,nir_intrinsic_instr * intr,void * data)2649 update_intrinsic_formats(nir_builder *b, nir_intrinsic_instr *intr,
2650                          void *data)
2651 {
2652    if (!nir_intrinsic_has_format(intr))
2653       return false;
2654    nir_deref_instr *deref = nir_src_as_deref(intr->src[0]);
2655    if (deref) {
2656       nir_variable *var = nir_deref_instr_get_variable(deref);
2657       if (var)
2658          update_intrinsic_format_and_type(intr, var);
2659       return var != NULL;
2660    }
2661 
2662    if (!nir_intrinsic_has_range_base(intr))
2663       return false;
2664 
2665    unsigned binding = nir_src_as_uint(intr->src[0]);
2666    nir_foreach_variable_with_modes(var, b->shader, nir_var_image) {
2667       if (var->data.binding <= binding &&
2668           var->data.binding + aoa_size(var->type) > binding) {
2669          update_intrinsic_format_and_type(intr, var);
2670          return true;
2671       }
2672    }
2673    return false;
2674 }
2675 
2676 bool
dxil_nir_guess_image_formats(nir_shader * s)2677 dxil_nir_guess_image_formats(nir_shader *s)
2678 {
2679    bool progress = false;
2680    nir_foreach_variable_with_modes(var, s, nir_var_image) {
2681       progress |= guess_image_format_for_var(s, var);
2682    }
2683    nir_shader_intrinsics_pass(s, update_intrinsic_formats, nir_metadata_all,
2684                               NULL);
2685    return progress;
2686 }
2687 
2688 static void
set_binding_variables_coherent(nir_shader * s,nir_binding binding,nir_variable_mode modes)2689 set_binding_variables_coherent(nir_shader *s, nir_binding binding, nir_variable_mode modes)
2690 {
2691    nir_foreach_variable_with_modes(var, s, modes) {
2692       if (var->data.binding == binding.binding &&
2693           var->data.descriptor_set == binding.desc_set) {
2694          var->data.access |= ACCESS_COHERENT;
2695       }
2696    }
2697 }
2698 
2699 static void
set_deref_variables_coherent(nir_shader * s,nir_deref_instr * deref)2700 set_deref_variables_coherent(nir_shader *s, nir_deref_instr *deref)
2701 {
2702    while (deref->deref_type != nir_deref_type_var &&
2703           deref->deref_type != nir_deref_type_cast) {
2704       deref = nir_deref_instr_parent(deref);
2705    }
2706    if (deref->deref_type == nir_deref_type_var) {
2707       deref->var->data.access |= ACCESS_COHERENT;
2708       return;
2709    }
2710 
2711    /* For derefs with casts, we only support pre-lowered Vulkan accesses */
2712    assert(deref->deref_type == nir_deref_type_cast);
2713    nir_intrinsic_instr *cast_src = nir_instr_as_intrinsic(deref->parent.ssa->parent_instr);
2714    assert(cast_src->intrinsic == nir_intrinsic_load_vulkan_descriptor);
2715    nir_binding binding = nir_chase_binding(cast_src->src[0]);
2716    set_binding_variables_coherent(s, binding, nir_var_mem_ssbo);
2717 }
2718 
2719 static nir_def *
get_atomic_for_load_store(nir_builder * b,nir_intrinsic_instr * intr,unsigned bit_size)2720 get_atomic_for_load_store(nir_builder *b, nir_intrinsic_instr *intr, unsigned bit_size)
2721 {
2722    nir_def *zero = nir_imm_intN_t(b, 0, bit_size);
2723    switch (intr->intrinsic) {
2724    case nir_intrinsic_load_deref:
2725       return nir_deref_atomic(b, bit_size, intr->src[0].ssa, zero, .atomic_op = nir_atomic_op_iadd);
2726    case nir_intrinsic_load_ssbo:
2727       return nir_ssbo_atomic(b, bit_size, intr->src[0].ssa, intr->src[1].ssa, zero, .atomic_op = nir_atomic_op_iadd);
2728    case nir_intrinsic_image_deref_load:
2729       return nir_image_deref_atomic(b, bit_size, intr->src[0].ssa, intr->src[1].ssa, intr->src[2].ssa, zero, .atomic_op = nir_atomic_op_iadd);
2730    case nir_intrinsic_image_load:
2731       return nir_image_atomic(b, bit_size, intr->src[0].ssa, intr->src[1].ssa, intr->src[2].ssa, zero, .atomic_op = nir_atomic_op_iadd);
2732    case nir_intrinsic_store_deref:
2733       return nir_deref_atomic(b, bit_size, intr->src[0].ssa, intr->src[1].ssa, .atomic_op = nir_atomic_op_xchg);
2734    case nir_intrinsic_store_ssbo:
2735       return nir_ssbo_atomic(b, bit_size, intr->src[1].ssa, intr->src[2].ssa, intr->src[0].ssa, .atomic_op = nir_atomic_op_xchg);
2736    case nir_intrinsic_image_deref_store:
2737       return nir_image_deref_atomic(b, bit_size, intr->src[0].ssa, intr->src[1].ssa, intr->src[2].ssa, intr->src[3].ssa, .atomic_op = nir_atomic_op_xchg);
2738    case nir_intrinsic_image_store:
2739       return nir_image_atomic(b, bit_size, intr->src[0].ssa, intr->src[1].ssa, intr->src[2].ssa, intr->src[3].ssa, .atomic_op = nir_atomic_op_xchg);
2740    default:
2741       return NULL;
2742    }
2743 }
2744 
2745 static bool
lower_coherent_load_store(nir_builder * b,nir_intrinsic_instr * intr,void * context)2746 lower_coherent_load_store(nir_builder *b, nir_intrinsic_instr *intr, void *context)
2747 {
2748    if (!nir_intrinsic_has_access(intr) || (nir_intrinsic_access(intr) & ACCESS_COHERENT) == 0)
2749       return false;
2750 
2751    nir_def *atomic_def = NULL;
2752    b->cursor = nir_before_instr(&intr->instr);
2753    switch (intr->intrinsic) {
2754    case nir_intrinsic_load_deref:
2755    case nir_intrinsic_load_ssbo:
2756    case nir_intrinsic_image_deref_load:
2757    case nir_intrinsic_image_load: {
2758       if (intr->def.bit_size < 32 || intr->def.num_components > 1) {
2759          if (intr->intrinsic == nir_intrinsic_load_deref)
2760             set_deref_variables_coherent(b->shader, nir_src_as_deref(intr->src[0]));
2761          else {
2762             nir_binding binding = {0};
2763             if (nir_src_is_const(intr->src[0]))
2764                binding.binding = nir_src_as_uint(intr->src[0]);
2765             set_binding_variables_coherent(b->shader, binding,
2766                                            intr->intrinsic == nir_intrinsic_load_ssbo ? nir_var_mem_ssbo : nir_var_image);
2767          }
2768          return false;
2769       }
2770 
2771       atomic_def = get_atomic_for_load_store(b, intr, intr->def.bit_size);
2772       nir_def_rewrite_uses(&intr->def, atomic_def);
2773       break;
2774    }
2775    case nir_intrinsic_store_deref:
2776    case nir_intrinsic_store_ssbo:
2777    case nir_intrinsic_image_deref_store:
2778    case nir_intrinsic_image_store: {
2779       int resource_idx = intr->intrinsic == nir_intrinsic_store_ssbo ? 1 : 0;
2780       int value_idx = intr->intrinsic == nir_intrinsic_store_ssbo ? 0 :
2781          intr->intrinsic == nir_intrinsic_store_deref ? 1 : 3;
2782       unsigned num_components = nir_intrinsic_has_write_mask(intr) ?
2783          util_bitcount(nir_intrinsic_write_mask(intr)) : intr->src[value_idx].ssa->num_components;
2784       if (intr->src[value_idx].ssa->bit_size < 32 || num_components > 1) {
2785          if (intr->intrinsic == nir_intrinsic_store_deref)
2786             set_deref_variables_coherent(b->shader, nir_src_as_deref(intr->src[resource_idx]));
2787          else {
2788             nir_binding binding = {0};
2789             if (nir_src_is_const(intr->src[resource_idx]))
2790                binding.binding = nir_src_as_uint(intr->src[resource_idx]);
2791             set_binding_variables_coherent(b->shader, binding,
2792                                            intr->intrinsic == nir_intrinsic_store_ssbo ? nir_var_mem_ssbo : nir_var_image);
2793          }
2794          return false;
2795       }
2796 
2797       atomic_def = get_atomic_for_load_store(b, intr, intr->src[value_idx].ssa->bit_size);
2798       break;
2799    }
2800    default:
2801       return false;
2802    }
2803 
2804    nir_intrinsic_instr *atomic = nir_instr_as_intrinsic(atomic_def->parent_instr);
2805    nir_intrinsic_set_access(atomic, nir_intrinsic_access(intr));
2806    if (nir_intrinsic_has_image_dim(intr))
2807       nir_intrinsic_set_image_dim(atomic, nir_intrinsic_image_dim(intr));
2808    if (nir_intrinsic_has_image_array(intr))
2809       nir_intrinsic_set_image_array(atomic, nir_intrinsic_image_array(intr));
2810    if (nir_intrinsic_has_format(intr))
2811       nir_intrinsic_set_format(atomic, nir_intrinsic_format(intr));
2812    if (nir_intrinsic_has_range_base(intr))
2813       nir_intrinsic_set_range_base(atomic, nir_intrinsic_range_base(intr));
2814    nir_instr_remove(&intr->instr);
2815    return true;
2816 }
2817 
2818 bool
dxil_nir_lower_coherent_loads_and_stores(nir_shader * s)2819 dxil_nir_lower_coherent_loads_and_stores(nir_shader *s)
2820 {
2821    return nir_shader_intrinsics_pass(s, lower_coherent_load_store,
2822                                      nir_metadata_control_flow | nir_metadata_loop_analysis,
2823                                      NULL);
2824 }
2825 
2826 struct undefined_varying_masks {
2827    uint64_t io_mask;
2828    uint32_t patch_io_mask;
2829    const BITSET_WORD *frac_io_mask;
2830 };
2831 
2832 static bool
is_dead_in_variable(nir_variable * var,void * data)2833 is_dead_in_variable(nir_variable *var, void *data)
2834 {
2835    switch (var->data.location) {
2836    /* Only these values can be system generated values in addition to varyings */
2837    case VARYING_SLOT_PRIMITIVE_ID:
2838    case VARYING_SLOT_FACE:
2839    case VARYING_SLOT_VIEW_INDEX:
2840       return false;
2841    /* Tessellation input vars must remain untouched */
2842    case VARYING_SLOT_TESS_LEVEL_INNER:
2843    case VARYING_SLOT_TESS_LEVEL_OUTER:
2844       return false;
2845    default:
2846       return true;
2847    }
2848 }
2849 
2850 static bool
kill_undefined_varyings(struct nir_builder * b,nir_instr * instr,void * data)2851 kill_undefined_varyings(struct nir_builder *b,
2852                         nir_instr *instr,
2853                         void *data)
2854 {
2855    const struct undefined_varying_masks *masks = data;
2856 
2857    if (instr->type != nir_instr_type_intrinsic)
2858       return false;
2859 
2860    nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
2861 
2862    if (intr->intrinsic != nir_intrinsic_load_deref)
2863       return false;
2864 
2865    nir_variable *var = nir_intrinsic_get_var(intr, 0);
2866    if (!var || var->data.mode != nir_var_shader_in)
2867       return false;
2868 
2869    if (!is_dead_in_variable(var, NULL))
2870       return false;
2871 
2872    uint32_t loc = var->data.patch && var->data.location >= VARYING_SLOT_PATCH0 ?
2873       var->data.location - VARYING_SLOT_PATCH0 :
2874       var->data.location;
2875    uint64_t written = var->data.patch && var->data.location >= VARYING_SLOT_PATCH0 ?
2876       masks->patch_io_mask : masks->io_mask;
2877    if (BITFIELD64_RANGE(loc, glsl_varying_count(var->type)) & written) {
2878       if (!masks->frac_io_mask || !var->data.location_frac ||
2879           var->data.location < VARYING_SLOT_VAR0 ||
2880           BITSET_TEST(masks->frac_io_mask, (var->data.location - VARYING_SLOT_VAR0) * 4 + var->data.location_frac))
2881       return false;
2882    }
2883 
2884    b->cursor = nir_after_instr(instr);
2885    /* Note: zero is used instead of undef, because optimization is not run here, but is
2886     * run later on. If we load an undef here, and that undef ends up being used to store
2887     * to position later on, that can cause some or all of the components in that position
2888     * write to be removed, which is problematic especially in the case of all components,
2889     * since that would remove the store instruction, and would make it tricky to satisfy
2890     * the DXIL requirements of writing all position components.
2891     */
2892    nir_def *zero = nir_imm_zero(b, intr->def.num_components,
2893                                        intr->def.bit_size);
2894    nir_def_replace(&intr->def, zero);
2895    return true;
2896 }
2897 
2898 bool
dxil_nir_kill_undefined_varyings(nir_shader * shader,uint64_t prev_stage_written_mask,uint32_t prev_stage_patch_written_mask,const BITSET_WORD * prev_stage_frac_output_mask)2899 dxil_nir_kill_undefined_varyings(nir_shader *shader, uint64_t prev_stage_written_mask, uint32_t prev_stage_patch_written_mask,
2900                                  const BITSET_WORD *prev_stage_frac_output_mask)
2901 {
2902    struct undefined_varying_masks masks = {
2903       .io_mask = prev_stage_written_mask,
2904       .patch_io_mask = prev_stage_patch_written_mask,
2905       .frac_io_mask = prev_stage_frac_output_mask
2906    };
2907    bool progress = nir_shader_instructions_pass(shader,
2908                                                 kill_undefined_varyings,
2909                                                 nir_metadata_control_flow |
2910                                                 nir_metadata_loop_analysis,
2911                                                 (void *)&masks);
2912    if (progress) {
2913       nir_opt_dce(shader);
2914       nir_remove_dead_derefs(shader);
2915    }
2916 
2917    const struct nir_remove_dead_variables_options options = {
2918       .can_remove_var = is_dead_in_variable,
2919       .can_remove_var_data = &masks,
2920    };
2921    progress |= nir_remove_dead_variables(shader, nir_var_shader_in, &options);
2922    return progress;
2923 }
2924 
2925 static bool
is_dead_out_variable(nir_variable * var,void * data)2926 is_dead_out_variable(nir_variable *var, void *data)
2927 {
2928    return !nir_slot_is_sysval_output(var->data.location, MESA_SHADER_NONE);
2929 }
2930 
2931 static bool
kill_unused_outputs(struct nir_builder * b,nir_instr * instr,void * data)2932 kill_unused_outputs(struct nir_builder *b,
2933                     nir_instr *instr,
2934                     void *data)
2935 {
2936    const struct undefined_varying_masks *masks = data;
2937 
2938    if (instr->type != nir_instr_type_intrinsic)
2939       return false;
2940 
2941    nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
2942 
2943    if (intr->intrinsic != nir_intrinsic_store_deref &&
2944        intr->intrinsic != nir_intrinsic_load_deref)
2945       return false;
2946 
2947    nir_variable *var = nir_intrinsic_get_var(intr, 0);
2948    if (!var || var->data.mode != nir_var_shader_out ||
2949        /* always_active_io can mean two things: xfb or GL separable shaders. We can't delete
2950         * varyings that are used for xfb (we'll just sort them last), but we must delete varyings
2951         * that are mismatching between TCS and TES. Fortunately TCS can't do xfb, so we can ignore
2952         the always_active_io bit for TCS outputs. */
2953        (b->shader->info.stage != MESA_SHADER_TESS_CTRL && var->data.always_active_io))
2954       return false;
2955 
2956    if (!is_dead_out_variable(var, NULL))
2957       return false;
2958 
2959    unsigned loc = var->data.patch && var->data.location >= VARYING_SLOT_PATCH0 ?
2960       var->data.location - VARYING_SLOT_PATCH0 :
2961       var->data.location;
2962    uint64_t read = var->data.patch && var->data.location >= VARYING_SLOT_PATCH0 ?
2963       masks->patch_io_mask : masks->io_mask;
2964    if (BITFIELD64_RANGE(loc, glsl_varying_count(var->type)) & read) {
2965       if (!masks->frac_io_mask || !var->data.location_frac ||
2966           var->data.location < VARYING_SLOT_VAR0 ||
2967           BITSET_TEST(masks->frac_io_mask, (var->data.location - VARYING_SLOT_VAR0) * 4 + var->data.location_frac))
2968       return false;
2969    }
2970 
2971    if (intr->intrinsic == nir_intrinsic_load_deref) {
2972       b->cursor = nir_after_instr(&intr->instr);
2973       nir_def *zero = nir_imm_zero(b, intr->def.num_components, intr->def.bit_size);
2974       nir_def_rewrite_uses(&intr->def, zero);
2975    }
2976    nir_instr_remove(instr);
2977    return true;
2978 }
2979 
2980 bool
dxil_nir_kill_unused_outputs(nir_shader * shader,uint64_t next_stage_read_mask,uint32_t next_stage_patch_read_mask,const BITSET_WORD * next_stage_frac_input_mask)2981 dxil_nir_kill_unused_outputs(nir_shader *shader, uint64_t next_stage_read_mask, uint32_t next_stage_patch_read_mask,
2982                              const BITSET_WORD *next_stage_frac_input_mask)
2983 {
2984    struct undefined_varying_masks masks = {
2985       .io_mask = next_stage_read_mask,
2986       .patch_io_mask = next_stage_patch_read_mask,
2987       .frac_io_mask = next_stage_frac_input_mask
2988    };
2989 
2990    bool progress = nir_shader_instructions_pass(shader,
2991                                                 kill_unused_outputs,
2992                                                 nir_metadata_control_flow |
2993                                                 nir_metadata_loop_analysis,
2994                                                 (void *)&masks);
2995 
2996    if (progress) {
2997       nir_opt_dce(shader);
2998       nir_remove_dead_derefs(shader);
2999    }
3000    const struct nir_remove_dead_variables_options options = {
3001       .can_remove_var = is_dead_out_variable,
3002       .can_remove_var_data = &masks,
3003    };
3004    progress |= nir_remove_dead_variables(shader, nir_var_shader_out, &options);
3005    return progress;
3006 }
3007