1 /*------------------------------------------------------------------------
2 * Vulkan Conformance Tests
3 * ------------------------
4 *
5 * Copyright (c) 2019 The Khronos Group Inc.
6 * Copyright (c) 2019 Google Inc.
7 * Copyright (c) 2017 Codeplay Software Ltd.
8 *
9 * Licensed under the Apache License, Version 2.0 (the "License");
10 * you may not use this file except in compliance with the License.
11 * You may obtain a copy of the License at
12 *
13 * http://www.apache.org/licenses/LICENSE-2.0
14 *
15 * Unless required by applicable law or agreed to in writing, software
16 * distributed under the License is distributed on an "AS IS" BASIS,
17 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18 * See the License for the specific language governing permissions and
19 * limitations under the License.
20 *
21 */ /*!
22 * \file
23 * \brief Subgroups Tests
24 */ /*--------------------------------------------------------------------*/
25
26 #include "vktSubgroupsClusteredTests.hpp"
27 #include "vktSubgroupsScanHelpers.hpp"
28 #include "vktSubgroupsTestsUtils.hpp"
29
30 #include <string>
31 #include <vector>
32
33 using namespace tcu;
34 using namespace std;
35 using namespace vk;
36 using namespace vkt;
37
38 namespace
39 {
40 enum OpType
41 {
42 OPTYPE_CLUSTERED_ADD = 0,
43 OPTYPE_CLUSTERED_MUL,
44 OPTYPE_CLUSTERED_MIN,
45 OPTYPE_CLUSTERED_MAX,
46 OPTYPE_CLUSTERED_AND,
47 OPTYPE_CLUSTERED_OR,
48 OPTYPE_CLUSTERED_XOR,
49 OPTYPE_CLUSTERED_LAST
50 };
51
52 struct CaseDefinition
53 {
54 Operator op;
55 VkShaderStageFlags shaderStage;
56 VkFormat format;
57 de::SharedPtr<bool> geometryPointSizeSupported;
58 deBool requiredSubgroupSize;
59 };
60
getOperator(OpType opType)61 static Operator getOperator (OpType opType)
62 {
63 switch (opType)
64 {
65 case OPTYPE_CLUSTERED_ADD: return OPERATOR_ADD;
66 case OPTYPE_CLUSTERED_MUL: return OPERATOR_MUL;
67 case OPTYPE_CLUSTERED_MIN: return OPERATOR_MIN;
68 case OPTYPE_CLUSTERED_MAX: return OPERATOR_MAX;
69 case OPTYPE_CLUSTERED_AND: return OPERATOR_AND;
70 case OPTYPE_CLUSTERED_OR: return OPERATOR_OR;
71 case OPTYPE_CLUSTERED_XOR: return OPERATOR_XOR;
72 default: TCU_THROW(InternalError, "Unsupported op type");
73 }
74 }
75
checkVertexPipelineStages(const void * internalData,vector<const void * > datas,deUint32 width,deUint32)76 static bool checkVertexPipelineStages (const void* internalData,
77 vector<const void*> datas,
78 deUint32 width,
79 deUint32)
80 {
81 DE_UNREF(internalData);
82
83 return subgroups::check(datas, width, 1);
84 }
85
checkCompute(const void * internalData,vector<const void * > datas,const deUint32 numWorkgroups[3],const deUint32 localSize[3],deUint32)86 static bool checkCompute (const void* internalData,
87 vector<const void*> datas,
88 const deUint32 numWorkgroups[3],
89 const deUint32 localSize[3],
90 deUint32)
91 {
92 DE_UNREF(internalData);
93
94 return subgroups::checkCompute(datas, numWorkgroups, localSize, 1);
95 }
96
getOpTypeName(Operator op)97 string getOpTypeName (Operator op)
98 {
99 return getScanOpName("subgroupClustered", "", op, SCAN_REDUCE);
100 }
101
getExtHeader(CaseDefinition & caseDef)102 string getExtHeader (CaseDefinition& caseDef)
103 {
104 return "#extension GL_KHR_shader_subgroup_clustered: enable\n"
105 "#extension GL_KHR_shader_subgroup_ballot: enable\n" +
106 subgroups::getAdditionalExtensionForFormat(caseDef.format);
107 }
108
getTestSrc(CaseDefinition & caseDef)109 string getTestSrc (CaseDefinition& caseDef)
110 {
111 const string formatName = subgroups::getFormatNameForGLSL(caseDef.format);
112 const string opTypeName = getOpTypeName(caseDef.op);
113 const string identity = getIdentity(caseDef.op, caseDef.format);
114 const string opOperation = getOpOperation(caseDef.op, caseDef.format, "ref", "data[index]");
115 const string compare = getCompare(caseDef.op, caseDef.format, "ref", "op");
116 ostringstream bdy;
117
118 bdy << " bool tempResult = true;\n"
119 << " uvec4 mask = subgroupBallot(true);\n";
120
121 for (deUint32 i = 1; i <= subgroups::maxSupportedSubgroupSize(); i *= 2)
122 {
123 bdy << " {\n"
124 << " const uint clusterSize = " << i << ";\n"
125 << " if (clusterSize <= gl_SubgroupSize)\n"
126 << " {\n"
127 << " " << formatName << " op = "
128 << opTypeName + "(data[gl_SubgroupInvocationID], clusterSize);\n"
129 << " for (uint clusterOffset = 0; clusterOffset < gl_SubgroupSize; clusterOffset += clusterSize)\n"
130 << " {\n"
131 << " " << formatName << " ref = "
132 << identity << ";\n"
133 << " for (uint index = clusterOffset; index < (clusterOffset + clusterSize); index++)\n"
134 << " {\n"
135 << " if (subgroupBallotBitExtract(mask, index))\n"
136 << " {\n"
137 << " ref = " << opOperation << ";\n"
138 << " }\n"
139 << " }\n"
140 << " if ((clusterOffset <= gl_SubgroupInvocationID) && (gl_SubgroupInvocationID < (clusterOffset + clusterSize)))\n"
141 << " {\n"
142 << " if (!" << compare << ")\n"
143 << " {\n"
144 << " tempResult = false;\n"
145 << " }\n"
146 << " }\n"
147 << " }\n"
148 << " }\n"
149 << " }\n"
150 << " tempRes = tempResult ? 1 : 0;\n";
151 }
152
153 return bdy.str();
154 }
155
initFrameBufferPrograms(SourceCollections & programCollection,CaseDefinition caseDef)156 void initFrameBufferPrograms (SourceCollections& programCollection, CaseDefinition caseDef)
157 {
158 const ShaderBuildOptions buildOptions (programCollection.usedVulkanVersion, SPIRV_VERSION_1_3, 0u);
159 const string extHeader = getExtHeader(caseDef);
160 const string testSrc = getTestSrc(caseDef);
161
162 subgroups::initStdFrameBufferPrograms(programCollection, buildOptions, caseDef.shaderStage, caseDef.format, *caseDef.geometryPointSizeSupported, extHeader, testSrc, "");
163 }
164
initPrograms(SourceCollections & programCollection,CaseDefinition caseDef)165 void initPrograms (SourceCollections& programCollection, CaseDefinition caseDef)
166 {
167 const bool spirv14required = isAllRayTracingStages(caseDef.shaderStage);
168 const SpirvVersion spirvVersion = spirv14required ? SPIRV_VERSION_1_4 : SPIRV_VERSION_1_3;
169 const ShaderBuildOptions buildOptions (programCollection.usedVulkanVersion, spirvVersion, 0u);
170 const string extHeader = getExtHeader(caseDef);
171 const string testSrc = getTestSrc(caseDef);
172
173 subgroups::initStdPrograms(programCollection, buildOptions, caseDef.shaderStage, caseDef.format, *caseDef.geometryPointSizeSupported, extHeader, testSrc, "");
174 }
175
supportedCheck(Context & context,CaseDefinition caseDef)176 void supportedCheck (Context& context, CaseDefinition caseDef)
177 {
178 if (!subgroups::isSubgroupSupported(context))
179 TCU_THROW(NotSupportedError, "Subgroup operations are not supported");
180
181 if (!subgroups::isSubgroupFeatureSupportedForDevice(context, VK_SUBGROUP_FEATURE_CLUSTERED_BIT))
182 TCU_THROW(NotSupportedError, "Device does not support subgroup clustered operations");
183
184 if (!subgroups::isFormatSupportedForDevice(context, caseDef.format))
185 TCU_THROW(NotSupportedError, "Device does not support the specified format in subgroup operations");
186
187 if (caseDef.requiredSubgroupSize)
188 {
189 context.requireDeviceFunctionality("VK_EXT_subgroup_size_control");
190
191 const VkPhysicalDeviceSubgroupSizeControlFeaturesEXT& subgroupSizeControlFeatures = context.getSubgroupSizeControlFeaturesEXT();
192 const VkPhysicalDeviceSubgroupSizeControlPropertiesEXT& subgroupSizeControlProperties = context.getSubgroupSizeControlPropertiesEXT();
193
194 if (subgroupSizeControlFeatures.subgroupSizeControl == DE_FALSE)
195 TCU_THROW(NotSupportedError, "Device does not support varying subgroup sizes nor required subgroup size");
196
197 if (subgroupSizeControlFeatures.computeFullSubgroups == DE_FALSE)
198 TCU_THROW(NotSupportedError, "Device does not support full subgroups in compute shaders");
199
200 if ((subgroupSizeControlProperties.requiredSubgroupSizeStages & caseDef.shaderStage) != caseDef.shaderStage)
201 TCU_THROW(NotSupportedError, "Required subgroup size is not supported for shader stage");
202 }
203
204 *caseDef.geometryPointSizeSupported = subgroups::isTessellationAndGeometryPointSizeSupported(context);
205
206 if (isAllRayTracingStages(caseDef.shaderStage))
207 {
208 context.requireDeviceFunctionality("VK_KHR_ray_tracing_pipeline");
209 }
210
211 subgroups::supportedCheckShader(context, caseDef.shaderStage);
212 }
213
noSSBOtest(Context & context,const CaseDefinition caseDef)214 TestStatus noSSBOtest (Context& context, const CaseDefinition caseDef)
215 {
216 const subgroups::SSBOData inputData =
217 {
218 subgroups::SSBOData::InitializeNonZero, // InputDataInitializeType initializeType;
219 subgroups::SSBOData::LayoutStd140, // InputDataLayoutType layout;
220 caseDef.format, // vk::VkFormat format;
221 subgroups::maxSupportedSubgroupSize(), // vk::VkDeviceSize numElements;
222 };
223
224 switch (caseDef.shaderStage)
225 {
226 case VK_SHADER_STAGE_VERTEX_BIT: return subgroups::makeVertexFrameBufferTest(context, VK_FORMAT_R32_UINT, &inputData, 1, DE_NULL, checkVertexPipelineStages);
227 case VK_SHADER_STAGE_GEOMETRY_BIT: return subgroups::makeGeometryFrameBufferTest(context, VK_FORMAT_R32_UINT, &inputData, 1, DE_NULL, checkVertexPipelineStages);
228 case VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT: return subgroups::makeTessellationEvaluationFrameBufferTest(context, VK_FORMAT_R32_UINT, &inputData, 1, DE_NULL, checkVertexPipelineStages, caseDef.shaderStage);
229 case VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT: return subgroups::makeTessellationEvaluationFrameBufferTest(context, VK_FORMAT_R32_UINT, &inputData, 1, DE_NULL, checkVertexPipelineStages, caseDef.shaderStage);
230 default: TCU_THROW(InternalError, "Unhandled shader stage");
231 }
232 }
233
test(Context & context,const CaseDefinition caseDef)234 TestStatus test (Context& context, const CaseDefinition caseDef)
235 {
236 if (isAllComputeStages(caseDef.shaderStage))
237 {
238 const VkPhysicalDeviceSubgroupSizeControlPropertiesEXT& subgroupSizeControlProperties = context.getSubgroupSizeControlPropertiesEXT();
239 TestLog& log = context.getTestContext().getLog();
240
241 subgroups::SSBOData inputData;
242 inputData.format = caseDef.format;
243 inputData.layout = subgroups::SSBOData::LayoutStd430;
244 inputData.numElements = subgroups::maxSupportedSubgroupSize();
245 inputData.initializeType = subgroups::SSBOData::InitializeNonZero;
246
247 if (caseDef.requiredSubgroupSize == DE_FALSE)
248 return subgroups::makeComputeTest(context, VK_FORMAT_R32_UINT, &inputData, 1, DE_NULL, checkCompute);
249
250 log << TestLog::Message << "Testing required subgroup size range [" << subgroupSizeControlProperties.minSubgroupSize << ", "
251 << subgroupSizeControlProperties.maxSubgroupSize << "]" << TestLog::EndMessage;
252
253 // According to the spec, requiredSubgroupSize must be a power-of-two integer.
254 for (deUint32 size = subgroupSizeControlProperties.minSubgroupSize; size <= subgroupSizeControlProperties.maxSubgroupSize; size *= 2)
255 {
256 TestStatus result = subgroups::makeComputeTest(context, VK_FORMAT_R32_UINT, &inputData, 1, DE_NULL, checkCompute,
257 size, VK_PIPELINE_SHADER_STAGE_CREATE_REQUIRE_FULL_SUBGROUPS_BIT_EXT);
258 if (result.getCode() != QP_TEST_RESULT_PASS)
259 {
260 log << TestLog::Message << "subgroupSize " << size << " failed" << TestLog::EndMessage;
261 return result;
262 }
263 }
264
265 return TestStatus::pass("OK");
266 }
267 else if (isAllGraphicsStages(caseDef.shaderStage))
268 {
269 const VkShaderStageFlags stages = subgroups::getPossibleGraphicsSubgroupStages(context, caseDef.shaderStage);
270 const subgroups::SSBOData inputData =
271 {
272 subgroups::SSBOData::InitializeNonZero, // InputDataInitializeType initializeType;
273 subgroups::SSBOData::LayoutStd430, // InputDataLayoutType layout;
274 caseDef.format, // vk::VkFormat format;
275 subgroups::maxSupportedSubgroupSize(), // vk::VkDeviceSize numElements;
276 false, // bool isImage;
277 4u, // deUint32 binding;
278 stages, // vk::VkShaderStageFlags stages;
279 };
280
281 return subgroups::allStages(context, VK_FORMAT_R32_UINT, &inputData, 1, DE_NULL, checkVertexPipelineStages, stages);
282 }
283 else if (isAllRayTracingStages(caseDef.shaderStage))
284 {
285 const VkShaderStageFlags stages = subgroups::getPossibleRayTracingSubgroupStages(context, caseDef.shaderStage);
286 const subgroups::SSBOData inputData =
287 {
288 subgroups::SSBOData::InitializeNonZero, // InputDataInitializeType initializeType;
289 subgroups::SSBOData::LayoutStd430, // InputDataLayoutType layout;
290 caseDef.format, // vk::VkFormat format;
291 subgroups::maxSupportedSubgroupSize(), // vk::VkDeviceSize numElements;
292 false, // bool isImage;
293 6u, // deUint32 binding;
294 stages, // vk::VkShaderStageFlags stages;
295 };
296
297 return subgroups::allRayTracingStages(context, VK_FORMAT_R32_UINT, &inputData, 1, DE_NULL, checkVertexPipelineStages, stages);
298 }
299 else
300 TCU_THROW(InternalError, "Unknown stage or invalid stage set");
301 }
302 }
303
304 namespace vkt
305 {
306 namespace subgroups
307 {
createSubgroupsClusteredTests(TestContext & testCtx)308 TestCaseGroup* createSubgroupsClusteredTests (TestContext& testCtx)
309 {
310 de::MovePtr<TestCaseGroup> group (new TestCaseGroup(testCtx, "clustered", "Subgroup clustered category tests"));
311 de::MovePtr<TestCaseGroup> graphicGroup (new TestCaseGroup(testCtx, "graphics", "Subgroup clustered category tests: graphics"));
312 de::MovePtr<TestCaseGroup> computeGroup (new TestCaseGroup(testCtx, "compute", "Subgroup clustered category tests: compute"));
313 de::MovePtr<TestCaseGroup> framebufferGroup (new TestCaseGroup(testCtx, "framebuffer", "Subgroup clustered category tests: framebuffer"));
314 de::MovePtr<TestCaseGroup> raytracingGroup (new TestCaseGroup(testCtx, "ray_tracing", "Subgroup clustered category tests: ray tracing"));
315 const VkShaderStageFlags stages[] =
316 {
317 VK_SHADER_STAGE_VERTEX_BIT,
318 VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT,
319 VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT,
320 VK_SHADER_STAGE_GEOMETRY_BIT,
321 };
322 const deBool boolValues[] =
323 {
324 DE_FALSE,
325 DE_TRUE
326 };
327
328 {
329 const vector<VkFormat> formats = subgroups::getAllFormats();
330
331 for (size_t formatIndex = 0; formatIndex < formats.size(); ++formatIndex)
332 {
333 const VkFormat format = formats[formatIndex];
334 const string formatName = subgroups::getFormatNameForGLSL(format);
335 const bool isBool = subgroups::isFormatBool(format);
336 const bool isFloat = subgroups::isFormatFloat(format);
337
338 for (int opTypeIndex = 0; opTypeIndex < OPTYPE_CLUSTERED_LAST; ++opTypeIndex)
339 {
340 const OpType opType = static_cast<OpType>(opTypeIndex);
341 const Operator op = getOperator(opType);
342 const bool isBitwiseOp = (op == OPERATOR_AND || op == OPERATOR_OR || op == OPERATOR_XOR);
343
344 // Skip float with bitwise category.
345 if (isFloat && isBitwiseOp)
346 continue;
347
348 // Skip bool when its not the bitwise category.
349 if (isBool && !isBitwiseOp)
350 continue;
351
352 const string name = de::toLower(getOpTypeName(op)) +"_" + formatName;
353
354 for (size_t groupSizeNdx = 0; groupSizeNdx < DE_LENGTH_OF_ARRAY(boolValues); ++groupSizeNdx)
355 {
356 const deBool requiredSubgroupSize = boolValues[groupSizeNdx];
357 const string testName = name + (requiredSubgroupSize ? "_requiredsubgroupsize" : "");
358 const CaseDefinition caseDef =
359 {
360 op, // Operator op;
361 VK_SHADER_STAGE_COMPUTE_BIT, // VkShaderStageFlags shaderStage;
362 format, // VkFormat format;
363 de::SharedPtr<bool>(new bool), // de::SharedPtr<bool> geometryPointSizeSupported;
364 requiredSubgroupSize, // deBool requiredSubgroupSize;
365 };
366
367 addFunctionCaseWithPrograms(computeGroup.get(), testName, "", supportedCheck, initPrograms, test, caseDef);
368 }
369
370 {
371 const CaseDefinition caseDef =
372 {
373 op, // Operator op;
374 VK_SHADER_STAGE_ALL_GRAPHICS, // VkShaderStageFlags shaderStage;
375 format, // VkFormat format;
376 de::SharedPtr<bool>(new bool), // de::SharedPtr<bool> geometryPointSizeSupported;
377 DE_FALSE // deBool requiredSubgroupSize;
378 };
379
380 addFunctionCaseWithPrograms(graphicGroup.get(), name, "", supportedCheck, initPrograms, test, caseDef);
381 }
382
383 for (int stageIndex = 0; stageIndex < DE_LENGTH_OF_ARRAY(stages); ++stageIndex)
384 {
385 const CaseDefinition caseDef =
386 {
387 op, // Operator op;
388 stages[stageIndex], // VkShaderStageFlags shaderStage;
389 format, // VkFormat format;
390 de::SharedPtr<bool>(new bool), // de::SharedPtr<bool> geometryPointSizeSupported;
391 DE_FALSE // deBool requiredSubgroupSize;
392 };
393 const string testName = name +"_" + getShaderStageName(caseDef.shaderStage);
394
395 addFunctionCaseWithPrograms(framebufferGroup.get(), testName, "", supportedCheck, initFrameBufferPrograms, noSSBOtest, caseDef);
396 }
397 }
398 }
399 }
400
401 {
402 const vector<VkFormat> formats = subgroups::getAllRayTracingFormats();
403
404 for (size_t formatIndex = 0; formatIndex < formats.size(); ++formatIndex)
405 {
406 const VkFormat format = formats[formatIndex];
407 const string formatName = subgroups::getFormatNameForGLSL(format);
408 const bool isBool = subgroups::isFormatBool(format);
409 const bool isFloat = subgroups::isFormatFloat(format);
410
411 for (int opTypeIndex = 0; opTypeIndex < OPTYPE_CLUSTERED_LAST; ++opTypeIndex)
412 {
413 const OpType opType = static_cast<OpType>(opTypeIndex);
414 const Operator op = getOperator(opType);
415 const bool isBitwiseOp = (op == OPERATOR_AND || op == OPERATOR_OR || op == OPERATOR_XOR);
416
417 // Skip float with bitwise category.
418 if (isFloat && isBitwiseOp)
419 continue;
420
421 // Skip bool when its not the bitwise category.
422 if (isBool && !isBitwiseOp)
423 continue;
424
425 {
426 const string name = de::toLower(getOpTypeName(op)) +"_" + formatName;
427 const CaseDefinition caseDef =
428 {
429 op, // Operator op;
430 SHADER_STAGE_ALL_RAY_TRACING, // VkShaderStageFlags shaderStage;
431 format, // VkFormat format;
432 de::SharedPtr<bool>(new bool), // de::SharedPtr<bool> geometryPointSizeSupported;
433 DE_FALSE // deBool requiredSubgroupSize;
434 };
435
436 addFunctionCaseWithPrograms(raytracingGroup.get(), name, "", supportedCheck, initPrograms, test, caseDef);
437 }
438 }
439 }
440 }
441
442 group->addChild(graphicGroup.release());
443 group->addChild(computeGroup.release());
444 group->addChild(framebufferGroup.release());
445 group->addChild(raytracingGroup.release());
446
447 return group.release();
448 }
449 } // subgroups
450 } // vkt
451