• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*------------------------------------------------------------------------
2  * Vulkan Conformance Tests
3  * ------------------------
4  *
5  * Copyright (c) 2021 The Khronos Group Inc.
6  * Copyright (c) 2021 Valve Corporation.
7  *
8  * Licensed under the Apache License, Version 2.0 (the "License");
9  * you may not use this file except in compliance with the License.
10  * You may obtain a copy of the License at
11  *
12  *	  http://www.apache.org/licenses/LICENSE-2.0
13  *
14  * Unless required by applicable law or agreed to in writing, software
15  * distributed under the License is distributed on an "AS IS" BASIS,
16  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17  * See the License for the specific language governing permissions and
18  * limitations under the License.
19  *
20  *//*!
21  * \file
22  * \brief Tests using non-uniform arguments with traceRayExt().
23  *//*--------------------------------------------------------------------*/
24 
25 #include "vktRayTracingNonUniformArgsTests.hpp"
26 #include "vktTestCase.hpp"
27 
28 #include "vkRayTracingUtil.hpp"
29 #include "vkObjUtil.hpp"
30 #include "vkCmdUtil.hpp"
31 #include "vkBuilderUtil.hpp"
32 #include "vkTypeUtil.hpp"
33 #include "vkBarrierUtil.hpp"
34 
35 #include "tcuTestLog.hpp"
36 
37 #include <vector>
38 #include <iostream>
39 
40 namespace vkt
41 {
42 namespace RayTracing
43 {
44 namespace
45 {
46 
47 using namespace vk;
48 
49 // Causes for hitting the miss shader due to argument values.
50 enum class MissCause
51 {
52 	NONE = 0,
53 	FLAGS,
54 	CULL_MASK,
55 	ORIGIN,
56 	TMIN,
57 	DIRECTION,
58 	TMAX,
59 	CAUSE_COUNT,
60 };
61 
62 struct NonUniformParams
63 {
64 	bool miss;
65 
66 	struct
67 	{
68 		deUint32	rayTypeCount;
69 		deUint32	rayType;
70 	} hitParams;
71 
72 	struct
73 	{
74 		MissCause	missCause;
75 		deUint32	missIndex;
76 	} missParams;
77 };
78 
79 class NonUniformArgsCase : public TestCase
80 {
81 public:
82 							NonUniformArgsCase		(tcu::TestContext& testCtx, const std::string& name, const std::string& description, const NonUniformParams& params);
~NonUniformArgsCase(void)83 	virtual					~NonUniformArgsCase		(void) {}
84 
85 	virtual void			checkSupport			(Context& context) const;
86 	virtual void			initPrograms			(vk::SourceCollections& programCollection) const;
87 	virtual TestInstance*	createInstance			(Context& context) const;
88 
89 protected:
90 	NonUniformParams		m_params;
91 };
92 
93 class NonUniformArgsInstance : public TestInstance
94 {
95 public:
96 								NonUniformArgsInstance	(Context& context, const NonUniformParams& params);
~NonUniformArgsInstance(void)97 	virtual						~NonUniformArgsInstance	(void) {}
98 
99 	virtual tcu::TestStatus		iterate					(void);
100 
101 protected:
102 	NonUniformParams			m_params;
103 };
104 
NonUniformArgsCase(tcu::TestContext & testCtx,const std::string & name,const std::string & description,const NonUniformParams & params)105 NonUniformArgsCase::NonUniformArgsCase (tcu::TestContext& testCtx, const std::string& name, const std::string& description, const NonUniformParams& params)
106 	: TestCase	(testCtx, name, description)
107 	, m_params	(params)
108 {}
109 
checkSupport(Context & context) const110 void NonUniformArgsCase::checkSupport (Context& context) const
111 {
112 	context.requireDeviceFunctionality("VK_KHR_acceleration_structure");
113 	context.requireDeviceFunctionality("VK_KHR_ray_tracing_pipeline");
114 }
115 
116 struct ArgsBufferData
117 {
118 	tcu::Vec4	origin;
119 	tcu::Vec4	direction;
120 	float		Tmin;
121 	float		Tmax;
122 	deUint32	rayFlags;
123 	deUint32	cullMask;
124 	deUint32	sbtRecordOffset;
125 	deUint32	sbtRecordStride;
126 	deUint32	missIndex;
127 };
128 
initPrograms(vk::SourceCollections & programCollection) const129 void NonUniformArgsCase::initPrograms (vk::SourceCollections& programCollection) const
130 {
131 	const ShaderBuildOptions buildOptions(programCollection.usedVulkanVersion, vk::SPIRV_VERSION_1_4, 0u, true);
132 
133 	std::ostringstream descriptors;
134 	descriptors
135 		<< "layout(set=0, binding=0) uniform accelerationStructureEXT topLevelAS;\n"
136 		<< "layout(set=0, binding=1, std430) buffer ArgumentsBlock {\n" // Must match ArgsBufferData.
137 		<< "  vec4  origin;\n"
138 		<< "  vec4  direction;\n"
139 		<< "  float Tmin;\n"
140 		<< "  float Tmax;\n"
141 		<< "  uint  rayFlags;\n"
142 		<< "  uint  cullMask;\n"
143 		<< "  uint  sbtRecordOffset;\n"
144 		<< "  uint  sbtRecordStride;\n"
145 		<< "  uint  missIndex;\n"
146 		<< "} args;\n"
147 		<< "layout(set=0, binding=2, std430) buffer ResultBlock {\n"
148 		<< "  uint shaderId;\n"
149 		<< "} result;\n"
150 		;
151 	const auto descriptorsStr = descriptors.str();
152 
153 	std::ostringstream rgen;
154 	rgen
155 		<< "#version 460 core\n"
156 		<< "#extension GL_EXT_ray_tracing : require\n"
157 		<< "\n"
158 		<< descriptorsStr
159 		<< "layout(location=0) rayPayloadEXT vec4 unused;\n"
160 		<< "\n"
161 		<< "void main()\n"
162 		<< "{\n"
163 		<< "  traceRayEXT(topLevelAS,\n"
164 		<< "    args.rayFlags,\n"
165 		<< "    args.cullMask,\n"
166 		<< "    args.sbtRecordOffset,\n"
167 		<< "    args.sbtRecordStride,\n"
168 		<< "    args.missIndex,\n"
169 		<< "    args.origin.xyz,\n"
170 		<< "    args.Tmin,\n"
171 		<< "    args.direction.xyz,\n"
172 		<< "    args.Tmax,\n"
173 		<< "    0);\n"
174 		<< "}\n"
175 		;
176 
177 	std::ostringstream chit;
178 	chit
179 		<< "#version 460 core\n"
180 		<< "#extension GL_EXT_ray_tracing : require\n"
181 		<< "\n"
182 		<< descriptorsStr
183 		<< "layout(constant_id=0) const uint chitShaderId = 0;\n"
184 		<< "layout(location=0) rayPayloadInEXT vec4 unused;\n"
185 		<< "\n"
186 		<< "void main()\n"
187 		<< "{\n"
188 		<< "  result.shaderId = chitShaderId;\n"
189 		<< "}\n"
190 		;
191 
192 	std::ostringstream miss;
193 	miss
194 		<< "#version 460 core\n"
195 		<< "#extension GL_EXT_ray_tracing : require\n"
196 		<< "\n"
197 		<< descriptorsStr
198 		<< "layout(constant_id=0) const uint missShaderId = 0;\n"
199 		<< "layout(location=0) rayPayloadInEXT vec4 unused;\n"
200 		<< "\n"
201 		<< "void main()\n"
202 		<< "{\n"
203 		<< "  result.shaderId = missShaderId;\n"
204 		<< "}\n"
205 		;
206 
207 	programCollection.glslSources.add("rgen") << glu::RaygenSource(rgen.str()) << buildOptions;
208 	programCollection.glslSources.add("chit") << glu::ClosestHitSource(chit.str()) << buildOptions;
209 	programCollection.glslSources.add("miss") << glu::MissSource(miss.str()) << buildOptions;
210 }
211 
createInstance(Context & context) const212 TestInstance* NonUniformArgsCase::createInstance (Context& context) const
213 {
214 	return new NonUniformArgsInstance(context, m_params);
215 }
216 
NonUniformArgsInstance(Context & context,const NonUniformParams & params)217 NonUniformArgsInstance::NonUniformArgsInstance (Context& context, const NonUniformParams& params)
218 	: TestInstance	(context)
219 	, m_params		(params)
220 {}
221 
joinMostLeast(deUint32 most,deUint32 least)222 deUint32 joinMostLeast (deUint32 most, deUint32 least)
223 {
224 	constexpr auto kMaxUint16 = static_cast<deUint32>(std::numeric_limits<deUint16>::max());
225 	DE_UNREF(kMaxUint16); // For release builds.
226 	DE_ASSERT(most <= kMaxUint16 && least <= kMaxUint16);
227 	return ((most << 16) | least);
228 }
229 
makeMissId(deUint32 missIndex)230 deUint32 makeMissId (deUint32 missIndex)
231 {
232 	// 1 on the highest 16 bits for miss shaders.
233 	return joinMostLeast(1u, missIndex);
234 }
235 
makeChitId(deUint32 chitIndex)236 deUint32 makeChitId (deUint32 chitIndex)
237 {
238 	// 2 on the highest 16 bits for closest hit shaders.
239 	return joinMostLeast(2u, chitIndex);
240 }
241 
iterate(void)242 tcu::TestStatus NonUniformArgsInstance::iterate (void)
243 {
244 	const auto&	vki		= m_context.getInstanceInterface();
245 	const auto	physDev	= m_context.getPhysicalDevice();
246 	const auto&	vkd		= m_context.getDeviceInterface();
247 	const auto	device	= m_context.getDevice();
248 	auto&		alloc	= m_context.getDefaultAllocator();
249 	const auto	qIndex	= m_context.getUniversalQueueFamilyIndex();
250 	const auto	queue	= m_context.getUniversalQueue();
251 	const auto	stages	= (VK_SHADER_STAGE_RAYGEN_BIT_KHR | VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR | VK_SHADER_STAGE_MISS_BIT_KHR);
252 
253 	// Geometry data constants.
254 	const std::vector<tcu::Vec3> kOffscreenTriangle =
255 	{
256 		// Triangle around (x=0, y=2) z=-5
257 		tcu::Vec3( 0.0f, 2.5f, -5.0f),
258 		tcu::Vec3(-0.5f, 1.5f, -5.0f),
259 		tcu::Vec3( 0.5f, 1.5f, -5.0f),
260 	};
261 	const std::vector<tcu::Vec3> kOnscreenTriangle =
262 	{
263 		// Triangle around (x=0, y=2) z=5
264 		tcu::Vec3( 0.0f, 2.5f, 5.0f),
265 		tcu::Vec3(-0.5f, 1.5f, 5.0f),
266 		tcu::Vec3( 0.5f, 1.5f, 5.0f),
267 	};
268 	const tcu::Vec4		kGoodOrigin		(0.0f, 2.0f, 0.0f, 0.0f);	// Around (x=0, y=2) z=0.
269 	const tcu::Vec4		kBadOrigin		(0.0f, 8.0f, 0.0f, 0.0f);	// Too high, around (x=0, y=8) depth 0.
270 	const tcu::Vec4		kGoodDirection	(0.0f, 0.0f, 1.0f, 0.0f);	// Towards +z.
271 	const tcu::Vec4		kBadDirection	(1.0f, 0.0f, 0.0f, 0.0f);	// Towards +x.
272 	const float			kGoodTmin		= 4.0f;						// Good to travel from z=0 to z=5.
273 	const float			kGoodTmax		= 6.0f;						// Ditto.
274 	const float			kBadTmin		= 5.5f;						// Tmin after triangle.
275 	const float			kBadTmax		= 4.5f;						// Tmax before triangle.
276 	const deUint32		kGoodFlags		= 0u;						// MaskNone
277 	const deUint32		kBadFlags		= 256u;						// SkipTrianglesKHR
278 	const deUint32		kGoodCullMask	= 0x0Fu;					// Matches instance.
279 	const deUint32		kBadCullMask	= 0xF0u;					// Does not match instance.
280 
281 	// Command pool and buffer.
282 	const auto cmdPool		= makeCommandPool(vkd, device, qIndex);
283 	const auto cmdBufferPtr	= allocateCommandBuffer(vkd, device, cmdPool.get(), VK_COMMAND_BUFFER_LEVEL_PRIMARY);
284 	const auto cmdBuffer	= cmdBufferPtr.get();
285 
286 	beginCommandBuffer(vkd, cmdBuffer);
287 
288 	// Build acceleration structures.
289 	auto topLevelAS		= makeTopLevelAccelerationStructure();
290 	auto bottomLevelAS	= makeBottomLevelAccelerationStructure();
291 
292 	// Putting the offscreen triangle first makes sure hits have a geometryIndex=1, meaning sbtRecordStride matters.
293 	std::vector<const std::vector<tcu::Vec3>*> geometries;
294 	geometries.push_back(&kOffscreenTriangle);
295 	geometries.push_back(&kOnscreenTriangle);
296 
297 	for (const auto& geometryPtr : geometries)
298 		bottomLevelAS->addGeometry(*geometryPtr, true /* is triangles */);
299 
300 	bottomLevelAS->createAndBuild(vkd, device, cmdBuffer, alloc);
301 
302 	de::SharedPtr<BottomLevelAccelerationStructure> blasSharedPtr (bottomLevelAS.release());
303 	topLevelAS->setInstanceCount(1);
304 	topLevelAS->addInstance(blasSharedPtr, identityMatrix3x4, 0u, kGoodCullMask, 0u, VK_GEOMETRY_INSTANCE_TRIANGLE_FACING_CULL_DISABLE_BIT_KHR);
305 	topLevelAS->createAndBuild(vkd, device, cmdBuffer, alloc);
306 
307 	// Input storage buffer.
308 	const auto			inputBufferSize		= static_cast<VkDeviceSize>(sizeof(ArgsBufferData));
309 	const auto			inputBufferInfo		= makeBufferCreateInfo(inputBufferSize, VK_BUFFER_USAGE_STORAGE_BUFFER_BIT);
310 	BufferWithMemory	inputBuffer			(vkd, device, alloc, inputBufferInfo, MemoryRequirement::HostVisible);
311 	auto&				inputBufferAlloc	= inputBuffer.getAllocation();
312 
313 	// Output storage buffer.
314 	const auto			outputBufferSize	= static_cast<VkDeviceSize>(sizeof(deUint32));
315 	const auto			outputBufferInfo	= makeBufferCreateInfo(outputBufferSize, VK_BUFFER_USAGE_STORAGE_BUFFER_BIT);
316 	BufferWithMemory	outputBuffer		(vkd, device, alloc, outputBufferInfo, MemoryRequirement::HostVisible);
317 	auto&				outputBufferAlloc	= outputBuffer.getAllocation();
318 
319 	// Fill output buffer with an initial value.
320 	deMemset(outputBufferAlloc.getHostPtr(), 0, static_cast<size_t>(outputBufferSize));
321 	flushAlloc(vkd, device, outputBufferAlloc);
322 
323 	// Descriptor set layout and pipeline layout.
324 	DescriptorSetLayoutBuilder setLayoutBuilder;
325 	setLayoutBuilder.addSingleBinding(VK_DESCRIPTOR_TYPE_ACCELERATION_STRUCTURE_KHR, stages);
326 	setLayoutBuilder.addSingleBinding(VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, stages);
327 	setLayoutBuilder.addSingleBinding(VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, stages);
328 	const auto setLayout		= setLayoutBuilder.build(vkd, device);
329 	const auto pipelineLayout	= makePipelineLayout(vkd, device, setLayout.get());
330 
331 	// Descriptor pool and set.
332 	DescriptorPoolBuilder poolBuilder;
333 	poolBuilder.addType(VK_DESCRIPTOR_TYPE_ACCELERATION_STRUCTURE_KHR);
334 	poolBuilder.addType(VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, 2u);
335 	const auto descriptorPool	= poolBuilder.build(vkd, device, VK_DESCRIPTOR_POOL_CREATE_FREE_DESCRIPTOR_SET_BIT, 1u);
336 	const auto descriptorSet	= makeDescriptorSet(vkd, device, descriptorPool.get(), setLayout.get());
337 
338 	// Update descriptor set.
339 	{
340 		const VkWriteDescriptorSetAccelerationStructureKHR accelDescInfo =
341 		{
342 			VK_STRUCTURE_TYPE_WRITE_DESCRIPTOR_SET_ACCELERATION_STRUCTURE_KHR,
343 			nullptr,
344 			1u,
345 			topLevelAS.get()->getPtr(),
346 		};
347 
348 		const auto inputBufferDescInfo	= makeDescriptorBufferInfo(inputBuffer.get(), 0ull, VK_WHOLE_SIZE);
349 		const auto outputBufferDescInfo	= makeDescriptorBufferInfo(outputBuffer.get(), 0ull, VK_WHOLE_SIZE);
350 
351 		DescriptorSetUpdateBuilder updateBuilder;
352 		updateBuilder.writeSingle(descriptorSet.get(), DescriptorSetUpdateBuilder::Location::binding(0u), VK_DESCRIPTOR_TYPE_ACCELERATION_STRUCTURE_KHR, &accelDescInfo);
353 		updateBuilder.writeSingle(descriptorSet.get(), DescriptorSetUpdateBuilder::Location::binding(1u), VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, &inputBufferDescInfo);
354 		updateBuilder.writeSingle(descriptorSet.get(), DescriptorSetUpdateBuilder::Location::binding(2u), VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, &outputBufferDescInfo);
355 		updateBuilder.update(vkd, device);
356 	}
357 
358 	// Shader modules.
359 	auto rgenModule = makeVkSharedPtr(createShaderModule(vkd, device, m_context.getBinaryCollection().get("rgen"), 0));
360 	auto missModule = makeVkSharedPtr(createShaderModule(vkd, device, m_context.getBinaryCollection().get("miss"), 0));
361 	auto chitModule = makeVkSharedPtr(createShaderModule(vkd, device, m_context.getBinaryCollection().get("chit"), 0));
362 
363 	// Get some ray tracing properties.
364 	deUint32 shaderGroupHandleSize		= 0u;
365 	deUint32 shaderGroupBaseAlignment	= 1u;
366 	{
367 		const auto rayTracingPropertiesKHR	= makeRayTracingProperties(vki, physDev);
368 		shaderGroupHandleSize				= rayTracingPropertiesKHR->getShaderGroupHandleSize();
369 		shaderGroupBaseAlignment			= rayTracingPropertiesKHR->getShaderGroupBaseAlignment();
370 	}
371 
372 	// Create raytracing pipeline and shader binding tables.
373 	Move<VkPipeline>				pipeline;
374 
375 	de::MovePtr<BufferWithMemory>	raygenSBT;
376 	de::MovePtr<BufferWithMemory>	missSBT;
377 	de::MovePtr<BufferWithMemory>	hitSBT;
378 	de::MovePtr<BufferWithMemory>	callableSBT;
379 
380 	VkStridedDeviceAddressRegionKHR	raygenSBTRegion		= makeStridedDeviceAddressRegionKHR(DE_NULL, 0, 0);
381 	VkStridedDeviceAddressRegionKHR	missSBTRegion		= makeStridedDeviceAddressRegionKHR(DE_NULL, 0, 0);
382 	VkStridedDeviceAddressRegionKHR	hitSBTRegion		= makeStridedDeviceAddressRegionKHR(DE_NULL, 0, 0);
383 	VkStridedDeviceAddressRegionKHR	callableSBTRegion	= makeStridedDeviceAddressRegionKHR(DE_NULL, 0, 0);
384 
385 	// Generate ids for the closest hit and miss shaders according to the test parameters.
386 	DE_ASSERT(m_params.hitParams.rayTypeCount > 0u);
387 	DE_ASSERT(m_params.hitParams.rayType < m_params.hitParams.rayTypeCount);
388 	DE_ASSERT(geometries.size() > 0u);
389 
390 	std::vector<deUint32> missShaderIds;
391 	for (deUint32 missIdx = 0; missIdx <= m_params.missParams.missIndex; ++missIdx)
392 		missShaderIds.push_back(makeMissId(missIdx));
393 
394 	deUint32				chitCounter		= 0u;
395 	std::vector<deUint32>	chitShaderIds;
396 
397 	for (size_t geoIdx = 0; geoIdx < geometries.size(); ++geoIdx)
398 	for (deUint32 rayIdx = 0; rayIdx < m_params.hitParams.rayTypeCount; ++rayIdx)
399 		chitShaderIds.push_back(makeChitId(chitCounter++));
400 
401 	{
402 		const auto						rayTracingPipeline		= de::newMovePtr<RayTracingPipeline>();
403 		const VkSpecializationMapEntry	specializationMapEntry	=
404 		{
405 			0u,											//	deUint32	constantID;
406 			0u,											//	deUint32	offset;
407 			static_cast<deUintptr>(sizeof(deUint32)),	//	deUintptr	size;
408 		};
409 		VkSpecializationInfo			specInfo				=
410 		{
411 			1u,											//	deUint32						mapEntryCount;
412 			&specializationMapEntry,					//	const VkSpecializationMapEntry*	pMapEntries;
413 			static_cast<deUintptr>(sizeof(deUint32)),	//	deUintptr						dataSize;
414 			nullptr,									//	const void*						pData;
415 		};
416 
417 		std::vector<VkSpecializationInfo> specInfos;
418 		specInfos.reserve(missShaderIds.size() + chitShaderIds.size());
419 
420 		deUint32 shaderGroupIdx = 0u;
421 		rayTracingPipeline->addShader(VK_SHADER_STAGE_RAYGEN_BIT_KHR, rgenModule, shaderGroupIdx++);
422 
423 		for (size_t missIdx = 0; missIdx < missShaderIds.size(); ++missIdx)
424 		{
425 			specInfo.pData = &missShaderIds.at(missIdx);
426 			specInfos.push_back(specInfo);
427 			rayTracingPipeline->addShader(VK_SHADER_STAGE_MISS_BIT_KHR, missModule, shaderGroupIdx++, &specInfos.back());
428 		}
429 
430 		const auto firstChitGroup = shaderGroupIdx;
431 
432 		for (size_t chitIdx = 0; chitIdx < chitShaderIds.size(); ++chitIdx)
433 		{
434 			specInfo.pData = &chitShaderIds.at(chitIdx);
435 			specInfos.push_back(specInfo);
436 			rayTracingPipeline->addShader(VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR, chitModule, shaderGroupIdx++, &specInfos.back());
437 		}
438 
439 		pipeline = rayTracingPipeline->createPipeline(vkd, device, pipelineLayout.get());
440 
441 		raygenSBT		= rayTracingPipeline->createShaderBindingTable(vkd, device, pipeline.get(), alloc, shaderGroupHandleSize, shaderGroupBaseAlignment, 0u, 1u);
442 		raygenSBTRegion	= makeStridedDeviceAddressRegionKHR(getBufferDeviceAddress(vkd, device, raygenSBT->get(), 0), shaderGroupHandleSize, shaderGroupHandleSize);
443 
444 		missSBT			= rayTracingPipeline->createShaderBindingTable(vkd, device, pipeline.get(), alloc, shaderGroupHandleSize, shaderGroupBaseAlignment, 1u, static_cast<deUint32>(missShaderIds.size()));
445 		missSBTRegion	= makeStridedDeviceAddressRegionKHR(getBufferDeviceAddress(vkd, device, missSBT->get(), 0), shaderGroupHandleSize, shaderGroupHandleSize * missShaderIds.size());
446 
447 		hitSBT			= rayTracingPipeline->createShaderBindingTable(vkd, device, pipeline.get(), alloc, shaderGroupHandleSize, shaderGroupBaseAlignment, firstChitGroup, static_cast<deUint32>(chitShaderIds.size()));
448 		hitSBTRegion	= makeStridedDeviceAddressRegionKHR(getBufferDeviceAddress(vkd, device, hitSBT->get(), 0), shaderGroupHandleSize, shaderGroupHandleSize * chitShaderIds.size());
449 	}
450 
451 	// Fill input buffer values.
452 	{
453 		DE_ASSERT(!(m_params.miss && m_params.missParams.missCause == MissCause::NONE));
454 
455 		const ArgsBufferData argsBufferData =
456 		{
457 			((m_params.miss && m_params.missParams.missCause == MissCause::ORIGIN)		? kBadOrigin	: kGoodOrigin),
458 			((m_params.miss && m_params.missParams.missCause == MissCause::DIRECTION)	? kBadDirection	: kGoodDirection),
459 			((m_params.miss && m_params.missParams.missCause == MissCause::TMIN)		? kBadTmin		: kGoodTmin),
460 			((m_params.miss && m_params.missParams.missCause == MissCause::TMAX)		? kBadTmax		: kGoodTmax),
461 			((m_params.miss && m_params.missParams.missCause == MissCause::FLAGS)		? kBadFlags		: kGoodFlags),
462 			((m_params.miss && m_params.missParams.missCause == MissCause::CULL_MASK)	? kBadCullMask	: kGoodCullMask),
463 			m_params.hitParams.rayType,
464 			m_params.hitParams.rayTypeCount,
465 			m_params.missParams.missIndex,
466 		};
467 
468 		deMemcpy(inputBufferAlloc.getHostPtr(), &argsBufferData, sizeof(argsBufferData));
469 		flushAlloc(vkd, device, inputBufferAlloc);
470 	}
471 
472 	// Trace rays.
473 	vkd.cmdBindPipeline(cmdBuffer, VK_PIPELINE_BIND_POINT_RAY_TRACING_KHR, pipeline.get());
474 	vkd.cmdBindDescriptorSets(cmdBuffer, VK_PIPELINE_BIND_POINT_RAY_TRACING_KHR, pipelineLayout.get(), 0u, 1u, &descriptorSet.get(), 0u, nullptr);
475 	vkd.cmdTraceRaysKHR(cmdBuffer, &raygenSBTRegion, &missSBTRegion, &hitSBTRegion, &callableSBTRegion, 1u, 1u, 1u);
476 
477 	// Barrier for the output buffer.
478 	const auto bufferBarrier = makeMemoryBarrier(VK_ACCESS_SHADER_WRITE_BIT, VK_ACCESS_HOST_READ_BIT);
479 	vkd.cmdPipelineBarrier(cmdBuffer, VK_PIPELINE_STAGE_RAY_TRACING_SHADER_BIT_KHR, VK_PIPELINE_STAGE_HOST_BIT, 0u, 1u, &bufferBarrier, 0u, nullptr, 0u, nullptr);
480 
481 	endCommandBuffer(vkd, cmdBuffer);
482 	submitCommandsAndWait(vkd, device, queue, cmdBuffer);
483 
484 	// Check output value.
485 	invalidateAlloc(vkd, device, outputBufferAlloc);
486 	deUint32 outputVal = std::numeric_limits<deUint32>::max();
487 	deMemcpy(&outputVal, outputBufferAlloc.getHostPtr(), sizeof(outputVal));
488 	const auto expectedVal = (m_params.miss ? makeMissId(m_params.missParams.missIndex) : makeChitId(m_params.hitParams.rayTypeCount + m_params.hitParams.rayType));
489 
490 	std::ostringstream msg;
491 	msg << "Output value: 0x" << std::hex << outputVal << " (expected 0x" << expectedVal << ")";
492 
493 	if (outputVal != expectedVal)
494 		return tcu::TestStatus::fail(msg.str());
495 
496 	auto& log = m_context.getTestContext().getLog();
497 	log << tcu::TestLog::Message << msg.str() << tcu::TestLog::EndMessage;
498 
499 	return tcu::TestStatus::pass("Pass");
500 }
501 
502 } // anonymous
503 
createNonUniformArgsTests(tcu::TestContext & testCtx)504 tcu::TestCaseGroup*	createNonUniformArgsTests (tcu::TestContext& testCtx)
505 {
506 	de::MovePtr<tcu::TestCaseGroup> nonUniformGroup(new tcu::TestCaseGroup(testCtx, "non_uniform_args", "Test non-uniform arguments in traceRayExt()"));
507 
508 	// Closest hit cases.
509 	{
510 		NonUniformParams params;
511 		params.miss = false;
512 		params.missParams.missIndex = 0u;
513 		params.missParams.missCause = MissCause::NONE;
514 
515 		for (deUint32 typeCount = 1u; typeCount <= 4u; ++typeCount)
516 		{
517 			params.hitParams.rayTypeCount = typeCount;
518 			for (deUint32 rayType = 0u; rayType < typeCount; ++rayType)
519 			{
520 				params.hitParams.rayType = rayType;
521 				nonUniformGroup->addChild(new NonUniformArgsCase(testCtx, "chit_" + de::toString(typeCount) + "_types_" + de::toString(rayType), "", params));
522 			}
523 		}
524 	}
525 
526 	// Miss cases.
527 	{
528 		NonUniformParams params;
529 		params.miss = true;
530 		params.hitParams.rayTypeCount = 1u;
531 		params.hitParams.rayType = 0u;
532 
533 		for (int causeIdx = static_cast<int>(MissCause::NONE) + 1; causeIdx < static_cast<int>(MissCause::CAUSE_COUNT); ++causeIdx)
534 		{
535 			params.missParams.missCause = static_cast<MissCause>(causeIdx);
536 			params.missParams.missIndex = static_cast<deUint32>(causeIdx-1);
537 			nonUniformGroup->addChild(new NonUniformArgsCase(testCtx, "miss_cause_" + de::toString(causeIdx), "", params));
538 		}
539 	}
540 
541 	return nonUniformGroup.release();
542 }
543 
544 }	// RayTracing
545 }	// vkt
546