• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright © 2022 Valve Corporation
3  *
4  * Permission is hereby granted, free of charge, to any person obtaining a
5  * copy of this software and associated documentation files (the "Software"),
6  * to deal in the Software without restriction, including without limitation
7  * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8  * and/or sell copies of the Software, and to permit persons to whom the
9  * Software is furnished to do so, subject to the following conditions:
10  *
11  * The above copyright notice and this permission notice (including the next
12  * paragraph) shall be included in all copies or substantial portions of the
13  * Software.
14  *
15  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18  * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20  * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21  * IN THE SOFTWARE.
22  *
23  */
24 
25 #include "ac_nir.h"
26 #include "nir_builder.h"
27 #include "amdgfxregs.h"
28 #include "u_math.h"
29 
30 /*
31  * These NIR passes are used to lower NIR cross-stage I/O intrinsics
32  * between task and mesh shader stages into the memory accesses
33  * that actually happen on the HW.
34  *
35  */
36 
37 typedef struct {
38    unsigned payload_entry_bytes;
39    unsigned draw_entry_bytes;
40    unsigned num_entries;
41 } lower_tsms_io_state;
42 
43 typedef struct {
44    nir_ssa_def *hw_workgroup_id;
45    nir_ssa_def *api_workgroup_id;
46 } add_first_task_to_workgroup_id_state;
47 
filter_workgroup_id(const nir_instr * instr,UNUSED const void * state)48 static bool filter_workgroup_id(const nir_instr *instr,
49                                 UNUSED const void *state)
50 {
51    if (instr->type != nir_instr_type_intrinsic)
52       return false;
53 
54    nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
55    return intrin->intrinsic == nir_intrinsic_load_workgroup_id;
56 }
57 
58 static nir_ssa_def *
replace_workgroup_id_use_first_task(nir_builder * b,nir_instr * instr,void * state)59 replace_workgroup_id_use_first_task(nir_builder *b,
60                                     nir_instr *instr,
61                                     void *state)
62 {
63    add_first_task_to_workgroup_id_state *s = (add_first_task_to_workgroup_id_state *) state;
64    nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
65 
66    assert(s->hw_workgroup_id);
67 
68    if (s->hw_workgroup_id == &intrin->dest.ssa)
69       return NULL;
70 
71    return s->api_workgroup_id;
72 }
73 
74 void
ac_nir_apply_first_task_to_task_shader(nir_shader * shader)75 ac_nir_apply_first_task_to_task_shader(nir_shader *shader)
76 {
77    /* The draw packets on RDNA2 GPUs don't support adding an offset to the task shader
78     * workgroups, so we have to emulate the firstTask feature for NV_mesh_shader.
79     *
80     * 1. Pass the address of the IB (indirect buffer) from the NV_mesh_shader draw call
81     *    to the shader in an SGPR argument (2 SGPRs for address, 1 SGPR for stride).
82     * 2. Create a descriptor for the IB in the shader.
83     * 3. Load the firstTask value from the IB
84     * 4. Add the firstTask value the workgroup ID and use the result instead of the
85     *    workgroup ID generated by the HW.
86     *
87     * NOTE: This pass must run _before_ lowering the task shader outputs to memory
88     *       accesses. The lowering uses the workgroup ID and that must be unchanged
89     *       because it has to be the real HW workgroup ID.
90     */
91 
92    /* If the shader doesn't use workgroup ID, nothing to do here. */
93    if (!BITSET_TEST(shader->info.system_values_read, SYSTEM_VALUE_WORKGROUP_ID))
94       return;
95 
96    nir_function_impl *impl = nir_shader_get_entrypoint(shader);
97    assert(impl);
98 
99    nir_builder builder;
100    nir_builder *b = &builder; /* This is to avoid the & */
101    nir_builder_init(b, impl);
102    b->cursor = nir_before_cf_list(&impl->body);
103 
104    /* This is the stride passed to vkCmdDrawMeshTasksIndirectNV */
105    nir_ssa_def *ib_stride = nir_load_task_ib_stride(b);
106    nir_ssa_def *zero = nir_imm_int(b, 0);
107    nir_ssa_def *first_task = NULL;
108 
109    /* If the stride is zero, we assume that firstTask is also 0. */
110    nir_if *if_stride = nir_push_if(b, nir_ine(b, ib_stride, zero));
111    {
112       /* Address of the IB (indirect buffer) used by the current draw call. */
113       nir_ssa_def *ib_addr = nir_load_task_ib_addr(b);
114 
115       /* Compose a 64-bit address from the IB address. */
116       nir_ssa_def *addr = nir_pack_64_2x32_split(b, nir_channel(b, ib_addr, 0), nir_channel(b, ib_addr, 1));
117       /* The IB needs to be addressed by draw ID * stride. */
118       addr = nir_iadd(b, addr, nir_u2u64(b, nir_imul(b, nir_load_draw_id(b), ib_stride)));
119       /* Byte offset of the firstTask field in VkDrawMeshTasksIndirectCommandNV. */
120       addr = nir_iadd_imm(b, addr, 4);
121 
122       first_task = nir_build_load_global(b, 1, 32, addr, .access = ACCESS_NON_WRITEABLE | ACCESS_COHERENT);
123    }
124    nir_pop_if(b, if_stride);
125    first_task = nir_if_phi(b, first_task, zero);
126 
127    /* NV_mesh_shader workgroups are 1 dimensional.
128     * Apply firstTask to the X dimension, but leave Y and Z intact.
129     */
130    nir_ssa_def *hw_workgroup_id = nir_load_workgroup_id(b, 32);
131    nir_ssa_def *api_workgroup_id_x = nir_iadd(b, nir_channel(b, hw_workgroup_id, 0), first_task);
132    nir_ssa_def *api_workgroup_id = nir_vector_insert_imm(b, hw_workgroup_id, api_workgroup_id_x, 0);
133 
134    add_first_task_to_workgroup_id_state state = {
135       .hw_workgroup_id = hw_workgroup_id,
136       .api_workgroup_id = api_workgroup_id,
137    };
138    nir_shader_lower_instructions(shader,
139                                  filter_workgroup_id,
140                                  replace_workgroup_id_use_first_task,
141                                  &state);
142 
143    nir_validate_shader(shader, "after including firstTask in the task shader workgroup ID");
144 }
145 
146 static nir_ssa_def *
task_workgroup_index(nir_builder * b,lower_tsms_io_state * s)147 task_workgroup_index(nir_builder *b,
148                      lower_tsms_io_state *s)
149 {
150    nir_ssa_def *id = nir_load_workgroup_id(b, 32);
151 
152    nir_ssa_def *x = nir_channel(b, id, 0);
153    nir_ssa_def *y = nir_channel(b, id, 1);
154    nir_ssa_def *z = nir_channel(b, id, 2);
155 
156    nir_ssa_def *grid_size = nir_load_num_workgroups(b, 32);
157    nir_ssa_def *grid_size_x = nir_channel(b, grid_size, 0);
158    nir_ssa_def *grid_size_y = nir_channel(b, grid_size, 1);
159 
160    return nir_iadd(b, nir_imul(b, nir_imul(b, grid_size_x, grid_size_y), z),
161                       nir_iadd(b, nir_imul(b, grid_size_x, y), x));
162 }
163 
164 static nir_ssa_def *
task_ring_entry_index(nir_builder * b,lower_tsms_io_state * s)165 task_ring_entry_index(nir_builder *b,
166                       lower_tsms_io_state *s)
167 {
168    /* Task shader ring_entry shader argument:
169     *
170     * - It's a copy of write_ptr[31:0] from the task control buffer.
171     * - The same value (which is the initial value at dispatch)
172     *   seems to be copied to all workgroups in the same dispatch,
173     *   therefore a workgroup index needs to be added.
174     * - write_ptr must be initialized to num_entries so ring_entry needs
175     *   AND with num_entries - 1 to get the correct meaning.
176     *   Note that num_entries must be a power of two.
177     */
178    nir_ssa_def *ring_entry = nir_load_task_ring_entry_amd(b);
179    nir_ssa_def *idx = nir_iadd_nuw(b, ring_entry, task_workgroup_index(b, s));
180    return nir_iand_imm(b, idx, s->num_entries - 1);
181 }
182 
183 static nir_ssa_def *
task_draw_ready_bit(nir_builder * b,lower_tsms_io_state * s)184 task_draw_ready_bit(nir_builder *b,
185                     lower_tsms_io_state *s)
186 {
187    /* Value of the ready bit is 1 for odd and 0 for even passes through the draw ring.
188     *
189     * The ring_entry is a copy of the write_ptr. We use that to determine whether
190     * the current pass through the draw ring is odd or even, so we can write the
191     * correct value to the draw ready bit.
192     *
193     * This tells the firmware that it can now start launching mesh shader workgroups.
194     * The encoding of the last dword of the draw ring entry is:
195     * - bit 0: Draw ready bit.
196     *          Its meaning flips on every pass through the entry.
197     * - bit 1: Packet end bit.
198     *          The firmware uses this to mark the entry after the last one
199     *          used by the current task dispatch.
200     * - bits [2:31] unused.
201     *
202     * Task shaders MUST write the draw ready bit to the draw ring
203     * before they finish. The firmware waits for the shader to write
204     * this bit before it reads the mesh dispatch size to launch the
205     * mesh shader workgroups.
206     *
207     * If the task shader doesn't write this bit, the HW hangs.
208     */
209 
210    nir_ssa_def *ring_entry = nir_load_task_ring_entry_amd(b);
211    nir_ssa_def *workgroup_index = task_workgroup_index(b, s);
212 
213    nir_ssa_def *idx = nir_iadd_nuw(b, ring_entry, workgroup_index);
214    return nir_ubfe(b, idx, nir_imm_int(b, util_bitcount(s->num_entries - 1)), nir_imm_int(b, 1));
215 }
216 
217 static nir_ssa_def *
mesh_ring_entry_index(nir_builder * b,lower_tsms_io_state * s)218 mesh_ring_entry_index(nir_builder *b,
219                       lower_tsms_io_state *s)
220 {
221    /* Mesh shader ring_entry shader argument:
222     *
223     * - It's a copy of the read_ptr[31:0] from the task control buffer.
224     * - All workgroups in the same task->mesh dispatch get the same value,
225     *   which is fine because they need to read the same entry.
226     * - read_ptr must be initialized to num_entries so ring_entry needs
227     *   AND with num_entries - 1 to get the correct meaning.
228     *   Note that num_entries must be a power of two.
229     */
230    return nir_iand_imm(b, nir_load_task_ring_entry_amd(b), s->num_entries - 1);
231 }
232 
233 static void
task_write_draw_ring(nir_builder * b,nir_ssa_def * store_val,unsigned const_off,lower_tsms_io_state * s)234 task_write_draw_ring(nir_builder *b,
235                      nir_ssa_def *store_val,
236                      unsigned const_off,
237                      lower_tsms_io_state *s)
238 {
239    nir_ssa_def *ptr = task_ring_entry_index(b, s);
240    nir_ssa_def *ring = nir_load_ring_task_draw_amd(b);
241    nir_ssa_def *scalar_off = nir_imul_imm(b, ptr, s->draw_entry_bytes);
242    nir_ssa_def *vector_off = nir_imm_int(b, 0);
243 
244    nir_store_buffer_amd(b, store_val, ring, vector_off, scalar_off,
245                         .base = const_off, .memory_modes = nir_var_shader_out);
246 }
247 
248 static bool
filter_task_intrinsics(const nir_instr * instr,UNUSED const void * state)249 filter_task_intrinsics(const nir_instr *instr,
250                        UNUSED const void *state)
251 {
252    if (instr->type != nir_instr_type_intrinsic)
253       return false;
254 
255    nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
256    return intrin->intrinsic == nir_intrinsic_launch_mesh_workgroups ||
257           intrin->intrinsic == nir_intrinsic_store_task_payload ||
258           intrin->intrinsic == nir_intrinsic_load_task_payload;
259 }
260 
261 static nir_ssa_def *
lower_task_launch_mesh_workgroups(nir_builder * b,nir_intrinsic_instr * intrin,lower_tsms_io_state * s)262 lower_task_launch_mesh_workgroups(nir_builder *b,
263                                   nir_intrinsic_instr *intrin,
264                                   lower_tsms_io_state *s)
265 {
266    /* This intrinsic must be always in uniform control flow,
267     * so we assume that all invocations are active here.
268     */
269 
270    /* Wait for all necessary stores to finish. */
271    nir_scoped_barrier(b, .execution_scope = NIR_SCOPE_WORKGROUP,
272                          .memory_scope = NIR_SCOPE_WORKGROUP,
273                          .memory_semantics = NIR_MEMORY_ACQ_REL,
274                          .memory_modes = nir_var_mem_task_payload | nir_var_shader_out |
275                                          nir_var_mem_ssbo | nir_var_mem_global);
276 
277    /* On the first invocation, write the full draw ring entry. */
278    nir_ssa_def *invocation_index = nir_load_local_invocation_index(b);
279    nir_if *if_invocation_index_zero = nir_push_if(b, nir_ieq_imm(b, invocation_index, 0));
280    {
281       nir_ssa_def *dimensions = intrin->src[0].ssa;
282       nir_ssa_def *x = nir_channel(b, dimensions, 0);
283       nir_ssa_def *y = nir_channel(b, dimensions, 1);
284       nir_ssa_def *z = nir_channel(b, dimensions, 2);
285       nir_ssa_def *rdy = task_draw_ready_bit(b, s);
286       nir_ssa_def *store_val = nir_vec4(b, x, y, z, rdy);
287       task_write_draw_ring(b, store_val, 0, s);
288    }
289    nir_pop_if(b, if_invocation_index_zero);
290 
291    return NIR_LOWER_INSTR_PROGRESS_REPLACE;
292 }
293 
294 static nir_ssa_def *
lower_task_payload_store(nir_builder * b,nir_intrinsic_instr * intrin,lower_tsms_io_state * s)295 lower_task_payload_store(nir_builder *b,
296                          nir_intrinsic_instr *intrin,
297                          lower_tsms_io_state *s)
298 {
299    unsigned write_mask = nir_intrinsic_write_mask(intrin);
300    unsigned base = nir_intrinsic_base(intrin);
301 
302    nir_ssa_def *store_val = intrin->src[0].ssa;
303    nir_ssa_def *addr = intrin->src[1].ssa;
304    nir_ssa_def *ring = nir_load_ring_task_payload_amd(b);
305    nir_ssa_def *ptr = task_ring_entry_index(b, s);
306    nir_ssa_def *ring_off = nir_imul_imm(b, ptr, s->payload_entry_bytes);
307 
308    nir_store_buffer_amd(b, store_val, ring, addr, ring_off, .base = base,
309                         .write_mask = write_mask,
310                         .memory_modes = nir_var_mem_task_payload);
311 
312    return NIR_LOWER_INSTR_PROGRESS_REPLACE;
313 }
314 
315 static nir_ssa_def *
lower_taskmesh_payload_load(nir_builder * b,nir_intrinsic_instr * intrin,lower_tsms_io_state * s)316 lower_taskmesh_payload_load(nir_builder *b,
317                             nir_intrinsic_instr *intrin,
318                             lower_tsms_io_state *s)
319 {
320    unsigned base = nir_intrinsic_base(intrin);
321    unsigned num_components = intrin->dest.ssa.num_components;
322    unsigned bit_size = intrin->dest.ssa.bit_size;
323 
324    nir_ssa_def *ptr =
325       b->shader->info.stage == MESA_SHADER_TASK ?
326       task_ring_entry_index(b, s) :
327       mesh_ring_entry_index(b, s);
328 
329    nir_ssa_def *addr = intrin->src[0].ssa;
330    nir_ssa_def *ring = nir_load_ring_task_payload_amd(b);
331    nir_ssa_def *ring_off = nir_imul_imm(b, ptr, s->payload_entry_bytes);
332 
333    return nir_load_buffer_amd(b, num_components, bit_size, ring, addr, ring_off, .base = base,
334                               .memory_modes = nir_var_mem_task_payload);
335 }
336 
337 static nir_ssa_def *
lower_task_intrinsics(nir_builder * b,nir_instr * instr,void * state)338 lower_task_intrinsics(nir_builder *b,
339                       nir_instr *instr,
340                       void *state)
341 {
342    assert(instr->type == nir_instr_type_intrinsic);
343    nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
344    lower_tsms_io_state *s = (lower_tsms_io_state *)state;
345 
346    switch (intrin->intrinsic) {
347       case nir_intrinsic_store_task_payload:
348          return lower_task_payload_store(b, intrin, s);
349       case nir_intrinsic_load_task_payload:
350          return lower_taskmesh_payload_load(b, intrin, s);
351       case nir_intrinsic_launch_mesh_workgroups:
352          return lower_task_launch_mesh_workgroups(b, intrin, s);
353       default:
354          unreachable("unsupported task shader intrinsic");
355    }
356 }
357 
358 void
ac_nir_lower_task_outputs_to_mem(nir_shader * shader,unsigned task_payload_entry_bytes,unsigned task_num_entries)359 ac_nir_lower_task_outputs_to_mem(nir_shader *shader,
360                                  unsigned task_payload_entry_bytes,
361                                  unsigned task_num_entries)
362 {
363    assert(util_is_power_of_two_nonzero(task_num_entries));
364 
365    nir_lower_task_shader_options lower_ts_opt = {
366       .payload_to_shared_for_atomics = true,
367    };
368    NIR_PASS(_, shader, nir_lower_task_shader, lower_ts_opt);
369 
370    lower_tsms_io_state state = {
371       .draw_entry_bytes = 16,
372       .payload_entry_bytes = task_payload_entry_bytes,
373       .num_entries = task_num_entries,
374    };
375 
376    nir_function_impl *impl = nir_shader_get_entrypoint(shader);
377    nir_builder builder;
378    nir_builder *b = &builder; /* This is to avoid the & */
379    nir_builder_init(b, impl);
380 
381    nir_shader_lower_instructions(shader,
382                                  filter_task_intrinsics,
383                                  lower_task_intrinsics,
384                                  &state);
385 
386    nir_metadata_preserve(impl, nir_metadata_none);
387    nir_validate_shader(shader, "after lowering task shader outputs to memory stores");
388 }
389 
390 static bool
filter_mesh_input_load(const nir_instr * instr,UNUSED const void * state)391 filter_mesh_input_load(const nir_instr *instr,
392                        UNUSED const void *state)
393 {
394    if (instr->type != nir_instr_type_intrinsic)
395       return false;
396 
397    nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
398    return intrin->intrinsic == nir_intrinsic_load_task_payload;
399 }
400 
401 static nir_ssa_def *
lower_mesh_intrinsics(nir_builder * b,nir_instr * instr,void * state)402 lower_mesh_intrinsics(nir_builder *b,
403                       nir_instr *instr,
404                       void *state)
405 {
406    assert(instr->type == nir_instr_type_intrinsic);
407    nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
408    lower_tsms_io_state *s = (lower_tsms_io_state *)state;
409 
410    if (intrin->intrinsic == nir_intrinsic_load_task_payload)
411       return lower_taskmesh_payload_load(b, intrin, s);
412    else
413       unreachable("unsupported mesh shader intrinsic");
414 }
415 
416 void
ac_nir_lower_mesh_inputs_to_mem(nir_shader * shader,unsigned task_payload_entry_bytes,unsigned task_num_entries)417 ac_nir_lower_mesh_inputs_to_mem(nir_shader *shader,
418                                 unsigned task_payload_entry_bytes,
419                                 unsigned task_num_entries)
420 {
421    assert(util_is_power_of_two_nonzero(task_num_entries));
422 
423    lower_tsms_io_state state = {
424       .draw_entry_bytes = 16,
425       .payload_entry_bytes = task_payload_entry_bytes,
426       .num_entries = task_num_entries,
427    };
428 
429    nir_shader_lower_instructions(shader,
430                                  filter_mesh_input_load,
431                                  lower_mesh_intrinsics,
432                                  &state);
433 }
434