• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright © 2021 Valve Corporation
3  *
4  * SPDX-License-Identifier: MIT
5  */
6 
7 #include "ac_nir.h"
8 #include "ac_nir_helpers.h"
9 
10 #include "nir_builder.h"
11 
12 typedef struct {
13    const struct ac_shader_args *const args;
14    const enum amd_gfx_level gfx_level;
15    bool has_ls_vgpr_init_bug;
16    unsigned wave_size;
17    unsigned workgroup_size;
18    const enum ac_hw_stage hw_stage;
19 
20    nir_def *vertex_id;
21    nir_def *instance_id;
22    nir_def *vs_rel_patch_id;
23    nir_def *tes_u;
24    nir_def *tes_v;
25    nir_def *tes_patch_id;
26    nir_def *tes_rel_patch_id;
27 } lower_intrinsics_to_args_state;
28 
29 static nir_def *
preload_arg(lower_intrinsics_to_args_state * s,nir_function_impl * impl,struct ac_arg arg,struct ac_arg ls_buggy_arg,unsigned upper_bound)30 preload_arg(lower_intrinsics_to_args_state *s, nir_function_impl *impl, struct ac_arg arg,
31             struct ac_arg ls_buggy_arg, unsigned upper_bound)
32 {
33    nir_builder start_b = nir_builder_at(nir_before_impl(impl));
34    nir_def *value = ac_nir_load_arg_upper_bound(&start_b, s->args, arg, upper_bound);
35 
36    /* If there are no HS threads, SPI mistakenly loads the LS VGPRs starting at VGPR 0. */
37    if ((s->hw_stage == AC_HW_LOCAL_SHADER || s->hw_stage == AC_HW_HULL_SHADER) &&
38        s->has_ls_vgpr_init_bug) {
39       nir_def *count = ac_nir_unpack_arg(&start_b, s->args, s->args->merged_wave_info, 8, 8);
40       nir_def *hs_empty = nir_ieq_imm(&start_b, count, 0);
41       value = nir_bcsel(&start_b, hs_empty,
42                         ac_nir_load_arg_upper_bound(&start_b, s->args, ls_buggy_arg, upper_bound),
43                         value);
44    }
45    return value;
46 }
47 
48 static nir_def *
load_subgroup_id_lowered(lower_intrinsics_to_args_state * s,nir_builder * b)49 load_subgroup_id_lowered(lower_intrinsics_to_args_state *s, nir_builder *b)
50 {
51    if (s->workgroup_size <= s->wave_size) {
52       return nir_imm_int(b, 0);
53    } else if (s->hw_stage == AC_HW_COMPUTE_SHADER) {
54       assert(s->gfx_level < GFX12 && s->args->tg_size.used);
55 
56       if (s->gfx_level >= GFX10_3) {
57          return ac_nir_unpack_arg(b, s->args, s->args->tg_size, 20, 5);
58       } else {
59          /* GFX6-10 don't actually support a wave id, but we can
60           * use the ordered id because ORDERED_APPEND_* is set to
61           * zero in the compute dispatch initiator.
62           */
63          return ac_nir_unpack_arg(b, s->args, s->args->tg_size, 6, 6);
64       }
65    } else if (s->hw_stage == AC_HW_HULL_SHADER && s->gfx_level >= GFX11) {
66       assert(s->args->tcs_wave_id.used);
67       return ac_nir_unpack_arg(b, s->args, s->args->tcs_wave_id, 0, 3);
68    } else if (s->hw_stage == AC_HW_LEGACY_GEOMETRY_SHADER ||
69               s->hw_stage == AC_HW_NEXT_GEN_GEOMETRY_SHADER) {
70       assert(s->args->merged_wave_info.used);
71       return ac_nir_unpack_arg(b, s->args, s->args->merged_wave_info, 24, 4);
72    } else {
73       return nir_imm_int(b, 0);
74    }
75 }
76 
77 static bool
lower_intrinsic_to_arg(nir_builder * b,nir_instr * instr,void * state)78 lower_intrinsic_to_arg(nir_builder *b, nir_instr *instr, void *state)
79 {
80    if (instr->type != nir_instr_type_intrinsic)
81       return false;
82 
83    lower_intrinsics_to_args_state *s = (lower_intrinsics_to_args_state *)state;
84    nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
85    nir_def *replacement = NULL;
86    b->cursor = nir_after_instr(&intrin->instr);
87 
88    switch (intrin->intrinsic) {
89    case nir_intrinsic_load_subgroup_id:
90       if (s->gfx_level >= GFX12 && s->hw_stage == AC_HW_COMPUTE_SHADER)
91          return false; /* Lowered in backend compilers. */
92       replacement = load_subgroup_id_lowered(s, b);
93       break;
94    case nir_intrinsic_load_num_subgroups: {
95       if (s->hw_stage == AC_HW_COMPUTE_SHADER) {
96          assert(s->args->tg_size.used);
97          replacement = ac_nir_unpack_arg(b, s->args, s->args->tg_size, 0, 6);
98       } else if (s->hw_stage == AC_HW_LEGACY_GEOMETRY_SHADER ||
99                  s->hw_stage == AC_HW_NEXT_GEN_GEOMETRY_SHADER) {
100          assert(s->args->merged_wave_info.used);
101          replacement = ac_nir_unpack_arg(b, s->args, s->args->merged_wave_info, 28, 4);
102       } else {
103          replacement = nir_imm_int(b, 1);
104       }
105 
106       break;
107    }
108    case nir_intrinsic_load_workgroup_id:
109       if (b->shader->info.stage == MESA_SHADER_MESH) {
110          /* This lowering is only valid with fast_launch = 2, otherwise we assume that
111           * lower_workgroup_id_to_index removed any uses of the workgroup id by this point.
112           */
113          assert(s->gfx_level >= GFX11);
114          nir_def *xy = ac_nir_load_arg(b, s->args, s->args->tess_offchip_offset);
115          nir_def *z = ac_nir_load_arg(b, s->args, s->args->gs_attr_offset);
116          replacement = nir_vec3(b, nir_extract_u16(b, xy, nir_imm_int(b, 0)),
117                                 nir_extract_u16(b, xy, nir_imm_int(b, 1)),
118                                 nir_extract_u16(b, z, nir_imm_int(b, 1)));
119       } else {
120          return false;
121       }
122       break;
123    case nir_intrinsic_load_pixel_coord:
124       replacement = nir_unpack_32_2x16(b, ac_nir_load_arg(b, s->args, s->args->pos_fixed_pt));
125       break;
126    case nir_intrinsic_load_frag_coord:
127       replacement = nir_vec4(b, ac_nir_load_arg(b, s->args, s->args->frag_pos[0]),
128                              ac_nir_load_arg(b, s->args, s->args->frag_pos[1]),
129                              ac_nir_load_arg(b, s->args, s->args->frag_pos[2]),
130                              ac_nir_load_arg(b, s->args, s->args->frag_pos[3]));
131       break;
132    case nir_intrinsic_load_local_invocation_id: {
133       unsigned num_bits[3];
134       nir_def *vec[3];
135 
136       for (unsigned i = 0; i < 3; i++) {
137          bool has_chan = b->shader->info.workgroup_size_variable ||
138                          b->shader->info.workgroup_size[i] > 1;
139          /* Extract as few bits possible - we want the constant to be an inline constant
140           * instead of a literal.
141           */
142          num_bits[i] = !has_chan ? 0 :
143                        b->shader->info.workgroup_size_variable ?
144                                    10 : util_logbase2_ceil(b->shader->info.workgroup_size[i]);
145       }
146 
147       if (s->args->local_invocation_ids_packed.used) {
148          unsigned extract_bits[3];
149          memcpy(extract_bits, num_bits, sizeof(num_bits));
150 
151          /* Thread IDs are packed in VGPR0, 10 bits per component.
152           * Always extract all remaining bits if later ID components are always 0, which will
153           * translate to a bit shift.
154           */
155          if (num_bits[2]) {
156             extract_bits[2] = 12; /* Z > 0 */
157          } else if (num_bits[1])
158             extract_bits[1] = 22; /* Y > 0, Z == 0 */
159          else if (num_bits[0])
160             extract_bits[0] = 32; /* X > 0, Y == 0, Z == 0 */
161 
162          nir_def *ids_packed =
163             ac_nir_load_arg_upper_bound(b, s->args, s->args->local_invocation_ids_packed,
164                                         b->shader->info.workgroup_size_variable ?
165                                            0 : ((b->shader->info.workgroup_size[0] - 1) |
166                                                 ((b->shader->info.workgroup_size[1] - 1) << 10) |
167                                                 ((b->shader->info.workgroup_size[2] - 1) << 20)));
168 
169          for (unsigned i = 0; i < 3; i++) {
170             vec[i] = !num_bits[i] ? nir_imm_int(b, 0) :
171                                     ac_nir_unpack_value(b,  ids_packed, i * 10, extract_bits[i]);
172          }
173       } else {
174          const struct ac_arg ids[] = {
175             s->args->local_invocation_id_x,
176             s->args->local_invocation_id_y,
177             s->args->local_invocation_id_z,
178          };
179 
180          for (unsigned i = 0; i < 3; i++) {
181             unsigned max = b->shader->info.workgroup_size_variable ?
182                               1023 : (b->shader->info.workgroup_size[i] - 1);
183             vec[i] = !num_bits[i] ? nir_imm_int(b, 0) :
184                                     ac_nir_load_arg_upper_bound(b, s->args, ids[i], max);
185          }
186       }
187       replacement = nir_vec(b, vec, 3);
188       break;
189    }
190    case nir_intrinsic_load_merged_wave_info_amd:
191       replacement = ac_nir_load_arg(b, s->args, s->args->merged_wave_info);
192       break;
193    case nir_intrinsic_load_workgroup_num_input_vertices_amd:
194       replacement = ac_nir_unpack_arg(b, s->args, s->args->gs_tg_info, 12, 9);
195       break;
196    case nir_intrinsic_load_workgroup_num_input_primitives_amd:
197       replacement = ac_nir_unpack_arg(b, s->args, s->args->gs_tg_info, 22, 9);
198       break;
199    case nir_intrinsic_load_packed_passthrough_primitive_amd:
200       /* NGG passthrough mode: the HW already packs the primitive export value to a single register.
201        */
202       replacement = ac_nir_load_arg(b, s->args, s->args->gs_vtx_offset[0]);
203       break;
204    case nir_intrinsic_load_ordered_id_amd:
205       replacement = ac_nir_unpack_arg(b, s->args, s->args->gs_tg_info, 0, 12);
206       break;
207    case nir_intrinsic_load_ring_tess_offchip_offset_amd:
208       replacement = ac_nir_load_arg(b, s->args, s->args->tess_offchip_offset);
209       break;
210    case nir_intrinsic_load_ring_tess_factors_offset_amd:
211       replacement = ac_nir_load_arg(b, s->args, s->args->tcs_factor_offset);
212       break;
213    case nir_intrinsic_load_ring_es2gs_offset_amd:
214       replacement = ac_nir_load_arg(b, s->args, s->args->es2gs_offset);
215       break;
216    case nir_intrinsic_load_ring_gs2vs_offset_amd:
217       replacement = ac_nir_load_arg(b, s->args, s->args->gs2vs_offset);
218       break;
219    case nir_intrinsic_load_gs_vertex_offset_amd:
220       replacement = ac_nir_load_arg(b, s->args, s->args->gs_vtx_offset[nir_intrinsic_base(intrin)]);
221       break;
222    case nir_intrinsic_load_streamout_config_amd:
223       replacement = ac_nir_load_arg(b, s->args, s->args->streamout_config);
224       break;
225    case nir_intrinsic_load_streamout_write_index_amd:
226       replacement = ac_nir_load_arg(b, s->args, s->args->streamout_write_index);
227       break;
228    case nir_intrinsic_load_streamout_offset_amd:
229       replacement = ac_nir_load_arg(b, s->args, s->args->streamout_offset[nir_intrinsic_base(intrin)]);
230       break;
231    case nir_intrinsic_load_ring_attr_offset_amd: {
232       nir_def *ring_attr_offset = ac_nir_load_arg(b, s->args, s->args->gs_attr_offset);
233       replacement = nir_ishl_imm(b, nir_ubfe_imm(b, ring_attr_offset, 0, 15), 9); /* 512b increments. */
234       break;
235    }
236    case nir_intrinsic_load_first_vertex:
237       replacement = ac_nir_load_arg(b, s->args, s->args->base_vertex);
238       break;
239    case nir_intrinsic_load_base_instance:
240       replacement = ac_nir_load_arg(b, s->args, s->args->start_instance);
241       break;
242    case nir_intrinsic_load_draw_id:
243       replacement = ac_nir_load_arg(b, s->args, s->args->draw_id);
244       break;
245    case nir_intrinsic_load_view_index:
246       replacement = ac_nir_load_arg_upper_bound(b, s->args, s->args->view_index, 1);
247       break;
248    case nir_intrinsic_load_invocation_id:
249       if (b->shader->info.stage == MESA_SHADER_TESS_CTRL) {
250          replacement = ac_nir_unpack_arg(b, s->args, s->args->tcs_rel_ids, 8, 5);
251       } else if (b->shader->info.stage == MESA_SHADER_GEOMETRY) {
252          if (s->gfx_level >= GFX12) {
253             replacement = ac_nir_unpack_arg(b, s->args, s->args->gs_vtx_offset[0], 27, 5);
254          } else if (s->gfx_level >= GFX10) {
255             replacement = ac_nir_unpack_arg(b, s->args, s->args->gs_invocation_id, 0, 5);
256          } else {
257             replacement = ac_nir_load_arg_upper_bound(b, s->args, s->args->gs_invocation_id, 31);
258          }
259       } else {
260          unreachable("unexpected shader stage");
261       }
262       break;
263    case nir_intrinsic_load_sample_id:
264       replacement = ac_nir_unpack_arg(b, s->args, s->args->ancillary, 8, 4);
265       break;
266    case nir_intrinsic_load_sample_pos:
267       replacement = nir_vec2(b, nir_ffract(b, ac_nir_load_arg(b, s->args, s->args->frag_pos[0])),
268                              nir_ffract(b, ac_nir_load_arg(b, s->args, s->args->frag_pos[1])));
269       break;
270    case nir_intrinsic_load_frag_shading_rate: {
271       /* VRS Rate X = Ancillary[2:3]
272        * VRS Rate Y = Ancillary[4:5]
273        */
274       nir_def *x_rate = ac_nir_unpack_arg(b, s->args, s->args->ancillary, 2, 2);
275       nir_def *y_rate = ac_nir_unpack_arg(b, s->args, s->args->ancillary, 4, 2);
276 
277       /* xRate = xRate == 0x1 ? Horizontal2Pixels : None. */
278       x_rate = nir_bcsel(b, nir_ieq_imm(b, x_rate, 1), nir_imm_int(b, 4), nir_imm_int(b, 0));
279 
280       /* yRate = yRate == 0x1 ? Vertical2Pixels : None. */
281       y_rate = nir_bcsel(b, nir_ieq_imm(b, y_rate, 1), nir_imm_int(b, 1), nir_imm_int(b, 0));
282       replacement = nir_ior(b, x_rate, y_rate);
283       break;
284    }
285    case nir_intrinsic_load_front_face:
286       replacement = nir_fgt_imm(b, ac_nir_load_arg(b, s->args, s->args->front_face), 0);
287       break;
288    case nir_intrinsic_load_front_face_fsign:
289       replacement = ac_nir_load_arg(b, s->args, s->args->front_face);
290       break;
291    case nir_intrinsic_load_layer_id:
292       replacement = ac_nir_unpack_arg(b, s->args, s->args->ancillary,
293                                       16, s->gfx_level >= GFX12 ? 14 : 13);
294       break;
295    case nir_intrinsic_load_barycentric_optimize_amd: {
296       nir_def *prim_mask = ac_nir_load_arg(b, s->args, s->args->prim_mask);
297       /* enabled when bit 31 is set */
298       replacement = nir_ilt_imm(b, prim_mask, 0);
299       break;
300    }
301    case nir_intrinsic_load_barycentric_pixel:
302       if (nir_intrinsic_interp_mode(intrin) == INTERP_MODE_NOPERSPECTIVE)
303          replacement = ac_nir_load_arg(b, s->args, s->args->linear_center);
304       else
305          replacement = ac_nir_load_arg(b, s->args, s->args->persp_center);
306       nir_intrinsic_set_flags(nir_instr_as_intrinsic(replacement->parent_instr),
307                               AC_VECTOR_ARG_FLAG(AC_VECTOR_ARG_INTERP_MODE,
308                                                  nir_intrinsic_interp_mode(intrin)));
309       break;
310    case nir_intrinsic_load_barycentric_centroid:
311       if (nir_intrinsic_interp_mode(intrin) == INTERP_MODE_NOPERSPECTIVE)
312          replacement = ac_nir_load_arg(b, s->args, s->args->linear_centroid);
313       else
314          replacement = ac_nir_load_arg(b, s->args, s->args->persp_centroid);
315       nir_intrinsic_set_flags(nir_instr_as_intrinsic(replacement->parent_instr),
316                               AC_VECTOR_ARG_FLAG(AC_VECTOR_ARG_INTERP_MODE,
317                                                  nir_intrinsic_interp_mode(intrin)));
318       break;
319    case nir_intrinsic_load_barycentric_sample:
320       if (nir_intrinsic_interp_mode(intrin) == INTERP_MODE_NOPERSPECTIVE)
321          replacement = ac_nir_load_arg(b, s->args, s->args->linear_sample);
322       else
323          replacement = ac_nir_load_arg(b, s->args, s->args->persp_sample);
324       nir_intrinsic_set_flags(nir_instr_as_intrinsic(replacement->parent_instr),
325                               AC_VECTOR_ARG_FLAG(AC_VECTOR_ARG_INTERP_MODE,
326                                                  nir_intrinsic_interp_mode(intrin)));
327       break;
328    case nir_intrinsic_load_barycentric_model:
329       replacement = ac_nir_load_arg(b, s->args, s->args->pull_model);
330       break;
331    case nir_intrinsic_load_barycentric_at_offset: {
332       nir_def *baryc = nir_intrinsic_interp_mode(intrin) == INTERP_MODE_NOPERSPECTIVE ?
333                           ac_nir_load_arg(b, s->args, s->args->linear_center) :
334                           ac_nir_load_arg(b, s->args, s->args->persp_center);
335       nir_def *i = nir_channel(b, baryc, 0);
336       nir_def *j = nir_channel(b, baryc, 1);
337       nir_def *offset_x = nir_channel(b, intrin->src[0].ssa, 0);
338       nir_def *offset_y = nir_channel(b, intrin->src[0].ssa, 1);
339       nir_def *ddx_i = nir_ddx(b, i);
340       nir_def *ddx_j = nir_ddx(b, j);
341       nir_def *ddy_i = nir_ddy(b, i);
342       nir_def *ddy_j = nir_ddy(b, j);
343 
344       /* Interpolate standard barycentrics by offset. */
345       nir_def *offset_i = nir_ffma(b, ddy_i, offset_y, nir_ffma(b, ddx_i, offset_x, i));
346       nir_def *offset_j = nir_ffma(b, ddy_j, offset_y, nir_ffma(b, ddx_j, offset_x, j));
347       replacement = nir_vec2(b, offset_i, offset_j);
348       break;
349    }
350    case nir_intrinsic_load_gs_wave_id_amd:
351       if (s->args->merged_wave_info.used)
352          replacement = ac_nir_unpack_arg(b, s->args, s->args->merged_wave_info, 16, 8);
353       else if (s->args->gs_wave_id.used)
354          replacement = ac_nir_load_arg(b, s->args, s->args->gs_wave_id);
355       else
356          unreachable("Shader doesn't have GS wave ID.");
357       break;
358    case nir_intrinsic_overwrite_vs_arguments_amd:
359       s->vertex_id = intrin->src[0].ssa;
360       s->instance_id = intrin->src[1].ssa;
361       nir_instr_remove(instr);
362       return true;
363    case nir_intrinsic_overwrite_tes_arguments_amd:
364       s->tes_u = intrin->src[0].ssa;
365       s->tes_v = intrin->src[1].ssa;
366       s->tes_patch_id = intrin->src[2].ssa;
367       s->tes_rel_patch_id = intrin->src[3].ssa;
368       nir_instr_remove(instr);
369       return true;
370    case nir_intrinsic_load_vertex_id_zero_base:
371       if (!s->vertex_id)
372          s->vertex_id = preload_arg(s, b->impl, s->args->vertex_id, s->args->tcs_patch_id, 0);
373       replacement = s->vertex_id;
374       break;
375    case nir_intrinsic_load_instance_id:
376       if (!s->instance_id)
377          s->instance_id = preload_arg(s, b->impl, s->args->instance_id, s->args->vertex_id, 0);
378       replacement = s->instance_id;
379       break;
380    case nir_intrinsic_load_tess_rel_patch_id_amd:
381       if (b->shader->info.stage == MESA_SHADER_TESS_CTRL) {
382          replacement = ac_nir_unpack_arg(b, s->args, s->args->tcs_rel_ids, 0, 8);
383       } else if (b->shader->info.stage == MESA_SHADER_TESS_EVAL) {
384          if (s->tes_rel_patch_id) {
385             replacement = s->tes_rel_patch_id;
386          } else {
387             replacement = ac_nir_load_arg(b, s->args, s->args->tes_rel_patch_id);
388             if (b->shader->info.tess.tcs_vertices_out) {
389                /* Setting an upper bound like this will actually make it possible
390                 * to optimize some multiplications (in address calculations) so that
391                 * constant additions can be added to the const offset in memory load instructions.
392                 */
393                nir_intrinsic_set_arg_upper_bound_u32_amd(nir_instr_as_intrinsic(replacement->parent_instr),
394                                                          2048 / b->shader->info.tess.tcs_vertices_out);
395             }
396          }
397       } else {
398          unreachable("invalid stage");
399       }
400       break;
401    case nir_intrinsic_load_primitive_id:
402       if (b->shader->info.stage == MESA_SHADER_GEOMETRY) {
403          replacement = ac_nir_load_arg(b, s->args, s->args->gs_prim_id);
404       } else if (b->shader->info.stage == MESA_SHADER_TESS_CTRL) {
405          replacement = ac_nir_load_arg(b, s->args, s->args->tcs_patch_id);
406       } else if (b->shader->info.stage == MESA_SHADER_TESS_EVAL) {
407          replacement = s->tes_patch_id ? s->tes_patch_id :
408                                          ac_nir_load_arg(b, s->args, s->args->tes_patch_id);
409       } else if (b->shader->info.stage == MESA_SHADER_VERTEX) {
410          if (s->hw_stage == AC_HW_VERTEX_SHADER)
411             replacement = ac_nir_load_arg(b, s->args, s->args->vs_prim_id); /* legacy */
412          else
413             replacement = ac_nir_load_arg(b, s->args, s->args->gs_prim_id); /* NGG */
414       } else {
415          unreachable("invalid stage");
416       }
417       break;
418    case nir_intrinsic_load_tess_coord: {
419       nir_def *coord[3] = {
420          s->tes_u ? s->tes_u : ac_nir_load_arg(b, s->args, s->args->tes_u),
421          s->tes_v ? s->tes_v : ac_nir_load_arg(b, s->args, s->args->tes_v),
422          nir_imm_float(b, 0),
423       };
424 
425       /* For triangles, the vector should be (u, v, 1-u-v). */
426       if (b->shader->info.tess._primitive_mode == TESS_PRIMITIVE_TRIANGLES)
427          coord[2] = nir_fsub(b, nir_imm_float(b, 1), nir_fadd(b, coord[0], coord[1]));
428       replacement = nir_vec(b, coord, 3);
429       break;
430    }
431    case nir_intrinsic_load_local_invocation_index:
432       /* GFX11 HS has subgroup_id, so use it instead of vs_rel_patch_id. */
433       if (s->gfx_level < GFX11 &&
434           (s->hw_stage == AC_HW_LOCAL_SHADER || s->hw_stage == AC_HW_HULL_SHADER)) {
435          if (!s->vs_rel_patch_id) {
436             s->vs_rel_patch_id = preload_arg(s, b->impl, s->args->vs_rel_patch_id,
437                                              s->args->tcs_rel_ids, 255);
438          }
439          replacement = s->vs_rel_patch_id;
440       } else if (s->workgroup_size <= s->wave_size) {
441          /* Just a subgroup invocation ID. */
442          replacement = nir_mbcnt_amd(b, nir_imm_intN_t(b, ~0ull, s->wave_size), nir_imm_int(b, 0));
443       } else if (s->gfx_level < GFX12 && s->hw_stage == AC_HW_COMPUTE_SHADER && s->wave_size == 64) {
444          /* After the AND the bits are already multiplied by 64 (left shifted by 6) so we can just
445           * feed that to mbcnt. (GFX12 doesn't have tg_size)
446           */
447          nir_def *wave_id_mul_64 = nir_iand_imm(b, ac_nir_load_arg(b, s->args, s->args->tg_size), 0xfc0);
448          replacement = nir_mbcnt_amd(b, nir_imm_intN_t(b, ~0ull, s->wave_size), wave_id_mul_64);
449       } else {
450          nir_def *subgroup_id;
451 
452          if (s->gfx_level >= GFX12 && s->hw_stage == AC_HW_COMPUTE_SHADER) {
453             subgroup_id = nir_load_subgroup_id(b);
454          } else {
455             subgroup_id = load_subgroup_id_lowered(s, b);
456          }
457 
458          replacement = nir_mbcnt_amd(b, nir_imm_intN_t(b, ~0ull, s->wave_size),
459                                      nir_imul_imm(b, subgroup_id, s->wave_size));
460       }
461       break;
462    case nir_intrinsic_load_subgroup_invocation:
463       replacement = nir_mbcnt_amd(b, nir_imm_intN_t(b, ~0ull, s->wave_size), nir_imm_int(b, 0));
464       break;
465    default:
466       return false;
467    }
468 
469    assert(replacement);
470    nir_def_replace(&intrin->def, replacement);
471    return true;
472 }
473 
474 bool
ac_nir_lower_intrinsics_to_args(nir_shader * shader,const enum amd_gfx_level gfx_level,bool has_ls_vgpr_init_bug,const enum ac_hw_stage hw_stage,unsigned wave_size,unsigned workgroup_size,const struct ac_shader_args * ac_args)475 ac_nir_lower_intrinsics_to_args(nir_shader *shader, const enum amd_gfx_level gfx_level,
476                                 bool has_ls_vgpr_init_bug, const enum ac_hw_stage hw_stage,
477                                 unsigned wave_size, unsigned workgroup_size,
478                                 const struct ac_shader_args *ac_args)
479 {
480    lower_intrinsics_to_args_state state = {
481       .gfx_level = gfx_level,
482       .hw_stage = hw_stage,
483       .has_ls_vgpr_init_bug = has_ls_vgpr_init_bug,
484       .wave_size = wave_size,
485       .workgroup_size = workgroup_size,
486       .args = ac_args,
487    };
488 
489    return nir_shader_instructions_pass(shader, lower_intrinsic_to_arg,
490                                        nir_metadata_control_flow, &state);
491 }
492