1 /*
2 * Copyright © 2016 Red Hat.
3 * Copyright © 2016 Bas Nieuwenhuizen
4 * Copyright © 2023 Valve Corporation
5 *
6 * SPDX-License-Identifier: MIT
7 */
8
9 #include "ac_nir.h"
10 #include "nir.h"
11 #include "nir_builder.h"
12 #include "radv_device.h"
13 #include "radv_nir.h"
14 #include "radv_physical_device.h"
15 #include "radv_shader.h"
16
17 static int
type_size_vec4(const struct glsl_type * type,bool bindless)18 type_size_vec4(const struct glsl_type *type, bool bindless)
19 {
20 return glsl_count_attribute_slots(type, false);
21 }
22
23 void
radv_nir_lower_io_to_scalar_early(nir_shader * nir,nir_variable_mode mask)24 radv_nir_lower_io_to_scalar_early(nir_shader *nir, nir_variable_mode mask)
25 {
26 bool progress = false;
27
28 NIR_PASS(progress, nir, nir_lower_io_to_scalar_early, mask);
29 if (progress) {
30 /* Optimize the new vector code and then remove dead vars */
31 NIR_PASS(_, nir, nir_copy_prop);
32 NIR_PASS(_, nir, nir_opt_shrink_vectors, true);
33
34 if (mask & nir_var_shader_out) {
35 /* Optimize swizzled movs of load_const for nir_link_opt_varyings's constant propagation. */
36 NIR_PASS(_, nir, nir_opt_constant_folding);
37
38 /* For nir_link_opt_varyings's duplicate input opt */
39 NIR_PASS(_, nir, nir_opt_cse);
40 }
41
42 /* Run copy-propagation to help remove dead output variables (some shaders have useless copies
43 * to/from an output), so compaction later will be more effective.
44 *
45 * This will have been done earlier but it might not have worked because the outputs were
46 * vector.
47 */
48 if (nir->info.stage == MESA_SHADER_TESS_CTRL)
49 NIR_PASS(_, nir, nir_opt_copy_prop_vars);
50
51 NIR_PASS(_, nir, nir_opt_dce);
52 NIR_PASS(_, nir, nir_remove_dead_variables, nir_var_function_temp | nir_var_shader_in | nir_var_shader_out, NULL);
53 }
54 }
55
56 typedef struct {
57 uint64_t always_per_vertex;
58 uint64_t potentially_per_primitive;
59 uint64_t always_per_primitive;
60 unsigned num_always_per_vertex;
61 unsigned num_potentially_per_primitive;
62 } radv_recompute_fs_input_bases_state;
63
64 static bool
radv_recompute_fs_input_bases_callback(UNUSED nir_builder * b,nir_intrinsic_instr * intrin,void * data)65 radv_recompute_fs_input_bases_callback(UNUSED nir_builder *b, nir_intrinsic_instr *intrin, void *data)
66 {
67 const radv_recompute_fs_input_bases_state *s = (const radv_recompute_fs_input_bases_state *)data;
68
69 /* Filter possible FS input intrinsics */
70 switch (intrin->intrinsic) {
71 case nir_intrinsic_load_input:
72 case nir_intrinsic_load_per_primitive_input:
73 case nir_intrinsic_load_interpolated_input:
74 case nir_intrinsic_load_input_vertex:
75 break;
76 default:
77 return false;
78 }
79
80 const nir_io_semantics sem = nir_intrinsic_io_semantics(intrin);
81 const uint64_t location_bit = BITFIELD64_BIT(sem.location);
82 const uint64_t location_mask = BITFIELD64_MASK(sem.location);
83 const unsigned old_base = nir_intrinsic_base(intrin);
84 unsigned new_base = 0;
85
86 if (location_bit & s->always_per_vertex) {
87 new_base = util_bitcount64(s->always_per_vertex & location_mask);
88 } else if (location_bit & s->potentially_per_primitive) {
89 new_base = s->num_always_per_vertex;
90
91 switch (location_bit) {
92 case VARYING_BIT_VIEWPORT:
93 break;
94 case VARYING_BIT_PRIMITIVE_ID:
95 new_base += !!(s->potentially_per_primitive & VARYING_BIT_VIEWPORT);
96 break;
97 }
98 } else if (location_bit & s->always_per_primitive) {
99 new_base = s->num_always_per_vertex + s->num_potentially_per_primitive +
100 util_bitcount64(s->always_per_primitive & location_mask);
101 } else {
102 unreachable("invalid FS input");
103 }
104
105 if (new_base != old_base) {
106 nir_intrinsic_set_base(intrin, new_base);
107 return true;
108 }
109
110 return false;
111 }
112
113 bool
radv_recompute_fs_input_bases(nir_shader * nir)114 radv_recompute_fs_input_bases(nir_shader *nir)
115 {
116 const uint64_t always_per_vertex =
117 nir->info.inputs_read & ~nir->info.per_primitive_inputs & ~(VARYING_BIT_PRIMITIVE_ID | VARYING_BIT_VIEWPORT);
118
119 const uint64_t potentially_per_primitive = nir->info.inputs_read & (VARYING_BIT_PRIMITIVE_ID | VARYING_BIT_VIEWPORT);
120
121 const uint64_t always_per_primitive =
122 nir->info.inputs_read & nir->info.per_primitive_inputs & ~(VARYING_BIT_PRIMITIVE_ID | VARYING_BIT_VIEWPORT);
123
124 radv_recompute_fs_input_bases_state s = {
125 .always_per_vertex = always_per_vertex,
126 .potentially_per_primitive = potentially_per_primitive,
127 .always_per_primitive = always_per_primitive,
128 .num_always_per_vertex = util_bitcount64(always_per_vertex),
129 .num_potentially_per_primitive = util_bitcount64(potentially_per_primitive),
130 };
131
132 return nir_shader_intrinsics_pass(nir, radv_recompute_fs_input_bases_callback, nir_metadata_control_flow, &s);
133 }
134
135 void
radv_nir_lower_io(struct radv_device * device,nir_shader * nir)136 radv_nir_lower_io(struct radv_device *device, nir_shader *nir)
137 {
138 const struct radv_physical_device *pdev = radv_device_physical(device);
139
140 if (nir->info.stage == MESA_SHADER_VERTEX) {
141 NIR_PASS(_, nir, nir_lower_io, nir_var_shader_in, type_size_vec4, 0);
142 NIR_PASS(_, nir, nir_lower_io, nir_var_shader_out, type_size_vec4, nir_lower_io_lower_64bit_to_32);
143 } else {
144 NIR_PASS(_, nir, nir_lower_io, nir_var_shader_in | nir_var_shader_out, type_size_vec4,
145 nir_lower_io_lower_64bit_to_32 | nir_lower_io_use_interpolated_input_intrinsics);
146 }
147
148 /* This pass needs actual constants */
149 NIR_PASS(_, nir, nir_opt_constant_folding);
150
151 NIR_PASS(_, nir, nir_io_add_const_offset_to_base, nir_var_shader_in | nir_var_shader_out);
152
153 if (nir->xfb_info) {
154 NIR_PASS(_, nir, nir_io_add_intrinsic_xfb_info);
155
156 if (pdev->use_ngg_streamout) {
157 /* The total number of shader outputs is required for computing the pervertex LDS size for
158 * VS/TES when lowering NGG streamout.
159 */
160 nir_assign_io_var_locations(nir, nir_var_shader_out, &nir->num_outputs, nir->info.stage);
161 }
162 }
163
164 if (nir->info.stage == MESA_SHADER_FRAGMENT) {
165 /* Lower explicit input load intrinsics to sysvals for the layer ID. */
166 NIR_PASS(_, nir, nir_lower_system_values);
167
168 /* Recompute FS input intrinsic bases to assign a location to each FS input.
169 * The computed base will match the index of each input in SPI_PS_INPUT_CNTL_n.
170 */
171 radv_recompute_fs_input_bases(nir);
172 }
173
174 NIR_PASS_V(nir, nir_opt_dce);
175 NIR_PASS_V(nir, nir_remove_dead_variables, nir_var_shader_in | nir_var_shader_out, NULL);
176 }
177
178 /* IO slot layout for stages that aren't linked. */
179 enum {
180 RADV_IO_SLOT_POS = 0,
181 RADV_IO_SLOT_CLIP_DIST0,
182 RADV_IO_SLOT_CLIP_DIST1,
183 RADV_IO_SLOT_PSIZ,
184 RADV_IO_SLOT_VAR0, /* 0..31 */
185 };
186
187 unsigned
radv_map_io_driver_location(unsigned semantic)188 radv_map_io_driver_location(unsigned semantic)
189 {
190 if ((semantic >= VARYING_SLOT_PATCH0 && semantic < VARYING_SLOT_TESS_MAX) ||
191 semantic == VARYING_SLOT_TESS_LEVEL_INNER || semantic == VARYING_SLOT_TESS_LEVEL_OUTER)
192 return ac_shader_io_get_unique_index_patch(semantic);
193
194 switch (semantic) {
195 case VARYING_SLOT_POS:
196 return RADV_IO_SLOT_POS;
197 case VARYING_SLOT_CLIP_DIST0:
198 return RADV_IO_SLOT_CLIP_DIST0;
199 case VARYING_SLOT_CLIP_DIST1:
200 return RADV_IO_SLOT_CLIP_DIST1;
201 case VARYING_SLOT_PSIZ:
202 return RADV_IO_SLOT_PSIZ;
203 default:
204 assert(semantic >= VARYING_SLOT_VAR0 && semantic <= VARYING_SLOT_VAR31);
205 return RADV_IO_SLOT_VAR0 + (semantic - VARYING_SLOT_VAR0);
206 }
207 }
208
209 bool
radv_nir_lower_io_to_mem(struct radv_device * device,struct radv_shader_stage * stage)210 radv_nir_lower_io_to_mem(struct radv_device *device, struct radv_shader_stage *stage)
211 {
212 const struct radv_physical_device *pdev = radv_device_physical(device);
213 const struct radv_shader_info *info = &stage->info;
214 ac_nir_map_io_driver_location map_input = info->inputs_linked ? NULL : radv_map_io_driver_location;
215 ac_nir_map_io_driver_location map_output = info->outputs_linked ? NULL : radv_map_io_driver_location;
216 nir_shader *nir = stage->nir;
217
218 if (nir->info.stage == MESA_SHADER_VERTEX) {
219 if (info->vs.as_ls) {
220 NIR_PASS_V(nir, ac_nir_lower_ls_outputs_to_mem, map_output, pdev->info.gfx_level, info->vs.tcs_in_out_eq,
221 info->vs.tcs_inputs_via_temp, info->vs.tcs_inputs_via_lds);
222 return true;
223 } else if (info->vs.as_es) {
224 NIR_PASS_V(nir, ac_nir_lower_es_outputs_to_mem, map_output, pdev->info.gfx_level, info->esgs_itemsize, info->gs_inputs_read);
225 return true;
226 }
227 } else if (nir->info.stage == MESA_SHADER_TESS_CTRL) {
228 NIR_PASS_V(nir, ac_nir_lower_hs_inputs_to_mem, map_input, pdev->info.gfx_level, info->vs.tcs_in_out_eq,
229 info->vs.tcs_inputs_via_temp, info->vs.tcs_inputs_via_lds);
230 NIR_PASS_V(nir, ac_nir_lower_hs_outputs_to_mem, &info->tcs.info, map_output, pdev->info.gfx_level,
231 info->tcs.tes_inputs_read, info->tcs.tes_patch_inputs_read, info->wave_size);
232
233 return true;
234 } else if (nir->info.stage == MESA_SHADER_TESS_EVAL) {
235 NIR_PASS_V(nir, ac_nir_lower_tes_inputs_to_mem, map_input);
236
237 if (info->tes.as_es) {
238 NIR_PASS_V(nir, ac_nir_lower_es_outputs_to_mem, map_output, pdev->info.gfx_level, info->esgs_itemsize, info->gs_inputs_read);
239 }
240
241 return true;
242 } else if (nir->info.stage == MESA_SHADER_GEOMETRY) {
243 NIR_PASS_V(nir, ac_nir_lower_gs_inputs_to_mem, map_input, pdev->info.gfx_level, false);
244 return true;
245 } else if (nir->info.stage == MESA_SHADER_TASK) {
246 ac_nir_lower_task_outputs_to_mem(nir, AC_TASK_PAYLOAD_ENTRY_BYTES, pdev->task_info.num_entries,
247 info->cs.has_query);
248 return true;
249 } else if (nir->info.stage == MESA_SHADER_MESH) {
250 ac_nir_lower_mesh_inputs_to_mem(nir, AC_TASK_PAYLOAD_ENTRY_BYTES, pdev->task_info.num_entries);
251 return true;
252 }
253
254 return false;
255 }
256
257 static bool
radv_nir_lower_draw_id_to_zero_callback(struct nir_builder * b,nir_intrinsic_instr * intrin,UNUSED void * state)258 radv_nir_lower_draw_id_to_zero_callback(struct nir_builder *b, nir_intrinsic_instr *intrin, UNUSED void *state)
259 {
260 if (intrin->intrinsic != nir_intrinsic_load_draw_id)
261 return false;
262
263 nir_def *replacement = nir_imm_zero(b, intrin->def.num_components, intrin->def.bit_size);
264 nir_def_replace(&intrin->def, replacement);
265 nir_instr_free(&intrin->instr);
266
267 return true;
268 }
269
270 bool
radv_nir_lower_draw_id_to_zero(nir_shader * shader)271 radv_nir_lower_draw_id_to_zero(nir_shader *shader)
272 {
273 return nir_shader_intrinsics_pass(shader, radv_nir_lower_draw_id_to_zero_callback, nir_metadata_control_flow, NULL);
274 }
275