• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright © 2021 Valve Corporation
3  *
4  * SPDX-License-Identifier: MIT
5  */
6 
7 #include "ac_gpu_info.h"
8 #include "ac_nir.h"
9 #include "ac_nir_helpers.h"
10 #include "nir_builder.h"
11 #include "util/u_math.h"
12 
13 /*
14  * These NIR passes are used to lower NIR cross-stage I/O intrinsics into the
15  * memory accesses that actually happen on the HW.
16  *
17  * Each input and output has a 16-byte (4 dwords) slot reserved for it, and
18  * can have up to 4 components. Each component is 32 bits.
19  *
20  * ## VS-TCS-TES I/O - Terminology:
21  *
22  * * patch - Group of vertices, used instead of primitives in tessellation
23  * * per-vertex - input or output which can be different for every vertex.
24  * * per-patch - input output which applies to a patch (a group of vertices)
25  *
26  * ## VS-TCS-TES I/O - How it works:
27  *
28  * ```
29  * SW model:    SW VS         SW TCS    tessellator    SW TES
30  *                ┊             ┊             ┊          ┊
31  *              ┌────┐        ┌────┐        ┌────┐    ┌─────┐
32  * HW pipeline: │ LS │─╮   ╭─>│ HS │─╮   ╭─>│ FF │ ╭─>│VS/ES33  *              └────┘ │   │  └────┘ │   │  └────┘ │  └─────┘
34  * Memory:             ╰─>LDS<──╯    ╰─>VRAM───────╯
35  * ```
36  *
37  * * SW VS runs as a HW LS (Local Shader, merged into HS on GFX9+),
38  *   and SW TCS runs as HW HS (Hull Shader).
39  *   SW TES runs as either HW VS or HW ES (Export Shader).
40  * * LS and HS share the same LDS space.
41  * * LS (SW VS) stores outputs to LDS to be read by HS (SW TCS).
42  * * HS (SW TCS) stores outputs in LDS if the HS (SW TCS) reads them.
43  * * HS (SW TCS) stores outputs in VRAM if the next stage (SW TES) reads them.
44  *
45  * Side note: some old HW supports having TES read from the same LDS space where LS/HS write, but
46  * Mesa always stores HS outputs to VRAM to avoid forcing TES waves to run on the same CU as the LS/HS waves.
47  *
48  * ### Passing VS-TCS I/O in registers
49  *
50  * On GPUs that run SW VS and  SW TCS on the same HW stage (HS on GFX9+),
51  * IO can be passed through registers instead of LDS when the following conditions are met:
52  *
53  * 1. TCS input and output patch size match
54  * 2. Floating point execution modes in SW VS and SW TCS match
55  * 3. The SW VS output is not written indirectly, and the corresponding SW TCS input is not read indirectly
56  *
57  * Some HS outputs could be passed through registers to, but this is a TODO.
58  *
59  * ### LDS layout used by VS-TCS:
60  *
61  * ```
62  * TCS per-vertex inputs for patch 0  <─── 0
63  * TCS per-vertex inputs for patch 1
64  * TCS per-vertex inputs for patch 2  <─── hs_per_vertex_input_lds_offset (rel_patch_id = 2)
65  * ...
66  * TCS per-vertex outputs for patch 0 <─── hs_output_lds_offset (rel_patch_id = 0, per-vertex)
67  * TCS per-patch outputs for patch 0  <─── hs_output_lds_offset (rel_patch_id = 0, per-patch)
68  * TCS per-vertex outputs for patch 1
69  * TCS per-patch outputs for patch 1
70  * TCS per-vertex outputs for patch 2 <─── hs_output_lds_offset (rel_patch_id = 2, per-vertex)
71  * TCS per-patch outputs for patch 2  <─── hs_output_lds_offset (rel_patch_id = 2, per-patch)
72  * ...
73  * ```
74  *
75  * ### VRAM layout used by TCS-TES I/O:
76  *
77  * ```
78  * attr 0 of patch 0 vertex 0   <─── "off-chip LDS" offset
79  * attr 0 of patch 0 vertex 1
80  * attr 0 of patch 0 vertex 2
81  * ...
82  * attr 0 of patch 1 vertex 0
83  * attr 0 of patch 1 vertex 1
84  * attr 0 of patch 1 vertex 2   <─── hs_per_vertex_output_vmem_offset (attribute slot = 0, rel_patch_id = 1, vertex index = 1)
85  * ...
86  * attr 0 of patch 2 vertex 0
87  * attr 0 of patch 2 vertex 1
88  * attr 0 of patch 2 vertex 2
89  * ...
90  * attr 1 of patch 0 vertex 0
91  * attr 1 of patch 0 vertex 1
92  * attr 1 of patch 0 vertex 2
93  * ...
94  * ...
95  * per-patch attr 0 of patch 0  <─── hs_out_patch_data_offset_amd
96  * per-patch attr 0 of patch 1
97  * per-patch attr 0 of patch 2  <─── hs_per_patch_output_vmem_offset (attribute slot = 0, rel_patch_id = 2)
98  * ...
99  * per-patch attr 1 of patch 0
100  * per-patch attr 1 of patch 1
101  * per-patch attr 1 of patch 2
102  * ...
103  * ```
104  *
105  */
106 
107 typedef struct {
108    /* Which hardware generation we're dealing with */
109    enum amd_gfx_level gfx_level;
110    nir_tcs_info tcs_info;
111 
112    /* I/O semantic -> real location used by lowering. */
113    ac_nir_map_io_driver_location map_io;
114 
115    /* Bit mask of TCS per-vertex inputs (VS outputs) which are passed via temporaries (VGPRs)
116     * from VS to TCS because they are read using gl_InvocationIndex as the vertex index.
117     *
118     * If TCS cross-invocation reads or indirect reads of these inputs are present, they don't
119     * prevent fast access via gl_InvocationIndex because those are just different ways of reading
120     * the same values.
121     *
122     * An example where a TCS input is indexed by gl_InvocationIndex and some other index is
123     * Unigine Heaven where the position input is used for patch culling (with cross-invocation
124     * access) and also read with gl_InvocationIndex to forward it to TES.
125     *
126     * Passing TCS inputs in VGPRs is only possible when:
127     * - VS+TCS are merged (GFX9+).
128     * - Input and output patch sizes are the same.
129     */
130    uint64_t tcs_inputs_via_temp;
131 
132    /* Bit mask of TCS per-vertex inputs (VS outputs) which are passed via LDS for cross-invocation
133     * reads or indirect reads.
134     */
135    uint64_t tcs_inputs_via_lds;
136 
137    /* Bit mask of TCS outputs read by TES. */
138    uint64_t tes_inputs_read;
139    uint32_t tes_patch_inputs_read;
140 
141    /* True if the output patch fits the subgroup, so all TCS outputs are always written in the same
142     * subgroup that reads them.
143     */
144    bool tcs_out_patch_fits_subgroup;
145 
146    /* Save TCS tess factor for tess factor writer. */
147    nir_variable *tcs_tess_level_outer;
148    nir_variable *tcs_tess_level_inner;
149    unsigned tcs_tess_level_outer_base;
150    unsigned tcs_tess_level_outer_mask;
151    unsigned tcs_tess_level_inner_base;
152    unsigned tcs_tess_level_inner_mask;
153 } lower_tess_io_state;
154 
155 typedef struct {
156    nir_def *outer;
157    nir_def *inner;
158 } tess_levels;
159 
160 #define TESS_LVL_MASK (VARYING_BIT_TESS_LEVEL_OUTER | VARYING_BIT_TESS_LEVEL_INNER)
161 
162 static uint64_t
tcs_vram_per_vtx_out_mask(nir_shader * shader,lower_tess_io_state * st)163 tcs_vram_per_vtx_out_mask(nir_shader *shader, lower_tess_io_state *st)
164 {
165    return st->tes_inputs_read & ~TESS_LVL_MASK;
166 }
167 
168 static uint32_t
tcs_vram_tf_out_mask(nir_shader * shader,lower_tess_io_state * st)169 tcs_vram_tf_out_mask(nir_shader *shader, lower_tess_io_state *st)
170 {
171    return st->tes_inputs_read & TESS_LVL_MASK;
172 }
173 
174 static uint32_t
tcs_vram_per_patch_out_mask(nir_shader * shader,lower_tess_io_state * st)175 tcs_vram_per_patch_out_mask(nir_shader *shader, lower_tess_io_state *st)
176 {
177    return st->tes_patch_inputs_read;
178 }
179 
180 static bool
tcs_output_needs_vmem(nir_intrinsic_instr * intrin,nir_shader * shader,lower_tess_io_state * st)181 tcs_output_needs_vmem(nir_intrinsic_instr *intrin,
182                       nir_shader *shader,
183                       lower_tess_io_state *st)
184 {
185    /* no_varying indicates that TES doesn't read the output. */
186    if (nir_intrinsic_io_semantics(intrin).no_varying)
187       return false;
188 
189    const unsigned loc = nir_intrinsic_io_semantics(intrin).location;
190    const bool per_vertex = intrin->intrinsic == nir_intrinsic_store_per_vertex_output ||
191                            intrin->intrinsic == nir_intrinsic_load_per_vertex_output;
192 
193    if (per_vertex) {
194       return tcs_vram_per_vtx_out_mask(shader, st) & BITFIELD64_BIT(loc);
195    } else if (loc == VARYING_SLOT_TESS_LEVEL_OUTER || loc == VARYING_SLOT_TESS_LEVEL_INNER) {
196       return false;
197    } else {
198       return tcs_vram_per_patch_out_mask(shader, st) & BITFIELD_BIT(loc - VARYING_SLOT_PATCH0);
199    }
200 }
201 
202 static uint64_t
tcs_lds_per_vtx_out_mask(nir_shader * shader)203 tcs_lds_per_vtx_out_mask(nir_shader *shader)
204 {
205    return shader->info.outputs_read & shader->info.outputs_written & ~TESS_LVL_MASK;
206 }
207 
208 static uint64_t
tcs_lds_tf_out_mask(nir_shader * shader,lower_tess_io_state * st)209 tcs_lds_tf_out_mask(nir_shader *shader, lower_tess_io_state *st)
210 {
211    return st->tcs_info.all_invocations_define_tess_levels ?
212             0ull : (shader->info.outputs_written & TESS_LVL_MASK);
213 }
214 
215 static uint32_t
tcs_lds_per_patch_out_mask(nir_shader * shader)216 tcs_lds_per_patch_out_mask(nir_shader *shader)
217 {
218    return shader->info.patch_outputs_read & shader->info.patch_outputs_written;
219 }
220 
221 static bool
tcs_output_needs_lds(nir_intrinsic_instr * intrin,nir_shader * shader,lower_tess_io_state * st)222 tcs_output_needs_lds(nir_intrinsic_instr *intrin,
223                      nir_shader *shader,
224                      lower_tess_io_state *st)
225 {
226    const unsigned loc = nir_intrinsic_io_semantics(intrin).location;
227    const bool per_vertex = intrin->intrinsic == nir_intrinsic_store_per_vertex_output ||
228                            intrin->intrinsic == nir_intrinsic_load_per_vertex_output;
229 
230    if (per_vertex) {
231       return tcs_lds_per_vtx_out_mask(shader) & BITFIELD64_BIT(loc);
232    } else if (loc == VARYING_SLOT_TESS_LEVEL_OUTER || loc == VARYING_SLOT_TESS_LEVEL_INNER) {
233       return tcs_lds_tf_out_mask(shader, st) & BITFIELD64_BIT(loc);
234    } else {
235       return tcs_lds_per_patch_out_mask(shader) & BITFIELD_BIT(loc - VARYING_SLOT_PATCH0);
236    }
237 }
238 
239 static bool
lower_ls_output_store(nir_builder * b,nir_intrinsic_instr * intrin,void * state)240 lower_ls_output_store(nir_builder *b,
241                       nir_intrinsic_instr *intrin,
242                       void *state)
243 {
244    if (intrin->intrinsic != nir_intrinsic_store_output)
245       return false;
246 
247    /* The ARB_shader_viewport_layer_array spec contains the
248     * following issue:
249     *
250     *    2) What happens if gl_ViewportIndex or gl_Layer is
251     *    written in the vertex shader and a geometry shader is
252     *    present?
253     *
254     *    RESOLVED: The value written by the last vertex processing
255     *    stage is used. If the last vertex processing stage
256     *    (vertex, tessellation evaluation or geometry) does not
257     *    statically assign to gl_ViewportIndex or gl_Layer, index
258     *    or layer zero is assumed.
259     *
260     * So writes to those outputs in VS-as-LS are simply ignored.
261     */
262    const nir_io_semantics io_sem = nir_intrinsic_io_semantics(intrin);
263    if (io_sem.location == VARYING_SLOT_LAYER || io_sem.location == VARYING_SLOT_VIEWPORT) {
264       nir_instr_remove(&intrin->instr);
265       return true;
266    }
267 
268    lower_tess_io_state *st = (lower_tess_io_state *) state;
269 
270    /* When a VS output isn't read by TCS, don't emit anything. */
271    if ((io_sem.no_varying ||
272         !((st->tcs_inputs_via_temp | st->tcs_inputs_via_lds) & BITFIELD64_BIT(io_sem.location)))) {
273       nir_instr_remove(&intrin->instr);
274       return true;
275    }
276 
277    if (st->tcs_inputs_via_lds & BITFIELD64_BIT(io_sem.location)) {
278       b->cursor = nir_before_instr(&intrin->instr);
279 
280       nir_def *vertex_idx = nir_load_local_invocation_index(b);
281       nir_def *base_off_var = nir_imul(b, vertex_idx, nir_load_lshs_vertex_stride_amd(b));
282 
283       unsigned mapped = ac_nir_map_io_location(io_sem.location, st->tcs_inputs_via_lds, st->map_io);
284       nir_def *io_off = ac_nir_calc_io_off(b, intrin, nir_imm_int(b, 16u), 4u, mapped);
285       unsigned write_mask = nir_intrinsic_write_mask(intrin);
286 
287       nir_def *off = nir_iadd_nuw(b, base_off_var, io_off);
288 
289       /* The first vec4 is reserved for the tf0/1 shader message group vote. */
290       if (st->gfx_level >= GFX11)
291          off = nir_iadd_imm_nuw(b, off, AC_HS_MSG_VOTE_LDS_BYTES);
292 
293       AC_NIR_STORE_IO(b, intrin->src[0].ssa, 0, write_mask, io_sem.high_16bits,
294                       nir_store_shared, off, .write_mask = store_write_mask, .base = store_const_offset);
295    }
296 
297    /* The store_output intrinsic on GFX9+ is used to pass the output to TCS via VGPRs. */
298    if (!(st->tcs_inputs_via_temp & BITFIELD64_BIT(io_sem.location)))
299       nir_instr_remove(&intrin->instr);
300 
301    return true;
302 }
303 
304 static bool
filter_load_tcs_per_vertex_input(const nir_instr * instr,UNUSED const void * state)305 filter_load_tcs_per_vertex_input(const nir_instr *instr,
306                                  UNUSED const void *state)
307 {
308    if (instr->type != nir_instr_type_intrinsic)
309       return false;
310 
311    lower_tess_io_state *st = (lower_tess_io_state *) state;
312    nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
313 
314    if (intrin->intrinsic != nir_intrinsic_load_per_vertex_input)
315       return false;
316 
317    nir_src *off_src = nir_get_io_offset_src(intrin);
318    nir_src *vertex_index_src = nir_get_io_arrayed_index_src(intrin);
319    nir_instr *vertex_index_instr = vertex_index_src->ssa->parent_instr;
320    const nir_io_semantics io_sem = nir_intrinsic_io_semantics(intrin);
321 
322    /* If this is accessed via gl_InvocationIndex, don't use LDS if tcs_inputs_via_temp is also set,
323     * which indicates that VS and TCS have the same number of patch vertices and the input can be
324     * read from VGPRs.
325     */
326    if (st->tcs_inputs_via_temp & BITFIELD64_BIT(io_sem.location) &&
327        nir_src_is_const(*off_src) && /* array indexing */
328        vertex_index_instr->type == nir_instr_type_intrinsic &&
329        nir_instr_as_intrinsic(vertex_index_instr)->intrinsic == nir_intrinsic_load_invocation_id)
330       return false;
331 
332    return true;
333 }
334 
335 static nir_def *
hs_per_vertex_input_lds_offset(nir_builder * b,lower_tess_io_state * st,nir_intrinsic_instr * instr)336 hs_per_vertex_input_lds_offset(nir_builder *b,
337                                lower_tess_io_state *st,
338                                nir_intrinsic_instr *instr)
339 {
340    nir_def *tcs_in_vtxcnt = nir_load_patch_vertices_in(b);
341    nir_def *rel_patch_id = nir_load_tess_rel_patch_id_amd(b);
342    nir_def *vertex_index = nir_get_io_arrayed_index_src(instr)->ssa;
343 
344    nir_def *stride = nir_load_lshs_vertex_stride_amd(b);
345    nir_def *tcs_in_patch_stride = nir_imul(b, tcs_in_vtxcnt, stride);
346    nir_def *vertex_index_off = nir_imul(b, vertex_index, stride);
347 
348    nir_def *tcs_in_current_patch_offset = nir_imul(b, rel_patch_id, tcs_in_patch_stride);
349 
350    const nir_io_semantics io_sem = nir_intrinsic_io_semantics(instr);
351    const unsigned mapped = ac_nir_map_io_location(io_sem.location, st->tcs_inputs_via_lds, st->map_io);
352    nir_def *io_offset = ac_nir_calc_io_off(b, instr, nir_imm_int(b, 16u), 4u, mapped);
353    nir_def *lds_offset = nir_iadd_nuw(b, nir_iadd_nuw(b, tcs_in_current_patch_offset, vertex_index_off), io_offset);
354 
355    /* The first LDS vec4 is reserved for the tf0/1 shader message group vote. */
356    return st->gfx_level >= GFX11 ? nir_iadd_imm_nuw(b, lds_offset, AC_HS_MSG_VOTE_LDS_BYTES) : lds_offset;
357 }
358 
359 static unsigned
hs_output_lds_map_io_location(nir_shader * shader,const bool per_vertex,const unsigned loc,lower_tess_io_state * st)360 hs_output_lds_map_io_location(nir_shader *shader,
361                               const bool per_vertex,
362                               const unsigned loc,
363                               lower_tess_io_state *st)
364 {
365    if (!per_vertex) {
366       const uint64_t tf_mask = tcs_lds_tf_out_mask(shader, st);
367       if (loc == VARYING_SLOT_TESS_LEVEL_INNER || loc == VARYING_SLOT_TESS_LEVEL_OUTER) {
368          assert(tf_mask & BITFIELD64_BIT(loc));
369          return util_bitcount64(tf_mask & BITFIELD64_MASK(loc));
370       }
371 
372       const uint32_t patch_out_mask = tcs_lds_per_patch_out_mask(shader);
373       assert(patch_out_mask & BITFIELD_BIT(loc - VARYING_SLOT_PATCH0));
374       return util_bitcount64(tf_mask) +
375              util_bitcount(patch_out_mask & BITFIELD_MASK(loc - VARYING_SLOT_PATCH0));
376    } else {
377       const uint64_t per_vertex_mask = tcs_lds_per_vtx_out_mask(shader);
378       assert(per_vertex_mask & BITFIELD64_BIT(loc));
379       return util_bitcount64(per_vertex_mask & BITFIELD64_MASK(loc));
380    }
381 }
382 
383 static nir_def *
hs_output_lds_offset(nir_builder * b,lower_tess_io_state * st,nir_intrinsic_instr * intrin)384 hs_output_lds_offset(nir_builder *b,
385                      lower_tess_io_state *st,
386                      nir_intrinsic_instr *intrin)
387 {
388    bool per_vertex = intrin &&
389                      (intrin->intrinsic == nir_intrinsic_store_per_vertex_output ||
390                       intrin->intrinsic == nir_intrinsic_load_per_vertex_output);
391 
392    const uint64_t per_vertex_mask = tcs_lds_per_vtx_out_mask(b->shader);
393    const uint64_t tf_mask = tcs_lds_tf_out_mask(b->shader, st);
394    const uint32_t patch_out_mask = tcs_lds_per_patch_out_mask(b->shader);
395 
396    unsigned tcs_num_reserved_outputs = util_bitcount64(per_vertex_mask);
397    unsigned tcs_num_reserved_patch_outputs = util_bitcount64(tf_mask) + util_bitcount(patch_out_mask);
398    unsigned output_vertex_size = tcs_num_reserved_outputs * 16u;
399    unsigned pervertex_output_patch_size = b->shader->info.tess.tcs_vertices_out * output_vertex_size;
400    unsigned output_patch_stride = pervertex_output_patch_size + tcs_num_reserved_patch_outputs * 16u;
401 
402    nir_def *off = NULL;
403 
404    if (intrin) {
405       const nir_io_semantics io_sem = nir_intrinsic_io_semantics(intrin);
406       const unsigned mapped = hs_output_lds_map_io_location(b->shader, per_vertex, io_sem.location, st);
407       off = ac_nir_calc_io_off(b, intrin, nir_imm_int(b, 16u), 4, mapped);
408    } else {
409       off = nir_imm_int(b, 0);
410    }
411 
412    nir_def *rel_patch_id = nir_load_tess_rel_patch_id_amd(b);
413    nir_def *patch_offset = nir_imul_imm(b, rel_patch_id, output_patch_stride);
414 
415    nir_def *tcs_in_vtxcnt = nir_load_patch_vertices_in(b);
416    nir_def *tcs_num_patches = nir_load_tcs_num_patches_amd(b);
417    nir_def *input_patch_size = nir_imul(b, tcs_in_vtxcnt, nir_load_lshs_vertex_stride_amd(b));
418    nir_def *output_patch0_offset = nir_imul(b, input_patch_size, tcs_num_patches);
419    nir_def *output_patch_offset = nir_iadd_nuw(b, patch_offset, output_patch0_offset);
420    nir_def *lds_offset;
421 
422    if (per_vertex) {
423       nir_def *vertex_index = nir_get_io_arrayed_index_src(intrin)->ssa;
424       nir_def *vertex_index_off = nir_imul_imm(b, vertex_index, output_vertex_size);
425 
426       off = nir_iadd_nuw(b, off, vertex_index_off);
427       lds_offset = nir_iadd_nuw(b, off, output_patch_offset);
428    } else {
429       off = nir_iadd_imm_nuw(b, off, pervertex_output_patch_size);
430       lds_offset = nir_iadd_nuw(b, off, output_patch_offset);
431    }
432 
433    /* The first LDS vec4 is reserved for the tf0/1 shader message group vote. */
434    return st->gfx_level >= GFX11 ? nir_iadd_imm_nuw(b, lds_offset, AC_HS_MSG_VOTE_LDS_BYTES) : lds_offset;
435 }
436 
437 static unsigned
hs_output_vram_map_io_location(nir_shader * shader,const bool per_vertex,const unsigned loc,lower_tess_io_state * st)438 hs_output_vram_map_io_location(nir_shader *shader,
439                                const bool per_vertex,
440                                const unsigned loc,
441                                lower_tess_io_state *st)
442 {
443    /* Unlinked shaders:
444     * We are unaware of TES inputs while lowering TCS outputs.
445     * The driver needs to pass a callback to map varyings to a fixed location.
446     */
447    if (st->map_io)
448       return st->map_io(loc);
449 
450    /* Linked shaders:
451     * Take advantage of having knowledge of TES inputs while lowering TCS outputs.
452     * Map varyings to a prefix sum of the IO mask to save space in VRAM.
453     */
454    if (!per_vertex) {
455       const uint64_t tf_mask = tcs_vram_tf_out_mask(shader, st);
456       if (loc == VARYING_SLOT_TESS_LEVEL_INNER || loc == VARYING_SLOT_TESS_LEVEL_OUTER) {
457          assert(tf_mask & BITFIELD64_BIT(loc));
458          return util_bitcount64(tf_mask & BITFIELD64_MASK(loc));
459       }
460 
461       const uint32_t patch_out_mask = tcs_vram_per_patch_out_mask(shader, st);
462       assert(patch_out_mask & BITFIELD_BIT(loc - VARYING_SLOT_PATCH0));
463       return util_bitcount64(tf_mask) +
464              util_bitcount(patch_out_mask & BITFIELD_MASK(loc - VARYING_SLOT_PATCH0));
465    } else {
466       const uint64_t per_vertex_mask = tcs_vram_per_vtx_out_mask(shader, st);
467       assert(per_vertex_mask & BITFIELD64_BIT(loc));
468       return util_bitcount64(per_vertex_mask & BITFIELD64_MASK(loc));
469    }
470 }
471 
472 static nir_def *
hs_per_vertex_output_vmem_offset(nir_builder * b,lower_tess_io_state * st,nir_intrinsic_instr * intrin)473 hs_per_vertex_output_vmem_offset(nir_builder *b,
474                                  lower_tess_io_state *st,
475                                  nir_intrinsic_instr *intrin)
476 {
477    const nir_io_semantics io_sem = nir_intrinsic_io_semantics(intrin);
478 
479    nir_def *out_vertices_per_patch = b->shader->info.stage == MESA_SHADER_TESS_CTRL
480                                          ? nir_imm_int(b, b->shader->info.tess.tcs_vertices_out)
481                                          : nir_load_patch_vertices_in(b);
482 
483    nir_def *tcs_num_patches = nir_load_tcs_num_patches_amd(b);
484    nir_def *attr_stride = nir_imul(b, tcs_num_patches, nir_imul_imm(b, out_vertices_per_patch, 16u));
485    nir_def *io_offset =
486       ac_nir_calc_io_off(b, intrin, attr_stride, 4u,
487                                    hs_output_vram_map_io_location(b->shader, true, io_sem.location, st));
488 
489    nir_def *rel_patch_id = nir_load_tess_rel_patch_id_amd(b);
490    nir_def *patch_offset = nir_imul(b, rel_patch_id, nir_imul_imm(b, out_vertices_per_patch, 16u));
491 
492    nir_def *vertex_index = nir_get_io_arrayed_index_src(intrin)->ssa;
493    nir_def *vertex_index_off = nir_imul_imm(b, vertex_index, 16u);
494 
495    return nir_iadd_nuw(b, nir_iadd_nuw(b, patch_offset, vertex_index_off), io_offset);
496 }
497 
498 static nir_def *
hs_per_patch_output_vmem_offset(nir_builder * b,lower_tess_io_state * st,nir_intrinsic_instr * intrin,unsigned const_base_offset)499 hs_per_patch_output_vmem_offset(nir_builder *b,
500                                 lower_tess_io_state *st,
501                                 nir_intrinsic_instr *intrin,
502                                 unsigned const_base_offset)
503 {
504    nir_def *tcs_num_patches = nir_load_tcs_num_patches_amd(b);
505    nir_def *per_patch_data_offset = nir_load_hs_out_patch_data_offset_amd(b);
506 
507    nir_def * off =
508       intrin
509       ? ac_nir_calc_io_off(b, intrin, nir_imul_imm(b, tcs_num_patches, 16u), 4u,
510                                      hs_output_vram_map_io_location(b->shader, false, nir_intrinsic_io_semantics(intrin).location, st))
511       : nir_imm_int(b, 0);
512 
513    if (const_base_offset)
514       off = nir_iadd_nuw(b, off, nir_imul_imm(b, tcs_num_patches, const_base_offset));
515 
516    nir_def *rel_patch_id = nir_load_tess_rel_patch_id_amd(b);
517    nir_def *patch_offset = nir_imul_imm(b, rel_patch_id, 16u);
518    off = nir_iadd_nuw(b, off, per_patch_data_offset);
519    return nir_iadd_nuw(b, off, patch_offset);
520 }
521 
522 static nir_def *
lower_hs_per_vertex_input_load(nir_builder * b,nir_instr * instr,void * state)523 lower_hs_per_vertex_input_load(nir_builder *b,
524                                nir_instr *instr,
525                                void *state)
526 {
527    lower_tess_io_state *st = (lower_tess_io_state *) state;
528    nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
529 
530    const nir_io_semantics io_sem = nir_intrinsic_io_semantics(intrin);
531    nir_def *off = hs_per_vertex_input_lds_offset(b, st, intrin);
532    nir_def *load = NULL;
533 
534    AC_NIR_LOAD_IO(load, b, intrin->def.num_components, intrin->def.bit_size, io_sem.high_16bits,
535                   nir_load_shared, off);
536 
537    return load;
538 }
539 
540 static nir_def *
lower_hs_output_store(nir_builder * b,nir_intrinsic_instr * intrin,lower_tess_io_state * st)541 lower_hs_output_store(nir_builder *b,
542                       nir_intrinsic_instr *intrin,
543                       lower_tess_io_state *st)
544 {
545    assert(intrin->intrinsic == nir_intrinsic_store_per_vertex_output ||
546           intrin->intrinsic == nir_intrinsic_store_output);
547 
548    nir_io_semantics semantics = nir_intrinsic_io_semantics(intrin);
549    nir_def *store_val = intrin->src[0].ssa;
550    const unsigned write_mask = nir_intrinsic_write_mask(intrin);
551    const bool write_to_vmem = tcs_output_needs_vmem(intrin, b->shader, st);
552    const bool write_to_lds =  tcs_output_needs_lds(intrin, b->shader, st);
553 
554    if (write_to_vmem) {
555       nir_def *vmem_off = intrin->intrinsic == nir_intrinsic_store_per_vertex_output
556                             ? hs_per_vertex_output_vmem_offset(b, st, intrin)
557                             : hs_per_patch_output_vmem_offset(b, st, intrin, 0);
558 
559       nir_def *hs_ring_tess_offchip = nir_load_ring_tess_offchip_amd(b);
560       nir_def *offchip_offset = nir_load_ring_tess_offchip_offset_amd(b);
561       nir_def *zero = nir_imm_int(b, 0);
562       AC_NIR_STORE_IO(b, store_val, 0, write_mask, semantics.high_16bits,
563                       nir_store_buffer_amd, hs_ring_tess_offchip, vmem_off, offchip_offset, zero,
564                       .write_mask = store_write_mask, .base = store_const_offset,
565                       .memory_modes = nir_var_shader_out, .access = ACCESS_COHERENT);
566    }
567 
568    if (write_to_lds) {
569       nir_def *lds_off = hs_output_lds_offset(b, st, intrin);
570       AC_NIR_STORE_IO(b, store_val, 0, write_mask, semantics.high_16bits,
571                       nir_store_shared, lds_off, .write_mask = store_write_mask, .base = store_const_offset);
572    }
573 
574    /* Save tess factor to be used by tess factor writer or reconstruct
575     * store output instruction later.
576     */
577    if (semantics.location == VARYING_SLOT_TESS_LEVEL_INNER ||
578        semantics.location == VARYING_SLOT_TESS_LEVEL_OUTER) {
579       const unsigned base = nir_intrinsic_base(intrin);
580       const unsigned component = nir_intrinsic_component(intrin);
581 
582       if (semantics.location == VARYING_SLOT_TESS_LEVEL_INNER) {
583          st->tcs_tess_level_inner_base = base;
584          st->tcs_tess_level_inner_mask |= write_mask << component;
585 
586          if (st->tcs_info.all_invocations_define_tess_levels)
587             ac_nir_store_var_components(b, st->tcs_tess_level_inner, store_val,
588                                         component, write_mask);
589       } else {
590          st->tcs_tess_level_outer_base = base;
591          st->tcs_tess_level_outer_mask |= write_mask << component;
592 
593          if (st->tcs_info.all_invocations_define_tess_levels)
594             ac_nir_store_var_components(b, st->tcs_tess_level_outer, store_val,
595                                         component, write_mask);
596       }
597    }
598 
599    return NIR_LOWER_INSTR_PROGRESS_REPLACE;
600 }
601 
602 static nir_def *
lower_hs_output_load(nir_builder * b,nir_intrinsic_instr * intrin,lower_tess_io_state * st)603 lower_hs_output_load(nir_builder *b,
604                      nir_intrinsic_instr *intrin,
605                      lower_tess_io_state *st)
606 {
607    const nir_io_semantics io_sem = nir_intrinsic_io_semantics(intrin);
608    const bool is_tess_factor = io_sem.location == VARYING_SLOT_TESS_LEVEL_INNER ||
609                                io_sem.location == VARYING_SLOT_TESS_LEVEL_OUTER;
610 
611    if (is_tess_factor && st->tcs_info.all_invocations_define_tess_levels) {
612       const unsigned component = nir_intrinsic_component(intrin);
613       const unsigned num_components = intrin->def.num_components;
614       const unsigned bit_size = intrin->def.bit_size;
615 
616       nir_def *var =
617          io_sem.location == VARYING_SLOT_TESS_LEVEL_OUTER
618             ? nir_load_var(b, st->tcs_tess_level_outer)
619             : nir_load_var(b, st->tcs_tess_level_inner);
620 
621       return nir_extract_bits(b, &var, 1, component * bit_size, num_components, bit_size);
622    }
623 
624    /* If an output is not stored by the shader, replace the output load by undef. */
625    if (!tcs_output_needs_lds(intrin, b->shader, st))
626       return nir_undef(b, intrin->def.num_components, intrin->def.bit_size);
627 
628    nir_def *off = hs_output_lds_offset(b, st, intrin);
629    nir_def *load = NULL;
630 
631    AC_NIR_LOAD_IO(load, b, intrin->def.num_components, intrin->def.bit_size, io_sem.high_16bits,
632                   nir_load_shared, off);
633 
634    return load;
635 }
636 
637 static void
update_hs_barrier(nir_intrinsic_instr * intrin,lower_tess_io_state * st)638 update_hs_barrier(nir_intrinsic_instr *intrin, lower_tess_io_state *st)
639 {
640    /* Output loads and stores are lowered to shared memory access,
641     * so we have to update the barriers to also reflect this.
642     */
643    unsigned mem_modes = nir_intrinsic_memory_modes(intrin);
644    if (mem_modes & nir_var_shader_out) {
645       mem_modes |= nir_var_mem_shared;
646       mem_modes &= ~nir_var_shader_out;
647    }
648    nir_intrinsic_set_memory_modes(intrin, mem_modes);
649 
650    mesa_scope exec_scope = nir_intrinsic_execution_scope(intrin);
651    if (exec_scope == SCOPE_WORKGROUP && st->tcs_out_patch_fits_subgroup)
652       nir_intrinsic_set_execution_scope(intrin, SCOPE_SUBGROUP);
653 
654    mesa_scope mem_scope = nir_intrinsic_memory_scope(intrin);
655    if (mem_scope == SCOPE_WORKGROUP && st->tcs_out_patch_fits_subgroup)
656       nir_intrinsic_set_memory_scope(intrin, SCOPE_SUBGROUP);
657 }
658 
659 static nir_def *
lower_hs_output_access(nir_builder * b,nir_instr * instr,void * state)660 lower_hs_output_access(nir_builder *b,
661                        nir_instr *instr,
662                        void *state)
663 {
664    lower_tess_io_state *st = (lower_tess_io_state *) state;
665    nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
666 
667    if (intrin->intrinsic == nir_intrinsic_store_output ||
668        intrin->intrinsic == nir_intrinsic_store_per_vertex_output) {
669       return lower_hs_output_store(b, intrin, st);
670    } else if (intrin->intrinsic == nir_intrinsic_load_output ||
671               intrin->intrinsic == nir_intrinsic_load_per_vertex_output) {
672       return lower_hs_output_load(b, intrin, st);
673    } else if (intrin->intrinsic == nir_intrinsic_barrier) {
674       update_hs_barrier(intrin, st);
675       return NIR_LOWER_INSTR_PROGRESS;
676    } else {
677       unreachable("intrinsic not supported by lower_hs_output_access");
678    }
679 }
680 
681 static tess_levels
hs_load_tess_levels(nir_builder * b,lower_tess_io_state * st)682 hs_load_tess_levels(nir_builder *b,
683                     lower_tess_io_state *st)
684 {
685    unsigned outer_comps, inner_comps;
686    mesa_count_tess_level_components(b->shader->info.tess._primitive_mode,
687                                     &outer_comps, &inner_comps);
688 
689    nir_def *outer = NULL;
690    nir_def *inner = NULL;
691 
692    if (st->tcs_info.all_invocations_define_tess_levels) {
693       if (st->tcs_tess_level_outer_mask) {
694          outer = nir_load_var(b, st->tcs_tess_level_outer);
695          outer = nir_trim_vector(b, outer, outer_comps);
696       }
697 
698       if (inner_comps && st->tcs_tess_level_inner_mask) {
699          inner = nir_load_var(b, st->tcs_tess_level_inner);
700          inner = nir_trim_vector(b, inner, inner_comps);
701       }
702    } else {
703       /* Base LDS address of per-patch outputs in the current patch. */
704       nir_def *lds_base = hs_output_lds_offset(b, st, NULL);
705 
706       /* Load all tessellation factors (aka. tess levels) from LDS. */
707       if (st->tcs_tess_level_outer_mask) {
708          const unsigned mapped = hs_output_lds_map_io_location(b->shader, false, VARYING_SLOT_TESS_LEVEL_OUTER, st);
709          outer = nir_load_shared(b, outer_comps, 32, lds_base, .base = mapped * 16);
710       }
711 
712       if (inner_comps && st->tcs_tess_level_inner_mask) {
713          const unsigned mapped = hs_output_lds_map_io_location(b->shader, false, VARYING_SLOT_TESS_LEVEL_INNER, st);
714          inner = nir_load_shared(b, inner_comps, 32, lds_base, .base = mapped * 16);
715       }
716    }
717 
718    /* Set tess factor to zero if the shader did not write them. */
719    if (!outer)
720       outer = nir_imm_zero(b, outer_comps, 32);
721    if (inner_comps && !inner)
722       inner = nir_imm_zero(b, inner_comps, 32);
723 
724    tess_levels r = {
725       .outer = outer,
726       .inner = inner,
727    };
728 
729    return r;
730 }
731 
732 static void
hs_store_dynamic_control_word_gfx6(nir_builder * b)733 hs_store_dynamic_control_word_gfx6(nir_builder *b)
734 {
735    nir_def *rel_patch_id = nir_load_tess_rel_patch_id_amd(b);
736    nir_def *tessfactor_ring = nir_load_ring_tess_factors_amd(b);
737    nir_def *tess_factors_base = nir_load_ring_tess_factors_offset_amd(b);
738 
739    /* Store the dynamic HS control word. */
740    nir_if *rel_patch_id_zero = nir_push_if(b, nir_ieq_imm(b, rel_patch_id, 0));
741    nir_def *zero = nir_imm_int(b, 0);
742    nir_def *ctrlw = nir_imm_int(b, 0x80000000u);
743    nir_store_buffer_amd(b, ctrlw, tessfactor_ring, zero, tess_factors_base, zero,
744                         .access = ACCESS_COHERENT);
745    nir_pop_if(b, rel_patch_id_zero);
746 }
747 
748 static nir_def *
hs_resize_tess_factor(nir_builder * b,nir_def * tf,unsigned comps)749 hs_resize_tess_factor(nir_builder *b, nir_def *tf, unsigned comps)
750 {
751    if (!comps)
752       return NULL;
753    else if (!tf)
754       return nir_imm_zero(b, comps, 32);
755    else if (comps > tf->num_components)
756       return nir_pad_vector_imm_int(b, tf, 0, comps);
757    else if (comps < tf->num_components)
758       return nir_trim_vector(b, tf, comps);
759    else
760       return tf;
761 }
762 
763 static nir_if *
hs_if_invocation_id_zero(nir_builder * b)764 hs_if_invocation_id_zero(nir_builder *b)
765 {
766    nir_def *invocation_id = nir_load_invocation_id(b);
767 
768    /* Only the 1st invocation of each patch needs to do this. */
769    nir_if *invocation_id_zero = nir_push_if(b, nir_ieq_imm(b, invocation_id, 0));
770 
771    /* When the output patch size is <= 32 then we can flatten the branch here
772     * because we know for sure that at least 1 invocation in all waves will
773     * take the branch.
774     */
775    if (b->shader->info.tess.tcs_vertices_out <= 32)
776       invocation_id_zero->control = nir_selection_control_divergent_always_taken;
777 
778    return invocation_id_zero;
779 }
780 
781 static nir_def *
tess_level_has_effect(nir_builder * b,nir_def * prim_mode,unsigned comp,bool outer)782 tess_level_has_effect(nir_builder *b, nir_def *prim_mode, unsigned comp, bool outer)
783 {
784    if (outer && comp <= 1)
785       return nir_imm_true(b);
786    else if ((outer && comp == 2) || (!outer && comp == 0))
787       return nir_ine_imm(b, prim_mode, TESS_PRIMITIVE_ISOLINES);
788    else if ((outer && comp == 3) || (!outer && comp == 1))
789       return nir_ieq_imm(b, prim_mode, TESS_PRIMITIVE_QUADS);
790    else
791       unreachable("invalid comp");
792 }
793 
794 /* Return true if memory should be used. If false is returned, the shader message has been used. */
795 static nir_def *
hs_msg_group_vote_use_memory(nir_builder * b,lower_tess_io_state * st,tess_levels * tessfactors,nir_def * prim_mode)796 hs_msg_group_vote_use_memory(nir_builder *b, lower_tess_io_state *st,
797                              tess_levels *tessfactors, nir_def *prim_mode)
798 {
799    /* Don't do the group vote and send the message directly if tess level values were determined
800     * by nir_gather_tcs_info at compile time.
801     *
802     * Disable the shader cache if you set the environment variable.
803     */
804    if (debug_get_bool_option("AMD_FAST_HS_MSG", true) &&
805        (st->tcs_info.all_tess_levels_are_effectively_zero ||
806         st->tcs_info.all_tess_levels_are_effectively_one)) {
807       nir_if *if_subgroup0 = nir_push_if(b, nir_ieq_imm(b, nir_load_subgroup_id(b), 0));
808       {
809          /* m0[0] == 0 means all TF are 0 in the workgroup.
810           * m0[0] == 1 means all TF are 1 in the workgroup.
811           */
812          nir_def *m0 = nir_imm_int(b, st->tcs_info.all_tess_levels_are_effectively_zero ? 0 : 1);
813          nir_sendmsg_amd(b, m0, .base = AC_SENDMSG_HS_TESSFACTOR);
814       }
815       nir_pop_if(b, if_subgroup0);
816       return nir_imm_false(b);
817    }
818 
819    /* Initialize the first LDS dword for the tf0/1 group vote at the beginning of TCS. */
820    nir_block *start_block = nir_start_block(nir_shader_get_entrypoint(b->shader));
821    nir_builder top_b = nir_builder_at(nir_before_block(start_block));
822 
823    nir_if *thread0 = nir_push_if(&top_b,
824                                  nir_iand(&top_b, nir_ieq_imm(&top_b, nir_load_subgroup_id(&top_b), 0),
825                                           nir_inverse_ballot(&top_b, 1, nir_imm_ivec4(&top_b, 0x1, 0, 0, 0))));
826    {
827       /* 0x3 is the initial bitmask (tf0 | tf1). Each subgroup will do atomic iand on it for the vote. */
828       nir_store_shared(&top_b, nir_imm_int(&top_b, 0x3), nir_imm_int(&top_b, 0),
829                        .write_mask = 0x1, .align_mul = 4);
830    }
831    nir_pop_if(&top_b, thread0);
832 
833    /* Insert a barrier to wait for initialization above if there hasn't been any other barrier
834     * in the shader.
835     */
836    if (!st->tcs_info.always_executes_barrier) {
837       nir_barrier(b, .execution_scope = SCOPE_WORKGROUP, .memory_scope = SCOPE_WORKGROUP,
838                   .memory_semantics = NIR_MEMORY_ACQ_REL, .memory_modes = nir_var_mem_shared);
839    }
840 
841    /* Use s_sendmsg to tell the hw whether the whole workgroup has either of these cases:
842     *
843     * tf0: All patches in the workgroup have at least one outer tess level component either
844     *      in the [-inf, 0] range or equal to NaN, causing them to be discarded. Inner tess levels
845     *      have no effect.
846     *
847     * tf1: All patches in the workgroup have the values of tess levels set to 1 or equivalent numbers,
848     *      which doesn't discard any patches. Each spacing interprets different tess level ranges as 1:
849     *
850     *      1) equal_spacing, fractional_odd_spacing, and unknown spacing
851     *      For undiscarded patches, the tessellator clamps all tess levels to 1. If all tess levels
852     *      are in the (0, 1] range, which is effectively 1, untessellated patches are
853     *      drawn.
854     *
855     *      2) fractional_even_spacing
856     *      For undiscarded patches, the tessellator clamps all tess levels to 2 (both outer and inner)
857     *      except isolines, which clamp the first outer tess level component to 1. If all outer tess
858     *      levels are in the (0, 2] or (0, 1] range (for outer[0] of isolines) and all inner tess levels
859     *      are in the [-inf, 2] range, the tf1 message can be used. The tessellator will receive 1 via
860     *      the message, but will clamp them to 2 or keep 1 (for outer[0] of isolines).
861     *
862     *      If we make this mutually exclusive with tf0, we only have to compare against the upper bound.
863     */
864 
865    /* Determine tf0/tf1 for the subgroup at the end of TCS. */
866    nir_if *if_invocation_id_zero = hs_if_invocation_id_zero(b);
867    {
868       *tessfactors = hs_load_tess_levels(b, st);
869 
870       nir_def *lane_tf_effectively_0 = nir_imm_false(b);
871       for (unsigned i = 0; i < tessfactors->outer->num_components; i++) {
872          nir_def *valid = tess_level_has_effect(b, prim_mode, i, true);
873          /* fgeu returns true for NaN */
874          nir_def *le0 = nir_fgeu(b, nir_imm_float(b, 0), nir_channel(b, tessfactors->outer, i));
875          lane_tf_effectively_0 = nir_ior(b, lane_tf_effectively_0, nir_iand(b, le0, valid));
876       }
877 
878       /* Use case 1: unknown spacing */
879       nir_def *lane_tf_effectively_1 = nir_imm_true(b);
880       for (unsigned i = 0; i < tessfactors->outer->num_components; i++) {
881          nir_def *valid = tess_level_has_effect(b, prim_mode, i, true);
882          nir_def *le1 = nir_fle_imm(b, nir_channel(b, tessfactors->outer, i), 1);
883          lane_tf_effectively_1 = nir_iand(b, lane_tf_effectively_1, nir_ior(b, le1, nir_inot(b, valid)));
884       }
885 
886       if (tessfactors->inner) {
887          for (unsigned i = 0; i < tessfactors->inner->num_components; i++) {
888             nir_def *valid = tess_level_has_effect(b, prim_mode, i, false);
889             nir_def *le1 = nir_fle_imm(b, nir_channel(b, tessfactors->inner, i), 1);
890             lane_tf_effectively_1 = nir_iand(b, lane_tf_effectively_1, nir_ior(b, le1, nir_inot(b, valid)));
891          }
892       }
893 
894       /* Make them mutually exclusive. */
895       lane_tf_effectively_1 = nir_iand(b, lane_tf_effectively_1, nir_inot(b, lane_tf_effectively_0));
896 
897       nir_def *subgroup_uses_tf0 = nir_b2i32(b, nir_vote_all(b, 1, lane_tf_effectively_0));
898       nir_def *subgroup_uses_tf1 = nir_b2i32(b, nir_vote_all(b, 1, lane_tf_effectively_1));
899 
900       /* Pack the value for LDS. Encoding:
901        *    0 = none of the below
902        *    1 = all tess factors are effectively 0
903        *    2 = all tess factors are effectively 1
904        *    3 = invalid
905        *
906        * Since we will do bitwise AND reduction across all waves, 3 can never occur.
907        */
908       nir_def *packed_tf01_mask = nir_ior(b, subgroup_uses_tf0,
909                                           nir_ishl_imm(b, subgroup_uses_tf1, 1));
910 
911       /* This function is only called within a block that only executes for patch invocation 0, so we
912        * only need to mask out invocation 0 of other patches in the subgroup to execute on only 1 lane.
913        *
914        * Since patch invocations are placed sequentially in the subgroup, we know that invocation 0
915        * of the lowest patch must be somewhere in BITFIELD_MASK(tcs_vertices_out) lanes.
916        */
917       const unsigned tcs_vertices_out = b->shader->info.tess.tcs_vertices_out;
918       assert(tcs_vertices_out <= 32);
919       nir_def *is_first_active_lane =
920          nir_inverse_ballot(b, 1, nir_imm_ivec4(b, BITFIELD_MASK(tcs_vertices_out), 0, 0, 0));
921 
922       /* Only the first active invocation in each subgroup performs the AND reduction through LDS. */
923       nir_if *if_first_active_lane = nir_push_if(b, is_first_active_lane);
924       if_first_active_lane->control = nir_selection_control_divergent_always_taken;
925       {
926          /* Use atomic iand to combine results from all subgroups. */
927          nir_shared_atomic(b, 32, nir_imm_int(b, 0), packed_tf01_mask,
928                            .atomic_op = nir_atomic_op_iand);
929       }
930       nir_pop_if(b, if_first_active_lane);
931    }
932    nir_pop_if(b, if_invocation_id_zero);
933    /* The caller will reuse these. */
934    tessfactors->outer = nir_if_phi(b, tessfactors->outer, nir_undef(b, tessfactors->outer->num_components, 32));
935    if (tessfactors->inner) /* Isolines don't have inner tess levels. */
936       tessfactors->inner = nir_if_phi(b, tessfactors->inner, nir_undef(b, tessfactors->inner->num_components, 32));
937 
938    /* Wait for all waves to execute the LDS atomic. */
939    nir_barrier(b, .execution_scope = SCOPE_WORKGROUP, .memory_scope = SCOPE_WORKGROUP,
940                .memory_semantics = NIR_MEMORY_ACQ_REL, .memory_modes = nir_var_mem_shared);
941 
942    /* Read the result from LDS. Only 1 lane should load it to prevent LDS bank conflicts. */
943    nir_def *lds_result;
944    nir_if *if_lane0 = nir_push_if(b, nir_inverse_ballot(b, 1, nir_imm_ivec4(b, 0x1, 0, 0, 0)));
945    if_lane0->control = nir_selection_control_divergent_always_taken;
946    {
947       lds_result = nir_load_shared(b, 1, 32, nir_imm_int(b, 0), .align_mul = 4);
948    }
949    nir_pop_if(b, if_lane0);
950    lds_result = nir_if_phi(b, lds_result, nir_undef(b, 1, 32));
951    lds_result = nir_read_invocation(b, lds_result, nir_imm_int(b, 0));
952 
953    /* Determine the vote value and send the message. */
954    nir_def *use_memory = nir_ieq_imm(b, lds_result, 0);
955 
956    nir_if *if_subgroup0_sendmsg = nir_push_if(b, nir_iand(b, nir_inot(b, use_memory),
957                                                           nir_ieq_imm(b, nir_load_subgroup_id(b), 0)));
958    {
959       /* m0[0] == 0 means all TF are 0 in the workgroup.
960        * m0[0] == 1 means all TF are 1 in the workgroup.
961        */
962       nir_def *m0 = nir_iadd_imm(b, lds_result, -1);
963       nir_sendmsg_amd(b, m0, .base = AC_SENDMSG_HS_TESSFACTOR);
964    }
965    nir_pop_if(b, if_subgroup0_sendmsg);
966 
967    return use_memory;
968 }
969 
970 static void
hs_store_tess_factors_for_tessellator(nir_builder * b,enum amd_gfx_level gfx_level,enum tess_primitive_mode prim_mode,tess_levels tessfactors)971 hs_store_tess_factors_for_tessellator(nir_builder *b, enum amd_gfx_level gfx_level,
972                                       enum tess_primitive_mode prim_mode,
973                                       tess_levels tessfactors)
974 {
975    nir_def *rel_patch_id = nir_load_tess_rel_patch_id_amd(b);
976    nir_def *tessfactor_ring = nir_load_ring_tess_factors_amd(b);
977    nir_def *tess_factors_base = nir_load_ring_tess_factors_offset_amd(b);
978    nir_def *zero = nir_imm_int(b, 0);
979 
980    const unsigned tess_factors_const_offset = gfx_level <= GFX8 ? 4 : 0;
981    unsigned outer_comps, inner_comps;
982 
983    mesa_count_tess_level_components(prim_mode, &outer_comps, &inner_comps);
984 
985    nir_def *tess_factors_offset =
986       nir_imul_imm(b, rel_patch_id, (inner_comps + outer_comps) * 4u);
987 
988    nir_def *tf_outer = hs_resize_tess_factor(b, tessfactors.outer, outer_comps);
989    nir_def *tf_inner = hs_resize_tess_factor(b, tessfactors.inner, inner_comps);
990 
991    /* Store tess factors for the tessellator */
992    if (prim_mode == TESS_PRIMITIVE_ISOLINES) {
993       /* LINES reversal */
994       nir_def *t = nir_vec2(b, nir_channel(b, tf_outer, 1), nir_channel(b, tf_outer, 0));
995       nir_store_buffer_amd(b, t, tessfactor_ring, tess_factors_offset, tess_factors_base, zero,
996                            .base = tess_factors_const_offset, .access = ACCESS_COHERENT | ACCESS_CP_GE_COHERENT_AMD);
997    } else if (prim_mode == TESS_PRIMITIVE_TRIANGLES) {
998       nir_def *t = nir_vec4(b, nir_channel(b, tf_outer, 0), nir_channel(b, tf_outer, 1),
999                                nir_channel(b, tf_outer, 2), nir_channel(b, tf_inner, 0));
1000       nir_store_buffer_amd(b, t, tessfactor_ring, tess_factors_offset, tess_factors_base, zero,
1001                            .base = tess_factors_const_offset, .access = ACCESS_COHERENT | ACCESS_CP_GE_COHERENT_AMD);
1002    } else {
1003       nir_store_buffer_amd(b, tf_outer, tessfactor_ring, tess_factors_offset, tess_factors_base, zero,
1004                            .base = tess_factors_const_offset, .access = ACCESS_COHERENT | ACCESS_CP_GE_COHERENT_AMD);
1005       nir_store_buffer_amd(b, tf_inner, tessfactor_ring, tess_factors_offset, tess_factors_base, zero,
1006                            .base = tess_factors_const_offset + 4u * outer_comps,
1007                            .access = ACCESS_COHERENT | ACCESS_CP_GE_COHERENT_AMD);
1008    }
1009 }
1010 
1011 static void
hs_store_tess_factors_for_tes(nir_builder * b,tess_levels tessfactors,lower_tess_io_state * st)1012 hs_store_tess_factors_for_tes(nir_builder *b, tess_levels tessfactors, lower_tess_io_state *st)
1013 {
1014    nir_def *hs_ring_tess_offchip = nir_load_ring_tess_offchip_amd(b);
1015    nir_def *offchip_offset = nir_load_ring_tess_offchip_offset_amd(b);
1016    nir_def *zero = nir_imm_int(b, 0);
1017 
1018    /* For linked shaders, we must only write the tess factors that the TES actually reads,
1019     * otherwise we would write to a memory location reserved for another per-patch output.
1020     */
1021    const bool tes_reads_outer = st->tes_inputs_read & VARYING_BIT_TESS_LEVEL_OUTER;
1022    const bool tes_reads_inner = st->tes_inputs_read & VARYING_BIT_TESS_LEVEL_INNER;
1023 
1024    if (st->tcs_tess_level_outer_mask && tes_reads_outer) {
1025       const unsigned tf_outer_loc = hs_output_vram_map_io_location(b->shader, false, VARYING_SLOT_TESS_LEVEL_OUTER, st);
1026       nir_def *vmem_off_outer = hs_per_patch_output_vmem_offset(b, st, NULL, tf_outer_loc * 16);
1027 
1028       nir_store_buffer_amd(b, tessfactors.outer, hs_ring_tess_offchip,
1029                            vmem_off_outer, offchip_offset, zero,
1030                            .memory_modes = nir_var_shader_out,
1031                            .access = ACCESS_COHERENT);
1032    }
1033 
1034    if (tessfactors.inner && st->tcs_tess_level_inner_mask && tes_reads_inner) {
1035       const unsigned tf_inner_loc = hs_output_vram_map_io_location(b->shader, false, VARYING_SLOT_TESS_LEVEL_INNER, st);
1036       nir_def *vmem_off_inner = hs_per_patch_output_vmem_offset(b, st, NULL, tf_inner_loc * 16);
1037 
1038       nir_store_buffer_amd(b, tessfactors.inner, hs_ring_tess_offchip,
1039                            vmem_off_inner, offchip_offset, zero,
1040                            .memory_modes = nir_var_shader_out,
1041                            .access = ACCESS_COHERENT);
1042    }
1043 }
1044 
1045 static void
hs_finale(nir_shader * shader,lower_tess_io_state * st)1046 hs_finale(nir_shader *shader, lower_tess_io_state *st)
1047 {
1048    nir_function_impl *impl = nir_shader_get_entrypoint(shader);
1049    assert(impl);
1050    nir_block *last_block = nir_impl_last_block(impl);
1051    assert(last_block);
1052 
1053    nir_builder builder = nir_builder_at(nir_after_block(last_block));
1054    nir_builder *b = &builder; /* This is to avoid the & */
1055 
1056    /* If tess factors are loaded from LDS, wait for their LDS stores. */
1057    if (!st->tcs_info.all_invocations_define_tess_levels) {
1058       mesa_scope scope = st->tcs_out_patch_fits_subgroup ? SCOPE_SUBGROUP : SCOPE_WORKGROUP;
1059       nir_barrier(b, .execution_scope = scope, .memory_scope = scope,
1060                      .memory_semantics = NIR_MEMORY_ACQ_REL, .memory_modes = nir_var_mem_shared);
1061       st->tcs_info.always_executes_barrier = true;
1062    }
1063 
1064    nir_def *prim_mode = nir_load_tcs_primitive_mode_amd(b);
1065    nir_def *use_memory = NULL;
1066    tess_levels tessfactors = {0};
1067 
1068    /* This also loads tess levels for patch invocation 0. */
1069    if (st->gfx_level >= GFX11)
1070       use_memory = hs_msg_group_vote_use_memory(b, st, &tessfactors, prim_mode);
1071 
1072    /* Only the 1st invocation of each patch needs to access VRAM and/or LDS. */
1073    nir_if *if_invocation_id_zero = hs_if_invocation_id_zero(b);
1074    {
1075       if (!tessfactors.outer)
1076          tessfactors = hs_load_tess_levels(b, st);
1077 
1078       nir_if *if_use_memory = NULL;
1079       if (use_memory != NULL)
1080          if_use_memory = nir_push_if(b, use_memory);
1081 
1082       if (st->gfx_level <= GFX8)
1083          hs_store_dynamic_control_word_gfx6(b);
1084 
1085       nir_if *if_triangles = nir_push_if(b, nir_ieq_imm(b, prim_mode, TESS_PRIMITIVE_TRIANGLES));
1086       {
1087          hs_store_tess_factors_for_tessellator(b, st->gfx_level, TESS_PRIMITIVE_TRIANGLES, tessfactors);
1088       }
1089       nir_push_else(b, if_triangles);
1090       {
1091          nir_if *if_isolines = nir_push_if(b, nir_ieq_imm(b, prim_mode, TESS_PRIMITIVE_ISOLINES));
1092          {
1093             hs_store_tess_factors_for_tessellator(b, st->gfx_level, TESS_PRIMITIVE_ISOLINES, tessfactors);
1094          }
1095          nir_push_else(b, if_isolines);
1096          {
1097             hs_store_tess_factors_for_tessellator(b, st->gfx_level, TESS_PRIMITIVE_QUADS, tessfactors);
1098          }
1099          nir_pop_if(b, if_isolines);
1100       }
1101       nir_pop_if(b, if_triangles);
1102 
1103       if (use_memory != NULL)
1104          nir_pop_if(b, if_use_memory);
1105 
1106       nir_if *if_tes_reads_tf = nir_push_if(b, nir_load_tcs_tess_levels_to_tes_amd(b));
1107       {
1108          hs_store_tess_factors_for_tes(b, tessfactors, st);
1109       }
1110       nir_pop_if(b, if_tes_reads_tf);
1111    }
1112    nir_pop_if(b, if_invocation_id_zero);
1113 
1114    nir_metadata_preserve(impl, nir_metadata_none);
1115 }
1116 
1117 static nir_def *
lower_tes_input_load(nir_builder * b,nir_instr * instr,void * state)1118 lower_tes_input_load(nir_builder *b,
1119                      nir_instr *instr,
1120                      void *state)
1121 {
1122    lower_tess_io_state *st = (lower_tess_io_state *) state;
1123    nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
1124 
1125    const nir_io_semantics io_sem = nir_intrinsic_io_semantics(intrin);
1126    nir_def *offchip_ring = nir_load_ring_tess_offchip_amd(b);
1127    nir_def *offchip_offset = nir_load_ring_tess_offchip_offset_amd(b);
1128    nir_def *off = intrin->intrinsic == nir_intrinsic_load_per_vertex_input
1129                     ? hs_per_vertex_output_vmem_offset(b, st, intrin)
1130                     : hs_per_patch_output_vmem_offset(b, st, intrin, 0);
1131 
1132    nir_def *zero = nir_imm_int(b, 0);
1133    nir_def *load = NULL;
1134 
1135    AC_NIR_LOAD_IO(load, b, intrin->def.num_components, intrin->def.bit_size, io_sem.high_16bits,
1136                   nir_load_buffer_amd, offchip_ring, off, offchip_offset, zero, .access = ACCESS_COHERENT);
1137 
1138    return load;
1139 }
1140 
1141 static bool
filter_hs_output_access(const nir_instr * instr,UNUSED const void * st)1142 filter_hs_output_access(const nir_instr *instr,
1143                          UNUSED const void *st)
1144 {
1145    if (instr->type != nir_instr_type_intrinsic)
1146       return false;
1147 
1148    nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
1149    return intrin->intrinsic == nir_intrinsic_store_output ||
1150           intrin->intrinsic == nir_intrinsic_store_per_vertex_output ||
1151           intrin->intrinsic == nir_intrinsic_load_output ||
1152           intrin->intrinsic == nir_intrinsic_load_per_vertex_output ||
1153           intrin->intrinsic == nir_intrinsic_barrier;
1154 }
1155 
1156 static bool
filter_any_input_access(const nir_instr * instr,UNUSED const void * st)1157 filter_any_input_access(const nir_instr *instr,
1158                         UNUSED const void *st)
1159 {
1160    if (instr->type != nir_instr_type_intrinsic)
1161       return false;
1162 
1163    nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
1164    return intrin->intrinsic == nir_intrinsic_load_input ||
1165           intrin->intrinsic == nir_intrinsic_load_per_vertex_input;
1166 }
1167 
1168 void
ac_nir_lower_ls_outputs_to_mem(nir_shader * shader,ac_nir_map_io_driver_location map,enum amd_gfx_level gfx_level,bool tcs_in_out_eq,uint64_t tcs_inputs_via_temp,uint64_t tcs_inputs_via_lds)1169 ac_nir_lower_ls_outputs_to_mem(nir_shader *shader,
1170                                ac_nir_map_io_driver_location map,
1171                                enum amd_gfx_level gfx_level,
1172                                bool tcs_in_out_eq,
1173                                uint64_t tcs_inputs_via_temp,
1174                                uint64_t tcs_inputs_via_lds)
1175 {
1176    assert(shader->info.stage == MESA_SHADER_VERTEX);
1177    assert(gfx_level >= GFX9 || !tcs_in_out_eq);
1178 
1179    lower_tess_io_state state = {
1180       .gfx_level = gfx_level,
1181       .map_io = map,
1182    };
1183 
1184    if (tcs_in_out_eq) {
1185       state.tcs_inputs_via_temp = tcs_inputs_via_temp;
1186       state.tcs_inputs_via_lds = tcs_inputs_via_lds;
1187    } else {
1188       state.tcs_inputs_via_lds = tcs_inputs_via_lds | tcs_inputs_via_temp;
1189    }
1190 
1191    nir_shader_intrinsics_pass(shader, lower_ls_output_store,
1192                                 nir_metadata_control_flow,
1193                                 &state);
1194 }
1195 
1196 void
ac_nir_lower_hs_inputs_to_mem(nir_shader * shader,ac_nir_map_io_driver_location map,enum amd_gfx_level gfx_level,bool tcs_in_out_eq,uint64_t tcs_inputs_via_temp,uint64_t tcs_inputs_via_lds)1197 ac_nir_lower_hs_inputs_to_mem(nir_shader *shader,
1198                               ac_nir_map_io_driver_location map,
1199                               enum amd_gfx_level gfx_level,
1200                               bool tcs_in_out_eq,
1201                               uint64_t tcs_inputs_via_temp,
1202                               uint64_t tcs_inputs_via_lds)
1203 {
1204    assert(shader->info.stage == MESA_SHADER_TESS_CTRL);
1205    assert(gfx_level >= GFX9 || !tcs_in_out_eq);
1206 
1207    lower_tess_io_state state = {
1208       .gfx_level = gfx_level,
1209       .map_io = map,
1210    };
1211 
1212    if (tcs_in_out_eq) {
1213       state.tcs_inputs_via_temp = tcs_inputs_via_temp;
1214       state.tcs_inputs_via_lds = tcs_inputs_via_lds;
1215    } else {
1216       state.tcs_inputs_via_lds = shader->info.inputs_read;
1217    }
1218 
1219    nir_shader_lower_instructions(shader,
1220                                  filter_load_tcs_per_vertex_input,
1221                                  lower_hs_per_vertex_input_load,
1222                                  &state);
1223 }
1224 
1225 void
ac_nir_lower_hs_outputs_to_mem(nir_shader * shader,const nir_tcs_info * info,ac_nir_map_io_driver_location map,enum amd_gfx_level gfx_level,uint64_t tes_inputs_read,uint32_t tes_patch_inputs_read,unsigned wave_size)1226 ac_nir_lower_hs_outputs_to_mem(nir_shader *shader, const nir_tcs_info *info,
1227                                ac_nir_map_io_driver_location map,
1228                                enum amd_gfx_level gfx_level,
1229                                uint64_t tes_inputs_read,
1230                                uint32_t tes_patch_inputs_read,
1231                                unsigned wave_size)
1232 {
1233    assert(shader->info.stage == MESA_SHADER_TESS_CTRL);
1234 
1235    lower_tess_io_state state = {
1236       .gfx_level = gfx_level,
1237       .tcs_info = *info,
1238       .tes_inputs_read = tes_inputs_read,
1239       .tes_patch_inputs_read = tes_patch_inputs_read,
1240       .tcs_out_patch_fits_subgroup = wave_size % shader->info.tess.tcs_vertices_out == 0,
1241       .map_io = map,
1242    };
1243 
1244    if (state.tcs_info.all_invocations_define_tess_levels) {
1245       nir_function_impl *impl = nir_shader_get_entrypoint(shader);
1246       state.tcs_tess_level_outer =
1247          nir_local_variable_create(impl, glsl_vec4_type(), "tess outer");
1248       state.tcs_tess_level_inner =
1249          nir_local_variable_create(impl, glsl_vec4_type(), "tess inner");
1250    }
1251 
1252    nir_shader_lower_instructions(shader,
1253                                  filter_hs_output_access,
1254                                  lower_hs_output_access,
1255                                  &state);
1256 
1257    hs_finale(shader, &state);
1258 
1259    /* Cleanup the local variable for tess levels. */
1260    if (state.tcs_info.all_invocations_define_tess_levels) {
1261       NIR_PASS(_, shader, nir_lower_vars_to_ssa);
1262       NIR_PASS(_, shader, nir_remove_dead_variables, nir_var_function_temp, NULL);
1263       NIR_PASS(_, shader, nir_lower_alu_to_scalar, NULL, NULL);
1264       NIR_PASS(_, shader, nir_lower_phis_to_scalar, true);
1265    }
1266 }
1267 
1268 void
ac_nir_lower_tes_inputs_to_mem(nir_shader * shader,ac_nir_map_io_driver_location map)1269 ac_nir_lower_tes_inputs_to_mem(nir_shader *shader,
1270                                ac_nir_map_io_driver_location map)
1271 {
1272    assert(shader->info.stage == MESA_SHADER_TESS_EVAL);
1273 
1274    lower_tess_io_state state = {
1275       .map_io = map,
1276       .tes_inputs_read = shader->info.inputs_read,
1277       .tes_patch_inputs_read = shader->info.patch_inputs_read,
1278    };
1279 
1280    nir_shader_lower_instructions(shader,
1281                                  filter_any_input_access,
1282                                  lower_tes_input_load,
1283                                  &state);
1284 }
1285 
1286 void
ac_nir_compute_tess_wg_info(const struct radeon_info * info,const struct shader_info * tcs_info,unsigned wave_size,bool tess_uses_primid,bool all_invocations_define_tess_levels,unsigned num_tcs_input_cp,unsigned lds_input_vertex_size,unsigned num_mem_tcs_outputs,unsigned num_mem_tcs_patch_outputs,unsigned * num_patches_per_wg,unsigned * hw_lds_size)1287 ac_nir_compute_tess_wg_info(const struct radeon_info *info, const struct shader_info *tcs_info,
1288                             unsigned wave_size, bool tess_uses_primid, bool all_invocations_define_tess_levels,
1289                             unsigned num_tcs_input_cp, unsigned lds_input_vertex_size,
1290                             unsigned num_mem_tcs_outputs, unsigned num_mem_tcs_patch_outputs,
1291                             unsigned *num_patches_per_wg, unsigned *hw_lds_size)
1292 {
1293    unsigned num_tcs_output_cp = tcs_info->tess.tcs_vertices_out;
1294    unsigned lds_output_vertex_size =
1295       util_bitcount64(tcs_info->outputs_read & tcs_info->outputs_written & ~TESS_LVL_MASK) * 16;
1296    unsigned lds_perpatch_output_patch_size =
1297       (util_bitcount64(all_invocations_define_tess_levels ?
1298                           0 : tcs_info->outputs_written & TESS_LVL_MASK) +
1299        util_bitcount(tcs_info->patch_outputs_read & tcs_info->patch_outputs_written)) * 16;
1300 
1301    unsigned lds_per_patch = num_tcs_input_cp * lds_input_vertex_size +
1302                             num_tcs_output_cp * lds_output_vertex_size +
1303                             lds_perpatch_output_patch_size;
1304    unsigned mem_per_patch = (num_tcs_output_cp * num_mem_tcs_outputs + num_mem_tcs_patch_outputs) * 16;
1305    unsigned num_patches = ac_compute_num_tess_patches(info, num_tcs_input_cp, num_tcs_output_cp, mem_per_patch,
1306                                                       lds_per_patch, wave_size, tess_uses_primid);
1307    unsigned lds_size = lds_per_patch * num_patches;
1308    unsigned mem_size = mem_per_patch * num_patches;
1309 
1310    /* The first vec4 is reserved for the tf0/1 shader message group vote. */
1311    if (info->gfx_level >= GFX11)
1312       lds_size += AC_HS_MSG_VOTE_LDS_BYTES;
1313 
1314    /* SPI_SHADER_PGM_RSRC2_HS.LDS_SIZE specifies the allocation size for both LDS and the HS
1315     * offchip ring buffer. LDS is only used for TCS inputs (with cross-invocation or indirect
1316     * access only or if TCS in/out vertex counts are different) and for TCS outputs that are read
1317     * (including tess level outputs if they need to be re-read in invocation 0), while the HS ring
1318     * buffer is only used for TCS outputs consumed by TES.
1319     */
1320    unsigned merged_size = MAX2(lds_size, mem_size);
1321    assert(merged_size <= (info->gfx_level >= GFX9 ? 65536 : 32768));
1322 
1323    *num_patches_per_wg = num_patches;
1324    *hw_lds_size = DIV_ROUND_UP(merged_size, info->lds_encode_granularity);
1325 }
1326