• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright © Microsoft Corporation
3  *
4  * Permission is hereby granted, free of charge, to any person obtaining a
5  * copy of this software and associated documentation files (the "Software"),
6  * to deal in the Software without restriction, including without limitation
7  * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8  * and/or sell copies of the Software, and to permit persons to whom the
9  * Software is furnished to do so, subject to the following conditions:
10  *
11  * The above copyright notice and this permission notice (including the next
12  * paragraph) shall be included in all copies or substantial portions of the
13  * Software.
14  *
15  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18  * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20  * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21  * IN THE SOFTWARE.
22  */
23 
24 #include "dxil_nir.h"
25 
26 #include "nir_builder.h"
27 #include "nir_deref.h"
28 #include "nir_to_dxil.h"
29 #include "util/u_math.h"
30 #include "vulkan/vulkan_core.h"
31 
32 static void
cl_type_size_align(const struct glsl_type * type,unsigned * size,unsigned * align)33 cl_type_size_align(const struct glsl_type *type, unsigned *size,
34                    unsigned *align)
35 {
36    *size = glsl_get_cl_size(type);
37    *align = glsl_get_cl_alignment(type);
38 }
39 
40 static void
extract_comps_from_vec32(nir_builder * b,nir_ssa_def * vec32,unsigned dst_bit_size,nir_ssa_def ** dst_comps,unsigned num_dst_comps)41 extract_comps_from_vec32(nir_builder *b, nir_ssa_def *vec32,
42                          unsigned dst_bit_size,
43                          nir_ssa_def **dst_comps,
44                          unsigned num_dst_comps)
45 {
46    unsigned step = DIV_ROUND_UP(dst_bit_size, 32);
47    unsigned comps_per32b = 32 / dst_bit_size;
48    nir_ssa_def *tmp;
49 
50    for (unsigned i = 0; i < vec32->num_components; i += step) {
51       switch (dst_bit_size) {
52       case 64:
53          tmp = nir_pack_64_2x32_split(b, nir_channel(b, vec32, i),
54                                          nir_channel(b, vec32, i + 1));
55          dst_comps[i / 2] = tmp;
56          break;
57       case 32:
58          dst_comps[i] = nir_channel(b, vec32, i);
59          break;
60       case 16:
61       case 8: {
62          unsigned dst_offs = i * comps_per32b;
63 
64          tmp = nir_unpack_bits(b, nir_channel(b, vec32, i), dst_bit_size);
65          for (unsigned j = 0; j < comps_per32b && dst_offs + j < num_dst_comps; j++)
66             dst_comps[dst_offs + j] = nir_channel(b, tmp, j);
67          }
68 
69          break;
70       }
71    }
72 }
73 
74 static nir_ssa_def *
load_comps_to_vec32(nir_builder * b,unsigned src_bit_size,nir_ssa_def ** src_comps,unsigned num_src_comps)75 load_comps_to_vec32(nir_builder *b, unsigned src_bit_size,
76                     nir_ssa_def **src_comps, unsigned num_src_comps)
77 {
78    unsigned num_vec32comps = DIV_ROUND_UP(num_src_comps * src_bit_size, 32);
79    unsigned step = DIV_ROUND_UP(src_bit_size, 32);
80    unsigned comps_per32b = 32 / src_bit_size;
81    nir_ssa_def *vec32comps[4];
82 
83    for (unsigned i = 0; i < num_vec32comps; i += step) {
84       switch (src_bit_size) {
85       case 64:
86          vec32comps[i] = nir_unpack_64_2x32_split_x(b, src_comps[i / 2]);
87          vec32comps[i + 1] = nir_unpack_64_2x32_split_y(b, src_comps[i / 2]);
88          break;
89       case 32:
90          vec32comps[i] = src_comps[i];
91          break;
92       case 16:
93       case 8: {
94          unsigned src_offs = i * comps_per32b;
95 
96          vec32comps[i] = nir_u2u32(b, src_comps[src_offs]);
97          for (unsigned j = 1; j < comps_per32b && src_offs + j < num_src_comps; j++) {
98             nir_ssa_def *tmp = nir_ishl(b, nir_u2u32(b, src_comps[src_offs + j]),
99                                            nir_imm_int(b, j * src_bit_size));
100             vec32comps[i] = nir_ior(b, vec32comps[i], tmp);
101          }
102          break;
103       }
104       }
105    }
106 
107    return nir_vec(b, vec32comps, num_vec32comps);
108 }
109 
110 static nir_ssa_def *
build_load_ptr_dxil(nir_builder * b,nir_deref_instr * deref,nir_ssa_def * idx)111 build_load_ptr_dxil(nir_builder *b, nir_deref_instr *deref, nir_ssa_def *idx)
112 {
113    return nir_load_ptr_dxil(b, 1, 32, &deref->dest.ssa, idx);
114 }
115 
116 static bool
lower_load_deref(nir_builder * b,nir_intrinsic_instr * intr)117 lower_load_deref(nir_builder *b, nir_intrinsic_instr *intr)
118 {
119    assert(intr->dest.is_ssa);
120 
121    b->cursor = nir_before_instr(&intr->instr);
122 
123    nir_deref_instr *deref = nir_src_as_deref(intr->src[0]);
124    if (!nir_deref_mode_is(deref, nir_var_shader_temp))
125       return false;
126    nir_ssa_def *ptr = nir_u2u32(b, nir_build_deref_offset(b, deref, cl_type_size_align));
127    nir_ssa_def *offset = nir_iand(b, ptr, nir_inot(b, nir_imm_int(b, 3)));
128 
129    assert(intr->dest.is_ssa);
130    unsigned num_components = nir_dest_num_components(intr->dest);
131    unsigned bit_size = nir_dest_bit_size(intr->dest);
132    unsigned load_size = MAX2(32, bit_size);
133    unsigned num_bits = num_components * bit_size;
134    nir_ssa_def *comps[NIR_MAX_VEC_COMPONENTS];
135    unsigned comp_idx = 0;
136 
137    nir_deref_path path;
138    nir_deref_path_init(&path, deref, NULL);
139    nir_ssa_def *base_idx = nir_ishr(b, offset, nir_imm_int(b, 2 /* log2(32 / 8) */));
140 
141    /* Split loads into 32-bit chunks */
142    for (unsigned i = 0; i < num_bits; i += load_size) {
143       unsigned subload_num_bits = MIN2(num_bits - i, load_size);
144       nir_ssa_def *idx = nir_iadd(b, base_idx, nir_imm_int(b, i / 32));
145       nir_ssa_def *vec32 = build_load_ptr_dxil(b, path.path[0], idx);
146 
147       if (load_size == 64) {
148          idx = nir_iadd(b, idx, nir_imm_int(b, 1));
149          vec32 = nir_vec2(b, vec32,
150                              build_load_ptr_dxil(b, path.path[0], idx));
151       }
152 
153       /* If we have 2 bytes or less to load we need to adjust the u32 value so
154        * we can always extract the LSB.
155        */
156       if (subload_num_bits <= 16) {
157          nir_ssa_def *shift = nir_imul(b, nir_iand(b, ptr, nir_imm_int(b, 3)),
158                                           nir_imm_int(b, 8));
159          vec32 = nir_ushr(b, vec32, shift);
160       }
161 
162       /* And now comes the pack/unpack step to match the original type. */
163       extract_comps_from_vec32(b, vec32, bit_size, &comps[comp_idx],
164                                subload_num_bits / bit_size);
165       comp_idx += subload_num_bits / bit_size;
166    }
167 
168    nir_deref_path_finish(&path);
169    assert(comp_idx == num_components);
170    nir_ssa_def *result = nir_vec(b, comps, num_components);
171    nir_ssa_def_rewrite_uses(&intr->dest.ssa, result);
172    nir_instr_remove(&intr->instr);
173    return true;
174 }
175 
176 static nir_ssa_def *
ubo_load_select_32b_comps(nir_builder * b,nir_ssa_def * vec32,nir_ssa_def * offset,unsigned num_bytes)177 ubo_load_select_32b_comps(nir_builder *b, nir_ssa_def *vec32,
178                           nir_ssa_def *offset, unsigned num_bytes)
179 {
180    assert(num_bytes == 16 || num_bytes == 12 || num_bytes == 8 ||
181           num_bytes == 4 || num_bytes == 3 || num_bytes == 2 ||
182           num_bytes == 1);
183    assert(vec32->num_components == 4);
184 
185    /* 16 and 12 byte types are always aligned on 16 bytes. */
186    if (num_bytes > 8)
187       return vec32;
188 
189    nir_ssa_def *comps[4];
190    nir_ssa_def *cond;
191 
192    for (unsigned i = 0; i < 4; i++)
193       comps[i] = nir_channel(b, vec32, i);
194 
195    /* If we have 8bytes or less to load, select which half the vec4 should
196     * be used.
197     */
198    cond = nir_ine(b, nir_iand(b, offset, nir_imm_int(b, 0x8)),
199                                  nir_imm_int(b, 0));
200 
201    comps[0] = nir_bcsel(b, cond, comps[2], comps[0]);
202    comps[1] = nir_bcsel(b, cond, comps[3], comps[1]);
203 
204    /* Thanks to the CL alignment constraints, if we want 8 bytes we're done. */
205    if (num_bytes == 8)
206       return nir_vec(b, comps, 2);
207 
208    /* 4 bytes or less needed, select which of the 32bit component should be
209     * used and return it. The sub-32bit split is handled in
210     * extract_comps_from_vec32().
211     */
212    cond = nir_ine(b, nir_iand(b, offset, nir_imm_int(b, 0x4)),
213                                  nir_imm_int(b, 0));
214    return nir_bcsel(b, cond, comps[1], comps[0]);
215 }
216 
217 nir_ssa_def *
build_load_ubo_dxil(nir_builder * b,nir_ssa_def * buffer,nir_ssa_def * offset,unsigned num_components,unsigned bit_size)218 build_load_ubo_dxil(nir_builder *b, nir_ssa_def *buffer,
219                     nir_ssa_def *offset, unsigned num_components,
220                     unsigned bit_size)
221 {
222    nir_ssa_def *idx = nir_ushr(b, offset, nir_imm_int(b, 4));
223    nir_ssa_def *comps[NIR_MAX_VEC_COMPONENTS];
224    unsigned num_bits = num_components * bit_size;
225    unsigned comp_idx = 0;
226 
227    /* We need to split loads in 16byte chunks because that's the
228     * granularity of cBufferLoadLegacy().
229     */
230    for (unsigned i = 0; i < num_bits; i += (16 * 8)) {
231       /* For each 16byte chunk (or smaller) we generate a 32bit ubo vec
232        * load.
233        */
234       unsigned subload_num_bits = MIN2(num_bits - i, 16 * 8);
235       nir_ssa_def *vec32 =
236          nir_load_ubo_dxil(b, 4, 32, buffer, nir_iadd(b, idx, nir_imm_int(b, i / (16 * 8))));
237 
238       /* First re-arrange the vec32 to account for intra 16-byte offset. */
239       vec32 = ubo_load_select_32b_comps(b, vec32, offset, subload_num_bits / 8);
240 
241       /* If we have 2 bytes or less to load we need to adjust the u32 value so
242        * we can always extract the LSB.
243        */
244       if (subload_num_bits <= 16) {
245          nir_ssa_def *shift = nir_imul(b, nir_iand(b, offset,
246                                                       nir_imm_int(b, 3)),
247                                           nir_imm_int(b, 8));
248          vec32 = nir_ushr(b, vec32, shift);
249       }
250 
251       /* And now comes the pack/unpack step to match the original type. */
252       extract_comps_from_vec32(b, vec32, bit_size, &comps[comp_idx],
253                                subload_num_bits / bit_size);
254       comp_idx += subload_num_bits / bit_size;
255    }
256 
257    assert(comp_idx == num_components);
258    return nir_vec(b, comps, num_components);
259 }
260 
261 static bool
lower_load_ssbo(nir_builder * b,nir_intrinsic_instr * intr)262 lower_load_ssbo(nir_builder *b, nir_intrinsic_instr *intr)
263 {
264    assert(intr->dest.is_ssa);
265    assert(intr->src[0].is_ssa);
266    assert(intr->src[1].is_ssa);
267 
268    b->cursor = nir_before_instr(&intr->instr);
269 
270    nir_ssa_def *buffer = intr->src[0].ssa;
271    nir_ssa_def *offset = nir_iand(b, intr->src[1].ssa, nir_imm_int(b, ~3));
272    enum gl_access_qualifier access = nir_intrinsic_access(intr);
273    unsigned bit_size = nir_dest_bit_size(intr->dest);
274    unsigned num_components = nir_dest_num_components(intr->dest);
275    unsigned num_bits = num_components * bit_size;
276 
277    nir_ssa_def *comps[NIR_MAX_VEC_COMPONENTS];
278    unsigned comp_idx = 0;
279 
280    /* We need to split loads in 16byte chunks because that's the optimal
281     * granularity of bufferLoad(). Minimum alignment is 4byte, which saves
282     * from us from extra complexity to extract >= 32 bit components.
283     */
284    for (unsigned i = 0; i < num_bits; i += 4 * 32) {
285       /* For each 16byte chunk (or smaller) we generate a 32bit ssbo vec
286        * load.
287        */
288       unsigned subload_num_bits = MIN2(num_bits - i, 4 * 32);
289 
290       /* The number of components to store depends on the number of bytes. */
291       nir_ssa_def *vec32 =
292          nir_load_ssbo(b, DIV_ROUND_UP(subload_num_bits, 32), 32,
293                        buffer, nir_iadd(b, offset, nir_imm_int(b, i / 8)),
294                        .align_mul = 4,
295                        .align_offset = 0,
296                        .access = access);
297 
298       /* If we have 2 bytes or less to load we need to adjust the u32 value so
299        * we can always extract the LSB.
300        */
301       if (subload_num_bits <= 16) {
302          nir_ssa_def *shift = nir_imul(b, nir_iand(b, intr->src[1].ssa, nir_imm_int(b, 3)),
303                                           nir_imm_int(b, 8));
304          vec32 = nir_ushr(b, vec32, shift);
305       }
306 
307       /* And now comes the pack/unpack step to match the original type. */
308       extract_comps_from_vec32(b, vec32, bit_size, &comps[comp_idx],
309                                subload_num_bits / bit_size);
310       comp_idx += subload_num_bits / bit_size;
311    }
312 
313    assert(comp_idx == num_components);
314    nir_ssa_def *result = nir_vec(b, comps, num_components);
315    nir_ssa_def_rewrite_uses(&intr->dest.ssa, result);
316    nir_instr_remove(&intr->instr);
317    return true;
318 }
319 
320 static bool
lower_store_ssbo(nir_builder * b,nir_intrinsic_instr * intr)321 lower_store_ssbo(nir_builder *b, nir_intrinsic_instr *intr)
322 {
323    b->cursor = nir_before_instr(&intr->instr);
324 
325    assert(intr->src[0].is_ssa);
326    assert(intr->src[1].is_ssa);
327    assert(intr->src[2].is_ssa);
328 
329    nir_ssa_def *val = intr->src[0].ssa;
330    nir_ssa_def *buffer = intr->src[1].ssa;
331    nir_ssa_def *offset = nir_iand(b, intr->src[2].ssa, nir_imm_int(b, ~3));
332 
333    unsigned bit_size = val->bit_size;
334    unsigned num_components = val->num_components;
335    unsigned num_bits = num_components * bit_size;
336 
337    nir_ssa_def *comps[NIR_MAX_VEC_COMPONENTS] = { 0 };
338    unsigned comp_idx = 0;
339 
340    unsigned write_mask = nir_intrinsic_write_mask(intr);
341    for (unsigned i = 0; i < num_components; i++)
342       if (write_mask & (1 << i))
343          comps[i] = nir_channel(b, val, i);
344 
345    /* We split stores in 16byte chunks because that's the optimal granularity
346     * of bufferStore(). Minimum alignment is 4byte, which saves from us from
347     * extra complexity to store >= 32 bit components.
348     */
349    unsigned bit_offset = 0;
350    while (true) {
351       /* Skip over holes in the write mask */
352       while (comp_idx < num_components && comps[comp_idx] == NULL) {
353          comp_idx++;
354          bit_offset += bit_size;
355       }
356       if (comp_idx >= num_components)
357          break;
358 
359       /* For each 16byte chunk (or smaller) we generate a 32bit ssbo vec
360        * store. If a component is skipped by the write mask, do a smaller
361        * sub-store
362        */
363       unsigned num_src_comps_stored = 0, substore_num_bits = 0;
364       while(num_src_comps_stored + comp_idx < num_components &&
365             substore_num_bits + bit_offset < num_bits &&
366             substore_num_bits < 4 * 32 &&
367             comps[comp_idx + num_src_comps_stored]) {
368          ++num_src_comps_stored;
369          substore_num_bits += bit_size;
370       }
371       nir_ssa_def *local_offset = nir_iadd(b, offset, nir_imm_int(b, bit_offset / 8));
372       nir_ssa_def *vec32 = load_comps_to_vec32(b, bit_size, &comps[comp_idx],
373                                                num_src_comps_stored);
374       nir_intrinsic_instr *store;
375 
376       if (substore_num_bits < 32) {
377          nir_ssa_def *mask = nir_imm_int(b, (1 << substore_num_bits) - 1);
378 
379         /* If we have 16 bits or less to store we need to place them
380          * correctly in the u32 component. Anything greater than 16 bits
381          * (including uchar3) is naturally aligned on 32bits.
382          */
383          if (substore_num_bits <= 16) {
384             nir_ssa_def *pos = nir_iand(b, intr->src[2].ssa, nir_imm_int(b, 3));
385             nir_ssa_def *shift = nir_imul_imm(b, pos, 8);
386 
387             vec32 = nir_ishl(b, vec32, shift);
388             mask = nir_ishl(b, mask, shift);
389          }
390 
391          store = nir_intrinsic_instr_create(b->shader,
392                                             nir_intrinsic_store_ssbo_masked_dxil);
393          store->src[0] = nir_src_for_ssa(vec32);
394          store->src[1] = nir_src_for_ssa(nir_inot(b, mask));
395          store->src[2] = nir_src_for_ssa(buffer);
396          store->src[3] = nir_src_for_ssa(local_offset);
397       } else {
398          store = nir_intrinsic_instr_create(b->shader,
399                                             nir_intrinsic_store_ssbo);
400          store->src[0] = nir_src_for_ssa(vec32);
401          store->src[1] = nir_src_for_ssa(buffer);
402          store->src[2] = nir_src_for_ssa(local_offset);
403 
404          nir_intrinsic_set_align(store, 4, 0);
405       }
406 
407       /* The number of components to store depends on the number of bits. */
408       store->num_components = DIV_ROUND_UP(substore_num_bits, 32);
409       nir_builder_instr_insert(b, &store->instr);
410       comp_idx += num_src_comps_stored;
411       bit_offset += substore_num_bits;
412 
413       if (nir_intrinsic_has_write_mask(store))
414          nir_intrinsic_set_write_mask(store, (1 << store->num_components) - 1);
415    }
416 
417    nir_instr_remove(&intr->instr);
418    return true;
419 }
420 
421 static void
lower_load_vec32(nir_builder * b,nir_ssa_def * index,unsigned num_comps,nir_ssa_def ** comps,nir_intrinsic_op op)422 lower_load_vec32(nir_builder *b, nir_ssa_def *index, unsigned num_comps, nir_ssa_def **comps, nir_intrinsic_op op)
423 {
424    for (unsigned i = 0; i < num_comps; i++) {
425       nir_intrinsic_instr *load =
426          nir_intrinsic_instr_create(b->shader, op);
427 
428       load->num_components = 1;
429       load->src[0] = nir_src_for_ssa(nir_iadd(b, index, nir_imm_int(b, i)));
430       nir_ssa_dest_init(&load->instr, &load->dest, 1, 32, NULL);
431       nir_builder_instr_insert(b, &load->instr);
432       comps[i] = &load->dest.ssa;
433    }
434 }
435 
436 static bool
lower_32b_offset_load(nir_builder * b,nir_intrinsic_instr * intr)437 lower_32b_offset_load(nir_builder *b, nir_intrinsic_instr *intr)
438 {
439    assert(intr->dest.is_ssa);
440    unsigned bit_size = nir_dest_bit_size(intr->dest);
441    unsigned num_components = nir_dest_num_components(intr->dest);
442    unsigned num_bits = num_components * bit_size;
443 
444    b->cursor = nir_before_instr(&intr->instr);
445    nir_intrinsic_op op = intr->intrinsic;
446 
447    assert(intr->src[0].is_ssa);
448    nir_ssa_def *offset = intr->src[0].ssa;
449    if (op == nir_intrinsic_load_shared) {
450       offset = nir_iadd(b, offset, nir_imm_int(b, nir_intrinsic_base(intr)));
451       op = nir_intrinsic_load_shared_dxil;
452    } else {
453       offset = nir_u2u32(b, offset);
454       op = nir_intrinsic_load_scratch_dxil;
455    }
456    nir_ssa_def *index = nir_ushr(b, offset, nir_imm_int(b, 2));
457    nir_ssa_def *comps[NIR_MAX_VEC_COMPONENTS];
458    nir_ssa_def *comps_32bit[NIR_MAX_VEC_COMPONENTS * 2];
459 
460    /* We need to split loads in 32-bit accesses because the buffer
461     * is an i32 array and DXIL does not support type casts.
462     */
463    unsigned num_32bit_comps = DIV_ROUND_UP(num_bits, 32);
464    lower_load_vec32(b, index, num_32bit_comps, comps_32bit, op);
465    unsigned num_comps_per_pass = MIN2(num_32bit_comps, 4);
466 
467    for (unsigned i = 0; i < num_32bit_comps; i += num_comps_per_pass) {
468       unsigned num_vec32_comps = MIN2(num_32bit_comps - i, 4);
469       unsigned num_dest_comps = num_vec32_comps * 32 / bit_size;
470       nir_ssa_def *vec32 = nir_vec(b, &comps_32bit[i], num_vec32_comps);
471 
472       /* If we have 16 bits or less to load we need to adjust the u32 value so
473        * we can always extract the LSB.
474        */
475       if (num_bits <= 16) {
476          nir_ssa_def *shift =
477             nir_imul(b, nir_iand(b, offset, nir_imm_int(b, 3)),
478                         nir_imm_int(b, 8));
479          vec32 = nir_ushr(b, vec32, shift);
480       }
481 
482       /* And now comes the pack/unpack step to match the original type. */
483       unsigned dest_index = i * 32 / bit_size;
484       extract_comps_from_vec32(b, vec32, bit_size, &comps[dest_index], num_dest_comps);
485    }
486 
487    nir_ssa_def *result = nir_vec(b, comps, num_components);
488    nir_ssa_def_rewrite_uses(&intr->dest.ssa, result);
489    nir_instr_remove(&intr->instr);
490 
491    return true;
492 }
493 
494 static void
lower_store_vec32(nir_builder * b,nir_ssa_def * index,nir_ssa_def * vec32,nir_intrinsic_op op)495 lower_store_vec32(nir_builder *b, nir_ssa_def *index, nir_ssa_def *vec32, nir_intrinsic_op op)
496 {
497 
498    for (unsigned i = 0; i < vec32->num_components; i++) {
499       nir_intrinsic_instr *store =
500          nir_intrinsic_instr_create(b->shader, op);
501 
502       store->src[0] = nir_src_for_ssa(nir_channel(b, vec32, i));
503       store->src[1] = nir_src_for_ssa(nir_iadd(b, index, nir_imm_int(b, i)));
504       store->num_components = 1;
505       nir_builder_instr_insert(b, &store->instr);
506    }
507 }
508 
509 static void
lower_masked_store_vec32(nir_builder * b,nir_ssa_def * offset,nir_ssa_def * index,nir_ssa_def * vec32,unsigned num_bits,nir_intrinsic_op op)510 lower_masked_store_vec32(nir_builder *b, nir_ssa_def *offset, nir_ssa_def *index,
511                          nir_ssa_def *vec32, unsigned num_bits, nir_intrinsic_op op)
512 {
513    nir_ssa_def *mask = nir_imm_int(b, (1 << num_bits) - 1);
514 
515    /* If we have 16 bits or less to store we need to place them correctly in
516     * the u32 component. Anything greater than 16 bits (including uchar3) is
517     * naturally aligned on 32bits.
518     */
519    if (num_bits <= 16) {
520       nir_ssa_def *shift =
521          nir_imul_imm(b, nir_iand(b, offset, nir_imm_int(b, 3)), 8);
522 
523       vec32 = nir_ishl(b, vec32, shift);
524       mask = nir_ishl(b, mask, shift);
525    }
526 
527    if (op == nir_intrinsic_store_shared_dxil) {
528       /* Use the dedicated masked intrinsic */
529       nir_store_shared_masked_dxil(b, vec32, nir_inot(b, mask), index);
530    } else {
531       /* For scratch, since we don't need atomics, just generate the read-modify-write in NIR */
532       nir_ssa_def *load = nir_load_scratch_dxil(b, 1, 32, index);
533 
534       nir_ssa_def *new_val = nir_ior(b, vec32,
535                                      nir_iand(b,
536                                               nir_inot(b, mask),
537                                               load));
538 
539       lower_store_vec32(b, index, new_val, op);
540    }
541 }
542 
543 static bool
lower_32b_offset_store(nir_builder * b,nir_intrinsic_instr * intr)544 lower_32b_offset_store(nir_builder *b, nir_intrinsic_instr *intr)
545 {
546    assert(intr->src[0].is_ssa);
547    unsigned num_components = nir_src_num_components(intr->src[0]);
548    unsigned bit_size = nir_src_bit_size(intr->src[0]);
549    unsigned num_bits = num_components * bit_size;
550 
551    b->cursor = nir_before_instr(&intr->instr);
552    nir_intrinsic_op op = intr->intrinsic;
553 
554    nir_ssa_def *offset = intr->src[1].ssa;
555    if (op == nir_intrinsic_store_shared) {
556       offset = nir_iadd(b, offset, nir_imm_int(b, nir_intrinsic_base(intr)));
557       op = nir_intrinsic_store_shared_dxil;
558    } else {
559       offset = nir_u2u32(b, offset);
560       op = nir_intrinsic_store_scratch_dxil;
561    }
562    nir_ssa_def *comps[NIR_MAX_VEC_COMPONENTS];
563 
564    unsigned comp_idx = 0;
565    for (unsigned i = 0; i < num_components; i++)
566       comps[i] = nir_channel(b, intr->src[0].ssa, i);
567 
568    for (unsigned i = 0; i < num_bits; i += 4 * 32) {
569       /* For each 4byte chunk (or smaller) we generate a 32bit scalar store.
570        */
571       unsigned substore_num_bits = MIN2(num_bits - i, 4 * 32);
572       nir_ssa_def *local_offset = nir_iadd(b, offset, nir_imm_int(b, i / 8));
573       nir_ssa_def *vec32 = load_comps_to_vec32(b, bit_size, &comps[comp_idx],
574                                                substore_num_bits / bit_size);
575       nir_ssa_def *index = nir_ushr(b, local_offset, nir_imm_int(b, 2));
576 
577       /* For anything less than 32bits we need to use the masked version of the
578        * intrinsic to preserve data living in the same 32bit slot.
579        */
580       if (num_bits < 32) {
581          lower_masked_store_vec32(b, local_offset, index, vec32, num_bits, op);
582       } else {
583          lower_store_vec32(b, index, vec32, op);
584       }
585 
586       comp_idx += substore_num_bits / bit_size;
587    }
588 
589    nir_instr_remove(&intr->instr);
590 
591    return true;
592 }
593 
594 static void
ubo_to_temp_patch_deref_mode(nir_deref_instr * deref)595 ubo_to_temp_patch_deref_mode(nir_deref_instr *deref)
596 {
597    deref->modes = nir_var_shader_temp;
598    nir_foreach_use(use_src, &deref->dest.ssa) {
599       if (use_src->parent_instr->type != nir_instr_type_deref)
600          continue;
601 
602       nir_deref_instr *parent = nir_instr_as_deref(use_src->parent_instr);
603       ubo_to_temp_patch_deref_mode(parent);
604    }
605 }
606 
607 static void
ubo_to_temp_update_entry(nir_deref_instr * deref,struct hash_entry * he)608 ubo_to_temp_update_entry(nir_deref_instr *deref, struct hash_entry *he)
609 {
610    assert(nir_deref_mode_is(deref, nir_var_mem_constant));
611    assert(deref->dest.is_ssa);
612    assert(he->data);
613 
614    nir_foreach_use(use_src, &deref->dest.ssa) {
615       if (use_src->parent_instr->type == nir_instr_type_deref) {
616          ubo_to_temp_update_entry(nir_instr_as_deref(use_src->parent_instr), he);
617       } else if (use_src->parent_instr->type == nir_instr_type_intrinsic) {
618          nir_intrinsic_instr *intr = nir_instr_as_intrinsic(use_src->parent_instr);
619          if (intr->intrinsic != nir_intrinsic_load_deref)
620             he->data = NULL;
621       } else {
622          he->data = NULL;
623       }
624 
625       if (!he->data)
626          break;
627    }
628 }
629 
630 bool
dxil_nir_lower_ubo_to_temp(nir_shader * nir)631 dxil_nir_lower_ubo_to_temp(nir_shader *nir)
632 {
633    struct hash_table *ubo_to_temp = _mesa_pointer_hash_table_create(NULL);
634    bool progress = false;
635 
636    /* First pass: collect all UBO accesses that could be turned into
637     * shader temp accesses.
638     */
639    foreach_list_typed(nir_function, func, node, &nir->functions) {
640       if (!func->is_entrypoint)
641          continue;
642       assert(func->impl);
643 
644       nir_foreach_block(block, func->impl) {
645          nir_foreach_instr_safe(instr, block) {
646             if (instr->type != nir_instr_type_deref)
647                continue;
648 
649             nir_deref_instr *deref = nir_instr_as_deref(instr);
650             if (!nir_deref_mode_is(deref, nir_var_mem_constant) ||
651                 deref->deref_type != nir_deref_type_var)
652                   continue;
653 
654             struct hash_entry *he =
655                _mesa_hash_table_search(ubo_to_temp, deref->var);
656 
657             if (!he)
658                he = _mesa_hash_table_insert(ubo_to_temp, deref->var, deref->var);
659 
660             if (!he->data)
661                continue;
662 
663             ubo_to_temp_update_entry(deref, he);
664          }
665       }
666    }
667 
668    hash_table_foreach(ubo_to_temp, he) {
669       nir_variable *var = he->data;
670 
671       if (!var)
672          continue;
673 
674       /* Change the variable mode. */
675       var->data.mode = nir_var_shader_temp;
676 
677       /* Make sure the variable has a name.
678        * DXIL variables must have names.
679        */
680       if (!var->name)
681          var->name = ralloc_asprintf(nir, "global_%d", exec_list_length(&nir->variables));
682 
683       progress = true;
684    }
685    _mesa_hash_table_destroy(ubo_to_temp, NULL);
686 
687    /* Second pass: patch all derefs that were accessing the converted UBOs
688     * variables.
689     */
690    foreach_list_typed(nir_function, func, node, &nir->functions) {
691       if (!func->is_entrypoint)
692          continue;
693       assert(func->impl);
694 
695       nir_foreach_block(block, func->impl) {
696          nir_foreach_instr_safe(instr, block) {
697             if (instr->type != nir_instr_type_deref)
698                continue;
699 
700             nir_deref_instr *deref = nir_instr_as_deref(instr);
701             if (nir_deref_mode_is(deref, nir_var_mem_constant) &&
702                 deref->deref_type == nir_deref_type_var &&
703                 deref->var->data.mode == nir_var_shader_temp)
704                ubo_to_temp_patch_deref_mode(deref);
705          }
706       }
707    }
708 
709    return progress;
710 }
711 
712 static bool
lower_load_ubo(nir_builder * b,nir_intrinsic_instr * intr)713 lower_load_ubo(nir_builder *b, nir_intrinsic_instr *intr)
714 {
715    assert(intr->dest.is_ssa);
716    assert(intr->src[0].is_ssa);
717    assert(intr->src[1].is_ssa);
718 
719    b->cursor = nir_before_instr(&intr->instr);
720 
721    nir_ssa_def *result =
722       build_load_ubo_dxil(b, intr->src[0].ssa, intr->src[1].ssa,
723                              nir_dest_num_components(intr->dest),
724                              nir_dest_bit_size(intr->dest));
725 
726    nir_ssa_def_rewrite_uses(&intr->dest.ssa, result);
727    nir_instr_remove(&intr->instr);
728    return true;
729 }
730 
731 bool
dxil_nir_lower_loads_stores_to_dxil(nir_shader * nir)732 dxil_nir_lower_loads_stores_to_dxil(nir_shader *nir)
733 {
734    bool progress = false;
735 
736    foreach_list_typed(nir_function, func, node, &nir->functions) {
737       if (!func->is_entrypoint)
738          continue;
739       assert(func->impl);
740 
741       nir_builder b;
742       nir_builder_init(&b, func->impl);
743 
744       nir_foreach_block(block, func->impl) {
745          nir_foreach_instr_safe(instr, block) {
746             if (instr->type != nir_instr_type_intrinsic)
747                continue;
748             nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
749 
750             switch (intr->intrinsic) {
751             case nir_intrinsic_load_deref:
752                progress |= lower_load_deref(&b, intr);
753                break;
754             case nir_intrinsic_load_shared:
755             case nir_intrinsic_load_scratch:
756                progress |= lower_32b_offset_load(&b, intr);
757                break;
758             case nir_intrinsic_load_ssbo:
759                progress |= lower_load_ssbo(&b, intr);
760                break;
761             case nir_intrinsic_load_ubo:
762                progress |= lower_load_ubo(&b, intr);
763                break;
764             case nir_intrinsic_store_shared:
765             case nir_intrinsic_store_scratch:
766                progress |= lower_32b_offset_store(&b, intr);
767                break;
768             case nir_intrinsic_store_ssbo:
769                progress |= lower_store_ssbo(&b, intr);
770                break;
771             default:
772                break;
773             }
774          }
775       }
776    }
777 
778    return progress;
779 }
780 
781 static bool
lower_shared_atomic(nir_builder * b,nir_intrinsic_instr * intr,nir_intrinsic_op dxil_op)782 lower_shared_atomic(nir_builder *b, nir_intrinsic_instr *intr,
783                     nir_intrinsic_op dxil_op)
784 {
785    b->cursor = nir_before_instr(&intr->instr);
786 
787    assert(intr->src[0].is_ssa);
788    nir_ssa_def *offset =
789       nir_iadd(b, intr->src[0].ssa, nir_imm_int(b, nir_intrinsic_base(intr)));
790    nir_ssa_def *index = nir_ushr(b, offset, nir_imm_int(b, 2));
791 
792    nir_intrinsic_instr *atomic = nir_intrinsic_instr_create(b->shader, dxil_op);
793    atomic->src[0] = nir_src_for_ssa(index);
794    assert(intr->src[1].is_ssa);
795    atomic->src[1] = nir_src_for_ssa(intr->src[1].ssa);
796    if (dxil_op == nir_intrinsic_shared_atomic_comp_swap_dxil) {
797       assert(intr->src[2].is_ssa);
798       atomic->src[2] = nir_src_for_ssa(intr->src[2].ssa);
799    }
800    atomic->num_components = 0;
801    nir_ssa_dest_init(&atomic->instr, &atomic->dest, 1, 32, NULL);
802 
803    nir_builder_instr_insert(b, &atomic->instr);
804    nir_ssa_def_rewrite_uses(&intr->dest.ssa, &atomic->dest.ssa);
805    nir_instr_remove(&intr->instr);
806    return true;
807 }
808 
809 bool
dxil_nir_lower_atomics_to_dxil(nir_shader * nir)810 dxil_nir_lower_atomics_to_dxil(nir_shader *nir)
811 {
812    bool progress = false;
813 
814    foreach_list_typed(nir_function, func, node, &nir->functions) {
815       if (!func->is_entrypoint)
816          continue;
817       assert(func->impl);
818 
819       nir_builder b;
820       nir_builder_init(&b, func->impl);
821 
822       nir_foreach_block(block, func->impl) {
823          nir_foreach_instr_safe(instr, block) {
824             if (instr->type != nir_instr_type_intrinsic)
825                continue;
826             nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
827 
828             switch (intr->intrinsic) {
829 
830 #define ATOMIC(op)                                                            \
831   case nir_intrinsic_shared_atomic_##op:                                     \
832      progress |= lower_shared_atomic(&b, intr,                                \
833                                      nir_intrinsic_shared_atomic_##op##_dxil); \
834      break
835 
836             ATOMIC(add);
837             ATOMIC(imin);
838             ATOMIC(umin);
839             ATOMIC(imax);
840             ATOMIC(umax);
841             ATOMIC(and);
842             ATOMIC(or);
843             ATOMIC(xor);
844             ATOMIC(exchange);
845             ATOMIC(comp_swap);
846 
847 #undef ATOMIC
848             default:
849                break;
850             }
851          }
852       }
853    }
854 
855    return progress;
856 }
857 
858 static bool
lower_deref_ssbo(nir_builder * b,nir_deref_instr * deref)859 lower_deref_ssbo(nir_builder *b, nir_deref_instr *deref)
860 {
861    assert(nir_deref_mode_is(deref, nir_var_mem_ssbo));
862    assert(deref->deref_type == nir_deref_type_var ||
863           deref->deref_type == nir_deref_type_cast);
864    nir_variable *var = deref->var;
865 
866    b->cursor = nir_before_instr(&deref->instr);
867 
868    if (deref->deref_type == nir_deref_type_var) {
869       /* We turn all deref_var into deref_cast and build a pointer value based on
870        * the var binding which encodes the UAV id.
871        */
872       nir_ssa_def *ptr = nir_imm_int64(b, (uint64_t)var->data.binding << 32);
873       nir_deref_instr *deref_cast =
874          nir_build_deref_cast(b, ptr, nir_var_mem_ssbo, deref->type,
875                               glsl_get_explicit_stride(var->type));
876       nir_ssa_def_rewrite_uses(&deref->dest.ssa,
877                                &deref_cast->dest.ssa);
878       nir_instr_remove(&deref->instr);
879 
880       deref = deref_cast;
881       return true;
882    }
883    return false;
884 }
885 
886 bool
dxil_nir_lower_deref_ssbo(nir_shader * nir)887 dxil_nir_lower_deref_ssbo(nir_shader *nir)
888 {
889    bool progress = false;
890 
891    foreach_list_typed(nir_function, func, node, &nir->functions) {
892       if (!func->is_entrypoint)
893          continue;
894       assert(func->impl);
895 
896       nir_builder b;
897       nir_builder_init(&b, func->impl);
898 
899       nir_foreach_block(block, func->impl) {
900          nir_foreach_instr_safe(instr, block) {
901             if (instr->type != nir_instr_type_deref)
902                continue;
903 
904             nir_deref_instr *deref = nir_instr_as_deref(instr);
905 
906             if (!nir_deref_mode_is(deref, nir_var_mem_ssbo) ||
907                 (deref->deref_type != nir_deref_type_var &&
908                  deref->deref_type != nir_deref_type_cast))
909                continue;
910 
911             progress |= lower_deref_ssbo(&b, deref);
912          }
913       }
914    }
915 
916    return progress;
917 }
918 
919 static bool
lower_alu_deref_srcs(nir_builder * b,nir_alu_instr * alu)920 lower_alu_deref_srcs(nir_builder *b, nir_alu_instr *alu)
921 {
922    const nir_op_info *info = &nir_op_infos[alu->op];
923    bool progress = false;
924 
925    b->cursor = nir_before_instr(&alu->instr);
926 
927    for (unsigned i = 0; i < info->num_inputs; i++) {
928       nir_deref_instr *deref = nir_src_as_deref(alu->src[i].src);
929 
930       if (!deref)
931          continue;
932 
933       nir_deref_path path;
934       nir_deref_path_init(&path, deref, NULL);
935       nir_deref_instr *root_deref = path.path[0];
936       nir_deref_path_finish(&path);
937 
938       if (root_deref->deref_type != nir_deref_type_cast)
939          continue;
940 
941       nir_ssa_def *ptr =
942          nir_iadd(b, root_deref->parent.ssa,
943                      nir_build_deref_offset(b, deref, cl_type_size_align));
944       nir_instr_rewrite_src(&alu->instr, &alu->src[i].src, nir_src_for_ssa(ptr));
945       progress = true;
946    }
947 
948    return progress;
949 }
950 
951 bool
dxil_nir_opt_alu_deref_srcs(nir_shader * nir)952 dxil_nir_opt_alu_deref_srcs(nir_shader *nir)
953 {
954    bool progress = false;
955 
956    foreach_list_typed(nir_function, func, node, &nir->functions) {
957       if (!func->is_entrypoint)
958          continue;
959       assert(func->impl);
960 
961       nir_builder b;
962       nir_builder_init(&b, func->impl);
963 
964       nir_foreach_block(block, func->impl) {
965          nir_foreach_instr_safe(instr, block) {
966             if (instr->type != nir_instr_type_alu)
967                continue;
968 
969             nir_alu_instr *alu = nir_instr_as_alu(instr);
970             progress |= lower_alu_deref_srcs(&b, alu);
971          }
972       }
973    }
974 
975    return progress;
976 }
977 
978 static nir_ssa_def *
memcpy_load_deref_elem(nir_builder * b,nir_deref_instr * parent,nir_ssa_def * index)979 memcpy_load_deref_elem(nir_builder *b, nir_deref_instr *parent,
980                        nir_ssa_def *index)
981 {
982    nir_deref_instr *deref;
983 
984    index = nir_i2i(b, index, nir_dest_bit_size(parent->dest));
985    assert(parent->deref_type == nir_deref_type_cast);
986    deref = nir_build_deref_ptr_as_array(b, parent, index);
987 
988    return nir_load_deref(b, deref);
989 }
990 
991 static void
memcpy_store_deref_elem(nir_builder * b,nir_deref_instr * parent,nir_ssa_def * index,nir_ssa_def * value)992 memcpy_store_deref_elem(nir_builder *b, nir_deref_instr *parent,
993                         nir_ssa_def *index, nir_ssa_def *value)
994 {
995    nir_deref_instr *deref;
996 
997    index = nir_i2i(b, index, nir_dest_bit_size(parent->dest));
998    assert(parent->deref_type == nir_deref_type_cast);
999    deref = nir_build_deref_ptr_as_array(b, parent, index);
1000    nir_store_deref(b, deref, value, 1);
1001 }
1002 
1003 static bool
lower_memcpy_deref(nir_builder * b,nir_intrinsic_instr * intr)1004 lower_memcpy_deref(nir_builder *b, nir_intrinsic_instr *intr)
1005 {
1006    nir_deref_instr *dst_deref = nir_src_as_deref(intr->src[0]);
1007    nir_deref_instr *src_deref = nir_src_as_deref(intr->src[1]);
1008    assert(intr->src[2].is_ssa);
1009    nir_ssa_def *num_bytes = intr->src[2].ssa;
1010 
1011    assert(dst_deref && src_deref);
1012 
1013    b->cursor = nir_after_instr(&intr->instr);
1014 
1015    dst_deref = nir_build_deref_cast(b, &dst_deref->dest.ssa, dst_deref->modes,
1016                                        glsl_uint8_t_type(), 1);
1017    src_deref = nir_build_deref_cast(b, &src_deref->dest.ssa, src_deref->modes,
1018                                        glsl_uint8_t_type(), 1);
1019 
1020    /*
1021     * We want to avoid 64b instructions, so let's assume we'll always be
1022     * passed a value that fits in a 32b type and truncate the 64b value.
1023     */
1024    num_bytes = nir_u2u32(b, num_bytes);
1025 
1026    nir_variable *loop_index_var =
1027      nir_local_variable_create(b->impl, glsl_uint_type(), "loop_index");
1028    nir_deref_instr *loop_index_deref = nir_build_deref_var(b, loop_index_var);
1029    nir_store_deref(b, loop_index_deref, nir_imm_int(b, 0), 1);
1030 
1031    nir_loop *loop = nir_push_loop(b);
1032    nir_ssa_def *loop_index = nir_load_deref(b, loop_index_deref);
1033    nir_ssa_def *cmp = nir_ige(b, loop_index, num_bytes);
1034    nir_if *loop_check = nir_push_if(b, cmp);
1035    nir_jump(b, nir_jump_break);
1036    nir_pop_if(b, loop_check);
1037    nir_ssa_def *val = memcpy_load_deref_elem(b, src_deref, loop_index);
1038    memcpy_store_deref_elem(b, dst_deref, loop_index, val);
1039    nir_store_deref(b, loop_index_deref, nir_iadd_imm(b, loop_index, 1), 1);
1040    nir_pop_loop(b, loop);
1041    nir_instr_remove(&intr->instr);
1042    return true;
1043 }
1044 
1045 bool
dxil_nir_lower_memcpy_deref(nir_shader * nir)1046 dxil_nir_lower_memcpy_deref(nir_shader *nir)
1047 {
1048    bool progress = false;
1049 
1050    foreach_list_typed(nir_function, func, node, &nir->functions) {
1051       if (!func->is_entrypoint)
1052          continue;
1053       assert(func->impl);
1054 
1055       nir_builder b;
1056       nir_builder_init(&b, func->impl);
1057 
1058       nir_foreach_block(block, func->impl) {
1059          nir_foreach_instr_safe(instr, block) {
1060             if (instr->type != nir_instr_type_intrinsic)
1061                continue;
1062 
1063             nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
1064 
1065             if (intr->intrinsic == nir_intrinsic_memcpy_deref)
1066                progress |= lower_memcpy_deref(&b, intr);
1067          }
1068       }
1069    }
1070 
1071    return progress;
1072 }
1073 
1074 static void
cast_phi(nir_builder * b,nir_phi_instr * phi,unsigned new_bit_size)1075 cast_phi(nir_builder *b, nir_phi_instr *phi, unsigned new_bit_size)
1076 {
1077    nir_phi_instr *lowered = nir_phi_instr_create(b->shader);
1078    int num_components = 0;
1079    int old_bit_size = phi->dest.ssa.bit_size;
1080 
1081    nir_op upcast_op = nir_type_conversion_op(nir_type_uint | old_bit_size,
1082                                              nir_type_uint | new_bit_size,
1083                                              nir_rounding_mode_undef);
1084    nir_op downcast_op = nir_type_conversion_op(nir_type_uint | new_bit_size,
1085                                                nir_type_uint | old_bit_size,
1086                                                nir_rounding_mode_undef);
1087 
1088    nir_foreach_phi_src(src, phi) {
1089       assert(num_components == 0 || num_components == src->src.ssa->num_components);
1090       num_components = src->src.ssa->num_components;
1091 
1092       b->cursor = nir_after_instr_and_phis(src->src.ssa->parent_instr);
1093 
1094       nir_ssa_def *cast = nir_build_alu(b, upcast_op, src->src.ssa, NULL, NULL, NULL);
1095       nir_phi_instr_add_src(lowered, src->pred, nir_src_for_ssa(cast));
1096    }
1097 
1098    nir_ssa_dest_init(&lowered->instr, &lowered->dest,
1099                      num_components, new_bit_size, NULL);
1100 
1101    b->cursor = nir_before_instr(&phi->instr);
1102    nir_builder_instr_insert(b, &lowered->instr);
1103 
1104    b->cursor = nir_after_phis(nir_cursor_current_block(b->cursor));
1105    nir_ssa_def *result = nir_build_alu(b, downcast_op, &lowered->dest.ssa, NULL, NULL, NULL);
1106 
1107    nir_ssa_def_rewrite_uses(&phi->dest.ssa, result);
1108    nir_instr_remove(&phi->instr);
1109 }
1110 
1111 static bool
upcast_phi_impl(nir_function_impl * impl,unsigned min_bit_size)1112 upcast_phi_impl(nir_function_impl *impl, unsigned min_bit_size)
1113 {
1114    nir_builder b;
1115    nir_builder_init(&b, impl);
1116    bool progress = false;
1117 
1118    nir_foreach_block_reverse(block, impl) {
1119       nir_foreach_instr_safe(instr, block) {
1120          if (instr->type != nir_instr_type_phi)
1121             continue;
1122 
1123          nir_phi_instr *phi = nir_instr_as_phi(instr);
1124          assert(phi->dest.is_ssa);
1125 
1126          if (phi->dest.ssa.bit_size == 1 ||
1127              phi->dest.ssa.bit_size >= min_bit_size)
1128             continue;
1129 
1130          cast_phi(&b, phi, min_bit_size);
1131          progress = true;
1132       }
1133    }
1134 
1135    if (progress) {
1136       nir_metadata_preserve(impl, nir_metadata_block_index |
1137                                   nir_metadata_dominance);
1138    } else {
1139       nir_metadata_preserve(impl, nir_metadata_all);
1140    }
1141 
1142    return progress;
1143 }
1144 
1145 bool
dxil_nir_lower_upcast_phis(nir_shader * shader,unsigned min_bit_size)1146 dxil_nir_lower_upcast_phis(nir_shader *shader, unsigned min_bit_size)
1147 {
1148    bool progress = false;
1149 
1150    nir_foreach_function(function, shader) {
1151       if (function->impl)
1152          progress |= upcast_phi_impl(function->impl, min_bit_size);
1153    }
1154 
1155    return progress;
1156 }
1157 
1158 struct dxil_nir_split_clip_cull_distance_params {
1159    nir_variable *new_var;
1160    nir_shader *shader;
1161 };
1162 
1163 /* In GLSL and SPIR-V, clip and cull distance are arrays of floats (with a limit of 8).
1164  * In DXIL, clip and cull distances are up to 2 float4s combined.
1165  * Coming from GLSL, we can request this 2 float4 format, but coming from SPIR-V,
1166  * we can't, and have to accept a "compact" array of scalar floats.
1167  *
1168  * To help emitting a valid input signature for this case, split the variables so that they
1169  * match what we need to put in the signature (e.g. { float clip[4]; float clip1; float cull[3]; })
1170  */
1171 static bool
dxil_nir_split_clip_cull_distance_instr(nir_builder * b,nir_instr * instr,void * cb_data)1172 dxil_nir_split_clip_cull_distance_instr(nir_builder *b,
1173                                         nir_instr *instr,
1174                                         void *cb_data)
1175 {
1176    struct dxil_nir_split_clip_cull_distance_params *params = cb_data;
1177    nir_variable *new_var = params->new_var;
1178 
1179    if (instr->type != nir_instr_type_deref)
1180       return false;
1181 
1182    nir_deref_instr *deref = nir_instr_as_deref(instr);
1183    nir_variable *var = nir_deref_instr_get_variable(deref);
1184    if (!var ||
1185        var->data.location < VARYING_SLOT_CLIP_DIST0 ||
1186        var->data.location > VARYING_SLOT_CULL_DIST1 ||
1187        !var->data.compact)
1188       return false;
1189 
1190    /* The location should only be inside clip distance, because clip
1191     * and cull should've been merged by nir_lower_clip_cull_distance_arrays()
1192     */
1193    assert(var->data.location == VARYING_SLOT_CLIP_DIST0 ||
1194           var->data.location == VARYING_SLOT_CLIP_DIST1);
1195 
1196    /* The deref chain to the clip/cull variables should be simple, just the
1197     * var and an array with a constant index, otherwise more lowering/optimization
1198     * might be needed before this pass, e.g. copy prop, lower_io_to_temporaries,
1199     * split_var_copies, and/or lower_var_copies. In the case of arrayed I/O like
1200     * inputs to the tessellation or geometry stages, there might be a second level
1201     * of array index.
1202     */
1203    assert(deref->deref_type == nir_deref_type_var ||
1204           deref->deref_type == nir_deref_type_array);
1205 
1206    b->cursor = nir_before_instr(instr);
1207    unsigned arrayed_io_length = 0;
1208    const struct glsl_type *old_type = var->type;
1209    if (nir_is_arrayed_io(var, b->shader->info.stage)) {
1210       arrayed_io_length = glsl_array_size(old_type);
1211       old_type = glsl_get_array_element(old_type);
1212    }
1213    if (!new_var) {
1214       /* Update lengths for new and old vars */
1215       int old_length = glsl_array_size(old_type);
1216       int new_length = (old_length + var->data.location_frac) - 4;
1217       old_length -= new_length;
1218 
1219       /* The existing variable fits in the float4 */
1220       if (new_length <= 0)
1221          return false;
1222 
1223       new_var = nir_variable_clone(var, params->shader);
1224       nir_shader_add_variable(params->shader, new_var);
1225       assert(glsl_get_base_type(glsl_get_array_element(old_type)) == GLSL_TYPE_FLOAT);
1226       var->type = glsl_array_type(glsl_float_type(), old_length, 0);
1227       new_var->type = glsl_array_type(glsl_float_type(), new_length, 0);
1228       if (arrayed_io_length) {
1229          var->type = glsl_array_type(var->type, arrayed_io_length, 0);
1230          new_var->type = glsl_array_type(new_var->type, arrayed_io_length, 0);
1231       }
1232       new_var->data.location++;
1233       new_var->data.location_frac = 0;
1234       params->new_var = new_var;
1235    }
1236 
1237    /* Update the type for derefs of the old var */
1238    if (deref->deref_type == nir_deref_type_var) {
1239       deref->type = var->type;
1240       return false;
1241    }
1242 
1243    if (glsl_type_is_array(deref->type)) {
1244       assert(arrayed_io_length > 0);
1245       deref->type = glsl_get_array_element(var->type);
1246       return false;
1247    }
1248 
1249    assert(glsl_get_base_type(deref->type) == GLSL_TYPE_FLOAT);
1250 
1251    nir_const_value *index = nir_src_as_const_value(deref->arr.index);
1252    assert(index);
1253 
1254    /* Treat this array as a vector starting at the component index in location_frac,
1255     * so if location_frac is 1 and index is 0, then it's accessing the 'y' component
1256     * of the vector. If index + location_frac is >= 4, there's no component there,
1257     * so we need to add a new variable and adjust the index.
1258     */
1259    unsigned total_index = index->u32 + var->data.location_frac;
1260    if (total_index < 4)
1261       return false;
1262 
1263    nir_deref_instr *new_var_deref = nir_build_deref_var(b, new_var);
1264    nir_deref_instr *new_intermediate_deref = new_var_deref;
1265    if (arrayed_io_length) {
1266       nir_deref_instr *parent = nir_src_as_deref(deref->parent);
1267       assert(parent->deref_type == nir_deref_type_array);
1268       new_intermediate_deref = nir_build_deref_array(b, new_intermediate_deref, parent->arr.index.ssa);
1269    }
1270    nir_deref_instr *new_array_deref = nir_build_deref_array(b, new_intermediate_deref, nir_imm_int(b, total_index % 4));
1271    nir_ssa_def_rewrite_uses(&deref->dest.ssa, &new_array_deref->dest.ssa);
1272    return true;
1273 }
1274 
1275 bool
dxil_nir_split_clip_cull_distance(nir_shader * shader)1276 dxil_nir_split_clip_cull_distance(nir_shader *shader)
1277 {
1278    struct dxil_nir_split_clip_cull_distance_params params = {
1279       .new_var = NULL,
1280       .shader = shader,
1281    };
1282    nir_shader_instructions_pass(shader,
1283                                 dxil_nir_split_clip_cull_distance_instr,
1284                                 nir_metadata_block_index |
1285                                 nir_metadata_dominance |
1286                                 nir_metadata_loop_analysis,
1287                                 &params);
1288    return params.new_var != NULL;
1289 }
1290 
1291 static bool
dxil_nir_lower_double_math_instr(nir_builder * b,nir_instr * instr,UNUSED void * cb_data)1292 dxil_nir_lower_double_math_instr(nir_builder *b,
1293                                  nir_instr *instr,
1294                                  UNUSED void *cb_data)
1295 {
1296    if (instr->type != nir_instr_type_alu)
1297       return false;
1298 
1299    nir_alu_instr *alu = nir_instr_as_alu(instr);
1300 
1301    /* TODO: See if we can apply this explicitly to packs/unpacks that are then
1302     * used as a double. As-is, if we had an app explicitly do a 64bit integer op,
1303     * then try to bitcast to double (not expressible in HLSL, but it is in other
1304     * source languages), this would unpack the integer and repack as a double, when
1305     * we probably want to just send the bitcast through to the backend.
1306     */
1307 
1308    b->cursor = nir_before_instr(&alu->instr);
1309 
1310    bool progress = false;
1311    for (unsigned i = 0; i < nir_op_infos[alu->op].num_inputs; ++i) {
1312       if (nir_alu_type_get_base_type(nir_op_infos[alu->op].input_types[i]) == nir_type_float &&
1313           alu->src[i].src.ssa->bit_size == 64) {
1314          unsigned num_components = nir_op_infos[alu->op].input_sizes[i];
1315          if (!num_components)
1316             num_components = alu->dest.dest.ssa.num_components;
1317          nir_ssa_def *components[NIR_MAX_VEC_COMPONENTS];
1318          for (unsigned c = 0; c < num_components; ++c) {
1319             nir_ssa_def *packed_double = nir_channel(b, alu->src[i].src.ssa, alu->src[i].swizzle[c]);
1320             nir_ssa_def *unpacked_double = nir_unpack_64_2x32(b, packed_double);
1321             components[c] = nir_pack_double_2x32_dxil(b, unpacked_double);
1322             alu->src[i].swizzle[c] = c;
1323          }
1324          nir_instr_rewrite_src_ssa(instr, &alu->src[i].src, nir_vec(b, components, num_components));
1325          progress = true;
1326       }
1327    }
1328 
1329    if (nir_alu_type_get_base_type(nir_op_infos[alu->op].output_type) == nir_type_float &&
1330        alu->dest.dest.ssa.bit_size == 64) {
1331       b->cursor = nir_after_instr(&alu->instr);
1332       nir_ssa_def *components[NIR_MAX_VEC_COMPONENTS];
1333       for (unsigned c = 0; c < alu->dest.dest.ssa.num_components; ++c) {
1334          nir_ssa_def *packed_double = nir_channel(b, &alu->dest.dest.ssa, c);
1335          nir_ssa_def *unpacked_double = nir_unpack_double_2x32_dxil(b, packed_double);
1336          components[c] = nir_pack_64_2x32(b, unpacked_double);
1337       }
1338       nir_ssa_def *repacked_dvec = nir_vec(b, components, alu->dest.dest.ssa.num_components);
1339       nir_ssa_def_rewrite_uses_after(&alu->dest.dest.ssa, repacked_dvec, repacked_dvec->parent_instr);
1340       progress = true;
1341    }
1342 
1343    return progress;
1344 }
1345 
1346 bool
dxil_nir_lower_double_math(nir_shader * shader)1347 dxil_nir_lower_double_math(nir_shader *shader)
1348 {
1349    return nir_shader_instructions_pass(shader,
1350                                        dxil_nir_lower_double_math_instr,
1351                                        nir_metadata_block_index |
1352                                        nir_metadata_dominance |
1353                                        nir_metadata_loop_analysis,
1354                                        NULL);
1355 }
1356 
1357 typedef struct {
1358    gl_system_value *values;
1359    uint32_t count;
1360 } zero_system_values_state;
1361 
1362 static bool
lower_system_value_to_zero_filter(const nir_instr * instr,const void * cb_state)1363 lower_system_value_to_zero_filter(const nir_instr* instr, const void* cb_state)
1364 {
1365    if (instr->type != nir_instr_type_intrinsic) {
1366       return false;
1367    }
1368 
1369    nir_intrinsic_instr* intrin = nir_instr_as_intrinsic(instr);
1370 
1371    /* All the intrinsics we care about are loads */
1372    if (!nir_intrinsic_infos[intrin->intrinsic].has_dest)
1373       return false;
1374 
1375    assert(intrin->dest.is_ssa);
1376 
1377    zero_system_values_state* state = (zero_system_values_state*)cb_state;
1378    for (uint32_t i = 0; i < state->count; ++i) {
1379       gl_system_value value = state->values[i];
1380       nir_intrinsic_op value_op = nir_intrinsic_from_system_value(value);
1381 
1382       if (intrin->intrinsic == value_op) {
1383          return true;
1384       } else if (intrin->intrinsic == nir_intrinsic_load_deref) {
1385          nir_deref_instr* deref = nir_src_as_deref(intrin->src[0]);
1386          if (!nir_deref_mode_is(deref, nir_var_system_value))
1387             return false;
1388 
1389          nir_variable* var = deref->var;
1390          if (var->data.location == value) {
1391             return true;
1392          }
1393       }
1394    }
1395 
1396    return false;
1397 }
1398 
1399 static nir_ssa_def*
lower_system_value_to_zero_instr(nir_builder * b,nir_instr * instr,void * _state)1400 lower_system_value_to_zero_instr(nir_builder* b, nir_instr* instr, void* _state)
1401 {
1402    return nir_imm_int(b, 0);
1403 }
1404 
1405 bool
dxil_nir_lower_system_values_to_zero(nir_shader * shader,gl_system_value * system_values,uint32_t count)1406 dxil_nir_lower_system_values_to_zero(nir_shader* shader,
1407                                      gl_system_value* system_values,
1408                                      uint32_t count)
1409 {
1410    zero_system_values_state state = { system_values, count };
1411    return nir_shader_lower_instructions(shader,
1412       lower_system_value_to_zero_filter,
1413       lower_system_value_to_zero_instr,
1414       &state);
1415 }
1416 
1417 static void
lower_load_local_group_size(nir_builder * b,nir_intrinsic_instr * intr)1418 lower_load_local_group_size(nir_builder *b, nir_intrinsic_instr *intr)
1419 {
1420    b->cursor = nir_after_instr(&intr->instr);
1421 
1422    nir_const_value v[3] = {
1423       nir_const_value_for_int(b->shader->info.workgroup_size[0], 32),
1424       nir_const_value_for_int(b->shader->info.workgroup_size[1], 32),
1425       nir_const_value_for_int(b->shader->info.workgroup_size[2], 32)
1426    };
1427    nir_ssa_def *size = nir_build_imm(b, 3, 32, v);
1428    nir_ssa_def_rewrite_uses(&intr->dest.ssa, size);
1429    nir_instr_remove(&intr->instr);
1430 }
1431 
1432 static bool
lower_system_values_impl(nir_builder * b,nir_instr * instr,void * _state)1433 lower_system_values_impl(nir_builder *b, nir_instr *instr, void *_state)
1434 {
1435    if (instr->type != nir_instr_type_intrinsic)
1436       return false;
1437    nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
1438    switch (intr->intrinsic) {
1439    case nir_intrinsic_load_workgroup_size:
1440       lower_load_local_group_size(b, intr);
1441       return true;
1442    default:
1443       return false;
1444    }
1445 }
1446 
1447 bool
dxil_nir_lower_system_values(nir_shader * shader)1448 dxil_nir_lower_system_values(nir_shader *shader)
1449 {
1450    return nir_shader_instructions_pass(shader, lower_system_values_impl,
1451       nir_metadata_block_index | nir_metadata_dominance | nir_metadata_loop_analysis, NULL);
1452 }
1453 
1454 static const struct glsl_type *
get_bare_samplers_for_type(const struct glsl_type * type,bool is_shadow)1455 get_bare_samplers_for_type(const struct glsl_type *type, bool is_shadow)
1456 {
1457    const struct glsl_type *base_sampler_type =
1458       is_shadow ?
1459       glsl_bare_shadow_sampler_type() : glsl_bare_sampler_type();
1460    return glsl_type_wrap_in_arrays(base_sampler_type, type);
1461 }
1462 
1463 static const struct glsl_type *
get_textures_for_sampler_type(const struct glsl_type * type)1464 get_textures_for_sampler_type(const struct glsl_type *type)
1465 {
1466    return glsl_type_wrap_in_arrays(
1467       glsl_sampler_type_to_texture(
1468          glsl_without_array(type)), type);
1469 }
1470 
1471 static bool
redirect_sampler_derefs(struct nir_builder * b,nir_instr * instr,void * data)1472 redirect_sampler_derefs(struct nir_builder *b, nir_instr *instr, void *data)
1473 {
1474    if (instr->type != nir_instr_type_tex)
1475       return false;
1476 
1477    nir_tex_instr *tex = nir_instr_as_tex(instr);
1478 
1479    int sampler_idx = nir_tex_instr_src_index(tex, nir_tex_src_sampler_deref);
1480    if (sampler_idx == -1) {
1481       /* No sampler deref - does this instruction even need a sampler? If not,
1482        * sampler_index doesn't necessarily point to a sampler, so early-out.
1483        */
1484       if (!nir_tex_instr_need_sampler(tex))
1485          return false;
1486 
1487       /* No derefs but needs a sampler, must be using indices */
1488       nir_variable *bare_sampler = _mesa_hash_table_u64_search(data, tex->sampler_index);
1489 
1490       /* Already have a bare sampler here */
1491       if (bare_sampler)
1492          return false;
1493 
1494       nir_variable *old_sampler = NULL;
1495       nir_foreach_variable_with_modes(var, b->shader, nir_var_uniform) {
1496          if (var->data.binding <= tex->sampler_index &&
1497              var->data.binding + glsl_type_get_sampler_count(var->type) >
1498                 tex->sampler_index) {
1499 
1500             /* Already have a bare sampler for this binding and it is of the
1501              * correct type, add it to the table */
1502             if (glsl_type_is_bare_sampler(glsl_without_array(var->type)) &&
1503                 glsl_sampler_type_is_shadow(glsl_without_array(var->type)) ==
1504                    tex->is_shadow) {
1505                _mesa_hash_table_u64_insert(data, tex->sampler_index, var);
1506                return false;
1507             }
1508 
1509             old_sampler = var;
1510          }
1511       }
1512 
1513       assert(old_sampler);
1514 
1515       /* Clone the original sampler to a bare sampler of the correct type */
1516       bare_sampler = nir_variable_clone(old_sampler, b->shader);
1517       nir_shader_add_variable(b->shader, bare_sampler);
1518 
1519       bare_sampler->type =
1520          get_bare_samplers_for_type(old_sampler->type, tex->is_shadow);
1521       _mesa_hash_table_u64_insert(data, tex->sampler_index, bare_sampler);
1522       return true;
1523    }
1524 
1525    /* Using derefs, means we have to rewrite the deref chain in addition to cloning */
1526    nir_deref_instr *final_deref = nir_src_as_deref(tex->src[sampler_idx].src);
1527    nir_deref_path path;
1528    nir_deref_path_init(&path, final_deref, NULL);
1529 
1530    nir_deref_instr *old_tail = path.path[0];
1531    assert(old_tail->deref_type == nir_deref_type_var);
1532    nir_variable *old_var = old_tail->var;
1533    if (glsl_type_is_bare_sampler(glsl_without_array(old_var->type)) &&
1534        glsl_sampler_type_is_shadow(glsl_without_array(old_var->type)) ==
1535           tex->is_shadow) {
1536       nir_deref_path_finish(&path);
1537       return false;
1538    }
1539 
1540    uint64_t var_key = ((uint64_t)old_var->data.descriptor_set << 32) |
1541                       old_var->data.binding;
1542    nir_variable *new_var = _mesa_hash_table_u64_search(data, var_key);
1543    if (!new_var) {
1544       new_var = nir_variable_clone(old_var, b->shader);
1545       nir_shader_add_variable(b->shader, new_var);
1546       new_var->type =
1547          get_bare_samplers_for_type(old_var->type, tex->is_shadow);
1548       _mesa_hash_table_u64_insert(data, var_key, new_var);
1549    }
1550 
1551    b->cursor = nir_after_instr(&old_tail->instr);
1552    nir_deref_instr *new_tail = nir_build_deref_var(b, new_var);
1553 
1554    for (unsigned i = 1; path.path[i]; ++i) {
1555       b->cursor = nir_after_instr(&path.path[i]->instr);
1556       new_tail = nir_build_deref_follower(b, new_tail, path.path[i]);
1557    }
1558 
1559    nir_deref_path_finish(&path);
1560    nir_instr_rewrite_src_ssa(&tex->instr, &tex->src[sampler_idx].src, &new_tail->dest.ssa);
1561    return true;
1562 }
1563 
1564 static bool
redirect_texture_derefs(struct nir_builder * b,nir_instr * instr,void * data)1565 redirect_texture_derefs(struct nir_builder *b, nir_instr *instr, void *data)
1566 {
1567    if (instr->type != nir_instr_type_tex)
1568       return false;
1569 
1570    nir_tex_instr *tex = nir_instr_as_tex(instr);
1571 
1572    int texture_idx = nir_tex_instr_src_index(tex, nir_tex_src_texture_deref);
1573    if (texture_idx == -1) {
1574       /* No derefs, must be using indices */
1575       nir_variable *bare_sampler = _mesa_hash_table_u64_search(data, tex->texture_index);
1576 
1577       /* Already have a texture here */
1578       if (bare_sampler)
1579          return false;
1580 
1581       nir_variable *typed_sampler = NULL;
1582       nir_foreach_variable_with_modes(var, b->shader, nir_var_uniform) {
1583          if (var->data.binding <= tex->texture_index &&
1584              var->data.binding + glsl_type_get_texture_count(var->type) > tex->texture_index) {
1585             /* Already have a texture for this binding, add it to the table */
1586             _mesa_hash_table_u64_insert(data, tex->texture_index, var);
1587             return false;
1588          }
1589 
1590          if (var->data.binding <= tex->texture_index &&
1591              var->data.binding + glsl_type_get_sampler_count(var->type) > tex->texture_index &&
1592              !glsl_type_is_bare_sampler(glsl_without_array(var->type))) {
1593             typed_sampler = var;
1594          }
1595       }
1596 
1597       /* Clone the typed sampler to a texture and we're done */
1598       assert(typed_sampler);
1599       bare_sampler = nir_variable_clone(typed_sampler, b->shader);
1600       bare_sampler->type = get_textures_for_sampler_type(typed_sampler->type);
1601       nir_shader_add_variable(b->shader, bare_sampler);
1602       _mesa_hash_table_u64_insert(data, tex->texture_index, bare_sampler);
1603       return true;
1604    }
1605 
1606    /* Using derefs, means we have to rewrite the deref chain in addition to cloning */
1607    nir_deref_instr *final_deref = nir_src_as_deref(tex->src[texture_idx].src);
1608    nir_deref_path path;
1609    nir_deref_path_init(&path, final_deref, NULL);
1610 
1611    nir_deref_instr *old_tail = path.path[0];
1612    assert(old_tail->deref_type == nir_deref_type_var);
1613    nir_variable *old_var = old_tail->var;
1614    if (glsl_type_is_texture(glsl_without_array(old_var->type)) ||
1615        glsl_type_is_image(glsl_without_array(old_var->type))) {
1616       nir_deref_path_finish(&path);
1617       return false;
1618    }
1619 
1620    uint64_t var_key = ((uint64_t)old_var->data.descriptor_set << 32) |
1621                       old_var->data.binding;
1622    nir_variable *new_var = _mesa_hash_table_u64_search(data, var_key);
1623    if (!new_var) {
1624       new_var = nir_variable_clone(old_var, b->shader);
1625       new_var->type = get_textures_for_sampler_type(old_var->type);
1626       nir_shader_add_variable(b->shader, new_var);
1627       _mesa_hash_table_u64_insert(data, var_key, new_var);
1628    }
1629 
1630    b->cursor = nir_after_instr(&old_tail->instr);
1631    nir_deref_instr *new_tail = nir_build_deref_var(b, new_var);
1632 
1633    for (unsigned i = 1; path.path[i]; ++i) {
1634       b->cursor = nir_after_instr(&path.path[i]->instr);
1635       new_tail = nir_build_deref_follower(b, new_tail, path.path[i]);
1636    }
1637 
1638    nir_deref_path_finish(&path);
1639    nir_instr_rewrite_src_ssa(&tex->instr, &tex->src[texture_idx].src, &new_tail->dest.ssa);
1640 
1641    return true;
1642 }
1643 
1644 bool
dxil_nir_split_typed_samplers(nir_shader * nir)1645 dxil_nir_split_typed_samplers(nir_shader *nir)
1646 {
1647    struct hash_table_u64 *hash_table = _mesa_hash_table_u64_create(NULL);
1648 
1649    bool progress = nir_shader_instructions_pass(nir, redirect_sampler_derefs,
1650       nir_metadata_block_index | nir_metadata_dominance | nir_metadata_loop_analysis, hash_table);
1651 
1652    _mesa_hash_table_u64_clear(hash_table);
1653 
1654    progress |= nir_shader_instructions_pass(nir, redirect_texture_derefs,
1655       nir_metadata_block_index | nir_metadata_dominance | nir_metadata_loop_analysis, hash_table);
1656 
1657    _mesa_hash_table_u64_destroy(hash_table);
1658    return progress;
1659 }
1660 
1661 
1662 static bool
lower_bool_input_filter(const nir_instr * instr,UNUSED const void * _options)1663 lower_bool_input_filter(const nir_instr *instr,
1664                         UNUSED const void *_options)
1665 {
1666    if (instr->type != nir_instr_type_intrinsic)
1667       return false;
1668 
1669    nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
1670    if (intr->intrinsic == nir_intrinsic_load_front_face)
1671       return true;
1672 
1673    if (intr->intrinsic == nir_intrinsic_load_deref) {
1674       nir_deref_instr *deref = nir_instr_as_deref(intr->src[0].ssa->parent_instr);
1675       nir_variable *var = nir_deref_instr_get_variable(deref);
1676       return var->data.mode == nir_var_shader_in &&
1677              glsl_get_base_type(var->type) == GLSL_TYPE_BOOL;
1678    }
1679 
1680    return false;
1681 }
1682 
1683 static nir_ssa_def *
lower_bool_input_impl(nir_builder * b,nir_instr * instr,UNUSED void * _options)1684 lower_bool_input_impl(nir_builder *b, nir_instr *instr,
1685                       UNUSED void *_options)
1686 {
1687    nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
1688 
1689    if (intr->intrinsic == nir_intrinsic_load_deref) {
1690       nir_deref_instr *deref = nir_instr_as_deref(intr->src[0].ssa->parent_instr);
1691       nir_variable *var = nir_deref_instr_get_variable(deref);
1692 
1693       /* rewrite var->type */
1694       var->type = glsl_vector_type(GLSL_TYPE_UINT,
1695                                    glsl_get_vector_elements(var->type));
1696       deref->type = var->type;
1697    }
1698 
1699    intr->dest.ssa.bit_size = 32;
1700    return nir_i2b1(b, &intr->dest.ssa);
1701 }
1702 
1703 bool
dxil_nir_lower_bool_input(struct nir_shader * s)1704 dxil_nir_lower_bool_input(struct nir_shader *s)
1705 {
1706    return nir_shader_lower_instructions(s, lower_bool_input_filter,
1707                                         lower_bool_input_impl, NULL);
1708 }
1709 
1710 static bool
lower_sysval_to_load_input_impl(nir_builder * b,nir_instr * instr,void * data)1711 lower_sysval_to_load_input_impl(nir_builder *b, nir_instr *instr, void *data)
1712 {
1713    if (instr->type != nir_instr_type_intrinsic)
1714       return false;
1715 
1716    nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
1717    gl_system_value sysval = SYSTEM_VALUE_MAX;
1718    switch (intr->intrinsic) {
1719    case nir_intrinsic_load_front_face:
1720       sysval = SYSTEM_VALUE_FRONT_FACE;
1721       break;
1722    case nir_intrinsic_load_instance_id:
1723       sysval = SYSTEM_VALUE_INSTANCE_ID;
1724       break;
1725    case nir_intrinsic_load_vertex_id_zero_base:
1726       sysval = SYSTEM_VALUE_VERTEX_ID_ZERO_BASE;
1727       break;
1728    default:
1729       return false;
1730    }
1731 
1732    nir_variable **sysval_vars = (nir_variable **)data;
1733    nir_variable *var = sysval_vars[sysval];
1734    assert(var);
1735 
1736    b->cursor = nir_before_instr(instr);
1737    nir_ssa_def *result = nir_build_load_input(b, intr->dest.ssa.num_components, intr->dest.ssa.bit_size, nir_imm_int(b, 0),
1738       .base = var->data.driver_location, .dest_type = nir_get_nir_type_for_glsl_type(var->type));
1739    nir_ssa_def_rewrite_uses(&intr->dest.ssa, result);
1740    return true;
1741 }
1742 
1743 bool
dxil_nir_lower_sysval_to_load_input(nir_shader * s,nir_variable ** sysval_vars)1744 dxil_nir_lower_sysval_to_load_input(nir_shader *s, nir_variable **sysval_vars)
1745 {
1746    return nir_shader_instructions_pass(s, lower_sysval_to_load_input_impl,
1747       nir_metadata_block_index | nir_metadata_dominance, sysval_vars);
1748 }
1749 
1750 /* Comparison function to sort io values so that first come normal varyings,
1751  * then system values, and then system generated values.
1752  */
1753 static int
variable_location_cmp(const nir_variable * a,const nir_variable * b)1754 variable_location_cmp(const nir_variable* a, const nir_variable* b)
1755 {
1756    // Sort by stream, driver_location, location, location_frac, then index
1757    unsigned a_location = a->data.location;
1758    if (a_location >= VARYING_SLOT_PATCH0)
1759       a_location -= VARYING_SLOT_PATCH0;
1760    unsigned b_location = b->data.location;
1761    if (b_location >= VARYING_SLOT_PATCH0)
1762       b_location -= VARYING_SLOT_PATCH0;
1763    unsigned a_stream = a->data.stream & ~NIR_STREAM_PACKED;
1764    unsigned b_stream = b->data.stream & ~NIR_STREAM_PACKED;
1765    return a_stream != b_stream ?
1766             a_stream - b_stream :
1767             a->data.driver_location != b->data.driver_location ?
1768                a->data.driver_location - b->data.driver_location :
1769                a_location !=  b_location ?
1770                   a_location - b_location :
1771                   a->data.location_frac != b->data.location_frac ?
1772                      a->data.location_frac - b->data.location_frac :
1773                      a->data.index - b->data.index;
1774 }
1775 
1776 /* Order varyings according to driver location */
1777 uint64_t
dxil_sort_by_driver_location(nir_shader * s,nir_variable_mode modes)1778 dxil_sort_by_driver_location(nir_shader* s, nir_variable_mode modes)
1779 {
1780    nir_sort_variables_with_modes(s, variable_location_cmp, modes);
1781 
1782    uint64_t result = 0;
1783    nir_foreach_variable_with_modes(var, s, modes) {
1784       result |= 1ull << var->data.location;
1785    }
1786    return result;
1787 }
1788 
1789 /* Sort PS outputs so that color outputs come first */
1790 void
dxil_sort_ps_outputs(nir_shader * s)1791 dxil_sort_ps_outputs(nir_shader* s)
1792 {
1793    nir_foreach_variable_with_modes_safe(var, s, nir_var_shader_out) {
1794       /* We use the driver_location here to avoid introducing a new
1795        * struct or member variable here. The true, updated driver location
1796        * will be written below, after sorting */
1797       switch (var->data.location) {
1798       case FRAG_RESULT_DEPTH:
1799          var->data.driver_location = 1;
1800          break;
1801       case FRAG_RESULT_STENCIL:
1802          var->data.driver_location = 2;
1803          break;
1804       case FRAG_RESULT_SAMPLE_MASK:
1805          var->data.driver_location = 3;
1806          break;
1807       default:
1808          var->data.driver_location = 0;
1809       }
1810    }
1811 
1812    nir_sort_variables_with_modes(s, variable_location_cmp,
1813                                  nir_var_shader_out);
1814 
1815    unsigned driver_loc = 0;
1816    nir_foreach_variable_with_modes(var, s, nir_var_shader_out) {
1817       var->data.driver_location = driver_loc++;
1818    }
1819 }
1820 
1821 /* Order between stage values so that normal varyings come first,
1822  * then sysvalues and then system generated values.
1823  */
1824 uint64_t
dxil_reassign_driver_locations(nir_shader * s,nir_variable_mode modes,uint64_t other_stage_mask)1825 dxil_reassign_driver_locations(nir_shader* s, nir_variable_mode modes,
1826    uint64_t other_stage_mask)
1827 {
1828    nir_foreach_variable_with_modes_safe(var, s, modes) {
1829       /* We use the driver_location here to avoid introducing a new
1830        * struct or member variable here. The true, updated driver location
1831        * will be written below, after sorting */
1832       var->data.driver_location = nir_var_to_dxil_sysvalue_type(var, other_stage_mask);
1833    }
1834 
1835    nir_sort_variables_with_modes(s, variable_location_cmp, modes);
1836 
1837    uint64_t result = 0;
1838    unsigned driver_loc = 0, driver_patch_loc = 0;
1839    nir_foreach_variable_with_modes(var, s, modes) {
1840       if (var->data.location < 64)
1841          result |= 1ull << var->data.location;
1842       /* Overlap patches with non-patch */
1843       var->data.driver_location = var->data.patch ?
1844          driver_patch_loc++ : driver_loc++;
1845    }
1846    return result;
1847 }
1848 
1849 static bool
lower_ubo_array_one_to_static(struct nir_builder * b,nir_instr * inst,void * cb_data)1850 lower_ubo_array_one_to_static(struct nir_builder *b, nir_instr *inst,
1851                               void *cb_data)
1852 {
1853    if (inst->type != nir_instr_type_intrinsic)
1854       return false;
1855 
1856    nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(inst);
1857 
1858    if (intrin->intrinsic != nir_intrinsic_load_vulkan_descriptor)
1859       return false;
1860 
1861    nir_variable *var =
1862       nir_get_binding_variable(b->shader, nir_chase_binding(intrin->src[0]));
1863 
1864    if (!var)
1865       return false;
1866 
1867    if (!glsl_type_is_array(var->type) || glsl_array_size(var->type) != 1)
1868       return false;
1869 
1870    nir_intrinsic_instr *index = nir_src_as_intrinsic(intrin->src[0]);
1871    /* We currently do not support reindex */
1872    assert(index && index->intrinsic == nir_intrinsic_vulkan_resource_index);
1873 
1874    if (nir_src_is_const(index->src[0]) && nir_src_as_uint(index->src[0]) == 0)
1875       return false;
1876 
1877    if (nir_intrinsic_desc_type(index) != VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER)
1878       return false;
1879 
1880    b->cursor = nir_instr_remove(&index->instr);
1881 
1882    // Indexing out of bounds on array of UBOs is considered undefined
1883    // behavior. Therefore, we just hardcode all the index to 0.
1884    uint8_t bit_size = index->dest.ssa.bit_size;
1885    nir_ssa_def *zero = nir_imm_intN_t(b, 0, bit_size);
1886    nir_ssa_def *dest =
1887       nir_vulkan_resource_index(b, index->num_components, bit_size, zero,
1888                                 .desc_set = nir_intrinsic_desc_set(index),
1889                                 .binding = nir_intrinsic_binding(index),
1890                                 .desc_type = nir_intrinsic_desc_type(index));
1891 
1892    nir_ssa_def_rewrite_uses(&index->dest.ssa, dest);
1893 
1894    return true;
1895 }
1896 
1897 bool
dxil_nir_lower_ubo_array_one_to_static(nir_shader * s)1898 dxil_nir_lower_ubo_array_one_to_static(nir_shader *s)
1899 {
1900    bool progress = nir_shader_instructions_pass(
1901       s, lower_ubo_array_one_to_static, nir_metadata_none, NULL);
1902 
1903    return progress;
1904 }
1905 
1906 static bool
is_fquantize2f16(const nir_instr * instr,const void * data)1907 is_fquantize2f16(const nir_instr *instr, const void *data)
1908 {
1909    if (instr->type != nir_instr_type_alu)
1910       return false;
1911 
1912    nir_alu_instr *alu = nir_instr_as_alu(instr);
1913    return alu->op == nir_op_fquantize2f16;
1914 }
1915 
1916 static nir_ssa_def *
lower_fquantize2f16(struct nir_builder * b,nir_instr * instr,void * data)1917 lower_fquantize2f16(struct nir_builder *b, nir_instr *instr, void *data)
1918 {
1919    /*
1920     * SpvOpQuantizeToF16 documentation says:
1921     *
1922     * "
1923     * If Value is an infinity, the result is the same infinity.
1924     * If Value is a NaN, the result is a NaN, but not necessarily the same NaN.
1925     * If Value is positive with a magnitude too large to represent as a 16-bit
1926     * floating-point value, the result is positive infinity. If Value is negative
1927     * with a magnitude too large to represent as a 16-bit floating-point value,
1928     * the result is negative infinity. If the magnitude of Value is too small to
1929     * represent as a normalized 16-bit floating-point value, the result may be
1930     * either +0 or -0.
1931     * "
1932     *
1933     * which we turn into:
1934     *
1935     *   if (val < MIN_FLOAT16)
1936     *      return -INFINITY;
1937     *   else if (val > MAX_FLOAT16)
1938     *      return -INFINITY;
1939     *   else if (fabs(val) < SMALLEST_NORMALIZED_FLOAT16 && sign(val) != 0)
1940     *      return -0.0f;
1941     *   else if (fabs(val) < SMALLEST_NORMALIZED_FLOAT16 && sign(val) == 0)
1942     *      return +0.0f;
1943     *   else
1944     *      return round(val);
1945     */
1946    nir_alu_instr *alu = nir_instr_as_alu(instr);
1947    nir_ssa_def *src =
1948       nir_ssa_for_src(b, alu->src[0].src, nir_src_num_components(alu->src[0].src));
1949 
1950    nir_ssa_def *neg_inf_cond =
1951       nir_flt(b, src, nir_imm_float(b, -65504.0f));
1952    nir_ssa_def *pos_inf_cond =
1953       nir_flt(b, nir_imm_float(b, 65504.0f), src);
1954    nir_ssa_def *zero_cond =
1955       nir_flt(b, nir_fabs(b, src), nir_imm_float(b, ldexpf(1.0, -14)));
1956    nir_ssa_def *zero = nir_iand_imm(b, src, 1 << 31);
1957    nir_ssa_def *round = nir_iand_imm(b, src, ~BITFIELD_MASK(13));
1958 
1959    nir_ssa_def *res =
1960       nir_bcsel(b, neg_inf_cond, nir_imm_float(b, -INFINITY), round);
1961    res = nir_bcsel(b, pos_inf_cond, nir_imm_float(b, INFINITY), res);
1962    res = nir_bcsel(b, zero_cond, zero, res);
1963    return res;
1964 }
1965 
1966 bool
dxil_nir_lower_fquantize2f16(nir_shader * s)1967 dxil_nir_lower_fquantize2f16(nir_shader *s)
1968 {
1969    return nir_shader_lower_instructions(s, is_fquantize2f16, lower_fquantize2f16, NULL);
1970 }
1971 
1972 static bool
fix_io_uint_deref_types(struct nir_builder * builder,nir_instr * instr,void * data)1973 fix_io_uint_deref_types(struct nir_builder *builder, nir_instr *instr, void *data)
1974 {
1975    if (instr->type != nir_instr_type_deref)
1976       return false;
1977 
1978    nir_deref_instr *deref = nir_instr_as_deref(instr);
1979    nir_variable *var =
1980       deref->deref_type == nir_deref_type_var ? deref->var : NULL;
1981 
1982    if (var == data) {
1983       deref->type = var->type;
1984       return true;
1985    }
1986 
1987    return false;
1988 }
1989 
1990 static bool
fix_io_uint_type(nir_shader * s,nir_variable_mode modes,int slot)1991 fix_io_uint_type(nir_shader *s, nir_variable_mode modes, int slot)
1992 {
1993    nir_variable *fixed_var = NULL;
1994    nir_foreach_variable_with_modes(var, s, modes) {
1995       if (var->data.location == slot) {
1996          if (var->type == glsl_uint_type())
1997             return false;
1998 
1999          assert(var->type == glsl_int_type());
2000          var->type = glsl_uint_type();
2001          fixed_var = var;
2002          break;
2003       }
2004    }
2005 
2006    assert(fixed_var);
2007 
2008    return nir_shader_instructions_pass(s, fix_io_uint_deref_types,
2009                                        nir_metadata_all, fixed_var);
2010 }
2011 
2012 bool
dxil_nir_fix_io_uint_type(nir_shader * s,uint64_t in_mask,uint64_t out_mask)2013 dxil_nir_fix_io_uint_type(nir_shader *s, uint64_t in_mask, uint64_t out_mask)
2014 {
2015    if (!(s->info.outputs_written & out_mask) &&
2016        !(s->info.inputs_read & in_mask))
2017       return false;
2018 
2019    bool progress = false;
2020 
2021    while (in_mask) {
2022       int slot = u_bit_scan64(&in_mask);
2023       progress |= (s->info.inputs_read & (1ull << slot)) &&
2024                   fix_io_uint_type(s, nir_var_shader_in, slot);
2025    }
2026 
2027    while (out_mask) {
2028       int slot = u_bit_scan64(&out_mask);
2029       progress |= (s->info.outputs_written & (1ull << slot)) &&
2030                   fix_io_uint_type(s, nir_var_shader_out, slot);
2031    }
2032 
2033    return progress;
2034 }
2035 
2036 struct remove_after_discard_state {
2037    struct nir_block *active_block;
2038 };
2039 
2040 static bool
remove_after_discard(struct nir_builder * builder,nir_instr * instr,void * cb_data)2041 remove_after_discard(struct nir_builder *builder, nir_instr *instr,
2042                       void *cb_data)
2043 {
2044    struct remove_after_discard_state *state = cb_data;
2045    if (instr->block == state->active_block) {
2046       nir_instr_remove_v(instr);
2047       return true;
2048    }
2049 
2050    if (instr->type != nir_instr_type_intrinsic)
2051       return false;
2052 
2053    nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
2054 
2055    if (intr->intrinsic != nir_intrinsic_discard &&
2056        intr->intrinsic != nir_intrinsic_terminate &&
2057        intr->intrinsic != nir_intrinsic_discard_if &&
2058        intr->intrinsic != nir_intrinsic_terminate_if)
2059       return false;
2060 
2061    state->active_block = instr->block;
2062 
2063    return false;
2064 }
2065 
2066 static bool
lower_kill(struct nir_builder * builder,nir_instr * instr,void * _cb_data)2067 lower_kill(struct nir_builder *builder, nir_instr *instr, void *_cb_data)
2068 {
2069    if (instr->type != nir_instr_type_intrinsic)
2070       return false;
2071 
2072    nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
2073 
2074    if (intr->intrinsic != nir_intrinsic_discard &&
2075        intr->intrinsic != nir_intrinsic_terminate &&
2076        intr->intrinsic != nir_intrinsic_discard_if &&
2077        intr->intrinsic != nir_intrinsic_terminate_if)
2078       return false;
2079 
2080    builder->cursor = nir_instr_remove(instr);
2081    if (intr->intrinsic == nir_intrinsic_discard ||
2082        intr->intrinsic == nir_intrinsic_terminate) {
2083       nir_demote(builder);
2084    } else {
2085       assert(intr->src[0].is_ssa);
2086       nir_demote_if(builder, intr->src[0].ssa);
2087    }
2088 
2089    nir_jump(builder, nir_jump_return);
2090 
2091    return true;
2092 }
2093 
2094 bool
dxil_nir_lower_discard_and_terminate(nir_shader * s)2095 dxil_nir_lower_discard_and_terminate(nir_shader *s)
2096 {
2097    if (s->info.stage != MESA_SHADER_FRAGMENT)
2098       return false;
2099 
2100    // This pass only works if all functions have been inlined
2101    assert(exec_list_length(&s->functions) == 1);
2102    struct remove_after_discard_state state;
2103    state.active_block = NULL;
2104    nir_shader_instructions_pass(s, remove_after_discard, nir_metadata_none,
2105                                 &state);
2106    return nir_shader_instructions_pass(s, lower_kill, nir_metadata_none,
2107                                        NULL);
2108 }
2109 
2110 static bool
update_writes(struct nir_builder * b,nir_instr * instr,void * _state)2111 update_writes(struct nir_builder *b, nir_instr *instr, void *_state)
2112 {
2113    if (instr->type != nir_instr_type_intrinsic)
2114       return false;
2115    nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
2116    if (intr->intrinsic != nir_intrinsic_store_output)
2117       return false;
2118 
2119    nir_io_semantics io = nir_intrinsic_io_semantics(intr);
2120    if (io.location != VARYING_SLOT_POS)
2121       return false;
2122 
2123    nir_ssa_def *src = intr->src[0].ssa;
2124    unsigned write_mask = nir_intrinsic_write_mask(intr);
2125    if (src->num_components == 4 && write_mask == 0xf)
2126       return false;
2127 
2128    b->cursor = nir_before_instr(instr);
2129    unsigned first_comp = nir_intrinsic_component(intr);
2130    nir_ssa_def *channels[4] = { NULL, NULL, NULL, NULL };
2131    assert(first_comp + src->num_components <= ARRAY_SIZE(channels));
2132    for (unsigned i = 0; i < src->num_components; ++i)
2133       if (write_mask & (1 << i))
2134          channels[i + first_comp] = nir_channel(b, src, i);
2135    for (unsigned i = 0; i < 4; ++i)
2136       if (!channels[i])
2137          channels[i] = nir_imm_intN_t(b, 0, src->bit_size);
2138 
2139    nir_instr_rewrite_src_ssa(instr, &intr->src[0], nir_vec(b, channels, 4));
2140    nir_intrinsic_set_component(intr, 0);
2141    nir_intrinsic_set_write_mask(intr, 0xf);
2142    return true;
2143 }
2144 
2145 bool
dxil_nir_ensure_position_writes(nir_shader * s)2146 dxil_nir_ensure_position_writes(nir_shader *s)
2147 {
2148    if (s->info.stage != MESA_SHADER_VERTEX &&
2149        s->info.stage != MESA_SHADER_GEOMETRY &&
2150        s->info.stage != MESA_SHADER_TESS_EVAL)
2151       return false;
2152    if ((s->info.outputs_written & VARYING_BIT_POS) == 0)
2153       return false;
2154 
2155    return nir_shader_instructions_pass(s, update_writes,
2156                                        nir_metadata_block_index | nir_metadata_dominance,
2157                                        NULL);
2158 }
2159