1 /*
2 * Copyright © 2021 Valve Corporation
3 *
4 * SPDX-License-Identifier: MIT
5 */
6
7 #include "ac_nir.h"
8 #include "nir_builder.h"
9
10 /*
11 * Lower NIR cross-stage I/O intrinsics into the memory accesses that actually happen on the HW.
12 *
13 * These HW stages are used only when a Geometry Shader is used.
14 * Export Shader (ES) runs the SW stage before GS, can be either VS or TES.
15 *
16 * * GFX6-8:
17 * ES and GS are separate HW stages.
18 * I/O is passed between them through VRAM.
19 * * GFX9+:
20 * ES and GS are merged into a single HW stage.
21 * I/O is passed between them through LDS.
22 *
23 */
24
25 typedef struct {
26 /* Which hardware generation we're dealing with */
27 enum amd_gfx_level gfx_level;
28
29 /* I/O semantic -> real location used by lowering. */
30 ac_nir_map_io_driver_location map_io;
31
32 /* Stride of an ES invocation outputs in esgs ring, in bytes. */
33 unsigned esgs_itemsize;
34
35 /* Enable fix for triangle strip adjacency in geometry shader. */
36 bool gs_triangle_strip_adjacency_fix;
37 } lower_esgs_io_state;
38
39 static nir_def *
emit_split_buffer_load(nir_builder * b,nir_def * desc,nir_def * v_off,nir_def * s_off,unsigned component_stride,unsigned num_components,unsigned bit_size)40 emit_split_buffer_load(nir_builder *b, nir_def *desc, nir_def *v_off, nir_def *s_off,
41 unsigned component_stride, unsigned num_components, unsigned bit_size)
42 {
43 unsigned total_bytes = num_components * bit_size / 8u;
44 unsigned full_dwords = total_bytes / 4u;
45 unsigned remaining_bytes = total_bytes - full_dwords * 4u;
46
47 /* Accommodate max number of split 64-bit loads */
48 nir_def *comps[NIR_MAX_VEC_COMPONENTS * 2u];
49
50 /* Assume that 1x32-bit load is better than 1x16-bit + 1x8-bit */
51 if (remaining_bytes == 3) {
52 remaining_bytes = 0;
53 full_dwords++;
54 }
55
56 nir_def *zero = nir_imm_int(b, 0);
57
58 for (unsigned i = 0; i < full_dwords; ++i)
59 comps[i] = nir_load_buffer_amd(b, 1, 32, desc, v_off, s_off, zero,
60 .base = component_stride * i, .memory_modes = nir_var_shader_in,
61 .access = ACCESS_COHERENT);
62
63 if (remaining_bytes)
64 comps[full_dwords] = nir_load_buffer_amd(b, 1, remaining_bytes * 8, desc, v_off, s_off, zero,
65 .base = component_stride * full_dwords,
66 .memory_modes = nir_var_shader_in,
67 .access = ACCESS_COHERENT);
68
69 return nir_extract_bits(b, comps, full_dwords + !!remaining_bytes, 0, num_components, bit_size);
70 }
71
72 static void
emit_split_buffer_store(nir_builder * b,nir_def * d,nir_def * desc,nir_def * v_off,nir_def * s_off,unsigned component_stride,unsigned num_components,unsigned bit_size,unsigned writemask,bool swizzled,bool slc)73 emit_split_buffer_store(nir_builder *b, nir_def *d, nir_def *desc, nir_def *v_off, nir_def *s_off,
74 unsigned component_stride, unsigned num_components, unsigned bit_size,
75 unsigned writemask, bool swizzled, bool slc)
76 {
77 nir_def *zero = nir_imm_int(b, 0);
78
79 while (writemask) {
80 int start, count;
81 u_bit_scan_consecutive_range(&writemask, &start, &count);
82 assert(start >= 0 && count >= 0);
83
84 unsigned bytes = count * bit_size / 8u;
85 unsigned start_byte = start * bit_size / 8u;
86
87 while (bytes) {
88 unsigned store_bytes = MIN2(bytes, 4u);
89 if ((start_byte % 4) == 1 || (start_byte % 4) == 3)
90 store_bytes = MIN2(store_bytes, 1);
91 else if ((start_byte % 4) == 2)
92 store_bytes = MIN2(store_bytes, 2);
93
94 nir_def *store_val = nir_extract_bits(b, &d, 1, start_byte * 8u, 1, store_bytes * 8u);
95 nir_store_buffer_amd(b, store_val, desc, v_off, s_off, zero,
96 .base = start_byte, .memory_modes = nir_var_shader_out,
97 .access = ACCESS_COHERENT |
98 (slc ? ACCESS_NON_TEMPORAL : 0) |
99 (swizzled ? ACCESS_IS_SWIZZLED_AMD : 0));
100
101 start_byte += store_bytes;
102 bytes -= store_bytes;
103 }
104 }
105 }
106
107 static bool
lower_es_output_store(nir_builder * b,nir_intrinsic_instr * intrin,void * state)108 lower_es_output_store(nir_builder *b,
109 nir_intrinsic_instr *intrin,
110 void *state)
111 {
112 if (intrin->intrinsic != nir_intrinsic_store_output)
113 return false;
114
115 /* The ARB_shader_viewport_layer_array spec contains the
116 * following issue:
117 *
118 * 2) What happens if gl_ViewportIndex or gl_Layer is
119 * written in the vertex shader and a geometry shader is
120 * present?
121 *
122 * RESOLVED: The value written by the last vertex processing
123 * stage is used. If the last vertex processing stage
124 * (vertex, tessellation evaluation or geometry) does not
125 * statically assign to gl_ViewportIndex or gl_Layer, index
126 * or layer zero is assumed.
127 *
128 * Vulkan spec 15.7 Built-In Variables:
129 *
130 * The last active pre-rasterization shader stage (in pipeline order)
131 * controls the Layer that is used. Outputs in previous shader stages
132 * are not used, even if the last stage fails to write the Layer.
133 *
134 * The last active pre-rasterization shader stage (in pipeline order)
135 * controls the ViewportIndex that is used. Outputs in previous shader
136 * stages are not used, even if the last stage fails to write the
137 * ViewportIndex.
138 *
139 * So writes to those outputs in ES are simply ignored.
140 */
141 unsigned semantic = nir_intrinsic_io_semantics(intrin).location;
142 if (semantic == VARYING_SLOT_LAYER || semantic == VARYING_SLOT_VIEWPORT) {
143 nir_instr_remove(&intrin->instr);
144 return true;
145 }
146
147 lower_esgs_io_state *st = (lower_esgs_io_state *) state;
148 unsigned write_mask = nir_intrinsic_write_mask(intrin);
149
150 b->cursor = nir_before_instr(&intrin->instr);
151 nir_def *io_off = ac_nir_calc_io_offset(b, intrin, nir_imm_int(b, 16u), 4u, st->map_io);
152
153 if (st->gfx_level <= GFX8) {
154 /* GFX6-8: ES is a separate HW stage, data is passed from ES to GS in VRAM. */
155 nir_def *ring = nir_load_ring_esgs_amd(b);
156 nir_def *es2gs_off = nir_load_ring_es2gs_offset_amd(b);
157 emit_split_buffer_store(b, intrin->src[0].ssa, ring, io_off, es2gs_off, 4u,
158 intrin->src[0].ssa->num_components, intrin->src[0].ssa->bit_size,
159 write_mask, true, true);
160 } else {
161 /* GFX9+: ES is merged into GS, data is passed through LDS. */
162 nir_def *vertex_idx = nir_load_local_invocation_index(b);
163 nir_def *off = nir_iadd(b, nir_imul_imm(b, vertex_idx, st->esgs_itemsize), io_off);
164 nir_store_shared(b, intrin->src[0].ssa, off, .write_mask = write_mask);
165 }
166
167 nir_instr_remove(&intrin->instr);
168 return true;
169 }
170
171 static nir_def *
gs_get_vertex_offset(nir_builder * b,lower_esgs_io_state * st,unsigned vertex_index)172 gs_get_vertex_offset(nir_builder *b, lower_esgs_io_state *st, unsigned vertex_index)
173 {
174 nir_def *origin = nir_load_gs_vertex_offset_amd(b, .base = vertex_index);
175 if (!st->gs_triangle_strip_adjacency_fix)
176 return origin;
177
178 unsigned fixed_index;
179 if (st->gfx_level < GFX9) {
180 /* Rotate vertex index by 2. */
181 fixed_index = (vertex_index + 4) % 6;
182 } else {
183 /* This issue has been fixed for GFX10+ */
184 assert(st->gfx_level == GFX9);
185 /* 6 vertex offset are packed to 3 vgprs for GFX9+ */
186 fixed_index = (vertex_index + 2) % 3;
187 }
188 nir_def *fixed = nir_load_gs_vertex_offset_amd(b, .base = fixed_index);
189
190 nir_def *prim_id = nir_load_primitive_id(b);
191 /* odd primitive id use fixed offset */
192 nir_def *cond = nir_i2b(b, nir_iand_imm(b, prim_id, 1));
193 return nir_bcsel(b, cond, fixed, origin);
194 }
195
196 static nir_def *
gs_per_vertex_input_vertex_offset_gfx6(nir_builder * b,lower_esgs_io_state * st,nir_src * vertex_src)197 gs_per_vertex_input_vertex_offset_gfx6(nir_builder *b, lower_esgs_io_state *st,
198 nir_src *vertex_src)
199 {
200 if (nir_src_is_const(*vertex_src))
201 return gs_get_vertex_offset(b, st, nir_src_as_uint(*vertex_src));
202
203 nir_def *vertex_offset = gs_get_vertex_offset(b, st, 0);
204
205 for (unsigned i = 1; i < b->shader->info.gs.vertices_in; ++i) {
206 nir_def *cond = nir_ieq_imm(b, vertex_src->ssa, i);
207 nir_def *elem = gs_get_vertex_offset(b, st, i);
208 vertex_offset = nir_bcsel(b, cond, elem, vertex_offset);
209 }
210
211 return vertex_offset;
212 }
213
214 static nir_def *
gs_per_vertex_input_vertex_offset_gfx9(nir_builder * b,lower_esgs_io_state * st,nir_src * vertex_src)215 gs_per_vertex_input_vertex_offset_gfx9(nir_builder *b, lower_esgs_io_state *st,
216 nir_src *vertex_src)
217 {
218 if (nir_src_is_const(*vertex_src)) {
219 unsigned vertex = nir_src_as_uint(*vertex_src);
220 return nir_ubfe_imm(b, gs_get_vertex_offset(b, st, vertex / 2u),
221 (vertex & 1u) * 16u, 16u);
222 }
223
224 nir_def *vertex_offset = gs_get_vertex_offset(b, st, 0);
225
226 for (unsigned i = 1; i < b->shader->info.gs.vertices_in; i++) {
227 nir_def *cond = nir_ieq_imm(b, vertex_src->ssa, i);
228 nir_def *elem = gs_get_vertex_offset(b, st, i / 2u * 2u);
229 if (i % 2u)
230 elem = nir_ishr_imm(b, elem, 16u);
231
232 vertex_offset = nir_bcsel(b, cond, elem, vertex_offset);
233 }
234
235 return nir_iand_imm(b, vertex_offset, 0xffffu);
236 }
237
238 static nir_def *
gs_per_vertex_input_offset(nir_builder * b,lower_esgs_io_state * st,nir_intrinsic_instr * instr)239 gs_per_vertex_input_offset(nir_builder *b,
240 lower_esgs_io_state *st,
241 nir_intrinsic_instr *instr)
242 {
243 nir_src *vertex_src = nir_get_io_arrayed_index_src(instr);
244 nir_def *vertex_offset = st->gfx_level >= GFX9
245 ? gs_per_vertex_input_vertex_offset_gfx9(b, st, vertex_src)
246 : gs_per_vertex_input_vertex_offset_gfx6(b, st, vertex_src);
247
248 /* Gfx6-8 can't emulate VGT_ESGS_RING_ITEMSIZE because it uses the register to determine
249 * the allocation size of the ESGS ring buffer in memory.
250 */
251 if (st->gfx_level >= GFX9)
252 vertex_offset = nir_imul(b, vertex_offset, nir_load_esgs_vertex_stride_amd(b));
253
254 unsigned base_stride = st->gfx_level >= GFX9 ? 1 : 64 /* Wave size on GFX6-8 */;
255 nir_def *io_off = ac_nir_calc_io_offset(b, instr, nir_imm_int(b, base_stride * 4u), base_stride, st->map_io);
256 nir_def *off = nir_iadd(b, io_off, vertex_offset);
257 return nir_imul_imm(b, off, 4u);
258 }
259
260 static nir_def *
lower_gs_per_vertex_input_load(nir_builder * b,nir_instr * instr,void * state)261 lower_gs_per_vertex_input_load(nir_builder *b,
262 nir_instr *instr,
263 void *state)
264 {
265 lower_esgs_io_state *st = (lower_esgs_io_state *) state;
266 nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
267 nir_def *off = gs_per_vertex_input_offset(b, st, intrin);
268
269 if (st->gfx_level >= GFX9)
270 return nir_load_shared(b, intrin->def.num_components, intrin->def.bit_size, off);
271
272 unsigned wave_size = 64u; /* GFX6-8 only support wave64 */
273 nir_def *ring = nir_load_ring_esgs_amd(b);
274 return emit_split_buffer_load(b, ring, off, nir_imm_zero(b, 1, 32), 4u * wave_size,
275 intrin->def.num_components, intrin->def.bit_size);
276 }
277
278 static bool
filter_load_per_vertex_input(const nir_instr * instr,UNUSED const void * state)279 filter_load_per_vertex_input(const nir_instr *instr, UNUSED const void *state)
280 {
281 return instr->type == nir_instr_type_intrinsic && nir_instr_as_intrinsic(instr)->intrinsic == nir_intrinsic_load_per_vertex_input;
282 }
283
284 void
ac_nir_lower_es_outputs_to_mem(nir_shader * shader,ac_nir_map_io_driver_location map,enum amd_gfx_level gfx_level,unsigned esgs_itemsize)285 ac_nir_lower_es_outputs_to_mem(nir_shader *shader,
286 ac_nir_map_io_driver_location map,
287 enum amd_gfx_level gfx_level,
288 unsigned esgs_itemsize)
289 {
290 lower_esgs_io_state state = {
291 .gfx_level = gfx_level,
292 .esgs_itemsize = esgs_itemsize,
293 .map_io = map,
294 };
295
296 nir_shader_intrinsics_pass(shader, lower_es_output_store,
297 nir_metadata_block_index | nir_metadata_dominance,
298 &state);
299 }
300
301 void
ac_nir_lower_gs_inputs_to_mem(nir_shader * shader,ac_nir_map_io_driver_location map,enum amd_gfx_level gfx_level,bool triangle_strip_adjacency_fix)302 ac_nir_lower_gs_inputs_to_mem(nir_shader *shader,
303 ac_nir_map_io_driver_location map,
304 enum amd_gfx_level gfx_level,
305 bool triangle_strip_adjacency_fix)
306 {
307 lower_esgs_io_state state = {
308 .gfx_level = gfx_level,
309 .map_io = map,
310 .gs_triangle_strip_adjacency_fix = triangle_strip_adjacency_fix,
311 };
312
313 nir_shader_lower_instructions(shader,
314 filter_load_per_vertex_input,
315 lower_gs_per_vertex_input_load,
316 &state);
317 }
318