• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*------------------------------------------------------------------------
2  * Vulkan Conformance Tests
3  * ------------------------
4  *
5  * Copyright (c) 2024 The Khronos Group Inc.
6  * Copyright (c) 2024 Valve Corporation.
7  *
8  * Licensed under the Apache License, Version 2.0 (the "License");
9  * you may not use this file except in compliance with the License.
10  * You may obtain a copy of the License at
11  *
12  *      http://www.apache.org/licenses/LICENSE-2.0
13  *
14  * Unless required by applicable law or agreed to in writing, software
15  * distributed under the License is distributed on an "AS IS" BASIS,
16  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17  * See the License for the specific language governing permissions and
18  * limitations under the License.
19  *
20  *//*!
21  * \file
22  * \brief Device Generated Commands EXT Ray Tracing Tests
23  *//*--------------------------------------------------------------------*/
24 
25 #include "vktDGCRayTracingTestsExt.hpp"
26 #include "vktTestCase.hpp"
27 #include "vkRayTracingUtil.hpp"
28 #include "vktDGCUtilExt.hpp"
29 #include "vktDGCUtilCommon.hpp"
30 
31 #include "tcuVectorUtil.hpp"
32 
33 #include "deUniquePtr.hpp"
34 #include "deRandom.hpp"
35 
36 #include <string>
37 #include <sstream>
38 #include <vector>
39 #include <cstdlib>
40 
41 namespace vkt
42 {
43 namespace DGC
44 {
45 
46 namespace
47 {
48 
49 using namespace vk;
50 
51 // Place geometry in the XY [0, N] range, with one horizontal and vertical unit per instance.
52 //
53 // In the Z coordinate, geometry will be located around +10. Inactive geometries, part of the same bottom level AS, will
54 // be located in negative Z ranges to make sure rays do not hit them.
55 //
56 // Rays will be cast from the middle X+0.5, Y+0.5 points, towards +Z.
57 //
58 struct BottomLevelASParams
59 {
60     static constexpr uint32_t kTriangles = 0u;
61     static constexpr uint32_t kAABBs     = 1u;
62 
63     static constexpr uint32_t kCounterClockwise = 0u;
64     static constexpr uint32_t kClockwise        = 1u;
65 
66     static constexpr uint32_t kPrimitiveCount = 4u;
67     static constexpr uint32_t kGeometryCount  = 2u;
68     static constexpr float kBaseZ             = 10.0f;
69 
70     uint32_t geometryType;        // 0: triangles, 1: AABBs.
71     uint32_t activeGeometryIndex; // Other geometries will be located such that the ray doesn't hit them.
72     uint32_t windingDirection;    // 0: counter clockwise, 1: clockwise.
73     uint32_t closestPrimitive;    // [0,kPrimitiveCount)
74 
BottomLevelASParamsvkt::DGC::__anonba7c4f2e0111::BottomLevelASParams75     BottomLevelASParams(de::Random &rnd)
76     {
77         geometryType        = (rnd.getBool() ? kTriangles : kAABBs);
78         activeGeometryIndex = static_cast<uint32_t>(rnd.getInt(0, static_cast<int>(kGeometryCount - 1u)));
79         windingDirection    = (rnd.getBool() ? kCounterClockwise : kClockwise);
80         closestPrimitive    = static_cast<uint32_t>(rnd.getInt(0, static_cast<int>(kPrimitiveCount - 1u)));
81     }
82 };
83 
84 constexpr uint32_t kWidth      = 16u;
85 constexpr uint32_t kHeight     = 16u;
86 constexpr uint32_t kBLASCount  = 16u;
87 constexpr uint32_t kSBTCount   = 2u;
88 constexpr uint32_t kDispHeight = kHeight / kSBTCount; // Each dispatch will handle a number of rows.
89 
90 // === GLSL_EXT_ray_tracing ===
91 constexpr uint32_t kRayFlagsNoneEXT = 0u;
92 //constexpr uint32_t kRayFlagsOpaqueEXT = 1u;
93 //constexpr uint32_t kRayFlagsNoOpaqueEXT = 2u;
94 //constexpr uint32_t kRayFlagsTerminateOnFirstHitEXT = 4u;
95 //constexpr uint32_t kRayFlagsSkipClosestHitShaderEXT = 8u;
96 constexpr uint32_t kRayFlagsCullBackFacingTrianglesEXT  = 16u;
97 constexpr uint32_t kRayFlagsCullFrontFacingTrianglesEXT = 32u;
98 constexpr uint32_t kRayFlagsCullOpaqueEXT               = 64u;
99 //constexpr uint32_t kRayFlagsCullNoOpaqueEXT = 128u;
100 constexpr uint32_t kHitKindFrontFacingTriangleEXT = 0xFEu;
101 constexpr uint32_t kHitKindBackFacingTriangleEXT  = 0xFFu;
102 // === GLSL_EXT_ray_tracing ===
103 
104 constexpr float kFloatThreshold = 1.0f / 256.0f;
105 
106 constexpr VkShaderStageFlags kStageFlags =
107     (VK_SHADER_STAGE_RAYGEN_BIT_KHR | VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR | VK_SHADER_STAGE_INTERSECTION_BIT_KHR |
108      VK_SHADER_STAGE_MISS_BIT_KHR | VK_SHADER_STAGE_CALLABLE_BIT_KHR);
109 
110 // What to do in each XY 1-unit square where we trace rays.
111 struct CellParams
112 {
113     tcu::Vec4 origin;
114     VkTransformMatrixKHR transformMatrix;
115     uint32_t closestPrimitive;    // This is a copy of the bottom level AS param. Needed in the isec shader.
116     float zDirection;             // +1 or +2.
117     float minT;                   // Appropriate so the ray starts at [4,8]
118     float maxT;                   // Appropriate so the ray ends at [20,40]
119     uint32_t blasIndex;           // [0, kBLASCount)
120     uint32_t instanceCustomIndex; // [100 to 150], peudorandomly and without specific meaning.
121     VkBool32 opaque;
122     uint32_t rayFlags;  // One of: None, CullBackFacingTri, CullFrontFacingTri, CullOpaque.
123     uint32_t missIndex; // 0 or 1.
124 
125     uint32_t padding0[3]; // Padding to match std430.
126 
CellParamsvkt::DGC::__anonba7c4f2e0111::CellParams127     CellParams(uint32_t x, uint32_t y, de::Random &rnd)
128     {
129         const auto fx = static_cast<float>(x);
130         const auto fy = static_cast<float>(y);
131 
132         origin           = tcu::Vec4(fx + 0.5f, fy + 0.5f, 0.0f, 1.0f);
133         transformMatrix  = VkTransformMatrixKHR{{
134             {1.0f, 0.0f, 0.0f, fx},
135             {0.0f, 1.0f, 0.0f, fy},
136             {0.0f, 0.0f, 1.0f, 0.0f},
137         }};
138         closestPrimitive = 0u; // This needs to be copied later, after blasIndex is set in this constructor.
139 
140         zDirection          = (rnd.getBool() ? 1.0f : 2.0f);
141         minT                = (rnd.getBool() ? 4.0f : 8.0f) / zDirection;
142         maxT                = (rnd.getBool() ? 20.0f : 40.0f) / zDirection;
143         blasIndex           = static_cast<uint32_t>(rnd.getInt(0, static_cast<int>(kBLASCount - 1u)));
144         instanceCustomIndex = static_cast<uint32_t>(rnd.getInt(100, 150)); // Just an ID.
145         opaque              = (rnd.getBool() ? VK_TRUE : VK_FALSE);
146 
147         static const std::vector<uint32_t> kFlagCatalogue{
148             kRayFlagsNoneEXT,
149             kRayFlagsCullBackFacingTrianglesEXT,
150             kRayFlagsCullFrontFacingTrianglesEXT,
151             kRayFlagsCullOpaqueEXT,
152         };
153         rayFlags  = kFlagCatalogue.at(rnd.getInt(0, static_cast<int>(kFlagCatalogue.size()) - 1));
154         missIndex = static_cast<uint32_t>(rnd.getInt(0, 1));
155     }
156 };
157 
158 // Information to be filled from shaders.
159 struct CellOutput
160 {
161     // I/O Data.
162     tcu::Vec4 rgenInitialPayload;
163     tcu::Vec4 rgenFinalPayload;
164     tcu::Vec4 chitPayload;
165     tcu::Vec4 missPayload;
166     tcu::Vec4 chitIncomingPayload;
167     tcu::Vec4 missIncomingPayload;
168     tcu::Vec4 isecAttribute;
169     tcu::Vec4 chitAttribute;
170     tcu::Vec4 rgenSRB;
171     tcu::Vec4 isecSRB;
172     tcu::Vec4 chitSRB;
173     tcu::Vec4 missSRB;
174     tcu::Vec4 call0SRB;
175     tcu::Vec4 call1SRB;
176 
177     // Built-ins.
178     tcu::UVec4 rgenLaunchIDEXT;
179     tcu::UVec4 rgenLaunchSizeEXT;
180 
181     tcu::UVec4 chitLaunchIDEXT;
182     tcu::UVec4 chitLaunchSizeEXT;
183 
184     int32_t chitPrimitiveID;
185     int32_t chitInstanceID;
186     int32_t chitInstanceCustomIndexEXT;
187     int32_t chitGeometryIndexEXT;
188 
189     tcu::Vec4 chitWorldRayOriginEXT;
190     tcu::Vec4 chitWorldRayDirectionEXT;
191     tcu::Vec4 chitObjectRayOriginEXT;
192     tcu::Vec4 chitObjectRayDirectionEXT;
193 
194     float chitRayTminEXT;
195     float chitRayTmaxEXT;
196     uint32_t chitIncomingRayFlagsEXT;
197 
198     float chitHitTEXT;
199     uint32_t chitHitKindEXT;
200 
201     uint32_t padding0[3]; // To match the GLSL alignment.
202 
203     tcu::Vec4 chitObjectToWorldEXT[3];
204     tcu::Vec4 chitObjectToWorld3x4EXT[4];
205     tcu::Vec4 chitWorldToObjectEXT[3];
206     tcu::Vec4 chitWorldToObject3x4EXT[4];
207 
208     tcu::UVec4 isecLaunchIDEXT;
209     tcu::UVec4 isecLaunchSizeEXT;
210 
211     int32_t isecPrimitiveID;
212     int32_t isecInstanceID;
213     int32_t isecInstanceCustomIndexEXT;
214     int32_t isecGeometryIndexEXT;
215 
216     tcu::Vec4 isecWorldRayOriginEXT;
217     tcu::Vec4 isecWorldRayDirectionEXT;
218     tcu::Vec4 isecObjectRayOriginEXT;
219     tcu::Vec4 isecObjectRayDirectionEXT;
220 
221     float isecRayTminEXT;
222     float isecRayTmaxEXT;
223     uint32_t isecIncomingRayFlagsEXT;
224 
225     uint32_t padding1[1]; // To match the GLSL alignment.
226 
227     tcu::Vec4 isecObjectToWorldEXT[3];
228     tcu::Vec4 isecObjectToWorld3x4EXT[4];
229     tcu::Vec4 isecWorldToObjectEXT[3];
230     tcu::Vec4 isecWorldToObject3x4EXT[4];
231 
232     tcu::UVec4 missLaunchIDEXT;
233     tcu::UVec4 missLaunchSizeEXT;
234 
235     tcu::Vec4 missWorldRayOriginEXT;
236     tcu::Vec4 missWorldRayDirectionEXT;
237 
238     float missRayTminEXT;
239     float missRayTmaxEXT;
240     uint32_t missIncomingRayFlagsEXT;
241 
242     uint32_t padding2[1]; // To match the GLSL alignment.
243 
244     tcu::UVec4 callLaunchIDEXT;
245     tcu::UVec4 callLaunchSizeEXT;
246 
CellOutputvkt::DGC::__anonba7c4f2e0111::CellOutput247     CellOutput(void)
248     {
249         deMemset(this, 0, sizeof(*this));
250     }
251 };
252 
253 using BLASPtr = de::SharedPtr<BottomLevelAccelerationStructure>;
254 using TLASPtr = de::SharedPtr<TopLevelAccelerationStructure>;
255 
makeBottomLevelASWithParams(const BottomLevelASParams & params)256 BLASPtr makeBottomLevelASWithParams(const BottomLevelASParams &params)
257 {
258     auto blas = makeBottomLevelAccelerationStructure();
259 
260     if (params.geometryType == BottomLevelASParams::kTriangles)
261     {
262         static constexpr uint32_t kTriangleVertices = 3u;
263         const bool clockwise                        = (params.windingDirection == BottomLevelASParams::kClockwise);
264 
265         for (uint32_t geometryIdx = 0u; geometryIdx < BottomLevelASParams::kGeometryCount; ++geometryIdx)
266         {
267             std::vector<tcu::Vec3> vertices;
268             vertices.reserve(kTriangleVertices * BottomLevelASParams::kPrimitiveCount);
269 
270             const float zFactor = (geometryIdx == params.activeGeometryIndex ? 1.0f : -1.0f);
271 
272             for (uint32_t primIdx = 0u; primIdx < BottomLevelASParams::kPrimitiveCount; ++primIdx)
273             {
274                 const float zOffset = (primIdx == params.closestPrimitive ? 0.0f : static_cast<float>(primIdx + 1u));
275                 const float zCoord  = zFactor * BottomLevelASParams::kBaseZ + zOffset;
276 
277                 const tcu::Vec3 vertA(0.25f, 0.25f, zCoord);
278                 const tcu::Vec3 vertB(0.75f, 0.25f, zCoord);
279                 const tcu::Vec3 vertC(0.50f, 0.75f, zCoord);
280 
281                 vertices.push_back(clockwise ? vertB : vertA);
282                 vertices.push_back(clockwise ? vertA : vertB);
283                 vertices.push_back(vertC);
284             }
285 
286             blas->addGeometry(vertices, true /*triangles*/, 0u);
287         }
288     }
289     else
290     {
291         static constexpr uint32_t kAABBVertices = 2u;
292 
293         for (uint32_t geometryIdx = 0u; geometryIdx < BottomLevelASParams::kGeometryCount; ++geometryIdx)
294         {
295             std::vector<tcu::Vec3> vertices;
296             vertices.reserve(kAABBVertices * BottomLevelASParams::kPrimitiveCount);
297 
298             const float zFactor = (geometryIdx == params.activeGeometryIndex ? 1.0f : -1.0f);
299 
300             for (uint32_t primIdx = 0u; primIdx < BottomLevelASParams::kPrimitiveCount; ++primIdx)
301             {
302                 const float zOffset = (primIdx == params.closestPrimitive ? 0.0f : static_cast<float>(primIdx + 1u));
303                 const float zCoord  = zFactor * BottomLevelASParams::kBaseZ + zFactor * zOffset;
304 
305                 const tcu::Vec3 vertA(0.0f, 0.0f, zCoord);
306                 const tcu::Vec3 vertB(1.0f, 1.0f, zCoord + 0.5f);
307 
308                 vertices.push_back(vertA);
309                 vertices.push_back(vertB);
310             }
311 
312             blas->addGeometry(vertices, false /*triangles*/, 0u);
313         }
314     }
315 
316     return BLASPtr(blas.release());
317 }
318 
makeTopLevelASWithParams(const std::vector<BLASPtr> & blas,const std::vector<CellParams> & cellParams)319 TLASPtr makeTopLevelASWithParams(const std::vector<BLASPtr> &blas, const std::vector<CellParams> &cellParams)
320 {
321     const auto fixedGeometryFlags = static_cast<VkGeometryInstanceFlagsKHR>(VK_GEOMETRY_INSTANCE_FORCE_OPAQUE_BIT_KHR);
322 
323     auto topLevelAS = makeTopLevelAccelerationStructure();
324     topLevelAS->setInstanceCount(cellParams.size());
325 
326     for (const auto &cp : cellParams)
327         topLevelAS->addInstance(blas.at(cp.blasIndex), cp.transformMatrix, cp.instanceCustomIndex, 0xFFu, 0u,
328                                 fixedGeometryFlags);
329 
330     return TLASPtr(topLevelAS.release());
331 }
332 
333 class RayTracingInstance : public vkt::TestInstance
334 {
335 public:
336     struct Params
337     {
338         bool useExecutionSet;
339         bool preprocess;
340         bool unordered;
341         bool computeQueue;
342 
getRandomSeedvkt::DGC::__anonba7c4f2e0111::RayTracingInstance::Params343         uint32_t getRandomSeed(void) const
344         {
345             return 1720182500u;
346         }
347     };
348 
RayTracingInstance(Context & context,const Params & params)349     RayTracingInstance(Context &context, const Params &params) : vkt::TestInstance(context), m_params(params)
350     {
351     }
~RayTracingInstance(void)352     virtual ~RayTracingInstance(void)
353     {
354     }
355 
356     tcu::TestStatus iterate(void) override;
357 
358 protected:
359     const Params m_params;
360 };
361 
362 class RayTracingCase : public vkt::TestCase
363 {
364 public:
RayTracingCase(tcu::TestContext & testCtx,const std::string & name,const RayTracingInstance::Params & params)365     RayTracingCase(tcu::TestContext &testCtx, const std::string &name, const RayTracingInstance::Params &params)
366         : vkt::TestCase(testCtx, name)
367         , m_params(params)
368     {
369     }
~RayTracingCase(void)370     virtual ~RayTracingCase(void)
371     {
372     }
373 
374     void checkSupport(Context &context) const override;
375     void initPrograms(vk::SourceCollections &programCollection) const override;
376     TestInstance *createInstance(Context &context) const override;
377 
378 protected:
379     const RayTracingInstance::Params m_params;
380 };
381 
checkSupport(Context & context) const382 void RayTracingCase::checkSupport(Context &context) const
383 {
384     context.requireDeviceFunctionality("VK_KHR_acceleration_structure");
385     context.requireDeviceFunctionality("VK_KHR_ray_tracing_pipeline");
386     context.requireDeviceFunctionality("VK_KHR_ray_tracing_maintenance1");
387 
388     const auto bindStages = (m_params.useExecutionSet ? kStageFlags : 0u);
389     checkDGCExtSupport(context, kStageFlags, bindStages);
390 
391     if (m_params.computeQueue)
392         context.getComputeQueue(); // Will throw NotSupportedError if not available.
393 }
394 
395 // Offset that the miss index applies to payload values.
getMissIndexOffset(uint32_t missIndex)396 uint32_t getMissIndexOffset(uint32_t missIndex)
397 {
398     return (missIndex + 1u) * 1000000u;
399 }
400 
401 // Offset that the closest-hit index applies ot payload values.
getChitIndexOffset(uint32_t chitIndex)402 uint32_t getChitIndexOffset(uint32_t chitIndex)
403 {
404     return (chitIndex + 1u) * 100000u;
405 }
406 
407 // Offset that the intersection index sets in the hit attribute.
getIsecIndexOffset(uint32_t isecIndex)408 uint32_t getIsecIndexOffset(uint32_t isecIndex)
409 {
410     return (isecIndex + 1u) * 10000u;
411 }
412 
413 // Offset that the callable shader applies to the callable data.
getCallIndexOffset(uint32_t callIndex)414 uint32_t getCallIndexOffset(uint32_t callIndex)
415 {
416     return (callIndex + 1u) * 1000u;
417 }
418 
initPrograms(vk::SourceCollections & programCollection) const419 void RayTracingCase::initPrograms(vk::SourceCollections &programCollection) const
420 {
421     const vk::ShaderBuildOptions buildOptions(programCollection.usedVulkanVersion, vk::SPIRV_VERSION_1_4, 0u, true);
422 
423     std::string cellParamsDecl;
424     {
425         // Note this must roughly match the CellParams struct declared above.
426         std::ostringstream cellParamsStream;
427         cellParamsStream << "struct CellParams\n"
428                          << "{\n"
429                          << "    vec4  origin;\n"
430                          << "    float transformMatrix[12];\n"
431                          << "    uint  closestPrimitive;\n"
432                          << "    float zDirection;\n"
433                          << "    float minT;\n"
434                          << "    float maxT;\n"
435                          << "    uint  blasIndex;\n"
436                          << "    uint  instanceCustomIndex;\n"
437                          << "    uint  opaque;\n"
438                          << "    uint  rayFlags;\n"
439                          << "    uint  missIndex;\n"
440                          << "};\n";
441         cellParamsDecl = cellParamsStream.str();
442     }
443 
444     std::string cellOutputDecl;
445     {
446         std::ostringstream cellOutputStream;
447         cellOutputStream << "struct CellOutput\n"
448                          << "{\n"
449                          << "    vec4 rgenInitialPayload;\n"
450                          << "    vec4 rgenFinalPayload;\n"
451                          << "    vec4 chitPayload;\n"
452                          << "    vec4 missPayload;\n"
453                          << "    vec4 chitIncomingPayload;\n"
454                          << "    vec4 missIncomingPayload;\n"
455                          << "    vec4 isecAttribute;\n"
456                          << "    vec4 chitAttribute;\n"
457                          << "    vec4 rgenSRB;\n"
458                          << "    vec4 isecSRB;\n"
459                          << "    vec4 chitSRB;\n"
460                          << "    vec4 missSRB;\n"
461                          << "    vec4 call0SRB;\n"
462                          << "    vec4 call1SRB;\n"
463                          << "\n"
464                          << "    uvec4 rgenLaunchIDEXT;\n"
465                          << "    uvec4 rgenLaunchSizeEXT;\n"
466                          << "\n"
467                          << "    uvec4 chitLaunchIDEXT;\n"
468                          << "    uvec4 chitLaunchSizeEXT;\n"
469                          << "\n"
470                          << "    int chitPrimitiveID;\n"
471                          << "    int chitInstanceID;\n"
472                          << "    int chitInstanceCustomIndexEXT;\n"
473                          << "    int chitGeometryIndexEXT;\n"
474                          << "\n"
475                          << "    vec4 chitWorldRayOriginEXT;\n"
476                          << "    vec4 chitWorldRayDirectionEXT;\n"
477                          << "    vec4 chitObjectRayOriginEXT;\n"
478                          << "    vec4 chitObjectRayDirectionEXT;\n"
479                          << "\n"
480                          << "    float chitRayTminEXT;\n"
481                          << "    float chitRayTmaxEXT;\n"
482                          << "    uint  chitIncomingRayFlagsEXT;\n"
483                          << "\n"
484                          << "    float chitHitTEXT;\n"
485                          << "    uint  chitHitKindEXT;\n"
486                          << "\n"
487                          << "    vec4 chitObjectToWorldEXT[3];\n"
488                          << "    vec4 chitObjectToWorld3x4EXT[4];\n"
489                          << "    vec4 chitWorldToObjectEXT[3];\n"
490                          << "    vec4 chitWorldToObject3x4EXT[4];\n"
491                          << "\n"
492                          << "    uvec4 isecLaunchIDEXT;\n"
493                          << "    uvec4 isecLaunchSizeEXT;\n"
494                          << "\n"
495                          << "    int isecPrimitiveID;\n"
496                          << "    int isecInstanceID;\n"
497                          << "    int isecInstanceCustomIndexEXT;\n"
498                          << "    int isecGeometryIndexEXT;\n"
499                          << "\n"
500                          << "    vec4 isecWorldRayOriginEXT;\n"
501                          << "    vec4 isecWorldRayDirectionEXT;\n"
502                          << "    vec4 isecObjectRayOriginEXT;\n"
503                          << "    vec4 isecObjectRayDirectionEXT;\n"
504                          << "\n"
505                          << "    float isecRayTminEXT;\n"
506                          << "    float isecRayTmaxEXT;\n"
507                          << "    uint  isecIncomingRayFlagsEXT;\n"
508                          << "\n"
509                          << "    vec4 isecObjectToWorldEXT[3];\n"
510                          << "    vec4 isecObjectToWorld3x4EXT[4];\n"
511                          << "    vec4 isecWorldToObjectEXT[3];\n"
512                          << "    vec4 isecWorldToObject3x4EXT[4];\n"
513                          << "\n"
514                          << "    uvec4 missLaunchIDEXT;\n"
515                          << "    uvec4 missLaunchSizeEXT;\n"
516                          << "\n"
517                          << "    vec4 missWorldRayOriginEXT;\n"
518                          << "    vec4 missWorldRayDirectionEXT;\n"
519                          << "\n"
520                          << "    float missRayTminEXT;\n"
521                          << "    float missRayTmaxEXT;\n"
522                          << "    uint  missIncomingRayFlagsEXT;\n"
523                          << "\n"
524                          << "    uvec4 callLaunchIDEXT;\n"
525                          << "    uvec4 callLaunchSizeEXT;\n"
526                          << "};\n";
527         cellOutputDecl = cellOutputStream.str();
528     }
529 
530     const uint32_t cellCount = kWidth * kHeight;
531 
532     std::string descDecl;
533     {
534         std::ostringstream descStream;
535         descStream << cellParamsDecl << cellOutputDecl
536                    << "layout (set=0, binding=0) uniform accelerationStructureEXT topLevelAS;\n"
537                    << "layout (set=0, binding=1, std430) readonly buffer InputBlock {\n"
538                    << "    CellParams params[" << cellCount << "];\n"
539                    << "} ib;\n"
540                    << "layout (set=0, binding=2, std430) buffer OutputBlock {\n"
541                    << "    CellOutput values[" << cellCount << "];\n"
542                    << "} ob;\n"
543                    << "layout (push_constant, std430) uniform PCBlock { uint offsetY; } pc;\n";
544         descDecl = descStream.str();
545     }
546 
547     std::string cellIdxFuncDecl;
548     {
549         std::ostringstream cellIdxFuncStream;
550         cellIdxFuncStream
551             << "uint getCellIndex(bool print) {\n"
552             << "    const uint row = gl_LaunchIDEXT.y + pc.offsetY;\n"
553             << "    const uint cellIndex = row * gl_LaunchSizeEXT.x + gl_LaunchIDEXT.x;\n"
554             << "    if (print)"
555             << "        debugPrintfEXT(\"pc.offsetY=%u gl_LaunchIDEXT.x=%u gl_LaunchIDEXT.y=%u gl_LaunchSizeEXT.x=%u "
556                "gl_LaunchSizeEXT.y=%u row=%u cellIndex=%u\\n\", pc.offsetY, gl_LaunchIDEXT.x, gl_LaunchIDEXT.y, "
557                "gl_LaunchSizeEXT.x, gl_LaunchSizeEXT.y, row, cellIndex);\n"
558             << "    return cellIndex;\n"
559             << "}\n";
560         cellIdxFuncDecl = cellIdxFuncStream.str();
561     }
562 
563     std::string shaderRecordDecl;
564     {
565         std::ostringstream shaderRecordStream;
566         shaderRecordStream << "layout(shaderRecordEXT, std430) buffer SRBBlock {\n"
567                            << "    vec4 data;\n"
568                            << "} srb;\n";
569         shaderRecordDecl = shaderRecordStream.str();
570     }
571 
572     // 2 ray-gen shaders: one without SRB and one with it.
573     for (uint32_t rgenIdx = 0u; rgenIdx < kSBTCount; ++rgenIdx)
574     {
575         const bool withSRB = (rgenIdx > 0u);
576         const auto suffix  = (withSRB ? "-srb" : "");
577 
578         std::ostringstream rgen;
579         rgen << "#version 460 core\n"
580              << "#extension GL_EXT_debug_printf : enable\n"
581              << "#extension GL_EXT_ray_tracing : require\n"
582              << "layout (location=0) rayPayloadEXT vec4 payload;\n"
583              << descDecl << (withSRB ? shaderRecordDecl : "") << cellIdxFuncDecl << "void main()\n"
584              << "{\n"
585              << "    const uint cellIdx = getCellIndex(false);\n"
586              << "\n"
587              << "    ob.values[cellIdx].rgenLaunchIDEXT = uvec4(gl_LaunchIDEXT.xyz, 0u);\n"
588              << "    ob.values[cellIdx].rgenLaunchSizeEXT = uvec4(gl_LaunchSizeEXT.xyz, 0u);\n"
589              << "\n"
590              << "    const uint  rayFlags  = ib.params[cellIdx].rayFlags;\n"
591              << "    const vec3  origin    = ib.params[cellIdx].origin.xyz;\n"
592              << "    const vec3  direction = vec3(0, 0, ib.params[cellIdx].zDirection);\n"
593              << "    const float tMin      = ib.params[cellIdx].minT;\n"
594              << "    const float tMax      = ib.params[cellIdx].maxT;\n"
595              << "    const uint  missIndex = ib.params[cellIdx].missIndex;\n"
596              << "    const uint  cullMask  = 0xFF;\n"
597              << "    const uint  sbtOffset = 0u;\n"
598              << "    const uint  sbtStride = 1u;\n"
599              << "\n"
600              << "    const vec4 payloadValue = vec4(gl_LaunchIDEXT.xyz, 0.0);\n"
601              << "    payload = payloadValue;\n"
602              << "    ob.values[cellIdx].rgenInitialPayload = payload;\n"
603              << "    traceRayEXT(topLevelAS, rayFlags, cullMask, sbtOffset, sbtStride, missIndex, origin, tMin, "
604                 "direction, tMax, 0);\n"
605              << "    ob.values[cellIdx].rgenFinalPayload = payload;\n"
606              << (withSRB ? "    ob.values[cellIdx].rgenSRB = srb.data;\n" : "") << "}\n";
607         const auto shaderName = std::string("rgen") + suffix;
608         programCollection.glslSources.add(shaderName) << glu::RaygenSource(rgen.str()) << buildOptions;
609     }
610 
611     // 2 miss shaders, and variants with/without SRB for each.
612     for (uint32_t missIdx = 0u; missIdx < 2u; ++missIdx)
613         for (uint32_t srbIdx = 0u; srbIdx < kSBTCount; ++srbIdx)
614         {
615             const bool withSRB = (srbIdx > 0u);
616             const auto suffix  = (withSRB ? "-srb" : "");
617 
618             std::ostringstream miss;
619             miss << "#version 460 core\n"
620                  << "#extension GL_EXT_debug_printf : enable\n"
621                  << "#extension GL_EXT_ray_tracing : require\n"
622                  << "layout (location = 0) rayPayloadInEXT vec4 payload;\n"
623                  << descDecl << (withSRB ? shaderRecordDecl : "") << cellIdxFuncDecl << "void main()\n"
624                  << "{\n"
625                  << "    const uint cellIdx = getCellIndex(false);\n"
626                  << "\n"
627                  << "    ob.values[cellIdx].missLaunchIDEXT = uvec4(gl_LaunchIDEXT, 0u);\n"
628                  << "    ob.values[cellIdx].missLaunchSizeEXT = uvec4(gl_LaunchSizeEXT, 0u);\n"
629                  << "    ob.values[cellIdx].missWorldRayOriginEXT = vec4(gl_WorldRayOriginEXT, 1.0);\n"
630                  << "    ob.values[cellIdx].missWorldRayDirectionEXT = vec4(gl_WorldRayDirectionEXT, 0.0);\n"
631                  << "    ob.values[cellIdx].missRayTminEXT = gl_RayTminEXT;\n"
632                  << "    ob.values[cellIdx].missRayTmaxEXT = gl_RayTmaxEXT;\n"
633                  << "    ob.values[cellIdx].missIncomingRayFlagsEXT = gl_IncomingRayFlagsEXT;\n"
634                  << "\n"
635                  << "    ob.values[cellIdx].missIncomingPayload = payload;\n"
636                  << "    const float valueOffset = " << getMissIndexOffset(missIdx) << ";\n"
637                  << "    const vec4 vecOffset = vec4(valueOffset, valueOffset, valueOffset, valueOffset);\n"
638                  << "    payload = payload + vecOffset;\n"
639                  << "    ob.values[cellIdx].missPayload = payload;\n"
640                  << (withSRB ? "    ob.values[cellIdx].missSRB = srb.data;\n" : "") << "}\n";
641             const auto shaderName = std::string("miss") + std::to_string(missIdx) + suffix;
642             programCollection.glslSources.add(shaderName) << glu::MissSource(miss.str()) << buildOptions;
643         }
644 
645     // 2 closest-hit shaders and variants with/without SRB for each.
646     for (uint32_t chitIdx = 0u; chitIdx < 2u; ++chitIdx)
647         for (uint32_t srbIdx = 0u; srbIdx < kSBTCount; ++srbIdx)
648         {
649             const bool withSRB = (srbIdx > 0u);
650             const auto suffix  = (withSRB ? "-srb" : "");
651 
652             std::ostringstream chit;
653             chit << "#version 460 core\n"
654                  << "#extension GL_EXT_debug_printf : enable\n"
655                  << "#extension GL_EXT_ray_tracing : require\n"
656                  << "layout (location = 0) rayPayloadInEXT vec4 payload;\n"
657                  << "layout (location = 0) callableDataEXT vec4 callData;\n"
658                  << "hitAttributeEXT vec2 hitAttrib;\n"
659                  << descDecl << (withSRB ? shaderRecordDecl : "") << cellIdxFuncDecl << "void main()\n"
660                  << "{\n"
661                  << "    const uint cellIdx = getCellIndex(false);\n"
662                  << "\n"
663                  << "    ob.values[cellIdx].chitLaunchIDEXT = uvec4(gl_LaunchIDEXT, 0u);\n"
664                  << "    ob.values[cellIdx].chitLaunchSizeEXT = uvec4(gl_LaunchSizeEXT, 0u);\n"
665                  << "    ob.values[cellIdx].chitPrimitiveID = gl_PrimitiveID;\n"
666                  << "    ob.values[cellIdx].chitInstanceID = gl_InstanceID;\n"
667                  << "    ob.values[cellIdx].chitInstanceCustomIndexEXT = gl_InstanceCustomIndexEXT;\n"
668                  << "    ob.values[cellIdx].chitGeometryIndexEXT = gl_GeometryIndexEXT;\n"
669                  << "    ob.values[cellIdx].chitWorldRayOriginEXT = vec4(gl_WorldRayOriginEXT, 1.0);\n"
670                  << "    ob.values[cellIdx].chitWorldRayDirectionEXT = vec4(gl_WorldRayDirectionEXT, 0.0);\n"
671                  << "    ob.values[cellIdx].chitObjectRayOriginEXT = vec4(gl_ObjectRayOriginEXT, 1.0);\n"
672                  << "    ob.values[cellIdx].chitObjectRayDirectionEXT = vec4(gl_ObjectRayDirectionEXT, 0.0);\n"
673                  << "    ob.values[cellIdx].chitRayTminEXT = gl_RayTminEXT;\n"
674                  << "    ob.values[cellIdx].chitRayTmaxEXT = gl_RayTmaxEXT;\n"
675                  << "    ob.values[cellIdx].chitIncomingRayFlagsEXT = gl_IncomingRayFlagsEXT;\n"
676                  << "    ob.values[cellIdx].chitHitTEXT = gl_HitTEXT;\n"
677                  << "    ob.values[cellIdx].chitHitKindEXT = gl_HitKindEXT;\n"
678                  << "    ob.values[cellIdx].chitObjectToWorldEXT[0] = vec4(gl_ObjectToWorldEXT[0][0], "
679                     "gl_ObjectToWorldEXT[1][0], gl_ObjectToWorldEXT[2][0], gl_ObjectToWorldEXT[3][0]);\n"
680                  << "    ob.values[cellIdx].chitObjectToWorldEXT[1] = vec4(gl_ObjectToWorldEXT[0][1], "
681                     "gl_ObjectToWorldEXT[1][1], gl_ObjectToWorldEXT[2][1], gl_ObjectToWorldEXT[3][1]);\n"
682                  << "    ob.values[cellIdx].chitObjectToWorldEXT[2] = vec4(gl_ObjectToWorldEXT[0][2], "
683                     "gl_ObjectToWorldEXT[1][2], gl_ObjectToWorldEXT[2][2], gl_ObjectToWorldEXT[3][2]);\n"
684                  << "    ob.values[cellIdx].chitObjectToWorld3x4EXT[0] = vec4(gl_ObjectToWorld3x4EXT[0][0], "
685                     "gl_ObjectToWorld3x4EXT[1][0], gl_ObjectToWorld3x4EXT[2][0], 0.0);\n"
686                  << "    ob.values[cellIdx].chitObjectToWorld3x4EXT[1] = vec4(gl_ObjectToWorld3x4EXT[0][1], "
687                     "gl_ObjectToWorld3x4EXT[1][1], gl_ObjectToWorld3x4EXT[2][1], 0.0);\n"
688                  << "    ob.values[cellIdx].chitObjectToWorld3x4EXT[2] = vec4(gl_ObjectToWorld3x4EXT[0][2], "
689                     "gl_ObjectToWorld3x4EXT[1][2], gl_ObjectToWorld3x4EXT[2][2], 0.0);\n"
690                  << "    ob.values[cellIdx].chitObjectToWorld3x4EXT[3] = vec4(gl_ObjectToWorld3x4EXT[0][3], "
691                     "gl_ObjectToWorld3x4EXT[1][3], gl_ObjectToWorld3x4EXT[2][3], 0.0);\n"
692                  << "    ob.values[cellIdx].chitWorldToObjectEXT[0] = vec4(gl_WorldToObjectEXT[0][0], "
693                     "gl_WorldToObjectEXT[1][0], gl_WorldToObjectEXT[2][0], gl_WorldToObjectEXT[3][0]);\n"
694                  << "    ob.values[cellIdx].chitWorldToObjectEXT[1] = vec4(gl_WorldToObjectEXT[0][1], "
695                     "gl_WorldToObjectEXT[1][1], gl_WorldToObjectEXT[2][1], gl_WorldToObjectEXT[3][1]);\n"
696                  << "    ob.values[cellIdx].chitWorldToObjectEXT[2] = vec4(gl_WorldToObjectEXT[0][2], "
697                     "gl_WorldToObjectEXT[1][2], gl_WorldToObjectEXT[2][2], gl_WorldToObjectEXT[3][2]);\n"
698                  << "    ob.values[cellIdx].chitWorldToObject3x4EXT[0] = vec4(gl_WorldToObject3x4EXT[0][0], "
699                     "gl_WorldToObject3x4EXT[1][0], gl_WorldToObject3x4EXT[2][0], 0.0);\n"
700                  << "    ob.values[cellIdx].chitWorldToObject3x4EXT[1] = vec4(gl_WorldToObject3x4EXT[0][1], "
701                     "gl_WorldToObject3x4EXT[1][1], gl_WorldToObject3x4EXT[2][1], 0.0);\n"
702                  << "    ob.values[cellIdx].chitWorldToObject3x4EXT[2] = vec4(gl_WorldToObject3x4EXT[0][2], "
703                     "gl_WorldToObject3x4EXT[1][2], gl_WorldToObject3x4EXT[2][2], 0.0);\n"
704                  << "    ob.values[cellIdx].chitWorldToObject3x4EXT[3] = vec4(gl_WorldToObject3x4EXT[0][3], "
705                     "gl_WorldToObject3x4EXT[1][3], gl_WorldToObject3x4EXT[2][3], 0.0);\n"
706                  << "\n"
707                  << "    ob.values[cellIdx].chitIncomingPayload = payload;\n"
708                  << "    const float valueOffset = " << getChitIndexOffset(chitIdx) << ";\n"
709                  << "    const vec4 vecOffset = vec4(valueOffset, valueOffset, valueOffset, valueOffset);\n"
710                  << "    payload = payload + vecOffset;\n"
711                  << "    callData = payload;\n"
712                  << "    executeCallableEXT(1, 0); // Callable shader 1, callable data 0\n"
713                  << "    payload = callData;\n"
714                  << "    ob.values[cellIdx].chitPayload = payload;\n"
715                  << "    ob.values[cellIdx].chitAttribute = ((gl_HitKindEXT < 0xF0u) ? vec4(hitAttrib.xy, 0, 0) : "
716                     "vec4(0, 0, 0, 0));\n"
717                  << (withSRB ? "    ob.values[cellIdx].chitSRB = srb.data;\n" : "") << "}\n";
718             const auto shaderName = std::string("chit") + std::to_string(chitIdx) + suffix;
719             programCollection.glslSources.add(shaderName) << glu::ClosestHitSource(chit.str()) << buildOptions;
720         }
721 
722     // 2 intersection shaders and variants with/without SRB for each.
723     for (uint32_t isecIdx = 0u; isecIdx < 2u; ++isecIdx)
724         for (uint32_t srbIdx = 0u; srbIdx < kSBTCount; ++srbIdx)
725         {
726             const bool withSRB = (srbIdx > 0u);
727             const auto suffix  = (withSRB ? "-srb" : "");
728 
729             std::ostringstream isec;
730             isec << "#version 460 core\n"
731                  << "#extension GL_EXT_debug_printf : enable\n"
732                  << "#extension GL_EXT_ray_tracing : require\n"
733                  << "hitAttributeEXT vec2 hitAttrib;\n"
734                  << descDecl << (withSRB ? shaderRecordDecl : "") << cellIdxFuncDecl << "void main()\n"
735                  << "{\n"
736                  << "    const uint cellIdx = getCellIndex(false);\n"
737                  << "\n"
738                  << "    if (gl_PrimitiveID == ib.params[cellIdx].closestPrimitive) {\n"
739                  << "        ob.values[cellIdx].isecLaunchIDEXT = uvec4(gl_LaunchIDEXT, 0u);\n"
740                  << "        ob.values[cellIdx].isecLaunchSizeEXT = uvec4(gl_LaunchSizeEXT, 0u);\n"
741                  << "        ob.values[cellIdx].isecPrimitiveID = gl_PrimitiveID;\n"
742                  << "        ob.values[cellIdx].isecInstanceID = gl_InstanceID;\n"
743                  << "        ob.values[cellIdx].isecInstanceCustomIndexEXT = gl_InstanceCustomIndexEXT;\n"
744                  << "        ob.values[cellIdx].isecGeometryIndexEXT = gl_GeometryIndexEXT;\n"
745                  << "        ob.values[cellIdx].isecWorldRayOriginEXT = vec4(gl_WorldRayOriginEXT, 1.0);\n"
746                  << "        ob.values[cellIdx].isecWorldRayDirectionEXT = vec4(gl_WorldRayDirectionEXT, 0.0);\n"
747                  << "        ob.values[cellIdx].isecObjectRayOriginEXT = vec4(gl_ObjectRayOriginEXT, 1.0);\n"
748                  << "        ob.values[cellIdx].isecObjectRayDirectionEXT = vec4(gl_ObjectRayDirectionEXT, 0.0);\n"
749                  << "        ob.values[cellIdx].isecRayTminEXT = gl_RayTminEXT;\n"
750                  << "        ob.values[cellIdx].isecRayTmaxEXT = gl_RayTmaxEXT;\n"
751                  << "        ob.values[cellIdx].isecIncomingRayFlagsEXT = gl_IncomingRayFlagsEXT;\n"
752                  << "        ob.values[cellIdx].isecObjectToWorldEXT[0] = vec4(gl_ObjectToWorldEXT[0][0], "
753                     "gl_ObjectToWorldEXT[1][0], gl_ObjectToWorldEXT[2][0], gl_ObjectToWorldEXT[3][0]);\n"
754                  << "        ob.values[cellIdx].isecObjectToWorldEXT[1] = vec4(gl_ObjectToWorldEXT[0][1], "
755                     "gl_ObjectToWorldEXT[1][1], gl_ObjectToWorldEXT[2][1], gl_ObjectToWorldEXT[3][1]);\n"
756                  << "        ob.values[cellIdx].isecObjectToWorldEXT[2] = vec4(gl_ObjectToWorldEXT[0][2], "
757                     "gl_ObjectToWorldEXT[1][2], gl_ObjectToWorldEXT[2][2], gl_ObjectToWorldEXT[3][2]);\n"
758                  << "        ob.values[cellIdx].isecObjectToWorld3x4EXT[0] = vec4(gl_ObjectToWorld3x4EXT[0][0], "
759                     "gl_ObjectToWorld3x4EXT[1][0], gl_ObjectToWorld3x4EXT[2][0], 0.0);\n"
760                  << "        ob.values[cellIdx].isecObjectToWorld3x4EXT[1] = vec4(gl_ObjectToWorld3x4EXT[0][1], "
761                     "gl_ObjectToWorld3x4EXT[1][1], gl_ObjectToWorld3x4EXT[2][1], 0.0);\n"
762                  << "        ob.values[cellIdx].isecObjectToWorld3x4EXT[2] = vec4(gl_ObjectToWorld3x4EXT[0][2], "
763                     "gl_ObjectToWorld3x4EXT[1][2], gl_ObjectToWorld3x4EXT[2][2], 0.0);\n"
764                  << "        ob.values[cellIdx].isecObjectToWorld3x4EXT[3] = vec4(gl_ObjectToWorld3x4EXT[0][3], "
765                     "gl_ObjectToWorld3x4EXT[1][3], gl_ObjectToWorld3x4EXT[2][3], 0.0);\n"
766                  << "        ob.values[cellIdx].isecWorldToObjectEXT[0] = vec4(gl_WorldToObjectEXT[0][0], "
767                     "gl_WorldToObjectEXT[1][0], gl_WorldToObjectEXT[2][0], gl_WorldToObjectEXT[3][0]);\n"
768                  << "        ob.values[cellIdx].isecWorldToObjectEXT[1] = vec4(gl_WorldToObjectEXT[0][1], "
769                     "gl_WorldToObjectEXT[1][1], gl_WorldToObjectEXT[2][1], gl_WorldToObjectEXT[3][1]);\n"
770                  << "        ob.values[cellIdx].isecWorldToObjectEXT[2] = vec4(gl_WorldToObjectEXT[0][2], "
771                     "gl_WorldToObjectEXT[1][2], gl_WorldToObjectEXT[2][2], gl_WorldToObjectEXT[3][2]);\n"
772                  << "        ob.values[cellIdx].isecWorldToObject3x4EXT[0] = vec4(gl_WorldToObject3x4EXT[0][0], "
773                     "gl_WorldToObject3x4EXT[1][0], gl_WorldToObject3x4EXT[2][0], 0.0);\n"
774                  << "        ob.values[cellIdx].isecWorldToObject3x4EXT[1] = vec4(gl_WorldToObject3x4EXT[0][1], "
775                     "gl_WorldToObject3x4EXT[1][1], gl_WorldToObject3x4EXT[2][1], 0.0);\n"
776                  << "        ob.values[cellIdx].isecWorldToObject3x4EXT[2] = vec4(gl_WorldToObject3x4EXT[0][2], "
777                     "gl_WorldToObject3x4EXT[1][2], gl_WorldToObject3x4EXT[2][2], 0.0);\n"
778                  << "        ob.values[cellIdx].isecWorldToObject3x4EXT[3] = vec4(gl_WorldToObject3x4EXT[0][3], "
779                     "gl_WorldToObject3x4EXT[1][3], gl_WorldToObject3x4EXT[2][3], 0.0);\n"
780                  << "\n"
781                  << "        const float valueOffset = " << getIsecIndexOffset(isecIdx) << ";\n"
782                  << "        hitAttrib = vec2(valueOffset, valueOffset);\n"
783                  << "        ob.values[cellIdx].isecAttribute = vec4(hitAttrib, 0.0, 0.0);\n"
784                  << (withSRB ? "        ob.values[cellIdx].isecSRB = srb.data;\n" : "")
785                  << "        const float hitT = " << BottomLevelASParams::kBaseZ
786                  << " / ib.params[cellIdx].zDirection;\n"
787                  << "        reportIntersectionEXT(hitT, 0u);\n"
788                  << "    }\n"
789                  << "}\n";
790             const auto shaderName = std::string("isec") + std::to_string(isecIdx) + suffix;
791             programCollection.glslSources.add(shaderName) << glu::IntersectionSource(isec.str()) << buildOptions;
792         }
793 
794     // Callable shader 0, at the top of the stack and storing the built-ins.
795     for (uint32_t srbIdx = 0u; srbIdx < kSBTCount; ++srbIdx)
796     {
797         const bool withSRB = (srbIdx > 0u);
798         const auto suffix  = (withSRB ? "-srb" : "");
799 
800         std::ostringstream call;
801         call << "#version 460 core\n"
802              << "#extension GL_EXT_debug_printf : enable\n"
803              << "#extension GL_EXT_ray_tracing : require\n"
804              << descDecl << (withSRB ? shaderRecordDecl : "") << cellIdxFuncDecl
805              << "layout(location = 1) callableDataInEXT vec4 callData;\n"
806              << "void main (void) {\n"
807              << "    const uint cellIdx = getCellIndex(false);\n"
808              << "\n"
809              << "    ob.values[cellIdx].callLaunchIDEXT = uvec4(gl_LaunchIDEXT.xyz, 0u);\n"
810              << "    ob.values[cellIdx].callLaunchSizeEXT = uvec4(gl_LaunchSizeEXT.xyz, 0u);\n"
811              << "\n"
812              << "    const float valueOffset = " << getCallIndexOffset(0u) << ";\n"
813              << "    const vec4 vecOffset = vec4(valueOffset, valueOffset, valueOffset, valueOffset);\n"
814              << "    callData = callData + vecOffset;\n"
815              << (withSRB ? "    ob.values[cellIdx].call0SRB = srb.data;\n" : "") << "}\n";
816         const auto shaderName = std::string("call0") + suffix;
817         programCollection.glslSources.add(shaderName) << glu::CallableSource(call.str()) << buildOptions;
818     }
819 
820     // Callable shader 1, intermediary.
821     for (uint32_t srbIdx = 0u; srbIdx < kSBTCount; ++srbIdx)
822     {
823         const bool withSRB = (srbIdx > 0u);
824         const auto suffix  = (withSRB ? "-srb" : "");
825 
826         std::ostringstream call;
827         call << "#version 460 core\n"
828              << "#extension GL_EXT_debug_printf : enable\n"
829              << "#extension GL_EXT_ray_tracing : require\n"
830              << descDecl << (withSRB ? shaderRecordDecl : "") << cellIdxFuncDecl
831              << "layout(location = 0) callableDataInEXT vec4 callDataIn;\n"
832              << "layout(location = 1) callableDataEXT vec4 callData;\n"
833              << "void main (void) {\n"
834              << "    const uint cellIdx = getCellIndex(false);\n"
835              << "\n"
836              << "    const float valueOffset = " << getCallIndexOffset(1u) << ";\n"
837              << "    const vec4 vecOffset = vec4(valueOffset, valueOffset, valueOffset, valueOffset);\n"
838              << "    callData = callDataIn + vecOffset;\n"
839              << "    executeCallableEXT(0, 1); // Callable shader 0, callable data 1\n"
840              << "    callDataIn = callData;\n"
841              << (withSRB ? "    ob.values[cellIdx].call1SRB = srb.data;\n" : "") << "}\n";
842         const auto shaderName = std::string("call1") + suffix;
843         programCollection.glslSources.add(shaderName) << glu::CallableSource(call.str()) << buildOptions;
844     }
845 }
846 
createInstance(Context & context) const847 TestInstance *RayTracingCase::createInstance(Context &context) const
848 {
849     return new RayTracingInstance(context, m_params);
850 }
851 
852 using BufferWithMemoryPtr = de::MovePtr<BufferWithMemory>;
853 
854 struct SBTSet
855 {
856     uint32_t shaderGroupHandleSize;
857     uint32_t srbSize;
858 
859     BufferWithMemoryPtr rgenSBT;
860     BufferWithMemoryPtr missSBT;
861     BufferWithMemoryPtr hitsSBT;
862     BufferWithMemoryPtr callSBT;
863 
864     void setRgenSRB(const tcu::Vec4 &data) const;
865     void setMissSRB(uint32_t index, const tcu::Vec4 &data) const;
866     void setCallSRB(uint32_t index, const tcu::Vec4 &data) const;
867     void setHitsSRB(uint32_t index, const tcu::Vec4 &data) const;
868 
869     const tcu::Vec4 &getRgenSRB() const;
870     const tcu::Vec4 &getMissSRB(uint32_t index) const;
871     const tcu::Vec4 &getCallSRB(uint32_t index) const;
872     const tcu::Vec4 &getHitsSRB(uint32_t index) const;
873 
874     uint32_t getStride(void) const;
875 
876 protected:
877     char *getDataPtr(const BufferWithMemory &buffer, uint32_t index) const;
878     void storeDataAt(const BufferWithMemory &buffer, uint32_t index, const tcu::Vec4 &data) const;
879     const tcu::Vec4 &getDataAt(const BufferWithMemory &buffer, uint32_t index) const;
880 };
881 
getDataPtr(const BufferWithMemory & buffer,uint32_t index) const882 char *SBTSet::getDataPtr(const BufferWithMemory &buffer, uint32_t index) const
883 {
884     DE_ASSERT(srbSize > 0u);
885 
886     const uint32_t stride = shaderGroupHandleSize + srbSize;
887     const uint32_t offset = index * stride + shaderGroupHandleSize;
888     char *bufferData      = reinterpret_cast<char *>(buffer.getAllocation().getHostPtr());
889     return bufferData + offset;
890 }
891 
storeDataAt(const BufferWithMemory & buffer,uint32_t index,const tcu::Vec4 & data) const892 void SBTSet::storeDataAt(const BufferWithMemory &buffer, uint32_t index, const tcu::Vec4 &data) const
893 {
894     char *bufferData = getDataPtr(buffer, index);
895     deMemcpy(bufferData, &data, sizeof(data));
896 }
897 
getDataAt(const BufferWithMemory & buffer,uint32_t index) const898 const tcu::Vec4 &SBTSet::getDataAt(const BufferWithMemory &buffer, uint32_t index) const
899 {
900     const char *bufferData  = getDataPtr(buffer, index);
901     const tcu::Vec4 *retPtr = reinterpret_cast<const tcu::Vec4 *>(bufferData);
902     return *retPtr;
903 }
904 
setRgenSRB(const tcu::Vec4 & data) const905 void SBTSet::setRgenSRB(const tcu::Vec4 &data) const
906 {
907     storeDataAt(*rgenSBT, 0u, data);
908 }
909 
setMissSRB(uint32_t index,const tcu::Vec4 & data) const910 void SBTSet::setMissSRB(uint32_t index, const tcu::Vec4 &data) const
911 {
912     storeDataAt(*missSBT, index, data);
913 }
914 
setCallSRB(uint32_t index,const tcu::Vec4 & data) const915 void SBTSet::setCallSRB(uint32_t index, const tcu::Vec4 &data) const
916 {
917     storeDataAt(*callSBT, index, data);
918 }
919 
setHitsSRB(uint32_t index,const tcu::Vec4 & data) const920 void SBTSet::setHitsSRB(uint32_t index, const tcu::Vec4 &data) const
921 {
922     storeDataAt(*hitsSBT, index, data);
923 }
924 
getRgenSRB() const925 const tcu::Vec4 &SBTSet::getRgenSRB() const
926 {
927     return getDataAt(*rgenSBT, 0u);
928 }
929 
getMissSRB(uint32_t index) const930 const tcu::Vec4 &SBTSet::getMissSRB(uint32_t index) const
931 {
932     return getDataAt(*missSBT, index);
933 }
934 
getCallSRB(uint32_t index) const935 const tcu::Vec4 &SBTSet::getCallSRB(uint32_t index) const
936 {
937     return getDataAt(*callSBT, index);
938 }
939 
getHitsSRB(uint32_t index) const940 const tcu::Vec4 &SBTSet::getHitsSRB(uint32_t index) const
941 {
942     return getDataAt(*hitsSBT, index);
943 }
944 
getStride(void) const945 uint32_t SBTSet::getStride(void) const
946 {
947     return (shaderGroupHandleSize + srbSize);
948 }
949 
950 struct ShaderSet
951 {
952     uint32_t baseGroupIndex;
953     VkShaderModule rgen;
954     VkShaderModule miss0;
955     VkShaderModule miss1;
956     VkShaderModule call0;
957     VkShaderModule call1;
958     VkShaderModule chit0;
959     VkShaderModule chit1;
960     VkShaderModule isec0;
961     VkShaderModule isec1;
962 };
963 
genSRBData(de::Random & rnd)964 tcu::Vec4 genSRBData(de::Random &rnd)
965 {
966     static const int minVal = 0;
967     static const int maxVal = 9;
968 
969     tcu::Vec4 data(static_cast<float>(rnd.getInt(minVal, maxVal)), static_cast<float>(rnd.getInt(minVal, maxVal)),
970                    static_cast<float>(rnd.getInt(minVal, maxVal)), static_cast<float>(rnd.getInt(minVal, maxVal)));
971 
972     return data;
973 }
974 
floatEqual(const tcu::Vec4 & a,const tcu::Vec4 & b)975 bool floatEqual(const tcu::Vec4 &a, const tcu::Vec4 &b)
976 {
977     static const tcu::Vec4 thresholdVec(kFloatThreshold, kFloatThreshold, kFloatThreshold, kFloatThreshold);
978 
979     const auto diffs   = tcu::absDiff(a, b);
980     const auto inRange = tcu::lessThan(diffs, thresholdVec);
981     return tcu::boolAll(inRange);
982 }
983 
floatEqual(float a,float b)984 bool floatEqual(float a, float b)
985 {
986     const float diff = std::abs(a - b);
987     return (diff < kFloatThreshold);
988 }
989 
iterate(void)990 tcu::TestStatus RayTracingInstance::iterate(void)
991 {
992     const auto ctx     = m_context.getContextCommonData();
993     const auto qfIndex = (m_params.computeQueue ? m_context.getComputeQueueFamilyIndex() : ctx.qfIndex);
994     const auto queue   = (m_params.computeQueue ? m_context.getComputeQueue() : ctx.queue);
995 
996     const CommandPoolWithBuffer cmd(ctx.vkd, ctx.device, qfIndex);
997     const auto cmdBuffer = *cmd.cmdBuffer;
998 
999     de::Random rnd(m_params.getRandomSeed());
1000     beginCommandBuffer(ctx.vkd, cmdBuffer);
1001 
1002     // Bottom level AS and their parameters.
1003     std::vector<BottomLevelASParams> blasParams;
1004     std::vector<BLASPtr> blas;
1005 
1006     blasParams.reserve(kBLASCount);
1007     blas.reserve(kBLASCount);
1008 
1009     for (uint32_t i = 0u; i < kBLASCount; ++i)
1010     {
1011         blasParams.emplace_back(rnd);
1012         blas.emplace_back(makeBottomLevelASWithParams(blasParams.back()));
1013         blas.back()->createAndBuild(ctx.vkd, ctx.device, cmdBuffer, ctx.allocator);
1014     }
1015 
1016     // Top level acceleration structure using instances of the previous BLASes.
1017     const uint32_t cellCount = kWidth * kHeight;
1018     std::vector<CellParams> cellParams;
1019     cellParams.reserve(cellCount);
1020 
1021     for (uint32_t y = 0u; y < kHeight; ++y)
1022         for (uint32_t x = 0u; x < kWidth; ++x)
1023         {
1024             cellParams.emplace_back(x, y, rnd);
1025             auto &cp            = cellParams.back();
1026             cp.closestPrimitive = blasParams.at(cp.blasIndex).closestPrimitive;
1027         }
1028 
1029     auto topLevelAS = makeTopLevelASWithParams(blas, cellParams);
1030     topLevelAS->createAndBuild(ctx.vkd, ctx.device, cmdBuffer, ctx.allocator);
1031 
1032     // Input and output buffer.
1033     std::vector<CellOutput> cellOutputs(cellCount);
1034 
1035     const auto inputBufferSize = static_cast<VkDeviceSize>(de::dataSize(cellParams));
1036     const auto inputBufferInfo = makeBufferCreateInfo(inputBufferSize, VK_BUFFER_USAGE_STORAGE_BUFFER_BIT);
1037     BufferWithMemory inputBuffer(ctx.vkd, ctx.device, ctx.allocator, inputBufferInfo, MemoryRequirement::HostVisible);
1038     auto &inputBufferAlloc = inputBuffer.getAllocation();
1039     void *inputBufferPtr   = inputBufferAlloc.getHostPtr();
1040     deMemcpy(inputBufferPtr, de::dataOrNull(cellParams), de::dataSize(cellParams));
1041 
1042     const auto outputBufferSize = static_cast<VkDeviceSize>(de::dataSize(cellOutputs));
1043     const auto outputBufferInfo = makeBufferCreateInfo(outputBufferSize, VK_BUFFER_USAGE_STORAGE_BUFFER_BIT);
1044     BufferWithMemory outputBuffer(ctx.vkd, ctx.device, ctx.allocator, outputBufferInfo, MemoryRequirement::HostVisible);
1045     auto &outputBufferAlloc = outputBuffer.getAllocation();
1046     void *outputBufferPtr   = outputBufferAlloc.getHostPtr();
1047     deMemset(outputBufferPtr, 0, de::dataSize(cellOutputs));
1048 
1049     // Descriptor pool and set.
1050     DescriptorPoolBuilder poolBuilder;
1051     poolBuilder.addType(VK_DESCRIPTOR_TYPE_ACCELERATION_STRUCTURE_KHR);
1052     poolBuilder.addType(VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, 2u /*input and output buffers*/);
1053     const auto descriptorPool =
1054         poolBuilder.build(ctx.vkd, ctx.device, VK_DESCRIPTOR_POOL_CREATE_FREE_DESCRIPTOR_SET_BIT, 1u);
1055 
1056     DescriptorSetLayoutBuilder setLayoutBuilder;
1057     setLayoutBuilder.addSingleBinding(VK_DESCRIPTOR_TYPE_ACCELERATION_STRUCTURE_KHR, kStageFlags);
1058     setLayoutBuilder.addSingleBinding(VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, kStageFlags);
1059     setLayoutBuilder.addSingleBinding(VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, kStageFlags);
1060     const auto setLayout = setLayoutBuilder.build(ctx.vkd, ctx.device);
1061 
1062     const auto descriptorSet = makeDescriptorSet(ctx.vkd, ctx.device, *descriptorPool, *setLayout);
1063 
1064     const auto pcSize         = DE_SIZEOF32(uint32_t);
1065     const auto pcRange        = makePushConstantRange(kStageFlags, 0u, pcSize);
1066     const auto pipelineLayout = makePipelineLayout(ctx.vkd, ctx.device, *setLayout, &pcRange);
1067 
1068     DescriptorSetUpdateBuilder setUpdateBuilder;
1069     {
1070         using Location                                            = DescriptorSetUpdateBuilder::Location;
1071         const VkWriteDescriptorSetAccelerationStructureKHR asDesc = {
1072             VK_STRUCTURE_TYPE_WRITE_DESCRIPTOR_SET_ACCELERATION_STRUCTURE_KHR, // VkStructureType                       sType;
1073             nullptr,                    // const void*                           pNext;
1074             1u,                         // uint32_t                          accelerationStructureCount;
1075             topLevelAS.get()->getPtr(), // const VkAccelerationStructureKHR* pAccelerationStructures;
1076         };
1077         const auto inputBufferDescInfo  = makeDescriptorBufferInfo(inputBuffer.get(), 0ull, VK_WHOLE_SIZE);
1078         const auto outputBufferDescInfo = makeDescriptorBufferInfo(outputBuffer.get(), 0ull, VK_WHOLE_SIZE);
1079 
1080         setUpdateBuilder.writeSingle(*descriptorSet, Location::binding(0u),
1081                                      VK_DESCRIPTOR_TYPE_ACCELERATION_STRUCTURE_KHR, &asDesc);
1082         setUpdateBuilder.writeSingle(*descriptorSet, Location::binding(1u), VK_DESCRIPTOR_TYPE_STORAGE_BUFFER,
1083                                      &inputBufferDescInfo);
1084         setUpdateBuilder.writeSingle(*descriptorSet, Location::binding(2u), VK_DESCRIPTOR_TYPE_STORAGE_BUFFER,
1085                                      &outputBufferDescInfo);
1086     }
1087     setUpdateBuilder.update(ctx.vkd, ctx.device);
1088 
1089     // Create indirect commands layout.
1090     VkIndirectCommandsLayoutUsageFlagsEXT cmdsLayoutFlags = 0u;
1091     if (m_params.preprocess)
1092         cmdsLayoutFlags |= VK_INDIRECT_COMMANDS_LAYOUT_USAGE_EXPLICIT_PREPROCESS_BIT_EXT;
1093     if (m_params.unordered)
1094         cmdsLayoutFlags |= VK_INDIRECT_COMMANDS_LAYOUT_USAGE_UNORDERED_SEQUENCES_BIT_EXT;
1095     IndirectCommandsLayoutBuilderExt cmdsLayoutBuilder(cmdsLayoutFlags, kStageFlags, *pipelineLayout);
1096     if (m_params.useExecutionSet)
1097         cmdsLayoutBuilder.addExecutionSetToken(cmdsLayoutBuilder.getStreamRange(),
1098                                                VK_INDIRECT_EXECUTION_SET_INFO_TYPE_PIPELINES_EXT, kStageFlags);
1099     cmdsLayoutBuilder.addPushConstantToken(cmdsLayoutBuilder.getStreamRange(), pcRange);
1100     cmdsLayoutBuilder.addTraceRays2Token(cmdsLayoutBuilder.getStreamRange());
1101     const auto cmdsLayout = cmdsLayoutBuilder.build(ctx.vkd, ctx.device);
1102 
1103     // Shaders.
1104     const auto &binaries = m_context.getBinaryCollection();
1105 
1106     const auto rgenMod    = createShaderModule(ctx.vkd, ctx.device, binaries.get("rgen"));
1107     const auto rgenSRBMod = createShaderModule(ctx.vkd, ctx.device, binaries.get("rgen-srb"));
1108 
1109     const auto miss0Mod    = createShaderModule(ctx.vkd, ctx.device, binaries.get("miss0"));
1110     const auto miss1Mod    = createShaderModule(ctx.vkd, ctx.device, binaries.get("miss1"));
1111     const auto miss0SRBMod = createShaderModule(ctx.vkd, ctx.device, binaries.get("miss0-srb"));
1112     const auto miss1SRBMod = createShaderModule(ctx.vkd, ctx.device, binaries.get("miss1-srb"));
1113 
1114     const auto chit0Mod    = createShaderModule(ctx.vkd, ctx.device, binaries.get("chit0"));
1115     const auto chit1Mod    = createShaderModule(ctx.vkd, ctx.device, binaries.get("chit1"));
1116     const auto chit0SRBMod = createShaderModule(ctx.vkd, ctx.device, binaries.get("chit0-srb"));
1117     const auto chit1SRBMod = createShaderModule(ctx.vkd, ctx.device, binaries.get("chit1-srb"));
1118 
1119     const auto isec0Mod    = createShaderModule(ctx.vkd, ctx.device, binaries.get("isec0"));
1120     const auto isec1Mod    = createShaderModule(ctx.vkd, ctx.device, binaries.get("isec1"));
1121     const auto isec0SRBMod = createShaderModule(ctx.vkd, ctx.device, binaries.get("isec0-srb"));
1122     const auto isec1SRBMod = createShaderModule(ctx.vkd, ctx.device, binaries.get("isec1-srb"));
1123 
1124     const auto call0Mod    = createShaderModule(ctx.vkd, ctx.device, binaries.get("call0"));
1125     const auto call1Mod    = createShaderModule(ctx.vkd, ctx.device, binaries.get("call1"));
1126     const auto call0SRBMod = createShaderModule(ctx.vkd, ctx.device, binaries.get("call0-srb"));
1127     const auto call1SRBMod = createShaderModule(ctx.vkd, ctx.device, binaries.get("call1-srb"));
1128 
1129     const auto rayTracingPropertiesKHR   = makeRayTracingProperties(ctx.vki, ctx.physicalDevice);
1130     const auto &shaderGroupHandleSize    = rayTracingPropertiesKHR->getShaderGroupHandleSize();
1131     const auto &shaderGroupBaseAlignment = rayTracingPropertiesKHR->getShaderGroupBaseAlignment();
1132 
1133     // SBTs. We need 2 because we'll divide shaders by the absence or presence of the SRBs.
1134     std::vector<SBTSet> sbts(kSBTCount);
1135 
1136     const bool multiplePipelines = (m_params.useExecutionSet);
1137     const auto pipelineCount     = (multiplePipelines ? 2u : 1u);
1138 
1139     using RTPipelinePtr = de::MovePtr<RayTracingPipeline>;
1140     std::vector<RTPipelinePtr> rayTracingPipelines;
1141     std::vector<Move<VkPipeline>> pipelines;
1142 
1143     // These are higher than what will be used.
1144     const auto recursionDepth = 5u;
1145     const auto size2Vec4      = DE_SIZEOF32(tcu::Vec4) * 2u;
1146 
1147     rayTracingPipelines.reserve(pipelineCount);
1148     pipelines.reserve(pipelineCount);
1149 
1150     for (uint32_t i = 0u; i < pipelineCount; ++i)
1151     {
1152         rayTracingPipelines.emplace_back(de::newMovePtr<RayTracingPipeline>());
1153         auto &rtPipeline = rayTracingPipelines.back();
1154         rtPipeline->setCreateFlags2(VK_PIPELINE_CREATE_2_INDIRECT_BINDABLE_BIT_EXT);
1155         rtPipeline->setMaxAttributeSize(size2Vec4);
1156         rtPipeline->setMaxPayloadSize(size2Vec4);
1157         rtPipeline->setMaxRecursionDepth(recursionDepth);
1158     }
1159 
1160     // Base shader group numbers.
1161     const uint32_t rgenGroup     = 0u; // Just one group.
1162     const uint32_t missGroupBase = 1u; // 2 groups for the rest.
1163     const uint32_t callGroupBase = 3u;
1164     const uint32_t hitsGroupBase = 5u;
1165     const uint32_t groupCount    = 7u;
1166 
1167     std::vector<ShaderSet> shaderSets;
1168     shaderSets.reserve(kSBTCount);
1169 
1170     shaderSets.push_back(ShaderSet{
1171         0u,
1172         *rgenMod,
1173         *miss0Mod,
1174         *miss1Mod,
1175         *call0Mod,
1176         *call1Mod,
1177         *chit0Mod,
1178         *chit1Mod,
1179         *isec0Mod,
1180         *isec1Mod,
1181     });
1182     shaderSets.push_back(ShaderSet{
1183         (multiplePipelines ? 0u : groupCount),
1184         *rgenSRBMod,
1185         *miss0SRBMod,
1186         *miss1SRBMod,
1187         *call0SRBMod,
1188         *call1SRBMod,
1189         *chit0SRBMod,
1190         *chit1SRBMod,
1191         *isec0SRBMod,
1192         *isec1SRBMod,
1193     });
1194 
1195     for (uint32_t i = 0u; i < kSBTCount; ++i)
1196     {
1197         const auto pipelineIdx = (multiplePipelines ? i : 0u);
1198         auto &rtPipeline       = rayTracingPipelines.at(pipelineIdx);
1199 
1200         const auto &shaderSet = shaderSets.at(i);
1201 
1202         rtPipeline->addShader(VK_SHADER_STAGE_RAYGEN_BIT_KHR, shaderSet.rgen, shaderSet.baseGroupIndex + rgenGroup);
1203 
1204         rtPipeline->addShader(VK_SHADER_STAGE_MISS_BIT_KHR, shaderSet.miss0,
1205                               shaderSet.baseGroupIndex + missGroupBase + 0u);
1206         rtPipeline->addShader(VK_SHADER_STAGE_MISS_BIT_KHR, shaderSet.miss1,
1207                               shaderSet.baseGroupIndex + missGroupBase + 1u);
1208 
1209         rtPipeline->addShader(VK_SHADER_STAGE_CALLABLE_BIT_KHR, shaderSet.call0,
1210                               shaderSet.baseGroupIndex + callGroupBase + 0u);
1211         rtPipeline->addShader(VK_SHADER_STAGE_CALLABLE_BIT_KHR, shaderSet.call1,
1212                               shaderSet.baseGroupIndex + callGroupBase + 1u);
1213 
1214         rtPipeline->addShader(VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR, shaderSet.chit0,
1215                               shaderSet.baseGroupIndex + hitsGroupBase + 0u);
1216         rtPipeline->addShader(VK_SHADER_STAGE_INTERSECTION_BIT_KHR, shaderSet.isec0,
1217                               shaderSet.baseGroupIndex + hitsGroupBase + 0u);
1218 
1219         rtPipeline->addShader(VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR, shaderSet.chit1,
1220                               shaderSet.baseGroupIndex + hitsGroupBase + 1u);
1221         rtPipeline->addShader(VK_SHADER_STAGE_INTERSECTION_BIT_KHR, shaderSet.isec1,
1222                               shaderSet.baseGroupIndex + hitsGroupBase + 1u);
1223     }
1224 
1225     for (uint32_t i = 0u; i < pipelineCount; ++i)
1226         pipelines.emplace_back(rayTracingPipelines.at(i)->createPipeline(ctx.vkd, ctx.device, *pipelineLayout));
1227 
1228     // Indirect execution set if used.
1229     VkIndirectExecutionSetEXT iesHandle = VK_NULL_HANDLE;
1230     ExecutionSetManagerPtr iesManager;
1231     if (m_params.useExecutionSet)
1232     {
1233         // Note we insert the back pipeline at index 0, but we'll overwrite both entries.
1234         iesManager = makeExecutionSetManagerPipeline(ctx.vkd, ctx.device, *pipelines.back(), pipelineCount);
1235         for (uint32_t i = 0u; i < pipelineCount; ++i)
1236             iesManager->addPipeline(i, *pipelines.at(i));
1237         iesManager->update();
1238         iesHandle = iesManager->get();
1239     }
1240 
1241     for (uint32_t i = 0u; i < kSBTCount; ++i)
1242     {
1243         const auto withSRB     = (i > 0u);
1244         const auto srbSize     = (withSRB ? shaderGroupHandleSize : 0u);
1245         const auto pipelineIdx = (multiplePipelines ? i : 0u);
1246 
1247         auto &rtPipeline    = rayTracingPipelines.at(pipelineIdx);
1248         const auto pipeline = pipelines.at(pipelineIdx).get();
1249 
1250         auto &sbt = sbts.at(i);
1251 
1252         sbt.shaderGroupHandleSize = shaderGroupHandleSize;
1253         sbt.srbSize               = srbSize;
1254 
1255         sbt.rgenSBT = rtPipeline->createShaderBindingTable(
1256             ctx.vkd, ctx.device, pipeline, ctx.allocator, shaderGroupHandleSize, shaderGroupBaseAlignment,
1257             shaderSets.at(i).baseGroupIndex + rgenGroup, 1u, 0u, 0u, MemoryRequirement::Any, 0u, 0u, srbSize);
1258 
1259         sbt.missSBT = rtPipeline->createShaderBindingTable(
1260             ctx.vkd, ctx.device, pipeline, ctx.allocator, shaderGroupHandleSize, shaderGroupBaseAlignment,
1261             shaderSets.at(i).baseGroupIndex + missGroupBase, 2u, 0u, 0u, MemoryRequirement::Any, 0u, 0u, srbSize);
1262 
1263         sbt.callSBT = rtPipeline->createShaderBindingTable(
1264             ctx.vkd, ctx.device, pipeline, ctx.allocator, shaderGroupHandleSize, shaderGroupBaseAlignment,
1265             shaderSets.at(i).baseGroupIndex + callGroupBase, 2u, 0u, 0u, MemoryRequirement::Any, 0u, 0u, srbSize);
1266 
1267         sbt.hitsSBT = rtPipeline->createShaderBindingTable(
1268             ctx.vkd, ctx.device, pipeline, ctx.allocator, shaderGroupHandleSize, shaderGroupBaseAlignment,
1269             shaderSets.at(i).baseGroupIndex + hitsGroupBase, 2u, 0u, 0u, MemoryRequirement::Any, 0u, 0u, srbSize);
1270 
1271         if (withSRB)
1272         {
1273             sbt.setRgenSRB(genSRBData(rnd));
1274             sbt.setMissSRB(0u, genSRBData(rnd));
1275             sbt.setMissSRB(1u, genSRBData(rnd));
1276             sbt.setCallSRB(0u, genSRBData(rnd));
1277             sbt.setCallSRB(1u, genSRBData(rnd));
1278             sbt.setHitsSRB(0u, genSRBData(rnd));
1279             sbt.setHitsSRB(1u, genSRBData(rnd));
1280         }
1281     }
1282 
1283     ctx.vkd.cmdBindDescriptorSets(cmdBuffer, VK_PIPELINE_BIND_POINT_RAY_TRACING_KHR, *pipelineLayout, 0u, 1u,
1284                                   &descriptorSet.get(), 0u, nullptr);
1285 
1286     DE_ASSERT(kHeight % kSBTCount == 0u);
1287 
1288     // DGC buffer with device-generated commands.
1289     const auto dgcDataSize = kSBTCount * cmdsLayoutBuilder.getStreamStride();
1290     std::vector<uint32_t> dgcData;
1291     dgcData.reserve(dgcDataSize / DE_SIZEOF32(uint32_t));
1292 
1293     DGCBuffer dgcBuffer(ctx.vkd, ctx.device, ctx.allocator, dgcDataSize);
1294     auto &dgcBufferAlloc      = dgcBuffer.getAllocation();
1295     void *dgcBufferPtr        = dgcBufferAlloc.getHostPtr();
1296     const auto dgcBaseAddress = dgcBuffer.getDeviceAddress();
1297 
1298     // Fill DGC data and copy it to the buffer.
1299     for (uint32_t i = 0u; i < kSBTCount; ++i)
1300     {
1301         if (m_params.useExecutionSet)
1302             dgcData.push_back(i);
1303         const uint32_t offsetY = i * kDispHeight;
1304         dgcData.push_back(offsetY);
1305 
1306         const auto pipelineIdx = (multiplePipelines ? i : 0u);
1307         auto &sbt              = sbts.at(i);
1308 
1309         const auto stride      = sbt.getStride();
1310         const auto twiceStride = stride * 2u; // Size for those SBTs with 2 entries (miss, call, hits).
1311 
1312 //#define USE_NON_DGC_PATH 1
1313 #undef USE_NON_DGC_PATH
1314 
1315 #ifdef USE_NON_DGC_PATH
1316         // Non-DGC version.
1317         ctx.vkd.cmdBindPipeline(cmdBuffer, VK_PIPELINE_BIND_POINT_RAY_TRACING_KHR, *pipelines.at(pipelineIdx));
1318 #else
1319         // For DGC we need the initial shader state bound.
1320         // For the single pipeline case, this will also be the pipeline in use.
1321         if (i == 0u)
1322             ctx.vkd.cmdBindPipeline(cmdBuffer, VK_PIPELINE_BIND_POINT_RAY_TRACING_KHR, *pipelines.at(pipelineIdx));
1323 #endif
1324 
1325         const auto rgenAddress = getBufferDeviceAddress(ctx.vkd, ctx.device, sbt.rgenSBT->get(), 0u);
1326         const auto missAddress = getBufferDeviceAddress(ctx.vkd, ctx.device, sbt.missSBT->get(), 0u);
1327         const auto callAddress = getBufferDeviceAddress(ctx.vkd, ctx.device, sbt.callSBT->get(), 0u);
1328         const auto hitsAddress = getBufferDeviceAddress(ctx.vkd, ctx.device, sbt.hitsSBT->get(), 0u);
1329 
1330         const auto rgenRegion = makeStridedDeviceAddressRegionKHR(rgenAddress, stride, stride);
1331         const auto missRegion = makeStridedDeviceAddressRegionKHR(missAddress, stride, twiceStride);
1332         const auto callRegion = makeStridedDeviceAddressRegionKHR(callAddress, stride, twiceStride);
1333         const auto hitsRegion = makeStridedDeviceAddressRegionKHR(hitsAddress, stride, twiceStride);
1334 
1335         const VkTraceRaysIndirectCommand2KHR traceRaysCmd{
1336             rgenRegion.deviceAddress, //  VkDeviceAddress   raygenShaderRecordAddress;
1337             rgenRegion.size,          //  VkDeviceSize      raygenShaderRecordSize;
1338             missRegion.deviceAddress, //  VkDeviceAddress   missShaderBindingTableAddress;
1339             missRegion.size,          //  VkDeviceSize      missShaderBindingTableSize;
1340             missRegion.stride,        //  VkDeviceSize      missShaderBindingTableStride;
1341             hitsRegion.deviceAddress, //  VkDeviceAddress   hitShaderBindingTableAddress;
1342             hitsRegion.size,          //  VkDeviceSize      hitShaderBindingTableSize;
1343             hitsRegion.stride,        //  VkDeviceSize      hitShaderBindingTableStride;
1344             callRegion.deviceAddress, //  VkDeviceAddress   callableShaderBindingTableAddress;
1345             callRegion.size,          //  VkDeviceSize      callableShaderBindingTableSize;
1346             callRegion.stride,        //  VkDeviceSize      callableShaderBindingTableStride;
1347             kWidth,                   //  uint32_t          width;
1348             kDispHeight,              //  uint32_t          height;
1349             1u,                       //  uint32_t          depth;
1350         };
1351 
1352         // This is interesting for the non-DGC path below, so we can have indirect ray trace commands.
1353         // We pick the command offset before adding it to the dgcData vector.
1354         const auto cmdOffset = static_cast<uint32_t>(de::dataSize(dgcData));
1355         DE_UNREF(cmdOffset);
1356 
1357         pushBackElement(dgcData, traceRaysCmd);
1358 #ifdef USE_NON_DGC_PATH
1359         // Non-DGC version.
1360         ctx.vkd.cmdPushConstants(cmdBuffer, *pipelineLayout, kStageFlags, 0u, pcSize, &offsetY);
1361         ctx.vkd.cmdTraceRaysIndirect2KHR(cmdBuffer, dgcBaseAddress + cmdOffset);
1362         //ctx.vkd.cmdTraceRaysKHR(cmdBuffer, &rgenRegion, &missRegion, &hitsRegion, &callRegion, kWidth, kDispHeight, 1u);
1363 #endif
1364     }
1365 
1366     DE_ASSERT(dgcDataSize == de::dataSize(dgcData));
1367     deMemcpy(dgcBufferPtr, de::dataOrNull(dgcData), de::dataSize(dgcData));
1368     flushAlloc(ctx.vkd, ctx.device, dgcBufferAlloc);
1369 
1370     // Create preprocess buffer and execute commands.
1371     const auto fixedPipeline = (m_params.useExecutionSet ? VK_NULL_HANDLE : *pipelines.front());
1372     PreprocessBufferExt preprocessBuffer(ctx.vkd, ctx.device, ctx.allocator, iesHandle, *cmdsLayout, kSBTCount, 0u,
1373                                          fixedPipeline);
1374 
1375 #ifndef USE_NON_DGC_PATH
1376     {
1377         DGCGenCmdsInfo cmdsInfo(kStageFlags, iesHandle, *cmdsLayout, dgcBaseAddress, dgcBuffer.getSize(),
1378                                 preprocessBuffer.getDeviceAddress(), preprocessBuffer.getSize(), kSBTCount, 0ull, 0u,
1379                                 fixedPipeline);
1380 
1381         if (m_params.preprocess)
1382         {
1383             ctx.vkd.cmdPreprocessGeneratedCommandsEXT(cmdBuffer, &cmdsInfo.get(), cmdBuffer);
1384             preprocessToExecuteBarrierExt(ctx.vkd, cmdBuffer);
1385         }
1386         {
1387             const auto isPreprocessed = makeVkBool(m_params.preprocess);
1388             ctx.vkd.cmdExecuteGeneratedCommandsEXT(cmdBuffer, isPreprocessed, &cmdsInfo.get());
1389         }
1390     }
1391 #endif
1392 
1393     // Sync shader writes to host reads for the output buffer.
1394     {
1395         const auto barrier = makeMemoryBarrier(VK_ACCESS_SHADER_WRITE_BIT, VK_ACCESS_HOST_READ_BIT);
1396         cmdPipelineMemoryBarrier(ctx.vkd, cmdBuffer, VK_PIPELINE_STAGE_RAY_TRACING_SHADER_BIT_KHR,
1397                                  VK_PIPELINE_STAGE_HOST_BIT, &barrier);
1398     }
1399 
1400     endCommandBuffer(ctx.vkd, cmdBuffer);
1401     submitCommandsAndWait(ctx.vkd, ctx.device, queue, cmdBuffer);
1402     //ctx.vkd.deviceWaitIdle(ctx.device); // For debugPrintf.
1403 
1404     invalidateAlloc(ctx.vkd, ctx.device, outputBuffer.getAllocation());
1405     deMemcpy(de::dataOrNull(cellOutputs), outputBuffer.getAllocation().getHostPtr(), de::dataSize(cellOutputs));
1406 
1407     // Verify cell outputs.
1408     bool fail = false;
1409     auto &log = m_context.getTestContext().getLog();
1410 
1411     for (uint32_t y = 0u; y < kHeight; ++y)
1412         for (uint32_t x = 0u; x < kWidth; ++x)
1413         {
1414             const auto cellIdx     = y * kWidth + x;
1415             const auto &cellOut    = cellOutputs.at(cellIdx);
1416             const auto &cellIn     = cellParams.at(cellIdx);
1417             const auto &blasInfo   = blasParams.at(cellIn.blasIndex);
1418             const auto isTriangles = (blasInfo.geometryType == BottomLevelASParams::kTriangles);
1419             const auto sbtIndex    = y / kDispHeight;
1420             const auto &sbt        = sbts.at(sbtIndex);
1421             const auto withSRB     = (sbtIndex > 0u);
1422 
1423             bool miss = false;
1424             if (cellIn.rayFlags != kRayFlagsNoneEXT)
1425             {
1426 
1427                 if (isTriangles)
1428                 {
1429                     // Front face is clockwise by default.
1430                     if ((cellIn.rayFlags & kRayFlagsCullBackFacingTrianglesEXT) != 0u &&
1431                         blasInfo.windingDirection == BottomLevelASParams::kCounterClockwise)
1432                         miss = true;
1433                     else if ((cellIn.rayFlags & kRayFlagsCullFrontFacingTrianglesEXT) != 0u &&
1434                              blasInfo.windingDirection == BottomLevelASParams::kClockwise)
1435                         miss = true;
1436                 }
1437                 if ((cellIn.rayFlags & kRayFlagsCullOpaqueEXT) != 0u)
1438                     miss = true;
1439             }
1440 
1441             const auto launchID            = tcu::UVec4(x, y % kDispHeight, 0u, 0u);
1442             const auto launchSize          = tcu::UVec4(kWidth, kDispHeight, 1u, 0u);
1443             const auto rgenInitialPayload  = launchID.asFloat();
1444             const auto &origin             = cellIn.origin;
1445             const auto direction           = tcu::Vec4(0.0f, 0.0f, cellIn.zDirection, 0.0f);
1446             const auto primitiveID         = static_cast<int32_t>(blasInfo.closestPrimitive);
1447             const auto instanceID          = static_cast<int32_t>(cellIdx);
1448             const auto instanceCustomIndex = static_cast<int32_t>(cellIn.instanceCustomIndex);
1449             const auto geometryIndex       = static_cast<int32_t>(blasInfo.activeGeometryIndex);
1450             const auto objectRayOrigin     = tcu::Vec4(0.5f, 0.5f, 0.0, 1.0f);
1451 
1452             if (cellOut.rgenLaunchIDEXT != launchID)
1453             {
1454                 log << tcu::TestLog::Message << "Bad rgenLaunchIDEXT at (" << x << ", " << y << "): expected "
1455                     << launchID << " found " << cellOut.rgenLaunchIDEXT << tcu::TestLog::EndMessage;
1456                 fail = true;
1457             }
1458 
1459             if (cellOut.rgenLaunchSizeEXT != launchSize)
1460             {
1461                 log << tcu::TestLog::Message << "Bad rgenLaunchSizeEXT at (" << x << ", " << y << "): expected "
1462                     << launchSize << " found " << cellOut.rgenLaunchSizeEXT << tcu::TestLog::EndMessage;
1463                 fail = true;
1464             }
1465 
1466             if (cellOut.rgenInitialPayload != rgenInitialPayload)
1467             {
1468                 log << tcu::TestLog::Message << "Bad rgenInitialPayload at (" << x << ", " << y << "): expected "
1469                     << rgenInitialPayload << " found " << cellOut.rgenInitialPayload << tcu::TestLog::EndMessage;
1470                 fail = true;
1471             }
1472 
1473             if (withSRB)
1474             {
1475                 const auto srb = sbt.getRgenSRB();
1476                 if (cellOut.rgenSRB != srb)
1477                 {
1478                     log << tcu::TestLog::Message << "Bad rgenSRB at (" << x << ", " << y << "): expected " << srb
1479                         << " found " << cellOut.rgenSRB << tcu::TestLog::EndMessage;
1480                     fail = true;
1481                 }
1482             }
1483 
1484             tcu::Vec4 payload = rgenInitialPayload;
1485 
1486             if (miss)
1487             {
1488                 const auto missOffset = static_cast<float>(getMissIndexOffset(cellIn.missIndex));
1489                 const tcu::Vec4 missVecOffset(missOffset, missOffset, missOffset, missOffset);
1490                 payload += missVecOffset;
1491 
1492                 // Miss payload verification.
1493                 if (cellOut.missIncomingPayload != rgenInitialPayload)
1494                 {
1495                     log << tcu::TestLog::Message << "Bad missIncomingPayload at (" << x << ", " << y << "): expected "
1496                         << rgenInitialPayload << " found " << cellOut.missIncomingPayload << tcu::TestLog::EndMessage;
1497                     fail = true;
1498                 }
1499                 if (cellOut.missPayload != payload)
1500                 {
1501                     log << tcu::TestLog::Message << "Bad missPayload at (" << x << ", " << y << "): expected "
1502                         << payload << " found " << cellOut.missPayload << tcu::TestLog::EndMessage;
1503                     fail = true;
1504                 }
1505 
1506                 if (cellOut.missLaunchIDEXT != launchID)
1507                 {
1508                     log << tcu::TestLog::Message << "Bad missLaunchIDEXT at (" << x << ", " << y << "): expected "
1509                         << launchID << " found " << cellOut.missLaunchIDEXT << tcu::TestLog::EndMessage;
1510                     fail = true;
1511                 }
1512                 if (cellOut.missLaunchSizeEXT != launchSize)
1513                 {
1514                     log << tcu::TestLog::Message << "Bad missLaunchSizeEXT at (" << x << ", " << y << "): expected "
1515                         << launchSize << " found " << cellOut.missLaunchSizeEXT << tcu::TestLog::EndMessage;
1516                     fail = true;
1517                 }
1518                 if (cellOut.missWorldRayOriginEXT != origin)
1519                 {
1520                     log << tcu::TestLog::Message << "Bad missWorldRayOriginEXT at (" << x << ", " << y << "): expected "
1521                         << origin << " found " << cellOut.missWorldRayOriginEXT << tcu::TestLog::EndMessage;
1522                     fail = true;
1523                 }
1524                 if (cellOut.missWorldRayDirectionEXT != direction)
1525                 {
1526                     log << tcu::TestLog::Message << "Bad missWorldRayDirectionEXT at (" << x << ", " << y
1527                         << "): expected " << direction << " found " << cellOut.missWorldRayDirectionEXT
1528                         << tcu::TestLog::EndMessage;
1529                     fail = true;
1530                 }
1531                 if (cellOut.missRayTminEXT != cellIn.minT)
1532                 {
1533                     log << tcu::TestLog::Message << "Bad missRayTminEXT at (" << x << ", " << y << "): expected "
1534                         << cellIn.minT << " found " << cellOut.missRayTminEXT << tcu::TestLog::EndMessage;
1535                     fail = true;
1536                 }
1537                 if (cellOut.missRayTmaxEXT != cellIn.maxT)
1538                 {
1539                     log << tcu::TestLog::Message << "Bad missRayTmaxEXT at (" << x << ", " << y << "): expected "
1540                         << cellIn.maxT << " found " << cellOut.missRayTmaxEXT << tcu::TestLog::EndMessage;
1541                     fail = true;
1542                 }
1543                 if (cellOut.missIncomingRayFlagsEXT != cellIn.rayFlags)
1544                 {
1545                     log << tcu::TestLog::Message << "Bad missIncomingRayFlagsEXT at (" << x << ", " << y
1546                         << "): expected " << cellIn.rayFlags << " found " << cellOut.missIncomingRayFlagsEXT
1547                         << tcu::TestLog::EndMessage;
1548                     fail = true;
1549                 }
1550 
1551                 if (withSRB)
1552                 {
1553                     const auto srb = sbt.getMissSRB(cellIn.missIndex);
1554                     if (cellOut.missSRB != srb)
1555                     {
1556                         log << tcu::TestLog::Message << "Bad missSRB at (" << x << ", " << y << "): expected " << srb
1557                             << " found " << cellOut.missSRB << tcu::TestLog::EndMessage;
1558                         fail = true;
1559                     }
1560                 }
1561             }
1562             else
1563             {
1564                 const auto isecOffset  = static_cast<float>(getIsecIndexOffset(blasInfo.activeGeometryIndex));
1565                 const auto chitOffset  = static_cast<float>(getChitIndexOffset(blasInfo.activeGeometryIndex));
1566                 const auto call0Offset = static_cast<float>(getCallIndexOffset(0u));
1567                 const auto call1Offset = static_cast<float>(getCallIndexOffset(1u));
1568 
1569                 const tcu::Vec4 chitVecOffset(chitOffset, chitOffset, chitOffset, chitOffset);
1570                 const tcu::Vec4 call0VecOffset(call0Offset, call0Offset, call0Offset, call0Offset);
1571                 const tcu::Vec4 call1VecOffset(call1Offset, call1Offset, call1Offset, call1Offset);
1572 
1573                 const auto chitIncomingPayload = payload;
1574 
1575                 payload += call0VecOffset;
1576                 payload += call1VecOffset;
1577                 payload += chitVecOffset;
1578 
1579                 const tcu::Vec4 hitAttribute(isecOffset, isecOffset, 0.0f, 0.0f);
1580 
1581                 const float tMaxAtIsec = BottomLevelASParams::kBaseZ / cellIn.zDirection;
1582                 uint32_t hitKind       = 0u;
1583 
1584                 if (blasInfo.geometryType == BottomLevelASParams::kTriangles)
1585                 {
1586                     hitKind = ((blasInfo.windingDirection == BottomLevelASParams::kClockwise) ?
1587                                    kHitKindFrontFacingTriangleEXT :
1588                                    kHitKindBackFacingTriangleEXT);
1589                 }
1590 
1591                 if (blasInfo.geometryType == BottomLevelASParams::kAABBs)
1592                 {
1593                     // Intersection shader.
1594                     if (cellOut.isecLaunchIDEXT != launchID)
1595                     {
1596                         log << tcu::TestLog::Message << "Bad isecLaunchIDEXT at (" << x << ", " << y << "): expected "
1597                             << launchID << " found " << cellOut.isecLaunchIDEXT << tcu::TestLog::EndMessage;
1598                         fail = true;
1599                     }
1600                     if (cellOut.isecLaunchSizeEXT != launchSize)
1601                     {
1602                         log << tcu::TestLog::Message << "Bad isecLaunchSizeEXT at (" << x << ", " << y << "): expected "
1603                             << launchSize << " found " << cellOut.isecLaunchSizeEXT << tcu::TestLog::EndMessage;
1604                         fail = true;
1605                     }
1606 
1607                     if (cellOut.isecPrimitiveID != primitiveID)
1608                     {
1609                         log << tcu::TestLog::Message << "Bad isecPrimitiveID at (" << x << ", " << y << "): expected "
1610                             << primitiveID << " found " << cellOut.isecPrimitiveID << tcu::TestLog::EndMessage;
1611                         fail = true;
1612                     }
1613                     if (cellOut.isecInstanceID != instanceID)
1614                     {
1615                         log << tcu::TestLog::Message << "Bad isecInstanceID at (" << x << ", " << y << "): expected "
1616                             << instanceID << " found " << cellOut.isecInstanceID << tcu::TestLog::EndMessage;
1617                         fail = true;
1618                     }
1619                     if (cellOut.isecInstanceCustomIndexEXT != instanceCustomIndex)
1620                     {
1621                         log << tcu::TestLog::Message << "Bad isecInstanceCustomIndexEXT at (" << x << ", " << y
1622                             << "): expected " << instanceCustomIndex << " found " << cellOut.isecInstanceCustomIndexEXT
1623                             << tcu::TestLog::EndMessage;
1624                         fail = true;
1625                     }
1626                     if (cellOut.isecGeometryIndexEXT != geometryIndex)
1627                     {
1628                         log << tcu::TestLog::Message << "Bad isecGeometryIndexEXT at (" << x << ", " << y
1629                             << "): expected " << geometryIndex << " found " << cellOut.isecGeometryIndexEXT
1630                             << tcu::TestLog::EndMessage;
1631                         fail = true;
1632                     }
1633                     if (cellOut.isecWorldRayOriginEXT != origin)
1634                     {
1635                         log << tcu::TestLog::Message << "Bad isecWorldRayOriginEXT at (" << x << ", " << y
1636                             << "): expected " << origin << " found " << cellOut.isecWorldRayOriginEXT
1637                             << tcu::TestLog::EndMessage;
1638                         fail = true;
1639                     }
1640                     if (cellOut.isecWorldRayDirectionEXT != direction)
1641                     {
1642                         log << tcu::TestLog::Message << "Bad isecWorldRayDirectionEXT at (" << x << ", " << y
1643                             << "): expected " << direction << " found " << cellOut.isecWorldRayDirectionEXT
1644                             << tcu::TestLog::EndMessage;
1645                         fail = true;
1646                     }
1647                     if (!floatEqual(cellOut.isecObjectRayOriginEXT, objectRayOrigin))
1648                     {
1649                         log << tcu::TestLog::Message << "Bad isecObjectRayOriginEXT at (" << x << ", " << y
1650                             << "): expected " << objectRayOrigin << " found " << cellOut.isecObjectRayOriginEXT
1651                             << tcu::TestLog::EndMessage;
1652                         fail = true;
1653                     }
1654                     if (!floatEqual(cellOut.isecObjectRayDirectionEXT, direction))
1655                     {
1656                         log << tcu::TestLog::Message << "Bad isecObjectRayDirectionEXT at (" << x << ", " << y
1657                             << "): expected " << direction << " found " << cellOut.isecObjectRayDirectionEXT
1658                             << tcu::TestLog::EndMessage;
1659                         fail = true;
1660                     }
1661                     if (cellOut.isecRayTminEXT != cellIn.minT)
1662                     {
1663                         log << tcu::TestLog::Message << "Bad isecRayTminEXT at (" << x << ", " << y << "): expected "
1664                             << cellIn.minT << " found " << cellOut.isecRayTminEXT << tcu::TestLog::EndMessage;
1665                         fail = true;
1666                     }
1667                     if (cellOut.isecRayTmaxEXT != cellIn.maxT)
1668                     {
1669                         log << tcu::TestLog::Message << "Bad isecRayTmaxEXT at (" << x << ", " << y << "): expected "
1670                             << cellIn.maxT << " found " << cellOut.isecRayTmaxEXT << tcu::TestLog::EndMessage;
1671                         fail = true;
1672                     }
1673                     if (cellOut.isecIncomingRayFlagsEXT != cellIn.rayFlags)
1674                     {
1675                         log << tcu::TestLog::Message << "Bad isecIncomingRayFlagsEXT at (" << x << ", " << y
1676                             << "): expected " << cellIn.rayFlags << " found " << cellOut.isecIncomingRayFlagsEXT
1677                             << tcu::TestLog::EndMessage;
1678                         fail = true;
1679                     }
1680                     for (uint32_t i = 0u; i < de::arrayLength(cellIn.transformMatrix.matrix); ++i)
1681                     {
1682                         const tcu::Vec4 row(cellIn.transformMatrix.matrix[i][0], cellIn.transformMatrix.matrix[i][1],
1683                                             cellIn.transformMatrix.matrix[i][2], cellIn.transformMatrix.matrix[i][3]);
1684                         if (!floatEqual(row, cellOut.isecObjectToWorldEXT[i]))
1685                         {
1686                             log << tcu::TestLog::Message << "Bad isecObjectToWorldEXT[" << i << "] at (" << x << ", "
1687                                 << y << "): expected " << row << " found " << cellOut.isecObjectToWorldEXT[i]
1688                                 << tcu::TestLog::EndMessage;
1689                             fail = true;
1690                         }
1691                     }
1692                     for (uint32_t i = 0u; i < de::arrayLength(cellIn.transformMatrix.matrix); ++i)
1693                     {
1694                         const tcu::Vec4 expected(
1695                             cellIn.transformMatrix.matrix[i][0], cellIn.transformMatrix.matrix[i][1],
1696                             cellIn.transformMatrix.matrix[i][2], cellIn.transformMatrix.matrix[i][3]);
1697                         const tcu::Vec4 result(
1698                             cellOut.isecObjectToWorld3x4EXT[0][i], cellOut.isecObjectToWorld3x4EXT[1][i],
1699                             cellOut.isecObjectToWorld3x4EXT[2][i], cellOut.isecObjectToWorld3x4EXT[3][i]);
1700                         if (!floatEqual(expected, result))
1701                         {
1702                             log << tcu::TestLog::Message << "Bad isecObjectToWorld3x4EXT[][" << i << "] at (" << x
1703                                 << ", " << y << "): expected " << expected << " found " << result
1704                                 << tcu::TestLog::EndMessage;
1705                             fail = true;
1706                         }
1707                     }
1708                     for (uint32_t i = 0u; i < de::arrayLength(cellIn.transformMatrix.matrix); ++i)
1709                     {
1710                         // Note W column is negative to undo the translation.
1711                         const tcu::Vec4 row(cellIn.transformMatrix.matrix[i][0], cellIn.transformMatrix.matrix[i][1],
1712                                             cellIn.transformMatrix.matrix[i][2], -cellIn.transformMatrix.matrix[i][3]);
1713                         if (!floatEqual(row, cellOut.isecWorldToObjectEXT[i]))
1714                         {
1715                             log << tcu::TestLog::Message << "Bad isecWorldToObjectEXT[" << i << "] at (" << x << ", "
1716                                 << y << "): expected " << row << " found " << cellOut.isecWorldToObjectEXT[i]
1717                                 << tcu::TestLog::EndMessage;
1718                             fail = true;
1719                         }
1720                     }
1721                     for (uint32_t i = 0u; i < de::arrayLength(cellIn.transformMatrix.matrix); ++i)
1722                     {
1723                         // Note W column is negative to undo the translation.
1724                         const tcu::Vec4 expected(
1725                             cellIn.transformMatrix.matrix[i][0], cellIn.transformMatrix.matrix[i][1],
1726                             cellIn.transformMatrix.matrix[i][2], -cellIn.transformMatrix.matrix[i][3]);
1727                         const tcu::Vec4 result(
1728                             cellOut.isecWorldToObject3x4EXT[0][i], cellOut.isecWorldToObject3x4EXT[1][i],
1729                             cellOut.isecWorldToObject3x4EXT[2][i], cellOut.isecWorldToObject3x4EXT[3][i]);
1730                         if (!floatEqual(expected, result))
1731                         {
1732                             log << tcu::TestLog::Message << "Bad isecWorldToObject3x4EXT[][" << i << "] at (" << x
1733                                 << ", " << y << "): expected " << expected << " found " << result
1734                                 << tcu::TestLog::EndMessage;
1735                             fail = true;
1736                         }
1737                     }
1738 
1739                     if (cellOut.isecAttribute != hitAttribute)
1740                     {
1741                         log << tcu::TestLog::Message << "Bad isecAttribute at (" << x << ", " << y << "): expected "
1742                             << hitAttribute << " found " << cellOut.isecAttribute << tcu::TestLog::EndMessage;
1743                         fail = true;
1744                     }
1745                     if (cellOut.chitAttribute != hitAttribute)
1746                     {
1747                         log << tcu::TestLog::Message << "Bad chitAttribute at (" << x << ", " << y << "): expected "
1748                             << hitAttribute << " found " << cellOut.chitAttribute << tcu::TestLog::EndMessage;
1749                         fail = true;
1750                     }
1751 
1752                     if (withSRB)
1753                     {
1754                         const auto srb = sbt.getHitsSRB(blasInfo.activeGeometryIndex);
1755                         if (cellOut.isecSRB != srb)
1756                         {
1757                             log << tcu::TestLog::Message << "Bad isecSRB at (" << x << ", " << y << "): expected "
1758                                 << srb << " found " << cellOut.isecSRB << tcu::TestLog::EndMessage;
1759                             fail = true;
1760                         }
1761                     }
1762                 }
1763 
1764                 // Closest-hit shader.
1765                 if (cellOut.chitLaunchIDEXT != launchID)
1766                 {
1767                     log << tcu::TestLog::Message << "Bad chitLaunchIDEXT at (" << x << ", " << y << "): expected "
1768                         << launchID << " found " << cellOut.chitLaunchIDEXT << tcu::TestLog::EndMessage;
1769                     fail = true;
1770                 }
1771                 if (cellOut.chitLaunchSizeEXT != launchSize)
1772                 {
1773                     log << tcu::TestLog::Message << "Bad chitLaunchSizeEXT at (" << x << ", " << y << "): expected "
1774                         << launchSize << " found " << cellOut.chitLaunchSizeEXT << tcu::TestLog::EndMessage;
1775                     fail = true;
1776                 }
1777 
1778                 if (cellOut.chitPrimitiveID != primitiveID)
1779                 {
1780                     log << tcu::TestLog::Message << "Bad chitPrimitiveID at (" << x << ", " << y << "): expected "
1781                         << primitiveID << " found " << cellOut.chitPrimitiveID << tcu::TestLog::EndMessage;
1782                     fail = true;
1783                 }
1784                 if (cellOut.chitInstanceID != instanceID)
1785                 {
1786                     log << tcu::TestLog::Message << "Bad chitInstanceID at (" << x << ", " << y << "): expected "
1787                         << instanceID << " found " << cellOut.chitInstanceID << tcu::TestLog::EndMessage;
1788                     fail = true;
1789                 }
1790                 if (cellOut.chitInstanceCustomIndexEXT != instanceCustomIndex)
1791                 {
1792                     log << tcu::TestLog::Message << "Bad chitInstanceCustomIndexEXT at (" << x << ", " << y
1793                         << "): expected " << instanceCustomIndex << " found " << cellOut.chitInstanceCustomIndexEXT
1794                         << tcu::TestLog::EndMessage;
1795                     fail = true;
1796                 }
1797                 if (cellOut.chitGeometryIndexEXT != geometryIndex)
1798                 {
1799                     log << tcu::TestLog::Message << "Bad chitGeometryIndexEXT at (" << x << ", " << y << "): expected "
1800                         << geometryIndex << " found " << cellOut.chitGeometryIndexEXT << tcu::TestLog::EndMessage;
1801                     fail = true;
1802                 }
1803                 if (cellOut.chitWorldRayOriginEXT != origin)
1804                 {
1805                     log << tcu::TestLog::Message << "Bad chitWorldRayOriginEXT at (" << x << ", " << y << "): expected "
1806                         << origin << " found " << cellOut.chitWorldRayOriginEXT << tcu::TestLog::EndMessage;
1807                     fail = true;
1808                 }
1809                 if (cellOut.chitWorldRayDirectionEXT != direction)
1810                 {
1811                     log << tcu::TestLog::Message << "Bad chitWorldRayDirectionEXT at (" << x << ", " << y
1812                         << "): expected " << direction << " found " << cellOut.chitWorldRayDirectionEXT
1813                         << tcu::TestLog::EndMessage;
1814                     fail = true;
1815                 }
1816                 if (!floatEqual(cellOut.chitObjectRayOriginEXT, objectRayOrigin))
1817                 {
1818                     log << tcu::TestLog::Message << "Bad chitObjectRayOriginEXT at (" << x << ", " << y
1819                         << "): expected " << objectRayOrigin << " found " << cellOut.chitObjectRayOriginEXT
1820                         << tcu::TestLog::EndMessage;
1821                     fail = true;
1822                 }
1823                 if (!floatEqual(cellOut.chitObjectRayDirectionEXT, direction))
1824                 {
1825                     log << tcu::TestLog::Message << "Bad chitObjectRayDirectionEXT at (" << x << ", " << y
1826                         << "): expected " << direction << " found " << cellOut.chitObjectRayDirectionEXT
1827                         << tcu::TestLog::EndMessage;
1828                     fail = true;
1829                 }
1830                 if (cellOut.chitRayTminEXT != cellIn.minT)
1831                 {
1832                     log << tcu::TestLog::Message << "Bad chitRayTminEXT at (" << x << ", " << y << "): expected "
1833                         << cellIn.minT << " found " << cellOut.chitRayTminEXT << tcu::TestLog::EndMessage;
1834                     fail = true;
1835                 }
1836                 if (!floatEqual(cellOut.chitRayTmaxEXT, tMaxAtIsec))
1837                 {
1838                     log << tcu::TestLog::Message << "Bad chitRayTmaxEXT at (" << x << ", " << y << "): expected "
1839                         << tMaxAtIsec << " found " << cellOut.chitRayTmaxEXT << tcu::TestLog::EndMessage;
1840                     fail = true;
1841                 }
1842                 if (cellOut.chitIncomingRayFlagsEXT != cellIn.rayFlags)
1843                 {
1844                     log << tcu::TestLog::Message << "Bad chitIncomingRayFlagsEXT at (" << x << ", " << y
1845                         << "): expected " << cellIn.rayFlags << " found " << cellOut.chitIncomingRayFlagsEXT
1846                         << tcu::TestLog::EndMessage;
1847                     fail = true;
1848                 }
1849                 if (!floatEqual(cellOut.chitHitTEXT, tMaxAtIsec))
1850                 {
1851                     log << tcu::TestLog::Message << "Bad chitHitTEXT at (" << x << ", " << y << "): expected "
1852                         << tMaxAtIsec << " found " << cellOut.chitHitTEXT << tcu::TestLog::EndMessage;
1853                     fail = true;
1854                 }
1855                 if (cellOut.chitHitKindEXT != hitKind)
1856                 {
1857                     log << tcu::TestLog::Message << "Bad chitHitKindEXT at (" << x << ", " << y << "): expected "
1858                         << hitKind << " found " << cellOut.chitHitKindEXT << tcu::TestLog::EndMessage;
1859                     fail = true;
1860                 }
1861                 for (uint32_t i = 0u; i < de::arrayLength(cellIn.transformMatrix.matrix); ++i)
1862                 {
1863                     const tcu::Vec4 row(cellIn.transformMatrix.matrix[i][0], cellIn.transformMatrix.matrix[i][1],
1864                                         cellIn.transformMatrix.matrix[i][2], cellIn.transformMatrix.matrix[i][3]);
1865                     if (!floatEqual(row, cellOut.chitObjectToWorldEXT[i]))
1866                     {
1867                         log << tcu::TestLog::Message << "Bad chitObjectToWorldEXT[" << i << "] at (" << x << ", " << y
1868                             << "): expected " << row << " found " << cellOut.chitObjectToWorldEXT[i]
1869                             << tcu::TestLog::EndMessage;
1870                         fail = true;
1871                     }
1872                 }
1873                 for (uint32_t i = 0u; i < de::arrayLength(cellIn.transformMatrix.matrix); ++i)
1874                 {
1875                     const tcu::Vec4 expected(cellIn.transformMatrix.matrix[i][0], cellIn.transformMatrix.matrix[i][1],
1876                                              cellIn.transformMatrix.matrix[i][2], cellIn.transformMatrix.matrix[i][3]);
1877                     const tcu::Vec4 result(cellOut.chitObjectToWorld3x4EXT[0][i], cellOut.chitObjectToWorld3x4EXT[1][i],
1878                                            cellOut.chitObjectToWorld3x4EXT[2][i],
1879                                            cellOut.chitObjectToWorld3x4EXT[3][i]);
1880                     if (!floatEqual(expected, result))
1881                     {
1882                         log << tcu::TestLog::Message << "Bad chitObjectToWorld3x4EXT[][" << i << "] at (" << x << ", "
1883                             << y << "): expected " << expected << " found " << result << tcu::TestLog::EndMessage;
1884                         fail = true;
1885                     }
1886                 }
1887                 for (uint32_t i = 0u; i < de::arrayLength(cellIn.transformMatrix.matrix); ++i)
1888                 {
1889                     // Note W column is negative to undo the translation.
1890                     const tcu::Vec4 row(cellIn.transformMatrix.matrix[i][0], cellIn.transformMatrix.matrix[i][1],
1891                                         cellIn.transformMatrix.matrix[i][2], -cellIn.transformMatrix.matrix[i][3]);
1892                     if (!floatEqual(row, cellOut.chitWorldToObjectEXT[i]))
1893                     {
1894                         log << tcu::TestLog::Message << "Bad chitWorldToObjectEXT[" << i << "] at (" << x << ", " << y
1895                             << "): expected " << row << " found " << cellOut.chitWorldToObjectEXT[i]
1896                             << tcu::TestLog::EndMessage;
1897                         fail = true;
1898                     }
1899                 }
1900                 for (uint32_t i = 0u; i < de::arrayLength(cellIn.transformMatrix.matrix); ++i)
1901                 {
1902                     // Note W column is negative to undo the translation.
1903                     const tcu::Vec4 expected(cellIn.transformMatrix.matrix[i][0], cellIn.transformMatrix.matrix[i][1],
1904                                              cellIn.transformMatrix.matrix[i][2], -cellIn.transformMatrix.matrix[i][3]);
1905                     const tcu::Vec4 result(cellOut.chitWorldToObject3x4EXT[0][i], cellOut.chitWorldToObject3x4EXT[1][i],
1906                                            cellOut.chitWorldToObject3x4EXT[2][i],
1907                                            cellOut.chitWorldToObject3x4EXT[3][i]);
1908                     if (!floatEqual(expected, result))
1909                     {
1910                         log << tcu::TestLog::Message << "Bad chitWorldToObject3x4EXT[][" << i << "] at (" << x << ", "
1911                             << y << "): expected " << expected << " found " << result << tcu::TestLog::EndMessage;
1912                         fail = true;
1913                     }
1914                 }
1915 
1916                 if (withSRB)
1917                 {
1918                     const auto srb = sbt.getHitsSRB(blasInfo.activeGeometryIndex);
1919                     if (cellOut.chitSRB != srb)
1920                     {
1921                         log << tcu::TestLog::Message << "Bad chitSRB at (" << x << ", " << y << "): expected " << srb
1922                             << " found " << cellOut.chitSRB << tcu::TestLog::EndMessage;
1923                         fail = true;
1924                     }
1925                 }
1926 
1927                 // Call shaders.
1928                 if (cellOut.callLaunchIDEXT != launchID)
1929                 {
1930                     log << tcu::TestLog::Message << "Bad callLaunchIDEXT at (" << x << ", " << y << "): expected "
1931                         << launchID << " found " << cellOut.callLaunchIDEXT << tcu::TestLog::EndMessage;
1932                     fail = true;
1933                 }
1934                 if (cellOut.callLaunchSizeEXT != launchSize)
1935                 {
1936                     log << tcu::TestLog::Message << "Bad callLaunchSizeEXT at (" << x << ", " << y << "): expected "
1937                         << launchSize << " found " << cellOut.callLaunchSizeEXT << tcu::TestLog::EndMessage;
1938                     fail = true;
1939                 }
1940 
1941                 if (cellOut.chitIncomingPayload != chitIncomingPayload)
1942                 {
1943                     log << tcu::TestLog::Message << "Bad chitIncomingPayload at (" << x << ", " << y << "): expected "
1944                         << chitIncomingPayload << " found " << cellOut.chitIncomingPayload << tcu::TestLog::EndMessage;
1945                     fail = true;
1946                 }
1947 
1948                 if (cellOut.chitPayload != payload)
1949                 {
1950                     log << tcu::TestLog::Message << "Bad chitPayload at (" << x << ", " << y << "): expected "
1951                         << payload << " found " << cellOut.chitPayload << tcu::TestLog::EndMessage;
1952                     fail = true;
1953                 }
1954 
1955                 if (withSRB)
1956                 {
1957                     const auto srb0 = sbt.getCallSRB(0u);
1958                     if (cellOut.call0SRB != srb0)
1959                     {
1960                         log << tcu::TestLog::Message << "Bad call0SRB at (" << x << ", " << y << "): expected " << srb0
1961                             << " found " << cellOut.call0SRB << tcu::TestLog::EndMessage;
1962                         fail = true;
1963                     }
1964 
1965                     const auto srb1 = sbt.getCallSRB(1u);
1966                     if (cellOut.call1SRB != srb1)
1967                     {
1968                         log << tcu::TestLog::Message << "Bad call1SRB at (" << x << ", " << y << "): expected " << srb1
1969                             << " found " << cellOut.call1SRB << tcu::TestLog::EndMessage;
1970                         fail = true;
1971                     }
1972                 }
1973             }
1974 
1975             if (cellOut.rgenFinalPayload != payload)
1976             {
1977                 log << tcu::TestLog::Message << "Bad rgenFinalPayload at (" << x << ", " << y << "): expected "
1978                     << payload << " found " << cellOut.rgenFinalPayload << tcu::TestLog::EndMessage;
1979                 fail = true;
1980             }
1981         }
1982 
1983     if (fail)
1984         return tcu::TestStatus::fail("Fail; check log for details");
1985     return tcu::TestStatus::pass("Pass");
1986 }
1987 
1988 } // namespace
1989 
createDGCRayTracingTestsExt(tcu::TestContext & testCtx)1990 tcu::TestCaseGroup *createDGCRayTracingTestsExt(tcu::TestContext &testCtx)
1991 {
1992     using GroupPtr = de::MovePtr<tcu::TestCaseGroup>;
1993     GroupPtr mainGroup(new tcu::TestCaseGroup(testCtx, "ray_tracing", ""));
1994 
1995     for (const bool useExecutionSet : {false, true})
1996         for (const bool preprocess : {false, true})
1997             for (const bool unordered : {false, true})
1998                 for (const bool computeQueue : {false, true})
1999                 {
2000                     const RayTracingInstance::Params params{useExecutionSet, preprocess, unordered, computeQueue};
2001                     const auto testName = std::string(useExecutionSet ? "with_execution_set" : "no_execution_set") +
2002                                           (preprocess ? "_preprocess" : "") + (unordered ? "_unordered" : "") +
2003                                           (computeQueue ? "_cq" : "");
2004                     mainGroup->addChild(new RayTracingCase(testCtx, testName, params));
2005                 }
2006 
2007     return mainGroup.release();
2008 }
2009 
2010 } // namespace DGC
2011 } // namespace vkt
2012