• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright 2023 Alyssa Rosenzweig
3  * Copyright 2023 Valve Corporation
4  * SPDX-License-Identifier: MIT
5  */
6 
7 #include "agx_nir_lower_gs.h"
8 #include "asahi/compiler/agx_compile.h"
9 #include "compiler/nir/nir_builder.h"
10 #include "gallium/include/pipe/p_defines.h"
11 #include "libagx/geometry.h"
12 #include "libagx/libagx.h"
13 #include "util/bitscan.h"
14 #include "util/list.h"
15 #include "util/macros.h"
16 #include "util/ralloc.h"
17 #include "util/u_math.h"
18 #include "nir.h"
19 #include "nir_builder_opcodes.h"
20 #include "nir_intrinsics.h"
21 #include "nir_intrinsics_indices.h"
22 #include "nir_xfb_info.h"
23 #include "shader_enums.h"
24 
25 enum gs_counter {
26    GS_COUNTER_VERTICES = 0,
27    GS_COUNTER_PRIMITIVES,
28    GS_COUNTER_XFB_PRIMITIVES,
29    GS_NUM_COUNTERS
30 };
31 
32 #define MAX_PRIM_OUT_SIZE 3
33 
34 struct lower_gs_state {
35    int static_count[GS_NUM_COUNTERS][MAX_VERTEX_STREAMS];
36    nir_variable *outputs[NUM_TOTAL_VARYING_SLOTS][MAX_PRIM_OUT_SIZE];
37 
38    /* The count buffer contains `count_stride_el` 32-bit words in a row for each
39     * input primitive, for `input_primitives * count_stride_el * 4` total bytes.
40     */
41    unsigned count_stride_el;
42 
43    /* The index of each counter in the count buffer, or -1 if it's not in the
44     * count buffer.
45     *
46     * Invariant: count_stride_el == sum(count_index[i][j] >= 0).
47     */
48    int count_index[MAX_VERTEX_STREAMS][GS_NUM_COUNTERS];
49 
50    bool rasterizer_discard;
51 };
52 
53 /* Helpers for loading from the geometry state buffer */
54 static nir_def *
load_geometry_param_offset(nir_builder * b,uint32_t offset,uint8_t bytes)55 load_geometry_param_offset(nir_builder *b, uint32_t offset, uint8_t bytes)
56 {
57    nir_def *base = nir_load_geometry_param_buffer_agx(b);
58    nir_def *addr = nir_iadd_imm(b, base, offset);
59 
60    assert((offset % bytes) == 0 && "must be naturally aligned");
61 
62    return nir_load_global_constant(b, addr, bytes, 1, bytes * 8);
63 }
64 
65 static void
store_geometry_param_offset(nir_builder * b,nir_def * def,uint32_t offset,uint8_t bytes)66 store_geometry_param_offset(nir_builder *b, nir_def *def, uint32_t offset,
67                             uint8_t bytes)
68 {
69    nir_def *base = nir_load_geometry_param_buffer_agx(b);
70    nir_def *addr = nir_iadd_imm(b, base, offset);
71 
72    assert((offset % bytes) == 0 && "must be naturally aligned");
73 
74    nir_store_global(b, addr, 4, def, nir_component_mask(def->num_components));
75 }
76 
77 #define store_geometry_param(b, field, def)                                    \
78    store_geometry_param_offset(                                                \
79       b, def, offsetof(struct agx_geometry_params, field),                     \
80       sizeof(((struct agx_geometry_params *)0)->field))
81 
82 #define load_geometry_param(b, field)                                          \
83    load_geometry_param_offset(                                                 \
84       b, offsetof(struct agx_geometry_params, field),                          \
85       sizeof(((struct agx_geometry_params *)0)->field))
86 
87 /* Helper for updating counters */
88 static void
add_counter(nir_builder * b,nir_def * counter,nir_def * increment)89 add_counter(nir_builder *b, nir_def *counter, nir_def *increment)
90 {
91    /* If the counter is NULL, the counter is disabled. Skip the update. */
92    nir_if *nif = nir_push_if(b, nir_ine_imm(b, counter, 0));
93    {
94       nir_def *old = nir_load_global(b, counter, 4, 1, 32);
95       nir_def *new_ = nir_iadd(b, old, increment);
96       nir_store_global(b, counter, 4, new_, nir_component_mask(1));
97    }
98    nir_pop_if(b, nif);
99 }
100 
101 /* Helpers for lowering I/O to variables */
102 static void
lower_store_to_var(nir_builder * b,nir_intrinsic_instr * intr,struct agx_lower_output_to_var_state * state)103 lower_store_to_var(nir_builder *b, nir_intrinsic_instr *intr,
104                    struct agx_lower_output_to_var_state *state)
105 {
106    b->cursor = nir_instr_remove(&intr->instr);
107    nir_io_semantics sem = nir_intrinsic_io_semantics(intr);
108    unsigned component = nir_intrinsic_component(intr);
109    nir_def *value = intr->src[0].ssa;
110 
111    assert(nir_src_is_const(intr->src[1]) && "no indirect outputs");
112    assert(nir_intrinsic_write_mask(intr) == nir_component_mask(1) &&
113           "should be scalarized");
114 
115    nir_variable *var =
116       state->outputs[sem.location + nir_src_as_uint(intr->src[1])];
117    if (!var) {
118       assert(sem.location == VARYING_SLOT_PSIZ &&
119              "otherwise in outputs_written");
120       return;
121    }
122 
123    unsigned nr_components = glsl_get_components(glsl_without_array(var->type));
124    assert(component < nr_components);
125 
126    /* Turn it into a vec4 write like NIR expects */
127    value = nir_vector_insert_imm(b, nir_undef(b, nr_components, 32), value,
128                                  component);
129 
130    nir_store_var(b, var, value, BITFIELD_BIT(component));
131 }
132 
133 bool
agx_lower_output_to_var(nir_builder * b,nir_instr * instr,void * data)134 agx_lower_output_to_var(nir_builder *b, nir_instr *instr, void *data)
135 {
136    if (instr->type != nir_instr_type_intrinsic)
137       return false;
138 
139    nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
140    if (intr->intrinsic != nir_intrinsic_store_output)
141       return false;
142 
143    lower_store_to_var(b, intr, data);
144    return true;
145 }
146 
147 /*
148  * Geometry shader invocations are compute-like:
149  *
150  * (primitive ID, instance ID, 1)
151  */
152 static nir_def *
load_primitive_id(nir_builder * b)153 load_primitive_id(nir_builder *b)
154 {
155    return nir_channel(b, nir_load_global_invocation_id(b, 32), 0);
156 }
157 
158 static nir_def *
load_instance_id(nir_builder * b)159 load_instance_id(nir_builder *b)
160 {
161    return nir_channel(b, nir_load_global_invocation_id(b, 32), 1);
162 }
163 
164 /* Geometry shaders use software input assembly. The software vertex shader
165  * is invoked for each index, and the geometry shader applies the topology. This
166  * helper applies the topology.
167  */
168 static nir_def *
vertex_id_for_topology_class(nir_builder * b,nir_def * vert,enum mesa_prim cls)169 vertex_id_for_topology_class(nir_builder *b, nir_def *vert, enum mesa_prim cls)
170 {
171    nir_def *prim = nir_load_primitive_id(b);
172    nir_def *flatshade_first = nir_ieq_imm(b, nir_load_provoking_last(b), 0);
173    nir_def *nr = load_geometry_param(b, gs_grid[0]);
174    nir_def *topology = nir_load_input_topology_agx(b);
175 
176    switch (cls) {
177    case MESA_PRIM_POINTS:
178       return prim;
179 
180    case MESA_PRIM_LINES:
181       return libagx_vertex_id_for_line_class(b, topology, prim, vert, nr);
182 
183    case MESA_PRIM_TRIANGLES:
184       return libagx_vertex_id_for_tri_class(b, topology, prim, vert,
185                                             flatshade_first);
186 
187    case MESA_PRIM_LINES_ADJACENCY:
188       return libagx_vertex_id_for_line_adj_class(b, topology, prim, vert);
189 
190    case MESA_PRIM_TRIANGLES_ADJACENCY:
191       return libagx_vertex_id_for_tri_adj_class(b, topology, prim, vert, nr,
192                                                 flatshade_first);
193 
194    default:
195       unreachable("invalid topology class");
196    }
197 }
198 
199 nir_def *
agx_load_per_vertex_input(nir_builder * b,nir_intrinsic_instr * intr,nir_def * vertex)200 agx_load_per_vertex_input(nir_builder *b, nir_intrinsic_instr *intr,
201                           nir_def *vertex)
202 {
203    assert(intr->intrinsic == nir_intrinsic_load_per_vertex_input);
204    nir_io_semantics sem = nir_intrinsic_io_semantics(intr);
205 
206    nir_def *location = nir_iadd_imm(b, intr->src[1].ssa, sem.location);
207    nir_def *addr;
208 
209    if (b->shader->info.stage == MESA_SHADER_GEOMETRY) {
210       /* GS may be preceded by VS or TES so specified as param */
211       addr = libagx_geometry_input_address(
212          b, nir_load_geometry_param_buffer_agx(b), vertex, location);
213    } else {
214       assert(b->shader->info.stage == MESA_SHADER_TESS_CTRL);
215 
216       /* TCS always preceded by VS so we use the VS state directly */
217       addr = libagx_vertex_output_address(b, nir_load_vs_output_buffer_agx(b),
218                                           nir_load_vs_outputs_agx(b), vertex,
219                                           location);
220    }
221 
222    addr = nir_iadd_imm(b, addr, 4 * nir_intrinsic_component(intr));
223    return nir_load_global_constant(b, addr, 4, intr->def.num_components,
224                                    intr->def.bit_size);
225 }
226 
227 static bool
lower_gs_inputs(nir_builder * b,nir_intrinsic_instr * intr,void * _)228 lower_gs_inputs(nir_builder *b, nir_intrinsic_instr *intr, void *_)
229 {
230    if (intr->intrinsic != nir_intrinsic_load_per_vertex_input)
231       return false;
232 
233    b->cursor = nir_instr_remove(&intr->instr);
234 
235    /* Calculate the vertex ID we're pulling, based on the topology class */
236    nir_def *vert_in_prim = intr->src[0].ssa;
237    nir_def *vertex = vertex_id_for_topology_class(
238       b, vert_in_prim, b->shader->info.gs.input_primitive);
239 
240    nir_def *verts = load_geometry_param(b, vs_grid[0]);
241    nir_def *unrolled =
242       nir_iadd(b, nir_imul(b, nir_load_instance_id(b), verts), vertex);
243 
244    nir_def *val = agx_load_per_vertex_input(b, intr, unrolled);
245    nir_def_rewrite_uses(&intr->def, val);
246    return true;
247 }
248 
249 /*
250  * Unrolled ID is the index of the primitive in the count buffer, given as
251  * (instance ID * # vertices/instance) + vertex ID
252  */
253 static nir_def *
calc_unrolled_id(nir_builder * b)254 calc_unrolled_id(nir_builder *b)
255 {
256    return nir_iadd(
257       b, nir_imul(b, load_instance_id(b), load_geometry_param(b, gs_grid[0])),
258       load_primitive_id(b));
259 }
260 
261 static unsigned
output_vertex_id_stride(nir_shader * gs)262 output_vertex_id_stride(nir_shader *gs)
263 {
264    /* round up to power of two for cheap multiply/division */
265    return util_next_power_of_two(MAX2(gs->info.gs.vertices_out, 1));
266 }
267 
268 /* Variant of calc_unrolled_id that uses a power-of-two stride for indices. This
269  * is sparser (acceptable for index buffer values, not for count buffer
270  * indices). It has the nice property of being cheap to invert, unlike
271  * calc_unrolled_id. So, we use calc_unrolled_id for count buffers and
272  * calc_unrolled_index_id for index values.
273  *
274  * This also multiplies by the appropriate stride to calculate the final index
275  * base value.
276  */
277 static nir_def *
calc_unrolled_index_id(nir_builder * b)278 calc_unrolled_index_id(nir_builder *b)
279 {
280    unsigned vertex_stride = output_vertex_id_stride(b->shader);
281    nir_def *primitives_log2 = load_geometry_param(b, primitives_log2);
282 
283    nir_def *instance = nir_ishl(b, load_instance_id(b), primitives_log2);
284    nir_def *prim = nir_iadd(b, instance, load_primitive_id(b));
285 
286    return nir_imul_imm(b, prim, vertex_stride);
287 }
288 
289 static nir_def *
load_count_address(nir_builder * b,struct lower_gs_state * state,nir_def * unrolled_id,unsigned stream,enum gs_counter counter)290 load_count_address(nir_builder *b, struct lower_gs_state *state,
291                    nir_def *unrolled_id, unsigned stream,
292                    enum gs_counter counter)
293 {
294    int index = state->count_index[stream][counter];
295    if (index < 0)
296       return NULL;
297 
298    nir_def *prim_offset_el =
299       nir_imul_imm(b, unrolled_id, state->count_stride_el);
300 
301    nir_def *offset_el = nir_iadd_imm(b, prim_offset_el, index);
302 
303    return nir_iadd(b, load_geometry_param(b, count_buffer),
304                    nir_u2u64(b, nir_imul_imm(b, offset_el, 4)));
305 }
306 
307 static void
write_counts(nir_builder * b,nir_intrinsic_instr * intr,struct lower_gs_state * state)308 write_counts(nir_builder *b, nir_intrinsic_instr *intr,
309              struct lower_gs_state *state)
310 {
311    /* Store each required counter */
312    nir_def *counts[GS_NUM_COUNTERS] = {
313       [GS_COUNTER_VERTICES] = intr->src[0].ssa,
314       [GS_COUNTER_PRIMITIVES] = intr->src[1].ssa,
315       [GS_COUNTER_XFB_PRIMITIVES] = intr->src[2].ssa,
316    };
317 
318    for (unsigned i = 0; i < GS_NUM_COUNTERS; ++i) {
319       nir_def *addr = load_count_address(b, state, calc_unrolled_id(b),
320                                          nir_intrinsic_stream_id(intr), i);
321 
322       if (addr)
323          nir_store_global(b, addr, 4, counts[i], nir_component_mask(1));
324    }
325 }
326 
327 static bool
lower_gs_count_instr(nir_builder * b,nir_intrinsic_instr * intr,void * data)328 lower_gs_count_instr(nir_builder *b, nir_intrinsic_instr *intr, void *data)
329 {
330    switch (intr->intrinsic) {
331    case nir_intrinsic_emit_vertex_with_counter:
332    case nir_intrinsic_end_primitive_with_counter:
333    case nir_intrinsic_store_output:
334       /* These are for the main shader, just remove them */
335       nir_instr_remove(&intr->instr);
336       return true;
337 
338    case nir_intrinsic_set_vertex_and_primitive_count:
339       b->cursor = nir_instr_remove(&intr->instr);
340       write_counts(b, intr, data);
341       return true;
342 
343    default:
344       return false;
345    }
346 }
347 
348 static bool
lower_id(nir_builder * b,nir_intrinsic_instr * intr,void * data)349 lower_id(nir_builder *b, nir_intrinsic_instr *intr, void *data)
350 {
351    b->cursor = nir_before_instr(&intr->instr);
352 
353    nir_def *id;
354    if (intr->intrinsic == nir_intrinsic_load_primitive_id)
355       id = load_primitive_id(b);
356    else if (intr->intrinsic == nir_intrinsic_load_instance_id)
357       id = load_instance_id(b);
358    else if (intr->intrinsic == nir_intrinsic_load_flat_mask)
359       id = load_geometry_param(b, flat_outputs);
360    else if (intr->intrinsic == nir_intrinsic_load_input_topology_agx)
361       id = load_geometry_param(b, input_topology);
362    else
363       return false;
364 
365    b->cursor = nir_instr_remove(&intr->instr);
366    nir_def_rewrite_uses(&intr->def, id);
367    return true;
368 }
369 
370 /*
371  * Create a "Geometry count" shader. This is a stripped down geometry shader
372  * that just write its number of emitted vertices / primitives / transform
373  * feedback primitives to a count buffer. That count buffer will be prefix
374  * summed prior to running the real geometry shader. This is skipped if the
375  * counts are statically known.
376  */
377 static nir_shader *
agx_nir_create_geometry_count_shader(nir_shader * gs,const nir_shader * libagx,struct lower_gs_state * state)378 agx_nir_create_geometry_count_shader(nir_shader *gs, const nir_shader *libagx,
379                                      struct lower_gs_state *state)
380 {
381    /* Don't muck up the original shader */
382    nir_shader *shader = nir_shader_clone(NULL, gs);
383 
384    if (shader->info.name) {
385       shader->info.name =
386          ralloc_asprintf(shader, "%s_count", shader->info.name);
387    } else {
388       shader->info.name = "count";
389    }
390 
391    NIR_PASS(_, shader, nir_shader_intrinsics_pass, lower_gs_count_instr,
392             nir_metadata_control_flow, state);
393 
394    NIR_PASS(_, shader, nir_shader_intrinsics_pass, lower_id,
395             nir_metadata_control_flow, NULL);
396 
397    agx_preprocess_nir(shader, libagx);
398    return shader;
399 }
400 
401 struct lower_gs_rast_state {
402    nir_def *instance_id, *primitive_id, *output_id;
403    struct agx_lower_output_to_var_state outputs;
404    struct agx_lower_output_to_var_state selected;
405 };
406 
407 static void
select_rast_output(nir_builder * b,nir_intrinsic_instr * intr,struct lower_gs_rast_state * state)408 select_rast_output(nir_builder *b, nir_intrinsic_instr *intr,
409                    struct lower_gs_rast_state *state)
410 {
411    b->cursor = nir_instr_remove(&intr->instr);
412 
413    /* We only care about the rasterization stream in the rasterization
414     * shader, so just ignore emits from other streams.
415     */
416    if (nir_intrinsic_stream_id(intr) != 0)
417       return;
418 
419    u_foreach_bit64(slot, b->shader->info.outputs_written) {
420       nir_def *orig = nir_load_var(b, state->selected.outputs[slot]);
421       nir_def *data = nir_load_var(b, state->outputs.outputs[slot]);
422 
423       nir_def *value = nir_bcsel(
424          b, nir_ieq(b, intr->src[0].ssa, state->output_id), data, orig);
425 
426       nir_store_var(b, state->selected.outputs[slot], value,
427                     nir_component_mask(value->num_components));
428    }
429 }
430 
431 static bool
lower_to_gs_rast(nir_builder * b,nir_intrinsic_instr * intr,void * data)432 lower_to_gs_rast(nir_builder *b, nir_intrinsic_instr *intr, void *data)
433 {
434    struct lower_gs_rast_state *state = data;
435 
436    switch (intr->intrinsic) {
437    case nir_intrinsic_store_output:
438       lower_store_to_var(b, intr, &state->outputs);
439       return true;
440 
441    case nir_intrinsic_emit_vertex_with_counter:
442       select_rast_output(b, intr, state);
443       return true;
444 
445    case nir_intrinsic_load_primitive_id:
446       nir_def_rewrite_uses(&intr->def, state->primitive_id);
447       return true;
448 
449    case nir_intrinsic_load_instance_id:
450       nir_def_rewrite_uses(&intr->def, state->instance_id);
451       return true;
452 
453    case nir_intrinsic_load_flat_mask:
454    case nir_intrinsic_load_provoking_last:
455    case nir_intrinsic_load_input_topology_agx: {
456       /* Lowering the same in both GS variants */
457       return lower_id(b, intr, NULL);
458    }
459 
460    case nir_intrinsic_end_primitive_with_counter:
461    case nir_intrinsic_set_vertex_and_primitive_count:
462       nir_instr_remove(&intr->instr);
463       return true;
464 
465    default:
466       return false;
467    }
468 }
469 
470 /*
471  * Side effects in geometry shaders are problematic with our "GS rasterization
472  * shader" implementation. Where does the side effect happen? In the prepass?
473  * In the rast shader? In both?
474  *
475  * A perfect solution is impossible with rast shaders. Since the spec is loose
476  * here, we follow the principle of "least surprise":
477  *
478  * 1. Prefer side effects in the prepass over the rast shader. The prepass runs
479  *    once per API GS invocation so will match the expectations of buggy apps
480  *    not written for tilers.
481  *
482  * 2. If we must execute any side effect in the rast shader, try to execute all
483  *    side effects only in the rast shader. If some side effects must happen in
484  *    the rast shader and others don't, this gets consistent counts
485  *    (i.e. if the app expects plain stores and atomics to match up).
486  *
487  * 3. If we must execute side effects in both rast and the prepass,
488  *    execute all side effects in the rast shader and strip what we can from
489  *    the prepass. This gets the "unsurprising" behaviour from #2 without
490  *    falling over for ridiculous uses of atomics.
491  */
492 static bool
strip_side_effect_from_rast(nir_builder * b,nir_intrinsic_instr * intr,void * data)493 strip_side_effect_from_rast(nir_builder *b, nir_intrinsic_instr *intr,
494                             void *data)
495 {
496    switch (intr->intrinsic) {
497    case nir_intrinsic_store_global:
498    case nir_intrinsic_global_atomic:
499    case nir_intrinsic_global_atomic_swap:
500       break;
501    default:
502       return false;
503    }
504 
505    /* If there's a side effect that's actually required, keep it. */
506    if (nir_intrinsic_infos[intr->intrinsic].has_dest &&
507        !list_is_empty(&intr->def.uses)) {
508 
509       bool *any = data;
510       *any = true;
511       return false;
512    }
513 
514    /* Otherwise, remove the dead instruction. */
515    nir_instr_remove(&intr->instr);
516    return true;
517 }
518 
519 static bool
strip_side_effects_from_rast(nir_shader * s,bool * side_effects_for_rast)520 strip_side_effects_from_rast(nir_shader *s, bool *side_effects_for_rast)
521 {
522    bool progress, any;
523 
524    /* Rather than complex analysis, clone and try to remove as many side effects
525     * as possible. Then we check if we removed them all. We need to loop to
526     * handle complex control flow with side effects, where we can strip
527     * everything but can't figure that out with a simple one-shot analysis.
528     */
529    nir_shader *clone = nir_shader_clone(NULL, s);
530 
531    /* Drop as much as we can */
532    do {
533       progress = false;
534       any = false;
535       NIR_PASS(progress, clone, nir_shader_intrinsics_pass,
536                strip_side_effect_from_rast, nir_metadata_control_flow, &any);
537 
538       NIR_PASS(progress, clone, nir_opt_dce);
539       NIR_PASS(progress, clone, nir_opt_dead_cf);
540    } while (progress);
541 
542    ralloc_free(clone);
543 
544    /* If we need atomics, leave them in */
545    if (any) {
546       *side_effects_for_rast = true;
547       return false;
548    }
549 
550    /* Else strip it all */
551    do {
552       progress = false;
553       any = false;
554       NIR_PASS(progress, s, nir_shader_intrinsics_pass,
555                strip_side_effect_from_rast, nir_metadata_control_flow, &any);
556 
557       NIR_PASS(progress, s, nir_opt_dce);
558       NIR_PASS(progress, s, nir_opt_dead_cf);
559    } while (progress);
560 
561    assert(!any);
562    return progress;
563 }
564 
565 static bool
strip_side_effect_from_main(nir_builder * b,nir_intrinsic_instr * intr,void * data)566 strip_side_effect_from_main(nir_builder *b, nir_intrinsic_instr *intr,
567                             void *data)
568 {
569    switch (intr->intrinsic) {
570    case nir_intrinsic_global_atomic:
571    case nir_intrinsic_global_atomic_swap:
572       break;
573    default:
574       return false;
575    }
576 
577    if (list_is_empty(&intr->def.uses)) {
578       nir_instr_remove(&intr->instr);
579       return true;
580    }
581 
582    return false;
583 }
584 
585 /*
586  * Create a GS rasterization shader. This is a hardware vertex shader that
587  * shades each rasterized output vertex in parallel.
588  */
589 static nir_shader *
agx_nir_create_gs_rast_shader(const nir_shader * gs,const nir_shader * libagx,bool * side_effects_for_rast)590 agx_nir_create_gs_rast_shader(const nir_shader *gs, const nir_shader *libagx,
591                               bool *side_effects_for_rast)
592 {
593    /* Don't muck up the original shader */
594    nir_shader *shader = nir_shader_clone(NULL, gs);
595 
596    unsigned max_verts = output_vertex_id_stride(shader);
597 
598    /* Turn into a vertex shader run only for rasterization. Transform feedback
599     * was handled in the prepass.
600     */
601    shader->info.stage = MESA_SHADER_VERTEX;
602    shader->info.has_transform_feedback_varyings = false;
603    memset(&shader->info.vs, 0, sizeof(shader->info.vs));
604    shader->xfb_info = NULL;
605 
606    if (shader->info.name) {
607       shader->info.name = ralloc_asprintf(shader, "%s_rast", shader->info.name);
608    } else {
609       shader->info.name = "gs rast";
610    }
611 
612    nir_builder b_ =
613       nir_builder_at(nir_before_impl(nir_shader_get_entrypoint(shader)));
614    nir_builder *b = &b_;
615 
616    NIR_PASS(_, shader, strip_side_effects_from_rast, side_effects_for_rast);
617 
618    /* Optimize out pointless gl_PointSize outputs. Bizarrely, these occur. */
619    if (shader->info.gs.output_primitive != MESA_PRIM_POINTS)
620       shader->info.outputs_written &= ~VARYING_BIT_PSIZ;
621 
622    /* See calc_unrolled_index_id */
623    nir_def *raw_id = nir_load_vertex_id(b);
624    nir_def *output_id = nir_umod_imm(b, raw_id, max_verts);
625    nir_def *unrolled = nir_udiv_imm(b, raw_id, max_verts);
626 
627    nir_def *primitives_log2 = load_geometry_param(b, primitives_log2);
628    nir_def *instance_id = nir_ushr(b, unrolled, primitives_log2);
629    nir_def *primitive_id = nir_iand(
630       b, unrolled,
631       nir_iadd_imm(b, nir_ishl(b, nir_imm_int(b, 1), primitives_log2), -1));
632 
633    struct lower_gs_rast_state rast_state = {
634       .instance_id = instance_id,
635       .primitive_id = primitive_id,
636       .output_id = output_id,
637    };
638 
639    u_foreach_bit64(slot, shader->info.outputs_written) {
640       const char *slot_name =
641          gl_varying_slot_name_for_stage(slot, MESA_SHADER_GEOMETRY);
642 
643       bool scalar = (slot == VARYING_SLOT_PSIZ) ||
644                     (slot == VARYING_SLOT_LAYER) ||
645                     (slot == VARYING_SLOT_VIEWPORT);
646       unsigned comps = scalar ? 1 : 4;
647 
648       rast_state.outputs.outputs[slot] = nir_variable_create(
649          shader, nir_var_shader_temp, glsl_vector_type(GLSL_TYPE_UINT, comps),
650          ralloc_asprintf(shader, "%s-temp", slot_name));
651 
652       rast_state.selected.outputs[slot] = nir_variable_create(
653          shader, nir_var_shader_temp, glsl_vector_type(GLSL_TYPE_UINT, comps),
654          ralloc_asprintf(shader, "%s-selected", slot_name));
655    }
656 
657    nir_shader_intrinsics_pass(shader, lower_to_gs_rast,
658                               nir_metadata_control_flow, &rast_state);
659 
660    b->cursor = nir_after_impl(b->impl);
661 
662    /* Forward each selected output to the rasterizer */
663    u_foreach_bit64(slot, shader->info.outputs_written) {
664       assert(rast_state.selected.outputs[slot] != NULL);
665       nir_def *value = nir_load_var(b, rast_state.selected.outputs[slot]);
666 
667       /* We set NIR_COMPACT_ARRAYS so clip/cull distance needs to come all in
668        * DIST0. Undo the offset if we need to.
669        */
670       assert(slot != VARYING_SLOT_CULL_DIST1);
671       unsigned offset = 0;
672       if (slot == VARYING_SLOT_CLIP_DIST1)
673          offset = 1;
674 
675       nir_store_output(b, value, nir_imm_int(b, offset),
676                        .io_semantics.location = slot - offset,
677                        .io_semantics.num_slots = 1,
678                        .write_mask = nir_component_mask(value->num_components),
679                        .src_type = nir_type_uint32);
680    }
681 
682    /* It is legal to omit the point size write from the geometry shader when
683     * drawing points. In this case, the point size is implicitly 1.0. To
684     * implement, insert a synthetic `gl_PointSize = 1.0` write into the GS copy
685     * shader, if the GS does not export a point size while drawing points.
686     */
687    bool is_points = gs->info.gs.output_primitive == MESA_PRIM_POINTS;
688 
689    if (!(shader->info.outputs_written & VARYING_BIT_PSIZ) && is_points) {
690       nir_store_output(b, nir_imm_float(b, 1.0), nir_imm_int(b, 0),
691                        .io_semantics.location = VARYING_SLOT_PSIZ,
692                        .io_semantics.num_slots = 1,
693                        .write_mask = nir_component_mask(1),
694                        .src_type = nir_type_float32);
695 
696       shader->info.outputs_written |= VARYING_BIT_PSIZ;
697    }
698 
699    nir_opt_idiv_const(shader, 16);
700 
701    agx_preprocess_nir(shader, libagx);
702    return shader;
703 }
704 
705 static nir_def *
previous_count(nir_builder * b,struct lower_gs_state * state,unsigned stream,nir_def * unrolled_id,enum gs_counter counter)706 previous_count(nir_builder *b, struct lower_gs_state *state, unsigned stream,
707                nir_def *unrolled_id, enum gs_counter counter)
708 {
709    assert(stream < MAX_VERTEX_STREAMS);
710    assert(counter < GS_NUM_COUNTERS);
711    int static_count = state->static_count[counter][stream];
712 
713    if (static_count >= 0) {
714       /* If the number of outputted vertices per invocation is known statically,
715        * we can calculate the base.
716        */
717       return nir_imul_imm(b, unrolled_id, static_count);
718    } else {
719       /* Otherwise, we need to load from the prefix sum buffer. Note that the
720        * sums are inclusive, so index 0 is nonzero. This requires a little
721        * fixup here. We use a saturating unsigned subtraction so we don't read
722        * out-of-bounds for zero.
723        *
724        * TODO: Optimize this.
725        */
726       nir_def *prim_minus_1 = nir_usub_sat(b, unrolled_id, nir_imm_int(b, 1));
727       nir_def *addr =
728          load_count_address(b, state, prim_minus_1, stream, counter);
729 
730       return nir_bcsel(b, nir_ieq_imm(b, unrolled_id, 0), nir_imm_int(b, 0),
731                        nir_load_global_constant(b, addr, 4, 1, 32));
732    }
733 }
734 
735 static nir_def *
previous_vertices(nir_builder * b,struct lower_gs_state * state,unsigned stream,nir_def * unrolled_id)736 previous_vertices(nir_builder *b, struct lower_gs_state *state, unsigned stream,
737                   nir_def *unrolled_id)
738 {
739    return previous_count(b, state, stream, unrolled_id, GS_COUNTER_VERTICES);
740 }
741 
742 static nir_def *
previous_primitives(nir_builder * b,struct lower_gs_state * state,unsigned stream,nir_def * unrolled_id)743 previous_primitives(nir_builder *b, struct lower_gs_state *state,
744                     unsigned stream, nir_def *unrolled_id)
745 {
746    return previous_count(b, state, stream, unrolled_id, GS_COUNTER_PRIMITIVES);
747 }
748 
749 static nir_def *
previous_xfb_primitives(nir_builder * b,struct lower_gs_state * state,unsigned stream,nir_def * unrolled_id)750 previous_xfb_primitives(nir_builder *b, struct lower_gs_state *state,
751                         unsigned stream, nir_def *unrolled_id)
752 {
753    return previous_count(b, state, stream, unrolled_id,
754                          GS_COUNTER_XFB_PRIMITIVES);
755 }
756 
757 static void
lower_end_primitive(nir_builder * b,nir_intrinsic_instr * intr,struct lower_gs_state * state)758 lower_end_primitive(nir_builder *b, nir_intrinsic_instr *intr,
759                     struct lower_gs_state *state)
760 {
761    assert((intr->intrinsic == nir_intrinsic_set_vertex_and_primitive_count ||
762            b->shader->info.gs.output_primitive != MESA_PRIM_POINTS) &&
763           "endprimitive for points should've been removed");
764 
765    /* The GS is the last stage before rasterization, so if we discard the
766     * rasterization, we don't output an index buffer, nothing will read it.
767     * Index buffer is only for the rasterization stream.
768     */
769    unsigned stream = nir_intrinsic_stream_id(intr);
770    if (state->rasterizer_discard || stream != 0)
771       return;
772 
773    libagx_end_primitive(
774       b, load_geometry_param(b, output_index_buffer), intr->src[0].ssa,
775       intr->src[1].ssa, intr->src[2].ssa,
776       previous_vertices(b, state, 0, calc_unrolled_id(b)),
777       previous_primitives(b, state, 0, calc_unrolled_id(b)),
778       calc_unrolled_index_id(b),
779       nir_imm_bool(b, b->shader->info.gs.output_primitive != MESA_PRIM_POINTS));
780 }
781 
782 static unsigned
verts_in_output_prim(nir_shader * gs)783 verts_in_output_prim(nir_shader *gs)
784 {
785    return mesa_vertices_per_prim(gs->info.gs.output_primitive);
786 }
787 
788 static void
write_xfb(nir_builder * b,struct lower_gs_state * state,unsigned stream,nir_def * index_in_strip,nir_def * prim_id_in_invocation)789 write_xfb(nir_builder *b, struct lower_gs_state *state, unsigned stream,
790           nir_def *index_in_strip, nir_def *prim_id_in_invocation)
791 {
792    struct nir_xfb_info *xfb = b->shader->xfb_info;
793    unsigned verts = verts_in_output_prim(b->shader);
794 
795    /* Get the index of this primitive in the XFB buffer. That is, the base for
796     * this invocation for the stream plus the offset within this invocation.
797     */
798    nir_def *invocation_base =
799       previous_xfb_primitives(b, state, stream, calc_unrolled_id(b));
800 
801    nir_def *prim_index = nir_iadd(b, invocation_base, prim_id_in_invocation);
802    nir_def *base_index = nir_imul_imm(b, prim_index, verts);
803 
804    nir_def *xfb_prims = load_geometry_param(b, xfb_prims[stream]);
805    nir_push_if(b, nir_ult(b, prim_index, xfb_prims));
806 
807    /* Write XFB for each output */
808    for (unsigned i = 0; i < xfb->output_count; ++i) {
809       nir_xfb_output_info output = xfb->outputs[i];
810 
811       /* Only write to the selected stream */
812       if (xfb->buffer_to_stream[output.buffer] != stream)
813          continue;
814 
815       unsigned buffer = output.buffer;
816       unsigned stride = xfb->buffers[buffer].stride;
817       unsigned count = util_bitcount(output.component_mask);
818 
819       for (unsigned vert = 0; vert < verts; ++vert) {
820          /* We write out the vertices backwards, since 0 is the current
821           * emitted vertex (which is actually the last vertex).
822           *
823           * We handle NULL var for
824           * KHR-Single-GL44.enhanced_layouts.xfb_capture_struct.
825           */
826          unsigned v = (verts - 1) - vert;
827          nir_variable *var = state->outputs[output.location][v];
828          nir_def *value = var ? nir_load_var(b, var) : nir_undef(b, 4, 32);
829 
830          /* In case output.component_mask contains invalid components, write
831           * out zeroes instead of blowing up validation.
832           *
833           * KHR-Single-GL44.enhanced_layouts.xfb_capture_inactive_output_component
834           * hits this.
835           */
836          value = nir_pad_vector_imm_int(b, value, 0, 4);
837 
838          nir_def *rotated_vert = nir_imm_int(b, vert);
839          if (verts == 3) {
840             /* Map vertices for output so we get consistent winding order. For
841              * the primitive index, we use the index_in_strip. This is actually
842              * the vertex index in the strip, hence
843              * offset by 2 relative to the true primitive index (#2 for the
844              * first triangle in the strip, #3 for the second). That's ok
845              * because only the parity matters.
846              */
847             rotated_vert = libagx_map_vertex_in_tri_strip(
848                b, index_in_strip, rotated_vert,
849                nir_inot(b, nir_i2b(b, nir_load_provoking_last(b))));
850          }
851 
852          nir_def *addr = libagx_xfb_vertex_address(
853             b, nir_load_geometry_param_buffer_agx(b), base_index, rotated_vert,
854             nir_imm_int(b, buffer), nir_imm_int(b, stride),
855             nir_imm_int(b, output.offset));
856 
857          nir_store_global(b, addr, 4,
858                           nir_channels(b, value, output.component_mask),
859                           nir_component_mask(count));
860       }
861    }
862 
863    nir_pop_if(b, NULL);
864 }
865 
866 /* Handle transform feedback for a given emit_vertex_with_counter */
867 static void
lower_emit_vertex_xfb(nir_builder * b,nir_intrinsic_instr * intr,struct lower_gs_state * state)868 lower_emit_vertex_xfb(nir_builder *b, nir_intrinsic_instr *intr,
869                       struct lower_gs_state *state)
870 {
871    /* Transform feedback is written for each decomposed output primitive. Since
872     * we're writing strips, that means we output XFB for each vertex after the
873     * first complete primitive is formed.
874     */
875    unsigned first_prim = verts_in_output_prim(b->shader) - 1;
876    nir_def *index_in_strip = intr->src[1].ssa;
877 
878    nir_push_if(b, nir_uge_imm(b, index_in_strip, first_prim));
879    {
880       write_xfb(b, state, nir_intrinsic_stream_id(intr), index_in_strip,
881                 intr->src[3].ssa);
882    }
883    nir_pop_if(b, NULL);
884 
885    /* Transform feedback writes out entire primitives during the emit_vertex. To
886     * do that, we store the values at all vertices in the strip in a little ring
887     * buffer. Index #0 is always the most recent primitive (so non-XFB code can
888     * just grab index #0 without any checking). Index #1 is the previous vertex,
889     * and index #2 is the vertex before that. Now that we've written XFB, since
890     * we've emitted a vertex we need to cycle the ringbuffer, freeing up index
891     * #0 for the next vertex that we are about to emit. We do that by copying
892     * the first n - 1 vertices forward one slot, which has to happen with a
893     * backwards copy implemented here.
894     *
895     * If we're lucky, all of these copies will be propagated away. If we're
896     * unlucky, this involves at most 2 copies per component per XFB output per
897     * vertex.
898     */
899    u_foreach_bit64(slot, b->shader->info.outputs_written) {
900       /* Note: if we're outputting points, verts_in_output_prim will be 1, so
901        * this loop will not execute. This is intended: points are self-contained
902        * primitives and do not need these copies.
903        */
904       for (int v = verts_in_output_prim(b->shader) - 1; v >= 1; --v) {
905          nir_def *value = nir_load_var(b, state->outputs[slot][v - 1]);
906 
907          nir_store_var(b, state->outputs[slot][v], value,
908                        nir_component_mask(value->num_components));
909       }
910    }
911 }
912 
913 static bool
lower_gs_instr(nir_builder * b,nir_intrinsic_instr * intr,void * state)914 lower_gs_instr(nir_builder *b, nir_intrinsic_instr *intr, void *state)
915 {
916    b->cursor = nir_before_instr(&intr->instr);
917 
918    switch (intr->intrinsic) {
919    case nir_intrinsic_set_vertex_and_primitive_count:
920       /* This instruction is mostly for the count shader, so just remove. But
921        * for points, we write the index buffer here so the rast shader can map.
922        */
923       if (b->shader->info.gs.output_primitive == MESA_PRIM_POINTS) {
924          lower_end_primitive(b, intr, state);
925       }
926 
927       break;
928 
929    case nir_intrinsic_end_primitive_with_counter: {
930       unsigned min = verts_in_output_prim(b->shader);
931 
932       /* We only write out complete primitives */
933       nir_push_if(b, nir_uge_imm(b, intr->src[1].ssa, min));
934       {
935          lower_end_primitive(b, intr, state);
936       }
937       nir_pop_if(b, NULL);
938       break;
939    }
940 
941    case nir_intrinsic_emit_vertex_with_counter:
942       /* emit_vertex triggers transform feedback but is otherwise a no-op. */
943       if (b->shader->xfb_info)
944          lower_emit_vertex_xfb(b, intr, state);
945       break;
946 
947    default:
948       return false;
949    }
950 
951    nir_instr_remove(&intr->instr);
952    return true;
953 }
954 
955 static bool
collect_components(nir_builder * b,nir_intrinsic_instr * intr,void * data)956 collect_components(nir_builder *b, nir_intrinsic_instr *intr, void *data)
957 {
958    uint8_t *counts = data;
959    if (intr->intrinsic != nir_intrinsic_store_output)
960       return false;
961 
962    unsigned count = nir_intrinsic_component(intr) +
963                     util_last_bit(nir_intrinsic_write_mask(intr));
964 
965    unsigned loc =
966       nir_intrinsic_io_semantics(intr).location + nir_src_as_uint(intr->src[1]);
967 
968    uint8_t *total_count = &counts[loc];
969 
970    *total_count = MAX2(*total_count, count);
971    return true;
972 }
973 
974 /*
975  * Create the pre-GS shader. This is a small compute 1x1x1 kernel that produces
976  * an indirect draw to rasterize the produced geometry, as well as updates
977  * transform feedback offsets and counters as applicable.
978  */
979 static nir_shader *
agx_nir_create_pre_gs(struct lower_gs_state * state,const nir_shader * libagx,bool indexed,bool restart,struct nir_xfb_info * xfb,unsigned vertices_per_prim,uint8_t streams,unsigned invocations)980 agx_nir_create_pre_gs(struct lower_gs_state *state, const nir_shader *libagx,
981                       bool indexed, bool restart, struct nir_xfb_info *xfb,
982                       unsigned vertices_per_prim, uint8_t streams,
983                       unsigned invocations)
984 {
985    nir_builder b_ = nir_builder_init_simple_shader(
986       MESA_SHADER_COMPUTE, &agx_nir_options, "Pre-GS patch up");
987    nir_builder *b = &b_;
988 
989    /* Load the number of primitives input to the GS */
990    nir_def *unrolled_in_prims = load_geometry_param(b, input_primitives);
991 
992    /* Setup the draw from the rasterization stream (0). */
993    if (!state->rasterizer_discard) {
994       libagx_build_gs_draw(
995          b, nir_load_geometry_param_buffer_agx(b),
996          previous_vertices(b, state, 0, unrolled_in_prims),
997          restart ? previous_primitives(b, state, 0, unrolled_in_prims)
998                  : nir_imm_int(b, 0));
999    }
1000 
1001    /* Determine the number of primitives generated in each stream */
1002    nir_def *in_prims[MAX_VERTEX_STREAMS], *prims[MAX_VERTEX_STREAMS];
1003 
1004    u_foreach_bit(i, streams) {
1005       in_prims[i] = previous_xfb_primitives(b, state, i, unrolled_in_prims);
1006       prims[i] = in_prims[i];
1007 
1008       add_counter(b, load_geometry_param(b, prims_generated_counter[i]),
1009                   prims[i]);
1010    }
1011 
1012    if (xfb) {
1013       /* Write XFB addresses */
1014       nir_def *offsets[4] = {NULL};
1015       u_foreach_bit(i, xfb->buffers_written) {
1016          offsets[i] = libagx_setup_xfb_buffer(
1017             b, nir_load_geometry_param_buffer_agx(b), nir_imm_int(b, i));
1018       }
1019 
1020       /* Now clamp to the number that XFB captures */
1021       for (unsigned i = 0; i < xfb->output_count; ++i) {
1022          nir_xfb_output_info output = xfb->outputs[i];
1023 
1024          unsigned buffer = output.buffer;
1025          unsigned stream = xfb->buffer_to_stream[buffer];
1026          unsigned stride = xfb->buffers[buffer].stride;
1027          unsigned words_written = util_bitcount(output.component_mask);
1028          unsigned bytes_written = words_written * 4;
1029 
1030          /* Primitive P will write up to (but not including) offset:
1031           *
1032           *    xfb_offset + ((P - 1) * (verts_per_prim * stride))
1033           *               + ((verts_per_prim - 1) * stride)
1034           *               + output_offset
1035           *               + output_size
1036           *
1037           * Given an XFB buffer of size xfb_size, we get the inequality:
1038           *
1039           *    floor(P) <= (stride + xfb_size - xfb_offset - output_offset -
1040           *                     output_size) // (stride * verts_per_prim)
1041           */
1042          nir_def *size = load_geometry_param(b, xfb_size[buffer]);
1043          size = nir_iadd_imm(b, size, stride - output.offset - bytes_written);
1044          size = nir_isub(b, size, offsets[buffer]);
1045          size = nir_imax(b, size, nir_imm_int(b, 0));
1046          nir_def *max_prims = nir_udiv_imm(b, size, stride * vertices_per_prim);
1047 
1048          prims[stream] = nir_umin(b, prims[stream], max_prims);
1049       }
1050 
1051       nir_def *any_overflow = nir_imm_false(b);
1052 
1053       u_foreach_bit(i, streams) {
1054          nir_def *overflow = nir_ult(b, prims[i], in_prims[i]);
1055          any_overflow = nir_ior(b, any_overflow, overflow);
1056 
1057          store_geometry_param(b, xfb_prims[i], prims[i]);
1058 
1059          add_counter(b, load_geometry_param(b, xfb_overflow[i]),
1060                      nir_b2i32(b, overflow));
1061 
1062          add_counter(b, load_geometry_param(b, xfb_prims_generated_counter[i]),
1063                      prims[i]);
1064       }
1065 
1066       add_counter(b, load_geometry_param(b, xfb_any_overflow),
1067                   nir_b2i32(b, any_overflow));
1068 
1069       /* Update XFB counters */
1070       u_foreach_bit(i, xfb->buffers_written) {
1071          uint32_t prim_stride_B = xfb->buffers[i].stride * vertices_per_prim;
1072          unsigned stream = xfb->buffer_to_stream[i];
1073 
1074          nir_def *off_ptr = load_geometry_param(b, xfb_offs_ptrs[i]);
1075          nir_def *size = nir_imul_imm(b, prims[stream], prim_stride_B);
1076          add_counter(b, off_ptr, size);
1077       }
1078    }
1079 
1080    /* The geometry shader is invoked once per primitive (after unrolling
1081     * primitive restart). From the spec:
1082     *
1083     *    In case of instanced geometry shaders (see section 11.3.4.2) the
1084     *    geometry shader invocations count is incremented for each separate
1085     *    instanced invocation.
1086     */
1087    add_counter(b,
1088                nir_load_stat_query_address_agx(
1089                   b, .base = PIPE_STAT_QUERY_GS_INVOCATIONS),
1090                nir_imul_imm(b, unrolled_in_prims, invocations));
1091 
1092    nir_def *emitted_prims = nir_imm_int(b, 0);
1093    u_foreach_bit(i, streams) {
1094       emitted_prims =
1095          nir_iadd(b, emitted_prims,
1096                   previous_xfb_primitives(b, state, i, unrolled_in_prims));
1097    }
1098 
1099    add_counter(
1100       b,
1101       nir_load_stat_query_address_agx(b, .base = PIPE_STAT_QUERY_GS_PRIMITIVES),
1102       emitted_prims);
1103 
1104    /* Clipper queries are not well-defined, so we can emulate them in lots of
1105     * silly ways. We need the hardware counters to implement them properly. For
1106     * now, just consider all primitives emitted as passing through the clipper.
1107     * This satisfies spec text:
1108     *
1109     *    The number of primitives that reach the primitive clipping stage.
1110     *
1111     * and
1112     *
1113     *    If at least one vertex of the primitive lies inside the clipping
1114     *    volume, the counter is incremented by one or more. Otherwise, the
1115     *    counter is incremented by zero or more.
1116     */
1117    add_counter(
1118       b,
1119       nir_load_stat_query_address_agx(b, .base = PIPE_STAT_QUERY_C_PRIMITIVES),
1120       emitted_prims);
1121 
1122    add_counter(
1123       b,
1124       nir_load_stat_query_address_agx(b, .base = PIPE_STAT_QUERY_C_INVOCATIONS),
1125       emitted_prims);
1126 
1127    agx_preprocess_nir(b->shader, libagx);
1128    return b->shader;
1129 }
1130 
1131 static bool
rewrite_invocation_id(nir_builder * b,nir_intrinsic_instr * intr,void * data)1132 rewrite_invocation_id(nir_builder *b, nir_intrinsic_instr *intr, void *data)
1133 {
1134    if (intr->intrinsic != nir_intrinsic_load_invocation_id)
1135       return false;
1136 
1137    b->cursor = nir_instr_remove(&intr->instr);
1138    nir_def_rewrite_uses(&intr->def, nir_u2uN(b, data, intr->def.bit_size));
1139    return true;
1140 }
1141 
1142 /*
1143  * Geometry shader instancing allows a GS to run multiple times. The number of
1144  * times is statically known and small. It's easiest to turn this into a loop
1145  * inside the GS, to avoid the feature "leaking" outside and affecting e.g. the
1146  * counts.
1147  */
1148 static void
agx_nir_lower_gs_instancing(nir_shader * gs)1149 agx_nir_lower_gs_instancing(nir_shader *gs)
1150 {
1151    unsigned nr_invocations = gs->info.gs.invocations;
1152    nir_function_impl *impl = nir_shader_get_entrypoint(gs);
1153 
1154    /* Each invocation can produce up to the shader-declared max_vertices, so
1155     * multiply it up for proper bounds check. Emitting more than the declared
1156     * max_vertices per invocation results in undefined behaviour, so erroneously
1157     * emitting more as asked on early invocations is a perfectly cromulent
1158     * behvaiour.
1159     */
1160    gs->info.gs.vertices_out *= gs->info.gs.invocations;
1161 
1162    /* Get the original function */
1163    nir_cf_list list;
1164    nir_cf_extract(&list, nir_before_impl(impl), nir_after_impl(impl));
1165 
1166    /* Create a builder for the wrapped function */
1167    nir_builder b = nir_builder_at(nir_after_block(nir_start_block(impl)));
1168 
1169    nir_variable *i =
1170       nir_local_variable_create(impl, glsl_uintN_t_type(16), NULL);
1171    nir_store_var(&b, i, nir_imm_intN_t(&b, 0, 16), ~0);
1172    nir_def *index = NULL;
1173 
1174    /* Create a loop in the wrapped function */
1175    nir_loop *loop = nir_push_loop(&b);
1176    {
1177       index = nir_load_var(&b, i);
1178       nir_push_if(&b, nir_uge_imm(&b, index, nr_invocations));
1179       {
1180          nir_jump(&b, nir_jump_break);
1181       }
1182       nir_pop_if(&b, NULL);
1183 
1184       b.cursor = nir_cf_reinsert(&list, b.cursor);
1185       nir_store_var(&b, i, nir_iadd_imm(&b, index, 1), ~0);
1186 
1187       /* Make sure we end the primitive between invocations. If the geometry
1188        * shader already ended the primitive, this will get optimized out.
1189        */
1190       nir_end_primitive(&b);
1191    }
1192    nir_pop_loop(&b, loop);
1193 
1194    /* We've mucked about with control flow */
1195    nir_metadata_preserve(impl, nir_metadata_none);
1196 
1197    /* Use the loop counter as the invocation ID each iteration */
1198    nir_shader_intrinsics_pass(gs, rewrite_invocation_id,
1199                               nir_metadata_control_flow, index);
1200 }
1201 
1202 static void
link_libagx(nir_shader * nir,const nir_shader * libagx)1203 link_libagx(nir_shader *nir, const nir_shader *libagx)
1204 {
1205    nir_link_shader_functions(nir, libagx);
1206    NIR_PASS(_, nir, nir_inline_functions);
1207    nir_remove_non_entrypoints(nir);
1208    NIR_PASS(_, nir, nir_lower_indirect_derefs, nir_var_function_temp, 64);
1209    NIR_PASS(_, nir, nir_opt_dce);
1210    NIR_PASS(_, nir, nir_lower_vars_to_explicit_types,
1211             nir_var_shader_temp | nir_var_function_temp | nir_var_mem_shared,
1212             glsl_get_cl_type_size_align);
1213    NIR_PASS(_, nir, nir_opt_deref);
1214    NIR_PASS(_, nir, nir_lower_vars_to_ssa);
1215    NIR_PASS(_, nir, nir_lower_explicit_io,
1216             nir_var_shader_temp | nir_var_function_temp | nir_var_mem_shared |
1217                nir_var_mem_global,
1218             nir_address_format_62bit_generic);
1219 }
1220 
1221 bool
agx_nir_lower_gs(nir_shader * gs,const nir_shader * libagx,bool rasterizer_discard,nir_shader ** gs_count,nir_shader ** gs_copy,nir_shader ** pre_gs,enum mesa_prim * out_mode,unsigned * out_count_words)1222 agx_nir_lower_gs(nir_shader *gs, const nir_shader *libagx,
1223                  bool rasterizer_discard, nir_shader **gs_count,
1224                  nir_shader **gs_copy, nir_shader **pre_gs,
1225                  enum mesa_prim *out_mode, unsigned *out_count_words)
1226 {
1227    /* Lower I/O as assumed by the rest of GS lowering */
1228    if (gs->xfb_info != NULL) {
1229       NIR_PASS(_, gs, nir_io_add_const_offset_to_base,
1230                nir_var_shader_in | nir_var_shader_out);
1231       NIR_PASS(_, gs, nir_io_add_intrinsic_xfb_info);
1232    }
1233 
1234    NIR_PASS(_, gs, nir_lower_io_to_scalar, nir_var_shader_out, NULL, NULL);
1235 
1236    /* Collect output component counts so we can size the geometry output buffer
1237     * appropriately, instead of assuming everything is vec4.
1238     */
1239    uint8_t component_counts[NUM_TOTAL_VARYING_SLOTS] = {0};
1240    nir_shader_intrinsics_pass(gs, collect_components, nir_metadata_all,
1241                               component_counts);
1242 
1243    /* If geometry shader instancing is used, lower it away before linking
1244     * anything. Otherwise, smash the invocation ID to zero.
1245     */
1246    if (gs->info.gs.invocations != 1) {
1247       agx_nir_lower_gs_instancing(gs);
1248    } else {
1249       nir_function_impl *impl = nir_shader_get_entrypoint(gs);
1250       nir_builder b = nir_builder_at(nir_before_impl(impl));
1251 
1252       nir_shader_intrinsics_pass(gs, rewrite_invocation_id,
1253                                  nir_metadata_control_flow, nir_imm_int(&b, 0));
1254    }
1255 
1256    NIR_PASS(_, gs, nir_shader_intrinsics_pass, lower_gs_inputs,
1257             nir_metadata_control_flow, NULL);
1258 
1259    /* Lower geometry shader writes to contain all of the required counts, so we
1260     * know where in the various buffers we should write vertices.
1261     */
1262    NIR_PASS(_, gs, nir_lower_gs_intrinsics,
1263             nir_lower_gs_intrinsics_count_primitives |
1264                nir_lower_gs_intrinsics_per_stream |
1265                nir_lower_gs_intrinsics_count_vertices_per_primitive |
1266                nir_lower_gs_intrinsics_overwrite_incomplete |
1267                nir_lower_gs_intrinsics_always_end_primitive |
1268                nir_lower_gs_intrinsics_count_decomposed_primitives);
1269 
1270    /* Clean up after all that lowering we did */
1271    bool progress = false;
1272    do {
1273       progress = false;
1274       NIR_PASS(progress, gs, nir_lower_var_copies);
1275       NIR_PASS(progress, gs, nir_lower_variable_initializers,
1276                nir_var_shader_temp);
1277       NIR_PASS(progress, gs, nir_lower_vars_to_ssa);
1278       NIR_PASS(progress, gs, nir_copy_prop);
1279       NIR_PASS(progress, gs, nir_opt_constant_folding);
1280       NIR_PASS(progress, gs, nir_opt_algebraic);
1281       NIR_PASS(progress, gs, nir_opt_cse);
1282       NIR_PASS(progress, gs, nir_opt_dead_cf);
1283       NIR_PASS(progress, gs, nir_opt_dce);
1284 
1285       /* Unrolling lets us statically determine counts more often, which
1286        * otherwise would not be possible with multiple invocations even in the
1287        * simplest of cases.
1288        */
1289       NIR_PASS(progress, gs, nir_opt_loop_unroll);
1290    } while (progress);
1291 
1292    /* If we know counts at compile-time we can simplify, so try to figure out
1293     * the counts statically.
1294     */
1295    struct lower_gs_state gs_state = {
1296       .rasterizer_discard = rasterizer_discard,
1297    };
1298 
1299    nir_gs_count_vertices_and_primitives(
1300       gs, gs_state.static_count[GS_COUNTER_VERTICES],
1301       gs_state.static_count[GS_COUNTER_PRIMITIVES],
1302       gs_state.static_count[GS_COUNTER_XFB_PRIMITIVES], 4);
1303 
1304    /* Anything we don't know statically will be tracked by the count buffer.
1305     * Determine the layout for it.
1306     */
1307    for (unsigned i = 0; i < MAX_VERTEX_STREAMS; ++i) {
1308       for (unsigned c = 0; c < GS_NUM_COUNTERS; ++c) {
1309          gs_state.count_index[i][c] =
1310             (gs_state.static_count[c][i] < 0) ? gs_state.count_stride_el++ : -1;
1311       }
1312    }
1313 
1314    bool side_effects_for_rast = false;
1315    *gs_copy = agx_nir_create_gs_rast_shader(gs, libagx, &side_effects_for_rast);
1316 
1317    NIR_PASS(_, gs, nir_shader_intrinsics_pass, lower_id,
1318             nir_metadata_control_flow, NULL);
1319 
1320    link_libagx(gs, libagx);
1321 
1322    NIR_PASS(_, gs, nir_lower_idiv,
1323             &(const nir_lower_idiv_options){.allow_fp16 = true});
1324 
1325    /* All those variables we created should've gone away by now */
1326    NIR_PASS(_, gs, nir_remove_dead_variables, nir_var_function_temp, NULL);
1327 
1328    /* If there is any unknown count, we need a geometry count shader */
1329    if (gs_state.count_stride_el > 0)
1330       *gs_count = agx_nir_create_geometry_count_shader(gs, libagx, &gs_state);
1331    else
1332       *gs_count = NULL;
1333 
1334    /* Geometry shader outputs are staged to temporaries */
1335    struct agx_lower_output_to_var_state state = {0};
1336 
1337    u_foreach_bit64(slot, gs->info.outputs_written) {
1338       /* After enough optimizations, the shader metadata can go out of sync, fix
1339        * with our gathered info. Otherwise glsl_vector_type will assert fail.
1340        */
1341       if (component_counts[slot] == 0) {
1342          gs->info.outputs_written &= ~BITFIELD64_BIT(slot);
1343          continue;
1344       }
1345 
1346       const char *slot_name =
1347          gl_varying_slot_name_for_stage(slot, MESA_SHADER_GEOMETRY);
1348 
1349       for (unsigned i = 0; i < MAX_PRIM_OUT_SIZE; ++i) {
1350          gs_state.outputs[slot][i] = nir_variable_create(
1351             gs, nir_var_shader_temp,
1352             glsl_vector_type(GLSL_TYPE_UINT, component_counts[slot]),
1353             ralloc_asprintf(gs, "%s-%u", slot_name, i));
1354       }
1355 
1356       state.outputs[slot] = gs_state.outputs[slot][0];
1357    }
1358 
1359    NIR_PASS(_, gs, nir_shader_instructions_pass, agx_lower_output_to_var,
1360             nir_metadata_control_flow, &state);
1361 
1362    NIR_PASS(_, gs, nir_shader_intrinsics_pass, lower_gs_instr,
1363             nir_metadata_none, &gs_state);
1364 
1365    /* Determine if we are guaranteed to rasterize at least one vertex, so that
1366     * we can strip the prepass of side effects knowing they will execute in the
1367     * rasterization shader.
1368     */
1369    bool rasterizes_at_least_one_vertex =
1370       !rasterizer_discard && gs_state.static_count[0][0] > 0;
1371 
1372    /* Clean up after all that lowering we did */
1373    nir_lower_global_vars_to_local(gs);
1374    do {
1375       progress = false;
1376       NIR_PASS(progress, gs, nir_lower_var_copies);
1377       NIR_PASS(progress, gs, nir_lower_variable_initializers,
1378                nir_var_shader_temp);
1379       NIR_PASS(progress, gs, nir_lower_vars_to_ssa);
1380       NIR_PASS(progress, gs, nir_copy_prop);
1381       NIR_PASS(progress, gs, nir_opt_constant_folding);
1382       NIR_PASS(progress, gs, nir_opt_algebraic);
1383       NIR_PASS(progress, gs, nir_opt_cse);
1384       NIR_PASS(progress, gs, nir_opt_dead_cf);
1385       NIR_PASS(progress, gs, nir_opt_dce);
1386       NIR_PASS(progress, gs, nir_opt_loop_unroll);
1387 
1388    } while (progress);
1389 
1390    /* When rasterizing, we try to handle side effects sensibly. */
1391    if (rasterizes_at_least_one_vertex && side_effects_for_rast) {
1392       do {
1393          progress = false;
1394          NIR_PASS(progress, gs, nir_shader_intrinsics_pass,
1395                   strip_side_effect_from_main, nir_metadata_control_flow, NULL);
1396 
1397          NIR_PASS(progress, gs, nir_opt_dce);
1398          NIR_PASS(progress, gs, nir_opt_dead_cf);
1399       } while (progress);
1400    }
1401 
1402    /* All those variables we created should've gone away by now */
1403    NIR_PASS(_, gs, nir_remove_dead_variables, nir_var_function_temp, NULL);
1404 
1405    NIR_PASS(_, gs, nir_opt_sink, ~0);
1406    NIR_PASS(_, gs, nir_opt_move, ~0);
1407 
1408    NIR_PASS(_, gs, nir_shader_intrinsics_pass, lower_id,
1409             nir_metadata_control_flow, NULL);
1410 
1411    /* Create auxiliary programs */
1412    *pre_gs = agx_nir_create_pre_gs(
1413       &gs_state, libagx, true, gs->info.gs.output_primitive != MESA_PRIM_POINTS,
1414       gs->xfb_info, verts_in_output_prim(gs), gs->info.gs.active_stream_mask,
1415       gs->info.gs.invocations);
1416 
1417    /* Signal what primitive we want to draw the GS Copy VS with */
1418    *out_mode = gs->info.gs.output_primitive;
1419    *out_count_words = gs_state.count_stride_el;
1420    return true;
1421 }
1422 
1423 /*
1424  * Vertex shaders (tessellation evaluation shaders) before a geometry shader run
1425  * as a dedicated compute prepass. They are invoked as (count, instances, 1).
1426  * Their linear ID is therefore (instances * num vertices) + vertex ID.
1427  *
1428  * This function lowers their vertex shader I/O to compute.
1429  *
1430  * Vertex ID becomes an index buffer pull (without applying the topology). Store
1431  * output becomes a store into the global vertex output buffer.
1432  */
1433 static bool
lower_vs_before_gs(nir_builder * b,nir_intrinsic_instr * intr,void * data)1434 lower_vs_before_gs(nir_builder *b, nir_intrinsic_instr *intr, void *data)
1435 {
1436    if (intr->intrinsic != nir_intrinsic_store_output)
1437       return false;
1438 
1439    b->cursor = nir_instr_remove(&intr->instr);
1440    nir_io_semantics sem = nir_intrinsic_io_semantics(intr);
1441    nir_def *location = nir_iadd_imm(b, intr->src[1].ssa, sem.location);
1442 
1443    /* We inline the outputs_written because it's known at compile-time, even
1444     * with shader objects. This lets us constant fold a bit of address math.
1445     */
1446    nir_def *mask = nir_imm_int64(b, b->shader->info.outputs_written);
1447 
1448    nir_def *buffer;
1449    nir_def *nr_verts;
1450    if (b->shader->info.stage == MESA_SHADER_VERTEX) {
1451       buffer = nir_load_vs_output_buffer_agx(b);
1452       nr_verts =
1453          libagx_input_vertices(b, nir_load_input_assembly_buffer_agx(b));
1454    } else {
1455       assert(b->shader->info.stage == MESA_SHADER_TESS_EVAL);
1456 
1457       /* Instancing is unrolled during tessellation so nr_verts is ignored. */
1458       nr_verts = nir_imm_int(b, 0);
1459       buffer = libagx_tes_buffer(b, nir_load_tess_param_buffer_agx(b));
1460    }
1461 
1462    nir_def *linear_id = nir_iadd(b, nir_imul(b, load_instance_id(b), nr_verts),
1463                                  load_primitive_id(b));
1464 
1465    nir_def *addr =
1466       libagx_vertex_output_address(b, buffer, mask, linear_id, location);
1467 
1468    assert(nir_src_bit_size(intr->src[0]) == 32);
1469    addr = nir_iadd_imm(b, addr, nir_intrinsic_component(intr) * 4);
1470 
1471    nir_store_global(b, addr, 4, intr->src[0].ssa,
1472                     nir_intrinsic_write_mask(intr));
1473    return true;
1474 }
1475 
1476 bool
agx_nir_lower_vs_before_gs(struct nir_shader * vs,const struct nir_shader * libagx)1477 agx_nir_lower_vs_before_gs(struct nir_shader *vs,
1478                            const struct nir_shader *libagx)
1479 {
1480    bool progress = false;
1481 
1482    /* Lower vertex stores to memory stores */
1483    progress |= nir_shader_intrinsics_pass(vs, lower_vs_before_gs,
1484                                           nir_metadata_control_flow, NULL);
1485 
1486    /* Link libagx, used in lower_vs_before_gs */
1487    if (progress)
1488       link_libagx(vs, libagx);
1489 
1490    return progress;
1491 }
1492