1 /*
2 * Copyright 2023 Alyssa Rosenzweig
3 * SPDX-License-Identifier: MIT
4 */
5
6 #include "shaders/geometry.h"
7 #include "util/bitscan.h"
8 #include "util/macros.h"
9 #include "agx_nir_lower_gs.h"
10 #include "glsl_types.h"
11 #include "libagx_shaders.h"
12 #include "nir.h"
13 #include "nir_builder.h"
14 #include "nir_builder_opcodes.h"
15 #include "nir_intrinsics.h"
16 #include "nir_intrinsics_indices.h"
17 #include "shader_enums.h"
18
19 struct tcs_state {
20 struct agx_lower_output_to_var_state vs_vars;
21 uint64_t vs_outputs_written;
22 };
23
24 static nir_def *
tcs_patch_id(nir_builder * b)25 tcs_patch_id(nir_builder *b)
26 {
27 return nir_channel(b, nir_load_workgroup_id(b), 0);
28 }
29
30 static nir_def *
tcs_instance_id(nir_builder * b)31 tcs_instance_id(nir_builder *b)
32 {
33 return nir_channel(b, nir_load_workgroup_id(b), 1);
34 }
35
36 static nir_def *
tcs_unrolled_id(nir_builder * b)37 tcs_unrolled_id(nir_builder *b)
38 {
39 nir_def *stride = nir_channel(b, nir_load_num_workgroups(b), 0);
40
41 return nir_iadd(b, nir_imul(b, tcs_instance_id(b), stride), tcs_patch_id(b));
42 }
43
44 uint64_t
agx_tcs_per_vertex_outputs(const nir_shader * nir)45 agx_tcs_per_vertex_outputs(const nir_shader *nir)
46 {
47 return nir->info.outputs_written &
48 ~(VARYING_BIT_TESS_LEVEL_INNER | VARYING_BIT_TESS_LEVEL_OUTER |
49 VARYING_BIT_BOUNDING_BOX0 | VARYING_BIT_BOUNDING_BOX1);
50 }
51
52 unsigned
agx_tcs_output_stride(const nir_shader * nir)53 agx_tcs_output_stride(const nir_shader *nir)
54 {
55 return libagx_tcs_out_stride(util_last_bit(nir->info.patch_outputs_written),
56 nir->info.tess.tcs_vertices_out,
57 agx_tcs_per_vertex_outputs(nir));
58 }
59
60 static nir_def *
tcs_out_addr(nir_builder * b,nir_intrinsic_instr * intr,nir_def * vertex_id)61 tcs_out_addr(nir_builder *b, nir_intrinsic_instr *intr, nir_def *vertex_id)
62 {
63 nir_io_semantics sem = nir_intrinsic_io_semantics(intr);
64
65 nir_def *offset = nir_get_io_offset_src(intr)->ssa;
66 nir_def *addr = libagx_tcs_out_address(
67 b, nir_load_tess_param_buffer_agx(b), tcs_unrolled_id(b), vertex_id,
68 nir_iadd_imm(b, offset, sem.location),
69 nir_imm_int(b, util_last_bit(b->shader->info.patch_outputs_written)),
70 nir_imm_int(b, b->shader->info.tess.tcs_vertices_out),
71 nir_imm_int64(b, agx_tcs_per_vertex_outputs(b->shader)));
72
73 addr = nir_iadd_imm(b, addr, nir_intrinsic_component(intr) * 4);
74
75 return addr;
76 }
77
78 static nir_def *
lower_tes_load(nir_builder * b,nir_intrinsic_instr * intr)79 lower_tes_load(nir_builder *b, nir_intrinsic_instr *intr)
80 {
81 gl_varying_slot location = nir_intrinsic_io_semantics(intr).location;
82 nir_src *offset_src = nir_get_io_offset_src(intr);
83
84 nir_def *vertex = nir_imm_int(b, 0);
85 nir_def *offset = offset_src ? offset_src->ssa : nir_imm_int(b, 0);
86
87 if (intr->intrinsic == nir_intrinsic_load_per_vertex_input)
88 vertex = intr->src[0].ssa;
89
90 nir_def *addr = libagx_tes_in_address(b, nir_load_tess_param_buffer_agx(b),
91 nir_load_vertex_id(b), vertex,
92 nir_iadd_imm(b, offset, location));
93
94 if (nir_intrinsic_has_component(intr))
95 addr = nir_iadd_imm(b, addr, nir_intrinsic_component(intr) * 4);
96
97 return nir_load_global_constant(b, addr, 4, intr->def.num_components,
98 intr->def.bit_size);
99 }
100
101 static nir_def *
tcs_load_input(nir_builder * b,nir_intrinsic_instr * intr,struct tcs_state * state)102 tcs_load_input(nir_builder *b, nir_intrinsic_instr *intr,
103 struct tcs_state *state)
104 {
105 nir_io_semantics sem = nir_intrinsic_io_semantics(intr);
106
107 nir_def *off = libagx_tcs_in_offset(
108 b, intr->src[0].ssa, nir_iadd_imm(b, intr->src[1].ssa, sem.location),
109 nir_imm_int64(b, state->vs_outputs_written));
110
111 off = nir_iadd_imm(b, off, 4 * nir_intrinsic_component(intr));
112
113 return nir_load_shared(b, intr->def.num_components, 32, off);
114 }
115
116 static nir_def *
lower_tcs_impl(nir_builder * b,nir_intrinsic_instr * intr,struct tcs_state * state)117 lower_tcs_impl(nir_builder *b, nir_intrinsic_instr *intr,
118 struct tcs_state *state)
119 {
120 switch (intr->intrinsic) {
121 case nir_intrinsic_barrier:
122 /* A patch fits in a subgroup, so the barrier is unnecessary. */
123 return NIR_LOWER_INSTR_PROGRESS_REPLACE;
124
125 case nir_intrinsic_load_primitive_id:
126 return tcs_patch_id(b);
127
128 case nir_intrinsic_load_instance_id:
129 return tcs_instance_id(b);
130
131 case nir_intrinsic_load_invocation_id:
132 return nir_channel(b, nir_load_local_invocation_id(b), 0);
133
134 case nir_intrinsic_load_per_vertex_input:
135 return tcs_load_input(b, intr, state);
136
137 case nir_intrinsic_load_patch_vertices_in:
138 return libagx_tcs_patch_vertices_in(b, nir_load_tess_param_buffer_agx(b));
139
140 case nir_intrinsic_load_tess_level_outer_default:
141 return libagx_tess_level_outer_default(b,
142 nir_load_tess_param_buffer_agx(b));
143
144 case nir_intrinsic_load_tess_level_inner_default:
145 return libagx_tess_level_inner_default(b,
146 nir_load_tess_param_buffer_agx(b));
147
148 case nir_intrinsic_load_output: {
149 nir_def *addr = tcs_out_addr(b, intr, nir_undef(b, 1, 32));
150 return nir_load_global(b, addr, 4, intr->def.num_components,
151 intr->def.bit_size);
152 }
153
154 case nir_intrinsic_load_per_vertex_output: {
155 nir_def *addr = tcs_out_addr(b, intr, intr->src[0].ssa);
156 return nir_load_global(b, addr, 4, intr->def.num_components,
157 intr->def.bit_size);
158 }
159
160 case nir_intrinsic_store_output: {
161 nir_store_global(b, tcs_out_addr(b, intr, nir_undef(b, 1, 32)), 4,
162 intr->src[0].ssa, nir_intrinsic_write_mask(intr));
163 return NIR_LOWER_INSTR_PROGRESS_REPLACE;
164 }
165
166 case nir_intrinsic_store_per_vertex_output: {
167 nir_store_global(b, tcs_out_addr(b, intr, intr->src[1].ssa), 4,
168 intr->src[0].ssa, nir_intrinsic_write_mask(intr));
169 return NIR_LOWER_INSTR_PROGRESS_REPLACE;
170 }
171
172 default:
173 return NULL;
174 }
175 }
176
177 static bool
lower_tcs(nir_builder * b,nir_intrinsic_instr * intr,void * data)178 lower_tcs(nir_builder *b, nir_intrinsic_instr *intr, void *data)
179 {
180 b->cursor = nir_before_instr(&intr->instr);
181
182 nir_def *repl = lower_tcs_impl(b, intr, data);
183 if (!repl)
184 return false;
185
186 if (repl != NIR_LOWER_INSTR_PROGRESS_REPLACE)
187 nir_def_rewrite_uses(&intr->def, repl);
188
189 nir_instr_remove(&intr->instr);
190 return true;
191 }
192
193 static void
link_libagx(nir_shader * nir,const nir_shader * libagx)194 link_libagx(nir_shader *nir, const nir_shader *libagx)
195 {
196 nir_link_shader_functions(nir, libagx);
197 NIR_PASS(_, nir, nir_inline_functions);
198 nir_remove_non_entrypoints(nir);
199 NIR_PASS(_, nir, nir_lower_indirect_derefs, nir_var_function_temp, 64);
200 NIR_PASS(_, nir, nir_opt_dce);
201 NIR_PASS(_, nir, nir_lower_vars_to_explicit_types, nir_var_function_temp,
202 glsl_get_cl_type_size_align);
203 NIR_PASS(_, nir, nir_opt_deref);
204 NIR_PASS(_, nir, nir_lower_vars_to_ssa);
205 NIR_PASS(_, nir, nir_lower_explicit_io,
206 nir_var_shader_temp | nir_var_function_temp | nir_var_mem_shared |
207 nir_var_mem_global,
208 nir_address_format_62bit_generic);
209 }
210
211 /*
212 * Predicate the TCS so the merged shader works when input patch size > output
213 * patch size.
214 */
215 static bool
agx_nir_predicate_tcs(nir_shader * tcs)216 agx_nir_predicate_tcs(nir_shader *tcs)
217 {
218 nir_function_impl *entry = nir_shader_get_entrypoint(tcs);
219 nir_cf_list list;
220 nir_cf_extract(&list, nir_before_impl(entry), nir_after_impl(entry));
221
222 nir_builder b = nir_builder_at(nir_after_block(nir_start_block(entry)));
223 nir_def *input_vtx_id = nir_load_invocation_id(&b);
224 unsigned verts = tcs->info.tess.tcs_vertices_out;
225
226 nir_push_if(&b, nir_ult_imm(&b, input_vtx_id, verts));
227 {
228 nir_cf_reinsert(&list, b.cursor);
229 }
230 nir_pop_if(&b, NULL);
231
232 nir_metadata_preserve(entry, nir_metadata_none);
233 return false;
234 }
235
236 bool
agx_nir_lower_tcs(nir_shader * tcs,const nir_shader * vs,const struct nir_shader * libagx,uint8_t index_size_B)237 agx_nir_lower_tcs(nir_shader *tcs, const nir_shader *vs,
238 const struct nir_shader *libagx, uint8_t index_size_B)
239 {
240 agx_nir_predicate_tcs(tcs);
241
242 nir_function_impl *tcs_entry = nir_shader_get_entrypoint(tcs);
243
244 /* Link the vertex shader with the TCS. This assumes that all functions have
245 * been inlined in the vertex shader.
246 */
247 nir_function_impl *vs_entry = nir_shader_get_entrypoint(vs);
248 nir_function *vs_function = nir_function_create(tcs, "vertex");
249 vs_function->impl = nir_function_impl_clone(tcs, vs_entry);
250 vs_function->impl->function = vs_function;
251
252 /* Vertex shader outputs are staged to temporaries */
253 struct tcs_state state = {
254 .vs_outputs_written = vs->info.outputs_written & tcs->info.inputs_read,
255 };
256
257 u_foreach_bit64(slot, vs->info.outputs_written) {
258 const char *slot_name =
259 gl_varying_slot_name_for_stage(slot, MESA_SHADER_VERTEX);
260
261 state.vs_vars.outputs[slot] = nir_variable_create(
262 tcs, nir_var_shader_temp, glsl_uvec4_type(), slot_name);
263 }
264
265 nir_function_instructions_pass(
266 vs_function->impl, agx_lower_output_to_var,
267 nir_metadata_block_index | nir_metadata_dominance, &state.vs_vars);
268
269 /* Invoke the VS first for each vertex in the input patch */
270 nir_builder b_ = nir_builder_at(nir_before_impl(tcs_entry));
271 nir_builder *b = &b_;
272
273 nir_def *input_vtx_id = nir_load_invocation_id(b);
274 nir_push_if(b, nir_ult(b, input_vtx_id, nir_load_patch_vertices_in(b)));
275 {
276 nir_inline_function_impl(b, vs_function->impl, NULL, NULL);
277
278 /* To handle cross-invocation VS output reads, dump everything in
279 * shared local memory.
280 *
281 * TODO: Optimize to registers.
282 */
283 u_foreach_bit64(slot, state.vs_outputs_written) {
284 nir_def *off =
285 libagx_tcs_in_offset(b, input_vtx_id, nir_imm_int(b, slot),
286 nir_imm_int64(b, state.vs_outputs_written));
287
288 nir_store_shared(b, nir_load_var(b, state.vs_vars.outputs[slot]), off,
289 .write_mask = nir_component_mask(4));
290 }
291 }
292 nir_pop_if(b, NULL);
293
294 /* Clean up after inlining VS into TCS */
295 exec_node_remove(&vs_function->node);
296 nir_lower_global_vars_to_local(tcs);
297
298 /* Lower I/A. TODO: Indirect multidraws */
299 agx_nir_lower_index_buffer(tcs, index_size_B, true);
300
301 /* Lower TCS outputs */
302 nir_shader_intrinsics_pass(tcs, lower_tcs,
303 nir_metadata_block_index | nir_metadata_dominance,
304 &state);
305 link_libagx(tcs, libagx);
306 nir_metadata_preserve(b->impl, nir_metadata_none);
307 return true;
308 }
309
310 static nir_def *
lower_tes_impl(nir_builder * b,nir_intrinsic_instr * intr,void * data)311 lower_tes_impl(nir_builder *b, nir_intrinsic_instr *intr, void *data)
312 {
313 switch (intr->intrinsic) {
314 case nir_intrinsic_load_tess_coord_xy:
315 return libagx_load_tess_coord(b, nir_load_tess_param_buffer_agx(b),
316 nir_load_vertex_id(b));
317
318 case nir_intrinsic_load_primitive_id:
319 return libagx_tes_patch_id(b, nir_load_tess_param_buffer_agx(b),
320 nir_load_vertex_id(b));
321
322 case nir_intrinsic_load_input:
323 case nir_intrinsic_load_per_vertex_input:
324 case nir_intrinsic_load_tess_level_inner:
325 case nir_intrinsic_load_tess_level_outer:
326 return lower_tes_load(b, intr);
327
328 case nir_intrinsic_load_patch_vertices_in:
329 return libagx_tes_patch_vertices_in(b, nir_load_tess_param_buffer_agx(b));
330
331 default:
332 return NULL;
333 }
334 }
335
336 static bool
lower_tes(nir_builder * b,nir_intrinsic_instr * intr,void * data)337 lower_tes(nir_builder *b, nir_intrinsic_instr *intr, void *data)
338 {
339 b->cursor = nir_before_instr(&intr->instr);
340 nir_def *repl = lower_tes_impl(b, intr, data);
341
342 if (repl) {
343 nir_def_rewrite_uses(&intr->def, repl);
344 nir_instr_remove(&intr->instr);
345 return true;
346 } else {
347 return false;
348 }
349 }
350
351 static int
glsl_type_size(const struct glsl_type * type,bool bindless)352 glsl_type_size(const struct glsl_type *type, bool bindless)
353 {
354 return glsl_count_attribute_slots(type, false);
355 }
356
357 bool
agx_nir_lower_tes(nir_shader * tes,const nir_shader * libagx)358 agx_nir_lower_tes(nir_shader *tes, const nir_shader *libagx)
359 {
360 nir_lower_tess_coord_z(
361 tes, tes->info.tess._primitive_mode == TESS_PRIMITIVE_TRIANGLES);
362
363 nir_shader_intrinsics_pass(
364 tes, lower_tes, nir_metadata_block_index | nir_metadata_dominance, NULL);
365
366 /* Points mode renders as points, make sure we write point size for the HW */
367 if (tes->info.tess.point_mode &&
368 !(tes->info.outputs_written & VARYING_BIT_PSIZ)) {
369
370 nir_function_impl *impl = nir_shader_get_entrypoint(tes);
371 nir_builder b = nir_builder_at(nir_after_impl(impl));
372
373 nir_store_output(&b, nir_imm_float(&b, 1.0), nir_imm_int(&b, 0),
374 .io_semantics.location = VARYING_SLOT_PSIZ,
375 .write_mask = nir_component_mask(1), .range = 1);
376
377 tes->info.outputs_written |= VARYING_BIT_PSIZ;
378 }
379
380 /* We lower to a HW VS, so update the shader info so the compiler does the
381 * right thing.
382 */
383 tes->info.stage = MESA_SHADER_VERTEX;
384 memset(&tes->info.vs, 0, sizeof(tes->info.vs));
385 tes->info.vs.tes_agx = true;
386
387 link_libagx(tes, libagx);
388 nir_lower_idiv(tes, &(nir_lower_idiv_options){.allow_fp16 = true});
389 nir_metadata_preserve(nir_shader_get_entrypoint(tes), nir_metadata_none);
390 return true;
391 }
392