• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright © 2024 Valve Corporation
3  *
4  * SPDX-License-Identifier: MIT
5  */
6 
7 #include "lvp_private.h"
8 #include "lvp_acceleration_structure.h"
9 #include "lvp_nir_ray_tracing.h"
10 
11 #include "vk_pipeline.h"
12 
13 #include "nir.h"
14 #include "nir_builder.h"
15 
16 #include "spirv/spirv.h"
17 
18 #include "util/mesa-sha1.h"
19 #include "util/simple_mtx.h"
20 
21 static void
lvp_init_ray_tracing_groups(struct lvp_pipeline * pipeline,const VkRayTracingPipelineCreateInfoKHR * create_info)22 lvp_init_ray_tracing_groups(struct lvp_pipeline *pipeline,
23                             const VkRayTracingPipelineCreateInfoKHR *create_info)
24 {
25    uint32_t i = 0;
26    for (; i < create_info->groupCount; i++) {
27       const VkRayTracingShaderGroupCreateInfoKHR *group_info = create_info->pGroups + i;
28       struct lvp_ray_tracing_group *dst = pipeline->rt.groups + i;
29 
30       dst->recursive_index = VK_SHADER_UNUSED_KHR;
31       dst->ahit_index = VK_SHADER_UNUSED_KHR;
32       dst->isec_index = VK_SHADER_UNUSED_KHR;
33 
34       switch (group_info->type) {
35       case VK_RAY_TRACING_SHADER_GROUP_TYPE_GENERAL_KHR:
36          if (group_info->generalShader != VK_SHADER_UNUSED_KHR) {
37             dst->recursive_index = group_info->generalShader;
38          }
39          break;
40       case VK_RAY_TRACING_SHADER_GROUP_TYPE_TRIANGLES_HIT_GROUP_KHR:
41          if (group_info->closestHitShader != VK_SHADER_UNUSED_KHR) {
42             dst->recursive_index = group_info->closestHitShader;
43          }
44          if (group_info->anyHitShader != VK_SHADER_UNUSED_KHR) {
45             dst->ahit_index = group_info->anyHitShader;
46          }
47          break;
48       case VK_RAY_TRACING_SHADER_GROUP_TYPE_PROCEDURAL_HIT_GROUP_KHR:
49          if (group_info->closestHitShader != VK_SHADER_UNUSED_KHR) {
50             dst->recursive_index = group_info->closestHitShader;
51          }
52          if (group_info->intersectionShader != VK_SHADER_UNUSED_KHR) {
53             dst->isec_index = group_info->intersectionShader;
54 
55             if (group_info->anyHitShader != VK_SHADER_UNUSED_KHR)
56                dst->ahit_index = group_info->anyHitShader;
57          }
58          break;
59       default:
60          unreachable("Unimplemented VkRayTracingShaderGroupTypeKHR");
61       }
62 
63       dst->handle.index = p_atomic_inc_return(&pipeline->device->group_handle_alloc);
64    }
65 
66    if (!create_info->pLibraryInfo)
67       return;
68 
69    uint32_t stage_base_index = create_info->stageCount;
70    for (uint32_t library_index = 0; library_index < create_info->pLibraryInfo->libraryCount; library_index++) {
71       VK_FROM_HANDLE(lvp_pipeline, library, create_info->pLibraryInfo->pLibraries[library_index]);
72       for (uint32_t group_index = 0; group_index < library->rt.group_count; group_index++) {
73          const struct lvp_ray_tracing_group *src = library->rt.groups + group_index;
74          struct lvp_ray_tracing_group *dst = pipeline->rt.groups + i;
75 
76          dst->handle = src->handle;
77 
78          if (src->recursive_index != VK_SHADER_UNUSED_KHR)
79             dst->recursive_index = stage_base_index + src->recursive_index;
80          else
81             dst->recursive_index = VK_SHADER_UNUSED_KHR;
82 
83          if (src->ahit_index != VK_SHADER_UNUSED_KHR)
84             dst->ahit_index = stage_base_index + src->ahit_index;
85          else
86             dst->ahit_index = VK_SHADER_UNUSED_KHR;
87 
88          if (src->isec_index != VK_SHADER_UNUSED_KHR)
89             dst->isec_index = stage_base_index + src->isec_index;
90          else
91             dst->isec_index = VK_SHADER_UNUSED_KHR;
92 
93          i++;
94       }
95       stage_base_index += library->rt.stage_count;
96    }
97 }
98 
99 static bool
lvp_lower_ray_tracing_derefs(nir_shader * shader)100 lvp_lower_ray_tracing_derefs(nir_shader *shader)
101 {
102    nir_function_impl *impl = nir_shader_get_entrypoint(shader);
103 
104    bool progress = false;
105 
106    nir_builder _b = nir_builder_at(nir_before_impl(impl));
107    nir_builder *b = &_b;
108 
109    nir_def *arg_offset = nir_load_shader_call_data_offset_lvp(b);
110 
111    nir_foreach_block (block, impl) {
112       nir_foreach_instr_safe (instr, block) {
113          if (instr->type != nir_instr_type_deref)
114             continue;
115 
116          nir_deref_instr *deref = nir_instr_as_deref(instr);
117          if (!nir_deref_mode_is_one_of(deref, nir_var_shader_call_data |
118                                        nir_var_ray_hit_attrib))
119             continue;
120 
121          bool is_shader_call_data = nir_deref_mode_is(deref, nir_var_shader_call_data);
122 
123          deref->modes = nir_var_function_temp;
124          progress = true;
125 
126          if (deref->deref_type == nir_deref_type_var) {
127             b->cursor = nir_before_instr(&deref->instr);
128             nir_def *offset = is_shader_call_data ? arg_offset : nir_imm_int(b, 0);
129             nir_deref_instr *replacement =
130                nir_build_deref_cast(b, offset, nir_var_function_temp, deref->var->type, 0);
131             nir_def_replace(&deref->def, &replacement->def);
132          }
133       }
134    }
135 
136    if (progress)
137       nir_metadata_preserve(impl, nir_metadata_control_flow);
138    else
139       nir_metadata_preserve(impl, nir_metadata_all);
140 
141    return progress;
142 }
143 
144 static bool
lvp_move_ray_tracing_intrinsic(nir_builder * b,nir_intrinsic_instr * instr,void * data)145 lvp_move_ray_tracing_intrinsic(nir_builder *b, nir_intrinsic_instr *instr, void *data)
146 {
147    switch (instr->intrinsic) {
148    case nir_intrinsic_load_shader_record_ptr:
149    case nir_intrinsic_load_ray_flags:
150    case nir_intrinsic_load_ray_object_origin:
151    case nir_intrinsic_load_ray_world_origin:
152    case nir_intrinsic_load_ray_t_min:
153    case nir_intrinsic_load_ray_object_direction:
154    case nir_intrinsic_load_ray_world_direction:
155    case nir_intrinsic_load_ray_t_max:
156       nir_instr_move(nir_before_impl(b->impl), &instr->instr);
157       return true;
158    default:
159       return false;
160    }
161 }
162 
163 static VkResult
lvp_compile_ray_tracing_stages(struct lvp_pipeline * pipeline,const VkRayTracingPipelineCreateInfoKHR * create_info)164 lvp_compile_ray_tracing_stages(struct lvp_pipeline *pipeline,
165                                const VkRayTracingPipelineCreateInfoKHR *create_info)
166 {
167    VkResult result = VK_SUCCESS;
168 
169    uint32_t i = 0;
170    for (; i < create_info->stageCount; i++) {
171       nir_shader *nir;
172       result = lvp_spirv_to_nir(pipeline, create_info->pStages + i, &nir);
173       if (result != VK_SUCCESS)
174          return result;
175 
176       assert(!nir->scratch_size);
177       if (nir->info.stage == MESA_SHADER_ANY_HIT ||
178           nir->info.stage == MESA_SHADER_CLOSEST_HIT ||
179           nir->info.stage == MESA_SHADER_INTERSECTION)
180          nir->scratch_size = LVP_RAY_HIT_ATTRIBS_SIZE;
181 
182       NIR_PASS(_, nir, nir_lower_vars_to_explicit_types,
183                nir_var_function_temp | nir_var_shader_call_data | nir_var_ray_hit_attrib,
184                glsl_get_natural_size_align_bytes);
185 
186       NIR_PASS(_, nir, lvp_lower_ray_tracing_derefs);
187 
188       NIR_PASS(_, nir, nir_lower_explicit_io, nir_var_function_temp, nir_address_format_32bit_offset);
189 
190       NIR_PASS(_, nir, nir_shader_intrinsics_pass, lvp_move_ray_tracing_intrinsic,
191                nir_metadata_control_flow, NULL);
192 
193       pipeline->rt.stages[i] = lvp_create_pipeline_nir(nir);
194       if (!pipeline->rt.stages[i]) {
195          result = VK_ERROR_OUT_OF_HOST_MEMORY;
196          ralloc_free(nir);
197          return result;
198       }
199       if (pipeline->layout)
200          pipeline->shaders[nir->info.stage].push_constant_size = pipeline->layout->push_constant_size;
201    }
202 
203    if (!create_info->pLibraryInfo)
204       return result;
205 
206    for (uint32_t library_index = 0; library_index < create_info->pLibraryInfo->libraryCount; library_index++) {
207       VK_FROM_HANDLE(lvp_pipeline, library, create_info->pLibraryInfo->pLibraries[library_index]);
208       for (uint32_t stage_index = 0; stage_index < library->rt.stage_count; stage_index++) {
209          lvp_pipeline_nir_ref(pipeline->rt.stages + i, library->rt.stages[stage_index]);
210          i++;
211       }
212    }
213 
214    return result;
215 }
216 
217 static nir_def *
lvp_load_trace_ray_command_field(nir_builder * b,uint32_t command_offset,uint32_t num_components,uint32_t bit_size)218 lvp_load_trace_ray_command_field(nir_builder *b, uint32_t command_offset,
219                                  uint32_t num_components, uint32_t bit_size)
220 {
221    return nir_load_ssbo(b, num_components, bit_size, nir_imm_int(b, 0),
222                         nir_imm_int(b, command_offset));
223 }
224 
225 struct lvp_sbt_entry {
226    nir_def *value;
227    nir_def *shader_record_ptr;
228 };
229 
230 static struct lvp_sbt_entry
lvp_load_sbt_entry(nir_builder * b,nir_def * index,uint32_t command_offset,uint32_t index_offset)231 lvp_load_sbt_entry(nir_builder *b, nir_def *index,
232                    uint32_t command_offset, uint32_t index_offset)
233 {
234    nir_def *addr = lvp_load_trace_ray_command_field(b, command_offset, 1, 64);
235 
236    if (index) {
237       /* The 32 high bits of stride can be ignored. */
238       nir_def *stride = lvp_load_trace_ray_command_field(
239          b, command_offset + sizeof(VkDeviceSize) * 2, 1, 32);
240       addr = nir_iadd(b, addr, nir_u2u64(b, nir_imul(b, index, stride)));
241    }
242 
243    return (struct lvp_sbt_entry) {
244       .value = nir_build_load_global(b, 1, 32, nir_iadd_imm(b, addr, index_offset)),
245       .shader_record_ptr = nir_iadd_imm(b, addr, LVP_RAY_TRACING_GROUP_HANDLE_SIZE),
246    };
247 }
248 
249 struct lvp_ray_traversal_state {
250    nir_variable *origin;
251    nir_variable *dir;
252    nir_variable *inv_dir;
253    nir_variable *bvh_base;
254    nir_variable *current_node;
255    nir_variable *stack_base;
256    nir_variable *stack_ptr;
257    nir_variable *stack;
258    nir_variable *hit;
259 
260    nir_variable *instance_addr;
261    nir_variable *sbt_offset_and_flags;
262 };
263 
264 struct lvp_ray_tracing_state {
265    nir_variable *bvh_base;
266    nir_variable *flags;
267    nir_variable *cull_mask;
268    nir_variable *sbt_offset;
269    nir_variable *sbt_stride;
270    nir_variable *miss_index;
271    nir_variable *origin;
272    nir_variable *tmin;
273    nir_variable *dir;
274    nir_variable *tmax;
275 
276    nir_variable *instance_addr;
277    nir_variable *primitive_id;
278    nir_variable *geometry_id_and_flags;
279    nir_variable *hit_kind;
280    nir_variable *sbt_index;
281 
282    nir_variable *shader_record_ptr;
283    nir_variable *stack_ptr;
284    nir_variable *shader_call_data_offset;
285 
286    nir_variable *accept;
287    nir_variable *terminate;
288    nir_variable *opaque;
289 
290    struct lvp_ray_traversal_state traversal;
291 };
292 
293 struct lvp_ray_tracing_pipeline_compiler {
294    struct lvp_pipeline *pipeline;
295    VkPipelineCreateFlags2KHR flags;
296 
297    struct lvp_ray_tracing_state state;
298 
299    struct hash_table *functions;
300 
301    uint32_t raygen_size;
302    uint32_t ahit_size;
303    uint32_t chit_size;
304    uint32_t miss_size;
305    uint32_t isec_size;
306    uint32_t callable_size;
307 };
308 
309 static uint32_t
lvp_ray_tracing_pipeline_compiler_get_stack_size(struct lvp_ray_tracing_pipeline_compiler * compiler,nir_function * function)310 lvp_ray_tracing_pipeline_compiler_get_stack_size(
311    struct lvp_ray_tracing_pipeline_compiler *compiler, nir_function *function)
312 {
313    hash_table_foreach(compiler->functions, entry) {
314       if (entry->data == function) {
315          const nir_shader *shader = entry->key;
316          return shader->scratch_size;
317       }
318    }
319    return 0;
320 }
321 
322 static void
lvp_ray_tracing_state_init(nir_shader * nir,struct lvp_ray_tracing_state * state)323 lvp_ray_tracing_state_init(nir_shader *nir, struct lvp_ray_tracing_state *state)
324 {
325    state->bvh_base = nir_variable_create(nir, nir_var_shader_temp, glsl_uint64_t_type(), "bvh_base");
326    state->flags = nir_variable_create(nir, nir_var_shader_temp, glsl_uint_type(), "flags");
327    state->cull_mask = nir_variable_create(nir, nir_var_shader_temp, glsl_uint_type(), "cull_mask");
328    state->sbt_offset = nir_variable_create(nir, nir_var_shader_temp, glsl_uint_type(), "sbt_offset");
329    state->sbt_stride = nir_variable_create(nir, nir_var_shader_temp, glsl_uint_type(), "sbt_stride");
330    state->miss_index = nir_variable_create(nir, nir_var_shader_temp, glsl_uint_type(), "miss_index");
331    state->origin = nir_variable_create(nir, nir_var_shader_temp, glsl_vec_type(3), "origin");
332    state->tmin = nir_variable_create(nir, nir_var_shader_temp, glsl_float_type(), "tmin");
333    state->dir = nir_variable_create(nir, nir_var_shader_temp, glsl_vec_type(3), "dir");
334    state->tmax = nir_variable_create(nir, nir_var_shader_temp, glsl_float_type(), "tmax");
335 
336    state->instance_addr = nir_variable_create(nir, nir_var_shader_temp, glsl_uint64_t_type(), "instance_addr");
337    state->primitive_id = nir_variable_create(nir, nir_var_shader_temp, glsl_uint_type(), "primitive_id");
338    state->geometry_id_and_flags = nir_variable_create(nir, nir_var_shader_temp, glsl_uint_type(), "geometry_id_and_flags");
339    state->hit_kind = nir_variable_create(nir, nir_var_shader_temp, glsl_uint_type(), "hit_kind");
340    state->sbt_index = nir_variable_create(nir, nir_var_shader_temp, glsl_uint_type(), "sbt_index");
341 
342    state->shader_record_ptr = nir_variable_create(nir, nir_var_shader_temp, glsl_uint64_t_type(), "shader_record_ptr");
343    state->stack_ptr = nir_variable_create(nir, nir_var_shader_temp, glsl_uint_type(), "stack_ptr");
344    state->shader_call_data_offset = nir_variable_create(nir, nir_var_shader_temp, glsl_uint_type(), "shader_call_data_offset");
345 
346    state->accept = nir_variable_create(nir, nir_var_shader_temp, glsl_bool_type(), "accept");
347    state->terminate = nir_variable_create(nir, nir_var_shader_temp, glsl_bool_type(), "terminate");
348    state->opaque = nir_variable_create(nir, nir_var_shader_temp, glsl_bool_type(), "opaque");
349 }
350 
351 static void
lvp_ray_traversal_state_init(nir_function_impl * impl,struct lvp_ray_traversal_state * state)352 lvp_ray_traversal_state_init(nir_function_impl *impl, struct lvp_ray_traversal_state *state)
353 {
354    state->origin = nir_local_variable_create(impl, glsl_vec_type(3), "traversal.origin");
355    state->dir = nir_local_variable_create(impl, glsl_vec_type(3), "traversal.dir");
356    state->inv_dir = nir_local_variable_create(impl, glsl_vec_type(3), "traversal.inv_dir");
357    state->bvh_base = nir_local_variable_create(impl, glsl_uint64_t_type(), "traversal.bvh_base");
358    state->current_node = nir_local_variable_create(impl, glsl_uint_type(), "traversal.current_node");
359    state->stack_base = nir_local_variable_create(impl, glsl_uint_type(), "traversal.stack_base");
360    state->stack_ptr = nir_local_variable_create(impl, glsl_uint_type(), "traversal.stack_ptr");
361    state->stack = nir_local_variable_create(impl, glsl_array_type(glsl_uint_type(), 24 * 2, 0), "traversal.stack");
362    state->hit = nir_local_variable_create(impl, glsl_bool_type(), "traversal.hit");
363 
364    state->instance_addr = nir_local_variable_create(impl, glsl_uint64_t_type(), "traversal.instance_addr");
365    state->sbt_offset_and_flags = nir_local_variable_create(impl, glsl_uint_type(), "traversal.sbt_offset_and_flags");
366 }
367 
368 static void
lvp_call_ray_tracing_stage(nir_builder * b,struct lvp_ray_tracing_pipeline_compiler * compiler,nir_shader * stage)369 lvp_call_ray_tracing_stage(nir_builder *b, struct lvp_ray_tracing_pipeline_compiler *compiler, nir_shader *stage)
370 {
371    nir_function *function;
372 
373    struct hash_entry *entry = _mesa_hash_table_search(compiler->functions, stage);
374    if (entry) {
375       function = entry->data;
376    } else {
377       nir_function_impl *stage_entrypoint = nir_shader_get_entrypoint(stage);
378       nir_function_impl *copy = nir_function_impl_clone(b->shader, stage_entrypoint);
379 
380       struct hash_table *var_remap = _mesa_pointer_hash_table_create(NULL);
381 
382       nir_foreach_block(block, copy) {
383          nir_foreach_instr_safe(instr, block) {
384             if (instr->type != nir_instr_type_deref)
385                continue;
386 
387             nir_deref_instr *deref = nir_instr_as_deref(instr);
388             if (deref->deref_type != nir_deref_type_var ||
389                 deref->var->data.mode == nir_var_function_temp)
390                continue;
391 
392             struct hash_entry *entry =
393                _mesa_hash_table_search(var_remap, deref->var);
394             if (!entry) {
395                nir_variable *new_var = nir_variable_clone(deref->var, b->shader);
396                nir_shader_add_variable(b->shader, new_var);
397                entry = _mesa_hash_table_insert(var_remap,
398                                                deref->var, new_var);
399             }
400             deref->var = entry->data;
401          }
402       }
403 
404       function = nir_function_create(
405          b->shader, _mesa_shader_stage_to_string(stage->info.stage));
406       nir_function_set_impl(function, copy);
407 
408       ralloc_free(var_remap);
409 
410       _mesa_hash_table_insert(compiler->functions, stage, function);
411    }
412 
413    nir_build_call(b, function, 0, NULL);
414 
415    switch(stage->info.stage) {
416    case MESA_SHADER_RAYGEN:
417       compiler->raygen_size = MAX2(compiler->raygen_size, stage->scratch_size);
418       break;
419    case MESA_SHADER_ANY_HIT:
420       compiler->ahit_size = MAX2(compiler->ahit_size, stage->scratch_size);
421       break;
422    case MESA_SHADER_CLOSEST_HIT:
423       compiler->chit_size = MAX2(compiler->chit_size, stage->scratch_size);
424       break;
425    case MESA_SHADER_MISS:
426       compiler->miss_size = MAX2(compiler->miss_size, stage->scratch_size);
427       break;
428    case MESA_SHADER_INTERSECTION:
429       compiler->isec_size = MAX2(compiler->isec_size, stage->scratch_size);
430       break;
431    case MESA_SHADER_CALLABLE:
432       compiler->callable_size = MAX2(compiler->callable_size, stage->scratch_size);
433       break;
434    default:
435       unreachable("Invalid ray tracing stage");
436       break;
437    }
438 }
439 
440 static void
lvp_execute_callable(nir_builder * b,struct lvp_ray_tracing_pipeline_compiler * compiler,nir_intrinsic_instr * instr)441 lvp_execute_callable(nir_builder *b, struct lvp_ray_tracing_pipeline_compiler *compiler,
442                      nir_intrinsic_instr *instr)
443 {
444    struct lvp_ray_tracing_state *state = &compiler->state;
445 
446    nir_def *sbt_index = instr->src[0].ssa;
447    nir_def *payload = instr->src[1].ssa;
448 
449    struct lvp_sbt_entry callable_entry = lvp_load_sbt_entry(
450       b,
451       sbt_index,
452       offsetof(VkTraceRaysIndirectCommand2KHR, callableShaderBindingTableAddress),
453       offsetof(struct lvp_ray_tracing_group_handle, index));
454    nir_store_var(b, compiler->state.shader_record_ptr, callable_entry.shader_record_ptr, 0x1);
455 
456    uint32_t stack_size =
457       lvp_ray_tracing_pipeline_compiler_get_stack_size(compiler, b->impl->function);
458    nir_def *stack_ptr = nir_load_var(b, state->stack_ptr);
459    nir_store_var(b, state->stack_ptr, nir_iadd_imm(b, stack_ptr, stack_size), 0x1);
460 
461    nir_store_var(b, state->shader_call_data_offset, nir_iadd_imm(b, payload, -stack_size), 0x1);
462 
463    for (uint32_t i = 0; i < compiler->pipeline->rt.group_count; i++) {
464       struct lvp_ray_tracing_group *group = compiler->pipeline->rt.groups + i;
465       if (group->recursive_index == VK_SHADER_UNUSED_KHR)
466          continue;
467 
468       nir_shader *stage = compiler->pipeline->rt.stages[group->recursive_index]->nir;
469       if (stage->info.stage != MESA_SHADER_CALLABLE)
470          continue;
471 
472       nir_push_if(b, nir_ieq_imm(b, callable_entry.value, group->handle.index));
473       lvp_call_ray_tracing_stage(b, compiler, stage);
474       nir_pop_if(b, NULL);
475    }
476 
477    nir_store_var(b, state->stack_ptr, stack_ptr, 0x1);
478 }
479 
480 struct lvp_lower_isec_intrinsic_state {
481    struct lvp_ray_tracing_pipeline_compiler *compiler;
482    nir_shader *ahit;
483 };
484 
485 static bool
lvp_lower_isec_intrinsic(nir_builder * b,nir_intrinsic_instr * instr,void * data)486 lvp_lower_isec_intrinsic(nir_builder *b, nir_intrinsic_instr *instr, void *data)
487 {
488    if (instr->intrinsic != nir_intrinsic_report_ray_intersection)
489       return false;
490 
491    struct lvp_lower_isec_intrinsic_state *isec_state = data;
492    struct lvp_ray_tracing_pipeline_compiler *compiler = isec_state->compiler;
493    struct lvp_ray_tracing_state *state = &compiler->state;
494 
495    b->cursor = nir_after_instr(&instr->instr);
496 
497    nir_def *t = instr->src[0].ssa;
498    nir_def *hit_kind = instr->src[1].ssa;
499 
500    nir_def *prev_accept = nir_load_var(b, state->accept);
501    nir_def *prev_tmax = nir_load_var(b, state->tmax);
502    nir_def *prev_hit_kind = nir_load_var(b, state->hit_kind);
503 
504    nir_variable *commit = nir_local_variable_create(b->impl, glsl_bool_type(), "commit");
505    nir_store_var(b, commit, nir_imm_false(b), 0x1);
506 
507    nir_def *in_range = nir_iand(b, nir_fge(b, t, nir_load_var(b, state->tmin)), nir_fge(b, nir_load_var(b, state->tmax), t));
508    nir_def *terminated = nir_iand(b, nir_load_var(b, state->terminate), nir_load_var(b, state->accept));
509    nir_push_if(b, nir_iand(b, in_range, nir_inot(b, terminated)));
510    {
511       nir_store_var(b, state->accept, nir_imm_true(b), 0x1);
512 
513       nir_store_var(b, state->tmax, t, 1);
514       nir_store_var(b, state->hit_kind, hit_kind, 1);
515 
516       if (isec_state->ahit) {
517          nir_def *prev_terminate = nir_load_var(b, state->terminate);
518          nir_store_var(b, state->terminate, nir_imm_false(b), 0x1);
519 
520          nir_push_if(b, nir_inot(b, nir_load_var(b, state->opaque)));
521          {
522             lvp_call_ray_tracing_stage(b, compiler, isec_state->ahit);
523          }
524          nir_pop_if(b, NULL);
525 
526          nir_def *terminate = nir_load_var(b, state->terminate);
527          nir_store_var(b, state->terminate, nir_ior(b, terminate, prev_terminate), 0x1);
528 
529          nir_push_if(b, terminate);
530          nir_jump(b, nir_jump_return);
531          nir_pop_if(b, NULL);
532       }
533 
534       nir_push_if(b, nir_load_var(b, state->accept));
535       {
536          nir_store_var(b, commit, nir_imm_true(b), 0x1);
537       }
538       nir_push_else(b, NULL);
539       {
540          nir_store_var(b, state->accept, prev_accept, 0x1);
541          nir_store_var(b, state->tmax, prev_tmax, 1);
542          nir_store_var(b, state->hit_kind, prev_hit_kind, 1);
543       }
544       nir_pop_if(b, NULL);
545    }
546    nir_pop_if(b, NULL);
547 
548    nir_def_replace(&instr->def, nir_load_var(b, commit));
549 
550    return true;
551 }
552 
553 static void
lvp_handle_aabb_intersection(nir_builder * b,struct lvp_leaf_intersection * intersection,const struct lvp_ray_traversal_args * args,const struct lvp_ray_flags * ray_flags)554 lvp_handle_aabb_intersection(nir_builder *b, struct lvp_leaf_intersection *intersection,
555                              const struct lvp_ray_traversal_args *args,
556                              const struct lvp_ray_flags *ray_flags)
557 {
558    struct lvp_ray_tracing_pipeline_compiler *compiler = args->data;
559    struct lvp_ray_tracing_state *state = &compiler->state;
560 
561    nir_store_var(b, state->accept, nir_imm_false(b), 0x1);
562    nir_store_var(b, state->terminate, ray_flags->terminate_on_first_hit, 0x1);
563    nir_store_var(b, state->opaque, intersection->opaque, 0x1);
564 
565    nir_def *prev_instance_addr = nir_load_var(b, state->instance_addr);
566    nir_def *prev_primitive_id = nir_load_var(b, state->primitive_id);
567    nir_def *prev_geometry_id_and_flags = nir_load_var(b, state->geometry_id_and_flags);
568 
569    nir_store_var(b, state->instance_addr, nir_load_var(b, state->traversal.instance_addr), 0x1);
570    nir_store_var(b, state->primitive_id, intersection->primitive_id, 0x1);
571    nir_store_var(b, state->geometry_id_and_flags, intersection->geometry_id_and_flags, 0x1);
572 
573    nir_def *geometry_id = nir_iand_imm(b, intersection->geometry_id_and_flags, 0xfffffff);
574    nir_def *sbt_index =
575       nir_iadd(b,
576                nir_iadd(b, nir_load_var(b, state->sbt_offset),
577                         nir_iand_imm(b, nir_load_var(b, state->traversal.sbt_offset_and_flags), 0xffffff)),
578                nir_imul(b, nir_load_var(b, state->sbt_stride), geometry_id));
579 
580    struct lvp_sbt_entry isec_entry = lvp_load_sbt_entry(
581       b,
582       sbt_index,
583       offsetof(VkTraceRaysIndirectCommand2KHR, hitShaderBindingTableAddress),
584       offsetof(struct lvp_ray_tracing_group_handle, index));
585    nir_store_var(b, compiler->state.shader_record_ptr, isec_entry.shader_record_ptr, 0x1);
586 
587    for (uint32_t i = 0; i < compiler->pipeline->rt.group_count; i++) {
588       struct lvp_ray_tracing_group *group = compiler->pipeline->rt.groups + i;
589       if (group->isec_index == VK_SHADER_UNUSED_KHR)
590          continue;
591 
592       nir_shader *stage = compiler->pipeline->rt.stages[group->isec_index]->nir;
593 
594       nir_push_if(b, nir_ieq_imm(b, isec_entry.value, group->handle.index));
595       lvp_call_ray_tracing_stage(b, compiler, stage);
596       nir_pop_if(b, NULL);
597 
598       nir_shader *ahit_stage = NULL;
599       if (group->ahit_index != VK_SHADER_UNUSED_KHR)
600          ahit_stage = compiler->pipeline->rt.stages[group->ahit_index]->nir;
601 
602       struct lvp_lower_isec_intrinsic_state isec_state = {
603          .compiler = compiler,
604          .ahit = ahit_stage,
605       };
606       nir_shader_intrinsics_pass(b->shader, lvp_lower_isec_intrinsic,
607                                  nir_metadata_none, &isec_state);
608    }
609 
610    nir_push_if(b, nir_load_var(b, state->accept));
611    {
612       nir_store_var(b, state->sbt_index, sbt_index, 0x1);
613       nir_store_var(b, state->traversal.hit, nir_imm_true(b), 0x1);
614 
615       nir_break_if(b, nir_load_var(b, state->terminate));
616    }
617    nir_push_else(b, NULL);
618    {
619       nir_store_var(b, state->instance_addr, prev_instance_addr, 0x1);
620       nir_store_var(b, state->primitive_id, prev_primitive_id, 0x1);
621       nir_store_var(b, state->geometry_id_and_flags, prev_geometry_id_and_flags, 0x1);
622    }
623    nir_pop_if(b, NULL);
624 }
625 
626 static void
lvp_handle_triangle_intersection(nir_builder * b,struct lvp_triangle_intersection * intersection,const struct lvp_ray_traversal_args * args,const struct lvp_ray_flags * ray_flags)627 lvp_handle_triangle_intersection(nir_builder *b,
628                                  struct lvp_triangle_intersection *intersection,
629                                  const struct lvp_ray_traversal_args *args,
630                                  const struct lvp_ray_flags *ray_flags)
631 {
632    struct lvp_ray_tracing_pipeline_compiler *compiler = args->data;
633    struct lvp_ray_tracing_state *state = &compiler->state;
634 
635    nir_store_var(b, state->accept, nir_imm_true(b), 0x1);
636    nir_store_var(b, state->terminate, ray_flags->terminate_on_first_hit, 0x1);
637 
638    nir_def *barycentrics_offset = nir_load_var(b, state->stack_ptr);
639 
640    nir_def *prev_tmax = nir_load_var(b, state->tmax);
641    nir_def *prev_instance_addr = nir_load_var(b, state->instance_addr);
642    nir_def *prev_primitive_id = nir_load_var(b, state->primitive_id);
643    nir_def *prev_geometry_id_and_flags = nir_load_var(b, state->geometry_id_and_flags);
644    nir_def *prev_hit_kind = nir_load_var(b, state->hit_kind);
645    nir_def *prev_barycentrics = nir_load_scratch(b, 2, 32, barycentrics_offset);
646 
647    nir_store_var(b, state->tmax, intersection->t, 0x1);
648    nir_store_var(b, state->instance_addr, nir_load_var(b, state->traversal.instance_addr), 0x1);
649    nir_store_var(b, state->primitive_id, intersection->base.primitive_id, 0x1);
650    nir_store_var(b, state->geometry_id_and_flags, intersection->base.geometry_id_and_flags, 0x1);
651    nir_store_var(b, state->hit_kind,
652                  nir_bcsel(b, intersection->frontface, nir_imm_int(b, 0xFE), nir_imm_int(b, 0xFF)), 0x1);
653 
654    nir_store_scratch(b, intersection->barycentrics, barycentrics_offset);
655 
656    nir_def *geometry_id = nir_iand_imm(b, intersection->base.geometry_id_and_flags, 0xfffffff);
657    nir_def *sbt_index =
658       nir_iadd(b,
659                nir_iadd(b, nir_load_var(b, state->sbt_offset),
660                         nir_iand_imm(b, nir_load_var(b, state->traversal.sbt_offset_and_flags), 0xffffff)),
661                nir_imul(b, nir_load_var(b, state->sbt_stride), geometry_id));
662 
663    nir_push_if(b, nir_inot(b, intersection->base.opaque));
664    {
665       struct lvp_sbt_entry ahit_entry = lvp_load_sbt_entry(
666          b,
667          sbt_index,
668          offsetof(VkTraceRaysIndirectCommand2KHR, hitShaderBindingTableAddress),
669          offsetof(struct lvp_ray_tracing_group_handle, index));
670       nir_store_var(b, compiler->state.shader_record_ptr, ahit_entry.shader_record_ptr, 0x1);
671 
672       for (uint32_t i = 0; i < compiler->pipeline->rt.group_count; i++) {
673          struct lvp_ray_tracing_group *group = compiler->pipeline->rt.groups + i;
674          if (group->ahit_index == VK_SHADER_UNUSED_KHR)
675             continue;
676 
677          nir_shader *stage = compiler->pipeline->rt.stages[group->ahit_index]->nir;
678 
679          nir_push_if(b, nir_ieq_imm(b, ahit_entry.value, group->handle.index));
680          lvp_call_ray_tracing_stage(b, compiler, stage);
681          nir_pop_if(b, NULL);
682       }
683    }
684    nir_pop_if(b, NULL);
685 
686    nir_push_if(b, nir_load_var(b, state->accept));
687    {
688       nir_store_var(b, state->sbt_index, sbt_index, 0x1);
689       nir_store_var(b, state->traversal.hit, nir_imm_true(b), 0x1);
690 
691       nir_break_if(b, nir_load_var(b, state->terminate));
692    }
693    nir_push_else(b, NULL);
694    {
695       nir_store_var(b, state->tmax, prev_tmax, 0x1);
696       nir_store_var(b, state->instance_addr, prev_instance_addr, 0x1);
697       nir_store_var(b, state->primitive_id, prev_primitive_id, 0x1);
698       nir_store_var(b, state->geometry_id_and_flags, prev_geometry_id_and_flags, 0x1);
699       nir_store_var(b, state->hit_kind, prev_hit_kind, 0x1);
700       nir_store_scratch(b, prev_barycentrics, barycentrics_offset);
701    }
702    nir_pop_if(b, NULL);
703 }
704 
705 static void
lvp_trace_ray(nir_builder * b,struct lvp_ray_tracing_pipeline_compiler * compiler,nir_intrinsic_instr * instr)706 lvp_trace_ray(nir_builder *b, struct lvp_ray_tracing_pipeline_compiler *compiler,
707               nir_intrinsic_instr *instr)
708 {
709    struct lvp_ray_tracing_state *state = &compiler->state;
710 
711    nir_def *accel_struct = instr->src[0].ssa;
712    nir_def *flags = instr->src[1].ssa;
713    nir_def *cull_mask = instr->src[2].ssa;
714    nir_def *sbt_offset = nir_iand_imm(b, instr->src[3].ssa, 0xF);
715    nir_def *sbt_stride = nir_iand_imm(b, instr->src[4].ssa, 0xF);
716    nir_def *miss_index = nir_iand_imm(b, instr->src[5].ssa, 0xFFFF);
717    nir_def *origin = instr->src[6].ssa;
718    nir_def *tmin = instr->src[7].ssa;
719    nir_def *dir = instr->src[8].ssa;
720    nir_def *tmax = instr->src[9].ssa;
721    nir_def *payload = instr->src[10].ssa;
722 
723    uint32_t stack_size =
724       lvp_ray_tracing_pipeline_compiler_get_stack_size(compiler, b->impl->function);
725    nir_def *stack_ptr = nir_load_var(b, state->stack_ptr);
726    nir_store_var(b, state->stack_ptr, nir_iadd_imm(b, stack_ptr, stack_size), 0x1);
727 
728    nir_store_var(b, state->shader_call_data_offset, nir_iadd_imm(b, payload, -stack_size), 0x1);
729 
730    nir_def *bvh_base = accel_struct;
731    if (bvh_base->bit_size != 64) {
732       assert(bvh_base->num_components >= 2);
733       bvh_base = nir_load_ubo(
734          b, 1, 64, nir_channel(b, accel_struct, 0),
735          nir_imul_imm(b, nir_channel(b, accel_struct, 1), sizeof(struct lp_descriptor)), .range = ~0);
736    }
737 
738    lvp_ray_traversal_state_init(b->impl, &state->traversal);
739 
740    nir_store_var(b, state->bvh_base, bvh_base, 0x1);
741    nir_store_var(b, state->flags, flags, 0x1);
742    nir_store_var(b, state->cull_mask, cull_mask, 0x1);
743    nir_store_var(b, state->sbt_offset, sbt_offset, 0x1);
744    nir_store_var(b, state->sbt_stride, sbt_stride, 0x1);
745    nir_store_var(b, state->miss_index, miss_index, 0x1);
746    nir_store_var(b, state->origin, origin, 0x7);
747    nir_store_var(b, state->tmin, tmin, 0x1);
748    nir_store_var(b, state->dir, dir, 0x7);
749    nir_store_var(b, state->tmax, tmax, 0x1);
750 
751    nir_store_var(b, state->traversal.bvh_base, bvh_base, 0x1);
752    nir_store_var(b, state->traversal.origin, origin, 0x7);
753    nir_store_var(b, state->traversal.dir, dir, 0x7);
754    nir_store_var(b, state->traversal.inv_dir, nir_frcp(b, dir), 0x7);
755    nir_store_var(b, state->traversal.current_node, nir_imm_int(b, LVP_BVH_ROOT_NODE), 0x1);
756    nir_store_var(b, state->traversal.stack_base, nir_imm_int(b, -1), 0x1);
757    nir_store_var(b, state->traversal.stack_ptr, nir_imm_int(b, 0), 0x1);
758 
759    nir_store_var(b, state->traversal.hit, nir_imm_false(b), 0x1);
760 
761    struct lvp_ray_traversal_vars vars = {
762       .tmax = nir_build_deref_var(b, state->tmax),
763       .origin = nir_build_deref_var(b, state->traversal.origin),
764       .dir = nir_build_deref_var(b, state->traversal.dir),
765       .inv_dir = nir_build_deref_var(b, state->traversal.inv_dir),
766       .bvh_base = nir_build_deref_var(b, state->traversal.bvh_base),
767       .current_node = nir_build_deref_var(b, state->traversal.current_node),
768       .stack_base = nir_build_deref_var(b, state->traversal.stack_base),
769       .stack_ptr = nir_build_deref_var(b, state->traversal.stack_ptr),
770       .stack = nir_build_deref_var(b, state->traversal.stack),
771       .instance_addr = nir_build_deref_var(b, state->traversal.instance_addr),
772       .sbt_offset_and_flags = nir_build_deref_var(b, state->traversal.sbt_offset_and_flags),
773    };
774 
775    struct lvp_ray_traversal_args args = {
776       .root_bvh_base = bvh_base,
777       .flags = flags,
778       .cull_mask = nir_ishl_imm(b, cull_mask, 24),
779       .origin = origin,
780       .tmin = tmin,
781       .dir = dir,
782       .vars = vars,
783       .aabb_cb = (compiler->flags & VK_PIPELINE_CREATE_2_RAY_TRACING_SKIP_AABBS_BIT_KHR) ?
784                  NULL : lvp_handle_aabb_intersection,
785       .triangle_cb = (compiler->flags & VK_PIPELINE_CREATE_2_RAY_TRACING_SKIP_TRIANGLES_BIT_KHR) ?
786                      NULL : lvp_handle_triangle_intersection,
787       .data = compiler,
788    };
789 
790    nir_push_if(b, nir_ine_imm(b, bvh_base, 0));
791    lvp_build_ray_traversal(b, &args);
792    nir_pop_if(b, NULL);
793 
794    nir_push_if(b, nir_load_var(b, state->traversal.hit));
795    {
796       nir_def *skip_chit = nir_test_mask(b, flags, SpvRayFlagsSkipClosestHitShaderKHRMask);
797       nir_push_if(b, nir_inot(b, skip_chit));
798 
799       struct lvp_sbt_entry chit_entry = lvp_load_sbt_entry(
800          b,
801          nir_load_var(b, state->sbt_index),
802          offsetof(VkTraceRaysIndirectCommand2KHR, hitShaderBindingTableAddress),
803          offsetof(struct lvp_ray_tracing_group_handle, index));
804       nir_store_var(b, compiler->state.shader_record_ptr, chit_entry.shader_record_ptr, 0x1);
805 
806       for (uint32_t i = 0; i < compiler->pipeline->rt.group_count; i++) {
807          struct lvp_ray_tracing_group *group = compiler->pipeline->rt.groups + i;
808          if (group->recursive_index == VK_SHADER_UNUSED_KHR)
809             continue;
810 
811          nir_shader *stage = compiler->pipeline->rt.stages[group->recursive_index]->nir;
812          if (stage->info.stage != MESA_SHADER_CLOSEST_HIT)
813             continue;
814 
815          nir_push_if(b, nir_ieq_imm(b, chit_entry.value, group->handle.index));
816          lvp_call_ray_tracing_stage(b, compiler, stage);
817          nir_pop_if(b, NULL);
818       }
819 
820       nir_pop_if(b, NULL);
821    }
822    nir_push_else(b, NULL);
823    {
824       struct lvp_sbt_entry miss_entry = lvp_load_sbt_entry(
825          b,
826          miss_index,
827          offsetof(VkTraceRaysIndirectCommand2KHR, missShaderBindingTableAddress),
828          offsetof(struct lvp_ray_tracing_group_handle, index));
829       nir_store_var(b, compiler->state.shader_record_ptr, miss_entry.shader_record_ptr, 0x1);
830 
831       for (uint32_t i = 0; i < compiler->pipeline->rt.group_count; i++) {
832          struct lvp_ray_tracing_group *group = compiler->pipeline->rt.groups + i;
833          if (group->recursive_index == VK_SHADER_UNUSED_KHR)
834             continue;
835 
836          nir_shader *stage = compiler->pipeline->rt.stages[group->recursive_index]->nir;
837          if (stage->info.stage != MESA_SHADER_MISS)
838             continue;
839 
840          nir_push_if(b, nir_ieq_imm(b, miss_entry.value, group->handle.index));
841          lvp_call_ray_tracing_stage(b, compiler, stage);
842          nir_pop_if(b, NULL);
843       }
844    }
845    nir_pop_if(b, NULL);
846 
847    nir_store_var(b, state->stack_ptr, stack_ptr, 0x1);
848 }
849 
850 static bool
lvp_lower_ray_tracing_instr(nir_builder * b,nir_instr * instr,void * data)851 lvp_lower_ray_tracing_instr(nir_builder *b, nir_instr *instr, void *data)
852 {
853    struct lvp_ray_tracing_pipeline_compiler *compiler = data;
854    struct lvp_ray_tracing_state *state = &compiler->state;
855 
856    if (instr->type == nir_instr_type_jump) {
857       nir_jump_instr *jump = nir_instr_as_jump(instr);
858       if (jump->type == nir_jump_halt) {
859          jump->type = nir_jump_return;
860          return true;
861       }
862       return false;
863    } else if (instr->type != nir_instr_type_intrinsic) {
864       return false;
865    }
866 
867    nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
868 
869    nir_def *def = NULL;
870 
871    b->cursor = nir_before_instr(instr);
872 
873    switch (intr->intrinsic) {
874    /* Ray tracing instructions */
875    case nir_intrinsic_execute_callable:
876       lvp_execute_callable(b, compiler, intr);
877       break;
878    case nir_intrinsic_trace_ray:
879       lvp_trace_ray(b, compiler, intr);
880       break;
881    case nir_intrinsic_ignore_ray_intersection: {
882       nir_store_var(b, state->accept, nir_imm_false(b), 0x1);
883 
884       nir_push_if(b, nir_imm_true(b));
885       nir_jump(b, nir_jump_return);
886       nir_pop_if(b, NULL);
887       break;
888    }
889    case nir_intrinsic_terminate_ray: {
890       nir_store_var(b, state->accept, nir_imm_true(b), 0x1);
891       nir_store_var(b, state->terminate, nir_imm_true(b), 0x1);
892 
893       nir_push_if(b, nir_imm_true(b));
894       nir_jump(b, nir_jump_return);
895       nir_pop_if(b, NULL);
896       break;
897    }
898    /* Ray tracing system values */
899    case nir_intrinsic_load_ray_launch_id:
900       def = nir_load_global_invocation_id(b, 32);
901       break;
902    case nir_intrinsic_load_ray_launch_size:
903       def = lvp_load_trace_ray_command_field(
904          b, offsetof(VkTraceRaysIndirectCommand2KHR, width), 3, 32);
905       break;
906    case nir_intrinsic_load_shader_record_ptr:
907       def = nir_load_var(b, state->shader_record_ptr);
908       break;
909    case nir_intrinsic_load_ray_t_min:
910       def = nir_load_var(b, state->tmin);
911       break;
912    case nir_intrinsic_load_ray_t_max:
913       def = nir_load_var(b, state->tmax);
914       break;
915    case nir_intrinsic_load_ray_world_origin:
916       def = nir_load_var(b, state->origin);
917       break;
918    case nir_intrinsic_load_ray_world_direction:
919       def = nir_load_var(b, state->dir);
920       break;
921    case nir_intrinsic_load_ray_instance_custom_index: {
922       nir_def *instance_node_addr = nir_load_var(b, state->instance_addr);
923       nir_def *custom_instance_and_mask = nir_build_load_global(
924          b, 1, 32,
925          nir_iadd_imm(b, instance_node_addr, offsetof(struct lvp_bvh_instance_node, custom_instance_and_mask)));
926       def = nir_iand_imm(b, custom_instance_and_mask, 0xFFFFFF);
927       break;
928    }
929    case nir_intrinsic_load_primitive_id:
930       def = nir_load_var(b, state->primitive_id);
931       break;
932    case nir_intrinsic_load_ray_geometry_index:
933       def = nir_load_var(b, state->geometry_id_and_flags);
934       def = nir_iand_imm(b, def, 0xFFFFFFF);
935       break;
936    case nir_intrinsic_load_instance_id: {
937       nir_def *instance_node_addr = nir_load_var(b, state->instance_addr);
938       def = nir_build_load_global(
939          b, 1, 32, nir_iadd_imm(b, instance_node_addr, offsetof(struct lvp_bvh_instance_node, instance_id)));
940       break;
941    }
942    case nir_intrinsic_load_ray_flags:
943       def = nir_load_var(b, state->flags);
944       break;
945    case nir_intrinsic_load_ray_hit_kind:
946       def = nir_load_var(b, state->hit_kind);
947       break;
948    case nir_intrinsic_load_ray_world_to_object: {
949       unsigned c = nir_intrinsic_column(intr);
950       nir_def *instance_node_addr = nir_load_var(b, state->instance_addr);
951       nir_def *wto_matrix[3];
952       lvp_load_wto_matrix(b, instance_node_addr, wto_matrix);
953 
954       nir_def *vals[3];
955       for (unsigned i = 0; i < 3; ++i)
956          vals[i] = nir_channel(b, wto_matrix[i], c);
957 
958       def = nir_vec(b, vals, 3);
959       break;
960    }
961    case nir_intrinsic_load_ray_object_to_world: {
962       unsigned c = nir_intrinsic_column(intr);
963       nir_def *instance_node_addr = nir_load_var(b, state->instance_addr);
964       nir_def *rows[3];
965       for (unsigned r = 0; r < 3; ++r)
966          rows[r] = nir_build_load_global(
967             b, 4, 32,
968             nir_iadd_imm(b, instance_node_addr, offsetof(struct lvp_bvh_instance_node, otw_matrix) + r * 16));
969       def = nir_vec3(b, nir_channel(b, rows[0], c), nir_channel(b, rows[1], c), nir_channel(b, rows[2], c));
970       break;
971    }
972    case nir_intrinsic_load_ray_object_origin: {
973       nir_def *instance_node_addr = nir_load_var(b, state->instance_addr);
974       nir_def *wto_matrix[3];
975       lvp_load_wto_matrix(b, instance_node_addr, wto_matrix);
976       def = lvp_mul_vec3_mat(b, nir_load_var(b, state->origin), wto_matrix, true);
977       break;
978    }
979    case nir_intrinsic_load_ray_object_direction: {
980       nir_def *instance_node_addr = nir_load_var(b, state->instance_addr);
981       nir_def *wto_matrix[3];
982       lvp_load_wto_matrix(b, instance_node_addr, wto_matrix);
983       def = lvp_mul_vec3_mat(b, nir_load_var(b, state->dir), wto_matrix, false);
984       break;
985    }
986    case nir_intrinsic_load_cull_mask:
987       def = nir_iand_imm(b, nir_load_var(b, state->cull_mask), 0xFF);
988       break;
989    /* Ray tracing stack lowering */
990    case nir_intrinsic_load_scratch: {
991       nir_src_rewrite(&intr->src[0], nir_iadd(b, nir_load_var(b, state->stack_ptr), intr->src[0].ssa));
992       return true;
993    }
994    case nir_intrinsic_store_scratch: {
995       nir_src_rewrite(&intr->src[1], nir_iadd(b, nir_load_var(b, state->stack_ptr), intr->src[1].ssa));
996       return true;
997    }
998    case nir_intrinsic_load_ray_triangle_vertex_positions: {
999       def = lvp_load_vertex_position(
1000          b, nir_load_var(b, state->instance_addr), nir_load_var(b, state->primitive_id),
1001          nir_intrinsic_column(intr));
1002       break;
1003    }
1004    /* Internal system values */
1005    case nir_intrinsic_load_shader_call_data_offset_lvp:
1006       def = nir_load_var(b, state->shader_call_data_offset);
1007       break;
1008    default:
1009       return false;
1010    }
1011 
1012    if (def)
1013       nir_def_rewrite_uses(&intr->def, def);
1014    nir_instr_remove(instr);
1015 
1016    return true;
1017 }
1018 
1019 static bool
lvp_lower_ray_tracing_stack_base(nir_builder * b,nir_intrinsic_instr * instr,void * data)1020 lvp_lower_ray_tracing_stack_base(nir_builder *b, nir_intrinsic_instr *instr, void *data)
1021 {
1022    if (instr->intrinsic != nir_intrinsic_load_ray_tracing_stack_base_lvp)
1023       return false;
1024 
1025    b->cursor = nir_after_instr(&instr->instr);
1026 
1027    nir_def_replace(&instr->def, nir_imm_int(b, b->shader->scratch_size));
1028 
1029    return true;
1030 }
1031 
1032 static void
lvp_compile_ray_tracing_pipeline(struct lvp_pipeline * pipeline,const VkRayTracingPipelineCreateInfoKHR * create_info)1033 lvp_compile_ray_tracing_pipeline(struct lvp_pipeline *pipeline,
1034                                  const VkRayTracingPipelineCreateInfoKHR *create_info)
1035 {
1036    nir_builder _b = nir_builder_init_simple_shader(
1037       MESA_SHADER_COMPUTE,
1038       pipeline->device->pscreen->get_compiler_options(pipeline->device->pscreen, PIPE_SHADER_IR_NIR, MESA_SHADER_COMPUTE),
1039       "ray tracing pipeline");
1040    nir_builder *b = &_b;
1041 
1042    b->shader->info.workgroup_size[0] = 8;
1043 
1044    struct lvp_ray_tracing_pipeline_compiler compiler = {
1045       .pipeline = pipeline,
1046       .flags = vk_rt_pipeline_create_flags(create_info),
1047    };
1048    lvp_ray_tracing_state_init(b->shader, &compiler.state);
1049    compiler.functions = _mesa_pointer_hash_table_create(NULL);
1050 
1051    nir_def *launch_id = nir_load_ray_launch_id(b);
1052    nir_def *launch_size = nir_load_ray_launch_size(b);
1053    nir_def *oob = nir_ige(b, nir_channel(b, launch_id, 0), nir_channel(b, launch_size, 0));
1054    oob = nir_ior(b, oob, nir_ige(b, nir_channel(b, launch_id, 1), nir_channel(b, launch_size, 1)));
1055    oob = nir_ior(b, oob, nir_ige(b, nir_channel(b, launch_id, 2), nir_channel(b, launch_size, 2)));
1056 
1057    nir_push_if(b, oob);
1058    nir_jump(b, nir_jump_return);
1059    nir_pop_if(b, NULL);
1060 
1061    nir_store_var(b, compiler.state.stack_ptr, nir_load_ray_tracing_stack_base_lvp(b), 0x1);
1062 
1063    struct lvp_sbt_entry raygen_entry = lvp_load_sbt_entry(
1064       b,
1065       NULL,
1066       offsetof(VkTraceRaysIndirectCommand2KHR, raygenShaderRecordAddress),
1067       offsetof(struct lvp_ray_tracing_group_handle, index));
1068    nir_store_var(b, compiler.state.shader_record_ptr, raygen_entry.shader_record_ptr, 0x1);
1069 
1070    for (uint32_t i = 0; i < pipeline->rt.group_count; i++) {
1071       struct lvp_ray_tracing_group *group = pipeline->rt.groups + i;
1072       if (group->recursive_index == VK_SHADER_UNUSED_KHR)
1073          continue;
1074 
1075       nir_shader *stage = pipeline->rt.stages[group->recursive_index]->nir;
1076 
1077       if (stage->info.stage != MESA_SHADER_RAYGEN)
1078          continue;
1079 
1080       nir_push_if(b, nir_ieq_imm(b, raygen_entry.value, group->handle.index));
1081       lvp_call_ray_tracing_stage(b, &compiler, stage);
1082       nir_pop_if(b, NULL);
1083    }
1084 
1085    nir_shader_instructions_pass(b->shader, lvp_lower_ray_tracing_instr, nir_metadata_none, &compiler);
1086 
1087    NIR_PASS(_, b->shader, nir_lower_returns);
1088 
1089    const struct nir_lower_compute_system_values_options compute_system_values = {0};
1090    NIR_PASS(_, b->shader, nir_lower_compute_system_values, &compute_system_values);
1091    NIR_PASS(_, b->shader, nir_lower_global_vars_to_local);
1092    NIR_PASS(_, b->shader, nir_lower_vars_to_ssa);
1093 
1094    NIR_PASS(_, b->shader, nir_lower_vars_to_explicit_types,
1095             nir_var_shader_temp,
1096             glsl_get_natural_size_align_bytes);
1097 
1098    NIR_PASS(_, b->shader, nir_lower_explicit_io, nir_var_shader_temp,
1099             nir_address_format_32bit_offset);
1100 
1101    NIR_PASS(_, b->shader, nir_shader_intrinsics_pass, lvp_lower_ray_tracing_stack_base,
1102             nir_metadata_control_flow, NULL);
1103 
1104    /* We can not support dynamic stack sizes, assume the worst. */
1105    b->shader->scratch_size +=
1106       compiler.raygen_size +
1107       MIN2(create_info->maxPipelineRayRecursionDepth, 1) * MAX3(compiler.chit_size, compiler.miss_size, compiler.isec_size + compiler.ahit_size) +
1108       MAX2(0, (int)create_info->maxPipelineRayRecursionDepth - 1) * MAX2(compiler.chit_size, compiler.miss_size) + 31 * compiler.callable_size;
1109 
1110    struct lvp_shader *shader = &pipeline->shaders[MESA_SHADER_RAYGEN];
1111    lvp_shader_init(shader, b->shader);
1112    shader->shader_cso = lvp_shader_compile(pipeline->device, shader, nir_shader_clone(NULL, shader->pipeline_nir->nir), false);
1113 
1114    _mesa_hash_table_destroy(compiler.functions, NULL);
1115 }
1116 
1117 static VkResult
lvp_create_ray_tracing_pipeline(VkDevice _device,const VkAllocationCallbacks * allocator,const VkRayTracingPipelineCreateInfoKHR * create_info,VkPipeline * out_pipeline)1118 lvp_create_ray_tracing_pipeline(VkDevice _device, const VkAllocationCallbacks *allocator,
1119                                 const VkRayTracingPipelineCreateInfoKHR *create_info,
1120                                 VkPipeline *out_pipeline)
1121 {
1122    VK_FROM_HANDLE(lvp_device, device, _device);
1123    VK_FROM_HANDLE(lvp_pipeline_layout, layout, create_info->layout);
1124 
1125    VkResult result = VK_SUCCESS;
1126 
1127    struct lvp_pipeline *pipeline = vk_zalloc2(&device->vk.alloc, allocator, sizeof(struct lvp_pipeline), 8,
1128                                               VK_SYSTEM_ALLOCATION_SCOPE_OBJECT);
1129    if (!pipeline)
1130       return VK_ERROR_OUT_OF_HOST_MEMORY;
1131 
1132    vk_object_base_init(&device->vk, &pipeline->base,
1133                        VK_OBJECT_TYPE_PIPELINE);
1134 
1135    vk_pipeline_layout_ref(&layout->vk);
1136 
1137    pipeline->device = device;
1138    pipeline->layout = layout;
1139    pipeline->type = LVP_PIPELINE_RAY_TRACING;
1140    pipeline->flags = vk_rt_pipeline_create_flags(create_info);
1141 
1142    pipeline->rt.stage_count = create_info->stageCount;
1143    pipeline->rt.group_count = create_info->groupCount;
1144    if (create_info->pLibraryInfo) {
1145       for (uint32_t i = 0; i < create_info->pLibraryInfo->libraryCount; i++) {
1146          VK_FROM_HANDLE(lvp_pipeline, library, create_info->pLibraryInfo->pLibraries[i]);
1147          pipeline->rt.stage_count += library->rt.stage_count;
1148          pipeline->rt.group_count += library->rt.group_count;
1149       }
1150    }
1151 
1152    pipeline->rt.stages = calloc(pipeline->rt.stage_count, sizeof(struct lvp_pipeline_nir *));
1153    pipeline->rt.groups = calloc(pipeline->rt.group_count, sizeof(struct lvp_ray_tracing_group));
1154    if (!pipeline->rt.stages || !pipeline->rt.groups) {
1155       result = VK_ERROR_OUT_OF_HOST_MEMORY;
1156       goto fail;
1157    }
1158 
1159    result = lvp_compile_ray_tracing_stages(pipeline, create_info);
1160    if (result != VK_SUCCESS)
1161       goto fail;
1162 
1163    lvp_init_ray_tracing_groups(pipeline, create_info);
1164 
1165    if (!(pipeline->flags & VK_PIPELINE_CREATE_2_LIBRARY_BIT_KHR)) {
1166       lvp_compile_ray_tracing_pipeline(pipeline, create_info);
1167    }
1168 
1169    *out_pipeline = lvp_pipeline_to_handle(pipeline);
1170 
1171    return VK_SUCCESS;
1172 
1173 fail:
1174    lvp_pipeline_destroy(device, pipeline, false);
1175    return result;
1176 }
1177 
1178 VKAPI_ATTR VkResult VKAPI_CALL
lvp_CreateRayTracingPipelinesKHR(VkDevice device,VkDeferredOperationKHR deferredOperation,VkPipelineCache pipelineCache,uint32_t createInfoCount,const VkRayTracingPipelineCreateInfoKHR * pCreateInfos,const VkAllocationCallbacks * pAllocator,VkPipeline * pPipelines)1179 lvp_CreateRayTracingPipelinesKHR(
1180    VkDevice device,
1181    VkDeferredOperationKHR deferredOperation,
1182    VkPipelineCache pipelineCache,
1183    uint32_t createInfoCount,
1184    const VkRayTracingPipelineCreateInfoKHR *pCreateInfos,
1185    const VkAllocationCallbacks *pAllocator,
1186    VkPipeline *pPipelines)
1187 {
1188    VkResult result = VK_SUCCESS;
1189 
1190    uint32_t i = 0;
1191    for (; i < createInfoCount; i++) {
1192       VkResult tmp_result = lvp_create_ray_tracing_pipeline(
1193          device, pAllocator, pCreateInfos + i, pPipelines + i);
1194 
1195       if (tmp_result != VK_SUCCESS) {
1196          result = tmp_result;
1197          pPipelines[i] = VK_NULL_HANDLE;
1198 
1199          if (vk_rt_pipeline_create_flags(&pCreateInfos[i]) &
1200              VK_PIPELINE_CREATE_2_EARLY_RETURN_ON_FAILURE_BIT_KHR)
1201             break;
1202       }
1203    }
1204 
1205    for (; i < createInfoCount; i++)
1206       pPipelines[i] = VK_NULL_HANDLE;
1207 
1208    return result;
1209 }
1210 
1211 
1212 VKAPI_ATTR VkResult VKAPI_CALL
lvp_GetRayTracingShaderGroupHandlesKHR(VkDevice _device,VkPipeline _pipeline,uint32_t firstGroup,uint32_t groupCount,size_t dataSize,void * pData)1213 lvp_GetRayTracingShaderGroupHandlesKHR(
1214     VkDevice _device,
1215     VkPipeline _pipeline,
1216     uint32_t firstGroup,
1217     uint32_t groupCount,
1218     size_t dataSize,
1219     void *pData)
1220 {
1221    VK_FROM_HANDLE(lvp_pipeline, pipeline, _pipeline);
1222 
1223    uint8_t *data = pData;
1224    memset(data, 0, dataSize);
1225 
1226    for (uint32_t i = 0; i < groupCount; i++) {
1227       memcpy(data + i * LVP_RAY_TRACING_GROUP_HANDLE_SIZE,
1228              pipeline->rt.groups + firstGroup + i,
1229              sizeof(struct lvp_ray_tracing_group_handle));
1230    }
1231 
1232    return VK_SUCCESS;
1233 }
1234 
1235 VKAPI_ATTR VkResult VKAPI_CALL
lvp_GetRayTracingCaptureReplayShaderGroupHandlesKHR(VkDevice device,VkPipeline pipeline,uint32_t firstGroup,uint32_t groupCount,size_t dataSize,void * pData)1236 lvp_GetRayTracingCaptureReplayShaderGroupHandlesKHR(
1237    VkDevice device,
1238    VkPipeline pipeline,
1239    uint32_t firstGroup,
1240    uint32_t groupCount,
1241    size_t dataSize,
1242    void *pData)
1243 {
1244    return VK_SUCCESS;
1245 }
1246 
1247 VKAPI_ATTR VkDeviceSize VKAPI_CALL
lvp_GetRayTracingShaderGroupStackSizeKHR(VkDevice device,VkPipeline pipeline,uint32_t group,VkShaderGroupShaderKHR groupShader)1248 lvp_GetRayTracingShaderGroupStackSizeKHR(
1249    VkDevice device,
1250    VkPipeline pipeline,
1251    uint32_t group,
1252    VkShaderGroupShaderKHR groupShader)
1253 {
1254    return 4;
1255 }
1256