• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright © 2021 Valve Corporation
3  *
4  * SPDX-License-Identifier: MIT
5  */
6 
7 #include "ac_nir.h"
8 #include "ac_nir_helpers.h"
9 #include "ac_gpu_info.h"
10 #include "amdgfxregs.h"
11 #include "nir_builder.h"
12 #include "nir_xfb_info.h"
13 #include "util/u_math.h"
14 #include "util/u_vector.h"
15 
16 typedef struct
17 {
18    const ac_nir_lower_ngg_options *options;
19 
20    nir_function_impl *impl;
21    int const_out_vtxcnt[4];
22    int const_out_prmcnt[4];
23    unsigned max_num_waves;
24    unsigned num_vertices_per_primitive;
25    nir_def *lds_addr_gs_out_vtx;
26    nir_def *lds_addr_gs_scratch;
27    unsigned lds_bytes_per_gs_out_vertex;
28    unsigned lds_offs_primflags;
29    bool output_compile_time_known;
30    bool streamout_enabled;
31    /* Outputs */
32    ac_nir_prerast_out out;
33    /* Count per stream. */
34    nir_def *vertex_count[4];
35    nir_def *primitive_count[4];
36 } lower_ngg_gs_state;
37 
38 /**
39  * Return the address of the LDS storage reserved for the N'th vertex,
40  * where N is in emit order, meaning:
41  * - during the finale, N is the invocation_index (within the workgroup)
42  * - during vertex emit, i.e. while the API GS shader invocation is running,
43  *   N = invocation_index * gs_max_out_vertices + emit_idx
44  *   where emit_idx is the vertex index in the current API GS invocation.
45  *
46  * Goals of the LDS memory layout:
47  * 1. Eliminate bank conflicts on write for geometry shaders that have all emits
48  *    in uniform control flow
49  * 2. Eliminate bank conflicts on read for export if, additionally, there is no
50  *    culling
51  * 3. Agnostic to the number of waves (since we don't know it before compiling)
52  * 4. Allow coalescing of LDS instructions (ds_write_b128 etc.)
53  * 5. Avoid wasting memory.
54  *
55  * We use an AoS layout due to point 4 (this also helps point 3). In an AoS
56  * layout, elimination of bank conflicts requires that each vertex occupy an
57  * odd number of dwords. We use the additional dword to store the output stream
58  * index as well as a flag to indicate whether this vertex ends a primitive
59  * for rasterization.
60  *
61  * Swizzling is required to satisfy points 1 and 2 simultaneously.
62  *
63  * Vertices are stored in export order (gsthread * gs_max_out_vertices + emitidx).
64  * Indices are swizzled in groups of 32, which ensures point 1 without
65  * disturbing point 2.
66  *
67  * \return an LDS pointer to type {[N x i32], [4 x i8]}
68  */
69 static nir_def *
ngg_gs_out_vertex_addr(nir_builder * b,nir_def * out_vtx_idx,lower_ngg_gs_state * s)70 ngg_gs_out_vertex_addr(nir_builder *b, nir_def *out_vtx_idx, lower_ngg_gs_state *s)
71 {
72    unsigned write_stride_2exp = ffs(MAX2(b->shader->info.gs.vertices_out, 1)) - 1;
73 
74    /* gs_max_out_vertices = 2^(write_stride_2exp) * some odd number */
75    if (write_stride_2exp) {
76       nir_def *row = nir_ushr_imm(b, out_vtx_idx, 5);
77       nir_def *swizzle = nir_iand_imm(b, row, (1u << write_stride_2exp) - 1u);
78       out_vtx_idx = nir_ixor(b, out_vtx_idx, swizzle);
79    }
80 
81    nir_def *out_vtx_offs = nir_imul_imm(b, out_vtx_idx, s->lds_bytes_per_gs_out_vertex);
82    return nir_iadd_nuw(b, out_vtx_offs, s->lds_addr_gs_out_vtx);
83 }
84 
85 static nir_def *
ngg_gs_emit_vertex_addr(nir_builder * b,nir_def * gs_vtx_idx,lower_ngg_gs_state * s)86 ngg_gs_emit_vertex_addr(nir_builder *b, nir_def *gs_vtx_idx, lower_ngg_gs_state *s)
87 {
88    nir_def *tid_in_tg = nir_load_local_invocation_index(b);
89    nir_def *gs_out_vtx_base = nir_imul_imm(b, tid_in_tg, b->shader->info.gs.vertices_out);
90    nir_def *out_vtx_idx = nir_iadd_nuw(b, gs_out_vtx_base, gs_vtx_idx);
91 
92    return ngg_gs_out_vertex_addr(b, out_vtx_idx, s);
93 }
94 
95 static void
ngg_gs_clear_primflags(nir_builder * b,nir_def * num_vertices,unsigned stream,lower_ngg_gs_state * s)96 ngg_gs_clear_primflags(nir_builder *b, nir_def *num_vertices, unsigned stream, lower_ngg_gs_state *s)
97 {
98    char name[32];
99    snprintf(name, sizeof(name), "clear_primflag_idx_%u", stream);
100    nir_variable *clear_primflag_idx_var = nir_local_variable_create(b->impl, glsl_uint_type(), name);
101 
102    nir_def *zero_u8 = nir_imm_zero(b, 1, 8);
103    nir_store_var(b, clear_primflag_idx_var, num_vertices, 0x1u);
104 
105    nir_loop *loop = nir_push_loop(b);
106    {
107       nir_def *clear_primflag_idx = nir_load_var(b, clear_primflag_idx_var);
108       nir_if *if_break = nir_push_if(b, nir_uge_imm(b, clear_primflag_idx, b->shader->info.gs.vertices_out));
109       {
110          nir_jump(b, nir_jump_break);
111       }
112       nir_push_else(b, if_break);
113       {
114          nir_def *emit_vtx_addr = ngg_gs_emit_vertex_addr(b, clear_primflag_idx, s);
115          nir_store_shared(b, zero_u8, emit_vtx_addr, .base = s->lds_offs_primflags + stream);
116          nir_store_var(b, clear_primflag_idx_var, nir_iadd_imm_nuw(b, clear_primflag_idx, 1), 0x1u);
117       }
118       nir_pop_if(b, if_break);
119    }
120    nir_pop_loop(b, loop);
121 }
122 
123 static bool
lower_ngg_gs_store_output(nir_builder * b,nir_intrinsic_instr * intrin,lower_ngg_gs_state * s)124 lower_ngg_gs_store_output(nir_builder *b, nir_intrinsic_instr *intrin, lower_ngg_gs_state *s)
125 {
126    ac_nir_gather_prerast_store_output_info(b, intrin, &s->out);
127    nir_instr_remove(&intrin->instr);
128    return true;
129 }
130 
131 static unsigned
gs_output_component_mask_with_stream(ac_nir_prerast_per_output_info * info,unsigned stream)132 gs_output_component_mask_with_stream(ac_nir_prerast_per_output_info *info, unsigned stream)
133 {
134    unsigned mask = info->components_mask;
135    if (!mask)
136       return 0;
137 
138    /* clear component when not requested stream */
139    for (int i = 0; i < 4; i++) {
140       if (((info->stream >> (i * 2)) & 3) != stream)
141          mask &= ~(1 << i);
142    }
143 
144    return mask;
145 }
146 
147 static bool
lower_ngg_gs_emit_vertex_with_counter(nir_builder * b,nir_intrinsic_instr * intrin,lower_ngg_gs_state * s)148 lower_ngg_gs_emit_vertex_with_counter(nir_builder *b, nir_intrinsic_instr *intrin, lower_ngg_gs_state *s)
149 {
150    b->cursor = nir_before_instr(&intrin->instr);
151 
152    unsigned stream = nir_intrinsic_stream_id(intrin);
153    if (!(b->shader->info.gs.active_stream_mask & (1 << stream))) {
154       nir_instr_remove(&intrin->instr);
155       return true;
156    }
157 
158    nir_def *gs_emit_vtx_idx = intrin->src[0].ssa;
159    nir_def *current_vtx_per_prim = intrin->src[1].ssa;
160    nir_def *gs_emit_vtx_addr = ngg_gs_emit_vertex_addr(b, gs_emit_vtx_idx, s);
161 
162    /* Store generic 32-bit outputs to LDS.
163     * In case of packed 16-bit, we assume that has been already packed into 32 bit slots by now.
164     */
165    u_foreach_bit64(slot, b->shader->info.outputs_written) {
166       const unsigned packed_location = util_bitcount64((b->shader->info.outputs_written & BITFIELD64_MASK(slot)));
167       unsigned mask = gs_output_component_mask_with_stream(&s->out.infos[slot], stream);
168 
169       nir_def **output = s->out.outputs[slot];
170       nir_def *undef = nir_undef(b, 1, 32);
171 
172       while (mask) {
173          int start, count;
174          u_bit_scan_consecutive_range(&mask, &start, &count);
175          nir_def *values[4] = {0};
176          for (int c = start; c < start + count; ++c) {
177             if (!output[c]) {
178                /* The shader hasn't written this output. */
179                values[c - start] = undef;
180             } else {
181                assert(output[c]->bit_size == 32);
182                values[c - start] = output[c];
183             }
184          }
185 
186          nir_def *store_val = nir_vec(b, values, (unsigned)count);
187          nir_store_shared(b, store_val, gs_emit_vtx_addr,
188                           .base = packed_location * 16 + start * 4,
189                           .align_mul = 4);
190       }
191 
192       /* Clear all outputs (they are undefined after emit_vertex) */
193       memset(s->out.outputs[slot], 0, sizeof(s->out.outputs[slot]));
194    }
195 
196    const unsigned num_32bit_outputs = util_bitcount64(b->shader->info.outputs_written);
197 
198    /* Store dedicated 16-bit outputs to LDS. */
199    u_foreach_bit(slot, b->shader->info.outputs_written_16bit) {
200       const unsigned packed_location = num_32bit_outputs +
201          util_bitcount(b->shader->info.outputs_written_16bit & BITFIELD_MASK(slot));
202 
203       const unsigned mask_lo = gs_output_component_mask_with_stream(s->out.infos_16bit_lo + slot, stream);
204       const unsigned mask_hi = gs_output_component_mask_with_stream(s->out.infos_16bit_hi + slot, stream);
205       unsigned mask = mask_lo | mask_hi;
206 
207       nir_def **output_lo = s->out.outputs_16bit_lo[slot];
208       nir_def **output_hi = s->out.outputs_16bit_hi[slot];
209       nir_def *undef = nir_undef(b, 1, 16);
210 
211       while (mask) {
212          int start, count;
213          u_bit_scan_consecutive_range(&mask, &start, &count);
214          nir_def *values[4] = {0};
215          for (int c = start; c < start + count; ++c) {
216             nir_def *lo = output_lo[c] ? output_lo[c] : undef;
217             nir_def *hi = output_hi[c] ? output_hi[c] : undef;
218 
219             values[c - start] = nir_pack_32_2x16_split(b, lo, hi);
220          }
221 
222          nir_def *store_val = nir_vec(b, values, (unsigned)count);
223          nir_store_shared(b, store_val, gs_emit_vtx_addr,
224                           .base = packed_location * 16 + start * 4,
225                           .align_mul = 4);
226       }
227 
228       /* Clear all outputs (they are undefined after emit_vertex) */
229       memset(s->out.outputs_16bit_lo[slot], 0, sizeof(s->out.outputs_16bit_lo[slot]));
230       memset(s->out.outputs_16bit_hi[slot], 0, sizeof(s->out.outputs_16bit_hi[slot]));
231    }
232 
233    /* Calculate and store per-vertex primitive flags based on vertex counts:
234     * - bit 0: whether this vertex finishes a primitive (a real primitive, not the strip)
235     * - bit 1: whether the primitive index is odd (if we are emitting triangle strips, otherwise always 0)
236     *          only set when the vertex also finishes the primitive
237     * - bit 2: whether vertex is live (if culling is enabled: set after culling, otherwise always 1)
238     */
239 
240    nir_def *vertex_live_flag =
241       !stream && s->options->can_cull
242          ? nir_ishl_imm(b, nir_b2i32(b, nir_inot(b, nir_load_cull_any_enabled_amd(b))), 2)
243          : nir_imm_int(b, 0b100);
244 
245    nir_def *completes_prim = nir_ige_imm(b, current_vtx_per_prim, s->num_vertices_per_primitive - 1);
246    nir_def *complete_flag = nir_b2i32(b, completes_prim);
247 
248    nir_def *prim_flag = nir_ior(b, vertex_live_flag, complete_flag);
249    if (s->num_vertices_per_primitive == 3) {
250       nir_def *odd = nir_iand(b, current_vtx_per_prim, complete_flag);
251       nir_def *odd_flag = nir_ishl_imm(b, odd, 1);
252       prim_flag = nir_ior(b, prim_flag, odd_flag);
253    }
254 
255    nir_store_shared(b, nir_u2u8(b, prim_flag), gs_emit_vtx_addr,
256                     .base = s->lds_offs_primflags + stream,
257                     .align_mul = 4, .align_offset = stream);
258 
259    nir_instr_remove(&intrin->instr);
260    return true;
261 }
262 
263 static bool
lower_ngg_gs_end_primitive_with_counter(nir_builder * b,nir_intrinsic_instr * intrin,UNUSED lower_ngg_gs_state * s)264 lower_ngg_gs_end_primitive_with_counter(nir_builder *b, nir_intrinsic_instr *intrin, UNUSED lower_ngg_gs_state *s)
265 {
266    b->cursor = nir_before_instr(&intrin->instr);
267 
268    /* These are not needed, we can simply remove them */
269    nir_instr_remove(&intrin->instr);
270    return true;
271 }
272 
273 static bool
lower_ngg_gs_set_vertex_and_primitive_count(nir_builder * b,nir_intrinsic_instr * intrin,lower_ngg_gs_state * s)274 lower_ngg_gs_set_vertex_and_primitive_count(nir_builder *b, nir_intrinsic_instr *intrin, lower_ngg_gs_state *s)
275 {
276    b->cursor = nir_before_instr(&intrin->instr);
277 
278    unsigned stream = nir_intrinsic_stream_id(intrin);
279    if (stream > 0 && !(b->shader->info.gs.active_stream_mask & (1 << stream))) {
280       nir_instr_remove(&intrin->instr);
281       return true;
282    }
283 
284    s->vertex_count[stream] = intrin->src[0].ssa;
285    s->primitive_count[stream] = intrin->src[1].ssa;
286 
287    /* Clear the primitive flags of non-emitted vertices */
288    if (!nir_src_is_const(intrin->src[0]) || nir_src_as_uint(intrin->src[0]) < b->shader->info.gs.vertices_out)
289       ngg_gs_clear_primflags(b, intrin->src[0].ssa, stream, s);
290 
291    nir_instr_remove(&intrin->instr);
292    return true;
293 }
294 
295 static bool
lower_ngg_gs_intrinsic(nir_builder * b,nir_intrinsic_instr * intrin,void * state)296 lower_ngg_gs_intrinsic(nir_builder *b, nir_intrinsic_instr *intrin, void *state)
297 {
298    lower_ngg_gs_state *s = (lower_ngg_gs_state *) state;
299 
300    if (intrin->intrinsic == nir_intrinsic_store_output)
301       return lower_ngg_gs_store_output(b, intrin, s);
302    else if (intrin->intrinsic == nir_intrinsic_emit_vertex_with_counter)
303       return lower_ngg_gs_emit_vertex_with_counter(b, intrin, s);
304    else if (intrin->intrinsic == nir_intrinsic_end_primitive_with_counter)
305       return lower_ngg_gs_end_primitive_with_counter(b, intrin, s);
306    else if (intrin->intrinsic == nir_intrinsic_set_vertex_and_primitive_count)
307       return lower_ngg_gs_set_vertex_and_primitive_count(b, intrin, s);
308 
309    return false;
310 }
311 
312 static void
lower_ngg_gs_intrinsics(nir_shader * shader,lower_ngg_gs_state * s)313 lower_ngg_gs_intrinsics(nir_shader *shader, lower_ngg_gs_state *s)
314 {
315    nir_shader_intrinsics_pass(shader, lower_ngg_gs_intrinsic, nir_metadata_none, s);
316 }
317 
318 static nir_def *
ngg_gs_process_out_primitive(nir_builder * b,nir_def * exporter_tid_in_tg,nir_def * primflag_0,lower_ngg_gs_state * s)319 ngg_gs_process_out_primitive(nir_builder *b,
320                              nir_def *exporter_tid_in_tg, nir_def *primflag_0,
321                              lower_ngg_gs_state *s)
322 {
323    /* Only bit 0 matters here - set it to 1 when the primitive should be null */
324    nir_def *is_null_prim = nir_ixor(b, primflag_0, nir_imm_int(b, -1u));
325 
326    nir_def *vtx_indices[3] = {0};
327    vtx_indices[s->num_vertices_per_primitive - 1] = exporter_tid_in_tg;
328    if (s->num_vertices_per_primitive >= 2)
329       vtx_indices[s->num_vertices_per_primitive - 2] = nir_iadd_imm(b, exporter_tid_in_tg, -1);
330    if (s->num_vertices_per_primitive == 3)
331       vtx_indices[s->num_vertices_per_primitive - 3] = nir_iadd_imm(b, exporter_tid_in_tg, -2);
332 
333    if (s->num_vertices_per_primitive == 3) {
334       /* API GS outputs triangle strips, but NGG HW understands triangles.
335        * We already know the triangles due to how we set the primitive flags, but we need to
336        * make sure the vertex order is so that the front/back is correct, and the provoking vertex is kept.
337        */
338 
339       nir_def *is_odd = nir_ubfe_imm(b, primflag_0, 1, 1);
340       nir_def *provoking_vertex_index = nir_load_provoking_vtx_in_prim_amd(b);
341       nir_def *provoking_vertex_first = nir_ieq_imm(b, provoking_vertex_index, 0);
342 
343       vtx_indices[0] = nir_bcsel(b, provoking_vertex_first, vtx_indices[0],
344                                  nir_iadd(b, vtx_indices[0], is_odd));
345       vtx_indices[1] = nir_bcsel(b, provoking_vertex_first,
346                                  nir_iadd(b, vtx_indices[1], is_odd),
347                                  nir_isub(b, vtx_indices[1], is_odd));
348       vtx_indices[2] = nir_bcsel(b, provoking_vertex_first,
349                                  nir_isub(b, vtx_indices[2], is_odd), vtx_indices[2]);
350    }
351 
352    return ac_nir_pack_ngg_prim_exp_arg(b, s->num_vertices_per_primitive, vtx_indices,
353                                              is_null_prim, s->options->hw_info->gfx_level);
354 }
355 
356 static void
ngg_gs_process_out_vertex(nir_builder * b,nir_def * out_vtx_lds_addr,lower_ngg_gs_state * s)357 ngg_gs_process_out_vertex(nir_builder *b, nir_def *out_vtx_lds_addr, lower_ngg_gs_state *s)
358 {
359    nir_def *exported_out_vtx_lds_addr = out_vtx_lds_addr;
360 
361    if (!s->output_compile_time_known) {
362       /* Vertex compaction.
363        * The current thread will export a vertex that was live in another invocation.
364        * Load the index of the vertex that the current thread will have to export.
365        */
366       nir_def *exported_vtx_idx = nir_load_shared(b, 1, 8, out_vtx_lds_addr, .base = s->lds_offs_primflags + 1);
367       exported_out_vtx_lds_addr = ngg_gs_out_vertex_addr(b, nir_u2u32(b, exported_vtx_idx), s);
368    }
369 
370    u_foreach_bit64(slot, b->shader->info.outputs_written) {
371       const unsigned packed_location =
372          util_bitcount64((b->shader->info.outputs_written & BITFIELD64_MASK(slot)));
373 
374       unsigned mask = gs_output_component_mask_with_stream(&s->out.infos[slot], 0);
375 
376       while (mask) {
377          int start, count;
378          u_bit_scan_consecutive_range(&mask, &start, &count);
379          nir_def *load =
380             nir_load_shared(b, count, 32, exported_out_vtx_lds_addr,
381                             .base = packed_location * 16 + start * 4,
382                             .align_mul = 4);
383 
384          for (int i = 0; i < count; i++)
385             s->out.outputs[slot][start + i] = nir_channel(b, load, i);
386       }
387    }
388 
389    const unsigned num_32bit_outputs = util_bitcount64(b->shader->info.outputs_written);
390 
391    /* Dedicated 16-bit outputs. */
392    u_foreach_bit(i, b->shader->info.outputs_written_16bit) {
393       const unsigned packed_location = num_32bit_outputs +
394          util_bitcount(b->shader->info.outputs_written_16bit & BITFIELD_MASK(i));
395 
396       const unsigned mask_lo = gs_output_component_mask_with_stream(&s->out.infos_16bit_lo[i], 0);
397       const unsigned mask_hi = gs_output_component_mask_with_stream(&s->out.infos_16bit_hi[i], 0);
398       unsigned mask = mask_lo | mask_hi;
399 
400       while (mask) {
401          int start, count;
402          u_bit_scan_consecutive_range(&mask, &start, &count);
403          nir_def *load =
404             nir_load_shared(b, count, 32, exported_out_vtx_lds_addr,
405                             .base = packed_location * 16 + start * 4,
406                             .align_mul = 4);
407 
408          for (int j = 0; j < count; j++) {
409             nir_def *val = nir_channel(b, load, j);
410             unsigned comp = start + j;
411 
412             if (mask_lo & BITFIELD_BIT(comp))
413                s->out.outputs_16bit_lo[i][comp] = nir_unpack_32_2x16_split_x(b, val);
414 
415             if (mask_hi & BITFIELD_BIT(comp))
416                s->out.outputs_16bit_hi[i][comp] = nir_unpack_32_2x16_split_y(b, val);
417          }
418       }
419    }
420 
421    /* This should be after streamout and before exports. */
422    ac_nir_clamp_vertex_color_outputs(b, &s->out);
423 }
424 
425 /**
426  * Emit NGG GS output, including vertex and primitive exports and attribute ring stores (if any).
427  * The exact sequence emitted, depends on the current GPU and its workarounds.
428  *
429  * The order mainly depends on whether the current GPU has an attribute ring, and
430  * whether it has the bug that requires us to emit a wait for the attribute ring stores.
431  *
432  * The basic structure looks like this:
433  *
434  * if (has primitive) {
435  *    <per-primitive processing: calculation of the primitive export argument>
436  *
437  *    if (!(wait for attr ring)) {
438  *       <primitive export>
439  *    }
440  * }
441  * if (has vertex) {
442  *    <per-vertex processing: load each output from LDS, and perform necessary adjustments>
443  *
444  *    if (!(wait for attr ring)) {
445  *       <vertex position exports>
446  *       <vertex parameter exports>
447  *    }
448  * }
449  * <per-vertex attribute ring stores, if the current GPU has an attribute ring>
450  * if (wait for attr ring) {
451  *    <barrier to wait for attribute ring stores>
452  *    if (has primitive) {
453  *       <primitive export>
454  *    }
455  *    if (has vertex) {
456  *       <vertex position exports>
457  *       <vertex parameter exports>
458  *    }
459  * }
460  *
461  */
462 static void
ngg_gs_emit_output(nir_builder * b,nir_def * max_num_out_vtx,nir_def * max_num_out_prims,nir_def * tid_in_tg,nir_def * out_vtx_lds_addr,nir_def * prim_exporter_tid_in_tg,nir_def * primflag_0,lower_ngg_gs_state * s)463 ngg_gs_emit_output(nir_builder *b, nir_def *max_num_out_vtx, nir_def *max_num_out_prims,
464                    nir_def *tid_in_tg, nir_def *out_vtx_lds_addr, nir_def *prim_exporter_tid_in_tg,
465                    nir_def *primflag_0, lower_ngg_gs_state *s)
466 {
467    nir_def *undef = nir_undef(b, 1, 32);
468 
469    /* Primitive processing */
470    nir_def *prim_exp_arg = NULL;
471    nir_if *if_process_primitive = nir_push_if(b, nir_ilt(b, tid_in_tg, max_num_out_prims));
472    {
473       prim_exp_arg = ngg_gs_process_out_primitive(b, prim_exporter_tid_in_tg, primflag_0, s);
474    }
475    nir_pop_if(b, if_process_primitive);
476    prim_exp_arg = nir_if_phi(b, prim_exp_arg, undef);
477 
478    /* Vertex processing */
479    nir_if *if_process_vertex = nir_push_if(b, nir_ilt(b, tid_in_tg, max_num_out_vtx));
480    {
481       ngg_gs_process_out_vertex(b, out_vtx_lds_addr, s);
482    }
483    nir_pop_if(b, if_process_vertex);
484    ac_nir_create_output_phis(b, b->shader->info.outputs_written, b->shader->info.outputs_written_16bit, &s->out);
485 
486    nir_if *if_export_primitive = nir_push_if(b, if_process_primitive->condition.ssa);
487    {
488       ac_nir_export_primitive(b, prim_exp_arg, NULL);
489    }
490    nir_pop_if(b, if_export_primitive);
491 
492    nir_if *if_export_vertex = nir_push_if(b, if_process_vertex->condition.ssa);
493    {
494       uint64_t export_outputs = b->shader->info.outputs_written | VARYING_BIT_POS;
495       if (s->options->kill_pointsize)
496          export_outputs &= ~VARYING_BIT_PSIZ;
497       if (s->options->kill_layer)
498          export_outputs &= ~VARYING_BIT_LAYER;
499 
500       ac_nir_export_position(b, s->options->hw_info->gfx_level,
501                              s->options->clip_cull_dist_mask,
502                              !s->options->has_param_exports,
503                              s->options->force_vrs, true,
504                              export_outputs, &s->out, NULL);
505 
506       if (s->options->has_param_exports && !s->options->hw_info->has_attr_ring)
507          ac_nir_export_parameters(b, s->options->vs_output_param_offset,
508                                   b->shader->info.outputs_written,
509                                   b->shader->info.outputs_written_16bit,
510                                   &s->out);
511    }
512    nir_pop_if(b, if_export_vertex);
513 
514    if (s->options->has_param_exports && s->options->hw_info->has_attr_ring) {
515       if (s->options->hw_info->has_attr_ring_wait_bug)
516          b->cursor = nir_after_cf_node_and_phis(&if_export_primitive->cf_node);
517 
518       nir_def *vertices_in_wave = nir_bit_count(b, nir_ballot(b, 1, s->options->wave_size, if_process_vertex->condition.ssa));
519 
520       ac_nir_store_parameters_to_attr_ring(b, s->options->vs_output_param_offset,
521                                            b->shader->info.outputs_written,
522                                            b->shader->info.outputs_written_16bit,
523                                            &s->out, vertices_in_wave);
524 
525       if (s->options->hw_info->has_attr_ring_wait_bug) {
526          /* Wait for attribute ring stores to finish. */
527          nir_barrier(b, .execution_scope = SCOPE_SUBGROUP,
528                         .memory_scope = SCOPE_DEVICE,
529                         .memory_semantics = NIR_MEMORY_RELEASE,
530                         .memory_modes = nir_var_mem_ssbo | nir_var_shader_out | nir_var_mem_global | nir_var_image);
531       }
532    }
533 }
534 
535 static void
ngg_gs_setup_vertex_compaction(nir_builder * b,nir_def * vertex_live,nir_def * tid_in_tg,nir_def * exporter_tid_in_tg,lower_ngg_gs_state * s)536 ngg_gs_setup_vertex_compaction(nir_builder *b, nir_def *vertex_live, nir_def *tid_in_tg,
537                                nir_def *exporter_tid_in_tg, lower_ngg_gs_state *s)
538 {
539    assert(vertex_live->bit_size == 1);
540    nir_if *if_vertex_live = nir_push_if(b, vertex_live);
541    {
542       /* Setup the vertex compaction.
543        * Save the current thread's id for the thread which will export the current vertex.
544        * We reuse stream 1 of the primitive flag of the other thread's vertex for storing this.
545        */
546 
547       nir_def *exporter_lds_addr = ngg_gs_out_vertex_addr(b, exporter_tid_in_tg, s);
548       nir_def *tid_in_tg_u8 = nir_u2u8(b, tid_in_tg);
549       nir_store_shared(b, tid_in_tg_u8, exporter_lds_addr, .base = s->lds_offs_primflags + 1);
550    }
551    nir_pop_if(b, if_vertex_live);
552 }
553 
554 static nir_def *
ngg_gs_load_out_vtx_primflag(nir_builder * b,unsigned stream,nir_def * tid_in_tg,nir_def * vtx_lds_addr,nir_def * max_num_out_vtx,lower_ngg_gs_state * s)555 ngg_gs_load_out_vtx_primflag(nir_builder *b, unsigned stream, nir_def *tid_in_tg,
556                              nir_def *vtx_lds_addr, nir_def *max_num_out_vtx,
557                              lower_ngg_gs_state *s)
558 {
559    nir_def *zero = nir_imm_int(b, 0);
560 
561    nir_if *if_outvtx_thread = nir_push_if(b, nir_ilt(b, tid_in_tg, max_num_out_vtx));
562    nir_def *primflag = nir_load_shared(b, 1, 8, vtx_lds_addr,
563                                            .base = s->lds_offs_primflags + stream);
564    primflag = nir_u2u32(b, primflag);
565    nir_pop_if(b, if_outvtx_thread);
566 
567    return nir_if_phi(b, primflag, zero);
568 }
569 
570 static void
ngg_gs_out_prim_all_vtxptr(nir_builder * b,nir_def * last_vtxidx,nir_def * last_vtxptr,nir_def * last_vtx_primflag,lower_ngg_gs_state * s,nir_def * vtxptr[3])571 ngg_gs_out_prim_all_vtxptr(nir_builder *b, nir_def *last_vtxidx, nir_def *last_vtxptr,
572                            nir_def *last_vtx_primflag, lower_ngg_gs_state *s,
573                            nir_def *vtxptr[3])
574 {
575    unsigned last_vtx = s->num_vertices_per_primitive - 1;
576    vtxptr[last_vtx]= last_vtxptr;
577 
578    bool primitive_is_triangle = s->num_vertices_per_primitive == 3;
579    nir_def *is_odd = primitive_is_triangle ?
580       nir_ubfe_imm(b, last_vtx_primflag, 1, 1) : NULL;
581 
582    for (unsigned i = 0; i < s->num_vertices_per_primitive - 1; i++) {
583       nir_def *vtxidx = nir_iadd_imm(b, last_vtxidx, -(last_vtx - i));
584 
585       /* Need to swap vertex 0 and vertex 1 when vertex 2 index is odd to keep
586        * CW/CCW order for correct front/back face culling.
587        */
588       if (primitive_is_triangle)
589          vtxidx = i == 0 ? nir_iadd(b, vtxidx, is_odd) : nir_isub(b, vtxidx, is_odd);
590 
591       vtxptr[i] = ngg_gs_out_vertex_addr(b, vtxidx, s);
592    }
593 }
594 
595 static nir_def *
ngg_gs_cull_primitive(nir_builder * b,nir_def * tid_in_tg,nir_def * max_vtxcnt,nir_def * out_vtx_lds_addr,nir_def * out_vtx_primflag_0,lower_ngg_gs_state * s)596 ngg_gs_cull_primitive(nir_builder *b, nir_def *tid_in_tg, nir_def *max_vtxcnt,
597                       nir_def *out_vtx_lds_addr, nir_def *out_vtx_primflag_0,
598                       lower_ngg_gs_state *s)
599 {
600    /* we haven't enabled point culling, if enabled this function could be further optimized */
601    assert(s->num_vertices_per_primitive > 1);
602 
603    /* save the primflag so that we don't need to load it from LDS again */
604    nir_variable *primflag_var = nir_local_variable_create(s->impl, glsl_uint_type(), "primflag");
605    nir_store_var(b, primflag_var, out_vtx_primflag_0, 1);
606 
607    /* last bit of primflag indicate if this is the final vertex of a primitive */
608    nir_def *is_end_prim_vtx = nir_i2b(b, nir_iand_imm(b, out_vtx_primflag_0, 1));
609    nir_def *has_output_vertex = nir_ilt(b, tid_in_tg, max_vtxcnt);
610    nir_def *prim_enable = nir_iand(b, is_end_prim_vtx, has_output_vertex);
611 
612    nir_if *if_prim_enable = nir_push_if(b, prim_enable);
613    {
614       /* Calculate the LDS address of every vertex in the current primitive. */
615       nir_def *vtxptr[3];
616       ngg_gs_out_prim_all_vtxptr(b, tid_in_tg, out_vtx_lds_addr, out_vtx_primflag_0, s, vtxptr);
617 
618       /* Load the positions from LDS. */
619       nir_def *pos[3][4];
620       for (unsigned i = 0; i < s->num_vertices_per_primitive; i++) {
621          /* VARYING_SLOT_POS == 0, so base won't count packed location */
622          pos[i][3] = nir_load_shared(b, 1, 32, vtxptr[i], .base = 12); /* W */
623          nir_def *xy = nir_load_shared(b, 2, 32, vtxptr[i], .base = 0, .align_mul = 4);
624          pos[i][0] = nir_channel(b, xy, 0);
625          pos[i][1] = nir_channel(b, xy, 1);
626 
627          pos[i][0] = nir_fdiv(b, pos[i][0], pos[i][3]);
628          pos[i][1] = nir_fdiv(b, pos[i][1], pos[i][3]);
629       }
630 
631       /* TODO: support clipdist culling in GS */
632       nir_def *accepted_by_clipdist = nir_imm_true(b);
633 
634       nir_def *accepted = ac_nir_cull_primitive(
635          b, accepted_by_clipdist, pos, s->num_vertices_per_primitive, NULL, NULL);
636 
637       nir_if *if_rejected = nir_push_if(b, nir_inot(b, accepted));
638       {
639          /* clear the primflag if rejected */
640          nir_store_shared(b, nir_imm_zero(b, 1, 8), out_vtx_lds_addr,
641                           .base = s->lds_offs_primflags);
642 
643          nir_store_var(b, primflag_var, nir_imm_int(b, 0), 1);
644       }
645       nir_pop_if(b, if_rejected);
646    }
647    nir_pop_if(b, if_prim_enable);
648 
649    /* Wait for LDS primflag access done. */
650    nir_barrier(b, .execution_scope = SCOPE_WORKGROUP,
651                          .memory_scope = SCOPE_WORKGROUP,
652                          .memory_semantics = NIR_MEMORY_ACQ_REL,
653                          .memory_modes = nir_var_mem_shared);
654 
655    /* only dead vertex need a chance to relive */
656    nir_def *vtx_is_dead = nir_ieq_imm(b, nir_load_var(b, primflag_var), 0);
657    nir_def *vtx_update_primflag = nir_iand(b, vtx_is_dead, has_output_vertex);
658    nir_if *if_update_primflag = nir_push_if(b, vtx_update_primflag);
659    {
660       /* get succeeding vertices' primflag to detect this vertex's liveness */
661       for (unsigned i = 1; i < s->num_vertices_per_primitive; i++) {
662          nir_def *vtxidx = nir_iadd_imm(b, tid_in_tg, i);
663          nir_def *not_overflow = nir_ilt(b, vtxidx, max_vtxcnt);
664          nir_if *if_not_overflow = nir_push_if(b, not_overflow);
665          {
666             nir_def *vtxptr = ngg_gs_out_vertex_addr(b, vtxidx, s);
667             nir_def *vtx_primflag =
668                nir_load_shared(b, 1, 8, vtxptr, .base = s->lds_offs_primflags);
669             vtx_primflag = nir_u2u32(b, vtx_primflag);
670 
671             /* if succeeding vertex is alive end of primitive vertex, need to set current
672              * thread vertex's liveness flag (bit 2)
673              */
674             nir_def *has_prim = nir_i2b(b, nir_iand_imm(b, vtx_primflag, 1));
675             nir_def *vtx_live_flag =
676                nir_bcsel(b, has_prim, nir_imm_int(b, 0b100), nir_imm_int(b, 0));
677 
678             /* update this vertex's primflag */
679             nir_def *primflag = nir_load_var(b, primflag_var);
680             primflag = nir_ior(b, primflag, vtx_live_flag);
681             nir_store_var(b, primflag_var, primflag, 1);
682          }
683          nir_pop_if(b, if_not_overflow);
684       }
685    }
686    nir_pop_if(b, if_update_primflag);
687 
688    return nir_load_var(b, primflag_var);
689 }
690 
691 static void
ngg_gs_build_streamout(nir_builder * b,lower_ngg_gs_state * s)692 ngg_gs_build_streamout(nir_builder *b, lower_ngg_gs_state *s)
693 {
694    nir_xfb_info *info = ac_nir_get_sorted_xfb_info(b->shader);
695 
696    nir_def *tid_in_tg = nir_load_local_invocation_index(b);
697    nir_def *max_vtxcnt = nir_load_workgroup_num_input_vertices_amd(b);
698    nir_def *out_vtx_lds_addr = ngg_gs_out_vertex_addr(b, tid_in_tg, s);
699    nir_def *prim_live[4] = {0};
700    nir_def *gen_prim[4] = {0};
701    nir_def *export_seq[4] = {0};
702    nir_def *out_vtx_primflag[4] = {0};
703    for (unsigned stream = 0; stream < 4; stream++) {
704       if (!(info->streams_written & BITFIELD_BIT(stream)))
705          continue;
706 
707       out_vtx_primflag[stream] =
708          ngg_gs_load_out_vtx_primflag(b, stream, tid_in_tg, out_vtx_lds_addr, max_vtxcnt, s);
709 
710       /* Check bit 0 of primflag for primitive alive, it's set for every last
711        * vertex of a primitive.
712        */
713       prim_live[stream] = nir_i2b(b, nir_iand_imm(b, out_vtx_primflag[stream], 1));
714 
715       unsigned scratch_stride = ALIGN(s->max_num_waves, 4);
716       nir_def *scratch_base =
717          nir_iadd_imm(b, s->lds_addr_gs_scratch, stream * scratch_stride);
718 
719       /* We want to export primitives to streamout buffer in sequence,
720        * but not all vertices are alive or mark end of a primitive, so
721        * there're "holes". We don't need continuous invocations to write
722        * primitives to streamout buffer like final vertex export, so
723        * just repack to get the sequence (export_seq) is enough, no need
724        * to do compaction.
725        *
726        * Use separate scratch space for each stream to avoid barrier.
727        * TODO: we may further reduce barriers by writing to all stream
728        * LDS at once, then we only need one barrier instead of one each
729        * stream..
730        */
731       ac_nir_wg_repack_result rep = {0};
732       ac_nir_repack_invocations_in_workgroup(b, &prim_live[stream], &rep, 1, scratch_base,
733                                       s->max_num_waves, s->options->wave_size);
734 
735       /* nir_intrinsic_set_vertex_and_primitive_count can also get primitive count of
736        * current wave, but still need LDS to sum all wave's count to get workgroup count.
737        * And we need repack to export primitive to streamout buffer anyway, so do here.
738        */
739       gen_prim[stream] = rep.num_repacked_invocations;
740       export_seq[stream] = rep.repacked_invocation_index;
741    }
742 
743    /* Workgroup barrier: wait for LDS scratch reads finish. */
744    nir_barrier(b, .execution_scope = SCOPE_WORKGROUP,
745                       .memory_scope = SCOPE_WORKGROUP,
746                       .memory_semantics = NIR_MEMORY_ACQ_REL,
747                       .memory_modes = nir_var_mem_shared);
748 
749    /* Get global buffer offset where this workgroup will stream out data to. */
750    nir_def *emit_prim[4] = {0};
751    nir_def *buffer_offsets[4] = {0};
752    nir_def *so_buffer[4] = {0};
753    ac_nir_ngg_build_streamout_buffer_info(b, info, s->options->hw_info->gfx_level, s->options->has_xfb_prim_query,
754                                    s->options->use_gfx12_xfb_intrinsic, s->lds_addr_gs_scratch, tid_in_tg,
755                                    gen_prim, so_buffer, buffer_offsets, emit_prim);
756 
757    for (unsigned stream = 0; stream < 4; stream++) {
758       if (!(info->streams_written & BITFIELD_BIT(stream)))
759          continue;
760 
761       nir_def *can_emit = nir_ilt(b, export_seq[stream], emit_prim[stream]);
762       nir_if *if_emit = nir_push_if(b, nir_iand(b, can_emit, prim_live[stream]));
763       {
764          /* Get streamout buffer vertex index for the first vertex of this primitive. */
765          nir_def *first_vertex_idx =
766             nir_imul_imm(b, export_seq[stream], s->num_vertices_per_primitive);
767          nir_def *stream_buffer_offsets[NIR_MAX_XFB_BUFFERS];
768 
769          u_foreach_bit(buffer, info->buffers_written) {
770             stream_buffer_offsets[buffer] = nir_iadd(b, buffer_offsets[buffer],
771                                                      nir_imul_imm(b, first_vertex_idx,
772                                                                   info->buffers[buffer].stride));
773          }
774 
775          /* Get all vertices' lds address of this primitive. */
776          nir_def *exported_vtx_lds_addr[3];
777          ngg_gs_out_prim_all_vtxptr(b, tid_in_tg, out_vtx_lds_addr,
778                                     out_vtx_primflag[stream], s,
779                                     exported_vtx_lds_addr);
780 
781          /* Write all vertices of this primitive to streamout buffer. */
782          for (unsigned i = 0; i < s->num_vertices_per_primitive; i++) {
783             ac_nir_ngg_build_streamout_vertex(b, info, stream, so_buffer,
784                                        stream_buffer_offsets, i,
785                                        exported_vtx_lds_addr[i],
786                                        &s->out, false);
787          }
788       }
789       nir_pop_if(b, if_emit);
790    }
791 }
792 
793 static void
ngg_gs_finale(nir_builder * b,lower_ngg_gs_state * s)794 ngg_gs_finale(nir_builder *b, lower_ngg_gs_state *s)
795 {
796    nir_def *tid_in_tg = nir_load_local_invocation_index(b);
797    nir_def *max_vtxcnt = nir_load_workgroup_num_input_vertices_amd(b);
798    nir_def *max_prmcnt = max_vtxcnt; /* They are currently practically the same; both RADV and RadeonSI do this. */
799    nir_def *out_vtx_lds_addr = ngg_gs_out_vertex_addr(b, tid_in_tg, s);
800 
801    if (s->output_compile_time_known) {
802       /* When the output is compile-time known, the GS writes all possible vertices and primitives it can.
803        * The gs_alloc_req needs to happen on one wave only, otherwise the HW hangs.
804        */
805       nir_if *if_wave_0 = nir_push_if(b, nir_ieq_imm(b, nir_load_subgroup_id(b), 0));
806       {
807          /* When the GS outputs 0 vertices, make the vertex and primitive count compile-time zero. */
808          if (b->shader->info.gs.vertices_out == 0)
809             max_vtxcnt = max_prmcnt = nir_imm_int(b, 0);
810 
811          ac_nir_ngg_alloc_vertices_and_primitives(b, max_vtxcnt, max_prmcnt,
812                                                   b->shader->info.gs.vertices_out == 0 &&
813                                                   s->options->hw_info->has_ngg_fully_culled_bug);
814       }
815       nir_pop_if(b, if_wave_0);
816    }
817 
818    /* Workgroup barrier already emitted, we can assume all GS output stores are done by now. */
819 
820    nir_def *out_vtx_primflag_0 = ngg_gs_load_out_vtx_primflag(b, 0, tid_in_tg, out_vtx_lds_addr, max_vtxcnt, s);
821 
822    if (s->output_compile_time_known && b->shader->info.gs.vertices_out) {
823       ngg_gs_emit_output(b, max_vtxcnt, max_prmcnt, tid_in_tg, out_vtx_lds_addr, tid_in_tg, out_vtx_primflag_0, s);
824       return;
825    }
826 
827    /* cull primitives */
828    if (s->options->can_cull) {
829       nir_if *if_cull_en = nir_push_if(b, nir_load_cull_any_enabled_amd(b));
830 
831       /* culling code will update the primflag */
832       nir_def *updated_primflag =
833          ngg_gs_cull_primitive(b, tid_in_tg, max_vtxcnt, out_vtx_lds_addr,
834                                out_vtx_primflag_0, s);
835 
836       nir_pop_if(b, if_cull_en);
837 
838       out_vtx_primflag_0 = nir_if_phi(b, updated_primflag, out_vtx_primflag_0);
839    }
840 
841    /* When the output vertex count is not known at compile time:
842     * There may be gaps between invocations that have live vertices, but NGG hardware
843     * requires that the invocations that export vertices are packed (ie. compact).
844     * To ensure this, we need to repack invocations that have a live vertex.
845     */
846    nir_def *vertex_live = nir_ine_imm(b, out_vtx_primflag_0, 0);
847    ac_nir_wg_repack_result rep = {0};
848 
849    ac_nir_repack_invocations_in_workgroup(b, &vertex_live, &rep, 1, s->lds_addr_gs_scratch,
850                                    s->max_num_waves, s->options->wave_size);
851 
852    nir_def *workgroup_num_vertices = rep.num_repacked_invocations;
853    nir_def *exporter_tid_in_tg = rep.repacked_invocation_index;
854 
855    /* When the workgroup emits 0 total vertices, we also must export 0 primitives (otherwise the HW can hang). */
856    nir_def *any_output = nir_ine_imm(b, workgroup_num_vertices, 0);
857    max_prmcnt = nir_bcsel(b, any_output, max_prmcnt, nir_imm_int(b, 0));
858 
859    /* Allocate export space. We currently don't compact primitives, just use the maximum number. */
860    nir_if *if_wave_0 = nir_push_if(b, nir_ieq_imm(b, nir_load_subgroup_id(b), 0));
861    {
862       ac_nir_ngg_alloc_vertices_and_primitives(b, workgroup_num_vertices, max_prmcnt, s->options->hw_info->has_ngg_fully_culled_bug);
863    }
864    nir_pop_if(b, if_wave_0);
865 
866    /* Vertex compaction. This makes sure there are no gaps between threads that export vertices. */
867    ngg_gs_setup_vertex_compaction(b, vertex_live, tid_in_tg, exporter_tid_in_tg, s);
868 
869    /* Workgroup barrier: wait for all LDS stores to finish. */
870    nir_barrier(b, .execution_scope=SCOPE_WORKGROUP, .memory_scope=SCOPE_WORKGROUP,
871                         .memory_semantics=NIR_MEMORY_ACQ_REL, .memory_modes=nir_var_mem_shared);
872 
873    ngg_gs_emit_output(b, workgroup_num_vertices, max_prmcnt, tid_in_tg, out_vtx_lds_addr, exporter_tid_in_tg, out_vtx_primflag_0, s);
874 }
875 
876 void
ac_nir_lower_ngg_gs(nir_shader * shader,const ac_nir_lower_ngg_options * options)877 ac_nir_lower_ngg_gs(nir_shader *shader, const ac_nir_lower_ngg_options *options)
878 {
879    nir_function_impl *impl = nir_shader_get_entrypoint(shader);
880    assert(impl);
881 
882    lower_ngg_gs_state state = {
883       .options = options,
884       .impl = impl,
885       .max_num_waves = DIV_ROUND_UP(options->max_workgroup_size, options->wave_size),
886       .lds_offs_primflags = options->gs_out_vtx_bytes,
887       .lds_bytes_per_gs_out_vertex = options->gs_out_vtx_bytes + 4u,
888       .streamout_enabled = shader->xfb_info && !options->disable_streamout,
889    };
890 
891    if (!options->can_cull) {
892       nir_gs_count_vertices_and_primitives(shader, state.const_out_vtxcnt,
893                                            state.const_out_prmcnt, NULL, 4u);
894       state.output_compile_time_known = false;
895    }
896 
897    if (shader->info.gs.output_primitive == MESA_PRIM_POINTS)
898       state.num_vertices_per_primitive = 1;
899    else if (shader->info.gs.output_primitive == MESA_PRIM_LINE_STRIP)
900       state.num_vertices_per_primitive = 2;
901    else if (shader->info.gs.output_primitive == MESA_PRIM_TRIANGLE_STRIP)
902       state.num_vertices_per_primitive = 3;
903    else
904       unreachable("Invalid GS output primitive.");
905 
906    /* Extract the full control flow. It is going to be wrapped in an if statement. */
907    nir_cf_list extracted;
908    nir_cf_extract(&extracted, nir_before_impl(impl),
909                   nir_after_impl(impl));
910 
911    nir_builder builder = nir_builder_at(nir_before_impl(impl));
912    nir_builder *b = &builder; /* This is to avoid the & */
913 
914    /* Workgroup barrier: wait for ES threads */
915    nir_barrier(b, .execution_scope=SCOPE_WORKGROUP, .memory_scope=SCOPE_WORKGROUP,
916                          .memory_semantics=NIR_MEMORY_ACQ_REL, .memory_modes=nir_var_mem_shared);
917 
918    state.lds_addr_gs_out_vtx = nir_load_lds_ngg_gs_out_vertex_base_amd(b);
919    state.lds_addr_gs_scratch = nir_load_lds_ngg_scratch_base_amd(b);
920 
921    /* Wrap the GS control flow. */
922    nir_if *if_gs_thread = nir_push_if(b, nir_is_subgroup_invocation_lt_amd(b, nir_load_merged_wave_info_amd(b), .base = 8));
923 
924    nir_cf_reinsert(&extracted, b->cursor);
925    b->cursor = nir_after_cf_list(&if_gs_thread->then_list);
926    nir_pop_if(b, if_gs_thread);
927 
928    /* Workgroup barrier: wait for all GS threads to finish */
929    nir_barrier(b, .execution_scope=SCOPE_WORKGROUP, .memory_scope=SCOPE_WORKGROUP,
930                          .memory_semantics=NIR_MEMORY_ACQ_REL, .memory_modes=nir_var_mem_shared);
931 
932    if (state.streamout_enabled)
933       ngg_gs_build_streamout(b, &state);
934 
935    /* Lower the GS intrinsics */
936    lower_ngg_gs_intrinsics(shader, &state);
937 
938    if (!state.vertex_count[0]) {
939       fprintf(stderr, "Could not find set_vertex_and_primitive_count for stream 0. This would hang your GPU.");
940       abort();
941    }
942 
943    /* Emit shader queries */
944    b->cursor = nir_after_cf_list(&if_gs_thread->then_list);
945    ac_nir_gs_shader_query(b,
946                           state.options->has_gen_prim_query,
947                           state.options->has_gs_invocations_query,
948                           state.options->has_gs_primitives_query,
949                           state.num_vertices_per_primitive,
950                           state.options->wave_size,
951                           state.vertex_count,
952                           state.primitive_count);
953 
954    b->cursor = nir_after_impl(impl);
955 
956    /* Emit the finale sequence */
957    ngg_gs_finale(b, &state);
958    nir_validate_shader(shader, "after emitting NGG GS");
959 
960    /* Cleanup */
961    nir_lower_vars_to_ssa(shader);
962    nir_remove_dead_variables(shader, nir_var_function_temp, NULL);
963    nir_metadata_preserve(impl, nir_metadata_none);
964 }
965