1 /*
2  * Copyright © 2024 Imagination Technologies Ltd.
3  *
4  * SPDX-License-Identifier: MIT
5  */
6 
7 /**
8  * \file pco_nir.c
9  *
10  * \brief NIR-specific functions.
11  */
12 
13 #include "nir/nir_builder.h"
14 #include "pco.h"
15 #include "pco_internal.h"
16 
17 #include <stdio.h>
18 
19 /** Base/common SPIR-V to NIR options. */
20 static const struct spirv_to_nir_options pco_base_spirv_options = {
21    .environment = NIR_SPIRV_VULKAN,
22 };
23 
24 /** Base/common NIR options. */
25 static const nir_shader_compiler_options pco_base_nir_options = {
26    .fuse_ffma32 = true,
27 
28    .lower_fquantize2f16 = true,
29    .lower_layer_fs_input_to_sysval = true,
30    .compact_arrays = true,
31 };
32 
33 /**
34  * \brief Sets up device/core-specific SPIR-V to NIR options.
35  *
36  * \param[in] dev_info Device info.
37  * \param[out] spirv_options SPIR-V to NIR options.
38  */
pco_setup_spirv_options(const struct pvr_device_info * dev_info,struct spirv_to_nir_options * spirv_options)39 void pco_setup_spirv_options(const struct pvr_device_info *dev_info,
40                              struct spirv_to_nir_options *spirv_options)
41 {
42    memcpy(spirv_options, &pco_base_spirv_options, sizeof(*spirv_options));
43 
44    /* TODO: Device/core-dependent options. */
45    puts("finishme: pco_setup_spirv_options");
46 }
47 
48 /**
49  * \brief Sets up device/core-specific NIR options.
50  *
51  * \param[in] dev_info Device info.
52  * \param[out] nir_options NIR options.
53  */
pco_setup_nir_options(const struct pvr_device_info * dev_info,nir_shader_compiler_options * nir_options)54 void pco_setup_nir_options(const struct pvr_device_info *dev_info,
55                            nir_shader_compiler_options *nir_options)
56 {
57    memcpy(nir_options, &pco_base_nir_options, sizeof(*nir_options));
58 
59    /* TODO: Device/core-dependent options. */
60    puts("finishme: pco_setup_nir_options");
61 }
62 
63 /**
64  * \brief Runs pre-processing passes on a NIR shader.
65  *
66  * \param[in] ctx PCO compiler context.
67  * \param[in,out] nir NIR shader.
68  */
pco_preprocess_nir(pco_ctx * ctx,nir_shader * nir)69 void pco_preprocess_nir(pco_ctx *ctx, nir_shader *nir)
70 {
71    if (nir->info.internal)
72       NIR_PASS(_, nir, nir_lower_returns);
73 
74    NIR_PASS(_, nir, nir_lower_global_vars_to_local);
75    NIR_PASS(_, nir, nir_lower_vars_to_ssa);
76    NIR_PASS(_, nir, nir_split_var_copies);
77    NIR_PASS(_, nir, nir_lower_var_copies);
78    NIR_PASS(_, nir, nir_split_per_member_structs);
79    NIR_PASS(_,
80             nir,
81             nir_split_struct_vars,
82             nir_var_function_temp | nir_var_shader_temp);
83    NIR_PASS(_,
84             nir,
85             nir_split_array_vars,
86             nir_var_function_temp | nir_var_shader_temp);
87    NIR_PASS(_,
88             nir,
89             nir_lower_indirect_derefs,
90             nir_var_shader_in | nir_var_shader_out,
91             UINT32_MAX);
92 
93    NIR_PASS(_,
94             nir,
95             nir_remove_dead_variables,
96             nir_var_function_temp | nir_var_shader_temp,
97             NULL);
98    NIR_PASS(_, nir, nir_opt_dce);
99 
100    if (pco_should_print_nir(nir)) {
101       puts("after pco_preprocess_nir:");
102       nir_print_shader(nir, stdout);
103    }
104 }
105 
106 /**
107  * \brief Returns the GLSL type size.
108  *
109  * \param[in] type Type.
110  * \param[in] bindless Whether the access is bindless.
111  * \return The size.
112  */
glsl_type_size(const struct glsl_type * type,UNUSED bool bindless)113 static int glsl_type_size(const struct glsl_type *type, UNUSED bool bindless)
114 {
115    return glsl_count_attribute_slots(type, false);
116 }
117 
118 /**
119  * \brief Returns the vectorization with for a given instruction.
120  *
121  * \param[in] instr Instruction.
122  * \param[in] data User data.
123  * \return The vectorization width.
124  */
vectorize_filter(const nir_instr * instr,UNUSED const void * data)125 static uint8_t vectorize_filter(const nir_instr *instr, UNUSED const void *data)
126 {
127    if (instr->type == nir_instr_type_load_const)
128       return 1;
129 
130    if (instr->type != nir_instr_type_alu)
131       return 0;
132 
133    /* TODO */
134    nir_alu_instr *alu = nir_instr_as_alu(instr);
135    switch (alu->op) {
136    default:
137       break;
138    }
139 
140    /* Basic for now. */
141    return 2;
142 }
143 
144 /**
145  * \brief Filters for a varying position load_input in frag shaders.
146  *
147  * \param[in] instr Instruction.
148  * \param[in] data User data.
149  * \return True if the instruction was found.
150  */
frag_pos_filter(const nir_instr * instr,UNUSED const void * data)151 static bool frag_pos_filter(const nir_instr *instr, UNUSED const void *data)
152 {
153    assert(instr->type == nir_instr_type_intrinsic);
154 
155    nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
156    if (intr->intrinsic != nir_intrinsic_load_input)
157       return false;
158 
159    return nir_intrinsic_io_semantics(intr).location == VARYING_SLOT_POS;
160 }
161 
162 /**
163  * \brief Lowers a NIR shader.
164  *
165  * \param[in] ctx PCO compiler context.
166  * \param[in,out] nir NIR shader.
167  * \param[in,out] data Shader data.
168  */
pco_lower_nir(pco_ctx * ctx,nir_shader * nir,pco_data * data)169 void pco_lower_nir(pco_ctx *ctx, nir_shader *nir, pco_data *data)
170 {
171    NIR_PASS(_,
172             nir,
173             nir_lower_io,
174             nir_var_shader_in | nir_var_shader_out,
175             glsl_type_size,
176             nir_lower_io_lower_64bit_to_32);
177 
178    NIR_PASS(_, nir, nir_opt_dce);
179    NIR_PASS(_, nir, nir_opt_constant_folding);
180    NIR_PASS(_,
181             nir,
182             nir_io_add_const_offset_to_base,
183             nir_var_shader_in | nir_var_shader_out);
184 
185    if (nir->info.stage == MESA_SHADER_FRAGMENT) {
186       NIR_PASS(_, nir, pco_nir_pfo, &data->fs);
187    } else if (nir->info.stage == MESA_SHADER_VERTEX) {
188       NIR_PASS(_, nir, pco_nir_pvi, &data->vs);
189    }
190 
191    /* TODO: this should happen in the linking stage to cull unused I/O. */
192    NIR_PASS(_,
193             nir,
194             nir_lower_io_to_scalar,
195             nir_var_shader_in | nir_var_shader_out,
196             NULL,
197             NULL);
198 
199    NIR_PASS(_, nir, nir_lower_vars_to_ssa);
200    NIR_PASS(_, nir, nir_opt_copy_prop_vars);
201    NIR_PASS(_, nir, nir_opt_dead_write_vars);
202    NIR_PASS(_, nir, nir_opt_combine_stores, nir_var_all);
203 
204    bool progress;
205    NIR_PASS(_, nir, nir_lower_alu);
206    NIR_PASS(_, nir, nir_lower_pack);
207    NIR_PASS(_, nir, nir_opt_algebraic);
208    do {
209       progress = false;
210 
211       NIR_PASS(progress, nir, nir_opt_algebraic_late);
212       NIR_PASS(_, nir, nir_opt_constant_folding);
213       NIR_PASS(_, nir, nir_lower_load_const_to_scalar);
214       NIR_PASS(_, nir, nir_copy_prop);
215       NIR_PASS(_, nir, nir_opt_dce);
216       NIR_PASS(_, nir, nir_opt_cse);
217    } while (progress);
218 
219    nir_variable_mode vec_modes = nir_var_shader_in;
220    /* Fragment shader needs scalar writes after pfo. */
221    if (nir->info.stage != MESA_SHADER_FRAGMENT)
222       vec_modes |= nir_var_shader_out;
223 
224    NIR_PASS(_, nir, nir_opt_vectorize_io, vec_modes);
225 
226    /* Special case for frag coords:
227     * - x,y come from (non-consecutive) special regs - always scalar.
228     * - z,w are iterated and driver will make sure they're consecutive.
229     *   - TODO: keep scalar for now, but add pass to vectorize.
230     */
231    if (nir->info.stage == MESA_SHADER_FRAGMENT) {
232       NIR_PASS(_,
233                nir,
234                nir_lower_io_to_scalar,
235                nir_var_shader_in,
236                frag_pos_filter,
237                NULL);
238    }
239 
240    NIR_PASS(_, nir, nir_lower_alu_to_scalar, NULL, NULL);
241 
242    do {
243       progress = false;
244 
245       NIR_PASS(progress, nir, nir_copy_prop);
246       NIR_PASS(progress, nir, nir_opt_dce);
247       NIR_PASS(progress, nir, nir_opt_cse);
248       NIR_PASS(progress, nir, nir_opt_constant_folding);
249       NIR_PASS(progress, nir, nir_opt_undef);
250    } while (progress);
251 
252    if (pco_should_print_nir(nir)) {
253       puts("after pco_lower_nir:");
254       nir_print_shader(nir, stdout);
255    }
256 }
257 
258 /**
259  * \brief Gather fragment shader data pass.
260  *
261  * \param[in] b NIR builder.
262  * \param[in] intr NIR intrinsic instruction.
263  * \param[in,out] cb_data Callback data.
264  * \return True if the shader was modified (always return false).
265  */
gather_fs_data_pass(UNUSED struct nir_builder * b,nir_intrinsic_instr * intr,void * cb_data)266 static bool gather_fs_data_pass(UNUSED struct nir_builder *b,
267                                 nir_intrinsic_instr *intr,
268                                 void *cb_data)
269 {
270    /* Check whether the shader accesses z/w. */
271    if (intr->intrinsic != nir_intrinsic_load_input)
272       return false;
273 
274    struct nir_io_semantics io_semantics = nir_intrinsic_io_semantics(intr);
275    if (io_semantics.location != VARYING_SLOT_POS)
276       return false;
277 
278    unsigned component = nir_intrinsic_component(intr);
279    unsigned chans = intr->def.num_components;
280 
281    pco_data *data = cb_data;
282 
283    data->fs.uses.z |= (component + chans > 2);
284    data->fs.uses.w |= (component + chans > 3);
285 
286    return false;
287 }
288 
289 /**
290  * \brief Gathers fragment shader data.
291  *
292  * \param[in] nir NIR shader.
293  * \param[in,out] data Shader data.
294  */
gather_fs_data(nir_shader * nir,pco_data * data)295 static void gather_fs_data(nir_shader *nir, pco_data *data)
296 {
297    nir_shader_intrinsics_pass(nir, gather_fs_data_pass, nir_metadata_all, data);
298 
299    /* If any inputs use smooth shading, then w is needed. */
300    if (!data->fs.uses.w) {
301       nir_foreach_shader_in_variable (var, nir) {
302          if (var->data.interpolation > INTERP_MODE_SMOOTH)
303             continue;
304 
305          data->fs.uses.w = true;
306          break;
307       }
308    }
309 }
310 
311 /**
312  * \brief Gathers shader data.
313  *
314  * \param[in] nir NIR shader.
315  * \param[in,out] data Shader data.
316  */
gather_data(nir_shader * nir,pco_data * data)317 static void gather_data(nir_shader *nir, pco_data *data)
318 {
319    switch (nir->info.stage) {
320    case MESA_SHADER_FRAGMENT:
321       return gather_fs_data(nir, data);
322 
323    case MESA_SHADER_VERTEX:
324       /* TODO */
325       break;
326 
327    default:
328       unreachable();
329    }
330 }
331 
332 /**
333  * \brief Runs post-processing passes on a NIR shader.
334  *
335  * \param[in] ctx PCO compiler context.
336  * \param[in,out] nir NIR shader.
337  * \param[in,out] data Shader data.
338  */
pco_postprocess_nir(pco_ctx * ctx,nir_shader * nir,pco_data * data)339 void pco_postprocess_nir(pco_ctx *ctx, nir_shader *nir, pco_data *data)
340 {
341    NIR_PASS(_, nir, nir_move_vec_src_uses_to_dest, false);
342 
343    /* Re-index everything. */
344    nir_foreach_function_with_impl (_, impl, nir) {
345       nir_index_blocks(impl);
346       nir_index_instrs(impl);
347       nir_index_ssa_defs(impl);
348    }
349 
350    nir_shader_gather_info(nir, nir_shader_get_entrypoint(nir));
351 
352    gather_data(nir, data);
353 
354    if (pco_should_print_nir(nir)) {
355       puts("after pco_postprocess_nir:");
356       nir_print_shader(nir, stdout);
357    }
358 }
359 
360 /**
361  * \brief Performs linking optimizations on consecutive NIR shader stages.
362  *
363  * \param[in] ctx PCO compiler context.
364  * \param[in,out] producer NIR producer shader.
365  * \param[in,out] consumer NIR consumer shader.
366  */
pco_link_nir(pco_ctx * ctx,nir_shader * producer,nir_shader * consumer)367 void pco_link_nir(pco_ctx *ctx, nir_shader *producer, nir_shader *consumer)
368 {
369    /* TODO */
370    puts("finishme: pco_link_nir");
371 
372    if (pco_should_print_nir(producer)) {
373       puts("producer after pco_link_nir:");
374       nir_print_shader(producer, stdout);
375    }
376 
377    if (pco_should_print_nir(consumer)) {
378       puts("consumer after pco_link_nir:");
379       nir_print_shader(consumer, stdout);
380    }
381 }
382 
383 /**
384  * \brief Checks whether two varying variables are the same.
385  *
386  * \param[in] out_var The first varying being compared.
387  * \param[in] in_var The second varying being compared.
388  * \return True if the varyings match.
389  */
varyings_match(nir_variable * out_var,nir_variable * in_var)390 static bool varyings_match(nir_variable *out_var, nir_variable *in_var)
391 {
392    return in_var->data.location == out_var->data.location &&
393           in_var->data.location_frac == out_var->data.location_frac &&
394           in_var->type == out_var->type;
395 }
396 
397 /**
398  * \brief Performs reverse linking optimizations on consecutive NIR shader
399  * stages.
400  *
401  * \param[in] ctx PCO compiler context.
402  * \param[in,out] producer NIR producer shader.
403  * \param[in,out] consumer NIR consumer shader.
404  */
pco_rev_link_nir(pco_ctx * ctx,nir_shader * producer,nir_shader * consumer)405 void pco_rev_link_nir(pco_ctx *ctx, nir_shader *producer, nir_shader *consumer)
406 {
407    /* TODO */
408    puts("finishme: pco_rev_link_nir");
409 
410    /* Propagate back/adjust the interpolation qualifiers. */
411    nir_foreach_shader_in_variable (in_var, consumer) {
412       if (in_var->data.location == VARYING_SLOT_POS ||
413           in_var->data.location == VARYING_SLOT_PNTC) {
414          in_var->data.interpolation = INTERP_MODE_NOPERSPECTIVE;
415       } else if (in_var->data.interpolation == INTERP_MODE_NONE) {
416          in_var->data.interpolation = INTERP_MODE_SMOOTH;
417       }
418 
419       nir_foreach_shader_out_variable (out_var, producer) {
420          if (!varyings_match(out_var, in_var))
421             continue;
422 
423          out_var->data.interpolation = in_var->data.interpolation;
424          break;
425       }
426    }
427 
428    if (pco_should_print_nir(producer)) {
429       puts("producer after pco_rev_link_nir:");
430       nir_print_shader(producer, stdout);
431    }
432 
433    if (pco_should_print_nir(consumer)) {
434       puts("consumer after pco_rev_link_nir:");
435       nir_print_shader(consumer, stdout);
436    }
437 }
438