• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright © 2021 Valve Corporation
3  *
4  * Permission is hereby granted, free of charge, to any person obtaining a
5  * copy of this software and associated documentation files (the "Software"),
6  * to deal in the Software without restriction, including without limitation
7  * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8  * and/or sell copies of the Software, and to permit persons to whom the
9  * Software is furnished to do so, subject to the following conditions:
10  *
11  * The above copyright notice and this permission notice (including the next
12  * paragraph) shall be included in all copies or substantial portions of the
13  * Software.
14  *
15  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18  * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20  * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21  * IN THE SOFTWARE.
22  *
23  */
24 
25 #include "ac_nir.h"
26 #include "nir_builder.h"
27 
28 /*
29  * These NIR passes are used to lower NIR cross-stage I/O intrinsics into the
30  * memory accesses that actually happen on the HW.
31  *
32  * Each input and output has a 16-byte (4 dwords) slot reserved for it, and
33  * can have up to 4 components. Each component is 32 bits.
34  *
35  * ## VS-TCS-TES I/O - Terminology:
36  *
37  * * patch - Group of vertices, used instead of primitives in tessellation
38  * * per-vertex - input or output which can be different for every vertex.
39  * * per-patch - input output which applies to a patch (a group of vertices)
40  *
41  * ## VS-TCS-TES I/O - How it works:
42  *
43  * ```
44  * SW model:    SW VS         SW TCS    tessellator    SW TES
45  *                ┊             ┊             ┊          ┊
46  *              ┌────┐        ┌────┐        ┌────┐    ┌─────┐
47  * HW pipeline: │ LS │─╮   ╭─>│ HS │─╮   ╭─>│ FF │ ╭─>│VS/ES48  *              └────┘ │   │  └────┘ │   │  └────┘ │  └─────┘
49  * Memory:             ╰─>LDS<──╯    ╰─>VRAM───────╯
50  * ```
51  *
52  * * SW VS runs as a HW LS (Local Shader, merged into HS on GFX9+),
53  *   and SW TCS runs as HW HS (Hull Shader).
54  *   SW TES runs as either HW VS or HW ES (Export Shader).
55  * * LS and HS share the same LDS space.
56  * * LS (SW VS) stores outputs to LDS to be read by HS (SW TCS).
57  * * HS (SW TCS) stores outputs in LDS if the HS (SW TCS) reads them.
58  * * HS (SW TCS) stores outputs in VRAM if the next stage (SW TES) reads them.
59  *
60  * Side note: some old HW supports having TES read from the same LDS space where LS/HS write, but
61  * Mesa always stores HS outputs to VRAM to avoid forcing TES waves to run on the same CU as the LS/HS waves.
62  *
63  * ### Passing VS-TCS I/O in registers
64  *
65  * On GPUs that run SW VS and  SW TCS on the same HW stage (HS on GFX9+),
66  * IO can be passed through registers instead of LDS when the following conditions are met:
67  *
68  * 1. TCS input and output patch size match
69  * 2. Floating point execution modes in SW VS and SW TCS match
70  * 3. The SW VS output is not written indirectly, and the corresponding SW TCS input is not read indirectly
71  *
72  * Some HS outputs could be passed through registers to, but this is a TODO.
73  *
74  * ### LDS layout used by VS-TCS:
75  *
76  * ```
77  * TCS per-vertex inputs for patch 0  <─── 0
78  * TCS per-vertex inputs for patch 1
79  * TCS per-vertex inputs for patch 2  <─── hs_per_vertex_input_lds_offset (rel_patch_id = 2)
80  * ...
81  * TCS per-vertex outputs for patch 0 <─── output_patch0_offset
82  * TCS per-patch outputs for patch 0  <─── output_patch0_patch_data_offset
83  * TCS per-vertex outputs for patch 1
84  * TCS per-patch outputs for patch 1
85  * TCS per-vertex outputs for patch 2 <─── hs_output_lds_offset (rel_patch_id = 2, per-vertex)
86  * TCS per-patch outputs for patch 2  <─── hs_output_lds_offset (rel_patch_id = 2, per-patch)
87  * ...
88  * ```
89  *
90  * ### VRAM layout used by TCS-TES I/O:
91  *
92  * ```
93  * attr 0 of patch 0 vertex 0   <─── "off-chip LDS" offset
94  * attr 0 of patch 0 vertex 1
95  * attr 0 of patch 0 vertex 2
96  * ...
97  * attr 0 of patch 1 vertex 0
98  * attr 0 of patch 1 vertex 1
99  * attr 0 of patch 1 vertex 2   <─── hs_per_vertex_output_vmem_offset (attribute slot = 0, rel_patch_id = 1, vertex index = 1)
100  * ...
101  * attr 0 of patch 2 vertex 0
102  * attr 0 of patch 2 vertex 1
103  * attr 0 of patch 2 vertex 2
104  * ...
105  * attr 1 of patch 0 vertex 0
106  * attr 1 of patch 0 vertex 1
107  * attr 1 of patch 0 vertex 2
108  * ...
109  * ...
110  * per-patch attr 0 of patch 0  <─── hs_out_patch_data_offset_amd
111  * per-patch attr 0 of patch 1
112  * per-patch attr 0 of patch 2  <─── hs_per_patch_output_vmem_offset (attribute slot = 0, rel_patch_id = 2)
113  * ...
114  * per-patch attr 1 of patch 0
115  * per-patch attr 1 of patch 1
116  * per-patch attr 1 of patch 2
117  * ...
118  * ```
119  *
120  */
121 
122 typedef struct {
123    /* Which hardware generation we're dealing with */
124    enum amd_gfx_level gfx_level;
125 
126    /* I/O semantic -> real location used by lowering. */
127    ac_nir_map_io_driver_location map_io;
128 
129    /* True if merged VS+TCS (on GFX9+) has the same number
130     * of input and output patch size.
131     */
132    bool tcs_in_out_eq;
133 
134    /* Bit mask of TCS per-vertex inputs (VS outputs) which
135     * are passed between the two stages only in temporaries (registers).
136     */
137    uint64_t tcs_temp_only_inputs;
138 
139    /* Bit mask of TCS outputs read by TES. */
140    uint64_t tes_inputs_read;
141    uint64_t tes_patch_inputs_read;
142 
143    /* Whether TES reads the tess factors. */
144    bool tes_reads_tessfactors;
145 
146    unsigned tcs_num_reserved_outputs;
147    unsigned tcs_num_reserved_patch_outputs;
148 
149    /* Location (slot) where tessellation levels are stored. */
150    unsigned tcs_tess_lvl_in_loc;
151    unsigned tcs_tess_lvl_out_loc;
152 
153    /* True if the output patch fits the subgroup, so all TCS outputs are always written in the same
154     * subgroup that reads them.
155     */
156    bool tcs_out_patch_fits_subgroup;
157 
158    /* Set if all invocations will write to all tess factors, so tess factors
159     * can be passed by register.
160     */
161    bool tcs_pass_tessfactors_by_reg;
162 
163    /* Whether all TCS inputs are accessed using gl_InvocationID and passed via VGPRs.
164     * In that case, no LDS is allocated for TCS inputs.
165     */
166    bool tcs_no_inputs_in_lds;
167 } lower_tess_io_state;
168 
169 static bool
match_mask(gl_shader_stage stage,nir_intrinsic_instr * intrin,uint64_t mask,bool match_indirect)170 match_mask(gl_shader_stage stage,
171            nir_intrinsic_instr *intrin,
172            uint64_t mask,
173            bool match_indirect)
174 {
175    bool indirect = !nir_src_is_const(*nir_get_io_offset_src(intrin));
176    if (indirect)
177       return match_indirect;
178 
179    uint64_t slot = nir_intrinsic_io_semantics(intrin).location;
180    if (stage == MESA_SHADER_TESS_CTRL &&
181        intrin->intrinsic != nir_intrinsic_load_per_vertex_input &&
182        intrin->intrinsic != nir_intrinsic_store_per_vertex_output)
183       slot -= VARYING_SLOT_PATCH0;
184 
185    return (UINT64_C(1) << slot) & mask;
186 }
187 
188 static bool
tcs_output_needs_vmem(nir_intrinsic_instr * intrin,lower_tess_io_state * st)189 tcs_output_needs_vmem(nir_intrinsic_instr *intrin,
190                       lower_tess_io_state *st)
191 {
192    uint64_t mask = intrin->intrinsic == nir_intrinsic_store_per_vertex_output
193                    ? st->tes_inputs_read
194                    : st->tes_patch_inputs_read;
195 
196    return match_mask(MESA_SHADER_TESS_CTRL, intrin, mask, true);
197 }
198 
199 static bool
tcs_output_needs_lds(nir_intrinsic_instr * intrin,nir_shader * shader)200 tcs_output_needs_lds(nir_intrinsic_instr *intrin,
201                      nir_shader *shader)
202 {
203    uint64_t mask = intrin->intrinsic == nir_intrinsic_store_per_vertex_output
204                    ? shader->info.outputs_read
205                    : shader->info.patch_outputs_read;
206 
207    return match_mask(MESA_SHADER_TESS_CTRL, intrin, mask, true);
208 }
209 
210 static bool
lower_ls_output_store(nir_builder * b,nir_instr * instr,void * state)211 lower_ls_output_store(nir_builder *b,
212                       nir_instr *instr,
213                       void *state)
214 {
215    if (instr->type != nir_instr_type_intrinsic)
216       return false;
217 
218    nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
219 
220    if (intrin->intrinsic != nir_intrinsic_store_output)
221       return false;
222 
223    /* The ARB_shader_viewport_layer_array spec contains the
224     * following issue:
225     *
226     *    2) What happens if gl_ViewportIndex or gl_Layer is
227     *    written in the vertex shader and a geometry shader is
228     *    present?
229     *
230     *    RESOLVED: The value written by the last vertex processing
231     *    stage is used. If the last vertex processing stage
232     *    (vertex, tessellation evaluation or geometry) does not
233     *    statically assign to gl_ViewportIndex or gl_Layer, index
234     *    or layer zero is assumed.
235     *
236     * So writes to those outputs in VS-as-LS are simply ignored.
237     */
238    unsigned semantic = nir_intrinsic_io_semantics(intrin).location;
239    if (semantic == VARYING_SLOT_LAYER || semantic == VARYING_SLOT_VIEWPORT) {
240       nir_instr_remove(instr);
241       return true;
242    }
243 
244    lower_tess_io_state *st = (lower_tess_io_state *) state;
245 
246    /* If this is a temp-only TCS input, we don't need to use shared memory at all. */
247    if (match_mask(MESA_SHADER_VERTEX, intrin, st->tcs_temp_only_inputs, false))
248       return false;
249 
250    b->cursor = nir_before_instr(instr);
251 
252    nir_ssa_def *vertex_idx = nir_load_local_invocation_index(b);
253    nir_ssa_def *base_off_var = nir_imul(b, vertex_idx, nir_load_lshs_vertex_stride_amd(b));
254 
255    nir_ssa_def *io_off = ac_nir_calc_io_offset(b, intrin, nir_imm_int(b, 16u), 4u, st->map_io);
256    unsigned write_mask = nir_intrinsic_write_mask(intrin);
257 
258    nir_ssa_def *off = nir_iadd_nuw(b, base_off_var, io_off);
259    nir_store_shared(b, intrin->src[0].ssa, off, .write_mask = write_mask,
260                     .align_mul = 16u, .align_offset = (nir_intrinsic_component(intrin) * 4u) % 16u);
261 
262    /* NOTE: don't remove the store_output intrinsic on GFX9+ when tcs_in_out_eq,
263     * it will be used by same-invocation TCS input loads.
264     */
265    if (!st->tcs_in_out_eq)
266       nir_instr_remove(instr);
267 
268    return true;
269 }
270 
271 static bool
filter_load_tcs_per_vertex_input(const nir_instr * instr,UNUSED const void * state)272 filter_load_tcs_per_vertex_input(const nir_instr *instr,
273                                  UNUSED const void *state)
274 {
275    if (instr->type != nir_instr_type_intrinsic)
276       return false;
277 
278    lower_tess_io_state *st = (lower_tess_io_state *) state;
279    nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
280 
281    if (intrin->intrinsic != nir_intrinsic_load_per_vertex_input)
282       return false;
283    if (!st->tcs_in_out_eq)
284       return true;
285 
286    /* tcs_in_out_eq: a same-invocation input load, without indirect offset,
287     * can use temporaries, no need to use shared memory.
288     */
289    nir_src *off_src = nir_get_io_offset_src(intrin);
290    nir_src *vertex_index_src = nir_get_io_arrayed_index_src(intrin);
291    nir_instr *vertex_index_instr = vertex_index_src->ssa->parent_instr;
292 
293    bool can_use_temps = nir_src_is_const(*off_src) &&
294                         vertex_index_instr->type == nir_instr_type_intrinsic &&
295                         nir_instr_as_intrinsic(vertex_index_instr)->intrinsic == nir_intrinsic_load_invocation_id;
296 
297    return !can_use_temps;
298 }
299 
300 static nir_ssa_def *
hs_per_vertex_input_lds_offset(nir_builder * b,lower_tess_io_state * st,nir_intrinsic_instr * instr)301 hs_per_vertex_input_lds_offset(nir_builder *b,
302                                lower_tess_io_state *st,
303                                nir_intrinsic_instr *instr)
304 {
305    nir_ssa_def *tcs_in_vtxcnt = nir_load_patch_vertices_in(b);
306    nir_ssa_def *rel_patch_id = nir_load_tess_rel_patch_id_amd(b);
307    nir_ssa_def *vertex_index = nir_get_io_arrayed_index_src(instr)->ssa;
308 
309    nir_ssa_def *stride = nir_load_lshs_vertex_stride_amd(b);
310    nir_ssa_def *tcs_in_patch_stride = nir_imul(b, tcs_in_vtxcnt, stride);
311    nir_ssa_def *vertex_index_off = nir_imul(b, vertex_index, stride);
312 
313    nir_ssa_def *tcs_in_current_patch_offset = nir_imul(b, rel_patch_id, tcs_in_patch_stride);
314 
315    nir_ssa_def *io_offset = ac_nir_calc_io_offset(b, instr, nir_imm_int(b, 16u), 4u, st->map_io);
316 
317    return nir_iadd_nuw(b, nir_iadd_nuw(b, tcs_in_current_patch_offset, vertex_index_off), io_offset);
318 }
319 
320 static nir_ssa_def *
hs_output_lds_offset(nir_builder * b,lower_tess_io_state * st,nir_intrinsic_instr * intrin)321 hs_output_lds_offset(nir_builder *b,
322                      lower_tess_io_state *st,
323                      nir_intrinsic_instr *intrin)
324 {
325    bool per_vertex = intrin &&
326                      (intrin->intrinsic == nir_intrinsic_store_per_vertex_output ||
327                       intrin->intrinsic == nir_intrinsic_load_per_vertex_output);
328 
329    unsigned output_vertex_size = st->tcs_num_reserved_outputs * 16u;
330    unsigned pervertex_output_patch_size = b->shader->info.tess.tcs_vertices_out * output_vertex_size;
331    unsigned output_patch_stride = pervertex_output_patch_size + st->tcs_num_reserved_patch_outputs * 16u;
332 
333    nir_ssa_def *off = intrin
334                     ? ac_nir_calc_io_offset(b, intrin, nir_imm_int(b, 16u), 4u, st->map_io)
335                     : nir_imm_int(b, 0);
336 
337    nir_ssa_def *rel_patch_id = nir_load_tess_rel_patch_id_amd(b);
338    nir_ssa_def *patch_offset = nir_imul_imm(b, rel_patch_id, output_patch_stride);
339 
340    nir_ssa_def *output_patch_offset;
341    if (st->tcs_no_inputs_in_lds)
342       output_patch_offset = patch_offset;
343    else {
344       nir_ssa_def *tcs_in_vtxcnt = nir_load_patch_vertices_in(b);
345       nir_ssa_def *tcs_num_patches = nir_load_tcs_num_patches_amd(b);
346       nir_ssa_def *input_patch_size =
347          nir_imul(b, tcs_in_vtxcnt, nir_load_lshs_vertex_stride_amd(b));
348       nir_ssa_def *output_patch0_offset = nir_imul(b, input_patch_size, tcs_num_patches);
349       output_patch_offset = nir_iadd_nuw(b, patch_offset, output_patch0_offset);
350    }
351 
352    if (per_vertex) {
353       nir_ssa_def *vertex_index = nir_ssa_for_src(b, *nir_get_io_arrayed_index_src(intrin), 1);
354       nir_ssa_def *vertex_index_off = nir_imul_imm(b, vertex_index, output_vertex_size);
355 
356       off = nir_iadd_nuw(b, off, vertex_index_off);
357       return nir_iadd_nuw(b, off, output_patch_offset);
358    } else {
359       off = nir_iadd_imm_nuw(b, off, pervertex_output_patch_size);
360       return nir_iadd_nuw(b, off, output_patch_offset);
361    }
362 }
363 
364 static nir_ssa_def *
hs_per_vertex_output_vmem_offset(nir_builder * b,lower_tess_io_state * st,nir_intrinsic_instr * intrin)365 hs_per_vertex_output_vmem_offset(nir_builder *b,
366                                  lower_tess_io_state *st,
367                                  nir_intrinsic_instr *intrin)
368 {
369    nir_ssa_def *out_vertices_per_patch = b->shader->info.stage == MESA_SHADER_TESS_CTRL
370                                          ? nir_imm_int(b, b->shader->info.tess.tcs_vertices_out)
371                                          : nir_load_patch_vertices_in(b);
372 
373    nir_ssa_def *tcs_num_patches = nir_load_tcs_num_patches_amd(b);
374    nir_ssa_def *attr_stride = nir_imul(b, tcs_num_patches, nir_imul_imm(b, out_vertices_per_patch, 16u));
375    nir_ssa_def *io_offset = ac_nir_calc_io_offset(b, intrin, attr_stride, 4u, st->map_io);
376 
377    nir_ssa_def *rel_patch_id = nir_load_tess_rel_patch_id_amd(b);
378    nir_ssa_def *patch_offset = nir_imul(b, rel_patch_id, nir_imul_imm(b, out_vertices_per_patch, 16u));
379 
380    nir_ssa_def *vertex_index = nir_ssa_for_src(b, *nir_get_io_arrayed_index_src(intrin), 1);
381    nir_ssa_def *vertex_index_off = nir_imul_imm(b, vertex_index, 16u);
382 
383    return nir_iadd_nuw(b, nir_iadd_nuw(b, patch_offset, vertex_index_off), io_offset);
384 }
385 
386 static nir_ssa_def *
hs_per_patch_output_vmem_offset(nir_builder * b,lower_tess_io_state * st,nir_intrinsic_instr * intrin,unsigned const_base_offset)387 hs_per_patch_output_vmem_offset(nir_builder *b,
388                                 lower_tess_io_state *st,
389                                 nir_intrinsic_instr *intrin,
390                                 unsigned const_base_offset)
391 {
392    nir_ssa_def *tcs_num_patches = nir_load_tcs_num_patches_amd(b);
393    nir_ssa_def *per_patch_data_offset = nir_load_hs_out_patch_data_offset_amd(b);
394 
395    nir_ssa_def * off = intrin
396                     ? ac_nir_calc_io_offset(b, intrin, nir_imul_imm(b, tcs_num_patches, 16u), 4u, st->map_io)
397                     : nir_imm_int(b, 0);
398 
399    if (const_base_offset)
400       off = nir_iadd_nuw(b, off, nir_imul_imm(b, tcs_num_patches, const_base_offset));
401 
402    nir_ssa_def *rel_patch_id = nir_load_tess_rel_patch_id_amd(b);
403    nir_ssa_def *patch_offset = nir_imul_imm(b, rel_patch_id, 16u);
404    off = nir_iadd_nuw(b, off, per_patch_data_offset);
405    return nir_iadd_nuw(b, off, patch_offset);
406 }
407 
408 static nir_ssa_def *
lower_hs_per_vertex_input_load(nir_builder * b,nir_instr * instr,void * state)409 lower_hs_per_vertex_input_load(nir_builder *b,
410                                nir_instr *instr,
411                                void *state)
412 {
413    lower_tess_io_state *st = (lower_tess_io_state *) state;
414    nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
415 
416    nir_ssa_def *off = hs_per_vertex_input_lds_offset(b, st, intrin);
417    return nir_load_shared(b, intrin->dest.ssa.num_components, intrin->dest.ssa.bit_size, off,
418                           .align_mul = 16u, .align_offset = (nir_intrinsic_component(intrin) * 4u) % 16u);
419 }
420 
421 static nir_ssa_def *
lower_hs_output_store(nir_builder * b,nir_intrinsic_instr * intrin,lower_tess_io_state * st)422 lower_hs_output_store(nir_builder *b,
423                       nir_intrinsic_instr *intrin,
424                       lower_tess_io_state *st)
425 {
426    assert(intrin->intrinsic == nir_intrinsic_store_per_vertex_output ||
427           intrin->intrinsic == nir_intrinsic_store_output);
428 
429    nir_io_semantics semantics = nir_intrinsic_io_semantics(intrin);
430    nir_ssa_def *store_val = intrin->src[0].ssa;
431    unsigned write_mask = nir_intrinsic_write_mask(intrin);
432    bool is_tess_factor = semantics.location == VARYING_SLOT_TESS_LEVEL_INNER ||
433                          semantics.location == VARYING_SLOT_TESS_LEVEL_OUTER;
434    bool write_to_vmem = !is_tess_factor && tcs_output_needs_vmem(intrin, st);
435    bool write_to_lds = (is_tess_factor && !st->tcs_pass_tessfactors_by_reg) ||
436       tcs_output_needs_lds(intrin, b->shader);
437 
438    if (write_to_vmem) {
439       nir_ssa_def *vmem_off = intrin->intrinsic == nir_intrinsic_store_per_vertex_output
440                             ? hs_per_vertex_output_vmem_offset(b, st, intrin)
441                             : hs_per_patch_output_vmem_offset(b, st, intrin, 0);
442 
443       nir_ssa_def *hs_ring_tess_offchip = nir_load_ring_tess_offchip_amd(b);
444       nir_ssa_def *offchip_offset = nir_load_ring_tess_offchip_offset_amd(b);
445       nir_store_buffer_amd(b, store_val, hs_ring_tess_offchip, vmem_off, offchip_offset, .write_mask = write_mask, .memory_modes = nir_var_shader_out);
446    }
447 
448    if (write_to_lds) {
449       /* Remember driver location of tess factors, so we can read them later */
450       if (semantics.location == VARYING_SLOT_TESS_LEVEL_INNER)
451          st->tcs_tess_lvl_in_loc = nir_intrinsic_base(intrin) * 16u;
452       else if (semantics.location == VARYING_SLOT_TESS_LEVEL_OUTER)
453          st->tcs_tess_lvl_out_loc = nir_intrinsic_base(intrin) * 16u;
454 
455       nir_ssa_def *lds_off = hs_output_lds_offset(b, st, intrin);
456       nir_store_shared(b, store_val, lds_off, .write_mask = write_mask,
457                        .align_mul = 16u, .align_offset = (nir_intrinsic_component(intrin) * 4u) % 16u);
458    }
459 
460    /* Keep tess factor nir_store_output instruction if it's going to be passed
461     * by reg instead of LDS, because it's used by radeonsi llvm backend to generate
462     * llvm variable which is read by the final llvm tess factor write epilog.
463     */
464    return is_tess_factor && st->tcs_pass_tessfactors_by_reg ?
465       NIR_LOWER_INSTR_PROGRESS : NIR_LOWER_INSTR_PROGRESS_REPLACE;
466 }
467 
468 static nir_ssa_def *
lower_hs_output_load(nir_builder * b,nir_intrinsic_instr * intrin,lower_tess_io_state * st)469 lower_hs_output_load(nir_builder *b,
470                      nir_intrinsic_instr *intrin,
471                      lower_tess_io_state *st)
472 {
473    nir_ssa_def *off = hs_output_lds_offset(b, st, intrin);
474    return nir_load_shared(b, intrin->dest.ssa.num_components, intrin->dest.ssa.bit_size, off,
475                           .align_mul = 16u, .align_offset = (nir_intrinsic_component(intrin) * 4u) % 16u);
476 }
477 
478 static void
update_hs_scoped_barrier(nir_intrinsic_instr * intrin,lower_tess_io_state * st)479 update_hs_scoped_barrier(nir_intrinsic_instr *intrin, lower_tess_io_state *st)
480 {
481    /* Output loads and stores are lowered to shared memory access,
482     * so we have to update the barriers to also reflect this.
483     */
484    unsigned mem_modes = nir_intrinsic_memory_modes(intrin);
485    if (mem_modes & nir_var_shader_out)
486       mem_modes |= nir_var_mem_shared;
487    nir_intrinsic_set_memory_modes(intrin, mem_modes);
488 
489    nir_scope exec_scope = nir_intrinsic_execution_scope(intrin);
490    if (exec_scope == NIR_SCOPE_WORKGROUP && st->tcs_out_patch_fits_subgroup)
491       nir_intrinsic_set_execution_scope(intrin, NIR_SCOPE_SUBGROUP);
492 
493    nir_scope mem_scope = nir_intrinsic_memory_scope(intrin);
494    if (mem_scope == NIR_SCOPE_WORKGROUP && st->tcs_out_patch_fits_subgroup)
495       nir_intrinsic_set_memory_scope(intrin, NIR_SCOPE_SUBGROUP);
496 }
497 
498 static nir_ssa_def *
lower_hs_output_access(nir_builder * b,nir_instr * instr,void * state)499 lower_hs_output_access(nir_builder *b,
500                        nir_instr *instr,
501                        void *state)
502 {
503    lower_tess_io_state *st = (lower_tess_io_state *) state;
504    nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
505 
506    if (intrin->intrinsic == nir_intrinsic_store_output ||
507        intrin->intrinsic == nir_intrinsic_store_per_vertex_output) {
508       return lower_hs_output_store(b, intrin, st);
509    } else if (intrin->intrinsic == nir_intrinsic_load_output ||
510               intrin->intrinsic == nir_intrinsic_load_per_vertex_output) {
511       return lower_hs_output_load(b, intrin, st);
512    } else if (intrin->intrinsic == nir_intrinsic_scoped_barrier) {
513       update_hs_scoped_barrier(intrin, st);
514       return NIR_LOWER_INSTR_PROGRESS;
515    } else {
516       unreachable("intrinsic not supported by lower_hs_output_access");
517    }
518 }
519 
520 static void
hs_emit_write_tess_factors(nir_shader * shader,lower_tess_io_state * st)521 hs_emit_write_tess_factors(nir_shader *shader,
522                            lower_tess_io_state *st)
523 {
524    unsigned outer_comps;
525    unsigned inner_comps;
526 
527    switch (shader->info.tess._primitive_mode) {
528    case TESS_PRIMITIVE_ISOLINES:
529       outer_comps = 2;
530       inner_comps = 0;
531       break;
532    case TESS_PRIMITIVE_TRIANGLES:
533       outer_comps = 3;
534       inner_comps = 1;
535       break;
536    case TESS_PRIMITIVE_QUADS:
537       outer_comps = 4;
538       inner_comps = 2;
539       break;
540    default:
541       unreachable("invalid primitive mode");
542       return;
543    }
544 
545    nir_function_impl *impl = nir_shader_get_entrypoint(shader);
546    assert(impl);
547    nir_block *last_block = nir_impl_last_block(impl);
548    assert(last_block);
549 
550    /* We assume there is always a single end block in the shader. */
551 
552    nir_builder builder;
553    nir_builder *b = &builder; /* This is to avoid the & */
554    nir_builder_init(b, impl);
555    b->cursor = nir_after_block(last_block);
556 
557    nir_scope scope =
558       st->tcs_out_patch_fits_subgroup ? NIR_SCOPE_SUBGROUP : NIR_SCOPE_WORKGROUP;
559    nir_scoped_barrier(b, .execution_scope = scope, .memory_scope = scope,
560                       .memory_semantics = NIR_MEMORY_ACQ_REL, .memory_modes = nir_var_mem_shared);
561 
562    nir_ssa_def *invocation_id = nir_load_invocation_id(b);
563 
564    /* Only the 1st invocation of each patch needs to do this. */
565    nir_if *invocation_id_zero = nir_push_if(b, nir_ieq_imm(b, invocation_id, 0));
566 
567    /* The descriptor where tess factors have to be stored by the shader. */
568    nir_ssa_def *tessfactor_ring = nir_load_ring_tess_factors_amd(b);
569 
570    /* Base LDS address of per-patch outputs in the current patch. */
571    nir_ssa_def *lds_base = hs_output_lds_offset(b, st, NULL);
572 
573    /* Load all tessellation factors (aka. tess levels) from LDS. */
574    nir_ssa_def *tessfactors_outer = nir_load_shared(b, outer_comps, 32, lds_base, .base = st->tcs_tess_lvl_out_loc,
575                                                     .align_mul = 16u, .align_offset = st->tcs_tess_lvl_out_loc % 16u);
576    nir_ssa_def *tessfactors_inner = inner_comps
577                                     ? nir_load_shared(b, inner_comps, 32, lds_base, .base = st->tcs_tess_lvl_in_loc,
578                                                       .align_mul = 16u, .align_offset = st->tcs_tess_lvl_in_loc % 16u)
579                                     : NULL;
580 
581    nir_ssa_def *rel_patch_id = nir_load_tess_rel_patch_id_amd(b);
582    nir_ssa_def *tess_factors_base = nir_load_ring_tess_factors_offset_amd(b);
583    nir_ssa_def *tess_factors_offset = nir_imul_imm(b, rel_patch_id, (inner_comps + outer_comps) * 4u);
584    unsigned tess_factors_const_offset = 0;
585 
586    if (st->gfx_level <= GFX8) {
587       /* Store the dynamic HS control word. */
588       nir_if *rel_patch_id_zero = nir_push_if(b, nir_ieq_imm(b, rel_patch_id, 0));
589       nir_ssa_def *ctrlw = nir_imm_int(b, 0x80000000u);
590       nir_store_buffer_amd(b, ctrlw, tessfactor_ring, nir_imm_zero(b, 1, 32), tess_factors_base);
591       tess_factors_const_offset += 4;
592       nir_pop_if(b, rel_patch_id_zero);
593    }
594 
595    /* Store tess factors for the tessellator */
596    if (shader->info.tess._primitive_mode == TESS_PRIMITIVE_ISOLINES) {
597       /* LINES reversal */
598       nir_ssa_def *t = nir_vec2(b, nir_channel(b, tessfactors_outer, 1), nir_channel(b, tessfactors_outer, 0));
599       nir_store_buffer_amd(b, t, tessfactor_ring, tess_factors_offset, tess_factors_base, .base = tess_factors_const_offset);
600    } else if (shader->info.tess._primitive_mode == TESS_PRIMITIVE_TRIANGLES) {
601       nir_ssa_def *t = nir_vec4(b, nir_channel(b, tessfactors_outer, 0), nir_channel(b, tessfactors_outer, 1),
602                                 nir_channel(b, tessfactors_outer, 2), nir_channel(b, tessfactors_inner, 0));
603       nir_store_buffer_amd(b, t, tessfactor_ring, tess_factors_offset, tess_factors_base, .base = tess_factors_const_offset);
604    } else {
605       nir_store_buffer_amd(b, tessfactors_outer, tessfactor_ring, tess_factors_offset, tess_factors_base, .base = tess_factors_const_offset);
606       nir_store_buffer_amd(b, tessfactors_inner, tessfactor_ring, tess_factors_offset, tess_factors_base, .base = tess_factors_const_offset + 4u * outer_comps);
607    }
608 
609    if (st->tes_reads_tessfactors) {
610       /* Store to offchip for TES to read - only if TES actually reads them */
611       nir_ssa_def *hs_ring_tess_offchip = nir_load_ring_tess_offchip_amd(b);
612       nir_ssa_def *offchip_offset = nir_load_ring_tess_offchip_offset_amd(b);
613 
614       nir_ssa_def *vmem_off_outer = hs_per_patch_output_vmem_offset(b, st, NULL, st->tcs_tess_lvl_out_loc);
615       nir_store_buffer_amd(b, tessfactors_outer, hs_ring_tess_offchip, vmem_off_outer, offchip_offset, .memory_modes = nir_var_shader_out);
616 
617       if (inner_comps) {
618          nir_ssa_def *vmem_off_inner = hs_per_patch_output_vmem_offset(b, st, NULL, st->tcs_tess_lvl_in_loc);
619          nir_store_buffer_amd(b, tessfactors_inner, hs_ring_tess_offchip, vmem_off_inner, offchip_offset, .memory_modes = nir_var_shader_out);
620       }
621    }
622 
623    nir_pop_if(b, invocation_id_zero);
624 
625    nir_metadata_preserve(impl, nir_metadata_none);
626 }
627 
628 static nir_ssa_def *
lower_tes_input_load(nir_builder * b,nir_instr * instr,void * state)629 lower_tes_input_load(nir_builder *b,
630                      nir_instr *instr,
631                      void *state)
632 {
633    lower_tess_io_state *st = (lower_tess_io_state *) state;
634    nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
635 
636    nir_ssa_def *offchip_ring = nir_load_ring_tess_offchip_amd(b);
637    nir_ssa_def *offchip_offset = nir_load_ring_tess_offchip_offset_amd(b);
638    nir_ssa_def *off = intrin->intrinsic == nir_intrinsic_load_per_vertex_input
639                     ? hs_per_vertex_output_vmem_offset(b, st, intrin)
640                     : hs_per_patch_output_vmem_offset(b, st, intrin, 0);
641 
642    return nir_load_buffer_amd(b, intrin->dest.ssa.num_components, intrin->dest.ssa.bit_size, offchip_ring, off, offchip_offset);
643 }
644 
645 static bool
filter_hs_output_access(const nir_instr * instr,UNUSED const void * st)646 filter_hs_output_access(const nir_instr *instr,
647                          UNUSED const void *st)
648 {
649    if (instr->type != nir_instr_type_intrinsic)
650       return false;
651 
652    nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
653    return intrin->intrinsic == nir_intrinsic_store_output ||
654           intrin->intrinsic == nir_intrinsic_store_per_vertex_output ||
655           intrin->intrinsic == nir_intrinsic_load_output ||
656           intrin->intrinsic == nir_intrinsic_load_per_vertex_output ||
657           intrin->intrinsic == nir_intrinsic_scoped_barrier;
658 }
659 
660 static bool
filter_any_input_access(const nir_instr * instr,UNUSED const void * st)661 filter_any_input_access(const nir_instr *instr,
662                         UNUSED const void *st)
663 {
664    if (instr->type != nir_instr_type_intrinsic)
665       return false;
666 
667    nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
668    return intrin->intrinsic == nir_intrinsic_load_input ||
669           intrin->intrinsic == nir_intrinsic_load_per_vertex_input;
670 }
671 
672 void
ac_nir_lower_ls_outputs_to_mem(nir_shader * shader,ac_nir_map_io_driver_location map,bool tcs_in_out_eq,uint64_t tcs_temp_only_inputs)673 ac_nir_lower_ls_outputs_to_mem(nir_shader *shader,
674                                ac_nir_map_io_driver_location map,
675                                bool tcs_in_out_eq,
676                                uint64_t tcs_temp_only_inputs)
677 {
678    assert(shader->info.stage == MESA_SHADER_VERTEX);
679 
680    lower_tess_io_state state = {
681       .tcs_in_out_eq = tcs_in_out_eq,
682       .tcs_temp_only_inputs = tcs_in_out_eq ? tcs_temp_only_inputs : 0,
683       .map_io = map,
684    };
685 
686    nir_shader_instructions_pass(shader,
687                                 lower_ls_output_store,
688                                 nir_metadata_block_index | nir_metadata_dominance,
689                                 &state);
690 }
691 
692 void
ac_nir_lower_hs_inputs_to_mem(nir_shader * shader,ac_nir_map_io_driver_location map,bool tcs_in_out_eq)693 ac_nir_lower_hs_inputs_to_mem(nir_shader *shader,
694                               ac_nir_map_io_driver_location map,
695                               bool tcs_in_out_eq)
696 {
697    assert(shader->info.stage == MESA_SHADER_TESS_CTRL);
698 
699    lower_tess_io_state state = {
700       .tcs_in_out_eq = tcs_in_out_eq,
701       .map_io = map,
702    };
703 
704    nir_shader_lower_instructions(shader,
705                                  filter_load_tcs_per_vertex_input,
706                                  lower_hs_per_vertex_input_load,
707                                  &state);
708 }
709 
710 void
ac_nir_lower_hs_outputs_to_mem(nir_shader * shader,ac_nir_map_io_driver_location map,enum amd_gfx_level gfx_level,bool tes_reads_tessfactors,uint64_t tes_inputs_read,uint64_t tes_patch_inputs_read,unsigned num_reserved_tcs_outputs,unsigned num_reserved_tcs_patch_outputs,unsigned wave_size,bool no_inputs_in_lds,bool pass_tessfactors_by_reg,bool emit_tess_factor_write)711 ac_nir_lower_hs_outputs_to_mem(nir_shader *shader,
712                                ac_nir_map_io_driver_location map,
713                                enum amd_gfx_level gfx_level,
714                                bool tes_reads_tessfactors,
715                                uint64_t tes_inputs_read,
716                                uint64_t tes_patch_inputs_read,
717                                unsigned num_reserved_tcs_outputs,
718                                unsigned num_reserved_tcs_patch_outputs,
719                                unsigned wave_size,
720                                bool no_inputs_in_lds,
721                                bool pass_tessfactors_by_reg,
722                                bool emit_tess_factor_write)
723 {
724    assert(shader->info.stage == MESA_SHADER_TESS_CTRL);
725 
726    lower_tess_io_state state = {
727       .gfx_level = gfx_level,
728       .tes_reads_tessfactors = tes_reads_tessfactors,
729       .tes_inputs_read = tes_inputs_read,
730       .tes_patch_inputs_read = tes_patch_inputs_read,
731       .tcs_num_reserved_outputs = num_reserved_tcs_outputs,
732       .tcs_num_reserved_patch_outputs = num_reserved_tcs_patch_outputs,
733       .tcs_out_patch_fits_subgroup = wave_size % shader->info.tess.tcs_vertices_out == 0,
734       .tcs_pass_tessfactors_by_reg = pass_tessfactors_by_reg,
735       .tcs_no_inputs_in_lds = no_inputs_in_lds,
736       .map_io = map,
737    };
738 
739    nir_shader_lower_instructions(shader,
740                                  filter_hs_output_access,
741                                  lower_hs_output_access,
742                                  &state);
743 
744    if (emit_tess_factor_write)
745       hs_emit_write_tess_factors(shader, &state);
746 }
747 
748 void
ac_nir_lower_tes_inputs_to_mem(nir_shader * shader,ac_nir_map_io_driver_location map)749 ac_nir_lower_tes_inputs_to_mem(nir_shader *shader,
750                                ac_nir_map_io_driver_location map)
751 {
752    assert(shader->info.stage == MESA_SHADER_TESS_EVAL);
753 
754    lower_tess_io_state state = {
755       .map_io = map,
756    };
757 
758    nir_shader_lower_instructions(shader,
759                                  filter_any_input_access,
760                                  lower_tes_input_load,
761                                  &state);
762 }
763