• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*-------------------------------------------------------------------------
2  * Vulkan CTS Framework
3  * --------------------
4  *
5  * Copyright (c) 2020 The Khronos Group Inc.
6  *
7  * Licensed under the Apache License, Version 2.0 (the "License");
8  * you may not use this file except in compliance with the License.
9  * You may obtain a copy of the License at
10  *
11  *      http://www.apache.org/licenses/LICENSE-2.0
12  *
13  * Unless required by applicable law or agreed to in writing, software
14  * distributed under the License is distributed on an "AS IS" BASIS,
15  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16  * See the License for the specific language governing permissions and
17  * limitations under the License.
18  *
19  *//*!
20  * \file
21  * \brief Utilities for creating commonly used Vulkan objects
22  *//*--------------------------------------------------------------------*/
23 
24 #include "vkRayTracingUtil.hpp"
25 
26 #include "vkRefUtil.hpp"
27 #include "vkQueryUtil.hpp"
28 #include "vkObjUtil.hpp"
29 #include "vkBarrierUtil.hpp"
30 #include "vkCmdUtil.hpp"
31 
32 #include "deStringUtil.hpp"
33 #include "deSTLUtil.hpp"
34 
35 #include <vector>
36 #include <string>
37 #include <thread>
38 #include <limits>
39 #include <type_traits>
40 #include <map>
41 
42 #include "SPIRV/spirv.hpp"
43 
44 namespace vk
45 {
46 
47 #ifndef CTS_USES_VULKANSC
48 
49 static const uint32_t WATCHDOG_INTERVAL = 16384; // Touch watchDog every N iterations.
50 
51 struct DeferredThreadParams
52 {
53     const DeviceInterface &vk;
54     VkDevice device;
55     VkDeferredOperationKHR deferredOperation;
56     VkResult result;
57 };
58 
getFormatSimpleName(vk::VkFormat format)59 std::string getFormatSimpleName(vk::VkFormat format)
60 {
61     constexpr size_t kPrefixLen = 10; // strlen("VK_FORMAT_")
62     return de::toLower(de::toString(format).substr(kPrefixLen));
63 }
64 
pointInTriangle2D(const tcu::Vec3 & p,const tcu::Vec3 & p0,const tcu::Vec3 & p1,const tcu::Vec3 & p2)65 bool pointInTriangle2D(const tcu::Vec3 &p, const tcu::Vec3 &p0, const tcu::Vec3 &p1, const tcu::Vec3 &p2)
66 {
67     float s = p0.y() * p2.x() - p0.x() * p2.y() + (p2.y() - p0.y()) * p.x() + (p0.x() - p2.x()) * p.y();
68     float t = p0.x() * p1.y() - p0.y() * p1.x() + (p0.y() - p1.y()) * p.x() + (p1.x() - p0.x()) * p.y();
69 
70     if ((s < 0) != (t < 0))
71         return false;
72 
73     float a = -p1.y() * p2.x() + p0.y() * (p2.x() - p1.x()) + p0.x() * (p1.y() - p2.y()) + p1.x() * p2.y();
74 
75     return a < 0 ? (s <= 0 && s + t >= a) : (s >= 0 && s + t <= a);
76 }
77 
78 // Returns true if VK_FORMAT_FEATURE_ACCELERATION_STRUCTURE_VERTEX_BUFFER_BIT_KHR needs to be supported for the given format.
isMandatoryAccelerationStructureVertexBufferFormat(vk::VkFormat format)79 static bool isMandatoryAccelerationStructureVertexBufferFormat(vk::VkFormat format)
80 {
81     bool mandatory = false;
82 
83     switch (format)
84     {
85     case VK_FORMAT_R32G32_SFLOAT:
86     case VK_FORMAT_R32G32B32_SFLOAT:
87     case VK_FORMAT_R16G16_SFLOAT:
88     case VK_FORMAT_R16G16B16A16_SFLOAT:
89     case VK_FORMAT_R16G16_SNORM:
90     case VK_FORMAT_R16G16B16A16_SNORM:
91         mandatory = true;
92         break;
93     default:
94         break;
95     }
96 
97     return mandatory;
98 }
99 
checkAccelerationStructureVertexBufferFormat(const vk::InstanceInterface & vki,vk::VkPhysicalDevice physicalDevice,vk::VkFormat format)100 void checkAccelerationStructureVertexBufferFormat(const vk::InstanceInterface &vki, vk::VkPhysicalDevice physicalDevice,
101                                                   vk::VkFormat format)
102 {
103     const vk::VkFormatProperties formatProperties = getPhysicalDeviceFormatProperties(vki, physicalDevice, format);
104 
105     if ((formatProperties.bufferFeatures & vk::VK_FORMAT_FEATURE_ACCELERATION_STRUCTURE_VERTEX_BUFFER_BIT_KHR) == 0u)
106     {
107         const std::string errorMsg = "Format not supported for acceleration structure vertex buffers";
108         if (isMandatoryAccelerationStructureVertexBufferFormat(format))
109             TCU_FAIL(errorMsg);
110         TCU_THROW(NotSupportedError, errorMsg);
111     }
112 }
113 
getCommonRayGenerationShader(void)114 std::string getCommonRayGenerationShader(void)
115 {
116     return "#version 460 core\n"
117            "#extension GL_EXT_ray_tracing : require\n"
118            "layout(location = 0) rayPayloadEXT vec3 hitValue;\n"
119            "layout(set = 0, binding = 1) uniform accelerationStructureEXT topLevelAS;\n"
120            "\n"
121            "void main()\n"
122            "{\n"
123            "  uint  rayFlags = 0;\n"
124            "  uint  cullMask = 0xFF;\n"
125            "  float tmin     = 0.0;\n"
126            "  float tmax     = 9.0;\n"
127            "  vec3  origin   = vec3((float(gl_LaunchIDEXT.x) + 0.5f) / float(gl_LaunchSizeEXT.x), "
128            "(float(gl_LaunchIDEXT.y) + 0.5f) / float(gl_LaunchSizeEXT.y), 0.0);\n"
129            "  vec3  direct   = vec3(0.0, 0.0, -1.0);\n"
130            "  traceRayEXT(topLevelAS, rayFlags, cullMask, 0, 0, 0, origin, tmin, direct, tmax, 0);\n"
131            "}\n";
132 }
133 
RaytracedGeometryBase(VkGeometryTypeKHR geometryType,VkFormat vertexFormat,VkIndexType indexType)134 RaytracedGeometryBase::RaytracedGeometryBase(VkGeometryTypeKHR geometryType, VkFormat vertexFormat,
135                                              VkIndexType indexType)
136     : m_geometryType(geometryType)
137     , m_vertexFormat(vertexFormat)
138     , m_indexType(indexType)
139     , m_geometryFlags((VkGeometryFlagsKHR)0u)
140     , m_hasOpacityMicromap(false)
141 {
142     if (m_geometryType == VK_GEOMETRY_TYPE_AABBS_KHR)
143         DE_ASSERT(m_vertexFormat == VK_FORMAT_R32G32B32_SFLOAT);
144 }
145 
~RaytracedGeometryBase()146 RaytracedGeometryBase::~RaytracedGeometryBase()
147 {
148 }
149 
150 struct GeometryBuilderParams
151 {
152     VkGeometryTypeKHR geometryType;
153     bool usePadding;
154 };
155 
156 template <typename V, typename I>
buildRaytracedGeometry(const GeometryBuilderParams & params)157 RaytracedGeometryBase *buildRaytracedGeometry(const GeometryBuilderParams &params)
158 {
159     return new RaytracedGeometry<V, I>(params.geometryType, (params.usePadding ? 1u : 0u));
160 }
161 
makeRaytracedGeometry(VkGeometryTypeKHR geometryType,VkFormat vertexFormat,VkIndexType indexType,bool padVertices)162 de::SharedPtr<RaytracedGeometryBase> makeRaytracedGeometry(VkGeometryTypeKHR geometryType, VkFormat vertexFormat,
163                                                            VkIndexType indexType, bool padVertices)
164 {
165     const GeometryBuilderParams builderParams{geometryType, padVertices};
166 
167     switch (vertexFormat)
168     {
169     case VK_FORMAT_R32G32_SFLOAT:
170         switch (indexType)
171         {
172         case VK_INDEX_TYPE_UINT16:
173             return de::SharedPtr<RaytracedGeometryBase>(buildRaytracedGeometry<tcu::Vec2, uint16_t>(builderParams));
174         case VK_INDEX_TYPE_UINT32:
175             return de::SharedPtr<RaytracedGeometryBase>(buildRaytracedGeometry<tcu::Vec2, uint32_t>(builderParams));
176         case VK_INDEX_TYPE_NONE_KHR:
177             return de::SharedPtr<RaytracedGeometryBase>(buildRaytracedGeometry<tcu::Vec2, EmptyIndex>(builderParams));
178         default:
179             TCU_THROW(InternalError, "Wrong index type");
180         }
181     case VK_FORMAT_R32G32B32_SFLOAT:
182         switch (indexType)
183         {
184         case VK_INDEX_TYPE_UINT16:
185             return de::SharedPtr<RaytracedGeometryBase>(buildRaytracedGeometry<tcu::Vec3, uint16_t>(builderParams));
186         case VK_INDEX_TYPE_UINT32:
187             return de::SharedPtr<RaytracedGeometryBase>(buildRaytracedGeometry<tcu::Vec3, uint32_t>(builderParams));
188         case VK_INDEX_TYPE_NONE_KHR:
189             return de::SharedPtr<RaytracedGeometryBase>(buildRaytracedGeometry<tcu::Vec3, EmptyIndex>(builderParams));
190         default:
191             TCU_THROW(InternalError, "Wrong index type");
192         }
193     case VK_FORMAT_R32G32B32A32_SFLOAT:
194         switch (indexType)
195         {
196         case VK_INDEX_TYPE_UINT16:
197             return de::SharedPtr<RaytracedGeometryBase>(buildRaytracedGeometry<tcu::Vec4, uint16_t>(builderParams));
198         case VK_INDEX_TYPE_UINT32:
199             return de::SharedPtr<RaytracedGeometryBase>(buildRaytracedGeometry<tcu::Vec4, uint32_t>(builderParams));
200         case VK_INDEX_TYPE_NONE_KHR:
201             return de::SharedPtr<RaytracedGeometryBase>(buildRaytracedGeometry<tcu::Vec4, EmptyIndex>(builderParams));
202         default:
203             TCU_THROW(InternalError, "Wrong index type");
204         }
205     case VK_FORMAT_R16G16_SFLOAT:
206         switch (indexType)
207         {
208         case VK_INDEX_TYPE_UINT16:
209             return de::SharedPtr<RaytracedGeometryBase>(buildRaytracedGeometry<Vec2_16, uint16_t>(builderParams));
210         case VK_INDEX_TYPE_UINT32:
211             return de::SharedPtr<RaytracedGeometryBase>(buildRaytracedGeometry<Vec2_16, uint32_t>(builderParams));
212         case VK_INDEX_TYPE_NONE_KHR:
213             return de::SharedPtr<RaytracedGeometryBase>(buildRaytracedGeometry<Vec2_16, EmptyIndex>(builderParams));
214         default:
215             TCU_THROW(InternalError, "Wrong index type");
216         }
217     case VK_FORMAT_R16G16B16_SFLOAT:
218         switch (indexType)
219         {
220         case VK_INDEX_TYPE_UINT16:
221             return de::SharedPtr<RaytracedGeometryBase>(buildRaytracedGeometry<Vec3_16, uint16_t>(builderParams));
222         case VK_INDEX_TYPE_UINT32:
223             return de::SharedPtr<RaytracedGeometryBase>(buildRaytracedGeometry<Vec3_16, uint32_t>(builderParams));
224         case VK_INDEX_TYPE_NONE_KHR:
225             return de::SharedPtr<RaytracedGeometryBase>(buildRaytracedGeometry<Vec3_16, EmptyIndex>(builderParams));
226         default:
227             TCU_THROW(InternalError, "Wrong index type");
228         }
229     case VK_FORMAT_R16G16B16A16_SFLOAT:
230         switch (indexType)
231         {
232         case VK_INDEX_TYPE_UINT16:
233             return de::SharedPtr<RaytracedGeometryBase>(buildRaytracedGeometry<Vec4_16, uint16_t>(builderParams));
234         case VK_INDEX_TYPE_UINT32:
235             return de::SharedPtr<RaytracedGeometryBase>(buildRaytracedGeometry<Vec4_16, uint32_t>(builderParams));
236         case VK_INDEX_TYPE_NONE_KHR:
237             return de::SharedPtr<RaytracedGeometryBase>(buildRaytracedGeometry<Vec4_16, EmptyIndex>(builderParams));
238         default:
239             TCU_THROW(InternalError, "Wrong index type");
240         }
241     case VK_FORMAT_R16G16_SNORM:
242         switch (indexType)
243         {
244         case VK_INDEX_TYPE_UINT16:
245             return de::SharedPtr<RaytracedGeometryBase>(buildRaytracedGeometry<Vec2_16SNorm, uint16_t>(builderParams));
246         case VK_INDEX_TYPE_UINT32:
247             return de::SharedPtr<RaytracedGeometryBase>(buildRaytracedGeometry<Vec2_16SNorm, uint32_t>(builderParams));
248         case VK_INDEX_TYPE_NONE_KHR:
249             return de::SharedPtr<RaytracedGeometryBase>(
250                 buildRaytracedGeometry<Vec2_16SNorm, EmptyIndex>(builderParams));
251         default:
252             TCU_THROW(InternalError, "Wrong index type");
253         }
254     case VK_FORMAT_R16G16B16_SNORM:
255         switch (indexType)
256         {
257         case VK_INDEX_TYPE_UINT16:
258             return de::SharedPtr<RaytracedGeometryBase>(buildRaytracedGeometry<Vec3_16SNorm, uint16_t>(builderParams));
259         case VK_INDEX_TYPE_UINT32:
260             return de::SharedPtr<RaytracedGeometryBase>(buildRaytracedGeometry<Vec3_16SNorm, uint32_t>(builderParams));
261         case VK_INDEX_TYPE_NONE_KHR:
262             return de::SharedPtr<RaytracedGeometryBase>(
263                 buildRaytracedGeometry<Vec3_16SNorm, EmptyIndex>(builderParams));
264         default:
265             TCU_THROW(InternalError, "Wrong index type");
266         }
267     case VK_FORMAT_R16G16B16A16_SNORM:
268         switch (indexType)
269         {
270         case VK_INDEX_TYPE_UINT16:
271             return de::SharedPtr<RaytracedGeometryBase>(buildRaytracedGeometry<Vec4_16SNorm, uint16_t>(builderParams));
272         case VK_INDEX_TYPE_UINT32:
273             return de::SharedPtr<RaytracedGeometryBase>(buildRaytracedGeometry<Vec4_16SNorm, uint32_t>(builderParams));
274         case VK_INDEX_TYPE_NONE_KHR:
275             return de::SharedPtr<RaytracedGeometryBase>(
276                 buildRaytracedGeometry<Vec4_16SNorm, EmptyIndex>(builderParams));
277         default:
278             TCU_THROW(InternalError, "Wrong index type");
279         }
280     case VK_FORMAT_R64G64_SFLOAT:
281         switch (indexType)
282         {
283         case VK_INDEX_TYPE_UINT16:
284             return de::SharedPtr<RaytracedGeometryBase>(buildRaytracedGeometry<tcu::DVec2, uint16_t>(builderParams));
285         case VK_INDEX_TYPE_UINT32:
286             return de::SharedPtr<RaytracedGeometryBase>(buildRaytracedGeometry<tcu::DVec2, uint32_t>(builderParams));
287         case VK_INDEX_TYPE_NONE_KHR:
288             return de::SharedPtr<RaytracedGeometryBase>(buildRaytracedGeometry<tcu::DVec2, EmptyIndex>(builderParams));
289         default:
290             TCU_THROW(InternalError, "Wrong index type");
291         }
292     case VK_FORMAT_R64G64B64_SFLOAT:
293         switch (indexType)
294         {
295         case VK_INDEX_TYPE_UINT16:
296             return de::SharedPtr<RaytracedGeometryBase>(buildRaytracedGeometry<tcu::DVec3, uint16_t>(builderParams));
297         case VK_INDEX_TYPE_UINT32:
298             return de::SharedPtr<RaytracedGeometryBase>(buildRaytracedGeometry<tcu::DVec3, uint32_t>(builderParams));
299         case VK_INDEX_TYPE_NONE_KHR:
300             return de::SharedPtr<RaytracedGeometryBase>(buildRaytracedGeometry<tcu::DVec3, EmptyIndex>(builderParams));
301         default:
302             TCU_THROW(InternalError, "Wrong index type");
303         }
304     case VK_FORMAT_R64G64B64A64_SFLOAT:
305         switch (indexType)
306         {
307         case VK_INDEX_TYPE_UINT16:
308             return de::SharedPtr<RaytracedGeometryBase>(buildRaytracedGeometry<tcu::DVec4, uint16_t>(builderParams));
309         case VK_INDEX_TYPE_UINT32:
310             return de::SharedPtr<RaytracedGeometryBase>(buildRaytracedGeometry<tcu::DVec4, uint32_t>(builderParams));
311         case VK_INDEX_TYPE_NONE_KHR:
312             return de::SharedPtr<RaytracedGeometryBase>(buildRaytracedGeometry<tcu::DVec4, EmptyIndex>(builderParams));
313         default:
314             TCU_THROW(InternalError, "Wrong index type");
315         }
316     case VK_FORMAT_R8G8_SNORM:
317         switch (indexType)
318         {
319         case VK_INDEX_TYPE_UINT16:
320             return de::SharedPtr<RaytracedGeometryBase>(buildRaytracedGeometry<Vec2_8SNorm, uint16_t>(builderParams));
321         case VK_INDEX_TYPE_UINT32:
322             return de::SharedPtr<RaytracedGeometryBase>(buildRaytracedGeometry<Vec2_8SNorm, uint32_t>(builderParams));
323         case VK_INDEX_TYPE_NONE_KHR:
324             return de::SharedPtr<RaytracedGeometryBase>(buildRaytracedGeometry<Vec2_8SNorm, EmptyIndex>(builderParams));
325         default:
326             TCU_THROW(InternalError, "Wrong index type");
327         }
328     case VK_FORMAT_R8G8B8_SNORM:
329         switch (indexType)
330         {
331         case VK_INDEX_TYPE_UINT16:
332             return de::SharedPtr<RaytracedGeometryBase>(buildRaytracedGeometry<Vec3_8SNorm, uint16_t>(builderParams));
333         case VK_INDEX_TYPE_UINT32:
334             return de::SharedPtr<RaytracedGeometryBase>(buildRaytracedGeometry<Vec3_8SNorm, uint32_t>(builderParams));
335         case VK_INDEX_TYPE_NONE_KHR:
336             return de::SharedPtr<RaytracedGeometryBase>(buildRaytracedGeometry<Vec3_8SNorm, EmptyIndex>(builderParams));
337         default:
338             TCU_THROW(InternalError, "Wrong index type");
339         }
340     case VK_FORMAT_R8G8B8A8_SNORM:
341         switch (indexType)
342         {
343         case VK_INDEX_TYPE_UINT16:
344             return de::SharedPtr<RaytracedGeometryBase>(buildRaytracedGeometry<Vec4_8SNorm, uint16_t>(builderParams));
345         case VK_INDEX_TYPE_UINT32:
346             return de::SharedPtr<RaytracedGeometryBase>(buildRaytracedGeometry<Vec4_8SNorm, uint32_t>(builderParams));
347         case VK_INDEX_TYPE_NONE_KHR:
348             return de::SharedPtr<RaytracedGeometryBase>(buildRaytracedGeometry<Vec4_8SNorm, EmptyIndex>(builderParams));
349         default:
350             TCU_THROW(InternalError, "Wrong index type");
351         }
352     default:
353         TCU_THROW(InternalError, "Wrong vertex format");
354     }
355 }
356 
getBufferDeviceAddress(const DeviceInterface & vk,const VkDevice device,const VkBuffer buffer,VkDeviceSize offset)357 VkDeviceAddress getBufferDeviceAddress(const DeviceInterface &vk, const VkDevice device, const VkBuffer buffer,
358                                        VkDeviceSize offset)
359 {
360 
361     if (buffer == DE_NULL)
362         return 0;
363 
364     VkBufferDeviceAddressInfo deviceAddressInfo{
365         VK_STRUCTURE_TYPE_BUFFER_DEVICE_ADDRESS_INFO, // VkStructureType    sType
366         DE_NULL,                                      // const void*        pNext
367         buffer                                        // VkBuffer           buffer;
368     };
369     return vk.getBufferDeviceAddress(device, &deviceAddressInfo) + offset;
370 }
371 
makeQueryPool(const DeviceInterface & vk,const VkDevice device,const VkQueryType queryType,uint32_t queryCount)372 static inline Move<VkQueryPool> makeQueryPool(const DeviceInterface &vk, const VkDevice device,
373                                               const VkQueryType queryType, uint32_t queryCount)
374 {
375     const VkQueryPoolCreateInfo queryPoolCreateInfo = {
376         VK_STRUCTURE_TYPE_QUERY_POOL_CREATE_INFO, // sType
377         DE_NULL,                                  // pNext
378         (VkQueryPoolCreateFlags)0,                // flags
379         queryType,                                // queryType
380         queryCount,                               // queryCount
381         0u,                                       // pipelineStatistics
382     };
383     return createQueryPool(vk, device, &queryPoolCreateInfo);
384 }
385 
makeVkAccelerationStructureGeometryDataKHR(const VkAccelerationStructureGeometryTrianglesDataKHR & triangles)386 static inline VkAccelerationStructureGeometryDataKHR makeVkAccelerationStructureGeometryDataKHR(
387     const VkAccelerationStructureGeometryTrianglesDataKHR &triangles)
388 {
389     VkAccelerationStructureGeometryDataKHR result;
390 
391     deMemset(&result, 0, sizeof(result));
392 
393     result.triangles = triangles;
394 
395     return result;
396 }
397 
makeVkAccelerationStructureGeometryDataKHR(const VkAccelerationStructureGeometryAabbsDataKHR & aabbs)398 static inline VkAccelerationStructureGeometryDataKHR makeVkAccelerationStructureGeometryDataKHR(
399     const VkAccelerationStructureGeometryAabbsDataKHR &aabbs)
400 {
401     VkAccelerationStructureGeometryDataKHR result;
402 
403     deMemset(&result, 0, sizeof(result));
404 
405     result.aabbs = aabbs;
406 
407     return result;
408 }
409 
makeVkAccelerationStructureInstancesDataKHR(const VkAccelerationStructureGeometryInstancesDataKHR & instances)410 static inline VkAccelerationStructureGeometryDataKHR makeVkAccelerationStructureInstancesDataKHR(
411     const VkAccelerationStructureGeometryInstancesDataKHR &instances)
412 {
413     VkAccelerationStructureGeometryDataKHR result;
414 
415     deMemset(&result, 0, sizeof(result));
416 
417     result.instances = instances;
418 
419     return result;
420 }
421 
makeVkAccelerationStructureInstanceKHR(const VkTransformMatrixKHR & transform,uint32_t instanceCustomIndex,uint32_t mask,uint32_t instanceShaderBindingTableRecordOffset,VkGeometryInstanceFlagsKHR flags,uint64_t accelerationStructureReference)422 static inline VkAccelerationStructureInstanceKHR makeVkAccelerationStructureInstanceKHR(
423     const VkTransformMatrixKHR &transform, uint32_t instanceCustomIndex, uint32_t mask,
424     uint32_t instanceShaderBindingTableRecordOffset, VkGeometryInstanceFlagsKHR flags,
425     uint64_t accelerationStructureReference)
426 {
427     VkAccelerationStructureInstanceKHR instance     = {transform, 0, 0, 0, 0, accelerationStructureReference};
428     instance.instanceCustomIndex                    = instanceCustomIndex & 0xFFFFFF;
429     instance.mask                                   = mask & 0xFF;
430     instance.instanceShaderBindingTableRecordOffset = instanceShaderBindingTableRecordOffset & 0xFFFFFF;
431     instance.flags                                  = flags & 0xFF;
432     return instance;
433 }
434 
getRayTracingShaderGroupHandlesKHR(const DeviceInterface & vk,const VkDevice device,const VkPipeline pipeline,const uint32_t firstGroup,const uint32_t groupCount,const uintptr_t dataSize,void * pData)435 VkResult getRayTracingShaderGroupHandlesKHR(const DeviceInterface &vk, const VkDevice device, const VkPipeline pipeline,
436                                             const uint32_t firstGroup, const uint32_t groupCount,
437                                             const uintptr_t dataSize, void *pData)
438 {
439     return vk.getRayTracingShaderGroupHandlesKHR(device, pipeline, firstGroup, groupCount, dataSize, pData);
440 }
441 
getRayTracingShaderGroupHandles(const DeviceInterface & vk,const VkDevice device,const VkPipeline pipeline,const uint32_t firstGroup,const uint32_t groupCount,const uintptr_t dataSize,void * pData)442 VkResult getRayTracingShaderGroupHandles(const DeviceInterface &vk, const VkDevice device, const VkPipeline pipeline,
443                                          const uint32_t firstGroup, const uint32_t groupCount, const uintptr_t dataSize,
444                                          void *pData)
445 {
446     return getRayTracingShaderGroupHandlesKHR(vk, device, pipeline, firstGroup, groupCount, dataSize, pData);
447 }
448 
getRayTracingCaptureReplayShaderGroupHandles(const DeviceInterface & vk,const VkDevice device,const VkPipeline pipeline,const uint32_t firstGroup,const uint32_t groupCount,const uintptr_t dataSize,void * pData)449 VkResult getRayTracingCaptureReplayShaderGroupHandles(const DeviceInterface &vk, const VkDevice device,
450                                                       const VkPipeline pipeline, const uint32_t firstGroup,
451                                                       const uint32_t groupCount, const uintptr_t dataSize, void *pData)
452 {
453     return vk.getRayTracingCaptureReplayShaderGroupHandlesKHR(device, pipeline, firstGroup, groupCount, dataSize,
454                                                               pData);
455 }
456 
finishDeferredOperation(const DeviceInterface & vk,VkDevice device,VkDeferredOperationKHR deferredOperation)457 VkResult finishDeferredOperation(const DeviceInterface &vk, VkDevice device, VkDeferredOperationKHR deferredOperation)
458 {
459     VkResult result = vk.deferredOperationJoinKHR(device, deferredOperation);
460 
461     while (result == VK_THREAD_IDLE_KHR)
462     {
463         std::this_thread::yield();
464         result = vk.deferredOperationJoinKHR(device, deferredOperation);
465     }
466 
467     switch (result)
468     {
469     case VK_SUCCESS:
470     {
471         // Deferred operation has finished. Query its result
472         result = vk.getDeferredOperationResultKHR(device, deferredOperation);
473 
474         break;
475     }
476 
477     case VK_THREAD_DONE_KHR:
478     {
479         // Deferred operation is being wrapped up by another thread
480         // wait for that thread to finish
481         do
482         {
483             std::this_thread::yield();
484             result = vk.getDeferredOperationResultKHR(device, deferredOperation);
485         } while (result == VK_NOT_READY);
486 
487         break;
488     }
489 
490     default:
491     {
492         DE_ASSERT(false);
493 
494         break;
495     }
496     }
497 
498     return result;
499 }
500 
finishDeferredOperationThreaded(DeferredThreadParams * deferredThreadParams)501 void finishDeferredOperationThreaded(DeferredThreadParams *deferredThreadParams)
502 {
503     deferredThreadParams->result = finishDeferredOperation(deferredThreadParams->vk, deferredThreadParams->device,
504                                                            deferredThreadParams->deferredOperation);
505 }
506 
finishDeferredOperation(const DeviceInterface & vk,VkDevice device,VkDeferredOperationKHR deferredOperation,const uint32_t workerThreadCount,const bool operationNotDeferred)507 void finishDeferredOperation(const DeviceInterface &vk, VkDevice device, VkDeferredOperationKHR deferredOperation,
508                              const uint32_t workerThreadCount, const bool operationNotDeferred)
509 {
510 
511     if (operationNotDeferred)
512     {
513         // when the operation deferral returns VK_OPERATION_NOT_DEFERRED_KHR,
514         // the deferred operation should act as if no command was deferred
515         VK_CHECK(vk.getDeferredOperationResultKHR(device, deferredOperation));
516 
517         // there is not need to join any threads to the deferred operation,
518         // so below can be skipped.
519         return;
520     }
521 
522     if (workerThreadCount == 0)
523     {
524         VK_CHECK(finishDeferredOperation(vk, device, deferredOperation));
525     }
526     else
527     {
528         const uint32_t maxThreadCountSupported =
529             deMinu32(256u, vk.getDeferredOperationMaxConcurrencyKHR(device, deferredOperation));
530         const uint32_t requestedThreadCount = workerThreadCount;
531         const uint32_t testThreadCount      = requestedThreadCount == std::numeric_limits<uint32_t>::max() ?
532                                                   maxThreadCountSupported :
533                                                   requestedThreadCount;
534 
535         if (maxThreadCountSupported == 0)
536             TCU_FAIL("vkGetDeferredOperationMaxConcurrencyKHR must not return 0");
537 
538         const DeferredThreadParams deferredThreadParams = {
539             vk,                 //  const DeviceInterface& vk;
540             device,             //  VkDevice device;
541             deferredOperation,  //  VkDeferredOperationKHR deferredOperation;
542             VK_RESULT_MAX_ENUM, //  VResult result;
543         };
544         std::vector<DeferredThreadParams> threadParams(testThreadCount, deferredThreadParams);
545         std::vector<de::MovePtr<std::thread>> threads(testThreadCount);
546         bool executionResult = false;
547 
548         DE_ASSERT(threads.size() > 0 && threads.size() == testThreadCount);
549 
550         for (uint32_t threadNdx = 0; threadNdx < testThreadCount; ++threadNdx)
551             threads[threadNdx] =
552                 de::MovePtr<std::thread>(new std::thread(finishDeferredOperationThreaded, &threadParams[threadNdx]));
553 
554         for (uint32_t threadNdx = 0; threadNdx < testThreadCount; ++threadNdx)
555             threads[threadNdx]->join();
556 
557         for (uint32_t threadNdx = 0; threadNdx < testThreadCount; ++threadNdx)
558             if (threadParams[threadNdx].result == VK_SUCCESS)
559                 executionResult = true;
560 
561         if (!executionResult)
562             TCU_FAIL("Neither reported VK_SUCCESS");
563     }
564 }
565 
SerialStorage(const DeviceInterface & vk,const VkDevice device,Allocator & allocator,const VkAccelerationStructureBuildTypeKHR buildType,const VkDeviceSize storageSize)566 SerialStorage::SerialStorage(const DeviceInterface &vk, const VkDevice device, Allocator &allocator,
567                              const VkAccelerationStructureBuildTypeKHR buildType, const VkDeviceSize storageSize)
568     : m_buildType(buildType)
569     , m_storageSize(storageSize)
570     , m_serialInfo()
571 {
572     const VkBufferCreateInfo bufferCreateInfo =
573         makeBufferCreateInfo(storageSize, VK_BUFFER_USAGE_ACCELERATION_STRUCTURE_STORAGE_BIT_KHR |
574                                               VK_BUFFER_USAGE_SHADER_DEVICE_ADDRESS_BIT);
575     try
576     {
577         m_buffer = de::MovePtr<BufferWithMemory>(
578             new BufferWithMemory(vk, device, allocator, bufferCreateInfo,
579                                  MemoryRequirement::Cached | MemoryRequirement::HostVisible |
580                                      MemoryRequirement::Coherent | MemoryRequirement::DeviceAddress));
581     }
582     catch (const tcu::NotSupportedError &)
583     {
584         // retry without Cached flag
585         m_buffer = de::MovePtr<BufferWithMemory>(new BufferWithMemory(
586             vk, device, allocator, bufferCreateInfo,
587             MemoryRequirement::HostVisible | MemoryRequirement::Coherent | MemoryRequirement::DeviceAddress));
588     }
589 }
590 
SerialStorage(const DeviceInterface & vk,const VkDevice device,Allocator & allocator,const VkAccelerationStructureBuildTypeKHR buildType,const SerialInfo & serialInfo)591 SerialStorage::SerialStorage(const DeviceInterface &vk, const VkDevice device, Allocator &allocator,
592                              const VkAccelerationStructureBuildTypeKHR buildType, const SerialInfo &serialInfo)
593     : m_buildType(buildType)
594     , m_storageSize(serialInfo.sizes()[0]) // raise assertion if serialInfo is empty
595     , m_serialInfo(serialInfo)
596 {
597     DE_ASSERT(serialInfo.sizes().size() >= 2u);
598 
599     // create buffer for top-level acceleration structure
600     {
601         const VkBufferCreateInfo bufferCreateInfo =
602             makeBufferCreateInfo(m_storageSize, VK_BUFFER_USAGE_ACCELERATION_STRUCTURE_STORAGE_BIT_KHR |
603                                                     VK_BUFFER_USAGE_SHADER_DEVICE_ADDRESS_BIT);
604         m_buffer = de::MovePtr<BufferWithMemory>(new BufferWithMemory(
605             vk, device, allocator, bufferCreateInfo,
606             MemoryRequirement::HostVisible | MemoryRequirement::Coherent | MemoryRequirement::DeviceAddress));
607     }
608 
609     // create buffers for bottom-level acceleration structures
610     {
611         std::vector<uint64_t> addrs;
612 
613         for (std::size_t i = 1; i < serialInfo.addresses().size(); ++i)
614         {
615             const uint64_t &lookAddr = serialInfo.addresses()[i];
616             auto end                 = addrs.end();
617             auto match = std::find_if(addrs.begin(), end, [&](const uint64_t &item) { return item == lookAddr; });
618             if (match == end)
619             {
620                 addrs.emplace_back(lookAddr);
621                 m_bottoms.emplace_back(de::SharedPtr<SerialStorage>(
622                     new SerialStorage(vk, device, allocator, buildType, serialInfo.sizes()[i])));
623             }
624         }
625     }
626 }
627 
getAddress(const DeviceInterface & vk,const VkDevice device,const VkAccelerationStructureBuildTypeKHR buildType)628 VkDeviceOrHostAddressKHR SerialStorage::getAddress(const DeviceInterface &vk, const VkDevice device,
629                                                    const VkAccelerationStructureBuildTypeKHR buildType)
630 {
631     if (buildType == VK_ACCELERATION_STRUCTURE_BUILD_TYPE_DEVICE_KHR)
632         return makeDeviceOrHostAddressKHR(vk, device, m_buffer->get(), 0);
633     else
634         return makeDeviceOrHostAddressKHR(m_buffer->getAllocation().getHostPtr());
635 }
636 
getASHeader()637 SerialStorage::AccelerationStructureHeader *SerialStorage::getASHeader()
638 {
639     return reinterpret_cast<AccelerationStructureHeader *>(getHostAddress().hostAddress);
640 }
641 
hasDeepFormat() const642 bool SerialStorage::hasDeepFormat() const
643 {
644     return (m_serialInfo.sizes().size() >= 2u);
645 }
646 
getBottomStorage(uint32_t index) const647 de::SharedPtr<SerialStorage> SerialStorage::getBottomStorage(uint32_t index) const
648 {
649     return m_bottoms[index];
650 }
651 
getHostAddress(VkDeviceSize offset)652 VkDeviceOrHostAddressKHR SerialStorage::getHostAddress(VkDeviceSize offset)
653 {
654     DE_ASSERT(offset < m_storageSize);
655     return makeDeviceOrHostAddressKHR(static_cast<uint8_t *>(m_buffer->getAllocation().getHostPtr()) + offset);
656 }
657 
getHostAddressConst(VkDeviceSize offset)658 VkDeviceOrHostAddressConstKHR SerialStorage::getHostAddressConst(VkDeviceSize offset)
659 {
660     return makeDeviceOrHostAddressConstKHR(static_cast<uint8_t *>(m_buffer->getAllocation().getHostPtr()) + offset);
661 }
662 
getAddressConst(const DeviceInterface & vk,const VkDevice device,const VkAccelerationStructureBuildTypeKHR buildType)663 VkDeviceOrHostAddressConstKHR SerialStorage::getAddressConst(const DeviceInterface &vk, const VkDevice device,
664                                                              const VkAccelerationStructureBuildTypeKHR buildType)
665 {
666     if (buildType == VK_ACCELERATION_STRUCTURE_BUILD_TYPE_DEVICE_KHR)
667         return makeDeviceOrHostAddressConstKHR(vk, device, m_buffer->get(), 0);
668     else
669         return getHostAddressConst();
670 }
671 
getStorageSize() const672 inline VkDeviceSize SerialStorage::getStorageSize() const
673 {
674     return m_storageSize;
675 }
676 
getSerialInfo() const677 inline const SerialInfo &SerialStorage::getSerialInfo() const
678 {
679     return m_serialInfo;
680 }
681 
getDeserializedSize()682 uint64_t SerialStorage::getDeserializedSize()
683 {
684     uint64_t result         = 0;
685     const uint8_t *startPtr = static_cast<uint8_t *>(m_buffer->getAllocation().getHostPtr());
686 
687     DE_ASSERT(sizeof(result) == DESERIALIZED_SIZE_SIZE);
688 
689     deMemcpy(&result, startPtr + DESERIALIZED_SIZE_OFFSET, sizeof(result));
690 
691     return result;
692 }
693 
~BottomLevelAccelerationStructure()694 BottomLevelAccelerationStructure::~BottomLevelAccelerationStructure()
695 {
696 }
697 
BottomLevelAccelerationStructure()698 BottomLevelAccelerationStructure::BottomLevelAccelerationStructure()
699     : m_structureSize(0u)
700     , m_updateScratchSize(0u)
701     , m_buildScratchSize(0u)
702 {
703 }
704 
setGeometryData(const std::vector<tcu::Vec3> & geometryData,const bool triangles,const VkGeometryFlagsKHR geometryFlags)705 void BottomLevelAccelerationStructure::setGeometryData(const std::vector<tcu::Vec3> &geometryData, const bool triangles,
706                                                        const VkGeometryFlagsKHR geometryFlags)
707 {
708     if (triangles)
709         DE_ASSERT((geometryData.size() % 3) == 0);
710     else
711         DE_ASSERT((geometryData.size() % 2) == 0);
712 
713     setGeometryCount(1u);
714 
715     addGeometry(geometryData, triangles, geometryFlags);
716 }
717 
setDefaultGeometryData(const VkShaderStageFlagBits testStage,const VkGeometryFlagsKHR geometryFlags)718 void BottomLevelAccelerationStructure::setDefaultGeometryData(const VkShaderStageFlagBits testStage,
719                                                               const VkGeometryFlagsKHR geometryFlags)
720 {
721     bool trianglesData = false;
722     float z            = 0.0f;
723     std::vector<tcu::Vec3> geometryData;
724 
725     switch (testStage)
726     {
727     case VK_SHADER_STAGE_RAYGEN_BIT_KHR:
728         z             = -1.0f;
729         trianglesData = true;
730         break;
731     case VK_SHADER_STAGE_ANY_HIT_BIT_KHR:
732         z             = -1.0f;
733         trianglesData = true;
734         break;
735     case VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR:
736         z             = -1.0f;
737         trianglesData = true;
738         break;
739     case VK_SHADER_STAGE_MISS_BIT_KHR:
740         z             = -9.9f;
741         trianglesData = true;
742         break;
743     case VK_SHADER_STAGE_INTERSECTION_BIT_KHR:
744         z             = -1.0f;
745         trianglesData = false;
746         break;
747     case VK_SHADER_STAGE_CALLABLE_BIT_KHR:
748         z             = -1.0f;
749         trianglesData = true;
750         break;
751     default:
752         TCU_THROW(InternalError, "Unacceptable stage");
753     }
754 
755     if (trianglesData)
756     {
757         geometryData.reserve(6);
758 
759         geometryData.push_back(tcu::Vec3(-1.0f, -1.0f, z));
760         geometryData.push_back(tcu::Vec3(-1.0f, +1.0f, z));
761         geometryData.push_back(tcu::Vec3(+1.0f, -1.0f, z));
762         geometryData.push_back(tcu::Vec3(+1.0f, -1.0f, z));
763         geometryData.push_back(tcu::Vec3(-1.0f, +1.0f, z));
764         geometryData.push_back(tcu::Vec3(+1.0f, +1.0f, z));
765     }
766     else
767     {
768         geometryData.reserve(2);
769 
770         geometryData.push_back(tcu::Vec3(-1.0f, -1.0f, z));
771         geometryData.push_back(tcu::Vec3(+1.0f, +1.0f, z));
772     }
773 
774     setGeometryCount(1u);
775 
776     addGeometry(geometryData, trianglesData, geometryFlags);
777 }
778 
setGeometryCount(const size_t geometryCount)779 void BottomLevelAccelerationStructure::setGeometryCount(const size_t geometryCount)
780 {
781     m_geometriesData.clear();
782 
783     m_geometriesData.reserve(geometryCount);
784 }
785 
addGeometry(de::SharedPtr<RaytracedGeometryBase> & raytracedGeometry)786 void BottomLevelAccelerationStructure::addGeometry(de::SharedPtr<RaytracedGeometryBase> &raytracedGeometry)
787 {
788     m_geometriesData.push_back(raytracedGeometry);
789 }
790 
addGeometry(const std::vector<tcu::Vec3> & geometryData,const bool triangles,const VkGeometryFlagsKHR geometryFlags,const VkAccelerationStructureTrianglesOpacityMicromapEXT * opacityGeometryMicromap)791 void BottomLevelAccelerationStructure::addGeometry(
792     const std::vector<tcu::Vec3> &geometryData, const bool triangles, const VkGeometryFlagsKHR geometryFlags,
793     const VkAccelerationStructureTrianglesOpacityMicromapEXT *opacityGeometryMicromap)
794 {
795     DE_ASSERT(geometryData.size() > 0);
796     DE_ASSERT((triangles && geometryData.size() % 3 == 0) || (!triangles && geometryData.size() % 2 == 0));
797 
798     if (!triangles)
799         for (size_t posNdx = 0; posNdx < geometryData.size() / 2; ++posNdx)
800         {
801             DE_ASSERT(geometryData[2 * posNdx].x() <= geometryData[2 * posNdx + 1].x());
802             DE_ASSERT(geometryData[2 * posNdx].y() <= geometryData[2 * posNdx + 1].y());
803             DE_ASSERT(geometryData[2 * posNdx].z() <= geometryData[2 * posNdx + 1].z());
804         }
805 
806     de::SharedPtr<RaytracedGeometryBase> geometry =
807         makeRaytracedGeometry(triangles ? VK_GEOMETRY_TYPE_TRIANGLES_KHR : VK_GEOMETRY_TYPE_AABBS_KHR,
808                               VK_FORMAT_R32G32B32_SFLOAT, VK_INDEX_TYPE_NONE_KHR);
809     for (auto it = begin(geometryData), eit = end(geometryData); it != eit; ++it)
810         geometry->addVertex(*it);
811 
812     geometry->setGeometryFlags(geometryFlags);
813     if (opacityGeometryMicromap)
814         geometry->setOpacityMicromap(opacityGeometryMicromap);
815     addGeometry(geometry);
816 }
817 
getStructureBuildSizes() const818 VkAccelerationStructureBuildSizesInfoKHR BottomLevelAccelerationStructure::getStructureBuildSizes() const
819 {
820     return {
821         VK_STRUCTURE_TYPE_ACCELERATION_STRUCTURE_BUILD_SIZES_INFO_KHR, //  VkStructureType sType;
822         DE_NULL,                                                       //  const void* pNext;
823         m_structureSize,                                               //  VkDeviceSize accelerationStructureSize;
824         m_updateScratchSize,                                           //  VkDeviceSize updateScratchSize;
825         m_buildScratchSize                                             //  VkDeviceSize buildScratchSize;
826     };
827 };
828 
getVertexBufferSize(const std::vector<de::SharedPtr<RaytracedGeometryBase>> & geometriesData)829 VkDeviceSize getVertexBufferSize(const std::vector<de::SharedPtr<RaytracedGeometryBase>> &geometriesData)
830 {
831     DE_ASSERT(geometriesData.size() != 0);
832     VkDeviceSize bufferSizeBytes = 0;
833     for (size_t geometryNdx = 0; geometryNdx < geometriesData.size(); ++geometryNdx)
834         bufferSizeBytes += deAlignSize(geometriesData[geometryNdx]->getVertexByteSize(), 8);
835     return bufferSizeBytes;
836 }
837 
createVertexBuffer(const DeviceInterface & vk,const VkDevice device,Allocator & allocator,const VkDeviceSize bufferSizeBytes)838 BufferWithMemory *createVertexBuffer(const DeviceInterface &vk, const VkDevice device, Allocator &allocator,
839                                      const VkDeviceSize bufferSizeBytes)
840 {
841     const VkBufferCreateInfo bufferCreateInfo =
842         makeBufferCreateInfo(bufferSizeBytes, VK_BUFFER_USAGE_ACCELERATION_STRUCTURE_BUILD_INPUT_READ_ONLY_BIT_KHR |
843                                                   VK_BUFFER_USAGE_SHADER_DEVICE_ADDRESS_BIT);
844     return new BufferWithMemory(vk, device, allocator, bufferCreateInfo,
845                                 MemoryRequirement::HostVisible | MemoryRequirement::Coherent |
846                                     MemoryRequirement::DeviceAddress);
847 }
848 
createVertexBuffer(const DeviceInterface & vk,const VkDevice device,Allocator & allocator,const std::vector<de::SharedPtr<RaytracedGeometryBase>> & geometriesData)849 BufferWithMemory *createVertexBuffer(const DeviceInterface &vk, const VkDevice device, Allocator &allocator,
850                                      const std::vector<de::SharedPtr<RaytracedGeometryBase>> &geometriesData)
851 {
852     return createVertexBuffer(vk, device, allocator, getVertexBufferSize(geometriesData));
853 }
854 
updateVertexBuffer(const DeviceInterface & vk,const VkDevice device,const std::vector<de::SharedPtr<RaytracedGeometryBase>> & geometriesData,BufferWithMemory * vertexBuffer,VkDeviceSize geometriesOffset=0)855 void updateVertexBuffer(const DeviceInterface &vk, const VkDevice device,
856                         const std::vector<de::SharedPtr<RaytracedGeometryBase>> &geometriesData,
857                         BufferWithMemory *vertexBuffer, VkDeviceSize geometriesOffset = 0)
858 {
859     const Allocation &geometryAlloc = vertexBuffer->getAllocation();
860     uint8_t *bufferStart            = static_cast<uint8_t *>(geometryAlloc.getHostPtr());
861     VkDeviceSize bufferOffset       = geometriesOffset;
862 
863     for (size_t geometryNdx = 0; geometryNdx < geometriesData.size(); ++geometryNdx)
864     {
865         const void *geometryPtr      = geometriesData[geometryNdx]->getVertexPointer();
866         const size_t geometryPtrSize = geometriesData[geometryNdx]->getVertexByteSize();
867 
868         deMemcpy(&bufferStart[bufferOffset], geometryPtr, geometryPtrSize);
869 
870         bufferOffset += deAlignSize(geometryPtrSize, 8);
871     }
872 
873     // Flush the whole allocation. We could flush only the interesting range, but we'd need to be sure both the offset and size
874     // align to VkPhysicalDeviceLimits::nonCoherentAtomSize, which we are not considering. Also note most code uses Coherent memory
875     // for the vertex and index buffers, so flushing is actually not needed.
876     flushAlloc(vk, device, geometryAlloc);
877 }
878 
getIndexBufferSize(const std::vector<de::SharedPtr<RaytracedGeometryBase>> & geometriesData)879 VkDeviceSize getIndexBufferSize(const std::vector<de::SharedPtr<RaytracedGeometryBase>> &geometriesData)
880 {
881     DE_ASSERT(!geometriesData.empty());
882 
883     VkDeviceSize bufferSizeBytes = 0;
884     for (size_t geometryNdx = 0; geometryNdx < geometriesData.size(); ++geometryNdx)
885         if (geometriesData[geometryNdx]->getIndexType() != VK_INDEX_TYPE_NONE_KHR)
886             bufferSizeBytes += deAlignSize(geometriesData[geometryNdx]->getIndexByteSize(), 8);
887     return bufferSizeBytes;
888 }
889 
createIndexBuffer(const DeviceInterface & vk,const VkDevice device,Allocator & allocator,const VkDeviceSize bufferSizeBytes)890 BufferWithMemory *createIndexBuffer(const DeviceInterface &vk, const VkDevice device, Allocator &allocator,
891                                     const VkDeviceSize bufferSizeBytes)
892 {
893     DE_ASSERT(bufferSizeBytes);
894     const VkBufferCreateInfo bufferCreateInfo =
895         makeBufferCreateInfo(bufferSizeBytes, VK_BUFFER_USAGE_ACCELERATION_STRUCTURE_BUILD_INPUT_READ_ONLY_BIT_KHR |
896                                                   VK_BUFFER_USAGE_SHADER_DEVICE_ADDRESS_BIT);
897     return new BufferWithMemory(vk, device, allocator, bufferCreateInfo,
898                                 MemoryRequirement::HostVisible | MemoryRequirement::Coherent |
899                                     MemoryRequirement::DeviceAddress);
900 }
901 
createIndexBuffer(const DeviceInterface & vk,const VkDevice device,Allocator & allocator,const std::vector<de::SharedPtr<RaytracedGeometryBase>> & geometriesData)902 BufferWithMemory *createIndexBuffer(const DeviceInterface &vk, const VkDevice device, Allocator &allocator,
903                                     const std::vector<de::SharedPtr<RaytracedGeometryBase>> &geometriesData)
904 {
905     const VkDeviceSize bufferSizeBytes = getIndexBufferSize(geometriesData);
906     return bufferSizeBytes ? createIndexBuffer(vk, device, allocator, bufferSizeBytes) : nullptr;
907 }
908 
updateIndexBuffer(const DeviceInterface & vk,const VkDevice device,const std::vector<de::SharedPtr<RaytracedGeometryBase>> & geometriesData,BufferWithMemory * indexBuffer,VkDeviceSize geometriesOffset)909 void updateIndexBuffer(const DeviceInterface &vk, const VkDevice device,
910                        const std::vector<de::SharedPtr<RaytracedGeometryBase>> &geometriesData,
911                        BufferWithMemory *indexBuffer, VkDeviceSize geometriesOffset)
912 {
913     const Allocation &indexAlloc = indexBuffer->getAllocation();
914     uint8_t *bufferStart         = static_cast<uint8_t *>(indexAlloc.getHostPtr());
915     VkDeviceSize bufferOffset    = geometriesOffset;
916 
917     for (size_t geometryNdx = 0; geometryNdx < geometriesData.size(); ++geometryNdx)
918     {
919         if (geometriesData[geometryNdx]->getIndexType() != VK_INDEX_TYPE_NONE_KHR)
920         {
921             const void *indexPtr      = geometriesData[geometryNdx]->getIndexPointer();
922             const size_t indexPtrSize = geometriesData[geometryNdx]->getIndexByteSize();
923 
924             deMemcpy(&bufferStart[bufferOffset], indexPtr, indexPtrSize);
925 
926             bufferOffset += deAlignSize(indexPtrSize, 8);
927         }
928     }
929 
930     // Flush the whole allocation. We could flush only the interesting range, but we'd need to be sure both the offset and size
931     // align to VkPhysicalDeviceLimits::nonCoherentAtomSize, which we are not considering. Also note most code uses Coherent memory
932     // for the vertex and index buffers, so flushing is actually not needed.
933     flushAlloc(vk, device, indexAlloc);
934 }
935 
936 class BottomLevelAccelerationStructureKHR : public BottomLevelAccelerationStructure
937 {
938 public:
939     static uint32_t getRequiredAllocationCount(void);
940 
941     BottomLevelAccelerationStructureKHR();
942     BottomLevelAccelerationStructureKHR(const BottomLevelAccelerationStructureKHR &other) = delete;
943     virtual ~BottomLevelAccelerationStructureKHR();
944 
945     void setBuildType(const VkAccelerationStructureBuildTypeKHR buildType) override;
946     VkAccelerationStructureBuildTypeKHR getBuildType() const override;
947     void setCreateFlags(const VkAccelerationStructureCreateFlagsKHR createFlags) override;
948     void setCreateGeneric(bool createGeneric) override;
949     void setCreationBufferUnbounded(bool creationBufferUnbounded) override;
950     void setBuildFlags(const VkBuildAccelerationStructureFlagsKHR buildFlags) override;
951     void setBuildWithoutGeometries(bool buildWithoutGeometries) override;
952     void setBuildWithoutPrimitives(bool buildWithoutPrimitives) override;
953     void setDeferredOperation(const bool deferredOperation, const uint32_t workerThreadCount) override;
954     void setUseArrayOfPointers(const bool useArrayOfPointers) override;
955     void setUseMaintenance5(const bool useMaintenance5) override;
956     void setIndirectBuildParameters(const VkBuffer indirectBuffer, const VkDeviceSize indirectBufferOffset,
957                                     const uint32_t indirectBufferStride) override;
958     VkBuildAccelerationStructureFlagsKHR getBuildFlags() const override;
959 
960     void create(const DeviceInterface &vk, const VkDevice device, Allocator &allocator, VkDeviceSize structureSize,
961                 VkDeviceAddress deviceAddress = 0u, const void *pNext = DE_NULL,
962                 const MemoryRequirement &addMemoryRequirement = MemoryRequirement::Any,
963                 const VkBuffer creationBuffer = VK_NULL_HANDLE, const VkDeviceSize creationBufferSize = 0u) override;
964     void build(const DeviceInterface &vk, const VkDevice device, const VkCommandBuffer cmdBuffer,
965                BottomLevelAccelerationStructure *srcAccelerationStructure = DE_NULL) override;
966     void copyFrom(const DeviceInterface &vk, const VkDevice device, const VkCommandBuffer cmdBuffer,
967                   BottomLevelAccelerationStructure *accelerationStructure, bool compactCopy) override;
968 
969     void serialize(const DeviceInterface &vk, const VkDevice device, const VkCommandBuffer cmdBuffer,
970                    SerialStorage *storage) override;
971     void deserialize(const DeviceInterface &vk, const VkDevice device, const VkCommandBuffer cmdBuffer,
972                      SerialStorage *storage) override;
973 
974     const VkAccelerationStructureKHR *getPtr(void) const override;
975     void updateGeometry(size_t geometryIndex, de::SharedPtr<RaytracedGeometryBase> &raytracedGeometry) override;
976 
977 protected:
978     VkAccelerationStructureBuildTypeKHR m_buildType;
979     VkAccelerationStructureCreateFlagsKHR m_createFlags;
980     bool m_createGeneric;
981     bool m_creationBufferUnbounded;
982     VkBuildAccelerationStructureFlagsKHR m_buildFlags;
983     bool m_buildWithoutGeometries;
984     bool m_buildWithoutPrimitives;
985     bool m_deferredOperation;
986     uint32_t m_workerThreadCount;
987     bool m_useArrayOfPointers;
988     bool m_useMaintenance5;
989     de::MovePtr<BufferWithMemory> m_accelerationStructureBuffer;
990     de::MovePtr<BufferWithMemory> m_vertexBuffer;
991     de::MovePtr<BufferWithMemory> m_indexBuffer;
992     de::MovePtr<BufferWithMemory> m_deviceScratchBuffer;
993     de::UniquePtr<std::vector<uint8_t>> m_hostScratchBuffer;
994     Move<VkAccelerationStructureKHR> m_accelerationStructureKHR;
995     VkBuffer m_indirectBuffer;
996     VkDeviceSize m_indirectBufferOffset;
997     uint32_t m_indirectBufferStride;
998 
999     void prepareGeometries(
1000         const DeviceInterface &vk, const VkDevice device,
1001         std::vector<VkAccelerationStructureGeometryKHR> &accelerationStructureGeometriesKHR,
1002         std::vector<VkAccelerationStructureGeometryKHR *> &accelerationStructureGeometriesKHRPointers,
1003         std::vector<VkAccelerationStructureBuildRangeInfoKHR> &accelerationStructureBuildRangeInfoKHR,
1004         std::vector<VkAccelerationStructureTrianglesOpacityMicromapEXT> &accelerationStructureGeometryMicromapsEXT,
1005         std::vector<uint32_t> &maxPrimitiveCounts, VkDeviceSize vertexBufferOffset = 0,
1006         VkDeviceSize indexBufferOffset = 0) const;
1007 
getAccelerationStructureBuffer() const1008     virtual BufferWithMemory *getAccelerationStructureBuffer() const
1009     {
1010         return m_accelerationStructureBuffer.get();
1011     }
getDeviceScratchBuffer() const1012     virtual BufferWithMemory *getDeviceScratchBuffer() const
1013     {
1014         return m_deviceScratchBuffer.get();
1015     }
getHostScratchBuffer() const1016     virtual std::vector<uint8_t> *getHostScratchBuffer() const
1017     {
1018         return m_hostScratchBuffer.get();
1019     }
getVertexBuffer() const1020     virtual BufferWithMemory *getVertexBuffer() const
1021     {
1022         return m_vertexBuffer.get();
1023     }
getIndexBuffer() const1024     virtual BufferWithMemory *getIndexBuffer() const
1025     {
1026         return m_indexBuffer.get();
1027     }
1028 
getAccelerationStructureBufferOffset() const1029     virtual VkDeviceSize getAccelerationStructureBufferOffset() const
1030     {
1031         return 0;
1032     }
getDeviceScratchBufferOffset() const1033     virtual VkDeviceSize getDeviceScratchBufferOffset() const
1034     {
1035         return 0;
1036     }
getVertexBufferOffset() const1037     virtual VkDeviceSize getVertexBufferOffset() const
1038     {
1039         return 0;
1040     }
getIndexBufferOffset() const1041     virtual VkDeviceSize getIndexBufferOffset() const
1042     {
1043         return 0;
1044     }
1045 };
1046 
getRequiredAllocationCount(void)1047 uint32_t BottomLevelAccelerationStructureKHR::getRequiredAllocationCount(void)
1048 {
1049     /*
1050         de::MovePtr<BufferWithMemory>                            m_geometryBuffer; // but only when m_buildType == VK_ACCELERATION_STRUCTURE_BUILD_TYPE_DEVICE_KHR
1051         de::MovePtr<Allocation>                                    m_accelerationStructureAlloc;
1052         de::MovePtr<BufferWithMemory>                            m_deviceScratchBuffer;
1053     */
1054     return 3u;
1055 }
1056 
~BottomLevelAccelerationStructureKHR()1057 BottomLevelAccelerationStructureKHR::~BottomLevelAccelerationStructureKHR()
1058 {
1059 }
1060 
BottomLevelAccelerationStructureKHR()1061 BottomLevelAccelerationStructureKHR::BottomLevelAccelerationStructureKHR()
1062     : BottomLevelAccelerationStructure()
1063     , m_buildType(VK_ACCELERATION_STRUCTURE_BUILD_TYPE_DEVICE_KHR)
1064     , m_createFlags(0u)
1065     , m_createGeneric(false)
1066     , m_creationBufferUnbounded(false)
1067     , m_buildFlags(0u)
1068     , m_buildWithoutGeometries(false)
1069     , m_buildWithoutPrimitives(false)
1070     , m_deferredOperation(false)
1071     , m_workerThreadCount(0)
1072     , m_useArrayOfPointers(false)
1073     , m_useMaintenance5(false)
1074     , m_accelerationStructureBuffer(DE_NULL)
1075     , m_vertexBuffer(DE_NULL)
1076     , m_indexBuffer(DE_NULL)
1077     , m_deviceScratchBuffer(DE_NULL)
1078     , m_hostScratchBuffer(new std::vector<uint8_t>)
1079     , m_accelerationStructureKHR()
1080     , m_indirectBuffer(DE_NULL)
1081     , m_indirectBufferOffset(0)
1082     , m_indirectBufferStride(0)
1083 {
1084 }
1085 
setBuildType(const VkAccelerationStructureBuildTypeKHR buildType)1086 void BottomLevelAccelerationStructureKHR::setBuildType(const VkAccelerationStructureBuildTypeKHR buildType)
1087 {
1088     m_buildType = buildType;
1089 }
1090 
getBuildType() const1091 VkAccelerationStructureBuildTypeKHR BottomLevelAccelerationStructureKHR::getBuildType() const
1092 {
1093     return m_buildType;
1094 }
1095 
setCreateFlags(const VkAccelerationStructureCreateFlagsKHR createFlags)1096 void BottomLevelAccelerationStructureKHR::setCreateFlags(const VkAccelerationStructureCreateFlagsKHR createFlags)
1097 {
1098     m_createFlags = createFlags;
1099 }
1100 
setCreateGeneric(bool createGeneric)1101 void BottomLevelAccelerationStructureKHR::setCreateGeneric(bool createGeneric)
1102 {
1103     m_createGeneric = createGeneric;
1104 }
1105 
setCreationBufferUnbounded(bool creationBufferUnbounded)1106 void BottomLevelAccelerationStructureKHR::setCreationBufferUnbounded(bool creationBufferUnbounded)
1107 {
1108     m_creationBufferUnbounded = creationBufferUnbounded;
1109 }
1110 
setBuildFlags(const VkBuildAccelerationStructureFlagsKHR buildFlags)1111 void BottomLevelAccelerationStructureKHR::setBuildFlags(const VkBuildAccelerationStructureFlagsKHR buildFlags)
1112 {
1113     m_buildFlags = buildFlags;
1114 }
1115 
setBuildWithoutGeometries(bool buildWithoutGeometries)1116 void BottomLevelAccelerationStructureKHR::setBuildWithoutGeometries(bool buildWithoutGeometries)
1117 {
1118     m_buildWithoutGeometries = buildWithoutGeometries;
1119 }
1120 
setBuildWithoutPrimitives(bool buildWithoutPrimitives)1121 void BottomLevelAccelerationStructureKHR::setBuildWithoutPrimitives(bool buildWithoutPrimitives)
1122 {
1123     m_buildWithoutPrimitives = buildWithoutPrimitives;
1124 }
1125 
setDeferredOperation(const bool deferredOperation,const uint32_t workerThreadCount)1126 void BottomLevelAccelerationStructureKHR::setDeferredOperation(const bool deferredOperation,
1127                                                                const uint32_t workerThreadCount)
1128 {
1129     m_deferredOperation = deferredOperation;
1130     m_workerThreadCount = workerThreadCount;
1131 }
1132 
setUseArrayOfPointers(const bool useArrayOfPointers)1133 void BottomLevelAccelerationStructureKHR::setUseArrayOfPointers(const bool useArrayOfPointers)
1134 {
1135     m_useArrayOfPointers = useArrayOfPointers;
1136 }
1137 
setUseMaintenance5(const bool useMaintenance5)1138 void BottomLevelAccelerationStructureKHR::setUseMaintenance5(const bool useMaintenance5)
1139 {
1140     m_useMaintenance5 = useMaintenance5;
1141 }
1142 
setIndirectBuildParameters(const VkBuffer indirectBuffer,const VkDeviceSize indirectBufferOffset,const uint32_t indirectBufferStride)1143 void BottomLevelAccelerationStructureKHR::setIndirectBuildParameters(const VkBuffer indirectBuffer,
1144                                                                      const VkDeviceSize indirectBufferOffset,
1145                                                                      const uint32_t indirectBufferStride)
1146 {
1147     m_indirectBuffer       = indirectBuffer;
1148     m_indirectBufferOffset = indirectBufferOffset;
1149     m_indirectBufferStride = indirectBufferStride;
1150 }
1151 
getBuildFlags() const1152 VkBuildAccelerationStructureFlagsKHR BottomLevelAccelerationStructureKHR::getBuildFlags() const
1153 {
1154     return m_buildFlags;
1155 }
1156 
create(const DeviceInterface & vk,const VkDevice device,Allocator & allocator,VkDeviceSize structureSize,VkDeviceAddress deviceAddress,const void * pNext,const MemoryRequirement & addMemoryRequirement,const VkBuffer creationBuffer,const VkDeviceSize creationBufferSize)1157 void BottomLevelAccelerationStructureKHR::create(const DeviceInterface &vk, const VkDevice device, Allocator &allocator,
1158                                                  VkDeviceSize structureSize, VkDeviceAddress deviceAddress,
1159                                                  const void *pNext, const MemoryRequirement &addMemoryRequirement,
1160                                                  const VkBuffer creationBuffer, const VkDeviceSize creationBufferSize)
1161 {
1162     // AS may be built from geometries using vkCmdBuildAccelerationStructuresKHR / vkBuildAccelerationStructuresKHR
1163     // or may be copied/compacted/deserialized from other AS ( in this case AS does not need geometries, but it needs to know its size before creation ).
1164     DE_ASSERT(!m_geometriesData.empty() != !(structureSize == 0)); // logical xor
1165 
1166     if (structureSize == 0)
1167     {
1168         std::vector<VkAccelerationStructureGeometryKHR> accelerationStructureGeometriesKHR;
1169         std::vector<VkAccelerationStructureGeometryKHR *> accelerationStructureGeometriesKHRPointers;
1170         std::vector<VkAccelerationStructureBuildRangeInfoKHR> accelerationStructureBuildRangeInfoKHR;
1171         std::vector<VkAccelerationStructureTrianglesOpacityMicromapEXT> accelerationStructureGeometryMicromapsEXT;
1172         std::vector<uint32_t> maxPrimitiveCounts;
1173         prepareGeometries(vk, device, accelerationStructureGeometriesKHR, accelerationStructureGeometriesKHRPointers,
1174                           accelerationStructureBuildRangeInfoKHR, accelerationStructureGeometryMicromapsEXT,
1175                           maxPrimitiveCounts);
1176 
1177         const VkAccelerationStructureGeometryKHR *accelerationStructureGeometriesKHRPointer =
1178             accelerationStructureGeometriesKHR.data();
1179         const VkAccelerationStructureGeometryKHR *const *accelerationStructureGeometry =
1180             accelerationStructureGeometriesKHRPointers.data();
1181 
1182         const uint32_t geometryCount =
1183             (m_buildWithoutGeometries ? 0u : static_cast<uint32_t>(accelerationStructureGeometriesKHR.size()));
1184         VkAccelerationStructureBuildGeometryInfoKHR accelerationStructureBuildGeometryInfoKHR = {
1185             VK_STRUCTURE_TYPE_ACCELERATION_STRUCTURE_BUILD_GEOMETRY_INFO_KHR, //  VkStructureType sType;
1186             DE_NULL,                                                          //  const void* pNext;
1187             VK_ACCELERATION_STRUCTURE_TYPE_BOTTOM_LEVEL_KHR,                  //  VkAccelerationStructureTypeKHR type;
1188             m_buildFlags,                                   //  VkBuildAccelerationStructureFlagsKHR flags;
1189             VK_BUILD_ACCELERATION_STRUCTURE_MODE_BUILD_KHR, //  VkBuildAccelerationStructureModeKHR mode;
1190             DE_NULL,                                        //  VkAccelerationStructureKHR srcAccelerationStructure;
1191             DE_NULL,                                        //  VkAccelerationStructureKHR dstAccelerationStructure;
1192             geometryCount,                                  //  uint32_t geometryCount;
1193             m_useArrayOfPointers ?
1194                 DE_NULL :
1195                 accelerationStructureGeometriesKHRPointer, //  const VkAccelerationStructureGeometryKHR* pGeometries;
1196             m_useArrayOfPointers ? accelerationStructureGeometry :
1197                                    DE_NULL,     //  const VkAccelerationStructureGeometryKHR* const* ppGeometries;
1198             makeDeviceOrHostAddressKHR(DE_NULL) //  VkDeviceOrHostAddressKHR scratchData;
1199         };
1200         VkAccelerationStructureBuildSizesInfoKHR sizeInfo = {
1201             VK_STRUCTURE_TYPE_ACCELERATION_STRUCTURE_BUILD_SIZES_INFO_KHR, //  VkStructureType sType;
1202             DE_NULL,                                                       //  const void* pNext;
1203             0,                                                             //  VkDeviceSize accelerationStructureSize;
1204             0,                                                             //  VkDeviceSize updateScratchSize;
1205             0                                                              //  VkDeviceSize buildScratchSize;
1206         };
1207 
1208         vk.getAccelerationStructureBuildSizesKHR(device, m_buildType, &accelerationStructureBuildGeometryInfoKHR,
1209                                                  maxPrimitiveCounts.data(), &sizeInfo);
1210 
1211         m_structureSize     = sizeInfo.accelerationStructureSize;
1212         m_updateScratchSize = sizeInfo.updateScratchSize;
1213         m_buildScratchSize  = sizeInfo.buildScratchSize;
1214     }
1215     else
1216     {
1217         m_structureSize     = structureSize;
1218         m_updateScratchSize = 0u;
1219         m_buildScratchSize  = 0u;
1220     }
1221 
1222     const bool externalCreationBuffer = (creationBuffer != VK_NULL_HANDLE);
1223 
1224     if (externalCreationBuffer)
1225     {
1226         DE_UNREF(creationBufferSize); // For release builds.
1227         DE_ASSERT(creationBufferSize >= m_structureSize);
1228     }
1229 
1230     if (!externalCreationBuffer)
1231     {
1232         VkBufferCreateInfo bufferCreateInfo =
1233             makeBufferCreateInfo(m_structureSize, VK_BUFFER_USAGE_ACCELERATION_STRUCTURE_STORAGE_BIT_KHR |
1234                                                       VK_BUFFER_USAGE_SHADER_DEVICE_ADDRESS_BIT);
1235         VkBufferUsageFlags2CreateInfoKHR bufferUsageFlags2 = vk::initVulkanStructure();
1236 
1237         if (m_useMaintenance5)
1238         {
1239             bufferUsageFlags2.usage = VK_BUFFER_USAGE_2_ACCELERATION_STRUCTURE_STORAGE_BIT_KHR |
1240                                       VK_BUFFER_USAGE_2_SHADER_DEVICE_ADDRESS_BIT_KHR;
1241             bufferCreateInfo.pNext = &bufferUsageFlags2;
1242             bufferCreateInfo.usage = 0;
1243         }
1244 
1245         const MemoryRequirement memoryRequirement = addMemoryRequirement | MemoryRequirement::HostVisible |
1246                                                     MemoryRequirement::Coherent | MemoryRequirement::DeviceAddress;
1247         const bool bindMemOnCreation = (!m_creationBufferUnbounded);
1248 
1249         try
1250         {
1251             m_accelerationStructureBuffer = de::MovePtr<BufferWithMemory>(
1252                 new BufferWithMemory(vk, device, allocator, bufferCreateInfo,
1253                                      (MemoryRequirement::Cached | memoryRequirement), bindMemOnCreation));
1254         }
1255         catch (const tcu::NotSupportedError &)
1256         {
1257             // retry without Cached flag
1258             m_accelerationStructureBuffer = de::MovePtr<BufferWithMemory>(
1259                 new BufferWithMemory(vk, device, allocator, bufferCreateInfo, memoryRequirement, bindMemOnCreation));
1260         }
1261     }
1262 
1263     const auto createInfoBuffer = (externalCreationBuffer ? creationBuffer : getAccelerationStructureBuffer()->get());
1264     const auto createInfoOffset =
1265         (externalCreationBuffer ? static_cast<VkDeviceSize>(0) : getAccelerationStructureBufferOffset());
1266     {
1267         const VkAccelerationStructureTypeKHR structureType =
1268             (m_createGeneric ? VK_ACCELERATION_STRUCTURE_TYPE_GENERIC_KHR :
1269                                VK_ACCELERATION_STRUCTURE_TYPE_BOTTOM_LEVEL_KHR);
1270         const VkAccelerationStructureCreateInfoKHR accelerationStructureCreateInfoKHR{
1271             VK_STRUCTURE_TYPE_ACCELERATION_STRUCTURE_CREATE_INFO_KHR, //  VkStructureType sType;
1272             pNext,                                                    //  const void* pNext;
1273             m_createFlags,    //  VkAccelerationStructureCreateFlagsKHR createFlags;
1274             createInfoBuffer, //  VkBuffer buffer;
1275             createInfoOffset, //  VkDeviceSize offset;
1276             m_structureSize,  //  VkDeviceSize size;
1277             structureType,    //  VkAccelerationStructureTypeKHR type;
1278             deviceAddress     //  VkDeviceAddress deviceAddress;
1279         };
1280 
1281         m_accelerationStructureKHR =
1282             createAccelerationStructureKHR(vk, device, &accelerationStructureCreateInfoKHR, DE_NULL);
1283 
1284         // Make sure buffer memory is always bound after creation.
1285         if (!externalCreationBuffer)
1286             m_accelerationStructureBuffer->bindMemory();
1287     }
1288 
1289     if (m_buildScratchSize > 0u)
1290     {
1291         if (m_buildType == VK_ACCELERATION_STRUCTURE_BUILD_TYPE_DEVICE_KHR)
1292         {
1293             const VkBufferCreateInfo bufferCreateInfo = makeBufferCreateInfo(
1294                 m_buildScratchSize, VK_BUFFER_USAGE_STORAGE_BUFFER_BIT | VK_BUFFER_USAGE_SHADER_DEVICE_ADDRESS_BIT);
1295             m_deviceScratchBuffer = de::MovePtr<BufferWithMemory>(new BufferWithMemory(
1296                 vk, device, allocator, bufferCreateInfo,
1297                 MemoryRequirement::HostVisible | MemoryRequirement::Coherent | MemoryRequirement::DeviceAddress));
1298         }
1299         else
1300         {
1301             m_hostScratchBuffer->resize(static_cast<size_t>(m_buildScratchSize));
1302         }
1303     }
1304 
1305     if (m_buildType == VK_ACCELERATION_STRUCTURE_BUILD_TYPE_DEVICE_KHR && !m_geometriesData.empty())
1306     {
1307         VkBufferCreateInfo bufferCreateInfo =
1308             makeBufferCreateInfo(getVertexBufferSize(m_geometriesData),
1309                                  VK_BUFFER_USAGE_ACCELERATION_STRUCTURE_BUILD_INPUT_READ_ONLY_BIT_KHR |
1310                                      VK_BUFFER_USAGE_SHADER_DEVICE_ADDRESS_BIT);
1311         VkBufferUsageFlags2CreateInfoKHR bufferUsageFlags2 = vk::initVulkanStructure();
1312 
1313         if (m_useMaintenance5)
1314         {
1315             bufferUsageFlags2.usage = vk::VK_BUFFER_USAGE_2_ACCELERATION_STRUCTURE_BUILD_INPUT_READ_ONLY_BIT_KHR |
1316                                       VK_BUFFER_USAGE_2_SHADER_DEVICE_ADDRESS_BIT_KHR;
1317             bufferCreateInfo.pNext = &bufferUsageFlags2;
1318             bufferCreateInfo.usage = 0;
1319         }
1320 
1321         const vk::MemoryRequirement memoryRequirement =
1322             MemoryRequirement::HostVisible | MemoryRequirement::Coherent | MemoryRequirement::DeviceAddress;
1323         m_vertexBuffer = de::MovePtr<BufferWithMemory>(
1324             new BufferWithMemory(vk, device, allocator, bufferCreateInfo, memoryRequirement));
1325 
1326         bufferCreateInfo.size = getIndexBufferSize(m_geometriesData);
1327         if (bufferCreateInfo.size)
1328             m_indexBuffer = de::MovePtr<BufferWithMemory>(
1329                 new BufferWithMemory(vk, device, allocator, bufferCreateInfo, memoryRequirement));
1330         else
1331             m_indexBuffer = de::MovePtr<BufferWithMemory>(nullptr);
1332     }
1333 }
1334 
build(const DeviceInterface & vk,const VkDevice device,const VkCommandBuffer cmdBuffer,BottomLevelAccelerationStructure * srcAccelerationStructure)1335 void BottomLevelAccelerationStructureKHR::build(const DeviceInterface &vk, const VkDevice device,
1336                                                 const VkCommandBuffer cmdBuffer,
1337                                                 BottomLevelAccelerationStructure *srcAccelerationStructure)
1338 {
1339     DE_ASSERT(!m_geometriesData.empty());
1340     DE_ASSERT(m_accelerationStructureKHR.get() != DE_NULL);
1341     DE_ASSERT(m_buildScratchSize != 0);
1342 
1343     if (m_buildType == VK_ACCELERATION_STRUCTURE_BUILD_TYPE_DEVICE_KHR)
1344     {
1345         updateVertexBuffer(vk, device, m_geometriesData, getVertexBuffer(), getVertexBufferOffset());
1346         if (getIndexBuffer() != DE_NULL)
1347             updateIndexBuffer(vk, device, m_geometriesData, getIndexBuffer(), getIndexBufferOffset());
1348     }
1349 
1350     {
1351         std::vector<VkAccelerationStructureGeometryKHR> accelerationStructureGeometriesKHR;
1352         std::vector<VkAccelerationStructureGeometryKHR *> accelerationStructureGeometriesKHRPointers;
1353         std::vector<VkAccelerationStructureBuildRangeInfoKHR> accelerationStructureBuildRangeInfoKHR;
1354         std::vector<VkAccelerationStructureTrianglesOpacityMicromapEXT> accelerationStructureGeometryMicromapsEXT;
1355         std::vector<uint32_t> maxPrimitiveCounts;
1356 
1357         prepareGeometries(vk, device, accelerationStructureGeometriesKHR, accelerationStructureGeometriesKHRPointers,
1358                           accelerationStructureBuildRangeInfoKHR, accelerationStructureGeometryMicromapsEXT,
1359                           maxPrimitiveCounts, getVertexBufferOffset(), getIndexBufferOffset());
1360 
1361         const VkAccelerationStructureGeometryKHR *accelerationStructureGeometriesKHRPointer =
1362             accelerationStructureGeometriesKHR.data();
1363         const VkAccelerationStructureGeometryKHR *const *accelerationStructureGeometry =
1364             accelerationStructureGeometriesKHRPointers.data();
1365         VkDeviceOrHostAddressKHR scratchData =
1366             (m_buildType == VK_ACCELERATION_STRUCTURE_BUILD_TYPE_DEVICE_KHR) ?
1367                 makeDeviceOrHostAddressKHR(vk, device, getDeviceScratchBuffer()->get(),
1368                                            getDeviceScratchBufferOffset()) :
1369                 makeDeviceOrHostAddressKHR(getHostScratchBuffer()->data());
1370         const uint32_t geometryCount =
1371             (m_buildWithoutGeometries ? 0u : static_cast<uint32_t>(accelerationStructureGeometriesKHR.size()));
1372 
1373         VkAccelerationStructureKHR srcStructure =
1374             (srcAccelerationStructure != DE_NULL) ? *(srcAccelerationStructure->getPtr()) : DE_NULL;
1375         VkBuildAccelerationStructureModeKHR mode = (srcAccelerationStructure != DE_NULL) ?
1376                                                        VK_BUILD_ACCELERATION_STRUCTURE_MODE_UPDATE_KHR :
1377                                                        VK_BUILD_ACCELERATION_STRUCTURE_MODE_BUILD_KHR;
1378 
1379         VkAccelerationStructureBuildGeometryInfoKHR accelerationStructureBuildGeometryInfoKHR = {
1380             VK_STRUCTURE_TYPE_ACCELERATION_STRUCTURE_BUILD_GEOMETRY_INFO_KHR, //  VkStructureType sType;
1381             DE_NULL,                                                          //  const void* pNext;
1382             VK_ACCELERATION_STRUCTURE_TYPE_BOTTOM_LEVEL_KHR,                  //  VkAccelerationStructureTypeKHR type;
1383             m_buildFlags,                     //  VkBuildAccelerationStructureFlagsKHR flags;
1384             mode,                             //  VkBuildAccelerationStructureModeKHR mode;
1385             srcStructure,                     //  VkAccelerationStructureKHR srcAccelerationStructure;
1386             m_accelerationStructureKHR.get(), //  VkAccelerationStructureKHR dstAccelerationStructure;
1387             geometryCount,                    //  uint32_t geometryCount;
1388             m_useArrayOfPointers ?
1389                 DE_NULL :
1390                 accelerationStructureGeometriesKHRPointer, //  const VkAccelerationStructureGeometryKHR* pGeometries;
1391             m_useArrayOfPointers ? accelerationStructureGeometry :
1392                                    DE_NULL, //  const VkAccelerationStructureGeometryKHR* const* ppGeometries;
1393             scratchData                     //  VkDeviceOrHostAddressKHR scratchData;
1394         };
1395 
1396         VkAccelerationStructureBuildRangeInfoKHR *accelerationStructureBuildRangeInfoKHRPtr =
1397             accelerationStructureBuildRangeInfoKHR.data();
1398 
1399         if (m_buildType == VK_ACCELERATION_STRUCTURE_BUILD_TYPE_DEVICE_KHR)
1400         {
1401             if (m_indirectBuffer == DE_NULL)
1402                 vk.cmdBuildAccelerationStructuresKHR(
1403                     cmdBuffer, 1u, &accelerationStructureBuildGeometryInfoKHR,
1404                     (const VkAccelerationStructureBuildRangeInfoKHR **)&accelerationStructureBuildRangeInfoKHRPtr);
1405             else
1406             {
1407                 VkDeviceAddress indirectDeviceAddress =
1408                     getBufferDeviceAddress(vk, device, m_indirectBuffer, m_indirectBufferOffset);
1409                 uint32_t *pMaxPrimitiveCounts = maxPrimitiveCounts.data();
1410                 vk.cmdBuildAccelerationStructuresIndirectKHR(cmdBuffer, 1u, &accelerationStructureBuildGeometryInfoKHR,
1411                                                              &indirectDeviceAddress, &m_indirectBufferStride,
1412                                                              &pMaxPrimitiveCounts);
1413             }
1414         }
1415         else if (!m_deferredOperation)
1416         {
1417             VK_CHECK(vk.buildAccelerationStructuresKHR(
1418                 device, DE_NULL, 1u, &accelerationStructureBuildGeometryInfoKHR,
1419                 (const VkAccelerationStructureBuildRangeInfoKHR **)&accelerationStructureBuildRangeInfoKHRPtr));
1420         }
1421         else
1422         {
1423             const auto deferredOperationPtr = createDeferredOperationKHR(vk, device);
1424             const auto deferredOperation    = deferredOperationPtr.get();
1425 
1426             VkResult result = vk.buildAccelerationStructuresKHR(
1427                 device, deferredOperation, 1u, &accelerationStructureBuildGeometryInfoKHR,
1428                 (const VkAccelerationStructureBuildRangeInfoKHR **)&accelerationStructureBuildRangeInfoKHRPtr);
1429 
1430             DE_ASSERT(result == VK_OPERATION_DEFERRED_KHR || result == VK_OPERATION_NOT_DEFERRED_KHR ||
1431                       result == VK_SUCCESS);
1432 
1433             finishDeferredOperation(vk, device, deferredOperation, m_workerThreadCount,
1434                                     result == VK_OPERATION_NOT_DEFERRED_KHR);
1435         }
1436     }
1437 
1438     if (m_buildType == VK_ACCELERATION_STRUCTURE_BUILD_TYPE_DEVICE_KHR)
1439     {
1440         const VkAccessFlags accessMasks =
1441             VK_ACCESS_ACCELERATION_STRUCTURE_WRITE_BIT_KHR | VK_ACCESS_ACCELERATION_STRUCTURE_READ_BIT_KHR;
1442         const VkMemoryBarrier memBarrier = makeMemoryBarrier(accessMasks, accessMasks);
1443 
1444         cmdPipelineMemoryBarrier(vk, cmdBuffer, VK_PIPELINE_STAGE_ACCELERATION_STRUCTURE_BUILD_BIT_KHR,
1445                                  VK_PIPELINE_STAGE_ALL_COMMANDS_BIT, &memBarrier);
1446     }
1447 }
1448 
copyFrom(const DeviceInterface & vk,const VkDevice device,const VkCommandBuffer cmdBuffer,BottomLevelAccelerationStructure * accelerationStructure,bool compactCopy)1449 void BottomLevelAccelerationStructureKHR::copyFrom(const DeviceInterface &vk, const VkDevice device,
1450                                                    const VkCommandBuffer cmdBuffer,
1451                                                    BottomLevelAccelerationStructure *accelerationStructure,
1452                                                    bool compactCopy)
1453 {
1454     DE_ASSERT(m_accelerationStructureKHR.get() != DE_NULL);
1455     DE_ASSERT(accelerationStructure != DE_NULL);
1456 
1457     VkCopyAccelerationStructureInfoKHR copyAccelerationStructureInfo = {
1458         VK_STRUCTURE_TYPE_COPY_ACCELERATION_STRUCTURE_INFO_KHR, // VkStructureType sType;
1459         DE_NULL,                                                // const void* pNext;
1460         *(accelerationStructure->getPtr()),                     // VkAccelerationStructureKHR src;
1461         *(getPtr()),                                            // VkAccelerationStructureKHR dst;
1462         compactCopy ? VK_COPY_ACCELERATION_STRUCTURE_MODE_COMPACT_KHR :
1463                       VK_COPY_ACCELERATION_STRUCTURE_MODE_CLONE_KHR // VkCopyAccelerationStructureModeKHR mode;
1464     };
1465 
1466     if (m_buildType == VK_ACCELERATION_STRUCTURE_BUILD_TYPE_DEVICE_KHR)
1467     {
1468         vk.cmdCopyAccelerationStructureKHR(cmdBuffer, &copyAccelerationStructureInfo);
1469     }
1470     else if (!m_deferredOperation)
1471     {
1472         VK_CHECK(vk.copyAccelerationStructureKHR(device, DE_NULL, &copyAccelerationStructureInfo));
1473     }
1474     else
1475     {
1476         const auto deferredOperationPtr = createDeferredOperationKHR(vk, device);
1477         const auto deferredOperation    = deferredOperationPtr.get();
1478 
1479         VkResult result = vk.copyAccelerationStructureKHR(device, deferredOperation, &copyAccelerationStructureInfo);
1480 
1481         DE_ASSERT(result == VK_OPERATION_DEFERRED_KHR || result == VK_OPERATION_NOT_DEFERRED_KHR ||
1482                   result == VK_SUCCESS);
1483 
1484         finishDeferredOperation(vk, device, deferredOperation, m_workerThreadCount,
1485                                 result == VK_OPERATION_NOT_DEFERRED_KHR);
1486     }
1487 
1488     if (m_buildType == VK_ACCELERATION_STRUCTURE_BUILD_TYPE_DEVICE_KHR)
1489     {
1490         const VkAccessFlags accessMasks =
1491             VK_ACCESS_ACCELERATION_STRUCTURE_WRITE_BIT_KHR | VK_ACCESS_ACCELERATION_STRUCTURE_READ_BIT_KHR;
1492         const VkMemoryBarrier memBarrier = makeMemoryBarrier(accessMasks, accessMasks);
1493 
1494         cmdPipelineMemoryBarrier(vk, cmdBuffer, VK_PIPELINE_STAGE_ACCELERATION_STRUCTURE_BUILD_BIT_KHR,
1495                                  VK_PIPELINE_STAGE_ALL_COMMANDS_BIT, &memBarrier);
1496     }
1497 }
1498 
serialize(const DeviceInterface & vk,const VkDevice device,const VkCommandBuffer cmdBuffer,SerialStorage * storage)1499 void BottomLevelAccelerationStructureKHR::serialize(const DeviceInterface &vk, const VkDevice device,
1500                                                     const VkCommandBuffer cmdBuffer, SerialStorage *storage)
1501 {
1502     DE_ASSERT(m_accelerationStructureKHR.get() != DE_NULL);
1503     DE_ASSERT(storage != DE_NULL);
1504 
1505     const VkCopyAccelerationStructureToMemoryInfoKHR copyAccelerationStructureInfo = {
1506         VK_STRUCTURE_TYPE_COPY_ACCELERATION_STRUCTURE_TO_MEMORY_INFO_KHR, // VkStructureType sType;
1507         DE_NULL,                                                          // const void* pNext;
1508         *(getPtr()),                                                      // VkAccelerationStructureKHR src;
1509         storage->getAddress(vk, device, m_buildType),                     // VkDeviceOrHostAddressKHR dst;
1510         VK_COPY_ACCELERATION_STRUCTURE_MODE_SERIALIZE_KHR                 // VkCopyAccelerationStructureModeKHR mode;
1511     };
1512 
1513     if (m_buildType == VK_ACCELERATION_STRUCTURE_BUILD_TYPE_DEVICE_KHR)
1514     {
1515         vk.cmdCopyAccelerationStructureToMemoryKHR(cmdBuffer, &copyAccelerationStructureInfo);
1516     }
1517     else if (!m_deferredOperation)
1518     {
1519         VK_CHECK(vk.copyAccelerationStructureToMemoryKHR(device, DE_NULL, &copyAccelerationStructureInfo));
1520     }
1521     else
1522     {
1523         const auto deferredOperationPtr = createDeferredOperationKHR(vk, device);
1524         const auto deferredOperation    = deferredOperationPtr.get();
1525 
1526         const VkResult result =
1527             vk.copyAccelerationStructureToMemoryKHR(device, deferredOperation, &copyAccelerationStructureInfo);
1528 
1529         DE_ASSERT(result == VK_OPERATION_DEFERRED_KHR || result == VK_OPERATION_NOT_DEFERRED_KHR ||
1530                   result == VK_SUCCESS);
1531 
1532         finishDeferredOperation(vk, device, deferredOperation, m_workerThreadCount,
1533                                 result == VK_OPERATION_NOT_DEFERRED_KHR);
1534     }
1535 }
1536 
deserialize(const DeviceInterface & vk,const VkDevice device,const VkCommandBuffer cmdBuffer,SerialStorage * storage)1537 void BottomLevelAccelerationStructureKHR::deserialize(const DeviceInterface &vk, const VkDevice device,
1538                                                       const VkCommandBuffer cmdBuffer, SerialStorage *storage)
1539 {
1540     DE_ASSERT(m_accelerationStructureKHR.get() != DE_NULL);
1541     DE_ASSERT(storage != DE_NULL);
1542 
1543     const VkCopyMemoryToAccelerationStructureInfoKHR copyAccelerationStructureInfo = {
1544         VK_STRUCTURE_TYPE_COPY_MEMORY_TO_ACCELERATION_STRUCTURE_INFO_KHR, // VkStructureType sType;
1545         DE_NULL,                                                          // const void* pNext;
1546         storage->getAddressConst(vk, device, m_buildType),                // VkDeviceOrHostAddressConstKHR src;
1547         *(getPtr()),                                                      // VkAccelerationStructureKHR dst;
1548         VK_COPY_ACCELERATION_STRUCTURE_MODE_DESERIALIZE_KHR               // VkCopyAccelerationStructureModeKHR mode;
1549     };
1550 
1551     if (m_buildType == VK_ACCELERATION_STRUCTURE_BUILD_TYPE_DEVICE_KHR)
1552     {
1553         vk.cmdCopyMemoryToAccelerationStructureKHR(cmdBuffer, &copyAccelerationStructureInfo);
1554     }
1555     else if (!m_deferredOperation)
1556     {
1557         VK_CHECK(vk.copyMemoryToAccelerationStructureKHR(device, DE_NULL, &copyAccelerationStructureInfo));
1558     }
1559     else
1560     {
1561         const auto deferredOperationPtr = createDeferredOperationKHR(vk, device);
1562         const auto deferredOperation    = deferredOperationPtr.get();
1563 
1564         const VkResult result =
1565             vk.copyMemoryToAccelerationStructureKHR(device, deferredOperation, &copyAccelerationStructureInfo);
1566 
1567         DE_ASSERT(result == VK_OPERATION_DEFERRED_KHR || result == VK_OPERATION_NOT_DEFERRED_KHR ||
1568                   result == VK_SUCCESS);
1569 
1570         finishDeferredOperation(vk, device, deferredOperation, m_workerThreadCount,
1571                                 result == VK_OPERATION_NOT_DEFERRED_KHR);
1572     }
1573 
1574     if (m_buildType == VK_ACCELERATION_STRUCTURE_BUILD_TYPE_DEVICE_KHR)
1575     {
1576         const VkAccessFlags accessMasks =
1577             VK_ACCESS_ACCELERATION_STRUCTURE_WRITE_BIT_KHR | VK_ACCESS_ACCELERATION_STRUCTURE_READ_BIT_KHR;
1578         const VkMemoryBarrier memBarrier = makeMemoryBarrier(accessMasks, accessMasks);
1579 
1580         cmdPipelineMemoryBarrier(vk, cmdBuffer, VK_PIPELINE_STAGE_ACCELERATION_STRUCTURE_BUILD_BIT_KHR,
1581                                  VK_PIPELINE_STAGE_ALL_COMMANDS_BIT, &memBarrier);
1582     }
1583 }
1584 
getPtr(void) const1585 const VkAccelerationStructureKHR *BottomLevelAccelerationStructureKHR::getPtr(void) const
1586 {
1587     return &m_accelerationStructureKHR.get();
1588 }
1589 
prepareGeometries(const DeviceInterface & vk,const VkDevice device,std::vector<VkAccelerationStructureGeometryKHR> & accelerationStructureGeometriesKHR,std::vector<VkAccelerationStructureGeometryKHR * > & accelerationStructureGeometriesKHRPointers,std::vector<VkAccelerationStructureBuildRangeInfoKHR> & accelerationStructureBuildRangeInfoKHR,std::vector<VkAccelerationStructureTrianglesOpacityMicromapEXT> & accelerationStructureGeometryMicromapsEXT,std::vector<uint32_t> & maxPrimitiveCounts,VkDeviceSize vertexBufferOffset,VkDeviceSize indexBufferOffset) const1590 void BottomLevelAccelerationStructureKHR::prepareGeometries(
1591     const DeviceInterface &vk, const VkDevice device,
1592     std::vector<VkAccelerationStructureGeometryKHR> &accelerationStructureGeometriesKHR,
1593     std::vector<VkAccelerationStructureGeometryKHR *> &accelerationStructureGeometriesKHRPointers,
1594     std::vector<VkAccelerationStructureBuildRangeInfoKHR> &accelerationStructureBuildRangeInfoKHR,
1595     std::vector<VkAccelerationStructureTrianglesOpacityMicromapEXT> &accelerationStructureGeometryMicromapsEXT,
1596     std::vector<uint32_t> &maxPrimitiveCounts, VkDeviceSize vertexBufferOffset, VkDeviceSize indexBufferOffset) const
1597 {
1598     accelerationStructureGeometriesKHR.resize(m_geometriesData.size());
1599     accelerationStructureGeometriesKHRPointers.resize(m_geometriesData.size());
1600     accelerationStructureBuildRangeInfoKHR.resize(m_geometriesData.size());
1601     accelerationStructureGeometryMicromapsEXT.resize(m_geometriesData.size());
1602     maxPrimitiveCounts.resize(m_geometriesData.size());
1603 
1604     for (size_t geometryNdx = 0; geometryNdx < m_geometriesData.size(); ++geometryNdx)
1605     {
1606         const de::SharedPtr<RaytracedGeometryBase> &geometryData = m_geometriesData[geometryNdx];
1607         VkDeviceOrHostAddressConstKHR vertexData, indexData;
1608         if (m_buildType == VK_ACCELERATION_STRUCTURE_BUILD_TYPE_DEVICE_KHR)
1609         {
1610             if (getVertexBuffer() != DE_NULL)
1611             {
1612                 vertexData = makeDeviceOrHostAddressConstKHR(vk, device, getVertexBuffer()->get(), vertexBufferOffset);
1613                 if (m_indirectBuffer == DE_NULL)
1614                 {
1615                     vertexBufferOffset += deAlignSize(geometryData->getVertexByteSize(), 8);
1616                 }
1617             }
1618             else
1619                 vertexData = makeDeviceOrHostAddressConstKHR(DE_NULL);
1620 
1621             if (getIndexBuffer() != DE_NULL && geometryData->getIndexType() != VK_INDEX_TYPE_NONE_KHR)
1622             {
1623                 indexData = makeDeviceOrHostAddressConstKHR(vk, device, getIndexBuffer()->get(), indexBufferOffset);
1624                 indexBufferOffset += deAlignSize(geometryData->getIndexByteSize(), 8);
1625             }
1626             else
1627                 indexData = makeDeviceOrHostAddressConstKHR(DE_NULL);
1628         }
1629         else
1630         {
1631             vertexData = makeDeviceOrHostAddressConstKHR(geometryData->getVertexPointer());
1632             if (geometryData->getIndexType() != VK_INDEX_TYPE_NONE_KHR)
1633                 indexData = makeDeviceOrHostAddressConstKHR(geometryData->getIndexPointer());
1634             else
1635                 indexData = makeDeviceOrHostAddressConstKHR(DE_NULL);
1636         }
1637 
1638         VkAccelerationStructureGeometryTrianglesDataKHR accelerationStructureGeometryTrianglesDataKHR = {
1639             VK_STRUCTURE_TYPE_ACCELERATION_STRUCTURE_GEOMETRY_TRIANGLES_DATA_KHR, //  VkStructureType sType;
1640             DE_NULL,                                                              //  const void* pNext;
1641             geometryData->getVertexFormat(),                                      //  VkFormat vertexFormat;
1642             vertexData,                                            //  VkDeviceOrHostAddressConstKHR vertexData;
1643             geometryData->getVertexStride(),                       //  VkDeviceSize vertexStride;
1644             static_cast<uint32_t>(geometryData->getVertexCount()), //  uint32_t maxVertex;
1645             geometryData->getIndexType(),                          //  VkIndexType indexType;
1646             indexData,                                             //  VkDeviceOrHostAddressConstKHR indexData;
1647             makeDeviceOrHostAddressConstKHR(DE_NULL),              //  VkDeviceOrHostAddressConstKHR transformData;
1648         };
1649 
1650         if (geometryData->getHasOpacityMicromap())
1651             accelerationStructureGeometryTrianglesDataKHR.pNext = &geometryData->getOpacityMicromap();
1652 
1653         const VkAccelerationStructureGeometryAabbsDataKHR accelerationStructureGeometryAabbsDataKHR = {
1654             VK_STRUCTURE_TYPE_ACCELERATION_STRUCTURE_GEOMETRY_AABBS_DATA_KHR, //  VkStructureType sType;
1655             DE_NULL,                                                          //  const void* pNext;
1656             vertexData,                                                       //  VkDeviceOrHostAddressConstKHR data;
1657             geometryData->getAABBStride()                                     //  VkDeviceSize stride;
1658         };
1659         const VkAccelerationStructureGeometryDataKHR geometry =
1660             (geometryData->isTrianglesType()) ?
1661                 makeVkAccelerationStructureGeometryDataKHR(accelerationStructureGeometryTrianglesDataKHR) :
1662                 makeVkAccelerationStructureGeometryDataKHR(accelerationStructureGeometryAabbsDataKHR);
1663         const VkAccelerationStructureGeometryKHR accelerationStructureGeometryKHR = {
1664             VK_STRUCTURE_TYPE_ACCELERATION_STRUCTURE_GEOMETRY_KHR, //  VkStructureType sType;
1665             DE_NULL,                                               //  const void* pNext;
1666             geometryData->getGeometryType(),                       //  VkGeometryTypeKHR geometryType;
1667             geometry,                                              //  VkAccelerationStructureGeometryDataKHR geometry;
1668             geometryData->getGeometryFlags()                       //  VkGeometryFlagsKHR flags;
1669         };
1670 
1671         const uint32_t primitiveCount = (m_buildWithoutPrimitives ? 0u : geometryData->getPrimitiveCount());
1672 
1673         const VkAccelerationStructureBuildRangeInfoKHR accelerationStructureBuildRangeInfosKHR = {
1674             primitiveCount, //  uint32_t primitiveCount;
1675             0,              //  uint32_t primitiveOffset;
1676             0,              //  uint32_t firstVertex;
1677             0               //  uint32_t firstTransform;
1678         };
1679 
1680         accelerationStructureGeometriesKHR[geometryNdx]         = accelerationStructureGeometryKHR;
1681         accelerationStructureGeometriesKHRPointers[geometryNdx] = &accelerationStructureGeometriesKHR[geometryNdx];
1682         accelerationStructureBuildRangeInfoKHR[geometryNdx]     = accelerationStructureBuildRangeInfosKHR;
1683         maxPrimitiveCounts[geometryNdx]                         = geometryData->getPrimitiveCount();
1684     }
1685 }
1686 
getRequiredAllocationCount(void)1687 uint32_t BottomLevelAccelerationStructure::getRequiredAllocationCount(void)
1688 {
1689     return BottomLevelAccelerationStructureKHR::getRequiredAllocationCount();
1690 }
1691 
createAndBuild(const DeviceInterface & vk,const VkDevice device,const VkCommandBuffer cmdBuffer,Allocator & allocator,VkDeviceAddress deviceAddress)1692 void BottomLevelAccelerationStructure::createAndBuild(const DeviceInterface &vk, const VkDevice device,
1693                                                       const VkCommandBuffer cmdBuffer, Allocator &allocator,
1694                                                       VkDeviceAddress deviceAddress)
1695 {
1696     create(vk, device, allocator, 0u, deviceAddress);
1697     build(vk, device, cmdBuffer);
1698 }
1699 
createAndCopyFrom(const DeviceInterface & vk,const VkDevice device,const VkCommandBuffer cmdBuffer,Allocator & allocator,BottomLevelAccelerationStructure * accelerationStructure,VkDeviceSize compactCopySize,VkDeviceAddress deviceAddress)1700 void BottomLevelAccelerationStructure::createAndCopyFrom(const DeviceInterface &vk, const VkDevice device,
1701                                                          const VkCommandBuffer cmdBuffer, Allocator &allocator,
1702                                                          BottomLevelAccelerationStructure *accelerationStructure,
1703                                                          VkDeviceSize compactCopySize, VkDeviceAddress deviceAddress)
1704 {
1705     DE_ASSERT(accelerationStructure != NULL);
1706     VkDeviceSize copiedSize = compactCopySize > 0u ?
1707                                   compactCopySize :
1708                                   accelerationStructure->getStructureBuildSizes().accelerationStructureSize;
1709     DE_ASSERT(copiedSize != 0u);
1710 
1711     create(vk, device, allocator, copiedSize, deviceAddress);
1712     copyFrom(vk, device, cmdBuffer, accelerationStructure, compactCopySize > 0u);
1713 }
1714 
createAndDeserializeFrom(const DeviceInterface & vk,const VkDevice device,const VkCommandBuffer cmdBuffer,Allocator & allocator,SerialStorage * storage,VkDeviceAddress deviceAddress)1715 void BottomLevelAccelerationStructure::createAndDeserializeFrom(const DeviceInterface &vk, const VkDevice device,
1716                                                                 const VkCommandBuffer cmdBuffer, Allocator &allocator,
1717                                                                 SerialStorage *storage, VkDeviceAddress deviceAddress)
1718 {
1719     DE_ASSERT(storage != NULL);
1720     DE_ASSERT(storage->getStorageSize() >= SerialStorage::SERIAL_STORAGE_SIZE_MIN);
1721     create(vk, device, allocator, storage->getDeserializedSize(), deviceAddress);
1722     deserialize(vk, device, cmdBuffer, storage);
1723 }
1724 
updateGeometry(size_t geometryIndex,de::SharedPtr<RaytracedGeometryBase> & raytracedGeometry)1725 void BottomLevelAccelerationStructureKHR::updateGeometry(size_t geometryIndex,
1726                                                          de::SharedPtr<RaytracedGeometryBase> &raytracedGeometry)
1727 {
1728     DE_ASSERT(geometryIndex < m_geometriesData.size());
1729     m_geometriesData[geometryIndex] = raytracedGeometry;
1730 }
1731 
makeBottomLevelAccelerationStructure()1732 de::MovePtr<BottomLevelAccelerationStructure> makeBottomLevelAccelerationStructure()
1733 {
1734     return de::MovePtr<BottomLevelAccelerationStructure>(new BottomLevelAccelerationStructureKHR);
1735 }
1736 
1737 // Forward declaration
1738 struct BottomLevelAccelerationStructurePoolImpl;
1739 
1740 class BottomLevelAccelerationStructurePoolMember : public BottomLevelAccelerationStructureKHR
1741 {
1742 public:
1743     friend class BottomLevelAccelerationStructurePool;
1744 
1745     BottomLevelAccelerationStructurePoolMember(BottomLevelAccelerationStructurePoolImpl &pool);
1746     BottomLevelAccelerationStructurePoolMember(const BottomLevelAccelerationStructurePoolMember &) = delete;
1747     BottomLevelAccelerationStructurePoolMember(BottomLevelAccelerationStructurePoolMember &&)      = delete;
1748     virtual ~BottomLevelAccelerationStructurePoolMember()                                          = default;
1749 
create(const DeviceInterface &,const VkDevice,Allocator &,VkDeviceSize,VkDeviceAddress,const void *,const MemoryRequirement &,const VkBuffer,const VkDeviceSize)1750     virtual void create(const DeviceInterface &, const VkDevice, Allocator &, VkDeviceSize, VkDeviceAddress,
1751                         const void *, const MemoryRequirement &, const VkBuffer, const VkDeviceSize) override
1752     {
1753         DE_ASSERT(0); // Silent this method
1754     }
1755     virtual auto computeBuildSize(const DeviceInterface &vk, const VkDevice device, const VkDeviceSize strSize) const
1756         //              accStrSize,updateScratch, buildScratch, vertexSize,   indexSize
1757         -> std::tuple<VkDeviceSize, VkDeviceSize, VkDeviceSize, VkDeviceSize, VkDeviceSize>;
1758 
1759 protected:
1760     struct Info;
1761     virtual void preCreateSetSizesAndOffsets(const Info &info, const VkDeviceSize accStrSize,
1762                                              const VkDeviceSize updateScratchSize, const VkDeviceSize buildScratchSize);
1763     virtual void createAccellerationStructure(const DeviceInterface &vk, const VkDevice device,
1764                                               VkDeviceAddress deviceAddress);
1765 
1766     virtual BufferWithMemory *getAccelerationStructureBuffer() const override;
1767     virtual BufferWithMemory *getDeviceScratchBuffer() const override;
1768     virtual std::vector<uint8_t> *getHostScratchBuffer() const override;
1769     virtual BufferWithMemory *getVertexBuffer() const override;
1770     virtual BufferWithMemory *getIndexBuffer() const override;
1771 
getAccelerationStructureBufferOffset() const1772     virtual VkDeviceSize getAccelerationStructureBufferOffset() const override
1773     {
1774         return m_info.accStrOffset;
1775     }
getDeviceScratchBufferOffset() const1776     virtual VkDeviceSize getDeviceScratchBufferOffset() const override
1777     {
1778         return m_info.buildScratchBuffOffset;
1779     }
getVertexBufferOffset() const1780     virtual VkDeviceSize getVertexBufferOffset() const override
1781     {
1782         return m_info.vertBuffOffset;
1783     }
getIndexBufferOffset() const1784     virtual VkDeviceSize getIndexBufferOffset() const override
1785     {
1786         return m_info.indexBuffOffset;
1787     }
1788 
1789     BottomLevelAccelerationStructurePoolImpl &m_pool;
1790 
1791     struct Info
1792     {
1793         uint32_t accStrIndex;
1794         VkDeviceSize accStrOffset;
1795         uint32_t vertBuffIndex;
1796         VkDeviceSize vertBuffOffset;
1797         uint32_t indexBuffIndex;
1798         VkDeviceSize indexBuffOffset;
1799         uint32_t buildScratchBuffIndex;
1800         VkDeviceSize buildScratchBuffOffset;
1801     } m_info;
1802 };
1803 
1804 template <class X>
negz(const X &)1805 inline X negz(const X &)
1806 {
1807     return (~static_cast<X>(0));
1808 }
1809 template <class X>
isnegz(const X & x)1810 inline bool isnegz(const X &x)
1811 {
1812     return x == negz(x);
1813 }
1814 template <class Y>
make_unsigned(const Y & y)1815 inline auto make_unsigned(const Y &y) -> typename std::make_unsigned<Y>::type
1816 {
1817     return static_cast<typename std::make_unsigned<Y>::type>(y);
1818 }
1819 
BottomLevelAccelerationStructurePoolMember(BottomLevelAccelerationStructurePoolImpl & pool)1820 BottomLevelAccelerationStructurePoolMember::BottomLevelAccelerationStructurePoolMember(
1821     BottomLevelAccelerationStructurePoolImpl &pool)
1822     : m_pool(pool)
1823     , m_info{}
1824 {
1825 }
1826 
1827 struct BottomLevelAccelerationStructurePoolImpl
1828 {
1829     BottomLevelAccelerationStructurePoolImpl(BottomLevelAccelerationStructurePoolImpl &&)      = delete;
1830     BottomLevelAccelerationStructurePoolImpl(const BottomLevelAccelerationStructurePoolImpl &) = delete;
1831     BottomLevelAccelerationStructurePoolImpl(BottomLevelAccelerationStructurePool &pool);
1832 
1833     BottomLevelAccelerationStructurePool &m_pool;
1834     std::vector<de::SharedPtr<BufferWithMemory>> m_accellerationStructureBuffers;
1835     de::SharedPtr<BufferWithMemory> m_deviceScratchBuffer;
1836     de::UniquePtr<std::vector<uint8_t>> m_hostScratchBuffer;
1837     std::vector<de::SharedPtr<BufferWithMemory>> m_vertexBuffers;
1838     std::vector<de::SharedPtr<BufferWithMemory>> m_indexBuffers;
1839 };
BottomLevelAccelerationStructurePoolImpl(BottomLevelAccelerationStructurePool & pool)1840 BottomLevelAccelerationStructurePoolImpl::BottomLevelAccelerationStructurePoolImpl(
1841     BottomLevelAccelerationStructurePool &pool)
1842     : m_pool(pool)
1843     , m_accellerationStructureBuffers()
1844     , m_deviceScratchBuffer()
1845     , m_hostScratchBuffer(new std::vector<uint8_t>)
1846     , m_vertexBuffers()
1847     , m_indexBuffers()
1848 {
1849 }
getAccelerationStructureBuffer() const1850 BufferWithMemory *BottomLevelAccelerationStructurePoolMember::getAccelerationStructureBuffer() const
1851 {
1852     BufferWithMemory *result = nullptr;
1853     if (m_pool.m_accellerationStructureBuffers.size())
1854     {
1855         DE_ASSERT(!isnegz(m_info.accStrIndex));
1856         result = m_pool.m_accellerationStructureBuffers[m_info.accStrIndex].get();
1857     }
1858     return result;
1859 }
getDeviceScratchBuffer() const1860 BufferWithMemory *BottomLevelAccelerationStructurePoolMember::getDeviceScratchBuffer() const
1861 {
1862     DE_ASSERT(m_info.buildScratchBuffIndex == 0);
1863     return m_pool.m_deviceScratchBuffer.get();
1864 }
getHostScratchBuffer() const1865 std::vector<uint8_t> *BottomLevelAccelerationStructurePoolMember::getHostScratchBuffer() const
1866 {
1867     return this->m_buildScratchSize ? m_pool.m_hostScratchBuffer.get() : nullptr;
1868 }
1869 
getVertexBuffer() const1870 BufferWithMemory *BottomLevelAccelerationStructurePoolMember::getVertexBuffer() const
1871 {
1872     BufferWithMemory *result = nullptr;
1873     if (m_pool.m_vertexBuffers.size())
1874     {
1875         DE_ASSERT(!isnegz(m_info.vertBuffIndex));
1876         result = m_pool.m_vertexBuffers[m_info.vertBuffIndex].get();
1877     }
1878     return result;
1879 }
getIndexBuffer() const1880 BufferWithMemory *BottomLevelAccelerationStructurePoolMember::getIndexBuffer() const
1881 {
1882     BufferWithMemory *result = nullptr;
1883     if (m_pool.m_indexBuffers.size())
1884     {
1885         DE_ASSERT(!isnegz(m_info.indexBuffIndex));
1886         result = m_pool.m_indexBuffers[m_info.indexBuffIndex].get();
1887     }
1888     return result;
1889 }
1890 
1891 struct BottomLevelAccelerationStructurePool::Impl : BottomLevelAccelerationStructurePoolImpl
1892 {
1893     friend class BottomLevelAccelerationStructurePool;
1894     friend class BottomLevelAccelerationStructurePoolMember;
1895 
Implvk::BottomLevelAccelerationStructurePool::Impl1896     Impl(BottomLevelAccelerationStructurePool &pool) : BottomLevelAccelerationStructurePoolImpl(pool)
1897     {
1898     }
1899 };
1900 
BottomLevelAccelerationStructurePool()1901 BottomLevelAccelerationStructurePool::BottomLevelAccelerationStructurePool()
1902     : m_batchStructCount(4)
1903     , m_batchGeomCount(0)
1904     , m_infos()
1905     , m_structs()
1906     , m_createOnce(false)
1907     , m_tryCachedMemory(true)
1908     , m_structsBuffSize(0)
1909     , m_updatesScratchSize(0)
1910     , m_buildsScratchSize(0)
1911     , m_verticesSize(0)
1912     , m_indicesSize(0)
1913     , m_impl(new Impl(*this))
1914 {
1915 }
1916 
~BottomLevelAccelerationStructurePool()1917 BottomLevelAccelerationStructurePool::~BottomLevelAccelerationStructurePool()
1918 {
1919     delete m_impl;
1920 }
1921 
batchStructCount(const uint32_t & value)1922 void BottomLevelAccelerationStructurePool::batchStructCount(const uint32_t &value)
1923 {
1924     DE_ASSERT(value >= 1);
1925     m_batchStructCount = value;
1926 }
1927 
add(VkDeviceSize structureSize,VkDeviceAddress deviceAddress)1928 auto BottomLevelAccelerationStructurePool::add(VkDeviceSize structureSize, VkDeviceAddress deviceAddress)
1929     -> BottomLevelAccelerationStructurePool::BlasPtr
1930 {
1931     // Prevent a programmer from calling this method after batchCreate(...) method has been called.
1932     if (m_createOnce)
1933         DE_ASSERT(0);
1934 
1935     auto blas = new BottomLevelAccelerationStructurePoolMember(*m_impl);
1936     m_infos.push_back({structureSize, deviceAddress});
1937     m_structs.emplace_back(blas);
1938     return m_structs.back();
1939 }
1940 
adjustBatchCount(const DeviceInterface & vkd,const VkDevice device,const std::vector<BottomLevelAccelerationStructurePool::BlasPtr> & structs,const std::vector<BottomLevelAccelerationStructurePool::BlasInfo> & infos,const VkDeviceSize maxBufferSize,uint32_t (& result)[4])1941 void adjustBatchCount(const DeviceInterface &vkd, const VkDevice device,
1942                       const std::vector<BottomLevelAccelerationStructurePool::BlasPtr> &structs,
1943                       const std::vector<BottomLevelAccelerationStructurePool::BlasInfo> &infos,
1944                       const VkDeviceSize maxBufferSize, uint32_t (&result)[4])
1945 {
1946     tcu::Vector<VkDeviceSize, 4> sizes(0);
1947     tcu::Vector<VkDeviceSize, 4> sums(0);
1948     tcu::Vector<uint32_t, 4> tmps(0);
1949     tcu::Vector<uint32_t, 4> batches(0);
1950 
1951     VkDeviceSize updateScratchSize = 0;
1952     static_cast<void>(updateScratchSize); // not used yet, disabled for future implementation
1953 
1954     auto updateIf = [&](uint32_t c)
1955     {
1956         if (sums[c] + sizes[c] <= maxBufferSize)
1957         {
1958             sums[c] += sizes[c];
1959             tmps[c] += 1;
1960 
1961             batches[c] = std::max(tmps[c], batches[c]);
1962         }
1963         else
1964         {
1965             sums[c] = 0;
1966             tmps[c] = 0;
1967         }
1968     };
1969 
1970     const uint32_t maxIter = static_cast<uint32_t>(structs.size());
1971     for (uint32_t i = 0; i < maxIter; ++i)
1972     {
1973         auto &str = *dynamic_cast<BottomLevelAccelerationStructurePoolMember *>(structs[i].get());
1974         std::tie(sizes[0], updateScratchSize, sizes[1], sizes[2], sizes[3]) =
1975             str.computeBuildSize(vkd, device, infos[i].structureSize);
1976 
1977         updateIf(0);
1978         updateIf(1);
1979         updateIf(2);
1980         updateIf(3);
1981     }
1982 
1983     result[0] = std::max(batches[0], 1u);
1984     result[1] = std::max(batches[1], 1u);
1985     result[2] = std::max(batches[2], 1u);
1986     result[3] = std::max(batches[3], 1u);
1987 }
1988 
getAllocationCount() const1989 size_t BottomLevelAccelerationStructurePool::getAllocationCount() const
1990 {
1991     return m_impl->m_accellerationStructureBuffers.size() + m_impl->m_vertexBuffers.size() +
1992            m_impl->m_indexBuffers.size() + 1 /* for scratch buffer */;
1993 }
1994 
getAllocationCount(const DeviceInterface & vk,const VkDevice device,const VkDeviceSize maxBufferSize) const1995 size_t BottomLevelAccelerationStructurePool::getAllocationCount(const DeviceInterface &vk, const VkDevice device,
1996                                                                 const VkDeviceSize maxBufferSize) const
1997 {
1998     DE_ASSERT(m_structs.size() != 0);
1999 
2000     std::map<uint32_t, VkDeviceSize> accStrSizes;
2001     std::map<uint32_t, VkDeviceSize> vertBuffSizes;
2002     std::map<uint32_t, VkDeviceSize> indexBuffSizes;
2003     std::map<uint32_t, VkDeviceSize> scratchBuffSizes;
2004 
2005     const uint32_t allStructsCount = structCount();
2006 
2007     uint32_t batchStructCount  = m_batchStructCount;
2008     uint32_t batchScratchCount = m_batchStructCount;
2009     uint32_t batchVertexCount  = m_batchGeomCount ? m_batchGeomCount : m_batchStructCount;
2010     uint32_t batchIndexCount   = batchVertexCount;
2011 
2012     if (!isnegz(maxBufferSize))
2013     {
2014         uint32_t batches[4];
2015         adjustBatchCount(vk, device, m_structs, m_infos, maxBufferSize, batches);
2016         batchStructCount  = batches[0];
2017         batchScratchCount = batches[1];
2018         batchVertexCount  = batches[2];
2019         batchIndexCount   = batches[3];
2020     }
2021 
2022     uint32_t iStr     = 0;
2023     uint32_t iScratch = 0;
2024     uint32_t iVertex  = 0;
2025     uint32_t iIndex   = 0;
2026 
2027     VkDeviceSize strSize           = 0;
2028     VkDeviceSize updateScratchSize = 0;
2029     VkDeviceSize buildScratchSize  = 0;
2030     VkDeviceSize vertexSize        = 0;
2031     VkDeviceSize indexSize         = 0;
2032 
2033     for (; iStr < allStructsCount; ++iStr)
2034     {
2035         auto &str = *dynamic_cast<BottomLevelAccelerationStructurePoolMember *>(m_structs[iStr].get());
2036         std::tie(strSize, updateScratchSize, buildScratchSize, vertexSize, indexSize) =
2037             str.computeBuildSize(vk, device, m_infos[iStr].structureSize);
2038 
2039         {
2040             const VkDeviceSize alignedStrSize = deAlign64(strSize, 256);
2041             const uint32_t accStrIndex        = (iStr / batchStructCount);
2042             accStrSizes[accStrIndex] += alignedStrSize;
2043         }
2044 
2045         if (buildScratchSize != 0)
2046         {
2047             const VkDeviceSize alignedBuilsScratchSize = deAlign64(buildScratchSize, 256);
2048             const uint32_t scratchBuffIndex            = (iScratch / batchScratchCount);
2049             scratchBuffSizes[scratchBuffIndex] += alignedBuilsScratchSize;
2050             iScratch += 1;
2051         }
2052 
2053         if (vertexSize != 0)
2054         {
2055             const VkDeviceSize alignedVertBuffSize = deAlign64(vertexSize, 8);
2056             const uint32_t vertBuffIndex           = (iVertex / batchVertexCount);
2057             vertBuffSizes[vertBuffIndex] += alignedVertBuffSize;
2058             iVertex += 1;
2059         }
2060 
2061         if (indexSize != 0)
2062         {
2063             const VkDeviceSize alignedIndexBuffSize = deAlign64(indexSize, 8);
2064             const uint32_t indexBuffIndex           = (iIndex / batchIndexCount);
2065             indexBuffSizes[indexBuffIndex] += alignedIndexBuffSize;
2066             iIndex += 1;
2067         }
2068     }
2069 
2070     return accStrSizes.size() + vertBuffSizes.size() + indexBuffSizes.size() + scratchBuffSizes.size();
2071 }
2072 
getAllocationSizes(const DeviceInterface & vk,const VkDevice device) const2073 tcu::Vector<VkDeviceSize, 4> BottomLevelAccelerationStructurePool::getAllocationSizes(const DeviceInterface &vk,
2074                                                                                       const VkDevice device) const
2075 {
2076     if (m_structsBuffSize)
2077     {
2078         return tcu::Vector<VkDeviceSize, 4>(m_structsBuffSize, m_buildsScratchSize, m_verticesSize, m_indicesSize);
2079     }
2080 
2081     VkDeviceSize strSize           = 0;
2082     VkDeviceSize updateScratchSize = 0;
2083     static_cast<void>(updateScratchSize); // not used yet, disabled for future implementation
2084     VkDeviceSize buildScratchSize     = 0;
2085     VkDeviceSize vertexSize           = 0;
2086     VkDeviceSize indexSize            = 0;
2087     VkDeviceSize sumStrSize           = 0;
2088     VkDeviceSize sumUpdateScratchSize = 0;
2089     static_cast<void>(sumUpdateScratchSize); // not used yet, disabled for future implementation
2090     VkDeviceSize sumBuildScratchSize = 0;
2091     VkDeviceSize sumVertexSize       = 0;
2092     VkDeviceSize sumIndexSize        = 0;
2093     for (size_t i = 0; i < structCount(); ++i)
2094     {
2095         auto &str = *dynamic_cast<BottomLevelAccelerationStructurePoolMember *>(m_structs[i].get());
2096         std::tie(strSize, updateScratchSize, buildScratchSize, vertexSize, indexSize) =
2097             str.computeBuildSize(vk, device, m_infos[i].structureSize);
2098         sumStrSize += deAlign64(strSize, 256);
2099         //sumUpdateScratchSize    += deAlign64(updateScratchSize, 256);    not used yet, disabled for future implementation
2100         sumBuildScratchSize += deAlign64(buildScratchSize, 256);
2101         sumVertexSize += deAlign64(vertexSize, 8);
2102         sumIndexSize += deAlign64(indexSize, 8);
2103     }
2104     return tcu::Vector<VkDeviceSize, 4>(sumStrSize, sumBuildScratchSize, sumVertexSize, sumIndexSize);
2105 }
2106 
batchCreate(const DeviceInterface & vkd,const VkDevice device,Allocator & allocator)2107 void BottomLevelAccelerationStructurePool::batchCreate(const DeviceInterface &vkd, const VkDevice device,
2108                                                        Allocator &allocator)
2109 {
2110     batchCreateAdjust(vkd, device, allocator, negz<VkDeviceSize>(0));
2111 }
2112 
batchCreateAdjust(const DeviceInterface & vkd,const VkDevice device,Allocator & allocator,const VkDeviceSize maxBufferSize)2113 void BottomLevelAccelerationStructurePool::batchCreateAdjust(const DeviceInterface &vkd, const VkDevice device,
2114                                                              Allocator &allocator, const VkDeviceSize maxBufferSize)
2115 {
2116     // Prevent a programmer from calling this method more than once.
2117     if (m_createOnce)
2118         DE_ASSERT(0);
2119 
2120     m_createOnce = true;
2121     DE_ASSERT(m_structs.size() != 0);
2122 
2123     auto createAccellerationStructureBuffer = [&](VkDeviceSize bufferSize) ->
2124         typename std::add_pointer<BufferWithMemory>::type
2125     {
2126         BufferWithMemory *res = nullptr;
2127         const VkBufferCreateInfo bci =
2128             makeBufferCreateInfo(bufferSize, VK_BUFFER_USAGE_ACCELERATION_STRUCTURE_STORAGE_BIT_KHR |
2129                                                  VK_BUFFER_USAGE_SHADER_DEVICE_ADDRESS_BIT);
2130 
2131         if (m_tryCachedMemory)
2132             try
2133             {
2134                 res = new BufferWithMemory(vkd, device, allocator, bci,
2135                                            MemoryRequirement::Cached | MemoryRequirement::HostVisible |
2136                                                MemoryRequirement::Coherent | MemoryRequirement::DeviceAddress);
2137             }
2138             catch (const tcu::NotSupportedError &)
2139             {
2140                 res = nullptr;
2141             }
2142 
2143         return (nullptr != res) ? res :
2144                                   (new BufferWithMemory(vkd, device, allocator, bci,
2145                                                         MemoryRequirement::HostVisible | MemoryRequirement::Coherent |
2146                                                             MemoryRequirement::DeviceAddress));
2147     };
2148 
2149     auto createDeviceScratchBuffer = [&](VkDeviceSize bufferSize) -> de::SharedPtr<BufferWithMemory>
2150     {
2151         const VkBufferCreateInfo bci = makeBufferCreateInfo(bufferSize, VK_BUFFER_USAGE_STORAGE_BUFFER_BIT |
2152                                                                             VK_BUFFER_USAGE_SHADER_DEVICE_ADDRESS_BIT);
2153         BufferWithMemory *p          = new BufferWithMemory(vkd, device, allocator, bci,
2154                                                             MemoryRequirement::HostVisible | MemoryRequirement::Coherent |
2155                                                                 MemoryRequirement::DeviceAddress);
2156         return de::SharedPtr<BufferWithMemory>(p);
2157     };
2158 
2159     std::map<uint32_t, VkDeviceSize> accStrSizes;
2160     std::map<uint32_t, VkDeviceSize> vertBuffSizes;
2161     std::map<uint32_t, VkDeviceSize> indexBuffSizes;
2162 
2163     const uint32_t allStructsCount = structCount();
2164     uint32_t iterKey               = 0;
2165 
2166     uint32_t batchStructCount = m_batchStructCount;
2167     uint32_t batchVertexCount = m_batchGeomCount ? m_batchGeomCount : m_batchStructCount;
2168     uint32_t batchIndexCount  = batchVertexCount;
2169 
2170     if (!isnegz(maxBufferSize))
2171     {
2172         uint32_t batches[4];
2173         adjustBatchCount(vkd, device, m_structs, m_infos, maxBufferSize, batches);
2174         batchStructCount = batches[0];
2175         // batches[1]: batchScratchCount
2176         batchVertexCount = batches[2];
2177         batchIndexCount  = batches[3];
2178     }
2179 
2180     uint32_t iStr    = 0;
2181     uint32_t iVertex = 0;
2182     uint32_t iIndex  = 0;
2183 
2184     VkDeviceSize strSize             = 0;
2185     VkDeviceSize updateScratchSize   = 0;
2186     VkDeviceSize buildScratchSize    = 0;
2187     VkDeviceSize maxBuildScratchSize = 0;
2188     VkDeviceSize vertexSize          = 0;
2189     VkDeviceSize indexSize           = 0;
2190 
2191     VkDeviceSize strOffset    = 0;
2192     VkDeviceSize vertexOffset = 0;
2193     VkDeviceSize indexOffset  = 0;
2194 
2195     uint32_t hostStructCount   = 0;
2196     uint32_t deviceStructCount = 0;
2197 
2198     for (; iStr < allStructsCount; ++iStr)
2199     {
2200         BottomLevelAccelerationStructurePoolMember::Info info{};
2201         auto &str = *dynamic_cast<BottomLevelAccelerationStructurePoolMember *>(m_structs[iStr].get());
2202         std::tie(strSize, updateScratchSize, buildScratchSize, vertexSize, indexSize) =
2203             str.computeBuildSize(vkd, device, m_infos[iStr].structureSize);
2204 
2205         ++(str.getBuildType() == VK_ACCELERATION_STRUCTURE_BUILD_TYPE_HOST_KHR ? hostStructCount : deviceStructCount);
2206 
2207         {
2208             const VkDeviceSize alignedStrSize = deAlign64(strSize, 256);
2209             const uint32_t accStrIndex        = (iStr / batchStructCount);
2210             if (iStr != 0 && (iStr % batchStructCount) == 0)
2211             {
2212                 strOffset = 0;
2213             }
2214 
2215             info.accStrIndex  = accStrIndex;
2216             info.accStrOffset = strOffset;
2217             accStrSizes[accStrIndex] += alignedStrSize;
2218             strOffset += alignedStrSize;
2219             m_structsBuffSize += alignedStrSize;
2220         }
2221 
2222         if (buildScratchSize != 0)
2223         {
2224             maxBuildScratchSize = std::max(maxBuildScratchSize, make_unsigned(deAlign64(buildScratchSize, 256u)));
2225 
2226             info.buildScratchBuffIndex  = 0;
2227             info.buildScratchBuffOffset = 0;
2228         }
2229 
2230         if (vertexSize != 0)
2231         {
2232             const VkDeviceSize alignedVertBuffSize = deAlign64(vertexSize, 8);
2233             const uint32_t vertBuffIndex           = (iVertex / batchVertexCount);
2234             if (iVertex != 0 && (iVertex % batchVertexCount) == 0)
2235             {
2236                 vertexOffset = 0;
2237             }
2238 
2239             info.vertBuffIndex  = vertBuffIndex;
2240             info.vertBuffOffset = vertexOffset;
2241             vertBuffSizes[vertBuffIndex] += alignedVertBuffSize;
2242             vertexOffset += alignedVertBuffSize;
2243             m_verticesSize += alignedVertBuffSize;
2244             iVertex += 1;
2245         }
2246 
2247         if (indexSize != 0)
2248         {
2249             const VkDeviceSize alignedIndexBuffSize = deAlign64(indexSize, 8);
2250             const uint32_t indexBuffIndex           = (iIndex / batchIndexCount);
2251             if (iIndex != 0 && (iIndex % batchIndexCount) == 0)
2252             {
2253                 indexOffset = 0;
2254             }
2255 
2256             info.indexBuffIndex  = indexBuffIndex;
2257             info.indexBuffOffset = indexOffset;
2258             indexBuffSizes[indexBuffIndex] += alignedIndexBuffSize;
2259             indexOffset += alignedIndexBuffSize;
2260             m_indicesSize += alignedIndexBuffSize;
2261             iIndex += 1;
2262         }
2263 
2264         str.preCreateSetSizesAndOffsets(info, strSize, updateScratchSize, buildScratchSize);
2265     }
2266 
2267     for (iterKey = 0; iterKey < static_cast<uint32_t>(accStrSizes.size()); ++iterKey)
2268     {
2269         m_impl->m_accellerationStructureBuffers.emplace_back(
2270             createAccellerationStructureBuffer(accStrSizes.at(iterKey)));
2271     }
2272     for (iterKey = 0; iterKey < static_cast<uint32_t>(vertBuffSizes.size()); ++iterKey)
2273     {
2274         m_impl->m_vertexBuffers.emplace_back(createVertexBuffer(vkd, device, allocator, vertBuffSizes.at(iterKey)));
2275     }
2276     for (iterKey = 0; iterKey < static_cast<uint32_t>(indexBuffSizes.size()); ++iterKey)
2277     {
2278         m_impl->m_indexBuffers.emplace_back(createIndexBuffer(vkd, device, allocator, indexBuffSizes.at(iterKey)));
2279     }
2280 
2281     if (maxBuildScratchSize)
2282     {
2283         if (hostStructCount)
2284             m_impl->m_hostScratchBuffer->resize(static_cast<size_t>(maxBuildScratchSize));
2285         if (deviceStructCount)
2286             m_impl->m_deviceScratchBuffer = createDeviceScratchBuffer(maxBuildScratchSize);
2287 
2288         m_buildsScratchSize = maxBuildScratchSize;
2289     }
2290 
2291     for (iterKey = 0; iterKey < allStructsCount; ++iterKey)
2292     {
2293         auto &str = *dynamic_cast<BottomLevelAccelerationStructurePoolMember *>(m_structs[iterKey].get());
2294         str.createAccellerationStructure(vkd, device, m_infos[iterKey].deviceAddress);
2295     }
2296 }
2297 
batchBuild(const DeviceInterface & vk,const VkDevice device,VkCommandBuffer cmdBuffer)2298 void BottomLevelAccelerationStructurePool::batchBuild(const DeviceInterface &vk, const VkDevice device,
2299                                                       VkCommandBuffer cmdBuffer)
2300 {
2301     for (const auto &str : m_structs)
2302     {
2303         str->build(vk, device, cmdBuffer);
2304     }
2305 }
2306 
batchBuild(const DeviceInterface & vk,const VkDevice device,VkCommandPool cmdPool,VkQueue queue,qpWatchDog * watchDog)2307 void BottomLevelAccelerationStructurePool::batchBuild(const DeviceInterface &vk, const VkDevice device,
2308                                                       VkCommandPool cmdPool, VkQueue queue, qpWatchDog *watchDog)
2309 {
2310     const uint32_t limit = 10000u;
2311     const uint32_t count = structCount();
2312     std::vector<BlasPtr> buildingOnDevice;
2313 
2314     auto buildOnDevice = [&]() -> void
2315     {
2316         Move<VkCommandBuffer> cmd = allocateCommandBuffer(vk, device, cmdPool, VK_COMMAND_BUFFER_LEVEL_PRIMARY);
2317 
2318         beginCommandBuffer(vk, *cmd, 0u);
2319         for (const auto &str : buildingOnDevice)
2320             str->build(vk, device, *cmd);
2321         endCommandBuffer(vk, *cmd);
2322 
2323         submitCommandsAndWait(vk, device, queue, *cmd);
2324         vk.resetCommandPool(device, cmdPool, VK_COMMAND_POOL_RESET_RELEASE_RESOURCES_BIT);
2325     };
2326 
2327     buildingOnDevice.reserve(limit);
2328     for (uint32_t i = 0; i < count; ++i)
2329     {
2330         auto str = m_structs[i];
2331 
2332         if (str->getBuildType() == VK_ACCELERATION_STRUCTURE_BUILD_TYPE_HOST_KHR)
2333             str->build(vk, device, DE_NULL);
2334         else
2335             buildingOnDevice.emplace_back(str);
2336 
2337         if (buildingOnDevice.size() == limit || (count - 1) == i)
2338         {
2339             buildOnDevice();
2340             buildingOnDevice.clear();
2341         }
2342 
2343         if ((i % WATCHDOG_INTERVAL) == 0 && watchDog)
2344             qpWatchDog_touch(watchDog);
2345     }
2346 }
2347 
computeBuildSize(const DeviceInterface & vk,const VkDevice device,const VkDeviceSize strSize) const2348 auto BottomLevelAccelerationStructurePoolMember::computeBuildSize(const DeviceInterface &vk, const VkDevice device,
2349                                                                   const VkDeviceSize strSize) const
2350     //              accStrSize,updateScratch,buildScratch, vertexSize, indexSize
2351     -> std::tuple<VkDeviceSize, VkDeviceSize, VkDeviceSize, VkDeviceSize, VkDeviceSize>
2352 {
2353     DE_ASSERT(!m_geometriesData.empty() != !(strSize == 0)); // logical xor
2354 
2355     std::tuple<VkDeviceSize, VkDeviceSize, VkDeviceSize, VkDeviceSize, VkDeviceSize> result(deAlign64(strSize, 256), 0,
2356                                                                                             0, 0, 0);
2357 
2358     if (!m_geometriesData.empty())
2359     {
2360         std::vector<VkAccelerationStructureGeometryKHR> accelerationStructureGeometriesKHR;
2361         std::vector<VkAccelerationStructureGeometryKHR *> accelerationStructureGeometriesKHRPointers;
2362         std::vector<VkAccelerationStructureBuildRangeInfoKHR> accelerationStructureBuildRangeInfoKHR;
2363         std::vector<VkAccelerationStructureTrianglesOpacityMicromapEXT> accelerationStructureGeometryMicromapsEXT;
2364         std::vector<uint32_t> maxPrimitiveCounts;
2365         prepareGeometries(vk, device, accelerationStructureGeometriesKHR, accelerationStructureGeometriesKHRPointers,
2366                           accelerationStructureBuildRangeInfoKHR, accelerationStructureGeometryMicromapsEXT,
2367                           maxPrimitiveCounts);
2368 
2369         const VkAccelerationStructureGeometryKHR *accelerationStructureGeometriesKHRPointer =
2370             accelerationStructureGeometriesKHR.data();
2371         const VkAccelerationStructureGeometryKHR *const *accelerationStructureGeometry =
2372             accelerationStructureGeometriesKHRPointers.data();
2373 
2374         VkAccelerationStructureBuildGeometryInfoKHR accelerationStructureBuildGeometryInfoKHR = {
2375             VK_STRUCTURE_TYPE_ACCELERATION_STRUCTURE_BUILD_GEOMETRY_INFO_KHR, //  VkStructureType sType;
2376             DE_NULL,                                                          //  const void* pNext;
2377             VK_ACCELERATION_STRUCTURE_TYPE_BOTTOM_LEVEL_KHR,                  //  VkAccelerationStructureTypeKHR type;
2378             m_buildFlags,                                   //  VkBuildAccelerationStructureFlagsKHR flags;
2379             VK_BUILD_ACCELERATION_STRUCTURE_MODE_BUILD_KHR, //  VkBuildAccelerationStructureModeKHR mode;
2380             DE_NULL,                                        //  VkAccelerationStructureKHR srcAccelerationStructure;
2381             DE_NULL,                                        //  VkAccelerationStructureKHR dstAccelerationStructure;
2382             static_cast<uint32_t>(accelerationStructureGeometriesKHR.size()), //  uint32_t geometryCount;
2383             m_useArrayOfPointers ?
2384                 DE_NULL :
2385                 accelerationStructureGeometriesKHRPointer, //  const VkAccelerationStructureGeometryKHR* pGeometries;
2386             m_useArrayOfPointers ? accelerationStructureGeometry :
2387                                    DE_NULL,     //  const VkAccelerationStructureGeometryKHR* const* ppGeometries;
2388             makeDeviceOrHostAddressKHR(DE_NULL) //  VkDeviceOrHostAddressKHR scratchData;
2389         };
2390 
2391         VkAccelerationStructureBuildSizesInfoKHR sizeInfo = {
2392             VK_STRUCTURE_TYPE_ACCELERATION_STRUCTURE_BUILD_SIZES_INFO_KHR, //  VkStructureType sType;
2393             DE_NULL,                                                       //  const void* pNext;
2394             0,                                                             //  VkDeviceSize accelerationStructureSize;
2395             0,                                                             //  VkDeviceSize updateScratchSize;
2396             0                                                              //  VkDeviceSize buildScratchSize;
2397         };
2398 
2399         vk.getAccelerationStructureBuildSizesKHR(device, m_buildType, &accelerationStructureBuildGeometryInfoKHR,
2400                                                  maxPrimitiveCounts.data(), &sizeInfo);
2401 
2402         std::get<0>(result) = sizeInfo.accelerationStructureSize;
2403         std::get<1>(result) = sizeInfo.updateScratchSize;
2404         std::get<2>(result) = sizeInfo.buildScratchSize;
2405         std::get<3>(result) = getVertexBufferSize(m_geometriesData);
2406         std::get<4>(result) = getIndexBufferSize(m_geometriesData);
2407     }
2408 
2409     return result;
2410 }
2411 
preCreateSetSizesAndOffsets(const Info & info,const VkDeviceSize accStrSize,const VkDeviceSize updateScratchSize,const VkDeviceSize buildScratchSize)2412 void BottomLevelAccelerationStructurePoolMember::preCreateSetSizesAndOffsets(const Info &info,
2413                                                                              const VkDeviceSize accStrSize,
2414                                                                              const VkDeviceSize updateScratchSize,
2415                                                                              const VkDeviceSize buildScratchSize)
2416 {
2417     m_info              = info;
2418     m_structureSize     = accStrSize;
2419     m_updateScratchSize = updateScratchSize;
2420     m_buildScratchSize  = buildScratchSize;
2421 }
2422 
createAccellerationStructure(const DeviceInterface & vk,const VkDevice device,VkDeviceAddress deviceAddress)2423 void BottomLevelAccelerationStructurePoolMember::createAccellerationStructure(const DeviceInterface &vk,
2424                                                                               const VkDevice device,
2425                                                                               VkDeviceAddress deviceAddress)
2426 {
2427     const VkAccelerationStructureTypeKHR structureType =
2428         (m_createGeneric ? VK_ACCELERATION_STRUCTURE_TYPE_GENERIC_KHR :
2429                            VK_ACCELERATION_STRUCTURE_TYPE_BOTTOM_LEVEL_KHR);
2430     const VkAccelerationStructureCreateInfoKHR accelerationStructureCreateInfoKHR{
2431         VK_STRUCTURE_TYPE_ACCELERATION_STRUCTURE_CREATE_INFO_KHR, //  VkStructureType sType;
2432         DE_NULL,                                                  //  const void* pNext;
2433         m_createFlags,                                            //  VkAccelerationStructureCreateFlagsKHR createFlags;
2434         getAccelerationStructureBuffer()->get(),                  //  VkBuffer buffer;
2435         getAccelerationStructureBufferOffset(),                   //  VkDeviceSize offset;
2436         m_structureSize,                                          //  VkDeviceSize size;
2437         structureType,                                            //  VkAccelerationStructureTypeKHR type;
2438         deviceAddress                                             //  VkDeviceAddress deviceAddress;
2439     };
2440 
2441     m_accelerationStructureKHR =
2442         createAccelerationStructureKHR(vk, device, &accelerationStructureCreateInfoKHR, DE_NULL);
2443 }
2444 
~TopLevelAccelerationStructure()2445 TopLevelAccelerationStructure::~TopLevelAccelerationStructure()
2446 {
2447 }
2448 
TopLevelAccelerationStructure()2449 TopLevelAccelerationStructure::TopLevelAccelerationStructure()
2450     : m_structureSize(0u)
2451     , m_updateScratchSize(0u)
2452     , m_buildScratchSize(0u)
2453 {
2454 }
2455 
setInstanceCount(const size_t instanceCount)2456 void TopLevelAccelerationStructure::setInstanceCount(const size_t instanceCount)
2457 {
2458     m_bottomLevelInstances.reserve(instanceCount);
2459     m_instanceData.reserve(instanceCount);
2460 }
2461 
addInstance(de::SharedPtr<BottomLevelAccelerationStructure> bottomLevelStructure,const VkTransformMatrixKHR & matrix,uint32_t instanceCustomIndex,uint32_t mask,uint32_t instanceShaderBindingTableRecordOffset,VkGeometryInstanceFlagsKHR flags)2462 void TopLevelAccelerationStructure::addInstance(de::SharedPtr<BottomLevelAccelerationStructure> bottomLevelStructure,
2463                                                 const VkTransformMatrixKHR &matrix, uint32_t instanceCustomIndex,
2464                                                 uint32_t mask, uint32_t instanceShaderBindingTableRecordOffset,
2465                                                 VkGeometryInstanceFlagsKHR flags)
2466 {
2467     m_bottomLevelInstances.push_back(bottomLevelStructure);
2468     m_instanceData.push_back(
2469         InstanceData(matrix, instanceCustomIndex, mask, instanceShaderBindingTableRecordOffset, flags));
2470 }
2471 
getStructureBuildSizes() const2472 VkAccelerationStructureBuildSizesInfoKHR TopLevelAccelerationStructure::getStructureBuildSizes() const
2473 {
2474     return {
2475         VK_STRUCTURE_TYPE_ACCELERATION_STRUCTURE_BUILD_SIZES_INFO_KHR, //  VkStructureType sType;
2476         DE_NULL,                                                       //  const void* pNext;
2477         m_structureSize,                                               //  VkDeviceSize accelerationStructureSize;
2478         m_updateScratchSize,                                           //  VkDeviceSize updateScratchSize;
2479         m_buildScratchSize                                             //  VkDeviceSize buildScratchSize;
2480     };
2481 }
2482 
createAndBuild(const DeviceInterface & vk,const VkDevice device,const VkCommandBuffer cmdBuffer,Allocator & allocator,VkDeviceAddress deviceAddress)2483 void TopLevelAccelerationStructure::createAndBuild(const DeviceInterface &vk, const VkDevice device,
2484                                                    const VkCommandBuffer cmdBuffer, Allocator &allocator,
2485                                                    VkDeviceAddress deviceAddress)
2486 {
2487     create(vk, device, allocator, 0u, deviceAddress);
2488     build(vk, device, cmdBuffer);
2489 }
2490 
createAndCopyFrom(const DeviceInterface & vk,const VkDevice device,const VkCommandBuffer cmdBuffer,Allocator & allocator,TopLevelAccelerationStructure * accelerationStructure,VkDeviceSize compactCopySize,VkDeviceAddress deviceAddress)2491 void TopLevelAccelerationStructure::createAndCopyFrom(const DeviceInterface &vk, const VkDevice device,
2492                                                       const VkCommandBuffer cmdBuffer, Allocator &allocator,
2493                                                       TopLevelAccelerationStructure *accelerationStructure,
2494                                                       VkDeviceSize compactCopySize, VkDeviceAddress deviceAddress)
2495 {
2496     DE_ASSERT(accelerationStructure != NULL);
2497     VkDeviceSize copiedSize = compactCopySize > 0u ?
2498                                   compactCopySize :
2499                                   accelerationStructure->getStructureBuildSizes().accelerationStructureSize;
2500     DE_ASSERT(copiedSize != 0u);
2501 
2502     create(vk, device, allocator, copiedSize, deviceAddress);
2503     copyFrom(vk, device, cmdBuffer, accelerationStructure, compactCopySize > 0u);
2504 }
2505 
createAndDeserializeFrom(const DeviceInterface & vk,const VkDevice device,const VkCommandBuffer cmdBuffer,Allocator & allocator,SerialStorage * storage,VkDeviceAddress deviceAddress)2506 void TopLevelAccelerationStructure::createAndDeserializeFrom(const DeviceInterface &vk, const VkDevice device,
2507                                                              const VkCommandBuffer cmdBuffer, Allocator &allocator,
2508                                                              SerialStorage *storage, VkDeviceAddress deviceAddress)
2509 {
2510     DE_ASSERT(storage != NULL);
2511     DE_ASSERT(storage->getStorageSize() >= SerialStorage::SERIAL_STORAGE_SIZE_MIN);
2512     create(vk, device, allocator, storage->getDeserializedSize(), deviceAddress);
2513     if (storage->hasDeepFormat())
2514         createAndDeserializeBottoms(vk, device, cmdBuffer, allocator, storage);
2515     deserialize(vk, device, cmdBuffer, storage);
2516 }
2517 
createInstanceBuffer(const DeviceInterface & vk,const VkDevice device,Allocator & allocator,std::vector<de::SharedPtr<BottomLevelAccelerationStructure>> bottomLevelInstances,std::vector<InstanceData> instanceData,const bool tryCachedMemory)2518 BufferWithMemory *createInstanceBuffer(
2519     const DeviceInterface &vk, const VkDevice device, Allocator &allocator,
2520     std::vector<de::SharedPtr<BottomLevelAccelerationStructure>> bottomLevelInstances,
2521     std::vector<InstanceData> instanceData, const bool tryCachedMemory)
2522 {
2523     DE_ASSERT(bottomLevelInstances.size() != 0);
2524     DE_ASSERT(bottomLevelInstances.size() == instanceData.size());
2525     DE_UNREF(instanceData);
2526 
2527     BufferWithMemory *result           = nullptr;
2528     const VkDeviceSize bufferSizeBytes = bottomLevelInstances.size() * sizeof(VkAccelerationStructureInstanceKHR);
2529     const VkBufferCreateInfo bufferCreateInfo =
2530         makeBufferCreateInfo(bufferSizeBytes, VK_BUFFER_USAGE_ACCELERATION_STRUCTURE_BUILD_INPUT_READ_ONLY_BIT_KHR |
2531                                                   VK_BUFFER_USAGE_SHADER_DEVICE_ADDRESS_BIT);
2532     if (tryCachedMemory)
2533         try
2534         {
2535             result = new BufferWithMemory(vk, device, allocator, bufferCreateInfo,
2536                                           MemoryRequirement::Cached | MemoryRequirement::HostVisible |
2537                                               MemoryRequirement::Coherent | MemoryRequirement::DeviceAddress);
2538         }
2539         catch (const tcu::NotSupportedError &)
2540         {
2541             result = nullptr;
2542         }
2543     return result ? result :
2544                     new BufferWithMemory(vk, device, allocator, bufferCreateInfo,
2545                                          MemoryRequirement::HostVisible | MemoryRequirement::Coherent |
2546                                              MemoryRequirement::DeviceAddress);
2547 }
2548 
updateSingleInstance(const DeviceInterface & vk,const VkDevice device,const BottomLevelAccelerationStructure & bottomLevelAccelerationStructure,const InstanceData & instanceData,uint8_t * bufferLocation,VkAccelerationStructureBuildTypeKHR buildType,bool inactiveInstances)2549 void updateSingleInstance(const DeviceInterface &vk, const VkDevice device,
2550                           const BottomLevelAccelerationStructure &bottomLevelAccelerationStructure,
2551                           const InstanceData &instanceData, uint8_t *bufferLocation,
2552                           VkAccelerationStructureBuildTypeKHR buildType, bool inactiveInstances)
2553 {
2554     const VkAccelerationStructureKHR accelerationStructureKHR = *bottomLevelAccelerationStructure.getPtr();
2555 
2556     // This part needs to be fixed once a new version of the VkAccelerationStructureInstanceKHR will be added to vkStructTypes.inl
2557     VkDeviceAddress accelerationStructureAddress;
2558     if (buildType == VK_ACCELERATION_STRUCTURE_BUILD_TYPE_DEVICE_KHR)
2559     {
2560         VkAccelerationStructureDeviceAddressInfoKHR asDeviceAddressInfo = {
2561             VK_STRUCTURE_TYPE_ACCELERATION_STRUCTURE_DEVICE_ADDRESS_INFO_KHR, // VkStructureType sType;
2562             DE_NULL,                                                          // const void* pNext;
2563             accelerationStructureKHR // VkAccelerationStructureKHR accelerationStructure;
2564         };
2565         accelerationStructureAddress = vk.getAccelerationStructureDeviceAddressKHR(device, &asDeviceAddressInfo);
2566     }
2567 
2568     uint64_t structureReference;
2569     if (inactiveInstances)
2570     {
2571         // Instances will be marked inactive by making their references VK_NULL_HANDLE or having address zero.
2572         structureReference = 0ull;
2573     }
2574     else
2575     {
2576         structureReference = (buildType == VK_ACCELERATION_STRUCTURE_BUILD_TYPE_DEVICE_KHR) ?
2577                                  uint64_t(accelerationStructureAddress) :
2578                                  uint64_t(accelerationStructureKHR.getInternal());
2579     }
2580 
2581     VkAccelerationStructureInstanceKHR accelerationStructureInstanceKHR = makeVkAccelerationStructureInstanceKHR(
2582         instanceData.matrix,                                 //  VkTransformMatrixKHR transform;
2583         instanceData.instanceCustomIndex,                    //  uint32_t instanceCustomIndex:24;
2584         instanceData.mask,                                   //  uint32_t mask:8;
2585         instanceData.instanceShaderBindingTableRecordOffset, //  uint32_t instanceShaderBindingTableRecordOffset:24;
2586         instanceData.flags,                                  //  VkGeometryInstanceFlagsKHR flags:8;
2587         structureReference                                   //  uint64_t accelerationStructureReference;
2588     );
2589 
2590     deMemcpy(bufferLocation, &accelerationStructureInstanceKHR, sizeof(VkAccelerationStructureInstanceKHR));
2591 }
2592 
updateInstanceBuffer(const DeviceInterface & vk,const VkDevice device,const std::vector<de::SharedPtr<BottomLevelAccelerationStructure>> & bottomLevelInstances,const std::vector<InstanceData> & instanceData,const BufferWithMemory * instanceBuffer,VkAccelerationStructureBuildTypeKHR buildType,bool inactiveInstances)2593 void updateInstanceBuffer(const DeviceInterface &vk, const VkDevice device,
2594                           const std::vector<de::SharedPtr<BottomLevelAccelerationStructure>> &bottomLevelInstances,
2595                           const std::vector<InstanceData> &instanceData, const BufferWithMemory *instanceBuffer,
2596                           VkAccelerationStructureBuildTypeKHR buildType, bool inactiveInstances)
2597 {
2598     DE_ASSERT(bottomLevelInstances.size() != 0);
2599     DE_ASSERT(bottomLevelInstances.size() == instanceData.size());
2600 
2601     auto &instancesAlloc      = instanceBuffer->getAllocation();
2602     auto bufferStart          = reinterpret_cast<uint8_t *>(instancesAlloc.getHostPtr());
2603     VkDeviceSize bufferOffset = 0ull;
2604 
2605     for (size_t instanceNdx = 0; instanceNdx < bottomLevelInstances.size(); ++instanceNdx)
2606     {
2607         const auto &blas = *bottomLevelInstances[instanceNdx];
2608         updateSingleInstance(vk, device, blas, instanceData[instanceNdx], bufferStart + bufferOffset, buildType,
2609                              inactiveInstances);
2610         bufferOffset += sizeof(VkAccelerationStructureInstanceKHR);
2611     }
2612 
2613     flushMappedMemoryRange(vk, device, instancesAlloc.getMemory(), instancesAlloc.getOffset(), VK_WHOLE_SIZE);
2614 }
2615 
2616 class TopLevelAccelerationStructureKHR : public TopLevelAccelerationStructure
2617 {
2618 public:
2619     static uint32_t getRequiredAllocationCount(void);
2620 
2621     TopLevelAccelerationStructureKHR();
2622     TopLevelAccelerationStructureKHR(const TopLevelAccelerationStructureKHR &other) = delete;
2623     virtual ~TopLevelAccelerationStructureKHR();
2624 
2625     void setBuildType(const VkAccelerationStructureBuildTypeKHR buildType) override;
2626     void setCreateFlags(const VkAccelerationStructureCreateFlagsKHR createFlags) override;
2627     void setCreateGeneric(bool createGeneric) override;
2628     void setCreationBufferUnbounded(bool creationBufferUnbounded) override;
2629     void setBuildFlags(const VkBuildAccelerationStructureFlagsKHR buildFlags) override;
2630     void setBuildWithoutPrimitives(bool buildWithoutPrimitives) override;
2631     void setInactiveInstances(bool inactiveInstances) override;
2632     void setDeferredOperation(const bool deferredOperation, const uint32_t workerThreadCount) override;
2633     void setUseArrayOfPointers(const bool useArrayOfPointers) override;
2634     void setIndirectBuildParameters(const VkBuffer indirectBuffer, const VkDeviceSize indirectBufferOffset,
2635                                     const uint32_t indirectBufferStride) override;
2636     void setUsePPGeometries(const bool usePPGeometries) override;
2637     void setTryCachedMemory(const bool tryCachedMemory) override;
2638     VkBuildAccelerationStructureFlagsKHR getBuildFlags() const override;
2639 
2640     void getCreationSizes(const DeviceInterface &vk, const VkDevice device, const VkDeviceSize structureSize,
2641                           CreationSizes &sizes) override;
2642     void create(const DeviceInterface &vk, const VkDevice device, Allocator &allocator, VkDeviceSize structureSize,
2643                 VkDeviceAddress deviceAddress = 0u, const void *pNext = DE_NULL,
2644                 const MemoryRequirement &addMemoryRequirement = MemoryRequirement::Any,
2645                 const VkBuffer creationBuffer = VK_NULL_HANDLE, const VkDeviceSize creationBufferSize = 0u) override;
2646     void build(const DeviceInterface &vk, const VkDevice device, const VkCommandBuffer cmdBuffer,
2647                TopLevelAccelerationStructure *srcAccelerationStructure = DE_NULL) override;
2648     void copyFrom(const DeviceInterface &vk, const VkDevice device, const VkCommandBuffer cmdBuffer,
2649                   TopLevelAccelerationStructure *accelerationStructure, bool compactCopy) override;
2650     void serialize(const DeviceInterface &vk, const VkDevice device, const VkCommandBuffer cmdBuffer,
2651                    SerialStorage *storage) override;
2652     void deserialize(const DeviceInterface &vk, const VkDevice device, const VkCommandBuffer cmdBuffer,
2653                      SerialStorage *storage) override;
2654 
2655     std::vector<VkDeviceSize> getSerializingSizes(const DeviceInterface &vk, const VkDevice device, const VkQueue queue,
2656                                                   const uint32_t queueFamilyIndex) override;
2657 
2658     std::vector<uint64_t> getSerializingAddresses(const DeviceInterface &vk, const VkDevice device) const override;
2659 
2660     const VkAccelerationStructureKHR *getPtr(void) const override;
2661 
2662     void updateInstanceMatrix(const DeviceInterface &vk, const VkDevice device, size_t instanceIndex,
2663                               const VkTransformMatrixKHR &matrix) override;
2664 
2665 protected:
2666     VkAccelerationStructureBuildTypeKHR m_buildType;
2667     VkAccelerationStructureCreateFlagsKHR m_createFlags;
2668     bool m_createGeneric;
2669     bool m_creationBufferUnbounded;
2670     VkBuildAccelerationStructureFlagsKHR m_buildFlags;
2671     bool m_buildWithoutPrimitives;
2672     bool m_inactiveInstances;
2673     bool m_deferredOperation;
2674     uint32_t m_workerThreadCount;
2675     bool m_useArrayOfPointers;
2676     de::MovePtr<BufferWithMemory> m_accelerationStructureBuffer;
2677     de::MovePtr<BufferWithMemory> m_instanceBuffer;
2678     de::MovePtr<BufferWithMemory> m_instanceAddressBuffer;
2679     de::MovePtr<BufferWithMemory> m_deviceScratchBuffer;
2680     std::vector<uint8_t> m_hostScratchBuffer;
2681     Move<VkAccelerationStructureKHR> m_accelerationStructureKHR;
2682     VkBuffer m_indirectBuffer;
2683     VkDeviceSize m_indirectBufferOffset;
2684     uint32_t m_indirectBufferStride;
2685     bool m_usePPGeometries;
2686     bool m_tryCachedMemory;
2687 
2688     void prepareInstances(const DeviceInterface &vk, const VkDevice device,
2689                           VkAccelerationStructureGeometryKHR &accelerationStructureGeometryKHR,
2690                           std::vector<uint32_t> &maxPrimitiveCounts);
2691 
2692     void serializeBottoms(const DeviceInterface &vk, const VkDevice device, const VkCommandBuffer cmdBuffer,
2693                           SerialStorage *storage, VkDeferredOperationKHR deferredOperation);
2694 
2695     void createAndDeserializeBottoms(const DeviceInterface &vk, const VkDevice device, const VkCommandBuffer cmdBuffer,
2696                                      Allocator &allocator, SerialStorage *storage) override;
2697 };
2698 
getRequiredAllocationCount(void)2699 uint32_t TopLevelAccelerationStructureKHR::getRequiredAllocationCount(void)
2700 {
2701     /*
2702         de::MovePtr<BufferWithMemory>                            m_instanceBuffer;
2703         de::MovePtr<Allocation>                                    m_accelerationStructureAlloc;
2704         de::MovePtr<BufferWithMemory>                            m_deviceScratchBuffer;
2705     */
2706     return 3u;
2707 }
2708 
TopLevelAccelerationStructureKHR()2709 TopLevelAccelerationStructureKHR::TopLevelAccelerationStructureKHR()
2710     : TopLevelAccelerationStructure()
2711     , m_buildType(VK_ACCELERATION_STRUCTURE_BUILD_TYPE_DEVICE_KHR)
2712     , m_createFlags(0u)
2713     , m_createGeneric(false)
2714     , m_creationBufferUnbounded(false)
2715     , m_buildFlags(0u)
2716     , m_buildWithoutPrimitives(false)
2717     , m_inactiveInstances(false)
2718     , m_deferredOperation(false)
2719     , m_workerThreadCount(0)
2720     , m_useArrayOfPointers(false)
2721     , m_accelerationStructureBuffer(DE_NULL)
2722     , m_instanceBuffer(DE_NULL)
2723     , m_instanceAddressBuffer(DE_NULL)
2724     , m_deviceScratchBuffer(DE_NULL)
2725     , m_accelerationStructureKHR()
2726     , m_indirectBuffer(DE_NULL)
2727     , m_indirectBufferOffset(0)
2728     , m_indirectBufferStride(0)
2729     , m_usePPGeometries(false)
2730     , m_tryCachedMemory(true)
2731 {
2732 }
2733 
~TopLevelAccelerationStructureKHR()2734 TopLevelAccelerationStructureKHR::~TopLevelAccelerationStructureKHR()
2735 {
2736 }
2737 
setBuildType(const VkAccelerationStructureBuildTypeKHR buildType)2738 void TopLevelAccelerationStructureKHR::setBuildType(const VkAccelerationStructureBuildTypeKHR buildType)
2739 {
2740     m_buildType = buildType;
2741 }
2742 
setCreateFlags(const VkAccelerationStructureCreateFlagsKHR createFlags)2743 void TopLevelAccelerationStructureKHR::setCreateFlags(const VkAccelerationStructureCreateFlagsKHR createFlags)
2744 {
2745     m_createFlags = createFlags;
2746 }
2747 
setCreateGeneric(bool createGeneric)2748 void TopLevelAccelerationStructureKHR::setCreateGeneric(bool createGeneric)
2749 {
2750     m_createGeneric = createGeneric;
2751 }
2752 
setCreationBufferUnbounded(bool creationBufferUnbounded)2753 void TopLevelAccelerationStructureKHR::setCreationBufferUnbounded(bool creationBufferUnbounded)
2754 {
2755     m_creationBufferUnbounded = creationBufferUnbounded;
2756 }
2757 
setInactiveInstances(bool inactiveInstances)2758 void TopLevelAccelerationStructureKHR::setInactiveInstances(bool inactiveInstances)
2759 {
2760     m_inactiveInstances = inactiveInstances;
2761 }
2762 
setBuildFlags(const VkBuildAccelerationStructureFlagsKHR buildFlags)2763 void TopLevelAccelerationStructureKHR::setBuildFlags(const VkBuildAccelerationStructureFlagsKHR buildFlags)
2764 {
2765     m_buildFlags = buildFlags;
2766 }
2767 
setBuildWithoutPrimitives(bool buildWithoutPrimitives)2768 void TopLevelAccelerationStructureKHR::setBuildWithoutPrimitives(bool buildWithoutPrimitives)
2769 {
2770     m_buildWithoutPrimitives = buildWithoutPrimitives;
2771 }
2772 
setDeferredOperation(const bool deferredOperation,const uint32_t workerThreadCount)2773 void TopLevelAccelerationStructureKHR::setDeferredOperation(const bool deferredOperation,
2774                                                             const uint32_t workerThreadCount)
2775 {
2776     m_deferredOperation = deferredOperation;
2777     m_workerThreadCount = workerThreadCount;
2778 }
2779 
setUseArrayOfPointers(const bool useArrayOfPointers)2780 void TopLevelAccelerationStructureKHR::setUseArrayOfPointers(const bool useArrayOfPointers)
2781 {
2782     m_useArrayOfPointers = useArrayOfPointers;
2783 }
2784 
setUsePPGeometries(const bool usePPGeometries)2785 void TopLevelAccelerationStructureKHR::setUsePPGeometries(const bool usePPGeometries)
2786 {
2787     m_usePPGeometries = usePPGeometries;
2788 }
2789 
setTryCachedMemory(const bool tryCachedMemory)2790 void TopLevelAccelerationStructureKHR::setTryCachedMemory(const bool tryCachedMemory)
2791 {
2792     m_tryCachedMemory = tryCachedMemory;
2793 }
2794 
setIndirectBuildParameters(const VkBuffer indirectBuffer,const VkDeviceSize indirectBufferOffset,const uint32_t indirectBufferStride)2795 void TopLevelAccelerationStructureKHR::setIndirectBuildParameters(const VkBuffer indirectBuffer,
2796                                                                   const VkDeviceSize indirectBufferOffset,
2797                                                                   const uint32_t indirectBufferStride)
2798 {
2799     m_indirectBuffer       = indirectBuffer;
2800     m_indirectBufferOffset = indirectBufferOffset;
2801     m_indirectBufferStride = indirectBufferStride;
2802 }
2803 
getBuildFlags() const2804 VkBuildAccelerationStructureFlagsKHR TopLevelAccelerationStructureKHR::getBuildFlags() const
2805 {
2806     return m_buildFlags;
2807 }
2808 
sum() const2809 VkDeviceSize TopLevelAccelerationStructure::CreationSizes::sum() const
2810 {
2811     return structure + updateScratch + buildScratch + instancePointers + instancesBuffer;
2812 }
2813 
getCreationSizes(const DeviceInterface & vk,const VkDevice device,const VkDeviceSize structureSize,CreationSizes & sizes)2814 void TopLevelAccelerationStructureKHR::getCreationSizes(const DeviceInterface &vk, const VkDevice device,
2815                                                         const VkDeviceSize structureSize, CreationSizes &sizes)
2816 {
2817     // AS may be built from geometries using vkCmdBuildAccelerationStructureKHR / vkBuildAccelerationStructureKHR
2818     // or may be copied/compacted/deserialized from other AS ( in this case AS does not need geometries, but it needs to know its size before creation ).
2819     DE_ASSERT(!m_bottomLevelInstances.empty() != !(structureSize == 0)); // logical xor
2820 
2821     if (structureSize == 0)
2822     {
2823         VkAccelerationStructureGeometryKHR accelerationStructureGeometryKHR;
2824         const auto accelerationStructureGeometryKHRPtr = &accelerationStructureGeometryKHR;
2825         std::vector<uint32_t> maxPrimitiveCounts;
2826         prepareInstances(vk, device, accelerationStructureGeometryKHR, maxPrimitiveCounts);
2827 
2828         VkAccelerationStructureBuildGeometryInfoKHR accelerationStructureBuildGeometryInfoKHR = {
2829             VK_STRUCTURE_TYPE_ACCELERATION_STRUCTURE_BUILD_GEOMETRY_INFO_KHR, //  VkStructureType sType;
2830             DE_NULL,                                                          //  const void* pNext;
2831             VK_ACCELERATION_STRUCTURE_TYPE_TOP_LEVEL_KHR,                     //  VkAccelerationStructureTypeKHR type;
2832             m_buildFlags,                                   //  VkBuildAccelerationStructureFlagsKHR flags;
2833             VK_BUILD_ACCELERATION_STRUCTURE_MODE_BUILD_KHR, //  VkBuildAccelerationStructureModeKHR mode;
2834             DE_NULL,                                        //  VkAccelerationStructureKHR srcAccelerationStructure;
2835             DE_NULL,                                        //  VkAccelerationStructureKHR dstAccelerationStructure;
2836             1u,                                             //  uint32_t geometryCount;
2837             (m_usePPGeometries ?
2838                  nullptr :
2839                  &accelerationStructureGeometryKHR), //  const VkAccelerationStructureGeometryKHR* pGeometries;
2840             (m_usePPGeometries ? &accelerationStructureGeometryKHRPtr :
2841                                  nullptr),      //  const VkAccelerationStructureGeometryKHR* const* ppGeometries;
2842             makeDeviceOrHostAddressKHR(DE_NULL) //  VkDeviceOrHostAddressKHR scratchData;
2843         };
2844 
2845         VkAccelerationStructureBuildSizesInfoKHR sizeInfo = {
2846             VK_STRUCTURE_TYPE_ACCELERATION_STRUCTURE_BUILD_SIZES_INFO_KHR, //  VkStructureType sType;
2847             DE_NULL,                                                       //  const void* pNext;
2848             0,                                                             //  VkDeviceSize accelerationStructureSize;
2849             0,                                                             //  VkDeviceSize updateScratchSize;
2850             0                                                              //  VkDeviceSize buildScratchSize;
2851         };
2852 
2853         vk.getAccelerationStructureBuildSizesKHR(device, m_buildType, &accelerationStructureBuildGeometryInfoKHR,
2854                                                  maxPrimitiveCounts.data(), &sizeInfo);
2855 
2856         sizes.structure     = sizeInfo.accelerationStructureSize;
2857         sizes.updateScratch = sizeInfo.updateScratchSize;
2858         sizes.buildScratch  = sizeInfo.buildScratchSize;
2859     }
2860     else
2861     {
2862         sizes.structure     = structureSize;
2863         sizes.updateScratch = 0u;
2864         sizes.buildScratch  = 0u;
2865     }
2866 
2867     sizes.instancePointers = 0u;
2868     if (m_useArrayOfPointers)
2869     {
2870         const size_t pointerSize = (m_buildType == VK_ACCELERATION_STRUCTURE_BUILD_TYPE_DEVICE_KHR) ?
2871                                        sizeof(VkDeviceOrHostAddressConstKHR::deviceAddress) :
2872                                        sizeof(VkDeviceOrHostAddressConstKHR::hostAddress);
2873         sizes.instancePointers   = static_cast<VkDeviceSize>(m_bottomLevelInstances.size() * pointerSize);
2874     }
2875 
2876     sizes.instancesBuffer = m_bottomLevelInstances.empty() ?
2877                                 0u :
2878                                 m_bottomLevelInstances.size() * sizeof(VkAccelerationStructureInstanceKHR);
2879 }
2880 
create(const DeviceInterface & vk,const VkDevice device,Allocator & allocator,VkDeviceSize structureSize,VkDeviceAddress deviceAddress,const void * pNext,const MemoryRequirement & addMemoryRequirement,const VkBuffer creationBuffer,const VkDeviceSize creationBufferSize)2881 void TopLevelAccelerationStructureKHR::create(const DeviceInterface &vk, const VkDevice device, Allocator &allocator,
2882                                               VkDeviceSize structureSize, VkDeviceAddress deviceAddress,
2883                                               const void *pNext, const MemoryRequirement &addMemoryRequirement,
2884                                               const VkBuffer creationBuffer, const VkDeviceSize creationBufferSize)
2885 {
2886     // AS may be built from geometries using vkCmdBuildAccelerationStructureKHR / vkBuildAccelerationStructureKHR
2887     // or may be copied/compacted/deserialized from other AS ( in this case AS does not need geometries, but it needs to know its size before creation ).
2888     DE_ASSERT(!m_bottomLevelInstances.empty() != !(structureSize == 0)); // logical xor
2889 
2890     if (structureSize == 0)
2891     {
2892         VkAccelerationStructureGeometryKHR accelerationStructureGeometryKHR;
2893         const auto accelerationStructureGeometryKHRPtr = &accelerationStructureGeometryKHR;
2894         std::vector<uint32_t> maxPrimitiveCounts;
2895         prepareInstances(vk, device, accelerationStructureGeometryKHR, maxPrimitiveCounts);
2896 
2897         VkAccelerationStructureBuildGeometryInfoKHR accelerationStructureBuildGeometryInfoKHR = {
2898             VK_STRUCTURE_TYPE_ACCELERATION_STRUCTURE_BUILD_GEOMETRY_INFO_KHR, //  VkStructureType sType;
2899             DE_NULL,                                                          //  const void* pNext;
2900             VK_ACCELERATION_STRUCTURE_TYPE_TOP_LEVEL_KHR,                     //  VkAccelerationStructureTypeKHR type;
2901             m_buildFlags,                                   //  VkBuildAccelerationStructureFlagsKHR flags;
2902             VK_BUILD_ACCELERATION_STRUCTURE_MODE_BUILD_KHR, //  VkBuildAccelerationStructureModeKHR mode;
2903             DE_NULL,                                        //  VkAccelerationStructureKHR srcAccelerationStructure;
2904             DE_NULL,                                        //  VkAccelerationStructureKHR dstAccelerationStructure;
2905             1u,                                             //  uint32_t geometryCount;
2906             (m_usePPGeometries ?
2907                  nullptr :
2908                  &accelerationStructureGeometryKHR), //  const VkAccelerationStructureGeometryKHR* pGeometries;
2909             (m_usePPGeometries ? &accelerationStructureGeometryKHRPtr :
2910                                  nullptr),      //  const VkAccelerationStructureGeometryKHR* const* ppGeometries;
2911             makeDeviceOrHostAddressKHR(DE_NULL) //  VkDeviceOrHostAddressKHR scratchData;
2912         };
2913 
2914         VkAccelerationStructureBuildSizesInfoKHR sizeInfo = {
2915             VK_STRUCTURE_TYPE_ACCELERATION_STRUCTURE_BUILD_SIZES_INFO_KHR, //  VkStructureType sType;
2916             DE_NULL,                                                       //  const void* pNext;
2917             0,                                                             //  VkDeviceSize accelerationStructureSize;
2918             0,                                                             //  VkDeviceSize updateScratchSize;
2919             0                                                              //  VkDeviceSize buildScratchSize;
2920         };
2921 
2922         vk.getAccelerationStructureBuildSizesKHR(device, m_buildType, &accelerationStructureBuildGeometryInfoKHR,
2923                                                  maxPrimitiveCounts.data(), &sizeInfo);
2924 
2925         m_structureSize     = sizeInfo.accelerationStructureSize;
2926         m_updateScratchSize = sizeInfo.updateScratchSize;
2927         m_buildScratchSize  = sizeInfo.buildScratchSize;
2928     }
2929     else
2930     {
2931         m_structureSize     = structureSize;
2932         m_updateScratchSize = 0u;
2933         m_buildScratchSize  = 0u;
2934     }
2935 
2936     const bool externalCreationBuffer = (creationBuffer != VK_NULL_HANDLE);
2937 
2938     if (externalCreationBuffer)
2939     {
2940         DE_UNREF(creationBufferSize); // For release builds.
2941         DE_ASSERT(creationBufferSize >= m_structureSize);
2942     }
2943 
2944     if (!externalCreationBuffer)
2945     {
2946         const VkBufferCreateInfo bufferCreateInfo =
2947             makeBufferCreateInfo(m_structureSize, VK_BUFFER_USAGE_ACCELERATION_STRUCTURE_STORAGE_BIT_KHR |
2948                                                       VK_BUFFER_USAGE_SHADER_DEVICE_ADDRESS_BIT);
2949         const MemoryRequirement memoryRequirement = addMemoryRequirement | MemoryRequirement::HostVisible |
2950                                                     MemoryRequirement::Coherent | MemoryRequirement::DeviceAddress;
2951         const bool bindMemOnCreation = (!m_creationBufferUnbounded);
2952 
2953         try
2954         {
2955             m_accelerationStructureBuffer = de::MovePtr<BufferWithMemory>(
2956                 new BufferWithMemory(vk, device, allocator, bufferCreateInfo,
2957                                      (MemoryRequirement::Cached | memoryRequirement), bindMemOnCreation));
2958         }
2959         catch (const tcu::NotSupportedError &)
2960         {
2961             // retry without Cached flag
2962             m_accelerationStructureBuffer = de::MovePtr<BufferWithMemory>(
2963                 new BufferWithMemory(vk, device, allocator, bufferCreateInfo, memoryRequirement, bindMemOnCreation));
2964         }
2965     }
2966 
2967     const auto createInfoBuffer = (externalCreationBuffer ? creationBuffer : m_accelerationStructureBuffer->get());
2968     {
2969         const VkAccelerationStructureTypeKHR structureType =
2970             (m_createGeneric ? VK_ACCELERATION_STRUCTURE_TYPE_GENERIC_KHR :
2971                                VK_ACCELERATION_STRUCTURE_TYPE_TOP_LEVEL_KHR);
2972         const VkAccelerationStructureCreateInfoKHR accelerationStructureCreateInfoKHR = {
2973             VK_STRUCTURE_TYPE_ACCELERATION_STRUCTURE_CREATE_INFO_KHR, //  VkStructureType sType;
2974             pNext,                                                    //  const void* pNext;
2975             m_createFlags,    //  VkAccelerationStructureCreateFlagsKHR createFlags;
2976             createInfoBuffer, //  VkBuffer buffer;
2977             0u,               //  VkDeviceSize offset;
2978             m_structureSize,  //  VkDeviceSize size;
2979             structureType,    //  VkAccelerationStructureTypeKHR type;
2980             deviceAddress     //  VkDeviceAddress deviceAddress;
2981         };
2982 
2983         m_accelerationStructureKHR =
2984             createAccelerationStructureKHR(vk, device, &accelerationStructureCreateInfoKHR, DE_NULL);
2985 
2986         // Make sure buffer memory is always bound after creation.
2987         if (!externalCreationBuffer)
2988             m_accelerationStructureBuffer->bindMemory();
2989     }
2990 
2991     if (m_buildScratchSize > 0u)
2992     {
2993         if (m_buildType == VK_ACCELERATION_STRUCTURE_BUILD_TYPE_DEVICE_KHR)
2994         {
2995             const VkBufferCreateInfo bufferCreateInfo = makeBufferCreateInfo(
2996                 m_buildScratchSize, VK_BUFFER_USAGE_STORAGE_BUFFER_BIT | VK_BUFFER_USAGE_SHADER_DEVICE_ADDRESS_BIT);
2997             m_deviceScratchBuffer = de::MovePtr<BufferWithMemory>(new BufferWithMemory(
2998                 vk, device, allocator, bufferCreateInfo,
2999                 MemoryRequirement::HostVisible | MemoryRequirement::Coherent | MemoryRequirement::DeviceAddress));
3000         }
3001         else
3002         {
3003             m_hostScratchBuffer.resize(static_cast<size_t>(m_buildScratchSize));
3004         }
3005     }
3006 
3007     if (m_useArrayOfPointers)
3008     {
3009         const size_t pointerSize = (m_buildType == VK_ACCELERATION_STRUCTURE_BUILD_TYPE_DEVICE_KHR) ?
3010                                        sizeof(VkDeviceOrHostAddressConstKHR::deviceAddress) :
3011                                        sizeof(VkDeviceOrHostAddressConstKHR::hostAddress);
3012         const VkBufferCreateInfo bufferCreateInfo =
3013             makeBufferCreateInfo(static_cast<VkDeviceSize>(m_bottomLevelInstances.size() * pointerSize),
3014                                  VK_BUFFER_USAGE_ACCELERATION_STRUCTURE_BUILD_INPUT_READ_ONLY_BIT_KHR |
3015                                      VK_BUFFER_USAGE_SHADER_DEVICE_ADDRESS_BIT);
3016         m_instanceAddressBuffer = de::MovePtr<BufferWithMemory>(new BufferWithMemory(
3017             vk, device, allocator, bufferCreateInfo,
3018             MemoryRequirement::HostVisible | MemoryRequirement::Coherent | MemoryRequirement::DeviceAddress));
3019     }
3020 
3021     if (!m_bottomLevelInstances.empty())
3022         m_instanceBuffer = de::MovePtr<BufferWithMemory>(
3023             createInstanceBuffer(vk, device, allocator, m_bottomLevelInstances, m_instanceData, m_tryCachedMemory));
3024 }
3025 
updateInstanceMatrix(const DeviceInterface & vk,const VkDevice device,size_t instanceIndex,const VkTransformMatrixKHR & matrix)3026 void TopLevelAccelerationStructureKHR::updateInstanceMatrix(const DeviceInterface &vk, const VkDevice device,
3027                                                             size_t instanceIndex, const VkTransformMatrixKHR &matrix)
3028 {
3029     DE_ASSERT(instanceIndex < m_bottomLevelInstances.size());
3030     DE_ASSERT(instanceIndex < m_instanceData.size());
3031 
3032     const auto &blas          = *m_bottomLevelInstances[instanceIndex];
3033     auto &instanceData        = m_instanceData[instanceIndex];
3034     auto &instancesAlloc      = m_instanceBuffer->getAllocation();
3035     auto bufferStart          = reinterpret_cast<uint8_t *>(instancesAlloc.getHostPtr());
3036     VkDeviceSize bufferOffset = sizeof(VkAccelerationStructureInstanceKHR) * instanceIndex;
3037 
3038     instanceData.matrix = matrix;
3039     updateSingleInstance(vk, device, blas, instanceData, bufferStart + bufferOffset, m_buildType, m_inactiveInstances);
3040     flushMappedMemoryRange(vk, device, instancesAlloc.getMemory(), instancesAlloc.getOffset(), VK_WHOLE_SIZE);
3041 }
3042 
build(const DeviceInterface & vk,const VkDevice device,const VkCommandBuffer cmdBuffer,TopLevelAccelerationStructure * srcAccelerationStructure)3043 void TopLevelAccelerationStructureKHR::build(const DeviceInterface &vk, const VkDevice device,
3044                                              const VkCommandBuffer cmdBuffer,
3045                                              TopLevelAccelerationStructure *srcAccelerationStructure)
3046 {
3047     DE_ASSERT(!m_bottomLevelInstances.empty());
3048     DE_ASSERT(m_accelerationStructureKHR.get() != DE_NULL);
3049     DE_ASSERT(m_buildScratchSize != 0);
3050 
3051     updateInstanceBuffer(vk, device, m_bottomLevelInstances, m_instanceData, m_instanceBuffer.get(), m_buildType,
3052                          m_inactiveInstances);
3053 
3054     VkAccelerationStructureGeometryKHR accelerationStructureGeometryKHR;
3055     const auto accelerationStructureGeometryKHRPtr = &accelerationStructureGeometryKHR;
3056     std::vector<uint32_t> maxPrimitiveCounts;
3057     prepareInstances(vk, device, accelerationStructureGeometryKHR, maxPrimitiveCounts);
3058 
3059     VkDeviceOrHostAddressKHR scratchData = (m_buildType == VK_ACCELERATION_STRUCTURE_BUILD_TYPE_DEVICE_KHR) ?
3060                                                makeDeviceOrHostAddressKHR(vk, device, m_deviceScratchBuffer->get(), 0) :
3061                                                makeDeviceOrHostAddressKHR(m_hostScratchBuffer.data());
3062 
3063     VkAccelerationStructureKHR srcStructure =
3064         (srcAccelerationStructure != DE_NULL) ? *(srcAccelerationStructure->getPtr()) : DE_NULL;
3065     VkBuildAccelerationStructureModeKHR mode = (srcAccelerationStructure != DE_NULL) ?
3066                                                    VK_BUILD_ACCELERATION_STRUCTURE_MODE_UPDATE_KHR :
3067                                                    VK_BUILD_ACCELERATION_STRUCTURE_MODE_BUILD_KHR;
3068 
3069     VkAccelerationStructureBuildGeometryInfoKHR accelerationStructureBuildGeometryInfoKHR = {
3070         VK_STRUCTURE_TYPE_ACCELERATION_STRUCTURE_BUILD_GEOMETRY_INFO_KHR, //  VkStructureType sType;
3071         DE_NULL,                                                          //  const void* pNext;
3072         VK_ACCELERATION_STRUCTURE_TYPE_TOP_LEVEL_KHR,                     //  VkAccelerationStructureTypeKHR type;
3073         m_buildFlags,                     //  VkBuildAccelerationStructureFlagsKHR flags;
3074         mode,                             //  VkBuildAccelerationStructureModeKHR mode;
3075         srcStructure,                     //  VkAccelerationStructureKHR srcAccelerationStructure;
3076         m_accelerationStructureKHR.get(), //  VkAccelerationStructureKHR dstAccelerationStructure;
3077         1u,                               //  uint32_t geometryCount;
3078         (m_usePPGeometries ?
3079              nullptr :
3080              &accelerationStructureGeometryKHR), //  const VkAccelerationStructureGeometryKHR* pGeometries;
3081         (m_usePPGeometries ? &accelerationStructureGeometryKHRPtr :
3082                              nullptr), //  const VkAccelerationStructureGeometryKHR* const* ppGeometries;
3083         scratchData                    //  VkDeviceOrHostAddressKHR scratchData;
3084     };
3085 
3086     const uint32_t primitiveCount =
3087         (m_buildWithoutPrimitives ? 0u : static_cast<uint32_t>(m_bottomLevelInstances.size()));
3088 
3089     VkAccelerationStructureBuildRangeInfoKHR accelerationStructureBuildRangeInfoKHR = {
3090         primitiveCount, //  uint32_t primitiveCount;
3091         0,              //  uint32_t primitiveOffset;
3092         0,              //  uint32_t firstVertex;
3093         0               //  uint32_t transformOffset;
3094     };
3095     VkAccelerationStructureBuildRangeInfoKHR *accelerationStructureBuildRangeInfoKHRPtr =
3096         &accelerationStructureBuildRangeInfoKHR;
3097 
3098     if (m_buildType == VK_ACCELERATION_STRUCTURE_BUILD_TYPE_DEVICE_KHR)
3099     {
3100         if (m_indirectBuffer == DE_NULL)
3101             vk.cmdBuildAccelerationStructuresKHR(
3102                 cmdBuffer, 1u, &accelerationStructureBuildGeometryInfoKHR,
3103                 (const VkAccelerationStructureBuildRangeInfoKHR **)&accelerationStructureBuildRangeInfoKHRPtr);
3104         else
3105         {
3106             VkDeviceAddress indirectDeviceAddress =
3107                 getBufferDeviceAddress(vk, device, m_indirectBuffer, m_indirectBufferOffset);
3108             uint32_t *pMaxPrimitiveCounts = maxPrimitiveCounts.data();
3109             vk.cmdBuildAccelerationStructuresIndirectKHR(cmdBuffer, 1u, &accelerationStructureBuildGeometryInfoKHR,
3110                                                          &indirectDeviceAddress, &m_indirectBufferStride,
3111                                                          &pMaxPrimitiveCounts);
3112         }
3113     }
3114     else if (!m_deferredOperation)
3115     {
3116         VK_CHECK(vk.buildAccelerationStructuresKHR(
3117             device, DE_NULL, 1u, &accelerationStructureBuildGeometryInfoKHR,
3118             (const VkAccelerationStructureBuildRangeInfoKHR **)&accelerationStructureBuildRangeInfoKHRPtr));
3119     }
3120     else
3121     {
3122         const auto deferredOperationPtr = createDeferredOperationKHR(vk, device);
3123         const auto deferredOperation    = deferredOperationPtr.get();
3124 
3125         VkResult result = vk.buildAccelerationStructuresKHR(
3126             device, deferredOperation, 1u, &accelerationStructureBuildGeometryInfoKHR,
3127             (const VkAccelerationStructureBuildRangeInfoKHR **)&accelerationStructureBuildRangeInfoKHRPtr);
3128 
3129         DE_ASSERT(result == VK_OPERATION_DEFERRED_KHR || result == VK_OPERATION_NOT_DEFERRED_KHR ||
3130                   result == VK_SUCCESS);
3131 
3132         finishDeferredOperation(vk, device, deferredOperation, m_workerThreadCount,
3133                                 result == VK_OPERATION_NOT_DEFERRED_KHR);
3134 
3135         accelerationStructureBuildGeometryInfoKHR.pNext = DE_NULL;
3136     }
3137 
3138     if (m_buildType == VK_ACCELERATION_STRUCTURE_BUILD_TYPE_DEVICE_KHR)
3139     {
3140         const VkAccessFlags accessMasks =
3141             VK_ACCESS_ACCELERATION_STRUCTURE_WRITE_BIT_KHR | VK_ACCESS_ACCELERATION_STRUCTURE_READ_BIT_KHR;
3142         const VkMemoryBarrier memBarrier = makeMemoryBarrier(accessMasks, accessMasks);
3143 
3144         cmdPipelineMemoryBarrier(vk, cmdBuffer, VK_PIPELINE_STAGE_ACCELERATION_STRUCTURE_BUILD_BIT_KHR,
3145                                  VK_PIPELINE_STAGE_ALL_COMMANDS_BIT, &memBarrier);
3146     }
3147 }
3148 
copyFrom(const DeviceInterface & vk,const VkDevice device,const VkCommandBuffer cmdBuffer,TopLevelAccelerationStructure * accelerationStructure,bool compactCopy)3149 void TopLevelAccelerationStructureKHR::copyFrom(const DeviceInterface &vk, const VkDevice device,
3150                                                 const VkCommandBuffer cmdBuffer,
3151                                                 TopLevelAccelerationStructure *accelerationStructure, bool compactCopy)
3152 {
3153     DE_ASSERT(m_accelerationStructureKHR.get() != DE_NULL);
3154     DE_ASSERT(accelerationStructure != DE_NULL);
3155 
3156     VkCopyAccelerationStructureInfoKHR copyAccelerationStructureInfo = {
3157         VK_STRUCTURE_TYPE_COPY_ACCELERATION_STRUCTURE_INFO_KHR, // VkStructureType sType;
3158         DE_NULL,                                                // const void* pNext;
3159         *(accelerationStructure->getPtr()),                     // VkAccelerationStructureKHR src;
3160         *(getPtr()),                                            // VkAccelerationStructureKHR dst;
3161         compactCopy ? VK_COPY_ACCELERATION_STRUCTURE_MODE_COMPACT_KHR :
3162                       VK_COPY_ACCELERATION_STRUCTURE_MODE_CLONE_KHR // VkCopyAccelerationStructureModeKHR mode;
3163     };
3164 
3165     if (m_buildType == VK_ACCELERATION_STRUCTURE_BUILD_TYPE_DEVICE_KHR)
3166     {
3167         vk.cmdCopyAccelerationStructureKHR(cmdBuffer, &copyAccelerationStructureInfo);
3168     }
3169     else if (!m_deferredOperation)
3170     {
3171         VK_CHECK(vk.copyAccelerationStructureKHR(device, DE_NULL, &copyAccelerationStructureInfo));
3172     }
3173     else
3174     {
3175         const auto deferredOperationPtr = createDeferredOperationKHR(vk, device);
3176         const auto deferredOperation    = deferredOperationPtr.get();
3177 
3178         VkResult result = vk.copyAccelerationStructureKHR(device, deferredOperation, &copyAccelerationStructureInfo);
3179 
3180         DE_ASSERT(result == VK_OPERATION_DEFERRED_KHR || result == VK_OPERATION_NOT_DEFERRED_KHR ||
3181                   result == VK_SUCCESS);
3182 
3183         finishDeferredOperation(vk, device, deferredOperation, m_workerThreadCount,
3184                                 result == VK_OPERATION_NOT_DEFERRED_KHR);
3185     }
3186 
3187     if (m_buildType == VK_ACCELERATION_STRUCTURE_BUILD_TYPE_DEVICE_KHR)
3188     {
3189         const VkAccessFlags accessMasks =
3190             VK_ACCESS_ACCELERATION_STRUCTURE_WRITE_BIT_KHR | VK_ACCESS_ACCELERATION_STRUCTURE_READ_BIT_KHR;
3191         const VkMemoryBarrier memBarrier = makeMemoryBarrier(accessMasks, accessMasks);
3192 
3193         cmdPipelineMemoryBarrier(vk, cmdBuffer, VK_PIPELINE_STAGE_ACCELERATION_STRUCTURE_BUILD_BIT_KHR,
3194                                  VK_PIPELINE_STAGE_ALL_COMMANDS_BIT, &memBarrier);
3195     }
3196 }
3197 
serialize(const DeviceInterface & vk,const VkDevice device,const VkCommandBuffer cmdBuffer,SerialStorage * storage)3198 void TopLevelAccelerationStructureKHR::serialize(const DeviceInterface &vk, const VkDevice device,
3199                                                  const VkCommandBuffer cmdBuffer, SerialStorage *storage)
3200 {
3201     DE_ASSERT(m_accelerationStructureKHR.get() != DE_NULL);
3202     DE_ASSERT(storage != DE_NULL);
3203 
3204     const VkCopyAccelerationStructureToMemoryInfoKHR copyAccelerationStructureInfo = {
3205         VK_STRUCTURE_TYPE_COPY_ACCELERATION_STRUCTURE_TO_MEMORY_INFO_KHR, // VkStructureType sType;
3206         DE_NULL,                                                          // const void* pNext;
3207         *(getPtr()),                                                      // VkAccelerationStructureKHR src;
3208         storage->getAddress(vk, device, m_buildType),                     // VkDeviceOrHostAddressKHR dst;
3209         VK_COPY_ACCELERATION_STRUCTURE_MODE_SERIALIZE_KHR                 // VkCopyAccelerationStructureModeKHR mode;
3210     };
3211 
3212     if (m_buildType == VK_ACCELERATION_STRUCTURE_BUILD_TYPE_DEVICE_KHR)
3213     {
3214         vk.cmdCopyAccelerationStructureToMemoryKHR(cmdBuffer, &copyAccelerationStructureInfo);
3215         if (storage->hasDeepFormat())
3216             serializeBottoms(vk, device, cmdBuffer, storage, DE_NULL);
3217     }
3218     else if (!m_deferredOperation)
3219     {
3220         VK_CHECK(vk.copyAccelerationStructureToMemoryKHR(device, DE_NULL, &copyAccelerationStructureInfo));
3221         if (storage->hasDeepFormat())
3222             serializeBottoms(vk, device, cmdBuffer, storage, DE_NULL);
3223     }
3224     else
3225     {
3226         const auto deferredOperationPtr = createDeferredOperationKHR(vk, device);
3227         const auto deferredOperation    = deferredOperationPtr.get();
3228 
3229         const VkResult result =
3230             vk.copyAccelerationStructureToMemoryKHR(device, deferredOperation, &copyAccelerationStructureInfo);
3231 
3232         DE_ASSERT(result == VK_OPERATION_DEFERRED_KHR || result == VK_OPERATION_NOT_DEFERRED_KHR ||
3233                   result == VK_SUCCESS);
3234         if (storage->hasDeepFormat())
3235             serializeBottoms(vk, device, cmdBuffer, storage, deferredOperation);
3236 
3237         finishDeferredOperation(vk, device, deferredOperation, m_workerThreadCount,
3238                                 result == VK_OPERATION_NOT_DEFERRED_KHR);
3239     }
3240 }
3241 
deserialize(const DeviceInterface & vk,const VkDevice device,const VkCommandBuffer cmdBuffer,SerialStorage * storage)3242 void TopLevelAccelerationStructureKHR::deserialize(const DeviceInterface &vk, const VkDevice device,
3243                                                    const VkCommandBuffer cmdBuffer, SerialStorage *storage)
3244 {
3245     DE_ASSERT(m_accelerationStructureKHR.get() != DE_NULL);
3246     DE_ASSERT(storage != DE_NULL);
3247 
3248     const VkCopyMemoryToAccelerationStructureInfoKHR copyAccelerationStructureInfo = {
3249         VK_STRUCTURE_TYPE_COPY_MEMORY_TO_ACCELERATION_STRUCTURE_INFO_KHR, // VkStructureType sType;
3250         DE_NULL,                                                          // const void* pNext;
3251         storage->getAddressConst(vk, device, m_buildType),                // VkDeviceOrHostAddressConstKHR src;
3252         *(getPtr()),                                                      // VkAccelerationStructureKHR dst;
3253         VK_COPY_ACCELERATION_STRUCTURE_MODE_DESERIALIZE_KHR               // VkCopyAccelerationStructureModeKHR mode;
3254     };
3255 
3256     if (m_buildType == VK_ACCELERATION_STRUCTURE_BUILD_TYPE_DEVICE_KHR)
3257     {
3258         vk.cmdCopyMemoryToAccelerationStructureKHR(cmdBuffer, &copyAccelerationStructureInfo);
3259     }
3260     else if (!m_deferredOperation)
3261     {
3262         VK_CHECK(vk.copyMemoryToAccelerationStructureKHR(device, DE_NULL, &copyAccelerationStructureInfo));
3263     }
3264     else
3265     {
3266         const auto deferredOperationPtr = createDeferredOperationKHR(vk, device);
3267         const auto deferredOperation    = deferredOperationPtr.get();
3268 
3269         const VkResult result =
3270             vk.copyMemoryToAccelerationStructureKHR(device, deferredOperation, &copyAccelerationStructureInfo);
3271 
3272         DE_ASSERT(result == VK_OPERATION_DEFERRED_KHR || result == VK_OPERATION_NOT_DEFERRED_KHR ||
3273                   result == VK_SUCCESS);
3274 
3275         finishDeferredOperation(vk, device, deferredOperation, m_workerThreadCount,
3276                                 result == VK_OPERATION_NOT_DEFERRED_KHR);
3277     }
3278 
3279     if (m_buildType == VK_ACCELERATION_STRUCTURE_BUILD_TYPE_DEVICE_KHR)
3280     {
3281         const VkAccessFlags accessMasks =
3282             VK_ACCESS_ACCELERATION_STRUCTURE_WRITE_BIT_KHR | VK_ACCESS_ACCELERATION_STRUCTURE_READ_BIT_KHR;
3283         const VkMemoryBarrier memBarrier = makeMemoryBarrier(accessMasks, accessMasks);
3284 
3285         cmdPipelineMemoryBarrier(vk, cmdBuffer, VK_PIPELINE_STAGE_ACCELERATION_STRUCTURE_BUILD_BIT_KHR,
3286                                  VK_PIPELINE_STAGE_ALL_COMMANDS_BIT, &memBarrier);
3287     }
3288 }
3289 
serializeBottoms(const DeviceInterface & vk,const VkDevice device,const VkCommandBuffer cmdBuffer,SerialStorage * storage,VkDeferredOperationKHR deferredOperation)3290 void TopLevelAccelerationStructureKHR::serializeBottoms(const DeviceInterface &vk, const VkDevice device,
3291                                                         const VkCommandBuffer cmdBuffer, SerialStorage *storage,
3292                                                         VkDeferredOperationKHR deferredOperation)
3293 {
3294     DE_UNREF(deferredOperation);
3295     DE_ASSERT(storage->hasDeepFormat());
3296 
3297     const std::vector<uint64_t> &addresses = storage->getSerialInfo().addresses();
3298     const std::size_t cbottoms             = m_bottomLevelInstances.size();
3299 
3300     uint32_t storageIndex = 0;
3301     std::vector<uint64_t> matches;
3302 
3303     for (std::size_t i = 0; i < cbottoms; ++i)
3304     {
3305         const uint64_t &lookAddr = addresses[i + 1];
3306         auto end                 = matches.end();
3307         auto match = std::find_if(matches.begin(), end, [&](const uint64_t &item) { return item == lookAddr; });
3308         if (match == end)
3309         {
3310             matches.emplace_back(lookAddr);
3311             m_bottomLevelInstances[i].get()->serialize(vk, device, cmdBuffer,
3312                                                        storage->getBottomStorage(storageIndex).get());
3313             storageIndex += 1;
3314         }
3315     }
3316 }
3317 
createAndDeserializeBottoms(const DeviceInterface & vk,const VkDevice device,const VkCommandBuffer cmdBuffer,Allocator & allocator,SerialStorage * storage)3318 void TopLevelAccelerationStructureKHR::createAndDeserializeBottoms(const DeviceInterface &vk, const VkDevice device,
3319                                                                    const VkCommandBuffer cmdBuffer,
3320                                                                    Allocator &allocator, SerialStorage *storage)
3321 {
3322     DE_ASSERT(storage->hasDeepFormat());
3323     DE_ASSERT(m_bottomLevelInstances.size() == 0);
3324 
3325     const std::vector<uint64_t> &addresses = storage->getSerialInfo().addresses();
3326     const std::size_t cbottoms             = addresses.size() - 1;
3327     uint32_t storageIndex                  = 0;
3328     std::vector<std::pair<uint64_t, std::size_t>> matches;
3329 
3330     for (std::size_t i = 0; i < cbottoms; ++i)
3331     {
3332         const uint64_t &lookAddr = addresses[i + 1];
3333         auto end                 = matches.end();
3334         auto match               = std::find_if(matches.begin(), end,
3335                                                 [&](const std::pair<uint64_t, std::size_t> &item) { return item.first == lookAddr; });
3336         if (match != end)
3337         {
3338             m_bottomLevelInstances.emplace_back(m_bottomLevelInstances[match->second]);
3339         }
3340         else
3341         {
3342             de::MovePtr<BottomLevelAccelerationStructure> blas = makeBottomLevelAccelerationStructure();
3343             blas->createAndDeserializeFrom(vk, device, cmdBuffer, allocator,
3344                                            storage->getBottomStorage(storageIndex).get());
3345             m_bottomLevelInstances.emplace_back(de::SharedPtr<BottomLevelAccelerationStructure>(blas.release()));
3346             matches.emplace_back(lookAddr, i);
3347             storageIndex += 1;
3348         }
3349     }
3350 
3351     std::vector<uint64_t> newAddresses = getSerializingAddresses(vk, device);
3352     DE_ASSERT(addresses.size() == newAddresses.size());
3353 
3354     SerialStorage::AccelerationStructureHeader *header = storage->getASHeader();
3355     DE_ASSERT(cbottoms == header->handleCount);
3356 
3357     // finally update bottom-level AS addresses before top-level AS deserialization
3358     for (std::size_t i = 0; i < cbottoms; ++i)
3359     {
3360         header->handleArray[i] = newAddresses[i + 1];
3361     }
3362 }
3363 
getSerializingSizes(const DeviceInterface & vk,const VkDevice device,const VkQueue queue,const uint32_t queueFamilyIndex)3364 std::vector<VkDeviceSize> TopLevelAccelerationStructureKHR::getSerializingSizes(const DeviceInterface &vk,
3365                                                                                 const VkDevice device,
3366                                                                                 const VkQueue queue,
3367                                                                                 const uint32_t queueFamilyIndex)
3368 {
3369     const uint32_t queryCount(uint32_t(m_bottomLevelInstances.size()) + 1);
3370     std::vector<VkAccelerationStructureKHR> handles(queryCount);
3371     std::vector<VkDeviceSize> sizes(queryCount);
3372 
3373     handles[0] = m_accelerationStructureKHR.get();
3374 
3375     for (uint32_t h = 1; h < queryCount; ++h)
3376         handles[h] = *m_bottomLevelInstances[h - 1].get()->getPtr();
3377 
3378     if (VK_ACCELERATION_STRUCTURE_BUILD_TYPE_HOST_KHR == m_buildType)
3379         queryAccelerationStructureSize(vk, device, DE_NULL, handles, m_buildType, DE_NULL,
3380                                        VK_QUERY_TYPE_ACCELERATION_STRUCTURE_SERIALIZATION_SIZE_KHR, 0u, sizes);
3381     else
3382     {
3383         const Move<VkCommandPool> cmdPool = createCommandPool(vk, device, 0, queueFamilyIndex);
3384         const Move<VkCommandBuffer> cmdBuffer =
3385             allocateCommandBuffer(vk, device, *cmdPool, VK_COMMAND_BUFFER_LEVEL_PRIMARY);
3386         const Move<VkQueryPool> queryPool =
3387             makeQueryPool(vk, device, VK_QUERY_TYPE_ACCELERATION_STRUCTURE_SERIALIZATION_SIZE_KHR, queryCount);
3388 
3389         beginCommandBuffer(vk, *cmdBuffer);
3390         queryAccelerationStructureSize(vk, device, *cmdBuffer, handles, m_buildType, *queryPool,
3391                                        VK_QUERY_TYPE_ACCELERATION_STRUCTURE_SERIALIZATION_SIZE_KHR, 0u, sizes);
3392         endCommandBuffer(vk, *cmdBuffer);
3393         submitCommandsAndWait(vk, device, queue, cmdBuffer.get());
3394 
3395         VK_CHECK(vk.getQueryPoolResults(device, *queryPool, 0u, queryCount, queryCount * sizeof(VkDeviceSize),
3396                                         sizes.data(), sizeof(VkDeviceSize),
3397                                         VK_QUERY_RESULT_64_BIT | VK_QUERY_RESULT_WAIT_BIT));
3398     }
3399 
3400     return sizes;
3401 }
3402 
getSerializingAddresses(const DeviceInterface & vk,const VkDevice device) const3403 std::vector<uint64_t> TopLevelAccelerationStructureKHR::getSerializingAddresses(const DeviceInterface &vk,
3404                                                                                 const VkDevice device) const
3405 {
3406     std::vector<uint64_t> result(m_bottomLevelInstances.size() + 1);
3407 
3408     VkAccelerationStructureDeviceAddressInfoKHR asDeviceAddressInfo = {
3409         VK_STRUCTURE_TYPE_ACCELERATION_STRUCTURE_DEVICE_ADDRESS_INFO_KHR, // VkStructureType sType;
3410         DE_NULL,                                                          // const void* pNext;
3411         DE_NULL // VkAccelerationStructureKHR accelerationStructure;
3412     };
3413 
3414     if (m_buildType == VK_ACCELERATION_STRUCTURE_BUILD_TYPE_DEVICE_KHR)
3415     {
3416         asDeviceAddressInfo.accelerationStructure = m_accelerationStructureKHR.get();
3417         result[0] = vk.getAccelerationStructureDeviceAddressKHR(device, &asDeviceAddressInfo);
3418     }
3419     else
3420     {
3421         result[0] = uint64_t(getPtr()->getInternal());
3422     }
3423 
3424     for (size_t instanceNdx = 0; instanceNdx < m_bottomLevelInstances.size(); ++instanceNdx)
3425     {
3426         const BottomLevelAccelerationStructure &bottomLevelAccelerationStructure = *m_bottomLevelInstances[instanceNdx];
3427         const VkAccelerationStructureKHR accelerationStructureKHR = *bottomLevelAccelerationStructure.getPtr();
3428 
3429         if (m_buildType == VK_ACCELERATION_STRUCTURE_BUILD_TYPE_DEVICE_KHR)
3430         {
3431             asDeviceAddressInfo.accelerationStructure = accelerationStructureKHR;
3432             result[instanceNdx + 1] = vk.getAccelerationStructureDeviceAddressKHR(device, &asDeviceAddressInfo);
3433         }
3434         else
3435         {
3436             result[instanceNdx + 1] = uint64_t(accelerationStructureKHR.getInternal());
3437         }
3438     }
3439 
3440     return result;
3441 }
3442 
getPtr(void) const3443 const VkAccelerationStructureKHR *TopLevelAccelerationStructureKHR::getPtr(void) const
3444 {
3445     return &m_accelerationStructureKHR.get();
3446 }
3447 
prepareInstances(const DeviceInterface & vk,const VkDevice device,VkAccelerationStructureGeometryKHR & accelerationStructureGeometryKHR,std::vector<uint32_t> & maxPrimitiveCounts)3448 void TopLevelAccelerationStructureKHR::prepareInstances(
3449     const DeviceInterface &vk, const VkDevice device,
3450     VkAccelerationStructureGeometryKHR &accelerationStructureGeometryKHR, std::vector<uint32_t> &maxPrimitiveCounts)
3451 {
3452     maxPrimitiveCounts.resize(1);
3453     maxPrimitiveCounts[0] = static_cast<uint32_t>(m_bottomLevelInstances.size());
3454 
3455     VkDeviceOrHostAddressConstKHR instancesData;
3456     if (m_buildType == VK_ACCELERATION_STRUCTURE_BUILD_TYPE_DEVICE_KHR)
3457     {
3458         if (m_instanceBuffer.get() != DE_NULL)
3459         {
3460             if (m_useArrayOfPointers)
3461             {
3462                 uint8_t *bufferStart = static_cast<uint8_t *>(m_instanceAddressBuffer->getAllocation().getHostPtr());
3463                 VkDeviceSize bufferOffset = 0;
3464                 VkDeviceOrHostAddressConstKHR firstInstance =
3465                     makeDeviceOrHostAddressConstKHR(vk, device, m_instanceBuffer->get(), 0);
3466                 for (size_t instanceNdx = 0; instanceNdx < m_bottomLevelInstances.size(); ++instanceNdx)
3467                 {
3468                     VkDeviceOrHostAddressConstKHR currentInstance;
3469                     currentInstance.deviceAddress =
3470                         firstInstance.deviceAddress + instanceNdx * sizeof(VkAccelerationStructureInstanceKHR);
3471 
3472                     deMemcpy(&bufferStart[bufferOffset], &currentInstance,
3473                              sizeof(VkDeviceOrHostAddressConstKHR::deviceAddress));
3474                     bufferOffset += sizeof(VkDeviceOrHostAddressConstKHR::deviceAddress);
3475                 }
3476                 flushMappedMemoryRange(vk, device, m_instanceAddressBuffer->getAllocation().getMemory(),
3477                                        m_instanceAddressBuffer->getAllocation().getOffset(), VK_WHOLE_SIZE);
3478 
3479                 instancesData = makeDeviceOrHostAddressConstKHR(vk, device, m_instanceAddressBuffer->get(), 0);
3480             }
3481             else
3482                 instancesData = makeDeviceOrHostAddressConstKHR(vk, device, m_instanceBuffer->get(), 0);
3483         }
3484         else
3485             instancesData = makeDeviceOrHostAddressConstKHR(DE_NULL);
3486     }
3487     else
3488     {
3489         if (m_instanceBuffer.get() != DE_NULL)
3490         {
3491             if (m_useArrayOfPointers)
3492             {
3493                 uint8_t *bufferStart = static_cast<uint8_t *>(m_instanceAddressBuffer->getAllocation().getHostPtr());
3494                 VkDeviceSize bufferOffset = 0;
3495                 for (size_t instanceNdx = 0; instanceNdx < m_bottomLevelInstances.size(); ++instanceNdx)
3496                 {
3497                     VkDeviceOrHostAddressConstKHR currentInstance;
3498                     currentInstance.hostAddress = (uint8_t *)m_instanceBuffer->getAllocation().getHostPtr() +
3499                                                   instanceNdx * sizeof(VkAccelerationStructureInstanceKHR);
3500 
3501                     deMemcpy(&bufferStart[bufferOffset], &currentInstance,
3502                              sizeof(VkDeviceOrHostAddressConstKHR::hostAddress));
3503                     bufferOffset += sizeof(VkDeviceOrHostAddressConstKHR::hostAddress);
3504                 }
3505                 instancesData = makeDeviceOrHostAddressConstKHR(m_instanceAddressBuffer->getAllocation().getHostPtr());
3506             }
3507             else
3508                 instancesData = makeDeviceOrHostAddressConstKHR(m_instanceBuffer->getAllocation().getHostPtr());
3509         }
3510         else
3511             instancesData = makeDeviceOrHostAddressConstKHR(DE_NULL);
3512     }
3513 
3514     VkAccelerationStructureGeometryInstancesDataKHR accelerationStructureGeometryInstancesDataKHR = {
3515         VK_STRUCTURE_TYPE_ACCELERATION_STRUCTURE_GEOMETRY_INSTANCES_DATA_KHR, //  VkStructureType sType;
3516         DE_NULL,                                                              //  const void* pNext;
3517         (VkBool32)(m_useArrayOfPointers ? true : false),                      //  VkBool32 arrayOfPointers;
3518         instancesData                                                         //  VkDeviceOrHostAddressConstKHR data;
3519     };
3520 
3521     accelerationStructureGeometryKHR = {
3522         VK_STRUCTURE_TYPE_ACCELERATION_STRUCTURE_GEOMETRY_KHR, //  VkStructureType sType;
3523         DE_NULL,                                               //  const void* pNext;
3524         VK_GEOMETRY_TYPE_INSTANCES_KHR,                        //  VkGeometryTypeKHR geometryType;
3525         makeVkAccelerationStructureInstancesDataKHR(
3526             accelerationStructureGeometryInstancesDataKHR), //  VkAccelerationStructureGeometryDataKHR geometry;
3527         (VkGeometryFlagsKHR)0u                              //  VkGeometryFlagsKHR flags;
3528     };
3529 }
3530 
getRequiredAllocationCount(void)3531 uint32_t TopLevelAccelerationStructure::getRequiredAllocationCount(void)
3532 {
3533     return TopLevelAccelerationStructureKHR::getRequiredAllocationCount();
3534 }
3535 
makeTopLevelAccelerationStructure()3536 de::MovePtr<TopLevelAccelerationStructure> makeTopLevelAccelerationStructure()
3537 {
3538     return de::MovePtr<TopLevelAccelerationStructure>(new TopLevelAccelerationStructureKHR);
3539 }
3540 
queryAccelerationStructureSizeKHR(const DeviceInterface & vk,const VkDevice device,const VkCommandBuffer cmdBuffer,const std::vector<VkAccelerationStructureKHR> & accelerationStructureHandles,VkAccelerationStructureBuildTypeKHR buildType,const VkQueryPool queryPool,VkQueryType queryType,uint32_t firstQuery,std::vector<VkDeviceSize> & results)3541 bool queryAccelerationStructureSizeKHR(const DeviceInterface &vk, const VkDevice device,
3542                                        const VkCommandBuffer cmdBuffer,
3543                                        const std::vector<VkAccelerationStructureKHR> &accelerationStructureHandles,
3544                                        VkAccelerationStructureBuildTypeKHR buildType, const VkQueryPool queryPool,
3545                                        VkQueryType queryType, uint32_t firstQuery, std::vector<VkDeviceSize> &results)
3546 {
3547     DE_ASSERT(queryType == VK_QUERY_TYPE_ACCELERATION_STRUCTURE_COMPACTED_SIZE_KHR ||
3548               queryType == VK_QUERY_TYPE_ACCELERATION_STRUCTURE_SERIALIZATION_SIZE_KHR);
3549 
3550     if (buildType == VK_ACCELERATION_STRUCTURE_BUILD_TYPE_DEVICE_KHR)
3551     {
3552         // queryPool must be large enough to contain at least (firstQuery + accelerationStructureHandles.size()) queries
3553         vk.cmdResetQueryPool(cmdBuffer, queryPool, firstQuery, uint32_t(accelerationStructureHandles.size()));
3554         vk.cmdWriteAccelerationStructuresPropertiesKHR(cmdBuffer, uint32_t(accelerationStructureHandles.size()),
3555                                                        accelerationStructureHandles.data(), queryType, queryPool,
3556                                                        firstQuery);
3557         // results cannot be retrieved to CPU at the moment - you need to do it using getQueryPoolResults after cmdBuffer is executed. Meanwhile function returns a vector of 0s.
3558         results.resize(accelerationStructureHandles.size(), 0u);
3559         return false;
3560     }
3561     // buildType != VK_ACCELERATION_STRUCTURE_BUILD_TYPE_DEVICE_KHR
3562     results.resize(accelerationStructureHandles.size(), 0u);
3563     vk.writeAccelerationStructuresPropertiesKHR(
3564         device, uint32_t(accelerationStructureHandles.size()), accelerationStructureHandles.data(), queryType,
3565         sizeof(VkDeviceSize) * accelerationStructureHandles.size(), results.data(), sizeof(VkDeviceSize));
3566     // results will contain proper values
3567     return true;
3568 }
3569 
queryAccelerationStructureSize(const DeviceInterface & vk,const VkDevice device,const VkCommandBuffer cmdBuffer,const std::vector<VkAccelerationStructureKHR> & accelerationStructureHandles,VkAccelerationStructureBuildTypeKHR buildType,const VkQueryPool queryPool,VkQueryType queryType,uint32_t firstQuery,std::vector<VkDeviceSize> & results)3570 bool queryAccelerationStructureSize(const DeviceInterface &vk, const VkDevice device, const VkCommandBuffer cmdBuffer,
3571                                     const std::vector<VkAccelerationStructureKHR> &accelerationStructureHandles,
3572                                     VkAccelerationStructureBuildTypeKHR buildType, const VkQueryPool queryPool,
3573                                     VkQueryType queryType, uint32_t firstQuery, std::vector<VkDeviceSize> &results)
3574 {
3575     return queryAccelerationStructureSizeKHR(vk, device, cmdBuffer, accelerationStructureHandles, buildType, queryPool,
3576                                              queryType, firstQuery, results);
3577 }
3578 
RayTracingPipeline()3579 RayTracingPipeline::RayTracingPipeline()
3580     : m_shadersModules()
3581     , m_pipelineLibraries()
3582     , m_shaderCreateInfos()
3583     , m_shadersGroupCreateInfos()
3584     , m_pipelineCreateFlags(0U)
3585     , m_pipelineCreateFlags2(0U)
3586     , m_maxRecursionDepth(1U)
3587     , m_maxPayloadSize(0U)
3588     , m_maxAttributeSize(0U)
3589     , m_deferredOperation(false)
3590     , m_workerThreadCount(0)
3591 {
3592 }
3593 
~RayTracingPipeline()3594 RayTracingPipeline::~RayTracingPipeline()
3595 {
3596 }
3597 
3598 #define CHECKED_ASSIGN_SHADER(SHADER, STAGE) \
3599     if (SHADER == VK_SHADER_UNUSED_KHR)      \
3600         SHADER = STAGE;                      \
3601     else                                     \
3602         TCU_THROW(InternalError, "Attempt to reassign shader")
3603 
addShader(VkShaderStageFlagBits shaderStage,Move<VkShaderModule> shaderModule,uint32_t group,const VkSpecializationInfo * specializationInfo,const VkPipelineShaderStageCreateFlags pipelineShaderStageCreateFlags,const void * pipelineShaderStageCreateInfopNext)3604 void RayTracingPipeline::addShader(VkShaderStageFlagBits shaderStage, Move<VkShaderModule> shaderModule, uint32_t group,
3605                                    const VkSpecializationInfo *specializationInfo,
3606                                    const VkPipelineShaderStageCreateFlags pipelineShaderStageCreateFlags,
3607                                    const void *pipelineShaderStageCreateInfopNext)
3608 {
3609     addShader(shaderStage, makeVkSharedPtr(shaderModule), group, specializationInfo, pipelineShaderStageCreateFlags,
3610               pipelineShaderStageCreateInfopNext);
3611 }
3612 
addShader(VkShaderStageFlagBits shaderStage,de::SharedPtr<Move<VkShaderModule>> shaderModule,uint32_t group,const VkSpecializationInfo * specializationInfoPtr,const VkPipelineShaderStageCreateFlags pipelineShaderStageCreateFlags,const void * pipelineShaderStageCreateInfopNext)3613 void RayTracingPipeline::addShader(VkShaderStageFlagBits shaderStage, de::SharedPtr<Move<VkShaderModule>> shaderModule,
3614                                    uint32_t group, const VkSpecializationInfo *specializationInfoPtr,
3615                                    const VkPipelineShaderStageCreateFlags pipelineShaderStageCreateFlags,
3616                                    const void *pipelineShaderStageCreateInfopNext)
3617 {
3618     addShader(shaderStage, **shaderModule, group, specializationInfoPtr, pipelineShaderStageCreateFlags,
3619               pipelineShaderStageCreateInfopNext);
3620     m_shadersModules.push_back(shaderModule);
3621 }
3622 
addShader(VkShaderStageFlagBits shaderStage,VkShaderModule shaderModule,uint32_t group,const VkSpecializationInfo * specializationInfoPtr,const VkPipelineShaderStageCreateFlags pipelineShaderStageCreateFlags,const void * pipelineShaderStageCreateInfopNext)3623 void RayTracingPipeline::addShader(VkShaderStageFlagBits shaderStage, VkShaderModule shaderModule, uint32_t group,
3624                                    const VkSpecializationInfo *specializationInfoPtr,
3625                                    const VkPipelineShaderStageCreateFlags pipelineShaderStageCreateFlags,
3626                                    const void *pipelineShaderStageCreateInfopNext)
3627 {
3628     if (group >= m_shadersGroupCreateInfos.size())
3629     {
3630         for (size_t groupNdx = m_shadersGroupCreateInfos.size(); groupNdx <= group; ++groupNdx)
3631         {
3632             VkRayTracingShaderGroupCreateInfoKHR shaderGroupCreateInfo = {
3633                 VK_STRUCTURE_TYPE_RAY_TRACING_SHADER_GROUP_CREATE_INFO_KHR, //  VkStructureType sType;
3634                 DE_NULL,                                                    //  const void* pNext;
3635                 VK_RAY_TRACING_SHADER_GROUP_TYPE_MAX_ENUM_KHR,              //  VkRayTracingShaderGroupTypeKHR type;
3636                 VK_SHADER_UNUSED_KHR,                                       //  uint32_t generalShader;
3637                 VK_SHADER_UNUSED_KHR,                                       //  uint32_t closestHitShader;
3638                 VK_SHADER_UNUSED_KHR,                                       //  uint32_t anyHitShader;
3639                 VK_SHADER_UNUSED_KHR,                                       //  uint32_t intersectionShader;
3640                 DE_NULL, //  const void* pShaderGroupCaptureReplayHandle;
3641             };
3642 
3643             m_shadersGroupCreateInfos.push_back(shaderGroupCreateInfo);
3644         }
3645     }
3646 
3647     const uint32_t shaderStageNdx                               = (uint32_t)m_shaderCreateInfos.size();
3648     VkRayTracingShaderGroupCreateInfoKHR &shaderGroupCreateInfo = m_shadersGroupCreateInfos[group];
3649 
3650     switch (shaderStage)
3651     {
3652     case VK_SHADER_STAGE_RAYGEN_BIT_KHR:
3653         CHECKED_ASSIGN_SHADER(shaderGroupCreateInfo.generalShader, shaderStageNdx);
3654         break;
3655     case VK_SHADER_STAGE_MISS_BIT_KHR:
3656         CHECKED_ASSIGN_SHADER(shaderGroupCreateInfo.generalShader, shaderStageNdx);
3657         break;
3658     case VK_SHADER_STAGE_CALLABLE_BIT_KHR:
3659         CHECKED_ASSIGN_SHADER(shaderGroupCreateInfo.generalShader, shaderStageNdx);
3660         break;
3661     case VK_SHADER_STAGE_ANY_HIT_BIT_KHR:
3662         CHECKED_ASSIGN_SHADER(shaderGroupCreateInfo.anyHitShader, shaderStageNdx);
3663         break;
3664     case VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR:
3665         CHECKED_ASSIGN_SHADER(shaderGroupCreateInfo.closestHitShader, shaderStageNdx);
3666         break;
3667     case VK_SHADER_STAGE_INTERSECTION_BIT_KHR:
3668         CHECKED_ASSIGN_SHADER(shaderGroupCreateInfo.intersectionShader, shaderStageNdx);
3669         break;
3670     default:
3671         TCU_THROW(InternalError, "Unacceptable stage");
3672     }
3673 
3674     switch (shaderStage)
3675     {
3676     case VK_SHADER_STAGE_RAYGEN_BIT_KHR:
3677     case VK_SHADER_STAGE_MISS_BIT_KHR:
3678     case VK_SHADER_STAGE_CALLABLE_BIT_KHR:
3679     {
3680         DE_ASSERT(shaderGroupCreateInfo.type == VK_RAY_TRACING_SHADER_GROUP_TYPE_MAX_ENUM_KHR);
3681         shaderGroupCreateInfo.type = VK_RAY_TRACING_SHADER_GROUP_TYPE_GENERAL_KHR;
3682 
3683         break;
3684     }
3685 
3686     case VK_SHADER_STAGE_ANY_HIT_BIT_KHR:
3687     case VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR:
3688     case VK_SHADER_STAGE_INTERSECTION_BIT_KHR:
3689     {
3690         DE_ASSERT(shaderGroupCreateInfo.type != VK_RAY_TRACING_SHADER_GROUP_TYPE_GENERAL_KHR);
3691         shaderGroupCreateInfo.type = (shaderGroupCreateInfo.intersectionShader == VK_SHADER_UNUSED_KHR) ?
3692                                          VK_RAY_TRACING_SHADER_GROUP_TYPE_TRIANGLES_HIT_GROUP_KHR :
3693                                          VK_RAY_TRACING_SHADER_GROUP_TYPE_PROCEDURAL_HIT_GROUP_KHR;
3694 
3695         break;
3696     }
3697 
3698     default:
3699         TCU_THROW(InternalError, "Unacceptable stage");
3700     }
3701 
3702     {
3703         const VkPipelineShaderStageCreateInfo shaderCreateInfo = {
3704             VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO, //  VkStructureType sType;
3705             pipelineShaderStageCreateInfopNext,                  //  const void* pNext;
3706             pipelineShaderStageCreateFlags,                      //  VkPipelineShaderStageCreateFlags flags;
3707             shaderStage,                                         //  VkShaderStageFlagBits stage;
3708             shaderModule,                                        //  VkShaderModule module;
3709             "main",                                              //  const char* pName;
3710             specializationInfoPtr,                               //  const VkSpecializationInfo* pSpecializationInfo;
3711         };
3712 
3713         m_shaderCreateInfos.push_back(shaderCreateInfo);
3714     }
3715 }
3716 
setGroupCaptureReplayHandle(uint32_t group,const void * pShaderGroupCaptureReplayHandle)3717 void RayTracingPipeline::setGroupCaptureReplayHandle(uint32_t group, const void *pShaderGroupCaptureReplayHandle)
3718 {
3719     DE_ASSERT(static_cast<size_t>(group) < m_shadersGroupCreateInfos.size());
3720     m_shadersGroupCreateInfos[group].pShaderGroupCaptureReplayHandle = pShaderGroupCaptureReplayHandle;
3721 }
3722 
addLibrary(de::SharedPtr<de::MovePtr<RayTracingPipeline>> pipelineLibrary)3723 void RayTracingPipeline::addLibrary(de::SharedPtr<de::MovePtr<RayTracingPipeline>> pipelineLibrary)
3724 {
3725     m_pipelineLibraries.push_back(pipelineLibrary);
3726 }
3727 
getShaderGroupCount(void)3728 uint32_t RayTracingPipeline::getShaderGroupCount(void)
3729 {
3730     return de::sizeU32(m_shadersGroupCreateInfos);
3731 }
3732 
getFullShaderGroupCount(void)3733 uint32_t RayTracingPipeline::getFullShaderGroupCount(void)
3734 {
3735     uint32_t totalCount = getShaderGroupCount();
3736 
3737     for (const auto &lib : m_pipelineLibraries)
3738         totalCount += lib->get()->getFullShaderGroupCount();
3739 
3740     return totalCount;
3741 }
3742 
createPipelineKHR(const DeviceInterface & vk,const VkDevice device,const VkPipelineLayout pipelineLayout,const std::vector<VkPipeline> & pipelineLibraries,const VkPipelineCache pipelineCache)3743 Move<VkPipeline> RayTracingPipeline::createPipelineKHR(const DeviceInterface &vk, const VkDevice device,
3744                                                        const VkPipelineLayout pipelineLayout,
3745                                                        const std::vector<VkPipeline> &pipelineLibraries,
3746                                                        const VkPipelineCache pipelineCache)
3747 {
3748     for (size_t groupNdx = 0; groupNdx < m_shadersGroupCreateInfos.size(); ++groupNdx)
3749         DE_ASSERT(m_shadersGroupCreateInfos[groupNdx].sType ==
3750                   VK_STRUCTURE_TYPE_RAY_TRACING_SHADER_GROUP_CREATE_INFO_KHR);
3751 
3752     VkPipelineLibraryCreateInfoKHR librariesCreateInfo = {
3753         VK_STRUCTURE_TYPE_PIPELINE_LIBRARY_CREATE_INFO_KHR, //  VkStructureType sType;
3754         DE_NULL,                                            //  const void* pNext;
3755         de::sizeU32(pipelineLibraries),                     //  uint32_t libraryCount;
3756         de::dataOrNull(pipelineLibraries)                   //  VkPipeline* pLibraries;
3757     };
3758     const VkRayTracingPipelineInterfaceCreateInfoKHR pipelineInterfaceCreateInfo = {
3759         VK_STRUCTURE_TYPE_RAY_TRACING_PIPELINE_INTERFACE_CREATE_INFO_KHR, //  VkStructureType sType;
3760         DE_NULL,                                                          //  const void* pNext;
3761         m_maxPayloadSize,                                                 //  uint32_t maxPayloadSize;
3762         m_maxAttributeSize                                                //  uint32_t maxAttributeSize;
3763     };
3764     const bool addPipelineInterfaceCreateInfo = m_maxPayloadSize != 0 || m_maxAttributeSize != 0;
3765     const VkRayTracingPipelineInterfaceCreateInfoKHR *pipelineInterfaceCreateInfoPtr =
3766         addPipelineInterfaceCreateInfo ? &pipelineInterfaceCreateInfo : DE_NULL;
3767     const VkPipelineLibraryCreateInfoKHR *librariesCreateInfoPtr =
3768         (pipelineLibraries.empty() ? nullptr : &librariesCreateInfo);
3769 
3770     Move<VkDeferredOperationKHR> deferredOperation;
3771     if (m_deferredOperation)
3772         deferredOperation = createDeferredOperationKHR(vk, device);
3773 
3774     VkPipelineDynamicStateCreateInfo dynamicStateCreateInfo = {
3775         VK_STRUCTURE_TYPE_PIPELINE_DYNAMIC_STATE_CREATE_INFO, // VkStructureType sType;
3776         DE_NULL,                                              // const void* pNext;
3777         0,                                                    // VkPipelineDynamicStateCreateFlags flags;
3778         static_cast<uint32_t>(m_dynamicStates.size()),        // uint32_t dynamicStateCount;
3779         m_dynamicStates.data(),                               // const VkDynamicState* pDynamicStates;
3780     };
3781 
3782     VkRayTracingPipelineCreateInfoKHR pipelineCreateInfo{
3783         VK_STRUCTURE_TYPE_RAY_TRACING_PIPELINE_CREATE_INFO_KHR, //  VkStructureType sType;
3784         DE_NULL,                                                //  const void* pNext;
3785         m_pipelineCreateFlags,                                  //  VkPipelineCreateFlags flags;
3786         de::sizeU32(m_shaderCreateInfos),                       //  uint32_t stageCount;
3787         de::dataOrNull(m_shaderCreateInfos),                    //  const VkPipelineShaderStageCreateInfo* pStages;
3788         de::sizeU32(m_shadersGroupCreateInfos),                 //  uint32_t groupCount;
3789         de::dataOrNull(m_shadersGroupCreateInfos),              //  const VkRayTracingShaderGroupCreateInfoKHR* pGroups;
3790         m_maxRecursionDepth,                                    //  uint32_t maxRecursionDepth;
3791         librariesCreateInfoPtr,                                 //  VkPipelineLibraryCreateInfoKHR* pLibraryInfo;
3792         pipelineInterfaceCreateInfoPtr, //  VkRayTracingPipelineInterfaceCreateInfoKHR* pLibraryInterface;
3793         &dynamicStateCreateInfo,        //  const VkPipelineDynamicStateCreateInfo* pDynamicState;
3794         pipelineLayout,                 //  VkPipelineLayout layout;
3795         (VkPipeline)DE_NULL,            //  VkPipeline basePipelineHandle;
3796         0,                              //  int32_t basePipelineIndex;
3797     };
3798     VkPipeline object = DE_NULL;
3799     VkResult result   = vk.createRayTracingPipelinesKHR(device, deferredOperation.get(), pipelineCache, 1u,
3800                                                         &pipelineCreateInfo, DE_NULL, &object);
3801     const bool allowCompileRequired =
3802         ((m_pipelineCreateFlags & VK_PIPELINE_CREATE_FAIL_ON_PIPELINE_COMPILE_REQUIRED_BIT_EXT) != 0);
3803 
3804     VkPipelineCreateFlags2CreateInfoKHR pipelineFlags2CreateInfo = initVulkanStructure();
3805     if (m_pipelineCreateFlags2)
3806     {
3807         pipelineFlags2CreateInfo.flags = m_pipelineCreateFlags2;
3808         pipelineCreateInfo.pNext       = &pipelineFlags2CreateInfo;
3809         pipelineCreateInfo.flags       = 0;
3810     }
3811 
3812     if (m_deferredOperation)
3813     {
3814         DE_ASSERT(result == VK_OPERATION_DEFERRED_KHR || result == VK_OPERATION_NOT_DEFERRED_KHR ||
3815                   result == VK_SUCCESS || (allowCompileRequired && result == VK_PIPELINE_COMPILE_REQUIRED));
3816         finishDeferredOperation(vk, device, deferredOperation.get(), m_workerThreadCount,
3817                                 result == VK_OPERATION_NOT_DEFERRED_KHR);
3818     }
3819 
3820     if (allowCompileRequired && result == VK_PIPELINE_COMPILE_REQUIRED)
3821         throw CompileRequiredError("createRayTracingPipelinesKHR returned VK_PIPELINE_COMPILE_REQUIRED");
3822 
3823     Move<VkPipeline> pipeline(check<VkPipeline>(object), Deleter<VkPipeline>(vk, device, DE_NULL));
3824     return pipeline;
3825 }
3826 
createPipeline(const DeviceInterface & vk,const VkDevice device,const VkPipelineLayout pipelineLayout,const std::vector<de::SharedPtr<Move<VkPipeline>>> & pipelineLibraries)3827 Move<VkPipeline> RayTracingPipeline::createPipeline(
3828     const DeviceInterface &vk, const VkDevice device, const VkPipelineLayout pipelineLayout,
3829     const std::vector<de::SharedPtr<Move<VkPipeline>>> &pipelineLibraries)
3830 {
3831     std::vector<VkPipeline> rawPipelines;
3832     rawPipelines.reserve(pipelineLibraries.size());
3833     for (const auto &lib : pipelineLibraries)
3834         rawPipelines.push_back(lib.get()->get());
3835 
3836     return createPipelineKHR(vk, device, pipelineLayout, rawPipelines);
3837 }
3838 
createPipeline(const DeviceInterface & vk,const VkDevice device,const VkPipelineLayout pipelineLayout,const std::vector<VkPipeline> & pipelineLibraries,const VkPipelineCache pipelineCache)3839 Move<VkPipeline> RayTracingPipeline::createPipeline(const DeviceInterface &vk, const VkDevice device,
3840                                                     const VkPipelineLayout pipelineLayout,
3841                                                     const std::vector<VkPipeline> &pipelineLibraries,
3842                                                     const VkPipelineCache pipelineCache)
3843 {
3844     return createPipelineKHR(vk, device, pipelineLayout, pipelineLibraries, pipelineCache);
3845 }
3846 
createPipelineWithLibraries(const DeviceInterface & vk,const VkDevice device,const VkPipelineLayout pipelineLayout)3847 std::vector<de::SharedPtr<Move<VkPipeline>>> RayTracingPipeline::createPipelineWithLibraries(
3848     const DeviceInterface &vk, const VkDevice device, const VkPipelineLayout pipelineLayout)
3849 {
3850     for (size_t groupNdx = 0; groupNdx < m_shadersGroupCreateInfos.size(); ++groupNdx)
3851         DE_ASSERT(m_shadersGroupCreateInfos[groupNdx].sType ==
3852                   VK_STRUCTURE_TYPE_RAY_TRACING_SHADER_GROUP_CREATE_INFO_KHR);
3853 
3854     DE_ASSERT(m_shaderCreateInfos.size() > 0);
3855     DE_ASSERT(m_shadersGroupCreateInfos.size() > 0);
3856 
3857     std::vector<de::SharedPtr<Move<VkPipeline>>> result, allLibraries, firstLibraries;
3858     for (auto it = begin(m_pipelineLibraries), eit = end(m_pipelineLibraries); it != eit; ++it)
3859     {
3860         auto childLibraries = (*it)->get()->createPipelineWithLibraries(vk, device, pipelineLayout);
3861         DE_ASSERT(childLibraries.size() > 0);
3862         firstLibraries.push_back(childLibraries[0]);
3863         std::copy(begin(childLibraries), end(childLibraries), std::back_inserter(allLibraries));
3864     }
3865     result.push_back(makeVkSharedPtr(createPipeline(vk, device, pipelineLayout, firstLibraries)));
3866     std::copy(begin(allLibraries), end(allLibraries), std::back_inserter(result));
3867     return result;
3868 }
3869 
getShaderGroupHandles(const DeviceInterface & vk,const VkDevice device,const VkPipeline pipeline,const uint32_t shaderGroupHandleSize,const uint32_t firstGroup,const uint32_t groupCount) const3870 std::vector<uint8_t> RayTracingPipeline::getShaderGroupHandles(const DeviceInterface &vk, const VkDevice device,
3871                                                                const VkPipeline pipeline,
3872                                                                const uint32_t shaderGroupHandleSize,
3873                                                                const uint32_t firstGroup,
3874                                                                const uint32_t groupCount) const
3875 {
3876     const auto handleArraySizeBytes = groupCount * shaderGroupHandleSize;
3877     std::vector<uint8_t> shaderHandles(handleArraySizeBytes);
3878 
3879     VK_CHECK(getRayTracingShaderGroupHandles(vk, device, pipeline, firstGroup, groupCount,
3880                                              static_cast<uintptr_t>(shaderHandles.size()),
3881                                              de::dataOrNull(shaderHandles)));
3882 
3883     return shaderHandles;
3884 }
3885 
getShaderGroupReplayHandles(const DeviceInterface & vk,const VkDevice device,const VkPipeline pipeline,const uint32_t shaderGroupHandleReplaySize,const uint32_t firstGroup,const uint32_t groupCount) const3886 std::vector<uint8_t> RayTracingPipeline::getShaderGroupReplayHandles(const DeviceInterface &vk, const VkDevice device,
3887                                                                      const VkPipeline pipeline,
3888                                                                      const uint32_t shaderGroupHandleReplaySize,
3889                                                                      const uint32_t firstGroup,
3890                                                                      const uint32_t groupCount) const
3891 {
3892     const auto handleArraySizeBytes = groupCount * shaderGroupHandleReplaySize;
3893     std::vector<uint8_t> shaderHandles(handleArraySizeBytes);
3894 
3895     VK_CHECK(getRayTracingCaptureReplayShaderGroupHandles(vk, device, pipeline, firstGroup, groupCount,
3896                                                           static_cast<uintptr_t>(shaderHandles.size()),
3897                                                           de::dataOrNull(shaderHandles)));
3898 
3899     return shaderHandles;
3900 }
3901 
createShaderBindingTable(const DeviceInterface & vk,const VkDevice device,const VkPipeline pipeline,Allocator & allocator,const uint32_t & shaderGroupHandleSize,const uint32_t shaderGroupBaseAlignment,const uint32_t & firstGroup,const uint32_t & groupCount,const VkBufferCreateFlags & additionalBufferCreateFlags,const VkBufferUsageFlags & additionalBufferUsageFlags,const MemoryRequirement & additionalMemoryRequirement,const VkDeviceAddress & opaqueCaptureAddress,const uint32_t shaderBindingTableOffset,const uint32_t shaderRecordSize,const void ** shaderGroupDataPtrPerGroup,const bool autoAlignRecords)3902 de::MovePtr<BufferWithMemory> RayTracingPipeline::createShaderBindingTable(
3903     const DeviceInterface &vk, const VkDevice device, const VkPipeline pipeline, Allocator &allocator,
3904     const uint32_t &shaderGroupHandleSize, const uint32_t shaderGroupBaseAlignment, const uint32_t &firstGroup,
3905     const uint32_t &groupCount, const VkBufferCreateFlags &additionalBufferCreateFlags,
3906     const VkBufferUsageFlags &additionalBufferUsageFlags, const MemoryRequirement &additionalMemoryRequirement,
3907     const VkDeviceAddress &opaqueCaptureAddress, const uint32_t shaderBindingTableOffset,
3908     const uint32_t shaderRecordSize, const void **shaderGroupDataPtrPerGroup, const bool autoAlignRecords)
3909 {
3910     const auto shaderHandles =
3911         getShaderGroupHandles(vk, device, pipeline, shaderGroupHandleSize, firstGroup, groupCount);
3912     return createShaderBindingTable(vk, device, allocator, shaderGroupHandleSize, shaderGroupBaseAlignment,
3913                                     shaderHandles, additionalBufferCreateFlags, additionalBufferUsageFlags,
3914                                     additionalMemoryRequirement, opaqueCaptureAddress, shaderBindingTableOffset,
3915                                     shaderRecordSize, shaderGroupDataPtrPerGroup, autoAlignRecords);
3916 }
3917 
createShaderBindingTable(const DeviceInterface & vk,const VkDevice device,Allocator & allocator,const uint32_t shaderGroupHandleSize,const uint32_t shaderGroupBaseAlignment,const std::vector<uint8_t> & shaderHandles,const VkBufferCreateFlags additionalBufferCreateFlags,const VkBufferUsageFlags additionalBufferUsageFlags,const MemoryRequirement & additionalMemoryRequirement,const VkDeviceAddress opaqueCaptureAddress,const uint32_t shaderBindingTableOffset,const uint32_t shaderRecordSize,const void ** shaderGroupDataPtrPerGroup,const bool autoAlignRecords)3918 de::MovePtr<BufferWithMemory> RayTracingPipeline::createShaderBindingTable(
3919     const DeviceInterface &vk, const VkDevice device, Allocator &allocator, const uint32_t shaderGroupHandleSize,
3920     const uint32_t shaderGroupBaseAlignment, const std::vector<uint8_t> &shaderHandles,
3921     const VkBufferCreateFlags additionalBufferCreateFlags, const VkBufferUsageFlags additionalBufferUsageFlags,
3922     const MemoryRequirement &additionalMemoryRequirement, const VkDeviceAddress opaqueCaptureAddress,
3923     const uint32_t shaderBindingTableOffset, const uint32_t shaderRecordSize, const void **shaderGroupDataPtrPerGroup,
3924     const bool autoAlignRecords)
3925 {
3926     DE_ASSERT(shaderGroupBaseAlignment != 0u);
3927     DE_ASSERT((shaderBindingTableOffset % shaderGroupBaseAlignment) == 0);
3928     DE_UNREF(shaderGroupBaseAlignment);
3929 
3930     const auto groupCount = de::sizeU32(shaderHandles) / shaderGroupHandleSize;
3931     const auto totalEntrySize =
3932         (autoAlignRecords ? (deAlign32(shaderGroupHandleSize + shaderRecordSize, shaderGroupHandleSize)) :
3933                             (shaderGroupHandleSize + shaderRecordSize));
3934     const uint32_t sbtSize            = shaderBindingTableOffset + groupCount * totalEntrySize;
3935     const VkBufferUsageFlags sbtFlags = VK_BUFFER_USAGE_TRANSFER_DST_BIT |
3936                                         VK_BUFFER_USAGE_SHADER_BINDING_TABLE_BIT_KHR |
3937                                         VK_BUFFER_USAGE_SHADER_DEVICE_ADDRESS_BIT | additionalBufferUsageFlags;
3938     VkBufferCreateInfo sbtCreateInfo = makeBufferCreateInfo(sbtSize, sbtFlags);
3939     sbtCreateInfo.flags |= additionalBufferCreateFlags;
3940     VkBufferUsageFlags2CreateInfoKHR bufferUsageFlags2           = vk::initVulkanStructure();
3941     VkBufferOpaqueCaptureAddressCreateInfo sbtCaptureAddressInfo = {
3942         VK_STRUCTURE_TYPE_BUFFER_OPAQUE_CAPTURE_ADDRESS_CREATE_INFO, // VkStructureType sType;
3943         DE_NULL,                                                     // const void* pNext;
3944         uint64_t(opaqueCaptureAddress)                               // uint64_t opaqueCaptureAddress;
3945     };
3946 
3947     // when maintenance5 is tested then m_pipelineCreateFlags2 is non-zero
3948     if (m_pipelineCreateFlags2)
3949     {
3950         bufferUsageFlags2.usage = (VkBufferUsageFlags2KHR)sbtFlags;
3951         sbtCreateInfo.pNext     = &bufferUsageFlags2;
3952         sbtCreateInfo.usage     = 0;
3953     }
3954 
3955     if (opaqueCaptureAddress != 0u)
3956     {
3957         sbtCreateInfo.pNext = &sbtCaptureAddressInfo;
3958         sbtCreateInfo.flags |= VK_BUFFER_CREATE_DEVICE_ADDRESS_CAPTURE_REPLAY_BIT;
3959     }
3960     const MemoryRequirement sbtMemRequirements = MemoryRequirement::HostVisible | MemoryRequirement::Coherent |
3961                                                  MemoryRequirement::DeviceAddress | additionalMemoryRequirement;
3962     de::MovePtr<BufferWithMemory> sbtBuffer =
3963         de::MovePtr<BufferWithMemory>(new BufferWithMemory(vk, device, allocator, sbtCreateInfo, sbtMemRequirements));
3964     vk::Allocation &sbtAlloc = sbtBuffer->getAllocation();
3965 
3966     // Copy handles to table, leaving space for ShaderRecordKHR after each handle.
3967     uint8_t *shaderBegin = (uint8_t *)sbtAlloc.getHostPtr() + shaderBindingTableOffset;
3968     for (uint32_t idx = 0; idx < groupCount; ++idx)
3969     {
3970         const uint8_t *shaderSrcPos = shaderHandles.data() + idx * shaderGroupHandleSize;
3971         uint8_t *shaderDstPos       = shaderBegin + idx * totalEntrySize;
3972         deMemcpy(shaderDstPos, shaderSrcPos, shaderGroupHandleSize);
3973 
3974         if (shaderGroupDataPtrPerGroup != nullptr && shaderGroupDataPtrPerGroup[idx] != nullptr)
3975         {
3976             DE_ASSERT(sbtSize >= static_cast<uint32_t>(shaderDstPos - shaderBegin) + shaderGroupHandleSize);
3977 
3978             deMemcpy(shaderDstPos + shaderGroupHandleSize, shaderGroupDataPtrPerGroup[idx], shaderRecordSize);
3979         }
3980     }
3981 
3982     flushMappedMemoryRange(vk, device, sbtAlloc.getMemory(), sbtAlloc.getOffset(), VK_WHOLE_SIZE);
3983 
3984     return sbtBuffer;
3985 }
3986 
setCreateFlags(const VkPipelineCreateFlags & pipelineCreateFlags)3987 void RayTracingPipeline::setCreateFlags(const VkPipelineCreateFlags &pipelineCreateFlags)
3988 {
3989     m_pipelineCreateFlags = pipelineCreateFlags;
3990 }
3991 
setCreateFlags2(const VkPipelineCreateFlags2KHR & pipelineCreateFlags2)3992 void RayTracingPipeline::setCreateFlags2(const VkPipelineCreateFlags2KHR &pipelineCreateFlags2)
3993 {
3994     m_pipelineCreateFlags2 = pipelineCreateFlags2;
3995 }
3996 
setMaxRecursionDepth(const uint32_t & maxRecursionDepth)3997 void RayTracingPipeline::setMaxRecursionDepth(const uint32_t &maxRecursionDepth)
3998 {
3999     m_maxRecursionDepth = maxRecursionDepth;
4000 }
4001 
setMaxPayloadSize(const uint32_t & maxPayloadSize)4002 void RayTracingPipeline::setMaxPayloadSize(const uint32_t &maxPayloadSize)
4003 {
4004     m_maxPayloadSize = maxPayloadSize;
4005 }
4006 
setMaxAttributeSize(const uint32_t & maxAttributeSize)4007 void RayTracingPipeline::setMaxAttributeSize(const uint32_t &maxAttributeSize)
4008 {
4009     m_maxAttributeSize = maxAttributeSize;
4010 }
4011 
setDeferredOperation(const bool deferredOperation,const uint32_t workerThreadCount)4012 void RayTracingPipeline::setDeferredOperation(const bool deferredOperation, const uint32_t workerThreadCount)
4013 {
4014     m_deferredOperation = deferredOperation;
4015     m_workerThreadCount = workerThreadCount;
4016 }
4017 
addDynamicState(const VkDynamicState & dynamicState)4018 void RayTracingPipeline::addDynamicState(const VkDynamicState &dynamicState)
4019 {
4020     m_dynamicStates.push_back(dynamicState);
4021 }
4022 
4023 class RayTracingPropertiesKHR : public RayTracingProperties
4024 {
4025 public:
4026     RayTracingPropertiesKHR() = delete;
4027     RayTracingPropertiesKHR(const InstanceInterface &vki, const VkPhysicalDevice physicalDevice);
4028     virtual ~RayTracingPropertiesKHR();
4029 
getShaderGroupHandleSize(void)4030     uint32_t getShaderGroupHandleSize(void) override
4031     {
4032         return m_rayTracingPipelineProperties.shaderGroupHandleSize;
4033     }
getShaderGroupHandleAlignment(void)4034     uint32_t getShaderGroupHandleAlignment(void) override
4035     {
4036         return m_rayTracingPipelineProperties.shaderGroupHandleAlignment;
4037     }
getShaderGroupHandleCaptureReplaySize(void)4038     uint32_t getShaderGroupHandleCaptureReplaySize(void) override
4039     {
4040         return m_rayTracingPipelineProperties.shaderGroupHandleCaptureReplaySize;
4041     }
getMaxRecursionDepth(void)4042     uint32_t getMaxRecursionDepth(void) override
4043     {
4044         return m_rayTracingPipelineProperties.maxRayRecursionDepth;
4045     }
getMaxShaderGroupStride(void)4046     uint32_t getMaxShaderGroupStride(void) override
4047     {
4048         return m_rayTracingPipelineProperties.maxShaderGroupStride;
4049     }
getShaderGroupBaseAlignment(void)4050     uint32_t getShaderGroupBaseAlignment(void) override
4051     {
4052         return m_rayTracingPipelineProperties.shaderGroupBaseAlignment;
4053     }
getMaxGeometryCount(void)4054     uint64_t getMaxGeometryCount(void) override
4055     {
4056         return m_accelerationStructureProperties.maxGeometryCount;
4057     }
getMaxInstanceCount(void)4058     uint64_t getMaxInstanceCount(void) override
4059     {
4060         return m_accelerationStructureProperties.maxInstanceCount;
4061     }
getMaxPrimitiveCount(void)4062     uint64_t getMaxPrimitiveCount(void) override
4063     {
4064         return m_accelerationStructureProperties.maxPrimitiveCount;
4065     }
getMaxDescriptorSetAccelerationStructures(void)4066     uint32_t getMaxDescriptorSetAccelerationStructures(void) override
4067     {
4068         return m_accelerationStructureProperties.maxDescriptorSetAccelerationStructures;
4069     }
getMaxRayDispatchInvocationCount(void)4070     uint32_t getMaxRayDispatchInvocationCount(void) override
4071     {
4072         return m_rayTracingPipelineProperties.maxRayDispatchInvocationCount;
4073     }
getMaxRayHitAttributeSize(void)4074     uint32_t getMaxRayHitAttributeSize(void) override
4075     {
4076         return m_rayTracingPipelineProperties.maxRayHitAttributeSize;
4077     }
getMaxMemoryAllocationCount(void)4078     uint32_t getMaxMemoryAllocationCount(void) override
4079     {
4080         return m_maxMemoryAllocationCount;
4081     }
4082 
4083 protected:
4084     VkPhysicalDeviceAccelerationStructurePropertiesKHR m_accelerationStructureProperties;
4085     VkPhysicalDeviceRayTracingPipelinePropertiesKHR m_rayTracingPipelineProperties;
4086     uint32_t m_maxMemoryAllocationCount;
4087 };
4088 
~RayTracingPropertiesKHR()4089 RayTracingPropertiesKHR::~RayTracingPropertiesKHR()
4090 {
4091 }
4092 
RayTracingPropertiesKHR(const InstanceInterface & vki,const VkPhysicalDevice physicalDevice)4093 RayTracingPropertiesKHR::RayTracingPropertiesKHR(const InstanceInterface &vki, const VkPhysicalDevice physicalDevice)
4094     : RayTracingProperties(vki, physicalDevice)
4095 {
4096     m_accelerationStructureProperties = getPhysicalDeviceExtensionProperties(vki, physicalDevice);
4097     m_rayTracingPipelineProperties    = getPhysicalDeviceExtensionProperties(vki, physicalDevice);
4098     m_maxMemoryAllocationCount = getPhysicalDeviceProperties(vki, physicalDevice).limits.maxMemoryAllocationCount;
4099 }
4100 
makeRayTracingProperties(const InstanceInterface & vki,const VkPhysicalDevice physicalDevice)4101 de::MovePtr<RayTracingProperties> makeRayTracingProperties(const InstanceInterface &vki,
4102                                                            const VkPhysicalDevice physicalDevice)
4103 {
4104     return de::MovePtr<RayTracingProperties>(new RayTracingPropertiesKHR(vki, physicalDevice));
4105 }
4106 
cmdTraceRaysKHR(const DeviceInterface & vk,VkCommandBuffer commandBuffer,const VkStridedDeviceAddressRegionKHR * raygenShaderBindingTableRegion,const VkStridedDeviceAddressRegionKHR * missShaderBindingTableRegion,const VkStridedDeviceAddressRegionKHR * hitShaderBindingTableRegion,const VkStridedDeviceAddressRegionKHR * callableShaderBindingTableRegion,uint32_t width,uint32_t height,uint32_t depth)4107 static inline void cmdTraceRaysKHR(const DeviceInterface &vk, VkCommandBuffer commandBuffer,
4108                                    const VkStridedDeviceAddressRegionKHR *raygenShaderBindingTableRegion,
4109                                    const VkStridedDeviceAddressRegionKHR *missShaderBindingTableRegion,
4110                                    const VkStridedDeviceAddressRegionKHR *hitShaderBindingTableRegion,
4111                                    const VkStridedDeviceAddressRegionKHR *callableShaderBindingTableRegion,
4112                                    uint32_t width, uint32_t height, uint32_t depth)
4113 {
4114     return vk.cmdTraceRaysKHR(commandBuffer, raygenShaderBindingTableRegion, missShaderBindingTableRegion,
4115                               hitShaderBindingTableRegion, callableShaderBindingTableRegion, width, height, depth);
4116 }
4117 
cmdTraceRays(const DeviceInterface & vk,VkCommandBuffer commandBuffer,const VkStridedDeviceAddressRegionKHR * raygenShaderBindingTableRegion,const VkStridedDeviceAddressRegionKHR * missShaderBindingTableRegion,const VkStridedDeviceAddressRegionKHR * hitShaderBindingTableRegion,const VkStridedDeviceAddressRegionKHR * callableShaderBindingTableRegion,uint32_t width,uint32_t height,uint32_t depth)4118 void cmdTraceRays(const DeviceInterface &vk, VkCommandBuffer commandBuffer,
4119                   const VkStridedDeviceAddressRegionKHR *raygenShaderBindingTableRegion,
4120                   const VkStridedDeviceAddressRegionKHR *missShaderBindingTableRegion,
4121                   const VkStridedDeviceAddressRegionKHR *hitShaderBindingTableRegion,
4122                   const VkStridedDeviceAddressRegionKHR *callableShaderBindingTableRegion, uint32_t width,
4123                   uint32_t height, uint32_t depth)
4124 {
4125     DE_ASSERT(raygenShaderBindingTableRegion != DE_NULL);
4126     DE_ASSERT(missShaderBindingTableRegion != DE_NULL);
4127     DE_ASSERT(hitShaderBindingTableRegion != DE_NULL);
4128     DE_ASSERT(callableShaderBindingTableRegion != DE_NULL);
4129 
4130     return cmdTraceRaysKHR(vk, commandBuffer, raygenShaderBindingTableRegion, missShaderBindingTableRegion,
4131                            hitShaderBindingTableRegion, callableShaderBindingTableRegion, width, height, depth);
4132 }
4133 
cmdTraceRaysIndirectKHR(const DeviceInterface & vk,VkCommandBuffer commandBuffer,const VkStridedDeviceAddressRegionKHR * raygenShaderBindingTableRegion,const VkStridedDeviceAddressRegionKHR * missShaderBindingTableRegion,const VkStridedDeviceAddressRegionKHR * hitShaderBindingTableRegion,const VkStridedDeviceAddressRegionKHR * callableShaderBindingTableRegion,VkDeviceAddress indirectDeviceAddress)4134 static inline void cmdTraceRaysIndirectKHR(const DeviceInterface &vk, VkCommandBuffer commandBuffer,
4135                                            const VkStridedDeviceAddressRegionKHR *raygenShaderBindingTableRegion,
4136                                            const VkStridedDeviceAddressRegionKHR *missShaderBindingTableRegion,
4137                                            const VkStridedDeviceAddressRegionKHR *hitShaderBindingTableRegion,
4138                                            const VkStridedDeviceAddressRegionKHR *callableShaderBindingTableRegion,
4139                                            VkDeviceAddress indirectDeviceAddress)
4140 {
4141     DE_ASSERT(raygenShaderBindingTableRegion != DE_NULL);
4142     DE_ASSERT(missShaderBindingTableRegion != DE_NULL);
4143     DE_ASSERT(hitShaderBindingTableRegion != DE_NULL);
4144     DE_ASSERT(callableShaderBindingTableRegion != DE_NULL);
4145     DE_ASSERT(indirectDeviceAddress != 0);
4146 
4147     return vk.cmdTraceRaysIndirectKHR(commandBuffer, raygenShaderBindingTableRegion, missShaderBindingTableRegion,
4148                                       hitShaderBindingTableRegion, callableShaderBindingTableRegion,
4149                                       indirectDeviceAddress);
4150 }
4151 
cmdTraceRaysIndirect(const DeviceInterface & vk,VkCommandBuffer commandBuffer,const VkStridedDeviceAddressRegionKHR * raygenShaderBindingTableRegion,const VkStridedDeviceAddressRegionKHR * missShaderBindingTableRegion,const VkStridedDeviceAddressRegionKHR * hitShaderBindingTableRegion,const VkStridedDeviceAddressRegionKHR * callableShaderBindingTableRegion,VkDeviceAddress indirectDeviceAddress)4152 void cmdTraceRaysIndirect(const DeviceInterface &vk, VkCommandBuffer commandBuffer,
4153                           const VkStridedDeviceAddressRegionKHR *raygenShaderBindingTableRegion,
4154                           const VkStridedDeviceAddressRegionKHR *missShaderBindingTableRegion,
4155                           const VkStridedDeviceAddressRegionKHR *hitShaderBindingTableRegion,
4156                           const VkStridedDeviceAddressRegionKHR *callableShaderBindingTableRegion,
4157                           VkDeviceAddress indirectDeviceAddress)
4158 {
4159     return cmdTraceRaysIndirectKHR(vk, commandBuffer, raygenShaderBindingTableRegion, missShaderBindingTableRegion,
4160                                    hitShaderBindingTableRegion, callableShaderBindingTableRegion,
4161                                    indirectDeviceAddress);
4162 }
4163 
cmdTraceRaysIndirect2KHR(const DeviceInterface & vk,VkCommandBuffer commandBuffer,VkDeviceAddress indirectDeviceAddress)4164 static inline void cmdTraceRaysIndirect2KHR(const DeviceInterface &vk, VkCommandBuffer commandBuffer,
4165                                             VkDeviceAddress indirectDeviceAddress)
4166 {
4167     DE_ASSERT(indirectDeviceAddress != 0);
4168 
4169     return vk.cmdTraceRaysIndirect2KHR(commandBuffer, indirectDeviceAddress);
4170 }
4171 
cmdTraceRaysIndirect2(const DeviceInterface & vk,VkCommandBuffer commandBuffer,VkDeviceAddress indirectDeviceAddress)4172 void cmdTraceRaysIndirect2(const DeviceInterface &vk, VkCommandBuffer commandBuffer,
4173                            VkDeviceAddress indirectDeviceAddress)
4174 {
4175     return cmdTraceRaysIndirect2KHR(vk, commandBuffer, indirectDeviceAddress);
4176 }
4177 
4178 constexpr uint32_t NO_INT_VALUE = spv::RayQueryCommittedIntersectionTypeMax;
4179 
generateRayQueryShaders(SourceCollections & programCollection,RayQueryTestParams params,std::string rayQueryPart,float max_t)4180 void generateRayQueryShaders(SourceCollections &programCollection, RayQueryTestParams params, std::string rayQueryPart,
4181                              float max_t)
4182 {
4183     std::stringstream genericMiss;
4184     genericMiss << "#version 460\n"
4185                    "#extension GL_EXT_ray_tracing : require\n"
4186                    "#extension GL_EXT_ray_query : require\n"
4187                    "layout(location = 0) rayPayloadInEXT vec4 payload;\n"
4188                    "void main()\n"
4189                    "{\n"
4190                    "  payload.x = 2000;\n"
4191                    "  payload.y = 2000;\n"
4192                    "  payload.z = 2000;\n"
4193                    "  payload.w = 2000;\n"
4194                    "}\n";
4195 
4196     std::stringstream genericIsect;
4197     genericIsect << "#version 460\n"
4198                     "#extension GL_EXT_ray_tracing : require\n"
4199                     "hitAttributeEXT uvec4 hitValue;\n"
4200                     "void main()\n"
4201                     "{\n"
4202                     "  reportIntersectionEXT(0.5f, 0);\n"
4203                     "}\n";
4204 
4205     std::stringstream rtChit;
4206     rtChit << "#version 460    \n"
4207               "#extension GL_EXT_ray_tracing : require\n"
4208               "#extension GL_EXT_ray_query : require\n"
4209               "layout(location = 0) rayPayloadInEXT vec4 payload;\n"
4210               "void main()\n"
4211               "{\n"
4212               "  uint index = (gl_LaunchIDEXT.z * gl_LaunchSizeEXT.x * gl_LaunchSizeEXT.y) + (gl_LaunchIDEXT.y * "
4213               "gl_LaunchSizeEXT.x) + gl_LaunchIDEXT.x;\n"
4214               "  payload.x = index;\n"
4215               "  payload.y = gl_HitTEXT;\n"
4216               "  payload.z = 1000;\n"
4217               "  payload.w = 1000;\n"
4218               "}\n";
4219 
4220     std::stringstream genericChit;
4221     genericChit << "#version 460    \n"
4222                    "#extension GL_EXT_ray_tracing : require\n"
4223                    "#extension GL_EXT_ray_query : require\n"
4224                    "layout(location = 0) rayPayloadInEXT vec4 payload;\n"
4225                    "void main()\n"
4226                    "{\n"
4227                    "  payload.x = 1000;\n"
4228                    "  payload.y = 1000;\n"
4229                    "  payload.z = 1000;\n"
4230                    "  payload.w = 1000;\n"
4231                    "}\n";
4232 
4233     std::stringstream genericRayTracingSetResultsShader;
4234     genericRayTracingSetResultsShader << "#version 460    \n"
4235                                          "#extension GL_EXT_ray_tracing : require\n"
4236                                          "#extension GL_EXT_ray_query : require\n"
4237                                          "layout(location = 0) rayPayloadInEXT vec4 payload;\n"
4238                                          "struct Ray { vec3 pos; float tmin; vec3 dir; float tmax; };\n"
4239                                          "layout(set = 0, binding = 1) uniform accelerationStructureEXT scene;\n"
4240                                          "layout(std430, set = 0, binding = 2) buffer Rays { Ray rays[]; };\n"
4241                                       << params.shaderFunctions
4242                                       << "void main()\n"
4243                                          "{\n"
4244                                          "  uint index = (gl_LaunchIDEXT.z * gl_LaunchSizeEXT.x * gl_LaunchSizeEXT.y) "
4245                                          "+ (gl_LaunchIDEXT.y * gl_LaunchSizeEXT.x) + gl_LaunchIDEXT.x;\n"
4246                                       << rayQueryPart
4247                                       << "  payload.x = x;\n"
4248                                          "  payload.y = y;\n"
4249                                          "  payload.z = z;\n"
4250                                          "  payload.w = w;\n"
4251                                          "}\n";
4252 
4253     const vk::ShaderBuildOptions buildOptions(programCollection.usedVulkanVersion, vk::SPIRV_VERSION_1_5, 0u, true);
4254 
4255     switch (params.pipelineType)
4256     {
4257     case RayQueryShaderSourcePipeline::COMPUTE:
4258     {
4259         std::ostringstream compute;
4260         compute << "#version 460\n"
4261                    "#extension GL_EXT_ray_tracing : enable\n"
4262                    "#extension GL_EXT_ray_query : require\n"
4263                    "\n"
4264                    "struct Ray { vec3 pos; float tmin; vec3 dir; float tmax; };\n"
4265                    "struct ResultType { float x; float y; float z; float w; };\n"
4266                    "layout(std430, set = 0, binding = 0) buffer Results { ResultType results[]; };\n"
4267                    "layout(set = 0, binding = 1) uniform accelerationStructureEXT scene;\n"
4268                    "layout(std430, set = 0, binding = 2) buffer Rays { Ray rays[]; };\n"
4269                    "layout (local_size_x = 1, local_size_y = 1, local_size_z = 1) in;\n"
4270                 << params.shaderFunctions
4271                 << "void main() {\n"
4272                    "   uint index = (gl_NumWorkGroups.x * gl_WorkGroupSize.x) * gl_GlobalInvocationID.y + "
4273                    "gl_GlobalInvocationID.x;\n"
4274                 << rayQueryPart
4275                 << "   results[index].x = x;\n"
4276                    "   results[index].y = y;\n"
4277                    "   results[index].z = z;\n"
4278                    "   results[index].w = w;\n"
4279                    "}";
4280 
4281         programCollection.glslSources.add("comp", &buildOptions) << glu::ComputeSource(compute.str());
4282 
4283         break;
4284     }
4285     case RayQueryShaderSourcePipeline::GRAPHICS:
4286     {
4287         std::ostringstream vertex;
4288 
4289         if (params.shaderSourceType == RayQueryShaderSourceType::VERTEX)
4290         {
4291             vertex << "#version 460\n"
4292                       "#extension GL_EXT_ray_tracing : enable\n"
4293                       "#extension GL_EXT_ray_query : require\n"
4294                       "struct Ray { vec3 pos; float tmin; vec3 dir; float tmax; };\n"
4295                       "layout(location = 0) in vec4 in_position;\n"
4296                       "layout(rgba32f, set = 0, binding = 0) uniform image3D resultImage;\n"
4297                       "layout(set = 0, binding = 1) uniform accelerationStructureEXT scene;\n"
4298                       "layout(std430, set = 0, binding = 2) buffer Rays { Ray rays[]; };\n"
4299                    << params.shaderFunctions
4300                    << "void main(void)\n"
4301                       "{\n"
4302                       "  const int  vertId = int(gl_VertexIndex % 3);\n"
4303                       "  if (vertId == 0)\n"
4304                       "  {\n"
4305                       "    ivec3 sz = imageSize(resultImage);\n"
4306                       "    int index = int(in_position.z);\n"
4307                       "    int idx = int(index % sz.x);\n"
4308                       "    int idy = int(index / sz.y);\n"
4309                    << rayQueryPart
4310                    << "     imageStore(resultImage, ivec3(idx, idy, 0), vec4(x, y, z, w));\n"
4311                       "  }\n"
4312                       "}\n";
4313         }
4314         else
4315         {
4316             vertex << "#version 460\n"
4317                       "layout(location = 0) in highp vec3 position;\n"
4318                       "\n"
4319                       "out gl_PerVertex {\n"
4320                       "   vec4 gl_Position;\n"
4321                       "};\n"
4322                       "\n"
4323                       "void main (void)\n"
4324                       "{\n"
4325                       "    gl_Position = vec4(position, 1.0);\n"
4326                       "}\n";
4327         }
4328 
4329         programCollection.glslSources.add("vert", &buildOptions) << glu::VertexSource(vertex.str());
4330 
4331         if (params.shaderSourceType == RayQueryShaderSourceType::FRAGMENT)
4332         {
4333             std::ostringstream frag;
4334             frag << "#version 460\n"
4335                     "#extension GL_EXT_ray_tracing : enable\n"
4336                     "#extension GL_EXT_ray_query : require\n"
4337                     "struct Ray { vec3 pos; float tmin; vec3 dir; float tmax; };\n"
4338                     "layout(rgba32f, set = 0, binding = 0) uniform image3D resultImage;\n"
4339                     "layout(set = 0, binding = 1) uniform accelerationStructureEXT scene;\n"
4340                     "layout(std430, set = 0, binding = 2) buffer Rays { Ray rays[]; };\n"
4341                  << params.shaderFunctions
4342                  << "void main() {\n"
4343                     "    ivec3 sz = imageSize(resultImage);\n"
4344                     "    uint index = uint(gl_FragCoord.x) + sz.x * uint(gl_FragCoord.y);\n"
4345                  << rayQueryPart
4346                  << "    imageStore(resultImage, ivec3(gl_FragCoord.xy, 0), vec4(x, y, z, w));\n"
4347                     "}";
4348 
4349             programCollection.glslSources.add("frag", &buildOptions) << glu::FragmentSource(frag.str());
4350         }
4351         else if (params.shaderSourceType == RayQueryShaderSourceType::GEOMETRY)
4352         {
4353             std::stringstream geom;
4354             geom << "#version 460\n"
4355                     "#extension GL_EXT_ray_tracing : enable\n"
4356                     "#extension GL_EXT_ray_query : require\n"
4357                     "struct Ray { vec3 pos; float tmin; vec3 dir; float tmax; };\n"
4358                     "layout(triangles) in;\n"
4359                     "layout (triangle_strip, max_vertices = 3) out;\n"
4360                     "layout(rgba32f, set = 0, binding = 0) uniform image3D resultImage;\n"
4361                     "layout(set = 0, binding = 1) uniform accelerationStructureEXT scene;\n"
4362                     "layout(std430, set = 0, binding = 2) buffer Rays { Ray rays[]; };\n"
4363                     "\n"
4364                     "in gl_PerVertex {\n"
4365                     "  vec4  gl_Position;\n"
4366                     "} gl_in[];\n"
4367                     "out gl_PerVertex {\n"
4368                     "  vec4 gl_Position;\n"
4369                     "};\n"
4370                  << params.shaderFunctions
4371                  << "void main (void)\n"
4372                     "{\n"
4373                     "  ivec3 sz = imageSize(resultImage);\n"
4374                     "  int index = int(gl_in[0].gl_Position.z);\n"
4375                     "  int idx = int(index % sz.x);\n"
4376                     "  int idy = int(index / sz.y);\n"
4377                  << rayQueryPart
4378                  << "  imageStore(resultImage, ivec3(idx, idy, 0), vec4(x, y, z, w));\n"
4379                     "  for (int i = 0; i < gl_in.length(); ++i)\n"
4380                     "  {\n"
4381                     "        gl_Position      = gl_in[i].gl_Position;\n"
4382                     "        EmitVertex();\n"
4383                     "  }\n"
4384                     "  EndPrimitive();\n"
4385                     "}\n";
4386 
4387             programCollection.glslSources.add("geom", &buildOptions) << glu::GeometrySource(geom.str());
4388         }
4389         else if (params.shaderSourceType == RayQueryShaderSourceType::TESSELLATION_EVALUATION)
4390         {
4391             {
4392                 std::stringstream tesc;
4393                 tesc << "#version 460\n"
4394                         "#extension GL_EXT_tessellation_shader : require\n"
4395                         "in gl_PerVertex\n"
4396                         "{\n"
4397                         "  vec4 gl_Position;\n"
4398                         "} gl_in[];\n"
4399                         "layout(vertices = 4) out;\n"
4400                         "out gl_PerVertex\n"
4401                         "{\n"
4402                         "  vec4 gl_Position;\n"
4403                         "} gl_out[];\n"
4404                         "\n"
4405                         "void main (void)\n"
4406                         "{\n"
4407                         "  gl_out[gl_InvocationID].gl_Position = gl_in[gl_InvocationID].gl_Position;\n"
4408                         "  gl_TessLevelInner[0] = 1;\n"
4409                         "  gl_TessLevelInner[1] = 1;\n"
4410                         "  gl_TessLevelOuter[gl_InvocationID] = 1;\n"
4411                         "}\n";
4412                 programCollection.glslSources.add("tesc", &buildOptions) << glu::TessellationControlSource(tesc.str());
4413             }
4414 
4415             {
4416                 std::ostringstream tese;
4417                 tese << "#version 460\n"
4418                         "#extension GL_EXT_ray_tracing : enable\n"
4419                         "#extension GL_EXT_tessellation_shader : require\n"
4420                         "#extension GL_EXT_ray_query : require\n"
4421                         "struct Ray { vec3 pos; float tmin; vec3 dir; float tmax; };\n"
4422                         "layout(rgba32f, set = 0, binding = 0) uniform image3D resultImage;\n"
4423                         "layout(set = 0, binding = 1) uniform accelerationStructureEXT scene;\n"
4424                         "layout(std430, set = 0, binding = 2) buffer Rays { Ray rays[]; };\n"
4425                         "layout(quads, equal_spacing, ccw) in;\n"
4426                         "in gl_PerVertex\n"
4427                         "{\n"
4428                         "  vec4 gl_Position;\n"
4429                         "} gl_in[];\n"
4430                      << params.shaderFunctions
4431                      << "void main(void)\n"
4432                         "{\n"
4433                         "  ivec3 sz = imageSize(resultImage);\n"
4434                         "  int index = int(gl_in[0].gl_Position.z);\n"
4435                         "  int idx = int(index % sz.x);\n"
4436                         "  int idy = int(index / sz.y);\n"
4437                      << rayQueryPart
4438                      << "  imageStore(resultImage, ivec3(idx, idy, 0), vec4(x, y, z, w));\n"
4439                         "  gl_Position = gl_in[0].gl_Position;\n"
4440                         "}\n";
4441 
4442                 programCollection.glslSources.add("tese", &buildOptions)
4443                     << glu::TessellationEvaluationSource(tese.str());
4444             }
4445         }
4446         else if (params.shaderSourceType == RayQueryShaderSourceType::TESSELLATION_CONTROL)
4447         {
4448             {
4449                 std::ostringstream tesc;
4450                 tesc << "#version 460\n"
4451                         "#extension GL_EXT_ray_tracing : enable\n"
4452                         "#extension GL_EXT_tessellation_shader : require\n"
4453                         "#extension GL_EXT_ray_query : require\n"
4454                         "struct Ray { vec3 pos; float tmin; vec3 dir; float tmax; };\n"
4455                         "layout(rgba32f, set = 0, binding = 0) uniform image3D resultImage;\n"
4456                         "layout(set = 0, binding = 1) uniform accelerationStructureEXT scene;\n"
4457                         "layout(std430, set = 0, binding = 2) buffer Rays { Ray rays[]; };\n"
4458                         "in gl_PerVertex\n"
4459                         "{\n"
4460                         "  vec4 gl_Position;\n"
4461                         "} gl_in[];\n"
4462                         "layout(vertices = 4) out;\n"
4463                         "out gl_PerVertex\n"
4464                         "{\n"
4465                         "  vec4 gl_Position;\n"
4466                         "} gl_out[];\n"
4467                         "\n"
4468                      << params.shaderFunctions
4469                      << "void main(void)\n"
4470                         "{\n"
4471                         "  ivec3 sz = imageSize(resultImage);\n"
4472                         "  int index = int(gl_in[0].gl_Position.z);\n"
4473                         "  int idx = int(index % sz.x);\n"
4474                         "  int idy = int(index / sz.y);\n"
4475                      << rayQueryPart
4476                      << "  imageStore(resultImage, ivec3(idx, idy, 0), vec4(x, y, z, w));\n"
4477                         "  gl_out[gl_InvocationID].gl_Position = gl_in[gl_InvocationID].gl_Position;\n"
4478                         "  gl_TessLevelInner[0] = 1;\n"
4479                         "  gl_TessLevelInner[1] = 1;\n"
4480                         "  gl_TessLevelOuter[gl_InvocationID] = 1;\n"
4481                         "}\n";
4482 
4483                 programCollection.glslSources.add("tesc", &buildOptions) << glu::TessellationControlSource(tesc.str());
4484             }
4485 
4486             {
4487                 std::ostringstream tese;
4488                 tese << "#version 460\n"
4489                         "#extension GL_EXT_tessellation_shader : require\n"
4490                         "layout(quads, equal_spacing, ccw) in;\n"
4491                         "in gl_PerVertex\n"
4492                         "{\n"
4493                         "  vec4 gl_Position;\n"
4494                         "} gl_in[];\n"
4495                         "\n"
4496                         "void main(void)\n"
4497                         "{\n"
4498                         "  gl_Position = gl_in[0].gl_Position;\n"
4499                         "}\n";
4500 
4501                 programCollection.glslSources.add("tese", &buildOptions)
4502                     << glu::TessellationEvaluationSource(tese.str());
4503             }
4504         }
4505 
4506         break;
4507     }
4508     case RayQueryShaderSourcePipeline::RAYTRACING:
4509     {
4510         std::stringstream rayGen;
4511 
4512         if (params.shaderSourceType == RayQueryShaderSourceType::RAY_GENERATION_RT)
4513         {
4514             rayGen << "#version 460\n"
4515                       "#extension GL_EXT_ray_tracing : enable\n"
4516                       "#extension GL_EXT_ray_query : require\n"
4517                       "struct Ray { vec3 pos; float tmin; vec3 dir; float tmax; };\n"
4518                       "struct ResultType { float x; float y; float z; float w; };\n"
4519                       "layout(std430, set = 0, binding = 0) buffer Results { ResultType results[]; };\n"
4520                       "layout(set = 0, binding = 1) uniform accelerationStructureEXT scene;\n"
4521                       "layout(std430, set = 0, binding = 2) buffer Rays { Ray rays[]; };\n"
4522                       "layout(location = 0) rayPayloadEXT vec4 payload;\n"
4523                    << params.shaderFunctions
4524                    << "void main() {\n"
4525                       "   payload = vec4("
4526                    << NO_INT_VALUE << "," << max_t * 2
4527                    << ",0,0);\n"
4528                       "   uint index = (gl_LaunchIDEXT.z * gl_LaunchSizeEXT.x * gl_LaunchSizeEXT.y) + "
4529                       "(gl_LaunchIDEXT.y * gl_LaunchSizeEXT.x) + gl_LaunchIDEXT.x;\n"
4530                    << rayQueryPart
4531                    << "   results[index].x = x;\n"
4532                       "   results[index].y = y;\n"
4533                       "   results[index].z = z;\n"
4534                       "   results[index].w = w;\n"
4535                       "}";
4536 
4537             programCollection.glslSources.add("isect_rt", &buildOptions)
4538                 << glu::IntersectionSource(updateRayTracingGLSL(genericIsect.str()));
4539             programCollection.glslSources.add("chit_rt", &buildOptions) << glu::ClosestHitSource(rtChit.str());
4540             programCollection.glslSources.add("ahit_rt", &buildOptions) << glu::AnyHitSource(genericChit.str());
4541             programCollection.glslSources.add("miss_rt", &buildOptions) << glu::MissSource(genericMiss.str());
4542         }
4543         else if (params.shaderSourceType == RayQueryShaderSourceType::RAY_GENERATION)
4544         {
4545             rayGen << "#version 460\n"
4546                       "#extension GL_EXT_ray_tracing : enable\n"
4547                       "#extension GL_EXT_ray_query : require\n"
4548                       "struct Ray { vec3 pos; float tmin; vec3 dir; float tmax; };\n"
4549                       "struct ResultType { float x; float y; float z; float w; };\n"
4550                       "layout(std430, set = 0, binding = 0) buffer Results { ResultType results[]; };\n"
4551                       "layout(set = 0, binding = 1) uniform accelerationStructureEXT scene;\n"
4552                       "layout(std430, set = 0, binding = 2) buffer Rays { Ray rays[]; };\n"
4553                    << params.shaderFunctions
4554                    << "void main() {\n"
4555                       "   uint index = (gl_LaunchIDEXT.z * gl_LaunchSizeEXT.x * gl_LaunchSizeEXT.y) + "
4556                       "(gl_LaunchIDEXT.y * gl_LaunchSizeEXT.x) + gl_LaunchIDEXT.x;\n"
4557                    << rayQueryPart
4558                    << "   results[index].x = x;\n"
4559                       "   results[index].y = y;\n"
4560                       "   results[index].z = z;\n"
4561                       "   results[index].w = w;\n"
4562                       "}";
4563         }
4564         else if (params.shaderSourceType == RayQueryShaderSourceType::CALLABLE)
4565         {
4566             rayGen << "#version 460\n"
4567                       "#extension GL_EXT_ray_tracing : require\n"
4568                       "struct CallValue\n{\n"
4569                       "  uint index;\n"
4570                       "  vec4 hitAttrib;\n"
4571                       "};\n"
4572                       "layout(location = 0) callableDataEXT CallValue param;\n"
4573                       "struct ResultType { float x; float y; float z; float w; };\n"
4574                       "layout(std430, set = 0, binding = 0) buffer Results { ResultType results[]; };\n"
4575                       "void main()\n"
4576                       "{\n"
4577                       "  uint index = (gl_LaunchIDEXT.z * gl_LaunchSizeEXT.x * gl_LaunchSizeEXT.y) + (gl_LaunchIDEXT.y "
4578                       "* gl_LaunchSizeEXT.x) + gl_LaunchIDEXT.x;\n"
4579                       "  param.index = index;\n"
4580                       "  param.hitAttrib = vec4(0, 0, 0, 0);\n"
4581                       "  executeCallableEXT(0, 0);\n"
4582                       "  results[index].x = param.hitAttrib.x;\n"
4583                       "  results[index].y = param.hitAttrib.y;\n"
4584                       "  results[index].z = param.hitAttrib.z;\n"
4585                       "  results[index].w = param.hitAttrib.w;\n"
4586                       "}\n";
4587         }
4588         else
4589         {
4590             rayGen << "#version 460\n"
4591                       "#extension GL_EXT_ray_tracing : require\n"
4592                       "#extension GL_EXT_ray_query : require\n"
4593                       "layout(location = 0) rayPayloadEXT vec4 payload;\n"
4594                       "struct ResultType { float x; float y; float z; float w; };\n"
4595                       "layout(std430, set = 0, binding = 0) buffer Results { ResultType results[]; };\n"
4596                       "layout(set = 0, binding = 3) uniform accelerationStructureEXT traceEXTAccel;\n"
4597                       "void main()\n"
4598                       "{\n"
4599                       "  payload = vec4("
4600                    << NO_INT_VALUE << "," << max_t * 2
4601                    << ",0,0);\n"
4602                       "  uint index = (gl_LaunchIDEXT.z * gl_LaunchSizeEXT.x * gl_LaunchSizeEXT.y) + (gl_LaunchIDEXT.y "
4603                       "* gl_LaunchSizeEXT.x) + gl_LaunchIDEXT.x;\n"
4604                       "  traceRayEXT(traceEXTAccel, 0, 0xFF, 0, 0, 0, vec3(0.1, 0.1, 0.0), 0.0, vec3(0.0, 0.0, 1.0), "
4605                       "500.0, 0);\n"
4606                       "  results[index].x = payload.x;\n"
4607                       "  results[index].y = payload.y;\n"
4608                       "  results[index].z = payload.z;\n"
4609                       "  results[index].w = payload.w;\n"
4610                       "}\n";
4611         }
4612 
4613         programCollection.glslSources.add("rgen", &buildOptions) << glu::RaygenSource(rayGen.str());
4614 
4615         if (params.shaderSourceType == RayQueryShaderSourceType::CLOSEST_HIT)
4616         {
4617             programCollection.glslSources.add("chit", &buildOptions)
4618                 << glu::ClosestHitSource(genericRayTracingSetResultsShader.str());
4619             programCollection.glslSources.add("miss", &buildOptions) << glu::MissSource(genericMiss.str());
4620             programCollection.glslSources.add("isect", &buildOptions)
4621                 << glu::IntersectionSource(updateRayTracingGLSL(genericIsect.str()));
4622         }
4623         else if (params.shaderSourceType == RayQueryShaderSourceType::ANY_HIT)
4624         {
4625             programCollection.glslSources.add("ahit", &buildOptions)
4626                 << glu::AnyHitSource(genericRayTracingSetResultsShader.str());
4627             programCollection.glslSources.add("miss", &buildOptions) << glu::MissSource(genericMiss.str());
4628             programCollection.glslSources.add("isect", &buildOptions)
4629                 << glu::IntersectionSource(updateRayTracingGLSL(genericIsect.str()));
4630         }
4631         else if (params.shaderSourceType == RayQueryShaderSourceType::MISS)
4632         {
4633 
4634             programCollection.glslSources.add("chit", &buildOptions) << glu::ClosestHitSource(genericChit.str());
4635             programCollection.glslSources.add("miss_1", &buildOptions)
4636                 << glu::MissSource(genericRayTracingSetResultsShader.str());
4637             programCollection.glslSources.add("isect", &buildOptions)
4638                 << glu::IntersectionSource(updateRayTracingGLSL(genericIsect.str()));
4639         }
4640         else if (params.shaderSourceType == RayQueryShaderSourceType::INTERSECTION)
4641         {
4642             {
4643                 std::stringstream chit;
4644                 chit << "#version 460    \n"
4645                         "#extension GL_EXT_ray_tracing : require\n"
4646                         "#extension GL_EXT_ray_query : require\n"
4647                         "layout(location = 0) rayPayloadInEXT vec4 payload;\n"
4648                         "hitAttributeEXT vec4 hitAttrib;\n"
4649                         "void main()\n"
4650                         "{\n"
4651                         "  payload = hitAttrib;\n"
4652                         "}\n";
4653 
4654                 programCollection.glslSources.add("chit", &buildOptions) << glu::ClosestHitSource(chit.str());
4655             }
4656 
4657             programCollection.glslSources.add("miss", &buildOptions) << glu::MissSource(genericMiss.str());
4658 
4659             {
4660                 std::stringstream isect;
4661                 isect << "#version 460\n"
4662                          "#extension GL_EXT_ray_tracing : require\n"
4663                          "#extension GL_EXT_ray_query : require\n"
4664                          "hitAttributeEXT vec4 hitValue;\n"
4665                          "struct Ray { vec3 pos; float tmin; vec3 dir; float tmax; };\n"
4666                          "layout(set = 0, binding = 1) uniform accelerationStructureEXT scene;\n"
4667                          "layout(std430, set = 0, binding = 2) buffer Rays { Ray rays[]; };\n"
4668                       << params.shaderFunctions
4669                       << "void main()\n"
4670                          "{\n"
4671                          "  uint index = (gl_LaunchIDEXT.z * gl_LaunchSizeEXT.x * gl_LaunchSizeEXT.y) + "
4672                          "(gl_LaunchIDEXT.y * gl_LaunchSizeEXT.x) + gl_LaunchIDEXT.x;\n"
4673                       << rayQueryPart
4674                       << "  hitValue.x = x;\n"
4675                          "  hitValue.y = y;\n"
4676                          "  hitValue.z = z;\n"
4677                          "  hitValue.w = w;\n"
4678                          "  reportIntersectionEXT(0.5f, 0);\n"
4679                          "}\n";
4680 
4681                 programCollection.glslSources.add("isect_1", &buildOptions)
4682                     << glu::IntersectionSource(updateRayTracingGLSL(isect.str()));
4683             }
4684         }
4685         else if (params.shaderSourceType == RayQueryShaderSourceType::CALLABLE)
4686         {
4687             {
4688                 std::stringstream call;
4689                 call << "#version 460\n"
4690                         "#extension GL_EXT_ray_tracing : require\n"
4691                         "#extension GL_EXT_ray_query : require\n"
4692                         "struct CallValue\n{\n"
4693                         "  uint index;\n"
4694                         "  vec4 hitAttrib;\n"
4695                         "};\n"
4696                         "layout(location = 0) callableDataInEXT CallValue result;\n"
4697                         "struct Ray { vec3 pos; float tmin; vec3 dir; float tmax; };\n"
4698                         "layout(set = 0, binding = 1) uniform accelerationStructureEXT scene;\n"
4699                         "layout(std430, set = 0, binding = 2) buffer Rays { Ray rays[]; };\n"
4700                      << params.shaderFunctions
4701                      << "void main()\n"
4702                         "{\n"
4703                         "  uint index = result.index;\n"
4704                      << rayQueryPart
4705                      << "  result.hitAttrib.x = x;\n"
4706                         "  result.hitAttrib.y = y;\n"
4707                         "  result.hitAttrib.z = z;\n"
4708                         "  result.hitAttrib.w = w;\n"
4709                         "}\n";
4710 
4711                 programCollection.glslSources.add("call", &buildOptions)
4712                     << glu::CallableSource(updateRayTracingGLSL(call.str()));
4713             }
4714 
4715             programCollection.glslSources.add("chit", &buildOptions) << glu::ClosestHitSource(genericChit.str());
4716             programCollection.glslSources.add("miss", &buildOptions) << glu::MissSource(genericMiss.str());
4717         }
4718 
4719         break;
4720     }
4721     default:
4722     {
4723         TCU_FAIL("Shader type not valid.");
4724     }
4725     }
4726 }
4727 
4728 #else
4729 
4730 uint32_t rayTracingDefineAnything()
4731 {
4732     return 0;
4733 }
4734 
4735 #endif // CTS_USES_VULKANSC
4736 
4737 } // namespace vk
4738