1 /*------------------------------------------------------------------------
2 * Vulkan Conformance Tests
3 * ------------------------
4 *
5 * Copyright (c) 2021 The Khronos Group Inc.
6 * Copyright (c) 2021 Valve Corporation.
7 *
8 * Licensed under the Apache License, Version 2.0 (the "License");
9 * you may not use this file except in compliance with the License.
10 * You may obtain a copy of the License at
11 *
12 * http://www.apache.org/licenses/LICENSE-2.0
13 *
14 * Unless required by applicable law or agreed to in writing, software
15 * distributed under the License is distributed on an "AS IS" BASIS,
16 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17 * See the License for the specific language governing permissions and
18 * limitations under the License.
19 *
20 *//*!
21 * \file
22 * \brief Tests using non-uniform arguments with traceRayExt().
23 *//*--------------------------------------------------------------------*/
24
25 #include "vktRayTracingNonUniformArgsTests.hpp"
26 #include "vktTestCase.hpp"
27
28 #include "vkRayTracingUtil.hpp"
29 #include "vkObjUtil.hpp"
30 #include "vkCmdUtil.hpp"
31 #include "vkBuilderUtil.hpp"
32 #include "vkTypeUtil.hpp"
33 #include "vkBarrierUtil.hpp"
34
35 #include "tcuTestLog.hpp"
36
37 #include <vector>
38 #include <iostream>
39
40 namespace vkt
41 {
42 namespace RayTracing
43 {
44 namespace
45 {
46
47 using namespace vk;
48
49 // Causes for hitting the miss shader due to argument values.
50 enum class MissCause
51 {
52 NONE = 0,
53 FLAGS,
54 CULL_MASK,
55 ORIGIN,
56 TMIN,
57 DIRECTION,
58 TMAX,
59 CAUSE_COUNT,
60 };
61
62 struct NonUniformParams
63 {
64 bool miss;
65
66 struct
67 {
68 deUint32 rayTypeCount;
69 deUint32 rayType;
70 } hitParams;
71
72 struct
73 {
74 MissCause missCause;
75 deUint32 missIndex;
76 } missParams;
77 };
78
79 class NonUniformArgsCase : public TestCase
80 {
81 public:
82 NonUniformArgsCase (tcu::TestContext& testCtx, const std::string& name, const std::string& description, const NonUniformParams& params);
~NonUniformArgsCase(void)83 virtual ~NonUniformArgsCase (void) {}
84
85 virtual void checkSupport (Context& context) const;
86 virtual void initPrograms (vk::SourceCollections& programCollection) const;
87 virtual TestInstance* createInstance (Context& context) const;
88
89 protected:
90 NonUniformParams m_params;
91 };
92
93 class NonUniformArgsInstance : public TestInstance
94 {
95 public:
96 NonUniformArgsInstance (Context& context, const NonUniformParams& params);
~NonUniformArgsInstance(void)97 virtual ~NonUniformArgsInstance (void) {}
98
99 virtual tcu::TestStatus iterate (void);
100
101 protected:
102 NonUniformParams m_params;
103 };
104
NonUniformArgsCase(tcu::TestContext & testCtx,const std::string & name,const std::string & description,const NonUniformParams & params)105 NonUniformArgsCase::NonUniformArgsCase (tcu::TestContext& testCtx, const std::string& name, const std::string& description, const NonUniformParams& params)
106 : TestCase (testCtx, name, description)
107 , m_params (params)
108 {}
109
checkSupport(Context & context) const110 void NonUniformArgsCase::checkSupport (Context& context) const
111 {
112 context.requireDeviceFunctionality("VK_KHR_acceleration_structure");
113 context.requireDeviceFunctionality("VK_KHR_ray_tracing_pipeline");
114 }
115
116 struct ArgsBufferData
117 {
118 tcu::Vec4 origin;
119 tcu::Vec4 direction;
120 float Tmin;
121 float Tmax;
122 deUint32 rayFlags;
123 deUint32 cullMask;
124 deUint32 sbtRecordOffset;
125 deUint32 sbtRecordStride;
126 deUint32 missIndex;
127 };
128
initPrograms(vk::SourceCollections & programCollection) const129 void NonUniformArgsCase::initPrograms (vk::SourceCollections& programCollection) const
130 {
131 const ShaderBuildOptions buildOptions(programCollection.usedVulkanVersion, vk::SPIRV_VERSION_1_4, 0u, true);
132
133 std::ostringstream descriptors;
134 descriptors
135 << "layout(set=0, binding=0) uniform accelerationStructureEXT topLevelAS;\n"
136 << "layout(set=0, binding=1, std430) buffer ArgumentsBlock {\n" // Must match ArgsBufferData.
137 << " vec4 origin;\n"
138 << " vec4 direction;\n"
139 << " float Tmin;\n"
140 << " float Tmax;\n"
141 << " uint rayFlags;\n"
142 << " uint cullMask;\n"
143 << " uint sbtRecordOffset;\n"
144 << " uint sbtRecordStride;\n"
145 << " uint missIndex;\n"
146 << "} args;\n"
147 << "layout(set=0, binding=2, std430) buffer ResultBlock {\n"
148 << " uint shaderId;\n"
149 << "} result;\n"
150 ;
151 const auto descriptorsStr = descriptors.str();
152
153 std::ostringstream rgen;
154 rgen
155 << "#version 460 core\n"
156 << "#extension GL_EXT_ray_tracing : require\n"
157 << "\n"
158 << descriptorsStr
159 << "layout(location=0) rayPayloadEXT vec4 unused;\n"
160 << "\n"
161 << "void main()\n"
162 << "{\n"
163 << " traceRayEXT(topLevelAS,\n"
164 << " args.rayFlags,\n"
165 << " args.cullMask,\n"
166 << " args.sbtRecordOffset,\n"
167 << " args.sbtRecordStride,\n"
168 << " args.missIndex,\n"
169 << " args.origin.xyz,\n"
170 << " args.Tmin,\n"
171 << " args.direction.xyz,\n"
172 << " args.Tmax,\n"
173 << " 0);\n"
174 << "}\n"
175 ;
176
177 std::ostringstream chit;
178 chit
179 << "#version 460 core\n"
180 << "#extension GL_EXT_ray_tracing : require\n"
181 << "\n"
182 << descriptorsStr
183 << "layout(constant_id=0) const uint chitShaderId = 0;\n"
184 << "layout(location=0) rayPayloadInEXT vec4 unused;\n"
185 << "\n"
186 << "void main()\n"
187 << "{\n"
188 << " result.shaderId = chitShaderId;\n"
189 << "}\n"
190 ;
191
192 std::ostringstream miss;
193 miss
194 << "#version 460 core\n"
195 << "#extension GL_EXT_ray_tracing : require\n"
196 << "\n"
197 << descriptorsStr
198 << "layout(constant_id=0) const uint missShaderId = 0;\n"
199 << "layout(location=0) rayPayloadInEXT vec4 unused;\n"
200 << "\n"
201 << "void main()\n"
202 << "{\n"
203 << " result.shaderId = missShaderId;\n"
204 << "}\n"
205 ;
206
207 programCollection.glslSources.add("rgen") << glu::RaygenSource(rgen.str()) << buildOptions;
208 programCollection.glslSources.add("chit") << glu::ClosestHitSource(chit.str()) << buildOptions;
209 programCollection.glslSources.add("miss") << glu::MissSource(miss.str()) << buildOptions;
210 }
211
createInstance(Context & context) const212 TestInstance* NonUniformArgsCase::createInstance (Context& context) const
213 {
214 return new NonUniformArgsInstance(context, m_params);
215 }
216
NonUniformArgsInstance(Context & context,const NonUniformParams & params)217 NonUniformArgsInstance::NonUniformArgsInstance (Context& context, const NonUniformParams& params)
218 : TestInstance (context)
219 , m_params (params)
220 {}
221
joinMostLeast(deUint32 most,deUint32 least)222 deUint32 joinMostLeast (deUint32 most, deUint32 least)
223 {
224 constexpr auto kMaxUint16 = static_cast<deUint32>(std::numeric_limits<deUint16>::max());
225 DE_UNREF(kMaxUint16); // For release builds.
226 DE_ASSERT(most <= kMaxUint16 && least <= kMaxUint16);
227 return ((most << 16) | least);
228 }
229
makeMissId(deUint32 missIndex)230 deUint32 makeMissId (deUint32 missIndex)
231 {
232 // 1 on the highest 16 bits for miss shaders.
233 return joinMostLeast(1u, missIndex);
234 }
235
makeChitId(deUint32 chitIndex)236 deUint32 makeChitId (deUint32 chitIndex)
237 {
238 // 2 on the highest 16 bits for closest hit shaders.
239 return joinMostLeast(2u, chitIndex);
240 }
241
iterate(void)242 tcu::TestStatus NonUniformArgsInstance::iterate (void)
243 {
244 const auto& vki = m_context.getInstanceInterface();
245 const auto physDev = m_context.getPhysicalDevice();
246 const auto& vkd = m_context.getDeviceInterface();
247 const auto device = m_context.getDevice();
248 auto& alloc = m_context.getDefaultAllocator();
249 const auto qIndex = m_context.getUniversalQueueFamilyIndex();
250 const auto queue = m_context.getUniversalQueue();
251 const auto stages = (VK_SHADER_STAGE_RAYGEN_BIT_KHR | VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR | VK_SHADER_STAGE_MISS_BIT_KHR);
252
253 // Geometry data constants.
254 const std::vector<tcu::Vec3> kOffscreenTriangle =
255 {
256 // Triangle around (x=0, y=2) z=-5
257 tcu::Vec3( 0.0f, 2.5f, -5.0f),
258 tcu::Vec3(-0.5f, 1.5f, -5.0f),
259 tcu::Vec3( 0.5f, 1.5f, -5.0f),
260 };
261 const std::vector<tcu::Vec3> kOnscreenTriangle =
262 {
263 // Triangle around (x=0, y=2) z=5
264 tcu::Vec3( 0.0f, 2.5f, 5.0f),
265 tcu::Vec3(-0.5f, 1.5f, 5.0f),
266 tcu::Vec3( 0.5f, 1.5f, 5.0f),
267 };
268 const tcu::Vec4 kGoodOrigin (0.0f, 2.0f, 0.0f, 0.0f); // Around (x=0, y=2) z=0.
269 const tcu::Vec4 kBadOrigin (0.0f, 8.0f, 0.0f, 0.0f); // Too high, around (x=0, y=8) depth 0.
270 const tcu::Vec4 kGoodDirection (0.0f, 0.0f, 1.0f, 0.0f); // Towards +z.
271 const tcu::Vec4 kBadDirection (1.0f, 0.0f, 0.0f, 0.0f); // Towards +x.
272 const float kGoodTmin = 4.0f; // Good to travel from z=0 to z=5.
273 const float kGoodTmax = 6.0f; // Ditto.
274 const float kBadTmin = 5.5f; // Tmin after triangle.
275 const float kBadTmax = 4.5f; // Tmax before triangle.
276 const deUint32 kGoodFlags = 0u; // MaskNone
277 const deUint32 kBadFlags = 256u; // SkipTrianglesKHR
278 const deUint32 kGoodCullMask = 0x0Fu; // Matches instance.
279 const deUint32 kBadCullMask = 0xF0u; // Does not match instance.
280
281 // Command pool and buffer.
282 const auto cmdPool = makeCommandPool(vkd, device, qIndex);
283 const auto cmdBufferPtr = allocateCommandBuffer(vkd, device, cmdPool.get(), VK_COMMAND_BUFFER_LEVEL_PRIMARY);
284 const auto cmdBuffer = cmdBufferPtr.get();
285
286 beginCommandBuffer(vkd, cmdBuffer);
287
288 // Build acceleration structures.
289 auto topLevelAS = makeTopLevelAccelerationStructure();
290 auto bottomLevelAS = makeBottomLevelAccelerationStructure();
291
292 // Putting the offscreen triangle first makes sure hits have a geometryIndex=1, meaning sbtRecordStride matters.
293 std::vector<const std::vector<tcu::Vec3>*> geometries;
294 geometries.push_back(&kOffscreenTriangle);
295 geometries.push_back(&kOnscreenTriangle);
296
297 for (const auto& geometryPtr : geometries)
298 bottomLevelAS->addGeometry(*geometryPtr, true /* is triangles */);
299
300 bottomLevelAS->createAndBuild(vkd, device, cmdBuffer, alloc);
301
302 de::SharedPtr<BottomLevelAccelerationStructure> blasSharedPtr (bottomLevelAS.release());
303 topLevelAS->setInstanceCount(1);
304 topLevelAS->addInstance(blasSharedPtr, identityMatrix3x4, 0u, kGoodCullMask, 0u, VK_GEOMETRY_INSTANCE_TRIANGLE_FACING_CULL_DISABLE_BIT_KHR);
305 topLevelAS->createAndBuild(vkd, device, cmdBuffer, alloc);
306
307 // Input storage buffer.
308 const auto inputBufferSize = static_cast<VkDeviceSize>(sizeof(ArgsBufferData));
309 const auto inputBufferInfo = makeBufferCreateInfo(inputBufferSize, VK_BUFFER_USAGE_STORAGE_BUFFER_BIT);
310 BufferWithMemory inputBuffer (vkd, device, alloc, inputBufferInfo, MemoryRequirement::HostVisible);
311 auto& inputBufferAlloc = inputBuffer.getAllocation();
312
313 // Output storage buffer.
314 const auto outputBufferSize = static_cast<VkDeviceSize>(sizeof(deUint32));
315 const auto outputBufferInfo = makeBufferCreateInfo(outputBufferSize, VK_BUFFER_USAGE_STORAGE_BUFFER_BIT);
316 BufferWithMemory outputBuffer (vkd, device, alloc, outputBufferInfo, MemoryRequirement::HostVisible);
317 auto& outputBufferAlloc = outputBuffer.getAllocation();
318
319 // Fill output buffer with an initial value.
320 deMemset(outputBufferAlloc.getHostPtr(), 0, static_cast<size_t>(outputBufferSize));
321 flushAlloc(vkd, device, outputBufferAlloc);
322
323 // Descriptor set layout and pipeline layout.
324 DescriptorSetLayoutBuilder setLayoutBuilder;
325 setLayoutBuilder.addSingleBinding(VK_DESCRIPTOR_TYPE_ACCELERATION_STRUCTURE_KHR, stages);
326 setLayoutBuilder.addSingleBinding(VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, stages);
327 setLayoutBuilder.addSingleBinding(VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, stages);
328 const auto setLayout = setLayoutBuilder.build(vkd, device);
329 const auto pipelineLayout = makePipelineLayout(vkd, device, setLayout.get());
330
331 // Descriptor pool and set.
332 DescriptorPoolBuilder poolBuilder;
333 poolBuilder.addType(VK_DESCRIPTOR_TYPE_ACCELERATION_STRUCTURE_KHR);
334 poolBuilder.addType(VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, 2u);
335 const auto descriptorPool = poolBuilder.build(vkd, device, VK_DESCRIPTOR_POOL_CREATE_FREE_DESCRIPTOR_SET_BIT, 1u);
336 const auto descriptorSet = makeDescriptorSet(vkd, device, descriptorPool.get(), setLayout.get());
337
338 // Update descriptor set.
339 {
340 const VkWriteDescriptorSetAccelerationStructureKHR accelDescInfo =
341 {
342 VK_STRUCTURE_TYPE_WRITE_DESCRIPTOR_SET_ACCELERATION_STRUCTURE_KHR,
343 nullptr,
344 1u,
345 topLevelAS.get()->getPtr(),
346 };
347
348 const auto inputBufferDescInfo = makeDescriptorBufferInfo(inputBuffer.get(), 0ull, VK_WHOLE_SIZE);
349 const auto outputBufferDescInfo = makeDescriptorBufferInfo(outputBuffer.get(), 0ull, VK_WHOLE_SIZE);
350
351 DescriptorSetUpdateBuilder updateBuilder;
352 updateBuilder.writeSingle(descriptorSet.get(), DescriptorSetUpdateBuilder::Location::binding(0u), VK_DESCRIPTOR_TYPE_ACCELERATION_STRUCTURE_KHR, &accelDescInfo);
353 updateBuilder.writeSingle(descriptorSet.get(), DescriptorSetUpdateBuilder::Location::binding(1u), VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, &inputBufferDescInfo);
354 updateBuilder.writeSingle(descriptorSet.get(), DescriptorSetUpdateBuilder::Location::binding(2u), VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, &outputBufferDescInfo);
355 updateBuilder.update(vkd, device);
356 }
357
358 // Shader modules.
359 auto rgenModule = makeVkSharedPtr(createShaderModule(vkd, device, m_context.getBinaryCollection().get("rgen"), 0));
360 auto missModule = makeVkSharedPtr(createShaderModule(vkd, device, m_context.getBinaryCollection().get("miss"), 0));
361 auto chitModule = makeVkSharedPtr(createShaderModule(vkd, device, m_context.getBinaryCollection().get("chit"), 0));
362
363 // Get some ray tracing properties.
364 deUint32 shaderGroupHandleSize = 0u;
365 deUint32 shaderGroupBaseAlignment = 1u;
366 {
367 const auto rayTracingPropertiesKHR = makeRayTracingProperties(vki, physDev);
368 shaderGroupHandleSize = rayTracingPropertiesKHR->getShaderGroupHandleSize();
369 shaderGroupBaseAlignment = rayTracingPropertiesKHR->getShaderGroupBaseAlignment();
370 }
371
372 // Create raytracing pipeline and shader binding tables.
373 Move<VkPipeline> pipeline;
374
375 de::MovePtr<BufferWithMemory> raygenSBT;
376 de::MovePtr<BufferWithMemory> missSBT;
377 de::MovePtr<BufferWithMemory> hitSBT;
378 de::MovePtr<BufferWithMemory> callableSBT;
379
380 VkStridedDeviceAddressRegionKHR raygenSBTRegion = makeStridedDeviceAddressRegionKHR(DE_NULL, 0, 0);
381 VkStridedDeviceAddressRegionKHR missSBTRegion = makeStridedDeviceAddressRegionKHR(DE_NULL, 0, 0);
382 VkStridedDeviceAddressRegionKHR hitSBTRegion = makeStridedDeviceAddressRegionKHR(DE_NULL, 0, 0);
383 VkStridedDeviceAddressRegionKHR callableSBTRegion = makeStridedDeviceAddressRegionKHR(DE_NULL, 0, 0);
384
385 // Generate ids for the closest hit and miss shaders according to the test parameters.
386 DE_ASSERT(m_params.hitParams.rayTypeCount > 0u);
387 DE_ASSERT(m_params.hitParams.rayType < m_params.hitParams.rayTypeCount);
388 DE_ASSERT(geometries.size() > 0u);
389
390 std::vector<deUint32> missShaderIds;
391 for (deUint32 missIdx = 0; missIdx <= m_params.missParams.missIndex; ++missIdx)
392 missShaderIds.push_back(makeMissId(missIdx));
393
394 deUint32 chitCounter = 0u;
395 std::vector<deUint32> chitShaderIds;
396
397 for (size_t geoIdx = 0; geoIdx < geometries.size(); ++geoIdx)
398 for (deUint32 rayIdx = 0; rayIdx < m_params.hitParams.rayTypeCount; ++rayIdx)
399 chitShaderIds.push_back(makeChitId(chitCounter++));
400
401 {
402 const auto rayTracingPipeline = de::newMovePtr<RayTracingPipeline>();
403 const VkSpecializationMapEntry specializationMapEntry =
404 {
405 0u, // deUint32 constantID;
406 0u, // deUint32 offset;
407 static_cast<deUintptr>(sizeof(deUint32)), // deUintptr size;
408 };
409 VkSpecializationInfo specInfo =
410 {
411 1u, // deUint32 mapEntryCount;
412 &specializationMapEntry, // const VkSpecializationMapEntry* pMapEntries;
413 static_cast<deUintptr>(sizeof(deUint32)), // deUintptr dataSize;
414 nullptr, // const void* pData;
415 };
416
417 std::vector<VkSpecializationInfo> specInfos;
418 specInfos.reserve(missShaderIds.size() + chitShaderIds.size());
419
420 deUint32 shaderGroupIdx = 0u;
421 rayTracingPipeline->addShader(VK_SHADER_STAGE_RAYGEN_BIT_KHR, rgenModule, shaderGroupIdx++);
422
423 for (size_t missIdx = 0; missIdx < missShaderIds.size(); ++missIdx)
424 {
425 specInfo.pData = &missShaderIds.at(missIdx);
426 specInfos.push_back(specInfo);
427 rayTracingPipeline->addShader(VK_SHADER_STAGE_MISS_BIT_KHR, missModule, shaderGroupIdx++, &specInfos.back());
428 }
429
430 const auto firstChitGroup = shaderGroupIdx;
431
432 for (size_t chitIdx = 0; chitIdx < chitShaderIds.size(); ++chitIdx)
433 {
434 specInfo.pData = &chitShaderIds.at(chitIdx);
435 specInfos.push_back(specInfo);
436 rayTracingPipeline->addShader(VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR, chitModule, shaderGroupIdx++, &specInfos.back());
437 }
438
439 pipeline = rayTracingPipeline->createPipeline(vkd, device, pipelineLayout.get());
440
441 raygenSBT = rayTracingPipeline->createShaderBindingTable(vkd, device, pipeline.get(), alloc, shaderGroupHandleSize, shaderGroupBaseAlignment, 0u, 1u);
442 raygenSBTRegion = makeStridedDeviceAddressRegionKHR(getBufferDeviceAddress(vkd, device, raygenSBT->get(), 0), shaderGroupHandleSize, shaderGroupHandleSize);
443
444 missSBT = rayTracingPipeline->createShaderBindingTable(vkd, device, pipeline.get(), alloc, shaderGroupHandleSize, shaderGroupBaseAlignment, 1u, static_cast<deUint32>(missShaderIds.size()));
445 missSBTRegion = makeStridedDeviceAddressRegionKHR(getBufferDeviceAddress(vkd, device, missSBT->get(), 0), shaderGroupHandleSize, shaderGroupHandleSize * missShaderIds.size());
446
447 hitSBT = rayTracingPipeline->createShaderBindingTable(vkd, device, pipeline.get(), alloc, shaderGroupHandleSize, shaderGroupBaseAlignment, firstChitGroup, static_cast<deUint32>(chitShaderIds.size()));
448 hitSBTRegion = makeStridedDeviceAddressRegionKHR(getBufferDeviceAddress(vkd, device, hitSBT->get(), 0), shaderGroupHandleSize, shaderGroupHandleSize * chitShaderIds.size());
449 }
450
451 // Fill input buffer values.
452 {
453 DE_ASSERT(!(m_params.miss && m_params.missParams.missCause == MissCause::NONE));
454
455 const ArgsBufferData argsBufferData =
456 {
457 ((m_params.miss && m_params.missParams.missCause == MissCause::ORIGIN) ? kBadOrigin : kGoodOrigin),
458 ((m_params.miss && m_params.missParams.missCause == MissCause::DIRECTION) ? kBadDirection : kGoodDirection),
459 ((m_params.miss && m_params.missParams.missCause == MissCause::TMIN) ? kBadTmin : kGoodTmin),
460 ((m_params.miss && m_params.missParams.missCause == MissCause::TMAX) ? kBadTmax : kGoodTmax),
461 ((m_params.miss && m_params.missParams.missCause == MissCause::FLAGS) ? kBadFlags : kGoodFlags),
462 ((m_params.miss && m_params.missParams.missCause == MissCause::CULL_MASK) ? kBadCullMask : kGoodCullMask),
463 m_params.hitParams.rayType,
464 m_params.hitParams.rayTypeCount,
465 m_params.missParams.missIndex,
466 };
467
468 deMemcpy(inputBufferAlloc.getHostPtr(), &argsBufferData, sizeof(argsBufferData));
469 flushAlloc(vkd, device, inputBufferAlloc);
470 }
471
472 // Trace rays.
473 vkd.cmdBindPipeline(cmdBuffer, VK_PIPELINE_BIND_POINT_RAY_TRACING_KHR, pipeline.get());
474 vkd.cmdBindDescriptorSets(cmdBuffer, VK_PIPELINE_BIND_POINT_RAY_TRACING_KHR, pipelineLayout.get(), 0u, 1u, &descriptorSet.get(), 0u, nullptr);
475 vkd.cmdTraceRaysKHR(cmdBuffer, &raygenSBTRegion, &missSBTRegion, &hitSBTRegion, &callableSBTRegion, 1u, 1u, 1u);
476
477 // Barrier for the output buffer.
478 const auto bufferBarrier = makeMemoryBarrier(VK_ACCESS_SHADER_WRITE_BIT, VK_ACCESS_HOST_READ_BIT);
479 vkd.cmdPipelineBarrier(cmdBuffer, VK_PIPELINE_STAGE_RAY_TRACING_SHADER_BIT_KHR, VK_PIPELINE_STAGE_HOST_BIT, 0u, 1u, &bufferBarrier, 0u, nullptr, 0u, nullptr);
480
481 endCommandBuffer(vkd, cmdBuffer);
482 submitCommandsAndWait(vkd, device, queue, cmdBuffer);
483
484 // Check output value.
485 invalidateAlloc(vkd, device, outputBufferAlloc);
486 deUint32 outputVal = std::numeric_limits<deUint32>::max();
487 deMemcpy(&outputVal, outputBufferAlloc.getHostPtr(), sizeof(outputVal));
488 const auto expectedVal = (m_params.miss ? makeMissId(m_params.missParams.missIndex) : makeChitId(m_params.hitParams.rayTypeCount + m_params.hitParams.rayType));
489
490 std::ostringstream msg;
491 msg << "Output value: 0x" << std::hex << outputVal << " (expected 0x" << expectedVal << ")";
492
493 if (outputVal != expectedVal)
494 return tcu::TestStatus::fail(msg.str());
495
496 auto& log = m_context.getTestContext().getLog();
497 log << tcu::TestLog::Message << msg.str() << tcu::TestLog::EndMessage;
498
499 return tcu::TestStatus::pass("Pass");
500 }
501
502 } // anonymous
503
createNonUniformArgsTests(tcu::TestContext & testCtx)504 tcu::TestCaseGroup* createNonUniformArgsTests (tcu::TestContext& testCtx)
505 {
506 de::MovePtr<tcu::TestCaseGroup> nonUniformGroup(new tcu::TestCaseGroup(testCtx, "non_uniform_args", "Test non-uniform arguments in traceRayExt()"));
507
508 // Closest hit cases.
509 {
510 NonUniformParams params;
511 params.miss = false;
512 params.missParams.missIndex = 0u;
513 params.missParams.missCause = MissCause::NONE;
514
515 for (deUint32 typeCount = 1u; typeCount <= 4u; ++typeCount)
516 {
517 params.hitParams.rayTypeCount = typeCount;
518 for (deUint32 rayType = 0u; rayType < typeCount; ++rayType)
519 {
520 params.hitParams.rayType = rayType;
521 nonUniformGroup->addChild(new NonUniformArgsCase(testCtx, "chit_" + de::toString(typeCount) + "_types_" + de::toString(rayType), "", params));
522 }
523 }
524 }
525
526 // Miss cases.
527 {
528 NonUniformParams params;
529 params.miss = true;
530 params.hitParams.rayTypeCount = 1u;
531 params.hitParams.rayType = 0u;
532
533 for (int causeIdx = static_cast<int>(MissCause::NONE) + 1; causeIdx < static_cast<int>(MissCause::CAUSE_COUNT); ++causeIdx)
534 {
535 params.missParams.missCause = static_cast<MissCause>(causeIdx);
536 params.missParams.missIndex = static_cast<deUint32>(causeIdx-1);
537 nonUniformGroup->addChild(new NonUniformArgsCase(testCtx, "miss_cause_" + de::toString(causeIdx), "", params));
538 }
539 }
540
541 return nonUniformGroup.release();
542 }
543
544 } // RayTracing
545 } // vkt
546