1 /*
2 * Copyright 2023 Alyssa Rosenzweig
3 * Copyright 2023 Valve Corporation
4 * SPDX-License-Identifier: MIT
5 */
6
7 #include "agx_nir_lower_gs.h"
8 #include "asahi/compiler/agx_compile.h"
9 #include "compiler/nir/nir_builder.h"
10 #include "gallium/include/pipe/p_defines.h"
11 #include "libagx/geometry.h"
12 #include "libagx/libagx.h"
13 #include "util/bitscan.h"
14 #include "util/list.h"
15 #include "util/macros.h"
16 #include "util/ralloc.h"
17 #include "util/u_math.h"
18 #include "nir.h"
19 #include "nir_builder_opcodes.h"
20 #include "nir_intrinsics.h"
21 #include "nir_intrinsics_indices.h"
22 #include "nir_xfb_info.h"
23 #include "shader_enums.h"
24
25 enum gs_counter {
26 GS_COUNTER_VERTICES = 0,
27 GS_COUNTER_PRIMITIVES,
28 GS_COUNTER_XFB_PRIMITIVES,
29 GS_NUM_COUNTERS
30 };
31
32 #define MAX_PRIM_OUT_SIZE 3
33
34 struct lower_gs_state {
35 int static_count[GS_NUM_COUNTERS][MAX_VERTEX_STREAMS];
36 nir_variable *outputs[NUM_TOTAL_VARYING_SLOTS][MAX_PRIM_OUT_SIZE];
37
38 /* The count buffer contains `count_stride_el` 32-bit words in a row for each
39 * input primitive, for `input_primitives * count_stride_el * 4` total bytes.
40 */
41 unsigned count_stride_el;
42
43 /* The index of each counter in the count buffer, or -1 if it's not in the
44 * count buffer.
45 *
46 * Invariant: count_stride_el == sum(count_index[i][j] >= 0).
47 */
48 int count_index[MAX_VERTEX_STREAMS][GS_NUM_COUNTERS];
49
50 bool rasterizer_discard;
51 };
52
53 /* Helpers for loading from the geometry state buffer */
54 static nir_def *
load_geometry_param_offset(nir_builder * b,uint32_t offset,uint8_t bytes)55 load_geometry_param_offset(nir_builder *b, uint32_t offset, uint8_t bytes)
56 {
57 nir_def *base = nir_load_geometry_param_buffer_agx(b);
58 nir_def *addr = nir_iadd_imm(b, base, offset);
59
60 assert((offset % bytes) == 0 && "must be naturally aligned");
61
62 return nir_load_global_constant(b, addr, bytes, 1, bytes * 8);
63 }
64
65 static void
store_geometry_param_offset(nir_builder * b,nir_def * def,uint32_t offset,uint8_t bytes)66 store_geometry_param_offset(nir_builder *b, nir_def *def, uint32_t offset,
67 uint8_t bytes)
68 {
69 nir_def *base = nir_load_geometry_param_buffer_agx(b);
70 nir_def *addr = nir_iadd_imm(b, base, offset);
71
72 assert((offset % bytes) == 0 && "must be naturally aligned");
73
74 nir_store_global(b, addr, 4, def, nir_component_mask(def->num_components));
75 }
76
77 #define store_geometry_param(b, field, def) \
78 store_geometry_param_offset( \
79 b, def, offsetof(struct agx_geometry_params, field), \
80 sizeof(((struct agx_geometry_params *)0)->field))
81
82 #define load_geometry_param(b, field) \
83 load_geometry_param_offset( \
84 b, offsetof(struct agx_geometry_params, field), \
85 sizeof(((struct agx_geometry_params *)0)->field))
86
87 /* Helper for updating counters */
88 static void
add_counter(nir_builder * b,nir_def * counter,nir_def * increment)89 add_counter(nir_builder *b, nir_def *counter, nir_def *increment)
90 {
91 /* If the counter is NULL, the counter is disabled. Skip the update. */
92 nir_if *nif = nir_push_if(b, nir_ine_imm(b, counter, 0));
93 {
94 nir_def *old = nir_load_global(b, counter, 4, 1, 32);
95 nir_def *new_ = nir_iadd(b, old, increment);
96 nir_store_global(b, counter, 4, new_, nir_component_mask(1));
97 }
98 nir_pop_if(b, nif);
99 }
100
101 /* Helpers for lowering I/O to variables */
102 static void
lower_store_to_var(nir_builder * b,nir_intrinsic_instr * intr,struct agx_lower_output_to_var_state * state)103 lower_store_to_var(nir_builder *b, nir_intrinsic_instr *intr,
104 struct agx_lower_output_to_var_state *state)
105 {
106 b->cursor = nir_instr_remove(&intr->instr);
107 nir_io_semantics sem = nir_intrinsic_io_semantics(intr);
108 unsigned component = nir_intrinsic_component(intr);
109 nir_def *value = intr->src[0].ssa;
110
111 assert(nir_src_is_const(intr->src[1]) && "no indirect outputs");
112 assert(nir_intrinsic_write_mask(intr) == nir_component_mask(1) &&
113 "should be scalarized");
114
115 nir_variable *var =
116 state->outputs[sem.location + nir_src_as_uint(intr->src[1])];
117 if (!var) {
118 assert(sem.location == VARYING_SLOT_PSIZ &&
119 "otherwise in outputs_written");
120 return;
121 }
122
123 unsigned nr_components = glsl_get_components(glsl_without_array(var->type));
124 assert(component < nr_components);
125
126 /* Turn it into a vec4 write like NIR expects */
127 value = nir_vector_insert_imm(b, nir_undef(b, nr_components, 32), value,
128 component);
129
130 nir_store_var(b, var, value, BITFIELD_BIT(component));
131 }
132
133 bool
agx_lower_output_to_var(nir_builder * b,nir_instr * instr,void * data)134 agx_lower_output_to_var(nir_builder *b, nir_instr *instr, void *data)
135 {
136 if (instr->type != nir_instr_type_intrinsic)
137 return false;
138
139 nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
140 if (intr->intrinsic != nir_intrinsic_store_output)
141 return false;
142
143 lower_store_to_var(b, intr, data);
144 return true;
145 }
146
147 /*
148 * Geometry shader invocations are compute-like:
149 *
150 * (primitive ID, instance ID, 1)
151 */
152 static nir_def *
load_primitive_id(nir_builder * b)153 load_primitive_id(nir_builder *b)
154 {
155 return nir_channel(b, nir_load_global_invocation_id(b, 32), 0);
156 }
157
158 static nir_def *
load_instance_id(nir_builder * b)159 load_instance_id(nir_builder *b)
160 {
161 return nir_channel(b, nir_load_global_invocation_id(b, 32), 1);
162 }
163
164 /* Geometry shaders use software input assembly. The software vertex shader
165 * is invoked for each index, and the geometry shader applies the topology. This
166 * helper applies the topology.
167 */
168 static nir_def *
vertex_id_for_topology_class(nir_builder * b,nir_def * vert,enum mesa_prim cls)169 vertex_id_for_topology_class(nir_builder *b, nir_def *vert, enum mesa_prim cls)
170 {
171 nir_def *prim = nir_load_primitive_id(b);
172 nir_def *flatshade_first = nir_ieq_imm(b, nir_load_provoking_last(b), 0);
173 nir_def *nr = load_geometry_param(b, gs_grid[0]);
174 nir_def *topology = nir_load_input_topology_agx(b);
175
176 switch (cls) {
177 case MESA_PRIM_POINTS:
178 return prim;
179
180 case MESA_PRIM_LINES:
181 return libagx_vertex_id_for_line_class(b, topology, prim, vert, nr);
182
183 case MESA_PRIM_TRIANGLES:
184 return libagx_vertex_id_for_tri_class(b, topology, prim, vert,
185 flatshade_first);
186
187 case MESA_PRIM_LINES_ADJACENCY:
188 return libagx_vertex_id_for_line_adj_class(b, topology, prim, vert);
189
190 case MESA_PRIM_TRIANGLES_ADJACENCY:
191 return libagx_vertex_id_for_tri_adj_class(b, topology, prim, vert, nr,
192 flatshade_first);
193
194 default:
195 unreachable("invalid topology class");
196 }
197 }
198
199 nir_def *
agx_load_per_vertex_input(nir_builder * b,nir_intrinsic_instr * intr,nir_def * vertex)200 agx_load_per_vertex_input(nir_builder *b, nir_intrinsic_instr *intr,
201 nir_def *vertex)
202 {
203 assert(intr->intrinsic == nir_intrinsic_load_per_vertex_input);
204 nir_io_semantics sem = nir_intrinsic_io_semantics(intr);
205
206 nir_def *location = nir_iadd_imm(b, intr->src[1].ssa, sem.location);
207 nir_def *addr;
208
209 if (b->shader->info.stage == MESA_SHADER_GEOMETRY) {
210 /* GS may be preceded by VS or TES so specified as param */
211 addr = libagx_geometry_input_address(
212 b, nir_load_geometry_param_buffer_agx(b), vertex, location);
213 } else {
214 assert(b->shader->info.stage == MESA_SHADER_TESS_CTRL);
215
216 /* TCS always preceded by VS so we use the VS state directly */
217 addr = libagx_vertex_output_address(b, nir_load_vs_output_buffer_agx(b),
218 nir_load_vs_outputs_agx(b), vertex,
219 location);
220 }
221
222 addr = nir_iadd_imm(b, addr, 4 * nir_intrinsic_component(intr));
223 return nir_load_global_constant(b, addr, 4, intr->def.num_components,
224 intr->def.bit_size);
225 }
226
227 static bool
lower_gs_inputs(nir_builder * b,nir_intrinsic_instr * intr,void * _)228 lower_gs_inputs(nir_builder *b, nir_intrinsic_instr *intr, void *_)
229 {
230 if (intr->intrinsic != nir_intrinsic_load_per_vertex_input)
231 return false;
232
233 b->cursor = nir_instr_remove(&intr->instr);
234
235 /* Calculate the vertex ID we're pulling, based on the topology class */
236 nir_def *vert_in_prim = intr->src[0].ssa;
237 nir_def *vertex = vertex_id_for_topology_class(
238 b, vert_in_prim, b->shader->info.gs.input_primitive);
239
240 nir_def *verts = load_geometry_param(b, vs_grid[0]);
241 nir_def *unrolled =
242 nir_iadd(b, nir_imul(b, nir_load_instance_id(b), verts), vertex);
243
244 nir_def *val = agx_load_per_vertex_input(b, intr, unrolled);
245 nir_def_rewrite_uses(&intr->def, val);
246 return true;
247 }
248
249 /*
250 * Unrolled ID is the index of the primitive in the count buffer, given as
251 * (instance ID * # vertices/instance) + vertex ID
252 */
253 static nir_def *
calc_unrolled_id(nir_builder * b)254 calc_unrolled_id(nir_builder *b)
255 {
256 return nir_iadd(
257 b, nir_imul(b, load_instance_id(b), load_geometry_param(b, gs_grid[0])),
258 load_primitive_id(b));
259 }
260
261 static unsigned
output_vertex_id_stride(nir_shader * gs)262 output_vertex_id_stride(nir_shader *gs)
263 {
264 /* round up to power of two for cheap multiply/division */
265 return util_next_power_of_two(MAX2(gs->info.gs.vertices_out, 1));
266 }
267
268 /* Variant of calc_unrolled_id that uses a power-of-two stride for indices. This
269 * is sparser (acceptable for index buffer values, not for count buffer
270 * indices). It has the nice property of being cheap to invert, unlike
271 * calc_unrolled_id. So, we use calc_unrolled_id for count buffers and
272 * calc_unrolled_index_id for index values.
273 *
274 * This also multiplies by the appropriate stride to calculate the final index
275 * base value.
276 */
277 static nir_def *
calc_unrolled_index_id(nir_builder * b)278 calc_unrolled_index_id(nir_builder *b)
279 {
280 unsigned vertex_stride = output_vertex_id_stride(b->shader);
281 nir_def *primitives_log2 = load_geometry_param(b, primitives_log2);
282
283 nir_def *instance = nir_ishl(b, load_instance_id(b), primitives_log2);
284 nir_def *prim = nir_iadd(b, instance, load_primitive_id(b));
285
286 return nir_imul_imm(b, prim, vertex_stride);
287 }
288
289 static nir_def *
load_count_address(nir_builder * b,struct lower_gs_state * state,nir_def * unrolled_id,unsigned stream,enum gs_counter counter)290 load_count_address(nir_builder *b, struct lower_gs_state *state,
291 nir_def *unrolled_id, unsigned stream,
292 enum gs_counter counter)
293 {
294 int index = state->count_index[stream][counter];
295 if (index < 0)
296 return NULL;
297
298 nir_def *prim_offset_el =
299 nir_imul_imm(b, unrolled_id, state->count_stride_el);
300
301 nir_def *offset_el = nir_iadd_imm(b, prim_offset_el, index);
302
303 return nir_iadd(b, load_geometry_param(b, count_buffer),
304 nir_u2u64(b, nir_imul_imm(b, offset_el, 4)));
305 }
306
307 static void
write_counts(nir_builder * b,nir_intrinsic_instr * intr,struct lower_gs_state * state)308 write_counts(nir_builder *b, nir_intrinsic_instr *intr,
309 struct lower_gs_state *state)
310 {
311 /* Store each required counter */
312 nir_def *counts[GS_NUM_COUNTERS] = {
313 [GS_COUNTER_VERTICES] = intr->src[0].ssa,
314 [GS_COUNTER_PRIMITIVES] = intr->src[1].ssa,
315 [GS_COUNTER_XFB_PRIMITIVES] = intr->src[2].ssa,
316 };
317
318 for (unsigned i = 0; i < GS_NUM_COUNTERS; ++i) {
319 nir_def *addr = load_count_address(b, state, calc_unrolled_id(b),
320 nir_intrinsic_stream_id(intr), i);
321
322 if (addr)
323 nir_store_global(b, addr, 4, counts[i], nir_component_mask(1));
324 }
325 }
326
327 static bool
lower_gs_count_instr(nir_builder * b,nir_intrinsic_instr * intr,void * data)328 lower_gs_count_instr(nir_builder *b, nir_intrinsic_instr *intr, void *data)
329 {
330 switch (intr->intrinsic) {
331 case nir_intrinsic_emit_vertex_with_counter:
332 case nir_intrinsic_end_primitive_with_counter:
333 case nir_intrinsic_store_output:
334 /* These are for the main shader, just remove them */
335 nir_instr_remove(&intr->instr);
336 return true;
337
338 case nir_intrinsic_set_vertex_and_primitive_count:
339 b->cursor = nir_instr_remove(&intr->instr);
340 write_counts(b, intr, data);
341 return true;
342
343 default:
344 return false;
345 }
346 }
347
348 static bool
lower_id(nir_builder * b,nir_intrinsic_instr * intr,void * data)349 lower_id(nir_builder *b, nir_intrinsic_instr *intr, void *data)
350 {
351 b->cursor = nir_before_instr(&intr->instr);
352
353 nir_def *id;
354 if (intr->intrinsic == nir_intrinsic_load_primitive_id)
355 id = load_primitive_id(b);
356 else if (intr->intrinsic == nir_intrinsic_load_instance_id)
357 id = load_instance_id(b);
358 else if (intr->intrinsic == nir_intrinsic_load_flat_mask)
359 id = load_geometry_param(b, flat_outputs);
360 else if (intr->intrinsic == nir_intrinsic_load_input_topology_agx)
361 id = load_geometry_param(b, input_topology);
362 else
363 return false;
364
365 b->cursor = nir_instr_remove(&intr->instr);
366 nir_def_rewrite_uses(&intr->def, id);
367 return true;
368 }
369
370 /*
371 * Create a "Geometry count" shader. This is a stripped down geometry shader
372 * that just write its number of emitted vertices / primitives / transform
373 * feedback primitives to a count buffer. That count buffer will be prefix
374 * summed prior to running the real geometry shader. This is skipped if the
375 * counts are statically known.
376 */
377 static nir_shader *
agx_nir_create_geometry_count_shader(nir_shader * gs,const nir_shader * libagx,struct lower_gs_state * state)378 agx_nir_create_geometry_count_shader(nir_shader *gs, const nir_shader *libagx,
379 struct lower_gs_state *state)
380 {
381 /* Don't muck up the original shader */
382 nir_shader *shader = nir_shader_clone(NULL, gs);
383
384 if (shader->info.name) {
385 shader->info.name =
386 ralloc_asprintf(shader, "%s_count", shader->info.name);
387 } else {
388 shader->info.name = "count";
389 }
390
391 NIR_PASS(_, shader, nir_shader_intrinsics_pass, lower_gs_count_instr,
392 nir_metadata_control_flow, state);
393
394 NIR_PASS(_, shader, nir_shader_intrinsics_pass, lower_id,
395 nir_metadata_control_flow, NULL);
396
397 agx_preprocess_nir(shader, libagx);
398 return shader;
399 }
400
401 struct lower_gs_rast_state {
402 nir_def *instance_id, *primitive_id, *output_id;
403 struct agx_lower_output_to_var_state outputs;
404 struct agx_lower_output_to_var_state selected;
405 };
406
407 static void
select_rast_output(nir_builder * b,nir_intrinsic_instr * intr,struct lower_gs_rast_state * state)408 select_rast_output(nir_builder *b, nir_intrinsic_instr *intr,
409 struct lower_gs_rast_state *state)
410 {
411 b->cursor = nir_instr_remove(&intr->instr);
412
413 /* We only care about the rasterization stream in the rasterization
414 * shader, so just ignore emits from other streams.
415 */
416 if (nir_intrinsic_stream_id(intr) != 0)
417 return;
418
419 u_foreach_bit64(slot, b->shader->info.outputs_written) {
420 nir_def *orig = nir_load_var(b, state->selected.outputs[slot]);
421 nir_def *data = nir_load_var(b, state->outputs.outputs[slot]);
422
423 nir_def *value = nir_bcsel(
424 b, nir_ieq(b, intr->src[0].ssa, state->output_id), data, orig);
425
426 nir_store_var(b, state->selected.outputs[slot], value,
427 nir_component_mask(value->num_components));
428 }
429 }
430
431 static bool
lower_to_gs_rast(nir_builder * b,nir_intrinsic_instr * intr,void * data)432 lower_to_gs_rast(nir_builder *b, nir_intrinsic_instr *intr, void *data)
433 {
434 struct lower_gs_rast_state *state = data;
435
436 switch (intr->intrinsic) {
437 case nir_intrinsic_store_output:
438 lower_store_to_var(b, intr, &state->outputs);
439 return true;
440
441 case nir_intrinsic_emit_vertex_with_counter:
442 select_rast_output(b, intr, state);
443 return true;
444
445 case nir_intrinsic_load_primitive_id:
446 nir_def_rewrite_uses(&intr->def, state->primitive_id);
447 return true;
448
449 case nir_intrinsic_load_instance_id:
450 nir_def_rewrite_uses(&intr->def, state->instance_id);
451 return true;
452
453 case nir_intrinsic_load_flat_mask:
454 case nir_intrinsic_load_provoking_last:
455 case nir_intrinsic_load_input_topology_agx: {
456 /* Lowering the same in both GS variants */
457 return lower_id(b, intr, NULL);
458 }
459
460 case nir_intrinsic_end_primitive_with_counter:
461 case nir_intrinsic_set_vertex_and_primitive_count:
462 nir_instr_remove(&intr->instr);
463 return true;
464
465 default:
466 return false;
467 }
468 }
469
470 /*
471 * Side effects in geometry shaders are problematic with our "GS rasterization
472 * shader" implementation. Where does the side effect happen? In the prepass?
473 * In the rast shader? In both?
474 *
475 * A perfect solution is impossible with rast shaders. Since the spec is loose
476 * here, we follow the principle of "least surprise":
477 *
478 * 1. Prefer side effects in the prepass over the rast shader. The prepass runs
479 * once per API GS invocation so will match the expectations of buggy apps
480 * not written for tilers.
481 *
482 * 2. If we must execute any side effect in the rast shader, try to execute all
483 * side effects only in the rast shader. If some side effects must happen in
484 * the rast shader and others don't, this gets consistent counts
485 * (i.e. if the app expects plain stores and atomics to match up).
486 *
487 * 3. If we must execute side effects in both rast and the prepass,
488 * execute all side effects in the rast shader and strip what we can from
489 * the prepass. This gets the "unsurprising" behaviour from #2 without
490 * falling over for ridiculous uses of atomics.
491 */
492 static bool
strip_side_effect_from_rast(nir_builder * b,nir_intrinsic_instr * intr,void * data)493 strip_side_effect_from_rast(nir_builder *b, nir_intrinsic_instr *intr,
494 void *data)
495 {
496 switch (intr->intrinsic) {
497 case nir_intrinsic_store_global:
498 case nir_intrinsic_global_atomic:
499 case nir_intrinsic_global_atomic_swap:
500 break;
501 default:
502 return false;
503 }
504
505 /* If there's a side effect that's actually required, keep it. */
506 if (nir_intrinsic_infos[intr->intrinsic].has_dest &&
507 !list_is_empty(&intr->def.uses)) {
508
509 bool *any = data;
510 *any = true;
511 return false;
512 }
513
514 /* Otherwise, remove the dead instruction. */
515 nir_instr_remove(&intr->instr);
516 return true;
517 }
518
519 static bool
strip_side_effects_from_rast(nir_shader * s,bool * side_effects_for_rast)520 strip_side_effects_from_rast(nir_shader *s, bool *side_effects_for_rast)
521 {
522 bool progress, any;
523
524 /* Rather than complex analysis, clone and try to remove as many side effects
525 * as possible. Then we check if we removed them all. We need to loop to
526 * handle complex control flow with side effects, where we can strip
527 * everything but can't figure that out with a simple one-shot analysis.
528 */
529 nir_shader *clone = nir_shader_clone(NULL, s);
530
531 /* Drop as much as we can */
532 do {
533 progress = false;
534 any = false;
535 NIR_PASS(progress, clone, nir_shader_intrinsics_pass,
536 strip_side_effect_from_rast, nir_metadata_control_flow, &any);
537
538 NIR_PASS(progress, clone, nir_opt_dce);
539 NIR_PASS(progress, clone, nir_opt_dead_cf);
540 } while (progress);
541
542 ralloc_free(clone);
543
544 /* If we need atomics, leave them in */
545 if (any) {
546 *side_effects_for_rast = true;
547 return false;
548 }
549
550 /* Else strip it all */
551 do {
552 progress = false;
553 any = false;
554 NIR_PASS(progress, s, nir_shader_intrinsics_pass,
555 strip_side_effect_from_rast, nir_metadata_control_flow, &any);
556
557 NIR_PASS(progress, s, nir_opt_dce);
558 NIR_PASS(progress, s, nir_opt_dead_cf);
559 } while (progress);
560
561 assert(!any);
562 return progress;
563 }
564
565 static bool
strip_side_effect_from_main(nir_builder * b,nir_intrinsic_instr * intr,void * data)566 strip_side_effect_from_main(nir_builder *b, nir_intrinsic_instr *intr,
567 void *data)
568 {
569 switch (intr->intrinsic) {
570 case nir_intrinsic_global_atomic:
571 case nir_intrinsic_global_atomic_swap:
572 break;
573 default:
574 return false;
575 }
576
577 if (list_is_empty(&intr->def.uses)) {
578 nir_instr_remove(&intr->instr);
579 return true;
580 }
581
582 return false;
583 }
584
585 /*
586 * Create a GS rasterization shader. This is a hardware vertex shader that
587 * shades each rasterized output vertex in parallel.
588 */
589 static nir_shader *
agx_nir_create_gs_rast_shader(const nir_shader * gs,const nir_shader * libagx,bool * side_effects_for_rast)590 agx_nir_create_gs_rast_shader(const nir_shader *gs, const nir_shader *libagx,
591 bool *side_effects_for_rast)
592 {
593 /* Don't muck up the original shader */
594 nir_shader *shader = nir_shader_clone(NULL, gs);
595
596 unsigned max_verts = output_vertex_id_stride(shader);
597
598 /* Turn into a vertex shader run only for rasterization. Transform feedback
599 * was handled in the prepass.
600 */
601 shader->info.stage = MESA_SHADER_VERTEX;
602 shader->info.has_transform_feedback_varyings = false;
603 memset(&shader->info.vs, 0, sizeof(shader->info.vs));
604 shader->xfb_info = NULL;
605
606 if (shader->info.name) {
607 shader->info.name = ralloc_asprintf(shader, "%s_rast", shader->info.name);
608 } else {
609 shader->info.name = "gs rast";
610 }
611
612 nir_builder b_ =
613 nir_builder_at(nir_before_impl(nir_shader_get_entrypoint(shader)));
614 nir_builder *b = &b_;
615
616 NIR_PASS(_, shader, strip_side_effects_from_rast, side_effects_for_rast);
617
618 /* Optimize out pointless gl_PointSize outputs. Bizarrely, these occur. */
619 if (shader->info.gs.output_primitive != MESA_PRIM_POINTS)
620 shader->info.outputs_written &= ~VARYING_BIT_PSIZ;
621
622 /* See calc_unrolled_index_id */
623 nir_def *raw_id = nir_load_vertex_id(b);
624 nir_def *output_id = nir_umod_imm(b, raw_id, max_verts);
625 nir_def *unrolled = nir_udiv_imm(b, raw_id, max_verts);
626
627 nir_def *primitives_log2 = load_geometry_param(b, primitives_log2);
628 nir_def *instance_id = nir_ushr(b, unrolled, primitives_log2);
629 nir_def *primitive_id = nir_iand(
630 b, unrolled,
631 nir_iadd_imm(b, nir_ishl(b, nir_imm_int(b, 1), primitives_log2), -1));
632
633 struct lower_gs_rast_state rast_state = {
634 .instance_id = instance_id,
635 .primitive_id = primitive_id,
636 .output_id = output_id,
637 };
638
639 u_foreach_bit64(slot, shader->info.outputs_written) {
640 const char *slot_name =
641 gl_varying_slot_name_for_stage(slot, MESA_SHADER_GEOMETRY);
642
643 bool scalar = (slot == VARYING_SLOT_PSIZ) ||
644 (slot == VARYING_SLOT_LAYER) ||
645 (slot == VARYING_SLOT_VIEWPORT);
646 unsigned comps = scalar ? 1 : 4;
647
648 rast_state.outputs.outputs[slot] = nir_variable_create(
649 shader, nir_var_shader_temp, glsl_vector_type(GLSL_TYPE_UINT, comps),
650 ralloc_asprintf(shader, "%s-temp", slot_name));
651
652 rast_state.selected.outputs[slot] = nir_variable_create(
653 shader, nir_var_shader_temp, glsl_vector_type(GLSL_TYPE_UINT, comps),
654 ralloc_asprintf(shader, "%s-selected", slot_name));
655 }
656
657 nir_shader_intrinsics_pass(shader, lower_to_gs_rast,
658 nir_metadata_control_flow, &rast_state);
659
660 b->cursor = nir_after_impl(b->impl);
661
662 /* Forward each selected output to the rasterizer */
663 u_foreach_bit64(slot, shader->info.outputs_written) {
664 assert(rast_state.selected.outputs[slot] != NULL);
665 nir_def *value = nir_load_var(b, rast_state.selected.outputs[slot]);
666
667 /* We set NIR_COMPACT_ARRAYS so clip/cull distance needs to come all in
668 * DIST0. Undo the offset if we need to.
669 */
670 assert(slot != VARYING_SLOT_CULL_DIST1);
671 unsigned offset = 0;
672 if (slot == VARYING_SLOT_CLIP_DIST1)
673 offset = 1;
674
675 nir_store_output(b, value, nir_imm_int(b, offset),
676 .io_semantics.location = slot - offset,
677 .io_semantics.num_slots = 1,
678 .write_mask = nir_component_mask(value->num_components),
679 .src_type = nir_type_uint32);
680 }
681
682 /* It is legal to omit the point size write from the geometry shader when
683 * drawing points. In this case, the point size is implicitly 1.0. To
684 * implement, insert a synthetic `gl_PointSize = 1.0` write into the GS copy
685 * shader, if the GS does not export a point size while drawing points.
686 */
687 bool is_points = gs->info.gs.output_primitive == MESA_PRIM_POINTS;
688
689 if (!(shader->info.outputs_written & VARYING_BIT_PSIZ) && is_points) {
690 nir_store_output(b, nir_imm_float(b, 1.0), nir_imm_int(b, 0),
691 .io_semantics.location = VARYING_SLOT_PSIZ,
692 .io_semantics.num_slots = 1,
693 .write_mask = nir_component_mask(1),
694 .src_type = nir_type_float32);
695
696 shader->info.outputs_written |= VARYING_BIT_PSIZ;
697 }
698
699 nir_opt_idiv_const(shader, 16);
700
701 agx_preprocess_nir(shader, libagx);
702 return shader;
703 }
704
705 static nir_def *
previous_count(nir_builder * b,struct lower_gs_state * state,unsigned stream,nir_def * unrolled_id,enum gs_counter counter)706 previous_count(nir_builder *b, struct lower_gs_state *state, unsigned stream,
707 nir_def *unrolled_id, enum gs_counter counter)
708 {
709 assert(stream < MAX_VERTEX_STREAMS);
710 assert(counter < GS_NUM_COUNTERS);
711 int static_count = state->static_count[counter][stream];
712
713 if (static_count >= 0) {
714 /* If the number of outputted vertices per invocation is known statically,
715 * we can calculate the base.
716 */
717 return nir_imul_imm(b, unrolled_id, static_count);
718 } else {
719 /* Otherwise, we need to load from the prefix sum buffer. Note that the
720 * sums are inclusive, so index 0 is nonzero. This requires a little
721 * fixup here. We use a saturating unsigned subtraction so we don't read
722 * out-of-bounds for zero.
723 *
724 * TODO: Optimize this.
725 */
726 nir_def *prim_minus_1 = nir_usub_sat(b, unrolled_id, nir_imm_int(b, 1));
727 nir_def *addr =
728 load_count_address(b, state, prim_minus_1, stream, counter);
729
730 return nir_bcsel(b, nir_ieq_imm(b, unrolled_id, 0), nir_imm_int(b, 0),
731 nir_load_global_constant(b, addr, 4, 1, 32));
732 }
733 }
734
735 static nir_def *
previous_vertices(nir_builder * b,struct lower_gs_state * state,unsigned stream,nir_def * unrolled_id)736 previous_vertices(nir_builder *b, struct lower_gs_state *state, unsigned stream,
737 nir_def *unrolled_id)
738 {
739 return previous_count(b, state, stream, unrolled_id, GS_COUNTER_VERTICES);
740 }
741
742 static nir_def *
previous_primitives(nir_builder * b,struct lower_gs_state * state,unsigned stream,nir_def * unrolled_id)743 previous_primitives(nir_builder *b, struct lower_gs_state *state,
744 unsigned stream, nir_def *unrolled_id)
745 {
746 return previous_count(b, state, stream, unrolled_id, GS_COUNTER_PRIMITIVES);
747 }
748
749 static nir_def *
previous_xfb_primitives(nir_builder * b,struct lower_gs_state * state,unsigned stream,nir_def * unrolled_id)750 previous_xfb_primitives(nir_builder *b, struct lower_gs_state *state,
751 unsigned stream, nir_def *unrolled_id)
752 {
753 return previous_count(b, state, stream, unrolled_id,
754 GS_COUNTER_XFB_PRIMITIVES);
755 }
756
757 static void
lower_end_primitive(nir_builder * b,nir_intrinsic_instr * intr,struct lower_gs_state * state)758 lower_end_primitive(nir_builder *b, nir_intrinsic_instr *intr,
759 struct lower_gs_state *state)
760 {
761 assert((intr->intrinsic == nir_intrinsic_set_vertex_and_primitive_count ||
762 b->shader->info.gs.output_primitive != MESA_PRIM_POINTS) &&
763 "endprimitive for points should've been removed");
764
765 /* The GS is the last stage before rasterization, so if we discard the
766 * rasterization, we don't output an index buffer, nothing will read it.
767 * Index buffer is only for the rasterization stream.
768 */
769 unsigned stream = nir_intrinsic_stream_id(intr);
770 if (state->rasterizer_discard || stream != 0)
771 return;
772
773 libagx_end_primitive(
774 b, load_geometry_param(b, output_index_buffer), intr->src[0].ssa,
775 intr->src[1].ssa, intr->src[2].ssa,
776 previous_vertices(b, state, 0, calc_unrolled_id(b)),
777 previous_primitives(b, state, 0, calc_unrolled_id(b)),
778 calc_unrolled_index_id(b),
779 nir_imm_bool(b, b->shader->info.gs.output_primitive != MESA_PRIM_POINTS));
780 }
781
782 static unsigned
verts_in_output_prim(nir_shader * gs)783 verts_in_output_prim(nir_shader *gs)
784 {
785 return mesa_vertices_per_prim(gs->info.gs.output_primitive);
786 }
787
788 static void
write_xfb(nir_builder * b,struct lower_gs_state * state,unsigned stream,nir_def * index_in_strip,nir_def * prim_id_in_invocation)789 write_xfb(nir_builder *b, struct lower_gs_state *state, unsigned stream,
790 nir_def *index_in_strip, nir_def *prim_id_in_invocation)
791 {
792 struct nir_xfb_info *xfb = b->shader->xfb_info;
793 unsigned verts = verts_in_output_prim(b->shader);
794
795 /* Get the index of this primitive in the XFB buffer. That is, the base for
796 * this invocation for the stream plus the offset within this invocation.
797 */
798 nir_def *invocation_base =
799 previous_xfb_primitives(b, state, stream, calc_unrolled_id(b));
800
801 nir_def *prim_index = nir_iadd(b, invocation_base, prim_id_in_invocation);
802 nir_def *base_index = nir_imul_imm(b, prim_index, verts);
803
804 nir_def *xfb_prims = load_geometry_param(b, xfb_prims[stream]);
805 nir_push_if(b, nir_ult(b, prim_index, xfb_prims));
806
807 /* Write XFB for each output */
808 for (unsigned i = 0; i < xfb->output_count; ++i) {
809 nir_xfb_output_info output = xfb->outputs[i];
810
811 /* Only write to the selected stream */
812 if (xfb->buffer_to_stream[output.buffer] != stream)
813 continue;
814
815 unsigned buffer = output.buffer;
816 unsigned stride = xfb->buffers[buffer].stride;
817 unsigned count = util_bitcount(output.component_mask);
818
819 for (unsigned vert = 0; vert < verts; ++vert) {
820 /* We write out the vertices backwards, since 0 is the current
821 * emitted vertex (which is actually the last vertex).
822 *
823 * We handle NULL var for
824 * KHR-Single-GL44.enhanced_layouts.xfb_capture_struct.
825 */
826 unsigned v = (verts - 1) - vert;
827 nir_variable *var = state->outputs[output.location][v];
828 nir_def *value = var ? nir_load_var(b, var) : nir_undef(b, 4, 32);
829
830 /* In case output.component_mask contains invalid components, write
831 * out zeroes instead of blowing up validation.
832 *
833 * KHR-Single-GL44.enhanced_layouts.xfb_capture_inactive_output_component
834 * hits this.
835 */
836 value = nir_pad_vector_imm_int(b, value, 0, 4);
837
838 nir_def *rotated_vert = nir_imm_int(b, vert);
839 if (verts == 3) {
840 /* Map vertices for output so we get consistent winding order. For
841 * the primitive index, we use the index_in_strip. This is actually
842 * the vertex index in the strip, hence
843 * offset by 2 relative to the true primitive index (#2 for the
844 * first triangle in the strip, #3 for the second). That's ok
845 * because only the parity matters.
846 */
847 rotated_vert = libagx_map_vertex_in_tri_strip(
848 b, index_in_strip, rotated_vert,
849 nir_inot(b, nir_i2b(b, nir_load_provoking_last(b))));
850 }
851
852 nir_def *addr = libagx_xfb_vertex_address(
853 b, nir_load_geometry_param_buffer_agx(b), base_index, rotated_vert,
854 nir_imm_int(b, buffer), nir_imm_int(b, stride),
855 nir_imm_int(b, output.offset));
856
857 nir_store_global(b, addr, 4,
858 nir_channels(b, value, output.component_mask),
859 nir_component_mask(count));
860 }
861 }
862
863 nir_pop_if(b, NULL);
864 }
865
866 /* Handle transform feedback for a given emit_vertex_with_counter */
867 static void
lower_emit_vertex_xfb(nir_builder * b,nir_intrinsic_instr * intr,struct lower_gs_state * state)868 lower_emit_vertex_xfb(nir_builder *b, nir_intrinsic_instr *intr,
869 struct lower_gs_state *state)
870 {
871 /* Transform feedback is written for each decomposed output primitive. Since
872 * we're writing strips, that means we output XFB for each vertex after the
873 * first complete primitive is formed.
874 */
875 unsigned first_prim = verts_in_output_prim(b->shader) - 1;
876 nir_def *index_in_strip = intr->src[1].ssa;
877
878 nir_push_if(b, nir_uge_imm(b, index_in_strip, first_prim));
879 {
880 write_xfb(b, state, nir_intrinsic_stream_id(intr), index_in_strip,
881 intr->src[3].ssa);
882 }
883 nir_pop_if(b, NULL);
884
885 /* Transform feedback writes out entire primitives during the emit_vertex. To
886 * do that, we store the values at all vertices in the strip in a little ring
887 * buffer. Index #0 is always the most recent primitive (so non-XFB code can
888 * just grab index #0 without any checking). Index #1 is the previous vertex,
889 * and index #2 is the vertex before that. Now that we've written XFB, since
890 * we've emitted a vertex we need to cycle the ringbuffer, freeing up index
891 * #0 for the next vertex that we are about to emit. We do that by copying
892 * the first n - 1 vertices forward one slot, which has to happen with a
893 * backwards copy implemented here.
894 *
895 * If we're lucky, all of these copies will be propagated away. If we're
896 * unlucky, this involves at most 2 copies per component per XFB output per
897 * vertex.
898 */
899 u_foreach_bit64(slot, b->shader->info.outputs_written) {
900 /* Note: if we're outputting points, verts_in_output_prim will be 1, so
901 * this loop will not execute. This is intended: points are self-contained
902 * primitives and do not need these copies.
903 */
904 for (int v = verts_in_output_prim(b->shader) - 1; v >= 1; --v) {
905 nir_def *value = nir_load_var(b, state->outputs[slot][v - 1]);
906
907 nir_store_var(b, state->outputs[slot][v], value,
908 nir_component_mask(value->num_components));
909 }
910 }
911 }
912
913 static bool
lower_gs_instr(nir_builder * b,nir_intrinsic_instr * intr,void * state)914 lower_gs_instr(nir_builder *b, nir_intrinsic_instr *intr, void *state)
915 {
916 b->cursor = nir_before_instr(&intr->instr);
917
918 switch (intr->intrinsic) {
919 case nir_intrinsic_set_vertex_and_primitive_count:
920 /* This instruction is mostly for the count shader, so just remove. But
921 * for points, we write the index buffer here so the rast shader can map.
922 */
923 if (b->shader->info.gs.output_primitive == MESA_PRIM_POINTS) {
924 lower_end_primitive(b, intr, state);
925 }
926
927 break;
928
929 case nir_intrinsic_end_primitive_with_counter: {
930 unsigned min = verts_in_output_prim(b->shader);
931
932 /* We only write out complete primitives */
933 nir_push_if(b, nir_uge_imm(b, intr->src[1].ssa, min));
934 {
935 lower_end_primitive(b, intr, state);
936 }
937 nir_pop_if(b, NULL);
938 break;
939 }
940
941 case nir_intrinsic_emit_vertex_with_counter:
942 /* emit_vertex triggers transform feedback but is otherwise a no-op. */
943 if (b->shader->xfb_info)
944 lower_emit_vertex_xfb(b, intr, state);
945 break;
946
947 default:
948 return false;
949 }
950
951 nir_instr_remove(&intr->instr);
952 return true;
953 }
954
955 static bool
collect_components(nir_builder * b,nir_intrinsic_instr * intr,void * data)956 collect_components(nir_builder *b, nir_intrinsic_instr *intr, void *data)
957 {
958 uint8_t *counts = data;
959 if (intr->intrinsic != nir_intrinsic_store_output)
960 return false;
961
962 unsigned count = nir_intrinsic_component(intr) +
963 util_last_bit(nir_intrinsic_write_mask(intr));
964
965 unsigned loc =
966 nir_intrinsic_io_semantics(intr).location + nir_src_as_uint(intr->src[1]);
967
968 uint8_t *total_count = &counts[loc];
969
970 *total_count = MAX2(*total_count, count);
971 return true;
972 }
973
974 /*
975 * Create the pre-GS shader. This is a small compute 1x1x1 kernel that produces
976 * an indirect draw to rasterize the produced geometry, as well as updates
977 * transform feedback offsets and counters as applicable.
978 */
979 static nir_shader *
agx_nir_create_pre_gs(struct lower_gs_state * state,const nir_shader * libagx,bool indexed,bool restart,struct nir_xfb_info * xfb,unsigned vertices_per_prim,uint8_t streams,unsigned invocations)980 agx_nir_create_pre_gs(struct lower_gs_state *state, const nir_shader *libagx,
981 bool indexed, bool restart, struct nir_xfb_info *xfb,
982 unsigned vertices_per_prim, uint8_t streams,
983 unsigned invocations)
984 {
985 nir_builder b_ = nir_builder_init_simple_shader(
986 MESA_SHADER_COMPUTE, &agx_nir_options, "Pre-GS patch up");
987 nir_builder *b = &b_;
988
989 /* Load the number of primitives input to the GS */
990 nir_def *unrolled_in_prims = load_geometry_param(b, input_primitives);
991
992 /* Setup the draw from the rasterization stream (0). */
993 if (!state->rasterizer_discard) {
994 libagx_build_gs_draw(
995 b, nir_load_geometry_param_buffer_agx(b),
996 previous_vertices(b, state, 0, unrolled_in_prims),
997 restart ? previous_primitives(b, state, 0, unrolled_in_prims)
998 : nir_imm_int(b, 0));
999 }
1000
1001 /* Determine the number of primitives generated in each stream */
1002 nir_def *in_prims[MAX_VERTEX_STREAMS], *prims[MAX_VERTEX_STREAMS];
1003
1004 u_foreach_bit(i, streams) {
1005 in_prims[i] = previous_xfb_primitives(b, state, i, unrolled_in_prims);
1006 prims[i] = in_prims[i];
1007
1008 add_counter(b, load_geometry_param(b, prims_generated_counter[i]),
1009 prims[i]);
1010 }
1011
1012 if (xfb) {
1013 /* Write XFB addresses */
1014 nir_def *offsets[4] = {NULL};
1015 u_foreach_bit(i, xfb->buffers_written) {
1016 offsets[i] = libagx_setup_xfb_buffer(
1017 b, nir_load_geometry_param_buffer_agx(b), nir_imm_int(b, i));
1018 }
1019
1020 /* Now clamp to the number that XFB captures */
1021 for (unsigned i = 0; i < xfb->output_count; ++i) {
1022 nir_xfb_output_info output = xfb->outputs[i];
1023
1024 unsigned buffer = output.buffer;
1025 unsigned stream = xfb->buffer_to_stream[buffer];
1026 unsigned stride = xfb->buffers[buffer].stride;
1027 unsigned words_written = util_bitcount(output.component_mask);
1028 unsigned bytes_written = words_written * 4;
1029
1030 /* Primitive P will write up to (but not including) offset:
1031 *
1032 * xfb_offset + ((P - 1) * (verts_per_prim * stride))
1033 * + ((verts_per_prim - 1) * stride)
1034 * + output_offset
1035 * + output_size
1036 *
1037 * Given an XFB buffer of size xfb_size, we get the inequality:
1038 *
1039 * floor(P) <= (stride + xfb_size - xfb_offset - output_offset -
1040 * output_size) // (stride * verts_per_prim)
1041 */
1042 nir_def *size = load_geometry_param(b, xfb_size[buffer]);
1043 size = nir_iadd_imm(b, size, stride - output.offset - bytes_written);
1044 size = nir_isub(b, size, offsets[buffer]);
1045 size = nir_imax(b, size, nir_imm_int(b, 0));
1046 nir_def *max_prims = nir_udiv_imm(b, size, stride * vertices_per_prim);
1047
1048 prims[stream] = nir_umin(b, prims[stream], max_prims);
1049 }
1050
1051 nir_def *any_overflow = nir_imm_false(b);
1052
1053 u_foreach_bit(i, streams) {
1054 nir_def *overflow = nir_ult(b, prims[i], in_prims[i]);
1055 any_overflow = nir_ior(b, any_overflow, overflow);
1056
1057 store_geometry_param(b, xfb_prims[i], prims[i]);
1058
1059 add_counter(b, load_geometry_param(b, xfb_overflow[i]),
1060 nir_b2i32(b, overflow));
1061
1062 add_counter(b, load_geometry_param(b, xfb_prims_generated_counter[i]),
1063 prims[i]);
1064 }
1065
1066 add_counter(b, load_geometry_param(b, xfb_any_overflow),
1067 nir_b2i32(b, any_overflow));
1068
1069 /* Update XFB counters */
1070 u_foreach_bit(i, xfb->buffers_written) {
1071 uint32_t prim_stride_B = xfb->buffers[i].stride * vertices_per_prim;
1072 unsigned stream = xfb->buffer_to_stream[i];
1073
1074 nir_def *off_ptr = load_geometry_param(b, xfb_offs_ptrs[i]);
1075 nir_def *size = nir_imul_imm(b, prims[stream], prim_stride_B);
1076 add_counter(b, off_ptr, size);
1077 }
1078 }
1079
1080 /* The geometry shader is invoked once per primitive (after unrolling
1081 * primitive restart). From the spec:
1082 *
1083 * In case of instanced geometry shaders (see section 11.3.4.2) the
1084 * geometry shader invocations count is incremented for each separate
1085 * instanced invocation.
1086 */
1087 add_counter(b,
1088 nir_load_stat_query_address_agx(
1089 b, .base = PIPE_STAT_QUERY_GS_INVOCATIONS),
1090 nir_imul_imm(b, unrolled_in_prims, invocations));
1091
1092 nir_def *emitted_prims = nir_imm_int(b, 0);
1093 u_foreach_bit(i, streams) {
1094 emitted_prims =
1095 nir_iadd(b, emitted_prims,
1096 previous_xfb_primitives(b, state, i, unrolled_in_prims));
1097 }
1098
1099 add_counter(
1100 b,
1101 nir_load_stat_query_address_agx(b, .base = PIPE_STAT_QUERY_GS_PRIMITIVES),
1102 emitted_prims);
1103
1104 /* Clipper queries are not well-defined, so we can emulate them in lots of
1105 * silly ways. We need the hardware counters to implement them properly. For
1106 * now, just consider all primitives emitted as passing through the clipper.
1107 * This satisfies spec text:
1108 *
1109 * The number of primitives that reach the primitive clipping stage.
1110 *
1111 * and
1112 *
1113 * If at least one vertex of the primitive lies inside the clipping
1114 * volume, the counter is incremented by one or more. Otherwise, the
1115 * counter is incremented by zero or more.
1116 */
1117 add_counter(
1118 b,
1119 nir_load_stat_query_address_agx(b, .base = PIPE_STAT_QUERY_C_PRIMITIVES),
1120 emitted_prims);
1121
1122 add_counter(
1123 b,
1124 nir_load_stat_query_address_agx(b, .base = PIPE_STAT_QUERY_C_INVOCATIONS),
1125 emitted_prims);
1126
1127 agx_preprocess_nir(b->shader, libagx);
1128 return b->shader;
1129 }
1130
1131 static bool
rewrite_invocation_id(nir_builder * b,nir_intrinsic_instr * intr,void * data)1132 rewrite_invocation_id(nir_builder *b, nir_intrinsic_instr *intr, void *data)
1133 {
1134 if (intr->intrinsic != nir_intrinsic_load_invocation_id)
1135 return false;
1136
1137 b->cursor = nir_instr_remove(&intr->instr);
1138 nir_def_rewrite_uses(&intr->def, nir_u2uN(b, data, intr->def.bit_size));
1139 return true;
1140 }
1141
1142 /*
1143 * Geometry shader instancing allows a GS to run multiple times. The number of
1144 * times is statically known and small. It's easiest to turn this into a loop
1145 * inside the GS, to avoid the feature "leaking" outside and affecting e.g. the
1146 * counts.
1147 */
1148 static void
agx_nir_lower_gs_instancing(nir_shader * gs)1149 agx_nir_lower_gs_instancing(nir_shader *gs)
1150 {
1151 unsigned nr_invocations = gs->info.gs.invocations;
1152 nir_function_impl *impl = nir_shader_get_entrypoint(gs);
1153
1154 /* Each invocation can produce up to the shader-declared max_vertices, so
1155 * multiply it up for proper bounds check. Emitting more than the declared
1156 * max_vertices per invocation results in undefined behaviour, so erroneously
1157 * emitting more as asked on early invocations is a perfectly cromulent
1158 * behvaiour.
1159 */
1160 gs->info.gs.vertices_out *= gs->info.gs.invocations;
1161
1162 /* Get the original function */
1163 nir_cf_list list;
1164 nir_cf_extract(&list, nir_before_impl(impl), nir_after_impl(impl));
1165
1166 /* Create a builder for the wrapped function */
1167 nir_builder b = nir_builder_at(nir_after_block(nir_start_block(impl)));
1168
1169 nir_variable *i =
1170 nir_local_variable_create(impl, glsl_uintN_t_type(16), NULL);
1171 nir_store_var(&b, i, nir_imm_intN_t(&b, 0, 16), ~0);
1172 nir_def *index = NULL;
1173
1174 /* Create a loop in the wrapped function */
1175 nir_loop *loop = nir_push_loop(&b);
1176 {
1177 index = nir_load_var(&b, i);
1178 nir_push_if(&b, nir_uge_imm(&b, index, nr_invocations));
1179 {
1180 nir_jump(&b, nir_jump_break);
1181 }
1182 nir_pop_if(&b, NULL);
1183
1184 b.cursor = nir_cf_reinsert(&list, b.cursor);
1185 nir_store_var(&b, i, nir_iadd_imm(&b, index, 1), ~0);
1186
1187 /* Make sure we end the primitive between invocations. If the geometry
1188 * shader already ended the primitive, this will get optimized out.
1189 */
1190 nir_end_primitive(&b);
1191 }
1192 nir_pop_loop(&b, loop);
1193
1194 /* We've mucked about with control flow */
1195 nir_metadata_preserve(impl, nir_metadata_none);
1196
1197 /* Use the loop counter as the invocation ID each iteration */
1198 nir_shader_intrinsics_pass(gs, rewrite_invocation_id,
1199 nir_metadata_control_flow, index);
1200 }
1201
1202 static void
link_libagx(nir_shader * nir,const nir_shader * libagx)1203 link_libagx(nir_shader *nir, const nir_shader *libagx)
1204 {
1205 nir_link_shader_functions(nir, libagx);
1206 NIR_PASS(_, nir, nir_inline_functions);
1207 nir_remove_non_entrypoints(nir);
1208 NIR_PASS(_, nir, nir_lower_indirect_derefs, nir_var_function_temp, 64);
1209 NIR_PASS(_, nir, nir_opt_dce);
1210 NIR_PASS(_, nir, nir_lower_vars_to_explicit_types,
1211 nir_var_shader_temp | nir_var_function_temp | nir_var_mem_shared,
1212 glsl_get_cl_type_size_align);
1213 NIR_PASS(_, nir, nir_opt_deref);
1214 NIR_PASS(_, nir, nir_lower_vars_to_ssa);
1215 NIR_PASS(_, nir, nir_lower_explicit_io,
1216 nir_var_shader_temp | nir_var_function_temp | nir_var_mem_shared |
1217 nir_var_mem_global,
1218 nir_address_format_62bit_generic);
1219 }
1220
1221 bool
agx_nir_lower_gs(nir_shader * gs,const nir_shader * libagx,bool rasterizer_discard,nir_shader ** gs_count,nir_shader ** gs_copy,nir_shader ** pre_gs,enum mesa_prim * out_mode,unsigned * out_count_words)1222 agx_nir_lower_gs(nir_shader *gs, const nir_shader *libagx,
1223 bool rasterizer_discard, nir_shader **gs_count,
1224 nir_shader **gs_copy, nir_shader **pre_gs,
1225 enum mesa_prim *out_mode, unsigned *out_count_words)
1226 {
1227 /* Lower I/O as assumed by the rest of GS lowering */
1228 if (gs->xfb_info != NULL) {
1229 NIR_PASS(_, gs, nir_io_add_const_offset_to_base,
1230 nir_var_shader_in | nir_var_shader_out);
1231 NIR_PASS(_, gs, nir_io_add_intrinsic_xfb_info);
1232 }
1233
1234 NIR_PASS(_, gs, nir_lower_io_to_scalar, nir_var_shader_out, NULL, NULL);
1235
1236 /* Collect output component counts so we can size the geometry output buffer
1237 * appropriately, instead of assuming everything is vec4.
1238 */
1239 uint8_t component_counts[NUM_TOTAL_VARYING_SLOTS] = {0};
1240 nir_shader_intrinsics_pass(gs, collect_components, nir_metadata_all,
1241 component_counts);
1242
1243 /* If geometry shader instancing is used, lower it away before linking
1244 * anything. Otherwise, smash the invocation ID to zero.
1245 */
1246 if (gs->info.gs.invocations != 1) {
1247 agx_nir_lower_gs_instancing(gs);
1248 } else {
1249 nir_function_impl *impl = nir_shader_get_entrypoint(gs);
1250 nir_builder b = nir_builder_at(nir_before_impl(impl));
1251
1252 nir_shader_intrinsics_pass(gs, rewrite_invocation_id,
1253 nir_metadata_control_flow, nir_imm_int(&b, 0));
1254 }
1255
1256 NIR_PASS(_, gs, nir_shader_intrinsics_pass, lower_gs_inputs,
1257 nir_metadata_control_flow, NULL);
1258
1259 /* Lower geometry shader writes to contain all of the required counts, so we
1260 * know where in the various buffers we should write vertices.
1261 */
1262 NIR_PASS(_, gs, nir_lower_gs_intrinsics,
1263 nir_lower_gs_intrinsics_count_primitives |
1264 nir_lower_gs_intrinsics_per_stream |
1265 nir_lower_gs_intrinsics_count_vertices_per_primitive |
1266 nir_lower_gs_intrinsics_overwrite_incomplete |
1267 nir_lower_gs_intrinsics_always_end_primitive |
1268 nir_lower_gs_intrinsics_count_decomposed_primitives);
1269
1270 /* Clean up after all that lowering we did */
1271 bool progress = false;
1272 do {
1273 progress = false;
1274 NIR_PASS(progress, gs, nir_lower_var_copies);
1275 NIR_PASS(progress, gs, nir_lower_variable_initializers,
1276 nir_var_shader_temp);
1277 NIR_PASS(progress, gs, nir_lower_vars_to_ssa);
1278 NIR_PASS(progress, gs, nir_copy_prop);
1279 NIR_PASS(progress, gs, nir_opt_constant_folding);
1280 NIR_PASS(progress, gs, nir_opt_algebraic);
1281 NIR_PASS(progress, gs, nir_opt_cse);
1282 NIR_PASS(progress, gs, nir_opt_dead_cf);
1283 NIR_PASS(progress, gs, nir_opt_dce);
1284
1285 /* Unrolling lets us statically determine counts more often, which
1286 * otherwise would not be possible with multiple invocations even in the
1287 * simplest of cases.
1288 */
1289 NIR_PASS(progress, gs, nir_opt_loop_unroll);
1290 } while (progress);
1291
1292 /* If we know counts at compile-time we can simplify, so try to figure out
1293 * the counts statically.
1294 */
1295 struct lower_gs_state gs_state = {
1296 .rasterizer_discard = rasterizer_discard,
1297 };
1298
1299 nir_gs_count_vertices_and_primitives(
1300 gs, gs_state.static_count[GS_COUNTER_VERTICES],
1301 gs_state.static_count[GS_COUNTER_PRIMITIVES],
1302 gs_state.static_count[GS_COUNTER_XFB_PRIMITIVES], 4);
1303
1304 /* Anything we don't know statically will be tracked by the count buffer.
1305 * Determine the layout for it.
1306 */
1307 for (unsigned i = 0; i < MAX_VERTEX_STREAMS; ++i) {
1308 for (unsigned c = 0; c < GS_NUM_COUNTERS; ++c) {
1309 gs_state.count_index[i][c] =
1310 (gs_state.static_count[c][i] < 0) ? gs_state.count_stride_el++ : -1;
1311 }
1312 }
1313
1314 bool side_effects_for_rast = false;
1315 *gs_copy = agx_nir_create_gs_rast_shader(gs, libagx, &side_effects_for_rast);
1316
1317 NIR_PASS(_, gs, nir_shader_intrinsics_pass, lower_id,
1318 nir_metadata_control_flow, NULL);
1319
1320 link_libagx(gs, libagx);
1321
1322 NIR_PASS(_, gs, nir_lower_idiv,
1323 &(const nir_lower_idiv_options){.allow_fp16 = true});
1324
1325 /* All those variables we created should've gone away by now */
1326 NIR_PASS(_, gs, nir_remove_dead_variables, nir_var_function_temp, NULL);
1327
1328 /* If there is any unknown count, we need a geometry count shader */
1329 if (gs_state.count_stride_el > 0)
1330 *gs_count = agx_nir_create_geometry_count_shader(gs, libagx, &gs_state);
1331 else
1332 *gs_count = NULL;
1333
1334 /* Geometry shader outputs are staged to temporaries */
1335 struct agx_lower_output_to_var_state state = {0};
1336
1337 u_foreach_bit64(slot, gs->info.outputs_written) {
1338 /* After enough optimizations, the shader metadata can go out of sync, fix
1339 * with our gathered info. Otherwise glsl_vector_type will assert fail.
1340 */
1341 if (component_counts[slot] == 0) {
1342 gs->info.outputs_written &= ~BITFIELD64_BIT(slot);
1343 continue;
1344 }
1345
1346 const char *slot_name =
1347 gl_varying_slot_name_for_stage(slot, MESA_SHADER_GEOMETRY);
1348
1349 for (unsigned i = 0; i < MAX_PRIM_OUT_SIZE; ++i) {
1350 gs_state.outputs[slot][i] = nir_variable_create(
1351 gs, nir_var_shader_temp,
1352 glsl_vector_type(GLSL_TYPE_UINT, component_counts[slot]),
1353 ralloc_asprintf(gs, "%s-%u", slot_name, i));
1354 }
1355
1356 state.outputs[slot] = gs_state.outputs[slot][0];
1357 }
1358
1359 NIR_PASS(_, gs, nir_shader_instructions_pass, agx_lower_output_to_var,
1360 nir_metadata_control_flow, &state);
1361
1362 NIR_PASS(_, gs, nir_shader_intrinsics_pass, lower_gs_instr,
1363 nir_metadata_none, &gs_state);
1364
1365 /* Determine if we are guaranteed to rasterize at least one vertex, so that
1366 * we can strip the prepass of side effects knowing they will execute in the
1367 * rasterization shader.
1368 */
1369 bool rasterizes_at_least_one_vertex =
1370 !rasterizer_discard && gs_state.static_count[0][0] > 0;
1371
1372 /* Clean up after all that lowering we did */
1373 nir_lower_global_vars_to_local(gs);
1374 do {
1375 progress = false;
1376 NIR_PASS(progress, gs, nir_lower_var_copies);
1377 NIR_PASS(progress, gs, nir_lower_variable_initializers,
1378 nir_var_shader_temp);
1379 NIR_PASS(progress, gs, nir_lower_vars_to_ssa);
1380 NIR_PASS(progress, gs, nir_copy_prop);
1381 NIR_PASS(progress, gs, nir_opt_constant_folding);
1382 NIR_PASS(progress, gs, nir_opt_algebraic);
1383 NIR_PASS(progress, gs, nir_opt_cse);
1384 NIR_PASS(progress, gs, nir_opt_dead_cf);
1385 NIR_PASS(progress, gs, nir_opt_dce);
1386 NIR_PASS(progress, gs, nir_opt_loop_unroll);
1387
1388 } while (progress);
1389
1390 /* When rasterizing, we try to handle side effects sensibly. */
1391 if (rasterizes_at_least_one_vertex && side_effects_for_rast) {
1392 do {
1393 progress = false;
1394 NIR_PASS(progress, gs, nir_shader_intrinsics_pass,
1395 strip_side_effect_from_main, nir_metadata_control_flow, NULL);
1396
1397 NIR_PASS(progress, gs, nir_opt_dce);
1398 NIR_PASS(progress, gs, nir_opt_dead_cf);
1399 } while (progress);
1400 }
1401
1402 /* All those variables we created should've gone away by now */
1403 NIR_PASS(_, gs, nir_remove_dead_variables, nir_var_function_temp, NULL);
1404
1405 NIR_PASS(_, gs, nir_opt_sink, ~0);
1406 NIR_PASS(_, gs, nir_opt_move, ~0);
1407
1408 NIR_PASS(_, gs, nir_shader_intrinsics_pass, lower_id,
1409 nir_metadata_control_flow, NULL);
1410
1411 /* Create auxiliary programs */
1412 *pre_gs = agx_nir_create_pre_gs(
1413 &gs_state, libagx, true, gs->info.gs.output_primitive != MESA_PRIM_POINTS,
1414 gs->xfb_info, verts_in_output_prim(gs), gs->info.gs.active_stream_mask,
1415 gs->info.gs.invocations);
1416
1417 /* Signal what primitive we want to draw the GS Copy VS with */
1418 *out_mode = gs->info.gs.output_primitive;
1419 *out_count_words = gs_state.count_stride_el;
1420 return true;
1421 }
1422
1423 /*
1424 * Vertex shaders (tessellation evaluation shaders) before a geometry shader run
1425 * as a dedicated compute prepass. They are invoked as (count, instances, 1).
1426 * Their linear ID is therefore (instances * num vertices) + vertex ID.
1427 *
1428 * This function lowers their vertex shader I/O to compute.
1429 *
1430 * Vertex ID becomes an index buffer pull (without applying the topology). Store
1431 * output becomes a store into the global vertex output buffer.
1432 */
1433 static bool
lower_vs_before_gs(nir_builder * b,nir_intrinsic_instr * intr,void * data)1434 lower_vs_before_gs(nir_builder *b, nir_intrinsic_instr *intr, void *data)
1435 {
1436 if (intr->intrinsic != nir_intrinsic_store_output)
1437 return false;
1438
1439 b->cursor = nir_instr_remove(&intr->instr);
1440 nir_io_semantics sem = nir_intrinsic_io_semantics(intr);
1441 nir_def *location = nir_iadd_imm(b, intr->src[1].ssa, sem.location);
1442
1443 /* We inline the outputs_written because it's known at compile-time, even
1444 * with shader objects. This lets us constant fold a bit of address math.
1445 */
1446 nir_def *mask = nir_imm_int64(b, b->shader->info.outputs_written);
1447
1448 nir_def *buffer;
1449 nir_def *nr_verts;
1450 if (b->shader->info.stage == MESA_SHADER_VERTEX) {
1451 buffer = nir_load_vs_output_buffer_agx(b);
1452 nr_verts =
1453 libagx_input_vertices(b, nir_load_input_assembly_buffer_agx(b));
1454 } else {
1455 assert(b->shader->info.stage == MESA_SHADER_TESS_EVAL);
1456
1457 /* Instancing is unrolled during tessellation so nr_verts is ignored. */
1458 nr_verts = nir_imm_int(b, 0);
1459 buffer = libagx_tes_buffer(b, nir_load_tess_param_buffer_agx(b));
1460 }
1461
1462 nir_def *linear_id = nir_iadd(b, nir_imul(b, load_instance_id(b), nr_verts),
1463 load_primitive_id(b));
1464
1465 nir_def *addr =
1466 libagx_vertex_output_address(b, buffer, mask, linear_id, location);
1467
1468 assert(nir_src_bit_size(intr->src[0]) == 32);
1469 addr = nir_iadd_imm(b, addr, nir_intrinsic_component(intr) * 4);
1470
1471 nir_store_global(b, addr, 4, intr->src[0].ssa,
1472 nir_intrinsic_write_mask(intr));
1473 return true;
1474 }
1475
1476 bool
agx_nir_lower_vs_before_gs(struct nir_shader * vs,const struct nir_shader * libagx)1477 agx_nir_lower_vs_before_gs(struct nir_shader *vs,
1478 const struct nir_shader *libagx)
1479 {
1480 bool progress = false;
1481
1482 /* Lower vertex stores to memory stores */
1483 progress |= nir_shader_intrinsics_pass(vs, lower_vs_before_gs,
1484 nir_metadata_control_flow, NULL);
1485
1486 /* Link libagx, used in lower_vs_before_gs */
1487 if (progress)
1488 link_libagx(vs, libagx);
1489
1490 return progress;
1491 }
1492