1 /*
2 * Copyright © 2021 Valve Corporation
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 * and/or sell copies of the Software, and to permit persons to whom the
9 * Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21 * IN THE SOFTWARE.
22 *
23 */
24
25 #include "ac_nir.h"
26 #include "nir_builder.h"
27 #include "u_math.h"
28 #include "u_vector.h"
29
30 enum {
31 nggc_passflag_used_by_pos = 1,
32 nggc_passflag_used_by_other = 2,
33 nggc_passflag_used_by_both = nggc_passflag_used_by_pos | nggc_passflag_used_by_other,
34 };
35
36 typedef struct
37 {
38 nir_ssa_def *ssa;
39 nir_variable *var;
40 } saved_uniform;
41
42 typedef struct
43 {
44 nir_variable *position_value_var;
45 nir_variable *prim_exp_arg_var;
46 nir_variable *es_accepted_var;
47 nir_variable *gs_accepted_var;
48 nir_variable *gs_vtx_indices_vars[3];
49
50 struct u_vector saved_uniforms;
51
52 bool passthrough;
53 bool export_prim_id;
54 bool early_prim_export;
55 bool use_edgeflags;
56 bool has_prim_query;
57 bool can_cull;
58 unsigned wave_size;
59 unsigned max_num_waves;
60 unsigned num_vertices_per_primitives;
61 unsigned provoking_vtx_idx;
62 unsigned max_es_num_vertices;
63 unsigned total_lds_bytes;
64
65 uint64_t inputs_needed_by_pos;
66 uint64_t inputs_needed_by_others;
67 uint32_t instance_rate_inputs;
68
69 nir_instr *compact_arg_stores[4];
70 nir_intrinsic_instr *overwrite_args;
71 } lower_ngg_nogs_state;
72
73 typedef struct
74 {
75 /* Bitmask of components used: 4 bits per slot, 1 bit per component. */
76 uint8_t components_mask : 4;
77 /* output stream index */
78 uint8_t stream : 2;
79 } gs_output_info;
80
81 typedef struct
82 {
83 nir_variable *output_vars[VARYING_SLOT_MAX][4];
84 nir_variable *current_clear_primflag_idx_var;
85 int const_out_vtxcnt[4];
86 int const_out_prmcnt[4];
87 unsigned wave_size;
88 unsigned max_num_waves;
89 unsigned num_vertices_per_primitive;
90 unsigned lds_addr_gs_out_vtx;
91 unsigned lds_addr_gs_scratch;
92 unsigned lds_bytes_per_gs_out_vertex;
93 unsigned lds_offs_primflags;
94 bool found_out_vtxcnt[4];
95 bool output_compile_time_known;
96 bool provoking_vertex_last;
97 gs_output_info output_info[VARYING_SLOT_MAX];
98 } lower_ngg_gs_state;
99
100 /* LDS layout of Mesh Shader workgroup info. */
101 enum {
102 /* DW0: number of primitives */
103 lds_ms_num_prims = 0,
104 /* DW1: reserved for future use */
105 lds_ms_dw1_reserved = 4,
106 /* DW2: workgroup index within the current dispatch */
107 lds_ms_wg_index = 8,
108 /* DW3: number of API workgroups in flight */
109 lds_ms_num_api_waves = 12,
110 };
111
112 /* Potential location for Mesh Shader outputs. */
113 typedef enum {
114 ms_out_mode_lds,
115 ms_out_mode_vram,
116 ms_out_mode_var,
117 } ms_out_mode;
118
119 typedef struct
120 {
121 uint64_t mask; /* Mask of output locations */
122 uint32_t addr; /* Base address */
123 } ms_out_part;
124
125 typedef struct
126 {
127 /* Mesh shader LDS layout. For details, see ms_calculate_output_layout. */
128 struct {
129 uint32_t workgroup_info_addr;
130 ms_out_part vtx_attr;
131 ms_out_part prm_attr;
132 uint32_t indices_addr;
133 uint32_t total_size;
134 } lds;
135 /* VRAM "mesh shader scratch ring" layout for outputs that don't fit into the LDS. */
136 struct {
137 ms_out_part vtx_attr;
138 ms_out_part prm_attr;
139 } vram;
140 /* Outputs without cross-invocation access can be stored in variables. */
141 struct {
142 ms_out_part vtx_attr;
143 ms_out_part prm_attr;
144 } var;
145 } ms_out_mem_layout;
146
147 typedef struct
148 {
149 ms_out_mem_layout layout;
150 uint64_t per_vertex_outputs;
151 uint64_t per_primitive_outputs;
152 unsigned vertices_per_prim;
153
154 unsigned wave_size;
155 unsigned api_workgroup_size;
156 unsigned hw_workgroup_size;
157
158 nir_ssa_def *workgroup_index;
159 nir_variable *out_variables[VARYING_SLOT_MAX * 4];
160
161 /* True if the lowering needs to insert the layer output. */
162 bool insert_layer_output;
163
164 struct {
165 /* Bitmask of components used: 4 bits per slot, 1 bit per component. */
166 uint32_t components_mask;
167 } output_info[VARYING_SLOT_MAX];
168 } lower_ngg_ms_state;
169
170 typedef struct {
171 nir_variable *pre_cull_position_value_var;
172 } remove_culling_shader_outputs_state;
173
174 typedef struct {
175 nir_variable *pos_value_replacement;
176 } remove_extra_position_output_state;
177
178 /* Per-vertex LDS layout of culling shaders */
179 enum {
180 /* Position of the ES vertex (at the beginning for alignment reasons) */
181 lds_es_pos_x = 0,
182 lds_es_pos_y = 4,
183 lds_es_pos_z = 8,
184 lds_es_pos_w = 12,
185
186 /* 1 when the vertex is accepted, 0 if it should be culled */
187 lds_es_vertex_accepted = 16,
188 /* ID of the thread which will export the current thread's vertex */
189 lds_es_exporter_tid = 17,
190
191 /* Repacked arguments - also listed separately for VS and TES */
192 lds_es_arg_0 = 20,
193
194 /* VS arguments which need to be repacked */
195 lds_es_vs_vertex_id = 20,
196 lds_es_vs_instance_id = 24,
197
198 /* TES arguments which need to be repacked */
199 lds_es_tes_u = 20,
200 lds_es_tes_v = 24,
201 lds_es_tes_rel_patch_id = 28,
202 lds_es_tes_patch_id = 32,
203 };
204
205 typedef struct {
206 nir_ssa_def *num_repacked_invocations;
207 nir_ssa_def *repacked_invocation_index;
208 } wg_repack_result;
209
210 /**
211 * Computes a horizontal sum of 8-bit packed values loaded from LDS.
212 *
213 * Each lane N will sum packed bytes 0 to N-1.
214 * We only care about the results from up to wave_id+1 lanes.
215 * (Other lanes are not deactivated but their calculation is not used.)
216 */
217 static nir_ssa_def *
summarize_repack(nir_builder * b,nir_ssa_def * packed_counts,unsigned num_lds_dwords)218 summarize_repack(nir_builder *b, nir_ssa_def *packed_counts, unsigned num_lds_dwords)
219 {
220 /* We'll use shift to filter out the bytes not needed by the current lane.
221 *
222 * Need to shift by: num_lds_dwords * 4 - lane_id (in bytes).
223 * However, two shifts are needed because one can't go all the way,
224 * so the shift amount is half that (and in bits).
225 *
226 * When v_dot4_u32_u8 is available, we right-shift a series of 0x01 bytes.
227 * This will yield 0x01 at wanted byte positions and 0x00 at unwanted positions,
228 * therefore v_dot can get rid of the unneeded values.
229 * This sequence is preferable because it better hides the latency of the LDS.
230 *
231 * If the v_dot instruction can't be used, we left-shift the packed bytes.
232 * This will shift out the unneeded bytes and shift in zeroes instead,
233 * then we sum them using v_sad_u8.
234 */
235
236 nir_ssa_def *lane_id = nir_load_subgroup_invocation(b);
237 nir_ssa_def *shift = nir_iadd_imm(b, nir_imul_imm(b, lane_id, -4u), num_lds_dwords * 16);
238 bool use_dot = b->shader->options->has_udot_4x8;
239
240 if (num_lds_dwords == 1) {
241 nir_ssa_def *dot_op = !use_dot ? NULL : nir_ushr(b, nir_ushr(b, nir_imm_int(b, 0x01010101), shift), shift);
242
243 /* Broadcast the packed data we read from LDS (to the first 16 lanes, but we only care up to num_waves). */
244 nir_ssa_def *packed = nir_lane_permute_16_amd(b, packed_counts, nir_imm_int(b, 0), nir_imm_int(b, 0));
245
246 /* Horizontally add the packed bytes. */
247 if (use_dot) {
248 return nir_udot_4x8_uadd(b, packed, dot_op, nir_imm_int(b, 0));
249 } else {
250 nir_ssa_def *sad_op = nir_ishl(b, nir_ishl(b, packed, shift), shift);
251 return nir_sad_u8x4(b, sad_op, nir_imm_int(b, 0), nir_imm_int(b, 0));
252 }
253 } else if (num_lds_dwords == 2) {
254 nir_ssa_def *dot_op = !use_dot ? NULL : nir_ushr(b, nir_ushr(b, nir_imm_int64(b, 0x0101010101010101), shift), shift);
255
256 /* Broadcast the packed data we read from LDS (to the first 16 lanes, but we only care up to num_waves). */
257 nir_ssa_def *packed_dw0 = nir_lane_permute_16_amd(b, nir_unpack_64_2x32_split_x(b, packed_counts), nir_imm_int(b, 0), nir_imm_int(b, 0));
258 nir_ssa_def *packed_dw1 = nir_lane_permute_16_amd(b, nir_unpack_64_2x32_split_y(b, packed_counts), nir_imm_int(b, 0), nir_imm_int(b, 0));
259
260 /* Horizontally add the packed bytes. */
261 if (use_dot) {
262 nir_ssa_def *sum = nir_udot_4x8_uadd(b, packed_dw0, nir_unpack_64_2x32_split_x(b, dot_op), nir_imm_int(b, 0));
263 return nir_udot_4x8_uadd(b, packed_dw1, nir_unpack_64_2x32_split_y(b, dot_op), sum);
264 } else {
265 nir_ssa_def *sad_op = nir_ishl(b, nir_ishl(b, nir_pack_64_2x32_split(b, packed_dw0, packed_dw1), shift), shift);
266 nir_ssa_def *sum = nir_sad_u8x4(b, nir_unpack_64_2x32_split_x(b, sad_op), nir_imm_int(b, 0), nir_imm_int(b, 0));
267 return nir_sad_u8x4(b, nir_unpack_64_2x32_split_y(b, sad_op), nir_imm_int(b, 0), sum);
268 }
269 } else {
270 unreachable("Unimplemented NGG wave count");
271 }
272 }
273
274 /**
275 * Repacks invocations in the current workgroup to eliminate gaps between them.
276 *
277 * Uses 1 dword of LDS per 4 waves (1 byte of LDS per wave).
278 * Assumes that all invocations in the workgroup are active (exec = -1).
279 */
280 static wg_repack_result
repack_invocations_in_workgroup(nir_builder * b,nir_ssa_def * input_bool,unsigned lds_addr_base,unsigned max_num_waves,unsigned wave_size)281 repack_invocations_in_workgroup(nir_builder *b, nir_ssa_def *input_bool,
282 unsigned lds_addr_base, unsigned max_num_waves,
283 unsigned wave_size)
284 {
285 /* Input boolean: 1 if the current invocation should survive the repack. */
286 assert(input_bool->bit_size == 1);
287
288 /* STEP 1. Count surviving invocations in the current wave.
289 *
290 * Implemented by a scalar instruction that simply counts the number of bits set in a 32/64-bit mask.
291 */
292
293 nir_ssa_def *input_mask = nir_ballot(b, 1, wave_size, input_bool);
294 nir_ssa_def *surviving_invocations_in_current_wave = nir_bit_count(b, input_mask);
295
296 /* If we know at compile time that the workgroup has only 1 wave, no further steps are necessary. */
297 if (max_num_waves == 1) {
298 wg_repack_result r = {
299 .num_repacked_invocations = surviving_invocations_in_current_wave,
300 .repacked_invocation_index = nir_mbcnt_amd(b, input_mask, nir_imm_int(b, 0)),
301 };
302 return r;
303 }
304
305 /* STEP 2. Waves tell each other their number of surviving invocations.
306 *
307 * Each wave activates only its first lane (exec = 1), which stores the number of surviving
308 * invocations in that wave into the LDS, then reads the numbers from every wave.
309 *
310 * The workgroup size of NGG shaders is at most 256, which means
311 * the maximum number of waves is 4 in Wave64 mode and 8 in Wave32 mode.
312 * Each wave writes 1 byte, so it's up to 8 bytes, so at most 2 dwords are necessary.
313 */
314
315 const unsigned num_lds_dwords = DIV_ROUND_UP(max_num_waves, 4);
316 assert(num_lds_dwords <= 2);
317
318 nir_ssa_def *wave_id = nir_load_subgroup_id(b);
319 nir_ssa_def *dont_care = nir_ssa_undef(b, 1, num_lds_dwords * 32);
320 nir_if *if_first_lane = nir_push_if(b, nir_elect(b, 1));
321
322 nir_store_shared(b, nir_u2u8(b, surviving_invocations_in_current_wave), wave_id, .base = lds_addr_base);
323
324 nir_scoped_barrier(b, .execution_scope=NIR_SCOPE_WORKGROUP, .memory_scope=NIR_SCOPE_WORKGROUP,
325 .memory_semantics=NIR_MEMORY_ACQ_REL, .memory_modes=nir_var_mem_shared);
326
327 nir_ssa_def *packed_counts = nir_load_shared(b, 1, num_lds_dwords * 32, nir_imm_int(b, 0), .base = lds_addr_base, .align_mul = 8u);
328
329 nir_pop_if(b, if_first_lane);
330
331 packed_counts = nir_if_phi(b, packed_counts, dont_care);
332
333 /* STEP 3. Compute the repacked invocation index and the total number of surviving invocations.
334 *
335 * By now, every wave knows the number of surviving invocations in all waves.
336 * Each number is 1 byte, and they are packed into up to 2 dwords.
337 *
338 * Each lane N will sum the number of surviving invocations from waves 0 to N-1.
339 * If the workgroup has M waves, then each wave will use only its first M+1 lanes for this.
340 * (Other lanes are not deactivated but their calculation is not used.)
341 *
342 * - We read the sum from the lane whose id is the current wave's id.
343 * Add the masked bitcount to this, and we get the repacked invocation index.
344 * - We read the sum from the lane whose id is the number of waves in the workgroup.
345 * This is the total number of surviving invocations in the workgroup.
346 */
347
348 nir_ssa_def *num_waves = nir_load_num_subgroups(b);
349 nir_ssa_def *sum = summarize_repack(b, packed_counts, num_lds_dwords);
350
351 nir_ssa_def *wg_repacked_index_base = nir_read_invocation(b, sum, wave_id);
352 nir_ssa_def *wg_num_repacked_invocations = nir_read_invocation(b, sum, num_waves);
353 nir_ssa_def *wg_repacked_index = nir_mbcnt_amd(b, input_mask, wg_repacked_index_base);
354
355 wg_repack_result r = {
356 .num_repacked_invocations = wg_num_repacked_invocations,
357 .repacked_invocation_index = wg_repacked_index,
358 };
359
360 return r;
361 }
362
363 static nir_ssa_def *
pervertex_lds_addr(nir_builder * b,nir_ssa_def * vertex_idx,unsigned per_vtx_bytes)364 pervertex_lds_addr(nir_builder *b, nir_ssa_def *vertex_idx, unsigned per_vtx_bytes)
365 {
366 return nir_imul_imm(b, vertex_idx, per_vtx_bytes);
367 }
368
369 static nir_ssa_def *
emit_pack_ngg_prim_exp_arg(nir_builder * b,unsigned num_vertices_per_primitives,nir_ssa_def * vertex_indices[3],nir_ssa_def * is_null_prim,bool use_edgeflags)370 emit_pack_ngg_prim_exp_arg(nir_builder *b, unsigned num_vertices_per_primitives,
371 nir_ssa_def *vertex_indices[3], nir_ssa_def *is_null_prim,
372 bool use_edgeflags)
373 {
374 nir_ssa_def *arg = use_edgeflags
375 ? nir_load_initial_edgeflags_amd(b)
376 : nir_imm_int(b, 0);
377
378 for (unsigned i = 0; i < num_vertices_per_primitives; ++i) {
379 assert(vertex_indices[i]);
380 arg = nir_ior(b, arg, nir_ishl(b, vertex_indices[i], nir_imm_int(b, 10u * i)));
381 }
382
383 if (is_null_prim) {
384 if (is_null_prim->bit_size == 1)
385 is_null_prim = nir_b2i32(b, is_null_prim);
386 assert(is_null_prim->bit_size == 32);
387 arg = nir_ior(b, arg, nir_ishl(b, is_null_prim, nir_imm_int(b, 31u)));
388 }
389
390 return arg;
391 }
392
393 static void
ngg_nogs_init_vertex_indices_vars(nir_builder * b,nir_function_impl * impl,lower_ngg_nogs_state * st)394 ngg_nogs_init_vertex_indices_vars(nir_builder *b, nir_function_impl *impl, lower_ngg_nogs_state *st)
395 {
396 for (unsigned v = 0; v < st->num_vertices_per_primitives; ++v) {
397 st->gs_vtx_indices_vars[v] = nir_local_variable_create(impl, glsl_uint_type(), "gs_vtx_addr");
398
399 nir_ssa_def *vtx = nir_ubfe(b, nir_load_gs_vertex_offset_amd(b, .base = v / 2u),
400 nir_imm_int(b, (v & 1u) * 16u), nir_imm_int(b, 16u));
401 nir_store_var(b, st->gs_vtx_indices_vars[v], vtx, 0x1);
402 }
403 }
404
405 static nir_ssa_def *
emit_ngg_nogs_prim_exp_arg(nir_builder * b,lower_ngg_nogs_state * st)406 emit_ngg_nogs_prim_exp_arg(nir_builder *b, lower_ngg_nogs_state *st)
407 {
408 if (st->passthrough) {
409 assert(!st->export_prim_id || b->shader->info.stage != MESA_SHADER_VERTEX);
410 return nir_load_packed_passthrough_primitive_amd(b);
411 } else {
412 nir_ssa_def *vtx_idx[3] = {0};
413
414 for (unsigned v = 0; v < st->num_vertices_per_primitives; ++v)
415 vtx_idx[v] = nir_load_var(b, st->gs_vtx_indices_vars[v]);
416
417 return emit_pack_ngg_prim_exp_arg(b, st->num_vertices_per_primitives, vtx_idx, NULL, st->use_edgeflags);
418 }
419 }
420
421 static void
emit_ngg_nogs_prim_export(nir_builder * b,lower_ngg_nogs_state * st,nir_ssa_def * arg)422 emit_ngg_nogs_prim_export(nir_builder *b, lower_ngg_nogs_state *st, nir_ssa_def *arg)
423 {
424 nir_ssa_def *gs_thread = st->gs_accepted_var
425 ? nir_load_var(b, st->gs_accepted_var)
426 : nir_has_input_primitive_amd(b);
427
428 nir_if *if_gs_thread = nir_push_if(b, gs_thread);
429 {
430 if (!arg)
431 arg = emit_ngg_nogs_prim_exp_arg(b, st);
432
433 if (st->has_prim_query) {
434 nir_if *if_shader_query = nir_push_if(b, nir_load_shader_query_enabled_amd(b));
435 {
436 /* Number of active GS threads. Each has 1 output primitive. */
437 nir_ssa_def *num_gs_threads = nir_bit_count(b, nir_ballot(b, 1, st->wave_size, nir_imm_bool(b, true)));
438 /* Activate only 1 lane and add the number of primitives to GDS. */
439 nir_if *if_elected = nir_push_if(b, nir_elect(b, 1));
440 {
441 /* Use a different GDS offset than NGG GS to ensure that pipeline statistics
442 * queries won't return the number of primitives generated by VS/TES.
443 */
444 nir_gds_atomic_add_amd(b, 32, num_gs_threads, nir_imm_int(b, 4), nir_imm_int(b, 0x100));
445 }
446 nir_pop_if(b, if_elected);
447 }
448 nir_pop_if(b, if_shader_query);
449 }
450
451 nir_export_primitive_amd(b, arg);
452 }
453 nir_pop_if(b, if_gs_thread);
454 }
455
456 static void
emit_ngg_nogs_prim_id_store_shared(nir_builder * b,lower_ngg_nogs_state * st)457 emit_ngg_nogs_prim_id_store_shared(nir_builder *b, lower_ngg_nogs_state *st)
458 {
459 nir_ssa_def *gs_thread = st->gs_accepted_var ?
460 nir_load_var(b, st->gs_accepted_var) : nir_has_input_primitive_amd(b);
461
462 nir_if *if_gs_thread = nir_push_if(b, gs_thread);
463 {
464 /* Copy Primitive IDs from GS threads to the LDS address
465 * corresponding to the ES thread of the provoking vertex.
466 * It will be exported as a per-vertex attribute.
467 */
468 nir_ssa_def *prim_id = nir_load_primitive_id(b);
469 nir_ssa_def *provoking_vtx_idx = nir_load_var(b, st->gs_vtx_indices_vars[st->provoking_vtx_idx]);
470 nir_ssa_def *addr = pervertex_lds_addr(b, provoking_vtx_idx, 4u);
471
472 nir_store_shared(b, prim_id, addr);
473 }
474 nir_pop_if(b, if_gs_thread);
475 }
476
477 static void
emit_store_ngg_nogs_es_primitive_id(nir_builder * b)478 emit_store_ngg_nogs_es_primitive_id(nir_builder *b)
479 {
480 nir_ssa_def *prim_id = NULL;
481
482 if (b->shader->info.stage == MESA_SHADER_VERTEX) {
483 /* LDS address where the primitive ID is stored */
484 nir_ssa_def *thread_id_in_threadgroup = nir_load_local_invocation_index(b);
485 nir_ssa_def *addr = pervertex_lds_addr(b, thread_id_in_threadgroup, 4u);
486
487 /* Load primitive ID from LDS */
488 prim_id = nir_load_shared(b, 1, 32, addr);
489 } else if (b->shader->info.stage == MESA_SHADER_TESS_EVAL) {
490 /* Just use tess eval primitive ID, which is the same as the patch ID. */
491 prim_id = nir_load_primitive_id(b);
492 }
493
494 nir_io_semantics io_sem = {
495 .location = VARYING_SLOT_PRIMITIVE_ID,
496 .num_slots = 1,
497 };
498
499 nir_store_output(b, prim_id, nir_imm_zero(b, 1, 32),
500 .base = io_sem.location,
501 .src_type = nir_type_uint32, .io_semantics = io_sem);
502 }
503
504 static bool
remove_culling_shader_output(nir_builder * b,nir_instr * instr,void * state)505 remove_culling_shader_output(nir_builder *b, nir_instr *instr, void *state)
506 {
507 remove_culling_shader_outputs_state *s = (remove_culling_shader_outputs_state *) state;
508
509 if (instr->type != nir_instr_type_intrinsic)
510 return false;
511
512 nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
513
514 /* These are not allowed in VS / TES */
515 assert(intrin->intrinsic != nir_intrinsic_store_per_vertex_output &&
516 intrin->intrinsic != nir_intrinsic_load_per_vertex_input);
517
518 /* We are only interested in output stores now */
519 if (intrin->intrinsic != nir_intrinsic_store_output)
520 return false;
521
522 b->cursor = nir_before_instr(instr);
523
524 /* Position output - store the value to a variable, remove output store */
525 nir_io_semantics io_sem = nir_intrinsic_io_semantics(intrin);
526 if (io_sem.location == VARYING_SLOT_POS) {
527 /* TODO: check if it's indirect, etc? */
528 unsigned writemask = nir_intrinsic_write_mask(intrin);
529 nir_ssa_def *store_val = intrin->src[0].ssa;
530 nir_store_var(b, s->pre_cull_position_value_var, store_val, writemask);
531 }
532
533 /* Remove all output stores */
534 nir_instr_remove(instr);
535 return true;
536 }
537
538 static void
remove_culling_shader_outputs(nir_shader * culling_shader,lower_ngg_nogs_state * nogs_state,nir_variable * pre_cull_position_value_var)539 remove_culling_shader_outputs(nir_shader *culling_shader, lower_ngg_nogs_state *nogs_state, nir_variable *pre_cull_position_value_var)
540 {
541 remove_culling_shader_outputs_state s = {
542 .pre_cull_position_value_var = pre_cull_position_value_var,
543 };
544
545 nir_shader_instructions_pass(culling_shader, remove_culling_shader_output,
546 nir_metadata_block_index | nir_metadata_dominance, &s);
547
548 /* Remove dead code resulting from the deleted outputs. */
549 bool progress;
550 do {
551 progress = false;
552 NIR_PASS(progress, culling_shader, nir_opt_dead_write_vars);
553 NIR_PASS(progress, culling_shader, nir_opt_dce);
554 NIR_PASS(progress, culling_shader, nir_opt_dead_cf);
555 } while (progress);
556 }
557
558 static void
rewrite_uses_to_var(nir_builder * b,nir_ssa_def * old_def,nir_variable * replacement_var,unsigned replacement_var_channel)559 rewrite_uses_to_var(nir_builder *b, nir_ssa_def *old_def, nir_variable *replacement_var, unsigned replacement_var_channel)
560 {
561 if (old_def->parent_instr->type == nir_instr_type_load_const)
562 return;
563
564 b->cursor = nir_after_instr(old_def->parent_instr);
565 if (b->cursor.instr->type == nir_instr_type_phi)
566 b->cursor = nir_after_phis(old_def->parent_instr->block);
567
568 nir_ssa_def *pos_val_rep = nir_load_var(b, replacement_var);
569 nir_ssa_def *replacement = nir_channel(b, pos_val_rep, replacement_var_channel);
570
571 if (old_def->num_components > 1) {
572 /* old_def uses a swizzled vector component.
573 * There is no way to replace the uses of just a single vector component,
574 * so instead create a new vector and replace all uses of the old vector.
575 */
576 nir_ssa_def *old_def_elements[NIR_MAX_VEC_COMPONENTS] = {0};
577 for (unsigned j = 0; j < old_def->num_components; ++j)
578 old_def_elements[j] = nir_channel(b, old_def, j);
579 replacement = nir_vec(b, old_def_elements, old_def->num_components);
580 }
581
582 nir_ssa_def_rewrite_uses_after(old_def, replacement, replacement->parent_instr);
583 }
584
585 static bool
remove_extra_pos_output(nir_builder * b,nir_instr * instr,void * state)586 remove_extra_pos_output(nir_builder *b, nir_instr *instr, void *state)
587 {
588 remove_extra_position_output_state *s = (remove_extra_position_output_state *) state;
589
590 if (instr->type != nir_instr_type_intrinsic)
591 return false;
592
593 nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
594
595 /* These are not allowed in VS / TES */
596 assert(intrin->intrinsic != nir_intrinsic_store_per_vertex_output &&
597 intrin->intrinsic != nir_intrinsic_load_per_vertex_input);
598
599 /* We are only interested in output stores now */
600 if (intrin->intrinsic != nir_intrinsic_store_output)
601 return false;
602
603 nir_io_semantics io_sem = nir_intrinsic_io_semantics(intrin);
604 if (io_sem.location != VARYING_SLOT_POS)
605 return false;
606
607 b->cursor = nir_before_instr(instr);
608
609 /* In case other outputs use what we calculated for pos,
610 * try to avoid calculating it again by rewriting the usages
611 * of the store components here.
612 */
613 nir_ssa_def *store_val = intrin->src[0].ssa;
614 unsigned store_pos_component = nir_intrinsic_component(intrin);
615
616 nir_instr_remove(instr);
617
618 if (store_val->parent_instr->type == nir_instr_type_alu) {
619 nir_alu_instr *alu = nir_instr_as_alu(store_val->parent_instr);
620 if (nir_op_is_vec(alu->op)) {
621 /* Output store uses a vector, we can easily rewrite uses of each vector element. */
622
623 unsigned num_vec_src = 0;
624 if (alu->op == nir_op_mov)
625 num_vec_src = 1;
626 else if (alu->op == nir_op_vec2)
627 num_vec_src = 2;
628 else if (alu->op == nir_op_vec3)
629 num_vec_src = 3;
630 else if (alu->op == nir_op_vec4)
631 num_vec_src = 4;
632 assert(num_vec_src);
633
634 /* Remember the current components whose uses we wish to replace.
635 * This is needed because rewriting one source can affect the others too.
636 */
637 nir_ssa_def *vec_comps[NIR_MAX_VEC_COMPONENTS] = {0};
638 for (unsigned i = 0; i < num_vec_src; i++)
639 vec_comps[i] = alu->src[i].src.ssa;
640
641 for (unsigned i = 0; i < num_vec_src; i++)
642 rewrite_uses_to_var(b, vec_comps[i], s->pos_value_replacement, store_pos_component + i);
643 } else {
644 rewrite_uses_to_var(b, store_val, s->pos_value_replacement, store_pos_component);
645 }
646 } else {
647 rewrite_uses_to_var(b, store_val, s->pos_value_replacement, store_pos_component);
648 }
649
650 return true;
651 }
652
653 static void
remove_extra_pos_outputs(nir_shader * shader,lower_ngg_nogs_state * nogs_state)654 remove_extra_pos_outputs(nir_shader *shader, lower_ngg_nogs_state *nogs_state)
655 {
656 remove_extra_position_output_state s = {
657 .pos_value_replacement = nogs_state->position_value_var,
658 };
659
660 nir_shader_instructions_pass(shader, remove_extra_pos_output,
661 nir_metadata_block_index | nir_metadata_dominance, &s);
662 }
663
664 static bool
remove_compacted_arg(lower_ngg_nogs_state * state,nir_builder * b,unsigned idx)665 remove_compacted_arg(lower_ngg_nogs_state *state, nir_builder *b, unsigned idx)
666 {
667 nir_instr *store_instr = state->compact_arg_stores[idx];
668 if (!store_instr)
669 return false;
670
671 /* Simply remove the store. */
672 nir_instr_remove(store_instr);
673
674 /* Find the intrinsic that overwrites the shader arguments,
675 * and change its corresponding source.
676 * This will cause NIR's DCE to recognize the load and its phis as dead.
677 */
678 b->cursor = nir_before_instr(&state->overwrite_args->instr);
679 nir_ssa_def *undef_arg = nir_ssa_undef(b, 1, 32);
680 nir_ssa_def_rewrite_uses(state->overwrite_args->src[idx].ssa, undef_arg);
681
682 state->compact_arg_stores[idx] = NULL;
683 return true;
684 }
685
686 static bool
cleanup_culling_shader_after_dce(nir_shader * shader,nir_function_impl * function_impl,lower_ngg_nogs_state * state)687 cleanup_culling_shader_after_dce(nir_shader *shader,
688 nir_function_impl *function_impl,
689 lower_ngg_nogs_state *state)
690 {
691 bool uses_vs_vertex_id = false;
692 bool uses_vs_instance_id = false;
693 bool uses_tes_u = false;
694 bool uses_tes_v = false;
695 bool uses_tes_rel_patch_id = false;
696 bool uses_tes_patch_id = false;
697
698 bool progress = false;
699 nir_builder b;
700 nir_builder_init(&b, function_impl);
701
702 nir_foreach_block_reverse_safe(block, function_impl) {
703 nir_foreach_instr_reverse_safe(instr, block) {
704 if (instr->type != nir_instr_type_intrinsic)
705 continue;
706
707 nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
708
709 switch (intrin->intrinsic) {
710 case nir_intrinsic_alloc_vertices_and_primitives_amd:
711 goto cleanup_culling_shader_after_dce_done;
712 case nir_intrinsic_load_vertex_id:
713 case nir_intrinsic_load_vertex_id_zero_base:
714 uses_vs_vertex_id = true;
715 break;
716 case nir_intrinsic_load_instance_id:
717 uses_vs_instance_id = true;
718 break;
719 case nir_intrinsic_load_input:
720 if (state->instance_rate_inputs &
721 (1u << (nir_intrinsic_base(intrin) - VERT_ATTRIB_GENERIC0)))
722 uses_vs_instance_id = true;
723 else
724 uses_vs_vertex_id = true;
725 break;
726 case nir_intrinsic_load_tess_coord:
727 uses_tes_u = uses_tes_v = true;
728 break;
729 case nir_intrinsic_load_tess_rel_patch_id_amd:
730 uses_tes_rel_patch_id = true;
731 break;
732 case nir_intrinsic_load_primitive_id:
733 if (shader->info.stage == MESA_SHADER_TESS_EVAL)
734 uses_tes_patch_id = true;
735 break;
736 default:
737 break;
738 }
739 }
740 }
741
742 cleanup_culling_shader_after_dce_done:
743
744 if (shader->info.stage == MESA_SHADER_VERTEX) {
745 if (!uses_vs_vertex_id)
746 progress |= remove_compacted_arg(state, &b, 0);
747 if (!uses_vs_instance_id)
748 progress |= remove_compacted_arg(state, &b, 1);
749 } else if (shader->info.stage == MESA_SHADER_TESS_EVAL) {
750 if (!uses_tes_u)
751 progress |= remove_compacted_arg(state, &b, 0);
752 if (!uses_tes_v)
753 progress |= remove_compacted_arg(state, &b, 1);
754 if (!uses_tes_rel_patch_id)
755 progress |= remove_compacted_arg(state, &b, 2);
756 if (!uses_tes_patch_id)
757 progress |= remove_compacted_arg(state, &b, 3);
758 }
759
760 return progress;
761 }
762
763 /**
764 * Perform vertex compaction after culling.
765 *
766 * 1. Repack surviving ES invocations (this determines which lane will export which vertex)
767 * 2. Surviving ES vertex invocations store their data to LDS
768 * 3. Emit GS_ALLOC_REQ
769 * 4. Repacked invocations load the vertex data from LDS
770 * 5. GS threads update their vertex indices
771 */
772 static void
compact_vertices_after_culling(nir_builder * b,lower_ngg_nogs_state * nogs_state,nir_variable ** repacked_arg_vars,nir_variable ** gs_vtxaddr_vars,nir_ssa_def * invocation_index,nir_ssa_def * es_vertex_lds_addr,nir_ssa_def * es_exporter_tid,nir_ssa_def * num_live_vertices_in_workgroup,nir_ssa_def * fully_culled,unsigned ngg_scratch_lds_base_addr,unsigned pervertex_lds_bytes,unsigned max_exported_args)773 compact_vertices_after_culling(nir_builder *b,
774 lower_ngg_nogs_state *nogs_state,
775 nir_variable **repacked_arg_vars,
776 nir_variable **gs_vtxaddr_vars,
777 nir_ssa_def *invocation_index,
778 nir_ssa_def *es_vertex_lds_addr,
779 nir_ssa_def *es_exporter_tid,
780 nir_ssa_def *num_live_vertices_in_workgroup,
781 nir_ssa_def *fully_culled,
782 unsigned ngg_scratch_lds_base_addr,
783 unsigned pervertex_lds_bytes,
784 unsigned max_exported_args)
785 {
786 nir_variable *es_accepted_var = nogs_state->es_accepted_var;
787 nir_variable *gs_accepted_var = nogs_state->gs_accepted_var;
788 nir_variable *position_value_var = nogs_state->position_value_var;
789 nir_variable *prim_exp_arg_var = nogs_state->prim_exp_arg_var;
790
791 nir_if *if_es_accepted = nir_push_if(b, nir_load_var(b, es_accepted_var));
792 {
793 nir_ssa_def *exporter_addr = pervertex_lds_addr(b, es_exporter_tid, pervertex_lds_bytes);
794
795 /* Store the exporter thread's index to the LDS space of the current thread so GS threads can load it */
796 nir_store_shared(b, nir_u2u8(b, es_exporter_tid), es_vertex_lds_addr, .base = lds_es_exporter_tid);
797
798 /* Store the current thread's position output to the exporter thread's LDS space */
799 nir_ssa_def *pos = nir_load_var(b, position_value_var);
800 nir_store_shared(b, pos, exporter_addr, .base = lds_es_pos_x);
801
802 /* Store the current thread's repackable arguments to the exporter thread's LDS space */
803 for (unsigned i = 0; i < max_exported_args; ++i) {
804 nir_ssa_def *arg_val = nir_load_var(b, repacked_arg_vars[i]);
805 nir_intrinsic_instr *store = nir_store_shared(b, arg_val, exporter_addr, .base = lds_es_arg_0 + 4u * i);
806
807 nogs_state->compact_arg_stores[i] = &store->instr;
808 }
809 }
810 nir_pop_if(b, if_es_accepted);
811
812 /* TODO: Consider adding a shortcut exit.
813 * Waves that have no vertices and primitives left can s_endpgm right here.
814 */
815
816 nir_scoped_barrier(b, .execution_scope=NIR_SCOPE_WORKGROUP, .memory_scope=NIR_SCOPE_WORKGROUP,
817 .memory_semantics=NIR_MEMORY_ACQ_REL, .memory_modes=nir_var_mem_shared);
818
819 nir_ssa_def *es_survived = nir_ilt(b, invocation_index, num_live_vertices_in_workgroup);
820 nir_if *if_packed_es_thread = nir_push_if(b, es_survived);
821 {
822 /* Read position from the current ES thread's LDS space (written by the exported vertex's ES thread) */
823 nir_ssa_def *exported_pos = nir_load_shared(b, 4, 32, es_vertex_lds_addr, .base = lds_es_pos_x);
824 nir_store_var(b, position_value_var, exported_pos, 0xfu);
825
826 /* Read the repacked arguments */
827 for (unsigned i = 0; i < max_exported_args; ++i) {
828 nir_ssa_def *arg_val = nir_load_shared(b, 1, 32, es_vertex_lds_addr, .base = lds_es_arg_0 + 4u * i);
829 nir_store_var(b, repacked_arg_vars[i], arg_val, 0x1u);
830 }
831 }
832 nir_push_else(b, if_packed_es_thread);
833 {
834 nir_store_var(b, position_value_var, nir_ssa_undef(b, 4, 32), 0xfu);
835 for (unsigned i = 0; i < max_exported_args; ++i)
836 nir_store_var(b, repacked_arg_vars[i], nir_ssa_undef(b, 1, 32), 0x1u);
837 }
838 nir_pop_if(b, if_packed_es_thread);
839
840 nir_if *if_gs_accepted = nir_push_if(b, nir_load_var(b, gs_accepted_var));
841 {
842 nir_ssa_def *exporter_vtx_indices[3] = {0};
843
844 /* Load the index of the ES threads that will export the current GS thread's vertices */
845 for (unsigned v = 0; v < 3; ++v) {
846 nir_ssa_def *vtx_addr = nir_load_var(b, gs_vtxaddr_vars[v]);
847 nir_ssa_def *exporter_vtx_idx = nir_load_shared(b, 1, 8, vtx_addr, .base = lds_es_exporter_tid);
848 exporter_vtx_indices[v] = nir_u2u32(b, exporter_vtx_idx);
849 nir_store_var(b, nogs_state->gs_vtx_indices_vars[v], exporter_vtx_indices[v], 0x1);
850 }
851
852 nir_ssa_def *prim_exp_arg = emit_pack_ngg_prim_exp_arg(b, 3, exporter_vtx_indices, NULL, nogs_state->use_edgeflags);
853 nir_store_var(b, prim_exp_arg_var, prim_exp_arg, 0x1u);
854 }
855 nir_pop_if(b, if_gs_accepted);
856
857 nir_store_var(b, es_accepted_var, es_survived, 0x1u);
858 nir_store_var(b, gs_accepted_var, nir_bcsel(b, fully_culled, nir_imm_false(b), nir_has_input_primitive_amd(b)), 0x1u);
859 }
860
861 static void
analyze_shader_before_culling_walk(nir_ssa_def * ssa,uint8_t flag,lower_ngg_nogs_state * nogs_state)862 analyze_shader_before_culling_walk(nir_ssa_def *ssa,
863 uint8_t flag,
864 lower_ngg_nogs_state *nogs_state)
865 {
866 nir_instr *instr = ssa->parent_instr;
867 uint8_t old_pass_flags = instr->pass_flags;
868 instr->pass_flags |= flag;
869
870 if (instr->pass_flags == old_pass_flags)
871 return; /* Already visited. */
872
873 switch (instr->type) {
874 case nir_instr_type_intrinsic: {
875 nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
876
877 /* VS input loads and SSBO loads are actually VRAM reads on AMD HW. */
878 switch (intrin->intrinsic) {
879 case nir_intrinsic_load_input: {
880 nir_io_semantics in_io_sem = nir_intrinsic_io_semantics(intrin);
881 uint64_t in_mask = UINT64_C(1) << (uint64_t) in_io_sem.location;
882 if (instr->pass_flags & nggc_passflag_used_by_pos)
883 nogs_state->inputs_needed_by_pos |= in_mask;
884 else if (instr->pass_flags & nggc_passflag_used_by_other)
885 nogs_state->inputs_needed_by_others |= in_mask;
886 break;
887 }
888 default:
889 break;
890 }
891
892 break;
893 }
894 case nir_instr_type_alu: {
895 nir_alu_instr *alu = nir_instr_as_alu(instr);
896 unsigned num_srcs = nir_op_infos[alu->op].num_inputs;
897
898 for (unsigned i = 0; i < num_srcs; ++i) {
899 analyze_shader_before_culling_walk(alu->src[i].src.ssa, flag, nogs_state);
900 }
901
902 break;
903 }
904 case nir_instr_type_phi: {
905 nir_phi_instr *phi = nir_instr_as_phi(instr);
906 nir_foreach_phi_src_safe(phi_src, phi) {
907 analyze_shader_before_culling_walk(phi_src->src.ssa, flag, nogs_state);
908 }
909
910 break;
911 }
912 default:
913 break;
914 }
915 }
916
917 static void
analyze_shader_before_culling(nir_shader * shader,lower_ngg_nogs_state * nogs_state)918 analyze_shader_before_culling(nir_shader *shader, lower_ngg_nogs_state *nogs_state)
919 {
920 nir_foreach_function(func, shader) {
921 nir_foreach_block(block, func->impl) {
922 nir_foreach_instr(instr, block) {
923 instr->pass_flags = 0;
924
925 if (instr->type != nir_instr_type_intrinsic)
926 continue;
927
928 nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
929 if (intrin->intrinsic != nir_intrinsic_store_output)
930 continue;
931
932 nir_io_semantics io_sem = nir_intrinsic_io_semantics(intrin);
933 nir_ssa_def *store_val = intrin->src[0].ssa;
934 uint8_t flag = io_sem.location == VARYING_SLOT_POS ? nggc_passflag_used_by_pos : nggc_passflag_used_by_other;
935 analyze_shader_before_culling_walk(store_val, flag, nogs_state);
936 }
937 }
938 }
939 }
940
941 /**
942 * Save the reusable SSA definitions to variables so that the
943 * bottom shader part can reuse them from the top part.
944 *
945 * 1. We create a new function temporary variable for reusables,
946 * and insert a store+load.
947 * 2. The shader is cloned (the top part is created), then the
948 * control flow is reinserted (for the bottom part.)
949 * 3. For reusables, we delete the variable stores from the
950 * bottom part. This will make them use the variables from
951 * the top part and DCE the redundant instructions.
952 */
953 static void
save_reusable_variables(nir_builder * b,lower_ngg_nogs_state * nogs_state)954 save_reusable_variables(nir_builder *b, lower_ngg_nogs_state *nogs_state)
955 {
956 ASSERTED int vec_ok = u_vector_init(&nogs_state->saved_uniforms, 4, sizeof(saved_uniform));
957 assert(vec_ok);
958
959 nir_block *block = nir_start_block(b->impl);
960 while (block) {
961 /* Process the instructions in the current block. */
962 nir_foreach_instr_safe(instr, block) {
963 /* Find instructions whose SSA definitions are used by both
964 * the top and bottom parts of the shader (before and after culling).
965 * Only in this case, it makes sense for the bottom part
966 * to try to reuse these from the top part.
967 */
968 if ((instr->pass_flags & nggc_passflag_used_by_both) != nggc_passflag_used_by_both)
969 continue;
970
971 /* Determine if we can reuse the current SSA value.
972 * When vertex compaction is used, it is possible that the same shader invocation
973 * processes a different vertex in the top and bottom part of the shader.
974 * Therefore, we only reuse uniform values.
975 */
976 nir_ssa_def *ssa = NULL;
977 switch (instr->type) {
978 case nir_instr_type_alu: {
979 nir_alu_instr *alu = nir_instr_as_alu(instr);
980 if (alu->dest.dest.ssa.divergent)
981 continue;
982 /* Ignore uniform floats because they regress VGPR usage too much */
983 if (nir_op_infos[alu->op].output_type & nir_type_float)
984 continue;
985 ssa = &alu->dest.dest.ssa;
986 break;
987 }
988 case nir_instr_type_intrinsic: {
989 nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
990 if (!nir_intrinsic_can_reorder(intrin) ||
991 !nir_intrinsic_infos[intrin->intrinsic].has_dest ||
992 intrin->dest.ssa.divergent)
993 continue;
994 ssa = &intrin->dest.ssa;
995 break;
996 }
997 case nir_instr_type_phi: {
998 nir_phi_instr *phi = nir_instr_as_phi(instr);
999 if (phi->dest.ssa.divergent)
1000 continue;
1001 ssa = &phi->dest.ssa;
1002 break;
1003 }
1004 default:
1005 continue;
1006 }
1007
1008 assert(ssa);
1009
1010 /* Determine a suitable type for the SSA value. */
1011 enum glsl_base_type base_type = GLSL_TYPE_UINT;
1012 switch (ssa->bit_size) {
1013 case 8: base_type = GLSL_TYPE_UINT8; break;
1014 case 16: base_type = GLSL_TYPE_UINT16; break;
1015 case 32: base_type = GLSL_TYPE_UINT; break;
1016 case 64: base_type = GLSL_TYPE_UINT64; break;
1017 default: continue;
1018 }
1019
1020 const struct glsl_type *t = ssa->num_components == 1
1021 ? glsl_scalar_type(base_type)
1022 : glsl_vector_type(base_type, ssa->num_components);
1023
1024 saved_uniform *saved = (saved_uniform *) u_vector_add(&nogs_state->saved_uniforms);
1025 assert(saved);
1026
1027 /* Create a new NIR variable where we store the reusable value.
1028 * Then, we reload the variable and replace the uses of the value
1029 * with the reloaded variable.
1030 */
1031 saved->var = nir_local_variable_create(b->impl, t, NULL);
1032 saved->ssa = ssa;
1033
1034 b->cursor = instr->type == nir_instr_type_phi
1035 ? nir_after_instr_and_phis(instr)
1036 : nir_after_instr(instr);
1037 nir_store_var(b, saved->var, saved->ssa, BITFIELD_MASK(ssa->num_components));
1038 nir_ssa_def *reloaded = nir_load_var(b, saved->var);
1039 nir_ssa_def_rewrite_uses_after(ssa, reloaded, reloaded->parent_instr);
1040 }
1041
1042 /* Look at the next CF node. */
1043 nir_cf_node *next_cf_node = nir_cf_node_next(&block->cf_node);
1044 if (next_cf_node) {
1045 /* It makes no sense to try to reuse things from within loops. */
1046 bool next_is_loop = next_cf_node->type == nir_cf_node_loop;
1047
1048 /* Don't reuse if we're in divergent control flow.
1049 *
1050 * Thanks to vertex repacking, the same shader invocation may process a different vertex
1051 * in the top and bottom part, and it's even possible that this different vertex was initially
1052 * processed in a different wave. So the two parts may take a different divergent code path.
1053 * Therefore, these variables in divergent control flow may stay undefined.
1054 *
1055 * Note that this problem doesn't exist if vertices are not repacked or if the
1056 * workgroup only has a single wave.
1057 */
1058 bool next_is_divergent_if =
1059 next_cf_node->type == nir_cf_node_if &&
1060 nir_cf_node_as_if(next_cf_node)->condition.ssa->divergent;
1061
1062 if (next_is_loop || next_is_divergent_if) {
1063 block = nir_cf_node_cf_tree_next(next_cf_node);
1064 continue;
1065 }
1066 }
1067
1068 /* Go to the next block. */
1069 block = nir_block_cf_tree_next(block);
1070 }
1071 }
1072
1073 /**
1074 * Reuses suitable variables from the top part of the shader,
1075 * by deleting their stores from the bottom part.
1076 */
1077 static void
apply_reusable_variables(nir_builder * b,lower_ngg_nogs_state * nogs_state)1078 apply_reusable_variables(nir_builder *b, lower_ngg_nogs_state *nogs_state)
1079 {
1080 if (!u_vector_length(&nogs_state->saved_uniforms)) {
1081 u_vector_finish(&nogs_state->saved_uniforms);
1082 return;
1083 }
1084
1085 nir_foreach_block_reverse_safe(block, b->impl) {
1086 nir_foreach_instr_reverse_safe(instr, block) {
1087 if (instr->type != nir_instr_type_intrinsic)
1088 continue;
1089 nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
1090
1091 /* When we found any of these intrinsics, it means
1092 * we reached the top part and we must stop.
1093 */
1094 if (intrin->intrinsic == nir_intrinsic_alloc_vertices_and_primitives_amd)
1095 goto done;
1096
1097 if (intrin->intrinsic != nir_intrinsic_store_deref)
1098 continue;
1099 nir_deref_instr *deref = nir_src_as_deref(intrin->src[0]);
1100 if (deref->deref_type != nir_deref_type_var)
1101 continue;
1102
1103 saved_uniform *saved;
1104 u_vector_foreach(saved, &nogs_state->saved_uniforms) {
1105 if (saved->var == deref->var) {
1106 nir_instr_remove(instr);
1107 }
1108 }
1109 }
1110 }
1111
1112 done:
1113 u_vector_finish(&nogs_state->saved_uniforms);
1114 }
1115
1116 static void
add_deferred_attribute_culling(nir_builder * b,nir_cf_list * original_extracted_cf,lower_ngg_nogs_state * nogs_state)1117 add_deferred_attribute_culling(nir_builder *b, nir_cf_list *original_extracted_cf, lower_ngg_nogs_state *nogs_state)
1118 {
1119 bool uses_instance_id = BITSET_TEST(b->shader->info.system_values_read, SYSTEM_VALUE_INSTANCE_ID);
1120 bool uses_tess_primitive_id = BITSET_TEST(b->shader->info.system_values_read, SYSTEM_VALUE_PRIMITIVE_ID);
1121
1122 unsigned max_exported_args = b->shader->info.stage == MESA_SHADER_VERTEX ? 2 : 4;
1123 if (b->shader->info.stage == MESA_SHADER_VERTEX && !uses_instance_id)
1124 max_exported_args--;
1125 else if (b->shader->info.stage == MESA_SHADER_TESS_EVAL && !uses_tess_primitive_id)
1126 max_exported_args--;
1127
1128 unsigned pervertex_lds_bytes = lds_es_arg_0 + max_exported_args * 4u;
1129 unsigned total_es_lds_bytes = pervertex_lds_bytes * nogs_state->max_es_num_vertices;
1130 unsigned max_num_waves = nogs_state->max_num_waves;
1131 unsigned ngg_scratch_lds_base_addr = ALIGN(total_es_lds_bytes, 8u);
1132 unsigned ngg_scratch_lds_bytes = ALIGN(max_num_waves, 4u);
1133 nogs_state->total_lds_bytes = ngg_scratch_lds_base_addr + ngg_scratch_lds_bytes;
1134
1135 nir_function_impl *impl = nir_shader_get_entrypoint(b->shader);
1136
1137 /* Create some helper variables. */
1138 nir_variable *position_value_var = nogs_state->position_value_var;
1139 nir_variable *prim_exp_arg_var = nogs_state->prim_exp_arg_var;
1140 nir_variable *gs_accepted_var = nogs_state->gs_accepted_var;
1141 nir_variable *es_accepted_var = nogs_state->es_accepted_var;
1142 nir_variable *gs_vtxaddr_vars[3] = {
1143 nir_local_variable_create(impl, glsl_uint_type(), "gs_vtx0_addr"),
1144 nir_local_variable_create(impl, glsl_uint_type(), "gs_vtx1_addr"),
1145 nir_local_variable_create(impl, glsl_uint_type(), "gs_vtx2_addr"),
1146 };
1147 nir_variable *repacked_arg_vars[4] = {
1148 nir_local_variable_create(impl, glsl_uint_type(), "repacked_arg_0"),
1149 nir_local_variable_create(impl, glsl_uint_type(), "repacked_arg_1"),
1150 nir_local_variable_create(impl, glsl_uint_type(), "repacked_arg_2"),
1151 nir_local_variable_create(impl, glsl_uint_type(), "repacked_arg_3"),
1152 };
1153
1154 /* Top part of the culling shader (aka. position shader part)
1155 *
1156 * We clone the full ES shader and emit it here, but we only really care
1157 * about its position output, so we delete every other output from this part.
1158 * The position output is stored into a temporary variable, and reloaded later.
1159 */
1160
1161 b->cursor = nir_before_cf_list(&impl->body);
1162
1163 nir_ssa_def *es_thread = nir_has_input_vertex_amd(b);
1164 nir_if *if_es_thread = nir_push_if(b, es_thread);
1165 {
1166 /* Initialize the position output variable to zeroes, in case not all VS/TES invocations store the output.
1167 * The spec doesn't require it, but we use (0, 0, 0, 1) because some games rely on that.
1168 */
1169 nir_store_var(b, position_value_var, nir_imm_vec4(b, 0.0f, 0.0f, 0.0f, 1.0f), 0xfu);
1170
1171 /* Now reinsert a clone of the shader code */
1172 struct hash_table *remap_table = _mesa_pointer_hash_table_create(NULL);
1173 nir_cf_list_clone_and_reinsert(original_extracted_cf, &if_es_thread->cf_node, b->cursor, remap_table);
1174 _mesa_hash_table_destroy(remap_table, NULL);
1175 b->cursor = nir_after_cf_list(&if_es_thread->then_list);
1176
1177 /* Remember the current thread's shader arguments */
1178 if (b->shader->info.stage == MESA_SHADER_VERTEX) {
1179 nir_store_var(b, repacked_arg_vars[0], nir_load_vertex_id_zero_base(b), 0x1u);
1180 if (uses_instance_id)
1181 nir_store_var(b, repacked_arg_vars[1], nir_load_instance_id(b), 0x1u);
1182 } else if (b->shader->info.stage == MESA_SHADER_TESS_EVAL) {
1183 nir_ssa_def *tess_coord = nir_load_tess_coord(b);
1184 nir_store_var(b, repacked_arg_vars[0], nir_channel(b, tess_coord, 0), 0x1u);
1185 nir_store_var(b, repacked_arg_vars[1], nir_channel(b, tess_coord, 1), 0x1u);
1186 nir_store_var(b, repacked_arg_vars[2], nir_load_tess_rel_patch_id_amd(b), 0x1u);
1187 if (uses_tess_primitive_id)
1188 nir_store_var(b, repacked_arg_vars[3], nir_load_primitive_id(b), 0x1u);
1189 } else {
1190 unreachable("Should be VS or TES.");
1191 }
1192 }
1193 nir_pop_if(b, if_es_thread);
1194
1195 nir_store_var(b, es_accepted_var, es_thread, 0x1u);
1196 nir_store_var(b, gs_accepted_var, nir_has_input_primitive_amd(b), 0x1u);
1197
1198 /* Remove all non-position outputs, and put the position output into the variable. */
1199 nir_metadata_preserve(impl, nir_metadata_none);
1200 remove_culling_shader_outputs(b->shader, nogs_state, position_value_var);
1201 b->cursor = nir_after_cf_list(&impl->body);
1202
1203 /* Run culling algorithms if culling is enabled.
1204 *
1205 * NGG culling can be enabled or disabled in runtime.
1206 * This is determined by a SGPR shader argument which is acccessed
1207 * by the following NIR intrinsic.
1208 */
1209
1210 nir_if *if_cull_en = nir_push_if(b, nir_load_cull_any_enabled_amd(b));
1211 {
1212 nir_ssa_def *invocation_index = nir_load_local_invocation_index(b);
1213 nir_ssa_def *es_vertex_lds_addr = pervertex_lds_addr(b, invocation_index, pervertex_lds_bytes);
1214
1215 /* ES invocations store their vertex data to LDS for GS threads to read. */
1216 if_es_thread = nir_push_if(b, nir_has_input_vertex_amd(b));
1217 {
1218 /* Store position components that are relevant to culling in LDS */
1219 nir_ssa_def *pre_cull_pos = nir_load_var(b, position_value_var);
1220 nir_ssa_def *pre_cull_w = nir_channel(b, pre_cull_pos, 3);
1221 nir_store_shared(b, pre_cull_w, es_vertex_lds_addr, .base = lds_es_pos_w);
1222 nir_ssa_def *pre_cull_x_div_w = nir_fdiv(b, nir_channel(b, pre_cull_pos, 0), pre_cull_w);
1223 nir_ssa_def *pre_cull_y_div_w = nir_fdiv(b, nir_channel(b, pre_cull_pos, 1), pre_cull_w);
1224 nir_store_shared(b, nir_vec2(b, pre_cull_x_div_w, pre_cull_y_div_w), es_vertex_lds_addr, .base = lds_es_pos_x);
1225
1226 /* Clear out the ES accepted flag in LDS */
1227 nir_store_shared(b, nir_imm_zero(b, 1, 8), es_vertex_lds_addr, .align_mul = 4, .base = lds_es_vertex_accepted);
1228 }
1229 nir_pop_if(b, if_es_thread);
1230
1231 nir_scoped_barrier(b, .execution_scope=NIR_SCOPE_WORKGROUP, .memory_scope=NIR_SCOPE_WORKGROUP,
1232 .memory_semantics=NIR_MEMORY_ACQ_REL, .memory_modes=nir_var_mem_shared);
1233
1234 nir_store_var(b, gs_accepted_var, nir_imm_bool(b, false), 0x1u);
1235 nir_store_var(b, prim_exp_arg_var, nir_imm_int(b, 1u << 31), 0x1u);
1236
1237 /* GS invocations load the vertex data and perform the culling. */
1238 nir_if *if_gs_thread = nir_push_if(b, nir_has_input_primitive_amd(b));
1239 {
1240 /* Load vertex indices from input VGPRs */
1241 nir_ssa_def *vtx_idx[3] = {0};
1242 for (unsigned vertex = 0; vertex < 3; ++vertex)
1243 vtx_idx[vertex] = nir_load_var(b, nogs_state->gs_vtx_indices_vars[vertex]);
1244
1245 nir_ssa_def *vtx_addr[3] = {0};
1246 nir_ssa_def *pos[3][4] = {0};
1247
1248 /* Load W positions of vertices first because the culling code will use these first */
1249 for (unsigned vtx = 0; vtx < 3; ++vtx) {
1250 vtx_addr[vtx] = pervertex_lds_addr(b, vtx_idx[vtx], pervertex_lds_bytes);
1251 pos[vtx][3] = nir_load_shared(b, 1, 32, vtx_addr[vtx], .base = lds_es_pos_w);
1252 nir_store_var(b, gs_vtxaddr_vars[vtx], vtx_addr[vtx], 0x1u);
1253 }
1254
1255 /* Load the X/W, Y/W positions of vertices */
1256 for (unsigned vtx = 0; vtx < 3; ++vtx) {
1257 nir_ssa_def *xy = nir_load_shared(b, 2, 32, vtx_addr[vtx], .base = lds_es_pos_x);
1258 pos[vtx][0] = nir_channel(b, xy, 0);
1259 pos[vtx][1] = nir_channel(b, xy, 1);
1260 }
1261
1262 /* See if the current primitive is accepted */
1263 nir_ssa_def *accepted = ac_nir_cull_triangle(b, nir_imm_bool(b, true), pos);
1264 nir_store_var(b, gs_accepted_var, accepted, 0x1u);
1265
1266 nir_if *if_gs_accepted = nir_push_if(b, accepted);
1267 {
1268 /* Store the accepted state to LDS for ES threads */
1269 for (unsigned vtx = 0; vtx < 3; ++vtx)
1270 nir_store_shared(b, nir_imm_intN_t(b, 0xff, 8), vtx_addr[vtx], .base = lds_es_vertex_accepted, .align_mul = 4u);
1271 }
1272 nir_pop_if(b, if_gs_accepted);
1273 }
1274 nir_pop_if(b, if_gs_thread);
1275
1276 nir_scoped_barrier(b, .execution_scope=NIR_SCOPE_WORKGROUP, .memory_scope=NIR_SCOPE_WORKGROUP,
1277 .memory_semantics=NIR_MEMORY_ACQ_REL, .memory_modes=nir_var_mem_shared);
1278
1279 nir_store_var(b, es_accepted_var, nir_imm_bool(b, false), 0x1u);
1280
1281 /* ES invocations load their accepted flag from LDS. */
1282 if_es_thread = nir_push_if(b, nir_has_input_vertex_amd(b));
1283 {
1284 nir_ssa_def *accepted = nir_load_shared(b, 1, 8u, es_vertex_lds_addr, .base = lds_es_vertex_accepted, .align_mul = 4u);
1285 nir_ssa_def *accepted_bool = nir_ine(b, accepted, nir_imm_intN_t(b, 0, 8));
1286 nir_store_var(b, es_accepted_var, accepted_bool, 0x1u);
1287 }
1288 nir_pop_if(b, if_es_thread);
1289
1290 nir_ssa_def *es_accepted = nir_load_var(b, es_accepted_var);
1291
1292 /* Repack the vertices that survived the culling. */
1293 wg_repack_result rep = repack_invocations_in_workgroup(b, es_accepted, ngg_scratch_lds_base_addr,
1294 nogs_state->max_num_waves, nogs_state->wave_size);
1295 nir_ssa_def *num_live_vertices_in_workgroup = rep.num_repacked_invocations;
1296 nir_ssa_def *es_exporter_tid = rep.repacked_invocation_index;
1297
1298 /* If all vertices are culled, set primitive count to 0 as well. */
1299 nir_ssa_def *num_exported_prims = nir_load_workgroup_num_input_primitives_amd(b);
1300 nir_ssa_def *fully_culled = nir_ieq_imm(b, num_live_vertices_in_workgroup, 0u);
1301 num_exported_prims = nir_bcsel(b, fully_culled, nir_imm_int(b, 0u), num_exported_prims);
1302
1303 nir_if *if_wave_0 = nir_push_if(b, nir_ieq(b, nir_load_subgroup_id(b), nir_imm_int(b, 0)));
1304 {
1305 /* Tell the final vertex and primitive count to the HW. */
1306 nir_alloc_vertices_and_primitives_amd(b, num_live_vertices_in_workgroup, num_exported_prims);
1307 }
1308 nir_pop_if(b, if_wave_0);
1309
1310 /* Vertex compaction. */
1311 compact_vertices_after_culling(b, nogs_state,
1312 repacked_arg_vars, gs_vtxaddr_vars,
1313 invocation_index, es_vertex_lds_addr,
1314 es_exporter_tid, num_live_vertices_in_workgroup, fully_culled,
1315 ngg_scratch_lds_base_addr, pervertex_lds_bytes, max_exported_args);
1316 }
1317 nir_push_else(b, if_cull_en);
1318 {
1319 /* When culling is disabled, we do the same as we would without culling. */
1320 nir_if *if_wave_0 = nir_push_if(b, nir_ieq(b, nir_load_subgroup_id(b), nir_imm_int(b, 0)));
1321 {
1322 nir_ssa_def *vtx_cnt = nir_load_workgroup_num_input_vertices_amd(b);
1323 nir_ssa_def *prim_cnt = nir_load_workgroup_num_input_primitives_amd(b);
1324 nir_alloc_vertices_and_primitives_amd(b, vtx_cnt, prim_cnt);
1325 }
1326 nir_pop_if(b, if_wave_0);
1327 nir_store_var(b, prim_exp_arg_var, emit_ngg_nogs_prim_exp_arg(b, nogs_state), 0x1u);
1328 }
1329 nir_pop_if(b, if_cull_en);
1330
1331 /* Update shader arguments.
1332 *
1333 * The registers which hold information about the subgroup's
1334 * vertices and primitives are updated here, so the rest of the shader
1335 * doesn't need to worry about the culling.
1336 *
1337 * These "overwrite" intrinsics must be at top level control flow,
1338 * otherwise they can mess up the backend (eg. ACO's SSA).
1339 *
1340 * TODO:
1341 * A cleaner solution would be to simply replace all usages of these args
1342 * with the load of the variables.
1343 * However, this wouldn't work right now because the backend uses the arguments
1344 * for purposes not expressed in NIR, eg. VS input loads, etc.
1345 * This can change if VS input loads and other stuff are lowered to eg. load_buffer_amd.
1346 */
1347
1348 if (b->shader->info.stage == MESA_SHADER_VERTEX)
1349 nogs_state->overwrite_args =
1350 nir_overwrite_vs_arguments_amd(b,
1351 nir_load_var(b, repacked_arg_vars[0]), nir_load_var(b, repacked_arg_vars[1]));
1352 else if (b->shader->info.stage == MESA_SHADER_TESS_EVAL)
1353 nogs_state->overwrite_args =
1354 nir_overwrite_tes_arguments_amd(b,
1355 nir_load_var(b, repacked_arg_vars[0]), nir_load_var(b, repacked_arg_vars[1]),
1356 nir_load_var(b, repacked_arg_vars[2]), nir_load_var(b, repacked_arg_vars[3]));
1357 else
1358 unreachable("Should be VS or TES.");
1359 }
1360
1361 void
ac_nir_lower_ngg_nogs(nir_shader * shader,enum radeon_family family,unsigned max_num_es_vertices,unsigned num_vertices_per_primitives,unsigned max_workgroup_size,unsigned wave_size,bool can_cull,bool early_prim_export,bool passthrough,bool export_prim_id,bool provoking_vtx_last,bool use_edgeflags,bool has_prim_query,uint32_t instance_rate_inputs)1362 ac_nir_lower_ngg_nogs(nir_shader *shader,
1363 enum radeon_family family,
1364 unsigned max_num_es_vertices,
1365 unsigned num_vertices_per_primitives,
1366 unsigned max_workgroup_size,
1367 unsigned wave_size,
1368 bool can_cull,
1369 bool early_prim_export,
1370 bool passthrough,
1371 bool export_prim_id,
1372 bool provoking_vtx_last,
1373 bool use_edgeflags,
1374 bool has_prim_query,
1375 uint32_t instance_rate_inputs)
1376 {
1377 nir_function_impl *impl = nir_shader_get_entrypoint(shader);
1378 assert(impl);
1379 assert(max_num_es_vertices && max_workgroup_size && wave_size);
1380 assert(!(can_cull && passthrough));
1381
1382 nir_variable *position_value_var = nir_local_variable_create(impl, glsl_vec4_type(), "position_value");
1383 nir_variable *prim_exp_arg_var = nir_local_variable_create(impl, glsl_uint_type(), "prim_exp_arg");
1384 nir_variable *es_accepted_var = can_cull ? nir_local_variable_create(impl, glsl_bool_type(), "es_accepted") : NULL;
1385 nir_variable *gs_accepted_var = can_cull ? nir_local_variable_create(impl, glsl_bool_type(), "gs_accepted") : NULL;
1386
1387 lower_ngg_nogs_state state = {
1388 .passthrough = passthrough,
1389 .export_prim_id = export_prim_id,
1390 .early_prim_export = early_prim_export,
1391 .use_edgeflags = use_edgeflags,
1392 .has_prim_query = has_prim_query,
1393 .can_cull = can_cull,
1394 .num_vertices_per_primitives = num_vertices_per_primitives,
1395 .provoking_vtx_idx = provoking_vtx_last ? (num_vertices_per_primitives - 1) : 0,
1396 .position_value_var = position_value_var,
1397 .prim_exp_arg_var = prim_exp_arg_var,
1398 .es_accepted_var = es_accepted_var,
1399 .gs_accepted_var = gs_accepted_var,
1400 .max_num_waves = DIV_ROUND_UP(max_workgroup_size, wave_size),
1401 .max_es_num_vertices = max_num_es_vertices,
1402 .wave_size = wave_size,
1403 .instance_rate_inputs = instance_rate_inputs,
1404 };
1405
1406 const bool need_prim_id_store_shared =
1407 export_prim_id && shader->info.stage == MESA_SHADER_VERTEX;
1408
1409 if (export_prim_id) {
1410 nir_variable *prim_id_var = nir_variable_create(shader, nir_var_shader_out, glsl_uint_type(), "ngg_prim_id");
1411 prim_id_var->data.location = VARYING_SLOT_PRIMITIVE_ID;
1412 prim_id_var->data.driver_location = VARYING_SLOT_PRIMITIVE_ID;
1413 prim_id_var->data.interpolation = INTERP_MODE_NONE;
1414 shader->info.outputs_written |= VARYING_BIT_PRIMITIVE_ID;
1415 }
1416
1417 nir_builder builder;
1418 nir_builder *b = &builder; /* This is to avoid the & */
1419 nir_builder_init(b, impl);
1420
1421 if (can_cull) {
1422 /* We need divergence info for culling shaders. */
1423 nir_divergence_analysis(shader);
1424 analyze_shader_before_culling(shader, &state);
1425 save_reusable_variables(b, &state);
1426 }
1427
1428 nir_cf_list extracted;
1429 nir_cf_extract(&extracted, nir_before_cf_list(&impl->body), nir_after_cf_list(&impl->body));
1430 b->cursor = nir_before_cf_list(&impl->body);
1431
1432 ngg_nogs_init_vertex_indices_vars(b, impl, &state);
1433
1434 if (!can_cull) {
1435 /* Newer chips can use PRIMGEN_PASSTHRU_NO_MSG to skip gs_alloc_req for NGG passthrough. */
1436 if (!(passthrough && family >= CHIP_NAVI23)) {
1437 /* Allocate export space on wave 0 - confirm to the HW that we want to use all possible space */
1438 nir_if *if_wave_0 = nir_push_if(b, nir_ieq(b, nir_load_subgroup_id(b), nir_imm_int(b, 0)));
1439 {
1440 nir_ssa_def *vtx_cnt = nir_load_workgroup_num_input_vertices_amd(b);
1441 nir_ssa_def *prim_cnt = nir_load_workgroup_num_input_primitives_amd(b);
1442 nir_alloc_vertices_and_primitives_amd(b, vtx_cnt, prim_cnt);
1443 }
1444 nir_pop_if(b, if_wave_0);
1445 }
1446
1447 /* Take care of early primitive export, otherwise just pack the primitive export argument */
1448 if (state.early_prim_export)
1449 emit_ngg_nogs_prim_export(b, &state, NULL);
1450 else
1451 nir_store_var(b, prim_exp_arg_var, emit_ngg_nogs_prim_exp_arg(b, &state), 0x1u);
1452 } else {
1453 add_deferred_attribute_culling(b, &extracted, &state);
1454 b->cursor = nir_after_cf_list(&impl->body);
1455
1456 if (state.early_prim_export)
1457 emit_ngg_nogs_prim_export(b, &state, nir_load_var(b, state.prim_exp_arg_var));
1458 }
1459
1460 if (need_prim_id_store_shared) {
1461 /* We need LDS space when VS needs to export the primitive ID. */
1462 state.total_lds_bytes = MAX2(state.total_lds_bytes, max_num_es_vertices * 4u);
1463
1464 /* The LDS space aliases with what is used by culling, so we need a barrier. */
1465 if (can_cull) {
1466 nir_scoped_barrier(b, .execution_scope = NIR_SCOPE_WORKGROUP,
1467 .memory_scope = NIR_SCOPE_WORKGROUP,
1468 .memory_semantics = NIR_MEMORY_ACQ_REL,
1469 .memory_modes = nir_var_mem_shared);
1470 }
1471
1472 emit_ngg_nogs_prim_id_store_shared(b, &state);
1473
1474 /* Wait for GS threads to store primitive ID in LDS. */
1475 nir_scoped_barrier(b, .execution_scope = NIR_SCOPE_WORKGROUP, .memory_scope = NIR_SCOPE_WORKGROUP,
1476 .memory_semantics = NIR_MEMORY_ACQ_REL, .memory_modes = nir_var_mem_shared);
1477 }
1478
1479 nir_intrinsic_instr *export_vertex_instr;
1480 nir_ssa_def *es_thread = can_cull ? nir_load_var(b, es_accepted_var) : nir_has_input_vertex_amd(b);
1481
1482 nir_if *if_es_thread = nir_push_if(b, es_thread);
1483 {
1484 /* Run the actual shader */
1485 nir_cf_reinsert(&extracted, b->cursor);
1486 b->cursor = nir_after_cf_list(&if_es_thread->then_list);
1487
1488 if (state.export_prim_id)
1489 emit_store_ngg_nogs_es_primitive_id(b);
1490
1491 /* Export all vertex attributes (including the primitive ID) */
1492 export_vertex_instr = nir_export_vertex_amd(b);
1493 }
1494 nir_pop_if(b, if_es_thread);
1495
1496 /* Take care of late primitive export */
1497 if (!state.early_prim_export) {
1498 emit_ngg_nogs_prim_export(b, &state, nir_load_var(b, prim_exp_arg_var));
1499 }
1500
1501 if (can_cull) {
1502 /* Replace uniforms. */
1503 apply_reusable_variables(b, &state);
1504
1505 /* Remove the redundant position output. */
1506 remove_extra_pos_outputs(shader, &state);
1507
1508 /* After looking at the performance in apps eg. Doom Eternal, and The Witcher 3,
1509 * it seems that it's best to put the position export always at the end, and
1510 * then let ACO schedule it up (slightly) only when early prim export is used.
1511 */
1512 b->cursor = nir_before_instr(&export_vertex_instr->instr);
1513
1514 nir_ssa_def *pos_val = nir_load_var(b, state.position_value_var);
1515 nir_io_semantics io_sem = { .location = VARYING_SLOT_POS, .num_slots = 1 };
1516 nir_store_output(b, pos_val, nir_imm_int(b, 0), .base = VARYING_SLOT_POS, .component = 0, .io_semantics = io_sem);
1517 }
1518
1519 nir_metadata_preserve(impl, nir_metadata_none);
1520 nir_validate_shader(shader, "after emitting NGG VS/TES");
1521
1522 /* Cleanup */
1523 nir_opt_dead_write_vars(shader);
1524 nir_lower_vars_to_ssa(shader);
1525 nir_remove_dead_variables(shader, nir_var_function_temp, NULL);
1526 nir_lower_alu_to_scalar(shader, NULL, NULL);
1527 nir_lower_phis_to_scalar(shader, true);
1528
1529 if (can_cull) {
1530 /* It's beneficial to redo these opts after splitting the shader. */
1531 nir_opt_sink(shader, nir_move_load_input | nir_move_const_undef | nir_move_copies);
1532 nir_opt_move(shader, nir_move_load_input | nir_move_copies | nir_move_const_undef);
1533 }
1534
1535 bool progress;
1536 do {
1537 progress = false;
1538 NIR_PASS(progress, shader, nir_opt_undef);
1539 NIR_PASS(progress, shader, nir_opt_dce);
1540 NIR_PASS(progress, shader, nir_opt_dead_cf);
1541
1542 if (can_cull)
1543 progress |= cleanup_culling_shader_after_dce(shader, b->impl, &state);
1544 } while (progress);
1545
1546 shader->info.shared_size = state.total_lds_bytes;
1547 }
1548
1549 /**
1550 * Return the address of the LDS storage reserved for the N'th vertex,
1551 * where N is in emit order, meaning:
1552 * - during the finale, N is the invocation_index (within the workgroup)
1553 * - during vertex emit, i.e. while the API GS shader invocation is running,
1554 * N = invocation_index * gs_max_out_vertices + emit_idx
1555 * where emit_idx is the vertex index in the current API GS invocation.
1556 *
1557 * Goals of the LDS memory layout:
1558 * 1. Eliminate bank conflicts on write for geometry shaders that have all emits
1559 * in uniform control flow
1560 * 2. Eliminate bank conflicts on read for export if, additionally, there is no
1561 * culling
1562 * 3. Agnostic to the number of waves (since we don't know it before compiling)
1563 * 4. Allow coalescing of LDS instructions (ds_write_b128 etc.)
1564 * 5. Avoid wasting memory.
1565 *
1566 * We use an AoS layout due to point 4 (this also helps point 3). In an AoS
1567 * layout, elimination of bank conflicts requires that each vertex occupy an
1568 * odd number of dwords. We use the additional dword to store the output stream
1569 * index as well as a flag to indicate whether this vertex ends a primitive
1570 * for rasterization.
1571 *
1572 * Swizzling is required to satisfy points 1 and 2 simultaneously.
1573 *
1574 * Vertices are stored in export order (gsthread * gs_max_out_vertices + emitidx).
1575 * Indices are swizzled in groups of 32, which ensures point 1 without
1576 * disturbing point 2.
1577 *
1578 * \return an LDS pointer to type {[N x i32], [4 x i8]}
1579 */
1580 static nir_ssa_def *
ngg_gs_out_vertex_addr(nir_builder * b,nir_ssa_def * out_vtx_idx,lower_ngg_gs_state * s)1581 ngg_gs_out_vertex_addr(nir_builder *b, nir_ssa_def *out_vtx_idx, lower_ngg_gs_state *s)
1582 {
1583 unsigned write_stride_2exp = ffs(MAX2(b->shader->info.gs.vertices_out, 1)) - 1;
1584
1585 /* gs_max_out_vertices = 2^(write_stride_2exp) * some odd number */
1586 if (write_stride_2exp) {
1587 nir_ssa_def *row = nir_ushr_imm(b, out_vtx_idx, 5);
1588 nir_ssa_def *swizzle = nir_iand_imm(b, row, (1u << write_stride_2exp) - 1u);
1589 out_vtx_idx = nir_ixor(b, out_vtx_idx, swizzle);
1590 }
1591
1592 nir_ssa_def *out_vtx_offs = nir_imul_imm(b, out_vtx_idx, s->lds_bytes_per_gs_out_vertex);
1593 return nir_iadd_imm_nuw(b, out_vtx_offs, s->lds_addr_gs_out_vtx);
1594 }
1595
1596 static nir_ssa_def *
ngg_gs_emit_vertex_addr(nir_builder * b,nir_ssa_def * gs_vtx_idx,lower_ngg_gs_state * s)1597 ngg_gs_emit_vertex_addr(nir_builder *b, nir_ssa_def *gs_vtx_idx, lower_ngg_gs_state *s)
1598 {
1599 nir_ssa_def *tid_in_tg = nir_load_local_invocation_index(b);
1600 nir_ssa_def *gs_out_vtx_base = nir_imul_imm(b, tid_in_tg, b->shader->info.gs.vertices_out);
1601 nir_ssa_def *out_vtx_idx = nir_iadd_nuw(b, gs_out_vtx_base, gs_vtx_idx);
1602
1603 return ngg_gs_out_vertex_addr(b, out_vtx_idx, s);
1604 }
1605
1606 static void
ngg_gs_clear_primflags(nir_builder * b,nir_ssa_def * num_vertices,unsigned stream,lower_ngg_gs_state * s)1607 ngg_gs_clear_primflags(nir_builder *b, nir_ssa_def *num_vertices, unsigned stream, lower_ngg_gs_state *s)
1608 {
1609 nir_ssa_def *zero_u8 = nir_imm_zero(b, 1, 8);
1610 nir_store_var(b, s->current_clear_primflag_idx_var, num_vertices, 0x1u);
1611
1612 nir_loop *loop = nir_push_loop(b);
1613 {
1614 nir_ssa_def *current_clear_primflag_idx = nir_load_var(b, s->current_clear_primflag_idx_var);
1615 nir_if *if_break = nir_push_if(b, nir_uge(b, current_clear_primflag_idx, nir_imm_int(b, b->shader->info.gs.vertices_out)));
1616 {
1617 nir_jump(b, nir_jump_break);
1618 }
1619 nir_push_else(b, if_break);
1620 {
1621 nir_ssa_def *emit_vtx_addr = ngg_gs_emit_vertex_addr(b, current_clear_primflag_idx, s);
1622 nir_store_shared(b, zero_u8, emit_vtx_addr, .base = s->lds_offs_primflags + stream);
1623 nir_store_var(b, s->current_clear_primflag_idx_var, nir_iadd_imm_nuw(b, current_clear_primflag_idx, 1), 0x1u);
1624 }
1625 nir_pop_if(b, if_break);
1626 }
1627 nir_pop_loop(b, loop);
1628 }
1629
1630 static void
ngg_gs_shader_query(nir_builder * b,nir_intrinsic_instr * intrin,lower_ngg_gs_state * s)1631 ngg_gs_shader_query(nir_builder *b, nir_intrinsic_instr *intrin, lower_ngg_gs_state *s)
1632 {
1633 nir_if *if_shader_query = nir_push_if(b, nir_load_shader_query_enabled_amd(b));
1634 nir_ssa_def *num_prims_in_wave = NULL;
1635
1636 /* Calculate the "real" number of emitted primitives from the emitted GS vertices and primitives.
1637 * GS emits points, line strips or triangle strips.
1638 * Real primitives are points, lines or triangles.
1639 */
1640 if (nir_src_is_const(intrin->src[0]) && nir_src_is_const(intrin->src[1])) {
1641 unsigned gs_vtx_cnt = nir_src_as_uint(intrin->src[0]);
1642 unsigned gs_prm_cnt = nir_src_as_uint(intrin->src[1]);
1643 unsigned total_prm_cnt = gs_vtx_cnt - gs_prm_cnt * (s->num_vertices_per_primitive - 1u);
1644 nir_ssa_def *num_threads = nir_bit_count(b, nir_ballot(b, 1, s->wave_size, nir_imm_bool(b, true)));
1645 num_prims_in_wave = nir_imul_imm(b, num_threads, total_prm_cnt);
1646 } else {
1647 nir_ssa_def *gs_vtx_cnt = intrin->src[0].ssa;
1648 nir_ssa_def *prm_cnt = intrin->src[1].ssa;
1649 if (s->num_vertices_per_primitive > 1)
1650 prm_cnt = nir_iadd(b, nir_imul_imm(b, prm_cnt, -1u * (s->num_vertices_per_primitive - 1)), gs_vtx_cnt);
1651 num_prims_in_wave = nir_reduce(b, prm_cnt, .reduction_op = nir_op_iadd);
1652 }
1653
1654 /* Store the query result to GDS using an atomic add. */
1655 nir_if *if_first_lane = nir_push_if(b, nir_elect(b, 1));
1656 nir_gds_atomic_add_amd(b, 32, num_prims_in_wave, nir_imm_int(b, 0), nir_imm_int(b, 0x100));
1657 nir_pop_if(b, if_first_lane);
1658
1659 nir_pop_if(b, if_shader_query);
1660 }
1661
1662 static bool
lower_ngg_gs_store_output(nir_builder * b,nir_intrinsic_instr * intrin,lower_ngg_gs_state * s)1663 lower_ngg_gs_store_output(nir_builder *b, nir_intrinsic_instr *intrin, lower_ngg_gs_state *s)
1664 {
1665 assert(nir_src_is_const(intrin->src[1]));
1666 b->cursor = nir_before_instr(&intrin->instr);
1667
1668 unsigned writemask = nir_intrinsic_write_mask(intrin);
1669 unsigned base = nir_intrinsic_base(intrin);
1670 unsigned component_offset = nir_intrinsic_component(intrin);
1671 unsigned base_offset = nir_src_as_uint(intrin->src[1]);
1672 nir_io_semantics io_sem = nir_intrinsic_io_semantics(intrin);
1673
1674 assert((base + base_offset) < VARYING_SLOT_MAX);
1675
1676 nir_ssa_def *store_val = intrin->src[0].ssa;
1677
1678 for (unsigned comp = 0; comp < 4; ++comp) {
1679 if (!(writemask & (1 << comp)))
1680 continue;
1681 unsigned stream = (io_sem.gs_streams >> (comp * 2)) & 0x3;
1682 if (!(b->shader->info.gs.active_stream_mask & (1 << stream)))
1683 continue;
1684
1685 /* Small bitsize components consume the same amount of space as 32-bit components,
1686 * but 64-bit ones consume twice as many. (Vulkan spec 15.1.5)
1687 */
1688 unsigned num_consumed_components = DIV_ROUND_UP(store_val->bit_size, 32);
1689 nir_ssa_def *element = nir_channel(b, store_val, comp);
1690 if (num_consumed_components > 1)
1691 element = nir_extract_bits(b, &element, 1, 0, num_consumed_components, 32);
1692
1693 /* Save output usage info. */
1694 gs_output_info *info = &s->output_info[io_sem.location];
1695 /* The same output should always belong to the same stream. */
1696 assert(!info->components_mask || info->stream == stream);
1697 info->stream = stream;
1698 info->components_mask |= BITFIELD_BIT(component_offset + comp * num_consumed_components);
1699
1700 for (unsigned c = 0; c < num_consumed_components; ++c) {
1701 unsigned component_index = (comp * num_consumed_components) + c + component_offset;
1702 unsigned base_index = base + base_offset + component_index / 4;
1703 component_index %= 4;
1704
1705 /* Store the current component element */
1706 nir_ssa_def *component_element = element;
1707 if (num_consumed_components > 1)
1708 component_element = nir_channel(b, component_element, c);
1709 if (component_element->bit_size != 32)
1710 component_element = nir_u2u32(b, component_element);
1711
1712 nir_store_var(b, s->output_vars[base_index][component_index], component_element, 0x1u);
1713 }
1714 }
1715
1716 nir_instr_remove(&intrin->instr);
1717 return true;
1718 }
1719
1720 static bool
lower_ngg_gs_emit_vertex_with_counter(nir_builder * b,nir_intrinsic_instr * intrin,lower_ngg_gs_state * s)1721 lower_ngg_gs_emit_vertex_with_counter(nir_builder *b, nir_intrinsic_instr *intrin, lower_ngg_gs_state *s)
1722 {
1723 b->cursor = nir_before_instr(&intrin->instr);
1724
1725 unsigned stream = nir_intrinsic_stream_id(intrin);
1726 if (!(b->shader->info.gs.active_stream_mask & (1 << stream))) {
1727 nir_instr_remove(&intrin->instr);
1728 return true;
1729 }
1730
1731 nir_ssa_def *gs_emit_vtx_idx = intrin->src[0].ssa;
1732 nir_ssa_def *current_vtx_per_prim = intrin->src[1].ssa;
1733 nir_ssa_def *gs_emit_vtx_addr = ngg_gs_emit_vertex_addr(b, gs_emit_vtx_idx, s);
1734
1735 for (unsigned slot = 0; slot < VARYING_SLOT_MAX; ++slot) {
1736 unsigned packed_location = util_bitcount64((b->shader->info.outputs_written & BITFIELD64_MASK(slot)));
1737 gs_output_info *info = &s->output_info[slot];
1738 if (info->stream != stream || !info->components_mask)
1739 continue;
1740
1741 unsigned mask = info->components_mask;
1742 while (mask) {
1743 int start, count;
1744 u_bit_scan_consecutive_range(&mask, &start, &count);
1745 nir_ssa_def *values[4] = {0};
1746 for (int c = start; c < start + count; ++c) {
1747 /* Load output from variable. */
1748 values[c - start] = nir_load_var(b, s->output_vars[slot][c]);
1749 /* Clear the variable (it is undefined after emit_vertex) */
1750 nir_store_var(b, s->output_vars[slot][c], nir_ssa_undef(b, 1, 32), 0x1);
1751 }
1752
1753 nir_ssa_def *store_val = nir_vec(b, values, (unsigned)count);
1754 nir_store_shared(b, store_val, gs_emit_vtx_addr,
1755 .base = packed_location * 16 + start * 4,
1756 .align_mul = 4);
1757 }
1758 }
1759
1760 /* Calculate and store per-vertex primitive flags based on vertex counts:
1761 * - bit 0: whether this vertex finishes a primitive (a real primitive, not the strip)
1762 * - bit 1: whether the primitive index is odd (if we are emitting triangle strips, otherwise always 0)
1763 * - bit 2: always 1 (so that we can use it for determining vertex liveness)
1764 */
1765
1766 nir_ssa_def *completes_prim = nir_ige(b, current_vtx_per_prim, nir_imm_int(b, s->num_vertices_per_primitive - 1));
1767 nir_ssa_def *prim_flag = nir_bcsel(b, completes_prim, nir_imm_int(b, 0b101u), nir_imm_int(b, 0b100u));
1768
1769 if (s->num_vertices_per_primitive == 3) {
1770 nir_ssa_def *odd = nir_iand_imm(b, current_vtx_per_prim, 1);
1771 prim_flag = nir_iadd_nuw(b, prim_flag, nir_ishl(b, odd, nir_imm_int(b, 1)));
1772 }
1773
1774 nir_store_shared(b, nir_u2u8(b, prim_flag), gs_emit_vtx_addr, .base = s->lds_offs_primflags + stream, .align_mul = 4u);
1775 nir_instr_remove(&intrin->instr);
1776 return true;
1777 }
1778
1779 static bool
lower_ngg_gs_end_primitive_with_counter(nir_builder * b,nir_intrinsic_instr * intrin,UNUSED lower_ngg_gs_state * s)1780 lower_ngg_gs_end_primitive_with_counter(nir_builder *b, nir_intrinsic_instr *intrin, UNUSED lower_ngg_gs_state *s)
1781 {
1782 b->cursor = nir_before_instr(&intrin->instr);
1783
1784 /* These are not needed, we can simply remove them */
1785 nir_instr_remove(&intrin->instr);
1786 return true;
1787 }
1788
1789 static bool
lower_ngg_gs_set_vertex_and_primitive_count(nir_builder * b,nir_intrinsic_instr * intrin,lower_ngg_gs_state * s)1790 lower_ngg_gs_set_vertex_and_primitive_count(nir_builder *b, nir_intrinsic_instr *intrin, lower_ngg_gs_state *s)
1791 {
1792 b->cursor = nir_before_instr(&intrin->instr);
1793
1794 unsigned stream = nir_intrinsic_stream_id(intrin);
1795 if (stream > 0 && !(b->shader->info.gs.active_stream_mask & (1 << stream))) {
1796 nir_instr_remove(&intrin->instr);
1797 return true;
1798 }
1799
1800 s->found_out_vtxcnt[stream] = true;
1801
1802 /* Clear the primitive flags of non-emitted vertices */
1803 if (!nir_src_is_const(intrin->src[0]) || nir_src_as_uint(intrin->src[0]) < b->shader->info.gs.vertices_out)
1804 ngg_gs_clear_primflags(b, intrin->src[0].ssa, stream, s);
1805
1806 ngg_gs_shader_query(b, intrin, s);
1807 nir_instr_remove(&intrin->instr);
1808 return true;
1809 }
1810
1811 static bool
lower_ngg_gs_intrinsic(nir_builder * b,nir_instr * instr,void * state)1812 lower_ngg_gs_intrinsic(nir_builder *b, nir_instr *instr, void *state)
1813 {
1814 lower_ngg_gs_state *s = (lower_ngg_gs_state *) state;
1815
1816 if (instr->type != nir_instr_type_intrinsic)
1817 return false;
1818
1819 nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
1820
1821 if (intrin->intrinsic == nir_intrinsic_store_output)
1822 return lower_ngg_gs_store_output(b, intrin, s);
1823 else if (intrin->intrinsic == nir_intrinsic_emit_vertex_with_counter)
1824 return lower_ngg_gs_emit_vertex_with_counter(b, intrin, s);
1825 else if (intrin->intrinsic == nir_intrinsic_end_primitive_with_counter)
1826 return lower_ngg_gs_end_primitive_with_counter(b, intrin, s);
1827 else if (intrin->intrinsic == nir_intrinsic_set_vertex_and_primitive_count)
1828 return lower_ngg_gs_set_vertex_and_primitive_count(b, intrin, s);
1829
1830 return false;
1831 }
1832
1833 static void
lower_ngg_gs_intrinsics(nir_shader * shader,lower_ngg_gs_state * s)1834 lower_ngg_gs_intrinsics(nir_shader *shader, lower_ngg_gs_state *s)
1835 {
1836 nir_shader_instructions_pass(shader, lower_ngg_gs_intrinsic, nir_metadata_none, s);
1837 }
1838
1839 static void
ngg_gs_export_primitives(nir_builder * b,nir_ssa_def * max_num_out_prims,nir_ssa_def * tid_in_tg,nir_ssa_def * exporter_tid_in_tg,nir_ssa_def * primflag_0,lower_ngg_gs_state * s)1840 ngg_gs_export_primitives(nir_builder *b, nir_ssa_def *max_num_out_prims, nir_ssa_def *tid_in_tg,
1841 nir_ssa_def *exporter_tid_in_tg, nir_ssa_def *primflag_0,
1842 lower_ngg_gs_state *s)
1843 {
1844 nir_if *if_prim_export_thread = nir_push_if(b, nir_ilt(b, tid_in_tg, max_num_out_prims));
1845
1846 /* Only bit 0 matters here - set it to 1 when the primitive should be null */
1847 nir_ssa_def *is_null_prim = nir_ixor(b, primflag_0, nir_imm_int(b, -1u));
1848
1849 nir_ssa_def *vtx_indices[3] = {0};
1850 vtx_indices[s->num_vertices_per_primitive - 1] = exporter_tid_in_tg;
1851 if (s->num_vertices_per_primitive >= 2)
1852 vtx_indices[s->num_vertices_per_primitive - 2] = nir_isub(b, exporter_tid_in_tg, nir_imm_int(b, 1));
1853 if (s->num_vertices_per_primitive == 3)
1854 vtx_indices[s->num_vertices_per_primitive - 3] = nir_isub(b, exporter_tid_in_tg, nir_imm_int(b, 2));
1855
1856 if (s->num_vertices_per_primitive == 3) {
1857 /* API GS outputs triangle strips, but NGG HW understands triangles.
1858 * We already know the triangles due to how we set the primitive flags, but we need to
1859 * make sure the vertex order is so that the front/back is correct, and the provoking vertex is kept.
1860 */
1861
1862 nir_ssa_def *is_odd = nir_ubfe(b, primflag_0, nir_imm_int(b, 1), nir_imm_int(b, 1));
1863 if (!s->provoking_vertex_last) {
1864 vtx_indices[1] = nir_iadd(b, vtx_indices[1], is_odd);
1865 vtx_indices[2] = nir_isub(b, vtx_indices[2], is_odd);
1866 } else {
1867 vtx_indices[0] = nir_iadd(b, vtx_indices[0], is_odd);
1868 vtx_indices[1] = nir_isub(b, vtx_indices[1], is_odd);
1869 }
1870 }
1871
1872 nir_ssa_def *arg = emit_pack_ngg_prim_exp_arg(b, s->num_vertices_per_primitive, vtx_indices, is_null_prim, false);
1873 nir_export_primitive_amd(b, arg);
1874 nir_pop_if(b, if_prim_export_thread);
1875 }
1876
1877 static void
ngg_gs_export_vertices(nir_builder * b,nir_ssa_def * max_num_out_vtx,nir_ssa_def * tid_in_tg,nir_ssa_def * out_vtx_lds_addr,lower_ngg_gs_state * s)1878 ngg_gs_export_vertices(nir_builder *b, nir_ssa_def *max_num_out_vtx, nir_ssa_def *tid_in_tg,
1879 nir_ssa_def *out_vtx_lds_addr, lower_ngg_gs_state *s)
1880 {
1881 nir_if *if_vtx_export_thread = nir_push_if(b, nir_ilt(b, tid_in_tg, max_num_out_vtx));
1882 nir_ssa_def *exported_out_vtx_lds_addr = out_vtx_lds_addr;
1883
1884 if (!s->output_compile_time_known) {
1885 /* Vertex compaction.
1886 * The current thread will export a vertex that was live in another invocation.
1887 * Load the index of the vertex that the current thread will have to export.
1888 */
1889 nir_ssa_def *exported_vtx_idx = nir_load_shared(b, 1, 8, out_vtx_lds_addr, .base = s->lds_offs_primflags + 1);
1890 exported_out_vtx_lds_addr = ngg_gs_out_vertex_addr(b, nir_u2u32(b, exported_vtx_idx), s);
1891 }
1892
1893 /* Remember proper bit sizes of output variables. */
1894 uint8_t out_bitsizes[VARYING_SLOT_MAX];
1895 memset(out_bitsizes, 32, VARYING_SLOT_MAX);
1896 nir_foreach_shader_out_variable(var, b->shader) {
1897 /* Check 8/16-bit. All others should be lowered to 32-bit already. */
1898 unsigned bit_size = glsl_base_type_bit_size(glsl_get_base_type(glsl_without_array(var->type)));
1899 if (bit_size == 8 || bit_size == 16)
1900 out_bitsizes[var->data.location] = bit_size;
1901 }
1902
1903 for (unsigned slot = 0; slot < VARYING_SLOT_MAX; ++slot) {
1904 if (!(b->shader->info.outputs_written & BITFIELD64_BIT(slot)))
1905 continue;
1906
1907 gs_output_info *info = &s->output_info[slot];
1908 if (!info->components_mask || info->stream != 0)
1909 continue;
1910
1911 unsigned packed_location = util_bitcount64((b->shader->info.outputs_written & BITFIELD64_MASK(slot)));
1912 nir_io_semantics io_sem = { .location = slot, .num_slots = 1 };
1913
1914 unsigned mask = info->components_mask;
1915 while (mask) {
1916 int start, count;
1917 u_bit_scan_consecutive_range(&mask, &start, &count);
1918 nir_ssa_def *load =
1919 nir_load_shared(b, count, 32, exported_out_vtx_lds_addr,
1920 .base = packed_location * 16 + start * 4,
1921 .align_mul = 4);
1922
1923 /* Convert to the expected bit size of the output variable. */
1924 if (out_bitsizes[slot] != 32)
1925 load = nir_u2u(b, load, out_bitsizes[slot]);
1926
1927 nir_store_output(b, load, nir_imm_int(b, 0), .base = slot, .io_semantics = io_sem,
1928 .component = start, .write_mask = BITFIELD_MASK(count));
1929 }
1930 }
1931
1932 nir_export_vertex_amd(b);
1933 nir_pop_if(b, if_vtx_export_thread);
1934 }
1935
1936 static void
ngg_gs_setup_vertex_compaction(nir_builder * b,nir_ssa_def * vertex_live,nir_ssa_def * tid_in_tg,nir_ssa_def * exporter_tid_in_tg,lower_ngg_gs_state * s)1937 ngg_gs_setup_vertex_compaction(nir_builder *b, nir_ssa_def *vertex_live, nir_ssa_def *tid_in_tg,
1938 nir_ssa_def *exporter_tid_in_tg, lower_ngg_gs_state *s)
1939 {
1940 assert(vertex_live->bit_size == 1);
1941 nir_if *if_vertex_live = nir_push_if(b, vertex_live);
1942 {
1943 /* Setup the vertex compaction.
1944 * Save the current thread's id for the thread which will export the current vertex.
1945 * We reuse stream 1 of the primitive flag of the other thread's vertex for storing this.
1946 */
1947
1948 nir_ssa_def *exporter_lds_addr = ngg_gs_out_vertex_addr(b, exporter_tid_in_tg, s);
1949 nir_ssa_def *tid_in_tg_u8 = nir_u2u8(b, tid_in_tg);
1950 nir_store_shared(b, tid_in_tg_u8, exporter_lds_addr, .base = s->lds_offs_primflags + 1);
1951 }
1952 nir_pop_if(b, if_vertex_live);
1953 }
1954
1955 static nir_ssa_def *
ngg_gs_load_out_vtx_primflag_0(nir_builder * b,nir_ssa_def * tid_in_tg,nir_ssa_def * vtx_lds_addr,nir_ssa_def * max_num_out_vtx,lower_ngg_gs_state * s)1956 ngg_gs_load_out_vtx_primflag_0(nir_builder *b, nir_ssa_def *tid_in_tg, nir_ssa_def *vtx_lds_addr,
1957 nir_ssa_def *max_num_out_vtx, lower_ngg_gs_state *s)
1958 {
1959 nir_ssa_def *zero = nir_imm_int(b, 0);
1960
1961 nir_if *if_outvtx_thread = nir_push_if(b, nir_ilt(b, tid_in_tg, max_num_out_vtx));
1962 nir_ssa_def *primflag_0 = nir_load_shared(b, 1, 8, vtx_lds_addr, .base = s->lds_offs_primflags, .align_mul = 4u);
1963 primflag_0 = nir_u2u32(b, primflag_0);
1964 nir_pop_if(b, if_outvtx_thread);
1965
1966 return nir_if_phi(b, primflag_0, zero);
1967 }
1968
1969 static void
ngg_gs_finale(nir_builder * b,lower_ngg_gs_state * s)1970 ngg_gs_finale(nir_builder *b, lower_ngg_gs_state *s)
1971 {
1972 nir_ssa_def *tid_in_tg = nir_load_local_invocation_index(b);
1973 nir_ssa_def *max_vtxcnt = nir_load_workgroup_num_input_vertices_amd(b);
1974 nir_ssa_def *max_prmcnt = max_vtxcnt; /* They are currently practically the same; both RADV and RadeonSI do this. */
1975 nir_ssa_def *out_vtx_lds_addr = ngg_gs_out_vertex_addr(b, tid_in_tg, s);
1976
1977 if (s->output_compile_time_known) {
1978 /* When the output is compile-time known, the GS writes all possible vertices and primitives it can.
1979 * The gs_alloc_req needs to happen on one wave only, otherwise the HW hangs.
1980 */
1981 nir_if *if_wave_0 = nir_push_if(b, nir_ieq(b, nir_load_subgroup_id(b), nir_imm_zero(b, 1, 32)));
1982 nir_alloc_vertices_and_primitives_amd(b, max_vtxcnt, max_prmcnt);
1983 nir_pop_if(b, if_wave_0);
1984 }
1985
1986 /* Workgroup barrier: wait for all GS threads to finish */
1987 nir_scoped_barrier(b, .execution_scope=NIR_SCOPE_WORKGROUP, .memory_scope=NIR_SCOPE_WORKGROUP,
1988 .memory_semantics=NIR_MEMORY_ACQ_REL, .memory_modes=nir_var_mem_shared);
1989
1990 nir_ssa_def *out_vtx_primflag_0 = ngg_gs_load_out_vtx_primflag_0(b, tid_in_tg, out_vtx_lds_addr, max_vtxcnt, s);
1991
1992 if (s->output_compile_time_known) {
1993 ngg_gs_export_primitives(b, max_vtxcnt, tid_in_tg, tid_in_tg, out_vtx_primflag_0, s);
1994 ngg_gs_export_vertices(b, max_vtxcnt, tid_in_tg, out_vtx_lds_addr, s);
1995 return;
1996 }
1997
1998 /* When the output vertex count is not known at compile time:
1999 * There may be gaps between invocations that have live vertices, but NGG hardware
2000 * requires that the invocations that export vertices are packed (ie. compact).
2001 * To ensure this, we need to repack invocations that have a live vertex.
2002 */
2003 nir_ssa_def *vertex_live = nir_ine(b, out_vtx_primflag_0, nir_imm_zero(b, 1, out_vtx_primflag_0->bit_size));
2004 wg_repack_result rep = repack_invocations_in_workgroup(b, vertex_live, s->lds_addr_gs_scratch, s->max_num_waves, s->wave_size);
2005
2006 nir_ssa_def *workgroup_num_vertices = rep.num_repacked_invocations;
2007 nir_ssa_def *exporter_tid_in_tg = rep.repacked_invocation_index;
2008
2009 /* When the workgroup emits 0 total vertices, we also must export 0 primitives (otherwise the HW can hang). */
2010 nir_ssa_def *any_output = nir_ine(b, workgroup_num_vertices, nir_imm_int(b, 0));
2011 max_prmcnt = nir_bcsel(b, any_output, max_prmcnt, nir_imm_int(b, 0));
2012
2013 /* Allocate export space. We currently don't compact primitives, just use the maximum number. */
2014 nir_if *if_wave_0 = nir_push_if(b, nir_ieq(b, nir_load_subgroup_id(b), nir_imm_zero(b, 1, 32)));
2015 nir_alloc_vertices_and_primitives_amd(b, workgroup_num_vertices, max_prmcnt);
2016 nir_pop_if(b, if_wave_0);
2017
2018 /* Vertex compaction. This makes sure there are no gaps between threads that export vertices. */
2019 ngg_gs_setup_vertex_compaction(b, vertex_live, tid_in_tg, exporter_tid_in_tg, s);
2020
2021 /* Workgroup barrier: wait for all LDS stores to finish. */
2022 nir_scoped_barrier(b, .execution_scope=NIR_SCOPE_WORKGROUP, .memory_scope=NIR_SCOPE_WORKGROUP,
2023 .memory_semantics=NIR_MEMORY_ACQ_REL, .memory_modes=nir_var_mem_shared);
2024
2025 ngg_gs_export_primitives(b, max_prmcnt, tid_in_tg, exporter_tid_in_tg, out_vtx_primflag_0, s);
2026 ngg_gs_export_vertices(b, workgroup_num_vertices, tid_in_tg, out_vtx_lds_addr, s);
2027 }
2028
2029 void
ac_nir_lower_ngg_gs(nir_shader * shader,unsigned wave_size,unsigned max_workgroup_size,unsigned esgs_ring_lds_bytes,unsigned gs_out_vtx_bytes,unsigned gs_total_out_vtx_bytes,bool provoking_vertex_last)2030 ac_nir_lower_ngg_gs(nir_shader *shader,
2031 unsigned wave_size,
2032 unsigned max_workgroup_size,
2033 unsigned esgs_ring_lds_bytes,
2034 unsigned gs_out_vtx_bytes,
2035 unsigned gs_total_out_vtx_bytes,
2036 bool provoking_vertex_last)
2037 {
2038 nir_function_impl *impl = nir_shader_get_entrypoint(shader);
2039 assert(impl);
2040
2041 lower_ngg_gs_state state = {
2042 .max_num_waves = DIV_ROUND_UP(max_workgroup_size, wave_size),
2043 .wave_size = wave_size,
2044 .lds_addr_gs_out_vtx = esgs_ring_lds_bytes,
2045 .lds_addr_gs_scratch = ALIGN(esgs_ring_lds_bytes + gs_total_out_vtx_bytes, 8u /* for the repacking code */),
2046 .lds_offs_primflags = gs_out_vtx_bytes,
2047 .lds_bytes_per_gs_out_vertex = gs_out_vtx_bytes + 4u,
2048 .provoking_vertex_last = provoking_vertex_last,
2049 };
2050
2051 unsigned lds_scratch_bytes = DIV_ROUND_UP(state.max_num_waves, 4u) * 4u;
2052 unsigned total_lds_bytes = state.lds_addr_gs_scratch + lds_scratch_bytes;
2053 shader->info.shared_size = total_lds_bytes;
2054
2055 nir_gs_count_vertices_and_primitives(shader, state.const_out_vtxcnt, state.const_out_prmcnt, 4u);
2056 state.output_compile_time_known = state.const_out_vtxcnt[0] == shader->info.gs.vertices_out &&
2057 state.const_out_prmcnt[0] != -1;
2058
2059 if (!state.output_compile_time_known)
2060 state.current_clear_primflag_idx_var = nir_local_variable_create(impl, glsl_uint_type(), "current_clear_primflag_idx");
2061
2062 if (shader->info.gs.output_primitive == SHADER_PRIM_POINTS)
2063 state.num_vertices_per_primitive = 1;
2064 else if (shader->info.gs.output_primitive == SHADER_PRIM_LINE_STRIP)
2065 state.num_vertices_per_primitive = 2;
2066 else if (shader->info.gs.output_primitive == SHADER_PRIM_TRIANGLE_STRIP)
2067 state.num_vertices_per_primitive = 3;
2068 else
2069 unreachable("Invalid GS output primitive.");
2070
2071 /* Extract the full control flow. It is going to be wrapped in an if statement. */
2072 nir_cf_list extracted;
2073 nir_cf_extract(&extracted, nir_before_cf_list(&impl->body), nir_after_cf_list(&impl->body));
2074
2075 nir_builder builder;
2076 nir_builder *b = &builder; /* This is to avoid the & */
2077 nir_builder_init(b, impl);
2078 b->cursor = nir_before_cf_list(&impl->body);
2079
2080 /* Workgroup barrier: wait for ES threads */
2081 nir_scoped_barrier(b, .execution_scope=NIR_SCOPE_WORKGROUP, .memory_scope=NIR_SCOPE_WORKGROUP,
2082 .memory_semantics=NIR_MEMORY_ACQ_REL, .memory_modes=nir_var_mem_shared);
2083
2084 /* Wrap the GS control flow. */
2085 nir_if *if_gs_thread = nir_push_if(b, nir_has_input_primitive_amd(b));
2086
2087 /* Create and initialize output variables */
2088 for (unsigned slot = 0; slot < VARYING_SLOT_MAX; ++slot) {
2089 for (unsigned comp = 0; comp < 4; ++comp) {
2090 state.output_vars[slot][comp] = nir_local_variable_create(impl, glsl_uint_type(), "output");
2091 }
2092 }
2093
2094 nir_cf_reinsert(&extracted, b->cursor);
2095 b->cursor = nir_after_cf_list(&if_gs_thread->then_list);
2096 nir_pop_if(b, if_gs_thread);
2097
2098 /* Lower the GS intrinsics */
2099 lower_ngg_gs_intrinsics(shader, &state);
2100 b->cursor = nir_after_cf_list(&impl->body);
2101
2102 if (!state.found_out_vtxcnt[0]) {
2103 fprintf(stderr, "Could not find set_vertex_and_primitive_count for stream 0. This would hang your GPU.");
2104 abort();
2105 }
2106
2107 /* Emit the finale sequence */
2108 ngg_gs_finale(b, &state);
2109 nir_validate_shader(shader, "after emitting NGG GS");
2110
2111 /* Cleanup */
2112 nir_lower_vars_to_ssa(shader);
2113 nir_remove_dead_variables(shader, nir_var_function_temp, NULL);
2114 nir_metadata_preserve(impl, nir_metadata_none);
2115 }
2116
2117 static void
ms_store_prim_indices(nir_builder * b,nir_ssa_def * val,nir_ssa_def * offset_src,lower_ngg_ms_state * s)2118 ms_store_prim_indices(nir_builder *b,
2119 nir_ssa_def *val,
2120 nir_ssa_def *offset_src,
2121 lower_ngg_ms_state *s)
2122 {
2123 assert(val->num_components <= 3);
2124
2125 if (!offset_src)
2126 offset_src = nir_imm_int(b, 0);
2127
2128 nir_store_shared(b, nir_u2u8(b, val), offset_src, .base = s->layout.lds.indices_addr);
2129 }
2130
2131 static nir_ssa_def *
ms_load_prim_indices(nir_builder * b,nir_ssa_def * offset_src,lower_ngg_ms_state * s)2132 ms_load_prim_indices(nir_builder *b,
2133 nir_ssa_def *offset_src,
2134 lower_ngg_ms_state *s)
2135 {
2136 if (!offset_src)
2137 offset_src = nir_imm_int(b, 0);
2138
2139 return nir_load_shared(b, 1, 8, offset_src, .base = s->layout.lds.indices_addr);
2140 }
2141
2142 static void
ms_store_num_prims(nir_builder * b,nir_ssa_def * store_val,lower_ngg_ms_state * s)2143 ms_store_num_prims(nir_builder *b,
2144 nir_ssa_def *store_val,
2145 lower_ngg_ms_state *s)
2146 {
2147 nir_ssa_def *addr = nir_imm_int(b, 0);
2148 nir_store_shared(b, nir_u2u32(b, store_val), addr, .base = s->layout.lds.workgroup_info_addr + lds_ms_num_prims);
2149 }
2150
2151 static nir_ssa_def *
ms_load_num_prims(nir_builder * b,lower_ngg_ms_state * s)2152 ms_load_num_prims(nir_builder *b,
2153 lower_ngg_ms_state *s)
2154 {
2155 nir_ssa_def *addr = nir_imm_int(b, 0);
2156 return nir_load_shared(b, 1, 32, addr, .base = s->layout.lds.workgroup_info_addr + lds_ms_num_prims);
2157 }
2158
2159 static nir_ssa_def *
lower_ms_store_output(nir_builder * b,nir_intrinsic_instr * intrin,lower_ngg_ms_state * s)2160 lower_ms_store_output(nir_builder *b,
2161 nir_intrinsic_instr *intrin,
2162 lower_ngg_ms_state *s)
2163 {
2164 nir_io_semantics io_sem = nir_intrinsic_io_semantics(intrin);
2165 nir_ssa_def *store_val = intrin->src[0].ssa;
2166
2167 /* Component makes no sense here. */
2168 assert(nir_intrinsic_component(intrin) == 0);
2169
2170 if (io_sem.location == VARYING_SLOT_PRIMITIVE_COUNT) {
2171 /* Total number of primitives output by the mesh shader workgroup.
2172 * This can be read and written by any invocation any number of times.
2173 */
2174
2175 /* Base, offset and component make no sense here. */
2176 assert(nir_src_is_const(intrin->src[1]) && nir_src_as_uint(intrin->src[1]) == 0);
2177
2178 ms_store_num_prims(b, store_val, s);
2179 } else if (io_sem.location == VARYING_SLOT_PRIMITIVE_INDICES) {
2180 /* Contrary to the name, these are not primitive indices, but
2181 * vertex indices for each vertex of the output primitives.
2182 * The Mesh NV API has these stored in a flat array.
2183 */
2184
2185 nir_ssa_def *offset_src = nir_get_io_offset_src(intrin)->ssa;
2186 ms_store_prim_indices(b, store_val, offset_src, s);
2187 } else {
2188 unreachable("Invalid mesh shader output");
2189 }
2190
2191 return NIR_LOWER_INSTR_PROGRESS_REPLACE;
2192 }
2193
2194 static nir_ssa_def *
lower_ms_load_output(nir_builder * b,nir_intrinsic_instr * intrin,lower_ngg_ms_state * s)2195 lower_ms_load_output(nir_builder *b,
2196 nir_intrinsic_instr *intrin,
2197 lower_ngg_ms_state *s)
2198 {
2199 nir_io_semantics io_sem = nir_intrinsic_io_semantics(intrin);
2200
2201 /* Component makes no sense here. */
2202 assert(nir_intrinsic_component(intrin) == 0);
2203
2204 if (io_sem.location == VARYING_SLOT_PRIMITIVE_COUNT) {
2205 /* Base, offset and component make no sense here. */
2206 assert(nir_src_is_const(intrin->src[1]) && nir_src_as_uint(intrin->src[1]) == 0);
2207
2208 return ms_load_num_prims(b, s);
2209 } else if (io_sem.location == VARYING_SLOT_PRIMITIVE_INDICES) {
2210 nir_ssa_def *offset_src = nir_get_io_offset_src(intrin)->ssa;
2211 nir_ssa_def *index = ms_load_prim_indices(b, offset_src, s);
2212 return nir_u2u(b, index, intrin->dest.ssa.bit_size);
2213 }
2214
2215 unreachable("Invalid mesh shader output");
2216 }
2217
2218 static nir_ssa_def *
ms_arrayed_output_base_addr(nir_builder * b,nir_ssa_def * arr_index,unsigned driver_location,unsigned num_arrayed_outputs)2219 ms_arrayed_output_base_addr(nir_builder *b,
2220 nir_ssa_def *arr_index,
2221 unsigned driver_location,
2222 unsigned num_arrayed_outputs)
2223 {
2224 /* Address offset of the array item (vertex or primitive). */
2225 unsigned arr_index_stride = num_arrayed_outputs * 16u;
2226 nir_ssa_def *arr_index_off = nir_imul_imm(b, arr_index, arr_index_stride);
2227
2228 /* IO address offset within the vertex or primitive data. */
2229 unsigned io_offset = driver_location * 16u;
2230 nir_ssa_def *io_off = nir_imm_int(b, io_offset);
2231
2232 return nir_iadd_nuw(b, arr_index_off, io_off);
2233 }
2234
2235 static void
update_ms_output_info_slot(lower_ngg_ms_state * s,unsigned slot,unsigned base_off,uint32_t components_mask)2236 update_ms_output_info_slot(lower_ngg_ms_state *s,
2237 unsigned slot, unsigned base_off,
2238 uint32_t components_mask)
2239 {
2240 while (components_mask) {
2241 s->output_info[slot + base_off].components_mask |= components_mask & 0xF;
2242
2243 components_mask >>= 4;
2244 base_off++;
2245 }
2246 }
2247
2248 static void
update_ms_output_info(nir_intrinsic_instr * intrin,const ms_out_part * out,lower_ngg_ms_state * s)2249 update_ms_output_info(nir_intrinsic_instr *intrin,
2250 const ms_out_part *out,
2251 lower_ngg_ms_state *s)
2252 {
2253 nir_io_semantics io_sem = nir_intrinsic_io_semantics(intrin);
2254 nir_src *base_offset_src = nir_get_io_offset_src(intrin);
2255 uint32_t write_mask = nir_intrinsic_write_mask(intrin);
2256 unsigned component_offset = nir_intrinsic_component(intrin);
2257
2258 nir_ssa_def *store_val = intrin->src[0].ssa;
2259 write_mask = util_widen_mask(write_mask, DIV_ROUND_UP(store_val->bit_size, 32));
2260 uint32_t components_mask = write_mask << component_offset;
2261
2262 if (nir_src_is_const(*base_offset_src)) {
2263 /* Simply mark the components of the current slot as used. */
2264 unsigned base_off = nir_src_as_uint(*base_offset_src);
2265 update_ms_output_info_slot(s, io_sem.location, base_off, components_mask);
2266 } else {
2267 /* Indirect offset: mark the components of all slots as used. */
2268 for (unsigned base_off = 0; base_off < io_sem.num_slots; ++base_off)
2269 update_ms_output_info_slot(s, io_sem.location, base_off, components_mask);
2270 }
2271 }
2272
2273 static nir_ssa_def *
regroup_store_val(nir_builder * b,nir_ssa_def * store_val)2274 regroup_store_val(nir_builder *b, nir_ssa_def *store_val)
2275 {
2276 /* Vulkan spec 15.1.4-15.1.5:
2277 *
2278 * The shader interface consists of output slots with 4x 32-bit components.
2279 * Small bitsize components consume the same space as 32-bit components,
2280 * but 64-bit ones consume twice as much.
2281 *
2282 * The same output slot may consist of components of different bit sizes.
2283 * Therefore for simplicity we don't store small bitsize components
2284 * contiguously, but pad them instead. In practice, they are converted to
2285 * 32-bit and then stored contiguously.
2286 */
2287
2288 if (store_val->bit_size < 32) {
2289 assert(store_val->num_components <= 4);
2290 nir_ssa_def *comps[4] = {0};
2291 for (unsigned c = 0; c < store_val->num_components; ++c)
2292 comps[c] = nir_u2u32(b, nir_channel(b, store_val, c));
2293 return nir_vec(b, comps, store_val->num_components);
2294 }
2295
2296 return store_val;
2297 }
2298
2299 static nir_ssa_def *
regroup_load_val(nir_builder * b,nir_ssa_def * load,unsigned dest_bit_size)2300 regroup_load_val(nir_builder *b, nir_ssa_def *load, unsigned dest_bit_size)
2301 {
2302 if (dest_bit_size == load->bit_size)
2303 return load;
2304
2305 /* Small bitsize components are not stored contiguously, take care of that here. */
2306 unsigned num_components = load->num_components;
2307 assert(num_components <= 4);
2308 nir_ssa_def *components[4] = {0};
2309 for (unsigned i = 0; i < num_components; ++i)
2310 components[i] = nir_u2u(b, nir_channel(b, load, i), dest_bit_size);
2311
2312 return nir_vec(b, components, num_components);
2313 }
2314
2315 static const ms_out_part *
ms_get_out_layout_part(unsigned location,shader_info * info,ms_out_mode * out_mode,lower_ngg_ms_state * s)2316 ms_get_out_layout_part(unsigned location,
2317 shader_info *info,
2318 ms_out_mode *out_mode,
2319 lower_ngg_ms_state *s)
2320 {
2321 uint64_t mask = BITFIELD64_BIT(location);
2322
2323 if (info->per_primitive_outputs & mask) {
2324 if (mask & s->layout.lds.prm_attr.mask) {
2325 *out_mode = ms_out_mode_lds;
2326 return &s->layout.lds.prm_attr;
2327 } else if (mask & s->layout.vram.prm_attr.mask) {
2328 *out_mode = ms_out_mode_vram;
2329 return &s->layout.vram.prm_attr;
2330 } else if (mask & s->layout.var.prm_attr.mask) {
2331 *out_mode = ms_out_mode_var;
2332 return &s->layout.var.prm_attr;
2333 }
2334 } else {
2335 if (mask & s->layout.lds.vtx_attr.mask) {
2336 *out_mode = ms_out_mode_lds;
2337 return &s->layout.lds.vtx_attr;
2338 } else if (mask & s->layout.vram.vtx_attr.mask) {
2339 *out_mode = ms_out_mode_vram;
2340 return &s->layout.vram.vtx_attr;
2341 } else if (mask & s->layout.var.vtx_attr.mask) {
2342 *out_mode = ms_out_mode_var;
2343 return &s->layout.var.vtx_attr;
2344 }
2345 }
2346
2347 unreachable("Couldn't figure out mesh shader output mode.");
2348 }
2349
2350 static void
ms_store_arrayed_output_intrin(nir_builder * b,nir_intrinsic_instr * intrin,lower_ngg_ms_state * s)2351 ms_store_arrayed_output_intrin(nir_builder *b,
2352 nir_intrinsic_instr *intrin,
2353 lower_ngg_ms_state *s)
2354 {
2355 ms_out_mode out_mode;
2356 unsigned location = nir_intrinsic_io_semantics(intrin).location;
2357 const ms_out_part *out = ms_get_out_layout_part(location, &b->shader->info, &out_mode, s);
2358 update_ms_output_info(intrin, out, s);
2359
2360 /* We compact the LDS size (we don't reserve LDS space for outputs which can
2361 * be stored in variables), so we can't rely on the original driver_location.
2362 * Instead, we compute the first free location based on the output mask.
2363 */
2364 unsigned driver_location = util_bitcount64(out->mask & u_bit_consecutive64(0, location));
2365 unsigned component_offset = nir_intrinsic_component(intrin);
2366 unsigned write_mask = nir_intrinsic_write_mask(intrin);
2367 unsigned num_outputs = util_bitcount64(out->mask);
2368 unsigned const_off = out->addr + component_offset * 4;
2369
2370 nir_ssa_def *store_val = regroup_store_val(b, intrin->src[0].ssa);
2371 nir_ssa_def *arr_index = nir_get_io_arrayed_index_src(intrin)->ssa;
2372 nir_ssa_def *base_addr = ms_arrayed_output_base_addr(b, arr_index, driver_location, num_outputs);
2373 nir_ssa_def *base_offset = nir_get_io_offset_src(intrin)->ssa;
2374 nir_ssa_def *base_addr_off = nir_imul_imm(b, base_offset, 16u);
2375 nir_ssa_def *addr = nir_iadd_nuw(b, base_addr, base_addr_off);
2376
2377 if (out_mode == ms_out_mode_lds) {
2378 nir_store_shared(b, store_val, addr, .base = const_off,
2379 .write_mask = write_mask, .align_mul = 16,
2380 .align_offset = const_off % 16);
2381 } else if (out_mode == ms_out_mode_vram) {
2382 nir_ssa_def *ring = nir_load_ring_mesh_scratch_amd(b);
2383 nir_ssa_def *off = nir_load_ring_mesh_scratch_offset_amd(b);
2384 nir_store_buffer_amd(b, store_val, ring, addr, off,
2385 .base = const_off,
2386 .write_mask = write_mask,
2387 .memory_modes = nir_var_shader_out);
2388 } else if (out_mode == ms_out_mode_var) {
2389 if (store_val->bit_size > 32) {
2390 /* Split 64-bit store values to 32-bit components. */
2391 store_val = nir_bitcast_vector(b, store_val, 32);
2392 /* Widen the write mask so it is in 32-bit components. */
2393 write_mask = util_widen_mask(write_mask, store_val->bit_size / 32);
2394 }
2395
2396 u_foreach_bit(comp, write_mask) {
2397 nir_ssa_def *val = nir_channel(b, store_val, comp);
2398 unsigned idx = location * 4 + comp + component_offset;
2399 nir_store_var(b, s->out_variables[idx], val, 0x1);
2400 }
2401 } else {
2402 unreachable("Invalid MS output mode for store");
2403 }
2404 }
2405
2406 static nir_ssa_def *
ms_load_arrayed_output(nir_builder * b,nir_ssa_def * arr_index,nir_ssa_def * base_offset,unsigned location,unsigned component_offset,unsigned num_components,unsigned load_bit_size,lower_ngg_ms_state * s)2407 ms_load_arrayed_output(nir_builder *b,
2408 nir_ssa_def *arr_index,
2409 nir_ssa_def *base_offset,
2410 unsigned location,
2411 unsigned component_offset,
2412 unsigned num_components,
2413 unsigned load_bit_size,
2414 lower_ngg_ms_state *s)
2415 {
2416 ms_out_mode out_mode;
2417 const ms_out_part *out = ms_get_out_layout_part(location, &b->shader->info, &out_mode, s);
2418
2419 unsigned component_addr_off = component_offset * 4;
2420 unsigned num_outputs = util_bitcount64(out->mask);
2421 unsigned const_off = out->addr + component_offset * 4;
2422
2423 /* Use compacted driver location instead of the original. */
2424 unsigned driver_location = util_bitcount64(out->mask & u_bit_consecutive64(0, location));
2425
2426 nir_ssa_def *base_addr = ms_arrayed_output_base_addr(b, arr_index, driver_location, num_outputs);
2427 nir_ssa_def *base_addr_off = nir_imul_imm(b, base_offset, 16);
2428 nir_ssa_def *addr = nir_iadd_nuw(b, base_addr, base_addr_off);
2429
2430 if (out_mode == ms_out_mode_lds) {
2431 return nir_load_shared(b, num_components, load_bit_size, addr, .align_mul = 16,
2432 .align_offset = component_addr_off % 16,
2433 .base = const_off);
2434 } else if (out_mode == ms_out_mode_vram) {
2435 nir_ssa_def *ring = nir_load_ring_mesh_scratch_amd(b);
2436 nir_ssa_def *off = nir_load_ring_mesh_scratch_offset_amd(b);
2437 return nir_load_buffer_amd(b, num_components, load_bit_size, ring, addr, off,
2438 .base = const_off,
2439 .memory_modes = nir_var_shader_out);
2440 } else if (out_mode == ms_out_mode_var) {
2441 nir_ssa_def *arr[8] = {0};
2442 unsigned num_32bit_components = num_components * load_bit_size / 32;
2443 for (unsigned comp = 0; comp < num_32bit_components; ++comp) {
2444 unsigned idx = location * 4 + comp + component_addr_off;
2445 arr[comp] = nir_load_var(b, s->out_variables[idx]);
2446 }
2447 if (load_bit_size > 32)
2448 return nir_extract_bits(b, arr, 1, 0, num_components, load_bit_size);
2449 return nir_vec(b, arr, num_components);
2450 } else {
2451 unreachable("Invalid MS output mode for load");
2452 }
2453 }
2454
2455 static nir_ssa_def *
ms_load_arrayed_output_intrin(nir_builder * b,nir_intrinsic_instr * intrin,lower_ngg_ms_state * s)2456 ms_load_arrayed_output_intrin(nir_builder *b,
2457 nir_intrinsic_instr *intrin,
2458 lower_ngg_ms_state *s)
2459 {
2460 nir_ssa_def *arr_index = nir_get_io_arrayed_index_src(intrin)->ssa;
2461 nir_ssa_def *base_offset = nir_get_io_offset_src(intrin)->ssa;
2462
2463 unsigned location = nir_intrinsic_io_semantics(intrin).location;
2464 unsigned component_offset = nir_intrinsic_component(intrin);
2465 unsigned bit_size = intrin->dest.ssa.bit_size;
2466 unsigned num_components = intrin->dest.ssa.num_components;
2467 unsigned load_bit_size = MAX2(bit_size, 32);
2468
2469 nir_ssa_def *load =
2470 ms_load_arrayed_output(b, arr_index, base_offset, location, component_offset,
2471 num_components, load_bit_size, s);
2472
2473 return regroup_load_val(b, load, bit_size);
2474 }
2475
2476 static nir_ssa_def *
lower_ms_load_workgroup_index(nir_builder * b,UNUSED nir_intrinsic_instr * intrin,lower_ngg_ms_state * s)2477 lower_ms_load_workgroup_index(nir_builder *b,
2478 UNUSED nir_intrinsic_instr *intrin,
2479 lower_ngg_ms_state *s)
2480 {
2481 return s->workgroup_index;
2482 }
2483
2484 static nir_ssa_def *
update_ms_scoped_barrier(nir_builder * b,nir_intrinsic_instr * intrin,lower_ngg_ms_state * s)2485 update_ms_scoped_barrier(nir_builder *b,
2486 nir_intrinsic_instr *intrin,
2487 lower_ngg_ms_state *s)
2488 {
2489 /* Output loads and stores are lowered to shared memory access,
2490 * so we have to update the barriers to also reflect this.
2491 */
2492 unsigned mem_modes = nir_intrinsic_memory_modes(intrin);
2493 if (mem_modes & nir_var_shader_out)
2494 mem_modes |= nir_var_mem_shared;
2495 else
2496 return NULL;
2497
2498 nir_intrinsic_set_memory_modes(intrin, mem_modes);
2499
2500 return NIR_LOWER_INSTR_PROGRESS;
2501 }
2502
2503 static nir_ssa_def *
lower_ms_intrinsic(nir_builder * b,nir_instr * instr,void * state)2504 lower_ms_intrinsic(nir_builder *b, nir_instr *instr, void *state)
2505 {
2506 lower_ngg_ms_state *s = (lower_ngg_ms_state *) state;
2507
2508 if (instr->type != nir_instr_type_intrinsic)
2509 return NULL;
2510
2511 nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
2512
2513 switch (intrin->intrinsic) {
2514 case nir_intrinsic_store_output:
2515 return lower_ms_store_output(b, intrin, s);
2516 case nir_intrinsic_load_output:
2517 return lower_ms_load_output(b, intrin, s);
2518 case nir_intrinsic_store_per_vertex_output:
2519 case nir_intrinsic_store_per_primitive_output:
2520 ms_store_arrayed_output_intrin(b, intrin, s);
2521 return NIR_LOWER_INSTR_PROGRESS_REPLACE;
2522 case nir_intrinsic_load_per_vertex_output:
2523 case nir_intrinsic_load_per_primitive_output:
2524 return ms_load_arrayed_output_intrin(b, intrin, s);
2525 case nir_intrinsic_scoped_barrier:
2526 return update_ms_scoped_barrier(b, intrin, s);
2527 case nir_intrinsic_load_workgroup_index:
2528 return lower_ms_load_workgroup_index(b, intrin, s);
2529 default:
2530 unreachable("Not a lowerable mesh shader intrinsic.");
2531 }
2532 }
2533
2534 static bool
filter_ms_intrinsic(const nir_instr * instr,UNUSED const void * st)2535 filter_ms_intrinsic(const nir_instr *instr,
2536 UNUSED const void *st)
2537 {
2538 if (instr->type != nir_instr_type_intrinsic)
2539 return false;
2540
2541 nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
2542 return intrin->intrinsic == nir_intrinsic_store_output ||
2543 intrin->intrinsic == nir_intrinsic_load_output ||
2544 intrin->intrinsic == nir_intrinsic_store_per_vertex_output ||
2545 intrin->intrinsic == nir_intrinsic_load_per_vertex_output ||
2546 intrin->intrinsic == nir_intrinsic_store_per_primitive_output ||
2547 intrin->intrinsic == nir_intrinsic_load_per_primitive_output ||
2548 intrin->intrinsic == nir_intrinsic_scoped_barrier ||
2549 intrin->intrinsic == nir_intrinsic_load_workgroup_index;
2550 }
2551
2552 static void
lower_ms_intrinsics(nir_shader * shader,lower_ngg_ms_state * s)2553 lower_ms_intrinsics(nir_shader *shader, lower_ngg_ms_state *s)
2554 {
2555 nir_shader_lower_instructions(shader, filter_ms_intrinsic, lower_ms_intrinsic, s);
2556 }
2557
2558 static void
ms_emit_arrayed_outputs(nir_builder * b,nir_ssa_def * invocation_index,uint64_t mask,lower_ngg_ms_state * s)2559 ms_emit_arrayed_outputs(nir_builder *b,
2560 nir_ssa_def *invocation_index,
2561 uint64_t mask,
2562 lower_ngg_ms_state *s)
2563 {
2564 nir_ssa_def *zero = nir_imm_int(b, 0);
2565
2566 u_foreach_bit64(slot, mask) {
2567 /* Should not occour here, handled separately. */
2568 assert(slot != VARYING_SLOT_PRIMITIVE_COUNT && slot != VARYING_SLOT_PRIMITIVE_INDICES);
2569
2570 const nir_io_semantics io_sem = { .location = slot, .num_slots = 1 };
2571 unsigned component_mask = s->output_info[slot].components_mask;
2572
2573 while (component_mask) {
2574 int start_comp = 0, num_components = 1;
2575 u_bit_scan_consecutive_range(&component_mask, &start_comp, &num_components);
2576
2577 nir_ssa_def *load =
2578 ms_load_arrayed_output(b, invocation_index, zero, slot, start_comp,
2579 num_components, 32, s);
2580
2581 nir_store_output(b, load, nir_imm_int(b, 0), .base = slot, .component = start_comp,
2582 .io_semantics = io_sem);
2583 }
2584 }
2585 }
2586
2587 static void
emit_ms_prelude(nir_builder * b,lower_ngg_ms_state * s)2588 emit_ms_prelude(nir_builder *b, lower_ngg_ms_state *s)
2589 {
2590 b->cursor = nir_before_cf_list(&b->impl->body);
2591
2592 /* Initialize NIR variables for same-invocation outputs. */
2593 uint64_t same_invocation_output_mask = s->layout.var.prm_attr.mask | s->layout.var.vtx_attr.mask;
2594
2595 u_foreach_bit64(slot, same_invocation_output_mask) {
2596 for (unsigned comp = 0; comp < 4; ++comp) {
2597 unsigned idx = slot * 4 + comp;
2598 s->out_variables[idx] = nir_local_variable_create(b->impl, glsl_uint_type(), "ms_var_output");
2599 nir_store_var(b, s->out_variables[idx], nir_imm_int(b, 0), 0x1);
2600 }
2601 }
2602
2603 bool uses_workgroup_id =
2604 BITSET_TEST(b->shader->info.system_values_read, SYSTEM_VALUE_WORKGROUP_ID) ||
2605 BITSET_TEST(b->shader->info.system_values_read, SYSTEM_VALUE_WORKGROUP_INDEX);
2606
2607 if (!uses_workgroup_id)
2608 return;
2609
2610 /* The HW doesn't support a proper workgroup index for vertex processing stages,
2611 * so we use the vertex ID which is equivalent to the index of the current workgroup
2612 * within the current dispatch.
2613 *
2614 * Due to the register programming of mesh shaders, this value is only filled for
2615 * the first invocation of the first wave. To let other waves know, we use LDS.
2616 */
2617 nir_ssa_def *workgroup_index = nir_load_vertex_id_zero_base(b);
2618
2619 if (s->api_workgroup_size <= s->wave_size) {
2620 /* API workgroup is small, so we don't need to use LDS. */
2621 s->workgroup_index = nir_read_first_invocation(b, workgroup_index);
2622 return;
2623 }
2624
2625 unsigned workgroup_index_lds_addr = s->layout.lds.workgroup_info_addr + lds_ms_wg_index;
2626
2627 nir_ssa_def *zero = nir_imm_int(b, 0);
2628 nir_ssa_def *dont_care = nir_ssa_undef(b, 1, 32);
2629 nir_ssa_def *loaded_workgroup_index = NULL;
2630
2631 /* Use elect to make sure only 1 invocation uses LDS. */
2632 nir_if *if_elected = nir_push_if(b, nir_elect(b, 1));
2633 {
2634 nir_ssa_def *wave_id = nir_load_subgroup_id(b);
2635 nir_if *if_wave_0 = nir_push_if(b, nir_ieq_imm(b, wave_id, 0));
2636 {
2637 nir_store_shared(b, workgroup_index, zero, .base = workgroup_index_lds_addr);
2638 nir_scoped_barrier(b, .execution_scope = NIR_SCOPE_WORKGROUP,
2639 .memory_scope = NIR_SCOPE_WORKGROUP,
2640 .memory_semantics = NIR_MEMORY_ACQ_REL,
2641 .memory_modes = nir_var_mem_shared);
2642 }
2643 nir_push_else(b, if_wave_0);
2644 {
2645 nir_scoped_barrier(b, .execution_scope = NIR_SCOPE_WORKGROUP,
2646 .memory_scope = NIR_SCOPE_WORKGROUP,
2647 .memory_semantics = NIR_MEMORY_ACQ_REL,
2648 .memory_modes = nir_var_mem_shared);
2649 loaded_workgroup_index = nir_load_shared(b, 1, 32, zero, .base = workgroup_index_lds_addr);
2650 }
2651 nir_pop_if(b, if_wave_0);
2652
2653 workgroup_index = nir_if_phi(b, workgroup_index, loaded_workgroup_index);
2654 }
2655 nir_pop_if(b, if_elected);
2656
2657 workgroup_index = nir_if_phi(b, workgroup_index, dont_care);
2658 s->workgroup_index = nir_read_first_invocation(b, workgroup_index);
2659 }
2660
2661 static void
set_nv_ms_final_output_counts(nir_builder * b,lower_ngg_ms_state * s,nir_ssa_def ** out_num_prm,nir_ssa_def ** out_num_vtx)2662 set_nv_ms_final_output_counts(nir_builder *b,
2663 lower_ngg_ms_state *s,
2664 nir_ssa_def **out_num_prm,
2665 nir_ssa_def **out_num_vtx)
2666 {
2667 /* Limitations of the NV extension:
2668 * - Number of primitives can be written and read by any invocation,
2669 * so we have to store/load it to/from LDS to make sure the general case works.
2670 * - Number of vertices is not actually known, so we just always use the
2671 * maximum number here.
2672 */
2673 nir_ssa_def *loaded_num_prm;
2674 nir_ssa_def *dont_care = nir_ssa_undef(b, 1, 32);
2675 nir_if *if_elected = nir_push_if(b, nir_elect(b, 1));
2676 {
2677 loaded_num_prm = ms_load_num_prims(b, s);
2678 }
2679 nir_pop_if(b, if_elected);
2680 loaded_num_prm = nir_if_phi(b, loaded_num_prm, dont_care);
2681 nir_ssa_def *num_prm = nir_read_first_invocation(b, loaded_num_prm);
2682 nir_ssa_def *num_vtx = nir_imm_int(b, b->shader->info.mesh.max_vertices_out);
2683 num_prm = nir_umin(b, num_prm, nir_imm_int(b, b->shader->info.mesh.max_primitives_out));
2684
2685 /* If the shader doesn't actually create any primitives, don't allocate any output. */
2686 num_vtx = nir_bcsel(b, nir_ieq_imm(b, num_prm, 0), nir_imm_int(b, 0), num_vtx);
2687
2688 /* Emit GS_ALLOC_REQ on Wave 0 to let the HW know the output size. */
2689 nir_ssa_def *wave_id = nir_load_subgroup_id(b);
2690 nir_if *if_wave_0 = nir_push_if(b, nir_ieq_imm(b, wave_id, 0));
2691 {
2692 nir_alloc_vertices_and_primitives_amd(b, num_vtx, num_prm);
2693 }
2694 nir_pop_if(b, if_wave_0);
2695
2696 *out_num_prm = num_prm;
2697 *out_num_vtx = num_vtx;
2698 }
2699
2700 static void
emit_ms_finale(nir_builder * b,lower_ngg_ms_state * s)2701 emit_ms_finale(nir_builder *b, lower_ngg_ms_state *s)
2702 {
2703 /* We assume there is always a single end block in the shader. */
2704 nir_block *last_block = nir_impl_last_block(b->impl);
2705 b->cursor = nir_after_block(last_block);
2706
2707 nir_scoped_barrier(b, .execution_scope=NIR_SCOPE_WORKGROUP, .memory_scope=NIR_SCOPE_WORKGROUP,
2708 .memory_semantics=NIR_MEMORY_ACQ_REL, .memory_modes=nir_var_shader_out|nir_var_mem_shared);
2709
2710 nir_ssa_def *num_prm;
2711 nir_ssa_def *num_vtx;
2712
2713 set_nv_ms_final_output_counts(b, s, &num_prm, &num_vtx);
2714
2715 nir_ssa_def *invocation_index = nir_load_local_invocation_index(b);
2716
2717 /* Load vertex/primitive attributes from shared memory and
2718 * emit store_output intrinsics for them.
2719 *
2720 * Contrary to the semantics of the API mesh shader, these are now
2721 * compliant with NGG HW semantics, meaning that these store the
2722 * current thread's vertex attributes in a way the HW can export.
2723 */
2724
2725 /* Export vertices. */
2726 nir_ssa_def *has_output_vertex = nir_ilt(b, invocation_index, num_vtx);
2727 nir_if *if_has_output_vertex = nir_push_if(b, has_output_vertex);
2728 {
2729 /* All per-vertex attributes. */
2730 ms_emit_arrayed_outputs(b, invocation_index, s->per_vertex_outputs, s);
2731 nir_export_vertex_amd(b);
2732 }
2733 nir_pop_if(b, if_has_output_vertex);
2734
2735 /* Export primitives. */
2736 nir_ssa_def *has_output_primitive = nir_ilt(b, invocation_index, num_prm);
2737 nir_if *if_has_output_primitive = nir_push_if(b, has_output_primitive);
2738 {
2739 /* Generic per-primitive attributes. */
2740 ms_emit_arrayed_outputs(b, invocation_index, s->per_primitive_outputs, s);
2741
2742 /* Insert layer output store if the pipeline uses multiview but the API shader doesn't write it. */
2743 if (s->insert_layer_output) {
2744 nir_ssa_def *layer = nir_load_view_index(b);
2745 const nir_io_semantics io_sem = { .location = VARYING_SLOT_LAYER, .num_slots = 1 };
2746 nir_store_output(b, layer, nir_imm_int(b, 0), .base = VARYING_SLOT_LAYER, .component = 0, .io_semantics = io_sem);
2747 b->shader->info.outputs_written |= VARYING_BIT_LAYER;
2748 b->shader->info.per_primitive_outputs |= VARYING_BIT_LAYER;
2749 }
2750
2751 /* Primitive connectivity data: describes which vertices the primitive uses. */
2752 nir_ssa_def *prim_idx_addr = nir_imul_imm(b, invocation_index, s->vertices_per_prim);
2753 nir_ssa_def *indices_loaded = nir_load_shared(b, s->vertices_per_prim, 8, prim_idx_addr, .base = s->layout.lds.indices_addr);
2754 nir_ssa_def *indices[3];
2755 nir_ssa_def *max_vtx_idx = nir_iadd_imm(b, num_vtx, -1u);
2756
2757 for (unsigned i = 0; i < s->vertices_per_prim; ++i) {
2758 indices[i] = nir_u2u32(b, nir_channel(b, indices_loaded, i));
2759 indices[i] = nir_umin(b, indices[i], max_vtx_idx);
2760 }
2761
2762 nir_ssa_def *prim_exp_arg = emit_pack_ngg_prim_exp_arg(b, s->vertices_per_prim, indices, NULL, false);
2763 nir_export_primitive_amd(b, prim_exp_arg);
2764 }
2765 nir_pop_if(b, if_has_output_primitive);
2766 }
2767
2768 static void
handle_smaller_ms_api_workgroup(nir_builder * b,lower_ngg_ms_state * s)2769 handle_smaller_ms_api_workgroup(nir_builder *b,
2770 lower_ngg_ms_state *s)
2771 {
2772 if (s->api_workgroup_size >= s->hw_workgroup_size)
2773 return;
2774
2775 /* Handle barriers manually when the API workgroup
2776 * size is less than the HW workgroup size.
2777 *
2778 * The problem is that the real workgroup launched on NGG HW
2779 * will be larger than the size specified by the API, and the
2780 * extra waves need to keep up with barriers in the API waves.
2781 *
2782 * There are 2 different cases:
2783 * 1. The whole API workgroup fits in a single wave.
2784 * We can shrink the barriers to subgroup scope and
2785 * don't need to insert any extra ones.
2786 * 2. The API workgroup occupies multiple waves, but not
2787 * all. In this case, we emit code that consumes every
2788 * barrier on the extra waves.
2789 */
2790 assert(s->hw_workgroup_size % s->wave_size == 0);
2791 bool scan_barriers = ALIGN(s->api_workgroup_size, s->wave_size) < s->hw_workgroup_size;
2792 bool can_shrink_barriers = s->api_workgroup_size <= s->wave_size;
2793 bool need_additional_barriers = scan_barriers && !can_shrink_barriers;
2794
2795 unsigned api_waves_in_flight_addr = s->layout.lds.workgroup_info_addr + lds_ms_num_api_waves;
2796 unsigned num_api_waves = DIV_ROUND_UP(s->api_workgroup_size, s->wave_size);
2797
2798 /* Scan the shader for workgroup barriers. */
2799 if (scan_barriers) {
2800 bool has_any_workgroup_barriers = false;
2801
2802 nir_foreach_block(block, b->impl) {
2803 nir_foreach_instr_safe(instr, block) {
2804 if (instr->type != nir_instr_type_intrinsic)
2805 continue;
2806
2807 nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
2808 bool is_workgroup_barrier =
2809 intrin->intrinsic == nir_intrinsic_scoped_barrier &&
2810 nir_intrinsic_execution_scope(intrin) == NIR_SCOPE_WORKGROUP;
2811
2812 if (!is_workgroup_barrier)
2813 continue;
2814
2815 if (can_shrink_barriers) {
2816 /* Every API invocation runs in the first wave.
2817 * In this case, we can change the barriers to subgroup scope
2818 * and avoid adding additional barriers.
2819 */
2820 nir_intrinsic_set_memory_scope(intrin, NIR_SCOPE_SUBGROUP);
2821 nir_intrinsic_set_execution_scope(intrin, NIR_SCOPE_SUBGROUP);
2822 } else {
2823 has_any_workgroup_barriers = true;
2824 }
2825 }
2826 }
2827
2828 need_additional_barriers &= has_any_workgroup_barriers;
2829 }
2830
2831 /* Extract the full control flow of the shader. */
2832 nir_cf_list extracted;
2833 nir_cf_extract(&extracted, nir_before_cf_list(&b->impl->body), nir_after_cf_list(&b->impl->body));
2834 b->cursor = nir_before_cf_list(&b->impl->body);
2835
2836 /* Wrap the shader in an if to ensure that only the necessary amount of lanes run it. */
2837 nir_ssa_def *invocation_index = nir_load_local_invocation_index(b);
2838 nir_ssa_def *zero = nir_imm_int(b, 0);
2839
2840 if (need_additional_barriers) {
2841 /* First invocation stores 0 to number of API waves in flight. */
2842 nir_if *if_first_in_workgroup = nir_push_if(b, nir_ieq_imm(b, invocation_index, 0));
2843 {
2844 nir_store_shared(b, nir_imm_int(b, num_api_waves), zero, .base = api_waves_in_flight_addr);
2845 }
2846 nir_pop_if(b, if_first_in_workgroup);
2847
2848 nir_scoped_barrier(b, .execution_scope = NIR_SCOPE_WORKGROUP,
2849 .memory_scope = NIR_SCOPE_WORKGROUP,
2850 .memory_semantics = NIR_MEMORY_ACQ_REL,
2851 .memory_modes = nir_var_shader_out | nir_var_mem_shared);
2852 }
2853
2854 nir_ssa_def *has_api_ms_invocation = nir_ult(b, invocation_index, nir_imm_int(b, s->api_workgroup_size));
2855 nir_if *if_has_api_ms_invocation = nir_push_if(b, has_api_ms_invocation);
2856 {
2857 nir_cf_reinsert(&extracted, b->cursor);
2858 b->cursor = nir_after_cf_list(&if_has_api_ms_invocation->then_list);
2859
2860 if (need_additional_barriers) {
2861 /* One invocation in each API wave decrements the number of API waves in flight. */
2862 nir_if *if_elected_again = nir_push_if(b, nir_elect(b, 1));
2863 {
2864 nir_shared_atomic_add(b, 32, zero, nir_imm_int(b, -1u), .base = api_waves_in_flight_addr);
2865 }
2866 nir_pop_if(b, if_elected_again);
2867
2868 nir_scoped_barrier(b, .execution_scope = NIR_SCOPE_WORKGROUP,
2869 .memory_scope = NIR_SCOPE_WORKGROUP,
2870 .memory_semantics = NIR_MEMORY_ACQ_REL,
2871 .memory_modes = nir_var_shader_out | nir_var_mem_shared);
2872 }
2873 }
2874 nir_pop_if(b, if_has_api_ms_invocation);
2875
2876 if (need_additional_barriers) {
2877 /* Make sure that waves that don't run any API invocations execute
2878 * the same amount of barriers as those that do.
2879 *
2880 * We do this by executing a barrier until the number of API waves
2881 * in flight becomes zero.
2882 */
2883 nir_ssa_def *has_api_ms_ballot = nir_ballot(b, 1, s->wave_size, has_api_ms_invocation);
2884 nir_ssa_def *wave_has_no_api_ms = nir_ieq_imm(b, has_api_ms_ballot, 0);
2885 nir_if *if_wave_has_no_api_ms = nir_push_if(b, wave_has_no_api_ms);
2886 {
2887 nir_if *if_elected = nir_push_if(b, nir_elect(b, 1));
2888 {
2889 nir_loop *loop = nir_push_loop(b);
2890 {
2891 nir_scoped_barrier(b, .execution_scope = NIR_SCOPE_WORKGROUP,
2892 .memory_scope = NIR_SCOPE_WORKGROUP,
2893 .memory_semantics = NIR_MEMORY_ACQ_REL,
2894 .memory_modes = nir_var_shader_out | nir_var_mem_shared);
2895
2896 nir_ssa_def *loaded = nir_load_shared(b, 1, 32, zero, .base = api_waves_in_flight_addr);
2897 nir_if *if_break = nir_push_if(b, nir_ieq_imm(b, loaded, 0));
2898 {
2899 nir_jump(b, nir_jump_break);
2900 }
2901 nir_pop_if(b, if_break);
2902 }
2903 nir_pop_loop(b, loop);
2904 }
2905 nir_pop_if(b, if_elected);
2906 }
2907 nir_pop_if(b, if_wave_has_no_api_ms);
2908 }
2909 }
2910
2911 static void
ms_move_output(ms_out_part * from,ms_out_part * to)2912 ms_move_output(ms_out_part *from, ms_out_part *to)
2913 {
2914 uint64_t loc = util_logbase2_64(from->mask);
2915 uint64_t bit = BITFIELD64_BIT(loc);
2916 from->mask ^= bit;
2917 to->mask |= bit;
2918 }
2919
2920 static void
ms_calculate_arrayed_output_layout(ms_out_mem_layout * l,unsigned max_vertices,unsigned max_primitives)2921 ms_calculate_arrayed_output_layout(ms_out_mem_layout *l,
2922 unsigned max_vertices,
2923 unsigned max_primitives)
2924 {
2925 uint32_t lds_vtx_attr_size = util_bitcount64(l->lds.vtx_attr.mask) * max_vertices * 16;
2926 uint32_t lds_prm_attr_size = util_bitcount64(l->lds.prm_attr.mask) * max_primitives * 16;
2927 l->lds.prm_attr.addr = ALIGN(l->lds.vtx_attr.addr + lds_vtx_attr_size, 16);
2928 l->lds.total_size = l->lds.prm_attr.addr + lds_prm_attr_size;
2929
2930 uint32_t vram_vtx_attr_size = util_bitcount64(l->vram.vtx_attr.mask) * max_vertices * 16;
2931 l->vram.prm_attr.addr = ALIGN(l->vram.vtx_attr.addr + vram_vtx_attr_size, 16);
2932 }
2933
2934 static ms_out_mem_layout
ms_calculate_output_layout(unsigned api_shared_size,uint64_t per_vertex_output_mask,uint64_t per_primitive_output_mask,uint64_t cross_invocation_output_access,unsigned max_vertices,unsigned max_primitives,unsigned vertices_per_prim)2935 ms_calculate_output_layout(unsigned api_shared_size,
2936 uint64_t per_vertex_output_mask,
2937 uint64_t per_primitive_output_mask,
2938 uint64_t cross_invocation_output_access,
2939 unsigned max_vertices,
2940 unsigned max_primitives,
2941 unsigned vertices_per_prim)
2942 {
2943 uint64_t lds_per_vertex_output_mask = per_vertex_output_mask & cross_invocation_output_access;
2944 uint64_t lds_per_primitive_output_mask = per_primitive_output_mask & cross_invocation_output_access;
2945
2946 /* Shared memory used by the API shader. */
2947 ms_out_mem_layout l = { .lds = { .total_size = api_shared_size } };
2948
2949 /* Outputs without cross-invocation access can be stored in variables. */
2950 l.var.vtx_attr.mask = per_vertex_output_mask & ~lds_per_vertex_output_mask;
2951 l.var.prm_attr.mask = per_primitive_output_mask & ~lds_per_primitive_output_mask;
2952
2953 /* Workgroup information, see ms_workgroup_* for the layout. */
2954 l.lds.workgroup_info_addr = ALIGN(l.lds.total_size, 16);
2955 l.lds.total_size = l.lds.workgroup_info_addr + 16;
2956
2957 /* Per-vertex and per-primitive output attributes.
2958 * Outputs without cross-invocation access are not included here.
2959 * First, try to put all outputs into LDS (shared memory).
2960 * If they don't fit, try to move them to VRAM one by one.
2961 */
2962 l.lds.vtx_attr.addr = ALIGN(l.lds.total_size, 16);
2963 l.lds.vtx_attr.mask = lds_per_vertex_output_mask;
2964 l.lds.prm_attr.mask = lds_per_primitive_output_mask;
2965 ms_calculate_arrayed_output_layout(&l, max_vertices, max_primitives);
2966
2967 /* NGG shaders can only address up to 32K LDS memory.
2968 * The spec requires us to allow the application to use at least up to 28K
2969 * shared memory. Additionally, we reserve 2K for driver internal use
2970 * (eg. primitive indices and such, see below).
2971 *
2972 * Move the outputs that do not fit LDS, to VRAM.
2973 * Start with per-primitive attributes, because those are grouped at the end.
2974 */
2975 while (l.lds.total_size >= 30 * 1024) {
2976 if (l.lds.prm_attr.mask)
2977 ms_move_output(&l.lds.prm_attr, &l.vram.prm_attr);
2978 else if (l.lds.vtx_attr.mask)
2979 ms_move_output(&l.lds.vtx_attr, &l.vram.vtx_attr);
2980 else
2981 unreachable("API shader uses too much shared memory.");
2982
2983 ms_calculate_arrayed_output_layout(&l, max_vertices, max_primitives);
2984 }
2985
2986 /* Indices: flat array of 8-bit vertex indices for each primitive. */
2987 l.lds.indices_addr = ALIGN(l.lds.total_size, 16);
2988 l.lds.total_size = l.lds.indices_addr + max_primitives * vertices_per_prim;
2989
2990 /* NGG is only allowed to address up to 32K of LDS. */
2991 assert(l.lds.total_size <= 32 * 1024);
2992 return l;
2993 }
2994
2995 void
ac_nir_lower_ngg_ms(nir_shader * shader,bool * out_needs_scratch_ring,unsigned wave_size,bool multiview)2996 ac_nir_lower_ngg_ms(nir_shader *shader,
2997 bool *out_needs_scratch_ring,
2998 unsigned wave_size,
2999 bool multiview)
3000 {
3001 unsigned vertices_per_prim =
3002 num_mesh_vertices_per_primitive(shader->info.mesh.primitive_type);
3003
3004 uint64_t special_outputs =
3005 BITFIELD64_BIT(VARYING_SLOT_PRIMITIVE_COUNT) | BITFIELD64_BIT(VARYING_SLOT_PRIMITIVE_INDICES);
3006 uint64_t per_vertex_outputs =
3007 shader->info.outputs_written & ~shader->info.per_primitive_outputs & ~special_outputs;
3008 uint64_t per_primitive_outputs =
3009 shader->info.per_primitive_outputs & shader->info.outputs_written & ~special_outputs;
3010
3011 /* Can't handle indirect register addressing, pretend as if they were cross-invocation. */
3012 uint64_t cross_invocation_access = shader->info.mesh.ms_cross_invocation_output_access |
3013 shader->info.outputs_accessed_indirectly;
3014
3015 unsigned max_vertices = shader->info.mesh.max_vertices_out;
3016 unsigned max_primitives = shader->info.mesh.max_primitives_out;
3017
3018 ms_out_mem_layout layout =
3019 ms_calculate_output_layout(shader->info.shared_size, per_vertex_outputs, per_primitive_outputs,
3020 cross_invocation_access, max_vertices, max_primitives, vertices_per_prim);
3021
3022 shader->info.shared_size = layout.lds.total_size;
3023 *out_needs_scratch_ring = layout.vram.vtx_attr.mask || layout.vram.prm_attr.mask;
3024
3025 /* The workgroup size that is specified by the API shader may be different
3026 * from the size of the workgroup that actually runs on the HW, due to the
3027 * limitations of NGG: max 0/1 vertex and 0/1 primitive per lane is allowed.
3028 *
3029 * Therefore, we must make sure that when the API workgroup size is smaller,
3030 * we don't run the API shader on more HW invocations than is necessary.
3031 */
3032 unsigned api_workgroup_size = shader->info.workgroup_size[0] *
3033 shader->info.workgroup_size[1] *
3034 shader->info.workgroup_size[2];
3035
3036 unsigned hw_workgroup_size =
3037 ALIGN(MAX3(api_workgroup_size, max_primitives, max_vertices), wave_size);
3038
3039 lower_ngg_ms_state state = {
3040 .layout = layout,
3041 .wave_size = wave_size,
3042 .per_vertex_outputs = per_vertex_outputs,
3043 .per_primitive_outputs = per_primitive_outputs,
3044 .vertices_per_prim = vertices_per_prim,
3045 .api_workgroup_size = api_workgroup_size,
3046 .hw_workgroup_size = hw_workgroup_size,
3047 .insert_layer_output = multiview && !(shader->info.outputs_written & VARYING_BIT_LAYER),
3048 };
3049
3050 nir_function_impl *impl = nir_shader_get_entrypoint(shader);
3051 assert(impl);
3052
3053 nir_builder builder;
3054 nir_builder *b = &builder; /* This is to avoid the & */
3055 nir_builder_init(b, impl);
3056 b->cursor = nir_before_cf_list(&impl->body);
3057
3058 handle_smaller_ms_api_workgroup(b, &state);
3059 emit_ms_prelude(b, &state);
3060 nir_metadata_preserve(impl, nir_metadata_none);
3061
3062 lower_ms_intrinsics(shader, &state);
3063
3064 emit_ms_finale(b, &state);
3065 nir_metadata_preserve(impl, nir_metadata_none);
3066
3067 /* Cleanup */
3068 nir_lower_vars_to_ssa(shader);
3069 nir_remove_dead_variables(shader, nir_var_function_temp, NULL);
3070 nir_lower_alu_to_scalar(shader, NULL, NULL);
3071 nir_lower_phis_to_scalar(shader, true);
3072
3073 nir_validate_shader(shader, "after emitting NGG MS");
3074 }
3075