• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright © 2021 Bas Nieuwenhuizen
3  * Copyright © 2023 Valve Corporation
4  *
5  * Permission is hereby granted, free of charge, to any person obtaining a
6  * copy of this software and associated documentation files (the "Software"),
7  * to deal in the Software without restriction, including without limitation
8  * the rights to use, copy, modify, merge, publish, distribute, sublicense,
9  * and/or sell copies of the Software, and to permit persons to whom the
10  * Software is furnished to do so, subject to the following conditions:
11  *
12  * The above copyright notice and this permission notice (including the next
13  * paragraph) shall be included in all copies or substantial portions of the
14  * Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
19  * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
21  * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
22  * IN THE SOFTWARE.
23  */
24 
25 #include "vk_acceleration_structure.h"
26 
27 #include "vk_alloc.h"
28 #include "vk_common_entrypoints.h"
29 #include "vk_device.h"
30 #include "vk_command_buffer.h"
31 #include "vk_log.h"
32 #include "vk_meta.h"
33 
34 #include "bvh/vk_build_interface.h"
35 #include "bvh/vk_bvh.h"
36 
37 #include "radix_sort/common/vk/barrier.h"
38 #include "radix_sort/shaders/push.h"
39 
40 #include "util/u_string.h"
41 
42 static const uint32_t leaf_spv[] = {
43 #include "bvh/leaf.spv.h"
44 };
45 
46 static const uint32_t leaf_always_active_spv[] = {
47 #include "bvh/leaf_always_active.spv.h"
48 };
49 
50 static const uint32_t morton_spv[] = {
51 #include "bvh/morton.spv.h"
52 };
53 
54 static const uint32_t lbvh_main_spv[] = {
55 #include "bvh/lbvh_main.spv.h"
56 };
57 
58 static const uint32_t lbvh_generate_ir_spv[] = {
59 #include "bvh/lbvh_generate_ir.spv.h"
60 };
61 
62 static const uint32_t ploc_spv[] = {
63 #include "bvh/ploc_internal.spv.h"
64 };
65 
66 VkDeviceAddress
vk_acceleration_structure_get_va(struct vk_acceleration_structure * accel_struct)67 vk_acceleration_structure_get_va(struct vk_acceleration_structure *accel_struct)
68 {
69    VkBufferDeviceAddressInfo info = {
70       .sType = VK_STRUCTURE_TYPE_BUFFER_DEVICE_ADDRESS_INFO,
71       .buffer = accel_struct->buffer,
72    };
73 
74    VkDeviceAddress base_addr = accel_struct->base.device->dispatch_table.GetBufferDeviceAddress(
75       vk_device_to_handle(accel_struct->base.device), &info);
76 
77    return base_addr + accel_struct->offset;
78 }
79 
80 
81 VKAPI_ATTR VkResult VKAPI_CALL
vk_common_CreateAccelerationStructureKHR(VkDevice _device,const VkAccelerationStructureCreateInfoKHR * pCreateInfo,const VkAllocationCallbacks * pAllocator,VkAccelerationStructureKHR * pAccelerationStructure)82 vk_common_CreateAccelerationStructureKHR(VkDevice _device,
83                                          const VkAccelerationStructureCreateInfoKHR *pCreateInfo,
84                                          const VkAllocationCallbacks *pAllocator,
85                                          VkAccelerationStructureKHR *pAccelerationStructure)
86 {
87    VK_FROM_HANDLE(vk_device, device, _device);
88 
89    struct vk_acceleration_structure *accel_struct = vk_object_alloc(
90       device, pAllocator, sizeof(struct vk_acceleration_structure),
91       VK_OBJECT_TYPE_ACCELERATION_STRUCTURE_KHR);
92 
93    if (!accel_struct)
94       return vk_error(device, VK_ERROR_OUT_OF_HOST_MEMORY);
95 
96    accel_struct->buffer = pCreateInfo->buffer;
97    accel_struct->offset = pCreateInfo->offset;
98    accel_struct->size = pCreateInfo->size;
99 
100    if (pCreateInfo->deviceAddress &&
101        vk_acceleration_structure_get_va(accel_struct) != pCreateInfo->deviceAddress)
102       return vk_error(device, VK_ERROR_INVALID_OPAQUE_CAPTURE_ADDRESS);
103 
104    *pAccelerationStructure = vk_acceleration_structure_to_handle(accel_struct);
105    return VK_SUCCESS;
106 }
107 
108 VKAPI_ATTR void VKAPI_CALL
vk_common_DestroyAccelerationStructureKHR(VkDevice _device,VkAccelerationStructureKHR accelerationStructure,const VkAllocationCallbacks * pAllocator)109 vk_common_DestroyAccelerationStructureKHR(VkDevice _device,
110                                      VkAccelerationStructureKHR accelerationStructure,
111                                      const VkAllocationCallbacks *pAllocator)
112 {
113    VK_FROM_HANDLE(vk_device, device, _device);
114    VK_FROM_HANDLE(vk_acceleration_structure, accel_struct, accelerationStructure);
115 
116    if (!accel_struct)
117       return;
118 
119    vk_object_free(device, pAllocator, accel_struct);
120 }
121 
122 VKAPI_ATTR VkDeviceAddress VKAPI_CALL
vk_common_GetAccelerationStructureDeviceAddressKHR(VkDevice _device,const VkAccelerationStructureDeviceAddressInfoKHR * pInfo)123 vk_common_GetAccelerationStructureDeviceAddressKHR(
124    VkDevice _device, const VkAccelerationStructureDeviceAddressInfoKHR *pInfo)
125 {
126    VK_FROM_HANDLE(vk_acceleration_structure, accel_struct, pInfo->accelerationStructure);
127    return vk_acceleration_structure_get_va(accel_struct);
128 }
129 
130 #define KEY_ID_PAIR_SIZE 8
131 #define MORTON_BIT_SIZE  24
132 
133 enum internal_build_type {
134    INTERNAL_BUILD_TYPE_LBVH,
135    INTERNAL_BUILD_TYPE_PLOC,
136    INTERNAL_BUILD_TYPE_UPDATE,
137 };
138 
139 struct build_config {
140    enum internal_build_type internal_type;
141    bool updateable;
142    uint32_t encode_key[MAX_ENCODE_PASSES];
143 };
144 
145 struct scratch_layout {
146    uint32_t size;
147    uint32_t update_size;
148 
149    uint32_t header_offset;
150 
151    /* Used for BUILD only. */
152 
153    uint32_t sort_buffer_offset[2];
154    uint32_t sort_internal_offset;
155 
156    uint32_t ploc_prefix_sum_partition_offset;
157    uint32_t lbvh_node_offset;
158 
159    uint32_t ir_offset;
160    uint32_t internal_node_offset;
161 };
162 
163 static struct build_config
build_config(uint32_t leaf_count,const VkAccelerationStructureBuildGeometryInfoKHR * build_info,const struct vk_acceleration_structure_build_ops * ops)164 build_config(uint32_t leaf_count,
165              const VkAccelerationStructureBuildGeometryInfoKHR *build_info,
166              const struct vk_acceleration_structure_build_ops *ops)
167 {
168    struct build_config config = {0};
169 
170    if (leaf_count <= 4)
171       config.internal_type = INTERNAL_BUILD_TYPE_LBVH;
172    else if (build_info->type == VK_ACCELERATION_STRUCTURE_TYPE_TOP_LEVEL_KHR)
173       config.internal_type = INTERNAL_BUILD_TYPE_PLOC;
174    else if (!(build_info->flags & VK_BUILD_ACCELERATION_STRUCTURE_PREFER_FAST_BUILD_BIT_KHR) &&
175             !(build_info->flags & VK_BUILD_ACCELERATION_STRUCTURE_ALLOW_UPDATE_BIT_KHR))
176       config.internal_type = INTERNAL_BUILD_TYPE_PLOC;
177    else
178       config.internal_type = INTERNAL_BUILD_TYPE_LBVH;
179 
180    if (build_info->mode == VK_BUILD_ACCELERATION_STRUCTURE_MODE_UPDATE_KHR &&
181        build_info->type == VK_ACCELERATION_STRUCTURE_TYPE_BOTTOM_LEVEL_KHR &&
182        ops->update_as[0])
183       config.internal_type = INTERNAL_BUILD_TYPE_UPDATE;
184 
185    if ((build_info->flags & VK_BUILD_ACCELERATION_STRUCTURE_ALLOW_UPDATE_BIT_KHR) &&
186        build_info->type == VK_ACCELERATION_STRUCTURE_TYPE_BOTTOM_LEVEL_KHR &&
187        ops->update_as[0])
188       config.updateable = true;
189 
190    for (unsigned i = 0; i < ARRAY_SIZE(config.encode_key); i++) {
191       if (!ops->get_encode_key[i])
192          break;
193       config.encode_key[i] = ops->get_encode_key[i](leaf_count, build_info->flags);
194    }
195 
196    return config;
197 }
198 
199 static void
get_scratch_layout(struct vk_device * device,uint32_t leaf_count,const VkAccelerationStructureBuildGeometryInfoKHR * build_info,const struct vk_acceleration_structure_build_args * args,struct scratch_layout * scratch)200 get_scratch_layout(struct vk_device *device,
201                    uint32_t leaf_count,
202                    const VkAccelerationStructureBuildGeometryInfoKHR *build_info,
203                    const struct vk_acceleration_structure_build_args *args,
204                    struct scratch_layout *scratch)
205 {
206    uint32_t internal_count = MAX2(leaf_count, 2) - 1;
207 
208    radix_sort_vk_memory_requirements_t requirements = {
209       0,
210    };
211    radix_sort_vk_get_memory_requirements(args->radix_sort, leaf_count,
212                                          &requirements);
213 
214    uint32_t ir_leaf_size;
215    switch (vk_get_as_geometry_type(build_info)) {
216    case VK_GEOMETRY_TYPE_TRIANGLES_KHR:
217       ir_leaf_size = sizeof(struct vk_ir_triangle_node);
218       break;
219    case VK_GEOMETRY_TYPE_AABBS_KHR:
220       ir_leaf_size = sizeof(struct vk_ir_aabb_node);
221       break;
222    case VK_GEOMETRY_TYPE_INSTANCES_KHR:
223       ir_leaf_size = sizeof(struct vk_ir_instance_node);
224       break;
225    default:
226       unreachable("Unknown VkGeometryTypeKHR");
227    }
228 
229 
230    uint32_t offset = 0;
231 
232    uint32_t ploc_scratch_space = 0;
233    uint32_t lbvh_node_space = 0;
234 
235    struct build_config config = build_config(leaf_count, build_info,
236                                              device->as_build_ops);
237 
238    if (config.internal_type == INTERNAL_BUILD_TYPE_PLOC)
239       ploc_scratch_space = DIV_ROUND_UP(leaf_count, PLOC_WORKGROUP_SIZE) * sizeof(struct ploc_prefix_scan_partition);
240    else
241       lbvh_node_space = sizeof(struct lbvh_node_info) * internal_count;
242 
243    scratch->header_offset = offset;
244    offset += sizeof(struct vk_ir_header);
245 
246    scratch->sort_buffer_offset[0] = offset;
247    offset += requirements.keyvals_size;
248 
249    scratch->sort_buffer_offset[1] = offset;
250    offset += requirements.keyvals_size;
251 
252    scratch->sort_internal_offset = offset;
253    /* Internal sorting data is not needed when PLOC/LBVH are invoked,
254     * save space by aliasing them */
255    scratch->ploc_prefix_sum_partition_offset = offset;
256    scratch->lbvh_node_offset = offset;
257    offset += MAX3(requirements.internal_size, ploc_scratch_space, lbvh_node_space);
258 
259    scratch->ir_offset = offset;
260    offset += ir_leaf_size * leaf_count;
261 
262    scratch->internal_node_offset = offset;
263    offset += sizeof(struct vk_ir_box_node) * internal_count;
264 
265    scratch->size = offset;
266 
267    if (build_info->type == VK_ACCELERATION_STRUCTURE_TYPE_BOTTOM_LEVEL_KHR &&
268        device->as_build_ops->update_as[0]) {
269       scratch->update_size =
270          device->as_build_ops->get_update_scratch_size(device, leaf_count);
271    } else {
272       scratch->update_size = offset;
273    }
274 }
275 
276 struct bvh_state {
277    uint32_t scratch_offset;
278 
279    uint32_t leaf_node_count;
280    uint32_t internal_node_count;
281    uint32_t leaf_node_size;
282 
283    struct scratch_layout scratch;
284    struct build_config config;
285 
286    /* Radix sort state */
287    uint32_t scatter_blocks;
288    uint32_t count_ru_scatter;
289    uint32_t histo_blocks;
290    uint32_t count_ru_histo;
291    struct rs_push_scatter push_scatter;
292 
293    uint32_t last_encode_pass;
294 };
295 
296 struct bvh_batch_state {
297    bool any_updateable;
298    bool any_non_updateable;
299    bool any_ploc;
300    bool any_lbvh;
301    bool any_update;
302 };
303 
304 static VkResult
get_pipeline_spv(struct vk_device * device,struct vk_meta_device * meta,const char * name,const uint32_t * spv,uint32_t spv_size,unsigned push_constant_size,const struct vk_acceleration_structure_build_args * args,VkPipeline * pipeline,VkPipelineLayout * layout)305 get_pipeline_spv(struct vk_device *device, struct vk_meta_device *meta,
306                  const char *name, const uint32_t *spv, uint32_t spv_size,
307                  unsigned push_constant_size,
308                  const struct vk_acceleration_structure_build_args *args,
309                  VkPipeline *pipeline, VkPipelineLayout *layout)
310 {
311    size_t key_size = strlen(name);
312 
313    VkResult result = vk_meta_get_pipeline_layout(
314          device, meta, NULL,
315          &(VkPushConstantRange){
316             VK_SHADER_STAGE_COMPUTE_BIT, 0, push_constant_size
317          },
318          name, key_size, layout);
319 
320    if (result != VK_SUCCESS)
321       return result;
322 
323    VkPipeline pipeline_from_cache = vk_meta_lookup_pipeline(meta, name, key_size);
324    if (pipeline_from_cache != VK_NULL_HANDLE) {
325       *pipeline = pipeline_from_cache;
326       return VK_SUCCESS;
327    }
328 
329    VkShaderModuleCreateInfo module_info = {
330       .sType = VK_STRUCTURE_TYPE_SHADER_MODULE_CREATE_INFO,
331       .pNext = NULL,
332       .flags = 0,
333       .codeSize = spv_size,
334       .pCode = spv,
335    };
336 
337    VkSpecializationMapEntry spec_map[2] = {
338       {
339          .constantID = SUBGROUP_SIZE_ID,
340          .offset = 0,
341          .size = sizeof(args->subgroup_size),
342       },
343       {
344          .constantID = BVH_BOUNDS_OFFSET_ID,
345          .offset = sizeof(args->subgroup_size),
346          .size = sizeof(args->bvh_bounds_offset),
347       },
348    };
349 
350    uint32_t spec_constants[2] = {
351       args->subgroup_size,
352       args->bvh_bounds_offset
353    };
354 
355    VkSpecializationInfo spec_info = {
356       .mapEntryCount = ARRAY_SIZE(spec_map),
357       .pMapEntries = spec_map,
358       .dataSize = sizeof(spec_constants),
359       .pData = spec_constants,
360    };
361 
362    VkPipelineShaderStageRequiredSubgroupSizeCreateInfoEXT rssci = {
363       .sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_REQUIRED_SUBGROUP_SIZE_CREATE_INFO_EXT,
364       .pNext = &module_info,
365       .requiredSubgroupSize = args->subgroup_size,
366    };
367 
368    VkPipelineShaderStageCreateInfo shader_stage = {
369       .sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO,
370       .pNext = &rssci,
371       .flags = VK_PIPELINE_SHADER_STAGE_CREATE_REQUIRE_FULL_SUBGROUPS_BIT_EXT,
372       .stage = VK_SHADER_STAGE_COMPUTE_BIT,
373       .pName = "main",
374       .pSpecializationInfo = &spec_info,
375    };
376 
377    VkComputePipelineCreateInfo pipeline_info = {
378       .sType = VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO,
379       .stage = shader_stage,
380       .flags = 0,
381       .layout = *layout,
382    };
383 
384    return vk_meta_create_compute_pipeline(device, meta, &pipeline_info,
385                                           name, key_size, pipeline);
386 }
387 
388 static uint32_t
pack_geometry_id_and_flags(uint32_t geometry_id,uint32_t flags)389 pack_geometry_id_and_flags(uint32_t geometry_id, uint32_t flags)
390 {
391    uint32_t geometry_id_and_flags = geometry_id;
392    if (flags & VK_GEOMETRY_OPAQUE_BIT_KHR)
393       geometry_id_and_flags |= VK_GEOMETRY_OPAQUE;
394 
395    return geometry_id_and_flags;
396 }
397 
398 struct vk_bvh_geometry_data
vk_fill_geometry_data(VkAccelerationStructureTypeKHR type,uint32_t first_id,uint32_t geom_index,const VkAccelerationStructureGeometryKHR * geometry,const VkAccelerationStructureBuildRangeInfoKHR * build_range_info)399 vk_fill_geometry_data(VkAccelerationStructureTypeKHR type, uint32_t first_id, uint32_t geom_index,
400                       const VkAccelerationStructureGeometryKHR *geometry,
401                       const VkAccelerationStructureBuildRangeInfoKHR *build_range_info)
402 {
403    struct vk_bvh_geometry_data data = {
404       .first_id = first_id,
405       .geometry_id = pack_geometry_id_and_flags(geom_index, geometry->flags),
406       .geometry_type = geometry->geometryType,
407    };
408 
409    switch (geometry->geometryType) {
410    case VK_GEOMETRY_TYPE_TRIANGLES_KHR:
411       assert(type == VK_ACCELERATION_STRUCTURE_TYPE_BOTTOM_LEVEL_KHR);
412 
413       data.data = geometry->geometry.triangles.vertexData.deviceAddress +
414                   build_range_info->firstVertex * geometry->geometry.triangles.vertexStride;
415       data.indices = geometry->geometry.triangles.indexData.deviceAddress;
416 
417       if (geometry->geometry.triangles.indexType == VK_INDEX_TYPE_NONE_KHR)
418          data.data += build_range_info->primitiveOffset;
419       else
420          data.indices += build_range_info->primitiveOffset;
421 
422       data.transform = geometry->geometry.triangles.transformData.deviceAddress;
423       if (data.transform)
424          data.transform += build_range_info->transformOffset;
425 
426       data.stride = geometry->geometry.triangles.vertexStride;
427       data.vertex_format = geometry->geometry.triangles.vertexFormat;
428       data.index_format = geometry->geometry.triangles.indexType;
429       break;
430    case VK_GEOMETRY_TYPE_AABBS_KHR:
431       assert(type == VK_ACCELERATION_STRUCTURE_TYPE_BOTTOM_LEVEL_KHR);
432 
433       data.data = geometry->geometry.aabbs.data.deviceAddress + build_range_info->primitiveOffset;
434       data.stride = geometry->geometry.aabbs.stride;
435       break;
436    case VK_GEOMETRY_TYPE_INSTANCES_KHR:
437       assert(type == VK_ACCELERATION_STRUCTURE_TYPE_TOP_LEVEL_KHR);
438 
439       data.data = geometry->geometry.instances.data.deviceAddress + build_range_info->primitiveOffset;
440 
441       if (geometry->geometry.instances.arrayOfPointers)
442          data.stride = 8;
443       else
444          data.stride = sizeof(VkAccelerationStructureInstanceKHR);
445       break;
446    default:
447       unreachable("Unknown geometryType");
448    }
449 
450    return data;
451 }
452 
453 void
vk_accel_struct_cmd_begin_debug_marker(VkCommandBuffer commandBuffer,enum vk_acceleration_structure_build_step step,const char * format,...)454 vk_accel_struct_cmd_begin_debug_marker(VkCommandBuffer commandBuffer,
455                                        enum vk_acceleration_structure_build_step step,
456                                        const char *format, ...)
457 {
458    VK_FROM_HANDLE(vk_command_buffer, cmd_buffer, commandBuffer);
459    struct vk_device *device = cmd_buffer->base.device;
460 
461    va_list ap;
462    va_start(ap, format);
463 
464    char *name;
465    if (vasprintf(&name, format, ap) == -1) {
466       va_end(ap);
467       return;
468    }
469 
470    va_end(ap);
471 
472    VkDebugMarkerMarkerInfoEXT marker = {
473       .sType = VK_STRUCTURE_TYPE_DEBUG_MARKER_MARKER_INFO_EXT,
474       .pMarkerName = name,
475    };
476 
477    device->dispatch_table.CmdDebugMarkerBeginEXT(commandBuffer, &marker);
478 }
479 
480 void
vk_accel_struct_cmd_end_debug_marker(VkCommandBuffer commandBuffer)481 vk_accel_struct_cmd_end_debug_marker(VkCommandBuffer commandBuffer)
482 {
483    VK_FROM_HANDLE(vk_command_buffer, cmd_buffer, commandBuffer);
484    struct vk_device *device = cmd_buffer->base.device;
485 
486    device->dispatch_table.CmdDebugMarkerEndEXT(commandBuffer);
487 }
488 
489 static VkResult
build_leaves(VkCommandBuffer commandBuffer,struct vk_device * device,struct vk_meta_device * meta,const struct vk_acceleration_structure_build_args * args,uint32_t infoCount,const VkAccelerationStructureBuildGeometryInfoKHR * pInfos,const VkAccelerationStructureBuildRangeInfoKHR * const * ppBuildRangeInfos,struct bvh_state * bvh_states,bool updateable)490 build_leaves(VkCommandBuffer commandBuffer,
491              struct vk_device *device, struct vk_meta_device *meta,
492              const struct vk_acceleration_structure_build_args *args,
493              uint32_t infoCount,
494              const VkAccelerationStructureBuildGeometryInfoKHR *pInfos,
495              const VkAccelerationStructureBuildRangeInfoKHR *const *ppBuildRangeInfos,
496              struct bvh_state *bvh_states,
497              bool updateable)
498 {
499    VkPipeline pipeline;
500    VkPipelineLayout layout;
501 
502    /* Many apps are broken and will make inactive primitives active when
503     * updating, even though this is disallowed by the spec.  To handle this,
504     * we use a different variant for updateable acceleration structures when
505     * the driver implements an update pass. This passes through inactive leaf
506     * nodes as if they were active, with an empty bounding box. It's then the
507     * driver or HW's responsibility to filter out inactive nodes.
508     */
509     VkResult result;
510    if (updateable) {
511       result = get_pipeline_spv(device, meta, "leaves_always_active",
512                                 leaf_always_active_spv,
513                                 sizeof(leaf_always_active_spv),
514                                 sizeof(struct leaf_args), args, &pipeline, &layout);
515    } else {
516       result = get_pipeline_spv(device, meta, "leaves", leaf_spv, sizeof(leaf_spv),
517                                 sizeof(struct leaf_args), args, &pipeline, &layout);
518    }
519 
520    if (result != VK_SUCCESS)
521       return result;
522 
523    if (args->emit_markers) {
524       device->as_build_ops->begin_debug_marker(commandBuffer,
525                                                VK_ACCELERATION_STRUCTURE_BUILD_STEP_BUILD_LEAVES,
526                                                "build_leaves");
527    }
528 
529    const struct vk_device_dispatch_table *disp = &device->dispatch_table;
530    disp->CmdBindPipeline(
531       commandBuffer, VK_PIPELINE_BIND_POINT_COMPUTE, pipeline);
532 
533    for (uint32_t i = 0; i < infoCount; ++i) {
534       if (bvh_states[i].config.internal_type == INTERNAL_BUILD_TYPE_UPDATE)
535          continue;
536       if (bvh_states[i].config.updateable != updateable)
537          continue;
538 
539       struct leaf_args leaf_consts = {
540          .bvh = pInfos[i].scratchData.deviceAddress + bvh_states[i].scratch.ir_offset,
541          .header = pInfos[i].scratchData.deviceAddress + bvh_states[i].scratch.header_offset,
542          .ids = pInfos[i].scratchData.deviceAddress + bvh_states[i].scratch.sort_buffer_offset[0],
543       };
544 
545       for (unsigned j = 0; j < pInfos[i].geometryCount; ++j) {
546          const VkAccelerationStructureGeometryKHR *geom =
547             pInfos[i].pGeometries ? &pInfos[i].pGeometries[j] : pInfos[i].ppGeometries[j];
548 
549          const VkAccelerationStructureBuildRangeInfoKHR *build_range_info = &ppBuildRangeInfos[i][j];
550 
551          if (build_range_info->primitiveCount == 0)
552             continue;
553 
554          leaf_consts.geom_data = vk_fill_geometry_data(pInfos[i].type, bvh_states[i].leaf_node_count, j, geom, build_range_info);
555 
556          disp->CmdPushConstants(commandBuffer, layout,
557                                 VK_SHADER_STAGE_COMPUTE_BIT, 0, sizeof(leaf_consts), &leaf_consts);
558          device->cmd_dispatch_unaligned(commandBuffer, build_range_info->primitiveCount, 1, 1);
559 
560          bvh_states[i].leaf_node_count += build_range_info->primitiveCount;
561       }
562    }
563 
564    if (args->emit_markers)
565       device->as_build_ops->end_debug_marker(commandBuffer);
566 
567    return VK_SUCCESS;
568 }
569 
570 static VkResult
morton_generate(VkCommandBuffer commandBuffer,struct vk_device * device,struct vk_meta_device * meta,const struct vk_acceleration_structure_build_args * args,uint32_t infoCount,const VkAccelerationStructureBuildGeometryInfoKHR * pInfos,struct bvh_state * bvh_states)571 morton_generate(VkCommandBuffer commandBuffer, struct vk_device *device,
572                 struct vk_meta_device *meta,
573                 const struct vk_acceleration_structure_build_args *args,
574                 uint32_t infoCount,
575                 const VkAccelerationStructureBuildGeometryInfoKHR *pInfos,
576                 struct bvh_state *bvh_states)
577 {
578    VkPipeline pipeline;
579    VkPipelineLayout layout;
580 
581    VkResult result =
582       get_pipeline_spv(device, meta, "morton", morton_spv, sizeof(morton_spv),
583                        sizeof(struct morton_args), args, &pipeline, &layout);
584 
585    if (result != VK_SUCCESS)
586       return result;
587 
588    if (args->emit_markers) {
589       device->as_build_ops->begin_debug_marker(commandBuffer,
590                                                VK_ACCELERATION_STRUCTURE_BUILD_STEP_MORTON_GENERATE,
591                                                "morton_generate");
592    }
593 
594    const struct vk_device_dispatch_table *disp = &device->dispatch_table;
595    disp->CmdBindPipeline(
596       commandBuffer, VK_PIPELINE_BIND_POINT_COMPUTE, pipeline);
597 
598    for (uint32_t i = 0; i < infoCount; ++i) {
599       if (bvh_states[i].config.internal_type == INTERNAL_BUILD_TYPE_UPDATE)
600          continue;
601       const struct morton_args consts = {
602          .bvh = pInfos[i].scratchData.deviceAddress + bvh_states[i].scratch.ir_offset,
603          .header = pInfos[i].scratchData.deviceAddress + bvh_states[i].scratch.header_offset,
604          .ids = pInfos[i].scratchData.deviceAddress + bvh_states[i].scratch.sort_buffer_offset[0],
605       };
606 
607       disp->CmdPushConstants(commandBuffer, layout,
608                              VK_SHADER_STAGE_COMPUTE_BIT, 0, sizeof(consts), &consts);
609       device->cmd_dispatch_unaligned(commandBuffer, bvh_states[i].leaf_node_count, 1, 1);
610    }
611 
612    if (args->emit_markers)
613       device->as_build_ops->end_debug_marker(commandBuffer);
614 
615    return VK_SUCCESS;
616 }
617 
618 static void
morton_sort(VkCommandBuffer commandBuffer,struct vk_device * device,const struct vk_acceleration_structure_build_args * args,uint32_t infoCount,const VkAccelerationStructureBuildGeometryInfoKHR * pInfos,struct bvh_state * bvh_states)619 morton_sort(VkCommandBuffer commandBuffer, struct vk_device *device,
620             const struct vk_acceleration_structure_build_args *args,
621             uint32_t infoCount,
622             const VkAccelerationStructureBuildGeometryInfoKHR *pInfos, struct bvh_state *bvh_states)
623 {
624    const struct vk_device_dispatch_table *disp = &device->dispatch_table;
625 
626    if (args->emit_markers) {
627       device->as_build_ops->begin_debug_marker(commandBuffer,
628                                                VK_ACCELERATION_STRUCTURE_BUILD_STEP_MORTON_SORT,
629                                                "morton_sort");
630    }
631 
632    /* Copyright 2019 The Fuchsia Authors. */
633    const radix_sort_vk_t *rs = args->radix_sort;
634 
635    /*
636     * OVERVIEW
637     *
638     *   1. Pad the keyvals in `scatter_even`.
639     *   2. Zero the `histograms` and `partitions`.
640     *      --- BARRIER ---
641     *   3. HISTOGRAM is dispatched before PREFIX.
642     *      --- BARRIER ---
643     *   4. PREFIX is dispatched before the first SCATTER.
644     *      --- BARRIER ---
645     *   5. One or more SCATTER dispatches.
646     *
647     * Note that the `partitions` buffer can be zeroed anytime before the first
648     * scatter.
649     */
650 
651    /* How many passes? */
652    uint32_t keyval_bytes = rs->config.keyval_dwords * (uint32_t)sizeof(uint32_t);
653    uint32_t keyval_bits = keyval_bytes * 8;
654    uint32_t key_bits = MIN2(MORTON_BIT_SIZE, keyval_bits);
655    uint32_t passes = (key_bits + RS_RADIX_LOG2 - 1) / RS_RADIX_LOG2;
656 
657    for (uint32_t i = 0; i < infoCount; ++i) {
658       if (bvh_states[i].leaf_node_count)
659          bvh_states[i].scratch_offset = bvh_states[i].scratch.sort_buffer_offset[passes & 1];
660       else
661          bvh_states[i].scratch_offset = bvh_states[i].scratch.sort_buffer_offset[0];
662    }
663 
664    /*
665     * PAD KEYVALS AND ZERO HISTOGRAM/PARTITIONS
666     *
667     * Pad fractional blocks with max-valued keyvals.
668     *
669     * Zero the histograms and partitions buffer.
670     *
671     * This assumes the partitions follow the histograms.
672     */
673 
674    /* FIXME(allanmac): Consider precomputing some of these values and hang them off `rs`. */
675 
676    /* How many scatter blocks? */
677    uint32_t scatter_wg_size = 1 << rs->config.scatter.workgroup_size_log2;
678    uint32_t scatter_block_kvs = scatter_wg_size * rs->config.scatter.block_rows;
679 
680    /*
681     * How many histogram blocks?
682     *
683     * Note that it's OK to have more max-valued digits counted by the histogram
684     * than sorted by the scatters because the sort is stable.
685     */
686    uint32_t histo_wg_size = 1 << rs->config.histogram.workgroup_size_log2;
687    uint32_t histo_block_kvs = histo_wg_size * rs->config.histogram.block_rows;
688 
689    uint32_t pass_idx = (keyval_bytes - passes);
690 
691    for (uint32_t i = 0; i < infoCount; ++i) {
692       if (!bvh_states[i].leaf_node_count)
693          continue;
694       if (bvh_states[i].config.internal_type == INTERNAL_BUILD_TYPE_UPDATE)
695          continue;
696 
697       uint64_t keyvals_even_addr = pInfos[i].scratchData.deviceAddress + bvh_states[i].scratch.sort_buffer_offset[0];
698       uint64_t internal_addr = pInfos[i].scratchData.deviceAddress + bvh_states[i].scratch.sort_internal_offset;
699 
700       bvh_states[i].scatter_blocks = (bvh_states[i].leaf_node_count + scatter_block_kvs - 1) / scatter_block_kvs;
701       bvh_states[i].count_ru_scatter = bvh_states[i].scatter_blocks * scatter_block_kvs;
702 
703       bvh_states[i].histo_blocks = (bvh_states[i].count_ru_scatter + histo_block_kvs - 1) / histo_block_kvs;
704       bvh_states[i].count_ru_histo = bvh_states[i].histo_blocks * histo_block_kvs;
705 
706       /* Fill with max values */
707       if (bvh_states[i].count_ru_histo > bvh_states[i].leaf_node_count) {
708          device->cmd_fill_buffer_addr(commandBuffer, keyvals_even_addr +
709                                       bvh_states[i].leaf_node_count * keyval_bytes,
710                                       (bvh_states[i].count_ru_histo - bvh_states[i].leaf_node_count) * keyval_bytes,
711                                       0xFFFFFFFF);
712       }
713 
714       /*
715        * Zero histograms and invalidate partitions.
716        *
717        * Note that the partition invalidation only needs to be performed once
718        * because the even/odd scatter dispatches rely on the the previous pass to
719        * leave the partitions in an invalid state.
720        *
721        * Note that the last workgroup doesn't read/write a partition so it doesn't
722        * need to be initialized.
723        */
724       uint32_t histo_partition_count = passes + bvh_states[i].scatter_blocks - 1;
725 
726       uint32_t fill_base = pass_idx * (RS_RADIX_SIZE * sizeof(uint32_t));
727 
728       device->cmd_fill_buffer_addr(commandBuffer,
729                                    internal_addr + rs->internal.histograms.offset + fill_base,
730                                    histo_partition_count * (RS_RADIX_SIZE * sizeof(uint32_t)) + keyval_bytes * sizeof(uint32_t), 0);
731    }
732 
733    /*
734     * Pipeline: HISTOGRAM
735     *
736     * TODO(allanmac): All subgroups should try to process approximately the same
737     * number of blocks in order to minimize tail effects.  This was implemented
738     * and reverted but should be reimplemented and benchmarked later.
739     */
740    vk_barrier_transfer_w_to_compute_r(commandBuffer);
741 
742    disp->CmdBindPipeline(commandBuffer, VK_PIPELINE_BIND_POINT_COMPUTE,
743                          rs->pipelines.named.histogram);
744 
745    for (uint32_t i = 0; i < infoCount; ++i) {
746       if (!bvh_states[i].leaf_node_count)
747          continue;
748       if (bvh_states[i].config.internal_type == INTERNAL_BUILD_TYPE_UPDATE)
749          continue;
750 
751       uint64_t keyvals_even_addr = pInfos[i].scratchData.deviceAddress + bvh_states[i].scratch.sort_buffer_offset[0];
752       uint64_t internal_addr = pInfos[i].scratchData.deviceAddress + bvh_states[i].scratch.sort_internal_offset;
753 
754       /* Dispatch histogram */
755       struct rs_push_histogram push_histogram = {
756          .devaddr_histograms = internal_addr + rs->internal.histograms.offset,
757          .devaddr_keyvals = keyvals_even_addr,
758          .passes = passes,
759       };
760 
761       disp->CmdPushConstants(commandBuffer, rs->pipeline_layouts.named.histogram, VK_SHADER_STAGE_COMPUTE_BIT, 0,
762                              sizeof(push_histogram), &push_histogram);
763 
764       disp->CmdDispatch(commandBuffer, bvh_states[i].histo_blocks, 1, 1);
765    }
766 
767    /*
768     * Pipeline: PREFIX
769     *
770     * Launch one workgroup per pass.
771     */
772    vk_barrier_compute_w_to_compute_r(commandBuffer);
773 
774    disp->CmdBindPipeline(commandBuffer, VK_PIPELINE_BIND_POINT_COMPUTE,
775                          rs->pipelines.named.prefix);
776 
777    for (uint32_t i = 0; i < infoCount; ++i) {
778       if (!bvh_states[i].leaf_node_count)
779          continue;
780       if (bvh_states[i].config.internal_type == INTERNAL_BUILD_TYPE_UPDATE)
781          continue;
782 
783       uint64_t internal_addr = pInfos[i].scratchData.deviceAddress + bvh_states[i].scratch.sort_internal_offset;
784 
785       struct rs_push_prefix push_prefix = {
786          .devaddr_histograms = internal_addr + rs->internal.histograms.offset,
787       };
788 
789       disp->CmdPushConstants(commandBuffer, rs->pipeline_layouts.named.prefix, VK_SHADER_STAGE_COMPUTE_BIT, 0,
790                              sizeof(push_prefix), &push_prefix);
791 
792       disp->CmdDispatch(commandBuffer, passes, 1, 1);
793    }
794 
795    /* Pipeline: SCATTER */
796    vk_barrier_compute_w_to_compute_r(commandBuffer);
797 
798    uint32_t histogram_offset = pass_idx * (RS_RADIX_SIZE * sizeof(uint32_t));
799 
800    for (uint32_t i = 0; i < infoCount; i++) {
801       uint64_t keyvals_even_addr = pInfos[i].scratchData.deviceAddress + bvh_states[i].scratch.sort_buffer_offset[0];
802       uint64_t keyvals_odd_addr = pInfos[i].scratchData.deviceAddress + bvh_states[i].scratch.sort_buffer_offset[1];
803       uint64_t internal_addr = pInfos[i].scratchData.deviceAddress + bvh_states[i].scratch.sort_internal_offset;
804 
805       bvh_states[i].push_scatter = (struct rs_push_scatter){
806          .devaddr_keyvals_even = keyvals_even_addr,
807          .devaddr_keyvals_odd = keyvals_odd_addr,
808          .devaddr_partitions = internal_addr + rs->internal.partitions.offset,
809          .devaddr_histograms = internal_addr + rs->internal.histograms.offset + histogram_offset,
810       };
811    }
812 
813    bool is_even = true;
814 
815    while (true) {
816       uint32_t pass_dword = pass_idx / 4;
817 
818       /* Bind new pipeline */
819       VkPipeline p =
820          is_even ? rs->pipelines.named.scatter[pass_dword].even : rs->pipelines.named.scatter[pass_dword].odd;
821       disp->CmdBindPipeline(commandBuffer, VK_PIPELINE_BIND_POINT_COMPUTE, p);
822 
823       /* Update push constants that changed */
824       VkPipelineLayout pl = is_even ? rs->pipeline_layouts.named.scatter[pass_dword].even
825                                     : rs->pipeline_layouts.named.scatter[pass_dword].odd;
826 
827       for (uint32_t i = 0; i < infoCount; i++) {
828          if (!bvh_states[i].leaf_node_count)
829             continue;
830          if (bvh_states[i].config.internal_type == INTERNAL_BUILD_TYPE_UPDATE)
831             continue;
832 
833          bvh_states[i].push_scatter.pass_offset = (pass_idx & 3) * RS_RADIX_LOG2;
834 
835          disp->CmdPushConstants(commandBuffer, pl, VK_SHADER_STAGE_COMPUTE_BIT, 0, sizeof(struct rs_push_scatter),
836                                 &bvh_states[i].push_scatter);
837 
838          disp->CmdDispatch(commandBuffer, bvh_states[i].scatter_blocks, 1, 1);
839 
840          bvh_states[i].push_scatter.devaddr_histograms += (RS_RADIX_SIZE * sizeof(uint32_t));
841       }
842 
843       /* Continue? */
844       if (++pass_idx >= keyval_bytes)
845          break;
846 
847       vk_barrier_compute_w_to_compute_r(commandBuffer);
848 
849       is_even ^= true;
850    }
851 
852    if (args->emit_markers)
853       device->as_build_ops->end_debug_marker(commandBuffer);
854 }
855 
856 static VkResult
lbvh_build_internal(VkCommandBuffer commandBuffer,struct vk_device * device,struct vk_meta_device * meta,const struct vk_acceleration_structure_build_args * args,uint32_t infoCount,const VkAccelerationStructureBuildGeometryInfoKHR * pInfos,struct bvh_state * bvh_states)857 lbvh_build_internal(VkCommandBuffer commandBuffer,
858                     struct vk_device *device, struct vk_meta_device *meta,
859                     const struct vk_acceleration_structure_build_args *args,
860                     uint32_t infoCount,
861                     const VkAccelerationStructureBuildGeometryInfoKHR *pInfos, struct bvh_state *bvh_states)
862 {
863    VkPipeline pipeline;
864    VkPipelineLayout layout;
865 
866    VkResult result =
867       get_pipeline_spv(device, meta, "lbvh_main", lbvh_main_spv,
868                        sizeof(lbvh_main_spv),
869                        sizeof(struct lbvh_main_args), args, &pipeline, &layout);
870 
871    if (result != VK_SUCCESS)
872       return result;
873 
874    if (args->emit_markers) {
875       device->as_build_ops->begin_debug_marker(commandBuffer,
876                                                VK_ACCELERATION_STRUCTURE_BUILD_STEP_LBVH_BUILD_INTERNAL,
877                                                "lbvh_build_internal");
878    }
879 
880    const struct vk_device_dispatch_table *disp = &device->dispatch_table;
881    disp->CmdBindPipeline(
882       commandBuffer, VK_PIPELINE_BIND_POINT_COMPUTE, pipeline);
883 
884    for (uint32_t i = 0; i < infoCount; ++i) {
885       if (bvh_states[i].config.internal_type != INTERNAL_BUILD_TYPE_LBVH)
886          continue;
887 
888       uint32_t src_scratch_offset = bvh_states[i].scratch_offset;
889       uint32_t internal_node_count = MAX2(bvh_states[i].leaf_node_count, 2) - 1;
890 
891       const struct lbvh_main_args consts = {
892          .bvh = pInfos[i].scratchData.deviceAddress + bvh_states[i].scratch.ir_offset,
893          .src_ids = pInfos[i].scratchData.deviceAddress + src_scratch_offset,
894          .node_info = pInfos[i].scratchData.deviceAddress + bvh_states[i].scratch.lbvh_node_offset,
895          .id_count = bvh_states[i].leaf_node_count,
896          .internal_node_base = bvh_states[i].scratch.internal_node_offset - bvh_states[i].scratch.ir_offset,
897       };
898 
899       disp->CmdPushConstants(commandBuffer, layout,
900                              VK_SHADER_STAGE_COMPUTE_BIT, 0, sizeof(consts), &consts);
901       device->cmd_dispatch_unaligned(commandBuffer, internal_node_count, 1, 1);
902       bvh_states[i].internal_node_count = internal_node_count;
903    }
904 
905    vk_barrier_compute_w_to_compute_r(commandBuffer);
906 
907    result =
908       get_pipeline_spv(device, meta, "lbvh_generate_ir", lbvh_generate_ir_spv,
909                        sizeof(lbvh_generate_ir_spv),
910                        sizeof(struct lbvh_generate_ir_args), args, &pipeline, &layout);
911 
912    if (result != VK_SUCCESS)
913       return result;
914 
915    disp->CmdBindPipeline(
916       commandBuffer, VK_PIPELINE_BIND_POINT_COMPUTE, pipeline);
917 
918    for (uint32_t i = 0; i < infoCount; ++i) {
919       if (bvh_states[i].config.internal_type != INTERNAL_BUILD_TYPE_LBVH)
920          continue;
921 
922       const struct lbvh_generate_ir_args consts = {
923          .bvh = pInfos[i].scratchData.deviceAddress + bvh_states[i].scratch.ir_offset,
924          .node_info = pInfos[i].scratchData.deviceAddress + bvh_states[i].scratch.lbvh_node_offset,
925          .header = pInfos[i].scratchData.deviceAddress + bvh_states[i].scratch.header_offset,
926          .internal_node_base = bvh_states[i].scratch.internal_node_offset - bvh_states[i].scratch.ir_offset,
927       };
928 
929       disp->CmdPushConstants(commandBuffer, layout,
930                              VK_SHADER_STAGE_COMPUTE_BIT, 0, sizeof(consts), &consts);
931       device->cmd_dispatch_unaligned(commandBuffer, bvh_states[i].internal_node_count, 1, 1);
932    }
933 
934    if (args->emit_markers)
935       device->as_build_ops->end_debug_marker(commandBuffer);
936 
937    return VK_SUCCESS;
938 }
939 
940 static VkResult
ploc_build_internal(VkCommandBuffer commandBuffer,struct vk_device * device,struct vk_meta_device * meta,const struct vk_acceleration_structure_build_args * args,uint32_t infoCount,const VkAccelerationStructureBuildGeometryInfoKHR * pInfos,struct bvh_state * bvh_states)941 ploc_build_internal(VkCommandBuffer commandBuffer,
942                     struct vk_device *device, struct vk_meta_device *meta,
943                     const struct vk_acceleration_structure_build_args *args,
944                     uint32_t infoCount,
945                     const VkAccelerationStructureBuildGeometryInfoKHR *pInfos, struct bvh_state *bvh_states)
946 {
947    VkPipeline pipeline;
948    VkPipelineLayout layout;
949 
950    VkResult result =
951       get_pipeline_spv(device, meta, "ploc", ploc_spv,
952                        sizeof(ploc_spv),
953                        sizeof(struct ploc_args), args, &pipeline, &layout);
954 
955    if (result != VK_SUCCESS)
956       return result;
957 
958    if (args->emit_markers) {
959       device->as_build_ops->begin_debug_marker(commandBuffer,
960                                                VK_ACCELERATION_STRUCTURE_BUILD_STEP_PLOC_BUILD_INTERNAL,
961                                                "ploc_build_internal");
962    }
963 
964    const struct vk_device_dispatch_table *disp = &device->dispatch_table;
965    disp->CmdBindPipeline(
966       commandBuffer, VK_PIPELINE_BIND_POINT_COMPUTE, pipeline);
967 
968    for (uint32_t i = 0; i < infoCount; ++i) {
969       if (bvh_states[i].config.internal_type != INTERNAL_BUILD_TYPE_PLOC)
970          continue;
971 
972       uint32_t src_scratch_offset = bvh_states[i].scratch_offset;
973       uint32_t dst_scratch_offset = (src_scratch_offset == bvh_states[i].scratch.sort_buffer_offset[0])
974                                        ? bvh_states[i].scratch.sort_buffer_offset[1]
975                                        : bvh_states[i].scratch.sort_buffer_offset[0];
976 
977       const struct ploc_args consts = {
978          .bvh = pInfos[i].scratchData.deviceAddress + bvh_states[i].scratch.ir_offset,
979          .header = pInfos[i].scratchData.deviceAddress + bvh_states[i].scratch.header_offset,
980          .ids_0 = pInfos[i].scratchData.deviceAddress + src_scratch_offset,
981          .ids_1 = pInfos[i].scratchData.deviceAddress + dst_scratch_offset,
982          .prefix_scan_partitions =
983             pInfos[i].scratchData.deviceAddress + bvh_states[i].scratch.ploc_prefix_sum_partition_offset,
984          .internal_node_offset = bvh_states[i].scratch.internal_node_offset - bvh_states[i].scratch.ir_offset,
985       };
986 
987       disp->CmdPushConstants(commandBuffer, layout,
988                              VK_SHADER_STAGE_COMPUTE_BIT, 0, sizeof(consts), &consts);
989       disp->CmdDispatch(commandBuffer, MAX2(DIV_ROUND_UP(bvh_states[i].leaf_node_count, PLOC_WORKGROUP_SIZE), 1), 1, 1);
990    }
991 
992    if (args->emit_markers)
993       device->as_build_ops->end_debug_marker(commandBuffer);
994 
995    return VK_SUCCESS;
996 }
997 
998 void
vk_cmd_build_acceleration_structures(VkCommandBuffer commandBuffer,struct vk_device * device,struct vk_meta_device * meta,uint32_t infoCount,const VkAccelerationStructureBuildGeometryInfoKHR * pInfos,const VkAccelerationStructureBuildRangeInfoKHR * const * ppBuildRangeInfos,const struct vk_acceleration_structure_build_args * args)999 vk_cmd_build_acceleration_structures(VkCommandBuffer commandBuffer,
1000                                      struct vk_device *device,
1001                                      struct vk_meta_device *meta,
1002                                      uint32_t infoCount,
1003                                      const VkAccelerationStructureBuildGeometryInfoKHR *pInfos,
1004                                      const VkAccelerationStructureBuildRangeInfoKHR *const *ppBuildRangeInfos,
1005                                      const struct vk_acceleration_structure_build_args *args)
1006 {
1007    VK_FROM_HANDLE(vk_command_buffer, cmd_buffer, commandBuffer);
1008    const struct vk_acceleration_structure_build_ops *ops = device->as_build_ops;
1009 
1010    struct bvh_batch_state batch_state = {0};
1011 
1012    struct bvh_state *bvh_states = calloc(infoCount, sizeof(struct bvh_state));
1013 
1014    if (args->emit_markers) {
1015       device->as_build_ops->begin_debug_marker(commandBuffer,
1016                                                VK_ACCELERATION_STRUCTURE_BUILD_STEP_TOP,
1017                                                "vkCmdBuildAccelerationStructuresKHR(%u)",
1018                                                infoCount);
1019    }
1020 
1021    for (uint32_t i = 0; i < infoCount; ++i) {
1022       uint32_t leaf_node_count = 0;
1023       for (uint32_t j = 0; j < pInfos[i].geometryCount; ++j) {
1024          leaf_node_count += ppBuildRangeInfos[i][j].primitiveCount;
1025       }
1026 
1027       get_scratch_layout(device, leaf_node_count, pInfos + i, args, &bvh_states[i].scratch);
1028 
1029       struct build_config config = build_config(leaf_node_count, pInfos + i,
1030                                                 device->as_build_ops);
1031       bvh_states[i].config = config;
1032 
1033       if (config.updateable)
1034          batch_state.any_updateable = true;
1035       else
1036          batch_state.any_non_updateable = true;
1037 
1038       if (config.internal_type == INTERNAL_BUILD_TYPE_PLOC) {
1039          batch_state.any_ploc = true;
1040       } else if (config.internal_type == INTERNAL_BUILD_TYPE_LBVH) {
1041          batch_state.any_lbvh = true;
1042       } else if (config.internal_type == INTERNAL_BUILD_TYPE_UPDATE) {
1043          batch_state.any_update = true;
1044          /* For updates, the leaf node pass never runs, so set leaf_node_count here. */
1045          bvh_states[i].leaf_node_count = leaf_node_count;
1046       } else {
1047          unreachable("Unknown internal_build_type");
1048       }
1049 
1050       if (bvh_states[i].config.internal_type != INTERNAL_BUILD_TYPE_UPDATE) {
1051          /* The internal node count is updated in lbvh_build_internal for LBVH
1052           * and from the PLOC shader for PLOC. */
1053          struct vk_ir_header header = {
1054             .min_bounds = {0x7fffffff, 0x7fffffff, 0x7fffffff},
1055             .max_bounds = {0x80000000, 0x80000000, 0x80000000},
1056             .dispatch_size_y = 1,
1057             .dispatch_size_z = 1,
1058             .sync_data =
1059                {
1060                   .current_phase_end_counter = TASK_INDEX_INVALID,
1061                   /* Will be updated by the first PLOC shader invocation */
1062                   .task_counts = {TASK_INDEX_INVALID, TASK_INDEX_INVALID},
1063                },
1064          };
1065 
1066          device->write_buffer_cp(commandBuffer, pInfos[i].scratchData.deviceAddress + bvh_states[i].scratch.header_offset,
1067                                  &header, sizeof(header));
1068       } else {
1069          VK_FROM_HANDLE(vk_acceleration_structure, src_as, pInfos[i].srcAccelerationStructure);
1070          VK_FROM_HANDLE(vk_acceleration_structure, dst_as, pInfos[i].dstAccelerationStructure);
1071 
1072          ops->init_update_scratch(commandBuffer, pInfos[i].scratchData.deviceAddress,
1073                                   leaf_node_count, src_as, dst_as);
1074       }
1075    }
1076 
1077    /* Wait for the write_buffer_cp to land before using in compute shaders */
1078    device->flush_buffer_write_cp(commandBuffer);
1079    device->dispatch_table.CmdPipelineBarrier(commandBuffer,
1080                                              VK_PIPELINE_STAGE_TOP_OF_PIPE_BIT,
1081                                              VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT,
1082                                              0, /* dependencyFlags */
1083                                              1,
1084                                              &(VkMemoryBarrier) {
1085                                                 .srcAccessMask = 0,
1086                                                 .dstAccessMask = VK_ACCESS_SHADER_READ_BIT,
1087                                              }, 0, NULL, 0, NULL);
1088 
1089    if (batch_state.any_lbvh || batch_state.any_ploc) {
1090       VkResult result;
1091 
1092       if (batch_state.any_non_updateable) {
1093          result =
1094             build_leaves(commandBuffer, device, meta, args, infoCount, pInfos,
1095                          ppBuildRangeInfos, bvh_states, false);
1096 
1097          if (result != VK_SUCCESS) {
1098             free(bvh_states);
1099             vk_command_buffer_set_error(cmd_buffer, result);
1100             return;
1101          }
1102       }
1103 
1104       if (batch_state.any_updateable) {
1105          result =
1106             build_leaves(commandBuffer, device, meta, args, infoCount, pInfos,
1107                          ppBuildRangeInfos, bvh_states, true);
1108 
1109          if (result != VK_SUCCESS) {
1110             free(bvh_states);
1111             vk_command_buffer_set_error(cmd_buffer, result);
1112             return;
1113          }
1114       }
1115 
1116       vk_barrier_compute_w_to_compute_r(commandBuffer);
1117 
1118       result =
1119          morton_generate(commandBuffer, device, meta, args, infoCount, pInfos, bvh_states);
1120 
1121       if (result != VK_SUCCESS) {
1122          free(bvh_states);
1123          vk_command_buffer_set_error(cmd_buffer, result);
1124          return;
1125       }
1126 
1127       vk_barrier_compute_w_to_compute_r(commandBuffer);
1128 
1129       morton_sort(commandBuffer, device, args, infoCount, pInfos, bvh_states);
1130 
1131       vk_barrier_compute_w_to_compute_r(commandBuffer);
1132 
1133       if (batch_state.any_lbvh) {
1134          result =
1135             lbvh_build_internal(commandBuffer, device, meta, args, infoCount, pInfos, bvh_states);
1136 
1137          if (result != VK_SUCCESS) {
1138             free(bvh_states);
1139             vk_command_buffer_set_error(cmd_buffer, result);
1140             return;
1141          }
1142       }
1143 
1144       if (batch_state.any_ploc) {
1145          result =
1146             ploc_build_internal(commandBuffer, device, meta, args, infoCount, pInfos, bvh_states);
1147 
1148          if (result != VK_SUCCESS) {
1149             vk_command_buffer_set_error(cmd_buffer, result);
1150             return;
1151          }
1152       }
1153 
1154       vk_barrier_compute_w_to_compute_r(commandBuffer);
1155       vk_barrier_compute_w_to_indirect_compute_r(commandBuffer);
1156    }
1157 
1158    if (args->emit_markers) {
1159       device->as_build_ops->begin_debug_marker(commandBuffer,
1160                                                VK_ACCELERATION_STRUCTURE_BUILD_STEP_ENCODE,
1161                                                "encode");
1162    }
1163 
1164    for (unsigned pass = 0; pass < ARRAY_SIZE(ops->encode_as); pass++) {
1165       if (!ops->encode_as[pass] && !ops->update_as[pass])
1166          break;
1167 
1168       bool progress;
1169       do {
1170          progress = false;
1171 
1172          bool update;
1173          uint32_t encode_key = 0;
1174          for (uint32_t i = 0; i < infoCount; ++i) {
1175             if (bvh_states[i].last_encode_pass == pass + 1)
1176                continue;
1177 
1178             if (!progress) {
1179                update = (bvh_states[i].config.internal_type ==
1180                          INTERNAL_BUILD_TYPE_UPDATE);
1181                if (update && !ops->update_as[pass])
1182                   continue;
1183                if (!update && !ops->encode_as[pass])
1184                   continue;
1185                encode_key = bvh_states[i].config.encode_key[pass];
1186                progress = true;
1187                if (update)
1188                   ops->update_bind_pipeline[pass](commandBuffer);
1189                else
1190                   ops->encode_bind_pipeline[pass](commandBuffer, encode_key);
1191             } else {
1192                if (update != (bvh_states[i].config.internal_type ==
1193                               INTERNAL_BUILD_TYPE_UPDATE) ||
1194                    encode_key != bvh_states[i].config.encode_key[pass])
1195                   continue;
1196             }
1197 
1198             VK_FROM_HANDLE(vk_acceleration_structure, accel_struct, pInfos[i].dstAccelerationStructure);
1199 
1200             if (update) {
1201                VK_FROM_HANDLE(vk_acceleration_structure, src, pInfos[i].srcAccelerationStructure);
1202                ops->update_as[pass](commandBuffer,
1203                                     &pInfos[i],
1204                                     ppBuildRangeInfos[i],
1205                                     bvh_states[i].leaf_node_count,
1206                                     src,
1207                                     accel_struct);
1208 
1209             } else {
1210                ops->encode_as[pass](commandBuffer,
1211                                     &pInfos[i],
1212                                     ppBuildRangeInfos[i],
1213                                     pInfos[i].scratchData.deviceAddress + bvh_states[i].scratch.ir_offset,
1214                                     pInfos[i].scratchData.deviceAddress + bvh_states[i].scratch.header_offset,
1215                                     bvh_states[i].leaf_node_count,
1216                                     encode_key,
1217                                     accel_struct);
1218             }
1219 
1220             bvh_states[i].last_encode_pass = pass + 1;
1221          }
1222       } while (progress);
1223    }
1224 
1225    if (args->emit_markers)
1226       device->as_build_ops->end_debug_marker(commandBuffer);
1227 
1228    if (args->emit_markers)
1229       device->as_build_ops->end_debug_marker(commandBuffer);
1230 
1231    free(bvh_states);
1232 }
1233 
1234 void
vk_get_as_build_sizes(VkDevice _device,VkAccelerationStructureBuildTypeKHR buildType,const VkAccelerationStructureBuildGeometryInfoKHR * pBuildInfo,const uint32_t * pMaxPrimitiveCounts,VkAccelerationStructureBuildSizesInfoKHR * pSizeInfo,const struct vk_acceleration_structure_build_args * args)1235 vk_get_as_build_sizes(VkDevice _device, VkAccelerationStructureBuildTypeKHR buildType,
1236                       const VkAccelerationStructureBuildGeometryInfoKHR *pBuildInfo,
1237                       const uint32_t *pMaxPrimitiveCounts,
1238                       VkAccelerationStructureBuildSizesInfoKHR *pSizeInfo,
1239                       const struct vk_acceleration_structure_build_args *args)
1240 {
1241    VK_FROM_HANDLE(vk_device, device, _device);
1242 
1243    uint32_t leaf_count = 0;
1244    for (uint32_t i = 0; i < pBuildInfo->geometryCount; i++)
1245       leaf_count += pMaxPrimitiveCounts[i];
1246 
1247    struct scratch_layout scratch;
1248 
1249    get_scratch_layout(device, leaf_count, pBuildInfo, args, &scratch);
1250 
1251    pSizeInfo->accelerationStructureSize =
1252       device->as_build_ops->get_as_size(_device, pBuildInfo, leaf_count);
1253    pSizeInfo->updateScratchSize = scratch.update_size;
1254    pSizeInfo->buildScratchSize = scratch.size;
1255 }
1256 
1257 /* Return true if the common framework supports using this format for loading
1258  * vertices. Must match the formats handled by load_vertices() on the GPU.
1259  */
1260 bool
vk_acceleration_struct_vtx_format_supported(VkFormat format)1261 vk_acceleration_struct_vtx_format_supported(VkFormat format)
1262 {
1263    switch (format) {
1264    case VK_FORMAT_R32G32_SFLOAT:
1265    case VK_FORMAT_R32G32B32_SFLOAT:
1266    case VK_FORMAT_R32G32B32A32_SFLOAT:
1267    case VK_FORMAT_R16G16_SFLOAT:
1268    case VK_FORMAT_R16G16B16_SFLOAT:
1269    case VK_FORMAT_R16G16B16A16_SFLOAT:
1270    case VK_FORMAT_R16G16_SNORM:
1271    case VK_FORMAT_R16G16_UNORM:
1272    case VK_FORMAT_R16G16B16A16_SNORM:
1273    case VK_FORMAT_R16G16B16A16_UNORM:
1274    case VK_FORMAT_R8G8_SNORM:
1275    case VK_FORMAT_R8G8_UNORM:
1276    case VK_FORMAT_R8G8B8A8_SNORM:
1277    case VK_FORMAT_R8G8B8A8_UNORM:
1278    case VK_FORMAT_A2B10G10R10_UNORM_PACK32:
1279       return true;
1280    default:
1281       return false;
1282    }
1283 }
1284