1 /*
2 * Copyright © 2024 Valve Corporation
3 *
4 * SPDX-License-Identifier: MIT
5 */
6
7 #include "lvp_private.h"
8 #include "lvp_acceleration_structure.h"
9 #include "lvp_nir_ray_tracing.h"
10
11 #include "vk_pipeline.h"
12
13 #include "nir.h"
14 #include "nir_builder.h"
15
16 #include "spirv/spirv.h"
17
18 #include "util/mesa-sha1.h"
19 #include "util/simple_mtx.h"
20
21 static void
lvp_init_ray_tracing_groups(struct lvp_pipeline * pipeline,const VkRayTracingPipelineCreateInfoKHR * create_info)22 lvp_init_ray_tracing_groups(struct lvp_pipeline *pipeline,
23 const VkRayTracingPipelineCreateInfoKHR *create_info)
24 {
25 uint32_t i = 0;
26 for (; i < create_info->groupCount; i++) {
27 const VkRayTracingShaderGroupCreateInfoKHR *group_info = create_info->pGroups + i;
28 struct lvp_ray_tracing_group *dst = pipeline->rt.groups + i;
29
30 dst->recursive_index = VK_SHADER_UNUSED_KHR;
31 dst->ahit_index = VK_SHADER_UNUSED_KHR;
32 dst->isec_index = VK_SHADER_UNUSED_KHR;
33
34 switch (group_info->type) {
35 case VK_RAY_TRACING_SHADER_GROUP_TYPE_GENERAL_KHR:
36 if (group_info->generalShader != VK_SHADER_UNUSED_KHR) {
37 dst->recursive_index = group_info->generalShader;
38 }
39 break;
40 case VK_RAY_TRACING_SHADER_GROUP_TYPE_TRIANGLES_HIT_GROUP_KHR:
41 if (group_info->closestHitShader != VK_SHADER_UNUSED_KHR) {
42 dst->recursive_index = group_info->closestHitShader;
43 }
44 if (group_info->anyHitShader != VK_SHADER_UNUSED_KHR) {
45 dst->ahit_index = group_info->anyHitShader;
46 }
47 break;
48 case VK_RAY_TRACING_SHADER_GROUP_TYPE_PROCEDURAL_HIT_GROUP_KHR:
49 if (group_info->closestHitShader != VK_SHADER_UNUSED_KHR) {
50 dst->recursive_index = group_info->closestHitShader;
51 }
52 if (group_info->intersectionShader != VK_SHADER_UNUSED_KHR) {
53 dst->isec_index = group_info->intersectionShader;
54
55 if (group_info->anyHitShader != VK_SHADER_UNUSED_KHR)
56 dst->ahit_index = group_info->anyHitShader;
57 }
58 break;
59 default:
60 unreachable("Unimplemented VkRayTracingShaderGroupTypeKHR");
61 }
62
63 dst->handle.index = p_atomic_inc_return(&pipeline->device->group_handle_alloc);
64 }
65
66 if (!create_info->pLibraryInfo)
67 return;
68
69 uint32_t stage_base_index = create_info->stageCount;
70 for (uint32_t library_index = 0; library_index < create_info->pLibraryInfo->libraryCount; library_index++) {
71 VK_FROM_HANDLE(lvp_pipeline, library, create_info->pLibraryInfo->pLibraries[library_index]);
72 for (uint32_t group_index = 0; group_index < library->rt.group_count; group_index++) {
73 const struct lvp_ray_tracing_group *src = library->rt.groups + group_index;
74 struct lvp_ray_tracing_group *dst = pipeline->rt.groups + i;
75
76 dst->handle = src->handle;
77
78 if (src->recursive_index != VK_SHADER_UNUSED_KHR)
79 dst->recursive_index = stage_base_index + src->recursive_index;
80 else
81 dst->recursive_index = VK_SHADER_UNUSED_KHR;
82
83 if (src->ahit_index != VK_SHADER_UNUSED_KHR)
84 dst->ahit_index = stage_base_index + src->ahit_index;
85 else
86 dst->ahit_index = VK_SHADER_UNUSED_KHR;
87
88 if (src->isec_index != VK_SHADER_UNUSED_KHR)
89 dst->isec_index = stage_base_index + src->isec_index;
90 else
91 dst->isec_index = VK_SHADER_UNUSED_KHR;
92
93 i++;
94 }
95 stage_base_index += library->rt.stage_count;
96 }
97 }
98
99 static bool
lvp_lower_ray_tracing_derefs(nir_shader * shader)100 lvp_lower_ray_tracing_derefs(nir_shader *shader)
101 {
102 nir_function_impl *impl = nir_shader_get_entrypoint(shader);
103
104 bool progress = false;
105
106 nir_builder _b = nir_builder_at(nir_before_impl(impl));
107 nir_builder *b = &_b;
108
109 nir_def *arg_offset = nir_load_shader_call_data_offset_lvp(b);
110
111 nir_foreach_block (block, impl) {
112 nir_foreach_instr_safe (instr, block) {
113 if (instr->type != nir_instr_type_deref)
114 continue;
115
116 nir_deref_instr *deref = nir_instr_as_deref(instr);
117 if (!nir_deref_mode_is_one_of(deref, nir_var_shader_call_data |
118 nir_var_ray_hit_attrib))
119 continue;
120
121 bool is_shader_call_data = nir_deref_mode_is(deref, nir_var_shader_call_data);
122
123 deref->modes = nir_var_function_temp;
124 progress = true;
125
126 if (deref->deref_type == nir_deref_type_var) {
127 b->cursor = nir_before_instr(&deref->instr);
128 nir_def *offset = is_shader_call_data ? arg_offset : nir_imm_int(b, 0);
129 nir_deref_instr *replacement =
130 nir_build_deref_cast(b, offset, nir_var_function_temp, deref->var->type, 0);
131 nir_def_replace(&deref->def, &replacement->def);
132 }
133 }
134 }
135
136 if (progress)
137 nir_metadata_preserve(impl, nir_metadata_control_flow);
138 else
139 nir_metadata_preserve(impl, nir_metadata_all);
140
141 return progress;
142 }
143
144 static bool
lvp_move_ray_tracing_intrinsic(nir_builder * b,nir_intrinsic_instr * instr,void * data)145 lvp_move_ray_tracing_intrinsic(nir_builder *b, nir_intrinsic_instr *instr, void *data)
146 {
147 switch (instr->intrinsic) {
148 case nir_intrinsic_load_shader_record_ptr:
149 case nir_intrinsic_load_ray_flags:
150 case nir_intrinsic_load_ray_object_origin:
151 case nir_intrinsic_load_ray_world_origin:
152 case nir_intrinsic_load_ray_t_min:
153 case nir_intrinsic_load_ray_object_direction:
154 case nir_intrinsic_load_ray_world_direction:
155 case nir_intrinsic_load_ray_t_max:
156 nir_instr_move(nir_before_impl(b->impl), &instr->instr);
157 return true;
158 default:
159 return false;
160 }
161 }
162
163 static VkResult
lvp_compile_ray_tracing_stages(struct lvp_pipeline * pipeline,const VkRayTracingPipelineCreateInfoKHR * create_info)164 lvp_compile_ray_tracing_stages(struct lvp_pipeline *pipeline,
165 const VkRayTracingPipelineCreateInfoKHR *create_info)
166 {
167 VkResult result = VK_SUCCESS;
168
169 uint32_t i = 0;
170 for (; i < create_info->stageCount; i++) {
171 nir_shader *nir;
172 result = lvp_spirv_to_nir(pipeline, create_info->pStages + i, &nir);
173 if (result != VK_SUCCESS)
174 return result;
175
176 assert(!nir->scratch_size);
177 if (nir->info.stage == MESA_SHADER_ANY_HIT ||
178 nir->info.stage == MESA_SHADER_CLOSEST_HIT ||
179 nir->info.stage == MESA_SHADER_INTERSECTION)
180 nir->scratch_size = LVP_RAY_HIT_ATTRIBS_SIZE;
181
182 NIR_PASS(_, nir, nir_lower_vars_to_explicit_types,
183 nir_var_function_temp | nir_var_shader_call_data | nir_var_ray_hit_attrib,
184 glsl_get_natural_size_align_bytes);
185
186 NIR_PASS(_, nir, lvp_lower_ray_tracing_derefs);
187
188 NIR_PASS(_, nir, nir_lower_explicit_io, nir_var_function_temp, nir_address_format_32bit_offset);
189
190 NIR_PASS(_, nir, nir_shader_intrinsics_pass, lvp_move_ray_tracing_intrinsic,
191 nir_metadata_control_flow, NULL);
192
193 pipeline->rt.stages[i] = lvp_create_pipeline_nir(nir);
194 if (!pipeline->rt.stages[i]) {
195 result = VK_ERROR_OUT_OF_HOST_MEMORY;
196 ralloc_free(nir);
197 return result;
198 }
199 if (pipeline->layout)
200 pipeline->shaders[nir->info.stage].push_constant_size = pipeline->layout->push_constant_size;
201 }
202
203 if (!create_info->pLibraryInfo)
204 return result;
205
206 for (uint32_t library_index = 0; library_index < create_info->pLibraryInfo->libraryCount; library_index++) {
207 VK_FROM_HANDLE(lvp_pipeline, library, create_info->pLibraryInfo->pLibraries[library_index]);
208 for (uint32_t stage_index = 0; stage_index < library->rt.stage_count; stage_index++) {
209 lvp_pipeline_nir_ref(pipeline->rt.stages + i, library->rt.stages[stage_index]);
210 i++;
211 }
212 }
213
214 return result;
215 }
216
217 static nir_def *
lvp_load_trace_ray_command_field(nir_builder * b,uint32_t command_offset,uint32_t num_components,uint32_t bit_size)218 lvp_load_trace_ray_command_field(nir_builder *b, uint32_t command_offset,
219 uint32_t num_components, uint32_t bit_size)
220 {
221 return nir_load_ssbo(b, num_components, bit_size, nir_imm_int(b, 0),
222 nir_imm_int(b, command_offset));
223 }
224
225 struct lvp_sbt_entry {
226 nir_def *value;
227 nir_def *shader_record_ptr;
228 };
229
230 static struct lvp_sbt_entry
lvp_load_sbt_entry(nir_builder * b,nir_def * index,uint32_t command_offset,uint32_t index_offset)231 lvp_load_sbt_entry(nir_builder *b, nir_def *index,
232 uint32_t command_offset, uint32_t index_offset)
233 {
234 nir_def *addr = lvp_load_trace_ray_command_field(b, command_offset, 1, 64);
235
236 if (index) {
237 /* The 32 high bits of stride can be ignored. */
238 nir_def *stride = lvp_load_trace_ray_command_field(
239 b, command_offset + sizeof(VkDeviceSize) * 2, 1, 32);
240 addr = nir_iadd(b, addr, nir_u2u64(b, nir_imul(b, index, stride)));
241 }
242
243 return (struct lvp_sbt_entry) {
244 .value = nir_build_load_global(b, 1, 32, nir_iadd_imm(b, addr, index_offset)),
245 .shader_record_ptr = nir_iadd_imm(b, addr, LVP_RAY_TRACING_GROUP_HANDLE_SIZE),
246 };
247 }
248
249 struct lvp_ray_traversal_state {
250 nir_variable *origin;
251 nir_variable *dir;
252 nir_variable *inv_dir;
253 nir_variable *bvh_base;
254 nir_variable *current_node;
255 nir_variable *stack_base;
256 nir_variable *stack_ptr;
257 nir_variable *stack;
258 nir_variable *hit;
259
260 nir_variable *instance_addr;
261 nir_variable *sbt_offset_and_flags;
262 };
263
264 struct lvp_ray_tracing_state {
265 nir_variable *bvh_base;
266 nir_variable *flags;
267 nir_variable *cull_mask;
268 nir_variable *sbt_offset;
269 nir_variable *sbt_stride;
270 nir_variable *miss_index;
271 nir_variable *origin;
272 nir_variable *tmin;
273 nir_variable *dir;
274 nir_variable *tmax;
275
276 nir_variable *instance_addr;
277 nir_variable *primitive_id;
278 nir_variable *geometry_id_and_flags;
279 nir_variable *hit_kind;
280 nir_variable *sbt_index;
281
282 nir_variable *shader_record_ptr;
283 nir_variable *stack_ptr;
284 nir_variable *shader_call_data_offset;
285
286 nir_variable *accept;
287 nir_variable *terminate;
288 nir_variable *opaque;
289
290 struct lvp_ray_traversal_state traversal;
291 };
292
293 struct lvp_ray_tracing_pipeline_compiler {
294 struct lvp_pipeline *pipeline;
295 VkPipelineCreateFlags2KHR flags;
296
297 struct lvp_ray_tracing_state state;
298
299 struct hash_table *functions;
300
301 uint32_t raygen_size;
302 uint32_t ahit_size;
303 uint32_t chit_size;
304 uint32_t miss_size;
305 uint32_t isec_size;
306 uint32_t callable_size;
307 };
308
309 static uint32_t
lvp_ray_tracing_pipeline_compiler_get_stack_size(struct lvp_ray_tracing_pipeline_compiler * compiler,nir_function * function)310 lvp_ray_tracing_pipeline_compiler_get_stack_size(
311 struct lvp_ray_tracing_pipeline_compiler *compiler, nir_function *function)
312 {
313 hash_table_foreach(compiler->functions, entry) {
314 if (entry->data == function) {
315 const nir_shader *shader = entry->key;
316 return shader->scratch_size;
317 }
318 }
319 return 0;
320 }
321
322 static void
lvp_ray_tracing_state_init(nir_shader * nir,struct lvp_ray_tracing_state * state)323 lvp_ray_tracing_state_init(nir_shader *nir, struct lvp_ray_tracing_state *state)
324 {
325 state->bvh_base = nir_variable_create(nir, nir_var_shader_temp, glsl_uint64_t_type(), "bvh_base");
326 state->flags = nir_variable_create(nir, nir_var_shader_temp, glsl_uint_type(), "flags");
327 state->cull_mask = nir_variable_create(nir, nir_var_shader_temp, glsl_uint_type(), "cull_mask");
328 state->sbt_offset = nir_variable_create(nir, nir_var_shader_temp, glsl_uint_type(), "sbt_offset");
329 state->sbt_stride = nir_variable_create(nir, nir_var_shader_temp, glsl_uint_type(), "sbt_stride");
330 state->miss_index = nir_variable_create(nir, nir_var_shader_temp, glsl_uint_type(), "miss_index");
331 state->origin = nir_variable_create(nir, nir_var_shader_temp, glsl_vec_type(3), "origin");
332 state->tmin = nir_variable_create(nir, nir_var_shader_temp, glsl_float_type(), "tmin");
333 state->dir = nir_variable_create(nir, nir_var_shader_temp, glsl_vec_type(3), "dir");
334 state->tmax = nir_variable_create(nir, nir_var_shader_temp, glsl_float_type(), "tmax");
335
336 state->instance_addr = nir_variable_create(nir, nir_var_shader_temp, glsl_uint64_t_type(), "instance_addr");
337 state->primitive_id = nir_variable_create(nir, nir_var_shader_temp, glsl_uint_type(), "primitive_id");
338 state->geometry_id_and_flags = nir_variable_create(nir, nir_var_shader_temp, glsl_uint_type(), "geometry_id_and_flags");
339 state->hit_kind = nir_variable_create(nir, nir_var_shader_temp, glsl_uint_type(), "hit_kind");
340 state->sbt_index = nir_variable_create(nir, nir_var_shader_temp, glsl_uint_type(), "sbt_index");
341
342 state->shader_record_ptr = nir_variable_create(nir, nir_var_shader_temp, glsl_uint64_t_type(), "shader_record_ptr");
343 state->stack_ptr = nir_variable_create(nir, nir_var_shader_temp, glsl_uint_type(), "stack_ptr");
344 state->shader_call_data_offset = nir_variable_create(nir, nir_var_shader_temp, glsl_uint_type(), "shader_call_data_offset");
345
346 state->accept = nir_variable_create(nir, nir_var_shader_temp, glsl_bool_type(), "accept");
347 state->terminate = nir_variable_create(nir, nir_var_shader_temp, glsl_bool_type(), "terminate");
348 state->opaque = nir_variable_create(nir, nir_var_shader_temp, glsl_bool_type(), "opaque");
349 }
350
351 static void
lvp_ray_traversal_state_init(nir_function_impl * impl,struct lvp_ray_traversal_state * state)352 lvp_ray_traversal_state_init(nir_function_impl *impl, struct lvp_ray_traversal_state *state)
353 {
354 state->origin = nir_local_variable_create(impl, glsl_vec_type(3), "traversal.origin");
355 state->dir = nir_local_variable_create(impl, glsl_vec_type(3), "traversal.dir");
356 state->inv_dir = nir_local_variable_create(impl, glsl_vec_type(3), "traversal.inv_dir");
357 state->bvh_base = nir_local_variable_create(impl, glsl_uint64_t_type(), "traversal.bvh_base");
358 state->current_node = nir_local_variable_create(impl, glsl_uint_type(), "traversal.current_node");
359 state->stack_base = nir_local_variable_create(impl, glsl_uint_type(), "traversal.stack_base");
360 state->stack_ptr = nir_local_variable_create(impl, glsl_uint_type(), "traversal.stack_ptr");
361 state->stack = nir_local_variable_create(impl, glsl_array_type(glsl_uint_type(), 24 * 2, 0), "traversal.stack");
362 state->hit = nir_local_variable_create(impl, glsl_bool_type(), "traversal.hit");
363
364 state->instance_addr = nir_local_variable_create(impl, glsl_uint64_t_type(), "traversal.instance_addr");
365 state->sbt_offset_and_flags = nir_local_variable_create(impl, glsl_uint_type(), "traversal.sbt_offset_and_flags");
366 }
367
368 static void
lvp_call_ray_tracing_stage(nir_builder * b,struct lvp_ray_tracing_pipeline_compiler * compiler,nir_shader * stage)369 lvp_call_ray_tracing_stage(nir_builder *b, struct lvp_ray_tracing_pipeline_compiler *compiler, nir_shader *stage)
370 {
371 nir_function *function;
372
373 struct hash_entry *entry = _mesa_hash_table_search(compiler->functions, stage);
374 if (entry) {
375 function = entry->data;
376 } else {
377 nir_function_impl *stage_entrypoint = nir_shader_get_entrypoint(stage);
378 nir_function_impl *copy = nir_function_impl_clone(b->shader, stage_entrypoint);
379
380 struct hash_table *var_remap = _mesa_pointer_hash_table_create(NULL);
381
382 nir_foreach_block(block, copy) {
383 nir_foreach_instr_safe(instr, block) {
384 if (instr->type != nir_instr_type_deref)
385 continue;
386
387 nir_deref_instr *deref = nir_instr_as_deref(instr);
388 if (deref->deref_type != nir_deref_type_var ||
389 deref->var->data.mode == nir_var_function_temp)
390 continue;
391
392 struct hash_entry *entry =
393 _mesa_hash_table_search(var_remap, deref->var);
394 if (!entry) {
395 nir_variable *new_var = nir_variable_clone(deref->var, b->shader);
396 nir_shader_add_variable(b->shader, new_var);
397 entry = _mesa_hash_table_insert(var_remap,
398 deref->var, new_var);
399 }
400 deref->var = entry->data;
401 }
402 }
403
404 function = nir_function_create(
405 b->shader, _mesa_shader_stage_to_string(stage->info.stage));
406 nir_function_set_impl(function, copy);
407
408 ralloc_free(var_remap);
409
410 _mesa_hash_table_insert(compiler->functions, stage, function);
411 }
412
413 nir_build_call(b, function, 0, NULL);
414
415 switch(stage->info.stage) {
416 case MESA_SHADER_RAYGEN:
417 compiler->raygen_size = MAX2(compiler->raygen_size, stage->scratch_size);
418 break;
419 case MESA_SHADER_ANY_HIT:
420 compiler->ahit_size = MAX2(compiler->ahit_size, stage->scratch_size);
421 break;
422 case MESA_SHADER_CLOSEST_HIT:
423 compiler->chit_size = MAX2(compiler->chit_size, stage->scratch_size);
424 break;
425 case MESA_SHADER_MISS:
426 compiler->miss_size = MAX2(compiler->miss_size, stage->scratch_size);
427 break;
428 case MESA_SHADER_INTERSECTION:
429 compiler->isec_size = MAX2(compiler->isec_size, stage->scratch_size);
430 break;
431 case MESA_SHADER_CALLABLE:
432 compiler->callable_size = MAX2(compiler->callable_size, stage->scratch_size);
433 break;
434 default:
435 unreachable("Invalid ray tracing stage");
436 break;
437 }
438 }
439
440 static void
lvp_execute_callable(nir_builder * b,struct lvp_ray_tracing_pipeline_compiler * compiler,nir_intrinsic_instr * instr)441 lvp_execute_callable(nir_builder *b, struct lvp_ray_tracing_pipeline_compiler *compiler,
442 nir_intrinsic_instr *instr)
443 {
444 struct lvp_ray_tracing_state *state = &compiler->state;
445
446 nir_def *sbt_index = instr->src[0].ssa;
447 nir_def *payload = instr->src[1].ssa;
448
449 struct lvp_sbt_entry callable_entry = lvp_load_sbt_entry(
450 b,
451 sbt_index,
452 offsetof(VkTraceRaysIndirectCommand2KHR, callableShaderBindingTableAddress),
453 offsetof(struct lvp_ray_tracing_group_handle, index));
454 nir_store_var(b, compiler->state.shader_record_ptr, callable_entry.shader_record_ptr, 0x1);
455
456 uint32_t stack_size =
457 lvp_ray_tracing_pipeline_compiler_get_stack_size(compiler, b->impl->function);
458 nir_def *stack_ptr = nir_load_var(b, state->stack_ptr);
459 nir_store_var(b, state->stack_ptr, nir_iadd_imm(b, stack_ptr, stack_size), 0x1);
460
461 nir_store_var(b, state->shader_call_data_offset, nir_iadd_imm(b, payload, -stack_size), 0x1);
462
463 for (uint32_t i = 0; i < compiler->pipeline->rt.group_count; i++) {
464 struct lvp_ray_tracing_group *group = compiler->pipeline->rt.groups + i;
465 if (group->recursive_index == VK_SHADER_UNUSED_KHR)
466 continue;
467
468 nir_shader *stage = compiler->pipeline->rt.stages[group->recursive_index]->nir;
469 if (stage->info.stage != MESA_SHADER_CALLABLE)
470 continue;
471
472 nir_push_if(b, nir_ieq_imm(b, callable_entry.value, group->handle.index));
473 lvp_call_ray_tracing_stage(b, compiler, stage);
474 nir_pop_if(b, NULL);
475 }
476
477 nir_store_var(b, state->stack_ptr, stack_ptr, 0x1);
478 }
479
480 struct lvp_lower_isec_intrinsic_state {
481 struct lvp_ray_tracing_pipeline_compiler *compiler;
482 nir_shader *ahit;
483 };
484
485 static bool
lvp_lower_isec_intrinsic(nir_builder * b,nir_intrinsic_instr * instr,void * data)486 lvp_lower_isec_intrinsic(nir_builder *b, nir_intrinsic_instr *instr, void *data)
487 {
488 if (instr->intrinsic != nir_intrinsic_report_ray_intersection)
489 return false;
490
491 struct lvp_lower_isec_intrinsic_state *isec_state = data;
492 struct lvp_ray_tracing_pipeline_compiler *compiler = isec_state->compiler;
493 struct lvp_ray_tracing_state *state = &compiler->state;
494
495 b->cursor = nir_after_instr(&instr->instr);
496
497 nir_def *t = instr->src[0].ssa;
498 nir_def *hit_kind = instr->src[1].ssa;
499
500 nir_def *prev_accept = nir_load_var(b, state->accept);
501 nir_def *prev_tmax = nir_load_var(b, state->tmax);
502 nir_def *prev_hit_kind = nir_load_var(b, state->hit_kind);
503
504 nir_variable *commit = nir_local_variable_create(b->impl, glsl_bool_type(), "commit");
505 nir_store_var(b, commit, nir_imm_false(b), 0x1);
506
507 nir_def *in_range = nir_iand(b, nir_fge(b, t, nir_load_var(b, state->tmin)), nir_fge(b, nir_load_var(b, state->tmax), t));
508 nir_def *terminated = nir_iand(b, nir_load_var(b, state->terminate), nir_load_var(b, state->accept));
509 nir_push_if(b, nir_iand(b, in_range, nir_inot(b, terminated)));
510 {
511 nir_store_var(b, state->accept, nir_imm_true(b), 0x1);
512
513 nir_store_var(b, state->tmax, t, 1);
514 nir_store_var(b, state->hit_kind, hit_kind, 1);
515
516 if (isec_state->ahit) {
517 nir_def *prev_terminate = nir_load_var(b, state->terminate);
518 nir_store_var(b, state->terminate, nir_imm_false(b), 0x1);
519
520 nir_push_if(b, nir_inot(b, nir_load_var(b, state->opaque)));
521 {
522 lvp_call_ray_tracing_stage(b, compiler, isec_state->ahit);
523 }
524 nir_pop_if(b, NULL);
525
526 nir_def *terminate = nir_load_var(b, state->terminate);
527 nir_store_var(b, state->terminate, nir_ior(b, terminate, prev_terminate), 0x1);
528
529 nir_push_if(b, terminate);
530 nir_jump(b, nir_jump_return);
531 nir_pop_if(b, NULL);
532 }
533
534 nir_push_if(b, nir_load_var(b, state->accept));
535 {
536 nir_store_var(b, commit, nir_imm_true(b), 0x1);
537 }
538 nir_push_else(b, NULL);
539 {
540 nir_store_var(b, state->accept, prev_accept, 0x1);
541 nir_store_var(b, state->tmax, prev_tmax, 1);
542 nir_store_var(b, state->hit_kind, prev_hit_kind, 1);
543 }
544 nir_pop_if(b, NULL);
545 }
546 nir_pop_if(b, NULL);
547
548 nir_def_replace(&instr->def, nir_load_var(b, commit));
549
550 return true;
551 }
552
553 static void
lvp_handle_aabb_intersection(nir_builder * b,struct lvp_leaf_intersection * intersection,const struct lvp_ray_traversal_args * args,const struct lvp_ray_flags * ray_flags)554 lvp_handle_aabb_intersection(nir_builder *b, struct lvp_leaf_intersection *intersection,
555 const struct lvp_ray_traversal_args *args,
556 const struct lvp_ray_flags *ray_flags)
557 {
558 struct lvp_ray_tracing_pipeline_compiler *compiler = args->data;
559 struct lvp_ray_tracing_state *state = &compiler->state;
560
561 nir_store_var(b, state->accept, nir_imm_false(b), 0x1);
562 nir_store_var(b, state->terminate, ray_flags->terminate_on_first_hit, 0x1);
563 nir_store_var(b, state->opaque, intersection->opaque, 0x1);
564
565 nir_def *prev_instance_addr = nir_load_var(b, state->instance_addr);
566 nir_def *prev_primitive_id = nir_load_var(b, state->primitive_id);
567 nir_def *prev_geometry_id_and_flags = nir_load_var(b, state->geometry_id_and_flags);
568
569 nir_store_var(b, state->instance_addr, nir_load_var(b, state->traversal.instance_addr), 0x1);
570 nir_store_var(b, state->primitive_id, intersection->primitive_id, 0x1);
571 nir_store_var(b, state->geometry_id_and_flags, intersection->geometry_id_and_flags, 0x1);
572
573 nir_def *geometry_id = nir_iand_imm(b, intersection->geometry_id_and_flags, 0xfffffff);
574 nir_def *sbt_index =
575 nir_iadd(b,
576 nir_iadd(b, nir_load_var(b, state->sbt_offset),
577 nir_iand_imm(b, nir_load_var(b, state->traversal.sbt_offset_and_flags), 0xffffff)),
578 nir_imul(b, nir_load_var(b, state->sbt_stride), geometry_id));
579
580 struct lvp_sbt_entry isec_entry = lvp_load_sbt_entry(
581 b,
582 sbt_index,
583 offsetof(VkTraceRaysIndirectCommand2KHR, hitShaderBindingTableAddress),
584 offsetof(struct lvp_ray_tracing_group_handle, index));
585 nir_store_var(b, compiler->state.shader_record_ptr, isec_entry.shader_record_ptr, 0x1);
586
587 for (uint32_t i = 0; i < compiler->pipeline->rt.group_count; i++) {
588 struct lvp_ray_tracing_group *group = compiler->pipeline->rt.groups + i;
589 if (group->isec_index == VK_SHADER_UNUSED_KHR)
590 continue;
591
592 nir_shader *stage = compiler->pipeline->rt.stages[group->isec_index]->nir;
593
594 nir_push_if(b, nir_ieq_imm(b, isec_entry.value, group->handle.index));
595 lvp_call_ray_tracing_stage(b, compiler, stage);
596 nir_pop_if(b, NULL);
597
598 nir_shader *ahit_stage = NULL;
599 if (group->ahit_index != VK_SHADER_UNUSED_KHR)
600 ahit_stage = compiler->pipeline->rt.stages[group->ahit_index]->nir;
601
602 struct lvp_lower_isec_intrinsic_state isec_state = {
603 .compiler = compiler,
604 .ahit = ahit_stage,
605 };
606 nir_shader_intrinsics_pass(b->shader, lvp_lower_isec_intrinsic,
607 nir_metadata_none, &isec_state);
608 }
609
610 nir_push_if(b, nir_load_var(b, state->accept));
611 {
612 nir_store_var(b, state->sbt_index, sbt_index, 0x1);
613 nir_store_var(b, state->traversal.hit, nir_imm_true(b), 0x1);
614
615 nir_break_if(b, nir_load_var(b, state->terminate));
616 }
617 nir_push_else(b, NULL);
618 {
619 nir_store_var(b, state->instance_addr, prev_instance_addr, 0x1);
620 nir_store_var(b, state->primitive_id, prev_primitive_id, 0x1);
621 nir_store_var(b, state->geometry_id_and_flags, prev_geometry_id_and_flags, 0x1);
622 }
623 nir_pop_if(b, NULL);
624 }
625
626 static void
lvp_handle_triangle_intersection(nir_builder * b,struct lvp_triangle_intersection * intersection,const struct lvp_ray_traversal_args * args,const struct lvp_ray_flags * ray_flags)627 lvp_handle_triangle_intersection(nir_builder *b,
628 struct lvp_triangle_intersection *intersection,
629 const struct lvp_ray_traversal_args *args,
630 const struct lvp_ray_flags *ray_flags)
631 {
632 struct lvp_ray_tracing_pipeline_compiler *compiler = args->data;
633 struct lvp_ray_tracing_state *state = &compiler->state;
634
635 nir_store_var(b, state->accept, nir_imm_true(b), 0x1);
636 nir_store_var(b, state->terminate, ray_flags->terminate_on_first_hit, 0x1);
637
638 nir_def *barycentrics_offset = nir_load_var(b, state->stack_ptr);
639
640 nir_def *prev_tmax = nir_load_var(b, state->tmax);
641 nir_def *prev_instance_addr = nir_load_var(b, state->instance_addr);
642 nir_def *prev_primitive_id = nir_load_var(b, state->primitive_id);
643 nir_def *prev_geometry_id_and_flags = nir_load_var(b, state->geometry_id_and_flags);
644 nir_def *prev_hit_kind = nir_load_var(b, state->hit_kind);
645 nir_def *prev_barycentrics = nir_load_scratch(b, 2, 32, barycentrics_offset);
646
647 nir_store_var(b, state->tmax, intersection->t, 0x1);
648 nir_store_var(b, state->instance_addr, nir_load_var(b, state->traversal.instance_addr), 0x1);
649 nir_store_var(b, state->primitive_id, intersection->base.primitive_id, 0x1);
650 nir_store_var(b, state->geometry_id_and_flags, intersection->base.geometry_id_and_flags, 0x1);
651 nir_store_var(b, state->hit_kind,
652 nir_bcsel(b, intersection->frontface, nir_imm_int(b, 0xFE), nir_imm_int(b, 0xFF)), 0x1);
653
654 nir_store_scratch(b, intersection->barycentrics, barycentrics_offset);
655
656 nir_def *geometry_id = nir_iand_imm(b, intersection->base.geometry_id_and_flags, 0xfffffff);
657 nir_def *sbt_index =
658 nir_iadd(b,
659 nir_iadd(b, nir_load_var(b, state->sbt_offset),
660 nir_iand_imm(b, nir_load_var(b, state->traversal.sbt_offset_and_flags), 0xffffff)),
661 nir_imul(b, nir_load_var(b, state->sbt_stride), geometry_id));
662
663 nir_push_if(b, nir_inot(b, intersection->base.opaque));
664 {
665 struct lvp_sbt_entry ahit_entry = lvp_load_sbt_entry(
666 b,
667 sbt_index,
668 offsetof(VkTraceRaysIndirectCommand2KHR, hitShaderBindingTableAddress),
669 offsetof(struct lvp_ray_tracing_group_handle, index));
670 nir_store_var(b, compiler->state.shader_record_ptr, ahit_entry.shader_record_ptr, 0x1);
671
672 for (uint32_t i = 0; i < compiler->pipeline->rt.group_count; i++) {
673 struct lvp_ray_tracing_group *group = compiler->pipeline->rt.groups + i;
674 if (group->ahit_index == VK_SHADER_UNUSED_KHR)
675 continue;
676
677 nir_shader *stage = compiler->pipeline->rt.stages[group->ahit_index]->nir;
678
679 nir_push_if(b, nir_ieq_imm(b, ahit_entry.value, group->handle.index));
680 lvp_call_ray_tracing_stage(b, compiler, stage);
681 nir_pop_if(b, NULL);
682 }
683 }
684 nir_pop_if(b, NULL);
685
686 nir_push_if(b, nir_load_var(b, state->accept));
687 {
688 nir_store_var(b, state->sbt_index, sbt_index, 0x1);
689 nir_store_var(b, state->traversal.hit, nir_imm_true(b), 0x1);
690
691 nir_break_if(b, nir_load_var(b, state->terminate));
692 }
693 nir_push_else(b, NULL);
694 {
695 nir_store_var(b, state->tmax, prev_tmax, 0x1);
696 nir_store_var(b, state->instance_addr, prev_instance_addr, 0x1);
697 nir_store_var(b, state->primitive_id, prev_primitive_id, 0x1);
698 nir_store_var(b, state->geometry_id_and_flags, prev_geometry_id_and_flags, 0x1);
699 nir_store_var(b, state->hit_kind, prev_hit_kind, 0x1);
700 nir_store_scratch(b, prev_barycentrics, barycentrics_offset);
701 }
702 nir_pop_if(b, NULL);
703 }
704
705 static void
lvp_trace_ray(nir_builder * b,struct lvp_ray_tracing_pipeline_compiler * compiler,nir_intrinsic_instr * instr)706 lvp_trace_ray(nir_builder *b, struct lvp_ray_tracing_pipeline_compiler *compiler,
707 nir_intrinsic_instr *instr)
708 {
709 struct lvp_ray_tracing_state *state = &compiler->state;
710
711 nir_def *accel_struct = instr->src[0].ssa;
712 nir_def *flags = instr->src[1].ssa;
713 nir_def *cull_mask = instr->src[2].ssa;
714 nir_def *sbt_offset = nir_iand_imm(b, instr->src[3].ssa, 0xF);
715 nir_def *sbt_stride = nir_iand_imm(b, instr->src[4].ssa, 0xF);
716 nir_def *miss_index = nir_iand_imm(b, instr->src[5].ssa, 0xFFFF);
717 nir_def *origin = instr->src[6].ssa;
718 nir_def *tmin = instr->src[7].ssa;
719 nir_def *dir = instr->src[8].ssa;
720 nir_def *tmax = instr->src[9].ssa;
721 nir_def *payload = instr->src[10].ssa;
722
723 uint32_t stack_size =
724 lvp_ray_tracing_pipeline_compiler_get_stack_size(compiler, b->impl->function);
725 nir_def *stack_ptr = nir_load_var(b, state->stack_ptr);
726 nir_store_var(b, state->stack_ptr, nir_iadd_imm(b, stack_ptr, stack_size), 0x1);
727
728 nir_store_var(b, state->shader_call_data_offset, nir_iadd_imm(b, payload, -stack_size), 0x1);
729
730 nir_def *bvh_base = accel_struct;
731 if (bvh_base->bit_size != 64) {
732 assert(bvh_base->num_components >= 2);
733 bvh_base = nir_load_ubo(
734 b, 1, 64, nir_channel(b, accel_struct, 0),
735 nir_imul_imm(b, nir_channel(b, accel_struct, 1), sizeof(struct lp_descriptor)), .range = ~0);
736 }
737
738 lvp_ray_traversal_state_init(b->impl, &state->traversal);
739
740 nir_store_var(b, state->bvh_base, bvh_base, 0x1);
741 nir_store_var(b, state->flags, flags, 0x1);
742 nir_store_var(b, state->cull_mask, cull_mask, 0x1);
743 nir_store_var(b, state->sbt_offset, sbt_offset, 0x1);
744 nir_store_var(b, state->sbt_stride, sbt_stride, 0x1);
745 nir_store_var(b, state->miss_index, miss_index, 0x1);
746 nir_store_var(b, state->origin, origin, 0x7);
747 nir_store_var(b, state->tmin, tmin, 0x1);
748 nir_store_var(b, state->dir, dir, 0x7);
749 nir_store_var(b, state->tmax, tmax, 0x1);
750
751 nir_store_var(b, state->traversal.bvh_base, bvh_base, 0x1);
752 nir_store_var(b, state->traversal.origin, origin, 0x7);
753 nir_store_var(b, state->traversal.dir, dir, 0x7);
754 nir_store_var(b, state->traversal.inv_dir, nir_frcp(b, dir), 0x7);
755 nir_store_var(b, state->traversal.current_node, nir_imm_int(b, LVP_BVH_ROOT_NODE), 0x1);
756 nir_store_var(b, state->traversal.stack_base, nir_imm_int(b, -1), 0x1);
757 nir_store_var(b, state->traversal.stack_ptr, nir_imm_int(b, 0), 0x1);
758
759 nir_store_var(b, state->traversal.hit, nir_imm_false(b), 0x1);
760
761 struct lvp_ray_traversal_vars vars = {
762 .tmax = nir_build_deref_var(b, state->tmax),
763 .origin = nir_build_deref_var(b, state->traversal.origin),
764 .dir = nir_build_deref_var(b, state->traversal.dir),
765 .inv_dir = nir_build_deref_var(b, state->traversal.inv_dir),
766 .bvh_base = nir_build_deref_var(b, state->traversal.bvh_base),
767 .current_node = nir_build_deref_var(b, state->traversal.current_node),
768 .stack_base = nir_build_deref_var(b, state->traversal.stack_base),
769 .stack_ptr = nir_build_deref_var(b, state->traversal.stack_ptr),
770 .stack = nir_build_deref_var(b, state->traversal.stack),
771 .instance_addr = nir_build_deref_var(b, state->traversal.instance_addr),
772 .sbt_offset_and_flags = nir_build_deref_var(b, state->traversal.sbt_offset_and_flags),
773 };
774
775 struct lvp_ray_traversal_args args = {
776 .root_bvh_base = bvh_base,
777 .flags = flags,
778 .cull_mask = nir_ishl_imm(b, cull_mask, 24),
779 .origin = origin,
780 .tmin = tmin,
781 .dir = dir,
782 .vars = vars,
783 .aabb_cb = (compiler->flags & VK_PIPELINE_CREATE_2_RAY_TRACING_SKIP_AABBS_BIT_KHR) ?
784 NULL : lvp_handle_aabb_intersection,
785 .triangle_cb = (compiler->flags & VK_PIPELINE_CREATE_2_RAY_TRACING_SKIP_TRIANGLES_BIT_KHR) ?
786 NULL : lvp_handle_triangle_intersection,
787 .data = compiler,
788 };
789
790 nir_push_if(b, nir_ine_imm(b, bvh_base, 0));
791 lvp_build_ray_traversal(b, &args);
792 nir_pop_if(b, NULL);
793
794 nir_push_if(b, nir_load_var(b, state->traversal.hit));
795 {
796 nir_def *skip_chit = nir_test_mask(b, flags, SpvRayFlagsSkipClosestHitShaderKHRMask);
797 nir_push_if(b, nir_inot(b, skip_chit));
798
799 struct lvp_sbt_entry chit_entry = lvp_load_sbt_entry(
800 b,
801 nir_load_var(b, state->sbt_index),
802 offsetof(VkTraceRaysIndirectCommand2KHR, hitShaderBindingTableAddress),
803 offsetof(struct lvp_ray_tracing_group_handle, index));
804 nir_store_var(b, compiler->state.shader_record_ptr, chit_entry.shader_record_ptr, 0x1);
805
806 for (uint32_t i = 0; i < compiler->pipeline->rt.group_count; i++) {
807 struct lvp_ray_tracing_group *group = compiler->pipeline->rt.groups + i;
808 if (group->recursive_index == VK_SHADER_UNUSED_KHR)
809 continue;
810
811 nir_shader *stage = compiler->pipeline->rt.stages[group->recursive_index]->nir;
812 if (stage->info.stage != MESA_SHADER_CLOSEST_HIT)
813 continue;
814
815 nir_push_if(b, nir_ieq_imm(b, chit_entry.value, group->handle.index));
816 lvp_call_ray_tracing_stage(b, compiler, stage);
817 nir_pop_if(b, NULL);
818 }
819
820 nir_pop_if(b, NULL);
821 }
822 nir_push_else(b, NULL);
823 {
824 struct lvp_sbt_entry miss_entry = lvp_load_sbt_entry(
825 b,
826 miss_index,
827 offsetof(VkTraceRaysIndirectCommand2KHR, missShaderBindingTableAddress),
828 offsetof(struct lvp_ray_tracing_group_handle, index));
829 nir_store_var(b, compiler->state.shader_record_ptr, miss_entry.shader_record_ptr, 0x1);
830
831 for (uint32_t i = 0; i < compiler->pipeline->rt.group_count; i++) {
832 struct lvp_ray_tracing_group *group = compiler->pipeline->rt.groups + i;
833 if (group->recursive_index == VK_SHADER_UNUSED_KHR)
834 continue;
835
836 nir_shader *stage = compiler->pipeline->rt.stages[group->recursive_index]->nir;
837 if (stage->info.stage != MESA_SHADER_MISS)
838 continue;
839
840 nir_push_if(b, nir_ieq_imm(b, miss_entry.value, group->handle.index));
841 lvp_call_ray_tracing_stage(b, compiler, stage);
842 nir_pop_if(b, NULL);
843 }
844 }
845 nir_pop_if(b, NULL);
846
847 nir_store_var(b, state->stack_ptr, stack_ptr, 0x1);
848 }
849
850 static bool
lvp_lower_ray_tracing_instr(nir_builder * b,nir_instr * instr,void * data)851 lvp_lower_ray_tracing_instr(nir_builder *b, nir_instr *instr, void *data)
852 {
853 struct lvp_ray_tracing_pipeline_compiler *compiler = data;
854 struct lvp_ray_tracing_state *state = &compiler->state;
855
856 if (instr->type == nir_instr_type_jump) {
857 nir_jump_instr *jump = nir_instr_as_jump(instr);
858 if (jump->type == nir_jump_halt) {
859 jump->type = nir_jump_return;
860 return true;
861 }
862 return false;
863 } else if (instr->type != nir_instr_type_intrinsic) {
864 return false;
865 }
866
867 nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
868
869 nir_def *def = NULL;
870
871 b->cursor = nir_before_instr(instr);
872
873 switch (intr->intrinsic) {
874 /* Ray tracing instructions */
875 case nir_intrinsic_execute_callable:
876 lvp_execute_callable(b, compiler, intr);
877 break;
878 case nir_intrinsic_trace_ray:
879 lvp_trace_ray(b, compiler, intr);
880 break;
881 case nir_intrinsic_ignore_ray_intersection: {
882 nir_store_var(b, state->accept, nir_imm_false(b), 0x1);
883
884 nir_push_if(b, nir_imm_true(b));
885 nir_jump(b, nir_jump_return);
886 nir_pop_if(b, NULL);
887 break;
888 }
889 case nir_intrinsic_terminate_ray: {
890 nir_store_var(b, state->accept, nir_imm_true(b), 0x1);
891 nir_store_var(b, state->terminate, nir_imm_true(b), 0x1);
892
893 nir_push_if(b, nir_imm_true(b));
894 nir_jump(b, nir_jump_return);
895 nir_pop_if(b, NULL);
896 break;
897 }
898 /* Ray tracing system values */
899 case nir_intrinsic_load_ray_launch_id:
900 def = nir_load_global_invocation_id(b, 32);
901 break;
902 case nir_intrinsic_load_ray_launch_size:
903 def = lvp_load_trace_ray_command_field(
904 b, offsetof(VkTraceRaysIndirectCommand2KHR, width), 3, 32);
905 break;
906 case nir_intrinsic_load_shader_record_ptr:
907 def = nir_load_var(b, state->shader_record_ptr);
908 break;
909 case nir_intrinsic_load_ray_t_min:
910 def = nir_load_var(b, state->tmin);
911 break;
912 case nir_intrinsic_load_ray_t_max:
913 def = nir_load_var(b, state->tmax);
914 break;
915 case nir_intrinsic_load_ray_world_origin:
916 def = nir_load_var(b, state->origin);
917 break;
918 case nir_intrinsic_load_ray_world_direction:
919 def = nir_load_var(b, state->dir);
920 break;
921 case nir_intrinsic_load_ray_instance_custom_index: {
922 nir_def *instance_node_addr = nir_load_var(b, state->instance_addr);
923 nir_def *custom_instance_and_mask = nir_build_load_global(
924 b, 1, 32,
925 nir_iadd_imm(b, instance_node_addr, offsetof(struct lvp_bvh_instance_node, custom_instance_and_mask)));
926 def = nir_iand_imm(b, custom_instance_and_mask, 0xFFFFFF);
927 break;
928 }
929 case nir_intrinsic_load_primitive_id:
930 def = nir_load_var(b, state->primitive_id);
931 break;
932 case nir_intrinsic_load_ray_geometry_index:
933 def = nir_load_var(b, state->geometry_id_and_flags);
934 def = nir_iand_imm(b, def, 0xFFFFFFF);
935 break;
936 case nir_intrinsic_load_instance_id: {
937 nir_def *instance_node_addr = nir_load_var(b, state->instance_addr);
938 def = nir_build_load_global(
939 b, 1, 32, nir_iadd_imm(b, instance_node_addr, offsetof(struct lvp_bvh_instance_node, instance_id)));
940 break;
941 }
942 case nir_intrinsic_load_ray_flags:
943 def = nir_load_var(b, state->flags);
944 break;
945 case nir_intrinsic_load_ray_hit_kind:
946 def = nir_load_var(b, state->hit_kind);
947 break;
948 case nir_intrinsic_load_ray_world_to_object: {
949 unsigned c = nir_intrinsic_column(intr);
950 nir_def *instance_node_addr = nir_load_var(b, state->instance_addr);
951 nir_def *wto_matrix[3];
952 lvp_load_wto_matrix(b, instance_node_addr, wto_matrix);
953
954 nir_def *vals[3];
955 for (unsigned i = 0; i < 3; ++i)
956 vals[i] = nir_channel(b, wto_matrix[i], c);
957
958 def = nir_vec(b, vals, 3);
959 break;
960 }
961 case nir_intrinsic_load_ray_object_to_world: {
962 unsigned c = nir_intrinsic_column(intr);
963 nir_def *instance_node_addr = nir_load_var(b, state->instance_addr);
964 nir_def *rows[3];
965 for (unsigned r = 0; r < 3; ++r)
966 rows[r] = nir_build_load_global(
967 b, 4, 32,
968 nir_iadd_imm(b, instance_node_addr, offsetof(struct lvp_bvh_instance_node, otw_matrix) + r * 16));
969 def = nir_vec3(b, nir_channel(b, rows[0], c), nir_channel(b, rows[1], c), nir_channel(b, rows[2], c));
970 break;
971 }
972 case nir_intrinsic_load_ray_object_origin: {
973 nir_def *instance_node_addr = nir_load_var(b, state->instance_addr);
974 nir_def *wto_matrix[3];
975 lvp_load_wto_matrix(b, instance_node_addr, wto_matrix);
976 def = lvp_mul_vec3_mat(b, nir_load_var(b, state->origin), wto_matrix, true);
977 break;
978 }
979 case nir_intrinsic_load_ray_object_direction: {
980 nir_def *instance_node_addr = nir_load_var(b, state->instance_addr);
981 nir_def *wto_matrix[3];
982 lvp_load_wto_matrix(b, instance_node_addr, wto_matrix);
983 def = lvp_mul_vec3_mat(b, nir_load_var(b, state->dir), wto_matrix, false);
984 break;
985 }
986 case nir_intrinsic_load_cull_mask:
987 def = nir_iand_imm(b, nir_load_var(b, state->cull_mask), 0xFF);
988 break;
989 /* Ray tracing stack lowering */
990 case nir_intrinsic_load_scratch: {
991 nir_src_rewrite(&intr->src[0], nir_iadd(b, nir_load_var(b, state->stack_ptr), intr->src[0].ssa));
992 return true;
993 }
994 case nir_intrinsic_store_scratch: {
995 nir_src_rewrite(&intr->src[1], nir_iadd(b, nir_load_var(b, state->stack_ptr), intr->src[1].ssa));
996 return true;
997 }
998 case nir_intrinsic_load_ray_triangle_vertex_positions: {
999 def = lvp_load_vertex_position(
1000 b, nir_load_var(b, state->instance_addr), nir_load_var(b, state->primitive_id),
1001 nir_intrinsic_column(intr));
1002 break;
1003 }
1004 /* Internal system values */
1005 case nir_intrinsic_load_shader_call_data_offset_lvp:
1006 def = nir_load_var(b, state->shader_call_data_offset);
1007 break;
1008 default:
1009 return false;
1010 }
1011
1012 if (def)
1013 nir_def_rewrite_uses(&intr->def, def);
1014 nir_instr_remove(instr);
1015
1016 return true;
1017 }
1018
1019 static bool
lvp_lower_ray_tracing_stack_base(nir_builder * b,nir_intrinsic_instr * instr,void * data)1020 lvp_lower_ray_tracing_stack_base(nir_builder *b, nir_intrinsic_instr *instr, void *data)
1021 {
1022 if (instr->intrinsic != nir_intrinsic_load_ray_tracing_stack_base_lvp)
1023 return false;
1024
1025 b->cursor = nir_after_instr(&instr->instr);
1026
1027 nir_def_replace(&instr->def, nir_imm_int(b, b->shader->scratch_size));
1028
1029 return true;
1030 }
1031
1032 static void
lvp_compile_ray_tracing_pipeline(struct lvp_pipeline * pipeline,const VkRayTracingPipelineCreateInfoKHR * create_info)1033 lvp_compile_ray_tracing_pipeline(struct lvp_pipeline *pipeline,
1034 const VkRayTracingPipelineCreateInfoKHR *create_info)
1035 {
1036 nir_builder _b = nir_builder_init_simple_shader(
1037 MESA_SHADER_COMPUTE,
1038 pipeline->device->pscreen->get_compiler_options(pipeline->device->pscreen, PIPE_SHADER_IR_NIR, MESA_SHADER_COMPUTE),
1039 "ray tracing pipeline");
1040 nir_builder *b = &_b;
1041
1042 b->shader->info.workgroup_size[0] = 8;
1043
1044 struct lvp_ray_tracing_pipeline_compiler compiler = {
1045 .pipeline = pipeline,
1046 .flags = vk_rt_pipeline_create_flags(create_info),
1047 };
1048 lvp_ray_tracing_state_init(b->shader, &compiler.state);
1049 compiler.functions = _mesa_pointer_hash_table_create(NULL);
1050
1051 nir_def *launch_id = nir_load_ray_launch_id(b);
1052 nir_def *launch_size = nir_load_ray_launch_size(b);
1053 nir_def *oob = nir_ige(b, nir_channel(b, launch_id, 0), nir_channel(b, launch_size, 0));
1054 oob = nir_ior(b, oob, nir_ige(b, nir_channel(b, launch_id, 1), nir_channel(b, launch_size, 1)));
1055 oob = nir_ior(b, oob, nir_ige(b, nir_channel(b, launch_id, 2), nir_channel(b, launch_size, 2)));
1056
1057 nir_push_if(b, oob);
1058 nir_jump(b, nir_jump_return);
1059 nir_pop_if(b, NULL);
1060
1061 nir_store_var(b, compiler.state.stack_ptr, nir_load_ray_tracing_stack_base_lvp(b), 0x1);
1062
1063 struct lvp_sbt_entry raygen_entry = lvp_load_sbt_entry(
1064 b,
1065 NULL,
1066 offsetof(VkTraceRaysIndirectCommand2KHR, raygenShaderRecordAddress),
1067 offsetof(struct lvp_ray_tracing_group_handle, index));
1068 nir_store_var(b, compiler.state.shader_record_ptr, raygen_entry.shader_record_ptr, 0x1);
1069
1070 for (uint32_t i = 0; i < pipeline->rt.group_count; i++) {
1071 struct lvp_ray_tracing_group *group = pipeline->rt.groups + i;
1072 if (group->recursive_index == VK_SHADER_UNUSED_KHR)
1073 continue;
1074
1075 nir_shader *stage = pipeline->rt.stages[group->recursive_index]->nir;
1076
1077 if (stage->info.stage != MESA_SHADER_RAYGEN)
1078 continue;
1079
1080 nir_push_if(b, nir_ieq_imm(b, raygen_entry.value, group->handle.index));
1081 lvp_call_ray_tracing_stage(b, &compiler, stage);
1082 nir_pop_if(b, NULL);
1083 }
1084
1085 nir_shader_instructions_pass(b->shader, lvp_lower_ray_tracing_instr, nir_metadata_none, &compiler);
1086
1087 NIR_PASS(_, b->shader, nir_lower_returns);
1088
1089 const struct nir_lower_compute_system_values_options compute_system_values = {0};
1090 NIR_PASS(_, b->shader, nir_lower_compute_system_values, &compute_system_values);
1091 NIR_PASS(_, b->shader, nir_lower_global_vars_to_local);
1092 NIR_PASS(_, b->shader, nir_lower_vars_to_ssa);
1093
1094 NIR_PASS(_, b->shader, nir_lower_vars_to_explicit_types,
1095 nir_var_shader_temp,
1096 glsl_get_natural_size_align_bytes);
1097
1098 NIR_PASS(_, b->shader, nir_lower_explicit_io, nir_var_shader_temp,
1099 nir_address_format_32bit_offset);
1100
1101 NIR_PASS(_, b->shader, nir_shader_intrinsics_pass, lvp_lower_ray_tracing_stack_base,
1102 nir_metadata_control_flow, NULL);
1103
1104 /* We can not support dynamic stack sizes, assume the worst. */
1105 b->shader->scratch_size +=
1106 compiler.raygen_size +
1107 MIN2(create_info->maxPipelineRayRecursionDepth, 1) * MAX3(compiler.chit_size, compiler.miss_size, compiler.isec_size + compiler.ahit_size) +
1108 MAX2(0, (int)create_info->maxPipelineRayRecursionDepth - 1) * MAX2(compiler.chit_size, compiler.miss_size) + 31 * compiler.callable_size;
1109
1110 struct lvp_shader *shader = &pipeline->shaders[MESA_SHADER_RAYGEN];
1111 lvp_shader_init(shader, b->shader);
1112 shader->shader_cso = lvp_shader_compile(pipeline->device, shader, nir_shader_clone(NULL, shader->pipeline_nir->nir), false);
1113
1114 _mesa_hash_table_destroy(compiler.functions, NULL);
1115 }
1116
1117 static VkResult
lvp_create_ray_tracing_pipeline(VkDevice _device,const VkAllocationCallbacks * allocator,const VkRayTracingPipelineCreateInfoKHR * create_info,VkPipeline * out_pipeline)1118 lvp_create_ray_tracing_pipeline(VkDevice _device, const VkAllocationCallbacks *allocator,
1119 const VkRayTracingPipelineCreateInfoKHR *create_info,
1120 VkPipeline *out_pipeline)
1121 {
1122 VK_FROM_HANDLE(lvp_device, device, _device);
1123 VK_FROM_HANDLE(lvp_pipeline_layout, layout, create_info->layout);
1124
1125 VkResult result = VK_SUCCESS;
1126
1127 struct lvp_pipeline *pipeline = vk_zalloc2(&device->vk.alloc, allocator, sizeof(struct lvp_pipeline), 8,
1128 VK_SYSTEM_ALLOCATION_SCOPE_OBJECT);
1129 if (!pipeline)
1130 return VK_ERROR_OUT_OF_HOST_MEMORY;
1131
1132 vk_object_base_init(&device->vk, &pipeline->base,
1133 VK_OBJECT_TYPE_PIPELINE);
1134
1135 vk_pipeline_layout_ref(&layout->vk);
1136
1137 pipeline->device = device;
1138 pipeline->layout = layout;
1139 pipeline->type = LVP_PIPELINE_RAY_TRACING;
1140 pipeline->flags = vk_rt_pipeline_create_flags(create_info);
1141
1142 pipeline->rt.stage_count = create_info->stageCount;
1143 pipeline->rt.group_count = create_info->groupCount;
1144 if (create_info->pLibraryInfo) {
1145 for (uint32_t i = 0; i < create_info->pLibraryInfo->libraryCount; i++) {
1146 VK_FROM_HANDLE(lvp_pipeline, library, create_info->pLibraryInfo->pLibraries[i]);
1147 pipeline->rt.stage_count += library->rt.stage_count;
1148 pipeline->rt.group_count += library->rt.group_count;
1149 }
1150 }
1151
1152 pipeline->rt.stages = calloc(pipeline->rt.stage_count, sizeof(struct lvp_pipeline_nir *));
1153 pipeline->rt.groups = calloc(pipeline->rt.group_count, sizeof(struct lvp_ray_tracing_group));
1154 if (!pipeline->rt.stages || !pipeline->rt.groups) {
1155 result = VK_ERROR_OUT_OF_HOST_MEMORY;
1156 goto fail;
1157 }
1158
1159 result = lvp_compile_ray_tracing_stages(pipeline, create_info);
1160 if (result != VK_SUCCESS)
1161 goto fail;
1162
1163 lvp_init_ray_tracing_groups(pipeline, create_info);
1164
1165 if (!(pipeline->flags & VK_PIPELINE_CREATE_2_LIBRARY_BIT_KHR)) {
1166 lvp_compile_ray_tracing_pipeline(pipeline, create_info);
1167 }
1168
1169 *out_pipeline = lvp_pipeline_to_handle(pipeline);
1170
1171 return VK_SUCCESS;
1172
1173 fail:
1174 lvp_pipeline_destroy(device, pipeline, false);
1175 return result;
1176 }
1177
1178 VKAPI_ATTR VkResult VKAPI_CALL
lvp_CreateRayTracingPipelinesKHR(VkDevice device,VkDeferredOperationKHR deferredOperation,VkPipelineCache pipelineCache,uint32_t createInfoCount,const VkRayTracingPipelineCreateInfoKHR * pCreateInfos,const VkAllocationCallbacks * pAllocator,VkPipeline * pPipelines)1179 lvp_CreateRayTracingPipelinesKHR(
1180 VkDevice device,
1181 VkDeferredOperationKHR deferredOperation,
1182 VkPipelineCache pipelineCache,
1183 uint32_t createInfoCount,
1184 const VkRayTracingPipelineCreateInfoKHR *pCreateInfos,
1185 const VkAllocationCallbacks *pAllocator,
1186 VkPipeline *pPipelines)
1187 {
1188 VkResult result = VK_SUCCESS;
1189
1190 uint32_t i = 0;
1191 for (; i < createInfoCount; i++) {
1192 VkResult tmp_result = lvp_create_ray_tracing_pipeline(
1193 device, pAllocator, pCreateInfos + i, pPipelines + i);
1194
1195 if (tmp_result != VK_SUCCESS) {
1196 result = tmp_result;
1197 pPipelines[i] = VK_NULL_HANDLE;
1198
1199 if (vk_rt_pipeline_create_flags(&pCreateInfos[i]) &
1200 VK_PIPELINE_CREATE_2_EARLY_RETURN_ON_FAILURE_BIT_KHR)
1201 break;
1202 }
1203 }
1204
1205 for (; i < createInfoCount; i++)
1206 pPipelines[i] = VK_NULL_HANDLE;
1207
1208 return result;
1209 }
1210
1211
1212 VKAPI_ATTR VkResult VKAPI_CALL
lvp_GetRayTracingShaderGroupHandlesKHR(VkDevice _device,VkPipeline _pipeline,uint32_t firstGroup,uint32_t groupCount,size_t dataSize,void * pData)1213 lvp_GetRayTracingShaderGroupHandlesKHR(
1214 VkDevice _device,
1215 VkPipeline _pipeline,
1216 uint32_t firstGroup,
1217 uint32_t groupCount,
1218 size_t dataSize,
1219 void *pData)
1220 {
1221 VK_FROM_HANDLE(lvp_pipeline, pipeline, _pipeline);
1222
1223 uint8_t *data = pData;
1224 memset(data, 0, dataSize);
1225
1226 for (uint32_t i = 0; i < groupCount; i++) {
1227 memcpy(data + i * LVP_RAY_TRACING_GROUP_HANDLE_SIZE,
1228 pipeline->rt.groups + firstGroup + i,
1229 sizeof(struct lvp_ray_tracing_group_handle));
1230 }
1231
1232 return VK_SUCCESS;
1233 }
1234
1235 VKAPI_ATTR VkResult VKAPI_CALL
lvp_GetRayTracingCaptureReplayShaderGroupHandlesKHR(VkDevice device,VkPipeline pipeline,uint32_t firstGroup,uint32_t groupCount,size_t dataSize,void * pData)1236 lvp_GetRayTracingCaptureReplayShaderGroupHandlesKHR(
1237 VkDevice device,
1238 VkPipeline pipeline,
1239 uint32_t firstGroup,
1240 uint32_t groupCount,
1241 size_t dataSize,
1242 void *pData)
1243 {
1244 return VK_SUCCESS;
1245 }
1246
1247 VKAPI_ATTR VkDeviceSize VKAPI_CALL
lvp_GetRayTracingShaderGroupStackSizeKHR(VkDevice device,VkPipeline pipeline,uint32_t group,VkShaderGroupShaderKHR groupShader)1248 lvp_GetRayTracingShaderGroupStackSizeKHR(
1249 VkDevice device,
1250 VkPipeline pipeline,
1251 uint32_t group,
1252 VkShaderGroupShaderKHR groupShader)
1253 {
1254 return 4;
1255 }
1256