• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright 2023 Alyssa Rosenzweig
3  * SPDX-License-Identifier: MIT
4  */
5 
6 #include "libagx/geometry.h"
7 #include "libagx/libagx.h"
8 #include "util/bitscan.h"
9 #include "util/macros.h"
10 #include "agx_nir_lower_gs.h"
11 #include "nir.h"
12 #include "nir_builder.h"
13 #include "nir_builder_opcodes.h"
14 #include "nir_intrinsics.h"
15 #include "nir_intrinsics_indices.h"
16 #include "shader_enums.h"
17 
18 static nir_def *
tcs_patch_id(nir_builder * b)19 tcs_patch_id(nir_builder *b)
20 {
21    return nir_channel(b, nir_load_workgroup_id(b), 0);
22 }
23 
24 static nir_def *
tcs_instance_id(nir_builder * b)25 tcs_instance_id(nir_builder *b)
26 {
27    return nir_channel(b, nir_load_workgroup_id(b), 1);
28 }
29 
30 static nir_def *
tcs_unrolled_id(nir_builder * b)31 tcs_unrolled_id(nir_builder *b)
32 {
33    return libagx_tcs_unrolled_id(b, nir_load_tess_param_buffer_agx(b),
34                                  nir_load_workgroup_id(b));
35 }
36 
37 uint64_t
agx_tcs_per_vertex_outputs(const nir_shader * nir)38 agx_tcs_per_vertex_outputs(const nir_shader *nir)
39 {
40    return nir->info.outputs_written &
41           ~(VARYING_BIT_TESS_LEVEL_INNER | VARYING_BIT_TESS_LEVEL_OUTER |
42             VARYING_BIT_BOUNDING_BOX0 | VARYING_BIT_BOUNDING_BOX1);
43 }
44 
45 unsigned
agx_tcs_output_stride(const nir_shader * nir)46 agx_tcs_output_stride(const nir_shader *nir)
47 {
48    return libagx_tcs_out_stride(util_last_bit(nir->info.patch_outputs_written),
49                                 nir->info.tess.tcs_vertices_out,
50                                 agx_tcs_per_vertex_outputs(nir));
51 }
52 
53 static nir_def *
tcs_out_addr(nir_builder * b,nir_intrinsic_instr * intr,nir_def * vertex_id)54 tcs_out_addr(nir_builder *b, nir_intrinsic_instr *intr, nir_def *vertex_id)
55 {
56    nir_io_semantics sem = nir_intrinsic_io_semantics(intr);
57 
58    nir_def *offset = nir_get_io_offset_src(intr)->ssa;
59    nir_def *addr = libagx_tcs_out_address(
60       b, nir_load_tess_param_buffer_agx(b), tcs_unrolled_id(b), vertex_id,
61       nir_iadd_imm(b, offset, sem.location),
62       nir_imm_int(b, util_last_bit(b->shader->info.patch_outputs_written)),
63       nir_imm_int(b, b->shader->info.tess.tcs_vertices_out),
64       nir_imm_int64(b, agx_tcs_per_vertex_outputs(b->shader)));
65 
66    addr = nir_iadd_imm(b, addr, nir_intrinsic_component(intr) * 4);
67 
68    return addr;
69 }
70 
71 static nir_def *
lower_tes_load(nir_builder * b,nir_intrinsic_instr * intr)72 lower_tes_load(nir_builder *b, nir_intrinsic_instr *intr)
73 {
74    gl_varying_slot location = nir_intrinsic_io_semantics(intr).location;
75    nir_src *offset_src = nir_get_io_offset_src(intr);
76 
77    nir_def *vertex = nir_imm_int(b, 0);
78    nir_def *offset = offset_src ? offset_src->ssa : nir_imm_int(b, 0);
79 
80    if (intr->intrinsic == nir_intrinsic_load_per_vertex_input)
81       vertex = intr->src[0].ssa;
82 
83    nir_def *addr = libagx_tes_in_address(b, nir_load_tess_param_buffer_agx(b),
84                                          nir_load_vertex_id(b), vertex,
85                                          nir_iadd_imm(b, offset, location));
86 
87    if (nir_intrinsic_has_component(intr))
88       addr = nir_iadd_imm(b, addr, nir_intrinsic_component(intr) * 4);
89 
90    return nir_load_global_constant(b, addr, 4, intr->def.num_components,
91                                    intr->def.bit_size);
92 }
93 
94 static nir_def *
tcs_load_input(nir_builder * b,nir_intrinsic_instr * intr)95 tcs_load_input(nir_builder *b, nir_intrinsic_instr *intr)
96 {
97    nir_def *base = nir_imul(
98       b, tcs_unrolled_id(b),
99       libagx_tcs_patch_vertices_in(b, nir_load_tess_param_buffer_agx(b)));
100    nir_def *vertex = nir_iadd(b, base, intr->src[0].ssa);
101 
102    return agx_load_per_vertex_input(b, intr, vertex);
103 }
104 
105 static nir_def *
lower_tcs_impl(nir_builder * b,nir_intrinsic_instr * intr)106 lower_tcs_impl(nir_builder *b, nir_intrinsic_instr *intr)
107 {
108    switch (intr->intrinsic) {
109    case nir_intrinsic_barrier:
110       /* A patch fits in a subgroup, so the barrier is unnecessary. */
111       return NIR_LOWER_INSTR_PROGRESS_REPLACE;
112 
113    case nir_intrinsic_load_primitive_id:
114       return tcs_patch_id(b);
115 
116    case nir_intrinsic_load_instance_id:
117       return tcs_instance_id(b);
118 
119    case nir_intrinsic_load_invocation_id:
120       if (b->shader->info.tess.tcs_vertices_out == 1)
121          return nir_imm_int(b, 0);
122       else
123          return nir_channel(b, nir_load_local_invocation_id(b), 0);
124 
125    case nir_intrinsic_load_per_vertex_input:
126       return tcs_load_input(b, intr);
127 
128    case nir_intrinsic_load_patch_vertices_in:
129       return libagx_tcs_patch_vertices_in(b, nir_load_tess_param_buffer_agx(b));
130 
131    case nir_intrinsic_load_tess_level_outer_default:
132       return libagx_tess_level_outer_default(b,
133                                              nir_load_tess_param_buffer_agx(b));
134 
135    case nir_intrinsic_load_tess_level_inner_default:
136       return libagx_tess_level_inner_default(b,
137                                              nir_load_tess_param_buffer_agx(b));
138 
139    case nir_intrinsic_load_output: {
140       nir_def *addr = tcs_out_addr(b, intr, nir_undef(b, 1, 32));
141       return nir_load_global(b, addr, 4, intr->def.num_components,
142                              intr->def.bit_size);
143    }
144 
145    case nir_intrinsic_load_per_vertex_output: {
146       nir_def *addr = tcs_out_addr(b, intr, intr->src[0].ssa);
147       return nir_load_global(b, addr, 4, intr->def.num_components,
148                              intr->def.bit_size);
149    }
150 
151    case nir_intrinsic_store_output: {
152       /* Only vec2, make sure we can't overwrite */
153       assert(intr->src[0].ssa->num_components <= 2 ||
154              nir_intrinsic_io_semantics(intr).location !=
155                 VARYING_SLOT_TESS_LEVEL_INNER);
156 
157       nir_store_global(b, tcs_out_addr(b, intr, nir_undef(b, 1, 32)), 4,
158                        intr->src[0].ssa, nir_intrinsic_write_mask(intr));
159       return NIR_LOWER_INSTR_PROGRESS_REPLACE;
160    }
161 
162    case nir_intrinsic_store_per_vertex_output: {
163       nir_store_global(b, tcs_out_addr(b, intr, intr->src[1].ssa), 4,
164                        intr->src[0].ssa, nir_intrinsic_write_mask(intr));
165       return NIR_LOWER_INSTR_PROGRESS_REPLACE;
166    }
167 
168    default:
169       return NULL;
170    }
171 }
172 
173 static bool
lower_tcs(nir_builder * b,nir_intrinsic_instr * intr,void * data)174 lower_tcs(nir_builder *b, nir_intrinsic_instr *intr, void *data)
175 {
176    b->cursor = nir_before_instr(&intr->instr);
177 
178    nir_def *repl = lower_tcs_impl(b, intr);
179    if (!repl)
180       return false;
181 
182    if (repl != NIR_LOWER_INSTR_PROGRESS_REPLACE)
183       nir_def_rewrite_uses(&intr->def, repl);
184 
185    nir_instr_remove(&intr->instr);
186    return true;
187 }
188 
189 static void
link_libagx(nir_shader * nir,const nir_shader * libagx)190 link_libagx(nir_shader *nir, const nir_shader *libagx)
191 {
192    nir_link_shader_functions(nir, libagx);
193    NIR_PASS(_, nir, nir_inline_functions);
194    nir_remove_non_entrypoints(nir);
195    NIR_PASS(_, nir, nir_lower_indirect_derefs, nir_var_function_temp, 64);
196    NIR_PASS(_, nir, nir_opt_dce);
197    NIR_PASS(_, nir, nir_lower_vars_to_explicit_types, nir_var_function_temp,
198             glsl_get_cl_type_size_align);
199    NIR_PASS(_, nir, nir_opt_deref);
200    NIR_PASS(_, nir, nir_lower_vars_to_ssa);
201    NIR_PASS(_, nir, nir_lower_explicit_io,
202             nir_var_shader_temp | nir_var_function_temp | nir_var_mem_shared |
203                nir_var_mem_global,
204             nir_address_format_62bit_generic);
205 }
206 
207 bool
agx_nir_lower_tcs(nir_shader * tcs,const struct nir_shader * libagx)208 agx_nir_lower_tcs(nir_shader *tcs, const struct nir_shader *libagx)
209 {
210    nir_shader_intrinsics_pass(tcs, lower_tcs, nir_metadata_control_flow, NULL);
211 
212    link_libagx(tcs, libagx);
213    return true;
214 }
215 
216 static nir_def *
lower_tes_impl(nir_builder * b,nir_intrinsic_instr * intr,void * data)217 lower_tes_impl(nir_builder *b, nir_intrinsic_instr *intr, void *data)
218 {
219    switch (intr->intrinsic) {
220    case nir_intrinsic_load_tess_coord_xy:
221       return libagx_load_tess_coord(b, nir_load_tess_param_buffer_agx(b),
222                                     nir_load_vertex_id(b));
223 
224    case nir_intrinsic_load_primitive_id:
225       return libagx_tes_patch_id(b, nir_load_tess_param_buffer_agx(b),
226                                  nir_load_vertex_id(b));
227 
228    case nir_intrinsic_load_input:
229    case nir_intrinsic_load_per_vertex_input:
230    case nir_intrinsic_load_tess_level_inner:
231    case nir_intrinsic_load_tess_level_outer:
232       return lower_tes_load(b, intr);
233 
234    case nir_intrinsic_load_patch_vertices_in:
235       return libagx_tes_patch_vertices_in(b, nir_load_tess_param_buffer_agx(b));
236 
237    default:
238       return NULL;
239    }
240 }
241 
242 static bool
lower_tes(nir_builder * b,nir_intrinsic_instr * intr,void * data)243 lower_tes(nir_builder *b, nir_intrinsic_instr *intr, void *data)
244 {
245    b->cursor = nir_before_instr(&intr->instr);
246    nir_def *repl = lower_tes_impl(b, intr, data);
247 
248    if (repl) {
249       nir_def_replace(&intr->def, repl);
250       return true;
251    } else {
252       return false;
253    }
254 }
255 
256 static bool
lower_tes_indexing(nir_builder * b,nir_intrinsic_instr * intr,void * data)257 lower_tes_indexing(nir_builder *b, nir_intrinsic_instr *intr, void *data)
258 {
259    if (intr->intrinsic == nir_intrinsic_load_instance_id)
260       unreachable("todo");
261 
262    if (intr->intrinsic != nir_intrinsic_load_vertex_id)
263       return false;
264 
265    b->cursor = nir_before_instr(&intr->instr);
266    nir_def *p = nir_load_tess_param_buffer_agx(b);
267    nir_def *id = nir_channel(b, nir_load_global_invocation_id(b, 32), 0);
268    nir_def_replace(&intr->def, libagx_load_tes_index(b, p, id));
269    return true;
270 }
271 
272 bool
agx_nir_lower_tes(nir_shader * tes,const nir_shader * libagx,bool to_hw_vs)273 agx_nir_lower_tes(nir_shader *tes, const nir_shader *libagx, bool to_hw_vs)
274 {
275    nir_lower_tess_coord_z(
276       tes, tes->info.tess._primitive_mode == TESS_PRIMITIVE_TRIANGLES);
277 
278    nir_shader_intrinsics_pass(tes, lower_tes, nir_metadata_control_flow, NULL);
279 
280    /* Points mode renders as points, make sure we write point size for the HW */
281    if (tes->info.tess.point_mode &&
282        !(tes->info.outputs_written & VARYING_BIT_PSIZ) && to_hw_vs) {
283 
284       nir_function_impl *impl = nir_shader_get_entrypoint(tes);
285       nir_builder b = nir_builder_at(nir_after_impl(impl));
286 
287       nir_store_output(&b, nir_imm_float(&b, 1.0), nir_imm_int(&b, 0),
288                        .io_semantics.location = VARYING_SLOT_PSIZ,
289                        .write_mask = nir_component_mask(1), .range = 1,
290                        .src_type = nir_type_float32);
291 
292       tes->info.outputs_written |= VARYING_BIT_PSIZ;
293    }
294 
295    if (to_hw_vs) {
296       /* We lower to a HW VS, so update the shader info so the compiler does the
297        * right thing.
298        */
299       tes->info.stage = MESA_SHADER_VERTEX;
300       memset(&tes->info.vs, 0, sizeof(tes->info.vs));
301       tes->info.vs.tes_agx = true;
302    } else {
303       /* If we're running as a compute shader, we need to load from the index
304        * buffer manually. Fortunately, this doesn't require a shader key:
305        * tess-as-compute always use U32 index buffers.
306        */
307       nir_shader_intrinsics_pass(tes, lower_tes_indexing,
308                                  nir_metadata_control_flow, NULL);
309    }
310 
311    link_libagx(tes, libagx);
312    nir_lower_idiv(tes, &(nir_lower_idiv_options){.allow_fp16 = true});
313    nir_metadata_preserve(nir_shader_get_entrypoint(tes), nir_metadata_none);
314    return true;
315 }
316