• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright 2023 Alyssa Rosenzweig
3  * SPDX-License-Identifier: MIT
4  */
5 
6 #include "shaders/geometry.h"
7 #include "util/bitscan.h"
8 #include "util/macros.h"
9 #include "agx_nir_lower_gs.h"
10 #include "glsl_types.h"
11 #include "libagx_shaders.h"
12 #include "nir.h"
13 #include "nir_builder.h"
14 #include "nir_builder_opcodes.h"
15 #include "nir_intrinsics.h"
16 #include "nir_intrinsics_indices.h"
17 #include "shader_enums.h"
18 
19 struct tcs_state {
20    struct agx_lower_output_to_var_state vs_vars;
21    uint64_t vs_outputs_written;
22 };
23 
24 static nir_def *
tcs_patch_id(nir_builder * b)25 tcs_patch_id(nir_builder *b)
26 {
27    return nir_channel(b, nir_load_workgroup_id(b), 0);
28 }
29 
30 static nir_def *
tcs_instance_id(nir_builder * b)31 tcs_instance_id(nir_builder *b)
32 {
33    return nir_channel(b, nir_load_workgroup_id(b), 1);
34 }
35 
36 static nir_def *
tcs_unrolled_id(nir_builder * b)37 tcs_unrolled_id(nir_builder *b)
38 {
39    nir_def *stride = nir_channel(b, nir_load_num_workgroups(b), 0);
40 
41    return nir_iadd(b, nir_imul(b, tcs_instance_id(b), stride), tcs_patch_id(b));
42 }
43 
44 uint64_t
agx_tcs_per_vertex_outputs(const nir_shader * nir)45 agx_tcs_per_vertex_outputs(const nir_shader *nir)
46 {
47    return nir->info.outputs_written &
48           ~(VARYING_BIT_TESS_LEVEL_INNER | VARYING_BIT_TESS_LEVEL_OUTER |
49             VARYING_BIT_BOUNDING_BOX0 | VARYING_BIT_BOUNDING_BOX1);
50 }
51 
52 unsigned
agx_tcs_output_stride(const nir_shader * nir)53 agx_tcs_output_stride(const nir_shader *nir)
54 {
55    return libagx_tcs_out_stride(util_last_bit(nir->info.patch_outputs_written),
56                                 nir->info.tess.tcs_vertices_out,
57                                 agx_tcs_per_vertex_outputs(nir));
58 }
59 
60 static nir_def *
tcs_out_addr(nir_builder * b,nir_intrinsic_instr * intr,nir_def * vertex_id)61 tcs_out_addr(nir_builder *b, nir_intrinsic_instr *intr, nir_def *vertex_id)
62 {
63    nir_io_semantics sem = nir_intrinsic_io_semantics(intr);
64 
65    nir_def *offset = nir_get_io_offset_src(intr)->ssa;
66    nir_def *addr = libagx_tcs_out_address(
67       b, nir_load_tess_param_buffer_agx(b), tcs_unrolled_id(b), vertex_id,
68       nir_iadd_imm(b, offset, sem.location),
69       nir_imm_int(b, util_last_bit(b->shader->info.patch_outputs_written)),
70       nir_imm_int(b, b->shader->info.tess.tcs_vertices_out),
71       nir_imm_int64(b, agx_tcs_per_vertex_outputs(b->shader)));
72 
73    addr = nir_iadd_imm(b, addr, nir_intrinsic_component(intr) * 4);
74 
75    return addr;
76 }
77 
78 static nir_def *
lower_tes_load(nir_builder * b,nir_intrinsic_instr * intr)79 lower_tes_load(nir_builder *b, nir_intrinsic_instr *intr)
80 {
81    gl_varying_slot location = nir_intrinsic_io_semantics(intr).location;
82    nir_src *offset_src = nir_get_io_offset_src(intr);
83 
84    nir_def *vertex = nir_imm_int(b, 0);
85    nir_def *offset = offset_src ? offset_src->ssa : nir_imm_int(b, 0);
86 
87    if (intr->intrinsic == nir_intrinsic_load_per_vertex_input)
88       vertex = intr->src[0].ssa;
89 
90    nir_def *addr = libagx_tes_in_address(b, nir_load_tess_param_buffer_agx(b),
91                                          nir_load_vertex_id(b), vertex,
92                                          nir_iadd_imm(b, offset, location));
93 
94    if (nir_intrinsic_has_component(intr))
95       addr = nir_iadd_imm(b, addr, nir_intrinsic_component(intr) * 4);
96 
97    return nir_load_global_constant(b, addr, 4, intr->def.num_components,
98                                    intr->def.bit_size);
99 }
100 
101 static nir_def *
tcs_load_input(nir_builder * b,nir_intrinsic_instr * intr,struct tcs_state * state)102 tcs_load_input(nir_builder *b, nir_intrinsic_instr *intr,
103                struct tcs_state *state)
104 {
105    nir_io_semantics sem = nir_intrinsic_io_semantics(intr);
106 
107    nir_def *off = libagx_tcs_in_offset(
108       b, intr->src[0].ssa, nir_iadd_imm(b, intr->src[1].ssa, sem.location),
109       nir_imm_int64(b, state->vs_outputs_written));
110 
111    off = nir_iadd_imm(b, off, 4 * nir_intrinsic_component(intr));
112 
113    return nir_load_shared(b, intr->def.num_components, 32, off);
114 }
115 
116 static nir_def *
lower_tcs_impl(nir_builder * b,nir_intrinsic_instr * intr,struct tcs_state * state)117 lower_tcs_impl(nir_builder *b, nir_intrinsic_instr *intr,
118                struct tcs_state *state)
119 {
120    switch (intr->intrinsic) {
121    case nir_intrinsic_barrier:
122       /* A patch fits in a subgroup, so the barrier is unnecessary. */
123       return NIR_LOWER_INSTR_PROGRESS_REPLACE;
124 
125    case nir_intrinsic_load_primitive_id:
126       return tcs_patch_id(b);
127 
128    case nir_intrinsic_load_instance_id:
129       return tcs_instance_id(b);
130 
131    case nir_intrinsic_load_invocation_id:
132       return nir_channel(b, nir_load_local_invocation_id(b), 0);
133 
134    case nir_intrinsic_load_per_vertex_input:
135       return tcs_load_input(b, intr, state);
136 
137    case nir_intrinsic_load_patch_vertices_in:
138       return libagx_tcs_patch_vertices_in(b, nir_load_tess_param_buffer_agx(b));
139 
140    case nir_intrinsic_load_tess_level_outer_default:
141       return libagx_tess_level_outer_default(b,
142                                              nir_load_tess_param_buffer_agx(b));
143 
144    case nir_intrinsic_load_tess_level_inner_default:
145       return libagx_tess_level_inner_default(b,
146                                              nir_load_tess_param_buffer_agx(b));
147 
148    case nir_intrinsic_load_output: {
149       nir_def *addr = tcs_out_addr(b, intr, nir_undef(b, 1, 32));
150       return nir_load_global(b, addr, 4, intr->def.num_components,
151                              intr->def.bit_size);
152    }
153 
154    case nir_intrinsic_load_per_vertex_output: {
155       nir_def *addr = tcs_out_addr(b, intr, intr->src[0].ssa);
156       return nir_load_global(b, addr, 4, intr->def.num_components,
157                              intr->def.bit_size);
158    }
159 
160    case nir_intrinsic_store_output: {
161       nir_store_global(b, tcs_out_addr(b, intr, nir_undef(b, 1, 32)), 4,
162                        intr->src[0].ssa, nir_intrinsic_write_mask(intr));
163       return NIR_LOWER_INSTR_PROGRESS_REPLACE;
164    }
165 
166    case nir_intrinsic_store_per_vertex_output: {
167       nir_store_global(b, tcs_out_addr(b, intr, intr->src[1].ssa), 4,
168                        intr->src[0].ssa, nir_intrinsic_write_mask(intr));
169       return NIR_LOWER_INSTR_PROGRESS_REPLACE;
170    }
171 
172    default:
173       return NULL;
174    }
175 }
176 
177 static bool
lower_tcs(nir_builder * b,nir_intrinsic_instr * intr,void * data)178 lower_tcs(nir_builder *b, nir_intrinsic_instr *intr, void *data)
179 {
180    b->cursor = nir_before_instr(&intr->instr);
181 
182    nir_def *repl = lower_tcs_impl(b, intr, data);
183    if (!repl)
184       return false;
185 
186    if (repl != NIR_LOWER_INSTR_PROGRESS_REPLACE)
187       nir_def_rewrite_uses(&intr->def, repl);
188 
189    nir_instr_remove(&intr->instr);
190    return true;
191 }
192 
193 static void
link_libagx(nir_shader * nir,const nir_shader * libagx)194 link_libagx(nir_shader *nir, const nir_shader *libagx)
195 {
196    nir_link_shader_functions(nir, libagx);
197    NIR_PASS(_, nir, nir_inline_functions);
198    nir_remove_non_entrypoints(nir);
199    NIR_PASS(_, nir, nir_lower_indirect_derefs, nir_var_function_temp, 64);
200    NIR_PASS(_, nir, nir_opt_dce);
201    NIR_PASS(_, nir, nir_lower_vars_to_explicit_types, nir_var_function_temp,
202             glsl_get_cl_type_size_align);
203    NIR_PASS(_, nir, nir_opt_deref);
204    NIR_PASS(_, nir, nir_lower_vars_to_ssa);
205    NIR_PASS(_, nir, nir_lower_explicit_io,
206             nir_var_shader_temp | nir_var_function_temp | nir_var_mem_shared |
207                nir_var_mem_global,
208             nir_address_format_62bit_generic);
209 }
210 
211 /*
212  * Predicate the TCS so the merged shader works when input patch size > output
213  * patch size.
214  */
215 static bool
agx_nir_predicate_tcs(nir_shader * tcs)216 agx_nir_predicate_tcs(nir_shader *tcs)
217 {
218    nir_function_impl *entry = nir_shader_get_entrypoint(tcs);
219    nir_cf_list list;
220    nir_cf_extract(&list, nir_before_impl(entry), nir_after_impl(entry));
221 
222    nir_builder b = nir_builder_at(nir_after_block(nir_start_block(entry)));
223    nir_def *input_vtx_id = nir_load_invocation_id(&b);
224    unsigned verts = tcs->info.tess.tcs_vertices_out;
225 
226    nir_push_if(&b, nir_ult_imm(&b, input_vtx_id, verts));
227    {
228       nir_cf_reinsert(&list, b.cursor);
229    }
230    nir_pop_if(&b, NULL);
231 
232    nir_metadata_preserve(entry, nir_metadata_none);
233    return false;
234 }
235 
236 bool
agx_nir_lower_tcs(nir_shader * tcs,const nir_shader * vs,const struct nir_shader * libagx,uint8_t index_size_B)237 agx_nir_lower_tcs(nir_shader *tcs, const nir_shader *vs,
238                   const struct nir_shader *libagx, uint8_t index_size_B)
239 {
240    agx_nir_predicate_tcs(tcs);
241 
242    nir_function_impl *tcs_entry = nir_shader_get_entrypoint(tcs);
243 
244    /* Link the vertex shader with the TCS. This assumes that all functions have
245     * been inlined in the vertex shader.
246     */
247    nir_function_impl *vs_entry = nir_shader_get_entrypoint(vs);
248    nir_function *vs_function = nir_function_create(tcs, "vertex");
249    vs_function->impl = nir_function_impl_clone(tcs, vs_entry);
250    vs_function->impl->function = vs_function;
251 
252    /* Vertex shader outputs are staged to temporaries */
253    struct tcs_state state = {
254       .vs_outputs_written = vs->info.outputs_written & tcs->info.inputs_read,
255    };
256 
257    u_foreach_bit64(slot, vs->info.outputs_written) {
258       const char *slot_name =
259          gl_varying_slot_name_for_stage(slot, MESA_SHADER_VERTEX);
260 
261       state.vs_vars.outputs[slot] = nir_variable_create(
262          tcs, nir_var_shader_temp, glsl_uvec4_type(), slot_name);
263    }
264 
265    nir_function_instructions_pass(
266       vs_function->impl, agx_lower_output_to_var,
267       nir_metadata_block_index | nir_metadata_dominance, &state.vs_vars);
268 
269    /* Invoke the VS first for each vertex in the input patch */
270    nir_builder b_ = nir_builder_at(nir_before_impl(tcs_entry));
271    nir_builder *b = &b_;
272 
273    nir_def *input_vtx_id = nir_load_invocation_id(b);
274    nir_push_if(b, nir_ult(b, input_vtx_id, nir_load_patch_vertices_in(b)));
275    {
276       nir_inline_function_impl(b, vs_function->impl, NULL, NULL);
277 
278       /* To handle cross-invocation VS output reads, dump everything in
279        * shared local memory.
280        *
281        * TODO: Optimize to registers.
282        */
283       u_foreach_bit64(slot, state.vs_outputs_written) {
284          nir_def *off =
285             libagx_tcs_in_offset(b, input_vtx_id, nir_imm_int(b, slot),
286                                  nir_imm_int64(b, state.vs_outputs_written));
287 
288          nir_store_shared(b, nir_load_var(b, state.vs_vars.outputs[slot]), off,
289                           .write_mask = nir_component_mask(4));
290       }
291    }
292    nir_pop_if(b, NULL);
293 
294    /* Clean up after inlining VS into TCS */
295    exec_node_remove(&vs_function->node);
296    nir_lower_global_vars_to_local(tcs);
297 
298    /* Lower I/A. TODO: Indirect multidraws */
299    agx_nir_lower_index_buffer(tcs, index_size_B, true);
300 
301    /* Lower TCS outputs */
302    nir_shader_intrinsics_pass(tcs, lower_tcs,
303                               nir_metadata_block_index | nir_metadata_dominance,
304                               &state);
305    link_libagx(tcs, libagx);
306    nir_metadata_preserve(b->impl, nir_metadata_none);
307    return true;
308 }
309 
310 static nir_def *
lower_tes_impl(nir_builder * b,nir_intrinsic_instr * intr,void * data)311 lower_tes_impl(nir_builder *b, nir_intrinsic_instr *intr, void *data)
312 {
313    switch (intr->intrinsic) {
314    case nir_intrinsic_load_tess_coord_xy:
315       return libagx_load_tess_coord(b, nir_load_tess_param_buffer_agx(b),
316                                     nir_load_vertex_id(b));
317 
318    case nir_intrinsic_load_primitive_id:
319       return libagx_tes_patch_id(b, nir_load_tess_param_buffer_agx(b),
320                                  nir_load_vertex_id(b));
321 
322    case nir_intrinsic_load_input:
323    case nir_intrinsic_load_per_vertex_input:
324    case nir_intrinsic_load_tess_level_inner:
325    case nir_intrinsic_load_tess_level_outer:
326       return lower_tes_load(b, intr);
327 
328    case nir_intrinsic_load_patch_vertices_in:
329       return libagx_tes_patch_vertices_in(b, nir_load_tess_param_buffer_agx(b));
330 
331    default:
332       return NULL;
333    }
334 }
335 
336 static bool
lower_tes(nir_builder * b,nir_intrinsic_instr * intr,void * data)337 lower_tes(nir_builder *b, nir_intrinsic_instr *intr, void *data)
338 {
339    b->cursor = nir_before_instr(&intr->instr);
340    nir_def *repl = lower_tes_impl(b, intr, data);
341 
342    if (repl) {
343       nir_def_rewrite_uses(&intr->def, repl);
344       nir_instr_remove(&intr->instr);
345       return true;
346    } else {
347       return false;
348    }
349 }
350 
351 static int
glsl_type_size(const struct glsl_type * type,bool bindless)352 glsl_type_size(const struct glsl_type *type, bool bindless)
353 {
354    return glsl_count_attribute_slots(type, false);
355 }
356 
357 bool
agx_nir_lower_tes(nir_shader * tes,const nir_shader * libagx)358 agx_nir_lower_tes(nir_shader *tes, const nir_shader *libagx)
359 {
360    nir_lower_tess_coord_z(
361       tes, tes->info.tess._primitive_mode == TESS_PRIMITIVE_TRIANGLES);
362 
363    nir_shader_intrinsics_pass(
364       tes, lower_tes, nir_metadata_block_index | nir_metadata_dominance, NULL);
365 
366    /* Points mode renders as points, make sure we write point size for the HW */
367    if (tes->info.tess.point_mode &&
368        !(tes->info.outputs_written & VARYING_BIT_PSIZ)) {
369 
370       nir_function_impl *impl = nir_shader_get_entrypoint(tes);
371       nir_builder b = nir_builder_at(nir_after_impl(impl));
372 
373       nir_store_output(&b, nir_imm_float(&b, 1.0), nir_imm_int(&b, 0),
374                        .io_semantics.location = VARYING_SLOT_PSIZ,
375                        .write_mask = nir_component_mask(1), .range = 1);
376 
377       tes->info.outputs_written |= VARYING_BIT_PSIZ;
378    }
379 
380    /* We lower to a HW VS, so update the shader info so the compiler does the
381     * right thing.
382     */
383    tes->info.stage = MESA_SHADER_VERTEX;
384    memset(&tes->info.vs, 0, sizeof(tes->info.vs));
385    tes->info.vs.tes_agx = true;
386 
387    link_libagx(tes, libagx);
388    nir_lower_idiv(tes, &(nir_lower_idiv_options){.allow_fp16 = true});
389    nir_metadata_preserve(nir_shader_get_entrypoint(tes), nir_metadata_none);
390    return true;
391 }
392