• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright © 2021 Bas Nieuwenhuizen
3  *
4  * Permission is hereby granted, free of charge, to any person obtaining a
5  * copy of this software and associated documentation files (the "Software"),
6  * to deal in the Software without restriction, including without limitation
7  * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8  * and/or sell copies of the Software, and to permit persons to whom the
9  * Software is furnished to do so, subject to the following conditions:
10  *
11  * The above copyright notice and this permission notice (including the next
12  * paragraph) shall be included in all copies or substantial portions of the
13  * Software.
14  *
15  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18  * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20  * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21  * IN THE SOFTWARE.
22  */
23 #include "radv_acceleration_structure.h"
24 #include "radv_private.h"
25 
26 #include "util/format/format_utils.h"
27 #include "util/half_float.h"
28 #include "nir_builder.h"
29 #include "radv_cs.h"
30 #include "radv_meta.h"
31 
32 #include "radix_sort/radv_radix_sort.h"
33 
34 /* Min and max bounds of the bvh used to compute morton codes */
35 #define SCRATCH_TOTAL_BOUNDS_SIZE (6 * sizeof(float))
36 
37 enum accel_struct_build {
38    accel_struct_build_unoptimized,
39    accel_struct_build_lbvh,
40 };
41 
42 static enum accel_struct_build
get_accel_struct_build(const struct radv_physical_device * pdevice,VkAccelerationStructureBuildTypeKHR buildType)43 get_accel_struct_build(const struct radv_physical_device *pdevice,
44                        VkAccelerationStructureBuildTypeKHR buildType)
45 {
46    return buildType == VK_ACCELERATION_STRUCTURE_BUILD_TYPE_DEVICE_KHR
47              ? accel_struct_build_lbvh
48              : accel_struct_build_unoptimized;
49 }
50 
51 static uint32_t
get_node_id_stride(enum accel_struct_build build_mode)52 get_node_id_stride(enum accel_struct_build build_mode)
53 {
54    switch (build_mode) {
55    case accel_struct_build_unoptimized:
56       return 4;
57    case accel_struct_build_lbvh:
58       return 8;
59    default:
60       unreachable("Unhandled accel_struct_build!");
61    }
62 }
63 
64 VKAPI_ATTR void VKAPI_CALL
radv_GetAccelerationStructureBuildSizesKHR(VkDevice _device,VkAccelerationStructureBuildTypeKHR buildType,const VkAccelerationStructureBuildGeometryInfoKHR * pBuildInfo,const uint32_t * pMaxPrimitiveCounts,VkAccelerationStructureBuildSizesInfoKHR * pSizeInfo)65 radv_GetAccelerationStructureBuildSizesKHR(
66    VkDevice _device, VkAccelerationStructureBuildTypeKHR buildType,
67    const VkAccelerationStructureBuildGeometryInfoKHR *pBuildInfo,
68    const uint32_t *pMaxPrimitiveCounts, VkAccelerationStructureBuildSizesInfoKHR *pSizeInfo)
69 {
70    RADV_FROM_HANDLE(radv_device, device, _device);
71 
72    uint64_t triangles = 0, boxes = 0, instances = 0;
73 
74    STATIC_ASSERT(sizeof(struct radv_bvh_triangle_node) == 64);
75    STATIC_ASSERT(sizeof(struct radv_bvh_aabb_node) == 64);
76    STATIC_ASSERT(sizeof(struct radv_bvh_instance_node) == 128);
77    STATIC_ASSERT(sizeof(struct radv_bvh_box16_node) == 64);
78    STATIC_ASSERT(sizeof(struct radv_bvh_box32_node) == 128);
79 
80    for (uint32_t i = 0; i < pBuildInfo->geometryCount; ++i) {
81       const VkAccelerationStructureGeometryKHR *geometry;
82       if (pBuildInfo->pGeometries)
83          geometry = &pBuildInfo->pGeometries[i];
84       else
85          geometry = pBuildInfo->ppGeometries[i];
86 
87       switch (geometry->geometryType) {
88       case VK_GEOMETRY_TYPE_TRIANGLES_KHR:
89          triangles += pMaxPrimitiveCounts[i];
90          break;
91       case VK_GEOMETRY_TYPE_AABBS_KHR:
92          boxes += pMaxPrimitiveCounts[i];
93          break;
94       case VK_GEOMETRY_TYPE_INSTANCES_KHR:
95          instances += pMaxPrimitiveCounts[i];
96          break;
97       case VK_GEOMETRY_TYPE_MAX_ENUM_KHR:
98          unreachable("VK_GEOMETRY_TYPE_MAX_ENUM_KHR unhandled");
99       }
100    }
101 
102    uint64_t children = boxes + instances + triangles;
103    /* Initialize to 1 to have enought space for the root node. */
104    uint64_t internal_nodes = 1;
105    while (children > 1) {
106       children = DIV_ROUND_UP(children, 4);
107       internal_nodes += children;
108    }
109 
110    uint64_t size = boxes * 128 + instances * 128 + triangles * 64 + internal_nodes * 128 +
111                    ALIGN(sizeof(struct radv_accel_struct_header), 64);
112 
113    pSizeInfo->accelerationStructureSize = size;
114 
115    /* 2x the max number of nodes in a BVH layer and order information for sorting when using
116     * LBVH (one uint32_t each, two buffers) plus space to store the bounds.
117     * LBVH is only supported for device builds and hardware that supports global atomics.
118     */
119    enum accel_struct_build build_mode = get_accel_struct_build(device->physical_device, buildType);
120    uint32_t node_id_stride = get_node_id_stride(build_mode);
121 
122    uint32_t leaf_count = boxes + instances + triangles;
123    VkDeviceSize scratchSize = 2 * leaf_count * node_id_stride;
124 
125    if (build_mode == accel_struct_build_lbvh) {
126       radix_sort_vk_memory_requirements_t requirements;
127       radix_sort_vk_get_memory_requirements(device->meta_state.accel_struct_build.radix_sort,
128                                             leaf_count, &requirements);
129 
130       /* Make sure we have the space required by the radix sort. */
131       scratchSize = MAX2(scratchSize, requirements.keyvals_size * 2);
132 
133       scratchSize += requirements.internal_size + SCRATCH_TOTAL_BOUNDS_SIZE;
134    }
135 
136    scratchSize = MAX2(4096, scratchSize);
137    pSizeInfo->updateScratchSize = scratchSize;
138    pSizeInfo->buildScratchSize = scratchSize;
139 }
140 
141 VKAPI_ATTR VkResult VKAPI_CALL
radv_CreateAccelerationStructureKHR(VkDevice _device,const VkAccelerationStructureCreateInfoKHR * pCreateInfo,const VkAllocationCallbacks * pAllocator,VkAccelerationStructureKHR * pAccelerationStructure)142 radv_CreateAccelerationStructureKHR(VkDevice _device,
143                                     const VkAccelerationStructureCreateInfoKHR *pCreateInfo,
144                                     const VkAllocationCallbacks *pAllocator,
145                                     VkAccelerationStructureKHR *pAccelerationStructure)
146 {
147    RADV_FROM_HANDLE(radv_device, device, _device);
148    RADV_FROM_HANDLE(radv_buffer, buffer, pCreateInfo->buffer);
149    struct radv_acceleration_structure *accel;
150 
151    accel = vk_alloc2(&device->vk.alloc, pAllocator, sizeof(*accel), 8,
152                      VK_SYSTEM_ALLOCATION_SCOPE_OBJECT);
153    if (accel == NULL)
154       return vk_error(device, VK_ERROR_OUT_OF_HOST_MEMORY);
155 
156    vk_object_base_init(&device->vk, &accel->base, VK_OBJECT_TYPE_ACCELERATION_STRUCTURE_KHR);
157 
158    accel->mem_offset = buffer->offset + pCreateInfo->offset;
159    accel->size = pCreateInfo->size;
160    accel->bo = buffer->bo;
161 
162    *pAccelerationStructure = radv_acceleration_structure_to_handle(accel);
163    return VK_SUCCESS;
164 }
165 
166 VKAPI_ATTR void VKAPI_CALL
radv_DestroyAccelerationStructureKHR(VkDevice _device,VkAccelerationStructureKHR accelerationStructure,const VkAllocationCallbacks * pAllocator)167 radv_DestroyAccelerationStructureKHR(VkDevice _device,
168                                      VkAccelerationStructureKHR accelerationStructure,
169                                      const VkAllocationCallbacks *pAllocator)
170 {
171    RADV_FROM_HANDLE(radv_device, device, _device);
172    RADV_FROM_HANDLE(radv_acceleration_structure, accel, accelerationStructure);
173 
174    if (!accel)
175       return;
176 
177    vk_object_base_finish(&accel->base);
178    vk_free2(&device->vk.alloc, pAllocator, accel);
179 }
180 
181 VKAPI_ATTR VkDeviceAddress VKAPI_CALL
radv_GetAccelerationStructureDeviceAddressKHR(VkDevice _device,const VkAccelerationStructureDeviceAddressInfoKHR * pInfo)182 radv_GetAccelerationStructureDeviceAddressKHR(
183    VkDevice _device, const VkAccelerationStructureDeviceAddressInfoKHR *pInfo)
184 {
185    RADV_FROM_HANDLE(radv_acceleration_structure, accel, pInfo->accelerationStructure);
186    return radv_accel_struct_get_va(accel);
187 }
188 
189 VKAPI_ATTR VkResult VKAPI_CALL
radv_WriteAccelerationStructuresPropertiesKHR(VkDevice _device,uint32_t accelerationStructureCount,const VkAccelerationStructureKHR * pAccelerationStructures,VkQueryType queryType,size_t dataSize,void * pData,size_t stride)190 radv_WriteAccelerationStructuresPropertiesKHR(
191    VkDevice _device, uint32_t accelerationStructureCount,
192    const VkAccelerationStructureKHR *pAccelerationStructures, VkQueryType queryType,
193    size_t dataSize, void *pData, size_t stride)
194 {
195    RADV_FROM_HANDLE(radv_device, device, _device);
196    char *data_out = (char *)pData;
197 
198    for (uint32_t i = 0; i < accelerationStructureCount; ++i) {
199       RADV_FROM_HANDLE(radv_acceleration_structure, accel, pAccelerationStructures[i]);
200       const char *base_ptr = (const char *)device->ws->buffer_map(accel->bo);
201       if (!base_ptr)
202          return vk_error(device, VK_ERROR_OUT_OF_HOST_MEMORY);
203 
204       const struct radv_accel_struct_header *header = (const void *)(base_ptr + accel->mem_offset);
205       if (stride * i + sizeof(VkDeviceSize) <= dataSize) {
206          uint64_t value;
207          switch (queryType) {
208          case VK_QUERY_TYPE_ACCELERATION_STRUCTURE_COMPACTED_SIZE_KHR:
209             value = header->compacted_size;
210             break;
211          case VK_QUERY_TYPE_ACCELERATION_STRUCTURE_SERIALIZATION_SIZE_KHR:
212             value = header->serialization_size;
213             break;
214          case VK_QUERY_TYPE_ACCELERATION_STRUCTURE_SERIALIZATION_BOTTOM_LEVEL_POINTERS_KHR:
215             value = header->instance_count;
216             break;
217          case VK_QUERY_TYPE_ACCELERATION_STRUCTURE_SIZE_KHR:
218             value = header->size;
219             break;
220          default:
221             unreachable("Unhandled acceleration structure query");
222          }
223          *(VkDeviceSize *)(data_out + stride * i) = value;
224       }
225       device->ws->buffer_unmap(accel->bo);
226    }
227    return VK_SUCCESS;
228 }
229 
230 struct radv_bvh_build_ctx {
231    uint32_t *write_scratch;
232    char *base;
233    char *curr_ptr;
234 };
235 
236 static void
build_triangles(struct radv_bvh_build_ctx * ctx,const VkAccelerationStructureGeometryKHR * geom,const VkAccelerationStructureBuildRangeInfoKHR * range,unsigned geometry_id)237 build_triangles(struct radv_bvh_build_ctx *ctx, const VkAccelerationStructureGeometryKHR *geom,
238                 const VkAccelerationStructureBuildRangeInfoKHR *range, unsigned geometry_id)
239 {
240    const VkAccelerationStructureGeometryTrianglesDataKHR *tri_data = &geom->geometry.triangles;
241    VkTransformMatrixKHR matrix;
242    const char *index_data = (const char *)tri_data->indexData.hostAddress;
243    const char *v_data_base = (const char *)tri_data->vertexData.hostAddress;
244 
245    if (tri_data->indexType == VK_INDEX_TYPE_NONE_KHR)
246       v_data_base += range->primitiveOffset;
247    else
248       index_data += range->primitiveOffset;
249 
250    if (tri_data->transformData.hostAddress) {
251       matrix = *(const VkTransformMatrixKHR *)((const char *)tri_data->transformData.hostAddress +
252                                                range->transformOffset);
253    } else {
254       matrix = (VkTransformMatrixKHR){
255          .matrix = {{1.0, 0.0, 0.0, 0.0}, {0.0, 1.0, 0.0, 0.0}, {0.0, 0.0, 1.0, 0.0}}};
256    }
257 
258    for (uint32_t p = 0; p < range->primitiveCount; ++p, ctx->curr_ptr += 64) {
259       struct radv_bvh_triangle_node *node = (void *)ctx->curr_ptr;
260       uint32_t node_offset = ctx->curr_ptr - ctx->base;
261       uint32_t node_id = node_offset >> 3;
262       *ctx->write_scratch++ = node_id;
263 
264       for (unsigned v = 0; v < 3; ++v) {
265          uint32_t v_index = range->firstVertex;
266          switch (tri_data->indexType) {
267          case VK_INDEX_TYPE_NONE_KHR:
268             v_index += p * 3 + v;
269             break;
270          case VK_INDEX_TYPE_UINT8_EXT:
271             v_index += *(const uint8_t *)index_data;
272             index_data += 1;
273             break;
274          case VK_INDEX_TYPE_UINT16:
275             v_index += *(const uint16_t *)index_data;
276             index_data += 2;
277             break;
278          case VK_INDEX_TYPE_UINT32:
279             v_index += *(const uint32_t *)index_data;
280             index_data += 4;
281             break;
282          case VK_INDEX_TYPE_MAX_ENUM:
283             unreachable("Unhandled VK_INDEX_TYPE_MAX_ENUM");
284             break;
285          }
286 
287          const char *v_data = v_data_base + v_index * tri_data->vertexStride;
288          float coords[4];
289          switch (tri_data->vertexFormat) {
290          case VK_FORMAT_R32G32_SFLOAT:
291             coords[0] = *(const float *)(v_data + 0);
292             coords[1] = *(const float *)(v_data + 4);
293             coords[2] = 0.0f;
294             coords[3] = 1.0f;
295             break;
296          case VK_FORMAT_R32G32B32_SFLOAT:
297             coords[0] = *(const float *)(v_data + 0);
298             coords[1] = *(const float *)(v_data + 4);
299             coords[2] = *(const float *)(v_data + 8);
300             coords[3] = 1.0f;
301             break;
302          case VK_FORMAT_R32G32B32A32_SFLOAT:
303             coords[0] = *(const float *)(v_data + 0);
304             coords[1] = *(const float *)(v_data + 4);
305             coords[2] = *(const float *)(v_data + 8);
306             coords[3] = *(const float *)(v_data + 12);
307             break;
308          case VK_FORMAT_R16G16_SFLOAT:
309             coords[0] = _mesa_half_to_float(*(const uint16_t *)(v_data + 0));
310             coords[1] = _mesa_half_to_float(*(const uint16_t *)(v_data + 2));
311             coords[2] = 0.0f;
312             coords[3] = 1.0f;
313             break;
314          case VK_FORMAT_R16G16B16_SFLOAT:
315             coords[0] = _mesa_half_to_float(*(const uint16_t *)(v_data + 0));
316             coords[1] = _mesa_half_to_float(*(const uint16_t *)(v_data + 2));
317             coords[2] = _mesa_half_to_float(*(const uint16_t *)(v_data + 4));
318             coords[3] = 1.0f;
319             break;
320          case VK_FORMAT_R16G16B16A16_SFLOAT:
321             coords[0] = _mesa_half_to_float(*(const uint16_t *)(v_data + 0));
322             coords[1] = _mesa_half_to_float(*(const uint16_t *)(v_data + 2));
323             coords[2] = _mesa_half_to_float(*(const uint16_t *)(v_data + 4));
324             coords[3] = _mesa_half_to_float(*(const uint16_t *)(v_data + 6));
325             break;
326          case VK_FORMAT_R16G16_SNORM:
327             coords[0] = _mesa_snorm_to_float(*(const int16_t *)(v_data + 0), 16);
328             coords[1] = _mesa_snorm_to_float(*(const int16_t *)(v_data + 2), 16);
329             coords[2] = 0.0f;
330             coords[3] = 1.0f;
331             break;
332          case VK_FORMAT_R16G16_UNORM:
333             coords[0] = _mesa_unorm_to_float(*(const uint16_t *)(v_data + 0), 16);
334             coords[1] = _mesa_unorm_to_float(*(const uint16_t *)(v_data + 2), 16);
335             coords[2] = 0.0f;
336             coords[3] = 1.0f;
337             break;
338          case VK_FORMAT_R16G16B16A16_SNORM:
339             coords[0] = _mesa_snorm_to_float(*(const int16_t *)(v_data + 0), 16);
340             coords[1] = _mesa_snorm_to_float(*(const int16_t *)(v_data + 2), 16);
341             coords[2] = _mesa_snorm_to_float(*(const int16_t *)(v_data + 4), 16);
342             coords[3] = _mesa_snorm_to_float(*(const int16_t *)(v_data + 6), 16);
343             break;
344          case VK_FORMAT_R16G16B16A16_UNORM:
345             coords[0] = _mesa_unorm_to_float(*(const uint16_t *)(v_data + 0), 16);
346             coords[1] = _mesa_unorm_to_float(*(const uint16_t *)(v_data + 2), 16);
347             coords[2] = _mesa_unorm_to_float(*(const uint16_t *)(v_data + 4), 16);
348             coords[3] = _mesa_unorm_to_float(*(const uint16_t *)(v_data + 6), 16);
349             break;
350          case VK_FORMAT_R8G8_SNORM:
351             coords[0] = _mesa_snorm_to_float(*(const int8_t *)(v_data + 0), 8);
352             coords[1] = _mesa_snorm_to_float(*(const int8_t *)(v_data + 1), 8);
353             coords[2] = 0.0f;
354             coords[3] = 1.0f;
355             break;
356          case VK_FORMAT_R8G8_UNORM:
357             coords[0] = _mesa_unorm_to_float(*(const uint8_t *)(v_data + 0), 8);
358             coords[1] = _mesa_unorm_to_float(*(const uint8_t *)(v_data + 1), 8);
359             coords[2] = 0.0f;
360             coords[3] = 1.0f;
361             break;
362          case VK_FORMAT_R8G8B8A8_SNORM:
363             coords[0] = _mesa_snorm_to_float(*(const int8_t *)(v_data + 0), 8);
364             coords[1] = _mesa_snorm_to_float(*(const int8_t *)(v_data + 1), 8);
365             coords[2] = _mesa_snorm_to_float(*(const int8_t *)(v_data + 2), 8);
366             coords[3] = _mesa_snorm_to_float(*(const int8_t *)(v_data + 3), 8);
367             break;
368          case VK_FORMAT_R8G8B8A8_UNORM:
369             coords[0] = _mesa_unorm_to_float(*(const uint8_t *)(v_data + 0), 8);
370             coords[1] = _mesa_unorm_to_float(*(const uint8_t *)(v_data + 1), 8);
371             coords[2] = _mesa_unorm_to_float(*(const uint8_t *)(v_data + 2), 8);
372             coords[3] = _mesa_unorm_to_float(*(const uint8_t *)(v_data + 3), 8);
373             break;
374          case VK_FORMAT_A2B10G10R10_UNORM_PACK32: {
375             uint32_t val = *(const uint32_t *)v_data;
376             coords[0] = _mesa_unorm_to_float((val >> 0) & 0x3FF, 10);
377             coords[1] = _mesa_unorm_to_float((val >> 10) & 0x3FF, 10);
378             coords[2] = _mesa_unorm_to_float((val >> 20) & 0x3FF, 10);
379             coords[3] = _mesa_unorm_to_float((val >> 30) & 0x3, 2);
380          } break;
381          default:
382             unreachable("Unhandled vertex format in BVH build");
383          }
384 
385          for (unsigned j = 0; j < 3; ++j) {
386             float r = 0;
387             for (unsigned k = 0; k < 4; ++k)
388                r += matrix.matrix[j][k] * coords[k];
389             node->coords[v][j] = r;
390          }
391 
392          node->triangle_id = p;
393          node->geometry_id_and_flags = geometry_id | (geom->flags << 28);
394 
395          /* Seems to be needed for IJ, otherwise I = J = ? */
396          node->id = 9;
397       }
398    }
399 }
400 
401 static VkResult
build_instances(struct radv_device * device,struct radv_bvh_build_ctx * ctx,const VkAccelerationStructureGeometryKHR * geom,const VkAccelerationStructureBuildRangeInfoKHR * range)402 build_instances(struct radv_device *device, struct radv_bvh_build_ctx *ctx,
403                 const VkAccelerationStructureGeometryKHR *geom,
404                 const VkAccelerationStructureBuildRangeInfoKHR *range)
405 {
406    const VkAccelerationStructureGeometryInstancesDataKHR *inst_data = &geom->geometry.instances;
407 
408    for (uint32_t p = 0; p < range->primitiveCount; ++p, ctx->curr_ptr += 128) {
409       const char *instance_data =
410          (const char *)inst_data->data.hostAddress + range->primitiveOffset;
411       const VkAccelerationStructureInstanceKHR *instance =
412          inst_data->arrayOfPointers
413             ? (((const VkAccelerationStructureInstanceKHR *const *)instance_data)[p])
414             : &((const VkAccelerationStructureInstanceKHR *)instance_data)[p];
415       if (!instance->accelerationStructureReference) {
416          continue;
417       }
418 
419       struct radv_bvh_instance_node *node = (void *)ctx->curr_ptr;
420       uint32_t node_offset = ctx->curr_ptr - ctx->base;
421       uint32_t node_id = (node_offset >> 3) | radv_bvh_node_instance;
422       *ctx->write_scratch++ = node_id;
423 
424       float transform[16], inv_transform[16];
425       memcpy(transform, &instance->transform.matrix, sizeof(instance->transform.matrix));
426       transform[12] = transform[13] = transform[14] = 0.0f;
427       transform[15] = 1.0f;
428 
429       util_invert_mat4x4(inv_transform, transform);
430       memcpy(node->wto_matrix, inv_transform, sizeof(node->wto_matrix));
431       node->wto_matrix[3] = transform[3];
432       node->wto_matrix[7] = transform[7];
433       node->wto_matrix[11] = transform[11];
434       node->custom_instance_and_mask = instance->instanceCustomIndex | (instance->mask << 24);
435       node->sbt_offset_and_flags =
436          instance->instanceShaderBindingTableRecordOffset | (instance->flags << 24);
437       node->instance_id = p;
438 
439       for (unsigned i = 0; i < 3; ++i)
440          for (unsigned j = 0; j < 3; ++j)
441             node->otw_matrix[i * 3 + j] = instance->transform.matrix[j][i];
442 
443       RADV_FROM_HANDLE(radv_acceleration_structure, src_accel_struct,
444                        (VkAccelerationStructureKHR)instance->accelerationStructureReference);
445       const void *src_base = device->ws->buffer_map(src_accel_struct->bo);
446       if (!src_base)
447          return vk_error(device, VK_ERROR_OUT_OF_HOST_MEMORY);
448 
449       src_base = (const char *)src_base + src_accel_struct->mem_offset;
450       const struct radv_accel_struct_header *src_header = src_base;
451       node->base_ptr = radv_accel_struct_get_va(src_accel_struct) | src_header->root_node_offset;
452 
453       for (unsigned j = 0; j < 3; ++j) {
454          node->aabb[0][j] = instance->transform.matrix[j][3];
455          node->aabb[1][j] = instance->transform.matrix[j][3];
456          for (unsigned k = 0; k < 3; ++k) {
457             node->aabb[0][j] += MIN2(instance->transform.matrix[j][k] * src_header->aabb[0][k],
458                                      instance->transform.matrix[j][k] * src_header->aabb[1][k]);
459             node->aabb[1][j] += MAX2(instance->transform.matrix[j][k] * src_header->aabb[0][k],
460                                      instance->transform.matrix[j][k] * src_header->aabb[1][k]);
461          }
462       }
463       device->ws->buffer_unmap(src_accel_struct->bo);
464    }
465    return VK_SUCCESS;
466 }
467 
468 static void
build_aabbs(struct radv_bvh_build_ctx * ctx,const VkAccelerationStructureGeometryKHR * geom,const VkAccelerationStructureBuildRangeInfoKHR * range,unsigned geometry_id)469 build_aabbs(struct radv_bvh_build_ctx *ctx, const VkAccelerationStructureGeometryKHR *geom,
470             const VkAccelerationStructureBuildRangeInfoKHR *range, unsigned geometry_id)
471 {
472    const VkAccelerationStructureGeometryAabbsDataKHR *aabb_data = &geom->geometry.aabbs;
473 
474    for (uint32_t p = 0; p < range->primitiveCount; ++p, ctx->curr_ptr += 64) {
475       struct radv_bvh_aabb_node *node = (void *)ctx->curr_ptr;
476       uint32_t node_offset = ctx->curr_ptr - ctx->base;
477       uint32_t node_id = (node_offset >> 3) | radv_bvh_node_aabb;
478       *ctx->write_scratch++ = node_id;
479 
480       const VkAabbPositionsKHR *aabb =
481          (const VkAabbPositionsKHR *)((const char *)aabb_data->data.hostAddress +
482                                       range->primitiveOffset + p * aabb_data->stride);
483 
484       node->aabb[0][0] = aabb->minX;
485       node->aabb[0][1] = aabb->minY;
486       node->aabb[0][2] = aabb->minZ;
487       node->aabb[1][0] = aabb->maxX;
488       node->aabb[1][1] = aabb->maxY;
489       node->aabb[1][2] = aabb->maxZ;
490       node->primitive_id = p;
491       node->geometry_id_and_flags = geometry_id;
492    }
493 }
494 
495 static uint32_t
leaf_node_count(const VkAccelerationStructureBuildGeometryInfoKHR * info,const VkAccelerationStructureBuildRangeInfoKHR * ranges)496 leaf_node_count(const VkAccelerationStructureBuildGeometryInfoKHR *info,
497                 const VkAccelerationStructureBuildRangeInfoKHR *ranges)
498 {
499    uint32_t count = 0;
500    for (uint32_t i = 0; i < info->geometryCount; ++i) {
501       count += ranges[i].primitiveCount;
502    }
503    return count;
504 }
505 
506 static void
compute_bounds(const char * base_ptr,uint32_t node_id,float * bounds)507 compute_bounds(const char *base_ptr, uint32_t node_id, float *bounds)
508 {
509    for (unsigned i = 0; i < 3; ++i)
510       bounds[i] = INFINITY;
511    for (unsigned i = 0; i < 3; ++i)
512       bounds[3 + i] = -INFINITY;
513 
514    switch (node_id & 7) {
515    case radv_bvh_node_triangle: {
516       const struct radv_bvh_triangle_node *node = (const void *)(base_ptr + (node_id / 8 * 64));
517       for (unsigned v = 0; v < 3; ++v) {
518          for (unsigned j = 0; j < 3; ++j) {
519             bounds[j] = MIN2(bounds[j], node->coords[v][j]);
520             bounds[3 + j] = MAX2(bounds[3 + j], node->coords[v][j]);
521          }
522       }
523       break;
524    }
525    case radv_bvh_node_internal: {
526       const struct radv_bvh_box32_node *node = (const void *)(base_ptr + (node_id / 8 * 64));
527       for (unsigned c2 = 0; c2 < 4; ++c2) {
528          if (isnan(node->coords[c2][0][0]))
529             continue;
530          for (unsigned j = 0; j < 3; ++j) {
531             bounds[j] = MIN2(bounds[j], node->coords[c2][0][j]);
532             bounds[3 + j] = MAX2(bounds[3 + j], node->coords[c2][1][j]);
533          }
534       }
535       break;
536    }
537    case radv_bvh_node_instance: {
538       const struct radv_bvh_instance_node *node = (const void *)(base_ptr + (node_id / 8 * 64));
539       for (unsigned j = 0; j < 3; ++j) {
540          bounds[j] = MIN2(bounds[j], node->aabb[0][j]);
541          bounds[3 + j] = MAX2(bounds[3 + j], node->aabb[1][j]);
542       }
543       break;
544    }
545    case radv_bvh_node_aabb: {
546       const struct radv_bvh_aabb_node *node = (const void *)(base_ptr + (node_id / 8 * 64));
547       for (unsigned j = 0; j < 3; ++j) {
548          bounds[j] = MIN2(bounds[j], node->aabb[0][j]);
549          bounds[3 + j] = MAX2(bounds[3 + j], node->aabb[1][j]);
550       }
551       break;
552    }
553    }
554 }
555 
556 struct bvh_opt_entry {
557    uint64_t key;
558    uint32_t node_id;
559 };
560 
561 static int
bvh_opt_compare(const void * _a,const void * _b)562 bvh_opt_compare(const void *_a, const void *_b)
563 {
564    const struct bvh_opt_entry *a = _a;
565    const struct bvh_opt_entry *b = _b;
566 
567    if (a->key < b->key)
568       return -1;
569    if (a->key > b->key)
570       return 1;
571    if (a->node_id < b->node_id)
572       return -1;
573    if (a->node_id > b->node_id)
574       return 1;
575    return 0;
576 }
577 
578 static void
optimize_bvh(const char * base_ptr,uint32_t * node_ids,uint32_t node_count)579 optimize_bvh(const char *base_ptr, uint32_t *node_ids, uint32_t node_count)
580 {
581    if (node_count == 0)
582       return;
583 
584    float bounds[6];
585    for (unsigned i = 0; i < 3; ++i)
586       bounds[i] = INFINITY;
587    for (unsigned i = 0; i < 3; ++i)
588       bounds[3 + i] = -INFINITY;
589 
590    for (uint32_t i = 0; i < node_count; ++i) {
591       float node_bounds[6];
592       compute_bounds(base_ptr, node_ids[i], node_bounds);
593       for (unsigned j = 0; j < 3; ++j)
594          bounds[j] = MIN2(bounds[j], node_bounds[j]);
595       for (unsigned j = 0; j < 3; ++j)
596          bounds[3 + j] = MAX2(bounds[3 + j], node_bounds[3 + j]);
597    }
598 
599    struct bvh_opt_entry *entries = calloc(node_count, sizeof(struct bvh_opt_entry));
600    if (!entries)
601       return;
602 
603    for (uint32_t i = 0; i < node_count; ++i) {
604       float node_bounds[6];
605       compute_bounds(base_ptr, node_ids[i], node_bounds);
606       float node_coords[3];
607       for (unsigned j = 0; j < 3; ++j)
608          node_coords[j] = (node_bounds[j] + node_bounds[3 + j]) * 0.5;
609       int32_t coords[3];
610       for (unsigned j = 0; j < 3; ++j)
611          coords[j] = MAX2(
612             MIN2((int32_t)((node_coords[j] - bounds[j]) / (bounds[3 + j] - bounds[j]) * (1 << 21)),
613                  (1 << 21) - 1),
614             0);
615       uint64_t key = 0;
616       for (unsigned j = 0; j < 21; ++j)
617          for (unsigned k = 0; k < 3; ++k)
618             key |= (uint64_t)((coords[k] >> j) & 1) << (j * 3 + k);
619       entries[i].key = key;
620       entries[i].node_id = node_ids[i];
621    }
622 
623    qsort(entries, node_count, sizeof(entries[0]), bvh_opt_compare);
624    for (unsigned i = 0; i < node_count; ++i)
625       node_ids[i] = entries[i].node_id;
626 
627    free(entries);
628 }
629 
630 static void
fill_accel_struct_header(struct radv_accel_struct_header * header)631 fill_accel_struct_header(struct radv_accel_struct_header *header)
632 {
633    /* 16 bytes per invocation, 64 invocations per workgroup */
634    header->copy_dispatch_size[0] = DIV_ROUND_UP(header->compacted_size, 16 * 64);
635    header->copy_dispatch_size[1] = 1;
636    header->copy_dispatch_size[2] = 1;
637 
638    header->serialization_size =
639       header->compacted_size + align(sizeof(struct radv_accel_struct_serialization_header) +
640                                         sizeof(uint64_t) * header->instance_count,
641                                      128);
642 
643    header->size = header->serialization_size -
644                   sizeof(struct radv_accel_struct_serialization_header) -
645                   sizeof(uint64_t) * header->instance_count;
646 }
647 
648 static VkResult
build_bvh(struct radv_device * device,const VkAccelerationStructureBuildGeometryInfoKHR * info,const VkAccelerationStructureBuildRangeInfoKHR * ranges)649 build_bvh(struct radv_device *device, const VkAccelerationStructureBuildGeometryInfoKHR *info,
650           const VkAccelerationStructureBuildRangeInfoKHR *ranges)
651 {
652    RADV_FROM_HANDLE(radv_acceleration_structure, accel, info->dstAccelerationStructure);
653    VkResult result = VK_SUCCESS;
654 
655    uint32_t *scratch[2];
656    scratch[0] = info->scratchData.hostAddress;
657    scratch[1] = scratch[0] + leaf_node_count(info, ranges);
658 
659    char *base_ptr = (char *)device->ws->buffer_map(accel->bo);
660    if (!base_ptr)
661       return vk_error(device, VK_ERROR_OUT_OF_HOST_MEMORY);
662 
663    base_ptr = base_ptr + accel->mem_offset;
664    struct radv_accel_struct_header *header = (void *)base_ptr;
665    void *first_node_ptr = (char *)base_ptr + ALIGN(sizeof(*header), 64);
666 
667    struct radv_bvh_build_ctx ctx = {.write_scratch = scratch[0],
668                                     .base = base_ptr,
669                                     .curr_ptr = (char *)first_node_ptr + 128};
670 
671    uint64_t instance_offset = (const char *)ctx.curr_ptr - (const char *)base_ptr;
672    uint64_t instance_count = 0;
673 
674    /* This initializes the leaf nodes of the BVH all at the same level. */
675    for (int inst = 1; inst >= 0; --inst) {
676       for (uint32_t i = 0; i < info->geometryCount; ++i) {
677          const VkAccelerationStructureGeometryKHR *geom =
678             info->pGeometries ? &info->pGeometries[i] : info->ppGeometries[i];
679 
680          if ((inst && geom->geometryType != VK_GEOMETRY_TYPE_INSTANCES_KHR) ||
681              (!inst && geom->geometryType == VK_GEOMETRY_TYPE_INSTANCES_KHR))
682             continue;
683 
684          switch (geom->geometryType) {
685          case VK_GEOMETRY_TYPE_TRIANGLES_KHR:
686             build_triangles(&ctx, geom, ranges + i, i);
687             break;
688          case VK_GEOMETRY_TYPE_AABBS_KHR:
689             build_aabbs(&ctx, geom, ranges + i, i);
690             break;
691          case VK_GEOMETRY_TYPE_INSTANCES_KHR: {
692             result = build_instances(device, &ctx, geom, ranges + i);
693             if (result != VK_SUCCESS)
694                goto fail;
695 
696             instance_count += ranges[i].primitiveCount;
697             break;
698          }
699          case VK_GEOMETRY_TYPE_MAX_ENUM_KHR:
700             unreachable("VK_GEOMETRY_TYPE_MAX_ENUM_KHR unhandled");
701          }
702       }
703    }
704 
705    uint32_t node_counts[2] = {ctx.write_scratch - scratch[0], 0};
706    optimize_bvh(base_ptr, scratch[0], node_counts[0]);
707    unsigned d;
708 
709    /*
710     * This is the most naive BVH building algorithm I could think of:
711     * just iteratively builds each level from bottom to top with
712     * the children of each node being in-order and tightly packed.
713     *
714     * Is probably terrible for traversal but should be easy to build an
715     * equivalent GPU version.
716     */
717    for (d = 0; node_counts[d & 1] > 1 || d == 0; ++d) {
718       uint32_t child_count = node_counts[d & 1];
719       const uint32_t *children = scratch[d & 1];
720       uint32_t *dst_ids = scratch[(d & 1) ^ 1];
721       unsigned dst_count;
722       unsigned child_idx = 0;
723       for (dst_count = 0; child_idx < MAX2(1, child_count); ++dst_count, child_idx += 4) {
724          unsigned local_child_count = MIN2(4, child_count - child_idx);
725          uint32_t child_ids[4];
726          float bounds[4][6];
727 
728          for (unsigned c = 0; c < local_child_count; ++c) {
729             uint32_t id = children[child_idx + c];
730             child_ids[c] = id;
731 
732             compute_bounds(base_ptr, id, bounds[c]);
733          }
734 
735          struct radv_bvh_box32_node *node;
736 
737          /* Put the root node at base_ptr so the id = 0, which allows some
738           * traversal optimizations. */
739          if (child_idx == 0 && local_child_count == child_count) {
740             node = first_node_ptr;
741             header->root_node_offset = ((char *)first_node_ptr - (char *)base_ptr) / 64 * 8 + 5;
742          } else {
743             uint32_t dst_id = (ctx.curr_ptr - base_ptr) / 64;
744             dst_ids[dst_count] = dst_id * 8 + 5;
745 
746             node = (void *)ctx.curr_ptr;
747             ctx.curr_ptr += 128;
748          }
749 
750          for (unsigned c = 0; c < local_child_count; ++c) {
751             node->children[c] = child_ids[c];
752             for (unsigned i = 0; i < 2; ++i)
753                for (unsigned j = 0; j < 3; ++j)
754                   node->coords[c][i][j] = bounds[c][i * 3 + j];
755          }
756          for (unsigned c = local_child_count; c < 4; ++c) {
757             for (unsigned i = 0; i < 2; ++i)
758                for (unsigned j = 0; j < 3; ++j)
759                   node->coords[c][i][j] = NAN;
760          }
761       }
762 
763       node_counts[(d & 1) ^ 1] = dst_count;
764    }
765 
766    compute_bounds(base_ptr, header->root_node_offset, &header->aabb[0][0]);
767 
768    header->instance_offset = instance_offset;
769    header->instance_count = instance_count;
770    header->compacted_size = (char *)ctx.curr_ptr - base_ptr;
771 
772    fill_accel_struct_header(header);
773 
774 fail:
775    device->ws->buffer_unmap(accel->bo);
776    return result;
777 }
778 
779 VKAPI_ATTR VkResult VKAPI_CALL
radv_BuildAccelerationStructuresKHR(VkDevice _device,VkDeferredOperationKHR deferredOperation,uint32_t infoCount,const VkAccelerationStructureBuildGeometryInfoKHR * pInfos,const VkAccelerationStructureBuildRangeInfoKHR * const * ppBuildRangeInfos)780 radv_BuildAccelerationStructuresKHR(
781    VkDevice _device, VkDeferredOperationKHR deferredOperation, uint32_t infoCount,
782    const VkAccelerationStructureBuildGeometryInfoKHR *pInfos,
783    const VkAccelerationStructureBuildRangeInfoKHR *const *ppBuildRangeInfos)
784 {
785    RADV_FROM_HANDLE(radv_device, device, _device);
786    VkResult result = VK_SUCCESS;
787 
788    for (uint32_t i = 0; i < infoCount; ++i) {
789       result = build_bvh(device, pInfos + i, ppBuildRangeInfos[i]);
790       if (result != VK_SUCCESS)
791          break;
792    }
793    return result;
794 }
795 
796 VKAPI_ATTR VkResult VKAPI_CALL
radv_CopyAccelerationStructureKHR(VkDevice _device,VkDeferredOperationKHR deferredOperation,const VkCopyAccelerationStructureInfoKHR * pInfo)797 radv_CopyAccelerationStructureKHR(VkDevice _device, VkDeferredOperationKHR deferredOperation,
798                                   const VkCopyAccelerationStructureInfoKHR *pInfo)
799 {
800    RADV_FROM_HANDLE(radv_device, device, _device);
801    RADV_FROM_HANDLE(radv_acceleration_structure, src_struct, pInfo->src);
802    RADV_FROM_HANDLE(radv_acceleration_structure, dst_struct, pInfo->dst);
803 
804    char *src_ptr = (char *)device->ws->buffer_map(src_struct->bo);
805    if (!src_ptr)
806       return VK_ERROR_OUT_OF_HOST_MEMORY;
807 
808    char *dst_ptr = (char *)device->ws->buffer_map(dst_struct->bo);
809    if (!dst_ptr) {
810       device->ws->buffer_unmap(src_struct->bo);
811       return VK_ERROR_OUT_OF_HOST_MEMORY;
812    }
813 
814    src_ptr += src_struct->mem_offset;
815    dst_ptr += dst_struct->mem_offset;
816 
817    const struct radv_accel_struct_header *header = (const void *)src_ptr;
818    memcpy(dst_ptr, src_ptr, header->compacted_size);
819 
820    device->ws->buffer_unmap(src_struct->bo);
821    device->ws->buffer_unmap(dst_struct->bo);
822    return VK_SUCCESS;
823 }
824 
825 static nir_builder
create_accel_build_shader(struct radv_device * device,const char * name)826 create_accel_build_shader(struct radv_device *device, const char *name)
827 {
828    nir_builder b = radv_meta_init_shader(device, MESA_SHADER_COMPUTE, "%s", name);
829    b.shader->info.workgroup_size[0] = 64;
830 
831    assert(b.shader->info.workgroup_size[1] == 1);
832    assert(b.shader->info.workgroup_size[2] == 1);
833    assert(!b.shader->info.workgroup_size_variable);
834 
835    return b;
836 }
837 
838 static nir_ssa_def *
get_indices(nir_builder * b,nir_ssa_def * addr,nir_ssa_def * type,nir_ssa_def * id)839 get_indices(nir_builder *b, nir_ssa_def *addr, nir_ssa_def *type, nir_ssa_def *id)
840 {
841    const struct glsl_type *uvec3_type = glsl_vector_type(GLSL_TYPE_UINT, 3);
842    nir_variable *result =
843       nir_variable_create(b->shader, nir_var_shader_temp, uvec3_type, "indices");
844 
845    nir_push_if(b, nir_ult(b, type, nir_imm_int(b, 2)));
846    nir_push_if(b, nir_ieq_imm(b, type, VK_INDEX_TYPE_UINT16));
847    {
848       nir_ssa_def *index_id = nir_umul24(b, id, nir_imm_int(b, 6));
849       nir_ssa_def *indices[3];
850       for (unsigned i = 0; i < 3; ++i) {
851          indices[i] = nir_build_load_global(
852             b, 1, 16, nir_iadd(b, addr, nir_u2u64(b, nir_iadd_imm(b, index_id, 2 * i))));
853       }
854       nir_store_var(b, result, nir_u2u32(b, nir_vec(b, indices, 3)), 7);
855    }
856    nir_push_else(b, NULL);
857    {
858       nir_ssa_def *index_id = nir_umul24(b, id, nir_imm_int(b, 12));
859       nir_ssa_def *indices =
860          nir_build_load_global(b, 3, 32, nir_iadd(b, addr, nir_u2u64(b, index_id)));
861       nir_store_var(b, result, indices, 7);
862    }
863    nir_pop_if(b, NULL);
864    nir_push_else(b, NULL);
865    {
866       nir_ssa_def *index_id = nir_umul24(b, id, nir_imm_int(b, 3));
867       nir_ssa_def *indices[] = {
868          index_id,
869          nir_iadd_imm(b, index_id, 1),
870          nir_iadd_imm(b, index_id, 2),
871       };
872 
873       nir_push_if(b, nir_ieq_imm(b, type, VK_INDEX_TYPE_NONE_KHR));
874       {
875          nir_store_var(b, result, nir_vec(b, indices, 3), 7);
876       }
877       nir_push_else(b, NULL);
878       {
879          for (unsigned i = 0; i < 3; ++i) {
880             indices[i] =
881                nir_build_load_global(b, 1, 8, nir_iadd(b, addr, nir_u2u64(b, indices[i])));
882          }
883          nir_store_var(b, result, nir_u2u32(b, nir_vec(b, indices, 3)), 7);
884       }
885       nir_pop_if(b, NULL);
886    }
887    nir_pop_if(b, NULL);
888    return nir_load_var(b, result);
889 }
890 
891 static void
get_vertices(nir_builder * b,nir_ssa_def * addresses,nir_ssa_def * format,nir_ssa_def * positions[3])892 get_vertices(nir_builder *b, nir_ssa_def *addresses, nir_ssa_def *format, nir_ssa_def *positions[3])
893 {
894    const struct glsl_type *vec3_type = glsl_vector_type(GLSL_TYPE_FLOAT, 3);
895    nir_variable *results[3] = {
896       nir_variable_create(b->shader, nir_var_shader_temp, vec3_type, "vertex0"),
897       nir_variable_create(b->shader, nir_var_shader_temp, vec3_type, "vertex1"),
898       nir_variable_create(b->shader, nir_var_shader_temp, vec3_type, "vertex2")};
899 
900    VkFormat formats[] = {
901       VK_FORMAT_R32G32B32_SFLOAT,
902       VK_FORMAT_R32G32B32A32_SFLOAT,
903       VK_FORMAT_R16G16B16_SFLOAT,
904       VK_FORMAT_R16G16B16A16_SFLOAT,
905       VK_FORMAT_R16G16_SFLOAT,
906       VK_FORMAT_R32G32_SFLOAT,
907       VK_FORMAT_R16G16_SNORM,
908       VK_FORMAT_R16G16_UNORM,
909       VK_FORMAT_R16G16B16A16_SNORM,
910       VK_FORMAT_R16G16B16A16_UNORM,
911       VK_FORMAT_R8G8_SNORM,
912       VK_FORMAT_R8G8_UNORM,
913       VK_FORMAT_R8G8B8A8_SNORM,
914       VK_FORMAT_R8G8B8A8_UNORM,
915       VK_FORMAT_A2B10G10R10_UNORM_PACK32,
916    };
917 
918    for (unsigned f = 0; f < ARRAY_SIZE(formats); ++f) {
919       if (f + 1 < ARRAY_SIZE(formats))
920          nir_push_if(b, nir_ieq_imm(b, format, formats[f]));
921 
922       for (unsigned i = 0; i < 3; ++i) {
923          switch (formats[f]) {
924          case VK_FORMAT_R32G32B32_SFLOAT:
925          case VK_FORMAT_R32G32B32A32_SFLOAT:
926             nir_store_var(b, results[i],
927                           nir_build_load_global(b, 3, 32, nir_channel(b, addresses, i)), 7);
928             break;
929          case VK_FORMAT_R32G32_SFLOAT:
930          case VK_FORMAT_R16G16_SFLOAT:
931          case VK_FORMAT_R16G16B16_SFLOAT:
932          case VK_FORMAT_R16G16B16A16_SFLOAT:
933          case VK_FORMAT_R16G16_SNORM:
934          case VK_FORMAT_R16G16_UNORM:
935          case VK_FORMAT_R16G16B16A16_SNORM:
936          case VK_FORMAT_R16G16B16A16_UNORM:
937          case VK_FORMAT_R8G8_SNORM:
938          case VK_FORMAT_R8G8_UNORM:
939          case VK_FORMAT_R8G8B8A8_SNORM:
940          case VK_FORMAT_R8G8B8A8_UNORM:
941          case VK_FORMAT_A2B10G10R10_UNORM_PACK32: {
942             unsigned components = MIN2(3, vk_format_get_nr_components(formats[f]));
943             unsigned comp_bits =
944                vk_format_get_blocksizebits(formats[f]) / vk_format_get_nr_components(formats[f]);
945             unsigned comp_bytes = comp_bits / 8;
946             nir_ssa_def *values[3];
947             nir_ssa_def *addr = nir_channel(b, addresses, i);
948 
949             if (formats[f] == VK_FORMAT_A2B10G10R10_UNORM_PACK32) {
950                comp_bits = 10;
951                nir_ssa_def *val = nir_build_load_global(b, 1, 32, addr);
952                for (unsigned j = 0; j < 3; ++j)
953                   values[j] = nir_ubfe(b, val, nir_imm_int(b, j * 10), nir_imm_int(b, 10));
954             } else {
955                for (unsigned j = 0; j < components; ++j)
956                   values[j] =
957                      nir_build_load_global(b, 1, comp_bits, nir_iadd_imm(b, addr, j * comp_bytes));
958 
959                for (unsigned j = components; j < 3; ++j)
960                   values[j] = nir_imm_intN_t(b, 0, comp_bits);
961             }
962 
963             nir_ssa_def *vec;
964             if (util_format_is_snorm(vk_format_to_pipe_format(formats[f]))) {
965                for (unsigned j = 0; j < 3; ++j) {
966                   values[j] = nir_fdiv(b, nir_i2f32(b, values[j]),
967                                        nir_imm_float(b, (1u << (comp_bits - 1)) - 1));
968                   values[j] = nir_fmax(b, values[j], nir_imm_float(b, -1.0));
969                }
970                vec = nir_vec(b, values, 3);
971             } else if (util_format_is_unorm(vk_format_to_pipe_format(formats[f]))) {
972                for (unsigned j = 0; j < 3; ++j) {
973                   values[j] =
974                      nir_fdiv(b, nir_u2f32(b, values[j]), nir_imm_float(b, (1u << comp_bits) - 1));
975                   values[j] = nir_fmin(b, values[j], nir_imm_float(b, 1.0));
976                }
977                vec = nir_vec(b, values, 3);
978             } else if (comp_bits == 16)
979                vec = nir_f2f32(b, nir_vec(b, values, 3));
980             else
981                vec = nir_vec(b, values, 3);
982             nir_store_var(b, results[i], vec, 7);
983             break;
984          }
985          default:
986             unreachable("Unhandled format");
987          }
988       }
989       if (f + 1 < ARRAY_SIZE(formats))
990          nir_push_else(b, NULL);
991    }
992    for (unsigned f = 1; f < ARRAY_SIZE(formats); ++f) {
993       nir_pop_if(b, NULL);
994    }
995 
996    for (unsigned i = 0; i < 3; ++i)
997       positions[i] = nir_load_var(b, results[i]);
998 }
999 
1000 struct build_primitive_constants {
1001    uint64_t node_dst_addr;
1002    uint64_t scratch_addr;
1003    uint32_t dst_offset;
1004    uint32_t dst_scratch_offset;
1005    uint32_t geometry_type;
1006    uint32_t geometry_id;
1007 
1008    union {
1009       struct {
1010          uint64_t vertex_addr;
1011          uint64_t index_addr;
1012          uint64_t transform_addr;
1013          uint32_t vertex_stride;
1014          uint32_t vertex_format;
1015          uint32_t index_format;
1016       };
1017       struct {
1018          uint64_t instance_data;
1019          uint32_t array_of_pointers;
1020       };
1021       struct {
1022          uint64_t aabb_addr;
1023          uint32_t aabb_stride;
1024       };
1025    };
1026 };
1027 
1028 struct bounds_constants {
1029    uint64_t node_addr;
1030    uint64_t scratch_addr;
1031 };
1032 
1033 struct morton_constants {
1034    uint64_t node_addr;
1035    uint64_t scratch_addr;
1036 };
1037 
1038 struct fill_constants {
1039    uint64_t addr;
1040    uint32_t value;
1041 };
1042 
1043 struct build_internal_constants {
1044    uint64_t node_dst_addr;
1045    uint64_t scratch_addr;
1046    uint32_t dst_offset;
1047    uint32_t dst_scratch_offset;
1048    uint32_t src_scratch_offset;
1049    uint32_t fill_header;
1050 };
1051 
1052 /* This inverts a 3x3 matrix using cofactors, as in e.g.
1053  * https://www.mathsisfun.com/algebra/matrix-inverse-minors-cofactors-adjugate.html */
1054 static void
nir_invert_3x3(nir_builder * b,nir_ssa_def * in[3][3],nir_ssa_def * out[3][3])1055 nir_invert_3x3(nir_builder *b, nir_ssa_def *in[3][3], nir_ssa_def *out[3][3])
1056 {
1057    nir_ssa_def *cofactors[3][3];
1058    for (unsigned i = 0; i < 3; ++i) {
1059       for (unsigned j = 0; j < 3; ++j) {
1060          cofactors[i][j] =
1061             nir_fsub(b, nir_fmul(b, in[(i + 1) % 3][(j + 1) % 3], in[(i + 2) % 3][(j + 2) % 3]),
1062                      nir_fmul(b, in[(i + 1) % 3][(j + 2) % 3], in[(i + 2) % 3][(j + 1) % 3]));
1063       }
1064    }
1065 
1066    nir_ssa_def *det = NULL;
1067    for (unsigned i = 0; i < 3; ++i) {
1068       nir_ssa_def *det_part = nir_fmul(b, in[0][i], cofactors[0][i]);
1069       det = det ? nir_fadd(b, det, det_part) : det_part;
1070    }
1071 
1072    nir_ssa_def *det_inv = nir_frcp(b, det);
1073    for (unsigned i = 0; i < 3; ++i) {
1074       for (unsigned j = 0; j < 3; ++j) {
1075          out[i][j] = nir_fmul(b, cofactors[j][i], det_inv);
1076       }
1077    }
1078 }
1079 
1080 static nir_ssa_def *
id_to_node_id_offset(nir_builder * b,nir_ssa_def * global_id,const struct radv_physical_device * pdevice)1081 id_to_node_id_offset(nir_builder *b, nir_ssa_def *global_id,
1082                      const struct radv_physical_device *pdevice)
1083 {
1084    uint32_t stride = get_node_id_stride(
1085       get_accel_struct_build(pdevice, VK_ACCELERATION_STRUCTURE_BUILD_TYPE_DEVICE_KHR));
1086 
1087    return nir_imul_imm(b, global_id, stride);
1088 }
1089 
1090 static nir_ssa_def *
id_to_morton_offset(nir_builder * b,nir_ssa_def * global_id,const struct radv_physical_device * pdevice)1091 id_to_morton_offset(nir_builder *b, nir_ssa_def *global_id,
1092                     const struct radv_physical_device *pdevice)
1093 {
1094    enum accel_struct_build build_mode =
1095       get_accel_struct_build(pdevice, VK_ACCELERATION_STRUCTURE_BUILD_TYPE_DEVICE_KHR);
1096    assert(build_mode == accel_struct_build_lbvh);
1097 
1098    uint32_t stride = get_node_id_stride(build_mode);
1099 
1100    return nir_iadd_imm(b, nir_imul_imm(b, global_id, stride), sizeof(uint32_t));
1101 }
1102 
1103 static void
atomic_fminmax(struct radv_device * dev,nir_builder * b,nir_ssa_def * addr,bool is_max,nir_ssa_def * val)1104 atomic_fminmax(struct radv_device *dev, nir_builder *b, nir_ssa_def *addr, bool is_max,
1105                nir_ssa_def *val)
1106 {
1107    if (radv_has_shader_buffer_float_minmax(dev->physical_device)) {
1108       if (is_max)
1109          nir_global_atomic_fmax(b, 32, addr, val);
1110       else
1111          nir_global_atomic_fmin(b, 32, addr, val);
1112       return;
1113    }
1114 
1115    /* Use an integer comparison to work correctly with negative zero. */
1116    val = nir_bcsel(b, nir_ilt(b, val, nir_imm_int(b, 0)),
1117                    nir_isub(b, nir_imm_int(b, -2147483648), val), val);
1118 
1119    if (is_max)
1120       nir_global_atomic_imax(b, 32, addr, val);
1121    else
1122       nir_global_atomic_imin(b, 32, addr, val);
1123 }
1124 
1125 static nir_ssa_def *
read_fminmax_atomic(struct radv_device * dev,nir_builder * b,unsigned channels,nir_ssa_def * addr)1126 read_fminmax_atomic(struct radv_device *dev, nir_builder *b, unsigned channels, nir_ssa_def *addr)
1127 {
1128    nir_ssa_def *val = nir_build_load_global(b, channels, 32, addr,
1129                                             .access = ACCESS_NON_WRITEABLE | ACCESS_CAN_REORDER);
1130 
1131    if (radv_has_shader_buffer_float_minmax(dev->physical_device))
1132       return val;
1133 
1134    return nir_bcsel(b, nir_ilt(b, val, nir_imm_int(b, 0)),
1135                     nir_isub(b, nir_imm_int(b, -2147483648), val), val);
1136 }
1137 
1138 static nir_shader *
build_leaf_shader(struct radv_device * dev)1139 build_leaf_shader(struct radv_device *dev)
1140 {
1141    enum accel_struct_build build_mode =
1142       get_accel_struct_build(dev->physical_device, VK_ACCELERATION_STRUCTURE_BUILD_TYPE_DEVICE_KHR);
1143 
1144    const struct glsl_type *vec3_type = glsl_vector_type(GLSL_TYPE_FLOAT, 3);
1145    nir_builder b = create_accel_build_shader(dev, "accel_build_leaf_shader");
1146 
1147    nir_ssa_def *pconst0 =
1148       nir_load_push_constant(&b, 4, 32, nir_imm_int(&b, 0), .base = 0, .range = 16);
1149    nir_ssa_def *pconst1 =
1150       nir_load_push_constant(&b, 4, 32, nir_imm_int(&b, 0), .base = 16, .range = 16);
1151    nir_ssa_def *pconst2 =
1152       nir_load_push_constant(&b, 4, 32, nir_imm_int(&b, 0), .base = 32, .range = 16);
1153    nir_ssa_def *pconst3 =
1154       nir_load_push_constant(&b, 4, 32, nir_imm_int(&b, 0), .base = 48, .range = 16);
1155    nir_ssa_def *index_format =
1156       nir_load_push_constant(&b, 1, 32, nir_imm_int(&b, 0), .base = 64, .range = 4);
1157 
1158    nir_ssa_def *node_dst_addr = nir_pack_64_2x32(&b, nir_channels(&b, pconst0, 0b0011));
1159    nir_ssa_def *scratch_addr = nir_pack_64_2x32(&b, nir_channels(&b, pconst0, 0b1100));
1160    nir_ssa_def *node_dst_offset = nir_channel(&b, pconst1, 0);
1161    nir_ssa_def *scratch_offset = nir_channel(&b, pconst1, 1);
1162    nir_ssa_def *geom_type = nir_channel(&b, pconst1, 2);
1163    nir_ssa_def *geometry_id = nir_channel(&b, pconst1, 3);
1164 
1165    nir_ssa_def *global_id =
1166       nir_iadd(&b,
1167                nir_imul_imm(&b, nir_channels(&b, nir_load_workgroup_id(&b, 32), 1),
1168                             b.shader->info.workgroup_size[0]),
1169                nir_channels(&b, nir_load_local_invocation_id(&b), 1));
1170    nir_ssa_def *scratch_dst_addr =
1171       nir_iadd(&b, scratch_addr,
1172                nir_u2u64(&b, nir_iadd(&b, scratch_offset,
1173                                       id_to_node_id_offset(&b, global_id, dev->physical_device))));
1174    if (build_mode != accel_struct_build_unoptimized)
1175       scratch_dst_addr = nir_iadd_imm(&b, scratch_dst_addr, SCRATCH_TOTAL_BOUNDS_SIZE);
1176 
1177    nir_variable *bounds[2] = {
1178       nir_variable_create(b.shader, nir_var_shader_temp, vec3_type, "min_bound"),
1179       nir_variable_create(b.shader, nir_var_shader_temp, vec3_type, "max_bound"),
1180    };
1181 
1182    nir_push_if(&b, nir_ieq_imm(&b, geom_type, VK_GEOMETRY_TYPE_TRIANGLES_KHR));
1183    { /* Triangles */
1184       nir_ssa_def *vertex_addr = nir_pack_64_2x32(&b, nir_channels(&b, pconst2, 0b0011));
1185       nir_ssa_def *index_addr = nir_pack_64_2x32(&b, nir_channels(&b, pconst2, 0b1100));
1186       nir_ssa_def *transform_addr = nir_pack_64_2x32(&b, nir_channels(&b, pconst3, 3));
1187       nir_ssa_def *vertex_stride = nir_channel(&b, pconst3, 2);
1188       nir_ssa_def *vertex_format = nir_channel(&b, pconst3, 3);
1189       unsigned repl_swizzle[4] = {0, 0, 0, 0};
1190 
1191       nir_ssa_def *node_offset = nir_iadd(&b, node_dst_offset, nir_imul_imm(&b, global_id, 64));
1192       nir_ssa_def *triangle_node_dst_addr = nir_iadd(&b, node_dst_addr, nir_u2u64(&b, node_offset));
1193 
1194       nir_ssa_def *indices = get_indices(&b, index_addr, index_format, global_id);
1195       nir_ssa_def *vertex_addresses = nir_iadd(
1196          &b, nir_u2u64(&b, nir_imul(&b, indices, nir_swizzle(&b, vertex_stride, repl_swizzle, 3))),
1197          nir_swizzle(&b, vertex_addr, repl_swizzle, 3));
1198       nir_ssa_def *positions[3];
1199       get_vertices(&b, vertex_addresses, vertex_format, positions);
1200 
1201       nir_ssa_def *node_data[16];
1202       memset(node_data, 0, sizeof(node_data));
1203 
1204       nir_variable *transform[] = {
1205          nir_variable_create(b.shader, nir_var_shader_temp, glsl_vec4_type(), "transform0"),
1206          nir_variable_create(b.shader, nir_var_shader_temp, glsl_vec4_type(), "transform1"),
1207          nir_variable_create(b.shader, nir_var_shader_temp, glsl_vec4_type(), "transform2"),
1208       };
1209       nir_store_var(&b, transform[0], nir_imm_vec4(&b, 1.0, 0.0, 0.0, 0.0), 0xf);
1210       nir_store_var(&b, transform[1], nir_imm_vec4(&b, 0.0, 1.0, 0.0, 0.0), 0xf);
1211       nir_store_var(&b, transform[2], nir_imm_vec4(&b, 0.0, 0.0, 1.0, 0.0), 0xf);
1212 
1213       nir_push_if(&b, nir_ine_imm(&b, transform_addr, 0));
1214       nir_store_var(&b, transform[0],
1215                     nir_build_load_global(&b, 4, 32, nir_iadd_imm(&b, transform_addr, 0),
1216                                           .access = ACCESS_NON_WRITEABLE | ACCESS_CAN_REORDER),
1217                     0xf);
1218       nir_store_var(&b, transform[1],
1219                     nir_build_load_global(&b, 4, 32, nir_iadd_imm(&b, transform_addr, 16),
1220                                           .access = ACCESS_NON_WRITEABLE | ACCESS_CAN_REORDER),
1221                     0xf);
1222       nir_store_var(&b, transform[2],
1223                     nir_build_load_global(&b, 4, 32, nir_iadd_imm(&b, transform_addr, 32),
1224                                           .access = ACCESS_NON_WRITEABLE | ACCESS_CAN_REORDER),
1225                     0xf);
1226       nir_pop_if(&b, NULL);
1227 
1228       for (unsigned i = 0; i < 3; ++i)
1229          for (unsigned j = 0; j < 3; ++j)
1230             node_data[i * 3 + j] = nir_fdph(&b, positions[i], nir_load_var(&b, transform[j]));
1231 
1232       nir_ssa_def *min_bound = NULL;
1233       nir_ssa_def *max_bound = NULL;
1234       for (unsigned i = 0; i < 3; ++i) {
1235          nir_ssa_def *position = nir_vec(&b, node_data + i * 3, 3);
1236          if (min_bound) {
1237             min_bound = nir_fmin(&b, min_bound, position);
1238             max_bound = nir_fmax(&b, max_bound, position);
1239          } else {
1240             min_bound = position;
1241             max_bound = position;
1242          }
1243       }
1244 
1245       nir_store_var(&b, bounds[0], min_bound, 7);
1246       nir_store_var(&b, bounds[1], max_bound, 7);
1247 
1248       node_data[12] = global_id;
1249       node_data[13] = geometry_id;
1250       node_data[15] = nir_imm_int(&b, 9);
1251       for (unsigned i = 0; i < ARRAY_SIZE(node_data); ++i)
1252          if (!node_data[i])
1253             node_data[i] = nir_imm_int(&b, 0);
1254 
1255       for (unsigned i = 0; i < 4; ++i) {
1256          nir_build_store_global(&b, nir_vec(&b, node_data + i * 4, 4),
1257                                 nir_iadd_imm(&b, triangle_node_dst_addr, i * 16), .align_mul = 16);
1258       }
1259 
1260       nir_ssa_def *node_id =
1261          nir_iadd_imm(&b, nir_ushr_imm(&b, node_offset, 3), radv_bvh_node_triangle);
1262       nir_build_store_global(&b, node_id, scratch_dst_addr);
1263    }
1264    nir_push_else(&b, NULL);
1265    nir_push_if(&b, nir_ieq_imm(&b, geom_type, VK_GEOMETRY_TYPE_AABBS_KHR));
1266    { /* AABBs */
1267       nir_ssa_def *aabb_addr = nir_pack_64_2x32(&b, nir_channels(&b, pconst2, 0b0011));
1268       nir_ssa_def *aabb_stride = nir_channel(&b, pconst2, 2);
1269 
1270       nir_ssa_def *node_offset = nir_iadd(&b, node_dst_offset, nir_imul_imm(&b, global_id, 64));
1271       nir_ssa_def *aabb_node_dst_addr = nir_iadd(&b, node_dst_addr, nir_u2u64(&b, node_offset));
1272 
1273       nir_ssa_def *node_id = nir_iadd_imm(&b, nir_ushr_imm(&b, node_offset, 3), radv_bvh_node_aabb);
1274       nir_build_store_global(&b, node_id, scratch_dst_addr);
1275 
1276       aabb_addr = nir_iadd(&b, aabb_addr, nir_u2u64(&b, nir_imul(&b, aabb_stride, global_id)));
1277 
1278       nir_ssa_def *min_bound =
1279          nir_build_load_global(&b, 3, 32, nir_iadd_imm(&b, aabb_addr, 0),
1280                                .access = ACCESS_NON_WRITEABLE | ACCESS_CAN_REORDER);
1281       nir_ssa_def *max_bound =
1282          nir_build_load_global(&b, 3, 32, nir_iadd_imm(&b, aabb_addr, 12),
1283                                .access = ACCESS_NON_WRITEABLE | ACCESS_CAN_REORDER);
1284 
1285       nir_store_var(&b, bounds[0], min_bound, 7);
1286       nir_store_var(&b, bounds[1], max_bound, 7);
1287 
1288       nir_ssa_def *values[] = {nir_channel(&b, min_bound, 0),
1289                                nir_channel(&b, min_bound, 1),
1290                                nir_channel(&b, min_bound, 2),
1291                                nir_channel(&b, max_bound, 0),
1292                                nir_channel(&b, max_bound, 1),
1293                                nir_channel(&b, max_bound, 2),
1294                                global_id,
1295                                geometry_id};
1296 
1297       nir_build_store_global(&b, nir_vec(&b, values + 0, 4),
1298                              nir_iadd_imm(&b, aabb_node_dst_addr, 0), .align_mul = 16);
1299       nir_build_store_global(&b, nir_vec(&b, values + 4, 4),
1300                              nir_iadd_imm(&b, aabb_node_dst_addr, 16), .align_mul = 16);
1301    }
1302    nir_push_else(&b, NULL);
1303    { /* Instances */
1304 
1305       nir_variable *instance_addr_var =
1306          nir_variable_create(b.shader, nir_var_shader_temp, glsl_uint64_t_type(), "instance_addr");
1307       nir_push_if(&b, nir_ine_imm(&b, nir_channel(&b, pconst2, 2), 0));
1308       {
1309          nir_ssa_def *ptr = nir_iadd(&b, nir_pack_64_2x32(&b, nir_channels(&b, pconst2, 0b0011)),
1310                                      nir_u2u64(&b, nir_imul_imm(&b, global_id, 8)));
1311          nir_ssa_def *addr =
1312             nir_pack_64_2x32(&b, nir_build_load_global(&b, 2, 32, ptr, .align_mul = 8));
1313          nir_store_var(&b, instance_addr_var, addr, 1);
1314       }
1315       nir_push_else(&b, NULL);
1316       {
1317          nir_ssa_def *addr = nir_iadd(&b, nir_pack_64_2x32(&b, nir_channels(&b, pconst2, 0b0011)),
1318                                       nir_u2u64(&b, nir_imul_imm(&b, global_id, 64)));
1319          nir_store_var(&b, instance_addr_var, addr, 1);
1320       }
1321       nir_pop_if(&b, NULL);
1322       nir_ssa_def *instance_addr = nir_load_var(&b, instance_addr_var);
1323 
1324       nir_ssa_def *inst_transform[] = {
1325          nir_build_load_global(&b, 4, 32, nir_iadd_imm(&b, instance_addr, 0)),
1326          nir_build_load_global(&b, 4, 32, nir_iadd_imm(&b, instance_addr, 16)),
1327          nir_build_load_global(&b, 4, 32, nir_iadd_imm(&b, instance_addr, 32))};
1328       nir_ssa_def *inst3 = nir_build_load_global(&b, 4, 32, nir_iadd_imm(&b, instance_addr, 48));
1329 
1330       nir_ssa_def *node_offset = nir_iadd(&b, node_dst_offset, nir_imul_imm(&b, global_id, 128));
1331       node_dst_addr = nir_iadd(&b, node_dst_addr, nir_u2u64(&b, node_offset));
1332 
1333       nir_ssa_def *node_id =
1334          nir_iadd_imm(&b, nir_ushr_imm(&b, node_offset, 3), radv_bvh_node_instance);
1335       nir_build_store_global(&b, node_id, scratch_dst_addr);
1336 
1337       nir_ssa_def *header_addr = nir_pack_64_2x32(&b, nir_channels(&b, inst3, 12));
1338       nir_push_if(&b, nir_ine_imm(&b, header_addr, 0));
1339       nir_ssa_def *header_root_offset =
1340          nir_build_load_global(&b, 1, 32, nir_iadd_imm(&b, header_addr, 0));
1341       nir_ssa_def *header_min = nir_build_load_global(&b, 3, 32, nir_iadd_imm(&b, header_addr, 8));
1342       nir_ssa_def *header_max = nir_build_load_global(&b, 3, 32, nir_iadd_imm(&b, header_addr, 20));
1343 
1344       nir_ssa_def *bound_defs[2][3];
1345       for (unsigned i = 0; i < 3; ++i) {
1346          bound_defs[0][i] = bound_defs[1][i] = nir_channel(&b, inst_transform[i], 3);
1347 
1348          nir_ssa_def *mul_a = nir_fmul(&b, nir_channels(&b, inst_transform[i], 7), header_min);
1349          nir_ssa_def *mul_b = nir_fmul(&b, nir_channels(&b, inst_transform[i], 7), header_max);
1350          nir_ssa_def *mi = nir_fmin(&b, mul_a, mul_b);
1351          nir_ssa_def *ma = nir_fmax(&b, mul_a, mul_b);
1352          for (unsigned j = 0; j < 3; ++j) {
1353             bound_defs[0][i] = nir_fadd(&b, bound_defs[0][i], nir_channel(&b, mi, j));
1354             bound_defs[1][i] = nir_fadd(&b, bound_defs[1][i], nir_channel(&b, ma, j));
1355          }
1356       }
1357 
1358       nir_store_var(&b, bounds[0], nir_vec(&b, bound_defs[0], 3), 7);
1359       nir_store_var(&b, bounds[1], nir_vec(&b, bound_defs[1], 3), 7);
1360 
1361       /* Store object to world matrix */
1362       for (unsigned i = 0; i < 3; ++i) {
1363          nir_ssa_def *vals[3];
1364          for (unsigned j = 0; j < 3; ++j)
1365             vals[j] = nir_channel(&b, inst_transform[j], i);
1366 
1367          nir_build_store_global(&b, nir_vec(&b, vals, 3),
1368                                 nir_iadd_imm(&b, node_dst_addr, 92 + 12 * i));
1369       }
1370 
1371       nir_ssa_def *m_in[3][3], *m_out[3][3], *m_vec[3][4];
1372       for (unsigned i = 0; i < 3; ++i)
1373          for (unsigned j = 0; j < 3; ++j)
1374             m_in[i][j] = nir_channel(&b, inst_transform[i], j);
1375       nir_invert_3x3(&b, m_in, m_out);
1376       for (unsigned i = 0; i < 3; ++i) {
1377          for (unsigned j = 0; j < 3; ++j)
1378             m_vec[i][j] = m_out[i][j];
1379          m_vec[i][3] = nir_channel(&b, inst_transform[i], 3);
1380       }
1381 
1382       for (unsigned i = 0; i < 3; ++i) {
1383          nir_build_store_global(&b, nir_vec(&b, m_vec[i], 4),
1384                                 nir_iadd_imm(&b, node_dst_addr, 16 + 16 * i));
1385       }
1386 
1387       nir_ssa_def *out0[4] = {
1388          nir_ior(&b, nir_channel(&b, nir_unpack_64_2x32(&b, header_addr), 0), header_root_offset),
1389          nir_channel(&b, nir_unpack_64_2x32(&b, header_addr), 1), nir_channel(&b, inst3, 0),
1390          nir_channel(&b, inst3, 1)};
1391       nir_build_store_global(&b, nir_vec(&b, out0, 4), nir_iadd_imm(&b, node_dst_addr, 0));
1392       nir_build_store_global(&b, global_id, nir_iadd_imm(&b, node_dst_addr, 88));
1393       nir_pop_if(&b, NULL);
1394       nir_build_store_global(&b, nir_load_var(&b, bounds[0]), nir_iadd_imm(&b, node_dst_addr, 64));
1395       nir_build_store_global(&b, nir_load_var(&b, bounds[1]), nir_iadd_imm(&b, node_dst_addr, 76));
1396    }
1397    nir_pop_if(&b, NULL);
1398    nir_pop_if(&b, NULL);
1399 
1400    if (build_mode != accel_struct_build_unoptimized) {
1401       nir_ssa_def *min = nir_load_var(&b, bounds[0]);
1402       nir_ssa_def *max = nir_load_var(&b, bounds[1]);
1403 
1404       nir_ssa_def *min_reduced = nir_reduce(&b, min, .reduction_op = nir_op_fmin);
1405       nir_ssa_def *max_reduced = nir_reduce(&b, max, .reduction_op = nir_op_fmax);
1406 
1407       nir_push_if(&b, nir_elect(&b, 1));
1408 
1409       atomic_fminmax(dev, &b, scratch_addr, false, nir_channel(&b, min_reduced, 0));
1410       atomic_fminmax(dev, &b, nir_iadd_imm(&b, scratch_addr, 4), false,
1411                      nir_channel(&b, min_reduced, 1));
1412       atomic_fminmax(dev, &b, nir_iadd_imm(&b, scratch_addr, 8), false,
1413                      nir_channel(&b, min_reduced, 2));
1414 
1415       atomic_fminmax(dev, &b, nir_iadd_imm(&b, scratch_addr, 12), true,
1416                      nir_channel(&b, max_reduced, 0));
1417       atomic_fminmax(dev, &b, nir_iadd_imm(&b, scratch_addr, 16), true,
1418                      nir_channel(&b, max_reduced, 1));
1419       atomic_fminmax(dev, &b, nir_iadd_imm(&b, scratch_addr, 20), true,
1420                      nir_channel(&b, max_reduced, 2));
1421    }
1422 
1423    return b.shader;
1424 }
1425 
1426 static void
determine_bounds(nir_builder * b,nir_ssa_def * node_addr,nir_ssa_def * node_id,nir_variable * bounds_vars[2])1427 determine_bounds(nir_builder *b, nir_ssa_def *node_addr, nir_ssa_def *node_id,
1428                  nir_variable *bounds_vars[2])
1429 {
1430    nir_ssa_def *node_type = nir_iand_imm(b, node_id, 7);
1431    node_addr =
1432       nir_iadd(b, node_addr, nir_u2u64(b, nir_ishl_imm(b, nir_iand_imm(b, node_id, ~7u), 3)));
1433 
1434    nir_push_if(b, nir_ieq_imm(b, node_type, radv_bvh_node_triangle));
1435    {
1436       nir_ssa_def *positions[3];
1437       for (unsigned i = 0; i < 3; ++i)
1438          positions[i] = nir_build_load_global(b, 3, 32, nir_iadd_imm(b, node_addr, i * 12));
1439       nir_ssa_def *bounds[] = {positions[0], positions[0]};
1440       for (unsigned i = 1; i < 3; ++i) {
1441          bounds[0] = nir_fmin(b, bounds[0], positions[i]);
1442          bounds[1] = nir_fmax(b, bounds[1], positions[i]);
1443       }
1444       nir_store_var(b, bounds_vars[0], bounds[0], 7);
1445       nir_store_var(b, bounds_vars[1], bounds[1], 7);
1446    }
1447    nir_push_else(b, NULL);
1448    nir_push_if(b, nir_ieq_imm(b, node_type, radv_bvh_node_internal));
1449    {
1450       nir_ssa_def *input_bounds[4][2];
1451       for (unsigned i = 0; i < 4; ++i)
1452          for (unsigned j = 0; j < 2; ++j)
1453             input_bounds[i][j] =
1454                nir_build_load_global(b, 3, 32, nir_iadd_imm(b, node_addr, 16 + i * 24 + j * 12));
1455       nir_ssa_def *bounds[] = {input_bounds[0][0], input_bounds[0][1]};
1456       for (unsigned i = 1; i < 4; ++i) {
1457          bounds[0] = nir_fmin(b, bounds[0], input_bounds[i][0]);
1458          bounds[1] = nir_fmax(b, bounds[1], input_bounds[i][1]);
1459       }
1460 
1461       nir_store_var(b, bounds_vars[0], bounds[0], 7);
1462       nir_store_var(b, bounds_vars[1], bounds[1], 7);
1463    }
1464    nir_push_else(b, NULL);
1465    nir_push_if(b, nir_ieq_imm(b, node_type, radv_bvh_node_instance));
1466    { /* Instances */
1467       nir_ssa_def *bounds[2];
1468       for (unsigned i = 0; i < 2; ++i)
1469          bounds[i] = nir_build_load_global(b, 3, 32, nir_iadd_imm(b, node_addr, 64 + i * 12));
1470       nir_store_var(b, bounds_vars[0], bounds[0], 7);
1471       nir_store_var(b, bounds_vars[1], bounds[1], 7);
1472    }
1473    nir_push_else(b, NULL);
1474    { /* AABBs */
1475       nir_ssa_def *bounds[2];
1476       for (unsigned i = 0; i < 2; ++i)
1477          bounds[i] = nir_build_load_global(b, 3, 32, nir_iadd_imm(b, node_addr, i * 12));
1478       nir_store_var(b, bounds_vars[0], bounds[0], 7);
1479       nir_store_var(b, bounds_vars[1], bounds[1], 7);
1480    }
1481    nir_pop_if(b, NULL);
1482    nir_pop_if(b, NULL);
1483    nir_pop_if(b, NULL);
1484 }
1485 
1486 /* https://developer.nvidia.com/blog/thinking-parallel-part-iii-tree-construction-gpu/ */
1487 static nir_ssa_def *
build_morton_component(nir_builder * b,nir_ssa_def * x)1488 build_morton_component(nir_builder *b, nir_ssa_def *x)
1489 {
1490    x = nir_iand_imm(b, nir_imul_imm(b, x, 0x00000101u), 0x0F00F00Fu);
1491    x = nir_iand_imm(b, nir_imul_imm(b, x, 0x00000011u), 0xC30C30C3u);
1492    x = nir_iand_imm(b, nir_imul_imm(b, x, 0x00000005u), 0x49249249u);
1493    return x;
1494 }
1495 
1496 static nir_shader *
build_morton_shader(struct radv_device * dev)1497 build_morton_shader(struct radv_device *dev)
1498 {
1499    const struct glsl_type *vec3_type = glsl_vector_type(GLSL_TYPE_FLOAT, 3);
1500 
1501    nir_builder b = create_accel_build_shader(dev, "accel_build_morton_shader");
1502 
1503    /*
1504     * push constants:
1505     *   i32 x 2: node address
1506     *   i32 x 2: scratch address
1507     */
1508    nir_ssa_def *pconst0 =
1509       nir_load_push_constant(&b, 4, 32, nir_imm_int(&b, 0), .base = 0, .range = 16);
1510 
1511    nir_ssa_def *node_addr = nir_pack_64_2x32(&b, nir_channels(&b, pconst0, 0b0011));
1512    nir_ssa_def *scratch_addr = nir_pack_64_2x32(&b, nir_channels(&b, pconst0, 0b1100));
1513 
1514    nir_ssa_def *global_id =
1515       nir_iadd(&b,
1516                nir_imul_imm(&b, nir_channel(&b, nir_load_workgroup_id(&b, 32), 0),
1517                             b.shader->info.workgroup_size[0]),
1518                nir_load_local_invocation_index(&b));
1519 
1520    nir_ssa_def *node_id_addr =
1521       nir_iadd(&b, nir_iadd_imm(&b, scratch_addr, SCRATCH_TOTAL_BOUNDS_SIZE),
1522                nir_u2u64(&b, id_to_node_id_offset(&b, global_id, dev->physical_device)));
1523    nir_ssa_def *node_id =
1524       nir_build_load_global(&b, 1, 32, node_id_addr, .align_mul = 4, .align_offset = 0);
1525 
1526    nir_variable *node_bounds[2] = {
1527       nir_variable_create(b.shader, nir_var_shader_temp, vec3_type, "min_bound"),
1528       nir_variable_create(b.shader, nir_var_shader_temp, vec3_type, "max_bound"),
1529    };
1530 
1531    determine_bounds(&b, node_addr, node_id, node_bounds);
1532 
1533    nir_ssa_def *node_min = nir_load_var(&b, node_bounds[0]);
1534    nir_ssa_def *node_max = nir_load_var(&b, node_bounds[1]);
1535    nir_ssa_def *node_pos =
1536       nir_fmul(&b, nir_fadd(&b, node_min, node_max), nir_imm_vec3(&b, 0.5, 0.5, 0.5));
1537 
1538    nir_ssa_def *bvh_min = read_fminmax_atomic(dev, &b, 3, scratch_addr);
1539    nir_ssa_def *bvh_max = read_fminmax_atomic(dev, &b, 3, nir_iadd_imm(&b, scratch_addr, 12));
1540    nir_ssa_def *bvh_size = nir_fsub(&b, bvh_max, bvh_min);
1541 
1542    nir_ssa_def *normalized_node_pos = nir_fdiv(&b, nir_fsub(&b, node_pos, bvh_min), bvh_size);
1543 
1544    nir_ssa_def *x_int =
1545       nir_f2u32(&b, nir_fmul_imm(&b, nir_channel(&b, normalized_node_pos, 0), 255.0));
1546    nir_ssa_def *x_morton = build_morton_component(&b, x_int);
1547 
1548    nir_ssa_def *y_int =
1549       nir_f2u32(&b, nir_fmul_imm(&b, nir_channel(&b, normalized_node_pos, 1), 255.0));
1550    nir_ssa_def *y_morton = build_morton_component(&b, y_int);
1551 
1552    nir_ssa_def *z_int =
1553       nir_f2u32(&b, nir_fmul_imm(&b, nir_channel(&b, normalized_node_pos, 2), 255.0));
1554    nir_ssa_def *z_morton = build_morton_component(&b, z_int);
1555 
1556    nir_ssa_def *morton_code = nir_iadd(
1557       &b, nir_iadd(&b, nir_ishl_imm(&b, x_morton, 2), nir_ishl_imm(&b, y_morton, 1)), z_morton);
1558    nir_ssa_def *key = nir_ishl_imm(&b, morton_code, 8);
1559 
1560    nir_ssa_def *dst_addr =
1561       nir_iadd(&b, nir_iadd_imm(&b, scratch_addr, SCRATCH_TOTAL_BOUNDS_SIZE),
1562                nir_u2u64(&b, id_to_morton_offset(&b, global_id, dev->physical_device)));
1563    nir_build_store_global(&b, key, dst_addr, .align_mul = 4);
1564 
1565    return b.shader;
1566 }
1567 
1568 static nir_shader *
build_internal_shader(struct radv_device * dev)1569 build_internal_shader(struct radv_device *dev)
1570 {
1571    const struct glsl_type *vec3_type = glsl_vector_type(GLSL_TYPE_FLOAT, 3);
1572    nir_builder b = create_accel_build_shader(dev, "accel_build_internal_shader");
1573 
1574    /*
1575     * push constants:
1576     *   i32 x 2: node dst address
1577     *   i32 x 2: scratch address
1578     *   i32: dst offset
1579     *   i32: dst scratch offset
1580     *   i32: src scratch offset
1581     *   i32: src_node_count | (fill_header << 31)
1582     */
1583    nir_ssa_def *pconst0 =
1584       nir_load_push_constant(&b, 4, 32, nir_imm_int(&b, 0), .base = 0, .range = 16);
1585    nir_ssa_def *pconst1 =
1586       nir_load_push_constant(&b, 4, 32, nir_imm_int(&b, 0), .base = 16, .range = 16);
1587 
1588    nir_ssa_def *node_addr = nir_pack_64_2x32(&b, nir_channels(&b, pconst0, 0b0011));
1589    nir_ssa_def *scratch_addr = nir_pack_64_2x32(&b, nir_channels(&b, pconst0, 0b1100));
1590    nir_ssa_def *node_dst_offset = nir_channel(&b, pconst1, 0);
1591    nir_ssa_def *dst_scratch_offset = nir_channel(&b, pconst1, 1);
1592    nir_ssa_def *src_scratch_offset = nir_channel(&b, pconst1, 2);
1593    nir_ssa_def *src_node_count = nir_iand_imm(&b, nir_channel(&b, pconst1, 3), 0x7FFFFFFFU);
1594    nir_ssa_def *fill_header =
1595       nir_ine_imm(&b, nir_iand_imm(&b, nir_channel(&b, pconst1, 3), 0x80000000U), 0);
1596 
1597    nir_ssa_def *global_id =
1598       nir_iadd(&b,
1599                nir_imul_imm(&b, nir_channels(&b, nir_load_workgroup_id(&b, 32), 1),
1600                             b.shader->info.workgroup_size[0]),
1601                nir_channels(&b, nir_load_local_invocation_id(&b), 1));
1602    nir_ssa_def *src_idx = nir_imul_imm(&b, global_id, 4);
1603    nir_ssa_def *src_count = nir_umin(&b, nir_imm_int(&b, 4), nir_isub(&b, src_node_count, src_idx));
1604 
1605    nir_ssa_def *node_offset = nir_iadd(&b, node_dst_offset, nir_ishl_imm(&b, global_id, 7));
1606    nir_ssa_def *node_dst_addr = nir_iadd(&b, node_addr, nir_u2u64(&b, node_offset));
1607 
1608    nir_ssa_def *src_base_addr =
1609       nir_iadd(&b, scratch_addr,
1610                nir_u2u64(&b, nir_iadd(&b, src_scratch_offset,
1611                                       id_to_node_id_offset(&b, src_idx, dev->physical_device))));
1612 
1613    enum accel_struct_build build_mode =
1614       get_accel_struct_build(dev->physical_device, VK_ACCELERATION_STRUCTURE_BUILD_TYPE_DEVICE_KHR);
1615    uint32_t node_id_stride = get_node_id_stride(build_mode);
1616 
1617    nir_ssa_def *src_nodes[4];
1618    for (uint32_t i = 0; i < 4; i++) {
1619       src_nodes[i] =
1620          nir_build_load_global(&b, 1, 32, nir_iadd_imm(&b, src_base_addr, i * node_id_stride));
1621       nir_build_store_global(&b, src_nodes[i], nir_iadd_imm(&b, node_dst_addr, i * 4));
1622    }
1623 
1624    nir_ssa_def *total_bounds[2] = {
1625       nir_channels(&b, nir_imm_vec4(&b, NAN, NAN, NAN, NAN), 7),
1626       nir_channels(&b, nir_imm_vec4(&b, NAN, NAN, NAN, NAN), 7),
1627    };
1628 
1629    for (unsigned i = 0; i < 4; ++i) {
1630       nir_variable *bounds[2] = {
1631          nir_variable_create(b.shader, nir_var_shader_temp, vec3_type, "min_bound"),
1632          nir_variable_create(b.shader, nir_var_shader_temp, vec3_type, "max_bound"),
1633       };
1634       nir_store_var(&b, bounds[0], nir_channels(&b, nir_imm_vec4(&b, NAN, NAN, NAN, NAN), 7), 7);
1635       nir_store_var(&b, bounds[1], nir_channels(&b, nir_imm_vec4(&b, NAN, NAN, NAN, NAN), 7), 7);
1636 
1637       nir_push_if(&b, nir_ilt(&b, nir_imm_int(&b, i), src_count));
1638       determine_bounds(&b, node_addr, src_nodes[i], bounds);
1639       nir_pop_if(&b, NULL);
1640       nir_build_store_global(&b, nir_load_var(&b, bounds[0]),
1641                              nir_iadd_imm(&b, node_dst_addr, 16 + 24 * i));
1642       nir_build_store_global(&b, nir_load_var(&b, bounds[1]),
1643                              nir_iadd_imm(&b, node_dst_addr, 28 + 24 * i));
1644       total_bounds[0] = nir_fmin(&b, total_bounds[0], nir_load_var(&b, bounds[0]));
1645       total_bounds[1] = nir_fmax(&b, total_bounds[1], nir_load_var(&b, bounds[1]));
1646    }
1647 
1648    nir_ssa_def *node_id =
1649       nir_iadd_imm(&b, nir_ushr_imm(&b, node_offset, 3), radv_bvh_node_internal);
1650    nir_ssa_def *dst_scratch_addr =
1651       nir_iadd(&b, scratch_addr,
1652                nir_u2u64(&b, nir_iadd(&b, dst_scratch_offset,
1653                                       id_to_node_id_offset(&b, global_id, dev->physical_device))));
1654    nir_build_store_global(&b, node_id, dst_scratch_addr);
1655 
1656    nir_push_if(&b, fill_header);
1657    nir_build_store_global(&b, node_id, node_addr);
1658    nir_build_store_global(&b, total_bounds[0], nir_iadd_imm(&b, node_addr, 8));
1659    nir_build_store_global(&b, total_bounds[1], nir_iadd_imm(&b, node_addr, 20));
1660    nir_pop_if(&b, NULL);
1661    return b.shader;
1662 }
1663 
1664 enum copy_mode {
1665    COPY_MODE_COPY,
1666    COPY_MODE_SERIALIZE,
1667    COPY_MODE_DESERIALIZE,
1668 };
1669 
1670 struct copy_constants {
1671    uint64_t src_addr;
1672    uint64_t dst_addr;
1673    uint32_t mode;
1674 };
1675 
1676 static nir_shader *
build_copy_shader(struct radv_device * dev)1677 build_copy_shader(struct radv_device *dev)
1678 {
1679    nir_builder b = create_accel_build_shader(dev, "accel_copy");
1680 
1681    nir_ssa_def *invoc_id = nir_load_local_invocation_id(&b);
1682    nir_ssa_def *wg_id = nir_load_workgroup_id(&b, 32);
1683    nir_ssa_def *block_size =
1684       nir_imm_ivec4(&b, b.shader->info.workgroup_size[0], b.shader->info.workgroup_size[1],
1685                     b.shader->info.workgroup_size[2], 0);
1686 
1687    nir_ssa_def *global_id =
1688       nir_channel(&b, nir_iadd(&b, nir_imul(&b, wg_id, block_size), invoc_id), 0);
1689 
1690    nir_variable *offset_var =
1691       nir_variable_create(b.shader, nir_var_shader_temp, glsl_uint_type(), "offset");
1692    nir_ssa_def *offset = nir_imul_imm(&b, global_id, 16);
1693    nir_store_var(&b, offset_var, offset, 1);
1694 
1695    nir_ssa_def *increment = nir_imul_imm(&b, nir_channel(&b, nir_load_num_workgroups(&b, 32), 0),
1696                                          b.shader->info.workgroup_size[0] * 16);
1697 
1698    nir_ssa_def *pconst0 =
1699       nir_load_push_constant(&b, 4, 32, nir_imm_int(&b, 0), .base = 0, .range = 16);
1700    nir_ssa_def *pconst1 =
1701       nir_load_push_constant(&b, 1, 32, nir_imm_int(&b, 0), .base = 16, .range = 4);
1702    nir_ssa_def *src_base_addr = nir_pack_64_2x32(&b, nir_channels(&b, pconst0, 0b0011));
1703    nir_ssa_def *dst_base_addr = nir_pack_64_2x32(&b, nir_channels(&b, pconst0, 0b1100));
1704    nir_ssa_def *mode = nir_channel(&b, pconst1, 0);
1705 
1706    nir_variable *compacted_size_var =
1707       nir_variable_create(b.shader, nir_var_shader_temp, glsl_uint64_t_type(), "compacted_size");
1708    nir_variable *src_offset_var =
1709       nir_variable_create(b.shader, nir_var_shader_temp, glsl_uint_type(), "src_offset");
1710    nir_variable *dst_offset_var =
1711       nir_variable_create(b.shader, nir_var_shader_temp, glsl_uint_type(), "dst_offset");
1712    nir_variable *instance_offset_var =
1713       nir_variable_create(b.shader, nir_var_shader_temp, glsl_uint_type(), "instance_offset");
1714    nir_variable *instance_count_var =
1715       nir_variable_create(b.shader, nir_var_shader_temp, glsl_uint_type(), "instance_count");
1716    nir_variable *value_var =
1717       nir_variable_create(b.shader, nir_var_shader_temp, glsl_vec4_type(), "value");
1718 
1719    nir_push_if(&b, nir_ieq_imm(&b, mode, COPY_MODE_SERIALIZE));
1720    {
1721       nir_ssa_def *instance_count = nir_build_load_global(
1722          &b, 1, 32,
1723          nir_iadd_imm(&b, src_base_addr,
1724                       offsetof(struct radv_accel_struct_header, instance_count)));
1725       nir_ssa_def *compacted_size = nir_build_load_global(
1726          &b, 1, 64,
1727          nir_iadd_imm(&b, src_base_addr,
1728                       offsetof(struct radv_accel_struct_header, compacted_size)));
1729       nir_ssa_def *serialization_size = nir_build_load_global(
1730          &b, 1, 64,
1731          nir_iadd_imm(&b, src_base_addr,
1732                       offsetof(struct radv_accel_struct_header, serialization_size)));
1733 
1734       nir_store_var(&b, compacted_size_var, compacted_size, 1);
1735       nir_store_var(&b, instance_offset_var,
1736                     nir_build_load_global(
1737                        &b, 1, 32,
1738                        nir_iadd_imm(&b, src_base_addr,
1739                                     offsetof(struct radv_accel_struct_header, instance_offset))),
1740                     1);
1741       nir_store_var(&b, instance_count_var, instance_count, 1);
1742 
1743       nir_ssa_def *dst_offset = nir_iadd_imm(&b, nir_imul_imm(&b, instance_count, sizeof(uint64_t)),
1744                                              sizeof(struct radv_accel_struct_serialization_header));
1745       nir_store_var(&b, src_offset_var, nir_imm_int(&b, 0), 1);
1746       nir_store_var(&b, dst_offset_var, dst_offset, 1);
1747 
1748       nir_push_if(&b, nir_ieq_imm(&b, global_id, 0));
1749       {
1750          nir_build_store_global(&b, serialization_size,
1751                                 nir_iadd_imm(&b, dst_base_addr,
1752                                              offsetof(struct radv_accel_struct_serialization_header,
1753                                                       serialization_size)));
1754          nir_build_store_global(
1755             &b, compacted_size,
1756             nir_iadd_imm(&b, dst_base_addr,
1757                          offsetof(struct radv_accel_struct_serialization_header, compacted_size)));
1758          nir_build_store_global(
1759             &b, nir_u2u64(&b, instance_count),
1760             nir_iadd_imm(&b, dst_base_addr,
1761                          offsetof(struct radv_accel_struct_serialization_header, instance_count)));
1762       }
1763       nir_pop_if(&b, NULL);
1764    }
1765    nir_push_else(&b, NULL);
1766    nir_push_if(&b, nir_ieq_imm(&b, mode, COPY_MODE_DESERIALIZE));
1767    {
1768       nir_ssa_def *instance_count = nir_build_load_global(
1769          &b, 1, 32,
1770          nir_iadd_imm(&b, src_base_addr,
1771                       offsetof(struct radv_accel_struct_serialization_header, instance_count)));
1772       nir_ssa_def *src_offset = nir_iadd_imm(&b, nir_imul_imm(&b, instance_count, sizeof(uint64_t)),
1773                                              sizeof(struct radv_accel_struct_serialization_header));
1774 
1775       nir_ssa_def *header_addr = nir_iadd(&b, src_base_addr, nir_u2u64(&b, src_offset));
1776       nir_store_var(&b, compacted_size_var,
1777                     nir_build_load_global(
1778                        &b, 1, 64,
1779                        nir_iadd_imm(&b, header_addr,
1780                                     offsetof(struct radv_accel_struct_header, compacted_size))),
1781                     1);
1782       nir_store_var(&b, instance_offset_var,
1783                     nir_build_load_global(
1784                        &b, 1, 32,
1785                        nir_iadd_imm(&b, header_addr,
1786                                     offsetof(struct radv_accel_struct_header, instance_offset))),
1787                     1);
1788       nir_store_var(&b, instance_count_var, instance_count, 1);
1789       nir_store_var(&b, src_offset_var, src_offset, 1);
1790       nir_store_var(&b, dst_offset_var, nir_imm_int(&b, 0), 1);
1791    }
1792    nir_push_else(&b, NULL); /* COPY_MODE_COPY */
1793    {
1794       nir_store_var(&b, compacted_size_var,
1795                     nir_build_load_global(
1796                        &b, 1, 64,
1797                        nir_iadd_imm(&b, src_base_addr,
1798                                     offsetof(struct radv_accel_struct_header, compacted_size))),
1799                     1);
1800 
1801       nir_store_var(&b, src_offset_var, nir_imm_int(&b, 0), 1);
1802       nir_store_var(&b, dst_offset_var, nir_imm_int(&b, 0), 1);
1803       nir_store_var(&b, instance_offset_var, nir_imm_int(&b, 0), 1);
1804       nir_store_var(&b, instance_count_var, nir_imm_int(&b, 0), 1);
1805    }
1806    nir_pop_if(&b, NULL);
1807    nir_pop_if(&b, NULL);
1808 
1809    nir_ssa_def *instance_bound =
1810       nir_imul_imm(&b, nir_load_var(&b, instance_count_var), sizeof(struct radv_bvh_instance_node));
1811    nir_ssa_def *compacted_size = nir_build_load_global(
1812       &b, 1, 32,
1813       nir_iadd_imm(&b, src_base_addr, offsetof(struct radv_accel_struct_header, compacted_size)));
1814 
1815    nir_push_loop(&b);
1816    {
1817       offset = nir_load_var(&b, offset_var);
1818       nir_push_if(&b, nir_ilt(&b, offset, compacted_size));
1819       {
1820          nir_ssa_def *src_offset = nir_iadd(&b, offset, nir_load_var(&b, src_offset_var));
1821          nir_ssa_def *dst_offset = nir_iadd(&b, offset, nir_load_var(&b, dst_offset_var));
1822          nir_ssa_def *src_addr = nir_iadd(&b, src_base_addr, nir_u2u64(&b, src_offset));
1823          nir_ssa_def *dst_addr = nir_iadd(&b, dst_base_addr, nir_u2u64(&b, dst_offset));
1824 
1825          nir_ssa_def *value = nir_build_load_global(&b, 4, 32, src_addr, .align_mul = 16);
1826          nir_store_var(&b, value_var, value, 0xf);
1827 
1828          nir_ssa_def *instance_offset = nir_isub(&b, offset, nir_load_var(&b, instance_offset_var));
1829          nir_ssa_def *in_instance_bound =
1830             nir_iand(&b, nir_uge(&b, offset, nir_load_var(&b, instance_offset_var)),
1831                      nir_ult(&b, instance_offset, instance_bound));
1832          nir_ssa_def *instance_start = nir_ieq_imm(
1833             &b, nir_iand_imm(&b, instance_offset, sizeof(struct radv_bvh_instance_node) - 1), 0);
1834 
1835          nir_push_if(&b, nir_iand(&b, in_instance_bound, instance_start));
1836          {
1837             nir_ssa_def *instance_id = nir_ushr_imm(&b, instance_offset, 7);
1838 
1839             nir_push_if(&b, nir_ieq_imm(&b, mode, COPY_MODE_SERIALIZE));
1840             {
1841                nir_ssa_def *instance_addr = nir_imul_imm(&b, instance_id, sizeof(uint64_t));
1842                instance_addr = nir_iadd_imm(&b, instance_addr,
1843                                             sizeof(struct radv_accel_struct_serialization_header));
1844                instance_addr = nir_iadd(&b, dst_base_addr, nir_u2u64(&b, instance_addr));
1845 
1846                nir_build_store_global(&b, nir_channels(&b, value, 3), instance_addr,
1847                                       .align_mul = 8);
1848             }
1849             nir_push_else(&b, NULL);
1850             {
1851                nir_ssa_def *instance_addr = nir_imul_imm(&b, instance_id, sizeof(uint64_t));
1852                instance_addr = nir_iadd_imm(&b, instance_addr,
1853                                             sizeof(struct radv_accel_struct_serialization_header));
1854                instance_addr = nir_iadd(&b, src_base_addr, nir_u2u64(&b, instance_addr));
1855 
1856                nir_ssa_def *instance_value =
1857                   nir_build_load_global(&b, 2, 32, instance_addr, .align_mul = 8);
1858 
1859                nir_ssa_def *values[] = {
1860                   nir_channel(&b, instance_value, 0),
1861                   nir_channel(&b, instance_value, 1),
1862                   nir_channel(&b, value, 2),
1863                   nir_channel(&b, value, 3),
1864                };
1865 
1866                nir_store_var(&b, value_var, nir_vec(&b, values, 4), 0xf);
1867             }
1868             nir_pop_if(&b, NULL);
1869          }
1870          nir_pop_if(&b, NULL);
1871 
1872          nir_store_var(&b, offset_var, nir_iadd(&b, offset, increment), 1);
1873 
1874          nir_build_store_global(&b, nir_load_var(&b, value_var), dst_addr, .align_mul = 16);
1875       }
1876       nir_push_else(&b, NULL);
1877       {
1878          nir_jump(&b, nir_jump_break);
1879       }
1880       nir_pop_if(&b, NULL);
1881    }
1882    nir_pop_loop(&b, NULL);
1883    return b.shader;
1884 }
1885 
1886 void
radv_device_finish_accel_struct_build_state(struct radv_device * device)1887 radv_device_finish_accel_struct_build_state(struct radv_device *device)
1888 {
1889    struct radv_meta_state *state = &device->meta_state;
1890    radv_DestroyPipeline(radv_device_to_handle(device), state->accel_struct_build.copy_pipeline,
1891                         &state->alloc);
1892    radv_DestroyPipeline(radv_device_to_handle(device), state->accel_struct_build.internal_pipeline,
1893                         &state->alloc);
1894    radv_DestroyPipeline(radv_device_to_handle(device), state->accel_struct_build.leaf_pipeline,
1895                         &state->alloc);
1896    radv_DestroyPipeline(radv_device_to_handle(device), state->accel_struct_build.morton_pipeline,
1897                         &state->alloc);
1898    radv_DestroyPipelineLayout(radv_device_to_handle(device),
1899                               state->accel_struct_build.copy_p_layout, &state->alloc);
1900    radv_DestroyPipelineLayout(radv_device_to_handle(device),
1901                               state->accel_struct_build.internal_p_layout, &state->alloc);
1902    radv_DestroyPipelineLayout(radv_device_to_handle(device),
1903                               state->accel_struct_build.leaf_p_layout, &state->alloc);
1904    radv_DestroyPipelineLayout(radv_device_to_handle(device),
1905                               state->accel_struct_build.morton_p_layout, &state->alloc);
1906 
1907    if (state->accel_struct_build.radix_sort)
1908       radix_sort_vk_destroy(state->accel_struct_build.radix_sort, radv_device_to_handle(device),
1909                             &state->alloc);
1910 }
1911 
1912 static VkResult
create_build_pipeline(struct radv_device * device,nir_shader * shader,unsigned push_constant_size,VkPipeline * pipeline,VkPipelineLayout * layout)1913 create_build_pipeline(struct radv_device *device, nir_shader *shader, unsigned push_constant_size,
1914                       VkPipeline *pipeline, VkPipelineLayout *layout)
1915 {
1916    const VkPipelineLayoutCreateInfo pl_create_info = {
1917       .sType = VK_STRUCTURE_TYPE_PIPELINE_LAYOUT_CREATE_INFO,
1918       .setLayoutCount = 0,
1919       .pushConstantRangeCount = 1,
1920       .pPushConstantRanges =
1921          &(VkPushConstantRange){VK_SHADER_STAGE_COMPUTE_BIT, 0, push_constant_size},
1922    };
1923 
1924    VkResult result = radv_CreatePipelineLayout(radv_device_to_handle(device), &pl_create_info,
1925                                                &device->meta_state.alloc, layout);
1926    if (result != VK_SUCCESS) {
1927       ralloc_free(shader);
1928       return result;
1929    }
1930 
1931    VkPipelineShaderStageCreateInfo shader_stage = {
1932       .sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO,
1933       .stage = VK_SHADER_STAGE_COMPUTE_BIT,
1934       .module = vk_shader_module_handle_from_nir(shader),
1935       .pName = "main",
1936       .pSpecializationInfo = NULL,
1937    };
1938 
1939    VkComputePipelineCreateInfo pipeline_info = {
1940       .sType = VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO,
1941       .stage = shader_stage,
1942       .flags = 0,
1943       .layout = *layout,
1944    };
1945 
1946    result = radv_CreateComputePipelines(radv_device_to_handle(device),
1947                                         radv_pipeline_cache_to_handle(&device->meta_state.cache), 1,
1948                                         &pipeline_info, &device->meta_state.alloc, pipeline);
1949 
1950    if (result != VK_SUCCESS) {
1951       ralloc_free(shader);
1952       return result;
1953    }
1954 
1955    return VK_SUCCESS;
1956 }
1957 
1958 static void
radix_sort_fill_buffer(VkCommandBuffer commandBuffer,radix_sort_vk_buffer_info_t const * buffer_info,VkDeviceSize offset,VkDeviceSize size,uint32_t data)1959 radix_sort_fill_buffer(VkCommandBuffer commandBuffer,
1960                        radix_sort_vk_buffer_info_t const *buffer_info, VkDeviceSize offset,
1961                        VkDeviceSize size, uint32_t data)
1962 {
1963    RADV_FROM_HANDLE(radv_cmd_buffer, cmd_buffer, commandBuffer);
1964 
1965    assert(size != VK_WHOLE_SIZE);
1966 
1967    radv_fill_buffer(cmd_buffer, NULL, NULL, buffer_info->devaddr + buffer_info->offset + offset,
1968                     size, data);
1969 }
1970 
1971 VkResult
radv_device_init_accel_struct_build_state(struct radv_device * device)1972 radv_device_init_accel_struct_build_state(struct radv_device *device)
1973 {
1974    VkResult result;
1975    nir_shader *leaf_cs = build_leaf_shader(device);
1976    nir_shader *internal_cs = build_internal_shader(device);
1977    nir_shader *copy_cs = build_copy_shader(device);
1978 
1979    result = create_build_pipeline(device, leaf_cs, sizeof(struct build_primitive_constants),
1980                                   &device->meta_state.accel_struct_build.leaf_pipeline,
1981                                   &device->meta_state.accel_struct_build.leaf_p_layout);
1982    if (result != VK_SUCCESS)
1983       return result;
1984 
1985    result = create_build_pipeline(device, internal_cs, sizeof(struct build_internal_constants),
1986                                   &device->meta_state.accel_struct_build.internal_pipeline,
1987                                   &device->meta_state.accel_struct_build.internal_p_layout);
1988    if (result != VK_SUCCESS)
1989       return result;
1990 
1991    result = create_build_pipeline(device, copy_cs, sizeof(struct copy_constants),
1992                                   &device->meta_state.accel_struct_build.copy_pipeline,
1993                                   &device->meta_state.accel_struct_build.copy_p_layout);
1994 
1995    if (result != VK_SUCCESS)
1996       return result;
1997 
1998    if (get_accel_struct_build(device->physical_device,
1999                               VK_ACCELERATION_STRUCTURE_BUILD_TYPE_DEVICE_KHR) ==
2000        accel_struct_build_lbvh) {
2001       nir_shader *morton_cs = build_morton_shader(device);
2002 
2003       result = create_build_pipeline(device, morton_cs, sizeof(struct morton_constants),
2004                                      &device->meta_state.accel_struct_build.morton_pipeline,
2005                                      &device->meta_state.accel_struct_build.morton_p_layout);
2006       if (result != VK_SUCCESS)
2007          return result;
2008 
2009       device->meta_state.accel_struct_build.radix_sort =
2010          radv_create_radix_sort_u64(radv_device_to_handle(device), &device->meta_state.alloc,
2011                                     radv_pipeline_cache_to_handle(&device->meta_state.cache));
2012 
2013       struct radix_sort_vk_sort_devaddr_info *radix_sort_info =
2014          &device->meta_state.accel_struct_build.radix_sort_info;
2015       radix_sort_info->ext = NULL;
2016       radix_sort_info->key_bits = 24;
2017       radix_sort_info->fill_buffer = radix_sort_fill_buffer;
2018    }
2019 
2020    return result;
2021 }
2022 
2023 struct bvh_state {
2024    uint32_t node_offset;
2025    uint32_t node_count;
2026    uint32_t scratch_offset;
2027    uint32_t buffer_1_offset;
2028    uint32_t buffer_2_offset;
2029 
2030    uint32_t instance_offset;
2031    uint32_t instance_count;
2032 };
2033 
2034 VKAPI_ATTR void VKAPI_CALL
radv_CmdBuildAccelerationStructuresKHR(VkCommandBuffer commandBuffer,uint32_t infoCount,const VkAccelerationStructureBuildGeometryInfoKHR * pInfos,const VkAccelerationStructureBuildRangeInfoKHR * const * ppBuildRangeInfos)2035 radv_CmdBuildAccelerationStructuresKHR(
2036    VkCommandBuffer commandBuffer, uint32_t infoCount,
2037    const VkAccelerationStructureBuildGeometryInfoKHR *pInfos,
2038    const VkAccelerationStructureBuildRangeInfoKHR *const *ppBuildRangeInfos)
2039 {
2040    RADV_FROM_HANDLE(radv_cmd_buffer, cmd_buffer, commandBuffer);
2041    struct radv_meta_saved_state saved_state;
2042 
2043    enum radv_cmd_flush_bits flush_bits =
2044       RADV_CMD_FLAG_CS_PARTIAL_FLUSH |
2045       radv_src_access_flush(cmd_buffer, VK_ACCESS_2_SHADER_READ_BIT | VK_ACCESS_2_SHADER_WRITE_BIT,
2046                             NULL) |
2047       radv_dst_access_flush(cmd_buffer, VK_ACCESS_2_SHADER_READ_BIT | VK_ACCESS_2_SHADER_WRITE_BIT,
2048                             NULL);
2049 
2050    enum accel_struct_build build_mode = get_accel_struct_build(
2051       cmd_buffer->device->physical_device, VK_ACCELERATION_STRUCTURE_BUILD_TYPE_DEVICE_KHR);
2052    uint32_t node_id_stride = get_node_id_stride(build_mode);
2053 
2054    radv_meta_save(
2055       &saved_state, cmd_buffer,
2056       RADV_META_SAVE_COMPUTE_PIPELINE | RADV_META_SAVE_DESCRIPTORS | RADV_META_SAVE_CONSTANTS);
2057    struct bvh_state *bvh_states = calloc(infoCount, sizeof(struct bvh_state));
2058 
2059    if (build_mode != accel_struct_build_unoptimized) {
2060       for (uint32_t i = 0; i < infoCount; ++i) {
2061          if (radv_has_shader_buffer_float_minmax(cmd_buffer->device->physical_device)) {
2062             /* Clear the bvh bounds with nan. */
2063             si_cp_dma_clear_buffer(cmd_buffer, pInfos[i].scratchData.deviceAddress,
2064                                    6 * sizeof(float), 0x7FC00000);
2065          } else {
2066             /* Clear the bvh bounds with int max/min. */
2067             si_cp_dma_clear_buffer(cmd_buffer, pInfos[i].scratchData.deviceAddress,
2068                                    3 * sizeof(float), 0x7fffffff);
2069             si_cp_dma_clear_buffer(cmd_buffer,
2070                                    pInfos[i].scratchData.deviceAddress + 3 * sizeof(float),
2071                                    3 * sizeof(float), 0x80000000);
2072          }
2073       }
2074 
2075       cmd_buffer->state.flush_bits |= flush_bits;
2076    }
2077 
2078    radv_CmdBindPipeline(commandBuffer, VK_PIPELINE_BIND_POINT_COMPUTE,
2079                         cmd_buffer->device->meta_state.accel_struct_build.leaf_pipeline);
2080 
2081    for (uint32_t i = 0; i < infoCount; ++i) {
2082       RADV_FROM_HANDLE(radv_acceleration_structure, accel_struct,
2083                        pInfos[i].dstAccelerationStructure);
2084 
2085       struct build_primitive_constants prim_consts = {
2086          .node_dst_addr = radv_accel_struct_get_va(accel_struct),
2087          .scratch_addr = pInfos[i].scratchData.deviceAddress,
2088          .dst_offset = ALIGN(sizeof(struct radv_accel_struct_header), 64) + 128,
2089          .dst_scratch_offset = 0,
2090       };
2091       bvh_states[i].node_offset = prim_consts.dst_offset;
2092       bvh_states[i].instance_offset = prim_consts.dst_offset;
2093 
2094       for (int inst = 1; inst >= 0; --inst) {
2095          for (unsigned j = 0; j < pInfos[i].geometryCount; ++j) {
2096             const VkAccelerationStructureGeometryKHR *geom =
2097                pInfos[i].pGeometries ? &pInfos[i].pGeometries[j] : pInfos[i].ppGeometries[j];
2098 
2099             if (!inst == (geom->geometryType == VK_GEOMETRY_TYPE_INSTANCES_KHR))
2100                continue;
2101 
2102             const VkAccelerationStructureBuildRangeInfoKHR *buildRangeInfo =
2103                &ppBuildRangeInfos[i][j];
2104 
2105             prim_consts.geometry_type = geom->geometryType;
2106             prim_consts.geometry_id = j | (geom->flags << 28);
2107             unsigned prim_size;
2108             switch (geom->geometryType) {
2109             case VK_GEOMETRY_TYPE_TRIANGLES_KHR:
2110                prim_consts.vertex_addr =
2111                   geom->geometry.triangles.vertexData.deviceAddress +
2112                   buildRangeInfo->firstVertex * geom->geometry.triangles.vertexStride;
2113                prim_consts.index_addr = geom->geometry.triangles.indexData.deviceAddress;
2114 
2115                if (geom->geometry.triangles.indexType == VK_INDEX_TYPE_NONE_KHR)
2116                   prim_consts.vertex_addr += buildRangeInfo->primitiveOffset;
2117                else
2118                   prim_consts.index_addr += buildRangeInfo->primitiveOffset;
2119 
2120                prim_consts.transform_addr = geom->geometry.triangles.transformData.deviceAddress;
2121                if (prim_consts.transform_addr)
2122                   prim_consts.transform_addr += buildRangeInfo->transformOffset;
2123 
2124                prim_consts.vertex_stride = geom->geometry.triangles.vertexStride;
2125                prim_consts.vertex_format = geom->geometry.triangles.vertexFormat;
2126                prim_consts.index_format = geom->geometry.triangles.indexType;
2127                prim_size = 64;
2128                break;
2129             case VK_GEOMETRY_TYPE_AABBS_KHR:
2130                prim_consts.aabb_addr =
2131                   geom->geometry.aabbs.data.deviceAddress + buildRangeInfo->primitiveOffset;
2132                prim_consts.aabb_stride = geom->geometry.aabbs.stride;
2133                prim_size = 64;
2134                break;
2135             case VK_GEOMETRY_TYPE_INSTANCES_KHR:
2136                prim_consts.instance_data =
2137                   geom->geometry.instances.data.deviceAddress + buildRangeInfo->primitiveOffset;
2138                prim_consts.array_of_pointers = geom->geometry.instances.arrayOfPointers ? 1 : 0;
2139                prim_size = 128;
2140                bvh_states[i].instance_count += buildRangeInfo->primitiveCount;
2141                break;
2142             default:
2143                unreachable("Unknown geometryType");
2144             }
2145 
2146             radv_CmdPushConstants(
2147                commandBuffer, cmd_buffer->device->meta_state.accel_struct_build.leaf_p_layout,
2148                VK_SHADER_STAGE_COMPUTE_BIT, 0, sizeof(prim_consts), &prim_consts);
2149             radv_unaligned_dispatch(cmd_buffer, buildRangeInfo->primitiveCount, 1, 1);
2150             prim_consts.dst_offset += prim_size * buildRangeInfo->primitiveCount;
2151             prim_consts.dst_scratch_offset += node_id_stride * buildRangeInfo->primitiveCount;
2152          }
2153       }
2154       bvh_states[i].node_offset = prim_consts.dst_offset;
2155       bvh_states[i].node_count = prim_consts.dst_scratch_offset / node_id_stride;
2156    }
2157 
2158    if (build_mode == accel_struct_build_lbvh) {
2159       cmd_buffer->state.flush_bits |= flush_bits;
2160 
2161       radv_CmdBindPipeline(commandBuffer, VK_PIPELINE_BIND_POINT_COMPUTE,
2162                            cmd_buffer->device->meta_state.accel_struct_build.morton_pipeline);
2163 
2164       for (uint32_t i = 0; i < infoCount; ++i) {
2165          RADV_FROM_HANDLE(radv_acceleration_structure, accel_struct,
2166                           pInfos[i].dstAccelerationStructure);
2167 
2168          const struct morton_constants consts = {
2169             .node_addr = radv_accel_struct_get_va(accel_struct),
2170             .scratch_addr = pInfos[i].scratchData.deviceAddress,
2171          };
2172 
2173          radv_CmdPushConstants(commandBuffer,
2174                                cmd_buffer->device->meta_state.accel_struct_build.morton_p_layout,
2175                                VK_SHADER_STAGE_COMPUTE_BIT, 0, sizeof(consts), &consts);
2176          radv_unaligned_dispatch(cmd_buffer, bvh_states[i].node_count, 1, 1);
2177       }
2178 
2179       cmd_buffer->state.flush_bits |= flush_bits;
2180 
2181       for (uint32_t i = 0; i < infoCount; ++i) {
2182          struct radix_sort_vk_memory_requirements requirements;
2183          radix_sort_vk_get_memory_requirements(
2184             cmd_buffer->device->meta_state.accel_struct_build.radix_sort, bvh_states[i].node_count,
2185             &requirements);
2186 
2187          struct radix_sort_vk_sort_devaddr_info info =
2188             cmd_buffer->device->meta_state.accel_struct_build.radix_sort_info;
2189          info.count = bvh_states[i].node_count;
2190 
2191          VkDeviceAddress base_addr =
2192             pInfos[i].scratchData.deviceAddress + SCRATCH_TOTAL_BOUNDS_SIZE;
2193 
2194          info.keyvals_even.buffer = VK_NULL_HANDLE;
2195          info.keyvals_even.offset = 0;
2196          info.keyvals_even.devaddr = base_addr;
2197 
2198          info.keyvals_odd = base_addr + requirements.keyvals_size;
2199 
2200          info.internal.buffer = VK_NULL_HANDLE;
2201          info.internal.offset = 0;
2202          info.internal.devaddr = base_addr + requirements.keyvals_size * 2;
2203 
2204          VkDeviceAddress result_addr;
2205          radix_sort_vk_sort_devaddr(cmd_buffer->device->meta_state.accel_struct_build.radix_sort,
2206                                     &info, radv_device_to_handle(cmd_buffer->device), commandBuffer,
2207                                     &result_addr);
2208 
2209          assert(result_addr == info.keyvals_even.devaddr || result_addr == info.keyvals_odd);
2210 
2211          if (result_addr == info.keyvals_even.devaddr) {
2212             bvh_states[i].buffer_1_offset = SCRATCH_TOTAL_BOUNDS_SIZE;
2213             bvh_states[i].buffer_2_offset = SCRATCH_TOTAL_BOUNDS_SIZE + requirements.keyvals_size;
2214          } else {
2215             bvh_states[i].buffer_1_offset = SCRATCH_TOTAL_BOUNDS_SIZE + requirements.keyvals_size;
2216             bvh_states[i].buffer_2_offset = SCRATCH_TOTAL_BOUNDS_SIZE;
2217          }
2218          bvh_states[i].scratch_offset = bvh_states[i].buffer_1_offset;
2219       }
2220 
2221       cmd_buffer->state.flush_bits |= flush_bits;
2222    } else {
2223       for (uint32_t i = 0; i < infoCount; ++i) {
2224          bvh_states[i].buffer_1_offset = 0;
2225          bvh_states[i].buffer_2_offset = bvh_states[i].node_count * 4;
2226       }
2227    }
2228 
2229    radv_CmdBindPipeline(commandBuffer, VK_PIPELINE_BIND_POINT_COMPUTE,
2230                         cmd_buffer->device->meta_state.accel_struct_build.internal_pipeline);
2231    bool progress = true;
2232    for (unsigned iter = 0; progress; ++iter) {
2233       progress = false;
2234       for (uint32_t i = 0; i < infoCount; ++i) {
2235          RADV_FROM_HANDLE(radv_acceleration_structure, accel_struct,
2236                           pInfos[i].dstAccelerationStructure);
2237 
2238          if (iter && bvh_states[i].node_count == 1)
2239             continue;
2240 
2241          if (!progress)
2242             cmd_buffer->state.flush_bits |= flush_bits;
2243 
2244          progress = true;
2245 
2246          uint32_t dst_node_count = MAX2(1, DIV_ROUND_UP(bvh_states[i].node_count, 4));
2247          bool final_iter = dst_node_count == 1;
2248 
2249          uint32_t src_scratch_offset = bvh_states[i].scratch_offset;
2250          uint32_t buffer_1_offset = bvh_states[i].buffer_1_offset;
2251          uint32_t buffer_2_offset = bvh_states[i].buffer_2_offset;
2252          uint32_t dst_scratch_offset =
2253             (src_scratch_offset == buffer_1_offset) ? buffer_2_offset : buffer_1_offset;
2254 
2255          uint32_t dst_node_offset = bvh_states[i].node_offset;
2256          if (final_iter)
2257             dst_node_offset = ALIGN(sizeof(struct radv_accel_struct_header), 64);
2258 
2259          const struct build_internal_constants consts = {
2260             .node_dst_addr = radv_accel_struct_get_va(accel_struct),
2261             .scratch_addr = pInfos[i].scratchData.deviceAddress,
2262             .dst_offset = dst_node_offset,
2263             .dst_scratch_offset = dst_scratch_offset,
2264             .src_scratch_offset = src_scratch_offset,
2265             .fill_header = bvh_states[i].node_count | (final_iter ? 0x80000000U : 0),
2266          };
2267 
2268          radv_CmdPushConstants(commandBuffer,
2269                                cmd_buffer->device->meta_state.accel_struct_build.internal_p_layout,
2270                                VK_SHADER_STAGE_COMPUTE_BIT, 0, sizeof(consts), &consts);
2271          radv_unaligned_dispatch(cmd_buffer, dst_node_count, 1, 1);
2272          if (!final_iter)
2273             bvh_states[i].node_offset += dst_node_count * 128;
2274          bvh_states[i].node_count = dst_node_count;
2275          bvh_states[i].scratch_offset = dst_scratch_offset;
2276       }
2277    }
2278    for (uint32_t i = 0; i < infoCount; ++i) {
2279       RADV_FROM_HANDLE(radv_acceleration_structure, accel_struct,
2280                        pInfos[i].dstAccelerationStructure);
2281       const size_t base = offsetof(struct radv_accel_struct_header, compacted_size);
2282       struct radv_accel_struct_header header;
2283 
2284       header.instance_offset = bvh_states[i].instance_offset;
2285       header.instance_count = bvh_states[i].instance_count;
2286       header.compacted_size = bvh_states[i].node_offset;
2287 
2288       fill_accel_struct_header(&header);
2289 
2290       radv_update_buffer_cp(cmd_buffer,
2291                             radv_buffer_get_va(accel_struct->bo) + accel_struct->mem_offset + base,
2292                             (const char *)&header + base, sizeof(header) - base);
2293    }
2294    free(bvh_states);
2295    radv_meta_restore(&saved_state, cmd_buffer);
2296 }
2297 
2298 VKAPI_ATTR void VKAPI_CALL
radv_CmdCopyAccelerationStructureKHR(VkCommandBuffer commandBuffer,const VkCopyAccelerationStructureInfoKHR * pInfo)2299 radv_CmdCopyAccelerationStructureKHR(VkCommandBuffer commandBuffer,
2300                                      const VkCopyAccelerationStructureInfoKHR *pInfo)
2301 {
2302    RADV_FROM_HANDLE(radv_cmd_buffer, cmd_buffer, commandBuffer);
2303    RADV_FROM_HANDLE(radv_acceleration_structure, src, pInfo->src);
2304    RADV_FROM_HANDLE(radv_acceleration_structure, dst, pInfo->dst);
2305    struct radv_meta_saved_state saved_state;
2306 
2307    radv_meta_save(
2308       &saved_state, cmd_buffer,
2309       RADV_META_SAVE_COMPUTE_PIPELINE | RADV_META_SAVE_DESCRIPTORS | RADV_META_SAVE_CONSTANTS);
2310 
2311    uint64_t src_addr = radv_accel_struct_get_va(src);
2312    uint64_t dst_addr = radv_accel_struct_get_va(dst);
2313 
2314    radv_CmdBindPipeline(radv_cmd_buffer_to_handle(cmd_buffer), VK_PIPELINE_BIND_POINT_COMPUTE,
2315                         cmd_buffer->device->meta_state.accel_struct_build.copy_pipeline);
2316 
2317    const struct copy_constants consts = {
2318       .src_addr = src_addr,
2319       .dst_addr = dst_addr,
2320       .mode = COPY_MODE_COPY,
2321    };
2322 
2323    radv_CmdPushConstants(radv_cmd_buffer_to_handle(cmd_buffer),
2324                          cmd_buffer->device->meta_state.accel_struct_build.copy_p_layout,
2325                          VK_SHADER_STAGE_COMPUTE_BIT, 0, sizeof(consts), &consts);
2326 
2327    cmd_buffer->state.flush_bits |=
2328       radv_dst_access_flush(cmd_buffer, VK_ACCESS_2_INDIRECT_COMMAND_READ_BIT, NULL);
2329 
2330    radv_indirect_dispatch(cmd_buffer, src->bo,
2331                           src_addr + offsetof(struct radv_accel_struct_header, copy_dispatch_size));
2332    radv_meta_restore(&saved_state, cmd_buffer);
2333 }
2334 
2335 VKAPI_ATTR void VKAPI_CALL
radv_GetDeviceAccelerationStructureCompatibilityKHR(VkDevice _device,const VkAccelerationStructureVersionInfoKHR * pVersionInfo,VkAccelerationStructureCompatibilityKHR * pCompatibility)2336 radv_GetDeviceAccelerationStructureCompatibilityKHR(
2337    VkDevice _device, const VkAccelerationStructureVersionInfoKHR *pVersionInfo,
2338    VkAccelerationStructureCompatibilityKHR *pCompatibility)
2339 {
2340    RADV_FROM_HANDLE(radv_device, device, _device);
2341    uint8_t zero[VK_UUID_SIZE] = {
2342       0,
2343    };
2344    bool compat =
2345       memcmp(pVersionInfo->pVersionData, device->physical_device->driver_uuid, VK_UUID_SIZE) == 0 &&
2346       memcmp(pVersionInfo->pVersionData + VK_UUID_SIZE, zero, VK_UUID_SIZE) == 0;
2347    *pCompatibility = compat ? VK_ACCELERATION_STRUCTURE_COMPATIBILITY_COMPATIBLE_KHR
2348                             : VK_ACCELERATION_STRUCTURE_COMPATIBILITY_INCOMPATIBLE_KHR;
2349 }
2350 
2351 VKAPI_ATTR VkResult VKAPI_CALL
radv_CopyMemoryToAccelerationStructureKHR(VkDevice _device,VkDeferredOperationKHR deferredOperation,const VkCopyMemoryToAccelerationStructureInfoKHR * pInfo)2352 radv_CopyMemoryToAccelerationStructureKHR(VkDevice _device,
2353                                           VkDeferredOperationKHR deferredOperation,
2354                                           const VkCopyMemoryToAccelerationStructureInfoKHR *pInfo)
2355 {
2356    RADV_FROM_HANDLE(radv_device, device, _device);
2357    RADV_FROM_HANDLE(radv_acceleration_structure, accel_struct, pInfo->dst);
2358 
2359    char *base = device->ws->buffer_map(accel_struct->bo);
2360    if (!base)
2361       return VK_ERROR_OUT_OF_HOST_MEMORY;
2362 
2363    base += accel_struct->mem_offset;
2364    const struct radv_accel_struct_header *header = (const struct radv_accel_struct_header *)base;
2365 
2366    const char *src = pInfo->src.hostAddress;
2367    struct radv_accel_struct_serialization_header *src_header = (void *)src;
2368    src += sizeof(*src_header) + sizeof(uint64_t) * src_header->instance_count;
2369 
2370    memcpy(base, src, src_header->compacted_size);
2371 
2372    for (unsigned i = 0; i < src_header->instance_count; ++i) {
2373       uint64_t *p = (uint64_t *)(base + i * 128 + header->instance_offset);
2374       *p = (*p & 63) | src_header->instances[i];
2375    }
2376 
2377    device->ws->buffer_unmap(accel_struct->bo);
2378    return VK_SUCCESS;
2379 }
2380 
2381 VKAPI_ATTR VkResult VKAPI_CALL
radv_CopyAccelerationStructureToMemoryKHR(VkDevice _device,VkDeferredOperationKHR deferredOperation,const VkCopyAccelerationStructureToMemoryInfoKHR * pInfo)2382 radv_CopyAccelerationStructureToMemoryKHR(VkDevice _device,
2383                                           VkDeferredOperationKHR deferredOperation,
2384                                           const VkCopyAccelerationStructureToMemoryInfoKHR *pInfo)
2385 {
2386    RADV_FROM_HANDLE(radv_device, device, _device);
2387    RADV_FROM_HANDLE(radv_acceleration_structure, accel_struct, pInfo->src);
2388 
2389    const char *base = device->ws->buffer_map(accel_struct->bo);
2390    if (!base)
2391       return VK_ERROR_OUT_OF_HOST_MEMORY;
2392 
2393    base += accel_struct->mem_offset;
2394    const struct radv_accel_struct_header *header = (const struct radv_accel_struct_header *)base;
2395 
2396    char *dst = pInfo->dst.hostAddress;
2397    struct radv_accel_struct_serialization_header *dst_header = (void *)dst;
2398    dst += sizeof(*dst_header) + sizeof(uint64_t) * header->instance_count;
2399 
2400    memcpy(dst_header->driver_uuid, device->physical_device->driver_uuid, VK_UUID_SIZE);
2401    memset(dst_header->accel_struct_compat, 0, VK_UUID_SIZE);
2402 
2403    dst_header->serialization_size = header->serialization_size;
2404    dst_header->compacted_size = header->compacted_size;
2405    dst_header->instance_count = header->instance_count;
2406 
2407    memcpy(dst, base, header->compacted_size);
2408 
2409    for (unsigned i = 0; i < header->instance_count; ++i) {
2410       dst_header->instances[i] =
2411          *(const uint64_t *)(base + i * 128 + header->instance_offset) & ~63ull;
2412    }
2413 
2414    device->ws->buffer_unmap(accel_struct->bo);
2415    return VK_SUCCESS;
2416 }
2417 
2418 VKAPI_ATTR void VKAPI_CALL
radv_CmdCopyMemoryToAccelerationStructureKHR(VkCommandBuffer commandBuffer,const VkCopyMemoryToAccelerationStructureInfoKHR * pInfo)2419 radv_CmdCopyMemoryToAccelerationStructureKHR(
2420    VkCommandBuffer commandBuffer, const VkCopyMemoryToAccelerationStructureInfoKHR *pInfo)
2421 {
2422    RADV_FROM_HANDLE(radv_cmd_buffer, cmd_buffer, commandBuffer);
2423    RADV_FROM_HANDLE(radv_acceleration_structure, dst, pInfo->dst);
2424    struct radv_meta_saved_state saved_state;
2425 
2426    radv_meta_save(
2427       &saved_state, cmd_buffer,
2428       RADV_META_SAVE_COMPUTE_PIPELINE | RADV_META_SAVE_DESCRIPTORS | RADV_META_SAVE_CONSTANTS);
2429 
2430    uint64_t dst_addr = radv_accel_struct_get_va(dst);
2431 
2432    radv_CmdBindPipeline(radv_cmd_buffer_to_handle(cmd_buffer), VK_PIPELINE_BIND_POINT_COMPUTE,
2433                         cmd_buffer->device->meta_state.accel_struct_build.copy_pipeline);
2434 
2435    const struct copy_constants consts = {
2436       .src_addr = pInfo->src.deviceAddress,
2437       .dst_addr = dst_addr,
2438       .mode = COPY_MODE_DESERIALIZE,
2439    };
2440 
2441    radv_CmdPushConstants(radv_cmd_buffer_to_handle(cmd_buffer),
2442                          cmd_buffer->device->meta_state.accel_struct_build.copy_p_layout,
2443                          VK_SHADER_STAGE_COMPUTE_BIT, 0, sizeof(consts), &consts);
2444 
2445    radv_CmdDispatch(commandBuffer, 512, 1, 1);
2446    radv_meta_restore(&saved_state, cmd_buffer);
2447 }
2448 
2449 VKAPI_ATTR void VKAPI_CALL
radv_CmdCopyAccelerationStructureToMemoryKHR(VkCommandBuffer commandBuffer,const VkCopyAccelerationStructureToMemoryInfoKHR * pInfo)2450 radv_CmdCopyAccelerationStructureToMemoryKHR(
2451    VkCommandBuffer commandBuffer, const VkCopyAccelerationStructureToMemoryInfoKHR *pInfo)
2452 {
2453    RADV_FROM_HANDLE(radv_cmd_buffer, cmd_buffer, commandBuffer);
2454    RADV_FROM_HANDLE(radv_acceleration_structure, src, pInfo->src);
2455    struct radv_meta_saved_state saved_state;
2456 
2457    radv_meta_save(
2458       &saved_state, cmd_buffer,
2459       RADV_META_SAVE_COMPUTE_PIPELINE | RADV_META_SAVE_DESCRIPTORS | RADV_META_SAVE_CONSTANTS);
2460 
2461    uint64_t src_addr = radv_accel_struct_get_va(src);
2462 
2463    radv_CmdBindPipeline(radv_cmd_buffer_to_handle(cmd_buffer), VK_PIPELINE_BIND_POINT_COMPUTE,
2464                         cmd_buffer->device->meta_state.accel_struct_build.copy_pipeline);
2465 
2466    const struct copy_constants consts = {
2467       .src_addr = src_addr,
2468       .dst_addr = pInfo->dst.deviceAddress,
2469       .mode = COPY_MODE_SERIALIZE,
2470    };
2471 
2472    radv_CmdPushConstants(radv_cmd_buffer_to_handle(cmd_buffer),
2473                          cmd_buffer->device->meta_state.accel_struct_build.copy_p_layout,
2474                          VK_SHADER_STAGE_COMPUTE_BIT, 0, sizeof(consts), &consts);
2475 
2476    cmd_buffer->state.flush_bits |=
2477       radv_dst_access_flush(cmd_buffer, VK_ACCESS_2_INDIRECT_COMMAND_READ_BIT, NULL);
2478 
2479    radv_indirect_dispatch(cmd_buffer, src->bo,
2480                           src_addr + offsetof(struct radv_accel_struct_header, copy_dispatch_size));
2481    radv_meta_restore(&saved_state, cmd_buffer);
2482 
2483    /* Set the header of the serialized data. */
2484    uint8_t header_data[2 * VK_UUID_SIZE] = {0};
2485    memcpy(header_data, cmd_buffer->device->physical_device->driver_uuid, VK_UUID_SIZE);
2486 
2487    radv_update_buffer_cp(cmd_buffer, pInfo->dst.deviceAddress, header_data, sizeof(header_data));
2488 }
2489 
2490 VKAPI_ATTR void VKAPI_CALL
radv_CmdBuildAccelerationStructuresIndirectKHR(VkCommandBuffer commandBuffer,uint32_t infoCount,const VkAccelerationStructureBuildGeometryInfoKHR * pInfos,const VkDeviceAddress * pIndirectDeviceAddresses,const uint32_t * pIndirectStrides,const uint32_t * const * ppMaxPrimitiveCounts)2491 radv_CmdBuildAccelerationStructuresIndirectKHR(
2492    VkCommandBuffer commandBuffer, uint32_t infoCount,
2493    const VkAccelerationStructureBuildGeometryInfoKHR *pInfos,
2494    const VkDeviceAddress *pIndirectDeviceAddresses, const uint32_t *pIndirectStrides,
2495    const uint32_t *const *ppMaxPrimitiveCounts)
2496 {
2497    unreachable("Unimplemented");
2498 }
2499