1 /*
2 * Copyright © 2021 Valve Corporation
3 *
4 * SPDX-License-Identifier: MIT
5 */
6
7 #include "ac_nir.h"
8 #include "ac_nir_helpers.h"
9 #include "amdgfxregs.h"
10 #include "nir_builder.h"
11 #include "nir_xfb_info.h"
12 #include "util/u_math.h"
13 #include "util/u_vector.h"
14
15 enum {
16 nggc_passflag_used_by_pos = 1,
17 nggc_passflag_used_by_other = 2,
18 nggc_passflag_used_by_both = nggc_passflag_used_by_pos | nggc_passflag_used_by_other,
19 };
20
21 typedef struct
22 {
23 nir_def *ssa;
24 nir_variable *var;
25 } reusable_nondeferred_variable;
26
27 typedef struct
28 {
29 gl_varying_slot slot;
30 nir_def *chan[4];
31 } vs_output;
32
33 typedef struct
34 {
35 const ac_nir_lower_ngg_options *options;
36
37 nir_variable *position_value_var;
38 nir_variable *prim_exp_arg_var;
39 nir_variable *es_accepted_var;
40 nir_variable *gs_accepted_var;
41 nir_variable *gs_exported_var;
42 nir_variable *gs_vtx_indices_vars[3];
43
44 nir_def *vtx_addr[3];
45
46 struct u_vector reusable_nondeferred_variables;
47
48 bool early_prim_export;
49 bool streamout_enabled;
50 bool has_user_edgeflags;
51 bool skip_primitive_id;
52 unsigned max_num_waves;
53
54 /* LDS params */
55 unsigned pervertex_lds_bytes;
56
57 uint64_t inputs_needed_by_pos;
58 uint64_t inputs_needed_by_others;
59
60 nir_instr *compact_arg_stores[4];
61 nir_intrinsic_instr *overwrite_args;
62 nir_variable *repacked_rel_patch_id;
63
64 /* clip distance */
65 nir_variable *clip_vertex_var;
66 nir_variable *clipdist_neg_mask_var;
67 bool has_clipdist;
68
69 /* outputs */
70 ac_nir_prerast_out out;
71 } lower_ngg_nogs_state;
72
73 typedef struct
74 {
75 const ac_nir_lower_ngg_options *options;
76
77 nir_function_impl *impl;
78 int const_out_vtxcnt[4];
79 int const_out_prmcnt[4];
80 unsigned max_num_waves;
81 unsigned num_vertices_per_primitive;
82 nir_def *lds_addr_gs_out_vtx;
83 nir_def *lds_addr_gs_scratch;
84 unsigned lds_bytes_per_gs_out_vertex;
85 unsigned lds_offs_primflags;
86 bool output_compile_time_known;
87 bool streamout_enabled;
88 /* Outputs */
89 ac_nir_prerast_out out;
90 /* Count per stream. */
91 nir_def *vertex_count[4];
92 nir_def *primitive_count[4];
93 } lower_ngg_gs_state;
94
95 /* Per-vertex LDS layout of culling shaders */
96 enum {
97 /* Position of the ES vertex (at the beginning for alignment reasons) */
98 lds_es_pos_x = 0,
99 lds_es_pos_y = 4,
100 lds_es_pos_z = 8,
101 lds_es_pos_w = 12,
102
103 /* 1 when the vertex is accepted, 0 if it should be culled */
104 lds_es_vertex_accepted = 16,
105 /* ID of the thread which will export the current thread's vertex */
106 lds_es_exporter_tid = 17,
107 /* bit i is set when the i'th clip distance of a vertex is negative */
108 lds_es_clipdist_neg_mask = 18,
109 /* TES only, relative patch ID, less than max workgroup size */
110 lds_es_tes_rel_patch_id = 19,
111
112 /* Repacked arguments - also listed separately for VS and TES */
113 lds_es_arg_0 = 20,
114 };
115
116 typedef struct {
117 nir_def *num_repacked_invocations;
118 nir_def *repacked_invocation_index;
119 } wg_repack_result;
120
121 /**
122 * Computes a horizontal sum of 8-bit packed values loaded from LDS.
123 *
124 * Each lane N will sum packed bytes 0 to N.
125 * We only care about the results from up to wave_id lanes.
126 * (Other lanes are not deactivated but their calculation is not used.)
127 */
128 static nir_def *
summarize_repack(nir_builder * b,nir_def * packed_counts,bool mask_lane_id,unsigned num_lds_dwords)129 summarize_repack(nir_builder *b, nir_def *packed_counts, bool mask_lane_id, unsigned num_lds_dwords)
130 {
131 /* We'll use shift to filter out the bytes not needed by the current lane.
132 *
133 * For each row:
134 * Need to shift by: `num_lds_dwords * 4 - 1 - lane_id_in_row` (in bytes)
135 * in order to implement an inclusive scan.
136 *
137 * When v_dot4_u32_u8 is available, we right-shift a series of 0x01 bytes.
138 * This will yield 0x01 at wanted byte positions and 0x00 at unwanted positions,
139 * therefore v_dot can get rid of the unneeded values.
140 *
141 * If the v_dot instruction can't be used, we left-shift the packed bytes
142 * in order to shift out the unneeded bytes and shift in zeroes instead,
143 * then we sum them using v_msad_u8.
144 */
145
146 nir_def *lane_id = nir_load_subgroup_invocation(b);
147
148 /* Mask lane ID so that lanes 16...31 also have the ID 0...15,
149 * in order to perform a second horizontal sum in parallel when needed.
150 */
151 if (mask_lane_id)
152 lane_id = nir_iand_imm(b, lane_id, 0xf);
153
154 nir_def *shift = nir_iadd_imm(b, nir_imul_imm(b, lane_id, -8u), num_lds_dwords * 32 - 8);
155 assert(b->shader->options->has_msad || b->shader->options->has_udot_4x8);
156 bool use_dot = b->shader->options->has_udot_4x8;
157
158 if (num_lds_dwords == 1) {
159 /* Broadcast the packed data we read from LDS
160 * (to the first 16 lanes of the row, but we only care up to num_waves).
161 */
162 nir_def *packed = nir_lane_permute_16_amd(b, packed_counts, nir_imm_int(b, 0), nir_imm_int(b, 0));
163
164 /* Horizontally add the packed bytes. */
165 if (use_dot) {
166 nir_def *dot_op = nir_ushr(b, nir_imm_int(b, 0x01010101), shift);
167 return nir_udot_4x8_uadd(b, packed, dot_op, nir_imm_int(b, 0));
168 } else {
169 nir_def *sad_op = nir_ishl(b, packed, shift);
170 return nir_msad_4x8(b, sad_op, nir_imm_int(b, 0), nir_imm_int(b, 0));
171 }
172 } else if (num_lds_dwords == 2) {
173 /* Broadcast the packed data we read from LDS
174 * (to the first 16 lanes of the row, but we only care up to num_waves).
175 */
176 nir_def *packed_dw0 = nir_lane_permute_16_amd(b, nir_unpack_64_2x32_split_x(b, packed_counts), nir_imm_int(b, 0), nir_imm_int(b, 0));
177 nir_def *packed_dw1 = nir_lane_permute_16_amd(b, nir_unpack_64_2x32_split_y(b, packed_counts), nir_imm_int(b, 0), nir_imm_int(b, 0));
178
179 /* Horizontally add the packed bytes. */
180 if (use_dot) {
181 nir_def *dot_op = nir_ushr(b, nir_imm_int64(b, 0x0101010101010101), shift);
182 nir_def *sum = nir_udot_4x8_uadd(b, packed_dw0, nir_unpack_64_2x32_split_x(b, dot_op), nir_imm_int(b, 0));
183 return nir_udot_4x8_uadd(b, packed_dw1, nir_unpack_64_2x32_split_y(b, dot_op), sum);
184 } else {
185 nir_def *sad_op = nir_ishl(b, nir_pack_64_2x32_split(b, packed_dw0, packed_dw1), shift);
186 nir_def *sum = nir_msad_4x8(b, nir_unpack_64_2x32_split_x(b, sad_op), nir_imm_int(b, 0), nir_imm_int(b, 0));
187 return nir_msad_4x8(b, nir_unpack_64_2x32_split_y(b, sad_op), nir_imm_int(b, 0), sum);
188 }
189 } else {
190 unreachable("Unimplemented NGG wave count");
191 }
192 }
193
194 /**
195 * Repacks invocations in the current workgroup to eliminate gaps between them.
196 *
197 * Uses 1 dword of LDS per 4 waves (1 byte of LDS per wave) for each repack.
198 * Assumes that all invocations in the workgroup are active (exec = -1).
199 */
200 static void
repack_invocations_in_workgroup(nir_builder * b,nir_def ** input_bool,wg_repack_result * results,const unsigned num_repacks,nir_def * lds_addr_base,unsigned max_num_waves,unsigned wave_size)201 repack_invocations_in_workgroup(nir_builder *b, nir_def **input_bool,
202 wg_repack_result *results, const unsigned num_repacks,
203 nir_def *lds_addr_base, unsigned max_num_waves,
204 unsigned wave_size)
205 {
206 /* We can currently only do up to 2 repacks at a time. */
207 assert(num_repacks <= 2);
208
209 /* STEP 1. Count surviving invocations in the current wave.
210 *
211 * Implemented by a scalar instruction that simply counts the number of bits set in a 32/64-bit mask.
212 */
213
214 nir_def *input_mask[2];
215 nir_def *surviving_invocations_in_current_wave[2];
216
217 for (unsigned i = 0; i < num_repacks; ++i) {
218 /* Input should be boolean: 1 if the current invocation should survive the repack. */
219 assert(input_bool[i]->bit_size == 1);
220
221 input_mask[i] = nir_ballot(b, 1, wave_size, input_bool[i]);
222 surviving_invocations_in_current_wave[i] = nir_bit_count(b, input_mask[i]);
223 }
224
225 /* If we know at compile time that the workgroup has only 1 wave, no further steps are necessary. */
226 if (max_num_waves == 1) {
227 for (unsigned i = 0; i < num_repacks; ++i) {
228 results[i].num_repacked_invocations = surviving_invocations_in_current_wave[i];
229 results[i].repacked_invocation_index = nir_mbcnt_amd(b, input_mask[i], nir_imm_int(b, 0));
230 }
231 return;
232 }
233
234 /* STEP 2. Waves tell each other their number of surviving invocations.
235 *
236 * Row 0 (lanes 0-15) performs the first repack, and Row 1 (lanes 16-31) the second in parallel.
237 * Each wave activates only its first lane per row, which stores the number of surviving
238 * invocations in that wave into the LDS for that repack, then reads the numbers from every wave.
239 *
240 * The workgroup size of NGG shaders is at most 256, which means
241 * the maximum number of waves is 4 in Wave64 mode and 8 in Wave32 mode.
242 * For each repack:
243 * Each wave writes 1 byte, so it's up to 8 bytes, so at most 2 dwords are necessary.
244 * (The maximum is 4 dwords for 2 repacks in Wave32 mode.)
245 */
246
247 const unsigned num_lds_dwords = DIV_ROUND_UP(max_num_waves, 4);
248 assert(num_lds_dwords <= 2);
249
250 /* The first lane of each row (per repack) needs to access the LDS. */
251 const unsigned ballot = num_repacks == 1 ? 1 : 0x10001;
252
253 nir_def *wave_id = nir_load_subgroup_id(b);
254 nir_def *dont_care = nir_undef(b, 1, num_lds_dwords * 32);
255 nir_def *packed_counts = NULL;
256
257 nir_if *if_use_lds = nir_push_if(b, nir_inverse_ballot(b, 1, nir_imm_intN_t(b, ballot, wave_size)));
258 {
259 nir_def *store_val = surviving_invocations_in_current_wave[0];
260
261 if (num_repacks == 2) {
262 nir_def *lane_id_0 = nir_inverse_ballot(b, 1, nir_imm_intN_t(b, 1, wave_size));
263 nir_def *off = nir_bcsel(b, lane_id_0, nir_imm_int(b, 0), nir_imm_int(b, num_lds_dwords * 4));
264 lds_addr_base = nir_iadd_nuw(b, lds_addr_base, off);
265 store_val = nir_bcsel(b, lane_id_0, store_val, surviving_invocations_in_current_wave[1]);
266 }
267
268 nir_def *store_byte = nir_u2u8(b, store_val);
269 nir_def *lds_offset = nir_iadd(b, lds_addr_base, wave_id);
270 nir_store_shared(b, store_byte, lds_offset);
271
272 nir_barrier(b, .execution_scope = SCOPE_WORKGROUP, .memory_scope = SCOPE_WORKGROUP,
273 .memory_semantics = NIR_MEMORY_ACQ_REL, .memory_modes = nir_var_mem_shared);
274
275 packed_counts = nir_load_shared(b, 1, num_lds_dwords * 32, lds_addr_base, .align_mul = 8u);
276 }
277 nir_pop_if(b, if_use_lds);
278
279 packed_counts = nir_if_phi(b, packed_counts, dont_care);
280
281 /* STEP 3. Compute the repacked invocation index and the total number of surviving invocations.
282 *
283 * By now, every wave knows the number of surviving invocations in all waves.
284 * Each number is 1 byte, and they are packed into up to 2 dwords.
285 *
286 * For each row (of 16 lanes):
287 * Each lane N (in the row) will sum the number of surviving invocations inclusively from waves 0 to N.
288 * If the workgroup has M waves, then each row will use only its first M lanes for this.
289 * (Other lanes are not deactivated but their calculation is not used.)
290 *
291 * - We read the sum from the lane whose id (in the row) is the current wave's id,
292 * and subtract the number of its own surviving invocations.
293 * Add the masked bitcount to this, and we get the repacked invocation index.
294 * - We read the sum from the lane whose id (in the row) is the number of waves in the workgroup minus 1.
295 * This is the total number of surviving invocations in the workgroup.
296 */
297
298 nir_def *num_waves = nir_load_num_subgroups(b);
299 nir_def *sum = summarize_repack(b, packed_counts, num_repacks == 2, num_lds_dwords);
300
301 for (unsigned i = 0; i < num_repacks; ++i) {
302 nir_def *index_base_lane = nir_iadd_imm_nuw(b, wave_id, i * 16);
303 nir_def *num_invocartions_lane = nir_iadd_imm(b, num_waves, i * 16 - 1);
304 nir_def *wg_repacked_index_base =
305 nir_isub(b, nir_read_invocation(b, sum, index_base_lane), surviving_invocations_in_current_wave[i]);
306 results[i].num_repacked_invocations =
307 nir_read_invocation(b, sum, num_invocartions_lane);
308 results[i].repacked_invocation_index =
309 nir_mbcnt_amd(b, input_mask[i], wg_repacked_index_base);
310 }
311 }
312
313 static nir_def *
pervertex_lds_addr(nir_builder * b,nir_def * vertex_idx,unsigned per_vtx_bytes)314 pervertex_lds_addr(nir_builder *b, nir_def *vertex_idx, unsigned per_vtx_bytes)
315 {
316 return nir_imul_imm(b, vertex_idx, per_vtx_bytes);
317 }
318
319 static void
alloc_vertices_and_primitives(nir_builder * b,nir_def * num_vtx,nir_def * num_prim)320 alloc_vertices_and_primitives(nir_builder *b,
321 nir_def *num_vtx,
322 nir_def *num_prim)
323 {
324 /* The caller should only call this conditionally on wave 0.
325 *
326 * Send GS Alloc Request message from the first wave of the group to SPI.
327 * Message payload (in the m0 register) is:
328 * - bits 0..10: number of vertices in group
329 * - bits 12..22: number of primitives in group
330 */
331
332 nir_def *m0 = nir_ior(b, nir_ishl_imm(b, num_prim, 12), num_vtx);
333 nir_sendmsg_amd(b, m0, .base = AC_SENDMSG_GS_ALLOC_REQ);
334 }
335
336 static void
alloc_vertices_and_primitives_gfx10_workaround(nir_builder * b,nir_def * num_vtx,nir_def * num_prim)337 alloc_vertices_and_primitives_gfx10_workaround(nir_builder *b,
338 nir_def *num_vtx,
339 nir_def *num_prim)
340 {
341 /* HW workaround for a GPU hang with 100% culling on GFX10.
342 * We always have to export at least 1 primitive.
343 * Export a degenerate triangle using vertex 0 for all 3 vertices.
344 *
345 * NOTE: We rely on the caller to set the vertex count also to 0 when the primitive count is 0.
346 */
347 nir_def *is_prim_cnt_0 = nir_ieq_imm(b, num_prim, 0);
348 nir_if *if_prim_cnt_0 = nir_push_if(b, is_prim_cnt_0);
349 {
350 nir_def *one = nir_imm_int(b, 1);
351 alloc_vertices_and_primitives(b, one, one);
352
353 nir_def *tid = nir_load_subgroup_invocation(b);
354 nir_def *is_thread_0 = nir_ieq_imm(b, tid, 0);
355 nir_if *if_thread_0 = nir_push_if(b, is_thread_0);
356 {
357 /* The vertex indices are 0, 0, 0. */
358 nir_export_amd(b, nir_imm_zero(b, 4, 32),
359 .base = V_008DFC_SQ_EXP_PRIM,
360 .flags = AC_EXP_FLAG_DONE,
361 .write_mask = 1);
362
363 /* The HW culls primitives with NaN. -1 is also NaN and can save
364 * a dword in binary code by inlining constant.
365 */
366 nir_export_amd(b, nir_imm_ivec4(b, -1, -1, -1, -1),
367 .base = V_008DFC_SQ_EXP_POS,
368 .flags = AC_EXP_FLAG_DONE,
369 .write_mask = 0xf);
370 }
371 nir_pop_if(b, if_thread_0);
372 }
373 nir_push_else(b, if_prim_cnt_0);
374 {
375 alloc_vertices_and_primitives(b, num_vtx, num_prim);
376 }
377 nir_pop_if(b, if_prim_cnt_0);
378 }
379
380 static void
ngg_nogs_init_vertex_indices_vars(nir_builder * b,nir_function_impl * impl,lower_ngg_nogs_state * s)381 ngg_nogs_init_vertex_indices_vars(nir_builder *b, nir_function_impl *impl, lower_ngg_nogs_state *s)
382 {
383 for (unsigned v = 0; v < s->options->num_vertices_per_primitive; ++v) {
384 s->gs_vtx_indices_vars[v] = nir_local_variable_create(impl, glsl_uint_type(), "gs_vtx_addr");
385
386 nir_def *vtx;
387
388 if (s->options->gfx_level >= GFX12) {
389 vtx = nir_ubfe_imm(b, nir_load_packed_passthrough_primitive_amd(b), 9 * v, 8);
390 } else if (s->options->passthrough) {
391 vtx = nir_ubfe_imm(b, nir_load_packed_passthrough_primitive_amd(b), 10 * v, 9);
392 } else {
393 vtx = nir_ubfe_imm(b, nir_load_gs_vertex_offset_amd(b, .base = v / 2u),
394 (v & 1u) * 16u, 16u);
395 }
396
397 nir_store_var(b, s->gs_vtx_indices_vars[v], vtx, 0x1);
398 }
399 }
400
401 static nir_def *
emit_ngg_nogs_prim_exp_arg(nir_builder * b,lower_ngg_nogs_state * s)402 emit_ngg_nogs_prim_exp_arg(nir_builder *b, lower_ngg_nogs_state *s)
403 {
404 if (s->options->gfx_level >= GFX12 || s->options->passthrough) {
405 return nir_load_packed_passthrough_primitive_amd(b);
406 } else {
407 nir_def *vtx_idx[3] = {0};
408
409 for (unsigned v = 0; v < s->options->num_vertices_per_primitive; ++v)
410 vtx_idx[v] = nir_load_var(b, s->gs_vtx_indices_vars[v]);
411
412 return ac_nir_pack_ngg_prim_exp_arg(b, s->options->num_vertices_per_primitive, vtx_idx, NULL,
413 s->options->gfx_level);
414 }
415 }
416
417 static nir_def *
has_input_vertex(nir_builder * b)418 has_input_vertex(nir_builder *b)
419 {
420 return nir_is_subgroup_invocation_lt_amd(b, nir_load_merged_wave_info_amd(b));
421 }
422
423 static nir_def *
has_input_primitive(nir_builder * b)424 has_input_primitive(nir_builder *b)
425 {
426 return nir_is_subgroup_invocation_lt_amd(b, nir_load_merged_wave_info_amd(b), .base = 8);
427 }
428
429 static void
nogs_prim_gen_query(nir_builder * b,lower_ngg_nogs_state * s)430 nogs_prim_gen_query(nir_builder *b, lower_ngg_nogs_state *s)
431 {
432 if (!s->options->has_gen_prim_query)
433 return;
434
435 nir_if *if_shader_query = nir_push_if(b, nir_load_prim_gen_query_enabled_amd(b));
436 {
437 /* Activate only 1 lane and add the number of primitives to query result. */
438 nir_if *if_elected = nir_push_if(b, nir_elect(b, 1));
439 {
440 /* Number of input primitives in the current wave. */
441 nir_def *num_input_prims = nir_ubfe_imm(b, nir_load_merged_wave_info_amd(b),
442 8, 8);
443
444 /* Add to stream 0 primitive generated counter. */
445 nir_atomic_add_gen_prim_count_amd(b, num_input_prims, .stream_id = 0);
446 }
447 nir_pop_if(b, if_elected);
448 }
449 nir_pop_if(b, if_shader_query);
450 }
451
452 static void
emit_ngg_nogs_prim_export(nir_builder * b,lower_ngg_nogs_state * s,nir_def * arg)453 emit_ngg_nogs_prim_export(nir_builder *b, lower_ngg_nogs_state *s, nir_def *arg)
454 {
455 nir_if *if_gs_thread = nir_push_if(b, nir_load_var(b, s->gs_exported_var));
456 {
457 if (!arg)
458 arg = emit_ngg_nogs_prim_exp_arg(b, s);
459
460 /* pack user edge flag info into arg */
461 if (s->has_user_edgeflags) {
462 /* Workgroup barrier: wait for ES threads store user edge flags to LDS */
463 nir_barrier(b, .execution_scope = SCOPE_WORKGROUP,
464 .memory_scope = SCOPE_WORKGROUP,
465 .memory_semantics = NIR_MEMORY_ACQ_REL,
466 .memory_modes = nir_var_mem_shared);
467
468 unsigned edge_flag_bits = ac_get_all_edge_flag_bits(s->options->gfx_level);
469 nir_def *mask = nir_imm_intN_t(b, ~edge_flag_bits, 32);
470
471 unsigned edge_flag_offset = 0;
472 if (s->streamout_enabled) {
473 unsigned packed_location =
474 util_bitcount64(b->shader->info.outputs_written &
475 BITFIELD64_MASK(VARYING_SLOT_EDGE));
476 edge_flag_offset = packed_location * 16;
477 }
478
479 for (int i = 0; i < s->options->num_vertices_per_primitive; i++) {
480 nir_def *vtx_idx = nir_load_var(b, s->gs_vtx_indices_vars[i]);
481 nir_def *addr = pervertex_lds_addr(b, vtx_idx, s->pervertex_lds_bytes);
482 nir_def *edge = nir_load_shared(b, 1, 32, addr, .base = edge_flag_offset);
483
484 if (s->options->gfx_level >= GFX12)
485 mask = nir_ior(b, mask, nir_ishl_imm(b, edge, 8 + i * 9));
486 else
487 mask = nir_ior(b, mask, nir_ishl_imm(b, edge, 9 + i * 10));
488 }
489 arg = nir_iand(b, arg, mask);
490 }
491
492 ac_nir_export_primitive(b, arg, NULL);
493
494 /* Store implicit primitive ID when configured as a per-primitive output on GFX10.3.
495 * Because this uses the export space, do it together with the primitive export.
496 */
497 if (s->options->gfx_level == GFX10_3 && s->options->export_primitive_id_per_prim) {
498 const uint8_t offset = s->options->vs_output_param_offset[VARYING_SLOT_PRIMITIVE_ID];
499 nir_def *prim_id = nir_load_primitive_id(b);
500 nir_def *undef = nir_undef(b, 1, 32);
501 ac_nir_prerast_out out = {
502 .infos = {{.components_mask = 1, .as_varying_mask = 1}},
503 .outputs = {{prim_id, undef, undef, undef}}
504 };
505
506 ac_nir_export_parameters(b, &offset, 1, 0, &out);
507 }
508 }
509 nir_pop_if(b, if_gs_thread);
510 }
511
512 static void
emit_ngg_nogs_prim_id_store_shared(nir_builder * b,lower_ngg_nogs_state * s)513 emit_ngg_nogs_prim_id_store_shared(nir_builder *b, lower_ngg_nogs_state *s)
514 {
515 nir_def *gs_thread =
516 s->gs_accepted_var ? nir_load_var(b, s->gs_accepted_var) : has_input_primitive(b);
517
518 nir_if *if_gs_thread = nir_push_if(b, gs_thread);
519 {
520 /* Copy Primitive IDs from GS threads to the LDS address
521 * corresponding to the ES thread of the provoking vertex.
522 * It will be exported as a per-vertex attribute.
523 */
524 nir_def *gs_vtx_indices[3];
525 for (unsigned i = 0; i < s->options->num_vertices_per_primitive; i++)
526 gs_vtx_indices[i] = nir_load_var(b, s->gs_vtx_indices_vars[i]);
527
528 nir_def *provoking_vertex = nir_load_provoking_vtx_in_prim_amd(b);
529 nir_def *provoking_vtx_idx = nir_select_from_ssa_def_array(
530 b, gs_vtx_indices, s->options->num_vertices_per_primitive, provoking_vertex);
531
532 nir_def *prim_id = nir_load_primitive_id(b);
533 nir_def *addr = pervertex_lds_addr(b, provoking_vtx_idx, s->pervertex_lds_bytes);
534
535 /* primitive id is always at last of a vertex */
536 nir_store_shared(b, prim_id, addr, .base = s->pervertex_lds_bytes - 4);
537 }
538 nir_pop_if(b, if_gs_thread);
539 }
540
541 /* Store implicit primitive ID when configured as a per-primitive output on GFX11+.
542 * This is done separately from the primitive export on GFX11 in order to
543 * optimize attribute ring access.
544 */
545 static void
emit_ngg_nogs_prim_id_store_per_prim_to_attr_ring(nir_builder * b,lower_ngg_nogs_state * s)546 emit_ngg_nogs_prim_id_store_per_prim_to_attr_ring(nir_builder *b, lower_ngg_nogs_state *s)
547 {
548 assert(s->options->gfx_level >= GFX11);
549
550 nir_def *is_gs_thread = nir_load_var(b, s->gs_exported_var);
551 nir_def *highest_gs_thread = nir_ufind_msb(b, nir_ballot(b, 1, s->options->wave_size, is_gs_thread));
552 nir_def *max_num_gs_threads = nir_iadd_imm_nuw(b, highest_gs_thread, 1);
553
554 const uint8_t offset = s->options->vs_output_param_offset[VARYING_SLOT_PRIMITIVE_ID];
555 ac_nir_prerast_out out = {
556 .infos = {{.components_mask = 1, .as_varying_mask = 1}},
557 .outputs = {{nir_load_primitive_id(b), NULL, NULL, NULL}}
558 };
559
560 ac_nir_store_parameters_to_attr_ring(b, &offset, 1, 0, &out, NULL, max_num_gs_threads);
561 }
562
563 static void
emit_store_ngg_nogs_es_primitive_id(nir_builder * b,lower_ngg_nogs_state * s)564 emit_store_ngg_nogs_es_primitive_id(nir_builder *b, lower_ngg_nogs_state *s)
565 {
566 nir_def *prim_id = NULL;
567
568 if (b->shader->info.stage == MESA_SHADER_VERTEX) {
569 /* LDS address where the primitive ID is stored */
570 nir_def *thread_id_in_threadgroup = nir_load_local_invocation_index(b);
571 nir_def *addr =
572 pervertex_lds_addr(b, thread_id_in_threadgroup, s->pervertex_lds_bytes);
573
574 /* Load primitive ID from LDS */
575 prim_id = nir_load_shared(b, 1, 32, addr, .base = s->pervertex_lds_bytes - 4);
576 } else if (b->shader->info.stage == MESA_SHADER_TESS_EVAL) {
577 /* Just use tess eval primitive ID, which is the same as the patch ID. */
578 prim_id = nir_load_primitive_id(b);
579 }
580
581 s->out.outputs[VARYING_SLOT_PRIMITIVE_ID][0] = prim_id;
582 s->out.infos[VARYING_SLOT_PRIMITIVE_ID].as_varying_mask |= 1;
583
584 /* Update outputs_written to reflect that the pass added a new output. */
585 b->shader->info.outputs_written |= VARYING_BIT_PRIMITIVE_ID;
586 }
587
588 static void
add_clipdist_bit(nir_builder * b,nir_def * dist,unsigned index,nir_variable * mask)589 add_clipdist_bit(nir_builder *b, nir_def *dist, unsigned index, nir_variable *mask)
590 {
591 nir_def *is_neg = nir_flt_imm(b, dist, 0);
592 nir_def *neg_mask = nir_ishl_imm(b, nir_b2i32(b, is_neg), index);
593 neg_mask = nir_ior(b, neg_mask, nir_load_var(b, mask));
594 nir_store_var(b, mask, neg_mask, 1);
595 }
596
597 static bool
remove_culling_shader_output(nir_builder * b,nir_instr * instr,void * state)598 remove_culling_shader_output(nir_builder *b, nir_instr *instr, void *state)
599 {
600 lower_ngg_nogs_state *s = (lower_ngg_nogs_state *) state;
601
602 if (instr->type != nir_instr_type_intrinsic)
603 return false;
604
605 nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
606
607 /* These are not allowed in VS / TES */
608 assert(intrin->intrinsic != nir_intrinsic_store_per_vertex_output &&
609 intrin->intrinsic != nir_intrinsic_load_per_vertex_input);
610
611 /* We are only interested in output stores now */
612 if (intrin->intrinsic != nir_intrinsic_store_output)
613 return false;
614
615 b->cursor = nir_before_instr(instr);
616
617 /* no indirect output */
618 assert(nir_src_is_const(intrin->src[1]) && nir_src_as_uint(intrin->src[1]) == 0);
619
620 unsigned writemask = nir_intrinsic_write_mask(intrin);
621 unsigned component = nir_intrinsic_component(intrin);
622 nir_def *store_val = intrin->src[0].ssa;
623
624 /* Position output - store the value to a variable, remove output store */
625 nir_io_semantics io_sem = nir_intrinsic_io_semantics(intrin);
626 switch (io_sem.location) {
627 case VARYING_SLOT_POS:
628 ac_nir_store_var_components(b, s->position_value_var, store_val, component, writemask);
629 break;
630 case VARYING_SLOT_CLIP_DIST0:
631 case VARYING_SLOT_CLIP_DIST1: {
632 unsigned base = io_sem.location == VARYING_SLOT_CLIP_DIST1 ? 4 : 0;
633 base += component;
634
635 /* valid clipdist component mask */
636 unsigned mask = (s->options->clip_cull_dist_mask >> base) & writemask;
637 u_foreach_bit(i, mask) {
638 add_clipdist_bit(b, nir_channel(b, store_val, i), base + i,
639 s->clipdist_neg_mask_var);
640 s->has_clipdist = true;
641 }
642 break;
643 }
644 case VARYING_SLOT_CLIP_VERTEX:
645 ac_nir_store_var_components(b, s->clip_vertex_var, store_val, component, writemask);
646 break;
647 default:
648 break;
649 }
650
651 /* Remove all output stores */
652 nir_instr_remove(instr);
653 return true;
654 }
655
656 static void
remove_culling_shader_outputs(nir_shader * culling_shader,lower_ngg_nogs_state * s)657 remove_culling_shader_outputs(nir_shader *culling_shader, lower_ngg_nogs_state *s)
658 {
659 nir_shader_instructions_pass(culling_shader, remove_culling_shader_output,
660 nir_metadata_control_flow, s);
661
662 /* Remove dead code resulting from the deleted outputs. */
663 bool progress;
664 do {
665 progress = false;
666 NIR_PASS(progress, culling_shader, nir_opt_dead_write_vars);
667 NIR_PASS(progress, culling_shader, nir_opt_dce);
668 NIR_PASS(progress, culling_shader, nir_opt_dead_cf);
669 } while (progress);
670 }
671
672 static void
rewrite_uses_to_var(nir_builder * b,nir_def * old_def,nir_variable * replacement_var,unsigned replacement_var_channel)673 rewrite_uses_to_var(nir_builder *b, nir_def *old_def, nir_variable *replacement_var, unsigned replacement_var_channel)
674 {
675 if (old_def->parent_instr->type == nir_instr_type_load_const)
676 return;
677
678 b->cursor = nir_after_instr(old_def->parent_instr);
679 if (b->cursor.instr->type == nir_instr_type_phi)
680 b->cursor = nir_after_phis(old_def->parent_instr->block);
681
682 nir_def *pos_val_rep = nir_load_var(b, replacement_var);
683 nir_def *replacement = nir_channel(b, pos_val_rep, replacement_var_channel);
684
685 if (old_def->num_components > 1) {
686 /* old_def uses a swizzled vector component.
687 * There is no way to replace the uses of just a single vector component,
688 * so instead create a new vector and replace all uses of the old vector.
689 */
690 nir_def *old_def_elements[NIR_MAX_VEC_COMPONENTS] = {0};
691 for (unsigned j = 0; j < old_def->num_components; ++j)
692 old_def_elements[j] = nir_channel(b, old_def, j);
693 replacement = nir_vec(b, old_def_elements, old_def->num_components);
694 }
695
696 nir_def_rewrite_uses_after(old_def, replacement, replacement->parent_instr);
697 }
698
699 static bool
remove_extra_pos_output(nir_builder * b,nir_instr * instr,void * state)700 remove_extra_pos_output(nir_builder *b, nir_instr *instr, void *state)
701 {
702 lower_ngg_nogs_state *s = (lower_ngg_nogs_state *) state;
703
704 if (instr->type != nir_instr_type_intrinsic)
705 return false;
706
707 nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
708
709 /* These are not allowed in VS / TES */
710 assert(intrin->intrinsic != nir_intrinsic_store_per_vertex_output &&
711 intrin->intrinsic != nir_intrinsic_load_per_vertex_input);
712
713 /* We are only interested in output stores now */
714 if (intrin->intrinsic != nir_intrinsic_store_output)
715 return false;
716
717 nir_io_semantics io_sem = nir_intrinsic_io_semantics(intrin);
718 if (io_sem.location != VARYING_SLOT_POS)
719 return false;
720
721 b->cursor = nir_before_instr(instr);
722
723 /* In case other outputs use what we calculated for pos,
724 * try to avoid calculating it again by rewriting the usages
725 * of the store components here.
726 */
727 nir_def *store_val = intrin->src[0].ssa;
728 unsigned store_pos_component = nir_intrinsic_component(intrin);
729
730 nir_instr_remove(instr);
731
732 if (store_val->parent_instr->type == nir_instr_type_alu) {
733 nir_alu_instr *alu = nir_instr_as_alu(store_val->parent_instr);
734 if (nir_op_is_vec_or_mov(alu->op)) {
735 /* Output store uses a vector, we can easily rewrite uses of each vector element. */
736
737 unsigned num_vec_src = 0;
738 if (alu->op == nir_op_mov)
739 num_vec_src = 1;
740 else if (alu->op == nir_op_vec2)
741 num_vec_src = 2;
742 else if (alu->op == nir_op_vec3)
743 num_vec_src = 3;
744 else if (alu->op == nir_op_vec4)
745 num_vec_src = 4;
746 assert(num_vec_src);
747
748 /* Remember the current components whose uses we wish to replace.
749 * This is needed because rewriting one source can affect the others too.
750 */
751 nir_def *vec_comps[NIR_MAX_VEC_COMPONENTS] = {0};
752 for (unsigned i = 0; i < num_vec_src; i++)
753 vec_comps[i] = alu->src[i].src.ssa;
754
755 for (unsigned i = 0; i < num_vec_src; i++)
756 rewrite_uses_to_var(b, vec_comps[i], s->position_value_var, store_pos_component + i);
757 } else {
758 rewrite_uses_to_var(b, store_val, s->position_value_var, store_pos_component);
759 }
760 } else {
761 rewrite_uses_to_var(b, store_val, s->position_value_var, store_pos_component);
762 }
763
764 return true;
765 }
766
767 static void
remove_extra_pos_outputs(nir_shader * shader,lower_ngg_nogs_state * s)768 remove_extra_pos_outputs(nir_shader *shader, lower_ngg_nogs_state *s)
769 {
770 nir_shader_instructions_pass(shader, remove_extra_pos_output,
771 nir_metadata_control_flow,
772 s);
773 }
774
775 static bool
remove_compacted_arg(lower_ngg_nogs_state * s,nir_builder * b,unsigned idx)776 remove_compacted_arg(lower_ngg_nogs_state *s, nir_builder *b, unsigned idx)
777 {
778 nir_instr *store_instr = s->compact_arg_stores[idx];
779 if (!store_instr)
780 return false;
781
782 /* Simply remove the store. */
783 nir_instr_remove(store_instr);
784
785 /* Find the intrinsic that overwrites the shader arguments,
786 * and change its corresponding source.
787 * This will cause NIR's DCE to recognize the load and its phis as dead.
788 */
789 b->cursor = nir_before_instr(&s->overwrite_args->instr);
790 nir_def *undef_arg = nir_undef(b, 1, 32);
791 nir_def_rewrite_uses(s->overwrite_args->src[idx].ssa, undef_arg);
792
793 s->compact_arg_stores[idx] = NULL;
794 return true;
795 }
796
797 static bool
cleanup_culling_shader_after_dce(nir_shader * shader,nir_function_impl * function_impl,lower_ngg_nogs_state * s)798 cleanup_culling_shader_after_dce(nir_shader *shader,
799 nir_function_impl *function_impl,
800 lower_ngg_nogs_state *s)
801 {
802 bool uses_vs_vertex_id = false;
803 bool uses_vs_instance_id = false;
804 bool uses_tes_u = false;
805 bool uses_tes_v = false;
806 bool uses_tes_rel_patch_id = false;
807 bool uses_tes_patch_id = false;
808
809 bool progress = false;
810 nir_builder b = nir_builder_create(function_impl);
811
812 nir_foreach_block_reverse_safe(block, function_impl) {
813 nir_foreach_instr_reverse_safe(instr, block) {
814 if (instr->type != nir_instr_type_intrinsic)
815 continue;
816
817 nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
818
819 switch (intrin->intrinsic) {
820 case nir_intrinsic_sendmsg_amd:
821 goto cleanup_culling_shader_after_dce_done;
822 case nir_intrinsic_load_vertex_id:
823 case nir_intrinsic_load_vertex_id_zero_base:
824 uses_vs_vertex_id = true;
825 break;
826 case nir_intrinsic_load_instance_id:
827 uses_vs_instance_id = true;
828 break;
829 case nir_intrinsic_load_input: {
830 const nir_io_semantics io_sem = nir_intrinsic_io_semantics(intrin);
831 if (s->options->instance_rate_inputs & BITFIELD_BIT(io_sem.location))
832 uses_vs_instance_id = true;
833 else
834 uses_vs_vertex_id = true;
835 break;
836 }
837 case nir_intrinsic_load_tess_coord:
838 uses_tes_u = uses_tes_v = true;
839 break;
840 case nir_intrinsic_load_tess_rel_patch_id_amd:
841 uses_tes_rel_patch_id = true;
842 break;
843 case nir_intrinsic_load_primitive_id:
844 if (shader->info.stage == MESA_SHADER_TESS_EVAL)
845 uses_tes_patch_id = true;
846 break;
847 default:
848 break;
849 }
850 }
851 }
852
853 cleanup_culling_shader_after_dce_done:
854
855 if (shader->info.stage == MESA_SHADER_VERTEX) {
856 if (!uses_vs_vertex_id)
857 progress |= remove_compacted_arg(s, &b, 0);
858 if (!uses_vs_instance_id)
859 progress |= remove_compacted_arg(s, &b, 1);
860 } else if (shader->info.stage == MESA_SHADER_TESS_EVAL) {
861 if (!uses_tes_u)
862 progress |= remove_compacted_arg(s, &b, 0);
863 if (!uses_tes_v)
864 progress |= remove_compacted_arg(s, &b, 1);
865 if (!uses_tes_rel_patch_id)
866 progress |= remove_compacted_arg(s, &b, 3);
867 if (!uses_tes_patch_id)
868 progress |= remove_compacted_arg(s, &b, 2);
869 }
870
871 return progress;
872 }
873
874 /**
875 * Perform vertex compaction after culling.
876 *
877 * 1. Repack surviving ES invocations (this determines which lane will export which vertex)
878 * 2. Surviving ES vertex invocations store their data to LDS
879 * 3. Emit GS_ALLOC_REQ
880 * 4. Repacked invocations load the vertex data from LDS
881 * 5. GS threads update their vertex indices
882 * 6. Optionally, do the same for primitives.
883 */
884 static void
compact_vertices_after_culling(nir_builder * b,lower_ngg_nogs_state * s,nir_variable ** repacked_variables,nir_variable ** gs_vtxaddr_vars,nir_def * invocation_index,nir_def * es_vertex_lds_addr,nir_def * es_exporter_tid,nir_def * num_live_vertices_in_workgroup,nir_def * gs_exporter_tid,nir_def * num_live_primitives_in_workgroup,unsigned pervertex_lds_bytes,unsigned num_repacked_variables)885 compact_vertices_after_culling(nir_builder *b,
886 lower_ngg_nogs_state *s,
887 nir_variable **repacked_variables,
888 nir_variable **gs_vtxaddr_vars,
889 nir_def *invocation_index,
890 nir_def *es_vertex_lds_addr,
891 nir_def *es_exporter_tid,
892 nir_def *num_live_vertices_in_workgroup,
893 nir_def *gs_exporter_tid,
894 nir_def *num_live_primitives_in_workgroup,
895 unsigned pervertex_lds_bytes,
896 unsigned num_repacked_variables)
897 {
898 nir_variable *es_accepted_var = s->es_accepted_var;
899 nir_variable *gs_accepted_var = s->gs_accepted_var;
900 nir_variable *position_value_var = s->position_value_var;
901 nir_variable *prim_exp_arg_var = s->prim_exp_arg_var;
902
903 nir_if *if_es_accepted = nir_push_if(b, nir_load_var(b, es_accepted_var));
904 {
905 nir_def *exporter_addr = pervertex_lds_addr(b, es_exporter_tid, pervertex_lds_bytes);
906
907 /* Store the exporter thread's index to the LDS space of the current thread so GS threads can load it */
908 nir_store_shared(b, nir_u2u8(b, es_exporter_tid), es_vertex_lds_addr, .base = lds_es_exporter_tid);
909
910 /* Store the current thread's position output to the exporter thread's LDS space */
911 nir_def *pos = nir_load_var(b, position_value_var);
912 nir_store_shared(b, pos, exporter_addr, .base = lds_es_pos_x);
913
914 /* Store the current thread's repackable arguments to the exporter thread's LDS space */
915 for (unsigned i = 0; i < num_repacked_variables; ++i) {
916 nir_def *arg_val = nir_load_var(b, repacked_variables[i]);
917 nir_intrinsic_instr *store = nir_store_shared(b, arg_val, exporter_addr, .base = lds_es_arg_0 + 4u * i);
918
919 s->compact_arg_stores[i] = &store->instr;
920 }
921
922 /* TES rel patch id does not cost extra dword */
923 if (b->shader->info.stage == MESA_SHADER_TESS_EVAL) {
924 nir_def *arg_val = nir_load_var(b, s->repacked_rel_patch_id);
925 nir_intrinsic_instr *store =
926 nir_store_shared(b, nir_u2u8(b, arg_val), exporter_addr,
927 .base = lds_es_tes_rel_patch_id);
928
929 s->compact_arg_stores[3] = &store->instr;
930 }
931 }
932 nir_pop_if(b, if_es_accepted);
933
934 /* TODO: Consider adding a shortcut exit.
935 * Waves that have no vertices and primitives left can s_endpgm right here.
936 */
937
938 nir_barrier(b, .execution_scope=SCOPE_WORKGROUP, .memory_scope=SCOPE_WORKGROUP,
939 .memory_semantics=NIR_MEMORY_ACQ_REL, .memory_modes=nir_var_mem_shared);
940
941 nir_def *es_survived = nir_ilt(b, invocation_index, num_live_vertices_in_workgroup);
942 nir_if *if_packed_es_thread = nir_push_if(b, es_survived);
943 {
944 /* Read position from the current ES thread's LDS space (written by the exported vertex's ES thread) */
945 nir_def *exported_pos = nir_load_shared(b, 4, 32, es_vertex_lds_addr, .base = lds_es_pos_x);
946 nir_store_var(b, position_value_var, exported_pos, 0xfu);
947
948 /* Read the repacked arguments */
949 for (unsigned i = 0; i < num_repacked_variables; ++i) {
950 nir_def *arg_val = nir_load_shared(b, 1, 32, es_vertex_lds_addr, .base = lds_es_arg_0 + 4u * i);
951 nir_store_var(b, repacked_variables[i], arg_val, 0x1u);
952 }
953
954 if (b->shader->info.stage == MESA_SHADER_TESS_EVAL) {
955 nir_def *arg_val = nir_load_shared(b, 1, 8, es_vertex_lds_addr,
956 .base = lds_es_tes_rel_patch_id);
957 nir_store_var(b, s->repacked_rel_patch_id, nir_u2u32(b, arg_val), 0x1u);
958 }
959 }
960 nir_push_else(b, if_packed_es_thread);
961 {
962 nir_store_var(b, position_value_var, nir_undef(b, 4, 32), 0xfu);
963 for (unsigned i = 0; i < num_repacked_variables; ++i)
964 nir_store_var(b, repacked_variables[i], nir_undef(b, 1, 32), 0x1u);
965 }
966 nir_pop_if(b, if_packed_es_thread);
967
968 nir_def *gs_accepted = nir_load_var(b, gs_accepted_var);
969 nir_if *if_gs_accepted = nir_push_if(b, gs_accepted);
970 {
971 nir_def *exporter_vtx_indices[3] = {0};
972
973 /* Load the index of the ES threads that will export the current GS thread's vertices */
974 for (unsigned v = 0; v < s->options->num_vertices_per_primitive; ++v) {
975 nir_def *vtx_addr = nir_load_var(b, gs_vtxaddr_vars[v]);
976 nir_def *exporter_vtx_idx = nir_load_shared(b, 1, 8, vtx_addr, .base = lds_es_exporter_tid);
977 exporter_vtx_indices[v] = nir_u2u32(b, exporter_vtx_idx);
978 nir_store_var(b, s->gs_vtx_indices_vars[v], exporter_vtx_indices[v], 0x1);
979 }
980
981 nir_def *prim_exp_arg =
982 ac_nir_pack_ngg_prim_exp_arg(b, s->options->num_vertices_per_primitive,
983 exporter_vtx_indices, NULL, s->options->gfx_level);
984 nir_store_var(b, prim_exp_arg_var, prim_exp_arg, 0x1u);
985 }
986 nir_pop_if(b, if_gs_accepted);
987
988 nir_store_var(b, es_accepted_var, es_survived, 0x1u);
989
990 if (s->options->compact_primitives) {
991 /* For primitive compaction, re-use the same LDS space that we used for
992 * vertex compaction, so we need to wait until vertex threads are finished reading it.
993 * Considering we only need 1 DWORD per primitive, let's assume we always have enough space,
994 * since vertex compaction requires at least 5 DWORDs per vertex.
995 */
996 nir_barrier(b, .execution_scope=SCOPE_WORKGROUP, .memory_scope=SCOPE_WORKGROUP,
997 .memory_semantics=NIR_MEMORY_ACQ_REL, .memory_modes=nir_var_mem_shared);
998
999 if_gs_accepted = nir_push_if(b, gs_accepted);
1000 {
1001 nir_def *exporter_addr = pervertex_lds_addr(b, gs_exporter_tid, pervertex_lds_bytes);
1002 nir_def *prim_exp_arg = nir_load_var(b, prim_exp_arg_var);
1003
1004 /* Store the primitive export argument into the address of the exporter thread. */
1005 nir_store_shared(b, prim_exp_arg, exporter_addr, .base = lds_es_pos_x);
1006 }
1007 nir_pop_if(b, if_gs_accepted);
1008
1009 nir_barrier(b, .execution_scope=SCOPE_WORKGROUP, .memory_scope=SCOPE_WORKGROUP,
1010 .memory_semantics=NIR_MEMORY_ACQ_REL, .memory_modes=nir_var_mem_shared);
1011
1012 nir_def *gs_survived = nir_ilt(b, invocation_index, num_live_primitives_in_workgroup);
1013 nir_if *if_packed_gs_thread = nir_push_if(b, gs_survived);
1014 {
1015 /* Load the primitive export argument that the current thread will export. */
1016 nir_def *prim_exp_arg = nir_load_shared(b, 1, 32, es_vertex_lds_addr, .base = lds_es_pos_x);
1017
1018 nir_store_var(b, prim_exp_arg_var, prim_exp_arg, 0x1u);
1019 }
1020 nir_push_else(b, if_packed_gs_thread);
1021 {
1022 nir_store_var(b, prim_exp_arg_var, nir_undef(b, 1, 32), 0x1u);
1023 }
1024 nir_pop_if(b, if_packed_gs_thread);
1025
1026 nir_store_var(b, gs_accepted_var, gs_survived, 0x1u);
1027 nir_store_var(b, s->gs_exported_var, gs_survived, 0x1u);
1028 }
1029 }
1030
1031 static void
analyze_shader_before_culling_walk(nir_def * ssa,uint8_t flag,lower_ngg_nogs_state * s)1032 analyze_shader_before_culling_walk(nir_def *ssa,
1033 uint8_t flag,
1034 lower_ngg_nogs_state *s)
1035 {
1036 nir_instr *instr = ssa->parent_instr;
1037 uint8_t old_pass_flags = instr->pass_flags;
1038 instr->pass_flags |= flag;
1039
1040 if (instr->pass_flags == old_pass_flags)
1041 return; /* Already visited. */
1042
1043 switch (instr->type) {
1044 case nir_instr_type_intrinsic: {
1045 nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
1046
1047 /* VS input loads and SSBO loads are actually VRAM reads on AMD HW. */
1048 switch (intrin->intrinsic) {
1049 case nir_intrinsic_load_input: {
1050 nir_io_semantics in_io_sem = nir_intrinsic_io_semantics(intrin);
1051 uint64_t in_mask = UINT64_C(1) << (uint64_t) in_io_sem.location;
1052 if (instr->pass_flags & nggc_passflag_used_by_pos)
1053 s->inputs_needed_by_pos |= in_mask;
1054 else if (instr->pass_flags & nggc_passflag_used_by_other)
1055 s->inputs_needed_by_others |= in_mask;
1056 break;
1057 }
1058 default:
1059 break;
1060 }
1061
1062 break;
1063 }
1064 case nir_instr_type_alu: {
1065 nir_alu_instr *alu = nir_instr_as_alu(instr);
1066 unsigned num_srcs = nir_op_infos[alu->op].num_inputs;
1067
1068 for (unsigned i = 0; i < num_srcs; ++i) {
1069 analyze_shader_before_culling_walk(alu->src[i].src.ssa, flag, s);
1070 }
1071
1072 break;
1073 }
1074 case nir_instr_type_tex: {
1075 nir_tex_instr *tex = nir_instr_as_tex(instr);
1076 unsigned num_srcs = tex->num_srcs;
1077
1078 for (unsigned i = 0; i < num_srcs; ++i) {
1079 analyze_shader_before_culling_walk(tex->src[i].src.ssa, flag, s);
1080 }
1081
1082 break;
1083 }
1084 case nir_instr_type_phi: {
1085 nir_phi_instr *phi = nir_instr_as_phi(instr);
1086 nir_foreach_phi_src_safe(phi_src, phi) {
1087 analyze_shader_before_culling_walk(phi_src->src.ssa, flag, s);
1088 }
1089
1090 break;
1091 }
1092 default:
1093 break;
1094 }
1095 }
1096
1097 static void
analyze_shader_before_culling(nir_shader * shader,lower_ngg_nogs_state * s)1098 analyze_shader_before_culling(nir_shader *shader, lower_ngg_nogs_state *s)
1099 {
1100 /* We need divergence info for culling shaders. */
1101 nir_divergence_analysis(shader);
1102
1103 nir_foreach_function_impl(impl, shader) {
1104 nir_foreach_block(block, impl) {
1105 nir_foreach_instr(instr, block) {
1106 instr->pass_flags = 0;
1107
1108 if (instr->type != nir_instr_type_intrinsic)
1109 continue;
1110
1111 nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
1112 if (intrin->intrinsic != nir_intrinsic_store_output)
1113 continue;
1114
1115 nir_io_semantics io_sem = nir_intrinsic_io_semantics(intrin);
1116 nir_def *store_val = intrin->src[0].ssa;
1117 uint8_t flag = io_sem.location == VARYING_SLOT_POS ? nggc_passflag_used_by_pos : nggc_passflag_used_by_other;
1118 analyze_shader_before_culling_walk(store_val, flag, s);
1119 }
1120 }
1121 }
1122 }
1123
1124 static nir_def *
find_reusable_ssa_def(nir_instr * instr)1125 find_reusable_ssa_def(nir_instr *instr)
1126 {
1127 /* Find instructions whose SSA definitions are used by both
1128 * the top and bottom parts of the shader (before and after culling).
1129 * Only in this case, it makes sense for the bottom part
1130 * to try to reuse these from the top part.
1131 */
1132 if ((instr->pass_flags & nggc_passflag_used_by_both) != nggc_passflag_used_by_both)
1133 return NULL;
1134
1135 switch (instr->type) {
1136 case nir_instr_type_alu: {
1137 nir_alu_instr *alu = nir_instr_as_alu(instr);
1138 if (alu->def.divergent)
1139 return NULL;
1140 /* Ignore uniform floats because they regress VGPR usage too much */
1141 if (nir_op_infos[alu->op].output_type & nir_type_float)
1142 return NULL;
1143 return &alu->def;
1144 }
1145 case nir_instr_type_intrinsic: {
1146 nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
1147 if (!nir_intrinsic_can_reorder(intrin) ||
1148 !nir_intrinsic_infos[intrin->intrinsic].has_dest ||
1149 intrin->def.divergent)
1150 return NULL;
1151 return &intrin->def;
1152 }
1153 case nir_instr_type_phi: {
1154 nir_phi_instr *phi = nir_instr_as_phi(instr);
1155 if (phi->def.divergent)
1156 return NULL;
1157 return &phi->def;
1158 }
1159 default:
1160 return NULL;
1161 }
1162 }
1163
1164 static const struct glsl_type *
glsl_uint_type_for_ssa(nir_def * ssa)1165 glsl_uint_type_for_ssa(nir_def *ssa)
1166 {
1167 enum glsl_base_type base_type = GLSL_TYPE_UINT;
1168 switch (ssa->bit_size) {
1169 case 8: base_type = GLSL_TYPE_UINT8; break;
1170 case 16: base_type = GLSL_TYPE_UINT16; break;
1171 case 32: base_type = GLSL_TYPE_UINT; break;
1172 case 64: base_type = GLSL_TYPE_UINT64; break;
1173 default: return NULL;
1174 }
1175
1176 return ssa->num_components == 1
1177 ? glsl_scalar_type(base_type)
1178 : glsl_vector_type(base_type, ssa->num_components);
1179 }
1180
1181 /**
1182 * Save the reusable SSA definitions to variables so that the
1183 * bottom shader part can reuse them from the top part.
1184 *
1185 * 1. We create a new function temporary variable for reusables,
1186 * and insert a store+load.
1187 * 2. The shader is cloned (the top part is created), then the
1188 * control flow is reinserted (for the bottom part.)
1189 * 3. For reusables, we delete the variable stores from the
1190 * bottom part. This will make them use the variables from
1191 * the top part and DCE the redundant instructions.
1192 */
1193 static void
save_reusable_variables(nir_builder * b,lower_ngg_nogs_state * s)1194 save_reusable_variables(nir_builder *b, lower_ngg_nogs_state *s)
1195 {
1196 ASSERTED int vec_ok = u_vector_init(&s->reusable_nondeferred_variables, 4, sizeof(reusable_nondeferred_variable));
1197 assert(vec_ok);
1198
1199 /* Upper limit on reusable uniforms in order to reduce SGPR spilling. */
1200 unsigned remaining_reusable_uniforms = 48;
1201
1202 nir_block *block = nir_start_block(b->impl);
1203 while (block) {
1204 /* Process the instructions in the current block. */
1205 nir_foreach_instr_safe(instr, block) {
1206 /* Determine if we can reuse the current SSA value.
1207 * When vertex compaction is used, it is possible that the same shader invocation
1208 * processes a different vertex in the top and bottom part of the shader.
1209 * Therefore, we only reuse uniform values.
1210 */
1211 nir_def *ssa = find_reusable_ssa_def(instr);
1212 if (!ssa)
1213 continue;
1214
1215 /* Determine a suitable type for the SSA value. */
1216 const struct glsl_type *t = glsl_uint_type_for_ssa(ssa);
1217 if (!t)
1218 continue;
1219
1220 if (!ssa->divergent) {
1221 if (remaining_reusable_uniforms < ssa->num_components)
1222 continue;
1223
1224 remaining_reusable_uniforms -= ssa->num_components;
1225 }
1226
1227 reusable_nondeferred_variable *saved = (reusable_nondeferred_variable *) u_vector_add(&s->reusable_nondeferred_variables);
1228 assert(saved);
1229
1230 /* Create a new NIR variable where we store the reusable value.
1231 * Then, we reload the variable and replace the uses of the value
1232 * with the reloaded variable.
1233 */
1234 saved->var = nir_local_variable_create(b->impl, t, NULL);
1235 saved->ssa = ssa;
1236
1237 b->cursor = instr->type == nir_instr_type_phi
1238 ? nir_after_instr_and_phis(instr)
1239 : nir_after_instr(instr);
1240 nir_store_var(b, saved->var, saved->ssa, BITFIELD_MASK(ssa->num_components));
1241 nir_def *reloaded = nir_load_var(b, saved->var);
1242 nir_def_rewrite_uses_after(ssa, reloaded, reloaded->parent_instr);
1243 }
1244
1245 /* Look at the next CF node. */
1246 nir_cf_node *next_cf_node = nir_cf_node_next(&block->cf_node);
1247 if (next_cf_node) {
1248 /* It makes no sense to try to reuse things from within loops. */
1249 bool next_is_loop = next_cf_node->type == nir_cf_node_loop;
1250
1251 /* Don't reuse if we're in divergent control flow.
1252 *
1253 * Thanks to vertex repacking, the same shader invocation may process a different vertex
1254 * in the top and bottom part, and it's even possible that this different vertex was initially
1255 * processed in a different wave. So the two parts may take a different divergent code path.
1256 * Therefore, these variables in divergent control flow may stay undefined.
1257 *
1258 * Note that this problem doesn't exist if vertices are not repacked or if the
1259 * workgroup only has a single wave.
1260 */
1261 bool next_is_divergent_if =
1262 next_cf_node->type == nir_cf_node_if &&
1263 nir_src_is_divergent(&nir_cf_node_as_if(next_cf_node)->condition);
1264
1265 if (next_is_loop || next_is_divergent_if) {
1266 block = nir_cf_node_cf_tree_next(next_cf_node);
1267 continue;
1268 }
1269 }
1270
1271 /* Go to the next block. */
1272 block = nir_block_cf_tree_next(block);
1273 }
1274 }
1275
1276 /**
1277 * Reuses suitable variables from the top part of the shader,
1278 * by deleting their stores from the bottom part.
1279 */
1280 static void
apply_reusable_variables(nir_builder * b,lower_ngg_nogs_state * s)1281 apply_reusable_variables(nir_builder *b, lower_ngg_nogs_state *s)
1282 {
1283 if (!u_vector_length(&s->reusable_nondeferred_variables)) {
1284 u_vector_finish(&s->reusable_nondeferred_variables);
1285 return;
1286 }
1287
1288 nir_foreach_block_reverse_safe(block, b->impl) {
1289 nir_foreach_instr_reverse_safe(instr, block) {
1290 if (instr->type != nir_instr_type_intrinsic)
1291 continue;
1292 nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
1293
1294 /* When we found any of these intrinsics, it means
1295 * we reached the top part and we must stop.
1296 */
1297 if (intrin->intrinsic == nir_intrinsic_sendmsg_amd)
1298 goto done;
1299
1300 if (intrin->intrinsic != nir_intrinsic_store_deref)
1301 continue;
1302 nir_deref_instr *deref = nir_src_as_deref(intrin->src[0]);
1303 if (deref->deref_type != nir_deref_type_var)
1304 continue;
1305
1306 reusable_nondeferred_variable *saved;
1307 u_vector_foreach(saved, &s->reusable_nondeferred_variables) {
1308 if (saved->var == deref->var) {
1309 nir_instr_remove(instr);
1310 }
1311 }
1312 }
1313 }
1314
1315 done:
1316 u_vector_finish(&s->reusable_nondeferred_variables);
1317 }
1318
1319 static void
cull_primitive_accepted(nir_builder * b,void * state)1320 cull_primitive_accepted(nir_builder *b, void *state)
1321 {
1322 lower_ngg_nogs_state *s = (lower_ngg_nogs_state *)state;
1323
1324 nir_store_var(b, s->gs_accepted_var, nir_imm_true(b), 0x1u);
1325
1326 /* Store the accepted state to LDS for ES threads */
1327 for (unsigned vtx = 0; vtx < s->options->num_vertices_per_primitive; ++vtx)
1328 nir_store_shared(b, nir_imm_intN_t(b, 1, 8), s->vtx_addr[vtx], .base = lds_es_vertex_accepted);
1329 }
1330
1331 static void
clipdist_culling_es_part(nir_builder * b,lower_ngg_nogs_state * s,nir_def * es_vertex_lds_addr)1332 clipdist_culling_es_part(nir_builder *b, lower_ngg_nogs_state *s,
1333 nir_def *es_vertex_lds_addr)
1334 {
1335 /* no gl_ClipDistance used but we have user defined clip plane */
1336 if (s->options->user_clip_plane_enable_mask && !s->has_clipdist) {
1337 /* use gl_ClipVertex if defined */
1338 nir_variable *clip_vertex_var =
1339 b->shader->info.outputs_written & BITFIELD64_BIT(VARYING_SLOT_CLIP_VERTEX) ?
1340 s->clip_vertex_var : s->position_value_var;
1341 nir_def *clip_vertex = nir_load_var(b, clip_vertex_var);
1342
1343 /* clip against user defined clip planes */
1344 for (unsigned i = 0; i < 8; i++) {
1345 if (!(s->options->user_clip_plane_enable_mask & BITFIELD_BIT(i)))
1346 continue;
1347
1348 nir_def *plane = nir_load_user_clip_plane(b, .ucp_id = i);
1349 nir_def *dist = nir_fdot(b, clip_vertex, plane);
1350 add_clipdist_bit(b, dist, i, s->clipdist_neg_mask_var);
1351 }
1352
1353 s->has_clipdist = true;
1354 }
1355
1356 /* store clipdist_neg_mask to LDS for culling latter in gs thread */
1357 if (s->has_clipdist) {
1358 nir_def *mask = nir_load_var(b, s->clipdist_neg_mask_var);
1359 nir_store_shared(b, nir_u2u8(b, mask), es_vertex_lds_addr,
1360 .base = lds_es_clipdist_neg_mask);
1361 }
1362 }
1363
1364 static unsigned
ngg_nogs_get_culling_pervertex_lds_size(gl_shader_stage stage,bool uses_instance_id,bool uses_primitive_id,unsigned * num_repacked_variables)1365 ngg_nogs_get_culling_pervertex_lds_size(gl_shader_stage stage,
1366 bool uses_instance_id,
1367 bool uses_primitive_id,
1368 unsigned *num_repacked_variables)
1369 {
1370 /* Culling shaders must repack some variables because
1371 * the same shader invocation may process different vertices
1372 * before and after the culling algorithm.
1373 */
1374
1375 unsigned num_repacked;
1376 if (stage == MESA_SHADER_VERTEX) {
1377 /* Vertex shaders repack:
1378 * - Vertex ID
1379 * - Instance ID (only if used)
1380 */
1381 num_repacked = uses_instance_id ? 2 : 1;
1382 } else {
1383 /* Tess eval shaders repack:
1384 * - U, V coordinates
1385 * - primitive ID (aka. patch id, only if used)
1386 * - relative patch id (not included here because doesn't need a dword)
1387 */
1388 assert(stage == MESA_SHADER_TESS_EVAL);
1389 num_repacked = uses_primitive_id ? 3 : 2;
1390 }
1391
1392 if (num_repacked_variables)
1393 *num_repacked_variables = num_repacked;
1394
1395 /* one odd dword to reduce LDS bank conflict */
1396 return (lds_es_arg_0 + num_repacked * 4u) | 4u;
1397 }
1398
1399 static void
add_deferred_attribute_culling(nir_builder * b,nir_cf_list * original_extracted_cf,lower_ngg_nogs_state * s)1400 add_deferred_attribute_culling(nir_builder *b, nir_cf_list *original_extracted_cf, lower_ngg_nogs_state *s)
1401 {
1402 bool uses_instance_id = BITSET_TEST(b->shader->info.system_values_read, SYSTEM_VALUE_INSTANCE_ID);
1403 bool uses_tess_primitive_id = BITSET_TEST(b->shader->info.system_values_read, SYSTEM_VALUE_PRIMITIVE_ID);
1404
1405 unsigned num_repacked_variables;
1406 unsigned pervertex_lds_bytes =
1407 ngg_nogs_get_culling_pervertex_lds_size(b->shader->info.stage,
1408 uses_instance_id,
1409 uses_tess_primitive_id,
1410 &num_repacked_variables);
1411
1412 nir_function_impl *impl = nir_shader_get_entrypoint(b->shader);
1413
1414 /* Create some helper variables. */
1415 nir_variable *gs_vtxaddr_vars[3] = {
1416 nir_local_variable_create(impl, glsl_uint_type(), "gs_vtx0_addr"),
1417 nir_local_variable_create(impl, glsl_uint_type(), "gs_vtx1_addr"),
1418 nir_local_variable_create(impl, glsl_uint_type(), "gs_vtx2_addr"),
1419 };
1420
1421 nir_variable *repacked_variables[3] = {
1422 nir_local_variable_create(impl, glsl_uint_type(), "repacked_var_0"),
1423 nir_local_variable_create(impl, glsl_uint_type(), "repacked_var_1"),
1424 nir_local_variable_create(impl, glsl_uint_type(), "repacked_var_2"),
1425 };
1426
1427 /* Relative patch ID is a special case because it doesn't need an extra dword, repack separately. */
1428 s->repacked_rel_patch_id = nir_local_variable_create(impl, glsl_uint_type(), "repacked_rel_patch_id");
1429
1430 if (s->options->clip_cull_dist_mask ||
1431 s->options->user_clip_plane_enable_mask) {
1432 s->clip_vertex_var =
1433 nir_local_variable_create(impl, glsl_vec4_type(), "clip_vertex");
1434 s->clipdist_neg_mask_var =
1435 nir_local_variable_create(impl, glsl_uint_type(), "clipdist_neg_mask");
1436
1437 /* init mask to 0 */
1438 nir_store_var(b, s->clipdist_neg_mask_var, nir_imm_int(b, 0), 1);
1439 }
1440
1441 /* Top part of the culling shader (aka. position shader part)
1442 *
1443 * We clone the full ES shader and emit it here, but we only really care
1444 * about its position output, so we delete every other output from this part.
1445 * The position output is stored into a temporary variable, and reloaded later.
1446 */
1447
1448 nir_def *es_thread = has_input_vertex(b);
1449 nir_if *if_es_thread = nir_push_if(b, es_thread);
1450 {
1451 /* Initialize the position output variable to zeroes, in case not all VS/TES invocations store the output.
1452 * The spec doesn't require it, but we use (0, 0, 0, 1) because some games rely on that.
1453 */
1454 nir_store_var(b, s->position_value_var, nir_imm_vec4(b, 0.0f, 0.0f, 0.0f, 1.0f), 0xfu);
1455
1456 /* Now reinsert a clone of the shader code */
1457 struct hash_table *remap_table = _mesa_pointer_hash_table_create(NULL);
1458 nir_cf_list_clone_and_reinsert(original_extracted_cf, &if_es_thread->cf_node, b->cursor, remap_table);
1459 _mesa_hash_table_destroy(remap_table, NULL);
1460 b->cursor = nir_after_cf_list(&if_es_thread->then_list);
1461
1462 /* Remember the current thread's shader arguments */
1463 if (b->shader->info.stage == MESA_SHADER_VERTEX) {
1464 nir_store_var(b, repacked_variables[0], nir_load_vertex_id_zero_base(b), 0x1u);
1465 if (uses_instance_id)
1466 nir_store_var(b, repacked_variables[1], nir_load_instance_id(b), 0x1u);
1467 } else if (b->shader->info.stage == MESA_SHADER_TESS_EVAL) {
1468 nir_store_var(b, s->repacked_rel_patch_id, nir_load_tess_rel_patch_id_amd(b), 0x1u);
1469 nir_def *tess_coord = nir_load_tess_coord(b);
1470 nir_store_var(b, repacked_variables[0], nir_channel(b, tess_coord, 0), 0x1u);
1471 nir_store_var(b, repacked_variables[1], nir_channel(b, tess_coord, 1), 0x1u);
1472 if (uses_tess_primitive_id)
1473 nir_store_var(b, repacked_variables[2], nir_load_primitive_id(b), 0x1u);
1474 } else {
1475 unreachable("Should be VS or TES.");
1476 }
1477 }
1478 nir_pop_if(b, if_es_thread);
1479
1480 nir_store_var(b, s->es_accepted_var, es_thread, 0x1u);
1481 nir_def *gs_thread = has_input_primitive(b);
1482 nir_store_var(b, s->gs_accepted_var, gs_thread, 0x1u);
1483
1484 /* Remove all non-position outputs, and put the position output into the variable. */
1485 nir_metadata_preserve(impl, nir_metadata_none);
1486 remove_culling_shader_outputs(b->shader, s);
1487 b->cursor = nir_after_impl(impl);
1488
1489 nir_def *lds_scratch_base = nir_load_lds_ngg_scratch_base_amd(b);
1490
1491 /* Run culling algorithms if culling is enabled.
1492 *
1493 * NGG culling can be enabled or disabled in runtime.
1494 * This is determined by a SGPR shader argument which is accessed
1495 * by the following NIR intrinsic.
1496 */
1497
1498 nir_if *if_cull_en = nir_push_if(b, nir_load_cull_any_enabled_amd(b));
1499 {
1500 nir_def *invocation_index = nir_load_local_invocation_index(b);
1501 nir_def *es_vertex_lds_addr = pervertex_lds_addr(b, invocation_index, pervertex_lds_bytes);
1502
1503 /* ES invocations store their vertex data to LDS for GS threads to read. */
1504 if_es_thread = nir_push_if(b, es_thread);
1505 if_es_thread->control = nir_selection_control_divergent_always_taken;
1506 {
1507 /* Store position components that are relevant to culling in LDS */
1508 nir_def *pre_cull_pos = nir_load_var(b, s->position_value_var);
1509 nir_def *pre_cull_w = nir_channel(b, pre_cull_pos, 3);
1510 nir_store_shared(b, pre_cull_w, es_vertex_lds_addr, .base = lds_es_pos_w);
1511 nir_def *pre_cull_x_div_w = nir_fdiv(b, nir_channel(b, pre_cull_pos, 0), pre_cull_w);
1512 nir_def *pre_cull_y_div_w = nir_fdiv(b, nir_channel(b, pre_cull_pos, 1), pre_cull_w);
1513 nir_store_shared(b, nir_vec2(b, pre_cull_x_div_w, pre_cull_y_div_w), es_vertex_lds_addr, .base = lds_es_pos_x);
1514
1515 /* Clear out the ES accepted flag in LDS */
1516 nir_store_shared(b, nir_imm_zero(b, 1, 8), es_vertex_lds_addr, .align_mul = 4, .base = lds_es_vertex_accepted);
1517
1518 /* For clipdist culling */
1519 clipdist_culling_es_part(b, s, es_vertex_lds_addr);
1520 }
1521 nir_pop_if(b, if_es_thread);
1522
1523 nir_barrier(b, .execution_scope=SCOPE_WORKGROUP, .memory_scope=SCOPE_WORKGROUP,
1524 .memory_semantics=NIR_MEMORY_ACQ_REL, .memory_modes=nir_var_mem_shared);
1525
1526 nir_store_var(b, s->gs_accepted_var, nir_imm_false(b), 0x1u);
1527 nir_store_var(b, s->prim_exp_arg_var, nir_imm_int(b, 1u << 31), 0x1u);
1528
1529 /* GS invocations load the vertex data and perform the culling. */
1530 nir_if *if_gs_thread = nir_push_if(b, gs_thread);
1531 {
1532 /* Load vertex indices from input VGPRs */
1533 nir_def *vtx_idx[3] = {0};
1534 for (unsigned vertex = 0; vertex < s->options->num_vertices_per_primitive;
1535 ++vertex)
1536 vtx_idx[vertex] = nir_load_var(b, s->gs_vtx_indices_vars[vertex]);
1537
1538 nir_def *pos[3][4] = {0};
1539
1540 /* Load W positions of vertices first because the culling code will use these first */
1541 for (unsigned vtx = 0; vtx < s->options->num_vertices_per_primitive; ++vtx) {
1542 s->vtx_addr[vtx] = pervertex_lds_addr(b, vtx_idx[vtx], pervertex_lds_bytes);
1543 pos[vtx][3] = nir_load_shared(b, 1, 32, s->vtx_addr[vtx], .base = lds_es_pos_w);
1544 nir_store_var(b, gs_vtxaddr_vars[vtx], s->vtx_addr[vtx], 0x1u);
1545 }
1546
1547 /* Load the X/W, Y/W positions of vertices */
1548 for (unsigned vtx = 0; vtx < s->options->num_vertices_per_primitive; ++vtx) {
1549 nir_def *xy = nir_load_shared(b, 2, 32, s->vtx_addr[vtx], .base = lds_es_pos_x);
1550 pos[vtx][0] = nir_channel(b, xy, 0);
1551 pos[vtx][1] = nir_channel(b, xy, 1);
1552 }
1553
1554 nir_def *accepted_by_clipdist;
1555 if (s->has_clipdist) {
1556 nir_def *clipdist_neg_mask = nir_imm_intN_t(b, 0xff, 8);
1557 for (unsigned vtx = 0; vtx < s->options->num_vertices_per_primitive; ++vtx) {
1558 nir_def *mask =
1559 nir_load_shared(b, 1, 8, s->vtx_addr[vtx],
1560 .base = lds_es_clipdist_neg_mask);
1561 clipdist_neg_mask = nir_iand(b, clipdist_neg_mask, mask);
1562 }
1563 /* primitive is culled if any plane's clipdist of all vertices are negative */
1564 accepted_by_clipdist = nir_ieq_imm(b, clipdist_neg_mask, 0);
1565 } else {
1566 accepted_by_clipdist = nir_imm_true(b);
1567 }
1568
1569 /* See if the current primitive is accepted */
1570 ac_nir_cull_primitive(b, accepted_by_clipdist, pos,
1571 s->options->num_vertices_per_primitive,
1572 cull_primitive_accepted, s);
1573 }
1574 nir_pop_if(b, if_gs_thread);
1575
1576 nir_barrier(b, .execution_scope=SCOPE_WORKGROUP, .memory_scope=SCOPE_WORKGROUP,
1577 .memory_semantics=NIR_MEMORY_ACQ_REL, .memory_modes=nir_var_mem_shared);
1578
1579 nir_store_var(b, s->es_accepted_var, nir_imm_false(b), 0x1u);
1580
1581 /* ES invocations load their accepted flag from LDS. */
1582 if_es_thread = nir_push_if(b, es_thread);
1583 if_es_thread->control = nir_selection_control_divergent_always_taken;
1584 {
1585 nir_def *accepted = nir_load_shared(b, 1, 8u, es_vertex_lds_addr, .base = lds_es_vertex_accepted, .align_mul = 4u);
1586 nir_def *accepted_bool = nir_ine_imm(b, nir_u2u32(b, accepted), 0);
1587 nir_store_var(b, s->es_accepted_var, accepted_bool, 0x1u);
1588 }
1589 nir_pop_if(b, if_es_thread);
1590
1591 nir_def *es_accepted = nir_load_var(b, s->es_accepted_var);
1592 nir_def *gs_accepted = nir_load_var(b, s->gs_accepted_var);
1593
1594 /* Repack the vertices (always) and primitives (optional) that survived the culling. */
1595 nir_def *accepted[] = { es_accepted, gs_accepted };
1596 wg_repack_result rep[2] = {0};
1597 const unsigned num_rep = s->options->compact_primitives ? 2 : 1;
1598 repack_invocations_in_workgroup(b, accepted, rep, num_rep, lds_scratch_base,
1599 s->max_num_waves, s->options->wave_size);
1600 nir_def *num_live_vertices_in_workgroup = rep[0].num_repacked_invocations;
1601 nir_def *es_exporter_tid = rep[0].repacked_invocation_index;
1602 nir_def *num_exported_prims = NULL;
1603 nir_def *gs_exporter_tid = NULL;
1604
1605 if (s->options->compact_primitives) {
1606 num_exported_prims = rep[1].num_repacked_invocations;
1607 gs_exporter_tid = rep[1].repacked_invocation_index;
1608 } else {
1609 /* If all vertices are culled, set primitive count to 0 as well. */
1610 nir_def *fully_culled = nir_ieq_imm(b, num_live_vertices_in_workgroup, 0u);
1611 num_exported_prims = nir_bcsel(b, fully_culled, nir_imm_int(b, 0u), nir_load_workgroup_num_input_primitives_amd(b));
1612 nir_store_var(b, s->gs_exported_var, nir_iand(b, nir_inot(b, fully_culled), has_input_primitive(b)), 0x1u);
1613 }
1614
1615 nir_if *if_wave_0 = nir_push_if(b, nir_ieq_imm(b, nir_load_subgroup_id(b), 0));
1616 {
1617 /* Tell the final vertex and primitive count to the HW. */
1618 if (s->options->gfx_level == GFX10) {
1619 alloc_vertices_and_primitives_gfx10_workaround(
1620 b, num_live_vertices_in_workgroup, num_exported_prims);
1621 } else {
1622 alloc_vertices_and_primitives(
1623 b, num_live_vertices_in_workgroup, num_exported_prims);
1624 }
1625 }
1626 nir_pop_if(b, if_wave_0);
1627
1628 /* Vertex compaction. */
1629 compact_vertices_after_culling(b, s,
1630 repacked_variables, gs_vtxaddr_vars,
1631 invocation_index, es_vertex_lds_addr,
1632 es_exporter_tid, num_live_vertices_in_workgroup,
1633 gs_exporter_tid, num_exported_prims,
1634 pervertex_lds_bytes, num_repacked_variables);
1635 }
1636 nir_push_else(b, if_cull_en);
1637 {
1638 /* When culling is disabled, we do the same as we would without culling. */
1639 nir_if *if_wave_0 = nir_push_if(b, nir_ieq_imm(b, nir_load_subgroup_id(b), 0));
1640 {
1641 nir_def *vtx_cnt = nir_load_workgroup_num_input_vertices_amd(b);
1642 nir_def *prim_cnt = nir_load_workgroup_num_input_primitives_amd(b);
1643 alloc_vertices_and_primitives(b, vtx_cnt, prim_cnt);
1644 }
1645 nir_pop_if(b, if_wave_0);
1646 nir_store_var(b, s->prim_exp_arg_var, emit_ngg_nogs_prim_exp_arg(b, s), 0x1u);
1647 }
1648 nir_pop_if(b, if_cull_en);
1649
1650 /* Update shader arguments.
1651 *
1652 * The registers which hold information about the subgroup's
1653 * vertices and primitives are updated here, so the rest of the shader
1654 * doesn't need to worry about the culling.
1655 *
1656 * These "overwrite" intrinsics must be at top level control flow,
1657 * otherwise they can mess up the backend (eg. ACO's SSA).
1658 *
1659 * TODO:
1660 * A cleaner solution would be to simply replace all usages of these args
1661 * with the load of the variables.
1662 * However, this wouldn't work right now because the backend uses the arguments
1663 * for purposes not expressed in NIR, eg. VS input loads, etc.
1664 * This can change if VS input loads and other stuff are lowered to eg. load_buffer_amd.
1665 */
1666
1667 if (b->shader->info.stage == MESA_SHADER_VERTEX)
1668 s->overwrite_args =
1669 nir_overwrite_vs_arguments_amd(b,
1670 nir_load_var(b, repacked_variables[0]), nir_load_var(b, repacked_variables[1]));
1671 else if (b->shader->info.stage == MESA_SHADER_TESS_EVAL)
1672 s->overwrite_args =
1673 nir_overwrite_tes_arguments_amd(b,
1674 nir_load_var(b, repacked_variables[0]), nir_load_var(b, repacked_variables[1]),
1675 nir_load_var(b, repacked_variables[2]), nir_load_var(b, s->repacked_rel_patch_id));
1676 else
1677 unreachable("Should be VS or TES.");
1678 }
1679
1680 static void
ngg_nogs_store_edgeflag_to_lds(nir_builder * b,lower_ngg_nogs_state * s)1681 ngg_nogs_store_edgeflag_to_lds(nir_builder *b, lower_ngg_nogs_state *s)
1682 {
1683 if (!s->out.outputs[VARYING_SLOT_EDGE][0])
1684 return;
1685
1686 /* clamp user edge flag to 1 for latter bit operations */
1687 nir_def *edgeflag = s->out.outputs[VARYING_SLOT_EDGE][0];
1688 edgeflag = nir_umin(b, edgeflag, nir_imm_int(b, 1));
1689
1690 /* user edge flag is stored at the beginning of a vertex if streamout is not enabled */
1691 unsigned offset = 0;
1692 if (s->streamout_enabled) {
1693 unsigned packed_location =
1694 util_bitcount64(b->shader->info.outputs_written & BITFIELD64_MASK(VARYING_SLOT_EDGE));
1695 offset = packed_location * 16;
1696 }
1697
1698 nir_def *tid = nir_load_local_invocation_index(b);
1699 nir_def *addr = pervertex_lds_addr(b, tid, s->pervertex_lds_bytes);
1700
1701 nir_store_shared(b, edgeflag, addr, .base = offset);
1702 }
1703
1704 static void
ngg_nogs_store_xfb_outputs_to_lds(nir_builder * b,lower_ngg_nogs_state * s)1705 ngg_nogs_store_xfb_outputs_to_lds(nir_builder *b, lower_ngg_nogs_state *s)
1706 {
1707 nir_xfb_info *info = ac_nir_get_sorted_xfb_info(b->shader);
1708
1709 uint64_t xfb_outputs = 0;
1710 unsigned xfb_outputs_16bit = 0;
1711 uint8_t xfb_mask[VARYING_SLOT_MAX] = {0};
1712 uint8_t xfb_mask_16bit_lo[16] = {0};
1713 uint8_t xfb_mask_16bit_hi[16] = {0};
1714
1715 /* Get XFB output mask for each slot. */
1716 for (int i = 0; i < info->output_count; i++) {
1717 nir_xfb_output_info *out = info->outputs + i;
1718
1719 if (out->location < VARYING_SLOT_VAR0_16BIT) {
1720 xfb_outputs |= BITFIELD64_BIT(out->location);
1721 xfb_mask[out->location] |= out->component_mask;
1722 } else {
1723 unsigned index = out->location - VARYING_SLOT_VAR0_16BIT;
1724 xfb_outputs_16bit |= BITFIELD_BIT(index);
1725
1726 if (out->high_16bits)
1727 xfb_mask_16bit_hi[index] |= out->component_mask;
1728 else
1729 xfb_mask_16bit_lo[index] |= out->component_mask;
1730 }
1731 }
1732
1733 nir_def *tid = nir_load_local_invocation_index(b);
1734 nir_def *addr = pervertex_lds_addr(b, tid, s->pervertex_lds_bytes);
1735
1736 u_foreach_bit64(slot, xfb_outputs) {
1737 uint64_t outputs_written = b->shader->info.outputs_written;
1738 if (s->skip_primitive_id)
1739 outputs_written &= ~VARYING_BIT_PRIMITIVE_ID;
1740 unsigned packed_location =
1741 util_bitcount64(outputs_written & BITFIELD64_MASK(slot));
1742
1743 unsigned mask = xfb_mask[slot];
1744
1745 /* Clear unused components. */
1746 for (unsigned i = 0; i < 4; i++) {
1747 if (!s->out.outputs[slot][i])
1748 mask &= ~BITFIELD_BIT(i);
1749 }
1750
1751 while (mask) {
1752 int start, count;
1753 u_bit_scan_consecutive_range(&mask, &start, &count);
1754 /* Outputs here are sure to be 32bit.
1755 *
1756 * 64bit outputs have been lowered to two 32bit. As 16bit outputs:
1757 * Vulkan does not allow streamout outputs less than 32bit.
1758 * OpenGL puts 16bit outputs in VARYING_SLOT_VAR0_16BIT.
1759 */
1760 nir_def *store_val = nir_vec(b, &s->out.outputs[slot][start], (unsigned)count);
1761 nir_store_shared(b, store_val, addr, .base = packed_location * 16 + start * 4);
1762 }
1763 }
1764
1765 unsigned num_32bit_outputs = util_bitcount64(b->shader->info.outputs_written);
1766 u_foreach_bit64(slot, xfb_outputs_16bit) {
1767 unsigned packed_location = num_32bit_outputs +
1768 util_bitcount(b->shader->info.outputs_written_16bit & BITFIELD_MASK(slot));
1769
1770 unsigned mask_lo = xfb_mask_16bit_lo[slot];
1771 unsigned mask_hi = xfb_mask_16bit_hi[slot];
1772
1773 /* Clear unused components. */
1774 for (unsigned i = 0; i < 4; i++) {
1775 if (!s->out.outputs_16bit_lo[slot][i])
1776 mask_lo &= ~BITFIELD_BIT(i);
1777 if (!s->out.outputs_16bit_hi[slot][i])
1778 mask_hi &= ~BITFIELD_BIT(i);
1779 }
1780
1781 nir_def **outputs_lo = s->out.outputs_16bit_lo[slot];
1782 nir_def **outputs_hi = s->out.outputs_16bit_hi[slot];
1783 nir_def *undef = nir_undef(b, 1, 16);
1784
1785 unsigned mask = mask_lo | mask_hi;
1786 while (mask) {
1787 int start, count;
1788 u_bit_scan_consecutive_range(&mask, &start, &count);
1789
1790 nir_def *values[4] = {0};
1791 for (int c = start; c < start + count; ++c) {
1792 nir_def *lo = mask_lo & BITFIELD_BIT(c) ? outputs_lo[c] : undef;
1793 nir_def *hi = mask_hi & BITFIELD_BIT(c) ? outputs_hi[c] : undef;
1794
1795 /* extend 8/16 bit to 32 bit, 64 bit has been lowered */
1796 values[c - start] = nir_pack_32_2x16_split(b, lo, hi);
1797 }
1798
1799 nir_def *store_val = nir_vec(b, values, (unsigned)count);
1800 nir_store_shared(b, store_val, addr, .base = packed_location * 16 + start * 4);
1801 }
1802 }
1803 }
1804
1805 static nir_def *
write_values_to_lanes(nir_builder * b,nir_def ** values,unsigned lane_mask)1806 write_values_to_lanes(nir_builder *b, nir_def **values, unsigned lane_mask)
1807 {
1808 nir_def *lanes = nir_imm_int(b, 0);
1809
1810 u_foreach_bit(i, lane_mask) {
1811 lanes = nir_write_invocation_amd(b, lanes, values[i], nir_imm_int(b, i));
1812 }
1813 return lanes;
1814 }
1815
1816 static nir_def *
read_values_from_4_lanes(nir_builder * b,nir_def * values,unsigned lane_mask)1817 read_values_from_4_lanes(nir_builder *b, nir_def *values, unsigned lane_mask)
1818 {
1819 nir_def *undef = nir_undef(b, 1, 32);
1820 nir_def *per_lane[4] = {undef, undef, undef, undef};
1821
1822 u_foreach_bit(i, lane_mask) {
1823 per_lane[i] = nir_read_invocation(b, values, nir_imm_int(b, i));
1824 }
1825 return nir_vec(b, per_lane, 4);
1826 }
1827
1828 static void
ngg_build_streamout_buffer_info(nir_builder * b,nir_xfb_info * info,enum amd_gfx_level gfx_level,bool has_xfb_prim_query,bool use_gfx12_xfb_intrinsic,nir_def * scratch_base,nir_def * tid_in_tg,nir_def * gen_prim[4],nir_def * so_buffer_ret[4],nir_def * buffer_offsets_ret[4],nir_def * emit_prim_ret[4])1829 ngg_build_streamout_buffer_info(nir_builder *b,
1830 nir_xfb_info *info,
1831 enum amd_gfx_level gfx_level,
1832 bool has_xfb_prim_query,
1833 bool use_gfx12_xfb_intrinsic,
1834 nir_def *scratch_base,
1835 nir_def *tid_in_tg,
1836 nir_def *gen_prim[4],
1837 nir_def *so_buffer_ret[4],
1838 nir_def *buffer_offsets_ret[4],
1839 nir_def *emit_prim_ret[4])
1840 {
1841 nir_def *prim_stride[4] = {0};
1842 nir_def *undef = nir_undef(b, 1, 32);
1843
1844 /* For radeonsi which pass this value by arg when VS. Streamout need accurate
1845 * num-vert-per-prim for writing correct amount of data to buffer.
1846 */
1847 nir_def *num_vert_per_prim = nir_load_num_vertices_per_primitive_amd(b);
1848 for (unsigned buffer = 0; buffer < 4; buffer++) {
1849 if (!(info->buffers_written & BITFIELD_BIT(buffer)))
1850 continue;
1851
1852 assert(info->buffers[buffer].stride);
1853
1854 prim_stride[buffer] =
1855 nir_imul_imm(b, num_vert_per_prim, info->buffers[buffer].stride);
1856 so_buffer_ret[buffer] = nir_load_streamout_buffer_amd(b, .base = buffer);
1857 }
1858
1859 nir_if *if_invocation_0 = nir_push_if(b, nir_ieq_imm(b, tid_in_tg, 0));
1860 {
1861 nir_def *any_buffer_valid = nir_imm_false(b);
1862 nir_def *workgroup_buffer_sizes[4];
1863
1864 for (unsigned buffer = 0; buffer < 4; buffer++) {
1865 if (info->buffers_written & BITFIELD_BIT(buffer)) {
1866 nir_def *buffer_size = nir_channel(b, so_buffer_ret[buffer], 2);
1867 /* In radeonsi, we may not know if a feedback buffer has been bound when
1868 * compile time, so have to check buffer size in runtime to disable the
1869 * GDS update for unbind buffer to prevent the case that previous draw
1870 * compiled with streamout but does not bind feedback buffer miss update
1871 * GDS which will affect current draw's streamout.
1872 */
1873 nir_def *buffer_valid = nir_ine_imm(b, buffer_size, 0);
1874 nir_def *inc_buffer_size =
1875 nir_imul(b, gen_prim[info->buffer_to_stream[buffer]], prim_stride[buffer]);
1876 workgroup_buffer_sizes[buffer] =
1877 nir_bcsel(b, buffer_valid, inc_buffer_size, nir_imm_int(b, 0));
1878 any_buffer_valid = nir_ior(b, any_buffer_valid, buffer_valid);
1879 } else
1880 workgroup_buffer_sizes[buffer] = undef;
1881 }
1882
1883 nir_def *buffer_offsets = NULL, *xfb_state_address = NULL, *xfb_voffset = NULL;
1884
1885 /* Get current global offset of buffer and increase by amount of
1886 * workgroup buffer size. This is an ordered operation sorted by
1887 * ordered_id; Each buffer info is in a channel of a vec4.
1888 */
1889 if (gfx_level >= GFX12) {
1890 nir_pop_if(b, if_invocation_0);
1891
1892 for (unsigned buffer = 0; buffer < 4; buffer++)
1893 workgroup_buffer_sizes[buffer] = nir_if_phi(b, workgroup_buffer_sizes[buffer], undef);
1894 any_buffer_valid = nir_if_phi(b, any_buffer_valid, nir_undef(b, 1, 1));
1895
1896 /* These must be set after nir_pop_if and phis. */
1897 xfb_state_address = nir_load_xfb_state_address_gfx12_amd(b);
1898 xfb_voffset = nir_imul_imm(b, tid_in_tg, 8);
1899
1900 nir_if *if_4lanes = nir_push_if(b, nir_iand(b, any_buffer_valid, nir_ult_imm(b, tid_in_tg, 4)));
1901 {
1902 /* Move workgroup buffer sizes from SGPRs to the first 4 lanes. */
1903 nir_def *workgroup_buffer_size_per_lane =
1904 write_values_to_lanes(b, workgroup_buffer_sizes, info->buffers_written);
1905 nir_def *ordered_id = nir_load_ordered_id_amd(b);
1906
1907 /* The atomic value for the 4 lanes is:
1908 * lane 0: uvec2(ordered_id, workgroup_buffer_size0)
1909 * lane 1: uvec2(ordered_id, workgroup_buffer_size1)
1910 * lane 2: uvec2(ordered_id, workgroup_buffer_size2)
1911 * lane 3: uvec2(ordered_id, workgroup_buffer_size3)
1912 */
1913 nir_def *atomic_src = nir_pack_64_2x32_split(b, ordered_id,
1914 workgroup_buffer_size_per_lane);
1915
1916 /* The memory layout of the xfb state is:
1917 * struct {
1918 * unsigned ordered_id;
1919 * unsigned dwords_written0;
1920 * unsigned ordered_id;
1921 * unsigned dwords_written1;
1922 * unsigned ordered_id;
1923 * unsigned dwords_written2;
1924 * unsigned ordered_id;
1925 * unsigned dwords_written3;
1926 * };
1927 *
1928 * Notes:
1929 * - global_atomic_ordered_add_b64 is semantically a 64-bit atomic, requiring 8-byte
1930 * address alignment, even though it operates on a pair of 32-bit values.
1931 * - The whole structure is updated at once by issuing the atomic from 4 lanes
1932 * with 8-byte address increments.
1933 * - The whole structure should be entirely within one 64B block of memory
1934 * for performance. (the address bits above 64B should not differ between lanes)
1935 */
1936 nir_def *buffer_offset_per_lane;
1937
1938 /* The gfx12 intrinsic inserts hand-written assembly producing better code than current
1939 * LLVM.
1940 */
1941 if (use_gfx12_xfb_intrinsic) {
1942 buffer_offset_per_lane =
1943 nir_ordered_add_loop_gfx12_amd(b, xfb_state_address, xfb_voffset, ordered_id,
1944 atomic_src);
1945
1946 /* Move the buffer offsets from the 4 lanes to lane 0. */
1947 buffer_offsets = read_values_from_4_lanes(b, buffer_offset_per_lane, info->buffers_written);
1948 } else {
1949 /* The NIR version of the above using nir_atomic_op_ordered_add_gfx12_amd. */
1950 enum { NUM_ATOMICS_IN_FLIGHT = 6 };
1951
1952 nir_variable *result_ring[NUM_ATOMICS_IN_FLIGHT] = {0};
1953 for (unsigned i = 0; i < NUM_ATOMICS_IN_FLIGHT; i++)
1954 result_ring[i] = nir_local_variable_create(b->impl, glsl_uint64_t_type(), "result");
1955
1956 /* Issue the first N-1 atomics. The shader must not wait because we want them to be
1957 * pipelined. It will only wait for the oldest atomic in the NIR loop.
1958 */
1959 for (unsigned i = 0; i < NUM_ATOMICS_IN_FLIGHT - 1; i++) {
1960 nir_store_var(b, result_ring[i],
1961 nir_global_atomic_amd(b, 64, xfb_state_address, atomic_src, xfb_voffset,
1962 .atomic_op = nir_atomic_op_ordered_add_gfx12_amd), 0x1);
1963 ac_nir_sleep(b, 24);
1964 }
1965
1966 nir_variable *buffer_offsets_var =
1967 nir_local_variable_create(b->impl, glsl_vec4_type(), "buffer_offset_per_lane");
1968
1969 nir_loop *loop = nir_push_loop(b);
1970 {
1971 for (unsigned i = 0; i < NUM_ATOMICS_IN_FLIGHT; i++) {
1972 int issue_index = (NUM_ATOMICS_IN_FLIGHT - 1 + i) % NUM_ATOMICS_IN_FLIGHT;
1973 int read_index = i;
1974
1975 /* Issue (or repeat) the atomic. */
1976 nir_store_var(b, result_ring[issue_index],
1977 nir_global_atomic_amd(b, 64, xfb_state_address, atomic_src, xfb_voffset,
1978 .atomic_op = nir_atomic_op_ordered_add_gfx12_amd), 0x1);
1979
1980 /* Break if the oldest atomic succeeded in incrementing the offsets. */
1981 nir_def *oldest_result = nir_load_var(b, result_ring[read_index]);
1982 nir_def *loaded_ordered_id = nir_unpack_64_2x32_split_x(b, oldest_result);
1983
1984 /* Debug: Write the vec4 into a shader log ring buffer. */
1985 #if 0
1986 nir_def *loaded_dwords_written = nir_unpack_64_2x32_split_y(b, oldest_result);
1987 ac_nir_store_debug_log_amd(b, nir_vec4(b, nir_u2u32(b, xfb_state_address),
1988 ordered_id, loaded_ordered_id,
1989 loaded_dwords_written));
1990 #endif
1991
1992 nir_def *continue_if = nir_ieq(b, loaded_ordered_id, ordered_id);
1993 continue_if = nir_inot(b, nir_vote_any(b, 1, continue_if));
1994 nir_push_if(b, continue_if);
1995 }
1996 nir_jump(b, nir_jump_continue);
1997
1998 for (unsigned i = 0; i < NUM_ATOMICS_IN_FLIGHT; i++) {
1999 int read_index = NUM_ATOMICS_IN_FLIGHT - 1 - i;
2000 nir_push_else(b, NULL);
2001 {
2002 nir_def *result = nir_load_var(b, result_ring[read_index]);
2003 buffer_offset_per_lane = nir_unpack_64_2x32_split_y(b, result);
2004 buffer_offsets = read_values_from_4_lanes(b, buffer_offset_per_lane, info->buffers_written);
2005 nir_store_var(b, buffer_offsets_var, buffer_offsets, info->buffers_written);
2006 }
2007 nir_pop_if(b, NULL);
2008 }
2009 nir_jump(b, nir_jump_break);
2010 }
2011 nir_pop_loop(b, loop);
2012 buffer_offsets = nir_load_var(b, buffer_offsets_var);
2013 }
2014 }
2015 nir_pop_if(b, if_4lanes);
2016 buffer_offsets = nir_if_phi(b, buffer_offsets, nir_undef(b, 4, 32));
2017
2018 if_invocation_0 = nir_push_if(b, nir_ieq_imm(b, tid_in_tg, 0));
2019 } else {
2020 nir_def *ordered_id = nir_load_ordered_id_amd(b);
2021 buffer_offsets =
2022 nir_ordered_xfb_counter_add_gfx11_amd(b, ordered_id,
2023 nir_vec(b, workgroup_buffer_sizes, 4),
2024 /* mask of buffers to update */
2025 .write_mask = info->buffers_written);
2026 }
2027
2028 nir_def *emit_prim[4];
2029 memcpy(emit_prim, gen_prim, 4 * sizeof(nir_def *));
2030
2031 nir_def *any_overflow = nir_imm_false(b);
2032 nir_def *overflow_amount[4] = {undef, undef, undef, undef};
2033
2034 for (unsigned buffer = 0; buffer < 4; buffer++) {
2035 if (!(info->buffers_written & BITFIELD_BIT(buffer)))
2036 continue;
2037
2038 nir_def *buffer_size = nir_channel(b, so_buffer_ret[buffer], 2);
2039
2040 /* Only consider overflow for valid feedback buffers because
2041 * otherwise the ordered operation above (GDS atomic return) might
2042 * return non-zero offsets for invalid buffers.
2043 */
2044 nir_def *buffer_valid = nir_ine_imm(b, buffer_size, 0);
2045 nir_def *buffer_offset = nir_channel(b, buffer_offsets, buffer);
2046 buffer_offset = nir_bcsel(b, buffer_valid, buffer_offset, nir_imm_int(b, 0));
2047
2048 nir_def *remain_size = nir_isub(b, buffer_size, buffer_offset);
2049 nir_def *remain_prim = nir_idiv(b, remain_size, prim_stride[buffer]);
2050 nir_def *overflow = nir_ilt(b, buffer_size, buffer_offset);
2051
2052 any_overflow = nir_ior(b, any_overflow, overflow);
2053 overflow_amount[buffer] = nir_imax(b, nir_imm_int(b, 0),
2054 nir_isub(b, buffer_offset, buffer_size));
2055
2056 unsigned stream = info->buffer_to_stream[buffer];
2057 /* when previous workgroup overflow, we can't emit any primitive */
2058 emit_prim[stream] = nir_bcsel(
2059 b, overflow, nir_imm_int(b, 0),
2060 /* we can emit part primitives, limited by smallest buffer */
2061 nir_imin(b, emit_prim[stream], remain_prim));
2062
2063 /* Save to LDS for being accessed by other waves in this workgroup. */
2064 nir_store_shared(b, buffer_offset, scratch_base, .base = buffer * 4);
2065 }
2066
2067 /* We have to fix up the streamout offsets if we overflowed because they determine
2068 * the vertex count for DrawTransformFeedback.
2069 */
2070 if (gfx_level >= GFX12) {
2071 nir_pop_if(b, if_invocation_0);
2072
2073 any_overflow = nir_if_phi(b, any_overflow, nir_undef(b, 1, 1));
2074 for (unsigned buffer = 0; buffer < 4; buffer++)
2075 overflow_amount[buffer] = nir_if_phi(b, overflow_amount[buffer], undef);
2076 for (unsigned stream = 0; stream < 4; stream++) {
2077 if (emit_prim[stream])
2078 emit_prim[stream] = nir_if_phi(b, emit_prim[stream], undef);
2079 }
2080
2081 nir_if *if_any_overflow_4_lanes =
2082 nir_push_if(b, nir_iand(b, any_overflow, nir_ult_imm(b, tid_in_tg, 4)));
2083 {
2084 /* Move overflow amounts from SGPRs to the first 4 lanes. */
2085 nir_def *overflow_amount_per_lane =
2086 write_values_to_lanes(b, overflow_amount, info->buffers_written);
2087
2088 nir_global_atomic_amd(b, 32, xfb_state_address, nir_ineg(b, overflow_amount_per_lane),
2089 xfb_voffset, .base = 4, .atomic_op = nir_atomic_op_iadd);
2090 }
2091 nir_pop_if(b, if_any_overflow_4_lanes);
2092
2093 if_invocation_0 = nir_push_if(b, nir_ieq_imm(b, tid_in_tg, 0));
2094 } else {
2095 nir_if *if_any_overflow = nir_push_if(b, any_overflow);
2096 nir_xfb_counter_sub_gfx11_amd(b, nir_vec(b, overflow_amount, 4),
2097 /* mask of buffers to update */
2098 .write_mask = info->buffers_written);
2099 nir_pop_if(b, if_any_overflow);
2100 }
2101
2102 /* Save to LDS for being accessed by other waves in this workgroup. */
2103 for (unsigned stream = 0; stream < 4; stream++) {
2104 if (!(info->streams_written & BITFIELD_BIT(stream)))
2105 continue;
2106
2107 nir_store_shared(b, emit_prim[stream], scratch_base, .base = 16 + stream * 4);
2108 }
2109
2110 /* Update shader query. */
2111 if (has_xfb_prim_query) {
2112 nir_if *if_shader_query = nir_push_if(b, nir_load_prim_xfb_query_enabled_amd(b));
2113 {
2114 for (unsigned stream = 0; stream < 4; stream++) {
2115 if (info->streams_written & BITFIELD_BIT(stream))
2116 nir_atomic_add_xfb_prim_count_amd(b, emit_prim[stream], .stream_id = stream);
2117 }
2118 }
2119 nir_pop_if(b, if_shader_query);
2120 }
2121 }
2122 nir_pop_if(b, if_invocation_0);
2123
2124 nir_barrier(b, .execution_scope = SCOPE_WORKGROUP,
2125 .memory_scope = SCOPE_WORKGROUP,
2126 .memory_semantics = NIR_MEMORY_ACQ_REL,
2127 .memory_modes = nir_var_mem_shared);
2128
2129 /* Fetch the per-buffer offsets in all waves. */
2130 for (unsigned buffer = 0; buffer < 4; buffer++) {
2131 if (!(info->buffers_written & BITFIELD_BIT(buffer)))
2132 continue;
2133
2134 buffer_offsets_ret[buffer] =
2135 nir_load_shared(b, 1, 32, scratch_base, .base = buffer * 4);
2136 }
2137
2138 /* Fetch the per-stream emit prim in all waves. */
2139 for (unsigned stream = 0; stream < 4; stream++) {
2140 if (!(info->streams_written & BITFIELD_BIT(stream)))
2141 continue;
2142
2143 emit_prim_ret[stream] =
2144 nir_load_shared(b, 1, 32, scratch_base, .base = 16 + stream * 4);
2145 }
2146 }
2147
2148 static void
ngg_build_streamout_vertex(nir_builder * b,nir_xfb_info * info,unsigned stream,nir_def * so_buffer[4],nir_def * buffer_offsets[4],unsigned vertex_index,nir_def * vtx_lds_addr,ac_nir_prerast_out * pr_out,bool skip_primitive_id)2149 ngg_build_streamout_vertex(nir_builder *b, nir_xfb_info *info,
2150 unsigned stream, nir_def *so_buffer[4],
2151 nir_def *buffer_offsets[4],
2152 unsigned vertex_index, nir_def *vtx_lds_addr,
2153 ac_nir_prerast_out *pr_out,
2154 bool skip_primitive_id)
2155 {
2156 unsigned vertex_offset[NIR_MAX_XFB_BUFFERS] = {0};
2157
2158 u_foreach_bit(buffer, info->buffers_written) {
2159 /* We use imm_offset for the vertex offset within a primitive, and GFX11 only supports
2160 * 12-bit unsigned imm_offset. (GFX12 supports 24-bit signed imm_offset)
2161 */
2162 assert(info->buffers[buffer].stride * 3 < 4096);
2163 vertex_offset[buffer] = vertex_index * info->buffers[buffer].stride;
2164 }
2165
2166 nir_def *zero = nir_imm_int(b, 0);
2167 unsigned num_values = 0, store_offset = 0, store_buffer_index = 0;
2168 nir_def *values[4];
2169
2170 for (unsigned i = 0; i < info->output_count; i++) {
2171 nir_xfb_output_info *out = info->outputs + i;
2172 if (!out->component_mask || info->buffer_to_stream[out->buffer] != stream)
2173 continue;
2174
2175 unsigned base;
2176 if (out->location >= VARYING_SLOT_VAR0_16BIT) {
2177 base =
2178 util_bitcount64(b->shader->info.outputs_written) +
2179 util_bitcount(b->shader->info.outputs_written_16bit &
2180 BITFIELD_MASK(out->location - VARYING_SLOT_VAR0_16BIT));
2181 } else {
2182 uint64_t outputs_written = b->shader->info.outputs_written;
2183 if (skip_primitive_id)
2184 outputs_written &= ~VARYING_BIT_PRIMITIVE_ID;
2185
2186 base =
2187 util_bitcount64(outputs_written &
2188 BITFIELD64_MASK(out->location));
2189 }
2190
2191 unsigned offset = (base * 4 + out->component_offset) * 4;
2192 unsigned count = util_bitcount(out->component_mask);
2193
2194 assert(u_bit_consecutive(out->component_offset, count) == out->component_mask);
2195
2196 nir_def *out_data =
2197 nir_load_shared(b, count, 32, vtx_lds_addr, .base = offset);
2198
2199 for (unsigned comp = 0; comp < count; comp++) {
2200 nir_def *data = nir_channel(b, out_data, comp);
2201
2202 /* Convert 16-bit outputs to 32-bit.
2203 *
2204 * OpenGL ES will put 16-bit medium precision varyings to VARYING_SLOT_VAR0_16BIT.
2205 * We need to convert them to 32-bit for streamout.
2206 *
2207 * Vulkan does not allow 8/16bit varyings for streamout.
2208 */
2209 if (out->location >= VARYING_SLOT_VAR0_16BIT) {
2210 unsigned index = out->location - VARYING_SLOT_VAR0_16BIT;
2211 unsigned c = out->component_offset + comp;
2212 nir_def *v;
2213 nir_alu_type t;
2214
2215 if (out->high_16bits) {
2216 v = nir_unpack_32_2x16_split_y(b, data);
2217 t = pr_out->types_16bit_hi[index][c];
2218 } else {
2219 v = nir_unpack_32_2x16_split_x(b, data);
2220 t = pr_out->types_16bit_lo[index][c];
2221 }
2222
2223 t = nir_alu_type_get_base_type(t);
2224 data = nir_convert_to_bit_size(b, v, t, 32);
2225 }
2226
2227 const unsigned store_comp_offset = out->offset + comp * 4;
2228 const bool has_hole = store_offset + num_values * 4 != store_comp_offset;
2229
2230 /* Flush the gathered components to memory as a vec4 store or less if there is a hole. */
2231 if (num_values && (num_values == 4 || store_buffer_index != out->buffer || has_hole)) {
2232 nir_store_buffer_amd(b, nir_vec(b, values, num_values), so_buffer[store_buffer_index],
2233 buffer_offsets[store_buffer_index], zero, zero,
2234 .base = vertex_offset[store_buffer_index] + store_offset,
2235 .access = ACCESS_NON_TEMPORAL);
2236 num_values = 0;
2237 }
2238
2239 /* Initialize the buffer index and offset if we are beginning a new vec4 store. */
2240 if (num_values == 0) {
2241 store_buffer_index = out->buffer;
2242 store_offset = store_comp_offset;
2243 }
2244
2245 values[num_values++] = data;
2246 }
2247 }
2248
2249 if (num_values) {
2250 /* Flush the remaining components to memory (as an up to vec4 store) */
2251 nir_store_buffer_amd(b, nir_vec(b, values, num_values), so_buffer[store_buffer_index],
2252 buffer_offsets[store_buffer_index], zero, zero,
2253 .base = vertex_offset[store_buffer_index] + store_offset,
2254 .access = ACCESS_NON_TEMPORAL);
2255 }
2256 }
2257
2258 static void
ngg_nogs_build_streamout(nir_builder * b,lower_ngg_nogs_state * s)2259 ngg_nogs_build_streamout(nir_builder *b, lower_ngg_nogs_state *s)
2260 {
2261 nir_xfb_info *info = ac_nir_get_sorted_xfb_info(b->shader);
2262
2263 nir_def *lds_scratch_base = nir_load_lds_ngg_scratch_base_amd(b);
2264
2265 /* Get global buffer offset where this workgroup will stream out data to. */
2266 nir_def *generated_prim = nir_load_workgroup_num_input_primitives_amd(b);
2267 nir_def *gen_prim_per_stream[4] = {generated_prim, 0, 0, 0};
2268 nir_def *emit_prim_per_stream[4] = {0};
2269 nir_def *buffer_offsets[4] = {0};
2270 nir_def *so_buffer[4] = {0};
2271 nir_def *tid_in_tg = nir_load_local_invocation_index(b);
2272 ngg_build_streamout_buffer_info(b, info, s->options->gfx_level, s->options->has_xfb_prim_query,
2273 s->options->use_gfx12_xfb_intrinsic, lds_scratch_base, tid_in_tg,
2274 gen_prim_per_stream,
2275 so_buffer, buffer_offsets,
2276 emit_prim_per_stream);
2277
2278 /* Write out primitive data */
2279 nir_if *if_emit = nir_push_if(b, nir_ilt(b, tid_in_tg, emit_prim_per_stream[0]));
2280 {
2281 unsigned vtx_lds_stride = (b->shader->num_outputs * 4 + 1) * 4;
2282 nir_def *num_vert_per_prim = nir_load_num_vertices_per_primitive_amd(b);
2283 nir_def *first_vertex_idx = nir_imul(b, tid_in_tg, num_vert_per_prim);
2284
2285 u_foreach_bit(buffer, info->buffers_written) {
2286 buffer_offsets[buffer] = nir_iadd(b, buffer_offsets[buffer],
2287 nir_imul_imm(b, first_vertex_idx,
2288 info->buffers[buffer].stride));
2289 }
2290
2291 for (unsigned i = 0; i < s->options->num_vertices_per_primitive; i++) {
2292 nir_if *if_valid_vertex =
2293 nir_push_if(b, nir_igt_imm(b, num_vert_per_prim, i));
2294 {
2295 nir_def *vtx_lds_idx = nir_load_var(b, s->gs_vtx_indices_vars[i]);
2296 nir_def *vtx_lds_addr = pervertex_lds_addr(b, vtx_lds_idx, vtx_lds_stride);
2297 ngg_build_streamout_vertex(b, info, 0, so_buffer, buffer_offsets, i,
2298 vtx_lds_addr, &s->out, s->skip_primitive_id);
2299 }
2300 nir_pop_if(b, if_valid_vertex);
2301 }
2302 }
2303 nir_pop_if(b, if_emit);
2304
2305 /* Wait streamout memory ops done before export primitive, otherwise it
2306 * may not finish when shader ends.
2307 *
2308 * If a shader has no param exports, rasterization can start before
2309 * the shader finishes and thus memory stores might not finish before
2310 * the pixel shader starts.
2311 *
2312 * TODO: we only need this when no param exports.
2313 *
2314 * TODO: not sure if we need this barrier when late prim export, as I
2315 * can't observe test fail without this barrier.
2316 */
2317 nir_scoped_memory_barrier(b, SCOPE_DEVICE, NIR_MEMORY_RELEASE, nir_var_mem_ssbo);
2318 }
2319
2320 static unsigned
ngg_nogs_get_pervertex_lds_size(gl_shader_stage stage,unsigned shader_num_outputs,bool streamout_enabled,bool export_prim_id,bool has_user_edgeflags)2321 ngg_nogs_get_pervertex_lds_size(gl_shader_stage stage,
2322 unsigned shader_num_outputs,
2323 bool streamout_enabled,
2324 bool export_prim_id,
2325 bool has_user_edgeflags)
2326 {
2327 unsigned pervertex_lds_bytes = 0;
2328
2329 if (streamout_enabled) {
2330 /* The extra dword is used to avoid LDS bank conflicts and store the primitive id.
2331 * TODO: only alloc space for outputs that really need streamout.
2332 */
2333 pervertex_lds_bytes = (shader_num_outputs * 4 + 1) * 4;
2334 }
2335
2336 bool need_prim_id_store_shared = export_prim_id && stage == MESA_SHADER_VERTEX;
2337 if (need_prim_id_store_shared || has_user_edgeflags) {
2338 unsigned size = 0;
2339 if (need_prim_id_store_shared)
2340 size += 4;
2341 if (has_user_edgeflags)
2342 size += 4;
2343
2344 /* pad to odd dwords to avoid LDS bank conflict */
2345 size |= 4;
2346
2347 pervertex_lds_bytes = MAX2(pervertex_lds_bytes, size);
2348 }
2349
2350 return pervertex_lds_bytes;
2351 }
2352
2353 static void
ngg_nogs_gather_outputs(nir_builder * b,struct exec_list * cf_list,lower_ngg_nogs_state * s)2354 ngg_nogs_gather_outputs(nir_builder *b, struct exec_list *cf_list, lower_ngg_nogs_state *s)
2355 {
2356 /* Assume:
2357 * - the shader used nir_lower_io_to_temporaries
2358 * - 64-bit outputs are lowered
2359 * - no indirect indexing is present
2360 */
2361 struct nir_cf_node *first_node =
2362 exec_node_data(nir_cf_node, exec_list_get_head(cf_list), node);
2363
2364 for (nir_block *block = nir_cf_node_cf_tree_first(first_node); block != NULL;
2365 block = nir_block_cf_tree_next(block)) {
2366 nir_foreach_instr_safe (instr, block) {
2367 if (instr->type != nir_instr_type_intrinsic)
2368 continue;
2369
2370 nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
2371 if (intrin->intrinsic != nir_intrinsic_store_output)
2372 continue;
2373
2374 ac_nir_gather_prerast_store_output_info(b, intrin, &s->out);
2375 nir_instr_remove(instr);
2376 }
2377 }
2378 }
2379
2380 static void
create_output_phis(nir_builder * b,const uint64_t outputs_written,const uint64_t outputs_written_16bit,ac_nir_prerast_out * out)2381 create_output_phis(nir_builder *b, const uint64_t outputs_written, const uint64_t outputs_written_16bit, ac_nir_prerast_out *out)
2382 {
2383 nir_def *undef = nir_undef(b, 1, 32); /* inserted at the start of the shader */
2384
2385 u_foreach_bit64(slot, outputs_written) {
2386 for (unsigned j = 0; j < 4; j++) {
2387 if (out->outputs[slot][j])
2388 out->outputs[slot][j] = nir_if_phi(b, out->outputs[slot][j], undef);
2389 }
2390 }
2391
2392 u_foreach_bit64(i, outputs_written_16bit) {
2393 for (unsigned j = 0; j < 4; j++) {
2394 if (out->outputs_16bit_hi[i][j])
2395 out->outputs_16bit_hi[i][j] = nir_if_phi(b, out->outputs_16bit_hi[i][j], undef);
2396
2397 if (out->outputs_16bit_lo[i][j])
2398 out->outputs_16bit_lo[i][j] = nir_if_phi(b, out->outputs_16bit_lo[i][j], undef);
2399 }
2400 }
2401 }
2402
must_wait_attr_ring(enum amd_gfx_level gfx_level,bool has_param_exports)2403 static bool must_wait_attr_ring(enum amd_gfx_level gfx_level, bool has_param_exports)
2404 {
2405 return (gfx_level == GFX11 || gfx_level == GFX11_5) && has_param_exports;
2406 }
2407
2408 static void
export_pos0_wait_attr_ring(nir_builder * b,nir_if * if_es_thread,nir_def * outputs[VARYING_SLOT_MAX][4],const ac_nir_lower_ngg_options * options)2409 export_pos0_wait_attr_ring(nir_builder *b, nir_if *if_es_thread, nir_def *outputs[VARYING_SLOT_MAX][4], const ac_nir_lower_ngg_options *options)
2410 {
2411 b->cursor = nir_after_cf_node(&if_es_thread->cf_node);
2412
2413 /* Create phi for the position output values. */
2414 ac_nir_prerast_out out = {
2415 .outputs = {{outputs[VARYING_SLOT_POS][0], outputs[VARYING_SLOT_POS][1], outputs[VARYING_SLOT_POS][2], outputs[VARYING_SLOT_POS][3]}},
2416 .infos = {{.components_mask = 0xf, .as_sysval_mask = 0xf}},
2417 };
2418
2419 b->cursor = nir_after_cf_list(&b->impl->body);
2420
2421 /* Wait for attribute stores to finish. */
2422 nir_barrier(b, .execution_scope = SCOPE_SUBGROUP,
2423 .memory_scope = SCOPE_DEVICE,
2424 .memory_semantics = NIR_MEMORY_RELEASE,
2425 .memory_modes = nir_var_mem_ssbo | nir_var_shader_out | nir_var_mem_global | nir_var_image);
2426
2427 /* Export just the pos0 output. */
2428 nir_if *if_export_empty_pos = nir_push_if(b, if_es_thread->condition.ssa);
2429 {
2430 ac_nir_export_position(b, options->gfx_level,
2431 options->clip_cull_dist_mask,
2432 !options->has_param_exports,
2433 options->force_vrs, true,
2434 VARYING_BIT_POS, &out, NULL);
2435 }
2436 nir_pop_if(b, if_export_empty_pos);
2437 }
2438
2439
2440 static void
nogs_export_vertex_params(nir_builder * b,nir_function_impl * impl,nir_if * if_es_thread,nir_def * num_es_threads,lower_ngg_nogs_state * s)2441 nogs_export_vertex_params(nir_builder *b, nir_function_impl *impl,
2442 nir_if *if_es_thread, nir_def *num_es_threads,
2443 lower_ngg_nogs_state *s)
2444 {
2445 if (!s->options->has_param_exports)
2446 return;
2447
2448 if (s->options->gfx_level >= GFX11) {
2449 /* Export varyings for GFX11+ */
2450 b->cursor = nir_after_impl(impl);
2451 if (!num_es_threads)
2452 num_es_threads = nir_load_merged_wave_info_amd(b);
2453
2454 ac_nir_store_parameters_to_attr_ring(b, s->options->vs_output_param_offset,
2455 b->shader->info.outputs_written,
2456 b->shader->info.outputs_written_16bit,
2457 &s->out, NULL, num_es_threads);
2458 } else {
2459 ac_nir_export_parameters(b, s->options->vs_output_param_offset,
2460 b->shader->info.outputs_written,
2461 b->shader->info.outputs_written_16bit,
2462 &s->out);
2463 }
2464 }
2465
2466 void
ac_nir_lower_ngg_nogs(nir_shader * shader,const ac_nir_lower_ngg_options * options)2467 ac_nir_lower_ngg_nogs(nir_shader *shader, const ac_nir_lower_ngg_options *options)
2468 {
2469 nir_function_impl *impl = nir_shader_get_entrypoint(shader);
2470 assert(impl);
2471 assert(options->max_workgroup_size && options->wave_size);
2472 assert(!(options->can_cull && options->passthrough));
2473
2474 nir_variable *position_value_var = nir_local_variable_create(impl, glsl_vec4_type(), "position_value");
2475 nir_variable *prim_exp_arg_var = nir_local_variable_create(impl, glsl_uint_type(), "prim_exp_arg");
2476 nir_variable *es_accepted_var =
2477 options->can_cull ? nir_local_variable_create(impl, glsl_bool_type(), "es_accepted") : NULL;
2478 nir_variable *gs_accepted_var =
2479 options->can_cull ? nir_local_variable_create(impl, glsl_bool_type(), "gs_accepted") : NULL;
2480 nir_variable *gs_exported_var = nir_local_variable_create(impl, glsl_bool_type(), "gs_exported");
2481
2482 bool streamout_enabled = shader->xfb_info && !options->disable_streamout;
2483 bool has_user_edgeflags =
2484 options->use_edgeflags && (shader->info.outputs_written & VARYING_BIT_EDGE);
2485 /* streamout need to be done before either prim or vertex export. Because when no
2486 * param export, rasterization can start right after prim and vertex export,
2487 * which left streamout buffer writes un-finished.
2488 *
2489 * Always use late prim export when user edge flags are enabled.
2490 * This is because edge flags are written by ES threads but they
2491 * are exported by GS threads as part of th primitive export.
2492 */
2493 bool early_prim_export =
2494 options->early_prim_export && !(streamout_enabled || has_user_edgeflags);
2495
2496 lower_ngg_nogs_state state = {
2497 .options = options,
2498 .early_prim_export = early_prim_export,
2499 .streamout_enabled = streamout_enabled,
2500 .position_value_var = position_value_var,
2501 .prim_exp_arg_var = prim_exp_arg_var,
2502 .es_accepted_var = es_accepted_var,
2503 .gs_accepted_var = gs_accepted_var,
2504 .gs_exported_var = gs_exported_var,
2505 .max_num_waves = DIV_ROUND_UP(options->max_workgroup_size, options->wave_size),
2506 .has_user_edgeflags = has_user_edgeflags,
2507 .skip_primitive_id = streamout_enabled && (options->export_primitive_id || options->export_primitive_id_per_prim),
2508 };
2509
2510 /* Can't export the primitive ID both as per-vertex and per-primitive. */
2511 assert(!options->export_primitive_id || !options->export_primitive_id_per_prim);
2512
2513 const bool need_prim_id_store_shared =
2514 options->export_primitive_id && shader->info.stage == MESA_SHADER_VERTEX;
2515
2516 if (options->export_primitive_id) {
2517 shader->info.outputs_written |= VARYING_BIT_PRIMITIVE_ID;
2518 }
2519
2520 if (options->export_primitive_id_per_prim) {
2521 /* The HW preloads the primitive ID to VGPRs of GS threads for VS, but not for TES. */
2522 assert(shader->info.stage == MESA_SHADER_VERTEX);
2523 assert(options->gfx_level >= GFX10_3);
2524 }
2525
2526 nir_builder builder = nir_builder_create(impl);
2527 nir_builder *b = &builder; /* This is to avoid the & */
2528
2529 if (options->can_cull) {
2530 analyze_shader_before_culling(shader, &state);
2531 save_reusable_variables(b, &state);
2532 }
2533
2534 nir_cf_list extracted;
2535 nir_cf_extract(&extracted, nir_before_impl(impl),
2536 nir_after_impl(impl));
2537 b->cursor = nir_before_impl(impl);
2538
2539 ngg_nogs_init_vertex_indices_vars(b, impl, &state);
2540
2541 /* Emit primitives generated query code here, so that
2542 * it executes before culling and isn't in the extracted CF.
2543 */
2544 nogs_prim_gen_query(b, &state);
2545
2546 /* Whether a shader invocation should export a primitive,
2547 * initialize to all invocations that have an input primitive.
2548 */
2549 nir_store_var(b, gs_exported_var, has_input_primitive(b), 0x1u);
2550
2551 if (!options->can_cull) {
2552 /* Newer chips can use PRIMGEN_PASSTHRU_NO_MSG to skip gs_alloc_req for NGG passthrough. */
2553 if (!(options->passthrough && options->family >= CHIP_NAVI23)) {
2554 /* Allocate export space on wave 0 - confirm to the HW that we want to use all possible space */
2555 nir_if *if_wave_0 = nir_push_if(b, nir_ieq_imm(b, nir_load_subgroup_id(b), 0));
2556 {
2557 nir_def *vtx_cnt = nir_load_workgroup_num_input_vertices_amd(b);
2558 nir_def *prim_cnt = nir_load_workgroup_num_input_primitives_amd(b);
2559 alloc_vertices_and_primitives(b, vtx_cnt, prim_cnt);
2560 }
2561 nir_pop_if(b, if_wave_0);
2562 }
2563
2564 /* Take care of early primitive export, otherwise just pack the primitive export argument */
2565 if (state.early_prim_export)
2566 emit_ngg_nogs_prim_export(b, &state, NULL);
2567 else
2568 nir_store_var(b, prim_exp_arg_var, emit_ngg_nogs_prim_exp_arg(b, &state), 0x1u);
2569 } else {
2570 add_deferred_attribute_culling(b, &extracted, &state);
2571 b->cursor = nir_after_impl(impl);
2572
2573 if (state.early_prim_export)
2574 emit_ngg_nogs_prim_export(b, &state, nir_load_var(b, state.prim_exp_arg_var));
2575
2576 /* Wait for culling to finish using LDS. */
2577 if (need_prim_id_store_shared || has_user_edgeflags) {
2578 nir_barrier(b, .execution_scope = SCOPE_WORKGROUP,
2579 .memory_scope = SCOPE_WORKGROUP,
2580 .memory_semantics = NIR_MEMORY_ACQ_REL,
2581 .memory_modes = nir_var_mem_shared);
2582 }
2583 }
2584
2585 /* determine the LDS vertex stride */
2586 state.pervertex_lds_bytes =
2587 ngg_nogs_get_pervertex_lds_size(shader->info.stage,
2588 shader->num_outputs,
2589 state.streamout_enabled,
2590 options->export_primitive_id,
2591 state.has_user_edgeflags);
2592
2593 if (need_prim_id_store_shared) {
2594 emit_ngg_nogs_prim_id_store_shared(b, &state);
2595
2596 /* Wait for GS threads to store primitive ID in LDS. */
2597 nir_barrier(b, .execution_scope = SCOPE_WORKGROUP, .memory_scope = SCOPE_WORKGROUP,
2598 .memory_semantics = NIR_MEMORY_ACQ_REL, .memory_modes = nir_var_mem_shared);
2599 } else if (options->export_primitive_id_per_prim && options->gfx_level >= GFX11) {
2600 emit_ngg_nogs_prim_id_store_per_prim_to_attr_ring(b, &state);
2601 }
2602
2603 nir_def *es_thread =
2604 options->can_cull ? nir_load_var(b, es_accepted_var) : has_input_vertex(b);
2605
2606 /* Calculate the bit count here instead of below for lower SGPR usage and better ALU
2607 * scheduling.
2608 */
2609 nir_def *num_es_threads = NULL;
2610 if (state.options->gfx_level >= GFX11 && options->can_cull) {
2611 nir_def *es_accepted_mask =
2612 nir_ballot(b, 1, options->wave_size, nir_load_var(b, es_accepted_var));
2613 num_es_threads = nir_bit_count(b, es_accepted_mask);
2614 }
2615
2616 nir_if *if_es_thread = nir_push_if(b, es_thread);
2617 {
2618 /* Run the actual shader */
2619 nir_cf_reinsert(&extracted, b->cursor);
2620 b->cursor = nir_after_cf_list(&if_es_thread->then_list);
2621
2622 if (options->export_primitive_id)
2623 emit_store_ngg_nogs_es_primitive_id(b, &state);
2624 }
2625 nir_pop_if(b, if_es_thread);
2626
2627 if (options->can_cull) {
2628 /* Replace uniforms. */
2629 apply_reusable_variables(b, &state);
2630
2631 /* Remove the redundant position output. */
2632 remove_extra_pos_outputs(shader, &state);
2633
2634 /* After looking at the performance in apps eg. Doom Eternal, and The Witcher 3,
2635 * it seems that it's best to put the position export always at the end, and
2636 * then let ACO schedule it up (slightly) only when early prim export is used.
2637 */
2638 b->cursor = nir_after_cf_list(&if_es_thread->then_list);
2639
2640 nir_def *pos_val = nir_load_var(b, state.position_value_var);
2641 for (int i = 0; i < 4; i++)
2642 state.out.outputs[VARYING_SLOT_POS][i] = nir_channel(b, pos_val, i);
2643 }
2644
2645 /* Gather outputs data and types */
2646 ngg_nogs_gather_outputs(b, &if_es_thread->then_list, &state);
2647 b->cursor = nir_after_cf_list(&if_es_thread->then_list);
2648
2649 if (state.has_user_edgeflags)
2650 ngg_nogs_store_edgeflag_to_lds(b, &state);
2651
2652 if (state.streamout_enabled) {
2653 /* TODO: support culling after streamout. */
2654 assert(!options->can_cull);
2655
2656 ngg_nogs_store_xfb_outputs_to_lds(b, &state);
2657
2658 b->cursor = nir_after_impl(impl);
2659 ngg_nogs_build_streamout(b, &state);
2660 }
2661
2662 /* Take care of late primitive export */
2663 if (!state.early_prim_export) {
2664 b->cursor = nir_after_impl(impl);
2665 emit_ngg_nogs_prim_export(b, &state, nir_load_var(b, prim_exp_arg_var));
2666 }
2667
2668 uint64_t export_outputs = shader->info.outputs_written | VARYING_BIT_POS;
2669 if (options->kill_pointsize)
2670 export_outputs &= ~VARYING_BIT_PSIZ;
2671 if (options->kill_layer)
2672 export_outputs &= ~VARYING_BIT_LAYER;
2673
2674 const bool wait_attr_ring = must_wait_attr_ring(options->gfx_level, options->has_param_exports);
2675 if (wait_attr_ring)
2676 export_outputs &= ~VARYING_BIT_POS;
2677
2678 bool phis_created = false;
2679
2680 /* Add position exports.
2681 *
2682 * If streamout is enabled, export positions after streamout. This increases streamout performance
2683 * for up to 4 vec4 xfb outputs on GFX12 because the streamout code doesn't have go through
2684 * the export allocation bottleneck. Adding more xfb outputs starts to be limited by the memory
2685 * bandwidth.
2686 */
2687 nir_if *if_pos_exports = NULL;
2688 if (state.streamout_enabled) {
2689 b->cursor = nir_after_cf_node(&if_es_thread->cf_node);
2690 create_output_phis(b, b->shader->info.outputs_written, b->shader->info.outputs_written_16bit,
2691 &state.out);
2692 phis_created = true;
2693
2694 b->cursor = nir_after_impl(impl);
2695 if_pos_exports = nir_push_if(b, es_thread);
2696 } else {
2697 b->cursor = nir_after_cf_list(&if_es_thread->then_list);
2698 }
2699
2700 ac_nir_export_position(b, options->gfx_level,
2701 options->clip_cull_dist_mask,
2702 !options->has_param_exports,
2703 options->force_vrs, !wait_attr_ring,
2704 export_outputs, &state.out, NULL);
2705
2706 if (if_pos_exports)
2707 nir_pop_if(b, if_pos_exports);
2708
2709 if (options->has_param_exports && options->gfx_level >= GFX11 && !phis_created) {
2710 b->cursor = nir_after_cf_node(&if_es_thread->cf_node);
2711 create_output_phis(b, b->shader->info.outputs_written, b->shader->info.outputs_written_16bit,
2712 &state.out);
2713 }
2714
2715 b->cursor = nir_after_cf_list(&if_es_thread->then_list);
2716 nogs_export_vertex_params(b, impl, if_es_thread, num_es_threads, &state);
2717
2718 if (wait_attr_ring)
2719 export_pos0_wait_attr_ring(b, if_es_thread, state.out.outputs, options);
2720
2721 nir_metadata_preserve(impl, nir_metadata_none);
2722 nir_validate_shader(shader, "after emitting NGG VS/TES");
2723
2724 /* Cleanup */
2725 nir_opt_dead_write_vars(shader);
2726 nir_lower_vars_to_ssa(shader);
2727 nir_remove_dead_variables(shader, nir_var_function_temp, NULL);
2728 nir_lower_alu_to_scalar(shader, NULL, NULL);
2729 nir_lower_phis_to_scalar(shader, true);
2730
2731 if (options->can_cull) {
2732 /* It's beneficial to redo these opts after splitting the shader. */
2733 nir_opt_sink(shader, nir_move_load_input | nir_move_const_undef | nir_move_copies);
2734 nir_opt_move(shader, nir_move_load_input | nir_move_copies | nir_move_const_undef);
2735 }
2736
2737 bool progress;
2738 do {
2739 progress = false;
2740 NIR_PASS(progress, shader, nir_opt_undef);
2741 NIR_PASS(progress, shader, nir_opt_dce);
2742 NIR_PASS(progress, shader, nir_opt_dead_cf);
2743
2744 if (options->can_cull)
2745 progress |= cleanup_culling_shader_after_dce(shader, b->impl, &state);
2746 } while (progress);
2747 }
2748
2749 /**
2750 * Return the address of the LDS storage reserved for the N'th vertex,
2751 * where N is in emit order, meaning:
2752 * - during the finale, N is the invocation_index (within the workgroup)
2753 * - during vertex emit, i.e. while the API GS shader invocation is running,
2754 * N = invocation_index * gs_max_out_vertices + emit_idx
2755 * where emit_idx is the vertex index in the current API GS invocation.
2756 *
2757 * Goals of the LDS memory layout:
2758 * 1. Eliminate bank conflicts on write for geometry shaders that have all emits
2759 * in uniform control flow
2760 * 2. Eliminate bank conflicts on read for export if, additionally, there is no
2761 * culling
2762 * 3. Agnostic to the number of waves (since we don't know it before compiling)
2763 * 4. Allow coalescing of LDS instructions (ds_write_b128 etc.)
2764 * 5. Avoid wasting memory.
2765 *
2766 * We use an AoS layout due to point 4 (this also helps point 3). In an AoS
2767 * layout, elimination of bank conflicts requires that each vertex occupy an
2768 * odd number of dwords. We use the additional dword to store the output stream
2769 * index as well as a flag to indicate whether this vertex ends a primitive
2770 * for rasterization.
2771 *
2772 * Swizzling is required to satisfy points 1 and 2 simultaneously.
2773 *
2774 * Vertices are stored in export order (gsthread * gs_max_out_vertices + emitidx).
2775 * Indices are swizzled in groups of 32, which ensures point 1 without
2776 * disturbing point 2.
2777 *
2778 * \return an LDS pointer to type {[N x i32], [4 x i8]}
2779 */
2780 static nir_def *
ngg_gs_out_vertex_addr(nir_builder * b,nir_def * out_vtx_idx,lower_ngg_gs_state * s)2781 ngg_gs_out_vertex_addr(nir_builder *b, nir_def *out_vtx_idx, lower_ngg_gs_state *s)
2782 {
2783 unsigned write_stride_2exp = ffs(MAX2(b->shader->info.gs.vertices_out, 1)) - 1;
2784
2785 /* gs_max_out_vertices = 2^(write_stride_2exp) * some odd number */
2786 if (write_stride_2exp) {
2787 nir_def *row = nir_ushr_imm(b, out_vtx_idx, 5);
2788 nir_def *swizzle = nir_iand_imm(b, row, (1u << write_stride_2exp) - 1u);
2789 out_vtx_idx = nir_ixor(b, out_vtx_idx, swizzle);
2790 }
2791
2792 nir_def *out_vtx_offs = nir_imul_imm(b, out_vtx_idx, s->lds_bytes_per_gs_out_vertex);
2793 return nir_iadd_nuw(b, out_vtx_offs, s->lds_addr_gs_out_vtx);
2794 }
2795
2796 static nir_def *
ngg_gs_emit_vertex_addr(nir_builder * b,nir_def * gs_vtx_idx,lower_ngg_gs_state * s)2797 ngg_gs_emit_vertex_addr(nir_builder *b, nir_def *gs_vtx_idx, lower_ngg_gs_state *s)
2798 {
2799 nir_def *tid_in_tg = nir_load_local_invocation_index(b);
2800 nir_def *gs_out_vtx_base = nir_imul_imm(b, tid_in_tg, b->shader->info.gs.vertices_out);
2801 nir_def *out_vtx_idx = nir_iadd_nuw(b, gs_out_vtx_base, gs_vtx_idx);
2802
2803 return ngg_gs_out_vertex_addr(b, out_vtx_idx, s);
2804 }
2805
2806 static void
ngg_gs_clear_primflags(nir_builder * b,nir_def * num_vertices,unsigned stream,lower_ngg_gs_state * s)2807 ngg_gs_clear_primflags(nir_builder *b, nir_def *num_vertices, unsigned stream, lower_ngg_gs_state *s)
2808 {
2809 char name[32];
2810 snprintf(name, sizeof(name), "clear_primflag_idx_%u", stream);
2811 nir_variable *clear_primflag_idx_var = nir_local_variable_create(b->impl, glsl_uint_type(), name);
2812
2813 nir_def *zero_u8 = nir_imm_zero(b, 1, 8);
2814 nir_store_var(b, clear_primflag_idx_var, num_vertices, 0x1u);
2815
2816 nir_loop *loop = nir_push_loop(b);
2817 {
2818 nir_def *clear_primflag_idx = nir_load_var(b, clear_primflag_idx_var);
2819 nir_if *if_break = nir_push_if(b, nir_uge_imm(b, clear_primflag_idx, b->shader->info.gs.vertices_out));
2820 {
2821 nir_jump(b, nir_jump_break);
2822 }
2823 nir_push_else(b, if_break);
2824 {
2825 nir_def *emit_vtx_addr = ngg_gs_emit_vertex_addr(b, clear_primflag_idx, s);
2826 nir_store_shared(b, zero_u8, emit_vtx_addr, .base = s->lds_offs_primflags + stream);
2827 nir_store_var(b, clear_primflag_idx_var, nir_iadd_imm_nuw(b, clear_primflag_idx, 1), 0x1u);
2828 }
2829 nir_pop_if(b, if_break);
2830 }
2831 nir_pop_loop(b, loop);
2832 }
2833
2834 static bool
lower_ngg_gs_store_output(nir_builder * b,nir_intrinsic_instr * intrin,lower_ngg_gs_state * s)2835 lower_ngg_gs_store_output(nir_builder *b, nir_intrinsic_instr *intrin, lower_ngg_gs_state *s)
2836 {
2837 ac_nir_gather_prerast_store_output_info(b, intrin, &s->out);
2838 nir_instr_remove(&intrin->instr);
2839 return true;
2840 }
2841
2842 static unsigned
gs_output_component_mask_with_stream(ac_nir_prerast_per_output_info * info,unsigned stream)2843 gs_output_component_mask_with_stream(ac_nir_prerast_per_output_info *info, unsigned stream)
2844 {
2845 unsigned mask = info->components_mask;
2846 if (!mask)
2847 return 0;
2848
2849 /* clear component when not requested stream */
2850 for (int i = 0; i < 4; i++) {
2851 if (((info->stream >> (i * 2)) & 3) != stream)
2852 mask &= ~(1 << i);
2853 }
2854
2855 return mask;
2856 }
2857
2858 static bool
lower_ngg_gs_emit_vertex_with_counter(nir_builder * b,nir_intrinsic_instr * intrin,lower_ngg_gs_state * s)2859 lower_ngg_gs_emit_vertex_with_counter(nir_builder *b, nir_intrinsic_instr *intrin, lower_ngg_gs_state *s)
2860 {
2861 b->cursor = nir_before_instr(&intrin->instr);
2862
2863 unsigned stream = nir_intrinsic_stream_id(intrin);
2864 if (!(b->shader->info.gs.active_stream_mask & (1 << stream))) {
2865 nir_instr_remove(&intrin->instr);
2866 return true;
2867 }
2868
2869 nir_def *gs_emit_vtx_idx = intrin->src[0].ssa;
2870 nir_def *current_vtx_per_prim = intrin->src[1].ssa;
2871 nir_def *gs_emit_vtx_addr = ngg_gs_emit_vertex_addr(b, gs_emit_vtx_idx, s);
2872
2873 /* Store generic 32-bit outputs to LDS.
2874 * In case of packed 16-bit, we assume that has been already packed into 32 bit slots by now.
2875 */
2876 u_foreach_bit64(slot, b->shader->info.outputs_written) {
2877 const unsigned packed_location = util_bitcount64((b->shader->info.outputs_written & BITFIELD64_MASK(slot)));
2878 unsigned mask = gs_output_component_mask_with_stream(&s->out.infos[slot], stream);
2879
2880 nir_def **output = s->out.outputs[slot];
2881 nir_def *undef = nir_undef(b, 1, 32);
2882
2883 while (mask) {
2884 int start, count;
2885 u_bit_scan_consecutive_range(&mask, &start, &count);
2886 nir_def *values[4] = {0};
2887 for (int c = start; c < start + count; ++c) {
2888 if (!output[c]) {
2889 /* The shader hasn't written this output. */
2890 values[c - start] = undef;
2891 } else {
2892 assert(output[c]->bit_size == 32);
2893 values[c - start] = output[c];
2894 }
2895 }
2896
2897 nir_def *store_val = nir_vec(b, values, (unsigned)count);
2898 nir_store_shared(b, store_val, gs_emit_vtx_addr,
2899 .base = packed_location * 16 + start * 4,
2900 .align_mul = 4);
2901 }
2902
2903 /* Clear all outputs (they are undefined after emit_vertex) */
2904 memset(s->out.outputs[slot], 0, sizeof(s->out.outputs[slot]));
2905 }
2906
2907 const unsigned num_32bit_outputs = util_bitcount64(b->shader->info.outputs_written);
2908
2909 /* Store dedicated 16-bit outputs to LDS. */
2910 u_foreach_bit(slot, b->shader->info.outputs_written_16bit) {
2911 const unsigned packed_location = num_32bit_outputs +
2912 util_bitcount(b->shader->info.outputs_written_16bit & BITFIELD_MASK(slot));
2913
2914 const unsigned mask_lo = gs_output_component_mask_with_stream(s->out.infos_16bit_lo + slot, stream);
2915 const unsigned mask_hi = gs_output_component_mask_with_stream(s->out.infos_16bit_hi + slot, stream);
2916 unsigned mask = mask_lo | mask_hi;
2917
2918 nir_def **output_lo = s->out.outputs_16bit_lo[slot];
2919 nir_def **output_hi = s->out.outputs_16bit_hi[slot];
2920 nir_def *undef = nir_undef(b, 1, 16);
2921
2922 while (mask) {
2923 int start, count;
2924 u_bit_scan_consecutive_range(&mask, &start, &count);
2925 nir_def *values[4] = {0};
2926 for (int c = start; c < start + count; ++c) {
2927 nir_def *lo = output_lo[c] ? output_lo[c] : undef;
2928 nir_def *hi = output_hi[c] ? output_hi[c] : undef;
2929
2930 values[c - start] = nir_pack_32_2x16_split(b, lo, hi);
2931 }
2932
2933 nir_def *store_val = nir_vec(b, values, (unsigned)count);
2934 nir_store_shared(b, store_val, gs_emit_vtx_addr,
2935 .base = packed_location * 16 + start * 4,
2936 .align_mul = 4);
2937 }
2938
2939 /* Clear all outputs (they are undefined after emit_vertex) */
2940 memset(s->out.outputs_16bit_lo[slot], 0, sizeof(s->out.outputs_16bit_lo[slot]));
2941 memset(s->out.outputs_16bit_hi[slot], 0, sizeof(s->out.outputs_16bit_hi[slot]));
2942 }
2943
2944 /* Calculate and store per-vertex primitive flags based on vertex counts:
2945 * - bit 0: whether this vertex finishes a primitive (a real primitive, not the strip)
2946 * - bit 1: whether the primitive index is odd (if we are emitting triangle strips, otherwise always 0)
2947 * only set when the vertex also finishes the primitive
2948 * - bit 2: whether vertex is live (if culling is enabled: set after culling, otherwise always 1)
2949 */
2950
2951 nir_def *vertex_live_flag =
2952 !stream && s->options->can_cull
2953 ? nir_ishl_imm(b, nir_b2i32(b, nir_inot(b, nir_load_cull_any_enabled_amd(b))), 2)
2954 : nir_imm_int(b, 0b100);
2955
2956 nir_def *completes_prim = nir_ige_imm(b, current_vtx_per_prim, s->num_vertices_per_primitive - 1);
2957 nir_def *complete_flag = nir_b2i32(b, completes_prim);
2958
2959 nir_def *prim_flag = nir_ior(b, vertex_live_flag, complete_flag);
2960 if (s->num_vertices_per_primitive == 3) {
2961 nir_def *odd = nir_iand(b, current_vtx_per_prim, complete_flag);
2962 nir_def *odd_flag = nir_ishl_imm(b, odd, 1);
2963 prim_flag = nir_ior(b, prim_flag, odd_flag);
2964 }
2965
2966 nir_store_shared(b, nir_u2u8(b, prim_flag), gs_emit_vtx_addr,
2967 .base = s->lds_offs_primflags + stream,
2968 .align_mul = 4, .align_offset = stream);
2969
2970 nir_instr_remove(&intrin->instr);
2971 return true;
2972 }
2973
2974 static bool
lower_ngg_gs_end_primitive_with_counter(nir_builder * b,nir_intrinsic_instr * intrin,UNUSED lower_ngg_gs_state * s)2975 lower_ngg_gs_end_primitive_with_counter(nir_builder *b, nir_intrinsic_instr *intrin, UNUSED lower_ngg_gs_state *s)
2976 {
2977 b->cursor = nir_before_instr(&intrin->instr);
2978
2979 /* These are not needed, we can simply remove them */
2980 nir_instr_remove(&intrin->instr);
2981 return true;
2982 }
2983
2984 static bool
lower_ngg_gs_set_vertex_and_primitive_count(nir_builder * b,nir_intrinsic_instr * intrin,lower_ngg_gs_state * s)2985 lower_ngg_gs_set_vertex_and_primitive_count(nir_builder *b, nir_intrinsic_instr *intrin, lower_ngg_gs_state *s)
2986 {
2987 b->cursor = nir_before_instr(&intrin->instr);
2988
2989 unsigned stream = nir_intrinsic_stream_id(intrin);
2990 if (stream > 0 && !(b->shader->info.gs.active_stream_mask & (1 << stream))) {
2991 nir_instr_remove(&intrin->instr);
2992 return true;
2993 }
2994
2995 s->vertex_count[stream] = intrin->src[0].ssa;
2996 s->primitive_count[stream] = intrin->src[1].ssa;
2997
2998 /* Clear the primitive flags of non-emitted vertices */
2999 if (!nir_src_is_const(intrin->src[0]) || nir_src_as_uint(intrin->src[0]) < b->shader->info.gs.vertices_out)
3000 ngg_gs_clear_primflags(b, intrin->src[0].ssa, stream, s);
3001
3002 nir_instr_remove(&intrin->instr);
3003 return true;
3004 }
3005
3006 static bool
lower_ngg_gs_intrinsic(nir_builder * b,nir_instr * instr,void * state)3007 lower_ngg_gs_intrinsic(nir_builder *b, nir_instr *instr, void *state)
3008 {
3009 lower_ngg_gs_state *s = (lower_ngg_gs_state *) state;
3010
3011 if (instr->type != nir_instr_type_intrinsic)
3012 return false;
3013
3014 nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
3015
3016 if (intrin->intrinsic == nir_intrinsic_store_output)
3017 return lower_ngg_gs_store_output(b, intrin, s);
3018 else if (intrin->intrinsic == nir_intrinsic_emit_vertex_with_counter)
3019 return lower_ngg_gs_emit_vertex_with_counter(b, intrin, s);
3020 else if (intrin->intrinsic == nir_intrinsic_end_primitive_with_counter)
3021 return lower_ngg_gs_end_primitive_with_counter(b, intrin, s);
3022 else if (intrin->intrinsic == nir_intrinsic_set_vertex_and_primitive_count)
3023 return lower_ngg_gs_set_vertex_and_primitive_count(b, intrin, s);
3024
3025 return false;
3026 }
3027
3028 static void
lower_ngg_gs_intrinsics(nir_shader * shader,lower_ngg_gs_state * s)3029 lower_ngg_gs_intrinsics(nir_shader *shader, lower_ngg_gs_state *s)
3030 {
3031 nir_shader_instructions_pass(shader, lower_ngg_gs_intrinsic, nir_metadata_none, s);
3032 }
3033
3034 static void
ngg_gs_export_primitives(nir_builder * b,nir_def * max_num_out_prims,nir_def * tid_in_tg,nir_def * exporter_tid_in_tg,nir_def * primflag_0,lower_ngg_gs_state * s)3035 ngg_gs_export_primitives(nir_builder *b, nir_def *max_num_out_prims, nir_def *tid_in_tg,
3036 nir_def *exporter_tid_in_tg, nir_def *primflag_0,
3037 lower_ngg_gs_state *s)
3038 {
3039 nir_if *if_prim_export_thread = nir_push_if(b, nir_ilt(b, tid_in_tg, max_num_out_prims));
3040
3041 /* Only bit 0 matters here - set it to 1 when the primitive should be null */
3042 nir_def *is_null_prim = nir_ixor(b, primflag_0, nir_imm_int(b, -1u));
3043
3044 nir_def *vtx_indices[3] = {0};
3045 vtx_indices[s->num_vertices_per_primitive - 1] = exporter_tid_in_tg;
3046 if (s->num_vertices_per_primitive >= 2)
3047 vtx_indices[s->num_vertices_per_primitive - 2] = nir_iadd_imm(b, exporter_tid_in_tg, -1);
3048 if (s->num_vertices_per_primitive == 3)
3049 vtx_indices[s->num_vertices_per_primitive - 3] = nir_iadd_imm(b, exporter_tid_in_tg, -2);
3050
3051 if (s->num_vertices_per_primitive == 3) {
3052 /* API GS outputs triangle strips, but NGG HW understands triangles.
3053 * We already know the triangles due to how we set the primitive flags, but we need to
3054 * make sure the vertex order is so that the front/back is correct, and the provoking vertex is kept.
3055 */
3056
3057 nir_def *is_odd = nir_ubfe_imm(b, primflag_0, 1, 1);
3058 nir_def *provoking_vertex_index = nir_load_provoking_vtx_in_prim_amd(b);
3059 nir_def *provoking_vertex_first = nir_ieq_imm(b, provoking_vertex_index, 0);
3060
3061 vtx_indices[0] = nir_bcsel(b, provoking_vertex_first, vtx_indices[0],
3062 nir_iadd(b, vtx_indices[0], is_odd));
3063 vtx_indices[1] = nir_bcsel(b, provoking_vertex_first,
3064 nir_iadd(b, vtx_indices[1], is_odd),
3065 nir_isub(b, vtx_indices[1], is_odd));
3066 vtx_indices[2] = nir_bcsel(b, provoking_vertex_first,
3067 nir_isub(b, vtx_indices[2], is_odd), vtx_indices[2]);
3068 }
3069
3070 nir_def *arg = ac_nir_pack_ngg_prim_exp_arg(b, s->num_vertices_per_primitive, vtx_indices,
3071 is_null_prim, s->options->gfx_level);
3072 ac_nir_export_primitive(b, arg, NULL);
3073 nir_pop_if(b, if_prim_export_thread);
3074 }
3075
3076 static void
ngg_gs_export_vertices(nir_builder * b,nir_def * max_num_out_vtx,nir_def * tid_in_tg,nir_def * out_vtx_lds_addr,lower_ngg_gs_state * s)3077 ngg_gs_export_vertices(nir_builder *b, nir_def *max_num_out_vtx, nir_def *tid_in_tg,
3078 nir_def *out_vtx_lds_addr, lower_ngg_gs_state *s)
3079 {
3080 nir_if *if_vtx_export_thread = nir_push_if(b, nir_ilt(b, tid_in_tg, max_num_out_vtx));
3081 nir_def *exported_out_vtx_lds_addr = out_vtx_lds_addr;
3082
3083 if (!s->output_compile_time_known) {
3084 /* Vertex compaction.
3085 * The current thread will export a vertex that was live in another invocation.
3086 * Load the index of the vertex that the current thread will have to export.
3087 */
3088 nir_def *exported_vtx_idx = nir_load_shared(b, 1, 8, out_vtx_lds_addr, .base = s->lds_offs_primflags + 1);
3089 exported_out_vtx_lds_addr = ngg_gs_out_vertex_addr(b, nir_u2u32(b, exported_vtx_idx), s);
3090 }
3091
3092 u_foreach_bit64(slot, b->shader->info.outputs_written) {
3093 const unsigned packed_location =
3094 util_bitcount64((b->shader->info.outputs_written & BITFIELD64_MASK(slot)));
3095
3096 unsigned mask = gs_output_component_mask_with_stream(&s->out.infos[slot], 0);
3097
3098 while (mask) {
3099 int start, count;
3100 u_bit_scan_consecutive_range(&mask, &start, &count);
3101 nir_def *load =
3102 nir_load_shared(b, count, 32, exported_out_vtx_lds_addr,
3103 .base = packed_location * 16 + start * 4,
3104 .align_mul = 4);
3105
3106 for (int i = 0; i < count; i++)
3107 s->out.outputs[slot][start + i] = nir_channel(b, load, i);
3108 }
3109 }
3110
3111 const unsigned num_32bit_outputs = util_bitcount64(b->shader->info.outputs_written);
3112
3113 /* Dedicated 16-bit outputs. */
3114 u_foreach_bit(i, b->shader->info.outputs_written_16bit) {
3115 const unsigned packed_location = num_32bit_outputs +
3116 util_bitcount(b->shader->info.outputs_written_16bit & BITFIELD_MASK(i));
3117
3118 const unsigned mask_lo = gs_output_component_mask_with_stream(&s->out.infos_16bit_lo[i], 0);
3119 const unsigned mask_hi = gs_output_component_mask_with_stream(&s->out.infos_16bit_hi[i], 0);
3120 unsigned mask = mask_lo | mask_hi;
3121
3122 while (mask) {
3123 int start, count;
3124 u_bit_scan_consecutive_range(&mask, &start, &count);
3125 nir_def *load =
3126 nir_load_shared(b, count, 32, exported_out_vtx_lds_addr,
3127 .base = packed_location * 16 + start * 4,
3128 .align_mul = 4);
3129
3130 for (int j = 0; j < count; j++) {
3131 nir_def *val = nir_channel(b, load, j);
3132 unsigned comp = start + j;
3133
3134 if (mask_lo & BITFIELD_BIT(comp))
3135 s->out.outputs_16bit_lo[i][comp] = nir_unpack_32_2x16_split_x(b, val);
3136
3137 if (mask_hi & BITFIELD_BIT(comp))
3138 s->out.outputs_16bit_hi[i][comp] = nir_unpack_32_2x16_split_y(b, val);
3139 }
3140 }
3141 }
3142
3143 uint64_t export_outputs = b->shader->info.outputs_written | VARYING_BIT_POS;
3144 if (s->options->kill_pointsize)
3145 export_outputs &= ~VARYING_BIT_PSIZ;
3146 if (s->options->kill_layer)
3147 export_outputs &= ~VARYING_BIT_LAYER;
3148
3149 const bool wait_attr_ring = must_wait_attr_ring(s->options->gfx_level, s->options->has_param_exports);
3150 if (wait_attr_ring)
3151 export_outputs &= ~VARYING_BIT_POS;
3152
3153 ac_nir_export_position(b, s->options->gfx_level,
3154 s->options->clip_cull_dist_mask,
3155 !s->options->has_param_exports,
3156 s->options->force_vrs, !wait_attr_ring,
3157 export_outputs, &s->out, NULL);
3158
3159 if (s->options->has_param_exports && s->options->gfx_level < GFX11) {
3160 /* Emit vertex parameter exports.
3161 * Only the vertex export threads should do this.
3162 */
3163 ac_nir_export_parameters(b, s->options->vs_output_param_offset,
3164 b->shader->info.outputs_written,
3165 b->shader->info.outputs_written_16bit,
3166 &s->out);
3167 }
3168
3169 nir_pop_if(b, if_vtx_export_thread);
3170
3171 if (s->options->has_param_exports && s->options->gfx_level >= GFX11) {
3172 /* Store vertex parameters to attribute ring.
3173 * For optimal attribute ring access, this should happen in top level CF.
3174 */
3175 create_output_phis(b, b->shader->info.outputs_written, b->shader->info.outputs_written_16bit, &s->out);
3176 ac_nir_store_parameters_to_attr_ring(b, s->options->vs_output_param_offset,
3177 b->shader->info.outputs_written,
3178 b->shader->info.outputs_written_16bit,
3179 &s->out, tid_in_tg, max_num_out_vtx);
3180
3181 if (wait_attr_ring)
3182 export_pos0_wait_attr_ring(b, if_vtx_export_thread, s->out.outputs, s->options);
3183 }
3184 }
3185
3186 static void
ngg_gs_setup_vertex_compaction(nir_builder * b,nir_def * vertex_live,nir_def * tid_in_tg,nir_def * exporter_tid_in_tg,lower_ngg_gs_state * s)3187 ngg_gs_setup_vertex_compaction(nir_builder *b, nir_def *vertex_live, nir_def *tid_in_tg,
3188 nir_def *exporter_tid_in_tg, lower_ngg_gs_state *s)
3189 {
3190 assert(vertex_live->bit_size == 1);
3191 nir_if *if_vertex_live = nir_push_if(b, vertex_live);
3192 {
3193 /* Setup the vertex compaction.
3194 * Save the current thread's id for the thread which will export the current vertex.
3195 * We reuse stream 1 of the primitive flag of the other thread's vertex for storing this.
3196 */
3197
3198 nir_def *exporter_lds_addr = ngg_gs_out_vertex_addr(b, exporter_tid_in_tg, s);
3199 nir_def *tid_in_tg_u8 = nir_u2u8(b, tid_in_tg);
3200 nir_store_shared(b, tid_in_tg_u8, exporter_lds_addr, .base = s->lds_offs_primflags + 1);
3201 }
3202 nir_pop_if(b, if_vertex_live);
3203 }
3204
3205 static nir_def *
ngg_gs_load_out_vtx_primflag(nir_builder * b,unsigned stream,nir_def * tid_in_tg,nir_def * vtx_lds_addr,nir_def * max_num_out_vtx,lower_ngg_gs_state * s)3206 ngg_gs_load_out_vtx_primflag(nir_builder *b, unsigned stream, nir_def *tid_in_tg,
3207 nir_def *vtx_lds_addr, nir_def *max_num_out_vtx,
3208 lower_ngg_gs_state *s)
3209 {
3210 nir_def *zero = nir_imm_int(b, 0);
3211
3212 nir_if *if_outvtx_thread = nir_push_if(b, nir_ilt(b, tid_in_tg, max_num_out_vtx));
3213 nir_def *primflag = nir_load_shared(b, 1, 8, vtx_lds_addr,
3214 .base = s->lds_offs_primflags + stream);
3215 primflag = nir_u2u32(b, primflag);
3216 nir_pop_if(b, if_outvtx_thread);
3217
3218 return nir_if_phi(b, primflag, zero);
3219 }
3220
3221 static void
ngg_gs_out_prim_all_vtxptr(nir_builder * b,nir_def * last_vtxidx,nir_def * last_vtxptr,nir_def * last_vtx_primflag,lower_ngg_gs_state * s,nir_def * vtxptr[3])3222 ngg_gs_out_prim_all_vtxptr(nir_builder *b, nir_def *last_vtxidx, nir_def *last_vtxptr,
3223 nir_def *last_vtx_primflag, lower_ngg_gs_state *s,
3224 nir_def *vtxptr[3])
3225 {
3226 unsigned last_vtx = s->num_vertices_per_primitive - 1;
3227 vtxptr[last_vtx]= last_vtxptr;
3228
3229 bool primitive_is_triangle = s->num_vertices_per_primitive == 3;
3230 nir_def *is_odd = primitive_is_triangle ?
3231 nir_ubfe_imm(b, last_vtx_primflag, 1, 1) : NULL;
3232
3233 for (unsigned i = 0; i < s->num_vertices_per_primitive - 1; i++) {
3234 nir_def *vtxidx = nir_iadd_imm(b, last_vtxidx, -(last_vtx - i));
3235
3236 /* Need to swap vertex 0 and vertex 1 when vertex 2 index is odd to keep
3237 * CW/CCW order for correct front/back face culling.
3238 */
3239 if (primitive_is_triangle)
3240 vtxidx = i == 0 ? nir_iadd(b, vtxidx, is_odd) : nir_isub(b, vtxidx, is_odd);
3241
3242 vtxptr[i] = ngg_gs_out_vertex_addr(b, vtxidx, s);
3243 }
3244 }
3245
3246 static nir_def *
ngg_gs_cull_primitive(nir_builder * b,nir_def * tid_in_tg,nir_def * max_vtxcnt,nir_def * out_vtx_lds_addr,nir_def * out_vtx_primflag_0,lower_ngg_gs_state * s)3247 ngg_gs_cull_primitive(nir_builder *b, nir_def *tid_in_tg, nir_def *max_vtxcnt,
3248 nir_def *out_vtx_lds_addr, nir_def *out_vtx_primflag_0,
3249 lower_ngg_gs_state *s)
3250 {
3251 /* we haven't enabled point culling, if enabled this function could be further optimized */
3252 assert(s->num_vertices_per_primitive > 1);
3253
3254 /* save the primflag so that we don't need to load it from LDS again */
3255 nir_variable *primflag_var = nir_local_variable_create(s->impl, glsl_uint_type(), "primflag");
3256 nir_store_var(b, primflag_var, out_vtx_primflag_0, 1);
3257
3258 /* last bit of primflag indicate if this is the final vertex of a primitive */
3259 nir_def *is_end_prim_vtx = nir_i2b(b, nir_iand_imm(b, out_vtx_primflag_0, 1));
3260 nir_def *has_output_vertex = nir_ilt(b, tid_in_tg, max_vtxcnt);
3261 nir_def *prim_enable = nir_iand(b, is_end_prim_vtx, has_output_vertex);
3262
3263 nir_if *if_prim_enable = nir_push_if(b, prim_enable);
3264 {
3265 /* Calculate the LDS address of every vertex in the current primitive. */
3266 nir_def *vtxptr[3];
3267 ngg_gs_out_prim_all_vtxptr(b, tid_in_tg, out_vtx_lds_addr, out_vtx_primflag_0, s, vtxptr);
3268
3269 /* Load the positions from LDS. */
3270 nir_def *pos[3][4];
3271 for (unsigned i = 0; i < s->num_vertices_per_primitive; i++) {
3272 /* VARYING_SLOT_POS == 0, so base won't count packed location */
3273 pos[i][3] = nir_load_shared(b, 1, 32, vtxptr[i], .base = 12); /* W */
3274 nir_def *xy = nir_load_shared(b, 2, 32, vtxptr[i], .base = 0, .align_mul = 4);
3275 pos[i][0] = nir_channel(b, xy, 0);
3276 pos[i][1] = nir_channel(b, xy, 1);
3277
3278 pos[i][0] = nir_fdiv(b, pos[i][0], pos[i][3]);
3279 pos[i][1] = nir_fdiv(b, pos[i][1], pos[i][3]);
3280 }
3281
3282 /* TODO: support clipdist culling in GS */
3283 nir_def *accepted_by_clipdist = nir_imm_true(b);
3284
3285 nir_def *accepted = ac_nir_cull_primitive(
3286 b, accepted_by_clipdist, pos, s->num_vertices_per_primitive, NULL, NULL);
3287
3288 nir_if *if_rejected = nir_push_if(b, nir_inot(b, accepted));
3289 {
3290 /* clear the primflag if rejected */
3291 nir_store_shared(b, nir_imm_zero(b, 1, 8), out_vtx_lds_addr,
3292 .base = s->lds_offs_primflags);
3293
3294 nir_store_var(b, primflag_var, nir_imm_int(b, 0), 1);
3295 }
3296 nir_pop_if(b, if_rejected);
3297 }
3298 nir_pop_if(b, if_prim_enable);
3299
3300 /* Wait for LDS primflag access done. */
3301 nir_barrier(b, .execution_scope = SCOPE_WORKGROUP,
3302 .memory_scope = SCOPE_WORKGROUP,
3303 .memory_semantics = NIR_MEMORY_ACQ_REL,
3304 .memory_modes = nir_var_mem_shared);
3305
3306 /* only dead vertex need a chance to relive */
3307 nir_def *vtx_is_dead = nir_ieq_imm(b, nir_load_var(b, primflag_var), 0);
3308 nir_def *vtx_update_primflag = nir_iand(b, vtx_is_dead, has_output_vertex);
3309 nir_if *if_update_primflag = nir_push_if(b, vtx_update_primflag);
3310 {
3311 /* get succeeding vertices' primflag to detect this vertex's liveness */
3312 for (unsigned i = 1; i < s->num_vertices_per_primitive; i++) {
3313 nir_def *vtxidx = nir_iadd_imm(b, tid_in_tg, i);
3314 nir_def *not_overflow = nir_ilt(b, vtxidx, max_vtxcnt);
3315 nir_if *if_not_overflow = nir_push_if(b, not_overflow);
3316 {
3317 nir_def *vtxptr = ngg_gs_out_vertex_addr(b, vtxidx, s);
3318 nir_def *vtx_primflag =
3319 nir_load_shared(b, 1, 8, vtxptr, .base = s->lds_offs_primflags);
3320 vtx_primflag = nir_u2u32(b, vtx_primflag);
3321
3322 /* if succeeding vertex is alive end of primitive vertex, need to set current
3323 * thread vertex's liveness flag (bit 2)
3324 */
3325 nir_def *has_prim = nir_i2b(b, nir_iand_imm(b, vtx_primflag, 1));
3326 nir_def *vtx_live_flag =
3327 nir_bcsel(b, has_prim, nir_imm_int(b, 0b100), nir_imm_int(b, 0));
3328
3329 /* update this vertex's primflag */
3330 nir_def *primflag = nir_load_var(b, primflag_var);
3331 primflag = nir_ior(b, primflag, vtx_live_flag);
3332 nir_store_var(b, primflag_var, primflag, 1);
3333 }
3334 nir_pop_if(b, if_not_overflow);
3335 }
3336 }
3337 nir_pop_if(b, if_update_primflag);
3338
3339 return nir_load_var(b, primflag_var);
3340 }
3341
3342 static void
ngg_gs_build_streamout(nir_builder * b,lower_ngg_gs_state * s)3343 ngg_gs_build_streamout(nir_builder *b, lower_ngg_gs_state *s)
3344 {
3345 nir_xfb_info *info = ac_nir_get_sorted_xfb_info(b->shader);
3346
3347 nir_def *tid_in_tg = nir_load_local_invocation_index(b);
3348 nir_def *max_vtxcnt = nir_load_workgroup_num_input_vertices_amd(b);
3349 nir_def *out_vtx_lds_addr = ngg_gs_out_vertex_addr(b, tid_in_tg, s);
3350 nir_def *prim_live[4] = {0};
3351 nir_def *gen_prim[4] = {0};
3352 nir_def *export_seq[4] = {0};
3353 nir_def *out_vtx_primflag[4] = {0};
3354 for (unsigned stream = 0; stream < 4; stream++) {
3355 if (!(info->streams_written & BITFIELD_BIT(stream)))
3356 continue;
3357
3358 out_vtx_primflag[stream] =
3359 ngg_gs_load_out_vtx_primflag(b, stream, tid_in_tg, out_vtx_lds_addr, max_vtxcnt, s);
3360
3361 /* Check bit 0 of primflag for primitive alive, it's set for every last
3362 * vertex of a primitive.
3363 */
3364 prim_live[stream] = nir_i2b(b, nir_iand_imm(b, out_vtx_primflag[stream], 1));
3365
3366 unsigned scratch_stride = ALIGN(s->max_num_waves, 4);
3367 nir_def *scratch_base =
3368 nir_iadd_imm(b, s->lds_addr_gs_scratch, stream * scratch_stride);
3369
3370 /* We want to export primitives to streamout buffer in sequence,
3371 * but not all vertices are alive or mark end of a primitive, so
3372 * there're "holes". We don't need continuous invocations to write
3373 * primitives to streamout buffer like final vertex export, so
3374 * just repack to get the sequence (export_seq) is enough, no need
3375 * to do compaction.
3376 *
3377 * Use separate scratch space for each stream to avoid barrier.
3378 * TODO: we may further reduce barriers by writing to all stream
3379 * LDS at once, then we only need one barrier instead of one each
3380 * stream..
3381 */
3382 wg_repack_result rep = {0};
3383 repack_invocations_in_workgroup(b, &prim_live[stream], &rep, 1, scratch_base,
3384 s->max_num_waves, s->options->wave_size);
3385
3386 /* nir_intrinsic_set_vertex_and_primitive_count can also get primitive count of
3387 * current wave, but still need LDS to sum all wave's count to get workgroup count.
3388 * And we need repack to export primitive to streamout buffer anyway, so do here.
3389 */
3390 gen_prim[stream] = rep.num_repacked_invocations;
3391 export_seq[stream] = rep.repacked_invocation_index;
3392 }
3393
3394 /* Workgroup barrier: wait for LDS scratch reads finish. */
3395 nir_barrier(b, .execution_scope = SCOPE_WORKGROUP,
3396 .memory_scope = SCOPE_WORKGROUP,
3397 .memory_semantics = NIR_MEMORY_ACQ_REL,
3398 .memory_modes = nir_var_mem_shared);
3399
3400 /* Get global buffer offset where this workgroup will stream out data to. */
3401 nir_def *emit_prim[4] = {0};
3402 nir_def *buffer_offsets[4] = {0};
3403 nir_def *so_buffer[4] = {0};
3404 ngg_build_streamout_buffer_info(b, info, s->options->gfx_level, s->options->has_xfb_prim_query,
3405 s->options->use_gfx12_xfb_intrinsic, s->lds_addr_gs_scratch, tid_in_tg,
3406 gen_prim, so_buffer, buffer_offsets, emit_prim);
3407
3408 for (unsigned stream = 0; stream < 4; stream++) {
3409 if (!(info->streams_written & BITFIELD_BIT(stream)))
3410 continue;
3411
3412 nir_def *can_emit = nir_ilt(b, export_seq[stream], emit_prim[stream]);
3413 nir_if *if_emit = nir_push_if(b, nir_iand(b, can_emit, prim_live[stream]));
3414 {
3415 /* Get streamout buffer vertex index for the first vertex of this primitive. */
3416 nir_def *first_vertex_idx =
3417 nir_imul_imm(b, export_seq[stream], s->num_vertices_per_primitive);
3418 nir_def *stream_buffer_offsets[NIR_MAX_XFB_BUFFERS];
3419
3420 u_foreach_bit(buffer, info->buffers_written) {
3421 stream_buffer_offsets[buffer] = nir_iadd(b, buffer_offsets[buffer],
3422 nir_imul_imm(b, first_vertex_idx,
3423 info->buffers[buffer].stride));
3424 }
3425
3426 /* Get all vertices' lds address of this primitive. */
3427 nir_def *exported_vtx_lds_addr[3];
3428 ngg_gs_out_prim_all_vtxptr(b, tid_in_tg, out_vtx_lds_addr,
3429 out_vtx_primflag[stream], s,
3430 exported_vtx_lds_addr);
3431
3432 /* Write all vertices of this primitive to streamout buffer. */
3433 for (unsigned i = 0; i < s->num_vertices_per_primitive; i++) {
3434 ngg_build_streamout_vertex(b, info, stream, so_buffer,
3435 stream_buffer_offsets, i,
3436 exported_vtx_lds_addr[i],
3437 &s->out, false);
3438 }
3439 }
3440 nir_pop_if(b, if_emit);
3441 }
3442 }
3443
3444 static void
ngg_gs_finale(nir_builder * b,lower_ngg_gs_state * s)3445 ngg_gs_finale(nir_builder *b, lower_ngg_gs_state *s)
3446 {
3447 nir_def *tid_in_tg = nir_load_local_invocation_index(b);
3448 nir_def *max_vtxcnt = nir_load_workgroup_num_input_vertices_amd(b);
3449 nir_def *max_prmcnt = max_vtxcnt; /* They are currently practically the same; both RADV and RadeonSI do this. */
3450 nir_def *out_vtx_lds_addr = ngg_gs_out_vertex_addr(b, tid_in_tg, s);
3451
3452 if (s->output_compile_time_known) {
3453 /* When the output is compile-time known, the GS writes all possible vertices and primitives it can.
3454 * The gs_alloc_req needs to happen on one wave only, otherwise the HW hangs.
3455 */
3456 nir_if *if_wave_0 = nir_push_if(b, nir_ieq_imm(b, nir_load_subgroup_id(b), 0));
3457 alloc_vertices_and_primitives(b, max_vtxcnt, max_prmcnt);
3458 nir_pop_if(b, if_wave_0);
3459 }
3460
3461 /* Workgroup barrier already emitted, we can assume all GS output stores are done by now. */
3462
3463 nir_def *out_vtx_primflag_0 = ngg_gs_load_out_vtx_primflag(b, 0, tid_in_tg, out_vtx_lds_addr, max_vtxcnt, s);
3464
3465 if (s->output_compile_time_known) {
3466 ngg_gs_export_primitives(b, max_vtxcnt, tid_in_tg, tid_in_tg, out_vtx_primflag_0, s);
3467 ngg_gs_export_vertices(b, max_vtxcnt, tid_in_tg, out_vtx_lds_addr, s);
3468 return;
3469 }
3470
3471 /* cull primitives */
3472 if (s->options->can_cull) {
3473 nir_if *if_cull_en = nir_push_if(b, nir_load_cull_any_enabled_amd(b));
3474
3475 /* culling code will update the primflag */
3476 nir_def *updated_primflag =
3477 ngg_gs_cull_primitive(b, tid_in_tg, max_vtxcnt, out_vtx_lds_addr,
3478 out_vtx_primflag_0, s);
3479
3480 nir_pop_if(b, if_cull_en);
3481
3482 out_vtx_primflag_0 = nir_if_phi(b, updated_primflag, out_vtx_primflag_0);
3483 }
3484
3485 /* When the output vertex count is not known at compile time:
3486 * There may be gaps between invocations that have live vertices, but NGG hardware
3487 * requires that the invocations that export vertices are packed (ie. compact).
3488 * To ensure this, we need to repack invocations that have a live vertex.
3489 */
3490 nir_def *vertex_live = nir_ine_imm(b, out_vtx_primflag_0, 0);
3491 wg_repack_result rep = {0};
3492
3493 repack_invocations_in_workgroup(b, &vertex_live, &rep, 1, s->lds_addr_gs_scratch,
3494 s->max_num_waves, s->options->wave_size);
3495
3496 nir_def *workgroup_num_vertices = rep.num_repacked_invocations;
3497 nir_def *exporter_tid_in_tg = rep.repacked_invocation_index;
3498
3499 /* When the workgroup emits 0 total vertices, we also must export 0 primitives (otherwise the HW can hang). */
3500 nir_def *any_output = nir_ine_imm(b, workgroup_num_vertices, 0);
3501 max_prmcnt = nir_bcsel(b, any_output, max_prmcnt, nir_imm_int(b, 0));
3502
3503 /* Allocate export space. We currently don't compact primitives, just use the maximum number. */
3504 nir_if *if_wave_0 = nir_push_if(b, nir_ieq_imm(b, nir_load_subgroup_id(b), 0));
3505 {
3506 if (s->options->gfx_level == GFX10)
3507 alloc_vertices_and_primitives_gfx10_workaround(b, workgroup_num_vertices, max_prmcnt);
3508 else
3509 alloc_vertices_and_primitives(b, workgroup_num_vertices, max_prmcnt);
3510 }
3511 nir_pop_if(b, if_wave_0);
3512
3513 /* Vertex compaction. This makes sure there are no gaps between threads that export vertices. */
3514 ngg_gs_setup_vertex_compaction(b, vertex_live, tid_in_tg, exporter_tid_in_tg, s);
3515
3516 /* Workgroup barrier: wait for all LDS stores to finish. */
3517 nir_barrier(b, .execution_scope=SCOPE_WORKGROUP, .memory_scope=SCOPE_WORKGROUP,
3518 .memory_semantics=NIR_MEMORY_ACQ_REL, .memory_modes=nir_var_mem_shared);
3519
3520 ngg_gs_export_primitives(b, max_prmcnt, tid_in_tg, exporter_tid_in_tg, out_vtx_primflag_0, s);
3521 ngg_gs_export_vertices(b, workgroup_num_vertices, tid_in_tg, out_vtx_lds_addr, s);
3522 }
3523
3524 void
ac_nir_lower_ngg_gs(nir_shader * shader,const ac_nir_lower_ngg_options * options)3525 ac_nir_lower_ngg_gs(nir_shader *shader, const ac_nir_lower_ngg_options *options)
3526 {
3527 nir_function_impl *impl = nir_shader_get_entrypoint(shader);
3528 assert(impl);
3529
3530 lower_ngg_gs_state state = {
3531 .options = options,
3532 .impl = impl,
3533 .max_num_waves = DIV_ROUND_UP(options->max_workgroup_size, options->wave_size),
3534 .lds_offs_primflags = options->gs_out_vtx_bytes,
3535 .lds_bytes_per_gs_out_vertex = options->gs_out_vtx_bytes + 4u,
3536 .streamout_enabled = shader->xfb_info && !options->disable_streamout,
3537 };
3538
3539 if (!options->can_cull) {
3540 nir_gs_count_vertices_and_primitives(shader, state.const_out_vtxcnt,
3541 state.const_out_prmcnt, NULL, 4u);
3542 state.output_compile_time_known =
3543 state.const_out_vtxcnt[0] == shader->info.gs.vertices_out &&
3544 state.const_out_prmcnt[0] != -1;
3545 }
3546
3547 if (shader->info.gs.output_primitive == MESA_PRIM_POINTS)
3548 state.num_vertices_per_primitive = 1;
3549 else if (shader->info.gs.output_primitive == MESA_PRIM_LINE_STRIP)
3550 state.num_vertices_per_primitive = 2;
3551 else if (shader->info.gs.output_primitive == MESA_PRIM_TRIANGLE_STRIP)
3552 state.num_vertices_per_primitive = 3;
3553 else
3554 unreachable("Invalid GS output primitive.");
3555
3556 /* Extract the full control flow. It is going to be wrapped in an if statement. */
3557 nir_cf_list extracted;
3558 nir_cf_extract(&extracted, nir_before_impl(impl),
3559 nir_after_impl(impl));
3560
3561 nir_builder builder = nir_builder_at(nir_before_impl(impl));
3562 nir_builder *b = &builder; /* This is to avoid the & */
3563
3564 /* Workgroup barrier: wait for ES threads */
3565 nir_barrier(b, .execution_scope=SCOPE_WORKGROUP, .memory_scope=SCOPE_WORKGROUP,
3566 .memory_semantics=NIR_MEMORY_ACQ_REL, .memory_modes=nir_var_mem_shared);
3567
3568 state.lds_addr_gs_out_vtx = nir_load_lds_ngg_gs_out_vertex_base_amd(b);
3569 state.lds_addr_gs_scratch = nir_load_lds_ngg_scratch_base_amd(b);
3570
3571 /* Wrap the GS control flow. */
3572 nir_if *if_gs_thread = nir_push_if(b, has_input_primitive(b));
3573
3574 nir_cf_reinsert(&extracted, b->cursor);
3575 b->cursor = nir_after_cf_list(&if_gs_thread->then_list);
3576 nir_pop_if(b, if_gs_thread);
3577
3578 /* Workgroup barrier: wait for all GS threads to finish */
3579 nir_barrier(b, .execution_scope=SCOPE_WORKGROUP, .memory_scope=SCOPE_WORKGROUP,
3580 .memory_semantics=NIR_MEMORY_ACQ_REL, .memory_modes=nir_var_mem_shared);
3581
3582 if (state.streamout_enabled)
3583 ngg_gs_build_streamout(b, &state);
3584
3585 /* Lower the GS intrinsics */
3586 lower_ngg_gs_intrinsics(shader, &state);
3587
3588 if (!state.vertex_count[0]) {
3589 fprintf(stderr, "Could not find set_vertex_and_primitive_count for stream 0. This would hang your GPU.");
3590 abort();
3591 }
3592
3593 /* Emit shader queries */
3594 b->cursor = nir_after_cf_list(&if_gs_thread->then_list);
3595 ac_nir_gs_shader_query(b,
3596 state.options->has_gen_prim_query,
3597 state.options->has_gs_invocations_query,
3598 state.options->has_gs_primitives_query,
3599 state.num_vertices_per_primitive,
3600 state.options->wave_size,
3601 state.vertex_count,
3602 state.primitive_count);
3603
3604 b->cursor = nir_after_impl(impl);
3605
3606 /* Emit the finale sequence */
3607 ngg_gs_finale(b, &state);
3608 nir_validate_shader(shader, "after emitting NGG GS");
3609
3610 /* Cleanup */
3611 nir_lower_vars_to_ssa(shader);
3612 nir_remove_dead_variables(shader, nir_var_function_temp, NULL);
3613 nir_metadata_preserve(impl, nir_metadata_none);
3614 }
3615
3616 unsigned
ac_ngg_nogs_get_pervertex_lds_size(gl_shader_stage stage,unsigned shader_num_outputs,bool streamout_enabled,bool export_prim_id,bool has_user_edgeflags,bool can_cull,bool uses_instance_id,bool uses_primitive_id)3617 ac_ngg_nogs_get_pervertex_lds_size(gl_shader_stage stage,
3618 unsigned shader_num_outputs,
3619 bool streamout_enabled,
3620 bool export_prim_id,
3621 bool has_user_edgeflags,
3622 bool can_cull,
3623 bool uses_instance_id,
3624 bool uses_primitive_id)
3625 {
3626 /* for culling time lds layout only */
3627 unsigned culling_pervertex_lds_bytes = can_cull ?
3628 ngg_nogs_get_culling_pervertex_lds_size(
3629 stage, uses_instance_id, uses_primitive_id, NULL) : 0;
3630
3631 unsigned pervertex_lds_bytes =
3632 ngg_nogs_get_pervertex_lds_size(stage, shader_num_outputs, streamout_enabled,
3633 export_prim_id, has_user_edgeflags);
3634
3635 return MAX2(culling_pervertex_lds_bytes, pervertex_lds_bytes);
3636 }
3637
3638 unsigned
ac_ngg_get_scratch_lds_size(gl_shader_stage stage,unsigned workgroup_size,unsigned wave_size,bool streamout_enabled,bool can_cull,bool compact_primitives)3639 ac_ngg_get_scratch_lds_size(gl_shader_stage stage,
3640 unsigned workgroup_size,
3641 unsigned wave_size,
3642 bool streamout_enabled,
3643 bool can_cull,
3644 bool compact_primitives)
3645 {
3646 unsigned scratch_lds_size = 0;
3647 unsigned max_num_waves = DIV_ROUND_UP(workgroup_size, wave_size);
3648
3649 if (stage == MESA_SHADER_VERTEX || stage == MESA_SHADER_TESS_EVAL) {
3650 if (streamout_enabled) {
3651 /* 4 dwords for 4 streamout buffer offset, 1 dword for emit prim count */
3652 scratch_lds_size = 20;
3653 } else if (can_cull) {
3654 /* 1 byte per wave per repack, max 8 waves */
3655 unsigned num_rep = compact_primitives ? 2 : 1;
3656 scratch_lds_size = ALIGN(max_num_waves, 4u) * num_rep;
3657 }
3658 } else {
3659 assert(stage == MESA_SHADER_GEOMETRY);
3660
3661 scratch_lds_size = ALIGN(max_num_waves, 4u);
3662 /* streamout take 8 dwords for buffer offset and emit vertex per stream */
3663 if (streamout_enabled)
3664 scratch_lds_size = MAX2(scratch_lds_size, 32);
3665 }
3666
3667 return scratch_lds_size;
3668 }
3669