1 /*
2 * Copyright © 2021 Google
3 *
4 * SPDX-License-Identifier: MIT
5 */
6
7 #include "nir/nir.h"
8 #include "nir/nir_builder.h"
9
10 #include "bvh/bvh.h"
11 #include "meta/radv_meta.h"
12 #include "nir/radv_nir.h"
13 #include "nir/radv_nir_rt_common.h"
14 #include "ac_nir.h"
15 #include "radv_pipeline_cache.h"
16 #include "radv_pipeline_rt.h"
17 #include "radv_shader.h"
18
19 #include "vk_pipeline.h"
20
21 /* Traversal stack size. This stack is put in LDS and experimentally 16 entries results in best
22 * performance. */
23 #define MAX_STACK_ENTRY_COUNT 16
24
25 #define RADV_RT_SWITCH_NULL_CHECK_THRESHOLD 3
26
27 /* Minimum number of inlined shaders to use binary search to select which shader to run. */
28 #define INLINED_SHADER_BSEARCH_THRESHOLD 16
29
30 struct radv_rt_case_data {
31 struct radv_device *device;
32 struct radv_ray_tracing_pipeline *pipeline;
33 struct rt_variables *vars;
34 };
35
36 typedef void (*radv_get_group_info)(struct radv_ray_tracing_group *, uint32_t *, uint32_t *,
37 struct radv_rt_case_data *);
38 typedef void (*radv_insert_shader_case)(nir_builder *, nir_def *, struct radv_ray_tracing_group *,
39 struct radv_rt_case_data *);
40
41 struct inlined_shader_case {
42 struct radv_ray_tracing_group *group;
43 uint32_t call_idx;
44 };
45
46 static int
compare_inlined_shader_case(const void * a,const void * b)47 compare_inlined_shader_case(const void *a, const void *b)
48 {
49 const struct inlined_shader_case *visit_a = a;
50 const struct inlined_shader_case *visit_b = b;
51 return visit_a->call_idx > visit_b->call_idx ? 1 : visit_a->call_idx < visit_b->call_idx ? -1 : 0;
52 }
53
54 static void
insert_inlined_range(nir_builder * b,nir_def * sbt_idx,radv_insert_shader_case shader_case,struct radv_rt_case_data * data,struct inlined_shader_case * cases,uint32_t length)55 insert_inlined_range(nir_builder *b, nir_def *sbt_idx, radv_insert_shader_case shader_case,
56 struct radv_rt_case_data *data, struct inlined_shader_case *cases, uint32_t length)
57 {
58 if (length >= INLINED_SHADER_BSEARCH_THRESHOLD) {
59 nir_push_if(b, nir_ige_imm(b, sbt_idx, cases[length / 2].call_idx));
60 {
61 insert_inlined_range(b, sbt_idx, shader_case, data, cases + (length / 2), length - (length / 2));
62 }
63 nir_push_else(b, NULL);
64 {
65 insert_inlined_range(b, sbt_idx, shader_case, data, cases, length / 2);
66 }
67 nir_pop_if(b, NULL);
68 } else {
69 for (uint32_t i = 0; i < length; ++i)
70 shader_case(b, sbt_idx, cases[i].group, data);
71 }
72 }
73
74 static void
radv_visit_inlined_shaders(nir_builder * b,nir_def * sbt_idx,bool can_have_null_shaders,struct radv_rt_case_data * data,radv_get_group_info group_info,radv_insert_shader_case shader_case)75 radv_visit_inlined_shaders(nir_builder *b, nir_def *sbt_idx, bool can_have_null_shaders, struct radv_rt_case_data *data,
76 radv_get_group_info group_info, radv_insert_shader_case shader_case)
77 {
78 struct inlined_shader_case *cases = calloc(data->pipeline->group_count, sizeof(struct inlined_shader_case));
79 uint32_t case_count = 0;
80
81 for (unsigned i = 0; i < data->pipeline->group_count; i++) {
82 struct radv_ray_tracing_group *group = &data->pipeline->groups[i];
83
84 uint32_t shader_index = VK_SHADER_UNUSED_KHR;
85 uint32_t handle_index = VK_SHADER_UNUSED_KHR;
86 group_info(group, &shader_index, &handle_index, data);
87 if (shader_index == VK_SHADER_UNUSED_KHR)
88 continue;
89
90 /* Avoid emitting stages with the same shaders/handles multiple times. */
91 bool duplicate = false;
92 for (unsigned j = 0; j < i; j++) {
93 uint32_t other_shader_index = VK_SHADER_UNUSED_KHR;
94 uint32_t other_handle_index = VK_SHADER_UNUSED_KHR;
95 group_info(&data->pipeline->groups[j], &other_shader_index, &other_handle_index, data);
96
97 if (handle_index == other_handle_index) {
98 duplicate = true;
99 break;
100 }
101 }
102
103 if (!duplicate) {
104 cases[case_count++] = (struct inlined_shader_case){
105 .group = group,
106 .call_idx = handle_index,
107 };
108 }
109 }
110
111 qsort(cases, case_count, sizeof(struct inlined_shader_case), compare_inlined_shader_case);
112
113 /* Do not emit 'if (sbt_idx != 0) { ... }' is there are only a few cases. */
114 can_have_null_shaders &= case_count >= RADV_RT_SWITCH_NULL_CHECK_THRESHOLD;
115
116 if (can_have_null_shaders)
117 nir_push_if(b, nir_ine_imm(b, sbt_idx, 0));
118
119 insert_inlined_range(b, sbt_idx, shader_case, data, cases, case_count);
120
121 if (can_have_null_shaders)
122 nir_pop_if(b, NULL);
123
124 free(cases);
125 }
126
127 static bool
lower_rt_derefs(nir_shader * shader)128 lower_rt_derefs(nir_shader *shader)
129 {
130 nir_function_impl *impl = nir_shader_get_entrypoint(shader);
131
132 bool progress = false;
133
134 nir_builder b = nir_builder_at(nir_before_impl(impl));
135
136 nir_def *arg_offset = nir_load_rt_arg_scratch_offset_amd(&b);
137
138 nir_foreach_block (block, impl) {
139 nir_foreach_instr_safe (instr, block) {
140 if (instr->type != nir_instr_type_deref)
141 continue;
142
143 nir_deref_instr *deref = nir_instr_as_deref(instr);
144 if (!nir_deref_mode_is(deref, nir_var_shader_call_data))
145 continue;
146
147 deref->modes = nir_var_function_temp;
148 progress = true;
149
150 if (deref->deref_type == nir_deref_type_var) {
151 b.cursor = nir_before_instr(&deref->instr);
152 nir_deref_instr *replacement =
153 nir_build_deref_cast(&b, arg_offset, nir_var_function_temp, deref->var->type, 0);
154 nir_def_replace(&deref->def, &replacement->def);
155 }
156 }
157 }
158
159 if (progress)
160 nir_metadata_preserve(impl, nir_metadata_control_flow);
161 else
162 nir_metadata_preserve(impl, nir_metadata_all);
163
164 return progress;
165 }
166
167 /*
168 * Global variables for an RT pipeline
169 */
170 struct rt_variables {
171 struct radv_device *device;
172 const VkPipelineCreateFlags2 flags;
173 bool monolithic;
174
175 /* idx of the next shader to run in the next iteration of the main loop.
176 * During traversal, idx is used to store the SBT index and will contain
177 * the correct resume index upon returning.
178 */
179 nir_variable *idx;
180 nir_variable *shader_addr;
181 nir_variable *traversal_addr;
182
183 /* scratch offset of the argument area relative to stack_ptr */
184 nir_variable *arg;
185 uint32_t payload_offset;
186
187 nir_variable *stack_ptr;
188
189 nir_variable *ahit_isec_count;
190
191 nir_variable *launch_sizes[3];
192 nir_variable *launch_ids[3];
193
194 /* global address of the SBT entry used for the shader */
195 nir_variable *shader_record_ptr;
196
197 /* trace_ray arguments */
198 nir_variable *accel_struct;
199 nir_variable *cull_mask_and_flags;
200 nir_variable *sbt_offset;
201 nir_variable *sbt_stride;
202 nir_variable *miss_index;
203 nir_variable *origin;
204 nir_variable *tmin;
205 nir_variable *direction;
206 nir_variable *tmax;
207
208 /* Properties of the primitive currently being visited. */
209 nir_variable *primitive_id;
210 nir_variable *geometry_id_and_flags;
211 nir_variable *instance_addr;
212 nir_variable *hit_kind;
213 nir_variable *opaque;
214
215 /* Output variables for intersection & anyhit shaders. */
216 nir_variable *ahit_accept;
217 nir_variable *ahit_terminate;
218 nir_variable *terminated;
219
220 unsigned stack_size;
221 };
222
223 static struct rt_variables
create_rt_variables(nir_shader * shader,struct radv_device * device,const VkPipelineCreateFlags2 flags,bool monolithic)224 create_rt_variables(nir_shader *shader, struct radv_device *device, const VkPipelineCreateFlags2 flags, bool monolithic)
225 {
226 struct rt_variables vars = {
227 .device = device,
228 .flags = flags,
229 .monolithic = monolithic,
230 };
231 vars.idx = nir_variable_create(shader, nir_var_shader_temp, glsl_uint_type(), "idx");
232 vars.shader_addr = nir_variable_create(shader, nir_var_shader_temp, glsl_uint64_t_type(), "shader_addr");
233 vars.traversal_addr = nir_variable_create(shader, nir_var_shader_temp, glsl_uint64_t_type(), "traversal_addr");
234 vars.arg = nir_variable_create(shader, nir_var_shader_temp, glsl_uint_type(), "arg");
235 vars.stack_ptr = nir_variable_create(shader, nir_var_shader_temp, glsl_uint_type(), "stack_ptr");
236 vars.shader_record_ptr = nir_variable_create(shader, nir_var_shader_temp, glsl_uint64_t_type(), "shader_record_ptr");
237
238 vars.launch_sizes[0] = nir_variable_create(shader, nir_var_shader_temp, glsl_uint_type(), "launch_size_x");
239 vars.launch_sizes[1] = nir_variable_create(shader, nir_var_shader_temp, glsl_uint_type(), "launch_size_y");
240 vars.launch_sizes[2] = nir_variable_create(shader, nir_var_shader_temp, glsl_uint_type(), "launch_size_z");
241
242 vars.launch_ids[0] = nir_variable_create(shader, nir_var_shader_temp, glsl_uint_type(), "launch_id_x");
243 vars.launch_ids[1] = nir_variable_create(shader, nir_var_shader_temp, glsl_uint_type(), "launch_id_y");
244 vars.launch_ids[2] = nir_variable_create(shader, nir_var_shader_temp, glsl_uint_type(), "launch_id_z");
245
246 if (device->rra_trace.ray_history_addr)
247 vars.ahit_isec_count = nir_variable_create(shader, nir_var_shader_temp, glsl_uint_type(), "ahit_isec_count");
248
249 const struct glsl_type *vec3_type = glsl_vector_type(GLSL_TYPE_FLOAT, 3);
250 vars.accel_struct = nir_variable_create(shader, nir_var_shader_temp, glsl_uint64_t_type(), "accel_struct");
251 vars.cull_mask_and_flags = nir_variable_create(shader, nir_var_shader_temp, glsl_uint_type(), "cull_mask_and_flags");
252 vars.sbt_offset = nir_variable_create(shader, nir_var_shader_temp, glsl_uint_type(), "sbt_offset");
253 vars.sbt_stride = nir_variable_create(shader, nir_var_shader_temp, glsl_uint_type(), "sbt_stride");
254 vars.miss_index = nir_variable_create(shader, nir_var_shader_temp, glsl_uint_type(), "miss_index");
255 vars.origin = nir_variable_create(shader, nir_var_shader_temp, vec3_type, "ray_origin");
256 vars.tmin = nir_variable_create(shader, nir_var_shader_temp, glsl_float_type(), "ray_tmin");
257 vars.direction = nir_variable_create(shader, nir_var_shader_temp, vec3_type, "ray_direction");
258 vars.tmax = nir_variable_create(shader, nir_var_shader_temp, glsl_float_type(), "ray_tmax");
259
260 vars.primitive_id = nir_variable_create(shader, nir_var_shader_temp, glsl_uint_type(), "primitive_id");
261 vars.geometry_id_and_flags =
262 nir_variable_create(shader, nir_var_shader_temp, glsl_uint_type(), "geometry_id_and_flags");
263 vars.instance_addr = nir_variable_create(shader, nir_var_shader_temp, glsl_uint64_t_type(), "instance_addr");
264 vars.hit_kind = nir_variable_create(shader, nir_var_shader_temp, glsl_uint_type(), "hit_kind");
265 vars.opaque = nir_variable_create(shader, nir_var_shader_temp, glsl_bool_type(), "opaque");
266
267 vars.ahit_accept = nir_variable_create(shader, nir_var_shader_temp, glsl_bool_type(), "ahit_accept");
268 vars.ahit_terminate = nir_variable_create(shader, nir_var_shader_temp, glsl_bool_type(), "ahit_terminate");
269 vars.terminated = nir_variable_create(shader, nir_var_shader_temp, glsl_bool_type(), "terminated");
270
271 return vars;
272 }
273
274 /*
275 * Remap all the variables between the two rt_variables struct for inlining.
276 */
277 static void
map_rt_variables(struct hash_table * var_remap,struct rt_variables * src,const struct rt_variables * dst)278 map_rt_variables(struct hash_table *var_remap, struct rt_variables *src, const struct rt_variables *dst)
279 {
280 _mesa_hash_table_insert(var_remap, src->idx, dst->idx);
281 _mesa_hash_table_insert(var_remap, src->shader_addr, dst->shader_addr);
282 _mesa_hash_table_insert(var_remap, src->traversal_addr, dst->traversal_addr);
283 _mesa_hash_table_insert(var_remap, src->arg, dst->arg);
284 _mesa_hash_table_insert(var_remap, src->stack_ptr, dst->stack_ptr);
285 _mesa_hash_table_insert(var_remap, src->shader_record_ptr, dst->shader_record_ptr);
286
287 for (uint32_t i = 0; i < ARRAY_SIZE(src->launch_sizes); i++)
288 _mesa_hash_table_insert(var_remap, src->launch_sizes[i], dst->launch_sizes[i]);
289
290 for (uint32_t i = 0; i < ARRAY_SIZE(src->launch_ids); i++)
291 _mesa_hash_table_insert(var_remap, src->launch_ids[i], dst->launch_ids[i]);
292
293 if (dst->ahit_isec_count)
294 _mesa_hash_table_insert(var_remap, src->ahit_isec_count, dst->ahit_isec_count);
295
296 _mesa_hash_table_insert(var_remap, src->accel_struct, dst->accel_struct);
297 _mesa_hash_table_insert(var_remap, src->cull_mask_and_flags, dst->cull_mask_and_flags);
298 _mesa_hash_table_insert(var_remap, src->sbt_offset, dst->sbt_offset);
299 _mesa_hash_table_insert(var_remap, src->sbt_stride, dst->sbt_stride);
300 _mesa_hash_table_insert(var_remap, src->miss_index, dst->miss_index);
301 _mesa_hash_table_insert(var_remap, src->origin, dst->origin);
302 _mesa_hash_table_insert(var_remap, src->tmin, dst->tmin);
303 _mesa_hash_table_insert(var_remap, src->direction, dst->direction);
304 _mesa_hash_table_insert(var_remap, src->tmax, dst->tmax);
305
306 _mesa_hash_table_insert(var_remap, src->primitive_id, dst->primitive_id);
307 _mesa_hash_table_insert(var_remap, src->geometry_id_and_flags, dst->geometry_id_and_flags);
308 _mesa_hash_table_insert(var_remap, src->instance_addr, dst->instance_addr);
309 _mesa_hash_table_insert(var_remap, src->hit_kind, dst->hit_kind);
310 _mesa_hash_table_insert(var_remap, src->opaque, dst->opaque);
311 _mesa_hash_table_insert(var_remap, src->ahit_accept, dst->ahit_accept);
312 _mesa_hash_table_insert(var_remap, src->ahit_terminate, dst->ahit_terminate);
313 _mesa_hash_table_insert(var_remap, src->terminated, dst->terminated);
314 }
315
316 /*
317 * Create a copy of the global rt variables where the primitive/instance related variables are
318 * independent.This is needed as we need to keep the old values of the global variables around
319 * in case e.g. an anyhit shader reject the collision. So there are inner variables that get copied
320 * to the outer variables once we commit to a better hit.
321 */
322 static struct rt_variables
create_inner_vars(nir_builder * b,const struct rt_variables * vars)323 create_inner_vars(nir_builder *b, const struct rt_variables *vars)
324 {
325 struct rt_variables inner_vars = *vars;
326 inner_vars.idx = nir_variable_create(b->shader, nir_var_shader_temp, glsl_uint_type(), "inner_idx");
327 inner_vars.shader_record_ptr =
328 nir_variable_create(b->shader, nir_var_shader_temp, glsl_uint64_t_type(), "inner_shader_record_ptr");
329 inner_vars.primitive_id =
330 nir_variable_create(b->shader, nir_var_shader_temp, glsl_uint_type(), "inner_primitive_id");
331 inner_vars.geometry_id_and_flags =
332 nir_variable_create(b->shader, nir_var_shader_temp, glsl_uint_type(), "inner_geometry_id_and_flags");
333 inner_vars.tmax = nir_variable_create(b->shader, nir_var_shader_temp, glsl_float_type(), "inner_tmax");
334 inner_vars.instance_addr =
335 nir_variable_create(b->shader, nir_var_shader_temp, glsl_uint64_t_type(), "inner_instance_addr");
336 inner_vars.hit_kind = nir_variable_create(b->shader, nir_var_shader_temp, glsl_uint_type(), "inner_hit_kind");
337
338 return inner_vars;
339 }
340
341 static void
insert_rt_return(nir_builder * b,const struct rt_variables * vars)342 insert_rt_return(nir_builder *b, const struct rt_variables *vars)
343 {
344 nir_store_var(b, vars->stack_ptr, nir_iadd_imm(b, nir_load_var(b, vars->stack_ptr), -16), 1);
345 nir_store_var(b, vars->shader_addr, nir_load_scratch(b, 1, 64, nir_load_var(b, vars->stack_ptr), .align_mul = 16),
346 1);
347 }
348
349 enum sbt_type {
350 SBT_RAYGEN = offsetof(VkTraceRaysIndirectCommand2KHR, raygenShaderRecordAddress),
351 SBT_MISS = offsetof(VkTraceRaysIndirectCommand2KHR, missShaderBindingTableAddress),
352 SBT_HIT = offsetof(VkTraceRaysIndirectCommand2KHR, hitShaderBindingTableAddress),
353 SBT_CALLABLE = offsetof(VkTraceRaysIndirectCommand2KHR, callableShaderBindingTableAddress),
354 };
355
356 enum sbt_entry {
357 SBT_RECURSIVE_PTR = offsetof(struct radv_pipeline_group_handle, recursive_shader_ptr),
358 SBT_GENERAL_IDX = offsetof(struct radv_pipeline_group_handle, general_index),
359 SBT_CLOSEST_HIT_IDX = offsetof(struct radv_pipeline_group_handle, closest_hit_index),
360 SBT_INTERSECTION_IDX = offsetof(struct radv_pipeline_group_handle, intersection_index),
361 SBT_ANY_HIT_IDX = offsetof(struct radv_pipeline_group_handle, any_hit_index),
362 };
363
364 static void
load_sbt_entry(nir_builder * b,const struct rt_variables * vars,nir_def * idx,enum sbt_type binding,enum sbt_entry offset)365 load_sbt_entry(nir_builder *b, const struct rt_variables *vars, nir_def *idx, enum sbt_type binding,
366 enum sbt_entry offset)
367 {
368 nir_def *desc_base_addr = nir_load_sbt_base_amd(b);
369
370 nir_def *desc = nir_pack_64_2x32(b, nir_load_smem_amd(b, 2, desc_base_addr, nir_imm_int(b, binding)));
371
372 nir_def *stride_offset = nir_imm_int(b, binding + (binding == SBT_RAYGEN ? 8 : 16));
373 nir_def *stride = nir_load_smem_amd(b, 1, desc_base_addr, stride_offset);
374
375 nir_def *addr = nir_iadd(b, desc, nir_u2u64(b, nir_iadd_imm(b, nir_imul(b, idx, stride), offset)));
376
377 if (offset == SBT_RECURSIVE_PTR) {
378 nir_store_var(b, vars->shader_addr, nir_build_load_global(b, 1, 64, addr), 1);
379 } else {
380 nir_store_var(b, vars->idx, nir_build_load_global(b, 1, 32, addr), 1);
381 }
382
383 nir_def *record_addr = nir_iadd_imm(b, addr, RADV_RT_HANDLE_SIZE - offset);
384 nir_store_var(b, vars->shader_record_ptr, record_addr, 1);
385 }
386
387 struct radv_rt_shader_info {
388 bool uses_launch_id;
389 bool uses_launch_size;
390 };
391
392 struct radv_lower_rt_instruction_data {
393 struct rt_variables *vars;
394 bool late_lowering;
395
396 struct radv_rt_shader_info *out_info;
397 };
398
399 static bool
radv_lower_rt_instruction(nir_builder * b,nir_instr * instr,void * _data)400 radv_lower_rt_instruction(nir_builder *b, nir_instr *instr, void *_data)
401 {
402 if (instr->type == nir_instr_type_jump) {
403 nir_jump_instr *jump = nir_instr_as_jump(instr);
404 if (jump->type == nir_jump_halt) {
405 jump->type = nir_jump_return;
406 return true;
407 }
408 return false;
409 } else if (instr->type != nir_instr_type_intrinsic) {
410 return false;
411 }
412
413 nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
414
415 struct radv_lower_rt_instruction_data *data = _data;
416 struct rt_variables *vars = data->vars;
417
418 b->cursor = nir_before_instr(&intr->instr);
419
420 nir_def *ret = NULL;
421 switch (intr->intrinsic) {
422 case nir_intrinsic_rt_execute_callable: {
423 uint32_t size = align(nir_intrinsic_stack_size(intr), 16);
424 nir_def *ret_ptr = nir_load_resume_shader_address_amd(b, nir_intrinsic_call_idx(intr));
425 ret_ptr = nir_ior_imm(b, ret_ptr, radv_get_rt_priority(b->shader->info.stage));
426
427 nir_store_var(b, vars->stack_ptr, nir_iadd_imm_nuw(b, nir_load_var(b, vars->stack_ptr), size), 1);
428 nir_store_scratch(b, ret_ptr, nir_load_var(b, vars->stack_ptr), .align_mul = 16);
429
430 nir_store_var(b, vars->stack_ptr, nir_iadd_imm_nuw(b, nir_load_var(b, vars->stack_ptr), 16), 1);
431 load_sbt_entry(b, vars, intr->src[0].ssa, SBT_CALLABLE, SBT_RECURSIVE_PTR);
432
433 nir_store_var(b, vars->arg, nir_iadd_imm(b, intr->src[1].ssa, -size - 16), 1);
434
435 vars->stack_size = MAX2(vars->stack_size, size + 16);
436 break;
437 }
438 case nir_intrinsic_rt_trace_ray: {
439 uint32_t size = align(nir_intrinsic_stack_size(intr), 16);
440 nir_def *ret_ptr = nir_load_resume_shader_address_amd(b, nir_intrinsic_call_idx(intr));
441 ret_ptr = nir_ior_imm(b, ret_ptr, radv_get_rt_priority(b->shader->info.stage));
442
443 nir_store_var(b, vars->stack_ptr, nir_iadd_imm_nuw(b, nir_load_var(b, vars->stack_ptr), size), 1);
444 nir_store_scratch(b, ret_ptr, nir_load_var(b, vars->stack_ptr), .align_mul = 16);
445
446 nir_store_var(b, vars->stack_ptr, nir_iadd_imm_nuw(b, nir_load_var(b, vars->stack_ptr), 16), 1);
447
448 nir_store_var(b, vars->shader_addr, nir_load_var(b, vars->traversal_addr), 1);
449 nir_store_var(b, vars->arg, nir_iadd_imm(b, intr->src[10].ssa, -size - 16), 1);
450
451 vars->stack_size = MAX2(vars->stack_size, size + 16);
452
453 /* Per the SPIR-V extension spec we have to ignore some bits for some arguments. */
454 nir_store_var(b, vars->accel_struct, intr->src[0].ssa, 0x1);
455 nir_store_var(b, vars->cull_mask_and_flags, nir_ior(b, nir_ishl_imm(b, intr->src[2].ssa, 24), intr->src[1].ssa),
456 0x1);
457 nir_store_var(b, vars->sbt_offset, nir_iand_imm(b, intr->src[3].ssa, 0xf), 0x1);
458 nir_store_var(b, vars->sbt_stride, nir_iand_imm(b, intr->src[4].ssa, 0xf), 0x1);
459 nir_store_var(b, vars->miss_index, nir_iand_imm(b, intr->src[5].ssa, 0xffff), 0x1);
460 nir_store_var(b, vars->origin, intr->src[6].ssa, 0x7);
461 nir_store_var(b, vars->tmin, intr->src[7].ssa, 0x1);
462 nir_store_var(b, vars->direction, intr->src[8].ssa, 0x7);
463 nir_store_var(b, vars->tmax, intr->src[9].ssa, 0x1);
464 break;
465 }
466 case nir_intrinsic_rt_resume: {
467 uint32_t size = align(nir_intrinsic_stack_size(intr), 16);
468
469 nir_store_var(b, vars->stack_ptr, nir_iadd_imm(b, nir_load_var(b, vars->stack_ptr), -size), 1);
470 break;
471 }
472 case nir_intrinsic_rt_return_amd: {
473 if (b->shader->info.stage == MESA_SHADER_RAYGEN) {
474 nir_terminate(b);
475 break;
476 }
477 insert_rt_return(b, vars);
478 break;
479 }
480 case nir_intrinsic_load_scratch: {
481 if (data->late_lowering)
482 nir_src_rewrite(&intr->src[0], nir_iadd_nuw(b, nir_load_var(b, vars->stack_ptr), intr->src[0].ssa));
483 return true;
484 }
485 case nir_intrinsic_store_scratch: {
486 if (data->late_lowering)
487 nir_src_rewrite(&intr->src[1], nir_iadd_nuw(b, nir_load_var(b, vars->stack_ptr), intr->src[1].ssa));
488 return true;
489 }
490 case nir_intrinsic_load_rt_arg_scratch_offset_amd: {
491 ret = nir_load_var(b, vars->arg);
492 break;
493 }
494 case nir_intrinsic_load_shader_record_ptr: {
495 ret = nir_load_var(b, vars->shader_record_ptr);
496 break;
497 }
498 case nir_intrinsic_load_ray_launch_size: {
499 if (data->out_info)
500 data->out_info->uses_launch_size = true;
501
502 if (!data->late_lowering)
503 return false;
504
505 ret = nir_vec3(b, nir_load_var(b, vars->launch_sizes[0]), nir_load_var(b, vars->launch_sizes[1]),
506 nir_load_var(b, vars->launch_sizes[2]));
507 break;
508 };
509 case nir_intrinsic_load_ray_launch_id: {
510 if (data->out_info)
511 data->out_info->uses_launch_id = true;
512
513 if (!data->late_lowering)
514 return false;
515
516 ret = nir_vec3(b, nir_load_var(b, vars->launch_ids[0]), nir_load_var(b, vars->launch_ids[1]),
517 nir_load_var(b, vars->launch_ids[2]));
518 break;
519 }
520 case nir_intrinsic_load_ray_t_min: {
521 ret = nir_load_var(b, vars->tmin);
522 break;
523 }
524 case nir_intrinsic_load_ray_t_max: {
525 ret = nir_load_var(b, vars->tmax);
526 break;
527 }
528 case nir_intrinsic_load_ray_world_origin: {
529 ret = nir_load_var(b, vars->origin);
530 break;
531 }
532 case nir_intrinsic_load_ray_world_direction: {
533 ret = nir_load_var(b, vars->direction);
534 break;
535 }
536 case nir_intrinsic_load_ray_instance_custom_index: {
537 nir_def *instance_node_addr = nir_load_var(b, vars->instance_addr);
538 nir_def *custom_instance_and_mask = nir_build_load_global(
539 b, 1, 32,
540 nir_iadd_imm(b, instance_node_addr, offsetof(struct radv_bvh_instance_node, custom_instance_and_mask)));
541 ret = nir_iand_imm(b, custom_instance_and_mask, 0xFFFFFF);
542 break;
543 }
544 case nir_intrinsic_load_primitive_id: {
545 ret = nir_load_var(b, vars->primitive_id);
546 break;
547 }
548 case nir_intrinsic_load_ray_geometry_index: {
549 ret = nir_load_var(b, vars->geometry_id_and_flags);
550 ret = nir_iand_imm(b, ret, 0xFFFFFFF);
551 break;
552 }
553 case nir_intrinsic_load_instance_id: {
554 nir_def *instance_node_addr = nir_load_var(b, vars->instance_addr);
555 ret = nir_build_load_global(
556 b, 1, 32, nir_iadd_imm(b, instance_node_addr, offsetof(struct radv_bvh_instance_node, instance_id)));
557 break;
558 }
559 case nir_intrinsic_load_ray_flags: {
560 ret = nir_iand_imm(b, nir_load_var(b, vars->cull_mask_and_flags), 0xFFFFFF);
561 break;
562 }
563 case nir_intrinsic_load_ray_hit_kind: {
564 ret = nir_load_var(b, vars->hit_kind);
565 break;
566 }
567 case nir_intrinsic_load_ray_world_to_object: {
568 unsigned c = nir_intrinsic_column(intr);
569 nir_def *instance_node_addr = nir_load_var(b, vars->instance_addr);
570 nir_def *wto_matrix[3];
571 nir_build_wto_matrix_load(b, instance_node_addr, wto_matrix);
572
573 nir_def *vals[3];
574 for (unsigned i = 0; i < 3; ++i)
575 vals[i] = nir_channel(b, wto_matrix[i], c);
576
577 ret = nir_vec(b, vals, 3);
578 break;
579 }
580 case nir_intrinsic_load_ray_object_to_world: {
581 unsigned c = nir_intrinsic_column(intr);
582 nir_def *instance_node_addr = nir_load_var(b, vars->instance_addr);
583 nir_def *rows[3];
584 for (unsigned r = 0; r < 3; ++r)
585 rows[r] = nir_build_load_global(
586 b, 4, 32,
587 nir_iadd_imm(b, instance_node_addr, offsetof(struct radv_bvh_instance_node, otw_matrix) + r * 16));
588 ret = nir_vec3(b, nir_channel(b, rows[0], c), nir_channel(b, rows[1], c), nir_channel(b, rows[2], c));
589 break;
590 }
591 case nir_intrinsic_load_ray_object_origin: {
592 nir_def *instance_node_addr = nir_load_var(b, vars->instance_addr);
593 nir_def *wto_matrix[3];
594 nir_build_wto_matrix_load(b, instance_node_addr, wto_matrix);
595 ret = nir_build_vec3_mat_mult(b, nir_load_var(b, vars->origin), wto_matrix, true);
596 break;
597 }
598 case nir_intrinsic_load_ray_object_direction: {
599 nir_def *instance_node_addr = nir_load_var(b, vars->instance_addr);
600 nir_def *wto_matrix[3];
601 nir_build_wto_matrix_load(b, instance_node_addr, wto_matrix);
602 ret = nir_build_vec3_mat_mult(b, nir_load_var(b, vars->direction), wto_matrix, false);
603 break;
604 }
605 case nir_intrinsic_load_intersection_opaque_amd: {
606 ret = nir_load_var(b, vars->opaque);
607 break;
608 }
609 case nir_intrinsic_load_cull_mask: {
610 ret = nir_ushr_imm(b, nir_load_var(b, vars->cull_mask_and_flags), 24);
611 break;
612 }
613 case nir_intrinsic_ignore_ray_intersection: {
614 nir_store_var(b, vars->ahit_accept, nir_imm_false(b), 0x1);
615
616 /* The if is a workaround to avoid having to fix up control flow manually */
617 nir_push_if(b, nir_imm_true(b));
618 nir_jump(b, nir_jump_return);
619 nir_pop_if(b, NULL);
620 break;
621 }
622 case nir_intrinsic_terminate_ray: {
623 nir_store_var(b, vars->ahit_accept, nir_imm_true(b), 0x1);
624 nir_store_var(b, vars->ahit_terminate, nir_imm_true(b), 0x1);
625
626 /* The if is a workaround to avoid having to fix up control flow manually */
627 nir_push_if(b, nir_imm_true(b));
628 nir_jump(b, nir_jump_return);
629 nir_pop_if(b, NULL);
630 break;
631 }
632 case nir_intrinsic_report_ray_intersection: {
633 nir_def *in_range = nir_iand(b, nir_fge(b, nir_load_var(b, vars->tmax), intr->src[0].ssa),
634 nir_fge(b, intr->src[0].ssa, nir_load_var(b, vars->tmin)));
635 nir_def *terminated = nir_load_var(b, vars->terminated);
636 nir_push_if(b, nir_iand(b, in_range, nir_inot(b, terminated)));
637 {
638 nir_store_var(b, vars->ahit_accept, nir_imm_true(b), 0x1);
639 nir_store_var(b, vars->tmax, intr->src[0].ssa, 1);
640 nir_store_var(b, vars->hit_kind, intr->src[1].ssa, 1);
641 nir_def *terminate_on_first_hit =
642 nir_test_mask(b, nir_load_var(b, vars->cull_mask_and_flags), SpvRayFlagsTerminateOnFirstHitKHRMask);
643 nir_store_var(b, vars->terminated, nir_ior(b, terminate_on_first_hit, nir_load_var(b, vars->ahit_terminate)),
644 1);
645 }
646 nir_pop_if(b, NULL);
647 break;
648 }
649 case nir_intrinsic_load_sbt_offset_amd: {
650 ret = nir_load_var(b, vars->sbt_offset);
651 break;
652 }
653 case nir_intrinsic_load_sbt_stride_amd: {
654 ret = nir_load_var(b, vars->sbt_stride);
655 break;
656 }
657 case nir_intrinsic_load_accel_struct_amd: {
658 ret = nir_load_var(b, vars->accel_struct);
659 break;
660 }
661 case nir_intrinsic_load_cull_mask_and_flags_amd: {
662 ret = nir_load_var(b, vars->cull_mask_and_flags);
663 break;
664 }
665 case nir_intrinsic_execute_closest_hit_amd: {
666 nir_store_var(b, vars->tmax, intr->src[1].ssa, 0x1);
667 nir_store_var(b, vars->primitive_id, intr->src[2].ssa, 0x1);
668 nir_store_var(b, vars->instance_addr, intr->src[3].ssa, 0x1);
669 nir_store_var(b, vars->geometry_id_and_flags, intr->src[4].ssa, 0x1);
670 nir_store_var(b, vars->hit_kind, intr->src[5].ssa, 0x1);
671 load_sbt_entry(b, vars, intr->src[0].ssa, SBT_HIT, SBT_RECURSIVE_PTR);
672
673 nir_def *should_return =
674 nir_test_mask(b, nir_load_var(b, vars->cull_mask_and_flags), SpvRayFlagsSkipClosestHitShaderKHRMask);
675
676 if (!(vars->flags & VK_PIPELINE_CREATE_2_RAY_TRACING_NO_NULL_CLOSEST_HIT_SHADERS_BIT_KHR)) {
677 should_return = nir_ior(b, should_return, nir_ieq_imm(b, nir_load_var(b, vars->shader_addr), 0));
678 }
679
680 /* should_return is set if we had a hit but we won't be calling the closest hit
681 * shader and hence need to return immediately to the calling shader. */
682 nir_push_if(b, should_return);
683 insert_rt_return(b, vars);
684 nir_pop_if(b, NULL);
685 break;
686 }
687 case nir_intrinsic_execute_miss_amd: {
688 nir_store_var(b, vars->tmax, intr->src[0].ssa, 0x1);
689 nir_def *undef = nir_undef(b, 1, 32);
690 nir_store_var(b, vars->primitive_id, undef, 0x1);
691 nir_store_var(b, vars->instance_addr, nir_undef(b, 1, 64), 0x1);
692 nir_store_var(b, vars->geometry_id_and_flags, undef, 0x1);
693 nir_store_var(b, vars->hit_kind, undef, 0x1);
694 nir_def *miss_index = nir_load_var(b, vars->miss_index);
695 load_sbt_entry(b, vars, miss_index, SBT_MISS, SBT_RECURSIVE_PTR);
696
697 if (!(vars->flags & VK_PIPELINE_CREATE_2_RAY_TRACING_NO_NULL_MISS_SHADERS_BIT_KHR)) {
698 /* In case of a NULL miss shader, do nothing and just return. */
699 nir_push_if(b, nir_ieq_imm(b, nir_load_var(b, vars->shader_addr), 0));
700 insert_rt_return(b, vars);
701 nir_pop_if(b, NULL);
702 }
703
704 break;
705 }
706 case nir_intrinsic_load_ray_triangle_vertex_positions: {
707 nir_def *instance_node_addr = nir_load_var(b, vars->instance_addr);
708 nir_def *primitive_id = nir_load_var(b, vars->primitive_id);
709 ret = radv_load_vertex_position(vars->device, b, instance_node_addr, primitive_id, nir_intrinsic_column(intr));
710 break;
711 }
712 default:
713 return false;
714 }
715
716 if (ret)
717 nir_def_rewrite_uses(&intr->def, ret);
718 nir_instr_remove(&intr->instr);
719
720 return true;
721 }
722
723 /* This lowers all the RT instructions that we do not want to pass on to the combined shader and
724 * that we can implement using the variables from the shader we are going to inline into. */
725 static void
lower_rt_instructions(nir_shader * shader,struct rt_variables * vars,bool late_lowering,struct radv_rt_shader_info * out_info)726 lower_rt_instructions(nir_shader *shader, struct rt_variables *vars, bool late_lowering,
727 struct radv_rt_shader_info *out_info)
728 {
729 struct radv_lower_rt_instruction_data data = {
730 .vars = vars,
731 .late_lowering = late_lowering,
732 .out_info = out_info,
733 };
734 nir_shader_instructions_pass(shader, radv_lower_rt_instruction, nir_metadata_none, &data);
735 }
736
737 /* Lowers hit attributes to registers or shared memory. If hit_attribs is NULL, attributes are
738 * lowered to shared memory. */
739 static void
lower_hit_attribs(nir_shader * shader,nir_variable ** hit_attribs,uint32_t workgroup_size)740 lower_hit_attribs(nir_shader *shader, nir_variable **hit_attribs, uint32_t workgroup_size)
741 {
742 nir_function_impl *impl = nir_shader_get_entrypoint(shader);
743
744 nir_foreach_variable_with_modes (attrib, shader, nir_var_ray_hit_attrib)
745 attrib->data.mode = nir_var_shader_temp;
746
747 nir_builder b = nir_builder_create(impl);
748
749 nir_foreach_block (block, impl) {
750 nir_foreach_instr_safe (instr, block) {
751 if (instr->type != nir_instr_type_intrinsic)
752 continue;
753
754 nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
755 if (intrin->intrinsic != nir_intrinsic_load_hit_attrib_amd &&
756 intrin->intrinsic != nir_intrinsic_store_hit_attrib_amd)
757 continue;
758
759 b.cursor = nir_after_instr(instr);
760
761 nir_def *offset;
762 if (!hit_attribs)
763 offset = nir_imul_imm(
764 &b, nir_iadd_imm(&b, nir_load_local_invocation_index(&b), nir_intrinsic_base(intrin) * workgroup_size),
765 sizeof(uint32_t));
766
767 if (intrin->intrinsic == nir_intrinsic_load_hit_attrib_amd) {
768 nir_def *ret;
769 if (hit_attribs)
770 ret = nir_load_var(&b, hit_attribs[nir_intrinsic_base(intrin)]);
771 else
772 ret = nir_load_shared(&b, 1, 32, offset, .base = 0, .align_mul = 4);
773 nir_def_rewrite_uses(nir_instr_def(instr), ret);
774 } else {
775 if (hit_attribs)
776 nir_store_var(&b, hit_attribs[nir_intrinsic_base(intrin)], intrin->src->ssa, 0x1);
777 else
778 nir_store_shared(&b, intrin->src->ssa, offset, .base = 0, .align_mul = 4);
779 }
780 nir_instr_remove(instr);
781 }
782 }
783
784 if (!hit_attribs)
785 shader->info.shared_size = MAX2(shader->info.shared_size, workgroup_size * RADV_MAX_HIT_ATTRIB_SIZE);
786 }
787
788 static void
inline_constants(nir_shader * dst,nir_shader * src)789 inline_constants(nir_shader *dst, nir_shader *src)
790 {
791 if (!src->constant_data_size)
792 return;
793
794 uint32_t old_constant_data_size = dst->constant_data_size;
795 uint32_t base_offset = align(dst->constant_data_size, 64);
796 dst->constant_data_size = base_offset + src->constant_data_size;
797 dst->constant_data = rerzalloc_size(dst, dst->constant_data, old_constant_data_size, dst->constant_data_size);
798 memcpy((char *)dst->constant_data + base_offset, src->constant_data, src->constant_data_size);
799
800 if (!base_offset)
801 return;
802
803 uint32_t base_align_mul = base_offset ? 1 << (ffs(base_offset) - 1) : NIR_ALIGN_MUL_MAX;
804 nir_foreach_block (block, nir_shader_get_entrypoint(src)) {
805 nir_foreach_instr (instr, block) {
806 if (instr->type != nir_instr_type_intrinsic)
807 continue;
808
809 nir_intrinsic_instr *intrinsic = nir_instr_as_intrinsic(instr);
810 if (intrinsic->intrinsic == nir_intrinsic_load_constant) {
811 nir_intrinsic_set_base(intrinsic, base_offset + nir_intrinsic_base(intrinsic));
812
813 uint32_t align_mul = nir_intrinsic_align_mul(intrinsic);
814 uint32_t align_offset = nir_intrinsic_align_offset(intrinsic);
815 align_mul = MIN2(align_mul, base_align_mul);
816 nir_intrinsic_set_align(intrinsic, align_mul, align_offset % align_mul);
817 }
818 }
819 }
820 }
821
822 static void
insert_rt_case(nir_builder * b,nir_shader * shader,struct rt_variables * vars,nir_def * idx,uint32_t call_idx)823 insert_rt_case(nir_builder *b, nir_shader *shader, struct rt_variables *vars, nir_def *idx, uint32_t call_idx)
824 {
825 struct hash_table *var_remap = _mesa_pointer_hash_table_create(NULL);
826
827 nir_opt_dead_cf(shader);
828
829 struct rt_variables src_vars = create_rt_variables(shader, vars->device, vars->flags, vars->monolithic);
830 map_rt_variables(var_remap, &src_vars, vars);
831
832 NIR_PASS_V(shader, lower_rt_instructions, &src_vars, false, NULL);
833
834 NIR_PASS(_, shader, nir_lower_returns);
835 NIR_PASS(_, shader, nir_opt_dce);
836
837 inline_constants(b->shader, shader);
838
839 nir_push_if(b, nir_ieq_imm(b, idx, call_idx));
840 nir_inline_function_impl(b, nir_shader_get_entrypoint(shader), NULL, var_remap);
841 nir_pop_if(b, NULL);
842
843 ralloc_free(var_remap);
844 }
845
846 void
radv_nir_lower_rt_io(nir_shader * nir,bool monolithic,uint32_t payload_offset)847 radv_nir_lower_rt_io(nir_shader *nir, bool monolithic, uint32_t payload_offset)
848 {
849 if (!monolithic) {
850 NIR_PASS(_, nir, nir_lower_vars_to_explicit_types, nir_var_function_temp | nir_var_shader_call_data,
851 glsl_get_natural_size_align_bytes);
852
853 NIR_PASS(_, nir, lower_rt_derefs);
854
855 NIR_PASS(_, nir, nir_lower_explicit_io, nir_var_function_temp, nir_address_format_32bit_offset);
856 } else {
857 NIR_PASS(_, nir, radv_nir_lower_ray_payload_derefs, payload_offset);
858 }
859 }
860
861 static nir_def *
radv_build_token_begin(nir_builder * b,struct rt_variables * vars,nir_def * hit,enum radv_packed_token_type token_type,nir_def * token_size,uint32_t max_token_size)862 radv_build_token_begin(nir_builder *b, struct rt_variables *vars, nir_def *hit, enum radv_packed_token_type token_type,
863 nir_def *token_size, uint32_t max_token_size)
864 {
865 struct radv_rra_trace_data *rra_trace = &vars->device->rra_trace;
866 assert(rra_trace->ray_history_addr);
867 assert(rra_trace->ray_history_buffer_size >= max_token_size);
868
869 nir_def *ray_history_addr = nir_imm_int64(b, rra_trace->ray_history_addr);
870
871 nir_def *launch_id = nir_load_ray_launch_id(b);
872
873 nir_def *trace = nir_imm_true(b);
874 for (uint32_t i = 0; i < 3; i++) {
875 nir_def *remainder = nir_umod_imm(b, nir_channel(b, launch_id, i), rra_trace->ray_history_resolution_scale);
876 trace = nir_iand(b, trace, nir_ieq_imm(b, remainder, 0));
877 }
878 nir_push_if(b, trace);
879
880 static_assert(offsetof(struct radv_ray_history_header, offset) == 0, "Unexpected offset");
881 nir_def *base_offset = nir_global_atomic(b, 32, ray_history_addr, token_size, .atomic_op = nir_atomic_op_iadd);
882
883 /* Abuse the dword alignment of token_size to add an invalid bit to offset. */
884 trace = nir_ieq_imm(b, nir_iand_imm(b, base_offset, 1), 0);
885
886 nir_def *in_bounds = nir_ule_imm(b, base_offset, rra_trace->ray_history_buffer_size - max_token_size);
887 /* Make sure we don't overwrite the header in case of an overflow. */
888 in_bounds = nir_iand(b, in_bounds, nir_uge_imm(b, base_offset, sizeof(struct radv_ray_history_header)));
889
890 nir_push_if(b, nir_iand(b, trace, in_bounds));
891
892 nir_def *dst_addr = nir_iadd(b, ray_history_addr, nir_u2u64(b, base_offset));
893
894 nir_def *launch_size = nir_load_ray_launch_size(b);
895
896 nir_def *launch_id_comps[3];
897 nir_def *launch_size_comps[3];
898 for (uint32_t i = 0; i < 3; i++) {
899 launch_id_comps[i] = nir_udiv_imm(b, nir_channel(b, launch_id, i), rra_trace->ray_history_resolution_scale);
900 launch_size_comps[i] = nir_udiv_imm(b, nir_channel(b, launch_size, i), rra_trace->ray_history_resolution_scale);
901 }
902
903 nir_def *global_index =
904 nir_iadd(b, launch_id_comps[0],
905 nir_iadd(b, nir_imul(b, launch_id_comps[1], launch_size_comps[0]),
906 nir_imul(b, launch_id_comps[2], nir_imul(b, launch_size_comps[0], launch_size_comps[1]))));
907 nir_def *launch_index_and_hit = nir_bcsel(b, hit, nir_ior_imm(b, global_index, 1u << 29u), global_index);
908 nir_build_store_global(b, nir_ior_imm(b, launch_index_and_hit, token_type << 30), dst_addr, .align_mul = 4);
909
910 return nir_iadd_imm(b, dst_addr, 4);
911 }
912
913 static void
radv_build_token_end(nir_builder * b)914 radv_build_token_end(nir_builder *b)
915 {
916 nir_pop_if(b, NULL);
917 nir_pop_if(b, NULL);
918 }
919
920 static void
radv_build_end_trace_token(nir_builder * b,struct rt_variables * vars,nir_def * tmax,nir_def * hit,nir_def * iteration_instance_count)921 radv_build_end_trace_token(nir_builder *b, struct rt_variables *vars, nir_def *tmax, nir_def *hit,
922 nir_def *iteration_instance_count)
923 {
924 nir_def *token_size = nir_bcsel(b, hit, nir_imm_int(b, sizeof(struct radv_packed_end_trace_token)),
925 nir_imm_int(b, offsetof(struct radv_packed_end_trace_token, primitive_id)));
926
927 nir_def *dst_addr = radv_build_token_begin(b, vars, hit, radv_packed_token_end_trace, token_size,
928 sizeof(struct radv_packed_end_trace_token));
929 {
930 nir_build_store_global(b, nir_load_var(b, vars->accel_struct), dst_addr, .align_mul = 4);
931 dst_addr = nir_iadd_imm(b, dst_addr, 8);
932
933 nir_def *dispatch_indices =
934 nir_load_smem_amd(b, 2, nir_imm_int64(b, vars->device->rra_trace.ray_history_addr),
935 nir_imm_int(b, offsetof(struct radv_ray_history_header, dispatch_index)), .align_mul = 4);
936 nir_def *dispatch_index = nir_iadd(b, nir_channel(b, dispatch_indices, 0), nir_channel(b, dispatch_indices, 1));
937 nir_def *dispatch_and_flags = nir_iand_imm(b, nir_load_var(b, vars->cull_mask_and_flags), 0xFFFF);
938 dispatch_and_flags = nir_ior(b, dispatch_and_flags, dispatch_index);
939 nir_build_store_global(b, dispatch_and_flags, dst_addr, .align_mul = 4);
940 dst_addr = nir_iadd_imm(b, dst_addr, 4);
941
942 nir_def *shifted_cull_mask = nir_iand_imm(b, nir_load_var(b, vars->cull_mask_and_flags), 0xFF000000);
943
944 nir_def *packed_args = nir_load_var(b, vars->sbt_offset);
945 packed_args = nir_ior(b, packed_args, nir_ishl_imm(b, nir_load_var(b, vars->sbt_stride), 4));
946 packed_args = nir_ior(b, packed_args, nir_ishl_imm(b, nir_load_var(b, vars->miss_index), 8));
947 packed_args = nir_ior(b, packed_args, shifted_cull_mask);
948 nir_build_store_global(b, packed_args, dst_addr, .align_mul = 4);
949 dst_addr = nir_iadd_imm(b, dst_addr, 4);
950
951 nir_build_store_global(b, nir_load_var(b, vars->origin), dst_addr, .align_mul = 4);
952 dst_addr = nir_iadd_imm(b, dst_addr, 12);
953
954 nir_build_store_global(b, nir_load_var(b, vars->tmin), dst_addr, .align_mul = 4);
955 dst_addr = nir_iadd_imm(b, dst_addr, 4);
956
957 nir_build_store_global(b, nir_load_var(b, vars->direction), dst_addr, .align_mul = 4);
958 dst_addr = nir_iadd_imm(b, dst_addr, 12);
959
960 nir_build_store_global(b, tmax, dst_addr, .align_mul = 4);
961 dst_addr = nir_iadd_imm(b, dst_addr, 4);
962
963 nir_build_store_global(b, iteration_instance_count, dst_addr, .align_mul = 4);
964 dst_addr = nir_iadd_imm(b, dst_addr, 4);
965
966 nir_build_store_global(b, nir_load_var(b, vars->ahit_isec_count), dst_addr, .align_mul = 4);
967 dst_addr = nir_iadd_imm(b, dst_addr, 4);
968
969 nir_push_if(b, hit);
970 {
971 nir_build_store_global(b, nir_load_var(b, vars->primitive_id), dst_addr, .align_mul = 4);
972 dst_addr = nir_iadd_imm(b, dst_addr, 4);
973
974 nir_def *geometry_id = nir_iand_imm(b, nir_load_var(b, vars->geometry_id_and_flags), 0xFFFFFFF);
975 nir_build_store_global(b, geometry_id, dst_addr, .align_mul = 4);
976 dst_addr = nir_iadd_imm(b, dst_addr, 4);
977
978 nir_def *instance_id_and_hit_kind =
979 nir_build_load_global(b, 1, 32,
980 nir_iadd_imm(b, nir_load_var(b, vars->instance_addr),
981 offsetof(struct radv_bvh_instance_node, instance_id)));
982 instance_id_and_hit_kind =
983 nir_ior(b, instance_id_and_hit_kind, nir_ishl_imm(b, nir_load_var(b, vars->hit_kind), 24));
984 nir_build_store_global(b, instance_id_and_hit_kind, dst_addr, .align_mul = 4);
985 dst_addr = nir_iadd_imm(b, dst_addr, 4);
986
987 nir_build_store_global(b, nir_load_var(b, vars->tmax), dst_addr, .align_mul = 4);
988 dst_addr = nir_iadd_imm(b, dst_addr, 4);
989 }
990 nir_pop_if(b, NULL);
991 }
992 radv_build_token_end(b);
993 }
994
995 static nir_function_impl *
lower_any_hit_for_intersection(nir_shader * any_hit)996 lower_any_hit_for_intersection(nir_shader *any_hit)
997 {
998 nir_function_impl *impl = nir_shader_get_entrypoint(any_hit);
999
1000 /* Any-hit shaders need three parameters */
1001 assert(impl->function->num_params == 0);
1002 nir_parameter params[] = {
1003 {
1004 /* A pointer to a boolean value for whether or not the hit was
1005 * accepted.
1006 */
1007 .num_components = 1,
1008 .bit_size = 32,
1009 },
1010 {
1011 /* The hit T value */
1012 .num_components = 1,
1013 .bit_size = 32,
1014 },
1015 {
1016 /* The hit kind */
1017 .num_components = 1,
1018 .bit_size = 32,
1019 },
1020 {
1021 /* Scratch offset */
1022 .num_components = 1,
1023 .bit_size = 32,
1024 },
1025 };
1026 impl->function->num_params = ARRAY_SIZE(params);
1027 impl->function->params = ralloc_array(any_hit, nir_parameter, ARRAY_SIZE(params));
1028 memcpy(impl->function->params, params, sizeof(params));
1029
1030 nir_builder build = nir_builder_at(nir_before_impl(impl));
1031 nir_builder *b = &build;
1032
1033 nir_def *commit_ptr = nir_load_param(b, 0);
1034 nir_def *hit_t = nir_load_param(b, 1);
1035 nir_def *hit_kind = nir_load_param(b, 2);
1036 nir_def *scratch_offset = nir_load_param(b, 3);
1037
1038 nir_deref_instr *commit = nir_build_deref_cast(b, commit_ptr, nir_var_function_temp, glsl_bool_type(), 0);
1039
1040 nir_foreach_block_safe (block, impl) {
1041 nir_foreach_instr_safe (instr, block) {
1042 switch (instr->type) {
1043 case nir_instr_type_intrinsic: {
1044 nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
1045 switch (intrin->intrinsic) {
1046 case nir_intrinsic_ignore_ray_intersection:
1047 b->cursor = nir_instr_remove(&intrin->instr);
1048 /* We put the newly emitted code inside a dummy if because it's
1049 * going to contain a jump instruction and we don't want to
1050 * deal with that mess here. It'll get dealt with by our
1051 * control-flow optimization passes.
1052 */
1053 nir_store_deref(b, commit, nir_imm_false(b), 0x1);
1054 nir_push_if(b, nir_imm_true(b));
1055 nir_jump(b, nir_jump_return);
1056 nir_pop_if(b, NULL);
1057 break;
1058
1059 case nir_intrinsic_terminate_ray:
1060 /* The "normal" handling of terminateRay works fine in
1061 * intersection shaders.
1062 */
1063 break;
1064
1065 case nir_intrinsic_load_ray_t_max:
1066 nir_def_replace(&intrin->def, hit_t);
1067 break;
1068
1069 case nir_intrinsic_load_ray_hit_kind:
1070 nir_def_replace(&intrin->def, hit_kind);
1071 break;
1072
1073 /* We place all any_hit scratch variables after intersection scratch variables.
1074 * For that reason, we increment the scratch offset by the intersection scratch
1075 * size. For call_data, we have to subtract the offset again.
1076 *
1077 * Note that we don't increase the scratch size as it is already reflected via
1078 * the any_hit stack_size.
1079 */
1080 case nir_intrinsic_load_scratch:
1081 b->cursor = nir_before_instr(instr);
1082 nir_src_rewrite(&intrin->src[0], nir_iadd_nuw(b, scratch_offset, intrin->src[0].ssa));
1083 break;
1084 case nir_intrinsic_store_scratch:
1085 b->cursor = nir_before_instr(instr);
1086 nir_src_rewrite(&intrin->src[1], nir_iadd_nuw(b, scratch_offset, intrin->src[1].ssa));
1087 break;
1088 case nir_intrinsic_load_rt_arg_scratch_offset_amd:
1089 b->cursor = nir_after_instr(instr);
1090 nir_def *arg_offset = nir_isub(b, &intrin->def, scratch_offset);
1091 nir_def_rewrite_uses_after(&intrin->def, arg_offset, arg_offset->parent_instr);
1092 break;
1093
1094 default:
1095 break;
1096 }
1097 break;
1098 }
1099 case nir_instr_type_jump: {
1100 nir_jump_instr *jump = nir_instr_as_jump(instr);
1101 if (jump->type == nir_jump_halt) {
1102 b->cursor = nir_instr_remove(instr);
1103 nir_jump(b, nir_jump_return);
1104 }
1105 break;
1106 }
1107
1108 default:
1109 break;
1110 }
1111 }
1112 }
1113
1114 nir_validate_shader(any_hit, "after initial any-hit lowering");
1115
1116 nir_lower_returns_impl(impl);
1117
1118 nir_validate_shader(any_hit, "after lowering returns");
1119
1120 return impl;
1121 }
1122
1123 /* Inline the any_hit shader into the intersection shader so we don't have
1124 * to implement yet another shader call interface here. Neither do any recursion.
1125 */
1126 static void
nir_lower_intersection_shader(nir_shader * intersection,nir_shader * any_hit)1127 nir_lower_intersection_shader(nir_shader *intersection, nir_shader *any_hit)
1128 {
1129 void *dead_ctx = ralloc_context(intersection);
1130
1131 nir_function_impl *any_hit_impl = NULL;
1132 struct hash_table *any_hit_var_remap = NULL;
1133 if (any_hit) {
1134 any_hit = nir_shader_clone(dead_ctx, any_hit);
1135 NIR_PASS(_, any_hit, nir_opt_dce);
1136
1137 inline_constants(intersection, any_hit);
1138
1139 any_hit_impl = lower_any_hit_for_intersection(any_hit);
1140 any_hit_var_remap = _mesa_pointer_hash_table_create(dead_ctx);
1141 }
1142
1143 nir_function_impl *impl = nir_shader_get_entrypoint(intersection);
1144
1145 nir_builder build = nir_builder_create(impl);
1146 nir_builder *b = &build;
1147
1148 b->cursor = nir_before_impl(impl);
1149
1150 nir_variable *commit = nir_local_variable_create(impl, glsl_bool_type(), "ray_commit");
1151 nir_store_var(b, commit, nir_imm_false(b), 0x1);
1152
1153 nir_foreach_block_safe (block, impl) {
1154 nir_foreach_instr_safe (instr, block) {
1155 if (instr->type != nir_instr_type_intrinsic)
1156 continue;
1157
1158 nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
1159 if (intrin->intrinsic != nir_intrinsic_report_ray_intersection)
1160 continue;
1161
1162 b->cursor = nir_instr_remove(&intrin->instr);
1163 nir_def *hit_t = intrin->src[0].ssa;
1164 nir_def *hit_kind = intrin->src[1].ssa;
1165 nir_def *min_t = nir_load_ray_t_min(b);
1166 nir_def *max_t = nir_load_ray_t_max(b);
1167
1168 /* bool commit_tmp = false; */
1169 nir_variable *commit_tmp = nir_local_variable_create(impl, glsl_bool_type(), "commit_tmp");
1170 nir_store_var(b, commit_tmp, nir_imm_false(b), 0x1);
1171
1172 nir_push_if(b, nir_iand(b, nir_fge(b, hit_t, min_t), nir_fge(b, max_t, hit_t)));
1173 {
1174 /* Any-hit defaults to commit */
1175 nir_store_var(b, commit_tmp, nir_imm_true(b), 0x1);
1176
1177 if (any_hit_impl != NULL) {
1178 nir_push_if(b, nir_inot(b, nir_load_intersection_opaque_amd(b)));
1179 {
1180 nir_def *params[] = {
1181 &nir_build_deref_var(b, commit_tmp)->def,
1182 hit_t,
1183 hit_kind,
1184 nir_imm_int(b, intersection->scratch_size),
1185 };
1186 nir_inline_function_impl(b, any_hit_impl, params, any_hit_var_remap);
1187 }
1188 nir_pop_if(b, NULL);
1189 }
1190
1191 nir_push_if(b, nir_load_var(b, commit_tmp));
1192 {
1193 nir_report_ray_intersection(b, 1, hit_t, hit_kind);
1194 }
1195 nir_pop_if(b, NULL);
1196 }
1197 nir_pop_if(b, NULL);
1198
1199 nir_def *accepted = nir_load_var(b, commit_tmp);
1200 nir_def_rewrite_uses(&intrin->def, accepted);
1201 }
1202 }
1203 nir_metadata_preserve(impl, nir_metadata_none);
1204
1205 /* We did some inlining; have to re-index SSA defs */
1206 nir_index_ssa_defs(impl);
1207
1208 /* Eliminate the casts introduced for the commit return of the any-hit shader. */
1209 NIR_PASS(_, intersection, nir_opt_deref);
1210
1211 ralloc_free(dead_ctx);
1212 }
1213
1214 /* Variables only used internally to ray traversal. This is data that describes
1215 * the current state of the traversal vs. what we'd give to a shader. e.g. what
1216 * is the instance we're currently visiting vs. what is the instance of the
1217 * closest hit. */
1218 struct rt_traversal_vars {
1219 nir_variable *origin;
1220 nir_variable *dir;
1221 nir_variable *inv_dir;
1222 nir_variable *sbt_offset_and_flags;
1223 nir_variable *instance_addr;
1224 nir_variable *hit;
1225 nir_variable *bvh_base;
1226 nir_variable *stack;
1227 nir_variable *top_stack;
1228 nir_variable *stack_low_watermark;
1229 nir_variable *current_node;
1230 nir_variable *previous_node;
1231 nir_variable *instance_top_node;
1232 nir_variable *instance_bottom_node;
1233 };
1234
1235 static struct rt_traversal_vars
init_traversal_vars(nir_builder * b)1236 init_traversal_vars(nir_builder *b)
1237 {
1238 const struct glsl_type *vec3_type = glsl_vector_type(GLSL_TYPE_FLOAT, 3);
1239 struct rt_traversal_vars ret;
1240
1241 ret.origin = nir_variable_create(b->shader, nir_var_shader_temp, vec3_type, "traversal_origin");
1242 ret.dir = nir_variable_create(b->shader, nir_var_shader_temp, vec3_type, "traversal_dir");
1243 ret.inv_dir = nir_variable_create(b->shader, nir_var_shader_temp, vec3_type, "traversal_inv_dir");
1244 ret.sbt_offset_and_flags =
1245 nir_variable_create(b->shader, nir_var_shader_temp, glsl_uint_type(), "traversal_sbt_offset_and_flags");
1246 ret.instance_addr = nir_variable_create(b->shader, nir_var_shader_temp, glsl_uint64_t_type(), "instance_addr");
1247 ret.hit = nir_variable_create(b->shader, nir_var_shader_temp, glsl_bool_type(), "traversal_hit");
1248 ret.bvh_base = nir_variable_create(b->shader, nir_var_shader_temp, glsl_uint64_t_type(), "traversal_bvh_base");
1249 ret.stack = nir_variable_create(b->shader, nir_var_shader_temp, glsl_uint_type(), "traversal_stack_ptr");
1250 ret.top_stack = nir_variable_create(b->shader, nir_var_shader_temp, glsl_uint_type(), "traversal_top_stack_ptr");
1251 ret.stack_low_watermark =
1252 nir_variable_create(b->shader, nir_var_shader_temp, glsl_uint_type(), "traversal_stack_low_watermark");
1253 ret.current_node = nir_variable_create(b->shader, nir_var_shader_temp, glsl_uint_type(), "current_node;");
1254 ret.previous_node = nir_variable_create(b->shader, nir_var_shader_temp, glsl_uint_type(), "previous_node");
1255 ret.instance_top_node = nir_variable_create(b->shader, nir_var_shader_temp, glsl_uint_type(), "instance_top_node");
1256 ret.instance_bottom_node =
1257 nir_variable_create(b->shader, nir_var_shader_temp, glsl_uint_type(), "instance_bottom_node");
1258 return ret;
1259 }
1260
1261 struct traversal_data {
1262 struct radv_device *device;
1263 struct rt_variables *vars;
1264 struct rt_traversal_vars *trav_vars;
1265 nir_variable *barycentrics;
1266
1267 struct radv_ray_tracing_pipeline *pipeline;
1268 };
1269
1270 static void
radv_ray_tracing_group_ahit_info(struct radv_ray_tracing_group * group,uint32_t * shader_index,uint32_t * handle_index,struct radv_rt_case_data * data)1271 radv_ray_tracing_group_ahit_info(struct radv_ray_tracing_group *group, uint32_t *shader_index, uint32_t *handle_index,
1272 struct radv_rt_case_data *data)
1273 {
1274 if (group->type == VK_RAY_TRACING_SHADER_GROUP_TYPE_TRIANGLES_HIT_GROUP_KHR) {
1275 *shader_index = group->any_hit_shader;
1276 *handle_index = group->handle.any_hit_index;
1277 }
1278 }
1279
1280 static void
radv_build_ahit_case(nir_builder * b,nir_def * sbt_idx,struct radv_ray_tracing_group * group,struct radv_rt_case_data * data)1281 radv_build_ahit_case(nir_builder *b, nir_def *sbt_idx, struct radv_ray_tracing_group *group,
1282 struct radv_rt_case_data *data)
1283 {
1284 nir_shader *nir_stage =
1285 radv_pipeline_cache_handle_to_nir(data->device, data->pipeline->stages[group->any_hit_shader].nir);
1286 assert(nir_stage);
1287
1288 radv_nir_lower_rt_io(nir_stage, data->vars->monolithic, data->vars->payload_offset);
1289
1290 insert_rt_case(b, nir_stage, data->vars, sbt_idx, group->handle.any_hit_index);
1291 ralloc_free(nir_stage);
1292 }
1293
1294 static void
radv_ray_tracing_group_isec_info(struct radv_ray_tracing_group * group,uint32_t * shader_index,uint32_t * handle_index,struct radv_rt_case_data * data)1295 radv_ray_tracing_group_isec_info(struct radv_ray_tracing_group *group, uint32_t *shader_index, uint32_t *handle_index,
1296 struct radv_rt_case_data *data)
1297 {
1298 if (group->type == VK_RAY_TRACING_SHADER_GROUP_TYPE_PROCEDURAL_HIT_GROUP_KHR) {
1299 *shader_index = group->intersection_shader;
1300 *handle_index = group->handle.intersection_index;
1301 }
1302 }
1303
1304 static void
radv_build_isec_case(nir_builder * b,nir_def * sbt_idx,struct radv_ray_tracing_group * group,struct radv_rt_case_data * data)1305 radv_build_isec_case(nir_builder *b, nir_def *sbt_idx, struct radv_ray_tracing_group *group,
1306 struct radv_rt_case_data *data)
1307 {
1308 nir_shader *nir_stage =
1309 radv_pipeline_cache_handle_to_nir(data->device, data->pipeline->stages[group->intersection_shader].nir);
1310 assert(nir_stage);
1311
1312 radv_nir_lower_rt_io(nir_stage, data->vars->monolithic, data->vars->payload_offset);
1313
1314 nir_shader *any_hit_stage = NULL;
1315 if (group->any_hit_shader != VK_SHADER_UNUSED_KHR) {
1316 any_hit_stage =
1317 radv_pipeline_cache_handle_to_nir(data->device, data->pipeline->stages[group->any_hit_shader].nir);
1318 assert(any_hit_stage);
1319
1320 radv_nir_lower_rt_io(any_hit_stage, data->vars->monolithic, data->vars->payload_offset);
1321
1322 /* reserve stack size for any_hit before it is inlined */
1323 data->pipeline->stages[group->any_hit_shader].stack_size = any_hit_stage->scratch_size;
1324
1325 nir_lower_intersection_shader(nir_stage, any_hit_stage);
1326 ralloc_free(any_hit_stage);
1327 }
1328
1329 insert_rt_case(b, nir_stage, data->vars, sbt_idx, group->handle.intersection_index);
1330 ralloc_free(nir_stage);
1331 }
1332
1333 static void
radv_ray_tracing_group_chit_info(struct radv_ray_tracing_group * group,uint32_t * shader_index,uint32_t * handle_index,struct radv_rt_case_data * data)1334 radv_ray_tracing_group_chit_info(struct radv_ray_tracing_group *group, uint32_t *shader_index, uint32_t *handle_index,
1335 struct radv_rt_case_data *data)
1336 {
1337 if (group->type != VK_RAY_TRACING_SHADER_GROUP_TYPE_GENERAL_KHR) {
1338 *shader_index = group->recursive_shader;
1339 *handle_index = group->handle.closest_hit_index;
1340 }
1341 }
1342
1343 static void
radv_ray_tracing_group_miss_info(struct radv_ray_tracing_group * group,uint32_t * shader_index,uint32_t * handle_index,struct radv_rt_case_data * data)1344 radv_ray_tracing_group_miss_info(struct radv_ray_tracing_group *group, uint32_t *shader_index, uint32_t *handle_index,
1345 struct radv_rt_case_data *data)
1346 {
1347 if (group->type == VK_RAY_TRACING_SHADER_GROUP_TYPE_GENERAL_KHR) {
1348 if (data->pipeline->stages[group->recursive_shader].stage != MESA_SHADER_MISS)
1349 return;
1350
1351 *shader_index = group->recursive_shader;
1352 *handle_index = group->handle.general_index;
1353 }
1354 }
1355
1356 static void
radv_build_recursive_case(nir_builder * b,nir_def * sbt_idx,struct radv_ray_tracing_group * group,struct radv_rt_case_data * data)1357 radv_build_recursive_case(nir_builder *b, nir_def *sbt_idx, struct radv_ray_tracing_group *group,
1358 struct radv_rt_case_data *data)
1359 {
1360 nir_shader *nir_stage =
1361 radv_pipeline_cache_handle_to_nir(data->device, data->pipeline->stages[group->recursive_shader].nir);
1362 assert(nir_stage);
1363
1364 radv_nir_lower_rt_io(nir_stage, data->vars->monolithic, data->vars->payload_offset);
1365
1366 insert_rt_case(b, nir_stage, data->vars, sbt_idx, group->handle.general_index);
1367 ralloc_free(nir_stage);
1368 }
1369
1370 static void
handle_candidate_triangle(nir_builder * b,struct radv_triangle_intersection * intersection,const struct radv_ray_traversal_args * args,const struct radv_ray_flags * ray_flags)1371 handle_candidate_triangle(nir_builder *b, struct radv_triangle_intersection *intersection,
1372 const struct radv_ray_traversal_args *args, const struct radv_ray_flags *ray_flags)
1373 {
1374 struct traversal_data *data = args->data;
1375
1376 nir_def *geometry_id = nir_iand_imm(b, intersection->base.geometry_id_and_flags, 0xfffffff);
1377 nir_def *sbt_idx =
1378 nir_iadd(b,
1379 nir_iadd(b, nir_load_var(b, data->vars->sbt_offset),
1380 nir_iand_imm(b, nir_load_var(b, data->trav_vars->sbt_offset_and_flags), 0xffffff)),
1381 nir_imul(b, nir_load_var(b, data->vars->sbt_stride), geometry_id));
1382
1383 nir_def *hit_kind = nir_bcsel(b, intersection->frontface, nir_imm_int(b, 0xFE), nir_imm_int(b, 0xFF));
1384
1385 nir_def *prev_barycentrics = nir_load_var(b, data->barycentrics);
1386 nir_store_var(b, data->barycentrics, intersection->barycentrics, 0x3);
1387
1388 nir_store_var(b, data->vars->ahit_accept, nir_imm_true(b), 0x1);
1389 nir_store_var(b, data->vars->ahit_terminate, nir_imm_false(b), 0x1);
1390
1391 nir_push_if(b, nir_inot(b, intersection->base.opaque));
1392 {
1393 struct rt_variables inner_vars = create_inner_vars(b, data->vars);
1394
1395 nir_store_var(b, inner_vars.primitive_id, intersection->base.primitive_id, 1);
1396 nir_store_var(b, inner_vars.geometry_id_and_flags, intersection->base.geometry_id_and_flags, 1);
1397 nir_store_var(b, inner_vars.tmax, intersection->t, 0x1);
1398 nir_store_var(b, inner_vars.instance_addr, nir_load_var(b, data->trav_vars->instance_addr), 0x1);
1399 nir_store_var(b, inner_vars.hit_kind, hit_kind, 0x1);
1400
1401 load_sbt_entry(b, &inner_vars, sbt_idx, SBT_HIT, SBT_ANY_HIT_IDX);
1402
1403 struct radv_rt_case_data case_data = {
1404 .device = data->device,
1405 .pipeline = data->pipeline,
1406 .vars = &inner_vars,
1407 };
1408
1409 if (data->vars->ahit_isec_count)
1410 nir_store_var(b, data->vars->ahit_isec_count, nir_iadd_imm(b, nir_load_var(b, data->vars->ahit_isec_count), 1),
1411 0x1);
1412
1413 radv_visit_inlined_shaders(
1414 b, nir_load_var(b, inner_vars.idx),
1415 !(data->vars->flags & VK_PIPELINE_CREATE_2_RAY_TRACING_NO_NULL_ANY_HIT_SHADERS_BIT_KHR), &case_data,
1416 radv_ray_tracing_group_ahit_info, radv_build_ahit_case);
1417
1418 nir_push_if(b, nir_inot(b, nir_load_var(b, data->vars->ahit_accept)));
1419 {
1420 nir_store_var(b, data->barycentrics, prev_barycentrics, 0x3);
1421 nir_jump(b, nir_jump_continue);
1422 }
1423 nir_pop_if(b, NULL);
1424 }
1425 nir_pop_if(b, NULL);
1426
1427 nir_store_var(b, data->vars->primitive_id, intersection->base.primitive_id, 1);
1428 nir_store_var(b, data->vars->geometry_id_and_flags, intersection->base.geometry_id_and_flags, 1);
1429 nir_store_var(b, data->vars->tmax, intersection->t, 0x1);
1430 nir_store_var(b, data->vars->instance_addr, nir_load_var(b, data->trav_vars->instance_addr), 0x1);
1431 nir_store_var(b, data->vars->hit_kind, hit_kind, 0x1);
1432
1433 nir_store_var(b, data->vars->idx, sbt_idx, 1);
1434 nir_store_var(b, data->trav_vars->hit, nir_imm_true(b), 1);
1435
1436 nir_def *ray_terminated = nir_load_var(b, data->vars->ahit_terminate);
1437 nir_break_if(b, nir_ior(b, ray_flags->terminate_on_first_hit, ray_terminated));
1438 }
1439
1440 static void
handle_candidate_aabb(nir_builder * b,struct radv_leaf_intersection * intersection,const struct radv_ray_traversal_args * args)1441 handle_candidate_aabb(nir_builder *b, struct radv_leaf_intersection *intersection,
1442 const struct radv_ray_traversal_args *args)
1443 {
1444 struct traversal_data *data = args->data;
1445
1446 nir_def *geometry_id = nir_iand_imm(b, intersection->geometry_id_and_flags, 0xfffffff);
1447 nir_def *sbt_idx =
1448 nir_iadd(b,
1449 nir_iadd(b, nir_load_var(b, data->vars->sbt_offset),
1450 nir_iand_imm(b, nir_load_var(b, data->trav_vars->sbt_offset_and_flags), 0xffffff)),
1451 nir_imul(b, nir_load_var(b, data->vars->sbt_stride), geometry_id));
1452
1453 struct rt_variables inner_vars = create_inner_vars(b, data->vars);
1454
1455 /* For AABBs the intersection shader writes the hit kind, and only does it if it is the
1456 * next closest hit candidate. */
1457 inner_vars.hit_kind = data->vars->hit_kind;
1458
1459 nir_store_var(b, inner_vars.primitive_id, intersection->primitive_id, 1);
1460 nir_store_var(b, inner_vars.geometry_id_and_flags, intersection->geometry_id_and_flags, 1);
1461 nir_store_var(b, inner_vars.tmax, nir_load_var(b, data->vars->tmax), 0x1);
1462 nir_store_var(b, inner_vars.instance_addr, nir_load_var(b, data->trav_vars->instance_addr), 0x1);
1463 nir_store_var(b, inner_vars.opaque, intersection->opaque, 1);
1464
1465 load_sbt_entry(b, &inner_vars, sbt_idx, SBT_HIT, SBT_INTERSECTION_IDX);
1466
1467 nir_store_var(b, data->vars->ahit_accept, nir_imm_false(b), 0x1);
1468 nir_store_var(b, data->vars->ahit_terminate, nir_imm_false(b), 0x1);
1469 nir_store_var(b, data->vars->terminated, nir_imm_false(b), 0x1);
1470
1471 if (data->vars->ahit_isec_count)
1472 nir_store_var(b, data->vars->ahit_isec_count,
1473 nir_iadd_imm(b, nir_load_var(b, data->vars->ahit_isec_count), 1 << 16), 0x1);
1474
1475 struct radv_rt_case_data case_data = {
1476 .device = data->device,
1477 .pipeline = data->pipeline,
1478 .vars = &inner_vars,
1479 };
1480
1481 radv_visit_inlined_shaders(
1482 b, nir_load_var(b, inner_vars.idx),
1483 !(data->vars->flags & VK_PIPELINE_CREATE_2_RAY_TRACING_NO_NULL_INTERSECTION_SHADERS_BIT_KHR), &case_data,
1484 radv_ray_tracing_group_isec_info, radv_build_isec_case);
1485
1486 nir_push_if(b, nir_load_var(b, data->vars->ahit_accept));
1487 {
1488 nir_store_var(b, data->vars->primitive_id, intersection->primitive_id, 1);
1489 nir_store_var(b, data->vars->geometry_id_and_flags, intersection->geometry_id_and_flags, 1);
1490 nir_store_var(b, data->vars->tmax, nir_load_var(b, inner_vars.tmax), 0x1);
1491 nir_store_var(b, data->vars->instance_addr, nir_load_var(b, data->trav_vars->instance_addr), 0x1);
1492
1493 nir_store_var(b, data->vars->idx, sbt_idx, 1);
1494 nir_store_var(b, data->trav_vars->hit, nir_imm_true(b), 1);
1495
1496 nir_break_if(b, nir_load_var(b, data->vars->terminated));
1497 }
1498 nir_pop_if(b, NULL);
1499 }
1500
1501 static void
store_stack_entry(nir_builder * b,nir_def * index,nir_def * value,const struct radv_ray_traversal_args * args)1502 store_stack_entry(nir_builder *b, nir_def *index, nir_def *value, const struct radv_ray_traversal_args *args)
1503 {
1504 nir_store_shared(b, value, index, .base = 0, .align_mul = 4);
1505 }
1506
1507 static nir_def *
load_stack_entry(nir_builder * b,nir_def * index,const struct radv_ray_traversal_args * args)1508 load_stack_entry(nir_builder *b, nir_def *index, const struct radv_ray_traversal_args *args)
1509 {
1510 return nir_load_shared(b, 1, 32, index, .base = 0, .align_mul = 4);
1511 }
1512
1513 static void
radv_build_traversal(struct radv_device * device,struct radv_ray_tracing_pipeline * pipeline,const VkRayTracingPipelineCreateInfoKHR * pCreateInfo,bool monolithic,nir_builder * b,struct rt_variables * vars,bool ignore_cull_mask,struct radv_ray_tracing_stage_info * info)1514 radv_build_traversal(struct radv_device *device, struct radv_ray_tracing_pipeline *pipeline,
1515 const VkRayTracingPipelineCreateInfoKHR *pCreateInfo, bool monolithic, nir_builder *b,
1516 struct rt_variables *vars, bool ignore_cull_mask, struct radv_ray_tracing_stage_info *info)
1517 {
1518 const struct radv_physical_device *pdev = radv_device_physical(device);
1519 nir_variable *barycentrics =
1520 nir_variable_create(b->shader, nir_var_ray_hit_attrib, glsl_vector_type(GLSL_TYPE_FLOAT, 2), "barycentrics");
1521 barycentrics->data.driver_location = 0;
1522
1523 struct rt_traversal_vars trav_vars = init_traversal_vars(b);
1524
1525 nir_store_var(b, trav_vars.hit, nir_imm_false(b), 1);
1526
1527 nir_def *accel_struct = nir_load_var(b, vars->accel_struct);
1528 nir_def *bvh_offset = nir_build_load_global(
1529 b, 1, 32, nir_iadd_imm(b, accel_struct, offsetof(struct radv_accel_struct_header, bvh_offset)),
1530 .access = ACCESS_NON_WRITEABLE);
1531 nir_def *root_bvh_base = nir_iadd(b, accel_struct, nir_u2u64(b, bvh_offset));
1532 root_bvh_base = build_addr_to_node(b, root_bvh_base);
1533
1534 nir_store_var(b, trav_vars.bvh_base, root_bvh_base, 1);
1535
1536 nir_def *vec3ones = nir_imm_vec3(b, 1.0, 1.0, 1.0);
1537
1538 nir_store_var(b, trav_vars.origin, nir_load_var(b, vars->origin), 7);
1539 nir_store_var(b, trav_vars.dir, nir_load_var(b, vars->direction), 7);
1540 nir_store_var(b, trav_vars.inv_dir, nir_fdiv(b, vec3ones, nir_load_var(b, trav_vars.dir)), 7);
1541 nir_store_var(b, trav_vars.sbt_offset_and_flags, nir_imm_int(b, 0), 1);
1542 nir_store_var(b, trav_vars.instance_addr, nir_imm_int64(b, 0), 1);
1543
1544 nir_store_var(b, trav_vars.stack, nir_imul_imm(b, nir_load_local_invocation_index(b), sizeof(uint32_t)), 1);
1545 nir_store_var(b, trav_vars.stack_low_watermark, nir_load_var(b, trav_vars.stack), 1);
1546 nir_store_var(b, trav_vars.current_node, nir_imm_int(b, RADV_BVH_ROOT_NODE), 0x1);
1547 nir_store_var(b, trav_vars.previous_node, nir_imm_int(b, RADV_BVH_INVALID_NODE), 0x1);
1548 nir_store_var(b, trav_vars.instance_top_node, nir_imm_int(b, RADV_BVH_INVALID_NODE), 0x1);
1549 nir_store_var(b, trav_vars.instance_bottom_node, nir_imm_int(b, RADV_BVH_NO_INSTANCE_ROOT), 0x1);
1550
1551 nir_store_var(b, trav_vars.top_stack, nir_imm_int(b, -1), 1);
1552
1553 struct radv_ray_traversal_vars trav_vars_args = {
1554 .tmax = nir_build_deref_var(b, vars->tmax),
1555 .origin = nir_build_deref_var(b, trav_vars.origin),
1556 .dir = nir_build_deref_var(b, trav_vars.dir),
1557 .inv_dir = nir_build_deref_var(b, trav_vars.inv_dir),
1558 .bvh_base = nir_build_deref_var(b, trav_vars.bvh_base),
1559 .stack = nir_build_deref_var(b, trav_vars.stack),
1560 .top_stack = nir_build_deref_var(b, trav_vars.top_stack),
1561 .stack_low_watermark = nir_build_deref_var(b, trav_vars.stack_low_watermark),
1562 .current_node = nir_build_deref_var(b, trav_vars.current_node),
1563 .previous_node = nir_build_deref_var(b, trav_vars.previous_node),
1564 .instance_top_node = nir_build_deref_var(b, trav_vars.instance_top_node),
1565 .instance_bottom_node = nir_build_deref_var(b, trav_vars.instance_bottom_node),
1566 .instance_addr = nir_build_deref_var(b, trav_vars.instance_addr),
1567 .sbt_offset_and_flags = nir_build_deref_var(b, trav_vars.sbt_offset_and_flags),
1568 };
1569
1570 nir_variable *iteration_instance_count = NULL;
1571 if (vars->device->rra_trace.ray_history_addr) {
1572 iteration_instance_count =
1573 nir_variable_create(b->shader, nir_var_shader_temp, glsl_uint_type(), "iteration_instance_count");
1574 nir_store_var(b, iteration_instance_count, nir_imm_int(b, 0), 0x1);
1575 trav_vars_args.iteration_instance_count = nir_build_deref_var(b, iteration_instance_count);
1576
1577 nir_store_var(b, vars->ahit_isec_count, nir_imm_int(b, 0), 0x1);
1578 }
1579
1580 struct traversal_data data = {
1581 .device = device,
1582 .vars = vars,
1583 .trav_vars = &trav_vars,
1584 .barycentrics = barycentrics,
1585 .pipeline = pipeline,
1586 };
1587
1588 nir_def *cull_mask_and_flags = nir_load_var(b, vars->cull_mask_and_flags);
1589 struct radv_ray_traversal_args args = {
1590 .root_bvh_base = root_bvh_base,
1591 .flags = cull_mask_and_flags,
1592 .cull_mask = cull_mask_and_flags,
1593 .origin = nir_load_var(b, vars->origin),
1594 .tmin = nir_load_var(b, vars->tmin),
1595 .dir = nir_load_var(b, vars->direction),
1596 .vars = trav_vars_args,
1597 .stack_stride = pdev->rt_wave_size * sizeof(uint32_t),
1598 .stack_entries = MAX_STACK_ENTRY_COUNT,
1599 .stack_base = 0,
1600 .ignore_cull_mask = ignore_cull_mask,
1601 .set_flags = info ? info->set_flags : 0,
1602 .unset_flags = info ? info->unset_flags : 0,
1603 .stack_store_cb = store_stack_entry,
1604 .stack_load_cb = load_stack_entry,
1605 .aabb_cb = (pipeline->base.base.create_flags & VK_PIPELINE_CREATE_2_RAY_TRACING_SKIP_AABBS_BIT_KHR)
1606 ? NULL
1607 : handle_candidate_aabb,
1608 .triangle_cb = (pipeline->base.base.create_flags & VK_PIPELINE_CREATE_2_RAY_TRACING_SKIP_TRIANGLES_BIT_KHR)
1609 ? NULL
1610 : handle_candidate_triangle,
1611 .data = &data,
1612 };
1613
1614 nir_def *original_tmax = nir_load_var(b, vars->tmax);
1615
1616 radv_build_ray_traversal(device, b, &args);
1617
1618 if (vars->device->rra_trace.ray_history_addr)
1619 radv_build_end_trace_token(b, vars, original_tmax, nir_load_var(b, trav_vars.hit),
1620 nir_load_var(b, iteration_instance_count));
1621
1622 nir_metadata_preserve(nir_shader_get_entrypoint(b->shader), nir_metadata_none);
1623 radv_nir_lower_hit_attrib_derefs(b->shader);
1624
1625 /* Register storage for hit attributes */
1626 nir_variable *hit_attribs[RADV_MAX_HIT_ATTRIB_DWORDS];
1627
1628 if (!monolithic) {
1629 for (uint32_t i = 0; i < ARRAY_SIZE(hit_attribs); i++)
1630 hit_attribs[i] =
1631 nir_local_variable_create(nir_shader_get_entrypoint(b->shader), glsl_uint_type(), "ahit_attrib");
1632
1633 lower_hit_attribs(b->shader, hit_attribs, pdev->rt_wave_size);
1634 }
1635
1636 /* Initialize follow-up shader. */
1637 nir_push_if(b, nir_load_var(b, trav_vars.hit));
1638 {
1639 if (monolithic) {
1640 load_sbt_entry(b, vars, nir_load_var(b, vars->idx), SBT_HIT, SBT_CLOSEST_HIT_IDX);
1641
1642 nir_def *should_return =
1643 nir_test_mask(b, nir_load_var(b, vars->cull_mask_and_flags), SpvRayFlagsSkipClosestHitShaderKHRMask);
1644
1645 /* should_return is set if we had a hit but we won't be calling the closest hit
1646 * shader and hence need to return immediately to the calling shader. */
1647 nir_push_if(b, nir_inot(b, should_return));
1648
1649 struct radv_rt_case_data case_data = {
1650 .device = device,
1651 .pipeline = pipeline,
1652 .vars = vars,
1653 };
1654
1655 radv_visit_inlined_shaders(
1656 b, nir_load_var(b, vars->idx),
1657 !(vars->flags & VK_PIPELINE_CREATE_2_RAY_TRACING_NO_NULL_CLOSEST_HIT_SHADERS_BIT_KHR), &case_data,
1658 radv_ray_tracing_group_chit_info, radv_build_recursive_case);
1659
1660 nir_pop_if(b, NULL);
1661 } else {
1662 for (int i = 0; i < ARRAY_SIZE(hit_attribs); ++i)
1663 nir_store_hit_attrib_amd(b, nir_load_var(b, hit_attribs[i]), .base = i);
1664 nir_execute_closest_hit_amd(b, nir_load_var(b, vars->idx), nir_load_var(b, vars->tmax),
1665 nir_load_var(b, vars->primitive_id), nir_load_var(b, vars->instance_addr),
1666 nir_load_var(b, vars->geometry_id_and_flags), nir_load_var(b, vars->hit_kind));
1667 }
1668 }
1669 nir_push_else(b, NULL);
1670 {
1671 if (monolithic) {
1672 load_sbt_entry(b, vars, nir_load_var(b, vars->miss_index), SBT_MISS, SBT_GENERAL_IDX);
1673
1674 struct radv_rt_case_data case_data = {
1675 .device = device,
1676 .pipeline = pipeline,
1677 .vars = vars,
1678 };
1679
1680 radv_visit_inlined_shaders(b, nir_load_var(b, vars->idx),
1681 !(vars->flags & VK_PIPELINE_CREATE_2_RAY_TRACING_NO_NULL_MISS_SHADERS_BIT_KHR),
1682 &case_data, radv_ray_tracing_group_miss_info, radv_build_recursive_case);
1683 } else {
1684 /* Only load the miss shader if we actually miss. It is valid to not specify an SBT pointer
1685 * for miss shaders if none of the rays miss. */
1686 nir_execute_miss_amd(b, nir_load_var(b, vars->tmax));
1687 }
1688 }
1689 nir_pop_if(b, NULL);
1690 }
1691
1692 nir_shader *
radv_build_traversal_shader(struct radv_device * device,struct radv_ray_tracing_pipeline * pipeline,const VkRayTracingPipelineCreateInfoKHR * pCreateInfo,struct radv_ray_tracing_stage_info * info)1693 radv_build_traversal_shader(struct radv_device *device, struct radv_ray_tracing_pipeline *pipeline,
1694 const VkRayTracingPipelineCreateInfoKHR *pCreateInfo,
1695 struct radv_ray_tracing_stage_info *info)
1696 {
1697 const struct radv_physical_device *pdev = radv_device_physical(device);
1698 const VkPipelineCreateFlagBits2 create_flags = vk_rt_pipeline_create_flags(pCreateInfo);
1699
1700 /* Create the traversal shader as an intersection shader to prevent validation failures due to
1701 * invalid variable modes.*/
1702 nir_builder b = radv_meta_init_shader(device, MESA_SHADER_INTERSECTION, "rt_traversal");
1703 b.shader->info.internal = false;
1704 b.shader->info.workgroup_size[0] = 8;
1705 b.shader->info.workgroup_size[1] = pdev->rt_wave_size == 64 ? 8 : 4;
1706 b.shader->info.shared_size = pdev->rt_wave_size * MAX_STACK_ENTRY_COUNT * sizeof(uint32_t);
1707 struct rt_variables vars = create_rt_variables(b.shader, device, create_flags, false);
1708
1709 if (info->tmin.state == RADV_RT_CONST_ARG_STATE_VALID)
1710 nir_store_var(&b, vars.tmin, nir_imm_int(&b, info->tmin.value), 0x1);
1711 else
1712 nir_store_var(&b, vars.tmin, nir_load_ray_t_min(&b), 0x1);
1713
1714 if (info->tmax.state == RADV_RT_CONST_ARG_STATE_VALID)
1715 nir_store_var(&b, vars.tmax, nir_imm_int(&b, info->tmax.value), 0x1);
1716 else
1717 nir_store_var(&b, vars.tmax, nir_load_ray_t_max(&b), 0x1);
1718
1719 if (info->sbt_offset.state == RADV_RT_CONST_ARG_STATE_VALID)
1720 nir_store_var(&b, vars.sbt_offset, nir_imm_int(&b, info->sbt_offset.value), 0x1);
1721 else
1722 nir_store_var(&b, vars.sbt_offset, nir_load_sbt_offset_amd(&b), 0x1);
1723
1724 if (info->sbt_stride.state == RADV_RT_CONST_ARG_STATE_VALID)
1725 nir_store_var(&b, vars.sbt_stride, nir_imm_int(&b, info->sbt_stride.value), 0x1);
1726 else
1727 nir_store_var(&b, vars.sbt_stride, nir_load_sbt_stride_amd(&b), 0x1);
1728
1729 /* initialize trace_ray arguments */
1730 nir_store_var(&b, vars.accel_struct, nir_load_accel_struct_amd(&b), 1);
1731 nir_store_var(&b, vars.cull_mask_and_flags, nir_load_cull_mask_and_flags_amd(&b), 0x1);
1732 nir_store_var(&b, vars.origin, nir_load_ray_world_origin(&b), 0x7);
1733 nir_store_var(&b, vars.direction, nir_load_ray_world_direction(&b), 0x7);
1734 nir_store_var(&b, vars.arg, nir_load_rt_arg_scratch_offset_amd(&b), 0x1);
1735 nir_store_var(&b, vars.stack_ptr, nir_imm_int(&b, 0), 0x1);
1736
1737 radv_build_traversal(device, pipeline, pCreateInfo, false, &b, &vars, false, info);
1738
1739 /* Deal with all the inline functions. */
1740 nir_index_ssa_defs(nir_shader_get_entrypoint(b.shader));
1741 nir_metadata_preserve(nir_shader_get_entrypoint(b.shader), nir_metadata_none);
1742
1743 /* Lower and cleanup variables */
1744 NIR_PASS_V(b.shader, nir_lower_global_vars_to_local);
1745 NIR_PASS_V(b.shader, nir_lower_vars_to_ssa);
1746
1747 return b.shader;
1748 }
1749
1750 struct lower_rt_instruction_monolithic_state {
1751 struct radv_device *device;
1752 struct radv_ray_tracing_pipeline *pipeline;
1753 const VkRayTracingPipelineCreateInfoKHR *pCreateInfo;
1754
1755 struct rt_variables *vars;
1756 };
1757
1758 static bool
lower_rt_instruction_monolithic(nir_builder * b,nir_instr * instr,void * data)1759 lower_rt_instruction_monolithic(nir_builder *b, nir_instr *instr, void *data)
1760 {
1761 if (instr->type != nir_instr_type_intrinsic)
1762 return false;
1763
1764 b->cursor = nir_after_instr(instr);
1765
1766 nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
1767
1768 struct lower_rt_instruction_monolithic_state *state = data;
1769 const struct radv_physical_device *pdev = radv_device_physical(state->device);
1770 struct rt_variables *vars = state->vars;
1771
1772 switch (intr->intrinsic) {
1773 case nir_intrinsic_execute_callable:
1774 /* It's allowed to place OpExecuteCallableKHR in a SPIR-V, even if the RT pipeline doesn't contain
1775 * any callable shaders. However, it's impossible to execute the instruction in a valid way, so just remove any
1776 * nir_intrinsic_execute_callable we encounter.
1777 */
1778 nir_instr_remove(instr);
1779 return true;
1780 case nir_intrinsic_trace_ray: {
1781 vars->payload_offset = nir_src_as_uint(intr->src[10]);
1782
1783 nir_src cull_mask = intr->src[2];
1784 bool ignore_cull_mask = nir_src_is_const(cull_mask) && (nir_src_as_uint(cull_mask) & 0xFF) == 0xFF;
1785
1786 /* Per the SPIR-V extension spec we have to ignore some bits for some arguments. */
1787 nir_store_var(b, vars->accel_struct, intr->src[0].ssa, 0x1);
1788 nir_store_var(b, vars->cull_mask_and_flags, nir_ior(b, nir_ishl_imm(b, cull_mask.ssa, 24), intr->src[1].ssa),
1789 0x1);
1790 nir_store_var(b, vars->sbt_offset, nir_iand_imm(b, intr->src[3].ssa, 0xf), 0x1);
1791 nir_store_var(b, vars->sbt_stride, nir_iand_imm(b, intr->src[4].ssa, 0xf), 0x1);
1792 nir_store_var(b, vars->miss_index, nir_iand_imm(b, intr->src[5].ssa, 0xffff), 0x1);
1793 nir_store_var(b, vars->origin, intr->src[6].ssa, 0x7);
1794 nir_store_var(b, vars->tmin, intr->src[7].ssa, 0x1);
1795 nir_store_var(b, vars->direction, intr->src[8].ssa, 0x7);
1796 nir_store_var(b, vars->tmax, intr->src[9].ssa, 0x1);
1797
1798 nir_def *stack_ptr = nir_load_var(b, vars->stack_ptr);
1799 nir_store_var(b, vars->stack_ptr, nir_iadd_imm(b, stack_ptr, b->shader->scratch_size), 0x1);
1800
1801 radv_build_traversal(state->device, state->pipeline, state->pCreateInfo, true, b, vars, ignore_cull_mask, NULL);
1802 b->shader->info.shared_size =
1803 MAX2(b->shader->info.shared_size, pdev->rt_wave_size * MAX_STACK_ENTRY_COUNT * sizeof(uint32_t));
1804
1805 nir_store_var(b, vars->stack_ptr, stack_ptr, 0x1);
1806
1807 nir_instr_remove(instr);
1808 return true;
1809 }
1810 case nir_intrinsic_rt_resume:
1811 unreachable("nir_intrinsic_rt_resume");
1812 case nir_intrinsic_rt_return_amd:
1813 unreachable("nir_intrinsic_rt_return_amd");
1814 case nir_intrinsic_execute_closest_hit_amd:
1815 unreachable("nir_intrinsic_execute_closest_hit_amd");
1816 case nir_intrinsic_execute_miss_amd:
1817 unreachable("nir_intrinsic_execute_miss_amd");
1818 default:
1819 return false;
1820 }
1821 }
1822
1823 static bool
radv_count_hit_attrib_slots(nir_builder * b,nir_intrinsic_instr * instr,void * data)1824 radv_count_hit_attrib_slots(nir_builder *b, nir_intrinsic_instr *instr, void *data)
1825 {
1826 uint32_t *count = data;
1827 if (instr->intrinsic == nir_intrinsic_load_hit_attrib_amd || instr->intrinsic == nir_intrinsic_store_hit_attrib_amd)
1828 *count = MAX2(*count, nir_intrinsic_base(instr) + 1);
1829
1830 return false;
1831 }
1832
1833 static void
lower_rt_instructions_monolithic(nir_shader * shader,struct radv_device * device,struct radv_ray_tracing_pipeline * pipeline,const VkRayTracingPipelineCreateInfoKHR * pCreateInfo,struct rt_variables * vars)1834 lower_rt_instructions_monolithic(nir_shader *shader, struct radv_device *device,
1835 struct radv_ray_tracing_pipeline *pipeline,
1836 const VkRayTracingPipelineCreateInfoKHR *pCreateInfo, struct rt_variables *vars)
1837 {
1838 nir_function_impl *impl = nir_shader_get_entrypoint(shader);
1839
1840 struct lower_rt_instruction_monolithic_state state = {
1841 .device = device,
1842 .pipeline = pipeline,
1843 .pCreateInfo = pCreateInfo,
1844 .vars = vars,
1845 };
1846
1847 nir_shader_instructions_pass(shader, lower_rt_instruction_monolithic, nir_metadata_none, &state);
1848 nir_index_ssa_defs(impl);
1849
1850 uint32_t hit_attrib_count = 0;
1851 nir_shader_intrinsics_pass(shader, radv_count_hit_attrib_slots, nir_metadata_all, &hit_attrib_count);
1852
1853 /* Register storage for hit attributes */
1854 STACK_ARRAY(nir_variable *, hit_attribs, hit_attrib_count);
1855 for (uint32_t i = 0; i < hit_attrib_count; i++)
1856 hit_attribs[i] = nir_local_variable_create(impl, glsl_uint_type(), "ahit_attrib");
1857
1858 lower_hit_attribs(shader, hit_attribs, 0);
1859 }
1860
1861 /** Select the next shader based on priorities:
1862 *
1863 * Detect the priority of the shader stage by the lowest bits in the address (low to high):
1864 * - Raygen - idx 0
1865 * - Traversal - idx 1
1866 * - Closest Hit / Miss - idx 2
1867 * - Callable - idx 3
1868 *
1869 *
1870 * This gives us the following priorities:
1871 * Raygen : Callable > > Traversal > Raygen
1872 * Traversal : > Chit / Miss > > Raygen
1873 * CHit / Miss : Callable > Chit / Miss > Traversal > Raygen
1874 * Callable : Callable > Chit / Miss > > Raygen
1875 */
1876 static nir_def *
select_next_shader(nir_builder * b,nir_def * shader_addr,unsigned wave_size)1877 select_next_shader(nir_builder *b, nir_def *shader_addr, unsigned wave_size)
1878 {
1879 gl_shader_stage stage = b->shader->info.stage;
1880 nir_def *prio = nir_iand_imm(b, shader_addr, radv_rt_priority_mask);
1881 nir_def *ballot = nir_ballot(b, 1, wave_size, nir_imm_bool(b, true));
1882 nir_def *ballot_traversal = nir_ballot(b, 1, wave_size, nir_ieq_imm(b, prio, radv_rt_priority_traversal));
1883 nir_def *ballot_hit_miss = nir_ballot(b, 1, wave_size, nir_ieq_imm(b, prio, radv_rt_priority_hit_miss));
1884 nir_def *ballot_callable = nir_ballot(b, 1, wave_size, nir_ieq_imm(b, prio, radv_rt_priority_callable));
1885
1886 if (stage != MESA_SHADER_CALLABLE && stage != MESA_SHADER_INTERSECTION)
1887 ballot = nir_bcsel(b, nir_ine_imm(b, ballot_traversal, 0), ballot_traversal, ballot);
1888 if (stage != MESA_SHADER_RAYGEN)
1889 ballot = nir_bcsel(b, nir_ine_imm(b, ballot_hit_miss, 0), ballot_hit_miss, ballot);
1890 if (stage != MESA_SHADER_INTERSECTION)
1891 ballot = nir_bcsel(b, nir_ine_imm(b, ballot_callable, 0), ballot_callable, ballot);
1892
1893 nir_def *lsb = nir_find_lsb(b, ballot);
1894 nir_def *next = nir_read_invocation(b, shader_addr, lsb);
1895 return nir_iand_imm(b, next, ~radv_rt_priority_mask);
1896 }
1897
1898 static void
radv_store_arg(nir_builder * b,const struct radv_shader_args * args,const struct radv_ray_tracing_stage_info * info,struct ac_arg arg,nir_def * value)1899 radv_store_arg(nir_builder *b, const struct radv_shader_args *args, const struct radv_ray_tracing_stage_info *info,
1900 struct ac_arg arg, nir_def *value)
1901 {
1902 /* Do not pass unused data to the next stage. */
1903 if (!info || !BITSET_TEST(info->unused_args, arg.arg_index))
1904 ac_nir_store_arg(b, &args->ac, arg, value);
1905 }
1906
1907 void
radv_nir_lower_rt_abi(nir_shader * shader,const VkRayTracingPipelineCreateInfoKHR * pCreateInfo,const struct radv_shader_args * args,const struct radv_shader_info * info,uint32_t * stack_size,bool resume_shader,struct radv_device * device,struct radv_ray_tracing_pipeline * pipeline,bool monolithic,const struct radv_ray_tracing_stage_info * traversal_info)1908 radv_nir_lower_rt_abi(nir_shader *shader, const VkRayTracingPipelineCreateInfoKHR *pCreateInfo,
1909 const struct radv_shader_args *args, const struct radv_shader_info *info, uint32_t *stack_size,
1910 bool resume_shader, struct radv_device *device, struct radv_ray_tracing_pipeline *pipeline,
1911 bool monolithic, const struct radv_ray_tracing_stage_info *traversal_info)
1912 {
1913 nir_function_impl *impl = nir_shader_get_entrypoint(shader);
1914
1915 const VkPipelineCreateFlagBits2 create_flags = vk_rt_pipeline_create_flags(pCreateInfo);
1916
1917 struct rt_variables vars = create_rt_variables(shader, device, create_flags, monolithic);
1918
1919 if (monolithic)
1920 lower_rt_instructions_monolithic(shader, device, pipeline, pCreateInfo, &vars);
1921
1922 struct radv_rt_shader_info rt_info = {0};
1923
1924 lower_rt_instructions(shader, &vars, true, &rt_info);
1925
1926 if (stack_size) {
1927 vars.stack_size = MAX2(vars.stack_size, shader->scratch_size);
1928 *stack_size = MAX2(*stack_size, vars.stack_size);
1929 }
1930 shader->scratch_size = 0;
1931
1932 NIR_PASS(_, shader, nir_lower_returns);
1933
1934 nir_cf_list list;
1935 nir_cf_extract(&list, nir_before_impl(impl), nir_after_impl(impl));
1936
1937 /* initialize variables */
1938 nir_builder b = nir_builder_at(nir_before_impl(impl));
1939
1940 nir_def *descriptor_sets = ac_nir_load_arg(&b, &args->ac, args->descriptor_sets[0]);
1941 nir_def *push_constants = ac_nir_load_arg(&b, &args->ac, args->ac.push_constants);
1942 nir_def *sbt_descriptors = ac_nir_load_arg(&b, &args->ac, args->ac.rt.sbt_descriptors);
1943
1944 nir_def *launch_sizes[3];
1945 for (uint32_t i = 0; i < ARRAY_SIZE(launch_sizes); i++) {
1946 launch_sizes[i] = ac_nir_load_arg(&b, &args->ac, args->ac.rt.launch_sizes[i]);
1947 nir_store_var(&b, vars.launch_sizes[i], launch_sizes[i], 1);
1948 }
1949
1950 nir_def *scratch_offset = NULL;
1951 if (args->ac.scratch_offset.used)
1952 scratch_offset = ac_nir_load_arg(&b, &args->ac, args->ac.scratch_offset);
1953 nir_def *ring_offsets = NULL;
1954 if (args->ac.ring_offsets.used)
1955 ring_offsets = ac_nir_load_arg(&b, &args->ac, args->ac.ring_offsets);
1956
1957 nir_def *launch_ids[3];
1958 for (uint32_t i = 0; i < ARRAY_SIZE(launch_ids); i++) {
1959 launch_ids[i] = ac_nir_load_arg(&b, &args->ac, args->ac.rt.launch_ids[i]);
1960 nir_store_var(&b, vars.launch_ids[i], launch_ids[i], 1);
1961 }
1962
1963 nir_def *traversal_addr = ac_nir_load_arg(&b, &args->ac, args->ac.rt.traversal_shader_addr);
1964 nir_store_var(&b, vars.traversal_addr, nir_pack_64_2x32(&b, traversal_addr), 1);
1965
1966 nir_def *shader_addr = ac_nir_load_arg(&b, &args->ac, args->ac.rt.shader_addr);
1967 shader_addr = nir_pack_64_2x32(&b, shader_addr);
1968 nir_store_var(&b, vars.shader_addr, shader_addr, 1);
1969
1970 nir_store_var(&b, vars.stack_ptr, ac_nir_load_arg(&b, &args->ac, args->ac.rt.dynamic_callable_stack_base), 1);
1971 nir_def *record_ptr = ac_nir_load_arg(&b, &args->ac, args->ac.rt.shader_record);
1972 nir_store_var(&b, vars.shader_record_ptr, nir_pack_64_2x32(&b, record_ptr), 1);
1973 nir_store_var(&b, vars.arg, ac_nir_load_arg(&b, &args->ac, args->ac.rt.payload_offset), 1);
1974
1975 nir_def *accel_struct = ac_nir_load_arg(&b, &args->ac, args->ac.rt.accel_struct);
1976 nir_store_var(&b, vars.accel_struct, nir_pack_64_2x32(&b, accel_struct), 1);
1977 nir_store_var(&b, vars.cull_mask_and_flags, ac_nir_load_arg(&b, &args->ac, args->ac.rt.cull_mask_and_flags), 1);
1978 nir_store_var(&b, vars.sbt_offset, ac_nir_load_arg(&b, &args->ac, args->ac.rt.sbt_offset), 1);
1979 nir_store_var(&b, vars.sbt_stride, ac_nir_load_arg(&b, &args->ac, args->ac.rt.sbt_stride), 1);
1980 nir_store_var(&b, vars.origin, ac_nir_load_arg(&b, &args->ac, args->ac.rt.ray_origin), 0x7);
1981 nir_store_var(&b, vars.tmin, ac_nir_load_arg(&b, &args->ac, args->ac.rt.ray_tmin), 1);
1982 nir_store_var(&b, vars.direction, ac_nir_load_arg(&b, &args->ac, args->ac.rt.ray_direction), 0x7);
1983 nir_store_var(&b, vars.tmax, ac_nir_load_arg(&b, &args->ac, args->ac.rt.ray_tmax), 1);
1984
1985 if (traversal_info && traversal_info->miss_index.state == RADV_RT_CONST_ARG_STATE_VALID)
1986 nir_store_var(&b, vars.miss_index, nir_imm_int(&b, traversal_info->miss_index.value), 0x1);
1987 else
1988 nir_store_var(&b, vars.miss_index, ac_nir_load_arg(&b, &args->ac, args->ac.rt.miss_index), 0x1);
1989
1990 nir_store_var(&b, vars.primitive_id, ac_nir_load_arg(&b, &args->ac, args->ac.rt.primitive_id), 1);
1991 nir_def *instance_addr = ac_nir_load_arg(&b, &args->ac, args->ac.rt.instance_addr);
1992 nir_store_var(&b, vars.instance_addr, nir_pack_64_2x32(&b, instance_addr), 1);
1993 nir_store_var(&b, vars.geometry_id_and_flags, ac_nir_load_arg(&b, &args->ac, args->ac.rt.geometry_id_and_flags), 1);
1994 nir_store_var(&b, vars.hit_kind, ac_nir_load_arg(&b, &args->ac, args->ac.rt.hit_kind), 1);
1995
1996 /* guard the shader, so that only the correct invocations execute it */
1997 nir_if *shader_guard = NULL;
1998 if (shader->info.stage != MESA_SHADER_RAYGEN || resume_shader) {
1999 nir_def *uniform_shader_addr = ac_nir_load_arg(&b, &args->ac, args->ac.rt.uniform_shader_addr);
2000 uniform_shader_addr = nir_pack_64_2x32(&b, uniform_shader_addr);
2001 uniform_shader_addr = nir_ior_imm(&b, uniform_shader_addr, radv_get_rt_priority(shader->info.stage));
2002
2003 shader_guard = nir_push_if(&b, nir_ieq(&b, uniform_shader_addr, shader_addr));
2004 shader_guard->control = nir_selection_control_divergent_always_taken;
2005 }
2006
2007 nir_cf_reinsert(&list, b.cursor);
2008
2009 if (shader_guard)
2010 nir_pop_if(&b, shader_guard);
2011
2012 b.cursor = nir_after_impl(impl);
2013
2014 if (monolithic) {
2015 nir_terminate(&b);
2016 } else {
2017 /* select next shader */
2018 shader_addr = nir_load_var(&b, vars.shader_addr);
2019 nir_def *next = select_next_shader(&b, shader_addr, info->wave_size);
2020 ac_nir_store_arg(&b, &args->ac, args->ac.rt.uniform_shader_addr, next);
2021
2022 ac_nir_store_arg(&b, &args->ac, args->descriptor_sets[0], descriptor_sets);
2023 ac_nir_store_arg(&b, &args->ac, args->ac.push_constants, push_constants);
2024 ac_nir_store_arg(&b, &args->ac, args->ac.rt.sbt_descriptors, sbt_descriptors);
2025 ac_nir_store_arg(&b, &args->ac, args->ac.rt.traversal_shader_addr, traversal_addr);
2026
2027 for (uint32_t i = 0; i < ARRAY_SIZE(launch_sizes); i++) {
2028 if (rt_info.uses_launch_size)
2029 ac_nir_store_arg(&b, &args->ac, args->ac.rt.launch_sizes[i], launch_sizes[i]);
2030 else
2031 radv_store_arg(&b, args, traversal_info, args->ac.rt.launch_sizes[i], launch_sizes[i]);
2032 }
2033
2034 if (scratch_offset)
2035 ac_nir_store_arg(&b, &args->ac, args->ac.scratch_offset, scratch_offset);
2036 if (ring_offsets)
2037 ac_nir_store_arg(&b, &args->ac, args->ac.ring_offsets, ring_offsets);
2038
2039 for (uint32_t i = 0; i < ARRAY_SIZE(launch_ids); i++) {
2040 if (rt_info.uses_launch_id)
2041 ac_nir_store_arg(&b, &args->ac, args->ac.rt.launch_ids[i], launch_ids[i]);
2042 else
2043 radv_store_arg(&b, args, traversal_info, args->ac.rt.launch_ids[i], launch_ids[i]);
2044 }
2045
2046 /* store back all variables to registers */
2047 ac_nir_store_arg(&b, &args->ac, args->ac.rt.dynamic_callable_stack_base, nir_load_var(&b, vars.stack_ptr));
2048 ac_nir_store_arg(&b, &args->ac, args->ac.rt.shader_addr, shader_addr);
2049 radv_store_arg(&b, args, traversal_info, args->ac.rt.shader_record, nir_load_var(&b, vars.shader_record_ptr));
2050 radv_store_arg(&b, args, traversal_info, args->ac.rt.payload_offset, nir_load_var(&b, vars.arg));
2051 radv_store_arg(&b, args, traversal_info, args->ac.rt.accel_struct, nir_load_var(&b, vars.accel_struct));
2052 radv_store_arg(&b, args, traversal_info, args->ac.rt.cull_mask_and_flags,
2053 nir_load_var(&b, vars.cull_mask_and_flags));
2054 radv_store_arg(&b, args, traversal_info, args->ac.rt.sbt_offset, nir_load_var(&b, vars.sbt_offset));
2055 radv_store_arg(&b, args, traversal_info, args->ac.rt.sbt_stride, nir_load_var(&b, vars.sbt_stride));
2056 radv_store_arg(&b, args, traversal_info, args->ac.rt.miss_index, nir_load_var(&b, vars.miss_index));
2057 radv_store_arg(&b, args, traversal_info, args->ac.rt.ray_origin, nir_load_var(&b, vars.origin));
2058 radv_store_arg(&b, args, traversal_info, args->ac.rt.ray_tmin, nir_load_var(&b, vars.tmin));
2059 radv_store_arg(&b, args, traversal_info, args->ac.rt.ray_direction, nir_load_var(&b, vars.direction));
2060 radv_store_arg(&b, args, traversal_info, args->ac.rt.ray_tmax, nir_load_var(&b, vars.tmax));
2061
2062 radv_store_arg(&b, args, traversal_info, args->ac.rt.primitive_id, nir_load_var(&b, vars.primitive_id));
2063 radv_store_arg(&b, args, traversal_info, args->ac.rt.instance_addr, nir_load_var(&b, vars.instance_addr));
2064 radv_store_arg(&b, args, traversal_info, args->ac.rt.geometry_id_and_flags,
2065 nir_load_var(&b, vars.geometry_id_and_flags));
2066 radv_store_arg(&b, args, traversal_info, args->ac.rt.hit_kind, nir_load_var(&b, vars.hit_kind));
2067 }
2068
2069 nir_metadata_preserve(impl, nir_metadata_none);
2070
2071 /* cleanup passes */
2072 NIR_PASS_V(shader, nir_lower_global_vars_to_local);
2073 NIR_PASS_V(shader, nir_lower_vars_to_ssa);
2074 if (shader->info.stage == MESA_SHADER_CLOSEST_HIT || shader->info.stage == MESA_SHADER_INTERSECTION)
2075 NIR_PASS_V(shader, lower_hit_attribs, NULL, info->wave_size);
2076 }
2077
2078 static bool
radv_arg_def_is_unused(nir_def * def)2079 radv_arg_def_is_unused(nir_def *def)
2080 {
2081 nir_foreach_use (use, def) {
2082 nir_instr *use_instr = nir_src_parent_instr(use);
2083 if (use_instr->type == nir_instr_type_intrinsic) {
2084 nir_intrinsic_instr *use_intr = nir_instr_as_intrinsic(use_instr);
2085 if (use_intr->intrinsic == nir_intrinsic_store_scalar_arg_amd ||
2086 use_intr->intrinsic == nir_intrinsic_store_vector_arg_amd)
2087 continue;
2088 } else if (use_instr->type == nir_instr_type_phi) {
2089 nir_cf_node *prev_node = nir_cf_node_prev(&use_instr->block->cf_node);
2090 if (!prev_node)
2091 return false;
2092
2093 nir_phi_instr *phi = nir_instr_as_phi(use_instr);
2094 if (radv_arg_def_is_unused(&phi->def))
2095 continue;
2096 }
2097
2098 return false;
2099 }
2100
2101 return true;
2102 }
2103
2104 static bool
radv_gather_unused_args_instr(nir_builder * b,nir_intrinsic_instr * instr,void * data)2105 radv_gather_unused_args_instr(nir_builder *b, nir_intrinsic_instr *instr, void *data)
2106 {
2107 if (instr->intrinsic != nir_intrinsic_load_scalar_arg_amd && instr->intrinsic != nir_intrinsic_load_vector_arg_amd)
2108 return false;
2109
2110 if (!radv_arg_def_is_unused(&instr->def)) {
2111 /* This arg is used for more than passing data to the next stage. */
2112 struct radv_ray_tracing_stage_info *info = data;
2113 BITSET_CLEAR(info->unused_args, nir_intrinsic_base(instr));
2114 }
2115
2116 return false;
2117 }
2118
2119 void
radv_gather_unused_args(struct radv_ray_tracing_stage_info * info,nir_shader * nir)2120 radv_gather_unused_args(struct radv_ray_tracing_stage_info *info, nir_shader *nir)
2121 {
2122 nir_shader_intrinsics_pass(nir, radv_gather_unused_args_instr, nir_metadata_all, info);
2123 }
2124