1 /*------------------------------------------------------------------------
2 * Vulkan Conformance Tests
3 * ------------------------
4 *
5 * Copyright (c) 2020 The Khronos Group Inc.
6 *
7 * Licensed under the Apache License, Version 2.0 (the "License");
8 * you may not use this file except in compliance with the License.
9 * You may obtain a copy of the License at
10 *
11 * http://www.apache.org/licenses/LICENSE-2.0
12 *
13 * Unless required by applicable law or agreed to in writing, software
14 * distributed under the License is distributed on an "AS IS" BASIS,
15 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16 * See the License for the specific language governing permissions and
17 * limitations under the License.
18 *
19 *//*!
20 * \file
21 * \brief Ray Tracing Callable Shader tests
22 *//*--------------------------------------------------------------------*/
23
24 #include "vktRayTracingCallableShadersTests.hpp"
25
26 #include "vkDefs.hpp"
27
28 #include "vktTestCase.hpp"
29 #include "vktTestGroupUtil.hpp"
30 #include "vkCmdUtil.hpp"
31 #include "vkObjUtil.hpp"
32 #include "vkBuilderUtil.hpp"
33 #include "vkBarrierUtil.hpp"
34 #include "vkBufferWithMemory.hpp"
35 #include "vkImageWithMemory.hpp"
36 #include "vkTypeUtil.hpp"
37 #include "vkImageUtil.hpp"
38 #include "deRandom.hpp"
39 #include "tcuTexture.hpp"
40 #include "tcuTextureUtil.hpp"
41 #include "tcuTestLog.hpp"
42 #include "tcuImageCompare.hpp"
43
44 #include "vkRayTracingUtil.hpp"
45
46 namespace vkt
47 {
48 namespace RayTracing
49 {
50 namespace
51 {
52 using namespace vk;
53 using namespace vkt;
54
55 static const VkFlags ALL_RAY_TRACING_STAGES = VK_SHADER_STAGE_RAYGEN_BIT_KHR
56 | VK_SHADER_STAGE_ANY_HIT_BIT_KHR
57 | VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR
58 | VK_SHADER_STAGE_MISS_BIT_KHR
59 | VK_SHADER_STAGE_INTERSECTION_BIT_KHR
60 | VK_SHADER_STAGE_CALLABLE_BIT_KHR;
61
62 enum CallableShaderTestType
63 {
64 CSTT_RGEN_CALL = 0,
65 CSTT_RGEN_CALL_CALL = 1,
66 CSTT_HIT_CALL = 2,
67 CSTT_RGEN_MULTICALL = 3,
68 CSTT_COUNT
69 };
70
71 const deUint32 TEST_WIDTH = 8;
72 const deUint32 TEST_HEIGHT = 8;
73
74 struct TestParams;
75
76 class TestConfiguration
77 {
78 public:
79 virtual std::vector<de::SharedPtr<BottomLevelAccelerationStructure>> initBottomAccelerationStructures (Context& context,
80 TestParams& testParams) = 0;
81 virtual de::MovePtr<TopLevelAccelerationStructure> initTopAccelerationStructure (Context& context,
82 TestParams& testParams,
83 std::vector<de::SharedPtr<BottomLevelAccelerationStructure> >& bottomLevelAccelerationStructures) = 0;
84 virtual void initRayTracingShaders (de::MovePtr<RayTracingPipeline>& rayTracingPipeline,
85 Context& context,
86 TestParams& testParams) = 0;
87 virtual void initShaderBindingTables (de::MovePtr<RayTracingPipeline>& rayTracingPipeline,
88 Context& context,
89 TestParams& testParams,
90 VkPipeline pipeline,
91 deUint32 shaderGroupHandleSize,
92 deUint32 shaderGroupBaseAlignment,
93 de::MovePtr<BufferWithMemory>& raygenShaderBindingTable,
94 de::MovePtr<BufferWithMemory>& hitShaderBindingTable,
95 de::MovePtr<BufferWithMemory>& missShaderBindingTable,
96 de::MovePtr<BufferWithMemory>& callableShaderBindingTable,
97 VkStridedDeviceAddressRegionKHR& raygenShaderBindingTableRegion,
98 VkStridedDeviceAddressRegionKHR& hitShaderBindingTableRegion,
99 VkStridedDeviceAddressRegionKHR& missShaderBindingTableRegion,
100 VkStridedDeviceAddressRegionKHR& callableShaderBindingTableRegion) = 0;
101 virtual bool verifyImage (BufferWithMemory* resultBuffer,
102 Context& context,
103 TestParams& testParams) = 0;
104 virtual VkFormat getResultImageFormat () = 0;
105 virtual size_t getResultImageFormatSize () = 0;
106 virtual VkClearValue getClearValue () = 0;
107 };
108
109 struct TestParams
110 {
111 deUint32 width;
112 deUint32 height;
113 CallableShaderTestType callableShaderTestType;
114 de::SharedPtr<TestConfiguration> testConfiguration;
115
116 };
117
getShaderGroupHandleSize(const InstanceInterface & vki,const VkPhysicalDevice physicalDevice)118 deUint32 getShaderGroupHandleSize (const InstanceInterface& vki,
119 const VkPhysicalDevice physicalDevice)
120 {
121 de::MovePtr<RayTracingProperties> rayTracingPropertiesKHR;
122
123 rayTracingPropertiesKHR = makeRayTracingProperties(vki, physicalDevice);
124 return rayTracingPropertiesKHR->getShaderGroupHandleSize();
125 }
126
getShaderGroupBaseAlignment(const InstanceInterface & vki,const VkPhysicalDevice physicalDevice)127 deUint32 getShaderGroupBaseAlignment (const InstanceInterface& vki,
128 const VkPhysicalDevice physicalDevice)
129 {
130 de::MovePtr<RayTracingProperties> rayTracingPropertiesKHR;
131
132 rayTracingPropertiesKHR = makeRayTracingProperties(vki, physicalDevice);
133 return rayTracingPropertiesKHR->getShaderGroupBaseAlignment();
134 }
135
makeImageCreateInfo(deUint32 width,deUint32 height,VkFormat format)136 VkImageCreateInfo makeImageCreateInfo (deUint32 width, deUint32 height, VkFormat format)
137 {
138 const VkImageCreateInfo imageCreateInfo =
139 {
140 VK_STRUCTURE_TYPE_IMAGE_CREATE_INFO, // VkStructureType sType;
141 DE_NULL, // const void* pNext;
142 (VkImageCreateFlags)0u, // VkImageCreateFlags flags;
143 VK_IMAGE_TYPE_2D, // VkImageType imageType;
144 format, // VkFormat format;
145 makeExtent3D(width, height, 1), // VkExtent3D extent;
146 1u, // deUint32 mipLevels;
147 1u, // deUint32 arrayLayers;
148 VK_SAMPLE_COUNT_1_BIT, // VkSampleCountFlagBits samples;
149 VK_IMAGE_TILING_OPTIMAL, // VkImageTiling tiling;
150 VK_IMAGE_USAGE_STORAGE_BIT | VK_IMAGE_USAGE_TRANSFER_SRC_BIT | VK_IMAGE_USAGE_TRANSFER_DST_BIT, // VkImageUsageFlags usage;
151 VK_SHARING_MODE_EXCLUSIVE, // VkSharingMode sharingMode;
152 0u, // deUint32 queueFamilyIndexCount;
153 DE_NULL, // const deUint32* pQueueFamilyIndices;
154 VK_IMAGE_LAYOUT_UNDEFINED // VkImageLayout initialLayout;
155 };
156
157 return imageCreateInfo;
158 }
159
160 class SingleSquareConfiguration : public TestConfiguration
161 {
162 public:
163 std::vector<de::SharedPtr<BottomLevelAccelerationStructure>> initBottomAccelerationStructures (Context& context,
164 TestParams& testParams) override;
165 de::MovePtr<TopLevelAccelerationStructure> initTopAccelerationStructure (Context& context,
166 TestParams& testParams,
167 std::vector<de::SharedPtr<BottomLevelAccelerationStructure> >& bottomLevelAccelerationStructures) override;
168 void initRayTracingShaders (de::MovePtr<RayTracingPipeline>& rayTracingPipeline,
169 Context& context,
170 TestParams& testParams) override;
171 void initShaderBindingTables (de::MovePtr<RayTracingPipeline>& rayTracingPipeline,
172 Context& context,
173 TestParams& testParams,
174 VkPipeline pipeline,
175 deUint32 shaderGroupHandleSize,
176 deUint32 shaderGroupBaseAlignment,
177 de::MovePtr<BufferWithMemory>& raygenShaderBindingTable,
178 de::MovePtr<BufferWithMemory>& hitShaderBindingTable,
179 de::MovePtr<BufferWithMemory>& missShaderBindingTable,
180 de::MovePtr<BufferWithMemory>& callableShaderBindingTable,
181 VkStridedDeviceAddressRegionKHR& raygenShaderBindingTableRegion,
182 VkStridedDeviceAddressRegionKHR& hitShaderBindingTableRegion,
183 VkStridedDeviceAddressRegionKHR& missShaderBindingTableRegion,
184 VkStridedDeviceAddressRegionKHR& callableShaderBindingTableRegion) override;
185 bool verifyImage (BufferWithMemory* resultBuffer,
186 Context& context,
187 TestParams& testParams) override;
188 VkFormat getResultImageFormat () override;
189 size_t getResultImageFormatSize () override;
190 VkClearValue getClearValue () override;
191 };
192
initBottomAccelerationStructures(Context & context,TestParams & testParams)193 std::vector<de::SharedPtr<BottomLevelAccelerationStructure> > SingleSquareConfiguration::initBottomAccelerationStructures (Context& context,
194 TestParams& testParams)
195 {
196 DE_UNREF(context);
197
198 tcu::Vec3 v0(1.0, float(testParams.height) - 1.0f, 0.0);
199 tcu::Vec3 v1(1.0, 1.0, 0.0);
200 tcu::Vec3 v2(float(testParams.width) - 1.0f, float(testParams.height) - 1.0f, 0.0);
201 tcu::Vec3 v3(float(testParams.width) - 1.0f, 1.0, 0.0);
202
203 std::vector<de::SharedPtr<BottomLevelAccelerationStructure> > result;
204 de::MovePtr<BottomLevelAccelerationStructure> bottomLevelAccelerationStructure = makeBottomLevelAccelerationStructure();
205 bottomLevelAccelerationStructure->setGeometryCount(1);
206
207 de::SharedPtr<RaytracedGeometryBase> geometry = makeRaytracedGeometry(VK_GEOMETRY_TYPE_TRIANGLES_KHR, VK_FORMAT_R32G32B32_SFLOAT, VK_INDEX_TYPE_NONE_KHR);
208 geometry->addVertex(v0);
209 geometry->addVertex(v1);
210 geometry->addVertex(v2);
211 geometry->addVertex(v2);
212 geometry->addVertex(v1);
213 geometry->addVertex(v3);
214 bottomLevelAccelerationStructure->addGeometry(geometry);
215
216 result.push_back(de::SharedPtr<BottomLevelAccelerationStructure>(bottomLevelAccelerationStructure.release()));
217
218 return result;
219 }
220
initTopAccelerationStructure(Context & context,TestParams & testParams,std::vector<de::SharedPtr<BottomLevelAccelerationStructure>> & bottomLevelAccelerationStructures)221 de::MovePtr<TopLevelAccelerationStructure> SingleSquareConfiguration::initTopAccelerationStructure (Context& context,
222 TestParams& testParams,
223 std::vector<de::SharedPtr<BottomLevelAccelerationStructure> >& bottomLevelAccelerationStructures)
224 {
225 DE_UNREF(context);
226 DE_UNREF(testParams);
227
228 de::MovePtr<TopLevelAccelerationStructure> result = makeTopLevelAccelerationStructure();
229 result->setInstanceCount(1);
230 result->addInstance(bottomLevelAccelerationStructures[0]);
231
232 return result;
233 }
234
initRayTracingShaders(de::MovePtr<RayTracingPipeline> & rayTracingPipeline,Context & context,TestParams & testParams)235 void SingleSquareConfiguration::initRayTracingShaders (de::MovePtr<RayTracingPipeline>& rayTracingPipeline,
236 Context& context,
237 TestParams& testParams)
238 {
239 const DeviceInterface& vkd = context.getDeviceInterface();
240 const VkDevice device = context.getDevice();
241
242 switch (testParams.callableShaderTestType)
243 {
244 case CSTT_RGEN_CALL:
245 {
246 rayTracingPipeline->addShader(VK_SHADER_STAGE_RAYGEN_BIT_KHR, createShaderModule(vkd, device, context.getBinaryCollection().get("rgen_call"), 0), 0);
247 rayTracingPipeline->addShader(VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR, createShaderModule(vkd, device, context.getBinaryCollection().get("chit"), 0), 1);
248 rayTracingPipeline->addShader(VK_SHADER_STAGE_MISS_BIT_KHR, createShaderModule(vkd, device, context.getBinaryCollection().get("miss"), 0), 2);
249 rayTracingPipeline->addShader(VK_SHADER_STAGE_CALLABLE_BIT_KHR, createShaderModule(vkd, device, context.getBinaryCollection().get("call_0"), 0), 3);
250 break;
251 }
252 case CSTT_RGEN_CALL_CALL:
253 {
254 rayTracingPipeline->addShader(VK_SHADER_STAGE_RAYGEN_BIT_KHR, createShaderModule(vkd, device, context.getBinaryCollection().get("rgen_call"), 0), 0);
255 rayTracingPipeline->addShader(VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR, createShaderModule(vkd, device, context.getBinaryCollection().get("chit"), 0), 1);
256 rayTracingPipeline->addShader(VK_SHADER_STAGE_MISS_BIT_KHR, createShaderModule(vkd, device, context.getBinaryCollection().get("miss"), 0), 2);
257 rayTracingPipeline->addShader(VK_SHADER_STAGE_CALLABLE_BIT_KHR, createShaderModule(vkd, device, context.getBinaryCollection().get("call_call"), 0), 3);
258 rayTracingPipeline->addShader(VK_SHADER_STAGE_CALLABLE_BIT_KHR, createShaderModule(vkd, device, context.getBinaryCollection().get("call_0"), 0), 4);
259 break;
260 }
261 case CSTT_HIT_CALL:
262 {
263 rayTracingPipeline->addShader(VK_SHADER_STAGE_RAYGEN_BIT_KHR, createShaderModule(vkd, device, context.getBinaryCollection().get("rgen"), 0), 0);
264 rayTracingPipeline->addShader(VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR, createShaderModule(vkd, device, context.getBinaryCollection().get("chit_call"), 0), 1);
265 rayTracingPipeline->addShader(VK_SHADER_STAGE_MISS_BIT_KHR, createShaderModule(vkd, device, context.getBinaryCollection().get("miss_call"), 0), 2);
266 rayTracingPipeline->addShader(VK_SHADER_STAGE_CALLABLE_BIT_KHR, createShaderModule(vkd, device, context.getBinaryCollection().get("call_0"), 0), 3);
267 break;
268 }
269 case CSTT_RGEN_MULTICALL:
270 {
271 rayTracingPipeline->addShader(VK_SHADER_STAGE_RAYGEN_BIT_KHR, createShaderModule(vkd, device, context.getBinaryCollection().get("rgen_multicall"), 0), 0);
272 rayTracingPipeline->addShader(VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR, createShaderModule(vkd, device, context.getBinaryCollection().get("chit"), 0), 1);
273 rayTracingPipeline->addShader(VK_SHADER_STAGE_MISS_BIT_KHR, createShaderModule(vkd, device, context.getBinaryCollection().get("miss"), 0), 2);
274 rayTracingPipeline->addShader(VK_SHADER_STAGE_CALLABLE_BIT_KHR, createShaderModule(vkd, device, context.getBinaryCollection().get("call_0"), 0), 3);
275 rayTracingPipeline->addShader(VK_SHADER_STAGE_CALLABLE_BIT_KHR, createShaderModule(vkd, device, context.getBinaryCollection().get("call_1"), 0), 4);
276 rayTracingPipeline->addShader(VK_SHADER_STAGE_CALLABLE_BIT_KHR, createShaderModule(vkd, device, context.getBinaryCollection().get("call_2"), 0), 5);
277 rayTracingPipeline->addShader(VK_SHADER_STAGE_CALLABLE_BIT_KHR, createShaderModule(vkd, device, context.getBinaryCollection().get("call_3"), 0), 6);
278 break;
279 }
280 default:
281 TCU_THROW(InternalError, "Wrong shader test type");
282 }
283 }
284
initShaderBindingTables(de::MovePtr<RayTracingPipeline> & rayTracingPipeline,Context & context,TestParams & testParams,VkPipeline pipeline,deUint32 shaderGroupHandleSize,deUint32 shaderGroupBaseAlignment,de::MovePtr<BufferWithMemory> & raygenShaderBindingTable,de::MovePtr<BufferWithMemory> & hitShaderBindingTable,de::MovePtr<BufferWithMemory> & missShaderBindingTable,de::MovePtr<BufferWithMemory> & callableShaderBindingTable,VkStridedDeviceAddressRegionKHR & raygenShaderBindingTableRegion,VkStridedDeviceAddressRegionKHR & hitShaderBindingTableRegion,VkStridedDeviceAddressRegionKHR & missShaderBindingTableRegion,VkStridedDeviceAddressRegionKHR & callableShaderBindingTableRegion)285 void SingleSquareConfiguration::initShaderBindingTables (de::MovePtr<RayTracingPipeline>& rayTracingPipeline,
286 Context& context,
287 TestParams& testParams,
288 VkPipeline pipeline,
289 deUint32 shaderGroupHandleSize,
290 deUint32 shaderGroupBaseAlignment,
291 de::MovePtr<BufferWithMemory>& raygenShaderBindingTable,
292 de::MovePtr<BufferWithMemory>& hitShaderBindingTable,
293 de::MovePtr<BufferWithMemory>& missShaderBindingTable,
294 de::MovePtr<BufferWithMemory>& callableShaderBindingTable,
295 VkStridedDeviceAddressRegionKHR& raygenShaderBindingTableRegion,
296 VkStridedDeviceAddressRegionKHR& hitShaderBindingTableRegion,
297 VkStridedDeviceAddressRegionKHR& missShaderBindingTableRegion,
298 VkStridedDeviceAddressRegionKHR& callableShaderBindingTableRegion)
299 {
300 const DeviceInterface& vkd = context.getDeviceInterface();
301 const VkDevice device = context.getDevice();
302 Allocator& allocator = context.getDefaultAllocator();
303
304 switch (testParams.callableShaderTestType)
305 {
306 case CSTT_RGEN_CALL:
307 {
308 raygenShaderBindingTable = rayTracingPipeline->createShaderBindingTable(vkd, device, pipeline, allocator, shaderGroupHandleSize, shaderGroupBaseAlignment, 0, 1);
309 hitShaderBindingTable = rayTracingPipeline->createShaderBindingTable(vkd, device, pipeline, allocator, shaderGroupHandleSize, shaderGroupBaseAlignment, 1, 1);
310 missShaderBindingTable = rayTracingPipeline->createShaderBindingTable(vkd, device, pipeline, allocator, shaderGroupHandleSize, shaderGroupBaseAlignment, 2, 1);
311 callableShaderBindingTable = rayTracingPipeline->createShaderBindingTable(vkd, device, pipeline, allocator, shaderGroupHandleSize, shaderGroupBaseAlignment, 3, 1);
312
313 raygenShaderBindingTableRegion = makeStridedDeviceAddressRegionKHR(getBufferDeviceAddress(vkd, device, raygenShaderBindingTable->get(), 0), shaderGroupHandleSize, shaderGroupHandleSize);
314 hitShaderBindingTableRegion = makeStridedDeviceAddressRegionKHR(getBufferDeviceAddress(vkd, device, hitShaderBindingTable->get(), 0), shaderGroupHandleSize, shaderGroupHandleSize);
315 missShaderBindingTableRegion = makeStridedDeviceAddressRegionKHR(getBufferDeviceAddress(vkd, device, missShaderBindingTable->get(), 0), shaderGroupHandleSize, shaderGroupHandleSize);
316 callableShaderBindingTableRegion = makeStridedDeviceAddressRegionKHR(getBufferDeviceAddress(vkd, device, callableShaderBindingTable->get(), 0), shaderGroupHandleSize, shaderGroupHandleSize);
317 break;
318 }
319 case CSTT_RGEN_CALL_CALL:
320 {
321 raygenShaderBindingTable = rayTracingPipeline->createShaderBindingTable(vkd, device, pipeline, allocator, shaderGroupHandleSize, shaderGroupBaseAlignment, 0, 1);
322 hitShaderBindingTable = rayTracingPipeline->createShaderBindingTable(vkd, device, pipeline, allocator, shaderGroupHandleSize, shaderGroupBaseAlignment, 1, 1);
323 missShaderBindingTable = rayTracingPipeline->createShaderBindingTable(vkd, device, pipeline, allocator, shaderGroupHandleSize, shaderGroupBaseAlignment, 2, 1);
324 callableShaderBindingTable = rayTracingPipeline->createShaderBindingTable(vkd, device, pipeline, allocator, shaderGroupHandleSize, shaderGroupBaseAlignment, 3, 2);
325
326 raygenShaderBindingTableRegion = makeStridedDeviceAddressRegionKHR(getBufferDeviceAddress(vkd, device, raygenShaderBindingTable->get(), 0), shaderGroupHandleSize, shaderGroupHandleSize);
327 hitShaderBindingTableRegion = makeStridedDeviceAddressRegionKHR(getBufferDeviceAddress(vkd, device, hitShaderBindingTable->get(), 0), shaderGroupHandleSize, shaderGroupHandleSize);
328 missShaderBindingTableRegion = makeStridedDeviceAddressRegionKHR(getBufferDeviceAddress(vkd, device, missShaderBindingTable->get(), 0), shaderGroupHandleSize, shaderGroupHandleSize);
329 callableShaderBindingTableRegion = makeStridedDeviceAddressRegionKHR(getBufferDeviceAddress(vkd, device, callableShaderBindingTable->get(), 0), shaderGroupHandleSize, 2*shaderGroupHandleSize);
330 break;
331 }
332 case CSTT_HIT_CALL:
333 {
334 raygenShaderBindingTable = rayTracingPipeline->createShaderBindingTable(vkd, device, pipeline, allocator, shaderGroupHandleSize, shaderGroupBaseAlignment, 0, 1);
335 hitShaderBindingTable = rayTracingPipeline->createShaderBindingTable(vkd, device, pipeline, allocator, shaderGroupHandleSize, shaderGroupBaseAlignment, 1, 1);
336 missShaderBindingTable = rayTracingPipeline->createShaderBindingTable(vkd, device, pipeline, allocator, shaderGroupHandleSize, shaderGroupBaseAlignment, 2, 1);
337 callableShaderBindingTable = rayTracingPipeline->createShaderBindingTable(vkd, device, pipeline, allocator, shaderGroupHandleSize, shaderGroupBaseAlignment, 3, 1);
338
339 raygenShaderBindingTableRegion = makeStridedDeviceAddressRegionKHR(getBufferDeviceAddress(vkd, device, raygenShaderBindingTable->get(), 0), shaderGroupHandleSize, shaderGroupHandleSize);
340 hitShaderBindingTableRegion = makeStridedDeviceAddressRegionKHR(getBufferDeviceAddress(vkd, device, hitShaderBindingTable->get(), 0), shaderGroupHandleSize, shaderGroupHandleSize);
341 missShaderBindingTableRegion = makeStridedDeviceAddressRegionKHR(getBufferDeviceAddress(vkd, device, missShaderBindingTable->get(), 0), shaderGroupHandleSize, shaderGroupHandleSize);
342 callableShaderBindingTableRegion = makeStridedDeviceAddressRegionKHR(getBufferDeviceAddress(vkd, device, callableShaderBindingTable->get(), 0), shaderGroupHandleSize, shaderGroupHandleSize);
343 break;
344 }
345 case CSTT_RGEN_MULTICALL:
346 {
347 raygenShaderBindingTable = rayTracingPipeline->createShaderBindingTable(vkd, device, pipeline, allocator, shaderGroupHandleSize, shaderGroupBaseAlignment, 0, 1);
348 hitShaderBindingTable = rayTracingPipeline->createShaderBindingTable(vkd, device, pipeline, allocator, shaderGroupHandleSize, shaderGroupBaseAlignment, 1, 1);
349 missShaderBindingTable = rayTracingPipeline->createShaderBindingTable(vkd, device, pipeline, allocator, shaderGroupHandleSize, shaderGroupBaseAlignment, 2, 1);
350 callableShaderBindingTable = rayTracingPipeline->createShaderBindingTable(vkd, device, pipeline, allocator, shaderGroupHandleSize, shaderGroupBaseAlignment, 3, 4);
351
352 raygenShaderBindingTableRegion = makeStridedDeviceAddressRegionKHR(getBufferDeviceAddress(vkd, device, raygenShaderBindingTable->get(), 0), shaderGroupHandleSize, shaderGroupHandleSize);
353 hitShaderBindingTableRegion = makeStridedDeviceAddressRegionKHR(getBufferDeviceAddress(vkd, device, hitShaderBindingTable->get(), 0), shaderGroupHandleSize, shaderGroupHandleSize);
354 missShaderBindingTableRegion = makeStridedDeviceAddressRegionKHR(getBufferDeviceAddress(vkd, device, missShaderBindingTable->get(), 0), shaderGroupHandleSize, shaderGroupHandleSize);
355 callableShaderBindingTableRegion = makeStridedDeviceAddressRegionKHR(getBufferDeviceAddress(vkd, device, callableShaderBindingTable->get(), 0), shaderGroupHandleSize, 4*shaderGroupHandleSize);
356 break;
357 }
358 default:
359 TCU_THROW(InternalError, "Wrong shader test type");
360 }
361 }
362
verifyImage(BufferWithMemory * resultBuffer,Context & context,TestParams & testParams)363 bool SingleSquareConfiguration::verifyImage (BufferWithMemory* resultBuffer, Context& context, TestParams& testParams)
364 {
365 // create result image
366 tcu::TextureFormat imageFormat = vk::mapVkFormat(getResultImageFormat());
367 tcu::ConstPixelBufferAccess resultAccess(imageFormat, testParams.width, testParams.height, 1, resultBuffer->getAllocation().getHostPtr());
368
369 // create reference image
370 std::vector<deUint32> reference(testParams.width * testParams.height);
371 tcu::PixelBufferAccess referenceAccess(imageFormat, testParams.width, testParams.height, 1, reference.data());
372
373 tcu::UVec4 missValue, hitValue;
374
375 // clear reference image with hit and miss values ( hit works only for tests calling traceRayEXT in rgen shader )
376 switch (testParams.callableShaderTestType)
377 {
378 case CSTT_RGEN_CALL:
379 missValue = tcu::UVec4(1, 0, 0, 0);
380 hitValue = tcu::UVec4(1, 0, 0, 0);
381 break;
382 case CSTT_RGEN_CALL_CALL:
383 missValue = tcu::UVec4(1, 0, 0, 0);
384 hitValue = tcu::UVec4(1, 0, 0, 0);
385 break;
386 case CSTT_HIT_CALL:
387 missValue = tcu::UVec4(1, 0, 0, 0);
388 hitValue = tcu::UVec4(2, 0, 0, 0);
389 break;
390 case CSTT_RGEN_MULTICALL:
391 missValue = tcu::UVec4(16, 0, 0, 0);
392 hitValue = tcu::UVec4(16, 0, 0, 0);
393 break;
394 default:
395 TCU_THROW(InternalError, "Wrong shader test type");
396 }
397
398 tcu::clear(referenceAccess, missValue);
399 for (deUint32 y = 1; y < testParams.width - 1; ++y)
400 for (deUint32 x = 1; x < testParams.height - 1; ++x)
401 referenceAccess.setPixel(hitValue, x, y);
402
403 // compare result and reference
404 return tcu::intThresholdCompare(context.getTestContext().getLog(), "Result comparison", "", referenceAccess, resultAccess, tcu::UVec4(0), tcu::COMPARE_LOG_RESULT);
405 }
406
getResultImageFormat()407 VkFormat SingleSquareConfiguration::getResultImageFormat ()
408 {
409 return VK_FORMAT_R32_UINT;
410 }
411
getResultImageFormatSize()412 size_t SingleSquareConfiguration::getResultImageFormatSize ()
413 {
414 return sizeof(deUint32);
415 }
416
getClearValue()417 VkClearValue SingleSquareConfiguration::getClearValue ()
418 {
419 return makeClearValueColorU32(0xFF, 0u, 0u, 0u);
420 }
421
422 class CallableShaderTestCase : public TestCase
423 {
424 public:
425 CallableShaderTestCase (tcu::TestContext& context, const char* name, const char* desc, const TestParams data);
426 ~CallableShaderTestCase (void);
427
428 virtual void checkSupport (Context& context) const;
429 virtual void initPrograms (SourceCollections& programCollection) const;
430 virtual TestInstance* createInstance (Context& context) const;
431 private:
432 TestParams m_data;
433 };
434
435 class CallableShaderTestInstance : public TestInstance
436 {
437 public:
438 CallableShaderTestInstance (Context& context, const TestParams& data);
439 ~CallableShaderTestInstance (void);
440 tcu::TestStatus iterate (void);
441
442 protected:
443 de::MovePtr<BufferWithMemory> runTest ();
444 private:
445 TestParams m_data;
446 };
447
CallableShaderTestCase(tcu::TestContext & context,const char * name,const char * desc,const TestParams data)448 CallableShaderTestCase::CallableShaderTestCase (tcu::TestContext& context, const char* name, const char* desc, const TestParams data)
449 : vkt::TestCase (context, name, desc)
450 , m_data (data)
451 {
452 }
453
~CallableShaderTestCase(void)454 CallableShaderTestCase::~CallableShaderTestCase (void)
455 {
456 }
457
checkSupport(Context & context) const458 void CallableShaderTestCase::checkSupport (Context& context) const
459 {
460 context.requireDeviceFunctionality("VK_KHR_acceleration_structure");
461 context.requireDeviceFunctionality("VK_KHR_ray_tracing_pipeline");
462
463 const VkPhysicalDeviceRayTracingPipelineFeaturesKHR& rayTracingPipelineFeaturesKHR = context.getRayTracingPipelineFeatures();
464 if (rayTracingPipelineFeaturesKHR.rayTracingPipeline == DE_FALSE )
465 TCU_THROW(NotSupportedError, "Requires VkPhysicalDeviceRayTracingPipelineFeaturesKHR.rayTracingPipeline");
466
467 const VkPhysicalDeviceAccelerationStructureFeaturesKHR& accelerationStructureFeaturesKHR = context.getAccelerationStructureFeatures();
468 if (accelerationStructureFeaturesKHR.accelerationStructure == DE_FALSE)
469 TCU_THROW(TestError, "VK_KHR_ray_tracing_pipeline requires VkPhysicalDeviceAccelerationStructureFeaturesKHR.accelerationStructure");
470 }
471
initPrograms(SourceCollections & programCollection) const472 void CallableShaderTestCase::initPrograms (SourceCollections& programCollection) const
473 {
474 const vk::ShaderBuildOptions buildOptions(programCollection.usedVulkanVersion, vk::SPIRV_VERSION_1_4, 0u, true);
475 {
476 std::stringstream css;
477 css <<
478 "#version 460 core\n"
479 "#extension GL_EXT_ray_tracing : require\n"
480 "layout(location = 0) rayPayloadEXT uvec4 hitValue;\n"
481 "layout(r32ui, set = 0, binding = 0) uniform uimage2D result;\n"
482 "layout(set = 0, binding = 1) uniform accelerationStructureEXT topLevelAS;\n"
483 "\n"
484 "void main()\n"
485 "{\n"
486 " float tmin = 0.0;\n"
487 " float tmax = 1.0;\n"
488 " vec3 origin = vec3(float(gl_LaunchIDEXT.x) + 0.5f, float(gl_LaunchIDEXT.y) + 0.5f, 0.5f);\n"
489 " vec3 direct = vec3(0.0, 0.0, -1.0);\n"
490 " hitValue = uvec4(0,0,0,0);\n"
491 " traceRayEXT(topLevelAS, 0, 0xFF, 0, 0, 0, origin, tmin, direct, tmax, 0);\n"
492 " imageStore(result, ivec2(gl_LaunchIDEXT.xy), hitValue);\n"
493 "}\n";
494 programCollection.glslSources.add("rgen") << glu::RaygenSource(updateRayTracingGLSL(css.str())) << buildOptions;
495 }
496
497 {
498 std::stringstream css;
499 css <<
500 "#version 460 core\n"
501 "#extension GL_EXT_ray_tracing : require\n"
502 "layout(location = 0) callableDataEXT uvec4 value;\n"
503 "layout(r32ui, set = 0, binding = 0) uniform uimage2D result;\n"
504 "layout(set = 0, binding = 1) uniform accelerationStructureEXT topLevelAS;\n"
505 "\n"
506 "void main()\n"
507 "{\n"
508 " executeCallableEXT(0, 0);\n"
509 " imageStore(result, ivec2(gl_LaunchIDEXT.xy), value);\n"
510 "}\n";
511 programCollection.glslSources.add("rgen_call") << glu::RaygenSource(updateRayTracingGLSL(css.str())) << buildOptions;
512 }
513
514 {
515 std::stringstream css;
516 css <<
517 "#version 460 core\n"
518 "#extension GL_EXT_ray_tracing : require\n"
519 "struct CallValue\n"
520 "{\n"
521 " ivec4 a;\n"
522 " vec4 b;\n"
523 "};\n"
524 "layout(location = 0) callableDataEXT uvec4 value0;\n"
525 "layout(location = 1) callableDataEXT uint value1;\n"
526 "layout(location = 2) callableDataEXT CallValue value2;\n"
527 "layout(location = 4) callableDataEXT vec3 value3;\n"
528 "layout(r32ui, set = 0, binding = 0) uniform uimage2D result;\n"
529 "layout(set = 0, binding = 1) uniform accelerationStructureEXT topLevelAS;\n"
530 "\n"
531 "void main()\n"
532 "{\n"
533 " executeCallableEXT(0, 0);\n"
534 " executeCallableEXT(1, 1);\n"
535 " executeCallableEXT(2, 2);\n"
536 " executeCallableEXT(3, 4);\n"
537 " uint resultValue = value0.x + value1 + value2.a.x * uint(floor(value2.b.y)) + uint(floor(value3.z));\n"
538 " imageStore(result, ivec2(gl_LaunchIDEXT.xy), uvec4(resultValue, 0, 0, 0));\n"
539 "}\n";
540 programCollection.glslSources.add("rgen_multicall") << glu::RaygenSource(updateRayTracingGLSL(css.str())) << buildOptions;
541 }
542
543 {
544 std::stringstream css;
545 css <<
546 "#version 460 core\n"
547 "#extension GL_EXT_ray_tracing : require\n"
548 "layout(location = 0) rayPayloadInEXT uvec4 hitValue;\n"
549 "void main()\n"
550 "{\n"
551 " hitValue = uvec4(1,0,0,1);\n"
552 "}\n";
553
554 programCollection.glslSources.add("chit") << glu::ClosestHitSource(updateRayTracingGLSL(css.str())) << buildOptions;
555 }
556
557 {
558 std::stringstream css;
559 css <<
560 "#version 460 core\n"
561 "#extension GL_EXT_ray_tracing : require\n"
562 "layout(location = 0) callableDataEXT uvec4 value;\n"
563 "layout(location = 0) rayPayloadInEXT uvec4 hitValue;\n"
564 "void main()\n"
565 "{\n"
566 " executeCallableEXT(0, 0);\n"
567 " hitValue = value;\n"
568 " hitValue.x = hitValue.x + 1;\n"
569 "}\n";
570
571 programCollection.glslSources.add("chit_call") << glu::ClosestHitSource(updateRayTracingGLSL(css.str())) << buildOptions;
572 }
573
574 {
575 std::stringstream css;
576 css <<
577 "#version 460 core\n"
578 "#extension GL_EXT_ray_tracing : require\n"
579 "layout(location = 0) rayPayloadInEXT uvec4 hitValue;\n"
580 "void main()\n"
581 "{\n"
582 " hitValue = uvec4(0,0,0,1);\n"
583 "}\n";
584
585 programCollection.glslSources.add("miss") << glu::MissSource(updateRayTracingGLSL(css.str())) << buildOptions;
586 }
587
588 {
589 std::stringstream css;
590 css <<
591 "#version 460 core\n"
592 "#extension GL_EXT_ray_tracing : require\n"
593 "layout(location = 0) callableDataEXT uvec4 value;\n"
594 "layout(location = 0) rayPayloadInEXT uvec4 hitValue;\n"
595 "void main()\n"
596 "{\n"
597 " executeCallableEXT(0, 0);\n"
598 " hitValue = value;\n"
599 "}\n";
600
601 programCollection.glslSources.add("miss_call") << glu::MissSource(updateRayTracingGLSL(css.str())) << buildOptions;
602 }
603
604 std::vector<std::string> callableDataDefinition =
605 {
606 "layout(location = 0) callableDataInEXT uvec4 result;\n",
607 "layout(location = 1) callableDataInEXT uint result;\n",
608 "struct CallValue\n{\n ivec4 a;\n vec4 b;\n};\nlayout(location = 2) callableDataInEXT CallValue result;\n",
609 "layout(location = 4) callableDataInEXT vec3 result;\n"
610 };
611
612 std::vector<std::string> callableDataComputation =
613 {
614 " result = uvec4(1,0,0,1);\n",
615 " result = 2;\n",
616 " result.a = ivec4(3,0,0,1);\n result.b = vec4(1.0, 3.2, 0.0, 1);\n",
617 " result = vec3(0.0, 0.0, 4.3);\n",
618 };
619
620 for (deUint32 idx = 0; idx < callableDataDefinition.size(); ++idx)
621 {
622 std::stringstream css;
623 css <<
624 "#version 460 core\n"
625 "#extension GL_EXT_ray_tracing : require\n"
626 << callableDataDefinition[idx] <<
627 "void main()\n"
628 "{\n"
629 << callableDataComputation[idx] <<
630 "}\n";
631 std::stringstream csname;
632 csname << "call_" << idx;
633
634 programCollection.glslSources.add(csname.str()) << glu::CallableSource(updateRayTracingGLSL(css.str())) << buildOptions;
635 }
636
637 {
638 std::stringstream css;
639 css <<
640 "#version 460 core\n"
641 "#extension GL_EXT_ray_tracing : require\n"
642 "layout(location = 0) callableDataInEXT uvec4 result;\n"
643 "layout(location = 1) callableDataEXT uvec4 info;\n"
644 "void main()\n"
645 "{\n"
646 " executeCallableEXT(1, 1);\n"
647 " result = info;\n"
648 "}\n";
649
650 programCollection.glslSources.add("call_call") << glu::CallableSource(updateRayTracingGLSL(css.str())) << buildOptions;
651 }
652 }
653
createInstance(Context & context) const654 TestInstance* CallableShaderTestCase::createInstance (Context& context) const
655 {
656 return new CallableShaderTestInstance(context, m_data);
657 }
658
CallableShaderTestInstance(Context & context,const TestParams & data)659 CallableShaderTestInstance::CallableShaderTestInstance (Context& context, const TestParams& data)
660 : vkt::TestInstance (context)
661 , m_data (data)
662 {
663 }
664
~CallableShaderTestInstance(void)665 CallableShaderTestInstance::~CallableShaderTestInstance (void)
666 {
667 }
668
runTest()669 de::MovePtr<BufferWithMemory> CallableShaderTestInstance::runTest ()
670 {
671 const InstanceInterface& vki = m_context.getInstanceInterface();
672 const DeviceInterface& vkd = m_context.getDeviceInterface();
673 const VkDevice device = m_context.getDevice();
674 const VkPhysicalDevice physicalDevice = m_context.getPhysicalDevice();
675 const deUint32 queueFamilyIndex = m_context.getUniversalQueueFamilyIndex();
676 const VkQueue queue = m_context.getUniversalQueue();
677 Allocator& allocator = m_context.getDefaultAllocator();
678 const deUint32 pixelCount = m_data.width * m_data.height * 1;
679
680 const Move<VkDescriptorSetLayout> descriptorSetLayout = DescriptorSetLayoutBuilder()
681 .addSingleBinding(VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, ALL_RAY_TRACING_STAGES)
682 .addSingleBinding(VK_DESCRIPTOR_TYPE_ACCELERATION_STRUCTURE_KHR, ALL_RAY_TRACING_STAGES)
683 .build(vkd, device);
684 const Move<VkDescriptorPool> descriptorPool = DescriptorPoolBuilder()
685 .addType(VK_DESCRIPTOR_TYPE_STORAGE_IMAGE)
686 .addType(VK_DESCRIPTOR_TYPE_ACCELERATION_STRUCTURE_KHR)
687 .build(vkd, device, VK_DESCRIPTOR_POOL_CREATE_FREE_DESCRIPTOR_SET_BIT, 1u);
688 const Move<VkDescriptorSet> descriptorSet = makeDescriptorSet(vkd, device, *descriptorPool, *descriptorSetLayout);
689 const Move<VkPipelineLayout> pipelineLayout = makePipelineLayout(vkd, device, descriptorSetLayout.get());
690
691 de::MovePtr<RayTracingPipeline> rayTracingPipeline = de::newMovePtr<RayTracingPipeline>();
692 m_data.testConfiguration->initRayTracingShaders(rayTracingPipeline, m_context, m_data);
693 Move<VkPipeline> pipeline = rayTracingPipeline->createPipeline(vkd, device, *pipelineLayout);
694
695 de::MovePtr<BufferWithMemory> raygenShaderBindingTable;
696 de::MovePtr<BufferWithMemory> hitShaderBindingTable;
697 de::MovePtr<BufferWithMemory> missShaderBindingTable;
698 de::MovePtr<BufferWithMemory> callableShaderBindingTable;
699 VkStridedDeviceAddressRegionKHR raygenShaderBindingTableRegion;
700 VkStridedDeviceAddressRegionKHR hitShaderBindingTableRegion;
701 VkStridedDeviceAddressRegionKHR missShaderBindingTableRegion;
702 VkStridedDeviceAddressRegionKHR callableShaderBindingTableRegion;
703 m_data.testConfiguration->initShaderBindingTables(rayTracingPipeline, m_context, m_data, *pipeline, getShaderGroupHandleSize(vki, physicalDevice), getShaderGroupBaseAlignment(vki, physicalDevice), raygenShaderBindingTable, hitShaderBindingTable, missShaderBindingTable, callableShaderBindingTable, raygenShaderBindingTableRegion, hitShaderBindingTableRegion, missShaderBindingTableRegion, callableShaderBindingTableRegion);
704
705 const VkFormat imageFormat = m_data.testConfiguration->getResultImageFormat();
706 const VkImageCreateInfo imageCreateInfo = makeImageCreateInfo(m_data.width, m_data.height, imageFormat);
707 const VkImageSubresourceRange imageSubresourceRange = makeImageSubresourceRange(VK_IMAGE_ASPECT_COLOR_BIT, 0u, 1u, 0u, 1u);
708 const de::MovePtr<ImageWithMemory> image = de::MovePtr<ImageWithMemory>(new ImageWithMemory(vkd, device, allocator, imageCreateInfo, MemoryRequirement::Any));
709 const Move<VkImageView> imageView = makeImageView(vkd, device, **image, VK_IMAGE_VIEW_TYPE_2D, imageFormat, imageSubresourceRange);
710
711 const VkBufferCreateInfo resultBufferCreateInfo = makeBufferCreateInfo(pixelCount*m_data.testConfiguration->getResultImageFormatSize(), VK_BUFFER_USAGE_TRANSFER_DST_BIT);
712 const VkImageSubresourceLayers resultBufferImageSubresourceLayers = makeImageSubresourceLayers(VK_IMAGE_ASPECT_COLOR_BIT, 0u, 0u, 1u);
713 const VkBufferImageCopy resultBufferImageRegion = makeBufferImageCopy(makeExtent3D(m_data.width, m_data.height, 1), resultBufferImageSubresourceLayers);
714 de::MovePtr<BufferWithMemory> resultBuffer = de::MovePtr<BufferWithMemory>(new BufferWithMemory(vkd, device, allocator, resultBufferCreateInfo, MemoryRequirement::HostVisible));
715
716 const VkDescriptorImageInfo descriptorImageInfo = makeDescriptorImageInfo(DE_NULL, *imageView, VK_IMAGE_LAYOUT_GENERAL);
717
718 const Move<VkCommandPool> cmdPool = createCommandPool(vkd, device, 0, queueFamilyIndex);
719 const Move<VkCommandBuffer> cmdBuffer = allocateCommandBuffer(vkd, device, *cmdPool, VK_COMMAND_BUFFER_LEVEL_PRIMARY);
720
721 std::vector<de::SharedPtr<BottomLevelAccelerationStructure> > bottomLevelAccelerationStructures;
722 de::MovePtr<TopLevelAccelerationStructure> topLevelAccelerationStructure;
723
724 beginCommandBuffer(vkd, *cmdBuffer, 0u);
725 {
726 const VkImageMemoryBarrier preImageBarrier = makeImageMemoryBarrier(0u, VK_ACCESS_TRANSFER_WRITE_BIT,
727 VK_IMAGE_LAYOUT_UNDEFINED, VK_IMAGE_LAYOUT_TRANSFER_DST_OPTIMAL,
728 **image, imageSubresourceRange);
729 cmdPipelineImageMemoryBarrier(vkd, *cmdBuffer, VK_PIPELINE_STAGE_TOP_OF_PIPE_BIT, VK_PIPELINE_STAGE_TRANSFER_BIT, &preImageBarrier);
730
731 const VkClearValue clearValue = m_data.testConfiguration->getClearValue();
732 vkd.cmdClearColorImage(*cmdBuffer, **image, VK_IMAGE_LAYOUT_TRANSFER_DST_OPTIMAL, &clearValue.color, 1, &imageSubresourceRange);
733
734 const VkImageMemoryBarrier postImageBarrier = makeImageMemoryBarrier(VK_ACCESS_TRANSFER_WRITE_BIT, VK_ACCESS_ACCELERATION_STRUCTURE_READ_BIT_KHR | VK_ACCESS_ACCELERATION_STRUCTURE_WRITE_BIT_KHR,
735 VK_IMAGE_LAYOUT_TRANSFER_DST_OPTIMAL, VK_IMAGE_LAYOUT_GENERAL,
736 **image, imageSubresourceRange);
737 cmdPipelineImageMemoryBarrier(vkd, *cmdBuffer, VK_PIPELINE_STAGE_TRANSFER_BIT, VK_PIPELINE_STAGE_ACCELERATION_STRUCTURE_BUILD_BIT_KHR, &postImageBarrier);
738
739 bottomLevelAccelerationStructures = m_data.testConfiguration->initBottomAccelerationStructures(m_context, m_data);
740 for (auto& blas : bottomLevelAccelerationStructures)
741 blas->createAndBuild(vkd, device, *cmdBuffer, allocator);
742 topLevelAccelerationStructure = m_data.testConfiguration->initTopAccelerationStructure(m_context, m_data, bottomLevelAccelerationStructures);
743 topLevelAccelerationStructure->createAndBuild(vkd, device, *cmdBuffer, allocator);
744
745 const TopLevelAccelerationStructure* topLevelAccelerationStructurePtr = topLevelAccelerationStructure.get();
746 VkWriteDescriptorSetAccelerationStructureKHR accelerationStructureWriteDescriptorSet =
747 {
748 VK_STRUCTURE_TYPE_WRITE_DESCRIPTOR_SET_ACCELERATION_STRUCTURE_KHR, // VkStructureType sType;
749 DE_NULL, // const void* pNext;
750 1u, // deUint32 accelerationStructureCount;
751 topLevelAccelerationStructurePtr->getPtr(), // const VkAccelerationStructureKHR* pAccelerationStructures;
752 };
753
754 DescriptorSetUpdateBuilder()
755 .writeSingle(*descriptorSet, DescriptorSetUpdateBuilder::Location::binding(0u), VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, &descriptorImageInfo)
756 .writeSingle(*descriptorSet, DescriptorSetUpdateBuilder::Location::binding(1u), VK_DESCRIPTOR_TYPE_ACCELERATION_STRUCTURE_KHR, &accelerationStructureWriteDescriptorSet)
757 .update(vkd, device);
758
759 vkd.cmdBindDescriptorSets(*cmdBuffer, VK_PIPELINE_BIND_POINT_RAY_TRACING_KHR, *pipelineLayout, 0, 1, &descriptorSet.get(), 0, DE_NULL);
760
761 vkd.cmdBindPipeline(*cmdBuffer, VK_PIPELINE_BIND_POINT_RAY_TRACING_KHR, *pipeline);
762
763 cmdTraceRays(vkd,
764 *cmdBuffer,
765 &raygenShaderBindingTableRegion,
766 &missShaderBindingTableRegion,
767 &hitShaderBindingTableRegion,
768 &callableShaderBindingTableRegion,
769 m_data.width, m_data.height, 1);
770
771 const VkMemoryBarrier postTraceMemoryBarrier = makeMemoryBarrier(VK_ACCESS_SHADER_WRITE_BIT, VK_ACCESS_TRANSFER_READ_BIT);
772 const VkMemoryBarrier postCopyMemoryBarrier = makeMemoryBarrier(VK_ACCESS_TRANSFER_WRITE_BIT, VK_ACCESS_HOST_READ_BIT);
773 cmdPipelineMemoryBarrier(vkd, *cmdBuffer, VK_PIPELINE_STAGE_RAY_TRACING_SHADER_BIT_KHR, VK_PIPELINE_STAGE_TRANSFER_BIT, &postTraceMemoryBarrier);
774
775 vkd.cmdCopyImageToBuffer(*cmdBuffer, **image, VK_IMAGE_LAYOUT_GENERAL, **resultBuffer, 1u, &resultBufferImageRegion);
776
777 cmdPipelineMemoryBarrier(vkd, *cmdBuffer, VK_PIPELINE_STAGE_TRANSFER_BIT, VK_PIPELINE_STAGE_HOST_BIT, &postCopyMemoryBarrier);
778 }
779 endCommandBuffer(vkd, *cmdBuffer);
780
781 submitCommandsAndWait(vkd, device, queue, cmdBuffer.get());
782
783 invalidateMappedMemoryRange(vkd, device, resultBuffer->getAllocation().getMemory(), resultBuffer->getAllocation().getOffset(), VK_WHOLE_SIZE);
784
785 return resultBuffer;
786 }
787
iterate(void)788 tcu::TestStatus CallableShaderTestInstance::iterate (void)
789 {
790 // run test using arrays of pointers
791 const de::MovePtr<BufferWithMemory> buffer = runTest();
792
793 if (!m_data.testConfiguration->verifyImage(buffer.get(), m_context, m_data))
794 return tcu::TestStatus::fail("Fail");
795 return tcu::TestStatus::pass("Pass");
796 }
797
798 } // anonymous
799
createCallableShadersTests(tcu::TestContext & testCtx)800 tcu::TestCaseGroup* createCallableShadersTests (tcu::TestContext& testCtx)
801 {
802 de::MovePtr<tcu::TestCaseGroup> group(new tcu::TestCaseGroup(testCtx, "callable_shader", "Tests veryfying callable shaders"));
803
804 struct CallableShaderTestTypeData
805 {
806 CallableShaderTestType shaderTestType;
807 const char* name;
808 } callableShaderTestTypes[] =
809 {
810 { CSTT_RGEN_CALL, "rgen_call" },
811 { CSTT_RGEN_CALL_CALL, "rgen_call_call" },
812 { CSTT_HIT_CALL, "hit_call" },
813 { CSTT_RGEN_MULTICALL, "rgen_multicall" },
814 };
815
816 for (size_t shaderTestNdx = 0; shaderTestNdx < DE_LENGTH_OF_ARRAY(callableShaderTestTypes); ++shaderTestNdx)
817 {
818 TestParams testParams
819 {
820 TEST_WIDTH,
821 TEST_HEIGHT,
822 callableShaderTestTypes[shaderTestNdx].shaderTestType,
823 de::SharedPtr<TestConfiguration>(new SingleSquareConfiguration())
824 };
825 group->addChild(new CallableShaderTestCase(group->getTestContext(), callableShaderTestTypes[shaderTestNdx].name, "", testParams));
826 }
827
828 return group.release();
829 }
830
831 } // RayTracing
832
833 } // vkt
834