• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*------------------------------------------------------------------------
2  * Vulkan Conformance Tests
3  * ------------------------
4  *
5  * Copyright (c) 2019 The Khronos Group Inc.
6  *
7  * Licensed under the Apache License, Version 2.0 (the "License");
8  * you may not use this file except in compliance with the License.
9  * You may obtain a copy of the License at
10  *
11  *	  http://www.apache.org/licenses/LICENSE-2.0
12  *
13  * Unless required by applicable law or agreed to in writing, software
14  * distributed under the License is distributed on an "AS IS" BASIS,
15  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16  * See the License for the specific language governing permissions and
17  * limitations under the License.
18  *
19  *//*!
20  * \file
21  * \brief Ray Tracing Complex Control Flow tests
22  *//*--------------------------------------------------------------------*/
23 
24 #include "vktRayTracingComplexControlFlowTests.hpp"
25 
26 #include "vkDefs.hpp"
27 
28 #include "vktTestCase.hpp"
29 #include "vkCmdUtil.hpp"
30 #include "vkObjUtil.hpp"
31 #include "vkBuilderUtil.hpp"
32 #include "vkBarrierUtil.hpp"
33 #include "vkBufferWithMemory.hpp"
34 #include "vkImageWithMemory.hpp"
35 #include "vkTypeUtil.hpp"
36 
37 #include "vkRayTracingUtil.hpp"
38 
39 #include "tcuTestLog.hpp"
40 
41 #include "deRandom.hpp"
42 
43 namespace vkt
44 {
45 namespace RayTracing
46 {
47 namespace
48 {
49 using namespace vk;
50 using namespace std;
51 
52 static const VkFlags	ALL_RAY_TRACING_STAGES	= VK_SHADER_STAGE_RAYGEN_BIT_KHR
53 												| VK_SHADER_STAGE_ANY_HIT_BIT_KHR
54 												| VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR
55 												| VK_SHADER_STAGE_MISS_BIT_KHR
56 												| VK_SHADER_STAGE_INTERSECTION_BIT_KHR
57 												| VK_SHADER_STAGE_CALLABLE_BIT_KHR;
58 
59 #if defined(DE_DEBUG)
60 static const deUint32	PUSH_CONSTANTS_COUNT	= 6;
61 #endif
62 static const deUint32	DEFAULT_CLEAR_VALUE		= 999999;
63 
64 enum TestType
65 {
66 	TEST_TYPE_IF						= 0,
67 	TEST_TYPE_LOOP,
68 	TEST_TYPE_SWITCH,
69 	TEST_TYPE_LOOP_DOUBLE_CALL,
70 	TEST_TYPE_LOOP_DOUBLE_CALL_SPARSE,
71 	TEST_TYPE_NESTED_LOOP,
72 	TEST_TYPE_NESTED_LOOP_BEFORE,
73 	TEST_TYPE_NESTED_LOOP_AFTER,
74 	TEST_TYPE_FUNCTION_CALL,
75 	TEST_TYPE_NESTED_FUNCTION_CALL,
76 };
77 
78 enum TestOp
79 {
80 	TEST_OP_EXECUTE_CALLABLE		= 0,
81 	TEST_OP_TRACE_RAY,
82 	TEST_OP_REPORT_INTERSECTION,
83 };
84 
85 enum ShaderGroups
86 {
87 	FIRST_GROUP		= 0,
88 	RAYGEN_GROUP	= FIRST_GROUP,
89 	MISS_GROUP,
90 	HIT_GROUP,
91 	GROUP_COUNT
92 };
93 
94 struct CaseDef
95 {
96 	TestType				testType;
97 	TestOp					testOp;
98 	VkShaderStageFlagBits	stage;
99 	deUint32				width;
100 	deUint32				height;
101 };
102 
103 struct PushConstants
104 {
105 	deUint32	a;
106 	deUint32	b;
107 	deUint32	c;
108 	deUint32	d;
109 	deUint32	hitOfs;
110 	deUint32	miss;
111 };
112 
getShaderGroupSize(const InstanceInterface & vki,const VkPhysicalDevice physicalDevice)113 deUint32 getShaderGroupSize (const InstanceInterface&	vki,
114 							 const VkPhysicalDevice		physicalDevice)
115 {
116 	de::MovePtr<RayTracingProperties>	rayTracingPropertiesKHR;
117 
118 	rayTracingPropertiesKHR	= makeRayTracingProperties(vki, physicalDevice);
119 	return rayTracingPropertiesKHR->getShaderGroupHandleSize();
120 }
121 
getShaderGroupBaseAlignment(const InstanceInterface & vki,const VkPhysicalDevice physicalDevice)122 deUint32 getShaderGroupBaseAlignment (const InstanceInterface&	vki,
123 									  const VkPhysicalDevice	physicalDevice)
124 {
125 	de::MovePtr<RayTracingProperties>	rayTracingPropertiesKHR;
126 
127 	rayTracingPropertiesKHR = makeRayTracingProperties(vki, physicalDevice);
128 	return rayTracingPropertiesKHR->getShaderGroupBaseAlignment();
129 }
130 
makeImageCreateInfo(deUint32 width,deUint32 height,deUint32 depth,VkFormat format)131 VkImageCreateInfo makeImageCreateInfo (deUint32 width, deUint32 height, deUint32 depth, VkFormat format)
132 {
133 	const VkImageUsageFlags	usage			= VK_IMAGE_USAGE_STORAGE_BIT | VK_IMAGE_USAGE_TRANSFER_SRC_BIT | VK_IMAGE_USAGE_TRANSFER_DST_BIT;
134 	const VkImageCreateInfo	imageCreateInfo	=
135 	{
136 		VK_STRUCTURE_TYPE_IMAGE_CREATE_INFO,	// VkStructureType			sType;
137 		DE_NULL,								// const void*				pNext;
138 		(VkImageCreateFlags)0u,					// VkImageCreateFlags		flags;
139 		VK_IMAGE_TYPE_3D,						// VkImageType				imageType;
140 		format,									// VkFormat					format;
141 		makeExtent3D(width, height, depth),		// VkExtent3D				extent;
142 		1u,										// deUint32					mipLevels;
143 		1u,										// deUint32					arrayLayers;
144 		VK_SAMPLE_COUNT_1_BIT,					// VkSampleCountFlagBits	samples;
145 		VK_IMAGE_TILING_OPTIMAL,				// VkImageTiling			tiling;
146 		usage,									// VkImageUsageFlags		usage;
147 		VK_SHARING_MODE_EXCLUSIVE,				// VkSharingMode			sharingMode;
148 		0u,										// deUint32					queueFamilyIndexCount;
149 		DE_NULL,								// const deUint32*			pQueueFamilyIndices;
150 		VK_IMAGE_LAYOUT_UNDEFINED				// VkImageLayout			initialLayout;
151 	};
152 
153 	return imageCreateInfo;
154 }
155 
makePipelineLayout(const DeviceInterface & vk,const VkDevice device,const VkDescriptorSetLayout descriptorSetLayout,const deUint32 pushConstantsSize)156 Move<VkPipelineLayout> makePipelineLayout (const DeviceInterface&		vk,
157 										   const VkDevice				device,
158 										   const VkDescriptorSetLayout	descriptorSetLayout,
159 										   const deUint32				pushConstantsSize)
160 {
161 	const VkDescriptorSetLayout*		descriptorSetLayoutPtr	= (descriptorSetLayout == DE_NULL) ? DE_NULL : &descriptorSetLayout;
162 	const deUint32						setLayoutCount			= (descriptorSetLayout == DE_NULL) ? 0u : 1u;
163 	const VkPushConstantRange			pushConstantRange		=
164 	{
165 		ALL_RAY_TRACING_STAGES,		//  VkShaderStageFlags	stageFlags;
166 		0u,							//  deUint32			offset;
167 		pushConstantsSize,			//  deUint32			size;
168 	};
169 	const VkPushConstantRange*			pPushConstantRanges		= (pushConstantsSize == 0) ? DE_NULL : &pushConstantRange;
170 	const deUint32						pushConstantRangeCount	= (pushConstantsSize == 0) ? 0 : 1u;
171 	const VkPipelineLayoutCreateInfo	pipelineLayoutParams	=
172 	{
173 		VK_STRUCTURE_TYPE_PIPELINE_LAYOUT_CREATE_INFO,	// VkStructureType					sType;
174 		DE_NULL,										// const void*						pNext;
175 		0u,												// VkPipelineLayoutCreateFlags		flags;
176 		setLayoutCount,									// deUint32							setLayoutCount;
177 		descriptorSetLayoutPtr,							// const VkDescriptorSetLayout*		pSetLayouts;
178 		pushConstantRangeCount,							// deUint32							pushConstantRangeCount;
179 		pPushConstantRanges,							// const VkPushConstantRange*		pPushConstantRanges;
180 	};
181 
182 	return createPipelineLayout(vk, device, &pipelineLayoutParams);
183 }
184 
getVkBuffer(const de::MovePtr<BufferWithMemory> & buffer)185 VkBuffer getVkBuffer (const de::MovePtr<BufferWithMemory>& buffer)
186 {
187 	VkBuffer result = (buffer.get() == DE_NULL) ? DE_NULL : buffer->get();
188 
189 	return result;
190 }
191 
makeStridedDeviceAddressRegion(const DeviceInterface & vkd,const VkDevice device,VkBuffer buffer,deUint32 stride,deUint32 count)192 VkStridedDeviceAddressRegionKHR makeStridedDeviceAddressRegion (const DeviceInterface& vkd, const VkDevice device, VkBuffer buffer, deUint32 stride, deUint32 count)
193 {
194 	if (buffer == DE_NULL)
195 	{
196 		return makeStridedDeviceAddressRegionKHR(0, 0, 0);
197 	}
198 	else
199 	{
200 		return makeStridedDeviceAddressRegionKHR(getBufferDeviceAddress(vkd, device, buffer, 0), stride, stride * count);
201 	}
202 }
203 
204 // Function replacing all occurrences of substring with string passed in last parameter.
replace(const std::string & str,const std::string & from,const std::string & to)205 static inline std::string replace(const std::string& str, const std::string& from, const std::string& to)
206 {
207 	std::string result(str);
208 
209 	size_t start_pos = 0;
210 	while((start_pos = result.find(from, start_pos)) != std::string::npos)
211 	{
212 		result.replace(start_pos, from.length(), to);
213 		start_pos += to.length();
214 	}
215 
216 	return result;
217 }
218 
219 
220 class RayTracingComplexControlFlowInstance : public TestInstance
221 {
222 public:
223 																RayTracingComplexControlFlowInstance	(Context& context, const CaseDef& data);
224 																~RayTracingComplexControlFlowInstance	(void);
225 	tcu::TestStatus												iterate									(void);
226 
227 protected:
228 	void														calcShaderGroup							(deUint32&					shaderGroupCounter,
229 																										 const VkShaderStageFlags	shaders1,
230 																										 const VkShaderStageFlags	shaders2,
231 																										 const VkShaderStageFlags	shaderStageFlags,
232 																										 deUint32&					shaderGroup,
233 																										 deUint32&					shaderGroupCount) const;
234 	PushConstants												getPushConstants						(void) const;
235 	std::vector<deUint32>										getExpectedValues						(void) const;
236 	de::MovePtr<BufferWithMemory>								runTest									(void);
237 	Move<VkPipeline>											makePipeline							(de::MovePtr<RayTracingPipeline>&							rayTracingPipeline,
238 																										 VkPipelineLayout											pipelineLayout);
239 	de::MovePtr<BufferWithMemory>								createShaderBindingTable				 (const InstanceInterface&									vki,
240 																										 const DeviceInterface&										vkd,
241 																										 const VkDevice												device,
242 																										 const VkPhysicalDevice										physicalDevice,
243 																										 const VkPipeline											pipeline,
244 																										 Allocator&													allocator,
245 																										 de::MovePtr<RayTracingPipeline>&							rayTracingPipeline,
246 																										 const deUint32												group,
247 																										 const deUint32												groupCount = 1u);
248 	de::MovePtr<TopLevelAccelerationStructure>					initTopAccelerationStructure			(VkCommandBuffer											cmdBuffer,
249 																										 vector<de::SharedPtr<BottomLevelAccelerationStructure> >&	bottomLevelAccelerationStructures);
250 	vector<de::SharedPtr<BottomLevelAccelerationStructure>	>	initBottomAccelerationStructures		(VkCommandBuffer											cmdBuffer);
251 	de::MovePtr<BottomLevelAccelerationStructure>				initBottomAccelerationStructure			(VkCommandBuffer											cmdBuffer,
252 																										 tcu::UVec2&												startPos);
253 
254 private:
255 	CaseDef														m_data;
256 	VkShaderStageFlags											m_shaders;
257 	VkShaderStageFlags											m_shaders2;
258 	deUint32													m_raygenShaderGroup;
259 	deUint32													m_missShaderGroup;
260 	deUint32													m_hitShaderGroup;
261 	deUint32													m_callableShaderGroup;
262 	deUint32													m_raygenShaderGroupCount;
263 	deUint32													m_missShaderGroupCount;
264 	deUint32													m_hitShaderGroupCount;
265 	deUint32													m_callableShaderGroupCount;
266 	deUint32													m_shaderGroupCount;
267 	deUint32													m_depth;
268 	PushConstants												m_pushConstants;
269 };
270 
RayTracingComplexControlFlowInstance(Context & context,const CaseDef & data)271 RayTracingComplexControlFlowInstance::RayTracingComplexControlFlowInstance (Context& context, const CaseDef& data)
272 	: vkt::TestInstance				(context)
273 	, m_data						(data)
274 	, m_shaders						(0)
275 	, m_shaders2					(0)
276 	, m_raygenShaderGroup			(~0u)
277 	, m_missShaderGroup				(~0u)
278 	, m_hitShaderGroup				(~0u)
279 	, m_callableShaderGroup			(~0u)
280 	, m_raygenShaderGroupCount		(0)
281 	, m_missShaderGroupCount		(0)
282 	, m_hitShaderGroupCount			(0)
283 	, m_callableShaderGroupCount	(0)
284 	, m_shaderGroupCount			(0)
285 	, m_depth						(16)
286 	, m_pushConstants				(getPushConstants())
287 {
288 	const VkShaderStageFlags	hitStages	= VK_SHADER_STAGE_ANY_HIT_BIT_KHR | VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR | VK_SHADER_STAGE_INTERSECTION_BIT_KHR;
289 	BinaryCollection&			collection	= m_context.getBinaryCollection();
290 	deUint32					shaderCount	= 0;
291 
292 	if (collection.contains("rgen")) m_shaders |= VK_SHADER_STAGE_RAYGEN_BIT_KHR;
293 	if (collection.contains("ahit")) m_shaders |= VK_SHADER_STAGE_ANY_HIT_BIT_KHR;
294 	if (collection.contains("chit")) m_shaders |= VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR;
295 	if (collection.contains("miss")) m_shaders |= VK_SHADER_STAGE_MISS_BIT_KHR;
296 	if (collection.contains("sect")) m_shaders |= VK_SHADER_STAGE_INTERSECTION_BIT_KHR;
297 	if (collection.contains("call")) m_shaders |= VK_SHADER_STAGE_CALLABLE_BIT_KHR;
298 
299 	if (collection.contains("ahit2")) m_shaders2 |= VK_SHADER_STAGE_ANY_HIT_BIT_KHR;
300 	if (collection.contains("chit2")) m_shaders2 |= VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR;
301 	if (collection.contains("miss2")) m_shaders2 |= VK_SHADER_STAGE_MISS_BIT_KHR;
302 	if (collection.contains("sect2")) m_shaders2 |= VK_SHADER_STAGE_INTERSECTION_BIT_KHR;
303 
304 	if (collection.contains("cal0")) m_shaders2 |= VK_SHADER_STAGE_CALLABLE_BIT_KHR;
305 
306 	for (BinaryCollection::Iterator it = collection.begin(); it != collection.end(); ++it)
307 		shaderCount++;
308 
309 	if (shaderCount != (deUint32)dePop32(m_shaders) + (deUint32)dePop32(m_shaders2))
310 		TCU_THROW(InternalError, "Unused shaders detected in the collection");
311 
312 	calcShaderGroup(m_shaderGroupCount, m_shaders, m_shaders2, VK_SHADER_STAGE_RAYGEN_BIT_KHR,   m_raygenShaderGroup,   m_raygenShaderGroupCount);
313 	calcShaderGroup(m_shaderGroupCount, m_shaders, m_shaders2, VK_SHADER_STAGE_MISS_BIT_KHR,     m_missShaderGroup,     m_missShaderGroupCount);
314 	calcShaderGroup(m_shaderGroupCount, m_shaders, m_shaders2, hitStages,                        m_hitShaderGroup,      m_hitShaderGroupCount);
315 	calcShaderGroup(m_shaderGroupCount, m_shaders, m_shaders2, VK_SHADER_STAGE_CALLABLE_BIT_KHR, m_callableShaderGroup, m_callableShaderGroupCount);
316 }
317 
~RayTracingComplexControlFlowInstance(void)318 RayTracingComplexControlFlowInstance::~RayTracingComplexControlFlowInstance (void)
319 {
320 }
321 
calcShaderGroup(deUint32 & shaderGroupCounter,const VkShaderStageFlags shaders1,const VkShaderStageFlags shaders2,const VkShaderStageFlags shaderStageFlags,deUint32 & shaderGroup,deUint32 & shaderGroupCount) const322 void RayTracingComplexControlFlowInstance::calcShaderGroup (deUint32&					shaderGroupCounter,
323 															const VkShaderStageFlags	shaders1,
324 															const VkShaderStageFlags	shaders2,
325 															const VkShaderStageFlags	shaderStageFlags,
326 															deUint32&					shaderGroup,
327 															deUint32&					shaderGroupCount) const
328 {
329 	const deUint32	shader1Count = ((shaders1 & shaderStageFlags) != 0) ? 1 : 0;
330 	const deUint32	shader2Count = ((shaders2 & shaderStageFlags) != 0) ? 1 : 0;
331 
332 	shaderGroupCount = shader1Count + shader2Count;
333 
334 	if (shaderGroupCount != 0)
335 	{
336 		shaderGroup			= shaderGroupCounter;
337 		shaderGroupCounter += shaderGroupCount;
338 	}
339 }
340 
makePipeline(de::MovePtr<RayTracingPipeline> & rayTracingPipeline,VkPipelineLayout pipelineLayout)341 Move<VkPipeline> RayTracingComplexControlFlowInstance::makePipeline (de::MovePtr<RayTracingPipeline>&	rayTracingPipeline,
342 																	  VkPipelineLayout					pipelineLayout)
343 {
344 	const DeviceInterface&	vkd			= m_context.getDeviceInterface();
345 	const VkDevice			device		= m_context.getDevice();
346 	vk::BinaryCollection&	collection	= m_context.getBinaryCollection();
347 
348 	if (0 != (m_shaders & VK_SHADER_STAGE_RAYGEN_BIT_KHR))			rayTracingPipeline->addShader(VK_SHADER_STAGE_RAYGEN_BIT_KHR		, createShaderModule(vkd, device, collection.get("rgen"), 0), m_raygenShaderGroup);
349 	if (0 != (m_shaders & VK_SHADER_STAGE_ANY_HIT_BIT_KHR))			rayTracingPipeline->addShader(VK_SHADER_STAGE_ANY_HIT_BIT_KHR		, createShaderModule(vkd, device, collection.get("ahit"), 0), m_hitShaderGroup);
350 	if (0 != (m_shaders & VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR))		rayTracingPipeline->addShader(VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR	, createShaderModule(vkd, device, collection.get("chit"), 0), m_hitShaderGroup);
351 	if (0 != (m_shaders & VK_SHADER_STAGE_MISS_BIT_KHR))			rayTracingPipeline->addShader(VK_SHADER_STAGE_MISS_BIT_KHR			, createShaderModule(vkd, device, collection.get("miss"), 0), m_missShaderGroup);
352 	if (0 != (m_shaders & VK_SHADER_STAGE_INTERSECTION_BIT_KHR))	rayTracingPipeline->addShader(VK_SHADER_STAGE_INTERSECTION_BIT_KHR	, createShaderModule(vkd, device, collection.get("sect"), 0), m_hitShaderGroup);
353 	if (0 != (m_shaders & VK_SHADER_STAGE_CALLABLE_BIT_KHR))		rayTracingPipeline->addShader(VK_SHADER_STAGE_CALLABLE_BIT_KHR		, createShaderModule(vkd, device, collection.get("call"), 0), m_callableShaderGroup + 1);
354 
355 	if (0 != (m_shaders2 & VK_SHADER_STAGE_CALLABLE_BIT_KHR))		rayTracingPipeline->addShader(VK_SHADER_STAGE_CALLABLE_BIT_KHR		, createShaderModule(vkd, device, collection.get("cal0"), 0), m_callableShaderGroup);
356 	if (0 != (m_shaders2 & VK_SHADER_STAGE_ANY_HIT_BIT_KHR))		rayTracingPipeline->addShader(VK_SHADER_STAGE_ANY_HIT_BIT_KHR		, createShaderModule(vkd, device, collection.get("ahit2"), 0), m_hitShaderGroup + 1);
357 	if (0 != (m_shaders2 & VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR))	rayTracingPipeline->addShader(VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR	, createShaderModule(vkd, device, collection.get("chit2"), 0), m_hitShaderGroup + 1);
358 	if (0 != (m_shaders2 & VK_SHADER_STAGE_MISS_BIT_KHR))			rayTracingPipeline->addShader(VK_SHADER_STAGE_MISS_BIT_KHR			, createShaderModule(vkd, device, collection.get("miss2"), 0), m_missShaderGroup + 1);
359 	if (0 != (m_shaders2 & VK_SHADER_STAGE_INTERSECTION_BIT_KHR))	rayTracingPipeline->addShader(VK_SHADER_STAGE_INTERSECTION_BIT_KHR	, createShaderModule(vkd, device, collection.get("sect2"), 0), m_hitShaderGroup + 1);
360 
361 	if (m_data.testOp == TEST_OP_TRACE_RAY && m_data.stage != VK_SHADER_STAGE_RAYGEN_BIT_KHR)
362 		rayTracingPipeline->setMaxRecursionDepth(2);
363 
364 	Move<VkPipeline> pipeline = rayTracingPipeline->createPipeline(vkd, device, pipelineLayout);
365 
366 	return pipeline;
367 }
368 
createShaderBindingTable(const InstanceInterface & vki,const DeviceInterface & vkd,const VkDevice device,const VkPhysicalDevice physicalDevice,const VkPipeline pipeline,Allocator & allocator,de::MovePtr<RayTracingPipeline> & rayTracingPipeline,const deUint32 group,const deUint32 groupCount)369 de::MovePtr<BufferWithMemory> RayTracingComplexControlFlowInstance::createShaderBindingTable (const InstanceInterface&			vki,
370 																							  const DeviceInterface&			vkd,
371 																							  const VkDevice					device,
372 																							  const VkPhysicalDevice			physicalDevice,
373 																							  const VkPipeline					pipeline,
374 																							  Allocator&						allocator,
375 																							  de::MovePtr<RayTracingPipeline>&	rayTracingPipeline,
376 																							  const deUint32					group,
377 																							  const deUint32					groupCount)
378 {
379 	de::MovePtr<BufferWithMemory>	shaderBindingTable;
380 
381 	if (group < m_shaderGroupCount)
382 	{
383 		const deUint32	shaderGroupHandleSize		= getShaderGroupSize(vki, physicalDevice);
384 		const deUint32	shaderGroupBaseAlignment	= getShaderGroupBaseAlignment(vki, physicalDevice);
385 
386 		shaderBindingTable = rayTracingPipeline->createShaderBindingTable(vkd, device, pipeline, allocator, shaderGroupHandleSize, shaderGroupBaseAlignment, group, groupCount);
387 	}
388 
389 	return shaderBindingTable;
390 }
391 
392 
initTopAccelerationStructure(VkCommandBuffer cmdBuffer,vector<de::SharedPtr<BottomLevelAccelerationStructure>> & bottomLevelAccelerationStructures)393 de::MovePtr<TopLevelAccelerationStructure> RayTracingComplexControlFlowInstance::initTopAccelerationStructure (VkCommandBuffer												cmdBuffer,
394 																											   vector<de::SharedPtr<BottomLevelAccelerationStructure> >&	bottomLevelAccelerationStructures)
395 {
396 	const DeviceInterface&						vkd			= m_context.getDeviceInterface();
397 	const VkDevice								device		= m_context.getDevice();
398 	Allocator&									allocator	= m_context.getDefaultAllocator();
399 	de::MovePtr<TopLevelAccelerationStructure>	result		= makeTopLevelAccelerationStructure();
400 
401 	result->setInstanceCount(bottomLevelAccelerationStructures.size());
402 
403 	for (size_t structNdx = 0; structNdx < bottomLevelAccelerationStructures.size(); ++structNdx)
404 		result->addInstance(bottomLevelAccelerationStructures[structNdx]);
405 
406 	result->createAndBuild(vkd, device, cmdBuffer, allocator);
407 
408 	return result;
409 }
410 
initBottomAccelerationStructure(VkCommandBuffer cmdBuffer,tcu::UVec2 & startPos)411 de::MovePtr<BottomLevelAccelerationStructure> RayTracingComplexControlFlowInstance::initBottomAccelerationStructure (VkCommandBuffer	cmdBuffer,
412 																													 tcu::UVec2&		startPos)
413 {
414 	const DeviceInterface&							vkd				= m_context.getDeviceInterface();
415 	const VkDevice									device			= m_context.getDevice();
416 	Allocator&										allocator		= m_context.getDefaultAllocator();
417 	de::MovePtr<BottomLevelAccelerationStructure>	result			= makeBottomLevelAccelerationStructure();
418 	const float										z				= (m_data.stage == VK_SHADER_STAGE_MISS_BIT_KHR) ? +1.0f : -1.0f;
419 	std::vector<tcu::Vec3>							geometryData;
420 
421 	DE_UNREF(startPos);
422 
423 	result->setGeometryCount(1);
424 	geometryData.push_back(tcu::Vec3(0.0f, 0.0f, z));
425 	geometryData.push_back(tcu::Vec3(1.0f, 1.0f, z));
426 	result->addGeometry(geometryData, false);
427 	result->createAndBuild(vkd, device, cmdBuffer, allocator);
428 
429 	return result;
430 }
431 
initBottomAccelerationStructures(VkCommandBuffer cmdBuffer)432 vector<de::SharedPtr<BottomLevelAccelerationStructure> > RayTracingComplexControlFlowInstance::initBottomAccelerationStructures (VkCommandBuffer	cmdBuffer)
433 {
434 	tcu::UVec2													startPos;
435 	vector<de::SharedPtr<BottomLevelAccelerationStructure> >	result;
436 	de::MovePtr<BottomLevelAccelerationStructure>				bottomLevelAccelerationStructure	= initBottomAccelerationStructure(cmdBuffer, startPos);
437 
438 	result.push_back(de::SharedPtr<BottomLevelAccelerationStructure>(bottomLevelAccelerationStructure.release()));
439 
440 	return result;
441 }
442 
getPushConstants(void) const443 PushConstants RayTracingComplexControlFlowInstance::getPushConstants (void) const
444 {
445 	const			deUint32	hitOfs	= 1;
446 	const			deUint32	miss	= 1;
447 	PushConstants	result;
448 
449 	switch (m_data.testType)
450 	{
451 		case TEST_TYPE_IF:
452 		{
453 			result = { 32 | 8 | 1, 10000, 0x0F, 0xF0, hitOfs, miss };
454 
455 			break;
456 		}
457 		case TEST_TYPE_LOOP:
458 		{
459 			result = { 8, 10000, 0x0F, 100000, hitOfs, miss };
460 
461 			break;
462 		}
463 		case TEST_TYPE_SWITCH:
464 		{
465 			result = { 3, 10000, 0x07, 100000, hitOfs, miss };
466 
467 			break;
468 		}
469 		case TEST_TYPE_LOOP_DOUBLE_CALL:
470 		{
471 			result = { 7, 10000, 0x0F, 0xF0, hitOfs, miss };
472 
473 			break;
474 		}
475 		case TEST_TYPE_LOOP_DOUBLE_CALL_SPARSE:
476 		{
477 			result = { 16, 5, 0x0F, 0xF0, hitOfs, miss };
478 
479 			break;
480 		}
481 		case TEST_TYPE_NESTED_LOOP:
482 		{
483 			result = { 8, 5, 0x0F, 0x09, hitOfs, miss };
484 
485 			break;
486 		}
487 		case TEST_TYPE_NESTED_LOOP_BEFORE:
488 		{
489 			result = { 9, 16, 0x0F, 10, hitOfs, miss };
490 
491 			break;
492 		}
493 		case TEST_TYPE_NESTED_LOOP_AFTER:
494 		{
495 			result = { 9, 16, 0x0F, 10, hitOfs, miss };
496 
497 			break;
498 		}
499 		case TEST_TYPE_FUNCTION_CALL:
500 		{
501 			result = { 0xFFB, 16, 10, 100000, hitOfs, miss };
502 
503 			break;
504 		}
505 		case TEST_TYPE_NESTED_FUNCTION_CALL:
506 		{
507 			result = { 0xFFB, 16, 10, 100000, hitOfs, miss };
508 
509 			break;
510 		}
511 
512 		default:
513 			TCU_THROW(InternalError, "Unknown testType");
514 	}
515 
516 	return result;
517 }
518 
runTest(void)519 de::MovePtr<BufferWithMemory> RayTracingComplexControlFlowInstance::runTest (void)
520 {
521 	const InstanceInterface&				vki									= m_context.getInstanceInterface();
522 	const DeviceInterface&					vkd									= m_context.getDeviceInterface();
523 	const VkDevice							device								= m_context.getDevice();
524 	const VkPhysicalDevice					physicalDevice						= m_context.getPhysicalDevice();
525 	const deUint32							queueFamilyIndex					= m_context.getUniversalQueueFamilyIndex();
526 	const VkQueue							queue								= m_context.getUniversalQueue();
527 	Allocator&								allocator							= m_context.getDefaultAllocator();
528 	const VkFormat							format								= VK_FORMAT_R32_UINT;
529 	const deUint32							pushConstants[]						= { m_pushConstants.a, m_pushConstants.b, m_pushConstants.c, m_pushConstants.d, m_pushConstants.hitOfs, m_pushConstants.miss };
530 	const deUint32							pushConstantsSize					= sizeof(pushConstants);
531 	const deUint32							pixelCount							= m_data.width * m_data.height * m_depth;
532 	const deUint32							shaderGroupHandleSize				= getShaderGroupSize(vki, physicalDevice);
533 
534 	const Move<VkDescriptorSetLayout>		descriptorSetLayout					= DescriptorSetLayoutBuilder()
535 																						.addSingleBinding(VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, ALL_RAY_TRACING_STAGES)
536 																						.addSingleBinding(VK_DESCRIPTOR_TYPE_ACCELERATION_STRUCTURE_KHR, ALL_RAY_TRACING_STAGES)
537 																						.build(vkd, device);
538 	const Move<VkDescriptorPool>			descriptorPool						= DescriptorPoolBuilder()
539 																						.addType(VK_DESCRIPTOR_TYPE_STORAGE_IMAGE)
540 																						.addType(VK_DESCRIPTOR_TYPE_ACCELERATION_STRUCTURE_KHR)
541 																						.build(vkd, device, VK_DESCRIPTOR_POOL_CREATE_FREE_DESCRIPTOR_SET_BIT, 1u);
542 	const Move<VkDescriptorSet>				descriptorSet						= makeDescriptorSet(vkd, device, *descriptorPool, *descriptorSetLayout);
543 	const Move<VkPipelineLayout>			pipelineLayout						= makePipelineLayout(vkd, device, descriptorSetLayout.get(), pushConstantsSize);
544 	const Move<VkCommandPool>				cmdPool								= createCommandPool(vkd, device, 0, queueFamilyIndex);
545 	const Move<VkCommandBuffer>				cmdBuffer							= allocateCommandBuffer(vkd, device, *cmdPool, VK_COMMAND_BUFFER_LEVEL_PRIMARY);
546 
547 	de::MovePtr<RayTracingPipeline>			rayTracingPipeline					= de::newMovePtr<RayTracingPipeline>();
548 	const Move<VkPipeline>					pipeline							= makePipeline(rayTracingPipeline, *pipelineLayout);
549 	const de::MovePtr<BufferWithMemory>		raygenShaderBindingTable			= createShaderBindingTable(vki, vkd, device, physicalDevice, *pipeline, allocator, rayTracingPipeline, m_raygenShaderGroup, m_raygenShaderGroupCount);
550 	const de::MovePtr<BufferWithMemory>		missShaderBindingTable				= createShaderBindingTable(vki, vkd, device, physicalDevice, *pipeline, allocator, rayTracingPipeline, m_missShaderGroup, m_missShaderGroupCount);
551 	const de::MovePtr<BufferWithMemory>		hitShaderBindingTable				= createShaderBindingTable(vki, vkd, device, physicalDevice, *pipeline, allocator, rayTracingPipeline, m_hitShaderGroup, m_hitShaderGroupCount);
552 	const de::MovePtr<BufferWithMemory>		callableShaderBindingTable			= createShaderBindingTable(vki, vkd, device, physicalDevice, *pipeline, allocator, rayTracingPipeline, m_callableShaderGroup, m_callableShaderGroupCount);
553 
554 	const VkStridedDeviceAddressRegionKHR	raygenShaderBindingTableRegion		= makeStridedDeviceAddressRegion(vkd, device, getVkBuffer(raygenShaderBindingTable),   shaderGroupHandleSize, m_raygenShaderGroupCount);
555 	const VkStridedDeviceAddressRegionKHR	missShaderBindingTableRegion		= makeStridedDeviceAddressRegion(vkd, device, getVkBuffer(missShaderBindingTable),     shaderGroupHandleSize, m_missShaderGroupCount);
556 	const VkStridedDeviceAddressRegionKHR	hitShaderBindingTableRegion			= makeStridedDeviceAddressRegion(vkd, device, getVkBuffer(hitShaderBindingTable),      shaderGroupHandleSize, m_hitShaderGroupCount);
557 	const VkStridedDeviceAddressRegionKHR	callableShaderBindingTableRegion	= makeStridedDeviceAddressRegion(vkd, device, getVkBuffer(callableShaderBindingTable), shaderGroupHandleSize, m_callableShaderGroupCount);
558 
559 	const VkImageCreateInfo					imageCreateInfo						= makeImageCreateInfo(m_data.width, m_data.height, m_depth, format);
560 	const VkImageSubresourceRange			imageSubresourceRange				= makeImageSubresourceRange(VK_IMAGE_ASPECT_COLOR_BIT, 0u, 1u, 0, 1u);
561 	const de::MovePtr<ImageWithMemory>		image								= de::MovePtr<ImageWithMemory>(new ImageWithMemory(vkd, device, allocator, imageCreateInfo, MemoryRequirement::Any));
562 	const Move<VkImageView>					imageView							= makeImageView(vkd, device, **image, VK_IMAGE_VIEW_TYPE_3D, format, imageSubresourceRange);
563 
564 	const VkBufferCreateInfo				bufferCreateInfo					= makeBufferCreateInfo(pixelCount*sizeof(deUint32), VK_BUFFER_USAGE_TRANSFER_DST_BIT);
565 	const VkImageSubresourceLayers			bufferImageSubresourceLayers		= makeImageSubresourceLayers(VK_IMAGE_ASPECT_COLOR_BIT, 0u, 0u, 1u);
566 	const VkBufferImageCopy					bufferImageRegion					= makeBufferImageCopy(makeExtent3D(m_data.width, m_data.height, m_depth), bufferImageSubresourceLayers);
567 	de::MovePtr<BufferWithMemory>			buffer								= de::MovePtr<BufferWithMemory>(new BufferWithMemory(vkd, device, allocator, bufferCreateInfo, MemoryRequirement::HostVisible));
568 
569 	const VkDescriptorImageInfo				descriptorImageInfo					= makeDescriptorImageInfo(DE_NULL, *imageView, VK_IMAGE_LAYOUT_GENERAL);
570 
571 	const VkImageMemoryBarrier				preImageBarrier						= makeImageMemoryBarrier(0u, VK_ACCESS_TRANSFER_WRITE_BIT,
572 																					VK_IMAGE_LAYOUT_UNDEFINED, VK_IMAGE_LAYOUT_TRANSFER_DST_OPTIMAL,
573 																					**image, imageSubresourceRange);
574 	const VkImageMemoryBarrier				postImageBarrier					= makeImageMemoryBarrier(VK_ACCESS_TRANSFER_WRITE_BIT, VK_ACCESS_SHADER_READ_BIT,
575 																					VK_IMAGE_LAYOUT_TRANSFER_DST_OPTIMAL, VK_IMAGE_LAYOUT_GENERAL,
576 																					**image, imageSubresourceRange);
577 	const VkMemoryBarrier					preTraceMemoryBarrier				= makeMemoryBarrier(VK_ACCESS_TRANSFER_WRITE_BIT, VK_ACCESS_SHADER_READ_BIT | VK_ACCESS_SHADER_WRITE_BIT);
578 	const VkMemoryBarrier					postTraceMemoryBarrier				= makeMemoryBarrier(VK_ACCESS_SHADER_READ_BIT | VK_ACCESS_SHADER_WRITE_BIT, VK_ACCESS_TRANSFER_READ_BIT);
579 	const VkMemoryBarrier					postCopyMemoryBarrier				= makeMemoryBarrier(VK_ACCESS_TRANSFER_WRITE_BIT, VK_ACCESS_HOST_READ_BIT);
580 	const VkClearValue						clearValue							= makeClearValueColorU32(DEFAULT_CLEAR_VALUE, 0u, 0u, 255u);
581 
582 	vector<de::SharedPtr<BottomLevelAccelerationStructure> >	bottomLevelAccelerationStructures;
583 	de::MovePtr<TopLevelAccelerationStructure>					topLevelAccelerationStructure;
584 
585 	DE_ASSERT(DE_LENGTH_OF_ARRAY(pushConstants) == PUSH_CONSTANTS_COUNT);
586 
587 	beginCommandBuffer(vkd, *cmdBuffer, 0u);
588 	{
589 		vkd.cmdPushConstants(*cmdBuffer, *pipelineLayout, ALL_RAY_TRACING_STAGES, 0, pushConstantsSize, &m_pushConstants);
590 
591 		cmdPipelineImageMemoryBarrier(vkd, *cmdBuffer, VK_PIPELINE_STAGE_TOP_OF_PIPE_BIT, VK_PIPELINE_STAGE_TRANSFER_BIT, &preImageBarrier);
592 		vkd.cmdClearColorImage(*cmdBuffer, **image, VK_IMAGE_LAYOUT_TRANSFER_DST_OPTIMAL, &clearValue.color, 1, &imageSubresourceRange);
593 		cmdPipelineImageMemoryBarrier(vkd, *cmdBuffer, VK_PIPELINE_STAGE_TRANSFER_BIT, ALL_RAY_TRACING_STAGES, &postImageBarrier);
594 
595 		bottomLevelAccelerationStructures = initBottomAccelerationStructures(*cmdBuffer);
596 		topLevelAccelerationStructure = initTopAccelerationStructure(*cmdBuffer, bottomLevelAccelerationStructures);
597 
598 		cmdPipelineMemoryBarrier(vkd, *cmdBuffer, VK_PIPELINE_STAGE_TRANSFER_BIT, ALL_RAY_TRACING_STAGES, &preTraceMemoryBarrier);
599 
600 		const TopLevelAccelerationStructure*			topLevelAccelerationStructurePtr		= topLevelAccelerationStructure.get();
601 		VkWriteDescriptorSetAccelerationStructureKHR	accelerationStructureWriteDescriptorSet	=
602 		{
603 			VK_STRUCTURE_TYPE_WRITE_DESCRIPTOR_SET_ACCELERATION_STRUCTURE_KHR,	//  VkStructureType						sType;
604 			DE_NULL,															//  const void*							pNext;
605 			1u,																	//  deUint32							accelerationStructureCount;
606 			topLevelAccelerationStructurePtr->getPtr(),							//  const VkAccelerationStructureKHR*	pAccelerationStructures;
607 		};
608 
609 		DescriptorSetUpdateBuilder()
610 			.writeSingle(*descriptorSet, DescriptorSetUpdateBuilder::Location::binding(0u), VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, &descriptorImageInfo)
611 			.writeSingle(*descriptorSet, DescriptorSetUpdateBuilder::Location::binding(1u), VK_DESCRIPTOR_TYPE_ACCELERATION_STRUCTURE_KHR, &accelerationStructureWriteDescriptorSet)
612 			.update(vkd, device);
613 
614 		vkd.cmdBindDescriptorSets(*cmdBuffer, VK_PIPELINE_BIND_POINT_RAY_TRACING_KHR, *pipelineLayout, 0, 1, &descriptorSet.get(), 0, DE_NULL);
615 
616 		vkd.cmdBindPipeline(*cmdBuffer, VK_PIPELINE_BIND_POINT_RAY_TRACING_KHR, *pipeline);
617 
618 		cmdTraceRays(vkd,
619 			*cmdBuffer,
620 			&raygenShaderBindingTableRegion,
621 			&missShaderBindingTableRegion,
622 			&hitShaderBindingTableRegion,
623 			&callableShaderBindingTableRegion,
624 			m_data.width, m_data.height, 1);
625 
626 		cmdPipelineMemoryBarrier(vkd, *cmdBuffer, ALL_RAY_TRACING_STAGES, VK_PIPELINE_STAGE_TRANSFER_BIT, &postTraceMemoryBarrier);
627 
628 		vkd.cmdCopyImageToBuffer(*cmdBuffer, **image, VK_IMAGE_LAYOUT_GENERAL, **buffer, 1u, &bufferImageRegion);
629 
630 		cmdPipelineMemoryBarrier(vkd, *cmdBuffer, VK_PIPELINE_STAGE_TRANSFER_BIT, VK_PIPELINE_STAGE_HOST_BIT, &postCopyMemoryBarrier);
631 	}
632 	endCommandBuffer(vkd, *cmdBuffer);
633 
634 	submitCommandsAndWait(vkd, device, queue, cmdBuffer.get());
635 
636 	invalidateMappedMemoryRange(vkd, device, buffer->getAllocation().getMemory(), buffer->getAllocation().getOffset(), pixelCount * sizeof(deUint32));
637 
638 	return buffer;
639 }
640 
getExpectedValues(void) const641 std::vector<deUint32> RayTracingComplexControlFlowInstance::getExpectedValues (void) const
642 {
643 	const deUint32				plainSize		= m_data.width * m_data.height;
644 	const deUint32				plain8Ofs		= 8 * plainSize;
645 	const struct PushConstants&	p				= m_pushConstants;
646 	const deUint32				pushConstants[]	= { 0, m_pushConstants.a, m_pushConstants.b, m_pushConstants.c, m_pushConstants.d, m_pushConstants.hitOfs, m_pushConstants.miss };
647 	const deUint32				resultSize		= plainSize * m_depth;
648 	const bool					fixed			= m_data.testOp == TEST_OP_REPORT_INTERSECTION;
649 	std::vector<deUint32>		result			(resultSize, DEFAULT_CLEAR_VALUE);
650 	deUint32					v0;
651 	deUint32					v1;
652 	deUint32					v2;
653 	deUint32					v3;
654 
655 	switch (m_data.testType)
656 	{
657 		case TEST_TYPE_IF:
658 		{
659 			for (deUint32 id = 0; id < plainSize; ++id)
660 			{
661 				v2 = v3 = p.b;
662 
663 				if ((p.a & id) != 0)
664 				{
665 					v0 = p.c & id;
666 					v1 = (p.d & id) + 1;
667 
668 					result[plain8Ofs + id] = v0;
669 					if (!fixed) v0++;
670 				}
671 				else
672 				{
673 					v0 = p.d & id;
674 					v1 = (p.c & id) + 1;
675 
676 					if (!fixed)
677 					{
678 						result[plain8Ofs + id] = v1;
679 						v1++;
680 					}
681 					else
682 						result[plain8Ofs + id] = v0;
683 				}
684 
685 				result[id] = v0 + v1 + v2 + v3;
686 			}
687 
688 			break;
689 		}
690 		case TEST_TYPE_LOOP:
691 		{
692 			for (deUint32 id = 0; id < plainSize; ++id)
693 			{
694 				result[id] = 0;
695 
696 				v1 = v3 = p.b;
697 
698 				for (deUint32 n = 0; n < p.a; n++)
699 				{
700 					v0 = (p.c & id) + n;
701 
702 					result[((n % 8) + 8) * plainSize + id] = v0;
703 					if (!fixed) v0++;
704 
705 					result[id] += v0 + v1 + v3;
706 				}
707 			}
708 
709 			break;
710 		}
711 		case TEST_TYPE_SWITCH:
712 		{
713 			for (deUint32 id = 0; id < plainSize; ++id)
714 			{
715 				switch (p.a & id)
716 				{
717 					case 0: { v1 = v2 = v3 = p.b; v0 = p.c & id; break; }
718 					case 1: { v0 = v2 = v3 = p.b; v1 = p.c & id; break; }
719 					case 2: { v0 = v1 = v3 = p.b; v2 = p.c & id; break; }
720 					case 3: { v0 = v1 = v2 = p.b; v3 = p.c & id; break; }
721 					default: { v0 = v1 = v2 = v3 = 0; break; }
722 				}
723 
724 				if (!fixed)
725 					result[plain8Ofs + id] = p.c & id;
726 				else
727 					result[plain8Ofs + id] = v0;
728 
729 				result[id] = v0 + v1 + v2 + v3;
730 
731 				if (!fixed) result[id]++;
732 			}
733 
734 			break;
735 		}
736 		case TEST_TYPE_LOOP_DOUBLE_CALL:
737 		{
738 			for (deUint32 id = 0; id < plainSize; ++id)
739 			{
740 				result[id] = 0;
741 
742 				v3 = p.b;
743 
744 				for (deUint32 x = 0; x < p.a; x++)
745 				{
746 					v0 = (p.c & id) + x;
747 					v1 = (p.d & id) + x + 1;
748 
749 					result[(((2 * x + 0) % 8) + 8) * plainSize + id] = v0;
750 					if (!fixed) v0++;
751 
752 					if (!fixed)
753 					{
754 						result[(((2 * x + 1) % 8) + 8) * plainSize + id] = v1;
755 						v1++;
756 					}
757 
758 					result[id] += v0 + v1 + v3;
759 				}
760 			}
761 
762 			break;
763 		}
764 		case TEST_TYPE_LOOP_DOUBLE_CALL_SPARSE:
765 		{
766 			for (deUint32 id = 0; id < plainSize; ++id)
767 			{
768 				result[id] = 0;
769 
770 				v3 = p.a + p.b;
771 
772 				for (deUint32 x = 0; x < p.a; x++)
773 				{
774 					if ((x & p.b) != 0)
775 					{
776 						v0 = (p.c & id) + x;
777 						v1 = (p.d & id) + x + 1;
778 
779 						result[(((2 * x + 0) % 8) + 8) * plainSize + id] = v0;
780 						if (!fixed) v0++;
781 
782 						if (!fixed)
783 						{
784 							result[(((2 * x + 1) % 8) + 8) * plainSize + id] = v1;
785 							v1++;
786 						}
787 
788 						result[id] += v0 + v1 + v3;
789 					}
790 				}
791 			}
792 
793 			break;
794 		}
795 		case TEST_TYPE_NESTED_LOOP:
796 		{
797 			for (deUint32 id = 0; id < plainSize; ++id)
798 			{
799 				result[id] = 0;
800 
801 				v1 = v3 = p.b;
802 
803 				for (deUint32 y = 0; y < p.a; y++)
804 				for (deUint32 x = 0; x < p.a; x++)
805 				{
806 					const deUint32 n = x + y * p.a;
807 
808 					if ((n & p.d) != 0)
809 					{
810 						v0 = (p.c & id) + n;
811 
812 						result[((n % 8) + 8) * plainSize + id] = v0;
813 						if (!fixed) v0++;
814 
815 						result[id] += v0 + v1 + v3;
816 					}
817 				}
818 			}
819 
820 			break;
821 		}
822 		case TEST_TYPE_NESTED_LOOP_BEFORE:
823 		{
824 			for (deUint32 id = 0; id < plainSize; ++id)
825 			{
826 				result[id] = 0;
827 
828 				for (deUint32 y = 0; y < p.d; y++)
829 				for (deUint32 x = 0; x < p.d; x++)
830 				{
831 					if (((x + y * p.a) & p.b) != 0)
832 						result[id] += (x + y);
833 				}
834 
835 				v1 = v3 = p.a;
836 
837 				for (deUint32 x = 0; x < p.b; x++)
838 				{
839 					if ((x & p.a) != 0)
840 					{
841 						v0 = p.c & id;
842 
843 						result[((x % 8) + 8) * plainSize + id] = v0;
844 						if (!fixed) v0++;
845 
846 						result[id] += v0 + v1 + v3;
847 					}
848 				}
849 			}
850 
851 			break;
852 		}
853 		case TEST_TYPE_NESTED_LOOP_AFTER:
854 		{
855 			for (deUint32 id = 0; id < plainSize; ++id)
856 			{
857 				result[id] = 0;
858 
859 				v1 = v3 = p.a;
860 
861 				for (deUint32 x = 0; x < p.b; x++)
862 				{
863 					if ((x & p.a) != 0)
864 					{
865 						v0 = p.c & id;
866 
867 						result[((x % 8) + 8) * plainSize + id] = v0;
868 						if (!fixed) v0++;
869 
870 						result[id] += v0 + v1 + v3;
871 					}
872 				}
873 
874 				for (deUint32 y = 0; y < p.d; y++)
875 				for (deUint32 x = 0; x < p.d; x++)
876 				{
877 					if (((x + y * p.a) & p.b) != 0)
878 						result[id] += (x + y);
879 				}
880 			}
881 
882 			break;
883 		}
884 		case TEST_TYPE_FUNCTION_CALL:
885 		{
886 			deUint32 a[42];
887 
888 			for (deUint32 id = 0; id < plainSize; ++id)
889 			{
890 				deUint32 r = 0;
891 				deUint32 i;
892 
893 				v0 = p.a & id;
894 				v1 = v3 = p.d;
895 
896 				for (i = 0; i < DE_LENGTH_OF_ARRAY(a); i++)
897 					a[i] = p.c * i;
898 
899 				result[plain8Ofs + id] = v0;
900 				if (!fixed) v0++;
901 
902 				for (i = 0; i < DE_LENGTH_OF_ARRAY(a); i++)
903 					r += a[i];
904 
905 				result[id] = (r + i) + v0 + v1 + v3;
906 			}
907 
908 			break;
909 		}
910 		case TEST_TYPE_NESTED_FUNCTION_CALL:
911 		{
912 			deUint32 a[14];
913 			deUint32 b[256];
914 
915 			for (deUint32 id = 0; id < plainSize; ++id)
916 			{
917 				deUint32 r = 0;
918 				deUint32 i;
919 				deUint32 t = 0;
920 				deUint32 j;
921 
922 				v0 = p.a & id;
923 				v3 = p.d;
924 
925 				for (j = 0; j < DE_LENGTH_OF_ARRAY(b); j++)
926 					b[j] = p.c * j;
927 
928 				v1 = p.b;
929 
930 				for (i = 0; i < DE_LENGTH_OF_ARRAY(a); i++)
931 					a[i] = p.c * i;
932 
933 				result[plain8Ofs + id] = v0;
934 				if (!fixed) v0++;
935 
936 				for (i = 0; i < DE_LENGTH_OF_ARRAY(a); i++)
937 					r += a[i];
938 
939 				for (j = 0; j < DE_LENGTH_OF_ARRAY(b); j++)
940 					t += b[j];
941 
942 				result[id] = (r + i) + (t + j) + v0 + v1 + v3;
943 			}
944 
945 			break;
946 		}
947 
948 		default:
949 			TCU_THROW(InternalError, "Unknown testType");
950 	}
951 
952 	{
953 		const deUint32	startOfs	= 7 * plainSize;
954 
955 		for (deUint32 n = 0; n < plainSize; ++n)
956 			result[startOfs + n] = n;
957 	}
958 
959 	for (deUint32 z = 1; z < DE_LENGTH_OF_ARRAY(pushConstants); ++z)
960 	{
961 		const deUint32	startOfs		= z * plainSize;
962 		const deUint32	pushConstant	= pushConstants[z];
963 
964 		for (deUint32 n = 0; n < plainSize; ++n)
965 			result[startOfs + n] = pushConstant;
966 	}
967 
968 	return result;
969 }
970 
iterate(void)971 tcu::TestStatus RayTracingComplexControlFlowInstance::iterate (void)
972 {
973 	const de::MovePtr<BufferWithMemory>	buffer		= runTest();
974 	const deUint32*						bufferPtr	= (deUint32*)buffer->getAllocation().getHostPtr();
975 	const vector<deUint32>				expected	= getExpectedValues();
976 	tcu::TestLog&						log			= m_context.getTestContext().getLog();
977 	deUint32							failures	= 0;
978 	deUint32							pos			= 0;
979 
980 	for (deUint32 z = 0; z < m_depth; ++z)
981 	for (deUint32 y = 0; y < m_data.height; ++y)
982 	for (deUint32 x = 0; x < m_data.width; ++x)
983 	{
984 		if (bufferPtr[pos] != expected[pos])
985 			failures++;
986 
987 		++pos;
988 	}
989 
990 	if (failures != 0)
991 	{
992 		deUint32			pos0	= 0;
993 		deUint32			pos1	= 0;
994 		std::stringstream	css;
995 
996 		for (deUint32 z = 0; z < m_depth; ++z)
997 		{
998 			css << "z=" << z << std::endl;
999 
1000 			for (deUint32 y = 0; y < m_data.height; ++y)
1001 			{
1002 				for (deUint32 x = 0; x < m_data.width; ++x)
1003 					css << std::setw(6) << bufferPtr[pos0++] << ' ';
1004 
1005 				css << "    ";
1006 
1007 				for (deUint32 x = 0; x < m_data.width; ++x)
1008 					css << std::setw(6) << expected[pos1++] << ' ';
1009 
1010 				css << std::endl;
1011 			}
1012 
1013 			css << std::endl;
1014 		}
1015 
1016 		log << tcu::TestLog::Message << css.str() << tcu::TestLog::EndMessage;
1017 	}
1018 
1019 	if (failures == 0)
1020 		return tcu::TestStatus::pass("Pass");
1021 	else
1022 		return tcu::TestStatus::fail("failures=" + de::toString(failures));
1023 }
1024 
1025 class ComplexControlFlowTestCase : public TestCase
1026 {
1027 	public:
1028 										ComplexControlFlowTestCase	(tcu::TestContext& context, const char* name, const char* desc, const CaseDef data);
1029 										~ComplexControlFlowTestCase	(void);
1030 
1031 	virtual	void						initPrograms				(SourceCollections& programCollection) const;
1032 	virtual TestInstance*				createInstance				(Context& context) const;
1033 	virtual void						checkSupport				(Context& context) const;
1034 
1035 private:
1036 	static inline const std::string		getIntersectionPassthrough	(void);
1037 	static inline const std::string		getMissPassthrough			(void);
1038 	static inline const std::string		getHitPassthrough			(void);
1039 
1040 	CaseDef								m_data;
1041 };
1042 
ComplexControlFlowTestCase(tcu::TestContext & context,const char * name,const char * desc,const CaseDef data)1043 ComplexControlFlowTestCase::ComplexControlFlowTestCase (tcu::TestContext& context, const char* name, const char* desc, const CaseDef data)
1044 	: vkt::TestCase	(context, name, desc)
1045 	, m_data		(data)
1046 {
1047 }
1048 
~ComplexControlFlowTestCase(void)1049 ComplexControlFlowTestCase::~ComplexControlFlowTestCase	(void)
1050 {
1051 }
1052 
checkSupport(Context & context) const1053 void ComplexControlFlowTestCase::checkSupport (Context& context) const
1054 {
1055 	context.requireDeviceFunctionality("VK_KHR_acceleration_structure");
1056 
1057 	const VkPhysicalDeviceAccelerationStructureFeaturesKHR&	accelerationStructureFeaturesKHR = context.getAccelerationStructureFeatures();
1058 
1059 	if (accelerationStructureFeaturesKHR.accelerationStructure == DE_FALSE)
1060 		TCU_THROW(TestError, "VK_KHR_ray_tracing_pipeline requires VkPhysicalDeviceAccelerationStructureFeaturesKHR.accelerationStructure");
1061 
1062 	context.requireDeviceFunctionality("VK_KHR_ray_tracing_pipeline");
1063 
1064 	const VkPhysicalDeviceRayTracingPipelineFeaturesKHR&	rayTracingPipelineFeaturesKHR = context.getRayTracingPipelineFeatures();
1065 
1066 	if (rayTracingPipelineFeaturesKHR.rayTracingPipeline == DE_FALSE)
1067 		TCU_THROW(NotSupportedError, "Requires VkPhysicalDeviceRayTracingPipelineFeaturesKHR.rayTracingPipeline");
1068 
1069 	const VkPhysicalDeviceRayTracingPipelinePropertiesKHR&	rayTracingPipelinePropertiesKHR = context.getRayTracingPipelineProperties();
1070 
1071 	if (m_data.testOp == TEST_OP_TRACE_RAY && m_data.stage != VK_SHADER_STAGE_RAYGEN_BIT_KHR)
1072 	{
1073 		if (rayTracingPipelinePropertiesKHR.maxRayRecursionDepth < 2)
1074 			TCU_THROW(NotSupportedError, "rayTracingPipelinePropertiesKHR.maxRayRecursionDepth is smaller than required");
1075 	}
1076 }
1077 
1078 
getIntersectionPassthrough(void)1079 const std::string ComplexControlFlowTestCase::getIntersectionPassthrough (void)
1080 {
1081 	const std::string intersectionPassthrough =
1082 		"#version 460 core\n"
1083 		"#extension GL_EXT_nonuniform_qualifier : enable\n"
1084 		"#extension GL_EXT_ray_tracing : require\n"
1085 		"hitAttributeEXT vec3 hitAttribute;\n"
1086 		"\n"
1087 		"void main()\n"
1088 		"{\n"
1089 		"  reportIntersectionEXT(0.95f, 0u);\n"
1090 		"}\n";
1091 
1092 	return intersectionPassthrough;
1093 }
1094 
getMissPassthrough(void)1095 const std::string ComplexControlFlowTestCase::getMissPassthrough (void)
1096 {
1097 	const std::string missPassthrough =
1098 		"#version 460 core\n"
1099 		"#extension GL_EXT_nonuniform_qualifier : enable\n"
1100 		"#extension GL_EXT_ray_tracing : require\n"
1101 		"layout(location = 0) rayPayloadInEXT vec3 hitValue;\n"
1102 		"\n"
1103 		"void main()\n"
1104 		"{\n"
1105 		"}\n";
1106 
1107 	return missPassthrough;
1108 }
1109 
getHitPassthrough(void)1110 const std::string ComplexControlFlowTestCase::getHitPassthrough (void)
1111 {
1112 	const std::string hitPassthrough =
1113 		"#version 460 core\n"
1114 		"#extension GL_EXT_nonuniform_qualifier : enable\n"
1115 		"#extension GL_EXT_ray_tracing : require\n"
1116 		"hitAttributeEXT vec3 attribs;\n"
1117 		"layout(location = 0) rayPayloadInEXT vec3 hitValue;\n"
1118 		"\n"
1119 		"void main()\n"
1120 		"{\n"
1121 		"}\n";
1122 
1123 	return hitPassthrough;
1124 }
1125 
initPrograms(SourceCollections & programCollection) const1126 void ComplexControlFlowTestCase::initPrograms (SourceCollections& programCollection) const
1127 {
1128 	const vk::ShaderBuildOptions	buildOptions			(programCollection.usedVulkanVersion, vk::SPIRV_VERSION_1_4, 0u, true);
1129 	const std::string				calleeMainPart			=
1130 		"  uint z = (inValue.x % 8) + 8;\n"
1131 		"  uint v = inValue.y;\n"
1132 		"  uint n = gl_LaunchIDEXT.x + gl_LaunchSizeEXT.x * gl_LaunchIDEXT.y;\n"
1133 		"  imageStore(resultImage, ivec3(gl_LaunchIDEXT.x, gl_LaunchIDEXT.y, z), uvec4(v, 0, 0, 1));\n"
1134 		"  imageStore(resultImage, ivec3(gl_LaunchIDEXT.x, gl_LaunchIDEXT.y, 7), uvec4(n, 0, 0, 1));\n";
1135 	const std::string				idTemplate				= "$";
1136 	const std::string				shaderCallInstruction	= (m_data.testOp == TEST_OP_EXECUTE_CALLABLE)    ? "executeCallableEXT(0, " + idTemplate + ")"
1137 															: (m_data.testOp == TEST_OP_TRACE_RAY)           ? "traceRayEXT(as, 0, 0xFF, p.hitOfs, 0, p.miss, vec3((gl_LaunchIDEXT.x) + vec3(0.5f)) / vec3(gl_LaunchSizeEXT), 1.0f, vec3(0.0f, 0.0f, 1.0f), 100.0f, " + idTemplate + ")"
1138 															: (m_data.testOp == TEST_OP_REPORT_INTERSECTION) ? "reportIntersectionEXT(1.0f, 0u)"
1139 															: "TEST_OP_NOT_IMPLEMENTED_FAILURE";
1140 	std::string						declsPreMain			=
1141 		"#version 460 core\n"
1142 		"#extension GL_EXT_nonuniform_qualifier : enable\n"
1143 		"#extension GL_EXT_ray_tracing : require\n"
1144 		"\n"
1145 		"layout(set = 0, binding = 0, r32ui) uniform uimage3D resultImage;\n"
1146 		"layout(set = 0, binding = 1) uniform accelerationStructureEXT as;\n"
1147 		"\n"
1148 		"layout(push_constant) uniform TestParams\n"
1149 		"{\n"
1150 		"    uint a;\n"
1151 		"    uint b;\n"
1152 		"    uint c;\n"
1153 		"    uint d;\n"
1154 		"    uint hitOfs;\n"
1155 		"    uint miss;\n"
1156 		"} p;\n";
1157 	std::string						declsInMainBeforeOp		=
1158 		"  uint result = 0;\n"
1159 		"  uint id = uint(gl_LaunchIDEXT.x + gl_LaunchSizeEXT.x * gl_LaunchIDEXT.y);\n";
1160 	std::string						declsInMainAfterOp		=
1161 		"  imageStore(resultImage, ivec3(gl_LaunchIDEXT.x, gl_LaunchIDEXT.y, 0), uvec4(result, 0, 0, 1));\n"
1162 		"  imageStore(resultImage, ivec3(gl_LaunchIDEXT.x, gl_LaunchIDEXT.y, 1), uvec4(p.a, 0, 0, 1));\n"
1163 		"  imageStore(resultImage, ivec3(gl_LaunchIDEXT.x, gl_LaunchIDEXT.y, 2), uvec4(p.b, 0, 0, 1));\n"
1164 		"  imageStore(resultImage, ivec3(gl_LaunchIDEXT.x, gl_LaunchIDEXT.y, 3), uvec4(p.c, 0, 0, 1));\n"
1165 		"  imageStore(resultImage, ivec3(gl_LaunchIDEXT.x, gl_LaunchIDEXT.y, 4), uvec4(p.d, 0, 0, 1));\n"
1166 		"  imageStore(resultImage, ivec3(gl_LaunchIDEXT.x, gl_LaunchIDEXT.y, 5), uvec4(p.hitOfs, 0, 0, 1));\n"
1167 		"  imageStore(resultImage, ivec3(gl_LaunchIDEXT.x, gl_LaunchIDEXT.y, 6), uvec4(p.miss, 0, 0, 1));\n";
1168 	std::string						opInMain				= "";
1169 	std::string						opPreMain				= "";
1170 
1171 	DE_ASSERT(!declsPreMain.empty() && PUSH_CONSTANTS_COUNT == 6);
1172 
1173 	switch (m_data.testType)
1174 	{
1175 		case TEST_TYPE_IF:
1176 		{
1177 			opInMain =
1178 				"  v2 = v3 = uvec2(0, p.b);\n"
1179 				"\n"
1180 				"  if ((p.a & id) != 0)\n"
1181 				"      { v0 = uvec2(0, p.c & id); v1 = uvec2(0, (p.d & id) + 1);" + replace(shaderCallInstruction, idTemplate, "0") + "; }\n"
1182 				"  else\n"
1183 				"      { v0 = uvec2(0, p.d & id); v1 = uvec2(0, (p.c & id) + 1);" + replace(shaderCallInstruction, idTemplate, "1") + "; }\n"
1184 				"\n"
1185 				"  result = v0.y + v1.y + v2.y + v3.y;\n";
1186 
1187 			break;
1188 		}
1189 		case TEST_TYPE_LOOP:
1190 		{
1191 			opInMain =
1192 				"  v1 = v3 = uvec2(0, p.b);\n"
1193 				"\n"
1194 				"  for (uint x = 0; x < p.a; x++)\n"
1195 				"  {\n"
1196 				"    v0 = uvec2(x, (p.c & id) + x);\n"
1197 				"    " + replace(shaderCallInstruction, idTemplate, "0") + ";\n"
1198 				"    result += v0.y + v1.y + v3.y;\n"
1199 				"  }\n";
1200 
1201 			break;
1202 		}
1203 		case TEST_TYPE_SWITCH:
1204 		{
1205 			opInMain =
1206 				"  switch (p.a & id)\n"
1207 				"  {\n"
1208 				"    case 0: { v1 = v2 = v3 = uvec2(0, p.b); v0 = uvec2(0, p.c & id); " + replace(shaderCallInstruction, idTemplate, "0") + "; break; }\n"
1209 				"    case 1: { v0 = v2 = v3 = uvec2(0, p.b); v1 = uvec2(0, p.c & id); " + replace(shaderCallInstruction, idTemplate, "1") + "; break; }\n"
1210 				"    case 2: { v0 = v1 = v3 = uvec2(0, p.b); v2 = uvec2(0, p.c & id); " + replace(shaderCallInstruction, idTemplate, "2") + "; break; }\n"
1211 				"    case 3: { v0 = v1 = v2 = uvec2(0, p.b); v3 = uvec2(0, p.c & id); " + replace(shaderCallInstruction, idTemplate, "3") + "; break; }\n"
1212 				"    default: break;\n"
1213 				"  }\n"
1214 				"\n"
1215 				"  result = v0.y + v1.y + v2.y + v3.y;\n";
1216 
1217 			break;
1218 		}
1219 		case TEST_TYPE_LOOP_DOUBLE_CALL:
1220 		{
1221 			opInMain =
1222 				"  v3 = uvec2(0, p.b);\n"
1223 				"  for (uint x = 0; x < p.a; x++)\n"
1224 				"  {\n"
1225 				"    v0 = uvec2(2 * x + 0, (p.c & id) + x);\n"
1226 				"    v1 = uvec2(2 * x + 1, (p.d & id) + x + 1);\n"
1227 				"    " + replace(shaderCallInstruction, idTemplate, "0") + ";\n"
1228 				"    " + replace(shaderCallInstruction, idTemplate, "1") + ";\n"
1229 				"    result += v0.y + v1.y + v3.y;\n"
1230 				"  }\n";
1231 
1232 			break;
1233 		}
1234 		case TEST_TYPE_LOOP_DOUBLE_CALL_SPARSE:
1235 		{
1236 			opInMain =
1237 				"  v3 = uvec2(0, p.a + p.b);\n"
1238 				"  for (uint x = 0; x < p.a; x++)\n"
1239 				"    if ((x & p.b) != 0)\n"
1240 				"    {\n"
1241 				"      v0 = uvec2(2 * x + 0, (p.c & id) + x + 0);\n"
1242 				"      v1 = uvec2(2 * x + 1, (p.d & id) + x + 1);\n"
1243 				"      " + replace(shaderCallInstruction, idTemplate, "0") + ";\n"
1244 				"      " + replace(shaderCallInstruction, idTemplate, "1") + ";\n"
1245 				"      result += v0.y + v1.y + v3.y;\n"
1246 				"    }\n"
1247 				"\n";
1248 
1249 			break;
1250 		}
1251 		case TEST_TYPE_NESTED_LOOP:
1252 		{
1253 			opInMain =
1254 				"  v1 = v3 = uvec2(0, p.b);\n"
1255 				"  for (uint y = 0; y < p.a; y++)\n"
1256 				"  for (uint x = 0; x < p.a; x++)\n"
1257 				"  {\n"
1258 				"    uint n = x + y * p.a;\n"
1259 				"    if ((n & p.d) != 0)\n"
1260 				"    {\n"
1261 				"      v0 = uvec2(n, (p.c & id) + (x + y * p.a));\n"
1262 				"      "+ replace(shaderCallInstruction, idTemplate, "0") + ";\n"
1263 				"      result += v0.y + v1.y + v3.y;\n"
1264 				"    }\n"
1265 				"  }\n"
1266 				"\n";
1267 
1268 			break;
1269 		}
1270 		case TEST_TYPE_NESTED_LOOP_BEFORE:
1271 		{
1272 			opInMain =
1273 				"  for (uint y = 0; y < p.d; y++)\n"
1274 				"  for (uint x = 0; x < p.d; x++)\n"
1275 				"    if (((x + y * p.a) & p.b) != 0)\n"
1276 				"      result += (x + y);\n"
1277 				"\n"
1278 				"  v1 = v3 = uvec2(0, p.a);\n"
1279 				"\n"
1280 				"  for (uint x = 0; x < p.b; x++)\n"
1281 				"    if ((x & p.a) != 0)\n"
1282 				"    {\n"
1283 				"      v0 = uvec2(x, p.c & id);\n"
1284 				"      " + replace(shaderCallInstruction, idTemplate, "0") + ";\n"
1285 				"      result += v0.y + v1.y + v3.y;\n"
1286 				"    }\n";
1287 
1288 			break;
1289 		}
1290 		case TEST_TYPE_NESTED_LOOP_AFTER:
1291 		{
1292 			opInMain =
1293 				"  v1 = v3 = uvec2(0, p.a); \n"
1294 				"  for (uint x = 0; x < p.b; x++)\n"
1295 				"    if ((x & p.a) != 0)\n"
1296 				"    {\n"
1297 				"      v0 = uvec2(x, p.c & id);\n"
1298 				"      " + replace(shaderCallInstruction, idTemplate, "0") + ";\n"
1299 				"      result += v0.y + v1.y + v3.y;\n"
1300 				"    }\n"
1301 				"\n"
1302 				"  for (uint y = 0; y < p.d; y++)\n"
1303 				"  for (uint x = 0; x < p.d; x++)\n"
1304 				"    if (((x + y * p.a) & p.b) != 0)\n"
1305 				"      result += x + y;\n";
1306 
1307 			break;
1308 		}
1309 		case TEST_TYPE_FUNCTION_CALL:
1310 		{
1311 			opPreMain =
1312 				"uint f1(void)\n"
1313 				"{\n"
1314 				"  uint i, r = 0;\n"
1315 				"  uint a[42];\n"
1316 				"\n"
1317 				"  for (i = 0; i < a.length(); i++) a[i] = p.c * i;\n"
1318 				"\n"
1319 				"  " + replace(shaderCallInstruction, idTemplate, "0") + ";\n"
1320 				"\n"
1321 				"  for (i = 0; i < a.length(); i++) r += a[i];\n"
1322 				"\n"
1323 				"  return r + i;\n"
1324 				"}\n";
1325 			opInMain =
1326 				"  v0 = uvec2(0, p.a & id); v1 = v3 = uvec2(0, p.d);\n"
1327 				"  result = f1() + v0.y + v1.y + v3.y;\n";
1328 
1329 			break;
1330 		}
1331 		case TEST_TYPE_NESTED_FUNCTION_CALL:
1332 		{
1333 			opPreMain =
1334 				"uint f0(void)\n"
1335 				"{\n"
1336 				"  uint i, r = 0;\n"
1337 				"  uint a[14];\n"
1338 				"\n"
1339 				"  for (i = 0; i < a.length(); i++) a[i] = p.c * i;\n"
1340 				"\n"
1341 				"  " + replace(shaderCallInstruction, idTemplate, "0") + ";\n"
1342 				"\n"
1343 				"  for (i = 0; i < a.length(); i++) r += a[i];\n"
1344 				"\n"
1345 				"  return r + i;\n"
1346 				"}\n"
1347 				"\n"
1348 				"uint f1(void)\n"
1349 				"{\n"
1350 				"  uint j, t = 0;\n"
1351 				"  uint b[256];\n"
1352 				"\n"
1353 				"  for (j = 0; j < b.length(); j++) b[j] = p.c * j;\n"
1354 				"\n"
1355 				"  v1 = uvec2(0, p.b);\n"
1356 				"\n"
1357 				"  t += f0();\n"
1358 				"\n"
1359 				"  for (j = 0; j < b.length(); j++) t += b[j];\n"
1360 				"\n"
1361 				"  return t + j;\n"
1362 				"}\n";
1363 			opInMain =
1364 				"  v0 = uvec2(0, p.a & id); v3 = uvec2(0, p.d);\n"
1365 				"  result = f1() + v0.y + v1.y + v3.y;\n";
1366 
1367 			break;
1368 		}
1369 
1370 		default:
1371 			TCU_THROW(InternalError, "Unknown testType");
1372 	}
1373 
1374 	if (m_data.testOp == TEST_OP_EXECUTE_CALLABLE)
1375 	{
1376 		const std::string	calleeShader			=
1377 			"#version 460 core\n"
1378 			"#extension GL_EXT_nonuniform_qualifier : enable\n"
1379 			"#extension GL_EXT_ray_tracing : require\n"
1380 			"\n"
1381 			"layout(set = 0, binding = 0, r32ui) uniform uimage3D resultImage;\n"
1382 			"layout(location = 0) callableDataInEXT uvec2 inValue;\n"
1383 			"\n"
1384 			"void main()\n"
1385 			"{\n"
1386 			+ calleeMainPart +
1387 			"  inValue.y++;\n"
1388 			"}\n";
1389 
1390 		declsPreMain +=
1391 			"layout(location = 0) callableDataEXT uvec2 v0;\n"
1392 			"layout(location = 1) callableDataEXT uvec2 v1;\n"
1393 			"layout(location = 2) callableDataEXT uvec2 v2;\n"
1394 			"layout(location = 3) callableDataEXT uvec2 v3;\n"
1395 			"\n";
1396 
1397 		switch (m_data.stage)
1398 		{
1399 			case VK_SHADER_STAGE_RAYGEN_BIT_KHR:
1400 			{
1401 				std::stringstream css;
1402 				css << declsPreMain
1403 					<< opPreMain
1404 					<< "\n"
1405 					<< "void main()\n"
1406 					<< "{\n"
1407 					<< declsInMainBeforeOp
1408 					<< opInMain // executeCallableEXT
1409 					<< declsInMainAfterOp
1410 					<< "}\n";
1411 
1412 				programCollection.glslSources.add("rgen") << glu::RaygenSource(css.str()) << buildOptions;
1413 				programCollection.glslSources.add("cal0") << glu::CallableSource(calleeShader) << buildOptions;
1414 
1415 				break;
1416 			}
1417 
1418 			case VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR:
1419 			{
1420 				programCollection.glslSources.add("rgen") << glu::RaygenSource(getCommonRayGenerationShader()) << buildOptions;
1421 
1422 				std::stringstream css;
1423 				css << declsPreMain
1424 					<< "layout(location = 0) rayPayloadInEXT vec3 hitValue;\n"
1425 					<< "hitAttributeEXT vec3 attribs;\n"
1426 					<< "\n"
1427 					<< opPreMain
1428 					<< "\n"
1429 					<< "void main()\n"
1430 					<< "{\n"
1431 					<< declsInMainBeforeOp
1432 					<< opInMain // executeCallableEXT
1433 					<< declsInMainAfterOp
1434 					<< "}\n";
1435 
1436 				programCollection.glslSources.add("chit") << glu::ClosestHitSource(css.str()) << buildOptions;
1437 				programCollection.glslSources.add("cal0") << glu::CallableSource(calleeShader) << buildOptions;
1438 
1439 				programCollection.glslSources.add("ahit") << glu::AnyHitSource(getHitPassthrough()) << buildOptions;
1440 				programCollection.glslSources.add("miss") << glu::MissSource(getMissPassthrough()) << buildOptions;
1441 				programCollection.glslSources.add("sect") << glu::IntersectionSource(getIntersectionPassthrough()) << buildOptions;
1442 
1443 				break;
1444 			}
1445 
1446 			case VK_SHADER_STAGE_MISS_BIT_KHR:
1447 			{
1448 				programCollection.glslSources.add("rgen") << glu::RaygenSource(getCommonRayGenerationShader()) << buildOptions;
1449 
1450 				std::stringstream css;
1451 				css << declsPreMain
1452 					<< opPreMain
1453 					<< "\n"
1454 					<< "void main()\n"
1455 					<< "{\n"
1456 					<< declsInMainBeforeOp
1457 					<< opInMain // executeCallableEXT
1458 					<< declsInMainAfterOp
1459 					<< "}\n";
1460 
1461 				programCollection.glslSources.add("miss") << glu::MissSource(css.str()) << buildOptions;
1462 				programCollection.glslSources.add("cal0") << glu::CallableSource(calleeShader) << buildOptions;
1463 
1464 				programCollection.glslSources.add("ahit") << glu::AnyHitSource(getHitPassthrough()) << buildOptions;
1465 				programCollection.glslSources.add("chit") << glu::ClosestHitSource(getHitPassthrough()) << buildOptions;
1466 				programCollection.glslSources.add("sect") << glu::IntersectionSource(getIntersectionPassthrough()) << buildOptions;
1467 
1468 				break;
1469 			}
1470 
1471 			case VK_SHADER_STAGE_CALLABLE_BIT_KHR:
1472 			{
1473 				{
1474 					std::stringstream css;
1475 					css << "#version 460 core\n"
1476 						<< "#extension GL_EXT_nonuniform_qualifier : enable\n"
1477 						<< "#extension GL_EXT_ray_tracing : require\n"
1478 						<< "\n"
1479 						<< "layout(location = 4) callableDataEXT float dummy;\n"
1480 						<< "layout(set = 0, binding = 0, r32ui) uniform uimage3D resultImage;\n"
1481 						<< "\n"
1482 						<< "void main()\n"
1483 						<< "{\n"
1484 						<< "  executeCallableEXT(1, 4);\n"
1485 						<< "}\n";
1486 
1487 					programCollection.glslSources.add("rgen") << glu::RaygenSource(css.str()) << buildOptions;
1488 				}
1489 
1490 				{
1491 					std::stringstream css;
1492 					css << declsPreMain
1493 						<< "layout(location = 4) callableDataInEXT float dummyIn;\n"
1494 						<< opPreMain
1495 						<< "\n"
1496 						<< "void main()\n"
1497 						<< "{\n"
1498 						<< declsInMainBeforeOp
1499 						<< opInMain // executeCallableEXT
1500 						<< declsInMainAfterOp
1501 						<< "}\n";
1502 
1503 					programCollection.glslSources.add("call") << glu::CallableSource(css.str()) << buildOptions;
1504 				}
1505 
1506 				programCollection.glslSources.add("cal0") << glu::CallableSource(calleeShader) << buildOptions;
1507 
1508 				break;
1509 			}
1510 
1511 			default:
1512 				TCU_THROW(InternalError, "Unknown stage");
1513 		}
1514 	}
1515 	else if (m_data.testOp == TEST_OP_TRACE_RAY)
1516 	{
1517 		const std::string	missShader	=
1518 			"#version 460 core\n"
1519 			"#extension GL_EXT_nonuniform_qualifier : enable\n"
1520 			"#extension GL_EXT_ray_tracing : require\n"
1521 			"\n"
1522 			"layout(set = 0, binding = 0, r32ui) uniform uimage3D resultImage;\n"
1523 			"layout(location = 0) rayPayloadInEXT uvec2 inValue;\n"
1524 			"\n"
1525 			"void main()\n"
1526 			"{\n"
1527 			+ calleeMainPart +
1528 			"  inValue.y++;\n"
1529 			"}\n";
1530 
1531 		declsPreMain +=
1532 			"layout(location = 0) rayPayloadEXT uvec2 v0;\n"
1533 			"layout(location = 1) rayPayloadEXT uvec2 v1;\n"
1534 			"layout(location = 2) rayPayloadEXT uvec2 v2;\n"
1535 			"layout(location = 3) rayPayloadEXT uvec2 v3;\n";
1536 
1537 		switch (m_data.stage)
1538 		{
1539 			case VK_SHADER_STAGE_RAYGEN_BIT_KHR:
1540 			{
1541 				std::stringstream css;
1542 				css << declsPreMain
1543 					<< opPreMain
1544 					<< "\n"
1545 					<< "void main()\n"
1546 					<< "{\n"
1547 					<< declsInMainBeforeOp
1548 					<< opInMain // traceRayEXT
1549 					<< declsInMainAfterOp
1550 					<< "}\n";
1551 
1552 				programCollection.glslSources.add("rgen") << glu::RaygenSource(css.str()) << buildOptions;
1553 
1554 				programCollection.glslSources.add("miss") << glu::MissSource(getMissPassthrough()) << buildOptions;
1555 				programCollection.glslSources.add("ahit") << glu::AnyHitSource(getHitPassthrough()) << buildOptions;
1556 				programCollection.glslSources.add("chit") << glu::ClosestHitSource(getHitPassthrough()) << buildOptions;
1557 				programCollection.glslSources.add("sect") << glu::IntersectionSource(getIntersectionPassthrough()) << buildOptions;
1558 
1559 				programCollection.glslSources.add("miss2") << glu::MissSource(missShader) << buildOptions;
1560 				programCollection.glslSources.add("ahit2") << glu::AnyHitSource(getHitPassthrough()) << buildOptions;
1561 				programCollection.glslSources.add("chit2") << glu::ClosestHitSource(getHitPassthrough()) << buildOptions;
1562 				programCollection.glslSources.add("sect2") << glu::IntersectionSource(getIntersectionPassthrough()) << buildOptions;
1563 
1564 				break;
1565 			}
1566 
1567 			case VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR:
1568 			{
1569 				programCollection.glslSources.add("rgen") << glu::RaygenSource(getCommonRayGenerationShader()) << buildOptions;
1570 
1571 				std::stringstream css;
1572 				css << declsPreMain
1573 					<< opPreMain
1574 					<< "\n"
1575 					<< "void main()\n"
1576 					<< "{\n"
1577 					<< declsInMainBeforeOp
1578 					<< opInMain // traceRayEXT
1579 					<< declsInMainAfterOp
1580 					<< "}\n";
1581 
1582 				programCollection.glslSources.add("chit") << glu::ClosestHitSource(css.str()) << buildOptions;
1583 
1584 				programCollection.glslSources.add("miss") << glu::MissSource(getMissPassthrough()) << buildOptions;
1585 				programCollection.glslSources.add("ahit") << glu::AnyHitSource(getHitPassthrough()) << buildOptions;
1586 				programCollection.glslSources.add("sect") << glu::IntersectionSource(getIntersectionPassthrough()) << buildOptions;
1587 
1588 				programCollection.glslSources.add("miss2") << glu::MissSource(missShader) << buildOptions;
1589 				programCollection.glslSources.add("ahit2") << glu::AnyHitSource(getHitPassthrough()) << buildOptions;
1590 				programCollection.glslSources.add("chit2") << glu::ClosestHitSource(getHitPassthrough()) << buildOptions;
1591 				programCollection.glslSources.add("sect2") << glu::IntersectionSource(getIntersectionPassthrough()) << buildOptions;
1592 
1593 				break;
1594 			}
1595 
1596 			case VK_SHADER_STAGE_MISS_BIT_KHR:
1597 			{
1598 				programCollection.glslSources.add("rgen") << glu::RaygenSource(getCommonRayGenerationShader()) << buildOptions;
1599 
1600 				std::stringstream css;
1601 				css << declsPreMain
1602 					<< opPreMain
1603 					<< "\n"
1604 					<< "void main()\n"
1605 					<< "{\n"
1606 					<< declsInMainBeforeOp
1607 					<< opInMain // traceRayEXT
1608 					<< declsInMainAfterOp
1609 					<< "}\n";
1610 
1611 				programCollection.glslSources.add("miss") << glu::MissSource(css.str()) << buildOptions;
1612 
1613 				programCollection.glslSources.add("ahit") << glu::AnyHitSource(getHitPassthrough()) << buildOptions;
1614 				programCollection.glslSources.add("chit") << glu::ClosestHitSource(getHitPassthrough()) << buildOptions;
1615 				programCollection.glslSources.add("sect") << glu::IntersectionSource(getIntersectionPassthrough()) << buildOptions;
1616 
1617 				programCollection.glslSources.add("miss2") << glu::MissSource(missShader) << buildOptions;
1618 				programCollection.glslSources.add("ahit2") << glu::AnyHitSource(getHitPassthrough()) << buildOptions;
1619 				programCollection.glslSources.add("chit2") << glu::ClosestHitSource(getHitPassthrough()) << buildOptions;
1620 				programCollection.glslSources.add("sect2") << glu::IntersectionSource(getIntersectionPassthrough()) << buildOptions;
1621 
1622 				break;
1623 			}
1624 
1625 			default:
1626 				TCU_THROW(InternalError, "Unknown stage");
1627 		}
1628 	}
1629 	else if (m_data.testOp == TEST_OP_REPORT_INTERSECTION)
1630 	{
1631 		const std::string	anyHitShader		=
1632 			"#version 460 core\n"
1633 			"#extension GL_EXT_nonuniform_qualifier : enable\n"
1634 			"#extension GL_EXT_ray_tracing : require\n"
1635 			"\n"
1636 			"layout(set = 0, binding = 0, r32ui) uniform uimage3D resultImage;\n"
1637 			"hitAttributeEXT block { uvec2 inValue; };\n"
1638 			"\n"
1639 			"void main()\n"
1640 			"{\n"
1641 			+ calleeMainPart +
1642 			"}\n";
1643 
1644 		declsPreMain +=
1645 			"hitAttributeEXT block { uvec2 v0; };\n"
1646 			"uvec2 v1;\n"
1647 			"uvec2 v2;\n"
1648 			"uvec2 v3;\n";
1649 
1650 		switch (m_data.stage)
1651 		{
1652 			case VK_SHADER_STAGE_INTERSECTION_BIT_KHR:
1653 			{
1654 				programCollection.glslSources.add("rgen") << glu::RaygenSource(getCommonRayGenerationShader()) << buildOptions;
1655 
1656 				std::stringstream css;
1657 				css << declsPreMain
1658 					<< opPreMain
1659 					<< "\n"
1660 					<< "void main()\n"
1661 					<< "{\n"
1662 					<< declsInMainBeforeOp
1663 					<< opInMain // reportIntersectionEXT
1664 					<< declsInMainAfterOp
1665 					<< "}\n";
1666 
1667 				programCollection.glslSources.add("sect") << glu::IntersectionSource(css.str()) << buildOptions;
1668 				programCollection.glslSources.add("ahit") << glu::AnyHitSource(anyHitShader) << buildOptions;
1669 
1670 				programCollection.glslSources.add("chit") << glu::ClosestHitSource(getHitPassthrough()) << buildOptions;
1671 				programCollection.glslSources.add("miss") << glu::MissSource(getMissPassthrough()) << buildOptions;
1672 
1673 				break;
1674 			}
1675 
1676 			default:
1677 				TCU_THROW(InternalError, "Unknown stage");
1678 		}
1679 	}
1680 	else
1681 	{
1682 		TCU_THROW(InternalError, "Unknown operation");
1683 	}
1684 }
1685 
createInstance(Context & context) const1686 TestInstance* ComplexControlFlowTestCase::createInstance (Context& context) const
1687 {
1688 	return new RayTracingComplexControlFlowInstance(context, m_data);
1689 }
1690 
1691 }	// anonymous
1692 
createComplexControlFlowTests(tcu::TestContext & testCtx)1693 tcu::TestCaseGroup*	createComplexControlFlowTests (tcu::TestContext& testCtx)
1694 {
1695 	const VkShaderStageFlagBits	R	= VK_SHADER_STAGE_RAYGEN_BIT_KHR;
1696 	const VkShaderStageFlagBits	A	= VK_SHADER_STAGE_ANY_HIT_BIT_KHR;
1697 	const VkShaderStageFlagBits	C	= VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR;
1698 	const VkShaderStageFlagBits	M	= VK_SHADER_STAGE_MISS_BIT_KHR;
1699 	const VkShaderStageFlagBits	I	= VK_SHADER_STAGE_INTERSECTION_BIT_KHR;
1700 	const VkShaderStageFlagBits	L	= VK_SHADER_STAGE_CALLABLE_BIT_KHR;
1701 
1702 	DE_UNREF(A);
1703 
1704 	static const struct
1705 	{
1706 		const char*				name;
1707 		VkShaderStageFlagBits	stage;
1708 	}
1709 	testStages[]
1710 	{
1711 		{ "rgen", VK_SHADER_STAGE_RAYGEN_BIT_KHR		},
1712 		{ "chit", VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR	},
1713 		{ "ahit", VK_SHADER_STAGE_ANY_HIT_BIT_KHR		},
1714 		{ "sect", VK_SHADER_STAGE_INTERSECTION_BIT_KHR	},
1715 		{ "miss", VK_SHADER_STAGE_MISS_BIT_KHR			},
1716 		{ "call", VK_SHADER_STAGE_CALLABLE_BIT_KHR		},
1717 	};
1718 	static const struct
1719 	{
1720 		const char*			name;
1721 		TestOp				op;
1722 		VkShaderStageFlags	applicableInStages;
1723 	}
1724 	testOps[]
1725 	{
1726 		{ "execute_callable",		TEST_OP_EXECUTE_CALLABLE,		R |    C | M     | L },
1727 		{ "trace_ray",				TEST_OP_TRACE_RAY,				R |    C | M         },
1728 		{ "report_intersection",	TEST_OP_REPORT_INTERSECTION,	               I     },
1729 	};
1730 	static const struct
1731 	{
1732 		const char*	name;
1733 		TestType	testType;
1734 	}
1735 	testTypes[]
1736 	{
1737 		{ "if",							TEST_TYPE_IF						},
1738 		{ "loop",						TEST_TYPE_LOOP						},
1739 		{ "switch",						TEST_TYPE_SWITCH					},
1740 		{ "loop_double_call",			TEST_TYPE_LOOP_DOUBLE_CALL			},
1741 		{ "loop_double_call_sparse",	TEST_TYPE_LOOP_DOUBLE_CALL_SPARSE	},
1742 		{ "nested_loop",				TEST_TYPE_NESTED_LOOP				},
1743 		{ "nested_loop_loop_before",	TEST_TYPE_NESTED_LOOP_BEFORE		},
1744 		{ "nested_loop_loop_after",		TEST_TYPE_NESTED_LOOP_AFTER			},
1745 		{ "function_call",				TEST_TYPE_FUNCTION_CALL				},
1746 		{ "nested_function_call",		TEST_TYPE_NESTED_FUNCTION_CALL		},
1747 	};
1748 
1749 	de::MovePtr<tcu::TestCaseGroup> group(new tcu::TestCaseGroup(testCtx, "complexcontrolflow", "Ray tracing complex control flow tests"));
1750 
1751 	for (size_t testTypeNdx = 0; testTypeNdx < DE_LENGTH_OF_ARRAY(testTypes); ++testTypeNdx)
1752 	{
1753 		const TestType					testType		= testTypes[testTypeNdx].testType;
1754 		de::MovePtr<tcu::TestCaseGroup> testTypeGroup	(new tcu::TestCaseGroup(testCtx, testTypes[testTypeNdx].name, ""));
1755 
1756 		for (size_t testOpNdx = 0; testOpNdx < DE_LENGTH_OF_ARRAY(testOps); ++testOpNdx)
1757 		{
1758 			const TestOp					testOp		= testOps[testOpNdx].op;
1759 			de::MovePtr<tcu::TestCaseGroup> testOpGroup	(new tcu::TestCaseGroup(testCtx, testOps[testOpNdx].name, ""));
1760 
1761 			for (size_t testStagesNdx = 0; testStagesNdx < DE_LENGTH_OF_ARRAY(testStages); ++testStagesNdx)
1762 			{
1763 				const VkShaderStageFlagBits	testStage				= testStages[testStagesNdx].stage;
1764 				const std::string			testName				= de::toString(testStages[testStagesNdx].name);
1765 				const deUint32				width					= 4u;
1766 				const deUint32				height					= 4u;
1767 				const CaseDef				caseDef					=
1768 				{
1769 					testType,				//  TestType				testType;
1770 					testOp,					//  TestOp					testOp;
1771 					testStage,				//  VkShaderStageFlagBits	stage;
1772 					width,					//  deUint32				width;
1773 					height,					//  deUint32				height;
1774 				};
1775 
1776 				if ((testOps[testOpNdx].applicableInStages & static_cast<VkShaderStageFlags>(testStage)) == 0)
1777 					continue;
1778 
1779 				testOpGroup->addChild(new ComplexControlFlowTestCase(testCtx, testName.c_str(), "", caseDef));
1780 			}
1781 
1782 			testTypeGroup->addChild(testOpGroup.release());
1783 		}
1784 
1785 		group->addChild(testTypeGroup.release());
1786 	}
1787 
1788 	return group.release();
1789 }
1790 
1791 }	// RayTracing
1792 }	// vkt
1793