1 /*
2 * Copyright 2024 Advanced Micro Devices, Inc.
3 *
4 * SPDX-License-Identifier: MIT
5 */
6
7 /* This is a pre-link lowering and optimization pass that modifies the shader for the purpose
8 * of gathering accurate shader_info and determining hw registers. It should be run before
9 * linking passes and it doesn't produce AMD intrinsics that would break linking passes.
10 * Some of the options come from dynamic state.
11 */
12
13 #include "ac_nir.h"
14 #include "sid.h"
15 #include "nir_builder.h"
16 #include "nir_builtin_builder.h"
17
18 typedef struct {
19 const ac_nir_lower_ps_early_options *options;
20
21 nir_variable *persp_center;
22 nir_variable *persp_centroid;
23 nir_variable *persp_sample;
24 nir_variable *linear_center;
25 nir_variable *linear_centroid;
26 nir_variable *linear_sample;
27 bool lower_load_barycentric;
28
29 bool seen_color0_alpha;
30 } lower_ps_early_state;
31
32 static void
create_interp_param(nir_builder * b,lower_ps_early_state * s)33 create_interp_param(nir_builder *b, lower_ps_early_state *s)
34 {
35 if (s->options->force_persp_sample_interp) {
36 s->persp_center =
37 nir_local_variable_create(b->impl, glsl_vec_type(2), "persp_center");
38 }
39
40 if (s->options->force_persp_sample_interp ||
41 s->options->force_persp_center_interp) {
42 s->persp_centroid =
43 nir_local_variable_create(b->impl, glsl_vec_type(2), "persp_centroid");
44 }
45
46 if (s->options->force_persp_center_interp) {
47 s->persp_sample =
48 nir_local_variable_create(b->impl, glsl_vec_type(2), "persp_sample");
49 }
50
51 if (s->options->force_linear_sample_interp) {
52 s->linear_center =
53 nir_local_variable_create(b->impl, glsl_vec_type(2), "linear_center");
54 }
55
56 if (s->options->force_linear_sample_interp ||
57 s->options->force_linear_center_interp) {
58 s->linear_centroid =
59 nir_local_variable_create(b->impl, glsl_vec_type(2), "linear_centroid");
60 }
61
62 if (s->options->force_linear_center_interp) {
63 s->linear_sample =
64 nir_local_variable_create(b->impl, glsl_vec_type(2), "linear_sample");
65 }
66
67 s->lower_load_barycentric =
68 s->persp_center || s->persp_centroid || s->persp_sample ||
69 s->linear_center || s->linear_centroid || s->linear_sample;
70 }
71
72 static void
init_interp_param(nir_builder * b,lower_ps_early_state * s)73 init_interp_param(nir_builder *b, lower_ps_early_state *s)
74 {
75 b->cursor = nir_before_cf_list(&b->impl->body);
76
77 if (s->options->force_persp_sample_interp) {
78 nir_def *sample =
79 nir_load_barycentric_sample(b, 32, .interp_mode = INTERP_MODE_SMOOTH);
80 nir_store_var(b, s->persp_center, sample, 0x3);
81 nir_store_var(b, s->persp_centroid, sample, 0x3);
82 }
83
84 if (s->options->force_linear_sample_interp) {
85 nir_def *sample =
86 nir_load_barycentric_sample(b, 32, .interp_mode = INTERP_MODE_NOPERSPECTIVE);
87 nir_store_var(b, s->linear_center, sample, 0x3);
88 nir_store_var(b, s->linear_centroid, sample, 0x3);
89 }
90
91 if (s->options->force_persp_center_interp) {
92 nir_def *center =
93 nir_load_barycentric_pixel(b, 32, .interp_mode = INTERP_MODE_SMOOTH);
94 nir_store_var(b, s->persp_sample, center, 0x3);
95 nir_store_var(b, s->persp_centroid, center, 0x3);
96 }
97
98 if (s->options->force_linear_center_interp) {
99 nir_def *center =
100 nir_load_barycentric_pixel(b, 32, .interp_mode = INTERP_MODE_NOPERSPECTIVE);
101 nir_store_var(b, s->linear_sample, center, 0x3);
102 nir_store_var(b, s->linear_centroid, center, 0x3);
103 }
104 }
105
106 static bool
rewrite_ps_load_barycentric(nir_builder * b,nir_intrinsic_instr * intrin,lower_ps_early_state * s)107 rewrite_ps_load_barycentric(nir_builder *b, nir_intrinsic_instr *intrin, lower_ps_early_state *s)
108 {
109 enum glsl_interp_mode mode = nir_intrinsic_interp_mode(intrin);
110 nir_variable *var = NULL;
111
112 switch (mode) {
113 case INTERP_MODE_NONE:
114 case INTERP_MODE_SMOOTH:
115 switch (intrin->intrinsic) {
116 case nir_intrinsic_load_barycentric_pixel:
117 var = s->persp_center;
118 break;
119 case nir_intrinsic_load_barycentric_centroid:
120 var = s->persp_centroid;
121 break;
122 case nir_intrinsic_load_barycentric_sample:
123 var = s->persp_sample;
124 break;
125 default:
126 break;
127 }
128 break;
129
130 case INTERP_MODE_NOPERSPECTIVE:
131 switch (intrin->intrinsic) {
132 case nir_intrinsic_load_barycentric_pixel:
133 var = s->linear_center;
134 break;
135 case nir_intrinsic_load_barycentric_centroid:
136 var = s->linear_centroid;
137 break;
138 case nir_intrinsic_load_barycentric_sample:
139 var = s->linear_sample;
140 break;
141 default:
142 break;
143 }
144 break;
145
146 default:
147 break;
148 }
149
150 if (!var)
151 return false;
152
153 b->cursor = nir_before_instr(&intrin->instr);
154
155 nir_def *replacement = nir_load_var(b, var);
156 nir_def_replace(&intrin->def, replacement);
157 return true;
158 }
159
160 static bool
optimize_lower_ps_outputs(nir_builder * b,nir_intrinsic_instr * intrin,lower_ps_early_state * s)161 optimize_lower_ps_outputs(nir_builder *b, nir_intrinsic_instr *intrin, lower_ps_early_state *s)
162 {
163 unsigned slot = nir_intrinsic_io_semantics(intrin).location;
164
165 switch (slot) {
166 case FRAG_RESULT_DEPTH:
167 if (!s->options->kill_z)
168 return false;
169 nir_instr_remove(&intrin->instr);
170 return true;
171
172 case FRAG_RESULT_STENCIL:
173 if (!s->options->kill_stencil)
174 return false;
175 nir_instr_remove(&intrin->instr);
176 return true;
177
178 case FRAG_RESULT_SAMPLE_MASK:
179 if (!s->options->kill_samplemask)
180 return false;
181 nir_instr_remove(&intrin->instr);
182 return true;
183 }
184
185 unsigned writemask = nir_intrinsic_write_mask(intrin);
186 unsigned component = nir_intrinsic_component(intrin);
187 unsigned color_index = (slot >= FRAG_RESULT_DATA0 ? slot - FRAG_RESULT_DATA0 : 0) +
188 nir_intrinsic_io_semantics(intrin).dual_source_blend_index;
189 nir_def *value = intrin->src[0].ssa;
190 bool progress = false;
191
192 b->cursor = nir_before_instr(&intrin->instr);
193
194 /* Clamp color. */
195 if (s->options->clamp_color) {
196 value = nir_fsat(b, value);
197 progress = true;
198 }
199
200 /* Alpha test. */
201 if (color_index == 0 && s->options->alpha_func != COMPARE_FUNC_ALWAYS &&
202 (writemask << component) & BITFIELD_BIT(3)) {
203 assert(!s->seen_color0_alpha);
204 s->seen_color0_alpha = true;
205
206 if (s->options->alpha_func == COMPARE_FUNC_NEVER) {
207 nir_discard(b);
208 } else {
209 nir_def *ref = nir_load_alpha_reference_amd(b);
210 ref = nir_convert_to_bit_size(b, ref, nir_type_float, value->bit_size);
211 nir_def *alpha = s->options->alpha_test_alpha_to_one ?
212 nir_imm_floatN_t(b, 1, value->bit_size) :
213 nir_channel(b, value, 3 - component);
214 nir_def *cond = nir_compare_func(b, s->options->alpha_func, alpha, ref);
215 nir_discard_if(b, nir_inot(b, cond));
216 }
217 progress = true;
218 }
219
220 /* Trim the src according to the format and writemask. */
221 unsigned cb_shader_mask = ac_get_cb_shader_mask(s->options->spi_shader_col_format_hint);
222 unsigned format_mask;
223
224 if (slot == FRAG_RESULT_COLOR) {
225 /* cb_shader_mask is 0 for disabled color buffers, so combine all of them. */
226 format_mask = 0;
227 for (unsigned i = 0; i < 8; i++)
228 format_mask |= (cb_shader_mask >> (i * 4)) & 0xf;
229 } else {
230 format_mask = (cb_shader_mask >> (color_index * 4)) & 0xf;
231 }
232
233 if (s->options->keep_alpha_for_mrtz && color_index == 0)
234 format_mask |= BITFIELD_BIT(3);
235
236 writemask = (format_mask >> component) & writemask;
237 nir_intrinsic_set_write_mask(intrin, writemask);
238
239 /* Empty writemask. */
240 if (!writemask) {
241 nir_instr_remove(&intrin->instr);
242 return true;
243 }
244
245 /* Trim the src to the last bit of writemask. */
246 unsigned num_components = util_last_bit(writemask);
247
248 if (num_components != value->num_components) {
249 assert(num_components < value->num_components);
250 value = nir_trim_vector(b, value, num_components);
251 progress = true;
252 }
253
254 /* Replace disabled channels in a non-contiguous writemask with undef. */
255 if (!util_is_power_of_two_nonzero(writemask + 1)) {
256 u_foreach_bit(i, BITFIELD_MASK(num_components) & ~writemask) {
257 value = nir_vector_insert_imm(b, value, nir_undef(b, 1, value->bit_size), i);
258 progress = true;
259 }
260 }
261
262 if (progress && intrin->src[0].ssa != value) {
263 nir_src_rewrite(&intrin->src[0], value);
264 intrin->num_components = value->num_components;
265 } else {
266 assert(intrin->src[0].ssa == value);
267 }
268
269 return progress;
270 }
271
272 static bool
lower_ps_load_sample_mask_in(nir_builder * b,nir_intrinsic_instr * intrin,lower_ps_early_state * s)273 lower_ps_load_sample_mask_in(nir_builder *b, nir_intrinsic_instr *intrin, lower_ps_early_state *s)
274 {
275 /* Section 15.2.2 (Shader Inputs) of the OpenGL 4.5 (Core Profile) spec
276 * says:
277 *
278 * "When per-sample shading is active due to the use of a fragment
279 * input qualified by sample or due to the use of the gl_SampleID
280 * or gl_SamplePosition variables, only the bit for the current
281 * sample is set in gl_SampleMaskIn. When state specifies multiple
282 * fragment shader invocations for a given fragment, the sample
283 * mask for any single fragment shader invocation may specify a
284 * subset of the covered samples for the fragment. In this case,
285 * the bit corresponding to each covered sample will be set in
286 * exactly one fragment shader invocation."
287 *
288 * The samplemask loaded by hardware is always the coverage of the
289 * entire pixel/fragment, so mask bits out based on the sample ID.
290 */
291
292 b->cursor = nir_before_instr(&intrin->instr);
293
294 uint32_t ps_iter_mask = ac_get_ps_iter_mask(s->options->ps_iter_samples);
295 nir_def *sampleid = nir_load_sample_id(b);
296 nir_def *submask = nir_ishl(b, nir_imm_int(b, ps_iter_mask), sampleid);
297
298 nir_def *sample_mask = nir_load_sample_mask_in(b);
299 nir_def *replacement = nir_iand(b, sample_mask, submask);
300
301 nir_def_replace(&intrin->def, replacement);
302 return true;
303 }
304
305 static bool
lower_ps_intrinsic(nir_builder * b,nir_instr * instr,void * state)306 lower_ps_intrinsic(nir_builder *b, nir_instr *instr, void *state)
307 {
308 lower_ps_early_state *s = (lower_ps_early_state *)state;
309
310 if (instr->type != nir_instr_type_intrinsic)
311 return false;
312
313 nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
314
315 switch (intrin->intrinsic) {
316 case nir_intrinsic_store_output:
317 return optimize_lower_ps_outputs(b, intrin, s);
318 case nir_intrinsic_load_barycentric_pixel:
319 case nir_intrinsic_load_barycentric_centroid:
320 case nir_intrinsic_load_barycentric_sample:
321 if (s->lower_load_barycentric)
322 return rewrite_ps_load_barycentric(b, intrin, s);
323 break;
324 case nir_intrinsic_load_sample_mask_in:
325 if (s->options->ps_iter_samples > 1)
326 return lower_ps_load_sample_mask_in(b, intrin, s);
327 break;
328 default:
329 break;
330 }
331
332 return false;
333 }
334
335 void
ac_nir_lower_ps_early(nir_shader * nir,const ac_nir_lower_ps_early_options * options)336 ac_nir_lower_ps_early(nir_shader *nir, const ac_nir_lower_ps_early_options *options)
337 {
338 assert(nir->info.stage == MESA_SHADER_FRAGMENT);
339 nir_function_impl *impl = nir_shader_get_entrypoint(nir);
340
341 nir_builder builder = nir_builder_create(impl);
342 nir_builder *b = &builder;
343
344 lower_ps_early_state state = {
345 .options = options,
346 };
347
348 create_interp_param(b, &state);
349
350 nir_shader_instructions_pass(nir, lower_ps_intrinsic,
351 nir_metadata_control_flow,
352 &state);
353
354 /* This must be after lower_ps_intrinsic. */
355 init_interp_param(b, &state);
356
357 /* Cleanup local variables, as RADV won't do this. */
358 if (state.lower_load_barycentric)
359 nir_lower_vars_to_ssa(nir);
360 }
361