• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*------------------------------------------------------------------------
2  * Vulkan Conformance Tests
3  * ------------------------
4  *
5  * Copyright (c) 2020 The Khronos Group Inc.
6  *
7  * Licensed under the Apache License, Version 2.0 (the "License");
8  * you may not use this file except in compliance with the License.
9  * You may obtain a copy of the License at
10  *
11  *	  http://www.apache.org/licenses/LICENSE-2.0
12  *
13  * Unless required by applicable law or agreed to in writing, software
14  * distributed under the License is distributed on an "AS IS" BASIS,
15  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16  * See the License for the specific language governing permissions and
17  * limitations under the License.
18  *
19  *//*!
20  * \file
21  * \brief Ray Tracing Callable Shader tests
22  *//*--------------------------------------------------------------------*/
23 
24 #include "vktRayTracingCallableShadersTests.hpp"
25 
26 #include "vkDefs.hpp"
27 
28 #include "vktTestCase.hpp"
29 #include "vktTestGroupUtil.hpp"
30 #include "vkCmdUtil.hpp"
31 #include "vkObjUtil.hpp"
32 #include "vkBuilderUtil.hpp"
33 #include "vkBarrierUtil.hpp"
34 #include "vkBufferWithMemory.hpp"
35 #include "vkImageWithMemory.hpp"
36 #include "vkTypeUtil.hpp"
37 #include "vkImageUtil.hpp"
38 #include "deRandom.hpp"
39 #include "tcuTexture.hpp"
40 #include "tcuTextureUtil.hpp"
41 #include "tcuTestLog.hpp"
42 #include "tcuImageCompare.hpp"
43 
44 #include "vkRayTracingUtil.hpp"
45 
46 namespace vkt
47 {
48 namespace RayTracing
49 {
50 namespace
51 {
52 using namespace vk;
53 using namespace vkt;
54 
55 static const VkFlags	ALL_RAY_TRACING_STAGES	= VK_SHADER_STAGE_RAYGEN_BIT_KHR
56 												| VK_SHADER_STAGE_ANY_HIT_BIT_KHR
57 												| VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR
58 												| VK_SHADER_STAGE_MISS_BIT_KHR
59 												| VK_SHADER_STAGE_INTERSECTION_BIT_KHR
60 												| VK_SHADER_STAGE_CALLABLE_BIT_KHR;
61 
62 enum CallableShaderTestType
63 {
64 	CSTT_RGEN_CALL		= 0,
65 	CSTT_RGEN_CALL_CALL	= 1,
66 	CSTT_HIT_CALL		= 2,
67 	CSTT_RGEN_MULTICALL	= 3,
68 	CSTT_COUNT
69 };
70 
71 const deUint32			TEST_WIDTH			= 8;
72 const deUint32			TEST_HEIGHT			= 8;
73 
74 struct TestParams;
75 
76 class TestConfiguration
77 {
78 public:
79 	virtual std::vector<de::SharedPtr<BottomLevelAccelerationStructure>>	initBottomAccelerationStructures	(Context&							context,
80 																												 TestParams&						testParams) = 0;
81 	virtual de::MovePtr<TopLevelAccelerationStructure>						initTopAccelerationStructure		(Context&							context,
82 																												 TestParams&						testParams,
83 																												 std::vector<de::SharedPtr<BottomLevelAccelerationStructure> >&	bottomLevelAccelerationStructures) = 0;
84 	virtual void															initRayTracingShaders				(de::MovePtr<RayTracingPipeline>&	rayTracingPipeline,
85 																												 Context&							context,
86 																												TestParams&							testParams) = 0;
87 	virtual void															initShaderBindingTables				(de::MovePtr<RayTracingPipeline>&	rayTracingPipeline,
88 																												 Context&							context,
89 																												 TestParams&						testParams,
90 																												 VkPipeline							pipeline,
91 																												 deUint32							shaderGroupHandleSize,
92 																												 deUint32							shaderGroupBaseAlignment,
93 																												 de::MovePtr<BufferWithMemory>&		raygenShaderBindingTable,
94 																												 de::MovePtr<BufferWithMemory>&		hitShaderBindingTable,
95 																												 de::MovePtr<BufferWithMemory>&		missShaderBindingTable,
96 																												 de::MovePtr<BufferWithMemory>&		callableShaderBindingTable,
97 																												 VkStridedDeviceAddressRegionKHR&	raygenShaderBindingTableRegion,
98 																												 VkStridedDeviceAddressRegionKHR&	hitShaderBindingTableRegion,
99 																												 VkStridedDeviceAddressRegionKHR&	missShaderBindingTableRegion,
100 																												 VkStridedDeviceAddressRegionKHR&	callableShaderBindingTableRegion) = 0;
101 	virtual bool															verifyImage							(BufferWithMemory*					resultBuffer,
102 																												 Context&							context,
103 																												 TestParams&						testParams) = 0;
104 	virtual VkFormat														getResultImageFormat				() = 0;
105 	virtual size_t															getResultImageFormatSize			() = 0;
106 	virtual VkClearValue													getClearValue						() = 0;
107 };
108 
109 struct TestParams
110 {
111 	deUint32							width;
112 	deUint32							height;
113 	CallableShaderTestType				callableShaderTestType;
114 	de::SharedPtr<TestConfiguration>	testConfiguration;
115 
116 };
117 
getShaderGroupHandleSize(const InstanceInterface & vki,const VkPhysicalDevice physicalDevice)118 deUint32 getShaderGroupHandleSize (const InstanceInterface&	vki,
119 								   const VkPhysicalDevice	physicalDevice)
120 {
121 	de::MovePtr<RayTracingProperties>	rayTracingPropertiesKHR;
122 
123 	rayTracingPropertiesKHR	= makeRayTracingProperties(vki, physicalDevice);
124 	return rayTracingPropertiesKHR->getShaderGroupHandleSize();
125 }
126 
getShaderGroupBaseAlignment(const InstanceInterface & vki,const VkPhysicalDevice physicalDevice)127 deUint32 getShaderGroupBaseAlignment (const InstanceInterface&	vki,
128 									  const VkPhysicalDevice	physicalDevice)
129 {
130 	de::MovePtr<RayTracingProperties>	rayTracingPropertiesKHR;
131 
132 	rayTracingPropertiesKHR = makeRayTracingProperties(vki, physicalDevice);
133 	return rayTracingPropertiesKHR->getShaderGroupBaseAlignment();
134 }
135 
makeImageCreateInfo(deUint32 width,deUint32 height,VkFormat format)136 VkImageCreateInfo makeImageCreateInfo (deUint32 width, deUint32 height, VkFormat format)
137 {
138 	const VkImageCreateInfo			imageCreateInfo			=
139 	{
140 		VK_STRUCTURE_TYPE_IMAGE_CREATE_INFO,																// VkStructureType			sType;
141 		DE_NULL,																							// const void*				pNext;
142 		(VkImageCreateFlags)0u,																				// VkImageCreateFlags		flags;
143 		VK_IMAGE_TYPE_2D,																					// VkImageType				imageType;
144 		format,																								// VkFormat					format;
145 		makeExtent3D(width, height, 1),																		// VkExtent3D				extent;
146 		1u,																									// deUint32					mipLevels;
147 		1u,																									// deUint32					arrayLayers;
148 		VK_SAMPLE_COUNT_1_BIT,																				// VkSampleCountFlagBits	samples;
149 		VK_IMAGE_TILING_OPTIMAL,																			// VkImageTiling			tiling;
150 		VK_IMAGE_USAGE_STORAGE_BIT | VK_IMAGE_USAGE_TRANSFER_SRC_BIT | VK_IMAGE_USAGE_TRANSFER_DST_BIT,		// VkImageUsageFlags		usage;
151 		VK_SHARING_MODE_EXCLUSIVE,																			// VkSharingMode			sharingMode;
152 		0u,																									// deUint32					queueFamilyIndexCount;
153 		DE_NULL,																							// const deUint32*			pQueueFamilyIndices;
154 		VK_IMAGE_LAYOUT_UNDEFINED																			// VkImageLayout			initialLayout;
155 	};
156 
157 	return imageCreateInfo;
158 }
159 
160 class SingleSquareConfiguration : public TestConfiguration
161 {
162 public:
163 	std::vector<de::SharedPtr<BottomLevelAccelerationStructure>>	initBottomAccelerationStructures	(Context&							context,
164 																										 TestParams&						testParams) override;
165 	de::MovePtr<TopLevelAccelerationStructure>						initTopAccelerationStructure		(Context&							context,
166 																										 TestParams&						testParams,
167 																										 std::vector<de::SharedPtr<BottomLevelAccelerationStructure> >&	bottomLevelAccelerationStructures) override;
168 	void															initRayTracingShaders				(de::MovePtr<RayTracingPipeline>&	rayTracingPipeline,
169 																										 Context&							context,
170 																										 TestParams&						testParams) override;
171 	void															initShaderBindingTables				(de::MovePtr<RayTracingPipeline>&	rayTracingPipeline,
172 																										 Context&							context,
173 																										 TestParams&						testParams,
174 																										 VkPipeline							pipeline,
175 																										 deUint32							shaderGroupHandleSize,
176 																										 deUint32							shaderGroupBaseAlignment,
177 																										 de::MovePtr<BufferWithMemory>&		raygenShaderBindingTable,
178 																										 de::MovePtr<BufferWithMemory>&		hitShaderBindingTable,
179 																										 de::MovePtr<BufferWithMemory>&		missShaderBindingTable,
180 																										 de::MovePtr<BufferWithMemory>&		callableShaderBindingTable,
181 																										 VkStridedDeviceAddressRegionKHR&	raygenShaderBindingTableRegion,
182 																										 VkStridedDeviceAddressRegionKHR&	hitShaderBindingTableRegion,
183 																										 VkStridedDeviceAddressRegionKHR&	missShaderBindingTableRegion,
184 																										 VkStridedDeviceAddressRegionKHR&	callableShaderBindingTableRegion) override;
185 	bool															verifyImage							(BufferWithMemory*					resultBuffer,
186 																										 Context&							context,
187 																										 TestParams&						testParams) override;
188 	VkFormat														getResultImageFormat				() override;
189 	size_t															getResultImageFormatSize			() override;
190 	VkClearValue													getClearValue						() override;
191 };
192 
initBottomAccelerationStructures(Context & context,TestParams & testParams)193 std::vector<de::SharedPtr<BottomLevelAccelerationStructure> > SingleSquareConfiguration::initBottomAccelerationStructures (Context&			context,
194 																														   TestParams&		testParams)
195 {
196 	DE_UNREF(context);
197 
198 	tcu::Vec3 v0(1.0, float(testParams.height) - 1.0f, 0.0);
199 	tcu::Vec3 v1(1.0, 1.0, 0.0);
200 	tcu::Vec3 v2(float(testParams.width) - 1.0f, float(testParams.height) - 1.0f, 0.0);
201 	tcu::Vec3 v3(float(testParams.width) - 1.0f, 1.0, 0.0);
202 
203 	std::vector<de::SharedPtr<BottomLevelAccelerationStructure> >	result;
204 	de::MovePtr<BottomLevelAccelerationStructure>					bottomLevelAccelerationStructure	= makeBottomLevelAccelerationStructure();
205 	bottomLevelAccelerationStructure->setGeometryCount(1);
206 
207 	de::SharedPtr<RaytracedGeometryBase> geometry = makeRaytracedGeometry(VK_GEOMETRY_TYPE_TRIANGLES_KHR, VK_FORMAT_R32G32B32_SFLOAT, VK_INDEX_TYPE_NONE_KHR);
208 	geometry->addVertex(v0);
209 	geometry->addVertex(v1);
210 	geometry->addVertex(v2);
211 	geometry->addVertex(v2);
212 	geometry->addVertex(v1);
213 	geometry->addVertex(v3);
214 	bottomLevelAccelerationStructure->addGeometry(geometry);
215 
216 	result.push_back(de::SharedPtr<BottomLevelAccelerationStructure>(bottomLevelAccelerationStructure.release()));
217 
218 	return result;
219 }
220 
initTopAccelerationStructure(Context & context,TestParams & testParams,std::vector<de::SharedPtr<BottomLevelAccelerationStructure>> & bottomLevelAccelerationStructures)221 de::MovePtr<TopLevelAccelerationStructure> SingleSquareConfiguration::initTopAccelerationStructure (Context&		context,
222 																									TestParams&		testParams,
223 																									std::vector<de::SharedPtr<BottomLevelAccelerationStructure> >& bottomLevelAccelerationStructures)
224 {
225 	DE_UNREF(context);
226 	DE_UNREF(testParams);
227 
228 	de::MovePtr<TopLevelAccelerationStructure>	result						= makeTopLevelAccelerationStructure();
229 	result->setInstanceCount(1);
230 	result->addInstance(bottomLevelAccelerationStructures[0]);
231 
232 	return result;
233 }
234 
initRayTracingShaders(de::MovePtr<RayTracingPipeline> & rayTracingPipeline,Context & context,TestParams & testParams)235 void SingleSquareConfiguration::initRayTracingShaders (de::MovePtr<RayTracingPipeline>&		rayTracingPipeline,
236 													   Context&								context,
237 													   TestParams&							testParams)
238 {
239 	const DeviceInterface&						vkd						= context.getDeviceInterface();
240 	const VkDevice								device					= context.getDevice();
241 
242 	switch (testParams.callableShaderTestType)
243 	{
244 		case CSTT_RGEN_CALL:
245 		{
246 			rayTracingPipeline->addShader(VK_SHADER_STAGE_RAYGEN_BIT_KHR,		createShaderModule(vkd, device, context.getBinaryCollection().get("rgen_call"), 0), 0);
247 			rayTracingPipeline->addShader(VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR,	createShaderModule(vkd, device, context.getBinaryCollection().get("chit"), 0), 1);
248 			rayTracingPipeline->addShader(VK_SHADER_STAGE_MISS_BIT_KHR,			createShaderModule(vkd, device, context.getBinaryCollection().get("miss"), 0), 2);
249 			rayTracingPipeline->addShader(VK_SHADER_STAGE_CALLABLE_BIT_KHR,		createShaderModule(vkd, device, context.getBinaryCollection().get("call_0"), 0), 3);
250 			break;
251 		}
252 		case CSTT_RGEN_CALL_CALL:
253 		{
254 			rayTracingPipeline->addShader(VK_SHADER_STAGE_RAYGEN_BIT_KHR,		createShaderModule(vkd, device, context.getBinaryCollection().get("rgen_call"), 0), 0);
255 			rayTracingPipeline->addShader(VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR,	createShaderModule(vkd, device, context.getBinaryCollection().get("chit"), 0), 1);
256 			rayTracingPipeline->addShader(VK_SHADER_STAGE_MISS_BIT_KHR,			createShaderModule(vkd, device, context.getBinaryCollection().get("miss"), 0), 2);
257 			rayTracingPipeline->addShader(VK_SHADER_STAGE_CALLABLE_BIT_KHR,		createShaderModule(vkd, device, context.getBinaryCollection().get("call_call"), 0), 3);
258 			rayTracingPipeline->addShader(VK_SHADER_STAGE_CALLABLE_BIT_KHR,		createShaderModule(vkd, device, context.getBinaryCollection().get("call_0"), 0), 4);
259 			break;
260 		}
261 		case CSTT_HIT_CALL:
262 		{
263 			rayTracingPipeline->addShader(VK_SHADER_STAGE_RAYGEN_BIT_KHR,		createShaderModule(vkd, device, context.getBinaryCollection().get("rgen"), 0), 0);
264 			rayTracingPipeline->addShader(VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR,	createShaderModule(vkd, device, context.getBinaryCollection().get("chit_call"), 0), 1);
265 			rayTracingPipeline->addShader(VK_SHADER_STAGE_MISS_BIT_KHR,			createShaderModule(vkd, device, context.getBinaryCollection().get("miss_call"), 0), 2);
266 			rayTracingPipeline->addShader(VK_SHADER_STAGE_CALLABLE_BIT_KHR,		createShaderModule(vkd, device, context.getBinaryCollection().get("call_0"), 0), 3);
267 			break;
268 		}
269 		case CSTT_RGEN_MULTICALL:
270 		{
271 			rayTracingPipeline->addShader(VK_SHADER_STAGE_RAYGEN_BIT_KHR,		createShaderModule(vkd, device, context.getBinaryCollection().get("rgen_multicall"), 0), 0);
272 			rayTracingPipeline->addShader(VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR,	createShaderModule(vkd, device, context.getBinaryCollection().get("chit"), 0), 1);
273 			rayTracingPipeline->addShader(VK_SHADER_STAGE_MISS_BIT_KHR,			createShaderModule(vkd, device, context.getBinaryCollection().get("miss"), 0), 2);
274 			rayTracingPipeline->addShader(VK_SHADER_STAGE_CALLABLE_BIT_KHR,		createShaderModule(vkd, device, context.getBinaryCollection().get("call_0"), 0), 3);
275 			rayTracingPipeline->addShader(VK_SHADER_STAGE_CALLABLE_BIT_KHR,		createShaderModule(vkd, device, context.getBinaryCollection().get("call_1"), 0), 4);
276 			rayTracingPipeline->addShader(VK_SHADER_STAGE_CALLABLE_BIT_KHR,		createShaderModule(vkd, device, context.getBinaryCollection().get("call_2"), 0), 5);
277 			rayTracingPipeline->addShader(VK_SHADER_STAGE_CALLABLE_BIT_KHR,		createShaderModule(vkd, device, context.getBinaryCollection().get("call_3"), 0), 6);
278 			break;
279 		}
280 		default:
281 			TCU_THROW(InternalError, "Wrong shader test type");
282 	}
283 }
284 
initShaderBindingTables(de::MovePtr<RayTracingPipeline> & rayTracingPipeline,Context & context,TestParams & testParams,VkPipeline pipeline,deUint32 shaderGroupHandleSize,deUint32 shaderGroupBaseAlignment,de::MovePtr<BufferWithMemory> & raygenShaderBindingTable,de::MovePtr<BufferWithMemory> & hitShaderBindingTable,de::MovePtr<BufferWithMemory> & missShaderBindingTable,de::MovePtr<BufferWithMemory> & callableShaderBindingTable,VkStridedDeviceAddressRegionKHR & raygenShaderBindingTableRegion,VkStridedDeviceAddressRegionKHR & hitShaderBindingTableRegion,VkStridedDeviceAddressRegionKHR & missShaderBindingTableRegion,VkStridedDeviceAddressRegionKHR & callableShaderBindingTableRegion)285 void SingleSquareConfiguration::initShaderBindingTables (de::MovePtr<RayTracingPipeline>&	rayTracingPipeline,
286 														 Context&							context,
287 														 TestParams&						testParams,
288 														 VkPipeline							pipeline,
289 														 deUint32							shaderGroupHandleSize,
290 														 deUint32							shaderGroupBaseAlignment,
291 														 de::MovePtr<BufferWithMemory>&		raygenShaderBindingTable,
292 														 de::MovePtr<BufferWithMemory>&		hitShaderBindingTable,
293 														 de::MovePtr<BufferWithMemory>&		missShaderBindingTable,
294 														 de::MovePtr<BufferWithMemory>&		callableShaderBindingTable,
295 														 VkStridedDeviceAddressRegionKHR&	raygenShaderBindingTableRegion,
296 														 VkStridedDeviceAddressRegionKHR&	hitShaderBindingTableRegion,
297 														 VkStridedDeviceAddressRegionKHR&	missShaderBindingTableRegion,
298 														 VkStridedDeviceAddressRegionKHR&	callableShaderBindingTableRegion)
299 {
300 	const DeviceInterface&						vkd							= context.getDeviceInterface();
301 	const VkDevice								device						= context.getDevice();
302 	Allocator&									allocator					= context.getDefaultAllocator();
303 
304 	switch (testParams.callableShaderTestType)
305 	{
306 		case CSTT_RGEN_CALL:
307 		{
308 			raygenShaderBindingTable			= rayTracingPipeline->createShaderBindingTable(vkd, device, pipeline, allocator, shaderGroupHandleSize, shaderGroupBaseAlignment, 0, 1);
309 			hitShaderBindingTable				= rayTracingPipeline->createShaderBindingTable(vkd, device, pipeline, allocator, shaderGroupHandleSize, shaderGroupBaseAlignment, 1, 1);
310 			missShaderBindingTable				= rayTracingPipeline->createShaderBindingTable(vkd, device, pipeline, allocator, shaderGroupHandleSize, shaderGroupBaseAlignment, 2, 1);
311 			callableShaderBindingTable			= rayTracingPipeline->createShaderBindingTable(vkd, device, pipeline, allocator, shaderGroupHandleSize, shaderGroupBaseAlignment, 3, 1);
312 
313 			raygenShaderBindingTableRegion		= makeStridedDeviceAddressRegionKHR(getBufferDeviceAddress(vkd, device, raygenShaderBindingTable->get(), 0), shaderGroupHandleSize, shaderGroupHandleSize);
314 			hitShaderBindingTableRegion			= makeStridedDeviceAddressRegionKHR(getBufferDeviceAddress(vkd, device, hitShaderBindingTable->get(), 0), shaderGroupHandleSize, shaderGroupHandleSize);
315 			missShaderBindingTableRegion		= makeStridedDeviceAddressRegionKHR(getBufferDeviceAddress(vkd, device, missShaderBindingTable->get(), 0), shaderGroupHandleSize, shaderGroupHandleSize);
316 			callableShaderBindingTableRegion	= makeStridedDeviceAddressRegionKHR(getBufferDeviceAddress(vkd, device, callableShaderBindingTable->get(), 0), shaderGroupHandleSize, shaderGroupHandleSize);
317 			break;
318 		}
319 		case CSTT_RGEN_CALL_CALL:
320 		{
321 			raygenShaderBindingTable			= rayTracingPipeline->createShaderBindingTable(vkd, device, pipeline, allocator, shaderGroupHandleSize, shaderGroupBaseAlignment, 0, 1);
322 			hitShaderBindingTable				= rayTracingPipeline->createShaderBindingTable(vkd, device, pipeline, allocator, shaderGroupHandleSize, shaderGroupBaseAlignment, 1, 1);
323 			missShaderBindingTable				= rayTracingPipeline->createShaderBindingTable(vkd, device, pipeline, allocator, shaderGroupHandleSize, shaderGroupBaseAlignment, 2, 1);
324 			callableShaderBindingTable			= rayTracingPipeline->createShaderBindingTable(vkd, device, pipeline, allocator, shaderGroupHandleSize, shaderGroupBaseAlignment, 3, 2);
325 
326 			raygenShaderBindingTableRegion		= makeStridedDeviceAddressRegionKHR(getBufferDeviceAddress(vkd, device, raygenShaderBindingTable->get(), 0), shaderGroupHandleSize, shaderGroupHandleSize);
327 			hitShaderBindingTableRegion			= makeStridedDeviceAddressRegionKHR(getBufferDeviceAddress(vkd, device, hitShaderBindingTable->get(), 0), shaderGroupHandleSize, shaderGroupHandleSize);
328 			missShaderBindingTableRegion		= makeStridedDeviceAddressRegionKHR(getBufferDeviceAddress(vkd, device, missShaderBindingTable->get(), 0), shaderGroupHandleSize, shaderGroupHandleSize);
329 			callableShaderBindingTableRegion	= makeStridedDeviceAddressRegionKHR(getBufferDeviceAddress(vkd, device, callableShaderBindingTable->get(), 0), shaderGroupHandleSize, 2*shaderGroupHandleSize);
330 			break;
331 		}
332 		case CSTT_HIT_CALL:
333 		{
334 			raygenShaderBindingTable			= rayTracingPipeline->createShaderBindingTable(vkd, device, pipeline, allocator, shaderGroupHandleSize, shaderGroupBaseAlignment, 0, 1);
335 			hitShaderBindingTable				= rayTracingPipeline->createShaderBindingTable(vkd, device, pipeline, allocator, shaderGroupHandleSize, shaderGroupBaseAlignment, 1, 1);
336 			missShaderBindingTable				= rayTracingPipeline->createShaderBindingTable(vkd, device, pipeline, allocator, shaderGroupHandleSize, shaderGroupBaseAlignment, 2, 1);
337 			callableShaderBindingTable			= rayTracingPipeline->createShaderBindingTable(vkd, device, pipeline, allocator, shaderGroupHandleSize, shaderGroupBaseAlignment, 3, 1);
338 
339 			raygenShaderBindingTableRegion		= makeStridedDeviceAddressRegionKHR(getBufferDeviceAddress(vkd, device, raygenShaderBindingTable->get(), 0), shaderGroupHandleSize, shaderGroupHandleSize);
340 			hitShaderBindingTableRegion			= makeStridedDeviceAddressRegionKHR(getBufferDeviceAddress(vkd, device, hitShaderBindingTable->get(), 0), shaderGroupHandleSize, shaderGroupHandleSize);
341 			missShaderBindingTableRegion		= makeStridedDeviceAddressRegionKHR(getBufferDeviceAddress(vkd, device, missShaderBindingTable->get(), 0), shaderGroupHandleSize, shaderGroupHandleSize);
342 			callableShaderBindingTableRegion	= makeStridedDeviceAddressRegionKHR(getBufferDeviceAddress(vkd, device, callableShaderBindingTable->get(), 0), shaderGroupHandleSize, shaderGroupHandleSize);
343 			break;
344 		}
345 		case CSTT_RGEN_MULTICALL:
346 		{
347 			raygenShaderBindingTable			= rayTracingPipeline->createShaderBindingTable(vkd, device, pipeline, allocator, shaderGroupHandleSize, shaderGroupBaseAlignment, 0, 1);
348 			hitShaderBindingTable				= rayTracingPipeline->createShaderBindingTable(vkd, device, pipeline, allocator, shaderGroupHandleSize, shaderGroupBaseAlignment, 1, 1);
349 			missShaderBindingTable				= rayTracingPipeline->createShaderBindingTable(vkd, device, pipeline, allocator, shaderGroupHandleSize, shaderGroupBaseAlignment, 2, 1);
350 			callableShaderBindingTable			= rayTracingPipeline->createShaderBindingTable(vkd, device, pipeline, allocator, shaderGroupHandleSize, shaderGroupBaseAlignment, 3, 4);
351 
352 			raygenShaderBindingTableRegion		= makeStridedDeviceAddressRegionKHR(getBufferDeviceAddress(vkd, device, raygenShaderBindingTable->get(), 0), shaderGroupHandleSize, shaderGroupHandleSize);
353 			hitShaderBindingTableRegion			= makeStridedDeviceAddressRegionKHR(getBufferDeviceAddress(vkd, device, hitShaderBindingTable->get(), 0), shaderGroupHandleSize, shaderGroupHandleSize);
354 			missShaderBindingTableRegion		= makeStridedDeviceAddressRegionKHR(getBufferDeviceAddress(vkd, device, missShaderBindingTable->get(), 0), shaderGroupHandleSize, shaderGroupHandleSize);
355 			callableShaderBindingTableRegion	= makeStridedDeviceAddressRegionKHR(getBufferDeviceAddress(vkd, device, callableShaderBindingTable->get(), 0), shaderGroupHandleSize, 4*shaderGroupHandleSize);
356 			break;
357 		}
358 		default:
359 			TCU_THROW(InternalError, "Wrong shader test type");
360 	}
361 }
362 
verifyImage(BufferWithMemory * resultBuffer,Context & context,TestParams & testParams)363 bool SingleSquareConfiguration::verifyImage (BufferWithMemory* resultBuffer, Context& context, TestParams& testParams)
364 {
365 	// create result image
366 	tcu::TextureFormat			imageFormat						= vk::mapVkFormat(getResultImageFormat());
367 	tcu::ConstPixelBufferAccess	resultAccess(imageFormat, testParams.width, testParams.height, 1, resultBuffer->getAllocation().getHostPtr());
368 
369 	// create reference image
370 	std::vector<deUint32>		reference(testParams.width * testParams.height);
371 	tcu::PixelBufferAccess		referenceAccess(imageFormat, testParams.width, testParams.height, 1, reference.data());
372 
373 	tcu::UVec4 missValue, hitValue;
374 
375 	// clear reference image with hit and miss values ( hit works only for tests calling traceRayEXT in rgen shader )
376 	switch (testParams.callableShaderTestType)
377 	{
378 		case CSTT_RGEN_CALL:
379 			missValue	= tcu::UVec4(1, 0, 0, 0);
380 			hitValue	= tcu::UVec4(1, 0, 0, 0);
381 			break;
382 		case CSTT_RGEN_CALL_CALL:
383 			missValue	= tcu::UVec4(1, 0, 0, 0);
384 			hitValue	= tcu::UVec4(1, 0, 0, 0);
385 			break;
386 		case CSTT_HIT_CALL:
387 			missValue	= tcu::UVec4(1, 0, 0, 0);
388 			hitValue	= tcu::UVec4(2, 0, 0, 0);
389 			break;
390 		case CSTT_RGEN_MULTICALL:
391 			missValue	= tcu::UVec4(16, 0, 0, 0);
392 			hitValue	= tcu::UVec4(16, 0, 0, 0);
393 			break;
394 		default:
395 			TCU_THROW(InternalError, "Wrong shader test type");
396 	}
397 
398 	tcu::clear(referenceAccess, missValue);
399 	for (deUint32 y = 1; y < testParams.width - 1; ++y)
400 	for (deUint32 x = 1; x < testParams.height - 1; ++x)
401 		referenceAccess.setPixel(hitValue, x, y);
402 
403 	// compare result and reference
404 	return tcu::intThresholdCompare(context.getTestContext().getLog(), "Result comparison", "", referenceAccess, resultAccess, tcu::UVec4(0), tcu::COMPARE_LOG_RESULT);
405 }
406 
getResultImageFormat()407 VkFormat SingleSquareConfiguration::getResultImageFormat ()
408 {
409 	return VK_FORMAT_R32_UINT;
410 }
411 
getResultImageFormatSize()412 size_t SingleSquareConfiguration::getResultImageFormatSize ()
413 {
414 	return sizeof(deUint32);
415 }
416 
getClearValue()417 VkClearValue SingleSquareConfiguration::getClearValue ()
418 {
419 	return makeClearValueColorU32(0xFF, 0u, 0u, 0u);
420 }
421 
422 class CallableShaderTestCase : public TestCase
423 {
424 	public:
425 							CallableShaderTestCase			(tcu::TestContext& context, const char* name, const char* desc, const TestParams data);
426 							~CallableShaderTestCase			(void);
427 
428 	virtual void			checkSupport								(Context& context) const;
429 	virtual	void			initPrograms								(SourceCollections& programCollection) const;
430 	virtual TestInstance*	createInstance								(Context& context) const;
431 private:
432 	TestParams				m_data;
433 };
434 
435 class CallableShaderTestInstance : public TestInstance
436 {
437 public:
438 																	CallableShaderTestInstance	(Context& context, const TestParams& data);
439 																	~CallableShaderTestInstance	(void);
440 	tcu::TestStatus													iterate									(void);
441 
442 protected:
443 	de::MovePtr<BufferWithMemory>									runTest									();
444 private:
445 	TestParams														m_data;
446 };
447 
CallableShaderTestCase(tcu::TestContext & context,const char * name,const char * desc,const TestParams data)448 CallableShaderTestCase::CallableShaderTestCase (tcu::TestContext& context, const char* name, const char* desc, const TestParams data)
449 	: vkt::TestCase	(context, name, desc)
450 	, m_data		(data)
451 {
452 }
453 
~CallableShaderTestCase(void)454 CallableShaderTestCase::~CallableShaderTestCase (void)
455 {
456 }
457 
checkSupport(Context & context) const458 void CallableShaderTestCase::checkSupport (Context& context) const
459 {
460 	context.requireDeviceFunctionality("VK_KHR_acceleration_structure");
461 	context.requireDeviceFunctionality("VK_KHR_ray_tracing_pipeline");
462 
463 	const VkPhysicalDeviceRayTracingPipelineFeaturesKHR&	rayTracingPipelineFeaturesKHR		= context.getRayTracingPipelineFeatures();
464 	if (rayTracingPipelineFeaturesKHR.rayTracingPipeline == DE_FALSE )
465 		TCU_THROW(NotSupportedError, "Requires VkPhysicalDeviceRayTracingPipelineFeaturesKHR.rayTracingPipeline");
466 
467 	const VkPhysicalDeviceAccelerationStructureFeaturesKHR&	accelerationStructureFeaturesKHR	= context.getAccelerationStructureFeatures();
468 	if (accelerationStructureFeaturesKHR.accelerationStructure == DE_FALSE)
469 		TCU_THROW(TestError, "VK_KHR_ray_tracing_pipeline requires VkPhysicalDeviceAccelerationStructureFeaturesKHR.accelerationStructure");
470 }
471 
initPrograms(SourceCollections & programCollection) const472 void CallableShaderTestCase::initPrograms (SourceCollections& programCollection) const
473 {
474 	const vk::ShaderBuildOptions	buildOptions(programCollection.usedVulkanVersion, vk::SPIRV_VERSION_1_4, 0u, true);
475 	{
476 		std::stringstream css;
477 		css <<
478 			"#version 460 core\n"
479 			"#extension GL_EXT_ray_tracing : require\n"
480 			"layout(location = 0) rayPayloadEXT uvec4 hitValue;\n"
481 			"layout(r32ui, set = 0, binding = 0) uniform uimage2D result;\n"
482 			"layout(set = 0, binding = 1) uniform accelerationStructureEXT topLevelAS;\n"
483 			"\n"
484 			"void main()\n"
485 			"{\n"
486 			"  float tmin     = 0.0;\n"
487 			"  float tmax     = 1.0;\n"
488 			"  vec3  origin   = vec3(float(gl_LaunchIDEXT.x) + 0.5f, float(gl_LaunchIDEXT.y) + 0.5f, 0.5f);\n"
489 			"  vec3  direct   = vec3(0.0, 0.0, -1.0);\n"
490 			"  hitValue       = uvec4(0,0,0,0);\n"
491 			"  traceRayEXT(topLevelAS, 0, 0xFF, 0, 0, 0, origin, tmin, direct, tmax, 0);\n"
492 			"  imageStore(result, ivec2(gl_LaunchIDEXT.xy), hitValue);\n"
493 			"}\n";
494 		programCollection.glslSources.add("rgen") << glu::RaygenSource(updateRayTracingGLSL(css.str())) << buildOptions;
495 	}
496 
497 	{
498 		std::stringstream css;
499 		css <<
500 			"#version 460 core\n"
501 			"#extension GL_EXT_ray_tracing : require\n"
502 			"layout(location = 0) callableDataEXT uvec4 value;\n"
503 			"layout(r32ui, set = 0, binding = 0) uniform uimage2D result;\n"
504 			"layout(set = 0, binding = 1) uniform accelerationStructureEXT topLevelAS;\n"
505 			"\n"
506 			"void main()\n"
507 			"{\n"
508 			"  executeCallableEXT(0, 0);\n"
509 			"  imageStore(result, ivec2(gl_LaunchIDEXT.xy), value);\n"
510 			"}\n";
511 		programCollection.glslSources.add("rgen_call") << glu::RaygenSource(updateRayTracingGLSL(css.str())) << buildOptions;
512 	}
513 
514 	{
515 		std::stringstream css;
516 		css <<
517 			"#version 460 core\n"
518 			"#extension GL_EXT_ray_tracing : require\n"
519 			"struct CallValue\n"
520 			"{\n"
521 			"  ivec4 a;\n"
522 			"  vec4  b;\n"
523 			"};\n"
524 			"layout(location = 0) callableDataEXT uvec4 value0;\n"
525 			"layout(location = 1) callableDataEXT uint value1;\n"
526 			"layout(location = 2) callableDataEXT CallValue value2;\n"
527 			"layout(location = 4) callableDataEXT vec3 value3;\n"
528 			"layout(r32ui, set = 0, binding = 0) uniform uimage2D result;\n"
529 			"layout(set = 0, binding = 1) uniform accelerationStructureEXT topLevelAS;\n"
530 			"\n"
531 			"void main()\n"
532 			"{\n"
533 			"  executeCallableEXT(0, 0);\n"
534 			"  executeCallableEXT(1, 1);\n"
535 			"  executeCallableEXT(2, 2);\n"
536 			"  executeCallableEXT(3, 4);\n"
537 			"  uint resultValue = value0.x + value1 + value2.a.x * uint(floor(value2.b.y)) + uint(floor(value3.z));\n"
538 			"  imageStore(result, ivec2(gl_LaunchIDEXT.xy), uvec4(resultValue, 0, 0, 0));\n"
539 			"}\n";
540 		programCollection.glslSources.add("rgen_multicall") << glu::RaygenSource(updateRayTracingGLSL(css.str())) << buildOptions;
541 	}
542 
543 	{
544 		std::stringstream css;
545 		css <<
546 			"#version 460 core\n"
547 			"#extension GL_EXT_ray_tracing : require\n"
548 			"layout(location = 0) rayPayloadInEXT uvec4 hitValue;\n"
549 			"void main()\n"
550 			"{\n"
551 			"  hitValue = uvec4(1,0,0,1);\n"
552 			"}\n";
553 
554 		programCollection.glslSources.add("chit") << glu::ClosestHitSource(updateRayTracingGLSL(css.str())) << buildOptions;
555 	}
556 
557 	{
558 		std::stringstream css;
559 		css <<
560 			"#version 460 core\n"
561 			"#extension GL_EXT_ray_tracing : require\n"
562 			"layout(location = 0) callableDataEXT uvec4 value;\n"
563 			"layout(location = 0) rayPayloadInEXT uvec4 hitValue;\n"
564 			"void main()\n"
565 			"{\n"
566 			"  executeCallableEXT(0, 0);\n"
567 			"  hitValue = value;\n"
568 			"  hitValue.x = hitValue.x + 1;\n"
569 			"}\n";
570 
571 		programCollection.glslSources.add("chit_call") << glu::ClosestHitSource(updateRayTracingGLSL(css.str())) << buildOptions;
572 	}
573 
574 	{
575 		std::stringstream css;
576 		css <<
577 			"#version 460 core\n"
578 			"#extension GL_EXT_ray_tracing : require\n"
579 			"layout(location = 0) rayPayloadInEXT uvec4 hitValue;\n"
580 			"void main()\n"
581 			"{\n"
582 			"  hitValue = uvec4(0,0,0,1);\n"
583 			"}\n";
584 
585 		programCollection.glslSources.add("miss") << glu::MissSource(updateRayTracingGLSL(css.str())) << buildOptions;
586 	}
587 
588 	{
589 		std::stringstream css;
590 		css <<
591 			"#version 460 core\n"
592 			"#extension GL_EXT_ray_tracing : require\n"
593 			"layout(location = 0) callableDataEXT uvec4 value;\n"
594 			"layout(location = 0) rayPayloadInEXT uvec4 hitValue;\n"
595 			"void main()\n"
596 			"{\n"
597 			"  executeCallableEXT(0, 0);\n"
598 			"  hitValue = value;\n"
599 			"}\n";
600 
601 		programCollection.glslSources.add("miss_call") << glu::MissSource(updateRayTracingGLSL(css.str())) << buildOptions;
602 	}
603 
604 	std::vector<std::string> callableDataDefinition =
605 	{
606 		"layout(location = 0) callableDataInEXT uvec4 result;\n",
607 		"layout(location = 1) callableDataInEXT uint result;\n",
608 		"struct CallValue\n{\n  ivec4 a;\n  vec4  b;\n};\nlayout(location = 2) callableDataInEXT CallValue result;\n",
609 		"layout(location = 4) callableDataInEXT vec3 result;\n"
610 	};
611 
612 	std::vector<std::string> callableDataComputation =
613 	{
614 		"  result = uvec4(1,0,0,1);\n",
615 		"  result = 2;\n",
616 		"  result.a = ivec4(3,0,0,1);\n  result.b = vec4(1.0, 3.2, 0.0, 1);\n",
617 		"  result = vec3(0.0, 0.0, 4.3);\n",
618 	};
619 
620 	for (deUint32 idx = 0; idx < callableDataDefinition.size(); ++idx)
621 	{
622 		std::stringstream css;
623 		css <<
624 			"#version 460 core\n"
625 			"#extension GL_EXT_ray_tracing : require\n"
626 			<< callableDataDefinition[idx] <<
627 			"void main()\n"
628 			"{\n"
629 			<< callableDataComputation[idx] <<
630 			"}\n";
631 		std::stringstream csname;
632 		csname << "call_" << idx;
633 
634 		programCollection.glslSources.add(csname.str()) << glu::CallableSource(updateRayTracingGLSL(css.str())) << buildOptions;
635 	}
636 
637 	{
638 		std::stringstream css;
639 		css <<
640 			"#version 460 core\n"
641 			"#extension GL_EXT_ray_tracing : require\n"
642 			"layout(location = 0) callableDataInEXT uvec4 result;\n"
643 			"layout(location = 1) callableDataEXT uvec4 info;\n"
644 			"void main()\n"
645 			"{\n"
646 			"  executeCallableEXT(1, 1);\n"
647 			"  result = info;\n"
648 			"}\n";
649 
650 		programCollection.glslSources.add("call_call") << glu::CallableSource(updateRayTracingGLSL(css.str())) << buildOptions;
651 	}
652 }
653 
createInstance(Context & context) const654 TestInstance* CallableShaderTestCase::createInstance (Context& context) const
655 {
656 	return new CallableShaderTestInstance(context, m_data);
657 }
658 
CallableShaderTestInstance(Context & context,const TestParams & data)659 CallableShaderTestInstance::CallableShaderTestInstance (Context& context, const TestParams& data)
660 	: vkt::TestInstance		(context)
661 	, m_data				(data)
662 {
663 }
664 
~CallableShaderTestInstance(void)665 CallableShaderTestInstance::~CallableShaderTestInstance (void)
666 {
667 }
668 
runTest()669 de::MovePtr<BufferWithMemory> CallableShaderTestInstance::runTest ()
670 {
671 	const InstanceInterface&			vki									= m_context.getInstanceInterface();
672 	const DeviceInterface&				vkd									= m_context.getDeviceInterface();
673 	const VkDevice						device								= m_context.getDevice();
674 	const VkPhysicalDevice				physicalDevice						= m_context.getPhysicalDevice();
675 	const deUint32						queueFamilyIndex					= m_context.getUniversalQueueFamilyIndex();
676 	const VkQueue						queue								= m_context.getUniversalQueue();
677 	Allocator&							allocator							= m_context.getDefaultAllocator();
678 	const deUint32						pixelCount							= m_data.width * m_data.height * 1;
679 
680 	const Move<VkDescriptorSetLayout>	descriptorSetLayout					= DescriptorSetLayoutBuilder()
681 																					.addSingleBinding(VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, ALL_RAY_TRACING_STAGES)
682 																					.addSingleBinding(VK_DESCRIPTOR_TYPE_ACCELERATION_STRUCTURE_KHR, ALL_RAY_TRACING_STAGES)
683 																					.build(vkd, device);
684 	const Move<VkDescriptorPool>		descriptorPool						= DescriptorPoolBuilder()
685 																					.addType(VK_DESCRIPTOR_TYPE_STORAGE_IMAGE)
686 																					.addType(VK_DESCRIPTOR_TYPE_ACCELERATION_STRUCTURE_KHR)
687 																					.build(vkd, device, VK_DESCRIPTOR_POOL_CREATE_FREE_DESCRIPTOR_SET_BIT, 1u);
688 	const Move<VkDescriptorSet>			descriptorSet						= makeDescriptorSet(vkd, device, *descriptorPool, *descriptorSetLayout);
689 	const Move<VkPipelineLayout>		pipelineLayout						= makePipelineLayout(vkd, device, descriptorSetLayout.get());
690 
691 	de::MovePtr<RayTracingPipeline>		rayTracingPipeline					= de::newMovePtr<RayTracingPipeline>();
692 	m_data.testConfiguration->initRayTracingShaders(rayTracingPipeline, m_context, m_data);
693 	Move<VkPipeline>					pipeline							= rayTracingPipeline->createPipeline(vkd, device, *pipelineLayout);
694 
695 	de::MovePtr<BufferWithMemory>		raygenShaderBindingTable;
696 	de::MovePtr<BufferWithMemory>		hitShaderBindingTable;
697 	de::MovePtr<BufferWithMemory>		missShaderBindingTable;
698 	de::MovePtr<BufferWithMemory>		callableShaderBindingTable;
699 	VkStridedDeviceAddressRegionKHR		raygenShaderBindingTableRegion;
700 	VkStridedDeviceAddressRegionKHR		hitShaderBindingTableRegion;
701 	VkStridedDeviceAddressRegionKHR		missShaderBindingTableRegion;
702 	VkStridedDeviceAddressRegionKHR		callableShaderBindingTableRegion;
703 	m_data.testConfiguration->initShaderBindingTables(rayTracingPipeline, m_context, m_data, *pipeline, getShaderGroupHandleSize(vki, physicalDevice), getShaderGroupBaseAlignment(vki, physicalDevice), raygenShaderBindingTable, hitShaderBindingTable, missShaderBindingTable, callableShaderBindingTable, raygenShaderBindingTableRegion, hitShaderBindingTableRegion, missShaderBindingTableRegion, callableShaderBindingTableRegion);
704 
705 	const VkFormat						imageFormat							= m_data.testConfiguration->getResultImageFormat();
706 	const VkImageCreateInfo				imageCreateInfo						= makeImageCreateInfo(m_data.width, m_data.height, imageFormat);
707 	const VkImageSubresourceRange		imageSubresourceRange				= makeImageSubresourceRange(VK_IMAGE_ASPECT_COLOR_BIT, 0u, 1u, 0u, 1u);
708 	const de::MovePtr<ImageWithMemory>	image								= de::MovePtr<ImageWithMemory>(new ImageWithMemory(vkd, device, allocator, imageCreateInfo, MemoryRequirement::Any));
709 	const Move<VkImageView>				imageView							= makeImageView(vkd, device, **image, VK_IMAGE_VIEW_TYPE_2D, imageFormat, imageSubresourceRange);
710 
711 	const VkBufferCreateInfo			resultBufferCreateInfo				= makeBufferCreateInfo(pixelCount*m_data.testConfiguration->getResultImageFormatSize(), VK_BUFFER_USAGE_TRANSFER_DST_BIT);
712 	const VkImageSubresourceLayers		resultBufferImageSubresourceLayers	= makeImageSubresourceLayers(VK_IMAGE_ASPECT_COLOR_BIT, 0u, 0u, 1u);
713 	const VkBufferImageCopy				resultBufferImageRegion				= makeBufferImageCopy(makeExtent3D(m_data.width, m_data.height, 1), resultBufferImageSubresourceLayers);
714 	de::MovePtr<BufferWithMemory>		resultBuffer						= de::MovePtr<BufferWithMemory>(new BufferWithMemory(vkd, device, allocator, resultBufferCreateInfo, MemoryRequirement::HostVisible));
715 
716 	const VkDescriptorImageInfo			descriptorImageInfo					= makeDescriptorImageInfo(DE_NULL, *imageView, VK_IMAGE_LAYOUT_GENERAL);
717 
718 	const Move<VkCommandPool>			cmdPool								= createCommandPool(vkd, device, 0, queueFamilyIndex);
719 	const Move<VkCommandBuffer>			cmdBuffer							= allocateCommandBuffer(vkd, device, *cmdPool, VK_COMMAND_BUFFER_LEVEL_PRIMARY);
720 
721 	std::vector<de::SharedPtr<BottomLevelAccelerationStructure> >	bottomLevelAccelerationStructures;
722 	de::MovePtr<TopLevelAccelerationStructure>						topLevelAccelerationStructure;
723 
724 	beginCommandBuffer(vkd, *cmdBuffer, 0u);
725 	{
726 		const VkImageMemoryBarrier			preImageBarrier						= makeImageMemoryBarrier(0u, VK_ACCESS_TRANSFER_WRITE_BIT,
727 																					VK_IMAGE_LAYOUT_UNDEFINED, VK_IMAGE_LAYOUT_TRANSFER_DST_OPTIMAL,
728 																					**image, imageSubresourceRange);
729 		cmdPipelineImageMemoryBarrier(vkd, *cmdBuffer, VK_PIPELINE_STAGE_TOP_OF_PIPE_BIT, VK_PIPELINE_STAGE_TRANSFER_BIT, &preImageBarrier);
730 
731 		const VkClearValue					clearValue							= m_data.testConfiguration->getClearValue();
732 		vkd.cmdClearColorImage(*cmdBuffer, **image, VK_IMAGE_LAYOUT_TRANSFER_DST_OPTIMAL, &clearValue.color, 1, &imageSubresourceRange);
733 
734 		const VkImageMemoryBarrier			postImageBarrier					= makeImageMemoryBarrier(VK_ACCESS_TRANSFER_WRITE_BIT, VK_ACCESS_ACCELERATION_STRUCTURE_READ_BIT_KHR | VK_ACCESS_ACCELERATION_STRUCTURE_WRITE_BIT_KHR,
735 																					VK_IMAGE_LAYOUT_TRANSFER_DST_OPTIMAL, VK_IMAGE_LAYOUT_GENERAL,
736 																					**image, imageSubresourceRange);
737 		cmdPipelineImageMemoryBarrier(vkd, *cmdBuffer, VK_PIPELINE_STAGE_TRANSFER_BIT, VK_PIPELINE_STAGE_ACCELERATION_STRUCTURE_BUILD_BIT_KHR, &postImageBarrier);
738 
739 		bottomLevelAccelerationStructures										= m_data.testConfiguration->initBottomAccelerationStructures(m_context, m_data);
740 		for (auto& blas : bottomLevelAccelerationStructures)
741 			blas->createAndBuild(vkd, device, *cmdBuffer, allocator);
742 		topLevelAccelerationStructure											= m_data.testConfiguration->initTopAccelerationStructure(m_context, m_data, bottomLevelAccelerationStructures);
743 		topLevelAccelerationStructure->createAndBuild(vkd, device, *cmdBuffer, allocator);
744 
745 		const TopLevelAccelerationStructure*			topLevelAccelerationStructurePtr		= topLevelAccelerationStructure.get();
746 		VkWriteDescriptorSetAccelerationStructureKHR	accelerationStructureWriteDescriptorSet	=
747 		{
748 			VK_STRUCTURE_TYPE_WRITE_DESCRIPTOR_SET_ACCELERATION_STRUCTURE_KHR,	//  VkStructureType						sType;
749 			DE_NULL,															//  const void*							pNext;
750 			1u,																	//  deUint32							accelerationStructureCount;
751 			topLevelAccelerationStructurePtr->getPtr(),							//  const VkAccelerationStructureKHR*	pAccelerationStructures;
752 		};
753 
754 		DescriptorSetUpdateBuilder()
755 			.writeSingle(*descriptorSet, DescriptorSetUpdateBuilder::Location::binding(0u), VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, &descriptorImageInfo)
756 			.writeSingle(*descriptorSet, DescriptorSetUpdateBuilder::Location::binding(1u), VK_DESCRIPTOR_TYPE_ACCELERATION_STRUCTURE_KHR, &accelerationStructureWriteDescriptorSet)
757 			.update(vkd, device);
758 
759 		vkd.cmdBindDescriptorSets(*cmdBuffer, VK_PIPELINE_BIND_POINT_RAY_TRACING_KHR, *pipelineLayout, 0, 1, &descriptorSet.get(), 0, DE_NULL);
760 
761 		vkd.cmdBindPipeline(*cmdBuffer, VK_PIPELINE_BIND_POINT_RAY_TRACING_KHR, *pipeline);
762 
763 		cmdTraceRays(vkd,
764 			*cmdBuffer,
765 			&raygenShaderBindingTableRegion,
766 			&missShaderBindingTableRegion,
767 			&hitShaderBindingTableRegion,
768 			&callableShaderBindingTableRegion,
769 			m_data.width, m_data.height, 1);
770 
771 		const VkMemoryBarrier							postTraceMemoryBarrier					= makeMemoryBarrier(VK_ACCESS_SHADER_WRITE_BIT, VK_ACCESS_TRANSFER_READ_BIT);
772 		const VkMemoryBarrier							postCopyMemoryBarrier					= makeMemoryBarrier(VK_ACCESS_TRANSFER_WRITE_BIT, VK_ACCESS_HOST_READ_BIT);
773 		cmdPipelineMemoryBarrier(vkd, *cmdBuffer, VK_PIPELINE_STAGE_RAY_TRACING_SHADER_BIT_KHR, VK_PIPELINE_STAGE_TRANSFER_BIT, &postTraceMemoryBarrier);
774 
775 		vkd.cmdCopyImageToBuffer(*cmdBuffer, **image, VK_IMAGE_LAYOUT_GENERAL, **resultBuffer, 1u, &resultBufferImageRegion);
776 
777 		cmdPipelineMemoryBarrier(vkd, *cmdBuffer, VK_PIPELINE_STAGE_TRANSFER_BIT, VK_PIPELINE_STAGE_HOST_BIT, &postCopyMemoryBarrier);
778 	}
779 	endCommandBuffer(vkd, *cmdBuffer);
780 
781 	submitCommandsAndWait(vkd, device, queue, cmdBuffer.get());
782 
783 	invalidateMappedMemoryRange(vkd, device, resultBuffer->getAllocation().getMemory(), resultBuffer->getAllocation().getOffset(), VK_WHOLE_SIZE);
784 
785 	return resultBuffer;
786 }
787 
iterate(void)788 tcu::TestStatus CallableShaderTestInstance::iterate (void)
789 {
790 	// run test using arrays of pointers
791 	const de::MovePtr<BufferWithMemory>	buffer		= runTest();
792 
793 	if (!m_data.testConfiguration->verifyImage(buffer.get(), m_context, m_data))
794 		return tcu::TestStatus::fail("Fail");
795 	return tcu::TestStatus::pass("Pass");
796 }
797 
798 }	// anonymous
799 
createCallableShadersTests(tcu::TestContext & testCtx)800 tcu::TestCaseGroup*	createCallableShadersTests (tcu::TestContext& testCtx)
801 {
802 	de::MovePtr<tcu::TestCaseGroup> group(new tcu::TestCaseGroup(testCtx, "callable_shader", "Tests veryfying callable shaders"));
803 
804 	struct CallableShaderTestTypeData
805 	{
806 		CallableShaderTestType					shaderTestType;
807 		const char*								name;
808 	} callableShaderTestTypes[] =
809 	{
810 		{ CSTT_RGEN_CALL,		"rgen_call"			},
811 		{ CSTT_RGEN_CALL_CALL,	"rgen_call_call"	},
812 		{ CSTT_HIT_CALL,		"hit_call"			},
813 		{ CSTT_RGEN_MULTICALL,	"rgen_multicall"	},
814 	};
815 
816 	for (size_t shaderTestNdx = 0; shaderTestNdx < DE_LENGTH_OF_ARRAY(callableShaderTestTypes); ++shaderTestNdx)
817 	{
818 		TestParams testParams
819 		{
820 			TEST_WIDTH,
821 			TEST_HEIGHT,
822 			callableShaderTestTypes[shaderTestNdx].shaderTestType,
823 			de::SharedPtr<TestConfiguration>(new SingleSquareConfiguration())
824 		};
825 		group->addChild(new CallableShaderTestCase(group->getTestContext(), callableShaderTestTypes[shaderTestNdx].name, "", testParams));
826 	}
827 
828 	return group.release();
829 }
830 
831 }	// RayTracing
832 
833 }	// vkt
834