• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright © 2021 Valve Corporation
3  *
4  * SPDX-License-Identifier: MIT
5  */
6 
7 #include "ac_nir.h"
8 #include "ac_nir_helpers.h"
9 
10 #include "nir_builder.h"
11 
12 #define SPECIAL_MS_OUT_MASK \
13    (BITFIELD64_BIT(VARYING_SLOT_PRIMITIVE_COUNT) | \
14     BITFIELD64_BIT(VARYING_SLOT_PRIMITIVE_INDICES) | \
15     BITFIELD64_BIT(VARYING_SLOT_CULL_PRIMITIVE))
16 
17 #define MS_PRIM_ARG_EXP_MASK \
18    (VARYING_BIT_LAYER | \
19     VARYING_BIT_VIEWPORT | \
20     VARYING_BIT_PRIMITIVE_SHADING_RATE)
21 
22 #define MS_VERT_ARG_EXP_MASK \
23    (VARYING_BIT_CULL_DIST0 | \
24     VARYING_BIT_CULL_DIST1 | \
25     VARYING_BIT_CLIP_DIST0 | \
26     VARYING_BIT_CLIP_DIST1 | \
27     VARYING_BIT_PSIZ)
28 
29 /* LDS layout of Mesh Shader workgroup info. */
30 enum {
31    /* DW0: number of primitives */
32    lds_ms_num_prims = 0,
33    /* DW1: number of vertices */
34    lds_ms_num_vtx = 4,
35    /* DW2: workgroup index within the current dispatch */
36    lds_ms_wg_index = 8,
37    /* DW3: number of API workgroups in flight */
38    lds_ms_num_api_waves = 12,
39 };
40 
41 /* Potential location for Mesh Shader outputs. */
42 typedef enum {
43    ms_out_mode_lds,
44    ms_out_mode_scratch_ring,
45    ms_out_mode_attr_ring,
46    ms_out_mode_var,
47 } ms_out_mode;
48 
49 typedef struct
50 {
51    uint64_t mask; /* Mask of output locations */
52    uint32_t addr; /* Base address */
53 } ms_out_part;
54 
55 typedef struct
56 {
57    /* Mesh shader LDS layout. For details, see ms_calculate_output_layout. */
58    struct {
59       uint32_t workgroup_info_addr;
60       ms_out_part vtx_attr;
61       ms_out_part prm_attr;
62       uint32_t indices_addr;
63       uint32_t cull_flags_addr;
64       uint32_t total_size;
65    } lds;
66 
67    /* VRAM "mesh shader scratch ring" layout for outputs that don't fit into the LDS.
68     * Not to be confused with scratch memory.
69     */
70    struct {
71       ms_out_part vtx_attr;
72       ms_out_part prm_attr;
73    } scratch_ring;
74 
75    /* VRAM attributes ring (GFX11 only) for all non-position outputs.
76     * GFX11 doesn't have to reload attributes from this ring at the end of the shader.
77     */
78    struct {
79       ms_out_part vtx_attr;
80       ms_out_part prm_attr;
81    } attr_ring;
82 
83    /* Outputs without cross-invocation access can be stored in variables. */
84    struct {
85       ms_out_part vtx_attr;
86       ms_out_part prm_attr;
87    } var;
88 } ms_out_mem_layout;
89 
90 typedef struct
91 {
92    enum amd_gfx_level gfx_level;
93    bool fast_launch_2;
94    bool vert_multirow_export;
95    bool prim_multirow_export;
96 
97    ms_out_mem_layout layout;
98    uint64_t per_vertex_outputs;
99    uint64_t per_primitive_outputs;
100    unsigned vertices_per_prim;
101 
102    unsigned wave_size;
103    unsigned api_workgroup_size;
104    unsigned hw_workgroup_size;
105 
106    nir_def *workgroup_index;
107    nir_variable *out_variables[VARYING_SLOT_MAX * 4];
108    nir_variable *primitive_count_var;
109    nir_variable *vertex_count_var;
110 
111    ac_nir_prerast_out out;
112 
113    /* True if the lowering needs to insert the layer output. */
114    bool insert_layer_output;
115    /* True if cull flags are used */
116    bool uses_cull_flags;
117 
118    uint32_t clipdist_enable_mask;
119    const uint8_t *vs_output_param_offset;
120    bool has_param_exports;
121 
122    /* True if the lowering needs to insert shader query. */
123    bool has_query;
124 } lower_ngg_ms_state;
125 
must_wait_attr_ring(enum amd_gfx_level gfx_level,bool has_param_exports)126 static bool must_wait_attr_ring(enum amd_gfx_level gfx_level, bool has_param_exports)
127 {
128    return (gfx_level == GFX11 || gfx_level == GFX11_5) && has_param_exports;
129 }
130 
131 static void
ms_store_prim_indices(nir_builder * b,nir_intrinsic_instr * intrin,lower_ngg_ms_state * s)132 ms_store_prim_indices(nir_builder *b,
133                       nir_intrinsic_instr *intrin,
134                       lower_ngg_ms_state *s)
135 {
136    /* EXT_mesh_shader primitive indices: array of vectors.
137     * They don't count as per-primitive outputs, but the array is indexed
138     * by the primitive index, so they are practically per-primitive.
139     */
140    assert(nir_src_is_const(*nir_get_io_offset_src(intrin)));
141    assert(nir_src_as_uint(*nir_get_io_offset_src(intrin)) == 0);
142 
143    const unsigned component_offset = nir_intrinsic_component(intrin);
144    nir_def *store_val = intrin->src[0].ssa;
145    assert(store_val->num_components <= 3);
146 
147    if (store_val->num_components > s->vertices_per_prim)
148       store_val = nir_trim_vector(b, store_val, s->vertices_per_prim);
149 
150    if (s->layout.var.prm_attr.mask & BITFIELD64_BIT(VARYING_SLOT_PRIMITIVE_INDICES)) {
151       for (unsigned c = 0; c < store_val->num_components; ++c) {
152          const unsigned i = VARYING_SLOT_PRIMITIVE_INDICES * 4 + c + component_offset;
153          nir_store_var(b, s->out_variables[i], nir_channel(b, store_val, c), 0x1);
154       }
155       return;
156    }
157 
158    nir_def *arr_index = nir_get_io_arrayed_index_src(intrin)->ssa;
159    nir_def *offset = nir_imul_imm(b, arr_index, s->vertices_per_prim);
160 
161    /* The max vertex count is 256, so these indices always fit 8 bits.
162     * To reduce LDS use, store these as a flat array of 8-bit values.
163     */
164    nir_store_shared(b, nir_u2u8(b, store_val), offset, .base = s->layout.lds.indices_addr + component_offset);
165 }
166 
167 static void
ms_store_cull_flag(nir_builder * b,nir_intrinsic_instr * intrin,lower_ngg_ms_state * s)168 ms_store_cull_flag(nir_builder *b,
169                    nir_intrinsic_instr *intrin,
170                    lower_ngg_ms_state *s)
171 {
172    /* EXT_mesh_shader cull primitive: per-primitive bool. */
173    assert(nir_src_is_const(*nir_get_io_offset_src(intrin)));
174    assert(nir_src_as_uint(*nir_get_io_offset_src(intrin)) == 0);
175    assert(nir_intrinsic_component(intrin) == 0);
176    assert(nir_intrinsic_write_mask(intrin) == 1);
177 
178    nir_def *store_val = intrin->src[0].ssa;
179 
180    assert(store_val->num_components == 1);
181    assert(store_val->bit_size == 1);
182 
183    if (s->layout.var.prm_attr.mask & BITFIELD64_BIT(VARYING_SLOT_CULL_PRIMITIVE)) {
184       nir_store_var(b, s->out_variables[VARYING_SLOT_CULL_PRIMITIVE * 4], nir_b2i32(b, store_val), 0x1);
185       return;
186    }
187 
188    nir_def *arr_index = nir_get_io_arrayed_index_src(intrin)->ssa;
189    nir_def *offset = nir_imul_imm(b, arr_index, s->vertices_per_prim);
190 
191    /* To reduce LDS use, store these as an array of 8-bit values. */
192    nir_store_shared(b, nir_b2i8(b, store_val), offset, .base = s->layout.lds.cull_flags_addr);
193 }
194 
195 static nir_def *
ms_arrayed_output_base_addr(nir_builder * b,nir_def * arr_index,unsigned mapped_location,unsigned num_arrayed_outputs)196 ms_arrayed_output_base_addr(nir_builder *b,
197                             nir_def *arr_index,
198                             unsigned mapped_location,
199                             unsigned num_arrayed_outputs)
200 {
201    /* Address offset of the array item (vertex or primitive). */
202    unsigned arr_index_stride = num_arrayed_outputs * 16u;
203    nir_def *arr_index_off = nir_imul_imm(b, arr_index, arr_index_stride);
204 
205    /* IO address offset within the vertex or primitive data. */
206    unsigned io_offset = mapped_location * 16u;
207    nir_def *io_off = nir_imm_int(b, io_offset);
208 
209    return nir_iadd_nuw(b, arr_index_off, io_off);
210 }
211 
212 static void
update_ms_output_info(const nir_io_semantics io_sem,const nir_src * base_offset_src,const uint32_t write_mask,const unsigned component_offset,const unsigned bit_size,const ms_out_part * out,lower_ngg_ms_state * s)213 update_ms_output_info(const nir_io_semantics io_sem,
214                       const nir_src *base_offset_src,
215                       const uint32_t write_mask,
216                       const unsigned component_offset,
217                       const unsigned bit_size,
218                       const ms_out_part *out,
219                       lower_ngg_ms_state *s)
220 {
221    const uint32_t components_mask = write_mask << component_offset;
222 
223    /* 64-bit outputs should have already been lowered to 32-bit. */
224    assert(bit_size <= 32);
225    assert(components_mask <= 0xf);
226 
227    /* When the base offset is constant, only mark the components of the current slot as used.
228     * Otherwise, mark the components of all possibly affected slots as used.
229     */
230    const unsigned base_off_start = nir_src_is_const(*base_offset_src) ? nir_src_as_uint(*base_offset_src) : 0;
231    const unsigned num_slots = nir_src_is_const(*base_offset_src) ? 1 : io_sem.num_slots;
232 
233    for (unsigned base_off = base_off_start; base_off < num_slots; ++base_off) {
234       ac_nir_prerast_per_output_info *info = &s->out.infos[io_sem.location + base_off];
235       info->components_mask |= components_mask;
236 
237       if (!io_sem.no_sysval_output)
238          info->as_sysval_mask |= components_mask;
239       if (!io_sem.no_varying)
240          info->as_varying_mask |= components_mask;
241    }
242 }
243 
244 static const ms_out_part *
ms_get_out_layout_part(unsigned location,shader_info * info,ms_out_mode * out_mode,lower_ngg_ms_state * s)245 ms_get_out_layout_part(unsigned location,
246                        shader_info *info,
247                        ms_out_mode *out_mode,
248                        lower_ngg_ms_state *s)
249 {
250    uint64_t mask = BITFIELD64_BIT(location);
251 
252    if (info->per_primitive_outputs & mask) {
253       if (mask & s->layout.lds.prm_attr.mask) {
254          *out_mode = ms_out_mode_lds;
255          return &s->layout.lds.prm_attr;
256       } else if (mask & s->layout.scratch_ring.prm_attr.mask) {
257          *out_mode = ms_out_mode_scratch_ring;
258          return &s->layout.scratch_ring.prm_attr;
259       } else if (mask & s->layout.attr_ring.prm_attr.mask) {
260          *out_mode = ms_out_mode_attr_ring;
261          return &s->layout.attr_ring.prm_attr;
262       } else if (mask & s->layout.var.prm_attr.mask) {
263          *out_mode = ms_out_mode_var;
264          return &s->layout.var.prm_attr;
265       }
266    } else {
267       if (mask & s->layout.lds.vtx_attr.mask) {
268          *out_mode = ms_out_mode_lds;
269          return &s->layout.lds.vtx_attr;
270       } else if (mask & s->layout.scratch_ring.vtx_attr.mask) {
271          *out_mode = ms_out_mode_scratch_ring;
272          return &s->layout.scratch_ring.vtx_attr;
273       } else if (mask & s->layout.attr_ring.vtx_attr.mask) {
274          *out_mode = ms_out_mode_attr_ring;
275          return &s->layout.attr_ring.vtx_attr;
276       } else if (mask & s->layout.var.vtx_attr.mask) {
277          *out_mode = ms_out_mode_var;
278          return &s->layout.var.vtx_attr;
279       }
280    }
281 
282    unreachable("Couldn't figure out mesh shader output mode.");
283 }
284 
285 static void
ms_store_arrayed_output(nir_builder * b,nir_src * base_off_src,nir_def * store_val,nir_def * arr_index,const nir_io_semantics io_sem,const unsigned component_offset,const unsigned write_mask,lower_ngg_ms_state * s)286 ms_store_arrayed_output(nir_builder *b,
287                         nir_src *base_off_src,
288                         nir_def *store_val,
289                         nir_def *arr_index,
290                         const nir_io_semantics io_sem,
291                         const unsigned component_offset,
292                         const unsigned write_mask,
293                         lower_ngg_ms_state *s)
294 {
295    ms_out_mode out_mode;
296    const ms_out_part *out = ms_get_out_layout_part(io_sem.location, &b->shader->info, &out_mode, s);
297    update_ms_output_info(io_sem, base_off_src, write_mask, component_offset, store_val->bit_size, out, s);
298 
299    bool hi_16b = io_sem.high_16bits;
300    bool lo_16b = !hi_16b && store_val->bit_size == 16;
301 
302    unsigned mapped_location = util_bitcount64(out->mask & u_bit_consecutive64(0, io_sem.location));
303    unsigned num_outputs = util_bitcount64(out->mask);
304    unsigned const_off = out->addr + component_offset * 4 + (hi_16b ? 2 : 0);
305 
306    nir_def *base_addr = ms_arrayed_output_base_addr(b, arr_index, mapped_location, num_outputs);
307    nir_def *base_offset = base_off_src->ssa;
308    nir_def *base_addr_off = nir_imul_imm(b, base_offset, 16u);
309    nir_def *addr = nir_iadd_nuw(b, base_addr, base_addr_off);
310 
311    if (out_mode == ms_out_mode_lds) {
312       nir_store_shared(b, store_val, addr, .base = const_off,
313                      .write_mask = write_mask, .align_mul = 16,
314                      .align_offset = const_off % 16);
315    } else if (out_mode == ms_out_mode_scratch_ring) {
316       nir_def *ring = nir_load_ring_mesh_scratch_amd(b);
317       nir_def *off = nir_load_ring_mesh_scratch_offset_amd(b);
318       nir_def *zero = nir_imm_int(b, 0);
319       nir_store_buffer_amd(b, store_val, ring, addr, off, zero,
320                            .base = const_off,
321                            .write_mask = write_mask,
322                            .memory_modes = nir_var_shader_out,
323                            .access = ACCESS_COHERENT);
324    } else if (out_mode == ms_out_mode_attr_ring) {
325       /* GFX11+: Store params straight to the attribute ring.
326        *
327        * Even though the access pattern may not be the most optimal,
328        * this is still much better than reserving LDS and losing waves.
329        * (Also much better than storing and reloading from the scratch ring.)
330        */
331       unsigned param_offset = s->vs_output_param_offset[io_sem.location];
332       nir_def *ring = nir_load_ring_attr_amd(b);
333       nir_def *soffset = nir_load_ring_attr_offset_amd(b);
334       nir_store_buffer_amd(b, store_val, ring, base_addr_off, soffset, arr_index,
335                            .base = const_off + param_offset * 16,
336                            .write_mask = write_mask,
337                            .memory_modes = nir_var_shader_out,
338                            .access = ACCESS_COHERENT | ACCESS_IS_SWIZZLED_AMD);
339    } else if (out_mode == ms_out_mode_var) {
340       unsigned write_mask_32 = write_mask;
341       if (store_val->bit_size > 32) {
342          /* Split 64-bit store values to 32-bit components. */
343          store_val = nir_bitcast_vector(b, store_val, 32);
344          /* Widen the write mask so it is in 32-bit components. */
345          write_mask_32 = util_widen_mask(write_mask, store_val->bit_size / 32);
346       }
347 
348       u_foreach_bit(comp, write_mask_32) {
349          unsigned idx = io_sem.location * 4 + comp + component_offset;
350          nir_def *val = nir_channel(b, store_val, comp);
351          nir_def *v = nir_load_var(b, s->out_variables[idx]);
352 
353          if (lo_16b) {
354             nir_def *var_hi = nir_unpack_32_2x16_split_y(b, v);
355             val = nir_pack_32_2x16_split(b, val, var_hi);
356          } else if (hi_16b) {
357             nir_def *var_lo = nir_unpack_32_2x16_split_x(b, v);
358             val = nir_pack_32_2x16_split(b, var_lo, val);
359          }
360 
361          nir_store_var(b, s->out_variables[idx], val, 0x1);
362       }
363    } else {
364       unreachable("Invalid MS output mode for store");
365    }
366 }
367 
368 static void
ms_store_arrayed_output_intrin(nir_builder * b,nir_intrinsic_instr * intrin,lower_ngg_ms_state * s)369 ms_store_arrayed_output_intrin(nir_builder *b,
370                                nir_intrinsic_instr *intrin,
371                                lower_ngg_ms_state *s)
372 {
373    const nir_io_semantics io_sem = nir_intrinsic_io_semantics(intrin);
374 
375    if (io_sem.location == VARYING_SLOT_PRIMITIVE_INDICES) {
376       ms_store_prim_indices(b, intrin, s);
377       return;
378    } else if (io_sem.location == VARYING_SLOT_CULL_PRIMITIVE) {
379       ms_store_cull_flag(b, intrin, s);
380       return;
381    }
382 
383    unsigned component_offset = nir_intrinsic_component(intrin);
384    unsigned write_mask = nir_intrinsic_write_mask(intrin);
385 
386    nir_def *store_val = intrin->src[0].ssa;
387    nir_def *arr_index = nir_get_io_arrayed_index_src(intrin)->ssa;
388    nir_src *base_off_src = nir_get_io_offset_src(intrin);
389 
390    if (store_val->bit_size < 32) {
391       /* Split 16-bit output stores to ensure each 16-bit component is stored
392        * in the correct location, without overwriting the other 16 bits there.
393        */
394       u_foreach_bit(c, write_mask) {
395          nir_def *store_component = nir_channel(b, store_val, c);
396          ms_store_arrayed_output(b, base_off_src, store_component, arr_index, io_sem, c + component_offset, 1, s);
397       }
398    } else {
399       ms_store_arrayed_output(b, base_off_src, store_val, arr_index, io_sem, component_offset, write_mask, s);
400    }
401 }
402 
403 static nir_def *
ms_load_arrayed_output(nir_builder * b,nir_def * arr_index,nir_def * base_offset,unsigned location,unsigned component_offset,unsigned num_components,unsigned load_bit_size,lower_ngg_ms_state * s)404 ms_load_arrayed_output(nir_builder *b,
405                        nir_def *arr_index,
406                        nir_def *base_offset,
407                        unsigned location,
408                        unsigned component_offset,
409                        unsigned num_components,
410                        unsigned load_bit_size,
411                        lower_ngg_ms_state *s)
412 {
413    ms_out_mode out_mode;
414    const ms_out_part *out = ms_get_out_layout_part(location, &b->shader->info, &out_mode, s);
415 
416    unsigned component_addr_off = component_offset * 4;
417    unsigned num_outputs = util_bitcount64(out->mask);
418    unsigned const_off = out->addr + component_offset * 4;
419 
420    /* Use compacted location instead of the original semantic location. */
421    unsigned mapped_location = util_bitcount64(out->mask & u_bit_consecutive64(0, location));
422 
423    nir_def *base_addr = ms_arrayed_output_base_addr(b, arr_index, mapped_location, num_outputs);
424    nir_def *base_addr_off = nir_imul_imm(b, base_offset, 16);
425    nir_def *addr = nir_iadd_nuw(b, base_addr, base_addr_off);
426 
427    if (out_mode == ms_out_mode_lds) {
428       return nir_load_shared(b, num_components, load_bit_size, addr, .align_mul = 16,
429                              .align_offset = component_addr_off % 16,
430                              .base = const_off);
431    } else if (out_mode == ms_out_mode_scratch_ring) {
432       nir_def *ring = nir_load_ring_mesh_scratch_amd(b);
433       nir_def *off = nir_load_ring_mesh_scratch_offset_amd(b);
434       nir_def *zero = nir_imm_int(b, 0);
435       return nir_load_buffer_amd(b, num_components, load_bit_size, ring, addr, off, zero,
436                                  .base = const_off,
437                                  .memory_modes = nir_var_shader_out,
438                                  .access = ACCESS_COHERENT);
439    } else if (out_mode == ms_out_mode_var) {
440       assert(load_bit_size == 32);
441       nir_def *arr[8] = {0};
442       for (unsigned comp = 0; comp < num_components; ++comp) {
443          unsigned idx = location * 4 + comp + component_addr_off;
444          arr[comp] = nir_load_var(b, s->out_variables[idx]);
445       }
446       return nir_vec(b, arr, num_components);
447    } else {
448       unreachable("Invalid MS output mode for load");
449    }
450 }
451 
452 static nir_def *
lower_ms_load_workgroup_index(nir_builder * b,UNUSED nir_intrinsic_instr * intrin,lower_ngg_ms_state * s)453 lower_ms_load_workgroup_index(nir_builder *b,
454                               UNUSED nir_intrinsic_instr *intrin,
455                               lower_ngg_ms_state *s)
456 {
457    return s->workgroup_index;
458 }
459 
460 static nir_def *
lower_ms_set_vertex_and_primitive_count(nir_builder * b,nir_intrinsic_instr * intrin,lower_ngg_ms_state * s)461 lower_ms_set_vertex_and_primitive_count(nir_builder *b,
462                                         nir_intrinsic_instr *intrin,
463                                         lower_ngg_ms_state *s)
464 {
465    /* If either the number of vertices or primitives is zero, set both of them to zero. */
466    nir_def *num_vtx = nir_read_first_invocation(b, intrin->src[0].ssa);
467    nir_def *num_prm = nir_read_first_invocation(b, intrin->src[1].ssa);
468    nir_def *zero = nir_imm_int(b, 0);
469    nir_def *is_either_zero = nir_ieq(b, nir_umin(b, num_vtx, num_prm), zero);
470    num_vtx = nir_bcsel(b, is_either_zero, zero, num_vtx);
471    num_prm = nir_bcsel(b, is_either_zero, zero, num_prm);
472 
473    nir_store_var(b, s->vertex_count_var, num_vtx, 0x1);
474    nir_store_var(b, s->primitive_count_var, num_prm, 0x1);
475 
476    return NIR_LOWER_INSTR_PROGRESS_REPLACE;
477 }
478 
479 static nir_def *
update_ms_barrier(nir_builder * b,nir_intrinsic_instr * intrin,lower_ngg_ms_state * s)480 update_ms_barrier(nir_builder *b,
481                          nir_intrinsic_instr *intrin,
482                          lower_ngg_ms_state *s)
483 {
484    /* Output loads and stores are lowered to shared memory access,
485     * so we have to update the barriers to also reflect this.
486     */
487    unsigned mem_modes = nir_intrinsic_memory_modes(intrin);
488    if (mem_modes & nir_var_shader_out)
489       mem_modes |= nir_var_mem_shared;
490    else
491       return NULL;
492 
493    nir_intrinsic_set_memory_modes(intrin, mem_modes);
494 
495    return NIR_LOWER_INSTR_PROGRESS;
496 }
497 
498 static nir_def *
lower_ms_intrinsic(nir_builder * b,nir_instr * instr,void * state)499 lower_ms_intrinsic(nir_builder *b, nir_instr *instr, void *state)
500 {
501    lower_ngg_ms_state *s = (lower_ngg_ms_state *) state;
502 
503    if (instr->type != nir_instr_type_intrinsic)
504       return NULL;
505 
506    nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
507 
508    switch (intrin->intrinsic) {
509    case nir_intrinsic_store_per_vertex_output:
510    case nir_intrinsic_store_per_primitive_output:
511       ms_store_arrayed_output_intrin(b, intrin, s);
512       return NIR_LOWER_INSTR_PROGRESS_REPLACE;
513    case nir_intrinsic_barrier:
514       return update_ms_barrier(b, intrin, s);
515    case nir_intrinsic_load_workgroup_index:
516       return lower_ms_load_workgroup_index(b, intrin, s);
517    case nir_intrinsic_set_vertex_and_primitive_count:
518       return lower_ms_set_vertex_and_primitive_count(b, intrin, s);
519    default:
520       unreachable("Not a lowerable mesh shader intrinsic.");
521    }
522 }
523 
524 static bool
filter_ms_intrinsic(const nir_instr * instr,UNUSED const void * s)525 filter_ms_intrinsic(const nir_instr *instr,
526                     UNUSED const void *s)
527 {
528    if (instr->type != nir_instr_type_intrinsic)
529       return false;
530 
531    nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
532    return intrin->intrinsic == nir_intrinsic_store_output ||
533           intrin->intrinsic == nir_intrinsic_load_output ||
534           intrin->intrinsic == nir_intrinsic_store_per_vertex_output ||
535           intrin->intrinsic == nir_intrinsic_store_per_primitive_output ||
536           intrin->intrinsic == nir_intrinsic_barrier ||
537           intrin->intrinsic == nir_intrinsic_load_workgroup_index ||
538           intrin->intrinsic == nir_intrinsic_set_vertex_and_primitive_count;
539 }
540 
541 static void
lower_ms_intrinsics(nir_shader * shader,lower_ngg_ms_state * s)542 lower_ms_intrinsics(nir_shader *shader, lower_ngg_ms_state *s)
543 {
544    nir_shader_lower_instructions(shader, filter_ms_intrinsic, lower_ms_intrinsic, s);
545 }
546 
547 static void
ms_emit_arrayed_outputs(nir_builder * b,nir_def * invocation_index,uint64_t mask,lower_ngg_ms_state * s)548 ms_emit_arrayed_outputs(nir_builder *b,
549                         nir_def *invocation_index,
550                         uint64_t mask,
551                         lower_ngg_ms_state *s)
552 {
553    nir_def *zero = nir_imm_int(b, 0);
554 
555    u_foreach_bit64(slot, mask) {
556       /* Should not occur here, handled separately. */
557       assert(slot != VARYING_SLOT_PRIMITIVE_COUNT && slot != VARYING_SLOT_PRIMITIVE_INDICES);
558 
559       unsigned component_mask = s->out.infos[slot].components_mask;
560 
561       while (component_mask) {
562          int start_comp = 0, num_components = 1;
563          u_bit_scan_consecutive_range(&component_mask, &start_comp, &num_components);
564 
565          nir_def *load =
566             ms_load_arrayed_output(b, invocation_index, zero, slot, start_comp,
567                                    num_components, 32, s);
568 
569          for (int i = 0; i < num_components; i++)
570             s->out.outputs[slot][start_comp + i] = nir_channel(b, load, i);
571       }
572    }
573 }
574 
575 static void
ms_create_same_invocation_vars(nir_builder * b,lower_ngg_ms_state * s)576 ms_create_same_invocation_vars(nir_builder *b, lower_ngg_ms_state *s)
577 {
578    /* Initialize NIR variables for same-invocation outputs. */
579    uint64_t same_invocation_output_mask = s->layout.var.prm_attr.mask | s->layout.var.vtx_attr.mask;
580 
581    u_foreach_bit64(slot, same_invocation_output_mask) {
582       for (unsigned comp = 0; comp < 4; ++comp) {
583          unsigned idx = slot * 4 + comp;
584          s->out_variables[idx] = nir_local_variable_create(b->impl, glsl_uint_type(), "ms_var_output");
585       }
586    }
587 }
588 
589 static void
ms_emit_legacy_workgroup_index(nir_builder * b,lower_ngg_ms_state * s)590 ms_emit_legacy_workgroup_index(nir_builder *b, lower_ngg_ms_state *s)
591 {
592    /* Workgroup ID should have been lowered to workgroup index. */
593    assert(!BITSET_TEST(b->shader->info.system_values_read, SYSTEM_VALUE_WORKGROUP_ID));
594 
595    /* No need to do anything if the shader doesn't use the workgroup index. */
596    if (!BITSET_TEST(b->shader->info.system_values_read, SYSTEM_VALUE_WORKGROUP_INDEX))
597       return;
598 
599    b->cursor = nir_before_impl(b->impl);
600 
601    /* Legacy fast launch mode (FAST_LAUNCH=1):
602     *
603     * The HW doesn't support a proper workgroup index for vertex processing stages,
604     * so we use the vertex ID which is equivalent to the index of the current workgroup
605     * within the current dispatch.
606     *
607     * Due to the register programming of mesh shaders, this value is only filled for
608     * the first invocation of the first wave. To let other waves know, we use LDS.
609     */
610    nir_def *workgroup_index = nir_load_vertex_id_zero_base(b);
611 
612    if (s->api_workgroup_size <= s->wave_size) {
613       /* API workgroup is small, so we don't need to use LDS. */
614       s->workgroup_index = nir_read_first_invocation(b, workgroup_index);
615       return;
616    }
617 
618    unsigned workgroup_index_lds_addr = s->layout.lds.workgroup_info_addr + lds_ms_wg_index;
619 
620    nir_def *zero = nir_imm_int(b, 0);
621    nir_def *dont_care = nir_undef(b, 1, 32);
622    nir_def *loaded_workgroup_index = NULL;
623 
624    /* Use elect to make sure only 1 invocation uses LDS. */
625    nir_if *if_elected = nir_push_if(b, nir_elect(b, 1));
626    {
627       nir_def *wave_id = nir_load_subgroup_id(b);
628       nir_if *if_wave_0 = nir_push_if(b, nir_ieq_imm(b, wave_id, 0));
629       {
630          nir_store_shared(b, workgroup_index, zero, .base = workgroup_index_lds_addr);
631          nir_barrier(b, .execution_scope = SCOPE_WORKGROUP,
632                                .memory_scope = SCOPE_WORKGROUP,
633                                .memory_semantics = NIR_MEMORY_ACQ_REL,
634                                .memory_modes = nir_var_mem_shared);
635       }
636       nir_push_else(b, if_wave_0);
637       {
638          nir_barrier(b, .execution_scope = SCOPE_WORKGROUP,
639                                .memory_scope = SCOPE_WORKGROUP,
640                                .memory_semantics = NIR_MEMORY_ACQ_REL,
641                                .memory_modes = nir_var_mem_shared);
642          loaded_workgroup_index = nir_load_shared(b, 1, 32, zero, .base = workgroup_index_lds_addr);
643       }
644       nir_pop_if(b, if_wave_0);
645 
646       workgroup_index = nir_if_phi(b, workgroup_index, loaded_workgroup_index);
647    }
648    nir_pop_if(b, if_elected);
649 
650    workgroup_index = nir_if_phi(b, workgroup_index, dont_care);
651    s->workgroup_index = nir_read_first_invocation(b, workgroup_index);
652 }
653 
654 static void
set_ms_final_output_counts(nir_builder * b,lower_ngg_ms_state * s,nir_def ** out_num_prm,nir_def ** out_num_vtx)655 set_ms_final_output_counts(nir_builder *b,
656                            lower_ngg_ms_state *s,
657                            nir_def **out_num_prm,
658                            nir_def **out_num_vtx)
659 {
660    /* The spec allows the numbers to be divergent, and in that case we need to
661     * use the values from the first invocation. Also the HW requires us to set
662     * both to 0 if either was 0.
663     *
664     * These are already done by the lowering.
665     */
666    nir_def *num_prm = nir_load_var(b, s->primitive_count_var);
667    nir_def *num_vtx = nir_load_var(b, s->vertex_count_var);
668 
669    if (s->hw_workgroup_size <= s->wave_size) {
670       /* Single-wave mesh shader workgroup. */
671       nir_def *m0 = nir_ior(b, nir_ishl_imm(b, num_prm, 12), num_vtx);
672       nir_sendmsg_amd(b, m0, .base = AC_SENDMSG_GS_ALLOC_REQ);
673 
674       *out_num_prm = num_prm;
675       *out_num_vtx = num_vtx;
676       return;
677    }
678 
679    /* Multi-wave mesh shader workgroup:
680     * We need to use LDS to distribute the correct values to the other waves.
681     *
682     * TODO:
683     * If we can prove that the values are workgroup-uniform, we can skip this
684     * and just use whatever the current wave has. However, NIR divergence analysis
685     * currently doesn't support this.
686     */
687 
688    nir_def *zero = nir_imm_int(b, 0);
689 
690    nir_if *if_wave_0 = nir_push_if(b, nir_ieq_imm(b, nir_load_subgroup_id(b), 0));
691    {
692       nir_if *if_elected = nir_push_if(b, nir_elect(b, 1));
693       {
694          nir_store_shared(b, nir_vec2(b, num_prm, num_vtx), zero,
695                           .base = s->layout.lds.workgroup_info_addr + lds_ms_num_prims);
696       }
697       nir_pop_if(b, if_elected);
698 
699       nir_barrier(b, .execution_scope = SCOPE_WORKGROUP,
700                             .memory_scope = SCOPE_WORKGROUP,
701                             .memory_semantics = NIR_MEMORY_ACQ_REL,
702                             .memory_modes = nir_var_mem_shared);
703 
704       nir_def *m0 = nir_ior(b, nir_ishl_imm(b, num_prm, 12), num_vtx);
705       nir_sendmsg_amd(b, m0, .base = AC_SENDMSG_GS_ALLOC_REQ);
706    }
707    nir_push_else(b, if_wave_0);
708    {
709       nir_barrier(b, .execution_scope = SCOPE_WORKGROUP,
710                             .memory_scope = SCOPE_WORKGROUP,
711                             .memory_semantics = NIR_MEMORY_ACQ_REL,
712                             .memory_modes = nir_var_mem_shared);
713 
714       nir_def *prm_vtx = NULL;
715       nir_def *dont_care_2x32 = nir_undef(b, 2, 32);
716       nir_if *if_elected = nir_push_if(b, nir_elect(b, 1));
717       {
718          prm_vtx = nir_load_shared(b, 2, 32, zero,
719                                    .base = s->layout.lds.workgroup_info_addr + lds_ms_num_prims);
720       }
721       nir_pop_if(b, if_elected);
722 
723       prm_vtx = nir_if_phi(b, prm_vtx, dont_care_2x32);
724       num_prm = nir_read_first_invocation(b, nir_channel(b, prm_vtx, 0));
725       num_vtx = nir_read_first_invocation(b, nir_channel(b, prm_vtx, 1));
726 
727       nir_store_var(b, s->primitive_count_var, num_prm, 0x1);
728       nir_store_var(b, s->vertex_count_var, num_vtx, 0x1);
729    }
730    nir_pop_if(b, if_wave_0);
731 
732    *out_num_prm = nir_load_var(b, s->primitive_count_var);
733    *out_num_vtx = nir_load_var(b, s->vertex_count_var);
734 }
735 
736 static void
ms_emit_attribute_ring_output_stores(nir_builder * b,const uint64_t outputs_mask,nir_def * idx,lower_ngg_ms_state * s)737 ms_emit_attribute_ring_output_stores(nir_builder *b, const uint64_t outputs_mask,
738                                      nir_def *idx, lower_ngg_ms_state *s)
739 {
740    if (!outputs_mask)
741       return;
742 
743    nir_def *ring = nir_load_ring_attr_amd(b);
744    nir_def *off = nir_load_ring_attr_offset_amd(b);
745    nir_def *zero = nir_imm_int(b, 0);
746 
747    u_foreach_bit64 (slot, outputs_mask) {
748       if (s->vs_output_param_offset[slot] > AC_EXP_PARAM_OFFSET_31)
749          continue;
750 
751       nir_def *soffset = nir_iadd_imm(b, off, s->vs_output_param_offset[slot] * 16 * 32);
752       nir_def *store_val = nir_undef(b, 4, 32);
753       unsigned store_val_components = 0;
754       for (unsigned c = 0; c < 4; ++c) {
755          if (s->out.outputs[slot][c]) {
756             store_val = nir_vector_insert_imm(b, store_val, s->out.outputs[slot][c], c);
757             store_val_components = c + 1;
758          }
759       }
760 
761       store_val = nir_trim_vector(b, store_val, store_val_components);
762       nir_store_buffer_amd(b, store_val, ring, zero, soffset, idx,
763                            .memory_modes = nir_var_shader_out,
764                            .access = ACCESS_COHERENT | ACCESS_IS_SWIZZLED_AMD);
765    }
766 }
767 
768 static nir_def *
ms_prim_exp_arg_ch1(nir_builder * b,nir_def * invocation_index,nir_def * num_vtx,lower_ngg_ms_state * s)769 ms_prim_exp_arg_ch1(nir_builder *b, nir_def *invocation_index, nir_def *num_vtx, lower_ngg_ms_state *s)
770 {
771    /* Primitive connectivity data: describes which vertices the primitive uses. */
772    nir_def *prim_idx_addr = nir_imul_imm(b, invocation_index, s->vertices_per_prim);
773    nir_def *indices_loaded = NULL;
774    nir_def *cull_flag = NULL;
775 
776    if (s->layout.var.prm_attr.mask & BITFIELD64_BIT(VARYING_SLOT_PRIMITIVE_INDICES)) {
777       nir_def *indices[3] = {0};
778       for (unsigned c = 0; c < s->vertices_per_prim; ++c)
779          indices[c] = nir_load_var(b, s->out_variables[VARYING_SLOT_PRIMITIVE_INDICES * 4 + c]);
780       indices_loaded = nir_vec(b, indices, s->vertices_per_prim);
781    } else {
782       indices_loaded = nir_load_shared(b, s->vertices_per_prim, 8, prim_idx_addr, .base = s->layout.lds.indices_addr);
783       indices_loaded = nir_u2u32(b, indices_loaded);
784    }
785 
786    if (s->uses_cull_flags) {
787       nir_def *loaded_cull_flag = NULL;
788       if (s->layout.var.prm_attr.mask & BITFIELD64_BIT(VARYING_SLOT_CULL_PRIMITIVE))
789          loaded_cull_flag = nir_load_var(b, s->out_variables[VARYING_SLOT_CULL_PRIMITIVE * 4]);
790       else
791          loaded_cull_flag = nir_u2u32(b, nir_load_shared(b, 1, 8, prim_idx_addr, .base = s->layout.lds.cull_flags_addr));
792 
793       cull_flag = nir_i2b(b, loaded_cull_flag);
794    }
795 
796    nir_def *indices[3];
797    nir_def *max_vtx_idx = nir_iadd_imm(b, num_vtx, -1u);
798 
799    for (unsigned i = 0; i < s->vertices_per_prim; ++i) {
800       indices[i] = nir_channel(b, indices_loaded, i);
801       indices[i] = nir_umin(b, indices[i], max_vtx_idx);
802    }
803 
804    return ac_nir_pack_ngg_prim_exp_arg(b, s->vertices_per_prim, indices, cull_flag, s->gfx_level);
805 }
806 
807 static nir_def *
ms_prim_exp_arg_ch2(nir_builder * b,uint64_t outputs_mask,lower_ngg_ms_state * s)808 ms_prim_exp_arg_ch2(nir_builder *b, uint64_t outputs_mask, lower_ngg_ms_state *s)
809 {
810    nir_def *prim_exp_arg_ch2 = NULL;
811 
812    if (outputs_mask) {
813       /* When layer, viewport etc. are per-primitive, they need to be encoded in
814        * the primitive export instruction's second channel. The encoding is:
815        *
816        * --- GFX10.3 ---
817        * bits 31..30: VRS rate Y
818        * bits 29..28: VRS rate X
819        * bits 23..20: viewport
820        * bits 19..17: layer
821        *
822        * --- GFX11 ---
823        * bits 31..28: VRS rate enum
824        * bits 23..20: viewport
825        * bits 12..00: layer
826        */
827       prim_exp_arg_ch2 = nir_imm_int(b, 0);
828 
829       if (outputs_mask & VARYING_BIT_LAYER) {
830          nir_def *layer =
831             nir_ishl_imm(b, s->out.outputs[VARYING_SLOT_LAYER][0], s->gfx_level >= GFX11 ? 0 : 17);
832          prim_exp_arg_ch2 = nir_ior(b, prim_exp_arg_ch2, layer);
833       }
834 
835       if (outputs_mask & VARYING_BIT_VIEWPORT) {
836          nir_def *view = nir_ishl_imm(b, s->out.outputs[VARYING_SLOT_VIEWPORT][0], 20);
837          prim_exp_arg_ch2 = nir_ior(b, prim_exp_arg_ch2, view);
838       }
839 
840       if (outputs_mask & VARYING_BIT_PRIMITIVE_SHADING_RATE) {
841          nir_def *rate = s->out.outputs[VARYING_SLOT_PRIMITIVE_SHADING_RATE][0];
842          prim_exp_arg_ch2 = nir_ior(b, prim_exp_arg_ch2, rate);
843       }
844    }
845 
846    return prim_exp_arg_ch2;
847 }
848 
849 static void
ms_prim_gen_query(nir_builder * b,nir_def * invocation_index,nir_def * num_prm,lower_ngg_ms_state * s)850 ms_prim_gen_query(nir_builder *b,
851                   nir_def *invocation_index,
852                   nir_def *num_prm,
853                   lower_ngg_ms_state *s)
854 {
855    if (!s->has_query)
856       return;
857 
858    nir_if *if_invocation_index_zero = nir_push_if(b, nir_ieq_imm(b, invocation_index, 0));
859    {
860       nir_if *if_shader_query = nir_push_if(b, nir_load_prim_gen_query_enabled_amd(b));
861       {
862          nir_atomic_add_gen_prim_count_amd(b, num_prm, .stream_id = 0);
863       }
864       nir_pop_if(b, if_shader_query);
865    }
866    nir_pop_if(b, if_invocation_index_zero);
867 }
868 
869 static void
ms_invocation_query(nir_builder * b,nir_def * invocation_index,lower_ngg_ms_state * s)870 ms_invocation_query(nir_builder *b,
871                     nir_def *invocation_index,
872                     lower_ngg_ms_state *s)
873 {
874    if (!s->has_query)
875       return;
876 
877    nir_if *if_invocation_index_zero = nir_push_if(b, nir_ieq_imm(b, invocation_index, 0));
878    {
879       nir_if *if_pipeline_query = nir_push_if(b, nir_load_pipeline_stat_query_enabled_amd(b));
880       {
881          nir_atomic_add_shader_invocation_count_amd(b, nir_imm_int(b, s->api_workgroup_size));
882       }
883       nir_pop_if(b, if_pipeline_query);
884    }
885    nir_pop_if(b, if_invocation_index_zero);
886 }
887 
888 static void
emit_ms_vertex(nir_builder * b,nir_def * index,nir_def * row,bool exports,bool parameters,uint64_t per_vertex_outputs,lower_ngg_ms_state * s)889 emit_ms_vertex(nir_builder *b, nir_def *index, nir_def *row, bool exports, bool parameters,
890                uint64_t per_vertex_outputs, lower_ngg_ms_state *s)
891 {
892    ms_emit_arrayed_outputs(b, index, per_vertex_outputs, s);
893 
894    if (exports) {
895       ac_nir_export_position(b, s->gfx_level, s->clipdist_enable_mask,
896                              !s->has_param_exports, false, true,
897                              s->per_vertex_outputs | VARYING_BIT_POS, &s->out, row);
898    }
899 
900    if (parameters) {
901       /* Export generic attributes on GFX10.3
902        * (On GFX11 they are already stored in the attribute ring.)
903        */
904       if (s->has_param_exports && s->gfx_level == GFX10_3) {
905          ac_nir_export_parameters(b, s->vs_output_param_offset, per_vertex_outputs, 0, &s->out);
906       }
907 
908       /* GFX11+: also store special outputs to the attribute ring so PS can load them. */
909       if (s->gfx_level >= GFX11 && (per_vertex_outputs & MS_VERT_ARG_EXP_MASK))
910          ms_emit_attribute_ring_output_stores(b, per_vertex_outputs & MS_VERT_ARG_EXP_MASK, index, s);
911    }
912 }
913 
914 static void
emit_ms_primitive(nir_builder * b,nir_def * index,nir_def * row,bool exports,bool parameters,uint64_t per_primitive_outputs,lower_ngg_ms_state * s)915 emit_ms_primitive(nir_builder *b, nir_def *index, nir_def *row, bool exports, bool parameters,
916                   uint64_t per_primitive_outputs, lower_ngg_ms_state *s)
917 {
918    ms_emit_arrayed_outputs(b, index, per_primitive_outputs, s);
919 
920    /* Insert layer output store if the pipeline uses multiview but the API shader doesn't write it. */
921    if (s->insert_layer_output) {
922       s->out.outputs[VARYING_SLOT_LAYER][0] = nir_load_view_index(b);
923       s->out.infos[VARYING_SLOT_LAYER].as_sysval_mask |= 1;
924    }
925 
926    if (exports) {
927       const uint64_t outputs_mask = per_primitive_outputs & MS_PRIM_ARG_EXP_MASK;
928       nir_def *num_vtx = nir_load_var(b, s->vertex_count_var);
929       nir_def *prim_exp_arg_ch1 = ms_prim_exp_arg_ch1(b, index, num_vtx, s);
930       nir_def *prim_exp_arg_ch2 = ms_prim_exp_arg_ch2(b, outputs_mask, s);
931 
932       nir_def *prim_exp_arg = prim_exp_arg_ch2 ?
933          nir_vec2(b, prim_exp_arg_ch1, prim_exp_arg_ch2) : prim_exp_arg_ch1;
934 
935       ac_nir_export_primitive(b, prim_exp_arg, row);
936    }
937 
938    if (parameters) {
939       /* Export generic attributes on GFX10.3
940        * (On GFX11 they are already stored in the attribute ring.)
941        */
942       if (s->has_param_exports && s->gfx_level == GFX10_3) {
943          ac_nir_export_parameters(b, s->vs_output_param_offset, per_primitive_outputs, 0, &s->out);
944       }
945 
946       /* GFX11+: also store special outputs to the attribute ring so PS can load them. */
947       if (s->gfx_level >= GFX11)
948          ms_emit_attribute_ring_output_stores(b, per_primitive_outputs & MS_PRIM_ARG_EXP_MASK, index, s);
949    }
950 }
951 
952 static void
emit_ms_outputs(nir_builder * b,nir_def * invocation_index,nir_def * row_start,nir_def * count,bool exports,bool parameters,uint64_t mask,void (* cb)(nir_builder *,nir_def *,nir_def *,bool,bool,uint64_t,lower_ngg_ms_state *),lower_ngg_ms_state * s)953 emit_ms_outputs(nir_builder *b, nir_def *invocation_index, nir_def *row_start,
954                 nir_def *count, bool exports, bool parameters, uint64_t mask,
955                 void (*cb)(nir_builder *, nir_def *, nir_def *, bool, bool,
956                            uint64_t, lower_ngg_ms_state *),
957                 lower_ngg_ms_state *s)
958 {
959    if (cb == &emit_ms_primitive ? s->prim_multirow_export : s->vert_multirow_export) {
960       assert(s->hw_workgroup_size % s->wave_size == 0);
961       const unsigned num_waves = s->hw_workgroup_size / s->wave_size;
962 
963       nir_loop *row_loop = nir_push_loop(b);
964       {
965          nir_block *preheader = nir_cf_node_as_block(nir_cf_node_prev(&row_loop->cf_node));
966 
967          nir_phi_instr *index = nir_phi_instr_create(b->shader);
968          nir_phi_instr *row = nir_phi_instr_create(b->shader);
969          nir_def_init(&index->instr, &index->def, 1, 32);
970          nir_def_init(&row->instr, &row->def, 1, 32);
971 
972          nir_phi_instr_add_src(index, preheader, invocation_index);
973          nir_phi_instr_add_src(row, preheader, row_start);
974 
975          nir_if *if_break = nir_push_if(b, nir_uge(b, &index->def, count));
976          {
977             nir_jump(b, nir_jump_break);
978          }
979          nir_pop_if(b, if_break);
980 
981          cb(b, &index->def, &row->def, exports, parameters, mask, s);
982 
983          nir_block *body = nir_cursor_current_block(b->cursor);
984          nir_phi_instr_add_src(index, body,
985                                nir_iadd_imm(b, &index->def, s->hw_workgroup_size));
986          nir_phi_instr_add_src(row, body,
987                                nir_iadd_imm(b, &row->def, num_waves));
988 
989          nir_instr_insert_before_cf_list(&row_loop->body, &row->instr);
990          nir_instr_insert_before_cf_list(&row_loop->body, &index->instr);
991       }
992       nir_pop_loop(b, row_loop);
993    } else {
994       nir_def *has_output = nir_ilt(b, invocation_index, count);
995       nir_if *if_has_output = nir_push_if(b, has_output);
996       {
997          cb(b, invocation_index, row_start, exports, parameters, mask, s);
998       }
999       nir_pop_if(b, if_has_output);
1000    }
1001 }
1002 
1003 static void
emit_ms_finale(nir_builder * b,lower_ngg_ms_state * s)1004 emit_ms_finale(nir_builder *b, lower_ngg_ms_state *s)
1005 {
1006    /* We assume there is always a single end block in the shader. */
1007    nir_block *last_block = nir_impl_last_block(b->impl);
1008    b->cursor = nir_after_block(last_block);
1009 
1010    nir_barrier(b, .execution_scope=SCOPE_WORKGROUP, .memory_scope=SCOPE_WORKGROUP,
1011                          .memory_semantics=NIR_MEMORY_ACQ_REL, .memory_modes=nir_var_shader_out|nir_var_mem_shared);
1012 
1013    nir_def *num_prm;
1014    nir_def *num_vtx;
1015 
1016    set_ms_final_output_counts(b, s, &num_prm, &num_vtx);
1017 
1018    nir_def *invocation_index = nir_load_local_invocation_index(b);
1019 
1020    ms_prim_gen_query(b, invocation_index, num_prm, s);
1021 
1022    nir_def *row_start = NULL;
1023    if (s->fast_launch_2)
1024       row_start = s->hw_workgroup_size <= s->wave_size ? nir_imm_int(b, 0) : nir_load_subgroup_id(b);
1025 
1026    /* Load vertex/primitive attributes from shared memory and
1027     * emit store_output intrinsics for them.
1028     *
1029     * Contrary to the semantics of the API mesh shader, these are now
1030     * compliant with NGG HW semantics, meaning that these store the
1031     * current thread's vertex attributes in a way the HW can export.
1032     */
1033 
1034    uint64_t per_vertex_outputs =
1035       s->per_vertex_outputs & ~s->layout.attr_ring.vtx_attr.mask;
1036    uint64_t per_primitive_outputs =
1037       s->per_primitive_outputs & ~s->layout.attr_ring.prm_attr.mask & ~SPECIAL_MS_OUT_MASK;
1038 
1039    /* Insert layer output store if the pipeline uses multiview but the API shader doesn't write it. */
1040    if (s->insert_layer_output) {
1041       b->shader->info.outputs_written |= VARYING_BIT_LAYER;
1042       b->shader->info.per_primitive_outputs |= VARYING_BIT_LAYER;
1043       per_primitive_outputs |= VARYING_BIT_LAYER;
1044    }
1045 
1046    const bool has_special_param_exports =
1047       (per_vertex_outputs & MS_VERT_ARG_EXP_MASK) ||
1048       (per_primitive_outputs & MS_PRIM_ARG_EXP_MASK);
1049 
1050    const bool wait_attr_ring = must_wait_attr_ring(s->gfx_level, has_special_param_exports);
1051 
1052    /* Export vertices. */
1053    if ((per_vertex_outputs & ~VARYING_BIT_POS) || !wait_attr_ring) {
1054       emit_ms_outputs(b, invocation_index, row_start, num_vtx, !wait_attr_ring, true,
1055                       per_vertex_outputs, &emit_ms_vertex, s);
1056    }
1057 
1058    /* Export primitives. */
1059    if (per_primitive_outputs || !wait_attr_ring) {
1060       emit_ms_outputs(b, invocation_index, row_start, num_prm, !wait_attr_ring, true,
1061                       per_primitive_outputs, &emit_ms_primitive, s);
1062    }
1063 
1064    /* When we need to wait for attribute ring stores, we emit both position and primitive
1065     * export instructions after a barrier to make sure both per-vertex and per-primitive
1066     * attribute ring stores are finished before the GPU starts rasterization.
1067     */
1068    if (wait_attr_ring) {
1069       /* Wait for attribute stores to finish. */
1070       nir_barrier(b, .execution_scope = SCOPE_SUBGROUP,
1071                      .memory_scope = SCOPE_DEVICE,
1072                      .memory_semantics = NIR_MEMORY_RELEASE,
1073                      .memory_modes = nir_var_shader_out);
1074 
1075       /* Position/primitive export only */
1076       emit_ms_outputs(b, invocation_index, row_start, num_vtx, true, false,
1077                       per_vertex_outputs, &emit_ms_vertex, s);
1078       emit_ms_outputs(b, invocation_index, row_start, num_prm, true, false,
1079                       per_primitive_outputs, &emit_ms_primitive, s);
1080    }
1081 }
1082 
1083 static void
handle_smaller_ms_api_workgroup(nir_builder * b,lower_ngg_ms_state * s)1084 handle_smaller_ms_api_workgroup(nir_builder *b,
1085                                 lower_ngg_ms_state *s)
1086 {
1087    if (s->api_workgroup_size >= s->hw_workgroup_size)
1088       return;
1089 
1090    /* Handle barriers manually when the API workgroup
1091     * size is less than the HW workgroup size.
1092     *
1093     * The problem is that the real workgroup launched on NGG HW
1094     * will be larger than the size specified by the API, and the
1095     * extra waves need to keep up with barriers in the API waves.
1096     *
1097     * There are 2 different cases:
1098     * 1. The whole API workgroup fits in a single wave.
1099     *    We can shrink the barriers to subgroup scope and
1100     *    don't need to insert any extra ones.
1101     * 2. The API workgroup occupies multiple waves, but not
1102     *    all. In this case, we emit code that consumes every
1103     *    barrier on the extra waves.
1104     */
1105    assert(s->hw_workgroup_size % s->wave_size == 0);
1106    bool scan_barriers = ALIGN(s->api_workgroup_size, s->wave_size) < s->hw_workgroup_size;
1107    bool can_shrink_barriers = s->api_workgroup_size <= s->wave_size;
1108    bool need_additional_barriers = scan_barriers && !can_shrink_barriers;
1109 
1110    unsigned api_waves_in_flight_addr = s->layout.lds.workgroup_info_addr + lds_ms_num_api_waves;
1111    unsigned num_api_waves = DIV_ROUND_UP(s->api_workgroup_size, s->wave_size);
1112 
1113    /* Scan the shader for workgroup barriers. */
1114    if (scan_barriers) {
1115       bool has_any_workgroup_barriers = false;
1116 
1117       nir_foreach_block(block, b->impl) {
1118          nir_foreach_instr_safe(instr, block) {
1119             if (instr->type != nir_instr_type_intrinsic)
1120                continue;
1121 
1122             nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
1123             bool is_workgroup_barrier =
1124                intrin->intrinsic == nir_intrinsic_barrier &&
1125                nir_intrinsic_execution_scope(intrin) == SCOPE_WORKGROUP;
1126 
1127             if (!is_workgroup_barrier)
1128                continue;
1129 
1130             if (can_shrink_barriers) {
1131                /* Every API invocation runs in the first wave.
1132                 * In this case, we can change the barriers to subgroup scope
1133                 * and avoid adding additional barriers.
1134                 */
1135                nir_intrinsic_set_memory_scope(intrin, SCOPE_SUBGROUP);
1136                nir_intrinsic_set_execution_scope(intrin, SCOPE_SUBGROUP);
1137             } else {
1138                has_any_workgroup_barriers = true;
1139             }
1140          }
1141       }
1142 
1143       need_additional_barriers &= has_any_workgroup_barriers;
1144    }
1145 
1146    /* Extract the full control flow of the shader. */
1147    nir_cf_list extracted;
1148    nir_cf_extract(&extracted, nir_before_impl(b->impl),
1149                   nir_after_cf_list(&b->impl->body));
1150    b->cursor = nir_before_impl(b->impl);
1151 
1152    /* Wrap the shader in an if to ensure that only the necessary amount of lanes run it. */
1153    nir_def *invocation_index = nir_load_local_invocation_index(b);
1154    nir_def *zero = nir_imm_int(b, 0);
1155 
1156    if (need_additional_barriers) {
1157       /* First invocation stores 0 to number of API waves in flight. */
1158       nir_if *if_first_in_workgroup = nir_push_if(b, nir_ieq_imm(b, invocation_index, 0));
1159       {
1160          nir_store_shared(b, nir_imm_int(b, num_api_waves), zero, .base = api_waves_in_flight_addr);
1161       }
1162       nir_pop_if(b, if_first_in_workgroup);
1163 
1164       nir_barrier(b, .execution_scope = SCOPE_WORKGROUP,
1165                             .memory_scope = SCOPE_WORKGROUP,
1166                             .memory_semantics = NIR_MEMORY_ACQ_REL,
1167                             .memory_modes = nir_var_shader_out | nir_var_mem_shared);
1168    }
1169 
1170    nir_def *has_api_ms_invocation = nir_ult_imm(b, invocation_index, s->api_workgroup_size);
1171    nir_if *if_has_api_ms_invocation = nir_push_if(b, has_api_ms_invocation);
1172    {
1173       nir_cf_reinsert(&extracted, b->cursor);
1174       b->cursor = nir_after_cf_list(&if_has_api_ms_invocation->then_list);
1175 
1176       if (need_additional_barriers) {
1177          /* One invocation in each API wave decrements the number of API waves in flight. */
1178          nir_if *if_elected_again = nir_push_if(b, nir_elect(b, 1));
1179          {
1180             nir_shared_atomic(b, 32, zero, nir_imm_int(b, -1u),
1181                               .base = api_waves_in_flight_addr,
1182                               .atomic_op = nir_atomic_op_iadd);
1183          }
1184          nir_pop_if(b, if_elected_again);
1185 
1186          nir_barrier(b, .execution_scope = SCOPE_WORKGROUP,
1187                                .memory_scope = SCOPE_WORKGROUP,
1188                                .memory_semantics = NIR_MEMORY_ACQ_REL,
1189                                .memory_modes = nir_var_shader_out | nir_var_mem_shared);
1190       }
1191 
1192       ms_invocation_query(b, invocation_index, s);
1193    }
1194    nir_pop_if(b, if_has_api_ms_invocation);
1195 
1196    if (need_additional_barriers) {
1197       /* Make sure that waves that don't run any API invocations execute
1198        * the same amount of barriers as those that do.
1199        *
1200        * We do this by executing a barrier until the number of API waves
1201        * in flight becomes zero.
1202        */
1203       nir_def *has_api_ms_ballot = nir_ballot(b, 1, s->wave_size, has_api_ms_invocation);
1204       nir_def *wave_has_no_api_ms = nir_ieq_imm(b, has_api_ms_ballot, 0);
1205       nir_if *if_wave_has_no_api_ms = nir_push_if(b, wave_has_no_api_ms);
1206       {
1207          nir_if *if_elected = nir_push_if(b, nir_elect(b, 1));
1208          {
1209             nir_loop *loop = nir_push_loop(b);
1210             {
1211                nir_barrier(b, .execution_scope = SCOPE_WORKGROUP,
1212                                      .memory_scope = SCOPE_WORKGROUP,
1213                                      .memory_semantics = NIR_MEMORY_ACQ_REL,
1214                                      .memory_modes = nir_var_shader_out | nir_var_mem_shared);
1215 
1216                nir_def *loaded = nir_load_shared(b, 1, 32, zero, .base = api_waves_in_flight_addr);
1217                nir_if *if_break = nir_push_if(b, nir_ieq_imm(b, loaded, 0));
1218                {
1219                   nir_jump(b, nir_jump_break);
1220                }
1221                nir_pop_if(b, if_break);
1222             }
1223             nir_pop_loop(b, loop);
1224          }
1225          nir_pop_if(b, if_elected);
1226       }
1227       nir_pop_if(b, if_wave_has_no_api_ms);
1228    }
1229 }
1230 
1231 static void
ms_move_output(ms_out_part * from,ms_out_part * to)1232 ms_move_output(ms_out_part *from, ms_out_part *to)
1233 {
1234    uint64_t loc = util_logbase2_64(from->mask);
1235    uint64_t bit = BITFIELD64_BIT(loc);
1236    from->mask ^= bit;
1237    to->mask |= bit;
1238 }
1239 
1240 static void
ms_calculate_arrayed_output_layout(ms_out_mem_layout * l,unsigned max_vertices,unsigned max_primitives)1241 ms_calculate_arrayed_output_layout(ms_out_mem_layout *l,
1242                                    unsigned max_vertices,
1243                                    unsigned max_primitives)
1244 {
1245    uint32_t lds_vtx_attr_size = util_bitcount64(l->lds.vtx_attr.mask) * max_vertices * 16;
1246    uint32_t lds_prm_attr_size = util_bitcount64(l->lds.prm_attr.mask) * max_primitives * 16;
1247    l->lds.prm_attr.addr = ALIGN(l->lds.vtx_attr.addr + lds_vtx_attr_size, 16);
1248    l->lds.total_size = l->lds.prm_attr.addr + lds_prm_attr_size;
1249 
1250    uint32_t scratch_ring_vtx_attr_size =
1251       util_bitcount64(l->scratch_ring.vtx_attr.mask) * max_vertices * 16;
1252    l->scratch_ring.prm_attr.addr =
1253       ALIGN(l->scratch_ring.vtx_attr.addr + scratch_ring_vtx_attr_size, 16);
1254 }
1255 
1256 static ms_out_mem_layout
ms_calculate_output_layout(enum amd_gfx_level gfx_level,unsigned api_shared_size,uint64_t per_vertex_output_mask,uint64_t per_primitive_output_mask,uint64_t cross_invocation_output_access,unsigned max_vertices,unsigned max_primitives,unsigned vertices_per_prim)1257 ms_calculate_output_layout(enum amd_gfx_level gfx_level, unsigned api_shared_size,
1258                            uint64_t per_vertex_output_mask, uint64_t per_primitive_output_mask,
1259                            uint64_t cross_invocation_output_access, unsigned max_vertices,
1260                            unsigned max_primitives, unsigned vertices_per_prim)
1261 {
1262    /* These outputs always need export instructions and can't use the attributes ring. */
1263    const uint64_t always_export_mask =
1264       VARYING_BIT_POS | VARYING_BIT_CULL_DIST0 | VARYING_BIT_CULL_DIST1 | VARYING_BIT_CLIP_DIST0 |
1265       VARYING_BIT_CLIP_DIST1 | VARYING_BIT_PSIZ | VARYING_BIT_VIEWPORT |
1266       VARYING_BIT_PRIMITIVE_SHADING_RATE | VARYING_BIT_LAYER |
1267       BITFIELD64_BIT(VARYING_SLOT_PRIMITIVE_COUNT) |
1268       BITFIELD64_BIT(VARYING_SLOT_PRIMITIVE_INDICES) | BITFIELD64_BIT(VARYING_SLOT_CULL_PRIMITIVE);
1269 
1270    const bool use_attr_ring = gfx_level >= GFX11;
1271    const uint64_t attr_ring_per_vertex_output_mask =
1272       use_attr_ring ? per_vertex_output_mask & ~always_export_mask : 0;
1273    const uint64_t attr_ring_per_primitive_output_mask =
1274       use_attr_ring ? per_primitive_output_mask & ~always_export_mask : 0;
1275 
1276    const uint64_t lds_per_vertex_output_mask =
1277       per_vertex_output_mask & ~attr_ring_per_vertex_output_mask & cross_invocation_output_access &
1278       ~SPECIAL_MS_OUT_MASK;
1279    const uint64_t lds_per_primitive_output_mask =
1280       per_primitive_output_mask & ~attr_ring_per_primitive_output_mask &
1281       cross_invocation_output_access & ~SPECIAL_MS_OUT_MASK;
1282 
1283    const bool cross_invocation_indices =
1284       cross_invocation_output_access & BITFIELD64_BIT(VARYING_SLOT_PRIMITIVE_INDICES);
1285    const bool cross_invocation_cull_primitive =
1286       cross_invocation_output_access & BITFIELD64_BIT(VARYING_SLOT_CULL_PRIMITIVE);
1287 
1288    /* Shared memory used by the API shader. */
1289    ms_out_mem_layout l = { .lds = { .total_size = api_shared_size } };
1290 
1291    /* GFX11+: use attribute ring for all generic attributes. */
1292    l.attr_ring.vtx_attr.mask = attr_ring_per_vertex_output_mask;
1293    l.attr_ring.prm_attr.mask = attr_ring_per_primitive_output_mask;
1294 
1295    /* Outputs without cross-invocation access can be stored in variables. */
1296    l.var.vtx_attr.mask =
1297       per_vertex_output_mask & ~attr_ring_per_vertex_output_mask & ~cross_invocation_output_access;
1298    l.var.prm_attr.mask = per_primitive_output_mask & ~attr_ring_per_primitive_output_mask &
1299                          ~cross_invocation_output_access;
1300 
1301    /* Workgroup information, see ms_workgroup_* for the layout. */
1302    l.lds.workgroup_info_addr = ALIGN(l.lds.total_size, 16);
1303    l.lds.total_size = l.lds.workgroup_info_addr + 16;
1304 
1305    /* Per-vertex and per-primitive output attributes.
1306     * Outputs without cross-invocation access are not included here.
1307     * First, try to put all outputs into LDS (shared memory).
1308     * If they don't fit, try to move them to VRAM one by one.
1309     */
1310    l.lds.vtx_attr.addr = ALIGN(l.lds.total_size, 16);
1311    l.lds.vtx_attr.mask = lds_per_vertex_output_mask;
1312    l.lds.prm_attr.mask = lds_per_primitive_output_mask;
1313    ms_calculate_arrayed_output_layout(&l, max_vertices, max_primitives);
1314 
1315    /* NGG shaders can only address up to 32K LDS memory.
1316     * The spec requires us to allow the application to use at least up to 28K
1317     * shared memory. Additionally, we reserve 2K for driver internal use
1318     * (eg. primitive indices and such, see below).
1319     *
1320     * Move the outputs that do not fit LDS, to VRAM.
1321     * Start with per-primitive attributes, because those are grouped at the end.
1322     */
1323    const unsigned usable_lds_kbytes =
1324       (cross_invocation_cull_primitive || cross_invocation_indices) ? 30 : 31;
1325    while (l.lds.total_size >= usable_lds_kbytes * 1024) {
1326       if (l.lds.prm_attr.mask)
1327          ms_move_output(&l.lds.prm_attr, &l.scratch_ring.prm_attr);
1328       else if (l.lds.vtx_attr.mask)
1329          ms_move_output(&l.lds.vtx_attr, &l.scratch_ring.vtx_attr);
1330       else
1331          unreachable("API shader uses too much shared memory.");
1332 
1333       ms_calculate_arrayed_output_layout(&l, max_vertices, max_primitives);
1334    }
1335 
1336    if (cross_invocation_indices) {
1337       /* Indices: flat array of 8-bit vertex indices for each primitive. */
1338       l.lds.indices_addr = ALIGN(l.lds.total_size, 16);
1339       l.lds.total_size = l.lds.indices_addr + max_primitives * vertices_per_prim;
1340    }
1341 
1342    if (cross_invocation_cull_primitive) {
1343       /* Cull flags: array of 8-bit cull flags for each primitive, 1=cull, 0=keep. */
1344       l.lds.cull_flags_addr = ALIGN(l.lds.total_size, 16);
1345       l.lds.total_size = l.lds.cull_flags_addr + max_primitives;
1346    }
1347 
1348    /* NGG is only allowed to address up to 32K of LDS. */
1349    assert(l.lds.total_size <= 32 * 1024);
1350    return l;
1351 }
1352 
1353 void
ac_nir_lower_ngg_mesh(nir_shader * shader,enum amd_gfx_level gfx_level,uint32_t clipdist_enable_mask,const uint8_t * vs_output_param_offset,bool has_param_exports,bool * out_needs_scratch_ring,unsigned wave_size,unsigned hw_workgroup_size,bool multiview,bool has_query,bool fast_launch_2)1354 ac_nir_lower_ngg_mesh(nir_shader *shader,
1355                     enum amd_gfx_level gfx_level,
1356                     uint32_t clipdist_enable_mask,
1357                     const uint8_t *vs_output_param_offset,
1358                     bool has_param_exports,
1359                     bool *out_needs_scratch_ring,
1360                     unsigned wave_size,
1361                     unsigned hw_workgroup_size,
1362                     bool multiview,
1363                     bool has_query,
1364                     bool fast_launch_2)
1365 {
1366    unsigned vertices_per_prim =
1367       mesa_vertices_per_prim(shader->info.mesh.primitive_type);
1368 
1369    uint64_t per_vertex_outputs =
1370       shader->info.outputs_written & ~shader->info.per_primitive_outputs & ~SPECIAL_MS_OUT_MASK;
1371    uint64_t per_primitive_outputs =
1372       shader->info.per_primitive_outputs & shader->info.outputs_written;
1373 
1374    /* Whether the shader uses CullPrimitiveEXT */
1375    bool uses_cull = shader->info.outputs_written & BITFIELD64_BIT(VARYING_SLOT_CULL_PRIMITIVE);
1376    /* Can't handle indirect register addressing, pretend as if they were cross-invocation. */
1377    uint64_t cross_invocation_access = shader->info.mesh.ms_cross_invocation_output_access |
1378                                       shader->info.outputs_accessed_indirectly;
1379 
1380    unsigned max_vertices = shader->info.mesh.max_vertices_out;
1381    unsigned max_primitives = shader->info.mesh.max_primitives_out;
1382 
1383    ms_out_mem_layout layout = ms_calculate_output_layout(
1384       gfx_level, shader->info.shared_size, per_vertex_outputs, per_primitive_outputs,
1385       cross_invocation_access, max_vertices, max_primitives, vertices_per_prim);
1386 
1387    shader->info.shared_size = layout.lds.total_size;
1388    *out_needs_scratch_ring = layout.scratch_ring.vtx_attr.mask || layout.scratch_ring.prm_attr.mask;
1389 
1390    /* The workgroup size that is specified by the API shader may be different
1391     * from the size of the workgroup that actually runs on the HW, due to the
1392     * limitations of NGG: max 0/1 vertex and 0/1 primitive per lane is allowed.
1393     *
1394     * Therefore, we must make sure that when the API workgroup size is smaller,
1395     * we don't run the API shader on more HW invocations than is necessary.
1396     */
1397    unsigned api_workgroup_size = shader->info.workgroup_size[0] *
1398                                  shader->info.workgroup_size[1] *
1399                                  shader->info.workgroup_size[2];
1400 
1401    lower_ngg_ms_state state = {
1402       .layout = layout,
1403       .wave_size = wave_size,
1404       .per_vertex_outputs = per_vertex_outputs,
1405       .per_primitive_outputs = per_primitive_outputs,
1406       .vertices_per_prim = vertices_per_prim,
1407       .api_workgroup_size = api_workgroup_size,
1408       .hw_workgroup_size = hw_workgroup_size,
1409       .insert_layer_output = multiview && !(shader->info.outputs_written & VARYING_BIT_LAYER),
1410       .uses_cull_flags = uses_cull,
1411       .gfx_level = gfx_level,
1412       .fast_launch_2 = fast_launch_2,
1413       .vert_multirow_export = fast_launch_2 && max_vertices > hw_workgroup_size,
1414       .prim_multirow_export = fast_launch_2 && max_primitives > hw_workgroup_size,
1415       .clipdist_enable_mask = clipdist_enable_mask,
1416       .vs_output_param_offset = vs_output_param_offset,
1417       .has_param_exports = has_param_exports,
1418       .has_query = has_query,
1419    };
1420 
1421    nir_function_impl *impl = nir_shader_get_entrypoint(shader);
1422    assert(impl);
1423 
1424    state.vertex_count_var =
1425       nir_local_variable_create(impl, glsl_uint_type(), "vertex_count_var");
1426    state.primitive_count_var =
1427       nir_local_variable_create(impl, glsl_uint_type(), "primitive_count_var");
1428 
1429    nir_builder builder = nir_builder_at(nir_before_impl(impl));
1430    nir_builder *b = &builder; /* This is to avoid the & */
1431 
1432    handle_smaller_ms_api_workgroup(b, &state);
1433    if (!fast_launch_2)
1434       ms_emit_legacy_workgroup_index(b, &state);
1435    ms_create_same_invocation_vars(b, &state);
1436    nir_metadata_preserve(impl, nir_metadata_none);
1437 
1438    lower_ms_intrinsics(shader, &state);
1439 
1440    emit_ms_finale(b, &state);
1441    nir_metadata_preserve(impl, nir_metadata_none);
1442 
1443    /* Cleanup */
1444    nir_lower_vars_to_ssa(shader);
1445    nir_remove_dead_variables(shader, nir_var_function_temp, NULL);
1446    nir_lower_alu_to_scalar(shader, NULL, NULL);
1447    nir_lower_phis_to_scalar(shader, true);
1448 
1449    /* Optimize load_local_invocation_index. When the API workgroup is smaller than the HW workgroup,
1450     * local_invocation_id isn't initialized for all lanes and we can't perform this optimization for
1451     * all load_local_invocation_index.
1452     */
1453    if (fast_launch_2 && api_workgroup_size == hw_workgroup_size &&
1454        ((shader->info.workgroup_size[0] == 1) + (shader->info.workgroup_size[1] == 1) +
1455         (shader->info.workgroup_size[2] == 1)) == 2) {
1456       nir_lower_compute_system_values_options csv_options = {
1457          .lower_local_invocation_index = true,
1458       };
1459       nir_lower_compute_system_values(shader, &csv_options);
1460    }
1461 
1462    nir_validate_shader(shader, "after emitting NGG MS");
1463 }
1464