• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright © 2021 Google
3  *
4  * Permission is hereby granted, free of charge, to any person obtaining a
5  * copy of this software and associated documentation files (the "Software"),
6  * to deal in the Software without restriction, including without limitation
7  * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8  * and/or sell copies of the Software, and to permit persons to whom the
9  * Software is furnished to do so, subject to the following conditions:
10  *
11  * The above copyright notice and this permission notice (including the next
12  * paragraph) shall be included in all copies or substantial portions of the
13  * Software.
14  *
15  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18  * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20  * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21  * IN THE SOFTWARE.
22  */
23 
24 #include "nir/nir.h"
25 #include "nir/nir_builder.h"
26 
27 #include "bvh/bvh.h"
28 #include "meta/radv_meta.h"
29 #include "nir/radv_nir.h"
30 #include "nir/radv_nir_rt_common.h"
31 #include "ac_nir.h"
32 #include "radv_private.h"
33 #include "radv_shader.h"
34 
35 #include "vk_pipeline.h"
36 
37 /* Traversal stack size. This stack is put in LDS and experimentally 16 entries results in best
38  * performance. */
39 #define MAX_STACK_ENTRY_COUNT 16
40 
41 #define RADV_RT_SWITCH_NULL_CHECK_THRESHOLD 3
42 
43 /* Minimum number of inlined shaders to use binary search to select which shader to run. */
44 #define INLINED_SHADER_BSEARCH_THRESHOLD 16
45 
46 struct radv_rt_case_data {
47    struct radv_device *device;
48    struct radv_ray_tracing_pipeline *pipeline;
49    struct rt_variables *vars;
50 };
51 
52 typedef void (*radv_get_group_info)(struct radv_ray_tracing_group *, uint32_t *, uint32_t *,
53                                     struct radv_rt_case_data *);
54 typedef void (*radv_insert_shader_case)(nir_builder *, nir_def *, struct radv_ray_tracing_group *,
55                                         struct radv_rt_case_data *);
56 
57 struct inlined_shader_case {
58    struct radv_ray_tracing_group *group;
59    uint32_t call_idx;
60 };
61 
62 static int
compare_inlined_shader_case(const void * a,const void * b)63 compare_inlined_shader_case(const void *a, const void *b)
64 {
65    const struct inlined_shader_case *visit_a = a;
66    const struct inlined_shader_case *visit_b = b;
67    return visit_a->call_idx > visit_b->call_idx ? 1 : visit_a->call_idx < visit_b->call_idx ? -1 : 0;
68 }
69 
70 static void
insert_inlined_range(nir_builder * b,nir_def * sbt_idx,radv_insert_shader_case shader_case,struct radv_rt_case_data * data,struct inlined_shader_case * cases,uint32_t length)71 insert_inlined_range(nir_builder *b, nir_def *sbt_idx, radv_insert_shader_case shader_case,
72                      struct radv_rt_case_data *data, struct inlined_shader_case *cases, uint32_t length)
73 {
74    if (length >= INLINED_SHADER_BSEARCH_THRESHOLD) {
75       nir_push_if(b, nir_ige_imm(b, sbt_idx, cases[length / 2].call_idx));
76       {
77          insert_inlined_range(b, sbt_idx, shader_case, data, cases + (length / 2), length - (length / 2));
78       }
79       nir_push_else(b, NULL);
80       {
81          insert_inlined_range(b, sbt_idx, shader_case, data, cases, length / 2);
82       }
83       nir_pop_if(b, NULL);
84    } else {
85       for (uint32_t i = 0; i < length; ++i)
86          shader_case(b, sbt_idx, cases[i].group, data);
87    }
88 }
89 
90 static void
radv_visit_inlined_shaders(nir_builder * b,nir_def * sbt_idx,bool can_have_null_shaders,struct radv_rt_case_data * data,radv_get_group_info group_info,radv_insert_shader_case shader_case)91 radv_visit_inlined_shaders(nir_builder *b, nir_def *sbt_idx, bool can_have_null_shaders, struct radv_rt_case_data *data,
92                            radv_get_group_info group_info, radv_insert_shader_case shader_case)
93 {
94    struct inlined_shader_case *cases = calloc(data->pipeline->group_count, sizeof(struct inlined_shader_case));
95    uint32_t case_count = 0;
96 
97    for (unsigned i = 0; i < data->pipeline->group_count; i++) {
98       struct radv_ray_tracing_group *group = &data->pipeline->groups[i];
99 
100       uint32_t shader_index = VK_SHADER_UNUSED_KHR;
101       uint32_t handle_index = VK_SHADER_UNUSED_KHR;
102       group_info(group, &shader_index, &handle_index, data);
103       if (shader_index == VK_SHADER_UNUSED_KHR)
104          continue;
105 
106       /* Avoid emitting stages with the same shaders/handles multiple times. */
107       bool duplicate = false;
108       for (unsigned j = 0; j < i; j++) {
109          uint32_t other_shader_index = VK_SHADER_UNUSED_KHR;
110          uint32_t other_handle_index = VK_SHADER_UNUSED_KHR;
111          group_info(&data->pipeline->groups[j], &other_shader_index, &other_handle_index, data);
112 
113          if (handle_index == other_handle_index) {
114             duplicate = true;
115             break;
116          }
117       }
118 
119       if (!duplicate) {
120          cases[case_count++] = (struct inlined_shader_case){
121             .group = group,
122             .call_idx = handle_index,
123          };
124       }
125    }
126 
127    qsort(cases, case_count, sizeof(struct inlined_shader_case), compare_inlined_shader_case);
128 
129    /* Do not emit 'if (sbt_idx != 0) { ... }' is there are only a few cases. */
130    can_have_null_shaders &= case_count >= RADV_RT_SWITCH_NULL_CHECK_THRESHOLD;
131 
132    if (can_have_null_shaders)
133       nir_push_if(b, nir_ine_imm(b, sbt_idx, 0));
134 
135    insert_inlined_range(b, sbt_idx, shader_case, data, cases, case_count);
136 
137    if (can_have_null_shaders)
138       nir_pop_if(b, NULL);
139 
140    free(cases);
141 }
142 
143 static bool
lower_rt_derefs(nir_shader * shader)144 lower_rt_derefs(nir_shader *shader)
145 {
146    nir_function_impl *impl = nir_shader_get_entrypoint(shader);
147 
148    bool progress = false;
149 
150    nir_builder b = nir_builder_at(nir_before_impl(impl));
151 
152    nir_def *arg_offset = nir_load_rt_arg_scratch_offset_amd(&b);
153 
154    nir_foreach_block (block, impl) {
155       nir_foreach_instr_safe (instr, block) {
156          if (instr->type != nir_instr_type_deref)
157             continue;
158 
159          nir_deref_instr *deref = nir_instr_as_deref(instr);
160          if (!nir_deref_mode_is(deref, nir_var_shader_call_data))
161             continue;
162 
163          deref->modes = nir_var_function_temp;
164          progress = true;
165 
166          if (deref->deref_type == nir_deref_type_var) {
167             b.cursor = nir_before_instr(&deref->instr);
168             nir_deref_instr *replacement =
169                nir_build_deref_cast(&b, arg_offset, nir_var_function_temp, deref->var->type, 0);
170             nir_def_rewrite_uses(&deref->def, &replacement->def);
171             nir_instr_remove(&deref->instr);
172          }
173       }
174    }
175 
176    if (progress)
177       nir_metadata_preserve(impl, nir_metadata_block_index | nir_metadata_dominance);
178    else
179       nir_metadata_preserve(impl, nir_metadata_all);
180 
181    return progress;
182 }
183 
184 /*
185  * Global variables for an RT pipeline
186  */
187 struct rt_variables {
188    struct radv_device *device;
189    const VkPipelineCreateFlags2KHR flags;
190    bool monolithic;
191 
192    /* idx of the next shader to run in the next iteration of the main loop.
193     * During traversal, idx is used to store the SBT index and will contain
194     * the correct resume index upon returning.
195     */
196    nir_variable *idx;
197    nir_variable *shader_addr;
198    nir_variable *traversal_addr;
199 
200    /* scratch offset of the argument area relative to stack_ptr */
201    nir_variable *arg;
202    uint32_t payload_offset;
203 
204    nir_variable *stack_ptr;
205 
206    nir_variable *ahit_isec_count;
207 
208    /* global address of the SBT entry used for the shader */
209    nir_variable *shader_record_ptr;
210 
211    /* trace_ray arguments */
212    nir_variable *accel_struct;
213    nir_variable *cull_mask_and_flags;
214    nir_variable *sbt_offset;
215    nir_variable *sbt_stride;
216    nir_variable *miss_index;
217    nir_variable *origin;
218    nir_variable *tmin;
219    nir_variable *direction;
220    nir_variable *tmax;
221 
222    /* Properties of the primitive currently being visited. */
223    nir_variable *primitive_id;
224    nir_variable *geometry_id_and_flags;
225    nir_variable *instance_addr;
226    nir_variable *hit_kind;
227    nir_variable *opaque;
228 
229    /* Output variables for intersection & anyhit shaders. */
230    nir_variable *ahit_accept;
231    nir_variable *ahit_terminate;
232 
233    unsigned stack_size;
234 };
235 
236 static struct rt_variables
create_rt_variables(nir_shader * shader,struct radv_device * device,const VkPipelineCreateFlags2KHR flags,bool monolithic)237 create_rt_variables(nir_shader *shader, struct radv_device *device, const VkPipelineCreateFlags2KHR flags,
238                     bool monolithic)
239 {
240    struct rt_variables vars = {
241       .device = device,
242       .flags = flags,
243       .monolithic = monolithic,
244    };
245    vars.idx = nir_variable_create(shader, nir_var_shader_temp, glsl_uint_type(), "idx");
246    vars.shader_addr = nir_variable_create(shader, nir_var_shader_temp, glsl_uint64_t_type(), "shader_addr");
247    vars.traversal_addr = nir_variable_create(shader, nir_var_shader_temp, glsl_uint64_t_type(), "traversal_addr");
248    vars.arg = nir_variable_create(shader, nir_var_shader_temp, glsl_uint_type(), "arg");
249    vars.stack_ptr = nir_variable_create(shader, nir_var_shader_temp, glsl_uint_type(), "stack_ptr");
250    vars.shader_record_ptr = nir_variable_create(shader, nir_var_shader_temp, glsl_uint64_t_type(), "shader_record_ptr");
251 
252    if (device->rra_trace.ray_history_addr)
253       vars.ahit_isec_count = nir_variable_create(shader, nir_var_shader_temp, glsl_uint_type(), "ahit_isec_count");
254 
255    const struct glsl_type *vec3_type = glsl_vector_type(GLSL_TYPE_FLOAT, 3);
256    vars.accel_struct = nir_variable_create(shader, nir_var_shader_temp, glsl_uint64_t_type(), "accel_struct");
257    vars.cull_mask_and_flags = nir_variable_create(shader, nir_var_shader_temp, glsl_uint_type(), "cull_mask_and_flags");
258    vars.sbt_offset = nir_variable_create(shader, nir_var_shader_temp, glsl_uint_type(), "sbt_offset");
259    vars.sbt_stride = nir_variable_create(shader, nir_var_shader_temp, glsl_uint_type(), "sbt_stride");
260    vars.miss_index = nir_variable_create(shader, nir_var_shader_temp, glsl_uint_type(), "miss_index");
261    vars.origin = nir_variable_create(shader, nir_var_shader_temp, vec3_type, "ray_origin");
262    vars.tmin = nir_variable_create(shader, nir_var_shader_temp, glsl_float_type(), "ray_tmin");
263    vars.direction = nir_variable_create(shader, nir_var_shader_temp, vec3_type, "ray_direction");
264    vars.tmax = nir_variable_create(shader, nir_var_shader_temp, glsl_float_type(), "ray_tmax");
265 
266    vars.primitive_id = nir_variable_create(shader, nir_var_shader_temp, glsl_uint_type(), "primitive_id");
267    vars.geometry_id_and_flags =
268       nir_variable_create(shader, nir_var_shader_temp, glsl_uint_type(), "geometry_id_and_flags");
269    vars.instance_addr = nir_variable_create(shader, nir_var_shader_temp, glsl_uint64_t_type(), "instance_addr");
270    vars.hit_kind = nir_variable_create(shader, nir_var_shader_temp, glsl_uint_type(), "hit_kind");
271    vars.opaque = nir_variable_create(shader, nir_var_shader_temp, glsl_bool_type(), "opaque");
272 
273    vars.ahit_accept = nir_variable_create(shader, nir_var_shader_temp, glsl_bool_type(), "ahit_accept");
274    vars.ahit_terminate = nir_variable_create(shader, nir_var_shader_temp, glsl_bool_type(), "ahit_terminate");
275 
276    return vars;
277 }
278 
279 /*
280  * Remap all the variables between the two rt_variables struct for inlining.
281  */
282 static void
map_rt_variables(struct hash_table * var_remap,struct rt_variables * src,const struct rt_variables * dst)283 map_rt_variables(struct hash_table *var_remap, struct rt_variables *src, const struct rt_variables *dst)
284 {
285    _mesa_hash_table_insert(var_remap, src->idx, dst->idx);
286    _mesa_hash_table_insert(var_remap, src->shader_addr, dst->shader_addr);
287    _mesa_hash_table_insert(var_remap, src->traversal_addr, dst->traversal_addr);
288    _mesa_hash_table_insert(var_remap, src->arg, dst->arg);
289    _mesa_hash_table_insert(var_remap, src->stack_ptr, dst->stack_ptr);
290    _mesa_hash_table_insert(var_remap, src->shader_record_ptr, dst->shader_record_ptr);
291 
292    if (dst->ahit_isec_count)
293       _mesa_hash_table_insert(var_remap, src->ahit_isec_count, dst->ahit_isec_count);
294 
295    _mesa_hash_table_insert(var_remap, src->accel_struct, dst->accel_struct);
296    _mesa_hash_table_insert(var_remap, src->cull_mask_and_flags, dst->cull_mask_and_flags);
297    _mesa_hash_table_insert(var_remap, src->sbt_offset, dst->sbt_offset);
298    _mesa_hash_table_insert(var_remap, src->sbt_stride, dst->sbt_stride);
299    _mesa_hash_table_insert(var_remap, src->miss_index, dst->miss_index);
300    _mesa_hash_table_insert(var_remap, src->origin, dst->origin);
301    _mesa_hash_table_insert(var_remap, src->tmin, dst->tmin);
302    _mesa_hash_table_insert(var_remap, src->direction, dst->direction);
303    _mesa_hash_table_insert(var_remap, src->tmax, dst->tmax);
304 
305    _mesa_hash_table_insert(var_remap, src->primitive_id, dst->primitive_id);
306    _mesa_hash_table_insert(var_remap, src->geometry_id_and_flags, dst->geometry_id_and_flags);
307    _mesa_hash_table_insert(var_remap, src->instance_addr, dst->instance_addr);
308    _mesa_hash_table_insert(var_remap, src->hit_kind, dst->hit_kind);
309    _mesa_hash_table_insert(var_remap, src->opaque, dst->opaque);
310    _mesa_hash_table_insert(var_remap, src->ahit_accept, dst->ahit_accept);
311    _mesa_hash_table_insert(var_remap, src->ahit_terminate, dst->ahit_terminate);
312 }
313 
314 /*
315  * Create a copy of the global rt variables where the primitive/instance related variables are
316  * independent.This is needed as we need to keep the old values of the global variables around
317  * in case e.g. an anyhit shader reject the collision. So there are inner variables that get copied
318  * to the outer variables once we commit to a better hit.
319  */
320 static struct rt_variables
create_inner_vars(nir_builder * b,const struct rt_variables * vars)321 create_inner_vars(nir_builder *b, const struct rt_variables *vars)
322 {
323    struct rt_variables inner_vars = *vars;
324    inner_vars.idx = nir_variable_create(b->shader, nir_var_shader_temp, glsl_uint_type(), "inner_idx");
325    inner_vars.shader_record_ptr =
326       nir_variable_create(b->shader, nir_var_shader_temp, glsl_uint64_t_type(), "inner_shader_record_ptr");
327    inner_vars.primitive_id =
328       nir_variable_create(b->shader, nir_var_shader_temp, glsl_uint_type(), "inner_primitive_id");
329    inner_vars.geometry_id_and_flags =
330       nir_variable_create(b->shader, nir_var_shader_temp, glsl_uint_type(), "inner_geometry_id_and_flags");
331    inner_vars.tmax = nir_variable_create(b->shader, nir_var_shader_temp, glsl_float_type(), "inner_tmax");
332    inner_vars.instance_addr =
333       nir_variable_create(b->shader, nir_var_shader_temp, glsl_uint64_t_type(), "inner_instance_addr");
334    inner_vars.hit_kind = nir_variable_create(b->shader, nir_var_shader_temp, glsl_uint_type(), "inner_hit_kind");
335 
336    return inner_vars;
337 }
338 
339 static void
insert_rt_return(nir_builder * b,const struct rt_variables * vars)340 insert_rt_return(nir_builder *b, const struct rt_variables *vars)
341 {
342    nir_store_var(b, vars->stack_ptr, nir_iadd_imm(b, nir_load_var(b, vars->stack_ptr), -16), 1);
343    nir_store_var(b, vars->shader_addr, nir_load_scratch(b, 1, 64, nir_load_var(b, vars->stack_ptr), .align_mul = 16),
344                  1);
345 }
346 
347 enum sbt_type {
348    SBT_RAYGEN = offsetof(VkTraceRaysIndirectCommand2KHR, raygenShaderRecordAddress),
349    SBT_MISS = offsetof(VkTraceRaysIndirectCommand2KHR, missShaderBindingTableAddress),
350    SBT_HIT = offsetof(VkTraceRaysIndirectCommand2KHR, hitShaderBindingTableAddress),
351    SBT_CALLABLE = offsetof(VkTraceRaysIndirectCommand2KHR, callableShaderBindingTableAddress),
352 };
353 
354 enum sbt_entry {
355    SBT_RECURSIVE_PTR = offsetof(struct radv_pipeline_group_handle, recursive_shader_ptr),
356    SBT_GENERAL_IDX = offsetof(struct radv_pipeline_group_handle, general_index),
357    SBT_CLOSEST_HIT_IDX = offsetof(struct radv_pipeline_group_handle, closest_hit_index),
358    SBT_INTERSECTION_IDX = offsetof(struct radv_pipeline_group_handle, intersection_index),
359    SBT_ANY_HIT_IDX = offsetof(struct radv_pipeline_group_handle, any_hit_index),
360 };
361 
362 static nir_def *
get_sbt_ptr(nir_builder * b,nir_def * idx,enum sbt_type binding)363 get_sbt_ptr(nir_builder *b, nir_def *idx, enum sbt_type binding)
364 {
365    nir_def *desc_base_addr = nir_load_sbt_base_amd(b);
366 
367    nir_def *desc = nir_pack_64_2x32(b, nir_load_smem_amd(b, 2, desc_base_addr, nir_imm_int(b, binding)));
368 
369    nir_def *stride_offset = nir_imm_int(b, binding + (binding == SBT_RAYGEN ? 8 : 16));
370    nir_def *stride = nir_pack_64_2x32(b, nir_load_smem_amd(b, 2, desc_base_addr, stride_offset));
371 
372    return nir_iadd(b, desc, nir_imul(b, nir_u2u64(b, idx), stride));
373 }
374 
375 static void
load_sbt_entry(nir_builder * b,const struct rt_variables * vars,nir_def * idx,enum sbt_type binding,enum sbt_entry offset)376 load_sbt_entry(nir_builder *b, const struct rt_variables *vars, nir_def *idx, enum sbt_type binding,
377                enum sbt_entry offset)
378 {
379    nir_def *addr = get_sbt_ptr(b, idx, binding);
380    nir_def *load_addr = nir_iadd_imm(b, addr, offset);
381 
382    if (offset == SBT_RECURSIVE_PTR) {
383       nir_store_var(b, vars->shader_addr, nir_build_load_global(b, 1, 64, load_addr), 1);
384    } else {
385       nir_store_var(b, vars->idx, nir_build_load_global(b, 1, 32, load_addr), 1);
386    }
387 
388    nir_def *record_addr = nir_iadd_imm(b, addr, RADV_RT_HANDLE_SIZE);
389    nir_store_var(b, vars->shader_record_ptr, record_addr, 1);
390 }
391 
392 struct radv_lower_rt_instruction_data {
393    struct rt_variables *vars;
394    bool apply_stack_ptr;
395 };
396 
397 static bool
radv_lower_rt_instruction(nir_builder * b,nir_instr * instr,void * _data)398 radv_lower_rt_instruction(nir_builder *b, nir_instr *instr, void *_data)
399 {
400    if (instr->type == nir_instr_type_jump) {
401       nir_jump_instr *jump = nir_instr_as_jump(instr);
402       if (jump->type == nir_jump_halt) {
403          jump->type = nir_jump_return;
404          return true;
405       }
406       return false;
407    } else if (instr->type != nir_instr_type_intrinsic) {
408       return false;
409    }
410 
411    nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
412 
413    struct radv_lower_rt_instruction_data *data = _data;
414    struct rt_variables *vars = data->vars;
415    bool apply_stack_ptr = data->apply_stack_ptr;
416 
417    b->cursor = nir_before_instr(&intr->instr);
418 
419    nir_def *ret = NULL;
420    switch (intr->intrinsic) {
421    case nir_intrinsic_rt_execute_callable: {
422       uint32_t size = align(nir_intrinsic_stack_size(intr), 16);
423       nir_def *ret_ptr = nir_load_resume_shader_address_amd(b, nir_intrinsic_call_idx(intr));
424       ret_ptr = nir_ior_imm(b, ret_ptr, radv_get_rt_priority(b->shader->info.stage));
425 
426       nir_store_var(b, vars->stack_ptr, nir_iadd_imm_nuw(b, nir_load_var(b, vars->stack_ptr), size), 1);
427       nir_store_scratch(b, ret_ptr, nir_load_var(b, vars->stack_ptr), .align_mul = 16);
428 
429       nir_store_var(b, vars->stack_ptr, nir_iadd_imm_nuw(b, nir_load_var(b, vars->stack_ptr), 16), 1);
430       load_sbt_entry(b, vars, intr->src[0].ssa, SBT_CALLABLE, SBT_RECURSIVE_PTR);
431 
432       nir_store_var(b, vars->arg, nir_iadd_imm(b, intr->src[1].ssa, -size - 16), 1);
433 
434       vars->stack_size = MAX2(vars->stack_size, size + 16);
435       break;
436    }
437    case nir_intrinsic_rt_trace_ray: {
438       uint32_t size = align(nir_intrinsic_stack_size(intr), 16);
439       nir_def *ret_ptr = nir_load_resume_shader_address_amd(b, nir_intrinsic_call_idx(intr));
440       ret_ptr = nir_ior_imm(b, ret_ptr, radv_get_rt_priority(b->shader->info.stage));
441 
442       nir_store_var(b, vars->stack_ptr, nir_iadd_imm_nuw(b, nir_load_var(b, vars->stack_ptr), size), 1);
443       nir_store_scratch(b, ret_ptr, nir_load_var(b, vars->stack_ptr), .align_mul = 16);
444 
445       nir_store_var(b, vars->stack_ptr, nir_iadd_imm_nuw(b, nir_load_var(b, vars->stack_ptr), 16), 1);
446 
447       nir_store_var(b, vars->shader_addr, nir_load_var(b, vars->traversal_addr), 1);
448       nir_store_var(b, vars->arg, nir_iadd_imm(b, intr->src[10].ssa, -size - 16), 1);
449 
450       vars->stack_size = MAX2(vars->stack_size, size + 16);
451 
452       /* Per the SPIR-V extension spec we have to ignore some bits for some arguments. */
453       nir_store_var(b, vars->accel_struct, intr->src[0].ssa, 0x1);
454       nir_store_var(b, vars->cull_mask_and_flags, nir_ior(b, nir_ishl_imm(b, intr->src[2].ssa, 24), intr->src[1].ssa),
455                     0x1);
456       nir_store_var(b, vars->sbt_offset, nir_iand_imm(b, intr->src[3].ssa, 0xf), 0x1);
457       nir_store_var(b, vars->sbt_stride, nir_iand_imm(b, intr->src[4].ssa, 0xf), 0x1);
458       nir_store_var(b, vars->miss_index, nir_iand_imm(b, intr->src[5].ssa, 0xffff), 0x1);
459       nir_store_var(b, vars->origin, intr->src[6].ssa, 0x7);
460       nir_store_var(b, vars->tmin, intr->src[7].ssa, 0x1);
461       nir_store_var(b, vars->direction, intr->src[8].ssa, 0x7);
462       nir_store_var(b, vars->tmax, intr->src[9].ssa, 0x1);
463       break;
464    }
465    case nir_intrinsic_rt_resume: {
466       uint32_t size = align(nir_intrinsic_stack_size(intr), 16);
467 
468       nir_store_var(b, vars->stack_ptr, nir_iadd_imm(b, nir_load_var(b, vars->stack_ptr), -size), 1);
469       break;
470    }
471    case nir_intrinsic_rt_return_amd: {
472       if (b->shader->info.stage == MESA_SHADER_RAYGEN) {
473          nir_terminate(b);
474          break;
475       }
476       insert_rt_return(b, vars);
477       break;
478    }
479    case nir_intrinsic_load_scratch: {
480       if (apply_stack_ptr)
481          nir_src_rewrite(&intr->src[0], nir_iadd_nuw(b, nir_load_var(b, vars->stack_ptr), intr->src[0].ssa));
482       return true;
483    }
484    case nir_intrinsic_store_scratch: {
485       if (apply_stack_ptr)
486          nir_src_rewrite(&intr->src[1], nir_iadd_nuw(b, nir_load_var(b, vars->stack_ptr), intr->src[1].ssa));
487       return true;
488    }
489    case nir_intrinsic_load_rt_arg_scratch_offset_amd: {
490       ret = nir_load_var(b, vars->arg);
491       break;
492    }
493    case nir_intrinsic_load_shader_record_ptr: {
494       ret = nir_load_var(b, vars->shader_record_ptr);
495       break;
496    }
497    case nir_intrinsic_load_ray_t_min: {
498       ret = nir_load_var(b, vars->tmin);
499       break;
500    }
501    case nir_intrinsic_load_ray_t_max: {
502       ret = nir_load_var(b, vars->tmax);
503       break;
504    }
505    case nir_intrinsic_load_ray_world_origin: {
506       ret = nir_load_var(b, vars->origin);
507       break;
508    }
509    case nir_intrinsic_load_ray_world_direction: {
510       ret = nir_load_var(b, vars->direction);
511       break;
512    }
513    case nir_intrinsic_load_ray_instance_custom_index: {
514       nir_def *instance_node_addr = nir_load_var(b, vars->instance_addr);
515       nir_def *custom_instance_and_mask = nir_build_load_global(
516          b, 1, 32,
517          nir_iadd_imm(b, instance_node_addr, offsetof(struct radv_bvh_instance_node, custom_instance_and_mask)));
518       ret = nir_iand_imm(b, custom_instance_and_mask, 0xFFFFFF);
519       break;
520    }
521    case nir_intrinsic_load_primitive_id: {
522       ret = nir_load_var(b, vars->primitive_id);
523       break;
524    }
525    case nir_intrinsic_load_ray_geometry_index: {
526       ret = nir_load_var(b, vars->geometry_id_and_flags);
527       ret = nir_iand_imm(b, ret, 0xFFFFFFF);
528       break;
529    }
530    case nir_intrinsic_load_instance_id: {
531       nir_def *instance_node_addr = nir_load_var(b, vars->instance_addr);
532       ret = nir_build_load_global(
533          b, 1, 32, nir_iadd_imm(b, instance_node_addr, offsetof(struct radv_bvh_instance_node, instance_id)));
534       break;
535    }
536    case nir_intrinsic_load_ray_flags: {
537       ret = nir_iand_imm(b, nir_load_var(b, vars->cull_mask_and_flags), 0xFFFFFF);
538       break;
539    }
540    case nir_intrinsic_load_ray_hit_kind: {
541       ret = nir_load_var(b, vars->hit_kind);
542       break;
543    }
544    case nir_intrinsic_load_ray_world_to_object: {
545       unsigned c = nir_intrinsic_column(intr);
546       nir_def *instance_node_addr = nir_load_var(b, vars->instance_addr);
547       nir_def *wto_matrix[3];
548       nir_build_wto_matrix_load(b, instance_node_addr, wto_matrix);
549 
550       nir_def *vals[3];
551       for (unsigned i = 0; i < 3; ++i)
552          vals[i] = nir_channel(b, wto_matrix[i], c);
553 
554       ret = nir_vec(b, vals, 3);
555       break;
556    }
557    case nir_intrinsic_load_ray_object_to_world: {
558       unsigned c = nir_intrinsic_column(intr);
559       nir_def *instance_node_addr = nir_load_var(b, vars->instance_addr);
560       nir_def *rows[3];
561       for (unsigned r = 0; r < 3; ++r)
562          rows[r] = nir_build_load_global(
563             b, 4, 32,
564             nir_iadd_imm(b, instance_node_addr, offsetof(struct radv_bvh_instance_node, otw_matrix) + r * 16));
565       ret = nir_vec3(b, nir_channel(b, rows[0], c), nir_channel(b, rows[1], c), nir_channel(b, rows[2], c));
566       break;
567    }
568    case nir_intrinsic_load_ray_object_origin: {
569       nir_def *instance_node_addr = nir_load_var(b, vars->instance_addr);
570       nir_def *wto_matrix[3];
571       nir_build_wto_matrix_load(b, instance_node_addr, wto_matrix);
572       ret = nir_build_vec3_mat_mult(b, nir_load_var(b, vars->origin), wto_matrix, true);
573       break;
574    }
575    case nir_intrinsic_load_ray_object_direction: {
576       nir_def *instance_node_addr = nir_load_var(b, vars->instance_addr);
577       nir_def *wto_matrix[3];
578       nir_build_wto_matrix_load(b, instance_node_addr, wto_matrix);
579       ret = nir_build_vec3_mat_mult(b, nir_load_var(b, vars->direction), wto_matrix, false);
580       break;
581    }
582    case nir_intrinsic_load_intersection_opaque_amd: {
583       ret = nir_load_var(b, vars->opaque);
584       break;
585    }
586    case nir_intrinsic_load_cull_mask: {
587       ret = nir_ushr_imm(b, nir_load_var(b, vars->cull_mask_and_flags), 24);
588       break;
589    }
590    case nir_intrinsic_ignore_ray_intersection: {
591       nir_store_var(b, vars->ahit_accept, nir_imm_false(b), 0x1);
592 
593       /* The if is a workaround to avoid having to fix up control flow manually */
594       nir_push_if(b, nir_imm_true(b));
595       nir_jump(b, nir_jump_return);
596       nir_pop_if(b, NULL);
597       break;
598    }
599    case nir_intrinsic_terminate_ray: {
600       nir_store_var(b, vars->ahit_accept, nir_imm_true(b), 0x1);
601       nir_store_var(b, vars->ahit_terminate, nir_imm_true(b), 0x1);
602 
603       /* The if is a workaround to avoid having to fix up control flow manually */
604       nir_push_if(b, nir_imm_true(b));
605       nir_jump(b, nir_jump_return);
606       nir_pop_if(b, NULL);
607       break;
608    }
609    case nir_intrinsic_report_ray_intersection: {
610       nir_push_if(b, nir_iand(b, nir_fge(b, nir_load_var(b, vars->tmax), intr->src[0].ssa),
611                               nir_fge(b, intr->src[0].ssa, nir_load_var(b, vars->tmin))));
612       {
613          nir_store_var(b, vars->ahit_accept, nir_imm_true(b), 0x1);
614          nir_store_var(b, vars->tmax, intr->src[0].ssa, 1);
615          nir_store_var(b, vars->hit_kind, intr->src[1].ssa, 1);
616       }
617       nir_pop_if(b, NULL);
618       break;
619    }
620    case nir_intrinsic_load_sbt_offset_amd: {
621       ret = nir_load_var(b, vars->sbt_offset);
622       break;
623    }
624    case nir_intrinsic_load_sbt_stride_amd: {
625       ret = nir_load_var(b, vars->sbt_stride);
626       break;
627    }
628    case nir_intrinsic_load_accel_struct_amd: {
629       ret = nir_load_var(b, vars->accel_struct);
630       break;
631    }
632    case nir_intrinsic_load_cull_mask_and_flags_amd: {
633       ret = nir_load_var(b, vars->cull_mask_and_flags);
634       break;
635    }
636    case nir_intrinsic_execute_closest_hit_amd: {
637       nir_store_var(b, vars->tmax, intr->src[1].ssa, 0x1);
638       nir_store_var(b, vars->primitive_id, intr->src[2].ssa, 0x1);
639       nir_store_var(b, vars->instance_addr, intr->src[3].ssa, 0x1);
640       nir_store_var(b, vars->geometry_id_and_flags, intr->src[4].ssa, 0x1);
641       nir_store_var(b, vars->hit_kind, intr->src[5].ssa, 0x1);
642       load_sbt_entry(b, vars, intr->src[0].ssa, SBT_HIT, SBT_RECURSIVE_PTR);
643 
644       nir_def *should_return =
645          nir_test_mask(b, nir_load_var(b, vars->cull_mask_and_flags), SpvRayFlagsSkipClosestHitShaderKHRMask);
646 
647       if (!(vars->flags & VK_PIPELINE_CREATE_2_RAY_TRACING_NO_NULL_CLOSEST_HIT_SHADERS_BIT_KHR)) {
648          should_return = nir_ior(b, should_return, nir_ieq_imm(b, nir_load_var(b, vars->shader_addr), 0));
649       }
650 
651       /* should_return is set if we had a hit but we won't be calling the closest hit
652        * shader and hence need to return immediately to the calling shader. */
653       nir_push_if(b, should_return);
654       insert_rt_return(b, vars);
655       nir_pop_if(b, NULL);
656       break;
657    }
658    case nir_intrinsic_execute_miss_amd: {
659       nir_store_var(b, vars->tmax, intr->src[0].ssa, 0x1);
660       nir_def *undef = nir_undef(b, 1, 32);
661       nir_store_var(b, vars->primitive_id, undef, 0x1);
662       nir_store_var(b, vars->instance_addr, nir_undef(b, 1, 64), 0x1);
663       nir_store_var(b, vars->geometry_id_and_flags, undef, 0x1);
664       nir_store_var(b, vars->hit_kind, undef, 0x1);
665       nir_def *miss_index = nir_load_var(b, vars->miss_index);
666       load_sbt_entry(b, vars, miss_index, SBT_MISS, SBT_RECURSIVE_PTR);
667 
668       if (!(vars->flags & VK_PIPELINE_CREATE_2_RAY_TRACING_NO_NULL_MISS_SHADERS_BIT_KHR)) {
669          /* In case of a NULL miss shader, do nothing and just return. */
670          nir_push_if(b, nir_ieq_imm(b, nir_load_var(b, vars->shader_addr), 0));
671          insert_rt_return(b, vars);
672          nir_pop_if(b, NULL);
673       }
674 
675       break;
676    }
677    case nir_intrinsic_load_ray_triangle_vertex_positions: {
678       nir_def *instance_node_addr = nir_load_var(b, vars->instance_addr);
679       nir_def *primitive_id = nir_load_var(b, vars->primitive_id);
680       ret = radv_load_vertex_position(vars->device, b, instance_node_addr, primitive_id, nir_intrinsic_column(intr));
681       break;
682    }
683    default:
684       return false;
685    }
686 
687    if (ret)
688       nir_def_rewrite_uses(&intr->def, ret);
689    nir_instr_remove(&intr->instr);
690 
691    return true;
692 }
693 
694 /* This lowers all the RT instructions that we do not want to pass on to the combined shader and
695  * that we can implement using the variables from the shader we are going to inline into. */
696 static void
lower_rt_instructions(nir_shader * shader,struct rt_variables * vars,bool apply_stack_ptr)697 lower_rt_instructions(nir_shader *shader, struct rt_variables *vars, bool apply_stack_ptr)
698 {
699    struct radv_lower_rt_instruction_data data = {
700       .vars = vars,
701       .apply_stack_ptr = apply_stack_ptr,
702    };
703    nir_shader_instructions_pass(shader, radv_lower_rt_instruction, nir_metadata_none, &data);
704 }
705 
706 /* Lowers hit attributes to registers or shared memory. If hit_attribs is NULL, attributes are
707  * lowered to shared memory. */
708 static void
lower_hit_attribs(nir_shader * shader,nir_variable ** hit_attribs,uint32_t workgroup_size)709 lower_hit_attribs(nir_shader *shader, nir_variable **hit_attribs, uint32_t workgroup_size)
710 {
711    nir_function_impl *impl = nir_shader_get_entrypoint(shader);
712 
713    nir_foreach_variable_with_modes (attrib, shader, nir_var_ray_hit_attrib)
714       attrib->data.mode = nir_var_shader_temp;
715 
716    nir_builder b = nir_builder_create(impl);
717 
718    nir_foreach_block (block, impl) {
719       nir_foreach_instr_safe (instr, block) {
720          if (instr->type != nir_instr_type_intrinsic)
721             continue;
722 
723          nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
724          if (intrin->intrinsic != nir_intrinsic_load_hit_attrib_amd &&
725              intrin->intrinsic != nir_intrinsic_store_hit_attrib_amd)
726             continue;
727 
728          b.cursor = nir_after_instr(instr);
729 
730          nir_def *offset;
731          if (!hit_attribs)
732             offset = nir_imul_imm(
733                &b, nir_iadd_imm(&b, nir_load_local_invocation_index(&b), nir_intrinsic_base(intrin) * workgroup_size),
734                sizeof(uint32_t));
735 
736          if (intrin->intrinsic == nir_intrinsic_load_hit_attrib_amd) {
737             nir_def *ret;
738             if (hit_attribs)
739                ret = nir_load_var(&b, hit_attribs[nir_intrinsic_base(intrin)]);
740             else
741                ret = nir_load_shared(&b, 1, 32, offset, .base = 0, .align_mul = 4);
742             nir_def_rewrite_uses(nir_instr_def(instr), ret);
743          } else {
744             if (hit_attribs)
745                nir_store_var(&b, hit_attribs[nir_intrinsic_base(intrin)], intrin->src->ssa, 0x1);
746             else
747                nir_store_shared(&b, intrin->src->ssa, offset, .base = 0, .align_mul = 4);
748          }
749          nir_instr_remove(instr);
750       }
751    }
752 
753    if (!hit_attribs)
754       shader->info.shared_size = MAX2(shader->info.shared_size, workgroup_size * RADV_MAX_HIT_ATTRIB_SIZE);
755 }
756 
757 static void
inline_constants(nir_shader * dst,nir_shader * src)758 inline_constants(nir_shader *dst, nir_shader *src)
759 {
760    if (!src->constant_data_size)
761       return;
762 
763    uint32_t align_mul = 1;
764    if (dst->constant_data_size) {
765       nir_foreach_block (block, nir_shader_get_entrypoint(src)) {
766          nir_foreach_instr (instr, block) {
767             if (instr->type != nir_instr_type_intrinsic)
768                continue;
769 
770             nir_intrinsic_instr *intrinsic = nir_instr_as_intrinsic(instr);
771             if (intrinsic->intrinsic == nir_intrinsic_load_constant)
772                align_mul = MAX2(align_mul, nir_intrinsic_align_mul(intrinsic));
773          }
774       }
775    }
776 
777    uint32_t old_constant_data_size = dst->constant_data_size;
778    uint32_t base_offset = align(dst->constant_data_size, align_mul);
779    dst->constant_data_size = base_offset + src->constant_data_size;
780    dst->constant_data = rerzalloc_size(dst, dst->constant_data, old_constant_data_size, dst->constant_data_size);
781    memcpy((char *)dst->constant_data + base_offset, src->constant_data, src->constant_data_size);
782 
783    if (!base_offset)
784       return;
785 
786    nir_foreach_block (block, nir_shader_get_entrypoint(src)) {
787       nir_foreach_instr (instr, block) {
788          if (instr->type != nir_instr_type_intrinsic)
789             continue;
790 
791          nir_intrinsic_instr *intrinsic = nir_instr_as_intrinsic(instr);
792          if (intrinsic->intrinsic == nir_intrinsic_load_constant)
793             nir_intrinsic_set_base(intrinsic, base_offset + nir_intrinsic_base(intrinsic));
794       }
795    }
796 }
797 
798 static void
insert_rt_case(nir_builder * b,nir_shader * shader,struct rt_variables * vars,nir_def * idx,uint32_t call_idx)799 insert_rt_case(nir_builder *b, nir_shader *shader, struct rt_variables *vars, nir_def *idx, uint32_t call_idx)
800 {
801    struct hash_table *var_remap = _mesa_pointer_hash_table_create(NULL);
802 
803    nir_opt_dead_cf(shader);
804 
805    struct rt_variables src_vars = create_rt_variables(shader, vars->device, vars->flags, vars->monolithic);
806    map_rt_variables(var_remap, &src_vars, vars);
807 
808    NIR_PASS_V(shader, lower_rt_instructions, &src_vars, false);
809 
810    NIR_PASS(_, shader, nir_lower_returns);
811    NIR_PASS(_, shader, nir_opt_dce);
812 
813    inline_constants(b->shader, shader);
814 
815    nir_push_if(b, nir_ieq_imm(b, idx, call_idx));
816    nir_inline_function_impl(b, nir_shader_get_entrypoint(shader), NULL, var_remap);
817    nir_pop_if(b, NULL);
818 
819    ralloc_free(var_remap);
820 }
821 
822 static bool
radv_lower_payload_arg_to_offset(nir_builder * b,nir_intrinsic_instr * instr,void * data)823 radv_lower_payload_arg_to_offset(nir_builder *b, nir_intrinsic_instr *instr, void *data)
824 {
825    if (instr->intrinsic != nir_intrinsic_trace_ray)
826       return false;
827 
828    nir_deref_instr *payload = nir_src_as_deref(instr->src[10]);
829    assert(payload->deref_type == nir_deref_type_var);
830 
831    b->cursor = nir_before_instr(&instr->instr);
832    nir_def *offset = nir_imm_int(b, payload->var->data.driver_location);
833 
834    nir_src_rewrite(&instr->src[10], offset);
835 
836    return true;
837 }
838 
839 void
radv_nir_lower_rt_io(nir_shader * nir,bool monolithic,uint32_t payload_offset)840 radv_nir_lower_rt_io(nir_shader *nir, bool monolithic, uint32_t payload_offset)
841 {
842    if (!monolithic) {
843       NIR_PASS(_, nir, nir_lower_vars_to_explicit_types, nir_var_function_temp | nir_var_shader_call_data,
844                glsl_get_natural_size_align_bytes);
845 
846       NIR_PASS(_, nir, lower_rt_derefs);
847 
848       NIR_PASS(_, nir, nir_lower_explicit_io, nir_var_function_temp, nir_address_format_32bit_offset);
849    } else {
850       if (nir->info.stage == MESA_SHADER_RAYGEN) {
851          /* Use nir_lower_vars_to_explicit_types to assign the payload locations. We call
852           * nir_lower_vars_to_explicit_types later after splitting the payloads.
853           */
854          uint32_t scratch_size = nir->scratch_size;
855          nir_lower_vars_to_explicit_types(nir, nir_var_function_temp, glsl_get_natural_size_align_bytes);
856          nir->scratch_size = scratch_size;
857 
858          nir_shader_intrinsics_pass(nir, radv_lower_payload_arg_to_offset,
859                                     nir_metadata_block_index | nir_metadata_dominance, NULL);
860       }
861 
862       NIR_PASS(_, nir, radv_nir_lower_ray_payload_derefs, payload_offset);
863    }
864 }
865 
866 static nir_def *
radv_build_token_begin(nir_builder * b,struct rt_variables * vars,nir_def * hit,enum radv_packed_token_type token_type,nir_def * token_size,uint32_t max_token_size)867 radv_build_token_begin(nir_builder *b, struct rt_variables *vars, nir_def *hit, enum radv_packed_token_type token_type,
868                        nir_def *token_size, uint32_t max_token_size)
869 {
870    struct radv_rra_trace_data *rra_trace = &vars->device->rra_trace;
871    assert(rra_trace->ray_history_addr);
872    assert(rra_trace->ray_history_buffer_size >= max_token_size);
873 
874    nir_def *ray_history_addr = nir_imm_int64(b, rra_trace->ray_history_addr);
875 
876    nir_def *launch_id = nir_load_ray_launch_id(b);
877 
878    nir_def *trace = nir_imm_true(b);
879    for (uint32_t i = 0; i < 3; i++) {
880       nir_def *remainder = nir_umod_imm(b, nir_channel(b, launch_id, i), rra_trace->ray_history_resolution_scale);
881       trace = nir_iand(b, trace, nir_ieq_imm(b, remainder, 0));
882    }
883    nir_push_if(b, trace);
884 
885    static_assert(offsetof(struct radv_ray_history_header, offset) == 0, "Unexpected offset");
886    nir_def *base_offset = nir_global_atomic(b, 32, ray_history_addr, token_size, .atomic_op = nir_atomic_op_iadd);
887 
888    /* Abuse the dword alignment of token_size to add an invalid bit to offset. */
889    trace = nir_ieq_imm(b, nir_iand_imm(b, base_offset, 1), 0);
890 
891    nir_def *in_bounds = nir_ule_imm(b, base_offset, rra_trace->ray_history_buffer_size - max_token_size);
892    /* Make sure we don't overwrite the header in case of an overflow. */
893    in_bounds = nir_iand(b, in_bounds, nir_uge_imm(b, base_offset, sizeof(struct radv_ray_history_header)));
894 
895    nir_push_if(b, nir_iand(b, trace, in_bounds));
896 
897    nir_def *dst_addr = nir_iadd(b, ray_history_addr, nir_u2u64(b, base_offset));
898 
899    nir_def *launch_size = nir_load_ray_launch_size(b);
900 
901    nir_def *launch_id_comps[3];
902    nir_def *launch_size_comps[3];
903    for (uint32_t i = 0; i < 3; i++) {
904       launch_id_comps[i] = nir_udiv_imm(b, nir_channel(b, launch_id, i), rra_trace->ray_history_resolution_scale);
905       launch_size_comps[i] = nir_udiv_imm(b, nir_channel(b, launch_size, i), rra_trace->ray_history_resolution_scale);
906    }
907 
908    nir_def *global_index =
909       nir_iadd(b, launch_id_comps[0],
910                nir_iadd(b, nir_imul(b, launch_id_comps[1], launch_size_comps[0]),
911                         nir_imul(b, launch_id_comps[2], nir_imul(b, launch_size_comps[0], launch_size_comps[1]))));
912    nir_def *launch_index_and_hit = nir_bcsel(b, hit, nir_ior_imm(b, global_index, 1u << 29u), global_index);
913    nir_build_store_global(b, nir_ior_imm(b, launch_index_and_hit, token_type << 30), dst_addr, .align_mul = 4);
914 
915    return nir_iadd_imm(b, dst_addr, 4);
916 }
917 
918 static void
radv_build_token_end(nir_builder * b)919 radv_build_token_end(nir_builder *b)
920 {
921    nir_pop_if(b, NULL);
922    nir_pop_if(b, NULL);
923 }
924 
925 static void
radv_build_end_trace_token(nir_builder * b,struct rt_variables * vars,nir_def * tmax,nir_def * hit,nir_def * iteration_instance_count)926 radv_build_end_trace_token(nir_builder *b, struct rt_variables *vars, nir_def *tmax, nir_def *hit,
927                            nir_def *iteration_instance_count)
928 {
929    nir_def *token_size = nir_bcsel(b, hit, nir_imm_int(b, sizeof(struct radv_packed_end_trace_token)),
930                                    nir_imm_int(b, offsetof(struct radv_packed_end_trace_token, primitive_id)));
931 
932    nir_def *dst_addr = radv_build_token_begin(b, vars, hit, radv_packed_token_end_trace, token_size,
933                                               sizeof(struct radv_packed_end_trace_token));
934    {
935       nir_build_store_global(b, nir_load_var(b, vars->accel_struct), dst_addr, .align_mul = 4);
936       dst_addr = nir_iadd_imm(b, dst_addr, 8);
937 
938       nir_def *dispatch_indices =
939          nir_load_smem_amd(b, 2, nir_imm_int64(b, vars->device->rra_trace.ray_history_addr),
940                            nir_imm_int(b, offsetof(struct radv_ray_history_header, dispatch_index)), .align_mul = 4);
941       nir_def *dispatch_index = nir_iadd(b, nir_channel(b, dispatch_indices, 0), nir_channel(b, dispatch_indices, 1));
942       nir_def *dispatch_and_flags = nir_iand_imm(b, nir_load_var(b, vars->cull_mask_and_flags), 0xFFFF);
943       dispatch_and_flags = nir_ior(b, dispatch_and_flags, dispatch_index);
944       nir_build_store_global(b, dispatch_and_flags, dst_addr, .align_mul = 4);
945       dst_addr = nir_iadd_imm(b, dst_addr, 4);
946 
947       nir_def *shifted_cull_mask = nir_iand_imm(b, nir_load_var(b, vars->cull_mask_and_flags), 0xFF000000);
948 
949       nir_def *packed_args = nir_load_var(b, vars->sbt_offset);
950       packed_args = nir_ior(b, packed_args, nir_ishl_imm(b, nir_load_var(b, vars->sbt_stride), 4));
951       packed_args = nir_ior(b, packed_args, nir_ishl_imm(b, nir_load_var(b, vars->miss_index), 8));
952       packed_args = nir_ior(b, packed_args, shifted_cull_mask);
953       nir_build_store_global(b, packed_args, dst_addr, .align_mul = 4);
954       dst_addr = nir_iadd_imm(b, dst_addr, 4);
955 
956       nir_build_store_global(b, nir_load_var(b, vars->origin), dst_addr, .align_mul = 4);
957       dst_addr = nir_iadd_imm(b, dst_addr, 12);
958 
959       nir_build_store_global(b, nir_load_var(b, vars->tmin), dst_addr, .align_mul = 4);
960       dst_addr = nir_iadd_imm(b, dst_addr, 4);
961 
962       nir_build_store_global(b, nir_load_var(b, vars->direction), dst_addr, .align_mul = 4);
963       dst_addr = nir_iadd_imm(b, dst_addr, 12);
964 
965       nir_build_store_global(b, tmax, dst_addr, .align_mul = 4);
966       dst_addr = nir_iadd_imm(b, dst_addr, 4);
967 
968       nir_build_store_global(b, iteration_instance_count, dst_addr, .align_mul = 4);
969       dst_addr = nir_iadd_imm(b, dst_addr, 4);
970 
971       nir_build_store_global(b, nir_load_var(b, vars->ahit_isec_count), dst_addr, .align_mul = 4);
972       dst_addr = nir_iadd_imm(b, dst_addr, 4);
973 
974       nir_push_if(b, hit);
975       {
976          nir_build_store_global(b, nir_load_var(b, vars->primitive_id), dst_addr, .align_mul = 4);
977          dst_addr = nir_iadd_imm(b, dst_addr, 4);
978 
979          nir_def *geometry_id = nir_iand_imm(b, nir_load_var(b, vars->geometry_id_and_flags), 0xFFFFFFF);
980          nir_build_store_global(b, geometry_id, dst_addr, .align_mul = 4);
981          dst_addr = nir_iadd_imm(b, dst_addr, 4);
982 
983          nir_def *instance_id_and_hit_kind =
984             nir_build_load_global(b, 1, 32,
985                                   nir_iadd_imm(b, nir_load_var(b, vars->instance_addr),
986                                                offsetof(struct radv_bvh_instance_node, instance_id)));
987          instance_id_and_hit_kind =
988             nir_ior(b, instance_id_and_hit_kind, nir_ishl_imm(b, nir_load_var(b, vars->hit_kind), 24));
989          nir_build_store_global(b, instance_id_and_hit_kind, dst_addr, .align_mul = 4);
990          dst_addr = nir_iadd_imm(b, dst_addr, 4);
991 
992          nir_build_store_global(b, nir_load_var(b, vars->tmax), dst_addr, .align_mul = 4);
993          dst_addr = nir_iadd_imm(b, dst_addr, 4);
994       }
995       nir_pop_if(b, NULL);
996    }
997    radv_build_token_end(b);
998 }
999 
1000 static nir_function_impl *
lower_any_hit_for_intersection(nir_shader * any_hit)1001 lower_any_hit_for_intersection(nir_shader *any_hit)
1002 {
1003    nir_function_impl *impl = nir_shader_get_entrypoint(any_hit);
1004 
1005    /* Any-hit shaders need three parameters */
1006    assert(impl->function->num_params == 0);
1007    nir_parameter params[] = {
1008       {
1009          /* A pointer to a boolean value for whether or not the hit was
1010           * accepted.
1011           */
1012          .num_components = 1,
1013          .bit_size = 32,
1014       },
1015       {
1016          /* The hit T value */
1017          .num_components = 1,
1018          .bit_size = 32,
1019       },
1020       {
1021          /* The hit kind */
1022          .num_components = 1,
1023          .bit_size = 32,
1024       },
1025       {
1026          /* Scratch offset */
1027          .num_components = 1,
1028          .bit_size = 32,
1029       },
1030    };
1031    impl->function->num_params = ARRAY_SIZE(params);
1032    impl->function->params = ralloc_array(any_hit, nir_parameter, ARRAY_SIZE(params));
1033    memcpy(impl->function->params, params, sizeof(params));
1034 
1035    nir_builder build = nir_builder_at(nir_before_impl(impl));
1036    nir_builder *b = &build;
1037 
1038    nir_def *commit_ptr = nir_load_param(b, 0);
1039    nir_def *hit_t = nir_load_param(b, 1);
1040    nir_def *hit_kind = nir_load_param(b, 2);
1041    nir_def *scratch_offset = nir_load_param(b, 3);
1042 
1043    nir_deref_instr *commit = nir_build_deref_cast(b, commit_ptr, nir_var_function_temp, glsl_bool_type(), 0);
1044 
1045    nir_foreach_block_safe (block, impl) {
1046       nir_foreach_instr_safe (instr, block) {
1047          switch (instr->type) {
1048          case nir_instr_type_intrinsic: {
1049             nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
1050             switch (intrin->intrinsic) {
1051             case nir_intrinsic_ignore_ray_intersection:
1052                b->cursor = nir_instr_remove(&intrin->instr);
1053                /* We put the newly emitted code inside a dummy if because it's
1054                 * going to contain a jump instruction and we don't want to
1055                 * deal with that mess here.  It'll get dealt with by our
1056                 * control-flow optimization passes.
1057                 */
1058                nir_store_deref(b, commit, nir_imm_false(b), 0x1);
1059                nir_push_if(b, nir_imm_true(b));
1060                nir_jump(b, nir_jump_return);
1061                nir_pop_if(b, NULL);
1062                break;
1063 
1064             case nir_intrinsic_terminate_ray:
1065                /* The "normal" handling of terminateRay works fine in
1066                 * intersection shaders.
1067                 */
1068                break;
1069 
1070             case nir_intrinsic_load_ray_t_max:
1071                nir_def_rewrite_uses(&intrin->def, hit_t);
1072                nir_instr_remove(&intrin->instr);
1073                break;
1074 
1075             case nir_intrinsic_load_ray_hit_kind:
1076                nir_def_rewrite_uses(&intrin->def, hit_kind);
1077                nir_instr_remove(&intrin->instr);
1078                break;
1079 
1080             /* We place all any_hit scratch variables after intersection scratch variables.
1081              * For that reason, we increment the scratch offset by the intersection scratch
1082              * size. For call_data, we have to subtract the offset again.
1083              *
1084              * Note that we don't increase the scratch size as it is already reflected via
1085              * the any_hit stack_size.
1086              */
1087             case nir_intrinsic_load_scratch:
1088                b->cursor = nir_before_instr(instr);
1089                nir_src_rewrite(&intrin->src[0], nir_iadd_nuw(b, scratch_offset, intrin->src[0].ssa));
1090                break;
1091             case nir_intrinsic_store_scratch:
1092                b->cursor = nir_before_instr(instr);
1093                nir_src_rewrite(&intrin->src[1], nir_iadd_nuw(b, scratch_offset, intrin->src[1].ssa));
1094                break;
1095             case nir_intrinsic_load_rt_arg_scratch_offset_amd:
1096                b->cursor = nir_after_instr(instr);
1097                nir_def *arg_offset = nir_isub(b, &intrin->def, scratch_offset);
1098                nir_def_rewrite_uses_after(&intrin->def, arg_offset, arg_offset->parent_instr);
1099                break;
1100 
1101             default:
1102                break;
1103             }
1104             break;
1105          }
1106          case nir_instr_type_jump: {
1107             nir_jump_instr *jump = nir_instr_as_jump(instr);
1108             if (jump->type == nir_jump_halt) {
1109                b->cursor = nir_instr_remove(instr);
1110                nir_jump(b, nir_jump_return);
1111             }
1112             break;
1113          }
1114 
1115          default:
1116             break;
1117          }
1118       }
1119    }
1120 
1121    nir_validate_shader(any_hit, "after initial any-hit lowering");
1122 
1123    nir_lower_returns_impl(impl);
1124 
1125    nir_validate_shader(any_hit, "after lowering returns");
1126 
1127    return impl;
1128 }
1129 
1130 /* Inline the any_hit shader into the intersection shader so we don't have
1131  * to implement yet another shader call interface here. Neither do any recursion.
1132  */
1133 static void
nir_lower_intersection_shader(nir_shader * intersection,nir_shader * any_hit)1134 nir_lower_intersection_shader(nir_shader *intersection, nir_shader *any_hit)
1135 {
1136    void *dead_ctx = ralloc_context(intersection);
1137 
1138    nir_function_impl *any_hit_impl = NULL;
1139    struct hash_table *any_hit_var_remap = NULL;
1140    if (any_hit) {
1141       any_hit = nir_shader_clone(dead_ctx, any_hit);
1142       NIR_PASS(_, any_hit, nir_opt_dce);
1143 
1144       inline_constants(intersection, any_hit);
1145 
1146       any_hit_impl = lower_any_hit_for_intersection(any_hit);
1147       any_hit_var_remap = _mesa_pointer_hash_table_create(dead_ctx);
1148    }
1149 
1150    nir_function_impl *impl = nir_shader_get_entrypoint(intersection);
1151 
1152    nir_builder build = nir_builder_create(impl);
1153    nir_builder *b = &build;
1154 
1155    b->cursor = nir_before_impl(impl);
1156 
1157    nir_variable *commit = nir_local_variable_create(impl, glsl_bool_type(), "ray_commit");
1158    nir_store_var(b, commit, nir_imm_false(b), 0x1);
1159 
1160    nir_foreach_block_safe (block, impl) {
1161       nir_foreach_instr_safe (instr, block) {
1162          if (instr->type != nir_instr_type_intrinsic)
1163             continue;
1164 
1165          nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
1166          if (intrin->intrinsic != nir_intrinsic_report_ray_intersection)
1167             continue;
1168 
1169          b->cursor = nir_instr_remove(&intrin->instr);
1170          nir_def *hit_t = intrin->src[0].ssa;
1171          nir_def *hit_kind = intrin->src[1].ssa;
1172          nir_def *min_t = nir_load_ray_t_min(b);
1173          nir_def *max_t = nir_load_ray_t_max(b);
1174 
1175          /* bool commit_tmp = false; */
1176          nir_variable *commit_tmp = nir_local_variable_create(impl, glsl_bool_type(), "commit_tmp");
1177          nir_store_var(b, commit_tmp, nir_imm_false(b), 0x1);
1178 
1179          nir_push_if(b, nir_iand(b, nir_fge(b, hit_t, min_t), nir_fge(b, max_t, hit_t)));
1180          {
1181             /* Any-hit defaults to commit */
1182             nir_store_var(b, commit_tmp, nir_imm_true(b), 0x1);
1183 
1184             if (any_hit_impl != NULL) {
1185                nir_push_if(b, nir_inot(b, nir_load_intersection_opaque_amd(b)));
1186                {
1187                   nir_def *params[] = {
1188                      &nir_build_deref_var(b, commit_tmp)->def,
1189                      hit_t,
1190                      hit_kind,
1191                      nir_imm_int(b, intersection->scratch_size),
1192                   };
1193                   nir_inline_function_impl(b, any_hit_impl, params, any_hit_var_remap);
1194                }
1195                nir_pop_if(b, NULL);
1196             }
1197 
1198             nir_push_if(b, nir_load_var(b, commit_tmp));
1199             {
1200                nir_report_ray_intersection(b, 1, hit_t, hit_kind);
1201             }
1202             nir_pop_if(b, NULL);
1203          }
1204          nir_pop_if(b, NULL);
1205 
1206          nir_def *accepted = nir_load_var(b, commit_tmp);
1207          nir_def_rewrite_uses(&intrin->def, accepted);
1208       }
1209    }
1210    nir_metadata_preserve(impl, nir_metadata_none);
1211 
1212    /* We did some inlining; have to re-index SSA defs */
1213    nir_index_ssa_defs(impl);
1214 
1215    /* Eliminate the casts introduced for the commit return of the any-hit shader. */
1216    NIR_PASS(_, intersection, nir_opt_deref);
1217 
1218    ralloc_free(dead_ctx);
1219 }
1220 
1221 /* Variables only used internally to ray traversal. This is data that describes
1222  * the current state of the traversal vs. what we'd give to a shader.  e.g. what
1223  * is the instance we're currently visiting vs. what is the instance of the
1224  * closest hit. */
1225 struct rt_traversal_vars {
1226    nir_variable *origin;
1227    nir_variable *dir;
1228    nir_variable *inv_dir;
1229    nir_variable *sbt_offset_and_flags;
1230    nir_variable *instance_addr;
1231    nir_variable *hit;
1232    nir_variable *bvh_base;
1233    nir_variable *stack;
1234    nir_variable *top_stack;
1235    nir_variable *stack_low_watermark;
1236    nir_variable *current_node;
1237    nir_variable *previous_node;
1238    nir_variable *instance_top_node;
1239    nir_variable *instance_bottom_node;
1240 };
1241 
1242 static struct rt_traversal_vars
init_traversal_vars(nir_builder * b)1243 init_traversal_vars(nir_builder *b)
1244 {
1245    const struct glsl_type *vec3_type = glsl_vector_type(GLSL_TYPE_FLOAT, 3);
1246    struct rt_traversal_vars ret;
1247 
1248    ret.origin = nir_variable_create(b->shader, nir_var_shader_temp, vec3_type, "traversal_origin");
1249    ret.dir = nir_variable_create(b->shader, nir_var_shader_temp, vec3_type, "traversal_dir");
1250    ret.inv_dir = nir_variable_create(b->shader, nir_var_shader_temp, vec3_type, "traversal_inv_dir");
1251    ret.sbt_offset_and_flags =
1252       nir_variable_create(b->shader, nir_var_shader_temp, glsl_uint_type(), "traversal_sbt_offset_and_flags");
1253    ret.instance_addr = nir_variable_create(b->shader, nir_var_shader_temp, glsl_uint64_t_type(), "instance_addr");
1254    ret.hit = nir_variable_create(b->shader, nir_var_shader_temp, glsl_bool_type(), "traversal_hit");
1255    ret.bvh_base = nir_variable_create(b->shader, nir_var_shader_temp, glsl_uint64_t_type(), "traversal_bvh_base");
1256    ret.stack = nir_variable_create(b->shader, nir_var_shader_temp, glsl_uint_type(), "traversal_stack_ptr");
1257    ret.top_stack = nir_variable_create(b->shader, nir_var_shader_temp, glsl_uint_type(), "traversal_top_stack_ptr");
1258    ret.stack_low_watermark =
1259       nir_variable_create(b->shader, nir_var_shader_temp, glsl_uint_type(), "traversal_stack_low_watermark");
1260    ret.current_node = nir_variable_create(b->shader, nir_var_shader_temp, glsl_uint_type(), "current_node;");
1261    ret.previous_node = nir_variable_create(b->shader, nir_var_shader_temp, glsl_uint_type(), "previous_node");
1262    ret.instance_top_node = nir_variable_create(b->shader, nir_var_shader_temp, glsl_uint_type(), "instance_top_node");
1263    ret.instance_bottom_node =
1264       nir_variable_create(b->shader, nir_var_shader_temp, glsl_uint_type(), "instance_bottom_node");
1265    return ret;
1266 }
1267 
1268 struct traversal_data {
1269    struct radv_device *device;
1270    struct rt_variables *vars;
1271    struct rt_traversal_vars *trav_vars;
1272    nir_variable *barycentrics;
1273 
1274    struct radv_ray_tracing_pipeline *pipeline;
1275 };
1276 
1277 static void
radv_ray_tracing_group_ahit_info(struct radv_ray_tracing_group * group,uint32_t * shader_index,uint32_t * handle_index,struct radv_rt_case_data * data)1278 radv_ray_tracing_group_ahit_info(struct radv_ray_tracing_group *group, uint32_t *shader_index, uint32_t *handle_index,
1279                                  struct radv_rt_case_data *data)
1280 {
1281    if (group->type == VK_RAY_TRACING_SHADER_GROUP_TYPE_TRIANGLES_HIT_GROUP_KHR) {
1282       *shader_index = group->any_hit_shader;
1283       *handle_index = group->handle.any_hit_index;
1284    }
1285 }
1286 
1287 static void
radv_build_ahit_case(nir_builder * b,nir_def * sbt_idx,struct radv_ray_tracing_group * group,struct radv_rt_case_data * data)1288 radv_build_ahit_case(nir_builder *b, nir_def *sbt_idx, struct radv_ray_tracing_group *group,
1289                      struct radv_rt_case_data *data)
1290 {
1291    nir_shader *nir_stage =
1292       radv_pipeline_cache_handle_to_nir(data->device, data->pipeline->stages[group->any_hit_shader].nir);
1293    assert(nir_stage);
1294 
1295    radv_nir_lower_rt_io(nir_stage, data->vars->monolithic, data->vars->payload_offset);
1296 
1297    insert_rt_case(b, nir_stage, data->vars, sbt_idx, group->handle.any_hit_index);
1298    ralloc_free(nir_stage);
1299 }
1300 
1301 static void
radv_ray_tracing_group_isec_info(struct radv_ray_tracing_group * group,uint32_t * shader_index,uint32_t * handle_index,struct radv_rt_case_data * data)1302 radv_ray_tracing_group_isec_info(struct radv_ray_tracing_group *group, uint32_t *shader_index, uint32_t *handle_index,
1303                                  struct radv_rt_case_data *data)
1304 {
1305    if (group->type == VK_RAY_TRACING_SHADER_GROUP_TYPE_PROCEDURAL_HIT_GROUP_KHR) {
1306       *shader_index = group->intersection_shader;
1307       *handle_index = group->handle.intersection_index;
1308    }
1309 }
1310 
1311 static void
radv_build_isec_case(nir_builder * b,nir_def * sbt_idx,struct radv_ray_tracing_group * group,struct radv_rt_case_data * data)1312 radv_build_isec_case(nir_builder *b, nir_def *sbt_idx, struct radv_ray_tracing_group *group,
1313                      struct radv_rt_case_data *data)
1314 {
1315    nir_shader *nir_stage =
1316       radv_pipeline_cache_handle_to_nir(data->device, data->pipeline->stages[group->intersection_shader].nir);
1317    assert(nir_stage);
1318 
1319    radv_nir_lower_rt_io(nir_stage, data->vars->monolithic, data->vars->payload_offset);
1320 
1321    nir_shader *any_hit_stage = NULL;
1322    if (group->any_hit_shader != VK_SHADER_UNUSED_KHR) {
1323       any_hit_stage =
1324          radv_pipeline_cache_handle_to_nir(data->device, data->pipeline->stages[group->any_hit_shader].nir);
1325       assert(any_hit_stage);
1326 
1327       radv_nir_lower_rt_io(any_hit_stage, data->vars->monolithic, data->vars->payload_offset);
1328 
1329       /* reserve stack size for any_hit before it is inlined */
1330       data->pipeline->stages[group->any_hit_shader].stack_size = any_hit_stage->scratch_size;
1331 
1332       nir_lower_intersection_shader(nir_stage, any_hit_stage);
1333       ralloc_free(any_hit_stage);
1334    }
1335 
1336    insert_rt_case(b, nir_stage, data->vars, sbt_idx, group->handle.intersection_index);
1337    ralloc_free(nir_stage);
1338 }
1339 
1340 static void
radv_ray_tracing_group_chit_info(struct radv_ray_tracing_group * group,uint32_t * shader_index,uint32_t * handle_index,struct radv_rt_case_data * data)1341 radv_ray_tracing_group_chit_info(struct radv_ray_tracing_group *group, uint32_t *shader_index, uint32_t *handle_index,
1342                                  struct radv_rt_case_data *data)
1343 {
1344    if (group->type != VK_RAY_TRACING_SHADER_GROUP_TYPE_GENERAL_KHR) {
1345       *shader_index = group->recursive_shader;
1346       *handle_index = group->handle.closest_hit_index;
1347    }
1348 }
1349 
1350 static void
radv_ray_tracing_group_miss_info(struct radv_ray_tracing_group * group,uint32_t * shader_index,uint32_t * handle_index,struct radv_rt_case_data * data)1351 radv_ray_tracing_group_miss_info(struct radv_ray_tracing_group *group, uint32_t *shader_index, uint32_t *handle_index,
1352                                  struct radv_rt_case_data *data)
1353 {
1354    if (group->type == VK_RAY_TRACING_SHADER_GROUP_TYPE_GENERAL_KHR) {
1355       if (data->pipeline->stages[group->recursive_shader].stage != MESA_SHADER_MISS)
1356          return;
1357 
1358       *shader_index = group->recursive_shader;
1359       *handle_index = group->handle.general_index;
1360    }
1361 }
1362 
1363 static void
radv_build_recursive_case(nir_builder * b,nir_def * sbt_idx,struct radv_ray_tracing_group * group,struct radv_rt_case_data * data)1364 radv_build_recursive_case(nir_builder *b, nir_def *sbt_idx, struct radv_ray_tracing_group *group,
1365                           struct radv_rt_case_data *data)
1366 {
1367    nir_shader *nir_stage =
1368       radv_pipeline_cache_handle_to_nir(data->device, data->pipeline->stages[group->recursive_shader].nir);
1369    assert(nir_stage);
1370 
1371    radv_nir_lower_rt_io(nir_stage, data->vars->monolithic, data->vars->payload_offset);
1372 
1373    insert_rt_case(b, nir_stage, data->vars, sbt_idx, group->handle.general_index);
1374    ralloc_free(nir_stage);
1375 }
1376 
1377 static void
handle_candidate_triangle(nir_builder * b,struct radv_triangle_intersection * intersection,const struct radv_ray_traversal_args * args,const struct radv_ray_flags * ray_flags)1378 handle_candidate_triangle(nir_builder *b, struct radv_triangle_intersection *intersection,
1379                           const struct radv_ray_traversal_args *args, const struct radv_ray_flags *ray_flags)
1380 {
1381    struct traversal_data *data = args->data;
1382 
1383    nir_def *geometry_id = nir_iand_imm(b, intersection->base.geometry_id_and_flags, 0xfffffff);
1384    nir_def *sbt_idx =
1385       nir_iadd(b,
1386                nir_iadd(b, nir_load_var(b, data->vars->sbt_offset),
1387                         nir_iand_imm(b, nir_load_var(b, data->trav_vars->sbt_offset_and_flags), 0xffffff)),
1388                nir_imul(b, nir_load_var(b, data->vars->sbt_stride), geometry_id));
1389 
1390    nir_def *hit_kind = nir_bcsel(b, intersection->frontface, nir_imm_int(b, 0xFE), nir_imm_int(b, 0xFF));
1391 
1392    nir_def *prev_barycentrics = nir_load_var(b, data->barycentrics);
1393    nir_store_var(b, data->barycentrics, intersection->barycentrics, 0x3);
1394 
1395    nir_store_var(b, data->vars->ahit_accept, nir_imm_true(b), 0x1);
1396    nir_store_var(b, data->vars->ahit_terminate, nir_imm_false(b), 0x1);
1397 
1398    nir_push_if(b, nir_inot(b, intersection->base.opaque));
1399    {
1400       struct rt_variables inner_vars = create_inner_vars(b, data->vars);
1401 
1402       nir_store_var(b, inner_vars.primitive_id, intersection->base.primitive_id, 1);
1403       nir_store_var(b, inner_vars.geometry_id_and_flags, intersection->base.geometry_id_and_flags, 1);
1404       nir_store_var(b, inner_vars.tmax, intersection->t, 0x1);
1405       nir_store_var(b, inner_vars.instance_addr, nir_load_var(b, data->trav_vars->instance_addr), 0x1);
1406       nir_store_var(b, inner_vars.hit_kind, hit_kind, 0x1);
1407 
1408       load_sbt_entry(b, &inner_vars, sbt_idx, SBT_HIT, SBT_ANY_HIT_IDX);
1409 
1410       struct radv_rt_case_data case_data = {
1411          .device = data->device,
1412          .pipeline = data->pipeline,
1413          .vars = &inner_vars,
1414       };
1415 
1416       if (data->vars->ahit_isec_count)
1417          nir_store_var(b, data->vars->ahit_isec_count, nir_iadd_imm(b, nir_load_var(b, data->vars->ahit_isec_count), 1),
1418                        0x1);
1419 
1420       radv_visit_inlined_shaders(
1421          b, nir_load_var(b, inner_vars.idx),
1422          !(data->vars->flags & VK_PIPELINE_CREATE_2_RAY_TRACING_NO_NULL_ANY_HIT_SHADERS_BIT_KHR), &case_data,
1423          radv_ray_tracing_group_ahit_info, radv_build_ahit_case);
1424 
1425       nir_push_if(b, nir_inot(b, nir_load_var(b, data->vars->ahit_accept)));
1426       {
1427          nir_store_var(b, data->barycentrics, prev_barycentrics, 0x3);
1428          nir_jump(b, nir_jump_continue);
1429       }
1430       nir_pop_if(b, NULL);
1431    }
1432    nir_pop_if(b, NULL);
1433 
1434    nir_store_var(b, data->vars->primitive_id, intersection->base.primitive_id, 1);
1435    nir_store_var(b, data->vars->geometry_id_and_flags, intersection->base.geometry_id_and_flags, 1);
1436    nir_store_var(b, data->vars->tmax, intersection->t, 0x1);
1437    nir_store_var(b, data->vars->instance_addr, nir_load_var(b, data->trav_vars->instance_addr), 0x1);
1438    nir_store_var(b, data->vars->hit_kind, hit_kind, 0x1);
1439 
1440    nir_store_var(b, data->vars->idx, sbt_idx, 1);
1441    nir_store_var(b, data->trav_vars->hit, nir_imm_true(b), 1);
1442 
1443    nir_def *ray_terminated = nir_load_var(b, data->vars->ahit_terminate);
1444    nir_push_if(b, nir_ior(b, ray_flags->terminate_on_first_hit, ray_terminated));
1445    {
1446       nir_jump(b, nir_jump_break);
1447    }
1448    nir_pop_if(b, NULL);
1449 }
1450 
1451 static void
handle_candidate_aabb(nir_builder * b,struct radv_leaf_intersection * intersection,const struct radv_ray_traversal_args * args)1452 handle_candidate_aabb(nir_builder *b, struct radv_leaf_intersection *intersection,
1453                       const struct radv_ray_traversal_args *args)
1454 {
1455    struct traversal_data *data = args->data;
1456 
1457    nir_def *geometry_id = nir_iand_imm(b, intersection->geometry_id_and_flags, 0xfffffff);
1458    nir_def *sbt_idx =
1459       nir_iadd(b,
1460                nir_iadd(b, nir_load_var(b, data->vars->sbt_offset),
1461                         nir_iand_imm(b, nir_load_var(b, data->trav_vars->sbt_offset_and_flags), 0xffffff)),
1462                nir_imul(b, nir_load_var(b, data->vars->sbt_stride), geometry_id));
1463 
1464    struct rt_variables inner_vars = create_inner_vars(b, data->vars);
1465 
1466    /* For AABBs the intersection shader writes the hit kind, and only does it if it is the
1467     * next closest hit candidate. */
1468    inner_vars.hit_kind = data->vars->hit_kind;
1469 
1470    nir_store_var(b, inner_vars.primitive_id, intersection->primitive_id, 1);
1471    nir_store_var(b, inner_vars.geometry_id_and_flags, intersection->geometry_id_and_flags, 1);
1472    nir_store_var(b, inner_vars.tmax, nir_load_var(b, data->vars->tmax), 0x1);
1473    nir_store_var(b, inner_vars.instance_addr, nir_load_var(b, data->trav_vars->instance_addr), 0x1);
1474    nir_store_var(b, inner_vars.opaque, intersection->opaque, 1);
1475 
1476    load_sbt_entry(b, &inner_vars, sbt_idx, SBT_HIT, SBT_INTERSECTION_IDX);
1477 
1478    nir_store_var(b, data->vars->ahit_accept, nir_imm_false(b), 0x1);
1479    nir_store_var(b, data->vars->ahit_terminate, nir_imm_false(b), 0x1);
1480 
1481    if (data->vars->ahit_isec_count)
1482       nir_store_var(b, data->vars->ahit_isec_count,
1483                     nir_iadd_imm(b, nir_load_var(b, data->vars->ahit_isec_count), 1 << 16), 0x1);
1484 
1485    struct radv_rt_case_data case_data = {
1486       .device = data->device,
1487       .pipeline = data->pipeline,
1488       .vars = &inner_vars,
1489    };
1490 
1491    radv_visit_inlined_shaders(
1492       b, nir_load_var(b, inner_vars.idx),
1493       !(data->vars->flags & VK_PIPELINE_CREATE_2_RAY_TRACING_NO_NULL_INTERSECTION_SHADERS_BIT_KHR), &case_data,
1494       radv_ray_tracing_group_isec_info, radv_build_isec_case);
1495 
1496    nir_push_if(b, nir_load_var(b, data->vars->ahit_accept));
1497    {
1498       nir_store_var(b, data->vars->primitive_id, intersection->primitive_id, 1);
1499       nir_store_var(b, data->vars->geometry_id_and_flags, intersection->geometry_id_and_flags, 1);
1500       nir_store_var(b, data->vars->tmax, nir_load_var(b, inner_vars.tmax), 0x1);
1501       nir_store_var(b, data->vars->instance_addr, nir_load_var(b, data->trav_vars->instance_addr), 0x1);
1502 
1503       nir_store_var(b, data->vars->idx, sbt_idx, 1);
1504       nir_store_var(b, data->trav_vars->hit, nir_imm_true(b), 1);
1505 
1506       nir_def *terminate_on_first_hit = nir_test_mask(b, args->flags, SpvRayFlagsTerminateOnFirstHitKHRMask);
1507       nir_def *ray_terminated = nir_load_var(b, data->vars->ahit_terminate);
1508       nir_push_if(b, nir_ior(b, terminate_on_first_hit, ray_terminated));
1509       {
1510          nir_jump(b, nir_jump_break);
1511       }
1512       nir_pop_if(b, NULL);
1513    }
1514    nir_pop_if(b, NULL);
1515 }
1516 
1517 static void
store_stack_entry(nir_builder * b,nir_def * index,nir_def * value,const struct radv_ray_traversal_args * args)1518 store_stack_entry(nir_builder *b, nir_def *index, nir_def *value, const struct radv_ray_traversal_args *args)
1519 {
1520    nir_store_shared(b, value, index, .base = 0, .align_mul = 4);
1521 }
1522 
1523 static nir_def *
load_stack_entry(nir_builder * b,nir_def * index,const struct radv_ray_traversal_args * args)1524 load_stack_entry(nir_builder *b, nir_def *index, const struct radv_ray_traversal_args *args)
1525 {
1526    return nir_load_shared(b, 1, 32, index, .base = 0, .align_mul = 4);
1527 }
1528 
1529 static void
radv_build_traversal(struct radv_device * device,struct radv_ray_tracing_pipeline * pipeline,const VkRayTracingPipelineCreateInfoKHR * pCreateInfo,bool monolithic,nir_builder * b,struct rt_variables * vars,bool ignore_cull_mask)1530 radv_build_traversal(struct radv_device *device, struct radv_ray_tracing_pipeline *pipeline,
1531                      const VkRayTracingPipelineCreateInfoKHR *pCreateInfo, bool monolithic, nir_builder *b,
1532                      struct rt_variables *vars, bool ignore_cull_mask)
1533 {
1534    nir_variable *barycentrics =
1535       nir_variable_create(b->shader, nir_var_ray_hit_attrib, glsl_vector_type(GLSL_TYPE_FLOAT, 2), "barycentrics");
1536    barycentrics->data.driver_location = 0;
1537 
1538    struct rt_traversal_vars trav_vars = init_traversal_vars(b);
1539 
1540    nir_store_var(b, trav_vars.hit, nir_imm_false(b), 1);
1541 
1542    nir_def *accel_struct = nir_load_var(b, vars->accel_struct);
1543    nir_def *bvh_offset = nir_build_load_global(
1544       b, 1, 32, nir_iadd_imm(b, accel_struct, offsetof(struct radv_accel_struct_header, bvh_offset)),
1545       .access = ACCESS_NON_WRITEABLE);
1546    nir_def *root_bvh_base = nir_iadd(b, accel_struct, nir_u2u64(b, bvh_offset));
1547    root_bvh_base = build_addr_to_node(b, root_bvh_base);
1548 
1549    nir_store_var(b, trav_vars.bvh_base, root_bvh_base, 1);
1550 
1551    nir_def *vec3ones = nir_imm_vec3(b, 1.0, 1.0, 1.0);
1552 
1553    nir_store_var(b, trav_vars.origin, nir_load_var(b, vars->origin), 7);
1554    nir_store_var(b, trav_vars.dir, nir_load_var(b, vars->direction), 7);
1555    nir_store_var(b, trav_vars.inv_dir, nir_fdiv(b, vec3ones, nir_load_var(b, trav_vars.dir)), 7);
1556    nir_store_var(b, trav_vars.sbt_offset_and_flags, nir_imm_int(b, 0), 1);
1557    nir_store_var(b, trav_vars.instance_addr, nir_imm_int64(b, 0), 1);
1558 
1559    nir_store_var(b, trav_vars.stack, nir_imul_imm(b, nir_load_local_invocation_index(b), sizeof(uint32_t)), 1);
1560    nir_store_var(b, trav_vars.stack_low_watermark, nir_load_var(b, trav_vars.stack), 1);
1561    nir_store_var(b, trav_vars.current_node, nir_imm_int(b, RADV_BVH_ROOT_NODE), 0x1);
1562    nir_store_var(b, trav_vars.previous_node, nir_imm_int(b, RADV_BVH_INVALID_NODE), 0x1);
1563    nir_store_var(b, trav_vars.instance_top_node, nir_imm_int(b, RADV_BVH_INVALID_NODE), 0x1);
1564    nir_store_var(b, trav_vars.instance_bottom_node, nir_imm_int(b, RADV_BVH_NO_INSTANCE_ROOT), 0x1);
1565 
1566    nir_store_var(b, trav_vars.top_stack, nir_imm_int(b, -1), 1);
1567 
1568    struct radv_ray_traversal_vars trav_vars_args = {
1569       .tmax = nir_build_deref_var(b, vars->tmax),
1570       .origin = nir_build_deref_var(b, trav_vars.origin),
1571       .dir = nir_build_deref_var(b, trav_vars.dir),
1572       .inv_dir = nir_build_deref_var(b, trav_vars.inv_dir),
1573       .bvh_base = nir_build_deref_var(b, trav_vars.bvh_base),
1574       .stack = nir_build_deref_var(b, trav_vars.stack),
1575       .top_stack = nir_build_deref_var(b, trav_vars.top_stack),
1576       .stack_low_watermark = nir_build_deref_var(b, trav_vars.stack_low_watermark),
1577       .current_node = nir_build_deref_var(b, trav_vars.current_node),
1578       .previous_node = nir_build_deref_var(b, trav_vars.previous_node),
1579       .instance_top_node = nir_build_deref_var(b, trav_vars.instance_top_node),
1580       .instance_bottom_node = nir_build_deref_var(b, trav_vars.instance_bottom_node),
1581       .instance_addr = nir_build_deref_var(b, trav_vars.instance_addr),
1582       .sbt_offset_and_flags = nir_build_deref_var(b, trav_vars.sbt_offset_and_flags),
1583    };
1584 
1585    nir_variable *iteration_instance_count = NULL;
1586    if (vars->device->rra_trace.ray_history_addr) {
1587       iteration_instance_count =
1588          nir_variable_create(b->shader, nir_var_shader_temp, glsl_uint_type(), "iteration_instance_count");
1589       nir_store_var(b, iteration_instance_count, nir_imm_int(b, 0), 0x1);
1590       trav_vars_args.iteration_instance_count = nir_build_deref_var(b, iteration_instance_count);
1591 
1592       nir_store_var(b, vars->ahit_isec_count, nir_imm_int(b, 0), 0x1);
1593    }
1594 
1595    struct traversal_data data = {
1596       .device = device,
1597       .vars = vars,
1598       .trav_vars = &trav_vars,
1599       .barycentrics = barycentrics,
1600       .pipeline = pipeline,
1601    };
1602 
1603    nir_def *cull_mask_and_flags = nir_load_var(b, vars->cull_mask_and_flags);
1604    struct radv_ray_traversal_args args = {
1605       .root_bvh_base = root_bvh_base,
1606       .flags = cull_mask_and_flags,
1607       .cull_mask = cull_mask_and_flags,
1608       .origin = nir_load_var(b, vars->origin),
1609       .tmin = nir_load_var(b, vars->tmin),
1610       .dir = nir_load_var(b, vars->direction),
1611       .vars = trav_vars_args,
1612       .stack_stride = device->physical_device->rt_wave_size * sizeof(uint32_t),
1613       .stack_entries = MAX_STACK_ENTRY_COUNT,
1614       .stack_base = 0,
1615       .ignore_cull_mask = ignore_cull_mask,
1616       .stack_store_cb = store_stack_entry,
1617       .stack_load_cb = load_stack_entry,
1618       .aabb_cb = (pipeline->base.base.create_flags & VK_PIPELINE_CREATE_2_RAY_TRACING_SKIP_AABBS_BIT_KHR)
1619                     ? NULL
1620                     : handle_candidate_aabb,
1621       .triangle_cb = (pipeline->base.base.create_flags & VK_PIPELINE_CREATE_2_RAY_TRACING_SKIP_TRIANGLES_BIT_KHR)
1622                         ? NULL
1623                         : handle_candidate_triangle,
1624       .data = &data,
1625    };
1626 
1627    nir_def *original_tmax = nir_load_var(b, vars->tmax);
1628 
1629    radv_build_ray_traversal(device, b, &args);
1630 
1631    if (vars->device->rra_trace.ray_history_addr)
1632       radv_build_end_trace_token(b, vars, original_tmax, nir_load_var(b, trav_vars.hit),
1633                                  nir_load_var(b, iteration_instance_count));
1634 
1635    nir_metadata_preserve(nir_shader_get_entrypoint(b->shader), nir_metadata_none);
1636    radv_nir_lower_hit_attrib_derefs(b->shader);
1637 
1638    /* Register storage for hit attributes */
1639    nir_variable *hit_attribs[RADV_MAX_HIT_ATTRIB_DWORDS];
1640 
1641    if (!monolithic) {
1642       for (uint32_t i = 0; i < ARRAY_SIZE(hit_attribs); i++)
1643          hit_attribs[i] =
1644             nir_local_variable_create(nir_shader_get_entrypoint(b->shader), glsl_uint_type(), "ahit_attrib");
1645 
1646       lower_hit_attribs(b->shader, hit_attribs, device->physical_device->rt_wave_size);
1647    }
1648 
1649    /* Initialize follow-up shader. */
1650    nir_push_if(b, nir_load_var(b, trav_vars.hit));
1651    {
1652       if (monolithic) {
1653          load_sbt_entry(b, vars, nir_load_var(b, vars->idx), SBT_HIT, SBT_CLOSEST_HIT_IDX);
1654 
1655          nir_def *should_return =
1656             nir_test_mask(b, nir_load_var(b, vars->cull_mask_and_flags), SpvRayFlagsSkipClosestHitShaderKHRMask);
1657 
1658          /* should_return is set if we had a hit but we won't be calling the closest hit
1659           * shader and hence need to return immediately to the calling shader. */
1660          nir_push_if(b, nir_inot(b, should_return));
1661 
1662          struct radv_rt_case_data case_data = {
1663             .device = device,
1664             .pipeline = pipeline,
1665             .vars = vars,
1666          };
1667 
1668          radv_visit_inlined_shaders(
1669             b, nir_load_var(b, vars->idx),
1670             !(vars->flags & VK_PIPELINE_CREATE_2_RAY_TRACING_NO_NULL_CLOSEST_HIT_SHADERS_BIT_KHR), &case_data,
1671             radv_ray_tracing_group_chit_info, radv_build_recursive_case);
1672 
1673          nir_pop_if(b, NULL);
1674       } else {
1675          for (int i = 0; i < ARRAY_SIZE(hit_attribs); ++i)
1676             nir_store_hit_attrib_amd(b, nir_load_var(b, hit_attribs[i]), .base = i);
1677          nir_execute_closest_hit_amd(b, nir_load_var(b, vars->idx), nir_load_var(b, vars->tmax),
1678                                      nir_load_var(b, vars->primitive_id), nir_load_var(b, vars->instance_addr),
1679                                      nir_load_var(b, vars->geometry_id_and_flags), nir_load_var(b, vars->hit_kind));
1680       }
1681    }
1682    nir_push_else(b, NULL);
1683    {
1684       if (monolithic) {
1685          load_sbt_entry(b, vars, nir_load_var(b, vars->miss_index), SBT_MISS, SBT_GENERAL_IDX);
1686 
1687          struct radv_rt_case_data case_data = {
1688             .device = device,
1689             .pipeline = pipeline,
1690             .vars = vars,
1691          };
1692 
1693          radv_visit_inlined_shaders(b, nir_load_var(b, vars->idx),
1694                                     !(vars->flags & VK_PIPELINE_CREATE_2_RAY_TRACING_NO_NULL_MISS_SHADERS_BIT_KHR),
1695                                     &case_data, radv_ray_tracing_group_miss_info, radv_build_recursive_case);
1696       } else {
1697          /* Only load the miss shader if we actually miss. It is valid to not specify an SBT pointer
1698           * for miss shaders if none of the rays miss. */
1699          nir_execute_miss_amd(b, nir_load_var(b, vars->tmax));
1700       }
1701    }
1702    nir_pop_if(b, NULL);
1703 }
1704 
1705 nir_shader *
radv_build_traversal_shader(struct radv_device * device,struct radv_ray_tracing_pipeline * pipeline,const VkRayTracingPipelineCreateInfoKHR * pCreateInfo)1706 radv_build_traversal_shader(struct radv_device *device, struct radv_ray_tracing_pipeline *pipeline,
1707                             const VkRayTracingPipelineCreateInfoKHR *pCreateInfo)
1708 {
1709    const VkPipelineCreateFlagBits2KHR create_flags = vk_rt_pipeline_create_flags(pCreateInfo);
1710 
1711    /* Create the traversal shader as an intersection shader to prevent validation failures due to
1712     * invalid variable modes.*/
1713    nir_builder b = radv_meta_init_shader(device, MESA_SHADER_INTERSECTION, "rt_traversal");
1714    b.shader->info.internal = false;
1715    b.shader->info.workgroup_size[0] = 8;
1716    b.shader->info.workgroup_size[1] = device->physical_device->rt_wave_size == 64 ? 8 : 4;
1717    b.shader->info.shared_size = device->physical_device->rt_wave_size * MAX_STACK_ENTRY_COUNT * sizeof(uint32_t);
1718    struct rt_variables vars = create_rt_variables(b.shader, device, create_flags, false);
1719 
1720    /* initialize trace_ray arguments */
1721    nir_store_var(&b, vars.accel_struct, nir_load_accel_struct_amd(&b), 1);
1722    nir_store_var(&b, vars.cull_mask_and_flags, nir_load_cull_mask_and_flags_amd(&b), 0x1);
1723    nir_store_var(&b, vars.sbt_offset, nir_load_sbt_offset_amd(&b), 0x1);
1724    nir_store_var(&b, vars.sbt_stride, nir_load_sbt_stride_amd(&b), 0x1);
1725    nir_store_var(&b, vars.origin, nir_load_ray_world_origin(&b), 0x7);
1726    nir_store_var(&b, vars.tmin, nir_load_ray_t_min(&b), 0x1);
1727    nir_store_var(&b, vars.direction, nir_load_ray_world_direction(&b), 0x7);
1728    nir_store_var(&b, vars.tmax, nir_load_ray_t_max(&b), 0x1);
1729    nir_store_var(&b, vars.arg, nir_load_rt_arg_scratch_offset_amd(&b), 0x1);
1730    nir_store_var(&b, vars.stack_ptr, nir_imm_int(&b, 0), 0x1);
1731 
1732    radv_build_traversal(device, pipeline, pCreateInfo, false, &b, &vars, false);
1733 
1734    /* Deal with all the inline functions. */
1735    nir_index_ssa_defs(nir_shader_get_entrypoint(b.shader));
1736    nir_metadata_preserve(nir_shader_get_entrypoint(b.shader), nir_metadata_none);
1737 
1738    /* Lower and cleanup variables */
1739    NIR_PASS_V(b.shader, nir_lower_global_vars_to_local);
1740    NIR_PASS_V(b.shader, nir_lower_vars_to_ssa);
1741 
1742    return b.shader;
1743 }
1744 
1745 struct lower_rt_instruction_monolithic_state {
1746    struct radv_device *device;
1747    struct radv_ray_tracing_pipeline *pipeline;
1748    const VkRayTracingPipelineCreateInfoKHR *pCreateInfo;
1749 
1750    struct rt_variables *vars;
1751 };
1752 
1753 static bool
lower_rt_instruction_monolithic(nir_builder * b,nir_instr * instr,void * data)1754 lower_rt_instruction_monolithic(nir_builder *b, nir_instr *instr, void *data)
1755 {
1756    if (instr->type != nir_instr_type_intrinsic)
1757       return false;
1758 
1759    b->cursor = nir_after_instr(instr);
1760 
1761    nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
1762 
1763    struct lower_rt_instruction_monolithic_state *state = data;
1764    struct rt_variables *vars = state->vars;
1765 
1766    switch (intr->intrinsic) {
1767    case nir_intrinsic_execute_callable:
1768       unreachable("nir_intrinsic_execute_callable");
1769    case nir_intrinsic_trace_ray: {
1770       vars->payload_offset = nir_src_as_uint(intr->src[10]);
1771 
1772       nir_src cull_mask = intr->src[2];
1773       bool ignore_cull_mask = nir_src_is_const(cull_mask) && (nir_src_as_uint(cull_mask) & 0xFF) == 0xFF;
1774 
1775       /* Per the SPIR-V extension spec we have to ignore some bits for some arguments. */
1776       nir_store_var(b, vars->accel_struct, intr->src[0].ssa, 0x1);
1777       nir_store_var(b, vars->cull_mask_and_flags, nir_ior(b, nir_ishl_imm(b, cull_mask.ssa, 24), intr->src[1].ssa),
1778                     0x1);
1779       nir_store_var(b, vars->sbt_offset, nir_iand_imm(b, intr->src[3].ssa, 0xf), 0x1);
1780       nir_store_var(b, vars->sbt_stride, nir_iand_imm(b, intr->src[4].ssa, 0xf), 0x1);
1781       nir_store_var(b, vars->miss_index, nir_iand_imm(b, intr->src[5].ssa, 0xffff), 0x1);
1782       nir_store_var(b, vars->origin, intr->src[6].ssa, 0x7);
1783       nir_store_var(b, vars->tmin, intr->src[7].ssa, 0x1);
1784       nir_store_var(b, vars->direction, intr->src[8].ssa, 0x7);
1785       nir_store_var(b, vars->tmax, intr->src[9].ssa, 0x1);
1786 
1787       nir_def *stack_ptr = nir_load_var(b, vars->stack_ptr);
1788       nir_store_var(b, vars->stack_ptr, nir_iadd_imm(b, stack_ptr, b->shader->scratch_size), 0x1);
1789 
1790       radv_build_traversal(state->device, state->pipeline, state->pCreateInfo, true, b, vars, ignore_cull_mask);
1791       b->shader->info.shared_size = MAX2(b->shader->info.shared_size, state->device->physical_device->rt_wave_size *
1792                                                                          MAX_STACK_ENTRY_COUNT * sizeof(uint32_t));
1793 
1794       nir_store_var(b, vars->stack_ptr, stack_ptr, 0x1);
1795 
1796       nir_instr_remove(instr);
1797       return true;
1798    }
1799    case nir_intrinsic_rt_resume:
1800       unreachable("nir_intrinsic_rt_resume");
1801    case nir_intrinsic_rt_return_amd:
1802       unreachable("nir_intrinsic_rt_return_amd");
1803    case nir_intrinsic_execute_closest_hit_amd:
1804       unreachable("nir_intrinsic_execute_closest_hit_amd");
1805    case nir_intrinsic_execute_miss_amd:
1806       unreachable("nir_intrinsic_execute_miss_amd");
1807    default:
1808       return false;
1809    }
1810 }
1811 
1812 static bool
radv_count_hit_attrib_slots(nir_builder * b,nir_intrinsic_instr * instr,void * data)1813 radv_count_hit_attrib_slots(nir_builder *b, nir_intrinsic_instr *instr, void *data)
1814 {
1815    uint32_t *count = data;
1816    if (instr->intrinsic == nir_intrinsic_load_hit_attrib_amd || instr->intrinsic == nir_intrinsic_store_hit_attrib_amd)
1817       *count = MAX2(*count, nir_intrinsic_base(instr) + 1);
1818 
1819    return false;
1820 }
1821 
1822 static void
lower_rt_instructions_monolithic(nir_shader * shader,struct radv_device * device,struct radv_ray_tracing_pipeline * pipeline,const VkRayTracingPipelineCreateInfoKHR * pCreateInfo,struct rt_variables * vars)1823 lower_rt_instructions_monolithic(nir_shader *shader, struct radv_device *device,
1824                                  struct radv_ray_tracing_pipeline *pipeline,
1825                                  const VkRayTracingPipelineCreateInfoKHR *pCreateInfo, struct rt_variables *vars)
1826 {
1827    nir_function_impl *impl = nir_shader_get_entrypoint(shader);
1828 
1829    struct lower_rt_instruction_monolithic_state state = {
1830       .device = device,
1831       .pipeline = pipeline,
1832       .pCreateInfo = pCreateInfo,
1833       .vars = vars,
1834    };
1835 
1836    nir_shader_instructions_pass(shader, lower_rt_instruction_monolithic, nir_metadata_none, &state);
1837    nir_index_ssa_defs(impl);
1838 
1839    uint32_t hit_attrib_count = 0;
1840    nir_shader_intrinsics_pass(shader, radv_count_hit_attrib_slots, nir_metadata_all, &hit_attrib_count);
1841 
1842    /* Register storage for hit attributes */
1843    STACK_ARRAY(nir_variable *, hit_attribs, hit_attrib_count);
1844    for (uint32_t i = 0; i < hit_attrib_count; i++)
1845       hit_attribs[i] = nir_local_variable_create(impl, glsl_uint_type(), "ahit_attrib");
1846 
1847    lower_hit_attribs(shader, hit_attribs, 0);
1848 }
1849 
1850 /** Select the next shader based on priorities:
1851  *
1852  * Detect the priority of the shader stage by the lowest bits in the address (low to high):
1853  *  - Raygen              - idx 0
1854  *  - Traversal           - idx 1
1855  *  - Closest Hit / Miss  - idx 2
1856  *  - Callable            - idx 3
1857  *
1858  *
1859  * This gives us the following priorities:
1860  * Raygen       :  Callable  >               >  Traversal  >  Raygen
1861  * Traversal    :            >  Chit / Miss  >             >  Raygen
1862  * CHit / Miss  :  Callable  >  Chit / Miss  >  Traversal  >  Raygen
1863  * Callable     :  Callable  >  Chit / Miss  >             >  Raygen
1864  */
1865 static nir_def *
select_next_shader(nir_builder * b,nir_def * shader_addr,unsigned wave_size)1866 select_next_shader(nir_builder *b, nir_def *shader_addr, unsigned wave_size)
1867 {
1868    gl_shader_stage stage = b->shader->info.stage;
1869    nir_def *prio = nir_iand_imm(b, shader_addr, radv_rt_priority_mask);
1870    nir_def *ballot = nir_ballot(b, 1, wave_size, nir_imm_bool(b, true));
1871    nir_def *ballot_traversal = nir_ballot(b, 1, wave_size, nir_ieq_imm(b, prio, radv_rt_priority_traversal));
1872    nir_def *ballot_hit_miss = nir_ballot(b, 1, wave_size, nir_ieq_imm(b, prio, radv_rt_priority_hit_miss));
1873    nir_def *ballot_callable = nir_ballot(b, 1, wave_size, nir_ieq_imm(b, prio, radv_rt_priority_callable));
1874 
1875    if (stage != MESA_SHADER_CALLABLE && stage != MESA_SHADER_INTERSECTION)
1876       ballot = nir_bcsel(b, nir_ine_imm(b, ballot_traversal, 0), ballot_traversal, ballot);
1877    if (stage != MESA_SHADER_RAYGEN)
1878       ballot = nir_bcsel(b, nir_ine_imm(b, ballot_hit_miss, 0), ballot_hit_miss, ballot);
1879    if (stage != MESA_SHADER_INTERSECTION)
1880       ballot = nir_bcsel(b, nir_ine_imm(b, ballot_callable, 0), ballot_callable, ballot);
1881 
1882    nir_def *lsb = nir_find_lsb(b, ballot);
1883    nir_def *next = nir_read_invocation(b, shader_addr, lsb);
1884    return nir_iand_imm(b, next, ~radv_rt_priority_mask);
1885 }
1886 
1887 void
radv_nir_lower_rt_abi(nir_shader * shader,const VkRayTracingPipelineCreateInfoKHR * pCreateInfo,const struct radv_shader_args * args,const struct radv_shader_info * info,uint32_t * stack_size,bool resume_shader,struct radv_device * device,struct radv_ray_tracing_pipeline * pipeline,bool monolithic)1888 radv_nir_lower_rt_abi(nir_shader *shader, const VkRayTracingPipelineCreateInfoKHR *pCreateInfo,
1889                       const struct radv_shader_args *args, const struct radv_shader_info *info, uint32_t *stack_size,
1890                       bool resume_shader, struct radv_device *device, struct radv_ray_tracing_pipeline *pipeline,
1891                       bool monolithic)
1892 {
1893    nir_function_impl *impl = nir_shader_get_entrypoint(shader);
1894 
1895    const VkPipelineCreateFlagBits2KHR create_flags = vk_rt_pipeline_create_flags(pCreateInfo);
1896 
1897    struct rt_variables vars = create_rt_variables(shader, device, create_flags, monolithic);
1898 
1899    if (monolithic)
1900       lower_rt_instructions_monolithic(shader, device, pipeline, pCreateInfo, &vars);
1901 
1902    lower_rt_instructions(shader, &vars, true);
1903 
1904    if (stack_size) {
1905       vars.stack_size = MAX2(vars.stack_size, shader->scratch_size);
1906       *stack_size = MAX2(*stack_size, vars.stack_size);
1907    }
1908    shader->scratch_size = 0;
1909 
1910    NIR_PASS(_, shader, nir_lower_returns);
1911 
1912    nir_cf_list list;
1913    nir_cf_extract(&list, nir_before_impl(impl), nir_after_impl(impl));
1914 
1915    /* initialize variables */
1916    nir_builder b = nir_builder_at(nir_before_impl(impl));
1917 
1918    nir_def *traversal_addr = ac_nir_load_arg(&b, &args->ac, args->ac.rt.traversal_shader_addr);
1919    nir_store_var(&b, vars.traversal_addr, nir_pack_64_2x32(&b, traversal_addr), 1);
1920 
1921    nir_def *shader_addr = ac_nir_load_arg(&b, &args->ac, args->ac.rt.shader_addr);
1922    shader_addr = nir_pack_64_2x32(&b, shader_addr);
1923    nir_store_var(&b, vars.shader_addr, shader_addr, 1);
1924 
1925    nir_store_var(&b, vars.stack_ptr, ac_nir_load_arg(&b, &args->ac, args->ac.rt.dynamic_callable_stack_base), 1);
1926    nir_def *record_ptr = ac_nir_load_arg(&b, &args->ac, args->ac.rt.shader_record);
1927    nir_store_var(&b, vars.shader_record_ptr, nir_pack_64_2x32(&b, record_ptr), 1);
1928    nir_store_var(&b, vars.arg, ac_nir_load_arg(&b, &args->ac, args->ac.rt.payload_offset), 1);
1929 
1930    nir_def *accel_struct = ac_nir_load_arg(&b, &args->ac, args->ac.rt.accel_struct);
1931    nir_store_var(&b, vars.accel_struct, nir_pack_64_2x32(&b, accel_struct), 1);
1932    nir_store_var(&b, vars.cull_mask_and_flags, ac_nir_load_arg(&b, &args->ac, args->ac.rt.cull_mask_and_flags), 1);
1933    nir_store_var(&b, vars.sbt_offset, ac_nir_load_arg(&b, &args->ac, args->ac.rt.sbt_offset), 1);
1934    nir_store_var(&b, vars.sbt_stride, ac_nir_load_arg(&b, &args->ac, args->ac.rt.sbt_stride), 1);
1935    nir_store_var(&b, vars.miss_index, ac_nir_load_arg(&b, &args->ac, args->ac.rt.miss_index), 1);
1936    nir_store_var(&b, vars.origin, ac_nir_load_arg(&b, &args->ac, args->ac.rt.ray_origin), 0x7);
1937    nir_store_var(&b, vars.tmin, ac_nir_load_arg(&b, &args->ac, args->ac.rt.ray_tmin), 1);
1938    nir_store_var(&b, vars.direction, ac_nir_load_arg(&b, &args->ac, args->ac.rt.ray_direction), 0x7);
1939    nir_store_var(&b, vars.tmax, ac_nir_load_arg(&b, &args->ac, args->ac.rt.ray_tmax), 1);
1940 
1941    nir_store_var(&b, vars.primitive_id, ac_nir_load_arg(&b, &args->ac, args->ac.rt.primitive_id), 1);
1942    nir_def *instance_addr = ac_nir_load_arg(&b, &args->ac, args->ac.rt.instance_addr);
1943    nir_store_var(&b, vars.instance_addr, nir_pack_64_2x32(&b, instance_addr), 1);
1944    nir_store_var(&b, vars.geometry_id_and_flags, ac_nir_load_arg(&b, &args->ac, args->ac.rt.geometry_id_and_flags), 1);
1945    nir_store_var(&b, vars.hit_kind, ac_nir_load_arg(&b, &args->ac, args->ac.rt.hit_kind), 1);
1946 
1947    /* guard the shader, so that only the correct invocations execute it */
1948    nir_if *shader_guard = NULL;
1949    if (shader->info.stage != MESA_SHADER_RAYGEN || resume_shader) {
1950       nir_def *uniform_shader_addr = ac_nir_load_arg(&b, &args->ac, args->ac.rt.uniform_shader_addr);
1951       uniform_shader_addr = nir_pack_64_2x32(&b, uniform_shader_addr);
1952       uniform_shader_addr = nir_ior_imm(&b, uniform_shader_addr, radv_get_rt_priority(shader->info.stage));
1953 
1954       shader_guard = nir_push_if(&b, nir_ieq(&b, uniform_shader_addr, shader_addr));
1955       shader_guard->control = nir_selection_control_divergent_always_taken;
1956    }
1957 
1958    nir_cf_reinsert(&list, b.cursor);
1959 
1960    if (shader_guard)
1961       nir_pop_if(&b, shader_guard);
1962 
1963    b.cursor = nir_after_impl(impl);
1964 
1965    if (monolithic) {
1966       nir_terminate(&b);
1967    } else {
1968       /* select next shader */
1969       shader_addr = nir_load_var(&b, vars.shader_addr);
1970       nir_def *next = select_next_shader(&b, shader_addr, info->wave_size);
1971       ac_nir_store_arg(&b, &args->ac, args->ac.rt.uniform_shader_addr, next);
1972 
1973       /* store back all variables to registers */
1974       ac_nir_store_arg(&b, &args->ac, args->ac.rt.dynamic_callable_stack_base, nir_load_var(&b, vars.stack_ptr));
1975       ac_nir_store_arg(&b, &args->ac, args->ac.rt.shader_addr, shader_addr);
1976       ac_nir_store_arg(&b, &args->ac, args->ac.rt.shader_record, nir_load_var(&b, vars.shader_record_ptr));
1977       ac_nir_store_arg(&b, &args->ac, args->ac.rt.payload_offset, nir_load_var(&b, vars.arg));
1978       ac_nir_store_arg(&b, &args->ac, args->ac.rt.accel_struct, nir_load_var(&b, vars.accel_struct));
1979       ac_nir_store_arg(&b, &args->ac, args->ac.rt.cull_mask_and_flags, nir_load_var(&b, vars.cull_mask_and_flags));
1980       ac_nir_store_arg(&b, &args->ac, args->ac.rt.sbt_offset, nir_load_var(&b, vars.sbt_offset));
1981       ac_nir_store_arg(&b, &args->ac, args->ac.rt.sbt_stride, nir_load_var(&b, vars.sbt_stride));
1982       ac_nir_store_arg(&b, &args->ac, args->ac.rt.miss_index, nir_load_var(&b, vars.miss_index));
1983       ac_nir_store_arg(&b, &args->ac, args->ac.rt.ray_origin, nir_load_var(&b, vars.origin));
1984       ac_nir_store_arg(&b, &args->ac, args->ac.rt.ray_tmin, nir_load_var(&b, vars.tmin));
1985       ac_nir_store_arg(&b, &args->ac, args->ac.rt.ray_direction, nir_load_var(&b, vars.direction));
1986       ac_nir_store_arg(&b, &args->ac, args->ac.rt.ray_tmax, nir_load_var(&b, vars.tmax));
1987 
1988       ac_nir_store_arg(&b, &args->ac, args->ac.rt.primitive_id, nir_load_var(&b, vars.primitive_id));
1989       ac_nir_store_arg(&b, &args->ac, args->ac.rt.instance_addr, nir_load_var(&b, vars.instance_addr));
1990       ac_nir_store_arg(&b, &args->ac, args->ac.rt.geometry_id_and_flags, nir_load_var(&b, vars.geometry_id_and_flags));
1991       ac_nir_store_arg(&b, &args->ac, args->ac.rt.hit_kind, nir_load_var(&b, vars.hit_kind));
1992    }
1993 
1994    nir_metadata_preserve(impl, nir_metadata_none);
1995 
1996    /* cleanup passes */
1997    NIR_PASS_V(shader, nir_lower_global_vars_to_local);
1998    NIR_PASS_V(shader, nir_lower_vars_to_ssa);
1999    if (shader->info.stage == MESA_SHADER_CLOSEST_HIT || shader->info.stage == MESA_SHADER_INTERSECTION)
2000       NIR_PASS_V(shader, lower_hit_attribs, NULL, info->wave_size);
2001 }
2002