• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright © 2021 Valve Corporation
3  *
4  * SPDX-License-Identifier: MIT
5  */
6 
7 #include "ac_nir.h"
8 #include "ac_nir_helpers.h"
9 #include "amdgfxregs.h"
10 #include "nir_builder.h"
11 #include "nir_xfb_info.h"
12 #include "util/u_math.h"
13 #include "util/u_vector.h"
14 
15 enum {
16    nggc_passflag_used_by_pos = 1,
17    nggc_passflag_used_by_other = 2,
18    nggc_passflag_used_by_both = nggc_passflag_used_by_pos | nggc_passflag_used_by_other,
19 };
20 
21 typedef struct
22 {
23    nir_def *ssa;
24    nir_variable *var;
25 } reusable_nondeferred_variable;
26 
27 typedef struct
28 {
29    gl_varying_slot slot;
30    nir_def *chan[4];
31 } vs_output;
32 
33 typedef struct
34 {
35    const ac_nir_lower_ngg_options *options;
36 
37    nir_variable *position_value_var;
38    nir_variable *prim_exp_arg_var;
39    nir_variable *es_accepted_var;
40    nir_variable *gs_accepted_var;
41    nir_variable *gs_exported_var;
42    nir_variable *gs_vtx_indices_vars[3];
43 
44    nir_def *vtx_addr[3];
45 
46    struct u_vector reusable_nondeferred_variables;
47 
48    bool early_prim_export;
49    bool streamout_enabled;
50    bool has_user_edgeflags;
51    bool skip_primitive_id;
52    unsigned max_num_waves;
53 
54    /* LDS params */
55    unsigned pervertex_lds_bytes;
56 
57    uint64_t inputs_needed_by_pos;
58    uint64_t inputs_needed_by_others;
59 
60    nir_instr *compact_arg_stores[4];
61    nir_intrinsic_instr *overwrite_args;
62    nir_variable *repacked_rel_patch_id;
63 
64    /* clip distance */
65    nir_variable *clip_vertex_var;
66    nir_variable *clipdist_neg_mask_var;
67    bool has_clipdist;
68 
69    /* outputs */
70    ac_nir_prerast_out out;
71 } lower_ngg_nogs_state;
72 
73 typedef struct
74 {
75    const ac_nir_lower_ngg_options *options;
76 
77    nir_function_impl *impl;
78    int const_out_vtxcnt[4];
79    int const_out_prmcnt[4];
80    unsigned max_num_waves;
81    unsigned num_vertices_per_primitive;
82    nir_def *lds_addr_gs_out_vtx;
83    nir_def *lds_addr_gs_scratch;
84    unsigned lds_bytes_per_gs_out_vertex;
85    unsigned lds_offs_primflags;
86    bool output_compile_time_known;
87    bool streamout_enabled;
88    /* Outputs */
89    ac_nir_prerast_out out;
90    /* Count per stream. */
91    nir_def *vertex_count[4];
92    nir_def *primitive_count[4];
93 } lower_ngg_gs_state;
94 
95 /* Per-vertex LDS layout of culling shaders */
96 enum {
97    /* Position of the ES vertex (at the beginning for alignment reasons) */
98    lds_es_pos_x = 0,
99    lds_es_pos_y = 4,
100    lds_es_pos_z = 8,
101    lds_es_pos_w = 12,
102 
103    /* 1 when the vertex is accepted, 0 if it should be culled */
104    lds_es_vertex_accepted = 16,
105    /* ID of the thread which will export the current thread's vertex */
106    lds_es_exporter_tid = 17,
107    /* bit i is set when the i'th clip distance of a vertex is negative */
108    lds_es_clipdist_neg_mask = 18,
109    /* TES only, relative patch ID, less than max workgroup size */
110    lds_es_tes_rel_patch_id = 19,
111 
112    /* Repacked arguments - also listed separately for VS and TES */
113    lds_es_arg_0 = 20,
114 };
115 
116 typedef struct {
117    nir_def *num_repacked_invocations;
118    nir_def *repacked_invocation_index;
119 } wg_repack_result;
120 
121 /**
122  * Computes a horizontal sum of 8-bit packed values loaded from LDS.
123  *
124  * Each lane N will sum packed bytes 0 to N.
125  * We only care about the results from up to wave_id lanes.
126  * (Other lanes are not deactivated but their calculation is not used.)
127  */
128 static nir_def *
summarize_repack(nir_builder * b,nir_def * packed_counts,bool mask_lane_id,unsigned num_lds_dwords)129 summarize_repack(nir_builder *b, nir_def *packed_counts, bool mask_lane_id, unsigned num_lds_dwords)
130 {
131    /* We'll use shift to filter out the bytes not needed by the current lane.
132     *
133     * For each row:
134     * Need to shift by: `num_lds_dwords * 4 - 1 - lane_id_in_row` (in bytes)
135     * in order to implement an inclusive scan.
136     *
137     * When v_dot4_u32_u8 is available, we right-shift a series of 0x01 bytes.
138     * This will yield 0x01 at wanted byte positions and 0x00 at unwanted positions,
139     * therefore v_dot can get rid of the unneeded values.
140     *
141     * If the v_dot instruction can't be used, we left-shift the packed bytes
142     * in order to shift out the unneeded bytes and shift in zeroes instead,
143     * then we sum them using v_msad_u8.
144     */
145 
146    nir_def *lane_id = nir_load_subgroup_invocation(b);
147 
148    /* Mask lane ID so that lanes 16...31 also have the ID 0...15,
149     * in order to perform a second horizontal sum in parallel when needed.
150     */
151    if (mask_lane_id)
152       lane_id = nir_iand_imm(b, lane_id, 0xf);
153 
154    nir_def *shift = nir_iadd_imm(b, nir_imul_imm(b, lane_id, -8u), num_lds_dwords * 32 - 8);
155    assert(b->shader->options->has_msad || b->shader->options->has_udot_4x8);
156    bool use_dot = b->shader->options->has_udot_4x8;
157 
158    if (num_lds_dwords == 1) {
159       /* Broadcast the packed data we read from LDS
160        * (to the first 16 lanes of the row, but we only care up to num_waves).
161        */
162       nir_def *packed = nir_lane_permute_16_amd(b, packed_counts, nir_imm_int(b, 0), nir_imm_int(b, 0));
163 
164       /* Horizontally add the packed bytes. */
165       if (use_dot) {
166          nir_def *dot_op = nir_ushr(b, nir_imm_int(b, 0x01010101), shift);
167          return nir_udot_4x8_uadd(b, packed, dot_op, nir_imm_int(b, 0));
168       } else {
169          nir_def *sad_op = nir_ishl(b, packed, shift);
170          return nir_msad_4x8(b, sad_op, nir_imm_int(b, 0), nir_imm_int(b, 0));
171       }
172    } else if (num_lds_dwords == 2) {
173       /* Broadcast the packed data we read from LDS
174        * (to the first 16 lanes of the row, but we only care up to num_waves).
175        */
176       nir_def *packed_dw0 = nir_lane_permute_16_amd(b, nir_unpack_64_2x32_split_x(b, packed_counts), nir_imm_int(b, 0), nir_imm_int(b, 0));
177       nir_def *packed_dw1 = nir_lane_permute_16_amd(b, nir_unpack_64_2x32_split_y(b, packed_counts), nir_imm_int(b, 0), nir_imm_int(b, 0));
178 
179       /* Horizontally add the packed bytes. */
180       if (use_dot) {
181          nir_def *dot_op = nir_ushr(b, nir_imm_int64(b, 0x0101010101010101), shift);
182          nir_def *sum = nir_udot_4x8_uadd(b, packed_dw0, nir_unpack_64_2x32_split_x(b, dot_op), nir_imm_int(b, 0));
183          return nir_udot_4x8_uadd(b, packed_dw1, nir_unpack_64_2x32_split_y(b, dot_op), sum);
184       } else {
185          nir_def *sad_op = nir_ishl(b, nir_pack_64_2x32_split(b, packed_dw0, packed_dw1), shift);
186          nir_def *sum = nir_msad_4x8(b, nir_unpack_64_2x32_split_x(b, sad_op), nir_imm_int(b, 0), nir_imm_int(b, 0));
187          return nir_msad_4x8(b, nir_unpack_64_2x32_split_y(b, sad_op), nir_imm_int(b, 0), sum);
188       }
189    } else {
190       unreachable("Unimplemented NGG wave count");
191    }
192 }
193 
194 /**
195  * Repacks invocations in the current workgroup to eliminate gaps between them.
196  *
197  * Uses 1 dword of LDS per 4 waves (1 byte of LDS per wave) for each repack.
198  * Assumes that all invocations in the workgroup are active (exec = -1).
199  */
200 static void
repack_invocations_in_workgroup(nir_builder * b,nir_def ** input_bool,wg_repack_result * results,const unsigned num_repacks,nir_def * lds_addr_base,unsigned max_num_waves,unsigned wave_size)201 repack_invocations_in_workgroup(nir_builder *b, nir_def **input_bool,
202                                 wg_repack_result *results, const unsigned num_repacks,
203                                 nir_def *lds_addr_base, unsigned max_num_waves,
204                                 unsigned wave_size)
205 {
206    /* We can currently only do up to 2 repacks at a time. */
207    assert(num_repacks <= 2);
208 
209    /* STEP 1. Count surviving invocations in the current wave.
210     *
211     * Implemented by a scalar instruction that simply counts the number of bits set in a 32/64-bit mask.
212     */
213 
214    nir_def *input_mask[2];
215    nir_def *surviving_invocations_in_current_wave[2];
216 
217    for (unsigned i = 0; i < num_repacks; ++i) {
218       /* Input should be boolean: 1 if the current invocation should survive the repack. */
219       assert(input_bool[i]->bit_size == 1);
220 
221       input_mask[i] = nir_ballot(b, 1, wave_size, input_bool[i]);
222       surviving_invocations_in_current_wave[i] = nir_bit_count(b, input_mask[i]);
223    }
224 
225    /* If we know at compile time that the workgroup has only 1 wave, no further steps are necessary. */
226    if (max_num_waves == 1) {
227       for (unsigned i = 0; i < num_repacks; ++i) {
228          results[i].num_repacked_invocations = surviving_invocations_in_current_wave[i];
229          results[i].repacked_invocation_index = nir_mbcnt_amd(b, input_mask[i], nir_imm_int(b, 0));
230       }
231       return;
232    }
233 
234    /* STEP 2. Waves tell each other their number of surviving invocations.
235     *
236     * Row 0 (lanes 0-15) performs the first repack, and Row 1 (lanes 16-31) the second in parallel.
237     * Each wave activates only its first lane per row, which stores the number of surviving
238     * invocations in that wave into the LDS for that repack, then reads the numbers from every wave.
239     *
240     * The workgroup size of NGG shaders is at most 256, which means
241     * the maximum number of waves is 4 in Wave64 mode and 8 in Wave32 mode.
242     * For each repack:
243     * Each wave writes 1 byte, so it's up to 8 bytes, so at most 2 dwords are necessary.
244     * (The maximum is 4 dwords for 2 repacks in Wave32 mode.)
245     */
246 
247    const unsigned num_lds_dwords = DIV_ROUND_UP(max_num_waves, 4);
248    assert(num_lds_dwords <= 2);
249 
250    /* The first lane of each row (per repack) needs to access the LDS. */
251    const unsigned ballot = num_repacks == 1 ? 1 : 0x10001;
252 
253    nir_def *wave_id = nir_load_subgroup_id(b);
254    nir_def *dont_care = nir_undef(b, 1, num_lds_dwords * 32);
255    nir_def *packed_counts = NULL;
256 
257    nir_if *if_use_lds = nir_push_if(b, nir_inverse_ballot(b, 1, nir_imm_intN_t(b, ballot, wave_size)));
258    {
259       nir_def *store_val = surviving_invocations_in_current_wave[0];
260 
261       if (num_repacks == 2) {
262          nir_def *lane_id_0 = nir_inverse_ballot(b, 1, nir_imm_intN_t(b, 1, wave_size));
263          nir_def *off = nir_bcsel(b, lane_id_0, nir_imm_int(b, 0), nir_imm_int(b, num_lds_dwords * 4));
264          lds_addr_base = nir_iadd_nuw(b, lds_addr_base, off);
265          store_val = nir_bcsel(b, lane_id_0, store_val, surviving_invocations_in_current_wave[1]);
266       }
267 
268       nir_def *store_byte = nir_u2u8(b, store_val);
269       nir_def *lds_offset = nir_iadd(b, lds_addr_base, wave_id);
270       nir_store_shared(b, store_byte, lds_offset);
271 
272       nir_barrier(b, .execution_scope = SCOPE_WORKGROUP, .memory_scope = SCOPE_WORKGROUP,
273                      .memory_semantics = NIR_MEMORY_ACQ_REL, .memory_modes = nir_var_mem_shared);
274 
275       packed_counts = nir_load_shared(b, 1, num_lds_dwords * 32, lds_addr_base, .align_mul = 8u);
276    }
277    nir_pop_if(b, if_use_lds);
278 
279    packed_counts = nir_if_phi(b, packed_counts, dont_care);
280 
281    /* STEP 3. Compute the repacked invocation index and the total number of surviving invocations.
282     *
283     * By now, every wave knows the number of surviving invocations in all waves.
284     * Each number is 1 byte, and they are packed into up to 2 dwords.
285     *
286     * For each row (of 16 lanes):
287     * Each lane N (in the row) will sum the number of surviving invocations inclusively from waves 0 to N.
288     * If the workgroup has M waves, then each row will use only its first M lanes for this.
289     * (Other lanes are not deactivated but their calculation is not used.)
290     *
291     * - We read the sum from the lane whose id  (in the row) is the current wave's id,
292     *   and subtract the number of its own surviving invocations.
293     *   Add the masked bitcount to this, and we get the repacked invocation index.
294     * - We read the sum from the lane whose id (in the row) is the number of waves in the workgroup minus 1.
295     *   This is the total number of surviving invocations in the workgroup.
296     */
297 
298    nir_def *num_waves = nir_load_num_subgroups(b);
299    nir_def *sum = summarize_repack(b, packed_counts, num_repacks == 2, num_lds_dwords);
300 
301    for (unsigned i = 0; i < num_repacks; ++i) {
302       nir_def *index_base_lane = nir_iadd_imm_nuw(b, wave_id, i * 16);
303       nir_def *num_invocartions_lane = nir_iadd_imm(b, num_waves, i * 16 - 1);
304       nir_def *wg_repacked_index_base =
305          nir_isub(b, nir_read_invocation(b, sum, index_base_lane), surviving_invocations_in_current_wave[i]);
306       results[i].num_repacked_invocations =
307          nir_read_invocation(b, sum, num_invocartions_lane);
308       results[i].repacked_invocation_index =
309          nir_mbcnt_amd(b, input_mask[i], wg_repacked_index_base);
310    }
311 }
312 
313 static nir_def *
pervertex_lds_addr(nir_builder * b,nir_def * vertex_idx,unsigned per_vtx_bytes)314 pervertex_lds_addr(nir_builder *b, nir_def *vertex_idx, unsigned per_vtx_bytes)
315 {
316    return nir_imul_imm(b, vertex_idx, per_vtx_bytes);
317 }
318 
319 static void
alloc_vertices_and_primitives(nir_builder * b,nir_def * num_vtx,nir_def * num_prim)320 alloc_vertices_and_primitives(nir_builder *b,
321                               nir_def *num_vtx,
322                               nir_def *num_prim)
323 {
324    /* The caller should only call this conditionally on wave 0.
325     *
326     * Send GS Alloc Request message from the first wave of the group to SPI.
327     * Message payload (in the m0 register) is:
328     * - bits 0..10: number of vertices in group
329     * - bits 12..22: number of primitives in group
330     */
331 
332    nir_def *m0 = nir_ior(b, nir_ishl_imm(b, num_prim, 12), num_vtx);
333    nir_sendmsg_amd(b, m0, .base = AC_SENDMSG_GS_ALLOC_REQ);
334 }
335 
336 static void
alloc_vertices_and_primitives_gfx10_workaround(nir_builder * b,nir_def * num_vtx,nir_def * num_prim)337 alloc_vertices_and_primitives_gfx10_workaround(nir_builder *b,
338                                                nir_def *num_vtx,
339                                                nir_def *num_prim)
340 {
341    /* HW workaround for a GPU hang with 100% culling on GFX10.
342     * We always have to export at least 1 primitive.
343     * Export a degenerate triangle using vertex 0 for all 3 vertices.
344     *
345     * NOTE: We rely on the caller to set the vertex count also to 0 when the primitive count is 0.
346     */
347    nir_def *is_prim_cnt_0 = nir_ieq_imm(b, num_prim, 0);
348    nir_if *if_prim_cnt_0 = nir_push_if(b, is_prim_cnt_0);
349    {
350       nir_def *one = nir_imm_int(b, 1);
351       alloc_vertices_and_primitives(b, one, one);
352 
353       nir_def *tid = nir_load_subgroup_invocation(b);
354       nir_def *is_thread_0 = nir_ieq_imm(b, tid, 0);
355       nir_if *if_thread_0 = nir_push_if(b, is_thread_0);
356       {
357          /* The vertex indices are 0, 0, 0. */
358          nir_export_amd(b, nir_imm_zero(b, 4, 32),
359                         .base = V_008DFC_SQ_EXP_PRIM,
360                         .flags = AC_EXP_FLAG_DONE,
361                         .write_mask = 1);
362 
363          /* The HW culls primitives with NaN. -1 is also NaN and can save
364           * a dword in binary code by inlining constant.
365           */
366          nir_export_amd(b, nir_imm_ivec4(b, -1, -1, -1, -1),
367                         .base = V_008DFC_SQ_EXP_POS,
368                         .flags = AC_EXP_FLAG_DONE,
369                         .write_mask = 0xf);
370       }
371       nir_pop_if(b, if_thread_0);
372    }
373    nir_push_else(b, if_prim_cnt_0);
374    {
375       alloc_vertices_and_primitives(b, num_vtx, num_prim);
376    }
377    nir_pop_if(b, if_prim_cnt_0);
378 }
379 
380 static void
ngg_nogs_init_vertex_indices_vars(nir_builder * b,nir_function_impl * impl,lower_ngg_nogs_state * s)381 ngg_nogs_init_vertex_indices_vars(nir_builder *b, nir_function_impl *impl, lower_ngg_nogs_state *s)
382 {
383    for (unsigned v = 0; v < s->options->num_vertices_per_primitive; ++v) {
384       s->gs_vtx_indices_vars[v] = nir_local_variable_create(impl, glsl_uint_type(), "gs_vtx_addr");
385 
386       nir_def *vtx;
387 
388       if (s->options->gfx_level >= GFX12) {
389          vtx = nir_ubfe_imm(b, nir_load_packed_passthrough_primitive_amd(b), 9 * v, 8);
390       } else if (s->options->passthrough) {
391          vtx = nir_ubfe_imm(b, nir_load_packed_passthrough_primitive_amd(b), 10 * v, 9);
392       } else {
393          vtx = nir_ubfe_imm(b, nir_load_gs_vertex_offset_amd(b, .base = v / 2u),
394                             (v & 1u) * 16u, 16u);
395       }
396 
397       nir_store_var(b, s->gs_vtx_indices_vars[v], vtx, 0x1);
398    }
399 }
400 
401 static nir_def *
emit_ngg_nogs_prim_exp_arg(nir_builder * b,lower_ngg_nogs_state * s)402 emit_ngg_nogs_prim_exp_arg(nir_builder *b, lower_ngg_nogs_state *s)
403 {
404    if (s->options->gfx_level >= GFX12 || s->options->passthrough) {
405       return nir_load_packed_passthrough_primitive_amd(b);
406    } else {
407       nir_def *vtx_idx[3] = {0};
408 
409       for (unsigned v = 0; v < s->options->num_vertices_per_primitive; ++v)
410          vtx_idx[v] = nir_load_var(b, s->gs_vtx_indices_vars[v]);
411 
412       return ac_nir_pack_ngg_prim_exp_arg(b, s->options->num_vertices_per_primitive, vtx_idx, NULL,
413                                         s->options->gfx_level);
414    }
415 }
416 
417 static nir_def *
has_input_vertex(nir_builder * b)418 has_input_vertex(nir_builder *b)
419 {
420    return nir_is_subgroup_invocation_lt_amd(b, nir_load_merged_wave_info_amd(b));
421 }
422 
423 static nir_def *
has_input_primitive(nir_builder * b)424 has_input_primitive(nir_builder *b)
425 {
426    return nir_is_subgroup_invocation_lt_amd(b, nir_load_merged_wave_info_amd(b), .base = 8);
427 }
428 
429 static void
nogs_prim_gen_query(nir_builder * b,lower_ngg_nogs_state * s)430 nogs_prim_gen_query(nir_builder *b, lower_ngg_nogs_state *s)
431 {
432    if (!s->options->has_gen_prim_query)
433       return;
434 
435    nir_if *if_shader_query = nir_push_if(b, nir_load_prim_gen_query_enabled_amd(b));
436    {
437       /* Activate only 1 lane and add the number of primitives to query result. */
438       nir_if *if_elected = nir_push_if(b, nir_elect(b, 1));
439       {
440          /* Number of input primitives in the current wave. */
441          nir_def *num_input_prims = nir_ubfe_imm(b, nir_load_merged_wave_info_amd(b),
442                                                      8, 8);
443 
444          /* Add to stream 0 primitive generated counter. */
445          nir_atomic_add_gen_prim_count_amd(b, num_input_prims, .stream_id = 0);
446       }
447       nir_pop_if(b, if_elected);
448    }
449    nir_pop_if(b, if_shader_query);
450 }
451 
452 static void
emit_ngg_nogs_prim_export(nir_builder * b,lower_ngg_nogs_state * s,nir_def * arg)453 emit_ngg_nogs_prim_export(nir_builder *b, lower_ngg_nogs_state *s, nir_def *arg)
454 {
455    nir_if *if_gs_thread = nir_push_if(b, nir_load_var(b, s->gs_exported_var));
456    {
457       if (!arg)
458          arg = emit_ngg_nogs_prim_exp_arg(b, s);
459 
460       /* pack user edge flag info into arg */
461       if (s->has_user_edgeflags) {
462          /* Workgroup barrier: wait for ES threads store user edge flags to LDS */
463          nir_barrier(b, .execution_scope = SCOPE_WORKGROUP,
464                             .memory_scope = SCOPE_WORKGROUP,
465                             .memory_semantics = NIR_MEMORY_ACQ_REL,
466                             .memory_modes = nir_var_mem_shared);
467 
468          unsigned edge_flag_bits = ac_get_all_edge_flag_bits(s->options->gfx_level);
469          nir_def *mask = nir_imm_intN_t(b, ~edge_flag_bits, 32);
470 
471          unsigned edge_flag_offset = 0;
472          if (s->streamout_enabled) {
473             unsigned packed_location =
474                util_bitcount64(b->shader->info.outputs_written &
475                                BITFIELD64_MASK(VARYING_SLOT_EDGE));
476             edge_flag_offset = packed_location * 16;
477          }
478 
479          for (int i = 0; i < s->options->num_vertices_per_primitive; i++) {
480             nir_def *vtx_idx = nir_load_var(b, s->gs_vtx_indices_vars[i]);
481             nir_def *addr = pervertex_lds_addr(b, vtx_idx, s->pervertex_lds_bytes);
482             nir_def *edge = nir_load_shared(b, 1, 32, addr, .base = edge_flag_offset);
483 
484             if (s->options->gfx_level >= GFX12)
485                mask = nir_ior(b, mask, nir_ishl_imm(b, edge, 8 + i * 9));
486             else
487                mask = nir_ior(b, mask, nir_ishl_imm(b, edge, 9 + i * 10));
488          }
489          arg = nir_iand(b, arg, mask);
490       }
491 
492       ac_nir_export_primitive(b, arg, NULL);
493 
494       /* Store implicit primitive ID when configured as a per-primitive output on GFX10.3.
495        * Because this uses the export space, do it together with the primitive export.
496        */
497       if (s->options->gfx_level == GFX10_3 && s->options->export_primitive_id_per_prim) {
498          const uint8_t offset = s->options->vs_output_param_offset[VARYING_SLOT_PRIMITIVE_ID];
499          nir_def *prim_id = nir_load_primitive_id(b);
500          nir_def *undef = nir_undef(b, 1, 32);
501          ac_nir_prerast_out out = {
502             .infos = {{.components_mask = 1, .as_varying_mask = 1}},
503             .outputs = {{prim_id, undef, undef, undef}}
504          };
505 
506          ac_nir_export_parameters(b, &offset, 1, 0, &out);
507       }
508    }
509    nir_pop_if(b, if_gs_thread);
510 }
511 
512 static void
emit_ngg_nogs_prim_id_store_shared(nir_builder * b,lower_ngg_nogs_state * s)513 emit_ngg_nogs_prim_id_store_shared(nir_builder *b, lower_ngg_nogs_state *s)
514 {
515    nir_def *gs_thread =
516       s->gs_accepted_var ? nir_load_var(b, s->gs_accepted_var) : has_input_primitive(b);
517 
518    nir_if *if_gs_thread = nir_push_if(b, gs_thread);
519    {
520       /* Copy Primitive IDs from GS threads to the LDS address
521        * corresponding to the ES thread of the provoking vertex.
522        * It will be exported as a per-vertex attribute.
523        */
524       nir_def *gs_vtx_indices[3];
525       for (unsigned i = 0; i < s->options->num_vertices_per_primitive; i++)
526          gs_vtx_indices[i] = nir_load_var(b, s->gs_vtx_indices_vars[i]);
527 
528       nir_def *provoking_vertex = nir_load_provoking_vtx_in_prim_amd(b);
529       nir_def *provoking_vtx_idx = nir_select_from_ssa_def_array(
530          b, gs_vtx_indices, s->options->num_vertices_per_primitive, provoking_vertex);
531 
532       nir_def *prim_id = nir_load_primitive_id(b);
533       nir_def *addr = pervertex_lds_addr(b, provoking_vtx_idx, s->pervertex_lds_bytes);
534 
535       /* primitive id is always at last of a vertex */
536       nir_store_shared(b, prim_id, addr, .base = s->pervertex_lds_bytes - 4);
537    }
538    nir_pop_if(b, if_gs_thread);
539 }
540 
541 /* Store implicit primitive ID when configured as a per-primitive output on GFX11+.
542  * This is done separately from the primitive export on GFX11 in order to
543  * optimize attribute ring access.
544  */
545 static void
emit_ngg_nogs_prim_id_store_per_prim_to_attr_ring(nir_builder * b,lower_ngg_nogs_state * s)546 emit_ngg_nogs_prim_id_store_per_prim_to_attr_ring(nir_builder *b, lower_ngg_nogs_state *s)
547 {
548    assert(s->options->gfx_level >= GFX11);
549 
550    nir_def *is_gs_thread = nir_load_var(b, s->gs_exported_var);
551    nir_def *highest_gs_thread = nir_ufind_msb(b, nir_ballot(b, 1, s->options->wave_size, is_gs_thread));
552    nir_def *max_num_gs_threads = nir_iadd_imm_nuw(b, highest_gs_thread, 1);
553 
554    const uint8_t offset = s->options->vs_output_param_offset[VARYING_SLOT_PRIMITIVE_ID];
555    ac_nir_prerast_out out = {
556       .infos = {{.components_mask = 1, .as_varying_mask = 1}},
557       .outputs = {{nir_load_primitive_id(b), NULL, NULL, NULL}}
558    };
559 
560    ac_nir_store_parameters_to_attr_ring(b, &offset, 1, 0, &out, NULL, max_num_gs_threads);
561 }
562 
563 static void
emit_store_ngg_nogs_es_primitive_id(nir_builder * b,lower_ngg_nogs_state * s)564 emit_store_ngg_nogs_es_primitive_id(nir_builder *b, lower_ngg_nogs_state *s)
565 {
566    nir_def *prim_id = NULL;
567 
568    if (b->shader->info.stage == MESA_SHADER_VERTEX) {
569       /* LDS address where the primitive ID is stored */
570       nir_def *thread_id_in_threadgroup = nir_load_local_invocation_index(b);
571       nir_def *addr =
572          pervertex_lds_addr(b, thread_id_in_threadgroup, s->pervertex_lds_bytes);
573 
574       /* Load primitive ID from LDS */
575       prim_id = nir_load_shared(b, 1, 32, addr, .base = s->pervertex_lds_bytes - 4);
576    } else if (b->shader->info.stage == MESA_SHADER_TESS_EVAL) {
577       /* Just use tess eval primitive ID, which is the same as the patch ID. */
578       prim_id = nir_load_primitive_id(b);
579    }
580 
581    s->out.outputs[VARYING_SLOT_PRIMITIVE_ID][0] = prim_id;
582    s->out.infos[VARYING_SLOT_PRIMITIVE_ID].as_varying_mask |= 1;
583 
584    /* Update outputs_written to reflect that the pass added a new output. */
585    b->shader->info.outputs_written |= VARYING_BIT_PRIMITIVE_ID;
586 }
587 
588 static void
add_clipdist_bit(nir_builder * b,nir_def * dist,unsigned index,nir_variable * mask)589 add_clipdist_bit(nir_builder *b, nir_def *dist, unsigned index, nir_variable *mask)
590 {
591    nir_def *is_neg = nir_flt_imm(b, dist, 0);
592    nir_def *neg_mask = nir_ishl_imm(b, nir_b2i32(b, is_neg), index);
593    neg_mask = nir_ior(b, neg_mask, nir_load_var(b, mask));
594    nir_store_var(b, mask, neg_mask, 1);
595 }
596 
597 static bool
remove_culling_shader_output(nir_builder * b,nir_instr * instr,void * state)598 remove_culling_shader_output(nir_builder *b, nir_instr *instr, void *state)
599 {
600    lower_ngg_nogs_state *s = (lower_ngg_nogs_state *) state;
601 
602    if (instr->type != nir_instr_type_intrinsic)
603       return false;
604 
605    nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
606 
607    /* These are not allowed in VS / TES */
608    assert(intrin->intrinsic != nir_intrinsic_store_per_vertex_output &&
609           intrin->intrinsic != nir_intrinsic_load_per_vertex_input);
610 
611    /* We are only interested in output stores now */
612    if (intrin->intrinsic != nir_intrinsic_store_output)
613       return false;
614 
615    b->cursor = nir_before_instr(instr);
616 
617    /* no indirect output */
618    assert(nir_src_is_const(intrin->src[1]) && nir_src_as_uint(intrin->src[1]) == 0);
619 
620    unsigned writemask = nir_intrinsic_write_mask(intrin);
621    unsigned component = nir_intrinsic_component(intrin);
622    nir_def *store_val = intrin->src[0].ssa;
623 
624    /* Position output - store the value to a variable, remove output store */
625    nir_io_semantics io_sem = nir_intrinsic_io_semantics(intrin);
626    switch (io_sem.location) {
627    case VARYING_SLOT_POS:
628       ac_nir_store_var_components(b, s->position_value_var, store_val, component, writemask);
629       break;
630    case VARYING_SLOT_CLIP_DIST0:
631    case VARYING_SLOT_CLIP_DIST1: {
632       unsigned base = io_sem.location == VARYING_SLOT_CLIP_DIST1 ? 4 : 0;
633       base += component;
634 
635       /* valid clipdist component mask */
636       unsigned mask = (s->options->clip_cull_dist_mask >> base) & writemask;
637       u_foreach_bit(i, mask) {
638          add_clipdist_bit(b, nir_channel(b, store_val, i), base + i,
639                           s->clipdist_neg_mask_var);
640          s->has_clipdist = true;
641       }
642       break;
643    }
644    case VARYING_SLOT_CLIP_VERTEX:
645       ac_nir_store_var_components(b, s->clip_vertex_var, store_val, component, writemask);
646       break;
647    default:
648       break;
649    }
650 
651    /* Remove all output stores */
652    nir_instr_remove(instr);
653    return true;
654 }
655 
656 static void
remove_culling_shader_outputs(nir_shader * culling_shader,lower_ngg_nogs_state * s)657 remove_culling_shader_outputs(nir_shader *culling_shader, lower_ngg_nogs_state *s)
658 {
659    nir_shader_instructions_pass(culling_shader, remove_culling_shader_output,
660                                 nir_metadata_control_flow, s);
661 
662    /* Remove dead code resulting from the deleted outputs. */
663    bool progress;
664    do {
665       progress = false;
666       NIR_PASS(progress, culling_shader, nir_opt_dead_write_vars);
667       NIR_PASS(progress, culling_shader, nir_opt_dce);
668       NIR_PASS(progress, culling_shader, nir_opt_dead_cf);
669    } while (progress);
670 }
671 
672 static void
rewrite_uses_to_var(nir_builder * b,nir_def * old_def,nir_variable * replacement_var,unsigned replacement_var_channel)673 rewrite_uses_to_var(nir_builder *b, nir_def *old_def, nir_variable *replacement_var, unsigned replacement_var_channel)
674 {
675    if (old_def->parent_instr->type == nir_instr_type_load_const)
676       return;
677 
678    b->cursor = nir_after_instr(old_def->parent_instr);
679    if (b->cursor.instr->type == nir_instr_type_phi)
680       b->cursor = nir_after_phis(old_def->parent_instr->block);
681 
682    nir_def *pos_val_rep = nir_load_var(b, replacement_var);
683    nir_def *replacement = nir_channel(b, pos_val_rep, replacement_var_channel);
684 
685    if (old_def->num_components > 1) {
686       /* old_def uses a swizzled vector component.
687        * There is no way to replace the uses of just a single vector component,
688        * so instead create a new vector and replace all uses of the old vector.
689        */
690       nir_def *old_def_elements[NIR_MAX_VEC_COMPONENTS] = {0};
691       for (unsigned j = 0; j < old_def->num_components; ++j)
692          old_def_elements[j] = nir_channel(b, old_def, j);
693       replacement = nir_vec(b, old_def_elements, old_def->num_components);
694    }
695 
696    nir_def_rewrite_uses_after(old_def, replacement, replacement->parent_instr);
697 }
698 
699 static bool
remove_extra_pos_output(nir_builder * b,nir_instr * instr,void * state)700 remove_extra_pos_output(nir_builder *b, nir_instr *instr, void *state)
701 {
702    lower_ngg_nogs_state *s = (lower_ngg_nogs_state *) state;
703 
704    if (instr->type != nir_instr_type_intrinsic)
705       return false;
706 
707    nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
708 
709    /* These are not allowed in VS / TES */
710    assert(intrin->intrinsic != nir_intrinsic_store_per_vertex_output &&
711           intrin->intrinsic != nir_intrinsic_load_per_vertex_input);
712 
713    /* We are only interested in output stores now */
714    if (intrin->intrinsic != nir_intrinsic_store_output)
715       return false;
716 
717    nir_io_semantics io_sem = nir_intrinsic_io_semantics(intrin);
718    if (io_sem.location != VARYING_SLOT_POS)
719       return false;
720 
721    b->cursor = nir_before_instr(instr);
722 
723    /* In case other outputs use what we calculated for pos,
724     * try to avoid calculating it again by rewriting the usages
725     * of the store components here.
726     */
727    nir_def *store_val = intrin->src[0].ssa;
728    unsigned store_pos_component = nir_intrinsic_component(intrin);
729 
730    nir_instr_remove(instr);
731 
732    if (store_val->parent_instr->type == nir_instr_type_alu) {
733       nir_alu_instr *alu = nir_instr_as_alu(store_val->parent_instr);
734       if (nir_op_is_vec_or_mov(alu->op)) {
735          /* Output store uses a vector, we can easily rewrite uses of each vector element. */
736 
737          unsigned num_vec_src = 0;
738          if (alu->op == nir_op_mov)
739             num_vec_src = 1;
740          else if (alu->op == nir_op_vec2)
741             num_vec_src = 2;
742          else if (alu->op == nir_op_vec3)
743             num_vec_src = 3;
744          else if (alu->op == nir_op_vec4)
745             num_vec_src = 4;
746          assert(num_vec_src);
747 
748          /* Remember the current components whose uses we wish to replace.
749           * This is needed because rewriting one source can affect the others too.
750           */
751          nir_def *vec_comps[NIR_MAX_VEC_COMPONENTS] = {0};
752          for (unsigned i = 0; i < num_vec_src; i++)
753             vec_comps[i] = alu->src[i].src.ssa;
754 
755          for (unsigned i = 0; i < num_vec_src; i++)
756             rewrite_uses_to_var(b, vec_comps[i], s->position_value_var, store_pos_component + i);
757       } else {
758          rewrite_uses_to_var(b, store_val, s->position_value_var, store_pos_component);
759       }
760    } else {
761       rewrite_uses_to_var(b, store_val, s->position_value_var, store_pos_component);
762    }
763 
764    return true;
765 }
766 
767 static void
remove_extra_pos_outputs(nir_shader * shader,lower_ngg_nogs_state * s)768 remove_extra_pos_outputs(nir_shader *shader, lower_ngg_nogs_state *s)
769 {
770    nir_shader_instructions_pass(shader, remove_extra_pos_output,
771                                 nir_metadata_control_flow,
772                                 s);
773 }
774 
775 static bool
remove_compacted_arg(lower_ngg_nogs_state * s,nir_builder * b,unsigned idx)776 remove_compacted_arg(lower_ngg_nogs_state *s, nir_builder *b, unsigned idx)
777 {
778    nir_instr *store_instr = s->compact_arg_stores[idx];
779    if (!store_instr)
780       return false;
781 
782    /* Simply remove the store. */
783    nir_instr_remove(store_instr);
784 
785    /* Find the intrinsic that overwrites the shader arguments,
786     * and change its corresponding source.
787     * This will cause NIR's DCE to recognize the load and its phis as dead.
788     */
789    b->cursor = nir_before_instr(&s->overwrite_args->instr);
790    nir_def *undef_arg = nir_undef(b, 1, 32);
791    nir_def_rewrite_uses(s->overwrite_args->src[idx].ssa, undef_arg);
792 
793    s->compact_arg_stores[idx] = NULL;
794    return true;
795 }
796 
797 static bool
cleanup_culling_shader_after_dce(nir_shader * shader,nir_function_impl * function_impl,lower_ngg_nogs_state * s)798 cleanup_culling_shader_after_dce(nir_shader *shader,
799                                  nir_function_impl *function_impl,
800                                  lower_ngg_nogs_state *s)
801 {
802    bool uses_vs_vertex_id = false;
803    bool uses_vs_instance_id = false;
804    bool uses_tes_u = false;
805    bool uses_tes_v = false;
806    bool uses_tes_rel_patch_id = false;
807    bool uses_tes_patch_id = false;
808 
809    bool progress = false;
810    nir_builder b = nir_builder_create(function_impl);
811 
812    nir_foreach_block_reverse_safe(block, function_impl) {
813       nir_foreach_instr_reverse_safe(instr, block) {
814          if (instr->type != nir_instr_type_intrinsic)
815             continue;
816 
817          nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
818 
819          switch (intrin->intrinsic) {
820          case nir_intrinsic_sendmsg_amd:
821             goto cleanup_culling_shader_after_dce_done;
822          case nir_intrinsic_load_vertex_id:
823          case nir_intrinsic_load_vertex_id_zero_base:
824             uses_vs_vertex_id = true;
825             break;
826          case nir_intrinsic_load_instance_id:
827             uses_vs_instance_id = true;
828             break;
829          case nir_intrinsic_load_input: {
830             const nir_io_semantics io_sem = nir_intrinsic_io_semantics(intrin);
831             if (s->options->instance_rate_inputs & BITFIELD_BIT(io_sem.location))
832                uses_vs_instance_id = true;
833             else
834                uses_vs_vertex_id = true;
835             break;
836          }
837          case nir_intrinsic_load_tess_coord:
838             uses_tes_u = uses_tes_v = true;
839             break;
840          case nir_intrinsic_load_tess_rel_patch_id_amd:
841             uses_tes_rel_patch_id = true;
842             break;
843          case nir_intrinsic_load_primitive_id:
844             if (shader->info.stage == MESA_SHADER_TESS_EVAL)
845                uses_tes_patch_id = true;
846             break;
847          default:
848             break;
849          }
850       }
851    }
852 
853    cleanup_culling_shader_after_dce_done:
854 
855    if (shader->info.stage == MESA_SHADER_VERTEX) {
856       if (!uses_vs_vertex_id)
857          progress |= remove_compacted_arg(s, &b, 0);
858       if (!uses_vs_instance_id)
859          progress |= remove_compacted_arg(s, &b, 1);
860    } else if (shader->info.stage == MESA_SHADER_TESS_EVAL) {
861       if (!uses_tes_u)
862          progress |= remove_compacted_arg(s, &b, 0);
863       if (!uses_tes_v)
864          progress |= remove_compacted_arg(s, &b, 1);
865       if (!uses_tes_rel_patch_id)
866          progress |= remove_compacted_arg(s, &b, 3);
867       if (!uses_tes_patch_id)
868          progress |= remove_compacted_arg(s, &b, 2);
869    }
870 
871    return progress;
872 }
873 
874 /**
875  * Perform vertex compaction after culling.
876  *
877  * 1. Repack surviving ES invocations (this determines which lane will export which vertex)
878  * 2. Surviving ES vertex invocations store their data to LDS
879  * 3. Emit GS_ALLOC_REQ
880  * 4. Repacked invocations load the vertex data from LDS
881  * 5. GS threads update their vertex indices
882  * 6. Optionally, do the same for primitives.
883  */
884 static void
compact_vertices_after_culling(nir_builder * b,lower_ngg_nogs_state * s,nir_variable ** repacked_variables,nir_variable ** gs_vtxaddr_vars,nir_def * invocation_index,nir_def * es_vertex_lds_addr,nir_def * es_exporter_tid,nir_def * num_live_vertices_in_workgroup,nir_def * gs_exporter_tid,nir_def * num_live_primitives_in_workgroup,unsigned pervertex_lds_bytes,unsigned num_repacked_variables)885 compact_vertices_after_culling(nir_builder *b,
886                                lower_ngg_nogs_state *s,
887                                nir_variable **repacked_variables,
888                                nir_variable **gs_vtxaddr_vars,
889                                nir_def *invocation_index,
890                                nir_def *es_vertex_lds_addr,
891                                nir_def *es_exporter_tid,
892                                nir_def *num_live_vertices_in_workgroup,
893                                nir_def *gs_exporter_tid,
894                                nir_def *num_live_primitives_in_workgroup,
895                                unsigned pervertex_lds_bytes,
896                                unsigned num_repacked_variables)
897 {
898    nir_variable *es_accepted_var = s->es_accepted_var;
899    nir_variable *gs_accepted_var = s->gs_accepted_var;
900    nir_variable *position_value_var = s->position_value_var;
901    nir_variable *prim_exp_arg_var = s->prim_exp_arg_var;
902 
903    nir_if *if_es_accepted = nir_push_if(b, nir_load_var(b, es_accepted_var));
904    {
905       nir_def *exporter_addr = pervertex_lds_addr(b, es_exporter_tid, pervertex_lds_bytes);
906 
907       /* Store the exporter thread's index to the LDS space of the current thread so GS threads can load it */
908       nir_store_shared(b, nir_u2u8(b, es_exporter_tid), es_vertex_lds_addr, .base = lds_es_exporter_tid);
909 
910       /* Store the current thread's position output to the exporter thread's LDS space */
911       nir_def *pos = nir_load_var(b, position_value_var);
912       nir_store_shared(b, pos, exporter_addr, .base = lds_es_pos_x);
913 
914       /* Store the current thread's repackable arguments to the exporter thread's LDS space */
915       for (unsigned i = 0; i < num_repacked_variables; ++i) {
916          nir_def *arg_val = nir_load_var(b, repacked_variables[i]);
917          nir_intrinsic_instr *store = nir_store_shared(b, arg_val, exporter_addr, .base = lds_es_arg_0 + 4u * i);
918 
919          s->compact_arg_stores[i] = &store->instr;
920       }
921 
922       /* TES rel patch id does not cost extra dword */
923       if (b->shader->info.stage == MESA_SHADER_TESS_EVAL) {
924          nir_def *arg_val = nir_load_var(b, s->repacked_rel_patch_id);
925          nir_intrinsic_instr *store =
926             nir_store_shared(b, nir_u2u8(b, arg_val), exporter_addr,
927                              .base = lds_es_tes_rel_patch_id);
928 
929          s->compact_arg_stores[3] = &store->instr;
930       }
931    }
932    nir_pop_if(b, if_es_accepted);
933 
934    /* TODO: Consider adding a shortcut exit.
935     * Waves that have no vertices and primitives left can s_endpgm right here.
936     */
937 
938    nir_barrier(b, .execution_scope=SCOPE_WORKGROUP, .memory_scope=SCOPE_WORKGROUP,
939                          .memory_semantics=NIR_MEMORY_ACQ_REL, .memory_modes=nir_var_mem_shared);
940 
941    nir_def *es_survived = nir_ilt(b, invocation_index, num_live_vertices_in_workgroup);
942    nir_if *if_packed_es_thread = nir_push_if(b, es_survived);
943    {
944       /* Read position from the current ES thread's LDS space (written by the exported vertex's ES thread) */
945       nir_def *exported_pos = nir_load_shared(b, 4, 32, es_vertex_lds_addr, .base = lds_es_pos_x);
946       nir_store_var(b, position_value_var, exported_pos, 0xfu);
947 
948       /* Read the repacked arguments */
949       for (unsigned i = 0; i < num_repacked_variables; ++i) {
950          nir_def *arg_val = nir_load_shared(b, 1, 32, es_vertex_lds_addr, .base = lds_es_arg_0 + 4u * i);
951          nir_store_var(b, repacked_variables[i], arg_val, 0x1u);
952       }
953 
954       if (b->shader->info.stage == MESA_SHADER_TESS_EVAL) {
955          nir_def *arg_val = nir_load_shared(b, 1, 8, es_vertex_lds_addr,
956                                                 .base = lds_es_tes_rel_patch_id);
957          nir_store_var(b, s->repacked_rel_patch_id, nir_u2u32(b, arg_val), 0x1u);
958       }
959    }
960    nir_push_else(b, if_packed_es_thread);
961    {
962       nir_store_var(b, position_value_var, nir_undef(b, 4, 32), 0xfu);
963       for (unsigned i = 0; i < num_repacked_variables; ++i)
964          nir_store_var(b, repacked_variables[i], nir_undef(b, 1, 32), 0x1u);
965    }
966    nir_pop_if(b, if_packed_es_thread);
967 
968    nir_def *gs_accepted = nir_load_var(b, gs_accepted_var);
969    nir_if *if_gs_accepted = nir_push_if(b, gs_accepted);
970    {
971       nir_def *exporter_vtx_indices[3] = {0};
972 
973       /* Load the index of the ES threads that will export the current GS thread's vertices */
974       for (unsigned v = 0; v < s->options->num_vertices_per_primitive; ++v) {
975          nir_def *vtx_addr = nir_load_var(b, gs_vtxaddr_vars[v]);
976          nir_def *exporter_vtx_idx = nir_load_shared(b, 1, 8, vtx_addr, .base = lds_es_exporter_tid);
977          exporter_vtx_indices[v] = nir_u2u32(b, exporter_vtx_idx);
978          nir_store_var(b, s->gs_vtx_indices_vars[v], exporter_vtx_indices[v], 0x1);
979       }
980 
981       nir_def *prim_exp_arg =
982          ac_nir_pack_ngg_prim_exp_arg(b, s->options->num_vertices_per_primitive,
983                                     exporter_vtx_indices, NULL, s->options->gfx_level);
984       nir_store_var(b, prim_exp_arg_var, prim_exp_arg, 0x1u);
985    }
986    nir_pop_if(b, if_gs_accepted);
987 
988    nir_store_var(b, es_accepted_var, es_survived, 0x1u);
989 
990    if (s->options->compact_primitives) {
991       /* For primitive compaction, re-use the same LDS space that we used for
992        * vertex compaction, so we need to wait until vertex threads are finished reading it.
993        * Considering we only need 1 DWORD per primitive, let's assume we always have enough space,
994        * since vertex compaction requires at least 5 DWORDs per vertex.
995        */
996       nir_barrier(b, .execution_scope=SCOPE_WORKGROUP, .memory_scope=SCOPE_WORKGROUP,
997                      .memory_semantics=NIR_MEMORY_ACQ_REL, .memory_modes=nir_var_mem_shared);
998 
999       if_gs_accepted = nir_push_if(b, gs_accepted);
1000       {
1001          nir_def *exporter_addr = pervertex_lds_addr(b, gs_exporter_tid, pervertex_lds_bytes);
1002          nir_def *prim_exp_arg = nir_load_var(b, prim_exp_arg_var);
1003 
1004          /* Store the primitive export argument into the address of the exporter thread. */
1005          nir_store_shared(b, prim_exp_arg, exporter_addr, .base = lds_es_pos_x);
1006       }
1007       nir_pop_if(b, if_gs_accepted);
1008 
1009       nir_barrier(b, .execution_scope=SCOPE_WORKGROUP, .memory_scope=SCOPE_WORKGROUP,
1010                      .memory_semantics=NIR_MEMORY_ACQ_REL, .memory_modes=nir_var_mem_shared);
1011 
1012       nir_def *gs_survived = nir_ilt(b, invocation_index, num_live_primitives_in_workgroup);
1013       nir_if *if_packed_gs_thread = nir_push_if(b, gs_survived);
1014       {
1015          /* Load the primitive export argument that the current thread will export. */
1016          nir_def *prim_exp_arg = nir_load_shared(b, 1, 32, es_vertex_lds_addr, .base = lds_es_pos_x);
1017 
1018          nir_store_var(b, prim_exp_arg_var, prim_exp_arg, 0x1u);
1019       }
1020       nir_push_else(b, if_packed_gs_thread);
1021       {
1022          nir_store_var(b, prim_exp_arg_var, nir_undef(b, 1, 32), 0x1u);
1023       }
1024       nir_pop_if(b, if_packed_gs_thread);
1025 
1026       nir_store_var(b, gs_accepted_var, gs_survived, 0x1u);
1027       nir_store_var(b, s->gs_exported_var, gs_survived, 0x1u);
1028    }
1029 }
1030 
1031 static void
analyze_shader_before_culling_walk(nir_def * ssa,uint8_t flag,lower_ngg_nogs_state * s)1032 analyze_shader_before_culling_walk(nir_def *ssa,
1033                                    uint8_t flag,
1034                                    lower_ngg_nogs_state *s)
1035 {
1036    nir_instr *instr = ssa->parent_instr;
1037    uint8_t old_pass_flags = instr->pass_flags;
1038    instr->pass_flags |= flag;
1039 
1040    if (instr->pass_flags == old_pass_flags)
1041       return; /* Already visited. */
1042 
1043    switch (instr->type) {
1044    case nir_instr_type_intrinsic: {
1045       nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
1046 
1047       /* VS input loads and SSBO loads are actually VRAM reads on AMD HW. */
1048       switch (intrin->intrinsic) {
1049       case nir_intrinsic_load_input: {
1050          nir_io_semantics in_io_sem = nir_intrinsic_io_semantics(intrin);
1051          uint64_t in_mask = UINT64_C(1) << (uint64_t) in_io_sem.location;
1052          if (instr->pass_flags & nggc_passflag_used_by_pos)
1053             s->inputs_needed_by_pos |= in_mask;
1054          else if (instr->pass_flags & nggc_passflag_used_by_other)
1055             s->inputs_needed_by_others |= in_mask;
1056          break;
1057       }
1058       default:
1059          break;
1060       }
1061 
1062       break;
1063    }
1064    case nir_instr_type_alu: {
1065       nir_alu_instr *alu = nir_instr_as_alu(instr);
1066       unsigned num_srcs = nir_op_infos[alu->op].num_inputs;
1067 
1068       for (unsigned i = 0; i < num_srcs; ++i) {
1069          analyze_shader_before_culling_walk(alu->src[i].src.ssa, flag, s);
1070       }
1071 
1072       break;
1073    }
1074    case nir_instr_type_tex: {
1075       nir_tex_instr *tex = nir_instr_as_tex(instr);
1076       unsigned num_srcs = tex->num_srcs;
1077 
1078       for (unsigned i = 0; i < num_srcs; ++i) {
1079          analyze_shader_before_culling_walk(tex->src[i].src.ssa, flag, s);
1080       }
1081 
1082       break;
1083    }
1084    case nir_instr_type_phi: {
1085       nir_phi_instr *phi = nir_instr_as_phi(instr);
1086       nir_foreach_phi_src_safe(phi_src, phi) {
1087          analyze_shader_before_culling_walk(phi_src->src.ssa, flag, s);
1088       }
1089 
1090       break;
1091    }
1092    default:
1093       break;
1094    }
1095 }
1096 
1097 static void
analyze_shader_before_culling(nir_shader * shader,lower_ngg_nogs_state * s)1098 analyze_shader_before_culling(nir_shader *shader, lower_ngg_nogs_state *s)
1099 {
1100    /* We need divergence info for culling shaders. */
1101    nir_divergence_analysis(shader);
1102 
1103    nir_foreach_function_impl(impl, shader) {
1104       nir_foreach_block(block, impl) {
1105          nir_foreach_instr(instr, block) {
1106             instr->pass_flags = 0;
1107 
1108             if (instr->type != nir_instr_type_intrinsic)
1109                continue;
1110 
1111             nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
1112             if (intrin->intrinsic != nir_intrinsic_store_output)
1113                continue;
1114 
1115             nir_io_semantics io_sem = nir_intrinsic_io_semantics(intrin);
1116             nir_def *store_val = intrin->src[0].ssa;
1117             uint8_t flag = io_sem.location == VARYING_SLOT_POS ? nggc_passflag_used_by_pos : nggc_passflag_used_by_other;
1118             analyze_shader_before_culling_walk(store_val, flag, s);
1119          }
1120       }
1121    }
1122 }
1123 
1124 static nir_def *
find_reusable_ssa_def(nir_instr * instr)1125 find_reusable_ssa_def(nir_instr *instr)
1126 {
1127    /* Find instructions whose SSA definitions are used by both
1128     * the top and bottom parts of the shader (before and after culling).
1129     * Only in this case, it makes sense for the bottom part
1130     * to try to reuse these from the top part.
1131     */
1132    if ((instr->pass_flags & nggc_passflag_used_by_both) != nggc_passflag_used_by_both)
1133       return NULL;
1134 
1135    switch (instr->type) {
1136    case nir_instr_type_alu: {
1137       nir_alu_instr *alu = nir_instr_as_alu(instr);
1138       if (alu->def.divergent)
1139          return NULL;
1140       /* Ignore uniform floats because they regress VGPR usage too much */
1141       if (nir_op_infos[alu->op].output_type & nir_type_float)
1142          return NULL;
1143       return &alu->def;
1144    }
1145    case nir_instr_type_intrinsic: {
1146       nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
1147       if (!nir_intrinsic_can_reorder(intrin) ||
1148             !nir_intrinsic_infos[intrin->intrinsic].has_dest ||
1149             intrin->def.divergent)
1150          return NULL;
1151       return &intrin->def;
1152    }
1153    case nir_instr_type_phi: {
1154       nir_phi_instr *phi = nir_instr_as_phi(instr);
1155       if (phi->def.divergent)
1156          return NULL;
1157       return &phi->def;
1158    }
1159    default:
1160       return NULL;
1161    }
1162 }
1163 
1164 static const struct glsl_type *
glsl_uint_type_for_ssa(nir_def * ssa)1165 glsl_uint_type_for_ssa(nir_def *ssa)
1166 {
1167    enum glsl_base_type base_type = GLSL_TYPE_UINT;
1168    switch (ssa->bit_size) {
1169    case 8: base_type = GLSL_TYPE_UINT8; break;
1170    case 16: base_type = GLSL_TYPE_UINT16; break;
1171    case 32: base_type = GLSL_TYPE_UINT; break;
1172    case 64: base_type = GLSL_TYPE_UINT64; break;
1173    default: return NULL;
1174    }
1175 
1176    return ssa->num_components == 1
1177           ? glsl_scalar_type(base_type)
1178           : glsl_vector_type(base_type, ssa->num_components);
1179 }
1180 
1181 /**
1182  * Save the reusable SSA definitions to variables so that the
1183  * bottom shader part can reuse them from the top part.
1184  *
1185  * 1. We create a new function temporary variable for reusables,
1186  *    and insert a store+load.
1187  * 2. The shader is cloned (the top part is created), then the
1188  *    control flow is reinserted (for the bottom part.)
1189  * 3. For reusables, we delete the variable stores from the
1190  *    bottom part. This will make them use the variables from
1191  *    the top part and DCE the redundant instructions.
1192  */
1193 static void
save_reusable_variables(nir_builder * b,lower_ngg_nogs_state * s)1194 save_reusable_variables(nir_builder *b, lower_ngg_nogs_state *s)
1195 {
1196    ASSERTED int vec_ok = u_vector_init(&s->reusable_nondeferred_variables, 4, sizeof(reusable_nondeferred_variable));
1197    assert(vec_ok);
1198 
1199    /* Upper limit on reusable uniforms in order to reduce SGPR spilling. */
1200    unsigned remaining_reusable_uniforms = 48;
1201 
1202    nir_block *block = nir_start_block(b->impl);
1203    while (block) {
1204       /* Process the instructions in the current block. */
1205       nir_foreach_instr_safe(instr, block) {
1206          /* Determine if we can reuse the current SSA value.
1207           * When vertex compaction is used, it is possible that the same shader invocation
1208           * processes a different vertex in the top and bottom part of the shader.
1209           * Therefore, we only reuse uniform values.
1210           */
1211          nir_def *ssa = find_reusable_ssa_def(instr);
1212          if (!ssa)
1213             continue;
1214 
1215          /* Determine a suitable type for the SSA value. */
1216          const struct glsl_type *t = glsl_uint_type_for_ssa(ssa);
1217          if (!t)
1218             continue;
1219 
1220          if (!ssa->divergent) {
1221             if (remaining_reusable_uniforms < ssa->num_components)
1222                continue;
1223 
1224             remaining_reusable_uniforms -= ssa->num_components;
1225          }
1226 
1227          reusable_nondeferred_variable *saved = (reusable_nondeferred_variable *) u_vector_add(&s->reusable_nondeferred_variables);
1228          assert(saved);
1229 
1230          /* Create a new NIR variable where we store the reusable value.
1231           * Then, we reload the variable and replace the uses of the value
1232           * with the reloaded variable.
1233           */
1234          saved->var = nir_local_variable_create(b->impl, t, NULL);
1235          saved->ssa = ssa;
1236 
1237          b->cursor = instr->type == nir_instr_type_phi
1238                      ? nir_after_instr_and_phis(instr)
1239                      : nir_after_instr(instr);
1240          nir_store_var(b, saved->var, saved->ssa, BITFIELD_MASK(ssa->num_components));
1241          nir_def *reloaded = nir_load_var(b, saved->var);
1242          nir_def_rewrite_uses_after(ssa, reloaded, reloaded->parent_instr);
1243       }
1244 
1245       /* Look at the next CF node. */
1246       nir_cf_node *next_cf_node = nir_cf_node_next(&block->cf_node);
1247       if (next_cf_node) {
1248          /* It makes no sense to try to reuse things from within loops. */
1249          bool next_is_loop = next_cf_node->type == nir_cf_node_loop;
1250 
1251          /* Don't reuse if we're in divergent control flow.
1252           *
1253           * Thanks to vertex repacking, the same shader invocation may process a different vertex
1254           * in the top and bottom part, and it's even possible that this different vertex was initially
1255           * processed in a different wave. So the two parts may take a different divergent code path.
1256           * Therefore, these variables in divergent control flow may stay undefined.
1257           *
1258           * Note that this problem doesn't exist if vertices are not repacked or if the
1259           * workgroup only has a single wave.
1260           */
1261          bool next_is_divergent_if =
1262             next_cf_node->type == nir_cf_node_if &&
1263             nir_src_is_divergent(&nir_cf_node_as_if(next_cf_node)->condition);
1264 
1265          if (next_is_loop || next_is_divergent_if) {
1266             block = nir_cf_node_cf_tree_next(next_cf_node);
1267             continue;
1268          }
1269       }
1270 
1271       /* Go to the next block. */
1272       block = nir_block_cf_tree_next(block);
1273    }
1274 }
1275 
1276 /**
1277  * Reuses suitable variables from the top part of the shader,
1278  * by deleting their stores from the bottom part.
1279  */
1280 static void
apply_reusable_variables(nir_builder * b,lower_ngg_nogs_state * s)1281 apply_reusable_variables(nir_builder *b, lower_ngg_nogs_state *s)
1282 {
1283    if (!u_vector_length(&s->reusable_nondeferred_variables)) {
1284       u_vector_finish(&s->reusable_nondeferred_variables);
1285       return;
1286    }
1287 
1288    nir_foreach_block_reverse_safe(block, b->impl) {
1289       nir_foreach_instr_reverse_safe(instr, block) {
1290          if (instr->type != nir_instr_type_intrinsic)
1291             continue;
1292          nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
1293 
1294          /* When we found any of these intrinsics, it means
1295           * we reached the top part and we must stop.
1296           */
1297          if (intrin->intrinsic == nir_intrinsic_sendmsg_amd)
1298             goto done;
1299 
1300          if (intrin->intrinsic != nir_intrinsic_store_deref)
1301             continue;
1302          nir_deref_instr *deref = nir_src_as_deref(intrin->src[0]);
1303          if (deref->deref_type != nir_deref_type_var)
1304             continue;
1305 
1306          reusable_nondeferred_variable *saved;
1307          u_vector_foreach(saved, &s->reusable_nondeferred_variables) {
1308             if (saved->var == deref->var) {
1309                nir_instr_remove(instr);
1310             }
1311          }
1312       }
1313    }
1314 
1315    done:
1316    u_vector_finish(&s->reusable_nondeferred_variables);
1317 }
1318 
1319 static void
cull_primitive_accepted(nir_builder * b,void * state)1320 cull_primitive_accepted(nir_builder *b, void *state)
1321 {
1322    lower_ngg_nogs_state *s = (lower_ngg_nogs_state *)state;
1323 
1324    nir_store_var(b, s->gs_accepted_var, nir_imm_true(b), 0x1u);
1325 
1326    /* Store the accepted state to LDS for ES threads */
1327    for (unsigned vtx = 0; vtx < s->options->num_vertices_per_primitive; ++vtx)
1328       nir_store_shared(b, nir_imm_intN_t(b, 1, 8), s->vtx_addr[vtx], .base = lds_es_vertex_accepted);
1329 }
1330 
1331 static void
clipdist_culling_es_part(nir_builder * b,lower_ngg_nogs_state * s,nir_def * es_vertex_lds_addr)1332 clipdist_culling_es_part(nir_builder *b, lower_ngg_nogs_state *s,
1333                          nir_def *es_vertex_lds_addr)
1334 {
1335    /* no gl_ClipDistance used but we have user defined clip plane */
1336    if (s->options->user_clip_plane_enable_mask && !s->has_clipdist) {
1337       /* use gl_ClipVertex if defined */
1338       nir_variable *clip_vertex_var =
1339          b->shader->info.outputs_written & BITFIELD64_BIT(VARYING_SLOT_CLIP_VERTEX) ?
1340          s->clip_vertex_var : s->position_value_var;
1341       nir_def *clip_vertex = nir_load_var(b, clip_vertex_var);
1342 
1343       /* clip against user defined clip planes */
1344       for (unsigned i = 0; i < 8; i++) {
1345          if (!(s->options->user_clip_plane_enable_mask & BITFIELD_BIT(i)))
1346             continue;
1347 
1348          nir_def *plane = nir_load_user_clip_plane(b, .ucp_id = i);
1349          nir_def *dist = nir_fdot(b, clip_vertex, plane);
1350          add_clipdist_bit(b, dist, i, s->clipdist_neg_mask_var);
1351       }
1352 
1353       s->has_clipdist = true;
1354    }
1355 
1356    /* store clipdist_neg_mask to LDS for culling latter in gs thread */
1357    if (s->has_clipdist) {
1358       nir_def *mask = nir_load_var(b, s->clipdist_neg_mask_var);
1359       nir_store_shared(b, nir_u2u8(b, mask), es_vertex_lds_addr,
1360                        .base = lds_es_clipdist_neg_mask);
1361    }
1362 }
1363 
1364 static unsigned
ngg_nogs_get_culling_pervertex_lds_size(gl_shader_stage stage,bool uses_instance_id,bool uses_primitive_id,unsigned * num_repacked_variables)1365 ngg_nogs_get_culling_pervertex_lds_size(gl_shader_stage stage,
1366                                         bool uses_instance_id,
1367                                         bool uses_primitive_id,
1368                                         unsigned *num_repacked_variables)
1369 {
1370    /* Culling shaders must repack some variables because
1371     * the same shader invocation may process different vertices
1372     * before and after the culling algorithm.
1373     */
1374 
1375    unsigned num_repacked;
1376    if (stage == MESA_SHADER_VERTEX) {
1377       /* Vertex shaders repack:
1378        * - Vertex ID
1379        * - Instance ID (only if used)
1380        */
1381       num_repacked = uses_instance_id ? 2 : 1;
1382    } else {
1383       /* Tess eval shaders repack:
1384        * - U, V coordinates
1385        * - primitive ID (aka. patch id, only if used)
1386        * - relative patch id (not included here because doesn't need a dword)
1387        */
1388       assert(stage == MESA_SHADER_TESS_EVAL);
1389       num_repacked = uses_primitive_id ? 3 : 2;
1390    }
1391 
1392    if (num_repacked_variables)
1393       *num_repacked_variables = num_repacked;
1394 
1395    /* one odd dword to reduce LDS bank conflict */
1396    return (lds_es_arg_0 + num_repacked * 4u) | 4u;
1397 }
1398 
1399 static void
add_deferred_attribute_culling(nir_builder * b,nir_cf_list * original_extracted_cf,lower_ngg_nogs_state * s)1400 add_deferred_attribute_culling(nir_builder *b, nir_cf_list *original_extracted_cf, lower_ngg_nogs_state *s)
1401 {
1402    bool uses_instance_id = BITSET_TEST(b->shader->info.system_values_read, SYSTEM_VALUE_INSTANCE_ID);
1403    bool uses_tess_primitive_id = BITSET_TEST(b->shader->info.system_values_read, SYSTEM_VALUE_PRIMITIVE_ID);
1404 
1405    unsigned num_repacked_variables;
1406    unsigned pervertex_lds_bytes =
1407       ngg_nogs_get_culling_pervertex_lds_size(b->shader->info.stage,
1408                                               uses_instance_id,
1409                                               uses_tess_primitive_id,
1410                                               &num_repacked_variables);
1411 
1412    nir_function_impl *impl = nir_shader_get_entrypoint(b->shader);
1413 
1414    /* Create some helper variables. */
1415    nir_variable *gs_vtxaddr_vars[3] = {
1416       nir_local_variable_create(impl, glsl_uint_type(), "gs_vtx0_addr"),
1417       nir_local_variable_create(impl, glsl_uint_type(), "gs_vtx1_addr"),
1418       nir_local_variable_create(impl, glsl_uint_type(), "gs_vtx2_addr"),
1419    };
1420 
1421    nir_variable *repacked_variables[3] = {
1422       nir_local_variable_create(impl, glsl_uint_type(), "repacked_var_0"),
1423       nir_local_variable_create(impl, glsl_uint_type(), "repacked_var_1"),
1424       nir_local_variable_create(impl, glsl_uint_type(), "repacked_var_2"),
1425    };
1426 
1427    /* Relative patch ID is a special case because it doesn't need an extra dword, repack separately. */
1428    s->repacked_rel_patch_id = nir_local_variable_create(impl, glsl_uint_type(), "repacked_rel_patch_id");
1429 
1430    if (s->options->clip_cull_dist_mask ||
1431        s->options->user_clip_plane_enable_mask) {
1432       s->clip_vertex_var =
1433          nir_local_variable_create(impl, glsl_vec4_type(), "clip_vertex");
1434       s->clipdist_neg_mask_var =
1435          nir_local_variable_create(impl, glsl_uint_type(), "clipdist_neg_mask");
1436 
1437       /* init mask to 0 */
1438       nir_store_var(b, s->clipdist_neg_mask_var, nir_imm_int(b, 0), 1);
1439    }
1440 
1441    /* Top part of the culling shader (aka. position shader part)
1442     *
1443     * We clone the full ES shader and emit it here, but we only really care
1444     * about its position output, so we delete every other output from this part.
1445     * The position output is stored into a temporary variable, and reloaded later.
1446     */
1447 
1448    nir_def *es_thread = has_input_vertex(b);
1449    nir_if *if_es_thread = nir_push_if(b, es_thread);
1450    {
1451       /* Initialize the position output variable to zeroes, in case not all VS/TES invocations store the output.
1452        * The spec doesn't require it, but we use (0, 0, 0, 1) because some games rely on that.
1453        */
1454       nir_store_var(b, s->position_value_var, nir_imm_vec4(b, 0.0f, 0.0f, 0.0f, 1.0f), 0xfu);
1455 
1456       /* Now reinsert a clone of the shader code */
1457       struct hash_table *remap_table = _mesa_pointer_hash_table_create(NULL);
1458       nir_cf_list_clone_and_reinsert(original_extracted_cf, &if_es_thread->cf_node, b->cursor, remap_table);
1459       _mesa_hash_table_destroy(remap_table, NULL);
1460       b->cursor = nir_after_cf_list(&if_es_thread->then_list);
1461 
1462       /* Remember the current thread's shader arguments */
1463       if (b->shader->info.stage == MESA_SHADER_VERTEX) {
1464          nir_store_var(b, repacked_variables[0], nir_load_vertex_id_zero_base(b), 0x1u);
1465          if (uses_instance_id)
1466             nir_store_var(b, repacked_variables[1], nir_load_instance_id(b), 0x1u);
1467       } else if (b->shader->info.stage == MESA_SHADER_TESS_EVAL) {
1468          nir_store_var(b, s->repacked_rel_patch_id, nir_load_tess_rel_patch_id_amd(b), 0x1u);
1469          nir_def *tess_coord = nir_load_tess_coord(b);
1470          nir_store_var(b, repacked_variables[0], nir_channel(b, tess_coord, 0), 0x1u);
1471          nir_store_var(b, repacked_variables[1], nir_channel(b, tess_coord, 1), 0x1u);
1472          if (uses_tess_primitive_id)
1473             nir_store_var(b, repacked_variables[2], nir_load_primitive_id(b), 0x1u);
1474       } else {
1475          unreachable("Should be VS or TES.");
1476       }
1477    }
1478    nir_pop_if(b, if_es_thread);
1479 
1480    nir_store_var(b, s->es_accepted_var, es_thread, 0x1u);
1481    nir_def *gs_thread = has_input_primitive(b);
1482    nir_store_var(b, s->gs_accepted_var, gs_thread, 0x1u);
1483 
1484    /* Remove all non-position outputs, and put the position output into the variable. */
1485    nir_metadata_preserve(impl, nir_metadata_none);
1486    remove_culling_shader_outputs(b->shader, s);
1487    b->cursor = nir_after_impl(impl);
1488 
1489    nir_def *lds_scratch_base = nir_load_lds_ngg_scratch_base_amd(b);
1490 
1491    /* Run culling algorithms if culling is enabled.
1492     *
1493     * NGG culling can be enabled or disabled in runtime.
1494     * This is determined by a SGPR shader argument which is accessed
1495     * by the following NIR intrinsic.
1496     */
1497 
1498    nir_if *if_cull_en = nir_push_if(b, nir_load_cull_any_enabled_amd(b));
1499    {
1500       nir_def *invocation_index = nir_load_local_invocation_index(b);
1501       nir_def *es_vertex_lds_addr = pervertex_lds_addr(b, invocation_index, pervertex_lds_bytes);
1502 
1503       /* ES invocations store their vertex data to LDS for GS threads to read. */
1504       if_es_thread = nir_push_if(b, es_thread);
1505       if_es_thread->control = nir_selection_control_divergent_always_taken;
1506       {
1507          /* Store position components that are relevant to culling in LDS */
1508          nir_def *pre_cull_pos = nir_load_var(b, s->position_value_var);
1509          nir_def *pre_cull_w = nir_channel(b, pre_cull_pos, 3);
1510          nir_store_shared(b, pre_cull_w, es_vertex_lds_addr, .base = lds_es_pos_w);
1511          nir_def *pre_cull_x_div_w = nir_fdiv(b, nir_channel(b, pre_cull_pos, 0), pre_cull_w);
1512          nir_def *pre_cull_y_div_w = nir_fdiv(b, nir_channel(b, pre_cull_pos, 1), pre_cull_w);
1513          nir_store_shared(b, nir_vec2(b, pre_cull_x_div_w, pre_cull_y_div_w), es_vertex_lds_addr, .base = lds_es_pos_x);
1514 
1515          /* Clear out the ES accepted flag in LDS */
1516          nir_store_shared(b, nir_imm_zero(b, 1, 8), es_vertex_lds_addr, .align_mul = 4, .base = lds_es_vertex_accepted);
1517 
1518          /* For clipdist culling */
1519          clipdist_culling_es_part(b, s, es_vertex_lds_addr);
1520       }
1521       nir_pop_if(b, if_es_thread);
1522 
1523       nir_barrier(b, .execution_scope=SCOPE_WORKGROUP, .memory_scope=SCOPE_WORKGROUP,
1524                             .memory_semantics=NIR_MEMORY_ACQ_REL, .memory_modes=nir_var_mem_shared);
1525 
1526       nir_store_var(b, s->gs_accepted_var, nir_imm_false(b), 0x1u);
1527       nir_store_var(b, s->prim_exp_arg_var, nir_imm_int(b, 1u << 31), 0x1u);
1528 
1529       /* GS invocations load the vertex data and perform the culling. */
1530       nir_if *if_gs_thread = nir_push_if(b, gs_thread);
1531       {
1532          /* Load vertex indices from input VGPRs */
1533          nir_def *vtx_idx[3] = {0};
1534          for (unsigned vertex = 0; vertex < s->options->num_vertices_per_primitive;
1535               ++vertex)
1536             vtx_idx[vertex] = nir_load_var(b, s->gs_vtx_indices_vars[vertex]);
1537 
1538          nir_def *pos[3][4] = {0};
1539 
1540          /* Load W positions of vertices first because the culling code will use these first */
1541          for (unsigned vtx = 0; vtx < s->options->num_vertices_per_primitive; ++vtx) {
1542             s->vtx_addr[vtx] = pervertex_lds_addr(b, vtx_idx[vtx], pervertex_lds_bytes);
1543             pos[vtx][3] = nir_load_shared(b, 1, 32, s->vtx_addr[vtx], .base = lds_es_pos_w);
1544             nir_store_var(b, gs_vtxaddr_vars[vtx], s->vtx_addr[vtx], 0x1u);
1545          }
1546 
1547          /* Load the X/W, Y/W positions of vertices */
1548          for (unsigned vtx = 0; vtx < s->options->num_vertices_per_primitive; ++vtx) {
1549             nir_def *xy = nir_load_shared(b, 2, 32, s->vtx_addr[vtx], .base = lds_es_pos_x);
1550             pos[vtx][0] = nir_channel(b, xy, 0);
1551             pos[vtx][1] = nir_channel(b, xy, 1);
1552          }
1553 
1554          nir_def *accepted_by_clipdist;
1555          if (s->has_clipdist) {
1556             nir_def *clipdist_neg_mask = nir_imm_intN_t(b, 0xff, 8);
1557             for (unsigned vtx = 0; vtx < s->options->num_vertices_per_primitive; ++vtx) {
1558                nir_def *mask =
1559                   nir_load_shared(b, 1, 8, s->vtx_addr[vtx],
1560                                   .base = lds_es_clipdist_neg_mask);
1561                clipdist_neg_mask = nir_iand(b, clipdist_neg_mask, mask);
1562             }
1563             /* primitive is culled if any plane's clipdist of all vertices are negative */
1564             accepted_by_clipdist = nir_ieq_imm(b, clipdist_neg_mask, 0);
1565          } else {
1566             accepted_by_clipdist = nir_imm_true(b);
1567          }
1568 
1569          /* See if the current primitive is accepted */
1570          ac_nir_cull_primitive(b, accepted_by_clipdist, pos,
1571                                s->options->num_vertices_per_primitive,
1572                                cull_primitive_accepted, s);
1573       }
1574       nir_pop_if(b, if_gs_thread);
1575 
1576       nir_barrier(b, .execution_scope=SCOPE_WORKGROUP, .memory_scope=SCOPE_WORKGROUP,
1577                             .memory_semantics=NIR_MEMORY_ACQ_REL, .memory_modes=nir_var_mem_shared);
1578 
1579       nir_store_var(b, s->es_accepted_var, nir_imm_false(b), 0x1u);
1580 
1581       /* ES invocations load their accepted flag from LDS. */
1582       if_es_thread = nir_push_if(b, es_thread);
1583       if_es_thread->control = nir_selection_control_divergent_always_taken;
1584       {
1585          nir_def *accepted = nir_load_shared(b, 1, 8u, es_vertex_lds_addr, .base = lds_es_vertex_accepted, .align_mul = 4u);
1586          nir_def *accepted_bool = nir_ine_imm(b, nir_u2u32(b, accepted), 0);
1587          nir_store_var(b, s->es_accepted_var, accepted_bool, 0x1u);
1588       }
1589       nir_pop_if(b, if_es_thread);
1590 
1591       nir_def *es_accepted = nir_load_var(b, s->es_accepted_var);
1592       nir_def *gs_accepted = nir_load_var(b, s->gs_accepted_var);
1593 
1594       /* Repack the vertices (always) and primitives (optional) that survived the culling. */
1595       nir_def *accepted[] = { es_accepted, gs_accepted };
1596       wg_repack_result rep[2] = {0};
1597       const unsigned num_rep = s->options->compact_primitives ? 2 : 1;
1598       repack_invocations_in_workgroup(b, accepted, rep, num_rep, lds_scratch_base,
1599                                       s->max_num_waves, s->options->wave_size);
1600       nir_def *num_live_vertices_in_workgroup = rep[0].num_repacked_invocations;
1601       nir_def *es_exporter_tid = rep[0].repacked_invocation_index;
1602       nir_def *num_exported_prims = NULL;
1603       nir_def *gs_exporter_tid = NULL;
1604 
1605       if (s->options->compact_primitives) {
1606          num_exported_prims = rep[1].num_repacked_invocations;
1607          gs_exporter_tid = rep[1].repacked_invocation_index;
1608       } else {
1609          /* If all vertices are culled, set primitive count to 0 as well. */
1610          nir_def *fully_culled = nir_ieq_imm(b, num_live_vertices_in_workgroup, 0u);
1611          num_exported_prims = nir_bcsel(b, fully_culled, nir_imm_int(b, 0u), nir_load_workgroup_num_input_primitives_amd(b));
1612          nir_store_var(b, s->gs_exported_var, nir_iand(b, nir_inot(b, fully_culled), has_input_primitive(b)), 0x1u);
1613       }
1614 
1615       nir_if *if_wave_0 = nir_push_if(b, nir_ieq_imm(b, nir_load_subgroup_id(b), 0));
1616       {
1617          /* Tell the final vertex and primitive count to the HW. */
1618          if (s->options->gfx_level == GFX10) {
1619             alloc_vertices_and_primitives_gfx10_workaround(
1620                b, num_live_vertices_in_workgroup, num_exported_prims);
1621          } else {
1622             alloc_vertices_and_primitives(
1623                b, num_live_vertices_in_workgroup, num_exported_prims);
1624          }
1625       }
1626       nir_pop_if(b, if_wave_0);
1627 
1628       /* Vertex compaction. */
1629       compact_vertices_after_culling(b, s,
1630                                      repacked_variables, gs_vtxaddr_vars,
1631                                      invocation_index, es_vertex_lds_addr,
1632                                      es_exporter_tid, num_live_vertices_in_workgroup,
1633                                      gs_exporter_tid, num_exported_prims,
1634                                      pervertex_lds_bytes, num_repacked_variables);
1635    }
1636    nir_push_else(b, if_cull_en);
1637    {
1638       /* When culling is disabled, we do the same as we would without culling. */
1639       nir_if *if_wave_0 = nir_push_if(b, nir_ieq_imm(b, nir_load_subgroup_id(b), 0));
1640       {
1641          nir_def *vtx_cnt = nir_load_workgroup_num_input_vertices_amd(b);
1642          nir_def *prim_cnt = nir_load_workgroup_num_input_primitives_amd(b);
1643          alloc_vertices_and_primitives(b, vtx_cnt, prim_cnt);
1644       }
1645       nir_pop_if(b, if_wave_0);
1646       nir_store_var(b, s->prim_exp_arg_var, emit_ngg_nogs_prim_exp_arg(b, s), 0x1u);
1647    }
1648    nir_pop_if(b, if_cull_en);
1649 
1650    /* Update shader arguments.
1651     *
1652     * The registers which hold information about the subgroup's
1653     * vertices and primitives are updated here, so the rest of the shader
1654     * doesn't need to worry about the culling.
1655     *
1656     * These "overwrite" intrinsics must be at top level control flow,
1657     * otherwise they can mess up the backend (eg. ACO's SSA).
1658     *
1659     * TODO:
1660     * A cleaner solution would be to simply replace all usages of these args
1661     * with the load of the variables.
1662     * However, this wouldn't work right now because the backend uses the arguments
1663     * for purposes not expressed in NIR, eg. VS input loads, etc.
1664     * This can change if VS input loads and other stuff are lowered to eg. load_buffer_amd.
1665     */
1666 
1667    if (b->shader->info.stage == MESA_SHADER_VERTEX)
1668       s->overwrite_args =
1669          nir_overwrite_vs_arguments_amd(b,
1670             nir_load_var(b, repacked_variables[0]), nir_load_var(b, repacked_variables[1]));
1671    else if (b->shader->info.stage == MESA_SHADER_TESS_EVAL)
1672       s->overwrite_args =
1673          nir_overwrite_tes_arguments_amd(b,
1674             nir_load_var(b, repacked_variables[0]), nir_load_var(b, repacked_variables[1]),
1675             nir_load_var(b, repacked_variables[2]), nir_load_var(b, s->repacked_rel_patch_id));
1676    else
1677       unreachable("Should be VS or TES.");
1678 }
1679 
1680 static void
ngg_nogs_store_edgeflag_to_lds(nir_builder * b,lower_ngg_nogs_state * s)1681 ngg_nogs_store_edgeflag_to_lds(nir_builder *b, lower_ngg_nogs_state *s)
1682 {
1683    if (!s->out.outputs[VARYING_SLOT_EDGE][0])
1684       return;
1685 
1686    /* clamp user edge flag to 1 for latter bit operations */
1687    nir_def *edgeflag = s->out.outputs[VARYING_SLOT_EDGE][0];
1688    edgeflag = nir_umin(b, edgeflag, nir_imm_int(b, 1));
1689 
1690    /* user edge flag is stored at the beginning of a vertex if streamout is not enabled */
1691    unsigned offset = 0;
1692    if (s->streamout_enabled) {
1693       unsigned packed_location =
1694          util_bitcount64(b->shader->info.outputs_written & BITFIELD64_MASK(VARYING_SLOT_EDGE));
1695       offset = packed_location * 16;
1696    }
1697 
1698    nir_def *tid = nir_load_local_invocation_index(b);
1699    nir_def *addr = pervertex_lds_addr(b, tid, s->pervertex_lds_bytes);
1700 
1701    nir_store_shared(b, edgeflag, addr, .base = offset);
1702 }
1703 
1704 static void
ngg_nogs_store_xfb_outputs_to_lds(nir_builder * b,lower_ngg_nogs_state * s)1705 ngg_nogs_store_xfb_outputs_to_lds(nir_builder *b, lower_ngg_nogs_state *s)
1706 {
1707    nir_xfb_info *info = ac_nir_get_sorted_xfb_info(b->shader);
1708 
1709    uint64_t xfb_outputs = 0;
1710    unsigned xfb_outputs_16bit = 0;
1711    uint8_t xfb_mask[VARYING_SLOT_MAX] = {0};
1712    uint8_t xfb_mask_16bit_lo[16] = {0};
1713    uint8_t xfb_mask_16bit_hi[16] = {0};
1714 
1715    /* Get XFB output mask for each slot. */
1716    for (int i = 0; i < info->output_count; i++) {
1717       nir_xfb_output_info *out = info->outputs + i;
1718 
1719       if (out->location < VARYING_SLOT_VAR0_16BIT) {
1720          xfb_outputs |= BITFIELD64_BIT(out->location);
1721          xfb_mask[out->location] |= out->component_mask;
1722       } else {
1723          unsigned index = out->location - VARYING_SLOT_VAR0_16BIT;
1724          xfb_outputs_16bit |= BITFIELD_BIT(index);
1725 
1726          if (out->high_16bits)
1727             xfb_mask_16bit_hi[index] |= out->component_mask;
1728          else
1729             xfb_mask_16bit_lo[index] |= out->component_mask;
1730       }
1731    }
1732 
1733    nir_def *tid = nir_load_local_invocation_index(b);
1734    nir_def *addr = pervertex_lds_addr(b, tid, s->pervertex_lds_bytes);
1735 
1736    u_foreach_bit64(slot, xfb_outputs) {
1737       uint64_t outputs_written = b->shader->info.outputs_written;
1738       if (s->skip_primitive_id)
1739          outputs_written &= ~VARYING_BIT_PRIMITIVE_ID;
1740       unsigned packed_location =
1741          util_bitcount64(outputs_written & BITFIELD64_MASK(slot));
1742 
1743       unsigned mask = xfb_mask[slot];
1744 
1745       /* Clear unused components. */
1746       for (unsigned i = 0; i < 4; i++) {
1747          if (!s->out.outputs[slot][i])
1748             mask &= ~BITFIELD_BIT(i);
1749       }
1750 
1751       while (mask) {
1752          int start, count;
1753          u_bit_scan_consecutive_range(&mask, &start, &count);
1754          /* Outputs here are sure to be 32bit.
1755           *
1756           * 64bit outputs have been lowered to two 32bit. As 16bit outputs:
1757           *   Vulkan does not allow streamout outputs less than 32bit.
1758           *   OpenGL puts 16bit outputs in VARYING_SLOT_VAR0_16BIT.
1759           */
1760          nir_def *store_val = nir_vec(b, &s->out.outputs[slot][start], (unsigned)count);
1761          nir_store_shared(b, store_val, addr, .base = packed_location * 16 + start * 4);
1762       }
1763    }
1764 
1765    unsigned num_32bit_outputs = util_bitcount64(b->shader->info.outputs_written);
1766    u_foreach_bit64(slot, xfb_outputs_16bit) {
1767       unsigned packed_location = num_32bit_outputs +
1768          util_bitcount(b->shader->info.outputs_written_16bit & BITFIELD_MASK(slot));
1769 
1770       unsigned mask_lo = xfb_mask_16bit_lo[slot];
1771       unsigned mask_hi = xfb_mask_16bit_hi[slot];
1772 
1773       /* Clear unused components. */
1774       for (unsigned i = 0; i < 4; i++) {
1775          if (!s->out.outputs_16bit_lo[slot][i])
1776             mask_lo &= ~BITFIELD_BIT(i);
1777          if (!s->out.outputs_16bit_hi[slot][i])
1778             mask_hi &= ~BITFIELD_BIT(i);
1779       }
1780 
1781       nir_def **outputs_lo = s->out.outputs_16bit_lo[slot];
1782       nir_def **outputs_hi = s->out.outputs_16bit_hi[slot];
1783       nir_def *undef = nir_undef(b, 1, 16);
1784 
1785       unsigned mask = mask_lo | mask_hi;
1786       while (mask) {
1787          int start, count;
1788          u_bit_scan_consecutive_range(&mask, &start, &count);
1789 
1790          nir_def *values[4] = {0};
1791          for (int c = start; c < start + count; ++c) {
1792             nir_def *lo = mask_lo & BITFIELD_BIT(c) ? outputs_lo[c] : undef;
1793             nir_def *hi = mask_hi & BITFIELD_BIT(c) ? outputs_hi[c] : undef;
1794 
1795             /* extend 8/16 bit to 32 bit, 64 bit has been lowered */
1796             values[c - start] = nir_pack_32_2x16_split(b, lo, hi);
1797          }
1798 
1799          nir_def *store_val = nir_vec(b, values, (unsigned)count);
1800          nir_store_shared(b, store_val, addr, .base = packed_location * 16 + start * 4);
1801       }
1802    }
1803 }
1804 
1805 static nir_def *
write_values_to_lanes(nir_builder * b,nir_def ** values,unsigned lane_mask)1806 write_values_to_lanes(nir_builder *b, nir_def **values, unsigned lane_mask)
1807 {
1808    nir_def *lanes = nir_imm_int(b, 0);
1809 
1810    u_foreach_bit(i, lane_mask) {
1811       lanes = nir_write_invocation_amd(b, lanes, values[i], nir_imm_int(b, i));
1812    }
1813    return lanes;
1814 }
1815 
1816 static nir_def *
read_values_from_4_lanes(nir_builder * b,nir_def * values,unsigned lane_mask)1817 read_values_from_4_lanes(nir_builder *b, nir_def *values, unsigned lane_mask)
1818 {
1819    nir_def *undef = nir_undef(b, 1, 32);
1820    nir_def *per_lane[4] = {undef, undef, undef, undef};
1821 
1822    u_foreach_bit(i, lane_mask) {
1823       per_lane[i] = nir_read_invocation(b, values, nir_imm_int(b, i));
1824    }
1825    return nir_vec(b, per_lane, 4);
1826 }
1827 
1828 static void
ngg_build_streamout_buffer_info(nir_builder * b,nir_xfb_info * info,enum amd_gfx_level gfx_level,bool has_xfb_prim_query,bool use_gfx12_xfb_intrinsic,nir_def * scratch_base,nir_def * tid_in_tg,nir_def * gen_prim[4],nir_def * so_buffer_ret[4],nir_def * buffer_offsets_ret[4],nir_def * emit_prim_ret[4])1829 ngg_build_streamout_buffer_info(nir_builder *b,
1830                                 nir_xfb_info *info,
1831                                 enum amd_gfx_level gfx_level,
1832                                 bool has_xfb_prim_query,
1833                                 bool use_gfx12_xfb_intrinsic,
1834                                 nir_def *scratch_base,
1835                                 nir_def *tid_in_tg,
1836                                 nir_def *gen_prim[4],
1837                                 nir_def *so_buffer_ret[4],
1838                                 nir_def *buffer_offsets_ret[4],
1839                                 nir_def *emit_prim_ret[4])
1840 {
1841    nir_def *prim_stride[4] = {0};
1842    nir_def *undef = nir_undef(b, 1, 32);
1843 
1844    /* For radeonsi which pass this value by arg when VS. Streamout need accurate
1845     * num-vert-per-prim for writing correct amount of data to buffer.
1846     */
1847    nir_def *num_vert_per_prim = nir_load_num_vertices_per_primitive_amd(b);
1848    for (unsigned buffer = 0; buffer < 4; buffer++) {
1849       if (!(info->buffers_written & BITFIELD_BIT(buffer)))
1850          continue;
1851 
1852       assert(info->buffers[buffer].stride);
1853 
1854       prim_stride[buffer] =
1855          nir_imul_imm(b, num_vert_per_prim, info->buffers[buffer].stride);
1856       so_buffer_ret[buffer] = nir_load_streamout_buffer_amd(b, .base = buffer);
1857    }
1858 
1859    nir_if *if_invocation_0 = nir_push_if(b, nir_ieq_imm(b, tid_in_tg, 0));
1860    {
1861       nir_def *any_buffer_valid = nir_imm_false(b);
1862       nir_def *workgroup_buffer_sizes[4];
1863 
1864       for (unsigned buffer = 0; buffer < 4; buffer++) {
1865          if (info->buffers_written & BITFIELD_BIT(buffer)) {
1866             nir_def *buffer_size = nir_channel(b, so_buffer_ret[buffer], 2);
1867             /* In radeonsi, we may not know if a feedback buffer has been bound when
1868              * compile time, so have to check buffer size in runtime to disable the
1869              * GDS update for unbind buffer to prevent the case that previous draw
1870              * compiled with streamout but does not bind feedback buffer miss update
1871              * GDS which will affect current draw's streamout.
1872              */
1873             nir_def *buffer_valid = nir_ine_imm(b, buffer_size, 0);
1874             nir_def *inc_buffer_size =
1875                nir_imul(b, gen_prim[info->buffer_to_stream[buffer]], prim_stride[buffer]);
1876             workgroup_buffer_sizes[buffer] =
1877                nir_bcsel(b, buffer_valid, inc_buffer_size, nir_imm_int(b, 0));
1878             any_buffer_valid = nir_ior(b, any_buffer_valid, buffer_valid);
1879          } else
1880             workgroup_buffer_sizes[buffer] = undef;
1881       }
1882 
1883       nir_def *buffer_offsets = NULL, *xfb_state_address = NULL, *xfb_voffset = NULL;
1884 
1885       /* Get current global offset of buffer and increase by amount of
1886        * workgroup buffer size. This is an ordered operation sorted by
1887        * ordered_id; Each buffer info is in a channel of a vec4.
1888        */
1889       if (gfx_level >= GFX12) {
1890          nir_pop_if(b, if_invocation_0);
1891 
1892          for (unsigned buffer = 0; buffer < 4; buffer++)
1893             workgroup_buffer_sizes[buffer] = nir_if_phi(b, workgroup_buffer_sizes[buffer], undef);
1894          any_buffer_valid = nir_if_phi(b, any_buffer_valid, nir_undef(b, 1, 1));
1895 
1896          /* These must be set after nir_pop_if and phis. */
1897          xfb_state_address = nir_load_xfb_state_address_gfx12_amd(b);
1898          xfb_voffset = nir_imul_imm(b, tid_in_tg, 8);
1899 
1900          nir_if *if_4lanes = nir_push_if(b, nir_iand(b, any_buffer_valid, nir_ult_imm(b, tid_in_tg, 4)));
1901          {
1902             /* Move workgroup buffer sizes from SGPRs to the first 4 lanes. */
1903             nir_def *workgroup_buffer_size_per_lane =
1904                write_values_to_lanes(b, workgroup_buffer_sizes, info->buffers_written);
1905             nir_def *ordered_id = nir_load_ordered_id_amd(b);
1906 
1907             /* The atomic value for the 4 lanes is:
1908              *    lane 0: uvec2(ordered_id, workgroup_buffer_size0)
1909              *    lane 1: uvec2(ordered_id, workgroup_buffer_size1)
1910              *    lane 2: uvec2(ordered_id, workgroup_buffer_size2)
1911              *    lane 3: uvec2(ordered_id, workgroup_buffer_size3)
1912              */
1913             nir_def *atomic_src = nir_pack_64_2x32_split(b, ordered_id,
1914                                                          workgroup_buffer_size_per_lane);
1915 
1916             /* The memory layout of the xfb state is:
1917              *    struct {
1918              *       unsigned ordered_id;
1919              *       unsigned dwords_written0;
1920              *       unsigned ordered_id;
1921              *       unsigned dwords_written1;
1922              *       unsigned ordered_id;
1923              *       unsigned dwords_written2;
1924              *       unsigned ordered_id;
1925              *       unsigned dwords_written3;
1926              *    };
1927              *
1928              * Notes:
1929              * - global_atomic_ordered_add_b64 is semantically a 64-bit atomic, requiring 8-byte
1930              *   address alignment, even though it operates on a pair of 32-bit values.
1931              * - The whole structure is updated at once by issuing the atomic from 4 lanes
1932              *   with 8-byte address increments.
1933              * - The whole structure should be entirely within one 64B block of memory
1934              *   for performance. (the address bits above 64B should not differ between lanes)
1935              */
1936             nir_def *buffer_offset_per_lane;
1937 
1938             /* The gfx12 intrinsic inserts hand-written assembly producing better code than current
1939              * LLVM.
1940              */
1941             if (use_gfx12_xfb_intrinsic) {
1942                buffer_offset_per_lane =
1943                   nir_ordered_add_loop_gfx12_amd(b, xfb_state_address, xfb_voffset, ordered_id,
1944                                                  atomic_src);
1945 
1946                /* Move the buffer offsets from the 4 lanes to lane 0. */
1947                buffer_offsets = read_values_from_4_lanes(b, buffer_offset_per_lane, info->buffers_written);
1948             } else {
1949                /* The NIR version of the above using nir_atomic_op_ordered_add_gfx12_amd. */
1950                enum { NUM_ATOMICS_IN_FLIGHT = 6 };
1951 
1952                nir_variable *result_ring[NUM_ATOMICS_IN_FLIGHT] = {0};
1953                for (unsigned i = 0; i < NUM_ATOMICS_IN_FLIGHT; i++)
1954                   result_ring[i] = nir_local_variable_create(b->impl, glsl_uint64_t_type(), "result");
1955 
1956                /* Issue the first N-1 atomics. The shader must not wait because we want them to be
1957                 * pipelined. It will only wait for the oldest atomic in the NIR loop.
1958                 */
1959                for (unsigned i = 0; i < NUM_ATOMICS_IN_FLIGHT - 1; i++) {
1960                   nir_store_var(b, result_ring[i],
1961                                 nir_global_atomic_amd(b, 64, xfb_state_address, atomic_src, xfb_voffset,
1962                                                       .atomic_op = nir_atomic_op_ordered_add_gfx12_amd), 0x1);
1963                   ac_nir_sleep(b, 24);
1964                }
1965 
1966                nir_variable *buffer_offsets_var =
1967                   nir_local_variable_create(b->impl, glsl_vec4_type(), "buffer_offset_per_lane");
1968 
1969                nir_loop *loop = nir_push_loop(b);
1970                {
1971                   for (unsigned i = 0; i < NUM_ATOMICS_IN_FLIGHT; i++) {
1972                      int issue_index = (NUM_ATOMICS_IN_FLIGHT - 1 + i) % NUM_ATOMICS_IN_FLIGHT;
1973                      int read_index = i;
1974 
1975                      /* Issue (or repeat) the atomic. */
1976                      nir_store_var(b, result_ring[issue_index],
1977                                    nir_global_atomic_amd(b, 64, xfb_state_address, atomic_src, xfb_voffset,
1978                                                          .atomic_op = nir_atomic_op_ordered_add_gfx12_amd), 0x1);
1979 
1980                      /* Break if the oldest atomic succeeded in incrementing the offsets. */
1981                      nir_def *oldest_result = nir_load_var(b, result_ring[read_index]);
1982                      nir_def *loaded_ordered_id = nir_unpack_64_2x32_split_x(b, oldest_result);
1983 
1984                      /* Debug: Write the vec4 into a shader log ring buffer. */
1985 #if 0
1986                      nir_def *loaded_dwords_written = nir_unpack_64_2x32_split_y(b, oldest_result);
1987                      ac_nir_store_debug_log_amd(b, nir_vec4(b, nir_u2u32(b, xfb_state_address),
1988                                                             ordered_id, loaded_ordered_id,
1989                                                             loaded_dwords_written));
1990 #endif
1991 
1992                      nir_def *continue_if = nir_ieq(b, loaded_ordered_id, ordered_id);
1993                      continue_if = nir_inot(b, nir_vote_any(b, 1, continue_if));
1994                      nir_push_if(b, continue_if);
1995                   }
1996                   nir_jump(b, nir_jump_continue);
1997 
1998                   for (unsigned i = 0; i < NUM_ATOMICS_IN_FLIGHT; i++) {
1999                      int read_index = NUM_ATOMICS_IN_FLIGHT - 1 - i;
2000                      nir_push_else(b, NULL);
2001                      {
2002                         nir_def *result = nir_load_var(b, result_ring[read_index]);
2003                         buffer_offset_per_lane = nir_unpack_64_2x32_split_y(b, result);
2004                         buffer_offsets = read_values_from_4_lanes(b, buffer_offset_per_lane, info->buffers_written);
2005                         nir_store_var(b, buffer_offsets_var, buffer_offsets, info->buffers_written);
2006                      }
2007                      nir_pop_if(b, NULL);
2008                   }
2009                   nir_jump(b, nir_jump_break);
2010                }
2011                nir_pop_loop(b, loop);
2012                buffer_offsets = nir_load_var(b, buffer_offsets_var);
2013             }
2014          }
2015          nir_pop_if(b, if_4lanes);
2016          buffer_offsets = nir_if_phi(b, buffer_offsets, nir_undef(b, 4, 32));
2017 
2018          if_invocation_0 = nir_push_if(b, nir_ieq_imm(b, tid_in_tg, 0));
2019       } else {
2020          nir_def *ordered_id = nir_load_ordered_id_amd(b);
2021          buffer_offsets =
2022             nir_ordered_xfb_counter_add_gfx11_amd(b, ordered_id,
2023                                                   nir_vec(b, workgroup_buffer_sizes, 4),
2024                                                   /* mask of buffers to update */
2025                                                   .write_mask = info->buffers_written);
2026       }
2027 
2028       nir_def *emit_prim[4];
2029       memcpy(emit_prim, gen_prim, 4 * sizeof(nir_def *));
2030 
2031       nir_def *any_overflow = nir_imm_false(b);
2032       nir_def *overflow_amount[4] = {undef, undef, undef, undef};
2033 
2034       for (unsigned buffer = 0; buffer < 4; buffer++) {
2035          if (!(info->buffers_written & BITFIELD_BIT(buffer)))
2036             continue;
2037 
2038          nir_def *buffer_size = nir_channel(b, so_buffer_ret[buffer], 2);
2039 
2040          /* Only consider overflow for valid feedback buffers because
2041           * otherwise the ordered operation above (GDS atomic return) might
2042           * return non-zero offsets for invalid buffers.
2043           */
2044          nir_def *buffer_valid = nir_ine_imm(b, buffer_size, 0);
2045          nir_def *buffer_offset = nir_channel(b, buffer_offsets, buffer);
2046          buffer_offset = nir_bcsel(b, buffer_valid, buffer_offset, nir_imm_int(b, 0));
2047 
2048          nir_def *remain_size = nir_isub(b, buffer_size, buffer_offset);
2049          nir_def *remain_prim = nir_idiv(b, remain_size, prim_stride[buffer]);
2050          nir_def *overflow = nir_ilt(b, buffer_size, buffer_offset);
2051 
2052          any_overflow = nir_ior(b, any_overflow, overflow);
2053          overflow_amount[buffer] = nir_imax(b, nir_imm_int(b, 0),
2054                                             nir_isub(b, buffer_offset, buffer_size));
2055 
2056          unsigned stream = info->buffer_to_stream[buffer];
2057          /* when previous workgroup overflow, we can't emit any primitive */
2058          emit_prim[stream] = nir_bcsel(
2059             b, overflow, nir_imm_int(b, 0),
2060             /* we can emit part primitives, limited by smallest buffer */
2061             nir_imin(b, emit_prim[stream], remain_prim));
2062 
2063          /* Save to LDS for being accessed by other waves in this workgroup. */
2064          nir_store_shared(b, buffer_offset, scratch_base, .base = buffer * 4);
2065       }
2066 
2067       /* We have to fix up the streamout offsets if we overflowed because they determine
2068        * the vertex count for DrawTransformFeedback.
2069        */
2070       if (gfx_level >= GFX12) {
2071          nir_pop_if(b, if_invocation_0);
2072 
2073          any_overflow = nir_if_phi(b, any_overflow, nir_undef(b, 1, 1));
2074          for (unsigned buffer = 0; buffer < 4; buffer++)
2075             overflow_amount[buffer] = nir_if_phi(b, overflow_amount[buffer], undef);
2076          for (unsigned stream = 0; stream < 4; stream++) {
2077             if (emit_prim[stream])
2078                emit_prim[stream] = nir_if_phi(b, emit_prim[stream], undef);
2079          }
2080 
2081          nir_if *if_any_overflow_4_lanes =
2082             nir_push_if(b, nir_iand(b, any_overflow, nir_ult_imm(b, tid_in_tg, 4)));
2083          {
2084             /* Move overflow amounts from SGPRs to the first 4 lanes. */
2085             nir_def *overflow_amount_per_lane =
2086                write_values_to_lanes(b, overflow_amount, info->buffers_written);
2087 
2088             nir_global_atomic_amd(b, 32, xfb_state_address, nir_ineg(b, overflow_amount_per_lane),
2089                                   xfb_voffset, .base = 4, .atomic_op = nir_atomic_op_iadd);
2090          }
2091          nir_pop_if(b, if_any_overflow_4_lanes);
2092 
2093          if_invocation_0 = nir_push_if(b, nir_ieq_imm(b, tid_in_tg, 0));
2094       } else {
2095          nir_if *if_any_overflow = nir_push_if(b, any_overflow);
2096          nir_xfb_counter_sub_gfx11_amd(b, nir_vec(b, overflow_amount, 4),
2097                                        /* mask of buffers to update */
2098                                        .write_mask = info->buffers_written);
2099          nir_pop_if(b, if_any_overflow);
2100       }
2101 
2102       /* Save to LDS for being accessed by other waves in this workgroup. */
2103       for (unsigned stream = 0; stream < 4; stream++) {
2104          if (!(info->streams_written & BITFIELD_BIT(stream)))
2105             continue;
2106 
2107          nir_store_shared(b, emit_prim[stream], scratch_base, .base = 16 + stream * 4);
2108       }
2109 
2110       /* Update shader query. */
2111       if (has_xfb_prim_query) {
2112          nir_if *if_shader_query = nir_push_if(b, nir_load_prim_xfb_query_enabled_amd(b));
2113          {
2114             for (unsigned stream = 0; stream < 4; stream++) {
2115                if (info->streams_written & BITFIELD_BIT(stream))
2116                   nir_atomic_add_xfb_prim_count_amd(b, emit_prim[stream], .stream_id = stream);
2117             }
2118          }
2119          nir_pop_if(b, if_shader_query);
2120       }
2121    }
2122    nir_pop_if(b, if_invocation_0);
2123 
2124    nir_barrier(b, .execution_scope = SCOPE_WORKGROUP,
2125                       .memory_scope = SCOPE_WORKGROUP,
2126                       .memory_semantics = NIR_MEMORY_ACQ_REL,
2127                       .memory_modes = nir_var_mem_shared);
2128 
2129    /* Fetch the per-buffer offsets in all waves. */
2130    for (unsigned buffer = 0; buffer < 4; buffer++) {
2131       if (!(info->buffers_written & BITFIELD_BIT(buffer)))
2132          continue;
2133 
2134       buffer_offsets_ret[buffer] =
2135          nir_load_shared(b, 1, 32, scratch_base, .base = buffer * 4);
2136    }
2137 
2138    /* Fetch the per-stream emit prim in all waves. */
2139    for (unsigned stream = 0; stream < 4; stream++) {
2140       if (!(info->streams_written & BITFIELD_BIT(stream)))
2141             continue;
2142 
2143       emit_prim_ret[stream] =
2144          nir_load_shared(b, 1, 32, scratch_base, .base = 16 + stream * 4);
2145    }
2146 }
2147 
2148 static void
ngg_build_streamout_vertex(nir_builder * b,nir_xfb_info * info,unsigned stream,nir_def * so_buffer[4],nir_def * buffer_offsets[4],unsigned vertex_index,nir_def * vtx_lds_addr,ac_nir_prerast_out * pr_out,bool skip_primitive_id)2149 ngg_build_streamout_vertex(nir_builder *b, nir_xfb_info *info,
2150                            unsigned stream, nir_def *so_buffer[4],
2151                            nir_def *buffer_offsets[4],
2152                            unsigned vertex_index, nir_def *vtx_lds_addr,
2153                            ac_nir_prerast_out *pr_out,
2154                            bool skip_primitive_id)
2155 {
2156    unsigned vertex_offset[NIR_MAX_XFB_BUFFERS] = {0};
2157 
2158    u_foreach_bit(buffer, info->buffers_written) {
2159       /* We use imm_offset for the vertex offset within a primitive, and GFX11 only supports
2160        * 12-bit unsigned imm_offset. (GFX12 supports 24-bit signed imm_offset)
2161        */
2162       assert(info->buffers[buffer].stride * 3 < 4096);
2163       vertex_offset[buffer] = vertex_index * info->buffers[buffer].stride;
2164    }
2165 
2166    nir_def *zero = nir_imm_int(b, 0);
2167    unsigned num_values = 0, store_offset = 0, store_buffer_index = 0;
2168    nir_def *values[4];
2169 
2170    for (unsigned i = 0; i < info->output_count; i++) {
2171       nir_xfb_output_info *out = info->outputs + i;
2172       if (!out->component_mask || info->buffer_to_stream[out->buffer] != stream)
2173          continue;
2174 
2175       unsigned base;
2176       if (out->location >= VARYING_SLOT_VAR0_16BIT) {
2177          base =
2178             util_bitcount64(b->shader->info.outputs_written) +
2179             util_bitcount(b->shader->info.outputs_written_16bit &
2180                           BITFIELD_MASK(out->location - VARYING_SLOT_VAR0_16BIT));
2181       } else {
2182          uint64_t outputs_written = b->shader->info.outputs_written;
2183          if (skip_primitive_id)
2184             outputs_written &= ~VARYING_BIT_PRIMITIVE_ID;
2185 
2186          base =
2187             util_bitcount64(outputs_written &
2188                             BITFIELD64_MASK(out->location));
2189       }
2190 
2191       unsigned offset = (base * 4 + out->component_offset) * 4;
2192       unsigned count = util_bitcount(out->component_mask);
2193 
2194       assert(u_bit_consecutive(out->component_offset, count) == out->component_mask);
2195 
2196       nir_def *out_data =
2197          nir_load_shared(b, count, 32, vtx_lds_addr, .base = offset);
2198 
2199       for (unsigned comp = 0; comp < count; comp++) {
2200          nir_def *data = nir_channel(b, out_data, comp);
2201 
2202          /* Convert 16-bit outputs to 32-bit.
2203           *
2204           * OpenGL ES will put 16-bit medium precision varyings to VARYING_SLOT_VAR0_16BIT.
2205           * We need to convert them to 32-bit for streamout.
2206           *
2207           * Vulkan does not allow 8/16bit varyings for streamout.
2208           */
2209          if (out->location >= VARYING_SLOT_VAR0_16BIT) {
2210             unsigned index = out->location - VARYING_SLOT_VAR0_16BIT;
2211             unsigned c = out->component_offset + comp;
2212             nir_def *v;
2213             nir_alu_type t;
2214 
2215             if (out->high_16bits) {
2216                v = nir_unpack_32_2x16_split_y(b, data);
2217                t = pr_out->types_16bit_hi[index][c];
2218             } else {
2219                v = nir_unpack_32_2x16_split_x(b, data);
2220                t = pr_out->types_16bit_lo[index][c];
2221             }
2222 
2223             t = nir_alu_type_get_base_type(t);
2224             data = nir_convert_to_bit_size(b, v, t, 32);
2225          }
2226 
2227          const unsigned store_comp_offset = out->offset + comp * 4;
2228          const bool has_hole = store_offset + num_values * 4 != store_comp_offset;
2229 
2230          /* Flush the gathered components to memory as a vec4 store or less if there is a hole. */
2231          if (num_values && (num_values == 4 || store_buffer_index != out->buffer || has_hole)) {
2232             nir_store_buffer_amd(b, nir_vec(b, values, num_values), so_buffer[store_buffer_index],
2233                                  buffer_offsets[store_buffer_index], zero, zero,
2234                                  .base = vertex_offset[store_buffer_index] + store_offset,
2235                                  .access = ACCESS_NON_TEMPORAL);
2236             num_values = 0;
2237          }
2238 
2239          /* Initialize the buffer index and offset if we are beginning a new vec4 store. */
2240          if (num_values == 0) {
2241             store_buffer_index = out->buffer;
2242             store_offset = store_comp_offset;
2243          }
2244 
2245          values[num_values++] = data;
2246       }
2247    }
2248 
2249    if (num_values) {
2250       /* Flush the remaining components to memory (as an up to vec4 store) */
2251       nir_store_buffer_amd(b, nir_vec(b, values, num_values), so_buffer[store_buffer_index],
2252                            buffer_offsets[store_buffer_index], zero, zero,
2253                            .base = vertex_offset[store_buffer_index] + store_offset,
2254                            .access = ACCESS_NON_TEMPORAL);
2255    }
2256 }
2257 
2258 static void
ngg_nogs_build_streamout(nir_builder * b,lower_ngg_nogs_state * s)2259 ngg_nogs_build_streamout(nir_builder *b, lower_ngg_nogs_state *s)
2260 {
2261    nir_xfb_info *info = ac_nir_get_sorted_xfb_info(b->shader);
2262 
2263    nir_def *lds_scratch_base = nir_load_lds_ngg_scratch_base_amd(b);
2264 
2265    /* Get global buffer offset where this workgroup will stream out data to. */
2266    nir_def *generated_prim = nir_load_workgroup_num_input_primitives_amd(b);
2267    nir_def *gen_prim_per_stream[4] = {generated_prim, 0, 0, 0};
2268    nir_def *emit_prim_per_stream[4] = {0};
2269    nir_def *buffer_offsets[4] = {0};
2270    nir_def *so_buffer[4] = {0};
2271    nir_def *tid_in_tg = nir_load_local_invocation_index(b);
2272    ngg_build_streamout_buffer_info(b, info, s->options->gfx_level, s->options->has_xfb_prim_query,
2273                                    s->options->use_gfx12_xfb_intrinsic, lds_scratch_base, tid_in_tg,
2274                                    gen_prim_per_stream,
2275                                    so_buffer, buffer_offsets,
2276                                    emit_prim_per_stream);
2277 
2278    /* Write out primitive data */
2279    nir_if *if_emit = nir_push_if(b, nir_ilt(b, tid_in_tg, emit_prim_per_stream[0]));
2280    {
2281       unsigned vtx_lds_stride = (b->shader->num_outputs * 4 + 1) * 4;
2282       nir_def *num_vert_per_prim = nir_load_num_vertices_per_primitive_amd(b);
2283       nir_def *first_vertex_idx = nir_imul(b, tid_in_tg, num_vert_per_prim);
2284 
2285       u_foreach_bit(buffer, info->buffers_written) {
2286          buffer_offsets[buffer] = nir_iadd(b, buffer_offsets[buffer],
2287                                            nir_imul_imm(b, first_vertex_idx,
2288                                                         info->buffers[buffer].stride));
2289       }
2290 
2291       for (unsigned i = 0; i < s->options->num_vertices_per_primitive; i++) {
2292          nir_if *if_valid_vertex =
2293             nir_push_if(b, nir_igt_imm(b, num_vert_per_prim, i));
2294          {
2295             nir_def *vtx_lds_idx = nir_load_var(b, s->gs_vtx_indices_vars[i]);
2296             nir_def *vtx_lds_addr = pervertex_lds_addr(b, vtx_lds_idx, vtx_lds_stride);
2297             ngg_build_streamout_vertex(b, info, 0, so_buffer, buffer_offsets, i,
2298                                        vtx_lds_addr, &s->out, s->skip_primitive_id);
2299          }
2300          nir_pop_if(b, if_valid_vertex);
2301       }
2302    }
2303    nir_pop_if(b, if_emit);
2304 
2305    /* Wait streamout memory ops done before export primitive, otherwise it
2306     * may not finish when shader ends.
2307     *
2308     * If a shader has no param exports, rasterization can start before
2309     * the shader finishes and thus memory stores might not finish before
2310     * the pixel shader starts.
2311     *
2312     * TODO: we only need this when no param exports.
2313     *
2314     * TODO: not sure if we need this barrier when late prim export, as I
2315     *       can't observe test fail without this barrier.
2316     */
2317    nir_scoped_memory_barrier(b, SCOPE_DEVICE, NIR_MEMORY_RELEASE, nir_var_mem_ssbo);
2318 }
2319 
2320 static unsigned
ngg_nogs_get_pervertex_lds_size(gl_shader_stage stage,unsigned shader_num_outputs,bool streamout_enabled,bool export_prim_id,bool has_user_edgeflags)2321 ngg_nogs_get_pervertex_lds_size(gl_shader_stage stage,
2322                                 unsigned shader_num_outputs,
2323                                 bool streamout_enabled,
2324                                 bool export_prim_id,
2325                                 bool has_user_edgeflags)
2326 {
2327    unsigned pervertex_lds_bytes = 0;
2328 
2329    if (streamout_enabled) {
2330       /* The extra dword is used to avoid LDS bank conflicts and store the primitive id.
2331        * TODO: only alloc space for outputs that really need streamout.
2332        */
2333       pervertex_lds_bytes = (shader_num_outputs * 4 + 1) * 4;
2334    }
2335 
2336    bool need_prim_id_store_shared = export_prim_id && stage == MESA_SHADER_VERTEX;
2337    if (need_prim_id_store_shared || has_user_edgeflags) {
2338       unsigned size = 0;
2339       if (need_prim_id_store_shared)
2340          size += 4;
2341       if (has_user_edgeflags)
2342          size += 4;
2343 
2344       /* pad to odd dwords to avoid LDS bank conflict */
2345       size |= 4;
2346 
2347       pervertex_lds_bytes = MAX2(pervertex_lds_bytes, size);
2348    }
2349 
2350    return pervertex_lds_bytes;
2351 }
2352 
2353 static void
ngg_nogs_gather_outputs(nir_builder * b,struct exec_list * cf_list,lower_ngg_nogs_state * s)2354 ngg_nogs_gather_outputs(nir_builder *b, struct exec_list *cf_list, lower_ngg_nogs_state *s)
2355 {
2356    /* Assume:
2357     * - the shader used nir_lower_io_to_temporaries
2358     * - 64-bit outputs are lowered
2359     * - no indirect indexing is present
2360     */
2361    struct nir_cf_node *first_node =
2362       exec_node_data(nir_cf_node, exec_list_get_head(cf_list), node);
2363 
2364    for (nir_block *block = nir_cf_node_cf_tree_first(first_node); block != NULL;
2365         block = nir_block_cf_tree_next(block)) {
2366       nir_foreach_instr_safe (instr, block) {
2367          if (instr->type != nir_instr_type_intrinsic)
2368             continue;
2369 
2370          nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
2371          if (intrin->intrinsic != nir_intrinsic_store_output)
2372             continue;
2373 
2374          ac_nir_gather_prerast_store_output_info(b, intrin, &s->out);
2375          nir_instr_remove(instr);
2376       }
2377    }
2378 }
2379 
2380 static void
create_output_phis(nir_builder * b,const uint64_t outputs_written,const uint64_t outputs_written_16bit,ac_nir_prerast_out * out)2381 create_output_phis(nir_builder *b, const uint64_t outputs_written, const uint64_t outputs_written_16bit, ac_nir_prerast_out *out)
2382 {
2383    nir_def *undef = nir_undef(b, 1, 32); /* inserted at the start of the shader */
2384 
2385    u_foreach_bit64(slot, outputs_written) {
2386       for (unsigned j = 0; j < 4; j++) {
2387          if (out->outputs[slot][j])
2388             out->outputs[slot][j] = nir_if_phi(b, out->outputs[slot][j], undef);
2389       }
2390    }
2391 
2392    u_foreach_bit64(i, outputs_written_16bit) {
2393       for (unsigned j = 0; j < 4; j++) {
2394          if (out->outputs_16bit_hi[i][j])
2395             out->outputs_16bit_hi[i][j] = nir_if_phi(b, out->outputs_16bit_hi[i][j], undef);
2396 
2397          if (out->outputs_16bit_lo[i][j])
2398             out->outputs_16bit_lo[i][j] = nir_if_phi(b, out->outputs_16bit_lo[i][j], undef);
2399       }
2400    }
2401 }
2402 
must_wait_attr_ring(enum amd_gfx_level gfx_level,bool has_param_exports)2403 static bool must_wait_attr_ring(enum amd_gfx_level gfx_level, bool has_param_exports)
2404 {
2405    return (gfx_level == GFX11 || gfx_level == GFX11_5) && has_param_exports;
2406 }
2407 
2408 static void
export_pos0_wait_attr_ring(nir_builder * b,nir_if * if_es_thread,nir_def * outputs[VARYING_SLOT_MAX][4],const ac_nir_lower_ngg_options * options)2409 export_pos0_wait_attr_ring(nir_builder *b, nir_if *if_es_thread, nir_def *outputs[VARYING_SLOT_MAX][4], const ac_nir_lower_ngg_options *options)
2410 {
2411    b->cursor = nir_after_cf_node(&if_es_thread->cf_node);
2412 
2413    /* Create phi for the position output values. */
2414    ac_nir_prerast_out out = {
2415       .outputs = {{outputs[VARYING_SLOT_POS][0], outputs[VARYING_SLOT_POS][1], outputs[VARYING_SLOT_POS][2], outputs[VARYING_SLOT_POS][3]}},
2416       .infos = {{.components_mask = 0xf, .as_sysval_mask = 0xf}},
2417    };
2418 
2419    b->cursor = nir_after_cf_list(&b->impl->body);
2420 
2421    /* Wait for attribute stores to finish. */
2422    nir_barrier(b, .execution_scope = SCOPE_SUBGROUP,
2423                   .memory_scope = SCOPE_DEVICE,
2424                   .memory_semantics = NIR_MEMORY_RELEASE,
2425                   .memory_modes = nir_var_mem_ssbo | nir_var_shader_out | nir_var_mem_global | nir_var_image);
2426 
2427    /* Export just the pos0 output. */
2428    nir_if *if_export_empty_pos = nir_push_if(b, if_es_thread->condition.ssa);
2429    {
2430       ac_nir_export_position(b, options->gfx_level,
2431                              options->clip_cull_dist_mask,
2432                              !options->has_param_exports,
2433                              options->force_vrs, true,
2434                              VARYING_BIT_POS, &out, NULL);
2435    }
2436    nir_pop_if(b, if_export_empty_pos);
2437 }
2438 
2439 
2440 static void
nogs_export_vertex_params(nir_builder * b,nir_function_impl * impl,nir_if * if_es_thread,nir_def * num_es_threads,lower_ngg_nogs_state * s)2441 nogs_export_vertex_params(nir_builder *b, nir_function_impl *impl,
2442                           nir_if *if_es_thread, nir_def *num_es_threads,
2443                           lower_ngg_nogs_state *s)
2444 {
2445    if (!s->options->has_param_exports)
2446       return;
2447 
2448    if (s->options->gfx_level >= GFX11) {
2449       /* Export varyings for GFX11+ */
2450       b->cursor = nir_after_impl(impl);
2451       if (!num_es_threads)
2452          num_es_threads = nir_load_merged_wave_info_amd(b);
2453 
2454       ac_nir_store_parameters_to_attr_ring(b, s->options->vs_output_param_offset,
2455                                            b->shader->info.outputs_written,
2456                                            b->shader->info.outputs_written_16bit,
2457                                            &s->out, NULL, num_es_threads);
2458    } else {
2459       ac_nir_export_parameters(b, s->options->vs_output_param_offset,
2460                                  b->shader->info.outputs_written,
2461                                  b->shader->info.outputs_written_16bit,
2462                                  &s->out);
2463    }
2464 }
2465 
2466 void
ac_nir_lower_ngg_nogs(nir_shader * shader,const ac_nir_lower_ngg_options * options)2467 ac_nir_lower_ngg_nogs(nir_shader *shader, const ac_nir_lower_ngg_options *options)
2468 {
2469    nir_function_impl *impl = nir_shader_get_entrypoint(shader);
2470    assert(impl);
2471    assert(options->max_workgroup_size && options->wave_size);
2472    assert(!(options->can_cull && options->passthrough));
2473 
2474    nir_variable *position_value_var = nir_local_variable_create(impl, glsl_vec4_type(), "position_value");
2475    nir_variable *prim_exp_arg_var = nir_local_variable_create(impl, glsl_uint_type(), "prim_exp_arg");
2476    nir_variable *es_accepted_var =
2477       options->can_cull ? nir_local_variable_create(impl, glsl_bool_type(), "es_accepted") : NULL;
2478    nir_variable *gs_accepted_var =
2479       options->can_cull ? nir_local_variable_create(impl, glsl_bool_type(), "gs_accepted") : NULL;
2480    nir_variable *gs_exported_var = nir_local_variable_create(impl, glsl_bool_type(), "gs_exported");
2481 
2482    bool streamout_enabled = shader->xfb_info && !options->disable_streamout;
2483    bool has_user_edgeflags =
2484       options->use_edgeflags && (shader->info.outputs_written & VARYING_BIT_EDGE);
2485    /* streamout need to be done before either prim or vertex export. Because when no
2486     * param export, rasterization can start right after prim and vertex export,
2487     * which left streamout buffer writes un-finished.
2488     *
2489     * Always use late prim export when user edge flags are enabled.
2490     * This is because edge flags are written by ES threads but they
2491     * are exported by GS threads as part of th primitive export.
2492     */
2493    bool early_prim_export =
2494       options->early_prim_export && !(streamout_enabled || has_user_edgeflags);
2495 
2496    lower_ngg_nogs_state state = {
2497       .options = options,
2498       .early_prim_export = early_prim_export,
2499       .streamout_enabled = streamout_enabled,
2500       .position_value_var = position_value_var,
2501       .prim_exp_arg_var = prim_exp_arg_var,
2502       .es_accepted_var = es_accepted_var,
2503       .gs_accepted_var = gs_accepted_var,
2504       .gs_exported_var = gs_exported_var,
2505       .max_num_waves = DIV_ROUND_UP(options->max_workgroup_size, options->wave_size),
2506       .has_user_edgeflags = has_user_edgeflags,
2507       .skip_primitive_id = streamout_enabled && (options->export_primitive_id || options->export_primitive_id_per_prim),
2508    };
2509 
2510    /* Can't export the primitive ID both as per-vertex and per-primitive. */
2511    assert(!options->export_primitive_id || !options->export_primitive_id_per_prim);
2512 
2513    const bool need_prim_id_store_shared =
2514       options->export_primitive_id && shader->info.stage == MESA_SHADER_VERTEX;
2515 
2516    if (options->export_primitive_id) {
2517       shader->info.outputs_written |= VARYING_BIT_PRIMITIVE_ID;
2518    }
2519 
2520    if (options->export_primitive_id_per_prim) {
2521       /* The HW preloads the primitive ID to VGPRs of GS threads for VS, but not for TES. */
2522       assert(shader->info.stage == MESA_SHADER_VERTEX);
2523       assert(options->gfx_level >= GFX10_3);
2524    }
2525 
2526    nir_builder builder = nir_builder_create(impl);
2527    nir_builder *b = &builder; /* This is to avoid the & */
2528 
2529    if (options->can_cull) {
2530       analyze_shader_before_culling(shader, &state);
2531       save_reusable_variables(b, &state);
2532    }
2533 
2534    nir_cf_list extracted;
2535    nir_cf_extract(&extracted, nir_before_impl(impl),
2536                   nir_after_impl(impl));
2537    b->cursor = nir_before_impl(impl);
2538 
2539    ngg_nogs_init_vertex_indices_vars(b, impl, &state);
2540 
2541    /* Emit primitives generated query code here, so that
2542     * it executes before culling and isn't in the extracted CF.
2543     */
2544    nogs_prim_gen_query(b, &state);
2545 
2546    /* Whether a shader invocation should export a primitive,
2547     * initialize to all invocations that have an input primitive.
2548     */
2549    nir_store_var(b, gs_exported_var, has_input_primitive(b), 0x1u);
2550 
2551    if (!options->can_cull) {
2552       /* Newer chips can use PRIMGEN_PASSTHRU_NO_MSG to skip gs_alloc_req for NGG passthrough. */
2553       if (!(options->passthrough && options->family >= CHIP_NAVI23)) {
2554          /* Allocate export space on wave 0 - confirm to the HW that we want to use all possible space */
2555          nir_if *if_wave_0 = nir_push_if(b, nir_ieq_imm(b, nir_load_subgroup_id(b), 0));
2556          {
2557             nir_def *vtx_cnt = nir_load_workgroup_num_input_vertices_amd(b);
2558             nir_def *prim_cnt = nir_load_workgroup_num_input_primitives_amd(b);
2559             alloc_vertices_and_primitives(b, vtx_cnt, prim_cnt);
2560          }
2561          nir_pop_if(b, if_wave_0);
2562       }
2563 
2564       /* Take care of early primitive export, otherwise just pack the primitive export argument */
2565       if (state.early_prim_export)
2566          emit_ngg_nogs_prim_export(b, &state, NULL);
2567       else
2568          nir_store_var(b, prim_exp_arg_var, emit_ngg_nogs_prim_exp_arg(b, &state), 0x1u);
2569    } else {
2570       add_deferred_attribute_culling(b, &extracted, &state);
2571       b->cursor = nir_after_impl(impl);
2572 
2573       if (state.early_prim_export)
2574          emit_ngg_nogs_prim_export(b, &state, nir_load_var(b, state.prim_exp_arg_var));
2575 
2576       /* Wait for culling to finish using LDS. */
2577       if (need_prim_id_store_shared || has_user_edgeflags) {
2578          nir_barrier(b, .execution_scope = SCOPE_WORKGROUP,
2579                                .memory_scope = SCOPE_WORKGROUP,
2580                                .memory_semantics = NIR_MEMORY_ACQ_REL,
2581                                .memory_modes = nir_var_mem_shared);
2582       }
2583    }
2584 
2585    /* determine the LDS vertex stride */
2586    state.pervertex_lds_bytes =
2587       ngg_nogs_get_pervertex_lds_size(shader->info.stage,
2588                                       shader->num_outputs,
2589                                       state.streamout_enabled,
2590                                       options->export_primitive_id,
2591                                       state.has_user_edgeflags);
2592 
2593    if (need_prim_id_store_shared) {
2594       emit_ngg_nogs_prim_id_store_shared(b, &state);
2595 
2596       /* Wait for GS threads to store primitive ID in LDS. */
2597       nir_barrier(b, .execution_scope = SCOPE_WORKGROUP, .memory_scope = SCOPE_WORKGROUP,
2598                             .memory_semantics = NIR_MEMORY_ACQ_REL, .memory_modes = nir_var_mem_shared);
2599    } else if (options->export_primitive_id_per_prim && options->gfx_level >= GFX11) {
2600       emit_ngg_nogs_prim_id_store_per_prim_to_attr_ring(b, &state);
2601    }
2602 
2603    nir_def *es_thread =
2604       options->can_cull ? nir_load_var(b, es_accepted_var) : has_input_vertex(b);
2605 
2606    /* Calculate the bit count here instead of below for lower SGPR usage and better ALU
2607     * scheduling.
2608     */
2609    nir_def *num_es_threads = NULL;
2610    if (state.options->gfx_level >= GFX11 && options->can_cull) {
2611       nir_def *es_accepted_mask =
2612          nir_ballot(b, 1, options->wave_size, nir_load_var(b, es_accepted_var));
2613       num_es_threads = nir_bit_count(b, es_accepted_mask);
2614    }
2615 
2616    nir_if *if_es_thread = nir_push_if(b, es_thread);
2617    {
2618       /* Run the actual shader */
2619       nir_cf_reinsert(&extracted, b->cursor);
2620       b->cursor = nir_after_cf_list(&if_es_thread->then_list);
2621 
2622       if (options->export_primitive_id)
2623          emit_store_ngg_nogs_es_primitive_id(b, &state);
2624    }
2625    nir_pop_if(b, if_es_thread);
2626 
2627    if (options->can_cull) {
2628       /* Replace uniforms. */
2629       apply_reusable_variables(b, &state);
2630 
2631       /* Remove the redundant position output. */
2632       remove_extra_pos_outputs(shader, &state);
2633 
2634       /* After looking at the performance in apps eg. Doom Eternal, and The Witcher 3,
2635        * it seems that it's best to put the position export always at the end, and
2636        * then let ACO schedule it up (slightly) only when early prim export is used.
2637        */
2638       b->cursor = nir_after_cf_list(&if_es_thread->then_list);
2639 
2640       nir_def *pos_val = nir_load_var(b, state.position_value_var);
2641       for (int i = 0; i < 4; i++)
2642          state.out.outputs[VARYING_SLOT_POS][i] = nir_channel(b, pos_val, i);
2643    }
2644 
2645    /* Gather outputs data and types */
2646    ngg_nogs_gather_outputs(b, &if_es_thread->then_list, &state);
2647    b->cursor = nir_after_cf_list(&if_es_thread->then_list);
2648 
2649    if (state.has_user_edgeflags)
2650       ngg_nogs_store_edgeflag_to_lds(b, &state);
2651 
2652    if (state.streamout_enabled) {
2653       /* TODO: support culling after streamout. */
2654       assert(!options->can_cull);
2655 
2656       ngg_nogs_store_xfb_outputs_to_lds(b, &state);
2657 
2658       b->cursor = nir_after_impl(impl);
2659       ngg_nogs_build_streamout(b, &state);
2660    }
2661 
2662    /* Take care of late primitive export */
2663    if (!state.early_prim_export) {
2664       b->cursor = nir_after_impl(impl);
2665       emit_ngg_nogs_prim_export(b, &state, nir_load_var(b, prim_exp_arg_var));
2666    }
2667 
2668    uint64_t export_outputs = shader->info.outputs_written | VARYING_BIT_POS;
2669    if (options->kill_pointsize)
2670       export_outputs &= ~VARYING_BIT_PSIZ;
2671    if (options->kill_layer)
2672       export_outputs &= ~VARYING_BIT_LAYER;
2673 
2674    const bool wait_attr_ring = must_wait_attr_ring(options->gfx_level, options->has_param_exports);
2675    if (wait_attr_ring)
2676       export_outputs &= ~VARYING_BIT_POS;
2677 
2678    bool phis_created = false;
2679 
2680    /* Add position exports.
2681     *
2682     * If streamout is enabled, export positions after streamout. This increases streamout performance
2683     * for up to 4 vec4 xfb outputs on GFX12 because the streamout code doesn't have go through
2684     * the export allocation bottleneck. Adding more xfb outputs starts to be limited by the memory
2685     * bandwidth.
2686     */
2687    nir_if *if_pos_exports = NULL;
2688    if (state.streamout_enabled) {
2689       b->cursor = nir_after_cf_node(&if_es_thread->cf_node);
2690       create_output_phis(b, b->shader->info.outputs_written, b->shader->info.outputs_written_16bit,
2691                          &state.out);
2692       phis_created = true;
2693 
2694       b->cursor = nir_after_impl(impl);
2695       if_pos_exports = nir_push_if(b, es_thread);
2696    } else {
2697       b->cursor = nir_after_cf_list(&if_es_thread->then_list);
2698    }
2699 
2700    ac_nir_export_position(b, options->gfx_level,
2701                           options->clip_cull_dist_mask,
2702                           !options->has_param_exports,
2703                           options->force_vrs, !wait_attr_ring,
2704                           export_outputs, &state.out, NULL);
2705 
2706    if (if_pos_exports)
2707       nir_pop_if(b, if_pos_exports);
2708 
2709    if (options->has_param_exports && options->gfx_level >= GFX11 && !phis_created) {
2710       b->cursor = nir_after_cf_node(&if_es_thread->cf_node);
2711       create_output_phis(b, b->shader->info.outputs_written, b->shader->info.outputs_written_16bit,
2712                          &state.out);
2713    }
2714 
2715    b->cursor = nir_after_cf_list(&if_es_thread->then_list);
2716    nogs_export_vertex_params(b, impl, if_es_thread, num_es_threads, &state);
2717 
2718    if (wait_attr_ring)
2719       export_pos0_wait_attr_ring(b, if_es_thread, state.out.outputs, options);
2720 
2721    nir_metadata_preserve(impl, nir_metadata_none);
2722    nir_validate_shader(shader, "after emitting NGG VS/TES");
2723 
2724    /* Cleanup */
2725    nir_opt_dead_write_vars(shader);
2726    nir_lower_vars_to_ssa(shader);
2727    nir_remove_dead_variables(shader, nir_var_function_temp, NULL);
2728    nir_lower_alu_to_scalar(shader, NULL, NULL);
2729    nir_lower_phis_to_scalar(shader, true);
2730 
2731    if (options->can_cull) {
2732       /* It's beneficial to redo these opts after splitting the shader. */
2733       nir_opt_sink(shader, nir_move_load_input | nir_move_const_undef | nir_move_copies);
2734       nir_opt_move(shader, nir_move_load_input | nir_move_copies | nir_move_const_undef);
2735    }
2736 
2737    bool progress;
2738    do {
2739       progress = false;
2740       NIR_PASS(progress, shader, nir_opt_undef);
2741       NIR_PASS(progress, shader, nir_opt_dce);
2742       NIR_PASS(progress, shader, nir_opt_dead_cf);
2743 
2744       if (options->can_cull)
2745          progress |= cleanup_culling_shader_after_dce(shader, b->impl, &state);
2746    } while (progress);
2747 }
2748 
2749 /**
2750  * Return the address of the LDS storage reserved for the N'th vertex,
2751  * where N is in emit order, meaning:
2752  * - during the finale, N is the invocation_index (within the workgroup)
2753  * - during vertex emit, i.e. while the API GS shader invocation is running,
2754  *   N = invocation_index * gs_max_out_vertices + emit_idx
2755  *   where emit_idx is the vertex index in the current API GS invocation.
2756  *
2757  * Goals of the LDS memory layout:
2758  * 1. Eliminate bank conflicts on write for geometry shaders that have all emits
2759  *    in uniform control flow
2760  * 2. Eliminate bank conflicts on read for export if, additionally, there is no
2761  *    culling
2762  * 3. Agnostic to the number of waves (since we don't know it before compiling)
2763  * 4. Allow coalescing of LDS instructions (ds_write_b128 etc.)
2764  * 5. Avoid wasting memory.
2765  *
2766  * We use an AoS layout due to point 4 (this also helps point 3). In an AoS
2767  * layout, elimination of bank conflicts requires that each vertex occupy an
2768  * odd number of dwords. We use the additional dword to store the output stream
2769  * index as well as a flag to indicate whether this vertex ends a primitive
2770  * for rasterization.
2771  *
2772  * Swizzling is required to satisfy points 1 and 2 simultaneously.
2773  *
2774  * Vertices are stored in export order (gsthread * gs_max_out_vertices + emitidx).
2775  * Indices are swizzled in groups of 32, which ensures point 1 without
2776  * disturbing point 2.
2777  *
2778  * \return an LDS pointer to type {[N x i32], [4 x i8]}
2779  */
2780 static nir_def *
ngg_gs_out_vertex_addr(nir_builder * b,nir_def * out_vtx_idx,lower_ngg_gs_state * s)2781 ngg_gs_out_vertex_addr(nir_builder *b, nir_def *out_vtx_idx, lower_ngg_gs_state *s)
2782 {
2783    unsigned write_stride_2exp = ffs(MAX2(b->shader->info.gs.vertices_out, 1)) - 1;
2784 
2785    /* gs_max_out_vertices = 2^(write_stride_2exp) * some odd number */
2786    if (write_stride_2exp) {
2787       nir_def *row = nir_ushr_imm(b, out_vtx_idx, 5);
2788       nir_def *swizzle = nir_iand_imm(b, row, (1u << write_stride_2exp) - 1u);
2789       out_vtx_idx = nir_ixor(b, out_vtx_idx, swizzle);
2790    }
2791 
2792    nir_def *out_vtx_offs = nir_imul_imm(b, out_vtx_idx, s->lds_bytes_per_gs_out_vertex);
2793    return nir_iadd_nuw(b, out_vtx_offs, s->lds_addr_gs_out_vtx);
2794 }
2795 
2796 static nir_def *
ngg_gs_emit_vertex_addr(nir_builder * b,nir_def * gs_vtx_idx,lower_ngg_gs_state * s)2797 ngg_gs_emit_vertex_addr(nir_builder *b, nir_def *gs_vtx_idx, lower_ngg_gs_state *s)
2798 {
2799    nir_def *tid_in_tg = nir_load_local_invocation_index(b);
2800    nir_def *gs_out_vtx_base = nir_imul_imm(b, tid_in_tg, b->shader->info.gs.vertices_out);
2801    nir_def *out_vtx_idx = nir_iadd_nuw(b, gs_out_vtx_base, gs_vtx_idx);
2802 
2803    return ngg_gs_out_vertex_addr(b, out_vtx_idx, s);
2804 }
2805 
2806 static void
ngg_gs_clear_primflags(nir_builder * b,nir_def * num_vertices,unsigned stream,lower_ngg_gs_state * s)2807 ngg_gs_clear_primflags(nir_builder *b, nir_def *num_vertices, unsigned stream, lower_ngg_gs_state *s)
2808 {
2809    char name[32];
2810    snprintf(name, sizeof(name), "clear_primflag_idx_%u", stream);
2811    nir_variable *clear_primflag_idx_var = nir_local_variable_create(b->impl, glsl_uint_type(), name);
2812 
2813    nir_def *zero_u8 = nir_imm_zero(b, 1, 8);
2814    nir_store_var(b, clear_primflag_idx_var, num_vertices, 0x1u);
2815 
2816    nir_loop *loop = nir_push_loop(b);
2817    {
2818       nir_def *clear_primflag_idx = nir_load_var(b, clear_primflag_idx_var);
2819       nir_if *if_break = nir_push_if(b, nir_uge_imm(b, clear_primflag_idx, b->shader->info.gs.vertices_out));
2820       {
2821          nir_jump(b, nir_jump_break);
2822       }
2823       nir_push_else(b, if_break);
2824       {
2825          nir_def *emit_vtx_addr = ngg_gs_emit_vertex_addr(b, clear_primflag_idx, s);
2826          nir_store_shared(b, zero_u8, emit_vtx_addr, .base = s->lds_offs_primflags + stream);
2827          nir_store_var(b, clear_primflag_idx_var, nir_iadd_imm_nuw(b, clear_primflag_idx, 1), 0x1u);
2828       }
2829       nir_pop_if(b, if_break);
2830    }
2831    nir_pop_loop(b, loop);
2832 }
2833 
2834 static bool
lower_ngg_gs_store_output(nir_builder * b,nir_intrinsic_instr * intrin,lower_ngg_gs_state * s)2835 lower_ngg_gs_store_output(nir_builder *b, nir_intrinsic_instr *intrin, lower_ngg_gs_state *s)
2836 {
2837    ac_nir_gather_prerast_store_output_info(b, intrin, &s->out);
2838    nir_instr_remove(&intrin->instr);
2839    return true;
2840 }
2841 
2842 static unsigned
gs_output_component_mask_with_stream(ac_nir_prerast_per_output_info * info,unsigned stream)2843 gs_output_component_mask_with_stream(ac_nir_prerast_per_output_info *info, unsigned stream)
2844 {
2845    unsigned mask = info->components_mask;
2846    if (!mask)
2847       return 0;
2848 
2849    /* clear component when not requested stream */
2850    for (int i = 0; i < 4; i++) {
2851       if (((info->stream >> (i * 2)) & 3) != stream)
2852          mask &= ~(1 << i);
2853    }
2854 
2855    return mask;
2856 }
2857 
2858 static bool
lower_ngg_gs_emit_vertex_with_counter(nir_builder * b,nir_intrinsic_instr * intrin,lower_ngg_gs_state * s)2859 lower_ngg_gs_emit_vertex_with_counter(nir_builder *b, nir_intrinsic_instr *intrin, lower_ngg_gs_state *s)
2860 {
2861    b->cursor = nir_before_instr(&intrin->instr);
2862 
2863    unsigned stream = nir_intrinsic_stream_id(intrin);
2864    if (!(b->shader->info.gs.active_stream_mask & (1 << stream))) {
2865       nir_instr_remove(&intrin->instr);
2866       return true;
2867    }
2868 
2869    nir_def *gs_emit_vtx_idx = intrin->src[0].ssa;
2870    nir_def *current_vtx_per_prim = intrin->src[1].ssa;
2871    nir_def *gs_emit_vtx_addr = ngg_gs_emit_vertex_addr(b, gs_emit_vtx_idx, s);
2872 
2873    /* Store generic 32-bit outputs to LDS.
2874     * In case of packed 16-bit, we assume that has been already packed into 32 bit slots by now.
2875     */
2876    u_foreach_bit64(slot, b->shader->info.outputs_written) {
2877       const unsigned packed_location = util_bitcount64((b->shader->info.outputs_written & BITFIELD64_MASK(slot)));
2878       unsigned mask = gs_output_component_mask_with_stream(&s->out.infos[slot], stream);
2879 
2880       nir_def **output = s->out.outputs[slot];
2881       nir_def *undef = nir_undef(b, 1, 32);
2882 
2883       while (mask) {
2884          int start, count;
2885          u_bit_scan_consecutive_range(&mask, &start, &count);
2886          nir_def *values[4] = {0};
2887          for (int c = start; c < start + count; ++c) {
2888             if (!output[c]) {
2889                /* The shader hasn't written this output. */
2890                values[c - start] = undef;
2891             } else {
2892                assert(output[c]->bit_size == 32);
2893                values[c - start] = output[c];
2894             }
2895          }
2896 
2897          nir_def *store_val = nir_vec(b, values, (unsigned)count);
2898          nir_store_shared(b, store_val, gs_emit_vtx_addr,
2899                           .base = packed_location * 16 + start * 4,
2900                           .align_mul = 4);
2901       }
2902 
2903       /* Clear all outputs (they are undefined after emit_vertex) */
2904       memset(s->out.outputs[slot], 0, sizeof(s->out.outputs[slot]));
2905    }
2906 
2907    const unsigned num_32bit_outputs = util_bitcount64(b->shader->info.outputs_written);
2908 
2909    /* Store dedicated 16-bit outputs to LDS. */
2910    u_foreach_bit(slot, b->shader->info.outputs_written_16bit) {
2911       const unsigned packed_location = num_32bit_outputs +
2912          util_bitcount(b->shader->info.outputs_written_16bit & BITFIELD_MASK(slot));
2913 
2914       const unsigned mask_lo = gs_output_component_mask_with_stream(s->out.infos_16bit_lo + slot, stream);
2915       const unsigned mask_hi = gs_output_component_mask_with_stream(s->out.infos_16bit_hi + slot, stream);
2916       unsigned mask = mask_lo | mask_hi;
2917 
2918       nir_def **output_lo = s->out.outputs_16bit_lo[slot];
2919       nir_def **output_hi = s->out.outputs_16bit_hi[slot];
2920       nir_def *undef = nir_undef(b, 1, 16);
2921 
2922       while (mask) {
2923          int start, count;
2924          u_bit_scan_consecutive_range(&mask, &start, &count);
2925          nir_def *values[4] = {0};
2926          for (int c = start; c < start + count; ++c) {
2927             nir_def *lo = output_lo[c] ? output_lo[c] : undef;
2928             nir_def *hi = output_hi[c] ? output_hi[c] : undef;
2929 
2930             values[c - start] = nir_pack_32_2x16_split(b, lo, hi);
2931          }
2932 
2933          nir_def *store_val = nir_vec(b, values, (unsigned)count);
2934          nir_store_shared(b, store_val, gs_emit_vtx_addr,
2935                           .base = packed_location * 16 + start * 4,
2936                           .align_mul = 4);
2937       }
2938 
2939       /* Clear all outputs (they are undefined after emit_vertex) */
2940       memset(s->out.outputs_16bit_lo[slot], 0, sizeof(s->out.outputs_16bit_lo[slot]));
2941       memset(s->out.outputs_16bit_hi[slot], 0, sizeof(s->out.outputs_16bit_hi[slot]));
2942    }
2943 
2944    /* Calculate and store per-vertex primitive flags based on vertex counts:
2945     * - bit 0: whether this vertex finishes a primitive (a real primitive, not the strip)
2946     * - bit 1: whether the primitive index is odd (if we are emitting triangle strips, otherwise always 0)
2947     *          only set when the vertex also finishes the primitive
2948     * - bit 2: whether vertex is live (if culling is enabled: set after culling, otherwise always 1)
2949     */
2950 
2951    nir_def *vertex_live_flag =
2952       !stream && s->options->can_cull
2953          ? nir_ishl_imm(b, nir_b2i32(b, nir_inot(b, nir_load_cull_any_enabled_amd(b))), 2)
2954          : nir_imm_int(b, 0b100);
2955 
2956    nir_def *completes_prim = nir_ige_imm(b, current_vtx_per_prim, s->num_vertices_per_primitive - 1);
2957    nir_def *complete_flag = nir_b2i32(b, completes_prim);
2958 
2959    nir_def *prim_flag = nir_ior(b, vertex_live_flag, complete_flag);
2960    if (s->num_vertices_per_primitive == 3) {
2961       nir_def *odd = nir_iand(b, current_vtx_per_prim, complete_flag);
2962       nir_def *odd_flag = nir_ishl_imm(b, odd, 1);
2963       prim_flag = nir_ior(b, prim_flag, odd_flag);
2964    }
2965 
2966    nir_store_shared(b, nir_u2u8(b, prim_flag), gs_emit_vtx_addr,
2967                     .base = s->lds_offs_primflags + stream,
2968                     .align_mul = 4, .align_offset = stream);
2969 
2970    nir_instr_remove(&intrin->instr);
2971    return true;
2972 }
2973 
2974 static bool
lower_ngg_gs_end_primitive_with_counter(nir_builder * b,nir_intrinsic_instr * intrin,UNUSED lower_ngg_gs_state * s)2975 lower_ngg_gs_end_primitive_with_counter(nir_builder *b, nir_intrinsic_instr *intrin, UNUSED lower_ngg_gs_state *s)
2976 {
2977    b->cursor = nir_before_instr(&intrin->instr);
2978 
2979    /* These are not needed, we can simply remove them */
2980    nir_instr_remove(&intrin->instr);
2981    return true;
2982 }
2983 
2984 static bool
lower_ngg_gs_set_vertex_and_primitive_count(nir_builder * b,nir_intrinsic_instr * intrin,lower_ngg_gs_state * s)2985 lower_ngg_gs_set_vertex_and_primitive_count(nir_builder *b, nir_intrinsic_instr *intrin, lower_ngg_gs_state *s)
2986 {
2987    b->cursor = nir_before_instr(&intrin->instr);
2988 
2989    unsigned stream = nir_intrinsic_stream_id(intrin);
2990    if (stream > 0 && !(b->shader->info.gs.active_stream_mask & (1 << stream))) {
2991       nir_instr_remove(&intrin->instr);
2992       return true;
2993    }
2994 
2995    s->vertex_count[stream] = intrin->src[0].ssa;
2996    s->primitive_count[stream] = intrin->src[1].ssa;
2997 
2998    /* Clear the primitive flags of non-emitted vertices */
2999    if (!nir_src_is_const(intrin->src[0]) || nir_src_as_uint(intrin->src[0]) < b->shader->info.gs.vertices_out)
3000       ngg_gs_clear_primflags(b, intrin->src[0].ssa, stream, s);
3001 
3002    nir_instr_remove(&intrin->instr);
3003    return true;
3004 }
3005 
3006 static bool
lower_ngg_gs_intrinsic(nir_builder * b,nir_instr * instr,void * state)3007 lower_ngg_gs_intrinsic(nir_builder *b, nir_instr *instr, void *state)
3008 {
3009    lower_ngg_gs_state *s = (lower_ngg_gs_state *) state;
3010 
3011    if (instr->type != nir_instr_type_intrinsic)
3012       return false;
3013 
3014    nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
3015 
3016    if (intrin->intrinsic == nir_intrinsic_store_output)
3017       return lower_ngg_gs_store_output(b, intrin, s);
3018    else if (intrin->intrinsic == nir_intrinsic_emit_vertex_with_counter)
3019       return lower_ngg_gs_emit_vertex_with_counter(b, intrin, s);
3020    else if (intrin->intrinsic == nir_intrinsic_end_primitive_with_counter)
3021       return lower_ngg_gs_end_primitive_with_counter(b, intrin, s);
3022    else if (intrin->intrinsic == nir_intrinsic_set_vertex_and_primitive_count)
3023       return lower_ngg_gs_set_vertex_and_primitive_count(b, intrin, s);
3024 
3025    return false;
3026 }
3027 
3028 static void
lower_ngg_gs_intrinsics(nir_shader * shader,lower_ngg_gs_state * s)3029 lower_ngg_gs_intrinsics(nir_shader *shader, lower_ngg_gs_state *s)
3030 {
3031    nir_shader_instructions_pass(shader, lower_ngg_gs_intrinsic, nir_metadata_none, s);
3032 }
3033 
3034 static void
ngg_gs_export_primitives(nir_builder * b,nir_def * max_num_out_prims,nir_def * tid_in_tg,nir_def * exporter_tid_in_tg,nir_def * primflag_0,lower_ngg_gs_state * s)3035 ngg_gs_export_primitives(nir_builder *b, nir_def *max_num_out_prims, nir_def *tid_in_tg,
3036                          nir_def *exporter_tid_in_tg, nir_def *primflag_0,
3037                          lower_ngg_gs_state *s)
3038 {
3039    nir_if *if_prim_export_thread = nir_push_if(b, nir_ilt(b, tid_in_tg, max_num_out_prims));
3040 
3041    /* Only bit 0 matters here - set it to 1 when the primitive should be null */
3042    nir_def *is_null_prim = nir_ixor(b, primflag_0, nir_imm_int(b, -1u));
3043 
3044    nir_def *vtx_indices[3] = {0};
3045    vtx_indices[s->num_vertices_per_primitive - 1] = exporter_tid_in_tg;
3046    if (s->num_vertices_per_primitive >= 2)
3047       vtx_indices[s->num_vertices_per_primitive - 2] = nir_iadd_imm(b, exporter_tid_in_tg, -1);
3048    if (s->num_vertices_per_primitive == 3)
3049       vtx_indices[s->num_vertices_per_primitive - 3] = nir_iadd_imm(b, exporter_tid_in_tg, -2);
3050 
3051    if (s->num_vertices_per_primitive == 3) {
3052       /* API GS outputs triangle strips, but NGG HW understands triangles.
3053        * We already know the triangles due to how we set the primitive flags, but we need to
3054        * make sure the vertex order is so that the front/back is correct, and the provoking vertex is kept.
3055        */
3056 
3057       nir_def *is_odd = nir_ubfe_imm(b, primflag_0, 1, 1);
3058       nir_def *provoking_vertex_index = nir_load_provoking_vtx_in_prim_amd(b);
3059       nir_def *provoking_vertex_first = nir_ieq_imm(b, provoking_vertex_index, 0);
3060 
3061       vtx_indices[0] = nir_bcsel(b, provoking_vertex_first, vtx_indices[0],
3062                                  nir_iadd(b, vtx_indices[0], is_odd));
3063       vtx_indices[1] = nir_bcsel(b, provoking_vertex_first,
3064                                  nir_iadd(b, vtx_indices[1], is_odd),
3065                                  nir_isub(b, vtx_indices[1], is_odd));
3066       vtx_indices[2] = nir_bcsel(b, provoking_vertex_first,
3067                                  nir_isub(b, vtx_indices[2], is_odd), vtx_indices[2]);
3068    }
3069 
3070    nir_def *arg = ac_nir_pack_ngg_prim_exp_arg(b, s->num_vertices_per_primitive, vtx_indices,
3071                                              is_null_prim, s->options->gfx_level);
3072    ac_nir_export_primitive(b, arg, NULL);
3073    nir_pop_if(b, if_prim_export_thread);
3074 }
3075 
3076 static void
ngg_gs_export_vertices(nir_builder * b,nir_def * max_num_out_vtx,nir_def * tid_in_tg,nir_def * out_vtx_lds_addr,lower_ngg_gs_state * s)3077 ngg_gs_export_vertices(nir_builder *b, nir_def *max_num_out_vtx, nir_def *tid_in_tg,
3078                        nir_def *out_vtx_lds_addr, lower_ngg_gs_state *s)
3079 {
3080    nir_if *if_vtx_export_thread = nir_push_if(b, nir_ilt(b, tid_in_tg, max_num_out_vtx));
3081    nir_def *exported_out_vtx_lds_addr = out_vtx_lds_addr;
3082 
3083    if (!s->output_compile_time_known) {
3084       /* Vertex compaction.
3085        * The current thread will export a vertex that was live in another invocation.
3086        * Load the index of the vertex that the current thread will have to export.
3087        */
3088       nir_def *exported_vtx_idx = nir_load_shared(b, 1, 8, out_vtx_lds_addr, .base = s->lds_offs_primflags + 1);
3089       exported_out_vtx_lds_addr = ngg_gs_out_vertex_addr(b, nir_u2u32(b, exported_vtx_idx), s);
3090    }
3091 
3092    u_foreach_bit64(slot, b->shader->info.outputs_written) {
3093       const unsigned packed_location =
3094          util_bitcount64((b->shader->info.outputs_written & BITFIELD64_MASK(slot)));
3095 
3096       unsigned mask = gs_output_component_mask_with_stream(&s->out.infos[slot], 0);
3097 
3098       while (mask) {
3099          int start, count;
3100          u_bit_scan_consecutive_range(&mask, &start, &count);
3101          nir_def *load =
3102             nir_load_shared(b, count, 32, exported_out_vtx_lds_addr,
3103                             .base = packed_location * 16 + start * 4,
3104                             .align_mul = 4);
3105 
3106          for (int i = 0; i < count; i++)
3107             s->out.outputs[slot][start + i] = nir_channel(b, load, i);
3108       }
3109    }
3110 
3111    const unsigned num_32bit_outputs = util_bitcount64(b->shader->info.outputs_written);
3112 
3113    /* Dedicated 16-bit outputs. */
3114    u_foreach_bit(i, b->shader->info.outputs_written_16bit) {
3115       const unsigned packed_location = num_32bit_outputs +
3116          util_bitcount(b->shader->info.outputs_written_16bit & BITFIELD_MASK(i));
3117 
3118       const unsigned mask_lo = gs_output_component_mask_with_stream(&s->out.infos_16bit_lo[i], 0);
3119       const unsigned mask_hi = gs_output_component_mask_with_stream(&s->out.infos_16bit_hi[i], 0);
3120       unsigned mask = mask_lo | mask_hi;
3121 
3122       while (mask) {
3123          int start, count;
3124          u_bit_scan_consecutive_range(&mask, &start, &count);
3125          nir_def *load =
3126             nir_load_shared(b, count, 32, exported_out_vtx_lds_addr,
3127                             .base = packed_location * 16 + start * 4,
3128                             .align_mul = 4);
3129 
3130          for (int j = 0; j < count; j++) {
3131             nir_def *val = nir_channel(b, load, j);
3132             unsigned comp = start + j;
3133 
3134             if (mask_lo & BITFIELD_BIT(comp))
3135                s->out.outputs_16bit_lo[i][comp] = nir_unpack_32_2x16_split_x(b, val);
3136 
3137             if (mask_hi & BITFIELD_BIT(comp))
3138                s->out.outputs_16bit_hi[i][comp] = nir_unpack_32_2x16_split_y(b, val);
3139          }
3140       }
3141    }
3142 
3143    uint64_t export_outputs = b->shader->info.outputs_written | VARYING_BIT_POS;
3144    if (s->options->kill_pointsize)
3145       export_outputs &= ~VARYING_BIT_PSIZ;
3146    if (s->options->kill_layer)
3147       export_outputs &= ~VARYING_BIT_LAYER;
3148 
3149    const bool wait_attr_ring = must_wait_attr_ring(s->options->gfx_level, s->options->has_param_exports);
3150    if (wait_attr_ring)
3151       export_outputs &= ~VARYING_BIT_POS;
3152 
3153    ac_nir_export_position(b, s->options->gfx_level,
3154                           s->options->clip_cull_dist_mask,
3155                           !s->options->has_param_exports,
3156                           s->options->force_vrs, !wait_attr_ring,
3157                           export_outputs, &s->out, NULL);
3158 
3159    if (s->options->has_param_exports && s->options->gfx_level < GFX11) {
3160       /* Emit vertex parameter exports.
3161        * Only the vertex export threads should do this.
3162        */
3163       ac_nir_export_parameters(b, s->options->vs_output_param_offset,
3164                                b->shader->info.outputs_written,
3165                                b->shader->info.outputs_written_16bit,
3166                                &s->out);
3167    }
3168 
3169    nir_pop_if(b, if_vtx_export_thread);
3170 
3171    if (s->options->has_param_exports && s->options->gfx_level >= GFX11) {
3172       /* Store vertex parameters to attribute ring.
3173        * For optimal attribute ring access, this should happen in top level CF.
3174        */
3175       create_output_phis(b, b->shader->info.outputs_written, b->shader->info.outputs_written_16bit, &s->out);
3176       ac_nir_store_parameters_to_attr_ring(b, s->options->vs_output_param_offset,
3177                                            b->shader->info.outputs_written,
3178                                            b->shader->info.outputs_written_16bit,
3179                                            &s->out, tid_in_tg, max_num_out_vtx);
3180 
3181       if (wait_attr_ring)
3182          export_pos0_wait_attr_ring(b, if_vtx_export_thread, s->out.outputs, s->options);
3183    }
3184 }
3185 
3186 static void
ngg_gs_setup_vertex_compaction(nir_builder * b,nir_def * vertex_live,nir_def * tid_in_tg,nir_def * exporter_tid_in_tg,lower_ngg_gs_state * s)3187 ngg_gs_setup_vertex_compaction(nir_builder *b, nir_def *vertex_live, nir_def *tid_in_tg,
3188                                nir_def *exporter_tid_in_tg, lower_ngg_gs_state *s)
3189 {
3190    assert(vertex_live->bit_size == 1);
3191    nir_if *if_vertex_live = nir_push_if(b, vertex_live);
3192    {
3193       /* Setup the vertex compaction.
3194        * Save the current thread's id for the thread which will export the current vertex.
3195        * We reuse stream 1 of the primitive flag of the other thread's vertex for storing this.
3196        */
3197 
3198       nir_def *exporter_lds_addr = ngg_gs_out_vertex_addr(b, exporter_tid_in_tg, s);
3199       nir_def *tid_in_tg_u8 = nir_u2u8(b, tid_in_tg);
3200       nir_store_shared(b, tid_in_tg_u8, exporter_lds_addr, .base = s->lds_offs_primflags + 1);
3201    }
3202    nir_pop_if(b, if_vertex_live);
3203 }
3204 
3205 static nir_def *
ngg_gs_load_out_vtx_primflag(nir_builder * b,unsigned stream,nir_def * tid_in_tg,nir_def * vtx_lds_addr,nir_def * max_num_out_vtx,lower_ngg_gs_state * s)3206 ngg_gs_load_out_vtx_primflag(nir_builder *b, unsigned stream, nir_def *tid_in_tg,
3207                              nir_def *vtx_lds_addr, nir_def *max_num_out_vtx,
3208                              lower_ngg_gs_state *s)
3209 {
3210    nir_def *zero = nir_imm_int(b, 0);
3211 
3212    nir_if *if_outvtx_thread = nir_push_if(b, nir_ilt(b, tid_in_tg, max_num_out_vtx));
3213    nir_def *primflag = nir_load_shared(b, 1, 8, vtx_lds_addr,
3214                                            .base = s->lds_offs_primflags + stream);
3215    primflag = nir_u2u32(b, primflag);
3216    nir_pop_if(b, if_outvtx_thread);
3217 
3218    return nir_if_phi(b, primflag, zero);
3219 }
3220 
3221 static void
ngg_gs_out_prim_all_vtxptr(nir_builder * b,nir_def * last_vtxidx,nir_def * last_vtxptr,nir_def * last_vtx_primflag,lower_ngg_gs_state * s,nir_def * vtxptr[3])3222 ngg_gs_out_prim_all_vtxptr(nir_builder *b, nir_def *last_vtxidx, nir_def *last_vtxptr,
3223                            nir_def *last_vtx_primflag, lower_ngg_gs_state *s,
3224                            nir_def *vtxptr[3])
3225 {
3226    unsigned last_vtx = s->num_vertices_per_primitive - 1;
3227    vtxptr[last_vtx]= last_vtxptr;
3228 
3229    bool primitive_is_triangle = s->num_vertices_per_primitive == 3;
3230    nir_def *is_odd = primitive_is_triangle ?
3231       nir_ubfe_imm(b, last_vtx_primflag, 1, 1) : NULL;
3232 
3233    for (unsigned i = 0; i < s->num_vertices_per_primitive - 1; i++) {
3234       nir_def *vtxidx = nir_iadd_imm(b, last_vtxidx, -(last_vtx - i));
3235 
3236       /* Need to swap vertex 0 and vertex 1 when vertex 2 index is odd to keep
3237        * CW/CCW order for correct front/back face culling.
3238        */
3239       if (primitive_is_triangle)
3240          vtxidx = i == 0 ? nir_iadd(b, vtxidx, is_odd) : nir_isub(b, vtxidx, is_odd);
3241 
3242       vtxptr[i] = ngg_gs_out_vertex_addr(b, vtxidx, s);
3243    }
3244 }
3245 
3246 static nir_def *
ngg_gs_cull_primitive(nir_builder * b,nir_def * tid_in_tg,nir_def * max_vtxcnt,nir_def * out_vtx_lds_addr,nir_def * out_vtx_primflag_0,lower_ngg_gs_state * s)3247 ngg_gs_cull_primitive(nir_builder *b, nir_def *tid_in_tg, nir_def *max_vtxcnt,
3248                       nir_def *out_vtx_lds_addr, nir_def *out_vtx_primflag_0,
3249                       lower_ngg_gs_state *s)
3250 {
3251    /* we haven't enabled point culling, if enabled this function could be further optimized */
3252    assert(s->num_vertices_per_primitive > 1);
3253 
3254    /* save the primflag so that we don't need to load it from LDS again */
3255    nir_variable *primflag_var = nir_local_variable_create(s->impl, glsl_uint_type(), "primflag");
3256    nir_store_var(b, primflag_var, out_vtx_primflag_0, 1);
3257 
3258    /* last bit of primflag indicate if this is the final vertex of a primitive */
3259    nir_def *is_end_prim_vtx = nir_i2b(b, nir_iand_imm(b, out_vtx_primflag_0, 1));
3260    nir_def *has_output_vertex = nir_ilt(b, tid_in_tg, max_vtxcnt);
3261    nir_def *prim_enable = nir_iand(b, is_end_prim_vtx, has_output_vertex);
3262 
3263    nir_if *if_prim_enable = nir_push_if(b, prim_enable);
3264    {
3265       /* Calculate the LDS address of every vertex in the current primitive. */
3266       nir_def *vtxptr[3];
3267       ngg_gs_out_prim_all_vtxptr(b, tid_in_tg, out_vtx_lds_addr, out_vtx_primflag_0, s, vtxptr);
3268 
3269       /* Load the positions from LDS. */
3270       nir_def *pos[3][4];
3271       for (unsigned i = 0; i < s->num_vertices_per_primitive; i++) {
3272          /* VARYING_SLOT_POS == 0, so base won't count packed location */
3273          pos[i][3] = nir_load_shared(b, 1, 32, vtxptr[i], .base = 12); /* W */
3274          nir_def *xy = nir_load_shared(b, 2, 32, vtxptr[i], .base = 0, .align_mul = 4);
3275          pos[i][0] = nir_channel(b, xy, 0);
3276          pos[i][1] = nir_channel(b, xy, 1);
3277 
3278          pos[i][0] = nir_fdiv(b, pos[i][0], pos[i][3]);
3279          pos[i][1] = nir_fdiv(b, pos[i][1], pos[i][3]);
3280       }
3281 
3282       /* TODO: support clipdist culling in GS */
3283       nir_def *accepted_by_clipdist = nir_imm_true(b);
3284 
3285       nir_def *accepted = ac_nir_cull_primitive(
3286          b, accepted_by_clipdist, pos, s->num_vertices_per_primitive, NULL, NULL);
3287 
3288       nir_if *if_rejected = nir_push_if(b, nir_inot(b, accepted));
3289       {
3290          /* clear the primflag if rejected */
3291          nir_store_shared(b, nir_imm_zero(b, 1, 8), out_vtx_lds_addr,
3292                           .base = s->lds_offs_primflags);
3293 
3294          nir_store_var(b, primflag_var, nir_imm_int(b, 0), 1);
3295       }
3296       nir_pop_if(b, if_rejected);
3297    }
3298    nir_pop_if(b, if_prim_enable);
3299 
3300    /* Wait for LDS primflag access done. */
3301    nir_barrier(b, .execution_scope = SCOPE_WORKGROUP,
3302                          .memory_scope = SCOPE_WORKGROUP,
3303                          .memory_semantics = NIR_MEMORY_ACQ_REL,
3304                          .memory_modes = nir_var_mem_shared);
3305 
3306    /* only dead vertex need a chance to relive */
3307    nir_def *vtx_is_dead = nir_ieq_imm(b, nir_load_var(b, primflag_var), 0);
3308    nir_def *vtx_update_primflag = nir_iand(b, vtx_is_dead, has_output_vertex);
3309    nir_if *if_update_primflag = nir_push_if(b, vtx_update_primflag);
3310    {
3311       /* get succeeding vertices' primflag to detect this vertex's liveness */
3312       for (unsigned i = 1; i < s->num_vertices_per_primitive; i++) {
3313          nir_def *vtxidx = nir_iadd_imm(b, tid_in_tg, i);
3314          nir_def *not_overflow = nir_ilt(b, vtxidx, max_vtxcnt);
3315          nir_if *if_not_overflow = nir_push_if(b, not_overflow);
3316          {
3317             nir_def *vtxptr = ngg_gs_out_vertex_addr(b, vtxidx, s);
3318             nir_def *vtx_primflag =
3319                nir_load_shared(b, 1, 8, vtxptr, .base = s->lds_offs_primflags);
3320             vtx_primflag = nir_u2u32(b, vtx_primflag);
3321 
3322             /* if succeeding vertex is alive end of primitive vertex, need to set current
3323              * thread vertex's liveness flag (bit 2)
3324              */
3325             nir_def *has_prim = nir_i2b(b, nir_iand_imm(b, vtx_primflag, 1));
3326             nir_def *vtx_live_flag =
3327                nir_bcsel(b, has_prim, nir_imm_int(b, 0b100), nir_imm_int(b, 0));
3328 
3329             /* update this vertex's primflag */
3330             nir_def *primflag = nir_load_var(b, primflag_var);
3331             primflag = nir_ior(b, primflag, vtx_live_flag);
3332             nir_store_var(b, primflag_var, primflag, 1);
3333          }
3334          nir_pop_if(b, if_not_overflow);
3335       }
3336    }
3337    nir_pop_if(b, if_update_primflag);
3338 
3339    return nir_load_var(b, primflag_var);
3340 }
3341 
3342 static void
ngg_gs_build_streamout(nir_builder * b,lower_ngg_gs_state * s)3343 ngg_gs_build_streamout(nir_builder *b, lower_ngg_gs_state *s)
3344 {
3345    nir_xfb_info *info = ac_nir_get_sorted_xfb_info(b->shader);
3346 
3347    nir_def *tid_in_tg = nir_load_local_invocation_index(b);
3348    nir_def *max_vtxcnt = nir_load_workgroup_num_input_vertices_amd(b);
3349    nir_def *out_vtx_lds_addr = ngg_gs_out_vertex_addr(b, tid_in_tg, s);
3350    nir_def *prim_live[4] = {0};
3351    nir_def *gen_prim[4] = {0};
3352    nir_def *export_seq[4] = {0};
3353    nir_def *out_vtx_primflag[4] = {0};
3354    for (unsigned stream = 0; stream < 4; stream++) {
3355       if (!(info->streams_written & BITFIELD_BIT(stream)))
3356          continue;
3357 
3358       out_vtx_primflag[stream] =
3359          ngg_gs_load_out_vtx_primflag(b, stream, tid_in_tg, out_vtx_lds_addr, max_vtxcnt, s);
3360 
3361       /* Check bit 0 of primflag for primitive alive, it's set for every last
3362        * vertex of a primitive.
3363        */
3364       prim_live[stream] = nir_i2b(b, nir_iand_imm(b, out_vtx_primflag[stream], 1));
3365 
3366       unsigned scratch_stride = ALIGN(s->max_num_waves, 4);
3367       nir_def *scratch_base =
3368          nir_iadd_imm(b, s->lds_addr_gs_scratch, stream * scratch_stride);
3369 
3370       /* We want to export primitives to streamout buffer in sequence,
3371        * but not all vertices are alive or mark end of a primitive, so
3372        * there're "holes". We don't need continuous invocations to write
3373        * primitives to streamout buffer like final vertex export, so
3374        * just repack to get the sequence (export_seq) is enough, no need
3375        * to do compaction.
3376        *
3377        * Use separate scratch space for each stream to avoid barrier.
3378        * TODO: we may further reduce barriers by writing to all stream
3379        * LDS at once, then we only need one barrier instead of one each
3380        * stream..
3381        */
3382       wg_repack_result rep = {0};
3383       repack_invocations_in_workgroup(b, &prim_live[stream], &rep, 1, scratch_base,
3384                                       s->max_num_waves, s->options->wave_size);
3385 
3386       /* nir_intrinsic_set_vertex_and_primitive_count can also get primitive count of
3387        * current wave, but still need LDS to sum all wave's count to get workgroup count.
3388        * And we need repack to export primitive to streamout buffer anyway, so do here.
3389        */
3390       gen_prim[stream] = rep.num_repacked_invocations;
3391       export_seq[stream] = rep.repacked_invocation_index;
3392    }
3393 
3394    /* Workgroup barrier: wait for LDS scratch reads finish. */
3395    nir_barrier(b, .execution_scope = SCOPE_WORKGROUP,
3396                       .memory_scope = SCOPE_WORKGROUP,
3397                       .memory_semantics = NIR_MEMORY_ACQ_REL,
3398                       .memory_modes = nir_var_mem_shared);
3399 
3400    /* Get global buffer offset where this workgroup will stream out data to. */
3401    nir_def *emit_prim[4] = {0};
3402    nir_def *buffer_offsets[4] = {0};
3403    nir_def *so_buffer[4] = {0};
3404    ngg_build_streamout_buffer_info(b, info, s->options->gfx_level, s->options->has_xfb_prim_query,
3405                                    s->options->use_gfx12_xfb_intrinsic, s->lds_addr_gs_scratch, tid_in_tg,
3406                                    gen_prim, so_buffer, buffer_offsets, emit_prim);
3407 
3408    for (unsigned stream = 0; stream < 4; stream++) {
3409       if (!(info->streams_written & BITFIELD_BIT(stream)))
3410          continue;
3411 
3412       nir_def *can_emit = nir_ilt(b, export_seq[stream], emit_prim[stream]);
3413       nir_if *if_emit = nir_push_if(b, nir_iand(b, can_emit, prim_live[stream]));
3414       {
3415          /* Get streamout buffer vertex index for the first vertex of this primitive. */
3416          nir_def *first_vertex_idx =
3417             nir_imul_imm(b, export_seq[stream], s->num_vertices_per_primitive);
3418          nir_def *stream_buffer_offsets[NIR_MAX_XFB_BUFFERS];
3419 
3420          u_foreach_bit(buffer, info->buffers_written) {
3421             stream_buffer_offsets[buffer] = nir_iadd(b, buffer_offsets[buffer],
3422                                                      nir_imul_imm(b, first_vertex_idx,
3423                                                                   info->buffers[buffer].stride));
3424          }
3425 
3426          /* Get all vertices' lds address of this primitive. */
3427          nir_def *exported_vtx_lds_addr[3];
3428          ngg_gs_out_prim_all_vtxptr(b, tid_in_tg, out_vtx_lds_addr,
3429                                     out_vtx_primflag[stream], s,
3430                                     exported_vtx_lds_addr);
3431 
3432          /* Write all vertices of this primitive to streamout buffer. */
3433          for (unsigned i = 0; i < s->num_vertices_per_primitive; i++) {
3434             ngg_build_streamout_vertex(b, info, stream, so_buffer,
3435                                        stream_buffer_offsets, i,
3436                                        exported_vtx_lds_addr[i],
3437                                        &s->out, false);
3438          }
3439       }
3440       nir_pop_if(b, if_emit);
3441    }
3442 }
3443 
3444 static void
ngg_gs_finale(nir_builder * b,lower_ngg_gs_state * s)3445 ngg_gs_finale(nir_builder *b, lower_ngg_gs_state *s)
3446 {
3447    nir_def *tid_in_tg = nir_load_local_invocation_index(b);
3448    nir_def *max_vtxcnt = nir_load_workgroup_num_input_vertices_amd(b);
3449    nir_def *max_prmcnt = max_vtxcnt; /* They are currently practically the same; both RADV and RadeonSI do this. */
3450    nir_def *out_vtx_lds_addr = ngg_gs_out_vertex_addr(b, tid_in_tg, s);
3451 
3452    if (s->output_compile_time_known) {
3453       /* When the output is compile-time known, the GS writes all possible vertices and primitives it can.
3454        * The gs_alloc_req needs to happen on one wave only, otherwise the HW hangs.
3455        */
3456       nir_if *if_wave_0 = nir_push_if(b, nir_ieq_imm(b, nir_load_subgroup_id(b), 0));
3457       alloc_vertices_and_primitives(b, max_vtxcnt, max_prmcnt);
3458       nir_pop_if(b, if_wave_0);
3459    }
3460 
3461    /* Workgroup barrier already emitted, we can assume all GS output stores are done by now. */
3462 
3463    nir_def *out_vtx_primflag_0 = ngg_gs_load_out_vtx_primflag(b, 0, tid_in_tg, out_vtx_lds_addr, max_vtxcnt, s);
3464 
3465    if (s->output_compile_time_known) {
3466       ngg_gs_export_primitives(b, max_vtxcnt, tid_in_tg, tid_in_tg, out_vtx_primflag_0, s);
3467       ngg_gs_export_vertices(b, max_vtxcnt, tid_in_tg, out_vtx_lds_addr, s);
3468       return;
3469    }
3470 
3471    /* cull primitives */
3472    if (s->options->can_cull) {
3473       nir_if *if_cull_en = nir_push_if(b, nir_load_cull_any_enabled_amd(b));
3474 
3475       /* culling code will update the primflag */
3476       nir_def *updated_primflag =
3477          ngg_gs_cull_primitive(b, tid_in_tg, max_vtxcnt, out_vtx_lds_addr,
3478                                out_vtx_primflag_0, s);
3479 
3480       nir_pop_if(b, if_cull_en);
3481 
3482       out_vtx_primflag_0 = nir_if_phi(b, updated_primflag, out_vtx_primflag_0);
3483    }
3484 
3485    /* When the output vertex count is not known at compile time:
3486     * There may be gaps between invocations that have live vertices, but NGG hardware
3487     * requires that the invocations that export vertices are packed (ie. compact).
3488     * To ensure this, we need to repack invocations that have a live vertex.
3489     */
3490    nir_def *vertex_live = nir_ine_imm(b, out_vtx_primflag_0, 0);
3491    wg_repack_result rep = {0};
3492 
3493    repack_invocations_in_workgroup(b, &vertex_live, &rep, 1, s->lds_addr_gs_scratch,
3494                                    s->max_num_waves, s->options->wave_size);
3495 
3496    nir_def *workgroup_num_vertices = rep.num_repacked_invocations;
3497    nir_def *exporter_tid_in_tg = rep.repacked_invocation_index;
3498 
3499    /* When the workgroup emits 0 total vertices, we also must export 0 primitives (otherwise the HW can hang). */
3500    nir_def *any_output = nir_ine_imm(b, workgroup_num_vertices, 0);
3501    max_prmcnt = nir_bcsel(b, any_output, max_prmcnt, nir_imm_int(b, 0));
3502 
3503    /* Allocate export space. We currently don't compact primitives, just use the maximum number. */
3504    nir_if *if_wave_0 = nir_push_if(b, nir_ieq_imm(b, nir_load_subgroup_id(b), 0));
3505    {
3506       if (s->options->gfx_level == GFX10)
3507          alloc_vertices_and_primitives_gfx10_workaround(b, workgroup_num_vertices, max_prmcnt);
3508       else
3509          alloc_vertices_and_primitives(b, workgroup_num_vertices, max_prmcnt);
3510    }
3511    nir_pop_if(b, if_wave_0);
3512 
3513    /* Vertex compaction. This makes sure there are no gaps between threads that export vertices. */
3514    ngg_gs_setup_vertex_compaction(b, vertex_live, tid_in_tg, exporter_tid_in_tg, s);
3515 
3516    /* Workgroup barrier: wait for all LDS stores to finish. */
3517    nir_barrier(b, .execution_scope=SCOPE_WORKGROUP, .memory_scope=SCOPE_WORKGROUP,
3518                         .memory_semantics=NIR_MEMORY_ACQ_REL, .memory_modes=nir_var_mem_shared);
3519 
3520    ngg_gs_export_primitives(b, max_prmcnt, tid_in_tg, exporter_tid_in_tg, out_vtx_primflag_0, s);
3521    ngg_gs_export_vertices(b, workgroup_num_vertices, tid_in_tg, out_vtx_lds_addr, s);
3522 }
3523 
3524 void
ac_nir_lower_ngg_gs(nir_shader * shader,const ac_nir_lower_ngg_options * options)3525 ac_nir_lower_ngg_gs(nir_shader *shader, const ac_nir_lower_ngg_options *options)
3526 {
3527    nir_function_impl *impl = nir_shader_get_entrypoint(shader);
3528    assert(impl);
3529 
3530    lower_ngg_gs_state state = {
3531       .options = options,
3532       .impl = impl,
3533       .max_num_waves = DIV_ROUND_UP(options->max_workgroup_size, options->wave_size),
3534       .lds_offs_primflags = options->gs_out_vtx_bytes,
3535       .lds_bytes_per_gs_out_vertex = options->gs_out_vtx_bytes + 4u,
3536       .streamout_enabled = shader->xfb_info && !options->disable_streamout,
3537    };
3538 
3539    if (!options->can_cull) {
3540       nir_gs_count_vertices_and_primitives(shader, state.const_out_vtxcnt,
3541                                            state.const_out_prmcnt, NULL, 4u);
3542       state.output_compile_time_known =
3543          state.const_out_vtxcnt[0] == shader->info.gs.vertices_out &&
3544          state.const_out_prmcnt[0] != -1;
3545    }
3546 
3547    if (shader->info.gs.output_primitive == MESA_PRIM_POINTS)
3548       state.num_vertices_per_primitive = 1;
3549    else if (shader->info.gs.output_primitive == MESA_PRIM_LINE_STRIP)
3550       state.num_vertices_per_primitive = 2;
3551    else if (shader->info.gs.output_primitive == MESA_PRIM_TRIANGLE_STRIP)
3552       state.num_vertices_per_primitive = 3;
3553    else
3554       unreachable("Invalid GS output primitive.");
3555 
3556    /* Extract the full control flow. It is going to be wrapped in an if statement. */
3557    nir_cf_list extracted;
3558    nir_cf_extract(&extracted, nir_before_impl(impl),
3559                   nir_after_impl(impl));
3560 
3561    nir_builder builder = nir_builder_at(nir_before_impl(impl));
3562    nir_builder *b = &builder; /* This is to avoid the & */
3563 
3564    /* Workgroup barrier: wait for ES threads */
3565    nir_barrier(b, .execution_scope=SCOPE_WORKGROUP, .memory_scope=SCOPE_WORKGROUP,
3566                          .memory_semantics=NIR_MEMORY_ACQ_REL, .memory_modes=nir_var_mem_shared);
3567 
3568    state.lds_addr_gs_out_vtx = nir_load_lds_ngg_gs_out_vertex_base_amd(b);
3569    state.lds_addr_gs_scratch = nir_load_lds_ngg_scratch_base_amd(b);
3570 
3571    /* Wrap the GS control flow. */
3572    nir_if *if_gs_thread = nir_push_if(b, has_input_primitive(b));
3573 
3574    nir_cf_reinsert(&extracted, b->cursor);
3575    b->cursor = nir_after_cf_list(&if_gs_thread->then_list);
3576    nir_pop_if(b, if_gs_thread);
3577 
3578    /* Workgroup barrier: wait for all GS threads to finish */
3579    nir_barrier(b, .execution_scope=SCOPE_WORKGROUP, .memory_scope=SCOPE_WORKGROUP,
3580                          .memory_semantics=NIR_MEMORY_ACQ_REL, .memory_modes=nir_var_mem_shared);
3581 
3582    if (state.streamout_enabled)
3583       ngg_gs_build_streamout(b, &state);
3584 
3585    /* Lower the GS intrinsics */
3586    lower_ngg_gs_intrinsics(shader, &state);
3587 
3588    if (!state.vertex_count[0]) {
3589       fprintf(stderr, "Could not find set_vertex_and_primitive_count for stream 0. This would hang your GPU.");
3590       abort();
3591    }
3592 
3593    /* Emit shader queries */
3594    b->cursor = nir_after_cf_list(&if_gs_thread->then_list);
3595    ac_nir_gs_shader_query(b,
3596                           state.options->has_gen_prim_query,
3597                           state.options->has_gs_invocations_query,
3598                           state.options->has_gs_primitives_query,
3599                           state.num_vertices_per_primitive,
3600                           state.options->wave_size,
3601                           state.vertex_count,
3602                           state.primitive_count);
3603 
3604    b->cursor = nir_after_impl(impl);
3605 
3606    /* Emit the finale sequence */
3607    ngg_gs_finale(b, &state);
3608    nir_validate_shader(shader, "after emitting NGG GS");
3609 
3610    /* Cleanup */
3611    nir_lower_vars_to_ssa(shader);
3612    nir_remove_dead_variables(shader, nir_var_function_temp, NULL);
3613    nir_metadata_preserve(impl, nir_metadata_none);
3614 }
3615 
3616 unsigned
ac_ngg_nogs_get_pervertex_lds_size(gl_shader_stage stage,unsigned shader_num_outputs,bool streamout_enabled,bool export_prim_id,bool has_user_edgeflags,bool can_cull,bool uses_instance_id,bool uses_primitive_id)3617 ac_ngg_nogs_get_pervertex_lds_size(gl_shader_stage stage,
3618                                    unsigned shader_num_outputs,
3619                                    bool streamout_enabled,
3620                                    bool export_prim_id,
3621                                    bool has_user_edgeflags,
3622                                    bool can_cull,
3623                                    bool uses_instance_id,
3624                                    bool uses_primitive_id)
3625 {
3626    /* for culling time lds layout only */
3627    unsigned culling_pervertex_lds_bytes = can_cull ?
3628       ngg_nogs_get_culling_pervertex_lds_size(
3629          stage, uses_instance_id, uses_primitive_id, NULL) : 0;
3630 
3631    unsigned pervertex_lds_bytes =
3632       ngg_nogs_get_pervertex_lds_size(stage, shader_num_outputs, streamout_enabled,
3633                                       export_prim_id, has_user_edgeflags);
3634 
3635    return MAX2(culling_pervertex_lds_bytes, pervertex_lds_bytes);
3636 }
3637 
3638 unsigned
ac_ngg_get_scratch_lds_size(gl_shader_stage stage,unsigned workgroup_size,unsigned wave_size,bool streamout_enabled,bool can_cull,bool compact_primitives)3639 ac_ngg_get_scratch_lds_size(gl_shader_stage stage,
3640                             unsigned workgroup_size,
3641                             unsigned wave_size,
3642                             bool streamout_enabled,
3643                             bool can_cull,
3644                             bool compact_primitives)
3645 {
3646    unsigned scratch_lds_size = 0;
3647    unsigned max_num_waves = DIV_ROUND_UP(workgroup_size, wave_size);
3648 
3649    if (stage == MESA_SHADER_VERTEX || stage == MESA_SHADER_TESS_EVAL) {
3650       if (streamout_enabled) {
3651          /* 4 dwords for 4 streamout buffer offset, 1 dword for emit prim count */
3652          scratch_lds_size = 20;
3653       } else if (can_cull) {
3654          /* 1 byte per wave per repack, max 8 waves */
3655          unsigned num_rep = compact_primitives ? 2 : 1;
3656          scratch_lds_size = ALIGN(max_num_waves, 4u) * num_rep;
3657       }
3658    } else {
3659       assert(stage == MESA_SHADER_GEOMETRY);
3660 
3661       scratch_lds_size = ALIGN(max_num_waves, 4u);
3662       /* streamout take 8 dwords for buffer offset and emit vertex per stream */
3663       if (streamout_enabled)
3664          scratch_lds_size = MAX2(scratch_lds_size, 32);
3665    }
3666 
3667    return scratch_lds_size;
3668 }
3669