• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright © 2021 Valve Corporation
3  *
4  * SPDX-License-Identifier: MIT
5  */
6 
7 #include "ac_nir.h"
8 #include "amdgfxregs.h"
9 #include "nir_builder.h"
10 #include "nir_xfb_info.h"
11 #include "util/u_math.h"
12 #include "util/u_vector.h"
13 
14 #define SPECIAL_MS_OUT_MASK \
15    (BITFIELD64_BIT(VARYING_SLOT_PRIMITIVE_COUNT) | \
16     BITFIELD64_BIT(VARYING_SLOT_PRIMITIVE_INDICES) | \
17     BITFIELD64_BIT(VARYING_SLOT_CULL_PRIMITIVE))
18 
19 #define MS_PRIM_ARG_EXP_MASK \
20    (VARYING_BIT_LAYER | \
21     VARYING_BIT_VIEWPORT | \
22     VARYING_BIT_PRIMITIVE_SHADING_RATE)
23 
24 #define MS_VERT_ARG_EXP_MASK \
25    (VARYING_BIT_CULL_DIST0 | \
26     VARYING_BIT_CULL_DIST1 | \
27     VARYING_BIT_CLIP_DIST0 | \
28     VARYING_BIT_CLIP_DIST1 | \
29     VARYING_BIT_PSIZ)
30 
31 enum {
32    nggc_passflag_used_by_pos = 1,
33    nggc_passflag_used_by_other = 2,
34    nggc_passflag_used_by_both = nggc_passflag_used_by_pos | nggc_passflag_used_by_other,
35 };
36 
37 typedef struct
38 {
39    nir_def *ssa;
40    nir_variable *var;
41 } reusable_nondeferred_variable;
42 
43 typedef struct
44 {
45    gl_varying_slot slot;
46    nir_def *chan[4];
47 } vs_output;
48 
49 typedef struct
50 {
51    nir_alu_type types[VARYING_SLOT_MAX][4];
52    nir_alu_type types_16bit_lo[16][4];
53    nir_alu_type types_16bit_hi[16][4];
54 } shader_output_types;
55 
56 typedef struct
57 {
58    const ac_nir_lower_ngg_options *options;
59 
60    nir_variable *position_value_var;
61    nir_variable *prim_exp_arg_var;
62    nir_variable *es_accepted_var;
63    nir_variable *gs_accepted_var;
64    nir_variable *gs_exported_var;
65    nir_variable *gs_vtx_indices_vars[3];
66 
67    nir_def *vtx_addr[3];
68 
69    struct u_vector reusable_nondeferred_variables;
70 
71    bool early_prim_export;
72    bool streamout_enabled;
73    bool has_user_edgeflags;
74    unsigned max_num_waves;
75 
76    /* LDS params */
77    unsigned pervertex_lds_bytes;
78 
79    uint64_t inputs_needed_by_pos;
80    uint64_t inputs_needed_by_others;
81 
82    nir_instr *compact_arg_stores[4];
83    nir_intrinsic_instr *overwrite_args;
84    nir_variable *repacked_rel_patch_id;
85 
86    /* clip distance */
87    nir_variable *clip_vertex_var;
88    nir_variable *clipdist_neg_mask_var;
89    bool has_clipdist;
90 
91    /* outputs */
92    nir_def *outputs[VARYING_SLOT_MAX][4];
93    nir_def *outputs_16bit_lo[16][4];
94    nir_def *outputs_16bit_hi[16][4];
95    shader_output_types output_types;
96 } lower_ngg_nogs_state;
97 
98 typedef struct
99 {
100    /* output stream index, 2 bit per component */
101    uint8_t stream;
102    /* Bitmask of components used: 4 bits per slot, 1 bit per component. */
103    uint8_t components_mask : 4;
104 } gs_output_info;
105 
106 typedef struct
107 {
108    const ac_nir_lower_ngg_options *options;
109 
110    nir_function_impl *impl;
111    int const_out_vtxcnt[4];
112    int const_out_prmcnt[4];
113    unsigned max_num_waves;
114    unsigned num_vertices_per_primitive;
115    nir_def *lds_addr_gs_out_vtx;
116    nir_def *lds_addr_gs_scratch;
117    unsigned lds_bytes_per_gs_out_vertex;
118    unsigned lds_offs_primflags;
119    bool output_compile_time_known;
120    bool streamout_enabled;
121    /* 32 bit outputs */
122    nir_def *outputs[VARYING_SLOT_MAX][4];
123    gs_output_info output_info[VARYING_SLOT_MAX];
124    /* 16 bit outputs */
125    nir_def *outputs_16bit_hi[16][4];
126    nir_def *outputs_16bit_lo[16][4];
127    gs_output_info output_info_16bit_hi[16];
128    gs_output_info output_info_16bit_lo[16];
129    /* output types for both 32bit and 16bit */
130    shader_output_types output_types;
131    /* Count per stream. */
132    nir_def *vertex_count[4];
133    nir_def *primitive_count[4];
134 } lower_ngg_gs_state;
135 
136 /* LDS layout of Mesh Shader workgroup info. */
137 enum {
138    /* DW0: number of primitives */
139    lds_ms_num_prims = 0,
140    /* DW1: number of vertices */
141    lds_ms_num_vtx = 4,
142    /* DW2: workgroup index within the current dispatch */
143    lds_ms_wg_index = 8,
144    /* DW3: number of API workgroups in flight */
145    lds_ms_num_api_waves = 12,
146 };
147 
148 /* Potential location for Mesh Shader outputs. */
149 typedef enum {
150    ms_out_mode_lds,
151    ms_out_mode_scratch_ring,
152    ms_out_mode_attr_ring,
153    ms_out_mode_var,
154 } ms_out_mode;
155 
156 typedef struct
157 {
158    uint64_t mask; /* Mask of output locations */
159    uint32_t addr; /* Base address */
160 } ms_out_part;
161 
162 typedef struct
163 {
164    /* Mesh shader LDS layout. For details, see ms_calculate_output_layout. */
165    struct {
166       uint32_t workgroup_info_addr;
167       ms_out_part vtx_attr;
168       ms_out_part prm_attr;
169       uint32_t indices_addr;
170       uint32_t cull_flags_addr;
171       uint32_t total_size;
172    } lds;
173 
174    /* VRAM "mesh shader scratch ring" layout for outputs that don't fit into the LDS.
175     * Not to be confused with scratch memory.
176     */
177    struct {
178       ms_out_part vtx_attr;
179       ms_out_part prm_attr;
180    } scratch_ring;
181 
182    /* VRAM attributes ring (GFX11 only) for all non-position outputs.
183     * GFX11 doesn't have to reload attributes from this ring at the end of the shader.
184     */
185    struct {
186       ms_out_part vtx_attr;
187       ms_out_part prm_attr;
188    } attr_ring;
189 
190    /* Outputs without cross-invocation access can be stored in variables. */
191    struct {
192       ms_out_part vtx_attr;
193       ms_out_part prm_attr;
194    } var;
195 } ms_out_mem_layout;
196 
197 typedef struct
198 {
199    enum amd_gfx_level gfx_level;
200    bool fast_launch_2;
201    bool vert_multirow_export;
202    bool prim_multirow_export;
203 
204    ms_out_mem_layout layout;
205    uint64_t per_vertex_outputs;
206    uint64_t per_primitive_outputs;
207    unsigned vertices_per_prim;
208 
209    unsigned wave_size;
210    unsigned api_workgroup_size;
211    unsigned hw_workgroup_size;
212 
213    nir_def *workgroup_index;
214    nir_variable *out_variables[VARYING_SLOT_MAX * 4];
215    nir_variable *primitive_count_var;
216    nir_variable *vertex_count_var;
217 
218    /* True if the lowering needs to insert the layer output. */
219    bool insert_layer_output;
220    /* True if cull flags are used */
221    bool uses_cull_flags;
222 
223    struct {
224       /* Bitmask of components used: 4 bits per slot, 1 bit per component. */
225       uint32_t components_mask;
226    } output_info[VARYING_SLOT_MAX];
227 
228    /* Used by outputs export. */
229    nir_def *outputs[VARYING_SLOT_MAX][4];
230    uint32_t clipdist_enable_mask;
231    const uint8_t *vs_output_param_offset;
232    bool has_param_exports;
233 
234    /* True if the lowering needs to insert shader query. */
235    bool has_query;
236 } lower_ngg_ms_state;
237 
238 /* Per-vertex LDS layout of culling shaders */
239 enum {
240    /* Position of the ES vertex (at the beginning for alignment reasons) */
241    lds_es_pos_x = 0,
242    lds_es_pos_y = 4,
243    lds_es_pos_z = 8,
244    lds_es_pos_w = 12,
245 
246    /* 1 when the vertex is accepted, 0 if it should be culled */
247    lds_es_vertex_accepted = 16,
248    /* ID of the thread which will export the current thread's vertex */
249    lds_es_exporter_tid = 17,
250    /* bit i is set when the i'th clip distance of a vertex is negative */
251    lds_es_clipdist_neg_mask = 18,
252    /* TES only, relative patch ID, less than max workgroup size */
253    lds_es_tes_rel_patch_id = 19,
254 
255    /* Repacked arguments - also listed separately for VS and TES */
256    lds_es_arg_0 = 20,
257 };
258 
259 typedef struct {
260    nir_def *num_repacked_invocations;
261    nir_def *repacked_invocation_index;
262 } wg_repack_result;
263 
264 /**
265  * Computes a horizontal sum of 8-bit packed values loaded from LDS.
266  *
267  * Each lane N will sum packed bytes 0 to N-1.
268  * We only care about the results from up to wave_id+1 lanes.
269  * (Other lanes are not deactivated but their calculation is not used.)
270  */
271 static nir_def *
summarize_repack(nir_builder * b,nir_def * packed_counts,unsigned num_lds_dwords)272 summarize_repack(nir_builder *b, nir_def *packed_counts, unsigned num_lds_dwords)
273 {
274    /* We'll use shift to filter out the bytes not needed by the current lane.
275     *
276     * Need to shift by: num_lds_dwords * 4 - lane_id (in bytes).
277     * However, two shifts are needed because one can't go all the way,
278     * so the shift amount is half that (and in bits).
279     *
280     * When v_dot4_u32_u8 is available, we right-shift a series of 0x01 bytes.
281     * This will yield 0x01 at wanted byte positions and 0x00 at unwanted positions,
282     * therefore v_dot can get rid of the unneeded values.
283     * This sequence is preferable because it better hides the latency of the LDS.
284     *
285     * If the v_dot instruction can't be used, we left-shift the packed bytes.
286     * This will shift out the unneeded bytes and shift in zeroes instead,
287     * then we sum them using v_msad_u8.
288     */
289 
290    nir_def *lane_id = nir_load_subgroup_invocation(b);
291    nir_def *shift = nir_iadd_imm(b, nir_imul_imm(b, lane_id, -4u), num_lds_dwords * 16);
292    bool use_dot = b->shader->options->has_udot_4x8;
293 
294    if (num_lds_dwords == 1) {
295       nir_def *dot_op = !use_dot ? NULL : nir_ushr(b, nir_ushr(b, nir_imm_int(b, 0x01010101), shift), shift);
296 
297       /* Broadcast the packed data we read from LDS (to the first 16 lanes, but we only care up to num_waves). */
298       nir_def *packed = nir_lane_permute_16_amd(b, packed_counts, nir_imm_int(b, 0), nir_imm_int(b, 0));
299 
300       /* Horizontally add the packed bytes. */
301       if (use_dot) {
302          return nir_udot_4x8_uadd(b, packed, dot_op, nir_imm_int(b, 0));
303       } else {
304          nir_def *sad_op = nir_ishl(b, nir_ishl(b, packed, shift), shift);
305          return nir_msad_4x8(b, sad_op, nir_imm_int(b, 0), nir_imm_int(b, 0));
306       }
307    } else if (num_lds_dwords == 2) {
308       nir_def *dot_op = !use_dot ? NULL : nir_ushr(b, nir_ushr(b, nir_imm_int64(b, 0x0101010101010101), shift), shift);
309 
310       /* Broadcast the packed data we read from LDS (to the first 16 lanes, but we only care up to num_waves). */
311       nir_def *packed_dw0 = nir_lane_permute_16_amd(b, nir_unpack_64_2x32_split_x(b, packed_counts), nir_imm_int(b, 0), nir_imm_int(b, 0));
312       nir_def *packed_dw1 = nir_lane_permute_16_amd(b, nir_unpack_64_2x32_split_y(b, packed_counts), nir_imm_int(b, 0), nir_imm_int(b, 0));
313 
314       /* Horizontally add the packed bytes. */
315       if (use_dot) {
316          nir_def *sum = nir_udot_4x8_uadd(b, packed_dw0, nir_unpack_64_2x32_split_x(b, dot_op), nir_imm_int(b, 0));
317          return nir_udot_4x8_uadd(b, packed_dw1, nir_unpack_64_2x32_split_y(b, dot_op), sum);
318       } else {
319          nir_def *sad_op = nir_ishl(b, nir_ishl(b, nir_pack_64_2x32_split(b, packed_dw0, packed_dw1), shift), shift);
320          nir_def *sum = nir_msad_4x8(b, nir_unpack_64_2x32_split_x(b, sad_op), nir_imm_int(b, 0), nir_imm_int(b, 0));
321          return nir_msad_4x8(b, nir_unpack_64_2x32_split_y(b, sad_op), nir_imm_int(b, 0), sum);
322       }
323    } else {
324       unreachable("Unimplemented NGG wave count");
325    }
326 }
327 
328 /**
329  * Repacks invocations in the current workgroup to eliminate gaps between them.
330  *
331  * Uses 1 dword of LDS per 4 waves (1 byte of LDS per wave).
332  * Assumes that all invocations in the workgroup are active (exec = -1).
333  */
334 static wg_repack_result
repack_invocations_in_workgroup(nir_builder * b,nir_def * input_bool,nir_def * lds_addr_base,unsigned max_num_waves,unsigned wave_size)335 repack_invocations_in_workgroup(nir_builder *b, nir_def *input_bool,
336                                 nir_def *lds_addr_base, unsigned max_num_waves,
337                                 unsigned wave_size)
338 {
339    /* Input boolean: 1 if the current invocation should survive the repack. */
340    assert(input_bool->bit_size == 1);
341 
342    /* STEP 1. Count surviving invocations in the current wave.
343     *
344     * Implemented by a scalar instruction that simply counts the number of bits set in a 32/64-bit mask.
345     */
346 
347    nir_def *input_mask = nir_ballot(b, 1, wave_size, input_bool);
348    nir_def *surviving_invocations_in_current_wave = nir_bit_count(b, input_mask);
349 
350    /* If we know at compile time that the workgroup has only 1 wave, no further steps are necessary. */
351    if (max_num_waves == 1) {
352       wg_repack_result r = {
353          .num_repacked_invocations = surviving_invocations_in_current_wave,
354          .repacked_invocation_index = nir_mbcnt_amd(b, input_mask, nir_imm_int(b, 0)),
355       };
356       return r;
357    }
358 
359    /* STEP 2. Waves tell each other their number of surviving invocations.
360     *
361     * Each wave activates only its first lane (exec = 1), which stores the number of surviving
362     * invocations in that wave into the LDS, then reads the numbers from every wave.
363     *
364     * The workgroup size of NGG shaders is at most 256, which means
365     * the maximum number of waves is 4 in Wave64 mode and 8 in Wave32 mode.
366     * Each wave writes 1 byte, so it's up to 8 bytes, so at most 2 dwords are necessary.
367     */
368 
369    const unsigned num_lds_dwords = DIV_ROUND_UP(max_num_waves, 4);
370    assert(num_lds_dwords <= 2);
371 
372    nir_def *wave_id = nir_load_subgroup_id(b);
373    nir_def *lds_offset = nir_iadd(b, lds_addr_base, wave_id);
374    nir_def *dont_care = nir_undef(b, 1, num_lds_dwords * 32);
375    nir_if *if_first_lane = nir_push_if(b, nir_elect(b, 1));
376 
377    nir_store_shared(b, nir_u2u8(b, surviving_invocations_in_current_wave), lds_offset);
378 
379    nir_barrier(b, .execution_scope=SCOPE_WORKGROUP, .memory_scope=SCOPE_WORKGROUP,
380                          .memory_semantics=NIR_MEMORY_ACQ_REL, .memory_modes=nir_var_mem_shared);
381 
382    nir_def *packed_counts =
383       nir_load_shared(b, 1, num_lds_dwords * 32, lds_addr_base, .align_mul = 8u);
384 
385    nir_pop_if(b, if_first_lane);
386 
387    packed_counts = nir_if_phi(b, packed_counts, dont_care);
388 
389    /* STEP 3. Compute the repacked invocation index and the total number of surviving invocations.
390     *
391     * By now, every wave knows the number of surviving invocations in all waves.
392     * Each number is 1 byte, and they are packed into up to 2 dwords.
393     *
394     * Each lane N will sum the number of surviving invocations from waves 0 to N-1.
395     * If the workgroup has M waves, then each wave will use only its first M+1 lanes for this.
396     * (Other lanes are not deactivated but their calculation is not used.)
397     *
398     * - We read the sum from the lane whose id is the current wave's id.
399     *   Add the masked bitcount to this, and we get the repacked invocation index.
400     * - We read the sum from the lane whose id is the number of waves in the workgroup.
401     *   This is the total number of surviving invocations in the workgroup.
402     */
403 
404    nir_def *num_waves = nir_load_num_subgroups(b);
405    nir_def *sum = summarize_repack(b, packed_counts, num_lds_dwords);
406 
407    nir_def *wg_repacked_index_base = nir_read_invocation(b, sum, wave_id);
408    nir_def *wg_num_repacked_invocations = nir_read_invocation(b, sum, num_waves);
409    nir_def *wg_repacked_index = nir_mbcnt_amd(b, input_mask, wg_repacked_index_base);
410 
411    wg_repack_result r = {
412       .num_repacked_invocations = wg_num_repacked_invocations,
413       .repacked_invocation_index = wg_repacked_index,
414    };
415 
416    return r;
417 }
418 
419 static nir_def *
pervertex_lds_addr(nir_builder * b,nir_def * vertex_idx,unsigned per_vtx_bytes)420 pervertex_lds_addr(nir_builder *b, nir_def *vertex_idx, unsigned per_vtx_bytes)
421 {
422    return nir_imul_imm(b, vertex_idx, per_vtx_bytes);
423 }
424 
425 static nir_def *
emit_pack_ngg_prim_exp_arg(nir_builder * b,unsigned num_vertices_per_primitives,nir_def * vertex_indices[3],nir_def * is_null_prim)426 emit_pack_ngg_prim_exp_arg(nir_builder *b, unsigned num_vertices_per_primitives,
427                            nir_def *vertex_indices[3], nir_def *is_null_prim)
428 {
429    nir_def *arg = nir_load_initial_edgeflags_amd(b);
430 
431    for (unsigned i = 0; i < num_vertices_per_primitives; ++i) {
432       assert(vertex_indices[i]);
433       arg = nir_ior(b, arg, nir_ishl_imm(b, vertex_indices[i], 10u * i));
434    }
435 
436    if (is_null_prim) {
437       if (is_null_prim->bit_size == 1)
438          is_null_prim = nir_b2i32(b, is_null_prim);
439       assert(is_null_prim->bit_size == 32);
440       arg = nir_ior(b, arg, nir_ishl_imm(b, is_null_prim, 31u));
441    }
442 
443    return arg;
444 }
445 
446 static void
alloc_vertices_and_primitives(nir_builder * b,nir_def * num_vtx,nir_def * num_prim)447 alloc_vertices_and_primitives(nir_builder *b,
448                               nir_def *num_vtx,
449                               nir_def *num_prim)
450 {
451    /* The caller should only call this conditionally on wave 0.
452     *
453     * Send GS Alloc Request message from the first wave of the group to SPI.
454     * Message payload (in the m0 register) is:
455     * - bits 0..10: number of vertices in group
456     * - bits 12..22: number of primitives in group
457     */
458 
459    nir_def *m0 = nir_ior(b, nir_ishl_imm(b, num_prim, 12), num_vtx);
460    nir_sendmsg_amd(b, m0, .base = AC_SENDMSG_GS_ALLOC_REQ);
461 }
462 
463 static void
alloc_vertices_and_primitives_gfx10_workaround(nir_builder * b,nir_def * num_vtx,nir_def * num_prim)464 alloc_vertices_and_primitives_gfx10_workaround(nir_builder *b,
465                                                nir_def *num_vtx,
466                                                nir_def *num_prim)
467 {
468    /* HW workaround for a GPU hang with 100% culling on GFX10.
469     * We always have to export at least 1 primitive.
470     * Export a degenerate triangle using vertex 0 for all 3 vertices.
471     *
472     * NOTE: We rely on the caller to set the vertex count also to 0 when the primitive count is 0.
473     */
474    nir_def *is_prim_cnt_0 = nir_ieq_imm(b, num_prim, 0);
475    nir_if *if_prim_cnt_0 = nir_push_if(b, is_prim_cnt_0);
476    {
477       nir_def *one = nir_imm_int(b, 1);
478       alloc_vertices_and_primitives(b, one, one);
479 
480       nir_def *tid = nir_load_subgroup_invocation(b);
481       nir_def *is_thread_0 = nir_ieq_imm(b, tid, 0);
482       nir_if *if_thread_0 = nir_push_if(b, is_thread_0);
483       {
484          /* The vertex indices are 0, 0, 0. */
485          nir_export_amd(b, nir_imm_zero(b, 4, 32),
486                         .base = V_008DFC_SQ_EXP_PRIM,
487                         .flags = AC_EXP_FLAG_DONE,
488                         .write_mask = 1);
489 
490          /* The HW culls primitives with NaN. -1 is also NaN and can save
491           * a dword in binary code by inlining constant.
492           */
493          nir_export_amd(b, nir_imm_ivec4(b, -1, -1, -1, -1),
494                         .base = V_008DFC_SQ_EXP_POS,
495                         .flags = AC_EXP_FLAG_DONE,
496                         .write_mask = 0xf);
497       }
498       nir_pop_if(b, if_thread_0);
499    }
500    nir_push_else(b, if_prim_cnt_0);
501    {
502       alloc_vertices_and_primitives(b, num_vtx, num_prim);
503    }
504    nir_pop_if(b, if_prim_cnt_0);
505 }
506 
507 static void
ngg_nogs_init_vertex_indices_vars(nir_builder * b,nir_function_impl * impl,lower_ngg_nogs_state * s)508 ngg_nogs_init_vertex_indices_vars(nir_builder *b, nir_function_impl *impl, lower_ngg_nogs_state *s)
509 {
510    for (unsigned v = 0; v < s->options->num_vertices_per_primitive; ++v) {
511       s->gs_vtx_indices_vars[v] = nir_local_variable_create(impl, glsl_uint_type(), "gs_vtx_addr");
512 
513       nir_def *vtx = s->options->passthrough ?
514          nir_ubfe_imm(b, nir_load_packed_passthrough_primitive_amd(b),
515                       10 * v, 9) :
516          nir_ubfe_imm(b, nir_load_gs_vertex_offset_amd(b, .base = v / 2u),
517                       (v & 1u) * 16u, 16u);
518 
519       nir_store_var(b, s->gs_vtx_indices_vars[v], vtx, 0x1);
520    }
521 }
522 
523 static nir_def *
emit_ngg_nogs_prim_exp_arg(nir_builder * b,lower_ngg_nogs_state * s)524 emit_ngg_nogs_prim_exp_arg(nir_builder *b, lower_ngg_nogs_state *s)
525 {
526    if (s->options->passthrough) {
527       return nir_load_packed_passthrough_primitive_amd(b);
528    } else {
529       nir_def *vtx_idx[3] = {0};
530 
531       for (unsigned v = 0; v < s->options->num_vertices_per_primitive; ++v)
532          vtx_idx[v] = nir_load_var(b, s->gs_vtx_indices_vars[v]);
533 
534       return emit_pack_ngg_prim_exp_arg(b, s->options->num_vertices_per_primitive, vtx_idx, NULL);
535    }
536 }
537 
538 static nir_def *
has_input_vertex(nir_builder * b)539 has_input_vertex(nir_builder *b)
540 {
541    return nir_is_subgroup_invocation_lt_amd(b, nir_load_merged_wave_info_amd(b));
542 }
543 
544 static nir_def *
has_input_primitive(nir_builder * b)545 has_input_primitive(nir_builder *b)
546 {
547    return nir_is_subgroup_invocation_lt_amd(b,
548                                             nir_ushr_imm(b, nir_load_merged_wave_info_amd(b), 8));
549 }
550 
551 static void
nogs_prim_gen_query(nir_builder * b,lower_ngg_nogs_state * s)552 nogs_prim_gen_query(nir_builder *b, lower_ngg_nogs_state *s)
553 {
554    if (!s->options->has_gen_prim_query)
555       return;
556 
557    nir_if *if_shader_query = nir_push_if(b, nir_load_prim_gen_query_enabled_amd(b));
558    {
559       /* Activate only 1 lane and add the number of primitives to query result. */
560       nir_if *if_elected = nir_push_if(b, nir_elect(b, 1));
561       {
562          /* Number of input primitives in the current wave. */
563          nir_def *num_input_prims = nir_ubfe_imm(b, nir_load_merged_wave_info_amd(b),
564                                                      8, 8);
565 
566          /* Add to stream 0 primitive generated counter. */
567          nir_atomic_add_gen_prim_count_amd(b, num_input_prims, .stream_id = 0);
568       }
569       nir_pop_if(b, if_elected);
570    }
571    nir_pop_if(b, if_shader_query);
572 }
573 
574 static void
emit_ngg_nogs_prim_export(nir_builder * b,lower_ngg_nogs_state * s,nir_def * arg)575 emit_ngg_nogs_prim_export(nir_builder *b, lower_ngg_nogs_state *s, nir_def *arg)
576 {
577    nir_if *if_gs_thread = nir_push_if(b, nir_load_var(b, s->gs_exported_var));
578    {
579       if (!arg)
580          arg = emit_ngg_nogs_prim_exp_arg(b, s);
581 
582       /* pack user edge flag info into arg */
583       if (s->has_user_edgeflags) {
584          /* Workgroup barrier: wait for ES threads store user edge flags to LDS */
585          nir_barrier(b, .execution_scope = SCOPE_WORKGROUP,
586                             .memory_scope = SCOPE_WORKGROUP,
587                             .memory_semantics = NIR_MEMORY_ACQ_REL,
588                             .memory_modes = nir_var_mem_shared);
589 
590          unsigned edge_flag_bits = ac_get_all_edge_flag_bits();
591          nir_def *mask = nir_imm_intN_t(b, ~edge_flag_bits, 32);
592 
593          unsigned edge_flag_offset = 0;
594          if (s->streamout_enabled) {
595             unsigned packed_location =
596                util_bitcount64(b->shader->info.outputs_written &
597                                BITFIELD64_MASK(VARYING_SLOT_EDGE));
598             edge_flag_offset = packed_location * 16;
599          }
600 
601          for (int i = 0; i < s->options->num_vertices_per_primitive; i++) {
602             nir_def *vtx_idx = nir_load_var(b, s->gs_vtx_indices_vars[i]);
603             nir_def *addr = pervertex_lds_addr(b, vtx_idx, s->pervertex_lds_bytes);
604             nir_def *edge = nir_load_shared(b, 1, 32, addr, .base = edge_flag_offset);
605             mask = nir_ior(b, mask, nir_ishl_imm(b, edge, 9 + i * 10));
606          }
607          arg = nir_iand(b, arg, mask);
608       }
609 
610       ac_nir_export_primitive(b, arg, NULL);
611    }
612    nir_pop_if(b, if_gs_thread);
613 }
614 
615 static void
emit_ngg_nogs_prim_id_store_shared(nir_builder * b,lower_ngg_nogs_state * s)616 emit_ngg_nogs_prim_id_store_shared(nir_builder *b, lower_ngg_nogs_state *s)
617 {
618    nir_def *gs_thread =
619       s->gs_accepted_var ? nir_load_var(b, s->gs_accepted_var) : has_input_primitive(b);
620 
621    nir_if *if_gs_thread = nir_push_if(b, gs_thread);
622    {
623       /* Copy Primitive IDs from GS threads to the LDS address
624        * corresponding to the ES thread of the provoking vertex.
625        * It will be exported as a per-vertex attribute.
626        */
627       nir_def *gs_vtx_indices[3];
628       for (unsigned i = 0; i < s->options->num_vertices_per_primitive; i++)
629          gs_vtx_indices[i] = nir_load_var(b, s->gs_vtx_indices_vars[i]);
630 
631       nir_def *provoking_vertex = nir_load_provoking_vtx_in_prim_amd(b);
632       nir_def *provoking_vtx_idx = nir_select_from_ssa_def_array(
633          b, gs_vtx_indices, s->options->num_vertices_per_primitive, provoking_vertex);
634 
635       nir_def *prim_id = nir_load_primitive_id(b);
636       nir_def *addr = pervertex_lds_addr(b, provoking_vtx_idx, s->pervertex_lds_bytes);
637 
638       /* primitive id is always at last of a vertex */
639       nir_store_shared(b, prim_id, addr, .base = s->pervertex_lds_bytes - 4);
640    }
641    nir_pop_if(b, if_gs_thread);
642 }
643 
644 static void
emit_store_ngg_nogs_es_primitive_id(nir_builder * b,lower_ngg_nogs_state * s)645 emit_store_ngg_nogs_es_primitive_id(nir_builder *b, lower_ngg_nogs_state *s)
646 {
647    nir_def *prim_id = NULL;
648 
649    if (b->shader->info.stage == MESA_SHADER_VERTEX) {
650       /* LDS address where the primitive ID is stored */
651       nir_def *thread_id_in_threadgroup = nir_load_local_invocation_index(b);
652       nir_def *addr =
653          pervertex_lds_addr(b, thread_id_in_threadgroup, s->pervertex_lds_bytes);
654 
655       /* Load primitive ID from LDS */
656       prim_id = nir_load_shared(b, 1, 32, addr, .base = s->pervertex_lds_bytes - 4);
657    } else if (b->shader->info.stage == MESA_SHADER_TESS_EVAL) {
658       /* Just use tess eval primitive ID, which is the same as the patch ID. */
659       prim_id = nir_load_primitive_id(b);
660    }
661 
662    s->outputs[VARYING_SLOT_PRIMITIVE_ID][0] = prim_id;
663 
664    /* Update outputs_written to reflect that the pass added a new output. */
665    b->shader->info.outputs_written |= VARYING_BIT_PRIMITIVE_ID;
666 }
667 
668 static void
add_clipdist_bit(nir_builder * b,nir_def * dist,unsigned index,nir_variable * mask)669 add_clipdist_bit(nir_builder *b, nir_def *dist, unsigned index, nir_variable *mask)
670 {
671    nir_def *is_neg = nir_flt_imm(b, dist, 0);
672    nir_def *neg_mask = nir_ishl_imm(b, nir_b2i32(b, is_neg), index);
673    neg_mask = nir_ior(b, neg_mask, nir_load_var(b, mask));
674    nir_store_var(b, mask, neg_mask, 1);
675 }
676 
677 static bool
remove_culling_shader_output(nir_builder * b,nir_instr * instr,void * state)678 remove_culling_shader_output(nir_builder *b, nir_instr *instr, void *state)
679 {
680    lower_ngg_nogs_state *s = (lower_ngg_nogs_state *) state;
681 
682    if (instr->type != nir_instr_type_intrinsic)
683       return false;
684 
685    nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
686 
687    /* These are not allowed in VS / TES */
688    assert(intrin->intrinsic != nir_intrinsic_store_per_vertex_output &&
689           intrin->intrinsic != nir_intrinsic_load_per_vertex_input);
690 
691    /* We are only interested in output stores now */
692    if (intrin->intrinsic != nir_intrinsic_store_output)
693       return false;
694 
695    b->cursor = nir_before_instr(instr);
696 
697    /* no indirect output */
698    assert(nir_src_is_const(intrin->src[1]) && nir_src_as_uint(intrin->src[1]) == 0);
699 
700    unsigned writemask = nir_intrinsic_write_mask(intrin);
701    unsigned component = nir_intrinsic_component(intrin);
702    nir_def *store_val = intrin->src[0].ssa;
703 
704    /* Position output - store the value to a variable, remove output store */
705    nir_io_semantics io_sem = nir_intrinsic_io_semantics(intrin);
706    switch (io_sem.location) {
707    case VARYING_SLOT_POS:
708       ac_nir_store_var_components(b, s->position_value_var, store_val, component, writemask);
709       break;
710    case VARYING_SLOT_CLIP_DIST0:
711    case VARYING_SLOT_CLIP_DIST1: {
712       unsigned base = io_sem.location == VARYING_SLOT_CLIP_DIST1 ? 4 : 0;
713       base += component;
714 
715       /* valid clipdist component mask */
716       unsigned mask = (s->options->clip_cull_dist_mask >> base) & writemask;
717       u_foreach_bit(i, mask) {
718          add_clipdist_bit(b, nir_channel(b, store_val, i), base + i,
719                           s->clipdist_neg_mask_var);
720          s->has_clipdist = true;
721       }
722       break;
723    }
724    case VARYING_SLOT_CLIP_VERTEX:
725       ac_nir_store_var_components(b, s->clip_vertex_var, store_val, component, writemask);
726       break;
727    default:
728       break;
729    }
730 
731    /* Remove all output stores */
732    nir_instr_remove(instr);
733    return true;
734 }
735 
736 static void
remove_culling_shader_outputs(nir_shader * culling_shader,lower_ngg_nogs_state * s)737 remove_culling_shader_outputs(nir_shader *culling_shader, lower_ngg_nogs_state *s)
738 {
739    nir_shader_instructions_pass(culling_shader, remove_culling_shader_output,
740                                 nir_metadata_block_index | nir_metadata_dominance, s);
741 
742    /* Remove dead code resulting from the deleted outputs. */
743    bool progress;
744    do {
745       progress = false;
746       NIR_PASS(progress, culling_shader, nir_opt_dead_write_vars);
747       NIR_PASS(progress, culling_shader, nir_opt_dce);
748       NIR_PASS(progress, culling_shader, nir_opt_dead_cf);
749    } while (progress);
750 }
751 
752 static void
rewrite_uses_to_var(nir_builder * b,nir_def * old_def,nir_variable * replacement_var,unsigned replacement_var_channel)753 rewrite_uses_to_var(nir_builder *b, nir_def *old_def, nir_variable *replacement_var, unsigned replacement_var_channel)
754 {
755    if (old_def->parent_instr->type == nir_instr_type_load_const)
756       return;
757 
758    b->cursor = nir_after_instr(old_def->parent_instr);
759    if (b->cursor.instr->type == nir_instr_type_phi)
760       b->cursor = nir_after_phis(old_def->parent_instr->block);
761 
762    nir_def *pos_val_rep = nir_load_var(b, replacement_var);
763    nir_def *replacement = nir_channel(b, pos_val_rep, replacement_var_channel);
764 
765    if (old_def->num_components > 1) {
766       /* old_def uses a swizzled vector component.
767        * There is no way to replace the uses of just a single vector component,
768        * so instead create a new vector and replace all uses of the old vector.
769        */
770       nir_def *old_def_elements[NIR_MAX_VEC_COMPONENTS] = {0};
771       for (unsigned j = 0; j < old_def->num_components; ++j)
772          old_def_elements[j] = nir_channel(b, old_def, j);
773       replacement = nir_vec(b, old_def_elements, old_def->num_components);
774    }
775 
776    nir_def_rewrite_uses_after(old_def, replacement, replacement->parent_instr);
777 }
778 
779 static bool
remove_extra_pos_output(nir_builder * b,nir_instr * instr,void * state)780 remove_extra_pos_output(nir_builder *b, nir_instr *instr, void *state)
781 {
782    lower_ngg_nogs_state *s = (lower_ngg_nogs_state *) state;
783 
784    if (instr->type != nir_instr_type_intrinsic)
785       return false;
786 
787    nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
788 
789    /* These are not allowed in VS / TES */
790    assert(intrin->intrinsic != nir_intrinsic_store_per_vertex_output &&
791           intrin->intrinsic != nir_intrinsic_load_per_vertex_input);
792 
793    /* We are only interested in output stores now */
794    if (intrin->intrinsic != nir_intrinsic_store_output)
795       return false;
796 
797    nir_io_semantics io_sem = nir_intrinsic_io_semantics(intrin);
798    if (io_sem.location != VARYING_SLOT_POS)
799       return false;
800 
801    b->cursor = nir_before_instr(instr);
802 
803    /* In case other outputs use what we calculated for pos,
804     * try to avoid calculating it again by rewriting the usages
805     * of the store components here.
806     */
807    nir_def *store_val = intrin->src[0].ssa;
808    unsigned store_pos_component = nir_intrinsic_component(intrin);
809 
810    nir_instr_remove(instr);
811 
812    if (store_val->parent_instr->type == nir_instr_type_alu) {
813       nir_alu_instr *alu = nir_instr_as_alu(store_val->parent_instr);
814       if (nir_op_is_vec_or_mov(alu->op)) {
815          /* Output store uses a vector, we can easily rewrite uses of each vector element. */
816 
817          unsigned num_vec_src = 0;
818          if (alu->op == nir_op_mov)
819             num_vec_src = 1;
820          else if (alu->op == nir_op_vec2)
821             num_vec_src = 2;
822          else if (alu->op == nir_op_vec3)
823             num_vec_src = 3;
824          else if (alu->op == nir_op_vec4)
825             num_vec_src = 4;
826          assert(num_vec_src);
827 
828          /* Remember the current components whose uses we wish to replace.
829           * This is needed because rewriting one source can affect the others too.
830           */
831          nir_def *vec_comps[NIR_MAX_VEC_COMPONENTS] = {0};
832          for (unsigned i = 0; i < num_vec_src; i++)
833             vec_comps[i] = alu->src[i].src.ssa;
834 
835          for (unsigned i = 0; i < num_vec_src; i++)
836             rewrite_uses_to_var(b, vec_comps[i], s->position_value_var, store_pos_component + i);
837       } else {
838          rewrite_uses_to_var(b, store_val, s->position_value_var, store_pos_component);
839       }
840    } else {
841       rewrite_uses_to_var(b, store_val, s->position_value_var, store_pos_component);
842    }
843 
844    return true;
845 }
846 
847 static void
remove_extra_pos_outputs(nir_shader * shader,lower_ngg_nogs_state * s)848 remove_extra_pos_outputs(nir_shader *shader, lower_ngg_nogs_state *s)
849 {
850    nir_shader_instructions_pass(shader, remove_extra_pos_output,
851                                 nir_metadata_block_index | nir_metadata_dominance,
852                                 s);
853 }
854 
855 static bool
remove_compacted_arg(lower_ngg_nogs_state * s,nir_builder * b,unsigned idx)856 remove_compacted_arg(lower_ngg_nogs_state *s, nir_builder *b, unsigned idx)
857 {
858    nir_instr *store_instr = s->compact_arg_stores[idx];
859    if (!store_instr)
860       return false;
861 
862    /* Simply remove the store. */
863    nir_instr_remove(store_instr);
864 
865    /* Find the intrinsic that overwrites the shader arguments,
866     * and change its corresponding source.
867     * This will cause NIR's DCE to recognize the load and its phis as dead.
868     */
869    b->cursor = nir_before_instr(&s->overwrite_args->instr);
870    nir_def *undef_arg = nir_undef(b, 1, 32);
871    nir_def_rewrite_uses(s->overwrite_args->src[idx].ssa, undef_arg);
872 
873    s->compact_arg_stores[idx] = NULL;
874    return true;
875 }
876 
877 static bool
cleanup_culling_shader_after_dce(nir_shader * shader,nir_function_impl * function_impl,lower_ngg_nogs_state * s)878 cleanup_culling_shader_after_dce(nir_shader *shader,
879                                  nir_function_impl *function_impl,
880                                  lower_ngg_nogs_state *s)
881 {
882    bool uses_vs_vertex_id = false;
883    bool uses_vs_instance_id = false;
884    bool uses_tes_u = false;
885    bool uses_tes_v = false;
886    bool uses_tes_rel_patch_id = false;
887    bool uses_tes_patch_id = false;
888 
889    bool progress = false;
890    nir_builder b = nir_builder_create(function_impl);
891 
892    nir_foreach_block_reverse_safe(block, function_impl) {
893       nir_foreach_instr_reverse_safe(instr, block) {
894          if (instr->type != nir_instr_type_intrinsic)
895             continue;
896 
897          nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
898 
899          switch (intrin->intrinsic) {
900          case nir_intrinsic_sendmsg_amd:
901             goto cleanup_culling_shader_after_dce_done;
902          case nir_intrinsic_load_vertex_id:
903          case nir_intrinsic_load_vertex_id_zero_base:
904             uses_vs_vertex_id = true;
905             break;
906          case nir_intrinsic_load_instance_id:
907             uses_vs_instance_id = true;
908             break;
909          case nir_intrinsic_load_input:
910             if (s->options->instance_rate_inputs & BITFIELD_BIT(nir_intrinsic_base(intrin)))
911                uses_vs_instance_id = true;
912             else
913                uses_vs_vertex_id = true;
914             break;
915          case nir_intrinsic_load_tess_coord:
916             uses_tes_u = uses_tes_v = true;
917             break;
918          case nir_intrinsic_load_tess_rel_patch_id_amd:
919             uses_tes_rel_patch_id = true;
920             break;
921          case nir_intrinsic_load_primitive_id:
922             if (shader->info.stage == MESA_SHADER_TESS_EVAL)
923                uses_tes_patch_id = true;
924             break;
925          default:
926             break;
927          }
928       }
929    }
930 
931    cleanup_culling_shader_after_dce_done:
932 
933    if (shader->info.stage == MESA_SHADER_VERTEX) {
934       if (!uses_vs_vertex_id)
935          progress |= remove_compacted_arg(s, &b, 0);
936       if (!uses_vs_instance_id)
937          progress |= remove_compacted_arg(s, &b, 1);
938    } else if (shader->info.stage == MESA_SHADER_TESS_EVAL) {
939       if (!uses_tes_u)
940          progress |= remove_compacted_arg(s, &b, 0);
941       if (!uses_tes_v)
942          progress |= remove_compacted_arg(s, &b, 1);
943       if (!uses_tes_rel_patch_id)
944          progress |= remove_compacted_arg(s, &b, 3);
945       if (!uses_tes_patch_id)
946          progress |= remove_compacted_arg(s, &b, 2);
947    }
948 
949    return progress;
950 }
951 
952 /**
953  * Perform vertex compaction after culling.
954  *
955  * 1. Repack surviving ES invocations (this determines which lane will export which vertex)
956  * 2. Surviving ES vertex invocations store their data to LDS
957  * 3. Emit GS_ALLOC_REQ
958  * 4. Repacked invocations load the vertex data from LDS
959  * 5. GS threads update their vertex indices
960  */
961 static void
compact_vertices_after_culling(nir_builder * b,lower_ngg_nogs_state * s,nir_variable ** repacked_variables,nir_variable ** gs_vtxaddr_vars,nir_def * invocation_index,nir_def * es_vertex_lds_addr,nir_def * es_exporter_tid,nir_def * num_live_vertices_in_workgroup,unsigned pervertex_lds_bytes,unsigned num_repacked_variables)962 compact_vertices_after_culling(nir_builder *b,
963                                lower_ngg_nogs_state *s,
964                                nir_variable **repacked_variables,
965                                nir_variable **gs_vtxaddr_vars,
966                                nir_def *invocation_index,
967                                nir_def *es_vertex_lds_addr,
968                                nir_def *es_exporter_tid,
969                                nir_def *num_live_vertices_in_workgroup,
970                                unsigned pervertex_lds_bytes,
971                                unsigned num_repacked_variables)
972 {
973    nir_variable *es_accepted_var = s->es_accepted_var;
974    nir_variable *gs_accepted_var = s->gs_accepted_var;
975    nir_variable *position_value_var = s->position_value_var;
976    nir_variable *prim_exp_arg_var = s->prim_exp_arg_var;
977 
978    nir_if *if_es_accepted = nir_push_if(b, nir_load_var(b, es_accepted_var));
979    {
980       nir_def *exporter_addr = pervertex_lds_addr(b, es_exporter_tid, pervertex_lds_bytes);
981 
982       /* Store the exporter thread's index to the LDS space of the current thread so GS threads can load it */
983       nir_store_shared(b, nir_u2u8(b, es_exporter_tid), es_vertex_lds_addr, .base = lds_es_exporter_tid);
984 
985       /* Store the current thread's position output to the exporter thread's LDS space */
986       nir_def *pos = nir_load_var(b, position_value_var);
987       nir_store_shared(b, pos, exporter_addr, .base = lds_es_pos_x);
988 
989       /* Store the current thread's repackable arguments to the exporter thread's LDS space */
990       for (unsigned i = 0; i < num_repacked_variables; ++i) {
991          nir_def *arg_val = nir_load_var(b, repacked_variables[i]);
992          nir_intrinsic_instr *store = nir_store_shared(b, arg_val, exporter_addr, .base = lds_es_arg_0 + 4u * i);
993 
994          s->compact_arg_stores[i] = &store->instr;
995       }
996 
997       /* TES rel patch id does not cost extra dword */
998       if (b->shader->info.stage == MESA_SHADER_TESS_EVAL) {
999          nir_def *arg_val = nir_load_var(b, s->repacked_rel_patch_id);
1000          nir_intrinsic_instr *store =
1001             nir_store_shared(b, nir_u2u8(b, arg_val), exporter_addr,
1002                              .base = lds_es_tes_rel_patch_id);
1003 
1004          s->compact_arg_stores[3] = &store->instr;
1005       }
1006    }
1007    nir_pop_if(b, if_es_accepted);
1008 
1009    /* TODO: Consider adding a shortcut exit.
1010     * Waves that have no vertices and primitives left can s_endpgm right here.
1011     */
1012 
1013    nir_barrier(b, .execution_scope=SCOPE_WORKGROUP, .memory_scope=SCOPE_WORKGROUP,
1014                          .memory_semantics=NIR_MEMORY_ACQ_REL, .memory_modes=nir_var_mem_shared);
1015 
1016    nir_def *es_survived = nir_ilt(b, invocation_index, num_live_vertices_in_workgroup);
1017    nir_if *if_packed_es_thread = nir_push_if(b, es_survived);
1018    {
1019       /* Read position from the current ES thread's LDS space (written by the exported vertex's ES thread) */
1020       nir_def *exported_pos = nir_load_shared(b, 4, 32, es_vertex_lds_addr, .base = lds_es_pos_x);
1021       nir_store_var(b, position_value_var, exported_pos, 0xfu);
1022 
1023       /* Read the repacked arguments */
1024       for (unsigned i = 0; i < num_repacked_variables; ++i) {
1025          nir_def *arg_val = nir_load_shared(b, 1, 32, es_vertex_lds_addr, .base = lds_es_arg_0 + 4u * i);
1026          nir_store_var(b, repacked_variables[i], arg_val, 0x1u);
1027       }
1028 
1029       if (b->shader->info.stage == MESA_SHADER_TESS_EVAL) {
1030          nir_def *arg_val = nir_load_shared(b, 1, 8, es_vertex_lds_addr,
1031                                                 .base = lds_es_tes_rel_patch_id);
1032          nir_store_var(b, s->repacked_rel_patch_id, nir_u2u32(b, arg_val), 0x1u);
1033       }
1034    }
1035    nir_push_else(b, if_packed_es_thread);
1036    {
1037       nir_store_var(b, position_value_var, nir_undef(b, 4, 32), 0xfu);
1038       for (unsigned i = 0; i < num_repacked_variables; ++i)
1039          nir_store_var(b, repacked_variables[i], nir_undef(b, 1, 32), 0x1u);
1040    }
1041    nir_pop_if(b, if_packed_es_thread);
1042 
1043    nir_if *if_gs_accepted = nir_push_if(b, nir_load_var(b, gs_accepted_var));
1044    {
1045       nir_def *exporter_vtx_indices[3] = {0};
1046 
1047       /* Load the index of the ES threads that will export the current GS thread's vertices */
1048       for (unsigned v = 0; v < s->options->num_vertices_per_primitive; ++v) {
1049          nir_def *vtx_addr = nir_load_var(b, gs_vtxaddr_vars[v]);
1050          nir_def *exporter_vtx_idx = nir_load_shared(b, 1, 8, vtx_addr, .base = lds_es_exporter_tid);
1051          exporter_vtx_indices[v] = nir_u2u32(b, exporter_vtx_idx);
1052          nir_store_var(b, s->gs_vtx_indices_vars[v], exporter_vtx_indices[v], 0x1);
1053       }
1054 
1055       nir_def *prim_exp_arg =
1056          emit_pack_ngg_prim_exp_arg(b, s->options->num_vertices_per_primitive,
1057                                     exporter_vtx_indices, NULL);
1058       nir_store_var(b, prim_exp_arg_var, prim_exp_arg, 0x1u);
1059    }
1060    nir_pop_if(b, if_gs_accepted);
1061 
1062    nir_store_var(b, es_accepted_var, es_survived, 0x1u);
1063 }
1064 
1065 static void
analyze_shader_before_culling_walk(nir_def * ssa,uint8_t flag,lower_ngg_nogs_state * s)1066 analyze_shader_before_culling_walk(nir_def *ssa,
1067                                    uint8_t flag,
1068                                    lower_ngg_nogs_state *s)
1069 {
1070    nir_instr *instr = ssa->parent_instr;
1071    uint8_t old_pass_flags = instr->pass_flags;
1072    instr->pass_flags |= flag;
1073 
1074    if (instr->pass_flags == old_pass_flags)
1075       return; /* Already visited. */
1076 
1077    switch (instr->type) {
1078    case nir_instr_type_intrinsic: {
1079       nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
1080 
1081       /* VS input loads and SSBO loads are actually VRAM reads on AMD HW. */
1082       switch (intrin->intrinsic) {
1083       case nir_intrinsic_load_input: {
1084          nir_io_semantics in_io_sem = nir_intrinsic_io_semantics(intrin);
1085          uint64_t in_mask = UINT64_C(1) << (uint64_t) in_io_sem.location;
1086          if (instr->pass_flags & nggc_passflag_used_by_pos)
1087             s->inputs_needed_by_pos |= in_mask;
1088          else if (instr->pass_flags & nggc_passflag_used_by_other)
1089             s->inputs_needed_by_others |= in_mask;
1090          break;
1091       }
1092       default:
1093          break;
1094       }
1095 
1096       break;
1097    }
1098    case nir_instr_type_alu: {
1099       nir_alu_instr *alu = nir_instr_as_alu(instr);
1100       unsigned num_srcs = nir_op_infos[alu->op].num_inputs;
1101 
1102       for (unsigned i = 0; i < num_srcs; ++i) {
1103          analyze_shader_before_culling_walk(alu->src[i].src.ssa, flag, s);
1104       }
1105 
1106       break;
1107    }
1108    case nir_instr_type_tex: {
1109       nir_tex_instr *tex = nir_instr_as_tex(instr);
1110       unsigned num_srcs = tex->num_srcs;
1111 
1112       for (unsigned i = 0; i < num_srcs; ++i) {
1113          analyze_shader_before_culling_walk(tex->src[i].src.ssa, flag, s);
1114       }
1115 
1116       break;
1117    }
1118    case nir_instr_type_phi: {
1119       nir_phi_instr *phi = nir_instr_as_phi(instr);
1120       nir_foreach_phi_src_safe(phi_src, phi) {
1121          analyze_shader_before_culling_walk(phi_src->src.ssa, flag, s);
1122       }
1123 
1124       break;
1125    }
1126    default:
1127       break;
1128    }
1129 }
1130 
1131 static void
analyze_shader_before_culling(nir_shader * shader,lower_ngg_nogs_state * s)1132 analyze_shader_before_culling(nir_shader *shader, lower_ngg_nogs_state *s)
1133 {
1134    /* LCSSA is needed to get correct results from divergence analysis. */
1135    nir_convert_to_lcssa(shader, true, true);
1136    /* We need divergence info for culling shaders. */
1137    nir_divergence_analysis(shader);
1138 
1139    nir_foreach_function_impl(impl, shader) {
1140       nir_foreach_block(block, impl) {
1141          nir_foreach_instr(instr, block) {
1142             instr->pass_flags = 0;
1143 
1144             if (instr->type != nir_instr_type_intrinsic)
1145                continue;
1146 
1147             nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
1148             if (intrin->intrinsic != nir_intrinsic_store_output)
1149                continue;
1150 
1151             nir_io_semantics io_sem = nir_intrinsic_io_semantics(intrin);
1152             nir_def *store_val = intrin->src[0].ssa;
1153             uint8_t flag = io_sem.location == VARYING_SLOT_POS ? nggc_passflag_used_by_pos : nggc_passflag_used_by_other;
1154             analyze_shader_before_culling_walk(store_val, flag, s);
1155          }
1156       }
1157    }
1158 }
1159 
1160 static nir_def *
find_reusable_ssa_def(nir_instr * instr)1161 find_reusable_ssa_def(nir_instr *instr)
1162 {
1163    /* Find instructions whose SSA definitions are used by both
1164     * the top and bottom parts of the shader (before and after culling).
1165     * Only in this case, it makes sense for the bottom part
1166     * to try to reuse these from the top part.
1167     */
1168    if ((instr->pass_flags & nggc_passflag_used_by_both) != nggc_passflag_used_by_both)
1169       return NULL;
1170 
1171    switch (instr->type) {
1172    case nir_instr_type_alu: {
1173       nir_alu_instr *alu = nir_instr_as_alu(instr);
1174       if (alu->def.divergent)
1175          return NULL;
1176       /* Ignore uniform floats because they regress VGPR usage too much */
1177       if (nir_op_infos[alu->op].output_type & nir_type_float)
1178          return NULL;
1179       return &alu->def;
1180    }
1181    case nir_instr_type_intrinsic: {
1182       nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
1183       if (!nir_intrinsic_can_reorder(intrin) ||
1184             !nir_intrinsic_infos[intrin->intrinsic].has_dest ||
1185             intrin->def.divergent)
1186          return NULL;
1187       return &intrin->def;
1188    }
1189    case nir_instr_type_phi: {
1190       nir_phi_instr *phi = nir_instr_as_phi(instr);
1191       if (phi->def.divergent)
1192          return NULL;
1193       return &phi->def;
1194    }
1195    default:
1196       return NULL;
1197    }
1198 }
1199 
1200 static const struct glsl_type *
glsl_uint_type_for_ssa(nir_def * ssa)1201 glsl_uint_type_for_ssa(nir_def *ssa)
1202 {
1203    enum glsl_base_type base_type = GLSL_TYPE_UINT;
1204    switch (ssa->bit_size) {
1205    case 8: base_type = GLSL_TYPE_UINT8; break;
1206    case 16: base_type = GLSL_TYPE_UINT16; break;
1207    case 32: base_type = GLSL_TYPE_UINT; break;
1208    case 64: base_type = GLSL_TYPE_UINT64; break;
1209    default: return NULL;
1210    }
1211 
1212    return ssa->num_components == 1
1213           ? glsl_scalar_type(base_type)
1214           : glsl_vector_type(base_type, ssa->num_components);
1215 }
1216 
1217 /**
1218  * Save the reusable SSA definitions to variables so that the
1219  * bottom shader part can reuse them from the top part.
1220  *
1221  * 1. We create a new function temporary variable for reusables,
1222  *    and insert a store+load.
1223  * 2. The shader is cloned (the top part is created), then the
1224  *    control flow is reinserted (for the bottom part.)
1225  * 3. For reusables, we delete the variable stores from the
1226  *    bottom part. This will make them use the variables from
1227  *    the top part and DCE the redundant instructions.
1228  */
1229 static void
save_reusable_variables(nir_builder * b,lower_ngg_nogs_state * s)1230 save_reusable_variables(nir_builder *b, lower_ngg_nogs_state *s)
1231 {
1232    ASSERTED int vec_ok = u_vector_init(&s->reusable_nondeferred_variables, 4, sizeof(reusable_nondeferred_variable));
1233    assert(vec_ok);
1234 
1235    /* Upper limit on reusable uniforms in order to reduce SGPR spilling. */
1236    unsigned remaining_reusable_uniforms = 48;
1237 
1238    nir_block *block = nir_start_block(b->impl);
1239    while (block) {
1240       /* Process the instructions in the current block. */
1241       nir_foreach_instr_safe(instr, block) {
1242          /* Determine if we can reuse the current SSA value.
1243           * When vertex compaction is used, it is possible that the same shader invocation
1244           * processes a different vertex in the top and bottom part of the shader.
1245           * Therefore, we only reuse uniform values.
1246           */
1247          nir_def *ssa = find_reusable_ssa_def(instr);
1248          if (!ssa)
1249             continue;
1250 
1251          /* Determine a suitable type for the SSA value. */
1252          const struct glsl_type *t = glsl_uint_type_for_ssa(ssa);
1253          if (!t)
1254             continue;
1255 
1256          if (!ssa->divergent) {
1257             if (remaining_reusable_uniforms < ssa->num_components)
1258                continue;
1259 
1260             remaining_reusable_uniforms -= ssa->num_components;
1261          }
1262 
1263          reusable_nondeferred_variable *saved = (reusable_nondeferred_variable *) u_vector_add(&s->reusable_nondeferred_variables);
1264          assert(saved);
1265 
1266          /* Create a new NIR variable where we store the reusable value.
1267           * Then, we reload the variable and replace the uses of the value
1268           * with the reloaded variable.
1269           */
1270          saved->var = nir_local_variable_create(b->impl, t, NULL);
1271          saved->ssa = ssa;
1272 
1273          b->cursor = instr->type == nir_instr_type_phi
1274                      ? nir_after_instr_and_phis(instr)
1275                      : nir_after_instr(instr);
1276          nir_store_var(b, saved->var, saved->ssa, BITFIELD_MASK(ssa->num_components));
1277          nir_def *reloaded = nir_load_var(b, saved->var);
1278          nir_def_rewrite_uses_after(ssa, reloaded, reloaded->parent_instr);
1279       }
1280 
1281       /* Look at the next CF node. */
1282       nir_cf_node *next_cf_node = nir_cf_node_next(&block->cf_node);
1283       if (next_cf_node) {
1284          /* It makes no sense to try to reuse things from within loops. */
1285          bool next_is_loop = next_cf_node->type == nir_cf_node_loop;
1286 
1287          /* Don't reuse if we're in divergent control flow.
1288           *
1289           * Thanks to vertex repacking, the same shader invocation may process a different vertex
1290           * in the top and bottom part, and it's even possible that this different vertex was initially
1291           * processed in a different wave. So the two parts may take a different divergent code path.
1292           * Therefore, these variables in divergent control flow may stay undefined.
1293           *
1294           * Note that this problem doesn't exist if vertices are not repacked or if the
1295           * workgroup only has a single wave.
1296           */
1297          bool next_is_divergent_if =
1298             next_cf_node->type == nir_cf_node_if &&
1299             nir_cf_node_as_if(next_cf_node)->condition.ssa->divergent;
1300 
1301          if (next_is_loop || next_is_divergent_if) {
1302             block = nir_cf_node_cf_tree_next(next_cf_node);
1303             continue;
1304          }
1305       }
1306 
1307       /* Go to the next block. */
1308       block = nir_block_cf_tree_next(block);
1309    }
1310 }
1311 
1312 /**
1313  * Reuses suitable variables from the top part of the shader,
1314  * by deleting their stores from the bottom part.
1315  */
1316 static void
apply_reusable_variables(nir_builder * b,lower_ngg_nogs_state * s)1317 apply_reusable_variables(nir_builder *b, lower_ngg_nogs_state *s)
1318 {
1319    if (!u_vector_length(&s->reusable_nondeferred_variables)) {
1320       u_vector_finish(&s->reusable_nondeferred_variables);
1321       return;
1322    }
1323 
1324    nir_foreach_block_reverse_safe(block, b->impl) {
1325       nir_foreach_instr_reverse_safe(instr, block) {
1326          if (instr->type != nir_instr_type_intrinsic)
1327             continue;
1328          nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
1329 
1330          /* When we found any of these intrinsics, it means
1331           * we reached the top part and we must stop.
1332           */
1333          if (intrin->intrinsic == nir_intrinsic_sendmsg_amd)
1334             goto done;
1335 
1336          if (intrin->intrinsic != nir_intrinsic_store_deref)
1337             continue;
1338          nir_deref_instr *deref = nir_src_as_deref(intrin->src[0]);
1339          if (deref->deref_type != nir_deref_type_var)
1340             continue;
1341 
1342          reusable_nondeferred_variable *saved;
1343          u_vector_foreach(saved, &s->reusable_nondeferred_variables) {
1344             if (saved->var == deref->var) {
1345                nir_instr_remove(instr);
1346             }
1347          }
1348       }
1349    }
1350 
1351    done:
1352    u_vector_finish(&s->reusable_nondeferred_variables);
1353 }
1354 
1355 static void
cull_primitive_accepted(nir_builder * b,void * state)1356 cull_primitive_accepted(nir_builder *b, void *state)
1357 {
1358    lower_ngg_nogs_state *s = (lower_ngg_nogs_state *)state;
1359 
1360    nir_store_var(b, s->gs_accepted_var, nir_imm_true(b), 0x1u);
1361 
1362    /* Store the accepted state to LDS for ES threads */
1363    for (unsigned vtx = 0; vtx < s->options->num_vertices_per_primitive; ++vtx)
1364       nir_store_shared(b, nir_imm_intN_t(b, 1, 8), s->vtx_addr[vtx], .base = lds_es_vertex_accepted);
1365 }
1366 
1367 static void
clipdist_culling_es_part(nir_builder * b,lower_ngg_nogs_state * s,nir_def * es_vertex_lds_addr)1368 clipdist_culling_es_part(nir_builder *b, lower_ngg_nogs_state *s,
1369                          nir_def *es_vertex_lds_addr)
1370 {
1371    /* no gl_ClipDistance used but we have user defined clip plane */
1372    if (s->options->user_clip_plane_enable_mask && !s->has_clipdist) {
1373       /* use gl_ClipVertex if defined */
1374       nir_variable *clip_vertex_var =
1375          b->shader->info.outputs_written & BITFIELD64_BIT(VARYING_SLOT_CLIP_VERTEX) ?
1376          s->clip_vertex_var : s->position_value_var;
1377       nir_def *clip_vertex = nir_load_var(b, clip_vertex_var);
1378 
1379       /* clip against user defined clip planes */
1380       for (unsigned i = 0; i < 8; i++) {
1381          if (!(s->options->user_clip_plane_enable_mask & BITFIELD_BIT(i)))
1382             continue;
1383 
1384          nir_def *plane = nir_load_user_clip_plane(b, .ucp_id = i);
1385          nir_def *dist = nir_fdot(b, clip_vertex, plane);
1386          add_clipdist_bit(b, dist, i, s->clipdist_neg_mask_var);
1387       }
1388 
1389       s->has_clipdist = true;
1390    }
1391 
1392    /* store clipdist_neg_mask to LDS for culling latter in gs thread */
1393    if (s->has_clipdist) {
1394       nir_def *mask = nir_load_var(b, s->clipdist_neg_mask_var);
1395       nir_store_shared(b, nir_u2u8(b, mask), es_vertex_lds_addr,
1396                        .base = lds_es_clipdist_neg_mask);
1397    }
1398 }
1399 
1400 static unsigned
ngg_nogs_get_culling_pervertex_lds_size(gl_shader_stage stage,bool uses_instance_id,bool uses_primitive_id,unsigned * num_repacked_variables)1401 ngg_nogs_get_culling_pervertex_lds_size(gl_shader_stage stage,
1402                                         bool uses_instance_id,
1403                                         bool uses_primitive_id,
1404                                         unsigned *num_repacked_variables)
1405 {
1406    /* Culling shaders must repack some variables because
1407     * the same shader invocation may process different vertices
1408     * before and after the culling algorithm.
1409     */
1410 
1411    unsigned num_repacked;
1412    if (stage == MESA_SHADER_VERTEX) {
1413       /* Vertex shaders repack:
1414        * - Vertex ID
1415        * - Instance ID (only if used)
1416        */
1417       num_repacked = uses_instance_id ? 2 : 1;
1418    } else {
1419       /* Tess eval shaders repack:
1420        * - U, V coordinates
1421        * - primitive ID (aka. patch id, only if used)
1422        * - relative patch id (not included here because doesn't need a dword)
1423        */
1424       assert(stage == MESA_SHADER_TESS_EVAL);
1425       num_repacked = uses_primitive_id ? 3 : 2;
1426    }
1427 
1428    if (num_repacked_variables)
1429       *num_repacked_variables = num_repacked;
1430 
1431    /* one odd dword to reduce LDS bank conflict */
1432    return (lds_es_arg_0 + num_repacked * 4u) | 4u;
1433 }
1434 
1435 static void
add_deferred_attribute_culling(nir_builder * b,nir_cf_list * original_extracted_cf,lower_ngg_nogs_state * s)1436 add_deferred_attribute_culling(nir_builder *b, nir_cf_list *original_extracted_cf, lower_ngg_nogs_state *s)
1437 {
1438    bool uses_instance_id = BITSET_TEST(b->shader->info.system_values_read, SYSTEM_VALUE_INSTANCE_ID);
1439    bool uses_tess_primitive_id = BITSET_TEST(b->shader->info.system_values_read, SYSTEM_VALUE_PRIMITIVE_ID);
1440 
1441    unsigned num_repacked_variables;
1442    unsigned pervertex_lds_bytes =
1443       ngg_nogs_get_culling_pervertex_lds_size(b->shader->info.stage,
1444                                               uses_instance_id,
1445                                               uses_tess_primitive_id,
1446                                               &num_repacked_variables);
1447 
1448    nir_function_impl *impl = nir_shader_get_entrypoint(b->shader);
1449 
1450    /* Create some helper variables. */
1451    nir_variable *gs_vtxaddr_vars[3] = {
1452       nir_local_variable_create(impl, glsl_uint_type(), "gs_vtx0_addr"),
1453       nir_local_variable_create(impl, glsl_uint_type(), "gs_vtx1_addr"),
1454       nir_local_variable_create(impl, glsl_uint_type(), "gs_vtx2_addr"),
1455    };
1456 
1457    nir_variable *repacked_variables[3] = {
1458       nir_local_variable_create(impl, glsl_uint_type(), "repacked_var_0"),
1459       nir_local_variable_create(impl, glsl_uint_type(), "repacked_var_1"),
1460       nir_local_variable_create(impl, glsl_uint_type(), "repacked_var_2"),
1461    };
1462 
1463    /* Relative patch ID is a special case because it doesn't need an extra dword, repack separately. */
1464    s->repacked_rel_patch_id = nir_local_variable_create(impl, glsl_uint_type(), "repacked_rel_patch_id");
1465 
1466    if (s->options->clip_cull_dist_mask ||
1467        s->options->user_clip_plane_enable_mask) {
1468       s->clip_vertex_var =
1469          nir_local_variable_create(impl, glsl_vec4_type(), "clip_vertex");
1470       s->clipdist_neg_mask_var =
1471          nir_local_variable_create(impl, glsl_uint_type(), "clipdist_neg_mask");
1472 
1473       /* init mask to 0 */
1474       nir_store_var(b, s->clipdist_neg_mask_var, nir_imm_int(b, 0), 1);
1475    }
1476 
1477    /* Top part of the culling shader (aka. position shader part)
1478     *
1479     * We clone the full ES shader and emit it here, but we only really care
1480     * about its position output, so we delete every other output from this part.
1481     * The position output is stored into a temporary variable, and reloaded later.
1482     */
1483 
1484    nir_def *es_thread = has_input_vertex(b);
1485    nir_if *if_es_thread = nir_push_if(b, es_thread);
1486    {
1487       /* Initialize the position output variable to zeroes, in case not all VS/TES invocations store the output.
1488        * The spec doesn't require it, but we use (0, 0, 0, 1) because some games rely on that.
1489        */
1490       nir_store_var(b, s->position_value_var, nir_imm_vec4(b, 0.0f, 0.0f, 0.0f, 1.0f), 0xfu);
1491 
1492       /* Now reinsert a clone of the shader code */
1493       struct hash_table *remap_table = _mesa_pointer_hash_table_create(NULL);
1494       nir_cf_list_clone_and_reinsert(original_extracted_cf, &if_es_thread->cf_node, b->cursor, remap_table);
1495       _mesa_hash_table_destroy(remap_table, NULL);
1496       b->cursor = nir_after_cf_list(&if_es_thread->then_list);
1497 
1498       /* Remember the current thread's shader arguments */
1499       if (b->shader->info.stage == MESA_SHADER_VERTEX) {
1500          nir_store_var(b, repacked_variables[0], nir_load_vertex_id_zero_base(b), 0x1u);
1501          if (uses_instance_id)
1502             nir_store_var(b, repacked_variables[1], nir_load_instance_id(b), 0x1u);
1503       } else if (b->shader->info.stage == MESA_SHADER_TESS_EVAL) {
1504          nir_store_var(b, s->repacked_rel_patch_id, nir_load_tess_rel_patch_id_amd(b), 0x1u);
1505          nir_def *tess_coord = nir_load_tess_coord(b);
1506          nir_store_var(b, repacked_variables[0], nir_channel(b, tess_coord, 0), 0x1u);
1507          nir_store_var(b, repacked_variables[1], nir_channel(b, tess_coord, 1), 0x1u);
1508          if (uses_tess_primitive_id)
1509             nir_store_var(b, repacked_variables[2], nir_load_primitive_id(b), 0x1u);
1510       } else {
1511          unreachable("Should be VS or TES.");
1512       }
1513    }
1514    nir_pop_if(b, if_es_thread);
1515 
1516    nir_store_var(b, s->es_accepted_var, es_thread, 0x1u);
1517    nir_def *gs_thread = has_input_primitive(b);
1518    nir_store_var(b, s->gs_accepted_var, gs_thread, 0x1u);
1519 
1520    /* Remove all non-position outputs, and put the position output into the variable. */
1521    nir_metadata_preserve(impl, nir_metadata_none);
1522    remove_culling_shader_outputs(b->shader, s);
1523    b->cursor = nir_after_impl(impl);
1524 
1525    nir_def *lds_scratch_base = nir_load_lds_ngg_scratch_base_amd(b);
1526 
1527    /* Run culling algorithms if culling is enabled.
1528     *
1529     * NGG culling can be enabled or disabled in runtime.
1530     * This is determined by a SGPR shader argument which is accessed
1531     * by the following NIR intrinsic.
1532     */
1533 
1534    nir_if *if_cull_en = nir_push_if(b, nir_load_cull_any_enabled_amd(b));
1535    {
1536       nir_def *invocation_index = nir_load_local_invocation_index(b);
1537       nir_def *es_vertex_lds_addr = pervertex_lds_addr(b, invocation_index, pervertex_lds_bytes);
1538 
1539       /* ES invocations store their vertex data to LDS for GS threads to read. */
1540       if_es_thread = nir_push_if(b, es_thread);
1541       if_es_thread->control = nir_selection_control_divergent_always_taken;
1542       {
1543          /* Store position components that are relevant to culling in LDS */
1544          nir_def *pre_cull_pos = nir_load_var(b, s->position_value_var);
1545          nir_def *pre_cull_w = nir_channel(b, pre_cull_pos, 3);
1546          nir_store_shared(b, pre_cull_w, es_vertex_lds_addr, .base = lds_es_pos_w);
1547          nir_def *pre_cull_x_div_w = nir_fdiv(b, nir_channel(b, pre_cull_pos, 0), pre_cull_w);
1548          nir_def *pre_cull_y_div_w = nir_fdiv(b, nir_channel(b, pre_cull_pos, 1), pre_cull_w);
1549          nir_store_shared(b, nir_vec2(b, pre_cull_x_div_w, pre_cull_y_div_w), es_vertex_lds_addr, .base = lds_es_pos_x);
1550 
1551          /* Clear out the ES accepted flag in LDS */
1552          nir_store_shared(b, nir_imm_zero(b, 1, 8), es_vertex_lds_addr, .align_mul = 4, .base = lds_es_vertex_accepted);
1553 
1554          /* For clipdist culling */
1555          clipdist_culling_es_part(b, s, es_vertex_lds_addr);
1556       }
1557       nir_pop_if(b, if_es_thread);
1558 
1559       nir_barrier(b, .execution_scope=SCOPE_WORKGROUP, .memory_scope=SCOPE_WORKGROUP,
1560                             .memory_semantics=NIR_MEMORY_ACQ_REL, .memory_modes=nir_var_mem_shared);
1561 
1562       nir_store_var(b, s->gs_accepted_var, nir_imm_false(b), 0x1u);
1563       nir_store_var(b, s->prim_exp_arg_var, nir_imm_int(b, 1u << 31), 0x1u);
1564 
1565       /* GS invocations load the vertex data and perform the culling. */
1566       nir_if *if_gs_thread = nir_push_if(b, gs_thread);
1567       {
1568          /* Load vertex indices from input VGPRs */
1569          nir_def *vtx_idx[3] = {0};
1570          for (unsigned vertex = 0; vertex < s->options->num_vertices_per_primitive;
1571               ++vertex)
1572             vtx_idx[vertex] = nir_load_var(b, s->gs_vtx_indices_vars[vertex]);
1573 
1574          nir_def *pos[3][4] = {0};
1575 
1576          /* Load W positions of vertices first because the culling code will use these first */
1577          for (unsigned vtx = 0; vtx < s->options->num_vertices_per_primitive; ++vtx) {
1578             s->vtx_addr[vtx] = pervertex_lds_addr(b, vtx_idx[vtx], pervertex_lds_bytes);
1579             pos[vtx][3] = nir_load_shared(b, 1, 32, s->vtx_addr[vtx], .base = lds_es_pos_w);
1580             nir_store_var(b, gs_vtxaddr_vars[vtx], s->vtx_addr[vtx], 0x1u);
1581          }
1582 
1583          /* Load the X/W, Y/W positions of vertices */
1584          for (unsigned vtx = 0; vtx < s->options->num_vertices_per_primitive; ++vtx) {
1585             nir_def *xy = nir_load_shared(b, 2, 32, s->vtx_addr[vtx], .base = lds_es_pos_x);
1586             pos[vtx][0] = nir_channel(b, xy, 0);
1587             pos[vtx][1] = nir_channel(b, xy, 1);
1588          }
1589 
1590          nir_def *accepted_by_clipdist;
1591          if (s->has_clipdist) {
1592             nir_def *clipdist_neg_mask = nir_imm_intN_t(b, 0xff, 8);
1593             for (unsigned vtx = 0; vtx < s->options->num_vertices_per_primitive; ++vtx) {
1594                nir_def *mask =
1595                   nir_load_shared(b, 1, 8, s->vtx_addr[vtx],
1596                                   .base = lds_es_clipdist_neg_mask);
1597                clipdist_neg_mask = nir_iand(b, clipdist_neg_mask, mask);
1598             }
1599             /* primitive is culled if any plane's clipdist of all vertices are negative */
1600             accepted_by_clipdist = nir_ieq_imm(b, clipdist_neg_mask, 0);
1601          } else {
1602             accepted_by_clipdist = nir_imm_true(b);
1603          }
1604 
1605          /* See if the current primitive is accepted */
1606          ac_nir_cull_primitive(b, accepted_by_clipdist, pos,
1607                                s->options->num_vertices_per_primitive,
1608                                cull_primitive_accepted, s);
1609       }
1610       nir_pop_if(b, if_gs_thread);
1611 
1612       nir_barrier(b, .execution_scope=SCOPE_WORKGROUP, .memory_scope=SCOPE_WORKGROUP,
1613                             .memory_semantics=NIR_MEMORY_ACQ_REL, .memory_modes=nir_var_mem_shared);
1614 
1615       nir_store_var(b, s->es_accepted_var, nir_imm_false(b), 0x1u);
1616 
1617       /* ES invocations load their accepted flag from LDS. */
1618       if_es_thread = nir_push_if(b, es_thread);
1619       if_es_thread->control = nir_selection_control_divergent_always_taken;
1620       {
1621          nir_def *accepted = nir_load_shared(b, 1, 8u, es_vertex_lds_addr, .base = lds_es_vertex_accepted, .align_mul = 4u);
1622          nir_def *accepted_bool = nir_ine_imm(b, nir_u2u32(b, accepted), 0);
1623          nir_store_var(b, s->es_accepted_var, accepted_bool, 0x1u);
1624       }
1625       nir_pop_if(b, if_es_thread);
1626 
1627       nir_def *es_accepted = nir_load_var(b, s->es_accepted_var);
1628 
1629       /* Repack the vertices that survived the culling. */
1630       wg_repack_result rep = repack_invocations_in_workgroup(b, es_accepted, lds_scratch_base,
1631                                                              s->max_num_waves,
1632                                                              s->options->wave_size);
1633       nir_def *num_live_vertices_in_workgroup = rep.num_repacked_invocations;
1634       nir_def *es_exporter_tid = rep.repacked_invocation_index;
1635 
1636       /* If all vertices are culled, set primitive count to 0 as well. */
1637       nir_def *num_exported_prims = nir_load_workgroup_num_input_primitives_amd(b);
1638       nir_def *fully_culled = nir_ieq_imm(b, num_live_vertices_in_workgroup, 0u);
1639       num_exported_prims = nir_bcsel(b, fully_culled, nir_imm_int(b, 0u), num_exported_prims);
1640       nir_store_var(b, s->gs_exported_var, nir_iand(b, nir_inot(b, fully_culled), has_input_primitive(b)), 0x1u);
1641 
1642       nir_if *if_wave_0 = nir_push_if(b, nir_ieq_imm(b, nir_load_subgroup_id(b), 0));
1643       {
1644          /* Tell the final vertex and primitive count to the HW. */
1645          if (s->options->gfx_level == GFX10) {
1646             alloc_vertices_and_primitives_gfx10_workaround(
1647                b, num_live_vertices_in_workgroup, num_exported_prims);
1648          } else {
1649             alloc_vertices_and_primitives(
1650                b, num_live_vertices_in_workgroup, num_exported_prims);
1651          }
1652       }
1653       nir_pop_if(b, if_wave_0);
1654 
1655       /* Vertex compaction. */
1656       compact_vertices_after_culling(b, s,
1657                                      repacked_variables, gs_vtxaddr_vars,
1658                                      invocation_index, es_vertex_lds_addr,
1659                                      es_exporter_tid, num_live_vertices_in_workgroup,
1660                                      pervertex_lds_bytes, num_repacked_variables);
1661    }
1662    nir_push_else(b, if_cull_en);
1663    {
1664       /* When culling is disabled, we do the same as we would without culling. */
1665       nir_if *if_wave_0 = nir_push_if(b, nir_ieq_imm(b, nir_load_subgroup_id(b), 0));
1666       {
1667          nir_def *vtx_cnt = nir_load_workgroup_num_input_vertices_amd(b);
1668          nir_def *prim_cnt = nir_load_workgroup_num_input_primitives_amd(b);
1669          alloc_vertices_and_primitives(b, vtx_cnt, prim_cnt);
1670       }
1671       nir_pop_if(b, if_wave_0);
1672       nir_store_var(b, s->prim_exp_arg_var, emit_ngg_nogs_prim_exp_arg(b, s), 0x1u);
1673    }
1674    nir_pop_if(b, if_cull_en);
1675 
1676    /* Update shader arguments.
1677     *
1678     * The registers which hold information about the subgroup's
1679     * vertices and primitives are updated here, so the rest of the shader
1680     * doesn't need to worry about the culling.
1681     *
1682     * These "overwrite" intrinsics must be at top level control flow,
1683     * otherwise they can mess up the backend (eg. ACO's SSA).
1684     *
1685     * TODO:
1686     * A cleaner solution would be to simply replace all usages of these args
1687     * with the load of the variables.
1688     * However, this wouldn't work right now because the backend uses the arguments
1689     * for purposes not expressed in NIR, eg. VS input loads, etc.
1690     * This can change if VS input loads and other stuff are lowered to eg. load_buffer_amd.
1691     */
1692 
1693    if (b->shader->info.stage == MESA_SHADER_VERTEX)
1694       s->overwrite_args =
1695          nir_overwrite_vs_arguments_amd(b,
1696             nir_load_var(b, repacked_variables[0]), nir_load_var(b, repacked_variables[1]));
1697    else if (b->shader->info.stage == MESA_SHADER_TESS_EVAL)
1698       s->overwrite_args =
1699          nir_overwrite_tes_arguments_amd(b,
1700             nir_load_var(b, repacked_variables[0]), nir_load_var(b, repacked_variables[1]),
1701             nir_load_var(b, repacked_variables[2]), nir_load_var(b, s->repacked_rel_patch_id));
1702    else
1703       unreachable("Should be VS or TES.");
1704 }
1705 
1706 static void
ngg_nogs_store_edgeflag_to_lds(nir_builder * b,lower_ngg_nogs_state * s)1707 ngg_nogs_store_edgeflag_to_lds(nir_builder *b, lower_ngg_nogs_state *s)
1708 {
1709    if (!s->outputs[VARYING_SLOT_EDGE][0])
1710       return;
1711 
1712    /* clamp user edge flag to 1 for latter bit operations */
1713    nir_def *edgeflag = s->outputs[VARYING_SLOT_EDGE][0];
1714    edgeflag = nir_umin(b, edgeflag, nir_imm_int(b, 1));
1715 
1716    /* user edge flag is stored at the beginning of a vertex if streamout is not enabled */
1717    unsigned offset = 0;
1718    if (s->streamout_enabled) {
1719       unsigned packed_location =
1720          util_bitcount64(b->shader->info.outputs_written & BITFIELD64_MASK(VARYING_SLOT_EDGE));
1721       offset = packed_location * 16;
1722    }
1723 
1724    nir_def *tid = nir_load_local_invocation_index(b);
1725    nir_def *addr = pervertex_lds_addr(b, tid, s->pervertex_lds_bytes);
1726 
1727    nir_store_shared(b, edgeflag, addr, .base = offset);
1728 }
1729 
1730 static void
ngg_nogs_store_xfb_outputs_to_lds(nir_builder * b,lower_ngg_nogs_state * s)1731 ngg_nogs_store_xfb_outputs_to_lds(nir_builder *b, lower_ngg_nogs_state *s)
1732 {
1733    nir_xfb_info *info = b->shader->xfb_info;
1734 
1735    uint64_t xfb_outputs = 0;
1736    unsigned xfb_outputs_16bit = 0;
1737    uint8_t xfb_mask[VARYING_SLOT_MAX] = {0};
1738    uint8_t xfb_mask_16bit_lo[16] = {0};
1739    uint8_t xfb_mask_16bit_hi[16] = {0};
1740 
1741    /* Get XFB output mask for each slot. */
1742    for (int i = 0; i < info->output_count; i++) {
1743       nir_xfb_output_info *out = info->outputs + i;
1744 
1745       if (out->location < VARYING_SLOT_VAR0_16BIT) {
1746          xfb_outputs |= BITFIELD64_BIT(out->location);
1747          xfb_mask[out->location] |= out->component_mask;
1748       } else {
1749          unsigned index = out->location - VARYING_SLOT_VAR0_16BIT;
1750          xfb_outputs_16bit |= BITFIELD_BIT(index);
1751 
1752          if (out->high_16bits)
1753             xfb_mask_16bit_hi[index] |= out->component_mask;
1754          else
1755             xfb_mask_16bit_lo[index] |= out->component_mask;
1756       }
1757    }
1758 
1759    nir_def *tid = nir_load_local_invocation_index(b);
1760    nir_def *addr = pervertex_lds_addr(b, tid, s->pervertex_lds_bytes);
1761 
1762    u_foreach_bit64(slot, xfb_outputs) {
1763       unsigned packed_location =
1764          util_bitcount64(b->shader->info.outputs_written & BITFIELD64_MASK(slot));
1765 
1766       unsigned mask = xfb_mask[slot];
1767 
1768       /* Clear unused components. */
1769       for (unsigned i = 0; i < 4; i++) {
1770          if (!s->outputs[slot][i])
1771             mask &= ~BITFIELD_BIT(i);
1772       }
1773 
1774       while (mask) {
1775          int start, count;
1776          u_bit_scan_consecutive_range(&mask, &start, &count);
1777          /* Outputs here are sure to be 32bit.
1778           *
1779           * 64bit outputs have been lowered to two 32bit. As 16bit outputs:
1780           *   Vulkan does not allow streamout outputs less than 32bit.
1781           *   OpenGL puts 16bit outputs in VARYING_SLOT_VAR0_16BIT.
1782           */
1783          nir_def *store_val = nir_vec(b, &s->outputs[slot][start], (unsigned)count);
1784          nir_store_shared(b, store_val, addr, .base = packed_location * 16 + start * 4);
1785       }
1786    }
1787 
1788    unsigned num_32bit_outputs = util_bitcount64(b->shader->info.outputs_written);
1789    u_foreach_bit64(slot, xfb_outputs_16bit) {
1790       unsigned packed_location = num_32bit_outputs +
1791          util_bitcount(b->shader->info.outputs_written_16bit & BITFIELD_MASK(slot));
1792 
1793       unsigned mask_lo = xfb_mask_16bit_lo[slot];
1794       unsigned mask_hi = xfb_mask_16bit_hi[slot];
1795 
1796       /* Clear unused components. */
1797       for (unsigned i = 0; i < 4; i++) {
1798          if (!s->outputs_16bit_lo[slot][i])
1799             mask_lo &= ~BITFIELD_BIT(i);
1800          if (!s->outputs_16bit_hi[slot][i])
1801             mask_hi &= ~BITFIELD_BIT(i);
1802       }
1803 
1804       nir_def **outputs_lo = s->outputs_16bit_lo[slot];
1805       nir_def **outputs_hi = s->outputs_16bit_hi[slot];
1806       nir_def *undef = nir_undef(b, 1, 16);
1807 
1808       unsigned mask = mask_lo | mask_hi;
1809       while (mask) {
1810          int start, count;
1811          u_bit_scan_consecutive_range(&mask, &start, &count);
1812 
1813          nir_def *values[4] = {0};
1814          for (int c = start; c < start + count; ++c) {
1815             nir_def *lo = mask_lo & BITFIELD_BIT(c) ? outputs_lo[c] : undef;
1816             nir_def *hi = mask_hi & BITFIELD_BIT(c) ? outputs_hi[c] : undef;
1817 
1818             /* extend 8/16 bit to 32 bit, 64 bit has been lowered */
1819             values[c - start] = nir_pack_32_2x16_split(b, lo, hi);
1820          }
1821 
1822          nir_def *store_val = nir_vec(b, values, (unsigned)count);
1823          nir_store_shared(b, store_val, addr, .base = packed_location * 16 + start * 4);
1824       }
1825    }
1826 }
1827 
1828 static void
ngg_build_streamout_buffer_info(nir_builder * b,nir_xfb_info * info,bool has_xfb_prim_query,nir_def * scratch_base,nir_def * tid_in_tg,nir_def * gen_prim[4],nir_def * prim_stride_ret[4],nir_def * so_buffer_ret[4],nir_def * buffer_offsets_ret[4],nir_def * emit_prim_ret[4])1829 ngg_build_streamout_buffer_info(nir_builder *b,
1830                                 nir_xfb_info *info,
1831                                 bool has_xfb_prim_query,
1832                                 nir_def *scratch_base,
1833                                 nir_def *tid_in_tg,
1834                                 nir_def *gen_prim[4],
1835                                 nir_def *prim_stride_ret[4],
1836                                 nir_def *so_buffer_ret[4],
1837                                 nir_def *buffer_offsets_ret[4],
1838                                 nir_def *emit_prim_ret[4])
1839 {
1840    nir_def *undef = nir_undef(b, 1, 32);
1841 
1842    /* For radeonsi which pass this value by arg when VS. Streamout need accurate
1843     * num-vert-per-prim for writing correct amount of data to buffer.
1844     */
1845    nir_def *num_vert_per_prim = nir_load_num_vertices_per_primitive_amd(b);
1846    for (unsigned buffer = 0; buffer < 4; buffer++) {
1847       if (!(info->buffers_written & BITFIELD_BIT(buffer)))
1848          continue;
1849 
1850       assert(info->buffers[buffer].stride);
1851 
1852       prim_stride_ret[buffer] =
1853          nir_imul_imm(b, num_vert_per_prim, info->buffers[buffer].stride);
1854       so_buffer_ret[buffer] = nir_load_streamout_buffer_amd(b, .base = buffer);
1855    }
1856 
1857    nir_if *if_invocation_0 = nir_push_if(b, nir_ieq_imm(b, tid_in_tg, 0));
1858    {
1859       nir_def *workgroup_buffer_sizes[4];
1860       for (unsigned buffer = 0; buffer < 4; buffer++) {
1861          if (info->buffers_written & BITFIELD_BIT(buffer)) {
1862             nir_def *buffer_size = nir_channel(b, so_buffer_ret[buffer], 2);
1863             /* In radeonsi, we may not know if a feedback buffer has been bound when
1864              * compile time, so have to check buffer size in runtime to disable the
1865              * GDS update for unbind buffer to prevent the case that previous draw
1866              * compiled with streamout but does not bind feedback buffer miss update
1867              * GDS which will affect current draw's streamout.
1868              */
1869             nir_def *buffer_valid = nir_ine_imm(b, buffer_size, 0);
1870             nir_def *inc_buffer_size =
1871                nir_imul(b, gen_prim[info->buffer_to_stream[buffer]], prim_stride_ret[buffer]);
1872             workgroup_buffer_sizes[buffer] =
1873                nir_bcsel(b, buffer_valid, inc_buffer_size, nir_imm_int(b, 0));
1874          } else
1875             workgroup_buffer_sizes[buffer] = undef;
1876       }
1877 
1878       nir_def *ordered_id = nir_load_ordered_id_amd(b);
1879       /* Get current global offset of buffer and increase by amount of
1880        * workgroup buffer size. This is an ordered operation sorted by
1881        * ordered_id; Each buffer info is in a channel of a vec4.
1882        */
1883       nir_def *buffer_offsets =
1884          nir_ordered_xfb_counter_add_amd(b, ordered_id, nir_vec(b, workgroup_buffer_sizes, 4),
1885                                          /* mask of buffers to update */
1886                                          .write_mask = info->buffers_written);
1887 
1888       nir_def *emit_prim[4];
1889       memcpy(emit_prim, gen_prim, 4 * sizeof(nir_def *));
1890 
1891       nir_def *any_overflow = nir_imm_false(b);
1892       nir_def *overflow_amount[4] = {undef, undef, undef, undef};
1893 
1894       for (unsigned buffer = 0; buffer < 4; buffer++) {
1895          if (!(info->buffers_written & BITFIELD_BIT(buffer)))
1896             continue;
1897 
1898          nir_def *buffer_size = nir_channel(b, so_buffer_ret[buffer], 2);
1899 
1900          /* Only consider overflow for valid feedback buffers because
1901           * otherwise the ordered operation above (GDS atomic return) might
1902           * return non-zero offsets for invalid buffers.
1903           */
1904          nir_def *buffer_valid = nir_ine_imm(b, buffer_size, 0);
1905          nir_def *buffer_offset = nir_channel(b, buffer_offsets, buffer);
1906          buffer_offset = nir_bcsel(b, buffer_valid, buffer_offset, nir_imm_int(b, 0));
1907 
1908          nir_def *remain_size = nir_isub(b, buffer_size, buffer_offset);
1909          nir_def *remain_prim = nir_idiv(b, remain_size, prim_stride_ret[buffer]);
1910          nir_def *overflow = nir_ilt(b, buffer_size, buffer_offset);
1911 
1912          any_overflow = nir_ior(b, any_overflow, overflow);
1913          overflow_amount[buffer] = nir_imax(b, nir_imm_int(b, 0),
1914                                             nir_isub(b, buffer_offset, buffer_size));
1915 
1916          unsigned stream = info->buffer_to_stream[buffer];
1917          /* when previous workgroup overflow, we can't emit any primitive */
1918          emit_prim[stream] = nir_bcsel(
1919             b, overflow, nir_imm_int(b, 0),
1920             /* we can emit part primitives, limited by smallest buffer */
1921             nir_imin(b, emit_prim[stream], remain_prim));
1922 
1923          /* Save to LDS for being accessed by other waves in this workgroup. */
1924          nir_store_shared(b, buffer_offset, scratch_base, .base = buffer * 4);
1925       }
1926 
1927       /* We have to fix up the streamout offsets if we overflowed because they determine
1928        * the vertex count for DrawTransformFeedback.
1929        */
1930       nir_if *if_any_overflow = nir_push_if(b, any_overflow);
1931       {
1932          nir_xfb_counter_sub_amd(b, nir_vec(b, overflow_amount, 4),
1933                                  /* mask of buffers to update */
1934                                  .write_mask = info->buffers_written);
1935       }
1936       nir_pop_if(b, if_any_overflow);
1937 
1938       /* Save to LDS for being accessed by other waves in this workgroup. */
1939       for (unsigned stream = 0; stream < 4; stream++) {
1940          if (!(info->streams_written & BITFIELD_BIT(stream)))
1941             continue;
1942 
1943          nir_store_shared(b, emit_prim[stream], scratch_base, .base = 16 + stream * 4);
1944       }
1945 
1946       /* Update shader query. */
1947       if (has_xfb_prim_query) {
1948          nir_if *if_shader_query = nir_push_if(b, nir_load_prim_xfb_query_enabled_amd(b));
1949          {
1950             for (unsigned stream = 0; stream < 4; stream++) {
1951                if (info->streams_written & BITFIELD_BIT(stream))
1952                   nir_atomic_add_xfb_prim_count_amd(b, emit_prim[stream], .stream_id = stream);
1953             }
1954          }
1955          nir_pop_if(b, if_shader_query);
1956       }
1957    }
1958    nir_pop_if(b, if_invocation_0);
1959 
1960    nir_barrier(b, .execution_scope = SCOPE_WORKGROUP,
1961                       .memory_scope = SCOPE_WORKGROUP,
1962                       .memory_semantics = NIR_MEMORY_ACQ_REL,
1963                       .memory_modes = nir_var_mem_shared);
1964 
1965    /* Fetch the per-buffer offsets in all waves. */
1966    for (unsigned buffer = 0; buffer < 4; buffer++) {
1967       if (!(info->buffers_written & BITFIELD_BIT(buffer)))
1968          continue;
1969 
1970       buffer_offsets_ret[buffer] =
1971          nir_load_shared(b, 1, 32, scratch_base, .base = buffer * 4);
1972    }
1973 
1974    /* Fetch the per-stream emit prim in all waves. */
1975    for (unsigned stream = 0; stream < 4; stream++) {
1976       if (!(info->streams_written & BITFIELD_BIT(stream)))
1977             continue;
1978 
1979       emit_prim_ret[stream] =
1980          nir_load_shared(b, 1, 32, scratch_base, .base = 16 + stream * 4);
1981    }
1982 }
1983 
1984 static void
ngg_build_streamout_vertex(nir_builder * b,nir_xfb_info * info,unsigned stream,nir_def * so_buffer[4],nir_def * buffer_offsets[4],nir_def * vtx_buffer_idx,nir_def * vtx_lds_addr,shader_output_types * output_types)1985 ngg_build_streamout_vertex(nir_builder *b, nir_xfb_info *info,
1986                            unsigned stream, nir_def *so_buffer[4],
1987                            nir_def *buffer_offsets[4],
1988                            nir_def *vtx_buffer_idx, nir_def *vtx_lds_addr,
1989                            shader_output_types *output_types)
1990 {
1991    nir_def *vtx_buffer_offsets[4];
1992    for (unsigned buffer = 0; buffer < 4; buffer++) {
1993       if (!(info->buffers_written & BITFIELD_BIT(buffer)))
1994          continue;
1995 
1996       nir_def *offset = nir_imul_imm(b, vtx_buffer_idx, info->buffers[buffer].stride);
1997       vtx_buffer_offsets[buffer] = nir_iadd(b, buffer_offsets[buffer], offset);
1998    }
1999 
2000    for (unsigned i = 0; i < info->output_count; i++) {
2001       nir_xfb_output_info *out = info->outputs + i;
2002       if (!out->component_mask || info->buffer_to_stream[out->buffer] != stream)
2003          continue;
2004 
2005       unsigned base;
2006       if (out->location >= VARYING_SLOT_VAR0_16BIT) {
2007          base =
2008             util_bitcount64(b->shader->info.outputs_written) +
2009             util_bitcount(b->shader->info.outputs_written_16bit &
2010                           BITFIELD_MASK(out->location - VARYING_SLOT_VAR0_16BIT));
2011       } else {
2012          base =
2013             util_bitcount64(b->shader->info.outputs_written &
2014                             BITFIELD64_MASK(out->location));
2015       }
2016 
2017       unsigned offset = (base * 4 + out->component_offset) * 4;
2018       unsigned count = util_bitcount(out->component_mask);
2019 
2020       assert(u_bit_consecutive(out->component_offset, count) == out->component_mask);
2021 
2022       nir_def *out_data =
2023          nir_load_shared(b, count, 32, vtx_lds_addr, .base = offset);
2024 
2025       /* Up-scaling 16bit outputs to 32bit.
2026        *
2027        * OpenGL ES will put 16bit medium precision varyings to VARYING_SLOT_VAR0_16BIT.
2028        * We need to up-scaling them to 32bit when streamout to buffer.
2029        *
2030        * Vulkan does not allow 8/16bit varyings to be streamout.
2031        */
2032       if (out->location >= VARYING_SLOT_VAR0_16BIT) {
2033          unsigned index = out->location - VARYING_SLOT_VAR0_16BIT;
2034          nir_def *values[4];
2035 
2036          for (int j = 0; j < count; j++) {
2037             unsigned c = out->component_offset + j;
2038             nir_def *v = nir_channel(b, out_data, j);
2039             nir_alu_type t;
2040 
2041             if (out->high_16bits) {
2042                v = nir_unpack_32_2x16_split_y(b, v);
2043                t = output_types->types_16bit_hi[index][c];
2044             } else {
2045                v = nir_unpack_32_2x16_split_x(b, v);
2046                t = output_types->types_16bit_lo[index][c];
2047             }
2048 
2049             t = nir_alu_type_get_base_type(t);
2050             values[j] = nir_convert_to_bit_size(b, v, t, 32);
2051          }
2052 
2053          out_data = nir_vec(b, values, count);
2054       }
2055 
2056       nir_def *zero = nir_imm_int(b, 0);
2057       nir_store_buffer_amd(b, out_data, so_buffer[out->buffer],
2058                            vtx_buffer_offsets[out->buffer],
2059                            zero, zero,
2060                            .base = out->offset,
2061                            .memory_modes = nir_var_mem_ssbo,
2062                            .access = ACCESS_NON_TEMPORAL);
2063    }
2064 }
2065 
2066 static void
ngg_nogs_build_streamout(nir_builder * b,lower_ngg_nogs_state * s)2067 ngg_nogs_build_streamout(nir_builder *b, lower_ngg_nogs_state *s)
2068 {
2069    nir_xfb_info *info = b->shader->xfb_info;
2070 
2071    nir_def *lds_scratch_base = nir_load_lds_ngg_scratch_base_amd(b);
2072 
2073    /* Get global buffer offset where this workgroup will stream out data to. */
2074    nir_def *generated_prim = nir_load_workgroup_num_input_primitives_amd(b);
2075    nir_def *gen_prim_per_stream[4] = {generated_prim, 0, 0, 0};
2076    nir_def *emit_prim_per_stream[4] = {0};
2077    nir_def *buffer_offsets[4] = {0};
2078    nir_def *so_buffer[4] = {0};
2079    nir_def *prim_stride[4] = {0};
2080    nir_def *tid_in_tg = nir_load_local_invocation_index(b);
2081    ngg_build_streamout_buffer_info(b, info, s->options->has_xfb_prim_query,
2082                                    lds_scratch_base, tid_in_tg,
2083                                    gen_prim_per_stream, prim_stride,
2084                                    so_buffer, buffer_offsets,
2085                                    emit_prim_per_stream);
2086 
2087    /* Write out primitive data */
2088    nir_if *if_emit = nir_push_if(b, nir_ilt(b, tid_in_tg, emit_prim_per_stream[0]));
2089    {
2090       unsigned vtx_lds_stride = (b->shader->num_outputs * 4 + 1) * 4;
2091       nir_def *num_vert_per_prim = nir_load_num_vertices_per_primitive_amd(b);
2092       nir_def *vtx_buffer_idx = nir_imul(b, tid_in_tg, num_vert_per_prim);
2093 
2094       for (unsigned i = 0; i < s->options->num_vertices_per_primitive; i++) {
2095          nir_if *if_valid_vertex =
2096             nir_push_if(b, nir_igt_imm(b, num_vert_per_prim, i));
2097          {
2098             nir_def *vtx_lds_idx = nir_load_var(b, s->gs_vtx_indices_vars[i]);
2099             nir_def *vtx_lds_addr = pervertex_lds_addr(b, vtx_lds_idx, vtx_lds_stride);
2100             ngg_build_streamout_vertex(b, info, 0, so_buffer, buffer_offsets,
2101                                        nir_iadd_imm(b, vtx_buffer_idx, i),
2102                                        vtx_lds_addr, &s->output_types);
2103          }
2104          nir_pop_if(b, if_valid_vertex);
2105       }
2106    }
2107    nir_pop_if(b, if_emit);
2108 
2109    /* Wait streamout memory ops done before export primitive, otherwise it
2110     * may not finish when shader ends.
2111     *
2112     * If a shader has no param exports, rasterization can start before
2113     * the shader finishes and thus memory stores might not finish before
2114     * the pixel shader starts.
2115     *
2116     * TODO: we only need this when no param exports.
2117     *
2118     * TODO: not sure if we need this barrier when late prim export, as I
2119     *       can't observe test fail without this barrier.
2120     */
2121    nir_scoped_memory_barrier(b, SCOPE_DEVICE, NIR_MEMORY_RELEASE, nir_var_mem_ssbo);
2122 }
2123 
2124 static unsigned
ngg_nogs_get_pervertex_lds_size(gl_shader_stage stage,unsigned shader_num_outputs,bool streamout_enabled,bool export_prim_id,bool has_user_edgeflags)2125 ngg_nogs_get_pervertex_lds_size(gl_shader_stage stage,
2126                                 unsigned shader_num_outputs,
2127                                 bool streamout_enabled,
2128                                 bool export_prim_id,
2129                                 bool has_user_edgeflags)
2130 {
2131    unsigned pervertex_lds_bytes = 0;
2132 
2133    if (streamout_enabled) {
2134       /* The extra dword is used to avoid LDS bank conflicts and store the primitive id.
2135        * TODO: only alloc space for outputs that really need streamout.
2136        */
2137       pervertex_lds_bytes = (shader_num_outputs * 4 + 1) * 4;
2138    }
2139 
2140    bool need_prim_id_store_shared = export_prim_id && stage == MESA_SHADER_VERTEX;
2141    if (need_prim_id_store_shared || has_user_edgeflags) {
2142       unsigned size = 0;
2143       if (need_prim_id_store_shared)
2144          size += 4;
2145       if (has_user_edgeflags)
2146          size += 4;
2147 
2148       /* pad to odd dwords to avoid LDS bank conflict */
2149       size |= 4;
2150 
2151       pervertex_lds_bytes = MAX2(pervertex_lds_bytes, size);
2152    }
2153 
2154    return pervertex_lds_bytes;
2155 }
2156 
2157 static void
ngg_nogs_gather_outputs(nir_builder * b,struct exec_list * cf_list,lower_ngg_nogs_state * s)2158 ngg_nogs_gather_outputs(nir_builder *b, struct exec_list *cf_list, lower_ngg_nogs_state *s)
2159 {
2160    /* Assume:
2161     * - the shader used nir_lower_io_to_temporaries
2162     * - 64-bit outputs are lowered
2163     * - no indirect indexing is present
2164     */
2165    struct nir_cf_node *first_node =
2166       exec_node_data(nir_cf_node, exec_list_get_head(cf_list), node);
2167 
2168    for (nir_block *block = nir_cf_node_cf_tree_first(first_node); block != NULL;
2169         block = nir_block_cf_tree_next(block)) {
2170       nir_foreach_instr_safe (instr, block) {
2171          if (instr->type != nir_instr_type_intrinsic)
2172             continue;
2173 
2174          nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
2175          if (intrin->intrinsic != nir_intrinsic_store_output)
2176             continue;
2177 
2178          assert(nir_src_is_const(intrin->src[1]) && !nir_src_as_uint(intrin->src[1]));
2179 
2180          nir_io_semantics sem = nir_intrinsic_io_semantics(intrin);
2181          unsigned slot = sem.location;
2182 
2183          nir_def **output;
2184          nir_alu_type *type;
2185          if (slot >= VARYING_SLOT_VAR0_16BIT) {
2186             unsigned index = slot - VARYING_SLOT_VAR0_16BIT;
2187             if (sem.high_16bits) {
2188                output = s->outputs_16bit_hi[index];
2189                type = s->output_types.types_16bit_hi[index];
2190             } else {
2191                output = s->outputs_16bit_lo[index];
2192                type = s->output_types.types_16bit_lo[index];
2193             }
2194          } else {
2195             output = s->outputs[slot];
2196             type = s->output_types.types[slot];
2197          }
2198 
2199          unsigned component = nir_intrinsic_component(intrin);
2200          unsigned write_mask = nir_intrinsic_write_mask(intrin);
2201          nir_alu_type src_type = nir_intrinsic_src_type(intrin);
2202 
2203          u_foreach_bit (i, write_mask) {
2204             unsigned c = component + i;
2205             output[c] = nir_channel(b, intrin->src[0].ssa, i);
2206             type[c] = src_type;
2207          }
2208 
2209          /* remove all store output instructions */
2210          nir_instr_remove(instr);
2211       }
2212    }
2213 }
2214 
2215 static unsigned
gather_vs_outputs(nir_builder * b,vs_output * outputs,const uint8_t * param_offsets,nir_def * (* data)[4],nir_def * (* data_16bit_lo)[4],nir_def * (* data_16bit_hi)[4])2216 gather_vs_outputs(nir_builder *b, vs_output *outputs,
2217                   const uint8_t *param_offsets,
2218                   nir_def *(*data)[4],
2219                   nir_def *(*data_16bit_lo)[4],
2220                   nir_def *(*data_16bit_hi)[4])
2221 {
2222    unsigned num_outputs = 0;
2223    u_foreach_bit64 (slot, b->shader->info.outputs_written) {
2224       if (param_offsets[slot] > AC_EXP_PARAM_OFFSET_31)
2225          continue;
2226 
2227       nir_def **output = data[slot];
2228 
2229       /* skip output if no one written before */
2230       if (!output[0] && !output[1] && !output[2] && !output[3])
2231          continue;
2232 
2233       outputs[num_outputs].slot = slot;
2234       for (int i = 0; i < 4; i++) {
2235          nir_def *chan = output[i];
2236          /* RADV implements 16-bit outputs as 32-bit with VARYING_SLOT_VAR0-31. */
2237          outputs[num_outputs].chan[i] = chan && chan->bit_size == 16 ? nir_u2u32(b, chan) : chan;
2238       }
2239       num_outputs++;
2240    }
2241 
2242    u_foreach_bit (i, b->shader->info.outputs_written_16bit) {
2243       unsigned slot = VARYING_SLOT_VAR0_16BIT + i;
2244       if (param_offsets[slot] > AC_EXP_PARAM_OFFSET_31)
2245          continue;
2246 
2247       nir_def **output_lo = data_16bit_lo[i];
2248       nir_def **output_hi = data_16bit_hi[i];
2249 
2250       /* skip output if no one written before */
2251       if (!output_lo[0] && !output_lo[1] && !output_lo[2] && !output_lo[3] &&
2252           !output_hi[0] && !output_hi[1] && !output_hi[2] && !output_hi[3])
2253          continue;
2254 
2255       vs_output *output = &outputs[num_outputs++];
2256       output->slot = slot;
2257 
2258       nir_def *undef = nir_undef(b, 1, 16);
2259       for (int j = 0; j < 4; j++) {
2260          nir_def *lo = output_lo[j] ? output_lo[j] : undef;
2261          nir_def *hi = output_hi[j] ? output_hi[j] : undef;
2262          if (output_lo[j] || output_hi[j])
2263             output->chan[j] = nir_pack_32_2x16_split(b, lo, hi);
2264          else
2265             output->chan[j] = NULL;
2266       }
2267    }
2268 
2269    return num_outputs;
2270 }
2271 
2272 static void
create_vertex_param_phis(nir_builder * b,unsigned num_outputs,vs_output * outputs)2273 create_vertex_param_phis(nir_builder *b, unsigned num_outputs, vs_output *outputs)
2274 {
2275    nir_def *undef = nir_undef(b, 1, 32); /* inserted at the start of the shader */
2276 
2277    for (unsigned i = 0; i < num_outputs; i++) {
2278       for (unsigned j = 0; j < 4; j++) {
2279          if (outputs[i].chan[j])
2280             outputs[i].chan[j] = nir_if_phi(b, outputs[i].chan[j], undef);
2281       }
2282    }
2283 }
2284 
2285 static void
export_vertex_params_gfx11(nir_builder * b,nir_def * export_tid,nir_def * num_export_threads,unsigned num_outputs,vs_output * outputs,const uint8_t * vs_output_param_offset)2286 export_vertex_params_gfx11(nir_builder *b, nir_def *export_tid, nir_def *num_export_threads,
2287                            unsigned num_outputs, vs_output *outputs,
2288                            const uint8_t *vs_output_param_offset)
2289 {
2290    nir_def *attr_rsrc = nir_load_ring_attr_amd(b);
2291 
2292    /* We should always store full vec4s in groups of 8 lanes for the best performance even if
2293     * some of them are garbage or have unused components, so align the number of export threads
2294     * to 8.
2295     */
2296    num_export_threads = nir_iand_imm(b, nir_iadd_imm(b, num_export_threads, 7), ~7);
2297    if (!export_tid)
2298       nir_push_if(b, nir_is_subgroup_invocation_lt_amd(b, num_export_threads));
2299    else
2300       nir_push_if(b, nir_ult(b, export_tid, num_export_threads));
2301 
2302    nir_def *attr_offset = nir_load_ring_attr_offset_amd(b);
2303    nir_def *vindex = nir_load_local_invocation_index(b);
2304    nir_def *voffset = nir_imm_int(b, 0);
2305    nir_def *undef = nir_undef(b, 1, 32);
2306 
2307    uint32_t exported_params = 0;
2308 
2309    for (unsigned i = 0; i < num_outputs; i++) {
2310       gl_varying_slot slot = outputs[i].slot;
2311       unsigned offset = vs_output_param_offset[slot];
2312 
2313       /* Since vs_output_param_offset[] can map multiple varying slots to
2314        * the same param export index (that's radeonsi-specific behavior),
2315        * we need to do this so as not to emit duplicated exports.
2316        */
2317       if (exported_params & BITFIELD_BIT(offset))
2318          continue;
2319 
2320       nir_def *comp[4];
2321       for (unsigned j = 0; j < 4; j++)
2322          comp[j] = outputs[i].chan[j] ? outputs[i].chan[j] : undef;
2323       nir_store_buffer_amd(b, nir_vec(b, comp, 4), attr_rsrc, voffset, attr_offset, vindex,
2324                            .base = offset * 16,
2325                            .memory_modes = nir_var_shader_out,
2326                            .access = ACCESS_COHERENT | ACCESS_IS_SWIZZLED_AMD);
2327       exported_params |= BITFIELD_BIT(offset);
2328    }
2329 
2330    nir_pop_if(b, NULL);
2331 }
2332 
must_wait_attr_ring(enum amd_gfx_level gfx_level,bool has_param_exports)2333 static bool must_wait_attr_ring(enum amd_gfx_level gfx_level, bool has_param_exports)
2334 {
2335    return (gfx_level == GFX11 || gfx_level == GFX11_5) && has_param_exports;
2336 }
2337 
2338 static void
export_pos0_wait_attr_ring(nir_builder * b,nir_if * if_es_thread,nir_def * outputs[VARYING_SLOT_MAX][4],const ac_nir_lower_ngg_options * options)2339 export_pos0_wait_attr_ring(nir_builder *b, nir_if *if_es_thread, nir_def *outputs[VARYING_SLOT_MAX][4], const ac_nir_lower_ngg_options *options)
2340 {
2341    b->cursor = nir_after_cf_node(&if_es_thread->cf_node);
2342 
2343    /* Create phi for the position output values. */
2344    vs_output pos_output = {
2345       .slot = VARYING_SLOT_POS,
2346       .chan = {
2347          outputs[VARYING_SLOT_POS][0],
2348          outputs[VARYING_SLOT_POS][1],
2349          outputs[VARYING_SLOT_POS][2],
2350          outputs[VARYING_SLOT_POS][3],
2351       },
2352    };
2353    create_vertex_param_phis(b, 1, &pos_output);
2354 
2355    b->cursor = nir_after_cf_list(&b->impl->body);
2356 
2357    /* Wait for attribute stores to finish. */
2358    nir_barrier(b, .execution_scope = SCOPE_SUBGROUP,
2359                   .memory_scope = SCOPE_DEVICE,
2360                   .memory_semantics = NIR_MEMORY_RELEASE,
2361                   .memory_modes = nir_var_mem_ssbo | nir_var_shader_out | nir_var_mem_global | nir_var_image);
2362 
2363    /* Export just the pos0 output. */
2364    nir_if *if_export_empty_pos = nir_push_if(b, if_es_thread->condition.ssa);
2365    {
2366       nir_def *pos_output_array[VARYING_SLOT_MAX][4] = {0};
2367       memcpy(pos_output_array[VARYING_SLOT_POS], pos_output.chan, sizeof(pos_output.chan));
2368 
2369       ac_nir_export_position(b, options->gfx_level,
2370                              options->clip_cull_dist_mask,
2371                              !options->has_param_exports,
2372                              options->force_vrs, true,
2373                              VARYING_BIT_POS, pos_output_array, NULL);
2374    }
2375    nir_pop_if(b, if_export_empty_pos);
2376 }
2377 
2378 static void
nogs_export_vertex_params(nir_builder * b,nir_function_impl * impl,nir_if * if_es_thread,nir_def * num_es_threads,lower_ngg_nogs_state * s)2379 nogs_export_vertex_params(nir_builder *b, nir_function_impl *impl,
2380                           nir_if *if_es_thread, nir_def *num_es_threads,
2381                           lower_ngg_nogs_state *s)
2382 {
2383    if (!s->options->has_param_exports)
2384       return;
2385 
2386    if (s->options->gfx_level >= GFX11) {
2387       /* Export varyings for GFX11+ */
2388       vs_output outputs[64];
2389       const unsigned num_outputs =
2390          gather_vs_outputs(b, outputs,
2391                            s->options->vs_output_param_offset,
2392                            s->outputs,
2393                            s->outputs_16bit_lo,
2394                            s->outputs_16bit_hi);
2395 
2396       if (!num_outputs)
2397          return;
2398 
2399       b->cursor = nir_after_cf_node(&if_es_thread->cf_node);
2400       create_vertex_param_phis(b, num_outputs, outputs);
2401 
2402       b->cursor = nir_after_impl(impl);
2403       if (!num_es_threads)
2404          num_es_threads = nir_load_merged_wave_info_amd(b);
2405 
2406       export_vertex_params_gfx11(b, NULL, num_es_threads, num_outputs, outputs,
2407                                  s->options->vs_output_param_offset);
2408    } else {
2409       ac_nir_export_parameters(b, s->options->vs_output_param_offset,
2410                                  b->shader->info.outputs_written,
2411                                  b->shader->info.outputs_written_16bit,
2412                                  s->outputs, s->outputs_16bit_lo,
2413                                  s->outputs_16bit_hi);
2414    }
2415 }
2416 
2417 void
ac_nir_lower_ngg_nogs(nir_shader * shader,const ac_nir_lower_ngg_options * options)2418 ac_nir_lower_ngg_nogs(nir_shader *shader, const ac_nir_lower_ngg_options *options)
2419 {
2420    nir_function_impl *impl = nir_shader_get_entrypoint(shader);
2421    assert(impl);
2422    assert(options->max_workgroup_size && options->wave_size);
2423    assert(!(options->can_cull && options->passthrough));
2424 
2425    nir_variable *position_value_var = nir_local_variable_create(impl, glsl_vec4_type(), "position_value");
2426    nir_variable *prim_exp_arg_var = nir_local_variable_create(impl, glsl_uint_type(), "prim_exp_arg");
2427    nir_variable *es_accepted_var =
2428       options->can_cull ? nir_local_variable_create(impl, glsl_bool_type(), "es_accepted") : NULL;
2429    nir_variable *gs_accepted_var =
2430       options->can_cull ? nir_local_variable_create(impl, glsl_bool_type(), "gs_accepted") : NULL;
2431    nir_variable *gs_exported_var = nir_local_variable_create(impl, glsl_bool_type(), "gs_exported");
2432 
2433    bool streamout_enabled = shader->xfb_info && !options->disable_streamout;
2434    bool has_user_edgeflags =
2435       options->use_edgeflags && (shader->info.outputs_written & VARYING_BIT_EDGE);
2436    /* streamout need to be done before either prim or vertex export. Because when no
2437     * param export, rasterization can start right after prim and vertex export,
2438     * which left streamout buffer writes un-finished.
2439     *
2440     * Always use late prim export when user edge flags are enabled.
2441     * This is because edge flags are written by ES threads but they
2442     * are exported by GS threads as part of th primitive export.
2443     */
2444    bool early_prim_export =
2445       options->early_prim_export && !(streamout_enabled || has_user_edgeflags);
2446 
2447    lower_ngg_nogs_state state = {
2448       .options = options,
2449       .early_prim_export = early_prim_export,
2450       .streamout_enabled = streamout_enabled,
2451       .position_value_var = position_value_var,
2452       .prim_exp_arg_var = prim_exp_arg_var,
2453       .es_accepted_var = es_accepted_var,
2454       .gs_accepted_var = gs_accepted_var,
2455       .gs_exported_var = gs_exported_var,
2456       .max_num_waves = DIV_ROUND_UP(options->max_workgroup_size, options->wave_size),
2457       .has_user_edgeflags = has_user_edgeflags,
2458    };
2459 
2460    const bool need_prim_id_store_shared =
2461       options->export_primitive_id && shader->info.stage == MESA_SHADER_VERTEX;
2462 
2463    if (options->export_primitive_id) {
2464       nir_variable *prim_id_var = nir_variable_create(shader, nir_var_shader_out, glsl_uint_type(), "ngg_prim_id");
2465       prim_id_var->data.location = VARYING_SLOT_PRIMITIVE_ID;
2466       prim_id_var->data.driver_location = VARYING_SLOT_PRIMITIVE_ID;
2467       prim_id_var->data.interpolation = INTERP_MODE_NONE;
2468       shader->info.outputs_written |= VARYING_BIT_PRIMITIVE_ID;
2469    }
2470 
2471    nir_builder builder = nir_builder_create(impl);
2472    nir_builder *b = &builder; /* This is to avoid the & */
2473 
2474    if (options->can_cull) {
2475       analyze_shader_before_culling(shader, &state);
2476       save_reusable_variables(b, &state);
2477    }
2478 
2479    nir_cf_list extracted;
2480    nir_cf_extract(&extracted, nir_before_impl(impl),
2481                   nir_after_impl(impl));
2482    b->cursor = nir_before_impl(impl);
2483 
2484    ngg_nogs_init_vertex_indices_vars(b, impl, &state);
2485 
2486    /* Emit primitives generated query code here, so that
2487     * it executes before culling and isn't in the extracted CF.
2488     */
2489    nogs_prim_gen_query(b, &state);
2490 
2491    /* Whether a shader invocation should export a primitive,
2492     * initialize to all invocations that have an input primitive.
2493     */
2494    nir_store_var(b, gs_exported_var, has_input_primitive(b), 0x1u);
2495 
2496    if (!options->can_cull) {
2497       /* Newer chips can use PRIMGEN_PASSTHRU_NO_MSG to skip gs_alloc_req for NGG passthrough. */
2498       if (!(options->passthrough && options->family >= CHIP_NAVI23)) {
2499          /* Allocate export space on wave 0 - confirm to the HW that we want to use all possible space */
2500          nir_if *if_wave_0 = nir_push_if(b, nir_ieq_imm(b, nir_load_subgroup_id(b), 0));
2501          {
2502             nir_def *vtx_cnt = nir_load_workgroup_num_input_vertices_amd(b);
2503             nir_def *prim_cnt = nir_load_workgroup_num_input_primitives_amd(b);
2504             alloc_vertices_and_primitives(b, vtx_cnt, prim_cnt);
2505          }
2506          nir_pop_if(b, if_wave_0);
2507       }
2508 
2509       /* Take care of early primitive export, otherwise just pack the primitive export argument */
2510       if (state.early_prim_export)
2511          emit_ngg_nogs_prim_export(b, &state, NULL);
2512       else
2513          nir_store_var(b, prim_exp_arg_var, emit_ngg_nogs_prim_exp_arg(b, &state), 0x1u);
2514    } else {
2515       add_deferred_attribute_culling(b, &extracted, &state);
2516       b->cursor = nir_after_impl(impl);
2517 
2518       if (state.early_prim_export)
2519          emit_ngg_nogs_prim_export(b, &state, nir_load_var(b, state.prim_exp_arg_var));
2520 
2521       /* Wait for culling to finish using LDS. */
2522       if (need_prim_id_store_shared || has_user_edgeflags) {
2523          nir_barrier(b, .execution_scope = SCOPE_WORKGROUP,
2524                                .memory_scope = SCOPE_WORKGROUP,
2525                                .memory_semantics = NIR_MEMORY_ACQ_REL,
2526                                .memory_modes = nir_var_mem_shared);
2527       }
2528    }
2529 
2530    /* determine the LDS vertex stride */
2531    state.pervertex_lds_bytes =
2532       ngg_nogs_get_pervertex_lds_size(shader->info.stage,
2533                                       shader->num_outputs,
2534                                       state.streamout_enabled,
2535                                       options->export_primitive_id,
2536                                       state.has_user_edgeflags);
2537 
2538    if (need_prim_id_store_shared) {
2539       emit_ngg_nogs_prim_id_store_shared(b, &state);
2540 
2541       /* Wait for GS threads to store primitive ID in LDS. */
2542       nir_barrier(b, .execution_scope = SCOPE_WORKGROUP, .memory_scope = SCOPE_WORKGROUP,
2543                             .memory_semantics = NIR_MEMORY_ACQ_REL, .memory_modes = nir_var_mem_shared);
2544    }
2545 
2546    nir_def *es_thread =
2547       options->can_cull ? nir_load_var(b, es_accepted_var) : has_input_vertex(b);
2548 
2549    /* Calculate the bit count here instead of below for lower SGPR usage and better ALU
2550     * scheduling.
2551     */
2552    nir_def *num_es_threads = NULL;
2553    if (state.options->gfx_level >= GFX11 && options->can_cull) {
2554       nir_def *es_accepted_mask =
2555          nir_ballot(b, 1, options->wave_size, nir_load_var(b, es_accepted_var));
2556       num_es_threads = nir_bit_count(b, es_accepted_mask);
2557    }
2558 
2559    nir_if *if_es_thread = nir_push_if(b, es_thread);
2560    {
2561       /* Run the actual shader */
2562       nir_cf_reinsert(&extracted, b->cursor);
2563       b->cursor = nir_after_cf_list(&if_es_thread->then_list);
2564 
2565       if (options->export_primitive_id)
2566          emit_store_ngg_nogs_es_primitive_id(b, &state);
2567    }
2568    nir_pop_if(b, if_es_thread);
2569 
2570    if (options->can_cull) {
2571       /* Replace uniforms. */
2572       apply_reusable_variables(b, &state);
2573 
2574       /* Remove the redundant position output. */
2575       remove_extra_pos_outputs(shader, &state);
2576 
2577       /* After looking at the performance in apps eg. Doom Eternal, and The Witcher 3,
2578        * it seems that it's best to put the position export always at the end, and
2579        * then let ACO schedule it up (slightly) only when early prim export is used.
2580        */
2581       b->cursor = nir_after_cf_list(&if_es_thread->then_list);
2582 
2583       nir_def *pos_val = nir_load_var(b, state.position_value_var);
2584       for (int i = 0; i < 4; i++)
2585          state.outputs[VARYING_SLOT_POS][i] = nir_channel(b, pos_val, i);
2586    }
2587 
2588    /* Gather outputs data and types */
2589    b->cursor = nir_after_cf_list(&if_es_thread->then_list);
2590    ngg_nogs_gather_outputs(b, &if_es_thread->then_list, &state);
2591 
2592    if (state.has_user_edgeflags)
2593       ngg_nogs_store_edgeflag_to_lds(b, &state);
2594 
2595    if (state.streamout_enabled) {
2596       /* TODO: support culling after streamout. */
2597       assert(!options->can_cull);
2598 
2599       ngg_nogs_store_xfb_outputs_to_lds(b, &state);
2600 
2601       b->cursor = nir_after_impl(impl);
2602       ngg_nogs_build_streamout(b, &state);
2603    }
2604 
2605    /* Take care of late primitive export */
2606    if (!state.early_prim_export) {
2607       b->cursor = nir_after_impl(impl);
2608       emit_ngg_nogs_prim_export(b, &state, nir_load_var(b, prim_exp_arg_var));
2609    }
2610 
2611    uint64_t export_outputs = shader->info.outputs_written | VARYING_BIT_POS;
2612    if (options->kill_pointsize)
2613       export_outputs &= ~VARYING_BIT_PSIZ;
2614    if (options->kill_layer)
2615       export_outputs &= ~VARYING_BIT_LAYER;
2616 
2617    const bool wait_attr_ring = must_wait_attr_ring(options->gfx_level, options->has_param_exports);
2618    if (wait_attr_ring)
2619       export_outputs &= ~VARYING_BIT_POS;
2620 
2621    b->cursor = nir_after_cf_list(&if_es_thread->then_list);
2622 
2623    ac_nir_export_position(b, options->gfx_level,
2624                           options->clip_cull_dist_mask,
2625                           !options->has_param_exports,
2626                           options->force_vrs, !wait_attr_ring,
2627                           export_outputs, state.outputs, NULL);
2628 
2629    nogs_export_vertex_params(b, impl, if_es_thread, num_es_threads, &state);
2630 
2631    if (wait_attr_ring)
2632       export_pos0_wait_attr_ring(b, if_es_thread, state.outputs, options);
2633 
2634    nir_metadata_preserve(impl, nir_metadata_none);
2635    nir_validate_shader(shader, "after emitting NGG VS/TES");
2636 
2637    /* Cleanup */
2638    nir_opt_dead_write_vars(shader);
2639    nir_lower_vars_to_ssa(shader);
2640    nir_remove_dead_variables(shader, nir_var_function_temp, NULL);
2641    nir_lower_alu_to_scalar(shader, NULL, NULL);
2642    nir_lower_phis_to_scalar(shader, true);
2643 
2644    if (options->can_cull) {
2645       /* It's beneficial to redo these opts after splitting the shader. */
2646       nir_opt_sink(shader, nir_move_load_input | nir_move_const_undef | nir_move_copies);
2647       nir_opt_move(shader, nir_move_load_input | nir_move_copies | nir_move_const_undef);
2648    }
2649 
2650    bool progress;
2651    do {
2652       progress = false;
2653       NIR_PASS(progress, shader, nir_opt_undef);
2654       NIR_PASS(progress, shader, nir_opt_dce);
2655       NIR_PASS(progress, shader, nir_opt_dead_cf);
2656 
2657       if (options->can_cull)
2658          progress |= cleanup_culling_shader_after_dce(shader, b->impl, &state);
2659    } while (progress);
2660 }
2661 
2662 /**
2663  * Return the address of the LDS storage reserved for the N'th vertex,
2664  * where N is in emit order, meaning:
2665  * - during the finale, N is the invocation_index (within the workgroup)
2666  * - during vertex emit, i.e. while the API GS shader invocation is running,
2667  *   N = invocation_index * gs_max_out_vertices + emit_idx
2668  *   where emit_idx is the vertex index in the current API GS invocation.
2669  *
2670  * Goals of the LDS memory layout:
2671  * 1. Eliminate bank conflicts on write for geometry shaders that have all emits
2672  *    in uniform control flow
2673  * 2. Eliminate bank conflicts on read for export if, additionally, there is no
2674  *    culling
2675  * 3. Agnostic to the number of waves (since we don't know it before compiling)
2676  * 4. Allow coalescing of LDS instructions (ds_write_b128 etc.)
2677  * 5. Avoid wasting memory.
2678  *
2679  * We use an AoS layout due to point 4 (this also helps point 3). In an AoS
2680  * layout, elimination of bank conflicts requires that each vertex occupy an
2681  * odd number of dwords. We use the additional dword to store the output stream
2682  * index as well as a flag to indicate whether this vertex ends a primitive
2683  * for rasterization.
2684  *
2685  * Swizzling is required to satisfy points 1 and 2 simultaneously.
2686  *
2687  * Vertices are stored in export order (gsthread * gs_max_out_vertices + emitidx).
2688  * Indices are swizzled in groups of 32, which ensures point 1 without
2689  * disturbing point 2.
2690  *
2691  * \return an LDS pointer to type {[N x i32], [4 x i8]}
2692  */
2693 static nir_def *
ngg_gs_out_vertex_addr(nir_builder * b,nir_def * out_vtx_idx,lower_ngg_gs_state * s)2694 ngg_gs_out_vertex_addr(nir_builder *b, nir_def *out_vtx_idx, lower_ngg_gs_state *s)
2695 {
2696    unsigned write_stride_2exp = ffs(MAX2(b->shader->info.gs.vertices_out, 1)) - 1;
2697 
2698    /* gs_max_out_vertices = 2^(write_stride_2exp) * some odd number */
2699    if (write_stride_2exp) {
2700       nir_def *row = nir_ushr_imm(b, out_vtx_idx, 5);
2701       nir_def *swizzle = nir_iand_imm(b, row, (1u << write_stride_2exp) - 1u);
2702       out_vtx_idx = nir_ixor(b, out_vtx_idx, swizzle);
2703    }
2704 
2705    nir_def *out_vtx_offs = nir_imul_imm(b, out_vtx_idx, s->lds_bytes_per_gs_out_vertex);
2706    return nir_iadd_nuw(b, out_vtx_offs, s->lds_addr_gs_out_vtx);
2707 }
2708 
2709 static nir_def *
ngg_gs_emit_vertex_addr(nir_builder * b,nir_def * gs_vtx_idx,lower_ngg_gs_state * s)2710 ngg_gs_emit_vertex_addr(nir_builder *b, nir_def *gs_vtx_idx, lower_ngg_gs_state *s)
2711 {
2712    nir_def *tid_in_tg = nir_load_local_invocation_index(b);
2713    nir_def *gs_out_vtx_base = nir_imul_imm(b, tid_in_tg, b->shader->info.gs.vertices_out);
2714    nir_def *out_vtx_idx = nir_iadd_nuw(b, gs_out_vtx_base, gs_vtx_idx);
2715 
2716    return ngg_gs_out_vertex_addr(b, out_vtx_idx, s);
2717 }
2718 
2719 static void
ngg_gs_clear_primflags(nir_builder * b,nir_def * num_vertices,unsigned stream,lower_ngg_gs_state * s)2720 ngg_gs_clear_primflags(nir_builder *b, nir_def *num_vertices, unsigned stream, lower_ngg_gs_state *s)
2721 {
2722    char name[32];
2723    snprintf(name, sizeof(name), "clear_primflag_idx_%u", stream);
2724    nir_variable *clear_primflag_idx_var = nir_local_variable_create(b->impl, glsl_uint_type(), name);
2725 
2726    nir_def *zero_u8 = nir_imm_zero(b, 1, 8);
2727    nir_store_var(b, clear_primflag_idx_var, num_vertices, 0x1u);
2728 
2729    nir_loop *loop = nir_push_loop(b);
2730    {
2731       nir_def *clear_primflag_idx = nir_load_var(b, clear_primflag_idx_var);
2732       nir_if *if_break = nir_push_if(b, nir_uge_imm(b, clear_primflag_idx, b->shader->info.gs.vertices_out));
2733       {
2734          nir_jump(b, nir_jump_break);
2735       }
2736       nir_push_else(b, if_break);
2737       {
2738          nir_def *emit_vtx_addr = ngg_gs_emit_vertex_addr(b, clear_primflag_idx, s);
2739          nir_store_shared(b, zero_u8, emit_vtx_addr, .base = s->lds_offs_primflags + stream);
2740          nir_store_var(b, clear_primflag_idx_var, nir_iadd_imm_nuw(b, clear_primflag_idx, 1), 0x1u);
2741       }
2742       nir_pop_if(b, if_break);
2743    }
2744    nir_pop_loop(b, loop);
2745 }
2746 
2747 static bool
lower_ngg_gs_store_output(nir_builder * b,nir_intrinsic_instr * intrin,lower_ngg_gs_state * s)2748 lower_ngg_gs_store_output(nir_builder *b, nir_intrinsic_instr *intrin, lower_ngg_gs_state *s)
2749 {
2750    assert(nir_src_is_const(intrin->src[1]) && !nir_src_as_uint(intrin->src[1]));
2751    b->cursor = nir_before_instr(&intrin->instr);
2752 
2753    unsigned writemask = nir_intrinsic_write_mask(intrin);
2754    unsigned component_offset = nir_intrinsic_component(intrin);
2755    nir_io_semantics io_sem = nir_intrinsic_io_semantics(intrin);
2756 
2757    unsigned location = io_sem.location;
2758 
2759    nir_def *store_val = intrin->src[0].ssa;
2760    nir_alu_type src_type = nir_intrinsic_src_type(intrin);
2761 
2762    /* Small bitsize components consume the same amount of space as 32-bit components,
2763     * but 64-bit ones consume twice as many. (Vulkan spec 15.1.5)
2764     *
2765     * 64-bit IO has been lowered to multi 32-bit IO.
2766     */
2767    assert(store_val->bit_size <= 32);
2768    assert(nir_alu_type_get_type_size(src_type) == store_val->bit_size);
2769 
2770    /* Get corresponding output variable and usage info. */
2771    nir_def **output;
2772    nir_alu_type *type;
2773    gs_output_info *info;
2774    if (location >= VARYING_SLOT_VAR0_16BIT) {
2775       unsigned index = location - VARYING_SLOT_VAR0_16BIT;
2776       assert(index < 16);
2777 
2778       if (io_sem.high_16bits) {
2779          output = s->outputs_16bit_hi[index];
2780          type = s->output_types.types_16bit_hi[index];
2781          info = s->output_info_16bit_hi + index;
2782       } else {
2783          output = s->outputs_16bit_lo[index];
2784          type = s->output_types.types_16bit_lo[index];
2785          info = s->output_info_16bit_lo + index;
2786       }
2787    } else {
2788       assert(location < VARYING_SLOT_MAX);
2789       output = s->outputs[location];
2790       type = s->output_types.types[location];
2791       info = s->output_info + location;
2792    }
2793 
2794    for (unsigned comp = 0; comp < store_val->num_components; ++comp) {
2795       if (!(writemask & (1 << comp)))
2796          continue;
2797       unsigned stream = (io_sem.gs_streams >> (comp * 2)) & 0x3;
2798       if (!(b->shader->info.gs.active_stream_mask & (1 << stream)))
2799          continue;
2800 
2801       unsigned component = component_offset + comp;
2802 
2803       /* The same output component should always belong to the same stream. */
2804       assert(!(info->components_mask & (1 << component)) ||
2805              ((info->stream >> (component * 2)) & 3) == stream);
2806 
2807       /* Components of the same output slot may belong to different streams. */
2808       info->stream |= stream << (component * 2);
2809       info->components_mask |= BITFIELD_BIT(component);
2810 
2811       /* If type is set multiple times, the value must be same. */
2812       assert(type[component] == nir_type_invalid || type[component] == src_type);
2813       type[component] = src_type;
2814 
2815       /* Assume we have called nir_lower_io_to_temporaries which store output in the
2816        * same block as EmitVertex, so we don't need to use nir_variable for outputs.
2817        */
2818       output[component] = nir_channel(b, store_val, comp);
2819    }
2820 
2821    nir_instr_remove(&intrin->instr);
2822    return true;
2823 }
2824 
2825 static unsigned
gs_output_component_mask_with_stream(gs_output_info * info,unsigned stream)2826 gs_output_component_mask_with_stream(gs_output_info *info, unsigned stream)
2827 {
2828    unsigned mask = info->components_mask;
2829    if (!mask)
2830       return 0;
2831 
2832    /* clear component when not requested stream */
2833    for (int i = 0; i < 4; i++) {
2834       if (((info->stream >> (i * 2)) & 3) != stream)
2835          mask &= ~(1 << i);
2836    }
2837 
2838    return mask;
2839 }
2840 
2841 static bool
lower_ngg_gs_emit_vertex_with_counter(nir_builder * b,nir_intrinsic_instr * intrin,lower_ngg_gs_state * s)2842 lower_ngg_gs_emit_vertex_with_counter(nir_builder *b, nir_intrinsic_instr *intrin, lower_ngg_gs_state *s)
2843 {
2844    b->cursor = nir_before_instr(&intrin->instr);
2845 
2846    unsigned stream = nir_intrinsic_stream_id(intrin);
2847    if (!(b->shader->info.gs.active_stream_mask & (1 << stream))) {
2848       nir_instr_remove(&intrin->instr);
2849       return true;
2850    }
2851 
2852    nir_def *gs_emit_vtx_idx = intrin->src[0].ssa;
2853    nir_def *current_vtx_per_prim = intrin->src[1].ssa;
2854    nir_def *gs_emit_vtx_addr = ngg_gs_emit_vertex_addr(b, gs_emit_vtx_idx, s);
2855 
2856    u_foreach_bit64(slot, b->shader->info.outputs_written) {
2857       unsigned packed_location = util_bitcount64((b->shader->info.outputs_written & BITFIELD64_MASK(slot)));
2858       gs_output_info *info = &s->output_info[slot];
2859       nir_def **output = s->outputs[slot];
2860 
2861       unsigned mask = gs_output_component_mask_with_stream(info, stream);
2862       while (mask) {
2863          int start, count;
2864          u_bit_scan_consecutive_range(&mask, &start, &count);
2865          nir_def *values[4] = {0};
2866          for (int c = start; c < start + count; ++c) {
2867             if (!output[c]) {
2868                /* no one write to this output before */
2869                values[c - start] = nir_undef(b, 1, 32);
2870                continue;
2871             }
2872 
2873             /* extend 8/16 bit to 32 bit, 64 bit has been lowered */
2874             values[c - start] = nir_u2uN(b, output[c], 32);
2875          }
2876 
2877          nir_def *store_val = nir_vec(b, values, (unsigned)count);
2878          nir_store_shared(b, store_val, gs_emit_vtx_addr,
2879                           .base = packed_location * 16 + start * 4,
2880                           .align_mul = 4);
2881       }
2882 
2883       /* Clear all outputs (they are undefined after emit_vertex) */
2884       memset(s->outputs[slot], 0, sizeof(s->outputs[slot]));
2885    }
2886 
2887    /* Store 16bit outputs to LDS. */
2888    unsigned num_32bit_outputs = util_bitcount64(b->shader->info.outputs_written);
2889    u_foreach_bit(slot, b->shader->info.outputs_written_16bit) {
2890       unsigned packed_location = num_32bit_outputs +
2891          util_bitcount(b->shader->info.outputs_written_16bit & BITFIELD_MASK(slot));
2892 
2893       unsigned mask_lo = gs_output_component_mask_with_stream(s->output_info_16bit_lo + slot, stream);
2894       unsigned mask_hi = gs_output_component_mask_with_stream(s->output_info_16bit_hi + slot, stream);
2895       unsigned mask = mask_lo | mask_hi;
2896 
2897       nir_def **output_lo = s->outputs_16bit_lo[slot];
2898       nir_def **output_hi = s->outputs_16bit_hi[slot];
2899       nir_def *undef = nir_undef(b, 1, 16);
2900 
2901       while (mask) {
2902          int start, count;
2903          u_bit_scan_consecutive_range(&mask, &start, &count);
2904          nir_def *values[4] = {0};
2905          for (int c = start; c < start + count; ++c) {
2906             nir_def *lo = output_lo[c] ? output_lo[c] : undef;
2907             nir_def *hi = output_hi[c] ? output_hi[c] : undef;
2908 
2909             values[c - start] = nir_pack_32_2x16_split(b, lo, hi);
2910          }
2911 
2912          nir_def *store_val = nir_vec(b, values, (unsigned)count);
2913          nir_store_shared(b, store_val, gs_emit_vtx_addr,
2914                           .base = packed_location * 16 + start * 4,
2915                           .align_mul = 4);
2916       }
2917 
2918       /* Clear all outputs (they are undefined after emit_vertex) */
2919       memset(s->outputs_16bit_lo[slot], 0, sizeof(s->outputs_16bit_lo[slot]));
2920       memset(s->outputs_16bit_hi[slot], 0, sizeof(s->outputs_16bit_hi[slot]));
2921    }
2922 
2923    /* Calculate and store per-vertex primitive flags based on vertex counts:
2924     * - bit 0: whether this vertex finishes a primitive (a real primitive, not the strip)
2925     * - bit 1: whether the primitive index is odd (if we are emitting triangle strips, otherwise always 0)
2926     *          only set when the vertex also finishes the primitive
2927     * - bit 2: whether vertex is live (if culling is enabled: set after culling, otherwise always 1)
2928     */
2929 
2930    nir_def *vertex_live_flag =
2931       !stream && s->options->can_cull
2932          ? nir_ishl_imm(b, nir_b2i32(b, nir_inot(b, nir_load_cull_any_enabled_amd(b))), 2)
2933          : nir_imm_int(b, 0b100);
2934 
2935    nir_def *completes_prim = nir_ige_imm(b, current_vtx_per_prim, s->num_vertices_per_primitive - 1);
2936    nir_def *complete_flag = nir_b2i32(b, completes_prim);
2937 
2938    nir_def *prim_flag = nir_ior(b, vertex_live_flag, complete_flag);
2939    if (s->num_vertices_per_primitive == 3) {
2940       nir_def *odd = nir_iand(b, current_vtx_per_prim, complete_flag);
2941       nir_def *odd_flag = nir_ishl_imm(b, odd, 1);
2942       prim_flag = nir_ior(b, prim_flag, odd_flag);
2943    }
2944 
2945    nir_store_shared(b, nir_u2u8(b, prim_flag), gs_emit_vtx_addr,
2946                     .base = s->lds_offs_primflags + stream,
2947                     .align_mul = 4, .align_offset = stream);
2948 
2949    nir_instr_remove(&intrin->instr);
2950    return true;
2951 }
2952 
2953 static bool
lower_ngg_gs_end_primitive_with_counter(nir_builder * b,nir_intrinsic_instr * intrin,UNUSED lower_ngg_gs_state * s)2954 lower_ngg_gs_end_primitive_with_counter(nir_builder *b, nir_intrinsic_instr *intrin, UNUSED lower_ngg_gs_state *s)
2955 {
2956    b->cursor = nir_before_instr(&intrin->instr);
2957 
2958    /* These are not needed, we can simply remove them */
2959    nir_instr_remove(&intrin->instr);
2960    return true;
2961 }
2962 
2963 static bool
lower_ngg_gs_set_vertex_and_primitive_count(nir_builder * b,nir_intrinsic_instr * intrin,lower_ngg_gs_state * s)2964 lower_ngg_gs_set_vertex_and_primitive_count(nir_builder *b, nir_intrinsic_instr *intrin, lower_ngg_gs_state *s)
2965 {
2966    b->cursor = nir_before_instr(&intrin->instr);
2967 
2968    unsigned stream = nir_intrinsic_stream_id(intrin);
2969    if (stream > 0 && !(b->shader->info.gs.active_stream_mask & (1 << stream))) {
2970       nir_instr_remove(&intrin->instr);
2971       return true;
2972    }
2973 
2974    s->vertex_count[stream] = intrin->src[0].ssa;
2975    s->primitive_count[stream] = intrin->src[1].ssa;
2976 
2977    /* Clear the primitive flags of non-emitted vertices */
2978    if (!nir_src_is_const(intrin->src[0]) || nir_src_as_uint(intrin->src[0]) < b->shader->info.gs.vertices_out)
2979       ngg_gs_clear_primflags(b, intrin->src[0].ssa, stream, s);
2980 
2981    nir_instr_remove(&intrin->instr);
2982    return true;
2983 }
2984 
2985 static bool
lower_ngg_gs_intrinsic(nir_builder * b,nir_instr * instr,void * state)2986 lower_ngg_gs_intrinsic(nir_builder *b, nir_instr *instr, void *state)
2987 {
2988    lower_ngg_gs_state *s = (lower_ngg_gs_state *) state;
2989 
2990    if (instr->type != nir_instr_type_intrinsic)
2991       return false;
2992 
2993    nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
2994 
2995    if (intrin->intrinsic == nir_intrinsic_store_output)
2996       return lower_ngg_gs_store_output(b, intrin, s);
2997    else if (intrin->intrinsic == nir_intrinsic_emit_vertex_with_counter)
2998       return lower_ngg_gs_emit_vertex_with_counter(b, intrin, s);
2999    else if (intrin->intrinsic == nir_intrinsic_end_primitive_with_counter)
3000       return lower_ngg_gs_end_primitive_with_counter(b, intrin, s);
3001    else if (intrin->intrinsic == nir_intrinsic_set_vertex_and_primitive_count)
3002       return lower_ngg_gs_set_vertex_and_primitive_count(b, intrin, s);
3003 
3004    return false;
3005 }
3006 
3007 static void
lower_ngg_gs_intrinsics(nir_shader * shader,lower_ngg_gs_state * s)3008 lower_ngg_gs_intrinsics(nir_shader *shader, lower_ngg_gs_state *s)
3009 {
3010    nir_shader_instructions_pass(shader, lower_ngg_gs_intrinsic, nir_metadata_none, s);
3011 }
3012 
3013 static void
ngg_gs_export_primitives(nir_builder * b,nir_def * max_num_out_prims,nir_def * tid_in_tg,nir_def * exporter_tid_in_tg,nir_def * primflag_0,lower_ngg_gs_state * s)3014 ngg_gs_export_primitives(nir_builder *b, nir_def *max_num_out_prims, nir_def *tid_in_tg,
3015                          nir_def *exporter_tid_in_tg, nir_def *primflag_0,
3016                          lower_ngg_gs_state *s)
3017 {
3018    nir_if *if_prim_export_thread = nir_push_if(b, nir_ilt(b, tid_in_tg, max_num_out_prims));
3019 
3020    /* Only bit 0 matters here - set it to 1 when the primitive should be null */
3021    nir_def *is_null_prim = nir_ixor(b, primflag_0, nir_imm_int(b, -1u));
3022 
3023    nir_def *vtx_indices[3] = {0};
3024    vtx_indices[s->num_vertices_per_primitive - 1] = exporter_tid_in_tg;
3025    if (s->num_vertices_per_primitive >= 2)
3026       vtx_indices[s->num_vertices_per_primitive - 2] = nir_iadd_imm(b, exporter_tid_in_tg, -1);
3027    if (s->num_vertices_per_primitive == 3)
3028       vtx_indices[s->num_vertices_per_primitive - 3] = nir_iadd_imm(b, exporter_tid_in_tg, -2);
3029 
3030    if (s->num_vertices_per_primitive == 3) {
3031       /* API GS outputs triangle strips, but NGG HW understands triangles.
3032        * We already know the triangles due to how we set the primitive flags, but we need to
3033        * make sure the vertex order is so that the front/back is correct, and the provoking vertex is kept.
3034        */
3035 
3036       nir_def *is_odd = nir_ubfe_imm(b, primflag_0, 1, 1);
3037       nir_def *provoking_vertex_index = nir_load_provoking_vtx_in_prim_amd(b);
3038       nir_def *provoking_vertex_first = nir_ieq_imm(b, provoking_vertex_index, 0);
3039 
3040       vtx_indices[0] = nir_bcsel(b, provoking_vertex_first, vtx_indices[0],
3041                                  nir_iadd(b, vtx_indices[0], is_odd));
3042       vtx_indices[1] = nir_bcsel(b, provoking_vertex_first,
3043                                  nir_iadd(b, vtx_indices[1], is_odd),
3044                                  nir_isub(b, vtx_indices[1], is_odd));
3045       vtx_indices[2] = nir_bcsel(b, provoking_vertex_first,
3046                                  nir_isub(b, vtx_indices[2], is_odd), vtx_indices[2]);
3047    }
3048 
3049    nir_def *arg = emit_pack_ngg_prim_exp_arg(b, s->num_vertices_per_primitive, vtx_indices,
3050                                                  is_null_prim);
3051    ac_nir_export_primitive(b, arg, NULL);
3052    nir_pop_if(b, if_prim_export_thread);
3053 }
3054 
3055 static void
ngg_gs_export_vertices(nir_builder * b,nir_def * max_num_out_vtx,nir_def * tid_in_tg,nir_def * out_vtx_lds_addr,lower_ngg_gs_state * s)3056 ngg_gs_export_vertices(nir_builder *b, nir_def *max_num_out_vtx, nir_def *tid_in_tg,
3057                        nir_def *out_vtx_lds_addr, lower_ngg_gs_state *s)
3058 {
3059    nir_if *if_vtx_export_thread = nir_push_if(b, nir_ilt(b, tid_in_tg, max_num_out_vtx));
3060    nir_def *exported_out_vtx_lds_addr = out_vtx_lds_addr;
3061 
3062    if (!s->output_compile_time_known) {
3063       /* Vertex compaction.
3064        * The current thread will export a vertex that was live in another invocation.
3065        * Load the index of the vertex that the current thread will have to export.
3066        */
3067       nir_def *exported_vtx_idx = nir_load_shared(b, 1, 8, out_vtx_lds_addr, .base = s->lds_offs_primflags + 1);
3068       exported_out_vtx_lds_addr = ngg_gs_out_vertex_addr(b, nir_u2u32(b, exported_vtx_idx), s);
3069    }
3070 
3071    u_foreach_bit64(slot, b->shader->info.outputs_written) {
3072       unsigned packed_location =
3073          util_bitcount64((b->shader->info.outputs_written & BITFIELD64_MASK(slot)));
3074 
3075       gs_output_info *info = &s->output_info[slot];
3076       unsigned mask = gs_output_component_mask_with_stream(info, 0);
3077 
3078       while (mask) {
3079          int start, count;
3080          u_bit_scan_consecutive_range(&mask, &start, &count);
3081          nir_def *load =
3082             nir_load_shared(b, count, 32, exported_out_vtx_lds_addr,
3083                             .base = packed_location * 16 + start * 4,
3084                             .align_mul = 4);
3085 
3086          for (int i = 0; i < count; i++)
3087             s->outputs[slot][start + i] = nir_channel(b, load, i);
3088       }
3089    }
3090 
3091    /* 16bit outputs */
3092    unsigned num_32bit_outputs = util_bitcount64(b->shader->info.outputs_written);
3093    u_foreach_bit(i, b->shader->info.outputs_written_16bit) {
3094       unsigned packed_location = num_32bit_outputs +
3095          util_bitcount(b->shader->info.outputs_written_16bit & BITFIELD_MASK(i));
3096 
3097       gs_output_info *info_lo = s->output_info_16bit_lo + i;
3098       gs_output_info *info_hi = s->output_info_16bit_hi + i;
3099       unsigned mask_lo = gs_output_component_mask_with_stream(info_lo, 0);
3100       unsigned mask_hi = gs_output_component_mask_with_stream(info_hi, 0);
3101       unsigned mask = mask_lo | mask_hi;
3102 
3103       while (mask) {
3104          int start, count;
3105          u_bit_scan_consecutive_range(&mask, &start, &count);
3106          nir_def *load =
3107             nir_load_shared(b, count, 32, exported_out_vtx_lds_addr,
3108                             .base = packed_location * 16 + start * 4,
3109                             .align_mul = 4);
3110 
3111          for (int j = 0; j < count; j++) {
3112             nir_def *val = nir_channel(b, load, j);
3113             unsigned comp = start + j;
3114 
3115             if (mask_lo & BITFIELD_BIT(comp))
3116                s->outputs_16bit_lo[i][comp] = nir_unpack_32_2x16_split_x(b, val);
3117 
3118             if (mask_hi & BITFIELD_BIT(comp))
3119                s->outputs_16bit_hi[i][comp] = nir_unpack_32_2x16_split_y(b, val);
3120          }
3121       }
3122    }
3123 
3124    uint64_t export_outputs = b->shader->info.outputs_written | VARYING_BIT_POS;
3125    if (s->options->kill_pointsize)
3126       export_outputs &= ~VARYING_BIT_PSIZ;
3127    if (s->options->kill_layer)
3128       export_outputs &= ~VARYING_BIT_LAYER;
3129 
3130    const bool wait_attr_ring = must_wait_attr_ring(s->options->gfx_level, s->options->has_param_exports);
3131    if (wait_attr_ring)
3132       export_outputs &= ~VARYING_BIT_POS;
3133 
3134    ac_nir_export_position(b, s->options->gfx_level,
3135                           s->options->clip_cull_dist_mask,
3136                           !s->options->has_param_exports,
3137                           s->options->force_vrs, !wait_attr_ring,
3138                           export_outputs, s->outputs, NULL);
3139 
3140    nir_pop_if(b, if_vtx_export_thread);
3141 
3142    if (s->options->has_param_exports) {
3143       b->cursor = nir_after_cf_list(&if_vtx_export_thread->then_list);
3144 
3145       if (s->options->gfx_level >= GFX11) {
3146          vs_output outputs[64];
3147          unsigned num_outputs = gather_vs_outputs(b, outputs,
3148                                                   s->options->vs_output_param_offset,
3149                                                   s->outputs, s->outputs_16bit_lo,
3150                                                   s->outputs_16bit_hi);
3151 
3152          if (num_outputs) {
3153             b->cursor = nir_after_impl(s->impl);
3154             create_vertex_param_phis(b, num_outputs, outputs);
3155 
3156             export_vertex_params_gfx11(b, tid_in_tg, max_num_out_vtx, num_outputs, outputs,
3157                                        s->options->vs_output_param_offset);
3158          }
3159       } else {
3160          ac_nir_export_parameters(b, s->options->vs_output_param_offset,
3161                                   b->shader->info.outputs_written,
3162                                   b->shader->info.outputs_written_16bit,
3163                                   s->outputs, s->outputs_16bit_lo,
3164                                   s->outputs_16bit_hi);
3165       }
3166    }
3167 
3168    if (wait_attr_ring)
3169       export_pos0_wait_attr_ring(b, if_vtx_export_thread, s->outputs, s->options);
3170 }
3171 
3172 static void
ngg_gs_setup_vertex_compaction(nir_builder * b,nir_def * vertex_live,nir_def * tid_in_tg,nir_def * exporter_tid_in_tg,lower_ngg_gs_state * s)3173 ngg_gs_setup_vertex_compaction(nir_builder *b, nir_def *vertex_live, nir_def *tid_in_tg,
3174                                nir_def *exporter_tid_in_tg, lower_ngg_gs_state *s)
3175 {
3176    assert(vertex_live->bit_size == 1);
3177    nir_if *if_vertex_live = nir_push_if(b, vertex_live);
3178    {
3179       /* Setup the vertex compaction.
3180        * Save the current thread's id for the thread which will export the current vertex.
3181        * We reuse stream 1 of the primitive flag of the other thread's vertex for storing this.
3182        */
3183 
3184       nir_def *exporter_lds_addr = ngg_gs_out_vertex_addr(b, exporter_tid_in_tg, s);
3185       nir_def *tid_in_tg_u8 = nir_u2u8(b, tid_in_tg);
3186       nir_store_shared(b, tid_in_tg_u8, exporter_lds_addr, .base = s->lds_offs_primflags + 1);
3187    }
3188    nir_pop_if(b, if_vertex_live);
3189 }
3190 
3191 static nir_def *
ngg_gs_load_out_vtx_primflag(nir_builder * b,unsigned stream,nir_def * tid_in_tg,nir_def * vtx_lds_addr,nir_def * max_num_out_vtx,lower_ngg_gs_state * s)3192 ngg_gs_load_out_vtx_primflag(nir_builder *b, unsigned stream, nir_def *tid_in_tg,
3193                              nir_def *vtx_lds_addr, nir_def *max_num_out_vtx,
3194                              lower_ngg_gs_state *s)
3195 {
3196    nir_def *zero = nir_imm_int(b, 0);
3197 
3198    nir_if *if_outvtx_thread = nir_push_if(b, nir_ilt(b, tid_in_tg, max_num_out_vtx));
3199    nir_def *primflag = nir_load_shared(b, 1, 8, vtx_lds_addr,
3200                                            .base = s->lds_offs_primflags + stream);
3201    primflag = nir_u2u32(b, primflag);
3202    nir_pop_if(b, if_outvtx_thread);
3203 
3204    return nir_if_phi(b, primflag, zero);
3205 }
3206 
3207 static void
ngg_gs_out_prim_all_vtxptr(nir_builder * b,nir_def * last_vtxidx,nir_def * last_vtxptr,nir_def * last_vtx_primflag,lower_ngg_gs_state * s,nir_def * vtxptr[3])3208 ngg_gs_out_prim_all_vtxptr(nir_builder *b, nir_def *last_vtxidx, nir_def *last_vtxptr,
3209                            nir_def *last_vtx_primflag, lower_ngg_gs_state *s,
3210                            nir_def *vtxptr[3])
3211 {
3212    unsigned last_vtx = s->num_vertices_per_primitive - 1;
3213    vtxptr[last_vtx]= last_vtxptr;
3214 
3215    bool primitive_is_triangle = s->num_vertices_per_primitive == 3;
3216    nir_def *is_odd = primitive_is_triangle ?
3217       nir_ubfe_imm(b, last_vtx_primflag, 1, 1) : NULL;
3218 
3219    for (unsigned i = 0; i < s->num_vertices_per_primitive - 1; i++) {
3220       nir_def *vtxidx = nir_iadd_imm(b, last_vtxidx, -(last_vtx - i));
3221 
3222       /* Need to swap vertex 0 and vertex 1 when vertex 2 index is odd to keep
3223        * CW/CCW order for correct front/back face culling.
3224        */
3225       if (primitive_is_triangle)
3226          vtxidx = i == 0 ? nir_iadd(b, vtxidx, is_odd) : nir_isub(b, vtxidx, is_odd);
3227 
3228       vtxptr[i] = ngg_gs_out_vertex_addr(b, vtxidx, s);
3229    }
3230 }
3231 
3232 static nir_def *
ngg_gs_cull_primitive(nir_builder * b,nir_def * tid_in_tg,nir_def * max_vtxcnt,nir_def * out_vtx_lds_addr,nir_def * out_vtx_primflag_0,lower_ngg_gs_state * s)3233 ngg_gs_cull_primitive(nir_builder *b, nir_def *tid_in_tg, nir_def *max_vtxcnt,
3234                       nir_def *out_vtx_lds_addr, nir_def *out_vtx_primflag_0,
3235                       lower_ngg_gs_state *s)
3236 {
3237    /* we haven't enabled point culling, if enabled this function could be further optimized */
3238    assert(s->num_vertices_per_primitive > 1);
3239 
3240    /* save the primflag so that we don't need to load it from LDS again */
3241    nir_variable *primflag_var = nir_local_variable_create(s->impl, glsl_uint_type(), "primflag");
3242    nir_store_var(b, primflag_var, out_vtx_primflag_0, 1);
3243 
3244    /* last bit of primflag indicate if this is the final vertex of a primitive */
3245    nir_def *is_end_prim_vtx = nir_i2b(b, nir_iand_imm(b, out_vtx_primflag_0, 1));
3246    nir_def *has_output_vertex = nir_ilt(b, tid_in_tg, max_vtxcnt);
3247    nir_def *prim_enable = nir_iand(b, is_end_prim_vtx, has_output_vertex);
3248 
3249    nir_if *if_prim_enable = nir_push_if(b, prim_enable);
3250    {
3251       /* Calculate the LDS address of every vertex in the current primitive. */
3252       nir_def *vtxptr[3];
3253       ngg_gs_out_prim_all_vtxptr(b, tid_in_tg, out_vtx_lds_addr, out_vtx_primflag_0, s, vtxptr);
3254 
3255       /* Load the positions from LDS. */
3256       nir_def *pos[3][4];
3257       for (unsigned i = 0; i < s->num_vertices_per_primitive; i++) {
3258          /* VARYING_SLOT_POS == 0, so base won't count packed location */
3259          pos[i][3] = nir_load_shared(b, 1, 32, vtxptr[i], .base = 12); /* W */
3260          nir_def *xy = nir_load_shared(b, 2, 32, vtxptr[i], .base = 0, .align_mul = 4);
3261          pos[i][0] = nir_channel(b, xy, 0);
3262          pos[i][1] = nir_channel(b, xy, 1);
3263 
3264          pos[i][0] = nir_fdiv(b, pos[i][0], pos[i][3]);
3265          pos[i][1] = nir_fdiv(b, pos[i][1], pos[i][3]);
3266       }
3267 
3268       /* TODO: support clipdist culling in GS */
3269       nir_def *accepted_by_clipdist = nir_imm_true(b);
3270 
3271       nir_def *accepted = ac_nir_cull_primitive(
3272          b, accepted_by_clipdist, pos, s->num_vertices_per_primitive, NULL, NULL);
3273 
3274       nir_if *if_rejected = nir_push_if(b, nir_inot(b, accepted));
3275       {
3276          /* clear the primflag if rejected */
3277          nir_store_shared(b, nir_imm_zero(b, 1, 8), out_vtx_lds_addr,
3278                           .base = s->lds_offs_primflags);
3279 
3280          nir_store_var(b, primflag_var, nir_imm_int(b, 0), 1);
3281       }
3282       nir_pop_if(b, if_rejected);
3283    }
3284    nir_pop_if(b, if_prim_enable);
3285 
3286    /* Wait for LDS primflag access done. */
3287    nir_barrier(b, .execution_scope = SCOPE_WORKGROUP,
3288                          .memory_scope = SCOPE_WORKGROUP,
3289                          .memory_semantics = NIR_MEMORY_ACQ_REL,
3290                          .memory_modes = nir_var_mem_shared);
3291 
3292    /* only dead vertex need a chance to relive */
3293    nir_def *vtx_is_dead = nir_ieq_imm(b, nir_load_var(b, primflag_var), 0);
3294    nir_def *vtx_update_primflag = nir_iand(b, vtx_is_dead, has_output_vertex);
3295    nir_if *if_update_primflag = nir_push_if(b, vtx_update_primflag);
3296    {
3297       /* get succeeding vertices' primflag to detect this vertex's liveness */
3298       for (unsigned i = 1; i < s->num_vertices_per_primitive; i++) {
3299          nir_def *vtxidx = nir_iadd_imm(b, tid_in_tg, i);
3300          nir_def *not_overflow = nir_ilt(b, vtxidx, max_vtxcnt);
3301          nir_if *if_not_overflow = nir_push_if(b, not_overflow);
3302          {
3303             nir_def *vtxptr = ngg_gs_out_vertex_addr(b, vtxidx, s);
3304             nir_def *vtx_primflag =
3305                nir_load_shared(b, 1, 8, vtxptr, .base = s->lds_offs_primflags);
3306             vtx_primflag = nir_u2u32(b, vtx_primflag);
3307 
3308             /* if succeeding vertex is alive end of primitive vertex, need to set current
3309              * thread vertex's liveness flag (bit 2)
3310              */
3311             nir_def *has_prim = nir_i2b(b, nir_iand_imm(b, vtx_primflag, 1));
3312             nir_def *vtx_live_flag =
3313                nir_bcsel(b, has_prim, nir_imm_int(b, 0b100), nir_imm_int(b, 0));
3314 
3315             /* update this vertex's primflag */
3316             nir_def *primflag = nir_load_var(b, primflag_var);
3317             primflag = nir_ior(b, primflag, vtx_live_flag);
3318             nir_store_var(b, primflag_var, primflag, 1);
3319          }
3320          nir_pop_if(b, if_not_overflow);
3321       }
3322    }
3323    nir_pop_if(b, if_update_primflag);
3324 
3325    return nir_load_var(b, primflag_var);
3326 }
3327 
3328 static void
ngg_gs_build_streamout(nir_builder * b,lower_ngg_gs_state * s)3329 ngg_gs_build_streamout(nir_builder *b, lower_ngg_gs_state *s)
3330 {
3331    nir_xfb_info *info = b->shader->xfb_info;
3332 
3333    nir_def *tid_in_tg = nir_load_local_invocation_index(b);
3334    nir_def *max_vtxcnt = nir_load_workgroup_num_input_vertices_amd(b);
3335    nir_def *out_vtx_lds_addr = ngg_gs_out_vertex_addr(b, tid_in_tg, s);
3336    nir_def *prim_live[4] = {0};
3337    nir_def *gen_prim[4] = {0};
3338    nir_def *export_seq[4] = {0};
3339    nir_def *out_vtx_primflag[4] = {0};
3340    for (unsigned stream = 0; stream < 4; stream++) {
3341       if (!(info->streams_written & BITFIELD_BIT(stream)))
3342          continue;
3343 
3344       out_vtx_primflag[stream] =
3345          ngg_gs_load_out_vtx_primflag(b, stream, tid_in_tg, out_vtx_lds_addr, max_vtxcnt, s);
3346 
3347       /* Check bit 0 of primflag for primitive alive, it's set for every last
3348        * vertex of a primitive.
3349        */
3350       prim_live[stream] = nir_i2b(b, nir_iand_imm(b, out_vtx_primflag[stream], 1));
3351 
3352       unsigned scratch_stride = ALIGN(s->max_num_waves, 4);
3353       nir_def *scratch_base =
3354          nir_iadd_imm(b, s->lds_addr_gs_scratch, stream * scratch_stride);
3355 
3356       /* We want to export primitives to streamout buffer in sequence,
3357        * but not all vertices are alive or mark end of a primitive, so
3358        * there're "holes". We don't need continuous invocations to write
3359        * primitives to streamout buffer like final vertex export, so
3360        * just repack to get the sequence (export_seq) is enough, no need
3361        * to do compaction.
3362        *
3363        * Use separate scratch space for each stream to avoid barrier.
3364        * TODO: we may further reduce barriers by writing to all stream
3365        * LDS at once, then we only need one barrier instead of one each
3366        * stream..
3367        */
3368       wg_repack_result rep =
3369          repack_invocations_in_workgroup(b, prim_live[stream], scratch_base,
3370                                          s->max_num_waves, s->options->wave_size);
3371 
3372       /* nir_intrinsic_set_vertex_and_primitive_count can also get primitive count of
3373        * current wave, but still need LDS to sum all wave's count to get workgroup count.
3374        * And we need repack to export primitive to streamout buffer anyway, so do here.
3375        */
3376       gen_prim[stream] = rep.num_repacked_invocations;
3377       export_seq[stream] = rep.repacked_invocation_index;
3378    }
3379 
3380    /* Workgroup barrier: wait for LDS scratch reads finish. */
3381    nir_barrier(b, .execution_scope = SCOPE_WORKGROUP,
3382                       .memory_scope = SCOPE_WORKGROUP,
3383                       .memory_semantics = NIR_MEMORY_ACQ_REL,
3384                       .memory_modes = nir_var_mem_shared);
3385 
3386    /* Get global buffer offset where this workgroup will stream out data to. */
3387    nir_def *emit_prim[4] = {0};
3388    nir_def *buffer_offsets[4] = {0};
3389    nir_def *so_buffer[4] = {0};
3390    nir_def *prim_stride[4] = {0};
3391    ngg_build_streamout_buffer_info(b, info, s->options->has_xfb_prim_query,
3392                                    s->lds_addr_gs_scratch, tid_in_tg, gen_prim,
3393                                    prim_stride, so_buffer, buffer_offsets, emit_prim);
3394 
3395    for (unsigned stream = 0; stream < 4; stream++) {
3396       if (!(info->streams_written & BITFIELD_BIT(stream)))
3397          continue;
3398 
3399       nir_def *can_emit = nir_ilt(b, export_seq[stream], emit_prim[stream]);
3400       nir_if *if_emit = nir_push_if(b, nir_iand(b, can_emit, prim_live[stream]));
3401       {
3402          /* Get streamout buffer vertex index for the first vertex of this primitive. */
3403          nir_def *vtx_buffer_idx =
3404             nir_imul_imm(b, export_seq[stream], s->num_vertices_per_primitive);
3405 
3406          /* Get all vertices' lds address of this primitive. */
3407          nir_def *exported_vtx_lds_addr[3];
3408          ngg_gs_out_prim_all_vtxptr(b, tid_in_tg, out_vtx_lds_addr,
3409                                     out_vtx_primflag[stream], s,
3410                                     exported_vtx_lds_addr);
3411 
3412          /* Write all vertices of this primitive to streamout buffer. */
3413          for (unsigned i = 0; i < s->num_vertices_per_primitive; i++) {
3414             ngg_build_streamout_vertex(b, info, stream, so_buffer,
3415                                        buffer_offsets,
3416                                        nir_iadd_imm(b, vtx_buffer_idx, i),
3417                                        exported_vtx_lds_addr[i],
3418                                        &s->output_types);
3419          }
3420       }
3421       nir_pop_if(b, if_emit);
3422    }
3423 }
3424 
3425 static void
ngg_gs_finale(nir_builder * b,lower_ngg_gs_state * s)3426 ngg_gs_finale(nir_builder *b, lower_ngg_gs_state *s)
3427 {
3428    nir_def *tid_in_tg = nir_load_local_invocation_index(b);
3429    nir_def *max_vtxcnt = nir_load_workgroup_num_input_vertices_amd(b);
3430    nir_def *max_prmcnt = max_vtxcnt; /* They are currently practically the same; both RADV and RadeonSI do this. */
3431    nir_def *out_vtx_lds_addr = ngg_gs_out_vertex_addr(b, tid_in_tg, s);
3432 
3433    if (s->output_compile_time_known) {
3434       /* When the output is compile-time known, the GS writes all possible vertices and primitives it can.
3435        * The gs_alloc_req needs to happen on one wave only, otherwise the HW hangs.
3436        */
3437       nir_if *if_wave_0 = nir_push_if(b, nir_ieq_imm(b, nir_load_subgroup_id(b), 0));
3438       alloc_vertices_and_primitives(b, max_vtxcnt, max_prmcnt);
3439       nir_pop_if(b, if_wave_0);
3440    }
3441 
3442    /* Workgroup barrier already emitted, we can assume all GS output stores are done by now. */
3443 
3444    nir_def *out_vtx_primflag_0 = ngg_gs_load_out_vtx_primflag(b, 0, tid_in_tg, out_vtx_lds_addr, max_vtxcnt, s);
3445 
3446    if (s->output_compile_time_known) {
3447       ngg_gs_export_primitives(b, max_vtxcnt, tid_in_tg, tid_in_tg, out_vtx_primflag_0, s);
3448       ngg_gs_export_vertices(b, max_vtxcnt, tid_in_tg, out_vtx_lds_addr, s);
3449       return;
3450    }
3451 
3452    /* cull primitives */
3453    if (s->options->can_cull) {
3454       nir_if *if_cull_en = nir_push_if(b, nir_load_cull_any_enabled_amd(b));
3455 
3456       /* culling code will update the primflag */
3457       nir_def *updated_primflag =
3458          ngg_gs_cull_primitive(b, tid_in_tg, max_vtxcnt, out_vtx_lds_addr,
3459                                out_vtx_primflag_0, s);
3460 
3461       nir_pop_if(b, if_cull_en);
3462 
3463       out_vtx_primflag_0 = nir_if_phi(b, updated_primflag, out_vtx_primflag_0);
3464    }
3465 
3466    /* When the output vertex count is not known at compile time:
3467     * There may be gaps between invocations that have live vertices, but NGG hardware
3468     * requires that the invocations that export vertices are packed (ie. compact).
3469     * To ensure this, we need to repack invocations that have a live vertex.
3470     */
3471    nir_def *vertex_live = nir_ine_imm(b, out_vtx_primflag_0, 0);
3472    wg_repack_result rep = repack_invocations_in_workgroup(b, vertex_live, s->lds_addr_gs_scratch,
3473                                                           s->max_num_waves, s->options->wave_size);
3474 
3475    nir_def *workgroup_num_vertices = rep.num_repacked_invocations;
3476    nir_def *exporter_tid_in_tg = rep.repacked_invocation_index;
3477 
3478    /* When the workgroup emits 0 total vertices, we also must export 0 primitives (otherwise the HW can hang). */
3479    nir_def *any_output = nir_ine_imm(b, workgroup_num_vertices, 0);
3480    max_prmcnt = nir_bcsel(b, any_output, max_prmcnt, nir_imm_int(b, 0));
3481 
3482    /* Allocate export space. We currently don't compact primitives, just use the maximum number. */
3483    nir_if *if_wave_0 = nir_push_if(b, nir_ieq_imm(b, nir_load_subgroup_id(b), 0));
3484    {
3485       if (s->options->gfx_level == GFX10)
3486          alloc_vertices_and_primitives_gfx10_workaround(b, workgroup_num_vertices, max_prmcnt);
3487       else
3488          alloc_vertices_and_primitives(b, workgroup_num_vertices, max_prmcnt);
3489    }
3490    nir_pop_if(b, if_wave_0);
3491 
3492    /* Vertex compaction. This makes sure there are no gaps between threads that export vertices. */
3493    ngg_gs_setup_vertex_compaction(b, vertex_live, tid_in_tg, exporter_tid_in_tg, s);
3494 
3495    /* Workgroup barrier: wait for all LDS stores to finish. */
3496    nir_barrier(b, .execution_scope=SCOPE_WORKGROUP, .memory_scope=SCOPE_WORKGROUP,
3497                         .memory_semantics=NIR_MEMORY_ACQ_REL, .memory_modes=nir_var_mem_shared);
3498 
3499    ngg_gs_export_primitives(b, max_prmcnt, tid_in_tg, exporter_tid_in_tg, out_vtx_primflag_0, s);
3500    ngg_gs_export_vertices(b, workgroup_num_vertices, tid_in_tg, out_vtx_lds_addr, s);
3501 }
3502 
3503 void
ac_nir_lower_ngg_gs(nir_shader * shader,const ac_nir_lower_ngg_options * options)3504 ac_nir_lower_ngg_gs(nir_shader *shader, const ac_nir_lower_ngg_options *options)
3505 {
3506    nir_function_impl *impl = nir_shader_get_entrypoint(shader);
3507    assert(impl);
3508 
3509    lower_ngg_gs_state state = {
3510       .options = options,
3511       .impl = impl,
3512       .max_num_waves = DIV_ROUND_UP(options->max_workgroup_size, options->wave_size),
3513       .lds_offs_primflags = options->gs_out_vtx_bytes,
3514       .lds_bytes_per_gs_out_vertex = options->gs_out_vtx_bytes + 4u,
3515       .streamout_enabled = shader->xfb_info && !options->disable_streamout,
3516    };
3517 
3518    if (!options->can_cull) {
3519       nir_gs_count_vertices_and_primitives(shader, state.const_out_vtxcnt,
3520                                            state.const_out_prmcnt, NULL, 4u);
3521       state.output_compile_time_known =
3522          state.const_out_vtxcnt[0] == shader->info.gs.vertices_out &&
3523          state.const_out_prmcnt[0] != -1;
3524    }
3525 
3526    if (shader->info.gs.output_primitive == MESA_PRIM_POINTS)
3527       state.num_vertices_per_primitive = 1;
3528    else if (shader->info.gs.output_primitive == MESA_PRIM_LINE_STRIP)
3529       state.num_vertices_per_primitive = 2;
3530    else if (shader->info.gs.output_primitive == MESA_PRIM_TRIANGLE_STRIP)
3531       state.num_vertices_per_primitive = 3;
3532    else
3533       unreachable("Invalid GS output primitive.");
3534 
3535    /* Extract the full control flow. It is going to be wrapped in an if statement. */
3536    nir_cf_list extracted;
3537    nir_cf_extract(&extracted, nir_before_impl(impl),
3538                   nir_after_impl(impl));
3539 
3540    nir_builder builder = nir_builder_at(nir_before_impl(impl));
3541    nir_builder *b = &builder; /* This is to avoid the & */
3542 
3543    /* Workgroup barrier: wait for ES threads */
3544    nir_barrier(b, .execution_scope=SCOPE_WORKGROUP, .memory_scope=SCOPE_WORKGROUP,
3545                          .memory_semantics=NIR_MEMORY_ACQ_REL, .memory_modes=nir_var_mem_shared);
3546 
3547    state.lds_addr_gs_out_vtx = nir_load_lds_ngg_gs_out_vertex_base_amd(b);
3548    state.lds_addr_gs_scratch = nir_load_lds_ngg_scratch_base_amd(b);
3549 
3550    /* Wrap the GS control flow. */
3551    nir_if *if_gs_thread = nir_push_if(b, has_input_primitive(b));
3552 
3553    nir_cf_reinsert(&extracted, b->cursor);
3554    b->cursor = nir_after_cf_list(&if_gs_thread->then_list);
3555    nir_pop_if(b, if_gs_thread);
3556 
3557    /* Workgroup barrier: wait for all GS threads to finish */
3558    nir_barrier(b, .execution_scope=SCOPE_WORKGROUP, .memory_scope=SCOPE_WORKGROUP,
3559                          .memory_semantics=NIR_MEMORY_ACQ_REL, .memory_modes=nir_var_mem_shared);
3560 
3561    if (state.streamout_enabled)
3562       ngg_gs_build_streamout(b, &state);
3563 
3564    /* Lower the GS intrinsics */
3565    lower_ngg_gs_intrinsics(shader, &state);
3566 
3567    if (!state.vertex_count[0]) {
3568       fprintf(stderr, "Could not find set_vertex_and_primitive_count for stream 0. This would hang your GPU.");
3569       abort();
3570    }
3571 
3572    /* Emit shader queries */
3573    b->cursor = nir_after_cf_list(&if_gs_thread->then_list);
3574    ac_nir_gs_shader_query(b,
3575                           state.options->has_gen_prim_query,
3576                           state.options->has_gs_invocations_query,
3577                           state.options->has_gs_primitives_query,
3578                           state.num_vertices_per_primitive,
3579                           state.options->wave_size,
3580                           state.vertex_count,
3581                           state.primitive_count);
3582 
3583    b->cursor = nir_after_impl(impl);
3584 
3585    /* Emit the finale sequence */
3586    ngg_gs_finale(b, &state);
3587    nir_validate_shader(shader, "after emitting NGG GS");
3588 
3589    /* Cleanup */
3590    nir_lower_vars_to_ssa(shader);
3591    nir_remove_dead_variables(shader, nir_var_function_temp, NULL);
3592    nir_metadata_preserve(impl, nir_metadata_none);
3593 }
3594 
3595 unsigned
ac_ngg_nogs_get_pervertex_lds_size(gl_shader_stage stage,unsigned shader_num_outputs,bool streamout_enabled,bool export_prim_id,bool has_user_edgeflags,bool can_cull,bool uses_instance_id,bool uses_primitive_id)3596 ac_ngg_nogs_get_pervertex_lds_size(gl_shader_stage stage,
3597                                    unsigned shader_num_outputs,
3598                                    bool streamout_enabled,
3599                                    bool export_prim_id,
3600                                    bool has_user_edgeflags,
3601                                    bool can_cull,
3602                                    bool uses_instance_id,
3603                                    bool uses_primitive_id)
3604 {
3605    /* for culling time lds layout only */
3606    unsigned culling_pervertex_lds_bytes = can_cull ?
3607       ngg_nogs_get_culling_pervertex_lds_size(
3608          stage, uses_instance_id, uses_primitive_id, NULL) : 0;
3609 
3610    unsigned pervertex_lds_bytes =
3611       ngg_nogs_get_pervertex_lds_size(stage, shader_num_outputs, streamout_enabled,
3612                                       export_prim_id, has_user_edgeflags);
3613 
3614    return MAX2(culling_pervertex_lds_bytes, pervertex_lds_bytes);
3615 }
3616 
3617 unsigned
ac_ngg_get_scratch_lds_size(gl_shader_stage stage,unsigned workgroup_size,unsigned wave_size,bool streamout_enabled,bool can_cull)3618 ac_ngg_get_scratch_lds_size(gl_shader_stage stage,
3619                             unsigned workgroup_size,
3620                             unsigned wave_size,
3621                             bool streamout_enabled,
3622                             bool can_cull)
3623 {
3624    unsigned scratch_lds_size = 0;
3625    unsigned max_num_waves = DIV_ROUND_UP(workgroup_size, wave_size);
3626 
3627    if (stage == MESA_SHADER_VERTEX || stage == MESA_SHADER_TESS_EVAL) {
3628       if (streamout_enabled) {
3629          /* 4 dwords for 4 streamout buffer offset, 1 dword for emit prim count */
3630          scratch_lds_size = 20;
3631       } else if (can_cull) {
3632          scratch_lds_size = ALIGN(max_num_waves, 4u);
3633       }
3634    } else {
3635       assert(stage == MESA_SHADER_GEOMETRY);
3636 
3637       scratch_lds_size = ALIGN(max_num_waves, 4u);
3638       /* streamout take 8 dwords for buffer offset and emit vertex per stream */
3639       if (streamout_enabled)
3640          scratch_lds_size = MAX2(scratch_lds_size, 32);
3641    }
3642 
3643    return scratch_lds_size;
3644 }
3645 
3646 static void
ms_store_prim_indices(nir_builder * b,nir_def * val,nir_def * offset_src,lower_ngg_ms_state * s)3647 ms_store_prim_indices(nir_builder *b,
3648                       nir_def *val,
3649                       nir_def *offset_src,
3650                       lower_ngg_ms_state *s)
3651 {
3652    assert(val->num_components <= 3);
3653 
3654    if (s->layout.var.prm_attr.mask & BITFIELD64_BIT(VARYING_SLOT_PRIMITIVE_INDICES)) {
3655       for (unsigned c = 0; c < s->vertices_per_prim; ++c)
3656          nir_store_var(b, s->out_variables[VARYING_SLOT_PRIMITIVE_INDICES * 4 + c], nir_channel(b, val, c), 0x1);
3657       return;
3658    }
3659 
3660    if (!offset_src)
3661       offset_src = nir_imm_int(b, 0);
3662 
3663    nir_store_shared(b, nir_u2u8(b, val), offset_src, .base = s->layout.lds.indices_addr);
3664 }
3665 
3666 static void
ms_store_cull_flag(nir_builder * b,nir_def * val,nir_def * offset_src,lower_ngg_ms_state * s)3667 ms_store_cull_flag(nir_builder *b,
3668                    nir_def *val,
3669                    nir_def *offset_src,
3670                    lower_ngg_ms_state *s)
3671 {
3672    assert(val->num_components == 1);
3673    assert(val->bit_size == 1);
3674 
3675    if (s->layout.var.prm_attr.mask & BITFIELD64_BIT(VARYING_SLOT_CULL_PRIMITIVE)) {
3676       nir_store_var(b, s->out_variables[VARYING_SLOT_CULL_PRIMITIVE * 4], nir_b2i32(b, val), 0x1);
3677       return;
3678    }
3679 
3680    if (!offset_src)
3681       offset_src = nir_imm_int(b, 0);
3682 
3683    nir_store_shared(b, nir_b2i8(b, val), offset_src, .base = s->layout.lds.cull_flags_addr);
3684 }
3685 
3686 static nir_def *
ms_arrayed_output_base_addr(nir_builder * b,nir_def * arr_index,unsigned driver_location,unsigned num_arrayed_outputs)3687 ms_arrayed_output_base_addr(nir_builder *b,
3688                             nir_def *arr_index,
3689                             unsigned driver_location,
3690                             unsigned num_arrayed_outputs)
3691 {
3692    /* Address offset of the array item (vertex or primitive). */
3693    unsigned arr_index_stride = num_arrayed_outputs * 16u;
3694    nir_def *arr_index_off = nir_imul_imm(b, arr_index, arr_index_stride);
3695 
3696    /* IO address offset within the vertex or primitive data. */
3697    unsigned io_offset = driver_location * 16u;
3698    nir_def *io_off = nir_imm_int(b, io_offset);
3699 
3700    return nir_iadd_nuw(b, arr_index_off, io_off);
3701 }
3702 
3703 static void
update_ms_output_info_slot(lower_ngg_ms_state * s,unsigned slot,unsigned base_off,uint32_t components_mask)3704 update_ms_output_info_slot(lower_ngg_ms_state *s,
3705                            unsigned slot, unsigned base_off,
3706                            uint32_t components_mask)
3707 {
3708    while (components_mask) {
3709       s->output_info[slot + base_off].components_mask |= components_mask & 0xF;
3710 
3711       components_mask >>= 4;
3712       base_off++;
3713    }
3714 }
3715 
3716 static void
update_ms_output_info(nir_intrinsic_instr * intrin,const ms_out_part * out,lower_ngg_ms_state * s)3717 update_ms_output_info(nir_intrinsic_instr *intrin,
3718                       const ms_out_part *out,
3719                       lower_ngg_ms_state *s)
3720 {
3721    nir_io_semantics io_sem = nir_intrinsic_io_semantics(intrin);
3722    nir_src *base_offset_src = nir_get_io_offset_src(intrin);
3723    uint32_t write_mask = nir_intrinsic_write_mask(intrin);
3724    unsigned component_offset = nir_intrinsic_component(intrin);
3725 
3726    nir_def *store_val = intrin->src[0].ssa;
3727    write_mask = util_widen_mask(write_mask, DIV_ROUND_UP(store_val->bit_size, 32));
3728    uint32_t components_mask = write_mask << component_offset;
3729 
3730    if (nir_src_is_const(*base_offset_src)) {
3731       /* Simply mark the components of the current slot as used. */
3732       unsigned base_off = nir_src_as_uint(*base_offset_src);
3733       update_ms_output_info_slot(s, io_sem.location, base_off, components_mask);
3734    } else {
3735       /* Indirect offset: mark the components of all slots as used. */
3736       for (unsigned base_off = 0; base_off < io_sem.num_slots; ++base_off)
3737          update_ms_output_info_slot(s, io_sem.location, base_off, components_mask);
3738    }
3739 }
3740 
3741 static nir_def *
regroup_store_val(nir_builder * b,nir_def * store_val)3742 regroup_store_val(nir_builder *b, nir_def *store_val)
3743 {
3744    /* Vulkan spec 15.1.4-15.1.5:
3745     *
3746     * The shader interface consists of output slots with 4x 32-bit components.
3747     * Small bitsize components consume the same space as 32-bit components,
3748     * but 64-bit ones consume twice as much.
3749     *
3750     * The same output slot may consist of components of different bit sizes.
3751     * Therefore for simplicity we don't store small bitsize components
3752     * contiguously, but pad them instead. In practice, they are converted to
3753     * 32-bit and then stored contiguously.
3754     */
3755 
3756    if (store_val->bit_size < 32) {
3757       assert(store_val->num_components <= 4);
3758       nir_def *comps[4] = {0};
3759       for (unsigned c = 0; c < store_val->num_components; ++c)
3760          comps[c] = nir_u2u32(b, nir_channel(b, store_val, c));
3761       return nir_vec(b, comps, store_val->num_components);
3762    }
3763 
3764    return store_val;
3765 }
3766 
3767 static nir_def *
regroup_load_val(nir_builder * b,nir_def * load,unsigned dest_bit_size)3768 regroup_load_val(nir_builder *b, nir_def *load, unsigned dest_bit_size)
3769 {
3770    if (dest_bit_size == load->bit_size)
3771       return load;
3772 
3773    /* Small bitsize components are not stored contiguously, take care of that here. */
3774    unsigned num_components = load->num_components;
3775    assert(num_components <= 4);
3776    nir_def *components[4] = {0};
3777    for (unsigned i = 0; i < num_components; ++i)
3778       components[i] = nir_u2uN(b, nir_channel(b, load, i), dest_bit_size);
3779 
3780    return nir_vec(b, components, num_components);
3781 }
3782 
3783 static const ms_out_part *
ms_get_out_layout_part(unsigned location,shader_info * info,ms_out_mode * out_mode,lower_ngg_ms_state * s)3784 ms_get_out_layout_part(unsigned location,
3785                        shader_info *info,
3786                        ms_out_mode *out_mode,
3787                        lower_ngg_ms_state *s)
3788 {
3789    uint64_t mask = BITFIELD64_BIT(location);
3790 
3791    if (info->per_primitive_outputs & mask) {
3792       if (mask & s->layout.lds.prm_attr.mask) {
3793          *out_mode = ms_out_mode_lds;
3794          return &s->layout.lds.prm_attr;
3795       } else if (mask & s->layout.scratch_ring.prm_attr.mask) {
3796          *out_mode = ms_out_mode_scratch_ring;
3797          return &s->layout.scratch_ring.prm_attr;
3798       } else if (mask & s->layout.attr_ring.prm_attr.mask) {
3799          *out_mode = ms_out_mode_attr_ring;
3800          return &s->layout.attr_ring.prm_attr;
3801       } else if (mask & s->layout.var.prm_attr.mask) {
3802          *out_mode = ms_out_mode_var;
3803          return &s->layout.var.prm_attr;
3804       }
3805    } else {
3806       if (mask & s->layout.lds.vtx_attr.mask) {
3807          *out_mode = ms_out_mode_lds;
3808          return &s->layout.lds.vtx_attr;
3809       } else if (mask & s->layout.scratch_ring.vtx_attr.mask) {
3810          *out_mode = ms_out_mode_scratch_ring;
3811          return &s->layout.scratch_ring.vtx_attr;
3812       } else if (mask & s->layout.attr_ring.vtx_attr.mask) {
3813          *out_mode = ms_out_mode_attr_ring;
3814          return &s->layout.attr_ring.vtx_attr;
3815       } else if (mask & s->layout.var.vtx_attr.mask) {
3816          *out_mode = ms_out_mode_var;
3817          return &s->layout.var.vtx_attr;
3818       }
3819    }
3820 
3821    unreachable("Couldn't figure out mesh shader output mode.");
3822 }
3823 
3824 static void
ms_store_arrayed_output_intrin(nir_builder * b,nir_intrinsic_instr * intrin,lower_ngg_ms_state * s)3825 ms_store_arrayed_output_intrin(nir_builder *b,
3826                                nir_intrinsic_instr *intrin,
3827                                lower_ngg_ms_state *s)
3828 {
3829    unsigned location = nir_intrinsic_io_semantics(intrin).location;
3830 
3831    if (location == VARYING_SLOT_PRIMITIVE_INDICES) {
3832       /* EXT_mesh_shader primitive indices: array of vectors.
3833        * They don't count as per-primitive outputs, but the array is indexed
3834        * by the primitive index, so they are practically per-primitive.
3835        *
3836        * The max vertex count is 256, so these indices always fit 8 bits.
3837        * To reduce LDS use, store these as a flat array of 8-bit values.
3838        */
3839       assert(nir_src_is_const(*nir_get_io_offset_src(intrin)));
3840       assert(nir_src_as_uint(*nir_get_io_offset_src(intrin)) == 0);
3841       assert(nir_intrinsic_component(intrin) == 0);
3842 
3843       nir_def *store_val = intrin->src[0].ssa;
3844       nir_def *arr_index = nir_get_io_arrayed_index_src(intrin)->ssa;
3845       nir_def *offset = nir_imul_imm(b, arr_index, s->vertices_per_prim);
3846       ms_store_prim_indices(b, store_val, offset, s);
3847       return;
3848    } else if (location == VARYING_SLOT_CULL_PRIMITIVE) {
3849       /* EXT_mesh_shader cull primitive: per-primitive bool.
3850        * To reduce LDS use, store these as an array of 8-bit values.
3851        */
3852       assert(nir_src_is_const(*nir_get_io_offset_src(intrin)));
3853       assert(nir_src_as_uint(*nir_get_io_offset_src(intrin)) == 0);
3854       assert(nir_intrinsic_component(intrin) == 0);
3855       assert(nir_intrinsic_write_mask(intrin) == 1);
3856 
3857       nir_def *store_val = intrin->src[0].ssa;
3858       nir_def *arr_index = nir_get_io_arrayed_index_src(intrin)->ssa;
3859       nir_def *offset = nir_imul_imm(b, arr_index, s->vertices_per_prim);
3860       ms_store_cull_flag(b, store_val, offset, s);
3861       return;
3862    }
3863 
3864    ms_out_mode out_mode;
3865    const ms_out_part *out = ms_get_out_layout_part(location, &b->shader->info, &out_mode, s);
3866    update_ms_output_info(intrin, out, s);
3867 
3868    /* We compact the LDS size (we don't reserve LDS space for outputs which can
3869     * be stored in variables), so we can't rely on the original driver_location.
3870     * Instead, we compute the first free location based on the output mask.
3871     */
3872    unsigned driver_location = util_bitcount64(out->mask & u_bit_consecutive64(0, location));
3873    unsigned component_offset = nir_intrinsic_component(intrin);
3874    unsigned write_mask = nir_intrinsic_write_mask(intrin);
3875    unsigned num_outputs = util_bitcount64(out->mask);
3876    unsigned const_off = out->addr + component_offset * 4;
3877 
3878    nir_def *store_val = regroup_store_val(b, intrin->src[0].ssa);
3879    nir_def *arr_index = nir_get_io_arrayed_index_src(intrin)->ssa;
3880    nir_def *base_addr = ms_arrayed_output_base_addr(b, arr_index, driver_location, num_outputs);
3881    nir_def *base_offset = nir_get_io_offset_src(intrin)->ssa;
3882    nir_def *base_addr_off = nir_imul_imm(b, base_offset, 16u);
3883    nir_def *addr = nir_iadd_nuw(b, base_addr, base_addr_off);
3884 
3885    if (out_mode == ms_out_mode_lds) {
3886       nir_store_shared(b, store_val, addr, .base = const_off,
3887                      .write_mask = write_mask, .align_mul = 16,
3888                      .align_offset = const_off % 16);
3889    } else if (out_mode == ms_out_mode_scratch_ring) {
3890       nir_def *ring = nir_load_ring_mesh_scratch_amd(b);
3891       nir_def *off = nir_load_ring_mesh_scratch_offset_amd(b);
3892       nir_def *zero = nir_imm_int(b, 0);
3893       nir_store_buffer_amd(b, store_val, ring, addr, off, zero,
3894                            .base = const_off,
3895                            .write_mask = write_mask,
3896                            .memory_modes = nir_var_shader_out,
3897                            .access = ACCESS_COHERENT);
3898    } else if (out_mode == ms_out_mode_attr_ring) {
3899       /* GFX11+: Store params straight to the attribute ring.
3900        *
3901        * Even though the access pattern may not be the most optimal,
3902        * this is still much better than reserving LDS and losing waves.
3903        * (Also much better than storing and reloading from the scratch ring.)
3904        */
3905       const nir_io_semantics io_sem = nir_intrinsic_io_semantics(intrin);
3906       unsigned param_offset = s->vs_output_param_offset[io_sem.location];
3907       nir_def *ring = nir_load_ring_attr_amd(b);
3908       nir_def *soffset = nir_load_ring_attr_offset_amd(b);
3909       nir_store_buffer_amd(b, store_val, ring, base_addr_off, soffset, arr_index,
3910                            .base = const_off + param_offset * 16,
3911                            .write_mask = write_mask,
3912                            .memory_modes = nir_var_shader_out,
3913                            .access = ACCESS_COHERENT | ACCESS_IS_SWIZZLED_AMD);
3914    } else if (out_mode == ms_out_mode_var) {
3915       if (store_val->bit_size > 32) {
3916          /* Split 64-bit store values to 32-bit components. */
3917          store_val = nir_bitcast_vector(b, store_val, 32);
3918          /* Widen the write mask so it is in 32-bit components. */
3919          write_mask = util_widen_mask(write_mask, store_val->bit_size / 32);
3920       }
3921 
3922       u_foreach_bit(comp, write_mask) {
3923          nir_def *val = nir_channel(b, store_val, comp);
3924          unsigned idx = location * 4 + comp + component_offset;
3925          nir_store_var(b, s->out_variables[idx], val, 0x1);
3926       }
3927    } else {
3928       unreachable("Invalid MS output mode for store");
3929    }
3930 }
3931 
3932 static nir_def *
ms_load_arrayed_output(nir_builder * b,nir_def * arr_index,nir_def * base_offset,unsigned location,unsigned component_offset,unsigned num_components,unsigned load_bit_size,lower_ngg_ms_state * s)3933 ms_load_arrayed_output(nir_builder *b,
3934                        nir_def *arr_index,
3935                        nir_def *base_offset,
3936                        unsigned location,
3937                        unsigned component_offset,
3938                        unsigned num_components,
3939                        unsigned load_bit_size,
3940                        lower_ngg_ms_state *s)
3941 {
3942    ms_out_mode out_mode;
3943    const ms_out_part *out = ms_get_out_layout_part(location, &b->shader->info, &out_mode, s);
3944 
3945    unsigned component_addr_off = component_offset * 4;
3946    unsigned num_outputs = util_bitcount64(out->mask);
3947    unsigned const_off = out->addr + component_offset * 4;
3948 
3949    /* Use compacted driver location instead of the original. */
3950    unsigned driver_location = util_bitcount64(out->mask & u_bit_consecutive64(0, location));
3951 
3952    nir_def *base_addr = ms_arrayed_output_base_addr(b, arr_index, driver_location, num_outputs);
3953    nir_def *base_addr_off = nir_imul_imm(b, base_offset, 16);
3954    nir_def *addr = nir_iadd_nuw(b, base_addr, base_addr_off);
3955 
3956    if (out_mode == ms_out_mode_lds) {
3957       return nir_load_shared(b, num_components, load_bit_size, addr, .align_mul = 16,
3958                              .align_offset = component_addr_off % 16,
3959                              .base = const_off);
3960    } else if (out_mode == ms_out_mode_scratch_ring) {
3961       nir_def *ring = nir_load_ring_mesh_scratch_amd(b);
3962       nir_def *off = nir_load_ring_mesh_scratch_offset_amd(b);
3963       nir_def *zero = nir_imm_int(b, 0);
3964       return nir_load_buffer_amd(b, num_components, load_bit_size, ring, addr, off, zero,
3965                                  .base = const_off,
3966                                  .memory_modes = nir_var_shader_out,
3967                                  .access = ACCESS_COHERENT);
3968    } else if (out_mode == ms_out_mode_var) {
3969       nir_def *arr[8] = {0};
3970       unsigned num_32bit_components = num_components * load_bit_size / 32;
3971       for (unsigned comp = 0; comp < num_32bit_components; ++comp) {
3972          unsigned idx = location * 4 + comp + component_addr_off;
3973          arr[comp] = nir_load_var(b, s->out_variables[idx]);
3974       }
3975       if (load_bit_size > 32)
3976          return nir_extract_bits(b, arr, 1, 0, num_components, load_bit_size);
3977       return nir_vec(b, arr, num_components);
3978    } else {
3979       unreachable("Invalid MS output mode for load");
3980    }
3981 }
3982 
3983 static nir_def *
ms_load_arrayed_output_intrin(nir_builder * b,nir_intrinsic_instr * intrin,lower_ngg_ms_state * s)3984 ms_load_arrayed_output_intrin(nir_builder *b,
3985                               nir_intrinsic_instr *intrin,
3986                               lower_ngg_ms_state *s)
3987 {
3988    nir_def *arr_index = nir_get_io_arrayed_index_src(intrin)->ssa;
3989    nir_def *base_offset = nir_get_io_offset_src(intrin)->ssa;
3990 
3991    unsigned location = nir_intrinsic_io_semantics(intrin).location;
3992    unsigned component_offset = nir_intrinsic_component(intrin);
3993    unsigned bit_size = intrin->def.bit_size;
3994    unsigned num_components = intrin->def.num_components;
3995    unsigned load_bit_size = MAX2(bit_size, 32);
3996 
3997    nir_def *load =
3998       ms_load_arrayed_output(b, arr_index, base_offset, location, component_offset,
3999                              num_components, load_bit_size, s);
4000 
4001    return regroup_load_val(b, load, bit_size);
4002 }
4003 
4004 static nir_def *
lower_ms_load_workgroup_index(nir_builder * b,UNUSED nir_intrinsic_instr * intrin,lower_ngg_ms_state * s)4005 lower_ms_load_workgroup_index(nir_builder *b,
4006                               UNUSED nir_intrinsic_instr *intrin,
4007                               lower_ngg_ms_state *s)
4008 {
4009    return s->workgroup_index;
4010 }
4011 
4012 static nir_def *
lower_ms_set_vertex_and_primitive_count(nir_builder * b,nir_intrinsic_instr * intrin,lower_ngg_ms_state * s)4013 lower_ms_set_vertex_and_primitive_count(nir_builder *b,
4014                                         nir_intrinsic_instr *intrin,
4015                                         lower_ngg_ms_state *s)
4016 {
4017    /* If either the number of vertices or primitives is zero, set both of them to zero. */
4018    nir_def *num_vtx = nir_read_first_invocation(b, intrin->src[0].ssa);
4019    nir_def *num_prm = nir_read_first_invocation(b, intrin->src[1].ssa);
4020    nir_def *zero = nir_imm_int(b, 0);
4021    nir_def *is_either_zero = nir_ieq(b, nir_umin(b, num_vtx, num_prm), zero);
4022    num_vtx = nir_bcsel(b, is_either_zero, zero, num_vtx);
4023    num_prm = nir_bcsel(b, is_either_zero, zero, num_prm);
4024 
4025    nir_store_var(b, s->vertex_count_var, num_vtx, 0x1);
4026    nir_store_var(b, s->primitive_count_var, num_prm, 0x1);
4027 
4028    return NIR_LOWER_INSTR_PROGRESS_REPLACE;
4029 }
4030 
4031 static nir_def *
update_ms_barrier(nir_builder * b,nir_intrinsic_instr * intrin,lower_ngg_ms_state * s)4032 update_ms_barrier(nir_builder *b,
4033                          nir_intrinsic_instr *intrin,
4034                          lower_ngg_ms_state *s)
4035 {
4036    /* Output loads and stores are lowered to shared memory access,
4037     * so we have to update the barriers to also reflect this.
4038     */
4039    unsigned mem_modes = nir_intrinsic_memory_modes(intrin);
4040    if (mem_modes & nir_var_shader_out)
4041       mem_modes |= nir_var_mem_shared;
4042    else
4043       return NULL;
4044 
4045    nir_intrinsic_set_memory_modes(intrin, mem_modes);
4046 
4047    return NIR_LOWER_INSTR_PROGRESS;
4048 }
4049 
4050 static nir_def *
lower_ms_intrinsic(nir_builder * b,nir_instr * instr,void * state)4051 lower_ms_intrinsic(nir_builder *b, nir_instr *instr, void *state)
4052 {
4053    lower_ngg_ms_state *s = (lower_ngg_ms_state *) state;
4054 
4055    if (instr->type != nir_instr_type_intrinsic)
4056       return NULL;
4057 
4058    nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
4059 
4060    switch (intrin->intrinsic) {
4061    case nir_intrinsic_store_per_vertex_output:
4062    case nir_intrinsic_store_per_primitive_output:
4063       ms_store_arrayed_output_intrin(b, intrin, s);
4064       return NIR_LOWER_INSTR_PROGRESS_REPLACE;
4065    case nir_intrinsic_load_per_vertex_output:
4066    case nir_intrinsic_load_per_primitive_output:
4067       return ms_load_arrayed_output_intrin(b, intrin, s);
4068    case nir_intrinsic_barrier:
4069       return update_ms_barrier(b, intrin, s);
4070    case nir_intrinsic_load_workgroup_index:
4071       return lower_ms_load_workgroup_index(b, intrin, s);
4072    case nir_intrinsic_set_vertex_and_primitive_count:
4073       return lower_ms_set_vertex_and_primitive_count(b, intrin, s);
4074    default:
4075       unreachable("Not a lowerable mesh shader intrinsic.");
4076    }
4077 }
4078 
4079 static bool
filter_ms_intrinsic(const nir_instr * instr,UNUSED const void * s)4080 filter_ms_intrinsic(const nir_instr *instr,
4081                     UNUSED const void *s)
4082 {
4083    if (instr->type != nir_instr_type_intrinsic)
4084       return false;
4085 
4086    nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
4087    return intrin->intrinsic == nir_intrinsic_store_output ||
4088           intrin->intrinsic == nir_intrinsic_load_output ||
4089           intrin->intrinsic == nir_intrinsic_store_per_vertex_output ||
4090           intrin->intrinsic == nir_intrinsic_load_per_vertex_output ||
4091           intrin->intrinsic == nir_intrinsic_store_per_primitive_output ||
4092           intrin->intrinsic == nir_intrinsic_load_per_primitive_output ||
4093           intrin->intrinsic == nir_intrinsic_barrier ||
4094           intrin->intrinsic == nir_intrinsic_load_workgroup_index ||
4095           intrin->intrinsic == nir_intrinsic_set_vertex_and_primitive_count;
4096 }
4097 
4098 static void
lower_ms_intrinsics(nir_shader * shader,lower_ngg_ms_state * s)4099 lower_ms_intrinsics(nir_shader *shader, lower_ngg_ms_state *s)
4100 {
4101    nir_shader_lower_instructions(shader, filter_ms_intrinsic, lower_ms_intrinsic, s);
4102 }
4103 
4104 static void
ms_emit_arrayed_outputs(nir_builder * b,nir_def * invocation_index,uint64_t mask,lower_ngg_ms_state * s)4105 ms_emit_arrayed_outputs(nir_builder *b,
4106                         nir_def *invocation_index,
4107                         uint64_t mask,
4108                         lower_ngg_ms_state *s)
4109 {
4110    nir_def *zero = nir_imm_int(b, 0);
4111 
4112    u_foreach_bit64(slot, mask) {
4113       /* Should not occur here, handled separately. */
4114       assert(slot != VARYING_SLOT_PRIMITIVE_COUNT && slot != VARYING_SLOT_PRIMITIVE_INDICES);
4115 
4116       unsigned component_mask = s->output_info[slot].components_mask;
4117 
4118       while (component_mask) {
4119          int start_comp = 0, num_components = 1;
4120          u_bit_scan_consecutive_range(&component_mask, &start_comp, &num_components);
4121 
4122          nir_def *load =
4123             ms_load_arrayed_output(b, invocation_index, zero, slot, start_comp,
4124                                    num_components, 32, s);
4125 
4126          for (int i = 0; i < num_components; i++)
4127             s->outputs[slot][start_comp + i] = nir_channel(b, load, i);
4128       }
4129    }
4130 }
4131 
4132 static void
ms_create_same_invocation_vars(nir_builder * b,lower_ngg_ms_state * s)4133 ms_create_same_invocation_vars(nir_builder *b, lower_ngg_ms_state *s)
4134 {
4135    /* Initialize NIR variables for same-invocation outputs. */
4136    uint64_t same_invocation_output_mask = s->layout.var.prm_attr.mask | s->layout.var.vtx_attr.mask;
4137 
4138    u_foreach_bit64(slot, same_invocation_output_mask) {
4139       for (unsigned comp = 0; comp < 4; ++comp) {
4140          unsigned idx = slot * 4 + comp;
4141          s->out_variables[idx] = nir_local_variable_create(b->impl, glsl_uint_type(), "ms_var_output");
4142       }
4143    }
4144 }
4145 
4146 static void
ms_emit_legacy_workgroup_index(nir_builder * b,lower_ngg_ms_state * s)4147 ms_emit_legacy_workgroup_index(nir_builder *b, lower_ngg_ms_state *s)
4148 {
4149    /* Workgroup ID should have been lowered to workgroup index. */
4150    assert(!BITSET_TEST(b->shader->info.system_values_read, SYSTEM_VALUE_WORKGROUP_ID));
4151 
4152    /* No need to do anything if the shader doesn't use the workgroup index. */
4153    if (!BITSET_TEST(b->shader->info.system_values_read, SYSTEM_VALUE_WORKGROUP_INDEX))
4154       return;
4155 
4156    b->cursor = nir_before_impl(b->impl);
4157 
4158    /* Legacy fast launch mode (FAST_LAUNCH=1):
4159     *
4160     * The HW doesn't support a proper workgroup index for vertex processing stages,
4161     * so we use the vertex ID which is equivalent to the index of the current workgroup
4162     * within the current dispatch.
4163     *
4164     * Due to the register programming of mesh shaders, this value is only filled for
4165     * the first invocation of the first wave. To let other waves know, we use LDS.
4166     */
4167    nir_def *workgroup_index = nir_load_vertex_id_zero_base(b);
4168 
4169    if (s->api_workgroup_size <= s->wave_size) {
4170       /* API workgroup is small, so we don't need to use LDS. */
4171       s->workgroup_index = nir_read_first_invocation(b, workgroup_index);
4172       return;
4173    }
4174 
4175    unsigned workgroup_index_lds_addr = s->layout.lds.workgroup_info_addr + lds_ms_wg_index;
4176 
4177    nir_def *zero = nir_imm_int(b, 0);
4178    nir_def *dont_care = nir_undef(b, 1, 32);
4179    nir_def *loaded_workgroup_index = NULL;
4180 
4181    /* Use elect to make sure only 1 invocation uses LDS. */
4182    nir_if *if_elected = nir_push_if(b, nir_elect(b, 1));
4183    {
4184       nir_def *wave_id = nir_load_subgroup_id(b);
4185       nir_if *if_wave_0 = nir_push_if(b, nir_ieq_imm(b, wave_id, 0));
4186       {
4187          nir_store_shared(b, workgroup_index, zero, .base = workgroup_index_lds_addr);
4188          nir_barrier(b, .execution_scope = SCOPE_WORKGROUP,
4189                                .memory_scope = SCOPE_WORKGROUP,
4190                                .memory_semantics = NIR_MEMORY_ACQ_REL,
4191                                .memory_modes = nir_var_mem_shared);
4192       }
4193       nir_push_else(b, if_wave_0);
4194       {
4195          nir_barrier(b, .execution_scope = SCOPE_WORKGROUP,
4196                                .memory_scope = SCOPE_WORKGROUP,
4197                                .memory_semantics = NIR_MEMORY_ACQ_REL,
4198                                .memory_modes = nir_var_mem_shared);
4199          loaded_workgroup_index = nir_load_shared(b, 1, 32, zero, .base = workgroup_index_lds_addr);
4200       }
4201       nir_pop_if(b, if_wave_0);
4202 
4203       workgroup_index = nir_if_phi(b, workgroup_index, loaded_workgroup_index);
4204    }
4205    nir_pop_if(b, if_elected);
4206 
4207    workgroup_index = nir_if_phi(b, workgroup_index, dont_care);
4208    s->workgroup_index = nir_read_first_invocation(b, workgroup_index);
4209 }
4210 
4211 static void
set_ms_final_output_counts(nir_builder * b,lower_ngg_ms_state * s,nir_def ** out_num_prm,nir_def ** out_num_vtx)4212 set_ms_final_output_counts(nir_builder *b,
4213                            lower_ngg_ms_state *s,
4214                            nir_def **out_num_prm,
4215                            nir_def **out_num_vtx)
4216 {
4217    /* The spec allows the numbers to be divergent, and in that case we need to
4218     * use the values from the first invocation. Also the HW requires us to set
4219     * both to 0 if either was 0.
4220     *
4221     * These are already done by the lowering.
4222     */
4223    nir_def *num_prm = nir_load_var(b, s->primitive_count_var);
4224    nir_def *num_vtx = nir_load_var(b, s->vertex_count_var);
4225 
4226    if (s->hw_workgroup_size <= s->wave_size) {
4227       /* Single-wave mesh shader workgroup. */
4228       alloc_vertices_and_primitives(b, num_vtx, num_prm);
4229       *out_num_prm = num_prm;
4230       *out_num_vtx = num_vtx;
4231       return;
4232    }
4233 
4234    /* Multi-wave mesh shader workgroup:
4235     * We need to use LDS to distribute the correct values to the other waves.
4236     *
4237     * TODO:
4238     * If we can prove that the values are workgroup-uniform, we can skip this
4239     * and just use whatever the current wave has. However, NIR divergence analysis
4240     * currently doesn't support this.
4241     */
4242 
4243    nir_def *zero = nir_imm_int(b, 0);
4244 
4245    nir_if *if_wave_0 = nir_push_if(b, nir_ieq_imm(b, nir_load_subgroup_id(b), 0));
4246    {
4247       nir_if *if_elected = nir_push_if(b, nir_elect(b, 1));
4248       {
4249          nir_store_shared(b, nir_vec2(b, num_prm, num_vtx), zero,
4250                           .base = s->layout.lds.workgroup_info_addr + lds_ms_num_prims);
4251       }
4252       nir_pop_if(b, if_elected);
4253 
4254       nir_barrier(b, .execution_scope = SCOPE_WORKGROUP,
4255                             .memory_scope = SCOPE_WORKGROUP,
4256                             .memory_semantics = NIR_MEMORY_ACQ_REL,
4257                             .memory_modes = nir_var_mem_shared);
4258 
4259       alloc_vertices_and_primitives(b, num_vtx, num_prm);
4260    }
4261    nir_push_else(b, if_wave_0);
4262    {
4263       nir_barrier(b, .execution_scope = SCOPE_WORKGROUP,
4264                             .memory_scope = SCOPE_WORKGROUP,
4265                             .memory_semantics = NIR_MEMORY_ACQ_REL,
4266                             .memory_modes = nir_var_mem_shared);
4267 
4268       nir_def *prm_vtx = NULL;
4269       nir_def *dont_care_2x32 = nir_undef(b, 2, 32);
4270       nir_if *if_elected = nir_push_if(b, nir_elect(b, 1));
4271       {
4272          prm_vtx = nir_load_shared(b, 2, 32, zero,
4273                                    .base = s->layout.lds.workgroup_info_addr + lds_ms_num_prims);
4274       }
4275       nir_pop_if(b, if_elected);
4276 
4277       prm_vtx = nir_if_phi(b, prm_vtx, dont_care_2x32);
4278       num_prm = nir_read_first_invocation(b, nir_channel(b, prm_vtx, 0));
4279       num_vtx = nir_read_first_invocation(b, nir_channel(b, prm_vtx, 1));
4280 
4281       nir_store_var(b, s->primitive_count_var, num_prm, 0x1);
4282       nir_store_var(b, s->vertex_count_var, num_vtx, 0x1);
4283    }
4284    nir_pop_if(b, if_wave_0);
4285 
4286    *out_num_prm = nir_load_var(b, s->primitive_count_var);
4287    *out_num_vtx = nir_load_var(b, s->vertex_count_var);
4288 }
4289 
4290 static void
ms_emit_attribute_ring_output_stores(nir_builder * b,const uint64_t outputs_mask,nir_def * idx,lower_ngg_ms_state * s)4291 ms_emit_attribute_ring_output_stores(nir_builder *b, const uint64_t outputs_mask,
4292                                      nir_def *idx, lower_ngg_ms_state *s)
4293 {
4294    if (!outputs_mask)
4295       return;
4296 
4297    nir_def *ring = nir_load_ring_attr_amd(b);
4298    nir_def *off = nir_load_ring_attr_offset_amd(b);
4299    nir_def *zero = nir_imm_int(b, 0);
4300 
4301    u_foreach_bit64 (slot, outputs_mask) {
4302       if (s->vs_output_param_offset[slot] > AC_EXP_PARAM_OFFSET_31)
4303          continue;
4304 
4305       nir_def *soffset = nir_iadd_imm(b, off, s->vs_output_param_offset[slot] * 16 * 32);
4306       nir_def *store_val = nir_undef(b, 4, 32);
4307       unsigned store_val_components = 0;
4308       for (unsigned c = 0; c < 4; ++c) {
4309          if (s->outputs[slot][c]) {
4310             store_val = nir_vector_insert_imm(b, store_val, s->outputs[slot][c], c);
4311             store_val_components = c + 1;
4312          }
4313       }
4314 
4315       store_val = nir_trim_vector(b, store_val, store_val_components);
4316       nir_store_buffer_amd(b, store_val, ring, zero, soffset, idx,
4317                            .memory_modes = nir_var_shader_out,
4318                            .access = ACCESS_COHERENT | ACCESS_IS_SWIZZLED_AMD);
4319    }
4320 }
4321 
4322 static nir_def *
ms_prim_exp_arg_ch1(nir_builder * b,nir_def * invocation_index,nir_def * num_vtx,lower_ngg_ms_state * s)4323 ms_prim_exp_arg_ch1(nir_builder *b, nir_def *invocation_index, nir_def *num_vtx, lower_ngg_ms_state *s)
4324 {
4325    /* Primitive connectivity data: describes which vertices the primitive uses. */
4326    nir_def *prim_idx_addr = nir_imul_imm(b, invocation_index, s->vertices_per_prim);
4327    nir_def *indices_loaded = NULL;
4328    nir_def *cull_flag = NULL;
4329 
4330    if (s->layout.var.prm_attr.mask & BITFIELD64_BIT(VARYING_SLOT_PRIMITIVE_INDICES)) {
4331       nir_def *indices[3] = {0};
4332       for (unsigned c = 0; c < s->vertices_per_prim; ++c)
4333          indices[c] = nir_load_var(b, s->out_variables[VARYING_SLOT_PRIMITIVE_INDICES * 4 + c]);
4334       indices_loaded = nir_vec(b, indices, s->vertices_per_prim);
4335    } else {
4336       indices_loaded = nir_load_shared(b, s->vertices_per_prim, 8, prim_idx_addr, .base = s->layout.lds.indices_addr);
4337       indices_loaded = nir_u2u32(b, indices_loaded);
4338    }
4339 
4340    if (s->uses_cull_flags) {
4341       nir_def *loaded_cull_flag = NULL;
4342       if (s->layout.var.prm_attr.mask & BITFIELD64_BIT(VARYING_SLOT_CULL_PRIMITIVE))
4343          loaded_cull_flag = nir_load_var(b, s->out_variables[VARYING_SLOT_CULL_PRIMITIVE * 4]);
4344       else
4345          loaded_cull_flag = nir_u2u32(b, nir_load_shared(b, 1, 8, prim_idx_addr, .base = s->layout.lds.cull_flags_addr));
4346 
4347       cull_flag = nir_i2b(b, loaded_cull_flag);
4348    }
4349 
4350    nir_def *indices[3];
4351    nir_def *max_vtx_idx = nir_iadd_imm(b, num_vtx, -1u);
4352 
4353    for (unsigned i = 0; i < s->vertices_per_prim; ++i) {
4354       indices[i] = nir_channel(b, indices_loaded, i);
4355       indices[i] = nir_umin(b, indices[i], max_vtx_idx);
4356    }
4357 
4358    return emit_pack_ngg_prim_exp_arg(b, s->vertices_per_prim, indices, cull_flag);
4359 }
4360 
4361 static nir_def *
ms_prim_exp_arg_ch2(nir_builder * b,uint64_t outputs_mask,lower_ngg_ms_state * s)4362 ms_prim_exp_arg_ch2(nir_builder *b, uint64_t outputs_mask, lower_ngg_ms_state *s)
4363 {
4364    nir_def *prim_exp_arg_ch2 = NULL;
4365 
4366    if (outputs_mask) {
4367       /* When layer, viewport etc. are per-primitive, they need to be encoded in
4368        * the primitive export instruction's second channel. The encoding is:
4369        *
4370        * --- GFX10.3 ---
4371        * bits 31..30: VRS rate Y
4372        * bits 29..28: VRS rate X
4373        * bits 23..20: viewport
4374        * bits 19..17: layer
4375        *
4376        * --- GFX11 ---
4377        * bits 31..28: VRS rate enum
4378        * bits 23..20: viewport
4379        * bits 12..00: layer
4380        */
4381       prim_exp_arg_ch2 = nir_imm_int(b, 0);
4382 
4383       if (outputs_mask & VARYING_BIT_LAYER) {
4384          nir_def *layer =
4385             nir_ishl_imm(b, s->outputs[VARYING_SLOT_LAYER][0], s->gfx_level >= GFX11 ? 0 : 17);
4386          prim_exp_arg_ch2 = nir_ior(b, prim_exp_arg_ch2, layer);
4387       }
4388 
4389       if (outputs_mask & VARYING_BIT_VIEWPORT) {
4390          nir_def *view = nir_ishl_imm(b, s->outputs[VARYING_SLOT_VIEWPORT][0], 20);
4391          prim_exp_arg_ch2 = nir_ior(b, prim_exp_arg_ch2, view);
4392       }
4393 
4394       if (outputs_mask & VARYING_BIT_PRIMITIVE_SHADING_RATE) {
4395          nir_def *rate = s->outputs[VARYING_SLOT_PRIMITIVE_SHADING_RATE][0];
4396          prim_exp_arg_ch2 = nir_ior(b, prim_exp_arg_ch2, rate);
4397       }
4398    }
4399 
4400    return prim_exp_arg_ch2;
4401 }
4402 
4403 static void
ms_prim_gen_query(nir_builder * b,nir_def * invocation_index,nir_def * num_prm,lower_ngg_ms_state * s)4404 ms_prim_gen_query(nir_builder *b,
4405                   nir_def *invocation_index,
4406                   nir_def *num_prm,
4407                   lower_ngg_ms_state *s)
4408 {
4409    if (!s->has_query)
4410       return;
4411 
4412    nir_if *if_invocation_index_zero = nir_push_if(b, nir_ieq_imm(b, invocation_index, 0));
4413    {
4414       nir_if *if_shader_query = nir_push_if(b, nir_load_prim_gen_query_enabled_amd(b));
4415       {
4416          nir_atomic_add_gen_prim_count_amd(b, num_prm, .stream_id = 0);
4417       }
4418       nir_pop_if(b, if_shader_query);
4419    }
4420    nir_pop_if(b, if_invocation_index_zero);
4421 }
4422 
4423 static void
ms_invocation_query(nir_builder * b,nir_def * invocation_index,lower_ngg_ms_state * s)4424 ms_invocation_query(nir_builder *b,
4425                     nir_def *invocation_index,
4426                     lower_ngg_ms_state *s)
4427 {
4428    if (!s->has_query)
4429       return;
4430 
4431    nir_if *if_invocation_index_zero = nir_push_if(b, nir_ieq_imm(b, invocation_index, 0));
4432    {
4433       nir_if *if_pipeline_query = nir_push_if(b, nir_load_pipeline_stat_query_enabled_amd(b));
4434       {
4435          nir_atomic_add_shader_invocation_count_amd(b, nir_imm_int(b, s->api_workgroup_size));
4436       }
4437       nir_pop_if(b, if_pipeline_query);
4438    }
4439    nir_pop_if(b, if_invocation_index_zero);
4440 }
4441 
4442 static void
emit_ms_vertex(nir_builder * b,nir_def * index,nir_def * row,bool exports,bool parameters,uint64_t per_vertex_outputs,lower_ngg_ms_state * s)4443 emit_ms_vertex(nir_builder *b, nir_def *index, nir_def *row, bool exports, bool parameters,
4444                uint64_t per_vertex_outputs, lower_ngg_ms_state *s)
4445 {
4446    ms_emit_arrayed_outputs(b, index, per_vertex_outputs, s);
4447 
4448    if (exports) {
4449       ac_nir_export_position(b, s->gfx_level, s->clipdist_enable_mask,
4450                              !s->has_param_exports, false, true,
4451                              s->per_vertex_outputs | VARYING_BIT_POS, s->outputs, row);
4452    }
4453 
4454    if (parameters) {
4455       /* Export generic attributes on GFX10.3
4456        * (On GFX11 they are already stored in the attribute ring.)
4457        */
4458       if (s->has_param_exports && s->gfx_level == GFX10_3) {
4459          ac_nir_export_parameters(b, s->vs_output_param_offset, per_vertex_outputs, 0, s->outputs,
4460                                   NULL, NULL);
4461       }
4462 
4463       /* GFX11+: also store special outputs to the attribute ring so PS can load them. */
4464       if (s->gfx_level >= GFX11 && (per_vertex_outputs & MS_VERT_ARG_EXP_MASK))
4465          ms_emit_attribute_ring_output_stores(b, per_vertex_outputs & MS_VERT_ARG_EXP_MASK, index, s);
4466    }
4467 }
4468 
4469 static void
emit_ms_primitive(nir_builder * b,nir_def * index,nir_def * row,bool exports,bool parameters,uint64_t per_primitive_outputs,lower_ngg_ms_state * s)4470 emit_ms_primitive(nir_builder *b, nir_def *index, nir_def *row, bool exports, bool parameters,
4471                   uint64_t per_primitive_outputs, lower_ngg_ms_state *s)
4472 {
4473    ms_emit_arrayed_outputs(b, index, per_primitive_outputs, s);
4474 
4475    /* Insert layer output store if the pipeline uses multiview but the API shader doesn't write it. */
4476    if (s->insert_layer_output)
4477       s->outputs[VARYING_SLOT_LAYER][0] = nir_load_view_index(b);
4478 
4479    if (exports) {
4480       const uint64_t outputs_mask = per_primitive_outputs & MS_PRIM_ARG_EXP_MASK;
4481       nir_def *num_vtx = nir_load_var(b, s->vertex_count_var);
4482       nir_def *prim_exp_arg_ch1 = ms_prim_exp_arg_ch1(b, index, num_vtx, s);
4483       nir_def *prim_exp_arg_ch2 = ms_prim_exp_arg_ch2(b, outputs_mask, s);
4484 
4485       nir_def *prim_exp_arg = prim_exp_arg_ch2 ?
4486          nir_vec2(b, prim_exp_arg_ch1, prim_exp_arg_ch2) : prim_exp_arg_ch1;
4487 
4488       ac_nir_export_primitive(b, prim_exp_arg, row);
4489    }
4490 
4491    if (parameters) {
4492       /* Export generic attributes on GFX10.3
4493        * (On GFX11 they are already stored in the attribute ring.)
4494        */
4495       if (s->has_param_exports && s->gfx_level == GFX10_3) {
4496          ac_nir_export_parameters(b, s->vs_output_param_offset, per_primitive_outputs, 0,
4497                                   s->outputs, NULL, NULL);
4498       }
4499 
4500       /* GFX11+: also store special outputs to the attribute ring so PS can load them. */
4501       if (s->gfx_level >= GFX11)
4502          ms_emit_attribute_ring_output_stores(b, per_primitive_outputs & MS_PRIM_ARG_EXP_MASK, index, s);
4503    }
4504 }
4505 
4506 static void
emit_ms_outputs(nir_builder * b,nir_def * invocation_index,nir_def * row_start,nir_def * count,bool exports,bool parameters,uint64_t mask,void (* cb)(nir_builder *,nir_def *,nir_def *,bool,bool,uint64_t,lower_ngg_ms_state *),lower_ngg_ms_state * s)4507 emit_ms_outputs(nir_builder *b, nir_def *invocation_index, nir_def *row_start,
4508                 nir_def *count, bool exports, bool parameters, uint64_t mask,
4509                 void (*cb)(nir_builder *, nir_def *, nir_def *, bool, bool,
4510                            uint64_t, lower_ngg_ms_state *),
4511                 lower_ngg_ms_state *s)
4512 {
4513    if (cb == &emit_ms_primitive ? s->prim_multirow_export : s->vert_multirow_export) {
4514       assert(s->hw_workgroup_size % s->wave_size == 0);
4515       const unsigned num_waves = s->hw_workgroup_size / s->wave_size;
4516 
4517       nir_loop *row_loop = nir_push_loop(b);
4518       {
4519          nir_block *preheader = nir_cf_node_as_block(nir_cf_node_prev(&row_loop->cf_node));
4520 
4521          nir_phi_instr *index = nir_phi_instr_create(b->shader);
4522          nir_phi_instr *row = nir_phi_instr_create(b->shader);
4523          nir_def_init(&index->instr, &index->def, 1, 32);
4524          nir_def_init(&row->instr, &row->def, 1, 32);
4525 
4526          nir_phi_instr_add_src(index, preheader, invocation_index);
4527          nir_phi_instr_add_src(row, preheader, row_start);
4528 
4529          nir_if *if_break = nir_push_if(b, nir_uge(b, &index->def, count));
4530          {
4531             nir_jump(b, nir_jump_break);
4532          }
4533          nir_pop_if(b, if_break);
4534 
4535          cb(b, &index->def, &row->def, exports, parameters, mask, s);
4536 
4537          nir_block *body = nir_cursor_current_block(b->cursor);
4538          nir_phi_instr_add_src(index, body,
4539                                nir_iadd_imm(b, &index->def, s->hw_workgroup_size));
4540          nir_phi_instr_add_src(row, body,
4541                                nir_iadd_imm(b, &row->def, num_waves));
4542 
4543          nir_instr_insert_before_cf_list(&row_loop->body, &row->instr);
4544          nir_instr_insert_before_cf_list(&row_loop->body, &index->instr);
4545       }
4546       nir_pop_loop(b, row_loop);
4547    } else {
4548       nir_def *has_output = nir_ilt(b, invocation_index, count);
4549       nir_if *if_has_output = nir_push_if(b, has_output);
4550       {
4551          cb(b, invocation_index, row_start, exports, parameters, mask, s);
4552       }
4553       nir_pop_if(b, if_has_output);
4554    }
4555 }
4556 
4557 static void
emit_ms_finale(nir_builder * b,lower_ngg_ms_state * s)4558 emit_ms_finale(nir_builder *b, lower_ngg_ms_state *s)
4559 {
4560    /* We assume there is always a single end block in the shader. */
4561    nir_block *last_block = nir_impl_last_block(b->impl);
4562    b->cursor = nir_after_block(last_block);
4563 
4564    nir_barrier(b, .execution_scope=SCOPE_WORKGROUP, .memory_scope=SCOPE_WORKGROUP,
4565                          .memory_semantics=NIR_MEMORY_ACQ_REL, .memory_modes=nir_var_shader_out|nir_var_mem_shared);
4566 
4567    nir_def *num_prm;
4568    nir_def *num_vtx;
4569 
4570    set_ms_final_output_counts(b, s, &num_prm, &num_vtx);
4571 
4572    nir_def *invocation_index = nir_load_local_invocation_index(b);
4573 
4574    ms_prim_gen_query(b, invocation_index, num_prm, s);
4575 
4576    nir_def *row_start = NULL;
4577    if (s->fast_launch_2)
4578       row_start = s->hw_workgroup_size <= s->wave_size ? nir_imm_int(b, 0) : nir_load_subgroup_id(b);
4579 
4580    /* Load vertex/primitive attributes from shared memory and
4581     * emit store_output intrinsics for them.
4582     *
4583     * Contrary to the semantics of the API mesh shader, these are now
4584     * compliant with NGG HW semantics, meaning that these store the
4585     * current thread's vertex attributes in a way the HW can export.
4586     */
4587 
4588    uint64_t per_vertex_outputs =
4589       s->per_vertex_outputs & ~s->layout.attr_ring.vtx_attr.mask;
4590    uint64_t per_primitive_outputs =
4591       s->per_primitive_outputs & ~s->layout.attr_ring.prm_attr.mask & ~SPECIAL_MS_OUT_MASK;
4592 
4593    /* Insert layer output store if the pipeline uses multiview but the API shader doesn't write it. */
4594    if (s->insert_layer_output) {
4595       b->shader->info.outputs_written |= VARYING_BIT_LAYER;
4596       b->shader->info.per_primitive_outputs |= VARYING_BIT_LAYER;
4597       per_primitive_outputs |= VARYING_BIT_LAYER;
4598    }
4599 
4600    const bool has_special_param_exports =
4601       (per_vertex_outputs & MS_VERT_ARG_EXP_MASK) ||
4602       (per_primitive_outputs & MS_PRIM_ARG_EXP_MASK);
4603 
4604    const bool wait_attr_ring = must_wait_attr_ring(s->gfx_level, has_special_param_exports);
4605 
4606    /* Export vertices. */
4607    if ((per_vertex_outputs & ~VARYING_BIT_POS) || !wait_attr_ring) {
4608       emit_ms_outputs(b, invocation_index, row_start, num_vtx, !wait_attr_ring, true,
4609                       per_vertex_outputs, &emit_ms_vertex, s);
4610    }
4611 
4612    /* Export primitives. */
4613    if (per_primitive_outputs || !wait_attr_ring) {
4614       emit_ms_outputs(b, invocation_index, row_start, num_prm, !wait_attr_ring, true,
4615                       per_primitive_outputs, &emit_ms_primitive, s);
4616    }
4617 
4618    /* When we need to wait for attribute ring stores, we emit both position and primitive
4619     * export instructions after a barrier to make sure both per-vertex and per-primitive
4620     * attribute ring stores are finished before the GPU starts rasterization.
4621     */
4622    if (wait_attr_ring) {
4623       /* Wait for attribute stores to finish. */
4624       nir_barrier(b, .execution_scope = SCOPE_SUBGROUP,
4625                      .memory_scope = SCOPE_DEVICE,
4626                      .memory_semantics = NIR_MEMORY_RELEASE,
4627                      .memory_modes = nir_var_shader_out);
4628 
4629       /* Position/primitive export only */
4630       emit_ms_outputs(b, invocation_index, row_start, num_vtx, true, false,
4631                       per_vertex_outputs, &emit_ms_vertex, s);
4632       emit_ms_outputs(b, invocation_index, row_start, num_prm, true, false,
4633                       per_primitive_outputs, &emit_ms_primitive, s);
4634    }
4635 }
4636 
4637 static void
handle_smaller_ms_api_workgroup(nir_builder * b,lower_ngg_ms_state * s)4638 handle_smaller_ms_api_workgroup(nir_builder *b,
4639                                 lower_ngg_ms_state *s)
4640 {
4641    if (s->api_workgroup_size >= s->hw_workgroup_size)
4642       return;
4643 
4644    /* Handle barriers manually when the API workgroup
4645     * size is less than the HW workgroup size.
4646     *
4647     * The problem is that the real workgroup launched on NGG HW
4648     * will be larger than the size specified by the API, and the
4649     * extra waves need to keep up with barriers in the API waves.
4650     *
4651     * There are 2 different cases:
4652     * 1. The whole API workgroup fits in a single wave.
4653     *    We can shrink the barriers to subgroup scope and
4654     *    don't need to insert any extra ones.
4655     * 2. The API workgroup occupies multiple waves, but not
4656     *    all. In this case, we emit code that consumes every
4657     *    barrier on the extra waves.
4658     */
4659    assert(s->hw_workgroup_size % s->wave_size == 0);
4660    bool scan_barriers = ALIGN(s->api_workgroup_size, s->wave_size) < s->hw_workgroup_size;
4661    bool can_shrink_barriers = s->api_workgroup_size <= s->wave_size;
4662    bool need_additional_barriers = scan_barriers && !can_shrink_barriers;
4663 
4664    unsigned api_waves_in_flight_addr = s->layout.lds.workgroup_info_addr + lds_ms_num_api_waves;
4665    unsigned num_api_waves = DIV_ROUND_UP(s->api_workgroup_size, s->wave_size);
4666 
4667    /* Scan the shader for workgroup barriers. */
4668    if (scan_barriers) {
4669       bool has_any_workgroup_barriers = false;
4670 
4671       nir_foreach_block(block, b->impl) {
4672          nir_foreach_instr_safe(instr, block) {
4673             if (instr->type != nir_instr_type_intrinsic)
4674                continue;
4675 
4676             nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
4677             bool is_workgroup_barrier =
4678                intrin->intrinsic == nir_intrinsic_barrier &&
4679                nir_intrinsic_execution_scope(intrin) == SCOPE_WORKGROUP;
4680 
4681             if (!is_workgroup_barrier)
4682                continue;
4683 
4684             if (can_shrink_barriers) {
4685                /* Every API invocation runs in the first wave.
4686                 * In this case, we can change the barriers to subgroup scope
4687                 * and avoid adding additional barriers.
4688                 */
4689                nir_intrinsic_set_memory_scope(intrin, SCOPE_SUBGROUP);
4690                nir_intrinsic_set_execution_scope(intrin, SCOPE_SUBGROUP);
4691             } else {
4692                has_any_workgroup_barriers = true;
4693             }
4694          }
4695       }
4696 
4697       need_additional_barriers &= has_any_workgroup_barriers;
4698    }
4699 
4700    /* Extract the full control flow of the shader. */
4701    nir_cf_list extracted;
4702    nir_cf_extract(&extracted, nir_before_impl(b->impl),
4703                   nir_after_cf_list(&b->impl->body));
4704    b->cursor = nir_before_impl(b->impl);
4705 
4706    /* Wrap the shader in an if to ensure that only the necessary amount of lanes run it. */
4707    nir_def *invocation_index = nir_load_local_invocation_index(b);
4708    nir_def *zero = nir_imm_int(b, 0);
4709 
4710    if (need_additional_barriers) {
4711       /* First invocation stores 0 to number of API waves in flight. */
4712       nir_if *if_first_in_workgroup = nir_push_if(b, nir_ieq_imm(b, invocation_index, 0));
4713       {
4714          nir_store_shared(b, nir_imm_int(b, num_api_waves), zero, .base = api_waves_in_flight_addr);
4715       }
4716       nir_pop_if(b, if_first_in_workgroup);
4717 
4718       nir_barrier(b, .execution_scope = SCOPE_WORKGROUP,
4719                             .memory_scope = SCOPE_WORKGROUP,
4720                             .memory_semantics = NIR_MEMORY_ACQ_REL,
4721                             .memory_modes = nir_var_shader_out | nir_var_mem_shared);
4722    }
4723 
4724    nir_def *has_api_ms_invocation = nir_ult_imm(b, invocation_index, s->api_workgroup_size);
4725    nir_if *if_has_api_ms_invocation = nir_push_if(b, has_api_ms_invocation);
4726    {
4727       nir_cf_reinsert(&extracted, b->cursor);
4728       b->cursor = nir_after_cf_list(&if_has_api_ms_invocation->then_list);
4729 
4730       if (need_additional_barriers) {
4731          /* One invocation in each API wave decrements the number of API waves in flight. */
4732          nir_if *if_elected_again = nir_push_if(b, nir_elect(b, 1));
4733          {
4734             nir_shared_atomic(b, 32, zero, nir_imm_int(b, -1u),
4735                               .base = api_waves_in_flight_addr,
4736                               .atomic_op = nir_atomic_op_iadd);
4737          }
4738          nir_pop_if(b, if_elected_again);
4739 
4740          nir_barrier(b, .execution_scope = SCOPE_WORKGROUP,
4741                                .memory_scope = SCOPE_WORKGROUP,
4742                                .memory_semantics = NIR_MEMORY_ACQ_REL,
4743                                .memory_modes = nir_var_shader_out | nir_var_mem_shared);
4744       }
4745 
4746       ms_invocation_query(b, invocation_index, s);
4747    }
4748    nir_pop_if(b, if_has_api_ms_invocation);
4749 
4750    if (need_additional_barriers) {
4751       /* Make sure that waves that don't run any API invocations execute
4752        * the same amount of barriers as those that do.
4753        *
4754        * We do this by executing a barrier until the number of API waves
4755        * in flight becomes zero.
4756        */
4757       nir_def *has_api_ms_ballot = nir_ballot(b, 1, s->wave_size, has_api_ms_invocation);
4758       nir_def *wave_has_no_api_ms = nir_ieq_imm(b, has_api_ms_ballot, 0);
4759       nir_if *if_wave_has_no_api_ms = nir_push_if(b, wave_has_no_api_ms);
4760       {
4761          nir_if *if_elected = nir_push_if(b, nir_elect(b, 1));
4762          {
4763             nir_loop *loop = nir_push_loop(b);
4764             {
4765                nir_barrier(b, .execution_scope = SCOPE_WORKGROUP,
4766                                      .memory_scope = SCOPE_WORKGROUP,
4767                                      .memory_semantics = NIR_MEMORY_ACQ_REL,
4768                                      .memory_modes = nir_var_shader_out | nir_var_mem_shared);
4769 
4770                nir_def *loaded = nir_load_shared(b, 1, 32, zero, .base = api_waves_in_flight_addr);
4771                nir_if *if_break = nir_push_if(b, nir_ieq_imm(b, loaded, 0));
4772                {
4773                   nir_jump(b, nir_jump_break);
4774                }
4775                nir_pop_if(b, if_break);
4776             }
4777             nir_pop_loop(b, loop);
4778          }
4779          nir_pop_if(b, if_elected);
4780       }
4781       nir_pop_if(b, if_wave_has_no_api_ms);
4782    }
4783 }
4784 
4785 static void
ms_move_output(ms_out_part * from,ms_out_part * to)4786 ms_move_output(ms_out_part *from, ms_out_part *to)
4787 {
4788    uint64_t loc = util_logbase2_64(from->mask);
4789    uint64_t bit = BITFIELD64_BIT(loc);
4790    from->mask ^= bit;
4791    to->mask |= bit;
4792 }
4793 
4794 static void
ms_calculate_arrayed_output_layout(ms_out_mem_layout * l,unsigned max_vertices,unsigned max_primitives)4795 ms_calculate_arrayed_output_layout(ms_out_mem_layout *l,
4796                                    unsigned max_vertices,
4797                                    unsigned max_primitives)
4798 {
4799    uint32_t lds_vtx_attr_size = util_bitcount64(l->lds.vtx_attr.mask) * max_vertices * 16;
4800    uint32_t lds_prm_attr_size = util_bitcount64(l->lds.prm_attr.mask) * max_primitives * 16;
4801    l->lds.prm_attr.addr = ALIGN(l->lds.vtx_attr.addr + lds_vtx_attr_size, 16);
4802    l->lds.total_size = l->lds.prm_attr.addr + lds_prm_attr_size;
4803 
4804    uint32_t scratch_ring_vtx_attr_size =
4805       util_bitcount64(l->scratch_ring.vtx_attr.mask) * max_vertices * 16;
4806    l->scratch_ring.prm_attr.addr =
4807       ALIGN(l->scratch_ring.vtx_attr.addr + scratch_ring_vtx_attr_size, 16);
4808 }
4809 
4810 static ms_out_mem_layout
ms_calculate_output_layout(enum amd_gfx_level gfx_level,unsigned api_shared_size,uint64_t per_vertex_output_mask,uint64_t per_primitive_output_mask,uint64_t cross_invocation_output_access,unsigned max_vertices,unsigned max_primitives,unsigned vertices_per_prim)4811 ms_calculate_output_layout(enum amd_gfx_level gfx_level, unsigned api_shared_size,
4812                            uint64_t per_vertex_output_mask, uint64_t per_primitive_output_mask,
4813                            uint64_t cross_invocation_output_access, unsigned max_vertices,
4814                            unsigned max_primitives, unsigned vertices_per_prim)
4815 {
4816    /* These outputs always need export instructions and can't use the attributes ring. */
4817    const uint64_t always_export_mask =
4818       VARYING_BIT_POS | VARYING_BIT_CULL_DIST0 | VARYING_BIT_CULL_DIST1 | VARYING_BIT_CLIP_DIST0 |
4819       VARYING_BIT_CLIP_DIST1 | VARYING_BIT_PSIZ | VARYING_BIT_VIEWPORT |
4820       VARYING_BIT_PRIMITIVE_SHADING_RATE | VARYING_BIT_LAYER |
4821       BITFIELD64_BIT(VARYING_SLOT_PRIMITIVE_COUNT) |
4822       BITFIELD64_BIT(VARYING_SLOT_PRIMITIVE_INDICES) | BITFIELD64_BIT(VARYING_SLOT_CULL_PRIMITIVE);
4823 
4824    const bool use_attr_ring = gfx_level >= GFX11;
4825    const uint64_t attr_ring_per_vertex_output_mask =
4826       use_attr_ring ? per_vertex_output_mask & ~always_export_mask : 0;
4827    const uint64_t attr_ring_per_primitive_output_mask =
4828       use_attr_ring ? per_primitive_output_mask & ~always_export_mask : 0;
4829 
4830    const uint64_t lds_per_vertex_output_mask =
4831       per_vertex_output_mask & ~attr_ring_per_vertex_output_mask & cross_invocation_output_access &
4832       ~SPECIAL_MS_OUT_MASK;
4833    const uint64_t lds_per_primitive_output_mask =
4834       per_primitive_output_mask & ~attr_ring_per_primitive_output_mask &
4835       cross_invocation_output_access & ~SPECIAL_MS_OUT_MASK;
4836 
4837    const bool cross_invocation_indices =
4838       cross_invocation_output_access & BITFIELD64_BIT(VARYING_SLOT_PRIMITIVE_INDICES);
4839    const bool cross_invocation_cull_primitive =
4840       cross_invocation_output_access & BITFIELD64_BIT(VARYING_SLOT_CULL_PRIMITIVE);
4841 
4842    /* Shared memory used by the API shader. */
4843    ms_out_mem_layout l = { .lds = { .total_size = api_shared_size } };
4844 
4845    /* GFX11+: use attribute ring for all generic attributes. */
4846    l.attr_ring.vtx_attr.mask = attr_ring_per_vertex_output_mask;
4847    l.attr_ring.prm_attr.mask = attr_ring_per_primitive_output_mask;
4848 
4849    /* Outputs without cross-invocation access can be stored in variables. */
4850    l.var.vtx_attr.mask =
4851       per_vertex_output_mask & ~attr_ring_per_vertex_output_mask & ~cross_invocation_output_access;
4852    l.var.prm_attr.mask = per_primitive_output_mask & ~attr_ring_per_primitive_output_mask &
4853                          ~cross_invocation_output_access;
4854 
4855    /* Workgroup information, see ms_workgroup_* for the layout. */
4856    l.lds.workgroup_info_addr = ALIGN(l.lds.total_size, 16);
4857    l.lds.total_size = l.lds.workgroup_info_addr + 16;
4858 
4859    /* Per-vertex and per-primitive output attributes.
4860     * Outputs without cross-invocation access are not included here.
4861     * First, try to put all outputs into LDS (shared memory).
4862     * If they don't fit, try to move them to VRAM one by one.
4863     */
4864    l.lds.vtx_attr.addr = ALIGN(l.lds.total_size, 16);
4865    l.lds.vtx_attr.mask = lds_per_vertex_output_mask;
4866    l.lds.prm_attr.mask = lds_per_primitive_output_mask;
4867    ms_calculate_arrayed_output_layout(&l, max_vertices, max_primitives);
4868 
4869    /* NGG shaders can only address up to 32K LDS memory.
4870     * The spec requires us to allow the application to use at least up to 28K
4871     * shared memory. Additionally, we reserve 2K for driver internal use
4872     * (eg. primitive indices and such, see below).
4873     *
4874     * Move the outputs that do not fit LDS, to VRAM.
4875     * Start with per-primitive attributes, because those are grouped at the end.
4876     */
4877    const unsigned usable_lds_kbytes =
4878       (cross_invocation_cull_primitive || cross_invocation_indices) ? 30 : 31;
4879    while (l.lds.total_size >= usable_lds_kbytes * 1024) {
4880       if (l.lds.prm_attr.mask)
4881          ms_move_output(&l.lds.prm_attr, &l.scratch_ring.prm_attr);
4882       else if (l.lds.vtx_attr.mask)
4883          ms_move_output(&l.lds.vtx_attr, &l.scratch_ring.vtx_attr);
4884       else
4885          unreachable("API shader uses too much shared memory.");
4886 
4887       ms_calculate_arrayed_output_layout(&l, max_vertices, max_primitives);
4888    }
4889 
4890    if (cross_invocation_indices) {
4891       /* Indices: flat array of 8-bit vertex indices for each primitive. */
4892       l.lds.indices_addr = ALIGN(l.lds.total_size, 16);
4893       l.lds.total_size = l.lds.indices_addr + max_primitives * vertices_per_prim;
4894    }
4895 
4896    if (cross_invocation_cull_primitive) {
4897       /* Cull flags: array of 8-bit cull flags for each primitive, 1=cull, 0=keep. */
4898       l.lds.cull_flags_addr = ALIGN(l.lds.total_size, 16);
4899       l.lds.total_size = l.lds.cull_flags_addr + max_primitives;
4900    }
4901 
4902    /* NGG is only allowed to address up to 32K of LDS. */
4903    assert(l.lds.total_size <= 32 * 1024);
4904    return l;
4905 }
4906 
4907 void
ac_nir_lower_ngg_ms(nir_shader * shader,enum amd_gfx_level gfx_level,uint32_t clipdist_enable_mask,const uint8_t * vs_output_param_offset,bool has_param_exports,bool * out_needs_scratch_ring,unsigned wave_size,unsigned hw_workgroup_size,bool multiview,bool has_query,bool fast_launch_2)4908 ac_nir_lower_ngg_ms(nir_shader *shader,
4909                     enum amd_gfx_level gfx_level,
4910                     uint32_t clipdist_enable_mask,
4911                     const uint8_t *vs_output_param_offset,
4912                     bool has_param_exports,
4913                     bool *out_needs_scratch_ring,
4914                     unsigned wave_size,
4915                     unsigned hw_workgroup_size,
4916                     bool multiview,
4917                     bool has_query,
4918                     bool fast_launch_2)
4919 {
4920    unsigned vertices_per_prim =
4921       mesa_vertices_per_prim(shader->info.mesh.primitive_type);
4922 
4923    uint64_t per_vertex_outputs =
4924       shader->info.outputs_written & ~shader->info.per_primitive_outputs & ~SPECIAL_MS_OUT_MASK;
4925    uint64_t per_primitive_outputs =
4926       shader->info.per_primitive_outputs & shader->info.outputs_written;
4927 
4928    /* Whether the shader uses CullPrimitiveEXT */
4929    bool uses_cull = shader->info.outputs_written & BITFIELD64_BIT(VARYING_SLOT_CULL_PRIMITIVE);
4930    /* Can't handle indirect register addressing, pretend as if they were cross-invocation. */
4931    uint64_t cross_invocation_access = shader->info.mesh.ms_cross_invocation_output_access |
4932                                       shader->info.outputs_accessed_indirectly;
4933 
4934    unsigned max_vertices = shader->info.mesh.max_vertices_out;
4935    unsigned max_primitives = shader->info.mesh.max_primitives_out;
4936 
4937    ms_out_mem_layout layout = ms_calculate_output_layout(
4938       gfx_level, shader->info.shared_size, per_vertex_outputs, per_primitive_outputs,
4939       cross_invocation_access, max_vertices, max_primitives, vertices_per_prim);
4940 
4941    shader->info.shared_size = layout.lds.total_size;
4942    *out_needs_scratch_ring = layout.scratch_ring.vtx_attr.mask || layout.scratch_ring.prm_attr.mask;
4943 
4944    /* The workgroup size that is specified by the API shader may be different
4945     * from the size of the workgroup that actually runs on the HW, due to the
4946     * limitations of NGG: max 0/1 vertex and 0/1 primitive per lane is allowed.
4947     *
4948     * Therefore, we must make sure that when the API workgroup size is smaller,
4949     * we don't run the API shader on more HW invocations than is necessary.
4950     */
4951    unsigned api_workgroup_size = shader->info.workgroup_size[0] *
4952                                  shader->info.workgroup_size[1] *
4953                                  shader->info.workgroup_size[2];
4954 
4955    lower_ngg_ms_state state = {
4956       .layout = layout,
4957       .wave_size = wave_size,
4958       .per_vertex_outputs = per_vertex_outputs,
4959       .per_primitive_outputs = per_primitive_outputs,
4960       .vertices_per_prim = vertices_per_prim,
4961       .api_workgroup_size = api_workgroup_size,
4962       .hw_workgroup_size = hw_workgroup_size,
4963       .insert_layer_output = multiview && !(shader->info.outputs_written & VARYING_BIT_LAYER),
4964       .uses_cull_flags = uses_cull,
4965       .gfx_level = gfx_level,
4966       .fast_launch_2 = fast_launch_2,
4967       .vert_multirow_export = fast_launch_2 && max_vertices > hw_workgroup_size,
4968       .prim_multirow_export = fast_launch_2 && max_primitives > hw_workgroup_size,
4969       .clipdist_enable_mask = clipdist_enable_mask,
4970       .vs_output_param_offset = vs_output_param_offset,
4971       .has_param_exports = has_param_exports,
4972       .has_query = has_query,
4973    };
4974 
4975    nir_function_impl *impl = nir_shader_get_entrypoint(shader);
4976    assert(impl);
4977 
4978    state.vertex_count_var =
4979       nir_local_variable_create(impl, glsl_uint_type(), "vertex_count_var");
4980    state.primitive_count_var =
4981       nir_local_variable_create(impl, glsl_uint_type(), "primitive_count_var");
4982 
4983    nir_builder builder = nir_builder_at(nir_before_impl(impl));
4984    nir_builder *b = &builder; /* This is to avoid the & */
4985 
4986    handle_smaller_ms_api_workgroup(b, &state);
4987    if (!fast_launch_2)
4988       ms_emit_legacy_workgroup_index(b, &state);
4989    ms_create_same_invocation_vars(b, &state);
4990    nir_metadata_preserve(impl, nir_metadata_none);
4991 
4992    lower_ms_intrinsics(shader, &state);
4993 
4994    emit_ms_finale(b, &state);
4995    nir_metadata_preserve(impl, nir_metadata_none);
4996 
4997    /* Cleanup */
4998    nir_lower_vars_to_ssa(shader);
4999    nir_remove_dead_variables(shader, nir_var_function_temp, NULL);
5000    nir_lower_alu_to_scalar(shader, NULL, NULL);
5001    nir_lower_phis_to_scalar(shader, true);
5002 
5003    /* Optimize load_local_invocation_index. When the API workgroup is smaller than the HW workgroup,
5004     * local_invocation_id isn't initialized for all lanes and we can't perform this optimization for
5005     * all load_local_invocation_index.
5006     */
5007    if (fast_launch_2 && api_workgroup_size == hw_workgroup_size &&
5008        ((shader->info.workgroup_size[0] == 1) + (shader->info.workgroup_size[1] == 1) +
5009         (shader->info.workgroup_size[2] == 1)) == 2) {
5010       nir_lower_compute_system_values_options csv_options = {
5011          .lower_local_invocation_index = true,
5012       };
5013       nir_lower_compute_system_values(shader, &csv_options);
5014    }
5015 
5016    nir_validate_shader(shader, "after emitting NGG MS");
5017 }
5018