1 /*
2 * Copyright 2023 Alyssa Rosenzweig
3 * Copyright 2023 Valve Corporation
4 * SPDX-License-Identifier: MIT
5 */
6
7 #include "agx_nir_lower_gs.h"
8 #include "asahi/compiler/agx_compile.h"
9 #include "compiler/nir/nir_builder.h"
10 #include "gallium/include/pipe/p_defines.h"
11 #include "shaders/geometry.h"
12 #include "util/bitscan.h"
13 #include "util/list.h"
14 #include "util/macros.h"
15 #include "util/ralloc.h"
16 #include "util/u_math.h"
17 #include "libagx_shaders.h"
18 #include "nir.h"
19 #include "nir_builder_opcodes.h"
20 #include "nir_intrinsics.h"
21 #include "nir_intrinsics_indices.h"
22 #include "nir_xfb_info.h"
23 #include "shader_enums.h"
24
25 /* Marks a transform feedback store, which must not be stripped from the
26 * prepass since that's where the transform feedback happens. Chosen as a
27 * vendored flag not to alias other flags we'll see.
28 */
29 #define ACCESS_XFB (ACCESS_IS_SWIZZLED_AMD)
30
31 enum gs_counter {
32 GS_COUNTER_VERTICES = 0,
33 GS_COUNTER_PRIMITIVES,
34 GS_COUNTER_XFB_PRIMITIVES,
35 GS_NUM_COUNTERS
36 };
37
38 #define MAX_PRIM_OUT_SIZE 3
39
40 struct lower_gs_state {
41 int static_count[GS_NUM_COUNTERS][MAX_VERTEX_STREAMS];
42 nir_variable *outputs[NUM_TOTAL_VARYING_SLOTS][MAX_PRIM_OUT_SIZE];
43
44 /* The count buffer contains `count_stride_el` 32-bit words in a row for each
45 * input primitive, for `input_primitives * count_stride_el * 4` total bytes.
46 */
47 unsigned count_stride_el;
48
49 /* The index of each counter in the count buffer, or -1 if it's not in the
50 * count buffer.
51 *
52 * Invariant: count_stride_el == sum(count_index[i][j] >= 0).
53 */
54 int count_index[MAX_VERTEX_STREAMS][GS_NUM_COUNTERS];
55
56 bool rasterizer_discard;
57 };
58
59 /* Helpers for loading from the geometry state buffer */
60 static nir_def *
load_geometry_param_offset(nir_builder * b,uint32_t offset,uint8_t bytes)61 load_geometry_param_offset(nir_builder *b, uint32_t offset, uint8_t bytes)
62 {
63 nir_def *base = nir_load_geometry_param_buffer_agx(b);
64 nir_def *addr = nir_iadd_imm(b, base, offset);
65
66 assert((offset % bytes) == 0 && "must be naturally aligned");
67
68 return nir_load_global_constant(b, addr, bytes, 1, bytes * 8);
69 }
70
71 static void
store_geometry_param_offset(nir_builder * b,nir_def * def,uint32_t offset,uint8_t bytes)72 store_geometry_param_offset(nir_builder *b, nir_def *def, uint32_t offset,
73 uint8_t bytes)
74 {
75 nir_def *base = nir_load_geometry_param_buffer_agx(b);
76 nir_def *addr = nir_iadd_imm(b, base, offset);
77
78 assert((offset % bytes) == 0 && "must be naturally aligned");
79
80 nir_store_global(b, addr, 4, def, nir_component_mask(def->num_components));
81 }
82
83 #define store_geometry_param(b, field, def) \
84 store_geometry_param_offset( \
85 b, def, offsetof(struct agx_geometry_params, field), \
86 sizeof(((struct agx_geometry_params *)0)->field))
87
88 #define load_geometry_param(b, field) \
89 load_geometry_param_offset( \
90 b, offsetof(struct agx_geometry_params, field), \
91 sizeof(((struct agx_geometry_params *)0)->field))
92
93 /* Helper for updating counters */
94 static void
add_counter(nir_builder * b,nir_def * counter,nir_def * increment)95 add_counter(nir_builder *b, nir_def *counter, nir_def *increment)
96 {
97 /* If the counter is NULL, the counter is disabled. Skip the update. */
98 nir_if *nif = nir_push_if(b, nir_ine_imm(b, counter, 0));
99 {
100 nir_def *old = nir_load_global(b, counter, 4, 1, 32);
101 nir_def *new_ = nir_iadd(b, old, increment);
102 nir_store_global(b, counter, 4, new_, nir_component_mask(1));
103 }
104 nir_pop_if(b, nif);
105 }
106
107 /* Helpers for lowering I/O to variables */
108 static void
lower_store_to_var(nir_builder * b,nir_intrinsic_instr * intr,struct agx_lower_output_to_var_state * state)109 lower_store_to_var(nir_builder *b, nir_intrinsic_instr *intr,
110 struct agx_lower_output_to_var_state *state)
111 {
112 b->cursor = nir_instr_remove(&intr->instr);
113 nir_io_semantics sem = nir_intrinsic_io_semantics(intr);
114 unsigned component = nir_intrinsic_component(intr);
115 nir_def *value = intr->src[0].ssa;
116
117 assert(nir_src_is_const(intr->src[1]) && "no indirect outputs");
118 assert(nir_intrinsic_write_mask(intr) == nir_component_mask(1) &&
119 "should be scalarized");
120
121 nir_variable *var =
122 state->outputs[sem.location + nir_src_as_uint(intr->src[1])];
123 if (!var) {
124 assert(sem.location == VARYING_SLOT_PSIZ &&
125 "otherwise in outputs_written");
126 return;
127 }
128
129 unsigned nr_components = glsl_get_components(glsl_without_array(var->type));
130 assert(component < nr_components);
131
132 /* Turn it into a vec4 write like NIR expects */
133 value = nir_vector_insert_imm(b, nir_undef(b, nr_components, 32), value,
134 component);
135
136 nir_store_var(b, var, value, BITFIELD_BIT(component));
137 }
138
139 bool
agx_lower_output_to_var(nir_builder * b,nir_instr * instr,void * data)140 agx_lower_output_to_var(nir_builder *b, nir_instr *instr, void *data)
141 {
142 if (instr->type != nir_instr_type_intrinsic)
143 return false;
144
145 nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
146 if (intr->intrinsic != nir_intrinsic_store_output)
147 return false;
148
149 lower_store_to_var(b, intr, data);
150 return true;
151 }
152
153 /*
154 * Geometry shader invocations are compute-like:
155 *
156 * (primitive ID, instance ID, 1)
157 */
158 static nir_def *
load_primitive_id(nir_builder * b)159 load_primitive_id(nir_builder *b)
160 {
161 return nir_channel(b, nir_load_global_invocation_id(b, 32), 0);
162 }
163
164 static nir_def *
load_instance_id(nir_builder * b)165 load_instance_id(nir_builder *b)
166 {
167 return nir_channel(b, nir_load_global_invocation_id(b, 32), 1);
168 }
169
170 static bool
lower_gs_inputs(nir_builder * b,nir_intrinsic_instr * intr,void * _)171 lower_gs_inputs(nir_builder *b, nir_intrinsic_instr *intr, void *_)
172 {
173 if (intr->intrinsic != nir_intrinsic_load_per_vertex_input)
174 return false;
175
176 b->cursor = nir_instr_remove(&intr->instr);
177 nir_io_semantics sem = nir_intrinsic_io_semantics(intr);
178
179 nir_def *location = nir_iadd_imm(b, intr->src[1].ssa, sem.location);
180
181 /* Calculate the vertex ID we're pulling, based on the topology class */
182 nir_def *vert_in_prim = intr->src[0].ssa;
183 nir_def *vertex = agx_vertex_id_for_topology_class(
184 b, vert_in_prim, b->shader->info.gs.input_primitive);
185
186 /* The unrolled vertex ID uses the input_vertices, which differs from what
187 * our load_num_vertices will return (vertices vs primitives).
188 */
189 nir_def *unrolled =
190 nir_iadd(b,
191 nir_imul(b, nir_load_instance_id(b),
192 load_geometry_param(b, input_vertices)),
193 vertex);
194
195 /* Calculate the address of the input given the unrolled vertex ID */
196 nir_def *addr = libagx_vertex_output_address(
197 b, nir_load_geometry_param_buffer_agx(b), unrolled, location,
198 load_geometry_param(b, vs_outputs));
199
200 assert(intr->def.bit_size == 32);
201 addr = nir_iadd_imm(b, addr, nir_intrinsic_component(intr) * 4);
202
203 nir_def *val = nir_load_global_constant(b, addr, 4, intr->def.num_components,
204 intr->def.bit_size);
205 nir_def_rewrite_uses(&intr->def, val);
206 return true;
207 }
208
209 /*
210 * Unrolled ID is the index of the primitive in the count buffer, given as
211 * (instance ID * # vertices/instance) + vertex ID
212 */
213 static nir_def *
calc_unrolled_id(nir_builder * b)214 calc_unrolled_id(nir_builder *b)
215 {
216 return nir_iadd(b,
217 nir_imul(b, load_instance_id(b), nir_load_num_vertices(b)),
218 load_primitive_id(b));
219 }
220
221 static unsigned
output_vertex_id_stride(nir_shader * gs)222 output_vertex_id_stride(nir_shader *gs)
223 {
224 /* round up to power of two for cheap multiply/division */
225 return util_next_power_of_two(MAX2(gs->info.gs.vertices_out, 1));
226 }
227
228 /* Variant of calc_unrolled_id that uses a power-of-two stride for indices. This
229 * is sparser (acceptable for index buffer values, not for count buffer
230 * indices). It has the nice property of being cheap to invert, unlike
231 * calc_unrolled_id. So, we use calc_unrolled_id for count buffers and
232 * calc_unrolled_index_id for index values.
233 *
234 * This also multiplies by the appropriate stride to calculate the final index
235 * base value.
236 */
237 static nir_def *
calc_unrolled_index_id(nir_builder * b)238 calc_unrolled_index_id(nir_builder *b)
239 {
240 unsigned vertex_stride = output_vertex_id_stride(b->shader);
241 nir_def *primitives_log2 = load_geometry_param(b, primitives_log2);
242
243 nir_def *instance = nir_ishl(b, load_instance_id(b), primitives_log2);
244 nir_def *prim = nir_iadd(b, instance, load_primitive_id(b));
245
246 return nir_imul_imm(b, prim, vertex_stride);
247 }
248
249 static nir_def *
load_count_address(nir_builder * b,struct lower_gs_state * state,nir_def * unrolled_id,unsigned stream,enum gs_counter counter)250 load_count_address(nir_builder *b, struct lower_gs_state *state,
251 nir_def *unrolled_id, unsigned stream,
252 enum gs_counter counter)
253 {
254 int index = state->count_index[stream][counter];
255 if (index < 0)
256 return NULL;
257
258 nir_def *prim_offset_el =
259 nir_imul_imm(b, unrolled_id, state->count_stride_el);
260
261 nir_def *offset_el = nir_iadd_imm(b, prim_offset_el, index);
262
263 return nir_iadd(b, load_geometry_param(b, count_buffer),
264 nir_u2u64(b, nir_imul_imm(b, offset_el, 4)));
265 }
266
267 static void
write_counts(nir_builder * b,nir_intrinsic_instr * intr,struct lower_gs_state * state)268 write_counts(nir_builder *b, nir_intrinsic_instr *intr,
269 struct lower_gs_state *state)
270 {
271 /* Store each required counter */
272 nir_def *counts[GS_NUM_COUNTERS] = {
273 [GS_COUNTER_VERTICES] = intr->src[0].ssa,
274 [GS_COUNTER_PRIMITIVES] = intr->src[1].ssa,
275 [GS_COUNTER_XFB_PRIMITIVES] = intr->src[2].ssa,
276 };
277
278 for (unsigned i = 0; i < GS_NUM_COUNTERS; ++i) {
279 nir_def *addr = load_count_address(b, state, calc_unrolled_id(b),
280 nir_intrinsic_stream_id(intr), i);
281
282 if (addr)
283 nir_store_global(b, addr, 4, counts[i], nir_component_mask(1));
284 }
285 }
286
287 static bool
lower_gs_count_instr(nir_builder * b,nir_intrinsic_instr * intr,void * data)288 lower_gs_count_instr(nir_builder *b, nir_intrinsic_instr *intr, void *data)
289 {
290 switch (intr->intrinsic) {
291 case nir_intrinsic_emit_vertex_with_counter:
292 case nir_intrinsic_end_primitive_with_counter:
293 case nir_intrinsic_store_output:
294 /* These are for the main shader, just remove them */
295 nir_instr_remove(&intr->instr);
296 return true;
297
298 case nir_intrinsic_set_vertex_and_primitive_count:
299 b->cursor = nir_instr_remove(&intr->instr);
300 write_counts(b, intr, data);
301 return true;
302
303 default:
304 return false;
305 }
306 }
307
308 static bool
lower_id(nir_builder * b,nir_intrinsic_instr * intr,void * data)309 lower_id(nir_builder *b, nir_intrinsic_instr *intr, void *data)
310 {
311 b->cursor = nir_before_instr(&intr->instr);
312
313 nir_def *id;
314 if (intr->intrinsic == nir_intrinsic_load_primitive_id)
315 id = load_primitive_id(b);
316 else if (intr->intrinsic == nir_intrinsic_load_instance_id)
317 id = load_instance_id(b);
318 else if (intr->intrinsic == nir_intrinsic_load_num_vertices)
319 id = nir_channel(b, nir_load_num_workgroups(b), 0);
320 else if (intr->intrinsic == nir_intrinsic_load_flat_mask)
321 id = load_geometry_param(b, flat_outputs);
322 else if (intr->intrinsic == nir_intrinsic_load_input_topology_agx)
323 id = load_geometry_param(b, input_topology);
324 else if (intr->intrinsic == nir_intrinsic_load_provoking_last) {
325 id = nir_b2b32(
326 b, libagx_is_provoking_last(b, nir_load_input_assembly_buffer_agx(b)));
327 } else
328 return false;
329
330 b->cursor = nir_instr_remove(&intr->instr);
331 nir_def_rewrite_uses(&intr->def, id);
332 return true;
333 }
334
335 /*
336 * Create a "Geometry count" shader. This is a stripped down geometry shader
337 * that just write its number of emitted vertices / primitives / transform
338 * feedback primitives to a count buffer. That count buffer will be prefix
339 * summed prior to running the real geometry shader. This is skipped if the
340 * counts are statically known.
341 */
342 static nir_shader *
agx_nir_create_geometry_count_shader(nir_shader * gs,const nir_shader * libagx,struct lower_gs_state * state)343 agx_nir_create_geometry_count_shader(nir_shader *gs, const nir_shader *libagx,
344 struct lower_gs_state *state)
345 {
346 /* Don't muck up the original shader */
347 nir_shader *shader = nir_shader_clone(NULL, gs);
348
349 if (shader->info.name) {
350 shader->info.name =
351 ralloc_asprintf(shader, "%s_count", shader->info.name);
352 } else {
353 shader->info.name = "count";
354 }
355
356 NIR_PASS(_, shader, nir_shader_intrinsics_pass, lower_gs_count_instr,
357 nir_metadata_block_index | nir_metadata_dominance, state);
358
359 NIR_PASS(_, shader, nir_shader_intrinsics_pass, lower_id,
360 nir_metadata_block_index | nir_metadata_dominance, NULL);
361
362 /* Preprocess it */
363 UNUSED struct agx_uncompiled_shader_info info;
364 agx_preprocess_nir(shader, libagx, false, &info);
365
366 return shader;
367 }
368
369 struct lower_gs_rast_state {
370 nir_def *instance_id, *primitive_id, *output_id;
371 struct agx_lower_output_to_var_state outputs;
372 struct agx_lower_output_to_var_state selected;
373 };
374
375 static void
select_rast_output(nir_builder * b,nir_intrinsic_instr * intr,struct lower_gs_rast_state * state)376 select_rast_output(nir_builder *b, nir_intrinsic_instr *intr,
377 struct lower_gs_rast_state *state)
378 {
379 b->cursor = nir_instr_remove(&intr->instr);
380
381 /* We only care about the rasterization stream in the rasterization
382 * shader, so just ignore emits from other streams.
383 */
384 if (nir_intrinsic_stream_id(intr) != 0)
385 return;
386
387 u_foreach_bit64(slot, b->shader->info.outputs_written) {
388 nir_def *orig = nir_load_var(b, state->selected.outputs[slot]);
389 nir_def *data = nir_load_var(b, state->outputs.outputs[slot]);
390
391 nir_def *value = nir_bcsel(
392 b, nir_ieq(b, intr->src[0].ssa, state->output_id), data, orig);
393
394 nir_store_var(b, state->selected.outputs[slot], value,
395 nir_component_mask(value->num_components));
396 }
397 }
398
399 static bool
lower_to_gs_rast(nir_builder * b,nir_intrinsic_instr * intr,void * data)400 lower_to_gs_rast(nir_builder *b, nir_intrinsic_instr *intr, void *data)
401 {
402 struct lower_gs_rast_state *state = data;
403
404 switch (intr->intrinsic) {
405 case nir_intrinsic_store_output:
406 lower_store_to_var(b, intr, &state->outputs);
407 return true;
408
409 case nir_intrinsic_emit_vertex_with_counter:
410 select_rast_output(b, intr, state);
411 return true;
412
413 case nir_intrinsic_load_primitive_id:
414 nir_def_rewrite_uses(&intr->def, state->primitive_id);
415 return true;
416
417 case nir_intrinsic_load_instance_id:
418 nir_def_rewrite_uses(&intr->def, state->instance_id);
419 return true;
420
421 case nir_intrinsic_load_num_vertices: {
422 b->cursor = nir_before_instr(&intr->instr);
423 nir_def_rewrite_uses(&intr->def, load_geometry_param(b, gs_grid[0]));
424 return true;
425 }
426
427 case nir_intrinsic_load_flat_mask:
428 case nir_intrinsic_load_provoking_last:
429 case nir_intrinsic_load_input_topology_agx:
430 /* Lowering the same in both GS variants */
431 return lower_id(b, intr, data);
432
433 case nir_intrinsic_end_primitive_with_counter:
434 case nir_intrinsic_set_vertex_and_primitive_count:
435 nir_instr_remove(&intr->instr);
436 return true;
437
438 default:
439 return false;
440 }
441 }
442
443 /*
444 * Create a GS rasterization shader. This is a hardware vertex shader that
445 * shades each rasterized output vertex in parallel.
446 */
447 static nir_shader *
agx_nir_create_gs_rast_shader(const nir_shader * gs,const nir_shader * libagx)448 agx_nir_create_gs_rast_shader(const nir_shader *gs, const nir_shader *libagx)
449 {
450 /* Don't muck up the original shader */
451 nir_shader *shader = nir_shader_clone(NULL, gs);
452
453 unsigned max_verts = output_vertex_id_stride(shader);
454
455 /* Turn into a vertex shader run only for rasterization. Transform feedback
456 * was handled in the prepass.
457 */
458 shader->info.stage = MESA_SHADER_VERTEX;
459 shader->info.has_transform_feedback_varyings = false;
460 memset(&shader->info.vs, 0, sizeof(shader->info.vs));
461 shader->xfb_info = NULL;
462
463 if (shader->info.name) {
464 shader->info.name = ralloc_asprintf(shader, "%s_rast", shader->info.name);
465 } else {
466 shader->info.name = "gs rast";
467 }
468
469 nir_builder b_ =
470 nir_builder_at(nir_before_impl(nir_shader_get_entrypoint(shader)));
471 nir_builder *b = &b_;
472
473 /* Optimize out pointless gl_PointSize outputs. Bizarrely, these occur. */
474 if (shader->info.gs.output_primitive != MESA_PRIM_POINTS)
475 shader->info.outputs_written &= ~VARYING_BIT_PSIZ;
476
477 /* See calc_unrolled_index_id */
478 nir_def *raw_id = nir_load_vertex_id(b);
479 nir_def *output_id = nir_umod_imm(b, raw_id, max_verts);
480 nir_def *unrolled = nir_udiv_imm(b, raw_id, max_verts);
481
482 nir_def *primitives_log2 = load_geometry_param(b, primitives_log2);
483 nir_def *instance_id = nir_ushr(b, unrolled, primitives_log2);
484 nir_def *primitive_id = nir_iand(
485 b, unrolled,
486 nir_iadd_imm(b, nir_ishl(b, nir_imm_int(b, 1), primitives_log2), -1));
487
488 struct lower_gs_rast_state rast_state = {
489 .instance_id = instance_id,
490 .primitive_id = primitive_id,
491 .output_id = output_id,
492 };
493
494 u_foreach_bit64(slot, shader->info.outputs_written) {
495 const char *slot_name =
496 gl_varying_slot_name_for_stage(slot, MESA_SHADER_GEOMETRY);
497
498 rast_state.outputs.outputs[slot] = nir_variable_create(
499 shader, nir_var_shader_temp, glsl_vector_type(GLSL_TYPE_UINT, 4),
500 ralloc_asprintf(shader, "%s-temp", slot_name));
501
502 rast_state.selected.outputs[slot] = nir_variable_create(
503 shader, nir_var_shader_temp, glsl_vector_type(GLSL_TYPE_UINT, 4),
504 ralloc_asprintf(shader, "%s-selected", slot_name));
505 }
506
507 nir_shader_intrinsics_pass(shader, lower_to_gs_rast,
508 nir_metadata_block_index | nir_metadata_dominance,
509 &rast_state);
510
511 b->cursor = nir_after_impl(b->impl);
512
513 /* Forward each selected output to the rasterizer */
514 u_foreach_bit64(slot, shader->info.outputs_written) {
515 assert(rast_state.selected.outputs[slot] != NULL);
516 nir_def *value = nir_load_var(b, rast_state.selected.outputs[slot]);
517
518 /* We set NIR_COMPACT_ARRAYS so clip/cull distance needs to come all in
519 * DIST0. Undo the offset if we need to.
520 */
521 unsigned offset = 0;
522 if (slot == VARYING_SLOT_CULL_DIST1 || slot == VARYING_SLOT_CLIP_DIST1)
523 offset = 1;
524
525 nir_store_output(b, value, nir_imm_int(b, offset),
526 .io_semantics.location = slot - offset,
527 .io_semantics.num_slots = 1,
528 .write_mask = nir_component_mask(value->num_components));
529 }
530
531 /* In OpenGL ES, it is legal to omit the point size write from the geometry
532 * shader when drawing points. In this case, the point size is
533 * implicitly 1.0. We implement this by inserting this synthetic
534 * `gl_PointSize = 1.0` write into the GS copy shader, if the GS does not
535 * export a point size while drawing points.
536 *
537 * This should not be load bearing for other APIs, but should be harmless.
538 */
539 bool is_points = gs->info.gs.output_primitive == MESA_PRIM_POINTS;
540
541 if (!(shader->info.outputs_written & VARYING_BIT_PSIZ) && is_points) {
542 nir_store_output(b, nir_imm_float(b, 1.0), nir_imm_int(b, 0),
543 .io_semantics.location = VARYING_SLOT_PSIZ,
544 .io_semantics.num_slots = 1,
545 .write_mask = nir_component_mask(1));
546
547 shader->info.outputs_written |= VARYING_BIT_PSIZ;
548 }
549
550 nir_opt_idiv_const(shader, 16);
551
552 /* Preprocess it */
553 UNUSED struct agx_uncompiled_shader_info info;
554 agx_preprocess_nir(shader, libagx, false, &info);
555
556 return shader;
557 }
558
559 static nir_def *
previous_count(nir_builder * b,struct lower_gs_state * state,unsigned stream,nir_def * unrolled_id,enum gs_counter counter)560 previous_count(nir_builder *b, struct lower_gs_state *state, unsigned stream,
561 nir_def *unrolled_id, enum gs_counter counter)
562 {
563 assert(stream < MAX_VERTEX_STREAMS);
564 assert(counter < GS_NUM_COUNTERS);
565 int static_count = state->static_count[counter][stream];
566
567 if (static_count >= 0) {
568 /* If the number of outputted vertices per invocation is known statically,
569 * we can calculate the base.
570 */
571 return nir_imul_imm(b, unrolled_id, static_count);
572 } else {
573 /* Otherwise, we need to load from the prefix sum buffer. Note that the
574 * sums are inclusive, so index 0 is nonzero. This requires a little
575 * fixup here. We use a saturating unsigned subtraction so we don't read
576 * out-of-bounds for zero.
577 *
578 * TODO: Optimize this.
579 */
580 nir_def *prim_minus_1 = nir_usub_sat(b, unrolled_id, nir_imm_int(b, 1));
581 nir_def *addr =
582 load_count_address(b, state, prim_minus_1, stream, counter);
583
584 return nir_bcsel(b, nir_ieq_imm(b, unrolled_id, 0), nir_imm_int(b, 0),
585 nir_load_global_constant(b, addr, 4, 1, 32));
586 }
587 }
588
589 static nir_def *
previous_vertices(nir_builder * b,struct lower_gs_state * state,unsigned stream,nir_def * unrolled_id)590 previous_vertices(nir_builder *b, struct lower_gs_state *state, unsigned stream,
591 nir_def *unrolled_id)
592 {
593 return previous_count(b, state, stream, unrolled_id, GS_COUNTER_VERTICES);
594 }
595
596 static nir_def *
previous_primitives(nir_builder * b,struct lower_gs_state * state,unsigned stream,nir_def * unrolled_id)597 previous_primitives(nir_builder *b, struct lower_gs_state *state,
598 unsigned stream, nir_def *unrolled_id)
599 {
600 return previous_count(b, state, stream, unrolled_id, GS_COUNTER_PRIMITIVES);
601 }
602
603 static nir_def *
previous_xfb_primitives(nir_builder * b,struct lower_gs_state * state,unsigned stream,nir_def * unrolled_id)604 previous_xfb_primitives(nir_builder *b, struct lower_gs_state *state,
605 unsigned stream, nir_def *unrolled_id)
606 {
607 return previous_count(b, state, stream, unrolled_id,
608 GS_COUNTER_XFB_PRIMITIVES);
609 }
610
611 static void
lower_end_primitive(nir_builder * b,nir_intrinsic_instr * intr,struct lower_gs_state * state)612 lower_end_primitive(nir_builder *b, nir_intrinsic_instr *intr,
613 struct lower_gs_state *state)
614 {
615 assert((intr->intrinsic == nir_intrinsic_set_vertex_and_primitive_count ||
616 b->shader->info.gs.output_primitive != MESA_PRIM_POINTS) &&
617 "endprimitive for points should've been removed");
618
619 /* The GS is the last stage before rasterization, so if we discard the
620 * rasterization, we don't output an index buffer, nothing will read it.
621 * Index buffer is only for the rasterization stream.
622 */
623 unsigned stream = nir_intrinsic_stream_id(intr);
624 if (state->rasterizer_discard || stream != 0)
625 return;
626
627 libagx_end_primitive(
628 b, load_geometry_param(b, output_index_buffer), intr->src[0].ssa,
629 intr->src[1].ssa, intr->src[2].ssa,
630 previous_vertices(b, state, 0, calc_unrolled_id(b)),
631 previous_primitives(b, state, 0, calc_unrolled_id(b)),
632 calc_unrolled_index_id(b),
633 nir_imm_bool(b, b->shader->info.gs.output_primitive != MESA_PRIM_POINTS));
634 }
635
636 static unsigned
verts_in_output_prim(nir_shader * gs)637 verts_in_output_prim(nir_shader *gs)
638 {
639 return mesa_vertices_per_prim(gs->info.gs.output_primitive);
640 }
641
642 static void
write_xfb(nir_builder * b,struct lower_gs_state * state,unsigned stream,nir_def * index_in_strip,nir_def * prim_id_in_invocation)643 write_xfb(nir_builder *b, struct lower_gs_state *state, unsigned stream,
644 nir_def *index_in_strip, nir_def *prim_id_in_invocation)
645 {
646 struct nir_xfb_info *xfb = b->shader->xfb_info;
647 unsigned verts = verts_in_output_prim(b->shader);
648
649 /* Get the index of this primitive in the XFB buffer. That is, the base for
650 * this invocation for the stream plus the offset within this invocation.
651 */
652 nir_def *invocation_base =
653 previous_xfb_primitives(b, state, stream, calc_unrolled_id(b));
654
655 nir_def *prim_index = nir_iadd(b, invocation_base, prim_id_in_invocation);
656 nir_def *base_index = nir_imul_imm(b, prim_index, verts);
657
658 nir_def *xfb_prims = load_geometry_param(b, xfb_prims[stream]);
659 nir_push_if(b, nir_ult(b, prim_index, xfb_prims));
660
661 /* Write XFB for each output */
662 for (unsigned i = 0; i < xfb->output_count; ++i) {
663 nir_xfb_output_info output = xfb->outputs[i];
664
665 /* Only write to the selected stream */
666 if (xfb->buffer_to_stream[output.buffer] != stream)
667 continue;
668
669 unsigned buffer = output.buffer;
670 unsigned stride = xfb->buffers[buffer].stride;
671 unsigned count = util_bitcount(output.component_mask);
672
673 for (unsigned vert = 0; vert < verts; ++vert) {
674 /* We write out the vertices backwards, since 0 is the current
675 * emitted vertex (which is actually the last vertex).
676 *
677 * We handle NULL var for
678 * KHR-Single-GL44.enhanced_layouts.xfb_capture_struct.
679 */
680 unsigned v = (verts - 1) - vert;
681 nir_variable *var = state->outputs[output.location][v];
682 nir_def *value = var ? nir_load_var(b, var) : nir_undef(b, 4, 32);
683
684 /* In case output.component_mask contains invalid components, write
685 * out zeroes instead of blowing up validation.
686 *
687 * KHR-Single-GL44.enhanced_layouts.xfb_capture_inactive_output_component
688 * hits this.
689 */
690 value = nir_pad_vector_imm_int(b, value, 0, 4);
691
692 nir_def *rotated_vert = nir_imm_int(b, vert);
693 if (verts == 3) {
694 /* Map vertices for output so we get consistent winding order. For
695 * the primitive index, we use the index_in_strip. This is actually
696 * the vertex index in the strip, hence
697 * offset by 2 relative to the true primitive index (#2 for the
698 * first triangle in the strip, #3 for the second). That's ok
699 * because only the parity matters.
700 */
701 rotated_vert = libagx_map_vertex_in_tri_strip(
702 b, index_in_strip, rotated_vert,
703 nir_inot(b, nir_i2b(b, nir_load_provoking_last(b))));
704 }
705
706 nir_def *addr = libagx_xfb_vertex_address(
707 b, nir_load_geometry_param_buffer_agx(b), base_index, rotated_vert,
708 nir_imm_int(b, buffer), nir_imm_int(b, stride),
709 nir_imm_int(b, output.offset));
710
711 nir_build_store_global(
712 b, nir_channels(b, value, output.component_mask), addr,
713 .align_mul = 4, .write_mask = nir_component_mask(count),
714 .access = ACCESS_XFB);
715 }
716 }
717
718 nir_pop_if(b, NULL);
719 }
720
721 /* Handle transform feedback for a given emit_vertex_with_counter */
722 static void
lower_emit_vertex_xfb(nir_builder * b,nir_intrinsic_instr * intr,struct lower_gs_state * state)723 lower_emit_vertex_xfb(nir_builder *b, nir_intrinsic_instr *intr,
724 struct lower_gs_state *state)
725 {
726 /* Transform feedback is written for each decomposed output primitive. Since
727 * we're writing strips, that means we output XFB for each vertex after the
728 * first complete primitive is formed.
729 */
730 unsigned first_prim = verts_in_output_prim(b->shader) - 1;
731 nir_def *index_in_strip = intr->src[1].ssa;
732
733 nir_push_if(b, nir_uge_imm(b, index_in_strip, first_prim));
734 {
735 write_xfb(b, state, nir_intrinsic_stream_id(intr), index_in_strip,
736 intr->src[3].ssa);
737 }
738 nir_pop_if(b, NULL);
739
740 /* Transform feedback writes out entire primitives during the emit_vertex. To
741 * do that, we store the values at all vertices in the strip in a little ring
742 * buffer. Index #0 is always the most recent primitive (so non-XFB code can
743 * just grab index #0 without any checking). Index #1 is the previous vertex,
744 * and index #2 is the vertex before that. Now that we've written XFB, since
745 * we've emitted a vertex we need to cycle the ringbuffer, freeing up index
746 * #0 for the next vertex that we are about to emit. We do that by copying
747 * the first n - 1 vertices forward one slot, which has to happen with a
748 * backwards copy implemented here.
749 *
750 * If we're lucky, all of these copies will be propagated away. If we're
751 * unlucky, this involves at most 2 copies per component per XFB output per
752 * vertex.
753 */
754 u_foreach_bit64(slot, b->shader->info.outputs_written) {
755 /* Note: if we're outputting points, verts_in_output_prim will be 1, so
756 * this loop will not execute. This is intended: points are self-contained
757 * primitives and do not need these copies.
758 */
759 for (int v = verts_in_output_prim(b->shader) - 1; v >= 1; --v) {
760 nir_def *value = nir_load_var(b, state->outputs[slot][v - 1]);
761
762 nir_store_var(b, state->outputs[slot][v], value,
763 nir_component_mask(value->num_components));
764 }
765 }
766 }
767
768 static bool
lower_gs_instr(nir_builder * b,nir_intrinsic_instr * intr,void * state)769 lower_gs_instr(nir_builder *b, nir_intrinsic_instr *intr, void *state)
770 {
771 b->cursor = nir_before_instr(&intr->instr);
772
773 switch (intr->intrinsic) {
774 case nir_intrinsic_set_vertex_and_primitive_count:
775 /* This instruction is mostly for the count shader, so just remove. But
776 * for points, we write the index buffer here so the rast shader can map.
777 */
778 if (b->shader->info.gs.output_primitive == MESA_PRIM_POINTS) {
779 lower_end_primitive(b, intr, state);
780 }
781
782 break;
783
784 case nir_intrinsic_end_primitive_with_counter: {
785 unsigned min = verts_in_output_prim(b->shader);
786
787 /* We only write out complete primitives */
788 nir_push_if(b, nir_uge_imm(b, intr->src[1].ssa, min));
789 {
790 lower_end_primitive(b, intr, state);
791 }
792 nir_pop_if(b, NULL);
793 break;
794 }
795
796 case nir_intrinsic_emit_vertex_with_counter:
797 /* emit_vertex triggers transform feedback but is otherwise a no-op. */
798 if (b->shader->xfb_info)
799 lower_emit_vertex_xfb(b, intr, state);
800 break;
801
802 default:
803 return false;
804 }
805
806 nir_instr_remove(&intr->instr);
807 return true;
808 }
809
810 static bool
collect_components(nir_builder * b,nir_intrinsic_instr * intr,void * data)811 collect_components(nir_builder *b, nir_intrinsic_instr *intr, void *data)
812 {
813 uint8_t *counts = data;
814 if (intr->intrinsic != nir_intrinsic_store_output)
815 return false;
816
817 unsigned count = nir_intrinsic_component(intr) +
818 util_last_bit(nir_intrinsic_write_mask(intr));
819
820 unsigned loc =
821 nir_intrinsic_io_semantics(intr).location + nir_src_as_uint(intr->src[1]);
822
823 uint8_t *total_count = &counts[loc];
824
825 *total_count = MAX2(*total_count, count);
826 return true;
827 }
828
829 /*
830 * Create the pre-GS shader. This is a small compute 1x1x1 kernel that patches
831 * up the VDM Index List command from the draw to read the produced geometry, as
832 * well as updates transform feedack offsets and counters as applicable (TODO).
833 */
834 static nir_shader *
agx_nir_create_pre_gs(struct lower_gs_state * state,const nir_shader * libagx,bool indexed,bool restart,struct nir_xfb_info * xfb,unsigned vertices_per_prim,uint8_t streams,unsigned invocations)835 agx_nir_create_pre_gs(struct lower_gs_state *state, const nir_shader *libagx,
836 bool indexed, bool restart, struct nir_xfb_info *xfb,
837 unsigned vertices_per_prim, uint8_t streams,
838 unsigned invocations)
839 {
840 nir_builder b_ = nir_builder_init_simple_shader(
841 MESA_SHADER_COMPUTE, &agx_nir_options, "Pre-GS patch up");
842 nir_builder *b = &b_;
843
844 /* Load the number of primitives input to the GS */
845 nir_def *unrolled_in_prims = load_geometry_param(b, input_primitives);
846
847 /* Setup the draw from the rasterization stream (0). */
848 if (!state->rasterizer_discard) {
849 libagx_build_gs_draw(
850 b, nir_load_geometry_param_buffer_agx(b), nir_imm_bool(b, indexed),
851 previous_vertices(b, state, 0, unrolled_in_prims),
852 restart ? previous_primitives(b, state, 0, unrolled_in_prims)
853 : nir_imm_int(b, 0));
854 }
855
856 /* Determine the number of primitives generated in each stream */
857 nir_def *in_prims[MAX_VERTEX_STREAMS], *prims[MAX_VERTEX_STREAMS];
858
859 u_foreach_bit(i, streams) {
860 in_prims[i] = previous_xfb_primitives(b, state, i, unrolled_in_prims);
861 prims[i] = in_prims[i];
862
863 add_counter(b, load_geometry_param(b, prims_generated_counter[i]),
864 prims[i]);
865 }
866
867 if (xfb) {
868 /* Write XFB addresses */
869 nir_def *offsets[4] = {NULL};
870 u_foreach_bit(i, xfb->buffers_written) {
871 offsets[i] = libagx_setup_xfb_buffer(
872 b, nir_load_geometry_param_buffer_agx(b), nir_imm_int(b, i));
873 }
874
875 /* Now clamp to the number that XFB captures */
876 for (unsigned i = 0; i < xfb->output_count; ++i) {
877 nir_xfb_output_info output = xfb->outputs[i];
878
879 unsigned buffer = output.buffer;
880 unsigned stream = xfb->buffer_to_stream[buffer];
881 unsigned stride = xfb->buffers[buffer].stride;
882 unsigned words_written = util_bitcount(output.component_mask);
883 unsigned bytes_written = words_written * 4;
884
885 /* Primitive P will write up to (but not including) offset:
886 *
887 * xfb_offset + ((P - 1) * (verts_per_prim * stride))
888 * + ((verts_per_prim - 1) * stride)
889 * + output_offset
890 * + output_size
891 *
892 * Given an XFB buffer of size xfb_size, we get the inequality:
893 *
894 * floor(P) <= (stride + xfb_size - xfb_offset - output_offset -
895 * output_size) // (stride * verts_per_prim)
896 */
897 nir_def *size = load_geometry_param(b, xfb_size[buffer]);
898 size = nir_iadd_imm(b, size, stride - output.offset - bytes_written);
899 size = nir_isub(b, size, offsets[buffer]);
900 size = nir_imax(b, size, nir_imm_int(b, 0));
901 nir_def *max_prims = nir_udiv_imm(b, size, stride * vertices_per_prim);
902
903 prims[stream] = nir_umin(b, prims[stream], max_prims);
904 }
905
906 nir_def *any_overflow = nir_imm_false(b);
907
908 u_foreach_bit(i, streams) {
909 nir_def *overflow = nir_ult(b, prims[i], in_prims[i]);
910 any_overflow = nir_ior(b, any_overflow, overflow);
911
912 store_geometry_param(b, xfb_prims[i], prims[i]);
913
914 add_counter(b, load_geometry_param(b, xfb_overflow[i]),
915 nir_b2i32(b, overflow));
916
917 add_counter(b, load_geometry_param(b, xfb_prims_generated_counter[i]),
918 prims[i]);
919 }
920
921 add_counter(b, load_geometry_param(b, xfb_any_overflow),
922 nir_b2i32(b, any_overflow));
923
924 /* Update XFB counters */
925 u_foreach_bit(i, xfb->buffers_written) {
926 uint32_t prim_stride_B = xfb->buffers[i].stride * vertices_per_prim;
927 unsigned stream = xfb->buffer_to_stream[i];
928
929 nir_def *off_ptr = load_geometry_param(b, xfb_offs_ptrs[i]);
930 nir_def *size = nir_imul_imm(b, prims[stream], prim_stride_B);
931 add_counter(b, off_ptr, size);
932 }
933 }
934
935 /* The geometry shader receives a number of input primitives. The driver
936 * should disable this counter when tessellation is active TODO and count
937 * patches separately.
938 */
939 add_counter(
940 b,
941 nir_load_stat_query_address_agx(b, .base = PIPE_STAT_QUERY_IA_PRIMITIVES),
942 unrolled_in_prims);
943
944 /* The geometry shader is invoked once per primitive (after unrolling
945 * primitive restart). From the spec:
946 *
947 * In case of instanced geometry shaders (see section 11.3.4.2) the
948 * geometry shader invocations count is incremented for each separate
949 * instanced invocation.
950 */
951 add_counter(b,
952 nir_load_stat_query_address_agx(
953 b, .base = PIPE_STAT_QUERY_GS_INVOCATIONS),
954 nir_imul_imm(b, unrolled_in_prims, invocations));
955
956 nir_def *emitted_prims = nir_imm_int(b, 0);
957 u_foreach_bit(i, streams) {
958 emitted_prims =
959 nir_iadd(b, emitted_prims,
960 previous_xfb_primitives(b, state, i, unrolled_in_prims));
961 }
962
963 add_counter(
964 b,
965 nir_load_stat_query_address_agx(b, .base = PIPE_STAT_QUERY_GS_PRIMITIVES),
966 emitted_prims);
967
968 /* Clipper queries are not well-defined, so we can emulate them in lots of
969 * silly ways. We need the hardware counters to implement them properly. For
970 * now, just consider all primitives emitted as passing through the clipper.
971 * This satisfies spec text:
972 *
973 * The number of primitives that reach the primitive clipping stage.
974 *
975 * and
976 *
977 * If at least one vertex of the primitive lies inside the clipping
978 * volume, the counter is incremented by one or more. Otherwise, the
979 * counter is incremented by zero or more.
980 */
981 add_counter(
982 b,
983 nir_load_stat_query_address_agx(b, .base = PIPE_STAT_QUERY_C_PRIMITIVES),
984 emitted_prims);
985
986 add_counter(
987 b,
988 nir_load_stat_query_address_agx(b, .base = PIPE_STAT_QUERY_C_INVOCATIONS),
989 emitted_prims);
990
991 /* Preprocess it */
992 UNUSED struct agx_uncompiled_shader_info info;
993 agx_preprocess_nir(b->shader, libagx, false, &info);
994
995 return b->shader;
996 }
997
998 static bool
rewrite_invocation_id(nir_builder * b,nir_intrinsic_instr * intr,void * data)999 rewrite_invocation_id(nir_builder *b, nir_intrinsic_instr *intr, void *data)
1000 {
1001 if (intr->intrinsic != nir_intrinsic_load_invocation_id)
1002 return false;
1003
1004 b->cursor = nir_instr_remove(&intr->instr);
1005 nir_def_rewrite_uses(&intr->def, nir_u2uN(b, data, intr->def.bit_size));
1006 return true;
1007 }
1008
1009 /*
1010 * Geometry shader instancing allows a GS to run multiple times. The number of
1011 * times is statically known and small. It's easiest to turn this into a loop
1012 * inside the GS, to avoid the feature "leaking" outside and affecting e.g. the
1013 * counts.
1014 */
1015 static void
agx_nir_lower_gs_instancing(nir_shader * gs)1016 agx_nir_lower_gs_instancing(nir_shader *gs)
1017 {
1018 unsigned nr_invocations = gs->info.gs.invocations;
1019 nir_function_impl *impl = nir_shader_get_entrypoint(gs);
1020
1021 /* Each invocation can produce up to the shader-declared max_vertices, so
1022 * multiply it up for proper bounds check. Emitting more than the declared
1023 * max_vertices per invocation results in undefined behaviour, so erroneously
1024 * emitting more as asked on early invocations is a perfectly cromulent
1025 * behvaiour.
1026 */
1027 gs->info.gs.vertices_out *= gs->info.gs.invocations;
1028
1029 /* Get the original function */
1030 nir_cf_list list;
1031 nir_cf_extract(&list, nir_before_impl(impl), nir_after_impl(impl));
1032
1033 /* Create a builder for the wrapped function */
1034 nir_builder b = nir_builder_at(nir_after_block(nir_start_block(impl)));
1035
1036 nir_variable *i =
1037 nir_local_variable_create(impl, glsl_uintN_t_type(16), NULL);
1038 nir_store_var(&b, i, nir_imm_intN_t(&b, 0, 16), ~0);
1039 nir_def *index = NULL;
1040
1041 /* Create a loop in the wrapped function */
1042 nir_loop *loop = nir_push_loop(&b);
1043 {
1044 index = nir_load_var(&b, i);
1045 nir_push_if(&b, nir_uge_imm(&b, index, nr_invocations));
1046 {
1047 nir_jump(&b, nir_jump_break);
1048 }
1049 nir_pop_if(&b, NULL);
1050
1051 b.cursor = nir_cf_reinsert(&list, b.cursor);
1052 nir_store_var(&b, i, nir_iadd_imm(&b, index, 1), ~0);
1053
1054 /* Make sure we end the primitive between invocations. If the geometry
1055 * shader already ended the primitive, this will get optimized out.
1056 */
1057 nir_end_primitive(&b);
1058 }
1059 nir_pop_loop(&b, loop);
1060
1061 /* We've mucked about with control flow */
1062 nir_metadata_preserve(impl, nir_metadata_none);
1063
1064 /* Use the loop counter as the invocation ID each iteration */
1065 nir_shader_intrinsics_pass(gs, rewrite_invocation_id,
1066 nir_metadata_block_index | nir_metadata_dominance,
1067 index);
1068 }
1069
1070 static bool
strip_side_effects(nir_builder * b,nir_intrinsic_instr * intr,void * _)1071 strip_side_effects(nir_builder *b, nir_intrinsic_instr *intr, void *_)
1072 {
1073 switch (intr->intrinsic) {
1074 case nir_intrinsic_store_global:
1075 case nir_intrinsic_global_atomic:
1076 case nir_intrinsic_global_atomic_swap:
1077 break;
1078 default:
1079 return false;
1080 }
1081
1082 /* If there's a side effect that's actually required for the prepass, we have
1083 * to keep it in.
1084 */
1085 if (nir_intrinsic_infos[intr->intrinsic].has_dest &&
1086 !list_is_empty(&intr->def.uses))
1087 return false;
1088
1089 /* Do not strip transform feedback stores, the rasterization shader doesn't
1090 * execute them.
1091 */
1092 if (intr->intrinsic == nir_intrinsic_store_global &&
1093 nir_intrinsic_access(intr) & ACCESS_XFB)
1094 return false;
1095
1096 /* Otherwise, remove the dead instruction. The rasterization shader will
1097 * execute the side effect so the side effect still happens at least once.
1098 */
1099 nir_instr_remove(&intr->instr);
1100 return true;
1101 }
1102
1103 static void
link_libagx(nir_shader * nir,const nir_shader * libagx)1104 link_libagx(nir_shader *nir, const nir_shader *libagx)
1105 {
1106 nir_link_shader_functions(nir, libagx);
1107 NIR_PASS(_, nir, nir_inline_functions);
1108 nir_remove_non_entrypoints(nir);
1109 NIR_PASS(_, nir, nir_lower_indirect_derefs, nir_var_function_temp, 64);
1110 NIR_PASS(_, nir, nir_opt_dce);
1111 NIR_PASS(_, nir, nir_lower_vars_to_explicit_types,
1112 nir_var_shader_temp | nir_var_function_temp | nir_var_mem_shared |
1113 nir_var_mem_global,
1114 glsl_get_cl_type_size_align);
1115 NIR_PASS(_, nir, nir_opt_deref);
1116 NIR_PASS(_, nir, nir_lower_vars_to_ssa);
1117 NIR_PASS(_, nir, nir_lower_explicit_io,
1118 nir_var_shader_temp | nir_var_function_temp | nir_var_mem_shared |
1119 nir_var_mem_global,
1120 nir_address_format_62bit_generic);
1121 }
1122
1123 bool
agx_nir_lower_gs(nir_shader * gs,const nir_shader * libagx,bool rasterizer_discard,nir_shader ** gs_count,nir_shader ** gs_copy,nir_shader ** pre_gs,enum mesa_prim * out_mode,unsigned * out_count_words)1124 agx_nir_lower_gs(nir_shader *gs, const nir_shader *libagx,
1125 bool rasterizer_discard, nir_shader **gs_count,
1126 nir_shader **gs_copy, nir_shader **pre_gs,
1127 enum mesa_prim *out_mode, unsigned *out_count_words)
1128 {
1129 /* Collect output component counts so we can size the geometry output buffer
1130 * appropriately, instead of assuming everything is vec4.
1131 */
1132 uint8_t component_counts[NUM_TOTAL_VARYING_SLOTS] = {0};
1133 nir_shader_intrinsics_pass(gs, collect_components, nir_metadata_all,
1134 component_counts);
1135
1136 /* If geometry shader instancing is used, lower it away before linking
1137 * anything. Otherwise, smash the invocation ID to zero.
1138 */
1139 if (gs->info.gs.invocations != 1) {
1140 agx_nir_lower_gs_instancing(gs);
1141 } else {
1142 nir_function_impl *impl = nir_shader_get_entrypoint(gs);
1143 nir_builder b = nir_builder_at(nir_before_impl(impl));
1144
1145 nir_shader_intrinsics_pass(
1146 gs, rewrite_invocation_id,
1147 nir_metadata_block_index | nir_metadata_dominance, nir_imm_int(&b, 0));
1148 }
1149
1150 NIR_PASS(_, gs, nir_shader_intrinsics_pass, lower_gs_inputs,
1151 nir_metadata_block_index | nir_metadata_dominance, NULL);
1152
1153 /* Lower geometry shader writes to contain all of the required counts, so we
1154 * know where in the various buffers we should write vertices.
1155 */
1156 NIR_PASS(_, gs, nir_lower_gs_intrinsics,
1157 nir_lower_gs_intrinsics_count_primitives |
1158 nir_lower_gs_intrinsics_per_stream |
1159 nir_lower_gs_intrinsics_count_vertices_per_primitive |
1160 nir_lower_gs_intrinsics_overwrite_incomplete |
1161 nir_lower_gs_intrinsics_always_end_primitive |
1162 nir_lower_gs_intrinsics_count_decomposed_primitives);
1163
1164 /* Clean up after all that lowering we did */
1165 bool progress = false;
1166 do {
1167 progress = false;
1168 NIR_PASS(progress, gs, nir_lower_var_copies);
1169 NIR_PASS(progress, gs, nir_lower_variable_initializers,
1170 nir_var_shader_temp);
1171 NIR_PASS(progress, gs, nir_lower_vars_to_ssa);
1172 NIR_PASS(progress, gs, nir_copy_prop);
1173 NIR_PASS(progress, gs, nir_opt_constant_folding);
1174 NIR_PASS(progress, gs, nir_opt_algebraic);
1175 NIR_PASS(progress, gs, nir_opt_cse);
1176 NIR_PASS(progress, gs, nir_opt_dead_cf);
1177 NIR_PASS(progress, gs, nir_opt_dce);
1178
1179 /* Unrolling lets us statically determine counts more often, which
1180 * otherwise would not be possible with multiple invocations even in the
1181 * simplest of cases.
1182 */
1183 NIR_PASS(progress, gs, nir_opt_loop_unroll);
1184 } while (progress);
1185
1186 /* If we know counts at compile-time we can simplify, so try to figure out
1187 * the counts statically.
1188 */
1189 struct lower_gs_state gs_state = {
1190 .rasterizer_discard = rasterizer_discard,
1191 };
1192
1193 nir_gs_count_vertices_and_primitives(
1194 gs, gs_state.static_count[GS_COUNTER_VERTICES],
1195 gs_state.static_count[GS_COUNTER_PRIMITIVES],
1196 gs_state.static_count[GS_COUNTER_XFB_PRIMITIVES], 4);
1197
1198 /* Anything we don't know statically will be tracked by the count buffer.
1199 * Determine the layout for it.
1200 */
1201 for (unsigned i = 0; i < MAX_VERTEX_STREAMS; ++i) {
1202 for (unsigned c = 0; c < GS_NUM_COUNTERS; ++c) {
1203 gs_state.count_index[i][c] =
1204 (gs_state.static_count[c][i] < 0) ? gs_state.count_stride_el++ : -1;
1205 }
1206 }
1207
1208 *gs_copy = agx_nir_create_gs_rast_shader(gs, libagx);
1209
1210 NIR_PASS(_, gs, nir_shader_intrinsics_pass, lower_id,
1211 nir_metadata_block_index | nir_metadata_dominance, NULL);
1212
1213 link_libagx(gs, libagx);
1214
1215 NIR_PASS(_, gs, nir_lower_idiv,
1216 &(const nir_lower_idiv_options){.allow_fp16 = true});
1217
1218 /* All those variables we created should've gone away by now */
1219 NIR_PASS(_, gs, nir_remove_dead_variables, nir_var_function_temp, NULL);
1220
1221 /* If there is any unknown count, we need a geometry count shader */
1222 if (gs_state.count_stride_el > 0)
1223 *gs_count = agx_nir_create_geometry_count_shader(gs, libagx, &gs_state);
1224 else
1225 *gs_count = NULL;
1226
1227 /* Geometry shader outputs are staged to temporaries */
1228 struct agx_lower_output_to_var_state state = {0};
1229
1230 u_foreach_bit64(slot, gs->info.outputs_written) {
1231 const char *slot_name =
1232 gl_varying_slot_name_for_stage(slot, MESA_SHADER_GEOMETRY);
1233
1234 for (unsigned i = 0; i < MAX_PRIM_OUT_SIZE; ++i) {
1235 gs_state.outputs[slot][i] = nir_variable_create(
1236 gs, nir_var_shader_temp,
1237 glsl_vector_type(GLSL_TYPE_UINT, component_counts[slot]),
1238 ralloc_asprintf(gs, "%s-%u", slot_name, i));
1239 }
1240
1241 state.outputs[slot] = gs_state.outputs[slot][0];
1242 }
1243
1244 NIR_PASS(_, gs, nir_shader_instructions_pass, agx_lower_output_to_var,
1245 nir_metadata_block_index | nir_metadata_dominance, &state);
1246
1247 NIR_PASS(_, gs, nir_shader_intrinsics_pass, lower_gs_instr,
1248 nir_metadata_none, &gs_state);
1249
1250 /* Determine if we are guaranteed to rasterize at least one vertex, so that
1251 * we can strip the prepass of side effects knowing they will execute in the
1252 * rasterization shader.
1253 */
1254 bool rasterizes_at_least_one_vertex =
1255 !rasterizer_discard && gs_state.static_count[0][0] > 0;
1256
1257 /* Clean up after all that lowering we did */
1258 nir_lower_global_vars_to_local(gs);
1259 do {
1260 progress = false;
1261 NIR_PASS(progress, gs, nir_lower_var_copies);
1262 NIR_PASS(progress, gs, nir_lower_variable_initializers,
1263 nir_var_shader_temp);
1264 NIR_PASS(progress, gs, nir_lower_vars_to_ssa);
1265 NIR_PASS(progress, gs, nir_copy_prop);
1266 NIR_PASS(progress, gs, nir_opt_constant_folding);
1267 NIR_PASS(progress, gs, nir_opt_algebraic);
1268 NIR_PASS(progress, gs, nir_opt_cse);
1269 NIR_PASS(progress, gs, nir_opt_dead_cf);
1270 NIR_PASS(progress, gs, nir_opt_dce);
1271 NIR_PASS(progress, gs, nir_opt_loop_unroll);
1272
1273 /* When rasterizing, we try to move side effects to the rasterizer shader
1274 * and strip the prepass of the dead side effects. Run this in the opt
1275 * loop because it interacts with nir_opt_dce.
1276 */
1277 if (rasterizes_at_least_one_vertex) {
1278 NIR_PASS(progress, gs, nir_shader_intrinsics_pass, strip_side_effects,
1279 nir_metadata_block_index | nir_metadata_dominance, NULL);
1280 }
1281 } while (progress);
1282
1283 /* All those variables we created should've gone away by now */
1284 NIR_PASS(_, gs, nir_remove_dead_variables, nir_var_function_temp, NULL);
1285
1286 NIR_PASS(_, gs, nir_opt_sink, ~0);
1287 NIR_PASS(_, gs, nir_opt_move, ~0);
1288 NIR_PASS(_, gs, nir_shader_intrinsics_pass, lower_id,
1289 nir_metadata_block_index | nir_metadata_dominance, NULL);
1290
1291 /* Create auxiliary programs */
1292 *pre_gs = agx_nir_create_pre_gs(
1293 &gs_state, libagx, true, gs->info.gs.output_primitive != MESA_PRIM_POINTS,
1294 gs->xfb_info, verts_in_output_prim(gs), gs->info.gs.active_stream_mask,
1295 gs->info.gs.invocations);
1296
1297 /* Signal what primitive we want to draw the GS Copy VS with */
1298 *out_mode = gs->info.gs.output_primitive;
1299 *out_count_words = gs_state.count_stride_el;
1300 return true;
1301 }
1302
1303 /*
1304 * Vertex shaders (tessellation evaluation shaders) before a geometry shader run
1305 * as a dedicated compute prepass. They are invoked as (count, instances, 1),
1306 * equivalent to a geometry shader inputting POINTS, so the vertex output buffer
1307 * is indexed according to calc_unrolled_id.
1308 *
1309 * This function lowers their vertex shader I/O to compute.
1310 *
1311 * Vertex ID becomes an index buffer pull (without applying the topology). Store
1312 * output becomes a store into the global vertex output buffer.
1313 */
1314 static bool
lower_vs_before_gs(nir_builder * b,nir_intrinsic_instr * intr,void * data)1315 lower_vs_before_gs(nir_builder *b, nir_intrinsic_instr *intr, void *data)
1316 {
1317 if (intr->intrinsic != nir_intrinsic_store_output)
1318 return false;
1319
1320 b->cursor = nir_instr_remove(&intr->instr);
1321 nir_io_semantics sem = nir_intrinsic_io_semantics(intr);
1322 nir_def *location = nir_iadd_imm(b, intr->src[1].ssa, sem.location);
1323
1324 nir_def *addr = libagx_vertex_output_address(
1325 b, nir_load_geometry_param_buffer_agx(b), calc_unrolled_id(b), location,
1326 nir_imm_int64(b, b->shader->info.outputs_written));
1327
1328 assert(nir_src_bit_size(intr->src[0]) == 32);
1329 addr = nir_iadd_imm(b, addr, nir_intrinsic_component(intr) * 4);
1330
1331 nir_store_global(b, addr, 4, intr->src[0].ssa,
1332 nir_intrinsic_write_mask(intr));
1333 return true;
1334 }
1335
1336 bool
agx_nir_lower_vs_before_gs(struct nir_shader * vs,const struct nir_shader * libagx,unsigned index_size_B,uint64_t * outputs)1337 agx_nir_lower_vs_before_gs(struct nir_shader *vs,
1338 const struct nir_shader *libagx,
1339 unsigned index_size_B, uint64_t *outputs)
1340 {
1341 bool progress = false;
1342
1343 /* Lower vertex ID to an index buffer pull without a topology applied */
1344 progress |= agx_nir_lower_index_buffer(vs, index_size_B, false);
1345
1346 /* Lower vertex stores to memory stores */
1347 progress |= nir_shader_intrinsics_pass(
1348 vs, lower_vs_before_gs, nir_metadata_block_index | nir_metadata_dominance,
1349 &index_size_B);
1350
1351 /* Lower instance ID and num vertices */
1352 progress |= nir_shader_intrinsics_pass(
1353 vs, lower_id, nir_metadata_block_index | nir_metadata_dominance, NULL);
1354
1355 /* Link libagx, used in lower_vs_before_gs */
1356 if (progress)
1357 link_libagx(vs, libagx);
1358
1359 /* Turn into a compute shader now that we're free of vertexisms */
1360 vs->info.stage = MESA_SHADER_COMPUTE;
1361 memset(&vs->info.cs, 0, sizeof(vs->info.cs));
1362 vs->xfb_info = NULL;
1363 *outputs = vs->info.outputs_written;
1364 return true;
1365 }
1366
1367 void
agx_nir_prefix_sum_gs(nir_builder * b,const void * data)1368 agx_nir_prefix_sum_gs(nir_builder *b, const void *data)
1369 {
1370 const unsigned *words = data;
1371
1372 uint32_t subgroup_size = 32;
1373 b->shader->info.workgroup_size[0] = subgroup_size;
1374 b->shader->info.workgroup_size[1] = *words;
1375
1376 libagx_prefix_sum(b, load_geometry_param(b, count_buffer),
1377 load_geometry_param(b, input_primitives),
1378 nir_imm_int(b, *words),
1379 nir_trim_vector(b, nir_load_local_invocation_id(b), 2));
1380 }
1381
1382 void
agx_nir_gs_setup_indirect(nir_builder * b,const void * data)1383 agx_nir_gs_setup_indirect(nir_builder *b, const void *data)
1384 {
1385 const struct agx_gs_setup_indirect_key *key = data;
1386
1387 libagx_gs_setup_indirect(b, nir_load_geometry_param_buffer_agx(b),
1388 nir_load_input_assembly_buffer_agx(b),
1389 nir_imm_int(b, key->prim),
1390 nir_channel(b, nir_load_local_invocation_id(b), 0));
1391 }
1392
1393 void
agx_nir_unroll_restart(nir_builder * b,const void * data)1394 agx_nir_unroll_restart(nir_builder *b, const void *data)
1395 {
1396 const struct agx_unroll_restart_key *key = data;
1397 nir_def *ia = nir_load_input_assembly_buffer_agx(b);
1398 nir_def *draw = nir_channel(b, nir_load_workgroup_id(b), 0);
1399 nir_def *mode = nir_imm_int(b, key->prim);
1400
1401 if (key->index_size_B == 1)
1402 libagx_unroll_restart_u8(b, ia, mode, draw);
1403 else if (key->index_size_B == 2)
1404 libagx_unroll_restart_u16(b, ia, mode, draw);
1405 else if (key->index_size_B == 4)
1406 libagx_unroll_restart_u32(b, ia, mode, draw);
1407 else
1408 unreachable("invalid index size");
1409 }
1410