• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright © 2021 Valve Corporation
3  *
4  * SPDX-License-Identifier: MIT
5  */
6 
7 #include "ac_nir.h"
8 #include "ac_nir_helpers.h"
9 #include "ac_gpu_info.h"
10 
11 #include "nir_builder.h"
12 
13 #define SPECIAL_MS_OUT_MASK \
14    (BITFIELD64_BIT(VARYING_SLOT_PRIMITIVE_COUNT) | \
15     BITFIELD64_BIT(VARYING_SLOT_PRIMITIVE_INDICES) | \
16     BITFIELD64_BIT(VARYING_SLOT_CULL_PRIMITIVE))
17 
18 #define MS_PRIM_ARG_EXP_MASK \
19    (VARYING_BIT_LAYER | \
20     VARYING_BIT_VIEWPORT | \
21     VARYING_BIT_PRIMITIVE_SHADING_RATE)
22 
23 #define MS_VERT_ARG_EXP_MASK \
24    (VARYING_BIT_CULL_DIST0 | \
25     VARYING_BIT_CULL_DIST1 | \
26     VARYING_BIT_CLIP_DIST0 | \
27     VARYING_BIT_CLIP_DIST1 | \
28     VARYING_BIT_PSIZ)
29 
30 /* LDS layout of Mesh Shader workgroup info. */
31 enum {
32    /* DW0: number of primitives */
33    lds_ms_num_prims = 0,
34    /* DW1: number of vertices */
35    lds_ms_num_vtx = 4,
36    /* DW2: workgroup index within the current dispatch */
37    lds_ms_wg_index = 8,
38    /* DW3: number of API workgroups in flight */
39    lds_ms_num_api_waves = 12,
40 };
41 
42 /* Potential location for Mesh Shader outputs. */
43 typedef enum {
44    ms_out_mode_lds,
45    ms_out_mode_scratch_ring,
46    ms_out_mode_attr_ring,
47    ms_out_mode_var,
48 } ms_out_mode;
49 
50 typedef struct
51 {
52    uint64_t mask; /* Mask of output locations */
53    uint32_t addr; /* Base address */
54 } ms_out_part;
55 
56 typedef struct
57 {
58    /* Mesh shader LDS layout. For details, see ms_calculate_output_layout. */
59    struct {
60       uint32_t workgroup_info_addr;
61       ms_out_part vtx_attr;
62       ms_out_part prm_attr;
63       uint32_t indices_addr;
64       uint32_t cull_flags_addr;
65       uint32_t total_size;
66    } lds;
67 
68    /* VRAM "mesh shader scratch ring" layout for outputs that don't fit into the LDS.
69     * Not to be confused with scratch memory.
70     */
71    struct {
72       ms_out_part vtx_attr;
73       ms_out_part prm_attr;
74    } scratch_ring;
75 
76    /* VRAM attributes ring (supported GPUs only) for all non-position outputs.
77     * We don't have to reload attributes from this ring at the end of the shader.
78     */
79    struct {
80       ms_out_part vtx_attr;
81       ms_out_part prm_attr;
82    } attr_ring;
83 
84    /* Outputs without cross-invocation access can be stored in variables. */
85    struct {
86       ms_out_part vtx_attr;
87       ms_out_part prm_attr;
88    } var;
89 } ms_out_mem_layout;
90 
91 typedef struct
92 {
93    const struct radeon_info *hw_info;
94    bool fast_launch_2;
95    bool vert_multirow_export;
96    bool prim_multirow_export;
97 
98    ms_out_mem_layout layout;
99    uint64_t per_vertex_outputs;
100    uint64_t per_primitive_outputs;
101    unsigned vertices_per_prim;
102 
103    unsigned wave_size;
104    unsigned api_workgroup_size;
105    unsigned hw_workgroup_size;
106 
107    nir_def *workgroup_index;
108    nir_variable *out_variables[VARYING_SLOT_MAX * 4];
109    nir_variable *primitive_count_var;
110    nir_variable *vertex_count_var;
111 
112    ac_nir_prerast_out out;
113 
114    /* True if the lowering needs to insert the layer output. */
115    bool insert_layer_output;
116    /* True if cull flags are used */
117    bool uses_cull_flags;
118 
119    uint32_t clipdist_enable_mask;
120    const uint8_t *vs_output_param_offset;
121    bool has_param_exports;
122 
123    /* True if the lowering needs to insert shader query. */
124    bool has_query;
125 } lower_ngg_ms_state;
126 
127 static void
ms_store_prim_indices(nir_builder * b,nir_intrinsic_instr * intrin,lower_ngg_ms_state * s)128 ms_store_prim_indices(nir_builder *b,
129                       nir_intrinsic_instr *intrin,
130                       lower_ngg_ms_state *s)
131 {
132    /* EXT_mesh_shader primitive indices: array of vectors.
133     * They don't count as per-primitive outputs, but the array is indexed
134     * by the primitive index, so they are practically per-primitive.
135     */
136    assert(nir_src_is_const(*nir_get_io_offset_src(intrin)));
137    assert(nir_src_as_uint(*nir_get_io_offset_src(intrin)) == 0);
138 
139    const unsigned component_offset = nir_intrinsic_component(intrin);
140    nir_def *store_val = intrin->src[0].ssa;
141    assert(store_val->num_components <= 3);
142 
143    if (store_val->num_components > s->vertices_per_prim)
144       store_val = nir_trim_vector(b, store_val, s->vertices_per_prim);
145 
146    if (s->layout.var.prm_attr.mask & BITFIELD64_BIT(VARYING_SLOT_PRIMITIVE_INDICES)) {
147       for (unsigned c = 0; c < store_val->num_components; ++c) {
148          const unsigned i = VARYING_SLOT_PRIMITIVE_INDICES * 4 + c + component_offset;
149          nir_store_var(b, s->out_variables[i], nir_channel(b, store_val, c), 0x1);
150       }
151       return;
152    }
153 
154    nir_def *arr_index = nir_get_io_arrayed_index_src(intrin)->ssa;
155    nir_def *offset = nir_imul_imm(b, arr_index, s->vertices_per_prim);
156 
157    /* The max vertex count is 256, so these indices always fit 8 bits.
158     * To reduce LDS use, store these as a flat array of 8-bit values.
159     */
160    nir_store_shared(b, nir_u2u8(b, store_val), offset, .base = s->layout.lds.indices_addr + component_offset);
161 }
162 
163 static void
ms_store_cull_flag(nir_builder * b,nir_intrinsic_instr * intrin,lower_ngg_ms_state * s)164 ms_store_cull_flag(nir_builder *b,
165                    nir_intrinsic_instr *intrin,
166                    lower_ngg_ms_state *s)
167 {
168    /* EXT_mesh_shader cull primitive: per-primitive bool. */
169    assert(nir_src_is_const(*nir_get_io_offset_src(intrin)));
170    assert(nir_src_as_uint(*nir_get_io_offset_src(intrin)) == 0);
171    assert(nir_intrinsic_component(intrin) == 0);
172    assert(nir_intrinsic_write_mask(intrin) == 1);
173 
174    nir_def *store_val = intrin->src[0].ssa;
175 
176    assert(store_val->num_components == 1);
177    assert(store_val->bit_size == 1);
178 
179    if (s->layout.var.prm_attr.mask & BITFIELD64_BIT(VARYING_SLOT_CULL_PRIMITIVE)) {
180       nir_store_var(b, s->out_variables[VARYING_SLOT_CULL_PRIMITIVE * 4], nir_b2i32(b, store_val), 0x1);
181       return;
182    }
183 
184    nir_def *arr_index = nir_get_io_arrayed_index_src(intrin)->ssa;
185    nir_def *offset = nir_imul_imm(b, arr_index, s->vertices_per_prim);
186 
187    /* To reduce LDS use, store these as an array of 8-bit values. */
188    nir_store_shared(b, nir_b2i8(b, store_val), offset, .base = s->layout.lds.cull_flags_addr);
189 }
190 
191 static nir_def *
ms_arrayed_output_base_addr(nir_builder * b,nir_def * arr_index,unsigned mapped_location,unsigned num_arrayed_outputs)192 ms_arrayed_output_base_addr(nir_builder *b,
193                             nir_def *arr_index,
194                             unsigned mapped_location,
195                             unsigned num_arrayed_outputs)
196 {
197    /* Address offset of the array item (vertex or primitive). */
198    unsigned arr_index_stride = num_arrayed_outputs * 16u;
199    nir_def *arr_index_off = nir_imul_imm(b, arr_index, arr_index_stride);
200 
201    /* IO address offset within the vertex or primitive data. */
202    unsigned io_offset = mapped_location * 16u;
203    nir_def *io_off = nir_imm_int(b, io_offset);
204 
205    return nir_iadd_nuw(b, arr_index_off, io_off);
206 }
207 
208 static void
update_ms_output_info(const nir_io_semantics io_sem,const nir_src * base_offset_src,const uint32_t write_mask,const unsigned component_offset,const unsigned bit_size,const ms_out_part * out,lower_ngg_ms_state * s)209 update_ms_output_info(const nir_io_semantics io_sem,
210                       const nir_src *base_offset_src,
211                       const uint32_t write_mask,
212                       const unsigned component_offset,
213                       const unsigned bit_size,
214                       const ms_out_part *out,
215                       lower_ngg_ms_state *s)
216 {
217    const uint32_t components_mask = write_mask << component_offset;
218 
219    /* 64-bit outputs should have already been lowered to 32-bit. */
220    assert(bit_size <= 32);
221    assert(components_mask <= 0xf);
222 
223    /* When the base offset is constant, only mark the components of the current slot as used.
224     * Otherwise, mark the components of all possibly affected slots as used.
225     */
226    const unsigned base_off_start = nir_src_is_const(*base_offset_src) ? nir_src_as_uint(*base_offset_src) : 0;
227    const unsigned num_slots = nir_src_is_const(*base_offset_src) ? 1 : io_sem.num_slots;
228 
229    for (unsigned base_off = base_off_start; base_off < num_slots; ++base_off) {
230       ac_nir_prerast_per_output_info *info = &s->out.infos[io_sem.location + base_off];
231       info->components_mask |= components_mask;
232 
233       if (!io_sem.no_sysval_output)
234          info->as_sysval_mask |= components_mask;
235       if (!io_sem.no_varying)
236          info->as_varying_mask |= components_mask;
237    }
238 }
239 
240 static const ms_out_part *
ms_get_out_layout_part(unsigned location,shader_info * info,ms_out_mode * out_mode,lower_ngg_ms_state * s)241 ms_get_out_layout_part(unsigned location,
242                        shader_info *info,
243                        ms_out_mode *out_mode,
244                        lower_ngg_ms_state *s)
245 {
246    uint64_t mask = BITFIELD64_BIT(location);
247 
248    if (info->per_primitive_outputs & mask) {
249       if (mask & s->layout.lds.prm_attr.mask) {
250          *out_mode = ms_out_mode_lds;
251          return &s->layout.lds.prm_attr;
252       } else if (mask & s->layout.scratch_ring.prm_attr.mask) {
253          *out_mode = ms_out_mode_scratch_ring;
254          return &s->layout.scratch_ring.prm_attr;
255       } else if (mask & s->layout.attr_ring.prm_attr.mask) {
256          *out_mode = ms_out_mode_attr_ring;
257          return &s->layout.attr_ring.prm_attr;
258       } else if (mask & s->layout.var.prm_attr.mask) {
259          *out_mode = ms_out_mode_var;
260          return &s->layout.var.prm_attr;
261       }
262    } else {
263       if (mask & s->layout.lds.vtx_attr.mask) {
264          *out_mode = ms_out_mode_lds;
265          return &s->layout.lds.vtx_attr;
266       } else if (mask & s->layout.scratch_ring.vtx_attr.mask) {
267          *out_mode = ms_out_mode_scratch_ring;
268          return &s->layout.scratch_ring.vtx_attr;
269       } else if (mask & s->layout.attr_ring.vtx_attr.mask) {
270          *out_mode = ms_out_mode_attr_ring;
271          return &s->layout.attr_ring.vtx_attr;
272       } else if (mask & s->layout.var.vtx_attr.mask) {
273          *out_mode = ms_out_mode_var;
274          return &s->layout.var.vtx_attr;
275       }
276    }
277 
278    unreachable("Couldn't figure out mesh shader output mode.");
279 }
280 
281 static void
ms_store_arrayed_output(nir_builder * b,nir_src * base_off_src,nir_def * store_val,nir_def * arr_index,const nir_io_semantics io_sem,const unsigned component_offset,const unsigned write_mask,lower_ngg_ms_state * s)282 ms_store_arrayed_output(nir_builder *b,
283                         nir_src *base_off_src,
284                         nir_def *store_val,
285                         nir_def *arr_index,
286                         const nir_io_semantics io_sem,
287                         const unsigned component_offset,
288                         const unsigned write_mask,
289                         lower_ngg_ms_state *s)
290 {
291    ms_out_mode out_mode;
292    const ms_out_part *out = ms_get_out_layout_part(io_sem.location, &b->shader->info, &out_mode, s);
293    update_ms_output_info(io_sem, base_off_src, write_mask, component_offset, store_val->bit_size, out, s);
294 
295    bool hi_16b = io_sem.high_16bits;
296    bool lo_16b = !hi_16b && store_val->bit_size == 16;
297 
298    unsigned mapped_location = util_bitcount64(out->mask & u_bit_consecutive64(0, io_sem.location));
299    unsigned num_outputs = util_bitcount64(out->mask);
300    unsigned const_off = out->addr + component_offset * 4 + (hi_16b ? 2 : 0);
301 
302    nir_def *base_addr = ms_arrayed_output_base_addr(b, arr_index, mapped_location, num_outputs);
303    nir_def *base_offset = base_off_src->ssa;
304    nir_def *base_addr_off = nir_imul_imm(b, base_offset, 16u);
305    nir_def *addr = nir_iadd_nuw(b, base_addr, base_addr_off);
306 
307    if (out_mode == ms_out_mode_lds) {
308       nir_store_shared(b, store_val, addr, .base = const_off,
309                      .write_mask = write_mask, .align_mul = 16,
310                      .align_offset = const_off % 16);
311    } else if (out_mode == ms_out_mode_scratch_ring) {
312       nir_def *ring = nir_load_ring_mesh_scratch_amd(b);
313       nir_def *off = nir_load_ring_mesh_scratch_offset_amd(b);
314       nir_def *zero = nir_imm_int(b, 0);
315       nir_store_buffer_amd(b, store_val, ring, addr, off, zero,
316                            .base = const_off,
317                            .write_mask = write_mask,
318                            .memory_modes = nir_var_shader_out,
319                            .access = ACCESS_COHERENT);
320    } else if (out_mode == ms_out_mode_attr_ring) {
321       /* Store params straight to the attribute ring.
322        * Even though the access pattern may not be the most optimal,
323        * this is still much better than reserving LDS and losing waves.
324        * (Also much better than storing and reloading from the scratch ring.)
325        */
326       unsigned param_offset = s->vs_output_param_offset[io_sem.location];
327       nir_def *ring = nir_load_ring_attr_amd(b);
328       nir_def *soffset = nir_load_ring_attr_offset_amd(b);
329       nir_store_buffer_amd(b, store_val, ring, base_addr_off, soffset, arr_index,
330                            .base = const_off + param_offset * 16,
331                            .write_mask = write_mask,
332                            .memory_modes = nir_var_shader_out,
333                            .access = ACCESS_COHERENT | ACCESS_IS_SWIZZLED_AMD);
334    } else if (out_mode == ms_out_mode_var) {
335       unsigned write_mask_32 = write_mask;
336       if (store_val->bit_size > 32) {
337          /* Split 64-bit store values to 32-bit components. */
338          store_val = nir_bitcast_vector(b, store_val, 32);
339          /* Widen the write mask so it is in 32-bit components. */
340          write_mask_32 = util_widen_mask(write_mask, store_val->bit_size / 32);
341       }
342 
343       u_foreach_bit(comp, write_mask_32) {
344          unsigned idx = io_sem.location * 4 + comp + component_offset;
345          nir_def *val = nir_channel(b, store_val, comp);
346          nir_def *v = nir_load_var(b, s->out_variables[idx]);
347 
348          if (lo_16b) {
349             nir_def *var_hi = nir_unpack_32_2x16_split_y(b, v);
350             val = nir_pack_32_2x16_split(b, val, var_hi);
351          } else if (hi_16b) {
352             nir_def *var_lo = nir_unpack_32_2x16_split_x(b, v);
353             val = nir_pack_32_2x16_split(b, var_lo, val);
354          }
355 
356          nir_store_var(b, s->out_variables[idx], val, 0x1);
357       }
358    } else {
359       unreachable("Invalid MS output mode for store");
360    }
361 }
362 
363 static void
ms_store_arrayed_output_intrin(nir_builder * b,nir_intrinsic_instr * intrin,lower_ngg_ms_state * s)364 ms_store_arrayed_output_intrin(nir_builder *b,
365                                nir_intrinsic_instr *intrin,
366                                lower_ngg_ms_state *s)
367 {
368    const nir_io_semantics io_sem = nir_intrinsic_io_semantics(intrin);
369 
370    if (io_sem.location == VARYING_SLOT_PRIMITIVE_INDICES) {
371       ms_store_prim_indices(b, intrin, s);
372       return;
373    } else if (io_sem.location == VARYING_SLOT_CULL_PRIMITIVE) {
374       ms_store_cull_flag(b, intrin, s);
375       return;
376    }
377 
378    unsigned component_offset = nir_intrinsic_component(intrin);
379    unsigned write_mask = nir_intrinsic_write_mask(intrin);
380 
381    nir_def *store_val = intrin->src[0].ssa;
382    nir_def *arr_index = nir_get_io_arrayed_index_src(intrin)->ssa;
383    nir_src *base_off_src = nir_get_io_offset_src(intrin);
384 
385    if (store_val->bit_size < 32) {
386       /* Split 16-bit output stores to ensure each 16-bit component is stored
387        * in the correct location, without overwriting the other 16 bits there.
388        */
389       u_foreach_bit(c, write_mask) {
390          nir_def *store_component = nir_channel(b, store_val, c);
391          ms_store_arrayed_output(b, base_off_src, store_component, arr_index, io_sem, c + component_offset, 1, s);
392       }
393    } else {
394       ms_store_arrayed_output(b, base_off_src, store_val, arr_index, io_sem, component_offset, write_mask, s);
395    }
396 }
397 
398 static nir_def *
ms_load_arrayed_output(nir_builder * b,nir_def * arr_index,nir_def * base_offset,unsigned location,unsigned component_offset,unsigned num_components,unsigned load_bit_size,lower_ngg_ms_state * s)399 ms_load_arrayed_output(nir_builder *b,
400                        nir_def *arr_index,
401                        nir_def *base_offset,
402                        unsigned location,
403                        unsigned component_offset,
404                        unsigned num_components,
405                        unsigned load_bit_size,
406                        lower_ngg_ms_state *s)
407 {
408    ms_out_mode out_mode;
409    const ms_out_part *out = ms_get_out_layout_part(location, &b->shader->info, &out_mode, s);
410 
411    unsigned component_addr_off = component_offset * 4;
412    unsigned num_outputs = util_bitcount64(out->mask);
413    unsigned const_off = out->addr + component_offset * 4;
414 
415    /* Use compacted location instead of the original semantic location. */
416    unsigned mapped_location = util_bitcount64(out->mask & u_bit_consecutive64(0, location));
417 
418    nir_def *base_addr = ms_arrayed_output_base_addr(b, arr_index, mapped_location, num_outputs);
419    nir_def *base_addr_off = nir_imul_imm(b, base_offset, 16);
420    nir_def *addr = nir_iadd_nuw(b, base_addr, base_addr_off);
421 
422    if (out_mode == ms_out_mode_lds) {
423       return nir_load_shared(b, num_components, load_bit_size, addr, .align_mul = 16,
424                              .align_offset = component_addr_off % 16,
425                              .base = const_off);
426    } else if (out_mode == ms_out_mode_scratch_ring) {
427       nir_def *ring = nir_load_ring_mesh_scratch_amd(b);
428       nir_def *off = nir_load_ring_mesh_scratch_offset_amd(b);
429       nir_def *zero = nir_imm_int(b, 0);
430       return nir_load_buffer_amd(b, num_components, load_bit_size, ring, addr, off, zero,
431                                  .base = const_off,
432                                  .memory_modes = nir_var_shader_out,
433                                  .access = ACCESS_COHERENT);
434    } else if (out_mode == ms_out_mode_var) {
435       assert(load_bit_size == 32);
436       nir_def *arr[8] = {0};
437       for (unsigned comp = 0; comp < num_components; ++comp) {
438          unsigned idx = location * 4 + comp + component_addr_off;
439          arr[comp] = nir_load_var(b, s->out_variables[idx]);
440       }
441       return nir_vec(b, arr, num_components);
442    } else {
443       unreachable("Invalid MS output mode for load");
444    }
445 }
446 
447 static nir_def *
lower_ms_load_workgroup_index(nir_builder * b,UNUSED nir_intrinsic_instr * intrin,lower_ngg_ms_state * s)448 lower_ms_load_workgroup_index(nir_builder *b,
449                               UNUSED nir_intrinsic_instr *intrin,
450                               lower_ngg_ms_state *s)
451 {
452    return s->workgroup_index;
453 }
454 
455 static nir_def *
lower_ms_set_vertex_and_primitive_count(nir_builder * b,nir_intrinsic_instr * intrin,lower_ngg_ms_state * s)456 lower_ms_set_vertex_and_primitive_count(nir_builder *b,
457                                         nir_intrinsic_instr *intrin,
458                                         lower_ngg_ms_state *s)
459 {
460    /* If either the number of vertices or primitives is zero, set both of them to zero. */
461    nir_def *num_vtx = nir_read_first_invocation(b, intrin->src[0].ssa);
462    nir_def *num_prm = nir_read_first_invocation(b, intrin->src[1].ssa);
463    nir_def *zero = nir_imm_int(b, 0);
464    nir_def *is_either_zero = nir_ieq(b, nir_umin(b, num_vtx, num_prm), zero);
465    num_vtx = nir_bcsel(b, is_either_zero, zero, num_vtx);
466    num_prm = nir_bcsel(b, is_either_zero, zero, num_prm);
467 
468    nir_store_var(b, s->vertex_count_var, num_vtx, 0x1);
469    nir_store_var(b, s->primitive_count_var, num_prm, 0x1);
470 
471    return NIR_LOWER_INSTR_PROGRESS_REPLACE;
472 }
473 
474 static nir_def *
update_ms_barrier(nir_builder * b,nir_intrinsic_instr * intrin,lower_ngg_ms_state * s)475 update_ms_barrier(nir_builder *b,
476                          nir_intrinsic_instr *intrin,
477                          lower_ngg_ms_state *s)
478 {
479    /* Output loads and stores are lowered to shared memory access,
480     * so we have to update the barriers to also reflect this.
481     */
482    unsigned mem_modes = nir_intrinsic_memory_modes(intrin);
483    if (mem_modes & nir_var_shader_out)
484       mem_modes |= nir_var_mem_shared;
485    else
486       return NULL;
487 
488    nir_intrinsic_set_memory_modes(intrin, mem_modes);
489 
490    return NIR_LOWER_INSTR_PROGRESS;
491 }
492 
493 static nir_def *
lower_ms_intrinsic(nir_builder * b,nir_instr * instr,void * state)494 lower_ms_intrinsic(nir_builder *b, nir_instr *instr, void *state)
495 {
496    lower_ngg_ms_state *s = (lower_ngg_ms_state *) state;
497 
498    if (instr->type != nir_instr_type_intrinsic)
499       return NULL;
500 
501    nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
502 
503    switch (intrin->intrinsic) {
504    case nir_intrinsic_store_per_vertex_output:
505    case nir_intrinsic_store_per_primitive_output:
506       ms_store_arrayed_output_intrin(b, intrin, s);
507       return NIR_LOWER_INSTR_PROGRESS_REPLACE;
508    case nir_intrinsic_barrier:
509       return update_ms_barrier(b, intrin, s);
510    case nir_intrinsic_load_workgroup_index:
511       return lower_ms_load_workgroup_index(b, intrin, s);
512    case nir_intrinsic_set_vertex_and_primitive_count:
513       return lower_ms_set_vertex_and_primitive_count(b, intrin, s);
514    default:
515       unreachable("Not a lowerable mesh shader intrinsic.");
516    }
517 }
518 
519 static bool
filter_ms_intrinsic(const nir_instr * instr,UNUSED const void * s)520 filter_ms_intrinsic(const nir_instr *instr,
521                     UNUSED const void *s)
522 {
523    if (instr->type != nir_instr_type_intrinsic)
524       return false;
525 
526    nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
527    return intrin->intrinsic == nir_intrinsic_store_output ||
528           intrin->intrinsic == nir_intrinsic_load_output ||
529           intrin->intrinsic == nir_intrinsic_store_per_vertex_output ||
530           intrin->intrinsic == nir_intrinsic_store_per_primitive_output ||
531           intrin->intrinsic == nir_intrinsic_barrier ||
532           intrin->intrinsic == nir_intrinsic_load_workgroup_index ||
533           intrin->intrinsic == nir_intrinsic_set_vertex_and_primitive_count;
534 }
535 
536 static void
lower_ms_intrinsics(nir_shader * shader,lower_ngg_ms_state * s)537 lower_ms_intrinsics(nir_shader *shader, lower_ngg_ms_state *s)
538 {
539    nir_shader_lower_instructions(shader, filter_ms_intrinsic, lower_ms_intrinsic, s);
540 }
541 
542 static void
ms_emit_arrayed_outputs(nir_builder * b,nir_def * invocation_index,uint64_t mask,lower_ngg_ms_state * s)543 ms_emit_arrayed_outputs(nir_builder *b,
544                         nir_def *invocation_index,
545                         uint64_t mask,
546                         lower_ngg_ms_state *s)
547 {
548    nir_def *zero = nir_imm_int(b, 0);
549 
550    u_foreach_bit64(slot, mask) {
551       /* Should not occur here, handled separately. */
552       assert(slot != VARYING_SLOT_PRIMITIVE_COUNT && slot != VARYING_SLOT_PRIMITIVE_INDICES);
553 
554       unsigned component_mask = s->out.infos[slot].components_mask;
555 
556       while (component_mask) {
557          int start_comp = 0, num_components = 1;
558          u_bit_scan_consecutive_range(&component_mask, &start_comp, &num_components);
559 
560          nir_def *load =
561             ms_load_arrayed_output(b, invocation_index, zero, slot, start_comp,
562                                    num_components, 32, s);
563 
564          for (int i = 0; i < num_components; i++)
565             s->out.outputs[slot][start_comp + i] = nir_channel(b, load, i);
566       }
567    }
568 }
569 
570 static void
ms_create_same_invocation_vars(nir_builder * b,lower_ngg_ms_state * s)571 ms_create_same_invocation_vars(nir_builder *b, lower_ngg_ms_state *s)
572 {
573    /* Initialize NIR variables for same-invocation outputs. */
574    uint64_t same_invocation_output_mask = s->layout.var.prm_attr.mask | s->layout.var.vtx_attr.mask;
575 
576    u_foreach_bit64(slot, same_invocation_output_mask) {
577       for (unsigned comp = 0; comp < 4; ++comp) {
578          unsigned idx = slot * 4 + comp;
579          s->out_variables[idx] = nir_local_variable_create(b->impl, glsl_uint_type(), "ms_var_output");
580       }
581    }
582 }
583 
584 static void
ms_emit_legacy_workgroup_index(nir_builder * b,lower_ngg_ms_state * s)585 ms_emit_legacy_workgroup_index(nir_builder *b, lower_ngg_ms_state *s)
586 {
587    /* Workgroup ID should have been lowered to workgroup index. */
588    assert(!BITSET_TEST(b->shader->info.system_values_read, SYSTEM_VALUE_WORKGROUP_ID));
589 
590    /* No need to do anything if the shader doesn't use the workgroup index. */
591    if (!BITSET_TEST(b->shader->info.system_values_read, SYSTEM_VALUE_WORKGROUP_INDEX))
592       return;
593 
594    b->cursor = nir_before_impl(b->impl);
595 
596    /* Legacy fast launch mode (FAST_LAUNCH=1):
597     *
598     * The HW doesn't support a proper workgroup index for vertex processing stages,
599     * so we use the vertex ID which is equivalent to the index of the current workgroup
600     * within the current dispatch.
601     *
602     * Due to the register programming of mesh shaders, this value is only filled for
603     * the first invocation of the first wave. To let other waves know, we use LDS.
604     */
605    nir_def *workgroup_index = nir_load_vertex_id_zero_base(b);
606 
607    if (s->api_workgroup_size <= s->wave_size) {
608       /* API workgroup is small, so we don't need to use LDS. */
609       s->workgroup_index = nir_read_first_invocation(b, workgroup_index);
610       return;
611    }
612 
613    unsigned workgroup_index_lds_addr = s->layout.lds.workgroup_info_addr + lds_ms_wg_index;
614 
615    nir_def *zero = nir_imm_int(b, 0);
616    nir_def *dont_care = nir_undef(b, 1, 32);
617    nir_def *loaded_workgroup_index = NULL;
618 
619    /* Use elect to make sure only 1 invocation uses LDS. */
620    nir_if *if_elected = nir_push_if(b, nir_elect(b, 1));
621    {
622       nir_def *wave_id = nir_load_subgroup_id(b);
623       nir_if *if_wave_0 = nir_push_if(b, nir_ieq_imm(b, wave_id, 0));
624       {
625          nir_store_shared(b, workgroup_index, zero, .base = workgroup_index_lds_addr);
626          nir_barrier(b, .execution_scope = SCOPE_WORKGROUP,
627                                .memory_scope = SCOPE_WORKGROUP,
628                                .memory_semantics = NIR_MEMORY_ACQ_REL,
629                                .memory_modes = nir_var_mem_shared);
630       }
631       nir_push_else(b, if_wave_0);
632       {
633          nir_barrier(b, .execution_scope = SCOPE_WORKGROUP,
634                                .memory_scope = SCOPE_WORKGROUP,
635                                .memory_semantics = NIR_MEMORY_ACQ_REL,
636                                .memory_modes = nir_var_mem_shared);
637          loaded_workgroup_index = nir_load_shared(b, 1, 32, zero, .base = workgroup_index_lds_addr);
638       }
639       nir_pop_if(b, if_wave_0);
640 
641       workgroup_index = nir_if_phi(b, workgroup_index, loaded_workgroup_index);
642    }
643    nir_pop_if(b, if_elected);
644 
645    workgroup_index = nir_if_phi(b, workgroup_index, dont_care);
646    s->workgroup_index = nir_read_first_invocation(b, workgroup_index);
647 }
648 
649 static void
set_ms_final_output_counts(nir_builder * b,lower_ngg_ms_state * s,nir_def ** out_num_prm,nir_def ** out_num_vtx)650 set_ms_final_output_counts(nir_builder *b,
651                            lower_ngg_ms_state *s,
652                            nir_def **out_num_prm,
653                            nir_def **out_num_vtx)
654 {
655    /* The spec allows the numbers to be divergent, and in that case we need to
656     * use the values from the first invocation. Also the HW requires us to set
657     * both to 0 if either was 0.
658     *
659     * These are already done by the lowering.
660     */
661    nir_def *num_prm = nir_load_var(b, s->primitive_count_var);
662    nir_def *num_vtx = nir_load_var(b, s->vertex_count_var);
663 
664    if (s->hw_workgroup_size <= s->wave_size) {
665       /* Single-wave mesh shader workgroup. */
666       ac_nir_ngg_alloc_vertices_and_primitives(b, num_vtx, num_prm, false);
667 
668       *out_num_prm = num_prm;
669       *out_num_vtx = num_vtx;
670       return;
671    }
672 
673    /* Multi-wave mesh shader workgroup:
674     * We need to use LDS to distribute the correct values to the other waves.
675     *
676     * TODO:
677     * If we can prove that the values are workgroup-uniform, we can skip this
678     * and just use whatever the current wave has. However, NIR divergence analysis
679     * currently doesn't support this.
680     */
681 
682    nir_def *zero = nir_imm_int(b, 0);
683 
684    nir_if *if_wave_0 = nir_push_if(b, nir_ieq_imm(b, nir_load_subgroup_id(b), 0));
685    {
686       nir_if *if_elected = nir_push_if(b, nir_elect(b, 1));
687       {
688          nir_store_shared(b, nir_vec2(b, num_prm, num_vtx), zero,
689                           .base = s->layout.lds.workgroup_info_addr + lds_ms_num_prims);
690       }
691       nir_pop_if(b, if_elected);
692 
693       nir_barrier(b, .execution_scope = SCOPE_WORKGROUP,
694                             .memory_scope = SCOPE_WORKGROUP,
695                             .memory_semantics = NIR_MEMORY_ACQ_REL,
696                             .memory_modes = nir_var_mem_shared);
697 
698       ac_nir_ngg_alloc_vertices_and_primitives(b, num_vtx, num_prm, false);
699    }
700    nir_push_else(b, if_wave_0);
701    {
702       nir_barrier(b, .execution_scope = SCOPE_WORKGROUP,
703                             .memory_scope = SCOPE_WORKGROUP,
704                             .memory_semantics = NIR_MEMORY_ACQ_REL,
705                             .memory_modes = nir_var_mem_shared);
706 
707       nir_def *prm_vtx = NULL;
708       nir_def *dont_care_2x32 = nir_undef(b, 2, 32);
709       nir_if *if_elected = nir_push_if(b, nir_elect(b, 1));
710       {
711          prm_vtx = nir_load_shared(b, 2, 32, zero,
712                                    .base = s->layout.lds.workgroup_info_addr + lds_ms_num_prims);
713       }
714       nir_pop_if(b, if_elected);
715 
716       prm_vtx = nir_if_phi(b, prm_vtx, dont_care_2x32);
717       num_prm = nir_read_first_invocation(b, nir_channel(b, prm_vtx, 0));
718       num_vtx = nir_read_first_invocation(b, nir_channel(b, prm_vtx, 1));
719 
720       nir_store_var(b, s->primitive_count_var, num_prm, 0x1);
721       nir_store_var(b, s->vertex_count_var, num_vtx, 0x1);
722    }
723    nir_pop_if(b, if_wave_0);
724 
725    *out_num_prm = nir_load_var(b, s->primitive_count_var);
726    *out_num_vtx = nir_load_var(b, s->vertex_count_var);
727 }
728 
729 static void
ms_emit_attribute_ring_output_stores(nir_builder * b,const uint64_t outputs_mask,nir_def * idx,lower_ngg_ms_state * s)730 ms_emit_attribute_ring_output_stores(nir_builder *b, const uint64_t outputs_mask,
731                                      nir_def *idx, lower_ngg_ms_state *s)
732 {
733    if (!outputs_mask)
734       return;
735 
736    nir_def *ring = nir_load_ring_attr_amd(b);
737    nir_def *off = nir_load_ring_attr_offset_amd(b);
738    nir_def *zero = nir_imm_int(b, 0);
739 
740    u_foreach_bit64 (slot, outputs_mask) {
741       if (s->vs_output_param_offset[slot] > AC_EXP_PARAM_OFFSET_31)
742          continue;
743 
744       nir_def *soffset = nir_iadd_imm(b, off, s->vs_output_param_offset[slot] * 16 * 32);
745       nir_def *store_val = nir_undef(b, 4, 32);
746       unsigned store_val_components = 0;
747       for (unsigned c = 0; c < 4; ++c) {
748          if (s->out.outputs[slot][c]) {
749             store_val = nir_vector_insert_imm(b, store_val, s->out.outputs[slot][c], c);
750             store_val_components = c + 1;
751          }
752       }
753 
754       store_val = nir_trim_vector(b, store_val, store_val_components);
755       nir_store_buffer_amd(b, store_val, ring, zero, soffset, idx,
756                            .memory_modes = nir_var_shader_out,
757                            .access = ACCESS_COHERENT | ACCESS_IS_SWIZZLED_AMD);
758    }
759 }
760 
761 static nir_def *
ms_prim_exp_arg_ch1(nir_builder * b,nir_def * invocation_index,nir_def * num_vtx,lower_ngg_ms_state * s)762 ms_prim_exp_arg_ch1(nir_builder *b, nir_def *invocation_index, nir_def *num_vtx, lower_ngg_ms_state *s)
763 {
764    /* Primitive connectivity data: describes which vertices the primitive uses. */
765    nir_def *prim_idx_addr = nir_imul_imm(b, invocation_index, s->vertices_per_prim);
766    nir_def *indices_loaded = NULL;
767    nir_def *cull_flag = NULL;
768 
769    if (s->layout.var.prm_attr.mask & BITFIELD64_BIT(VARYING_SLOT_PRIMITIVE_INDICES)) {
770       nir_def *indices[3] = {0};
771       for (unsigned c = 0; c < s->vertices_per_prim; ++c)
772          indices[c] = nir_load_var(b, s->out_variables[VARYING_SLOT_PRIMITIVE_INDICES * 4 + c]);
773       indices_loaded = nir_vec(b, indices, s->vertices_per_prim);
774    } else {
775       indices_loaded = nir_load_shared(b, s->vertices_per_prim, 8, prim_idx_addr, .base = s->layout.lds.indices_addr);
776       indices_loaded = nir_u2u32(b, indices_loaded);
777    }
778 
779    if (s->uses_cull_flags) {
780       nir_def *loaded_cull_flag = NULL;
781       if (s->layout.var.prm_attr.mask & BITFIELD64_BIT(VARYING_SLOT_CULL_PRIMITIVE))
782          loaded_cull_flag = nir_load_var(b, s->out_variables[VARYING_SLOT_CULL_PRIMITIVE * 4]);
783       else
784          loaded_cull_flag = nir_u2u32(b, nir_load_shared(b, 1, 8, prim_idx_addr, .base = s->layout.lds.cull_flags_addr));
785 
786       cull_flag = nir_i2b(b, loaded_cull_flag);
787    }
788 
789    nir_def *indices[3];
790    nir_def *max_vtx_idx = nir_iadd_imm(b, num_vtx, -1u);
791 
792    for (unsigned i = 0; i < s->vertices_per_prim; ++i) {
793       indices[i] = nir_channel(b, indices_loaded, i);
794       indices[i] = nir_umin(b, indices[i], max_vtx_idx);
795    }
796 
797    return ac_nir_pack_ngg_prim_exp_arg(b, s->vertices_per_prim, indices, cull_flag, s->hw_info->gfx_level);
798 }
799 
800 static nir_def *
ms_prim_exp_arg_ch2(nir_builder * b,uint64_t outputs_mask,lower_ngg_ms_state * s)801 ms_prim_exp_arg_ch2(nir_builder *b, uint64_t outputs_mask, lower_ngg_ms_state *s)
802 {
803    nir_def *prim_exp_arg_ch2 = NULL;
804 
805    if (outputs_mask) {
806       /* When layer, viewport etc. are per-primitive, they need to be encoded in
807        * the primitive export instruction's second channel. The encoding is:
808        *
809        * --- GFX10.3 ---
810        * bits 31..30: VRS rate Y
811        * bits 29..28: VRS rate X
812        * bits 23..20: viewport
813        * bits 19..17: layer
814        *
815        * --- GFX11 ---
816        * bits 31..28: VRS rate enum
817        * bits 23..20: viewport
818        * bits 12..00: layer
819        */
820       prim_exp_arg_ch2 = nir_imm_int(b, 0);
821 
822       if (outputs_mask & VARYING_BIT_LAYER) {
823          nir_def *layer =
824             nir_ishl_imm(b, s->out.outputs[VARYING_SLOT_LAYER][0], s->hw_info->gfx_level >= GFX11 ? 0 : 17);
825          prim_exp_arg_ch2 = nir_ior(b, prim_exp_arg_ch2, layer);
826       }
827 
828       if (outputs_mask & VARYING_BIT_VIEWPORT) {
829          nir_def *view = nir_ishl_imm(b, s->out.outputs[VARYING_SLOT_VIEWPORT][0], 20);
830          prim_exp_arg_ch2 = nir_ior(b, prim_exp_arg_ch2, view);
831       }
832 
833       if (outputs_mask & VARYING_BIT_PRIMITIVE_SHADING_RATE) {
834          nir_def *rate = s->out.outputs[VARYING_SLOT_PRIMITIVE_SHADING_RATE][0];
835          prim_exp_arg_ch2 = nir_ior(b, prim_exp_arg_ch2, rate);
836       }
837    }
838 
839    return prim_exp_arg_ch2;
840 }
841 
842 static void
ms_prim_gen_query(nir_builder * b,nir_def * invocation_index,nir_def * num_prm,lower_ngg_ms_state * s)843 ms_prim_gen_query(nir_builder *b,
844                   nir_def *invocation_index,
845                   nir_def *num_prm,
846                   lower_ngg_ms_state *s)
847 {
848    if (!s->has_query)
849       return;
850 
851    nir_if *if_invocation_index_zero = nir_push_if(b, nir_ieq_imm(b, invocation_index, 0));
852    {
853       nir_if *if_shader_query = nir_push_if(b, nir_load_prim_gen_query_enabled_amd(b));
854       {
855          nir_atomic_add_gen_prim_count_amd(b, num_prm, .stream_id = 0);
856       }
857       nir_pop_if(b, if_shader_query);
858    }
859    nir_pop_if(b, if_invocation_index_zero);
860 }
861 
862 static void
ms_invocation_query(nir_builder * b,nir_def * invocation_index,lower_ngg_ms_state * s)863 ms_invocation_query(nir_builder *b,
864                     nir_def *invocation_index,
865                     lower_ngg_ms_state *s)
866 {
867    if (!s->has_query)
868       return;
869 
870    nir_if *if_invocation_index_zero = nir_push_if(b, nir_ieq_imm(b, invocation_index, 0));
871    {
872       nir_if *if_pipeline_query = nir_push_if(b, nir_load_pipeline_stat_query_enabled_amd(b));
873       {
874          nir_atomic_add_shader_invocation_count_amd(b, nir_imm_int(b, s->api_workgroup_size));
875       }
876       nir_pop_if(b, if_pipeline_query);
877    }
878    nir_pop_if(b, if_invocation_index_zero);
879 }
880 
881 static void
emit_ms_vertex(nir_builder * b,nir_def * index,nir_def * row,bool exports,bool parameters,uint64_t per_vertex_outputs,lower_ngg_ms_state * s)882 emit_ms_vertex(nir_builder *b, nir_def *index, nir_def *row, bool exports, bool parameters,
883                uint64_t per_vertex_outputs, lower_ngg_ms_state *s)
884 {
885    ms_emit_arrayed_outputs(b, index, per_vertex_outputs, s);
886 
887    if (exports) {
888       ac_nir_export_position(b, s->hw_info->gfx_level, s->clipdist_enable_mask,
889                              !s->has_param_exports, false, true,
890                              s->per_vertex_outputs | VARYING_BIT_POS, &s->out, row);
891    }
892 
893    if (parameters) {
894       /* Export generic attributes when there is no attribute ring. */
895       if (s->has_param_exports && !s->hw_info->has_attr_ring) {
896          ac_nir_export_parameters(b, s->vs_output_param_offset, per_vertex_outputs, 0, &s->out);
897       }
898 
899       /* Also store special outputs to the attribute ring so PS can load them. */
900       if (s->hw_info->has_attr_ring && (per_vertex_outputs & MS_VERT_ARG_EXP_MASK))
901          ms_emit_attribute_ring_output_stores(b, per_vertex_outputs & MS_VERT_ARG_EXP_MASK, index, s);
902    }
903 }
904 
905 static void
emit_ms_primitive(nir_builder * b,nir_def * index,nir_def * row,bool exports,bool parameters,uint64_t per_primitive_outputs,lower_ngg_ms_state * s)906 emit_ms_primitive(nir_builder *b, nir_def *index, nir_def *row, bool exports, bool parameters,
907                   uint64_t per_primitive_outputs, lower_ngg_ms_state *s)
908 {
909    ms_emit_arrayed_outputs(b, index, per_primitive_outputs, s);
910 
911    /* Insert layer output store if the pipeline uses multiview but the API shader doesn't write it. */
912    if (s->insert_layer_output) {
913       s->out.outputs[VARYING_SLOT_LAYER][0] = nir_load_view_index(b);
914       s->out.infos[VARYING_SLOT_LAYER].as_sysval_mask |= 1;
915    }
916 
917    if (exports) {
918       const uint64_t outputs_mask = per_primitive_outputs & MS_PRIM_ARG_EXP_MASK;
919       nir_def *num_vtx = nir_load_var(b, s->vertex_count_var);
920       nir_def *prim_exp_arg_ch1 = ms_prim_exp_arg_ch1(b, index, num_vtx, s);
921       nir_def *prim_exp_arg_ch2 = ms_prim_exp_arg_ch2(b, outputs_mask, s);
922 
923       nir_def *prim_exp_arg = prim_exp_arg_ch2 ?
924          nir_vec2(b, prim_exp_arg_ch1, prim_exp_arg_ch2) : prim_exp_arg_ch1;
925 
926       ac_nir_export_primitive(b, prim_exp_arg, row);
927    }
928 
929    if (parameters) {
930       /* Export generic attributes when there is no attribute ring. */
931       if (s->has_param_exports && !s->hw_info->has_attr_ring) {
932          ac_nir_export_parameters(b, s->vs_output_param_offset, per_primitive_outputs, 0, &s->out);
933       }
934 
935       /* Also store special outputs to the attribute ring so PS can load them. */
936       if (s->hw_info->has_attr_ring)
937          ms_emit_attribute_ring_output_stores(b, per_primitive_outputs & MS_PRIM_ARG_EXP_MASK, index, s);
938    }
939 }
940 
941 static void
emit_ms_outputs(nir_builder * b,nir_def * invocation_index,nir_def * row_start,nir_def * count,bool exports,bool parameters,uint64_t mask,void (* cb)(nir_builder *,nir_def *,nir_def *,bool,bool,uint64_t,lower_ngg_ms_state *),lower_ngg_ms_state * s)942 emit_ms_outputs(nir_builder *b, nir_def *invocation_index, nir_def *row_start,
943                 nir_def *count, bool exports, bool parameters, uint64_t mask,
944                 void (*cb)(nir_builder *, nir_def *, nir_def *, bool, bool,
945                            uint64_t, lower_ngg_ms_state *),
946                 lower_ngg_ms_state *s)
947 {
948    if (cb == &emit_ms_primitive ? s->prim_multirow_export : s->vert_multirow_export) {
949       assert(s->hw_workgroup_size % s->wave_size == 0);
950       const unsigned num_waves = s->hw_workgroup_size / s->wave_size;
951 
952       nir_loop *row_loop = nir_push_loop(b);
953       {
954          nir_block *preheader = nir_cf_node_as_block(nir_cf_node_prev(&row_loop->cf_node));
955 
956          nir_phi_instr *index = nir_phi_instr_create(b->shader);
957          nir_phi_instr *row = nir_phi_instr_create(b->shader);
958          nir_def_init(&index->instr, &index->def, 1, 32);
959          nir_def_init(&row->instr, &row->def, 1, 32);
960 
961          nir_phi_instr_add_src(index, preheader, invocation_index);
962          nir_phi_instr_add_src(row, preheader, row_start);
963 
964          nir_if *if_break = nir_push_if(b, nir_uge(b, &index->def, count));
965          {
966             nir_jump(b, nir_jump_break);
967          }
968          nir_pop_if(b, if_break);
969 
970          cb(b, &index->def, &row->def, exports, parameters, mask, s);
971 
972          nir_block *body = nir_cursor_current_block(b->cursor);
973          nir_phi_instr_add_src(index, body,
974                                nir_iadd_imm(b, &index->def, s->hw_workgroup_size));
975          nir_phi_instr_add_src(row, body,
976                                nir_iadd_imm(b, &row->def, num_waves));
977 
978          nir_instr_insert_before_cf_list(&row_loop->body, &row->instr);
979          nir_instr_insert_before_cf_list(&row_loop->body, &index->instr);
980       }
981       nir_pop_loop(b, row_loop);
982    } else {
983       nir_def *has_output = nir_ilt(b, invocation_index, count);
984       nir_if *if_has_output = nir_push_if(b, has_output);
985       {
986          cb(b, invocation_index, row_start, exports, parameters, mask, s);
987       }
988       nir_pop_if(b, if_has_output);
989    }
990 }
991 
992 static void
emit_ms_finale(nir_builder * b,lower_ngg_ms_state * s)993 emit_ms_finale(nir_builder *b, lower_ngg_ms_state *s)
994 {
995    /* We assume there is always a single end block in the shader. */
996    nir_block *last_block = nir_impl_last_block(b->impl);
997    b->cursor = nir_after_block(last_block);
998 
999    nir_barrier(b, .execution_scope=SCOPE_WORKGROUP, .memory_scope=SCOPE_WORKGROUP,
1000                          .memory_semantics=NIR_MEMORY_ACQ_REL, .memory_modes=nir_var_shader_out|nir_var_mem_shared);
1001 
1002    nir_def *num_prm;
1003    nir_def *num_vtx;
1004 
1005    set_ms_final_output_counts(b, s, &num_prm, &num_vtx);
1006 
1007    nir_def *invocation_index = nir_load_local_invocation_index(b);
1008 
1009    ms_prim_gen_query(b, invocation_index, num_prm, s);
1010 
1011    nir_def *row_start = NULL;
1012    if (s->fast_launch_2)
1013       row_start = s->hw_workgroup_size <= s->wave_size ? nir_imm_int(b, 0) : nir_load_subgroup_id(b);
1014 
1015    /* Load vertex/primitive attributes from shared memory and
1016     * emit store_output intrinsics for them.
1017     *
1018     * Contrary to the semantics of the API mesh shader, these are now
1019     * compliant with NGG HW semantics, meaning that these store the
1020     * current thread's vertex attributes in a way the HW can export.
1021     */
1022 
1023    uint64_t per_vertex_outputs =
1024       s->per_vertex_outputs & ~s->layout.attr_ring.vtx_attr.mask;
1025    uint64_t per_primitive_outputs =
1026       s->per_primitive_outputs & ~s->layout.attr_ring.prm_attr.mask & ~SPECIAL_MS_OUT_MASK;
1027 
1028    /* Insert layer output store if the pipeline uses multiview but the API shader doesn't write it. */
1029    if (s->insert_layer_output) {
1030       b->shader->info.outputs_written |= VARYING_BIT_LAYER;
1031       b->shader->info.per_primitive_outputs |= VARYING_BIT_LAYER;
1032       per_primitive_outputs |= VARYING_BIT_LAYER;
1033    }
1034 
1035    const bool has_special_param_exports =
1036       (per_vertex_outputs & MS_VERT_ARG_EXP_MASK) ||
1037       (per_primitive_outputs & MS_PRIM_ARG_EXP_MASK);
1038    const bool wait_attr_ring = has_special_param_exports && s->hw_info->has_attr_ring_wait_bug;
1039 
1040    /* Export vertices. */
1041    if ((per_vertex_outputs & ~VARYING_BIT_POS) || !wait_attr_ring) {
1042       emit_ms_outputs(b, invocation_index, row_start, num_vtx, !wait_attr_ring, true,
1043                       per_vertex_outputs, &emit_ms_vertex, s);
1044    }
1045 
1046    /* Export primitives. */
1047    if (per_primitive_outputs || !wait_attr_ring) {
1048       emit_ms_outputs(b, invocation_index, row_start, num_prm, !wait_attr_ring, true,
1049                       per_primitive_outputs, &emit_ms_primitive, s);
1050    }
1051 
1052    /* When we need to wait for attribute ring stores, we emit both position and primitive
1053     * export instructions after a barrier to make sure both per-vertex and per-primitive
1054     * attribute ring stores are finished before the GPU starts rasterization.
1055     */
1056    if (wait_attr_ring) {
1057       /* Wait for attribute stores to finish. */
1058       nir_barrier(b, .execution_scope = SCOPE_SUBGROUP,
1059                      .memory_scope = SCOPE_DEVICE,
1060                      .memory_semantics = NIR_MEMORY_RELEASE,
1061                      .memory_modes = nir_var_shader_out);
1062 
1063       /* Position/primitive export only */
1064       emit_ms_outputs(b, invocation_index, row_start, num_vtx, true, false,
1065                       per_vertex_outputs, &emit_ms_vertex, s);
1066       emit_ms_outputs(b, invocation_index, row_start, num_prm, true, false,
1067                       per_primitive_outputs, &emit_ms_primitive, s);
1068    }
1069 }
1070 
1071 static void
handle_smaller_ms_api_workgroup(nir_builder * b,lower_ngg_ms_state * s)1072 handle_smaller_ms_api_workgroup(nir_builder *b,
1073                                 lower_ngg_ms_state *s)
1074 {
1075    if (s->api_workgroup_size >= s->hw_workgroup_size)
1076       return;
1077 
1078    /* Handle barriers manually when the API workgroup
1079     * size is less than the HW workgroup size.
1080     *
1081     * The problem is that the real workgroup launched on NGG HW
1082     * will be larger than the size specified by the API, and the
1083     * extra waves need to keep up with barriers in the API waves.
1084     *
1085     * There are 2 different cases:
1086     * 1. The whole API workgroup fits in a single wave.
1087     *    We can shrink the barriers to subgroup scope and
1088     *    don't need to insert any extra ones.
1089     * 2. The API workgroup occupies multiple waves, but not
1090     *    all. In this case, we emit code that consumes every
1091     *    barrier on the extra waves.
1092     */
1093    assert(s->hw_workgroup_size % s->wave_size == 0);
1094    bool scan_barriers = ALIGN(s->api_workgroup_size, s->wave_size) < s->hw_workgroup_size;
1095    bool can_shrink_barriers = s->api_workgroup_size <= s->wave_size;
1096    bool need_additional_barriers = scan_barriers && !can_shrink_barriers;
1097 
1098    unsigned api_waves_in_flight_addr = s->layout.lds.workgroup_info_addr + lds_ms_num_api_waves;
1099    unsigned num_api_waves = DIV_ROUND_UP(s->api_workgroup_size, s->wave_size);
1100 
1101    /* Scan the shader for workgroup barriers. */
1102    if (scan_barriers) {
1103       bool has_any_workgroup_barriers = false;
1104 
1105       nir_foreach_block(block, b->impl) {
1106          nir_foreach_instr_safe(instr, block) {
1107             if (instr->type != nir_instr_type_intrinsic)
1108                continue;
1109 
1110             nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
1111             bool is_workgroup_barrier =
1112                intrin->intrinsic == nir_intrinsic_barrier &&
1113                nir_intrinsic_execution_scope(intrin) == SCOPE_WORKGROUP;
1114 
1115             if (!is_workgroup_barrier)
1116                continue;
1117 
1118             if (can_shrink_barriers) {
1119                /* Every API invocation runs in the first wave.
1120                 * In this case, we can change the barriers to subgroup scope
1121                 * and avoid adding additional barriers.
1122                 */
1123                nir_intrinsic_set_memory_scope(intrin, SCOPE_SUBGROUP);
1124                nir_intrinsic_set_execution_scope(intrin, SCOPE_SUBGROUP);
1125             } else {
1126                has_any_workgroup_barriers = true;
1127             }
1128          }
1129       }
1130 
1131       need_additional_barriers &= has_any_workgroup_barriers;
1132    }
1133 
1134    /* Extract the full control flow of the shader. */
1135    nir_cf_list extracted;
1136    nir_cf_extract(&extracted, nir_before_impl(b->impl),
1137                   nir_after_cf_list(&b->impl->body));
1138    b->cursor = nir_before_impl(b->impl);
1139 
1140    /* Wrap the shader in an if to ensure that only the necessary amount of lanes run it. */
1141    nir_def *invocation_index = nir_load_local_invocation_index(b);
1142    nir_def *zero = nir_imm_int(b, 0);
1143 
1144    if (need_additional_barriers) {
1145       /* First invocation stores 0 to number of API waves in flight. */
1146       nir_if *if_first_in_workgroup = nir_push_if(b, nir_ieq_imm(b, invocation_index, 0));
1147       {
1148          nir_store_shared(b, nir_imm_int(b, num_api_waves), zero, .base = api_waves_in_flight_addr);
1149       }
1150       nir_pop_if(b, if_first_in_workgroup);
1151 
1152       nir_barrier(b, .execution_scope = SCOPE_WORKGROUP,
1153                             .memory_scope = SCOPE_WORKGROUP,
1154                             .memory_semantics = NIR_MEMORY_ACQ_REL,
1155                             .memory_modes = nir_var_shader_out | nir_var_mem_shared);
1156    }
1157 
1158    nir_def *has_api_ms_invocation = nir_ult_imm(b, invocation_index, s->api_workgroup_size);
1159    nir_if *if_has_api_ms_invocation = nir_push_if(b, has_api_ms_invocation);
1160    {
1161       nir_cf_reinsert(&extracted, b->cursor);
1162       b->cursor = nir_after_cf_list(&if_has_api_ms_invocation->then_list);
1163 
1164       if (need_additional_barriers) {
1165          /* One invocation in each API wave decrements the number of API waves in flight. */
1166          nir_if *if_elected_again = nir_push_if(b, nir_elect(b, 1));
1167          {
1168             nir_shared_atomic(b, 32, zero, nir_imm_int(b, -1u),
1169                               .base = api_waves_in_flight_addr,
1170                               .atomic_op = nir_atomic_op_iadd);
1171          }
1172          nir_pop_if(b, if_elected_again);
1173 
1174          nir_barrier(b, .execution_scope = SCOPE_WORKGROUP,
1175                                .memory_scope = SCOPE_WORKGROUP,
1176                                .memory_semantics = NIR_MEMORY_ACQ_REL,
1177                                .memory_modes = nir_var_shader_out | nir_var_mem_shared);
1178       }
1179 
1180       ms_invocation_query(b, invocation_index, s);
1181    }
1182    nir_pop_if(b, if_has_api_ms_invocation);
1183 
1184    if (need_additional_barriers) {
1185       /* Make sure that waves that don't run any API invocations execute
1186        * the same amount of barriers as those that do.
1187        *
1188        * We do this by executing a barrier until the number of API waves
1189        * in flight becomes zero.
1190        */
1191       nir_def *has_api_ms_ballot = nir_ballot(b, 1, s->wave_size, has_api_ms_invocation);
1192       nir_def *wave_has_no_api_ms = nir_ieq_imm(b, has_api_ms_ballot, 0);
1193       nir_if *if_wave_has_no_api_ms = nir_push_if(b, wave_has_no_api_ms);
1194       {
1195          nir_if *if_elected = nir_push_if(b, nir_elect(b, 1));
1196          {
1197             nir_loop *loop = nir_push_loop(b);
1198             {
1199                nir_barrier(b, .execution_scope = SCOPE_WORKGROUP,
1200                                      .memory_scope = SCOPE_WORKGROUP,
1201                                      .memory_semantics = NIR_MEMORY_ACQ_REL,
1202                                      .memory_modes = nir_var_shader_out | nir_var_mem_shared);
1203 
1204                nir_def *loaded = nir_load_shared(b, 1, 32, zero, .base = api_waves_in_flight_addr);
1205                nir_if *if_break = nir_push_if(b, nir_ieq_imm(b, loaded, 0));
1206                {
1207                   nir_jump(b, nir_jump_break);
1208                }
1209                nir_pop_if(b, if_break);
1210             }
1211             nir_pop_loop(b, loop);
1212          }
1213          nir_pop_if(b, if_elected);
1214       }
1215       nir_pop_if(b, if_wave_has_no_api_ms);
1216    }
1217 }
1218 
1219 static void
ms_move_output(ms_out_part * from,ms_out_part * to)1220 ms_move_output(ms_out_part *from, ms_out_part *to)
1221 {
1222    uint64_t loc = util_logbase2_64(from->mask);
1223    uint64_t bit = BITFIELD64_BIT(loc);
1224    from->mask ^= bit;
1225    to->mask |= bit;
1226 }
1227 
1228 static void
ms_calculate_arrayed_output_layout(ms_out_mem_layout * l,unsigned max_vertices,unsigned max_primitives)1229 ms_calculate_arrayed_output_layout(ms_out_mem_layout *l,
1230                                    unsigned max_vertices,
1231                                    unsigned max_primitives)
1232 {
1233    uint32_t lds_vtx_attr_size = util_bitcount64(l->lds.vtx_attr.mask) * max_vertices * 16;
1234    uint32_t lds_prm_attr_size = util_bitcount64(l->lds.prm_attr.mask) * max_primitives * 16;
1235    l->lds.prm_attr.addr = ALIGN(l->lds.vtx_attr.addr + lds_vtx_attr_size, 16);
1236    l->lds.total_size = l->lds.prm_attr.addr + lds_prm_attr_size;
1237 
1238    uint32_t scratch_ring_vtx_attr_size =
1239       util_bitcount64(l->scratch_ring.vtx_attr.mask) * max_vertices * 16;
1240    l->scratch_ring.prm_attr.addr =
1241       ALIGN(l->scratch_ring.vtx_attr.addr + scratch_ring_vtx_attr_size, 16);
1242 }
1243 
1244 static ms_out_mem_layout
ms_calculate_output_layout(const struct radeon_info * hw_info,unsigned api_shared_size,uint64_t per_vertex_output_mask,uint64_t per_primitive_output_mask,uint64_t cross_invocation_output_access,unsigned max_vertices,unsigned max_primitives,unsigned vertices_per_prim)1245 ms_calculate_output_layout(const struct radeon_info *hw_info, unsigned api_shared_size,
1246                            uint64_t per_vertex_output_mask, uint64_t per_primitive_output_mask,
1247                            uint64_t cross_invocation_output_access, unsigned max_vertices,
1248                            unsigned max_primitives, unsigned vertices_per_prim)
1249 {
1250    /* These outputs always need export instructions and can't use the attributes ring. */
1251    const uint64_t always_export_mask =
1252       VARYING_BIT_POS | VARYING_BIT_CULL_DIST0 | VARYING_BIT_CULL_DIST1 | VARYING_BIT_CLIP_DIST0 |
1253       VARYING_BIT_CLIP_DIST1 | VARYING_BIT_PSIZ | VARYING_BIT_VIEWPORT |
1254       VARYING_BIT_PRIMITIVE_SHADING_RATE | VARYING_BIT_LAYER |
1255       BITFIELD64_BIT(VARYING_SLOT_PRIMITIVE_COUNT) |
1256       BITFIELD64_BIT(VARYING_SLOT_PRIMITIVE_INDICES) | BITFIELD64_BIT(VARYING_SLOT_CULL_PRIMITIVE);
1257 
1258    const bool use_attr_ring = hw_info->has_attr_ring;
1259    const uint64_t attr_ring_per_vertex_output_mask =
1260       use_attr_ring ? per_vertex_output_mask & ~always_export_mask : 0;
1261    const uint64_t attr_ring_per_primitive_output_mask =
1262       use_attr_ring ? per_primitive_output_mask & ~always_export_mask : 0;
1263 
1264    const uint64_t lds_per_vertex_output_mask =
1265       per_vertex_output_mask & ~attr_ring_per_vertex_output_mask & cross_invocation_output_access &
1266       ~SPECIAL_MS_OUT_MASK;
1267    const uint64_t lds_per_primitive_output_mask =
1268       per_primitive_output_mask & ~attr_ring_per_primitive_output_mask &
1269       cross_invocation_output_access & ~SPECIAL_MS_OUT_MASK;
1270 
1271    const bool cross_invocation_indices =
1272       cross_invocation_output_access & BITFIELD64_BIT(VARYING_SLOT_PRIMITIVE_INDICES);
1273    const bool cross_invocation_cull_primitive =
1274       cross_invocation_output_access & BITFIELD64_BIT(VARYING_SLOT_CULL_PRIMITIVE);
1275 
1276    /* Shared memory used by the API shader. */
1277    ms_out_mem_layout l = { .lds = { .total_size = api_shared_size } };
1278 
1279    /* Use attribute ring for all generic attributes (on GPUs with an attribute ring). */
1280    l.attr_ring.vtx_attr.mask = attr_ring_per_vertex_output_mask;
1281    l.attr_ring.prm_attr.mask = attr_ring_per_primitive_output_mask;
1282 
1283    /* Outputs without cross-invocation access can be stored in variables. */
1284    l.var.vtx_attr.mask =
1285       per_vertex_output_mask & ~attr_ring_per_vertex_output_mask & ~cross_invocation_output_access;
1286    l.var.prm_attr.mask = per_primitive_output_mask & ~attr_ring_per_primitive_output_mask &
1287                          ~cross_invocation_output_access;
1288 
1289    /* Workgroup information, see ms_workgroup_* for the layout. */
1290    l.lds.workgroup_info_addr = ALIGN(l.lds.total_size, 16);
1291    l.lds.total_size = l.lds.workgroup_info_addr + 16;
1292 
1293    /* Per-vertex and per-primitive output attributes.
1294     * Outputs without cross-invocation access are not included here.
1295     * First, try to put all outputs into LDS (shared memory).
1296     * If they don't fit, try to move them to VRAM one by one.
1297     */
1298    l.lds.vtx_attr.addr = ALIGN(l.lds.total_size, 16);
1299    l.lds.vtx_attr.mask = lds_per_vertex_output_mask;
1300    l.lds.prm_attr.mask = lds_per_primitive_output_mask;
1301    ms_calculate_arrayed_output_layout(&l, max_vertices, max_primitives);
1302 
1303    /* NGG shaders can only address up to 32K LDS memory.
1304     * The spec requires us to allow the application to use at least up to 28K
1305     * shared memory. Additionally, we reserve 2K for driver internal use
1306     * (eg. primitive indices and such, see below).
1307     *
1308     * Move the outputs that do not fit LDS, to VRAM.
1309     * Start with per-primitive attributes, because those are grouped at the end.
1310     */
1311    const unsigned usable_lds_kbytes =
1312       (cross_invocation_cull_primitive || cross_invocation_indices) ? 30 : 31;
1313    while (l.lds.total_size >= usable_lds_kbytes * 1024) {
1314       if (l.lds.prm_attr.mask)
1315          ms_move_output(&l.lds.prm_attr, &l.scratch_ring.prm_attr);
1316       else if (l.lds.vtx_attr.mask)
1317          ms_move_output(&l.lds.vtx_attr, &l.scratch_ring.vtx_attr);
1318       else
1319          unreachable("API shader uses too much shared memory.");
1320 
1321       ms_calculate_arrayed_output_layout(&l, max_vertices, max_primitives);
1322    }
1323 
1324    if (cross_invocation_indices) {
1325       /* Indices: flat array of 8-bit vertex indices for each primitive. */
1326       l.lds.indices_addr = ALIGN(l.lds.total_size, 16);
1327       l.lds.total_size = l.lds.indices_addr + max_primitives * vertices_per_prim;
1328    }
1329 
1330    if (cross_invocation_cull_primitive) {
1331       /* Cull flags: array of 8-bit cull flags for each primitive, 1=cull, 0=keep. */
1332       l.lds.cull_flags_addr = ALIGN(l.lds.total_size, 16);
1333       l.lds.total_size = l.lds.cull_flags_addr + max_primitives;
1334    }
1335 
1336    /* NGG is only allowed to address up to 32K of LDS. */
1337    assert(l.lds.total_size <= 32 * 1024);
1338    return l;
1339 }
1340 
1341 void
ac_nir_lower_ngg_mesh(nir_shader * shader,const struct radeon_info * hw_info,uint32_t clipdist_enable_mask,const uint8_t * vs_output_param_offset,bool has_param_exports,bool * out_needs_scratch_ring,unsigned wave_size,unsigned hw_workgroup_size,bool multiview,bool has_query,bool fast_launch_2)1342 ac_nir_lower_ngg_mesh(nir_shader *shader,
1343                       const struct radeon_info *hw_info,
1344                       uint32_t clipdist_enable_mask,
1345                       const uint8_t *vs_output_param_offset,
1346                       bool has_param_exports,
1347                       bool *out_needs_scratch_ring,
1348                       unsigned wave_size,
1349                       unsigned hw_workgroup_size,
1350                       bool multiview,
1351                       bool has_query,
1352                       bool fast_launch_2)
1353 {
1354    unsigned vertices_per_prim =
1355       mesa_vertices_per_prim(shader->info.mesh.primitive_type);
1356 
1357    uint64_t per_vertex_outputs =
1358       shader->info.outputs_written & ~shader->info.per_primitive_outputs & ~SPECIAL_MS_OUT_MASK;
1359    uint64_t per_primitive_outputs =
1360       shader->info.per_primitive_outputs & shader->info.outputs_written;
1361 
1362    /* Whether the shader uses CullPrimitiveEXT */
1363    bool uses_cull = shader->info.outputs_written & BITFIELD64_BIT(VARYING_SLOT_CULL_PRIMITIVE);
1364    /* Can't handle indirect register addressing, pretend as if they were cross-invocation. */
1365    uint64_t cross_invocation_access = shader->info.mesh.ms_cross_invocation_output_access |
1366                                       shader->info.outputs_accessed_indirectly;
1367 
1368    unsigned max_vertices = shader->info.mesh.max_vertices_out;
1369    unsigned max_primitives = shader->info.mesh.max_primitives_out;
1370 
1371    ms_out_mem_layout layout = ms_calculate_output_layout(
1372       hw_info, shader->info.shared_size, per_vertex_outputs, per_primitive_outputs,
1373       cross_invocation_access, max_vertices, max_primitives, vertices_per_prim);
1374 
1375    shader->info.shared_size = layout.lds.total_size;
1376    *out_needs_scratch_ring = layout.scratch_ring.vtx_attr.mask || layout.scratch_ring.prm_attr.mask;
1377 
1378    /* The workgroup size that is specified by the API shader may be different
1379     * from the size of the workgroup that actually runs on the HW, due to the
1380     * limitations of NGG: max 0/1 vertex and 0/1 primitive per lane is allowed.
1381     *
1382     * Therefore, we must make sure that when the API workgroup size is smaller,
1383     * we don't run the API shader on more HW invocations than is necessary.
1384     */
1385    unsigned api_workgroup_size = shader->info.workgroup_size[0] *
1386                                  shader->info.workgroup_size[1] *
1387                                  shader->info.workgroup_size[2];
1388 
1389    lower_ngg_ms_state state = {
1390       .layout = layout,
1391       .wave_size = wave_size,
1392       .per_vertex_outputs = per_vertex_outputs,
1393       .per_primitive_outputs = per_primitive_outputs,
1394       .vertices_per_prim = vertices_per_prim,
1395       .api_workgroup_size = api_workgroup_size,
1396       .hw_workgroup_size = hw_workgroup_size,
1397       .insert_layer_output = multiview && !(shader->info.outputs_written & VARYING_BIT_LAYER),
1398       .uses_cull_flags = uses_cull,
1399       .hw_info = hw_info,
1400       .fast_launch_2 = fast_launch_2,
1401       .vert_multirow_export = fast_launch_2 && max_vertices > hw_workgroup_size,
1402       .prim_multirow_export = fast_launch_2 && max_primitives > hw_workgroup_size,
1403       .clipdist_enable_mask = clipdist_enable_mask,
1404       .vs_output_param_offset = vs_output_param_offset,
1405       .has_param_exports = has_param_exports,
1406       .has_query = has_query,
1407    };
1408 
1409    nir_function_impl *impl = nir_shader_get_entrypoint(shader);
1410    assert(impl);
1411 
1412    state.vertex_count_var =
1413       nir_local_variable_create(impl, glsl_uint_type(), "vertex_count_var");
1414    state.primitive_count_var =
1415       nir_local_variable_create(impl, glsl_uint_type(), "primitive_count_var");
1416 
1417    nir_builder builder = nir_builder_at(nir_before_impl(impl));
1418    nir_builder *b = &builder; /* This is to avoid the & */
1419 
1420    handle_smaller_ms_api_workgroup(b, &state);
1421    if (!fast_launch_2)
1422       ms_emit_legacy_workgroup_index(b, &state);
1423    ms_create_same_invocation_vars(b, &state);
1424    nir_metadata_preserve(impl, nir_metadata_none);
1425 
1426    lower_ms_intrinsics(shader, &state);
1427 
1428    emit_ms_finale(b, &state);
1429    nir_metadata_preserve(impl, nir_metadata_none);
1430 
1431    /* Cleanup */
1432    nir_lower_vars_to_ssa(shader);
1433    nir_remove_dead_variables(shader, nir_var_function_temp, NULL);
1434    nir_lower_alu_to_scalar(shader, NULL, NULL);
1435    nir_lower_phis_to_scalar(shader, true);
1436 
1437    /* Optimize load_local_invocation_index. When the API workgroup is smaller than the HW workgroup,
1438     * local_invocation_id isn't initialized for all lanes and we can't perform this optimization for
1439     * all load_local_invocation_index.
1440     */
1441    if (fast_launch_2 && api_workgroup_size == hw_workgroup_size &&
1442        ((shader->info.workgroup_size[0] == 1) + (shader->info.workgroup_size[1] == 1) +
1443         (shader->info.workgroup_size[2] == 1)) == 2) {
1444       nir_lower_compute_system_values_options csv_options = {
1445          .lower_local_invocation_index = true,
1446       };
1447       nir_lower_compute_system_values(shader, &csv_options);
1448    }
1449 
1450    nir_validate_shader(shader, "after emitting NGG MS");
1451 }
1452