1 /*
2 * Copyright © 2021 Valve Corporation
3 *
4 * SPDX-License-Identifier: MIT
5 */
6
7 #include "ac_nir.h"
8 #include "ac_nir_helpers.h"
9 #include "ac_gpu_info.h"
10
11 #include "nir_builder.h"
12
13 #define SPECIAL_MS_OUT_MASK \
14 (BITFIELD64_BIT(VARYING_SLOT_PRIMITIVE_COUNT) | \
15 BITFIELD64_BIT(VARYING_SLOT_PRIMITIVE_INDICES) | \
16 BITFIELD64_BIT(VARYING_SLOT_CULL_PRIMITIVE))
17
18 #define MS_PRIM_ARG_EXP_MASK \
19 (VARYING_BIT_LAYER | \
20 VARYING_BIT_VIEWPORT | \
21 VARYING_BIT_PRIMITIVE_SHADING_RATE)
22
23 #define MS_VERT_ARG_EXP_MASK \
24 (VARYING_BIT_CULL_DIST0 | \
25 VARYING_BIT_CULL_DIST1 | \
26 VARYING_BIT_CLIP_DIST0 | \
27 VARYING_BIT_CLIP_DIST1 | \
28 VARYING_BIT_PSIZ)
29
30 /* LDS layout of Mesh Shader workgroup info. */
31 enum {
32 /* DW0: number of primitives */
33 lds_ms_num_prims = 0,
34 /* DW1: number of vertices */
35 lds_ms_num_vtx = 4,
36 /* DW2: workgroup index within the current dispatch */
37 lds_ms_wg_index = 8,
38 /* DW3: number of API workgroups in flight */
39 lds_ms_num_api_waves = 12,
40 };
41
42 /* Potential location for Mesh Shader outputs. */
43 typedef enum {
44 ms_out_mode_lds,
45 ms_out_mode_scratch_ring,
46 ms_out_mode_attr_ring,
47 ms_out_mode_var,
48 } ms_out_mode;
49
50 typedef struct
51 {
52 uint64_t mask; /* Mask of output locations */
53 uint32_t addr; /* Base address */
54 } ms_out_part;
55
56 typedef struct
57 {
58 /* Mesh shader LDS layout. For details, see ms_calculate_output_layout. */
59 struct {
60 uint32_t workgroup_info_addr;
61 ms_out_part vtx_attr;
62 ms_out_part prm_attr;
63 uint32_t indices_addr;
64 uint32_t cull_flags_addr;
65 uint32_t total_size;
66 } lds;
67
68 /* VRAM "mesh shader scratch ring" layout for outputs that don't fit into the LDS.
69 * Not to be confused with scratch memory.
70 */
71 struct {
72 ms_out_part vtx_attr;
73 ms_out_part prm_attr;
74 } scratch_ring;
75
76 /* VRAM attributes ring (supported GPUs only) for all non-position outputs.
77 * We don't have to reload attributes from this ring at the end of the shader.
78 */
79 struct {
80 ms_out_part vtx_attr;
81 ms_out_part prm_attr;
82 } attr_ring;
83
84 /* Outputs without cross-invocation access can be stored in variables. */
85 struct {
86 ms_out_part vtx_attr;
87 ms_out_part prm_attr;
88 } var;
89 } ms_out_mem_layout;
90
91 typedef struct
92 {
93 const struct radeon_info *hw_info;
94 bool fast_launch_2;
95 bool vert_multirow_export;
96 bool prim_multirow_export;
97
98 ms_out_mem_layout layout;
99 uint64_t per_vertex_outputs;
100 uint64_t per_primitive_outputs;
101 unsigned vertices_per_prim;
102
103 unsigned wave_size;
104 unsigned api_workgroup_size;
105 unsigned hw_workgroup_size;
106
107 nir_def *workgroup_index;
108 nir_variable *out_variables[VARYING_SLOT_MAX * 4];
109 nir_variable *primitive_count_var;
110 nir_variable *vertex_count_var;
111
112 ac_nir_prerast_out out;
113
114 /* True if the lowering needs to insert the layer output. */
115 bool insert_layer_output;
116 /* True if cull flags are used */
117 bool uses_cull_flags;
118
119 uint32_t clipdist_enable_mask;
120 const uint8_t *vs_output_param_offset;
121 bool has_param_exports;
122
123 /* True if the lowering needs to insert shader query. */
124 bool has_query;
125 } lower_ngg_ms_state;
126
127 static void
ms_store_prim_indices(nir_builder * b,nir_intrinsic_instr * intrin,lower_ngg_ms_state * s)128 ms_store_prim_indices(nir_builder *b,
129 nir_intrinsic_instr *intrin,
130 lower_ngg_ms_state *s)
131 {
132 /* EXT_mesh_shader primitive indices: array of vectors.
133 * They don't count as per-primitive outputs, but the array is indexed
134 * by the primitive index, so they are practically per-primitive.
135 */
136 assert(nir_src_is_const(*nir_get_io_offset_src(intrin)));
137 assert(nir_src_as_uint(*nir_get_io_offset_src(intrin)) == 0);
138
139 const unsigned component_offset = nir_intrinsic_component(intrin);
140 nir_def *store_val = intrin->src[0].ssa;
141 assert(store_val->num_components <= 3);
142
143 if (store_val->num_components > s->vertices_per_prim)
144 store_val = nir_trim_vector(b, store_val, s->vertices_per_prim);
145
146 if (s->layout.var.prm_attr.mask & BITFIELD64_BIT(VARYING_SLOT_PRIMITIVE_INDICES)) {
147 for (unsigned c = 0; c < store_val->num_components; ++c) {
148 const unsigned i = VARYING_SLOT_PRIMITIVE_INDICES * 4 + c + component_offset;
149 nir_store_var(b, s->out_variables[i], nir_channel(b, store_val, c), 0x1);
150 }
151 return;
152 }
153
154 nir_def *arr_index = nir_get_io_arrayed_index_src(intrin)->ssa;
155 nir_def *offset = nir_imul_imm(b, arr_index, s->vertices_per_prim);
156
157 /* The max vertex count is 256, so these indices always fit 8 bits.
158 * To reduce LDS use, store these as a flat array of 8-bit values.
159 */
160 nir_store_shared(b, nir_u2u8(b, store_val), offset, .base = s->layout.lds.indices_addr + component_offset);
161 }
162
163 static void
ms_store_cull_flag(nir_builder * b,nir_intrinsic_instr * intrin,lower_ngg_ms_state * s)164 ms_store_cull_flag(nir_builder *b,
165 nir_intrinsic_instr *intrin,
166 lower_ngg_ms_state *s)
167 {
168 /* EXT_mesh_shader cull primitive: per-primitive bool. */
169 assert(nir_src_is_const(*nir_get_io_offset_src(intrin)));
170 assert(nir_src_as_uint(*nir_get_io_offset_src(intrin)) == 0);
171 assert(nir_intrinsic_component(intrin) == 0);
172 assert(nir_intrinsic_write_mask(intrin) == 1);
173
174 nir_def *store_val = intrin->src[0].ssa;
175
176 assert(store_val->num_components == 1);
177 assert(store_val->bit_size == 1);
178
179 if (s->layout.var.prm_attr.mask & BITFIELD64_BIT(VARYING_SLOT_CULL_PRIMITIVE)) {
180 nir_store_var(b, s->out_variables[VARYING_SLOT_CULL_PRIMITIVE * 4], nir_b2i32(b, store_val), 0x1);
181 return;
182 }
183
184 nir_def *arr_index = nir_get_io_arrayed_index_src(intrin)->ssa;
185 nir_def *offset = nir_imul_imm(b, arr_index, s->vertices_per_prim);
186
187 /* To reduce LDS use, store these as an array of 8-bit values. */
188 nir_store_shared(b, nir_b2i8(b, store_val), offset, .base = s->layout.lds.cull_flags_addr);
189 }
190
191 static nir_def *
ms_arrayed_output_base_addr(nir_builder * b,nir_def * arr_index,unsigned mapped_location,unsigned num_arrayed_outputs)192 ms_arrayed_output_base_addr(nir_builder *b,
193 nir_def *arr_index,
194 unsigned mapped_location,
195 unsigned num_arrayed_outputs)
196 {
197 /* Address offset of the array item (vertex or primitive). */
198 unsigned arr_index_stride = num_arrayed_outputs * 16u;
199 nir_def *arr_index_off = nir_imul_imm(b, arr_index, arr_index_stride);
200
201 /* IO address offset within the vertex or primitive data. */
202 unsigned io_offset = mapped_location * 16u;
203 nir_def *io_off = nir_imm_int(b, io_offset);
204
205 return nir_iadd_nuw(b, arr_index_off, io_off);
206 }
207
208 static void
update_ms_output_info(const nir_io_semantics io_sem,const nir_src * base_offset_src,const uint32_t write_mask,const unsigned component_offset,const unsigned bit_size,const ms_out_part * out,lower_ngg_ms_state * s)209 update_ms_output_info(const nir_io_semantics io_sem,
210 const nir_src *base_offset_src,
211 const uint32_t write_mask,
212 const unsigned component_offset,
213 const unsigned bit_size,
214 const ms_out_part *out,
215 lower_ngg_ms_state *s)
216 {
217 const uint32_t components_mask = write_mask << component_offset;
218
219 /* 64-bit outputs should have already been lowered to 32-bit. */
220 assert(bit_size <= 32);
221 assert(components_mask <= 0xf);
222
223 /* When the base offset is constant, only mark the components of the current slot as used.
224 * Otherwise, mark the components of all possibly affected slots as used.
225 */
226 const unsigned base_off_start = nir_src_is_const(*base_offset_src) ? nir_src_as_uint(*base_offset_src) : 0;
227 const unsigned num_slots = nir_src_is_const(*base_offset_src) ? 1 : io_sem.num_slots;
228
229 for (unsigned base_off = base_off_start; base_off < num_slots; ++base_off) {
230 ac_nir_prerast_per_output_info *info = &s->out.infos[io_sem.location + base_off];
231 info->components_mask |= components_mask;
232
233 if (!io_sem.no_sysval_output)
234 info->as_sysval_mask |= components_mask;
235 if (!io_sem.no_varying)
236 info->as_varying_mask |= components_mask;
237 }
238 }
239
240 static const ms_out_part *
ms_get_out_layout_part(unsigned location,shader_info * info,ms_out_mode * out_mode,lower_ngg_ms_state * s)241 ms_get_out_layout_part(unsigned location,
242 shader_info *info,
243 ms_out_mode *out_mode,
244 lower_ngg_ms_state *s)
245 {
246 uint64_t mask = BITFIELD64_BIT(location);
247
248 if (info->per_primitive_outputs & mask) {
249 if (mask & s->layout.lds.prm_attr.mask) {
250 *out_mode = ms_out_mode_lds;
251 return &s->layout.lds.prm_attr;
252 } else if (mask & s->layout.scratch_ring.prm_attr.mask) {
253 *out_mode = ms_out_mode_scratch_ring;
254 return &s->layout.scratch_ring.prm_attr;
255 } else if (mask & s->layout.attr_ring.prm_attr.mask) {
256 *out_mode = ms_out_mode_attr_ring;
257 return &s->layout.attr_ring.prm_attr;
258 } else if (mask & s->layout.var.prm_attr.mask) {
259 *out_mode = ms_out_mode_var;
260 return &s->layout.var.prm_attr;
261 }
262 } else {
263 if (mask & s->layout.lds.vtx_attr.mask) {
264 *out_mode = ms_out_mode_lds;
265 return &s->layout.lds.vtx_attr;
266 } else if (mask & s->layout.scratch_ring.vtx_attr.mask) {
267 *out_mode = ms_out_mode_scratch_ring;
268 return &s->layout.scratch_ring.vtx_attr;
269 } else if (mask & s->layout.attr_ring.vtx_attr.mask) {
270 *out_mode = ms_out_mode_attr_ring;
271 return &s->layout.attr_ring.vtx_attr;
272 } else if (mask & s->layout.var.vtx_attr.mask) {
273 *out_mode = ms_out_mode_var;
274 return &s->layout.var.vtx_attr;
275 }
276 }
277
278 unreachable("Couldn't figure out mesh shader output mode.");
279 }
280
281 static void
ms_store_arrayed_output(nir_builder * b,nir_src * base_off_src,nir_def * store_val,nir_def * arr_index,const nir_io_semantics io_sem,const unsigned component_offset,const unsigned write_mask,lower_ngg_ms_state * s)282 ms_store_arrayed_output(nir_builder *b,
283 nir_src *base_off_src,
284 nir_def *store_val,
285 nir_def *arr_index,
286 const nir_io_semantics io_sem,
287 const unsigned component_offset,
288 const unsigned write_mask,
289 lower_ngg_ms_state *s)
290 {
291 ms_out_mode out_mode;
292 const ms_out_part *out = ms_get_out_layout_part(io_sem.location, &b->shader->info, &out_mode, s);
293 update_ms_output_info(io_sem, base_off_src, write_mask, component_offset, store_val->bit_size, out, s);
294
295 bool hi_16b = io_sem.high_16bits;
296 bool lo_16b = !hi_16b && store_val->bit_size == 16;
297
298 unsigned mapped_location = util_bitcount64(out->mask & u_bit_consecutive64(0, io_sem.location));
299 unsigned num_outputs = util_bitcount64(out->mask);
300 unsigned const_off = out->addr + component_offset * 4 + (hi_16b ? 2 : 0);
301
302 nir_def *base_addr = ms_arrayed_output_base_addr(b, arr_index, mapped_location, num_outputs);
303 nir_def *base_offset = base_off_src->ssa;
304 nir_def *base_addr_off = nir_imul_imm(b, base_offset, 16u);
305 nir_def *addr = nir_iadd_nuw(b, base_addr, base_addr_off);
306
307 if (out_mode == ms_out_mode_lds) {
308 nir_store_shared(b, store_val, addr, .base = const_off,
309 .write_mask = write_mask, .align_mul = 16,
310 .align_offset = const_off % 16);
311 } else if (out_mode == ms_out_mode_scratch_ring) {
312 nir_def *ring = nir_load_ring_mesh_scratch_amd(b);
313 nir_def *off = nir_load_ring_mesh_scratch_offset_amd(b);
314 nir_def *zero = nir_imm_int(b, 0);
315 nir_store_buffer_amd(b, store_val, ring, addr, off, zero,
316 .base = const_off,
317 .write_mask = write_mask,
318 .memory_modes = nir_var_shader_out,
319 .access = ACCESS_COHERENT);
320 } else if (out_mode == ms_out_mode_attr_ring) {
321 /* Store params straight to the attribute ring.
322 * Even though the access pattern may not be the most optimal,
323 * this is still much better than reserving LDS and losing waves.
324 * (Also much better than storing and reloading from the scratch ring.)
325 */
326 unsigned param_offset = s->vs_output_param_offset[io_sem.location];
327 nir_def *ring = nir_load_ring_attr_amd(b);
328 nir_def *soffset = nir_load_ring_attr_offset_amd(b);
329 nir_store_buffer_amd(b, store_val, ring, base_addr_off, soffset, arr_index,
330 .base = const_off + param_offset * 16,
331 .write_mask = write_mask,
332 .memory_modes = nir_var_shader_out,
333 .access = ACCESS_COHERENT | ACCESS_IS_SWIZZLED_AMD);
334 } else if (out_mode == ms_out_mode_var) {
335 unsigned write_mask_32 = write_mask;
336 if (store_val->bit_size > 32) {
337 /* Split 64-bit store values to 32-bit components. */
338 store_val = nir_bitcast_vector(b, store_val, 32);
339 /* Widen the write mask so it is in 32-bit components. */
340 write_mask_32 = util_widen_mask(write_mask, store_val->bit_size / 32);
341 }
342
343 u_foreach_bit(comp, write_mask_32) {
344 unsigned idx = io_sem.location * 4 + comp + component_offset;
345 nir_def *val = nir_channel(b, store_val, comp);
346 nir_def *v = nir_load_var(b, s->out_variables[idx]);
347
348 if (lo_16b) {
349 nir_def *var_hi = nir_unpack_32_2x16_split_y(b, v);
350 val = nir_pack_32_2x16_split(b, val, var_hi);
351 } else if (hi_16b) {
352 nir_def *var_lo = nir_unpack_32_2x16_split_x(b, v);
353 val = nir_pack_32_2x16_split(b, var_lo, val);
354 }
355
356 nir_store_var(b, s->out_variables[idx], val, 0x1);
357 }
358 } else {
359 unreachable("Invalid MS output mode for store");
360 }
361 }
362
363 static void
ms_store_arrayed_output_intrin(nir_builder * b,nir_intrinsic_instr * intrin,lower_ngg_ms_state * s)364 ms_store_arrayed_output_intrin(nir_builder *b,
365 nir_intrinsic_instr *intrin,
366 lower_ngg_ms_state *s)
367 {
368 const nir_io_semantics io_sem = nir_intrinsic_io_semantics(intrin);
369
370 if (io_sem.location == VARYING_SLOT_PRIMITIVE_INDICES) {
371 ms_store_prim_indices(b, intrin, s);
372 return;
373 } else if (io_sem.location == VARYING_SLOT_CULL_PRIMITIVE) {
374 ms_store_cull_flag(b, intrin, s);
375 return;
376 }
377
378 unsigned component_offset = nir_intrinsic_component(intrin);
379 unsigned write_mask = nir_intrinsic_write_mask(intrin);
380
381 nir_def *store_val = intrin->src[0].ssa;
382 nir_def *arr_index = nir_get_io_arrayed_index_src(intrin)->ssa;
383 nir_src *base_off_src = nir_get_io_offset_src(intrin);
384
385 if (store_val->bit_size < 32) {
386 /* Split 16-bit output stores to ensure each 16-bit component is stored
387 * in the correct location, without overwriting the other 16 bits there.
388 */
389 u_foreach_bit(c, write_mask) {
390 nir_def *store_component = nir_channel(b, store_val, c);
391 ms_store_arrayed_output(b, base_off_src, store_component, arr_index, io_sem, c + component_offset, 1, s);
392 }
393 } else {
394 ms_store_arrayed_output(b, base_off_src, store_val, arr_index, io_sem, component_offset, write_mask, s);
395 }
396 }
397
398 static nir_def *
ms_load_arrayed_output(nir_builder * b,nir_def * arr_index,nir_def * base_offset,unsigned location,unsigned component_offset,unsigned num_components,unsigned load_bit_size,lower_ngg_ms_state * s)399 ms_load_arrayed_output(nir_builder *b,
400 nir_def *arr_index,
401 nir_def *base_offset,
402 unsigned location,
403 unsigned component_offset,
404 unsigned num_components,
405 unsigned load_bit_size,
406 lower_ngg_ms_state *s)
407 {
408 ms_out_mode out_mode;
409 const ms_out_part *out = ms_get_out_layout_part(location, &b->shader->info, &out_mode, s);
410
411 unsigned component_addr_off = component_offset * 4;
412 unsigned num_outputs = util_bitcount64(out->mask);
413 unsigned const_off = out->addr + component_offset * 4;
414
415 /* Use compacted location instead of the original semantic location. */
416 unsigned mapped_location = util_bitcount64(out->mask & u_bit_consecutive64(0, location));
417
418 nir_def *base_addr = ms_arrayed_output_base_addr(b, arr_index, mapped_location, num_outputs);
419 nir_def *base_addr_off = nir_imul_imm(b, base_offset, 16);
420 nir_def *addr = nir_iadd_nuw(b, base_addr, base_addr_off);
421
422 if (out_mode == ms_out_mode_lds) {
423 return nir_load_shared(b, num_components, load_bit_size, addr, .align_mul = 16,
424 .align_offset = component_addr_off % 16,
425 .base = const_off);
426 } else if (out_mode == ms_out_mode_scratch_ring) {
427 nir_def *ring = nir_load_ring_mesh_scratch_amd(b);
428 nir_def *off = nir_load_ring_mesh_scratch_offset_amd(b);
429 nir_def *zero = nir_imm_int(b, 0);
430 return nir_load_buffer_amd(b, num_components, load_bit_size, ring, addr, off, zero,
431 .base = const_off,
432 .memory_modes = nir_var_shader_out,
433 .access = ACCESS_COHERENT);
434 } else if (out_mode == ms_out_mode_var) {
435 assert(load_bit_size == 32);
436 nir_def *arr[8] = {0};
437 for (unsigned comp = 0; comp < num_components; ++comp) {
438 unsigned idx = location * 4 + comp + component_addr_off;
439 arr[comp] = nir_load_var(b, s->out_variables[idx]);
440 }
441 return nir_vec(b, arr, num_components);
442 } else {
443 unreachable("Invalid MS output mode for load");
444 }
445 }
446
447 static nir_def *
lower_ms_load_workgroup_index(nir_builder * b,UNUSED nir_intrinsic_instr * intrin,lower_ngg_ms_state * s)448 lower_ms_load_workgroup_index(nir_builder *b,
449 UNUSED nir_intrinsic_instr *intrin,
450 lower_ngg_ms_state *s)
451 {
452 return s->workgroup_index;
453 }
454
455 static nir_def *
lower_ms_set_vertex_and_primitive_count(nir_builder * b,nir_intrinsic_instr * intrin,lower_ngg_ms_state * s)456 lower_ms_set_vertex_and_primitive_count(nir_builder *b,
457 nir_intrinsic_instr *intrin,
458 lower_ngg_ms_state *s)
459 {
460 /* If either the number of vertices or primitives is zero, set both of them to zero. */
461 nir_def *num_vtx = nir_read_first_invocation(b, intrin->src[0].ssa);
462 nir_def *num_prm = nir_read_first_invocation(b, intrin->src[1].ssa);
463 nir_def *zero = nir_imm_int(b, 0);
464 nir_def *is_either_zero = nir_ieq(b, nir_umin(b, num_vtx, num_prm), zero);
465 num_vtx = nir_bcsel(b, is_either_zero, zero, num_vtx);
466 num_prm = nir_bcsel(b, is_either_zero, zero, num_prm);
467
468 nir_store_var(b, s->vertex_count_var, num_vtx, 0x1);
469 nir_store_var(b, s->primitive_count_var, num_prm, 0x1);
470
471 return NIR_LOWER_INSTR_PROGRESS_REPLACE;
472 }
473
474 static nir_def *
update_ms_barrier(nir_builder * b,nir_intrinsic_instr * intrin,lower_ngg_ms_state * s)475 update_ms_barrier(nir_builder *b,
476 nir_intrinsic_instr *intrin,
477 lower_ngg_ms_state *s)
478 {
479 /* Output loads and stores are lowered to shared memory access,
480 * so we have to update the barriers to also reflect this.
481 */
482 unsigned mem_modes = nir_intrinsic_memory_modes(intrin);
483 if (mem_modes & nir_var_shader_out)
484 mem_modes |= nir_var_mem_shared;
485 else
486 return NULL;
487
488 nir_intrinsic_set_memory_modes(intrin, mem_modes);
489
490 return NIR_LOWER_INSTR_PROGRESS;
491 }
492
493 static nir_def *
lower_ms_intrinsic(nir_builder * b,nir_instr * instr,void * state)494 lower_ms_intrinsic(nir_builder *b, nir_instr *instr, void *state)
495 {
496 lower_ngg_ms_state *s = (lower_ngg_ms_state *) state;
497
498 if (instr->type != nir_instr_type_intrinsic)
499 return NULL;
500
501 nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
502
503 switch (intrin->intrinsic) {
504 case nir_intrinsic_store_per_vertex_output:
505 case nir_intrinsic_store_per_primitive_output:
506 ms_store_arrayed_output_intrin(b, intrin, s);
507 return NIR_LOWER_INSTR_PROGRESS_REPLACE;
508 case nir_intrinsic_barrier:
509 return update_ms_barrier(b, intrin, s);
510 case nir_intrinsic_load_workgroup_index:
511 return lower_ms_load_workgroup_index(b, intrin, s);
512 case nir_intrinsic_set_vertex_and_primitive_count:
513 return lower_ms_set_vertex_and_primitive_count(b, intrin, s);
514 default:
515 unreachable("Not a lowerable mesh shader intrinsic.");
516 }
517 }
518
519 static bool
filter_ms_intrinsic(const nir_instr * instr,UNUSED const void * s)520 filter_ms_intrinsic(const nir_instr *instr,
521 UNUSED const void *s)
522 {
523 if (instr->type != nir_instr_type_intrinsic)
524 return false;
525
526 nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
527 return intrin->intrinsic == nir_intrinsic_store_output ||
528 intrin->intrinsic == nir_intrinsic_load_output ||
529 intrin->intrinsic == nir_intrinsic_store_per_vertex_output ||
530 intrin->intrinsic == nir_intrinsic_store_per_primitive_output ||
531 intrin->intrinsic == nir_intrinsic_barrier ||
532 intrin->intrinsic == nir_intrinsic_load_workgroup_index ||
533 intrin->intrinsic == nir_intrinsic_set_vertex_and_primitive_count;
534 }
535
536 static void
lower_ms_intrinsics(nir_shader * shader,lower_ngg_ms_state * s)537 lower_ms_intrinsics(nir_shader *shader, lower_ngg_ms_state *s)
538 {
539 nir_shader_lower_instructions(shader, filter_ms_intrinsic, lower_ms_intrinsic, s);
540 }
541
542 static void
ms_emit_arrayed_outputs(nir_builder * b,nir_def * invocation_index,uint64_t mask,lower_ngg_ms_state * s)543 ms_emit_arrayed_outputs(nir_builder *b,
544 nir_def *invocation_index,
545 uint64_t mask,
546 lower_ngg_ms_state *s)
547 {
548 nir_def *zero = nir_imm_int(b, 0);
549
550 u_foreach_bit64(slot, mask) {
551 /* Should not occur here, handled separately. */
552 assert(slot != VARYING_SLOT_PRIMITIVE_COUNT && slot != VARYING_SLOT_PRIMITIVE_INDICES);
553
554 unsigned component_mask = s->out.infos[slot].components_mask;
555
556 while (component_mask) {
557 int start_comp = 0, num_components = 1;
558 u_bit_scan_consecutive_range(&component_mask, &start_comp, &num_components);
559
560 nir_def *load =
561 ms_load_arrayed_output(b, invocation_index, zero, slot, start_comp,
562 num_components, 32, s);
563
564 for (int i = 0; i < num_components; i++)
565 s->out.outputs[slot][start_comp + i] = nir_channel(b, load, i);
566 }
567 }
568 }
569
570 static void
ms_create_same_invocation_vars(nir_builder * b,lower_ngg_ms_state * s)571 ms_create_same_invocation_vars(nir_builder *b, lower_ngg_ms_state *s)
572 {
573 /* Initialize NIR variables for same-invocation outputs. */
574 uint64_t same_invocation_output_mask = s->layout.var.prm_attr.mask | s->layout.var.vtx_attr.mask;
575
576 u_foreach_bit64(slot, same_invocation_output_mask) {
577 for (unsigned comp = 0; comp < 4; ++comp) {
578 unsigned idx = slot * 4 + comp;
579 s->out_variables[idx] = nir_local_variable_create(b->impl, glsl_uint_type(), "ms_var_output");
580 }
581 }
582 }
583
584 static void
ms_emit_legacy_workgroup_index(nir_builder * b,lower_ngg_ms_state * s)585 ms_emit_legacy_workgroup_index(nir_builder *b, lower_ngg_ms_state *s)
586 {
587 /* Workgroup ID should have been lowered to workgroup index. */
588 assert(!BITSET_TEST(b->shader->info.system_values_read, SYSTEM_VALUE_WORKGROUP_ID));
589
590 /* No need to do anything if the shader doesn't use the workgroup index. */
591 if (!BITSET_TEST(b->shader->info.system_values_read, SYSTEM_VALUE_WORKGROUP_INDEX))
592 return;
593
594 b->cursor = nir_before_impl(b->impl);
595
596 /* Legacy fast launch mode (FAST_LAUNCH=1):
597 *
598 * The HW doesn't support a proper workgroup index for vertex processing stages,
599 * so we use the vertex ID which is equivalent to the index of the current workgroup
600 * within the current dispatch.
601 *
602 * Due to the register programming of mesh shaders, this value is only filled for
603 * the first invocation of the first wave. To let other waves know, we use LDS.
604 */
605 nir_def *workgroup_index = nir_load_vertex_id_zero_base(b);
606
607 if (s->api_workgroup_size <= s->wave_size) {
608 /* API workgroup is small, so we don't need to use LDS. */
609 s->workgroup_index = nir_read_first_invocation(b, workgroup_index);
610 return;
611 }
612
613 unsigned workgroup_index_lds_addr = s->layout.lds.workgroup_info_addr + lds_ms_wg_index;
614
615 nir_def *zero = nir_imm_int(b, 0);
616 nir_def *dont_care = nir_undef(b, 1, 32);
617 nir_def *loaded_workgroup_index = NULL;
618
619 /* Use elect to make sure only 1 invocation uses LDS. */
620 nir_if *if_elected = nir_push_if(b, nir_elect(b, 1));
621 {
622 nir_def *wave_id = nir_load_subgroup_id(b);
623 nir_if *if_wave_0 = nir_push_if(b, nir_ieq_imm(b, wave_id, 0));
624 {
625 nir_store_shared(b, workgroup_index, zero, .base = workgroup_index_lds_addr);
626 nir_barrier(b, .execution_scope = SCOPE_WORKGROUP,
627 .memory_scope = SCOPE_WORKGROUP,
628 .memory_semantics = NIR_MEMORY_ACQ_REL,
629 .memory_modes = nir_var_mem_shared);
630 }
631 nir_push_else(b, if_wave_0);
632 {
633 nir_barrier(b, .execution_scope = SCOPE_WORKGROUP,
634 .memory_scope = SCOPE_WORKGROUP,
635 .memory_semantics = NIR_MEMORY_ACQ_REL,
636 .memory_modes = nir_var_mem_shared);
637 loaded_workgroup_index = nir_load_shared(b, 1, 32, zero, .base = workgroup_index_lds_addr);
638 }
639 nir_pop_if(b, if_wave_0);
640
641 workgroup_index = nir_if_phi(b, workgroup_index, loaded_workgroup_index);
642 }
643 nir_pop_if(b, if_elected);
644
645 workgroup_index = nir_if_phi(b, workgroup_index, dont_care);
646 s->workgroup_index = nir_read_first_invocation(b, workgroup_index);
647 }
648
649 static void
set_ms_final_output_counts(nir_builder * b,lower_ngg_ms_state * s,nir_def ** out_num_prm,nir_def ** out_num_vtx)650 set_ms_final_output_counts(nir_builder *b,
651 lower_ngg_ms_state *s,
652 nir_def **out_num_prm,
653 nir_def **out_num_vtx)
654 {
655 /* The spec allows the numbers to be divergent, and in that case we need to
656 * use the values from the first invocation. Also the HW requires us to set
657 * both to 0 if either was 0.
658 *
659 * These are already done by the lowering.
660 */
661 nir_def *num_prm = nir_load_var(b, s->primitive_count_var);
662 nir_def *num_vtx = nir_load_var(b, s->vertex_count_var);
663
664 if (s->hw_workgroup_size <= s->wave_size) {
665 /* Single-wave mesh shader workgroup. */
666 ac_nir_ngg_alloc_vertices_and_primitives(b, num_vtx, num_prm, false);
667
668 *out_num_prm = num_prm;
669 *out_num_vtx = num_vtx;
670 return;
671 }
672
673 /* Multi-wave mesh shader workgroup:
674 * We need to use LDS to distribute the correct values to the other waves.
675 *
676 * TODO:
677 * If we can prove that the values are workgroup-uniform, we can skip this
678 * and just use whatever the current wave has. However, NIR divergence analysis
679 * currently doesn't support this.
680 */
681
682 nir_def *zero = nir_imm_int(b, 0);
683
684 nir_if *if_wave_0 = nir_push_if(b, nir_ieq_imm(b, nir_load_subgroup_id(b), 0));
685 {
686 nir_if *if_elected = nir_push_if(b, nir_elect(b, 1));
687 {
688 nir_store_shared(b, nir_vec2(b, num_prm, num_vtx), zero,
689 .base = s->layout.lds.workgroup_info_addr + lds_ms_num_prims);
690 }
691 nir_pop_if(b, if_elected);
692
693 nir_barrier(b, .execution_scope = SCOPE_WORKGROUP,
694 .memory_scope = SCOPE_WORKGROUP,
695 .memory_semantics = NIR_MEMORY_ACQ_REL,
696 .memory_modes = nir_var_mem_shared);
697
698 ac_nir_ngg_alloc_vertices_and_primitives(b, num_vtx, num_prm, false);
699 }
700 nir_push_else(b, if_wave_0);
701 {
702 nir_barrier(b, .execution_scope = SCOPE_WORKGROUP,
703 .memory_scope = SCOPE_WORKGROUP,
704 .memory_semantics = NIR_MEMORY_ACQ_REL,
705 .memory_modes = nir_var_mem_shared);
706
707 nir_def *prm_vtx = NULL;
708 nir_def *dont_care_2x32 = nir_undef(b, 2, 32);
709 nir_if *if_elected = nir_push_if(b, nir_elect(b, 1));
710 {
711 prm_vtx = nir_load_shared(b, 2, 32, zero,
712 .base = s->layout.lds.workgroup_info_addr + lds_ms_num_prims);
713 }
714 nir_pop_if(b, if_elected);
715
716 prm_vtx = nir_if_phi(b, prm_vtx, dont_care_2x32);
717 num_prm = nir_read_first_invocation(b, nir_channel(b, prm_vtx, 0));
718 num_vtx = nir_read_first_invocation(b, nir_channel(b, prm_vtx, 1));
719
720 nir_store_var(b, s->primitive_count_var, num_prm, 0x1);
721 nir_store_var(b, s->vertex_count_var, num_vtx, 0x1);
722 }
723 nir_pop_if(b, if_wave_0);
724
725 *out_num_prm = nir_load_var(b, s->primitive_count_var);
726 *out_num_vtx = nir_load_var(b, s->vertex_count_var);
727 }
728
729 static void
ms_emit_attribute_ring_output_stores(nir_builder * b,const uint64_t outputs_mask,nir_def * idx,lower_ngg_ms_state * s)730 ms_emit_attribute_ring_output_stores(nir_builder *b, const uint64_t outputs_mask,
731 nir_def *idx, lower_ngg_ms_state *s)
732 {
733 if (!outputs_mask)
734 return;
735
736 nir_def *ring = nir_load_ring_attr_amd(b);
737 nir_def *off = nir_load_ring_attr_offset_amd(b);
738 nir_def *zero = nir_imm_int(b, 0);
739
740 u_foreach_bit64 (slot, outputs_mask) {
741 if (s->vs_output_param_offset[slot] > AC_EXP_PARAM_OFFSET_31)
742 continue;
743
744 nir_def *soffset = nir_iadd_imm(b, off, s->vs_output_param_offset[slot] * 16 * 32);
745 nir_def *store_val = nir_undef(b, 4, 32);
746 unsigned store_val_components = 0;
747 for (unsigned c = 0; c < 4; ++c) {
748 if (s->out.outputs[slot][c]) {
749 store_val = nir_vector_insert_imm(b, store_val, s->out.outputs[slot][c], c);
750 store_val_components = c + 1;
751 }
752 }
753
754 store_val = nir_trim_vector(b, store_val, store_val_components);
755 nir_store_buffer_amd(b, store_val, ring, zero, soffset, idx,
756 .memory_modes = nir_var_shader_out,
757 .access = ACCESS_COHERENT | ACCESS_IS_SWIZZLED_AMD);
758 }
759 }
760
761 static nir_def *
ms_prim_exp_arg_ch1(nir_builder * b,nir_def * invocation_index,nir_def * num_vtx,lower_ngg_ms_state * s)762 ms_prim_exp_arg_ch1(nir_builder *b, nir_def *invocation_index, nir_def *num_vtx, lower_ngg_ms_state *s)
763 {
764 /* Primitive connectivity data: describes which vertices the primitive uses. */
765 nir_def *prim_idx_addr = nir_imul_imm(b, invocation_index, s->vertices_per_prim);
766 nir_def *indices_loaded = NULL;
767 nir_def *cull_flag = NULL;
768
769 if (s->layout.var.prm_attr.mask & BITFIELD64_BIT(VARYING_SLOT_PRIMITIVE_INDICES)) {
770 nir_def *indices[3] = {0};
771 for (unsigned c = 0; c < s->vertices_per_prim; ++c)
772 indices[c] = nir_load_var(b, s->out_variables[VARYING_SLOT_PRIMITIVE_INDICES * 4 + c]);
773 indices_loaded = nir_vec(b, indices, s->vertices_per_prim);
774 } else {
775 indices_loaded = nir_load_shared(b, s->vertices_per_prim, 8, prim_idx_addr, .base = s->layout.lds.indices_addr);
776 indices_loaded = nir_u2u32(b, indices_loaded);
777 }
778
779 if (s->uses_cull_flags) {
780 nir_def *loaded_cull_flag = NULL;
781 if (s->layout.var.prm_attr.mask & BITFIELD64_BIT(VARYING_SLOT_CULL_PRIMITIVE))
782 loaded_cull_flag = nir_load_var(b, s->out_variables[VARYING_SLOT_CULL_PRIMITIVE * 4]);
783 else
784 loaded_cull_flag = nir_u2u32(b, nir_load_shared(b, 1, 8, prim_idx_addr, .base = s->layout.lds.cull_flags_addr));
785
786 cull_flag = nir_i2b(b, loaded_cull_flag);
787 }
788
789 nir_def *indices[3];
790 nir_def *max_vtx_idx = nir_iadd_imm(b, num_vtx, -1u);
791
792 for (unsigned i = 0; i < s->vertices_per_prim; ++i) {
793 indices[i] = nir_channel(b, indices_loaded, i);
794 indices[i] = nir_umin(b, indices[i], max_vtx_idx);
795 }
796
797 return ac_nir_pack_ngg_prim_exp_arg(b, s->vertices_per_prim, indices, cull_flag, s->hw_info->gfx_level);
798 }
799
800 static nir_def *
ms_prim_exp_arg_ch2(nir_builder * b,uint64_t outputs_mask,lower_ngg_ms_state * s)801 ms_prim_exp_arg_ch2(nir_builder *b, uint64_t outputs_mask, lower_ngg_ms_state *s)
802 {
803 nir_def *prim_exp_arg_ch2 = NULL;
804
805 if (outputs_mask) {
806 /* When layer, viewport etc. are per-primitive, they need to be encoded in
807 * the primitive export instruction's second channel. The encoding is:
808 *
809 * --- GFX10.3 ---
810 * bits 31..30: VRS rate Y
811 * bits 29..28: VRS rate X
812 * bits 23..20: viewport
813 * bits 19..17: layer
814 *
815 * --- GFX11 ---
816 * bits 31..28: VRS rate enum
817 * bits 23..20: viewport
818 * bits 12..00: layer
819 */
820 prim_exp_arg_ch2 = nir_imm_int(b, 0);
821
822 if (outputs_mask & VARYING_BIT_LAYER) {
823 nir_def *layer =
824 nir_ishl_imm(b, s->out.outputs[VARYING_SLOT_LAYER][0], s->hw_info->gfx_level >= GFX11 ? 0 : 17);
825 prim_exp_arg_ch2 = nir_ior(b, prim_exp_arg_ch2, layer);
826 }
827
828 if (outputs_mask & VARYING_BIT_VIEWPORT) {
829 nir_def *view = nir_ishl_imm(b, s->out.outputs[VARYING_SLOT_VIEWPORT][0], 20);
830 prim_exp_arg_ch2 = nir_ior(b, prim_exp_arg_ch2, view);
831 }
832
833 if (outputs_mask & VARYING_BIT_PRIMITIVE_SHADING_RATE) {
834 nir_def *rate = s->out.outputs[VARYING_SLOT_PRIMITIVE_SHADING_RATE][0];
835 prim_exp_arg_ch2 = nir_ior(b, prim_exp_arg_ch2, rate);
836 }
837 }
838
839 return prim_exp_arg_ch2;
840 }
841
842 static void
ms_prim_gen_query(nir_builder * b,nir_def * invocation_index,nir_def * num_prm,lower_ngg_ms_state * s)843 ms_prim_gen_query(nir_builder *b,
844 nir_def *invocation_index,
845 nir_def *num_prm,
846 lower_ngg_ms_state *s)
847 {
848 if (!s->has_query)
849 return;
850
851 nir_if *if_invocation_index_zero = nir_push_if(b, nir_ieq_imm(b, invocation_index, 0));
852 {
853 nir_if *if_shader_query = nir_push_if(b, nir_load_prim_gen_query_enabled_amd(b));
854 {
855 nir_atomic_add_gen_prim_count_amd(b, num_prm, .stream_id = 0);
856 }
857 nir_pop_if(b, if_shader_query);
858 }
859 nir_pop_if(b, if_invocation_index_zero);
860 }
861
862 static void
ms_invocation_query(nir_builder * b,nir_def * invocation_index,lower_ngg_ms_state * s)863 ms_invocation_query(nir_builder *b,
864 nir_def *invocation_index,
865 lower_ngg_ms_state *s)
866 {
867 if (!s->has_query)
868 return;
869
870 nir_if *if_invocation_index_zero = nir_push_if(b, nir_ieq_imm(b, invocation_index, 0));
871 {
872 nir_if *if_pipeline_query = nir_push_if(b, nir_load_pipeline_stat_query_enabled_amd(b));
873 {
874 nir_atomic_add_shader_invocation_count_amd(b, nir_imm_int(b, s->api_workgroup_size));
875 }
876 nir_pop_if(b, if_pipeline_query);
877 }
878 nir_pop_if(b, if_invocation_index_zero);
879 }
880
881 static void
emit_ms_vertex(nir_builder * b,nir_def * index,nir_def * row,bool exports,bool parameters,uint64_t per_vertex_outputs,lower_ngg_ms_state * s)882 emit_ms_vertex(nir_builder *b, nir_def *index, nir_def *row, bool exports, bool parameters,
883 uint64_t per_vertex_outputs, lower_ngg_ms_state *s)
884 {
885 ms_emit_arrayed_outputs(b, index, per_vertex_outputs, s);
886
887 if (exports) {
888 ac_nir_export_position(b, s->hw_info->gfx_level, s->clipdist_enable_mask,
889 !s->has_param_exports, false, true,
890 s->per_vertex_outputs | VARYING_BIT_POS, &s->out, row);
891 }
892
893 if (parameters) {
894 /* Export generic attributes when there is no attribute ring. */
895 if (s->has_param_exports && !s->hw_info->has_attr_ring) {
896 ac_nir_export_parameters(b, s->vs_output_param_offset, per_vertex_outputs, 0, &s->out);
897 }
898
899 /* Also store special outputs to the attribute ring so PS can load them. */
900 if (s->hw_info->has_attr_ring && (per_vertex_outputs & MS_VERT_ARG_EXP_MASK))
901 ms_emit_attribute_ring_output_stores(b, per_vertex_outputs & MS_VERT_ARG_EXP_MASK, index, s);
902 }
903 }
904
905 static void
emit_ms_primitive(nir_builder * b,nir_def * index,nir_def * row,bool exports,bool parameters,uint64_t per_primitive_outputs,lower_ngg_ms_state * s)906 emit_ms_primitive(nir_builder *b, nir_def *index, nir_def *row, bool exports, bool parameters,
907 uint64_t per_primitive_outputs, lower_ngg_ms_state *s)
908 {
909 ms_emit_arrayed_outputs(b, index, per_primitive_outputs, s);
910
911 /* Insert layer output store if the pipeline uses multiview but the API shader doesn't write it. */
912 if (s->insert_layer_output) {
913 s->out.outputs[VARYING_SLOT_LAYER][0] = nir_load_view_index(b);
914 s->out.infos[VARYING_SLOT_LAYER].as_sysval_mask |= 1;
915 }
916
917 if (exports) {
918 const uint64_t outputs_mask = per_primitive_outputs & MS_PRIM_ARG_EXP_MASK;
919 nir_def *num_vtx = nir_load_var(b, s->vertex_count_var);
920 nir_def *prim_exp_arg_ch1 = ms_prim_exp_arg_ch1(b, index, num_vtx, s);
921 nir_def *prim_exp_arg_ch2 = ms_prim_exp_arg_ch2(b, outputs_mask, s);
922
923 nir_def *prim_exp_arg = prim_exp_arg_ch2 ?
924 nir_vec2(b, prim_exp_arg_ch1, prim_exp_arg_ch2) : prim_exp_arg_ch1;
925
926 ac_nir_export_primitive(b, prim_exp_arg, row);
927 }
928
929 if (parameters) {
930 /* Export generic attributes when there is no attribute ring. */
931 if (s->has_param_exports && !s->hw_info->has_attr_ring) {
932 ac_nir_export_parameters(b, s->vs_output_param_offset, per_primitive_outputs, 0, &s->out);
933 }
934
935 /* Also store special outputs to the attribute ring so PS can load them. */
936 if (s->hw_info->has_attr_ring)
937 ms_emit_attribute_ring_output_stores(b, per_primitive_outputs & MS_PRIM_ARG_EXP_MASK, index, s);
938 }
939 }
940
941 static void
emit_ms_outputs(nir_builder * b,nir_def * invocation_index,nir_def * row_start,nir_def * count,bool exports,bool parameters,uint64_t mask,void (* cb)(nir_builder *,nir_def *,nir_def *,bool,bool,uint64_t,lower_ngg_ms_state *),lower_ngg_ms_state * s)942 emit_ms_outputs(nir_builder *b, nir_def *invocation_index, nir_def *row_start,
943 nir_def *count, bool exports, bool parameters, uint64_t mask,
944 void (*cb)(nir_builder *, nir_def *, nir_def *, bool, bool,
945 uint64_t, lower_ngg_ms_state *),
946 lower_ngg_ms_state *s)
947 {
948 if (cb == &emit_ms_primitive ? s->prim_multirow_export : s->vert_multirow_export) {
949 assert(s->hw_workgroup_size % s->wave_size == 0);
950 const unsigned num_waves = s->hw_workgroup_size / s->wave_size;
951
952 nir_loop *row_loop = nir_push_loop(b);
953 {
954 nir_block *preheader = nir_cf_node_as_block(nir_cf_node_prev(&row_loop->cf_node));
955
956 nir_phi_instr *index = nir_phi_instr_create(b->shader);
957 nir_phi_instr *row = nir_phi_instr_create(b->shader);
958 nir_def_init(&index->instr, &index->def, 1, 32);
959 nir_def_init(&row->instr, &row->def, 1, 32);
960
961 nir_phi_instr_add_src(index, preheader, invocation_index);
962 nir_phi_instr_add_src(row, preheader, row_start);
963
964 nir_if *if_break = nir_push_if(b, nir_uge(b, &index->def, count));
965 {
966 nir_jump(b, nir_jump_break);
967 }
968 nir_pop_if(b, if_break);
969
970 cb(b, &index->def, &row->def, exports, parameters, mask, s);
971
972 nir_block *body = nir_cursor_current_block(b->cursor);
973 nir_phi_instr_add_src(index, body,
974 nir_iadd_imm(b, &index->def, s->hw_workgroup_size));
975 nir_phi_instr_add_src(row, body,
976 nir_iadd_imm(b, &row->def, num_waves));
977
978 nir_instr_insert_before_cf_list(&row_loop->body, &row->instr);
979 nir_instr_insert_before_cf_list(&row_loop->body, &index->instr);
980 }
981 nir_pop_loop(b, row_loop);
982 } else {
983 nir_def *has_output = nir_ilt(b, invocation_index, count);
984 nir_if *if_has_output = nir_push_if(b, has_output);
985 {
986 cb(b, invocation_index, row_start, exports, parameters, mask, s);
987 }
988 nir_pop_if(b, if_has_output);
989 }
990 }
991
992 static void
emit_ms_finale(nir_builder * b,lower_ngg_ms_state * s)993 emit_ms_finale(nir_builder *b, lower_ngg_ms_state *s)
994 {
995 /* We assume there is always a single end block in the shader. */
996 nir_block *last_block = nir_impl_last_block(b->impl);
997 b->cursor = nir_after_block(last_block);
998
999 nir_barrier(b, .execution_scope=SCOPE_WORKGROUP, .memory_scope=SCOPE_WORKGROUP,
1000 .memory_semantics=NIR_MEMORY_ACQ_REL, .memory_modes=nir_var_shader_out|nir_var_mem_shared);
1001
1002 nir_def *num_prm;
1003 nir_def *num_vtx;
1004
1005 set_ms_final_output_counts(b, s, &num_prm, &num_vtx);
1006
1007 nir_def *invocation_index = nir_load_local_invocation_index(b);
1008
1009 ms_prim_gen_query(b, invocation_index, num_prm, s);
1010
1011 nir_def *row_start = NULL;
1012 if (s->fast_launch_2)
1013 row_start = s->hw_workgroup_size <= s->wave_size ? nir_imm_int(b, 0) : nir_load_subgroup_id(b);
1014
1015 /* Load vertex/primitive attributes from shared memory and
1016 * emit store_output intrinsics for them.
1017 *
1018 * Contrary to the semantics of the API mesh shader, these are now
1019 * compliant with NGG HW semantics, meaning that these store the
1020 * current thread's vertex attributes in a way the HW can export.
1021 */
1022
1023 uint64_t per_vertex_outputs =
1024 s->per_vertex_outputs & ~s->layout.attr_ring.vtx_attr.mask;
1025 uint64_t per_primitive_outputs =
1026 s->per_primitive_outputs & ~s->layout.attr_ring.prm_attr.mask & ~SPECIAL_MS_OUT_MASK;
1027
1028 /* Insert layer output store if the pipeline uses multiview but the API shader doesn't write it. */
1029 if (s->insert_layer_output) {
1030 b->shader->info.outputs_written |= VARYING_BIT_LAYER;
1031 b->shader->info.per_primitive_outputs |= VARYING_BIT_LAYER;
1032 per_primitive_outputs |= VARYING_BIT_LAYER;
1033 }
1034
1035 const bool has_special_param_exports =
1036 (per_vertex_outputs & MS_VERT_ARG_EXP_MASK) ||
1037 (per_primitive_outputs & MS_PRIM_ARG_EXP_MASK);
1038 const bool wait_attr_ring = has_special_param_exports && s->hw_info->has_attr_ring_wait_bug;
1039
1040 /* Export vertices. */
1041 if ((per_vertex_outputs & ~VARYING_BIT_POS) || !wait_attr_ring) {
1042 emit_ms_outputs(b, invocation_index, row_start, num_vtx, !wait_attr_ring, true,
1043 per_vertex_outputs, &emit_ms_vertex, s);
1044 }
1045
1046 /* Export primitives. */
1047 if (per_primitive_outputs || !wait_attr_ring) {
1048 emit_ms_outputs(b, invocation_index, row_start, num_prm, !wait_attr_ring, true,
1049 per_primitive_outputs, &emit_ms_primitive, s);
1050 }
1051
1052 /* When we need to wait for attribute ring stores, we emit both position and primitive
1053 * export instructions after a barrier to make sure both per-vertex and per-primitive
1054 * attribute ring stores are finished before the GPU starts rasterization.
1055 */
1056 if (wait_attr_ring) {
1057 /* Wait for attribute stores to finish. */
1058 nir_barrier(b, .execution_scope = SCOPE_SUBGROUP,
1059 .memory_scope = SCOPE_DEVICE,
1060 .memory_semantics = NIR_MEMORY_RELEASE,
1061 .memory_modes = nir_var_shader_out);
1062
1063 /* Position/primitive export only */
1064 emit_ms_outputs(b, invocation_index, row_start, num_vtx, true, false,
1065 per_vertex_outputs, &emit_ms_vertex, s);
1066 emit_ms_outputs(b, invocation_index, row_start, num_prm, true, false,
1067 per_primitive_outputs, &emit_ms_primitive, s);
1068 }
1069 }
1070
1071 static void
handle_smaller_ms_api_workgroup(nir_builder * b,lower_ngg_ms_state * s)1072 handle_smaller_ms_api_workgroup(nir_builder *b,
1073 lower_ngg_ms_state *s)
1074 {
1075 if (s->api_workgroup_size >= s->hw_workgroup_size)
1076 return;
1077
1078 /* Handle barriers manually when the API workgroup
1079 * size is less than the HW workgroup size.
1080 *
1081 * The problem is that the real workgroup launched on NGG HW
1082 * will be larger than the size specified by the API, and the
1083 * extra waves need to keep up with barriers in the API waves.
1084 *
1085 * There are 2 different cases:
1086 * 1. The whole API workgroup fits in a single wave.
1087 * We can shrink the barriers to subgroup scope and
1088 * don't need to insert any extra ones.
1089 * 2. The API workgroup occupies multiple waves, but not
1090 * all. In this case, we emit code that consumes every
1091 * barrier on the extra waves.
1092 */
1093 assert(s->hw_workgroup_size % s->wave_size == 0);
1094 bool scan_barriers = ALIGN(s->api_workgroup_size, s->wave_size) < s->hw_workgroup_size;
1095 bool can_shrink_barriers = s->api_workgroup_size <= s->wave_size;
1096 bool need_additional_barriers = scan_barriers && !can_shrink_barriers;
1097
1098 unsigned api_waves_in_flight_addr = s->layout.lds.workgroup_info_addr + lds_ms_num_api_waves;
1099 unsigned num_api_waves = DIV_ROUND_UP(s->api_workgroup_size, s->wave_size);
1100
1101 /* Scan the shader for workgroup barriers. */
1102 if (scan_barriers) {
1103 bool has_any_workgroup_barriers = false;
1104
1105 nir_foreach_block(block, b->impl) {
1106 nir_foreach_instr_safe(instr, block) {
1107 if (instr->type != nir_instr_type_intrinsic)
1108 continue;
1109
1110 nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
1111 bool is_workgroup_barrier =
1112 intrin->intrinsic == nir_intrinsic_barrier &&
1113 nir_intrinsic_execution_scope(intrin) == SCOPE_WORKGROUP;
1114
1115 if (!is_workgroup_barrier)
1116 continue;
1117
1118 if (can_shrink_barriers) {
1119 /* Every API invocation runs in the first wave.
1120 * In this case, we can change the barriers to subgroup scope
1121 * and avoid adding additional barriers.
1122 */
1123 nir_intrinsic_set_memory_scope(intrin, SCOPE_SUBGROUP);
1124 nir_intrinsic_set_execution_scope(intrin, SCOPE_SUBGROUP);
1125 } else {
1126 has_any_workgroup_barriers = true;
1127 }
1128 }
1129 }
1130
1131 need_additional_barriers &= has_any_workgroup_barriers;
1132 }
1133
1134 /* Extract the full control flow of the shader. */
1135 nir_cf_list extracted;
1136 nir_cf_extract(&extracted, nir_before_impl(b->impl),
1137 nir_after_cf_list(&b->impl->body));
1138 b->cursor = nir_before_impl(b->impl);
1139
1140 /* Wrap the shader in an if to ensure that only the necessary amount of lanes run it. */
1141 nir_def *invocation_index = nir_load_local_invocation_index(b);
1142 nir_def *zero = nir_imm_int(b, 0);
1143
1144 if (need_additional_barriers) {
1145 /* First invocation stores 0 to number of API waves in flight. */
1146 nir_if *if_first_in_workgroup = nir_push_if(b, nir_ieq_imm(b, invocation_index, 0));
1147 {
1148 nir_store_shared(b, nir_imm_int(b, num_api_waves), zero, .base = api_waves_in_flight_addr);
1149 }
1150 nir_pop_if(b, if_first_in_workgroup);
1151
1152 nir_barrier(b, .execution_scope = SCOPE_WORKGROUP,
1153 .memory_scope = SCOPE_WORKGROUP,
1154 .memory_semantics = NIR_MEMORY_ACQ_REL,
1155 .memory_modes = nir_var_shader_out | nir_var_mem_shared);
1156 }
1157
1158 nir_def *has_api_ms_invocation = nir_ult_imm(b, invocation_index, s->api_workgroup_size);
1159 nir_if *if_has_api_ms_invocation = nir_push_if(b, has_api_ms_invocation);
1160 {
1161 nir_cf_reinsert(&extracted, b->cursor);
1162 b->cursor = nir_after_cf_list(&if_has_api_ms_invocation->then_list);
1163
1164 if (need_additional_barriers) {
1165 /* One invocation in each API wave decrements the number of API waves in flight. */
1166 nir_if *if_elected_again = nir_push_if(b, nir_elect(b, 1));
1167 {
1168 nir_shared_atomic(b, 32, zero, nir_imm_int(b, -1u),
1169 .base = api_waves_in_flight_addr,
1170 .atomic_op = nir_atomic_op_iadd);
1171 }
1172 nir_pop_if(b, if_elected_again);
1173
1174 nir_barrier(b, .execution_scope = SCOPE_WORKGROUP,
1175 .memory_scope = SCOPE_WORKGROUP,
1176 .memory_semantics = NIR_MEMORY_ACQ_REL,
1177 .memory_modes = nir_var_shader_out | nir_var_mem_shared);
1178 }
1179
1180 ms_invocation_query(b, invocation_index, s);
1181 }
1182 nir_pop_if(b, if_has_api_ms_invocation);
1183
1184 if (need_additional_barriers) {
1185 /* Make sure that waves that don't run any API invocations execute
1186 * the same amount of barriers as those that do.
1187 *
1188 * We do this by executing a barrier until the number of API waves
1189 * in flight becomes zero.
1190 */
1191 nir_def *has_api_ms_ballot = nir_ballot(b, 1, s->wave_size, has_api_ms_invocation);
1192 nir_def *wave_has_no_api_ms = nir_ieq_imm(b, has_api_ms_ballot, 0);
1193 nir_if *if_wave_has_no_api_ms = nir_push_if(b, wave_has_no_api_ms);
1194 {
1195 nir_if *if_elected = nir_push_if(b, nir_elect(b, 1));
1196 {
1197 nir_loop *loop = nir_push_loop(b);
1198 {
1199 nir_barrier(b, .execution_scope = SCOPE_WORKGROUP,
1200 .memory_scope = SCOPE_WORKGROUP,
1201 .memory_semantics = NIR_MEMORY_ACQ_REL,
1202 .memory_modes = nir_var_shader_out | nir_var_mem_shared);
1203
1204 nir_def *loaded = nir_load_shared(b, 1, 32, zero, .base = api_waves_in_flight_addr);
1205 nir_if *if_break = nir_push_if(b, nir_ieq_imm(b, loaded, 0));
1206 {
1207 nir_jump(b, nir_jump_break);
1208 }
1209 nir_pop_if(b, if_break);
1210 }
1211 nir_pop_loop(b, loop);
1212 }
1213 nir_pop_if(b, if_elected);
1214 }
1215 nir_pop_if(b, if_wave_has_no_api_ms);
1216 }
1217 }
1218
1219 static void
ms_move_output(ms_out_part * from,ms_out_part * to)1220 ms_move_output(ms_out_part *from, ms_out_part *to)
1221 {
1222 uint64_t loc = util_logbase2_64(from->mask);
1223 uint64_t bit = BITFIELD64_BIT(loc);
1224 from->mask ^= bit;
1225 to->mask |= bit;
1226 }
1227
1228 static void
ms_calculate_arrayed_output_layout(ms_out_mem_layout * l,unsigned max_vertices,unsigned max_primitives)1229 ms_calculate_arrayed_output_layout(ms_out_mem_layout *l,
1230 unsigned max_vertices,
1231 unsigned max_primitives)
1232 {
1233 uint32_t lds_vtx_attr_size = util_bitcount64(l->lds.vtx_attr.mask) * max_vertices * 16;
1234 uint32_t lds_prm_attr_size = util_bitcount64(l->lds.prm_attr.mask) * max_primitives * 16;
1235 l->lds.prm_attr.addr = ALIGN(l->lds.vtx_attr.addr + lds_vtx_attr_size, 16);
1236 l->lds.total_size = l->lds.prm_attr.addr + lds_prm_attr_size;
1237
1238 uint32_t scratch_ring_vtx_attr_size =
1239 util_bitcount64(l->scratch_ring.vtx_attr.mask) * max_vertices * 16;
1240 l->scratch_ring.prm_attr.addr =
1241 ALIGN(l->scratch_ring.vtx_attr.addr + scratch_ring_vtx_attr_size, 16);
1242 }
1243
1244 static ms_out_mem_layout
ms_calculate_output_layout(const struct radeon_info * hw_info,unsigned api_shared_size,uint64_t per_vertex_output_mask,uint64_t per_primitive_output_mask,uint64_t cross_invocation_output_access,unsigned max_vertices,unsigned max_primitives,unsigned vertices_per_prim)1245 ms_calculate_output_layout(const struct radeon_info *hw_info, unsigned api_shared_size,
1246 uint64_t per_vertex_output_mask, uint64_t per_primitive_output_mask,
1247 uint64_t cross_invocation_output_access, unsigned max_vertices,
1248 unsigned max_primitives, unsigned vertices_per_prim)
1249 {
1250 /* These outputs always need export instructions and can't use the attributes ring. */
1251 const uint64_t always_export_mask =
1252 VARYING_BIT_POS | VARYING_BIT_CULL_DIST0 | VARYING_BIT_CULL_DIST1 | VARYING_BIT_CLIP_DIST0 |
1253 VARYING_BIT_CLIP_DIST1 | VARYING_BIT_PSIZ | VARYING_BIT_VIEWPORT |
1254 VARYING_BIT_PRIMITIVE_SHADING_RATE | VARYING_BIT_LAYER |
1255 BITFIELD64_BIT(VARYING_SLOT_PRIMITIVE_COUNT) |
1256 BITFIELD64_BIT(VARYING_SLOT_PRIMITIVE_INDICES) | BITFIELD64_BIT(VARYING_SLOT_CULL_PRIMITIVE);
1257
1258 const bool use_attr_ring = hw_info->has_attr_ring;
1259 const uint64_t attr_ring_per_vertex_output_mask =
1260 use_attr_ring ? per_vertex_output_mask & ~always_export_mask : 0;
1261 const uint64_t attr_ring_per_primitive_output_mask =
1262 use_attr_ring ? per_primitive_output_mask & ~always_export_mask : 0;
1263
1264 const uint64_t lds_per_vertex_output_mask =
1265 per_vertex_output_mask & ~attr_ring_per_vertex_output_mask & cross_invocation_output_access &
1266 ~SPECIAL_MS_OUT_MASK;
1267 const uint64_t lds_per_primitive_output_mask =
1268 per_primitive_output_mask & ~attr_ring_per_primitive_output_mask &
1269 cross_invocation_output_access & ~SPECIAL_MS_OUT_MASK;
1270
1271 const bool cross_invocation_indices =
1272 cross_invocation_output_access & BITFIELD64_BIT(VARYING_SLOT_PRIMITIVE_INDICES);
1273 const bool cross_invocation_cull_primitive =
1274 cross_invocation_output_access & BITFIELD64_BIT(VARYING_SLOT_CULL_PRIMITIVE);
1275
1276 /* Shared memory used by the API shader. */
1277 ms_out_mem_layout l = { .lds = { .total_size = api_shared_size } };
1278
1279 /* Use attribute ring for all generic attributes (on GPUs with an attribute ring). */
1280 l.attr_ring.vtx_attr.mask = attr_ring_per_vertex_output_mask;
1281 l.attr_ring.prm_attr.mask = attr_ring_per_primitive_output_mask;
1282
1283 /* Outputs without cross-invocation access can be stored in variables. */
1284 l.var.vtx_attr.mask =
1285 per_vertex_output_mask & ~attr_ring_per_vertex_output_mask & ~cross_invocation_output_access;
1286 l.var.prm_attr.mask = per_primitive_output_mask & ~attr_ring_per_primitive_output_mask &
1287 ~cross_invocation_output_access;
1288
1289 /* Workgroup information, see ms_workgroup_* for the layout. */
1290 l.lds.workgroup_info_addr = ALIGN(l.lds.total_size, 16);
1291 l.lds.total_size = l.lds.workgroup_info_addr + 16;
1292
1293 /* Per-vertex and per-primitive output attributes.
1294 * Outputs without cross-invocation access are not included here.
1295 * First, try to put all outputs into LDS (shared memory).
1296 * If they don't fit, try to move them to VRAM one by one.
1297 */
1298 l.lds.vtx_attr.addr = ALIGN(l.lds.total_size, 16);
1299 l.lds.vtx_attr.mask = lds_per_vertex_output_mask;
1300 l.lds.prm_attr.mask = lds_per_primitive_output_mask;
1301 ms_calculate_arrayed_output_layout(&l, max_vertices, max_primitives);
1302
1303 /* NGG shaders can only address up to 32K LDS memory.
1304 * The spec requires us to allow the application to use at least up to 28K
1305 * shared memory. Additionally, we reserve 2K for driver internal use
1306 * (eg. primitive indices and such, see below).
1307 *
1308 * Move the outputs that do not fit LDS, to VRAM.
1309 * Start with per-primitive attributes, because those are grouped at the end.
1310 */
1311 const unsigned usable_lds_kbytes =
1312 (cross_invocation_cull_primitive || cross_invocation_indices) ? 30 : 31;
1313 while (l.lds.total_size >= usable_lds_kbytes * 1024) {
1314 if (l.lds.prm_attr.mask)
1315 ms_move_output(&l.lds.prm_attr, &l.scratch_ring.prm_attr);
1316 else if (l.lds.vtx_attr.mask)
1317 ms_move_output(&l.lds.vtx_attr, &l.scratch_ring.vtx_attr);
1318 else
1319 unreachable("API shader uses too much shared memory.");
1320
1321 ms_calculate_arrayed_output_layout(&l, max_vertices, max_primitives);
1322 }
1323
1324 if (cross_invocation_indices) {
1325 /* Indices: flat array of 8-bit vertex indices for each primitive. */
1326 l.lds.indices_addr = ALIGN(l.lds.total_size, 16);
1327 l.lds.total_size = l.lds.indices_addr + max_primitives * vertices_per_prim;
1328 }
1329
1330 if (cross_invocation_cull_primitive) {
1331 /* Cull flags: array of 8-bit cull flags for each primitive, 1=cull, 0=keep. */
1332 l.lds.cull_flags_addr = ALIGN(l.lds.total_size, 16);
1333 l.lds.total_size = l.lds.cull_flags_addr + max_primitives;
1334 }
1335
1336 /* NGG is only allowed to address up to 32K of LDS. */
1337 assert(l.lds.total_size <= 32 * 1024);
1338 return l;
1339 }
1340
1341 void
ac_nir_lower_ngg_mesh(nir_shader * shader,const struct radeon_info * hw_info,uint32_t clipdist_enable_mask,const uint8_t * vs_output_param_offset,bool has_param_exports,bool * out_needs_scratch_ring,unsigned wave_size,unsigned hw_workgroup_size,bool multiview,bool has_query,bool fast_launch_2)1342 ac_nir_lower_ngg_mesh(nir_shader *shader,
1343 const struct radeon_info *hw_info,
1344 uint32_t clipdist_enable_mask,
1345 const uint8_t *vs_output_param_offset,
1346 bool has_param_exports,
1347 bool *out_needs_scratch_ring,
1348 unsigned wave_size,
1349 unsigned hw_workgroup_size,
1350 bool multiview,
1351 bool has_query,
1352 bool fast_launch_2)
1353 {
1354 unsigned vertices_per_prim =
1355 mesa_vertices_per_prim(shader->info.mesh.primitive_type);
1356
1357 uint64_t per_vertex_outputs =
1358 shader->info.outputs_written & ~shader->info.per_primitive_outputs & ~SPECIAL_MS_OUT_MASK;
1359 uint64_t per_primitive_outputs =
1360 shader->info.per_primitive_outputs & shader->info.outputs_written;
1361
1362 /* Whether the shader uses CullPrimitiveEXT */
1363 bool uses_cull = shader->info.outputs_written & BITFIELD64_BIT(VARYING_SLOT_CULL_PRIMITIVE);
1364 /* Can't handle indirect register addressing, pretend as if they were cross-invocation. */
1365 uint64_t cross_invocation_access = shader->info.mesh.ms_cross_invocation_output_access |
1366 shader->info.outputs_accessed_indirectly;
1367
1368 unsigned max_vertices = shader->info.mesh.max_vertices_out;
1369 unsigned max_primitives = shader->info.mesh.max_primitives_out;
1370
1371 ms_out_mem_layout layout = ms_calculate_output_layout(
1372 hw_info, shader->info.shared_size, per_vertex_outputs, per_primitive_outputs,
1373 cross_invocation_access, max_vertices, max_primitives, vertices_per_prim);
1374
1375 shader->info.shared_size = layout.lds.total_size;
1376 *out_needs_scratch_ring = layout.scratch_ring.vtx_attr.mask || layout.scratch_ring.prm_attr.mask;
1377
1378 /* The workgroup size that is specified by the API shader may be different
1379 * from the size of the workgroup that actually runs on the HW, due to the
1380 * limitations of NGG: max 0/1 vertex and 0/1 primitive per lane is allowed.
1381 *
1382 * Therefore, we must make sure that when the API workgroup size is smaller,
1383 * we don't run the API shader on more HW invocations than is necessary.
1384 */
1385 unsigned api_workgroup_size = shader->info.workgroup_size[0] *
1386 shader->info.workgroup_size[1] *
1387 shader->info.workgroup_size[2];
1388
1389 lower_ngg_ms_state state = {
1390 .layout = layout,
1391 .wave_size = wave_size,
1392 .per_vertex_outputs = per_vertex_outputs,
1393 .per_primitive_outputs = per_primitive_outputs,
1394 .vertices_per_prim = vertices_per_prim,
1395 .api_workgroup_size = api_workgroup_size,
1396 .hw_workgroup_size = hw_workgroup_size,
1397 .insert_layer_output = multiview && !(shader->info.outputs_written & VARYING_BIT_LAYER),
1398 .uses_cull_flags = uses_cull,
1399 .hw_info = hw_info,
1400 .fast_launch_2 = fast_launch_2,
1401 .vert_multirow_export = fast_launch_2 && max_vertices > hw_workgroup_size,
1402 .prim_multirow_export = fast_launch_2 && max_primitives > hw_workgroup_size,
1403 .clipdist_enable_mask = clipdist_enable_mask,
1404 .vs_output_param_offset = vs_output_param_offset,
1405 .has_param_exports = has_param_exports,
1406 .has_query = has_query,
1407 };
1408
1409 nir_function_impl *impl = nir_shader_get_entrypoint(shader);
1410 assert(impl);
1411
1412 state.vertex_count_var =
1413 nir_local_variable_create(impl, glsl_uint_type(), "vertex_count_var");
1414 state.primitive_count_var =
1415 nir_local_variable_create(impl, glsl_uint_type(), "primitive_count_var");
1416
1417 nir_builder builder = nir_builder_at(nir_before_impl(impl));
1418 nir_builder *b = &builder; /* This is to avoid the & */
1419
1420 handle_smaller_ms_api_workgroup(b, &state);
1421 if (!fast_launch_2)
1422 ms_emit_legacy_workgroup_index(b, &state);
1423 ms_create_same_invocation_vars(b, &state);
1424 nir_metadata_preserve(impl, nir_metadata_none);
1425
1426 lower_ms_intrinsics(shader, &state);
1427
1428 emit_ms_finale(b, &state);
1429 nir_metadata_preserve(impl, nir_metadata_none);
1430
1431 /* Cleanup */
1432 nir_lower_vars_to_ssa(shader);
1433 nir_remove_dead_variables(shader, nir_var_function_temp, NULL);
1434 nir_lower_alu_to_scalar(shader, NULL, NULL);
1435 nir_lower_phis_to_scalar(shader, true);
1436
1437 /* Optimize load_local_invocation_index. When the API workgroup is smaller than the HW workgroup,
1438 * local_invocation_id isn't initialized for all lanes and we can't perform this optimization for
1439 * all load_local_invocation_index.
1440 */
1441 if (fast_launch_2 && api_workgroup_size == hw_workgroup_size &&
1442 ((shader->info.workgroup_size[0] == 1) + (shader->info.workgroup_size[1] == 1) +
1443 (shader->info.workgroup_size[2] == 1)) == 2) {
1444 nir_lower_compute_system_values_options csv_options = {
1445 .lower_local_invocation_index = true,
1446 };
1447 nir_lower_compute_system_values(shader, &csv_options);
1448 }
1449
1450 nir_validate_shader(shader, "after emitting NGG MS");
1451 }
1452