• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*-------------------------------------------------------------------------
2  * Vulkan CTS Framework
3  * --------------------
4  *
5  * Copyright (c) 2020 The Khronos Group Inc.
6  *
7  * Licensed under the Apache License, Version 2.0 (the "License");
8  * you may not use this file except in compliance with the License.
9  * You may obtain a copy of the License at
10  *
11  *      http://www.apache.org/licenses/LICENSE-2.0
12  *
13  * Unless required by applicable law or agreed to in writing, software
14  * distributed under the License is distributed on an "AS IS" BASIS,
15  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16  * See the License for the specific language governing permissions and
17  * limitations under the License.
18  *
19  *//*!
20  * \file
21  * \brief Utilities for creating commonly used Vulkan objects
22  *//*--------------------------------------------------------------------*/
23 
24 #include "vkRayTracingUtil.hpp"
25 
26 #include "vkRefUtil.hpp"
27 #include "vkQueryUtil.hpp"
28 #include "vkObjUtil.hpp"
29 #include "vkBarrierUtil.hpp"
30 #include "vkCmdUtil.hpp"
31 
32 #include "deStringUtil.hpp"
33 #include "deSTLUtil.hpp"
34 
35 #include <vector>
36 #include <string>
37 #include <thread>
38 #include <limits>
39 #include <type_traits>
40 #include <map>
41 
42 #include "SPIRV/spirv.hpp"
43 
44 namespace vk
45 {
46 
47 #ifndef CTS_USES_VULKANSC
48 
49 static const uint32_t WATCHDOG_INTERVAL = 16384; // Touch watchDog every N iterations.
50 
51 struct DeferredThreadParams
52 {
53     const DeviceInterface &vk;
54     VkDevice device;
55     VkDeferredOperationKHR deferredOperation;
56     VkResult result;
57 };
58 
getFormatSimpleName(vk::VkFormat format)59 std::string getFormatSimpleName(vk::VkFormat format)
60 {
61     constexpr size_t kPrefixLen = 10; // strlen("VK_FORMAT_")
62     return de::toLower(de::toString(format).substr(kPrefixLen));
63 }
64 
pointInTriangle2D(const tcu::Vec3 & p,const tcu::Vec3 & p0,const tcu::Vec3 & p1,const tcu::Vec3 & p2)65 bool pointInTriangle2D(const tcu::Vec3 &p, const tcu::Vec3 &p0, const tcu::Vec3 &p1, const tcu::Vec3 &p2)
66 {
67     float s = p0.y() * p2.x() - p0.x() * p2.y() + (p2.y() - p0.y()) * p.x() + (p0.x() - p2.x()) * p.y();
68     float t = p0.x() * p1.y() - p0.y() * p1.x() + (p0.y() - p1.y()) * p.x() + (p1.x() - p0.x()) * p.y();
69 
70     if ((s < 0) != (t < 0))
71         return false;
72 
73     float a = -p1.y() * p2.x() + p0.y() * (p2.x() - p1.x()) + p0.x() * (p1.y() - p2.y()) + p1.x() * p2.y();
74 
75     return a < 0 ? (s <= 0 && s + t >= a) : (s >= 0 && s + t <= a);
76 }
77 
78 // Returns true if VK_FORMAT_FEATURE_ACCELERATION_STRUCTURE_VERTEX_BUFFER_BIT_KHR needs to be supported for the given format.
isMandatoryAccelerationStructureVertexBufferFormat(vk::VkFormat format)79 static bool isMandatoryAccelerationStructureVertexBufferFormat(vk::VkFormat format)
80 {
81     bool mandatory = false;
82 
83     switch (format)
84     {
85     case VK_FORMAT_R32G32_SFLOAT:
86     case VK_FORMAT_R32G32B32_SFLOAT:
87     case VK_FORMAT_R16G16_SFLOAT:
88     case VK_FORMAT_R16G16B16A16_SFLOAT:
89     case VK_FORMAT_R16G16_SNORM:
90     case VK_FORMAT_R16G16B16A16_SNORM:
91         mandatory = true;
92         break;
93     default:
94         break;
95     }
96 
97     return mandatory;
98 }
99 
checkAccelerationStructureVertexBufferFormat(const vk::InstanceInterface & vki,vk::VkPhysicalDevice physicalDevice,vk::VkFormat format)100 void checkAccelerationStructureVertexBufferFormat(const vk::InstanceInterface &vki, vk::VkPhysicalDevice physicalDevice,
101                                                   vk::VkFormat format)
102 {
103     const vk::VkFormatProperties formatProperties = getPhysicalDeviceFormatProperties(vki, physicalDevice, format);
104 
105     if ((formatProperties.bufferFeatures & vk::VK_FORMAT_FEATURE_ACCELERATION_STRUCTURE_VERTEX_BUFFER_BIT_KHR) == 0u)
106     {
107         const std::string errorMsg = "Format not supported for acceleration structure vertex buffers";
108         if (isMandatoryAccelerationStructureVertexBufferFormat(format))
109             TCU_FAIL(errorMsg);
110         TCU_THROW(NotSupportedError, errorMsg);
111     }
112 }
113 
getCommonRayGenerationShader(void)114 std::string getCommonRayGenerationShader(void)
115 {
116     return "#version 460 core\n"
117            "#extension GL_EXT_ray_tracing : require\n"
118            "layout(location = 0) rayPayloadEXT vec3 hitValue;\n"
119            "layout(set = 0, binding = 1) uniform accelerationStructureEXT topLevelAS;\n"
120            "\n"
121            "void main()\n"
122            "{\n"
123            "  uint  rayFlags = 0;\n"
124            "  uint  cullMask = 0xFF;\n"
125            "  float tmin     = 0.0;\n"
126            "  float tmax     = 9.0;\n"
127            "  vec3  origin   = vec3((float(gl_LaunchIDEXT.x) + 0.5f) / float(gl_LaunchSizeEXT.x), "
128            "(float(gl_LaunchIDEXT.y) + 0.5f) / float(gl_LaunchSizeEXT.y), 0.0);\n"
129            "  vec3  direct   = vec3(0.0, 0.0, -1.0);\n"
130            "  traceRayEXT(topLevelAS, rayFlags, cullMask, 0, 0, 0, origin, tmin, direct, tmax, 0);\n"
131            "}\n";
132 }
133 
RaytracedGeometryBase(VkGeometryTypeKHR geometryType,VkFormat vertexFormat,VkIndexType indexType)134 RaytracedGeometryBase::RaytracedGeometryBase(VkGeometryTypeKHR geometryType, VkFormat vertexFormat,
135                                              VkIndexType indexType)
136     : m_geometryType(geometryType)
137     , m_vertexFormat(vertexFormat)
138     , m_indexType(indexType)
139     , m_geometryFlags((VkGeometryFlagsKHR)0u)
140     , m_hasOpacityMicromap(false)
141 {
142     if (m_geometryType == VK_GEOMETRY_TYPE_AABBS_KHR)
143         DE_ASSERT(m_vertexFormat == VK_FORMAT_R32G32B32_SFLOAT);
144 }
145 
~RaytracedGeometryBase()146 RaytracedGeometryBase::~RaytracedGeometryBase()
147 {
148 }
149 
150 struct GeometryBuilderParams
151 {
152     VkGeometryTypeKHR geometryType;
153     bool usePadding;
154 };
155 
156 template <typename V, typename I>
buildRaytracedGeometry(const GeometryBuilderParams & params)157 RaytracedGeometryBase *buildRaytracedGeometry(const GeometryBuilderParams &params)
158 {
159     return new RaytracedGeometry<V, I>(params.geometryType, (params.usePadding ? 1u : 0u));
160 }
161 
makeRaytracedGeometry(VkGeometryTypeKHR geometryType,VkFormat vertexFormat,VkIndexType indexType,bool padVertices)162 de::SharedPtr<RaytracedGeometryBase> makeRaytracedGeometry(VkGeometryTypeKHR geometryType, VkFormat vertexFormat,
163                                                            VkIndexType indexType, bool padVertices)
164 {
165     const GeometryBuilderParams builderParams{geometryType, padVertices};
166 
167     switch (vertexFormat)
168     {
169     case VK_FORMAT_R32G32_SFLOAT:
170         switch (indexType)
171         {
172         case VK_INDEX_TYPE_UINT16:
173             return de::SharedPtr<RaytracedGeometryBase>(buildRaytracedGeometry<tcu::Vec2, uint16_t>(builderParams));
174         case VK_INDEX_TYPE_UINT32:
175             return de::SharedPtr<RaytracedGeometryBase>(buildRaytracedGeometry<tcu::Vec2, uint32_t>(builderParams));
176         case VK_INDEX_TYPE_NONE_KHR:
177             return de::SharedPtr<RaytracedGeometryBase>(buildRaytracedGeometry<tcu::Vec2, EmptyIndex>(builderParams));
178         default:
179             TCU_THROW(InternalError, "Wrong index type");
180         }
181     case VK_FORMAT_R32G32B32_SFLOAT:
182         switch (indexType)
183         {
184         case VK_INDEX_TYPE_UINT16:
185             return de::SharedPtr<RaytracedGeometryBase>(buildRaytracedGeometry<tcu::Vec3, uint16_t>(builderParams));
186         case VK_INDEX_TYPE_UINT32:
187             return de::SharedPtr<RaytracedGeometryBase>(buildRaytracedGeometry<tcu::Vec3, uint32_t>(builderParams));
188         case VK_INDEX_TYPE_NONE_KHR:
189             return de::SharedPtr<RaytracedGeometryBase>(buildRaytracedGeometry<tcu::Vec3, EmptyIndex>(builderParams));
190         default:
191             TCU_THROW(InternalError, "Wrong index type");
192         }
193     case VK_FORMAT_R32G32B32A32_SFLOAT:
194         switch (indexType)
195         {
196         case VK_INDEX_TYPE_UINT16:
197             return de::SharedPtr<RaytracedGeometryBase>(buildRaytracedGeometry<tcu::Vec4, uint16_t>(builderParams));
198         case VK_INDEX_TYPE_UINT32:
199             return de::SharedPtr<RaytracedGeometryBase>(buildRaytracedGeometry<tcu::Vec4, uint32_t>(builderParams));
200         case VK_INDEX_TYPE_NONE_KHR:
201             return de::SharedPtr<RaytracedGeometryBase>(buildRaytracedGeometry<tcu::Vec4, EmptyIndex>(builderParams));
202         default:
203             TCU_THROW(InternalError, "Wrong index type");
204         }
205     case VK_FORMAT_R16G16_SFLOAT:
206         switch (indexType)
207         {
208         case VK_INDEX_TYPE_UINT16:
209             return de::SharedPtr<RaytracedGeometryBase>(buildRaytracedGeometry<Vec2_16, uint16_t>(builderParams));
210         case VK_INDEX_TYPE_UINT32:
211             return de::SharedPtr<RaytracedGeometryBase>(buildRaytracedGeometry<Vec2_16, uint32_t>(builderParams));
212         case VK_INDEX_TYPE_NONE_KHR:
213             return de::SharedPtr<RaytracedGeometryBase>(buildRaytracedGeometry<Vec2_16, EmptyIndex>(builderParams));
214         default:
215             TCU_THROW(InternalError, "Wrong index type");
216         }
217     case VK_FORMAT_R16G16B16_SFLOAT:
218         switch (indexType)
219         {
220         case VK_INDEX_TYPE_UINT16:
221             return de::SharedPtr<RaytracedGeometryBase>(buildRaytracedGeometry<Vec3_16, uint16_t>(builderParams));
222         case VK_INDEX_TYPE_UINT32:
223             return de::SharedPtr<RaytracedGeometryBase>(buildRaytracedGeometry<Vec3_16, uint32_t>(builderParams));
224         case VK_INDEX_TYPE_NONE_KHR:
225             return de::SharedPtr<RaytracedGeometryBase>(buildRaytracedGeometry<Vec3_16, EmptyIndex>(builderParams));
226         default:
227             TCU_THROW(InternalError, "Wrong index type");
228         }
229     case VK_FORMAT_R16G16B16A16_SFLOAT:
230         switch (indexType)
231         {
232         case VK_INDEX_TYPE_UINT16:
233             return de::SharedPtr<RaytracedGeometryBase>(buildRaytracedGeometry<Vec4_16, uint16_t>(builderParams));
234         case VK_INDEX_TYPE_UINT32:
235             return de::SharedPtr<RaytracedGeometryBase>(buildRaytracedGeometry<Vec4_16, uint32_t>(builderParams));
236         case VK_INDEX_TYPE_NONE_KHR:
237             return de::SharedPtr<RaytracedGeometryBase>(buildRaytracedGeometry<Vec4_16, EmptyIndex>(builderParams));
238         default:
239             TCU_THROW(InternalError, "Wrong index type");
240         }
241     case VK_FORMAT_R16G16_SNORM:
242         switch (indexType)
243         {
244         case VK_INDEX_TYPE_UINT16:
245             return de::SharedPtr<RaytracedGeometryBase>(buildRaytracedGeometry<Vec2_16SNorm, uint16_t>(builderParams));
246         case VK_INDEX_TYPE_UINT32:
247             return de::SharedPtr<RaytracedGeometryBase>(buildRaytracedGeometry<Vec2_16SNorm, uint32_t>(builderParams));
248         case VK_INDEX_TYPE_NONE_KHR:
249             return de::SharedPtr<RaytracedGeometryBase>(
250                 buildRaytracedGeometry<Vec2_16SNorm, EmptyIndex>(builderParams));
251         default:
252             TCU_THROW(InternalError, "Wrong index type");
253         }
254     case VK_FORMAT_R16G16B16_SNORM:
255         switch (indexType)
256         {
257         case VK_INDEX_TYPE_UINT16:
258             return de::SharedPtr<RaytracedGeometryBase>(buildRaytracedGeometry<Vec3_16SNorm, uint16_t>(builderParams));
259         case VK_INDEX_TYPE_UINT32:
260             return de::SharedPtr<RaytracedGeometryBase>(buildRaytracedGeometry<Vec3_16SNorm, uint32_t>(builderParams));
261         case VK_INDEX_TYPE_NONE_KHR:
262             return de::SharedPtr<RaytracedGeometryBase>(
263                 buildRaytracedGeometry<Vec3_16SNorm, EmptyIndex>(builderParams));
264         default:
265             TCU_THROW(InternalError, "Wrong index type");
266         }
267     case VK_FORMAT_R16G16B16A16_SNORM:
268         switch (indexType)
269         {
270         case VK_INDEX_TYPE_UINT16:
271             return de::SharedPtr<RaytracedGeometryBase>(buildRaytracedGeometry<Vec4_16SNorm, uint16_t>(builderParams));
272         case VK_INDEX_TYPE_UINT32:
273             return de::SharedPtr<RaytracedGeometryBase>(buildRaytracedGeometry<Vec4_16SNorm, uint32_t>(builderParams));
274         case VK_INDEX_TYPE_NONE_KHR:
275             return de::SharedPtr<RaytracedGeometryBase>(
276                 buildRaytracedGeometry<Vec4_16SNorm, EmptyIndex>(builderParams));
277         default:
278             TCU_THROW(InternalError, "Wrong index type");
279         }
280     case VK_FORMAT_R64G64_SFLOAT:
281         switch (indexType)
282         {
283         case VK_INDEX_TYPE_UINT16:
284             return de::SharedPtr<RaytracedGeometryBase>(buildRaytracedGeometry<tcu::DVec2, uint16_t>(builderParams));
285         case VK_INDEX_TYPE_UINT32:
286             return de::SharedPtr<RaytracedGeometryBase>(buildRaytracedGeometry<tcu::DVec2, uint32_t>(builderParams));
287         case VK_INDEX_TYPE_NONE_KHR:
288             return de::SharedPtr<RaytracedGeometryBase>(buildRaytracedGeometry<tcu::DVec2, EmptyIndex>(builderParams));
289         default:
290             TCU_THROW(InternalError, "Wrong index type");
291         }
292     case VK_FORMAT_R64G64B64_SFLOAT:
293         switch (indexType)
294         {
295         case VK_INDEX_TYPE_UINT16:
296             return de::SharedPtr<RaytracedGeometryBase>(buildRaytracedGeometry<tcu::DVec3, uint16_t>(builderParams));
297         case VK_INDEX_TYPE_UINT32:
298             return de::SharedPtr<RaytracedGeometryBase>(buildRaytracedGeometry<tcu::DVec3, uint32_t>(builderParams));
299         case VK_INDEX_TYPE_NONE_KHR:
300             return de::SharedPtr<RaytracedGeometryBase>(buildRaytracedGeometry<tcu::DVec3, EmptyIndex>(builderParams));
301         default:
302             TCU_THROW(InternalError, "Wrong index type");
303         }
304     case VK_FORMAT_R64G64B64A64_SFLOAT:
305         switch (indexType)
306         {
307         case VK_INDEX_TYPE_UINT16:
308             return de::SharedPtr<RaytracedGeometryBase>(buildRaytracedGeometry<tcu::DVec4, uint16_t>(builderParams));
309         case VK_INDEX_TYPE_UINT32:
310             return de::SharedPtr<RaytracedGeometryBase>(buildRaytracedGeometry<tcu::DVec4, uint32_t>(builderParams));
311         case VK_INDEX_TYPE_NONE_KHR:
312             return de::SharedPtr<RaytracedGeometryBase>(buildRaytracedGeometry<tcu::DVec4, EmptyIndex>(builderParams));
313         default:
314             TCU_THROW(InternalError, "Wrong index type");
315         }
316     case VK_FORMAT_R8G8_SNORM:
317         switch (indexType)
318         {
319         case VK_INDEX_TYPE_UINT16:
320             return de::SharedPtr<RaytracedGeometryBase>(buildRaytracedGeometry<Vec2_8SNorm, uint16_t>(builderParams));
321         case VK_INDEX_TYPE_UINT32:
322             return de::SharedPtr<RaytracedGeometryBase>(buildRaytracedGeometry<Vec2_8SNorm, uint32_t>(builderParams));
323         case VK_INDEX_TYPE_NONE_KHR:
324             return de::SharedPtr<RaytracedGeometryBase>(buildRaytracedGeometry<Vec2_8SNorm, EmptyIndex>(builderParams));
325         default:
326             TCU_THROW(InternalError, "Wrong index type");
327         }
328     case VK_FORMAT_R8G8B8_SNORM:
329         switch (indexType)
330         {
331         case VK_INDEX_TYPE_UINT16:
332             return de::SharedPtr<RaytracedGeometryBase>(buildRaytracedGeometry<Vec3_8SNorm, uint16_t>(builderParams));
333         case VK_INDEX_TYPE_UINT32:
334             return de::SharedPtr<RaytracedGeometryBase>(buildRaytracedGeometry<Vec3_8SNorm, uint32_t>(builderParams));
335         case VK_INDEX_TYPE_NONE_KHR:
336             return de::SharedPtr<RaytracedGeometryBase>(buildRaytracedGeometry<Vec3_8SNorm, EmptyIndex>(builderParams));
337         default:
338             TCU_THROW(InternalError, "Wrong index type");
339         }
340     case VK_FORMAT_R8G8B8A8_SNORM:
341         switch (indexType)
342         {
343         case VK_INDEX_TYPE_UINT16:
344             return de::SharedPtr<RaytracedGeometryBase>(buildRaytracedGeometry<Vec4_8SNorm, uint16_t>(builderParams));
345         case VK_INDEX_TYPE_UINT32:
346             return de::SharedPtr<RaytracedGeometryBase>(buildRaytracedGeometry<Vec4_8SNorm, uint32_t>(builderParams));
347         case VK_INDEX_TYPE_NONE_KHR:
348             return de::SharedPtr<RaytracedGeometryBase>(buildRaytracedGeometry<Vec4_8SNorm, EmptyIndex>(builderParams));
349         default:
350             TCU_THROW(InternalError, "Wrong index type");
351         }
352     default:
353         TCU_THROW(InternalError, "Wrong vertex format");
354     }
355 }
356 
getBufferDeviceAddress(const DeviceInterface & vk,const VkDevice device,const VkBuffer buffer,VkDeviceSize offset)357 VkDeviceAddress getBufferDeviceAddress(const DeviceInterface &vk, const VkDevice device, const VkBuffer buffer,
358                                        VkDeviceSize offset)
359 {
360 
361     if (buffer == VK_NULL_HANDLE)
362         return 0;
363 
364     VkBufferDeviceAddressInfo deviceAddressInfo{
365         VK_STRUCTURE_TYPE_BUFFER_DEVICE_ADDRESS_INFO, // VkStructureType    sType
366         nullptr,                                      // const void*        pNext
367         buffer                                        // VkBuffer           buffer;
368     };
369     return vk.getBufferDeviceAddress(device, &deviceAddressInfo) + offset;
370 }
371 
makeQueryPool(const DeviceInterface & vk,const VkDevice device,const VkQueryType queryType,uint32_t queryCount)372 static inline Move<VkQueryPool> makeQueryPool(const DeviceInterface &vk, const VkDevice device,
373                                               const VkQueryType queryType, uint32_t queryCount)
374 {
375     const VkQueryPoolCreateInfo queryPoolCreateInfo = {
376         VK_STRUCTURE_TYPE_QUERY_POOL_CREATE_INFO, // sType
377         nullptr,                                  // pNext
378         (VkQueryPoolCreateFlags)0,                // flags
379         queryType,                                // queryType
380         queryCount,                               // queryCount
381         0u,                                       // pipelineStatistics
382     };
383     return createQueryPool(vk, device, &queryPoolCreateInfo);
384 }
385 
makeVkAccelerationStructureGeometryDataKHR(const VkAccelerationStructureGeometryTrianglesDataKHR & triangles)386 static inline VkAccelerationStructureGeometryDataKHR makeVkAccelerationStructureGeometryDataKHR(
387     const VkAccelerationStructureGeometryTrianglesDataKHR &triangles)
388 {
389     VkAccelerationStructureGeometryDataKHR result;
390 
391     deMemset(&result, 0, sizeof(result));
392 
393     result.triangles = triangles;
394 
395     return result;
396 }
397 
makeVkAccelerationStructureGeometryDataKHR(const VkAccelerationStructureGeometryAabbsDataKHR & aabbs)398 static inline VkAccelerationStructureGeometryDataKHR makeVkAccelerationStructureGeometryDataKHR(
399     const VkAccelerationStructureGeometryAabbsDataKHR &aabbs)
400 {
401     VkAccelerationStructureGeometryDataKHR result;
402 
403     deMemset(&result, 0, sizeof(result));
404 
405     result.aabbs = aabbs;
406 
407     return result;
408 }
409 
makeVkAccelerationStructureInstancesDataKHR(const VkAccelerationStructureGeometryInstancesDataKHR & instances)410 static inline VkAccelerationStructureGeometryDataKHR makeVkAccelerationStructureInstancesDataKHR(
411     const VkAccelerationStructureGeometryInstancesDataKHR &instances)
412 {
413     VkAccelerationStructureGeometryDataKHR result;
414 
415     deMemset(&result, 0, sizeof(result));
416 
417     result.instances = instances;
418 
419     return result;
420 }
421 
makeVkAccelerationStructureInstanceKHR(const VkTransformMatrixKHR & transform,uint32_t instanceCustomIndex,uint32_t mask,uint32_t instanceShaderBindingTableRecordOffset,VkGeometryInstanceFlagsKHR flags,uint64_t accelerationStructureReference)422 static inline VkAccelerationStructureInstanceKHR makeVkAccelerationStructureInstanceKHR(
423     const VkTransformMatrixKHR &transform, uint32_t instanceCustomIndex, uint32_t mask,
424     uint32_t instanceShaderBindingTableRecordOffset, VkGeometryInstanceFlagsKHR flags,
425     uint64_t accelerationStructureReference)
426 {
427     VkAccelerationStructureInstanceKHR instance     = {transform, 0, 0, 0, 0, accelerationStructureReference};
428     instance.instanceCustomIndex                    = instanceCustomIndex & 0xFFFFFF;
429     instance.mask                                   = mask & 0xFF;
430     instance.instanceShaderBindingTableRecordOffset = instanceShaderBindingTableRecordOffset & 0xFFFFFF;
431     instance.flags                                  = flags & 0xFF;
432     return instance;
433 }
434 
getRayTracingShaderGroupHandlesKHR(const DeviceInterface & vk,const VkDevice device,const VkPipeline pipeline,const uint32_t firstGroup,const uint32_t groupCount,const uintptr_t dataSize,void * pData)435 VkResult getRayTracingShaderGroupHandlesKHR(const DeviceInterface &vk, const VkDevice device, const VkPipeline pipeline,
436                                             const uint32_t firstGroup, const uint32_t groupCount,
437                                             const uintptr_t dataSize, void *pData)
438 {
439     return vk.getRayTracingShaderGroupHandlesKHR(device, pipeline, firstGroup, groupCount, dataSize, pData);
440 }
441 
getRayTracingShaderGroupHandles(const DeviceInterface & vk,const VkDevice device,const VkPipeline pipeline,const uint32_t firstGroup,const uint32_t groupCount,const uintptr_t dataSize,void * pData)442 VkResult getRayTracingShaderGroupHandles(const DeviceInterface &vk, const VkDevice device, const VkPipeline pipeline,
443                                          const uint32_t firstGroup, const uint32_t groupCount, const uintptr_t dataSize,
444                                          void *pData)
445 {
446     return getRayTracingShaderGroupHandlesKHR(vk, device, pipeline, firstGroup, groupCount, dataSize, pData);
447 }
448 
getRayTracingCaptureReplayShaderGroupHandles(const DeviceInterface & vk,const VkDevice device,const VkPipeline pipeline,const uint32_t firstGroup,const uint32_t groupCount,const uintptr_t dataSize,void * pData)449 VkResult getRayTracingCaptureReplayShaderGroupHandles(const DeviceInterface &vk, const VkDevice device,
450                                                       const VkPipeline pipeline, const uint32_t firstGroup,
451                                                       const uint32_t groupCount, const uintptr_t dataSize, void *pData)
452 {
453     return vk.getRayTracingCaptureReplayShaderGroupHandlesKHR(device, pipeline, firstGroup, groupCount, dataSize,
454                                                               pData);
455 }
456 
finishDeferredOperation(const DeviceInterface & vk,VkDevice device,VkDeferredOperationKHR deferredOperation)457 VkResult finishDeferredOperation(const DeviceInterface &vk, VkDevice device, VkDeferredOperationKHR deferredOperation)
458 {
459     VkResult result = vk.deferredOperationJoinKHR(device, deferredOperation);
460 
461     while (result == VK_THREAD_IDLE_KHR)
462     {
463         std::this_thread::yield();
464         result = vk.deferredOperationJoinKHR(device, deferredOperation);
465     }
466 
467     switch (result)
468     {
469     case VK_SUCCESS:
470     {
471         // Deferred operation has finished. Query its result
472         result = vk.getDeferredOperationResultKHR(device, deferredOperation);
473 
474         break;
475     }
476 
477     case VK_THREAD_DONE_KHR:
478     {
479         // Deferred operation is being wrapped up by another thread
480         // wait for that thread to finish
481         do
482         {
483             std::this_thread::yield();
484             result = vk.getDeferredOperationResultKHR(device, deferredOperation);
485         } while (result == VK_NOT_READY);
486 
487         break;
488     }
489 
490     default:
491     {
492         DE_ASSERT(false);
493 
494         break;
495     }
496     }
497 
498     return result;
499 }
500 
finishDeferredOperationThreaded(DeferredThreadParams * deferredThreadParams)501 void finishDeferredOperationThreaded(DeferredThreadParams *deferredThreadParams)
502 {
503     deferredThreadParams->result = finishDeferredOperation(deferredThreadParams->vk, deferredThreadParams->device,
504                                                            deferredThreadParams->deferredOperation);
505 }
506 
finishDeferredOperation(const DeviceInterface & vk,VkDevice device,VkDeferredOperationKHR deferredOperation,const uint32_t workerThreadCount,const bool operationNotDeferred)507 void finishDeferredOperation(const DeviceInterface &vk, VkDevice device, VkDeferredOperationKHR deferredOperation,
508                              const uint32_t workerThreadCount, const bool operationNotDeferred)
509 {
510 
511     if (operationNotDeferred)
512     {
513         // when the operation deferral returns VK_OPERATION_NOT_DEFERRED_KHR,
514         // the deferred operation should act as if no command was deferred
515         VK_CHECK(vk.getDeferredOperationResultKHR(device, deferredOperation));
516 
517         // there is not need to join any threads to the deferred operation,
518         // so below can be skipped.
519         return;
520     }
521 
522     if (workerThreadCount == 0)
523     {
524         VK_CHECK(finishDeferredOperation(vk, device, deferredOperation));
525     }
526     else
527     {
528         const uint32_t maxThreadCountSupported =
529             deMinu32(256u, vk.getDeferredOperationMaxConcurrencyKHR(device, deferredOperation));
530         const uint32_t requestedThreadCount = workerThreadCount;
531         const uint32_t testThreadCount      = requestedThreadCount == std::numeric_limits<uint32_t>::max() ?
532                                                   maxThreadCountSupported :
533                                                   requestedThreadCount;
534 
535         if (maxThreadCountSupported == 0)
536             TCU_FAIL("vkGetDeferredOperationMaxConcurrencyKHR must not return 0");
537 
538         const DeferredThreadParams deferredThreadParams = {
539             vk,                 //  const DeviceInterface& vk;
540             device,             //  VkDevice device;
541             deferredOperation,  //  VkDeferredOperationKHR deferredOperation;
542             VK_RESULT_MAX_ENUM, //  VResult result;
543         };
544         std::vector<DeferredThreadParams> threadParams(testThreadCount, deferredThreadParams);
545         std::vector<de::MovePtr<std::thread>> threads(testThreadCount);
546         bool executionResult = false;
547 
548         DE_ASSERT(threads.size() > 0 && threads.size() == testThreadCount);
549 
550         for (uint32_t threadNdx = 0; threadNdx < testThreadCount; ++threadNdx)
551             threads[threadNdx] =
552                 de::MovePtr<std::thread>(new std::thread(finishDeferredOperationThreaded, &threadParams[threadNdx]));
553 
554         for (uint32_t threadNdx = 0; threadNdx < testThreadCount; ++threadNdx)
555             threads[threadNdx]->join();
556 
557         for (uint32_t threadNdx = 0; threadNdx < testThreadCount; ++threadNdx)
558             if (threadParams[threadNdx].result == VK_SUCCESS)
559                 executionResult = true;
560 
561         if (!executionResult)
562             TCU_FAIL("Neither reported VK_SUCCESS");
563     }
564 }
565 
SerialStorage(const DeviceInterface & vk,const VkDevice device,Allocator & allocator,const VkAccelerationStructureBuildTypeKHR buildType,const VkDeviceSize storageSize)566 SerialStorage::SerialStorage(const DeviceInterface &vk, const VkDevice device, Allocator &allocator,
567                              const VkAccelerationStructureBuildTypeKHR buildType, const VkDeviceSize storageSize)
568     : m_buildType(buildType)
569     , m_storageSize(storageSize)
570     , m_serialInfo()
571 {
572     const VkBufferCreateInfo bufferCreateInfo =
573         makeBufferCreateInfo(storageSize, VK_BUFFER_USAGE_ACCELERATION_STRUCTURE_STORAGE_BIT_KHR |
574                                               VK_BUFFER_USAGE_SHADER_DEVICE_ADDRESS_BIT);
575     try
576     {
577         m_buffer = de::MovePtr<BufferWithMemory>(
578             new BufferWithMemory(vk, device, allocator, bufferCreateInfo,
579                                  MemoryRequirement::Cached | MemoryRequirement::HostVisible |
580                                      MemoryRequirement::Coherent | MemoryRequirement::DeviceAddress));
581     }
582     catch (const tcu::NotSupportedError &)
583     {
584         // retry without Cached flag
585         m_buffer = de::MovePtr<BufferWithMemory>(new BufferWithMemory(
586             vk, device, allocator, bufferCreateInfo,
587             MemoryRequirement::HostVisible | MemoryRequirement::Coherent | MemoryRequirement::DeviceAddress));
588     }
589 }
590 
SerialStorage(const DeviceInterface & vk,const VkDevice device,Allocator & allocator,const VkAccelerationStructureBuildTypeKHR buildType,const SerialInfo & serialInfo)591 SerialStorage::SerialStorage(const DeviceInterface &vk, const VkDevice device, Allocator &allocator,
592                              const VkAccelerationStructureBuildTypeKHR buildType, const SerialInfo &serialInfo)
593     : m_buildType(buildType)
594     , m_storageSize(serialInfo.sizes()[0]) // raise assertion if serialInfo is empty
595     , m_serialInfo(serialInfo)
596 {
597     DE_ASSERT(serialInfo.sizes().size() >= 2u);
598 
599     // create buffer for top-level acceleration structure
600     {
601         const VkBufferCreateInfo bufferCreateInfo =
602             makeBufferCreateInfo(m_storageSize, VK_BUFFER_USAGE_ACCELERATION_STRUCTURE_STORAGE_BIT_KHR |
603                                                     VK_BUFFER_USAGE_SHADER_DEVICE_ADDRESS_BIT);
604         m_buffer = de::MovePtr<BufferWithMemory>(new BufferWithMemory(
605             vk, device, allocator, bufferCreateInfo,
606             MemoryRequirement::HostVisible | MemoryRequirement::Coherent | MemoryRequirement::DeviceAddress));
607     }
608 
609     // create buffers for bottom-level acceleration structures
610     {
611         std::vector<uint64_t> addrs;
612 
613         for (std::size_t i = 1; i < serialInfo.addresses().size(); ++i)
614         {
615             const uint64_t &lookAddr = serialInfo.addresses()[i];
616             auto end                 = addrs.end();
617             auto match = std::find_if(addrs.begin(), end, [&](const uint64_t &item) { return item == lookAddr; });
618             if (match == end)
619             {
620                 addrs.emplace_back(lookAddr);
621                 m_bottoms.emplace_back(de::SharedPtr<SerialStorage>(
622                     new SerialStorage(vk, device, allocator, buildType, serialInfo.sizes()[i])));
623             }
624         }
625     }
626 }
627 
getAddress(const DeviceInterface & vk,const VkDevice device,const VkAccelerationStructureBuildTypeKHR buildType)628 VkDeviceOrHostAddressKHR SerialStorage::getAddress(const DeviceInterface &vk, const VkDevice device,
629                                                    const VkAccelerationStructureBuildTypeKHR buildType)
630 {
631     if (buildType == VK_ACCELERATION_STRUCTURE_BUILD_TYPE_DEVICE_KHR)
632         return makeDeviceOrHostAddressKHR(vk, device, m_buffer->get(), 0);
633     else
634         return makeDeviceOrHostAddressKHR(m_buffer->getAllocation().getHostPtr());
635 }
636 
getASHeader()637 SerialStorage::AccelerationStructureHeader *SerialStorage::getASHeader()
638 {
639     return reinterpret_cast<AccelerationStructureHeader *>(getHostAddress().hostAddress);
640 }
641 
hasDeepFormat() const642 bool SerialStorage::hasDeepFormat() const
643 {
644     return (m_serialInfo.sizes().size() >= 2u);
645 }
646 
getBottomStorage(uint32_t index) const647 de::SharedPtr<SerialStorage> SerialStorage::getBottomStorage(uint32_t index) const
648 {
649     return m_bottoms[index];
650 }
651 
getHostAddress(VkDeviceSize offset)652 VkDeviceOrHostAddressKHR SerialStorage::getHostAddress(VkDeviceSize offset)
653 {
654     DE_ASSERT(offset < m_storageSize);
655     return makeDeviceOrHostAddressKHR(static_cast<uint8_t *>(m_buffer->getAllocation().getHostPtr()) + offset);
656 }
657 
getHostAddressConst(VkDeviceSize offset)658 VkDeviceOrHostAddressConstKHR SerialStorage::getHostAddressConst(VkDeviceSize offset)
659 {
660     return makeDeviceOrHostAddressConstKHR(static_cast<uint8_t *>(m_buffer->getAllocation().getHostPtr()) + offset);
661 }
662 
getAddressConst(const DeviceInterface & vk,const VkDevice device,const VkAccelerationStructureBuildTypeKHR buildType)663 VkDeviceOrHostAddressConstKHR SerialStorage::getAddressConst(const DeviceInterface &vk, const VkDevice device,
664                                                              const VkAccelerationStructureBuildTypeKHR buildType)
665 {
666     if (buildType == VK_ACCELERATION_STRUCTURE_BUILD_TYPE_DEVICE_KHR)
667         return makeDeviceOrHostAddressConstKHR(vk, device, m_buffer->get(), 0);
668     else
669         return getHostAddressConst();
670 }
671 
getStorageSize() const672 inline VkDeviceSize SerialStorage::getStorageSize() const
673 {
674     return m_storageSize;
675 }
676 
getSerialInfo() const677 inline const SerialInfo &SerialStorage::getSerialInfo() const
678 {
679     return m_serialInfo;
680 }
681 
getDeserializedSize()682 uint64_t SerialStorage::getDeserializedSize()
683 {
684     uint64_t result         = 0;
685     const uint8_t *startPtr = static_cast<uint8_t *>(m_buffer->getAllocation().getHostPtr());
686 
687     DE_ASSERT(sizeof(result) == DESERIALIZED_SIZE_SIZE);
688 
689     deMemcpy(&result, startPtr + DESERIALIZED_SIZE_OFFSET, sizeof(result));
690 
691     return result;
692 }
693 
~BottomLevelAccelerationStructure()694 BottomLevelAccelerationStructure::~BottomLevelAccelerationStructure()
695 {
696 }
697 
BottomLevelAccelerationStructure()698 BottomLevelAccelerationStructure::BottomLevelAccelerationStructure()
699     : m_structureSize(0u)
700     , m_updateScratchSize(0u)
701     , m_buildScratchSize(0u)
702 {
703 }
704 
setGeometryData(const std::vector<tcu::Vec3> & geometryData,const bool triangles,const VkGeometryFlagsKHR geometryFlags)705 void BottomLevelAccelerationStructure::setGeometryData(const std::vector<tcu::Vec3> &geometryData, const bool triangles,
706                                                        const VkGeometryFlagsKHR geometryFlags)
707 {
708     if (triangles)
709         DE_ASSERT((geometryData.size() % 3) == 0);
710     else
711         DE_ASSERT((geometryData.size() % 2) == 0);
712 
713     setGeometryCount(1u);
714 
715     addGeometry(geometryData, triangles, geometryFlags);
716 }
717 
setDefaultGeometryData(const VkShaderStageFlagBits testStage,const VkGeometryFlagsKHR geometryFlags)718 void BottomLevelAccelerationStructure::setDefaultGeometryData(const VkShaderStageFlagBits testStage,
719                                                               const VkGeometryFlagsKHR geometryFlags)
720 {
721     bool trianglesData = false;
722     float z            = 0.0f;
723     std::vector<tcu::Vec3> geometryData;
724 
725     switch (testStage)
726     {
727     case VK_SHADER_STAGE_RAYGEN_BIT_KHR:
728         z             = -1.0f;
729         trianglesData = true;
730         break;
731     case VK_SHADER_STAGE_ANY_HIT_BIT_KHR:
732         z             = -1.0f;
733         trianglesData = true;
734         break;
735     case VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR:
736         z             = -1.0f;
737         trianglesData = true;
738         break;
739     case VK_SHADER_STAGE_MISS_BIT_KHR:
740         z             = -9.9f;
741         trianglesData = true;
742         break;
743     case VK_SHADER_STAGE_INTERSECTION_BIT_KHR:
744         z             = -1.0f;
745         trianglesData = false;
746         break;
747     case VK_SHADER_STAGE_CALLABLE_BIT_KHR:
748         z             = -1.0f;
749         trianglesData = true;
750         break;
751     default:
752         TCU_THROW(InternalError, "Unacceptable stage");
753     }
754 
755     if (trianglesData)
756     {
757         geometryData.reserve(6);
758 
759         geometryData.push_back(tcu::Vec3(-1.0f, -1.0f, z));
760         geometryData.push_back(tcu::Vec3(-1.0f, +1.0f, z));
761         geometryData.push_back(tcu::Vec3(+1.0f, -1.0f, z));
762         geometryData.push_back(tcu::Vec3(+1.0f, -1.0f, z));
763         geometryData.push_back(tcu::Vec3(-1.0f, +1.0f, z));
764         geometryData.push_back(tcu::Vec3(+1.0f, +1.0f, z));
765     }
766     else
767     {
768         geometryData.reserve(2);
769 
770         geometryData.push_back(tcu::Vec3(-1.0f, -1.0f, z));
771         geometryData.push_back(tcu::Vec3(+1.0f, +1.0f, z));
772     }
773 
774     setGeometryCount(1u);
775 
776     addGeometry(geometryData, trianglesData, geometryFlags);
777 }
778 
setGeometryCount(const size_t geometryCount)779 void BottomLevelAccelerationStructure::setGeometryCount(const size_t geometryCount)
780 {
781     m_geometriesData.clear();
782 
783     m_geometriesData.reserve(geometryCount);
784 }
785 
addGeometry(de::SharedPtr<RaytracedGeometryBase> & raytracedGeometry)786 void BottomLevelAccelerationStructure::addGeometry(de::SharedPtr<RaytracedGeometryBase> &raytracedGeometry)
787 {
788     m_geometriesData.push_back(raytracedGeometry);
789 }
790 
addGeometry(const std::vector<tcu::Vec3> & geometryData,const bool triangles,const VkGeometryFlagsKHR geometryFlags,const VkAccelerationStructureTrianglesOpacityMicromapEXT * opacityGeometryMicromap)791 void BottomLevelAccelerationStructure::addGeometry(
792     const std::vector<tcu::Vec3> &geometryData, const bool triangles, const VkGeometryFlagsKHR geometryFlags,
793     const VkAccelerationStructureTrianglesOpacityMicromapEXT *opacityGeometryMicromap)
794 {
795     DE_ASSERT(geometryData.size() > 0);
796     DE_ASSERT((triangles && geometryData.size() % 3 == 0) || (!triangles && geometryData.size() % 2 == 0));
797 
798     if (!triangles)
799         for (size_t posNdx = 0; posNdx < geometryData.size() / 2; ++posNdx)
800         {
801             DE_ASSERT(geometryData[2 * posNdx].x() <= geometryData[2 * posNdx + 1].x());
802             DE_ASSERT(geometryData[2 * posNdx].y() <= geometryData[2 * posNdx + 1].y());
803             DE_ASSERT(geometryData[2 * posNdx].z() <= geometryData[2 * posNdx + 1].z());
804         }
805 
806     de::SharedPtr<RaytracedGeometryBase> geometry =
807         makeRaytracedGeometry(triangles ? VK_GEOMETRY_TYPE_TRIANGLES_KHR : VK_GEOMETRY_TYPE_AABBS_KHR,
808                               VK_FORMAT_R32G32B32_SFLOAT, VK_INDEX_TYPE_NONE_KHR);
809     for (auto it = begin(geometryData), eit = end(geometryData); it != eit; ++it)
810         geometry->addVertex(*it);
811 
812     geometry->setGeometryFlags(geometryFlags);
813     if (opacityGeometryMicromap)
814         geometry->setOpacityMicromap(opacityGeometryMicromap);
815     addGeometry(geometry);
816 }
817 
getStructureBuildSizes() const818 VkAccelerationStructureBuildSizesInfoKHR BottomLevelAccelerationStructure::getStructureBuildSizes() const
819 {
820     return {
821         VK_STRUCTURE_TYPE_ACCELERATION_STRUCTURE_BUILD_SIZES_INFO_KHR, //  VkStructureType sType;
822         nullptr,                                                       //  const void* pNext;
823         m_structureSize,                                               //  VkDeviceSize accelerationStructureSize;
824         m_updateScratchSize,                                           //  VkDeviceSize updateScratchSize;
825         m_buildScratchSize                                             //  VkDeviceSize buildScratchSize;
826     };
827 };
828 
getVertexBufferSize(const std::vector<de::SharedPtr<RaytracedGeometryBase>> & geometriesData)829 VkDeviceSize getVertexBufferSize(const std::vector<de::SharedPtr<RaytracedGeometryBase>> &geometriesData)
830 {
831     DE_ASSERT(geometriesData.size() != 0);
832     VkDeviceSize bufferSizeBytes = 0;
833     for (size_t geometryNdx = 0; geometryNdx < geometriesData.size(); ++geometryNdx)
834         bufferSizeBytes += deAlignSize(geometriesData[geometryNdx]->getVertexByteSize(), 8);
835     return bufferSizeBytes;
836 }
837 
createVertexBuffer(const DeviceInterface & vk,const VkDevice device,Allocator & allocator,const VkDeviceSize bufferSizeBytes)838 BufferWithMemory *createVertexBuffer(const DeviceInterface &vk, const VkDevice device, Allocator &allocator,
839                                      const VkDeviceSize bufferSizeBytes)
840 {
841     const VkBufferCreateInfo bufferCreateInfo =
842         makeBufferCreateInfo(bufferSizeBytes, VK_BUFFER_USAGE_ACCELERATION_STRUCTURE_BUILD_INPUT_READ_ONLY_BIT_KHR |
843                                                   VK_BUFFER_USAGE_SHADER_DEVICE_ADDRESS_BIT);
844     return new BufferWithMemory(vk, device, allocator, bufferCreateInfo,
845                                 MemoryRequirement::HostVisible | MemoryRequirement::Coherent |
846                                     MemoryRequirement::DeviceAddress);
847 }
848 
createVertexBuffer(const DeviceInterface & vk,const VkDevice device,Allocator & allocator,const std::vector<de::SharedPtr<RaytracedGeometryBase>> & geometriesData)849 BufferWithMemory *createVertexBuffer(const DeviceInterface &vk, const VkDevice device, Allocator &allocator,
850                                      const std::vector<de::SharedPtr<RaytracedGeometryBase>> &geometriesData)
851 {
852     return createVertexBuffer(vk, device, allocator, getVertexBufferSize(geometriesData));
853 }
854 
updateVertexBuffer(const DeviceInterface & vk,const VkDevice device,const std::vector<de::SharedPtr<RaytracedGeometryBase>> & geometriesData,BufferWithMemory * vertexBuffer,VkDeviceSize geometriesOffset=0)855 void updateVertexBuffer(const DeviceInterface &vk, const VkDevice device,
856                         const std::vector<de::SharedPtr<RaytracedGeometryBase>> &geometriesData,
857                         BufferWithMemory *vertexBuffer, VkDeviceSize geometriesOffset = 0)
858 {
859     const Allocation &geometryAlloc = vertexBuffer->getAllocation();
860     uint8_t *bufferStart            = static_cast<uint8_t *>(geometryAlloc.getHostPtr());
861     VkDeviceSize bufferOffset       = geometriesOffset;
862 
863     for (size_t geometryNdx = 0; geometryNdx < geometriesData.size(); ++geometryNdx)
864     {
865         const void *geometryPtr      = geometriesData[geometryNdx]->getVertexPointer();
866         const size_t geometryPtrSize = geometriesData[geometryNdx]->getVertexByteSize();
867 
868         deMemcpy(&bufferStart[bufferOffset], geometryPtr, geometryPtrSize);
869 
870         bufferOffset += deAlignSize(geometryPtrSize, 8);
871     }
872 
873     // Flush the whole allocation. We could flush only the interesting range, but we'd need to be sure both the offset and size
874     // align to VkPhysicalDeviceLimits::nonCoherentAtomSize, which we are not considering. Also note most code uses Coherent memory
875     // for the vertex and index buffers, so flushing is actually not needed.
876     flushAlloc(vk, device, geometryAlloc);
877 }
878 
getIndexBufferSize(const std::vector<de::SharedPtr<RaytracedGeometryBase>> & geometriesData)879 VkDeviceSize getIndexBufferSize(const std::vector<de::SharedPtr<RaytracedGeometryBase>> &geometriesData)
880 {
881     DE_ASSERT(!geometriesData.empty());
882 
883     VkDeviceSize bufferSizeBytes = 0;
884     for (size_t geometryNdx = 0; geometryNdx < geometriesData.size(); ++geometryNdx)
885         if (geometriesData[geometryNdx]->getIndexType() != VK_INDEX_TYPE_NONE_KHR)
886             bufferSizeBytes += deAlignSize(geometriesData[geometryNdx]->getIndexByteSize(), 8);
887     return bufferSizeBytes;
888 }
889 
createIndexBuffer(const DeviceInterface & vk,const VkDevice device,Allocator & allocator,const VkDeviceSize bufferSizeBytes)890 BufferWithMemory *createIndexBuffer(const DeviceInterface &vk, const VkDevice device, Allocator &allocator,
891                                     const VkDeviceSize bufferSizeBytes)
892 {
893     DE_ASSERT(bufferSizeBytes);
894     const VkBufferCreateInfo bufferCreateInfo =
895         makeBufferCreateInfo(bufferSizeBytes, VK_BUFFER_USAGE_ACCELERATION_STRUCTURE_BUILD_INPUT_READ_ONLY_BIT_KHR |
896                                                   VK_BUFFER_USAGE_SHADER_DEVICE_ADDRESS_BIT);
897     return new BufferWithMemory(vk, device, allocator, bufferCreateInfo,
898                                 MemoryRequirement::HostVisible | MemoryRequirement::Coherent |
899                                     MemoryRequirement::DeviceAddress);
900 }
901 
createIndexBuffer(const DeviceInterface & vk,const VkDevice device,Allocator & allocator,const std::vector<de::SharedPtr<RaytracedGeometryBase>> & geometriesData)902 BufferWithMemory *createIndexBuffer(const DeviceInterface &vk, const VkDevice device, Allocator &allocator,
903                                     const std::vector<de::SharedPtr<RaytracedGeometryBase>> &geometriesData)
904 {
905     const VkDeviceSize bufferSizeBytes = getIndexBufferSize(geometriesData);
906     return bufferSizeBytes ? createIndexBuffer(vk, device, allocator, bufferSizeBytes) : nullptr;
907 }
908 
updateIndexBuffer(const DeviceInterface & vk,const VkDevice device,const std::vector<de::SharedPtr<RaytracedGeometryBase>> & geometriesData,BufferWithMemory * indexBuffer,VkDeviceSize geometriesOffset)909 void updateIndexBuffer(const DeviceInterface &vk, const VkDevice device,
910                        const std::vector<de::SharedPtr<RaytracedGeometryBase>> &geometriesData,
911                        BufferWithMemory *indexBuffer, VkDeviceSize geometriesOffset)
912 {
913     const Allocation &indexAlloc = indexBuffer->getAllocation();
914     uint8_t *bufferStart         = static_cast<uint8_t *>(indexAlloc.getHostPtr());
915     VkDeviceSize bufferOffset    = geometriesOffset;
916 
917     for (size_t geometryNdx = 0; geometryNdx < geometriesData.size(); ++geometryNdx)
918     {
919         if (geometriesData[geometryNdx]->getIndexType() != VK_INDEX_TYPE_NONE_KHR)
920         {
921             const void *indexPtr      = geometriesData[geometryNdx]->getIndexPointer();
922             const size_t indexPtrSize = geometriesData[geometryNdx]->getIndexByteSize();
923 
924             deMemcpy(&bufferStart[bufferOffset], indexPtr, indexPtrSize);
925 
926             bufferOffset += deAlignSize(indexPtrSize, 8);
927         }
928     }
929 
930     // Flush the whole allocation. We could flush only the interesting range, but we'd need to be sure both the offset and size
931     // align to VkPhysicalDeviceLimits::nonCoherentAtomSize, which we are not considering. Also note most code uses Coherent memory
932     // for the vertex and index buffers, so flushing is actually not needed.
933     flushAlloc(vk, device, indexAlloc);
934 }
935 
936 class BottomLevelAccelerationStructureKHR : public BottomLevelAccelerationStructure
937 {
938 public:
939     static uint32_t getRequiredAllocationCount(void);
940 
941     BottomLevelAccelerationStructureKHR();
942     BottomLevelAccelerationStructureKHR(const BottomLevelAccelerationStructureKHR &other) = delete;
943     virtual ~BottomLevelAccelerationStructureKHR();
944 
945     void setBuildType(const VkAccelerationStructureBuildTypeKHR buildType) override;
946     VkAccelerationStructureBuildTypeKHR getBuildType() const override;
947     void setCreateFlags(const VkAccelerationStructureCreateFlagsKHR createFlags) override;
948     void setCreateGeneric(bool createGeneric) override;
949     void setCreationBufferUnbounded(bool creationBufferUnbounded) override;
950     void setBuildFlags(const VkBuildAccelerationStructureFlagsKHR buildFlags) override;
951     void setBuildWithoutGeometries(bool buildWithoutGeometries) override;
952     void setBuildWithoutPrimitives(bool buildWithoutPrimitives) override;
953     void setDeferredOperation(const bool deferredOperation, const uint32_t workerThreadCount) override;
954     void setUseArrayOfPointers(const bool useArrayOfPointers) override;
955     void setUseMaintenance5(const bool useMaintenance5) override;
956     void setIndirectBuildParameters(const VkBuffer indirectBuffer, const VkDeviceSize indirectBufferOffset,
957                                     const uint32_t indirectBufferStride) override;
958     VkBuildAccelerationStructureFlagsKHR getBuildFlags() const override;
959 
960     void create(const DeviceInterface &vk, const VkDevice device, Allocator &allocator, VkDeviceSize structureSize,
961                 VkDeviceAddress deviceAddress = 0u, const void *pNext = nullptr,
962                 const MemoryRequirement &addMemoryRequirement = MemoryRequirement::Any,
963                 const VkBuffer creationBuffer = VK_NULL_HANDLE, const VkDeviceSize creationBufferSize = 0u) override;
964     void build(const DeviceInterface &vk, const VkDevice device, const VkCommandBuffer cmdBuffer,
965                BottomLevelAccelerationStructure *srcAccelerationStructure = nullptr,
966                VkPipelineStageFlags barrierDstStages =
967                    static_cast<VkPipelineStageFlags>(VK_PIPELINE_STAGE_ALL_COMMANDS_BIT)) override;
968     void copyFrom(const DeviceInterface &vk, const VkDevice device, const VkCommandBuffer cmdBuffer,
969                   BottomLevelAccelerationStructure *accelerationStructure, bool compactCopy) override;
970 
971     void serialize(const DeviceInterface &vk, const VkDevice device, const VkCommandBuffer cmdBuffer,
972                    SerialStorage *storage) override;
973     void deserialize(const DeviceInterface &vk, const VkDevice device, const VkCommandBuffer cmdBuffer,
974                      SerialStorage *storage) override;
975 
976     const VkAccelerationStructureKHR *getPtr(void) const override;
977     void updateGeometry(size_t geometryIndex, de::SharedPtr<RaytracedGeometryBase> &raytracedGeometry) override;
978 
979 protected:
980     VkAccelerationStructureBuildTypeKHR m_buildType;
981     VkAccelerationStructureCreateFlagsKHR m_createFlags;
982     bool m_createGeneric;
983     bool m_creationBufferUnbounded;
984     VkBuildAccelerationStructureFlagsKHR m_buildFlags;
985     bool m_buildWithoutGeometries;
986     bool m_buildWithoutPrimitives;
987     bool m_deferredOperation;
988     uint32_t m_workerThreadCount;
989     bool m_useArrayOfPointers;
990     bool m_useMaintenance5;
991     de::MovePtr<BufferWithMemory> m_accelerationStructureBuffer;
992     de::MovePtr<BufferWithMemory> m_vertexBuffer;
993     de::MovePtr<BufferWithMemory> m_indexBuffer;
994     de::MovePtr<BufferWithMemory> m_deviceScratchBuffer;
995     de::UniquePtr<std::vector<uint8_t>> m_hostScratchBuffer;
996     Move<VkAccelerationStructureKHR> m_accelerationStructureKHR;
997     VkBuffer m_indirectBuffer;
998     VkDeviceSize m_indirectBufferOffset;
999     uint32_t m_indirectBufferStride;
1000 
1001     void prepareGeometries(
1002         const DeviceInterface &vk, const VkDevice device,
1003         std::vector<VkAccelerationStructureGeometryKHR> &accelerationStructureGeometriesKHR,
1004         std::vector<VkAccelerationStructureGeometryKHR *> &accelerationStructureGeometriesKHRPointers,
1005         std::vector<VkAccelerationStructureBuildRangeInfoKHR> &accelerationStructureBuildRangeInfoKHR,
1006         std::vector<VkAccelerationStructureTrianglesOpacityMicromapEXT> &accelerationStructureGeometryMicromapsEXT,
1007         std::vector<uint32_t> &maxPrimitiveCounts, VkDeviceSize vertexBufferOffset = 0,
1008         VkDeviceSize indexBufferOffset = 0) const;
1009 
getAccelerationStructureBuffer() const1010     virtual BufferWithMemory *getAccelerationStructureBuffer() const
1011     {
1012         return m_accelerationStructureBuffer.get();
1013     }
getDeviceScratchBuffer() const1014     virtual BufferWithMemory *getDeviceScratchBuffer() const
1015     {
1016         return m_deviceScratchBuffer.get();
1017     }
getHostScratchBuffer() const1018     virtual std::vector<uint8_t> *getHostScratchBuffer() const
1019     {
1020         return m_hostScratchBuffer.get();
1021     }
getVertexBuffer() const1022     virtual BufferWithMemory *getVertexBuffer() const
1023     {
1024         return m_vertexBuffer.get();
1025     }
getIndexBuffer() const1026     virtual BufferWithMemory *getIndexBuffer() const
1027     {
1028         return m_indexBuffer.get();
1029     }
1030 
getAccelerationStructureBufferOffset() const1031     virtual VkDeviceSize getAccelerationStructureBufferOffset() const
1032     {
1033         return 0;
1034     }
getDeviceScratchBufferOffset() const1035     virtual VkDeviceSize getDeviceScratchBufferOffset() const
1036     {
1037         return 0;
1038     }
getVertexBufferOffset() const1039     virtual VkDeviceSize getVertexBufferOffset() const
1040     {
1041         return 0;
1042     }
getIndexBufferOffset() const1043     virtual VkDeviceSize getIndexBufferOffset() const
1044     {
1045         return 0;
1046     }
1047 };
1048 
getRequiredAllocationCount(void)1049 uint32_t BottomLevelAccelerationStructureKHR::getRequiredAllocationCount(void)
1050 {
1051     /*
1052         de::MovePtr<BufferWithMemory>                            m_geometryBuffer; // but only when m_buildType == VK_ACCELERATION_STRUCTURE_BUILD_TYPE_DEVICE_KHR
1053         de::MovePtr<Allocation>                                    m_accelerationStructureAlloc;
1054         de::MovePtr<BufferWithMemory>                            m_deviceScratchBuffer;
1055     */
1056     return 3u;
1057 }
1058 
~BottomLevelAccelerationStructureKHR()1059 BottomLevelAccelerationStructureKHR::~BottomLevelAccelerationStructureKHR()
1060 {
1061 }
1062 
BottomLevelAccelerationStructureKHR()1063 BottomLevelAccelerationStructureKHR::BottomLevelAccelerationStructureKHR()
1064     : BottomLevelAccelerationStructure()
1065     , m_buildType(VK_ACCELERATION_STRUCTURE_BUILD_TYPE_DEVICE_KHR)
1066     , m_createFlags(0u)
1067     , m_createGeneric(false)
1068     , m_creationBufferUnbounded(false)
1069     , m_buildFlags(0u)
1070     , m_buildWithoutGeometries(false)
1071     , m_buildWithoutPrimitives(false)
1072     , m_deferredOperation(false)
1073     , m_workerThreadCount(0)
1074     , m_useArrayOfPointers(false)
1075     , m_useMaintenance5(false)
1076     , m_accelerationStructureBuffer()
1077     , m_vertexBuffer()
1078     , m_indexBuffer()
1079     , m_deviceScratchBuffer()
1080     , m_hostScratchBuffer(new std::vector<uint8_t>)
1081     , m_accelerationStructureKHR()
1082     , m_indirectBuffer(VK_NULL_HANDLE)
1083     , m_indirectBufferOffset(0)
1084     , m_indirectBufferStride(0)
1085 {
1086 }
1087 
setBuildType(const VkAccelerationStructureBuildTypeKHR buildType)1088 void BottomLevelAccelerationStructureKHR::setBuildType(const VkAccelerationStructureBuildTypeKHR buildType)
1089 {
1090     m_buildType = buildType;
1091 }
1092 
getBuildType() const1093 VkAccelerationStructureBuildTypeKHR BottomLevelAccelerationStructureKHR::getBuildType() const
1094 {
1095     return m_buildType;
1096 }
1097 
setCreateFlags(const VkAccelerationStructureCreateFlagsKHR createFlags)1098 void BottomLevelAccelerationStructureKHR::setCreateFlags(const VkAccelerationStructureCreateFlagsKHR createFlags)
1099 {
1100     m_createFlags = createFlags;
1101 }
1102 
setCreateGeneric(bool createGeneric)1103 void BottomLevelAccelerationStructureKHR::setCreateGeneric(bool createGeneric)
1104 {
1105     m_createGeneric = createGeneric;
1106 }
1107 
setCreationBufferUnbounded(bool creationBufferUnbounded)1108 void BottomLevelAccelerationStructureKHR::setCreationBufferUnbounded(bool creationBufferUnbounded)
1109 {
1110     m_creationBufferUnbounded = creationBufferUnbounded;
1111 }
1112 
setBuildFlags(const VkBuildAccelerationStructureFlagsKHR buildFlags)1113 void BottomLevelAccelerationStructureKHR::setBuildFlags(const VkBuildAccelerationStructureFlagsKHR buildFlags)
1114 {
1115     m_buildFlags = buildFlags;
1116 }
1117 
setBuildWithoutGeometries(bool buildWithoutGeometries)1118 void BottomLevelAccelerationStructureKHR::setBuildWithoutGeometries(bool buildWithoutGeometries)
1119 {
1120     m_buildWithoutGeometries = buildWithoutGeometries;
1121 }
1122 
setBuildWithoutPrimitives(bool buildWithoutPrimitives)1123 void BottomLevelAccelerationStructureKHR::setBuildWithoutPrimitives(bool buildWithoutPrimitives)
1124 {
1125     m_buildWithoutPrimitives = buildWithoutPrimitives;
1126 }
1127 
setDeferredOperation(const bool deferredOperation,const uint32_t workerThreadCount)1128 void BottomLevelAccelerationStructureKHR::setDeferredOperation(const bool deferredOperation,
1129                                                                const uint32_t workerThreadCount)
1130 {
1131     m_deferredOperation = deferredOperation;
1132     m_workerThreadCount = workerThreadCount;
1133 }
1134 
setUseArrayOfPointers(const bool useArrayOfPointers)1135 void BottomLevelAccelerationStructureKHR::setUseArrayOfPointers(const bool useArrayOfPointers)
1136 {
1137     m_useArrayOfPointers = useArrayOfPointers;
1138 }
1139 
setUseMaintenance5(const bool useMaintenance5)1140 void BottomLevelAccelerationStructureKHR::setUseMaintenance5(const bool useMaintenance5)
1141 {
1142     m_useMaintenance5 = useMaintenance5;
1143 }
1144 
setIndirectBuildParameters(const VkBuffer indirectBuffer,const VkDeviceSize indirectBufferOffset,const uint32_t indirectBufferStride)1145 void BottomLevelAccelerationStructureKHR::setIndirectBuildParameters(const VkBuffer indirectBuffer,
1146                                                                      const VkDeviceSize indirectBufferOffset,
1147                                                                      const uint32_t indirectBufferStride)
1148 {
1149     m_indirectBuffer       = indirectBuffer;
1150     m_indirectBufferOffset = indirectBufferOffset;
1151     m_indirectBufferStride = indirectBufferStride;
1152 }
1153 
getBuildFlags() const1154 VkBuildAccelerationStructureFlagsKHR BottomLevelAccelerationStructureKHR::getBuildFlags() const
1155 {
1156     return m_buildFlags;
1157 }
1158 
create(const DeviceInterface & vk,const VkDevice device,Allocator & allocator,VkDeviceSize structureSize,VkDeviceAddress deviceAddress,const void * pNext,const MemoryRequirement & addMemoryRequirement,const VkBuffer creationBuffer,const VkDeviceSize creationBufferSize)1159 void BottomLevelAccelerationStructureKHR::create(const DeviceInterface &vk, const VkDevice device, Allocator &allocator,
1160                                                  VkDeviceSize structureSize, VkDeviceAddress deviceAddress,
1161                                                  const void *pNext, const MemoryRequirement &addMemoryRequirement,
1162                                                  const VkBuffer creationBuffer, const VkDeviceSize creationBufferSize)
1163 {
1164     // AS may be built from geometries using vkCmdBuildAccelerationStructuresKHR / vkBuildAccelerationStructuresKHR
1165     // or may be copied/compacted/deserialized from other AS ( in this case AS does not need geometries, but it needs to know its size before creation ).
1166     DE_ASSERT(!m_geometriesData.empty() != !(structureSize == 0)); // logical xor
1167 
1168     if (structureSize == 0)
1169     {
1170         std::vector<VkAccelerationStructureGeometryKHR> accelerationStructureGeometriesKHR;
1171         std::vector<VkAccelerationStructureGeometryKHR *> accelerationStructureGeometriesKHRPointers;
1172         std::vector<VkAccelerationStructureBuildRangeInfoKHR> accelerationStructureBuildRangeInfoKHR;
1173         std::vector<VkAccelerationStructureTrianglesOpacityMicromapEXT> accelerationStructureGeometryMicromapsEXT;
1174         std::vector<uint32_t> maxPrimitiveCounts;
1175         prepareGeometries(vk, device, accelerationStructureGeometriesKHR, accelerationStructureGeometriesKHRPointers,
1176                           accelerationStructureBuildRangeInfoKHR, accelerationStructureGeometryMicromapsEXT,
1177                           maxPrimitiveCounts);
1178 
1179         const VkAccelerationStructureGeometryKHR *accelerationStructureGeometriesKHRPointer =
1180             accelerationStructureGeometriesKHR.data();
1181         const VkAccelerationStructureGeometryKHR *const *accelerationStructureGeometry =
1182             accelerationStructureGeometriesKHRPointers.data();
1183 
1184         const uint32_t geometryCount =
1185             (m_buildWithoutGeometries ? 0u : static_cast<uint32_t>(accelerationStructureGeometriesKHR.size()));
1186         VkAccelerationStructureBuildGeometryInfoKHR accelerationStructureBuildGeometryInfoKHR = {
1187             VK_STRUCTURE_TYPE_ACCELERATION_STRUCTURE_BUILD_GEOMETRY_INFO_KHR, //  VkStructureType sType;
1188             nullptr,                                                          //  const void* pNext;
1189             VK_ACCELERATION_STRUCTURE_TYPE_BOTTOM_LEVEL_KHR,                  //  VkAccelerationStructureTypeKHR type;
1190             m_buildFlags,                                   //  VkBuildAccelerationStructureFlagsKHR flags;
1191             VK_BUILD_ACCELERATION_STRUCTURE_MODE_BUILD_KHR, //  VkBuildAccelerationStructureModeKHR mode;
1192             VK_NULL_HANDLE,                                 //  VkAccelerationStructureKHR srcAccelerationStructure;
1193             VK_NULL_HANDLE,                                 //  VkAccelerationStructureKHR dstAccelerationStructure;
1194             geometryCount,                                  //  uint32_t geometryCount;
1195             m_useArrayOfPointers ?
1196                 nullptr :
1197                 accelerationStructureGeometriesKHRPointer, //  const VkAccelerationStructureGeometryKHR* pGeometries;
1198             m_useArrayOfPointers ? accelerationStructureGeometry :
1199                                    nullptr,     //  const VkAccelerationStructureGeometryKHR* const* ppGeometries;
1200             makeDeviceOrHostAddressKHR(nullptr) //  VkDeviceOrHostAddressKHR scratchData;
1201         };
1202         VkAccelerationStructureBuildSizesInfoKHR sizeInfo = {
1203             VK_STRUCTURE_TYPE_ACCELERATION_STRUCTURE_BUILD_SIZES_INFO_KHR, //  VkStructureType sType;
1204             nullptr,                                                       //  const void* pNext;
1205             0,                                                             //  VkDeviceSize accelerationStructureSize;
1206             0,                                                             //  VkDeviceSize updateScratchSize;
1207             0                                                              //  VkDeviceSize buildScratchSize;
1208         };
1209 
1210         vk.getAccelerationStructureBuildSizesKHR(device, m_buildType, &accelerationStructureBuildGeometryInfoKHR,
1211                                                  maxPrimitiveCounts.data(), &sizeInfo);
1212 
1213         m_structureSize     = sizeInfo.accelerationStructureSize;
1214         m_updateScratchSize = sizeInfo.updateScratchSize;
1215         m_buildScratchSize  = sizeInfo.buildScratchSize;
1216     }
1217     else
1218     {
1219         m_structureSize     = structureSize;
1220         m_updateScratchSize = 0u;
1221         m_buildScratchSize  = 0u;
1222     }
1223 
1224     const bool externalCreationBuffer = (creationBuffer != VK_NULL_HANDLE);
1225 
1226     if (externalCreationBuffer)
1227     {
1228         DE_UNREF(creationBufferSize); // For release builds.
1229         DE_ASSERT(creationBufferSize >= m_structureSize);
1230     }
1231 
1232     if (!externalCreationBuffer)
1233     {
1234         VkBufferCreateInfo bufferCreateInfo =
1235             makeBufferCreateInfo(m_structureSize, VK_BUFFER_USAGE_ACCELERATION_STRUCTURE_STORAGE_BIT_KHR |
1236                                                       VK_BUFFER_USAGE_SHADER_DEVICE_ADDRESS_BIT);
1237         VkBufferUsageFlags2CreateInfoKHR bufferUsageFlags2 = vk::initVulkanStructure();
1238 
1239         if (m_useMaintenance5)
1240         {
1241             bufferUsageFlags2.usage = VK_BUFFER_USAGE_2_ACCELERATION_STRUCTURE_STORAGE_BIT_KHR |
1242                                       VK_BUFFER_USAGE_2_SHADER_DEVICE_ADDRESS_BIT_KHR;
1243             bufferCreateInfo.pNext = &bufferUsageFlags2;
1244             bufferCreateInfo.usage = 0;
1245         }
1246 
1247         const MemoryRequirement memoryRequirement = addMemoryRequirement | MemoryRequirement::HostVisible |
1248                                                     MemoryRequirement::Coherent | MemoryRequirement::DeviceAddress;
1249         const bool bindMemOnCreation = (!m_creationBufferUnbounded);
1250 
1251         try
1252         {
1253             m_accelerationStructureBuffer = de::MovePtr<BufferWithMemory>(
1254                 new BufferWithMemory(vk, device, allocator, bufferCreateInfo,
1255                                      (MemoryRequirement::Cached | memoryRequirement), bindMemOnCreation));
1256         }
1257         catch (const tcu::NotSupportedError &)
1258         {
1259             // retry without Cached flag
1260             m_accelerationStructureBuffer = de::MovePtr<BufferWithMemory>(
1261                 new BufferWithMemory(vk, device, allocator, bufferCreateInfo, memoryRequirement, bindMemOnCreation));
1262         }
1263     }
1264 
1265     const auto createInfoBuffer = (externalCreationBuffer ? creationBuffer : getAccelerationStructureBuffer()->get());
1266     const auto createInfoOffset =
1267         (externalCreationBuffer ? static_cast<VkDeviceSize>(0) : getAccelerationStructureBufferOffset());
1268     {
1269         const VkAccelerationStructureTypeKHR structureType =
1270             (m_createGeneric ? VK_ACCELERATION_STRUCTURE_TYPE_GENERIC_KHR :
1271                                VK_ACCELERATION_STRUCTURE_TYPE_BOTTOM_LEVEL_KHR);
1272         const VkAccelerationStructureCreateInfoKHR accelerationStructureCreateInfoKHR{
1273             VK_STRUCTURE_TYPE_ACCELERATION_STRUCTURE_CREATE_INFO_KHR, //  VkStructureType sType;
1274             pNext,                                                    //  const void* pNext;
1275             m_createFlags,    //  VkAccelerationStructureCreateFlagsKHR createFlags;
1276             createInfoBuffer, //  VkBuffer buffer;
1277             createInfoOffset, //  VkDeviceSize offset;
1278             m_structureSize,  //  VkDeviceSize size;
1279             structureType,    //  VkAccelerationStructureTypeKHR type;
1280             deviceAddress     //  VkDeviceAddress deviceAddress;
1281         };
1282 
1283         m_accelerationStructureKHR =
1284             createAccelerationStructureKHR(vk, device, &accelerationStructureCreateInfoKHR, nullptr);
1285 
1286         // Make sure buffer memory is always bound after creation.
1287         if (!externalCreationBuffer)
1288             m_accelerationStructureBuffer->bindMemory();
1289     }
1290 
1291     if (m_buildScratchSize > 0u || m_updateScratchSize > 0u)
1292     {
1293         VkDeviceSize scratch_size = de::max(m_buildScratchSize, m_updateScratchSize);
1294         if (m_buildType == VK_ACCELERATION_STRUCTURE_BUILD_TYPE_DEVICE_KHR)
1295         {
1296             const VkBufferCreateInfo bufferCreateInfo = makeBufferCreateInfo(
1297                 scratch_size, VK_BUFFER_USAGE_STORAGE_BUFFER_BIT | VK_BUFFER_USAGE_SHADER_DEVICE_ADDRESS_BIT);
1298             m_deviceScratchBuffer = de::MovePtr<BufferWithMemory>(new BufferWithMemory(
1299                 vk, device, allocator, bufferCreateInfo,
1300                 MemoryRequirement::HostVisible | MemoryRequirement::Coherent | MemoryRequirement::DeviceAddress));
1301         }
1302         else
1303         {
1304             m_hostScratchBuffer->resize(static_cast<size_t>(scratch_size));
1305         }
1306     }
1307 
1308     if (m_buildType == VK_ACCELERATION_STRUCTURE_BUILD_TYPE_DEVICE_KHR && !m_geometriesData.empty())
1309     {
1310         VkBufferCreateInfo bufferCreateInfo =
1311             makeBufferCreateInfo(getVertexBufferSize(m_geometriesData),
1312                                  VK_BUFFER_USAGE_ACCELERATION_STRUCTURE_BUILD_INPUT_READ_ONLY_BIT_KHR |
1313                                      VK_BUFFER_USAGE_SHADER_DEVICE_ADDRESS_BIT);
1314         VkBufferUsageFlags2CreateInfoKHR bufferUsageFlags2 = vk::initVulkanStructure();
1315 
1316         if (m_useMaintenance5)
1317         {
1318             bufferUsageFlags2.usage = vk::VK_BUFFER_USAGE_2_ACCELERATION_STRUCTURE_BUILD_INPUT_READ_ONLY_BIT_KHR |
1319                                       VK_BUFFER_USAGE_2_SHADER_DEVICE_ADDRESS_BIT_KHR;
1320             bufferCreateInfo.pNext = &bufferUsageFlags2;
1321             bufferCreateInfo.usage = 0;
1322         }
1323 
1324         const vk::MemoryRequirement memoryRequirement =
1325             MemoryRequirement::HostVisible | MemoryRequirement::Coherent | MemoryRequirement::DeviceAddress;
1326         m_vertexBuffer = de::MovePtr<BufferWithMemory>(
1327             new BufferWithMemory(vk, device, allocator, bufferCreateInfo, memoryRequirement));
1328 
1329         bufferCreateInfo.size = getIndexBufferSize(m_geometriesData);
1330         if (bufferCreateInfo.size)
1331             m_indexBuffer = de::MovePtr<BufferWithMemory>(
1332                 new BufferWithMemory(vk, device, allocator, bufferCreateInfo, memoryRequirement));
1333         else
1334             m_indexBuffer = de::MovePtr<BufferWithMemory>(nullptr);
1335     }
1336 }
1337 
build(const DeviceInterface & vk,const VkDevice device,const VkCommandBuffer cmdBuffer,BottomLevelAccelerationStructure * srcAccelerationStructure,VkPipelineStageFlags barrierDstStages)1338 void BottomLevelAccelerationStructureKHR::build(const DeviceInterface &vk, const VkDevice device,
1339                                                 const VkCommandBuffer cmdBuffer,
1340                                                 BottomLevelAccelerationStructure *srcAccelerationStructure,
1341                                                 VkPipelineStageFlags barrierDstStages)
1342 {
1343     DE_ASSERT(!m_geometriesData.empty());
1344     DE_ASSERT(m_accelerationStructureKHR.get() != VK_NULL_HANDLE);
1345     DE_ASSERT(m_buildScratchSize != 0);
1346 
1347     if (m_buildType == VK_ACCELERATION_STRUCTURE_BUILD_TYPE_DEVICE_KHR)
1348     {
1349         updateVertexBuffer(vk, device, m_geometriesData, getVertexBuffer(), getVertexBufferOffset());
1350         if (getIndexBuffer() != VK_NULL_HANDLE)
1351             updateIndexBuffer(vk, device, m_geometriesData, getIndexBuffer(), getIndexBufferOffset());
1352     }
1353 
1354     {
1355         std::vector<VkAccelerationStructureGeometryKHR> accelerationStructureGeometriesKHR;
1356         std::vector<VkAccelerationStructureGeometryKHR *> accelerationStructureGeometriesKHRPointers;
1357         std::vector<VkAccelerationStructureBuildRangeInfoKHR> accelerationStructureBuildRangeInfoKHR;
1358         std::vector<VkAccelerationStructureTrianglesOpacityMicromapEXT> accelerationStructureGeometryMicromapsEXT;
1359         std::vector<uint32_t> maxPrimitiveCounts;
1360 
1361         prepareGeometries(vk, device, accelerationStructureGeometriesKHR, accelerationStructureGeometriesKHRPointers,
1362                           accelerationStructureBuildRangeInfoKHR, accelerationStructureGeometryMicromapsEXT,
1363                           maxPrimitiveCounts, getVertexBufferOffset(), getIndexBufferOffset());
1364 
1365         const VkAccelerationStructureGeometryKHR *accelerationStructureGeometriesKHRPointer =
1366             accelerationStructureGeometriesKHR.data();
1367         const VkAccelerationStructureGeometryKHR *const *accelerationStructureGeometry =
1368             accelerationStructureGeometriesKHRPointers.data();
1369         VkDeviceOrHostAddressKHR scratchData =
1370             (m_buildType == VK_ACCELERATION_STRUCTURE_BUILD_TYPE_DEVICE_KHR) ?
1371                 makeDeviceOrHostAddressKHR(vk, device, getDeviceScratchBuffer()->get(),
1372                                            getDeviceScratchBufferOffset()) :
1373                 makeDeviceOrHostAddressKHR(getHostScratchBuffer()->data());
1374         const uint32_t geometryCount =
1375             (m_buildWithoutGeometries ? 0u : static_cast<uint32_t>(accelerationStructureGeometriesKHR.size()));
1376 
1377         VkAccelerationStructureKHR srcStructure =
1378             (srcAccelerationStructure != nullptr) ? *(srcAccelerationStructure->getPtr()) : VK_NULL_HANDLE;
1379         VkBuildAccelerationStructureModeKHR mode = (srcAccelerationStructure != nullptr) ?
1380                                                        VK_BUILD_ACCELERATION_STRUCTURE_MODE_UPDATE_KHR :
1381                                                        VK_BUILD_ACCELERATION_STRUCTURE_MODE_BUILD_KHR;
1382 
1383         VkAccelerationStructureBuildGeometryInfoKHR accelerationStructureBuildGeometryInfoKHR = {
1384             VK_STRUCTURE_TYPE_ACCELERATION_STRUCTURE_BUILD_GEOMETRY_INFO_KHR, //  VkStructureType sType;
1385             nullptr,                                                          //  const void* pNext;
1386             VK_ACCELERATION_STRUCTURE_TYPE_BOTTOM_LEVEL_KHR,                  //  VkAccelerationStructureTypeKHR type;
1387             m_buildFlags,                     //  VkBuildAccelerationStructureFlagsKHR flags;
1388             mode,                             //  VkBuildAccelerationStructureModeKHR mode;
1389             srcStructure,                     //  VkAccelerationStructureKHR srcAccelerationStructure;
1390             m_accelerationStructureKHR.get(), //  VkAccelerationStructureKHR dstAccelerationStructure;
1391             geometryCount,                    //  uint32_t geometryCount;
1392             m_useArrayOfPointers ?
1393                 nullptr :
1394                 accelerationStructureGeometriesKHRPointer, //  const VkAccelerationStructureGeometryKHR* pGeometries;
1395             m_useArrayOfPointers ? accelerationStructureGeometry :
1396                                    nullptr, //  const VkAccelerationStructureGeometryKHR* const* ppGeometries;
1397             scratchData                     //  VkDeviceOrHostAddressKHR scratchData;
1398         };
1399 
1400         VkAccelerationStructureBuildRangeInfoKHR *accelerationStructureBuildRangeInfoKHRPtr =
1401             accelerationStructureBuildRangeInfoKHR.data();
1402 
1403         if (m_buildType == VK_ACCELERATION_STRUCTURE_BUILD_TYPE_DEVICE_KHR)
1404         {
1405             if (m_indirectBuffer == VK_NULL_HANDLE)
1406                 vk.cmdBuildAccelerationStructuresKHR(
1407                     cmdBuffer, 1u, &accelerationStructureBuildGeometryInfoKHR,
1408                     (const VkAccelerationStructureBuildRangeInfoKHR **)&accelerationStructureBuildRangeInfoKHRPtr);
1409             else
1410             {
1411                 VkDeviceAddress indirectDeviceAddress =
1412                     getBufferDeviceAddress(vk, device, m_indirectBuffer, m_indirectBufferOffset);
1413                 uint32_t *pMaxPrimitiveCounts = maxPrimitiveCounts.data();
1414                 vk.cmdBuildAccelerationStructuresIndirectKHR(cmdBuffer, 1u, &accelerationStructureBuildGeometryInfoKHR,
1415                                                              &indirectDeviceAddress, &m_indirectBufferStride,
1416                                                              &pMaxPrimitiveCounts);
1417             }
1418         }
1419         else if (!m_deferredOperation)
1420         {
1421             VK_CHECK(vk.buildAccelerationStructuresKHR(
1422                 device, VK_NULL_HANDLE, 1u, &accelerationStructureBuildGeometryInfoKHR,
1423                 (const VkAccelerationStructureBuildRangeInfoKHR **)&accelerationStructureBuildRangeInfoKHRPtr));
1424         }
1425         else
1426         {
1427             const auto deferredOperationPtr = createDeferredOperationKHR(vk, device);
1428             const auto deferredOperation    = deferredOperationPtr.get();
1429 
1430             VkResult result = vk.buildAccelerationStructuresKHR(
1431                 device, deferredOperation, 1u, &accelerationStructureBuildGeometryInfoKHR,
1432                 (const VkAccelerationStructureBuildRangeInfoKHR **)&accelerationStructureBuildRangeInfoKHRPtr);
1433 
1434             DE_ASSERT(result == VK_OPERATION_DEFERRED_KHR || result == VK_OPERATION_NOT_DEFERRED_KHR ||
1435                       result == VK_SUCCESS);
1436 
1437             finishDeferredOperation(vk, device, deferredOperation, m_workerThreadCount,
1438                                     result == VK_OPERATION_NOT_DEFERRED_KHR);
1439         }
1440     }
1441 
1442     if (m_buildType == VK_ACCELERATION_STRUCTURE_BUILD_TYPE_DEVICE_KHR)
1443     {
1444         const VkAccessFlags accessMasks =
1445             VK_ACCESS_ACCELERATION_STRUCTURE_WRITE_BIT_KHR | VK_ACCESS_ACCELERATION_STRUCTURE_READ_BIT_KHR;
1446         const VkMemoryBarrier memBarrier = makeMemoryBarrier(accessMasks, accessMasks);
1447 
1448         cmdPipelineMemoryBarrier(vk, cmdBuffer, VK_PIPELINE_STAGE_ACCELERATION_STRUCTURE_BUILD_BIT_KHR,
1449                                  barrierDstStages, &memBarrier);
1450     }
1451 }
1452 
copyFrom(const DeviceInterface & vk,const VkDevice device,const VkCommandBuffer cmdBuffer,BottomLevelAccelerationStructure * accelerationStructure,bool compactCopy)1453 void BottomLevelAccelerationStructureKHR::copyFrom(const DeviceInterface &vk, const VkDevice device,
1454                                                    const VkCommandBuffer cmdBuffer,
1455                                                    BottomLevelAccelerationStructure *accelerationStructure,
1456                                                    bool compactCopy)
1457 {
1458     DE_ASSERT(m_accelerationStructureKHR.get() != VK_NULL_HANDLE);
1459     DE_ASSERT(accelerationStructure != nullptr);
1460 
1461     VkCopyAccelerationStructureInfoKHR copyAccelerationStructureInfo = {
1462         VK_STRUCTURE_TYPE_COPY_ACCELERATION_STRUCTURE_INFO_KHR, // VkStructureType sType;
1463         nullptr,                                                // const void* pNext;
1464         *(accelerationStructure->getPtr()),                     // VkAccelerationStructureKHR src;
1465         *(getPtr()),                                            // VkAccelerationStructureKHR dst;
1466         compactCopy ? VK_COPY_ACCELERATION_STRUCTURE_MODE_COMPACT_KHR :
1467                       VK_COPY_ACCELERATION_STRUCTURE_MODE_CLONE_KHR // VkCopyAccelerationStructureModeKHR mode;
1468     };
1469 
1470     if (m_buildType == VK_ACCELERATION_STRUCTURE_BUILD_TYPE_DEVICE_KHR)
1471     {
1472         vk.cmdCopyAccelerationStructureKHR(cmdBuffer, &copyAccelerationStructureInfo);
1473     }
1474     else if (!m_deferredOperation)
1475     {
1476         VK_CHECK(vk.copyAccelerationStructureKHR(device, VK_NULL_HANDLE, &copyAccelerationStructureInfo));
1477     }
1478     else
1479     {
1480         const auto deferredOperationPtr = createDeferredOperationKHR(vk, device);
1481         const auto deferredOperation    = deferredOperationPtr.get();
1482 
1483         VkResult result = vk.copyAccelerationStructureKHR(device, deferredOperation, &copyAccelerationStructureInfo);
1484 
1485         DE_ASSERT(result == VK_OPERATION_DEFERRED_KHR || result == VK_OPERATION_NOT_DEFERRED_KHR ||
1486                   result == VK_SUCCESS);
1487 
1488         finishDeferredOperation(vk, device, deferredOperation, m_workerThreadCount,
1489                                 result == VK_OPERATION_NOT_DEFERRED_KHR);
1490     }
1491 
1492     if (m_buildType == VK_ACCELERATION_STRUCTURE_BUILD_TYPE_DEVICE_KHR)
1493     {
1494         const VkAccessFlags accessMasks =
1495             VK_ACCESS_ACCELERATION_STRUCTURE_WRITE_BIT_KHR | VK_ACCESS_ACCELERATION_STRUCTURE_READ_BIT_KHR;
1496         const VkMemoryBarrier memBarrier = makeMemoryBarrier(accessMasks, accessMasks);
1497 
1498         cmdPipelineMemoryBarrier(vk, cmdBuffer, VK_PIPELINE_STAGE_ACCELERATION_STRUCTURE_BUILD_BIT_KHR,
1499                                  VK_PIPELINE_STAGE_ALL_COMMANDS_BIT, &memBarrier);
1500     }
1501 }
1502 
serialize(const DeviceInterface & vk,const VkDevice device,const VkCommandBuffer cmdBuffer,SerialStorage * storage)1503 void BottomLevelAccelerationStructureKHR::serialize(const DeviceInterface &vk, const VkDevice device,
1504                                                     const VkCommandBuffer cmdBuffer, SerialStorage *storage)
1505 {
1506     DE_ASSERT(m_accelerationStructureKHR.get() != VK_NULL_HANDLE);
1507     DE_ASSERT(storage != nullptr);
1508 
1509     const VkCopyAccelerationStructureToMemoryInfoKHR copyAccelerationStructureInfo = {
1510         VK_STRUCTURE_TYPE_COPY_ACCELERATION_STRUCTURE_TO_MEMORY_INFO_KHR, // VkStructureType sType;
1511         nullptr,                                                          // const void* pNext;
1512         *(getPtr()),                                                      // VkAccelerationStructureKHR src;
1513         storage->getAddress(vk, device, m_buildType),                     // VkDeviceOrHostAddressKHR dst;
1514         VK_COPY_ACCELERATION_STRUCTURE_MODE_SERIALIZE_KHR                 // VkCopyAccelerationStructureModeKHR mode;
1515     };
1516 
1517     if (m_buildType == VK_ACCELERATION_STRUCTURE_BUILD_TYPE_DEVICE_KHR)
1518     {
1519         vk.cmdCopyAccelerationStructureToMemoryKHR(cmdBuffer, &copyAccelerationStructureInfo);
1520     }
1521     else if (!m_deferredOperation)
1522     {
1523         VK_CHECK(vk.copyAccelerationStructureToMemoryKHR(device, VK_NULL_HANDLE, &copyAccelerationStructureInfo));
1524     }
1525     else
1526     {
1527         const auto deferredOperationPtr = createDeferredOperationKHR(vk, device);
1528         const auto deferredOperation    = deferredOperationPtr.get();
1529 
1530         const VkResult result =
1531             vk.copyAccelerationStructureToMemoryKHR(device, deferredOperation, &copyAccelerationStructureInfo);
1532 
1533         DE_ASSERT(result == VK_OPERATION_DEFERRED_KHR || result == VK_OPERATION_NOT_DEFERRED_KHR ||
1534                   result == VK_SUCCESS);
1535 
1536         finishDeferredOperation(vk, device, deferredOperation, m_workerThreadCount,
1537                                 result == VK_OPERATION_NOT_DEFERRED_KHR);
1538     }
1539 }
1540 
deserialize(const DeviceInterface & vk,const VkDevice device,const VkCommandBuffer cmdBuffer,SerialStorage * storage)1541 void BottomLevelAccelerationStructureKHR::deserialize(const DeviceInterface &vk, const VkDevice device,
1542                                                       const VkCommandBuffer cmdBuffer, SerialStorage *storage)
1543 {
1544     DE_ASSERT(m_accelerationStructureKHR.get() != VK_NULL_HANDLE);
1545     DE_ASSERT(storage != nullptr);
1546 
1547     const VkCopyMemoryToAccelerationStructureInfoKHR copyAccelerationStructureInfo = {
1548         VK_STRUCTURE_TYPE_COPY_MEMORY_TO_ACCELERATION_STRUCTURE_INFO_KHR, // VkStructureType sType;
1549         nullptr,                                                          // const void* pNext;
1550         storage->getAddressConst(vk, device, m_buildType),                // VkDeviceOrHostAddressConstKHR src;
1551         *(getPtr()),                                                      // VkAccelerationStructureKHR dst;
1552         VK_COPY_ACCELERATION_STRUCTURE_MODE_DESERIALIZE_KHR               // VkCopyAccelerationStructureModeKHR mode;
1553     };
1554 
1555     if (m_buildType == VK_ACCELERATION_STRUCTURE_BUILD_TYPE_DEVICE_KHR)
1556     {
1557         vk.cmdCopyMemoryToAccelerationStructureKHR(cmdBuffer, &copyAccelerationStructureInfo);
1558     }
1559     else if (!m_deferredOperation)
1560     {
1561         VK_CHECK(vk.copyMemoryToAccelerationStructureKHR(device, VK_NULL_HANDLE, &copyAccelerationStructureInfo));
1562     }
1563     else
1564     {
1565         const auto deferredOperationPtr = createDeferredOperationKHR(vk, device);
1566         const auto deferredOperation    = deferredOperationPtr.get();
1567 
1568         const VkResult result =
1569             vk.copyMemoryToAccelerationStructureKHR(device, deferredOperation, &copyAccelerationStructureInfo);
1570 
1571         DE_ASSERT(result == VK_OPERATION_DEFERRED_KHR || result == VK_OPERATION_NOT_DEFERRED_KHR ||
1572                   result == VK_SUCCESS);
1573 
1574         finishDeferredOperation(vk, device, deferredOperation, m_workerThreadCount,
1575                                 result == VK_OPERATION_NOT_DEFERRED_KHR);
1576     }
1577 
1578     if (m_buildType == VK_ACCELERATION_STRUCTURE_BUILD_TYPE_DEVICE_KHR)
1579     {
1580         const VkAccessFlags accessMasks =
1581             VK_ACCESS_ACCELERATION_STRUCTURE_WRITE_BIT_KHR | VK_ACCESS_ACCELERATION_STRUCTURE_READ_BIT_KHR;
1582         const VkMemoryBarrier memBarrier = makeMemoryBarrier(accessMasks, accessMasks);
1583 
1584         cmdPipelineMemoryBarrier(vk, cmdBuffer, VK_PIPELINE_STAGE_ACCELERATION_STRUCTURE_BUILD_BIT_KHR,
1585                                  VK_PIPELINE_STAGE_ALL_COMMANDS_BIT, &memBarrier);
1586     }
1587 }
1588 
getPtr(void) const1589 const VkAccelerationStructureKHR *BottomLevelAccelerationStructureKHR::getPtr(void) const
1590 {
1591     return &m_accelerationStructureKHR.get();
1592 }
1593 
prepareGeometries(const DeviceInterface & vk,const VkDevice device,std::vector<VkAccelerationStructureGeometryKHR> & accelerationStructureGeometriesKHR,std::vector<VkAccelerationStructureGeometryKHR * > & accelerationStructureGeometriesKHRPointers,std::vector<VkAccelerationStructureBuildRangeInfoKHR> & accelerationStructureBuildRangeInfoKHR,std::vector<VkAccelerationStructureTrianglesOpacityMicromapEXT> & accelerationStructureGeometryMicromapsEXT,std::vector<uint32_t> & maxPrimitiveCounts,VkDeviceSize vertexBufferOffset,VkDeviceSize indexBufferOffset) const1594 void BottomLevelAccelerationStructureKHR::prepareGeometries(
1595     const DeviceInterface &vk, const VkDevice device,
1596     std::vector<VkAccelerationStructureGeometryKHR> &accelerationStructureGeometriesKHR,
1597     std::vector<VkAccelerationStructureGeometryKHR *> &accelerationStructureGeometriesKHRPointers,
1598     std::vector<VkAccelerationStructureBuildRangeInfoKHR> &accelerationStructureBuildRangeInfoKHR,
1599     std::vector<VkAccelerationStructureTrianglesOpacityMicromapEXT> &accelerationStructureGeometryMicromapsEXT,
1600     std::vector<uint32_t> &maxPrimitiveCounts, VkDeviceSize vertexBufferOffset, VkDeviceSize indexBufferOffset) const
1601 {
1602     accelerationStructureGeometriesKHR.resize(m_geometriesData.size());
1603     accelerationStructureGeometriesKHRPointers.resize(m_geometriesData.size());
1604     accelerationStructureBuildRangeInfoKHR.resize(m_geometriesData.size());
1605     accelerationStructureGeometryMicromapsEXT.resize(m_geometriesData.size());
1606     maxPrimitiveCounts.resize(m_geometriesData.size());
1607 
1608     for (size_t geometryNdx = 0; geometryNdx < m_geometriesData.size(); ++geometryNdx)
1609     {
1610         const de::SharedPtr<RaytracedGeometryBase> &geometryData = m_geometriesData[geometryNdx];
1611         VkDeviceOrHostAddressConstKHR vertexData, indexData;
1612         if (m_buildType == VK_ACCELERATION_STRUCTURE_BUILD_TYPE_DEVICE_KHR)
1613         {
1614             if (getVertexBuffer() != nullptr)
1615             {
1616                 vertexData = makeDeviceOrHostAddressConstKHR(vk, device, getVertexBuffer()->get(), vertexBufferOffset);
1617                 if (m_indirectBuffer == VK_NULL_HANDLE)
1618                 {
1619                     vertexBufferOffset += deAlignSize(geometryData->getVertexByteSize(), 8);
1620                 }
1621             }
1622             else
1623                 vertexData = makeDeviceOrHostAddressConstKHR(nullptr);
1624 
1625             if (getIndexBuffer() != nullptr && geometryData->getIndexType() != VK_INDEX_TYPE_NONE_KHR)
1626             {
1627                 indexData = makeDeviceOrHostAddressConstKHR(vk, device, getIndexBuffer()->get(), indexBufferOffset);
1628                 indexBufferOffset += deAlignSize(geometryData->getIndexByteSize(), 8);
1629             }
1630             else
1631                 indexData = makeDeviceOrHostAddressConstKHR(nullptr);
1632         }
1633         else
1634         {
1635             vertexData = makeDeviceOrHostAddressConstKHR(geometryData->getVertexPointer());
1636             if (geometryData->getIndexType() != VK_INDEX_TYPE_NONE_KHR)
1637                 indexData = makeDeviceOrHostAddressConstKHR(geometryData->getIndexPointer());
1638             else
1639                 indexData = makeDeviceOrHostAddressConstKHR(nullptr);
1640         }
1641 
1642         VkAccelerationStructureGeometryTrianglesDataKHR accelerationStructureGeometryTrianglesDataKHR = {
1643             VK_STRUCTURE_TYPE_ACCELERATION_STRUCTURE_GEOMETRY_TRIANGLES_DATA_KHR, //  VkStructureType sType;
1644             nullptr,                                                              //  const void* pNext;
1645             geometryData->getVertexFormat(),                                      //  VkFormat vertexFormat;
1646             vertexData,                                            //  VkDeviceOrHostAddressConstKHR vertexData;
1647             geometryData->getVertexStride(),                       //  VkDeviceSize vertexStride;
1648             static_cast<uint32_t>(geometryData->getVertexCount()), //  uint32_t maxVertex;
1649             geometryData->getIndexType(),                          //  VkIndexType indexType;
1650             indexData,                                             //  VkDeviceOrHostAddressConstKHR indexData;
1651             makeDeviceOrHostAddressConstKHR(nullptr),              //  VkDeviceOrHostAddressConstKHR transformData;
1652         };
1653 
1654         if (geometryData->getHasOpacityMicromap())
1655             accelerationStructureGeometryTrianglesDataKHR.pNext = &geometryData->getOpacityMicromap();
1656 
1657         const VkAccelerationStructureGeometryAabbsDataKHR accelerationStructureGeometryAabbsDataKHR = {
1658             VK_STRUCTURE_TYPE_ACCELERATION_STRUCTURE_GEOMETRY_AABBS_DATA_KHR, //  VkStructureType sType;
1659             nullptr,                                                          //  const void* pNext;
1660             vertexData,                                                       //  VkDeviceOrHostAddressConstKHR data;
1661             geometryData->getAABBStride()                                     //  VkDeviceSize stride;
1662         };
1663         const VkAccelerationStructureGeometryDataKHR geometry =
1664             (geometryData->isTrianglesType()) ?
1665                 makeVkAccelerationStructureGeometryDataKHR(accelerationStructureGeometryTrianglesDataKHR) :
1666                 makeVkAccelerationStructureGeometryDataKHR(accelerationStructureGeometryAabbsDataKHR);
1667         const VkAccelerationStructureGeometryKHR accelerationStructureGeometryKHR = {
1668             VK_STRUCTURE_TYPE_ACCELERATION_STRUCTURE_GEOMETRY_KHR, //  VkStructureType sType;
1669             nullptr,                                               //  const void* pNext;
1670             geometryData->getGeometryType(),                       //  VkGeometryTypeKHR geometryType;
1671             geometry,                                              //  VkAccelerationStructureGeometryDataKHR geometry;
1672             geometryData->getGeometryFlags()                       //  VkGeometryFlagsKHR flags;
1673         };
1674 
1675         const uint32_t primitiveCount = (m_buildWithoutPrimitives ? 0u : geometryData->getPrimitiveCount());
1676 
1677         const VkAccelerationStructureBuildRangeInfoKHR accelerationStructureBuildRangeInfosKHR = {
1678             primitiveCount, //  uint32_t primitiveCount;
1679             0,              //  uint32_t primitiveOffset;
1680             0,              //  uint32_t firstVertex;
1681             0               //  uint32_t firstTransform;
1682         };
1683 
1684         accelerationStructureGeometriesKHR[geometryNdx]         = accelerationStructureGeometryKHR;
1685         accelerationStructureGeometriesKHRPointers[geometryNdx] = &accelerationStructureGeometriesKHR[geometryNdx];
1686         accelerationStructureBuildRangeInfoKHR[geometryNdx]     = accelerationStructureBuildRangeInfosKHR;
1687         maxPrimitiveCounts[geometryNdx]                         = geometryData->getPrimitiveCount();
1688     }
1689 }
1690 
getRequiredAllocationCount(void)1691 uint32_t BottomLevelAccelerationStructure::getRequiredAllocationCount(void)
1692 {
1693     return BottomLevelAccelerationStructureKHR::getRequiredAllocationCount();
1694 }
1695 
createAndBuild(const DeviceInterface & vk,const VkDevice device,const VkCommandBuffer cmdBuffer,Allocator & allocator,VkDeviceAddress deviceAddress)1696 void BottomLevelAccelerationStructure::createAndBuild(const DeviceInterface &vk, const VkDevice device,
1697                                                       const VkCommandBuffer cmdBuffer, Allocator &allocator,
1698                                                       VkDeviceAddress deviceAddress)
1699 {
1700     create(vk, device, allocator, 0u, deviceAddress);
1701     build(vk, device, cmdBuffer);
1702 }
1703 
createAndCopyFrom(const DeviceInterface & vk,const VkDevice device,const VkCommandBuffer cmdBuffer,Allocator & allocator,BottomLevelAccelerationStructure * accelerationStructure,VkDeviceSize compactCopySize,VkDeviceAddress deviceAddress)1704 void BottomLevelAccelerationStructure::createAndCopyFrom(const DeviceInterface &vk, const VkDevice device,
1705                                                          const VkCommandBuffer cmdBuffer, Allocator &allocator,
1706                                                          BottomLevelAccelerationStructure *accelerationStructure,
1707                                                          VkDeviceSize compactCopySize, VkDeviceAddress deviceAddress)
1708 {
1709     DE_ASSERT(accelerationStructure != NULL);
1710     VkDeviceSize copiedSize = compactCopySize > 0u ?
1711                                   compactCopySize :
1712                                   accelerationStructure->getStructureBuildSizes().accelerationStructureSize;
1713     DE_ASSERT(copiedSize != 0u);
1714 
1715     create(vk, device, allocator, copiedSize, deviceAddress);
1716     copyFrom(vk, device, cmdBuffer, accelerationStructure, compactCopySize > 0u);
1717 }
1718 
createAndDeserializeFrom(const DeviceInterface & vk,const VkDevice device,const VkCommandBuffer cmdBuffer,Allocator & allocator,SerialStorage * storage,VkDeviceAddress deviceAddress)1719 void BottomLevelAccelerationStructure::createAndDeserializeFrom(const DeviceInterface &vk, const VkDevice device,
1720                                                                 const VkCommandBuffer cmdBuffer, Allocator &allocator,
1721                                                                 SerialStorage *storage, VkDeviceAddress deviceAddress)
1722 {
1723     DE_ASSERT(storage != NULL);
1724     DE_ASSERT(storage->getStorageSize() >= SerialStorage::SERIAL_STORAGE_SIZE_MIN);
1725     create(vk, device, allocator, storage->getDeserializedSize(), deviceAddress);
1726     deserialize(vk, device, cmdBuffer, storage);
1727 }
1728 
updateGeometry(size_t geometryIndex,de::SharedPtr<RaytracedGeometryBase> & raytracedGeometry)1729 void BottomLevelAccelerationStructureKHR::updateGeometry(size_t geometryIndex,
1730                                                          de::SharedPtr<RaytracedGeometryBase> &raytracedGeometry)
1731 {
1732     DE_ASSERT(geometryIndex < m_geometriesData.size());
1733     m_geometriesData[geometryIndex] = raytracedGeometry;
1734 }
1735 
makeBottomLevelAccelerationStructure()1736 de::MovePtr<BottomLevelAccelerationStructure> makeBottomLevelAccelerationStructure()
1737 {
1738     return de::MovePtr<BottomLevelAccelerationStructure>(new BottomLevelAccelerationStructureKHR);
1739 }
1740 
1741 // Forward declaration
1742 struct BottomLevelAccelerationStructurePoolImpl;
1743 
1744 class BottomLevelAccelerationStructurePoolMember : public BottomLevelAccelerationStructureKHR
1745 {
1746 public:
1747     friend class BottomLevelAccelerationStructurePool;
1748 
1749     BottomLevelAccelerationStructurePoolMember(BottomLevelAccelerationStructurePoolImpl &pool);
1750     BottomLevelAccelerationStructurePoolMember(const BottomLevelAccelerationStructurePoolMember &) = delete;
1751     BottomLevelAccelerationStructurePoolMember(BottomLevelAccelerationStructurePoolMember &&)      = delete;
1752     virtual ~BottomLevelAccelerationStructurePoolMember()                                          = default;
1753 
create(const DeviceInterface &,const VkDevice,Allocator &,VkDeviceSize,VkDeviceAddress,const void *,const MemoryRequirement &,const VkBuffer,const VkDeviceSize)1754     virtual void create(const DeviceInterface &, const VkDevice, Allocator &, VkDeviceSize, VkDeviceAddress,
1755                         const void *, const MemoryRequirement &, const VkBuffer, const VkDeviceSize) override
1756     {
1757         DE_ASSERT(0); // Silent this method
1758     }
1759     virtual auto computeBuildSize(const DeviceInterface &vk, const VkDevice device, const VkDeviceSize strSize) const
1760         //              accStrSize,updateScratch, buildScratch, vertexSize,   indexSize
1761         -> std::tuple<VkDeviceSize, VkDeviceSize, VkDeviceSize, VkDeviceSize, VkDeviceSize>;
1762 
1763 protected:
1764     struct Info;
1765     virtual void preCreateSetSizesAndOffsets(const Info &info, const VkDeviceSize accStrSize,
1766                                              const VkDeviceSize updateScratchSize, const VkDeviceSize buildScratchSize);
1767     virtual void createAccellerationStructure(const DeviceInterface &vk, const VkDevice device,
1768                                               VkDeviceAddress deviceAddress);
1769 
1770     virtual BufferWithMemory *getAccelerationStructureBuffer() const override;
1771     virtual BufferWithMemory *getDeviceScratchBuffer() const override;
1772     virtual std::vector<uint8_t> *getHostScratchBuffer() const override;
1773     virtual BufferWithMemory *getVertexBuffer() const override;
1774     virtual BufferWithMemory *getIndexBuffer() const override;
1775 
getAccelerationStructureBufferOffset() const1776     virtual VkDeviceSize getAccelerationStructureBufferOffset() const override
1777     {
1778         return m_info.accStrOffset;
1779     }
getDeviceScratchBufferOffset() const1780     virtual VkDeviceSize getDeviceScratchBufferOffset() const override
1781     {
1782         return m_info.buildScratchBuffOffset;
1783     }
getVertexBufferOffset() const1784     virtual VkDeviceSize getVertexBufferOffset() const override
1785     {
1786         return m_info.vertBuffOffset;
1787     }
getIndexBufferOffset() const1788     virtual VkDeviceSize getIndexBufferOffset() const override
1789     {
1790         return m_info.indexBuffOffset;
1791     }
1792 
1793     BottomLevelAccelerationStructurePoolImpl &m_pool;
1794 
1795     struct Info
1796     {
1797         uint32_t accStrIndex;
1798         VkDeviceSize accStrOffset;
1799         uint32_t vertBuffIndex;
1800         VkDeviceSize vertBuffOffset;
1801         uint32_t indexBuffIndex;
1802         VkDeviceSize indexBuffOffset;
1803         uint32_t buildScratchBuffIndex;
1804         VkDeviceSize buildScratchBuffOffset;
1805     } m_info;
1806 };
1807 
1808 template <class X>
negz(const X &)1809 inline X negz(const X &)
1810 {
1811     return (~static_cast<X>(0));
1812 }
1813 template <class X>
isnegz(const X & x)1814 inline bool isnegz(const X &x)
1815 {
1816     return x == negz(x);
1817 }
1818 template <class Y>
make_unsigned(const Y & y)1819 inline auto make_unsigned(const Y &y) -> typename std::make_unsigned<Y>::type
1820 {
1821     return static_cast<typename std::make_unsigned<Y>::type>(y);
1822 }
1823 
BottomLevelAccelerationStructurePoolMember(BottomLevelAccelerationStructurePoolImpl & pool)1824 BottomLevelAccelerationStructurePoolMember::BottomLevelAccelerationStructurePoolMember(
1825     BottomLevelAccelerationStructurePoolImpl &pool)
1826     : m_pool(pool)
1827     , m_info{}
1828 {
1829 }
1830 
1831 struct BottomLevelAccelerationStructurePoolImpl
1832 {
1833     BottomLevelAccelerationStructurePoolImpl(BottomLevelAccelerationStructurePoolImpl &&)      = delete;
1834     BottomLevelAccelerationStructurePoolImpl(const BottomLevelAccelerationStructurePoolImpl &) = delete;
1835     BottomLevelAccelerationStructurePoolImpl(BottomLevelAccelerationStructurePool &pool);
1836 
1837     BottomLevelAccelerationStructurePool &m_pool;
1838     std::vector<de::SharedPtr<BufferWithMemory>> m_accellerationStructureBuffers;
1839     de::SharedPtr<BufferWithMemory> m_deviceScratchBuffer;
1840     de::UniquePtr<std::vector<uint8_t>> m_hostScratchBuffer;
1841     std::vector<de::SharedPtr<BufferWithMemory>> m_vertexBuffers;
1842     std::vector<de::SharedPtr<BufferWithMemory>> m_indexBuffers;
1843 };
BottomLevelAccelerationStructurePoolImpl(BottomLevelAccelerationStructurePool & pool)1844 BottomLevelAccelerationStructurePoolImpl::BottomLevelAccelerationStructurePoolImpl(
1845     BottomLevelAccelerationStructurePool &pool)
1846     : m_pool(pool)
1847     , m_accellerationStructureBuffers()
1848     , m_deviceScratchBuffer()
1849     , m_hostScratchBuffer(new std::vector<uint8_t>)
1850     , m_vertexBuffers()
1851     , m_indexBuffers()
1852 {
1853 }
getAccelerationStructureBuffer() const1854 BufferWithMemory *BottomLevelAccelerationStructurePoolMember::getAccelerationStructureBuffer() const
1855 {
1856     BufferWithMemory *result = nullptr;
1857     if (m_pool.m_accellerationStructureBuffers.size())
1858     {
1859         DE_ASSERT(!isnegz(m_info.accStrIndex));
1860         result = m_pool.m_accellerationStructureBuffers[m_info.accStrIndex].get();
1861     }
1862     return result;
1863 }
getDeviceScratchBuffer() const1864 BufferWithMemory *BottomLevelAccelerationStructurePoolMember::getDeviceScratchBuffer() const
1865 {
1866     DE_ASSERT(m_info.buildScratchBuffIndex == 0);
1867     return m_pool.m_deviceScratchBuffer.get();
1868 }
getHostScratchBuffer() const1869 std::vector<uint8_t> *BottomLevelAccelerationStructurePoolMember::getHostScratchBuffer() const
1870 {
1871     return this->m_buildScratchSize ? m_pool.m_hostScratchBuffer.get() : nullptr;
1872 }
1873 
getVertexBuffer() const1874 BufferWithMemory *BottomLevelAccelerationStructurePoolMember::getVertexBuffer() const
1875 {
1876     BufferWithMemory *result = nullptr;
1877     if (m_pool.m_vertexBuffers.size())
1878     {
1879         DE_ASSERT(!isnegz(m_info.vertBuffIndex));
1880         result = m_pool.m_vertexBuffers[m_info.vertBuffIndex].get();
1881     }
1882     return result;
1883 }
getIndexBuffer() const1884 BufferWithMemory *BottomLevelAccelerationStructurePoolMember::getIndexBuffer() const
1885 {
1886     BufferWithMemory *result = nullptr;
1887     if (m_pool.m_indexBuffers.size())
1888     {
1889         DE_ASSERT(!isnegz(m_info.indexBuffIndex));
1890         result = m_pool.m_indexBuffers[m_info.indexBuffIndex].get();
1891     }
1892     return result;
1893 }
1894 
1895 struct BottomLevelAccelerationStructurePool::Impl : BottomLevelAccelerationStructurePoolImpl
1896 {
1897     friend class BottomLevelAccelerationStructurePool;
1898     friend class BottomLevelAccelerationStructurePoolMember;
1899 
Implvk::BottomLevelAccelerationStructurePool::Impl1900     Impl(BottomLevelAccelerationStructurePool &pool) : BottomLevelAccelerationStructurePoolImpl(pool)
1901     {
1902     }
1903 };
1904 
BottomLevelAccelerationStructurePool()1905 BottomLevelAccelerationStructurePool::BottomLevelAccelerationStructurePool()
1906     : m_batchStructCount(4)
1907     , m_batchGeomCount(0)
1908     , m_infos()
1909     , m_structs()
1910     , m_createOnce(false)
1911     , m_tryCachedMemory(true)
1912     , m_structsBuffSize(0)
1913     , m_updatesScratchSize(0)
1914     , m_buildsScratchSize(0)
1915     , m_verticesSize(0)
1916     , m_indicesSize(0)
1917     , m_impl(new Impl(*this))
1918 {
1919 }
1920 
~BottomLevelAccelerationStructurePool()1921 BottomLevelAccelerationStructurePool::~BottomLevelAccelerationStructurePool()
1922 {
1923     delete m_impl;
1924 }
1925 
batchStructCount(const uint32_t & value)1926 void BottomLevelAccelerationStructurePool::batchStructCount(const uint32_t &value)
1927 {
1928     DE_ASSERT(value >= 1);
1929     m_batchStructCount = value;
1930 }
1931 
add(VkDeviceSize structureSize,VkDeviceAddress deviceAddress)1932 auto BottomLevelAccelerationStructurePool::add(VkDeviceSize structureSize, VkDeviceAddress deviceAddress)
1933     -> BottomLevelAccelerationStructurePool::BlasPtr
1934 {
1935     // Prevent a programmer from calling this method after batchCreate(...) method has been called.
1936     if (m_createOnce)
1937         DE_ASSERT(0);
1938 
1939     auto blas = new BottomLevelAccelerationStructurePoolMember(*m_impl);
1940     m_infos.push_back({structureSize, deviceAddress});
1941     m_structs.emplace_back(blas);
1942     return m_structs.back();
1943 }
1944 
adjustBatchCount(const DeviceInterface & vkd,const VkDevice device,const std::vector<BottomLevelAccelerationStructurePool::BlasPtr> & structs,const std::vector<BottomLevelAccelerationStructurePool::BlasInfo> & infos,const VkDeviceSize maxBufferSize,uint32_t (& result)[4])1945 void adjustBatchCount(const DeviceInterface &vkd, const VkDevice device,
1946                       const std::vector<BottomLevelAccelerationStructurePool::BlasPtr> &structs,
1947                       const std::vector<BottomLevelAccelerationStructurePool::BlasInfo> &infos,
1948                       const VkDeviceSize maxBufferSize, uint32_t (&result)[4])
1949 {
1950     tcu::Vector<VkDeviceSize, 4> sizes(0);
1951     tcu::Vector<VkDeviceSize, 4> sums(0);
1952     tcu::Vector<uint32_t, 4> tmps(0);
1953     tcu::Vector<uint32_t, 4> batches(0);
1954 
1955     VkDeviceSize updateScratchSize = 0;
1956     static_cast<void>(updateScratchSize); // not used yet, disabled for future implementation
1957 
1958     auto updateIf = [&](uint32_t c)
1959     {
1960         if (sums[c] + sizes[c] <= maxBufferSize)
1961         {
1962             sums[c] += sizes[c];
1963             tmps[c] += 1;
1964 
1965             batches[c] = std::max(tmps[c], batches[c]);
1966         }
1967         else
1968         {
1969             sums[c] = 0;
1970             tmps[c] = 0;
1971         }
1972     };
1973 
1974     const uint32_t maxIter = static_cast<uint32_t>(structs.size());
1975     for (uint32_t i = 0; i < maxIter; ++i)
1976     {
1977         auto &str = *dynamic_cast<BottomLevelAccelerationStructurePoolMember *>(structs[i].get());
1978         std::tie(sizes[0], updateScratchSize, sizes[1], sizes[2], sizes[3]) =
1979             str.computeBuildSize(vkd, device, infos[i].structureSize);
1980 
1981         updateIf(0);
1982         updateIf(1);
1983         updateIf(2);
1984         updateIf(3);
1985     }
1986 
1987     result[0] = std::max(batches[0], 1u);
1988     result[1] = std::max(batches[1], 1u);
1989     result[2] = std::max(batches[2], 1u);
1990     result[3] = std::max(batches[3], 1u);
1991 }
1992 
getAllocationCount() const1993 size_t BottomLevelAccelerationStructurePool::getAllocationCount() const
1994 {
1995     return m_impl->m_accellerationStructureBuffers.size() + m_impl->m_vertexBuffers.size() +
1996            m_impl->m_indexBuffers.size() + 1 /* for scratch buffer */;
1997 }
1998 
getAllocationCount(const DeviceInterface & vk,const VkDevice device,const VkDeviceSize maxBufferSize) const1999 size_t BottomLevelAccelerationStructurePool::getAllocationCount(const DeviceInterface &vk, const VkDevice device,
2000                                                                 const VkDeviceSize maxBufferSize) const
2001 {
2002     DE_ASSERT(m_structs.size() != 0);
2003 
2004     std::map<uint32_t, VkDeviceSize> accStrSizes;
2005     std::map<uint32_t, VkDeviceSize> vertBuffSizes;
2006     std::map<uint32_t, VkDeviceSize> indexBuffSizes;
2007     std::map<uint32_t, VkDeviceSize> scratchBuffSizes;
2008 
2009     const uint32_t allStructsCount = structCount();
2010 
2011     uint32_t batchStructCount  = m_batchStructCount;
2012     uint32_t batchScratchCount = m_batchStructCount;
2013     uint32_t batchVertexCount  = m_batchGeomCount ? m_batchGeomCount : m_batchStructCount;
2014     uint32_t batchIndexCount   = batchVertexCount;
2015 
2016     if (!isnegz(maxBufferSize))
2017     {
2018         uint32_t batches[4];
2019         adjustBatchCount(vk, device, m_structs, m_infos, maxBufferSize, batches);
2020         batchStructCount  = batches[0];
2021         batchScratchCount = batches[1];
2022         batchVertexCount  = batches[2];
2023         batchIndexCount   = batches[3];
2024     }
2025 
2026     uint32_t iStr     = 0;
2027     uint32_t iScratch = 0;
2028     uint32_t iVertex  = 0;
2029     uint32_t iIndex   = 0;
2030 
2031     VkDeviceSize strSize           = 0;
2032     VkDeviceSize updateScratchSize = 0;
2033     VkDeviceSize buildScratchSize  = 0;
2034     VkDeviceSize vertexSize        = 0;
2035     VkDeviceSize indexSize         = 0;
2036 
2037     for (; iStr < allStructsCount; ++iStr)
2038     {
2039         auto &str = *dynamic_cast<BottomLevelAccelerationStructurePoolMember *>(m_structs[iStr].get());
2040         std::tie(strSize, updateScratchSize, buildScratchSize, vertexSize, indexSize) =
2041             str.computeBuildSize(vk, device, m_infos[iStr].structureSize);
2042 
2043         {
2044             const VkDeviceSize alignedStrSize = deAlign64(strSize, 256);
2045             const uint32_t accStrIndex        = (iStr / batchStructCount);
2046             accStrSizes[accStrIndex] += alignedStrSize;
2047         }
2048 
2049         if (buildScratchSize != 0)
2050         {
2051             const VkDeviceSize alignedBuilsScratchSize = deAlign64(buildScratchSize, 256);
2052             const uint32_t scratchBuffIndex            = (iScratch / batchScratchCount);
2053             scratchBuffSizes[scratchBuffIndex] += alignedBuilsScratchSize;
2054             iScratch += 1;
2055         }
2056 
2057         if (vertexSize != 0)
2058         {
2059             const VkDeviceSize alignedVertBuffSize = deAlign64(vertexSize, 8);
2060             const uint32_t vertBuffIndex           = (iVertex / batchVertexCount);
2061             vertBuffSizes[vertBuffIndex] += alignedVertBuffSize;
2062             iVertex += 1;
2063         }
2064 
2065         if (indexSize != 0)
2066         {
2067             const VkDeviceSize alignedIndexBuffSize = deAlign64(indexSize, 8);
2068             const uint32_t indexBuffIndex           = (iIndex / batchIndexCount);
2069             indexBuffSizes[indexBuffIndex] += alignedIndexBuffSize;
2070             iIndex += 1;
2071         }
2072     }
2073 
2074     return accStrSizes.size() + vertBuffSizes.size() + indexBuffSizes.size() + scratchBuffSizes.size();
2075 }
2076 
getAllocationSizes(const DeviceInterface & vk,const VkDevice device) const2077 tcu::Vector<VkDeviceSize, 4> BottomLevelAccelerationStructurePool::getAllocationSizes(const DeviceInterface &vk,
2078                                                                                       const VkDevice device) const
2079 {
2080     if (m_structsBuffSize)
2081     {
2082         return tcu::Vector<VkDeviceSize, 4>(m_structsBuffSize, m_buildsScratchSize, m_verticesSize, m_indicesSize);
2083     }
2084 
2085     VkDeviceSize strSize           = 0;
2086     VkDeviceSize updateScratchSize = 0;
2087     static_cast<void>(updateScratchSize); // not used yet, disabled for future implementation
2088     VkDeviceSize buildScratchSize     = 0;
2089     VkDeviceSize vertexSize           = 0;
2090     VkDeviceSize indexSize            = 0;
2091     VkDeviceSize sumStrSize           = 0;
2092     VkDeviceSize sumUpdateScratchSize = 0;
2093     static_cast<void>(sumUpdateScratchSize); // not used yet, disabled for future implementation
2094     VkDeviceSize sumBuildScratchSize = 0;
2095     VkDeviceSize sumVertexSize       = 0;
2096     VkDeviceSize sumIndexSize        = 0;
2097     for (size_t i = 0; i < structCount(); ++i)
2098     {
2099         auto &str = *dynamic_cast<BottomLevelAccelerationStructurePoolMember *>(m_structs[i].get());
2100         std::tie(strSize, updateScratchSize, buildScratchSize, vertexSize, indexSize) =
2101             str.computeBuildSize(vk, device, m_infos[i].structureSize);
2102         sumStrSize += deAlign64(strSize, 256);
2103         //sumUpdateScratchSize    += deAlign64(updateScratchSize, 256);    not used yet, disabled for future implementation
2104         sumBuildScratchSize += deAlign64(buildScratchSize, 256);
2105         sumVertexSize += deAlign64(vertexSize, 8);
2106         sumIndexSize += deAlign64(indexSize, 8);
2107     }
2108     return tcu::Vector<VkDeviceSize, 4>(sumStrSize, sumBuildScratchSize, sumVertexSize, sumIndexSize);
2109 }
2110 
batchCreate(const DeviceInterface & vkd,const VkDevice device,Allocator & allocator)2111 void BottomLevelAccelerationStructurePool::batchCreate(const DeviceInterface &vkd, const VkDevice device,
2112                                                        Allocator &allocator)
2113 {
2114     batchCreateAdjust(vkd, device, allocator, negz<VkDeviceSize>(0));
2115 }
2116 
batchCreateAdjust(const DeviceInterface & vkd,const VkDevice device,Allocator & allocator,const VkDeviceSize maxBufferSize,bool scratchIsHostVisible)2117 void BottomLevelAccelerationStructurePool::batchCreateAdjust(const DeviceInterface &vkd, const VkDevice device,
2118                                                              Allocator &allocator, const VkDeviceSize maxBufferSize,
2119                                                              bool scratchIsHostVisible)
2120 {
2121     // Prevent a programmer from calling this method more than once.
2122     if (m_createOnce)
2123         DE_ASSERT(0);
2124 
2125     m_createOnce = true;
2126     DE_ASSERT(m_structs.size() != 0);
2127 
2128     auto createAccellerationStructureBuffer = [&](VkDeviceSize bufferSize) ->
2129         typename std::add_pointer<BufferWithMemory>::type
2130     {
2131         BufferWithMemory *res = nullptr;
2132         const VkBufferCreateInfo bci =
2133             makeBufferCreateInfo(bufferSize, VK_BUFFER_USAGE_ACCELERATION_STRUCTURE_STORAGE_BIT_KHR |
2134                                                  VK_BUFFER_USAGE_SHADER_DEVICE_ADDRESS_BIT);
2135 
2136         if (m_tryCachedMemory)
2137             try
2138             {
2139                 res = new BufferWithMemory(vkd, device, allocator, bci,
2140                                            MemoryRequirement::Cached | MemoryRequirement::HostVisible |
2141                                                MemoryRequirement::Coherent | MemoryRequirement::DeviceAddress);
2142             }
2143             catch (const tcu::NotSupportedError &)
2144             {
2145                 res = nullptr;
2146             }
2147 
2148         return (nullptr != res) ? res :
2149                                   (new BufferWithMemory(vkd, device, allocator, bci,
2150                                                         MemoryRequirement::HostVisible | MemoryRequirement::Coherent |
2151                                                             MemoryRequirement::DeviceAddress));
2152     };
2153 
2154     auto createDeviceScratchBuffer = [&](VkDeviceSize bufferSize) -> de::SharedPtr<BufferWithMemory>
2155     {
2156         const auto extraMemReqs =
2157             (scratchIsHostVisible ? (MemoryRequirement::HostVisible | MemoryRequirement::Coherent) :
2158                                     MemoryRequirement::Any);
2159         const auto memReqs = (MemoryRequirement::DeviceAddress | extraMemReqs);
2160 
2161         const VkBufferCreateInfo bci = makeBufferCreateInfo(bufferSize, VK_BUFFER_USAGE_STORAGE_BUFFER_BIT |
2162                                                                             VK_BUFFER_USAGE_SHADER_DEVICE_ADDRESS_BIT);
2163         BufferWithMemory *p          = new BufferWithMemory(vkd, device, allocator, bci, memReqs);
2164         return de::SharedPtr<BufferWithMemory>(p);
2165     };
2166 
2167     std::map<uint32_t, VkDeviceSize> accStrSizes;
2168     std::map<uint32_t, VkDeviceSize> vertBuffSizes;
2169     std::map<uint32_t, VkDeviceSize> indexBuffSizes;
2170 
2171     const uint32_t allStructsCount = structCount();
2172     uint32_t iterKey               = 0;
2173 
2174     uint32_t batchStructCount = m_batchStructCount;
2175     uint32_t batchVertexCount = m_batchGeomCount ? m_batchGeomCount : m_batchStructCount;
2176     uint32_t batchIndexCount  = batchVertexCount;
2177 
2178     if (!isnegz(maxBufferSize))
2179     {
2180         uint32_t batches[4];
2181         adjustBatchCount(vkd, device, m_structs, m_infos, maxBufferSize, batches);
2182         batchStructCount = batches[0];
2183         // batches[1]: batchScratchCount
2184         batchVertexCount = batches[2];
2185         batchIndexCount  = batches[3];
2186     }
2187 
2188     uint32_t iStr    = 0;
2189     uint32_t iVertex = 0;
2190     uint32_t iIndex  = 0;
2191 
2192     VkDeviceSize strSize             = 0;
2193     VkDeviceSize updateScratchSize   = 0;
2194     VkDeviceSize buildScratchSize    = 0;
2195     VkDeviceSize maxBuildScratchSize = 0;
2196     VkDeviceSize vertexSize          = 0;
2197     VkDeviceSize indexSize           = 0;
2198 
2199     VkDeviceSize strOffset    = 0;
2200     VkDeviceSize vertexOffset = 0;
2201     VkDeviceSize indexOffset  = 0;
2202 
2203     uint32_t hostStructCount   = 0;
2204     uint32_t deviceStructCount = 0;
2205 
2206     for (; iStr < allStructsCount; ++iStr)
2207     {
2208         BottomLevelAccelerationStructurePoolMember::Info info{};
2209         auto &str = *dynamic_cast<BottomLevelAccelerationStructurePoolMember *>(m_structs[iStr].get());
2210         std::tie(strSize, updateScratchSize, buildScratchSize, vertexSize, indexSize) =
2211             str.computeBuildSize(vkd, device, m_infos[iStr].structureSize);
2212 
2213         ++(str.getBuildType() == VK_ACCELERATION_STRUCTURE_BUILD_TYPE_HOST_KHR ? hostStructCount : deviceStructCount);
2214 
2215         {
2216             const VkDeviceSize alignedStrSize = deAlign64(strSize, 256);
2217             const uint32_t accStrIndex        = (iStr / batchStructCount);
2218             if (iStr != 0 && (iStr % batchStructCount) == 0)
2219             {
2220                 strOffset = 0;
2221             }
2222 
2223             info.accStrIndex  = accStrIndex;
2224             info.accStrOffset = strOffset;
2225             accStrSizes[accStrIndex] += alignedStrSize;
2226             strOffset += alignedStrSize;
2227             m_structsBuffSize += alignedStrSize;
2228         }
2229 
2230         if (buildScratchSize != 0)
2231         {
2232             maxBuildScratchSize = std::max(maxBuildScratchSize, make_unsigned(deAlign64(buildScratchSize, 256u)));
2233 
2234             info.buildScratchBuffIndex  = 0;
2235             info.buildScratchBuffOffset = 0;
2236         }
2237 
2238         if (vertexSize != 0)
2239         {
2240             const VkDeviceSize alignedVertBuffSize = deAlign64(vertexSize, 8);
2241             const uint32_t vertBuffIndex           = (iVertex / batchVertexCount);
2242             if (iVertex != 0 && (iVertex % batchVertexCount) == 0)
2243             {
2244                 vertexOffset = 0;
2245             }
2246 
2247             info.vertBuffIndex  = vertBuffIndex;
2248             info.vertBuffOffset = vertexOffset;
2249             vertBuffSizes[vertBuffIndex] += alignedVertBuffSize;
2250             vertexOffset += alignedVertBuffSize;
2251             m_verticesSize += alignedVertBuffSize;
2252             iVertex += 1;
2253         }
2254 
2255         if (indexSize != 0)
2256         {
2257             const VkDeviceSize alignedIndexBuffSize = deAlign64(indexSize, 8);
2258             const uint32_t indexBuffIndex           = (iIndex / batchIndexCount);
2259             if (iIndex != 0 && (iIndex % batchIndexCount) == 0)
2260             {
2261                 indexOffset = 0;
2262             }
2263 
2264             info.indexBuffIndex  = indexBuffIndex;
2265             info.indexBuffOffset = indexOffset;
2266             indexBuffSizes[indexBuffIndex] += alignedIndexBuffSize;
2267             indexOffset += alignedIndexBuffSize;
2268             m_indicesSize += alignedIndexBuffSize;
2269             iIndex += 1;
2270         }
2271 
2272         str.preCreateSetSizesAndOffsets(info, strSize, updateScratchSize, buildScratchSize);
2273     }
2274 
2275     for (iterKey = 0; iterKey < static_cast<uint32_t>(accStrSizes.size()); ++iterKey)
2276     {
2277         m_impl->m_accellerationStructureBuffers.emplace_back(
2278             createAccellerationStructureBuffer(accStrSizes.at(iterKey)));
2279     }
2280     for (iterKey = 0; iterKey < static_cast<uint32_t>(vertBuffSizes.size()); ++iterKey)
2281     {
2282         m_impl->m_vertexBuffers.emplace_back(createVertexBuffer(vkd, device, allocator, vertBuffSizes.at(iterKey)));
2283     }
2284     for (iterKey = 0; iterKey < static_cast<uint32_t>(indexBuffSizes.size()); ++iterKey)
2285     {
2286         m_impl->m_indexBuffers.emplace_back(createIndexBuffer(vkd, device, allocator, indexBuffSizes.at(iterKey)));
2287     }
2288 
2289     if (maxBuildScratchSize)
2290     {
2291         if (hostStructCount)
2292             m_impl->m_hostScratchBuffer->resize(static_cast<size_t>(maxBuildScratchSize));
2293         if (deviceStructCount)
2294             m_impl->m_deviceScratchBuffer = createDeviceScratchBuffer(maxBuildScratchSize);
2295 
2296         m_buildsScratchSize = maxBuildScratchSize;
2297     }
2298 
2299     for (iterKey = 0; iterKey < allStructsCount; ++iterKey)
2300     {
2301         auto &str = *dynamic_cast<BottomLevelAccelerationStructurePoolMember *>(m_structs[iterKey].get());
2302         str.createAccellerationStructure(vkd, device, m_infos[iterKey].deviceAddress);
2303     }
2304 }
2305 
batchBuild(const DeviceInterface & vk,const VkDevice device,VkCommandBuffer cmdBuffer)2306 void BottomLevelAccelerationStructurePool::batchBuild(const DeviceInterface &vk, const VkDevice device,
2307                                                       VkCommandBuffer cmdBuffer)
2308 {
2309     for (size_t i = 0u; i < m_structs.size(); ++i)
2310     {
2311         const bool last = (i == m_structs.size() - 1u);
2312         const VkPipelineStageFlags barrierDst =
2313             (last ? VK_PIPELINE_STAGE_ALL_COMMANDS_BIT : VK_PIPELINE_STAGE_ACCELERATION_STRUCTURE_BUILD_BIT_KHR);
2314         m_structs.at(i)->build(vk, device, cmdBuffer, nullptr, barrierDst);
2315     }
2316 }
2317 
batchBuild(const DeviceInterface & vk,const VkDevice device,VkCommandPool cmdPool,VkQueue queue,qpWatchDog * watchDog)2318 void BottomLevelAccelerationStructurePool::batchBuild(const DeviceInterface &vk, const VkDevice device,
2319                                                       VkCommandPool cmdPool, VkQueue queue, qpWatchDog *watchDog)
2320 {
2321     const uint32_t limit = 10000u;
2322     const uint32_t count = structCount();
2323     std::vector<BlasPtr> buildingOnDevice;
2324 
2325     auto buildOnDevice = [&]() -> void
2326     {
2327         Move<VkCommandBuffer> cmd = allocateCommandBuffer(vk, device, cmdPool, VK_COMMAND_BUFFER_LEVEL_PRIMARY);
2328 
2329         beginCommandBuffer(vk, *cmd, 0u);
2330         for (const auto &str : buildingOnDevice)
2331             str->build(vk, device, *cmd);
2332         endCommandBuffer(vk, *cmd);
2333 
2334         submitCommandsAndWait(vk, device, queue, *cmd);
2335         vk.resetCommandPool(device, cmdPool, VK_COMMAND_POOL_RESET_RELEASE_RESOURCES_BIT);
2336     };
2337 
2338     buildingOnDevice.reserve(limit);
2339     for (uint32_t i = 0; i < count; ++i)
2340     {
2341         auto str = m_structs[i];
2342 
2343         if (str->getBuildType() == VK_ACCELERATION_STRUCTURE_BUILD_TYPE_HOST_KHR)
2344             str->build(vk, device, nullptr);
2345         else
2346             buildingOnDevice.emplace_back(str);
2347 
2348         if (buildingOnDevice.size() == limit || (count - 1) == i)
2349         {
2350             buildOnDevice();
2351             buildingOnDevice.clear();
2352         }
2353 
2354         if ((i % WATCHDOG_INTERVAL) == 0 && watchDog)
2355             qpWatchDog_touch(watchDog);
2356     }
2357 }
2358 
computeBuildSize(const DeviceInterface & vk,const VkDevice device,const VkDeviceSize strSize) const2359 auto BottomLevelAccelerationStructurePoolMember::computeBuildSize(const DeviceInterface &vk, const VkDevice device,
2360                                                                   const VkDeviceSize strSize) const
2361     //              accStrSize,updateScratch,buildScratch, vertexSize, indexSize
2362     -> std::tuple<VkDeviceSize, VkDeviceSize, VkDeviceSize, VkDeviceSize, VkDeviceSize>
2363 {
2364     DE_ASSERT(!m_geometriesData.empty() != !(strSize == 0)); // logical xor
2365 
2366     std::tuple<VkDeviceSize, VkDeviceSize, VkDeviceSize, VkDeviceSize, VkDeviceSize> result(deAlign64(strSize, 256), 0,
2367                                                                                             0, 0, 0);
2368 
2369     if (!m_geometriesData.empty())
2370     {
2371         std::vector<VkAccelerationStructureGeometryKHR> accelerationStructureGeometriesKHR;
2372         std::vector<VkAccelerationStructureGeometryKHR *> accelerationStructureGeometriesKHRPointers;
2373         std::vector<VkAccelerationStructureBuildRangeInfoKHR> accelerationStructureBuildRangeInfoKHR;
2374         std::vector<VkAccelerationStructureTrianglesOpacityMicromapEXT> accelerationStructureGeometryMicromapsEXT;
2375         std::vector<uint32_t> maxPrimitiveCounts;
2376         prepareGeometries(vk, device, accelerationStructureGeometriesKHR, accelerationStructureGeometriesKHRPointers,
2377                           accelerationStructureBuildRangeInfoKHR, accelerationStructureGeometryMicromapsEXT,
2378                           maxPrimitiveCounts);
2379 
2380         const VkAccelerationStructureGeometryKHR *accelerationStructureGeometriesKHRPointer =
2381             accelerationStructureGeometriesKHR.data();
2382         const VkAccelerationStructureGeometryKHR *const *accelerationStructureGeometry =
2383             accelerationStructureGeometriesKHRPointers.data();
2384 
2385         VkAccelerationStructureBuildGeometryInfoKHR accelerationStructureBuildGeometryInfoKHR = {
2386             VK_STRUCTURE_TYPE_ACCELERATION_STRUCTURE_BUILD_GEOMETRY_INFO_KHR, //  VkStructureType sType;
2387             nullptr,                                                          //  const void* pNext;
2388             VK_ACCELERATION_STRUCTURE_TYPE_BOTTOM_LEVEL_KHR,                  //  VkAccelerationStructureTypeKHR type;
2389             m_buildFlags,                                   //  VkBuildAccelerationStructureFlagsKHR flags;
2390             VK_BUILD_ACCELERATION_STRUCTURE_MODE_BUILD_KHR, //  VkBuildAccelerationStructureModeKHR mode;
2391             VK_NULL_HANDLE,                                 //  VkAccelerationStructureKHR srcAccelerationStructure;
2392             VK_NULL_HANDLE,                                 //  VkAccelerationStructureKHR dstAccelerationStructure;
2393             static_cast<uint32_t>(accelerationStructureGeometriesKHR.size()), //  uint32_t geometryCount;
2394             m_useArrayOfPointers ?
2395                 nullptr :
2396                 accelerationStructureGeometriesKHRPointer, //  const VkAccelerationStructureGeometryKHR* pGeometries;
2397             m_useArrayOfPointers ? accelerationStructureGeometry :
2398                                    nullptr,     //  const VkAccelerationStructureGeometryKHR* const* ppGeometries;
2399             makeDeviceOrHostAddressKHR(nullptr) //  VkDeviceOrHostAddressKHR scratchData;
2400         };
2401 
2402         VkAccelerationStructureBuildSizesInfoKHR sizeInfo = {
2403             VK_STRUCTURE_TYPE_ACCELERATION_STRUCTURE_BUILD_SIZES_INFO_KHR, //  VkStructureType sType;
2404             nullptr,                                                       //  const void* pNext;
2405             0,                                                             //  VkDeviceSize accelerationStructureSize;
2406             0,                                                             //  VkDeviceSize updateScratchSize;
2407             0                                                              //  VkDeviceSize buildScratchSize;
2408         };
2409 
2410         vk.getAccelerationStructureBuildSizesKHR(device, m_buildType, &accelerationStructureBuildGeometryInfoKHR,
2411                                                  maxPrimitiveCounts.data(), &sizeInfo);
2412 
2413         std::get<0>(result) = sizeInfo.accelerationStructureSize;
2414         std::get<1>(result) = sizeInfo.updateScratchSize;
2415         std::get<2>(result) = sizeInfo.buildScratchSize;
2416         std::get<3>(result) = getVertexBufferSize(m_geometriesData);
2417         std::get<4>(result) = getIndexBufferSize(m_geometriesData);
2418     }
2419 
2420     return result;
2421 }
2422 
preCreateSetSizesAndOffsets(const Info & info,const VkDeviceSize accStrSize,const VkDeviceSize updateScratchSize,const VkDeviceSize buildScratchSize)2423 void BottomLevelAccelerationStructurePoolMember::preCreateSetSizesAndOffsets(const Info &info,
2424                                                                              const VkDeviceSize accStrSize,
2425                                                                              const VkDeviceSize updateScratchSize,
2426                                                                              const VkDeviceSize buildScratchSize)
2427 {
2428     m_info              = info;
2429     m_structureSize     = accStrSize;
2430     m_updateScratchSize = updateScratchSize;
2431     m_buildScratchSize  = buildScratchSize;
2432 }
2433 
createAccellerationStructure(const DeviceInterface & vk,const VkDevice device,VkDeviceAddress deviceAddress)2434 void BottomLevelAccelerationStructurePoolMember::createAccellerationStructure(const DeviceInterface &vk,
2435                                                                               const VkDevice device,
2436                                                                               VkDeviceAddress deviceAddress)
2437 {
2438     const VkAccelerationStructureTypeKHR structureType =
2439         (m_createGeneric ? VK_ACCELERATION_STRUCTURE_TYPE_GENERIC_KHR :
2440                            VK_ACCELERATION_STRUCTURE_TYPE_BOTTOM_LEVEL_KHR);
2441     const VkAccelerationStructureCreateInfoKHR accelerationStructureCreateInfoKHR{
2442         VK_STRUCTURE_TYPE_ACCELERATION_STRUCTURE_CREATE_INFO_KHR, //  VkStructureType sType;
2443         nullptr,                                                  //  const void* pNext;
2444         m_createFlags,                                            //  VkAccelerationStructureCreateFlagsKHR createFlags;
2445         getAccelerationStructureBuffer()->get(),                  //  VkBuffer buffer;
2446         getAccelerationStructureBufferOffset(),                   //  VkDeviceSize offset;
2447         m_structureSize,                                          //  VkDeviceSize size;
2448         structureType,                                            //  VkAccelerationStructureTypeKHR type;
2449         deviceAddress                                             //  VkDeviceAddress deviceAddress;
2450     };
2451 
2452     m_accelerationStructureKHR =
2453         createAccelerationStructureKHR(vk, device, &accelerationStructureCreateInfoKHR, nullptr);
2454 }
2455 
~TopLevelAccelerationStructure()2456 TopLevelAccelerationStructure::~TopLevelAccelerationStructure()
2457 {
2458 }
2459 
TopLevelAccelerationStructure()2460 TopLevelAccelerationStructure::TopLevelAccelerationStructure()
2461     : m_structureSize(0u)
2462     , m_updateScratchSize(0u)
2463     , m_buildScratchSize(0u)
2464 {
2465 }
2466 
setInstanceCount(const size_t instanceCount)2467 void TopLevelAccelerationStructure::setInstanceCount(const size_t instanceCount)
2468 {
2469     m_bottomLevelInstances.reserve(instanceCount);
2470     m_instanceData.reserve(instanceCount);
2471 }
2472 
addInstance(de::SharedPtr<BottomLevelAccelerationStructure> bottomLevelStructure,const VkTransformMatrixKHR & matrix,uint32_t instanceCustomIndex,uint32_t mask,uint32_t instanceShaderBindingTableRecordOffset,VkGeometryInstanceFlagsKHR flags)2473 void TopLevelAccelerationStructure::addInstance(de::SharedPtr<BottomLevelAccelerationStructure> bottomLevelStructure,
2474                                                 const VkTransformMatrixKHR &matrix, uint32_t instanceCustomIndex,
2475                                                 uint32_t mask, uint32_t instanceShaderBindingTableRecordOffset,
2476                                                 VkGeometryInstanceFlagsKHR flags)
2477 {
2478     m_bottomLevelInstances.push_back(bottomLevelStructure);
2479     m_instanceData.push_back(
2480         InstanceData(matrix, instanceCustomIndex, mask, instanceShaderBindingTableRecordOffset, flags));
2481 }
2482 
getStructureBuildSizes() const2483 VkAccelerationStructureBuildSizesInfoKHR TopLevelAccelerationStructure::getStructureBuildSizes() const
2484 {
2485     return {
2486         VK_STRUCTURE_TYPE_ACCELERATION_STRUCTURE_BUILD_SIZES_INFO_KHR, //  VkStructureType sType;
2487         nullptr,                                                       //  const void* pNext;
2488         m_structureSize,                                               //  VkDeviceSize accelerationStructureSize;
2489         m_updateScratchSize,                                           //  VkDeviceSize updateScratchSize;
2490         m_buildScratchSize                                             //  VkDeviceSize buildScratchSize;
2491     };
2492 }
2493 
createAndBuild(const DeviceInterface & vk,const VkDevice device,const VkCommandBuffer cmdBuffer,Allocator & allocator,VkDeviceAddress deviceAddress)2494 void TopLevelAccelerationStructure::createAndBuild(const DeviceInterface &vk, const VkDevice device,
2495                                                    const VkCommandBuffer cmdBuffer, Allocator &allocator,
2496                                                    VkDeviceAddress deviceAddress)
2497 {
2498     create(vk, device, allocator, 0u, deviceAddress);
2499     build(vk, device, cmdBuffer);
2500 }
2501 
createAndCopyFrom(const DeviceInterface & vk,const VkDevice device,const VkCommandBuffer cmdBuffer,Allocator & allocator,TopLevelAccelerationStructure * accelerationStructure,VkDeviceSize compactCopySize,VkDeviceAddress deviceAddress)2502 void TopLevelAccelerationStructure::createAndCopyFrom(const DeviceInterface &vk, const VkDevice device,
2503                                                       const VkCommandBuffer cmdBuffer, Allocator &allocator,
2504                                                       TopLevelAccelerationStructure *accelerationStructure,
2505                                                       VkDeviceSize compactCopySize, VkDeviceAddress deviceAddress)
2506 {
2507     DE_ASSERT(accelerationStructure != NULL);
2508     VkDeviceSize copiedSize = compactCopySize > 0u ?
2509                                   compactCopySize :
2510                                   accelerationStructure->getStructureBuildSizes().accelerationStructureSize;
2511     DE_ASSERT(copiedSize != 0u);
2512 
2513     create(vk, device, allocator, copiedSize, deviceAddress);
2514     copyFrom(vk, device, cmdBuffer, accelerationStructure, compactCopySize > 0u);
2515 }
2516 
createAndDeserializeFrom(const DeviceInterface & vk,const VkDevice device,const VkCommandBuffer cmdBuffer,Allocator & allocator,SerialStorage * storage,VkDeviceAddress deviceAddress)2517 void TopLevelAccelerationStructure::createAndDeserializeFrom(const DeviceInterface &vk, const VkDevice device,
2518                                                              const VkCommandBuffer cmdBuffer, Allocator &allocator,
2519                                                              SerialStorage *storage, VkDeviceAddress deviceAddress)
2520 {
2521     DE_ASSERT(storage != NULL);
2522     DE_ASSERT(storage->getStorageSize() >= SerialStorage::SERIAL_STORAGE_SIZE_MIN);
2523     create(vk, device, allocator, storage->getDeserializedSize(), deviceAddress);
2524     if (storage->hasDeepFormat())
2525         createAndDeserializeBottoms(vk, device, cmdBuffer, allocator, storage);
2526     deserialize(vk, device, cmdBuffer, storage);
2527 }
2528 
createInstanceBuffer(const DeviceInterface & vk,const VkDevice device,Allocator & allocator,std::vector<de::SharedPtr<BottomLevelAccelerationStructure>> bottomLevelInstances,std::vector<InstanceData> instanceData,const bool tryCachedMemory)2529 BufferWithMemory *createInstanceBuffer(
2530     const DeviceInterface &vk, const VkDevice device, Allocator &allocator,
2531     std::vector<de::SharedPtr<BottomLevelAccelerationStructure>> bottomLevelInstances,
2532     std::vector<InstanceData> instanceData, const bool tryCachedMemory)
2533 {
2534     DE_ASSERT(bottomLevelInstances.size() != 0);
2535     DE_ASSERT(bottomLevelInstances.size() == instanceData.size());
2536     DE_UNREF(instanceData);
2537 
2538     BufferWithMemory *result           = nullptr;
2539     const VkDeviceSize bufferSizeBytes = bottomLevelInstances.size() * sizeof(VkAccelerationStructureInstanceKHR);
2540     const VkBufferCreateInfo bufferCreateInfo =
2541         makeBufferCreateInfo(bufferSizeBytes, VK_BUFFER_USAGE_ACCELERATION_STRUCTURE_BUILD_INPUT_READ_ONLY_BIT_KHR |
2542                                                   VK_BUFFER_USAGE_SHADER_DEVICE_ADDRESS_BIT);
2543     if (tryCachedMemory)
2544         try
2545         {
2546             result = new BufferWithMemory(vk, device, allocator, bufferCreateInfo,
2547                                           MemoryRequirement::Cached | MemoryRequirement::HostVisible |
2548                                               MemoryRequirement::Coherent | MemoryRequirement::DeviceAddress);
2549         }
2550         catch (const tcu::NotSupportedError &)
2551         {
2552             result = nullptr;
2553         }
2554     return result ? result :
2555                     new BufferWithMemory(vk, device, allocator, bufferCreateInfo,
2556                                          MemoryRequirement::HostVisible | MemoryRequirement::Coherent |
2557                                              MemoryRequirement::DeviceAddress);
2558 }
2559 
updateSingleInstance(const DeviceInterface & vk,const VkDevice device,const BottomLevelAccelerationStructure & bottomLevelAccelerationStructure,const InstanceData & instanceData,uint8_t * bufferLocation,VkAccelerationStructureBuildTypeKHR buildType,bool inactiveInstances)2560 void updateSingleInstance(const DeviceInterface &vk, const VkDevice device,
2561                           const BottomLevelAccelerationStructure &bottomLevelAccelerationStructure,
2562                           const InstanceData &instanceData, uint8_t *bufferLocation,
2563                           VkAccelerationStructureBuildTypeKHR buildType, bool inactiveInstances)
2564 {
2565     const VkAccelerationStructureKHR accelerationStructureKHR = *bottomLevelAccelerationStructure.getPtr();
2566 
2567     // This part needs to be fixed once a new version of the VkAccelerationStructureInstanceKHR will be added to vkStructTypes.inl
2568     VkDeviceAddress accelerationStructureAddress;
2569     if (buildType == VK_ACCELERATION_STRUCTURE_BUILD_TYPE_DEVICE_KHR)
2570     {
2571         VkAccelerationStructureDeviceAddressInfoKHR asDeviceAddressInfo = {
2572             VK_STRUCTURE_TYPE_ACCELERATION_STRUCTURE_DEVICE_ADDRESS_INFO_KHR, // VkStructureType sType;
2573             nullptr,                                                          // const void* pNext;
2574             accelerationStructureKHR // VkAccelerationStructureKHR accelerationStructure;
2575         };
2576         accelerationStructureAddress = vk.getAccelerationStructureDeviceAddressKHR(device, &asDeviceAddressInfo);
2577     }
2578 
2579     uint64_t structureReference;
2580     if (inactiveInstances)
2581     {
2582         // Instances will be marked inactive by making their references VK_NULL_HANDLE or having address zero.
2583         structureReference = 0ull;
2584     }
2585     else
2586     {
2587         structureReference = (buildType == VK_ACCELERATION_STRUCTURE_BUILD_TYPE_DEVICE_KHR) ?
2588                                  uint64_t(accelerationStructureAddress) :
2589                                  uint64_t(accelerationStructureKHR.getInternal());
2590     }
2591 
2592     VkAccelerationStructureInstanceKHR accelerationStructureInstanceKHR = makeVkAccelerationStructureInstanceKHR(
2593         instanceData.matrix,                                 //  VkTransformMatrixKHR transform;
2594         instanceData.instanceCustomIndex,                    //  uint32_t instanceCustomIndex:24;
2595         instanceData.mask,                                   //  uint32_t mask:8;
2596         instanceData.instanceShaderBindingTableRecordOffset, //  uint32_t instanceShaderBindingTableRecordOffset:24;
2597         instanceData.flags,                                  //  VkGeometryInstanceFlagsKHR flags:8;
2598         structureReference                                   //  uint64_t accelerationStructureReference;
2599     );
2600 
2601     deMemcpy(bufferLocation, &accelerationStructureInstanceKHR, sizeof(VkAccelerationStructureInstanceKHR));
2602 }
2603 
updateInstanceBuffer(const DeviceInterface & vk,const VkDevice device,const std::vector<de::SharedPtr<BottomLevelAccelerationStructure>> & bottomLevelInstances,const std::vector<InstanceData> & instanceData,const BufferWithMemory * instanceBuffer,VkAccelerationStructureBuildTypeKHR buildType,bool inactiveInstances)2604 void updateInstanceBuffer(const DeviceInterface &vk, const VkDevice device,
2605                           const std::vector<de::SharedPtr<BottomLevelAccelerationStructure>> &bottomLevelInstances,
2606                           const std::vector<InstanceData> &instanceData, const BufferWithMemory *instanceBuffer,
2607                           VkAccelerationStructureBuildTypeKHR buildType, bool inactiveInstances)
2608 {
2609     DE_ASSERT(bottomLevelInstances.size() != 0);
2610     DE_ASSERT(bottomLevelInstances.size() == instanceData.size());
2611 
2612     auto &instancesAlloc      = instanceBuffer->getAllocation();
2613     auto bufferStart          = reinterpret_cast<uint8_t *>(instancesAlloc.getHostPtr());
2614     VkDeviceSize bufferOffset = 0ull;
2615 
2616     for (size_t instanceNdx = 0; instanceNdx < bottomLevelInstances.size(); ++instanceNdx)
2617     {
2618         const auto &blas = *bottomLevelInstances[instanceNdx];
2619         updateSingleInstance(vk, device, blas, instanceData[instanceNdx], bufferStart + bufferOffset, buildType,
2620                              inactiveInstances);
2621         bufferOffset += sizeof(VkAccelerationStructureInstanceKHR);
2622     }
2623 
2624     flushMappedMemoryRange(vk, device, instancesAlloc.getMemory(), instancesAlloc.getOffset(), VK_WHOLE_SIZE);
2625 }
2626 
2627 class TopLevelAccelerationStructureKHR : public TopLevelAccelerationStructure
2628 {
2629 public:
2630     static uint32_t getRequiredAllocationCount(void);
2631 
2632     TopLevelAccelerationStructureKHR();
2633     TopLevelAccelerationStructureKHR(const TopLevelAccelerationStructureKHR &other) = delete;
2634     virtual ~TopLevelAccelerationStructureKHR();
2635 
2636     void setBuildType(const VkAccelerationStructureBuildTypeKHR buildType) override;
2637     void setCreateFlags(const VkAccelerationStructureCreateFlagsKHR createFlags) override;
2638     void setCreateGeneric(bool createGeneric) override;
2639     void setCreationBufferUnbounded(bool creationBufferUnbounded) override;
2640     void setBuildFlags(const VkBuildAccelerationStructureFlagsKHR buildFlags) override;
2641     void setBuildWithoutPrimitives(bool buildWithoutPrimitives) override;
2642     void setInactiveInstances(bool inactiveInstances) override;
2643     void setDeferredOperation(const bool deferredOperation, const uint32_t workerThreadCount) override;
2644     void setUseArrayOfPointers(const bool useArrayOfPointers) override;
2645     void setIndirectBuildParameters(const VkBuffer indirectBuffer, const VkDeviceSize indirectBufferOffset,
2646                                     const uint32_t indirectBufferStride) override;
2647     void setUsePPGeometries(const bool usePPGeometries) override;
2648     void setTryCachedMemory(const bool tryCachedMemory) override;
2649     VkBuildAccelerationStructureFlagsKHR getBuildFlags() const override;
2650 
2651     void getCreationSizes(const DeviceInterface &vk, const VkDevice device, const VkDeviceSize structureSize,
2652                           CreationSizes &sizes) override;
2653     void create(const DeviceInterface &vk, const VkDevice device, Allocator &allocator, VkDeviceSize structureSize,
2654                 VkDeviceAddress deviceAddress = 0u, const void *pNext = nullptr,
2655                 const MemoryRequirement &addMemoryRequirement = MemoryRequirement::Any,
2656                 const VkBuffer creationBuffer = VK_NULL_HANDLE, const VkDeviceSize creationBufferSize = 0u) override;
2657     void build(const DeviceInterface &vk, const VkDevice device, const VkCommandBuffer cmdBuffer,
2658                TopLevelAccelerationStructure *srcAccelerationStructure = nullptr) override;
2659     void copyFrom(const DeviceInterface &vk, const VkDevice device, const VkCommandBuffer cmdBuffer,
2660                   TopLevelAccelerationStructure *accelerationStructure, bool compactCopy) override;
2661     void serialize(const DeviceInterface &vk, const VkDevice device, const VkCommandBuffer cmdBuffer,
2662                    SerialStorage *storage) override;
2663     void deserialize(const DeviceInterface &vk, const VkDevice device, const VkCommandBuffer cmdBuffer,
2664                      SerialStorage *storage) override;
2665 
2666     std::vector<VkDeviceSize> getSerializingSizes(const DeviceInterface &vk, const VkDevice device, const VkQueue queue,
2667                                                   const uint32_t queueFamilyIndex) override;
2668 
2669     std::vector<uint64_t> getSerializingAddresses(const DeviceInterface &vk, const VkDevice device) const override;
2670 
2671     const VkAccelerationStructureKHR *getPtr(void) const override;
2672 
2673     void updateInstanceMatrix(const DeviceInterface &vk, const VkDevice device, size_t instanceIndex,
2674                               const VkTransformMatrixKHR &matrix) override;
2675 
2676 protected:
2677     VkAccelerationStructureBuildTypeKHR m_buildType;
2678     VkAccelerationStructureCreateFlagsKHR m_createFlags;
2679     bool m_createGeneric;
2680     bool m_creationBufferUnbounded;
2681     VkBuildAccelerationStructureFlagsKHR m_buildFlags;
2682     bool m_buildWithoutPrimitives;
2683     bool m_inactiveInstances;
2684     bool m_deferredOperation;
2685     uint32_t m_workerThreadCount;
2686     bool m_useArrayOfPointers;
2687     de::MovePtr<BufferWithMemory> m_accelerationStructureBuffer;
2688     de::MovePtr<BufferWithMemory> m_instanceBuffer;
2689     de::MovePtr<BufferWithMemory> m_instanceAddressBuffer;
2690     de::MovePtr<BufferWithMemory> m_deviceScratchBuffer;
2691     std::vector<uint8_t> m_hostScratchBuffer;
2692     Move<VkAccelerationStructureKHR> m_accelerationStructureKHR;
2693     VkBuffer m_indirectBuffer;
2694     VkDeviceSize m_indirectBufferOffset;
2695     uint32_t m_indirectBufferStride;
2696     bool m_usePPGeometries;
2697     bool m_tryCachedMemory;
2698 
2699     void prepareInstances(const DeviceInterface &vk, const VkDevice device,
2700                           VkAccelerationStructureGeometryKHR &accelerationStructureGeometryKHR,
2701                           std::vector<uint32_t> &maxPrimitiveCounts);
2702 
2703     void serializeBottoms(const DeviceInterface &vk, const VkDevice device, const VkCommandBuffer cmdBuffer,
2704                           SerialStorage *storage, VkDeferredOperationKHR deferredOperation);
2705 
2706     void createAndDeserializeBottoms(const DeviceInterface &vk, const VkDevice device, const VkCommandBuffer cmdBuffer,
2707                                      Allocator &allocator, SerialStorage *storage) override;
2708 };
2709 
getRequiredAllocationCount(void)2710 uint32_t TopLevelAccelerationStructureKHR::getRequiredAllocationCount(void)
2711 {
2712     /*
2713         de::MovePtr<BufferWithMemory>                            m_instanceBuffer;
2714         de::MovePtr<Allocation>                                    m_accelerationStructureAlloc;
2715         de::MovePtr<BufferWithMemory>                            m_deviceScratchBuffer;
2716     */
2717     return 3u;
2718 }
2719 
TopLevelAccelerationStructureKHR()2720 TopLevelAccelerationStructureKHR::TopLevelAccelerationStructureKHR()
2721     : TopLevelAccelerationStructure()
2722     , m_buildType(VK_ACCELERATION_STRUCTURE_BUILD_TYPE_DEVICE_KHR)
2723     , m_createFlags(0u)
2724     , m_createGeneric(false)
2725     , m_creationBufferUnbounded(false)
2726     , m_buildFlags(0u)
2727     , m_buildWithoutPrimitives(false)
2728     , m_inactiveInstances(false)
2729     , m_deferredOperation(false)
2730     , m_workerThreadCount(0)
2731     , m_useArrayOfPointers(false)
2732     , m_accelerationStructureBuffer(nullptr)
2733     , m_instanceBuffer(nullptr)
2734     , m_instanceAddressBuffer(nullptr)
2735     , m_deviceScratchBuffer(nullptr)
2736     , m_accelerationStructureKHR()
2737     , m_indirectBuffer(VK_NULL_HANDLE)
2738     , m_indirectBufferOffset(0)
2739     , m_indirectBufferStride(0)
2740     , m_usePPGeometries(false)
2741     , m_tryCachedMemory(true)
2742 {
2743 }
2744 
~TopLevelAccelerationStructureKHR()2745 TopLevelAccelerationStructureKHR::~TopLevelAccelerationStructureKHR()
2746 {
2747 }
2748 
setBuildType(const VkAccelerationStructureBuildTypeKHR buildType)2749 void TopLevelAccelerationStructureKHR::setBuildType(const VkAccelerationStructureBuildTypeKHR buildType)
2750 {
2751     m_buildType = buildType;
2752 }
2753 
setCreateFlags(const VkAccelerationStructureCreateFlagsKHR createFlags)2754 void TopLevelAccelerationStructureKHR::setCreateFlags(const VkAccelerationStructureCreateFlagsKHR createFlags)
2755 {
2756     m_createFlags = createFlags;
2757 }
2758 
setCreateGeneric(bool createGeneric)2759 void TopLevelAccelerationStructureKHR::setCreateGeneric(bool createGeneric)
2760 {
2761     m_createGeneric = createGeneric;
2762 }
2763 
setCreationBufferUnbounded(bool creationBufferUnbounded)2764 void TopLevelAccelerationStructureKHR::setCreationBufferUnbounded(bool creationBufferUnbounded)
2765 {
2766     m_creationBufferUnbounded = creationBufferUnbounded;
2767 }
2768 
setInactiveInstances(bool inactiveInstances)2769 void TopLevelAccelerationStructureKHR::setInactiveInstances(bool inactiveInstances)
2770 {
2771     m_inactiveInstances = inactiveInstances;
2772 }
2773 
setBuildFlags(const VkBuildAccelerationStructureFlagsKHR buildFlags)2774 void TopLevelAccelerationStructureKHR::setBuildFlags(const VkBuildAccelerationStructureFlagsKHR buildFlags)
2775 {
2776     m_buildFlags = buildFlags;
2777 }
2778 
setBuildWithoutPrimitives(bool buildWithoutPrimitives)2779 void TopLevelAccelerationStructureKHR::setBuildWithoutPrimitives(bool buildWithoutPrimitives)
2780 {
2781     m_buildWithoutPrimitives = buildWithoutPrimitives;
2782 }
2783 
setDeferredOperation(const bool deferredOperation,const uint32_t workerThreadCount)2784 void TopLevelAccelerationStructureKHR::setDeferredOperation(const bool deferredOperation,
2785                                                             const uint32_t workerThreadCount)
2786 {
2787     m_deferredOperation = deferredOperation;
2788     m_workerThreadCount = workerThreadCount;
2789 }
2790 
setUseArrayOfPointers(const bool useArrayOfPointers)2791 void TopLevelAccelerationStructureKHR::setUseArrayOfPointers(const bool useArrayOfPointers)
2792 {
2793     m_useArrayOfPointers = useArrayOfPointers;
2794 }
2795 
setUsePPGeometries(const bool usePPGeometries)2796 void TopLevelAccelerationStructureKHR::setUsePPGeometries(const bool usePPGeometries)
2797 {
2798     m_usePPGeometries = usePPGeometries;
2799 }
2800 
setTryCachedMemory(const bool tryCachedMemory)2801 void TopLevelAccelerationStructureKHR::setTryCachedMemory(const bool tryCachedMemory)
2802 {
2803     m_tryCachedMemory = tryCachedMemory;
2804 }
2805 
setIndirectBuildParameters(const VkBuffer indirectBuffer,const VkDeviceSize indirectBufferOffset,const uint32_t indirectBufferStride)2806 void TopLevelAccelerationStructureKHR::setIndirectBuildParameters(const VkBuffer indirectBuffer,
2807                                                                   const VkDeviceSize indirectBufferOffset,
2808                                                                   const uint32_t indirectBufferStride)
2809 {
2810     m_indirectBuffer       = indirectBuffer;
2811     m_indirectBufferOffset = indirectBufferOffset;
2812     m_indirectBufferStride = indirectBufferStride;
2813 }
2814 
getBuildFlags() const2815 VkBuildAccelerationStructureFlagsKHR TopLevelAccelerationStructureKHR::getBuildFlags() const
2816 {
2817     return m_buildFlags;
2818 }
2819 
sum() const2820 VkDeviceSize TopLevelAccelerationStructure::CreationSizes::sum() const
2821 {
2822     return structure + updateScratch + buildScratch + instancePointers + instancesBuffer;
2823 }
2824 
getCreationSizes(const DeviceInterface & vk,const VkDevice device,const VkDeviceSize structureSize,CreationSizes & sizes)2825 void TopLevelAccelerationStructureKHR::getCreationSizes(const DeviceInterface &vk, const VkDevice device,
2826                                                         const VkDeviceSize structureSize, CreationSizes &sizes)
2827 {
2828     // AS may be built from geometries using vkCmdBuildAccelerationStructureKHR / vkBuildAccelerationStructureKHR
2829     // or may be copied/compacted/deserialized from other AS ( in this case AS does not need geometries, but it needs to know its size before creation ).
2830     DE_ASSERT(!m_bottomLevelInstances.empty() != !(structureSize == 0)); // logical xor
2831 
2832     if (structureSize == 0)
2833     {
2834         VkAccelerationStructureGeometryKHR accelerationStructureGeometryKHR;
2835         const auto accelerationStructureGeometryKHRPtr = &accelerationStructureGeometryKHR;
2836         std::vector<uint32_t> maxPrimitiveCounts;
2837         prepareInstances(vk, device, accelerationStructureGeometryKHR, maxPrimitiveCounts);
2838 
2839         VkAccelerationStructureBuildGeometryInfoKHR accelerationStructureBuildGeometryInfoKHR = {
2840             VK_STRUCTURE_TYPE_ACCELERATION_STRUCTURE_BUILD_GEOMETRY_INFO_KHR, //  VkStructureType sType;
2841             nullptr,                                                          //  const void* pNext;
2842             VK_ACCELERATION_STRUCTURE_TYPE_TOP_LEVEL_KHR,                     //  VkAccelerationStructureTypeKHR type;
2843             m_buildFlags,                                   //  VkBuildAccelerationStructureFlagsKHR flags;
2844             VK_BUILD_ACCELERATION_STRUCTURE_MODE_BUILD_KHR, //  VkBuildAccelerationStructureModeKHR mode;
2845             VK_NULL_HANDLE,                                 //  VkAccelerationStructureKHR srcAccelerationStructure;
2846             VK_NULL_HANDLE,                                 //  VkAccelerationStructureKHR dstAccelerationStructure;
2847             1u,                                             //  uint32_t geometryCount;
2848             (m_usePPGeometries ?
2849                  nullptr :
2850                  &accelerationStructureGeometryKHR), //  const VkAccelerationStructureGeometryKHR* pGeometries;
2851             (m_usePPGeometries ? &accelerationStructureGeometryKHRPtr :
2852                                  nullptr),      //  const VkAccelerationStructureGeometryKHR* const* ppGeometries;
2853             makeDeviceOrHostAddressKHR(nullptr) //  VkDeviceOrHostAddressKHR scratchData;
2854         };
2855 
2856         VkAccelerationStructureBuildSizesInfoKHR sizeInfo = {
2857             VK_STRUCTURE_TYPE_ACCELERATION_STRUCTURE_BUILD_SIZES_INFO_KHR, //  VkStructureType sType;
2858             nullptr,                                                       //  const void* pNext;
2859             0,                                                             //  VkDeviceSize accelerationStructureSize;
2860             0,                                                             //  VkDeviceSize updateScratchSize;
2861             0                                                              //  VkDeviceSize buildScratchSize;
2862         };
2863 
2864         vk.getAccelerationStructureBuildSizesKHR(device, m_buildType, &accelerationStructureBuildGeometryInfoKHR,
2865                                                  maxPrimitiveCounts.data(), &sizeInfo);
2866 
2867         sizes.structure     = sizeInfo.accelerationStructureSize;
2868         sizes.updateScratch = sizeInfo.updateScratchSize;
2869         sizes.buildScratch  = sizeInfo.buildScratchSize;
2870     }
2871     else
2872     {
2873         sizes.structure     = structureSize;
2874         sizes.updateScratch = 0u;
2875         sizes.buildScratch  = 0u;
2876     }
2877 
2878     sizes.instancePointers = 0u;
2879     if (m_useArrayOfPointers)
2880     {
2881         const size_t pointerSize = (m_buildType == VK_ACCELERATION_STRUCTURE_BUILD_TYPE_DEVICE_KHR) ?
2882                                        sizeof(VkDeviceOrHostAddressConstKHR::deviceAddress) :
2883                                        sizeof(VkDeviceOrHostAddressConstKHR::hostAddress);
2884         sizes.instancePointers   = static_cast<VkDeviceSize>(m_bottomLevelInstances.size() * pointerSize);
2885     }
2886 
2887     sizes.instancesBuffer = m_bottomLevelInstances.empty() ?
2888                                 0u :
2889                                 m_bottomLevelInstances.size() * sizeof(VkAccelerationStructureInstanceKHR);
2890 }
2891 
create(const DeviceInterface & vk,const VkDevice device,Allocator & allocator,VkDeviceSize structureSize,VkDeviceAddress deviceAddress,const void * pNext,const MemoryRequirement & addMemoryRequirement,const VkBuffer creationBuffer,const VkDeviceSize creationBufferSize)2892 void TopLevelAccelerationStructureKHR::create(const DeviceInterface &vk, const VkDevice device, Allocator &allocator,
2893                                               VkDeviceSize structureSize, VkDeviceAddress deviceAddress,
2894                                               const void *pNext, const MemoryRequirement &addMemoryRequirement,
2895                                               const VkBuffer creationBuffer, const VkDeviceSize creationBufferSize)
2896 {
2897     // AS may be built from geometries using vkCmdBuildAccelerationStructureKHR / vkBuildAccelerationStructureKHR
2898     // or may be copied/compacted/deserialized from other AS ( in this case AS does not need geometries, but it needs to know its size before creation ).
2899     DE_ASSERT(!m_bottomLevelInstances.empty() != !(structureSize == 0)); // logical xor
2900 
2901     if (structureSize == 0)
2902     {
2903         VkAccelerationStructureGeometryKHR accelerationStructureGeometryKHR;
2904         const auto accelerationStructureGeometryKHRPtr = &accelerationStructureGeometryKHR;
2905         std::vector<uint32_t> maxPrimitiveCounts;
2906         prepareInstances(vk, device, accelerationStructureGeometryKHR, maxPrimitiveCounts);
2907 
2908         VkAccelerationStructureBuildGeometryInfoKHR accelerationStructureBuildGeometryInfoKHR = {
2909             VK_STRUCTURE_TYPE_ACCELERATION_STRUCTURE_BUILD_GEOMETRY_INFO_KHR, //  VkStructureType sType;
2910             nullptr,                                                          //  const void* pNext;
2911             VK_ACCELERATION_STRUCTURE_TYPE_TOP_LEVEL_KHR,                     //  VkAccelerationStructureTypeKHR type;
2912             m_buildFlags,                                   //  VkBuildAccelerationStructureFlagsKHR flags;
2913             VK_BUILD_ACCELERATION_STRUCTURE_MODE_BUILD_KHR, //  VkBuildAccelerationStructureModeKHR mode;
2914             VK_NULL_HANDLE,                                 //  VkAccelerationStructureKHR srcAccelerationStructure;
2915             VK_NULL_HANDLE,                                 //  VkAccelerationStructureKHR dstAccelerationStructure;
2916             1u,                                             //  uint32_t geometryCount;
2917             (m_usePPGeometries ?
2918                  nullptr :
2919                  &accelerationStructureGeometryKHR), //  const VkAccelerationStructureGeometryKHR* pGeometries;
2920             (m_usePPGeometries ? &accelerationStructureGeometryKHRPtr :
2921                                  nullptr),      //  const VkAccelerationStructureGeometryKHR* const* ppGeometries;
2922             makeDeviceOrHostAddressKHR(nullptr) //  VkDeviceOrHostAddressKHR scratchData;
2923         };
2924 
2925         VkAccelerationStructureBuildSizesInfoKHR sizeInfo = {
2926             VK_STRUCTURE_TYPE_ACCELERATION_STRUCTURE_BUILD_SIZES_INFO_KHR, //  VkStructureType sType;
2927             nullptr,                                                       //  const void* pNext;
2928             0,                                                             //  VkDeviceSize accelerationStructureSize;
2929             0,                                                             //  VkDeviceSize updateScratchSize;
2930             0                                                              //  VkDeviceSize buildScratchSize;
2931         };
2932 
2933         vk.getAccelerationStructureBuildSizesKHR(device, m_buildType, &accelerationStructureBuildGeometryInfoKHR,
2934                                                  maxPrimitiveCounts.data(), &sizeInfo);
2935 
2936         m_structureSize     = sizeInfo.accelerationStructureSize;
2937         m_updateScratchSize = sizeInfo.updateScratchSize;
2938         m_buildScratchSize  = sizeInfo.buildScratchSize;
2939     }
2940     else
2941     {
2942         m_structureSize     = structureSize;
2943         m_updateScratchSize = 0u;
2944         m_buildScratchSize  = 0u;
2945     }
2946 
2947     const bool externalCreationBuffer = (creationBuffer != VK_NULL_HANDLE);
2948 
2949     if (externalCreationBuffer)
2950     {
2951         DE_UNREF(creationBufferSize); // For release builds.
2952         DE_ASSERT(creationBufferSize >= m_structureSize);
2953     }
2954 
2955     if (!externalCreationBuffer)
2956     {
2957         const VkBufferCreateInfo bufferCreateInfo =
2958             makeBufferCreateInfo(m_structureSize, VK_BUFFER_USAGE_ACCELERATION_STRUCTURE_STORAGE_BIT_KHR |
2959                                                       VK_BUFFER_USAGE_SHADER_DEVICE_ADDRESS_BIT);
2960         const MemoryRequirement memoryRequirement = addMemoryRequirement | MemoryRequirement::HostVisible |
2961                                                     MemoryRequirement::Coherent | MemoryRequirement::DeviceAddress;
2962         const bool bindMemOnCreation = (!m_creationBufferUnbounded);
2963 
2964         try
2965         {
2966             m_accelerationStructureBuffer = de::MovePtr<BufferWithMemory>(
2967                 new BufferWithMemory(vk, device, allocator, bufferCreateInfo,
2968                                      (MemoryRequirement::Cached | memoryRequirement), bindMemOnCreation));
2969         }
2970         catch (const tcu::NotSupportedError &)
2971         {
2972             // retry without Cached flag
2973             m_accelerationStructureBuffer = de::MovePtr<BufferWithMemory>(
2974                 new BufferWithMemory(vk, device, allocator, bufferCreateInfo, memoryRequirement, bindMemOnCreation));
2975         }
2976     }
2977 
2978     const auto createInfoBuffer = (externalCreationBuffer ? creationBuffer : m_accelerationStructureBuffer->get());
2979     {
2980         const VkAccelerationStructureTypeKHR structureType =
2981             (m_createGeneric ? VK_ACCELERATION_STRUCTURE_TYPE_GENERIC_KHR :
2982                                VK_ACCELERATION_STRUCTURE_TYPE_TOP_LEVEL_KHR);
2983         const VkAccelerationStructureCreateInfoKHR accelerationStructureCreateInfoKHR = {
2984             VK_STRUCTURE_TYPE_ACCELERATION_STRUCTURE_CREATE_INFO_KHR, //  VkStructureType sType;
2985             pNext,                                                    //  const void* pNext;
2986             m_createFlags,    //  VkAccelerationStructureCreateFlagsKHR createFlags;
2987             createInfoBuffer, //  VkBuffer buffer;
2988             0u,               //  VkDeviceSize offset;
2989             m_structureSize,  //  VkDeviceSize size;
2990             structureType,    //  VkAccelerationStructureTypeKHR type;
2991             deviceAddress     //  VkDeviceAddress deviceAddress;
2992         };
2993 
2994         m_accelerationStructureKHR =
2995             createAccelerationStructureKHR(vk, device, &accelerationStructureCreateInfoKHR, nullptr);
2996 
2997         // Make sure buffer memory is always bound after creation.
2998         if (!externalCreationBuffer)
2999             m_accelerationStructureBuffer->bindMemory();
3000     }
3001 
3002     if (m_buildScratchSize > 0u || m_updateScratchSize > 0u)
3003     {
3004         VkDeviceSize scratch_size = de::max(m_buildScratchSize, m_updateScratchSize);
3005         if (m_buildType == VK_ACCELERATION_STRUCTURE_BUILD_TYPE_DEVICE_KHR)
3006         {
3007             const VkBufferCreateInfo bufferCreateInfo = makeBufferCreateInfo(
3008                 scratch_size, VK_BUFFER_USAGE_STORAGE_BUFFER_BIT | VK_BUFFER_USAGE_SHADER_DEVICE_ADDRESS_BIT);
3009             m_deviceScratchBuffer = de::MovePtr<BufferWithMemory>(new BufferWithMemory(
3010                 vk, device, allocator, bufferCreateInfo,
3011                 MemoryRequirement::HostVisible | MemoryRequirement::Coherent | MemoryRequirement::DeviceAddress));
3012         }
3013         else
3014         {
3015             m_hostScratchBuffer.resize(static_cast<size_t>(scratch_size));
3016         }
3017     }
3018 
3019     if (m_useArrayOfPointers)
3020     {
3021         const size_t pointerSize = (m_buildType == VK_ACCELERATION_STRUCTURE_BUILD_TYPE_DEVICE_KHR) ?
3022                                        sizeof(VkDeviceOrHostAddressConstKHR::deviceAddress) :
3023                                        sizeof(VkDeviceOrHostAddressConstKHR::hostAddress);
3024         const VkBufferCreateInfo bufferCreateInfo =
3025             makeBufferCreateInfo(static_cast<VkDeviceSize>(m_bottomLevelInstances.size() * pointerSize),
3026                                  VK_BUFFER_USAGE_ACCELERATION_STRUCTURE_BUILD_INPUT_READ_ONLY_BIT_KHR |
3027                                      VK_BUFFER_USAGE_SHADER_DEVICE_ADDRESS_BIT);
3028         m_instanceAddressBuffer = de::MovePtr<BufferWithMemory>(new BufferWithMemory(
3029             vk, device, allocator, bufferCreateInfo,
3030             MemoryRequirement::HostVisible | MemoryRequirement::Coherent | MemoryRequirement::DeviceAddress));
3031     }
3032 
3033     if (!m_bottomLevelInstances.empty())
3034         m_instanceBuffer = de::MovePtr<BufferWithMemory>(
3035             createInstanceBuffer(vk, device, allocator, m_bottomLevelInstances, m_instanceData, m_tryCachedMemory));
3036 }
3037 
updateInstanceMatrix(const DeviceInterface & vk,const VkDevice device,size_t instanceIndex,const VkTransformMatrixKHR & matrix)3038 void TopLevelAccelerationStructureKHR::updateInstanceMatrix(const DeviceInterface &vk, const VkDevice device,
3039                                                             size_t instanceIndex, const VkTransformMatrixKHR &matrix)
3040 {
3041     DE_ASSERT(instanceIndex < m_bottomLevelInstances.size());
3042     DE_ASSERT(instanceIndex < m_instanceData.size());
3043 
3044     const auto &blas          = *m_bottomLevelInstances[instanceIndex];
3045     auto &instanceData        = m_instanceData[instanceIndex];
3046     auto &instancesAlloc      = m_instanceBuffer->getAllocation();
3047     auto bufferStart          = reinterpret_cast<uint8_t *>(instancesAlloc.getHostPtr());
3048     VkDeviceSize bufferOffset = sizeof(VkAccelerationStructureInstanceKHR) * instanceIndex;
3049 
3050     instanceData.matrix = matrix;
3051     updateSingleInstance(vk, device, blas, instanceData, bufferStart + bufferOffset, m_buildType, m_inactiveInstances);
3052     flushMappedMemoryRange(vk, device, instancesAlloc.getMemory(), instancesAlloc.getOffset(), VK_WHOLE_SIZE);
3053 }
3054 
build(const DeviceInterface & vk,const VkDevice device,const VkCommandBuffer cmdBuffer,TopLevelAccelerationStructure * srcAccelerationStructure)3055 void TopLevelAccelerationStructureKHR::build(const DeviceInterface &vk, const VkDevice device,
3056                                              const VkCommandBuffer cmdBuffer,
3057                                              TopLevelAccelerationStructure *srcAccelerationStructure)
3058 {
3059     DE_ASSERT(!m_bottomLevelInstances.empty());
3060     DE_ASSERT(m_accelerationStructureKHR.get() != VK_NULL_HANDLE);
3061     DE_ASSERT(m_buildScratchSize != 0);
3062 
3063     updateInstanceBuffer(vk, device, m_bottomLevelInstances, m_instanceData, m_instanceBuffer.get(), m_buildType,
3064                          m_inactiveInstances);
3065 
3066     VkAccelerationStructureGeometryKHR accelerationStructureGeometryKHR;
3067     const auto accelerationStructureGeometryKHRPtr = &accelerationStructureGeometryKHR;
3068     std::vector<uint32_t> maxPrimitiveCounts;
3069     prepareInstances(vk, device, accelerationStructureGeometryKHR, maxPrimitiveCounts);
3070 
3071     VkDeviceOrHostAddressKHR scratchData = (m_buildType == VK_ACCELERATION_STRUCTURE_BUILD_TYPE_DEVICE_KHR) ?
3072                                                makeDeviceOrHostAddressKHR(vk, device, m_deviceScratchBuffer->get(), 0) :
3073                                                makeDeviceOrHostAddressKHR(m_hostScratchBuffer.data());
3074 
3075     VkAccelerationStructureKHR srcStructure =
3076         (srcAccelerationStructure != nullptr) ? *(srcAccelerationStructure->getPtr()) : VK_NULL_HANDLE;
3077     VkBuildAccelerationStructureModeKHR mode = (srcAccelerationStructure != nullptr) ?
3078                                                    VK_BUILD_ACCELERATION_STRUCTURE_MODE_UPDATE_KHR :
3079                                                    VK_BUILD_ACCELERATION_STRUCTURE_MODE_BUILD_KHR;
3080 
3081     VkAccelerationStructureBuildGeometryInfoKHR accelerationStructureBuildGeometryInfoKHR = {
3082         VK_STRUCTURE_TYPE_ACCELERATION_STRUCTURE_BUILD_GEOMETRY_INFO_KHR, //  VkStructureType sType;
3083         nullptr,                                                          //  const void* pNext;
3084         VK_ACCELERATION_STRUCTURE_TYPE_TOP_LEVEL_KHR,                     //  VkAccelerationStructureTypeKHR type;
3085         m_buildFlags,                     //  VkBuildAccelerationStructureFlagsKHR flags;
3086         mode,                             //  VkBuildAccelerationStructureModeKHR mode;
3087         srcStructure,                     //  VkAccelerationStructureKHR srcAccelerationStructure;
3088         m_accelerationStructureKHR.get(), //  VkAccelerationStructureKHR dstAccelerationStructure;
3089         1u,                               //  uint32_t geometryCount;
3090         (m_usePPGeometries ?
3091              nullptr :
3092              &accelerationStructureGeometryKHR), //  const VkAccelerationStructureGeometryKHR* pGeometries;
3093         (m_usePPGeometries ? &accelerationStructureGeometryKHRPtr :
3094                              nullptr), //  const VkAccelerationStructureGeometryKHR* const* ppGeometries;
3095         scratchData                    //  VkDeviceOrHostAddressKHR scratchData;
3096     };
3097 
3098     const uint32_t primitiveCount =
3099         (m_buildWithoutPrimitives ? 0u : static_cast<uint32_t>(m_bottomLevelInstances.size()));
3100 
3101     VkAccelerationStructureBuildRangeInfoKHR accelerationStructureBuildRangeInfoKHR = {
3102         primitiveCount, //  uint32_t primitiveCount;
3103         0,              //  uint32_t primitiveOffset;
3104         0,              //  uint32_t firstVertex;
3105         0               //  uint32_t transformOffset;
3106     };
3107     VkAccelerationStructureBuildRangeInfoKHR *accelerationStructureBuildRangeInfoKHRPtr =
3108         &accelerationStructureBuildRangeInfoKHR;
3109 
3110     if (m_buildType == VK_ACCELERATION_STRUCTURE_BUILD_TYPE_DEVICE_KHR)
3111     {
3112         if (m_indirectBuffer == VK_NULL_HANDLE)
3113             vk.cmdBuildAccelerationStructuresKHR(
3114                 cmdBuffer, 1u, &accelerationStructureBuildGeometryInfoKHR,
3115                 (const VkAccelerationStructureBuildRangeInfoKHR **)&accelerationStructureBuildRangeInfoKHRPtr);
3116         else
3117         {
3118             VkDeviceAddress indirectDeviceAddress =
3119                 getBufferDeviceAddress(vk, device, m_indirectBuffer, m_indirectBufferOffset);
3120             uint32_t *pMaxPrimitiveCounts = maxPrimitiveCounts.data();
3121             vk.cmdBuildAccelerationStructuresIndirectKHR(cmdBuffer, 1u, &accelerationStructureBuildGeometryInfoKHR,
3122                                                          &indirectDeviceAddress, &m_indirectBufferStride,
3123                                                          &pMaxPrimitiveCounts);
3124         }
3125     }
3126     else if (!m_deferredOperation)
3127     {
3128         VK_CHECK(vk.buildAccelerationStructuresKHR(
3129             device, VK_NULL_HANDLE, 1u, &accelerationStructureBuildGeometryInfoKHR,
3130             (const VkAccelerationStructureBuildRangeInfoKHR **)&accelerationStructureBuildRangeInfoKHRPtr));
3131     }
3132     else
3133     {
3134         const auto deferredOperationPtr = createDeferredOperationKHR(vk, device);
3135         const auto deferredOperation    = deferredOperationPtr.get();
3136 
3137         VkResult result = vk.buildAccelerationStructuresKHR(
3138             device, deferredOperation, 1u, &accelerationStructureBuildGeometryInfoKHR,
3139             (const VkAccelerationStructureBuildRangeInfoKHR **)&accelerationStructureBuildRangeInfoKHRPtr);
3140 
3141         DE_ASSERT(result == VK_OPERATION_DEFERRED_KHR || result == VK_OPERATION_NOT_DEFERRED_KHR ||
3142                   result == VK_SUCCESS);
3143 
3144         finishDeferredOperation(vk, device, deferredOperation, m_workerThreadCount,
3145                                 result == VK_OPERATION_NOT_DEFERRED_KHR);
3146 
3147         accelerationStructureBuildGeometryInfoKHR.pNext = nullptr;
3148     }
3149 
3150     if (m_buildType == VK_ACCELERATION_STRUCTURE_BUILD_TYPE_DEVICE_KHR)
3151     {
3152         const VkAccessFlags accessMasks =
3153             VK_ACCESS_ACCELERATION_STRUCTURE_WRITE_BIT_KHR | VK_ACCESS_ACCELERATION_STRUCTURE_READ_BIT_KHR;
3154         const VkMemoryBarrier memBarrier = makeMemoryBarrier(accessMasks, accessMasks);
3155 
3156         cmdPipelineMemoryBarrier(vk, cmdBuffer, VK_PIPELINE_STAGE_ACCELERATION_STRUCTURE_BUILD_BIT_KHR,
3157                                  VK_PIPELINE_STAGE_ALL_COMMANDS_BIT, &memBarrier);
3158     }
3159 }
3160 
copyFrom(const DeviceInterface & vk,const VkDevice device,const VkCommandBuffer cmdBuffer,TopLevelAccelerationStructure * accelerationStructure,bool compactCopy)3161 void TopLevelAccelerationStructureKHR::copyFrom(const DeviceInterface &vk, const VkDevice device,
3162                                                 const VkCommandBuffer cmdBuffer,
3163                                                 TopLevelAccelerationStructure *accelerationStructure, bool compactCopy)
3164 {
3165     DE_ASSERT(m_accelerationStructureKHR.get() != VK_NULL_HANDLE);
3166     DE_ASSERT(accelerationStructure != nullptr);
3167 
3168     VkCopyAccelerationStructureInfoKHR copyAccelerationStructureInfo = {
3169         VK_STRUCTURE_TYPE_COPY_ACCELERATION_STRUCTURE_INFO_KHR, // VkStructureType sType;
3170         nullptr,                                                // const void* pNext;
3171         *(accelerationStructure->getPtr()),                     // VkAccelerationStructureKHR src;
3172         *(getPtr()),                                            // VkAccelerationStructureKHR dst;
3173         compactCopy ? VK_COPY_ACCELERATION_STRUCTURE_MODE_COMPACT_KHR :
3174                       VK_COPY_ACCELERATION_STRUCTURE_MODE_CLONE_KHR // VkCopyAccelerationStructureModeKHR mode;
3175     };
3176 
3177     if (m_buildType == VK_ACCELERATION_STRUCTURE_BUILD_TYPE_DEVICE_KHR)
3178     {
3179         vk.cmdCopyAccelerationStructureKHR(cmdBuffer, &copyAccelerationStructureInfo);
3180     }
3181     else if (!m_deferredOperation)
3182     {
3183         VK_CHECK(vk.copyAccelerationStructureKHR(device, VK_NULL_HANDLE, &copyAccelerationStructureInfo));
3184     }
3185     else
3186     {
3187         const auto deferredOperationPtr = createDeferredOperationKHR(vk, device);
3188         const auto deferredOperation    = deferredOperationPtr.get();
3189 
3190         VkResult result = vk.copyAccelerationStructureKHR(device, deferredOperation, &copyAccelerationStructureInfo);
3191 
3192         DE_ASSERT(result == VK_OPERATION_DEFERRED_KHR || result == VK_OPERATION_NOT_DEFERRED_KHR ||
3193                   result == VK_SUCCESS);
3194 
3195         finishDeferredOperation(vk, device, deferredOperation, m_workerThreadCount,
3196                                 result == VK_OPERATION_NOT_DEFERRED_KHR);
3197     }
3198 
3199     if (m_buildType == VK_ACCELERATION_STRUCTURE_BUILD_TYPE_DEVICE_KHR)
3200     {
3201         const VkAccessFlags accessMasks =
3202             VK_ACCESS_ACCELERATION_STRUCTURE_WRITE_BIT_KHR | VK_ACCESS_ACCELERATION_STRUCTURE_READ_BIT_KHR;
3203         const VkMemoryBarrier memBarrier = makeMemoryBarrier(accessMasks, accessMasks);
3204 
3205         cmdPipelineMemoryBarrier(vk, cmdBuffer, VK_PIPELINE_STAGE_ACCELERATION_STRUCTURE_BUILD_BIT_KHR,
3206                                  VK_PIPELINE_STAGE_ALL_COMMANDS_BIT, &memBarrier);
3207     }
3208 }
3209 
serialize(const DeviceInterface & vk,const VkDevice device,const VkCommandBuffer cmdBuffer,SerialStorage * storage)3210 void TopLevelAccelerationStructureKHR::serialize(const DeviceInterface &vk, const VkDevice device,
3211                                                  const VkCommandBuffer cmdBuffer, SerialStorage *storage)
3212 {
3213     DE_ASSERT(m_accelerationStructureKHR.get() != VK_NULL_HANDLE);
3214     DE_ASSERT(storage != nullptr);
3215 
3216     const VkCopyAccelerationStructureToMemoryInfoKHR copyAccelerationStructureInfo = {
3217         VK_STRUCTURE_TYPE_COPY_ACCELERATION_STRUCTURE_TO_MEMORY_INFO_KHR, // VkStructureType sType;
3218         nullptr,                                                          // const void* pNext;
3219         *(getPtr()),                                                      // VkAccelerationStructureKHR src;
3220         storage->getAddress(vk, device, m_buildType),                     // VkDeviceOrHostAddressKHR dst;
3221         VK_COPY_ACCELERATION_STRUCTURE_MODE_SERIALIZE_KHR                 // VkCopyAccelerationStructureModeKHR mode;
3222     };
3223 
3224     if (m_buildType == VK_ACCELERATION_STRUCTURE_BUILD_TYPE_DEVICE_KHR)
3225     {
3226         vk.cmdCopyAccelerationStructureToMemoryKHR(cmdBuffer, &copyAccelerationStructureInfo);
3227         if (storage->hasDeepFormat())
3228             serializeBottoms(vk, device, cmdBuffer, storage, VK_NULL_HANDLE);
3229     }
3230     else if (!m_deferredOperation)
3231     {
3232         VK_CHECK(vk.copyAccelerationStructureToMemoryKHR(device, VK_NULL_HANDLE, &copyAccelerationStructureInfo));
3233         if (storage->hasDeepFormat())
3234             serializeBottoms(vk, device, cmdBuffer, storage, VK_NULL_HANDLE);
3235     }
3236     else
3237     {
3238         const auto deferredOperationPtr = createDeferredOperationKHR(vk, device);
3239         const auto deferredOperation    = deferredOperationPtr.get();
3240 
3241         const VkResult result =
3242             vk.copyAccelerationStructureToMemoryKHR(device, deferredOperation, &copyAccelerationStructureInfo);
3243 
3244         DE_ASSERT(result == VK_OPERATION_DEFERRED_KHR || result == VK_OPERATION_NOT_DEFERRED_KHR ||
3245                   result == VK_SUCCESS);
3246         if (storage->hasDeepFormat())
3247             serializeBottoms(vk, device, cmdBuffer, storage, deferredOperation);
3248 
3249         finishDeferredOperation(vk, device, deferredOperation, m_workerThreadCount,
3250                                 result == VK_OPERATION_NOT_DEFERRED_KHR);
3251     }
3252 }
3253 
deserialize(const DeviceInterface & vk,const VkDevice device,const VkCommandBuffer cmdBuffer,SerialStorage * storage)3254 void TopLevelAccelerationStructureKHR::deserialize(const DeviceInterface &vk, const VkDevice device,
3255                                                    const VkCommandBuffer cmdBuffer, SerialStorage *storage)
3256 {
3257     DE_ASSERT(m_accelerationStructureKHR.get() != VK_NULL_HANDLE);
3258     DE_ASSERT(storage != nullptr);
3259 
3260     const VkCopyMemoryToAccelerationStructureInfoKHR copyAccelerationStructureInfo = {
3261         VK_STRUCTURE_TYPE_COPY_MEMORY_TO_ACCELERATION_STRUCTURE_INFO_KHR, // VkStructureType sType;
3262         nullptr,                                                          // const void* pNext;
3263         storage->getAddressConst(vk, device, m_buildType),                // VkDeviceOrHostAddressConstKHR src;
3264         *(getPtr()),                                                      // VkAccelerationStructureKHR dst;
3265         VK_COPY_ACCELERATION_STRUCTURE_MODE_DESERIALIZE_KHR               // VkCopyAccelerationStructureModeKHR mode;
3266     };
3267 
3268     if (m_buildType == VK_ACCELERATION_STRUCTURE_BUILD_TYPE_DEVICE_KHR)
3269     {
3270         vk.cmdCopyMemoryToAccelerationStructureKHR(cmdBuffer, &copyAccelerationStructureInfo);
3271     }
3272     else if (!m_deferredOperation)
3273     {
3274         VK_CHECK(vk.copyMemoryToAccelerationStructureKHR(device, VK_NULL_HANDLE, &copyAccelerationStructureInfo));
3275     }
3276     else
3277     {
3278         const auto deferredOperationPtr = createDeferredOperationKHR(vk, device);
3279         const auto deferredOperation    = deferredOperationPtr.get();
3280 
3281         const VkResult result =
3282             vk.copyMemoryToAccelerationStructureKHR(device, deferredOperation, &copyAccelerationStructureInfo);
3283 
3284         DE_ASSERT(result == VK_OPERATION_DEFERRED_KHR || result == VK_OPERATION_NOT_DEFERRED_KHR ||
3285                   result == VK_SUCCESS);
3286 
3287         finishDeferredOperation(vk, device, deferredOperation, m_workerThreadCount,
3288                                 result == VK_OPERATION_NOT_DEFERRED_KHR);
3289     }
3290 
3291     if (m_buildType == VK_ACCELERATION_STRUCTURE_BUILD_TYPE_DEVICE_KHR)
3292     {
3293         const VkAccessFlags accessMasks =
3294             VK_ACCESS_ACCELERATION_STRUCTURE_WRITE_BIT_KHR | VK_ACCESS_ACCELERATION_STRUCTURE_READ_BIT_KHR;
3295         const VkMemoryBarrier memBarrier = makeMemoryBarrier(accessMasks, accessMasks);
3296 
3297         cmdPipelineMemoryBarrier(vk, cmdBuffer, VK_PIPELINE_STAGE_ACCELERATION_STRUCTURE_BUILD_BIT_KHR,
3298                                  VK_PIPELINE_STAGE_ALL_COMMANDS_BIT, &memBarrier);
3299     }
3300 }
3301 
serializeBottoms(const DeviceInterface & vk,const VkDevice device,const VkCommandBuffer cmdBuffer,SerialStorage * storage,VkDeferredOperationKHR deferredOperation)3302 void TopLevelAccelerationStructureKHR::serializeBottoms(const DeviceInterface &vk, const VkDevice device,
3303                                                         const VkCommandBuffer cmdBuffer, SerialStorage *storage,
3304                                                         VkDeferredOperationKHR deferredOperation)
3305 {
3306     DE_UNREF(deferredOperation);
3307     DE_ASSERT(storage->hasDeepFormat());
3308 
3309     const std::vector<uint64_t> &addresses = storage->getSerialInfo().addresses();
3310     const std::size_t cbottoms             = m_bottomLevelInstances.size();
3311 
3312     uint32_t storageIndex = 0;
3313     std::vector<uint64_t> matches;
3314 
3315     for (std::size_t i = 0; i < cbottoms; ++i)
3316     {
3317         const uint64_t &lookAddr = addresses[i + 1];
3318         auto end                 = matches.end();
3319         auto match = std::find_if(matches.begin(), end, [&](const uint64_t &item) { return item == lookAddr; });
3320         if (match == end)
3321         {
3322             matches.emplace_back(lookAddr);
3323             m_bottomLevelInstances[i].get()->serialize(vk, device, cmdBuffer,
3324                                                        storage->getBottomStorage(storageIndex).get());
3325             storageIndex += 1;
3326         }
3327     }
3328 }
3329 
createAndDeserializeBottoms(const DeviceInterface & vk,const VkDevice device,const VkCommandBuffer cmdBuffer,Allocator & allocator,SerialStorage * storage)3330 void TopLevelAccelerationStructureKHR::createAndDeserializeBottoms(const DeviceInterface &vk, const VkDevice device,
3331                                                                    const VkCommandBuffer cmdBuffer,
3332                                                                    Allocator &allocator, SerialStorage *storage)
3333 {
3334     DE_ASSERT(storage->hasDeepFormat());
3335     DE_ASSERT(m_bottomLevelInstances.size() == 0);
3336 
3337     const std::vector<uint64_t> &addresses = storage->getSerialInfo().addresses();
3338     const std::size_t cbottoms             = addresses.size() - 1;
3339     uint32_t storageIndex                  = 0;
3340     std::vector<std::pair<uint64_t, std::size_t>> matches;
3341 
3342     for (std::size_t i = 0; i < cbottoms; ++i)
3343     {
3344         const uint64_t &lookAddr = addresses[i + 1];
3345         auto end                 = matches.end();
3346         auto match               = std::find_if(matches.begin(), end,
3347                                                 [&](const std::pair<uint64_t, std::size_t> &item) { return item.first == lookAddr; });
3348         if (match != end)
3349         {
3350             m_bottomLevelInstances.emplace_back(m_bottomLevelInstances[match->second]);
3351         }
3352         else
3353         {
3354             de::MovePtr<BottomLevelAccelerationStructure> blas = makeBottomLevelAccelerationStructure();
3355             blas->createAndDeserializeFrom(vk, device, cmdBuffer, allocator,
3356                                            storage->getBottomStorage(storageIndex).get());
3357             m_bottomLevelInstances.emplace_back(de::SharedPtr<BottomLevelAccelerationStructure>(blas.release()));
3358             matches.emplace_back(lookAddr, i);
3359             storageIndex += 1;
3360         }
3361     }
3362 
3363     std::vector<uint64_t> newAddresses = getSerializingAddresses(vk, device);
3364     DE_ASSERT(addresses.size() == newAddresses.size());
3365 
3366     SerialStorage::AccelerationStructureHeader *header = storage->getASHeader();
3367     DE_ASSERT(cbottoms == header->handleCount);
3368 
3369     // finally update bottom-level AS addresses before top-level AS deserialization
3370     for (std::size_t i = 0; i < cbottoms; ++i)
3371     {
3372         header->handleArray[i] = newAddresses[i + 1];
3373     }
3374 }
3375 
getSerializingSizes(const DeviceInterface & vk,const VkDevice device,const VkQueue queue,const uint32_t queueFamilyIndex)3376 std::vector<VkDeviceSize> TopLevelAccelerationStructureKHR::getSerializingSizes(const DeviceInterface &vk,
3377                                                                                 const VkDevice device,
3378                                                                                 const VkQueue queue,
3379                                                                                 const uint32_t queueFamilyIndex)
3380 {
3381     const uint32_t queryCount(uint32_t(m_bottomLevelInstances.size()) + 1);
3382     std::vector<VkAccelerationStructureKHR> handles(queryCount);
3383     std::vector<VkDeviceSize> sizes(queryCount);
3384 
3385     handles[0] = m_accelerationStructureKHR.get();
3386 
3387     for (uint32_t h = 1; h < queryCount; ++h)
3388         handles[h] = *m_bottomLevelInstances[h - 1].get()->getPtr();
3389 
3390     if (VK_ACCELERATION_STRUCTURE_BUILD_TYPE_HOST_KHR == m_buildType)
3391         queryAccelerationStructureSize(vk, device, nullptr, handles, m_buildType, VK_NULL_HANDLE,
3392                                        VK_QUERY_TYPE_ACCELERATION_STRUCTURE_SERIALIZATION_SIZE_KHR, 0u, sizes);
3393     else
3394     {
3395         const Move<VkCommandPool> cmdPool = createCommandPool(vk, device, 0, queueFamilyIndex);
3396         const Move<VkCommandBuffer> cmdBuffer =
3397             allocateCommandBuffer(vk, device, *cmdPool, VK_COMMAND_BUFFER_LEVEL_PRIMARY);
3398         const Move<VkQueryPool> queryPool =
3399             makeQueryPool(vk, device, VK_QUERY_TYPE_ACCELERATION_STRUCTURE_SERIALIZATION_SIZE_KHR, queryCount);
3400 
3401         beginCommandBuffer(vk, *cmdBuffer);
3402         queryAccelerationStructureSize(vk, device, *cmdBuffer, handles, m_buildType, *queryPool,
3403                                        VK_QUERY_TYPE_ACCELERATION_STRUCTURE_SERIALIZATION_SIZE_KHR, 0u, sizes);
3404         endCommandBuffer(vk, *cmdBuffer);
3405         submitCommandsAndWait(vk, device, queue, cmdBuffer.get());
3406 
3407         VK_CHECK(vk.getQueryPoolResults(device, *queryPool, 0u, queryCount, queryCount * sizeof(VkDeviceSize),
3408                                         sizes.data(), sizeof(VkDeviceSize),
3409                                         VK_QUERY_RESULT_64_BIT | VK_QUERY_RESULT_WAIT_BIT));
3410     }
3411 
3412     return sizes;
3413 }
3414 
getSerializingAddresses(const DeviceInterface & vk,const VkDevice device) const3415 std::vector<uint64_t> TopLevelAccelerationStructureKHR::getSerializingAddresses(const DeviceInterface &vk,
3416                                                                                 const VkDevice device) const
3417 {
3418     std::vector<uint64_t> result(m_bottomLevelInstances.size() + 1);
3419 
3420     VkAccelerationStructureDeviceAddressInfoKHR asDeviceAddressInfo = {
3421         VK_STRUCTURE_TYPE_ACCELERATION_STRUCTURE_DEVICE_ADDRESS_INFO_KHR, // VkStructureType sType;
3422         nullptr,                                                          // const void* pNext;
3423         VK_NULL_HANDLE, // VkAccelerationStructureKHR accelerationStructure;
3424     };
3425 
3426     if (m_buildType == VK_ACCELERATION_STRUCTURE_BUILD_TYPE_DEVICE_KHR)
3427     {
3428         asDeviceAddressInfo.accelerationStructure = m_accelerationStructureKHR.get();
3429         result[0] = vk.getAccelerationStructureDeviceAddressKHR(device, &asDeviceAddressInfo);
3430     }
3431     else
3432     {
3433         result[0] = uint64_t(getPtr()->getInternal());
3434     }
3435 
3436     for (size_t instanceNdx = 0; instanceNdx < m_bottomLevelInstances.size(); ++instanceNdx)
3437     {
3438         const BottomLevelAccelerationStructure &bottomLevelAccelerationStructure = *m_bottomLevelInstances[instanceNdx];
3439         const VkAccelerationStructureKHR accelerationStructureKHR = *bottomLevelAccelerationStructure.getPtr();
3440 
3441         if (m_buildType == VK_ACCELERATION_STRUCTURE_BUILD_TYPE_DEVICE_KHR)
3442         {
3443             asDeviceAddressInfo.accelerationStructure = accelerationStructureKHR;
3444             result[instanceNdx + 1] = vk.getAccelerationStructureDeviceAddressKHR(device, &asDeviceAddressInfo);
3445         }
3446         else
3447         {
3448             result[instanceNdx + 1] = uint64_t(accelerationStructureKHR.getInternal());
3449         }
3450     }
3451 
3452     return result;
3453 }
3454 
getPtr(void) const3455 const VkAccelerationStructureKHR *TopLevelAccelerationStructureKHR::getPtr(void) const
3456 {
3457     return &m_accelerationStructureKHR.get();
3458 }
3459 
prepareInstances(const DeviceInterface & vk,const VkDevice device,VkAccelerationStructureGeometryKHR & accelerationStructureGeometryKHR,std::vector<uint32_t> & maxPrimitiveCounts)3460 void TopLevelAccelerationStructureKHR::prepareInstances(
3461     const DeviceInterface &vk, const VkDevice device,
3462     VkAccelerationStructureGeometryKHR &accelerationStructureGeometryKHR, std::vector<uint32_t> &maxPrimitiveCounts)
3463 {
3464     maxPrimitiveCounts.resize(1);
3465     maxPrimitiveCounts[0] = static_cast<uint32_t>(m_bottomLevelInstances.size());
3466 
3467     VkDeviceOrHostAddressConstKHR instancesData;
3468     if (m_buildType == VK_ACCELERATION_STRUCTURE_BUILD_TYPE_DEVICE_KHR)
3469     {
3470         if (m_instanceBuffer.get() != VK_NULL_HANDLE)
3471         {
3472             if (m_useArrayOfPointers)
3473             {
3474                 uint8_t *bufferStart = static_cast<uint8_t *>(m_instanceAddressBuffer->getAllocation().getHostPtr());
3475                 VkDeviceSize bufferOffset = 0;
3476                 VkDeviceOrHostAddressConstKHR firstInstance =
3477                     makeDeviceOrHostAddressConstKHR(vk, device, m_instanceBuffer->get(), 0);
3478                 for (size_t instanceNdx = 0; instanceNdx < m_bottomLevelInstances.size(); ++instanceNdx)
3479                 {
3480                     VkDeviceOrHostAddressConstKHR currentInstance;
3481                     currentInstance.deviceAddress =
3482                         firstInstance.deviceAddress + instanceNdx * sizeof(VkAccelerationStructureInstanceKHR);
3483 
3484                     deMemcpy(&bufferStart[bufferOffset], &currentInstance,
3485                              sizeof(VkDeviceOrHostAddressConstKHR::deviceAddress));
3486                     bufferOffset += sizeof(VkDeviceOrHostAddressConstKHR::deviceAddress);
3487                 }
3488                 flushMappedMemoryRange(vk, device, m_instanceAddressBuffer->getAllocation().getMemory(),
3489                                        m_instanceAddressBuffer->getAllocation().getOffset(), VK_WHOLE_SIZE);
3490 
3491                 instancesData = makeDeviceOrHostAddressConstKHR(vk, device, m_instanceAddressBuffer->get(), 0);
3492             }
3493             else
3494                 instancesData = makeDeviceOrHostAddressConstKHR(vk, device, m_instanceBuffer->get(), 0);
3495         }
3496         else
3497             instancesData = makeDeviceOrHostAddressConstKHR(nullptr);
3498     }
3499     else
3500     {
3501         if (m_instanceBuffer.get() != VK_NULL_HANDLE)
3502         {
3503             if (m_useArrayOfPointers)
3504             {
3505                 uint8_t *bufferStart = static_cast<uint8_t *>(m_instanceAddressBuffer->getAllocation().getHostPtr());
3506                 VkDeviceSize bufferOffset = 0;
3507                 for (size_t instanceNdx = 0; instanceNdx < m_bottomLevelInstances.size(); ++instanceNdx)
3508                 {
3509                     VkDeviceOrHostAddressConstKHR currentInstance;
3510                     currentInstance.hostAddress = (uint8_t *)m_instanceBuffer->getAllocation().getHostPtr() +
3511                                                   instanceNdx * sizeof(VkAccelerationStructureInstanceKHR);
3512 
3513                     deMemcpy(&bufferStart[bufferOffset], &currentInstance,
3514                              sizeof(VkDeviceOrHostAddressConstKHR::hostAddress));
3515                     bufferOffset += sizeof(VkDeviceOrHostAddressConstKHR::hostAddress);
3516                 }
3517                 instancesData = makeDeviceOrHostAddressConstKHR(m_instanceAddressBuffer->getAllocation().getHostPtr());
3518             }
3519             else
3520                 instancesData = makeDeviceOrHostAddressConstKHR(m_instanceBuffer->getAllocation().getHostPtr());
3521         }
3522         else
3523             instancesData = makeDeviceOrHostAddressConstKHR(nullptr);
3524     }
3525 
3526     VkAccelerationStructureGeometryInstancesDataKHR accelerationStructureGeometryInstancesDataKHR = {
3527         VK_STRUCTURE_TYPE_ACCELERATION_STRUCTURE_GEOMETRY_INSTANCES_DATA_KHR, //  VkStructureType sType;
3528         nullptr,                                                              //  const void* pNext;
3529         (VkBool32)(m_useArrayOfPointers ? true : false),                      //  VkBool32 arrayOfPointers;
3530         instancesData                                                         //  VkDeviceOrHostAddressConstKHR data;
3531     };
3532 
3533     accelerationStructureGeometryKHR = {
3534         VK_STRUCTURE_TYPE_ACCELERATION_STRUCTURE_GEOMETRY_KHR, //  VkStructureType sType;
3535         nullptr,                                               //  const void* pNext;
3536         VK_GEOMETRY_TYPE_INSTANCES_KHR,                        //  VkGeometryTypeKHR geometryType;
3537         makeVkAccelerationStructureInstancesDataKHR(
3538             accelerationStructureGeometryInstancesDataKHR), //  VkAccelerationStructureGeometryDataKHR geometry;
3539         (VkGeometryFlagsKHR)0u                              //  VkGeometryFlagsKHR flags;
3540     };
3541 }
3542 
getRequiredAllocationCount(void)3543 uint32_t TopLevelAccelerationStructure::getRequiredAllocationCount(void)
3544 {
3545     return TopLevelAccelerationStructureKHR::getRequiredAllocationCount();
3546 }
3547 
makeTopLevelAccelerationStructure()3548 de::MovePtr<TopLevelAccelerationStructure> makeTopLevelAccelerationStructure()
3549 {
3550     return de::MovePtr<TopLevelAccelerationStructure>(new TopLevelAccelerationStructureKHR);
3551 }
3552 
queryAccelerationStructureSizeKHR(const DeviceInterface & vk,const VkDevice device,const VkCommandBuffer cmdBuffer,const std::vector<VkAccelerationStructureKHR> & accelerationStructureHandles,VkAccelerationStructureBuildTypeKHR buildType,const VkQueryPool queryPool,VkQueryType queryType,uint32_t firstQuery,std::vector<VkDeviceSize> & results)3553 bool queryAccelerationStructureSizeKHR(const DeviceInterface &vk, const VkDevice device,
3554                                        const VkCommandBuffer cmdBuffer,
3555                                        const std::vector<VkAccelerationStructureKHR> &accelerationStructureHandles,
3556                                        VkAccelerationStructureBuildTypeKHR buildType, const VkQueryPool queryPool,
3557                                        VkQueryType queryType, uint32_t firstQuery, std::vector<VkDeviceSize> &results)
3558 {
3559     DE_ASSERT(queryType == VK_QUERY_TYPE_ACCELERATION_STRUCTURE_COMPACTED_SIZE_KHR ||
3560               queryType == VK_QUERY_TYPE_ACCELERATION_STRUCTURE_SERIALIZATION_SIZE_KHR);
3561 
3562     if (buildType == VK_ACCELERATION_STRUCTURE_BUILD_TYPE_DEVICE_KHR)
3563     {
3564         // queryPool must be large enough to contain at least (firstQuery + accelerationStructureHandles.size()) queries
3565         vk.cmdResetQueryPool(cmdBuffer, queryPool, firstQuery, uint32_t(accelerationStructureHandles.size()));
3566         vk.cmdWriteAccelerationStructuresPropertiesKHR(cmdBuffer, uint32_t(accelerationStructureHandles.size()),
3567                                                        accelerationStructureHandles.data(), queryType, queryPool,
3568                                                        firstQuery);
3569         // results cannot be retrieved to CPU at the moment - you need to do it using getQueryPoolResults after cmdBuffer is executed. Meanwhile function returns a vector of 0s.
3570         results.resize(accelerationStructureHandles.size(), 0u);
3571         return false;
3572     }
3573     // buildType != VK_ACCELERATION_STRUCTURE_BUILD_TYPE_DEVICE_KHR
3574     results.resize(accelerationStructureHandles.size(), 0u);
3575     vk.writeAccelerationStructuresPropertiesKHR(
3576         device, uint32_t(accelerationStructureHandles.size()), accelerationStructureHandles.data(), queryType,
3577         sizeof(VkDeviceSize) * accelerationStructureHandles.size(), results.data(), sizeof(VkDeviceSize));
3578     // results will contain proper values
3579     return true;
3580 }
3581 
queryAccelerationStructureSize(const DeviceInterface & vk,const VkDevice device,const VkCommandBuffer cmdBuffer,const std::vector<VkAccelerationStructureKHR> & accelerationStructureHandles,VkAccelerationStructureBuildTypeKHR buildType,const VkQueryPool queryPool,VkQueryType queryType,uint32_t firstQuery,std::vector<VkDeviceSize> & results)3582 bool queryAccelerationStructureSize(const DeviceInterface &vk, const VkDevice device, const VkCommandBuffer cmdBuffer,
3583                                     const std::vector<VkAccelerationStructureKHR> &accelerationStructureHandles,
3584                                     VkAccelerationStructureBuildTypeKHR buildType, const VkQueryPool queryPool,
3585                                     VkQueryType queryType, uint32_t firstQuery, std::vector<VkDeviceSize> &results)
3586 {
3587     return queryAccelerationStructureSizeKHR(vk, device, cmdBuffer, accelerationStructureHandles, buildType, queryPool,
3588                                              queryType, firstQuery, results);
3589 }
3590 
RayTracingPipeline()3591 RayTracingPipeline::RayTracingPipeline()
3592     : m_shadersModules()
3593     , m_pipelineLibraries()
3594     , m_shaderCreateInfos()
3595     , m_shadersGroupCreateInfos()
3596     , m_pipelineCreateFlags(0U)
3597     , m_pipelineCreateFlags2(0U)
3598     , m_maxRecursionDepth(1U)
3599     , m_maxPayloadSize(0U)
3600     , m_maxAttributeSize(0U)
3601     , m_deferredOperation(false)
3602     , m_workerThreadCount(0)
3603 {
3604 }
3605 
~RayTracingPipeline()3606 RayTracingPipeline::~RayTracingPipeline()
3607 {
3608 }
3609 
3610 #define CHECKED_ASSIGN_SHADER(SHADER, STAGE) \
3611     if (SHADER == VK_SHADER_UNUSED_KHR)      \
3612         SHADER = STAGE;                      \
3613     else                                     \
3614         TCU_THROW(InternalError, "Attempt to reassign shader")
3615 
addShader(VkShaderStageFlagBits shaderStage,Move<VkShaderModule> shaderModule,uint32_t group,const VkSpecializationInfo * specializationInfo,const VkPipelineShaderStageCreateFlags pipelineShaderStageCreateFlags,const void * pipelineShaderStageCreateInfopNext)3616 void RayTracingPipeline::addShader(VkShaderStageFlagBits shaderStage, Move<VkShaderModule> shaderModule, uint32_t group,
3617                                    const VkSpecializationInfo *specializationInfo,
3618                                    const VkPipelineShaderStageCreateFlags pipelineShaderStageCreateFlags,
3619                                    const void *pipelineShaderStageCreateInfopNext)
3620 {
3621     addShader(shaderStage, makeVkSharedPtr(shaderModule), group, specializationInfo, pipelineShaderStageCreateFlags,
3622               pipelineShaderStageCreateInfopNext);
3623 }
3624 
addShader(VkShaderStageFlagBits shaderStage,de::SharedPtr<Move<VkShaderModule>> shaderModule,uint32_t group,const VkSpecializationInfo * specializationInfoPtr,const VkPipelineShaderStageCreateFlags pipelineShaderStageCreateFlags,const void * pipelineShaderStageCreateInfopNext)3625 void RayTracingPipeline::addShader(VkShaderStageFlagBits shaderStage, de::SharedPtr<Move<VkShaderModule>> shaderModule,
3626                                    uint32_t group, const VkSpecializationInfo *specializationInfoPtr,
3627                                    const VkPipelineShaderStageCreateFlags pipelineShaderStageCreateFlags,
3628                                    const void *pipelineShaderStageCreateInfopNext)
3629 {
3630     addShader(shaderStage, **shaderModule, group, specializationInfoPtr, pipelineShaderStageCreateFlags,
3631               pipelineShaderStageCreateInfopNext);
3632     m_shadersModules.push_back(shaderModule);
3633 }
3634 
addShader(VkShaderStageFlagBits shaderStage,VkShaderModule shaderModule,uint32_t group,const VkSpecializationInfo * specializationInfoPtr,const VkPipelineShaderStageCreateFlags pipelineShaderStageCreateFlags,const void * pipelineShaderStageCreateInfopNext)3635 void RayTracingPipeline::addShader(VkShaderStageFlagBits shaderStage, VkShaderModule shaderModule, uint32_t group,
3636                                    const VkSpecializationInfo *specializationInfoPtr,
3637                                    const VkPipelineShaderStageCreateFlags pipelineShaderStageCreateFlags,
3638                                    const void *pipelineShaderStageCreateInfopNext)
3639 {
3640     if (group >= m_shadersGroupCreateInfos.size())
3641     {
3642         for (size_t groupNdx = m_shadersGroupCreateInfos.size(); groupNdx <= group; ++groupNdx)
3643         {
3644             VkRayTracingShaderGroupCreateInfoKHR shaderGroupCreateInfo = {
3645                 VK_STRUCTURE_TYPE_RAY_TRACING_SHADER_GROUP_CREATE_INFO_KHR, //  VkStructureType sType;
3646                 nullptr,                                                    //  const void* pNext;
3647                 VK_RAY_TRACING_SHADER_GROUP_TYPE_MAX_ENUM_KHR,              //  VkRayTracingShaderGroupTypeKHR type;
3648                 VK_SHADER_UNUSED_KHR,                                       //  uint32_t generalShader;
3649                 VK_SHADER_UNUSED_KHR,                                       //  uint32_t closestHitShader;
3650                 VK_SHADER_UNUSED_KHR,                                       //  uint32_t anyHitShader;
3651                 VK_SHADER_UNUSED_KHR,                                       //  uint32_t intersectionShader;
3652                 nullptr, //  const void* pShaderGroupCaptureReplayHandle;
3653             };
3654 
3655             m_shadersGroupCreateInfos.push_back(shaderGroupCreateInfo);
3656         }
3657     }
3658 
3659     const uint32_t shaderStageNdx                               = (uint32_t)m_shaderCreateInfos.size();
3660     VkRayTracingShaderGroupCreateInfoKHR &shaderGroupCreateInfo = m_shadersGroupCreateInfos[group];
3661 
3662     switch (shaderStage)
3663     {
3664     case VK_SHADER_STAGE_RAYGEN_BIT_KHR:
3665         CHECKED_ASSIGN_SHADER(shaderGroupCreateInfo.generalShader, shaderStageNdx);
3666         break;
3667     case VK_SHADER_STAGE_MISS_BIT_KHR:
3668         CHECKED_ASSIGN_SHADER(shaderGroupCreateInfo.generalShader, shaderStageNdx);
3669         break;
3670     case VK_SHADER_STAGE_CALLABLE_BIT_KHR:
3671         CHECKED_ASSIGN_SHADER(shaderGroupCreateInfo.generalShader, shaderStageNdx);
3672         break;
3673     case VK_SHADER_STAGE_ANY_HIT_BIT_KHR:
3674         CHECKED_ASSIGN_SHADER(shaderGroupCreateInfo.anyHitShader, shaderStageNdx);
3675         break;
3676     case VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR:
3677         CHECKED_ASSIGN_SHADER(shaderGroupCreateInfo.closestHitShader, shaderStageNdx);
3678         break;
3679     case VK_SHADER_STAGE_INTERSECTION_BIT_KHR:
3680         CHECKED_ASSIGN_SHADER(shaderGroupCreateInfo.intersectionShader, shaderStageNdx);
3681         break;
3682     default:
3683         TCU_THROW(InternalError, "Unacceptable stage");
3684     }
3685 
3686     switch (shaderStage)
3687     {
3688     case VK_SHADER_STAGE_RAYGEN_BIT_KHR:
3689     case VK_SHADER_STAGE_MISS_BIT_KHR:
3690     case VK_SHADER_STAGE_CALLABLE_BIT_KHR:
3691     {
3692         DE_ASSERT(shaderGroupCreateInfo.type == VK_RAY_TRACING_SHADER_GROUP_TYPE_MAX_ENUM_KHR);
3693         shaderGroupCreateInfo.type = VK_RAY_TRACING_SHADER_GROUP_TYPE_GENERAL_KHR;
3694 
3695         break;
3696     }
3697 
3698     case VK_SHADER_STAGE_ANY_HIT_BIT_KHR:
3699     case VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR:
3700     case VK_SHADER_STAGE_INTERSECTION_BIT_KHR:
3701     {
3702         DE_ASSERT(shaderGroupCreateInfo.type != VK_RAY_TRACING_SHADER_GROUP_TYPE_GENERAL_KHR);
3703         shaderGroupCreateInfo.type = (shaderGroupCreateInfo.intersectionShader == VK_SHADER_UNUSED_KHR) ?
3704                                          VK_RAY_TRACING_SHADER_GROUP_TYPE_TRIANGLES_HIT_GROUP_KHR :
3705                                          VK_RAY_TRACING_SHADER_GROUP_TYPE_PROCEDURAL_HIT_GROUP_KHR;
3706 
3707         break;
3708     }
3709 
3710     default:
3711         TCU_THROW(InternalError, "Unacceptable stage");
3712     }
3713 
3714     {
3715         const VkPipelineShaderStageCreateInfo shaderCreateInfo = {
3716             VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO, //  VkStructureType sType;
3717             pipelineShaderStageCreateInfopNext,                  //  const void* pNext;
3718             pipelineShaderStageCreateFlags,                      //  VkPipelineShaderStageCreateFlags flags;
3719             shaderStage,                                         //  VkShaderStageFlagBits stage;
3720             shaderModule,                                        //  VkShaderModule module;
3721             "main",                                              //  const char* pName;
3722             specializationInfoPtr,                               //  const VkSpecializationInfo* pSpecializationInfo;
3723         };
3724 
3725         m_shaderCreateInfos.push_back(shaderCreateInfo);
3726     }
3727 }
3728 
setGroupCaptureReplayHandle(uint32_t group,const void * pShaderGroupCaptureReplayHandle)3729 void RayTracingPipeline::setGroupCaptureReplayHandle(uint32_t group, const void *pShaderGroupCaptureReplayHandle)
3730 {
3731     DE_ASSERT(static_cast<size_t>(group) < m_shadersGroupCreateInfos.size());
3732     m_shadersGroupCreateInfos[group].pShaderGroupCaptureReplayHandle = pShaderGroupCaptureReplayHandle;
3733 }
3734 
addLibrary(de::SharedPtr<de::MovePtr<RayTracingPipeline>> pipelineLibrary)3735 void RayTracingPipeline::addLibrary(de::SharedPtr<de::MovePtr<RayTracingPipeline>> pipelineLibrary)
3736 {
3737     m_pipelineLibraries.push_back(pipelineLibrary);
3738 }
3739 
getShaderGroupCount(void)3740 uint32_t RayTracingPipeline::getShaderGroupCount(void)
3741 {
3742     return de::sizeU32(m_shadersGroupCreateInfos);
3743 }
3744 
getFullShaderGroupCount(void)3745 uint32_t RayTracingPipeline::getFullShaderGroupCount(void)
3746 {
3747     uint32_t totalCount = getShaderGroupCount();
3748 
3749     for (const auto &lib : m_pipelineLibraries)
3750         totalCount += lib->get()->getFullShaderGroupCount();
3751 
3752     return totalCount;
3753 }
3754 
createPipelineKHR(const DeviceInterface & vk,const VkDevice device,const VkPipelineLayout pipelineLayout,const std::vector<VkPipeline> & pipelineLibraries,const VkPipelineCache pipelineCache)3755 Move<VkPipeline> RayTracingPipeline::createPipelineKHR(const DeviceInterface &vk, const VkDevice device,
3756                                                        const VkPipelineLayout pipelineLayout,
3757                                                        const std::vector<VkPipeline> &pipelineLibraries,
3758                                                        const VkPipelineCache pipelineCache)
3759 {
3760     for (size_t groupNdx = 0; groupNdx < m_shadersGroupCreateInfos.size(); ++groupNdx)
3761         DE_ASSERT(m_shadersGroupCreateInfos[groupNdx].sType ==
3762                   VK_STRUCTURE_TYPE_RAY_TRACING_SHADER_GROUP_CREATE_INFO_KHR);
3763 
3764     VkPipelineLibraryCreateInfoKHR librariesCreateInfo = {
3765         VK_STRUCTURE_TYPE_PIPELINE_LIBRARY_CREATE_INFO_KHR, //  VkStructureType sType;
3766         nullptr,                                            //  const void* pNext;
3767         de::sizeU32(pipelineLibraries),                     //  uint32_t libraryCount;
3768         de::dataOrNull(pipelineLibraries)                   //  VkPipeline* pLibraries;
3769     };
3770     const VkRayTracingPipelineInterfaceCreateInfoKHR pipelineInterfaceCreateInfo = {
3771         VK_STRUCTURE_TYPE_RAY_TRACING_PIPELINE_INTERFACE_CREATE_INFO_KHR, //  VkStructureType sType;
3772         nullptr,                                                          //  const void* pNext;
3773         m_maxPayloadSize,                                                 //  uint32_t maxPayloadSize;
3774         m_maxAttributeSize                                                //  uint32_t maxAttributeSize;
3775     };
3776     const bool addPipelineInterfaceCreateInfo = m_maxPayloadSize != 0 || m_maxAttributeSize != 0;
3777     const VkRayTracingPipelineInterfaceCreateInfoKHR *pipelineInterfaceCreateInfoPtr =
3778         addPipelineInterfaceCreateInfo ? &pipelineInterfaceCreateInfo : nullptr;
3779     const VkPipelineLibraryCreateInfoKHR *librariesCreateInfoPtr =
3780         (pipelineLibraries.empty() ? nullptr : &librariesCreateInfo);
3781 
3782     Move<VkDeferredOperationKHR> deferredOperation;
3783     if (m_deferredOperation)
3784         deferredOperation = createDeferredOperationKHR(vk, device);
3785 
3786     VkPipelineDynamicStateCreateInfo dynamicStateCreateInfo = {
3787         VK_STRUCTURE_TYPE_PIPELINE_DYNAMIC_STATE_CREATE_INFO, // VkStructureType sType;
3788         nullptr,                                              // const void* pNext;
3789         0,                                                    // VkPipelineDynamicStateCreateFlags flags;
3790         static_cast<uint32_t>(m_dynamicStates.size()),        // uint32_t dynamicStateCount;
3791         m_dynamicStates.data(),                               // const VkDynamicState* pDynamicStates;
3792     };
3793 
3794     VkRayTracingPipelineCreateInfoKHR pipelineCreateInfo{
3795         VK_STRUCTURE_TYPE_RAY_TRACING_PIPELINE_CREATE_INFO_KHR, //  VkStructureType sType;
3796         nullptr,                                                //  const void* pNext;
3797         m_pipelineCreateFlags,                                  //  VkPipelineCreateFlags flags;
3798         de::sizeU32(m_shaderCreateInfos),                       //  uint32_t stageCount;
3799         de::dataOrNull(m_shaderCreateInfos),                    //  const VkPipelineShaderStageCreateInfo* pStages;
3800         de::sizeU32(m_shadersGroupCreateInfos),                 //  uint32_t groupCount;
3801         de::dataOrNull(m_shadersGroupCreateInfos),              //  const VkRayTracingShaderGroupCreateInfoKHR* pGroups;
3802         m_maxRecursionDepth,                                    //  uint32_t maxRecursionDepth;
3803         librariesCreateInfoPtr,                                 //  VkPipelineLibraryCreateInfoKHR* pLibraryInfo;
3804         pipelineInterfaceCreateInfoPtr, //  VkRayTracingPipelineInterfaceCreateInfoKHR* pLibraryInterface;
3805         &dynamicStateCreateInfo,        //  const VkPipelineDynamicStateCreateInfo* pDynamicState;
3806         pipelineLayout,                 //  VkPipelineLayout layout;
3807         VK_NULL_HANDLE,                 //  VkPipeline basePipelineHandle;
3808         0,                              //  int32_t basePipelineIndex;
3809     };
3810     VkPipelineCreateFlags2CreateInfoKHR pipelineFlags2CreateInfo = initVulkanStructure();
3811     if (m_pipelineCreateFlags2)
3812     {
3813         pipelineFlags2CreateInfo.flags = m_pipelineCreateFlags2;
3814         pipelineCreateInfo.pNext       = &pipelineFlags2CreateInfo;
3815         pipelineCreateInfo.flags       = 0;
3816     }
3817 
3818     VkPipeline object = VK_NULL_HANDLE;
3819     VkResult result   = vk.createRayTracingPipelinesKHR(device, deferredOperation.get(), pipelineCache, 1u,
3820                                                         &pipelineCreateInfo, nullptr, &object);
3821     const bool allowCompileRequired =
3822         ((m_pipelineCreateFlags & VK_PIPELINE_CREATE_FAIL_ON_PIPELINE_COMPILE_REQUIRED_BIT_EXT) != 0);
3823 
3824     if (m_deferredOperation)
3825     {
3826         DE_ASSERT(result == VK_OPERATION_DEFERRED_KHR || result == VK_OPERATION_NOT_DEFERRED_KHR ||
3827                   result == VK_SUCCESS || (allowCompileRequired && result == VK_PIPELINE_COMPILE_REQUIRED));
3828         finishDeferredOperation(vk, device, deferredOperation.get(), m_workerThreadCount,
3829                                 result == VK_OPERATION_NOT_DEFERRED_KHR);
3830     }
3831 
3832     if (allowCompileRequired && result == VK_PIPELINE_COMPILE_REQUIRED)
3833         throw CompileRequiredError("createRayTracingPipelinesKHR returned VK_PIPELINE_COMPILE_REQUIRED");
3834 
3835     Move<VkPipeline> pipeline(check<VkPipeline>(object), Deleter<VkPipeline>(vk, device, nullptr));
3836     return pipeline;
3837 }
3838 
createPipeline(const DeviceInterface & vk,const VkDevice device,const VkPipelineLayout pipelineLayout,const std::vector<de::SharedPtr<Move<VkPipeline>>> & pipelineLibraries)3839 Move<VkPipeline> RayTracingPipeline::createPipeline(
3840     const DeviceInterface &vk, const VkDevice device, const VkPipelineLayout pipelineLayout,
3841     const std::vector<de::SharedPtr<Move<VkPipeline>>> &pipelineLibraries)
3842 {
3843     std::vector<VkPipeline> rawPipelines;
3844     rawPipelines.reserve(pipelineLibraries.size());
3845     for (const auto &lib : pipelineLibraries)
3846         rawPipelines.push_back(lib.get()->get());
3847 
3848     return createPipelineKHR(vk, device, pipelineLayout, rawPipelines);
3849 }
3850 
createPipeline(const DeviceInterface & vk,const VkDevice device,const VkPipelineLayout pipelineLayout,const std::vector<VkPipeline> & pipelineLibraries,const VkPipelineCache pipelineCache)3851 Move<VkPipeline> RayTracingPipeline::createPipeline(const DeviceInterface &vk, const VkDevice device,
3852                                                     const VkPipelineLayout pipelineLayout,
3853                                                     const std::vector<VkPipeline> &pipelineLibraries,
3854                                                     const VkPipelineCache pipelineCache)
3855 {
3856     return createPipelineKHR(vk, device, pipelineLayout, pipelineLibraries, pipelineCache);
3857 }
3858 
createPipelineWithLibraries(const DeviceInterface & vk,const VkDevice device,const VkPipelineLayout pipelineLayout)3859 std::vector<de::SharedPtr<Move<VkPipeline>>> RayTracingPipeline::createPipelineWithLibraries(
3860     const DeviceInterface &vk, const VkDevice device, const VkPipelineLayout pipelineLayout)
3861 {
3862     for (size_t groupNdx = 0; groupNdx < m_shadersGroupCreateInfos.size(); ++groupNdx)
3863         DE_ASSERT(m_shadersGroupCreateInfos[groupNdx].sType ==
3864                   VK_STRUCTURE_TYPE_RAY_TRACING_SHADER_GROUP_CREATE_INFO_KHR);
3865 
3866     DE_ASSERT(m_shaderCreateInfos.size() > 0);
3867     DE_ASSERT(m_shadersGroupCreateInfos.size() > 0);
3868 
3869     std::vector<de::SharedPtr<Move<VkPipeline>>> result, allLibraries, firstLibraries;
3870     for (auto it = begin(m_pipelineLibraries), eit = end(m_pipelineLibraries); it != eit; ++it)
3871     {
3872         auto childLibraries = (*it)->get()->createPipelineWithLibraries(vk, device, pipelineLayout);
3873         DE_ASSERT(childLibraries.size() > 0);
3874         firstLibraries.push_back(childLibraries[0]);
3875         std::copy(begin(childLibraries), end(childLibraries), std::back_inserter(allLibraries));
3876     }
3877     result.push_back(makeVkSharedPtr(createPipeline(vk, device, pipelineLayout, firstLibraries)));
3878     std::copy(begin(allLibraries), end(allLibraries), std::back_inserter(result));
3879     return result;
3880 }
3881 
getShaderGroupHandles(const DeviceInterface & vk,const VkDevice device,const VkPipeline pipeline,const uint32_t shaderGroupHandleSize,const uint32_t firstGroup,const uint32_t groupCount) const3882 std::vector<uint8_t> RayTracingPipeline::getShaderGroupHandles(const DeviceInterface &vk, const VkDevice device,
3883                                                                const VkPipeline pipeline,
3884                                                                const uint32_t shaderGroupHandleSize,
3885                                                                const uint32_t firstGroup,
3886                                                                const uint32_t groupCount) const
3887 {
3888     const auto handleArraySizeBytes = groupCount * shaderGroupHandleSize;
3889     std::vector<uint8_t> shaderHandles(handleArraySizeBytes);
3890 
3891     VK_CHECK(getRayTracingShaderGroupHandles(vk, device, pipeline, firstGroup, groupCount,
3892                                              static_cast<uintptr_t>(shaderHandles.size()),
3893                                              de::dataOrNull(shaderHandles)));
3894 
3895     return shaderHandles;
3896 }
3897 
getShaderGroupReplayHandles(const DeviceInterface & vk,const VkDevice device,const VkPipeline pipeline,const uint32_t shaderGroupHandleReplaySize,const uint32_t firstGroup,const uint32_t groupCount) const3898 std::vector<uint8_t> RayTracingPipeline::getShaderGroupReplayHandles(const DeviceInterface &vk, const VkDevice device,
3899                                                                      const VkPipeline pipeline,
3900                                                                      const uint32_t shaderGroupHandleReplaySize,
3901                                                                      const uint32_t firstGroup,
3902                                                                      const uint32_t groupCount) const
3903 {
3904     const auto handleArraySizeBytes = groupCount * shaderGroupHandleReplaySize;
3905     std::vector<uint8_t> shaderHandles(handleArraySizeBytes);
3906 
3907     VK_CHECK(getRayTracingCaptureReplayShaderGroupHandles(vk, device, pipeline, firstGroup, groupCount,
3908                                                           static_cast<uintptr_t>(shaderHandles.size()),
3909                                                           de::dataOrNull(shaderHandles)));
3910 
3911     return shaderHandles;
3912 }
3913 
createShaderBindingTable(const DeviceInterface & vk,const VkDevice device,const VkPipeline pipeline,Allocator & allocator,const uint32_t & shaderGroupHandleSize,const uint32_t shaderGroupBaseAlignment,const uint32_t & firstGroup,const uint32_t & groupCount,const VkBufferCreateFlags & additionalBufferCreateFlags,const VkBufferUsageFlags & additionalBufferUsageFlags,const MemoryRequirement & additionalMemoryRequirement,const VkDeviceAddress & opaqueCaptureAddress,const uint32_t shaderBindingTableOffset,const uint32_t shaderRecordSize,const void ** shaderGroupDataPtrPerGroup,const bool autoAlignRecords)3914 de::MovePtr<BufferWithMemory> RayTracingPipeline::createShaderBindingTable(
3915     const DeviceInterface &vk, const VkDevice device, const VkPipeline pipeline, Allocator &allocator,
3916     const uint32_t &shaderGroupHandleSize, const uint32_t shaderGroupBaseAlignment, const uint32_t &firstGroup,
3917     const uint32_t &groupCount, const VkBufferCreateFlags &additionalBufferCreateFlags,
3918     const VkBufferUsageFlags &additionalBufferUsageFlags, const MemoryRequirement &additionalMemoryRequirement,
3919     const VkDeviceAddress &opaqueCaptureAddress, const uint32_t shaderBindingTableOffset,
3920     const uint32_t shaderRecordSize, const void **shaderGroupDataPtrPerGroup, const bool autoAlignRecords)
3921 {
3922     const auto shaderHandles =
3923         getShaderGroupHandles(vk, device, pipeline, shaderGroupHandleSize, firstGroup, groupCount);
3924     return createShaderBindingTable(vk, device, allocator, shaderGroupHandleSize, shaderGroupBaseAlignment,
3925                                     shaderHandles, additionalBufferCreateFlags, additionalBufferUsageFlags,
3926                                     additionalMemoryRequirement, opaqueCaptureAddress, shaderBindingTableOffset,
3927                                     shaderRecordSize, shaderGroupDataPtrPerGroup, autoAlignRecords);
3928 }
3929 
createShaderBindingTable(const DeviceInterface & vk,const VkDevice device,Allocator & allocator,const uint32_t shaderGroupHandleSize,const uint32_t shaderGroupBaseAlignment,const std::vector<uint8_t> & shaderHandles,const VkBufferCreateFlags additionalBufferCreateFlags,const VkBufferUsageFlags additionalBufferUsageFlags,const MemoryRequirement & additionalMemoryRequirement,const VkDeviceAddress opaqueCaptureAddress,const uint32_t shaderBindingTableOffset,const uint32_t shaderRecordSize,const void ** shaderGroupDataPtrPerGroup,const bool autoAlignRecords)3930 de::MovePtr<BufferWithMemory> RayTracingPipeline::createShaderBindingTable(
3931     const DeviceInterface &vk, const VkDevice device, Allocator &allocator, const uint32_t shaderGroupHandleSize,
3932     const uint32_t shaderGroupBaseAlignment, const std::vector<uint8_t> &shaderHandles,
3933     const VkBufferCreateFlags additionalBufferCreateFlags, const VkBufferUsageFlags additionalBufferUsageFlags,
3934     const MemoryRequirement &additionalMemoryRequirement, const VkDeviceAddress opaqueCaptureAddress,
3935     const uint32_t shaderBindingTableOffset, const uint32_t shaderRecordSize, const void **shaderGroupDataPtrPerGroup,
3936     const bool autoAlignRecords)
3937 {
3938     DE_ASSERT(shaderGroupBaseAlignment != 0u);
3939     DE_ASSERT((shaderBindingTableOffset % shaderGroupBaseAlignment) == 0);
3940     DE_UNREF(shaderGroupBaseAlignment);
3941 
3942     const auto groupCount = de::sizeU32(shaderHandles) / shaderGroupHandleSize;
3943     const auto totalEntrySize =
3944         (autoAlignRecords ? (deAlign32(shaderGroupHandleSize + shaderRecordSize, shaderGroupHandleSize)) :
3945                             (shaderGroupHandleSize + shaderRecordSize));
3946     const uint32_t sbtSize            = shaderBindingTableOffset + groupCount * totalEntrySize;
3947     const VkBufferUsageFlags sbtFlags = VK_BUFFER_USAGE_TRANSFER_DST_BIT |
3948                                         VK_BUFFER_USAGE_SHADER_BINDING_TABLE_BIT_KHR |
3949                                         VK_BUFFER_USAGE_SHADER_DEVICE_ADDRESS_BIT | additionalBufferUsageFlags;
3950     VkBufferCreateInfo sbtCreateInfo = makeBufferCreateInfo(sbtSize, sbtFlags);
3951     sbtCreateInfo.flags |= additionalBufferCreateFlags;
3952     VkBufferUsageFlags2CreateInfoKHR bufferUsageFlags2           = vk::initVulkanStructure();
3953     VkBufferOpaqueCaptureAddressCreateInfo sbtCaptureAddressInfo = {
3954         VK_STRUCTURE_TYPE_BUFFER_OPAQUE_CAPTURE_ADDRESS_CREATE_INFO, // VkStructureType sType;
3955         nullptr,                                                     // const void* pNext;
3956         uint64_t(opaqueCaptureAddress)                               // uint64_t opaqueCaptureAddress;
3957     };
3958 
3959     // when maintenance5 is tested then m_pipelineCreateFlags2 is non-zero
3960     if (m_pipelineCreateFlags2)
3961     {
3962         bufferUsageFlags2.usage = (VkBufferUsageFlags2KHR)sbtFlags;
3963         sbtCreateInfo.pNext     = &bufferUsageFlags2;
3964         sbtCreateInfo.usage     = 0;
3965     }
3966 
3967     if (opaqueCaptureAddress != 0u)
3968     {
3969         sbtCreateInfo.pNext = &sbtCaptureAddressInfo;
3970         sbtCreateInfo.flags |= VK_BUFFER_CREATE_DEVICE_ADDRESS_CAPTURE_REPLAY_BIT;
3971     }
3972     const MemoryRequirement sbtMemRequirements = MemoryRequirement::HostVisible | MemoryRequirement::Coherent |
3973                                                  MemoryRequirement::DeviceAddress | additionalMemoryRequirement;
3974     de::MovePtr<BufferWithMemory> sbtBuffer =
3975         de::MovePtr<BufferWithMemory>(new BufferWithMemory(vk, device, allocator, sbtCreateInfo, sbtMemRequirements));
3976     vk::Allocation &sbtAlloc = sbtBuffer->getAllocation();
3977 
3978     // Copy handles to table, leaving space for ShaderRecordKHR after each handle.
3979     uint8_t *shaderBegin = (uint8_t *)sbtAlloc.getHostPtr() + shaderBindingTableOffset;
3980     for (uint32_t idx = 0; idx < groupCount; ++idx)
3981     {
3982         const uint8_t *shaderSrcPos = shaderHandles.data() + idx * shaderGroupHandleSize;
3983         uint8_t *shaderDstPos       = shaderBegin + idx * totalEntrySize;
3984         deMemcpy(shaderDstPos, shaderSrcPos, shaderGroupHandleSize);
3985 
3986         if (shaderGroupDataPtrPerGroup != nullptr && shaderGroupDataPtrPerGroup[idx] != nullptr)
3987         {
3988             DE_ASSERT(sbtSize >= static_cast<uint32_t>(shaderDstPos - shaderBegin) + shaderGroupHandleSize);
3989 
3990             deMemcpy(shaderDstPos + shaderGroupHandleSize, shaderGroupDataPtrPerGroup[idx], shaderRecordSize);
3991         }
3992     }
3993 
3994     flushMappedMemoryRange(vk, device, sbtAlloc.getMemory(), sbtAlloc.getOffset(), VK_WHOLE_SIZE);
3995 
3996     return sbtBuffer;
3997 }
3998 
setCreateFlags(const VkPipelineCreateFlags & pipelineCreateFlags)3999 void RayTracingPipeline::setCreateFlags(const VkPipelineCreateFlags &pipelineCreateFlags)
4000 {
4001     m_pipelineCreateFlags = pipelineCreateFlags;
4002 }
4003 
setCreateFlags2(const VkPipelineCreateFlags2KHR & pipelineCreateFlags2)4004 void RayTracingPipeline::setCreateFlags2(const VkPipelineCreateFlags2KHR &pipelineCreateFlags2)
4005 {
4006     m_pipelineCreateFlags2 = pipelineCreateFlags2;
4007 }
4008 
setMaxRecursionDepth(const uint32_t & maxRecursionDepth)4009 void RayTracingPipeline::setMaxRecursionDepth(const uint32_t &maxRecursionDepth)
4010 {
4011     m_maxRecursionDepth = maxRecursionDepth;
4012 }
4013 
setMaxPayloadSize(const uint32_t & maxPayloadSize)4014 void RayTracingPipeline::setMaxPayloadSize(const uint32_t &maxPayloadSize)
4015 {
4016     m_maxPayloadSize = maxPayloadSize;
4017 }
4018 
setMaxAttributeSize(const uint32_t & maxAttributeSize)4019 void RayTracingPipeline::setMaxAttributeSize(const uint32_t &maxAttributeSize)
4020 {
4021     m_maxAttributeSize = maxAttributeSize;
4022 }
4023 
setDeferredOperation(const bool deferredOperation,const uint32_t workerThreadCount)4024 void RayTracingPipeline::setDeferredOperation(const bool deferredOperation, const uint32_t workerThreadCount)
4025 {
4026     m_deferredOperation = deferredOperation;
4027     m_workerThreadCount = workerThreadCount;
4028 }
4029 
addDynamicState(const VkDynamicState & dynamicState)4030 void RayTracingPipeline::addDynamicState(const VkDynamicState &dynamicState)
4031 {
4032     m_dynamicStates.push_back(dynamicState);
4033 }
4034 
4035 class RayTracingPropertiesKHR : public RayTracingProperties
4036 {
4037 public:
4038     RayTracingPropertiesKHR() = delete;
4039     RayTracingPropertiesKHR(const InstanceInterface &vki, const VkPhysicalDevice physicalDevice);
4040     virtual ~RayTracingPropertiesKHR();
4041 
getShaderGroupHandleSize(void)4042     uint32_t getShaderGroupHandleSize(void) override
4043     {
4044         return m_rayTracingPipelineProperties.shaderGroupHandleSize;
4045     }
getShaderGroupHandleAlignment(void)4046     uint32_t getShaderGroupHandleAlignment(void) override
4047     {
4048         return m_rayTracingPipelineProperties.shaderGroupHandleAlignment;
4049     }
getShaderGroupHandleCaptureReplaySize(void)4050     uint32_t getShaderGroupHandleCaptureReplaySize(void) override
4051     {
4052         return m_rayTracingPipelineProperties.shaderGroupHandleCaptureReplaySize;
4053     }
getMaxRecursionDepth(void)4054     uint32_t getMaxRecursionDepth(void) override
4055     {
4056         return m_rayTracingPipelineProperties.maxRayRecursionDepth;
4057     }
getMaxShaderGroupStride(void)4058     uint32_t getMaxShaderGroupStride(void) override
4059     {
4060         return m_rayTracingPipelineProperties.maxShaderGroupStride;
4061     }
getShaderGroupBaseAlignment(void)4062     uint32_t getShaderGroupBaseAlignment(void) override
4063     {
4064         return m_rayTracingPipelineProperties.shaderGroupBaseAlignment;
4065     }
getMaxGeometryCount(void)4066     uint64_t getMaxGeometryCount(void) override
4067     {
4068         return m_accelerationStructureProperties.maxGeometryCount;
4069     }
getMaxInstanceCount(void)4070     uint64_t getMaxInstanceCount(void) override
4071     {
4072         return m_accelerationStructureProperties.maxInstanceCount;
4073     }
getMaxPrimitiveCount(void)4074     uint64_t getMaxPrimitiveCount(void) override
4075     {
4076         return m_accelerationStructureProperties.maxPrimitiveCount;
4077     }
getMaxDescriptorSetAccelerationStructures(void)4078     uint32_t getMaxDescriptorSetAccelerationStructures(void) override
4079     {
4080         return m_accelerationStructureProperties.maxDescriptorSetAccelerationStructures;
4081     }
getMaxRayDispatchInvocationCount(void)4082     uint32_t getMaxRayDispatchInvocationCount(void) override
4083     {
4084         return m_rayTracingPipelineProperties.maxRayDispatchInvocationCount;
4085     }
getMaxRayHitAttributeSize(void)4086     uint32_t getMaxRayHitAttributeSize(void) override
4087     {
4088         return m_rayTracingPipelineProperties.maxRayHitAttributeSize;
4089     }
getMaxMemoryAllocationCount(void)4090     uint32_t getMaxMemoryAllocationCount(void) override
4091     {
4092         return m_maxMemoryAllocationCount;
4093     }
4094 
4095 protected:
4096     VkPhysicalDeviceAccelerationStructurePropertiesKHR m_accelerationStructureProperties;
4097     VkPhysicalDeviceRayTracingPipelinePropertiesKHR m_rayTracingPipelineProperties;
4098     uint32_t m_maxMemoryAllocationCount;
4099 };
4100 
~RayTracingPropertiesKHR()4101 RayTracingPropertiesKHR::~RayTracingPropertiesKHR()
4102 {
4103 }
4104 
RayTracingPropertiesKHR(const InstanceInterface & vki,const VkPhysicalDevice physicalDevice)4105 RayTracingPropertiesKHR::RayTracingPropertiesKHR(const InstanceInterface &vki, const VkPhysicalDevice physicalDevice)
4106     : RayTracingProperties(vki, physicalDevice)
4107 {
4108     m_accelerationStructureProperties = getPhysicalDeviceExtensionProperties(vki, physicalDevice);
4109     m_rayTracingPipelineProperties    = getPhysicalDeviceExtensionProperties(vki, physicalDevice);
4110     m_maxMemoryAllocationCount = getPhysicalDeviceProperties(vki, physicalDevice).limits.maxMemoryAllocationCount;
4111 }
4112 
makeRayTracingProperties(const InstanceInterface & vki,const VkPhysicalDevice physicalDevice)4113 de::MovePtr<RayTracingProperties> makeRayTracingProperties(const InstanceInterface &vki,
4114                                                            const VkPhysicalDevice physicalDevice)
4115 {
4116     return de::MovePtr<RayTracingProperties>(new RayTracingPropertiesKHR(vki, physicalDevice));
4117 }
4118 
cmdTraceRaysKHR(const DeviceInterface & vk,VkCommandBuffer commandBuffer,const VkStridedDeviceAddressRegionKHR * raygenShaderBindingTableRegion,const VkStridedDeviceAddressRegionKHR * missShaderBindingTableRegion,const VkStridedDeviceAddressRegionKHR * hitShaderBindingTableRegion,const VkStridedDeviceAddressRegionKHR * callableShaderBindingTableRegion,uint32_t width,uint32_t height,uint32_t depth)4119 static inline void cmdTraceRaysKHR(const DeviceInterface &vk, VkCommandBuffer commandBuffer,
4120                                    const VkStridedDeviceAddressRegionKHR *raygenShaderBindingTableRegion,
4121                                    const VkStridedDeviceAddressRegionKHR *missShaderBindingTableRegion,
4122                                    const VkStridedDeviceAddressRegionKHR *hitShaderBindingTableRegion,
4123                                    const VkStridedDeviceAddressRegionKHR *callableShaderBindingTableRegion,
4124                                    uint32_t width, uint32_t height, uint32_t depth)
4125 {
4126     return vk.cmdTraceRaysKHR(commandBuffer, raygenShaderBindingTableRegion, missShaderBindingTableRegion,
4127                               hitShaderBindingTableRegion, callableShaderBindingTableRegion, width, height, depth);
4128 }
4129 
cmdTraceRays(const DeviceInterface & vk,VkCommandBuffer commandBuffer,const VkStridedDeviceAddressRegionKHR * raygenShaderBindingTableRegion,const VkStridedDeviceAddressRegionKHR * missShaderBindingTableRegion,const VkStridedDeviceAddressRegionKHR * hitShaderBindingTableRegion,const VkStridedDeviceAddressRegionKHR * callableShaderBindingTableRegion,uint32_t width,uint32_t height,uint32_t depth)4130 void cmdTraceRays(const DeviceInterface &vk, VkCommandBuffer commandBuffer,
4131                   const VkStridedDeviceAddressRegionKHR *raygenShaderBindingTableRegion,
4132                   const VkStridedDeviceAddressRegionKHR *missShaderBindingTableRegion,
4133                   const VkStridedDeviceAddressRegionKHR *hitShaderBindingTableRegion,
4134                   const VkStridedDeviceAddressRegionKHR *callableShaderBindingTableRegion, uint32_t width,
4135                   uint32_t height, uint32_t depth)
4136 {
4137     DE_ASSERT(raygenShaderBindingTableRegion != nullptr);
4138     DE_ASSERT(missShaderBindingTableRegion != nullptr);
4139     DE_ASSERT(hitShaderBindingTableRegion != nullptr);
4140     DE_ASSERT(callableShaderBindingTableRegion != nullptr);
4141 
4142     return cmdTraceRaysKHR(vk, commandBuffer, raygenShaderBindingTableRegion, missShaderBindingTableRegion,
4143                            hitShaderBindingTableRegion, callableShaderBindingTableRegion, width, height, depth);
4144 }
4145 
cmdTraceRaysIndirectKHR(const DeviceInterface & vk,VkCommandBuffer commandBuffer,const VkStridedDeviceAddressRegionKHR * raygenShaderBindingTableRegion,const VkStridedDeviceAddressRegionKHR * missShaderBindingTableRegion,const VkStridedDeviceAddressRegionKHR * hitShaderBindingTableRegion,const VkStridedDeviceAddressRegionKHR * callableShaderBindingTableRegion,VkDeviceAddress indirectDeviceAddress)4146 static inline void cmdTraceRaysIndirectKHR(const DeviceInterface &vk, VkCommandBuffer commandBuffer,
4147                                            const VkStridedDeviceAddressRegionKHR *raygenShaderBindingTableRegion,
4148                                            const VkStridedDeviceAddressRegionKHR *missShaderBindingTableRegion,
4149                                            const VkStridedDeviceAddressRegionKHR *hitShaderBindingTableRegion,
4150                                            const VkStridedDeviceAddressRegionKHR *callableShaderBindingTableRegion,
4151                                            VkDeviceAddress indirectDeviceAddress)
4152 {
4153     DE_ASSERT(raygenShaderBindingTableRegion != nullptr);
4154     DE_ASSERT(missShaderBindingTableRegion != nullptr);
4155     DE_ASSERT(hitShaderBindingTableRegion != nullptr);
4156     DE_ASSERT(callableShaderBindingTableRegion != nullptr);
4157     DE_ASSERT(indirectDeviceAddress != 0);
4158 
4159     return vk.cmdTraceRaysIndirectKHR(commandBuffer, raygenShaderBindingTableRegion, missShaderBindingTableRegion,
4160                                       hitShaderBindingTableRegion, callableShaderBindingTableRegion,
4161                                       indirectDeviceAddress);
4162 }
4163 
cmdTraceRaysIndirect(const DeviceInterface & vk,VkCommandBuffer commandBuffer,const VkStridedDeviceAddressRegionKHR * raygenShaderBindingTableRegion,const VkStridedDeviceAddressRegionKHR * missShaderBindingTableRegion,const VkStridedDeviceAddressRegionKHR * hitShaderBindingTableRegion,const VkStridedDeviceAddressRegionKHR * callableShaderBindingTableRegion,VkDeviceAddress indirectDeviceAddress)4164 void cmdTraceRaysIndirect(const DeviceInterface &vk, VkCommandBuffer commandBuffer,
4165                           const VkStridedDeviceAddressRegionKHR *raygenShaderBindingTableRegion,
4166                           const VkStridedDeviceAddressRegionKHR *missShaderBindingTableRegion,
4167                           const VkStridedDeviceAddressRegionKHR *hitShaderBindingTableRegion,
4168                           const VkStridedDeviceAddressRegionKHR *callableShaderBindingTableRegion,
4169                           VkDeviceAddress indirectDeviceAddress)
4170 {
4171     return cmdTraceRaysIndirectKHR(vk, commandBuffer, raygenShaderBindingTableRegion, missShaderBindingTableRegion,
4172                                    hitShaderBindingTableRegion, callableShaderBindingTableRegion,
4173                                    indirectDeviceAddress);
4174 }
4175 
cmdTraceRaysIndirect2KHR(const DeviceInterface & vk,VkCommandBuffer commandBuffer,VkDeviceAddress indirectDeviceAddress)4176 static inline void cmdTraceRaysIndirect2KHR(const DeviceInterface &vk, VkCommandBuffer commandBuffer,
4177                                             VkDeviceAddress indirectDeviceAddress)
4178 {
4179     DE_ASSERT(indirectDeviceAddress != 0);
4180 
4181     return vk.cmdTraceRaysIndirect2KHR(commandBuffer, indirectDeviceAddress);
4182 }
4183 
cmdTraceRaysIndirect2(const DeviceInterface & vk,VkCommandBuffer commandBuffer,VkDeviceAddress indirectDeviceAddress)4184 void cmdTraceRaysIndirect2(const DeviceInterface &vk, VkCommandBuffer commandBuffer,
4185                            VkDeviceAddress indirectDeviceAddress)
4186 {
4187     return cmdTraceRaysIndirect2KHR(vk, commandBuffer, indirectDeviceAddress);
4188 }
4189 
4190 constexpr uint32_t NO_INT_VALUE = spv::RayQueryCommittedIntersectionTypeMax;
4191 
generateRayQueryShaders(SourceCollections & programCollection,RayQueryTestParams params,std::string rayQueryPart,float max_t)4192 void generateRayQueryShaders(SourceCollections &programCollection, RayQueryTestParams params, std::string rayQueryPart,
4193                              float max_t)
4194 {
4195     std::stringstream genericMiss;
4196     genericMiss << "#version 460\n"
4197                    "#extension GL_EXT_ray_tracing : require\n"
4198                    "#extension GL_EXT_ray_query : require\n"
4199                    "layout(location = 0) rayPayloadInEXT vec4 payload;\n"
4200                    "void main()\n"
4201                    "{\n"
4202                    "  payload.x = 2000;\n"
4203                    "  payload.y = 2000;\n"
4204                    "  payload.z = 2000;\n"
4205                    "  payload.w = 2000;\n"
4206                    "}\n";
4207 
4208     std::stringstream genericIsect;
4209     genericIsect << "#version 460\n"
4210                     "#extension GL_EXT_ray_tracing : require\n"
4211                     "hitAttributeEXT uvec4 hitValue;\n"
4212                     "void main()\n"
4213                     "{\n"
4214                     "  reportIntersectionEXT(0.5f, 0);\n"
4215                     "}\n";
4216 
4217     std::stringstream rtChit;
4218     rtChit << "#version 460    \n"
4219               "#extension GL_EXT_ray_tracing : require\n"
4220               "#extension GL_EXT_ray_query : require\n"
4221               "layout(location = 0) rayPayloadInEXT vec4 payload;\n"
4222               "void main()\n"
4223               "{\n"
4224               "  uint index = (gl_LaunchIDEXT.z * gl_LaunchSizeEXT.x * gl_LaunchSizeEXT.y) + (gl_LaunchIDEXT.y * "
4225               "gl_LaunchSizeEXT.x) + gl_LaunchIDEXT.x;\n"
4226               "  payload.x = index;\n"
4227               "  payload.y = gl_HitTEXT;\n"
4228               "  payload.z = 1000;\n"
4229               "  payload.w = 1000;\n"
4230               "}\n";
4231 
4232     std::stringstream genericChit;
4233     genericChit << "#version 460    \n"
4234                    "#extension GL_EXT_ray_tracing : require\n"
4235                    "#extension GL_EXT_ray_query : require\n"
4236                    "layout(location = 0) rayPayloadInEXT vec4 payload;\n"
4237                    "void main()\n"
4238                    "{\n"
4239                    "  payload.x = 1000;\n"
4240                    "  payload.y = 1000;\n"
4241                    "  payload.z = 1000;\n"
4242                    "  payload.w = 1000;\n"
4243                    "}\n";
4244 
4245     std::stringstream genericRayTracingSetResultsShader;
4246     genericRayTracingSetResultsShader << "#version 460    \n"
4247                                          "#extension GL_EXT_ray_tracing : require\n"
4248                                          "#extension GL_EXT_ray_query : require\n"
4249                                          "layout(location = 0) rayPayloadInEXT vec4 payload;\n"
4250                                          "struct Ray { vec3 pos; float tmin; vec3 dir; float tmax; };\n"
4251                                          "layout(set = 0, binding = 1) uniform accelerationStructureEXT scene;\n"
4252                                          "layout(std430, set = 0, binding = 2) buffer Rays { Ray rays[]; };\n"
4253                                       << params.shaderFunctions
4254                                       << "void main()\n"
4255                                          "{\n"
4256                                          "  uint index = (gl_LaunchIDEXT.z * gl_LaunchSizeEXT.x * gl_LaunchSizeEXT.y) "
4257                                          "+ (gl_LaunchIDEXT.y * gl_LaunchSizeEXT.x) + gl_LaunchIDEXT.x;\n"
4258                                       << rayQueryPart
4259                                       << "  payload.x = x;\n"
4260                                          "  payload.y = y;\n"
4261                                          "  payload.z = z;\n"
4262                                          "  payload.w = w;\n"
4263                                          "}\n";
4264 
4265     const vk::ShaderBuildOptions buildOptions(programCollection.usedVulkanVersion, vk::SPIRV_VERSION_1_5, 0u, true);
4266 
4267     switch (params.pipelineType)
4268     {
4269     case RayQueryShaderSourcePipeline::COMPUTE:
4270     {
4271         std::ostringstream compute;
4272         compute << "#version 460\n"
4273                    "#extension GL_EXT_ray_tracing : enable\n"
4274                    "#extension GL_EXT_ray_query : require\n"
4275                    "\n"
4276                    "struct Ray { vec3 pos; float tmin; vec3 dir; float tmax; };\n"
4277                    "struct ResultType { float x; float y; float z; float w; };\n"
4278                    "layout(std430, set = 0, binding = 0) buffer Results { ResultType results[]; };\n"
4279                    "layout(set = 0, binding = 1) uniform accelerationStructureEXT scene;\n"
4280                    "layout(std430, set = 0, binding = 2) buffer Rays { Ray rays[]; };\n"
4281                    "layout (local_size_x = 1, local_size_y = 1, local_size_z = 1) in;\n"
4282                 << params.shaderFunctions
4283                 << "void main() {\n"
4284                    "   uint index = (gl_NumWorkGroups.x * gl_WorkGroupSize.x) * gl_GlobalInvocationID.y + "
4285                    "gl_GlobalInvocationID.x;\n"
4286                 << rayQueryPart
4287                 << "   results[index].x = x;\n"
4288                    "   results[index].y = y;\n"
4289                    "   results[index].z = z;\n"
4290                    "   results[index].w = w;\n"
4291                    "}";
4292 
4293         programCollection.glslSources.add("comp", &buildOptions) << glu::ComputeSource(compute.str());
4294 
4295         break;
4296     }
4297     case RayQueryShaderSourcePipeline::GRAPHICS:
4298     {
4299         std::ostringstream vertex;
4300 
4301         if (params.shaderSourceType == RayQueryShaderSourceType::VERTEX)
4302         {
4303             vertex << "#version 460\n"
4304                       "#extension GL_EXT_ray_tracing : enable\n"
4305                       "#extension GL_EXT_ray_query : require\n"
4306                       "struct Ray { vec3 pos; float tmin; vec3 dir; float tmax; };\n"
4307                       "layout(location = 0) in vec4 in_position;\n"
4308                       "layout(rgba32f, set = 0, binding = 0) uniform image3D resultImage;\n"
4309                       "layout(set = 0, binding = 1) uniform accelerationStructureEXT scene;\n"
4310                       "layout(std430, set = 0, binding = 2) buffer Rays { Ray rays[]; };\n"
4311                    << params.shaderFunctions
4312                    << "void main(void)\n"
4313                       "{\n"
4314                       "  const int  vertId = int(gl_VertexIndex % 3);\n"
4315                       "  if (vertId == 0)\n"
4316                       "  {\n"
4317                       "    ivec3 sz = imageSize(resultImage);\n"
4318                       "    int index = int(in_position.z);\n"
4319                       "    int idx = int(index % sz.x);\n"
4320                       "    int idy = int(index / sz.y);\n"
4321                    << rayQueryPart
4322                    << "     imageStore(resultImage, ivec3(idx, idy, 0), vec4(x, y, z, w));\n"
4323                       "  }\n"
4324                       "}\n";
4325         }
4326         else
4327         {
4328             vertex << "#version 460\n"
4329                       "layout(location = 0) in highp vec3 position;\n"
4330                       "\n"
4331                       "out gl_PerVertex {\n"
4332                       "   vec4 gl_Position;\n"
4333                       "};\n"
4334                       "\n"
4335                       "void main (void)\n"
4336                       "{\n"
4337                       "    gl_Position = vec4(position, 1.0);\n"
4338                       "}\n";
4339         }
4340 
4341         programCollection.glslSources.add("vert", &buildOptions) << glu::VertexSource(vertex.str());
4342 
4343         if (params.shaderSourceType == RayQueryShaderSourceType::FRAGMENT)
4344         {
4345             std::ostringstream frag;
4346             frag << "#version 460\n"
4347                     "#extension GL_EXT_ray_tracing : enable\n"
4348                     "#extension GL_EXT_ray_query : require\n"
4349                     "struct Ray { vec3 pos; float tmin; vec3 dir; float tmax; };\n"
4350                     "layout(rgba32f, set = 0, binding = 0) uniform image3D resultImage;\n"
4351                     "layout(set = 0, binding = 1) uniform accelerationStructureEXT scene;\n"
4352                     "layout(std430, set = 0, binding = 2) buffer Rays { Ray rays[]; };\n"
4353                  << params.shaderFunctions
4354                  << "void main() {\n"
4355                     "    ivec3 sz = imageSize(resultImage);\n"
4356                     "    uint index = uint(gl_FragCoord.x) + sz.x * uint(gl_FragCoord.y);\n"
4357                  << rayQueryPart
4358                  << "    imageStore(resultImage, ivec3(gl_FragCoord.xy, 0), vec4(x, y, z, w));\n"
4359                     "}";
4360 
4361             programCollection.glslSources.add("frag", &buildOptions) << glu::FragmentSource(frag.str());
4362         }
4363         else if (params.shaderSourceType == RayQueryShaderSourceType::GEOMETRY)
4364         {
4365             std::stringstream geom;
4366             geom << "#version 460\n"
4367                     "#extension GL_EXT_ray_tracing : enable\n"
4368                     "#extension GL_EXT_ray_query : require\n"
4369                     "struct Ray { vec3 pos; float tmin; vec3 dir; float tmax; };\n"
4370                     "layout(triangles) in;\n"
4371                     "layout (triangle_strip, max_vertices = 3) out;\n"
4372                     "layout(rgba32f, set = 0, binding = 0) uniform image3D resultImage;\n"
4373                     "layout(set = 0, binding = 1) uniform accelerationStructureEXT scene;\n"
4374                     "layout(std430, set = 0, binding = 2) buffer Rays { Ray rays[]; };\n"
4375                     "\n"
4376                     "in gl_PerVertex {\n"
4377                     "  vec4  gl_Position;\n"
4378                     "} gl_in[];\n"
4379                     "out gl_PerVertex {\n"
4380                     "  vec4 gl_Position;\n"
4381                     "};\n"
4382                  << params.shaderFunctions
4383                  << "void main (void)\n"
4384                     "{\n"
4385                     "  ivec3 sz = imageSize(resultImage);\n"
4386                     "  int index = int(gl_in[0].gl_Position.z);\n"
4387                     "  int idx = int(index % sz.x);\n"
4388                     "  int idy = int(index / sz.y);\n"
4389                  << rayQueryPart
4390                  << "  imageStore(resultImage, ivec3(idx, idy, 0), vec4(x, y, z, w));\n"
4391                     "  for (int i = 0; i < gl_in.length(); ++i)\n"
4392                     "  {\n"
4393                     "        gl_Position      = gl_in[i].gl_Position;\n"
4394                     "        EmitVertex();\n"
4395                     "  }\n"
4396                     "  EndPrimitive();\n"
4397                     "}\n";
4398 
4399             programCollection.glslSources.add("geom", &buildOptions) << glu::GeometrySource(geom.str());
4400         }
4401         else if (params.shaderSourceType == RayQueryShaderSourceType::TESSELLATION_EVALUATION)
4402         {
4403             {
4404                 std::stringstream tesc;
4405                 tesc << "#version 460\n"
4406                         "#extension GL_EXT_tessellation_shader : require\n"
4407                         "in gl_PerVertex\n"
4408                         "{\n"
4409                         "  vec4 gl_Position;\n"
4410                         "} gl_in[];\n"
4411                         "layout(vertices = 4) out;\n"
4412                         "out gl_PerVertex\n"
4413                         "{\n"
4414                         "  vec4 gl_Position;\n"
4415                         "} gl_out[];\n"
4416                         "\n"
4417                         "void main (void)\n"
4418                         "{\n"
4419                         "  gl_out[gl_InvocationID].gl_Position = gl_in[gl_InvocationID].gl_Position;\n"
4420                         "  gl_TessLevelInner[0] = 1;\n"
4421                         "  gl_TessLevelInner[1] = 1;\n"
4422                         "  gl_TessLevelOuter[gl_InvocationID] = 1;\n"
4423                         "}\n";
4424                 programCollection.glslSources.add("tesc", &buildOptions) << glu::TessellationControlSource(tesc.str());
4425             }
4426 
4427             {
4428                 std::ostringstream tese;
4429                 tese << "#version 460\n"
4430                         "#extension GL_EXT_ray_tracing : enable\n"
4431                         "#extension GL_EXT_tessellation_shader : require\n"
4432                         "#extension GL_EXT_ray_query : require\n"
4433                         "struct Ray { vec3 pos; float tmin; vec3 dir; float tmax; };\n"
4434                         "layout(rgba32f, set = 0, binding = 0) uniform image3D resultImage;\n"
4435                         "layout(set = 0, binding = 1) uniform accelerationStructureEXT scene;\n"
4436                         "layout(std430, set = 0, binding = 2) buffer Rays { Ray rays[]; };\n"
4437                         "layout(quads, equal_spacing, ccw) in;\n"
4438                         "in gl_PerVertex\n"
4439                         "{\n"
4440                         "  vec4 gl_Position;\n"
4441                         "} gl_in[];\n"
4442                      << params.shaderFunctions
4443                      << "void main(void)\n"
4444                         "{\n"
4445                         "  ivec3 sz = imageSize(resultImage);\n"
4446                         "  int index = int(gl_in[0].gl_Position.z);\n"
4447                         "  int idx = int(index % sz.x);\n"
4448                         "  int idy = int(index / sz.y);\n"
4449                      << rayQueryPart
4450                      << "  imageStore(resultImage, ivec3(idx, idy, 0), vec4(x, y, z, w));\n"
4451                         "  gl_Position = gl_in[0].gl_Position;\n"
4452                         "}\n";
4453 
4454                 programCollection.glslSources.add("tese", &buildOptions)
4455                     << glu::TessellationEvaluationSource(tese.str());
4456             }
4457         }
4458         else if (params.shaderSourceType == RayQueryShaderSourceType::TESSELLATION_CONTROL)
4459         {
4460             {
4461                 std::ostringstream tesc;
4462                 tesc << "#version 460\n"
4463                         "#extension GL_EXT_ray_tracing : enable\n"
4464                         "#extension GL_EXT_tessellation_shader : require\n"
4465                         "#extension GL_EXT_ray_query : require\n"
4466                         "struct Ray { vec3 pos; float tmin; vec3 dir; float tmax; };\n"
4467                         "layout(rgba32f, set = 0, binding = 0) uniform image3D resultImage;\n"
4468                         "layout(set = 0, binding = 1) uniform accelerationStructureEXT scene;\n"
4469                         "layout(std430, set = 0, binding = 2) buffer Rays { Ray rays[]; };\n"
4470                         "in gl_PerVertex\n"
4471                         "{\n"
4472                         "  vec4 gl_Position;\n"
4473                         "} gl_in[];\n"
4474                         "layout(vertices = 3) out;\n"
4475                         "out gl_PerVertex\n"
4476                         "{\n"
4477                         "  vec4 gl_Position;\n"
4478                         "} gl_out[];\n"
4479                         "\n"
4480                      << params.shaderFunctions
4481                      << "void main(void)\n"
4482                         "{\n"
4483                         "  ivec3 sz = imageSize(resultImage);\n"
4484                         "  int index = int(gl_in[0].gl_Position.z);\n"
4485                         "  int idx = int(index % sz.x);\n"
4486                         "  int idy = int(index / sz.y);\n"
4487                      << rayQueryPart
4488                      << "  imageStore(resultImage, ivec3(idx, idy, 0), vec4(x, y, z, w));\n"
4489                         "  gl_out[gl_InvocationID].gl_Position = gl_in[gl_InvocationID].gl_Position;\n"
4490                         "  gl_TessLevelInner[0] = 1;\n"
4491                         "  gl_TessLevelInner[1] = 1;\n"
4492                         "  gl_TessLevelOuter[gl_InvocationID] = 1;\n"
4493                         "}\n";
4494 
4495                 programCollection.glslSources.add("tesc", &buildOptions) << glu::TessellationControlSource(tesc.str());
4496             }
4497 
4498             {
4499                 std::ostringstream tese;
4500                 tese << "#version 460\n"
4501                         "#extension GL_EXT_tessellation_shader : require\n"
4502                         "layout(quads, equal_spacing, ccw) in;\n"
4503                         "in gl_PerVertex\n"
4504                         "{\n"
4505                         "  vec4 gl_Position;\n"
4506                         "} gl_in[];\n"
4507                         "\n"
4508                         "void main(void)\n"
4509                         "{\n"
4510                         "  gl_Position = gl_in[0].gl_Position;\n"
4511                         "}\n";
4512 
4513                 programCollection.glslSources.add("tese", &buildOptions)
4514                     << glu::TessellationEvaluationSource(tese.str());
4515             }
4516         }
4517 
4518         break;
4519     }
4520     case RayQueryShaderSourcePipeline::RAYTRACING:
4521     {
4522         std::stringstream rayGen;
4523 
4524         if (params.shaderSourceType == RayQueryShaderSourceType::RAY_GENERATION_RT)
4525         {
4526             rayGen << "#version 460\n"
4527                       "#extension GL_EXT_ray_tracing : enable\n"
4528                       "#extension GL_EXT_ray_query : require\n"
4529                       "struct Ray { vec3 pos; float tmin; vec3 dir; float tmax; };\n"
4530                       "struct ResultType { float x; float y; float z; float w; };\n"
4531                       "layout(std430, set = 0, binding = 0) buffer Results { ResultType results[]; };\n"
4532                       "layout(set = 0, binding = 1) uniform accelerationStructureEXT scene;\n"
4533                       "layout(std430, set = 0, binding = 2) buffer Rays { Ray rays[]; };\n"
4534                       "layout(location = 0) rayPayloadEXT vec4 payload;\n"
4535                    << params.shaderFunctions
4536                    << "void main() {\n"
4537                       "   payload = vec4("
4538                    << NO_INT_VALUE << "," << max_t * 2
4539                    << ",0,0);\n"
4540                       "   uint index = (gl_LaunchIDEXT.z * gl_LaunchSizeEXT.x * gl_LaunchSizeEXT.y) + "
4541                       "(gl_LaunchIDEXT.y * gl_LaunchSizeEXT.x) + gl_LaunchIDEXT.x;\n"
4542                    << rayQueryPart
4543                    << "   results[index].x = x;\n"
4544                       "   results[index].y = y;\n"
4545                       "   results[index].z = z;\n"
4546                       "   results[index].w = w;\n"
4547                       "}";
4548 
4549             programCollection.glslSources.add("isect_rt", &buildOptions)
4550                 << glu::IntersectionSource(updateRayTracingGLSL(genericIsect.str()));
4551             programCollection.glslSources.add("chit_rt", &buildOptions) << glu::ClosestHitSource(rtChit.str());
4552             programCollection.glslSources.add("ahit_rt", &buildOptions) << glu::AnyHitSource(genericChit.str());
4553             programCollection.glslSources.add("miss_rt", &buildOptions) << glu::MissSource(genericMiss.str());
4554         }
4555         else if (params.shaderSourceType == RayQueryShaderSourceType::RAY_GENERATION)
4556         {
4557             rayGen << "#version 460\n"
4558                       "#extension GL_EXT_ray_tracing : enable\n"
4559                       "#extension GL_EXT_ray_query : require\n"
4560                       "struct Ray { vec3 pos; float tmin; vec3 dir; float tmax; };\n"
4561                       "struct ResultType { float x; float y; float z; float w; };\n"
4562                       "layout(std430, set = 0, binding = 0) buffer Results { ResultType results[]; };\n"
4563                       "layout(set = 0, binding = 1) uniform accelerationStructureEXT scene;\n"
4564                       "layout(std430, set = 0, binding = 2) buffer Rays { Ray rays[]; };\n"
4565                    << params.shaderFunctions
4566                    << "void main() {\n"
4567                       "   uint index = (gl_LaunchIDEXT.z * gl_LaunchSizeEXT.x * gl_LaunchSizeEXT.y) + "
4568                       "(gl_LaunchIDEXT.y * gl_LaunchSizeEXT.x) + gl_LaunchIDEXT.x;\n"
4569                    << rayQueryPart
4570                    << "   results[index].x = x;\n"
4571                       "   results[index].y = y;\n"
4572                       "   results[index].z = z;\n"
4573                       "   results[index].w = w;\n"
4574                       "}";
4575         }
4576         else if (params.shaderSourceType == RayQueryShaderSourceType::CALLABLE)
4577         {
4578             rayGen << "#version 460\n"
4579                       "#extension GL_EXT_ray_tracing : require\n"
4580                       "struct CallValue\n{\n"
4581                       "  uint index;\n"
4582                       "  vec4 hitAttrib;\n"
4583                       "};\n"
4584                       "layout(location = 0) callableDataEXT CallValue param;\n"
4585                       "struct ResultType { float x; float y; float z; float w; };\n"
4586                       "layout(std430, set = 0, binding = 0) buffer Results { ResultType results[]; };\n"
4587                       "void main()\n"
4588                       "{\n"
4589                       "  uint index = (gl_LaunchIDEXT.z * gl_LaunchSizeEXT.x * gl_LaunchSizeEXT.y) + (gl_LaunchIDEXT.y "
4590                       "* gl_LaunchSizeEXT.x) + gl_LaunchIDEXT.x;\n"
4591                       "  param.index = index;\n"
4592                       "  param.hitAttrib = vec4(0, 0, 0, 0);\n"
4593                       "  executeCallableEXT(0, 0);\n"
4594                       "  results[index].x = param.hitAttrib.x;\n"
4595                       "  results[index].y = param.hitAttrib.y;\n"
4596                       "  results[index].z = param.hitAttrib.z;\n"
4597                       "  results[index].w = param.hitAttrib.w;\n"
4598                       "}\n";
4599         }
4600         else
4601         {
4602             rayGen << "#version 460\n"
4603                       "#extension GL_EXT_ray_tracing : require\n"
4604                       "#extension GL_EXT_ray_query : require\n"
4605                       "layout(location = 0) rayPayloadEXT vec4 payload;\n"
4606                       "struct ResultType { float x; float y; float z; float w; };\n"
4607                       "layout(std430, set = 0, binding = 0) buffer Results { ResultType results[]; };\n"
4608                       "layout(set = 0, binding = 3) uniform accelerationStructureEXT traceEXTAccel;\n"
4609                       "void main()\n"
4610                       "{\n"
4611                       "  payload = vec4("
4612                    << NO_INT_VALUE << "," << max_t * 2
4613                    << ",0,0);\n"
4614                       "  uint index = (gl_LaunchIDEXT.z * gl_LaunchSizeEXT.x * gl_LaunchSizeEXT.y) + (gl_LaunchIDEXT.y "
4615                       "* gl_LaunchSizeEXT.x) + gl_LaunchIDEXT.x;\n"
4616                       "  traceRayEXT(traceEXTAccel, 0, 0xFF, 0, 0, 0, vec3(0.1, 0.1, 0.0), 0.0, vec3(0.0, 0.0, 1.0), "
4617                       "500.0, 0);\n"
4618                       "  results[index].x = payload.x;\n"
4619                       "  results[index].y = payload.y;\n"
4620                       "  results[index].z = payload.z;\n"
4621                       "  results[index].w = payload.w;\n"
4622                       "}\n";
4623         }
4624 
4625         programCollection.glslSources.add("rgen", &buildOptions) << glu::RaygenSource(rayGen.str());
4626 
4627         if (params.shaderSourceType == RayQueryShaderSourceType::CLOSEST_HIT)
4628         {
4629             programCollection.glslSources.add("chit", &buildOptions)
4630                 << glu::ClosestHitSource(genericRayTracingSetResultsShader.str());
4631             programCollection.glslSources.add("miss", &buildOptions) << glu::MissSource(genericMiss.str());
4632             programCollection.glslSources.add("isect", &buildOptions)
4633                 << glu::IntersectionSource(updateRayTracingGLSL(genericIsect.str()));
4634         }
4635         else if (params.shaderSourceType == RayQueryShaderSourceType::ANY_HIT)
4636         {
4637             programCollection.glslSources.add("ahit", &buildOptions)
4638                 << glu::AnyHitSource(genericRayTracingSetResultsShader.str());
4639             programCollection.glslSources.add("miss", &buildOptions) << glu::MissSource(genericMiss.str());
4640             programCollection.glslSources.add("isect", &buildOptions)
4641                 << glu::IntersectionSource(updateRayTracingGLSL(genericIsect.str()));
4642         }
4643         else if (params.shaderSourceType == RayQueryShaderSourceType::MISS)
4644         {
4645 
4646             programCollection.glslSources.add("chit", &buildOptions) << glu::ClosestHitSource(genericChit.str());
4647             programCollection.glslSources.add("miss_1", &buildOptions)
4648                 << glu::MissSource(genericRayTracingSetResultsShader.str());
4649             programCollection.glslSources.add("isect", &buildOptions)
4650                 << glu::IntersectionSource(updateRayTracingGLSL(genericIsect.str()));
4651         }
4652         else if (params.shaderSourceType == RayQueryShaderSourceType::INTERSECTION)
4653         {
4654             {
4655                 std::stringstream chit;
4656                 chit << "#version 460    \n"
4657                         "#extension GL_EXT_ray_tracing : require\n"
4658                         "#extension GL_EXT_ray_query : require\n"
4659                         "layout(location = 0) rayPayloadInEXT vec4 payload;\n"
4660                         "hitAttributeEXT vec4 hitAttrib;\n"
4661                         "void main()\n"
4662                         "{\n"
4663                         "  payload = hitAttrib;\n"
4664                         "}\n";
4665 
4666                 programCollection.glslSources.add("chit", &buildOptions) << glu::ClosestHitSource(chit.str());
4667             }
4668 
4669             programCollection.glslSources.add("miss", &buildOptions) << glu::MissSource(genericMiss.str());
4670 
4671             {
4672                 std::stringstream isect;
4673                 isect << "#version 460\n"
4674                          "#extension GL_EXT_ray_tracing : require\n"
4675                          "#extension GL_EXT_ray_query : require\n"
4676                          "hitAttributeEXT vec4 hitValue;\n"
4677                          "struct Ray { vec3 pos; float tmin; vec3 dir; float tmax; };\n"
4678                          "layout(set = 0, binding = 1) uniform accelerationStructureEXT scene;\n"
4679                          "layout(std430, set = 0, binding = 2) buffer Rays { Ray rays[]; };\n"
4680                       << params.shaderFunctions
4681                       << "void main()\n"
4682                          "{\n"
4683                          "  uint index = (gl_LaunchIDEXT.z * gl_LaunchSizeEXT.x * gl_LaunchSizeEXT.y) + "
4684                          "(gl_LaunchIDEXT.y * gl_LaunchSizeEXT.x) + gl_LaunchIDEXT.x;\n"
4685                       << rayQueryPart
4686                       << "  hitValue.x = x;\n"
4687                          "  hitValue.y = y;\n"
4688                          "  hitValue.z = z;\n"
4689                          "  hitValue.w = w;\n"
4690                          "  reportIntersectionEXT(0.5f, 0);\n"
4691                          "}\n";
4692 
4693                 programCollection.glslSources.add("isect_1", &buildOptions)
4694                     << glu::IntersectionSource(updateRayTracingGLSL(isect.str()));
4695             }
4696         }
4697         else if (params.shaderSourceType == RayQueryShaderSourceType::CALLABLE)
4698         {
4699             {
4700                 std::stringstream call;
4701                 call << "#version 460\n"
4702                         "#extension GL_EXT_ray_tracing : require\n"
4703                         "#extension GL_EXT_ray_query : require\n"
4704                         "struct CallValue\n{\n"
4705                         "  uint index;\n"
4706                         "  vec4 hitAttrib;\n"
4707                         "};\n"
4708                         "layout(location = 0) callableDataInEXT CallValue result;\n"
4709                         "struct Ray { vec3 pos; float tmin; vec3 dir; float tmax; };\n"
4710                         "layout(set = 0, binding = 1) uniform accelerationStructureEXT scene;\n"
4711                         "layout(std430, set = 0, binding = 2) buffer Rays { Ray rays[]; };\n"
4712                      << params.shaderFunctions
4713                      << "void main()\n"
4714                         "{\n"
4715                         "  uint index = result.index;\n"
4716                      << rayQueryPart
4717                      << "  result.hitAttrib.x = x;\n"
4718                         "  result.hitAttrib.y = y;\n"
4719                         "  result.hitAttrib.z = z;\n"
4720                         "  result.hitAttrib.w = w;\n"
4721                         "}\n";
4722 
4723                 programCollection.glslSources.add("call", &buildOptions)
4724                     << glu::CallableSource(updateRayTracingGLSL(call.str()));
4725             }
4726 
4727             programCollection.glslSources.add("chit", &buildOptions) << glu::ClosestHitSource(genericChit.str());
4728             programCollection.glslSources.add("miss", &buildOptions) << glu::MissSource(genericMiss.str());
4729         }
4730 
4731         break;
4732     }
4733     default:
4734     {
4735         TCU_FAIL("Shader type not valid.");
4736     }
4737     }
4738 }
4739 
4740 #else
4741 
4742 uint32_t rayTracingDefineAnything()
4743 {
4744     return 0;
4745 }
4746 
4747 #endif // CTS_USES_VULKANSC
4748 
4749 } // namespace vk
4750