1 /*
2 * Copyright © 2021 Valve Corporation
3 *
4 * SPDX-License-Identifier: MIT
5 */
6
7 #include "ac_nir.h"
8 #include "ac_nir_helpers.h"
9
10 #include "nir_builder.h"
11
12 typedef struct {
13 const struct ac_shader_args *const args;
14 const enum amd_gfx_level gfx_level;
15 bool has_ls_vgpr_init_bug;
16 unsigned wave_size;
17 unsigned workgroup_size;
18 const enum ac_hw_stage hw_stage;
19
20 nir_def *vertex_id;
21 nir_def *instance_id;
22 nir_def *vs_rel_patch_id;
23 nir_def *tes_u;
24 nir_def *tes_v;
25 nir_def *tes_patch_id;
26 nir_def *tes_rel_patch_id;
27 } lower_intrinsics_to_args_state;
28
29 static nir_def *
preload_arg(lower_intrinsics_to_args_state * s,nir_function_impl * impl,struct ac_arg arg,struct ac_arg ls_buggy_arg,unsigned upper_bound)30 preload_arg(lower_intrinsics_to_args_state *s, nir_function_impl *impl, struct ac_arg arg,
31 struct ac_arg ls_buggy_arg, unsigned upper_bound)
32 {
33 nir_builder start_b = nir_builder_at(nir_before_impl(impl));
34 nir_def *value = ac_nir_load_arg_upper_bound(&start_b, s->args, arg, upper_bound);
35
36 /* If there are no HS threads, SPI mistakenly loads the LS VGPRs starting at VGPR 0. */
37 if ((s->hw_stage == AC_HW_LOCAL_SHADER || s->hw_stage == AC_HW_HULL_SHADER) &&
38 s->has_ls_vgpr_init_bug) {
39 nir_def *count = ac_nir_unpack_arg(&start_b, s->args, s->args->merged_wave_info, 8, 8);
40 nir_def *hs_empty = nir_ieq_imm(&start_b, count, 0);
41 value = nir_bcsel(&start_b, hs_empty,
42 ac_nir_load_arg_upper_bound(&start_b, s->args, ls_buggy_arg, upper_bound),
43 value);
44 }
45 return value;
46 }
47
48 static nir_def *
load_subgroup_id_lowered(lower_intrinsics_to_args_state * s,nir_builder * b)49 load_subgroup_id_lowered(lower_intrinsics_to_args_state *s, nir_builder *b)
50 {
51 if (s->workgroup_size <= s->wave_size) {
52 return nir_imm_int(b, 0);
53 } else if (s->hw_stage == AC_HW_COMPUTE_SHADER) {
54 assert(s->gfx_level < GFX12 && s->args->tg_size.used);
55
56 if (s->gfx_level >= GFX10_3) {
57 return ac_nir_unpack_arg(b, s->args, s->args->tg_size, 20, 5);
58 } else {
59 /* GFX6-10 don't actually support a wave id, but we can
60 * use the ordered id because ORDERED_APPEND_* is set to
61 * zero in the compute dispatch initiator.
62 */
63 return ac_nir_unpack_arg(b, s->args, s->args->tg_size, 6, 6);
64 }
65 } else if (s->hw_stage == AC_HW_HULL_SHADER && s->gfx_level >= GFX11) {
66 assert(s->args->tcs_wave_id.used);
67 return ac_nir_unpack_arg(b, s->args, s->args->tcs_wave_id, 0, 3);
68 } else if (s->hw_stage == AC_HW_LEGACY_GEOMETRY_SHADER ||
69 s->hw_stage == AC_HW_NEXT_GEN_GEOMETRY_SHADER) {
70 assert(s->args->merged_wave_info.used);
71 return ac_nir_unpack_arg(b, s->args, s->args->merged_wave_info, 24, 4);
72 } else {
73 return nir_imm_int(b, 0);
74 }
75 }
76
77 static bool
lower_intrinsic_to_arg(nir_builder * b,nir_instr * instr,void * state)78 lower_intrinsic_to_arg(nir_builder *b, nir_instr *instr, void *state)
79 {
80 if (instr->type != nir_instr_type_intrinsic)
81 return false;
82
83 lower_intrinsics_to_args_state *s = (lower_intrinsics_to_args_state *)state;
84 nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
85 nir_def *replacement = NULL;
86 b->cursor = nir_after_instr(&intrin->instr);
87
88 switch (intrin->intrinsic) {
89 case nir_intrinsic_load_subgroup_id:
90 if (s->gfx_level >= GFX12 && s->hw_stage == AC_HW_COMPUTE_SHADER)
91 return false; /* Lowered in backend compilers. */
92 replacement = load_subgroup_id_lowered(s, b);
93 break;
94 case nir_intrinsic_load_num_subgroups: {
95 if (s->hw_stage == AC_HW_COMPUTE_SHADER) {
96 assert(s->args->tg_size.used);
97 replacement = ac_nir_unpack_arg(b, s->args, s->args->tg_size, 0, 6);
98 } else if (s->hw_stage == AC_HW_LEGACY_GEOMETRY_SHADER ||
99 s->hw_stage == AC_HW_NEXT_GEN_GEOMETRY_SHADER) {
100 assert(s->args->merged_wave_info.used);
101 replacement = ac_nir_unpack_arg(b, s->args, s->args->merged_wave_info, 28, 4);
102 } else {
103 replacement = nir_imm_int(b, 1);
104 }
105
106 break;
107 }
108 case nir_intrinsic_load_workgroup_id:
109 if (b->shader->info.stage == MESA_SHADER_MESH) {
110 /* This lowering is only valid with fast_launch = 2, otherwise we assume that
111 * lower_workgroup_id_to_index removed any uses of the workgroup id by this point.
112 */
113 assert(s->gfx_level >= GFX11);
114 nir_def *xy = ac_nir_load_arg(b, s->args, s->args->tess_offchip_offset);
115 nir_def *z = ac_nir_load_arg(b, s->args, s->args->gs_attr_offset);
116 replacement = nir_vec3(b, nir_extract_u16(b, xy, nir_imm_int(b, 0)),
117 nir_extract_u16(b, xy, nir_imm_int(b, 1)),
118 nir_extract_u16(b, z, nir_imm_int(b, 1)));
119 } else {
120 return false;
121 }
122 break;
123 case nir_intrinsic_load_pixel_coord:
124 replacement = nir_unpack_32_2x16(b, ac_nir_load_arg(b, s->args, s->args->pos_fixed_pt));
125 break;
126 case nir_intrinsic_load_frag_coord:
127 replacement = nir_vec4(b, ac_nir_load_arg(b, s->args, s->args->frag_pos[0]),
128 ac_nir_load_arg(b, s->args, s->args->frag_pos[1]),
129 ac_nir_load_arg(b, s->args, s->args->frag_pos[2]),
130 ac_nir_load_arg(b, s->args, s->args->frag_pos[3]));
131 break;
132 case nir_intrinsic_load_local_invocation_id: {
133 unsigned num_bits[3];
134 nir_def *vec[3];
135
136 for (unsigned i = 0; i < 3; i++) {
137 bool has_chan = b->shader->info.workgroup_size_variable ||
138 b->shader->info.workgroup_size[i] > 1;
139 /* Extract as few bits possible - we want the constant to be an inline constant
140 * instead of a literal.
141 */
142 num_bits[i] = !has_chan ? 0 :
143 b->shader->info.workgroup_size_variable ?
144 10 : util_logbase2_ceil(b->shader->info.workgroup_size[i]);
145 }
146
147 if (s->args->local_invocation_ids_packed.used) {
148 unsigned extract_bits[3];
149 memcpy(extract_bits, num_bits, sizeof(num_bits));
150
151 /* Thread IDs are packed in VGPR0, 10 bits per component.
152 * Always extract all remaining bits if later ID components are always 0, which will
153 * translate to a bit shift.
154 */
155 if (num_bits[2]) {
156 extract_bits[2] = 12; /* Z > 0 */
157 } else if (num_bits[1])
158 extract_bits[1] = 22; /* Y > 0, Z == 0 */
159 else if (num_bits[0])
160 extract_bits[0] = 32; /* X > 0, Y == 0, Z == 0 */
161
162 nir_def *ids_packed =
163 ac_nir_load_arg_upper_bound(b, s->args, s->args->local_invocation_ids_packed,
164 b->shader->info.workgroup_size_variable ?
165 0 : ((b->shader->info.workgroup_size[0] - 1) |
166 ((b->shader->info.workgroup_size[1] - 1) << 10) |
167 ((b->shader->info.workgroup_size[2] - 1) << 20)));
168
169 for (unsigned i = 0; i < 3; i++) {
170 vec[i] = !num_bits[i] ? nir_imm_int(b, 0) :
171 ac_nir_unpack_value(b, ids_packed, i * 10, extract_bits[i]);
172 }
173 } else {
174 const struct ac_arg ids[] = {
175 s->args->local_invocation_id_x,
176 s->args->local_invocation_id_y,
177 s->args->local_invocation_id_z,
178 };
179
180 for (unsigned i = 0; i < 3; i++) {
181 unsigned max = b->shader->info.workgroup_size_variable ?
182 1023 : (b->shader->info.workgroup_size[i] - 1);
183 vec[i] = !num_bits[i] ? nir_imm_int(b, 0) :
184 ac_nir_load_arg_upper_bound(b, s->args, ids[i], max);
185 }
186 }
187 replacement = nir_vec(b, vec, 3);
188 break;
189 }
190 case nir_intrinsic_load_merged_wave_info_amd:
191 replacement = ac_nir_load_arg(b, s->args, s->args->merged_wave_info);
192 break;
193 case nir_intrinsic_load_workgroup_num_input_vertices_amd:
194 replacement = ac_nir_unpack_arg(b, s->args, s->args->gs_tg_info, 12, 9);
195 break;
196 case nir_intrinsic_load_workgroup_num_input_primitives_amd:
197 replacement = ac_nir_unpack_arg(b, s->args, s->args->gs_tg_info, 22, 9);
198 break;
199 case nir_intrinsic_load_packed_passthrough_primitive_amd:
200 /* NGG passthrough mode: the HW already packs the primitive export value to a single register.
201 */
202 replacement = ac_nir_load_arg(b, s->args, s->args->gs_vtx_offset[0]);
203 break;
204 case nir_intrinsic_load_ordered_id_amd:
205 replacement = ac_nir_unpack_arg(b, s->args, s->args->gs_tg_info, 0, 12);
206 break;
207 case nir_intrinsic_load_ring_tess_offchip_offset_amd:
208 replacement = ac_nir_load_arg(b, s->args, s->args->tess_offchip_offset);
209 break;
210 case nir_intrinsic_load_ring_tess_factors_offset_amd:
211 replacement = ac_nir_load_arg(b, s->args, s->args->tcs_factor_offset);
212 break;
213 case nir_intrinsic_load_ring_es2gs_offset_amd:
214 replacement = ac_nir_load_arg(b, s->args, s->args->es2gs_offset);
215 break;
216 case nir_intrinsic_load_ring_gs2vs_offset_amd:
217 replacement = ac_nir_load_arg(b, s->args, s->args->gs2vs_offset);
218 break;
219 case nir_intrinsic_load_gs_vertex_offset_amd:
220 replacement = ac_nir_load_arg(b, s->args, s->args->gs_vtx_offset[nir_intrinsic_base(intrin)]);
221 break;
222 case nir_intrinsic_load_streamout_config_amd:
223 replacement = ac_nir_load_arg(b, s->args, s->args->streamout_config);
224 break;
225 case nir_intrinsic_load_streamout_write_index_amd:
226 replacement = ac_nir_load_arg(b, s->args, s->args->streamout_write_index);
227 break;
228 case nir_intrinsic_load_streamout_offset_amd:
229 replacement = ac_nir_load_arg(b, s->args, s->args->streamout_offset[nir_intrinsic_base(intrin)]);
230 break;
231 case nir_intrinsic_load_ring_attr_offset_amd: {
232 nir_def *ring_attr_offset = ac_nir_load_arg(b, s->args, s->args->gs_attr_offset);
233 replacement = nir_ishl_imm(b, nir_ubfe_imm(b, ring_attr_offset, 0, 15), 9); /* 512b increments. */
234 break;
235 }
236 case nir_intrinsic_load_first_vertex:
237 replacement = ac_nir_load_arg(b, s->args, s->args->base_vertex);
238 break;
239 case nir_intrinsic_load_base_instance:
240 replacement = ac_nir_load_arg(b, s->args, s->args->start_instance);
241 break;
242 case nir_intrinsic_load_draw_id:
243 replacement = ac_nir_load_arg(b, s->args, s->args->draw_id);
244 break;
245 case nir_intrinsic_load_view_index:
246 replacement = ac_nir_load_arg_upper_bound(b, s->args, s->args->view_index, 1);
247 break;
248 case nir_intrinsic_load_invocation_id:
249 if (b->shader->info.stage == MESA_SHADER_TESS_CTRL) {
250 replacement = ac_nir_unpack_arg(b, s->args, s->args->tcs_rel_ids, 8, 5);
251 } else if (b->shader->info.stage == MESA_SHADER_GEOMETRY) {
252 if (s->gfx_level >= GFX12) {
253 replacement = ac_nir_unpack_arg(b, s->args, s->args->gs_vtx_offset[0], 27, 5);
254 } else if (s->gfx_level >= GFX10) {
255 replacement = ac_nir_unpack_arg(b, s->args, s->args->gs_invocation_id, 0, 5);
256 } else {
257 replacement = ac_nir_load_arg_upper_bound(b, s->args, s->args->gs_invocation_id, 31);
258 }
259 } else {
260 unreachable("unexpected shader stage");
261 }
262 break;
263 case nir_intrinsic_load_sample_id:
264 replacement = ac_nir_unpack_arg(b, s->args, s->args->ancillary, 8, 4);
265 break;
266 case nir_intrinsic_load_sample_pos:
267 replacement = nir_vec2(b, nir_ffract(b, ac_nir_load_arg(b, s->args, s->args->frag_pos[0])),
268 nir_ffract(b, ac_nir_load_arg(b, s->args, s->args->frag_pos[1])));
269 break;
270 case nir_intrinsic_load_frag_shading_rate: {
271 /* VRS Rate X = Ancillary[2:3]
272 * VRS Rate Y = Ancillary[4:5]
273 */
274 nir_def *x_rate = ac_nir_unpack_arg(b, s->args, s->args->ancillary, 2, 2);
275 nir_def *y_rate = ac_nir_unpack_arg(b, s->args, s->args->ancillary, 4, 2);
276
277 /* xRate = xRate == 0x1 ? Horizontal2Pixels : None. */
278 x_rate = nir_bcsel(b, nir_ieq_imm(b, x_rate, 1), nir_imm_int(b, 4), nir_imm_int(b, 0));
279
280 /* yRate = yRate == 0x1 ? Vertical2Pixels : None. */
281 y_rate = nir_bcsel(b, nir_ieq_imm(b, y_rate, 1), nir_imm_int(b, 1), nir_imm_int(b, 0));
282 replacement = nir_ior(b, x_rate, y_rate);
283 break;
284 }
285 case nir_intrinsic_load_front_face:
286 replacement = nir_fgt_imm(b, ac_nir_load_arg(b, s->args, s->args->front_face), 0);
287 break;
288 case nir_intrinsic_load_front_face_fsign:
289 replacement = ac_nir_load_arg(b, s->args, s->args->front_face);
290 break;
291 case nir_intrinsic_load_layer_id:
292 replacement = ac_nir_unpack_arg(b, s->args, s->args->ancillary,
293 16, s->gfx_level >= GFX12 ? 14 : 13);
294 break;
295 case nir_intrinsic_load_barycentric_optimize_amd: {
296 nir_def *prim_mask = ac_nir_load_arg(b, s->args, s->args->prim_mask);
297 /* enabled when bit 31 is set */
298 replacement = nir_ilt_imm(b, prim_mask, 0);
299 break;
300 }
301 case nir_intrinsic_load_barycentric_pixel:
302 if (nir_intrinsic_interp_mode(intrin) == INTERP_MODE_NOPERSPECTIVE)
303 replacement = ac_nir_load_arg(b, s->args, s->args->linear_center);
304 else
305 replacement = ac_nir_load_arg(b, s->args, s->args->persp_center);
306 nir_intrinsic_set_flags(nir_instr_as_intrinsic(replacement->parent_instr),
307 AC_VECTOR_ARG_FLAG(AC_VECTOR_ARG_INTERP_MODE,
308 nir_intrinsic_interp_mode(intrin)));
309 break;
310 case nir_intrinsic_load_barycentric_centroid:
311 if (nir_intrinsic_interp_mode(intrin) == INTERP_MODE_NOPERSPECTIVE)
312 replacement = ac_nir_load_arg(b, s->args, s->args->linear_centroid);
313 else
314 replacement = ac_nir_load_arg(b, s->args, s->args->persp_centroid);
315 nir_intrinsic_set_flags(nir_instr_as_intrinsic(replacement->parent_instr),
316 AC_VECTOR_ARG_FLAG(AC_VECTOR_ARG_INTERP_MODE,
317 nir_intrinsic_interp_mode(intrin)));
318 break;
319 case nir_intrinsic_load_barycentric_sample:
320 if (nir_intrinsic_interp_mode(intrin) == INTERP_MODE_NOPERSPECTIVE)
321 replacement = ac_nir_load_arg(b, s->args, s->args->linear_sample);
322 else
323 replacement = ac_nir_load_arg(b, s->args, s->args->persp_sample);
324 nir_intrinsic_set_flags(nir_instr_as_intrinsic(replacement->parent_instr),
325 AC_VECTOR_ARG_FLAG(AC_VECTOR_ARG_INTERP_MODE,
326 nir_intrinsic_interp_mode(intrin)));
327 break;
328 case nir_intrinsic_load_barycentric_model:
329 replacement = ac_nir_load_arg(b, s->args, s->args->pull_model);
330 break;
331 case nir_intrinsic_load_barycentric_at_offset: {
332 nir_def *baryc = nir_intrinsic_interp_mode(intrin) == INTERP_MODE_NOPERSPECTIVE ?
333 ac_nir_load_arg(b, s->args, s->args->linear_center) :
334 ac_nir_load_arg(b, s->args, s->args->persp_center);
335 nir_def *i = nir_channel(b, baryc, 0);
336 nir_def *j = nir_channel(b, baryc, 1);
337 nir_def *offset_x = nir_channel(b, intrin->src[0].ssa, 0);
338 nir_def *offset_y = nir_channel(b, intrin->src[0].ssa, 1);
339 nir_def *ddx_i = nir_ddx(b, i);
340 nir_def *ddx_j = nir_ddx(b, j);
341 nir_def *ddy_i = nir_ddy(b, i);
342 nir_def *ddy_j = nir_ddy(b, j);
343
344 /* Interpolate standard barycentrics by offset. */
345 nir_def *offset_i = nir_ffma(b, ddy_i, offset_y, nir_ffma(b, ddx_i, offset_x, i));
346 nir_def *offset_j = nir_ffma(b, ddy_j, offset_y, nir_ffma(b, ddx_j, offset_x, j));
347 replacement = nir_vec2(b, offset_i, offset_j);
348 break;
349 }
350 case nir_intrinsic_load_gs_wave_id_amd:
351 if (s->args->merged_wave_info.used)
352 replacement = ac_nir_unpack_arg(b, s->args, s->args->merged_wave_info, 16, 8);
353 else if (s->args->gs_wave_id.used)
354 replacement = ac_nir_load_arg(b, s->args, s->args->gs_wave_id);
355 else
356 unreachable("Shader doesn't have GS wave ID.");
357 break;
358 case nir_intrinsic_overwrite_vs_arguments_amd:
359 s->vertex_id = intrin->src[0].ssa;
360 s->instance_id = intrin->src[1].ssa;
361 nir_instr_remove(instr);
362 return true;
363 case nir_intrinsic_overwrite_tes_arguments_amd:
364 s->tes_u = intrin->src[0].ssa;
365 s->tes_v = intrin->src[1].ssa;
366 s->tes_patch_id = intrin->src[2].ssa;
367 s->tes_rel_patch_id = intrin->src[3].ssa;
368 nir_instr_remove(instr);
369 return true;
370 case nir_intrinsic_load_vertex_id_zero_base:
371 if (!s->vertex_id)
372 s->vertex_id = preload_arg(s, b->impl, s->args->vertex_id, s->args->tcs_patch_id, 0);
373 replacement = s->vertex_id;
374 break;
375 case nir_intrinsic_load_instance_id:
376 if (!s->instance_id)
377 s->instance_id = preload_arg(s, b->impl, s->args->instance_id, s->args->vertex_id, 0);
378 replacement = s->instance_id;
379 break;
380 case nir_intrinsic_load_tess_rel_patch_id_amd:
381 if (b->shader->info.stage == MESA_SHADER_TESS_CTRL) {
382 replacement = ac_nir_unpack_arg(b, s->args, s->args->tcs_rel_ids, 0, 8);
383 } else if (b->shader->info.stage == MESA_SHADER_TESS_EVAL) {
384 if (s->tes_rel_patch_id) {
385 replacement = s->tes_rel_patch_id;
386 } else {
387 replacement = ac_nir_load_arg(b, s->args, s->args->tes_rel_patch_id);
388 if (b->shader->info.tess.tcs_vertices_out) {
389 /* Setting an upper bound like this will actually make it possible
390 * to optimize some multiplications (in address calculations) so that
391 * constant additions can be added to the const offset in memory load instructions.
392 */
393 nir_intrinsic_set_arg_upper_bound_u32_amd(nir_instr_as_intrinsic(replacement->parent_instr),
394 2048 / b->shader->info.tess.tcs_vertices_out);
395 }
396 }
397 } else {
398 unreachable("invalid stage");
399 }
400 break;
401 case nir_intrinsic_load_primitive_id:
402 if (b->shader->info.stage == MESA_SHADER_GEOMETRY) {
403 replacement = ac_nir_load_arg(b, s->args, s->args->gs_prim_id);
404 } else if (b->shader->info.stage == MESA_SHADER_TESS_CTRL) {
405 replacement = ac_nir_load_arg(b, s->args, s->args->tcs_patch_id);
406 } else if (b->shader->info.stage == MESA_SHADER_TESS_EVAL) {
407 replacement = s->tes_patch_id ? s->tes_patch_id :
408 ac_nir_load_arg(b, s->args, s->args->tes_patch_id);
409 } else if (b->shader->info.stage == MESA_SHADER_VERTEX) {
410 if (s->hw_stage == AC_HW_VERTEX_SHADER)
411 replacement = ac_nir_load_arg(b, s->args, s->args->vs_prim_id); /* legacy */
412 else
413 replacement = ac_nir_load_arg(b, s->args, s->args->gs_prim_id); /* NGG */
414 } else {
415 unreachable("invalid stage");
416 }
417 break;
418 case nir_intrinsic_load_tess_coord: {
419 nir_def *coord[3] = {
420 s->tes_u ? s->tes_u : ac_nir_load_arg(b, s->args, s->args->tes_u),
421 s->tes_v ? s->tes_v : ac_nir_load_arg(b, s->args, s->args->tes_v),
422 nir_imm_float(b, 0),
423 };
424
425 /* For triangles, the vector should be (u, v, 1-u-v). */
426 if (b->shader->info.tess._primitive_mode == TESS_PRIMITIVE_TRIANGLES)
427 coord[2] = nir_fsub(b, nir_imm_float(b, 1), nir_fadd(b, coord[0], coord[1]));
428 replacement = nir_vec(b, coord, 3);
429 break;
430 }
431 case nir_intrinsic_load_local_invocation_index:
432 /* GFX11 HS has subgroup_id, so use it instead of vs_rel_patch_id. */
433 if (s->gfx_level < GFX11 &&
434 (s->hw_stage == AC_HW_LOCAL_SHADER || s->hw_stage == AC_HW_HULL_SHADER)) {
435 if (!s->vs_rel_patch_id) {
436 s->vs_rel_patch_id = preload_arg(s, b->impl, s->args->vs_rel_patch_id,
437 s->args->tcs_rel_ids, 255);
438 }
439 replacement = s->vs_rel_patch_id;
440 } else if (s->workgroup_size <= s->wave_size) {
441 /* Just a subgroup invocation ID. */
442 replacement = nir_mbcnt_amd(b, nir_imm_intN_t(b, ~0ull, s->wave_size), nir_imm_int(b, 0));
443 } else if (s->gfx_level < GFX12 && s->hw_stage == AC_HW_COMPUTE_SHADER && s->wave_size == 64) {
444 /* After the AND the bits are already multiplied by 64 (left shifted by 6) so we can just
445 * feed that to mbcnt. (GFX12 doesn't have tg_size)
446 */
447 nir_def *wave_id_mul_64 = nir_iand_imm(b, ac_nir_load_arg(b, s->args, s->args->tg_size), 0xfc0);
448 replacement = nir_mbcnt_amd(b, nir_imm_intN_t(b, ~0ull, s->wave_size), wave_id_mul_64);
449 } else {
450 nir_def *subgroup_id;
451
452 if (s->gfx_level >= GFX12 && s->hw_stage == AC_HW_COMPUTE_SHADER) {
453 subgroup_id = nir_load_subgroup_id(b);
454 } else {
455 subgroup_id = load_subgroup_id_lowered(s, b);
456 }
457
458 replacement = nir_mbcnt_amd(b, nir_imm_intN_t(b, ~0ull, s->wave_size),
459 nir_imul_imm(b, subgroup_id, s->wave_size));
460 }
461 break;
462 case nir_intrinsic_load_subgroup_invocation:
463 replacement = nir_mbcnt_amd(b, nir_imm_intN_t(b, ~0ull, s->wave_size), nir_imm_int(b, 0));
464 break;
465 default:
466 return false;
467 }
468
469 assert(replacement);
470 nir_def_replace(&intrin->def, replacement);
471 return true;
472 }
473
474 bool
ac_nir_lower_intrinsics_to_args(nir_shader * shader,const enum amd_gfx_level gfx_level,bool has_ls_vgpr_init_bug,const enum ac_hw_stage hw_stage,unsigned wave_size,unsigned workgroup_size,const struct ac_shader_args * ac_args)475 ac_nir_lower_intrinsics_to_args(nir_shader *shader, const enum amd_gfx_level gfx_level,
476 bool has_ls_vgpr_init_bug, const enum ac_hw_stage hw_stage,
477 unsigned wave_size, unsigned workgroup_size,
478 const struct ac_shader_args *ac_args)
479 {
480 lower_intrinsics_to_args_state state = {
481 .gfx_level = gfx_level,
482 .hw_stage = hw_stage,
483 .has_ls_vgpr_init_bug = has_ls_vgpr_init_bug,
484 .wave_size = wave_size,
485 .workgroup_size = workgroup_size,
486 .args = ac_args,
487 };
488
489 return nir_shader_instructions_pass(shader, lower_intrinsic_to_arg,
490 nir_metadata_control_flow, &state);
491 }
492