• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright © 2017 Red Hat
3  *
4  * SPDX-License-Identifier: MIT
5  */
6 #include "radv_shader_info.h"
7 #include "nir/nir.h"
8 #include "nir/nir_xfb_info.h"
9 #include "nir/radv_nir.h"
10 #include "radv_device.h"
11 #include "radv_physical_device.h"
12 #include "radv_pipeline_graphics.h"
13 #include "radv_shader.h"
14 
15 #include "ac_nir.h"
16 
17 static void
mark_sampler_desc(const nir_variable * var,struct radv_shader_info * info)18 mark_sampler_desc(const nir_variable *var, struct radv_shader_info *info)
19 {
20    info->desc_set_used_mask |= (1u << var->data.descriptor_set);
21 }
22 
23 static bool
radv_use_vs_prolog(const nir_shader * nir,const struct radv_graphics_state_key * gfx_state)24 radv_use_vs_prolog(const nir_shader *nir,
25                    const struct radv_graphics_state_key *gfx_state)
26 {
27    return gfx_state->vs.has_prolog && nir->info.inputs_read;
28 }
29 
30 static bool
radv_use_per_attribute_vb_descs(const nir_shader * nir,const struct radv_graphics_state_key * gfx_state,const struct radv_shader_stage_key * stage_key)31 radv_use_per_attribute_vb_descs(const nir_shader *nir,
32                                 const struct radv_graphics_state_key *gfx_state,
33                                 const struct radv_shader_stage_key *stage_key)
34 {
35    return stage_key->vertex_robustness1 || radv_use_vs_prolog(nir, gfx_state);
36 }
37 
38 static void
gather_load_vs_input_info(const nir_shader * nir,const nir_intrinsic_instr * intrin,struct radv_shader_info * info,const struct radv_graphics_state_key * gfx_state,const struct radv_shader_stage_key * stage_key)39 gather_load_vs_input_info(const nir_shader *nir, const nir_intrinsic_instr *intrin, struct radv_shader_info *info,
40                           const struct radv_graphics_state_key *gfx_state,
41                           const struct radv_shader_stage_key *stage_key)
42 {
43    const nir_io_semantics io_sem = nir_intrinsic_io_semantics(intrin);
44    const unsigned location = io_sem.location;
45    const unsigned component = nir_intrinsic_component(intrin);
46    unsigned mask = nir_def_components_read(&intrin->def);
47    mask = (intrin->def.bit_size == 64 ? util_widen_mask(mask, 2) : mask) << component;
48 
49    if (location >= VERT_ATTRIB_GENERIC0) {
50       const unsigned generic_loc = location - VERT_ATTRIB_GENERIC0;
51 
52       if (gfx_state->vi.instance_rate_inputs & BITFIELD_BIT(generic_loc)) {
53          info->vs.needs_instance_id = true;
54          info->vs.needs_base_instance = true;
55       }
56 
57       if (radv_use_per_attribute_vb_descs(nir, gfx_state, stage_key))
58          info->vs.vb_desc_usage_mask |= BITFIELD_BIT(generic_loc);
59       else
60          info->vs.vb_desc_usage_mask |= BITFIELD_BIT(gfx_state->vi.vertex_attribute_bindings[generic_loc]);
61 
62       info->vs.input_slot_usage_mask |= BITFIELD_RANGE(generic_loc, io_sem.num_slots);
63    }
64 }
65 
66 static void
gather_load_fs_input_info(const nir_shader * nir,const nir_intrinsic_instr * intrin,struct radv_shader_info * info,const struct radv_graphics_state_key * gfx_state)67 gather_load_fs_input_info(const nir_shader *nir, const nir_intrinsic_instr *intrin, struct radv_shader_info *info,
68                           const struct radv_graphics_state_key *gfx_state)
69 {
70    const nir_io_semantics io_sem = nir_intrinsic_io_semantics(intrin);
71    const unsigned location = io_sem.location;
72    const unsigned mapped_location = nir_intrinsic_base(intrin);
73    const unsigned attrib_count = io_sem.num_slots;
74    const unsigned component = nir_intrinsic_component(intrin);
75 
76    switch (location) {
77    case VARYING_SLOT_CLIP_DIST0:
78       info->ps.input_clips_culls_mask |= BITFIELD_RANGE(component, intrin->num_components);
79       break;
80    case VARYING_SLOT_CLIP_DIST1:
81       info->ps.input_clips_culls_mask |= BITFIELD_RANGE(component, intrin->num_components) << 4;
82       break;
83    default:
84       break;
85    }
86 
87    const uint32_t mapped_mask = BITFIELD_RANGE(mapped_location, attrib_count);
88    const bool per_primitive = nir->info.per_primitive_inputs & BITFIELD64_BIT(location);
89 
90    if (!per_primitive) {
91       if (intrin->intrinsic == nir_intrinsic_load_input_vertex) {
92          if (io_sem.interp_explicit_strict)
93             info->ps.explicit_strict_shaded_mask |= mapped_mask;
94          else
95             info->ps.explicit_shaded_mask |= mapped_mask;
96       } else if (intrin->intrinsic == nir_intrinsic_load_interpolated_input && intrin->def.bit_size == 16) {
97          if (io_sem.high_16bits)
98             info->ps.float16_hi_shaded_mask |= mapped_mask;
99          else
100             info->ps.float16_shaded_mask |= mapped_mask;
101       } else if (intrin->intrinsic == nir_intrinsic_load_interpolated_input) {
102          info->ps.float32_shaded_mask |= mapped_mask;
103       }
104    }
105 
106    if (location >= VARYING_SLOT_VAR0) {
107       const uint32_t var_mask = BITFIELD_RANGE(location - VARYING_SLOT_VAR0, attrib_count);
108 
109       if (per_primitive)
110          info->ps.input_per_primitive_mask |= var_mask;
111       else
112          info->ps.input_mask |= var_mask;
113    }
114 }
115 
116 static void
gather_intrinsic_load_input_info(const nir_shader * nir,const nir_intrinsic_instr * instr,struct radv_shader_info * info,const struct radv_graphics_state_key * gfx_state,const struct radv_shader_stage_key * stage_key)117 gather_intrinsic_load_input_info(const nir_shader *nir, const nir_intrinsic_instr *instr, struct radv_shader_info *info,
118                                  const struct radv_graphics_state_key *gfx_state,
119                                  const struct radv_shader_stage_key *stage_key)
120 {
121    switch (nir->info.stage) {
122    case MESA_SHADER_VERTEX:
123       gather_load_vs_input_info(nir, instr, info, gfx_state, stage_key);
124       break;
125    case MESA_SHADER_FRAGMENT:
126       gather_load_fs_input_info(nir, instr, info, gfx_state);
127       break;
128    default:
129       break;
130    }
131 }
132 
133 static void
gather_intrinsic_store_output_info(const nir_shader * nir,const nir_intrinsic_instr * instr,struct radv_shader_info * info,bool consider_force_vrs)134 gather_intrinsic_store_output_info(const nir_shader *nir, const nir_intrinsic_instr *instr,
135                                    struct radv_shader_info *info, bool consider_force_vrs)
136 {
137    const nir_io_semantics io_sem = nir_intrinsic_io_semantics(instr);
138    const unsigned location = io_sem.location;
139    const unsigned num_slots = io_sem.num_slots;
140    const unsigned component = nir_intrinsic_component(instr);
141    const unsigned write_mask = nir_intrinsic_write_mask(instr);
142    uint8_t *output_usage_mask = NULL;
143 
144    switch (nir->info.stage) {
145    case MESA_SHADER_VERTEX:
146       output_usage_mask = info->vs.output_usage_mask;
147       break;
148    case MESA_SHADER_TESS_EVAL:
149       output_usage_mask = info->tes.output_usage_mask;
150       break;
151    case MESA_SHADER_GEOMETRY:
152       output_usage_mask = info->gs.output_usage_mask;
153       break;
154    case MESA_SHADER_FRAGMENT:
155       if (location >= FRAG_RESULT_DATA0) {
156          const unsigned fs_semantic = location + io_sem.dual_source_blend_index;
157          info->ps.colors_written |= 0xfu << (4 * (fs_semantic - FRAG_RESULT_DATA0));
158 
159          if (fs_semantic == FRAG_RESULT_DATA0)
160             info->ps.color0_written = write_mask;
161       }
162       break;
163    default:
164       break;
165    }
166 
167    if (output_usage_mask) {
168       for (unsigned i = 0; i < num_slots; i++) {
169          output_usage_mask[location + i] |= ((write_mask >> (i * 4)) & 0xf) << component;
170       }
171    }
172 
173    if (consider_force_vrs && location == VARYING_SLOT_POS) {
174       unsigned pos_w_chan = 3 - component;
175 
176       if (write_mask & BITFIELD_BIT(pos_w_chan)) {
177          nir_scalar pos_w = nir_scalar_resolved(instr->src[0].ssa, pos_w_chan);
178          /* Use coarse shading if the value of Pos.W can't be determined or if its value is != 1
179           * (typical for non-GUI elements).
180           */
181          if (!nir_scalar_is_const(pos_w) || nir_scalar_as_uint(pos_w) != 0x3f800000u)
182             info->force_vrs_per_vertex = true;
183       }
184    }
185 
186    if (nir->info.stage == MESA_SHADER_GEOMETRY) {
187       const uint8_t gs_streams = nir_intrinsic_io_semantics(instr).gs_streams;
188       info->gs.output_streams[location] |= gs_streams << (component * 2);
189    }
190 
191    if ((location == VARYING_SLOT_CLIP_DIST0 || location == VARYING_SLOT_CLIP_DIST1) && !io_sem.no_sysval_output) {
192       unsigned base = (location == VARYING_SLOT_CLIP_DIST1 ? 4 : 0) + component;
193       unsigned clip_array_mask = BITFIELD_MASK(nir->info.clip_distance_array_size);
194       info->outinfo.clip_dist_mask |= (write_mask << base) & clip_array_mask;
195       info->outinfo.cull_dist_mask |= (write_mask << base) & ~clip_array_mask;
196    }
197 }
198 
199 static void
gather_push_constant_info(const nir_shader * nir,const nir_intrinsic_instr * instr,struct radv_shader_info * info)200 gather_push_constant_info(const nir_shader *nir, const nir_intrinsic_instr *instr, struct radv_shader_info *info)
201 {
202    info->loads_push_constants = true;
203 
204    if (nir_src_is_const(instr->src[0]) && instr->def.bit_size >= 32) {
205       uint32_t start = (nir_intrinsic_base(instr) + nir_src_as_uint(instr->src[0])) / 4u;
206       uint32_t size = instr->num_components * (instr->def.bit_size / 32u);
207 
208       if (start + size <= (MAX_PUSH_CONSTANTS_SIZE / 4u)) {
209          info->inline_push_constant_mask |= u_bit_consecutive64(start, size);
210          return;
211       }
212    }
213 
214    info->can_inline_all_push_constants = false;
215 }
216 
217 static void
gather_intrinsic_info(const nir_shader * nir,const nir_intrinsic_instr * instr,struct radv_shader_info * info,const struct radv_graphics_state_key * gfx_state,const struct radv_shader_stage_key * stage_key,bool consider_force_vrs)218 gather_intrinsic_info(const nir_shader *nir, const nir_intrinsic_instr *instr, struct radv_shader_info *info,
219                       const struct radv_graphics_state_key *gfx_state, const struct radv_shader_stage_key *stage_key,
220                       bool consider_force_vrs)
221 {
222    switch (instr->intrinsic) {
223    case nir_intrinsic_load_barycentric_sample:
224    case nir_intrinsic_load_barycentric_pixel:
225    case nir_intrinsic_load_barycentric_centroid:
226    case nir_intrinsic_load_barycentric_at_sample:
227    case nir_intrinsic_load_barycentric_at_offset: {
228       enum glsl_interp_mode mode = nir_intrinsic_interp_mode(instr);
229       switch (mode) {
230       case INTERP_MODE_SMOOTH:
231       case INTERP_MODE_NONE:
232          if (instr->intrinsic == nir_intrinsic_load_barycentric_pixel ||
233              instr->intrinsic == nir_intrinsic_load_barycentric_at_sample ||
234              instr->intrinsic == nir_intrinsic_load_barycentric_at_offset)
235             info->ps.reads_persp_center = true;
236          else if (instr->intrinsic == nir_intrinsic_load_barycentric_centroid)
237             info->ps.reads_persp_centroid = true;
238          else if (instr->intrinsic == nir_intrinsic_load_barycentric_sample)
239             info->ps.reads_persp_sample = true;
240          break;
241       case INTERP_MODE_NOPERSPECTIVE:
242          if (instr->intrinsic == nir_intrinsic_load_barycentric_pixel ||
243              instr->intrinsic == nir_intrinsic_load_barycentric_at_sample ||
244              instr->intrinsic == nir_intrinsic_load_barycentric_at_offset)
245             info->ps.reads_linear_center = true;
246          else if (instr->intrinsic == nir_intrinsic_load_barycentric_centroid)
247             info->ps.reads_linear_centroid = true;
248          else if (instr->intrinsic == nir_intrinsic_load_barycentric_sample)
249             info->ps.reads_linear_sample = true;
250          break;
251       default:
252          break;
253       }
254       if (instr->intrinsic == nir_intrinsic_load_barycentric_at_sample)
255          info->ps.needs_sample_positions = true;
256       break;
257    }
258    case nir_intrinsic_load_provoking_vtx_amd:
259       info->ps.load_provoking_vtx = true;
260       break;
261    case nir_intrinsic_load_sample_positions_amd:
262       info->ps.needs_sample_positions = true;
263       break;
264    case nir_intrinsic_load_rasterization_primitive_amd:
265       info->ps.load_rasterization_prim = true;
266       break;
267    case nir_intrinsic_load_local_invocation_id:
268    case nir_intrinsic_load_workgroup_id: {
269       unsigned mask = nir_def_components_read(&instr->def);
270       while (mask) {
271          unsigned i = u_bit_scan(&mask);
272 
273          if (instr->intrinsic == nir_intrinsic_load_workgroup_id)
274             info->cs.uses_block_id[i] = true;
275          else
276             info->cs.uses_thread_id[i] = true;
277       }
278       break;
279    }
280    case nir_intrinsic_load_pixel_coord:
281       info->ps.reads_pixel_coord = true;
282       break;
283    case nir_intrinsic_load_frag_coord:
284       info->ps.reads_frag_coord_mask |= nir_def_components_read(&instr->def);
285       break;
286    case nir_intrinsic_load_sample_pos:
287       info->ps.reads_sample_pos_mask |= nir_def_components_read(&instr->def);
288       break;
289    case nir_intrinsic_load_push_constant:
290       gather_push_constant_info(nir, instr, info);
291       break;
292    case nir_intrinsic_vulkan_resource_index:
293       info->desc_set_used_mask |= (1u << nir_intrinsic_desc_set(instr));
294       break;
295    case nir_intrinsic_image_deref_load:
296    case nir_intrinsic_image_deref_sparse_load:
297    case nir_intrinsic_image_deref_store:
298    case nir_intrinsic_image_deref_atomic:
299    case nir_intrinsic_image_deref_atomic_swap:
300    case nir_intrinsic_image_deref_size:
301    case nir_intrinsic_image_deref_samples: {
302       nir_variable *var = nir_deref_instr_get_variable(nir_instr_as_deref(instr->src[0].ssa->parent_instr));
303       mark_sampler_desc(var, info);
304       break;
305    }
306    case nir_intrinsic_load_input:
307    case nir_intrinsic_load_per_primitive_input:
308    case nir_intrinsic_load_interpolated_input:
309    case nir_intrinsic_load_input_vertex:
310       gather_intrinsic_load_input_info(nir, instr, info, gfx_state, stage_key);
311       break;
312    case nir_intrinsic_store_output:
313    case nir_intrinsic_store_per_vertex_output:
314       gather_intrinsic_store_output_info(nir, instr, info, consider_force_vrs);
315       break;
316    case nir_intrinsic_bvh64_intersect_ray_amd:
317       info->cs.uses_rt = true;
318       break;
319    case nir_intrinsic_load_poly_line_smooth_enabled:
320       info->ps.needs_poly_line_smooth = true;
321       break;
322    case nir_intrinsic_begin_invocation_interlock:
323       info->ps.pops = true;
324       break;
325    default:
326       break;
327    }
328 }
329 
330 static void
gather_tex_info(const nir_shader * nir,const nir_tex_instr * instr,struct radv_shader_info * info)331 gather_tex_info(const nir_shader *nir, const nir_tex_instr *instr, struct radv_shader_info *info)
332 {
333    for (unsigned i = 0; i < instr->num_srcs; i++) {
334       switch (instr->src[i].src_type) {
335       case nir_tex_src_texture_deref:
336          mark_sampler_desc(nir_deref_instr_get_variable(nir_src_as_deref(instr->src[i].src)), info);
337          break;
338       case nir_tex_src_sampler_deref:
339          mark_sampler_desc(nir_deref_instr_get_variable(nir_src_as_deref(instr->src[i].src)), info);
340          break;
341       default:
342          break;
343       }
344    }
345 }
346 
347 static void
gather_info_block(const nir_shader * nir,const nir_block * block,struct radv_shader_info * info,const struct radv_graphics_state_key * gfx_state,const struct radv_shader_stage_key * stage_key,bool consider_force_vrs)348 gather_info_block(const nir_shader *nir, const nir_block *block, struct radv_shader_info *info,
349                   const struct radv_graphics_state_key *gfx_state, const struct radv_shader_stage_key *stage_key,
350                   bool consider_force_vrs)
351 {
352    nir_foreach_instr (instr, block) {
353       switch (instr->type) {
354       case nir_instr_type_intrinsic:
355          gather_intrinsic_info(nir, nir_instr_as_intrinsic(instr), info, gfx_state, stage_key, consider_force_vrs);
356          break;
357       case nir_instr_type_tex:
358          gather_tex_info(nir, nir_instr_as_tex(instr), info);
359          break;
360       default:
361          break;
362       }
363    }
364 }
365 
366 static void
gather_xfb_info(const nir_shader * nir,struct radv_shader_info * info)367 gather_xfb_info(const nir_shader *nir, struct radv_shader_info *info)
368 {
369    struct radv_streamout_info *so = &info->so;
370 
371    if (!nir->xfb_info)
372       return;
373 
374    const nir_xfb_info *xfb = nir->xfb_info;
375    assert(xfb->output_count <= MAX_SO_OUTPUTS);
376    so->num_outputs = xfb->output_count;
377 
378    for (unsigned i = 0; i < xfb->output_count; i++) {
379       unsigned output_buffer = xfb->outputs[i].buffer;
380       unsigned stream = xfb->buffer_to_stream[xfb->outputs[i].buffer];
381       so->enabled_stream_buffers_mask |= (1 << output_buffer) << (stream * 4);
382    }
383 
384    for (unsigned i = 0; i < NIR_MAX_XFB_BUFFERS; i++) {
385       so->strides[i] = xfb->buffers[i].stride / 4;
386    }
387 }
388 
389 static void
assign_outinfo_param(struct radv_vs_output_info * outinfo,gl_varying_slot idx,unsigned * total_param_exports,unsigned extra_offset)390 assign_outinfo_param(struct radv_vs_output_info *outinfo, gl_varying_slot idx, unsigned *total_param_exports,
391                      unsigned extra_offset)
392 {
393    if (outinfo->vs_output_param_offset[idx] == AC_EXP_PARAM_UNDEFINED)
394       outinfo->vs_output_param_offset[idx] = extra_offset + (*total_param_exports)++;
395 }
396 
397 static void
assign_outinfo_params(struct radv_vs_output_info * outinfo,uint64_t mask,unsigned * total_param_exports,unsigned extra_offset)398 assign_outinfo_params(struct radv_vs_output_info *outinfo, uint64_t mask, unsigned *total_param_exports,
399                       unsigned extra_offset)
400 {
401    u_foreach_bit64 (idx, mask) {
402       if (idx >= VARYING_SLOT_VAR0 || idx == VARYING_SLOT_LAYER || idx == VARYING_SLOT_PRIMITIVE_ID ||
403           idx == VARYING_SLOT_VIEWPORT)
404          assign_outinfo_param(outinfo, idx, total_param_exports, extra_offset);
405    }
406 }
407 
408 static void
radv_get_output_masks(const struct nir_shader * nir,const struct radv_graphics_state_key * gfx_state,uint64_t * per_vtx_mask,uint64_t * per_prim_mask)409 radv_get_output_masks(const struct nir_shader *nir, const struct radv_graphics_state_key *gfx_state,
410                       uint64_t *per_vtx_mask, uint64_t *per_prim_mask)
411 {
412    /* These are not compiled into neither output param nor position exports. */
413    const uint64_t special_mask = BITFIELD64_BIT(VARYING_SLOT_PRIMITIVE_COUNT) |
414                                  BITFIELD64_BIT(VARYING_SLOT_PRIMITIVE_INDICES) |
415                                  BITFIELD64_BIT(VARYING_SLOT_CULL_PRIMITIVE);
416 
417    *per_prim_mask = nir->info.outputs_written & nir->info.per_primitive_outputs & ~special_mask;
418    *per_vtx_mask = nir->info.outputs_written & ~nir->info.per_primitive_outputs & ~special_mask;
419 
420    /* Mesh multiview is only lowered in ac_nir_lower_ngg, so we have to fake it here. */
421    if (nir->info.stage == MESA_SHADER_MESH && gfx_state->has_multiview_view_index)
422       *per_prim_mask |= VARYING_BIT_LAYER;
423 }
424 
425 static void
radv_set_vs_output_param(struct radv_device * device,const struct nir_shader * nir,const struct radv_graphics_state_key * gfx_state,struct radv_shader_info * info,bool export_prim_id,bool export_clip_cull_dists)426 radv_set_vs_output_param(struct radv_device *device, const struct nir_shader *nir,
427                          const struct radv_graphics_state_key *gfx_state, struct radv_shader_info *info,
428                          bool export_prim_id, bool export_clip_cull_dists)
429 {
430    const struct radv_physical_device *pdev = radv_device_physical(device);
431    struct radv_vs_output_info *outinfo = &info->outinfo;
432    uint64_t per_vtx_mask, per_prim_mask;
433 
434    radv_get_output_masks(nir, gfx_state, &per_vtx_mask, &per_prim_mask);
435 
436    memset(outinfo->vs_output_param_offset, AC_EXP_PARAM_UNDEFINED, sizeof(outinfo->vs_output_param_offset));
437 
438    /* Implicit primitive ID for VS and TES is added by ac_nir_lower_legacy_vs / ac_nir_lower_ngg,
439     * it can be configured as either a per-vertex or per-primitive output depending on the GPU.
440     */
441    const bool implicit_prim_id_per_prim =
442       export_prim_id && info->is_ngg && pdev->info.gfx_level >= GFX10_3 && nir->info.stage == MESA_SHADER_VERTEX;
443    const bool implicit_prim_id_per_vertex =
444       export_prim_id && !implicit_prim_id_per_prim &&
445       (nir->info.stage == MESA_SHADER_VERTEX || nir->info.stage == MESA_SHADER_TESS_EVAL);
446 
447    unsigned total_param_exports = 0;
448 
449    /* Per-vertex outputs */
450    assign_outinfo_params(outinfo, per_vtx_mask, &total_param_exports, 0);
451 
452    if (implicit_prim_id_per_vertex) {
453       /* Mark the primitive ID as output when it's implicitly exported by VS or TES. */
454       if (outinfo->vs_output_param_offset[VARYING_SLOT_PRIMITIVE_ID] == AC_EXP_PARAM_UNDEFINED)
455          outinfo->vs_output_param_offset[VARYING_SLOT_PRIMITIVE_ID] = total_param_exports++;
456 
457       outinfo->export_prim_id = true;
458    }
459 
460    if (export_clip_cull_dists) {
461       if (nir->info.outputs_written & VARYING_BIT_CLIP_DIST0)
462          outinfo->vs_output_param_offset[VARYING_SLOT_CLIP_DIST0] = total_param_exports++;
463       if (nir->info.outputs_written & VARYING_BIT_CLIP_DIST1)
464          outinfo->vs_output_param_offset[VARYING_SLOT_CLIP_DIST1] = total_param_exports++;
465    }
466 
467    outinfo->param_exports = total_param_exports;
468 
469    /* The HW always assumes that there is at least 1 per-vertex param.
470     * so if there aren't any, we have to offset per-primitive params by 1.
471     */
472    const unsigned extra_offset = !!(total_param_exports == 0 && pdev->info.gfx_level >= GFX11);
473 
474    if (implicit_prim_id_per_prim) {
475       /* Mark the primitive ID as output when it's implicitly exported by VS. */
476       if (outinfo->vs_output_param_offset[VARYING_SLOT_PRIMITIVE_ID] == AC_EXP_PARAM_UNDEFINED)
477          outinfo->vs_output_param_offset[VARYING_SLOT_PRIMITIVE_ID] = extra_offset + total_param_exports++;
478 
479       outinfo->export_prim_id_per_primitive = true;
480    }
481 
482    /* Per-primitive outputs: the HW needs these to be last. */
483    assign_outinfo_params(outinfo, per_prim_mask, &total_param_exports, extra_offset);
484 
485    outinfo->prim_param_exports = total_param_exports - outinfo->param_exports;
486 }
487 
488 static uint8_t
radv_get_wave_size(struct radv_device * device,gl_shader_stage stage,const struct radv_shader_info * info,const struct radv_shader_stage_key * stage_key)489 radv_get_wave_size(struct radv_device *device, gl_shader_stage stage, const struct radv_shader_info *info,
490                    const struct radv_shader_stage_key *stage_key)
491 {
492    const struct radv_physical_device *pdev = radv_device_physical(device);
493 
494    if (stage_key->subgroup_required_size)
495       return stage_key->subgroup_required_size * 32;
496 
497    if (stage == MESA_SHADER_GEOMETRY && !info->is_ngg)
498       return 64;
499    else if (stage == MESA_SHADER_COMPUTE || stage == MESA_SHADER_TASK)
500       return info->wave_size;
501    else if (stage == MESA_SHADER_FRAGMENT)
502       return pdev->ps_wave_size;
503    else if (gl_shader_stage_is_rt(stage))
504       return pdev->rt_wave_size;
505    else
506       return pdev->ge_wave_size;
507 }
508 
509 static uint8_t
radv_get_ballot_bit_size(struct radv_device * device,gl_shader_stage stage,const struct radv_shader_info * info,const struct radv_shader_stage_key * stage_key)510 radv_get_ballot_bit_size(struct radv_device *device, gl_shader_stage stage, const struct radv_shader_info *info,
511                          const struct radv_shader_stage_key *stage_key)
512 {
513    if (stage_key->subgroup_required_size)
514       return stage_key->subgroup_required_size * 32;
515 
516    return 64;
517 }
518 
519 static uint32_t
radv_compute_esgs_itemsize(const struct radv_device * device,uint32_t num_varyings)520 radv_compute_esgs_itemsize(const struct radv_device *device, uint32_t num_varyings)
521 {
522    const struct radv_physical_device *pdev = radv_device_physical(device);
523    uint32_t esgs_itemsize;
524 
525    esgs_itemsize = num_varyings * 16;
526 
527    /* For the ESGS ring in LDS, add 1 dword to reduce LDS bank
528     * conflicts, i.e. each vertex will start on a different bank.
529     */
530    if (pdev->info.gfx_level >= GFX9 && esgs_itemsize)
531       esgs_itemsize += 4;
532 
533    return esgs_itemsize;
534 }
535 
536 static void
gather_shader_info_ngg_query(struct radv_device * device,struct radv_shader_info * info)537 gather_shader_info_ngg_query(struct radv_device *device, struct radv_shader_info *info)
538 {
539    const struct radv_physical_device *pdev = radv_device_physical(device);
540 
541    info->gs.has_pipeline_stat_query = pdev->emulate_ngg_gs_query_pipeline_stat && info->stage == MESA_SHADER_GEOMETRY;
542    info->has_xfb_query = info->so.num_outputs > 0;
543    info->has_prim_query = device->cache_key.primitives_generated_query || info->has_xfb_query;
544 }
545 
546 uint64_t
radv_gather_unlinked_io_mask(const uint64_t nir_io_mask)547 radv_gather_unlinked_io_mask(const uint64_t nir_io_mask)
548 {
549    /* Create a mask of driver locations mapped from NIR semantics. */
550    uint64_t radv_io_mask = 0;
551    u_foreach_bit64 (semantic, nir_io_mask) {
552       /* These outputs are not used when fixed output slots are needed. */
553       if (semantic == VARYING_SLOT_LAYER || semantic == VARYING_SLOT_VIEWPORT ||
554           semantic == VARYING_SLOT_PRIMITIVE_ID || semantic == VARYING_SLOT_PRIMITIVE_SHADING_RATE)
555          continue;
556 
557       radv_io_mask |= BITFIELD64_BIT(radv_map_io_driver_location(semantic));
558    }
559 
560    return radv_io_mask;
561 }
562 
563 uint64_t
radv_gather_unlinked_patch_io_mask(const uint64_t nir_io_mask,const uint32_t nir_patch_io_mask)564 radv_gather_unlinked_patch_io_mask(const uint64_t nir_io_mask, const uint32_t nir_patch_io_mask)
565 {
566    uint64_t radv_io_mask = 0;
567    u_foreach_bit64 (semantic, nir_patch_io_mask) {
568       radv_io_mask |= BITFIELD64_BIT(radv_map_io_driver_location(semantic + VARYING_SLOT_PATCH0));
569    }
570 
571    /* Tess levels need to be handled separately because they are not part of patch_outputs_written. */
572    if (nir_io_mask & VARYING_BIT_TESS_LEVEL_OUTER)
573       radv_io_mask |= BITFIELD64_BIT(radv_map_io_driver_location(VARYING_SLOT_TESS_LEVEL_OUTER));
574    if (nir_io_mask & VARYING_BIT_TESS_LEVEL_INNER)
575       radv_io_mask |= BITFIELD64_BIT(radv_map_io_driver_location(VARYING_SLOT_TESS_LEVEL_INNER));
576 
577    return radv_io_mask;
578 }
579 
580 static void
gather_shader_info_vs(struct radv_device * device,const nir_shader * nir,const struct radv_graphics_state_key * gfx_state,const struct radv_shader_stage_key * stage_key,struct radv_shader_info * info)581 gather_shader_info_vs(struct radv_device *device, const nir_shader *nir,
582                       const struct radv_graphics_state_key *gfx_state, const struct radv_shader_stage_key *stage_key,
583                       struct radv_shader_info *info)
584 {
585    if (radv_use_vs_prolog(nir, gfx_state)) {
586       info->vs.has_prolog = true;
587       info->vs.dynamic_inputs = true;
588    }
589 
590    info->gs_inputs_read = ~0ULL;
591    info->vs.tcs_inputs_via_lds = ~0ULL;
592 
593    /* Use per-attribute vertex descriptors to prevent faults and for correct bounds checking. */
594    info->vs.use_per_attribute_vb_descs = radv_use_per_attribute_vb_descs(nir, gfx_state, stage_key);
595 
596    /* We have to ensure consistent input register assignments between the main shader and the
597     * prolog.
598     */
599    info->vs.needs_instance_id |= info->vs.has_prolog;
600    info->vs.needs_base_instance |= info->vs.has_prolog;
601    info->vs.needs_draw_id |= info->vs.has_prolog;
602 
603    if (info->vs.dynamic_inputs)
604       info->vs.vb_desc_usage_mask = BITFIELD_MASK(util_last_bit(info->vs.vb_desc_usage_mask));
605 
606    /* When the topology is unknown (with GPL), the number of vertices per primitive needs be passed
607     * through a user SGPR for NGG streamout with VS. Otherwise, the XFB offset is incorrectly
608     * computed because using the maximum number of vertices can't work.
609     */
610    info->vs.dynamic_num_verts_per_prim = gfx_state->ia.topology == V_008958_DI_PT_NONE && info->is_ngg && nir->xfb_info;
611 
612    if (!info->outputs_linked)
613       info->vs.num_linked_outputs = util_last_bit64(radv_gather_unlinked_io_mask(nir->info.outputs_written));
614 
615    if (info->next_stage == MESA_SHADER_TESS_CTRL) {
616       info->vs.as_ls = true;
617    } else if (info->next_stage == MESA_SHADER_GEOMETRY) {
618       info->vs.as_es = true;
619       info->esgs_itemsize = radv_compute_esgs_itemsize(device, info->vs.num_linked_outputs);
620    }
621 
622    if (info->is_ngg) {
623       info->vs.num_outputs = nir->num_outputs;
624 
625       if (info->next_stage == MESA_SHADER_FRAGMENT || info->next_stage == MESA_SHADER_NONE) {
626          gather_shader_info_ngg_query(device, info);
627       }
628    }
629 }
630 
631 static void
gather_shader_info_tcs(struct radv_device * device,const nir_shader * nir,const struct radv_graphics_state_key * gfx_state,struct radv_shader_info * info)632 gather_shader_info_tcs(struct radv_device *device, const nir_shader *nir,
633                        const struct radv_graphics_state_key *gfx_state, struct radv_shader_info *info)
634 {
635    const struct radv_physical_device *pdev = radv_device_physical(device);
636 
637    nir_gather_tcs_info(nir, &info->tcs.info, nir->info.tess._primitive_mode, nir->info.tess.spacing);
638 
639    info->tcs.tcs_outputs_read = nir->info.outputs_read;
640    info->tcs.tcs_outputs_written = nir->info.outputs_written;
641    info->tcs.tcs_patch_outputs_read = nir->info.patch_inputs_read;
642    info->tcs.tcs_patch_outputs_written = nir->info.patch_outputs_written;
643    info->tcs.tcs_vertices_out = nir->info.tess.tcs_vertices_out;
644    info->tcs.tes_inputs_read = ~0ULL;
645    info->tcs.tes_patch_inputs_read = ~0ULL;
646 
647    if (!info->inputs_linked)
648       info->tcs.num_linked_inputs = util_last_bit64(radv_gather_unlinked_io_mask(nir->info.inputs_read));
649    if (!info->outputs_linked) {
650       info->tcs.num_linked_outputs = util_last_bit64(radv_gather_unlinked_io_mask(
651          nir->info.outputs_written & ~(VARYING_BIT_TESS_LEVEL_OUTER | VARYING_BIT_TESS_LEVEL_INNER)));
652       info->tcs.num_linked_patch_outputs = util_last_bit64(
653          radv_gather_unlinked_patch_io_mask(nir->info.outputs_written, nir->info.patch_outputs_written));
654    }
655 
656    if (gfx_state->ts.patch_control_points) {
657 
658       radv_get_tess_wg_info(pdev, &nir->info, gfx_state->ts.patch_control_points,
659                             /* TODO: This should be only inputs in LDS (not VGPR inputs) to reduce LDS usage */
660                             info->tcs.num_linked_inputs, info->tcs.num_linked_outputs,
661                             info->tcs.num_linked_patch_outputs, info->tcs.info.all_invocations_define_tess_levels,
662                             &info->num_tess_patches, &info->tcs.num_lds_blocks);
663    }
664 }
665 
666 static void
gather_shader_info_tes(struct radv_device * device,const nir_shader * nir,struct radv_shader_info * info)667 gather_shader_info_tes(struct radv_device *device, const nir_shader *nir, struct radv_shader_info *info)
668 {
669    info->gs_inputs_read = ~0ULL;
670    info->tes._primitive_mode = nir->info.tess._primitive_mode;
671    info->tes.spacing = nir->info.tess.spacing;
672    info->tes.ccw = nir->info.tess.ccw;
673    info->tes.point_mode = nir->info.tess.point_mode;
674    info->tes.tcs_vertices_out = nir->info.tess.tcs_vertices_out;
675    info->tes.reads_tess_factors =
676       !!(nir->info.inputs_read & (VARYING_BIT_TESS_LEVEL_INNER | VARYING_BIT_TESS_LEVEL_OUTER));
677 
678    if (!info->inputs_linked) {
679       info->tes.num_linked_inputs = util_last_bit64(radv_gather_unlinked_io_mask(
680          nir->info.inputs_read & ~(VARYING_BIT_TESS_LEVEL_OUTER | VARYING_BIT_TESS_LEVEL_INNER)));
681       info->tes.num_linked_patch_inputs = util_last_bit64(
682          radv_gather_unlinked_patch_io_mask(nir->info.inputs_read, nir->info.patch_inputs_read));
683    }
684    if (!info->outputs_linked)
685       info->tes.num_linked_outputs = util_last_bit64(radv_gather_unlinked_io_mask(nir->info.outputs_written));
686 
687    if (info->next_stage == MESA_SHADER_GEOMETRY) {
688       info->tes.as_es = true;
689       info->esgs_itemsize = radv_compute_esgs_itemsize(device, info->tes.num_linked_outputs);
690    }
691 
692    if (info->is_ngg) {
693       info->tes.num_outputs = nir->num_outputs;
694 
695       if (info->next_stage == MESA_SHADER_FRAGMENT || info->next_stage == MESA_SHADER_NONE) {
696          gather_shader_info_ngg_query(device, info);
697       }
698    }
699 }
700 
701 static void
radv_init_legacy_gs_ring_info(const struct radv_device * device,struct radv_shader_info * gs_info)702 radv_init_legacy_gs_ring_info(const struct radv_device *device, struct radv_shader_info *gs_info)
703 {
704    const struct radv_physical_device *pdev = radv_device_physical(device);
705    struct radv_legacy_gs_info *gs_ring_info = &gs_info->gs_ring_info;
706    unsigned num_se = pdev->info.max_se;
707    unsigned wave_size = 64;
708    unsigned max_gs_waves = 32 * num_se; /* max 32 per SE on GCN */
709    /* On GFX6-GFX7, the value comes from VGT_GS_VERTEX_REUSE = 16.
710     * On GFX8+, the value comes from VGT_VERTEX_REUSE_BLOCK_CNTL = 30 (+2).
711     */
712    unsigned gs_vertex_reuse = (pdev->info.gfx_level >= GFX8 ? 32 : 16) * num_se;
713    unsigned alignment = 256 * num_se;
714    /* The maximum size is 63.999 MB per SE. */
715    unsigned max_size = ((unsigned)(63.999 * 1024 * 1024) & ~255) * num_se;
716 
717    /* Calculate the minimum size. */
718    unsigned min_esgs_ring_size = align(gs_ring_info->esgs_itemsize * 4 * gs_vertex_reuse * wave_size, alignment);
719    /* These are recommended sizes, not minimum sizes. */
720    unsigned esgs_ring_size = max_gs_waves * 2 * wave_size * gs_ring_info->esgs_itemsize * 4 * gs_info->gs.vertices_in;
721    unsigned gsvs_ring_size = max_gs_waves * 2 * wave_size * gs_info->gs.max_gsvs_emit_size;
722 
723    min_esgs_ring_size = align(min_esgs_ring_size, alignment);
724    esgs_ring_size = align(esgs_ring_size, alignment);
725    gsvs_ring_size = align(gsvs_ring_size, alignment);
726 
727    if (pdev->info.gfx_level <= GFX8)
728       gs_ring_info->esgs_ring_size = CLAMP(esgs_ring_size, min_esgs_ring_size, max_size);
729 
730    gs_ring_info->gsvs_ring_size = MIN2(gsvs_ring_size, max_size);
731 }
732 
733 static void
radv_get_legacy_gs_info(const struct radv_device * device,struct radv_shader_info * gs_info)734 radv_get_legacy_gs_info(const struct radv_device *device, struct radv_shader_info *gs_info)
735 {
736    const struct radv_physical_device *pdev = radv_device_physical(device);
737    struct radv_legacy_gs_info *out = &gs_info->gs_ring_info;
738    const unsigned gs_num_invocations = MAX2(gs_info->gs.invocations, 1);
739    const bool uses_adjacency =
740       gs_info->gs.input_prim == MESA_PRIM_LINES_ADJACENCY || gs_info->gs.input_prim == MESA_PRIM_TRIANGLES_ADJACENCY;
741 
742    /* All these are in dwords: */
743    /* We can't allow using the whole LDS, because GS waves compete with
744     * other shader stages for LDS space. */
745    const unsigned max_lds_size = 8 * 1024;
746    const unsigned esgs_itemsize = radv_compute_esgs_itemsize(device, gs_info->gs.num_linked_inputs) / 4;
747    unsigned esgs_lds_size;
748 
749    /* All these are per subgroup: */
750    const unsigned max_out_prims = 32 * 1024;
751    const unsigned max_es_verts = 255;
752    const unsigned ideal_gs_prims = 64;
753    unsigned max_gs_prims, gs_prims;
754    unsigned min_es_verts, es_verts, worst_case_es_verts;
755 
756    if (uses_adjacency || gs_num_invocations > 1)
757       max_gs_prims = 127 / gs_num_invocations;
758    else
759       max_gs_prims = 255;
760 
761    /* MAX_PRIMS_PER_SUBGROUP = gs_prims * max_vert_out * gs_invocations.
762     * Make sure we don't go over the maximum value.
763     */
764    if (gs_info->gs.vertices_out > 0) {
765       max_gs_prims = MIN2(max_gs_prims, max_out_prims / (gs_info->gs.vertices_out * gs_num_invocations));
766    }
767    assert(max_gs_prims > 0);
768 
769    /* If the primitive has adjacency, halve the number of vertices
770     * that will be reused in multiple primitives.
771     */
772    min_es_verts = gs_info->gs.vertices_in / (uses_adjacency ? 2 : 1);
773 
774    gs_prims = MIN2(ideal_gs_prims, max_gs_prims);
775    worst_case_es_verts = MIN2(min_es_verts * gs_prims, max_es_verts);
776 
777    /* Compute ESGS LDS size based on the worst case number of ES vertices
778     * needed to create the target number of GS prims per subgroup.
779     */
780    esgs_lds_size = esgs_itemsize * worst_case_es_verts;
781 
782    /* If total LDS usage is too big, refactor partitions based on ratio
783     * of ESGS item sizes.
784     */
785    if (esgs_lds_size > max_lds_size) {
786       /* Our target GS Prims Per Subgroup was too large. Calculate
787        * the maximum number of GS Prims Per Subgroup that will fit
788        * into LDS, capped by the maximum that the hardware can support.
789        */
790       gs_prims = MIN2((max_lds_size / (esgs_itemsize * min_es_verts)), max_gs_prims);
791       assert(gs_prims > 0);
792       worst_case_es_verts = MIN2(min_es_verts * gs_prims, max_es_verts);
793 
794       esgs_lds_size = esgs_itemsize * worst_case_es_verts;
795       assert(esgs_lds_size <= max_lds_size);
796    }
797 
798    /* Now calculate remaining ESGS information. */
799    if (esgs_lds_size)
800       es_verts = MIN2(esgs_lds_size / esgs_itemsize, max_es_verts);
801    else
802       es_verts = max_es_verts;
803 
804    /* Vertices for adjacency primitives are not always reused, so restore
805     * it for ES_VERTS_PER_SUBGRP.
806     */
807    min_es_verts = gs_info->gs.vertices_in;
808 
809    /* For normal primitives, the VGT only checks if they are past the ES
810     * verts per subgroup after allocating a full GS primitive and if they
811     * are, kick off a new subgroup.  But if those additional ES verts are
812     * unique (e.g. not reused) we need to make sure there is enough LDS
813     * space to account for those ES verts beyond ES_VERTS_PER_SUBGRP.
814     */
815    es_verts -= min_es_verts - 1;
816 
817    const uint32_t es_verts_per_subgroup = es_verts;
818    const uint32_t gs_prims_per_subgroup = gs_prims;
819    const uint32_t gs_inst_prims_in_subgroup = gs_prims * gs_num_invocations;
820    const uint32_t max_prims_per_subgroup = gs_inst_prims_in_subgroup * gs_info->gs.vertices_out;
821    const uint32_t lds_granularity = pdev->info.lds_encode_granularity;
822    const uint32_t total_lds_bytes = align(esgs_lds_size * 4, lds_granularity);
823 
824    out->gs_inst_prims_in_subgroup = gs_inst_prims_in_subgroup;
825    out->es_verts_per_subgroup = es_verts_per_subgroup;
826    out->gs_prims_per_subgroup = gs_prims_per_subgroup;
827    out->esgs_itemsize = esgs_itemsize;
828    out->lds_size = total_lds_bytes / lds_granularity;
829    assert(max_prims_per_subgroup <= max_out_prims);
830 
831    radv_init_legacy_gs_ring_info(device, gs_info);
832 }
833 
834 static void
gather_shader_info_gs(struct radv_device * device,const nir_shader * nir,struct radv_shader_info * info)835 gather_shader_info_gs(struct radv_device *device, const nir_shader *nir, struct radv_shader_info *info)
836 {
837    unsigned add_clip = nir->info.clip_distance_array_size + nir->info.cull_distance_array_size > 4;
838    info->gs.gsvs_vertex_size = (util_bitcount64(nir->info.outputs_written) + add_clip) * 16;
839    info->gs.max_gsvs_emit_size = info->gs.gsvs_vertex_size * nir->info.gs.vertices_out;
840 
841    info->gs.vertices_in = nir->info.gs.vertices_in;
842    info->gs.vertices_out = nir->info.gs.vertices_out;
843    info->gs.input_prim = nir->info.gs.input_primitive;
844    info->gs.output_prim = nir->info.gs.output_primitive;
845    info->gs.invocations = nir->info.gs.invocations;
846    info->gs.max_stream = nir->info.gs.active_stream_mask ? util_last_bit(nir->info.gs.active_stream_mask) - 1 : 0;
847 
848    for (unsigned slot = 0; slot < VARYING_SLOT_MAX; ++slot) {
849       const uint8_t usage_mask = info->gs.output_usage_mask[slot];
850       const uint8_t gs_streams = info->gs.output_streams[slot];
851 
852       for (unsigned component = 0; component < 4; ++component) {
853          if (!(usage_mask & BITFIELD_BIT(component)))
854             continue;
855 
856          const uint8_t stream = (gs_streams >> (component * 2)) & 0x3;
857          info->gs.num_stream_output_components[stream]++;
858       }
859    }
860 
861    if (!info->inputs_linked)
862       info->gs.num_linked_inputs = util_last_bit64(radv_gather_unlinked_io_mask(nir->info.inputs_read));
863 
864    if (info->is_ngg) {
865       gather_shader_info_ngg_query(device, info);
866    } else {
867       radv_get_legacy_gs_info(device, info);
868    }
869 }
870 
871 static void
gather_shader_info_mesh(struct radv_device * device,const nir_shader * nir,const struct radv_shader_stage_key * stage_key,struct radv_shader_info * info)872 gather_shader_info_mesh(struct radv_device *device, const nir_shader *nir,
873                         const struct radv_shader_stage_key *stage_key, struct radv_shader_info *info)
874 {
875    struct gfx10_ngg_info *ngg_info = &info->ngg_info;
876 
877    info->ms.output_prim = nir->info.mesh.primitive_type;
878 
879    /* Special case for mesh shader workgroups.
880     *
881     * Mesh shaders don't have any real vertex input, but they can produce
882     * an arbitrary number of vertices and primitives (up to 256).
883     * We need to precisely control the number of mesh shader workgroups
884     * that are launched from draw calls.
885     *
886     * To achieve that, we set:
887     * - input primitive topology to point list
888     * - input vertex and primitive count to 1
889     * - max output vertex count and primitive amplification factor
890     *   to the boundaries of the shader
891     *
892     * With that, in the draw call:
893     * - drawing 1 input vertex ~ launching 1 mesh shader workgroup
894     *
895     * In the shader:
896     * - input vertex id ~ workgroup id (in 1D - shader needs to calculate in 3D)
897     *
898     * Notes:
899     * - without GS_EN=1 PRIM_AMP_FACTOR and MAX_VERTS_PER_SUBGROUP don't seem to work
900     * - with GS_EN=1 we must also set VGT_GS_MAX_VERT_OUT (otherwise the GPU hangs)
901     * - with GS_FAST_LAUNCH=1 every lane's VGPRs are initialized to the same input vertex index
902     *
903     */
904    ngg_info->esgs_ring_size = 1;
905    ngg_info->hw_max_esverts = 1;
906    ngg_info->max_gsprims = 1;
907    ngg_info->max_out_verts = nir->info.mesh.max_vertices_out;
908    ngg_info->max_vert_out_per_gs_instance = false;
909    ngg_info->ngg_emit_size = 0;
910    ngg_info->prim_amp_factor = nir->info.mesh.max_primitives_out;
911    ngg_info->vgt_esgs_ring_itemsize = 1;
912 
913    info->ms.has_query = device->cache_key.mesh_shader_queries;
914    info->ms.has_task = stage_key->has_task_shader;
915 }
916 
917 static void
calc_mesh_workgroup_size(const struct radv_device * device,const nir_shader * nir,struct radv_shader_info * info)918 calc_mesh_workgroup_size(const struct radv_device *device, const nir_shader *nir, struct radv_shader_info *info)
919 {
920    const struct radv_physical_device *pdev = radv_device_physical(device);
921    unsigned api_workgroup_size = ac_compute_cs_workgroup_size(nir->info.workgroup_size, false, UINT32_MAX);
922 
923    if (pdev->mesh_fast_launch_2) {
924       /* Use multi-row export. It is also necessary to use the API workgroup size for non-emulated queries. */
925       info->workgroup_size = api_workgroup_size;
926    } else {
927       struct gfx10_ngg_info *ngg_info = &info->ngg_info;
928       unsigned min_ngg_workgroup_size = ac_compute_ngg_workgroup_size(
929          ngg_info->hw_max_esverts, ngg_info->max_gsprims, ngg_info->max_out_verts, ngg_info->prim_amp_factor);
930 
931       info->workgroup_size = MAX2(min_ngg_workgroup_size, api_workgroup_size);
932    }
933 }
934 
935 static void
gather_shader_info_fs(const struct radv_device * device,const nir_shader * nir,const struct radv_graphics_state_key * gfx_state,struct radv_shader_info * info)936 gather_shader_info_fs(const struct radv_device *device, const nir_shader *nir,
937                       const struct radv_graphics_state_key *gfx_state, struct radv_shader_info *info)
938 {
939    const struct radv_physical_device *pdev = radv_device_physical(device);
940 
941    info->ps.num_inputs = util_bitcount64(nir->info.inputs_read);
942    info->ps.can_discard = nir->info.fs.uses_discard;
943    info->ps.early_fragment_test =
944       nir->info.fs.early_fragment_tests ||
945       (nir->info.fs.early_and_late_fragment_tests && nir->info.fs.depth_layout == FRAG_DEPTH_LAYOUT_NONE &&
946        nir->info.fs.stencil_front_layout == FRAG_STENCIL_LAYOUT_NONE &&
947        nir->info.fs.stencil_back_layout == FRAG_STENCIL_LAYOUT_NONE);
948    info->ps.post_depth_coverage = nir->info.fs.post_depth_coverage;
949    info->ps.depth_layout = nir->info.fs.depth_layout;
950    info->ps.uses_sample_shading = nir->info.fs.uses_sample_shading;
951    info->ps.uses_fbfetch_output = nir->info.fs.uses_fbfetch_output;
952    info->ps.writes_memory = nir->info.writes_memory;
953    info->ps.has_pcoord = nir->info.inputs_read & VARYING_BIT_PNTC;
954    info->ps.prim_id_input = nir->info.inputs_read & VARYING_BIT_PRIMITIVE_ID;
955    info->ps.reads_layer = BITSET_TEST(nir->info.system_values_read, SYSTEM_VALUE_LAYER_ID);
956    info->ps.viewport_index_input = nir->info.inputs_read & VARYING_BIT_VIEWPORT;
957    info->ps.writes_z = nir->info.outputs_written & BITFIELD64_BIT(FRAG_RESULT_DEPTH);
958    info->ps.writes_stencil = nir->info.outputs_written & BITFIELD64_BIT(FRAG_RESULT_STENCIL);
959    info->ps.writes_sample_mask = nir->info.outputs_written & BITFIELD64_BIT(FRAG_RESULT_SAMPLE_MASK);
960    info->ps.reads_sample_mask_in = BITSET_TEST(nir->info.system_values_read, SYSTEM_VALUE_SAMPLE_MASK_IN);
961    info->ps.reads_sample_id = BITSET_TEST(nir->info.system_values_read, SYSTEM_VALUE_SAMPLE_ID);
962    info->ps.reads_frag_shading_rate = BITSET_TEST(nir->info.system_values_read, SYSTEM_VALUE_FRAG_SHADING_RATE);
963    info->ps.reads_front_face = BITSET_TEST(nir->info.system_values_read, SYSTEM_VALUE_FRONT_FACE) |
964                                BITSET_TEST(nir->info.system_values_read, SYSTEM_VALUE_FRONT_FACE_FSIGN);
965    info->ps.reads_barycentric_model = BITSET_TEST(nir->info.system_values_read, SYSTEM_VALUE_BARYCENTRIC_PULL_MODEL);
966    info->ps.reads_fully_covered = BITSET_TEST(nir->info.system_values_read, SYSTEM_VALUE_FULLY_COVERED);
967 
968    bool uses_persp_or_linear_interp = info->ps.reads_persp_center || info->ps.reads_persp_centroid ||
969                                       info->ps.reads_persp_sample || info->ps.reads_linear_center ||
970                                       info->ps.reads_linear_centroid || info->ps.reads_linear_sample;
971 
972    info->ps.allow_flat_shading =
973       !(uses_persp_or_linear_interp || info->ps.needs_sample_positions || info->ps.reads_frag_shading_rate ||
974         info->ps.writes_memory || nir->info.fs.needs_quad_helper_invocations ||
975         BITSET_TEST(nir->info.system_values_read, SYSTEM_VALUE_FRAG_COORD) ||
976         BITSET_TEST(nir->info.system_values_read, SYSTEM_VALUE_PIXEL_COORD) ||
977         BITSET_TEST(nir->info.system_values_read, SYSTEM_VALUE_POINT_COORD) ||
978         BITSET_TEST(nir->info.system_values_read, SYSTEM_VALUE_SAMPLE_ID) ||
979         BITSET_TEST(nir->info.system_values_read, SYSTEM_VALUE_SAMPLE_POS) ||
980         BITSET_TEST(nir->info.system_values_read, SYSTEM_VALUE_SAMPLE_MASK_IN) ||
981         BITSET_TEST(nir->info.system_values_read, SYSTEM_VALUE_HELPER_INVOCATION));
982 
983    info->ps.pops_is_per_sample =
984       info->ps.pops && (nir->info.fs.sample_interlock_ordered || nir->info.fs.sample_interlock_unordered);
985 
986    info->ps.spi_ps_input_ena = radv_compute_spi_ps_input(pdev, gfx_state, info);
987    info->ps.spi_ps_input_addr = info->ps.spi_ps_input_ena;
988    if (pdev->info.gfx_level >= GFX12) {
989       /* Only SPI_PS_INPUT_ENA has this bit on GFX12. */
990       info->ps.spi_ps_input_addr &= C_02865C_COVERAGE_TO_SHADER_SELECT;
991    }
992 
993    info->ps.has_epilog = gfx_state->ps.has_epilog && info->ps.colors_written;
994 
995    const bool export_alpha = !!(info->ps.color0_written & 0x8);
996 
997    if (info->ps.has_epilog) {
998       info->ps.exports_mrtz_via_epilog = gfx_state->ps.exports_mrtz_via_epilog && export_alpha;
999    } else {
1000       info->ps.mrt0_is_dual_src = gfx_state->ps.epilog.mrt0_is_dual_src;
1001       info->ps.spi_shader_col_format = gfx_state->ps.epilog.spi_shader_col_format;
1002 
1003       /* Clear color attachments that aren't exported by the FS to match IO shader arguments. */
1004       info->ps.spi_shader_col_format &= info->ps.colors_written;
1005 
1006       info->ps.cb_shader_mask = ac_get_cb_shader_mask(info->ps.spi_shader_col_format);
1007    }
1008 
1009    if (!info->ps.exports_mrtz_via_epilog) {
1010       info->ps.writes_mrt0_alpha = gfx_state->ms.alpha_to_coverage_via_mrtz && export_alpha;
1011    }
1012 
1013    /* Disable VRS and use the rates from PS_ITER_SAMPLES if:
1014     *
1015     * - The fragment shader reads gl_SampleMaskIn because the 16-bit sample coverage mask isn't enough for MSAA8x and
1016     *   2x2 coarse shading.
1017     * - On GFX10.3, if the fragment shader requests a fragment interlock execution mode even if the ordered section was
1018     *   optimized out, to consistently implement fragmentShadingRateWithFragmentShaderInterlock = VK_FALSE.
1019     */
1020    info->ps.force_sample_iter_shading_rate =
1021       (info->ps.reads_sample_mask_in && !info->ps.needs_poly_line_smooth) ||
1022       (pdev->info.gfx_level == GFX10_3 &&
1023        (nir->info.fs.sample_interlock_ordered || nir->info.fs.sample_interlock_unordered ||
1024         nir->info.fs.pixel_interlock_ordered || nir->info.fs.pixel_interlock_unordered));
1025 }
1026 
1027 static void
gather_shader_info_rt(const nir_shader * nir,struct radv_shader_info * info)1028 gather_shader_info_rt(const nir_shader *nir, struct radv_shader_info *info)
1029 {
1030    // TODO: inline push_constants again
1031    info->loads_dynamic_offsets = true;
1032    info->loads_push_constants = true;
1033    info->can_inline_all_push_constants = false;
1034    info->inline_push_constant_mask = 0;
1035    info->desc_set_used_mask = -1u;
1036 }
1037 
1038 static void
gather_shader_info_cs(struct radv_device * device,const nir_shader * nir,const struct radv_shader_stage_key * stage_key,struct radv_shader_info * info)1039 gather_shader_info_cs(struct radv_device *device, const nir_shader *nir, const struct radv_shader_stage_key *stage_key,
1040                       struct radv_shader_info *info)
1041 {
1042    const struct radv_physical_device *pdev = radv_device_physical(device);
1043    unsigned default_wave_size = pdev->cs_wave_size;
1044    if (info->cs.uses_rt)
1045       default_wave_size = pdev->rt_wave_size;
1046 
1047    unsigned local_size = nir->info.workgroup_size[0] * nir->info.workgroup_size[1] * nir->info.workgroup_size[2];
1048 
1049    /* Games don't always request full subgroups when they should, which can cause bugs if cswave32
1050     * is enabled. Furthermore, if cooperative matrices or subgroup info are used, we can't transparently change
1051     * the subgroup size.
1052     */
1053    const bool require_full_subgroups =
1054       stage_key->subgroup_require_full || nir->info.cs.has_cooperative_matrix ||
1055       (default_wave_size == 32 && nir->info.uses_wide_subgroup_intrinsics && local_size % RADV_SUBGROUP_SIZE == 0);
1056 
1057    const unsigned required_subgroup_size = stage_key->subgroup_required_size * 32;
1058 
1059    if (required_subgroup_size) {
1060       info->wave_size = required_subgroup_size;
1061    } else if (require_full_subgroups) {
1062       info->wave_size = RADV_SUBGROUP_SIZE;
1063    } else if (pdev->info.gfx_level >= GFX10 && local_size <= 32) {
1064       /* Use wave32 for small workgroups. */
1065       info->wave_size = 32;
1066    } else {
1067       info->wave_size = default_wave_size;
1068    }
1069 
1070    if (pdev->info.has_cs_regalloc_hang_bug) {
1071       info->cs.regalloc_hang_bug = info->cs.block_size[0] * info->cs.block_size[1] * info->cs.block_size[2] > 256;
1072    }
1073 }
1074 
1075 static void
gather_shader_info_task(struct radv_device * device,const nir_shader * nir,const struct radv_shader_stage_key * stage_key,struct radv_shader_info * info)1076 gather_shader_info_task(struct radv_device *device, const nir_shader *nir,
1077                         const struct radv_shader_stage_key *stage_key, struct radv_shader_info *info)
1078 {
1079    gather_shader_info_cs(device, nir, stage_key, info);
1080 
1081    /* Task shaders always need these for the I/O lowering even if the API shader doesn't actually
1082     * use them.
1083     */
1084 
1085    /* Needed to address the task draw/payload rings. */
1086    info->cs.uses_block_id[0] = true;
1087    info->cs.uses_block_id[1] = true;
1088    info->cs.uses_block_id[2] = true;
1089    info->cs.uses_grid_size = true;
1090 
1091    /* Needed for storing draw ready only on the 1st thread. */
1092    info->cs.uses_local_invocation_idx = true;
1093 
1094    /* Task->Mesh dispatch is linear when Y = Z = 1.
1095     * GFX11 CP can optimize this case with a field in its draw packets.
1096     */
1097    info->cs.linear_taskmesh_dispatch =
1098       nir->info.mesh.ts_mesh_dispatch_dimensions[1] == 1 && nir->info.mesh.ts_mesh_dispatch_dimensions[2] == 1;
1099 
1100    info->cs.has_query = device->cache_key.mesh_shader_queries;
1101 }
1102 
1103 static uint32_t
radv_get_user_data_0(const struct radv_device * device,struct radv_shader_info * info)1104 radv_get_user_data_0(const struct radv_device *device, struct radv_shader_info *info)
1105 {
1106    const struct radv_physical_device *pdev = radv_device_physical(device);
1107    const enum amd_gfx_level gfx_level = pdev->info.gfx_level;
1108 
1109    switch (info->stage) {
1110    case MESA_SHADER_VERTEX:
1111    case MESA_SHADER_TESS_EVAL:
1112    case MESA_SHADER_MESH:
1113       if (info->next_stage == MESA_SHADER_TESS_CTRL) {
1114          assert(info->stage == MESA_SHADER_VERTEX);
1115 
1116          if (gfx_level >= GFX10) {
1117             return R_00B430_SPI_SHADER_USER_DATA_HS_0;
1118          } else if (gfx_level == GFX9) {
1119             return R_00B430_SPI_SHADER_USER_DATA_LS_0;
1120          } else {
1121             return R_00B530_SPI_SHADER_USER_DATA_LS_0;
1122          }
1123       }
1124 
1125       if (info->next_stage == MESA_SHADER_GEOMETRY) {
1126          assert(info->stage == MESA_SHADER_VERTEX || info->stage == MESA_SHADER_TESS_EVAL);
1127 
1128          if (gfx_level >= GFX10) {
1129             return R_00B230_SPI_SHADER_USER_DATA_GS_0;
1130          } else {
1131             return R_00B330_SPI_SHADER_USER_DATA_ES_0;
1132          }
1133       }
1134 
1135       if (info->is_ngg)
1136          return R_00B230_SPI_SHADER_USER_DATA_GS_0;
1137 
1138       assert(info->stage != MESA_SHADER_MESH);
1139       return R_00B130_SPI_SHADER_USER_DATA_VS_0;
1140    case MESA_SHADER_TESS_CTRL:
1141       return gfx_level == GFX9 ? R_00B430_SPI_SHADER_USER_DATA_LS_0 : R_00B430_SPI_SHADER_USER_DATA_HS_0;
1142    case MESA_SHADER_GEOMETRY:
1143       return gfx_level == GFX9 ? R_00B330_SPI_SHADER_USER_DATA_ES_0 : R_00B230_SPI_SHADER_USER_DATA_GS_0;
1144    case MESA_SHADER_FRAGMENT:
1145       return R_00B030_SPI_SHADER_USER_DATA_PS_0;
1146    case MESA_SHADER_COMPUTE:
1147    case MESA_SHADER_TASK:
1148    case MESA_SHADER_RAYGEN:
1149    case MESA_SHADER_CALLABLE:
1150    case MESA_SHADER_CLOSEST_HIT:
1151    case MESA_SHADER_MISS:
1152    case MESA_SHADER_INTERSECTION:
1153    case MESA_SHADER_ANY_HIT:
1154       return R_00B900_COMPUTE_USER_DATA_0;
1155    default:
1156       unreachable("invalid shader stage");
1157    }
1158 }
1159 
1160 static bool
radv_is_merged_shader_compiled_separately(const struct radv_device * device,const struct radv_shader_info * info)1161 radv_is_merged_shader_compiled_separately(const struct radv_device *device, const struct radv_shader_info *info)
1162 {
1163    const struct radv_physical_device *pdev = radv_device_physical(device);
1164    const enum amd_gfx_level gfx_level = pdev->info.gfx_level;
1165 
1166    if (gfx_level >= GFX9) {
1167       switch (info->stage) {
1168       case MESA_SHADER_VERTEX:
1169          if (info->next_stage == MESA_SHADER_TESS_CTRL || info->next_stage == MESA_SHADER_GEOMETRY)
1170             return !info->outputs_linked;
1171          break;
1172       case MESA_SHADER_TESS_EVAL:
1173          if (info->next_stage == MESA_SHADER_GEOMETRY)
1174             return !info->outputs_linked;
1175          break;
1176       case MESA_SHADER_TESS_CTRL:
1177       case MESA_SHADER_GEOMETRY:
1178          return !info->inputs_linked;
1179       default:
1180          break;
1181       }
1182    }
1183 
1184    return false;
1185 }
1186 
1187 void
radv_nir_shader_info_init(gl_shader_stage stage,gl_shader_stage next_stage,struct radv_shader_info * info)1188 radv_nir_shader_info_init(gl_shader_stage stage, gl_shader_stage next_stage, struct radv_shader_info *info)
1189 {
1190    memset(info, 0, sizeof(*info));
1191 
1192    /* Assume that shaders can inline all push constants by default. */
1193    info->can_inline_all_push_constants = true;
1194 
1195    info->stage = stage;
1196    info->next_stage = next_stage;
1197 }
1198 
1199 void
radv_nir_shader_info_pass(struct radv_device * device,const struct nir_shader * nir,const struct radv_shader_layout * layout,const struct radv_shader_stage_key * stage_key,const struct radv_graphics_state_key * gfx_state,const enum radv_pipeline_type pipeline_type,bool consider_force_vrs,struct radv_shader_info * info)1200 radv_nir_shader_info_pass(struct radv_device *device, const struct nir_shader *nir,
1201                           const struct radv_shader_layout *layout, const struct radv_shader_stage_key *stage_key,
1202                           const struct radv_graphics_state_key *gfx_state, const enum radv_pipeline_type pipeline_type,
1203                           bool consider_force_vrs, struct radv_shader_info *info)
1204 {
1205    const struct radv_physical_device *pdev = radv_device_physical(device);
1206    struct nir_function *func = (struct nir_function *)exec_list_get_head_const(&nir->functions);
1207 
1208    if (layout->use_dynamic_descriptors) {
1209       info->loads_push_constants = true;
1210       info->loads_dynamic_offsets = true;
1211    }
1212 
1213    nir_foreach_block (block, func->impl) {
1214       gather_info_block(nir, block, info, gfx_state, stage_key, consider_force_vrs);
1215    }
1216 
1217    if (nir->info.stage == MESA_SHADER_VERTEX || nir->info.stage == MESA_SHADER_TESS_EVAL ||
1218        nir->info.stage == MESA_SHADER_GEOMETRY)
1219       gather_xfb_info(nir, info);
1220 
1221    if (nir->info.stage == MESA_SHADER_VERTEX || nir->info.stage == MESA_SHADER_TESS_EVAL ||
1222        nir->info.stage == MESA_SHADER_GEOMETRY || nir->info.stage == MESA_SHADER_MESH) {
1223       struct radv_vs_output_info *outinfo = &info->outinfo;
1224       uint64_t per_vtx_mask, per_prim_mask;
1225 
1226       radv_get_output_masks(nir, gfx_state, &per_vtx_mask, &per_prim_mask);
1227 
1228       /* Mesh multiview is only lowered in ac_nir_lower_ngg, so we have to fake it here. */
1229       if (nir->info.stage == MESA_SHADER_MESH && gfx_state->has_multiview_view_index)
1230          info->uses_view_index = true;
1231 
1232       /* Per vertex outputs. */
1233       outinfo->writes_pointsize = per_vtx_mask & VARYING_BIT_PSIZ;
1234       outinfo->writes_viewport_index = per_vtx_mask & VARYING_BIT_VIEWPORT;
1235       outinfo->writes_layer = per_vtx_mask & VARYING_BIT_LAYER;
1236       outinfo->writes_primitive_shading_rate =
1237          (per_vtx_mask & VARYING_BIT_PRIMITIVE_SHADING_RATE) || info->force_vrs_per_vertex;
1238 
1239       /* Per primitive outputs. */
1240       outinfo->writes_viewport_index_per_primitive = per_prim_mask & VARYING_BIT_VIEWPORT;
1241       outinfo->writes_layer_per_primitive = per_prim_mask & VARYING_BIT_LAYER;
1242       outinfo->writes_primitive_shading_rate_per_primitive = per_prim_mask & VARYING_BIT_PRIMITIVE_SHADING_RATE;
1243       outinfo->export_prim_id_per_primitive = per_prim_mask & VARYING_BIT_PRIMITIVE_ID;
1244 
1245       outinfo->pos_exports = 1;
1246 
1247       if (outinfo->writes_pointsize || outinfo->writes_viewport_index || outinfo->writes_layer ||
1248           outinfo->writes_primitive_shading_rate)
1249          outinfo->pos_exports++;
1250 
1251       unsigned clip_cull_mask = outinfo->clip_dist_mask | outinfo->cull_dist_mask;
1252 
1253       if (clip_cull_mask & 0x0f)
1254          outinfo->pos_exports++;
1255       if (clip_cull_mask & 0xf0)
1256          outinfo->pos_exports++;
1257    }
1258 
1259    info->vs.needs_draw_id |= BITSET_TEST(nir->info.system_values_read, SYSTEM_VALUE_DRAW_ID);
1260    info->vs.needs_base_instance |= BITSET_TEST(nir->info.system_values_read, SYSTEM_VALUE_BASE_INSTANCE);
1261    info->vs.needs_instance_id |= BITSET_TEST(nir->info.system_values_read, SYSTEM_VALUE_INSTANCE_ID);
1262    info->uses_view_index |= BITSET_TEST(nir->info.system_values_read, SYSTEM_VALUE_VIEW_INDEX);
1263    info->uses_invocation_id |= BITSET_TEST(nir->info.system_values_read, SYSTEM_VALUE_INVOCATION_ID);
1264    info->uses_prim_id |= BITSET_TEST(nir->info.system_values_read, SYSTEM_VALUE_PRIMITIVE_ID);
1265 
1266    /* Used by compute and mesh shaders. Mesh shaders must always declare this before GFX11. */
1267    info->cs.uses_grid_size = BITSET_TEST(nir->info.system_values_read, SYSTEM_VALUE_NUM_WORKGROUPS) ||
1268                              (nir->info.stage == MESA_SHADER_MESH && pdev->info.gfx_level < GFX11);
1269    info->cs.uses_local_invocation_idx = BITSET_TEST(nir->info.system_values_read, SYSTEM_VALUE_LOCAL_INVOCATION_INDEX) |
1270                                         BITSET_TEST(nir->info.system_values_read, SYSTEM_VALUE_SUBGROUP_ID) |
1271                                         BITSET_TEST(nir->info.system_values_read, SYSTEM_VALUE_NUM_SUBGROUPS) |
1272                                         radv_shader_should_clear_lds(device, nir);
1273 
1274    if (nir->info.stage == MESA_SHADER_COMPUTE || nir->info.stage == MESA_SHADER_TASK ||
1275        nir->info.stage == MESA_SHADER_MESH) {
1276       for (int i = 0; i < 3; ++i)
1277          info->cs.block_size[i] = nir->info.workgroup_size[i];
1278    }
1279 
1280    info->user_data_0 = radv_get_user_data_0(device, info);
1281    info->merged_shader_compiled_separately = radv_is_merged_shader_compiled_separately(device, info);
1282    info->force_indirect_desc_sets = info->merged_shader_compiled_separately || stage_key->indirect_bindable;
1283 
1284    switch (nir->info.stage) {
1285    case MESA_SHADER_COMPUTE:
1286       gather_shader_info_cs(device, nir, stage_key, info);
1287       break;
1288    case MESA_SHADER_TASK:
1289       gather_shader_info_task(device, nir, stage_key, info);
1290       break;
1291    case MESA_SHADER_FRAGMENT:
1292       gather_shader_info_fs(device, nir, gfx_state, info);
1293       break;
1294    case MESA_SHADER_GEOMETRY:
1295       gather_shader_info_gs(device, nir, info);
1296       break;
1297    case MESA_SHADER_TESS_EVAL:
1298       gather_shader_info_tes(device, nir, info);
1299       break;
1300    case MESA_SHADER_TESS_CTRL:
1301       gather_shader_info_tcs(device, nir, gfx_state, info);
1302       break;
1303    case MESA_SHADER_VERTEX:
1304       gather_shader_info_vs(device, nir, gfx_state, stage_key, info);
1305       break;
1306    case MESA_SHADER_MESH:
1307       gather_shader_info_mesh(device, nir, stage_key, info);
1308       break;
1309    default:
1310       if (gl_shader_stage_is_rt(nir->info.stage))
1311          gather_shader_info_rt(nir, info);
1312       break;
1313    }
1314 
1315    info->wave_size = radv_get_wave_size(device, nir->info.stage, info, stage_key);
1316    info->ballot_bit_size = radv_get_ballot_bit_size(device, nir->info.stage, info, stage_key);
1317 
1318    switch (nir->info.stage) {
1319    case MESA_SHADER_COMPUTE:
1320    case MESA_SHADER_TASK:
1321       info->workgroup_size = ac_compute_cs_workgroup_size(nir->info.workgroup_size, false, UINT32_MAX);
1322 
1323       /* Allow the compiler to assume that the shader always has full subgroups,
1324        * meaning that the initial EXEC mask is -1 in all waves (all lanes enabled).
1325        * This assumption is incorrect for ray tracing and internal (meta) shaders
1326        * because they can use unaligned dispatch.
1327        */
1328       info->cs.uses_full_subgroups = pipeline_type != RADV_PIPELINE_RAY_TRACING && !nir->info.internal &&
1329                                      (info->workgroup_size % info->wave_size) == 0;
1330       break;
1331    case MESA_SHADER_VERTEX:
1332       if (info->vs.as_ls || info->vs.as_es) {
1333          /* Set the maximum possible value by default, this will be optimized during linking if
1334           * possible.
1335           */
1336          info->workgroup_size = 256;
1337       } else {
1338          info->workgroup_size = info->wave_size;
1339       }
1340       break;
1341    case MESA_SHADER_TESS_CTRL:
1342       if (gfx_state->ts.patch_control_points) {
1343          info->workgroup_size =
1344             ac_compute_lshs_workgroup_size(pdev->info.gfx_level, MESA_SHADER_TESS_CTRL, info->num_tess_patches,
1345                                            gfx_state->ts.patch_control_points, info->tcs.tcs_vertices_out);
1346       } else {
1347          /* Set the maximum possible value when the workgroup size can't be determined. */
1348          info->workgroup_size = 256;
1349       }
1350       break;
1351    case MESA_SHADER_TESS_EVAL:
1352       if (info->tes.as_es) {
1353          /* Set the maximum possible value by default, this will be optimized during linking if
1354           * possible.
1355           */
1356          info->workgroup_size = 256;
1357       } else {
1358          info->workgroup_size = info->wave_size;
1359       }
1360       break;
1361    case MESA_SHADER_GEOMETRY:
1362       if (!info->is_ngg) {
1363          unsigned es_verts_per_subgroup = info->gs_ring_info.es_verts_per_subgroup;
1364          unsigned gs_inst_prims_in_subgroup = info->gs_ring_info.gs_inst_prims_in_subgroup;
1365 
1366          info->workgroup_size = ac_compute_esgs_workgroup_size(pdev->info.gfx_level, info->wave_size,
1367                                                                es_verts_per_subgroup, gs_inst_prims_in_subgroup);
1368       } else {
1369          /* Set the maximum possible value by default, this will be optimized during linking if
1370           * possible.
1371           */
1372          info->workgroup_size = 256;
1373       }
1374       break;
1375    case MESA_SHADER_MESH:
1376       calc_mesh_workgroup_size(device, nir, info);
1377       break;
1378    default:
1379       /* FS always operates without workgroups. Other stages are computed during linking but assume
1380        * no workgroups by default.
1381        */
1382       info->workgroup_size = info->wave_size;
1383       break;
1384    }
1385 }
1386 
1387 static void
clamp_gsprims_to_esverts(unsigned * max_gsprims,unsigned max_esverts,unsigned min_verts_per_prim,bool use_adjacency)1388 clamp_gsprims_to_esverts(unsigned *max_gsprims, unsigned max_esverts, unsigned min_verts_per_prim, bool use_adjacency)
1389 {
1390    unsigned max_reuse = max_esverts - min_verts_per_prim;
1391    if (use_adjacency)
1392       max_reuse /= 2;
1393    *max_gsprims = MIN2(*max_gsprims, 1 + max_reuse);
1394 }
1395 
1396 static unsigned
radv_get_num_input_vertices(const struct radv_shader_info * es_info,const struct radv_shader_info * gs_info)1397 radv_get_num_input_vertices(const struct radv_shader_info *es_info, const struct radv_shader_info *gs_info)
1398 {
1399    if (gs_info) {
1400       return gs_info->gs.vertices_in;
1401    }
1402 
1403    if (es_info->stage == MESA_SHADER_TESS_EVAL) {
1404       if (es_info->tes.point_mode)
1405          return 1;
1406       if (es_info->tes._primitive_mode == TESS_PRIMITIVE_ISOLINES)
1407          return 2;
1408       return 3;
1409    }
1410 
1411    return 3;
1412 }
1413 
1414 static unsigned
radv_get_pre_rast_input_topology(const struct radv_shader_info * es_info,const struct radv_shader_info * gs_info)1415 radv_get_pre_rast_input_topology(const struct radv_shader_info *es_info, const struct radv_shader_info *gs_info)
1416 {
1417    if (gs_info) {
1418       return gs_info->gs.input_prim;
1419    }
1420 
1421    if (es_info->stage == MESA_SHADER_TESS_EVAL) {
1422       if (es_info->tes.point_mode)
1423          return MESA_PRIM_POINTS;
1424       if (es_info->tes._primitive_mode == TESS_PRIMITIVE_ISOLINES)
1425          return MESA_PRIM_LINES;
1426       return MESA_PRIM_TRIANGLES;
1427    }
1428 
1429    return MESA_PRIM_TRIANGLES;
1430 }
1431 
1432 static unsigned
gfx10_get_ngg_scratch_lds_base(const struct radv_device * device,const struct radv_shader_info * es_info,const struct radv_shader_info * gs_info,const struct gfx10_ngg_info * ngg_info)1433 gfx10_get_ngg_scratch_lds_base(const struct radv_device *device, const struct radv_shader_info *es_info,
1434                                const struct radv_shader_info *gs_info, const struct gfx10_ngg_info *ngg_info)
1435 {
1436    const struct radv_physical_device *pdev = radv_device_physical(device);
1437    uint32_t scratch_lds_base;
1438 
1439    if (gs_info) {
1440       const unsigned esgs_ring_lds_bytes = ngg_info->esgs_ring_size;
1441       const unsigned gs_total_out_vtx_bytes = ngg_info->ngg_emit_size * 4u;
1442 
1443       scratch_lds_base = ALIGN(esgs_ring_lds_bytes + gs_total_out_vtx_bytes, 8u /* for the repacking code */);
1444    } else {
1445       const bool uses_instanceid = es_info->vs.needs_instance_id;
1446       const bool uses_primitive_id = es_info->uses_prim_id;
1447       const bool streamout_enabled = es_info->so.num_outputs && pdev->use_ngg_streamout;
1448       const uint32_t num_outputs =
1449          es_info->stage == MESA_SHADER_VERTEX ? es_info->vs.num_outputs : es_info->tes.num_outputs;
1450       unsigned pervertex_lds_bytes = ac_ngg_nogs_get_pervertex_lds_size(
1451          es_info->stage, num_outputs, streamout_enabled, es_info->outinfo.export_prim_id, false, /* user edge flag */
1452          es_info->has_ngg_culling, uses_instanceid, uses_primitive_id);
1453 
1454       assert(ngg_info->hw_max_esverts <= 256);
1455       unsigned total_es_lds_bytes = pervertex_lds_bytes * ngg_info->hw_max_esverts;
1456 
1457       scratch_lds_base = ALIGN(total_es_lds_bytes, 8u);
1458    }
1459 
1460    return scratch_lds_base;
1461 }
1462 
1463 void
gfx10_get_ngg_info(const struct radv_device * device,struct radv_shader_info * es_info,struct radv_shader_info * gs_info,struct gfx10_ngg_info * out)1464 gfx10_get_ngg_info(const struct radv_device *device, struct radv_shader_info *es_info, struct radv_shader_info *gs_info,
1465                    struct gfx10_ngg_info *out)
1466 {
1467    const struct radv_physical_device *pdev = radv_device_physical(device);
1468    const enum amd_gfx_level gfx_level = pdev->info.gfx_level;
1469    const unsigned max_verts_per_prim = radv_get_num_input_vertices(es_info, gs_info);
1470    const unsigned min_verts_per_prim = gs_info ? max_verts_per_prim : 1;
1471 
1472    const unsigned gs_num_invocations = gs_info ? MAX2(gs_info->gs.invocations, 1) : 1;
1473 
1474    const unsigned input_prim = radv_get_pre_rast_input_topology(es_info, gs_info);
1475    const bool uses_adjacency = input_prim == MESA_PRIM_LINES_ADJACENCY || input_prim == MESA_PRIM_TRIANGLES_ADJACENCY;
1476 
1477    /* All these are in dwords: */
1478    /* We can't allow using the whole LDS, because GS waves compete with
1479     * other shader stages for LDS space.
1480     *
1481     * TODO: We should really take the shader's internal LDS use into
1482     *       account. The linker will fail if the size is greater than
1483     *       8K dwords.
1484     */
1485    const unsigned max_lds_size = 8 * 1024 - 768;
1486    const unsigned target_lds_size = max_lds_size;
1487    unsigned esvert_lds_size = 0;
1488    unsigned gsprim_lds_size = 0;
1489 
1490    /* All these are per subgroup: */
1491    const unsigned min_esverts = gfx_level >= GFX11 ? max_verts_per_prim /* gfx11 requires at least 1 primitive per TG */
1492                                 : gfx_level >= GFX10_3 ? 29
1493                                                        : (24 - 1 + max_verts_per_prim);
1494    bool max_vert_out_per_gs_instance = false;
1495    unsigned max_esverts_base = 128;
1496    unsigned max_gsprims_base = 128; /* default prim group size clamp */
1497 
1498    /* Hardware has the following non-natural restrictions on the value
1499     * of GE_CNTL.VERT_GRP_SIZE based on based on the primitive type of
1500     * the draw:
1501     *  - at most 252 for any line input primitive type
1502     *  - at most 251 for any quad input primitive type
1503     *  - at most 251 for triangle strips with adjacency (this happens to
1504     *    be the natural limit for triangle *lists* with adjacency)
1505     */
1506    max_esverts_base = MIN2(max_esverts_base, 251 + max_verts_per_prim - 1);
1507 
1508    if (gs_info) {
1509       unsigned max_out_verts_per_gsprim = gs_info->gs.vertices_out * gs_num_invocations;
1510 
1511       if (max_out_verts_per_gsprim <= 256) {
1512          if (max_out_verts_per_gsprim) {
1513             max_gsprims_base = MIN2(max_gsprims_base, 256 / max_out_verts_per_gsprim);
1514          }
1515       } else {
1516          /* Use special multi-cycling mode in which each GS
1517           * instance gets its own subgroup. Does not work with
1518           * tessellation. */
1519          max_vert_out_per_gs_instance = true;
1520          max_gsprims_base = 1;
1521          max_out_verts_per_gsprim = gs_info->gs.vertices_out;
1522       }
1523 
1524       esvert_lds_size = es_info->esgs_itemsize / 4;
1525       gsprim_lds_size = (gs_info->gs.gsvs_vertex_size / 4 + 1) * max_out_verts_per_gsprim;
1526    } else {
1527       /* VS and TES. */
1528       /* LDS size for passing data from GS to ES. */
1529       struct radv_streamout_info *so_info = &es_info->so;
1530 
1531       if (so_info->num_outputs) {
1532          /* Compute the same pervertex LDS size as the NGG streamout lowering pass which allocates
1533           * space for all outputs.
1534           * TODO: only alloc space for outputs that really need streamout.
1535           */
1536          const uint32_t num_outputs =
1537             es_info->stage == MESA_SHADER_VERTEX ? es_info->vs.num_outputs : es_info->tes.num_outputs;
1538          esvert_lds_size = 4 * num_outputs + 1;
1539       }
1540 
1541       /* GS stores Primitive IDs (one DWORD) into LDS at the address
1542        * corresponding to the ES thread of the provoking vertex. All
1543        * ES threads load and export PrimitiveID for their thread.
1544        */
1545       if (es_info->stage == MESA_SHADER_VERTEX && es_info->outinfo.export_prim_id)
1546          esvert_lds_size = MAX2(esvert_lds_size, 1);
1547    }
1548 
1549    unsigned max_gsprims = max_gsprims_base;
1550    unsigned max_esverts = max_esverts_base;
1551 
1552    if (esvert_lds_size)
1553       max_esverts = MIN2(max_esverts, target_lds_size / esvert_lds_size);
1554    if (gsprim_lds_size)
1555       max_gsprims = MIN2(max_gsprims, target_lds_size / gsprim_lds_size);
1556 
1557    max_esverts = MIN2(max_esverts, max_gsprims * max_verts_per_prim);
1558    clamp_gsprims_to_esverts(&max_gsprims, max_esverts, min_verts_per_prim, uses_adjacency);
1559    assert(max_esverts >= max_verts_per_prim && max_gsprims >= 1);
1560 
1561    if (esvert_lds_size || gsprim_lds_size) {
1562       /* Now that we have a rough proportionality between esverts
1563        * and gsprims based on the primitive type, scale both of them
1564        * down simultaneously based on required LDS space.
1565        *
1566        * We could be smarter about this if we knew how much vertex
1567        * reuse to expect.
1568        */
1569       unsigned lds_total = max_esverts * esvert_lds_size + max_gsprims * gsprim_lds_size;
1570       if (lds_total > target_lds_size) {
1571          max_esverts = max_esverts * target_lds_size / lds_total;
1572          max_gsprims = max_gsprims * target_lds_size / lds_total;
1573 
1574          max_esverts = MIN2(max_esverts, max_gsprims * max_verts_per_prim);
1575          clamp_gsprims_to_esverts(&max_gsprims, max_esverts, min_verts_per_prim, uses_adjacency);
1576          assert(max_esverts >= max_verts_per_prim && max_gsprims >= 1);
1577       }
1578    }
1579 
1580    /* Round up towards full wave sizes for better ALU utilization. */
1581    if (!max_vert_out_per_gs_instance) {
1582       unsigned orig_max_esverts;
1583       unsigned orig_max_gsprims;
1584       unsigned wavesize;
1585 
1586       if (gs_info) {
1587          wavesize = gs_info->wave_size;
1588       } else {
1589          wavesize = es_info->wave_size;
1590       }
1591 
1592       do {
1593          orig_max_esverts = max_esverts;
1594          orig_max_gsprims = max_gsprims;
1595 
1596          max_esverts = align(max_esverts, wavesize);
1597          max_esverts = MIN2(max_esverts, max_esverts_base);
1598          if (esvert_lds_size)
1599             max_esverts = MIN2(max_esverts, (max_lds_size - max_gsprims * gsprim_lds_size) / esvert_lds_size);
1600          max_esverts = MIN2(max_esverts, max_gsprims * max_verts_per_prim);
1601 
1602          /* Hardware restriction: minimum value of max_esverts */
1603          if (gfx_level == GFX10)
1604             max_esverts = MAX2(max_esverts, min_esverts - 1 + max_verts_per_prim);
1605          else
1606             max_esverts = MAX2(max_esverts, min_esverts);
1607 
1608          max_gsprims = align(max_gsprims, wavesize);
1609          max_gsprims = MIN2(max_gsprims, max_gsprims_base);
1610          if (gsprim_lds_size) {
1611             /* Don't count unusable vertices to the LDS
1612              * size. Those are vertices above the maximum
1613              * number of vertices that can occur in the
1614              * workgroup, which is e.g. max_gsprims * 3
1615              * for triangles.
1616              */
1617             unsigned usable_esverts = MIN2(max_esverts, max_gsprims * max_verts_per_prim);
1618             max_gsprims = MIN2(max_gsprims, (max_lds_size - usable_esverts * esvert_lds_size) / gsprim_lds_size);
1619          }
1620          clamp_gsprims_to_esverts(&max_gsprims, max_esverts, min_verts_per_prim, uses_adjacency);
1621          assert(max_esverts >= max_verts_per_prim && max_gsprims >= 1);
1622       } while (orig_max_esverts != max_esverts || orig_max_gsprims != max_gsprims);
1623 
1624       /* Verify the restriction. */
1625       if (gfx_level == GFX10)
1626          assert(max_esverts >= min_esverts - 1 + max_verts_per_prim);
1627       else
1628          assert(max_esverts >= min_esverts);
1629    } else {
1630       /* Hardware restriction: minimum value of max_esverts */
1631       if (gfx_level == GFX10)
1632          max_esverts = MAX2(max_esverts, min_esverts - 1 + max_verts_per_prim);
1633       else
1634          max_esverts = MAX2(max_esverts, min_esverts);
1635    }
1636 
1637    unsigned max_out_vertices = max_vert_out_per_gs_instance ? gs_info->gs.vertices_out
1638                                : gs_info ? max_gsprims * gs_num_invocations * gs_info->gs.vertices_out
1639                                          : max_esverts;
1640    assert(max_out_vertices <= 256);
1641 
1642    unsigned prim_amp_factor = 1;
1643    if (gs_info) {
1644       /* Number of output primitives per GS input primitive after
1645        * GS instancing. */
1646       prim_amp_factor = gs_info->gs.vertices_out;
1647    }
1648 
1649    /* On Gfx10, the GE only checks against the maximum number of ES verts
1650     * after allocating a full GS primitive. So we need to ensure that
1651     * whenever this check passes, there is enough space for a full
1652     * primitive without vertex reuse.
1653     */
1654    if (gfx_level == GFX10)
1655       out->hw_max_esverts = max_esverts - max_verts_per_prim + 1;
1656    else
1657       out->hw_max_esverts = max_esverts;
1658 
1659    out->max_gsprims = max_gsprims;
1660    out->max_out_verts = max_out_vertices;
1661    out->prim_amp_factor = prim_amp_factor;
1662    out->max_vert_out_per_gs_instance = max_vert_out_per_gs_instance;
1663    out->ngg_emit_size = max_gsprims * gsprim_lds_size;
1664 
1665    /* Don't count unusable vertices. */
1666    out->esgs_ring_size = MIN2(max_esverts, max_gsprims * max_verts_per_prim) * esvert_lds_size * 4;
1667 
1668    if (gs_info) {
1669       out->vgt_esgs_ring_itemsize = es_info->esgs_itemsize / 4;
1670    } else {
1671       out->vgt_esgs_ring_itemsize = 1;
1672    }
1673 
1674    assert(out->hw_max_esverts >= min_esverts); /* HW limitation */
1675 
1676    out->scratch_lds_base = gfx10_get_ngg_scratch_lds_base(device, es_info, gs_info, out);
1677 
1678    /* Get scratch LDS usage. */
1679    const struct radv_shader_info *info = gs_info ? gs_info : es_info;
1680    const unsigned scratch_lds_size = ac_ngg_get_scratch_lds_size(info->stage, info->workgroup_size, info->wave_size,
1681                                                                  pdev->use_ngg_streamout, info->has_ngg_culling, false);
1682    out->lds_size = out->scratch_lds_base + scratch_lds_size;
1683 
1684    unsigned workgroup_size =
1685       ac_compute_ngg_workgroup_size(max_esverts, max_gsprims * gs_num_invocations, max_out_vertices, prim_amp_factor);
1686    if (gs_info) {
1687       gs_info->workgroup_size = workgroup_size;
1688    }
1689    es_info->workgroup_size = workgroup_size;
1690 }
1691 
1692 static void
radv_determine_ngg_settings(struct radv_device * device,struct radv_shader_stage * es_stage,struct radv_shader_stage * fs_stage,const struct radv_graphics_state_key * gfx_state)1693 radv_determine_ngg_settings(struct radv_device *device, struct radv_shader_stage *es_stage,
1694                             struct radv_shader_stage *fs_stage, const struct radv_graphics_state_key *gfx_state)
1695 {
1696    const struct radv_physical_device *pdev = radv_device_physical(device);
1697    uint64_t ps_inputs_read;
1698 
1699    assert(es_stage->stage == MESA_SHADER_VERTEX || es_stage->stage == MESA_SHADER_TESS_EVAL);
1700    assert(!fs_stage || fs_stage->stage == MESA_SHADER_FRAGMENT);
1701 
1702    if (fs_stage) {
1703       ps_inputs_read = fs_stage->nir->info.inputs_read;
1704    } else {
1705       /* Rely on the number of VS/TES outputs when the FS is unknown (for fast-link or unlinked ESO)
1706        * because this should be a good approximation of the number of FS inputs.
1707        */
1708       ps_inputs_read = es_stage->nir->info.outputs_written;
1709 
1710       /* Clear varyings that can't be PS inputs. */
1711       ps_inputs_read &= ~(VARYING_BIT_POS | VARYING_BIT_PSIZ);
1712    }
1713 
1714    unsigned num_vertices_per_prim = 0;
1715    if (es_stage->stage == MESA_SHADER_VERTEX) {
1716       num_vertices_per_prim = radv_get_num_vertices_per_prim(gfx_state);
1717    } else if (es_stage->stage == MESA_SHADER_TESS_EVAL) {
1718       num_vertices_per_prim = es_stage->nir->info.tess.point_mode                                   ? 1
1719                               : es_stage->nir->info.tess._primitive_mode == TESS_PRIMITIVE_ISOLINES ? 2
1720                                                                                                     : 3;
1721    }
1722 
1723    es_stage->info.has_ngg_culling =
1724       radv_consider_culling(pdev, es_stage->nir, ps_inputs_read, num_vertices_per_prim, &es_stage->info);
1725 
1726    nir_function_impl *impl = nir_shader_get_entrypoint(es_stage->nir);
1727    es_stage->info.has_ngg_early_prim_export = pdev->info.gfx_level < GFX11 && exec_list_is_singular(&impl->body);
1728 
1729    /* NGG passthrough mode should be disabled when culling and when the vertex shader
1730     * exports the primitive ID.
1731     */
1732    es_stage->info.is_ngg_passthrough = !es_stage->info.has_ngg_culling && !(es_stage->stage == MESA_SHADER_VERTEX &&
1733                                                                             es_stage->info.outinfo.export_prim_id);
1734 }
1735 
1736 static void
radv_link_shaders_info(struct radv_device * device,struct radv_shader_stage * producer,struct radv_shader_stage * consumer,const struct radv_graphics_state_key * gfx_state)1737 radv_link_shaders_info(struct radv_device *device, struct radv_shader_stage *producer,
1738                        struct radv_shader_stage *consumer, const struct radv_graphics_state_key *gfx_state)
1739 {
1740    const struct radv_physical_device *pdev = radv_device_physical(device);
1741 
1742    /* Export primitive ID and clip/cull distances if read by the FS, or export unconditionally when
1743     * the next stage is unknown (with graphics pipeline library).
1744     */
1745    if (producer->info.next_stage == MESA_SHADER_FRAGMENT ||
1746        !(gfx_state->lib_flags & VK_GRAPHICS_PIPELINE_LIBRARY_FRAGMENT_SHADER_BIT_EXT)) {
1747       const bool ps_prim_id_in = !consumer || consumer->info.ps.prim_id_input;
1748       const bool ps_clip_dists_in = !consumer || !!consumer->info.ps.input_clips_culls_mask;
1749 
1750       radv_set_vs_output_param(device, producer->nir, gfx_state, &producer->info, ps_prim_id_in, ps_clip_dists_in);
1751    }
1752 
1753    if (producer->stage == MESA_SHADER_VERTEX || producer->stage == MESA_SHADER_TESS_EVAL) {
1754       /* Compute NGG info (GFX10+) or GS info. */
1755       if (producer->info.is_ngg) {
1756          struct radv_shader_stage *gs_stage = consumer && consumer->stage == MESA_SHADER_GEOMETRY ? consumer : NULL;
1757          struct gfx10_ngg_info *out = gs_stage ? &gs_stage->info.ngg_info : &producer->info.ngg_info;
1758 
1759          /* Determine other NGG settings like culling for VS or TES without GS. */
1760          if (!gs_stage) {
1761             radv_determine_ngg_settings(device, producer, consumer, gfx_state);
1762          }
1763 
1764          gfx10_get_ngg_info(device, &producer->info, gs_stage ? &gs_stage->info : NULL, out);
1765       } else if (consumer && consumer->stage == MESA_SHADER_GEOMETRY) {
1766          struct radv_shader_info *gs_info = &consumer->info;
1767          struct radv_shader_info *es_info = &producer->info;
1768 
1769          es_info->workgroup_size = gs_info->workgroup_size;
1770       }
1771 
1772       if (consumer && consumer->stage == MESA_SHADER_GEOMETRY) {
1773          producer->info.gs_inputs_read = consumer->nir->info.inputs_read;
1774       }
1775    }
1776 
1777    if (producer->stage == MESA_SHADER_VERTEX && consumer && consumer->stage == MESA_SHADER_TESS_CTRL) {
1778       struct radv_shader_stage *vs_stage = producer;
1779       struct radv_shader_stage *tcs_stage = consumer;
1780 
1781       vs_stage->info.vs.tcs_inputs_via_lds = tcs_stage->nir->info.inputs_read;
1782 
1783       if (gfx_state->ts.patch_control_points) {
1784          vs_stage->info.workgroup_size =
1785             ac_compute_lshs_workgroup_size(pdev->info.gfx_level, MESA_SHADER_VERTEX, tcs_stage->info.num_tess_patches,
1786                                            gfx_state->ts.patch_control_points, tcs_stage->info.tcs.tcs_vertices_out);
1787 
1788          if (!radv_use_llvm_for_stage(pdev, MESA_SHADER_VERTEX)) {
1789             /* When the number of TCS input and output vertices are the same (typically 3):
1790              * - There is an equal amount of LS and HS invocations
1791              * - In case of merged LSHS shaders, the LS and HS halves of the shader always process
1792              *   the exact same vertex. We can use this knowledge to optimize them.
1793              *
1794              * We don't set tcs_in_out_eq if the float controls differ because that might involve
1795              * different float modes for the same block and our optimizer doesn't handle a
1796              * instruction dominating another with a different mode.
1797              */
1798             vs_stage->info.vs.tcs_in_out_eq =
1799                pdev->info.gfx_level >= GFX9 &&
1800                gfx_state->ts.patch_control_points == tcs_stage->info.tcs.tcs_vertices_out &&
1801                vs_stage->nir->info.float_controls_execution_mode == tcs_stage->nir->info.float_controls_execution_mode;
1802 
1803             if (vs_stage->info.vs.tcs_in_out_eq) {
1804                vs_stage->info.vs.tcs_inputs_via_temp = vs_stage->nir->info.outputs_written &
1805                                                        ~vs_stage->nir->info.outputs_accessed_indirectly &
1806                                                        tcs_stage->nir->info.tess.tcs_same_invocation_inputs_read;
1807                vs_stage->info.vs.tcs_inputs_via_lds = tcs_stage->nir->info.tess.tcs_cross_invocation_inputs_read |
1808                                                       (tcs_stage->nir->info.tess.tcs_same_invocation_inputs_read &
1809                                                        tcs_stage->nir->info.inputs_read_indirectly) |
1810                                                       (tcs_stage->nir->info.tess.tcs_same_invocation_inputs_read &
1811                                                        vs_stage->nir->info.outputs_accessed_indirectly);
1812             }
1813          }
1814       }
1815    }
1816 
1817    /* Copy shader info between TCS<->TES. */
1818    if (producer->stage == MESA_SHADER_TESS_CTRL && consumer && consumer->stage == MESA_SHADER_TESS_EVAL) {
1819       struct radv_shader_stage *tcs_stage = producer;
1820       struct radv_shader_stage *tes_stage = consumer;
1821 
1822       tcs_stage->info.tcs.tes_reads_tess_factors = tes_stage->info.tes.reads_tess_factors;
1823       tcs_stage->info.tcs.tes_inputs_read = tes_stage->nir->info.inputs_read;
1824       tcs_stage->info.tcs.tes_patch_inputs_read = tes_stage->nir->info.patch_inputs_read;
1825       tcs_stage->info.tes._primitive_mode = tes_stage->nir->info.tess._primitive_mode;
1826 
1827       if (gfx_state->ts.patch_control_points)
1828          tes_stage->info.num_tess_patches = tcs_stage->info.num_tess_patches;
1829    }
1830 }
1831 
1832 static void
radv_nir_shader_info_merge(const struct radv_shader_stage * src,struct radv_shader_stage * dst)1833 radv_nir_shader_info_merge(const struct radv_shader_stage *src, struct radv_shader_stage *dst)
1834 {
1835    const struct radv_shader_info *src_info = &src->info;
1836    struct radv_shader_info *dst_info = &dst->info;
1837 
1838    assert((src->stage == MESA_SHADER_VERTEX && dst->stage == MESA_SHADER_TESS_CTRL) ||
1839           (src->stage == MESA_SHADER_VERTEX && dst->stage == MESA_SHADER_GEOMETRY) ||
1840           (src->stage == MESA_SHADER_TESS_EVAL && dst->stage == MESA_SHADER_GEOMETRY));
1841 
1842    dst_info->loads_push_constants |= src_info->loads_push_constants;
1843    dst_info->loads_dynamic_offsets |= src_info->loads_dynamic_offsets;
1844    dst_info->desc_set_used_mask |= src_info->desc_set_used_mask;
1845    dst_info->uses_view_index |= src_info->uses_view_index;
1846    dst_info->uses_prim_id |= src_info->uses_prim_id;
1847    dst_info->inline_push_constant_mask |= src_info->inline_push_constant_mask;
1848 
1849    /* Only inline all push constants if both allows it. */
1850    dst_info->can_inline_all_push_constants &= src_info->can_inline_all_push_constants;
1851 
1852    if (src->stage == MESA_SHADER_VERTEX) {
1853       dst_info->vs = src_info->vs;
1854    } else {
1855       dst_info->tes = src_info->tes;
1856    }
1857 
1858    if (dst->stage == MESA_SHADER_GEOMETRY)
1859       dst_info->gs.es_type = src->stage;
1860 }
1861 
1862 static const gl_shader_stage graphics_shader_order[] = {
1863    MESA_SHADER_VERTEX, MESA_SHADER_TESS_CTRL, MESA_SHADER_TESS_EVAL, MESA_SHADER_GEOMETRY,
1864 
1865    MESA_SHADER_TASK,   MESA_SHADER_MESH,
1866 };
1867 
1868 void
radv_nir_shader_info_link(struct radv_device * device,const struct radv_graphics_state_key * gfx_state,struct radv_shader_stage * stages)1869 radv_nir_shader_info_link(struct radv_device *device, const struct radv_graphics_state_key *gfx_state,
1870                           struct radv_shader_stage *stages)
1871 {
1872    const struct radv_physical_device *pdev = radv_device_physical(device);
1873 
1874    /* Walk backwards to link */
1875    struct radv_shader_stage *next_stage = stages[MESA_SHADER_FRAGMENT].nir ? &stages[MESA_SHADER_FRAGMENT] : NULL;
1876 
1877    for (int i = ARRAY_SIZE(graphics_shader_order) - 1; i >= 0; i--) {
1878       gl_shader_stage s = graphics_shader_order[i];
1879       if (!stages[s].nir)
1880          continue;
1881 
1882       radv_link_shaders_info(device, &stages[s], next_stage, gfx_state);
1883       next_stage = &stages[s];
1884    }
1885 
1886    if (pdev->info.gfx_level >= GFX9) {
1887       /* Merge shader info for VS+TCS. */
1888       if (stages[MESA_SHADER_VERTEX].nir && stages[MESA_SHADER_TESS_CTRL].nir) {
1889          radv_nir_shader_info_merge(&stages[MESA_SHADER_VERTEX], &stages[MESA_SHADER_TESS_CTRL]);
1890       }
1891 
1892       /* Merge shader info for VS+GS or TES+GS. */
1893       if ((stages[MESA_SHADER_VERTEX].nir || stages[MESA_SHADER_TESS_EVAL].nir) && stages[MESA_SHADER_GEOMETRY].nir) {
1894          gl_shader_stage pre_stage = stages[MESA_SHADER_TESS_EVAL].nir ? MESA_SHADER_TESS_EVAL : MESA_SHADER_VERTEX;
1895 
1896          radv_nir_shader_info_merge(&stages[pre_stage], &stages[MESA_SHADER_GEOMETRY]);
1897       }
1898    }
1899 }
1900 
1901 enum ac_hw_stage
radv_select_hw_stage(const struct radv_shader_info * const info,const enum amd_gfx_level gfx_level)1902 radv_select_hw_stage(const struct radv_shader_info *const info, const enum amd_gfx_level gfx_level)
1903 {
1904    switch (info->stage) {
1905    case MESA_SHADER_VERTEX:
1906       if (info->is_ngg)
1907          return AC_HW_NEXT_GEN_GEOMETRY_SHADER;
1908       else if (info->vs.as_es)
1909          return gfx_level >= GFX9 ? AC_HW_LEGACY_GEOMETRY_SHADER : AC_HW_EXPORT_SHADER;
1910       else if (info->vs.as_ls)
1911          return gfx_level >= GFX9 ? AC_HW_HULL_SHADER : AC_HW_LOCAL_SHADER;
1912       else
1913          return AC_HW_VERTEX_SHADER;
1914    case MESA_SHADER_TESS_EVAL:
1915       if (info->is_ngg)
1916          return AC_HW_NEXT_GEN_GEOMETRY_SHADER;
1917       else if (info->tes.as_es)
1918          return gfx_level >= GFX9 ? AC_HW_LEGACY_GEOMETRY_SHADER : AC_HW_EXPORT_SHADER;
1919       else
1920          return AC_HW_VERTEX_SHADER;
1921    case MESA_SHADER_TESS_CTRL:
1922       return AC_HW_HULL_SHADER;
1923    case MESA_SHADER_GEOMETRY:
1924       if (info->is_ngg)
1925          return AC_HW_NEXT_GEN_GEOMETRY_SHADER;
1926       else
1927          return AC_HW_LEGACY_GEOMETRY_SHADER;
1928    case MESA_SHADER_MESH:
1929       return AC_HW_NEXT_GEN_GEOMETRY_SHADER;
1930    case MESA_SHADER_FRAGMENT:
1931       return AC_HW_PIXEL_SHADER;
1932    case MESA_SHADER_COMPUTE:
1933    case MESA_SHADER_KERNEL:
1934    case MESA_SHADER_TASK:
1935    case MESA_SHADER_RAYGEN:
1936    case MESA_SHADER_ANY_HIT:
1937    case MESA_SHADER_CLOSEST_HIT:
1938    case MESA_SHADER_MISS:
1939    case MESA_SHADER_INTERSECTION:
1940    case MESA_SHADER_CALLABLE:
1941       return AC_HW_COMPUTE_SHADER;
1942    default:
1943       unreachable("Unsupported HW stage");
1944    }
1945 }
1946