1 /*
2 * Copyright © 2021 Valve Corporation
3 *
4 * SPDX-License-Identifier: MIT
5 */
6
7 #include "ac_nir.h"
8 #include "ac_nir_helpers.h"
9
10 #include "nir_builder.h"
11
12 #define SPECIAL_MS_OUT_MASK \
13 (BITFIELD64_BIT(VARYING_SLOT_PRIMITIVE_COUNT) | \
14 BITFIELD64_BIT(VARYING_SLOT_PRIMITIVE_INDICES) | \
15 BITFIELD64_BIT(VARYING_SLOT_CULL_PRIMITIVE))
16
17 #define MS_PRIM_ARG_EXP_MASK \
18 (VARYING_BIT_LAYER | \
19 VARYING_BIT_VIEWPORT | \
20 VARYING_BIT_PRIMITIVE_SHADING_RATE)
21
22 #define MS_VERT_ARG_EXP_MASK \
23 (VARYING_BIT_CULL_DIST0 | \
24 VARYING_BIT_CULL_DIST1 | \
25 VARYING_BIT_CLIP_DIST0 | \
26 VARYING_BIT_CLIP_DIST1 | \
27 VARYING_BIT_PSIZ)
28
29 /* LDS layout of Mesh Shader workgroup info. */
30 enum {
31 /* DW0: number of primitives */
32 lds_ms_num_prims = 0,
33 /* DW1: number of vertices */
34 lds_ms_num_vtx = 4,
35 /* DW2: workgroup index within the current dispatch */
36 lds_ms_wg_index = 8,
37 /* DW3: number of API workgroups in flight */
38 lds_ms_num_api_waves = 12,
39 };
40
41 /* Potential location for Mesh Shader outputs. */
42 typedef enum {
43 ms_out_mode_lds,
44 ms_out_mode_scratch_ring,
45 ms_out_mode_attr_ring,
46 ms_out_mode_var,
47 } ms_out_mode;
48
49 typedef struct
50 {
51 uint64_t mask; /* Mask of output locations */
52 uint32_t addr; /* Base address */
53 } ms_out_part;
54
55 typedef struct
56 {
57 /* Mesh shader LDS layout. For details, see ms_calculate_output_layout. */
58 struct {
59 uint32_t workgroup_info_addr;
60 ms_out_part vtx_attr;
61 ms_out_part prm_attr;
62 uint32_t indices_addr;
63 uint32_t cull_flags_addr;
64 uint32_t total_size;
65 } lds;
66
67 /* VRAM "mesh shader scratch ring" layout for outputs that don't fit into the LDS.
68 * Not to be confused with scratch memory.
69 */
70 struct {
71 ms_out_part vtx_attr;
72 ms_out_part prm_attr;
73 } scratch_ring;
74
75 /* VRAM attributes ring (GFX11 only) for all non-position outputs.
76 * GFX11 doesn't have to reload attributes from this ring at the end of the shader.
77 */
78 struct {
79 ms_out_part vtx_attr;
80 ms_out_part prm_attr;
81 } attr_ring;
82
83 /* Outputs without cross-invocation access can be stored in variables. */
84 struct {
85 ms_out_part vtx_attr;
86 ms_out_part prm_attr;
87 } var;
88 } ms_out_mem_layout;
89
90 typedef struct
91 {
92 enum amd_gfx_level gfx_level;
93 bool fast_launch_2;
94 bool vert_multirow_export;
95 bool prim_multirow_export;
96
97 ms_out_mem_layout layout;
98 uint64_t per_vertex_outputs;
99 uint64_t per_primitive_outputs;
100 unsigned vertices_per_prim;
101
102 unsigned wave_size;
103 unsigned api_workgroup_size;
104 unsigned hw_workgroup_size;
105
106 nir_def *workgroup_index;
107 nir_variable *out_variables[VARYING_SLOT_MAX * 4];
108 nir_variable *primitive_count_var;
109 nir_variable *vertex_count_var;
110
111 ac_nir_prerast_out out;
112
113 /* True if the lowering needs to insert the layer output. */
114 bool insert_layer_output;
115 /* True if cull flags are used */
116 bool uses_cull_flags;
117
118 uint32_t clipdist_enable_mask;
119 const uint8_t *vs_output_param_offset;
120 bool has_param_exports;
121
122 /* True if the lowering needs to insert shader query. */
123 bool has_query;
124 } lower_ngg_ms_state;
125
must_wait_attr_ring(enum amd_gfx_level gfx_level,bool has_param_exports)126 static bool must_wait_attr_ring(enum amd_gfx_level gfx_level, bool has_param_exports)
127 {
128 return (gfx_level == GFX11 || gfx_level == GFX11_5) && has_param_exports;
129 }
130
131 static void
ms_store_prim_indices(nir_builder * b,nir_intrinsic_instr * intrin,lower_ngg_ms_state * s)132 ms_store_prim_indices(nir_builder *b,
133 nir_intrinsic_instr *intrin,
134 lower_ngg_ms_state *s)
135 {
136 /* EXT_mesh_shader primitive indices: array of vectors.
137 * They don't count as per-primitive outputs, but the array is indexed
138 * by the primitive index, so they are practically per-primitive.
139 */
140 assert(nir_src_is_const(*nir_get_io_offset_src(intrin)));
141 assert(nir_src_as_uint(*nir_get_io_offset_src(intrin)) == 0);
142
143 const unsigned component_offset = nir_intrinsic_component(intrin);
144 nir_def *store_val = intrin->src[0].ssa;
145 assert(store_val->num_components <= 3);
146
147 if (store_val->num_components > s->vertices_per_prim)
148 store_val = nir_trim_vector(b, store_val, s->vertices_per_prim);
149
150 if (s->layout.var.prm_attr.mask & BITFIELD64_BIT(VARYING_SLOT_PRIMITIVE_INDICES)) {
151 for (unsigned c = 0; c < store_val->num_components; ++c) {
152 const unsigned i = VARYING_SLOT_PRIMITIVE_INDICES * 4 + c + component_offset;
153 nir_store_var(b, s->out_variables[i], nir_channel(b, store_val, c), 0x1);
154 }
155 return;
156 }
157
158 nir_def *arr_index = nir_get_io_arrayed_index_src(intrin)->ssa;
159 nir_def *offset = nir_imul_imm(b, arr_index, s->vertices_per_prim);
160
161 /* The max vertex count is 256, so these indices always fit 8 bits.
162 * To reduce LDS use, store these as a flat array of 8-bit values.
163 */
164 nir_store_shared(b, nir_u2u8(b, store_val), offset, .base = s->layout.lds.indices_addr + component_offset);
165 }
166
167 static void
ms_store_cull_flag(nir_builder * b,nir_intrinsic_instr * intrin,lower_ngg_ms_state * s)168 ms_store_cull_flag(nir_builder *b,
169 nir_intrinsic_instr *intrin,
170 lower_ngg_ms_state *s)
171 {
172 /* EXT_mesh_shader cull primitive: per-primitive bool. */
173 assert(nir_src_is_const(*nir_get_io_offset_src(intrin)));
174 assert(nir_src_as_uint(*nir_get_io_offset_src(intrin)) == 0);
175 assert(nir_intrinsic_component(intrin) == 0);
176 assert(nir_intrinsic_write_mask(intrin) == 1);
177
178 nir_def *store_val = intrin->src[0].ssa;
179
180 assert(store_val->num_components == 1);
181 assert(store_val->bit_size == 1);
182
183 if (s->layout.var.prm_attr.mask & BITFIELD64_BIT(VARYING_SLOT_CULL_PRIMITIVE)) {
184 nir_store_var(b, s->out_variables[VARYING_SLOT_CULL_PRIMITIVE * 4], nir_b2i32(b, store_val), 0x1);
185 return;
186 }
187
188 nir_def *arr_index = nir_get_io_arrayed_index_src(intrin)->ssa;
189 nir_def *offset = nir_imul_imm(b, arr_index, s->vertices_per_prim);
190
191 /* To reduce LDS use, store these as an array of 8-bit values. */
192 nir_store_shared(b, nir_b2i8(b, store_val), offset, .base = s->layout.lds.cull_flags_addr);
193 }
194
195 static nir_def *
ms_arrayed_output_base_addr(nir_builder * b,nir_def * arr_index,unsigned mapped_location,unsigned num_arrayed_outputs)196 ms_arrayed_output_base_addr(nir_builder *b,
197 nir_def *arr_index,
198 unsigned mapped_location,
199 unsigned num_arrayed_outputs)
200 {
201 /* Address offset of the array item (vertex or primitive). */
202 unsigned arr_index_stride = num_arrayed_outputs * 16u;
203 nir_def *arr_index_off = nir_imul_imm(b, arr_index, arr_index_stride);
204
205 /* IO address offset within the vertex or primitive data. */
206 unsigned io_offset = mapped_location * 16u;
207 nir_def *io_off = nir_imm_int(b, io_offset);
208
209 return nir_iadd_nuw(b, arr_index_off, io_off);
210 }
211
212 static void
update_ms_output_info(const nir_io_semantics io_sem,const nir_src * base_offset_src,const uint32_t write_mask,const unsigned component_offset,const unsigned bit_size,const ms_out_part * out,lower_ngg_ms_state * s)213 update_ms_output_info(const nir_io_semantics io_sem,
214 const nir_src *base_offset_src,
215 const uint32_t write_mask,
216 const unsigned component_offset,
217 const unsigned bit_size,
218 const ms_out_part *out,
219 lower_ngg_ms_state *s)
220 {
221 const uint32_t components_mask = write_mask << component_offset;
222
223 /* 64-bit outputs should have already been lowered to 32-bit. */
224 assert(bit_size <= 32);
225 assert(components_mask <= 0xf);
226
227 /* When the base offset is constant, only mark the components of the current slot as used.
228 * Otherwise, mark the components of all possibly affected slots as used.
229 */
230 const unsigned base_off_start = nir_src_is_const(*base_offset_src) ? nir_src_as_uint(*base_offset_src) : 0;
231 const unsigned num_slots = nir_src_is_const(*base_offset_src) ? 1 : io_sem.num_slots;
232
233 for (unsigned base_off = base_off_start; base_off < num_slots; ++base_off) {
234 ac_nir_prerast_per_output_info *info = &s->out.infos[io_sem.location + base_off];
235 info->components_mask |= components_mask;
236
237 if (!io_sem.no_sysval_output)
238 info->as_sysval_mask |= components_mask;
239 if (!io_sem.no_varying)
240 info->as_varying_mask |= components_mask;
241 }
242 }
243
244 static const ms_out_part *
ms_get_out_layout_part(unsigned location,shader_info * info,ms_out_mode * out_mode,lower_ngg_ms_state * s)245 ms_get_out_layout_part(unsigned location,
246 shader_info *info,
247 ms_out_mode *out_mode,
248 lower_ngg_ms_state *s)
249 {
250 uint64_t mask = BITFIELD64_BIT(location);
251
252 if (info->per_primitive_outputs & mask) {
253 if (mask & s->layout.lds.prm_attr.mask) {
254 *out_mode = ms_out_mode_lds;
255 return &s->layout.lds.prm_attr;
256 } else if (mask & s->layout.scratch_ring.prm_attr.mask) {
257 *out_mode = ms_out_mode_scratch_ring;
258 return &s->layout.scratch_ring.prm_attr;
259 } else if (mask & s->layout.attr_ring.prm_attr.mask) {
260 *out_mode = ms_out_mode_attr_ring;
261 return &s->layout.attr_ring.prm_attr;
262 } else if (mask & s->layout.var.prm_attr.mask) {
263 *out_mode = ms_out_mode_var;
264 return &s->layout.var.prm_attr;
265 }
266 } else {
267 if (mask & s->layout.lds.vtx_attr.mask) {
268 *out_mode = ms_out_mode_lds;
269 return &s->layout.lds.vtx_attr;
270 } else if (mask & s->layout.scratch_ring.vtx_attr.mask) {
271 *out_mode = ms_out_mode_scratch_ring;
272 return &s->layout.scratch_ring.vtx_attr;
273 } else if (mask & s->layout.attr_ring.vtx_attr.mask) {
274 *out_mode = ms_out_mode_attr_ring;
275 return &s->layout.attr_ring.vtx_attr;
276 } else if (mask & s->layout.var.vtx_attr.mask) {
277 *out_mode = ms_out_mode_var;
278 return &s->layout.var.vtx_attr;
279 }
280 }
281
282 unreachable("Couldn't figure out mesh shader output mode.");
283 }
284
285 static void
ms_store_arrayed_output(nir_builder * b,nir_src * base_off_src,nir_def * store_val,nir_def * arr_index,const nir_io_semantics io_sem,const unsigned component_offset,const unsigned write_mask,lower_ngg_ms_state * s)286 ms_store_arrayed_output(nir_builder *b,
287 nir_src *base_off_src,
288 nir_def *store_val,
289 nir_def *arr_index,
290 const nir_io_semantics io_sem,
291 const unsigned component_offset,
292 const unsigned write_mask,
293 lower_ngg_ms_state *s)
294 {
295 ms_out_mode out_mode;
296 const ms_out_part *out = ms_get_out_layout_part(io_sem.location, &b->shader->info, &out_mode, s);
297 update_ms_output_info(io_sem, base_off_src, write_mask, component_offset, store_val->bit_size, out, s);
298
299 bool hi_16b = io_sem.high_16bits;
300 bool lo_16b = !hi_16b && store_val->bit_size == 16;
301
302 unsigned mapped_location = util_bitcount64(out->mask & u_bit_consecutive64(0, io_sem.location));
303 unsigned num_outputs = util_bitcount64(out->mask);
304 unsigned const_off = out->addr + component_offset * 4 + (hi_16b ? 2 : 0);
305
306 nir_def *base_addr = ms_arrayed_output_base_addr(b, arr_index, mapped_location, num_outputs);
307 nir_def *base_offset = base_off_src->ssa;
308 nir_def *base_addr_off = nir_imul_imm(b, base_offset, 16u);
309 nir_def *addr = nir_iadd_nuw(b, base_addr, base_addr_off);
310
311 if (out_mode == ms_out_mode_lds) {
312 nir_store_shared(b, store_val, addr, .base = const_off,
313 .write_mask = write_mask, .align_mul = 16,
314 .align_offset = const_off % 16);
315 } else if (out_mode == ms_out_mode_scratch_ring) {
316 nir_def *ring = nir_load_ring_mesh_scratch_amd(b);
317 nir_def *off = nir_load_ring_mesh_scratch_offset_amd(b);
318 nir_def *zero = nir_imm_int(b, 0);
319 nir_store_buffer_amd(b, store_val, ring, addr, off, zero,
320 .base = const_off,
321 .write_mask = write_mask,
322 .memory_modes = nir_var_shader_out,
323 .access = ACCESS_COHERENT);
324 } else if (out_mode == ms_out_mode_attr_ring) {
325 /* GFX11+: Store params straight to the attribute ring.
326 *
327 * Even though the access pattern may not be the most optimal,
328 * this is still much better than reserving LDS and losing waves.
329 * (Also much better than storing and reloading from the scratch ring.)
330 */
331 unsigned param_offset = s->vs_output_param_offset[io_sem.location];
332 nir_def *ring = nir_load_ring_attr_amd(b);
333 nir_def *soffset = nir_load_ring_attr_offset_amd(b);
334 nir_store_buffer_amd(b, store_val, ring, base_addr_off, soffset, arr_index,
335 .base = const_off + param_offset * 16,
336 .write_mask = write_mask,
337 .memory_modes = nir_var_shader_out,
338 .access = ACCESS_COHERENT | ACCESS_IS_SWIZZLED_AMD);
339 } else if (out_mode == ms_out_mode_var) {
340 unsigned write_mask_32 = write_mask;
341 if (store_val->bit_size > 32) {
342 /* Split 64-bit store values to 32-bit components. */
343 store_val = nir_bitcast_vector(b, store_val, 32);
344 /* Widen the write mask so it is in 32-bit components. */
345 write_mask_32 = util_widen_mask(write_mask, store_val->bit_size / 32);
346 }
347
348 u_foreach_bit(comp, write_mask_32) {
349 unsigned idx = io_sem.location * 4 + comp + component_offset;
350 nir_def *val = nir_channel(b, store_val, comp);
351 nir_def *v = nir_load_var(b, s->out_variables[idx]);
352
353 if (lo_16b) {
354 nir_def *var_hi = nir_unpack_32_2x16_split_y(b, v);
355 val = nir_pack_32_2x16_split(b, val, var_hi);
356 } else if (hi_16b) {
357 nir_def *var_lo = nir_unpack_32_2x16_split_x(b, v);
358 val = nir_pack_32_2x16_split(b, var_lo, val);
359 }
360
361 nir_store_var(b, s->out_variables[idx], val, 0x1);
362 }
363 } else {
364 unreachable("Invalid MS output mode for store");
365 }
366 }
367
368 static void
ms_store_arrayed_output_intrin(nir_builder * b,nir_intrinsic_instr * intrin,lower_ngg_ms_state * s)369 ms_store_arrayed_output_intrin(nir_builder *b,
370 nir_intrinsic_instr *intrin,
371 lower_ngg_ms_state *s)
372 {
373 const nir_io_semantics io_sem = nir_intrinsic_io_semantics(intrin);
374
375 if (io_sem.location == VARYING_SLOT_PRIMITIVE_INDICES) {
376 ms_store_prim_indices(b, intrin, s);
377 return;
378 } else if (io_sem.location == VARYING_SLOT_CULL_PRIMITIVE) {
379 ms_store_cull_flag(b, intrin, s);
380 return;
381 }
382
383 unsigned component_offset = nir_intrinsic_component(intrin);
384 unsigned write_mask = nir_intrinsic_write_mask(intrin);
385
386 nir_def *store_val = intrin->src[0].ssa;
387 nir_def *arr_index = nir_get_io_arrayed_index_src(intrin)->ssa;
388 nir_src *base_off_src = nir_get_io_offset_src(intrin);
389
390 if (store_val->bit_size < 32) {
391 /* Split 16-bit output stores to ensure each 16-bit component is stored
392 * in the correct location, without overwriting the other 16 bits there.
393 */
394 u_foreach_bit(c, write_mask) {
395 nir_def *store_component = nir_channel(b, store_val, c);
396 ms_store_arrayed_output(b, base_off_src, store_component, arr_index, io_sem, c + component_offset, 1, s);
397 }
398 } else {
399 ms_store_arrayed_output(b, base_off_src, store_val, arr_index, io_sem, component_offset, write_mask, s);
400 }
401 }
402
403 static nir_def *
ms_load_arrayed_output(nir_builder * b,nir_def * arr_index,nir_def * base_offset,unsigned location,unsigned component_offset,unsigned num_components,unsigned load_bit_size,lower_ngg_ms_state * s)404 ms_load_arrayed_output(nir_builder *b,
405 nir_def *arr_index,
406 nir_def *base_offset,
407 unsigned location,
408 unsigned component_offset,
409 unsigned num_components,
410 unsigned load_bit_size,
411 lower_ngg_ms_state *s)
412 {
413 ms_out_mode out_mode;
414 const ms_out_part *out = ms_get_out_layout_part(location, &b->shader->info, &out_mode, s);
415
416 unsigned component_addr_off = component_offset * 4;
417 unsigned num_outputs = util_bitcount64(out->mask);
418 unsigned const_off = out->addr + component_offset * 4;
419
420 /* Use compacted location instead of the original semantic location. */
421 unsigned mapped_location = util_bitcount64(out->mask & u_bit_consecutive64(0, location));
422
423 nir_def *base_addr = ms_arrayed_output_base_addr(b, arr_index, mapped_location, num_outputs);
424 nir_def *base_addr_off = nir_imul_imm(b, base_offset, 16);
425 nir_def *addr = nir_iadd_nuw(b, base_addr, base_addr_off);
426
427 if (out_mode == ms_out_mode_lds) {
428 return nir_load_shared(b, num_components, load_bit_size, addr, .align_mul = 16,
429 .align_offset = component_addr_off % 16,
430 .base = const_off);
431 } else if (out_mode == ms_out_mode_scratch_ring) {
432 nir_def *ring = nir_load_ring_mesh_scratch_amd(b);
433 nir_def *off = nir_load_ring_mesh_scratch_offset_amd(b);
434 nir_def *zero = nir_imm_int(b, 0);
435 return nir_load_buffer_amd(b, num_components, load_bit_size, ring, addr, off, zero,
436 .base = const_off,
437 .memory_modes = nir_var_shader_out,
438 .access = ACCESS_COHERENT);
439 } else if (out_mode == ms_out_mode_var) {
440 assert(load_bit_size == 32);
441 nir_def *arr[8] = {0};
442 for (unsigned comp = 0; comp < num_components; ++comp) {
443 unsigned idx = location * 4 + comp + component_addr_off;
444 arr[comp] = nir_load_var(b, s->out_variables[idx]);
445 }
446 return nir_vec(b, arr, num_components);
447 } else {
448 unreachable("Invalid MS output mode for load");
449 }
450 }
451
452 static nir_def *
lower_ms_load_workgroup_index(nir_builder * b,UNUSED nir_intrinsic_instr * intrin,lower_ngg_ms_state * s)453 lower_ms_load_workgroup_index(nir_builder *b,
454 UNUSED nir_intrinsic_instr *intrin,
455 lower_ngg_ms_state *s)
456 {
457 return s->workgroup_index;
458 }
459
460 static nir_def *
lower_ms_set_vertex_and_primitive_count(nir_builder * b,nir_intrinsic_instr * intrin,lower_ngg_ms_state * s)461 lower_ms_set_vertex_and_primitive_count(nir_builder *b,
462 nir_intrinsic_instr *intrin,
463 lower_ngg_ms_state *s)
464 {
465 /* If either the number of vertices or primitives is zero, set both of them to zero. */
466 nir_def *num_vtx = nir_read_first_invocation(b, intrin->src[0].ssa);
467 nir_def *num_prm = nir_read_first_invocation(b, intrin->src[1].ssa);
468 nir_def *zero = nir_imm_int(b, 0);
469 nir_def *is_either_zero = nir_ieq(b, nir_umin(b, num_vtx, num_prm), zero);
470 num_vtx = nir_bcsel(b, is_either_zero, zero, num_vtx);
471 num_prm = nir_bcsel(b, is_either_zero, zero, num_prm);
472
473 nir_store_var(b, s->vertex_count_var, num_vtx, 0x1);
474 nir_store_var(b, s->primitive_count_var, num_prm, 0x1);
475
476 return NIR_LOWER_INSTR_PROGRESS_REPLACE;
477 }
478
479 static nir_def *
update_ms_barrier(nir_builder * b,nir_intrinsic_instr * intrin,lower_ngg_ms_state * s)480 update_ms_barrier(nir_builder *b,
481 nir_intrinsic_instr *intrin,
482 lower_ngg_ms_state *s)
483 {
484 /* Output loads and stores are lowered to shared memory access,
485 * so we have to update the barriers to also reflect this.
486 */
487 unsigned mem_modes = nir_intrinsic_memory_modes(intrin);
488 if (mem_modes & nir_var_shader_out)
489 mem_modes |= nir_var_mem_shared;
490 else
491 return NULL;
492
493 nir_intrinsic_set_memory_modes(intrin, mem_modes);
494
495 return NIR_LOWER_INSTR_PROGRESS;
496 }
497
498 static nir_def *
lower_ms_intrinsic(nir_builder * b,nir_instr * instr,void * state)499 lower_ms_intrinsic(nir_builder *b, nir_instr *instr, void *state)
500 {
501 lower_ngg_ms_state *s = (lower_ngg_ms_state *) state;
502
503 if (instr->type != nir_instr_type_intrinsic)
504 return NULL;
505
506 nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
507
508 switch (intrin->intrinsic) {
509 case nir_intrinsic_store_per_vertex_output:
510 case nir_intrinsic_store_per_primitive_output:
511 ms_store_arrayed_output_intrin(b, intrin, s);
512 return NIR_LOWER_INSTR_PROGRESS_REPLACE;
513 case nir_intrinsic_barrier:
514 return update_ms_barrier(b, intrin, s);
515 case nir_intrinsic_load_workgroup_index:
516 return lower_ms_load_workgroup_index(b, intrin, s);
517 case nir_intrinsic_set_vertex_and_primitive_count:
518 return lower_ms_set_vertex_and_primitive_count(b, intrin, s);
519 default:
520 unreachable("Not a lowerable mesh shader intrinsic.");
521 }
522 }
523
524 static bool
filter_ms_intrinsic(const nir_instr * instr,UNUSED const void * s)525 filter_ms_intrinsic(const nir_instr *instr,
526 UNUSED const void *s)
527 {
528 if (instr->type != nir_instr_type_intrinsic)
529 return false;
530
531 nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
532 return intrin->intrinsic == nir_intrinsic_store_output ||
533 intrin->intrinsic == nir_intrinsic_load_output ||
534 intrin->intrinsic == nir_intrinsic_store_per_vertex_output ||
535 intrin->intrinsic == nir_intrinsic_store_per_primitive_output ||
536 intrin->intrinsic == nir_intrinsic_barrier ||
537 intrin->intrinsic == nir_intrinsic_load_workgroup_index ||
538 intrin->intrinsic == nir_intrinsic_set_vertex_and_primitive_count;
539 }
540
541 static void
lower_ms_intrinsics(nir_shader * shader,lower_ngg_ms_state * s)542 lower_ms_intrinsics(nir_shader *shader, lower_ngg_ms_state *s)
543 {
544 nir_shader_lower_instructions(shader, filter_ms_intrinsic, lower_ms_intrinsic, s);
545 }
546
547 static void
ms_emit_arrayed_outputs(nir_builder * b,nir_def * invocation_index,uint64_t mask,lower_ngg_ms_state * s)548 ms_emit_arrayed_outputs(nir_builder *b,
549 nir_def *invocation_index,
550 uint64_t mask,
551 lower_ngg_ms_state *s)
552 {
553 nir_def *zero = nir_imm_int(b, 0);
554
555 u_foreach_bit64(slot, mask) {
556 /* Should not occur here, handled separately. */
557 assert(slot != VARYING_SLOT_PRIMITIVE_COUNT && slot != VARYING_SLOT_PRIMITIVE_INDICES);
558
559 unsigned component_mask = s->out.infos[slot].components_mask;
560
561 while (component_mask) {
562 int start_comp = 0, num_components = 1;
563 u_bit_scan_consecutive_range(&component_mask, &start_comp, &num_components);
564
565 nir_def *load =
566 ms_load_arrayed_output(b, invocation_index, zero, slot, start_comp,
567 num_components, 32, s);
568
569 for (int i = 0; i < num_components; i++)
570 s->out.outputs[slot][start_comp + i] = nir_channel(b, load, i);
571 }
572 }
573 }
574
575 static void
ms_create_same_invocation_vars(nir_builder * b,lower_ngg_ms_state * s)576 ms_create_same_invocation_vars(nir_builder *b, lower_ngg_ms_state *s)
577 {
578 /* Initialize NIR variables for same-invocation outputs. */
579 uint64_t same_invocation_output_mask = s->layout.var.prm_attr.mask | s->layout.var.vtx_attr.mask;
580
581 u_foreach_bit64(slot, same_invocation_output_mask) {
582 for (unsigned comp = 0; comp < 4; ++comp) {
583 unsigned idx = slot * 4 + comp;
584 s->out_variables[idx] = nir_local_variable_create(b->impl, glsl_uint_type(), "ms_var_output");
585 }
586 }
587 }
588
589 static void
ms_emit_legacy_workgroup_index(nir_builder * b,lower_ngg_ms_state * s)590 ms_emit_legacy_workgroup_index(nir_builder *b, lower_ngg_ms_state *s)
591 {
592 /* Workgroup ID should have been lowered to workgroup index. */
593 assert(!BITSET_TEST(b->shader->info.system_values_read, SYSTEM_VALUE_WORKGROUP_ID));
594
595 /* No need to do anything if the shader doesn't use the workgroup index. */
596 if (!BITSET_TEST(b->shader->info.system_values_read, SYSTEM_VALUE_WORKGROUP_INDEX))
597 return;
598
599 b->cursor = nir_before_impl(b->impl);
600
601 /* Legacy fast launch mode (FAST_LAUNCH=1):
602 *
603 * The HW doesn't support a proper workgroup index for vertex processing stages,
604 * so we use the vertex ID which is equivalent to the index of the current workgroup
605 * within the current dispatch.
606 *
607 * Due to the register programming of mesh shaders, this value is only filled for
608 * the first invocation of the first wave. To let other waves know, we use LDS.
609 */
610 nir_def *workgroup_index = nir_load_vertex_id_zero_base(b);
611
612 if (s->api_workgroup_size <= s->wave_size) {
613 /* API workgroup is small, so we don't need to use LDS. */
614 s->workgroup_index = nir_read_first_invocation(b, workgroup_index);
615 return;
616 }
617
618 unsigned workgroup_index_lds_addr = s->layout.lds.workgroup_info_addr + lds_ms_wg_index;
619
620 nir_def *zero = nir_imm_int(b, 0);
621 nir_def *dont_care = nir_undef(b, 1, 32);
622 nir_def *loaded_workgroup_index = NULL;
623
624 /* Use elect to make sure only 1 invocation uses LDS. */
625 nir_if *if_elected = nir_push_if(b, nir_elect(b, 1));
626 {
627 nir_def *wave_id = nir_load_subgroup_id(b);
628 nir_if *if_wave_0 = nir_push_if(b, nir_ieq_imm(b, wave_id, 0));
629 {
630 nir_store_shared(b, workgroup_index, zero, .base = workgroup_index_lds_addr);
631 nir_barrier(b, .execution_scope = SCOPE_WORKGROUP,
632 .memory_scope = SCOPE_WORKGROUP,
633 .memory_semantics = NIR_MEMORY_ACQ_REL,
634 .memory_modes = nir_var_mem_shared);
635 }
636 nir_push_else(b, if_wave_0);
637 {
638 nir_barrier(b, .execution_scope = SCOPE_WORKGROUP,
639 .memory_scope = SCOPE_WORKGROUP,
640 .memory_semantics = NIR_MEMORY_ACQ_REL,
641 .memory_modes = nir_var_mem_shared);
642 loaded_workgroup_index = nir_load_shared(b, 1, 32, zero, .base = workgroup_index_lds_addr);
643 }
644 nir_pop_if(b, if_wave_0);
645
646 workgroup_index = nir_if_phi(b, workgroup_index, loaded_workgroup_index);
647 }
648 nir_pop_if(b, if_elected);
649
650 workgroup_index = nir_if_phi(b, workgroup_index, dont_care);
651 s->workgroup_index = nir_read_first_invocation(b, workgroup_index);
652 }
653
654 static void
set_ms_final_output_counts(nir_builder * b,lower_ngg_ms_state * s,nir_def ** out_num_prm,nir_def ** out_num_vtx)655 set_ms_final_output_counts(nir_builder *b,
656 lower_ngg_ms_state *s,
657 nir_def **out_num_prm,
658 nir_def **out_num_vtx)
659 {
660 /* The spec allows the numbers to be divergent, and in that case we need to
661 * use the values from the first invocation. Also the HW requires us to set
662 * both to 0 if either was 0.
663 *
664 * These are already done by the lowering.
665 */
666 nir_def *num_prm = nir_load_var(b, s->primitive_count_var);
667 nir_def *num_vtx = nir_load_var(b, s->vertex_count_var);
668
669 if (s->hw_workgroup_size <= s->wave_size) {
670 /* Single-wave mesh shader workgroup. */
671 nir_def *m0 = nir_ior(b, nir_ishl_imm(b, num_prm, 12), num_vtx);
672 nir_sendmsg_amd(b, m0, .base = AC_SENDMSG_GS_ALLOC_REQ);
673
674 *out_num_prm = num_prm;
675 *out_num_vtx = num_vtx;
676 return;
677 }
678
679 /* Multi-wave mesh shader workgroup:
680 * We need to use LDS to distribute the correct values to the other waves.
681 *
682 * TODO:
683 * If we can prove that the values are workgroup-uniform, we can skip this
684 * and just use whatever the current wave has. However, NIR divergence analysis
685 * currently doesn't support this.
686 */
687
688 nir_def *zero = nir_imm_int(b, 0);
689
690 nir_if *if_wave_0 = nir_push_if(b, nir_ieq_imm(b, nir_load_subgroup_id(b), 0));
691 {
692 nir_if *if_elected = nir_push_if(b, nir_elect(b, 1));
693 {
694 nir_store_shared(b, nir_vec2(b, num_prm, num_vtx), zero,
695 .base = s->layout.lds.workgroup_info_addr + lds_ms_num_prims);
696 }
697 nir_pop_if(b, if_elected);
698
699 nir_barrier(b, .execution_scope = SCOPE_WORKGROUP,
700 .memory_scope = SCOPE_WORKGROUP,
701 .memory_semantics = NIR_MEMORY_ACQ_REL,
702 .memory_modes = nir_var_mem_shared);
703
704 nir_def *m0 = nir_ior(b, nir_ishl_imm(b, num_prm, 12), num_vtx);
705 nir_sendmsg_amd(b, m0, .base = AC_SENDMSG_GS_ALLOC_REQ);
706 }
707 nir_push_else(b, if_wave_0);
708 {
709 nir_barrier(b, .execution_scope = SCOPE_WORKGROUP,
710 .memory_scope = SCOPE_WORKGROUP,
711 .memory_semantics = NIR_MEMORY_ACQ_REL,
712 .memory_modes = nir_var_mem_shared);
713
714 nir_def *prm_vtx = NULL;
715 nir_def *dont_care_2x32 = nir_undef(b, 2, 32);
716 nir_if *if_elected = nir_push_if(b, nir_elect(b, 1));
717 {
718 prm_vtx = nir_load_shared(b, 2, 32, zero,
719 .base = s->layout.lds.workgroup_info_addr + lds_ms_num_prims);
720 }
721 nir_pop_if(b, if_elected);
722
723 prm_vtx = nir_if_phi(b, prm_vtx, dont_care_2x32);
724 num_prm = nir_read_first_invocation(b, nir_channel(b, prm_vtx, 0));
725 num_vtx = nir_read_first_invocation(b, nir_channel(b, prm_vtx, 1));
726
727 nir_store_var(b, s->primitive_count_var, num_prm, 0x1);
728 nir_store_var(b, s->vertex_count_var, num_vtx, 0x1);
729 }
730 nir_pop_if(b, if_wave_0);
731
732 *out_num_prm = nir_load_var(b, s->primitive_count_var);
733 *out_num_vtx = nir_load_var(b, s->vertex_count_var);
734 }
735
736 static void
ms_emit_attribute_ring_output_stores(nir_builder * b,const uint64_t outputs_mask,nir_def * idx,lower_ngg_ms_state * s)737 ms_emit_attribute_ring_output_stores(nir_builder *b, const uint64_t outputs_mask,
738 nir_def *idx, lower_ngg_ms_state *s)
739 {
740 if (!outputs_mask)
741 return;
742
743 nir_def *ring = nir_load_ring_attr_amd(b);
744 nir_def *off = nir_load_ring_attr_offset_amd(b);
745 nir_def *zero = nir_imm_int(b, 0);
746
747 u_foreach_bit64 (slot, outputs_mask) {
748 if (s->vs_output_param_offset[slot] > AC_EXP_PARAM_OFFSET_31)
749 continue;
750
751 nir_def *soffset = nir_iadd_imm(b, off, s->vs_output_param_offset[slot] * 16 * 32);
752 nir_def *store_val = nir_undef(b, 4, 32);
753 unsigned store_val_components = 0;
754 for (unsigned c = 0; c < 4; ++c) {
755 if (s->out.outputs[slot][c]) {
756 store_val = nir_vector_insert_imm(b, store_val, s->out.outputs[slot][c], c);
757 store_val_components = c + 1;
758 }
759 }
760
761 store_val = nir_trim_vector(b, store_val, store_val_components);
762 nir_store_buffer_amd(b, store_val, ring, zero, soffset, idx,
763 .memory_modes = nir_var_shader_out,
764 .access = ACCESS_COHERENT | ACCESS_IS_SWIZZLED_AMD);
765 }
766 }
767
768 static nir_def *
ms_prim_exp_arg_ch1(nir_builder * b,nir_def * invocation_index,nir_def * num_vtx,lower_ngg_ms_state * s)769 ms_prim_exp_arg_ch1(nir_builder *b, nir_def *invocation_index, nir_def *num_vtx, lower_ngg_ms_state *s)
770 {
771 /* Primitive connectivity data: describes which vertices the primitive uses. */
772 nir_def *prim_idx_addr = nir_imul_imm(b, invocation_index, s->vertices_per_prim);
773 nir_def *indices_loaded = NULL;
774 nir_def *cull_flag = NULL;
775
776 if (s->layout.var.prm_attr.mask & BITFIELD64_BIT(VARYING_SLOT_PRIMITIVE_INDICES)) {
777 nir_def *indices[3] = {0};
778 for (unsigned c = 0; c < s->vertices_per_prim; ++c)
779 indices[c] = nir_load_var(b, s->out_variables[VARYING_SLOT_PRIMITIVE_INDICES * 4 + c]);
780 indices_loaded = nir_vec(b, indices, s->vertices_per_prim);
781 } else {
782 indices_loaded = nir_load_shared(b, s->vertices_per_prim, 8, prim_idx_addr, .base = s->layout.lds.indices_addr);
783 indices_loaded = nir_u2u32(b, indices_loaded);
784 }
785
786 if (s->uses_cull_flags) {
787 nir_def *loaded_cull_flag = NULL;
788 if (s->layout.var.prm_attr.mask & BITFIELD64_BIT(VARYING_SLOT_CULL_PRIMITIVE))
789 loaded_cull_flag = nir_load_var(b, s->out_variables[VARYING_SLOT_CULL_PRIMITIVE * 4]);
790 else
791 loaded_cull_flag = nir_u2u32(b, nir_load_shared(b, 1, 8, prim_idx_addr, .base = s->layout.lds.cull_flags_addr));
792
793 cull_flag = nir_i2b(b, loaded_cull_flag);
794 }
795
796 nir_def *indices[3];
797 nir_def *max_vtx_idx = nir_iadd_imm(b, num_vtx, -1u);
798
799 for (unsigned i = 0; i < s->vertices_per_prim; ++i) {
800 indices[i] = nir_channel(b, indices_loaded, i);
801 indices[i] = nir_umin(b, indices[i], max_vtx_idx);
802 }
803
804 return ac_nir_pack_ngg_prim_exp_arg(b, s->vertices_per_prim, indices, cull_flag, s->gfx_level);
805 }
806
807 static nir_def *
ms_prim_exp_arg_ch2(nir_builder * b,uint64_t outputs_mask,lower_ngg_ms_state * s)808 ms_prim_exp_arg_ch2(nir_builder *b, uint64_t outputs_mask, lower_ngg_ms_state *s)
809 {
810 nir_def *prim_exp_arg_ch2 = NULL;
811
812 if (outputs_mask) {
813 /* When layer, viewport etc. are per-primitive, they need to be encoded in
814 * the primitive export instruction's second channel. The encoding is:
815 *
816 * --- GFX10.3 ---
817 * bits 31..30: VRS rate Y
818 * bits 29..28: VRS rate X
819 * bits 23..20: viewport
820 * bits 19..17: layer
821 *
822 * --- GFX11 ---
823 * bits 31..28: VRS rate enum
824 * bits 23..20: viewport
825 * bits 12..00: layer
826 */
827 prim_exp_arg_ch2 = nir_imm_int(b, 0);
828
829 if (outputs_mask & VARYING_BIT_LAYER) {
830 nir_def *layer =
831 nir_ishl_imm(b, s->out.outputs[VARYING_SLOT_LAYER][0], s->gfx_level >= GFX11 ? 0 : 17);
832 prim_exp_arg_ch2 = nir_ior(b, prim_exp_arg_ch2, layer);
833 }
834
835 if (outputs_mask & VARYING_BIT_VIEWPORT) {
836 nir_def *view = nir_ishl_imm(b, s->out.outputs[VARYING_SLOT_VIEWPORT][0], 20);
837 prim_exp_arg_ch2 = nir_ior(b, prim_exp_arg_ch2, view);
838 }
839
840 if (outputs_mask & VARYING_BIT_PRIMITIVE_SHADING_RATE) {
841 nir_def *rate = s->out.outputs[VARYING_SLOT_PRIMITIVE_SHADING_RATE][0];
842 prim_exp_arg_ch2 = nir_ior(b, prim_exp_arg_ch2, rate);
843 }
844 }
845
846 return prim_exp_arg_ch2;
847 }
848
849 static void
ms_prim_gen_query(nir_builder * b,nir_def * invocation_index,nir_def * num_prm,lower_ngg_ms_state * s)850 ms_prim_gen_query(nir_builder *b,
851 nir_def *invocation_index,
852 nir_def *num_prm,
853 lower_ngg_ms_state *s)
854 {
855 if (!s->has_query)
856 return;
857
858 nir_if *if_invocation_index_zero = nir_push_if(b, nir_ieq_imm(b, invocation_index, 0));
859 {
860 nir_if *if_shader_query = nir_push_if(b, nir_load_prim_gen_query_enabled_amd(b));
861 {
862 nir_atomic_add_gen_prim_count_amd(b, num_prm, .stream_id = 0);
863 }
864 nir_pop_if(b, if_shader_query);
865 }
866 nir_pop_if(b, if_invocation_index_zero);
867 }
868
869 static void
ms_invocation_query(nir_builder * b,nir_def * invocation_index,lower_ngg_ms_state * s)870 ms_invocation_query(nir_builder *b,
871 nir_def *invocation_index,
872 lower_ngg_ms_state *s)
873 {
874 if (!s->has_query)
875 return;
876
877 nir_if *if_invocation_index_zero = nir_push_if(b, nir_ieq_imm(b, invocation_index, 0));
878 {
879 nir_if *if_pipeline_query = nir_push_if(b, nir_load_pipeline_stat_query_enabled_amd(b));
880 {
881 nir_atomic_add_shader_invocation_count_amd(b, nir_imm_int(b, s->api_workgroup_size));
882 }
883 nir_pop_if(b, if_pipeline_query);
884 }
885 nir_pop_if(b, if_invocation_index_zero);
886 }
887
888 static void
emit_ms_vertex(nir_builder * b,nir_def * index,nir_def * row,bool exports,bool parameters,uint64_t per_vertex_outputs,lower_ngg_ms_state * s)889 emit_ms_vertex(nir_builder *b, nir_def *index, nir_def *row, bool exports, bool parameters,
890 uint64_t per_vertex_outputs, lower_ngg_ms_state *s)
891 {
892 ms_emit_arrayed_outputs(b, index, per_vertex_outputs, s);
893
894 if (exports) {
895 ac_nir_export_position(b, s->gfx_level, s->clipdist_enable_mask,
896 !s->has_param_exports, false, true,
897 s->per_vertex_outputs | VARYING_BIT_POS, &s->out, row);
898 }
899
900 if (parameters) {
901 /* Export generic attributes on GFX10.3
902 * (On GFX11 they are already stored in the attribute ring.)
903 */
904 if (s->has_param_exports && s->gfx_level == GFX10_3) {
905 ac_nir_export_parameters(b, s->vs_output_param_offset, per_vertex_outputs, 0, &s->out);
906 }
907
908 /* GFX11+: also store special outputs to the attribute ring so PS can load them. */
909 if (s->gfx_level >= GFX11 && (per_vertex_outputs & MS_VERT_ARG_EXP_MASK))
910 ms_emit_attribute_ring_output_stores(b, per_vertex_outputs & MS_VERT_ARG_EXP_MASK, index, s);
911 }
912 }
913
914 static void
emit_ms_primitive(nir_builder * b,nir_def * index,nir_def * row,bool exports,bool parameters,uint64_t per_primitive_outputs,lower_ngg_ms_state * s)915 emit_ms_primitive(nir_builder *b, nir_def *index, nir_def *row, bool exports, bool parameters,
916 uint64_t per_primitive_outputs, lower_ngg_ms_state *s)
917 {
918 ms_emit_arrayed_outputs(b, index, per_primitive_outputs, s);
919
920 /* Insert layer output store if the pipeline uses multiview but the API shader doesn't write it. */
921 if (s->insert_layer_output) {
922 s->out.outputs[VARYING_SLOT_LAYER][0] = nir_load_view_index(b);
923 s->out.infos[VARYING_SLOT_LAYER].as_sysval_mask |= 1;
924 }
925
926 if (exports) {
927 const uint64_t outputs_mask = per_primitive_outputs & MS_PRIM_ARG_EXP_MASK;
928 nir_def *num_vtx = nir_load_var(b, s->vertex_count_var);
929 nir_def *prim_exp_arg_ch1 = ms_prim_exp_arg_ch1(b, index, num_vtx, s);
930 nir_def *prim_exp_arg_ch2 = ms_prim_exp_arg_ch2(b, outputs_mask, s);
931
932 nir_def *prim_exp_arg = prim_exp_arg_ch2 ?
933 nir_vec2(b, prim_exp_arg_ch1, prim_exp_arg_ch2) : prim_exp_arg_ch1;
934
935 ac_nir_export_primitive(b, prim_exp_arg, row);
936 }
937
938 if (parameters) {
939 /* Export generic attributes on GFX10.3
940 * (On GFX11 they are already stored in the attribute ring.)
941 */
942 if (s->has_param_exports && s->gfx_level == GFX10_3) {
943 ac_nir_export_parameters(b, s->vs_output_param_offset, per_primitive_outputs, 0, &s->out);
944 }
945
946 /* GFX11+: also store special outputs to the attribute ring so PS can load them. */
947 if (s->gfx_level >= GFX11)
948 ms_emit_attribute_ring_output_stores(b, per_primitive_outputs & MS_PRIM_ARG_EXP_MASK, index, s);
949 }
950 }
951
952 static void
emit_ms_outputs(nir_builder * b,nir_def * invocation_index,nir_def * row_start,nir_def * count,bool exports,bool parameters,uint64_t mask,void (* cb)(nir_builder *,nir_def *,nir_def *,bool,bool,uint64_t,lower_ngg_ms_state *),lower_ngg_ms_state * s)953 emit_ms_outputs(nir_builder *b, nir_def *invocation_index, nir_def *row_start,
954 nir_def *count, bool exports, bool parameters, uint64_t mask,
955 void (*cb)(nir_builder *, nir_def *, nir_def *, bool, bool,
956 uint64_t, lower_ngg_ms_state *),
957 lower_ngg_ms_state *s)
958 {
959 if (cb == &emit_ms_primitive ? s->prim_multirow_export : s->vert_multirow_export) {
960 assert(s->hw_workgroup_size % s->wave_size == 0);
961 const unsigned num_waves = s->hw_workgroup_size / s->wave_size;
962
963 nir_loop *row_loop = nir_push_loop(b);
964 {
965 nir_block *preheader = nir_cf_node_as_block(nir_cf_node_prev(&row_loop->cf_node));
966
967 nir_phi_instr *index = nir_phi_instr_create(b->shader);
968 nir_phi_instr *row = nir_phi_instr_create(b->shader);
969 nir_def_init(&index->instr, &index->def, 1, 32);
970 nir_def_init(&row->instr, &row->def, 1, 32);
971
972 nir_phi_instr_add_src(index, preheader, invocation_index);
973 nir_phi_instr_add_src(row, preheader, row_start);
974
975 nir_if *if_break = nir_push_if(b, nir_uge(b, &index->def, count));
976 {
977 nir_jump(b, nir_jump_break);
978 }
979 nir_pop_if(b, if_break);
980
981 cb(b, &index->def, &row->def, exports, parameters, mask, s);
982
983 nir_block *body = nir_cursor_current_block(b->cursor);
984 nir_phi_instr_add_src(index, body,
985 nir_iadd_imm(b, &index->def, s->hw_workgroup_size));
986 nir_phi_instr_add_src(row, body,
987 nir_iadd_imm(b, &row->def, num_waves));
988
989 nir_instr_insert_before_cf_list(&row_loop->body, &row->instr);
990 nir_instr_insert_before_cf_list(&row_loop->body, &index->instr);
991 }
992 nir_pop_loop(b, row_loop);
993 } else {
994 nir_def *has_output = nir_ilt(b, invocation_index, count);
995 nir_if *if_has_output = nir_push_if(b, has_output);
996 {
997 cb(b, invocation_index, row_start, exports, parameters, mask, s);
998 }
999 nir_pop_if(b, if_has_output);
1000 }
1001 }
1002
1003 static void
emit_ms_finale(nir_builder * b,lower_ngg_ms_state * s)1004 emit_ms_finale(nir_builder *b, lower_ngg_ms_state *s)
1005 {
1006 /* We assume there is always a single end block in the shader. */
1007 nir_block *last_block = nir_impl_last_block(b->impl);
1008 b->cursor = nir_after_block(last_block);
1009
1010 nir_barrier(b, .execution_scope=SCOPE_WORKGROUP, .memory_scope=SCOPE_WORKGROUP,
1011 .memory_semantics=NIR_MEMORY_ACQ_REL, .memory_modes=nir_var_shader_out|nir_var_mem_shared);
1012
1013 nir_def *num_prm;
1014 nir_def *num_vtx;
1015
1016 set_ms_final_output_counts(b, s, &num_prm, &num_vtx);
1017
1018 nir_def *invocation_index = nir_load_local_invocation_index(b);
1019
1020 ms_prim_gen_query(b, invocation_index, num_prm, s);
1021
1022 nir_def *row_start = NULL;
1023 if (s->fast_launch_2)
1024 row_start = s->hw_workgroup_size <= s->wave_size ? nir_imm_int(b, 0) : nir_load_subgroup_id(b);
1025
1026 /* Load vertex/primitive attributes from shared memory and
1027 * emit store_output intrinsics for them.
1028 *
1029 * Contrary to the semantics of the API mesh shader, these are now
1030 * compliant with NGG HW semantics, meaning that these store the
1031 * current thread's vertex attributes in a way the HW can export.
1032 */
1033
1034 uint64_t per_vertex_outputs =
1035 s->per_vertex_outputs & ~s->layout.attr_ring.vtx_attr.mask;
1036 uint64_t per_primitive_outputs =
1037 s->per_primitive_outputs & ~s->layout.attr_ring.prm_attr.mask & ~SPECIAL_MS_OUT_MASK;
1038
1039 /* Insert layer output store if the pipeline uses multiview but the API shader doesn't write it. */
1040 if (s->insert_layer_output) {
1041 b->shader->info.outputs_written |= VARYING_BIT_LAYER;
1042 b->shader->info.per_primitive_outputs |= VARYING_BIT_LAYER;
1043 per_primitive_outputs |= VARYING_BIT_LAYER;
1044 }
1045
1046 const bool has_special_param_exports =
1047 (per_vertex_outputs & MS_VERT_ARG_EXP_MASK) ||
1048 (per_primitive_outputs & MS_PRIM_ARG_EXP_MASK);
1049
1050 const bool wait_attr_ring = must_wait_attr_ring(s->gfx_level, has_special_param_exports);
1051
1052 /* Export vertices. */
1053 if ((per_vertex_outputs & ~VARYING_BIT_POS) || !wait_attr_ring) {
1054 emit_ms_outputs(b, invocation_index, row_start, num_vtx, !wait_attr_ring, true,
1055 per_vertex_outputs, &emit_ms_vertex, s);
1056 }
1057
1058 /* Export primitives. */
1059 if (per_primitive_outputs || !wait_attr_ring) {
1060 emit_ms_outputs(b, invocation_index, row_start, num_prm, !wait_attr_ring, true,
1061 per_primitive_outputs, &emit_ms_primitive, s);
1062 }
1063
1064 /* When we need to wait for attribute ring stores, we emit both position and primitive
1065 * export instructions after a barrier to make sure both per-vertex and per-primitive
1066 * attribute ring stores are finished before the GPU starts rasterization.
1067 */
1068 if (wait_attr_ring) {
1069 /* Wait for attribute stores to finish. */
1070 nir_barrier(b, .execution_scope = SCOPE_SUBGROUP,
1071 .memory_scope = SCOPE_DEVICE,
1072 .memory_semantics = NIR_MEMORY_RELEASE,
1073 .memory_modes = nir_var_shader_out);
1074
1075 /* Position/primitive export only */
1076 emit_ms_outputs(b, invocation_index, row_start, num_vtx, true, false,
1077 per_vertex_outputs, &emit_ms_vertex, s);
1078 emit_ms_outputs(b, invocation_index, row_start, num_prm, true, false,
1079 per_primitive_outputs, &emit_ms_primitive, s);
1080 }
1081 }
1082
1083 static void
handle_smaller_ms_api_workgroup(nir_builder * b,lower_ngg_ms_state * s)1084 handle_smaller_ms_api_workgroup(nir_builder *b,
1085 lower_ngg_ms_state *s)
1086 {
1087 if (s->api_workgroup_size >= s->hw_workgroup_size)
1088 return;
1089
1090 /* Handle barriers manually when the API workgroup
1091 * size is less than the HW workgroup size.
1092 *
1093 * The problem is that the real workgroup launched on NGG HW
1094 * will be larger than the size specified by the API, and the
1095 * extra waves need to keep up with barriers in the API waves.
1096 *
1097 * There are 2 different cases:
1098 * 1. The whole API workgroup fits in a single wave.
1099 * We can shrink the barriers to subgroup scope and
1100 * don't need to insert any extra ones.
1101 * 2. The API workgroup occupies multiple waves, but not
1102 * all. In this case, we emit code that consumes every
1103 * barrier on the extra waves.
1104 */
1105 assert(s->hw_workgroup_size % s->wave_size == 0);
1106 bool scan_barriers = ALIGN(s->api_workgroup_size, s->wave_size) < s->hw_workgroup_size;
1107 bool can_shrink_barriers = s->api_workgroup_size <= s->wave_size;
1108 bool need_additional_barriers = scan_barriers && !can_shrink_barriers;
1109
1110 unsigned api_waves_in_flight_addr = s->layout.lds.workgroup_info_addr + lds_ms_num_api_waves;
1111 unsigned num_api_waves = DIV_ROUND_UP(s->api_workgroup_size, s->wave_size);
1112
1113 /* Scan the shader for workgroup barriers. */
1114 if (scan_barriers) {
1115 bool has_any_workgroup_barriers = false;
1116
1117 nir_foreach_block(block, b->impl) {
1118 nir_foreach_instr_safe(instr, block) {
1119 if (instr->type != nir_instr_type_intrinsic)
1120 continue;
1121
1122 nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
1123 bool is_workgroup_barrier =
1124 intrin->intrinsic == nir_intrinsic_barrier &&
1125 nir_intrinsic_execution_scope(intrin) == SCOPE_WORKGROUP;
1126
1127 if (!is_workgroup_barrier)
1128 continue;
1129
1130 if (can_shrink_barriers) {
1131 /* Every API invocation runs in the first wave.
1132 * In this case, we can change the barriers to subgroup scope
1133 * and avoid adding additional barriers.
1134 */
1135 nir_intrinsic_set_memory_scope(intrin, SCOPE_SUBGROUP);
1136 nir_intrinsic_set_execution_scope(intrin, SCOPE_SUBGROUP);
1137 } else {
1138 has_any_workgroup_barriers = true;
1139 }
1140 }
1141 }
1142
1143 need_additional_barriers &= has_any_workgroup_barriers;
1144 }
1145
1146 /* Extract the full control flow of the shader. */
1147 nir_cf_list extracted;
1148 nir_cf_extract(&extracted, nir_before_impl(b->impl),
1149 nir_after_cf_list(&b->impl->body));
1150 b->cursor = nir_before_impl(b->impl);
1151
1152 /* Wrap the shader in an if to ensure that only the necessary amount of lanes run it. */
1153 nir_def *invocation_index = nir_load_local_invocation_index(b);
1154 nir_def *zero = nir_imm_int(b, 0);
1155
1156 if (need_additional_barriers) {
1157 /* First invocation stores 0 to number of API waves in flight. */
1158 nir_if *if_first_in_workgroup = nir_push_if(b, nir_ieq_imm(b, invocation_index, 0));
1159 {
1160 nir_store_shared(b, nir_imm_int(b, num_api_waves), zero, .base = api_waves_in_flight_addr);
1161 }
1162 nir_pop_if(b, if_first_in_workgroup);
1163
1164 nir_barrier(b, .execution_scope = SCOPE_WORKGROUP,
1165 .memory_scope = SCOPE_WORKGROUP,
1166 .memory_semantics = NIR_MEMORY_ACQ_REL,
1167 .memory_modes = nir_var_shader_out | nir_var_mem_shared);
1168 }
1169
1170 nir_def *has_api_ms_invocation = nir_ult_imm(b, invocation_index, s->api_workgroup_size);
1171 nir_if *if_has_api_ms_invocation = nir_push_if(b, has_api_ms_invocation);
1172 {
1173 nir_cf_reinsert(&extracted, b->cursor);
1174 b->cursor = nir_after_cf_list(&if_has_api_ms_invocation->then_list);
1175
1176 if (need_additional_barriers) {
1177 /* One invocation in each API wave decrements the number of API waves in flight. */
1178 nir_if *if_elected_again = nir_push_if(b, nir_elect(b, 1));
1179 {
1180 nir_shared_atomic(b, 32, zero, nir_imm_int(b, -1u),
1181 .base = api_waves_in_flight_addr,
1182 .atomic_op = nir_atomic_op_iadd);
1183 }
1184 nir_pop_if(b, if_elected_again);
1185
1186 nir_barrier(b, .execution_scope = SCOPE_WORKGROUP,
1187 .memory_scope = SCOPE_WORKGROUP,
1188 .memory_semantics = NIR_MEMORY_ACQ_REL,
1189 .memory_modes = nir_var_shader_out | nir_var_mem_shared);
1190 }
1191
1192 ms_invocation_query(b, invocation_index, s);
1193 }
1194 nir_pop_if(b, if_has_api_ms_invocation);
1195
1196 if (need_additional_barriers) {
1197 /* Make sure that waves that don't run any API invocations execute
1198 * the same amount of barriers as those that do.
1199 *
1200 * We do this by executing a barrier until the number of API waves
1201 * in flight becomes zero.
1202 */
1203 nir_def *has_api_ms_ballot = nir_ballot(b, 1, s->wave_size, has_api_ms_invocation);
1204 nir_def *wave_has_no_api_ms = nir_ieq_imm(b, has_api_ms_ballot, 0);
1205 nir_if *if_wave_has_no_api_ms = nir_push_if(b, wave_has_no_api_ms);
1206 {
1207 nir_if *if_elected = nir_push_if(b, nir_elect(b, 1));
1208 {
1209 nir_loop *loop = nir_push_loop(b);
1210 {
1211 nir_barrier(b, .execution_scope = SCOPE_WORKGROUP,
1212 .memory_scope = SCOPE_WORKGROUP,
1213 .memory_semantics = NIR_MEMORY_ACQ_REL,
1214 .memory_modes = nir_var_shader_out | nir_var_mem_shared);
1215
1216 nir_def *loaded = nir_load_shared(b, 1, 32, zero, .base = api_waves_in_flight_addr);
1217 nir_if *if_break = nir_push_if(b, nir_ieq_imm(b, loaded, 0));
1218 {
1219 nir_jump(b, nir_jump_break);
1220 }
1221 nir_pop_if(b, if_break);
1222 }
1223 nir_pop_loop(b, loop);
1224 }
1225 nir_pop_if(b, if_elected);
1226 }
1227 nir_pop_if(b, if_wave_has_no_api_ms);
1228 }
1229 }
1230
1231 static void
ms_move_output(ms_out_part * from,ms_out_part * to)1232 ms_move_output(ms_out_part *from, ms_out_part *to)
1233 {
1234 uint64_t loc = util_logbase2_64(from->mask);
1235 uint64_t bit = BITFIELD64_BIT(loc);
1236 from->mask ^= bit;
1237 to->mask |= bit;
1238 }
1239
1240 static void
ms_calculate_arrayed_output_layout(ms_out_mem_layout * l,unsigned max_vertices,unsigned max_primitives)1241 ms_calculate_arrayed_output_layout(ms_out_mem_layout *l,
1242 unsigned max_vertices,
1243 unsigned max_primitives)
1244 {
1245 uint32_t lds_vtx_attr_size = util_bitcount64(l->lds.vtx_attr.mask) * max_vertices * 16;
1246 uint32_t lds_prm_attr_size = util_bitcount64(l->lds.prm_attr.mask) * max_primitives * 16;
1247 l->lds.prm_attr.addr = ALIGN(l->lds.vtx_attr.addr + lds_vtx_attr_size, 16);
1248 l->lds.total_size = l->lds.prm_attr.addr + lds_prm_attr_size;
1249
1250 uint32_t scratch_ring_vtx_attr_size =
1251 util_bitcount64(l->scratch_ring.vtx_attr.mask) * max_vertices * 16;
1252 l->scratch_ring.prm_attr.addr =
1253 ALIGN(l->scratch_ring.vtx_attr.addr + scratch_ring_vtx_attr_size, 16);
1254 }
1255
1256 static ms_out_mem_layout
ms_calculate_output_layout(enum amd_gfx_level gfx_level,unsigned api_shared_size,uint64_t per_vertex_output_mask,uint64_t per_primitive_output_mask,uint64_t cross_invocation_output_access,unsigned max_vertices,unsigned max_primitives,unsigned vertices_per_prim)1257 ms_calculate_output_layout(enum amd_gfx_level gfx_level, unsigned api_shared_size,
1258 uint64_t per_vertex_output_mask, uint64_t per_primitive_output_mask,
1259 uint64_t cross_invocation_output_access, unsigned max_vertices,
1260 unsigned max_primitives, unsigned vertices_per_prim)
1261 {
1262 /* These outputs always need export instructions and can't use the attributes ring. */
1263 const uint64_t always_export_mask =
1264 VARYING_BIT_POS | VARYING_BIT_CULL_DIST0 | VARYING_BIT_CULL_DIST1 | VARYING_BIT_CLIP_DIST0 |
1265 VARYING_BIT_CLIP_DIST1 | VARYING_BIT_PSIZ | VARYING_BIT_VIEWPORT |
1266 VARYING_BIT_PRIMITIVE_SHADING_RATE | VARYING_BIT_LAYER |
1267 BITFIELD64_BIT(VARYING_SLOT_PRIMITIVE_COUNT) |
1268 BITFIELD64_BIT(VARYING_SLOT_PRIMITIVE_INDICES) | BITFIELD64_BIT(VARYING_SLOT_CULL_PRIMITIVE);
1269
1270 const bool use_attr_ring = gfx_level >= GFX11;
1271 const uint64_t attr_ring_per_vertex_output_mask =
1272 use_attr_ring ? per_vertex_output_mask & ~always_export_mask : 0;
1273 const uint64_t attr_ring_per_primitive_output_mask =
1274 use_attr_ring ? per_primitive_output_mask & ~always_export_mask : 0;
1275
1276 const uint64_t lds_per_vertex_output_mask =
1277 per_vertex_output_mask & ~attr_ring_per_vertex_output_mask & cross_invocation_output_access &
1278 ~SPECIAL_MS_OUT_MASK;
1279 const uint64_t lds_per_primitive_output_mask =
1280 per_primitive_output_mask & ~attr_ring_per_primitive_output_mask &
1281 cross_invocation_output_access & ~SPECIAL_MS_OUT_MASK;
1282
1283 const bool cross_invocation_indices =
1284 cross_invocation_output_access & BITFIELD64_BIT(VARYING_SLOT_PRIMITIVE_INDICES);
1285 const bool cross_invocation_cull_primitive =
1286 cross_invocation_output_access & BITFIELD64_BIT(VARYING_SLOT_CULL_PRIMITIVE);
1287
1288 /* Shared memory used by the API shader. */
1289 ms_out_mem_layout l = { .lds = { .total_size = api_shared_size } };
1290
1291 /* GFX11+: use attribute ring for all generic attributes. */
1292 l.attr_ring.vtx_attr.mask = attr_ring_per_vertex_output_mask;
1293 l.attr_ring.prm_attr.mask = attr_ring_per_primitive_output_mask;
1294
1295 /* Outputs without cross-invocation access can be stored in variables. */
1296 l.var.vtx_attr.mask =
1297 per_vertex_output_mask & ~attr_ring_per_vertex_output_mask & ~cross_invocation_output_access;
1298 l.var.prm_attr.mask = per_primitive_output_mask & ~attr_ring_per_primitive_output_mask &
1299 ~cross_invocation_output_access;
1300
1301 /* Workgroup information, see ms_workgroup_* for the layout. */
1302 l.lds.workgroup_info_addr = ALIGN(l.lds.total_size, 16);
1303 l.lds.total_size = l.lds.workgroup_info_addr + 16;
1304
1305 /* Per-vertex and per-primitive output attributes.
1306 * Outputs without cross-invocation access are not included here.
1307 * First, try to put all outputs into LDS (shared memory).
1308 * If they don't fit, try to move them to VRAM one by one.
1309 */
1310 l.lds.vtx_attr.addr = ALIGN(l.lds.total_size, 16);
1311 l.lds.vtx_attr.mask = lds_per_vertex_output_mask;
1312 l.lds.prm_attr.mask = lds_per_primitive_output_mask;
1313 ms_calculate_arrayed_output_layout(&l, max_vertices, max_primitives);
1314
1315 /* NGG shaders can only address up to 32K LDS memory.
1316 * The spec requires us to allow the application to use at least up to 28K
1317 * shared memory. Additionally, we reserve 2K for driver internal use
1318 * (eg. primitive indices and such, see below).
1319 *
1320 * Move the outputs that do not fit LDS, to VRAM.
1321 * Start with per-primitive attributes, because those are grouped at the end.
1322 */
1323 const unsigned usable_lds_kbytes =
1324 (cross_invocation_cull_primitive || cross_invocation_indices) ? 30 : 31;
1325 while (l.lds.total_size >= usable_lds_kbytes * 1024) {
1326 if (l.lds.prm_attr.mask)
1327 ms_move_output(&l.lds.prm_attr, &l.scratch_ring.prm_attr);
1328 else if (l.lds.vtx_attr.mask)
1329 ms_move_output(&l.lds.vtx_attr, &l.scratch_ring.vtx_attr);
1330 else
1331 unreachable("API shader uses too much shared memory.");
1332
1333 ms_calculate_arrayed_output_layout(&l, max_vertices, max_primitives);
1334 }
1335
1336 if (cross_invocation_indices) {
1337 /* Indices: flat array of 8-bit vertex indices for each primitive. */
1338 l.lds.indices_addr = ALIGN(l.lds.total_size, 16);
1339 l.lds.total_size = l.lds.indices_addr + max_primitives * vertices_per_prim;
1340 }
1341
1342 if (cross_invocation_cull_primitive) {
1343 /* Cull flags: array of 8-bit cull flags for each primitive, 1=cull, 0=keep. */
1344 l.lds.cull_flags_addr = ALIGN(l.lds.total_size, 16);
1345 l.lds.total_size = l.lds.cull_flags_addr + max_primitives;
1346 }
1347
1348 /* NGG is only allowed to address up to 32K of LDS. */
1349 assert(l.lds.total_size <= 32 * 1024);
1350 return l;
1351 }
1352
1353 void
ac_nir_lower_ngg_mesh(nir_shader * shader,enum amd_gfx_level gfx_level,uint32_t clipdist_enable_mask,const uint8_t * vs_output_param_offset,bool has_param_exports,bool * out_needs_scratch_ring,unsigned wave_size,unsigned hw_workgroup_size,bool multiview,bool has_query,bool fast_launch_2)1354 ac_nir_lower_ngg_mesh(nir_shader *shader,
1355 enum amd_gfx_level gfx_level,
1356 uint32_t clipdist_enable_mask,
1357 const uint8_t *vs_output_param_offset,
1358 bool has_param_exports,
1359 bool *out_needs_scratch_ring,
1360 unsigned wave_size,
1361 unsigned hw_workgroup_size,
1362 bool multiview,
1363 bool has_query,
1364 bool fast_launch_2)
1365 {
1366 unsigned vertices_per_prim =
1367 mesa_vertices_per_prim(shader->info.mesh.primitive_type);
1368
1369 uint64_t per_vertex_outputs =
1370 shader->info.outputs_written & ~shader->info.per_primitive_outputs & ~SPECIAL_MS_OUT_MASK;
1371 uint64_t per_primitive_outputs =
1372 shader->info.per_primitive_outputs & shader->info.outputs_written;
1373
1374 /* Whether the shader uses CullPrimitiveEXT */
1375 bool uses_cull = shader->info.outputs_written & BITFIELD64_BIT(VARYING_SLOT_CULL_PRIMITIVE);
1376 /* Can't handle indirect register addressing, pretend as if they were cross-invocation. */
1377 uint64_t cross_invocation_access = shader->info.mesh.ms_cross_invocation_output_access |
1378 shader->info.outputs_accessed_indirectly;
1379
1380 unsigned max_vertices = shader->info.mesh.max_vertices_out;
1381 unsigned max_primitives = shader->info.mesh.max_primitives_out;
1382
1383 ms_out_mem_layout layout = ms_calculate_output_layout(
1384 gfx_level, shader->info.shared_size, per_vertex_outputs, per_primitive_outputs,
1385 cross_invocation_access, max_vertices, max_primitives, vertices_per_prim);
1386
1387 shader->info.shared_size = layout.lds.total_size;
1388 *out_needs_scratch_ring = layout.scratch_ring.vtx_attr.mask || layout.scratch_ring.prm_attr.mask;
1389
1390 /* The workgroup size that is specified by the API shader may be different
1391 * from the size of the workgroup that actually runs on the HW, due to the
1392 * limitations of NGG: max 0/1 vertex and 0/1 primitive per lane is allowed.
1393 *
1394 * Therefore, we must make sure that when the API workgroup size is smaller,
1395 * we don't run the API shader on more HW invocations than is necessary.
1396 */
1397 unsigned api_workgroup_size = shader->info.workgroup_size[0] *
1398 shader->info.workgroup_size[1] *
1399 shader->info.workgroup_size[2];
1400
1401 lower_ngg_ms_state state = {
1402 .layout = layout,
1403 .wave_size = wave_size,
1404 .per_vertex_outputs = per_vertex_outputs,
1405 .per_primitive_outputs = per_primitive_outputs,
1406 .vertices_per_prim = vertices_per_prim,
1407 .api_workgroup_size = api_workgroup_size,
1408 .hw_workgroup_size = hw_workgroup_size,
1409 .insert_layer_output = multiview && !(shader->info.outputs_written & VARYING_BIT_LAYER),
1410 .uses_cull_flags = uses_cull,
1411 .gfx_level = gfx_level,
1412 .fast_launch_2 = fast_launch_2,
1413 .vert_multirow_export = fast_launch_2 && max_vertices > hw_workgroup_size,
1414 .prim_multirow_export = fast_launch_2 && max_primitives > hw_workgroup_size,
1415 .clipdist_enable_mask = clipdist_enable_mask,
1416 .vs_output_param_offset = vs_output_param_offset,
1417 .has_param_exports = has_param_exports,
1418 .has_query = has_query,
1419 };
1420
1421 nir_function_impl *impl = nir_shader_get_entrypoint(shader);
1422 assert(impl);
1423
1424 state.vertex_count_var =
1425 nir_local_variable_create(impl, glsl_uint_type(), "vertex_count_var");
1426 state.primitive_count_var =
1427 nir_local_variable_create(impl, glsl_uint_type(), "primitive_count_var");
1428
1429 nir_builder builder = nir_builder_at(nir_before_impl(impl));
1430 nir_builder *b = &builder; /* This is to avoid the & */
1431
1432 handle_smaller_ms_api_workgroup(b, &state);
1433 if (!fast_launch_2)
1434 ms_emit_legacy_workgroup_index(b, &state);
1435 ms_create_same_invocation_vars(b, &state);
1436 nir_metadata_preserve(impl, nir_metadata_none);
1437
1438 lower_ms_intrinsics(shader, &state);
1439
1440 emit_ms_finale(b, &state);
1441 nir_metadata_preserve(impl, nir_metadata_none);
1442
1443 /* Cleanup */
1444 nir_lower_vars_to_ssa(shader);
1445 nir_remove_dead_variables(shader, nir_var_function_temp, NULL);
1446 nir_lower_alu_to_scalar(shader, NULL, NULL);
1447 nir_lower_phis_to_scalar(shader, true);
1448
1449 /* Optimize load_local_invocation_index. When the API workgroup is smaller than the HW workgroup,
1450 * local_invocation_id isn't initialized for all lanes and we can't perform this optimization for
1451 * all load_local_invocation_index.
1452 */
1453 if (fast_launch_2 && api_workgroup_size == hw_workgroup_size &&
1454 ((shader->info.workgroup_size[0] == 1) + (shader->info.workgroup_size[1] == 1) +
1455 (shader->info.workgroup_size[2] == 1)) == 2) {
1456 nir_lower_compute_system_values_options csv_options = {
1457 .lower_local_invocation_index = true,
1458 };
1459 nir_lower_compute_system_values(shader, &csv_options);
1460 }
1461
1462 nir_validate_shader(shader, "after emitting NGG MS");
1463 }
1464