1 /*
2 * Copyright © 2021 Valve Corporation
3 *
4 * SPDX-License-Identifier: MIT
5 */
6
7 #include "ac_nir.h"
8 #include "ac_nir_helpers.h"
9
10 #include "nir_builder.h"
11
12 typedef struct {
13 const struct ac_shader_args *const args;
14 const enum amd_gfx_level gfx_level;
15 bool has_ls_vgpr_init_bug;
16 unsigned wave_size;
17 unsigned workgroup_size;
18 const enum ac_hw_stage hw_stage;
19
20 nir_def *vertex_id;
21 nir_def *instance_id;
22 nir_def *vs_rel_patch_id;
23 nir_def *tes_u;
24 nir_def *tes_v;
25 nir_def *tes_patch_id;
26 nir_def *tes_rel_patch_id;
27 } lower_intrinsics_to_args_state;
28
29 static nir_def *
preload_arg(lower_intrinsics_to_args_state * s,nir_function_impl * impl,struct ac_arg arg,struct ac_arg ls_buggy_arg,unsigned upper_bound)30 preload_arg(lower_intrinsics_to_args_state *s, nir_function_impl *impl, struct ac_arg arg,
31 struct ac_arg ls_buggy_arg, unsigned upper_bound)
32 {
33 nir_builder start_b = nir_builder_at(nir_before_impl(impl));
34 nir_def *value = ac_nir_load_arg_upper_bound(&start_b, s->args, arg, upper_bound);
35
36 /* If there are no HS threads, SPI mistakenly loads the LS VGPRs starting at VGPR 0. */
37 if ((s->hw_stage == AC_HW_LOCAL_SHADER || s->hw_stage == AC_HW_HULL_SHADER) &&
38 s->has_ls_vgpr_init_bug) {
39 nir_def *count = ac_nir_unpack_arg(&start_b, s->args, s->args->merged_wave_info, 8, 8);
40 nir_def *hs_empty = nir_ieq_imm(&start_b, count, 0);
41 value = nir_bcsel(&start_b, hs_empty,
42 ac_nir_load_arg_upper_bound(&start_b, s->args, ls_buggy_arg, upper_bound),
43 value);
44 }
45 return value;
46 }
47
48 static nir_def *
load_subgroup_id_lowered(lower_intrinsics_to_args_state * s,nir_builder * b)49 load_subgroup_id_lowered(lower_intrinsics_to_args_state *s, nir_builder *b)
50 {
51 if (s->workgroup_size <= s->wave_size) {
52 return nir_imm_int(b, 0);
53 } else if (s->hw_stage == AC_HW_COMPUTE_SHADER) {
54 assert(s->gfx_level < GFX12 && s->args->tg_size.used);
55
56 if (s->gfx_level >= GFX10_3) {
57 return ac_nir_unpack_arg(b, s->args, s->args->tg_size, 20, 5);
58 } else {
59 /* GFX6-10 don't actually support a wave id, but we can
60 * use the ordered id because ORDERED_APPEND_* is set to
61 * zero in the compute dispatch initiator.
62 */
63 return ac_nir_unpack_arg(b, s->args, s->args->tg_size, 6, 6);
64 }
65 } else if (s->hw_stage == AC_HW_HULL_SHADER && s->gfx_level >= GFX11) {
66 assert(s->args->tcs_wave_id.used);
67 return ac_nir_unpack_arg(b, s->args, s->args->tcs_wave_id, 0, 3);
68 } else if (s->hw_stage == AC_HW_LEGACY_GEOMETRY_SHADER ||
69 s->hw_stage == AC_HW_NEXT_GEN_GEOMETRY_SHADER) {
70 assert(s->args->merged_wave_info.used);
71 return ac_nir_unpack_arg(b, s->args, s->args->merged_wave_info, 24, 4);
72 } else {
73 return nir_imm_int(b, 0);
74 }
75 }
76
77 static bool
lower_intrinsic_to_arg(nir_builder * b,nir_intrinsic_instr * intrin,void * state)78 lower_intrinsic_to_arg(nir_builder *b, nir_intrinsic_instr *intrin, void *state)
79 {
80 lower_intrinsics_to_args_state *s = (lower_intrinsics_to_args_state *)state;
81 nir_def *replacement = NULL;
82 b->cursor = nir_after_instr(&intrin->instr);
83
84 switch (intrin->intrinsic) {
85 case nir_intrinsic_load_subgroup_id:
86 if (s->gfx_level >= GFX12 && s->hw_stage == AC_HW_COMPUTE_SHADER)
87 return false; /* Lowered in backend compilers. */
88 replacement = load_subgroup_id_lowered(s, b);
89 break;
90 case nir_intrinsic_load_num_subgroups: {
91 if (s->hw_stage == AC_HW_COMPUTE_SHADER) {
92 assert(s->args->tg_size.used);
93 replacement = ac_nir_unpack_arg(b, s->args, s->args->tg_size, 0, 6);
94 } else if (s->hw_stage == AC_HW_LEGACY_GEOMETRY_SHADER ||
95 s->hw_stage == AC_HW_NEXT_GEN_GEOMETRY_SHADER) {
96 assert(s->args->merged_wave_info.used);
97 replacement = ac_nir_unpack_arg(b, s->args, s->args->merged_wave_info, 28, 4);
98 } else {
99 replacement = nir_imm_int(b, 1);
100 }
101
102 break;
103 }
104 case nir_intrinsic_load_workgroup_id:
105 if (b->shader->info.stage == MESA_SHADER_MESH) {
106 /* This lowering is only valid with fast_launch = 2, otherwise we assume that
107 * lower_workgroup_id_to_index removed any uses of the workgroup id by this point.
108 */
109 assert(s->gfx_level >= GFX11);
110 nir_def *xy = ac_nir_load_arg(b, s->args, s->args->tess_offchip_offset);
111 nir_def *z = ac_nir_load_arg(b, s->args, s->args->gs_attr_offset);
112 replacement = nir_vec3(b, nir_extract_u16(b, xy, nir_imm_int(b, 0)),
113 nir_extract_u16(b, xy, nir_imm_int(b, 1)),
114 nir_extract_u16(b, z, nir_imm_int(b, 1)));
115 } else {
116 return false;
117 }
118 break;
119 case nir_intrinsic_load_pixel_coord:
120 replacement = nir_unpack_32_2x16(b, ac_nir_load_arg(b, s->args, s->args->pos_fixed_pt));
121 break;
122 case nir_intrinsic_load_frag_coord:
123 replacement = nir_vec4(b, ac_nir_load_arg(b, s->args, s->args->frag_pos[0]),
124 ac_nir_load_arg(b, s->args, s->args->frag_pos[1]),
125 ac_nir_load_arg(b, s->args, s->args->frag_pos[2]),
126 ac_nir_load_arg(b, s->args, s->args->frag_pos[3]));
127 break;
128 case nir_intrinsic_load_local_invocation_id: {
129 unsigned num_bits[3];
130 nir_def *vec[3];
131
132 for (unsigned i = 0; i < 3; i++) {
133 bool has_chan = b->shader->info.workgroup_size_variable ||
134 b->shader->info.workgroup_size[i] > 1;
135 /* Extract as few bits possible - we want the constant to be an inline constant
136 * instead of a literal.
137 */
138 num_bits[i] = !has_chan ? 0 :
139 b->shader->info.workgroup_size_variable ?
140 10 : util_logbase2_ceil(b->shader->info.workgroup_size[i]);
141 }
142
143 if (s->args->local_invocation_ids_packed.used) {
144 unsigned extract_bits[3];
145 memcpy(extract_bits, num_bits, sizeof(num_bits));
146
147 /* Thread IDs are packed in VGPR0, 10 bits per component.
148 * Always extract all remaining bits if later ID components are always 0, which will
149 * translate to a bit shift.
150 */
151 if (num_bits[2]) {
152 extract_bits[2] = 12; /* Z > 0 */
153 } else if (num_bits[1])
154 extract_bits[1] = 22; /* Y > 0, Z == 0 */
155 else if (num_bits[0])
156 extract_bits[0] = 32; /* X > 0, Y == 0, Z == 0 */
157
158 nir_def *ids_packed =
159 ac_nir_load_arg_upper_bound(b, s->args, s->args->local_invocation_ids_packed,
160 b->shader->info.workgroup_size_variable ?
161 0 : ((b->shader->info.workgroup_size[0] - 1) |
162 ((b->shader->info.workgroup_size[1] - 1) << 10) |
163 ((b->shader->info.workgroup_size[2] - 1) << 20)));
164
165 for (unsigned i = 0; i < 3; i++) {
166 vec[i] = !num_bits[i] ? nir_imm_int(b, 0) :
167 ac_nir_unpack_value(b, ids_packed, i * 10, extract_bits[i]);
168 }
169 } else {
170 const struct ac_arg ids[] = {
171 s->args->local_invocation_id_x,
172 s->args->local_invocation_id_y,
173 s->args->local_invocation_id_z,
174 };
175
176 for (unsigned i = 0; i < 3; i++) {
177 unsigned max = b->shader->info.workgroup_size_variable ?
178 1023 : (b->shader->info.workgroup_size[i] - 1);
179 vec[i] = !num_bits[i] ? nir_imm_int(b, 0) :
180 ac_nir_load_arg_upper_bound(b, s->args, ids[i], max);
181 }
182 }
183 replacement = nir_vec(b, vec, 3);
184 break;
185 }
186 case nir_intrinsic_load_merged_wave_info_amd:
187 replacement = ac_nir_load_arg(b, s->args, s->args->merged_wave_info);
188 break;
189 case nir_intrinsic_load_workgroup_num_input_vertices_amd:
190 replacement = ac_nir_unpack_arg(b, s->args, s->args->gs_tg_info, 12, 9);
191 break;
192 case nir_intrinsic_load_workgroup_num_input_primitives_amd:
193 replacement = ac_nir_unpack_arg(b, s->args, s->args->gs_tg_info, 22, 9);
194 break;
195 case nir_intrinsic_load_packed_passthrough_primitive_amd:
196 /* NGG passthrough mode: the HW already packs the primitive export value to a single register.
197 */
198 replacement = ac_nir_load_arg(b, s->args, s->args->gs_vtx_offset[0]);
199 break;
200 case nir_intrinsic_load_ordered_id_amd:
201 replacement = ac_nir_unpack_arg(b, s->args, s->args->gs_tg_info, 0, 12);
202 break;
203 case nir_intrinsic_load_ring_tess_offchip_offset_amd:
204 replacement = ac_nir_load_arg(b, s->args, s->args->tess_offchip_offset);
205 break;
206 case nir_intrinsic_load_ring_tess_factors_offset_amd:
207 replacement = ac_nir_load_arg(b, s->args, s->args->tcs_factor_offset);
208 break;
209 case nir_intrinsic_load_ring_es2gs_offset_amd:
210 replacement = ac_nir_load_arg(b, s->args, s->args->es2gs_offset);
211 break;
212 case nir_intrinsic_load_ring_gs2vs_offset_amd:
213 replacement = ac_nir_load_arg(b, s->args, s->args->gs2vs_offset);
214 break;
215 case nir_intrinsic_load_gs_vertex_offset_amd:
216 replacement = ac_nir_load_arg(b, s->args, s->args->gs_vtx_offset[nir_intrinsic_base(intrin)]);
217 break;
218 case nir_intrinsic_load_streamout_config_amd:
219 replacement = ac_nir_load_arg(b, s->args, s->args->streamout_config);
220 break;
221 case nir_intrinsic_load_streamout_write_index_amd:
222 replacement = ac_nir_load_arg(b, s->args, s->args->streamout_write_index);
223 break;
224 case nir_intrinsic_load_streamout_offset_amd:
225 replacement = ac_nir_load_arg(b, s->args, s->args->streamout_offset[nir_intrinsic_base(intrin)]);
226 break;
227 case nir_intrinsic_load_ring_attr_offset_amd: {
228 nir_def *ring_attr_offset = ac_nir_load_arg(b, s->args, s->args->gs_attr_offset);
229 replacement = nir_ishl_imm(b, nir_ubfe_imm(b, ring_attr_offset, 0, 15), 9); /* 512b increments. */
230 break;
231 }
232 case nir_intrinsic_load_first_vertex:
233 replacement = ac_nir_load_arg(b, s->args, s->args->base_vertex);
234 break;
235 case nir_intrinsic_load_base_instance:
236 replacement = ac_nir_load_arg(b, s->args, s->args->start_instance);
237 break;
238 case nir_intrinsic_load_draw_id:
239 replacement = ac_nir_load_arg(b, s->args, s->args->draw_id);
240 break;
241 case nir_intrinsic_load_view_index:
242 replacement = ac_nir_load_arg_upper_bound(b, s->args, s->args->view_index, 1);
243 break;
244 case nir_intrinsic_load_invocation_id:
245 if (b->shader->info.stage == MESA_SHADER_TESS_CTRL) {
246 replacement = ac_nir_unpack_arg(b, s->args, s->args->tcs_rel_ids, 8, 5);
247 } else if (b->shader->info.stage == MESA_SHADER_GEOMETRY) {
248 if (s->gfx_level >= GFX12) {
249 replacement = ac_nir_unpack_arg(b, s->args, s->args->gs_vtx_offset[0], 27, 5);
250 } else if (s->gfx_level >= GFX10) {
251 replacement = ac_nir_unpack_arg(b, s->args, s->args->gs_invocation_id, 0, 5);
252 } else {
253 replacement = ac_nir_load_arg_upper_bound(b, s->args, s->args->gs_invocation_id, 31);
254 }
255 } else {
256 unreachable("unexpected shader stage");
257 }
258 break;
259 case nir_intrinsic_load_sample_id:
260 replacement = ac_nir_unpack_arg(b, s->args, s->args->ancillary, 8, 4);
261 break;
262 case nir_intrinsic_load_sample_pos:
263 replacement = nir_vec2(b, nir_ffract(b, ac_nir_load_arg(b, s->args, s->args->frag_pos[0])),
264 nir_ffract(b, ac_nir_load_arg(b, s->args, s->args->frag_pos[1])));
265 break;
266 case nir_intrinsic_load_frag_shading_rate: {
267 /* VRS Rate X = Ancillary[2:3]
268 * VRS Rate Y = Ancillary[4:5]
269 */
270 nir_def *x_rate = ac_nir_unpack_arg(b, s->args, s->args->ancillary, 2, 2);
271 nir_def *y_rate = ac_nir_unpack_arg(b, s->args, s->args->ancillary, 4, 2);
272
273 /* xRate = xRate == 0x1 ? Horizontal2Pixels : None. */
274 x_rate = nir_bcsel(b, nir_ieq_imm(b, x_rate, 1), nir_imm_int(b, 4), nir_imm_int(b, 0));
275
276 /* yRate = yRate == 0x1 ? Vertical2Pixels : None. */
277 y_rate = nir_bcsel(b, nir_ieq_imm(b, y_rate, 1), nir_imm_int(b, 1), nir_imm_int(b, 0));
278 replacement = nir_ior(b, x_rate, y_rate);
279 break;
280 }
281 case nir_intrinsic_load_front_face:
282 replacement = nir_fgt_imm(b, ac_nir_load_arg(b, s->args, s->args->front_face), 0);
283 break;
284 case nir_intrinsic_load_front_face_fsign:
285 replacement = ac_nir_load_arg(b, s->args, s->args->front_face);
286 break;
287 case nir_intrinsic_load_layer_id:
288 replacement = ac_nir_unpack_arg(b, s->args, s->args->ancillary,
289 16, s->gfx_level >= GFX12 ? 14 : 13);
290 break;
291 case nir_intrinsic_load_barycentric_optimize_amd: {
292 nir_def *prim_mask = ac_nir_load_arg(b, s->args, s->args->prim_mask);
293 /* enabled when bit 31 is set */
294 replacement = nir_ilt_imm(b, prim_mask, 0);
295 break;
296 }
297 case nir_intrinsic_load_barycentric_pixel:
298 if (nir_intrinsic_interp_mode(intrin) == INTERP_MODE_NOPERSPECTIVE)
299 replacement = ac_nir_load_arg(b, s->args, s->args->linear_center);
300 else
301 replacement = ac_nir_load_arg(b, s->args, s->args->persp_center);
302 break;
303 case nir_intrinsic_load_barycentric_centroid:
304 if (nir_intrinsic_interp_mode(intrin) == INTERP_MODE_NOPERSPECTIVE)
305 replacement = ac_nir_load_arg(b, s->args, s->args->linear_centroid);
306 else
307 replacement = ac_nir_load_arg(b, s->args, s->args->persp_centroid);
308 break;
309 case nir_intrinsic_load_barycentric_sample:
310 if (nir_intrinsic_interp_mode(intrin) == INTERP_MODE_NOPERSPECTIVE)
311 replacement = ac_nir_load_arg(b, s->args, s->args->linear_sample);
312 else
313 replacement = ac_nir_load_arg(b, s->args, s->args->persp_sample);
314 break;
315 case nir_intrinsic_load_barycentric_model:
316 replacement = ac_nir_load_arg(b, s->args, s->args->pull_model);
317 break;
318 case nir_intrinsic_load_barycentric_at_offset: {
319 nir_def *baryc = nir_intrinsic_interp_mode(intrin) == INTERP_MODE_NOPERSPECTIVE ?
320 ac_nir_load_arg(b, s->args, s->args->linear_center) :
321 ac_nir_load_arg(b, s->args, s->args->persp_center);
322 nir_def *i = nir_channel(b, baryc, 0);
323 nir_def *j = nir_channel(b, baryc, 1);
324 nir_def *offset_x = nir_channel(b, intrin->src[0].ssa, 0);
325 nir_def *offset_y = nir_channel(b, intrin->src[0].ssa, 1);
326 nir_def *ddx_i = nir_ddx(b, i);
327 nir_def *ddx_j = nir_ddx(b, j);
328 nir_def *ddy_i = nir_ddy(b, i);
329 nir_def *ddy_j = nir_ddy(b, j);
330
331 /* Interpolate standard barycentrics by offset. */
332 nir_def *offset_i = nir_ffma(b, ddy_i, offset_y, nir_ffma(b, ddx_i, offset_x, i));
333 nir_def *offset_j = nir_ffma(b, ddy_j, offset_y, nir_ffma(b, ddx_j, offset_x, j));
334 replacement = nir_vec2(b, offset_i, offset_j);
335 break;
336 }
337 case nir_intrinsic_load_gs_wave_id_amd:
338 if (s->args->merged_wave_info.used)
339 replacement = ac_nir_unpack_arg(b, s->args, s->args->merged_wave_info, 16, 8);
340 else if (s->args->gs_wave_id.used)
341 replacement = ac_nir_load_arg(b, s->args, s->args->gs_wave_id);
342 else
343 unreachable("Shader doesn't have GS wave ID.");
344 break;
345 case nir_intrinsic_overwrite_vs_arguments_amd:
346 s->vertex_id = intrin->src[0].ssa;
347 s->instance_id = intrin->src[1].ssa;
348 nir_instr_remove(&intrin->instr);
349 return true;
350 case nir_intrinsic_overwrite_tes_arguments_amd:
351 s->tes_u = intrin->src[0].ssa;
352 s->tes_v = intrin->src[1].ssa;
353 s->tes_patch_id = intrin->src[2].ssa;
354 s->tes_rel_patch_id = intrin->src[3].ssa;
355 nir_instr_remove(&intrin->instr);
356 return true;
357 case nir_intrinsic_load_vertex_id_zero_base:
358 if (!s->vertex_id)
359 s->vertex_id = preload_arg(s, b->impl, s->args->vertex_id, s->args->tcs_patch_id, 0);
360 replacement = s->vertex_id;
361 break;
362 case nir_intrinsic_load_instance_id:
363 if (!s->instance_id)
364 s->instance_id = preload_arg(s, b->impl, s->args->instance_id, s->args->vertex_id, 0);
365 replacement = s->instance_id;
366 break;
367 case nir_intrinsic_load_tess_rel_patch_id_amd:
368 if (b->shader->info.stage == MESA_SHADER_TESS_CTRL) {
369 replacement = ac_nir_unpack_arg(b, s->args, s->args->tcs_rel_ids, 0, 8);
370 } else if (b->shader->info.stage == MESA_SHADER_TESS_EVAL) {
371 if (s->tes_rel_patch_id) {
372 replacement = s->tes_rel_patch_id;
373 } else {
374 replacement = ac_nir_load_arg(b, s->args, s->args->tes_rel_patch_id);
375 if (b->shader->info.tess.tcs_vertices_out) {
376 /* Setting an upper bound like this will actually make it possible
377 * to optimize some multiplications (in address calculations) so that
378 * constant additions can be added to the const offset in memory load instructions.
379 */
380 nir_intrinsic_set_arg_upper_bound_u32_amd(nir_instr_as_intrinsic(replacement->parent_instr),
381 2048 / b->shader->info.tess.tcs_vertices_out);
382 }
383 }
384 } else {
385 unreachable("invalid stage");
386 }
387 break;
388 case nir_intrinsic_load_primitive_id:
389 if (b->shader->info.stage == MESA_SHADER_GEOMETRY) {
390 replacement = ac_nir_load_arg(b, s->args, s->args->gs_prim_id);
391 } else if (b->shader->info.stage == MESA_SHADER_TESS_CTRL) {
392 replacement = ac_nir_load_arg(b, s->args, s->args->tcs_patch_id);
393 } else if (b->shader->info.stage == MESA_SHADER_TESS_EVAL) {
394 replacement = s->tes_patch_id ? s->tes_patch_id :
395 ac_nir_load_arg(b, s->args, s->args->tes_patch_id);
396 } else if (b->shader->info.stage == MESA_SHADER_VERTEX) {
397 if (s->hw_stage == AC_HW_VERTEX_SHADER)
398 replacement = ac_nir_load_arg(b, s->args, s->args->vs_prim_id); /* legacy */
399 else
400 replacement = ac_nir_load_arg(b, s->args, s->args->gs_prim_id); /* NGG */
401 } else {
402 unreachable("invalid stage");
403 }
404 break;
405 case nir_intrinsic_load_tess_coord: {
406 nir_def *coord[3] = {
407 s->tes_u ? s->tes_u : ac_nir_load_arg(b, s->args, s->args->tes_u),
408 s->tes_v ? s->tes_v : ac_nir_load_arg(b, s->args, s->args->tes_v),
409 nir_imm_float(b, 0),
410 };
411
412 /* For triangles, the vector should be (u, v, 1-u-v). */
413 if (b->shader->info.tess._primitive_mode == TESS_PRIMITIVE_TRIANGLES)
414 coord[2] = nir_fsub(b, nir_imm_float(b, 1), nir_fadd(b, coord[0], coord[1]));
415 replacement = nir_vec(b, coord, 3);
416 break;
417 }
418 case nir_intrinsic_load_local_invocation_index:
419 /* GFX11 HS has subgroup_id, so use it instead of vs_rel_patch_id. */
420 if (s->gfx_level < GFX11 &&
421 (s->hw_stage == AC_HW_LOCAL_SHADER || s->hw_stage == AC_HW_HULL_SHADER)) {
422 if (!s->vs_rel_patch_id) {
423 s->vs_rel_patch_id = preload_arg(s, b->impl, s->args->vs_rel_patch_id,
424 s->args->tcs_rel_ids, 255);
425 }
426 replacement = s->vs_rel_patch_id;
427 } else if (s->workgroup_size <= s->wave_size) {
428 /* Just a subgroup invocation ID. */
429 replacement = nir_mbcnt_amd(b, nir_imm_intN_t(b, ~0ull, s->wave_size), nir_imm_int(b, 0));
430 } else if (s->gfx_level < GFX12 && s->hw_stage == AC_HW_COMPUTE_SHADER && s->wave_size == 64) {
431 /* After the AND the bits are already multiplied by 64 (left shifted by 6) so we can just
432 * feed that to mbcnt. (GFX12 doesn't have tg_size)
433 */
434 nir_def *wave_id_mul_64 = nir_iand_imm(b, ac_nir_load_arg(b, s->args, s->args->tg_size), 0xfc0);
435 replacement = nir_mbcnt_amd(b, nir_imm_intN_t(b, ~0ull, s->wave_size), wave_id_mul_64);
436 } else {
437 nir_def *subgroup_id;
438
439 if (s->gfx_level >= GFX12 && s->hw_stage == AC_HW_COMPUTE_SHADER) {
440 subgroup_id = nir_load_subgroup_id(b);
441 } else {
442 subgroup_id = load_subgroup_id_lowered(s, b);
443 }
444
445 replacement = nir_mbcnt_amd(b, nir_imm_intN_t(b, ~0ull, s->wave_size),
446 nir_imul_imm(b, subgroup_id, s->wave_size));
447 }
448 break;
449 case nir_intrinsic_load_subgroup_invocation:
450 replacement = nir_mbcnt_amd(b, nir_imm_intN_t(b, ~0ull, s->wave_size), nir_imm_int(b, 0));
451 break;
452 default:
453 return false;
454 }
455
456 assert(replacement);
457 nir_def_replace(&intrin->def, replacement);
458 return true;
459 }
460
461 bool
ac_nir_lower_intrinsics_to_args(nir_shader * shader,const enum amd_gfx_level gfx_level,bool has_ls_vgpr_init_bug,const enum ac_hw_stage hw_stage,unsigned wave_size,unsigned workgroup_size,const struct ac_shader_args * ac_args)462 ac_nir_lower_intrinsics_to_args(nir_shader *shader, const enum amd_gfx_level gfx_level,
463 bool has_ls_vgpr_init_bug, const enum ac_hw_stage hw_stage,
464 unsigned wave_size, unsigned workgroup_size,
465 const struct ac_shader_args *ac_args)
466 {
467 lower_intrinsics_to_args_state state = {
468 .gfx_level = gfx_level,
469 .hw_stage = hw_stage,
470 .has_ls_vgpr_init_bug = has_ls_vgpr_init_bug,
471 .wave_size = wave_size,
472 .workgroup_size = workgroup_size,
473 .args = ac_args,
474 };
475
476 return nir_shader_intrinsics_pass(shader, lower_intrinsic_to_arg,
477 nir_metadata_control_flow, &state);
478 }
479