• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright 2023 Alyssa Rosenzweig
3  * Copyright 2023 Valve Corporation
4  * SPDX-License-Identifier: MIT
5  */
6 
7 #include "agx_nir_lower_gs.h"
8 #include "asahi/compiler/agx_compile.h"
9 #include "compiler/nir/nir_builder.h"
10 #include "gallium/include/pipe/p_defines.h"
11 #include "shaders/geometry.h"
12 #include "util/bitscan.h"
13 #include "util/list.h"
14 #include "util/macros.h"
15 #include "util/ralloc.h"
16 #include "util/u_math.h"
17 #include "libagx_shaders.h"
18 #include "nir.h"
19 #include "nir_builder_opcodes.h"
20 #include "nir_intrinsics.h"
21 #include "nir_intrinsics_indices.h"
22 #include "nir_xfb_info.h"
23 #include "shader_enums.h"
24 
25 /* Marks a transform feedback store, which must not be stripped from the
26  * prepass since that's where the transform feedback happens. Chosen as a
27  * vendored flag not to alias other flags we'll see.
28  */
29 #define ACCESS_XFB (ACCESS_IS_SWIZZLED_AMD)
30 
31 enum gs_counter {
32    GS_COUNTER_VERTICES = 0,
33    GS_COUNTER_PRIMITIVES,
34    GS_COUNTER_XFB_PRIMITIVES,
35    GS_NUM_COUNTERS
36 };
37 
38 #define MAX_PRIM_OUT_SIZE 3
39 
40 struct lower_gs_state {
41    int static_count[GS_NUM_COUNTERS][MAX_VERTEX_STREAMS];
42    nir_variable *outputs[NUM_TOTAL_VARYING_SLOTS][MAX_PRIM_OUT_SIZE];
43 
44    /* The count buffer contains `count_stride_el` 32-bit words in a row for each
45     * input primitive, for `input_primitives * count_stride_el * 4` total bytes.
46     */
47    unsigned count_stride_el;
48 
49    /* The index of each counter in the count buffer, or -1 if it's not in the
50     * count buffer.
51     *
52     * Invariant: count_stride_el == sum(count_index[i][j] >= 0).
53     */
54    int count_index[MAX_VERTEX_STREAMS][GS_NUM_COUNTERS];
55 
56    bool rasterizer_discard;
57 };
58 
59 /* Helpers for loading from the geometry state buffer */
60 static nir_def *
load_geometry_param_offset(nir_builder * b,uint32_t offset,uint8_t bytes)61 load_geometry_param_offset(nir_builder *b, uint32_t offset, uint8_t bytes)
62 {
63    nir_def *base = nir_load_geometry_param_buffer_agx(b);
64    nir_def *addr = nir_iadd_imm(b, base, offset);
65 
66    assert((offset % bytes) == 0 && "must be naturally aligned");
67 
68    return nir_load_global_constant(b, addr, bytes, 1, bytes * 8);
69 }
70 
71 static void
store_geometry_param_offset(nir_builder * b,nir_def * def,uint32_t offset,uint8_t bytes)72 store_geometry_param_offset(nir_builder *b, nir_def *def, uint32_t offset,
73                             uint8_t bytes)
74 {
75    nir_def *base = nir_load_geometry_param_buffer_agx(b);
76    nir_def *addr = nir_iadd_imm(b, base, offset);
77 
78    assert((offset % bytes) == 0 && "must be naturally aligned");
79 
80    nir_store_global(b, addr, 4, def, nir_component_mask(def->num_components));
81 }
82 
83 #define store_geometry_param(b, field, def)                                    \
84    store_geometry_param_offset(                                                \
85       b, def, offsetof(struct agx_geometry_params, field),                     \
86       sizeof(((struct agx_geometry_params *)0)->field))
87 
88 #define load_geometry_param(b, field)                                          \
89    load_geometry_param_offset(                                                 \
90       b, offsetof(struct agx_geometry_params, field),                          \
91       sizeof(((struct agx_geometry_params *)0)->field))
92 
93 /* Helper for updating counters */
94 static void
add_counter(nir_builder * b,nir_def * counter,nir_def * increment)95 add_counter(nir_builder *b, nir_def *counter, nir_def *increment)
96 {
97    /* If the counter is NULL, the counter is disabled. Skip the update. */
98    nir_if *nif = nir_push_if(b, nir_ine_imm(b, counter, 0));
99    {
100       nir_def *old = nir_load_global(b, counter, 4, 1, 32);
101       nir_def *new_ = nir_iadd(b, old, increment);
102       nir_store_global(b, counter, 4, new_, nir_component_mask(1));
103    }
104    nir_pop_if(b, nif);
105 }
106 
107 /* Helpers for lowering I/O to variables */
108 static void
lower_store_to_var(nir_builder * b,nir_intrinsic_instr * intr,struct agx_lower_output_to_var_state * state)109 lower_store_to_var(nir_builder *b, nir_intrinsic_instr *intr,
110                    struct agx_lower_output_to_var_state *state)
111 {
112    b->cursor = nir_instr_remove(&intr->instr);
113    nir_io_semantics sem = nir_intrinsic_io_semantics(intr);
114    unsigned component = nir_intrinsic_component(intr);
115    nir_def *value = intr->src[0].ssa;
116 
117    assert(nir_src_is_const(intr->src[1]) && "no indirect outputs");
118    assert(nir_intrinsic_write_mask(intr) == nir_component_mask(1) &&
119           "should be scalarized");
120 
121    nir_variable *var =
122       state->outputs[sem.location + nir_src_as_uint(intr->src[1])];
123    if (!var) {
124       assert(sem.location == VARYING_SLOT_PSIZ &&
125              "otherwise in outputs_written");
126       return;
127    }
128 
129    unsigned nr_components = glsl_get_components(glsl_without_array(var->type));
130    assert(component < nr_components);
131 
132    /* Turn it into a vec4 write like NIR expects */
133    value = nir_vector_insert_imm(b, nir_undef(b, nr_components, 32), value,
134                                  component);
135 
136    nir_store_var(b, var, value, BITFIELD_BIT(component));
137 }
138 
139 bool
agx_lower_output_to_var(nir_builder * b,nir_instr * instr,void * data)140 agx_lower_output_to_var(nir_builder *b, nir_instr *instr, void *data)
141 {
142    if (instr->type != nir_instr_type_intrinsic)
143       return false;
144 
145    nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
146    if (intr->intrinsic != nir_intrinsic_store_output)
147       return false;
148 
149    lower_store_to_var(b, intr, data);
150    return true;
151 }
152 
153 /*
154  * Geometry shader invocations are compute-like:
155  *
156  * (primitive ID, instance ID, 1)
157  */
158 static nir_def *
load_primitive_id(nir_builder * b)159 load_primitive_id(nir_builder *b)
160 {
161    return nir_channel(b, nir_load_global_invocation_id(b, 32), 0);
162 }
163 
164 static nir_def *
load_instance_id(nir_builder * b)165 load_instance_id(nir_builder *b)
166 {
167    return nir_channel(b, nir_load_global_invocation_id(b, 32), 1);
168 }
169 
170 static bool
lower_gs_inputs(nir_builder * b,nir_intrinsic_instr * intr,void * _)171 lower_gs_inputs(nir_builder *b, nir_intrinsic_instr *intr, void *_)
172 {
173    if (intr->intrinsic != nir_intrinsic_load_per_vertex_input)
174       return false;
175 
176    b->cursor = nir_instr_remove(&intr->instr);
177    nir_io_semantics sem = nir_intrinsic_io_semantics(intr);
178 
179    nir_def *location = nir_iadd_imm(b, intr->src[1].ssa, sem.location);
180 
181    /* Calculate the vertex ID we're pulling, based on the topology class */
182    nir_def *vert_in_prim = intr->src[0].ssa;
183    nir_def *vertex = agx_vertex_id_for_topology_class(
184       b, vert_in_prim, b->shader->info.gs.input_primitive);
185 
186    /* The unrolled vertex ID uses the input_vertices, which differs from what
187     * our load_num_vertices will return (vertices vs primitives).
188     */
189    nir_def *unrolled =
190       nir_iadd(b,
191                nir_imul(b, nir_load_instance_id(b),
192                         load_geometry_param(b, input_vertices)),
193                vertex);
194 
195    /* Calculate the address of the input given the unrolled vertex ID */
196    nir_def *addr = libagx_vertex_output_address(
197       b, nir_load_geometry_param_buffer_agx(b), unrolled, location,
198       load_geometry_param(b, vs_outputs));
199 
200    assert(intr->def.bit_size == 32);
201    addr = nir_iadd_imm(b, addr, nir_intrinsic_component(intr) * 4);
202 
203    nir_def *val = nir_load_global_constant(b, addr, 4, intr->def.num_components,
204                                            intr->def.bit_size);
205    nir_def_rewrite_uses(&intr->def, val);
206    return true;
207 }
208 
209 /*
210  * Unrolled ID is the index of the primitive in the count buffer, given as
211  * (instance ID * # vertices/instance) + vertex ID
212  */
213 static nir_def *
calc_unrolled_id(nir_builder * b)214 calc_unrolled_id(nir_builder *b)
215 {
216    return nir_iadd(b,
217                    nir_imul(b, load_instance_id(b), nir_load_num_vertices(b)),
218                    load_primitive_id(b));
219 }
220 
221 static unsigned
output_vertex_id_stride(nir_shader * gs)222 output_vertex_id_stride(nir_shader *gs)
223 {
224    /* round up to power of two for cheap multiply/division */
225    return util_next_power_of_two(MAX2(gs->info.gs.vertices_out, 1));
226 }
227 
228 /* Variant of calc_unrolled_id that uses a power-of-two stride for indices. This
229  * is sparser (acceptable for index buffer values, not for count buffer
230  * indices). It has the nice property of being cheap to invert, unlike
231  * calc_unrolled_id. So, we use calc_unrolled_id for count buffers and
232  * calc_unrolled_index_id for index values.
233  *
234  * This also multiplies by the appropriate stride to calculate the final index
235  * base value.
236  */
237 static nir_def *
calc_unrolled_index_id(nir_builder * b)238 calc_unrolled_index_id(nir_builder *b)
239 {
240    unsigned vertex_stride = output_vertex_id_stride(b->shader);
241    nir_def *primitives_log2 = load_geometry_param(b, primitives_log2);
242 
243    nir_def *instance = nir_ishl(b, load_instance_id(b), primitives_log2);
244    nir_def *prim = nir_iadd(b, instance, load_primitive_id(b));
245 
246    return nir_imul_imm(b, prim, vertex_stride);
247 }
248 
249 static nir_def *
load_count_address(nir_builder * b,struct lower_gs_state * state,nir_def * unrolled_id,unsigned stream,enum gs_counter counter)250 load_count_address(nir_builder *b, struct lower_gs_state *state,
251                    nir_def *unrolled_id, unsigned stream,
252                    enum gs_counter counter)
253 {
254    int index = state->count_index[stream][counter];
255    if (index < 0)
256       return NULL;
257 
258    nir_def *prim_offset_el =
259       nir_imul_imm(b, unrolled_id, state->count_stride_el);
260 
261    nir_def *offset_el = nir_iadd_imm(b, prim_offset_el, index);
262 
263    return nir_iadd(b, load_geometry_param(b, count_buffer),
264                    nir_u2u64(b, nir_imul_imm(b, offset_el, 4)));
265 }
266 
267 static void
write_counts(nir_builder * b,nir_intrinsic_instr * intr,struct lower_gs_state * state)268 write_counts(nir_builder *b, nir_intrinsic_instr *intr,
269              struct lower_gs_state *state)
270 {
271    /* Store each required counter */
272    nir_def *counts[GS_NUM_COUNTERS] = {
273       [GS_COUNTER_VERTICES] = intr->src[0].ssa,
274       [GS_COUNTER_PRIMITIVES] = intr->src[1].ssa,
275       [GS_COUNTER_XFB_PRIMITIVES] = intr->src[2].ssa,
276    };
277 
278    for (unsigned i = 0; i < GS_NUM_COUNTERS; ++i) {
279       nir_def *addr = load_count_address(b, state, calc_unrolled_id(b),
280                                          nir_intrinsic_stream_id(intr), i);
281 
282       if (addr)
283          nir_store_global(b, addr, 4, counts[i], nir_component_mask(1));
284    }
285 }
286 
287 static bool
lower_gs_count_instr(nir_builder * b,nir_intrinsic_instr * intr,void * data)288 lower_gs_count_instr(nir_builder *b, nir_intrinsic_instr *intr, void *data)
289 {
290    switch (intr->intrinsic) {
291    case nir_intrinsic_emit_vertex_with_counter:
292    case nir_intrinsic_end_primitive_with_counter:
293    case nir_intrinsic_store_output:
294       /* These are for the main shader, just remove them */
295       nir_instr_remove(&intr->instr);
296       return true;
297 
298    case nir_intrinsic_set_vertex_and_primitive_count:
299       b->cursor = nir_instr_remove(&intr->instr);
300       write_counts(b, intr, data);
301       return true;
302 
303    default:
304       return false;
305    }
306 }
307 
308 static bool
lower_id(nir_builder * b,nir_intrinsic_instr * intr,void * data)309 lower_id(nir_builder *b, nir_intrinsic_instr *intr, void *data)
310 {
311    b->cursor = nir_before_instr(&intr->instr);
312 
313    nir_def *id;
314    if (intr->intrinsic == nir_intrinsic_load_primitive_id)
315       id = load_primitive_id(b);
316    else if (intr->intrinsic == nir_intrinsic_load_instance_id)
317       id = load_instance_id(b);
318    else if (intr->intrinsic == nir_intrinsic_load_num_vertices)
319       id = nir_channel(b, nir_load_num_workgroups(b), 0);
320    else if (intr->intrinsic == nir_intrinsic_load_flat_mask)
321       id = load_geometry_param(b, flat_outputs);
322    else if (intr->intrinsic == nir_intrinsic_load_input_topology_agx)
323       id = load_geometry_param(b, input_topology);
324    else if (intr->intrinsic == nir_intrinsic_load_provoking_last) {
325       id = nir_b2b32(
326          b, libagx_is_provoking_last(b, nir_load_input_assembly_buffer_agx(b)));
327    } else
328       return false;
329 
330    b->cursor = nir_instr_remove(&intr->instr);
331    nir_def_rewrite_uses(&intr->def, id);
332    return true;
333 }
334 
335 /*
336  * Create a "Geometry count" shader. This is a stripped down geometry shader
337  * that just write its number of emitted vertices / primitives / transform
338  * feedback primitives to a count buffer. That count buffer will be prefix
339  * summed prior to running the real geometry shader. This is skipped if the
340  * counts are statically known.
341  */
342 static nir_shader *
agx_nir_create_geometry_count_shader(nir_shader * gs,const nir_shader * libagx,struct lower_gs_state * state)343 agx_nir_create_geometry_count_shader(nir_shader *gs, const nir_shader *libagx,
344                                      struct lower_gs_state *state)
345 {
346    /* Don't muck up the original shader */
347    nir_shader *shader = nir_shader_clone(NULL, gs);
348 
349    if (shader->info.name) {
350       shader->info.name =
351          ralloc_asprintf(shader, "%s_count", shader->info.name);
352    } else {
353       shader->info.name = "count";
354    }
355 
356    NIR_PASS(_, shader, nir_shader_intrinsics_pass, lower_gs_count_instr,
357             nir_metadata_block_index | nir_metadata_dominance, state);
358 
359    NIR_PASS(_, shader, nir_shader_intrinsics_pass, lower_id,
360             nir_metadata_block_index | nir_metadata_dominance, NULL);
361 
362    /* Preprocess it */
363    UNUSED struct agx_uncompiled_shader_info info;
364    agx_preprocess_nir(shader, libagx, false, &info);
365 
366    return shader;
367 }
368 
369 struct lower_gs_rast_state {
370    nir_def *instance_id, *primitive_id, *output_id;
371    struct agx_lower_output_to_var_state outputs;
372    struct agx_lower_output_to_var_state selected;
373 };
374 
375 static void
select_rast_output(nir_builder * b,nir_intrinsic_instr * intr,struct lower_gs_rast_state * state)376 select_rast_output(nir_builder *b, nir_intrinsic_instr *intr,
377                    struct lower_gs_rast_state *state)
378 {
379    b->cursor = nir_instr_remove(&intr->instr);
380 
381    /* We only care about the rasterization stream in the rasterization
382     * shader, so just ignore emits from other streams.
383     */
384    if (nir_intrinsic_stream_id(intr) != 0)
385       return;
386 
387    u_foreach_bit64(slot, b->shader->info.outputs_written) {
388       nir_def *orig = nir_load_var(b, state->selected.outputs[slot]);
389       nir_def *data = nir_load_var(b, state->outputs.outputs[slot]);
390 
391       nir_def *value = nir_bcsel(
392          b, nir_ieq(b, intr->src[0].ssa, state->output_id), data, orig);
393 
394       nir_store_var(b, state->selected.outputs[slot], value,
395                     nir_component_mask(value->num_components));
396    }
397 }
398 
399 static bool
lower_to_gs_rast(nir_builder * b,nir_intrinsic_instr * intr,void * data)400 lower_to_gs_rast(nir_builder *b, nir_intrinsic_instr *intr, void *data)
401 {
402    struct lower_gs_rast_state *state = data;
403 
404    switch (intr->intrinsic) {
405    case nir_intrinsic_store_output:
406       lower_store_to_var(b, intr, &state->outputs);
407       return true;
408 
409    case nir_intrinsic_emit_vertex_with_counter:
410       select_rast_output(b, intr, state);
411       return true;
412 
413    case nir_intrinsic_load_primitive_id:
414       nir_def_rewrite_uses(&intr->def, state->primitive_id);
415       return true;
416 
417    case nir_intrinsic_load_instance_id:
418       nir_def_rewrite_uses(&intr->def, state->instance_id);
419       return true;
420 
421    case nir_intrinsic_load_num_vertices: {
422       b->cursor = nir_before_instr(&intr->instr);
423       nir_def_rewrite_uses(&intr->def, load_geometry_param(b, gs_grid[0]));
424       return true;
425    }
426 
427    case nir_intrinsic_load_flat_mask:
428    case nir_intrinsic_load_provoking_last:
429    case nir_intrinsic_load_input_topology_agx:
430       /* Lowering the same in both GS variants */
431       return lower_id(b, intr, data);
432 
433    case nir_intrinsic_end_primitive_with_counter:
434    case nir_intrinsic_set_vertex_and_primitive_count:
435       nir_instr_remove(&intr->instr);
436       return true;
437 
438    default:
439       return false;
440    }
441 }
442 
443 /*
444  * Create a GS rasterization shader. This is a hardware vertex shader that
445  * shades each rasterized output vertex in parallel.
446  */
447 static nir_shader *
agx_nir_create_gs_rast_shader(const nir_shader * gs,const nir_shader * libagx)448 agx_nir_create_gs_rast_shader(const nir_shader *gs, const nir_shader *libagx)
449 {
450    /* Don't muck up the original shader */
451    nir_shader *shader = nir_shader_clone(NULL, gs);
452 
453    unsigned max_verts = output_vertex_id_stride(shader);
454 
455    /* Turn into a vertex shader run only for rasterization. Transform feedback
456     * was handled in the prepass.
457     */
458    shader->info.stage = MESA_SHADER_VERTEX;
459    shader->info.has_transform_feedback_varyings = false;
460    memset(&shader->info.vs, 0, sizeof(shader->info.vs));
461    shader->xfb_info = NULL;
462 
463    if (shader->info.name) {
464       shader->info.name = ralloc_asprintf(shader, "%s_rast", shader->info.name);
465    } else {
466       shader->info.name = "gs rast";
467    }
468 
469    nir_builder b_ =
470       nir_builder_at(nir_before_impl(nir_shader_get_entrypoint(shader)));
471    nir_builder *b = &b_;
472 
473    /* Optimize out pointless gl_PointSize outputs. Bizarrely, these occur. */
474    if (shader->info.gs.output_primitive != MESA_PRIM_POINTS)
475       shader->info.outputs_written &= ~VARYING_BIT_PSIZ;
476 
477    /* See calc_unrolled_index_id */
478    nir_def *raw_id = nir_load_vertex_id(b);
479    nir_def *output_id = nir_umod_imm(b, raw_id, max_verts);
480    nir_def *unrolled = nir_udiv_imm(b, raw_id, max_verts);
481 
482    nir_def *primitives_log2 = load_geometry_param(b, primitives_log2);
483    nir_def *instance_id = nir_ushr(b, unrolled, primitives_log2);
484    nir_def *primitive_id = nir_iand(
485       b, unrolled,
486       nir_iadd_imm(b, nir_ishl(b, nir_imm_int(b, 1), primitives_log2), -1));
487 
488    struct lower_gs_rast_state rast_state = {
489       .instance_id = instance_id,
490       .primitive_id = primitive_id,
491       .output_id = output_id,
492    };
493 
494    u_foreach_bit64(slot, shader->info.outputs_written) {
495       const char *slot_name =
496          gl_varying_slot_name_for_stage(slot, MESA_SHADER_GEOMETRY);
497 
498       rast_state.outputs.outputs[slot] = nir_variable_create(
499          shader, nir_var_shader_temp, glsl_vector_type(GLSL_TYPE_UINT, 4),
500          ralloc_asprintf(shader, "%s-temp", slot_name));
501 
502       rast_state.selected.outputs[slot] = nir_variable_create(
503          shader, nir_var_shader_temp, glsl_vector_type(GLSL_TYPE_UINT, 4),
504          ralloc_asprintf(shader, "%s-selected", slot_name));
505    }
506 
507    nir_shader_intrinsics_pass(shader, lower_to_gs_rast,
508                               nir_metadata_block_index | nir_metadata_dominance,
509                               &rast_state);
510 
511    b->cursor = nir_after_impl(b->impl);
512 
513    /* Forward each selected output to the rasterizer */
514    u_foreach_bit64(slot, shader->info.outputs_written) {
515       assert(rast_state.selected.outputs[slot] != NULL);
516       nir_def *value = nir_load_var(b, rast_state.selected.outputs[slot]);
517 
518       /* We set NIR_COMPACT_ARRAYS so clip/cull distance needs to come all in
519        * DIST0. Undo the offset if we need to.
520        */
521       unsigned offset = 0;
522       if (slot == VARYING_SLOT_CULL_DIST1 || slot == VARYING_SLOT_CLIP_DIST1)
523          offset = 1;
524 
525       nir_store_output(b, value, nir_imm_int(b, offset),
526                        .io_semantics.location = slot - offset,
527                        .io_semantics.num_slots = 1,
528                        .write_mask = nir_component_mask(value->num_components));
529    }
530 
531    /* In OpenGL ES, it is legal to omit the point size write from the geometry
532     * shader when drawing points. In this case, the point size is
533     * implicitly 1.0. We implement this by inserting this synthetic
534     * `gl_PointSize = 1.0` write into the GS copy shader, if the GS does not
535     * export a point size while drawing points.
536     *
537     * This should not be load bearing for other APIs, but should be harmless.
538     */
539    bool is_points = gs->info.gs.output_primitive == MESA_PRIM_POINTS;
540 
541    if (!(shader->info.outputs_written & VARYING_BIT_PSIZ) && is_points) {
542       nir_store_output(b, nir_imm_float(b, 1.0), nir_imm_int(b, 0),
543                        .io_semantics.location = VARYING_SLOT_PSIZ,
544                        .io_semantics.num_slots = 1,
545                        .write_mask = nir_component_mask(1));
546 
547       shader->info.outputs_written |= VARYING_BIT_PSIZ;
548    }
549 
550    nir_opt_idiv_const(shader, 16);
551 
552    /* Preprocess it */
553    UNUSED struct agx_uncompiled_shader_info info;
554    agx_preprocess_nir(shader, libagx, false, &info);
555 
556    return shader;
557 }
558 
559 static nir_def *
previous_count(nir_builder * b,struct lower_gs_state * state,unsigned stream,nir_def * unrolled_id,enum gs_counter counter)560 previous_count(nir_builder *b, struct lower_gs_state *state, unsigned stream,
561                nir_def *unrolled_id, enum gs_counter counter)
562 {
563    assert(stream < MAX_VERTEX_STREAMS);
564    assert(counter < GS_NUM_COUNTERS);
565    int static_count = state->static_count[counter][stream];
566 
567    if (static_count >= 0) {
568       /* If the number of outputted vertices per invocation is known statically,
569        * we can calculate the base.
570        */
571       return nir_imul_imm(b, unrolled_id, static_count);
572    } else {
573       /* Otherwise, we need to load from the prefix sum buffer. Note that the
574        * sums are inclusive, so index 0 is nonzero. This requires a little
575        * fixup here. We use a saturating unsigned subtraction so we don't read
576        * out-of-bounds for zero.
577        *
578        * TODO: Optimize this.
579        */
580       nir_def *prim_minus_1 = nir_usub_sat(b, unrolled_id, nir_imm_int(b, 1));
581       nir_def *addr =
582          load_count_address(b, state, prim_minus_1, stream, counter);
583 
584       return nir_bcsel(b, nir_ieq_imm(b, unrolled_id, 0), nir_imm_int(b, 0),
585                        nir_load_global_constant(b, addr, 4, 1, 32));
586    }
587 }
588 
589 static nir_def *
previous_vertices(nir_builder * b,struct lower_gs_state * state,unsigned stream,nir_def * unrolled_id)590 previous_vertices(nir_builder *b, struct lower_gs_state *state, unsigned stream,
591                   nir_def *unrolled_id)
592 {
593    return previous_count(b, state, stream, unrolled_id, GS_COUNTER_VERTICES);
594 }
595 
596 static nir_def *
previous_primitives(nir_builder * b,struct lower_gs_state * state,unsigned stream,nir_def * unrolled_id)597 previous_primitives(nir_builder *b, struct lower_gs_state *state,
598                     unsigned stream, nir_def *unrolled_id)
599 {
600    return previous_count(b, state, stream, unrolled_id, GS_COUNTER_PRIMITIVES);
601 }
602 
603 static nir_def *
previous_xfb_primitives(nir_builder * b,struct lower_gs_state * state,unsigned stream,nir_def * unrolled_id)604 previous_xfb_primitives(nir_builder *b, struct lower_gs_state *state,
605                         unsigned stream, nir_def *unrolled_id)
606 {
607    return previous_count(b, state, stream, unrolled_id,
608                          GS_COUNTER_XFB_PRIMITIVES);
609 }
610 
611 static void
lower_end_primitive(nir_builder * b,nir_intrinsic_instr * intr,struct lower_gs_state * state)612 lower_end_primitive(nir_builder *b, nir_intrinsic_instr *intr,
613                     struct lower_gs_state *state)
614 {
615    assert((intr->intrinsic == nir_intrinsic_set_vertex_and_primitive_count ||
616            b->shader->info.gs.output_primitive != MESA_PRIM_POINTS) &&
617           "endprimitive for points should've been removed");
618 
619    /* The GS is the last stage before rasterization, so if we discard the
620     * rasterization, we don't output an index buffer, nothing will read it.
621     * Index buffer is only for the rasterization stream.
622     */
623    unsigned stream = nir_intrinsic_stream_id(intr);
624    if (state->rasterizer_discard || stream != 0)
625       return;
626 
627    libagx_end_primitive(
628       b, load_geometry_param(b, output_index_buffer), intr->src[0].ssa,
629       intr->src[1].ssa, intr->src[2].ssa,
630       previous_vertices(b, state, 0, calc_unrolled_id(b)),
631       previous_primitives(b, state, 0, calc_unrolled_id(b)),
632       calc_unrolled_index_id(b),
633       nir_imm_bool(b, b->shader->info.gs.output_primitive != MESA_PRIM_POINTS));
634 }
635 
636 static unsigned
verts_in_output_prim(nir_shader * gs)637 verts_in_output_prim(nir_shader *gs)
638 {
639    return mesa_vertices_per_prim(gs->info.gs.output_primitive);
640 }
641 
642 static void
write_xfb(nir_builder * b,struct lower_gs_state * state,unsigned stream,nir_def * index_in_strip,nir_def * prim_id_in_invocation)643 write_xfb(nir_builder *b, struct lower_gs_state *state, unsigned stream,
644           nir_def *index_in_strip, nir_def *prim_id_in_invocation)
645 {
646    struct nir_xfb_info *xfb = b->shader->xfb_info;
647    unsigned verts = verts_in_output_prim(b->shader);
648 
649    /* Get the index of this primitive in the XFB buffer. That is, the base for
650     * this invocation for the stream plus the offset within this invocation.
651     */
652    nir_def *invocation_base =
653       previous_xfb_primitives(b, state, stream, calc_unrolled_id(b));
654 
655    nir_def *prim_index = nir_iadd(b, invocation_base, prim_id_in_invocation);
656    nir_def *base_index = nir_imul_imm(b, prim_index, verts);
657 
658    nir_def *xfb_prims = load_geometry_param(b, xfb_prims[stream]);
659    nir_push_if(b, nir_ult(b, prim_index, xfb_prims));
660 
661    /* Write XFB for each output */
662    for (unsigned i = 0; i < xfb->output_count; ++i) {
663       nir_xfb_output_info output = xfb->outputs[i];
664 
665       /* Only write to the selected stream */
666       if (xfb->buffer_to_stream[output.buffer] != stream)
667          continue;
668 
669       unsigned buffer = output.buffer;
670       unsigned stride = xfb->buffers[buffer].stride;
671       unsigned count = util_bitcount(output.component_mask);
672 
673       for (unsigned vert = 0; vert < verts; ++vert) {
674          /* We write out the vertices backwards, since 0 is the current
675           * emitted vertex (which is actually the last vertex).
676           *
677           * We handle NULL var for
678           * KHR-Single-GL44.enhanced_layouts.xfb_capture_struct.
679           */
680          unsigned v = (verts - 1) - vert;
681          nir_variable *var = state->outputs[output.location][v];
682          nir_def *value = var ? nir_load_var(b, var) : nir_undef(b, 4, 32);
683 
684          /* In case output.component_mask contains invalid components, write
685           * out zeroes instead of blowing up validation.
686           *
687           * KHR-Single-GL44.enhanced_layouts.xfb_capture_inactive_output_component
688           * hits this.
689           */
690          value = nir_pad_vector_imm_int(b, value, 0, 4);
691 
692          nir_def *rotated_vert = nir_imm_int(b, vert);
693          if (verts == 3) {
694             /* Map vertices for output so we get consistent winding order. For
695              * the primitive index, we use the index_in_strip. This is actually
696              * the vertex index in the strip, hence
697              * offset by 2 relative to the true primitive index (#2 for the
698              * first triangle in the strip, #3 for the second). That's ok
699              * because only the parity matters.
700              */
701             rotated_vert = libagx_map_vertex_in_tri_strip(
702                b, index_in_strip, rotated_vert,
703                nir_inot(b, nir_i2b(b, nir_load_provoking_last(b))));
704          }
705 
706          nir_def *addr = libagx_xfb_vertex_address(
707             b, nir_load_geometry_param_buffer_agx(b), base_index, rotated_vert,
708             nir_imm_int(b, buffer), nir_imm_int(b, stride),
709             nir_imm_int(b, output.offset));
710 
711          nir_build_store_global(
712             b, nir_channels(b, value, output.component_mask), addr,
713             .align_mul = 4, .write_mask = nir_component_mask(count),
714             .access = ACCESS_XFB);
715       }
716    }
717 
718    nir_pop_if(b, NULL);
719 }
720 
721 /* Handle transform feedback for a given emit_vertex_with_counter */
722 static void
lower_emit_vertex_xfb(nir_builder * b,nir_intrinsic_instr * intr,struct lower_gs_state * state)723 lower_emit_vertex_xfb(nir_builder *b, nir_intrinsic_instr *intr,
724                       struct lower_gs_state *state)
725 {
726    /* Transform feedback is written for each decomposed output primitive. Since
727     * we're writing strips, that means we output XFB for each vertex after the
728     * first complete primitive is formed.
729     */
730    unsigned first_prim = verts_in_output_prim(b->shader) - 1;
731    nir_def *index_in_strip = intr->src[1].ssa;
732 
733    nir_push_if(b, nir_uge_imm(b, index_in_strip, first_prim));
734    {
735       write_xfb(b, state, nir_intrinsic_stream_id(intr), index_in_strip,
736                 intr->src[3].ssa);
737    }
738    nir_pop_if(b, NULL);
739 
740    /* Transform feedback writes out entire primitives during the emit_vertex. To
741     * do that, we store the values at all vertices in the strip in a little ring
742     * buffer. Index #0 is always the most recent primitive (so non-XFB code can
743     * just grab index #0 without any checking). Index #1 is the previous vertex,
744     * and index #2 is the vertex before that. Now that we've written XFB, since
745     * we've emitted a vertex we need to cycle the ringbuffer, freeing up index
746     * #0 for the next vertex that we are about to emit. We do that by copying
747     * the first n - 1 vertices forward one slot, which has to happen with a
748     * backwards copy implemented here.
749     *
750     * If we're lucky, all of these copies will be propagated away. If we're
751     * unlucky, this involves at most 2 copies per component per XFB output per
752     * vertex.
753     */
754    u_foreach_bit64(slot, b->shader->info.outputs_written) {
755       /* Note: if we're outputting points, verts_in_output_prim will be 1, so
756        * this loop will not execute. This is intended: points are self-contained
757        * primitives and do not need these copies.
758        */
759       for (int v = verts_in_output_prim(b->shader) - 1; v >= 1; --v) {
760          nir_def *value = nir_load_var(b, state->outputs[slot][v - 1]);
761 
762          nir_store_var(b, state->outputs[slot][v], value,
763                        nir_component_mask(value->num_components));
764       }
765    }
766 }
767 
768 static bool
lower_gs_instr(nir_builder * b,nir_intrinsic_instr * intr,void * state)769 lower_gs_instr(nir_builder *b, nir_intrinsic_instr *intr, void *state)
770 {
771    b->cursor = nir_before_instr(&intr->instr);
772 
773    switch (intr->intrinsic) {
774    case nir_intrinsic_set_vertex_and_primitive_count:
775       /* This instruction is mostly for the count shader, so just remove. But
776        * for points, we write the index buffer here so the rast shader can map.
777        */
778       if (b->shader->info.gs.output_primitive == MESA_PRIM_POINTS) {
779          lower_end_primitive(b, intr, state);
780       }
781 
782       break;
783 
784    case nir_intrinsic_end_primitive_with_counter: {
785       unsigned min = verts_in_output_prim(b->shader);
786 
787       /* We only write out complete primitives */
788       nir_push_if(b, nir_uge_imm(b, intr->src[1].ssa, min));
789       {
790          lower_end_primitive(b, intr, state);
791       }
792       nir_pop_if(b, NULL);
793       break;
794    }
795 
796    case nir_intrinsic_emit_vertex_with_counter:
797       /* emit_vertex triggers transform feedback but is otherwise a no-op. */
798       if (b->shader->xfb_info)
799          lower_emit_vertex_xfb(b, intr, state);
800       break;
801 
802    default:
803       return false;
804    }
805 
806    nir_instr_remove(&intr->instr);
807    return true;
808 }
809 
810 static bool
collect_components(nir_builder * b,nir_intrinsic_instr * intr,void * data)811 collect_components(nir_builder *b, nir_intrinsic_instr *intr, void *data)
812 {
813    uint8_t *counts = data;
814    if (intr->intrinsic != nir_intrinsic_store_output)
815       return false;
816 
817    unsigned count = nir_intrinsic_component(intr) +
818                     util_last_bit(nir_intrinsic_write_mask(intr));
819 
820    unsigned loc =
821       nir_intrinsic_io_semantics(intr).location + nir_src_as_uint(intr->src[1]);
822 
823    uint8_t *total_count = &counts[loc];
824 
825    *total_count = MAX2(*total_count, count);
826    return true;
827 }
828 
829 /*
830  * Create the pre-GS shader. This is a small compute 1x1x1 kernel that patches
831  * up the VDM Index List command from the draw to read the produced geometry, as
832  * well as updates transform feedack offsets and counters as applicable (TODO).
833  */
834 static nir_shader *
agx_nir_create_pre_gs(struct lower_gs_state * state,const nir_shader * libagx,bool indexed,bool restart,struct nir_xfb_info * xfb,unsigned vertices_per_prim,uint8_t streams,unsigned invocations)835 agx_nir_create_pre_gs(struct lower_gs_state *state, const nir_shader *libagx,
836                       bool indexed, bool restart, struct nir_xfb_info *xfb,
837                       unsigned vertices_per_prim, uint8_t streams,
838                       unsigned invocations)
839 {
840    nir_builder b_ = nir_builder_init_simple_shader(
841       MESA_SHADER_COMPUTE, &agx_nir_options, "Pre-GS patch up");
842    nir_builder *b = &b_;
843 
844    /* Load the number of primitives input to the GS */
845    nir_def *unrolled_in_prims = load_geometry_param(b, input_primitives);
846 
847    /* Setup the draw from the rasterization stream (0). */
848    if (!state->rasterizer_discard) {
849       libagx_build_gs_draw(
850          b, nir_load_geometry_param_buffer_agx(b), nir_imm_bool(b, indexed),
851          previous_vertices(b, state, 0, unrolled_in_prims),
852          restart ? previous_primitives(b, state, 0, unrolled_in_prims)
853                  : nir_imm_int(b, 0));
854    }
855 
856    /* Determine the number of primitives generated in each stream */
857    nir_def *in_prims[MAX_VERTEX_STREAMS], *prims[MAX_VERTEX_STREAMS];
858 
859    u_foreach_bit(i, streams) {
860       in_prims[i] = previous_xfb_primitives(b, state, i, unrolled_in_prims);
861       prims[i] = in_prims[i];
862 
863       add_counter(b, load_geometry_param(b, prims_generated_counter[i]),
864                   prims[i]);
865    }
866 
867    if (xfb) {
868       /* Write XFB addresses */
869       nir_def *offsets[4] = {NULL};
870       u_foreach_bit(i, xfb->buffers_written) {
871          offsets[i] = libagx_setup_xfb_buffer(
872             b, nir_load_geometry_param_buffer_agx(b), nir_imm_int(b, i));
873       }
874 
875       /* Now clamp to the number that XFB captures */
876       for (unsigned i = 0; i < xfb->output_count; ++i) {
877          nir_xfb_output_info output = xfb->outputs[i];
878 
879          unsigned buffer = output.buffer;
880          unsigned stream = xfb->buffer_to_stream[buffer];
881          unsigned stride = xfb->buffers[buffer].stride;
882          unsigned words_written = util_bitcount(output.component_mask);
883          unsigned bytes_written = words_written * 4;
884 
885          /* Primitive P will write up to (but not including) offset:
886           *
887           *    xfb_offset + ((P - 1) * (verts_per_prim * stride))
888           *               + ((verts_per_prim - 1) * stride)
889           *               + output_offset
890           *               + output_size
891           *
892           * Given an XFB buffer of size xfb_size, we get the inequality:
893           *
894           *    floor(P) <= (stride + xfb_size - xfb_offset - output_offset -
895           *                     output_size) // (stride * verts_per_prim)
896           */
897          nir_def *size = load_geometry_param(b, xfb_size[buffer]);
898          size = nir_iadd_imm(b, size, stride - output.offset - bytes_written);
899          size = nir_isub(b, size, offsets[buffer]);
900          size = nir_imax(b, size, nir_imm_int(b, 0));
901          nir_def *max_prims = nir_udiv_imm(b, size, stride * vertices_per_prim);
902 
903          prims[stream] = nir_umin(b, prims[stream], max_prims);
904       }
905 
906       nir_def *any_overflow = nir_imm_false(b);
907 
908       u_foreach_bit(i, streams) {
909          nir_def *overflow = nir_ult(b, prims[i], in_prims[i]);
910          any_overflow = nir_ior(b, any_overflow, overflow);
911 
912          store_geometry_param(b, xfb_prims[i], prims[i]);
913 
914          add_counter(b, load_geometry_param(b, xfb_overflow[i]),
915                      nir_b2i32(b, overflow));
916 
917          add_counter(b, load_geometry_param(b, xfb_prims_generated_counter[i]),
918                      prims[i]);
919       }
920 
921       add_counter(b, load_geometry_param(b, xfb_any_overflow),
922                   nir_b2i32(b, any_overflow));
923 
924       /* Update XFB counters */
925       u_foreach_bit(i, xfb->buffers_written) {
926          uint32_t prim_stride_B = xfb->buffers[i].stride * vertices_per_prim;
927          unsigned stream = xfb->buffer_to_stream[i];
928 
929          nir_def *off_ptr = load_geometry_param(b, xfb_offs_ptrs[i]);
930          nir_def *size = nir_imul_imm(b, prims[stream], prim_stride_B);
931          add_counter(b, off_ptr, size);
932       }
933    }
934 
935    /* The geometry shader receives a number of input primitives. The driver
936     * should disable this counter when tessellation is active TODO and count
937     * patches separately.
938     */
939    add_counter(
940       b,
941       nir_load_stat_query_address_agx(b, .base = PIPE_STAT_QUERY_IA_PRIMITIVES),
942       unrolled_in_prims);
943 
944    /* The geometry shader is invoked once per primitive (after unrolling
945     * primitive restart). From the spec:
946     *
947     *    In case of instanced geometry shaders (see section 11.3.4.2) the
948     *    geometry shader invocations count is incremented for each separate
949     *    instanced invocation.
950     */
951    add_counter(b,
952                nir_load_stat_query_address_agx(
953                   b, .base = PIPE_STAT_QUERY_GS_INVOCATIONS),
954                nir_imul_imm(b, unrolled_in_prims, invocations));
955 
956    nir_def *emitted_prims = nir_imm_int(b, 0);
957    u_foreach_bit(i, streams) {
958       emitted_prims =
959          nir_iadd(b, emitted_prims,
960                   previous_xfb_primitives(b, state, i, unrolled_in_prims));
961    }
962 
963    add_counter(
964       b,
965       nir_load_stat_query_address_agx(b, .base = PIPE_STAT_QUERY_GS_PRIMITIVES),
966       emitted_prims);
967 
968    /* Clipper queries are not well-defined, so we can emulate them in lots of
969     * silly ways. We need the hardware counters to implement them properly. For
970     * now, just consider all primitives emitted as passing through the clipper.
971     * This satisfies spec text:
972     *
973     *    The number of primitives that reach the primitive clipping stage.
974     *
975     * and
976     *
977     *    If at least one vertex of the primitive lies inside the clipping
978     *    volume, the counter is incremented by one or more. Otherwise, the
979     *    counter is incremented by zero or more.
980     */
981    add_counter(
982       b,
983       nir_load_stat_query_address_agx(b, .base = PIPE_STAT_QUERY_C_PRIMITIVES),
984       emitted_prims);
985 
986    add_counter(
987       b,
988       nir_load_stat_query_address_agx(b, .base = PIPE_STAT_QUERY_C_INVOCATIONS),
989       emitted_prims);
990 
991    /* Preprocess it */
992    UNUSED struct agx_uncompiled_shader_info info;
993    agx_preprocess_nir(b->shader, libagx, false, &info);
994 
995    return b->shader;
996 }
997 
998 static bool
rewrite_invocation_id(nir_builder * b,nir_intrinsic_instr * intr,void * data)999 rewrite_invocation_id(nir_builder *b, nir_intrinsic_instr *intr, void *data)
1000 {
1001    if (intr->intrinsic != nir_intrinsic_load_invocation_id)
1002       return false;
1003 
1004    b->cursor = nir_instr_remove(&intr->instr);
1005    nir_def_rewrite_uses(&intr->def, nir_u2uN(b, data, intr->def.bit_size));
1006    return true;
1007 }
1008 
1009 /*
1010  * Geometry shader instancing allows a GS to run multiple times. The number of
1011  * times is statically known and small. It's easiest to turn this into a loop
1012  * inside the GS, to avoid the feature "leaking" outside and affecting e.g. the
1013  * counts.
1014  */
1015 static void
agx_nir_lower_gs_instancing(nir_shader * gs)1016 agx_nir_lower_gs_instancing(nir_shader *gs)
1017 {
1018    unsigned nr_invocations = gs->info.gs.invocations;
1019    nir_function_impl *impl = nir_shader_get_entrypoint(gs);
1020 
1021    /* Each invocation can produce up to the shader-declared max_vertices, so
1022     * multiply it up for proper bounds check. Emitting more than the declared
1023     * max_vertices per invocation results in undefined behaviour, so erroneously
1024     * emitting more as asked on early invocations is a perfectly cromulent
1025     * behvaiour.
1026     */
1027    gs->info.gs.vertices_out *= gs->info.gs.invocations;
1028 
1029    /* Get the original function */
1030    nir_cf_list list;
1031    nir_cf_extract(&list, nir_before_impl(impl), nir_after_impl(impl));
1032 
1033    /* Create a builder for the wrapped function */
1034    nir_builder b = nir_builder_at(nir_after_block(nir_start_block(impl)));
1035 
1036    nir_variable *i =
1037       nir_local_variable_create(impl, glsl_uintN_t_type(16), NULL);
1038    nir_store_var(&b, i, nir_imm_intN_t(&b, 0, 16), ~0);
1039    nir_def *index = NULL;
1040 
1041    /* Create a loop in the wrapped function */
1042    nir_loop *loop = nir_push_loop(&b);
1043    {
1044       index = nir_load_var(&b, i);
1045       nir_push_if(&b, nir_uge_imm(&b, index, nr_invocations));
1046       {
1047          nir_jump(&b, nir_jump_break);
1048       }
1049       nir_pop_if(&b, NULL);
1050 
1051       b.cursor = nir_cf_reinsert(&list, b.cursor);
1052       nir_store_var(&b, i, nir_iadd_imm(&b, index, 1), ~0);
1053 
1054       /* Make sure we end the primitive between invocations. If the geometry
1055        * shader already ended the primitive, this will get optimized out.
1056        */
1057       nir_end_primitive(&b);
1058    }
1059    nir_pop_loop(&b, loop);
1060 
1061    /* We've mucked about with control flow */
1062    nir_metadata_preserve(impl, nir_metadata_none);
1063 
1064    /* Use the loop counter as the invocation ID each iteration */
1065    nir_shader_intrinsics_pass(gs, rewrite_invocation_id,
1066                               nir_metadata_block_index | nir_metadata_dominance,
1067                               index);
1068 }
1069 
1070 static bool
strip_side_effects(nir_builder * b,nir_intrinsic_instr * intr,void * _)1071 strip_side_effects(nir_builder *b, nir_intrinsic_instr *intr, void *_)
1072 {
1073    switch (intr->intrinsic) {
1074    case nir_intrinsic_store_global:
1075    case nir_intrinsic_global_atomic:
1076    case nir_intrinsic_global_atomic_swap:
1077       break;
1078    default:
1079       return false;
1080    }
1081 
1082    /* If there's a side effect that's actually required for the prepass, we have
1083     * to keep it in.
1084     */
1085    if (nir_intrinsic_infos[intr->intrinsic].has_dest &&
1086        !list_is_empty(&intr->def.uses))
1087       return false;
1088 
1089    /* Do not strip transform feedback stores, the rasterization shader doesn't
1090     * execute them.
1091     */
1092    if (intr->intrinsic == nir_intrinsic_store_global &&
1093        nir_intrinsic_access(intr) & ACCESS_XFB)
1094       return false;
1095 
1096    /* Otherwise, remove the dead instruction. The rasterization shader will
1097     * execute the side effect so the side effect still happens at least once.
1098     */
1099    nir_instr_remove(&intr->instr);
1100    return true;
1101 }
1102 
1103 static void
link_libagx(nir_shader * nir,const nir_shader * libagx)1104 link_libagx(nir_shader *nir, const nir_shader *libagx)
1105 {
1106    nir_link_shader_functions(nir, libagx);
1107    NIR_PASS(_, nir, nir_inline_functions);
1108    nir_remove_non_entrypoints(nir);
1109    NIR_PASS(_, nir, nir_lower_indirect_derefs, nir_var_function_temp, 64);
1110    NIR_PASS(_, nir, nir_opt_dce);
1111    NIR_PASS(_, nir, nir_lower_vars_to_explicit_types,
1112             nir_var_shader_temp | nir_var_function_temp | nir_var_mem_shared |
1113                nir_var_mem_global,
1114             glsl_get_cl_type_size_align);
1115    NIR_PASS(_, nir, nir_opt_deref);
1116    NIR_PASS(_, nir, nir_lower_vars_to_ssa);
1117    NIR_PASS(_, nir, nir_lower_explicit_io,
1118             nir_var_shader_temp | nir_var_function_temp | nir_var_mem_shared |
1119                nir_var_mem_global,
1120             nir_address_format_62bit_generic);
1121 }
1122 
1123 bool
agx_nir_lower_gs(nir_shader * gs,const nir_shader * libagx,bool rasterizer_discard,nir_shader ** gs_count,nir_shader ** gs_copy,nir_shader ** pre_gs,enum mesa_prim * out_mode,unsigned * out_count_words)1124 agx_nir_lower_gs(nir_shader *gs, const nir_shader *libagx,
1125                  bool rasterizer_discard, nir_shader **gs_count,
1126                  nir_shader **gs_copy, nir_shader **pre_gs,
1127                  enum mesa_prim *out_mode, unsigned *out_count_words)
1128 {
1129    /* Collect output component counts so we can size the geometry output buffer
1130     * appropriately, instead of assuming everything is vec4.
1131     */
1132    uint8_t component_counts[NUM_TOTAL_VARYING_SLOTS] = {0};
1133    nir_shader_intrinsics_pass(gs, collect_components, nir_metadata_all,
1134                               component_counts);
1135 
1136    /* If geometry shader instancing is used, lower it away before linking
1137     * anything. Otherwise, smash the invocation ID to zero.
1138     */
1139    if (gs->info.gs.invocations != 1) {
1140       agx_nir_lower_gs_instancing(gs);
1141    } else {
1142       nir_function_impl *impl = nir_shader_get_entrypoint(gs);
1143       nir_builder b = nir_builder_at(nir_before_impl(impl));
1144 
1145       nir_shader_intrinsics_pass(
1146          gs, rewrite_invocation_id,
1147          nir_metadata_block_index | nir_metadata_dominance, nir_imm_int(&b, 0));
1148    }
1149 
1150    NIR_PASS(_, gs, nir_shader_intrinsics_pass, lower_gs_inputs,
1151             nir_metadata_block_index | nir_metadata_dominance, NULL);
1152 
1153    /* Lower geometry shader writes to contain all of the required counts, so we
1154     * know where in the various buffers we should write vertices.
1155     */
1156    NIR_PASS(_, gs, nir_lower_gs_intrinsics,
1157             nir_lower_gs_intrinsics_count_primitives |
1158                nir_lower_gs_intrinsics_per_stream |
1159                nir_lower_gs_intrinsics_count_vertices_per_primitive |
1160                nir_lower_gs_intrinsics_overwrite_incomplete |
1161                nir_lower_gs_intrinsics_always_end_primitive |
1162                nir_lower_gs_intrinsics_count_decomposed_primitives);
1163 
1164    /* Clean up after all that lowering we did */
1165    bool progress = false;
1166    do {
1167       progress = false;
1168       NIR_PASS(progress, gs, nir_lower_var_copies);
1169       NIR_PASS(progress, gs, nir_lower_variable_initializers,
1170                nir_var_shader_temp);
1171       NIR_PASS(progress, gs, nir_lower_vars_to_ssa);
1172       NIR_PASS(progress, gs, nir_copy_prop);
1173       NIR_PASS(progress, gs, nir_opt_constant_folding);
1174       NIR_PASS(progress, gs, nir_opt_algebraic);
1175       NIR_PASS(progress, gs, nir_opt_cse);
1176       NIR_PASS(progress, gs, nir_opt_dead_cf);
1177       NIR_PASS(progress, gs, nir_opt_dce);
1178 
1179       /* Unrolling lets us statically determine counts more often, which
1180        * otherwise would not be possible with multiple invocations even in the
1181        * simplest of cases.
1182        */
1183       NIR_PASS(progress, gs, nir_opt_loop_unroll);
1184    } while (progress);
1185 
1186    /* If we know counts at compile-time we can simplify, so try to figure out
1187     * the counts statically.
1188     */
1189    struct lower_gs_state gs_state = {
1190       .rasterizer_discard = rasterizer_discard,
1191    };
1192 
1193    nir_gs_count_vertices_and_primitives(
1194       gs, gs_state.static_count[GS_COUNTER_VERTICES],
1195       gs_state.static_count[GS_COUNTER_PRIMITIVES],
1196       gs_state.static_count[GS_COUNTER_XFB_PRIMITIVES], 4);
1197 
1198    /* Anything we don't know statically will be tracked by the count buffer.
1199     * Determine the layout for it.
1200     */
1201    for (unsigned i = 0; i < MAX_VERTEX_STREAMS; ++i) {
1202       for (unsigned c = 0; c < GS_NUM_COUNTERS; ++c) {
1203          gs_state.count_index[i][c] =
1204             (gs_state.static_count[c][i] < 0) ? gs_state.count_stride_el++ : -1;
1205       }
1206    }
1207 
1208    *gs_copy = agx_nir_create_gs_rast_shader(gs, libagx);
1209 
1210    NIR_PASS(_, gs, nir_shader_intrinsics_pass, lower_id,
1211             nir_metadata_block_index | nir_metadata_dominance, NULL);
1212 
1213    link_libagx(gs, libagx);
1214 
1215    NIR_PASS(_, gs, nir_lower_idiv,
1216             &(const nir_lower_idiv_options){.allow_fp16 = true});
1217 
1218    /* All those variables we created should've gone away by now */
1219    NIR_PASS(_, gs, nir_remove_dead_variables, nir_var_function_temp, NULL);
1220 
1221    /* If there is any unknown count, we need a geometry count shader */
1222    if (gs_state.count_stride_el > 0)
1223       *gs_count = agx_nir_create_geometry_count_shader(gs, libagx, &gs_state);
1224    else
1225       *gs_count = NULL;
1226 
1227    /* Geometry shader outputs are staged to temporaries */
1228    struct agx_lower_output_to_var_state state = {0};
1229 
1230    u_foreach_bit64(slot, gs->info.outputs_written) {
1231       const char *slot_name =
1232          gl_varying_slot_name_for_stage(slot, MESA_SHADER_GEOMETRY);
1233 
1234       for (unsigned i = 0; i < MAX_PRIM_OUT_SIZE; ++i) {
1235          gs_state.outputs[slot][i] = nir_variable_create(
1236             gs, nir_var_shader_temp,
1237             glsl_vector_type(GLSL_TYPE_UINT, component_counts[slot]),
1238             ralloc_asprintf(gs, "%s-%u", slot_name, i));
1239       }
1240 
1241       state.outputs[slot] = gs_state.outputs[slot][0];
1242    }
1243 
1244    NIR_PASS(_, gs, nir_shader_instructions_pass, agx_lower_output_to_var,
1245             nir_metadata_block_index | nir_metadata_dominance, &state);
1246 
1247    NIR_PASS(_, gs, nir_shader_intrinsics_pass, lower_gs_instr,
1248             nir_metadata_none, &gs_state);
1249 
1250    /* Determine if we are guaranteed to rasterize at least one vertex, so that
1251     * we can strip the prepass of side effects knowing they will execute in the
1252     * rasterization shader.
1253     */
1254    bool rasterizes_at_least_one_vertex =
1255       !rasterizer_discard && gs_state.static_count[0][0] > 0;
1256 
1257    /* Clean up after all that lowering we did */
1258    nir_lower_global_vars_to_local(gs);
1259    do {
1260       progress = false;
1261       NIR_PASS(progress, gs, nir_lower_var_copies);
1262       NIR_PASS(progress, gs, nir_lower_variable_initializers,
1263                nir_var_shader_temp);
1264       NIR_PASS(progress, gs, nir_lower_vars_to_ssa);
1265       NIR_PASS(progress, gs, nir_copy_prop);
1266       NIR_PASS(progress, gs, nir_opt_constant_folding);
1267       NIR_PASS(progress, gs, nir_opt_algebraic);
1268       NIR_PASS(progress, gs, nir_opt_cse);
1269       NIR_PASS(progress, gs, nir_opt_dead_cf);
1270       NIR_PASS(progress, gs, nir_opt_dce);
1271       NIR_PASS(progress, gs, nir_opt_loop_unroll);
1272 
1273       /* When rasterizing, we try to move side effects to the rasterizer shader
1274        * and strip the prepass of the dead side effects. Run this in the opt
1275        * loop because it interacts with nir_opt_dce.
1276        */
1277       if (rasterizes_at_least_one_vertex) {
1278          NIR_PASS(progress, gs, nir_shader_intrinsics_pass, strip_side_effects,
1279                   nir_metadata_block_index | nir_metadata_dominance, NULL);
1280       }
1281    } while (progress);
1282 
1283    /* All those variables we created should've gone away by now */
1284    NIR_PASS(_, gs, nir_remove_dead_variables, nir_var_function_temp, NULL);
1285 
1286    NIR_PASS(_, gs, nir_opt_sink, ~0);
1287    NIR_PASS(_, gs, nir_opt_move, ~0);
1288    NIR_PASS(_, gs, nir_shader_intrinsics_pass, lower_id,
1289             nir_metadata_block_index | nir_metadata_dominance, NULL);
1290 
1291    /* Create auxiliary programs */
1292    *pre_gs = agx_nir_create_pre_gs(
1293       &gs_state, libagx, true, gs->info.gs.output_primitive != MESA_PRIM_POINTS,
1294       gs->xfb_info, verts_in_output_prim(gs), gs->info.gs.active_stream_mask,
1295       gs->info.gs.invocations);
1296 
1297    /* Signal what primitive we want to draw the GS Copy VS with */
1298    *out_mode = gs->info.gs.output_primitive;
1299    *out_count_words = gs_state.count_stride_el;
1300    return true;
1301 }
1302 
1303 /*
1304  * Vertex shaders (tessellation evaluation shaders) before a geometry shader run
1305  * as a dedicated compute prepass. They are invoked as (count, instances, 1),
1306  * equivalent to a geometry shader inputting POINTS, so the vertex output buffer
1307  * is indexed according to calc_unrolled_id.
1308  *
1309  * This function lowers their vertex shader I/O to compute.
1310  *
1311  * Vertex ID becomes an index buffer pull (without applying the topology). Store
1312  * output becomes a store into the global vertex output buffer.
1313  */
1314 static bool
lower_vs_before_gs(nir_builder * b,nir_intrinsic_instr * intr,void * data)1315 lower_vs_before_gs(nir_builder *b, nir_intrinsic_instr *intr, void *data)
1316 {
1317    if (intr->intrinsic != nir_intrinsic_store_output)
1318       return false;
1319 
1320    b->cursor = nir_instr_remove(&intr->instr);
1321    nir_io_semantics sem = nir_intrinsic_io_semantics(intr);
1322    nir_def *location = nir_iadd_imm(b, intr->src[1].ssa, sem.location);
1323 
1324    nir_def *addr = libagx_vertex_output_address(
1325       b, nir_load_geometry_param_buffer_agx(b), calc_unrolled_id(b), location,
1326       nir_imm_int64(b, b->shader->info.outputs_written));
1327 
1328    assert(nir_src_bit_size(intr->src[0]) == 32);
1329    addr = nir_iadd_imm(b, addr, nir_intrinsic_component(intr) * 4);
1330 
1331    nir_store_global(b, addr, 4, intr->src[0].ssa,
1332                     nir_intrinsic_write_mask(intr));
1333    return true;
1334 }
1335 
1336 bool
agx_nir_lower_vs_before_gs(struct nir_shader * vs,const struct nir_shader * libagx,unsigned index_size_B,uint64_t * outputs)1337 agx_nir_lower_vs_before_gs(struct nir_shader *vs,
1338                            const struct nir_shader *libagx,
1339                            unsigned index_size_B, uint64_t *outputs)
1340 {
1341    bool progress = false;
1342 
1343    /* Lower vertex ID to an index buffer pull without a topology applied */
1344    progress |= agx_nir_lower_index_buffer(vs, index_size_B, false);
1345 
1346    /* Lower vertex stores to memory stores */
1347    progress |= nir_shader_intrinsics_pass(
1348       vs, lower_vs_before_gs, nir_metadata_block_index | nir_metadata_dominance,
1349       &index_size_B);
1350 
1351    /* Lower instance ID and num vertices */
1352    progress |= nir_shader_intrinsics_pass(
1353       vs, lower_id, nir_metadata_block_index | nir_metadata_dominance, NULL);
1354 
1355    /* Link libagx, used in lower_vs_before_gs */
1356    if (progress)
1357       link_libagx(vs, libagx);
1358 
1359    /* Turn into a compute shader now that we're free of vertexisms */
1360    vs->info.stage = MESA_SHADER_COMPUTE;
1361    memset(&vs->info.cs, 0, sizeof(vs->info.cs));
1362    vs->xfb_info = NULL;
1363    *outputs = vs->info.outputs_written;
1364    return true;
1365 }
1366 
1367 void
agx_nir_prefix_sum_gs(nir_builder * b,const void * data)1368 agx_nir_prefix_sum_gs(nir_builder *b, const void *data)
1369 {
1370    const unsigned *words = data;
1371 
1372    uint32_t subgroup_size = 32;
1373    b->shader->info.workgroup_size[0] = subgroup_size;
1374    b->shader->info.workgroup_size[1] = *words;
1375 
1376    libagx_prefix_sum(b, load_geometry_param(b, count_buffer),
1377                      load_geometry_param(b, input_primitives),
1378                      nir_imm_int(b, *words),
1379                      nir_trim_vector(b, nir_load_local_invocation_id(b), 2));
1380 }
1381 
1382 void
agx_nir_gs_setup_indirect(nir_builder * b,const void * data)1383 agx_nir_gs_setup_indirect(nir_builder *b, const void *data)
1384 {
1385    const struct agx_gs_setup_indirect_key *key = data;
1386 
1387    libagx_gs_setup_indirect(b, nir_load_geometry_param_buffer_agx(b),
1388                             nir_load_input_assembly_buffer_agx(b),
1389                             nir_imm_int(b, key->prim),
1390                             nir_channel(b, nir_load_local_invocation_id(b), 0));
1391 }
1392 
1393 void
agx_nir_unroll_restart(nir_builder * b,const void * data)1394 agx_nir_unroll_restart(nir_builder *b, const void *data)
1395 {
1396    const struct agx_unroll_restart_key *key = data;
1397    nir_def *ia = nir_load_input_assembly_buffer_agx(b);
1398    nir_def *draw = nir_channel(b, nir_load_workgroup_id(b), 0);
1399    nir_def *mode = nir_imm_int(b, key->prim);
1400 
1401    if (key->index_size_B == 1)
1402       libagx_unroll_restart_u8(b, ia, mode, draw);
1403    else if (key->index_size_B == 2)
1404       libagx_unroll_restart_u16(b, ia, mode, draw);
1405    else if (key->index_size_B == 4)
1406       libagx_unroll_restart_u32(b, ia, mode, draw);
1407    else
1408       unreachable("invalid index size");
1409 }
1410