• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright © 2021 Google
3  *
4  * SPDX-License-Identifier: MIT
5  */
6 
7 #include "nir/nir.h"
8 #include "nir/nir_builder.h"
9 #include "nir/nir_serialize.h"
10 
11 #include "vk_shader_module.h"
12 
13 #include "nir/radv_nir.h"
14 #include "radv_debug.h"
15 #include "radv_descriptor_set.h"
16 #include "radv_entrypoints.h"
17 #include "radv_pipeline_binary.h"
18 #include "radv_pipeline_cache.h"
19 #include "radv_pipeline_rt.h"
20 #include "radv_rmv.h"
21 #include "radv_shader.h"
22 #include "ac_nir.h"
23 
24 struct rt_handle_hash_entry {
25    uint32_t key;
26    char hash[20];
27 };
28 
29 static uint32_t
handle_from_stages(struct radv_device * device,const unsigned char * shader_sha1,bool replay_namespace)30 handle_from_stages(struct radv_device *device, const unsigned char *shader_sha1, bool replay_namespace)
31 {
32    uint32_t ret;
33 
34    memcpy(&ret, shader_sha1, sizeof(ret));
35 
36    /* Leave the low half for resume shaders etc. */
37    ret |= 1u << 31;
38 
39    /* Ensure we have dedicated space for replayable shaders */
40    ret &= ~(1u << 30);
41    ret |= replay_namespace << 30;
42 
43    simple_mtx_lock(&device->rt_handles_mtx);
44 
45    struct hash_entry *he = NULL;
46    for (;;) {
47       he = _mesa_hash_table_search(device->rt_handles, &ret);
48       if (!he)
49          break;
50 
51       if (memcmp(he->data, shader_sha1, SHA1_DIGEST_LENGTH) == 0)
52          break;
53 
54       ++ret;
55    }
56 
57    if (!he) {
58       struct rt_handle_hash_entry *e = ralloc(device->rt_handles, struct rt_handle_hash_entry);
59       e->key = ret;
60       memcpy(e->hash, shader_sha1, SHA1_DIGEST_LENGTH);
61       _mesa_hash_table_insert(device->rt_handles, &e->key, &e->hash);
62    }
63 
64    simple_mtx_unlock(&device->rt_handles_mtx);
65 
66    return ret;
67 }
68 
69 static void
radv_generate_rt_shaders_key(const struct radv_device * device,const VkRayTracingPipelineCreateInfoKHR * pCreateInfo,struct radv_shader_stage_key * stage_keys)70 radv_generate_rt_shaders_key(const struct radv_device *device, const VkRayTracingPipelineCreateInfoKHR *pCreateInfo,
71                              struct radv_shader_stage_key *stage_keys)
72 {
73    VkPipelineCreateFlags2 create_flags = vk_rt_pipeline_create_flags(pCreateInfo);
74 
75    for (uint32_t i = 0; i < pCreateInfo->stageCount; i++) {
76       const VkPipelineShaderStageCreateInfo *stage = &pCreateInfo->pStages[i];
77       gl_shader_stage s = vk_to_mesa_shader_stage(stage->stage);
78 
79       stage_keys[s] = radv_pipeline_get_shader_key(device, stage, create_flags, pCreateInfo->pNext);
80    }
81 
82    if (pCreateInfo->pLibraryInfo) {
83       for (unsigned i = 0; i < pCreateInfo->pLibraryInfo->libraryCount; ++i) {
84          VK_FROM_HANDLE(radv_pipeline, pipeline_lib, pCreateInfo->pLibraryInfo->pLibraries[i]);
85          struct radv_ray_tracing_pipeline *library_pipeline = radv_pipeline_to_ray_tracing(pipeline_lib);
86          /* apply shader robustness from merged shaders */
87          if (library_pipeline->traversal_storage_robustness2)
88             stage_keys[MESA_SHADER_INTERSECTION].storage_robustness2 = true;
89 
90          if (library_pipeline->traversal_uniform_robustness2)
91             stage_keys[MESA_SHADER_INTERSECTION].uniform_robustness2 = true;
92       }
93    }
94 }
95 
96 static VkResult
radv_create_group_handles(struct radv_device * device,const VkRayTracingPipelineCreateInfoKHR * pCreateInfo,const struct radv_ray_tracing_stage * stages,struct radv_ray_tracing_group * groups)97 radv_create_group_handles(struct radv_device *device, const VkRayTracingPipelineCreateInfoKHR *pCreateInfo,
98                           const struct radv_ray_tracing_stage *stages, struct radv_ray_tracing_group *groups)
99 {
100    VkPipelineCreateFlags2 create_flags = vk_rt_pipeline_create_flags(pCreateInfo);
101    bool capture_replay = create_flags & VK_PIPELINE_CREATE_2_RAY_TRACING_SHADER_GROUP_HANDLE_CAPTURE_REPLAY_BIT_KHR;
102    for (unsigned i = 0; i < pCreateInfo->groupCount; ++i) {
103       const VkRayTracingShaderGroupCreateInfoKHR *group_info = &pCreateInfo->pGroups[i];
104       switch (group_info->type) {
105       case VK_RAY_TRACING_SHADER_GROUP_TYPE_GENERAL_KHR:
106          if (group_info->generalShader != VK_SHADER_UNUSED_KHR) {
107             const struct radv_ray_tracing_stage *stage = &stages[group_info->generalShader];
108             groups[i].handle.general_index = handle_from_stages(device, stage->sha1, capture_replay);
109          }
110          break;
111       case VK_RAY_TRACING_SHADER_GROUP_TYPE_PROCEDURAL_HIT_GROUP_KHR:
112          if (group_info->closestHitShader != VK_SHADER_UNUSED_KHR) {
113             const struct radv_ray_tracing_stage *stage = &stages[group_info->closestHitShader];
114             groups[i].handle.closest_hit_index = handle_from_stages(device, stage->sha1, capture_replay);
115          }
116 
117          if (group_info->intersectionShader != VK_SHADER_UNUSED_KHR) {
118             unsigned char sha1[SHA1_DIGEST_LENGTH];
119             struct mesa_sha1 ctx;
120 
121             _mesa_sha1_init(&ctx);
122             _mesa_sha1_update(&ctx, stages[group_info->intersectionShader].sha1, SHA1_DIGEST_LENGTH);
123             if (group_info->anyHitShader != VK_SHADER_UNUSED_KHR)
124                _mesa_sha1_update(&ctx, stages[group_info->anyHitShader].sha1, SHA1_DIGEST_LENGTH);
125             _mesa_sha1_final(&ctx, sha1);
126 
127             groups[i].handle.intersection_index = handle_from_stages(device, sha1, capture_replay);
128          }
129          break;
130       case VK_RAY_TRACING_SHADER_GROUP_TYPE_TRIANGLES_HIT_GROUP_KHR:
131          if (group_info->closestHitShader != VK_SHADER_UNUSED_KHR) {
132             const struct radv_ray_tracing_stage *stage = &stages[group_info->closestHitShader];
133             groups[i].handle.closest_hit_index = handle_from_stages(device, stage->sha1, capture_replay);
134          }
135 
136          if (group_info->anyHitShader != VK_SHADER_UNUSED_KHR) {
137             const struct radv_ray_tracing_stage *stage = &stages[group_info->anyHitShader];
138             groups[i].handle.any_hit_index = handle_from_stages(device, stage->sha1, capture_replay);
139          }
140          break;
141       case VK_SHADER_GROUP_SHADER_MAX_ENUM_KHR:
142          unreachable("VK_SHADER_GROUP_SHADER_MAX_ENUM_KHR");
143       }
144 
145       if (group_info->pShaderGroupCaptureReplayHandle) {
146          const struct radv_rt_capture_replay_handle *handle = group_info->pShaderGroupCaptureReplayHandle;
147          if (memcmp(&handle->non_recursive_idx, &groups[i].handle.any_hit_index, sizeof(uint32_t)) != 0) {
148             return VK_ERROR_INVALID_OPAQUE_CAPTURE_ADDRESS;
149          }
150       }
151    }
152 
153    return VK_SUCCESS;
154 }
155 
156 static VkResult
radv_rt_init_capture_replay(struct radv_device * device,const VkRayTracingPipelineCreateInfoKHR * pCreateInfo,const struct radv_ray_tracing_stage * stages,const struct radv_ray_tracing_group * groups,struct radv_serialized_shader_arena_block * capture_replay_blocks)157 radv_rt_init_capture_replay(struct radv_device *device, const VkRayTracingPipelineCreateInfoKHR *pCreateInfo,
158                             const struct radv_ray_tracing_stage *stages, const struct radv_ray_tracing_group *groups,
159                             struct radv_serialized_shader_arena_block *capture_replay_blocks)
160 {
161    VkResult result = VK_SUCCESS;
162    uint32_t idx;
163 
164    for (idx = 0; idx < pCreateInfo->groupCount; idx++) {
165       if (!pCreateInfo->pGroups[idx].pShaderGroupCaptureReplayHandle)
166          continue;
167 
168       const struct radv_rt_capture_replay_handle *handle =
169          (const struct radv_rt_capture_replay_handle *)pCreateInfo->pGroups[idx].pShaderGroupCaptureReplayHandle;
170 
171       if (groups[idx].recursive_shader < pCreateInfo->stageCount) {
172          capture_replay_blocks[groups[idx].recursive_shader] = handle->recursive_shader_alloc;
173       } else if (groups[idx].recursive_shader != VK_SHADER_UNUSED_KHR) {
174          struct radv_shader *library_shader = stages[groups[idx].recursive_shader].shader;
175          simple_mtx_lock(&library_shader->replay_mtx);
176          /* If arena_va is 0, the pipeline is monolithic and the shader was inlined into raygen */
177          if (!library_shader->has_replay_alloc && handle->recursive_shader_alloc.arena_va) {
178             union radv_shader_arena_block *new_block =
179                radv_replay_shader_arena_block(device, &handle->recursive_shader_alloc, library_shader);
180             if (!new_block) {
181                result = VK_ERROR_INVALID_OPAQUE_CAPTURE_ADDRESS;
182                goto reloc_out;
183             }
184 
185             radv_shader_wait_for_upload(device, library_shader->upload_seq);
186             radv_free_shader_memory(device, library_shader->alloc);
187 
188             library_shader->alloc = new_block;
189             library_shader->has_replay_alloc = true;
190 
191             library_shader->bo = library_shader->alloc->arena->bo;
192             library_shader->va = radv_buffer_get_va(library_shader->bo) + library_shader->alloc->offset;
193 
194             if (!radv_shader_reupload(device, library_shader)) {
195                result = VK_ERROR_UNKNOWN;
196                goto reloc_out;
197             }
198          }
199 
200          reloc_out:
201             simple_mtx_unlock(&library_shader->replay_mtx);
202             if (result != VK_SUCCESS)
203                return result;
204          }
205    }
206 
207    return result;
208 }
209 
210 static VkResult
radv_rt_fill_group_info(struct radv_device * device,const VkRayTracingPipelineCreateInfoKHR * pCreateInfo,const struct radv_ray_tracing_stage * stages,struct radv_ray_tracing_group * groups)211 radv_rt_fill_group_info(struct radv_device *device, const VkRayTracingPipelineCreateInfoKHR *pCreateInfo,
212                         const struct radv_ray_tracing_stage *stages, struct radv_ray_tracing_group *groups)
213 {
214    VkResult result = radv_create_group_handles(device, pCreateInfo, stages, groups);
215 
216    uint32_t idx;
217    for (idx = 0; idx < pCreateInfo->groupCount; idx++) {
218       groups[idx].type = pCreateInfo->pGroups[idx].type;
219       if (groups[idx].type == VK_RAY_TRACING_SHADER_GROUP_TYPE_GENERAL_KHR)
220          groups[idx].recursive_shader = pCreateInfo->pGroups[idx].generalShader;
221       else
222          groups[idx].recursive_shader = pCreateInfo->pGroups[idx].closestHitShader;
223       groups[idx].any_hit_shader = pCreateInfo->pGroups[idx].anyHitShader;
224       groups[idx].intersection_shader = pCreateInfo->pGroups[idx].intersectionShader;
225    }
226 
227    /* copy and adjust library groups (incl. handles) */
228    if (pCreateInfo->pLibraryInfo) {
229       unsigned stage_count = pCreateInfo->stageCount;
230       for (unsigned i = 0; i < pCreateInfo->pLibraryInfo->libraryCount; ++i) {
231          VK_FROM_HANDLE(radv_pipeline, pipeline_lib, pCreateInfo->pLibraryInfo->pLibraries[i]);
232          struct radv_ray_tracing_pipeline *library_pipeline = radv_pipeline_to_ray_tracing(pipeline_lib);
233 
234          for (unsigned j = 0; j < library_pipeline->group_count; ++j) {
235             struct radv_ray_tracing_group *dst = &groups[idx + j];
236             *dst = library_pipeline->groups[j];
237             if (dst->recursive_shader != VK_SHADER_UNUSED_KHR)
238                dst->recursive_shader += stage_count;
239             if (dst->any_hit_shader != VK_SHADER_UNUSED_KHR)
240                dst->any_hit_shader += stage_count;
241             if (dst->intersection_shader != VK_SHADER_UNUSED_KHR)
242                dst->intersection_shader += stage_count;
243             /* Don't set the shader VA since the handles are part of the pipeline hash */
244             dst->handle.recursive_shader_ptr = 0;
245          }
246          idx += library_pipeline->group_count;
247          stage_count += library_pipeline->stage_count;
248       }
249    }
250 
251    return result;
252 }
253 
254 static void
radv_rt_fill_stage_info(const VkRayTracingPipelineCreateInfoKHR * pCreateInfo,struct radv_ray_tracing_stage * stages)255 radv_rt_fill_stage_info(const VkRayTracingPipelineCreateInfoKHR *pCreateInfo, struct radv_ray_tracing_stage *stages)
256 {
257    uint32_t idx;
258    for (idx = 0; idx < pCreateInfo->stageCount; idx++)
259       stages[idx].stage = vk_to_mesa_shader_stage(pCreateInfo->pStages[idx].stage);
260 
261    if (pCreateInfo->pLibraryInfo) {
262       for (unsigned i = 0; i < pCreateInfo->pLibraryInfo->libraryCount; ++i) {
263          VK_FROM_HANDLE(radv_pipeline, pipeline, pCreateInfo->pLibraryInfo->pLibraries[i]);
264          struct radv_ray_tracing_pipeline *library_pipeline = radv_pipeline_to_ray_tracing(pipeline);
265          for (unsigned j = 0; j < library_pipeline->stage_count; ++j) {
266             if (library_pipeline->stages[j].nir)
267                stages[idx].nir = vk_pipeline_cache_object_ref(library_pipeline->stages[j].nir);
268             if (library_pipeline->stages[j].shader)
269                stages[idx].shader = radv_shader_ref(library_pipeline->stages[j].shader);
270 
271             stages[idx].stage = library_pipeline->stages[j].stage;
272             stages[idx].stack_size = library_pipeline->stages[j].stack_size;
273             stages[idx].info = library_pipeline->stages[j].info;
274             memcpy(stages[idx].sha1, library_pipeline->stages[j].sha1, SHA1_DIGEST_LENGTH);
275             idx++;
276          }
277       }
278    }
279 }
280 
281 static void
radv_init_rt_stage_hashes(const struct radv_device * device,VkPipelineCreateFlags2 pipeline_flags,const VkRayTracingPipelineCreateInfoKHR * pCreateInfo,struct radv_ray_tracing_stage * stages,const struct radv_shader_stage_key * stage_keys)282 radv_init_rt_stage_hashes(const struct radv_device *device, VkPipelineCreateFlags2 pipeline_flags,
283                           const VkRayTracingPipelineCreateInfoKHR *pCreateInfo, struct radv_ray_tracing_stage *stages,
284                           const struct radv_shader_stage_key *stage_keys)
285 {
286    const VkPipelineBinaryInfoKHR *binary_info = vk_find_struct_const(pCreateInfo->pNext, PIPELINE_BINARY_INFO_KHR);
287    if (binary_info && binary_info->binaryCount > 0) {
288       for (uint32_t i = 0; i < binary_info->binaryCount; i++) {
289          VK_FROM_HANDLE(radv_pipeline_binary, pipeline_binary, binary_info->pPipelineBinaries[i]);
290          struct blob_reader blob;
291 
292          blob_reader_init(&blob, pipeline_binary->data, pipeline_binary->size);
293 
294          const struct radv_ray_tracing_binary_header *header =
295             (const struct radv_ray_tracing_binary_header *)blob_read_bytes(&blob, sizeof(*header));
296 
297          if (header->is_traversal_shader)
298             continue;
299 
300          memcpy(stages[i].sha1, header->stage_sha1, SHA1_DIGEST_LENGTH);
301       }
302    } else {
303       for (uint32_t idx = 0; idx < pCreateInfo->stageCount; idx++) {
304          const VkPipelineShaderStageCreateInfo *sinfo = &pCreateInfo->pStages[idx];
305          gl_shader_stage s = vk_to_mesa_shader_stage(sinfo->stage);
306          struct mesa_sha1 ctx;
307 
308          _mesa_sha1_init(&ctx);
309          radv_pipeline_hash_shader_stage(pipeline_flags, sinfo, &stage_keys[s], &ctx);
310          _mesa_sha1_final(&ctx, stages[idx].sha1);
311       }
312    }
313 }
314 
315 static bool
should_move_rt_instruction(nir_intrinsic_instr * instr)316 should_move_rt_instruction(nir_intrinsic_instr *instr)
317 {
318    switch (instr->intrinsic) {
319    case nir_intrinsic_load_hit_attrib_amd:
320       return nir_intrinsic_base(instr) < RADV_MAX_HIT_ATTRIB_DWORDS;
321    case nir_intrinsic_load_rt_arg_scratch_offset_amd:
322    case nir_intrinsic_load_ray_flags:
323    case nir_intrinsic_load_ray_object_origin:
324    case nir_intrinsic_load_ray_world_origin:
325    case nir_intrinsic_load_ray_t_min:
326    case nir_intrinsic_load_ray_object_direction:
327    case nir_intrinsic_load_ray_world_direction:
328    case nir_intrinsic_load_ray_t_max:
329       return true;
330    default:
331       return false;
332    }
333 }
334 
335 static void
move_rt_instructions(nir_shader * shader)336 move_rt_instructions(nir_shader *shader)
337 {
338    nir_cursor target = nir_before_impl(nir_shader_get_entrypoint(shader));
339 
340    nir_foreach_block (block, nir_shader_get_entrypoint(shader)) {
341       nir_foreach_instr_safe (instr, block) {
342          if (instr->type != nir_instr_type_intrinsic)
343             continue;
344 
345          nir_intrinsic_instr *intrinsic = nir_instr_as_intrinsic(instr);
346 
347          if (!should_move_rt_instruction(intrinsic))
348             continue;
349 
350          nir_instr_move(target, instr);
351       }
352    }
353 
354    nir_metadata_preserve(nir_shader_get_entrypoint(shader), nir_metadata_all & (~nir_metadata_instr_index));
355 }
356 
357 static VkResult
radv_rt_nir_to_asm(struct radv_device * device,struct vk_pipeline_cache * cache,const VkRayTracingPipelineCreateInfoKHR * pCreateInfo,struct radv_ray_tracing_pipeline * pipeline,bool monolithic,struct radv_shader_stage * stage,uint32_t * stack_size,struct radv_ray_tracing_stage_info * stage_info,const struct radv_ray_tracing_stage_info * traversal_stage_info,struct radv_serialized_shader_arena_block * replay_block,bool skip_shaders_cache,struct radv_shader ** out_shader)358 radv_rt_nir_to_asm(struct radv_device *device, struct vk_pipeline_cache *cache,
359                    const VkRayTracingPipelineCreateInfoKHR *pCreateInfo, struct radv_ray_tracing_pipeline *pipeline,
360                    bool monolithic, struct radv_shader_stage *stage, uint32_t *stack_size,
361                    struct radv_ray_tracing_stage_info *stage_info,
362                    const struct radv_ray_tracing_stage_info *traversal_stage_info,
363                    struct radv_serialized_shader_arena_block *replay_block, bool skip_shaders_cache,
364                    struct radv_shader **out_shader)
365 {
366    struct radv_physical_device *pdev = radv_device_physical(device);
367    struct radv_instance *instance = radv_physical_device_instance(pdev);
368 
369    struct radv_shader_binary *binary;
370    bool keep_executable_info = radv_pipeline_capture_shaders(device, pipeline->base.base.create_flags);
371    bool keep_statistic_info = radv_pipeline_capture_shader_stats(device, pipeline->base.base.create_flags);
372 
373    radv_nir_lower_rt_io(stage->nir, monolithic, 0);
374 
375    /* Gather shader info. */
376    nir_shader_gather_info(stage->nir, nir_shader_get_entrypoint(stage->nir));
377    radv_nir_shader_info_init(stage->stage, MESA_SHADER_NONE, &stage->info);
378    radv_nir_shader_info_pass(device, stage->nir, &stage->layout, &stage->key, NULL, RADV_PIPELINE_RAY_TRACING, false,
379                              &stage->info);
380 
381    /* Declare shader arguments. */
382    radv_declare_shader_args(device, NULL, &stage->info, stage->stage, MESA_SHADER_NONE, &stage->args);
383 
384    stage->info.user_sgprs_locs = stage->args.user_sgprs_locs;
385    stage->info.inline_push_constant_mask = stage->args.ac.inline_push_const_mask;
386 
387    /* Move ray tracing system values to the top that are set by rt_trace_ray
388     * to prevent them from being overwritten by other rt_trace_ray calls.
389     */
390    NIR_PASS_V(stage->nir, move_rt_instructions);
391 
392    uint32_t num_resume_shaders = 0;
393    nir_shader **resume_shaders = NULL;
394 
395    if (stage->stage != MESA_SHADER_INTERSECTION && !monolithic) {
396       nir_builder b = nir_builder_at(nir_after_impl(nir_shader_get_entrypoint(stage->nir)));
397       nir_rt_return_amd(&b);
398 
399       const nir_lower_shader_calls_options opts = {
400          .address_format = nir_address_format_32bit_offset,
401          .stack_alignment = 16,
402          .localized_loads = true,
403          .vectorizer_callback = ac_nir_mem_vectorize_callback,
404          .vectorizer_data = &(struct ac_nir_config){pdev->info.gfx_level, !radv_use_llvm_for_stage(pdev, stage->stage)},
405       };
406       nir_lower_shader_calls(stage->nir, &opts, &resume_shaders, &num_resume_shaders, stage->nir);
407    }
408 
409    unsigned num_shaders = num_resume_shaders + 1;
410    nir_shader **shaders = ralloc_array(stage->nir, nir_shader *, num_shaders);
411    if (!shaders)
412       return VK_ERROR_OUT_OF_HOST_MEMORY;
413 
414    shaders[0] = stage->nir;
415    for (uint32_t i = 0; i < num_resume_shaders; i++)
416       shaders[i + 1] = resume_shaders[i];
417 
418    if (stage_info)
419       memset(stage_info->unused_args, 0xFF, sizeof(stage_info->unused_args));
420 
421    /* Postprocess shader parts. */
422    for (uint32_t i = 0; i < num_shaders; i++) {
423       struct radv_shader_stage temp_stage = *stage;
424       temp_stage.nir = shaders[i];
425       radv_nir_lower_rt_abi(temp_stage.nir, pCreateInfo, &temp_stage.args, &stage->info, stack_size, i > 0, device,
426                             pipeline, monolithic, traversal_stage_info);
427 
428       /* Info might be out-of-date after inlining in radv_nir_lower_rt_abi(). */
429       nir_shader_gather_info(temp_stage.nir, nir_shader_get_entrypoint(temp_stage.nir));
430 
431       radv_optimize_nir(temp_stage.nir, stage->key.optimisations_disabled);
432       radv_postprocess_nir(device, NULL, &temp_stage);
433 
434       if (stage_info)
435          radv_gather_unused_args(stage_info, shaders[i]);
436    }
437 
438    bool dump_shader = radv_can_dump_shader(device, shaders[0]);
439    bool dump_nir = dump_shader && (instance->debug_flags & RADV_DEBUG_DUMP_NIR);
440    bool replayable = (pipeline->base.base.create_flags &
441                       VK_PIPELINE_CREATE_2_RAY_TRACING_SHADER_GROUP_HANDLE_CAPTURE_REPLAY_BIT_KHR) &&
442                      stage->stage != MESA_SHADER_INTERSECTION;
443 
444    if (dump_shader) {
445       simple_mtx_lock(&instance->shader_dump_mtx);
446 
447       if (dump_nir) {
448          for (uint32_t i = 0; i < num_shaders; i++)
449             nir_print_shader(shaders[i], stderr);
450       }
451    }
452 
453    char *nir_string = NULL;
454    if (keep_executable_info || dump_shader)
455       nir_string = radv_dump_nir_shaders(instance, shaders, num_shaders);
456 
457    /* Compile NIR shader to AMD assembly. */
458    binary =
459       radv_shader_nir_to_asm(device, stage, shaders, num_shaders, NULL, keep_executable_info, keep_statistic_info);
460    struct radv_shader *shader;
461    if (replay_block || replayable) {
462       VkResult result = radv_shader_create_uncached(device, binary, replayable, replay_block, &shader);
463       if (result != VK_SUCCESS) {
464          if (dump_shader)
465             simple_mtx_unlock(&instance->shader_dump_mtx);
466 
467          free(binary);
468          return result;
469       }
470    } else
471       shader = radv_shader_create(device, cache, binary, skip_shaders_cache || dump_shader);
472 
473    if (shader) {
474       shader->nir_string = nir_string;
475 
476       radv_shader_dump_debug_info(device, dump_shader, binary, shader, shaders, num_shaders, &stage->info);
477 
478       if (shader && keep_executable_info && stage->spirv.size) {
479          shader->spirv = malloc(stage->spirv.size);
480          memcpy(shader->spirv, stage->spirv.data, stage->spirv.size);
481          shader->spirv_size = stage->spirv.size;
482       }
483    }
484 
485    if (dump_shader)
486       simple_mtx_unlock(&instance->shader_dump_mtx);
487 
488    free(binary);
489 
490    *out_shader = shader;
491 
492    if (radv_can_dump_shader_stats(device, stage->nir))
493       radv_dump_shader_stats(device, &pipeline->base.base, shader, stage->nir->info.stage, stderr);
494 
495    return shader ? VK_SUCCESS : VK_ERROR_OUT_OF_HOST_MEMORY;
496 }
497 
498 static void
radv_update_const_info(enum radv_rt_const_arg_state * state,bool equal)499 radv_update_const_info(enum radv_rt_const_arg_state *state, bool equal)
500 {
501    if (*state == RADV_RT_CONST_ARG_STATE_UNINITIALIZED)
502       *state = RADV_RT_CONST_ARG_STATE_VALID;
503    else if (*state == RADV_RT_CONST_ARG_STATE_VALID && !equal)
504       *state = RADV_RT_CONST_ARG_STATE_INVALID;
505 }
506 
507 static void
radv_gather_trace_ray_src(struct radv_rt_const_arg_info * info,nir_src src)508 radv_gather_trace_ray_src(struct radv_rt_const_arg_info *info, nir_src src)
509 {
510    if (nir_src_is_const(src)) {
511       radv_update_const_info(&info->state, info->value == nir_src_as_uint(src));
512       info->value = nir_src_as_uint(src);
513    } else {
514       info->state = RADV_RT_CONST_ARG_STATE_INVALID;
515    }
516 }
517 
518 static void
radv_rt_const_arg_info_combine(struct radv_rt_const_arg_info * dst,const struct radv_rt_const_arg_info * src)519 radv_rt_const_arg_info_combine(struct radv_rt_const_arg_info *dst, const struct radv_rt_const_arg_info *src)
520 {
521    if (src->state != RADV_RT_CONST_ARG_STATE_UNINITIALIZED) {
522       radv_update_const_info(&dst->state, dst->value == src->value);
523       if (src->state == RADV_RT_CONST_ARG_STATE_INVALID)
524          dst->state = RADV_RT_CONST_ARG_STATE_INVALID;
525       dst->value = src->value;
526    }
527 }
528 
529 static struct radv_ray_tracing_stage_info
radv_gather_ray_tracing_stage_info(nir_shader * nir)530 radv_gather_ray_tracing_stage_info(nir_shader *nir)
531 {
532    struct radv_ray_tracing_stage_info info = {
533       .can_inline = true,
534       .set_flags = 0xFFFFFFFF,
535       .unset_flags = 0xFFFFFFFF,
536    };
537 
538    nir_function_impl *impl = nir_shader_get_entrypoint(nir);
539    nir_foreach_block (block, impl) {
540       nir_foreach_instr (instr, block) {
541          if (instr->type != nir_instr_type_intrinsic)
542             continue;
543 
544          nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
545          if (intr->intrinsic != nir_intrinsic_trace_ray)
546             continue;
547 
548          info.can_inline = false;
549 
550          radv_gather_trace_ray_src(&info.tmin, intr->src[7]);
551          radv_gather_trace_ray_src(&info.tmax, intr->src[9]);
552          radv_gather_trace_ray_src(&info.sbt_offset, intr->src[3]);
553          radv_gather_trace_ray_src(&info.sbt_stride, intr->src[4]);
554          radv_gather_trace_ray_src(&info.miss_index, intr->src[5]);
555 
556          nir_src flags = intr->src[1];
557          if (nir_src_is_const(flags)) {
558             info.set_flags &= nir_src_as_uint(flags);
559             info.unset_flags &= ~nir_src_as_uint(flags);
560          } else {
561             info.set_flags = 0;
562             info.unset_flags = 0;
563          }
564       }
565    }
566 
567    if (nir->info.stage == MESA_SHADER_RAYGEN || nir->info.stage == MESA_SHADER_ANY_HIT ||
568        nir->info.stage == MESA_SHADER_INTERSECTION)
569       info.can_inline = true;
570    else if (nir->info.stage == MESA_SHADER_CALLABLE)
571       info.can_inline = false;
572 
573    return info;
574 }
575 
576 static inline bool
radv_ray_tracing_stage_is_always_inlined(struct radv_ray_tracing_stage * stage)577 radv_ray_tracing_stage_is_always_inlined(struct radv_ray_tracing_stage *stage)
578 {
579    return stage->stage == MESA_SHADER_ANY_HIT || stage->stage == MESA_SHADER_INTERSECTION;
580 }
581 
582 static VkResult
radv_rt_compile_shaders(struct radv_device * device,struct vk_pipeline_cache * cache,const VkRayTracingPipelineCreateInfoKHR * pCreateInfo,const VkPipelineCreationFeedbackCreateInfo * creation_feedback,const struct radv_shader_stage_key * stage_keys,struct radv_ray_tracing_pipeline * pipeline,struct radv_serialized_shader_arena_block * capture_replay_handles,bool skip_shaders_cache)583 radv_rt_compile_shaders(struct radv_device *device, struct vk_pipeline_cache *cache,
584                         const VkRayTracingPipelineCreateInfoKHR *pCreateInfo,
585                         const VkPipelineCreationFeedbackCreateInfo *creation_feedback,
586                         const struct radv_shader_stage_key *stage_keys, struct radv_ray_tracing_pipeline *pipeline,
587                         struct radv_serialized_shader_arena_block *capture_replay_handles, bool skip_shaders_cache)
588 {
589    VK_FROM_HANDLE(radv_pipeline_layout, pipeline_layout, pCreateInfo->layout);
590 
591    if (pipeline->base.base.create_flags & VK_PIPELINE_CREATE_2_FAIL_ON_PIPELINE_COMPILE_REQUIRED_BIT)
592       return VK_PIPELINE_COMPILE_REQUIRED;
593    VkResult result = VK_SUCCESS;
594 
595    struct radv_ray_tracing_stage *rt_stages = pipeline->stages;
596 
597    struct radv_shader_stage *stages = calloc(pCreateInfo->stageCount, sizeof(struct radv_shader_stage));
598    if (!stages)
599       return VK_ERROR_OUT_OF_HOST_MEMORY;
600 
601    bool library = pipeline->base.base.create_flags & VK_PIPELINE_CREATE_2_LIBRARY_BIT_KHR;
602 
603    bool monolithic = !library;
604    for (uint32_t i = 0; i < pCreateInfo->stageCount; i++) {
605       if (rt_stages[i].shader || rt_stages[i].nir)
606          continue;
607 
608       int64_t stage_start = os_time_get_nano();
609 
610       struct radv_shader_stage *stage = &stages[i];
611       gl_shader_stage s = vk_to_mesa_shader_stage(pCreateInfo->pStages[i].stage);
612       radv_pipeline_stage_init(pipeline->base.base.create_flags, &pCreateInfo->pStages[i],
613                                pipeline_layout, &stage_keys[s], stage);
614 
615       /* precompile the shader */
616       stage->nir = radv_shader_spirv_to_nir(device, stage, NULL, false);
617 
618       NIR_PASS(_, stage->nir, radv_nir_lower_hit_attrib_derefs);
619 
620       rt_stages[i].info = radv_gather_ray_tracing_stage_info(stage->nir);
621 
622       stage->feedback.duration = os_time_get_nano() - stage_start;
623    }
624 
625    bool has_callable = false;
626    /* TODO: Recompile recursive raygen shaders instead. */
627    bool raygen_imported = false;
628    for (uint32_t i = 0; i < pipeline->stage_count; i++) {
629       has_callable |= rt_stages[i].stage == MESA_SHADER_CALLABLE;
630       monolithic &= rt_stages[i].info.can_inline;
631 
632       if (i >= pCreateInfo->stageCount)
633          raygen_imported |= rt_stages[i].stage == MESA_SHADER_RAYGEN;
634    }
635 
636    for (uint32_t idx = 0; idx < pCreateInfo->stageCount; idx++) {
637       if (rt_stages[idx].shader || rt_stages[idx].nir)
638          continue;
639 
640       int64_t stage_start = os_time_get_nano();
641 
642       struct radv_shader_stage *stage = &stages[idx];
643 
644       /* Cases in which we need to keep around the NIR:
645        *    - pipeline library: The final pipeline might be monolithic in which case it will need every NIR shader.
646        *                        If there is a callable shader, we can be sure that the final pipeline won't be
647        *                        monolithic.
648        *    - non-recursive:    Non-recursive shaders are inlined into the traversal shader.
649        *    - monolithic:       Callable shaders (chit/miss) are inlined into the raygen shader.
650        */
651       bool always_inlined = radv_ray_tracing_stage_is_always_inlined(&rt_stages[idx]);
652       bool nir_needed =
653          (library && !has_callable) || always_inlined || (monolithic && rt_stages[idx].stage != MESA_SHADER_RAYGEN);
654       nir_needed &= !rt_stages[idx].nir;
655       if (nir_needed) {
656          const bool cached = !stage->key.optimisations_disabled &&
657                              !(pipeline->base.base.create_flags & VK_PIPELINE_CREATE_2_CAPTURE_DATA_BIT_KHR);
658          rt_stages[idx].stack_size = stage->nir->scratch_size;
659          rt_stages[idx].nir = radv_pipeline_cache_nir_to_handle(device, cache, stage->nir, rt_stages[idx].sha1, cached);
660       }
661 
662       stage->feedback.duration += os_time_get_nano() - stage_start;
663    }
664 
665    for (uint32_t idx = 0; idx < pCreateInfo->stageCount; idx++) {
666       int64_t stage_start = os_time_get_nano();
667       struct radv_shader_stage *stage = &stages[idx];
668 
669       /* Cases in which we need to compile the shader (raygen/callable/chit/miss):
670        *    TODO: - monolithic: Extend the loop to cover imported stages and force compilation of imported raygen
671        *                        shaders since pipeline library shaders use separate compilation.
672        *    - separate:   Compile any recursive stage if wasn't compiled yet.
673        */
674       bool shader_needed = !radv_ray_tracing_stage_is_always_inlined(&rt_stages[idx]) && !rt_stages[idx].shader;
675       if (rt_stages[idx].stage == MESA_SHADER_CLOSEST_HIT || rt_stages[idx].stage == MESA_SHADER_MISS)
676          shader_needed &= !monolithic || raygen_imported;
677 
678       if (shader_needed) {
679          uint32_t stack_size = 0;
680          struct radv_serialized_shader_arena_block *replay_block =
681             capture_replay_handles[idx].arena_va ? &capture_replay_handles[idx] : NULL;
682 
683          bool monolithic_raygen = monolithic && stage->stage == MESA_SHADER_RAYGEN;
684 
685          result =
686             radv_rt_nir_to_asm(device, cache, pCreateInfo, pipeline, monolithic_raygen, stage, &stack_size,
687                                &rt_stages[idx].info, NULL, replay_block, skip_shaders_cache, &rt_stages[idx].shader);
688          if (result != VK_SUCCESS)
689             goto cleanup;
690 
691          assert(rt_stages[idx].stack_size <= stack_size);
692          rt_stages[idx].stack_size = stack_size;
693       }
694 
695       if (creation_feedback && creation_feedback->pipelineStageCreationFeedbackCount) {
696          assert(idx < creation_feedback->pipelineStageCreationFeedbackCount);
697          stage->feedback.duration += os_time_get_nano() - stage_start;
698          creation_feedback->pPipelineStageCreationFeedbacks[idx] = stage->feedback;
699       }
700    }
701 
702    /* Monolithic raygen shaders do not need a traversal shader. Skip compiling one if there are only monolithic raygen
703     * shaders.
704     */
705    bool traversal_needed = !library && (!monolithic || raygen_imported);
706    if (!traversal_needed) {
707       result = VK_SUCCESS;
708       goto cleanup;
709    }
710 
711    struct radv_ray_tracing_stage_info traversal_info = {
712       .set_flags = 0xFFFFFFFF,
713       .unset_flags = 0xFFFFFFFF,
714    };
715 
716    memset(traversal_info.unused_args, 0xFF, sizeof(traversal_info.unused_args));
717 
718    for (uint32_t i = 0; i < pipeline->stage_count; i++) {
719       if (!pipeline->stages[i].shader)
720          continue;
721 
722       struct radv_ray_tracing_stage_info *info = &pipeline->stages[i].info;
723 
724       BITSET_AND(traversal_info.unused_args, traversal_info.unused_args, info->unused_args);
725 
726       radv_rt_const_arg_info_combine(&traversal_info.tmin, &info->tmin);
727       radv_rt_const_arg_info_combine(&traversal_info.tmax, &info->tmax);
728       radv_rt_const_arg_info_combine(&traversal_info.sbt_offset, &info->sbt_offset);
729       radv_rt_const_arg_info_combine(&traversal_info.sbt_stride, &info->sbt_stride);
730       radv_rt_const_arg_info_combine(&traversal_info.miss_index, &info->miss_index);
731 
732       traversal_info.set_flags &= info->set_flags;
733       traversal_info.unset_flags &= info->unset_flags;
734    }
735 
736    /* create traversal shader */
737    nir_shader *traversal_nir = radv_build_traversal_shader(device, pipeline, pCreateInfo, &traversal_info);
738    struct radv_shader_stage traversal_stage = {
739       .stage = MESA_SHADER_INTERSECTION,
740       .nir = traversal_nir,
741       .key = stage_keys[MESA_SHADER_INTERSECTION],
742    };
743    radv_shader_layout_init(pipeline_layout, MESA_SHADER_INTERSECTION, &traversal_stage.layout);
744    result =
745       radv_rt_nir_to_asm(device, cache, pCreateInfo, pipeline, false, &traversal_stage, NULL, NULL, &traversal_info,
746                          NULL, skip_shaders_cache, &pipeline->base.base.shaders[MESA_SHADER_INTERSECTION]);
747    ralloc_free(traversal_nir);
748 
749 cleanup:
750    for (uint32_t i = 0; i < pCreateInfo->stageCount; i++)
751       ralloc_free(stages[i].nir);
752    free(stages);
753    return result;
754 }
755 
756 static bool
radv_rt_pipeline_has_dynamic_stack_size(const VkRayTracingPipelineCreateInfoKHR * pCreateInfo)757 radv_rt_pipeline_has_dynamic_stack_size(const VkRayTracingPipelineCreateInfoKHR *pCreateInfo)
758 {
759    if (!pCreateInfo->pDynamicState)
760       return false;
761 
762    for (unsigned i = 0; i < pCreateInfo->pDynamicState->dynamicStateCount; ++i) {
763       if (pCreateInfo->pDynamicState->pDynamicStates[i] == VK_DYNAMIC_STATE_RAY_TRACING_PIPELINE_STACK_SIZE_KHR)
764          return true;
765    }
766 
767    return false;
768 }
769 
770 static void
compute_rt_stack_size(const VkRayTracingPipelineCreateInfoKHR * pCreateInfo,struct radv_ray_tracing_pipeline * pipeline)771 compute_rt_stack_size(const VkRayTracingPipelineCreateInfoKHR *pCreateInfo, struct radv_ray_tracing_pipeline *pipeline)
772 {
773    if (radv_rt_pipeline_has_dynamic_stack_size(pCreateInfo)) {
774       pipeline->stack_size = -1u;
775       return;
776    }
777 
778    unsigned raygen_size = 0;
779    unsigned callable_size = 0;
780    unsigned chit_miss_size = 0;
781    unsigned intersection_size = 0;
782    unsigned any_hit_size = 0;
783 
784    for (unsigned i = 0; i < pipeline->stage_count; ++i) {
785       uint32_t size = pipeline->stages[i].stack_size;
786       switch (pipeline->stages[i].stage) {
787       case MESA_SHADER_RAYGEN:
788          raygen_size = MAX2(raygen_size, size);
789          break;
790       case MESA_SHADER_CLOSEST_HIT:
791       case MESA_SHADER_MISS:
792          chit_miss_size = MAX2(chit_miss_size, size);
793          break;
794       case MESA_SHADER_CALLABLE:
795          callable_size = MAX2(callable_size, size);
796          break;
797       case MESA_SHADER_INTERSECTION:
798          intersection_size = MAX2(intersection_size, size);
799          break;
800       case MESA_SHADER_ANY_HIT:
801          any_hit_size = MAX2(any_hit_size, size);
802          break;
803       default:
804          unreachable("Invalid stage type in RT shader");
805       }
806    }
807    pipeline->stack_size =
808       raygen_size +
809       MIN2(pCreateInfo->maxPipelineRayRecursionDepth, 1) * MAX2(chit_miss_size, intersection_size + any_hit_size) +
810       MAX2(0, (int)(pCreateInfo->maxPipelineRayRecursionDepth) - 1) * chit_miss_size + 2 * callable_size;
811 }
812 
813 static void
combine_config(struct ac_shader_config * config,struct ac_shader_config * other)814 combine_config(struct ac_shader_config *config, struct ac_shader_config *other)
815 {
816    config->num_sgprs = MAX2(config->num_sgprs, other->num_sgprs);
817    config->num_vgprs = MAX2(config->num_vgprs, other->num_vgprs);
818    config->num_shared_vgprs = MAX2(config->num_shared_vgprs, other->num_shared_vgprs);
819    config->spilled_sgprs = MAX2(config->spilled_sgprs, other->spilled_sgprs);
820    config->spilled_vgprs = MAX2(config->spilled_vgprs, other->spilled_vgprs);
821    config->lds_size = MAX2(config->lds_size, other->lds_size);
822    config->scratch_bytes_per_wave = MAX2(config->scratch_bytes_per_wave, other->scratch_bytes_per_wave);
823 
824    assert(config->float_mode == other->float_mode);
825 }
826 
827 static void
postprocess_rt_config(struct ac_shader_config * config,enum amd_gfx_level gfx_level,unsigned wave_size)828 postprocess_rt_config(struct ac_shader_config *config, enum amd_gfx_level gfx_level, unsigned wave_size)
829 {
830    config->rsrc1 =
831       (config->rsrc1 & C_00B848_VGPRS) | S_00B848_VGPRS((config->num_vgprs - 1) / (wave_size == 32 ? 8 : 4));
832    if (gfx_level < GFX10)
833       config->rsrc1 = (config->rsrc1 & C_00B848_SGPRS) | S_00B848_SGPRS((config->num_sgprs - 1) / 8);
834 
835    config->rsrc2 = (config->rsrc2 & C_00B84C_LDS_SIZE) | S_00B84C_LDS_SIZE(config->lds_size);
836    config->rsrc3 = (config->rsrc3 & C_00B8A0_SHARED_VGPR_CNT) | S_00B8A0_SHARED_VGPR_CNT(config->num_shared_vgprs / 8);
837 }
838 
839 static void
compile_rt_prolog(struct radv_device * device,struct radv_ray_tracing_pipeline * pipeline)840 compile_rt_prolog(struct radv_device *device, struct radv_ray_tracing_pipeline *pipeline)
841 {
842    const struct radv_physical_device *pdev = radv_device_physical(device);
843 
844    pipeline->prolog = radv_create_rt_prolog(device);
845 
846    /* create combined config */
847    struct ac_shader_config *config = &pipeline->prolog->config;
848    for (unsigned i = 0; i < pipeline->stage_count; i++)
849       if (pipeline->stages[i].shader)
850          combine_config(config, &pipeline->stages[i].shader->config);
851 
852    if (pipeline->base.base.shaders[MESA_SHADER_INTERSECTION])
853       combine_config(config, &pipeline->base.base.shaders[MESA_SHADER_INTERSECTION]->config);
854 
855    postprocess_rt_config(config, pdev->info.gfx_level, pdev->rt_wave_size);
856 
857    pipeline->prolog->max_waves = radv_get_max_waves(device, config, &pipeline->prolog->info);
858 }
859 
860 void
radv_ray_tracing_pipeline_hash(const struct radv_device * device,const VkRayTracingPipelineCreateInfoKHR * pCreateInfo,const struct radv_ray_tracing_state_key * rt_state,unsigned char * hash)861 radv_ray_tracing_pipeline_hash(const struct radv_device *device, const VkRayTracingPipelineCreateInfoKHR *pCreateInfo,
862                                const struct radv_ray_tracing_state_key *rt_state, unsigned char *hash)
863 {
864    VK_FROM_HANDLE(radv_pipeline_layout, layout, pCreateInfo->layout);
865    struct mesa_sha1 ctx;
866 
867    _mesa_sha1_init(&ctx);
868    radv_pipeline_hash(device, layout, &ctx);
869 
870    for (uint32_t i = 0; i < pCreateInfo->stageCount; i++) {
871       _mesa_sha1_update(&ctx, rt_state->stages[i].sha1, sizeof(rt_state->stages[i].sha1));
872    }
873 
874    for (uint32_t i = 0; i < pCreateInfo->groupCount; i++) {
875       _mesa_sha1_update(&ctx, &pCreateInfo->pGroups[i].type, sizeof(pCreateInfo->pGroups[i].type));
876       _mesa_sha1_update(&ctx, &pCreateInfo->pGroups[i].generalShader, sizeof(pCreateInfo->pGroups[i].generalShader));
877       _mesa_sha1_update(&ctx, &pCreateInfo->pGroups[i].anyHitShader, sizeof(pCreateInfo->pGroups[i].anyHitShader));
878       _mesa_sha1_update(&ctx, &pCreateInfo->pGroups[i].closestHitShader,
879                         sizeof(pCreateInfo->pGroups[i].closestHitShader));
880       _mesa_sha1_update(&ctx, &pCreateInfo->pGroups[i].intersectionShader,
881                         sizeof(pCreateInfo->pGroups[i].intersectionShader));
882       _mesa_sha1_update(&ctx, &rt_state->groups[i].handle, sizeof(struct radv_pipeline_group_handle));
883    }
884 
885    if (pCreateInfo->pLibraryInfo) {
886       for (uint32_t i = 0; i < pCreateInfo->pLibraryInfo->libraryCount; ++i) {
887          VK_FROM_HANDLE(radv_pipeline, lib_pipeline, pCreateInfo->pLibraryInfo->pLibraries[i]);
888          struct radv_ray_tracing_pipeline *lib = radv_pipeline_to_ray_tracing(lib_pipeline);
889          _mesa_sha1_update(&ctx, lib->base.base.sha1, SHA1_DIGEST_LENGTH);
890       }
891    }
892 
893    const uint64_t pipeline_flags =
894       vk_rt_pipeline_create_flags(pCreateInfo) &
895       (VK_PIPELINE_CREATE_2_RAY_TRACING_SKIP_TRIANGLES_BIT_KHR | VK_PIPELINE_CREATE_2_RAY_TRACING_SKIP_AABBS_BIT_KHR |
896        VK_PIPELINE_CREATE_2_RAY_TRACING_NO_NULL_ANY_HIT_SHADERS_BIT_KHR |
897        VK_PIPELINE_CREATE_2_RAY_TRACING_NO_NULL_CLOSEST_HIT_SHADERS_BIT_KHR |
898        VK_PIPELINE_CREATE_2_RAY_TRACING_NO_NULL_MISS_SHADERS_BIT_KHR |
899        VK_PIPELINE_CREATE_2_RAY_TRACING_NO_NULL_INTERSECTION_SHADERS_BIT_KHR | VK_PIPELINE_CREATE_2_LIBRARY_BIT_KHR);
900    _mesa_sha1_update(&ctx, &pipeline_flags, sizeof(pipeline_flags));
901 
902    _mesa_sha1_final(&ctx, hash);
903 }
904 
905 static VkResult
radv_rt_pipeline_compile(struct radv_device * device,const VkRayTracingPipelineCreateInfoKHR * pCreateInfo,struct radv_ray_tracing_pipeline * pipeline,struct vk_pipeline_cache * cache,const struct radv_ray_tracing_state_key * rt_state,struct radv_serialized_shader_arena_block * capture_replay_blocks,const VkPipelineCreationFeedbackCreateInfo * creation_feedback)906 radv_rt_pipeline_compile(struct radv_device *device, const VkRayTracingPipelineCreateInfoKHR *pCreateInfo,
907                          struct radv_ray_tracing_pipeline *pipeline, struct vk_pipeline_cache *cache,
908                          const struct radv_ray_tracing_state_key *rt_state,
909                          struct radv_serialized_shader_arena_block *capture_replay_blocks,
910                          const VkPipelineCreationFeedbackCreateInfo *creation_feedback)
911 {
912    bool skip_shaders_cache = radv_pipeline_skip_shaders_cache(device, &pipeline->base.base);
913    const bool emit_ray_history = !!device->rra_trace.ray_history_buffer;
914    VkPipelineCreationFeedback pipeline_feedback = {
915       .flags = VK_PIPELINE_CREATION_FEEDBACK_VALID_BIT,
916    };
917    VkResult result = VK_SUCCESS;
918 
919    int64_t pipeline_start = os_time_get_nano();
920 
921    radv_ray_tracing_pipeline_hash(device, pCreateInfo, rt_state, pipeline->base.base.sha1);
922    pipeline->base.base.pipeline_hash = *(uint64_t *)pipeline->base.base.sha1;
923 
924    /* Skip the shaders cache when any of the below are true:
925     * - ray history is enabled
926     * - group handles are saved and reused on a subsequent run (ie. capture/replay)
927     */
928    if (emit_ray_history || (pipeline->base.base.create_flags &
929                             VK_PIPELINE_CREATE_2_RAY_TRACING_SHADER_GROUP_HANDLE_CAPTURE_REPLAY_BIT_KHR)) {
930       skip_shaders_cache = true;
931    }
932 
933    bool found_in_application_cache = true;
934    if (!skip_shaders_cache &&
935        radv_ray_tracing_pipeline_cache_search(device, cache, pipeline, &found_in_application_cache)) {
936       if (found_in_application_cache)
937          pipeline_feedback.flags |= VK_PIPELINE_CREATION_FEEDBACK_APPLICATION_PIPELINE_CACHE_HIT_BIT;
938       result = VK_SUCCESS;
939       goto done;
940    }
941 
942    result = radv_rt_compile_shaders(device, cache, pCreateInfo, creation_feedback, rt_state->stage_keys, pipeline,
943                                     capture_replay_blocks, skip_shaders_cache);
944 
945    if (result != VK_SUCCESS)
946       return result;
947 
948    if (!skip_shaders_cache)
949       radv_ray_tracing_pipeline_cache_insert(device, cache, pipeline, pCreateInfo->stageCount);
950 
951 done:
952    pipeline_feedback.duration = os_time_get_nano() - pipeline_start;
953 
954    if (creation_feedback)
955       *creation_feedback->pPipelineCreationFeedback = pipeline_feedback;
956 
957    return result;
958 }
959 
960 void
radv_ray_tracing_state_key_finish(struct radv_ray_tracing_state_key * rt_state)961 radv_ray_tracing_state_key_finish(struct radv_ray_tracing_state_key *rt_state)
962 {
963    free(rt_state->stages);
964    free(rt_state->groups);
965 }
966 
967 VkResult
radv_generate_ray_tracing_state_key(struct radv_device * device,const VkRayTracingPipelineCreateInfoKHR * pCreateInfo,struct radv_ray_tracing_state_key * rt_state)968 radv_generate_ray_tracing_state_key(struct radv_device *device, const VkRayTracingPipelineCreateInfoKHR *pCreateInfo,
969                                     struct radv_ray_tracing_state_key *rt_state)
970 {
971    VkResult result;
972 
973    memset(rt_state, 0, sizeof(*rt_state));
974 
975    /* Count the total number of stages/groups. */
976    rt_state->stage_count = pCreateInfo->stageCount;
977    rt_state->group_count = pCreateInfo->groupCount;
978 
979    if (pCreateInfo->pLibraryInfo) {
980       for (unsigned i = 0; i < pCreateInfo->pLibraryInfo->libraryCount; ++i) {
981          VK_FROM_HANDLE(radv_pipeline, pipeline, pCreateInfo->pLibraryInfo->pLibraries[i]);
982          struct radv_ray_tracing_pipeline *library_pipeline = radv_pipeline_to_ray_tracing(pipeline);
983 
984          rt_state->stage_count += library_pipeline->stage_count;
985          rt_state->group_count += library_pipeline->group_count;
986       }
987    }
988 
989    rt_state->stages = calloc(rt_state->stage_count, sizeof(*rt_state->stages));
990    if (!rt_state->stages)
991       return VK_ERROR_OUT_OF_HOST_MEMORY;
992 
993    rt_state->groups = calloc(rt_state->group_count, sizeof(*rt_state->groups));
994    if (!rt_state->groups) {
995       result = VK_ERROR_OUT_OF_HOST_MEMORY;
996       goto fail;
997    }
998 
999    /* Initialize stages/stage_keys/groups info. */
1000    radv_rt_fill_stage_info(pCreateInfo, rt_state->stages);
1001 
1002    radv_generate_rt_shaders_key(device, pCreateInfo, rt_state->stage_keys);
1003 
1004    VkPipelineCreateFlags2 create_flags = vk_rt_pipeline_create_flags(pCreateInfo);
1005    radv_init_rt_stage_hashes(device, create_flags, pCreateInfo, rt_state->stages, rt_state->stage_keys);
1006 
1007    result = radv_rt_fill_group_info(device, pCreateInfo, rt_state->stages, rt_state->groups);
1008    if (result != VK_SUCCESS)
1009       goto fail;
1010 
1011    return VK_SUCCESS;
1012 
1013 fail:
1014    radv_ray_tracing_state_key_finish(rt_state);
1015    return result;
1016 }
1017 
1018 static VkResult
radv_ray_tracing_pipeline_import_binary(struct radv_device * device,struct radv_ray_tracing_pipeline * pipeline,const VkPipelineBinaryInfoKHR * binary_info)1019 radv_ray_tracing_pipeline_import_binary(struct radv_device *device, struct radv_ray_tracing_pipeline *pipeline,
1020                                         const VkPipelineBinaryInfoKHR *binary_info)
1021 {
1022    blake3_hash pipeline_hash;
1023    struct mesa_blake3 ctx;
1024 
1025    _mesa_blake3_init(&ctx);
1026 
1027    for (uint32_t i = 0; i < binary_info->binaryCount; i++) {
1028       VK_FROM_HANDLE(radv_pipeline_binary, pipeline_binary, binary_info->pPipelineBinaries[i]);
1029       struct radv_shader *shader;
1030       struct blob_reader blob;
1031 
1032       blob_reader_init(&blob, pipeline_binary->data, pipeline_binary->size);
1033 
1034       const struct radv_ray_tracing_binary_header *header =
1035          (const struct radv_ray_tracing_binary_header *)blob_read_bytes(&blob, sizeof(*header));
1036 
1037       if (header->is_traversal_shader) {
1038          shader = radv_shader_deserialize(device, pipeline_binary->key, sizeof(pipeline_binary->key), &blob);
1039          if (!shader)
1040             return VK_ERROR_OUT_OF_DEVICE_MEMORY;
1041 
1042          pipeline->base.base.shaders[MESA_SHADER_INTERSECTION] = shader;
1043 
1044          _mesa_blake3_update(&ctx, pipeline_binary->key, sizeof(pipeline_binary->key));
1045          continue;
1046       }
1047 
1048       memcpy(&pipeline->stages[i].info, &header->stage_info, sizeof(pipeline->stages[i].info));
1049       pipeline->stages[i].stack_size = header->stack_size;
1050 
1051       if (header->has_shader) {
1052          shader = radv_shader_deserialize(device, pipeline_binary->key, sizeof(pipeline_binary->key), &blob);
1053          if (!shader)
1054             return VK_ERROR_OUT_OF_DEVICE_MEMORY;
1055 
1056          pipeline->stages[i].shader = shader;
1057 
1058          _mesa_blake3_update(&ctx, pipeline_binary->key, sizeof(pipeline_binary->key));
1059       }
1060 
1061       if (header->has_nir) {
1062          nir_shader *nir = nir_deserialize(NULL, NULL, &blob);
1063 
1064          pipeline->stages[i].nir = radv_pipeline_cache_nir_to_handle(device, NULL, nir, header->stage_sha1, false);
1065          ralloc_free(nir);
1066 
1067          if (!pipeline->stages[i].nir)
1068             return VK_ERROR_OUT_OF_HOST_MEMORY;
1069       }
1070    }
1071 
1072    _mesa_blake3_final(&ctx, pipeline_hash);
1073 
1074    pipeline->base.base.pipeline_hash = *(uint64_t *)pipeline_hash;
1075 
1076    return VK_SUCCESS;
1077 }
1078 
1079 static VkResult
radv_rt_pipeline_create(VkDevice _device,VkPipelineCache _cache,const VkRayTracingPipelineCreateInfoKHR * pCreateInfo,const VkAllocationCallbacks * pAllocator,VkPipeline * pPipeline)1080 radv_rt_pipeline_create(VkDevice _device, VkPipelineCache _cache, const VkRayTracingPipelineCreateInfoKHR *pCreateInfo,
1081                         const VkAllocationCallbacks *pAllocator, VkPipeline *pPipeline)
1082 {
1083    VK_FROM_HANDLE(radv_device, device, _device);
1084    VK_FROM_HANDLE(vk_pipeline_cache, cache, _cache);
1085    VK_FROM_HANDLE(radv_pipeline_layout, pipeline_layout, pCreateInfo->layout);
1086    struct radv_ray_tracing_state_key rt_state;
1087    VkResult result;
1088    const VkPipelineCreationFeedbackCreateInfo *creation_feedback =
1089       vk_find_struct_const(pCreateInfo->pNext, PIPELINE_CREATION_FEEDBACK_CREATE_INFO);
1090 
1091    result = radv_generate_ray_tracing_state_key(device, pCreateInfo, &rt_state);
1092    if (result != VK_SUCCESS)
1093       return result;
1094 
1095    VK_MULTIALLOC(ma);
1096    VK_MULTIALLOC_DECL(&ma, struct radv_ray_tracing_pipeline, pipeline, 1);
1097    VK_MULTIALLOC_DECL(&ma, struct radv_ray_tracing_stage, stages, rt_state.stage_count);
1098    VK_MULTIALLOC_DECL(&ma, struct radv_ray_tracing_group, groups, rt_state.group_count);
1099    VK_MULTIALLOC_DECL(&ma, struct radv_serialized_shader_arena_block, capture_replay_blocks, pCreateInfo->stageCount);
1100    if (!vk_multialloc_zalloc2(&ma, &device->vk.alloc, pAllocator, VK_SYSTEM_ALLOCATION_SCOPE_OBJECT)) {
1101       radv_ray_tracing_state_key_finish(&rt_state);
1102       return VK_ERROR_OUT_OF_HOST_MEMORY;
1103    }
1104 
1105    radv_pipeline_init(device, &pipeline->base.base, RADV_PIPELINE_RAY_TRACING);
1106    pipeline->base.base.create_flags = vk_rt_pipeline_create_flags(pCreateInfo);
1107    pipeline->stage_count = rt_state.stage_count;
1108    pipeline->non_imported_stage_count = pCreateInfo->stageCount;
1109    pipeline->group_count = rt_state.group_count;
1110    pipeline->stages = stages;
1111    pipeline->groups = groups;
1112 
1113    memcpy(pipeline->stages, rt_state.stages, rt_state.stage_count * sizeof(struct radv_ray_tracing_stage));
1114    memcpy(pipeline->groups, rt_state.groups, rt_state.group_count * sizeof(struct radv_ray_tracing_group));
1115 
1116    /* cache robustness state for making merged shaders */
1117    if (rt_state.stage_keys[MESA_SHADER_INTERSECTION].storage_robustness2)
1118       pipeline->traversal_storage_robustness2 = true;
1119 
1120    if (rt_state.stage_keys[MESA_SHADER_INTERSECTION].uniform_robustness2)
1121       pipeline->traversal_uniform_robustness2 = true;
1122 
1123    result = radv_rt_init_capture_replay(device, pCreateInfo, stages, pipeline->groups, capture_replay_blocks);
1124    if (result != VK_SUCCESS)
1125       goto fail;
1126 
1127    const VkPipelineBinaryInfoKHR *binary_info = vk_find_struct_const(pCreateInfo->pNext, PIPELINE_BINARY_INFO_KHR);
1128 
1129    if (binary_info && binary_info->binaryCount > 0) {
1130       result = radv_ray_tracing_pipeline_import_binary(device, pipeline, binary_info);
1131    } else {
1132       result = radv_rt_pipeline_compile(device, pCreateInfo, pipeline, cache, &rt_state, capture_replay_blocks,
1133                                         creation_feedback);
1134       if (result != VK_SUCCESS)
1135          goto fail;
1136    }
1137 
1138    if (!(pipeline->base.base.create_flags & VK_PIPELINE_CREATE_2_LIBRARY_BIT_KHR)) {
1139       compute_rt_stack_size(pCreateInfo, pipeline);
1140       compile_rt_prolog(device, pipeline);
1141 
1142       radv_compute_pipeline_init(&pipeline->base, pipeline_layout, pipeline->prolog);
1143    }
1144 
1145    /* write shader VAs into group handles */
1146    for (unsigned i = 0; i < pipeline->group_count; i++) {
1147       if (pipeline->groups[i].recursive_shader != VK_SHADER_UNUSED_KHR) {
1148          struct radv_shader *shader = pipeline->stages[pipeline->groups[i].recursive_shader].shader;
1149          if (shader)
1150             pipeline->groups[i].handle.recursive_shader_ptr = shader->va | radv_get_rt_priority(shader->info.stage);
1151       }
1152    }
1153 
1154    *pPipeline = radv_pipeline_to_handle(&pipeline->base.base);
1155    radv_rmv_log_rt_pipeline_create(device, pipeline);
1156 
1157    radv_ray_tracing_state_key_finish(&rt_state);
1158    return result;
1159 
1160 fail:
1161    radv_ray_tracing_state_key_finish(&rt_state);
1162    radv_pipeline_destroy(device, &pipeline->base.base, pAllocator);
1163    return result;
1164 }
1165 
1166 void
radv_destroy_ray_tracing_pipeline(struct radv_device * device,struct radv_ray_tracing_pipeline * pipeline)1167 radv_destroy_ray_tracing_pipeline(struct radv_device *device, struct radv_ray_tracing_pipeline *pipeline)
1168 {
1169    for (unsigned i = 0; i < pipeline->stage_count; i++) {
1170       if (pipeline->stages[i].nir)
1171          vk_pipeline_cache_object_unref(&device->vk, pipeline->stages[i].nir);
1172       if (pipeline->stages[i].shader)
1173          radv_shader_unref(device, pipeline->stages[i].shader);
1174    }
1175 
1176    if (pipeline->prolog)
1177       radv_shader_unref(device, pipeline->prolog);
1178    if (pipeline->base.base.shaders[MESA_SHADER_INTERSECTION])
1179       radv_shader_unref(device, pipeline->base.base.shaders[MESA_SHADER_INTERSECTION]);
1180 }
1181 
1182 VKAPI_ATTR VkResult VKAPI_CALL
radv_CreateRayTracingPipelinesKHR(VkDevice _device,VkDeferredOperationKHR deferredOperation,VkPipelineCache pipelineCache,uint32_t count,const VkRayTracingPipelineCreateInfoKHR * pCreateInfos,const VkAllocationCallbacks * pAllocator,VkPipeline * pPipelines)1183 radv_CreateRayTracingPipelinesKHR(VkDevice _device, VkDeferredOperationKHR deferredOperation,
1184                                   VkPipelineCache pipelineCache, uint32_t count,
1185                                   const VkRayTracingPipelineCreateInfoKHR *pCreateInfos,
1186                                   const VkAllocationCallbacks *pAllocator, VkPipeline *pPipelines)
1187 {
1188    VkResult result = VK_SUCCESS;
1189 
1190    unsigned i = 0;
1191    for (; i < count; i++) {
1192       VkResult r;
1193       r = radv_rt_pipeline_create(_device, pipelineCache, &pCreateInfos[i], pAllocator, &pPipelines[i]);
1194       if (r != VK_SUCCESS) {
1195          result = r;
1196          pPipelines[i] = VK_NULL_HANDLE;
1197 
1198          const VkPipelineCreateFlagBits2 create_flags = vk_rt_pipeline_create_flags(&pCreateInfos[i]);
1199          if (create_flags & VK_PIPELINE_CREATE_2_EARLY_RETURN_ON_FAILURE_BIT)
1200             break;
1201       }
1202    }
1203 
1204    for (; i < count; ++i)
1205       pPipelines[i] = VK_NULL_HANDLE;
1206 
1207    if (result != VK_SUCCESS)
1208       return result;
1209 
1210    /* Work around Portal RTX not handling VK_OPERATION_NOT_DEFERRED_KHR correctly. */
1211    if (deferredOperation != VK_NULL_HANDLE)
1212       return VK_OPERATION_DEFERRED_KHR;
1213 
1214    return result;
1215 }
1216 
1217 VKAPI_ATTR VkResult VKAPI_CALL
radv_GetRayTracingShaderGroupHandlesKHR(VkDevice device,VkPipeline _pipeline,uint32_t firstGroup,uint32_t groupCount,size_t dataSize,void * pData)1218 radv_GetRayTracingShaderGroupHandlesKHR(VkDevice device, VkPipeline _pipeline, uint32_t firstGroup, uint32_t groupCount,
1219                                         size_t dataSize, void *pData)
1220 {
1221    VK_FROM_HANDLE(radv_pipeline, pipeline, _pipeline);
1222    struct radv_ray_tracing_group *groups = radv_pipeline_to_ray_tracing(pipeline)->groups;
1223    char *data = pData;
1224 
1225    STATIC_ASSERT(sizeof(struct radv_pipeline_group_handle) <= RADV_RT_HANDLE_SIZE);
1226 
1227    memset(data, 0, groupCount * RADV_RT_HANDLE_SIZE);
1228 
1229    for (uint32_t i = 0; i < groupCount; ++i) {
1230       memcpy(data + i * RADV_RT_HANDLE_SIZE, &groups[firstGroup + i].handle, sizeof(struct radv_pipeline_group_handle));
1231    }
1232 
1233    return VK_SUCCESS;
1234 }
1235 
1236 VKAPI_ATTR VkDeviceSize VKAPI_CALL
radv_GetRayTracingShaderGroupStackSizeKHR(VkDevice device,VkPipeline _pipeline,uint32_t group,VkShaderGroupShaderKHR groupShader)1237 radv_GetRayTracingShaderGroupStackSizeKHR(VkDevice device, VkPipeline _pipeline, uint32_t group,
1238                                           VkShaderGroupShaderKHR groupShader)
1239 {
1240    VK_FROM_HANDLE(radv_pipeline, pipeline, _pipeline);
1241    struct radv_ray_tracing_pipeline *rt_pipeline = radv_pipeline_to_ray_tracing(pipeline);
1242    struct radv_ray_tracing_group *rt_group = &rt_pipeline->groups[group];
1243    switch (groupShader) {
1244    case VK_SHADER_GROUP_SHADER_GENERAL_KHR:
1245    case VK_SHADER_GROUP_SHADER_CLOSEST_HIT_KHR:
1246       return rt_pipeline->stages[rt_group->recursive_shader].stack_size;
1247    case VK_SHADER_GROUP_SHADER_ANY_HIT_KHR:
1248       return rt_pipeline->stages[rt_group->any_hit_shader].stack_size;
1249    case VK_SHADER_GROUP_SHADER_INTERSECTION_KHR:
1250       return rt_pipeline->stages[rt_group->intersection_shader].stack_size;
1251    default:
1252       return 0;
1253    }
1254 }
1255 
1256 VKAPI_ATTR VkResult VKAPI_CALL
radv_GetRayTracingCaptureReplayShaderGroupHandlesKHR(VkDevice device,VkPipeline _pipeline,uint32_t firstGroup,uint32_t groupCount,size_t dataSize,void * pData)1257 radv_GetRayTracingCaptureReplayShaderGroupHandlesKHR(VkDevice device, VkPipeline _pipeline, uint32_t firstGroup,
1258                                                      uint32_t groupCount, size_t dataSize, void *pData)
1259 {
1260    VK_FROM_HANDLE(radv_pipeline, pipeline, _pipeline);
1261    struct radv_ray_tracing_pipeline *rt_pipeline = radv_pipeline_to_ray_tracing(pipeline);
1262    struct radv_rt_capture_replay_handle *data = pData;
1263 
1264    memset(data, 0, groupCount * sizeof(struct radv_rt_capture_replay_handle));
1265 
1266    for (uint32_t i = 0; i < groupCount; ++i) {
1267       uint32_t recursive_shader = rt_pipeline->groups[firstGroup + i].recursive_shader;
1268       if (recursive_shader != VK_SHADER_UNUSED_KHR) {
1269          struct radv_shader *shader = rt_pipeline->stages[recursive_shader].shader;
1270          if (shader) {
1271             data[i].recursive_shader_alloc.offset = shader->alloc->offset;
1272             data[i].recursive_shader_alloc.size = shader->alloc->size;
1273             data[i].recursive_shader_alloc.arena_va = shader->alloc->arena->bo->va;
1274             data[i].recursive_shader_alloc.arena_size = shader->alloc->arena->size;
1275          }
1276       }
1277       data[i].non_recursive_idx = rt_pipeline->groups[firstGroup + i].handle.any_hit_index;
1278    }
1279 
1280    return VK_SUCCESS;
1281 }
1282