• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright 2017 Advanced Micro Devices, Inc.
3  *
4  * Permission is hereby granted, free of charge, to any person obtaining a
5  * copy of this software and associated documentation files (the "Software"),
6  * to deal in the Software without restriction, including without limitation
7  * on the rights to use, copy, modify, merge, publish, distribute, sub
8  * license, and/or sell copies of the Software, and to permit persons to whom
9  * the Software is furnished to do so, subject to the following conditions:
10  *
11  * The above copyright notice and this permission notice (including the next
12  * paragraph) shall be included in all copies or substantial portions of the
13  * Software.
14  *
15  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17  * FITNESS FOR A PARTICULAR PURPOSE AND NON-INFRINGEMENT. IN NO EVENT SHALL
18  * THE AUTHOR(S) AND/OR THEIR SUPPLIERS BE LIABLE FOR ANY CLAIM,
19  * DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
20  * OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE
21  * USE OR OTHER DEALINGS IN THE SOFTWARE.
22  */
23 
24 #include "ac_llvm_cull.h"
25 #include "si_pipe.h"
26 #include "si_query.h"
27 #include "si_shader_internal.h"
28 #include "sid.h"
29 #include "util/u_memory.h"
30 #include "util/u_prim.h"
31 
get_wave_id_in_tg(struct si_shader_context * ctx)32 static LLVMValueRef get_wave_id_in_tg(struct si_shader_context *ctx)
33 {
34    return si_unpack_param(ctx, ctx->args.merged_wave_info, 24, 4);
35 }
36 
get_tgsize(struct si_shader_context * ctx)37 static LLVMValueRef get_tgsize(struct si_shader_context *ctx)
38 {
39    return si_unpack_param(ctx, ctx->args.merged_wave_info, 28, 4);
40 }
41 
gfx10_get_thread_id_in_tg(struct si_shader_context * ctx)42 LLVMValueRef gfx10_get_thread_id_in_tg(struct si_shader_context *ctx)
43 {
44    LLVMBuilderRef builder = ctx->ac.builder;
45    LLVMValueRef tmp;
46    tmp = LLVMBuildMul(builder, get_wave_id_in_tg(ctx),
47                       LLVMConstInt(ctx->ac.i32, ctx->ac.wave_size, false), "");
48    return LLVMBuildAdd(builder, tmp, ac_get_thread_id(&ctx->ac), "");
49 }
50 
ngg_get_vtx_cnt(struct si_shader_context * ctx)51 static LLVMValueRef ngg_get_vtx_cnt(struct si_shader_context *ctx)
52 {
53    return si_unpack_param(ctx, ctx->args.gs_tg_info, 12, 9);
54 }
55 
ngg_get_prim_cnt(struct si_shader_context * ctx)56 static LLVMValueRef ngg_get_prim_cnt(struct si_shader_context *ctx)
57 {
58    return si_unpack_param(ctx, ctx->args.gs_tg_info, 22, 9);
59 }
60 
ngg_get_ordered_id(struct si_shader_context * ctx)61 static LLVMValueRef ngg_get_ordered_id(struct si_shader_context *ctx)
62 {
63    return si_unpack_param(ctx, ctx->args.gs_tg_info, 0, 12);
64 }
65 
ngg_get_query_buf(struct si_shader_context * ctx)66 static LLVMValueRef ngg_get_query_buf(struct si_shader_context *ctx)
67 {
68    LLVMValueRef buf_ptr = ac_get_arg(&ctx->ac, ctx->internal_bindings);
69 
70    return ac_build_load_to_sgpr(&ctx->ac, buf_ptr,
71                                 LLVMConstInt(ctx->ac.i32, SI_GS_QUERY_BUF, false));
72 }
73 
ngg_get_emulated_counters_buf(struct si_shader_context * ctx)74 static LLVMValueRef ngg_get_emulated_counters_buf(struct si_shader_context *ctx)
75 {
76    LLVMValueRef buf_ptr = ac_get_arg(&ctx->ac, ctx->internal_bindings);
77 
78    return ac_build_load_to_sgpr(&ctx->ac, buf_ptr,
79                                 LLVMConstInt(ctx->ac.i32, SI_GS_QUERY_EMULATED_COUNTERS_BUF, false));
80 }
81 
82 /**
83  * Return the number of vertices as a constant in \p num_vertices,
84  * and return a more precise value as LLVMValueRef from the function.
85  */
ngg_get_vertices_per_prim(struct si_shader_context * ctx,unsigned * num_vertices)86 static LLVMValueRef ngg_get_vertices_per_prim(struct si_shader_context *ctx, unsigned *num_vertices)
87 {
88    const struct si_shader_info *info = &ctx->shader->selector->info;
89 
90    if (ctx->stage == MESA_SHADER_GEOMETRY) {
91       *num_vertices = u_vertices_per_prim(info->base.gs.output_primitive);
92       return LLVMConstInt(ctx->ac.i32, *num_vertices, false);
93    } else if (ctx->stage == MESA_SHADER_VERTEX) {
94       if (info->base.vs.blit_sgprs_amd) {
95          /* Blits always use axis-aligned rectangles with 3 vertices. */
96          *num_vertices = 3;
97          return LLVMConstInt(ctx->ac.i32, 3, 0);
98       } else if (ctx->shader->key.ge.opt.ngg_culling & SI_NGG_CULL_LINES) {
99          *num_vertices = 2;
100          return LLVMConstInt(ctx->ac.i32, 2, 0);
101       } else {
102          /* We always build up all three indices for the prim export
103           * independent of the primitive type. The additional garbage
104           * data shouldn't hurt. This is used by exports and streamout.
105           */
106          *num_vertices = 3;
107 
108          /* Extract OUTPRIM field. */
109          LLVMValueRef num = GET_FIELD(ctx, GS_STATE_OUTPRIM);
110          return LLVMBuildAdd(ctx->ac.builder, num, ctx->ac.i32_1, "");
111       }
112    } else {
113       assert(ctx->stage == MESA_SHADER_TESS_EVAL);
114 
115       if (info->base.tess.point_mode)
116          *num_vertices = 1;
117       else if (info->base.tess._primitive_mode == TESS_PRIMITIVE_ISOLINES)
118          *num_vertices = 2;
119       else
120          *num_vertices = 3;
121 
122       return LLVMConstInt(ctx->ac.i32, *num_vertices, false);
123    }
124 }
125 
gfx10_ngg_export_prim_early(struct si_shader * shader)126 bool gfx10_ngg_export_prim_early(struct si_shader *shader)
127 {
128    struct si_shader_selector *sel = shader->selector;
129 
130    assert(shader->key.ge.as_ngg && !shader->key.ge.as_es);
131 
132    return sel->stage != MESA_SHADER_GEOMETRY &&
133           !gfx10_ngg_writes_user_edgeflags(shader);
134 }
135 
gfx10_ngg_build_sendmsg_gs_alloc_req(struct si_shader_context * ctx)136 void gfx10_ngg_build_sendmsg_gs_alloc_req(struct si_shader_context *ctx)
137 {
138    /* Newer chips can use PRIMGEN_PASSTHRU_NO_MSG to skip gs_alloc_req for NGG passthrough. */
139    if (gfx10_is_ngg_passthrough(ctx->shader) &&
140        ctx->screen->info.family >= CHIP_NAVI23)
141       return;
142 
143    ac_build_sendmsg_gs_alloc_req(&ctx->ac, get_wave_id_in_tg(ctx), ngg_get_vtx_cnt(ctx),
144                                  ngg_get_prim_cnt(ctx));
145 }
146 
gfx10_ngg_build_export_prim(struct si_shader_context * ctx,LLVMValueRef user_edgeflags[3],LLVMValueRef prim_passthrough)147 void gfx10_ngg_build_export_prim(struct si_shader_context *ctx, LLVMValueRef user_edgeflags[3],
148                                  LLVMValueRef prim_passthrough)
149 {
150    LLVMBuilderRef builder = ctx->ac.builder;
151 
152    if (gfx10_is_ngg_passthrough(ctx->shader) || ctx->shader->key.ge.opt.ngg_culling) {
153       ac_build_ifcc(&ctx->ac, si_is_gs_thread(ctx), 6001);
154       {
155          struct ac_ngg_prim prim = {};
156 
157          if (prim_passthrough)
158             prim.passthrough = prim_passthrough;
159          else
160             prim.passthrough = ac_get_arg(&ctx->ac, ctx->args.gs_vtx_offset[0]);
161 
162          /* This is only used with NGG culling, which returns the NGG
163           * passthrough prim export encoding.
164           */
165          if (gfx10_ngg_writes_user_edgeflags(ctx->shader)) {
166             unsigned all_bits_no_edgeflags = ~SI_NGG_PRIM_EDGE_FLAG_BITS;
167             LLVMValueRef edgeflags = LLVMConstInt(ctx->ac.i32, all_bits_no_edgeflags, 0);
168 
169             unsigned num_vertices;
170             ngg_get_vertices_per_prim(ctx, &num_vertices);
171 
172             for (unsigned i = 0; i < num_vertices; i++) {
173                unsigned shift = 9 + i * 10;
174                LLVMValueRef edge;
175 
176                edge = LLVMBuildLoad2(builder, ctx->ac.i1, user_edgeflags[i], "");
177                edge = LLVMBuildZExt(builder, edge, ctx->ac.i32, "");
178                edge = LLVMBuildShl(builder, edge, LLVMConstInt(ctx->ac.i32, shift, 0), "");
179                edgeflags = LLVMBuildOr(builder, edgeflags, edge, "");
180             }
181             prim.passthrough = LLVMBuildAnd(builder, prim.passthrough, edgeflags, "");
182          }
183 
184          ac_build_export_prim(&ctx->ac, &prim);
185       }
186       ac_build_endif(&ctx->ac, 6001);
187       return;
188    }
189 
190    ac_build_ifcc(&ctx->ac, si_is_gs_thread(ctx), 6001);
191    {
192       struct ac_ngg_prim prim = {};
193 
194       ngg_get_vertices_per_prim(ctx, &prim.num_vertices);
195 
196       prim.isnull = ctx->ac.i1false;
197 
198       if (gfx10_edgeflags_have_effect(ctx->shader))
199          prim.edgeflags = ac_pack_edgeflags_for_export(&ctx->ac, &ctx->args);
200       else
201          prim.edgeflags = ctx->ac.i32_0;
202 
203       for (unsigned i = 0; i < prim.num_vertices; ++i)
204          prim.index[i] = si_unpack_param(ctx, ctx->args.gs_vtx_offset[i / 2], (i & 1) * 16, 16);
205 
206       if (gfx10_ngg_writes_user_edgeflags(ctx->shader)) {
207          LLVMValueRef edgeflags = ctx->ac.i32_0;
208 
209          for (unsigned i = 0; i < prim.num_vertices; ++i) {
210             LLVMValueRef edge;
211 
212             edge = LLVMBuildLoad2(ctx->ac.builder, ctx->ac.i1, user_edgeflags[i], "");
213             edge = LLVMBuildZExt(ctx->ac.builder, edge, ctx->ac.i32, "");
214             edge = LLVMBuildShl(ctx->ac.builder, edge, LLVMConstInt(ctx->ac.i32, 9 + i*10, 0), "");
215             edgeflags = LLVMBuildOr(ctx->ac.builder, edgeflags, edge, "");
216          }
217          prim.edgeflags = LLVMBuildAnd(ctx->ac.builder, prim.edgeflags, edgeflags, "");
218       }
219 
220       ac_build_export_prim(&ctx->ac, &prim);
221    }
222    ac_build_endif(&ctx->ac, 6001);
223 }
224 
build_streamout_vertex(struct si_shader_context * ctx,LLVMValueRef * so_buffer,LLVMValueRef * wg_offset_dw,unsigned stream,LLVMValueRef offset_vtx,LLVMValueRef vertexptr)225 static void build_streamout_vertex(struct si_shader_context *ctx, LLVMValueRef *so_buffer,
226                                    LLVMValueRef *wg_offset_dw, unsigned stream,
227                                    LLVMValueRef offset_vtx, LLVMValueRef vertexptr)
228 {
229    struct si_shader_info *info = &ctx->shader->selector->info;
230    struct pipe_stream_output_info *so = &ctx->so;
231    LLVMBuilderRef builder = ctx->ac.builder;
232    LLVMValueRef offset[4] = {};
233    LLVMValueRef tmp;
234 
235    for (unsigned buffer = 0; buffer < 4; ++buffer) {
236       if (!wg_offset_dw[buffer])
237          continue;
238 
239       tmp = LLVMBuildMul(builder, offset_vtx, LLVMConstInt(ctx->ac.i32, so->stride[buffer], false),
240                          "");
241       tmp = LLVMBuildAdd(builder, wg_offset_dw[buffer], tmp, "");
242       offset[buffer] = LLVMBuildShl(builder, tmp, LLVMConstInt(ctx->ac.i32, 2, false), "");
243    }
244 
245    for (unsigned i = 0; i < so->num_outputs; ++i) {
246       if (so->output[i].stream != stream)
247          continue;
248 
249       unsigned reg = so->output[i].register_index;
250       struct si_shader_output_values out;
251       out.semantic = info->output_semantic[reg];
252 
253       for (unsigned comp = 0; comp < 4; comp++) {
254          tmp = ac_build_gep0(&ctx->ac, vertexptr, LLVMConstInt(ctx->ac.i32, 4 * reg + comp, false));
255          out.values[comp] = LLVMBuildLoad(builder, tmp, "");
256          out.vertex_streams = info->output_streams[reg];
257       }
258 
259       si_llvm_streamout_store_output(ctx, so_buffer, offset, &so->output[i], &out);
260    }
261 }
262 
263 struct ngg_streamout {
264    LLVMValueRef num_vertices;
265 
266    /* per-thread data */
267    LLVMValueRef prim_enable[4]; /* i1 per stream */
268    LLVMValueRef vertices[3];    /* [N x i32] addrspace(LDS)* */
269 
270    /* Output */
271    LLVMValueRef emit[4]; /* per-stream emitted primitives (only valid for used streams) */
272 };
273 
274 /**
275  * Build streamout logic.
276  *
277  * Implies a barrier.
278  *
279  * Writes number of emitted primitives to gs_ngg_scratch[4:8].
280  *
281  * Clobbers gs_ngg_scratch[8:].
282  */
build_streamout(struct si_shader_context * ctx,struct ngg_streamout * nggso)283 static void build_streamout(struct si_shader_context *ctx, struct ngg_streamout *nggso)
284 {
285    struct si_shader_info *info = &ctx->shader->selector->info;
286    struct pipe_stream_output_info *so = &ctx->so;
287    LLVMBuilderRef builder = ctx->ac.builder;
288    LLVMValueRef buf_ptr = ac_get_arg(&ctx->ac, ctx->internal_bindings);
289    LLVMValueRef tid = gfx10_get_thread_id_in_tg(ctx);
290    LLVMValueRef tmp, tmp2;
291    LLVMValueRef i32_2 = LLVMConstInt(ctx->ac.i32, 2, false);
292    LLVMValueRef i32_4 = LLVMConstInt(ctx->ac.i32, 4, false);
293    LLVMValueRef i32_8 = LLVMConstInt(ctx->ac.i32, 8, false);
294    LLVMValueRef so_buffer[4] = {};
295    unsigned max_num_vertices = 1 + (nggso->vertices[1] ? 1 : 0) + (nggso->vertices[2] ? 1 : 0);
296    LLVMValueRef prim_stride_dw[4] = {};
297    LLVMValueRef prim_stride_dw_vgpr = LLVMGetUndef(ctx->ac.i32);
298    int stream_for_buffer[4] = {-1, -1, -1, -1};
299    unsigned bufmask_for_stream[4] = {};
300    bool isgs = ctx->stage == MESA_SHADER_GEOMETRY;
301    unsigned scratch_emit_base = isgs ? 4 : 0;
302    LLVMValueRef scratch_emit_basev = isgs ? i32_4 : ctx->ac.i32_0;
303    unsigned scratch_offset_base = isgs ? 8 : 4;
304    LLVMValueRef scratch_offset_basev = isgs ? i32_8 : i32_4;
305 
306    /* Determine the mapping of streamout buffers to vertex streams. */
307    for (unsigned i = 0; i < so->num_outputs; ++i) {
308       unsigned buf = so->output[i].output_buffer;
309       unsigned stream = so->output[i].stream;
310       assert(stream_for_buffer[buf] < 0 || stream_for_buffer[buf] == stream);
311       stream_for_buffer[buf] = stream;
312       bufmask_for_stream[stream] |= 1 << buf;
313    }
314 
315    for (unsigned buffer = 0; buffer < 4; ++buffer) {
316       if (stream_for_buffer[buffer] == -1)
317          continue;
318 
319       assert(so->stride[buffer]);
320 
321       tmp = LLVMConstInt(ctx->ac.i32, so->stride[buffer], false);
322       prim_stride_dw[buffer] = LLVMBuildMul(builder, tmp, nggso->num_vertices, "");
323       prim_stride_dw_vgpr =
324          ac_build_writelane(&ctx->ac, prim_stride_dw_vgpr, prim_stride_dw[buffer],
325                             LLVMConstInt(ctx->ac.i32, buffer, false));
326 
327       so_buffer[buffer] = ac_build_load_to_sgpr(
328          &ctx->ac, buf_ptr, LLVMConstInt(ctx->ac.i32, SI_VS_STREAMOUT_BUF0 + buffer, false));
329    }
330 
331    tmp = LLVMBuildICmp(builder, LLVMIntEQ, get_wave_id_in_tg(ctx), ctx->ac.i32_0, "");
332    ac_build_ifcc(&ctx->ac, tmp, 5200);
333    {
334       LLVMTypeRef gdsptr = LLVMPointerType(ctx->ac.i32, AC_ADDR_SPACE_GDS);
335       LLVMValueRef gdsbase = LLVMBuildIntToPtr(builder, ctx->ac.i32_0, gdsptr, "");
336 
337       /* Advance the streamout offsets in GDS. */
338       LLVMValueRef offsets_vgpr = ac_build_alloca_undef(&ctx->ac, ctx->ac.i32, "");
339       LLVMValueRef generated_by_stream_vgpr = ac_build_alloca_undef(&ctx->ac, ctx->ac.i32, "");
340 
341       tmp = LLVMBuildICmp(builder, LLVMIntULT, ac_get_thread_id(&ctx->ac), i32_4, "");
342       ac_build_ifcc(&ctx->ac, tmp, 5210);
343       {
344          if (isgs) {
345             tmp = ac_build_gep0(&ctx->ac, ctx->gs_ngg_scratch, tid);
346             tmp = LLVMBuildLoad2(builder, ctx->ac.i32, tmp, "");
347          } else {
348             tmp = ac_build_writelane(&ctx->ac, ctx->ac.i32_0, ngg_get_prim_cnt(ctx), ctx->ac.i32_0);
349          }
350          LLVMBuildStore(builder, tmp, generated_by_stream_vgpr);
351 
352          unsigned swizzle[4];
353          int unused_stream = -1;
354          for (unsigned stream = 0; stream < 4; ++stream) {
355             if (!info->num_stream_output_components[stream]) {
356                unused_stream = stream;
357                break;
358             }
359          }
360          for (unsigned buffer = 0; buffer < 4; ++buffer) {
361             if (stream_for_buffer[buffer] >= 0) {
362                swizzle[buffer] = stream_for_buffer[buffer];
363             } else {
364                assert(unused_stream >= 0);
365                swizzle[buffer] = unused_stream;
366             }
367          }
368 
369          tmp = ac_build_quad_swizzle(&ctx->ac, tmp, swizzle[0], swizzle[1], swizzle[2], swizzle[3]);
370          tmp = LLVMBuildMul(builder, tmp, prim_stride_dw_vgpr, "");
371 
372          LLVMValueRef args[8] = {
373             LLVMBuildIntToPtr(builder, ngg_get_ordered_id(ctx), gdsptr, ""),
374             ctx->ac.i32_0,                             /* value to add */
375             ctx->ac.i32_0,                             /* ordering */
376             ctx->ac.i32_0,                             /* scope */
377             ctx->ac.i1false,                           /* isVolatile */
378             LLVMConstInt(ctx->ac.i32, 1 << 24, false), /* OA index, bits 24+: lane count */
379             ctx->ac.i1true,                            /* wave release */
380             ctx->ac.i1true,                            /* wave done */
381          };
382 
383          if (ctx->screen->info.gfx_level >= GFX11) {
384             /* Gfx11 GDS instructions only operate on the first active lane. All other lanes are
385              * ignored. So are their EXEC bits. This uses the mutex feature of ds_ordered_count
386              * to emulate a multi-dword atomic.
387              *
388              * This is the expected code:
389              *    ds_ordered_count release=0 done=0   // lock mutex
390              *    ds_add_rtn_u32 dwords_written0
391              *    ds_add_rtn_u32 dwords_written1
392              *    ds_add_rtn_u32 dwords_written2
393              *    ds_add_rtn_u32 dwords_written3
394              *    ds_ordered_count release=1 done=1   // unlock mutex
395              *
396              * TODO: Increment GDS_STRMOUT registers instead of GDS memory.
397              */
398             LLVMValueRef dwords_written[4] = {tmp, tmp, tmp, tmp};
399 
400             /* Move all 4 VGPRs from other lanes to lane 0. */
401             for (unsigned i = 1; i < 4; i++) {
402                if (ctx->shader->selector->info.base.xfb_stride[i])
403                   dwords_written[i] = ac_build_quad_swizzle(&ctx->ac, tmp, i, i, i, i);
404             }
405 
406             /* Set release=0 to start a GDS mutex. Set done=0 because it's not the last one. */
407             args[6] = args[7] = ctx->ac.i1false;
408             ac_build_intrinsic(&ctx->ac, "llvm.amdgcn.ds.ordered.add", ctx->ac.i32,
409                                args, ARRAY_SIZE(args), 0);
410             ac_build_waitcnt(&ctx->ac, AC_WAIT_LGKM);
411 
412             for (unsigned i = 0; i < 4; i++) {
413                if (ctx->shader->selector->info.base.xfb_stride[i]) {
414                   LLVMValueRef gds_ptr =
415                      ac_build_gep_ptr(&ctx->ac, gdsbase, LLVMConstInt(ctx->ac.i32, i, 0));
416 
417                   dwords_written[i] = LLVMBuildAtomicRMW(builder, LLVMAtomicRMWBinOpAdd,
418                                                          gds_ptr, dwords_written[i],
419                                                          LLVMAtomicOrderingMonotonic, false);
420                }
421             }
422 
423             /* TODO: This might not be needed if GDS executes instructions in order. */
424             ac_build_waitcnt(&ctx->ac, AC_WAIT_LGKM);
425 
426             /* Set release=1 to end a GDS mutex. Set done=1 because it's the last one. */
427             args[6] = args[7] = ctx->ac.i1true;
428             ac_build_intrinsic(&ctx->ac, "llvm.amdgcn.ds.ordered.add", ctx->ac.i32,
429                                args, ARRAY_SIZE(args), 0);
430 
431             tmp = dwords_written[0];
432             for (unsigned i = 1; i < 4; i++) {
433                if (ctx->shader->selector->info.base.xfb_stride[i]) {
434                   dwords_written[i] = ac_build_readlane(&ctx->ac, dwords_written[i], ctx->ac.i32_0);
435                   tmp = ac_build_writelane(&ctx->ac, tmp, dwords_written[i], LLVMConstInt(ctx->ac.i32, i, 0));
436                }
437             }
438          } else {
439             args[1] = tmp; /* value to add */
440             args[5] = LLVMConstInt(ctx->ac.i32, 4 << 24, false), /* bits 24+: lane count */
441 
442             tmp = ac_build_intrinsic(&ctx->ac, "llvm.amdgcn.ds.ordered.add", ctx->ac.i32,
443                                      args, ARRAY_SIZE(args), 0);
444          }
445 
446          /* Keep offsets in a VGPR for quick retrieval via readlane by
447           * the first wave for bounds checking, and also store in LDS
448           * for retrieval by all waves later. */
449          LLVMBuildStore(builder, tmp, offsets_vgpr);
450 
451          tmp2 = LLVMBuildAdd(builder, ac_get_thread_id(&ctx->ac), scratch_offset_basev, "");
452          tmp2 = ac_build_gep0(&ctx->ac, ctx->gs_ngg_scratch, tmp2);
453          LLVMBuildStore(builder, tmp, tmp2);
454       }
455       ac_build_endif(&ctx->ac, 5210);
456 
457       /* Determine the max emit per buffer. This is done via the SALU, in part
458        * because LLVM can't generate divide-by-multiply if we try to do this
459        * via VALU with one lane per buffer.
460        */
461       LLVMValueRef max_emit[4] = {};
462       for (unsigned buffer = 0; buffer < 4; ++buffer) {
463          if (stream_for_buffer[buffer] == -1)
464             continue;
465 
466          LLVMValueRef bufsize_dw = LLVMBuildLShr(
467             builder, LLVMBuildExtractElement(builder, so_buffer[buffer], i32_2, ""), i32_2, "");
468 
469          tmp = LLVMBuildLoad2(builder, ctx->ac.i32, offsets_vgpr, "");
470          LLVMValueRef offset_dw =
471             ac_build_readlane(&ctx->ac, tmp, LLVMConstInt(ctx->ac.i32, buffer, false));
472 
473          tmp = LLVMBuildSub(builder, bufsize_dw, offset_dw, "");
474          tmp = LLVMBuildUDiv(builder, tmp, prim_stride_dw[buffer], "");
475 
476          tmp2 = LLVMBuildICmp(builder, LLVMIntULT, bufsize_dw, offset_dw, "");
477          max_emit[buffer] = LLVMBuildSelect(builder, tmp2, ctx->ac.i32_0, tmp, "");
478       }
479 
480       /* Determine the number of emitted primitives per stream and fixup the
481        * GDS counter if necessary.
482        *
483        * This is complicated by the fact that a single stream can emit to
484        * multiple buffers (but luckily not vice versa).
485        */
486       LLVMValueRef emit_vgpr = ctx->ac.i32_0;
487 
488       for (unsigned stream = 0; stream < 4; ++stream) {
489          if (!info->num_stream_output_components[stream])
490             continue;
491 
492          tmp = LLVMBuildLoad2(builder, ctx->ac.i32, generated_by_stream_vgpr, "");
493          LLVMValueRef generated =
494             ac_build_readlane(&ctx->ac, tmp, LLVMConstInt(ctx->ac.i32, stream, false));
495 
496          LLVMValueRef emit = generated;
497          for (unsigned buffer = 0; buffer < 4; ++buffer) {
498             if (stream_for_buffer[buffer] == stream)
499                emit = ac_build_umin(&ctx->ac, emit, max_emit[buffer]);
500          }
501 
502          emit_vgpr =
503             ac_build_writelane(&ctx->ac, emit_vgpr, emit, LLVMConstInt(ctx->ac.i32, stream, false));
504 
505          /* Fixup the offset using a plain GDS atomic if we overflowed. */
506          tmp = LLVMBuildICmp(builder, LLVMIntULT, emit, generated, "");
507          ac_build_ifcc(&ctx->ac, tmp, 5221); /* scalar branch */
508          tmp = LLVMBuildLShr(builder, LLVMConstInt(ctx->ac.i32, bufmask_for_stream[stream], false),
509                              ac_get_thread_id(&ctx->ac), "");
510          tmp = LLVMBuildTrunc(builder, tmp, ctx->ac.i1, "");
511          ac_build_ifcc(&ctx->ac, tmp, 5222);
512          {
513             tmp = LLVMBuildSub(builder, generated, emit, "");
514             tmp = LLVMBuildMul(builder, tmp, prim_stride_dw_vgpr, "");
515 
516             if (ctx->screen->info.gfx_level >= GFX11) {
517                /* Gfx11 GDS instructions only operate on the first active lane.
518                 * This is an unrolled waterfall loop. We only get here when we overflow,
519                 * so it doesn't have to be fast.
520                 */
521                for (unsigned i = 0; i < 4; i++) {
522                   if (bufmask_for_stream[stream] & BITFIELD_BIT(i)) {
523                      LLVMValueRef index = LLVMConstInt(ctx->ac.i32, i, 0);
524 
525                      ac_build_ifcc(&ctx->ac, LLVMBuildICmp(builder, LLVMIntEQ, tid, index, ""), 0);
526                      LLVMBuildAtomicRMW(builder, LLVMAtomicRMWBinOpSub,
527                                         LLVMBuildGEP(builder, gdsbase, &index, 1, ""),
528                                         tmp, LLVMAtomicOrderingMonotonic, false);
529                      ac_build_endif(&ctx->ac, 0);
530                   }
531                }
532             } else {
533                LLVMBuildAtomicRMW(builder, LLVMAtomicRMWBinOpSub,
534                                   LLVMBuildGEP(builder, gdsbase, &tid, 1, ""),
535                                   tmp, LLVMAtomicOrderingMonotonic, false);
536             }
537          }
538          ac_build_endif(&ctx->ac, 5222);
539          ac_build_endif(&ctx->ac, 5221);
540       }
541 
542       tmp = LLVMBuildICmp(builder, LLVMIntULT, ac_get_thread_id(&ctx->ac), i32_4, "");
543       ac_build_ifcc(&ctx->ac, tmp, 5225);
544       {
545          tmp = LLVMBuildAdd(builder, ac_get_thread_id(&ctx->ac), scratch_emit_basev, "");
546          tmp = ac_build_gep0(&ctx->ac, ctx->gs_ngg_scratch, tmp);
547          LLVMBuildStore(builder, emit_vgpr, tmp);
548       }
549       ac_build_endif(&ctx->ac, 5225);
550    }
551    ac_build_endif(&ctx->ac, 5200);
552 
553    /* Determine the workgroup-relative per-thread / primitive offset into
554     * the streamout buffers */
555    struct ac_wg_scan primemit_scan[4] = {};
556 
557    if (isgs) {
558       for (unsigned stream = 0; stream < 4; ++stream) {
559          if (!info->num_stream_output_components[stream])
560             continue;
561 
562          primemit_scan[stream].stage = ctx->stage;
563          primemit_scan[stream].enable_exclusive = true;
564          primemit_scan[stream].op = nir_op_iadd;
565          primemit_scan[stream].src = nggso->prim_enable[stream];
566          primemit_scan[stream].scratch = ac_build_gep0(
567             &ctx->ac, ctx->gs_ngg_scratch, LLVMConstInt(ctx->ac.i32, 12 + 8 * stream, false));
568          primemit_scan[stream].waveidx = get_wave_id_in_tg(ctx);
569          primemit_scan[stream].numwaves = get_tgsize(ctx);
570          if (ctx->stage == MESA_SHADER_GEOMETRY) {
571             /* ngg_subgroup_size is only the input size. GS can always generate up to 256 vertices. */
572             primemit_scan[stream].maxwaves = DIV_ROUND_UP(256, ctx->ac.wave_size);
573          } else {
574             primemit_scan[stream].maxwaves = DIV_ROUND_UP(ctx->screen->ngg_subgroup_size,
575                                                           ctx->ac.wave_size);
576          }
577          ac_build_wg_scan_top(&ctx->ac, &primemit_scan[stream]);
578       }
579    }
580 
581    ac_build_waitcnt(&ctx->ac, AC_WAIT_LGKM);
582    ac_build_s_barrier(&ctx->ac, ctx->stage);
583 
584    /* Fetch the per-buffer offsets and per-stream emit counts in all waves. */
585    LLVMValueRef wgoffset_dw[4] = {};
586 
587    {
588       LLVMValueRef scratch_vgpr;
589 
590       tmp = ac_build_gep0(&ctx->ac, ctx->gs_ngg_scratch, ac_get_thread_id(&ctx->ac));
591       scratch_vgpr = LLVMBuildLoad2(builder, ctx->ac.i32, tmp, "");
592 
593       for (unsigned buffer = 0; buffer < 4; ++buffer) {
594          if (stream_for_buffer[buffer] >= 0) {
595             wgoffset_dw[buffer] =
596                ac_build_readlane(&ctx->ac, scratch_vgpr,
597                                  LLVMConstInt(ctx->ac.i32, scratch_offset_base + buffer, false));
598          }
599       }
600 
601       for (unsigned stream = 0; stream < 4; ++stream) {
602          if (info->num_stream_output_components[stream]) {
603             nggso->emit[stream] =
604                ac_build_readlane(&ctx->ac, scratch_vgpr,
605                                  LLVMConstInt(ctx->ac.i32, scratch_emit_base + stream, false));
606          }
607       }
608    }
609 
610    /* Write out primitive data */
611    for (unsigned stream = 0; stream < 4; ++stream) {
612       if (!info->num_stream_output_components[stream])
613          continue;
614 
615       if (isgs) {
616          ac_build_wg_scan_bottom(&ctx->ac, &primemit_scan[stream]);
617       } else {
618          primemit_scan[stream].result_exclusive = tid;
619       }
620 
621       tmp = LLVMBuildICmp(builder, LLVMIntULT, primemit_scan[stream].result_exclusive,
622                           nggso->emit[stream], "");
623       tmp = LLVMBuildAnd(builder, tmp, nggso->prim_enable[stream], "");
624       ac_build_ifcc(&ctx->ac, tmp, 5240);
625       {
626          LLVMValueRef offset_vtx =
627             LLVMBuildMul(builder, primemit_scan[stream].result_exclusive, nggso->num_vertices, "");
628 
629          for (unsigned i = 0; i < max_num_vertices; ++i) {
630             tmp = LLVMBuildICmp(builder, LLVMIntULT, LLVMConstInt(ctx->ac.i32, i, false),
631                                 nggso->num_vertices, "");
632             ac_build_ifcc(&ctx->ac, tmp, 5241);
633             build_streamout_vertex(ctx, so_buffer, wgoffset_dw, stream, offset_vtx,
634                                    nggso->vertices[i]);
635             ac_build_endif(&ctx->ac, 5241);
636             offset_vtx = LLVMBuildAdd(builder, offset_vtx, ctx->ac.i32_1, "");
637          }
638       }
639       ac_build_endif(&ctx->ac, 5240);
640    }
641 }
642 
643 /* LDS layout of ES vertex data for NGG culling. */
644 enum
645 {
646    /* Byte 0: Boolean ES thread accepted (unculled) flag.
647     * Byte 1: New ES thread ID, loaded by GS to prepare the prim export value.
648     * Byte 2: TES rel patch ID
649     * Byte 3: 8-bit clip distance mask: 1 means the clip distance is negative.
650     *         The mask from all vertices is AND'ed. If the result is non-zero,
651     *         the primitive is culled.
652     */
653    lds_byte0_accept_flag = 0,
654    lds_byte1_new_thread_id,
655    lds_byte2_tes_rel_patch_id,
656    lds_byte3_clipdist_neg_mask,
657 
658    lds_packed_data = 0, /* lds_byteN_... */
659    lds_pos_cull_x_div_w,
660    lds_pos_cull_y_div_w,
661    lds_pos_cull_w,
662 
663    lds_pos_x = lds_packed_data + 1,
664    lds_pos_y,
665    lds_pos_z,
666    lds_pos_w,
667    /* If VS: */
668    lds_vertex_id,
669    lds_instance_id, /* optional */
670    /* If TES: */
671    lds_tes_u = lds_vertex_id,
672    lds_tes_v = lds_instance_id,
673    lds_tes_patch_id, /* optional */
674 };
675 
si_build_gep_i8_var(struct si_shader_context * ctx,LLVMValueRef ptr,LLVMValueRef index)676 static LLVMValueRef si_build_gep_i8_var(struct si_shader_context *ctx, LLVMValueRef ptr,
677                                         LLVMValueRef index)
678 {
679    LLVMTypeRef pi8 = LLVMPointerType(ctx->ac.i8, AC_ADDR_SPACE_LDS);
680 
681    return LLVMBuildGEP(ctx->ac.builder, LLVMBuildPointerCast(ctx->ac.builder, ptr, pi8, ""), &index,
682                        1, "");
683 }
684 
si_build_gep_i8(struct si_shader_context * ctx,LLVMValueRef ptr,unsigned byte_index)685 static LLVMValueRef si_build_gep_i8(struct si_shader_context *ctx, LLVMValueRef ptr,
686                                     unsigned byte_index)
687 {
688    assert(byte_index < 4);
689    return si_build_gep_i8_var(ctx, ptr, LLVMConstInt(ctx->ac.i32, byte_index, 0));
690 }
691 
ngg_nogs_vertex_size(struct si_shader * shader)692 static unsigned ngg_nogs_vertex_size(struct si_shader *shader)
693 {
694    unsigned lds_vertex_size = 0;
695 
696    /* The edgeflag is always stored in the last element that's also
697     * used for padding to reduce LDS bank conflicts. */
698    if (si_shader_uses_streamout(shader))
699       lds_vertex_size = 4 * shader->selector->info.num_outputs + 1;
700    if (gfx10_ngg_writes_user_edgeflags(shader))
701       lds_vertex_size = MAX2(lds_vertex_size, 1);
702 
703    /* LDS size for passing data from GS to ES.
704     * GS stores Primitive IDs into LDS at the address corresponding
705     * to the ES thread of the provoking vertex. All ES threads
706     * load and export PrimitiveID for their thread.
707     */
708    if (shader->selector->stage == MESA_SHADER_VERTEX && shader->key.ge.mono.u.vs_export_prim_id)
709       lds_vertex_size = MAX2(lds_vertex_size, 1);
710 
711    if (shader->key.ge.opt.ngg_culling) {
712       if (shader->selector->stage == MESA_SHADER_VERTEX) {
713          STATIC_ASSERT(lds_instance_id + 1 == 7);
714          lds_vertex_size = MAX2(lds_vertex_size, 7);
715       } else {
716          assert(shader->selector->stage == MESA_SHADER_TESS_EVAL);
717 
718          if (shader->selector->info.uses_primid || shader->key.ge.mono.u.vs_export_prim_id) {
719             STATIC_ASSERT(lds_tes_patch_id + 2 == 9); /* +1 for LDS padding */
720             lds_vertex_size = MAX2(lds_vertex_size, 9);
721          } else {
722             STATIC_ASSERT(lds_tes_v + 1 == 7);
723             lds_vertex_size = MAX2(lds_vertex_size, 7);
724          }
725       }
726    }
727 
728    return lds_vertex_size;
729 }
730 
731 /**
732  * Returns an `[N x i32] addrspace(LDS)*` pointing at contiguous LDS storage
733  * for the vertex outputs.
734  */
ngg_nogs_vertex_ptr(struct si_shader_context * ctx,LLVMValueRef vtxid)735 static LLVMValueRef ngg_nogs_vertex_ptr(struct si_shader_context *ctx, LLVMValueRef vtxid)
736 {
737    /* The extra dword is used to avoid LDS bank conflicts. */
738    unsigned vertex_size = ngg_nogs_vertex_size(ctx->shader);
739    LLVMTypeRef ai32 = LLVMArrayType(ctx->ac.i32, vertex_size);
740    LLVMTypeRef pai32 = LLVMPointerType(ai32, AC_ADDR_SPACE_LDS);
741    LLVMValueRef tmp = LLVMBuildBitCast(ctx->ac.builder, ctx->esgs_ring, pai32, "");
742    return LLVMBuildGEP(ctx->ac.builder, tmp, &vtxid, 1, "");
743 }
744 
si_insert_input_v4i32(struct si_shader_context * ctx,LLVMValueRef ret,struct ac_arg param,unsigned return_index)745 static LLVMValueRef si_insert_input_v4i32(struct si_shader_context *ctx, LLVMValueRef ret,
746                                           struct ac_arg param, unsigned return_index)
747 {
748    LLVMValueRef v = ac_get_arg(&ctx->ac, param);
749 
750    for (unsigned i = 0; i < 4; i++) {
751       ret = LLVMBuildInsertValue(ctx->ac.builder, ret, ac_llvm_extract_elem(&ctx->ac, v, i),
752                                  return_index + i, "");
753    }
754    return ret;
755 }
756 
load_vertex_counts(struct si_shader_context * ctx,LLVMValueRef lds,unsigned max_waves,LLVMValueRef tid,LLVMValueRef * total_count,LLVMValueRef * prefix_sum)757 static void load_vertex_counts(struct si_shader_context *ctx, LLVMValueRef lds,
758                                unsigned max_waves, LLVMValueRef tid,
759                                LLVMValueRef *total_count,
760                                LLVMValueRef *prefix_sum)
761 {
762    LLVMBuilderRef builder = ctx->ac.builder;
763    LLVMValueRef i8vec4_lane = ac_build_alloca_undef(&ctx->ac, ctx->ac.i32, "");
764    unsigned num_i8vec4 = DIV_ROUND_UP(max_waves, 4);
765 
766    /* If all threads loaded the vertex counts, it would cause many LDS bank conflicts
767     * and the performance could decrease up to WaveSize times (32x or 64x).
768     *
769     * Therefore, only load the i-th tuple of vertex counts in the i-th thread. Other threads will
770     * get them through readlane. 4 8-bit vertex counts are loaded per thread.
771     */
772    ac_build_ifcc(&ctx->ac, LLVMBuildICmp(builder, LLVMIntULT, tid,
773                                          LLVMConstInt(ctx->ac.i32, num_i8vec4, 0), ""), 17771);
774    LLVMBuildStore(builder, LLVMBuildLoad2(builder, ctx->ac.i32, ac_build_gep0(&ctx->ac, lds, tid), ""), i8vec4_lane);
775    ac_build_endif(&ctx->ac, 17771);
776 
777    /* Compute the number of ES waves. */
778    LLVMValueRef num_waves = get_tgsize(ctx);
779 
780    /* Compute a byte mask where each byte is either 0 or 0xff depending on whether the wave
781     * exists. We need the mask to clear uninitialized bytes in LDS and to compute the prefix sum.
782     *
783     * 8 waves: valid_mask = ~0ull >> (64 - num_waves * 8)
784     * 4 waves: valid_mask = ~0 >> (32 - num_waves * 8)
785     */
786    LLVMValueRef num_waves8 = LLVMBuildShl(builder, num_waves, LLVMConstInt(ctx->ac.i32, 3, 0), "");
787    LLVMValueRef valid_mask;
788 
789    if (max_waves > 4) {
790       LLVMValueRef num_waves8_rev = LLVMBuildSub(builder, LLVMConstInt(ctx->ac.i32, 64, 0),
791                                                  num_waves8, "");
792       valid_mask = LLVMBuildLShr(builder, LLVMConstInt(ctx->ac.i64, ~0ull, 0),
793                                  LLVMBuildZExt(builder, num_waves8_rev, ctx->ac.i64, ""), "");
794    } else {
795       LLVMValueRef num_waves8_rev = LLVMBuildSub(builder, LLVMConstInt(ctx->ac.i32, 32, 0),
796                                                  num_waves8, "");
797       valid_mask = LLVMBuildLShr(builder, LLVMConstInt(ctx->ac.i32, ~0, 0), num_waves8_rev, "");
798    }
799 
800    /* Compute a byte mask where bytes below wave_id are 0xff, else they are 0.
801     *
802     * prefix_mask = ~(~0 << (wave_id * 8))
803     */
804    LLVMTypeRef type = max_waves > 4 ? ctx->ac.i64 : ctx->ac.i32;
805    LLVMValueRef wave_id8 = LLVMBuildShl(builder, get_wave_id_in_tg(ctx),
806                                         LLVMConstInt(ctx->ac.i32, 3, 0), "");
807    LLVMValueRef prefix_mask =
808       LLVMBuildNot(builder, LLVMBuildShl(builder, LLVMConstInt(type, ~0ull, 0),
809                                          LLVMBuildZExt(builder, wave_id8, type, ""), ""), "");
810 
811    /* Compute the total vertex count and the vertex count of previous waves (prefix). */
812    *total_count = ctx->ac.i32_0;
813    *prefix_sum = ctx->ac.i32_0;
814 
815    for (unsigned i = 0; i < num_i8vec4; i++) {
816       LLVMValueRef i8vec4;
817 
818       i8vec4 = ac_build_readlane_no_opt_barrier(&ctx->ac, LLVMBuildLoad2(builder, ctx->ac.i32, i8vec4_lane, ""),
819                                                 LLVMConstInt(ctx->ac.i32, i, 0));
820       /* Inactive waves have uninitialized vertex counts. Set them to 0 using this. */
821       i8vec4 = LLVMBuildAnd(builder, i8vec4,
822                             ac_unpack_param(&ctx->ac, valid_mask, 32 * i, 32), "");
823       /* Compute the sum of all i8vec4 components and add it to the result. */
824       *total_count = ac_build_intrinsic(&ctx->ac, "llvm.amdgcn.sad.u8", ctx->ac.i32,
825                                         (LLVMValueRef[]){i8vec4, ctx->ac.i32_0, *total_count},
826                                         3, AC_FUNC_ATTR_READNONE);
827       ac_set_range_metadata(&ctx->ac, *total_count, 0, 64*4 + 1); /* the result is at most 64*4 */
828 
829       /* Compute the sum of the vertex counts of all previous waves. */
830       i8vec4 = LLVMBuildAnd(builder, i8vec4,
831                                 ac_unpack_param(&ctx->ac, prefix_mask, 32 * i, 32), "");
832       *prefix_sum = ac_build_intrinsic(&ctx->ac, "llvm.amdgcn.sad.u8", ctx->ac.i32,
833                                        (LLVMValueRef[]){i8vec4, ctx->ac.i32_0, *prefix_sum},
834                                        3, AC_FUNC_ATTR_READNONE);
835       ac_set_range_metadata(&ctx->ac, *prefix_sum, 0, 64*4 + 1); /* the result is at most 64*4 */
836    }
837    *total_count = ac_build_readlane_no_opt_barrier(&ctx->ac, *total_count, NULL);
838 }
839 
840 /**
841  * Given a total thread count, update total and per-wave thread counts in input SGPRs
842  * and return the per-wave thread count.
843  *
844  * \param new_num_threads    Total thread count on the input, per-wave thread count on the output.
845  * \param tg_info            tg_info SGPR value
846  * \param tg_info_num_bits   the bit size of thread count field in tg_info
847  * \param tg_info_shift      the bit offset of the thread count field in tg_info
848  * \param wave_info          merged_wave_info SGPR value
849  * \param wave_info_num_bits the bit size of thread count field in merged_wave_info
850  * \param wave_info_shift    the bit offset of the thread count field in merged_wave_info
851  */
update_thread_counts(struct si_shader_context * ctx,LLVMValueRef * new_num_threads,LLVMValueRef * tg_info,unsigned tg_info_num_bits,unsigned tg_info_shift,LLVMValueRef * wave_info,unsigned wave_info_num_bits,unsigned wave_info_shift)852 static void update_thread_counts(struct si_shader_context *ctx, LLVMValueRef *new_num_threads,
853                                  LLVMValueRef *tg_info, unsigned tg_info_num_bits,
854                                  unsigned tg_info_shift, LLVMValueRef *wave_info,
855                                  unsigned wave_info_num_bits, unsigned wave_info_shift)
856 {
857    LLVMBuilderRef builder = ctx->ac.builder;
858 
859    /* Update the total thread count. */
860    unsigned tg_info_mask = ~(u_bit_consecutive(0, tg_info_num_bits) << tg_info_shift);
861    *tg_info = LLVMBuildAnd(builder, *tg_info, LLVMConstInt(ctx->ac.i32, tg_info_mask, 0), "");
862    *tg_info = LLVMBuildOr(
863       builder, *tg_info,
864       LLVMBuildShl(builder, *new_num_threads, LLVMConstInt(ctx->ac.i32, tg_info_shift, 0), ""), "");
865 
866    /* Update the per-wave thread count. */
867    LLVMValueRef prev_threads = LLVMBuildMul(builder, get_wave_id_in_tg(ctx),
868                                             LLVMConstInt(ctx->ac.i32, ctx->ac.wave_size, 0), "");
869    *new_num_threads = LLVMBuildSub(builder, *new_num_threads, prev_threads, "");
870    *new_num_threads = ac_build_imax(&ctx->ac, *new_num_threads, ctx->ac.i32_0);
871    *new_num_threads =
872       ac_build_imin(&ctx->ac, *new_num_threads, LLVMConstInt(ctx->ac.i32, ctx->ac.wave_size, 0));
873    unsigned wave_info_mask = ~(u_bit_consecutive(0, wave_info_num_bits) << wave_info_shift);
874    *wave_info = LLVMBuildAnd(builder, *wave_info, LLVMConstInt(ctx->ac.i32, wave_info_mask, 0), "");
875    *wave_info = LLVMBuildOr(
876       builder, *wave_info,
877       LLVMBuildShl(builder, *new_num_threads, LLVMConstInt(ctx->ac.i32, wave_info_shift, 0), ""),
878       "");
879 }
880 
gfx10_build_primitive_accepted(struct ac_llvm_context * ac,LLVMValueRef accepted,void * userdata)881 static void gfx10_build_primitive_accepted(struct ac_llvm_context *ac, LLVMValueRef accepted,
882                                            void *userdata)
883 {
884    struct si_shader_context *ctx = container_of(ac, struct si_shader_context, ac);
885    LLVMValueRef *params = (LLVMValueRef *)userdata;
886    LLVMValueRef gs_accepted = params[0];
887    LLVMValueRef *gs_vtxptr = (LLVMValueRef *)params[1];
888 
889    unsigned num_vertices;
890    ngg_get_vertices_per_prim(ctx, &num_vertices);
891 
892    ac_build_ifcc(&ctx->ac, accepted, 0);
893    LLVMBuildStore(ctx->ac.builder, ctx->ac.i32_1, gs_accepted);
894 
895    if (gs_vtxptr) {
896       for (unsigned vtx = 0; vtx < num_vertices; vtx++) {
897          LLVMBuildStore(ctx->ac.builder, ctx->ac.i8_1,
898                         si_build_gep_i8(ctx, gs_vtxptr[vtx], lds_byte0_accept_flag));
899       }
900    }
901    ac_build_endif(&ctx->ac, 0);
902 }
903 
add_clipdist_bit(struct si_shader_context * ctx,LLVMValueRef distance,unsigned i,LLVMValueRef * packed_data)904 static void add_clipdist_bit(struct si_shader_context *ctx, LLVMValueRef distance, unsigned i,
905                              LLVMValueRef *packed_data)
906 {
907    LLVMValueRef neg = LLVMBuildFCmp(ctx->ac.builder, LLVMRealOLT, distance, ctx->ac.f32_0, "");
908    neg = LLVMBuildZExt(ctx->ac.builder, neg, ctx->ac.i32, "");
909    /* Put the negative distance flag into lds_byte3_clipdist_neg_mask. */
910    neg = LLVMBuildShl(ctx->ac.builder, neg, LLVMConstInt(ctx->ac.i32, 24 + i, 0), "");
911    *packed_data = LLVMBuildOr(ctx->ac.builder, *packed_data, neg, "");
912 }
913 
add_clipdist_bits_for_clipvertex(struct si_shader_context * ctx,unsigned clipdist_enable,LLVMValueRef clipvertex[4],LLVMValueRef * packed_data)914 static bool add_clipdist_bits_for_clipvertex(struct si_shader_context *ctx,
915                                              unsigned clipdist_enable,
916                                              LLVMValueRef clipvertex[4],
917                                              LLVMValueRef *packed_data)
918 {
919    struct ac_export_args clipdist[2];
920    bool added = false;
921 
922    si_llvm_clipvertex_to_clipdist(ctx, clipdist, clipvertex);
923 
924    for (unsigned j = 0; j < 8; j++) {
925       if (!(clipdist_enable & BITFIELD_BIT(j)))
926          continue;
927 
928       LLVMValueRef distance = clipdist[j / 4].out[j % 4];
929       add_clipdist_bit(ctx, distance, j, packed_data);
930       added = true;
931    }
932    return added;
933 }
934 
cull_primitive(struct si_shader_context * ctx,LLVMValueRef pos[3][4],LLVMValueRef clipdist_accepted,LLVMValueRef out_prim_accepted,LLVMValueRef gs_vtxptr_accept[3])935 static void cull_primitive(struct si_shader_context *ctx,
936                            LLVMValueRef pos[3][4], LLVMValueRef clipdist_accepted,
937                            LLVMValueRef out_prim_accepted, LLVMValueRef gs_vtxptr_accept[3])
938 {
939    struct si_shader *shader = ctx->shader;
940    LLVMBuilderRef builder = ctx->ac.builder;
941 
942    LLVMValueRef vp_scale[2] = {}, vp_translate[2] = {}, small_prim_precision = NULL;
943    LLVMValueRef clip_half_line_width[2] = {};
944 
945    /* Load the viewport state for small prim culling. */
946    bool prim_is_lines = shader->key.ge.opt.ngg_culling & SI_NGG_CULL_LINES;
947    LLVMValueRef ptr = ac_get_arg(&ctx->ac, ctx->small_prim_cull_info);
948    /* Lines will always use the non-AA viewport transformation. */
949    LLVMValueRef vp = ac_build_load_to_sgpr(&ctx->ac, ptr,
950                                            prim_is_lines ? ctx->ac.i32_1 : ctx->ac.i32_0);
951    vp = LLVMBuildBitCast(builder, vp, ctx->ac.v4f32, "");
952    vp_scale[0] = ac_llvm_extract_elem(&ctx->ac, vp, 0);
953    vp_scale[1] = ac_llvm_extract_elem(&ctx->ac, vp, 1);
954    vp_translate[0] = ac_llvm_extract_elem(&ctx->ac, vp, 2);
955    vp_translate[1] = ac_llvm_extract_elem(&ctx->ac, vp, 3);
956 
957    /* Execute culling code. */
958    struct ac_cull_options options = {};
959    options.cull_view_xy = true;
960    options.cull_w = true;
961 
962    if (prim_is_lines) {
963       ptr = LLVMBuildPointerCast(ctx->ac.builder, ptr,
964                                  LLVMPointerType(ctx->ac.v2f32, AC_ADDR_SPACE_CONST_32BIT), "");
965       LLVMValueRef terms = ac_build_load_to_sgpr(&ctx->ac, ptr, LLVMConstInt(ctx->ac.i32, 4, 0));
966       terms = LLVMBuildBitCast(builder, terms, ctx->ac.v2f32, "");
967       clip_half_line_width[0] = ac_llvm_extract_elem(&ctx->ac, terms, 0);
968       clip_half_line_width[1] = ac_llvm_extract_elem(&ctx->ac, terms, 1);
969       small_prim_precision = GET_FIELD(ctx, GS_STATE_SMALL_PRIM_PRECISION_NO_AA);
970 
971       options.num_vertices = 2;
972       options.cull_small_prims = shader->key.ge.opt.ngg_culling & SI_NGG_CULL_SMALL_LINES_DIAMOND_EXIT;
973 
974       assert(!(shader->key.ge.opt.ngg_culling & SI_NGG_CULL_BACK_FACE));
975       assert(!(shader->key.ge.opt.ngg_culling & SI_NGG_CULL_FRONT_FACE));
976    } else {
977       /* Get the small prim filter precision. */
978       small_prim_precision = GET_FIELD(ctx, GS_STATE_SMALL_PRIM_PRECISION);
979 
980       options.num_vertices = 3;
981       options.cull_front = shader->key.ge.opt.ngg_culling & SI_NGG_CULL_FRONT_FACE;
982       options.cull_back = shader->key.ge.opt.ngg_culling & SI_NGG_CULL_BACK_FACE;
983       options.cull_small_prims = true; /* this would only be false with conservative rasterization */
984       options.cull_zero_area = options.cull_front || options.cull_back;
985    }
986 
987    /* Extract the small prim precision. */
988    small_prim_precision =
989       LLVMBuildOr(builder, small_prim_precision, LLVMConstInt(ctx->ac.i32, 0x70, 0), "");
990    small_prim_precision =
991       LLVMBuildShl(builder, small_prim_precision, LLVMConstInt(ctx->ac.i32, 23, 0), "");
992    small_prim_precision = LLVMBuildBitCast(builder, small_prim_precision, ctx->ac.f32, "");
993 
994    /* Tell ES threads whether their vertex survived. */
995    LLVMValueRef params[] = {
996       out_prim_accepted,
997       (void*)gs_vtxptr_accept,
998    };
999    ac_cull_primitive(&ctx->ac, pos, clipdist_accepted, vp_scale, vp_translate,
1000                      small_prim_precision, clip_half_line_width,
1001                      &options, gfx10_build_primitive_accepted, params);
1002 }
1003 
1004 /**
1005  * Cull primitives for NGG VS or TES, then compact vertices, which happens
1006  * before the VS or TES main function. Return values for the main function.
1007  * Also return the position, which is passed to the shader as an input,
1008  * so that we don't compute it twice.
1009  */
gfx10_ngg_culling_build_end(struct si_shader_context * ctx)1010 void gfx10_ngg_culling_build_end(struct si_shader_context *ctx)
1011 {
1012    struct si_shader *shader = ctx->shader;
1013    struct si_shader_selector *sel = shader->selector;
1014    struct si_shader_info *info = &sel->info;
1015    LLVMBuilderRef builder = ctx->ac.builder;
1016    LLVMValueRef *addrs = ctx->abi.outputs;
1017    unsigned max_waves = DIV_ROUND_UP(ctx->screen->ngg_subgroup_size, ctx->ac.wave_size);
1018 
1019    assert(shader->key.ge.opt.ngg_culling);
1020    assert(shader->key.ge.as_ngg);
1021    assert(sel->stage == MESA_SHADER_VERTEX ||
1022           (sel->stage == MESA_SHADER_TESS_EVAL && !shader->key.ge.as_es));
1023 
1024    LLVMValueRef es_vtxptr = ngg_nogs_vertex_ptr(ctx, gfx10_get_thread_id_in_tg(ctx));
1025    LLVMValueRef packed_data = ctx->ac.i32_0;
1026    LLVMValueRef position[4] = {};
1027    unsigned pos_index = 0;
1028    unsigned clip_plane_enable = SI_NGG_CULL_GET_CLIP_PLANE_ENABLE(shader->key.ge.opt.ngg_culling);
1029    unsigned clipdist_enable = (sel->info.clipdist_mask & clip_plane_enable) | sel->info.culldist_mask;
1030    bool has_clipdist_mask = false;
1031 
1032    for (unsigned i = 0; i < info->num_outputs; i++) {
1033       LLVMValueRef clipvertex[4];
1034       unsigned base;
1035 
1036       switch (info->output_semantic[i]) {
1037       case VARYING_SLOT_POS:
1038          /* If we are going to cull everything (rasterizer_discard), discard
1039           * the position. This is useful for analyzing maximum theoretical
1040           * performance without VS input loads.
1041           */
1042          if (shader->key.ge.opt.ngg_culling & SI_NGG_CULL_FRONT_FACE &&
1043              shader->key.ge.opt.ngg_culling & SI_NGG_CULL_BACK_FACE) {
1044             for (unsigned j = 0; j < 4; j++)
1045                LLVMBuildStore(builder, LLVMGetUndef(ctx->ac.f32), addrs[4 * i + j]);
1046             break;
1047          }
1048 
1049          pos_index = i;
1050          for (unsigned j = 0; j < 4; j++) {
1051             position[j] = LLVMBuildLoad2(ctx->ac.builder, ctx->ac.f32, addrs[4 * i + j], "");
1052          }
1053 
1054          /* Store Position.W into LDS. */
1055          LLVMBuildStore(
1056             builder, ac_to_integer(&ctx->ac, position[3]),
1057             ac_build_gep0(&ctx->ac, es_vtxptr, LLVMConstInt(ctx->ac.i32, lds_pos_cull_w, 0)));
1058 
1059          /* Store Position.XY / W into LDS. */
1060          for (unsigned chan = 0; chan < 2; chan++) {
1061             LLVMValueRef val = ac_build_fdiv(&ctx->ac, position[chan], position[3]);
1062             LLVMBuildStore(
1063                builder, ac_to_integer(&ctx->ac, val),
1064                ac_build_gep0(&ctx->ac, es_vtxptr, LLVMConstInt(ctx->ac.i32, lds_pos_cull_x_div_w + chan, 0)));
1065          }
1066          break;
1067 
1068       case VARYING_SLOT_CLIP_DIST0:
1069       case VARYING_SLOT_CLIP_DIST1:
1070          base = info->output_semantic[i] == VARYING_SLOT_CLIP_DIST1 ? 4 : 0;
1071 
1072          for (unsigned j = 0; j < 4; j++) {
1073             unsigned index = base + j;
1074 
1075             if (!(clipdist_enable & BITFIELD_BIT(index)))
1076                continue;
1077 
1078             LLVMValueRef distance = LLVMBuildLoad2(ctx->ac.builder, ctx->ac.f32, addrs[4 * i + j], "");
1079             add_clipdist_bit(ctx, distance, index, &packed_data);
1080             has_clipdist_mask = true;
1081          }
1082          break;
1083 
1084       case VARYING_SLOT_CLIP_VERTEX:
1085          for (unsigned j = 0; j < 4; j++)
1086             clipvertex[j] = LLVMBuildLoad2(ctx->ac.builder, ctx->ac.f32, addrs[4 * i + j], "");
1087 
1088          if (add_clipdist_bits_for_clipvertex(ctx, clipdist_enable, clipvertex, &packed_data))
1089             has_clipdist_mask = true;
1090          break;
1091       }
1092    }
1093 
1094    if (clip_plane_enable && !sel->info.clipdist_mask) {
1095       /* When clip planes are enabled and there are no clip distance outputs,
1096        * we should use user clip planes and cull against the position.
1097        */
1098       assert(!has_clipdist_mask);
1099       if (add_clipdist_bits_for_clipvertex(ctx, clipdist_enable, position, &packed_data))
1100          has_clipdist_mask = true;
1101    }
1102 
1103    /* Initialize the packed data. */
1104    LLVMBuildStore(
1105       builder, packed_data,
1106       ac_build_gep0(&ctx->ac, es_vtxptr, LLVMConstInt(ctx->ac.i32, lds_packed_data, 0)));
1107    ac_build_endif(&ctx->ac, ctx->merged_wrap_if_label);
1108 
1109    ac_build_waitcnt(&ctx->ac, AC_WAIT_LGKM);
1110    ac_build_s_barrier(&ctx->ac, ctx->stage);
1111 
1112    LLVMValueRef tid = ac_get_thread_id(&ctx->ac);
1113 
1114    unsigned num_vertices;
1115    ngg_get_vertices_per_prim(ctx, &num_vertices);
1116 
1117    /* The hardware requires that there are no holes between unculled vertices,
1118     * which means we have to pack ES threads, i.e. reduce the ES thread count
1119     * and move ES input VGPRs to lower threads. The upside is that varyings
1120     * are only fetched and computed for unculled vertices.
1121     *
1122     * Vertex compaction:
1123     *
1124     * Part 1: Store the surviving vertex count for each wave in LDS.
1125     *   - The GS culling code notifies ES threads which vertices were accepted.
1126     *   - Barrier
1127     *   - ES threads will compute the vertex count and store it in LDS.
1128     * - Barrier
1129     * - Each wave loads the vertex counts from LDS.
1130     *
1131     * Part 2: Compact ES threads:
1132     * - Compute the prefix sum for each surviving vertex. This is the new thread ID
1133     *   of the vertex.
1134     * - Write input VGPRs and vertex positions for each surviving vertex into the LDS
1135     *   address of the new thread ID.
1136     * - Now kill all waves that have inactive threads.
1137     * - Barrier
1138     * - Update vertex indices and null flag in the GS input VGPRs.
1139     *
1140     * Part 3: Update inputs GPRs
1141     * - For all waves, update per-wave thread counts in input SGPRs.
1142     * - In ES threads, update the ES input VGPRs (VertexID, InstanceID, TES inputs).
1143     */
1144 
1145    LLVMValueRef vtxindex[3];
1146    for (unsigned i = 0; i < num_vertices; ++i)
1147       vtxindex[i] = si_unpack_param(ctx, ctx->args.gs_vtx_offset[i / 2], (i & 1) * 16, 16);
1148 
1149    LLVMValueRef gs_vtxptr[3];
1150    for (unsigned i = 0; i < num_vertices; i++)
1151       gs_vtxptr[i] = ngg_nogs_vertex_ptr(ctx, vtxindex[i]);
1152 
1153    es_vtxptr = ngg_nogs_vertex_ptr(ctx, gfx10_get_thread_id_in_tg(ctx));
1154 
1155    /* Adding these optimization barriers improves the generated code as follows. Crazy right?
1156     *
1157     * - s_mov_b32 s4, 0xffff
1158     * - v_lshrrev_b32_e32 v10, 16, v0
1159     * - v_and_b32_e32 v12, s4, v0
1160     * - v_and_b32_e32 v11, s4, v1
1161     *   s_bfe_u32 s4, s3, 0x80008
1162     * - s_mov_b64 s[8:9], 0
1163     * - v_mul_u32_u24_e32 v0, 28, v10
1164     * - v_mul_u32_u24_e32 v9, 28, v12
1165     * - v_mul_u32_u24_e32 v1, 28, v11
1166     * + v_mov_b32_e32 v11, 28
1167     *   v_cmp_gt_u32_e32 vcc, s4, v2
1168     * + s_mov_b64 s[8:9], 0
1169     *   s_waitcnt lgkmcnt(0)
1170     *   s_barrier
1171     * + v_mul_u32_u24_sdwa v10, v0, v11 dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:WORD_0 src1_sel:DWORD
1172     * + v_mul_u32_u24_sdwa v23, v0, v11 dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:WORD_1 src1_sel:DWORD
1173     * + v_mul_u32_u24_sdwa v0, v1, v11 dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:WORD_0 src1_sel:DWORD
1174     *   s_and_saveexec_b64 s[44:45], vcc
1175     *   s_cbranch_execz BB2_8
1176     * - v_mul_u32_u24_e32 v16, 28, v12
1177     * - v_mul_u32_u24_e32 v17, 28, v11
1178     * - v_mul_u32_u24_e32 v18, 28, v10
1179     */
1180    for (unsigned i = 0; i < num_vertices; i++)
1181       ac_build_optimization_barrier(&ctx->ac, &gs_vtxptr[i], false);
1182 
1183    LLVMValueRef gs_accepted = ac_build_alloca(&ctx->ac, ctx->ac.i32, "");
1184 
1185    /* Do culling in GS threads. */
1186    ac_build_ifcc(&ctx->ac, si_is_gs_thread(ctx), 16002);
1187    {
1188       /* Load positions. */
1189       LLVMValueRef pos[3][4] = {};
1190       LLVMValueRef clipdist_neg_mask = NULL;
1191 
1192       for (unsigned vtx = 0; vtx < num_vertices; vtx++) {
1193          for (unsigned chan = 0; chan < 4; chan++) {
1194             unsigned index;
1195             if (chan == 0 || chan == 1)
1196                index = lds_pos_cull_x_div_w + chan;
1197             else if (chan == 3)
1198                index = lds_pos_cull_w;
1199             else
1200                continue;
1201 
1202             LLVMValueRef addr =
1203                ac_build_gep0(&ctx->ac, gs_vtxptr[vtx], LLVMConstInt(ctx->ac.i32, index, 0));
1204             pos[vtx][chan] = LLVMBuildLoad(builder, addr, "");
1205             pos[vtx][chan] = ac_to_float(&ctx->ac, pos[vtx][chan]);
1206          }
1207 
1208          if (has_clipdist_mask) {
1209             /* Load and AND clip distance masks. Each bit means whether that clip distance is
1210              * negative. If all masks are AND'ed and the result is 0, the primitive isn't culled
1211              * by clip distances.
1212              */
1213             LLVMValueRef addr = si_build_gep_i8(ctx, gs_vtxptr[vtx], lds_byte3_clipdist_neg_mask);
1214             LLVMValueRef mask = LLVMBuildLoad2(builder, ctx->ac.i8, addr, "");
1215             if (!clipdist_neg_mask)
1216                clipdist_neg_mask = mask;
1217             else
1218                clipdist_neg_mask = LLVMBuildAnd(builder, clipdist_neg_mask, mask, "");
1219          }
1220       }
1221 
1222       LLVMValueRef clipdist_accepted =
1223          has_clipdist_mask ? LLVMBuildICmp(builder, LLVMIntEQ, clipdist_neg_mask, ctx->ac.i8_0, "")
1224                            : ctx->ac.i1true;
1225 
1226       cull_primitive(ctx, pos, clipdist_accepted, gs_accepted, gs_vtxptr);
1227    }
1228    ac_build_endif(&ctx->ac, 16002);
1229 
1230    ac_build_waitcnt(&ctx->ac, AC_WAIT_LGKM);
1231    ac_build_s_barrier(&ctx->ac, ctx->stage);
1232 
1233    gs_accepted = LLVMBuildLoad2(builder, ctx->ac.i32, gs_accepted, "");
1234 
1235    LLVMValueRef vertex_accepted = ac_build_alloca(&ctx->ac, ctx->ac.i1, "");
1236    LLVMValueRef vertex_mask = ac_build_alloca(&ctx->ac, ctx->ac.iN_wavemask, "");
1237 
1238    /* Convert the per-vertex accept flag to a vertex thread mask, store it in registers. */
1239    ac_build_ifcc(&ctx->ac, si_is_es_thread(ctx), 16007);
1240    {
1241       LLVMValueRef accepted =
1242          LLVMBuildLoad2(builder, ctx->ac.i8, si_build_gep_i8(ctx, es_vtxptr, lds_byte0_accept_flag), "");
1243       accepted = LLVMBuildICmp(builder, LLVMIntNE, accepted, ctx->ac.i8_0, "");
1244       LLVMValueRef mask = ac_get_i1_sgpr_mask(&ctx->ac, accepted);
1245 
1246       LLVMBuildStore(builder, accepted, vertex_accepted);
1247       LLVMBuildStore(builder, mask, vertex_mask);
1248    }
1249    ac_build_endif(&ctx->ac, 16007);
1250 
1251    /* Store the per-wave vertex count to LDS. Non-ES waves store 0. */
1252    vertex_mask = LLVMBuildLoad2(builder, ctx->ac.iN_wavemask, vertex_mask, "");
1253    ac_build_ifcc(&ctx->ac, LLVMBuildICmp(builder, LLVMIntEQ, tid, ctx->ac.i32_0, ""), 16008);
1254    {
1255       LLVMValueRef vertex_count = ac_build_bit_count(&ctx->ac, vertex_mask);
1256       LLVMBuildStore(builder, LLVMBuildTrunc(builder, vertex_count, ctx->ac.i8, ""),
1257                      si_build_gep_i8_var(ctx, ctx->gs_ngg_scratch, get_wave_id_in_tg(ctx)));
1258    }
1259    ac_build_endif(&ctx->ac, 16008);
1260 
1261    ac_build_waitcnt(&ctx->ac, AC_WAIT_LGKM);
1262    ac_build_s_barrier(&ctx->ac, ctx->stage);
1263 
1264    /* Load the vertex masks and compute the new ES thread count. */
1265    LLVMValueRef new_num_es_threads, prefix_sum, kill_wave;
1266    load_vertex_counts(ctx, ctx->gs_ngg_scratch, max_waves, tid, &new_num_es_threads,
1267                       &prefix_sum);
1268 
1269    bool uses_instance_id = ctx->stage == MESA_SHADER_VERTEX &&
1270                            (sel->info.uses_instanceid ||
1271                             shader->key.ge.part.vs.prolog.instance_divisor_is_one ||
1272                             shader->key.ge.part.vs.prolog.instance_divisor_is_fetched);
1273    bool uses_tes_prim_id = ctx->stage == MESA_SHADER_TESS_EVAL &&
1274                            (sel->info.uses_primid || shader->key.ge.mono.u.vs_export_prim_id);
1275 
1276    /* ES threads compute their prefix sum, which is the new ES thread ID.
1277     * Then they write the vertex position and input VGPRs into the LDS address
1278     * of the new thread ID. It will be used to load input VGPRs by compacted
1279     * threads.
1280     */
1281    vertex_accepted = LLVMBuildLoad2(builder, ctx->ac.i1, vertex_accepted, "");
1282    ac_build_ifcc(&ctx->ac, vertex_accepted, 16009);
1283    {
1284       /* Add the number of bits set in vertex_mask up to the current thread ID - 1
1285        * to get the prefix sum.
1286        */
1287       prefix_sum = LLVMBuildAdd(builder, prefix_sum, ac_build_mbcnt(&ctx->ac, vertex_mask), "");
1288 
1289       LLVMValueRef new_id = prefix_sum;
1290       LLVMValueRef new_vtx = ngg_nogs_vertex_ptr(ctx, new_id);
1291 
1292       LLVMBuildStore(builder, LLVMBuildTrunc(builder, new_id, ctx->ac.i8, ""),
1293                      si_build_gep_i8(ctx, es_vtxptr, lds_byte1_new_thread_id));
1294 
1295       /* Store Position.XYZW into LDS. */
1296       for (unsigned chan = 0; chan < 4; chan++) {
1297          LLVMBuildStore(
1298             builder, ac_to_integer(&ctx->ac,
1299                                    LLVMBuildLoad2(builder, ctx->ac.f32, addrs[4 * pos_index + chan], "")),
1300             ac_build_gep0(&ctx->ac, new_vtx, LLVMConstInt(ctx->ac.i32, lds_pos_x + chan, 0)));
1301       }
1302 
1303       /* Store VertexID and InstanceID into LDS. ES threads will have to load them
1304        * from LDS after vertex compaction and use them instead of their own
1305        * system values.
1306        */
1307       if (ctx->stage == MESA_SHADER_VERTEX) {
1308          LLVMBuildStore(
1309             builder, ctx->abi.vertex_id,
1310             ac_build_gep0(&ctx->ac, new_vtx, LLVMConstInt(ctx->ac.i32, lds_vertex_id, 0)));
1311          if (uses_instance_id) {
1312             LLVMBuildStore(
1313                builder, ctx->abi.instance_id,
1314                ac_build_gep0(&ctx->ac, new_vtx, LLVMConstInt(ctx->ac.i32, lds_instance_id, 0)));
1315          }
1316       } else {
1317          assert(ctx->stage == MESA_SHADER_TESS_EVAL);
1318          LLVMBuildStore(builder, ac_to_integer(&ctx->ac, ac_get_arg(&ctx->ac, ctx->args.tes_u)),
1319                         ac_build_gep0(&ctx->ac, new_vtx, LLVMConstInt(ctx->ac.i32, lds_tes_u, 0)));
1320          LLVMBuildStore(builder, ac_to_integer(&ctx->ac, ac_get_arg(&ctx->ac, ctx->args.tes_v)),
1321                         ac_build_gep0(&ctx->ac, new_vtx, LLVMConstInt(ctx->ac.i32, lds_tes_v, 0)));
1322          LLVMBuildStore(builder, LLVMBuildTrunc(builder, ac_get_arg(&ctx->ac, ctx->args.tes_rel_patch_id), ctx->ac.i8, ""),
1323                         si_build_gep_i8(ctx, new_vtx, lds_byte2_tes_rel_patch_id));
1324          if (uses_tes_prim_id) {
1325             LLVMBuildStore(
1326                builder, ac_get_arg(&ctx->ac, ctx->args.tes_patch_id),
1327                ac_build_gep0(&ctx->ac, new_vtx, LLVMConstInt(ctx->ac.i32, lds_tes_patch_id, 0)));
1328          }
1329       }
1330    }
1331    ac_build_endif(&ctx->ac, 16009);
1332 
1333    /* If all vertices are culled, set the primitive count to 0, so that all waves are culled here. */
1334    LLVMValueRef num_primitives = ngg_get_prim_cnt(ctx);
1335    num_primitives = LLVMBuildSelect(builder,
1336                                     LLVMBuildICmp(builder, LLVMIntEQ, new_num_es_threads,
1337                                                   ctx->ac.i32_0, ""),
1338                                     ctx->ac.i32_0, num_primitives, "");
1339    /* Kill waves that have inactive threads. */
1340    kill_wave = LLVMBuildICmp(builder, LLVMIntULE,
1341                              ac_build_imax(&ctx->ac, new_num_es_threads, num_primitives),
1342                              LLVMBuildMul(builder, get_wave_id_in_tg(ctx),
1343                                           LLVMConstInt(ctx->ac.i32, ctx->ac.wave_size, 0), ""),
1344                              "");
1345    ac_build_ifcc(&ctx->ac, kill_wave, 19202);
1346    {
1347       /* If we are killing wave 0, send that there are no primitives
1348        * in this threadgroup.
1349        */
1350       ac_build_sendmsg_gs_alloc_req(&ctx->ac, get_wave_id_in_tg(ctx), ctx->ac.i32_0, ctx->ac.i32_0);
1351       ac_build_s_endpgm(&ctx->ac);
1352    }
1353    ac_build_endif(&ctx->ac, 19202);
1354 
1355    ac_build_waitcnt(&ctx->ac, AC_WAIT_LGKM);
1356    ac_build_s_barrier(&ctx->ac, ctx->stage);
1357 
1358    /* Send the final vertex and primitive counts. */
1359    ac_build_sendmsg_gs_alloc_req(&ctx->ac, get_wave_id_in_tg(ctx), new_num_es_threads,
1360                                  ngg_get_prim_cnt(ctx));
1361 
1362    /* Update thread counts in SGPRs. */
1363    LLVMValueRef new_gs_tg_info = ac_get_arg(&ctx->ac, ctx->args.gs_tg_info);
1364    LLVMValueRef new_merged_wave_info = ac_get_arg(&ctx->ac, ctx->args.merged_wave_info);
1365 
1366    /* This also converts the thread count from the total count to the per-wave count. */
1367    update_thread_counts(ctx, &new_num_es_threads, &new_gs_tg_info, 9, 12, &new_merged_wave_info, 8,
1368                         0);
1369 
1370    /* Update vertex indices in VGPR0 (same format as NGG passthrough).
1371     *
1372     * Set the null flag at the beginning (culled), and then
1373     * overwrite it for accepted primitives.
1374     */
1375    LLVMValueRef new_vgpr0 =
1376       ac_build_alloca_init(&ctx->ac, LLVMConstInt(ctx->ac.i32, 1u << 31, 0), "");
1377 
1378    /* Get vertex indices after vertex compaction. */
1379    ac_build_ifcc(&ctx->ac, LLVMBuildTrunc(builder, gs_accepted, ctx->ac.i1, ""), 16011);
1380    {
1381       struct ac_ngg_prim prim = {};
1382       prim.num_vertices = num_vertices;
1383       prim.isnull = ctx->ac.i1false;
1384 
1385       if (gfx10_edgeflags_have_effect(shader))
1386          prim.edgeflags = ac_pack_edgeflags_for_export(&ctx->ac, &ctx->args);
1387       else
1388          prim.edgeflags = ctx->ac.i32_0;
1389 
1390       for (unsigned vtx = 0; vtx < num_vertices; vtx++) {
1391          prim.index[vtx] = LLVMBuildLoad2(
1392             builder, ctx->ac.i8, si_build_gep_i8(ctx, gs_vtxptr[vtx], lds_byte1_new_thread_id), "");
1393          prim.index[vtx] = LLVMBuildZExt(builder, prim.index[vtx], ctx->ac.i32, "");
1394       }
1395 
1396       /* Set the new GS input VGPR. */
1397       LLVMBuildStore(builder, ac_pack_prim_export(&ctx->ac, &prim), new_vgpr0);
1398    }
1399    ac_build_endif(&ctx->ac, 16011);
1400 
1401    if (gfx10_ngg_export_prim_early(shader))
1402       gfx10_ngg_build_export_prim(ctx, NULL, LLVMBuildLoad2(builder, ctx->ac.i32, new_vgpr0, ""));
1403 
1404    /* Prepare LDS addresses of the new ES input VGPRs. */
1405    LLVMValueRef input_vgpr_addresses[4] = {
1406       ac_build_gep0(&ctx->ac, es_vtxptr, LLVMConstInt(ctx->ac.i32, lds_vertex_id, 0)),
1407       ac_build_gep0(&ctx->ac, es_vtxptr, LLVMConstInt(ctx->ac.i32, lds_instance_id, 0)),
1408    };
1409    if (ctx->stage == MESA_SHADER_TESS_EVAL) {
1410       input_vgpr_addresses[2] = si_build_gep_i8(ctx, es_vtxptr, lds_byte2_tes_rel_patch_id);
1411       if (uses_tes_prim_id) {
1412          input_vgpr_addresses[3] = ac_build_gep0(&ctx->ac, es_vtxptr,
1413                                                  LLVMConstInt(ctx->ac.i32, lds_tes_patch_id, 0));
1414       }
1415    }
1416 
1417    /* Return values for the main function. */
1418    LLVMValueRef ret = ctx->return_value;
1419    LLVMValueRef val;
1420 
1421    ret = LLVMBuildInsertValue(ctx->ac.builder, ret, new_gs_tg_info, 2, "");
1422    ret = LLVMBuildInsertValue(ctx->ac.builder, ret, new_merged_wave_info, 3, "");
1423    if (ctx->stage == MESA_SHADER_TESS_EVAL)
1424       ret = si_insert_input_ret(ctx, ret, ctx->args.tess_offchip_offset, 4);
1425    if (ctx->ac.gfx_level >= GFX11)
1426       ret = si_insert_input_ret(ctx, ret, ctx->args.gs_attr_offset, 5);
1427 
1428    ret = si_insert_input_ptr(ctx, ret, ctx->internal_bindings, 8 + SI_SGPR_INTERNAL_BINDINGS);
1429    ret = si_insert_input_ptr(ctx, ret, ctx->bindless_samplers_and_images,
1430                              8 + SI_SGPR_BINDLESS_SAMPLERS_AND_IMAGES);
1431    ret = si_insert_input_ptr(ctx, ret, ctx->const_and_shader_buffers,
1432                              8 + SI_SGPR_CONST_AND_SHADER_BUFFERS);
1433    ret = si_insert_input_ptr(ctx, ret, ctx->samplers_and_images, 8 + SI_SGPR_SAMPLERS_AND_IMAGES);
1434    ret = si_insert_input_ptr(ctx, ret, ctx->vs_state_bits, 8 + SI_SGPR_VS_STATE_BITS);
1435    if (ctx->ac.gfx_level >= GFX11)
1436       ret = si_insert_input_ptr(ctx, ret, ctx->gs_attr_address, 8 + GFX9_SGPR_ATTRIBUTE_RING_ADDR);
1437 
1438    if (ctx->stage == MESA_SHADER_VERTEX) {
1439       ret = si_insert_input_ptr(ctx, ret, ctx->args.base_vertex, 8 + SI_SGPR_BASE_VERTEX);
1440       ret = si_insert_input_ptr(ctx, ret, ctx->args.draw_id, 8 + SI_SGPR_DRAWID);
1441       ret = si_insert_input_ptr(ctx, ret, ctx->args.start_instance, 8 + SI_SGPR_START_INSTANCE);
1442       ret = si_insert_input_ptr(ctx, ret, ctx->args.vertex_buffers, 8 + GFX9_GS_NUM_USER_SGPR);
1443 
1444       for (unsigned i = 0; i < shader->selector->info.num_vbos_in_user_sgprs; i++) {
1445          ret = si_insert_input_v4i32(ctx, ret, ctx->vb_descriptors[i],
1446                                      8 + SI_SGPR_VS_VB_DESCRIPTOR_FIRST + i * 4);
1447       }
1448    } else {
1449       assert(ctx->stage == MESA_SHADER_TESS_EVAL);
1450       ret = si_insert_input_ptr(ctx, ret, ctx->tcs_offchip_layout, 8 + SI_SGPR_TES_OFFCHIP_LAYOUT);
1451       ret = si_insert_input_ptr(ctx, ret, ctx->tes_offchip_addr, 8 + SI_SGPR_TES_OFFCHIP_ADDR);
1452    }
1453 
1454    unsigned vgpr;
1455    if (ctx->stage == MESA_SHADER_VERTEX) {
1456       if (shader->selector->info.num_vbos_in_user_sgprs) {
1457          vgpr = 8 + SI_SGPR_VS_VB_DESCRIPTOR_FIRST + shader->selector->info.num_vbos_in_user_sgprs * 4;
1458       } else {
1459          vgpr = 8 + GFX9_GS_NUM_USER_SGPR + 1;
1460       }
1461    } else {
1462       vgpr = 8 + GFX9_GS_NUM_USER_SGPR;
1463    }
1464 
1465    val = LLVMBuildLoad2(builder, ctx->ac.i32, new_vgpr0, "");
1466    ret = LLVMBuildInsertValue(builder, ret, ac_to_float(&ctx->ac, val), vgpr++, "");
1467    vgpr++; /* gs_vtx_offset[1] = offsets of vertices 2-3  */
1468 
1469    ret = si_insert_input_ret_float(ctx, ret, ctx->args.gs_prim_id, vgpr++);
1470    ret = si_insert_input_ret_float(ctx, ret, ctx->args.gs_invocation_id, vgpr++);
1471    vgpr++; /* gs_vtx_offset[2] = offsets of vertices 4-5 */
1472 
1473    /* Set the input VPGRs to the corresponding LDS addresses where the VGPR values are
1474     * stored. The VS prolog will load them.
1475     */
1476    if (ctx->stage == MESA_SHADER_VERTEX) {
1477       val = LLVMBuildPtrToInt(builder, input_vgpr_addresses[0], ctx->ac.i32, "");
1478       ret = LLVMBuildInsertValue(builder, ret, ac_to_float(&ctx->ac, val), vgpr++,
1479                                  ""); /* VGPR5 - VertexID */
1480       vgpr += 2;
1481       if (uses_instance_id) {
1482          val = LLVMBuildPtrToInt(builder, input_vgpr_addresses[1], ctx->ac.i32, "");
1483          ret = LLVMBuildInsertValue(builder, ret, ac_to_float(&ctx->ac, val), vgpr++,
1484                                     ""); /* VGPR8 - InstanceID */
1485       } else {
1486          vgpr++;
1487       }
1488    } else {
1489       assert(ctx->stage == MESA_SHADER_TESS_EVAL);
1490       unsigned num_vgprs = uses_tes_prim_id ? 4 : 3;
1491       for (unsigned i = 0; i < num_vgprs; i++) {
1492          val = LLVMBuildPtrToInt(builder, input_vgpr_addresses[i], ctx->ac.i32, "");
1493          ret = LLVMBuildInsertValue(builder, ret, ac_to_float(&ctx->ac, val), vgpr++, "");
1494       }
1495       if (num_vgprs == 3)
1496          vgpr++;
1497    }
1498 
1499    /* These two also use LDS. */
1500    if (gfx10_ngg_writes_user_edgeflags(shader) ||
1501        (ctx->stage == MESA_SHADER_VERTEX && shader->key.ge.mono.u.vs_export_prim_id)) {
1502       ac_build_waitcnt(&ctx->ac, AC_WAIT_LGKM);
1503       ac_build_s_barrier(&ctx->ac, ctx->stage);
1504    }
1505 
1506    ctx->return_value = ret;
1507 }
1508 
1509 /**
1510  * Emit the end of an API VS or TES shader compiled as ESGS shader.
1511  */
gfx10_ngg_build_end(struct si_shader_context * ctx)1512 void gfx10_ngg_build_end(struct si_shader_context *ctx)
1513 {
1514    struct si_shader_selector *sel = ctx->shader->selector;
1515    struct si_shader_info *info = &sel->info;
1516    struct si_shader_output_values outputs[PIPE_MAX_SHADER_OUTPUTS];
1517    LLVMBuilderRef builder = ctx->ac.builder;
1518    LLVMValueRef *addrs = ctx->abi.outputs;
1519    LLVMValueRef tmp, tmp2;
1520 
1521    assert(!ctx->shader->is_gs_copy_shader);
1522    assert(info->num_outputs <= AC_LLVM_MAX_OUTPUTS);
1523 
1524    LLVMValueRef vertex_ptr = NULL;
1525 
1526    if (ctx->so.num_outputs || gfx10_ngg_writes_user_edgeflags(ctx->shader))
1527       vertex_ptr = ngg_nogs_vertex_ptr(ctx, gfx10_get_thread_id_in_tg(ctx));
1528 
1529    for (unsigned i = 0; i < info->num_outputs; i++) {
1530       outputs[i].semantic = info->output_semantic[i];
1531 
1532       for (unsigned j = 0; j < 4; j++) {
1533          outputs[i].vertex_streams = info->output_streams[i];
1534 
1535          /* TODO: we may store more outputs than streamout needs,
1536           * but streamout performance isn't that important.
1537           */
1538          if (ctx->so.num_outputs) {
1539             tmp = ac_build_gep0(&ctx->ac, vertex_ptr, LLVMConstInt(ctx->ac.i32, 4 * i + j, false));
1540             tmp2 = LLVMBuildLoad2(builder, ctx->ac.f32, addrs[4 * i + j], "");
1541             LLVMTypeRef type = ac_to_integer_type(&ctx->ac, ctx->ac.f32);
1542             tmp2 = LLVMBuildBitCast(ctx->ac.builder, tmp2, type, "");
1543             LLVMBuildStore(builder, tmp2, tmp);
1544          }
1545       }
1546 
1547       /* Store the edgeflag at the end (if streamout is enabled) */
1548       if (info->output_semantic[i] == VARYING_SLOT_EDGE && gfx10_ngg_writes_user_edgeflags(ctx->shader)) {
1549          LLVMValueRef edgeflag = LLVMBuildLoad2(builder, ctx->ac.f32, addrs[4 * i], "");
1550          /* The output is a float, but the hw expects a 1-bit integer. */
1551          edgeflag = LLVMBuildFPToUI(ctx->ac.builder, edgeflag, ctx->ac.i32, "");
1552          edgeflag = ac_build_umin(&ctx->ac, edgeflag, ctx->ac.i32_1);
1553 
1554          tmp = LLVMConstInt(ctx->ac.i32, ngg_nogs_vertex_size(ctx->shader) - 1, 0);
1555          tmp = ac_build_gep0(&ctx->ac, vertex_ptr, tmp);
1556          LLVMBuildStore(builder, edgeflag, tmp);
1557       }
1558    }
1559 
1560    bool unterminated_es_if_block =
1561       !ctx->so.num_outputs && !gfx10_ngg_writes_user_edgeflags(ctx->shader) &&
1562       !ctx->screen->use_ngg_streamout && /* no query buffer */
1563       (ctx->stage != MESA_SHADER_VERTEX || !ctx->shader->key.ge.mono.u.vs_export_prim_id);
1564 
1565    if (!unterminated_es_if_block)
1566       ac_build_endif(&ctx->ac, ctx->merged_wrap_if_label);
1567 
1568    LLVMValueRef is_gs_thread = si_is_gs_thread(ctx);
1569    LLVMValueRef is_es_thread = si_is_es_thread(ctx);
1570    LLVMValueRef vtxindex[3];
1571 
1572    if (ctx->shader->key.ge.opt.ngg_culling || gfx10_is_ngg_passthrough(ctx->shader)) {
1573       for (unsigned i = 0; i < 3; ++i)
1574          vtxindex[i] = si_unpack_param(ctx, ctx->args.gs_vtx_offset[0], 10 * i, 9);
1575    } else {
1576       for (unsigned i = 0; i < 3; ++i)
1577          vtxindex[i] = si_unpack_param(ctx, ctx->args.gs_vtx_offset[i / 2], (i & 1) * 16, 16);
1578    }
1579 
1580    /* Determine the number of vertices per primitive. */
1581    unsigned num_vertices;
1582    LLVMValueRef num_vertices_val = ngg_get_vertices_per_prim(ctx, &num_vertices);
1583 
1584    /* Streamout */
1585    LLVMValueRef emitted_prims = NULL;
1586 
1587    if (ctx->so.num_outputs) {
1588       assert(!unterminated_es_if_block);
1589 
1590       struct ngg_streamout nggso = {};
1591       nggso.num_vertices = num_vertices_val;
1592       nggso.prim_enable[0] = is_gs_thread;
1593 
1594       for (unsigned i = 0; i < num_vertices; ++i)
1595          nggso.vertices[i] = ngg_nogs_vertex_ptr(ctx, vtxindex[i]);
1596 
1597       build_streamout(ctx, &nggso);
1598       emitted_prims = nggso.emit[0];
1599    }
1600 
1601    LLVMValueRef user_edgeflags[3] = {};
1602 
1603    if (gfx10_ngg_writes_user_edgeflags(ctx->shader)) {
1604       assert(!unterminated_es_if_block);
1605 
1606       /* Streamout already inserted the barrier, so don't insert it again. */
1607       if (!ctx->so.num_outputs) {
1608          ac_build_waitcnt(&ctx->ac, AC_WAIT_LGKM);
1609          ac_build_s_barrier(&ctx->ac, ctx->stage);
1610       }
1611 
1612       ac_build_ifcc(&ctx->ac, is_gs_thread, 5400);
1613       /* Load edge flags from ES threads and store them into VGPRs in GS threads. */
1614       for (unsigned i = 0; i < num_vertices; i++) {
1615          tmp = ngg_nogs_vertex_ptr(ctx, vtxindex[i]);
1616          tmp2 = LLVMConstInt(ctx->ac.i32, ngg_nogs_vertex_size(ctx->shader) - 1, 0);
1617          tmp = ac_build_gep0(&ctx->ac, tmp, tmp2);
1618          tmp = LLVMBuildLoad2(builder, ctx->ac.i32, tmp, "");
1619          tmp = LLVMBuildTrunc(builder, tmp, ctx->ac.i1, "");
1620 
1621          user_edgeflags[i] = ac_build_alloca_init(&ctx->ac, tmp, "");
1622       }
1623       ac_build_endif(&ctx->ac, 5400);
1624    }
1625 
1626    /* Copy Primitive IDs from GS threads to the LDS address corresponding
1627     * to the ES thread of the provoking vertex.
1628     */
1629    if (ctx->stage == MESA_SHADER_VERTEX && ctx->shader->key.ge.mono.u.vs_export_prim_id) {
1630       assert(!unterminated_es_if_block);
1631 
1632       /* Streamout and edge flags use LDS. Make it idle, so that we can reuse it. */
1633       if (ctx->so.num_outputs || gfx10_ngg_writes_user_edgeflags(ctx->shader)) {
1634          ac_build_waitcnt(&ctx->ac, AC_WAIT_LGKM);
1635          ac_build_s_barrier(&ctx->ac, ctx->stage);
1636       }
1637 
1638       ac_build_ifcc(&ctx->ac, is_gs_thread, 5400);
1639       /* Extract the PROVOKING_VTX_INDEX field. */
1640       LLVMValueRef provoking_vtx_in_prim = GET_FIELD(ctx, GS_STATE_PROVOKING_VTX_INDEX);
1641 
1642       /* provoking_vtx_index = vtxindex[provoking_vtx_in_prim]; */
1643       LLVMValueRef indices = ac_build_gather_values(&ctx->ac, vtxindex, 3);
1644       LLVMValueRef provoking_vtx_index =
1645          LLVMBuildExtractElement(builder, indices, provoking_vtx_in_prim, "");
1646       LLVMValueRef vertex_ptr = ngg_nogs_vertex_ptr(ctx, provoking_vtx_index);
1647 
1648       LLVMBuildStore(builder, ac_get_arg(&ctx->ac, ctx->args.gs_prim_id),
1649                      ac_build_gep0(&ctx->ac, vertex_ptr, ctx->ac.i32_0));
1650       ac_build_endif(&ctx->ac, 5400);
1651    }
1652 
1653    /* Update query buffer */
1654    if (ctx->screen->use_ngg_streamout && !info->base.vs.blit_sgprs_amd) {
1655       assert(!unterminated_es_if_block);
1656 
1657       tmp = GET_FIELD(ctx, GS_STATE_STREAMOUT_QUERY_ENABLED);
1658       tmp = LLVMBuildTrunc(builder, tmp, ctx->ac.i1, "");
1659       ac_build_ifcc(&ctx->ac, tmp, 5029); /* if (STREAMOUT_QUERY_ENABLED) */
1660       tmp = LLVMBuildICmp(builder, LLVMIntEQ, get_wave_id_in_tg(ctx), ctx->ac.i32_0, "");
1661       ac_build_ifcc(&ctx->ac, tmp, 5030);
1662       tmp = LLVMBuildICmp(builder, LLVMIntULE, ac_get_thread_id(&ctx->ac),
1663                           ctx->so.num_outputs ? ctx->ac.i32_1 : ctx->ac.i32_0, "");
1664       ac_build_ifcc(&ctx->ac, tmp, 5031);
1665       {
1666          LLVMValueRef args[] = {
1667             ngg_get_prim_cnt(ctx),
1668             ngg_get_query_buf(ctx),
1669             LLVMConstInt(ctx->ac.i32, 16, false), /* offset of stream[0].generated_primitives */
1670             ctx->ac.i32_0,                        /* soffset */
1671             ctx->ac.i32_0,                        /* cachepolicy */
1672          };
1673 
1674          if (ctx->so.num_outputs) {
1675             args[0] = ac_build_writelane(&ctx->ac, args[0], emitted_prims, ctx->ac.i32_1);
1676             args[2] = ac_build_writelane(&ctx->ac, args[2], LLVMConstInt(ctx->ac.i32, 24, false),
1677                                          ctx->ac.i32_1);
1678          }
1679 
1680          /* TODO: should this be 64-bit atomics? */
1681          ac_build_intrinsic(&ctx->ac, "llvm.amdgcn.raw.buffer.atomic.add.i32", ctx->ac.i32, args, 5,
1682                             0);
1683       }
1684       ac_build_endif(&ctx->ac, 5031);
1685       ac_build_endif(&ctx->ac, 5030);
1686       ac_build_endif(&ctx->ac, 5029);
1687    }
1688 
1689    /* Build the primitive export. */
1690    if (!gfx10_ngg_export_prim_early(ctx->shader)) {
1691       assert(!unterminated_es_if_block);
1692       gfx10_ngg_build_export_prim(ctx, user_edgeflags, NULL);
1693    }
1694 
1695    /* Export per-vertex data (positions and parameters). */
1696    if (!unterminated_es_if_block)
1697       ac_build_ifcc(&ctx->ac, is_es_thread, 6002);
1698    {
1699       unsigned i;
1700 
1701       /* Unconditionally (re-)load the values for proper SSA form. */
1702       for (i = 0; i < info->num_outputs; i++) {
1703          /* If the NGG cull shader part computed the position, don't
1704           * use the position from the current shader part. Instead,
1705           * load it from LDS.
1706           */
1707          if (info->output_semantic[i] == VARYING_SLOT_POS &&
1708              ctx->shader->key.ge.opt.ngg_culling) {
1709             vertex_ptr = ngg_nogs_vertex_ptr(ctx, gfx10_get_thread_id_in_tg(ctx));
1710 
1711             for (unsigned j = 0; j < 4; j++) {
1712                tmp = LLVMConstInt(ctx->ac.i32, lds_pos_x + j, 0);
1713                tmp = ac_build_gep0(&ctx->ac, vertex_ptr, tmp);
1714                tmp = LLVMBuildLoad2(builder, ctx->ac.i32, tmp, "");
1715                outputs[i].values[j] = LLVMBuildBitCast(ctx->ac.builder, tmp,
1716                                                        ac_to_float_type(&ctx->ac, ctx->ac.i32), "");
1717             }
1718          } else {
1719             for (unsigned j = 0; j < 4; j++) {
1720                outputs[i].values[j] = LLVMBuildLoad2(builder, ctx->ac.f32, addrs[4 * i + j], "");
1721             }
1722          }
1723       }
1724 
1725       if (ctx->shader->key.ge.mono.u.vs_export_prim_id) {
1726          outputs[i].semantic = VARYING_SLOT_PRIMITIVE_ID;
1727          outputs[i].vertex_streams = 0;
1728 
1729          if (ctx->stage == MESA_SHADER_VERTEX) {
1730             /* Wait for LDS stores to finish. */
1731             ac_build_waitcnt(&ctx->ac, AC_WAIT_LGKM);
1732             ac_build_s_barrier(&ctx->ac, ctx->stage);
1733 
1734             tmp = ngg_nogs_vertex_ptr(ctx, gfx10_get_thread_id_in_tg(ctx));
1735             tmp = ac_build_gep0(&ctx->ac, tmp, ctx->ac.i32_0);
1736             outputs[i].values[0] = LLVMBuildLoad2(builder, ctx->ac.i32, tmp, "");
1737          } else {
1738             assert(ctx->stage == MESA_SHADER_TESS_EVAL);
1739             outputs[i].values[0] = si_get_primitive_id(ctx, 0);
1740          }
1741 
1742          outputs[i].values[0] = LLVMBuildBitCast(ctx->ac.builder, outputs[i].values[0], ctx->ac.f32, "");
1743          for (unsigned j = 1; j < 4; j++)
1744             outputs[i].values[j] = LLVMGetUndef(ctx->ac.f32);
1745          i++;
1746       }
1747 
1748       si_llvm_build_vs_exports(ctx, NULL, outputs, i);
1749    }
1750    ac_build_endif(&ctx->ac, 6002);
1751 }
1752 
ngg_gs_get_vertex_storage(struct si_shader_context * ctx)1753 static LLVMValueRef ngg_gs_get_vertex_storage(struct si_shader_context *ctx)
1754 {
1755    const struct si_shader_selector *sel = ctx->shader->selector;
1756    const struct si_shader_info *info = &sel->info;
1757 
1758    LLVMTypeRef elements[2] = {
1759       LLVMArrayType(ctx->ac.i32, 4 * info->num_outputs),
1760       LLVMArrayType(ctx->ac.i8, 4),
1761    };
1762    LLVMTypeRef type = LLVMStructTypeInContext(ctx->ac.context, elements, 2, false);
1763    type = LLVMPointerType(LLVMArrayType(type, 0), AC_ADDR_SPACE_LDS);
1764    return LLVMBuildBitCast(ctx->ac.builder, ctx->gs_ngg_emit, type, "");
1765 }
1766 
1767 /**
1768  * Return a pointer to the LDS storage reserved for the N'th vertex, where N
1769  * is in emit order; that is:
1770  * - at the shader end, N is the threadidx (relative to the entire threadgroup)
1771  * - during vertex emit, i.e. while the API GS shader invocation is running,
1772  *   N = threadidx * gs.vertices_out + emitidx
1773  *
1774  * Goals of the LDS memory layout:
1775  * 1. Eliminate bank conflicts on write for geometry shaders that have all emits
1776  *    in uniform control flow
1777  * 2. Eliminate bank conflicts on read for export if, additionally, there is no
1778  *    culling
1779  * 3. Agnostic to the number of waves (since we don't know it before compiling)
1780  * 4. Allow coalescing of LDS instructions (ds_write_b128 etc.)
1781  * 5. Avoid wasting memory.
1782  *
1783  * We use an AoS layout due to point 4 (this also helps point 3). In an AoS
1784  * layout, elimination of bank conflicts requires that each vertex occupy an
1785  * odd number of dwords. We use the additional dword to store the output stream
1786  * index as well as a flag to indicate whether this vertex ends a primitive
1787  * for rasterization.
1788  *
1789  * Swizzling is required to satisfy points 1 and 2 simultaneously.
1790  *
1791  * Vertices are stored in export order (gsthread * gs.vertices_out + emitidx).
1792  * Indices are swizzled in groups of 32, which ensures point 1 without
1793  * disturbing point 2.
1794  *
1795  * \return an LDS pointer to type {[N x i32], [4 x i8]}
1796  */
ngg_gs_vertex_ptr(struct si_shader_context * ctx,LLVMValueRef vertexidx)1797 static LLVMValueRef ngg_gs_vertex_ptr(struct si_shader_context *ctx, LLVMValueRef vertexidx)
1798 {
1799    struct si_shader_selector *sel = ctx->shader->selector;
1800    LLVMBuilderRef builder = ctx->ac.builder;
1801    LLVMValueRef storage = ngg_gs_get_vertex_storage(ctx);
1802 
1803    /* gs.vertices_out = 2^(write_stride_2exp) * some odd number */
1804    unsigned write_stride_2exp = ffs(sel->info.base.gs.vertices_out) - 1;
1805    if (write_stride_2exp) {
1806       LLVMValueRef row = LLVMBuildLShr(builder, vertexidx, LLVMConstInt(ctx->ac.i32, 5, false), "");
1807       LLVMValueRef swizzle = LLVMBuildAnd(
1808          builder, row, LLVMConstInt(ctx->ac.i32, (1u << write_stride_2exp) - 1, false), "");
1809       vertexidx = LLVMBuildXor(builder, vertexidx, swizzle, "");
1810    }
1811 
1812    return ac_build_gep0(&ctx->ac, storage, vertexidx);
1813 }
1814 
ngg_gs_emit_vertex_ptr(struct si_shader_context * ctx,LLVMValueRef gsthread,LLVMValueRef emitidx)1815 static LLVMValueRef ngg_gs_emit_vertex_ptr(struct si_shader_context *ctx, LLVMValueRef gsthread,
1816                                            LLVMValueRef emitidx)
1817 {
1818    struct si_shader_selector *sel = ctx->shader->selector;
1819    LLVMBuilderRef builder = ctx->ac.builder;
1820    LLVMValueRef tmp;
1821 
1822    tmp = LLVMConstInt(ctx->ac.i32, sel->info.base.gs.vertices_out, false);
1823    tmp = LLVMBuildMul(builder, tmp, gsthread, "");
1824    const LLVMValueRef vertexidx = LLVMBuildAdd(builder, tmp, emitidx, "");
1825    return ngg_gs_vertex_ptr(ctx, vertexidx);
1826 }
1827 
ngg_gs_get_emit_output_ptr(struct si_shader_context * ctx,LLVMValueRef vertexptr,unsigned out_idx)1828 static LLVMValueRef ngg_gs_get_emit_output_ptr(struct si_shader_context *ctx,
1829                                                LLVMValueRef vertexptr, unsigned out_idx)
1830 {
1831    LLVMValueRef gep_idx[3] = {
1832       ctx->ac.i32_0, /* implied C-style array */
1833       ctx->ac.i32_0, /* first struct entry */
1834       LLVMConstInt(ctx->ac.i32, out_idx, false),
1835    };
1836    return LLVMBuildGEP(ctx->ac.builder, vertexptr, gep_idx, 3, "");
1837 }
1838 
ngg_gs_get_emit_primflag_ptr(struct si_shader_context * ctx,LLVMValueRef vertexptr,unsigned stream)1839 static LLVMValueRef ngg_gs_get_emit_primflag_ptr(struct si_shader_context *ctx,
1840                                                  LLVMValueRef vertexptr, unsigned stream)
1841 {
1842    LLVMValueRef gep_idx[3] = {
1843       ctx->ac.i32_0, /* implied C-style array */
1844       ctx->ac.i32_1, /* second struct entry */
1845       LLVMConstInt(ctx->ac.i32, stream, false),
1846    };
1847    return LLVMBuildGEP(ctx->ac.builder, vertexptr, gep_idx, 3, "");
1848 }
1849 
gfx10_ngg_gs_emit_vertex(struct si_shader_context * ctx,unsigned stream,LLVMValueRef * addrs)1850 void gfx10_ngg_gs_emit_vertex(struct si_shader_context *ctx, unsigned stream, LLVMValueRef *addrs)
1851 {
1852    const struct si_shader_selector *sel = ctx->shader->selector;
1853    const struct si_shader_info *info = &sel->info;
1854    LLVMBuilderRef builder = ctx->ac.builder;
1855    LLVMValueRef tmp;
1856    const LLVMValueRef vertexidx = LLVMBuildLoad2(builder, ctx->ac.i32, ctx->gs_next_vertex[stream], "");
1857 
1858    /* If this thread has already emitted the declared maximum number of
1859     * vertices, skip the write: excessive vertex emissions are not
1860     * supposed to have any effect.
1861     */
1862    const LLVMValueRef can_emit =
1863       LLVMBuildICmp(builder, LLVMIntULT, vertexidx,
1864                     LLVMConstInt(ctx->ac.i32, sel->info.base.gs.vertices_out, false), "");
1865 
1866    tmp = LLVMBuildAdd(builder, vertexidx, ctx->ac.i32_1, "");
1867    tmp = LLVMBuildSelect(builder, can_emit, tmp, vertexidx, "");
1868    LLVMBuildStore(builder, tmp, ctx->gs_next_vertex[stream]);
1869 
1870    ac_build_ifcc(&ctx->ac, can_emit, 9001);
1871 
1872    const LLVMValueRef vertexptr = ngg_gs_emit_vertex_ptr(ctx, gfx10_get_thread_id_in_tg(ctx), vertexidx);
1873    unsigned out_idx = 0;
1874    for (unsigned i = 0; i < info->num_outputs; i++) {
1875       for (unsigned chan = 0; chan < 4; chan++, out_idx++) {
1876          if (!(info->output_usagemask[i] & (1 << chan)) ||
1877              ((info->output_streams[i] >> (2 * chan)) & 3) != stream)
1878             continue;
1879 
1880          LLVMValueRef out_val = LLVMBuildLoad2(builder, ctx->ac.f32, addrs[4 * i + chan], "");
1881          LLVMTypeRef as_int = ac_to_integer_type(&ctx->ac, ctx->ac.f32);
1882          out_val = LLVMBuildBitCast(ctx->ac.builder, out_val, as_int, "");
1883          LLVMBuildStore(builder, out_val, ngg_gs_get_emit_output_ptr(ctx, vertexptr, out_idx));
1884       }
1885    }
1886    assert(out_idx * 4 == info->gsvs_vertex_size);
1887 
1888    /* Determine and store whether this vertex completed a primitive. */
1889    const LLVMValueRef curverts = LLVMBuildLoad2(builder, ctx->ac.i32, ctx->gs_curprim_verts[stream], "");
1890 
1891    tmp = LLVMConstInt(ctx->ac.i32, u_vertices_per_prim(sel->info.base.gs.output_primitive) - 1, false);
1892    const LLVMValueRef iscompleteprim = LLVMBuildICmp(builder, LLVMIntUGE, curverts, tmp, "");
1893 
1894    /* Since the geometry shader emits triangle strips, we need to
1895     * track which primitive is odd and swap vertex indices to get
1896     * the correct vertex order.
1897     */
1898    LLVMValueRef is_odd = ctx->ac.i1false;
1899    if (stream == 0 && u_vertices_per_prim(sel->info.base.gs.output_primitive) == 3) {
1900       tmp = LLVMBuildAnd(builder, curverts, ctx->ac.i32_1, "");
1901       is_odd = LLVMBuildICmp(builder, LLVMIntEQ, tmp, ctx->ac.i32_1, "");
1902    }
1903 
1904    tmp = LLVMBuildAdd(builder, curverts, ctx->ac.i32_1, "");
1905    LLVMBuildStore(builder, tmp, ctx->gs_curprim_verts[stream]);
1906 
1907    /* The per-vertex primitive flag encoding:
1908     *   bit 0: whether this vertex finishes a primitive
1909     *   bit 1: whether the primitive is odd (if we are emitting triangle strips)
1910     */
1911    tmp = LLVMBuildZExt(builder, iscompleteprim, ctx->ac.i8, "");
1912    tmp = LLVMBuildOr(
1913       builder, tmp,
1914       LLVMBuildShl(builder, LLVMBuildZExt(builder, is_odd, ctx->ac.i8, ""), ctx->ac.i8_1, ""), "");
1915    LLVMBuildStore(builder, tmp, ngg_gs_get_emit_primflag_ptr(ctx, vertexptr, stream));
1916 
1917    tmp = LLVMBuildLoad2(builder, ctx->ac.i32, ctx->gs_generated_prims[stream], "");
1918    tmp = LLVMBuildAdd(builder, tmp, LLVMBuildZExt(builder, iscompleteprim, ctx->ac.i32, ""), "");
1919    LLVMBuildStore(builder, tmp, ctx->gs_generated_prims[stream]);
1920 
1921    ac_build_endif(&ctx->ac, 9001);
1922 }
1923 
gfx10_ngg_gs_emit_begin(struct si_shader_context * ctx)1924 void gfx10_ngg_gs_emit_begin(struct si_shader_context *ctx)
1925 {
1926    /* Zero out the part of LDS scratch that is used to accumulate the
1927     * per-stream generated primitive count.
1928     */
1929    LLVMBuilderRef builder = ctx->ac.builder;
1930    LLVMValueRef scratchptr = ctx->gs_ngg_scratch;
1931    LLVMValueRef tid = gfx10_get_thread_id_in_tg(ctx);
1932    LLVMValueRef tmp;
1933 
1934    tmp = LLVMBuildICmp(builder, LLVMIntULT, tid, LLVMConstInt(ctx->ac.i32, 4, false), "");
1935    ac_build_ifcc(&ctx->ac, tmp, 5090);
1936    {
1937       LLVMValueRef ptr = ac_build_gep0(&ctx->ac, scratchptr, tid);
1938       LLVMBuildStore(builder, ctx->ac.i32_0, ptr);
1939    }
1940    ac_build_endif(&ctx->ac, 5090);
1941 
1942    if (ctx->screen->info.gfx_level < GFX11) {
1943       tmp = si_is_gs_thread(ctx);
1944       ac_build_ifcc(&ctx->ac, tmp, 15090);
1945          {
1946             tmp = GET_FIELD(ctx, GS_STATE_PIPELINE_STATS_EMU);
1947             tmp = LLVMBuildTrunc(builder, tmp, ctx->ac.i1, "");
1948             ac_build_ifcc(&ctx->ac, tmp, 5109); /* if (GS_PIPELINE_STATS_EMU) */
1949             LLVMValueRef args[] = {
1950                ctx->ac.i32_1,
1951                ngg_get_emulated_counters_buf(ctx),
1952                LLVMConstInt(ctx->ac.i32,
1953                             si_query_pipestat_end_dw_offset(ctx->screen, PIPE_STAT_QUERY_GS_INVOCATIONS) * 4,
1954                             false),
1955                ctx->ac.i32_0,                            /* soffset */
1956                ctx->ac.i32_0,                            /* cachepolicy */
1957             };
1958 
1959             ac_build_intrinsic(&ctx->ac, "llvm.amdgcn.raw.buffer.atomic.add.i32", ctx->ac.i32, args, 5, 0);
1960             ac_build_endif(&ctx->ac, 5109);
1961          }
1962       ac_build_endif(&ctx->ac, 15090);
1963    }
1964 
1965    ac_build_waitcnt(&ctx->ac, AC_WAIT_LGKM);
1966    ac_build_s_barrier(&ctx->ac, ctx->stage);
1967 }
1968 
gfx10_ngg_gs_build_end(struct si_shader_context * ctx)1969 void gfx10_ngg_gs_build_end(struct si_shader_context *ctx)
1970 {
1971    const struct si_shader_selector *sel = ctx->shader->selector;
1972    const struct si_shader_info *info = &sel->info;
1973    const unsigned verts_per_prim = u_vertices_per_prim(sel->info.base.gs.output_primitive);
1974    LLVMBuilderRef builder = ctx->ac.builder;
1975    LLVMValueRef i8_0 = LLVMConstInt(ctx->ac.i8, 0, false);
1976    LLVMValueRef tmp, tmp2;
1977 
1978    /* Zero out remaining (non-emitted) primitive flags.
1979     *
1980     * Note: Alternatively, we could pass the relevant gs_next_vertex to
1981     *       the emit threads via LDS. This is likely worse in the expected
1982     *       typical case where each GS thread emits the full set of
1983     *       vertices.
1984     */
1985    for (unsigned stream = 0; stream < 4; ++stream) {
1986       if (!info->num_stream_output_components[stream])
1987          continue;
1988 
1989       const LLVMValueRef gsthread = gfx10_get_thread_id_in_tg(ctx);
1990 
1991       ac_build_bgnloop(&ctx->ac, 5100);
1992 
1993       const LLVMValueRef vertexidx = LLVMBuildLoad2(builder, ctx->ac.i32, ctx->gs_next_vertex[stream], "");
1994       tmp = LLVMBuildICmp(builder, LLVMIntUGE, vertexidx,
1995                           LLVMConstInt(ctx->ac.i32, sel->info.base.gs.vertices_out, false), "");
1996       ac_build_ifcc(&ctx->ac, tmp, 5101);
1997       ac_build_break(&ctx->ac);
1998       ac_build_endif(&ctx->ac, 5101);
1999 
2000       tmp = LLVMBuildAdd(builder, vertexidx, ctx->ac.i32_1, "");
2001       LLVMBuildStore(builder, tmp, ctx->gs_next_vertex[stream]);
2002 
2003       tmp = ngg_gs_emit_vertex_ptr(ctx, gsthread, vertexidx);
2004       LLVMBuildStore(builder, i8_0, ngg_gs_get_emit_primflag_ptr(ctx, tmp, stream));
2005 
2006       ac_build_endloop(&ctx->ac, 5100);
2007    }
2008 
2009    /* Accumulate generated primitives counts across the entire threadgroup. */
2010    for (unsigned stream = 0; stream < 4; ++stream) {
2011       if (!info->num_stream_output_components[stream])
2012          continue;
2013 
2014       LLVMValueRef numprims = LLVMBuildLoad2(builder, ctx->ac.i32, ctx->gs_generated_prims[stream], "");
2015       numprims = ac_build_reduce(&ctx->ac, numprims, nir_op_iadd, ctx->ac.wave_size);
2016 
2017       tmp = LLVMBuildICmp(builder, LLVMIntEQ, ac_get_thread_id(&ctx->ac), ctx->ac.i32_0, "");
2018       ac_build_ifcc(&ctx->ac, tmp, 5105);
2019       {
2020          LLVMBuildAtomicRMW(
2021             builder, LLVMAtomicRMWBinOpAdd,
2022             ac_build_gep0(&ctx->ac, ctx->gs_ngg_scratch, LLVMConstInt(ctx->ac.i32, stream, false)),
2023             numprims, LLVMAtomicOrderingMonotonic, false);
2024       }
2025       ac_build_endif(&ctx->ac, 5105);
2026    }
2027 
2028    ac_build_endif(&ctx->ac, ctx->merged_wrap_if_label);
2029 
2030    ac_build_waitcnt(&ctx->ac, AC_WAIT_LGKM);
2031    ac_build_s_barrier(&ctx->ac, ctx->stage);
2032 
2033    const LLVMValueRef tid = gfx10_get_thread_id_in_tg(ctx);
2034    LLVMValueRef num_emit_threads = ngg_get_prim_cnt(ctx);
2035 
2036    /* Streamout */
2037    if (ctx->so.num_outputs) {
2038       struct ngg_streamout nggso = {};
2039 
2040       nggso.num_vertices = LLVMConstInt(ctx->ac.i32, verts_per_prim, false);
2041 
2042       LLVMValueRef vertexptr = ngg_gs_vertex_ptr(ctx, tid);
2043       for (unsigned stream = 0; stream < 4; ++stream) {
2044          if (!info->num_stream_output_components[stream])
2045             continue;
2046 
2047          tmp = LLVMBuildLoad2(builder, ctx->ac.i8, ngg_gs_get_emit_primflag_ptr(ctx, vertexptr, stream), "");
2048          tmp = LLVMBuildTrunc(builder, tmp, ctx->ac.i1, "");
2049          tmp2 = LLVMBuildICmp(builder, LLVMIntULT, tid, num_emit_threads, "");
2050          nggso.prim_enable[stream] = LLVMBuildAnd(builder, tmp, tmp2, "");
2051       }
2052 
2053       for (unsigned i = 0; i < verts_per_prim; ++i) {
2054          tmp = LLVMBuildSub(builder, tid, LLVMConstInt(ctx->ac.i32, verts_per_prim - i - 1, false),
2055                             "");
2056          tmp = ngg_gs_vertex_ptr(ctx, tmp);
2057          nggso.vertices[i] = ac_build_gep0(&ctx->ac, tmp, ctx->ac.i32_0);
2058       }
2059 
2060       build_streamout(ctx, &nggso);
2061    }
2062 
2063    /* Write shader query data. */
2064    if (ctx->screen->use_ngg_streamout) {
2065       tmp = GET_FIELD(ctx, GS_STATE_STREAMOUT_QUERY_ENABLED);
2066       tmp = LLVMBuildTrunc(builder, tmp, ctx->ac.i1, "");
2067       ac_build_ifcc(&ctx->ac, tmp, 5109); /* if (STREAMOUT_QUERY_ENABLED) */
2068       unsigned num_query_comps = ctx->so.num_outputs ? 8 : 4;
2069       tmp = LLVMBuildICmp(builder, LLVMIntULT, tid,
2070                           LLVMConstInt(ctx->ac.i32, num_query_comps, false), "");
2071       ac_build_ifcc(&ctx->ac, tmp, 5110);
2072       {
2073          LLVMValueRef offset;
2074          tmp = tid;
2075          if (ctx->so.num_outputs)
2076             tmp = LLVMBuildAnd(builder, tmp, LLVMConstInt(ctx->ac.i32, 3, false), "");
2077          offset = LLVMBuildNUWMul(builder, tmp, LLVMConstInt(ctx->ac.i32, 32, false), "");
2078          if (ctx->so.num_outputs) {
2079             tmp = LLVMBuildLShr(builder, tid, LLVMConstInt(ctx->ac.i32, 2, false), "");
2080             tmp = LLVMBuildNUWMul(builder, tmp, LLVMConstInt(ctx->ac.i32, 8, false), "");
2081             offset = LLVMBuildAdd(builder, offset, tmp, "");
2082          }
2083 
2084          tmp = LLVMBuildLoad2(builder, ctx->ac.i32, ac_build_gep0(&ctx->ac, ctx->gs_ngg_scratch, tid), "");
2085          LLVMValueRef args[] = {
2086             tmp,           ngg_get_query_buf(ctx),
2087             offset,        LLVMConstInt(ctx->ac.i32, 16, false), /* soffset */
2088             ctx->ac.i32_0,                                       /* cachepolicy */
2089          };
2090          ac_build_intrinsic(&ctx->ac, "llvm.amdgcn.raw.buffer.atomic.add.i32", ctx->ac.i32, args, 5,
2091                             0);
2092       }
2093       ac_build_endif(&ctx->ac, 5110);
2094       ac_build_endif(&ctx->ac, 5109);
2095    }
2096 
2097    /* Cull primitives. */
2098    if (ctx->shader->key.ge.opt.ngg_culling) {
2099       assert(info->num_stream_output_components[0]);
2100 
2101       LLVMValueRef gs_vtxptr = ngg_gs_vertex_ptr(ctx, tid);
2102       LLVMValueRef live = LLVMBuildLoad2(builder, ctx->ac.i8, ngg_gs_get_emit_primflag_ptr(ctx, gs_vtxptr, 0), "");
2103       live = LLVMBuildTrunc(builder, live, ctx->ac.i1, "");
2104       LLVMValueRef is_emit = LLVMBuildICmp(builder, LLVMIntULT, tid, num_emit_threads, "");
2105       LLVMValueRef prim_enable = LLVMBuildAnd(builder, live, is_emit, "");
2106 
2107       /* Wait for streamout to finish before we kill primitives. */
2108       if (ctx->so.num_outputs) {
2109          ac_build_waitcnt(&ctx->ac, AC_WAIT_LGKM);
2110          ac_build_s_barrier(&ctx->ac, ctx->stage);
2111       }
2112 
2113       ac_build_ifcc(&ctx->ac, prim_enable, 0);
2114       {
2115          LLVMValueRef vtxptr[3] = {};
2116          LLVMValueRef pos[3][4] = {};
2117 
2118          for (unsigned i = 0; i < verts_per_prim; i++) {
2119             tmp = LLVMBuildSub(builder, tid, LLVMConstInt(ctx->ac.i32, verts_per_prim - i - 1, false), "");
2120             vtxptr[i] = ac_build_gep0(&ctx->ac, ngg_gs_vertex_ptr(ctx, tmp), ctx->ac.i32_0);
2121          }
2122 
2123          for (unsigned i = 0; i < info->num_outputs; i++) {
2124             /* If the stream index is non-zero for all channels, skip the output. */
2125             if (info->output_streams[i] & 0x3 &&
2126                 (info->output_streams[i] >> 2) & 0x3 &&
2127                 (info->output_streams[i] >> 4) & 0x3 &&
2128                 (info->output_streams[i] >> 6) & 0x3)
2129                continue;
2130 
2131             switch (info->output_semantic[i]) {
2132             case VARYING_SLOT_POS:
2133                /* Load the positions from LDS. */
2134                for (unsigned vert = 0; vert < verts_per_prim; vert++) {
2135                   for (unsigned comp = 0; comp < 4; comp++) {
2136                      /* Z is not needed. */
2137                      if (comp == 2)
2138                         continue;
2139 
2140                      tmp = ac_build_gep0(&ctx->ac, vtxptr[vert],
2141                                          LLVMConstInt(ctx->ac.i32, 4 * i + comp, false));
2142                      pos[vert][comp] = LLVMBuildLoad(builder, tmp, "");
2143                      pos[vert][comp] = ac_to_float(&ctx->ac, pos[vert][comp]);
2144                   }
2145                }
2146 
2147                /* Divide XY by W. */
2148                for (unsigned vert = 0; vert < verts_per_prim; vert++) {
2149                   for (unsigned comp = 0; comp < 2; comp++)
2150                      pos[vert][comp] = ac_build_fdiv(&ctx->ac, pos[vert][comp], pos[vert][3]);
2151                }
2152                break;
2153             }
2154          }
2155 
2156          LLVMValueRef clipdist_accepted = ctx->ac.i1true; /* TODO */
2157          LLVMValueRef accepted = ac_build_alloca(&ctx->ac, ctx->ac.i32, "");
2158 
2159          cull_primitive(ctx, pos, clipdist_accepted, accepted, NULL);
2160 
2161          accepted = LLVMBuildLoad2(builder, ctx->ac.i32, accepted, "");
2162          LLVMValueRef rejected = LLVMBuildNot(builder, LLVMBuildTrunc(builder, accepted, ctx->ac.i1, ""), "");
2163 
2164          ac_build_ifcc(&ctx->ac, rejected, 0);
2165          LLVMBuildStore(builder, ctx->ac.i8_0, ngg_gs_get_emit_primflag_ptr(ctx, gs_vtxptr, 0));
2166          ac_build_endif(&ctx->ac, 0);
2167       }
2168       ac_build_endif(&ctx->ac, 0);
2169 
2170       ac_build_waitcnt(&ctx->ac, AC_WAIT_LGKM);
2171       ac_build_s_barrier(&ctx->ac, ctx->stage);
2172    }
2173 
2174    /* Determine vertex liveness. */
2175    LLVMValueRef vertliveptr = ac_build_alloca(&ctx->ac, ctx->ac.i1, "vertexlive");
2176 
2177    tmp = LLVMBuildICmp(builder, LLVMIntULT, tid, num_emit_threads, "");
2178    ac_build_ifcc(&ctx->ac, tmp, 5120);
2179    {
2180       for (unsigned i = 0; i < verts_per_prim; ++i) {
2181          const LLVMValueRef primidx =
2182             LLVMBuildAdd(builder, tid, LLVMConstInt(ctx->ac.i32, i, false), "");
2183 
2184          if (i > 0) {
2185             tmp = LLVMBuildICmp(builder, LLVMIntULT, primidx, num_emit_threads, "");
2186             ac_build_ifcc(&ctx->ac, tmp, 5121 + i);
2187          }
2188 
2189          /* Load primitive liveness */
2190          tmp = ngg_gs_vertex_ptr(ctx, primidx);
2191          tmp = LLVMBuildLoad2(builder, ctx->ac.i8, ngg_gs_get_emit_primflag_ptr(ctx, tmp, 0), "");
2192          const LLVMValueRef primlive = LLVMBuildTrunc(builder, tmp, ctx->ac.i1, "");
2193 
2194          tmp = LLVMBuildLoad2(builder, ctx->ac.i1, vertliveptr, "");
2195          tmp = LLVMBuildOr(builder, tmp, primlive, ""), LLVMBuildStore(builder, tmp, vertliveptr);
2196 
2197          if (i > 0)
2198             ac_build_endif(&ctx->ac, 5121 + i);
2199       }
2200    }
2201    ac_build_endif(&ctx->ac, 5120);
2202 
2203    /* Inclusive scan addition across the current wave. */
2204    LLVMValueRef vertlive = LLVMBuildLoad2(builder, ctx->ac.i1, vertliveptr, "");
2205    struct ac_wg_scan vertlive_scan = {};
2206    vertlive_scan.stage = ctx->stage;
2207    vertlive_scan.op = nir_op_iadd;
2208    vertlive_scan.enable_reduce = true;
2209    vertlive_scan.enable_exclusive = true;
2210    vertlive_scan.src = vertlive;
2211    vertlive_scan.scratch = ac_build_gep0(&ctx->ac, ctx->gs_ngg_scratch, ctx->ac.i32_0);
2212    vertlive_scan.waveidx = get_wave_id_in_tg(ctx);
2213    vertlive_scan.numwaves = get_tgsize(ctx);
2214    vertlive_scan.maxwaves = DIV_ROUND_UP(256, ctx->ac.wave_size);
2215 
2216    ac_build_wg_scan(&ctx->ac, &vertlive_scan);
2217 
2218    /* Skip all exports (including index exports) when possible. */
2219    LLVMValueRef have_exports =
2220       LLVMBuildICmp(builder, LLVMIntNE, vertlive_scan.result_reduce, ctx->ac.i32_0, "");
2221    num_emit_threads = LLVMBuildSelect(builder, have_exports, num_emit_threads, ctx->ac.i32_0, "");
2222 
2223    /* Allocate export space. Send this message as early as possible, to
2224     * hide the latency of the SQ <-> SPI roundtrip.
2225     */
2226    ac_build_sendmsg_gs_alloc_req(&ctx->ac, get_wave_id_in_tg(ctx), vertlive_scan.result_reduce,
2227                                  num_emit_threads);
2228 
2229    /* Setup the reverse vertex compaction permutation. We re-use stream 1
2230     * of the primitive liveness flags, relying on the fact that each
2231     * threadgroup can have at most 256 threads. */
2232    ac_build_ifcc(&ctx->ac, vertlive, 5130);
2233    {
2234       tmp = ngg_gs_vertex_ptr(ctx, vertlive_scan.result_exclusive);
2235       tmp2 = LLVMBuildTrunc(builder, tid, ctx->ac.i8, "");
2236       LLVMBuildStore(builder, tmp2, ngg_gs_get_emit_primflag_ptr(ctx, tmp, 1));
2237    }
2238    ac_build_endif(&ctx->ac, 5130);
2239 
2240    ac_build_waitcnt(&ctx->ac, AC_WAIT_LGKM);
2241    ac_build_s_barrier(&ctx->ac, ctx->stage);
2242 
2243    /* Export primitive data */
2244    tmp = LLVMBuildICmp(builder, LLVMIntULT, tid, num_emit_threads, "");
2245    ac_build_ifcc(&ctx->ac, tmp, 5140);
2246    {
2247       LLVMValueRef flags;
2248       struct ac_ngg_prim prim = {};
2249       prim.num_vertices = verts_per_prim;
2250 
2251       tmp = ngg_gs_vertex_ptr(ctx, tid);
2252       flags = LLVMBuildLoad2(builder, ctx->ac.i8, ngg_gs_get_emit_primflag_ptr(ctx, tmp, 0), "");
2253       prim.isnull = LLVMBuildNot(builder, LLVMBuildTrunc(builder, flags, ctx->ac.i1, ""), "");
2254       prim.edgeflags = ctx->ac.i32_0;
2255 
2256       for (unsigned i = 0; i < verts_per_prim; ++i) {
2257          prim.index[i] = LLVMBuildSub(builder, vertlive_scan.result_exclusive,
2258                                       LLVMConstInt(ctx->ac.i32, verts_per_prim - i - 1, false), "");
2259       }
2260 
2261       /* Geometry shaders output triangle strips, but NGG expects triangles. */
2262       if (verts_per_prim == 3) {
2263          LLVMValueRef is_odd = LLVMBuildLShr(builder, flags, ctx->ac.i8_1, "");
2264          is_odd = LLVMBuildTrunc(builder, is_odd, ctx->ac.i1, "");
2265          LLVMValueRef flatshade_first = LLVMBuildICmp(
2266             builder, LLVMIntEQ, GET_FIELD(ctx, GS_STATE_PROVOKING_VTX_INDEX), ctx->ac.i32_0, "");
2267 
2268          ac_build_triangle_strip_indices_to_triangle(&ctx->ac, is_odd, flatshade_first, prim.index);
2269       }
2270 
2271       ac_build_export_prim(&ctx->ac, &prim);
2272 
2273       if (ctx->screen->info.gfx_level < GFX11) {
2274          tmp = GET_FIELD(ctx, GS_STATE_PIPELINE_STATS_EMU);
2275          tmp = LLVMBuildTrunc(builder, tmp, ctx->ac.i1, "");
2276          ac_build_ifcc(&ctx->ac, tmp, 5229); /* if (GS_PIPELINE_STATS_EMU) */
2277          ac_build_ifcc(&ctx->ac, LLVMBuildNot(builder, prim.isnull, ""), 5237);
2278          {
2279             LLVMValueRef args[] = {
2280                ctx->ac.i32_1,
2281                ngg_get_emulated_counters_buf(ctx),
2282                LLVMConstInt(ctx->ac.i32,
2283                             si_query_pipestat_end_dw_offset(ctx->screen, PIPE_STAT_QUERY_GS_PRIMITIVES) * 4,
2284                             false),
2285                ctx->ac.i32_0,                            /* soffset */
2286                ctx->ac.i32_0,                            /* cachepolicy */
2287             };
2288 
2289             ac_build_intrinsic(&ctx->ac, "llvm.amdgcn.raw.buffer.atomic.add.i32", ctx->ac.i32, args, 5, 0);
2290          }
2291          ac_build_endif(&ctx->ac, 5237);
2292          ac_build_endif(&ctx->ac, 5229);
2293       }
2294    }
2295    ac_build_endif(&ctx->ac, 5140);
2296 
2297    /* Export position and parameter data */
2298    LLVMValueRef num_export_threads = vertlive_scan.result_reduce;
2299    tmp = LLVMBuildICmp(builder, LLVMIntULT, tid, num_export_threads, "");
2300    ac_build_ifcc(&ctx->ac, tmp, 5145);
2301    {
2302       struct si_shader_output_values outputs[PIPE_MAX_SHADER_OUTPUTS];
2303 
2304       tmp = ngg_gs_vertex_ptr(ctx, tid);
2305       tmp = LLVMBuildLoad2(builder, ctx->ac.i8, ngg_gs_get_emit_primflag_ptr(ctx, tmp, 1), "");
2306       tmp = LLVMBuildZExt(builder, tmp, ctx->ac.i32, "");
2307       const LLVMValueRef vertexptr = ngg_gs_vertex_ptr(ctx, tmp);
2308 
2309       unsigned out_idx = 0;
2310       for (unsigned i = 0; i < info->num_outputs; i++) {
2311          outputs[i].semantic = info->output_semantic[i];
2312 
2313          for (unsigned j = 0; j < 4; j++, out_idx++) {
2314             tmp = ngg_gs_get_emit_output_ptr(ctx, vertexptr, out_idx);
2315             tmp = LLVMBuildLoad2(builder, ctx->ac.i32, tmp, "");
2316             assert(LLVMGetTypeKind(LLVMTypeOf(tmp)) != LLVMPointerTypeKind);
2317             outputs[i].values[j] = ac_to_float(&ctx->ac, tmp);
2318             outputs[i].vertex_streams = info->output_streams[i];
2319          }
2320       }
2321 
2322       si_llvm_build_vs_exports(ctx, num_export_threads, outputs, info->num_outputs);
2323    }
2324    ac_build_endif(&ctx->ac, 5145);
2325 }
2326 
clamp_gsprims_to_esverts(unsigned * max_gsprims,unsigned max_esverts,unsigned min_verts_per_prim,bool use_adjacency)2327 static void clamp_gsprims_to_esverts(unsigned *max_gsprims, unsigned max_esverts,
2328                                      unsigned min_verts_per_prim, bool use_adjacency)
2329 {
2330    unsigned max_reuse = max_esverts - min_verts_per_prim;
2331    if (use_adjacency)
2332       max_reuse /= 2;
2333    *max_gsprims = MIN2(*max_gsprims, 1 + max_reuse);
2334 }
2335 
gfx10_ngg_get_scratch_dw_size(struct si_shader * shader)2336 unsigned gfx10_ngg_get_scratch_dw_size(struct si_shader *shader)
2337 {
2338    const struct si_shader_selector *sel = shader->selector;
2339 
2340    if (sel->stage == MESA_SHADER_GEOMETRY && si_shader_uses_streamout(shader))
2341       return 44;
2342 
2343    return 8;
2344 }
2345 
2346 /**
2347  * Determine subgroup information like maximum number of vertices and prims.
2348  *
2349  * This happens before the shader is uploaded, since LDS relocations during
2350  * upload depend on the subgroup size.
2351  */
gfx10_ngg_calculate_subgroup_info(struct si_shader * shader)2352 bool gfx10_ngg_calculate_subgroup_info(struct si_shader *shader)
2353 {
2354    const struct si_shader_selector *gs_sel = shader->selector;
2355    const struct si_shader_selector *es_sel =
2356       shader->previous_stage_sel ? shader->previous_stage_sel : gs_sel;
2357    const gl_shader_stage gs_stage = gs_sel->stage;
2358    const unsigned gs_num_invocations = MAX2(gs_sel->info.base.gs.invocations, 1);
2359    const unsigned input_prim = si_get_input_prim(gs_sel, &shader->key);
2360    const bool use_adjacency =
2361       input_prim >= PIPE_PRIM_LINES_ADJACENCY && input_prim <= PIPE_PRIM_TRIANGLE_STRIP_ADJACENCY;
2362    const unsigned max_verts_per_prim = u_vertices_per_prim(input_prim);
2363    const unsigned min_verts_per_prim = gs_stage == MESA_SHADER_GEOMETRY ? max_verts_per_prim : 1;
2364 
2365    /* All these are in dwords: */
2366    /* GE can only use 8K dwords (32KB) of LDS per workgroup.
2367     */
2368    const unsigned max_lds_size = 8 * 1024 - gfx10_ngg_get_scratch_dw_size(shader);
2369    const unsigned target_lds_size = max_lds_size;
2370    unsigned esvert_lds_size = 0;
2371    unsigned gsprim_lds_size = 0;
2372 
2373    /* All these are per subgroup: */
2374    const unsigned min_esverts =
2375       gs_sel->screen->info.gfx_level >= GFX11 ? 3 : /* gfx11 requires at least 1 primitive per TG */
2376       gs_sel->screen->info.gfx_level >= GFX10_3 ? 29 : (24 - 1 + max_verts_per_prim);
2377    bool max_vert_out_per_gs_instance = false;
2378    unsigned max_gsprims_base = gs_sel->screen->ngg_subgroup_size; /* default prim group size clamp */
2379    unsigned max_esverts_base = gs_sel->screen->ngg_subgroup_size;
2380 
2381    if (gs_stage == MESA_SHADER_GEOMETRY) {
2382       bool force_multi_cycling = false;
2383       unsigned max_out_verts_per_gsprim = gs_sel->info.base.gs.vertices_out * gs_num_invocations;
2384 
2385 retry_select_mode:
2386       if (max_out_verts_per_gsprim <= 256 && !force_multi_cycling) {
2387          if (max_out_verts_per_gsprim) {
2388             max_gsprims_base = MIN2(max_gsprims_base, 256 / max_out_verts_per_gsprim);
2389          }
2390       } else {
2391          /* Use special multi-cycling mode in which each GS
2392           * instance gets its own subgroup. Does not work with
2393           * tessellation. */
2394          max_vert_out_per_gs_instance = true;
2395          max_gsprims_base = 1;
2396          max_out_verts_per_gsprim = gs_sel->info.base.gs.vertices_out;
2397       }
2398 
2399       esvert_lds_size = es_sel->info.esgs_itemsize / 4;
2400       gsprim_lds_size = (gs_sel->info.gsvs_vertex_size / 4 + 1) * max_out_verts_per_gsprim;
2401 
2402       if (gsprim_lds_size > target_lds_size && !force_multi_cycling) {
2403          if (gs_sel->tess_turns_off_ngg || es_sel->stage != MESA_SHADER_TESS_EVAL) {
2404             force_multi_cycling = true;
2405             goto retry_select_mode;
2406          }
2407       }
2408    } else {
2409       /* VS and TES. */
2410       /* LDS size for passing data from ES to GS. */
2411       esvert_lds_size = ngg_nogs_vertex_size(shader);
2412    }
2413 
2414    unsigned max_gsprims = max_gsprims_base;
2415    unsigned max_esverts = max_esverts_base;
2416 
2417    if (esvert_lds_size)
2418       max_esverts = MIN2(max_esverts, target_lds_size / esvert_lds_size);
2419    if (gsprim_lds_size)
2420       max_gsprims = MIN2(max_gsprims, target_lds_size / gsprim_lds_size);
2421 
2422    max_esverts = MIN2(max_esverts, max_gsprims * max_verts_per_prim);
2423    clamp_gsprims_to_esverts(&max_gsprims, max_esverts, min_verts_per_prim, use_adjacency);
2424    assert(max_esverts >= max_verts_per_prim && max_gsprims >= 1);
2425 
2426    if (esvert_lds_size || gsprim_lds_size) {
2427       /* Now that we have a rough proportionality between esverts
2428        * and gsprims based on the primitive type, scale both of them
2429        * down simultaneously based on required LDS space.
2430        *
2431        * We could be smarter about this if we knew how much vertex
2432        * reuse to expect.
2433        */
2434       unsigned lds_total = max_esverts * esvert_lds_size + max_gsprims * gsprim_lds_size;
2435       if (lds_total > target_lds_size) {
2436          max_esverts = max_esverts * target_lds_size / lds_total;
2437          max_gsprims = max_gsprims * target_lds_size / lds_total;
2438 
2439          max_esverts = MIN2(max_esverts, max_gsprims * max_verts_per_prim);
2440          clamp_gsprims_to_esverts(&max_gsprims, max_esverts, min_verts_per_prim, use_adjacency);
2441          assert(max_esverts >= max_verts_per_prim && max_gsprims >= 1);
2442       }
2443    }
2444 
2445    /* Round up towards full wave sizes for better ALU utilization. */
2446    if (!max_vert_out_per_gs_instance) {
2447       unsigned orig_max_esverts;
2448       unsigned orig_max_gsprims;
2449       do {
2450          orig_max_esverts = max_esverts;
2451          orig_max_gsprims = max_gsprims;
2452 
2453          max_esverts = align(max_esverts, shader->wave_size);
2454          max_esverts = MIN2(max_esverts, max_esverts_base);
2455          if (esvert_lds_size)
2456             max_esverts =
2457                MIN2(max_esverts, (max_lds_size - max_gsprims * gsprim_lds_size) / esvert_lds_size);
2458          max_esverts = MIN2(max_esverts, max_gsprims * max_verts_per_prim);
2459 
2460          /* Hardware restriction: minimum value of max_esverts */
2461          max_esverts = MAX2(max_esverts, min_esverts);
2462 
2463          max_gsprims = align(max_gsprims, shader->wave_size);
2464          max_gsprims = MIN2(max_gsprims, max_gsprims_base);
2465          if (gsprim_lds_size) {
2466             /* Don't count unusable vertices to the LDS size. Those are vertices above
2467              * the maximum number of vertices that can occur in the workgroup,
2468              * which is e.g. max_gsprims * 3 for triangles.
2469              */
2470             unsigned usable_esverts = MIN2(max_esverts, max_gsprims * max_verts_per_prim);
2471             max_gsprims =
2472                MIN2(max_gsprims, (max_lds_size - usable_esverts * esvert_lds_size) / gsprim_lds_size);
2473          }
2474          clamp_gsprims_to_esverts(&max_gsprims, max_esverts, min_verts_per_prim, use_adjacency);
2475          assert(max_esverts >= max_verts_per_prim && max_gsprims >= 1);
2476       } while (orig_max_esverts != max_esverts || orig_max_gsprims != max_gsprims);
2477 
2478       /* Verify the restriction. */
2479       assert(max_esverts >= min_esverts);
2480    } else {
2481       max_esverts = MAX2(max_esverts, min_esverts);
2482    }
2483 
2484    unsigned max_out_vertices =
2485       max_vert_out_per_gs_instance
2486          ? gs_sel->info.base.gs.vertices_out
2487          : gs_stage == MESA_SHADER_GEOMETRY
2488               ? max_gsprims * gs_num_invocations * gs_sel->info.base.gs.vertices_out
2489               : max_esverts;
2490    assert(max_out_vertices <= 256);
2491 
2492    unsigned prim_amp_factor = 1;
2493    if (gs_stage == MESA_SHADER_GEOMETRY) {
2494       /* Number of output primitives per GS input primitive after
2495        * GS instancing. */
2496       prim_amp_factor = gs_sel->info.base.gs.vertices_out;
2497    }
2498 
2499    shader->ngg.hw_max_esverts = max_esverts;
2500    shader->ngg.max_gsprims = max_gsprims;
2501    shader->ngg.max_out_verts = max_out_vertices;
2502    shader->ngg.prim_amp_factor = prim_amp_factor;
2503    shader->ngg.max_vert_out_per_gs_instance = max_vert_out_per_gs_instance;
2504 
2505    /* Don't count unusable vertices. */
2506    shader->gs_info.esgs_ring_size = MIN2(max_esverts, max_gsprims * max_verts_per_prim) *
2507                                     esvert_lds_size;
2508    shader->ngg.ngg_emit_size = max_gsprims * gsprim_lds_size;
2509 
2510    assert(shader->ngg.hw_max_esverts >= min_esverts); /* HW limitation */
2511 
2512    /* If asserts are disabled, we use the same conditions to return false */
2513    return max_esverts >= max_verts_per_prim && max_gsprims >= 1 &&
2514           max_out_vertices <= 256 &&
2515           shader->ngg.hw_max_esverts >= min_esverts;
2516 }
2517