1 /*------------------------------------------------------------------------
2 * Vulkan Conformance Tests
3 * ------------------------
4 *
5 * Copyright (c) 2019 The Khronos Group Inc.
6 * Copyright (c) 2019 Google Inc.
7 * Copyright (c) 2017 Codeplay Software Ltd.
8 *
9 * Licensed under the Apache License, Version 2.0 (the "License");
10 * you may not use this file except in compliance with the License.
11 * You may obtain a copy of the License at
12 *
13 * http://www.apache.org/licenses/LICENSE-2.0
14 *
15 * Unless required by applicable law or agreed to in writing, software
16 * distributed under the License is distributed on an "AS IS" BASIS,
17 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18 * See the License for the specific language governing permissions and
19 * limitations under the License.
20 *
21 */ /*!
22 * \file
23 * \brief Subgroups Tests
24 */ /*--------------------------------------------------------------------*/
25
26 #include "vktSubgroupsQuadTests.hpp"
27 #include "vktSubgroupsTestsUtils.hpp"
28
29 #include <string>
30 #include <vector>
31
32 using namespace tcu;
33 using namespace std;
34 using namespace vk;
35 using namespace vkt;
36
37 namespace
38 {
39 enum OpType
40 {
41 OPTYPE_QUAD_BROADCAST = 0,
42 OPTYPE_QUAD_BROADCAST_NONCONST,
43 OPTYPE_QUAD_SWAP_HORIZONTAL,
44 OPTYPE_QUAD_SWAP_VERTICAL,
45 OPTYPE_QUAD_SWAP_DIAGONAL,
46 OPTYPE_LAST
47 };
48
49 struct CaseDefinition
50 {
51 OpType opType;
52 VkShaderStageFlags shaderStage;
53 VkFormat format;
54 de::SharedPtr<bool> geometryPointSizeSupported;
55 deBool requiredSubgroupSize;
56 };
57
checkVertexPipelineStages(const void * internalData,vector<const void * > datas,deUint32 width,deUint32)58 static bool checkVertexPipelineStages (const void* internalData,
59 vector<const void*> datas,
60 deUint32 width,
61 deUint32)
62 {
63 DE_UNREF(internalData);
64
65 return subgroups::check(datas, width, 1);
66 }
67
checkCompute(const void * internalData,vector<const void * > datas,const deUint32 numWorkgroups[3],const deUint32 localSize[3],deUint32)68 static bool checkCompute (const void* internalData,
69 vector<const void*> datas,
70 const deUint32 numWorkgroups[3],
71 const deUint32 localSize[3],
72 deUint32)
73 {
74 DE_UNREF(internalData);
75
76 return subgroups::checkCompute(datas, numWorkgroups, localSize, 1);
77 }
78
getOpTypeName(OpType opType)79 string getOpTypeName (OpType opType)
80 {
81 switch (opType)
82 {
83 case OPTYPE_QUAD_BROADCAST: return "subgroupQuadBroadcast";
84 case OPTYPE_QUAD_BROADCAST_NONCONST: return "subgroupQuadBroadcast";
85 case OPTYPE_QUAD_SWAP_HORIZONTAL: return "subgroupQuadSwapHorizontal";
86 case OPTYPE_QUAD_SWAP_VERTICAL: return "subgroupQuadSwapVertical";
87 case OPTYPE_QUAD_SWAP_DIAGONAL: return "subgroupQuadSwapDiagonal";
88 default: TCU_THROW(InternalError, "Unsupported op type");
89 }
90 }
91
getOpTypeCaseName(OpType opType)92 string getOpTypeCaseName (OpType opType)
93 {
94 switch (opType)
95 {
96 case OPTYPE_QUAD_BROADCAST: return "subgroupquadbroadcast";
97 case OPTYPE_QUAD_BROADCAST_NONCONST: return "subgroupquadbroadcast_nonconst";
98 case OPTYPE_QUAD_SWAP_HORIZONTAL: return "subgroupquadswaphorizontal";
99 case OPTYPE_QUAD_SWAP_VERTICAL: return "subgroupquadswapvertical";
100 case OPTYPE_QUAD_SWAP_DIAGONAL: return "subgroupquadswapdiagonal";
101 default: TCU_THROW(InternalError, "Unsupported op type");
102 }
103 }
104
getExtHeader(VkFormat format)105 string getExtHeader (VkFormat format)
106 {
107 return "#extension GL_KHR_shader_subgroup_quad: enable\n"
108 "#extension GL_KHR_shader_subgroup_ballot: enable\n" +
109 subgroups::getAdditionalExtensionForFormat(format);
110 }
111
getTestSrc(const CaseDefinition & caseDef)112 string getTestSrc (const CaseDefinition &caseDef)
113 {
114 const string swapTable[OPTYPE_LAST] =
115 {
116 "",
117 "",
118 " const uint swapTable[4] = {1, 0, 3, 2};\n",
119 " const uint swapTable[4] = {2, 3, 0, 1};\n",
120 " const uint swapTable[4] = {3, 2, 1, 0};\n",
121 };
122 const string validate =
123 " if (subgroupBallotBitExtract(mask, otherID) && op !=data[otherID])\n"
124 " tempRes = 0;\n";
125 const string fmt = subgroups::getFormatNameForGLSL(caseDef.format);
126 const string op = getOpTypeName(caseDef.opType);
127 ostringstream testSrc;
128
129 testSrc << " uvec4 mask = subgroupBallot(true);\n"
130 << swapTable[caseDef.opType]
131 << " tempRes = 1;\n";
132
133 if (caseDef.opType == OPTYPE_QUAD_BROADCAST)
134 {
135 for (int i=0; i<4; i++)
136 {
137 testSrc << " {\n"
138 << " " << fmt << " op = " << op << "(data[gl_SubgroupInvocationID], " << i << ");\n"
139 << " uint otherID = (gl_SubgroupInvocationID & ~0x3) + " << i << ";\n"
140 << validate
141 << " }\n";
142 }
143 }
144 else if (caseDef.opType == OPTYPE_QUAD_BROADCAST_NONCONST)
145 {
146 testSrc << " for (int i=0; i<4; i++)"
147 << " {\n"
148 << " " << fmt << " op = " << op << "(data[gl_SubgroupInvocationID], i);\n"
149 << " uint otherID = (gl_SubgroupInvocationID & ~0x3) + i;\n"
150 << validate
151 << " }\n"
152 << " uint quadID = gl_SubgroupInvocationID >> 2;\n"
153 << " uint quadInvocation = gl_SubgroupInvocationID & 0x3;\n"
154 << " // Test lane ID that is only uniform in active lanes\n"
155 << " if (quadInvocation >= 2)\n"
156 << " {\n"
157 << " uint id = quadInvocation & ~1;\n"
158 << " " << fmt << " op = " << op << "(data[gl_SubgroupInvocationID], id);\n"
159 << " uint otherID = 4*quadID + id;\n"
160 << validate
161 << " }\n"
162 << " // Test lane ID that is only quad uniform, not subgroup uniform\n"
163 << " {\n"
164 << " uint id = quadID & 0x3;\n"
165 << " " << fmt << " op = " << op << "(data[gl_SubgroupInvocationID], id);\n"
166 << " uint otherID = 4*quadID + id;\n"
167 << validate
168 << " }\n";
169 }
170 else
171 {
172 testSrc << " " << fmt << " op = " << op << "(data[gl_SubgroupInvocationID]);\n"
173 << " uint otherID = (gl_SubgroupInvocationID & ~0x3) + swapTable[gl_SubgroupInvocationID & 0x3];\n"
174 << validate;
175 }
176
177 return testSrc.str();
178 }
179
initFrameBufferPrograms(SourceCollections & programCollection,CaseDefinition caseDef)180 void initFrameBufferPrograms (SourceCollections& programCollection, CaseDefinition caseDef)
181 {
182 const SpirvVersion spirvVersion = (caseDef.opType == OPTYPE_QUAD_BROADCAST_NONCONST) ? SPIRV_VERSION_1_5 : SPIRV_VERSION_1_3;
183 const ShaderBuildOptions buildOptions (programCollection.usedVulkanVersion, spirvVersion, 0u);
184
185 subgroups::initStdFrameBufferPrograms(programCollection, buildOptions, caseDef.shaderStage, caseDef.format, *caseDef.geometryPointSizeSupported, getExtHeader(caseDef.format), getTestSrc(caseDef), "");
186 }
187
initPrograms(SourceCollections & programCollection,CaseDefinition caseDef)188 void initPrograms (SourceCollections& programCollection, CaseDefinition caseDef)
189 {
190 const bool spirv15required = caseDef.opType == OPTYPE_QUAD_BROADCAST_NONCONST;
191 const bool spirv14required = isAllRayTracingStages(caseDef.shaderStage);
192 const SpirvVersion spirvVersion = spirv15required ? SPIRV_VERSION_1_5
193 : spirv14required ? SPIRV_VERSION_1_4
194 : SPIRV_VERSION_1_3;
195 const ShaderBuildOptions buildOptions (programCollection.usedVulkanVersion, spirvVersion, 0u);
196 const string extHeader = getExtHeader(caseDef.format);
197 const string testSrc = getTestSrc(caseDef);
198
199 subgroups::initStdPrograms(programCollection, buildOptions, caseDef.shaderStage, caseDef.format, *caseDef.geometryPointSizeSupported, extHeader, testSrc, "");
200 }
201
supportedCheck(Context & context,CaseDefinition caseDef)202 void supportedCheck (Context& context, CaseDefinition caseDef)
203 {
204 if (!subgroups::isSubgroupSupported(context))
205 TCU_THROW(NotSupportedError, "Subgroup operations are not supported");
206
207 if (!subgroups::isSubgroupFeatureSupportedForDevice(context, VK_SUBGROUP_FEATURE_QUAD_BIT))
208 TCU_THROW(NotSupportedError, "Device does not support subgroup quad operations");
209
210 if (!subgroups::isFormatSupportedForDevice(context, caseDef.format))
211 TCU_THROW(NotSupportedError, "Device does not support the specified format in subgroup operations");
212
213 if ((caseDef.opType == OPTYPE_QUAD_BROADCAST_NONCONST) && !subgroups::isSubgroupBroadcastDynamicIdSupported(context))
214 TCU_THROW(NotSupportedError, "Device does not support SubgroupBroadcastDynamicId");
215
216 if (caseDef.requiredSubgroupSize)
217 {
218 context.requireDeviceFunctionality("VK_EXT_subgroup_size_control");
219
220 const VkPhysicalDeviceSubgroupSizeControlFeaturesEXT& subgroupSizeControlFeatures = context.getSubgroupSizeControlFeaturesEXT();
221 const VkPhysicalDeviceSubgroupSizeControlPropertiesEXT& subgroupSizeControlProperties = context.getSubgroupSizeControlPropertiesEXT();
222
223 if (subgroupSizeControlFeatures.subgroupSizeControl == DE_FALSE)
224 TCU_THROW(NotSupportedError, "Device does not support varying subgroup sizes nor required subgroup size");
225
226 if (subgroupSizeControlFeatures.computeFullSubgroups == DE_FALSE)
227 TCU_THROW(NotSupportedError, "Device does not support full subgroups in compute shaders");
228
229 if ((subgroupSizeControlProperties.requiredSubgroupSizeStages & caseDef.shaderStage) != caseDef.shaderStage)
230 TCU_THROW(NotSupportedError, "Required subgroup size is not supported for shader stage");
231 }
232
233 *caseDef.geometryPointSizeSupported = subgroups::isTessellationAndGeometryPointSizeSupported(context);
234
235 if (isAllRayTracingStages(caseDef.shaderStage))
236 {
237 context.requireDeviceFunctionality("VK_KHR_ray_tracing_pipeline");
238 }
239
240 subgroups::supportedCheckShader(context, caseDef.shaderStage);
241 }
242
noSSBOtest(Context & context,const CaseDefinition caseDef)243 TestStatus noSSBOtest (Context& context, const CaseDefinition caseDef)
244 {
245 subgroups::SSBOData inputData;
246 inputData.format = caseDef.format;
247 inputData.layout = subgroups::SSBOData::LayoutStd140;
248 inputData.numElements = subgroups::maxSupportedSubgroupSize();
249 inputData.initializeType = subgroups::SSBOData::InitializeNonZero;
250
251 switch (caseDef.shaderStage)
252 {
253 case VK_SHADER_STAGE_VERTEX_BIT: return subgroups::makeVertexFrameBufferTest(context, VK_FORMAT_R32_UINT, &inputData, 1, DE_NULL, checkVertexPipelineStages);
254 case VK_SHADER_STAGE_GEOMETRY_BIT: return subgroups::makeGeometryFrameBufferTest(context, VK_FORMAT_R32_UINT, &inputData, 1, DE_NULL, checkVertexPipelineStages);
255 case VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT: return subgroups::makeTessellationEvaluationFrameBufferTest(context, VK_FORMAT_R32_UINT, &inputData, 1, DE_NULL, checkVertexPipelineStages, caseDef.shaderStage);
256 case VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT: return subgroups::makeTessellationEvaluationFrameBufferTest(context, VK_FORMAT_R32_UINT, &inputData, 1, DE_NULL, checkVertexPipelineStages, caseDef.shaderStage);
257 default: TCU_THROW(InternalError, "Unhandled shader stage");
258 }
259 }
260
test(Context & context,const CaseDefinition caseDef)261 TestStatus test (Context& context, const CaseDefinition caseDef)
262 {
263
264 if (isAllComputeStages(caseDef.shaderStage))
265 {
266 const VkPhysicalDeviceSubgroupSizeControlPropertiesEXT& subgroupSizeControlProperties = context.getSubgroupSizeControlPropertiesEXT();
267 TestLog& log = context.getTestContext().getLog();
268 const subgroups::SSBOData inputData
269 {
270 subgroups::SSBOData::InitializeNonZero, // InputDataInitializeType initializeType;
271 subgroups::SSBOData::LayoutStd430, // InputDataLayoutType layout;
272 caseDef.format, // vk::VkFormat format;
273 subgroups::maxSupportedSubgroupSize(), // vk::VkDeviceSize numElements;
274 };
275
276 if (caseDef.requiredSubgroupSize == DE_FALSE)
277 return subgroups::makeComputeTest(context, VK_FORMAT_R32_UINT, &inputData, 1, DE_NULL, checkCompute);
278
279 log << TestLog::Message << "Testing required subgroup size range [" << subgroupSizeControlProperties.minSubgroupSize << ", "
280 << subgroupSizeControlProperties.maxSubgroupSize << "]" << TestLog::EndMessage;
281
282 // According to the spec, requiredSubgroupSize must be a power-of-two integer.
283 for (deUint32 size = subgroupSizeControlProperties.minSubgroupSize; size <= subgroupSizeControlProperties.maxSubgroupSize; size *= 2)
284 {
285 TestStatus result = subgroups::makeComputeTest(context, VK_FORMAT_R32_UINT, &inputData, 1, DE_NULL, checkCompute,
286 size, VK_PIPELINE_SHADER_STAGE_CREATE_REQUIRE_FULL_SUBGROUPS_BIT_EXT);
287 if (result.getCode() != QP_TEST_RESULT_PASS)
288 {
289 log << TestLog::Message << "subgroupSize " << size << " failed" << TestLog::EndMessage;
290 return result;
291 }
292 }
293
294 return TestStatus::pass("OK");
295 }
296 else if (isAllGraphicsStages(caseDef.shaderStage))
297 {
298 const VkShaderStageFlags stages = subgroups::getPossibleGraphicsSubgroupStages(context, caseDef.shaderStage);
299 subgroups::SSBOData inputData;
300
301 inputData.format = caseDef.format;
302 inputData.layout = subgroups::SSBOData::LayoutStd430;
303 inputData.numElements = subgroups::maxSupportedSubgroupSize();
304 inputData.initializeType = subgroups::SSBOData::InitializeNonZero;
305 inputData.binding = 4u;
306 inputData.stages = stages;
307
308 return subgroups::allStages(context, VK_FORMAT_R32_UINT, &inputData, 1, DE_NULL, checkVertexPipelineStages, stages);
309 }
310 else if (isAllRayTracingStages(caseDef.shaderStage))
311 {
312 const VkShaderStageFlags stages = subgroups::getPossibleRayTracingSubgroupStages(context, caseDef.shaderStage);
313 const subgroups::SSBOData inputData =
314 {
315 subgroups::SSBOData::InitializeNonZero, // InputDataInitializeType initializeType;
316 subgroups::SSBOData::LayoutStd430, // InputDataLayoutType layout;
317 caseDef.format, // vk::VkFormat format;
318 subgroups::maxSupportedSubgroupSize(), // vk::VkDeviceSize numElements;
319 false, // bool isImage;
320 6u, // deUint32 binding;
321 stages, // vk::VkShaderStageFlags stages;
322 };
323
324 return subgroups::allRayTracingStages(context, VK_FORMAT_R32_UINT, &inputData, 1, DE_NULL, checkVertexPipelineStages, stages);
325 }
326 else
327 TCU_THROW(InternalError, "Unknown stage or invalid stage set");
328 }
329 }
330
331 namespace vkt
332 {
333 namespace subgroups
334 {
createSubgroupsQuadTests(TestContext & testCtx)335 TestCaseGroup* createSubgroupsQuadTests (TestContext& testCtx)
336 {
337 de::MovePtr<TestCaseGroup> group (new TestCaseGroup(testCtx, "quad", "Subgroup quad category tests"));
338 de::MovePtr<TestCaseGroup> graphicGroup (new TestCaseGroup(testCtx, "graphics", "Subgroup arithmetic category tests: graphics"));
339 de::MovePtr<TestCaseGroup> computeGroup (new TestCaseGroup(testCtx, "compute", "Subgroup arithmetic category tests: compute"));
340 de::MovePtr<TestCaseGroup> framebufferGroup (new TestCaseGroup(testCtx, "framebuffer", "Subgroup arithmetic category tests: framebuffer"));
341 de::MovePtr<TestCaseGroup> raytracingGroup (new TestCaseGroup(testCtx, "ray_tracing", "Subgroup arithmetic category tests: ray tracing"));
342 const VkShaderStageFlags stages[] =
343 {
344 VK_SHADER_STAGE_VERTEX_BIT,
345 VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT,
346 VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT,
347 VK_SHADER_STAGE_GEOMETRY_BIT,
348 };
349 const deBool boolValues[] =
350 {
351 DE_FALSE,
352 DE_TRUE
353 };
354
355 {
356 const vector<VkFormat> formats = subgroups::getAllFormats();
357
358 for (size_t formatIndex = 0; formatIndex < formats.size(); ++formatIndex)
359 {
360 const VkFormat format = formats[formatIndex];
361 const string formatName = subgroups::getFormatNameForGLSL(format);
362
363 for (int opTypeIndex = 0; opTypeIndex < OPTYPE_LAST; ++opTypeIndex)
364 {
365 const OpType opType = static_cast<OpType>(opTypeIndex);
366 const string name = getOpTypeCaseName(opType) + "_" + formatName;
367
368 for (size_t groupSizeNdx = 0; groupSizeNdx < DE_LENGTH_OF_ARRAY(boolValues); ++groupSizeNdx)
369 {
370 const deBool requiredSubgroupSize = boolValues[groupSizeNdx];
371 const string testNameSuffix = requiredSubgroupSize ? "_requiredsubgroupsize" : "";
372 const string testName = name + testNameSuffix;
373 const CaseDefinition caseDef =
374 {
375 opType, // OpType opType;
376 VK_SHADER_STAGE_COMPUTE_BIT, // VkShaderStageFlags shaderStage;
377 format, // VkFormat format;
378 de::SharedPtr<bool>(new bool), // de::SharedPtr<bool> geometryPointSizeSupported;
379 requiredSubgroupSize, // deBool requiredSubgroupSize;
380 };
381
382 addFunctionCaseWithPrograms(computeGroup.get(), testName, "", supportedCheck, initPrograms, test, caseDef);
383 }
384
385 {
386 const CaseDefinition caseDef =
387 {
388 opType, // OpType opType;
389 VK_SHADER_STAGE_ALL_GRAPHICS, // VkShaderStageFlags shaderStage;
390 format, // VkFormat format;
391 de::SharedPtr<bool>(new bool), // de::SharedPtr<bool> geometryPointSizeSupported;
392 DE_FALSE // deBool requiredSubgroupSize;
393 };
394
395 addFunctionCaseWithPrograms(graphicGroup.get(), name, "", supportedCheck, initPrograms, test, caseDef);
396 }
397
398 for (int stageIndex = 0; stageIndex < DE_LENGTH_OF_ARRAY(stages); ++stageIndex)
399 {
400 const CaseDefinition caseDef =
401 {
402 opType, // OpType opType;
403 stages[stageIndex], // VkShaderStageFlags shaderStage;
404 format, // VkFormat format;
405 de::SharedPtr<bool>(new bool), // de::SharedPtr<bool> geometryPointSizeSupported;
406 DE_FALSE // deBool requiredSubgroupSize;
407 };
408 const string testName = name + "_" + getShaderStageName(caseDef.shaderStage);
409
410 addFunctionCaseWithPrograms(framebufferGroup.get(), testName, "", supportedCheck, initFrameBufferPrograms, noSSBOtest, caseDef);
411 }
412 }
413 }
414 }
415
416 {
417 const vector<VkFormat> formats = subgroups::getAllRayTracingFormats();
418
419 for (size_t formatIndex = 0; formatIndex < formats.size(); ++formatIndex)
420 {
421 const VkFormat format = formats[formatIndex];
422 const string formatName = subgroups::getFormatNameForGLSL(format);
423
424 for (int opTypeIndex = 0; opTypeIndex < OPTYPE_LAST; ++opTypeIndex)
425 {
426 const OpType opType = static_cast<OpType>(opTypeIndex);
427 const string testName = getOpTypeCaseName(opType) + "_" + formatName;
428 const CaseDefinition caseDef =
429 {
430 opType, // OpType opType;
431 SHADER_STAGE_ALL_RAY_TRACING, // VkShaderStageFlags shaderStage;
432 format, // VkFormat format;
433 de::SharedPtr<bool>(new bool), // de::SharedPtr<bool> geometryPointSizeSupported;
434 DE_FALSE // deBool requiredSubgroupSize;
435 };
436
437 addFunctionCaseWithPrograms(raytracingGroup.get(), testName, "", supportedCheck, initPrograms, test, caseDef);
438 }
439 }
440 }
441
442 group->addChild(graphicGroup.release());
443 group->addChild(computeGroup.release());
444 group->addChild(framebufferGroup.release());
445 group->addChild(raytracingGroup.release());
446
447 return group.release();
448 }
449 } // subgroups
450 } // vkt
451