• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright © 2021 Valve Corporation
3  *
4  * SPDX-License-Identifier: MIT
5  */
6 
7 #include "ac_nir.h"
8 #include "ac_nir_helpers.h"
9 
10 #include "nir_builder.h"
11 
12 typedef struct {
13    const struct ac_shader_args *const args;
14    const enum amd_gfx_level gfx_level;
15    bool has_ls_vgpr_init_bug;
16    unsigned wave_size;
17    unsigned workgroup_size;
18    const enum ac_hw_stage hw_stage;
19 
20    nir_def *vertex_id;
21    nir_def *instance_id;
22    nir_def *vs_rel_patch_id;
23    nir_def *tes_u;
24    nir_def *tes_v;
25    nir_def *tes_patch_id;
26    nir_def *tes_rel_patch_id;
27 } lower_intrinsics_to_args_state;
28 
29 static nir_def *
preload_arg(lower_intrinsics_to_args_state * s,nir_function_impl * impl,struct ac_arg arg,struct ac_arg ls_buggy_arg,unsigned upper_bound)30 preload_arg(lower_intrinsics_to_args_state *s, nir_function_impl *impl, struct ac_arg arg,
31             struct ac_arg ls_buggy_arg, unsigned upper_bound)
32 {
33    nir_builder start_b = nir_builder_at(nir_before_impl(impl));
34    nir_def *value = ac_nir_load_arg_upper_bound(&start_b, s->args, arg, upper_bound);
35 
36    /* If there are no HS threads, SPI mistakenly loads the LS VGPRs starting at VGPR 0. */
37    if ((s->hw_stage == AC_HW_LOCAL_SHADER || s->hw_stage == AC_HW_HULL_SHADER) &&
38        s->has_ls_vgpr_init_bug) {
39       nir_def *count = ac_nir_unpack_arg(&start_b, s->args, s->args->merged_wave_info, 8, 8);
40       nir_def *hs_empty = nir_ieq_imm(&start_b, count, 0);
41       value = nir_bcsel(&start_b, hs_empty,
42                         ac_nir_load_arg_upper_bound(&start_b, s->args, ls_buggy_arg, upper_bound),
43                         value);
44    }
45    return value;
46 }
47 
48 static nir_def *
load_subgroup_id_lowered(lower_intrinsics_to_args_state * s,nir_builder * b)49 load_subgroup_id_lowered(lower_intrinsics_to_args_state *s, nir_builder *b)
50 {
51    if (s->workgroup_size <= s->wave_size) {
52       return nir_imm_int(b, 0);
53    } else if (s->hw_stage == AC_HW_COMPUTE_SHADER) {
54       assert(s->gfx_level < GFX12 && s->args->tg_size.used);
55 
56       if (s->gfx_level >= GFX10_3) {
57          return ac_nir_unpack_arg(b, s->args, s->args->tg_size, 20, 5);
58       } else {
59          /* GFX6-10 don't actually support a wave id, but we can
60           * use the ordered id because ORDERED_APPEND_* is set to
61           * zero in the compute dispatch initiator.
62           */
63          return ac_nir_unpack_arg(b, s->args, s->args->tg_size, 6, 6);
64       }
65    } else if (s->hw_stage == AC_HW_HULL_SHADER && s->gfx_level >= GFX11) {
66       assert(s->args->tcs_wave_id.used);
67       return ac_nir_unpack_arg(b, s->args, s->args->tcs_wave_id, 0, 3);
68    } else if (s->hw_stage == AC_HW_LEGACY_GEOMETRY_SHADER ||
69               s->hw_stage == AC_HW_NEXT_GEN_GEOMETRY_SHADER) {
70       assert(s->args->merged_wave_info.used);
71       return ac_nir_unpack_arg(b, s->args, s->args->merged_wave_info, 24, 4);
72    } else {
73       return nir_imm_int(b, 0);
74    }
75 }
76 
77 static bool
lower_intrinsic_to_arg(nir_builder * b,nir_intrinsic_instr * intrin,void * state)78 lower_intrinsic_to_arg(nir_builder *b, nir_intrinsic_instr *intrin, void *state)
79 {
80    lower_intrinsics_to_args_state *s = (lower_intrinsics_to_args_state *)state;
81    nir_def *replacement = NULL;
82    b->cursor = nir_after_instr(&intrin->instr);
83 
84    switch (intrin->intrinsic) {
85    case nir_intrinsic_load_subgroup_id:
86       if (s->gfx_level >= GFX12 && s->hw_stage == AC_HW_COMPUTE_SHADER)
87          return false; /* Lowered in backend compilers. */
88       replacement = load_subgroup_id_lowered(s, b);
89       break;
90    case nir_intrinsic_load_num_subgroups: {
91       if (s->hw_stage == AC_HW_COMPUTE_SHADER) {
92          assert(s->args->tg_size.used);
93          replacement = ac_nir_unpack_arg(b, s->args, s->args->tg_size, 0, 6);
94       } else if (s->hw_stage == AC_HW_LEGACY_GEOMETRY_SHADER ||
95                  s->hw_stage == AC_HW_NEXT_GEN_GEOMETRY_SHADER) {
96          assert(s->args->merged_wave_info.used);
97          replacement = ac_nir_unpack_arg(b, s->args, s->args->merged_wave_info, 28, 4);
98       } else {
99          replacement = nir_imm_int(b, 1);
100       }
101 
102       break;
103    }
104    case nir_intrinsic_load_workgroup_id:
105       if (b->shader->info.stage == MESA_SHADER_MESH) {
106          /* This lowering is only valid with fast_launch = 2, otherwise we assume that
107           * lower_workgroup_id_to_index removed any uses of the workgroup id by this point.
108           */
109          assert(s->gfx_level >= GFX11);
110          nir_def *xy = ac_nir_load_arg(b, s->args, s->args->tess_offchip_offset);
111          nir_def *z = ac_nir_load_arg(b, s->args, s->args->gs_attr_offset);
112          replacement = nir_vec3(b, nir_extract_u16(b, xy, nir_imm_int(b, 0)),
113                                 nir_extract_u16(b, xy, nir_imm_int(b, 1)),
114                                 nir_extract_u16(b, z, nir_imm_int(b, 1)));
115       } else {
116          return false;
117       }
118       break;
119    case nir_intrinsic_load_pixel_coord:
120       replacement = nir_unpack_32_2x16(b, ac_nir_load_arg(b, s->args, s->args->pos_fixed_pt));
121       break;
122    case nir_intrinsic_load_frag_coord:
123       replacement = nir_vec4(b, ac_nir_load_arg(b, s->args, s->args->frag_pos[0]),
124                              ac_nir_load_arg(b, s->args, s->args->frag_pos[1]),
125                              ac_nir_load_arg(b, s->args, s->args->frag_pos[2]),
126                              ac_nir_load_arg(b, s->args, s->args->frag_pos[3]));
127       break;
128    case nir_intrinsic_load_local_invocation_id: {
129       unsigned num_bits[3];
130       nir_def *vec[3];
131 
132       for (unsigned i = 0; i < 3; i++) {
133          bool has_chan = b->shader->info.workgroup_size_variable ||
134                          b->shader->info.workgroup_size[i] > 1;
135          /* Extract as few bits possible - we want the constant to be an inline constant
136           * instead of a literal.
137           */
138          num_bits[i] = !has_chan ? 0 :
139                        b->shader->info.workgroup_size_variable ?
140                                    10 : util_logbase2_ceil(b->shader->info.workgroup_size[i]);
141       }
142 
143       if (s->args->local_invocation_ids_packed.used) {
144          unsigned extract_bits[3];
145          memcpy(extract_bits, num_bits, sizeof(num_bits));
146 
147          /* Thread IDs are packed in VGPR0, 10 bits per component.
148           * Always extract all remaining bits if later ID components are always 0, which will
149           * translate to a bit shift.
150           */
151          if (num_bits[2]) {
152             extract_bits[2] = 12; /* Z > 0 */
153          } else if (num_bits[1])
154             extract_bits[1] = 22; /* Y > 0, Z == 0 */
155          else if (num_bits[0])
156             extract_bits[0] = 32; /* X > 0, Y == 0, Z == 0 */
157 
158          nir_def *ids_packed =
159             ac_nir_load_arg_upper_bound(b, s->args, s->args->local_invocation_ids_packed,
160                                         b->shader->info.workgroup_size_variable ?
161                                            0 : ((b->shader->info.workgroup_size[0] - 1) |
162                                                 ((b->shader->info.workgroup_size[1] - 1) << 10) |
163                                                 ((b->shader->info.workgroup_size[2] - 1) << 20)));
164 
165          for (unsigned i = 0; i < 3; i++) {
166             vec[i] = !num_bits[i] ? nir_imm_int(b, 0) :
167                                     ac_nir_unpack_value(b,  ids_packed, i * 10, extract_bits[i]);
168          }
169       } else {
170          const struct ac_arg ids[] = {
171             s->args->local_invocation_id_x,
172             s->args->local_invocation_id_y,
173             s->args->local_invocation_id_z,
174          };
175 
176          for (unsigned i = 0; i < 3; i++) {
177             unsigned max = b->shader->info.workgroup_size_variable ?
178                               1023 : (b->shader->info.workgroup_size[i] - 1);
179             vec[i] = !num_bits[i] ? nir_imm_int(b, 0) :
180                                     ac_nir_load_arg_upper_bound(b, s->args, ids[i], max);
181          }
182       }
183       replacement = nir_vec(b, vec, 3);
184       break;
185    }
186    case nir_intrinsic_load_merged_wave_info_amd:
187       replacement = ac_nir_load_arg(b, s->args, s->args->merged_wave_info);
188       break;
189    case nir_intrinsic_load_workgroup_num_input_vertices_amd:
190       replacement = ac_nir_unpack_arg(b, s->args, s->args->gs_tg_info, 12, 9);
191       break;
192    case nir_intrinsic_load_workgroup_num_input_primitives_amd:
193       replacement = ac_nir_unpack_arg(b, s->args, s->args->gs_tg_info, 22, 9);
194       break;
195    case nir_intrinsic_load_packed_passthrough_primitive_amd:
196       /* NGG passthrough mode: the HW already packs the primitive export value to a single register.
197        */
198       replacement = ac_nir_load_arg(b, s->args, s->args->gs_vtx_offset[0]);
199       break;
200    case nir_intrinsic_load_ordered_id_amd:
201       replacement = ac_nir_unpack_arg(b, s->args, s->args->gs_tg_info, 0, 12);
202       break;
203    case nir_intrinsic_load_ring_tess_offchip_offset_amd:
204       replacement = ac_nir_load_arg(b, s->args, s->args->tess_offchip_offset);
205       break;
206    case nir_intrinsic_load_ring_tess_factors_offset_amd:
207       replacement = ac_nir_load_arg(b, s->args, s->args->tcs_factor_offset);
208       break;
209    case nir_intrinsic_load_ring_es2gs_offset_amd:
210       replacement = ac_nir_load_arg(b, s->args, s->args->es2gs_offset);
211       break;
212    case nir_intrinsic_load_ring_gs2vs_offset_amd:
213       replacement = ac_nir_load_arg(b, s->args, s->args->gs2vs_offset);
214       break;
215    case nir_intrinsic_load_gs_vertex_offset_amd:
216       replacement = ac_nir_load_arg(b, s->args, s->args->gs_vtx_offset[nir_intrinsic_base(intrin)]);
217       break;
218    case nir_intrinsic_load_streamout_config_amd:
219       replacement = ac_nir_load_arg(b, s->args, s->args->streamout_config);
220       break;
221    case nir_intrinsic_load_streamout_write_index_amd:
222       replacement = ac_nir_load_arg(b, s->args, s->args->streamout_write_index);
223       break;
224    case nir_intrinsic_load_streamout_offset_amd:
225       replacement = ac_nir_load_arg(b, s->args, s->args->streamout_offset[nir_intrinsic_base(intrin)]);
226       break;
227    case nir_intrinsic_load_ring_attr_offset_amd: {
228       nir_def *ring_attr_offset = ac_nir_load_arg(b, s->args, s->args->gs_attr_offset);
229       replacement = nir_ishl_imm(b, nir_ubfe_imm(b, ring_attr_offset, 0, 15), 9); /* 512b increments. */
230       break;
231    }
232    case nir_intrinsic_load_first_vertex:
233       replacement = ac_nir_load_arg(b, s->args, s->args->base_vertex);
234       break;
235    case nir_intrinsic_load_base_instance:
236       replacement = ac_nir_load_arg(b, s->args, s->args->start_instance);
237       break;
238    case nir_intrinsic_load_draw_id:
239       replacement = ac_nir_load_arg(b, s->args, s->args->draw_id);
240       break;
241    case nir_intrinsic_load_view_index:
242       replacement = ac_nir_load_arg_upper_bound(b, s->args, s->args->view_index, 1);
243       break;
244    case nir_intrinsic_load_invocation_id:
245       if (b->shader->info.stage == MESA_SHADER_TESS_CTRL) {
246          replacement = ac_nir_unpack_arg(b, s->args, s->args->tcs_rel_ids, 8, 5);
247       } else if (b->shader->info.stage == MESA_SHADER_GEOMETRY) {
248          if (s->gfx_level >= GFX12) {
249             replacement = ac_nir_unpack_arg(b, s->args, s->args->gs_vtx_offset[0], 27, 5);
250          } else if (s->gfx_level >= GFX10) {
251             replacement = ac_nir_unpack_arg(b, s->args, s->args->gs_invocation_id, 0, 5);
252          } else {
253             replacement = ac_nir_load_arg_upper_bound(b, s->args, s->args->gs_invocation_id, 31);
254          }
255       } else {
256          unreachable("unexpected shader stage");
257       }
258       break;
259    case nir_intrinsic_load_sample_id:
260       replacement = ac_nir_unpack_arg(b, s->args, s->args->ancillary, 8, 4);
261       break;
262    case nir_intrinsic_load_sample_pos:
263       replacement = nir_vec2(b, nir_ffract(b, ac_nir_load_arg(b, s->args, s->args->frag_pos[0])),
264                              nir_ffract(b, ac_nir_load_arg(b, s->args, s->args->frag_pos[1])));
265       break;
266    case nir_intrinsic_load_frag_shading_rate: {
267       /* VRS Rate X = Ancillary[2:3]
268        * VRS Rate Y = Ancillary[4:5]
269        */
270       nir_def *x_rate = ac_nir_unpack_arg(b, s->args, s->args->ancillary, 2, 2);
271       nir_def *y_rate = ac_nir_unpack_arg(b, s->args, s->args->ancillary, 4, 2);
272 
273       /* xRate = xRate == 0x1 ? Horizontal2Pixels : None. */
274       x_rate = nir_bcsel(b, nir_ieq_imm(b, x_rate, 1), nir_imm_int(b, 4), nir_imm_int(b, 0));
275 
276       /* yRate = yRate == 0x1 ? Vertical2Pixels : None. */
277       y_rate = nir_bcsel(b, nir_ieq_imm(b, y_rate, 1), nir_imm_int(b, 1), nir_imm_int(b, 0));
278       replacement = nir_ior(b, x_rate, y_rate);
279       break;
280    }
281    case nir_intrinsic_load_front_face:
282       replacement = nir_fgt_imm(b, ac_nir_load_arg(b, s->args, s->args->front_face), 0);
283       break;
284    case nir_intrinsic_load_front_face_fsign:
285       replacement = ac_nir_load_arg(b, s->args, s->args->front_face);
286       break;
287    case nir_intrinsic_load_layer_id:
288       replacement = ac_nir_unpack_arg(b, s->args, s->args->ancillary,
289                                       16, s->gfx_level >= GFX12 ? 14 : 13);
290       break;
291    case nir_intrinsic_load_barycentric_optimize_amd: {
292       nir_def *prim_mask = ac_nir_load_arg(b, s->args, s->args->prim_mask);
293       /* enabled when bit 31 is set */
294       replacement = nir_ilt_imm(b, prim_mask, 0);
295       break;
296    }
297    case nir_intrinsic_load_barycentric_pixel:
298       if (nir_intrinsic_interp_mode(intrin) == INTERP_MODE_NOPERSPECTIVE)
299          replacement = ac_nir_load_arg(b, s->args, s->args->linear_center);
300       else
301          replacement = ac_nir_load_arg(b, s->args, s->args->persp_center);
302       break;
303    case nir_intrinsic_load_barycentric_centroid:
304       if (nir_intrinsic_interp_mode(intrin) == INTERP_MODE_NOPERSPECTIVE)
305          replacement = ac_nir_load_arg(b, s->args, s->args->linear_centroid);
306       else
307          replacement = ac_nir_load_arg(b, s->args, s->args->persp_centroid);
308       break;
309    case nir_intrinsic_load_barycentric_sample:
310       if (nir_intrinsic_interp_mode(intrin) == INTERP_MODE_NOPERSPECTIVE)
311          replacement = ac_nir_load_arg(b, s->args, s->args->linear_sample);
312       else
313          replacement = ac_nir_load_arg(b, s->args, s->args->persp_sample);
314       break;
315    case nir_intrinsic_load_barycentric_model:
316       replacement = ac_nir_load_arg(b, s->args, s->args->pull_model);
317       break;
318    case nir_intrinsic_load_barycentric_at_offset: {
319       nir_def *baryc = nir_intrinsic_interp_mode(intrin) == INTERP_MODE_NOPERSPECTIVE ?
320                           ac_nir_load_arg(b, s->args, s->args->linear_center) :
321                           ac_nir_load_arg(b, s->args, s->args->persp_center);
322       nir_def *i = nir_channel(b, baryc, 0);
323       nir_def *j = nir_channel(b, baryc, 1);
324       nir_def *offset_x = nir_channel(b, intrin->src[0].ssa, 0);
325       nir_def *offset_y = nir_channel(b, intrin->src[0].ssa, 1);
326       nir_def *ddx_i = nir_ddx(b, i);
327       nir_def *ddx_j = nir_ddx(b, j);
328       nir_def *ddy_i = nir_ddy(b, i);
329       nir_def *ddy_j = nir_ddy(b, j);
330 
331       /* Interpolate standard barycentrics by offset. */
332       nir_def *offset_i = nir_ffma(b, ddy_i, offset_y, nir_ffma(b, ddx_i, offset_x, i));
333       nir_def *offset_j = nir_ffma(b, ddy_j, offset_y, nir_ffma(b, ddx_j, offset_x, j));
334       replacement = nir_vec2(b, offset_i, offset_j);
335       break;
336    }
337    case nir_intrinsic_load_gs_wave_id_amd:
338       if (s->args->merged_wave_info.used)
339          replacement = ac_nir_unpack_arg(b, s->args, s->args->merged_wave_info, 16, 8);
340       else if (s->args->gs_wave_id.used)
341          replacement = ac_nir_load_arg(b, s->args, s->args->gs_wave_id);
342       else
343          unreachable("Shader doesn't have GS wave ID.");
344       break;
345    case nir_intrinsic_overwrite_vs_arguments_amd:
346       s->vertex_id = intrin->src[0].ssa;
347       s->instance_id = intrin->src[1].ssa;
348       nir_instr_remove(&intrin->instr);
349       return true;
350    case nir_intrinsic_overwrite_tes_arguments_amd:
351       s->tes_u = intrin->src[0].ssa;
352       s->tes_v = intrin->src[1].ssa;
353       s->tes_patch_id = intrin->src[2].ssa;
354       s->tes_rel_patch_id = intrin->src[3].ssa;
355       nir_instr_remove(&intrin->instr);
356       return true;
357    case nir_intrinsic_load_vertex_id_zero_base:
358       if (!s->vertex_id)
359          s->vertex_id = preload_arg(s, b->impl, s->args->vertex_id, s->args->tcs_patch_id, 0);
360       replacement = s->vertex_id;
361       break;
362    case nir_intrinsic_load_instance_id:
363       if (!s->instance_id)
364          s->instance_id = preload_arg(s, b->impl, s->args->instance_id, s->args->vertex_id, 0);
365       replacement = s->instance_id;
366       break;
367    case nir_intrinsic_load_tess_rel_patch_id_amd:
368       if (b->shader->info.stage == MESA_SHADER_TESS_CTRL) {
369          replacement = ac_nir_unpack_arg(b, s->args, s->args->tcs_rel_ids, 0, 8);
370       } else if (b->shader->info.stage == MESA_SHADER_TESS_EVAL) {
371          if (s->tes_rel_patch_id) {
372             replacement = s->tes_rel_patch_id;
373          } else {
374             replacement = ac_nir_load_arg(b, s->args, s->args->tes_rel_patch_id);
375             if (b->shader->info.tess.tcs_vertices_out) {
376                /* Setting an upper bound like this will actually make it possible
377                 * to optimize some multiplications (in address calculations) so that
378                 * constant additions can be added to the const offset in memory load instructions.
379                 */
380                nir_intrinsic_set_arg_upper_bound_u32_amd(nir_instr_as_intrinsic(replacement->parent_instr),
381                                                          2048 / b->shader->info.tess.tcs_vertices_out);
382             }
383          }
384       } else {
385          unreachable("invalid stage");
386       }
387       break;
388    case nir_intrinsic_load_primitive_id:
389       if (b->shader->info.stage == MESA_SHADER_GEOMETRY) {
390          replacement = ac_nir_load_arg(b, s->args, s->args->gs_prim_id);
391       } else if (b->shader->info.stage == MESA_SHADER_TESS_CTRL) {
392          replacement = ac_nir_load_arg(b, s->args, s->args->tcs_patch_id);
393       } else if (b->shader->info.stage == MESA_SHADER_TESS_EVAL) {
394          replacement = s->tes_patch_id ? s->tes_patch_id :
395                                          ac_nir_load_arg(b, s->args, s->args->tes_patch_id);
396       } else if (b->shader->info.stage == MESA_SHADER_VERTEX) {
397          if (s->hw_stage == AC_HW_VERTEX_SHADER)
398             replacement = ac_nir_load_arg(b, s->args, s->args->vs_prim_id); /* legacy */
399          else
400             replacement = ac_nir_load_arg(b, s->args, s->args->gs_prim_id); /* NGG */
401       } else {
402          unreachable("invalid stage");
403       }
404       break;
405    case nir_intrinsic_load_tess_coord: {
406       nir_def *coord[3] = {
407          s->tes_u ? s->tes_u : ac_nir_load_arg(b, s->args, s->args->tes_u),
408          s->tes_v ? s->tes_v : ac_nir_load_arg(b, s->args, s->args->tes_v),
409          nir_imm_float(b, 0),
410       };
411 
412       /* For triangles, the vector should be (u, v, 1-u-v). */
413       if (b->shader->info.tess._primitive_mode == TESS_PRIMITIVE_TRIANGLES)
414          coord[2] = nir_fsub(b, nir_imm_float(b, 1), nir_fadd(b, coord[0], coord[1]));
415       replacement = nir_vec(b, coord, 3);
416       break;
417    }
418    case nir_intrinsic_load_local_invocation_index:
419       /* GFX11 HS has subgroup_id, so use it instead of vs_rel_patch_id. */
420       if (s->gfx_level < GFX11 &&
421           (s->hw_stage == AC_HW_LOCAL_SHADER || s->hw_stage == AC_HW_HULL_SHADER)) {
422          if (!s->vs_rel_patch_id) {
423             s->vs_rel_patch_id = preload_arg(s, b->impl, s->args->vs_rel_patch_id,
424                                              s->args->tcs_rel_ids, 255);
425          }
426          replacement = s->vs_rel_patch_id;
427       } else if (s->workgroup_size <= s->wave_size) {
428          /* Just a subgroup invocation ID. */
429          replacement = nir_mbcnt_amd(b, nir_imm_intN_t(b, ~0ull, s->wave_size), nir_imm_int(b, 0));
430       } else if (s->gfx_level < GFX12 && s->hw_stage == AC_HW_COMPUTE_SHADER && s->wave_size == 64) {
431          /* After the AND the bits are already multiplied by 64 (left shifted by 6) so we can just
432           * feed that to mbcnt. (GFX12 doesn't have tg_size)
433           */
434          nir_def *wave_id_mul_64 = nir_iand_imm(b, ac_nir_load_arg(b, s->args, s->args->tg_size), 0xfc0);
435          replacement = nir_mbcnt_amd(b, nir_imm_intN_t(b, ~0ull, s->wave_size), wave_id_mul_64);
436       } else {
437          nir_def *subgroup_id;
438 
439          if (s->gfx_level >= GFX12 && s->hw_stage == AC_HW_COMPUTE_SHADER) {
440             subgroup_id = nir_load_subgroup_id(b);
441          } else {
442             subgroup_id = load_subgroup_id_lowered(s, b);
443          }
444 
445          replacement = nir_mbcnt_amd(b, nir_imm_intN_t(b, ~0ull, s->wave_size),
446                                      nir_imul_imm(b, subgroup_id, s->wave_size));
447       }
448       break;
449    case nir_intrinsic_load_subgroup_invocation:
450       replacement = nir_mbcnt_amd(b, nir_imm_intN_t(b, ~0ull, s->wave_size), nir_imm_int(b, 0));
451       break;
452    default:
453       return false;
454    }
455 
456    assert(replacement);
457    nir_def_replace(&intrin->def, replacement);
458    return true;
459 }
460 
461 bool
ac_nir_lower_intrinsics_to_args(nir_shader * shader,const enum amd_gfx_level gfx_level,bool has_ls_vgpr_init_bug,const enum ac_hw_stage hw_stage,unsigned wave_size,unsigned workgroup_size,const struct ac_shader_args * ac_args)462 ac_nir_lower_intrinsics_to_args(nir_shader *shader, const enum amd_gfx_level gfx_level,
463                                 bool has_ls_vgpr_init_bug, const enum ac_hw_stage hw_stage,
464                                 unsigned wave_size, unsigned workgroup_size,
465                                 const struct ac_shader_args *ac_args)
466 {
467    lower_intrinsics_to_args_state state = {
468       .gfx_level = gfx_level,
469       .hw_stage = hw_stage,
470       .has_ls_vgpr_init_bug = has_ls_vgpr_init_bug,
471       .wave_size = wave_size,
472       .workgroup_size = workgroup_size,
473       .args = ac_args,
474    };
475 
476    return nir_shader_intrinsics_pass(shader, lower_intrinsic_to_arg,
477                                      nir_metadata_control_flow, &state);
478 }
479