• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright © 2021 Valve Corporation
3  *
4  * Permission is hereby granted, free of charge, to any person obtaining a
5  * copy of this software and associated documentation files (the "Software"),
6  * to deal in the Software without restriction, including without limitation
7  * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8  * and/or sell copies of the Software, and to permit persons to whom the
9  * Software is furnished to do so, subject to the following conditions:
10  *
11  * The above copyright notice and this permission notice (including the next
12  * paragraph) shall be included in all copies or substantial portions of the
13  * Software.
14  *
15  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18  * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20  * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21  * IN THE SOFTWARE.
22  *
23  */
24 
25 #include "ac_nir.h"
26 #include "nir_builder.h"
27 #include "u_math.h"
28 #include "u_vector.h"
29 
30 enum {
31    nggc_passflag_used_by_pos = 1,
32    nggc_passflag_used_by_other = 2,
33    nggc_passflag_used_by_both = nggc_passflag_used_by_pos | nggc_passflag_used_by_other,
34 };
35 
36 typedef struct
37 {
38    nir_ssa_def *ssa;
39    nir_variable *var;
40 } saved_uniform;
41 
42 typedef struct
43 {
44    nir_variable *position_value_var;
45    nir_variable *prim_exp_arg_var;
46    nir_variable *es_accepted_var;
47    nir_variable *gs_accepted_var;
48    nir_variable *gs_vtx_indices_vars[3];
49 
50    struct u_vector saved_uniforms;
51 
52    bool passthrough;
53    bool export_prim_id;
54    bool early_prim_export;
55    bool use_edgeflags;
56    bool has_prim_query;
57    bool can_cull;
58    unsigned wave_size;
59    unsigned max_num_waves;
60    unsigned num_vertices_per_primitives;
61    unsigned provoking_vtx_idx;
62    unsigned max_es_num_vertices;
63    unsigned total_lds_bytes;
64 
65    uint64_t inputs_needed_by_pos;
66    uint64_t inputs_needed_by_others;
67    uint32_t instance_rate_inputs;
68 
69    nir_instr *compact_arg_stores[4];
70    nir_intrinsic_instr *overwrite_args;
71 } lower_ngg_nogs_state;
72 
73 typedef struct
74 {
75    /* Bitmask of components used: 4 bits per slot, 1 bit per component. */
76    uint8_t components_mask : 4;
77    /* output stream index  */
78    uint8_t stream : 2;
79 } gs_output_info;
80 
81 typedef struct
82 {
83    nir_variable *output_vars[VARYING_SLOT_MAX][4];
84    nir_variable *current_clear_primflag_idx_var;
85    int const_out_vtxcnt[4];
86    int const_out_prmcnt[4];
87    unsigned wave_size;
88    unsigned max_num_waves;
89    unsigned num_vertices_per_primitive;
90    unsigned lds_addr_gs_out_vtx;
91    unsigned lds_addr_gs_scratch;
92    unsigned lds_bytes_per_gs_out_vertex;
93    unsigned lds_offs_primflags;
94    bool found_out_vtxcnt[4];
95    bool output_compile_time_known;
96    bool provoking_vertex_last;
97    gs_output_info output_info[VARYING_SLOT_MAX];
98 } lower_ngg_gs_state;
99 
100 /* LDS layout of Mesh Shader workgroup info. */
101 enum {
102    /* DW0: number of primitives */
103    lds_ms_num_prims = 0,
104    /* DW1: reserved for future use */
105    lds_ms_dw1_reserved = 4,
106    /* DW2: workgroup index within the current dispatch */
107    lds_ms_wg_index = 8,
108    /* DW3: number of API workgroups in flight */
109    lds_ms_num_api_waves = 12,
110 };
111 
112 /* Potential location for Mesh Shader outputs. */
113 typedef enum {
114    ms_out_mode_lds,
115    ms_out_mode_vram,
116    ms_out_mode_var,
117 } ms_out_mode;
118 
119 typedef struct
120 {
121    uint64_t mask; /* Mask of output locations */
122    uint32_t addr; /* Base address */
123 } ms_out_part;
124 
125 typedef struct
126 {
127    /* Mesh shader LDS layout. For details, see ms_calculate_output_layout. */
128    struct {
129       uint32_t workgroup_info_addr;
130       ms_out_part vtx_attr;
131       ms_out_part prm_attr;
132       uint32_t indices_addr;
133       uint32_t total_size;
134    } lds;
135    /* VRAM "mesh shader scratch ring" layout for outputs that don't fit into the LDS. */
136    struct {
137       ms_out_part vtx_attr;
138       ms_out_part prm_attr;
139    } vram;
140    /* Outputs without cross-invocation access can be stored in variables. */
141    struct {
142       ms_out_part vtx_attr;
143       ms_out_part prm_attr;
144    } var;
145 } ms_out_mem_layout;
146 
147 typedef struct
148 {
149    ms_out_mem_layout layout;
150    uint64_t per_vertex_outputs;
151    uint64_t per_primitive_outputs;
152    unsigned vertices_per_prim;
153 
154    unsigned wave_size;
155    unsigned api_workgroup_size;
156    unsigned hw_workgroup_size;
157 
158    nir_ssa_def *workgroup_index;
159    nir_variable *out_variables[VARYING_SLOT_MAX * 4];
160 
161    /* True if the lowering needs to insert the layer output. */
162    bool insert_layer_output;
163 
164    struct {
165       /* Bitmask of components used: 4 bits per slot, 1 bit per component. */
166       uint32_t components_mask;
167    } output_info[VARYING_SLOT_MAX];
168 } lower_ngg_ms_state;
169 
170 typedef struct {
171    nir_variable *pre_cull_position_value_var;
172 } remove_culling_shader_outputs_state;
173 
174 typedef struct {
175    nir_variable *pos_value_replacement;
176 } remove_extra_position_output_state;
177 
178 /* Per-vertex LDS layout of culling shaders */
179 enum {
180    /* Position of the ES vertex (at the beginning for alignment reasons) */
181    lds_es_pos_x = 0,
182    lds_es_pos_y = 4,
183    lds_es_pos_z = 8,
184    lds_es_pos_w = 12,
185 
186    /* 1 when the vertex is accepted, 0 if it should be culled */
187    lds_es_vertex_accepted = 16,
188    /* ID of the thread which will export the current thread's vertex */
189    lds_es_exporter_tid = 17,
190 
191    /* Repacked arguments - also listed separately for VS and TES */
192    lds_es_arg_0 = 20,
193 
194    /* VS arguments which need to be repacked */
195    lds_es_vs_vertex_id = 20,
196    lds_es_vs_instance_id = 24,
197 
198    /* TES arguments which need to be repacked */
199    lds_es_tes_u = 20,
200    lds_es_tes_v = 24,
201    lds_es_tes_rel_patch_id = 28,
202    lds_es_tes_patch_id = 32,
203 };
204 
205 typedef struct {
206    nir_ssa_def *num_repacked_invocations;
207    nir_ssa_def *repacked_invocation_index;
208 } wg_repack_result;
209 
210 /**
211  * Computes a horizontal sum of 8-bit packed values loaded from LDS.
212  *
213  * Each lane N will sum packed bytes 0 to N-1.
214  * We only care about the results from up to wave_id+1 lanes.
215  * (Other lanes are not deactivated but their calculation is not used.)
216  */
217 static nir_ssa_def *
summarize_repack(nir_builder * b,nir_ssa_def * packed_counts,unsigned num_lds_dwords)218 summarize_repack(nir_builder *b, nir_ssa_def *packed_counts, unsigned num_lds_dwords)
219 {
220    /* We'll use shift to filter out the bytes not needed by the current lane.
221     *
222     * Need to shift by: num_lds_dwords * 4 - lane_id (in bytes).
223     * However, two shifts are needed because one can't go all the way,
224     * so the shift amount is half that (and in bits).
225     *
226     * When v_dot4_u32_u8 is available, we right-shift a series of 0x01 bytes.
227     * This will yield 0x01 at wanted byte positions and 0x00 at unwanted positions,
228     * therefore v_dot can get rid of the unneeded values.
229     * This sequence is preferable because it better hides the latency of the LDS.
230     *
231     * If the v_dot instruction can't be used, we left-shift the packed bytes.
232     * This will shift out the unneeded bytes and shift in zeroes instead,
233     * then we sum them using v_sad_u8.
234     */
235 
236    nir_ssa_def *lane_id = nir_load_subgroup_invocation(b);
237    nir_ssa_def *shift = nir_iadd_imm(b, nir_imul_imm(b, lane_id, -4u), num_lds_dwords * 16);
238    bool use_dot = b->shader->options->has_udot_4x8;
239 
240    if (num_lds_dwords == 1) {
241       nir_ssa_def *dot_op = !use_dot ? NULL : nir_ushr(b, nir_ushr(b, nir_imm_int(b, 0x01010101), shift), shift);
242 
243       /* Broadcast the packed data we read from LDS (to the first 16 lanes, but we only care up to num_waves). */
244       nir_ssa_def *packed = nir_lane_permute_16_amd(b, packed_counts, nir_imm_int(b, 0), nir_imm_int(b, 0));
245 
246       /* Horizontally add the packed bytes. */
247       if (use_dot) {
248          return nir_udot_4x8_uadd(b, packed, dot_op, nir_imm_int(b, 0));
249       } else {
250          nir_ssa_def *sad_op = nir_ishl(b, nir_ishl(b, packed, shift), shift);
251          return nir_sad_u8x4(b, sad_op, nir_imm_int(b, 0), nir_imm_int(b, 0));
252       }
253    } else if (num_lds_dwords == 2) {
254       nir_ssa_def *dot_op = !use_dot ? NULL : nir_ushr(b, nir_ushr(b, nir_imm_int64(b, 0x0101010101010101), shift), shift);
255 
256       /* Broadcast the packed data we read from LDS (to the first 16 lanes, but we only care up to num_waves). */
257       nir_ssa_def *packed_dw0 = nir_lane_permute_16_amd(b, nir_unpack_64_2x32_split_x(b, packed_counts), nir_imm_int(b, 0), nir_imm_int(b, 0));
258       nir_ssa_def *packed_dw1 = nir_lane_permute_16_amd(b, nir_unpack_64_2x32_split_y(b, packed_counts), nir_imm_int(b, 0), nir_imm_int(b, 0));
259 
260       /* Horizontally add the packed bytes. */
261       if (use_dot) {
262          nir_ssa_def *sum = nir_udot_4x8_uadd(b, packed_dw0, nir_unpack_64_2x32_split_x(b, dot_op), nir_imm_int(b, 0));
263          return nir_udot_4x8_uadd(b, packed_dw1, nir_unpack_64_2x32_split_y(b, dot_op), sum);
264       } else {
265          nir_ssa_def *sad_op = nir_ishl(b, nir_ishl(b, nir_pack_64_2x32_split(b, packed_dw0, packed_dw1), shift), shift);
266          nir_ssa_def *sum = nir_sad_u8x4(b, nir_unpack_64_2x32_split_x(b, sad_op), nir_imm_int(b, 0), nir_imm_int(b, 0));
267          return nir_sad_u8x4(b, nir_unpack_64_2x32_split_y(b, sad_op), nir_imm_int(b, 0), sum);
268       }
269    } else {
270       unreachable("Unimplemented NGG wave count");
271    }
272 }
273 
274 /**
275  * Repacks invocations in the current workgroup to eliminate gaps between them.
276  *
277  * Uses 1 dword of LDS per 4 waves (1 byte of LDS per wave).
278  * Assumes that all invocations in the workgroup are active (exec = -1).
279  */
280 static wg_repack_result
repack_invocations_in_workgroup(nir_builder * b,nir_ssa_def * input_bool,unsigned lds_addr_base,unsigned max_num_waves,unsigned wave_size)281 repack_invocations_in_workgroup(nir_builder *b, nir_ssa_def *input_bool,
282                                 unsigned lds_addr_base, unsigned max_num_waves,
283                                 unsigned wave_size)
284 {
285    /* Input boolean: 1 if the current invocation should survive the repack. */
286    assert(input_bool->bit_size == 1);
287 
288    /* STEP 1. Count surviving invocations in the current wave.
289     *
290     * Implemented by a scalar instruction that simply counts the number of bits set in a 32/64-bit mask.
291     */
292 
293    nir_ssa_def *input_mask = nir_ballot(b, 1, wave_size, input_bool);
294    nir_ssa_def *surviving_invocations_in_current_wave = nir_bit_count(b, input_mask);
295 
296    /* If we know at compile time that the workgroup has only 1 wave, no further steps are necessary. */
297    if (max_num_waves == 1) {
298       wg_repack_result r = {
299          .num_repacked_invocations = surviving_invocations_in_current_wave,
300          .repacked_invocation_index = nir_mbcnt_amd(b, input_mask, nir_imm_int(b, 0)),
301       };
302       return r;
303    }
304 
305    /* STEP 2. Waves tell each other their number of surviving invocations.
306     *
307     * Each wave activates only its first lane (exec = 1), which stores the number of surviving
308     * invocations in that wave into the LDS, then reads the numbers from every wave.
309     *
310     * The workgroup size of NGG shaders is at most 256, which means
311     * the maximum number of waves is 4 in Wave64 mode and 8 in Wave32 mode.
312     * Each wave writes 1 byte, so it's up to 8 bytes, so at most 2 dwords are necessary.
313     */
314 
315    const unsigned num_lds_dwords = DIV_ROUND_UP(max_num_waves, 4);
316    assert(num_lds_dwords <= 2);
317 
318    nir_ssa_def *wave_id = nir_load_subgroup_id(b);
319    nir_ssa_def *dont_care = nir_ssa_undef(b, 1, num_lds_dwords * 32);
320    nir_if *if_first_lane = nir_push_if(b, nir_elect(b, 1));
321 
322    nir_store_shared(b, nir_u2u8(b, surviving_invocations_in_current_wave), wave_id, .base = lds_addr_base);
323 
324    nir_scoped_barrier(b, .execution_scope=NIR_SCOPE_WORKGROUP, .memory_scope=NIR_SCOPE_WORKGROUP,
325                          .memory_semantics=NIR_MEMORY_ACQ_REL, .memory_modes=nir_var_mem_shared);
326 
327    nir_ssa_def *packed_counts = nir_load_shared(b, 1, num_lds_dwords * 32, nir_imm_int(b, 0), .base = lds_addr_base, .align_mul = 8u);
328 
329    nir_pop_if(b, if_first_lane);
330 
331    packed_counts = nir_if_phi(b, packed_counts, dont_care);
332 
333    /* STEP 3. Compute the repacked invocation index and the total number of surviving invocations.
334     *
335     * By now, every wave knows the number of surviving invocations in all waves.
336     * Each number is 1 byte, and they are packed into up to 2 dwords.
337     *
338     * Each lane N will sum the number of surviving invocations from waves 0 to N-1.
339     * If the workgroup has M waves, then each wave will use only its first M+1 lanes for this.
340     * (Other lanes are not deactivated but their calculation is not used.)
341     *
342     * - We read the sum from the lane whose id is the current wave's id.
343     *   Add the masked bitcount to this, and we get the repacked invocation index.
344     * - We read the sum from the lane whose id is the number of waves in the workgroup.
345     *   This is the total number of surviving invocations in the workgroup.
346     */
347 
348    nir_ssa_def *num_waves = nir_load_num_subgroups(b);
349    nir_ssa_def *sum = summarize_repack(b, packed_counts, num_lds_dwords);
350 
351    nir_ssa_def *wg_repacked_index_base = nir_read_invocation(b, sum, wave_id);
352    nir_ssa_def *wg_num_repacked_invocations = nir_read_invocation(b, sum, num_waves);
353    nir_ssa_def *wg_repacked_index = nir_mbcnt_amd(b, input_mask, wg_repacked_index_base);
354 
355    wg_repack_result r = {
356       .num_repacked_invocations = wg_num_repacked_invocations,
357       .repacked_invocation_index = wg_repacked_index,
358    };
359 
360    return r;
361 }
362 
363 static nir_ssa_def *
pervertex_lds_addr(nir_builder * b,nir_ssa_def * vertex_idx,unsigned per_vtx_bytes)364 pervertex_lds_addr(nir_builder *b, nir_ssa_def *vertex_idx, unsigned per_vtx_bytes)
365 {
366    return nir_imul_imm(b, vertex_idx, per_vtx_bytes);
367 }
368 
369 static nir_ssa_def *
emit_pack_ngg_prim_exp_arg(nir_builder * b,unsigned num_vertices_per_primitives,nir_ssa_def * vertex_indices[3],nir_ssa_def * is_null_prim,bool use_edgeflags)370 emit_pack_ngg_prim_exp_arg(nir_builder *b, unsigned num_vertices_per_primitives,
371                            nir_ssa_def *vertex_indices[3], nir_ssa_def *is_null_prim,
372                            bool use_edgeflags)
373 {
374    nir_ssa_def *arg = use_edgeflags
375                       ? nir_load_initial_edgeflags_amd(b)
376                       : nir_imm_int(b, 0);
377 
378    for (unsigned i = 0; i < num_vertices_per_primitives; ++i) {
379       assert(vertex_indices[i]);
380       arg = nir_ior(b, arg, nir_ishl(b, vertex_indices[i], nir_imm_int(b, 10u * i)));
381    }
382 
383    if (is_null_prim) {
384       if (is_null_prim->bit_size == 1)
385          is_null_prim = nir_b2i32(b, is_null_prim);
386       assert(is_null_prim->bit_size == 32);
387       arg = nir_ior(b, arg, nir_ishl(b, is_null_prim, nir_imm_int(b, 31u)));
388    }
389 
390    return arg;
391 }
392 
393 static void
ngg_nogs_init_vertex_indices_vars(nir_builder * b,nir_function_impl * impl,lower_ngg_nogs_state * st)394 ngg_nogs_init_vertex_indices_vars(nir_builder *b, nir_function_impl *impl, lower_ngg_nogs_state *st)
395 {
396    for (unsigned v = 0; v < st->num_vertices_per_primitives; ++v) {
397       st->gs_vtx_indices_vars[v] = nir_local_variable_create(impl, glsl_uint_type(), "gs_vtx_addr");
398 
399       nir_ssa_def *vtx = nir_ubfe(b, nir_load_gs_vertex_offset_amd(b, .base = v / 2u),
400                          nir_imm_int(b, (v & 1u) * 16u), nir_imm_int(b, 16u));
401       nir_store_var(b, st->gs_vtx_indices_vars[v], vtx, 0x1);
402    }
403 }
404 
405 static nir_ssa_def *
emit_ngg_nogs_prim_exp_arg(nir_builder * b,lower_ngg_nogs_state * st)406 emit_ngg_nogs_prim_exp_arg(nir_builder *b, lower_ngg_nogs_state *st)
407 {
408    if (st->passthrough) {
409       assert(!st->export_prim_id || b->shader->info.stage != MESA_SHADER_VERTEX);
410       return nir_load_packed_passthrough_primitive_amd(b);
411    } else {
412       nir_ssa_def *vtx_idx[3] = {0};
413 
414       for (unsigned v = 0; v < st->num_vertices_per_primitives; ++v)
415          vtx_idx[v] = nir_load_var(b, st->gs_vtx_indices_vars[v]);
416 
417       return emit_pack_ngg_prim_exp_arg(b, st->num_vertices_per_primitives, vtx_idx, NULL, st->use_edgeflags);
418    }
419 }
420 
421 static void
emit_ngg_nogs_prim_export(nir_builder * b,lower_ngg_nogs_state * st,nir_ssa_def * arg)422 emit_ngg_nogs_prim_export(nir_builder *b, lower_ngg_nogs_state *st, nir_ssa_def *arg)
423 {
424    nir_ssa_def *gs_thread = st->gs_accepted_var
425                             ? nir_load_var(b, st->gs_accepted_var)
426                             : nir_has_input_primitive_amd(b);
427 
428    nir_if *if_gs_thread = nir_push_if(b, gs_thread);
429    {
430       if (!arg)
431          arg = emit_ngg_nogs_prim_exp_arg(b, st);
432 
433       if (st->has_prim_query) {
434          nir_if *if_shader_query = nir_push_if(b, nir_load_shader_query_enabled_amd(b));
435          {
436             /* Number of active GS threads. Each has 1 output primitive. */
437             nir_ssa_def *num_gs_threads = nir_bit_count(b, nir_ballot(b, 1, st->wave_size, nir_imm_bool(b, true)));
438             /* Activate only 1 lane and add the number of primitives to GDS. */
439             nir_if *if_elected = nir_push_if(b, nir_elect(b, 1));
440             {
441                /* Use a different GDS offset than NGG GS to ensure that pipeline statistics
442                 * queries won't return the number of primitives generated by VS/TES.
443                 */
444                nir_gds_atomic_add_amd(b, 32, num_gs_threads, nir_imm_int(b, 4), nir_imm_int(b, 0x100));
445             }
446             nir_pop_if(b, if_elected);
447          }
448          nir_pop_if(b, if_shader_query);
449       }
450 
451       nir_export_primitive_amd(b, arg);
452    }
453    nir_pop_if(b, if_gs_thread);
454 }
455 
456 static void
emit_ngg_nogs_prim_id_store_shared(nir_builder * b,lower_ngg_nogs_state * st)457 emit_ngg_nogs_prim_id_store_shared(nir_builder *b, lower_ngg_nogs_state *st)
458 {
459    nir_ssa_def *gs_thread = st->gs_accepted_var ?
460       nir_load_var(b, st->gs_accepted_var) : nir_has_input_primitive_amd(b);
461 
462    nir_if *if_gs_thread = nir_push_if(b, gs_thread);
463    {
464       /* Copy Primitive IDs from GS threads to the LDS address
465        * corresponding to the ES thread of the provoking vertex.
466        * It will be exported as a per-vertex attribute.
467        */
468       nir_ssa_def *prim_id = nir_load_primitive_id(b);
469       nir_ssa_def *provoking_vtx_idx = nir_load_var(b, st->gs_vtx_indices_vars[st->provoking_vtx_idx]);
470       nir_ssa_def *addr = pervertex_lds_addr(b, provoking_vtx_idx, 4u);
471 
472       nir_store_shared(b, prim_id, addr);
473    }
474    nir_pop_if(b, if_gs_thread);
475 }
476 
477 static void
emit_store_ngg_nogs_es_primitive_id(nir_builder * b)478 emit_store_ngg_nogs_es_primitive_id(nir_builder *b)
479 {
480    nir_ssa_def *prim_id = NULL;
481 
482    if (b->shader->info.stage == MESA_SHADER_VERTEX) {
483       /* LDS address where the primitive ID is stored */
484       nir_ssa_def *thread_id_in_threadgroup = nir_load_local_invocation_index(b);
485       nir_ssa_def *addr =  pervertex_lds_addr(b, thread_id_in_threadgroup, 4u);
486 
487       /* Load primitive ID from LDS */
488       prim_id = nir_load_shared(b, 1, 32, addr);
489    } else if (b->shader->info.stage == MESA_SHADER_TESS_EVAL) {
490       /* Just use tess eval primitive ID, which is the same as the patch ID. */
491       prim_id = nir_load_primitive_id(b);
492    }
493 
494    nir_io_semantics io_sem = {
495       .location = VARYING_SLOT_PRIMITIVE_ID,
496       .num_slots = 1,
497    };
498 
499    nir_store_output(b, prim_id, nir_imm_zero(b, 1, 32),
500                     .base = io_sem.location,
501                     .src_type = nir_type_uint32, .io_semantics = io_sem);
502 }
503 
504 static bool
remove_culling_shader_output(nir_builder * b,nir_instr * instr,void * state)505 remove_culling_shader_output(nir_builder *b, nir_instr *instr, void *state)
506 {
507    remove_culling_shader_outputs_state *s = (remove_culling_shader_outputs_state *) state;
508 
509    if (instr->type != nir_instr_type_intrinsic)
510       return false;
511 
512    nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
513 
514    /* These are not allowed in VS / TES */
515    assert(intrin->intrinsic != nir_intrinsic_store_per_vertex_output &&
516           intrin->intrinsic != nir_intrinsic_load_per_vertex_input);
517 
518    /* We are only interested in output stores now */
519    if (intrin->intrinsic != nir_intrinsic_store_output)
520       return false;
521 
522    b->cursor = nir_before_instr(instr);
523 
524    /* Position output - store the value to a variable, remove output store */
525    nir_io_semantics io_sem = nir_intrinsic_io_semantics(intrin);
526    if (io_sem.location == VARYING_SLOT_POS) {
527       /* TODO: check if it's indirect, etc? */
528       unsigned writemask = nir_intrinsic_write_mask(intrin);
529       nir_ssa_def *store_val = intrin->src[0].ssa;
530       nir_store_var(b, s->pre_cull_position_value_var, store_val, writemask);
531    }
532 
533    /* Remove all output stores */
534    nir_instr_remove(instr);
535    return true;
536 }
537 
538 static void
remove_culling_shader_outputs(nir_shader * culling_shader,lower_ngg_nogs_state * nogs_state,nir_variable * pre_cull_position_value_var)539 remove_culling_shader_outputs(nir_shader *culling_shader, lower_ngg_nogs_state *nogs_state, nir_variable *pre_cull_position_value_var)
540 {
541    remove_culling_shader_outputs_state s = {
542       .pre_cull_position_value_var = pre_cull_position_value_var,
543    };
544 
545    nir_shader_instructions_pass(culling_shader, remove_culling_shader_output,
546                                 nir_metadata_block_index | nir_metadata_dominance, &s);
547 
548    /* Remove dead code resulting from the deleted outputs. */
549    bool progress;
550    do {
551       progress = false;
552       NIR_PASS(progress, culling_shader, nir_opt_dead_write_vars);
553       NIR_PASS(progress, culling_shader, nir_opt_dce);
554       NIR_PASS(progress, culling_shader, nir_opt_dead_cf);
555    } while (progress);
556 }
557 
558 static void
rewrite_uses_to_var(nir_builder * b,nir_ssa_def * old_def,nir_variable * replacement_var,unsigned replacement_var_channel)559 rewrite_uses_to_var(nir_builder *b, nir_ssa_def *old_def, nir_variable *replacement_var, unsigned replacement_var_channel)
560 {
561    if (old_def->parent_instr->type == nir_instr_type_load_const)
562       return;
563 
564    b->cursor = nir_after_instr(old_def->parent_instr);
565    if (b->cursor.instr->type == nir_instr_type_phi)
566       b->cursor = nir_after_phis(old_def->parent_instr->block);
567 
568    nir_ssa_def *pos_val_rep = nir_load_var(b, replacement_var);
569    nir_ssa_def *replacement = nir_channel(b, pos_val_rep, replacement_var_channel);
570 
571    if (old_def->num_components > 1) {
572       /* old_def uses a swizzled vector component.
573        * There is no way to replace the uses of just a single vector component,
574        * so instead create a new vector and replace all uses of the old vector.
575        */
576       nir_ssa_def *old_def_elements[NIR_MAX_VEC_COMPONENTS] = {0};
577       for (unsigned j = 0; j < old_def->num_components; ++j)
578          old_def_elements[j] = nir_channel(b, old_def, j);
579       replacement = nir_vec(b, old_def_elements, old_def->num_components);
580    }
581 
582    nir_ssa_def_rewrite_uses_after(old_def, replacement, replacement->parent_instr);
583 }
584 
585 static bool
remove_extra_pos_output(nir_builder * b,nir_instr * instr,void * state)586 remove_extra_pos_output(nir_builder *b, nir_instr *instr, void *state)
587 {
588    remove_extra_position_output_state *s = (remove_extra_position_output_state *) state;
589 
590    if (instr->type != nir_instr_type_intrinsic)
591       return false;
592 
593    nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
594 
595    /* These are not allowed in VS / TES */
596    assert(intrin->intrinsic != nir_intrinsic_store_per_vertex_output &&
597           intrin->intrinsic != nir_intrinsic_load_per_vertex_input);
598 
599    /* We are only interested in output stores now */
600    if (intrin->intrinsic != nir_intrinsic_store_output)
601       return false;
602 
603    nir_io_semantics io_sem = nir_intrinsic_io_semantics(intrin);
604    if (io_sem.location != VARYING_SLOT_POS)
605       return false;
606 
607    b->cursor = nir_before_instr(instr);
608 
609    /* In case other outputs use what we calculated for pos,
610     * try to avoid calculating it again by rewriting the usages
611     * of the store components here.
612     */
613    nir_ssa_def *store_val = intrin->src[0].ssa;
614    unsigned store_pos_component = nir_intrinsic_component(intrin);
615 
616    nir_instr_remove(instr);
617 
618    if (store_val->parent_instr->type == nir_instr_type_alu) {
619       nir_alu_instr *alu = nir_instr_as_alu(store_val->parent_instr);
620       if (nir_op_is_vec(alu->op)) {
621          /* Output store uses a vector, we can easily rewrite uses of each vector element. */
622 
623          unsigned num_vec_src = 0;
624          if (alu->op == nir_op_mov)
625             num_vec_src = 1;
626          else if (alu->op == nir_op_vec2)
627             num_vec_src = 2;
628          else if (alu->op == nir_op_vec3)
629             num_vec_src = 3;
630          else if (alu->op == nir_op_vec4)
631             num_vec_src = 4;
632          assert(num_vec_src);
633 
634          /* Remember the current components whose uses we wish to replace.
635           * This is needed because rewriting one source can affect the others too.
636           */
637          nir_ssa_def *vec_comps[NIR_MAX_VEC_COMPONENTS] = {0};
638          for (unsigned i = 0; i < num_vec_src; i++)
639             vec_comps[i] = alu->src[i].src.ssa;
640 
641          for (unsigned i = 0; i < num_vec_src; i++)
642             rewrite_uses_to_var(b, vec_comps[i], s->pos_value_replacement, store_pos_component + i);
643       } else {
644          rewrite_uses_to_var(b, store_val, s->pos_value_replacement, store_pos_component);
645       }
646    } else {
647       rewrite_uses_to_var(b, store_val, s->pos_value_replacement, store_pos_component);
648    }
649 
650    return true;
651 }
652 
653 static void
remove_extra_pos_outputs(nir_shader * shader,lower_ngg_nogs_state * nogs_state)654 remove_extra_pos_outputs(nir_shader *shader, lower_ngg_nogs_state *nogs_state)
655 {
656    remove_extra_position_output_state s = {
657       .pos_value_replacement = nogs_state->position_value_var,
658    };
659 
660    nir_shader_instructions_pass(shader, remove_extra_pos_output,
661                                 nir_metadata_block_index | nir_metadata_dominance, &s);
662 }
663 
664 static bool
remove_compacted_arg(lower_ngg_nogs_state * state,nir_builder * b,unsigned idx)665 remove_compacted_arg(lower_ngg_nogs_state *state, nir_builder *b, unsigned idx)
666 {
667    nir_instr *store_instr = state->compact_arg_stores[idx];
668    if (!store_instr)
669       return false;
670 
671    /* Simply remove the store. */
672    nir_instr_remove(store_instr);
673 
674    /* Find the intrinsic that overwrites the shader arguments,
675     * and change its corresponding source.
676     * This will cause NIR's DCE to recognize the load and its phis as dead.
677     */
678    b->cursor = nir_before_instr(&state->overwrite_args->instr);
679    nir_ssa_def *undef_arg = nir_ssa_undef(b, 1, 32);
680    nir_ssa_def_rewrite_uses(state->overwrite_args->src[idx].ssa, undef_arg);
681 
682    state->compact_arg_stores[idx] = NULL;
683    return true;
684 }
685 
686 static bool
cleanup_culling_shader_after_dce(nir_shader * shader,nir_function_impl * function_impl,lower_ngg_nogs_state * state)687 cleanup_culling_shader_after_dce(nir_shader *shader,
688                                  nir_function_impl *function_impl,
689                                  lower_ngg_nogs_state *state)
690 {
691    bool uses_vs_vertex_id = false;
692    bool uses_vs_instance_id = false;
693    bool uses_tes_u = false;
694    bool uses_tes_v = false;
695    bool uses_tes_rel_patch_id = false;
696    bool uses_tes_patch_id = false;
697 
698    bool progress = false;
699    nir_builder b;
700    nir_builder_init(&b, function_impl);
701 
702    nir_foreach_block_reverse_safe(block, function_impl) {
703       nir_foreach_instr_reverse_safe(instr, block) {
704          if (instr->type != nir_instr_type_intrinsic)
705             continue;
706 
707          nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
708 
709          switch (intrin->intrinsic) {
710          case nir_intrinsic_alloc_vertices_and_primitives_amd:
711             goto cleanup_culling_shader_after_dce_done;
712          case nir_intrinsic_load_vertex_id:
713          case nir_intrinsic_load_vertex_id_zero_base:
714             uses_vs_vertex_id = true;
715             break;
716          case nir_intrinsic_load_instance_id:
717             uses_vs_instance_id = true;
718             break;
719          case nir_intrinsic_load_input:
720             if (state->instance_rate_inputs &
721                 (1u << (nir_intrinsic_base(intrin) - VERT_ATTRIB_GENERIC0)))
722                uses_vs_instance_id = true;
723             else
724                uses_vs_vertex_id = true;
725             break;
726          case nir_intrinsic_load_tess_coord:
727             uses_tes_u = uses_tes_v = true;
728             break;
729          case nir_intrinsic_load_tess_rel_patch_id_amd:
730             uses_tes_rel_patch_id = true;
731             break;
732          case nir_intrinsic_load_primitive_id:
733             if (shader->info.stage == MESA_SHADER_TESS_EVAL)
734                uses_tes_patch_id = true;
735             break;
736          default:
737             break;
738          }
739       }
740    }
741 
742    cleanup_culling_shader_after_dce_done:
743 
744    if (shader->info.stage == MESA_SHADER_VERTEX) {
745       if (!uses_vs_vertex_id)
746          progress |= remove_compacted_arg(state, &b, 0);
747       if (!uses_vs_instance_id)
748          progress |= remove_compacted_arg(state, &b, 1);
749    } else if (shader->info.stage == MESA_SHADER_TESS_EVAL) {
750       if (!uses_tes_u)
751          progress |= remove_compacted_arg(state, &b, 0);
752       if (!uses_tes_v)
753          progress |= remove_compacted_arg(state, &b, 1);
754       if (!uses_tes_rel_patch_id)
755          progress |= remove_compacted_arg(state, &b, 2);
756       if (!uses_tes_patch_id)
757          progress |= remove_compacted_arg(state, &b, 3);
758    }
759 
760    return progress;
761 }
762 
763 /**
764  * Perform vertex compaction after culling.
765  *
766  * 1. Repack surviving ES invocations (this determines which lane will export which vertex)
767  * 2. Surviving ES vertex invocations store their data to LDS
768  * 3. Emit GS_ALLOC_REQ
769  * 4. Repacked invocations load the vertex data from LDS
770  * 5. GS threads update their vertex indices
771  */
772 static void
compact_vertices_after_culling(nir_builder * b,lower_ngg_nogs_state * nogs_state,nir_variable ** repacked_arg_vars,nir_variable ** gs_vtxaddr_vars,nir_ssa_def * invocation_index,nir_ssa_def * es_vertex_lds_addr,nir_ssa_def * es_exporter_tid,nir_ssa_def * num_live_vertices_in_workgroup,nir_ssa_def * fully_culled,unsigned ngg_scratch_lds_base_addr,unsigned pervertex_lds_bytes,unsigned max_exported_args)773 compact_vertices_after_culling(nir_builder *b,
774                                lower_ngg_nogs_state *nogs_state,
775                                nir_variable **repacked_arg_vars,
776                                nir_variable **gs_vtxaddr_vars,
777                                nir_ssa_def *invocation_index,
778                                nir_ssa_def *es_vertex_lds_addr,
779                                nir_ssa_def *es_exporter_tid,
780                                nir_ssa_def *num_live_vertices_in_workgroup,
781                                nir_ssa_def *fully_culled,
782                                unsigned ngg_scratch_lds_base_addr,
783                                unsigned pervertex_lds_bytes,
784                                unsigned max_exported_args)
785 {
786    nir_variable *es_accepted_var = nogs_state->es_accepted_var;
787    nir_variable *gs_accepted_var = nogs_state->gs_accepted_var;
788    nir_variable *position_value_var = nogs_state->position_value_var;
789    nir_variable *prim_exp_arg_var = nogs_state->prim_exp_arg_var;
790 
791    nir_if *if_es_accepted = nir_push_if(b, nir_load_var(b, es_accepted_var));
792    {
793       nir_ssa_def *exporter_addr = pervertex_lds_addr(b, es_exporter_tid, pervertex_lds_bytes);
794 
795       /* Store the exporter thread's index to the LDS space of the current thread so GS threads can load it */
796       nir_store_shared(b, nir_u2u8(b, es_exporter_tid), es_vertex_lds_addr, .base = lds_es_exporter_tid);
797 
798       /* Store the current thread's position output to the exporter thread's LDS space */
799       nir_ssa_def *pos = nir_load_var(b, position_value_var);
800       nir_store_shared(b, pos, exporter_addr, .base = lds_es_pos_x);
801 
802       /* Store the current thread's repackable arguments to the exporter thread's LDS space */
803       for (unsigned i = 0; i < max_exported_args; ++i) {
804          nir_ssa_def *arg_val = nir_load_var(b, repacked_arg_vars[i]);
805          nir_intrinsic_instr *store = nir_store_shared(b, arg_val, exporter_addr, .base = lds_es_arg_0 + 4u * i);
806 
807          nogs_state->compact_arg_stores[i] = &store->instr;
808       }
809    }
810    nir_pop_if(b, if_es_accepted);
811 
812    /* TODO: Consider adding a shortcut exit.
813     * Waves that have no vertices and primitives left can s_endpgm right here.
814     */
815 
816    nir_scoped_barrier(b, .execution_scope=NIR_SCOPE_WORKGROUP, .memory_scope=NIR_SCOPE_WORKGROUP,
817                          .memory_semantics=NIR_MEMORY_ACQ_REL, .memory_modes=nir_var_mem_shared);
818 
819    nir_ssa_def *es_survived = nir_ilt(b, invocation_index, num_live_vertices_in_workgroup);
820    nir_if *if_packed_es_thread = nir_push_if(b, es_survived);
821    {
822       /* Read position from the current ES thread's LDS space (written by the exported vertex's ES thread) */
823       nir_ssa_def *exported_pos = nir_load_shared(b, 4, 32, es_vertex_lds_addr, .base = lds_es_pos_x);
824       nir_store_var(b, position_value_var, exported_pos, 0xfu);
825 
826       /* Read the repacked arguments */
827       for (unsigned i = 0; i < max_exported_args; ++i) {
828          nir_ssa_def *arg_val = nir_load_shared(b, 1, 32, es_vertex_lds_addr, .base = lds_es_arg_0 + 4u * i);
829          nir_store_var(b, repacked_arg_vars[i], arg_val, 0x1u);
830       }
831    }
832    nir_push_else(b, if_packed_es_thread);
833    {
834       nir_store_var(b, position_value_var, nir_ssa_undef(b, 4, 32), 0xfu);
835       for (unsigned i = 0; i < max_exported_args; ++i)
836          nir_store_var(b, repacked_arg_vars[i], nir_ssa_undef(b, 1, 32), 0x1u);
837    }
838    nir_pop_if(b, if_packed_es_thread);
839 
840    nir_if *if_gs_accepted = nir_push_if(b, nir_load_var(b, gs_accepted_var));
841    {
842       nir_ssa_def *exporter_vtx_indices[3] = {0};
843 
844       /* Load the index of the ES threads that will export the current GS thread's vertices */
845       for (unsigned v = 0; v < 3; ++v) {
846          nir_ssa_def *vtx_addr = nir_load_var(b, gs_vtxaddr_vars[v]);
847          nir_ssa_def *exporter_vtx_idx = nir_load_shared(b, 1, 8, vtx_addr, .base = lds_es_exporter_tid);
848          exporter_vtx_indices[v] = nir_u2u32(b, exporter_vtx_idx);
849          nir_store_var(b, nogs_state->gs_vtx_indices_vars[v], exporter_vtx_indices[v], 0x1);
850       }
851 
852       nir_ssa_def *prim_exp_arg = emit_pack_ngg_prim_exp_arg(b, 3, exporter_vtx_indices, NULL, nogs_state->use_edgeflags);
853       nir_store_var(b, prim_exp_arg_var, prim_exp_arg, 0x1u);
854    }
855    nir_pop_if(b, if_gs_accepted);
856 
857    nir_store_var(b, es_accepted_var, es_survived, 0x1u);
858    nir_store_var(b, gs_accepted_var, nir_bcsel(b, fully_culled, nir_imm_false(b), nir_has_input_primitive_amd(b)), 0x1u);
859 }
860 
861 static void
analyze_shader_before_culling_walk(nir_ssa_def * ssa,uint8_t flag,lower_ngg_nogs_state * nogs_state)862 analyze_shader_before_culling_walk(nir_ssa_def *ssa,
863                                    uint8_t flag,
864                                    lower_ngg_nogs_state *nogs_state)
865 {
866    nir_instr *instr = ssa->parent_instr;
867    uint8_t old_pass_flags = instr->pass_flags;
868    instr->pass_flags |= flag;
869 
870    if (instr->pass_flags == old_pass_flags)
871       return; /* Already visited. */
872 
873    switch (instr->type) {
874    case nir_instr_type_intrinsic: {
875       nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
876 
877       /* VS input loads and SSBO loads are actually VRAM reads on AMD HW. */
878       switch (intrin->intrinsic) {
879       case nir_intrinsic_load_input: {
880          nir_io_semantics in_io_sem = nir_intrinsic_io_semantics(intrin);
881          uint64_t in_mask = UINT64_C(1) << (uint64_t) in_io_sem.location;
882          if (instr->pass_flags & nggc_passflag_used_by_pos)
883             nogs_state->inputs_needed_by_pos |= in_mask;
884          else if (instr->pass_flags & nggc_passflag_used_by_other)
885             nogs_state->inputs_needed_by_others |= in_mask;
886          break;
887       }
888       default:
889          break;
890       }
891 
892       break;
893    }
894    case nir_instr_type_alu: {
895       nir_alu_instr *alu = nir_instr_as_alu(instr);
896       unsigned num_srcs = nir_op_infos[alu->op].num_inputs;
897 
898       for (unsigned i = 0; i < num_srcs; ++i) {
899          analyze_shader_before_culling_walk(alu->src[i].src.ssa, flag, nogs_state);
900       }
901 
902       break;
903    }
904    case nir_instr_type_phi: {
905       nir_phi_instr *phi = nir_instr_as_phi(instr);
906       nir_foreach_phi_src_safe(phi_src, phi) {
907          analyze_shader_before_culling_walk(phi_src->src.ssa, flag, nogs_state);
908       }
909 
910       break;
911    }
912    default:
913       break;
914    }
915 }
916 
917 static void
analyze_shader_before_culling(nir_shader * shader,lower_ngg_nogs_state * nogs_state)918 analyze_shader_before_culling(nir_shader *shader, lower_ngg_nogs_state *nogs_state)
919 {
920    nir_foreach_function(func, shader) {
921       nir_foreach_block(block, func->impl) {
922          nir_foreach_instr(instr, block) {
923             instr->pass_flags = 0;
924 
925             if (instr->type != nir_instr_type_intrinsic)
926                continue;
927 
928             nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
929             if (intrin->intrinsic != nir_intrinsic_store_output)
930                continue;
931 
932             nir_io_semantics io_sem = nir_intrinsic_io_semantics(intrin);
933             nir_ssa_def *store_val = intrin->src[0].ssa;
934             uint8_t flag = io_sem.location == VARYING_SLOT_POS ? nggc_passflag_used_by_pos : nggc_passflag_used_by_other;
935             analyze_shader_before_culling_walk(store_val, flag, nogs_state);
936          }
937       }
938    }
939 }
940 
941 /**
942  * Save the reusable SSA definitions to variables so that the
943  * bottom shader part can reuse them from the top part.
944  *
945  * 1. We create a new function temporary variable for reusables,
946  *    and insert a store+load.
947  * 2. The shader is cloned (the top part is created), then the
948  *    control flow is reinserted (for the bottom part.)
949  * 3. For reusables, we delete the variable stores from the
950  *    bottom part. This will make them use the variables from
951  *    the top part and DCE the redundant instructions.
952  */
953 static void
save_reusable_variables(nir_builder * b,lower_ngg_nogs_state * nogs_state)954 save_reusable_variables(nir_builder *b, lower_ngg_nogs_state *nogs_state)
955 {
956    ASSERTED int vec_ok = u_vector_init(&nogs_state->saved_uniforms, 4, sizeof(saved_uniform));
957    assert(vec_ok);
958 
959    nir_block *block = nir_start_block(b->impl);
960    while (block) {
961       /* Process the instructions in the current block. */
962       nir_foreach_instr_safe(instr, block) {
963          /* Find instructions whose SSA definitions are used by both
964           * the top and bottom parts of the shader (before and after culling).
965           * Only in this case, it makes sense for the bottom part
966           * to try to reuse these from the top part.
967           */
968          if ((instr->pass_flags & nggc_passflag_used_by_both) != nggc_passflag_used_by_both)
969             continue;
970 
971          /* Determine if we can reuse the current SSA value.
972           * When vertex compaction is used, it is possible that the same shader invocation
973           * processes a different vertex in the top and bottom part of the shader.
974           * Therefore, we only reuse uniform values.
975           */
976          nir_ssa_def *ssa = NULL;
977          switch (instr->type) {
978          case nir_instr_type_alu: {
979             nir_alu_instr *alu = nir_instr_as_alu(instr);
980             if (alu->dest.dest.ssa.divergent)
981                continue;
982             /* Ignore uniform floats because they regress VGPR usage too much */
983             if (nir_op_infos[alu->op].output_type & nir_type_float)
984                continue;
985             ssa = &alu->dest.dest.ssa;
986             break;
987          }
988          case nir_instr_type_intrinsic: {
989             nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
990             if (!nir_intrinsic_can_reorder(intrin) ||
991                 !nir_intrinsic_infos[intrin->intrinsic].has_dest ||
992                 intrin->dest.ssa.divergent)
993                continue;
994             ssa = &intrin->dest.ssa;
995             break;
996          }
997          case nir_instr_type_phi: {
998             nir_phi_instr *phi = nir_instr_as_phi(instr);
999             if (phi->dest.ssa.divergent)
1000                continue;
1001             ssa = &phi->dest.ssa;
1002             break;
1003          }
1004          default:
1005             continue;
1006          }
1007 
1008          assert(ssa);
1009 
1010          /* Determine a suitable type for the SSA value. */
1011          enum glsl_base_type base_type = GLSL_TYPE_UINT;
1012          switch (ssa->bit_size) {
1013          case 8: base_type = GLSL_TYPE_UINT8; break;
1014          case 16: base_type = GLSL_TYPE_UINT16; break;
1015          case 32: base_type = GLSL_TYPE_UINT; break;
1016          case 64: base_type = GLSL_TYPE_UINT64; break;
1017          default: continue;
1018          }
1019 
1020          const struct glsl_type *t = ssa->num_components == 1
1021                                      ? glsl_scalar_type(base_type)
1022                                      : glsl_vector_type(base_type, ssa->num_components);
1023 
1024          saved_uniform *saved = (saved_uniform *) u_vector_add(&nogs_state->saved_uniforms);
1025          assert(saved);
1026 
1027          /* Create a new NIR variable where we store the reusable value.
1028           * Then, we reload the variable and replace the uses of the value
1029           * with the reloaded variable.
1030           */
1031          saved->var = nir_local_variable_create(b->impl, t, NULL);
1032          saved->ssa = ssa;
1033 
1034          b->cursor = instr->type == nir_instr_type_phi
1035                      ? nir_after_instr_and_phis(instr)
1036                      : nir_after_instr(instr);
1037          nir_store_var(b, saved->var, saved->ssa, BITFIELD_MASK(ssa->num_components));
1038          nir_ssa_def *reloaded = nir_load_var(b, saved->var);
1039          nir_ssa_def_rewrite_uses_after(ssa, reloaded, reloaded->parent_instr);
1040       }
1041 
1042       /* Look at the next CF node. */
1043       nir_cf_node *next_cf_node = nir_cf_node_next(&block->cf_node);
1044       if (next_cf_node) {
1045          /* It makes no sense to try to reuse things from within loops. */
1046          bool next_is_loop = next_cf_node->type == nir_cf_node_loop;
1047 
1048          /* Don't reuse if we're in divergent control flow.
1049           *
1050           * Thanks to vertex repacking, the same shader invocation may process a different vertex
1051           * in the top and bottom part, and it's even possible that this different vertex was initially
1052           * processed in a different wave. So the two parts may take a different divergent code path.
1053           * Therefore, these variables in divergent control flow may stay undefined.
1054           *
1055           * Note that this problem doesn't exist if vertices are not repacked or if the
1056           * workgroup only has a single wave.
1057           */
1058          bool next_is_divergent_if =
1059             next_cf_node->type == nir_cf_node_if &&
1060             nir_cf_node_as_if(next_cf_node)->condition.ssa->divergent;
1061 
1062          if (next_is_loop || next_is_divergent_if) {
1063             block = nir_cf_node_cf_tree_next(next_cf_node);
1064             continue;
1065          }
1066       }
1067 
1068       /* Go to the next block. */
1069       block = nir_block_cf_tree_next(block);
1070    }
1071 }
1072 
1073 /**
1074  * Reuses suitable variables from the top part of the shader,
1075  * by deleting their stores from the bottom part.
1076  */
1077 static void
apply_reusable_variables(nir_builder * b,lower_ngg_nogs_state * nogs_state)1078 apply_reusable_variables(nir_builder *b, lower_ngg_nogs_state *nogs_state)
1079 {
1080    if (!u_vector_length(&nogs_state->saved_uniforms)) {
1081       u_vector_finish(&nogs_state->saved_uniforms);
1082       return;
1083    }
1084 
1085    nir_foreach_block_reverse_safe(block, b->impl) {
1086       nir_foreach_instr_reverse_safe(instr, block) {
1087          if (instr->type != nir_instr_type_intrinsic)
1088             continue;
1089          nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
1090 
1091          /* When we found any of these intrinsics, it means
1092           * we reached the top part and we must stop.
1093           */
1094          if (intrin->intrinsic == nir_intrinsic_alloc_vertices_and_primitives_amd)
1095             goto done;
1096 
1097          if (intrin->intrinsic != nir_intrinsic_store_deref)
1098             continue;
1099          nir_deref_instr *deref = nir_src_as_deref(intrin->src[0]);
1100          if (deref->deref_type != nir_deref_type_var)
1101             continue;
1102 
1103          saved_uniform *saved;
1104          u_vector_foreach(saved, &nogs_state->saved_uniforms) {
1105             if (saved->var == deref->var) {
1106                nir_instr_remove(instr);
1107             }
1108          }
1109       }
1110    }
1111 
1112    done:
1113    u_vector_finish(&nogs_state->saved_uniforms);
1114 }
1115 
1116 static void
add_deferred_attribute_culling(nir_builder * b,nir_cf_list * original_extracted_cf,lower_ngg_nogs_state * nogs_state)1117 add_deferred_attribute_culling(nir_builder *b, nir_cf_list *original_extracted_cf, lower_ngg_nogs_state *nogs_state)
1118 {
1119    bool uses_instance_id = BITSET_TEST(b->shader->info.system_values_read, SYSTEM_VALUE_INSTANCE_ID);
1120    bool uses_tess_primitive_id = BITSET_TEST(b->shader->info.system_values_read, SYSTEM_VALUE_PRIMITIVE_ID);
1121 
1122    unsigned max_exported_args = b->shader->info.stage == MESA_SHADER_VERTEX ? 2 : 4;
1123    if (b->shader->info.stage == MESA_SHADER_VERTEX && !uses_instance_id)
1124       max_exported_args--;
1125    else if (b->shader->info.stage == MESA_SHADER_TESS_EVAL && !uses_tess_primitive_id)
1126       max_exported_args--;
1127 
1128    unsigned pervertex_lds_bytes = lds_es_arg_0 + max_exported_args * 4u;
1129    unsigned total_es_lds_bytes = pervertex_lds_bytes * nogs_state->max_es_num_vertices;
1130    unsigned max_num_waves = nogs_state->max_num_waves;
1131    unsigned ngg_scratch_lds_base_addr = ALIGN(total_es_lds_bytes, 8u);
1132    unsigned ngg_scratch_lds_bytes = ALIGN(max_num_waves, 4u);
1133    nogs_state->total_lds_bytes = ngg_scratch_lds_base_addr + ngg_scratch_lds_bytes;
1134 
1135    nir_function_impl *impl = nir_shader_get_entrypoint(b->shader);
1136 
1137    /* Create some helper variables. */
1138    nir_variable *position_value_var = nogs_state->position_value_var;
1139    nir_variable *prim_exp_arg_var = nogs_state->prim_exp_arg_var;
1140    nir_variable *gs_accepted_var = nogs_state->gs_accepted_var;
1141    nir_variable *es_accepted_var = nogs_state->es_accepted_var;
1142    nir_variable *gs_vtxaddr_vars[3] = {
1143       nir_local_variable_create(impl, glsl_uint_type(), "gs_vtx0_addr"),
1144       nir_local_variable_create(impl, glsl_uint_type(), "gs_vtx1_addr"),
1145       nir_local_variable_create(impl, glsl_uint_type(), "gs_vtx2_addr"),
1146    };
1147    nir_variable *repacked_arg_vars[4] = {
1148       nir_local_variable_create(impl, glsl_uint_type(), "repacked_arg_0"),
1149       nir_local_variable_create(impl, glsl_uint_type(), "repacked_arg_1"),
1150       nir_local_variable_create(impl, glsl_uint_type(), "repacked_arg_2"),
1151       nir_local_variable_create(impl, glsl_uint_type(), "repacked_arg_3"),
1152    };
1153 
1154    /* Top part of the culling shader (aka. position shader part)
1155     *
1156     * We clone the full ES shader and emit it here, but we only really care
1157     * about its position output, so we delete every other output from this part.
1158     * The position output is stored into a temporary variable, and reloaded later.
1159     */
1160 
1161    b->cursor = nir_before_cf_list(&impl->body);
1162 
1163    nir_ssa_def *es_thread = nir_has_input_vertex_amd(b);
1164    nir_if *if_es_thread = nir_push_if(b, es_thread);
1165    {
1166       /* Initialize the position output variable to zeroes, in case not all VS/TES invocations store the output.
1167        * The spec doesn't require it, but we use (0, 0, 0, 1) because some games rely on that.
1168        */
1169       nir_store_var(b, position_value_var, nir_imm_vec4(b, 0.0f, 0.0f, 0.0f, 1.0f), 0xfu);
1170 
1171       /* Now reinsert a clone of the shader code */
1172       struct hash_table *remap_table = _mesa_pointer_hash_table_create(NULL);
1173       nir_cf_list_clone_and_reinsert(original_extracted_cf, &if_es_thread->cf_node, b->cursor, remap_table);
1174       _mesa_hash_table_destroy(remap_table, NULL);
1175       b->cursor = nir_after_cf_list(&if_es_thread->then_list);
1176 
1177       /* Remember the current thread's shader arguments */
1178       if (b->shader->info.stage == MESA_SHADER_VERTEX) {
1179          nir_store_var(b, repacked_arg_vars[0], nir_load_vertex_id_zero_base(b), 0x1u);
1180          if (uses_instance_id)
1181             nir_store_var(b, repacked_arg_vars[1], nir_load_instance_id(b), 0x1u);
1182       } else if (b->shader->info.stage == MESA_SHADER_TESS_EVAL) {
1183          nir_ssa_def *tess_coord = nir_load_tess_coord(b);
1184          nir_store_var(b, repacked_arg_vars[0], nir_channel(b, tess_coord, 0), 0x1u);
1185          nir_store_var(b, repacked_arg_vars[1], nir_channel(b, tess_coord, 1), 0x1u);
1186          nir_store_var(b, repacked_arg_vars[2], nir_load_tess_rel_patch_id_amd(b), 0x1u);
1187          if (uses_tess_primitive_id)
1188             nir_store_var(b, repacked_arg_vars[3], nir_load_primitive_id(b), 0x1u);
1189       } else {
1190          unreachable("Should be VS or TES.");
1191       }
1192    }
1193    nir_pop_if(b, if_es_thread);
1194 
1195    nir_store_var(b, es_accepted_var, es_thread, 0x1u);
1196    nir_store_var(b, gs_accepted_var, nir_has_input_primitive_amd(b), 0x1u);
1197 
1198    /* Remove all non-position outputs, and put the position output into the variable. */
1199    nir_metadata_preserve(impl, nir_metadata_none);
1200    remove_culling_shader_outputs(b->shader, nogs_state, position_value_var);
1201    b->cursor = nir_after_cf_list(&impl->body);
1202 
1203    /* Run culling algorithms if culling is enabled.
1204     *
1205     * NGG culling can be enabled or disabled in runtime.
1206     * This is determined by a SGPR shader argument which is acccessed
1207     * by the following NIR intrinsic.
1208     */
1209 
1210    nir_if *if_cull_en = nir_push_if(b, nir_load_cull_any_enabled_amd(b));
1211    {
1212       nir_ssa_def *invocation_index = nir_load_local_invocation_index(b);
1213       nir_ssa_def *es_vertex_lds_addr = pervertex_lds_addr(b, invocation_index, pervertex_lds_bytes);
1214 
1215       /* ES invocations store their vertex data to LDS for GS threads to read. */
1216       if_es_thread = nir_push_if(b, nir_has_input_vertex_amd(b));
1217       {
1218          /* Store position components that are relevant to culling in LDS */
1219          nir_ssa_def *pre_cull_pos = nir_load_var(b, position_value_var);
1220          nir_ssa_def *pre_cull_w = nir_channel(b, pre_cull_pos, 3);
1221          nir_store_shared(b, pre_cull_w, es_vertex_lds_addr, .base = lds_es_pos_w);
1222          nir_ssa_def *pre_cull_x_div_w = nir_fdiv(b, nir_channel(b, pre_cull_pos, 0), pre_cull_w);
1223          nir_ssa_def *pre_cull_y_div_w = nir_fdiv(b, nir_channel(b, pre_cull_pos, 1), pre_cull_w);
1224          nir_store_shared(b, nir_vec2(b, pre_cull_x_div_w, pre_cull_y_div_w), es_vertex_lds_addr, .base = lds_es_pos_x);
1225 
1226          /* Clear out the ES accepted flag in LDS */
1227          nir_store_shared(b, nir_imm_zero(b, 1, 8), es_vertex_lds_addr, .align_mul = 4, .base = lds_es_vertex_accepted);
1228       }
1229       nir_pop_if(b, if_es_thread);
1230 
1231       nir_scoped_barrier(b, .execution_scope=NIR_SCOPE_WORKGROUP, .memory_scope=NIR_SCOPE_WORKGROUP,
1232                             .memory_semantics=NIR_MEMORY_ACQ_REL, .memory_modes=nir_var_mem_shared);
1233 
1234       nir_store_var(b, gs_accepted_var, nir_imm_bool(b, false), 0x1u);
1235       nir_store_var(b, prim_exp_arg_var, nir_imm_int(b, 1u << 31), 0x1u);
1236 
1237       /* GS invocations load the vertex data and perform the culling. */
1238       nir_if *if_gs_thread = nir_push_if(b, nir_has_input_primitive_amd(b));
1239       {
1240          /* Load vertex indices from input VGPRs */
1241          nir_ssa_def *vtx_idx[3] = {0};
1242          for (unsigned vertex = 0; vertex < 3; ++vertex)
1243             vtx_idx[vertex] = nir_load_var(b, nogs_state->gs_vtx_indices_vars[vertex]);
1244 
1245          nir_ssa_def *vtx_addr[3] = {0};
1246          nir_ssa_def *pos[3][4] = {0};
1247 
1248          /* Load W positions of vertices first because the culling code will use these first */
1249          for (unsigned vtx = 0; vtx < 3; ++vtx) {
1250             vtx_addr[vtx] = pervertex_lds_addr(b, vtx_idx[vtx], pervertex_lds_bytes);
1251             pos[vtx][3] = nir_load_shared(b, 1, 32, vtx_addr[vtx], .base = lds_es_pos_w);
1252             nir_store_var(b, gs_vtxaddr_vars[vtx], vtx_addr[vtx], 0x1u);
1253          }
1254 
1255          /* Load the X/W, Y/W positions of vertices */
1256          for (unsigned vtx = 0; vtx < 3; ++vtx) {
1257             nir_ssa_def *xy = nir_load_shared(b, 2, 32, vtx_addr[vtx], .base = lds_es_pos_x);
1258             pos[vtx][0] = nir_channel(b, xy, 0);
1259             pos[vtx][1] = nir_channel(b, xy, 1);
1260          }
1261 
1262          /* See if the current primitive is accepted */
1263          nir_ssa_def *accepted = ac_nir_cull_triangle(b, nir_imm_bool(b, true), pos);
1264          nir_store_var(b, gs_accepted_var, accepted, 0x1u);
1265 
1266          nir_if *if_gs_accepted = nir_push_if(b, accepted);
1267          {
1268             /* Store the accepted state to LDS for ES threads */
1269             for (unsigned vtx = 0; vtx < 3; ++vtx)
1270                nir_store_shared(b, nir_imm_intN_t(b, 0xff, 8), vtx_addr[vtx], .base = lds_es_vertex_accepted, .align_mul = 4u);
1271          }
1272          nir_pop_if(b, if_gs_accepted);
1273       }
1274       nir_pop_if(b, if_gs_thread);
1275 
1276       nir_scoped_barrier(b, .execution_scope=NIR_SCOPE_WORKGROUP, .memory_scope=NIR_SCOPE_WORKGROUP,
1277                             .memory_semantics=NIR_MEMORY_ACQ_REL, .memory_modes=nir_var_mem_shared);
1278 
1279       nir_store_var(b, es_accepted_var, nir_imm_bool(b, false), 0x1u);
1280 
1281       /* ES invocations load their accepted flag from LDS. */
1282       if_es_thread = nir_push_if(b, nir_has_input_vertex_amd(b));
1283       {
1284          nir_ssa_def *accepted = nir_load_shared(b, 1, 8u, es_vertex_lds_addr, .base = lds_es_vertex_accepted, .align_mul = 4u);
1285          nir_ssa_def *accepted_bool = nir_ine(b, accepted, nir_imm_intN_t(b, 0, 8));
1286          nir_store_var(b, es_accepted_var, accepted_bool, 0x1u);
1287       }
1288       nir_pop_if(b, if_es_thread);
1289 
1290       nir_ssa_def *es_accepted = nir_load_var(b, es_accepted_var);
1291 
1292       /* Repack the vertices that survived the culling. */
1293       wg_repack_result rep = repack_invocations_in_workgroup(b, es_accepted, ngg_scratch_lds_base_addr,
1294                                                             nogs_state->max_num_waves, nogs_state->wave_size);
1295       nir_ssa_def *num_live_vertices_in_workgroup = rep.num_repacked_invocations;
1296       nir_ssa_def *es_exporter_tid = rep.repacked_invocation_index;
1297 
1298       /* If all vertices are culled, set primitive count to 0 as well. */
1299       nir_ssa_def *num_exported_prims = nir_load_workgroup_num_input_primitives_amd(b);
1300       nir_ssa_def *fully_culled = nir_ieq_imm(b, num_live_vertices_in_workgroup, 0u);
1301       num_exported_prims = nir_bcsel(b, fully_culled, nir_imm_int(b, 0u), num_exported_prims);
1302 
1303       nir_if *if_wave_0 = nir_push_if(b, nir_ieq(b, nir_load_subgroup_id(b), nir_imm_int(b, 0)));
1304       {
1305          /* Tell the final vertex and primitive count to the HW. */
1306          nir_alloc_vertices_and_primitives_amd(b, num_live_vertices_in_workgroup, num_exported_prims);
1307       }
1308       nir_pop_if(b, if_wave_0);
1309 
1310       /* Vertex compaction. */
1311       compact_vertices_after_culling(b, nogs_state,
1312                                      repacked_arg_vars, gs_vtxaddr_vars,
1313                                      invocation_index, es_vertex_lds_addr,
1314                                      es_exporter_tid, num_live_vertices_in_workgroup, fully_culled,
1315                                      ngg_scratch_lds_base_addr, pervertex_lds_bytes, max_exported_args);
1316    }
1317    nir_push_else(b, if_cull_en);
1318    {
1319       /* When culling is disabled, we do the same as we would without culling. */
1320       nir_if *if_wave_0 = nir_push_if(b, nir_ieq(b, nir_load_subgroup_id(b), nir_imm_int(b, 0)));
1321       {
1322          nir_ssa_def *vtx_cnt = nir_load_workgroup_num_input_vertices_amd(b);
1323          nir_ssa_def *prim_cnt = nir_load_workgroup_num_input_primitives_amd(b);
1324          nir_alloc_vertices_and_primitives_amd(b, vtx_cnt, prim_cnt);
1325       }
1326       nir_pop_if(b, if_wave_0);
1327       nir_store_var(b, prim_exp_arg_var, emit_ngg_nogs_prim_exp_arg(b, nogs_state), 0x1u);
1328    }
1329    nir_pop_if(b, if_cull_en);
1330 
1331    /* Update shader arguments.
1332     *
1333     * The registers which hold information about the subgroup's
1334     * vertices and primitives are updated here, so the rest of the shader
1335     * doesn't need to worry about the culling.
1336     *
1337     * These "overwrite" intrinsics must be at top level control flow,
1338     * otherwise they can mess up the backend (eg. ACO's SSA).
1339     *
1340     * TODO:
1341     * A cleaner solution would be to simply replace all usages of these args
1342     * with the load of the variables.
1343     * However, this wouldn't work right now because the backend uses the arguments
1344     * for purposes not expressed in NIR, eg. VS input loads, etc.
1345     * This can change if VS input loads and other stuff are lowered to eg. load_buffer_amd.
1346     */
1347 
1348    if (b->shader->info.stage == MESA_SHADER_VERTEX)
1349       nogs_state->overwrite_args =
1350          nir_overwrite_vs_arguments_amd(b,
1351             nir_load_var(b, repacked_arg_vars[0]), nir_load_var(b, repacked_arg_vars[1]));
1352    else if (b->shader->info.stage == MESA_SHADER_TESS_EVAL)
1353       nogs_state->overwrite_args =
1354          nir_overwrite_tes_arguments_amd(b,
1355             nir_load_var(b, repacked_arg_vars[0]), nir_load_var(b, repacked_arg_vars[1]),
1356             nir_load_var(b, repacked_arg_vars[2]), nir_load_var(b, repacked_arg_vars[3]));
1357    else
1358       unreachable("Should be VS or TES.");
1359 }
1360 
1361 void
ac_nir_lower_ngg_nogs(nir_shader * shader,enum radeon_family family,unsigned max_num_es_vertices,unsigned num_vertices_per_primitives,unsigned max_workgroup_size,unsigned wave_size,bool can_cull,bool early_prim_export,bool passthrough,bool export_prim_id,bool provoking_vtx_last,bool use_edgeflags,bool has_prim_query,uint32_t instance_rate_inputs)1362 ac_nir_lower_ngg_nogs(nir_shader *shader,
1363                       enum radeon_family family,
1364                       unsigned max_num_es_vertices,
1365                       unsigned num_vertices_per_primitives,
1366                       unsigned max_workgroup_size,
1367                       unsigned wave_size,
1368                       bool can_cull,
1369                       bool early_prim_export,
1370                       bool passthrough,
1371                       bool export_prim_id,
1372                       bool provoking_vtx_last,
1373                       bool use_edgeflags,
1374                       bool has_prim_query,
1375                       uint32_t instance_rate_inputs)
1376 {
1377    nir_function_impl *impl = nir_shader_get_entrypoint(shader);
1378    assert(impl);
1379    assert(max_num_es_vertices && max_workgroup_size && wave_size);
1380    assert(!(can_cull && passthrough));
1381 
1382    nir_variable *position_value_var = nir_local_variable_create(impl, glsl_vec4_type(), "position_value");
1383    nir_variable *prim_exp_arg_var = nir_local_variable_create(impl, glsl_uint_type(), "prim_exp_arg");
1384    nir_variable *es_accepted_var = can_cull ? nir_local_variable_create(impl, glsl_bool_type(), "es_accepted") : NULL;
1385    nir_variable *gs_accepted_var = can_cull ? nir_local_variable_create(impl, glsl_bool_type(), "gs_accepted") : NULL;
1386 
1387    lower_ngg_nogs_state state = {
1388       .passthrough = passthrough,
1389       .export_prim_id = export_prim_id,
1390       .early_prim_export = early_prim_export,
1391       .use_edgeflags = use_edgeflags,
1392       .has_prim_query = has_prim_query,
1393       .can_cull = can_cull,
1394       .num_vertices_per_primitives = num_vertices_per_primitives,
1395       .provoking_vtx_idx = provoking_vtx_last ? (num_vertices_per_primitives - 1) : 0,
1396       .position_value_var = position_value_var,
1397       .prim_exp_arg_var = prim_exp_arg_var,
1398       .es_accepted_var = es_accepted_var,
1399       .gs_accepted_var = gs_accepted_var,
1400       .max_num_waves = DIV_ROUND_UP(max_workgroup_size, wave_size),
1401       .max_es_num_vertices = max_num_es_vertices,
1402       .wave_size = wave_size,
1403       .instance_rate_inputs = instance_rate_inputs,
1404    };
1405 
1406    const bool need_prim_id_store_shared =
1407       export_prim_id && shader->info.stage == MESA_SHADER_VERTEX;
1408 
1409    if (export_prim_id) {
1410       nir_variable *prim_id_var = nir_variable_create(shader, nir_var_shader_out, glsl_uint_type(), "ngg_prim_id");
1411       prim_id_var->data.location = VARYING_SLOT_PRIMITIVE_ID;
1412       prim_id_var->data.driver_location = VARYING_SLOT_PRIMITIVE_ID;
1413       prim_id_var->data.interpolation = INTERP_MODE_NONE;
1414       shader->info.outputs_written |= VARYING_BIT_PRIMITIVE_ID;
1415    }
1416 
1417    nir_builder builder;
1418    nir_builder *b = &builder; /* This is to avoid the & */
1419    nir_builder_init(b, impl);
1420 
1421    if (can_cull) {
1422       /* We need divergence info for culling shaders. */
1423       nir_divergence_analysis(shader);
1424       analyze_shader_before_culling(shader, &state);
1425       save_reusable_variables(b, &state);
1426    }
1427 
1428    nir_cf_list extracted;
1429    nir_cf_extract(&extracted, nir_before_cf_list(&impl->body), nir_after_cf_list(&impl->body));
1430    b->cursor = nir_before_cf_list(&impl->body);
1431 
1432    ngg_nogs_init_vertex_indices_vars(b, impl, &state);
1433 
1434    if (!can_cull) {
1435       /* Newer chips can use PRIMGEN_PASSTHRU_NO_MSG to skip gs_alloc_req for NGG passthrough. */
1436       if (!(passthrough && family >= CHIP_NAVI23)) {
1437          /* Allocate export space on wave 0 - confirm to the HW that we want to use all possible space */
1438          nir_if *if_wave_0 = nir_push_if(b, nir_ieq(b, nir_load_subgroup_id(b), nir_imm_int(b, 0)));
1439          {
1440             nir_ssa_def *vtx_cnt = nir_load_workgroup_num_input_vertices_amd(b);
1441             nir_ssa_def *prim_cnt = nir_load_workgroup_num_input_primitives_amd(b);
1442             nir_alloc_vertices_and_primitives_amd(b, vtx_cnt, prim_cnt);
1443          }
1444          nir_pop_if(b, if_wave_0);
1445       }
1446 
1447       /* Take care of early primitive export, otherwise just pack the primitive export argument */
1448       if (state.early_prim_export)
1449          emit_ngg_nogs_prim_export(b, &state, NULL);
1450       else
1451          nir_store_var(b, prim_exp_arg_var, emit_ngg_nogs_prim_exp_arg(b, &state), 0x1u);
1452    } else {
1453       add_deferred_attribute_culling(b, &extracted, &state);
1454       b->cursor = nir_after_cf_list(&impl->body);
1455 
1456       if (state.early_prim_export)
1457          emit_ngg_nogs_prim_export(b, &state, nir_load_var(b, state.prim_exp_arg_var));
1458    }
1459 
1460    if (need_prim_id_store_shared) {
1461       /* We need LDS space when VS needs to export the primitive ID. */
1462       state.total_lds_bytes = MAX2(state.total_lds_bytes, max_num_es_vertices * 4u);
1463 
1464       /* The LDS space aliases with what is used by culling, so we need a barrier. */
1465       if (can_cull) {
1466          nir_scoped_barrier(b, .execution_scope = NIR_SCOPE_WORKGROUP,
1467                                .memory_scope = NIR_SCOPE_WORKGROUP,
1468                                .memory_semantics = NIR_MEMORY_ACQ_REL,
1469                                .memory_modes = nir_var_mem_shared);
1470       }
1471 
1472       emit_ngg_nogs_prim_id_store_shared(b, &state);
1473 
1474       /* Wait for GS threads to store primitive ID in LDS. */
1475       nir_scoped_barrier(b, .execution_scope = NIR_SCOPE_WORKGROUP, .memory_scope = NIR_SCOPE_WORKGROUP,
1476                             .memory_semantics = NIR_MEMORY_ACQ_REL, .memory_modes = nir_var_mem_shared);
1477    }
1478 
1479    nir_intrinsic_instr *export_vertex_instr;
1480    nir_ssa_def *es_thread = can_cull ? nir_load_var(b, es_accepted_var) : nir_has_input_vertex_amd(b);
1481 
1482    nir_if *if_es_thread = nir_push_if(b, es_thread);
1483    {
1484       /* Run the actual shader */
1485       nir_cf_reinsert(&extracted, b->cursor);
1486       b->cursor = nir_after_cf_list(&if_es_thread->then_list);
1487 
1488       if (state.export_prim_id)
1489          emit_store_ngg_nogs_es_primitive_id(b);
1490 
1491       /* Export all vertex attributes (including the primitive ID) */
1492       export_vertex_instr = nir_export_vertex_amd(b);
1493    }
1494    nir_pop_if(b, if_es_thread);
1495 
1496    /* Take care of late primitive export */
1497    if (!state.early_prim_export) {
1498       emit_ngg_nogs_prim_export(b, &state, nir_load_var(b, prim_exp_arg_var));
1499    }
1500 
1501    if (can_cull) {
1502       /* Replace uniforms. */
1503       apply_reusable_variables(b, &state);
1504 
1505       /* Remove the redundant position output. */
1506       remove_extra_pos_outputs(shader, &state);
1507 
1508       /* After looking at the performance in apps eg. Doom Eternal, and The Witcher 3,
1509        * it seems that it's best to put the position export always at the end, and
1510        * then let ACO schedule it up (slightly) only when early prim export is used.
1511        */
1512       b->cursor = nir_before_instr(&export_vertex_instr->instr);
1513 
1514       nir_ssa_def *pos_val = nir_load_var(b, state.position_value_var);
1515       nir_io_semantics io_sem = { .location = VARYING_SLOT_POS, .num_slots = 1 };
1516       nir_store_output(b, pos_val, nir_imm_int(b, 0), .base = VARYING_SLOT_POS, .component = 0, .io_semantics = io_sem);
1517    }
1518 
1519    nir_metadata_preserve(impl, nir_metadata_none);
1520    nir_validate_shader(shader, "after emitting NGG VS/TES");
1521 
1522    /* Cleanup */
1523    nir_opt_dead_write_vars(shader);
1524    nir_lower_vars_to_ssa(shader);
1525    nir_remove_dead_variables(shader, nir_var_function_temp, NULL);
1526    nir_lower_alu_to_scalar(shader, NULL, NULL);
1527    nir_lower_phis_to_scalar(shader, true);
1528 
1529    if (can_cull) {
1530       /* It's beneficial to redo these opts after splitting the shader. */
1531       nir_opt_sink(shader, nir_move_load_input | nir_move_const_undef | nir_move_copies);
1532       nir_opt_move(shader, nir_move_load_input | nir_move_copies | nir_move_const_undef);
1533    }
1534 
1535    bool progress;
1536    do {
1537       progress = false;
1538       NIR_PASS(progress, shader, nir_opt_undef);
1539       NIR_PASS(progress, shader, nir_opt_dce);
1540       NIR_PASS(progress, shader, nir_opt_dead_cf);
1541 
1542       if (can_cull)
1543          progress |= cleanup_culling_shader_after_dce(shader, b->impl, &state);
1544    } while (progress);
1545 
1546    shader->info.shared_size = state.total_lds_bytes;
1547 }
1548 
1549 /**
1550  * Return the address of the LDS storage reserved for the N'th vertex,
1551  * where N is in emit order, meaning:
1552  * - during the finale, N is the invocation_index (within the workgroup)
1553  * - during vertex emit, i.e. while the API GS shader invocation is running,
1554  *   N = invocation_index * gs_max_out_vertices + emit_idx
1555  *   where emit_idx is the vertex index in the current API GS invocation.
1556  *
1557  * Goals of the LDS memory layout:
1558  * 1. Eliminate bank conflicts on write for geometry shaders that have all emits
1559  *    in uniform control flow
1560  * 2. Eliminate bank conflicts on read for export if, additionally, there is no
1561  *    culling
1562  * 3. Agnostic to the number of waves (since we don't know it before compiling)
1563  * 4. Allow coalescing of LDS instructions (ds_write_b128 etc.)
1564  * 5. Avoid wasting memory.
1565  *
1566  * We use an AoS layout due to point 4 (this also helps point 3). In an AoS
1567  * layout, elimination of bank conflicts requires that each vertex occupy an
1568  * odd number of dwords. We use the additional dword to store the output stream
1569  * index as well as a flag to indicate whether this vertex ends a primitive
1570  * for rasterization.
1571  *
1572  * Swizzling is required to satisfy points 1 and 2 simultaneously.
1573  *
1574  * Vertices are stored in export order (gsthread * gs_max_out_vertices + emitidx).
1575  * Indices are swizzled in groups of 32, which ensures point 1 without
1576  * disturbing point 2.
1577  *
1578  * \return an LDS pointer to type {[N x i32], [4 x i8]}
1579  */
1580 static nir_ssa_def *
ngg_gs_out_vertex_addr(nir_builder * b,nir_ssa_def * out_vtx_idx,lower_ngg_gs_state * s)1581 ngg_gs_out_vertex_addr(nir_builder *b, nir_ssa_def *out_vtx_idx, lower_ngg_gs_state *s)
1582 {
1583    unsigned write_stride_2exp = ffs(MAX2(b->shader->info.gs.vertices_out, 1)) - 1;
1584 
1585    /* gs_max_out_vertices = 2^(write_stride_2exp) * some odd number */
1586    if (write_stride_2exp) {
1587       nir_ssa_def *row = nir_ushr_imm(b, out_vtx_idx, 5);
1588       nir_ssa_def *swizzle = nir_iand_imm(b, row, (1u << write_stride_2exp) - 1u);
1589       out_vtx_idx = nir_ixor(b, out_vtx_idx, swizzle);
1590    }
1591 
1592    nir_ssa_def *out_vtx_offs = nir_imul_imm(b, out_vtx_idx, s->lds_bytes_per_gs_out_vertex);
1593    return nir_iadd_imm_nuw(b, out_vtx_offs, s->lds_addr_gs_out_vtx);
1594 }
1595 
1596 static nir_ssa_def *
ngg_gs_emit_vertex_addr(nir_builder * b,nir_ssa_def * gs_vtx_idx,lower_ngg_gs_state * s)1597 ngg_gs_emit_vertex_addr(nir_builder *b, nir_ssa_def *gs_vtx_idx, lower_ngg_gs_state *s)
1598 {
1599    nir_ssa_def *tid_in_tg = nir_load_local_invocation_index(b);
1600    nir_ssa_def *gs_out_vtx_base = nir_imul_imm(b, tid_in_tg, b->shader->info.gs.vertices_out);
1601    nir_ssa_def *out_vtx_idx = nir_iadd_nuw(b, gs_out_vtx_base, gs_vtx_idx);
1602 
1603    return ngg_gs_out_vertex_addr(b, out_vtx_idx, s);
1604 }
1605 
1606 static void
ngg_gs_clear_primflags(nir_builder * b,nir_ssa_def * num_vertices,unsigned stream,lower_ngg_gs_state * s)1607 ngg_gs_clear_primflags(nir_builder *b, nir_ssa_def *num_vertices, unsigned stream, lower_ngg_gs_state *s)
1608 {
1609    nir_ssa_def *zero_u8 = nir_imm_zero(b, 1, 8);
1610    nir_store_var(b, s->current_clear_primflag_idx_var, num_vertices, 0x1u);
1611 
1612    nir_loop *loop = nir_push_loop(b);
1613    {
1614       nir_ssa_def *current_clear_primflag_idx = nir_load_var(b, s->current_clear_primflag_idx_var);
1615       nir_if *if_break = nir_push_if(b, nir_uge(b, current_clear_primflag_idx, nir_imm_int(b, b->shader->info.gs.vertices_out)));
1616       {
1617          nir_jump(b, nir_jump_break);
1618       }
1619       nir_push_else(b, if_break);
1620       {
1621          nir_ssa_def *emit_vtx_addr = ngg_gs_emit_vertex_addr(b, current_clear_primflag_idx, s);
1622          nir_store_shared(b, zero_u8, emit_vtx_addr, .base = s->lds_offs_primflags + stream);
1623          nir_store_var(b, s->current_clear_primflag_idx_var, nir_iadd_imm_nuw(b, current_clear_primflag_idx, 1), 0x1u);
1624       }
1625       nir_pop_if(b, if_break);
1626    }
1627    nir_pop_loop(b, loop);
1628 }
1629 
1630 static void
ngg_gs_shader_query(nir_builder * b,nir_intrinsic_instr * intrin,lower_ngg_gs_state * s)1631 ngg_gs_shader_query(nir_builder *b, nir_intrinsic_instr *intrin, lower_ngg_gs_state *s)
1632 {
1633    nir_if *if_shader_query = nir_push_if(b, nir_load_shader_query_enabled_amd(b));
1634    nir_ssa_def *num_prims_in_wave = NULL;
1635 
1636    /* Calculate the "real" number of emitted primitives from the emitted GS vertices and primitives.
1637     * GS emits points, line strips or triangle strips.
1638     * Real primitives are points, lines or triangles.
1639     */
1640    if (nir_src_is_const(intrin->src[0]) && nir_src_is_const(intrin->src[1])) {
1641       unsigned gs_vtx_cnt = nir_src_as_uint(intrin->src[0]);
1642       unsigned gs_prm_cnt = nir_src_as_uint(intrin->src[1]);
1643       unsigned total_prm_cnt = gs_vtx_cnt - gs_prm_cnt * (s->num_vertices_per_primitive - 1u);
1644       nir_ssa_def *num_threads = nir_bit_count(b, nir_ballot(b, 1, s->wave_size, nir_imm_bool(b, true)));
1645       num_prims_in_wave = nir_imul_imm(b, num_threads, total_prm_cnt);
1646    } else {
1647       nir_ssa_def *gs_vtx_cnt = intrin->src[0].ssa;
1648       nir_ssa_def *prm_cnt = intrin->src[1].ssa;
1649       if (s->num_vertices_per_primitive > 1)
1650          prm_cnt = nir_iadd(b, nir_imul_imm(b, prm_cnt, -1u * (s->num_vertices_per_primitive - 1)), gs_vtx_cnt);
1651       num_prims_in_wave = nir_reduce(b, prm_cnt, .reduction_op = nir_op_iadd);
1652    }
1653 
1654    /* Store the query result to GDS using an atomic add. */
1655    nir_if *if_first_lane = nir_push_if(b, nir_elect(b, 1));
1656    nir_gds_atomic_add_amd(b, 32, num_prims_in_wave, nir_imm_int(b, 0), nir_imm_int(b, 0x100));
1657    nir_pop_if(b, if_first_lane);
1658 
1659    nir_pop_if(b, if_shader_query);
1660 }
1661 
1662 static bool
lower_ngg_gs_store_output(nir_builder * b,nir_intrinsic_instr * intrin,lower_ngg_gs_state * s)1663 lower_ngg_gs_store_output(nir_builder *b, nir_intrinsic_instr *intrin, lower_ngg_gs_state *s)
1664 {
1665    assert(nir_src_is_const(intrin->src[1]));
1666    b->cursor = nir_before_instr(&intrin->instr);
1667 
1668    unsigned writemask = nir_intrinsic_write_mask(intrin);
1669    unsigned base = nir_intrinsic_base(intrin);
1670    unsigned component_offset = nir_intrinsic_component(intrin);
1671    unsigned base_offset = nir_src_as_uint(intrin->src[1]);
1672    nir_io_semantics io_sem = nir_intrinsic_io_semantics(intrin);
1673 
1674    assert((base + base_offset) < VARYING_SLOT_MAX);
1675 
1676    nir_ssa_def *store_val = intrin->src[0].ssa;
1677 
1678    for (unsigned comp = 0; comp < 4; ++comp) {
1679       if (!(writemask & (1 << comp)))
1680          continue;
1681       unsigned stream = (io_sem.gs_streams >> (comp * 2)) & 0x3;
1682       if (!(b->shader->info.gs.active_stream_mask & (1 << stream)))
1683          continue;
1684 
1685       /* Small bitsize components consume the same amount of space as 32-bit components,
1686        * but 64-bit ones consume twice as many. (Vulkan spec 15.1.5)
1687        */
1688       unsigned num_consumed_components = DIV_ROUND_UP(store_val->bit_size, 32);
1689       nir_ssa_def *element = nir_channel(b, store_val, comp);
1690       if (num_consumed_components > 1)
1691          element = nir_extract_bits(b, &element, 1, 0, num_consumed_components, 32);
1692 
1693       /* Save output usage info. */
1694       gs_output_info *info = &s->output_info[io_sem.location];
1695       /* The same output should always belong to the same stream. */
1696       assert(!info->components_mask || info->stream == stream);
1697       info->stream = stream;
1698       info->components_mask |= BITFIELD_BIT(component_offset + comp * num_consumed_components);
1699 
1700       for (unsigned c = 0; c < num_consumed_components; ++c) {
1701          unsigned component_index =  (comp * num_consumed_components) + c + component_offset;
1702          unsigned base_index = base + base_offset + component_index / 4;
1703          component_index %= 4;
1704 
1705          /* Store the current component element */
1706          nir_ssa_def *component_element = element;
1707          if (num_consumed_components > 1)
1708             component_element = nir_channel(b, component_element, c);
1709          if (component_element->bit_size != 32)
1710             component_element = nir_u2u32(b, component_element);
1711 
1712          nir_store_var(b, s->output_vars[base_index][component_index], component_element, 0x1u);
1713       }
1714    }
1715 
1716    nir_instr_remove(&intrin->instr);
1717    return true;
1718 }
1719 
1720 static bool
lower_ngg_gs_emit_vertex_with_counter(nir_builder * b,nir_intrinsic_instr * intrin,lower_ngg_gs_state * s)1721 lower_ngg_gs_emit_vertex_with_counter(nir_builder *b, nir_intrinsic_instr *intrin, lower_ngg_gs_state *s)
1722 {
1723    b->cursor = nir_before_instr(&intrin->instr);
1724 
1725    unsigned stream = nir_intrinsic_stream_id(intrin);
1726    if (!(b->shader->info.gs.active_stream_mask & (1 << stream))) {
1727       nir_instr_remove(&intrin->instr);
1728       return true;
1729    }
1730 
1731    nir_ssa_def *gs_emit_vtx_idx = intrin->src[0].ssa;
1732    nir_ssa_def *current_vtx_per_prim = intrin->src[1].ssa;
1733    nir_ssa_def *gs_emit_vtx_addr = ngg_gs_emit_vertex_addr(b, gs_emit_vtx_idx, s);
1734 
1735    for (unsigned slot = 0; slot < VARYING_SLOT_MAX; ++slot) {
1736       unsigned packed_location = util_bitcount64((b->shader->info.outputs_written & BITFIELD64_MASK(slot)));
1737       gs_output_info *info = &s->output_info[slot];
1738       if (info->stream != stream || !info->components_mask)
1739          continue;
1740 
1741       unsigned mask = info->components_mask;
1742       while (mask) {
1743          int start, count;
1744          u_bit_scan_consecutive_range(&mask, &start, &count);
1745          nir_ssa_def *values[4] = {0};
1746          for (int c = start; c < start + count; ++c) {
1747             /* Load output from variable. */
1748             values[c - start] = nir_load_var(b, s->output_vars[slot][c]);
1749             /* Clear the variable (it is undefined after emit_vertex) */
1750             nir_store_var(b, s->output_vars[slot][c], nir_ssa_undef(b, 1, 32), 0x1);
1751          }
1752 
1753          nir_ssa_def *store_val = nir_vec(b, values, (unsigned)count);
1754          nir_store_shared(b, store_val, gs_emit_vtx_addr,
1755                           .base = packed_location * 16 + start * 4,
1756                           .align_mul = 4);
1757       }
1758    }
1759 
1760    /* Calculate and store per-vertex primitive flags based on vertex counts:
1761     * - bit 0: whether this vertex finishes a primitive (a real primitive, not the strip)
1762     * - bit 1: whether the primitive index is odd (if we are emitting triangle strips, otherwise always 0)
1763     * - bit 2: always 1 (so that we can use it for determining vertex liveness)
1764     */
1765 
1766    nir_ssa_def *completes_prim = nir_ige(b, current_vtx_per_prim, nir_imm_int(b, s->num_vertices_per_primitive - 1));
1767    nir_ssa_def *prim_flag = nir_bcsel(b, completes_prim, nir_imm_int(b, 0b101u), nir_imm_int(b, 0b100u));
1768 
1769    if (s->num_vertices_per_primitive == 3) {
1770       nir_ssa_def *odd = nir_iand_imm(b, current_vtx_per_prim, 1);
1771       prim_flag = nir_iadd_nuw(b, prim_flag, nir_ishl(b, odd, nir_imm_int(b, 1)));
1772    }
1773 
1774    nir_store_shared(b, nir_u2u8(b, prim_flag), gs_emit_vtx_addr, .base = s->lds_offs_primflags + stream, .align_mul = 4u);
1775    nir_instr_remove(&intrin->instr);
1776    return true;
1777 }
1778 
1779 static bool
lower_ngg_gs_end_primitive_with_counter(nir_builder * b,nir_intrinsic_instr * intrin,UNUSED lower_ngg_gs_state * s)1780 lower_ngg_gs_end_primitive_with_counter(nir_builder *b, nir_intrinsic_instr *intrin, UNUSED lower_ngg_gs_state *s)
1781 {
1782    b->cursor = nir_before_instr(&intrin->instr);
1783 
1784    /* These are not needed, we can simply remove them */
1785    nir_instr_remove(&intrin->instr);
1786    return true;
1787 }
1788 
1789 static bool
lower_ngg_gs_set_vertex_and_primitive_count(nir_builder * b,nir_intrinsic_instr * intrin,lower_ngg_gs_state * s)1790 lower_ngg_gs_set_vertex_and_primitive_count(nir_builder *b, nir_intrinsic_instr *intrin, lower_ngg_gs_state *s)
1791 {
1792    b->cursor = nir_before_instr(&intrin->instr);
1793 
1794    unsigned stream = nir_intrinsic_stream_id(intrin);
1795    if (stream > 0 && !(b->shader->info.gs.active_stream_mask & (1 << stream))) {
1796       nir_instr_remove(&intrin->instr);
1797       return true;
1798    }
1799 
1800    s->found_out_vtxcnt[stream] = true;
1801 
1802    /* Clear the primitive flags of non-emitted vertices */
1803    if (!nir_src_is_const(intrin->src[0]) || nir_src_as_uint(intrin->src[0]) < b->shader->info.gs.vertices_out)
1804       ngg_gs_clear_primflags(b, intrin->src[0].ssa, stream, s);
1805 
1806    ngg_gs_shader_query(b, intrin, s);
1807    nir_instr_remove(&intrin->instr);
1808    return true;
1809 }
1810 
1811 static bool
lower_ngg_gs_intrinsic(nir_builder * b,nir_instr * instr,void * state)1812 lower_ngg_gs_intrinsic(nir_builder *b, nir_instr *instr, void *state)
1813 {
1814    lower_ngg_gs_state *s = (lower_ngg_gs_state *) state;
1815 
1816    if (instr->type != nir_instr_type_intrinsic)
1817       return false;
1818 
1819    nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
1820 
1821    if (intrin->intrinsic == nir_intrinsic_store_output)
1822       return lower_ngg_gs_store_output(b, intrin, s);
1823    else if (intrin->intrinsic == nir_intrinsic_emit_vertex_with_counter)
1824       return lower_ngg_gs_emit_vertex_with_counter(b, intrin, s);
1825    else if (intrin->intrinsic == nir_intrinsic_end_primitive_with_counter)
1826       return lower_ngg_gs_end_primitive_with_counter(b, intrin, s);
1827    else if (intrin->intrinsic == nir_intrinsic_set_vertex_and_primitive_count)
1828       return lower_ngg_gs_set_vertex_and_primitive_count(b, intrin, s);
1829 
1830    return false;
1831 }
1832 
1833 static void
lower_ngg_gs_intrinsics(nir_shader * shader,lower_ngg_gs_state * s)1834 lower_ngg_gs_intrinsics(nir_shader *shader, lower_ngg_gs_state *s)
1835 {
1836    nir_shader_instructions_pass(shader, lower_ngg_gs_intrinsic, nir_metadata_none, s);
1837 }
1838 
1839 static void
ngg_gs_export_primitives(nir_builder * b,nir_ssa_def * max_num_out_prims,nir_ssa_def * tid_in_tg,nir_ssa_def * exporter_tid_in_tg,nir_ssa_def * primflag_0,lower_ngg_gs_state * s)1840 ngg_gs_export_primitives(nir_builder *b, nir_ssa_def *max_num_out_prims, nir_ssa_def *tid_in_tg,
1841                          nir_ssa_def *exporter_tid_in_tg, nir_ssa_def *primflag_0,
1842                          lower_ngg_gs_state *s)
1843 {
1844    nir_if *if_prim_export_thread = nir_push_if(b, nir_ilt(b, tid_in_tg, max_num_out_prims));
1845 
1846    /* Only bit 0 matters here - set it to 1 when the primitive should be null */
1847    nir_ssa_def *is_null_prim = nir_ixor(b, primflag_0, nir_imm_int(b, -1u));
1848 
1849    nir_ssa_def *vtx_indices[3] = {0};
1850    vtx_indices[s->num_vertices_per_primitive - 1] = exporter_tid_in_tg;
1851    if (s->num_vertices_per_primitive >= 2)
1852       vtx_indices[s->num_vertices_per_primitive - 2] = nir_isub(b, exporter_tid_in_tg, nir_imm_int(b, 1));
1853    if (s->num_vertices_per_primitive == 3)
1854       vtx_indices[s->num_vertices_per_primitive - 3] = nir_isub(b, exporter_tid_in_tg, nir_imm_int(b, 2));
1855 
1856    if (s->num_vertices_per_primitive == 3) {
1857       /* API GS outputs triangle strips, but NGG HW understands triangles.
1858        * We already know the triangles due to how we set the primitive flags, but we need to
1859        * make sure the vertex order is so that the front/back is correct, and the provoking vertex is kept.
1860        */
1861 
1862       nir_ssa_def *is_odd = nir_ubfe(b, primflag_0, nir_imm_int(b, 1), nir_imm_int(b, 1));
1863       if (!s->provoking_vertex_last) {
1864          vtx_indices[1] = nir_iadd(b, vtx_indices[1], is_odd);
1865          vtx_indices[2] = nir_isub(b, vtx_indices[2], is_odd);
1866       } else {
1867          vtx_indices[0] = nir_iadd(b, vtx_indices[0], is_odd);
1868          vtx_indices[1] = nir_isub(b, vtx_indices[1], is_odd);
1869       }
1870    }
1871 
1872    nir_ssa_def *arg = emit_pack_ngg_prim_exp_arg(b, s->num_vertices_per_primitive, vtx_indices, is_null_prim, false);
1873    nir_export_primitive_amd(b, arg);
1874    nir_pop_if(b, if_prim_export_thread);
1875 }
1876 
1877 static void
ngg_gs_export_vertices(nir_builder * b,nir_ssa_def * max_num_out_vtx,nir_ssa_def * tid_in_tg,nir_ssa_def * out_vtx_lds_addr,lower_ngg_gs_state * s)1878 ngg_gs_export_vertices(nir_builder *b, nir_ssa_def *max_num_out_vtx, nir_ssa_def *tid_in_tg,
1879                        nir_ssa_def *out_vtx_lds_addr, lower_ngg_gs_state *s)
1880 {
1881    nir_if *if_vtx_export_thread = nir_push_if(b, nir_ilt(b, tid_in_tg, max_num_out_vtx));
1882    nir_ssa_def *exported_out_vtx_lds_addr = out_vtx_lds_addr;
1883 
1884    if (!s->output_compile_time_known) {
1885       /* Vertex compaction.
1886        * The current thread will export a vertex that was live in another invocation.
1887        * Load the index of the vertex that the current thread will have to export.
1888        */
1889       nir_ssa_def *exported_vtx_idx = nir_load_shared(b, 1, 8, out_vtx_lds_addr, .base = s->lds_offs_primflags + 1);
1890       exported_out_vtx_lds_addr = ngg_gs_out_vertex_addr(b, nir_u2u32(b, exported_vtx_idx), s);
1891    }
1892 
1893    /* Remember proper bit sizes of output variables. */
1894    uint8_t out_bitsizes[VARYING_SLOT_MAX];
1895    memset(out_bitsizes, 32, VARYING_SLOT_MAX);
1896    nir_foreach_shader_out_variable(var, b->shader) {
1897       /* Check 8/16-bit. All others should be lowered to 32-bit already. */
1898       unsigned bit_size = glsl_base_type_bit_size(glsl_get_base_type(glsl_without_array(var->type)));
1899       if (bit_size == 8 || bit_size == 16)
1900          out_bitsizes[var->data.location] = bit_size;
1901    }
1902 
1903    for (unsigned slot = 0; slot < VARYING_SLOT_MAX; ++slot) {
1904       if (!(b->shader->info.outputs_written & BITFIELD64_BIT(slot)))
1905          continue;
1906 
1907       gs_output_info *info = &s->output_info[slot];
1908       if (!info->components_mask || info->stream != 0)
1909          continue;
1910 
1911       unsigned packed_location = util_bitcount64((b->shader->info.outputs_written & BITFIELD64_MASK(slot)));
1912       nir_io_semantics io_sem = { .location = slot, .num_slots = 1 };
1913 
1914       unsigned mask = info->components_mask;
1915       while (mask) {
1916          int start, count;
1917          u_bit_scan_consecutive_range(&mask, &start, &count);
1918          nir_ssa_def *load =
1919             nir_load_shared(b, count, 32, exported_out_vtx_lds_addr,
1920                             .base = packed_location * 16 + start * 4,
1921                             .align_mul = 4);
1922 
1923          /* Convert to the expected bit size of the output variable. */
1924          if (out_bitsizes[slot] != 32)
1925             load = nir_u2u(b, load, out_bitsizes[slot]);
1926 
1927          nir_store_output(b, load, nir_imm_int(b, 0), .base = slot, .io_semantics = io_sem,
1928                           .component = start, .write_mask = BITFIELD_MASK(count));
1929       }
1930    }
1931 
1932    nir_export_vertex_amd(b);
1933    nir_pop_if(b, if_vtx_export_thread);
1934 }
1935 
1936 static void
ngg_gs_setup_vertex_compaction(nir_builder * b,nir_ssa_def * vertex_live,nir_ssa_def * tid_in_tg,nir_ssa_def * exporter_tid_in_tg,lower_ngg_gs_state * s)1937 ngg_gs_setup_vertex_compaction(nir_builder *b, nir_ssa_def *vertex_live, nir_ssa_def *tid_in_tg,
1938                                nir_ssa_def *exporter_tid_in_tg, lower_ngg_gs_state *s)
1939 {
1940    assert(vertex_live->bit_size == 1);
1941    nir_if *if_vertex_live = nir_push_if(b, vertex_live);
1942    {
1943       /* Setup the vertex compaction.
1944        * Save the current thread's id for the thread which will export the current vertex.
1945        * We reuse stream 1 of the primitive flag of the other thread's vertex for storing this.
1946        */
1947 
1948       nir_ssa_def *exporter_lds_addr = ngg_gs_out_vertex_addr(b, exporter_tid_in_tg, s);
1949       nir_ssa_def *tid_in_tg_u8 = nir_u2u8(b, tid_in_tg);
1950       nir_store_shared(b, tid_in_tg_u8, exporter_lds_addr, .base = s->lds_offs_primflags + 1);
1951    }
1952    nir_pop_if(b, if_vertex_live);
1953 }
1954 
1955 static nir_ssa_def *
ngg_gs_load_out_vtx_primflag_0(nir_builder * b,nir_ssa_def * tid_in_tg,nir_ssa_def * vtx_lds_addr,nir_ssa_def * max_num_out_vtx,lower_ngg_gs_state * s)1956 ngg_gs_load_out_vtx_primflag_0(nir_builder *b, nir_ssa_def *tid_in_tg, nir_ssa_def *vtx_lds_addr,
1957                                nir_ssa_def *max_num_out_vtx, lower_ngg_gs_state *s)
1958 {
1959    nir_ssa_def *zero = nir_imm_int(b, 0);
1960 
1961    nir_if *if_outvtx_thread = nir_push_if(b, nir_ilt(b, tid_in_tg, max_num_out_vtx));
1962    nir_ssa_def *primflag_0 = nir_load_shared(b, 1, 8, vtx_lds_addr, .base = s->lds_offs_primflags, .align_mul = 4u);
1963    primflag_0 = nir_u2u32(b, primflag_0);
1964    nir_pop_if(b, if_outvtx_thread);
1965 
1966    return nir_if_phi(b, primflag_0, zero);
1967 }
1968 
1969 static void
ngg_gs_finale(nir_builder * b,lower_ngg_gs_state * s)1970 ngg_gs_finale(nir_builder *b, lower_ngg_gs_state *s)
1971 {
1972    nir_ssa_def *tid_in_tg = nir_load_local_invocation_index(b);
1973    nir_ssa_def *max_vtxcnt = nir_load_workgroup_num_input_vertices_amd(b);
1974    nir_ssa_def *max_prmcnt = max_vtxcnt; /* They are currently practically the same; both RADV and RadeonSI do this. */
1975    nir_ssa_def *out_vtx_lds_addr = ngg_gs_out_vertex_addr(b, tid_in_tg, s);
1976 
1977    if (s->output_compile_time_known) {
1978       /* When the output is compile-time known, the GS writes all possible vertices and primitives it can.
1979        * The gs_alloc_req needs to happen on one wave only, otherwise the HW hangs.
1980        */
1981       nir_if *if_wave_0 = nir_push_if(b, nir_ieq(b, nir_load_subgroup_id(b), nir_imm_zero(b, 1, 32)));
1982       nir_alloc_vertices_and_primitives_amd(b, max_vtxcnt, max_prmcnt);
1983       nir_pop_if(b, if_wave_0);
1984    }
1985 
1986    /* Workgroup barrier: wait for all GS threads to finish */
1987    nir_scoped_barrier(b, .execution_scope=NIR_SCOPE_WORKGROUP, .memory_scope=NIR_SCOPE_WORKGROUP,
1988                          .memory_semantics=NIR_MEMORY_ACQ_REL, .memory_modes=nir_var_mem_shared);
1989 
1990    nir_ssa_def *out_vtx_primflag_0 = ngg_gs_load_out_vtx_primflag_0(b, tid_in_tg, out_vtx_lds_addr, max_vtxcnt, s);
1991 
1992    if (s->output_compile_time_known) {
1993       ngg_gs_export_primitives(b, max_vtxcnt, tid_in_tg, tid_in_tg, out_vtx_primflag_0, s);
1994       ngg_gs_export_vertices(b, max_vtxcnt, tid_in_tg, out_vtx_lds_addr, s);
1995       return;
1996    }
1997 
1998    /* When the output vertex count is not known at compile time:
1999     * There may be gaps between invocations that have live vertices, but NGG hardware
2000     * requires that the invocations that export vertices are packed (ie. compact).
2001     * To ensure this, we need to repack invocations that have a live vertex.
2002     */
2003    nir_ssa_def *vertex_live = nir_ine(b, out_vtx_primflag_0, nir_imm_zero(b, 1, out_vtx_primflag_0->bit_size));
2004    wg_repack_result rep = repack_invocations_in_workgroup(b, vertex_live, s->lds_addr_gs_scratch, s->max_num_waves, s->wave_size);
2005 
2006    nir_ssa_def *workgroup_num_vertices = rep.num_repacked_invocations;
2007    nir_ssa_def *exporter_tid_in_tg = rep.repacked_invocation_index;
2008 
2009    /* When the workgroup emits 0 total vertices, we also must export 0 primitives (otherwise the HW can hang). */
2010    nir_ssa_def *any_output = nir_ine(b, workgroup_num_vertices, nir_imm_int(b, 0));
2011    max_prmcnt = nir_bcsel(b, any_output, max_prmcnt, nir_imm_int(b, 0));
2012 
2013    /* Allocate export space. We currently don't compact primitives, just use the maximum number. */
2014    nir_if *if_wave_0 = nir_push_if(b, nir_ieq(b, nir_load_subgroup_id(b), nir_imm_zero(b, 1, 32)));
2015    nir_alloc_vertices_and_primitives_amd(b, workgroup_num_vertices, max_prmcnt);
2016    nir_pop_if(b, if_wave_0);
2017 
2018    /* Vertex compaction. This makes sure there are no gaps between threads that export vertices. */
2019    ngg_gs_setup_vertex_compaction(b, vertex_live, tid_in_tg, exporter_tid_in_tg, s);
2020 
2021    /* Workgroup barrier: wait for all LDS stores to finish. */
2022    nir_scoped_barrier(b, .execution_scope=NIR_SCOPE_WORKGROUP, .memory_scope=NIR_SCOPE_WORKGROUP,
2023                         .memory_semantics=NIR_MEMORY_ACQ_REL, .memory_modes=nir_var_mem_shared);
2024 
2025    ngg_gs_export_primitives(b, max_prmcnt, tid_in_tg, exporter_tid_in_tg, out_vtx_primflag_0, s);
2026    ngg_gs_export_vertices(b, workgroup_num_vertices, tid_in_tg, out_vtx_lds_addr, s);
2027 }
2028 
2029 void
ac_nir_lower_ngg_gs(nir_shader * shader,unsigned wave_size,unsigned max_workgroup_size,unsigned esgs_ring_lds_bytes,unsigned gs_out_vtx_bytes,unsigned gs_total_out_vtx_bytes,bool provoking_vertex_last)2030 ac_nir_lower_ngg_gs(nir_shader *shader,
2031                     unsigned wave_size,
2032                     unsigned max_workgroup_size,
2033                     unsigned esgs_ring_lds_bytes,
2034                     unsigned gs_out_vtx_bytes,
2035                     unsigned gs_total_out_vtx_bytes,
2036                     bool provoking_vertex_last)
2037 {
2038    nir_function_impl *impl = nir_shader_get_entrypoint(shader);
2039    assert(impl);
2040 
2041    lower_ngg_gs_state state = {
2042       .max_num_waves = DIV_ROUND_UP(max_workgroup_size, wave_size),
2043       .wave_size = wave_size,
2044       .lds_addr_gs_out_vtx = esgs_ring_lds_bytes,
2045       .lds_addr_gs_scratch = ALIGN(esgs_ring_lds_bytes + gs_total_out_vtx_bytes, 8u /* for the repacking code */),
2046       .lds_offs_primflags = gs_out_vtx_bytes,
2047       .lds_bytes_per_gs_out_vertex = gs_out_vtx_bytes + 4u,
2048       .provoking_vertex_last = provoking_vertex_last,
2049    };
2050 
2051    unsigned lds_scratch_bytes = DIV_ROUND_UP(state.max_num_waves, 4u) * 4u;
2052    unsigned total_lds_bytes = state.lds_addr_gs_scratch + lds_scratch_bytes;
2053    shader->info.shared_size = total_lds_bytes;
2054 
2055    nir_gs_count_vertices_and_primitives(shader, state.const_out_vtxcnt, state.const_out_prmcnt, 4u);
2056    state.output_compile_time_known = state.const_out_vtxcnt[0] == shader->info.gs.vertices_out &&
2057                                      state.const_out_prmcnt[0] != -1;
2058 
2059    if (!state.output_compile_time_known)
2060       state.current_clear_primflag_idx_var = nir_local_variable_create(impl, glsl_uint_type(), "current_clear_primflag_idx");
2061 
2062    if (shader->info.gs.output_primitive == SHADER_PRIM_POINTS)
2063       state.num_vertices_per_primitive = 1;
2064    else if (shader->info.gs.output_primitive == SHADER_PRIM_LINE_STRIP)
2065       state.num_vertices_per_primitive = 2;
2066    else if (shader->info.gs.output_primitive == SHADER_PRIM_TRIANGLE_STRIP)
2067       state.num_vertices_per_primitive = 3;
2068    else
2069       unreachable("Invalid GS output primitive.");
2070 
2071    /* Extract the full control flow. It is going to be wrapped in an if statement. */
2072    nir_cf_list extracted;
2073    nir_cf_extract(&extracted, nir_before_cf_list(&impl->body), nir_after_cf_list(&impl->body));
2074 
2075    nir_builder builder;
2076    nir_builder *b = &builder; /* This is to avoid the & */
2077    nir_builder_init(b, impl);
2078    b->cursor = nir_before_cf_list(&impl->body);
2079 
2080    /* Workgroup barrier: wait for ES threads */
2081    nir_scoped_barrier(b, .execution_scope=NIR_SCOPE_WORKGROUP, .memory_scope=NIR_SCOPE_WORKGROUP,
2082                          .memory_semantics=NIR_MEMORY_ACQ_REL, .memory_modes=nir_var_mem_shared);
2083 
2084    /* Wrap the GS control flow. */
2085    nir_if *if_gs_thread = nir_push_if(b, nir_has_input_primitive_amd(b));
2086 
2087    /* Create and initialize output variables */
2088    for (unsigned slot = 0; slot < VARYING_SLOT_MAX; ++slot) {
2089       for (unsigned comp = 0; comp < 4; ++comp) {
2090          state.output_vars[slot][comp] = nir_local_variable_create(impl, glsl_uint_type(), "output");
2091       }
2092    }
2093 
2094    nir_cf_reinsert(&extracted, b->cursor);
2095    b->cursor = nir_after_cf_list(&if_gs_thread->then_list);
2096    nir_pop_if(b, if_gs_thread);
2097 
2098    /* Lower the GS intrinsics */
2099    lower_ngg_gs_intrinsics(shader, &state);
2100    b->cursor = nir_after_cf_list(&impl->body);
2101 
2102    if (!state.found_out_vtxcnt[0]) {
2103       fprintf(stderr, "Could not find set_vertex_and_primitive_count for stream 0. This would hang your GPU.");
2104       abort();
2105    }
2106 
2107    /* Emit the finale sequence */
2108    ngg_gs_finale(b, &state);
2109    nir_validate_shader(shader, "after emitting NGG GS");
2110 
2111    /* Cleanup */
2112    nir_lower_vars_to_ssa(shader);
2113    nir_remove_dead_variables(shader, nir_var_function_temp, NULL);
2114    nir_metadata_preserve(impl, nir_metadata_none);
2115 }
2116 
2117 static void
ms_store_prim_indices(nir_builder * b,nir_ssa_def * val,nir_ssa_def * offset_src,lower_ngg_ms_state * s)2118 ms_store_prim_indices(nir_builder *b,
2119                       nir_ssa_def *val,
2120                       nir_ssa_def *offset_src,
2121                       lower_ngg_ms_state *s)
2122 {
2123    assert(val->num_components <= 3);
2124 
2125    if (!offset_src)
2126       offset_src = nir_imm_int(b, 0);
2127 
2128    nir_store_shared(b, nir_u2u8(b, val), offset_src, .base = s->layout.lds.indices_addr);
2129 }
2130 
2131 static nir_ssa_def *
ms_load_prim_indices(nir_builder * b,nir_ssa_def * offset_src,lower_ngg_ms_state * s)2132 ms_load_prim_indices(nir_builder *b,
2133                      nir_ssa_def *offset_src,
2134                      lower_ngg_ms_state *s)
2135 {
2136    if (!offset_src)
2137       offset_src = nir_imm_int(b, 0);
2138 
2139    return nir_load_shared(b, 1, 8, offset_src, .base = s->layout.lds.indices_addr);
2140 }
2141 
2142 static void
ms_store_num_prims(nir_builder * b,nir_ssa_def * store_val,lower_ngg_ms_state * s)2143 ms_store_num_prims(nir_builder *b,
2144                    nir_ssa_def *store_val,
2145                    lower_ngg_ms_state *s)
2146 {
2147    nir_ssa_def *addr = nir_imm_int(b, 0);
2148    nir_store_shared(b, nir_u2u32(b, store_val), addr, .base = s->layout.lds.workgroup_info_addr + lds_ms_num_prims);
2149 }
2150 
2151 static nir_ssa_def *
ms_load_num_prims(nir_builder * b,lower_ngg_ms_state * s)2152 ms_load_num_prims(nir_builder *b,
2153                   lower_ngg_ms_state *s)
2154 {
2155    nir_ssa_def *addr = nir_imm_int(b, 0);
2156    return nir_load_shared(b, 1, 32, addr, .base = s->layout.lds.workgroup_info_addr + lds_ms_num_prims);
2157 }
2158 
2159 static nir_ssa_def *
lower_ms_store_output(nir_builder * b,nir_intrinsic_instr * intrin,lower_ngg_ms_state * s)2160 lower_ms_store_output(nir_builder *b,
2161                       nir_intrinsic_instr *intrin,
2162                       lower_ngg_ms_state *s)
2163 {
2164    nir_io_semantics io_sem = nir_intrinsic_io_semantics(intrin);
2165    nir_ssa_def *store_val = intrin->src[0].ssa;
2166 
2167    /* Component makes no sense here. */
2168    assert(nir_intrinsic_component(intrin) == 0);
2169 
2170    if (io_sem.location == VARYING_SLOT_PRIMITIVE_COUNT) {
2171       /* Total number of primitives output by the mesh shader workgroup.
2172        * This can be read and written by any invocation any number of times.
2173        */
2174 
2175       /* Base, offset and component make no sense here. */
2176       assert(nir_src_is_const(intrin->src[1]) && nir_src_as_uint(intrin->src[1]) == 0);
2177 
2178       ms_store_num_prims(b, store_val, s);
2179    } else if (io_sem.location == VARYING_SLOT_PRIMITIVE_INDICES) {
2180       /* Contrary to the name, these are not primitive indices, but
2181        * vertex indices for each vertex of the output primitives.
2182        * The Mesh NV API has these stored in a flat array.
2183        */
2184 
2185       nir_ssa_def *offset_src = nir_get_io_offset_src(intrin)->ssa;
2186       ms_store_prim_indices(b, store_val, offset_src, s);
2187    } else {
2188       unreachable("Invalid mesh shader output");
2189    }
2190 
2191    return NIR_LOWER_INSTR_PROGRESS_REPLACE;
2192 }
2193 
2194 static nir_ssa_def *
lower_ms_load_output(nir_builder * b,nir_intrinsic_instr * intrin,lower_ngg_ms_state * s)2195 lower_ms_load_output(nir_builder *b,
2196                      nir_intrinsic_instr *intrin,
2197                      lower_ngg_ms_state *s)
2198 {
2199    nir_io_semantics io_sem = nir_intrinsic_io_semantics(intrin);
2200 
2201    /* Component makes no sense here. */
2202    assert(nir_intrinsic_component(intrin) == 0);
2203 
2204    if (io_sem.location == VARYING_SLOT_PRIMITIVE_COUNT) {
2205       /* Base, offset and component make no sense here. */
2206       assert(nir_src_is_const(intrin->src[1]) && nir_src_as_uint(intrin->src[1]) == 0);
2207 
2208       return ms_load_num_prims(b, s);
2209    } else if (io_sem.location == VARYING_SLOT_PRIMITIVE_INDICES) {
2210       nir_ssa_def *offset_src = nir_get_io_offset_src(intrin)->ssa;
2211       nir_ssa_def *index = ms_load_prim_indices(b, offset_src, s);
2212       return nir_u2u(b, index, intrin->dest.ssa.bit_size);
2213    }
2214 
2215    unreachable("Invalid mesh shader output");
2216 }
2217 
2218 static nir_ssa_def *
ms_arrayed_output_base_addr(nir_builder * b,nir_ssa_def * arr_index,unsigned driver_location,unsigned num_arrayed_outputs)2219 ms_arrayed_output_base_addr(nir_builder *b,
2220                             nir_ssa_def *arr_index,
2221                             unsigned driver_location,
2222                             unsigned num_arrayed_outputs)
2223 {
2224    /* Address offset of the array item (vertex or primitive). */
2225    unsigned arr_index_stride = num_arrayed_outputs * 16u;
2226    nir_ssa_def *arr_index_off = nir_imul_imm(b, arr_index, arr_index_stride);
2227 
2228    /* IO address offset within the vertex or primitive data. */
2229    unsigned io_offset = driver_location * 16u;
2230    nir_ssa_def *io_off = nir_imm_int(b, io_offset);
2231 
2232    return nir_iadd_nuw(b, arr_index_off, io_off);
2233 }
2234 
2235 static void
update_ms_output_info_slot(lower_ngg_ms_state * s,unsigned slot,unsigned base_off,uint32_t components_mask)2236 update_ms_output_info_slot(lower_ngg_ms_state *s,
2237                            unsigned slot, unsigned base_off,
2238                            uint32_t components_mask)
2239 {
2240    while (components_mask) {
2241       s->output_info[slot + base_off].components_mask |= components_mask & 0xF;
2242 
2243       components_mask >>= 4;
2244       base_off++;
2245    }
2246 }
2247 
2248 static void
update_ms_output_info(nir_intrinsic_instr * intrin,const ms_out_part * out,lower_ngg_ms_state * s)2249 update_ms_output_info(nir_intrinsic_instr *intrin,
2250                       const ms_out_part *out,
2251                       lower_ngg_ms_state *s)
2252 {
2253    nir_io_semantics io_sem = nir_intrinsic_io_semantics(intrin);
2254    nir_src *base_offset_src = nir_get_io_offset_src(intrin);
2255    uint32_t write_mask = nir_intrinsic_write_mask(intrin);
2256    unsigned component_offset = nir_intrinsic_component(intrin);
2257 
2258    nir_ssa_def *store_val = intrin->src[0].ssa;
2259    write_mask = util_widen_mask(write_mask, DIV_ROUND_UP(store_val->bit_size, 32));
2260    uint32_t components_mask = write_mask << component_offset;
2261 
2262    if (nir_src_is_const(*base_offset_src)) {
2263       /* Simply mark the components of the current slot as used. */
2264       unsigned base_off = nir_src_as_uint(*base_offset_src);
2265       update_ms_output_info_slot(s, io_sem.location, base_off, components_mask);
2266    } else {
2267       /* Indirect offset: mark the components of all slots as used. */
2268       for (unsigned base_off = 0; base_off < io_sem.num_slots; ++base_off)
2269          update_ms_output_info_slot(s, io_sem.location, base_off, components_mask);
2270    }
2271 }
2272 
2273 static nir_ssa_def *
regroup_store_val(nir_builder * b,nir_ssa_def * store_val)2274 regroup_store_val(nir_builder *b, nir_ssa_def *store_val)
2275 {
2276    /* Vulkan spec 15.1.4-15.1.5:
2277     *
2278     * The shader interface consists of output slots with 4x 32-bit components.
2279     * Small bitsize components consume the same space as 32-bit components,
2280     * but 64-bit ones consume twice as much.
2281     *
2282     * The same output slot may consist of components of different bit sizes.
2283     * Therefore for simplicity we don't store small bitsize components
2284     * contiguously, but pad them instead. In practice, they are converted to
2285     * 32-bit and then stored contiguously.
2286     */
2287 
2288    if (store_val->bit_size < 32) {
2289       assert(store_val->num_components <= 4);
2290       nir_ssa_def *comps[4] = {0};
2291       for (unsigned c = 0; c < store_val->num_components; ++c)
2292          comps[c] = nir_u2u32(b, nir_channel(b, store_val, c));
2293       return nir_vec(b, comps, store_val->num_components);
2294    }
2295 
2296    return store_val;
2297 }
2298 
2299 static nir_ssa_def *
regroup_load_val(nir_builder * b,nir_ssa_def * load,unsigned dest_bit_size)2300 regroup_load_val(nir_builder *b, nir_ssa_def *load, unsigned dest_bit_size)
2301 {
2302    if (dest_bit_size == load->bit_size)
2303       return load;
2304 
2305    /* Small bitsize components are not stored contiguously, take care of that here. */
2306    unsigned num_components = load->num_components;
2307    assert(num_components <= 4);
2308    nir_ssa_def *components[4] = {0};
2309    for (unsigned i = 0; i < num_components; ++i)
2310       components[i] = nir_u2u(b, nir_channel(b, load, i), dest_bit_size);
2311 
2312    return nir_vec(b, components, num_components);
2313 }
2314 
2315 static const ms_out_part *
ms_get_out_layout_part(unsigned location,shader_info * info,ms_out_mode * out_mode,lower_ngg_ms_state * s)2316 ms_get_out_layout_part(unsigned location,
2317                        shader_info *info,
2318                        ms_out_mode *out_mode,
2319                        lower_ngg_ms_state *s)
2320 {
2321    uint64_t mask = BITFIELD64_BIT(location);
2322 
2323    if (info->per_primitive_outputs & mask) {
2324       if (mask & s->layout.lds.prm_attr.mask) {
2325          *out_mode = ms_out_mode_lds;
2326          return &s->layout.lds.prm_attr;
2327       } else if (mask & s->layout.vram.prm_attr.mask) {
2328          *out_mode = ms_out_mode_vram;
2329          return &s->layout.vram.prm_attr;
2330       } else if (mask & s->layout.var.prm_attr.mask) {
2331          *out_mode = ms_out_mode_var;
2332          return &s->layout.var.prm_attr;
2333       }
2334    } else {
2335       if (mask & s->layout.lds.vtx_attr.mask) {
2336          *out_mode = ms_out_mode_lds;
2337          return &s->layout.lds.vtx_attr;
2338       } else if (mask & s->layout.vram.vtx_attr.mask) {
2339          *out_mode = ms_out_mode_vram;
2340          return &s->layout.vram.vtx_attr;
2341       } else if (mask & s->layout.var.vtx_attr.mask) {
2342          *out_mode = ms_out_mode_var;
2343          return &s->layout.var.vtx_attr;
2344       }
2345    }
2346 
2347    unreachable("Couldn't figure out mesh shader output mode.");
2348 }
2349 
2350 static void
ms_store_arrayed_output_intrin(nir_builder * b,nir_intrinsic_instr * intrin,lower_ngg_ms_state * s)2351 ms_store_arrayed_output_intrin(nir_builder *b,
2352                                nir_intrinsic_instr *intrin,
2353                                lower_ngg_ms_state *s)
2354 {
2355    ms_out_mode out_mode;
2356    unsigned location = nir_intrinsic_io_semantics(intrin).location;
2357    const ms_out_part *out = ms_get_out_layout_part(location, &b->shader->info, &out_mode, s);
2358    update_ms_output_info(intrin, out, s);
2359 
2360    /* We compact the LDS size (we don't reserve LDS space for outputs which can
2361     * be stored in variables), so we can't rely on the original driver_location.
2362     * Instead, we compute the first free location based on the output mask.
2363     */
2364    unsigned driver_location = util_bitcount64(out->mask & u_bit_consecutive64(0, location));
2365    unsigned component_offset = nir_intrinsic_component(intrin);
2366    unsigned write_mask = nir_intrinsic_write_mask(intrin);
2367    unsigned num_outputs = util_bitcount64(out->mask);
2368    unsigned const_off = out->addr + component_offset * 4;
2369 
2370    nir_ssa_def *store_val = regroup_store_val(b, intrin->src[0].ssa);
2371    nir_ssa_def *arr_index = nir_get_io_arrayed_index_src(intrin)->ssa;
2372    nir_ssa_def *base_addr = ms_arrayed_output_base_addr(b, arr_index, driver_location, num_outputs);
2373    nir_ssa_def *base_offset = nir_get_io_offset_src(intrin)->ssa;
2374    nir_ssa_def *base_addr_off = nir_imul_imm(b, base_offset, 16u);
2375    nir_ssa_def *addr = nir_iadd_nuw(b, base_addr, base_addr_off);
2376 
2377    if (out_mode == ms_out_mode_lds) {
2378       nir_store_shared(b, store_val, addr, .base = const_off,
2379                      .write_mask = write_mask, .align_mul = 16,
2380                      .align_offset = const_off % 16);
2381    } else if (out_mode == ms_out_mode_vram) {
2382       nir_ssa_def *ring = nir_load_ring_mesh_scratch_amd(b);
2383       nir_ssa_def *off = nir_load_ring_mesh_scratch_offset_amd(b);
2384       nir_store_buffer_amd(b, store_val, ring, addr, off,
2385                            .base = const_off,
2386                            .write_mask = write_mask,
2387                            .memory_modes = nir_var_shader_out);
2388    } else if (out_mode == ms_out_mode_var) {
2389       if (store_val->bit_size > 32) {
2390          /* Split 64-bit store values to 32-bit components. */
2391          store_val = nir_bitcast_vector(b, store_val, 32);
2392          /* Widen the write mask so it is in 32-bit components. */
2393          write_mask = util_widen_mask(write_mask, store_val->bit_size / 32);
2394       }
2395 
2396       u_foreach_bit(comp, write_mask) {
2397          nir_ssa_def *val = nir_channel(b, store_val, comp);
2398          unsigned idx = location * 4 + comp + component_offset;
2399          nir_store_var(b, s->out_variables[idx], val, 0x1);
2400       }
2401    } else {
2402       unreachable("Invalid MS output mode for store");
2403    }
2404 }
2405 
2406 static nir_ssa_def *
ms_load_arrayed_output(nir_builder * b,nir_ssa_def * arr_index,nir_ssa_def * base_offset,unsigned location,unsigned component_offset,unsigned num_components,unsigned load_bit_size,lower_ngg_ms_state * s)2407 ms_load_arrayed_output(nir_builder *b,
2408                        nir_ssa_def *arr_index,
2409                        nir_ssa_def *base_offset,
2410                        unsigned location,
2411                        unsigned component_offset,
2412                        unsigned num_components,
2413                        unsigned load_bit_size,
2414                        lower_ngg_ms_state *s)
2415 {
2416    ms_out_mode out_mode;
2417    const ms_out_part *out = ms_get_out_layout_part(location, &b->shader->info, &out_mode, s);
2418 
2419    unsigned component_addr_off = component_offset * 4;
2420    unsigned num_outputs = util_bitcount64(out->mask);
2421    unsigned const_off = out->addr + component_offset * 4;
2422 
2423    /* Use compacted driver location instead of the original. */
2424    unsigned driver_location = util_bitcount64(out->mask & u_bit_consecutive64(0, location));
2425 
2426    nir_ssa_def *base_addr = ms_arrayed_output_base_addr(b, arr_index, driver_location, num_outputs);
2427    nir_ssa_def *base_addr_off = nir_imul_imm(b, base_offset, 16);
2428    nir_ssa_def *addr = nir_iadd_nuw(b, base_addr, base_addr_off);
2429 
2430    if (out_mode == ms_out_mode_lds) {
2431       return nir_load_shared(b, num_components, load_bit_size, addr, .align_mul = 16,
2432                              .align_offset = component_addr_off % 16,
2433                              .base = const_off);
2434    } else if (out_mode == ms_out_mode_vram) {
2435       nir_ssa_def *ring = nir_load_ring_mesh_scratch_amd(b);
2436       nir_ssa_def *off = nir_load_ring_mesh_scratch_offset_amd(b);
2437       return nir_load_buffer_amd(b, num_components, load_bit_size, ring, addr, off,
2438                                  .base = const_off,
2439                                  .memory_modes = nir_var_shader_out);
2440    } else if (out_mode == ms_out_mode_var) {
2441       nir_ssa_def *arr[8] = {0};
2442       unsigned num_32bit_components = num_components * load_bit_size / 32;
2443       for (unsigned comp = 0; comp < num_32bit_components; ++comp) {
2444          unsigned idx = location * 4 + comp + component_addr_off;
2445          arr[comp] = nir_load_var(b, s->out_variables[idx]);
2446       }
2447       if (load_bit_size > 32)
2448          return nir_extract_bits(b, arr, 1, 0, num_components, load_bit_size);
2449       return nir_vec(b, arr, num_components);
2450    } else {
2451       unreachable("Invalid MS output mode for load");
2452    }
2453 }
2454 
2455 static nir_ssa_def *
ms_load_arrayed_output_intrin(nir_builder * b,nir_intrinsic_instr * intrin,lower_ngg_ms_state * s)2456 ms_load_arrayed_output_intrin(nir_builder *b,
2457                               nir_intrinsic_instr *intrin,
2458                               lower_ngg_ms_state *s)
2459 {
2460    nir_ssa_def *arr_index = nir_get_io_arrayed_index_src(intrin)->ssa;
2461    nir_ssa_def *base_offset = nir_get_io_offset_src(intrin)->ssa;
2462 
2463    unsigned location = nir_intrinsic_io_semantics(intrin).location;
2464    unsigned component_offset = nir_intrinsic_component(intrin);
2465    unsigned bit_size = intrin->dest.ssa.bit_size;
2466    unsigned num_components = intrin->dest.ssa.num_components;
2467    unsigned load_bit_size = MAX2(bit_size, 32);
2468 
2469    nir_ssa_def *load =
2470       ms_load_arrayed_output(b, arr_index, base_offset, location, component_offset,
2471                              num_components, load_bit_size, s);
2472 
2473    return regroup_load_val(b, load, bit_size);
2474 }
2475 
2476 static nir_ssa_def *
lower_ms_load_workgroup_index(nir_builder * b,UNUSED nir_intrinsic_instr * intrin,lower_ngg_ms_state * s)2477 lower_ms_load_workgroup_index(nir_builder *b,
2478                               UNUSED nir_intrinsic_instr *intrin,
2479                               lower_ngg_ms_state *s)
2480 {
2481    return s->workgroup_index;
2482 }
2483 
2484 static nir_ssa_def *
update_ms_scoped_barrier(nir_builder * b,nir_intrinsic_instr * intrin,lower_ngg_ms_state * s)2485 update_ms_scoped_barrier(nir_builder *b,
2486                          nir_intrinsic_instr *intrin,
2487                          lower_ngg_ms_state *s)
2488 {
2489    /* Output loads and stores are lowered to shared memory access,
2490     * so we have to update the barriers to also reflect this.
2491     */
2492    unsigned mem_modes = nir_intrinsic_memory_modes(intrin);
2493    if (mem_modes & nir_var_shader_out)
2494       mem_modes |= nir_var_mem_shared;
2495    else
2496       return NULL;
2497 
2498    nir_intrinsic_set_memory_modes(intrin, mem_modes);
2499 
2500    return NIR_LOWER_INSTR_PROGRESS;
2501 }
2502 
2503 static nir_ssa_def *
lower_ms_intrinsic(nir_builder * b,nir_instr * instr,void * state)2504 lower_ms_intrinsic(nir_builder *b, nir_instr *instr, void *state)
2505 {
2506    lower_ngg_ms_state *s = (lower_ngg_ms_state *) state;
2507 
2508    if (instr->type != nir_instr_type_intrinsic)
2509       return NULL;
2510 
2511    nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
2512 
2513    switch (intrin->intrinsic) {
2514    case nir_intrinsic_store_output:
2515       return lower_ms_store_output(b, intrin, s);
2516    case nir_intrinsic_load_output:
2517       return lower_ms_load_output(b, intrin, s);
2518    case nir_intrinsic_store_per_vertex_output:
2519    case nir_intrinsic_store_per_primitive_output:
2520       ms_store_arrayed_output_intrin(b, intrin, s);
2521       return NIR_LOWER_INSTR_PROGRESS_REPLACE;
2522    case nir_intrinsic_load_per_vertex_output:
2523    case nir_intrinsic_load_per_primitive_output:
2524       return ms_load_arrayed_output_intrin(b, intrin, s);
2525    case nir_intrinsic_scoped_barrier:
2526       return update_ms_scoped_barrier(b, intrin, s);
2527    case nir_intrinsic_load_workgroup_index:
2528       return lower_ms_load_workgroup_index(b, intrin, s);
2529    default:
2530       unreachable("Not a lowerable mesh shader intrinsic.");
2531    }
2532 }
2533 
2534 static bool
filter_ms_intrinsic(const nir_instr * instr,UNUSED const void * st)2535 filter_ms_intrinsic(const nir_instr *instr,
2536                     UNUSED const void *st)
2537 {
2538    if (instr->type != nir_instr_type_intrinsic)
2539       return false;
2540 
2541    nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
2542    return intrin->intrinsic == nir_intrinsic_store_output ||
2543           intrin->intrinsic == nir_intrinsic_load_output ||
2544           intrin->intrinsic == nir_intrinsic_store_per_vertex_output ||
2545           intrin->intrinsic == nir_intrinsic_load_per_vertex_output ||
2546           intrin->intrinsic == nir_intrinsic_store_per_primitive_output ||
2547           intrin->intrinsic == nir_intrinsic_load_per_primitive_output ||
2548           intrin->intrinsic == nir_intrinsic_scoped_barrier ||
2549           intrin->intrinsic == nir_intrinsic_load_workgroup_index;
2550 }
2551 
2552 static void
lower_ms_intrinsics(nir_shader * shader,lower_ngg_ms_state * s)2553 lower_ms_intrinsics(nir_shader *shader, lower_ngg_ms_state *s)
2554 {
2555    nir_shader_lower_instructions(shader, filter_ms_intrinsic, lower_ms_intrinsic, s);
2556 }
2557 
2558 static void
ms_emit_arrayed_outputs(nir_builder * b,nir_ssa_def * invocation_index,uint64_t mask,lower_ngg_ms_state * s)2559 ms_emit_arrayed_outputs(nir_builder *b,
2560                         nir_ssa_def *invocation_index,
2561                         uint64_t mask,
2562                         lower_ngg_ms_state *s)
2563 {
2564    nir_ssa_def *zero = nir_imm_int(b, 0);
2565 
2566    u_foreach_bit64(slot, mask) {
2567       /* Should not occour here, handled separately. */
2568       assert(slot != VARYING_SLOT_PRIMITIVE_COUNT && slot != VARYING_SLOT_PRIMITIVE_INDICES);
2569 
2570       const nir_io_semantics io_sem = { .location = slot, .num_slots = 1 };
2571       unsigned component_mask = s->output_info[slot].components_mask;
2572 
2573       while (component_mask) {
2574          int start_comp = 0, num_components = 1;
2575          u_bit_scan_consecutive_range(&component_mask, &start_comp, &num_components);
2576 
2577          nir_ssa_def *load =
2578             ms_load_arrayed_output(b, invocation_index, zero, slot, start_comp,
2579                                    num_components, 32, s);
2580 
2581          nir_store_output(b, load, nir_imm_int(b, 0), .base = slot, .component = start_comp,
2582                           .io_semantics = io_sem);
2583       }
2584    }
2585 }
2586 
2587 static void
emit_ms_prelude(nir_builder * b,lower_ngg_ms_state * s)2588 emit_ms_prelude(nir_builder *b, lower_ngg_ms_state *s)
2589 {
2590    b->cursor = nir_before_cf_list(&b->impl->body);
2591 
2592    /* Initialize NIR variables for same-invocation outputs. */
2593    uint64_t same_invocation_output_mask = s->layout.var.prm_attr.mask | s->layout.var.vtx_attr.mask;
2594 
2595    u_foreach_bit64(slot, same_invocation_output_mask) {
2596       for (unsigned comp = 0; comp < 4; ++comp) {
2597          unsigned idx = slot * 4 + comp;
2598          s->out_variables[idx] = nir_local_variable_create(b->impl, glsl_uint_type(), "ms_var_output");
2599          nir_store_var(b, s->out_variables[idx], nir_imm_int(b, 0), 0x1);
2600       }
2601    }
2602 
2603    bool uses_workgroup_id =
2604       BITSET_TEST(b->shader->info.system_values_read, SYSTEM_VALUE_WORKGROUP_ID) ||
2605       BITSET_TEST(b->shader->info.system_values_read, SYSTEM_VALUE_WORKGROUP_INDEX);
2606 
2607    if (!uses_workgroup_id)
2608       return;
2609 
2610    /* The HW doesn't support a proper workgroup index for vertex processing stages,
2611     * so we use the vertex ID which is equivalent to the index of the current workgroup
2612     * within the current dispatch.
2613     *
2614     * Due to the register programming of mesh shaders, this value is only filled for
2615     * the first invocation of the first wave. To let other waves know, we use LDS.
2616     */
2617    nir_ssa_def *workgroup_index = nir_load_vertex_id_zero_base(b);
2618 
2619    if (s->api_workgroup_size <= s->wave_size) {
2620       /* API workgroup is small, so we don't need to use LDS. */
2621       s->workgroup_index = nir_read_first_invocation(b, workgroup_index);
2622       return;
2623    }
2624 
2625    unsigned workgroup_index_lds_addr = s->layout.lds.workgroup_info_addr + lds_ms_wg_index;
2626 
2627    nir_ssa_def *zero = nir_imm_int(b, 0);
2628    nir_ssa_def *dont_care = nir_ssa_undef(b, 1, 32);
2629    nir_ssa_def *loaded_workgroup_index = NULL;
2630 
2631    /* Use elect to make sure only 1 invocation uses LDS. */
2632    nir_if *if_elected = nir_push_if(b, nir_elect(b, 1));
2633    {
2634       nir_ssa_def *wave_id = nir_load_subgroup_id(b);
2635       nir_if *if_wave_0 = nir_push_if(b, nir_ieq_imm(b, wave_id, 0));
2636       {
2637          nir_store_shared(b, workgroup_index, zero, .base = workgroup_index_lds_addr);
2638          nir_scoped_barrier(b, .execution_scope = NIR_SCOPE_WORKGROUP,
2639                                .memory_scope = NIR_SCOPE_WORKGROUP,
2640                                .memory_semantics = NIR_MEMORY_ACQ_REL,
2641                                .memory_modes = nir_var_mem_shared);
2642       }
2643       nir_push_else(b, if_wave_0);
2644       {
2645          nir_scoped_barrier(b, .execution_scope = NIR_SCOPE_WORKGROUP,
2646                                .memory_scope = NIR_SCOPE_WORKGROUP,
2647                                .memory_semantics = NIR_MEMORY_ACQ_REL,
2648                                .memory_modes = nir_var_mem_shared);
2649          loaded_workgroup_index = nir_load_shared(b, 1, 32, zero, .base = workgroup_index_lds_addr);
2650       }
2651       nir_pop_if(b, if_wave_0);
2652 
2653       workgroup_index = nir_if_phi(b, workgroup_index, loaded_workgroup_index);
2654    }
2655    nir_pop_if(b, if_elected);
2656 
2657    workgroup_index = nir_if_phi(b, workgroup_index, dont_care);
2658    s->workgroup_index = nir_read_first_invocation(b, workgroup_index);
2659 }
2660 
2661 static void
set_nv_ms_final_output_counts(nir_builder * b,lower_ngg_ms_state * s,nir_ssa_def ** out_num_prm,nir_ssa_def ** out_num_vtx)2662 set_nv_ms_final_output_counts(nir_builder *b,
2663                                lower_ngg_ms_state *s,
2664                                nir_ssa_def **out_num_prm,
2665                                nir_ssa_def **out_num_vtx)
2666 {
2667    /* Limitations of the NV extension:
2668     * - Number of primitives can be written and read by any invocation,
2669     *   so we have to store/load it to/from LDS to make sure the general case works.
2670     * - Number of vertices is not actually known, so we just always use the
2671     *   maximum number here.
2672     */
2673    nir_ssa_def *loaded_num_prm;
2674    nir_ssa_def *dont_care = nir_ssa_undef(b, 1, 32);
2675    nir_if *if_elected = nir_push_if(b, nir_elect(b, 1));
2676    {
2677       loaded_num_prm = ms_load_num_prims(b, s);
2678    }
2679    nir_pop_if(b, if_elected);
2680    loaded_num_prm = nir_if_phi(b, loaded_num_prm, dont_care);
2681    nir_ssa_def *num_prm = nir_read_first_invocation(b, loaded_num_prm);
2682    nir_ssa_def *num_vtx = nir_imm_int(b, b->shader->info.mesh.max_vertices_out);
2683    num_prm = nir_umin(b, num_prm, nir_imm_int(b, b->shader->info.mesh.max_primitives_out));
2684 
2685    /* If the shader doesn't actually create any primitives, don't allocate any output. */
2686    num_vtx = nir_bcsel(b, nir_ieq_imm(b, num_prm, 0), nir_imm_int(b, 0), num_vtx);
2687 
2688    /* Emit GS_ALLOC_REQ on Wave 0 to let the HW know the output size. */
2689    nir_ssa_def *wave_id = nir_load_subgroup_id(b);
2690    nir_if *if_wave_0 = nir_push_if(b, nir_ieq_imm(b, wave_id, 0));
2691    {
2692       nir_alloc_vertices_and_primitives_amd(b, num_vtx, num_prm);
2693    }
2694    nir_pop_if(b, if_wave_0);
2695 
2696    *out_num_prm = num_prm;
2697    *out_num_vtx = num_vtx;
2698 }
2699 
2700 static void
emit_ms_finale(nir_builder * b,lower_ngg_ms_state * s)2701 emit_ms_finale(nir_builder *b, lower_ngg_ms_state *s)
2702 {
2703    /* We assume there is always a single end block in the shader. */
2704    nir_block *last_block = nir_impl_last_block(b->impl);
2705    b->cursor = nir_after_block(last_block);
2706 
2707    nir_scoped_barrier(b, .execution_scope=NIR_SCOPE_WORKGROUP, .memory_scope=NIR_SCOPE_WORKGROUP,
2708                          .memory_semantics=NIR_MEMORY_ACQ_REL, .memory_modes=nir_var_shader_out|nir_var_mem_shared);
2709 
2710    nir_ssa_def *num_prm;
2711    nir_ssa_def *num_vtx;
2712 
2713    set_nv_ms_final_output_counts(b, s, &num_prm, &num_vtx);
2714 
2715    nir_ssa_def *invocation_index = nir_load_local_invocation_index(b);
2716 
2717    /* Load vertex/primitive attributes from shared memory and
2718     * emit store_output intrinsics for them.
2719     *
2720     * Contrary to the semantics of the API mesh shader, these are now
2721     * compliant with NGG HW semantics, meaning that these store the
2722     * current thread's vertex attributes in a way the HW can export.
2723     */
2724 
2725    /* Export vertices. */
2726    nir_ssa_def *has_output_vertex = nir_ilt(b, invocation_index, num_vtx);
2727    nir_if *if_has_output_vertex = nir_push_if(b, has_output_vertex);
2728    {
2729       /* All per-vertex attributes. */
2730       ms_emit_arrayed_outputs(b, invocation_index, s->per_vertex_outputs, s);
2731       nir_export_vertex_amd(b);
2732    }
2733    nir_pop_if(b, if_has_output_vertex);
2734 
2735    /* Export primitives. */
2736    nir_ssa_def *has_output_primitive = nir_ilt(b, invocation_index, num_prm);
2737    nir_if *if_has_output_primitive = nir_push_if(b, has_output_primitive);
2738    {
2739       /* Generic per-primitive attributes. */
2740       ms_emit_arrayed_outputs(b, invocation_index, s->per_primitive_outputs, s);
2741 
2742       /* Insert layer output store if the pipeline uses multiview but the API shader doesn't write it. */
2743       if (s->insert_layer_output) {
2744          nir_ssa_def *layer = nir_load_view_index(b);
2745          const nir_io_semantics io_sem = { .location = VARYING_SLOT_LAYER, .num_slots = 1 };
2746          nir_store_output(b, layer, nir_imm_int(b, 0), .base = VARYING_SLOT_LAYER, .component = 0, .io_semantics = io_sem);
2747          b->shader->info.outputs_written |= VARYING_BIT_LAYER;
2748          b->shader->info.per_primitive_outputs |= VARYING_BIT_LAYER;
2749       }
2750 
2751       /* Primitive connectivity data: describes which vertices the primitive uses. */
2752       nir_ssa_def *prim_idx_addr = nir_imul_imm(b, invocation_index, s->vertices_per_prim);
2753       nir_ssa_def *indices_loaded = nir_load_shared(b, s->vertices_per_prim, 8, prim_idx_addr, .base = s->layout.lds.indices_addr);
2754       nir_ssa_def *indices[3];
2755       nir_ssa_def *max_vtx_idx = nir_iadd_imm(b, num_vtx, -1u);
2756 
2757       for (unsigned i = 0; i < s->vertices_per_prim; ++i) {
2758          indices[i] = nir_u2u32(b, nir_channel(b, indices_loaded, i));
2759          indices[i] = nir_umin(b, indices[i], max_vtx_idx);
2760       }
2761 
2762       nir_ssa_def *prim_exp_arg = emit_pack_ngg_prim_exp_arg(b, s->vertices_per_prim, indices, NULL, false);
2763       nir_export_primitive_amd(b, prim_exp_arg);
2764    }
2765    nir_pop_if(b, if_has_output_primitive);
2766 }
2767 
2768 static void
handle_smaller_ms_api_workgroup(nir_builder * b,lower_ngg_ms_state * s)2769 handle_smaller_ms_api_workgroup(nir_builder *b,
2770                                 lower_ngg_ms_state *s)
2771 {
2772    if (s->api_workgroup_size >= s->hw_workgroup_size)
2773       return;
2774 
2775    /* Handle barriers manually when the API workgroup
2776     * size is less than the HW workgroup size.
2777     *
2778     * The problem is that the real workgroup launched on NGG HW
2779     * will be larger than the size specified by the API, and the
2780     * extra waves need to keep up with barriers in the API waves.
2781     *
2782     * There are 2 different cases:
2783     * 1. The whole API workgroup fits in a single wave.
2784     *    We can shrink the barriers to subgroup scope and
2785     *    don't need to insert any extra ones.
2786     * 2. The API workgroup occupies multiple waves, but not
2787     *    all. In this case, we emit code that consumes every
2788     *    barrier on the extra waves.
2789     */
2790    assert(s->hw_workgroup_size % s->wave_size == 0);
2791    bool scan_barriers = ALIGN(s->api_workgroup_size, s->wave_size) < s->hw_workgroup_size;
2792    bool can_shrink_barriers = s->api_workgroup_size <= s->wave_size;
2793    bool need_additional_barriers = scan_barriers && !can_shrink_barriers;
2794 
2795    unsigned api_waves_in_flight_addr = s->layout.lds.workgroup_info_addr + lds_ms_num_api_waves;
2796    unsigned num_api_waves = DIV_ROUND_UP(s->api_workgroup_size, s->wave_size);
2797 
2798    /* Scan the shader for workgroup barriers. */
2799    if (scan_barriers) {
2800       bool has_any_workgroup_barriers = false;
2801 
2802       nir_foreach_block(block, b->impl) {
2803          nir_foreach_instr_safe(instr, block) {
2804             if (instr->type != nir_instr_type_intrinsic)
2805                continue;
2806 
2807             nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
2808             bool is_workgroup_barrier =
2809                intrin->intrinsic == nir_intrinsic_scoped_barrier &&
2810                nir_intrinsic_execution_scope(intrin) == NIR_SCOPE_WORKGROUP;
2811 
2812             if (!is_workgroup_barrier)
2813                continue;
2814 
2815             if (can_shrink_barriers) {
2816                /* Every API invocation runs in the first wave.
2817                 * In this case, we can change the barriers to subgroup scope
2818                 * and avoid adding additional barriers.
2819                 */
2820                nir_intrinsic_set_memory_scope(intrin, NIR_SCOPE_SUBGROUP);
2821                nir_intrinsic_set_execution_scope(intrin, NIR_SCOPE_SUBGROUP);
2822             } else {
2823                has_any_workgroup_barriers = true;
2824             }
2825          }
2826       }
2827 
2828       need_additional_barriers &= has_any_workgroup_barriers;
2829    }
2830 
2831    /* Extract the full control flow of the shader. */
2832    nir_cf_list extracted;
2833    nir_cf_extract(&extracted, nir_before_cf_list(&b->impl->body), nir_after_cf_list(&b->impl->body));
2834    b->cursor = nir_before_cf_list(&b->impl->body);
2835 
2836    /* Wrap the shader in an if to ensure that only the necessary amount of lanes run it. */
2837    nir_ssa_def *invocation_index = nir_load_local_invocation_index(b);
2838    nir_ssa_def *zero = nir_imm_int(b, 0);
2839 
2840    if (need_additional_barriers) {
2841       /* First invocation stores 0 to number of API waves in flight. */
2842       nir_if *if_first_in_workgroup = nir_push_if(b, nir_ieq_imm(b, invocation_index, 0));
2843       {
2844          nir_store_shared(b, nir_imm_int(b, num_api_waves), zero, .base = api_waves_in_flight_addr);
2845       }
2846       nir_pop_if(b, if_first_in_workgroup);
2847 
2848       nir_scoped_barrier(b, .execution_scope = NIR_SCOPE_WORKGROUP,
2849                             .memory_scope = NIR_SCOPE_WORKGROUP,
2850                             .memory_semantics = NIR_MEMORY_ACQ_REL,
2851                             .memory_modes = nir_var_shader_out | nir_var_mem_shared);
2852    }
2853 
2854    nir_ssa_def *has_api_ms_invocation = nir_ult(b, invocation_index, nir_imm_int(b, s->api_workgroup_size));
2855    nir_if *if_has_api_ms_invocation = nir_push_if(b, has_api_ms_invocation);
2856    {
2857       nir_cf_reinsert(&extracted, b->cursor);
2858       b->cursor = nir_after_cf_list(&if_has_api_ms_invocation->then_list);
2859 
2860       if (need_additional_barriers) {
2861          /* One invocation in each API wave decrements the number of API waves in flight. */
2862          nir_if *if_elected_again = nir_push_if(b, nir_elect(b, 1));
2863          {
2864             nir_shared_atomic_add(b, 32, zero, nir_imm_int(b, -1u), .base = api_waves_in_flight_addr);
2865          }
2866          nir_pop_if(b, if_elected_again);
2867 
2868          nir_scoped_barrier(b, .execution_scope = NIR_SCOPE_WORKGROUP,
2869                                .memory_scope = NIR_SCOPE_WORKGROUP,
2870                                .memory_semantics = NIR_MEMORY_ACQ_REL,
2871                                .memory_modes = nir_var_shader_out | nir_var_mem_shared);
2872       }
2873    }
2874    nir_pop_if(b, if_has_api_ms_invocation);
2875 
2876    if (need_additional_barriers) {
2877       /* Make sure that waves that don't run any API invocations execute
2878        * the same amount of barriers as those that do.
2879        *
2880        * We do this by executing a barrier until the number of API waves
2881        * in flight becomes zero.
2882        */
2883       nir_ssa_def *has_api_ms_ballot = nir_ballot(b, 1, s->wave_size, has_api_ms_invocation);
2884       nir_ssa_def *wave_has_no_api_ms = nir_ieq_imm(b, has_api_ms_ballot, 0);
2885       nir_if *if_wave_has_no_api_ms = nir_push_if(b, wave_has_no_api_ms);
2886       {
2887          nir_if *if_elected = nir_push_if(b, nir_elect(b, 1));
2888          {
2889             nir_loop *loop = nir_push_loop(b);
2890             {
2891                nir_scoped_barrier(b, .execution_scope = NIR_SCOPE_WORKGROUP,
2892                                      .memory_scope = NIR_SCOPE_WORKGROUP,
2893                                      .memory_semantics = NIR_MEMORY_ACQ_REL,
2894                                      .memory_modes = nir_var_shader_out | nir_var_mem_shared);
2895 
2896                nir_ssa_def *loaded = nir_load_shared(b, 1, 32, zero, .base = api_waves_in_flight_addr);
2897                nir_if *if_break = nir_push_if(b, nir_ieq_imm(b, loaded, 0));
2898                {
2899                   nir_jump(b, nir_jump_break);
2900                }
2901                nir_pop_if(b, if_break);
2902             }
2903             nir_pop_loop(b, loop);
2904          }
2905          nir_pop_if(b, if_elected);
2906       }
2907       nir_pop_if(b, if_wave_has_no_api_ms);
2908    }
2909 }
2910 
2911 static void
ms_move_output(ms_out_part * from,ms_out_part * to)2912 ms_move_output(ms_out_part *from, ms_out_part *to)
2913 {
2914    uint64_t loc = util_logbase2_64(from->mask);
2915    uint64_t bit = BITFIELD64_BIT(loc);
2916    from->mask ^= bit;
2917    to->mask |= bit;
2918 }
2919 
2920 static void
ms_calculate_arrayed_output_layout(ms_out_mem_layout * l,unsigned max_vertices,unsigned max_primitives)2921 ms_calculate_arrayed_output_layout(ms_out_mem_layout *l,
2922                                    unsigned max_vertices,
2923                                    unsigned max_primitives)
2924 {
2925    uint32_t lds_vtx_attr_size = util_bitcount64(l->lds.vtx_attr.mask) * max_vertices * 16;
2926    uint32_t lds_prm_attr_size = util_bitcount64(l->lds.prm_attr.mask) * max_primitives * 16;
2927    l->lds.prm_attr.addr = ALIGN(l->lds.vtx_attr.addr + lds_vtx_attr_size, 16);
2928    l->lds.total_size = l->lds.prm_attr.addr + lds_prm_attr_size;
2929 
2930    uint32_t vram_vtx_attr_size = util_bitcount64(l->vram.vtx_attr.mask) * max_vertices * 16;
2931    l->vram.prm_attr.addr = ALIGN(l->vram.vtx_attr.addr + vram_vtx_attr_size, 16);
2932 }
2933 
2934 static ms_out_mem_layout
ms_calculate_output_layout(unsigned api_shared_size,uint64_t per_vertex_output_mask,uint64_t per_primitive_output_mask,uint64_t cross_invocation_output_access,unsigned max_vertices,unsigned max_primitives,unsigned vertices_per_prim)2935 ms_calculate_output_layout(unsigned api_shared_size,
2936                            uint64_t per_vertex_output_mask,
2937                            uint64_t per_primitive_output_mask,
2938                            uint64_t cross_invocation_output_access,
2939                            unsigned max_vertices,
2940                            unsigned max_primitives,
2941                            unsigned vertices_per_prim)
2942 {
2943    uint64_t lds_per_vertex_output_mask = per_vertex_output_mask & cross_invocation_output_access;
2944    uint64_t lds_per_primitive_output_mask = per_primitive_output_mask & cross_invocation_output_access;
2945 
2946    /* Shared memory used by the API shader. */
2947    ms_out_mem_layout l = { .lds = { .total_size = api_shared_size } };
2948 
2949    /* Outputs without cross-invocation access can be stored in variables. */
2950    l.var.vtx_attr.mask = per_vertex_output_mask & ~lds_per_vertex_output_mask;
2951    l.var.prm_attr.mask = per_primitive_output_mask & ~lds_per_primitive_output_mask;
2952 
2953    /* Workgroup information, see ms_workgroup_* for the layout. */
2954    l.lds.workgroup_info_addr = ALIGN(l.lds.total_size, 16);
2955    l.lds.total_size = l.lds.workgroup_info_addr + 16;
2956 
2957    /* Per-vertex and per-primitive output attributes.
2958     * Outputs without cross-invocation access are not included here.
2959     * First, try to put all outputs into LDS (shared memory).
2960     * If they don't fit, try to move them to VRAM one by one.
2961     */
2962    l.lds.vtx_attr.addr = ALIGN(l.lds.total_size, 16);
2963    l.lds.vtx_attr.mask = lds_per_vertex_output_mask;
2964    l.lds.prm_attr.mask = lds_per_primitive_output_mask;
2965    ms_calculate_arrayed_output_layout(&l, max_vertices, max_primitives);
2966 
2967    /* NGG shaders can only address up to 32K LDS memory.
2968     * The spec requires us to allow the application to use at least up to 28K
2969     * shared memory. Additionally, we reserve 2K for driver internal use
2970     * (eg. primitive indices and such, see below).
2971     *
2972     * Move the outputs that do not fit LDS, to VRAM.
2973     * Start with per-primitive attributes, because those are grouped at the end.
2974     */
2975    while (l.lds.total_size >= 30 * 1024) {
2976       if (l.lds.prm_attr.mask)
2977          ms_move_output(&l.lds.prm_attr, &l.vram.prm_attr);
2978       else if (l.lds.vtx_attr.mask)
2979          ms_move_output(&l.lds.vtx_attr, &l.vram.vtx_attr);
2980       else
2981          unreachable("API shader uses too much shared memory.");
2982 
2983       ms_calculate_arrayed_output_layout(&l, max_vertices, max_primitives);
2984    }
2985 
2986    /* Indices: flat array of 8-bit vertex indices for each primitive. */
2987    l.lds.indices_addr = ALIGN(l.lds.total_size, 16);
2988    l.lds.total_size = l.lds.indices_addr + max_primitives * vertices_per_prim;
2989 
2990    /* NGG is only allowed to address up to 32K of LDS. */
2991    assert(l.lds.total_size <= 32 * 1024);
2992    return l;
2993 }
2994 
2995 void
ac_nir_lower_ngg_ms(nir_shader * shader,bool * out_needs_scratch_ring,unsigned wave_size,bool multiview)2996 ac_nir_lower_ngg_ms(nir_shader *shader,
2997                     bool *out_needs_scratch_ring,
2998                     unsigned wave_size,
2999                     bool multiview)
3000 {
3001    unsigned vertices_per_prim =
3002       num_mesh_vertices_per_primitive(shader->info.mesh.primitive_type);
3003 
3004    uint64_t special_outputs =
3005       BITFIELD64_BIT(VARYING_SLOT_PRIMITIVE_COUNT) | BITFIELD64_BIT(VARYING_SLOT_PRIMITIVE_INDICES);
3006    uint64_t per_vertex_outputs =
3007       shader->info.outputs_written & ~shader->info.per_primitive_outputs & ~special_outputs;
3008    uint64_t per_primitive_outputs =
3009       shader->info.per_primitive_outputs & shader->info.outputs_written & ~special_outputs;
3010 
3011    /* Can't handle indirect register addressing, pretend as if they were cross-invocation. */
3012    uint64_t cross_invocation_access = shader->info.mesh.ms_cross_invocation_output_access |
3013                                       shader->info.outputs_accessed_indirectly;
3014 
3015    unsigned max_vertices = shader->info.mesh.max_vertices_out;
3016    unsigned max_primitives = shader->info.mesh.max_primitives_out;
3017 
3018    ms_out_mem_layout layout =
3019       ms_calculate_output_layout(shader->info.shared_size, per_vertex_outputs, per_primitive_outputs,
3020                                  cross_invocation_access, max_vertices, max_primitives, vertices_per_prim);
3021 
3022    shader->info.shared_size = layout.lds.total_size;
3023    *out_needs_scratch_ring = layout.vram.vtx_attr.mask || layout.vram.prm_attr.mask;
3024 
3025    /* The workgroup size that is specified by the API shader may be different
3026     * from the size of the workgroup that actually runs on the HW, due to the
3027     * limitations of NGG: max 0/1 vertex and 0/1 primitive per lane is allowed.
3028     *
3029     * Therefore, we must make sure that when the API workgroup size is smaller,
3030     * we don't run the API shader on more HW invocations than is necessary.
3031     */
3032    unsigned api_workgroup_size = shader->info.workgroup_size[0] *
3033                                  shader->info.workgroup_size[1] *
3034                                  shader->info.workgroup_size[2];
3035 
3036    unsigned hw_workgroup_size =
3037       ALIGN(MAX3(api_workgroup_size, max_primitives, max_vertices), wave_size);
3038 
3039    lower_ngg_ms_state state = {
3040       .layout = layout,
3041       .wave_size = wave_size,
3042       .per_vertex_outputs = per_vertex_outputs,
3043       .per_primitive_outputs = per_primitive_outputs,
3044       .vertices_per_prim = vertices_per_prim,
3045       .api_workgroup_size = api_workgroup_size,
3046       .hw_workgroup_size = hw_workgroup_size,
3047       .insert_layer_output = multiview && !(shader->info.outputs_written & VARYING_BIT_LAYER),
3048    };
3049 
3050    nir_function_impl *impl = nir_shader_get_entrypoint(shader);
3051    assert(impl);
3052 
3053    nir_builder builder;
3054    nir_builder *b = &builder; /* This is to avoid the & */
3055    nir_builder_init(b, impl);
3056    b->cursor = nir_before_cf_list(&impl->body);
3057 
3058    handle_smaller_ms_api_workgroup(b, &state);
3059    emit_ms_prelude(b, &state);
3060    nir_metadata_preserve(impl, nir_metadata_none);
3061 
3062    lower_ms_intrinsics(shader, &state);
3063 
3064    emit_ms_finale(b, &state);
3065    nir_metadata_preserve(impl, nir_metadata_none);
3066 
3067    /* Cleanup */
3068    nir_lower_vars_to_ssa(shader);
3069    nir_remove_dead_variables(shader, nir_var_function_temp, NULL);
3070    nir_lower_alu_to_scalar(shader, NULL, NULL);
3071    nir_lower_phis_to_scalar(shader, true);
3072 
3073    nir_validate_shader(shader, "after emitting NGG MS");
3074 }
3075