• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*------------------------------------------------------------------------
2  * Vulkan Conformance Tests
3  * ------------------------
4  *
5  * Copyright (c) 2019 The Khronos Group Inc.
6  * Copyright (c) 2019 Google Inc.
7  * Copyright (c) 2017 Codeplay Software Ltd.
8  * Copyright (c) 2018 NVIDIA Corporation
9  *
10  * Licensed under the Apache License, Version 2.0 (the "License");
11  * you may not use this file except in compliance with the License.
12  * You may obtain a copy of the License at
13  *
14  *      http://www.apache.org/licenses/LICENSE-2.0
15  *
16  * Unless required by applicable law or agreed to in writing, software
17  * distributed under the License is distributed on an "AS IS" BASIS,
18  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19  * See the License for the specific language governing permissions and
20  * limitations under the License.
21  *
22  */ /*!
23  * \file
24  * \brief Subgroups Tests
25  */ /*--------------------------------------------------------------------*/
26 
27 #include "vktSubgroupsPartitionedTests.hpp"
28 #include "vktSubgroupsScanHelpers.hpp"
29 #include "vktSubgroupsTestsUtils.hpp"
30 
31 #include <string>
32 #include <vector>
33 
34 using namespace tcu;
35 using namespace std;
36 using namespace vk;
37 using namespace vkt;
38 
39 namespace
40 {
41 enum OpType
42 {
43 	OPTYPE_ADD = 0,
44 	OPTYPE_MUL,
45 	OPTYPE_MIN,
46 	OPTYPE_MAX,
47 	OPTYPE_AND,
48 	OPTYPE_OR,
49 	OPTYPE_XOR,
50 	OPTYPE_INCLUSIVE_ADD,
51 	OPTYPE_INCLUSIVE_MUL,
52 	OPTYPE_INCLUSIVE_MIN,
53 	OPTYPE_INCLUSIVE_MAX,
54 	OPTYPE_INCLUSIVE_AND,
55 	OPTYPE_INCLUSIVE_OR,
56 	OPTYPE_INCLUSIVE_XOR,
57 	OPTYPE_EXCLUSIVE_ADD,
58 	OPTYPE_EXCLUSIVE_MUL,
59 	OPTYPE_EXCLUSIVE_MIN,
60 	OPTYPE_EXCLUSIVE_MAX,
61 	OPTYPE_EXCLUSIVE_AND,
62 	OPTYPE_EXCLUSIVE_OR,
63 	OPTYPE_EXCLUSIVE_XOR,
64 	OPTYPE_LAST
65 };
66 
67 struct CaseDefinition
68 {
69 	Operator			op;
70 	ScanType			scanType;
71 	VkShaderStageFlags	shaderStage;
72 	VkFormat			format;
73 	de::SharedPtr<bool>	geometryPointSizeSupported;
74 	deBool				requiredSubgroupSize;
75 };
76 
getOperator(OpType opType)77 static Operator getOperator (OpType opType)
78 {
79 	switch (opType)
80 	{
81 		case OPTYPE_ADD:
82 		case OPTYPE_INCLUSIVE_ADD:
83 		case OPTYPE_EXCLUSIVE_ADD:
84 			return OPERATOR_ADD;
85 		case OPTYPE_MUL:
86 		case OPTYPE_INCLUSIVE_MUL:
87 		case OPTYPE_EXCLUSIVE_MUL:
88 			return OPERATOR_MUL;
89 		case OPTYPE_MIN:
90 		case OPTYPE_INCLUSIVE_MIN:
91 		case OPTYPE_EXCLUSIVE_MIN:
92 			return OPERATOR_MIN;
93 		case OPTYPE_MAX:
94 		case OPTYPE_INCLUSIVE_MAX:
95 		case OPTYPE_EXCLUSIVE_MAX:
96 			return OPERATOR_MAX;
97 		case OPTYPE_AND:
98 		case OPTYPE_INCLUSIVE_AND:
99 		case OPTYPE_EXCLUSIVE_AND:
100 			return OPERATOR_AND;
101 		case OPTYPE_OR:
102 		case OPTYPE_INCLUSIVE_OR:
103 		case OPTYPE_EXCLUSIVE_OR:
104 			return OPERATOR_OR;
105 		case OPTYPE_XOR:
106 		case OPTYPE_INCLUSIVE_XOR:
107 		case OPTYPE_EXCLUSIVE_XOR:
108 			return OPERATOR_XOR;
109 		default:
110 			DE_FATAL("Unsupported op type");
111 			return OPERATOR_ADD;
112 	}
113 }
114 
getScanType(OpType opType)115 static ScanType getScanType (OpType opType)
116 {
117 	switch (opType)
118 	{
119 		case OPTYPE_ADD:
120 		case OPTYPE_MUL:
121 		case OPTYPE_MIN:
122 		case OPTYPE_MAX:
123 		case OPTYPE_AND:
124 		case OPTYPE_OR:
125 		case OPTYPE_XOR:
126 			return SCAN_REDUCE;
127 		case OPTYPE_INCLUSIVE_ADD:
128 		case OPTYPE_INCLUSIVE_MUL:
129 		case OPTYPE_INCLUSIVE_MIN:
130 		case OPTYPE_INCLUSIVE_MAX:
131 		case OPTYPE_INCLUSIVE_AND:
132 		case OPTYPE_INCLUSIVE_OR:
133 		case OPTYPE_INCLUSIVE_XOR:
134 			return SCAN_INCLUSIVE;
135 		case OPTYPE_EXCLUSIVE_ADD:
136 		case OPTYPE_EXCLUSIVE_MUL:
137 		case OPTYPE_EXCLUSIVE_MIN:
138 		case OPTYPE_EXCLUSIVE_MAX:
139 		case OPTYPE_EXCLUSIVE_AND:
140 		case OPTYPE_EXCLUSIVE_OR:
141 		case OPTYPE_EXCLUSIVE_XOR:
142 			return SCAN_EXCLUSIVE;
143 		default:
144 			DE_FATAL("Unsupported op type");
145 			return SCAN_REDUCE;
146 	}
147 }
148 
checkVertexPipelineStages(const void * internalData,vector<const void * > datas,deUint32 width,deUint32)149 static bool checkVertexPipelineStages (const void*			internalData,
150 									   vector<const void*>	datas,
151 									   deUint32				width,
152 									   deUint32)
153 {
154 	DE_UNREF(internalData);
155 
156 	return subgroups::check(datas, width, 0xFFFFFF);
157 }
158 
checkComputeOrMesh(const void * internalData,vector<const void * > datas,const deUint32 numWorkgroups[3],const deUint32 localSize[3],deUint32)159 static bool checkComputeOrMesh (const void*			internalData,
160 								vector<const void*>	datas,
161 								const deUint32		numWorkgroups[3],
162 								const deUint32		localSize[3],
163 								deUint32)
164 {
165 	DE_UNREF(internalData);
166 
167 	return subgroups::checkComputeOrMesh(datas, numWorkgroups, localSize, 0xFFFFFF);
168 }
169 
getOpTypeName(Operator op,ScanType scanType)170 string getOpTypeName (Operator op, ScanType scanType)
171 {
172 	return getScanOpName("subgroup", "", op, scanType);
173 }
174 
getOpTypeNamePartitioned(Operator op,ScanType scanType)175 string getOpTypeNamePartitioned (Operator op, ScanType scanType)
176 {
177 	return getScanOpName("subgroupPartitioned", "NV", op, scanType);
178 }
179 
getExtHeader(const CaseDefinition & caseDef)180 string getExtHeader (const CaseDefinition& caseDef)
181 {
182 	return	"#extension GL_NV_shader_subgroup_partitioned: enable\n"
183 			"#extension GL_KHR_shader_subgroup_arithmetic: enable\n"
184 			"#extension GL_KHR_shader_subgroup_ballot: enable\n"
185 			+ subgroups::getAdditionalExtensionForFormat(caseDef.format);
186 }
187 
getTestString(const CaseDefinition & caseDef)188 string getTestString (const CaseDefinition& caseDef)
189 {
190 	Operator op = caseDef.op;
191 	ScanType st = caseDef.scanType;
192 
193 	// NOTE: tempResult can't have anything in bits 31:24 to avoid int->float
194 	// conversion overflow in framebuffer tests.
195 	string fmt = subgroups::getFormatNameForGLSL(caseDef.format);
196 	string bdy =
197 		"  uvec4 mask = subgroupBallot(true);\n"
198 		"  uint tempResult = 0;\n"
199 		"  uint id = gl_SubgroupInvocationID;\n";
200 
201 	// Test the case where the partition has a single subset with all invocations in it.
202 	// This should generate the same result as the non-partitioned function.
203 	bdy +=
204 		"  uvec4 allBallot = mask;\n"
205 		"  " + fmt + " allResult = " + getOpTypeNamePartitioned(op, st) + "(data[gl_SubgroupInvocationID], allBallot);\n"
206 		"  " + fmt + " refResult = " + getOpTypeName(op, st) + "(data[gl_SubgroupInvocationID]);\n"
207 		"  if (" + getCompare(op, caseDef.format, "allResult", "refResult") + ") {\n"
208 		"      tempResult |= 0x1;\n"
209 		"  }\n";
210 
211 	// The definition of a partition doesn't forbid bits corresponding to inactive
212 	// invocations being in the subset with active invocations. In other words, test that
213 	// bits corresponding to inactive invocations are ignored.
214 	bdy +=
215 		"  if (0 == (gl_SubgroupInvocationID % 2)) {\n"
216 		"    " + fmt + " allResult = " + getOpTypeNamePartitioned(op, st) + "(data[gl_SubgroupInvocationID], allBallot);\n"
217 		"    " + fmt + " refResult = " + getOpTypeName(op, st) + "(data[gl_SubgroupInvocationID]);\n"
218 		"    if (" + getCompare(op, caseDef.format, "allResult", "refResult") + ") {\n"
219 		"        tempResult |= 0x2;\n"
220 		"    }\n"
221 		"  } else {\n"
222 		"    tempResult |= 0x2;\n"
223 		"  }\n";
224 
225 	// Test the case where the partition has each invocation in a unique subset. For
226 	// exclusive ops, the result is identity. For reduce/inclusive, it's the original value.
227 	string expectedSelfResult = "data[gl_SubgroupInvocationID]";
228 	if (st == SCAN_EXCLUSIVE)
229 		expectedSelfResult = getIdentity(op, caseDef.format);
230 
231 	bdy +=
232 		"  uvec4 selfBallot = subgroupPartitionNV(gl_SubgroupInvocationID);\n"
233 		"  " + fmt + " selfResult = " + getOpTypeNamePartitioned(op, st) + "(data[gl_SubgroupInvocationID], selfBallot);\n"
234 		"  if (" + getCompare(op, caseDef.format, "selfResult", expectedSelfResult) + ") {\n"
235 		"      tempResult |= 0x4;\n"
236 		"  }\n";
237 
238 	// Test "random" partitions based on a hash of the invocation id.
239 	// This "hash" function produces interesting/randomish partitions.
240 	static const char *idhash = "((id%N)+(id%(N+1))-(id%2)+(id/2))%((N+1)/2)";
241 
242 	bdy +=
243 		"  for (uint N = 1; N < 16; ++N) {\n"
244 		"    " + fmt + " idhashFmt = " + fmt + "(" + idhash + ");\n"
245 		"    uvec4 partitionBallot = subgroupPartitionNV(idhashFmt) & mask;\n"
246 		"    " + fmt + " partitionedResult = " + getOpTypeNamePartitioned(op, st) + "(data[gl_SubgroupInvocationID], partitionBallot);\n"
247 		"      for (uint i = 0; i < N; ++i) {\n"
248 		"        " + fmt + " iFmt = " + fmt + "(i);\n"
249 		"        if (" + getCompare(op, caseDef.format, "idhashFmt", "iFmt") + ") {\n"
250 		"          " + fmt + " subsetResult = " + getOpTypeName(op, st) + "(data[gl_SubgroupInvocationID]);\n"
251 		"          tempResult |= " + getCompare(op, caseDef.format, "partitionedResult", "subsetResult") + " ? (0x4 << N) : 0;\n"
252 		"        }\n"
253 		"      }\n"
254 		"  }\n"
255 		// tests in flow control:
256 		"  if (1 == (gl_SubgroupInvocationID % 2)) {\n"
257 		"    for (uint N = 1; N < 7; ++N) {\n"
258 		"      " + fmt + " idhashFmt = " + fmt + "(" + idhash + ");\n"
259 		"      uvec4 partitionBallot = subgroupPartitionNV(idhashFmt) & mask;\n"
260 		"      " + fmt + " partitionedResult = " + getOpTypeNamePartitioned(op, st) + "(data[gl_SubgroupInvocationID], partitionBallot);\n"
261 		"        for (uint i = 0; i < N; ++i) {\n"
262 		"          " + fmt + " iFmt = " + fmt + "(i);\n"
263 		"          if (" + getCompare(op, caseDef.format, "idhashFmt", "iFmt") + ") {\n"
264 		"            " + fmt + " subsetResult = " + getOpTypeName(op, st) + "(data[gl_SubgroupInvocationID]);\n"
265 		"            tempResult |= " + getCompare(op, caseDef.format, "partitionedResult", "subsetResult") + " ? (0x20000 << N) : 0;\n"
266 		"          }\n"
267 		"        }\n"
268 		"    }\n"
269 		"  } else {\n"
270 		"    tempResult |= 0xFC0000;\n"
271 		"  }\n"
272 		"  tempRes = tempResult;\n"
273 		;
274 
275 	return bdy;
276 }
277 
initFrameBufferPrograms(SourceCollections & programCollection,CaseDefinition caseDef)278 void initFrameBufferPrograms (SourceCollections& programCollection, CaseDefinition caseDef)
279 {
280 	const ShaderBuildOptions	buildOptions		(programCollection.usedVulkanVersion, SPIRV_VERSION_1_3, 0u);
281 	const string				extHeader			= getExtHeader(caseDef);
282 	const string				testSrc				= getTestString(caseDef);
283 	const bool					pointSizeSupport	= *caseDef.geometryPointSizeSupported;
284 
285 	subgroups::initStdFrameBufferPrograms(programCollection, buildOptions, caseDef.shaderStage, caseDef.format, pointSizeSupport, extHeader, testSrc, "");
286 }
287 
initPrograms(SourceCollections & programCollection,CaseDefinition caseDef)288 void initPrograms (SourceCollections& programCollection, CaseDefinition caseDef)
289 {
290 	const bool					spirv14required		= (isAllRayTracingStages(caseDef.shaderStage) || isAllMeshShadingStages(caseDef.shaderStage));
291 	const SpirvVersion			spirvVersion		= spirv14required ? SPIRV_VERSION_1_4 : SPIRV_VERSION_1_3;
292 	const ShaderBuildOptions	buildOptions		(programCollection.usedVulkanVersion, spirvVersion, 0u, spirv14required);
293 	const string				extHeader			= getExtHeader(caseDef);
294 	const string				testSrc				= getTestString(caseDef);
295 	const bool					pointSizeSupport	= *caseDef.geometryPointSizeSupported;
296 
297 	subgroups::initStdPrograms(programCollection, buildOptions, caseDef.shaderStage, caseDef.format, pointSizeSupport, extHeader, testSrc, "");
298 }
299 
supportedCheck(Context & context,CaseDefinition caseDef)300 void supportedCheck (Context& context, CaseDefinition caseDef)
301 {
302 	if (!subgroups::isSubgroupSupported(context))
303 		TCU_THROW(NotSupportedError, "Subgroup operations are not supported");
304 
305 	if (!subgroups::isSubgroupFeatureSupportedForDevice(context, VK_SUBGROUP_FEATURE_PARTITIONED_BIT_NV))
306 		TCU_THROW(NotSupportedError, "Device does not support subgroup partitioned operations");
307 
308 	if (!subgroups::isFormatSupportedForDevice(context, caseDef.format))
309 		TCU_THROW(NotSupportedError, "Device does not support the specified format in subgroup operations");
310 
311 	if (caseDef.requiredSubgroupSize)
312 	{
313 		context.requireDeviceFunctionality("VK_EXT_subgroup_size_control");
314 
315 		const VkPhysicalDeviceSubgroupSizeControlFeatures&		subgroupSizeControlFeatures		= context.getSubgroupSizeControlFeatures();
316 		const VkPhysicalDeviceSubgroupSizeControlProperties&	subgroupSizeControlProperties	= context.getSubgroupSizeControlProperties();
317 
318 		if (subgroupSizeControlFeatures.subgroupSizeControl == DE_FALSE)
319 			TCU_THROW(NotSupportedError, "Device does not support varying subgroup sizes nor required subgroup size");
320 
321 		if (subgroupSizeControlFeatures.computeFullSubgroups == DE_FALSE)
322 			TCU_THROW(NotSupportedError, "Device does not support full subgroups in compute shaders");
323 
324 		if ((subgroupSizeControlProperties.requiredSubgroupSizeStages & caseDef.shaderStage) != caseDef.shaderStage)
325 			TCU_THROW(NotSupportedError, "Required subgroup size is not supported for shader stage");
326 	}
327 
328 	*caseDef.geometryPointSizeSupported = subgroups::isTessellationAndGeometryPointSizeSupported(context);
329 
330 	if (isAllRayTracingStages(caseDef.shaderStage))
331 	{
332 		context.requireDeviceFunctionality("VK_KHR_ray_tracing_pipeline");
333 	}
334 	else if (isAllMeshShadingStages(caseDef.shaderStage))
335 	{
336 		context.requireDeviceCoreFeature(DEVICE_CORE_FEATURE_VERTEX_PIPELINE_STORES_AND_ATOMICS);
337 		context.requireDeviceFunctionality("VK_EXT_mesh_shader");
338 
339 		if ((caseDef.shaderStage & VK_SHADER_STAGE_TASK_BIT_EXT) != 0u)
340 		{
341 			const auto& features = context.getMeshShaderFeaturesEXT();
342 			if (!features.taskShader)
343 				TCU_THROW(NotSupportedError, "Task shaders not supported");
344 		}
345 	}
346 
347 	subgroups::supportedCheckShader(context, caseDef.shaderStage);
348 }
349 
noSSBOtest(Context & context,const CaseDefinition caseDef)350 TestStatus noSSBOtest (Context& context, const CaseDefinition caseDef)
351 {
352 	const subgroups::SSBOData	inputData
353 	{
354 		subgroups::SSBOData::InitializeNonZero,	//  InputDataInitializeType		initializeType;
355 		subgroups::SSBOData::LayoutStd140,		//  InputDataLayoutType			layout;
356 		caseDef.format,							//  vk::VkFormat				format;
357 		subgroups::maxSupportedSubgroupSize(),	//  vk::VkDeviceSize			numElements;
358 		subgroups::SSBOData::BindingUBO,		//  BindingType					bindingType;
359 	};
360 
361 	switch (caseDef.shaderStage)
362 	{
363 		case VK_SHADER_STAGE_VERTEX_BIT:					return subgroups::makeVertexFrameBufferTest(context, VK_FORMAT_R32_UINT, &inputData, 1, DE_NULL, checkVertexPipelineStages);
364 		case VK_SHADER_STAGE_GEOMETRY_BIT:					return subgroups::makeGeometryFrameBufferTest(context, VK_FORMAT_R32_UINT, &inputData, 1, DE_NULL, checkVertexPipelineStages);
365 		case VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT:		return subgroups::makeTessellationEvaluationFrameBufferTest(context, VK_FORMAT_R32_UINT, &inputData, 1, DE_NULL, checkVertexPipelineStages, caseDef.shaderStage);
366 		case VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT:	return subgroups::makeTessellationEvaluationFrameBufferTest(context, VK_FORMAT_R32_UINT, &inputData, 1, DE_NULL, checkVertexPipelineStages, caseDef.shaderStage);
367 		default:											TCU_THROW(InternalError, "Unhandled shader stage");
368 	}
369 }
370 
test(Context & context,const CaseDefinition caseDef)371 TestStatus test (Context& context, const CaseDefinition caseDef)
372 {
373 	const bool isCompute	= isAllComputeStages(caseDef.shaderStage);
374 	const bool isMesh		= isAllMeshShadingStages(caseDef.shaderStage);
375 	DE_ASSERT(!(isCompute && isMesh));
376 
377 	if (isCompute || isMesh)
378 	{
379 		const VkPhysicalDeviceSubgroupSizeControlProperties&	subgroupSizeControlProperties	= context.getSubgroupSizeControlProperties();
380 		TestLog&												log								= context.getTestContext().getLog();
381 		const subgroups::SSBOData								inputData						=
382 		{
383 			subgroups::SSBOData::InitializeNonZero,	//  InputDataInitializeType		initializeType;
384 			subgroups::SSBOData::LayoutStd430,		//  InputDataLayoutType			layout;
385 			caseDef.format,							//  vk::VkFormat				format;
386 			subgroups::maxSupportedSubgroupSize(),	//  vk::VkDeviceSize			numElements;
387 		};
388 
389 		if (caseDef.requiredSubgroupSize == DE_FALSE)
390 		{
391 			if (isCompute)
392 				return subgroups::makeComputeTest(context, VK_FORMAT_R32_UINT, &inputData, 1, DE_NULL, checkComputeOrMesh);
393 			else
394 				return subgroups::makeMeshTest(context, VK_FORMAT_R32_UINT, &inputData, 1, DE_NULL, checkComputeOrMesh);
395 		}
396 
397 		log << TestLog::Message << "Testing required subgroup size range [" <<  subgroupSizeControlProperties.minSubgroupSize << ", "
398 			<< subgroupSizeControlProperties.maxSubgroupSize << "]" << TestLog::EndMessage;
399 
400 		// According to the spec, requiredSubgroupSize must be a power-of-two integer.
401 		for (deUint32 size = subgroupSizeControlProperties.minSubgroupSize; size <= subgroupSizeControlProperties.maxSubgroupSize; size *= 2)
402 		{
403 			TestStatus result (QP_TEST_RESULT_INTERNAL_ERROR, "Internal Error");
404 
405 			if (isCompute)
406 				result = subgroups::makeComputeTest(context, VK_FORMAT_R32_UINT, &inputData, 1, DE_NULL, checkComputeOrMesh, size);
407 			else
408 				result = subgroups::makeMeshTest(context, VK_FORMAT_R32_UINT, &inputData, 1, DE_NULL, checkComputeOrMesh, size);
409 
410 			if (result.getCode() != QP_TEST_RESULT_PASS)
411 			{
412 				log << TestLog::Message << "subgroupSize " << size << " failed" << TestLog::EndMessage;
413 				return result;
414 			}
415 		}
416 
417 		return TestStatus::pass("OK");
418 	}
419 	else if (isAllGraphicsStages(caseDef.shaderStage))
420 	{
421 		const VkShaderStageFlags	stages		= subgroups::getPossibleGraphicsSubgroupStages(context, caseDef.shaderStage);
422 		const subgroups::SSBOData	inputData
423 		{
424 			subgroups::SSBOData::InitializeNonZero,	//  InputDataInitializeType		initializeType;
425 			subgroups::SSBOData::LayoutStd430,		//  InputDataLayoutType			layout;
426 			caseDef.format,							//  vk::VkFormat				format;
427 			subgroups::maxSupportedSubgroupSize(),	//  vk::VkDeviceSize			numElements;
428 			subgroups::SSBOData::BindingSSBO,		//  bool						isImage;
429 			4u,										//  deUint32					binding;
430 			stages,									//  vk::VkShaderStageFlags		stages;
431 		};
432 
433 		return subgroups::allStages(context, VK_FORMAT_R32_UINT, &inputData, 1, DE_NULL, checkVertexPipelineStages, stages);
434 	}
435 	else if (isAllRayTracingStages(caseDef.shaderStage))
436 	{
437 		const VkShaderStageFlags	stages		= subgroups::getPossibleRayTracingSubgroupStages(context, caseDef.shaderStage);
438 		const subgroups::SSBOData	inputData
439 		{
440 			subgroups::SSBOData::InitializeNonZero,	//  InputDataInitializeType		initializeType;
441 			subgroups::SSBOData::LayoutStd430,		//  InputDataLayoutType			layout;
442 			caseDef.format,							//  vk::VkFormat				format;
443 			subgroups::maxSupportedSubgroupSize(),	//  vk::VkDeviceSize			numElements;
444 			subgroups::SSBOData::BindingSSBO,		//  bool						isImage;
445 			6u,										//  deUint32					binding;
446 			stages,									//  vk::VkShaderStageFlags		stages;
447 		};
448 
449 		return subgroups::allRayTracingStages(context, VK_FORMAT_R32_UINT, &inputData, 1, DE_NULL, checkVertexPipelineStages, stages);
450 	}
451 	else
452 		TCU_THROW(InternalError, "Unknown stage or invalid stage set");
453 }
454 }
455 
456 namespace vkt
457 {
458 namespace subgroups
459 {
createSubgroupsPartitionedTests(TestContext & testCtx)460 TestCaseGroup* createSubgroupsPartitionedTests (TestContext& testCtx)
461 {
462 	de::MovePtr<TestCaseGroup>	group				(new TestCaseGroup(testCtx, "partitioned", "Subgroup partitioned category tests"));
463 	de::MovePtr<TestCaseGroup>	graphicGroup		(new TestCaseGroup(testCtx, "graphics", "Subgroup partitioned category tests: graphics"));
464 	de::MovePtr<TestCaseGroup>	computeGroup		(new TestCaseGroup(testCtx, "compute", "Subgroup partitioned category tests: compute"));
465 	de::MovePtr<TestCaseGroup>	meshGroup			(new TestCaseGroup(testCtx, "mesh", "Subgroup partitioned category tests: mesh shading"));
466 	de::MovePtr<TestCaseGroup>	framebufferGroup	(new TestCaseGroup(testCtx, "framebuffer", "Subgroup partitioned category tests: framebuffer"));
467 	de::MovePtr<TestCaseGroup>	raytracingGroup		(new TestCaseGroup(testCtx, "ray_tracing", "Subgroup partitioned category tests: ray tracing"));
468 	const VkShaderStageFlags	fbStages[]			=
469 	{
470 		VK_SHADER_STAGE_VERTEX_BIT,
471 		VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT,
472 		VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT,
473 		VK_SHADER_STAGE_GEOMETRY_BIT,
474 	};
475 	const VkShaderStageFlags	meshStages[]		=
476 	{
477 		VK_SHADER_STAGE_MESH_BIT_EXT,
478 		VK_SHADER_STAGE_TASK_BIT_EXT,
479 	};
480 	const deBool				boolValues[]		=
481 	{
482 		DE_FALSE,
483 		DE_TRUE
484 	};
485 
486 	{
487 		const vector<VkFormat>		formats		= subgroups::getAllFormats();
488 
489 		for (size_t formatIndex = 0; formatIndex < formats.size(); ++formatIndex)
490 		{
491 			const VkFormat	format		= formats[formatIndex];
492 			const string	formatName	= subgroups::getFormatNameForGLSL(format);
493 			const bool		isBool		= subgroups::isFormatBool(format);
494 			const bool		isFloat		= subgroups::isFormatFloat(format);
495 
496 			for (int opTypeIndex = 0; opTypeIndex < OPTYPE_LAST; ++opTypeIndex)
497 			{
498 				const OpType	opType		= static_cast<OpType>(opTypeIndex);
499 				const Operator	op			= getOperator(opType);
500 				const ScanType	st			= getScanType(opType);
501 				const bool		isBitwiseOp	= (op == OPERATOR_AND || op == OPERATOR_OR || op == OPERATOR_XOR);
502 
503 				// Skip float with bitwise category.
504 				if (isFloat && isBitwiseOp)
505 					continue;
506 
507 				// Skip bool when its not the bitwise category.
508 				if (isBool && !isBitwiseOp)
509 					continue;
510 
511 				const string	name = de::toLower(getOpTypeName(op, st)) + "_" + formatName;
512 
513 				for (size_t groupSizeNdx = 0; groupSizeNdx < DE_LENGTH_OF_ARRAY(boolValues); ++groupSizeNdx)
514 				{
515 					const deBool			requiredSubgroupSize	= boolValues[groupSizeNdx];
516 					const string			testName				= name + (requiredSubgroupSize ? "_requiredsubgroupsize" : "");
517 					const CaseDefinition	caseDef					=
518 					{
519 						op,								//  Operator			op;
520 						st,								//  ScanType			scanType;
521 						VK_SHADER_STAGE_COMPUTE_BIT,	//  VkShaderStageFlags	shaderStage;
522 						format,							//  VkFormat			format;
523 						de::SharedPtr<bool>(new bool),	//  de::SharedPtr<bool>	geometryPointSizeSupported;
524 						requiredSubgroupSize			//  deBool				requiredSubgroupSize;
525 					};
526 
527 					addFunctionCaseWithPrograms(computeGroup.get(), testName, "", supportedCheck, initPrograms, test, caseDef);
528 				}
529 
530 				for (size_t groupSizeNdx = 0; groupSizeNdx < DE_LENGTH_OF_ARRAY(boolValues); ++groupSizeNdx)
531 				{
532 					for (const auto& stage : meshStages)
533 					{
534 						const deBool			requiredSubgroupSize	= boolValues[groupSizeNdx];
535 						const string			testName				= name + (requiredSubgroupSize ? "_requiredsubgroupsize" : "") + "_" + getShaderStageName(stage);
536 						const CaseDefinition	caseDef					=
537 						{
538 							op,								//  Operator			op;
539 							st,								//  ScanType			scanType;
540 							stage,							//  VkShaderStageFlags	shaderStage;
541 							format,							//  VkFormat			format;
542 							de::SharedPtr<bool>(new bool),	//  de::SharedPtr<bool>	geometryPointSizeSupported;
543 							requiredSubgroupSize			//  deBool				requiredSubgroupSize;
544 						};
545 
546 						addFunctionCaseWithPrograms(meshGroup.get(), testName, "", supportedCheck, initPrograms, test, caseDef);
547 					}
548 				}
549 
550 				{
551 					const CaseDefinition	caseDef		=
552 					{
553 						op,								//  Operator			op;
554 						st,								//  ScanType			scanType;
555 						VK_SHADER_STAGE_ALL_GRAPHICS,	//  VkShaderStageFlags	shaderStage;
556 						format,							//  VkFormat			format;
557 						de::SharedPtr<bool>(new bool),	//  de::SharedPtr<bool>	geometryPointSizeSupported;
558 						DE_FALSE						//  deBool				requiredSubgroupSize;
559 					};
560 
561 					addFunctionCaseWithPrograms(graphicGroup.get(), name, "", supportedCheck, initPrograms, test, caseDef);
562 				}
563 
564 				for (int stageIndex = 0; stageIndex < DE_LENGTH_OF_ARRAY(fbStages); ++stageIndex)
565 				{
566 					const CaseDefinition	caseDef		=
567 					{
568 						op,								//  Operator			op;
569 						st,								//  ScanType			scanType;
570 						fbStages[stageIndex],			//  VkShaderStageFlags	shaderStage;
571 						format,							//  VkFormat			format;
572 						de::SharedPtr<bool>(new bool),	//  de::SharedPtr<bool>	geometryPointSizeSupported;
573 						DE_FALSE						//  deBool				requiredSubgroupSize;
574 					};
575 					const string			testName	= name + "_" + getShaderStageName(caseDef.shaderStage);
576 
577 					addFunctionCaseWithPrograms(framebufferGroup.get(), testName, "", supportedCheck, initFrameBufferPrograms, noSSBOtest, caseDef);
578 				}
579 			}
580 		}
581 	}
582 
583 	{
584 		const vector<VkFormat>	formats		= subgroups::getAllRayTracingFormats();
585 
586 		for (size_t formatIndex = 0; formatIndex < formats.size(); ++formatIndex)
587 		{
588 			const VkFormat	format		= formats[formatIndex];
589 			const string	formatName	= subgroups::getFormatNameForGLSL(format);
590 			const bool		isBool		= subgroups::isFormatBool(format);
591 			const bool		isFloat		= subgroups::isFormatFloat(format);
592 
593 			for (int opTypeIndex = 0; opTypeIndex < OPTYPE_LAST; ++opTypeIndex)
594 			{
595 				const OpType	opType		= static_cast<OpType>(opTypeIndex);
596 				const Operator	op			= getOperator(opType);
597 				const ScanType	st			= getScanType(opType);
598 				const bool		isBitwiseOp	= (op == OPERATOR_AND || op == OPERATOR_OR || op == OPERATOR_XOR);
599 
600 				// Skip float with bitwise category.
601 				if (isFloat && isBitwiseOp)
602 					continue;
603 
604 				// Skip bool when its not the bitwise category.
605 				if (isBool && !isBitwiseOp)
606 					continue;
607 
608 				{
609 					const CaseDefinition	caseDef		=
610 					{
611 						op,								//  Operator			op;
612 						st,								//  ScanType			scanType;
613 						SHADER_STAGE_ALL_RAY_TRACING,	//  VkShaderStageFlags	shaderStage;
614 						format,							//  VkFormat			format;
615 						de::SharedPtr<bool>(new bool),	//  de::SharedPtr<bool>	geometryPointSizeSupported;
616 						DE_FALSE						//  deBool				requiredSubgroupSize;
617 					};
618 					const string			name		= de::toLower(getOpTypeName(op, st)) + "_" + formatName;
619 
620 					addFunctionCaseWithPrograms(raytracingGroup.get(), name, "", supportedCheck, initPrograms, test, caseDef);
621 				}
622 			}
623 		}
624 	}
625 
626 	group->addChild(graphicGroup.release());
627 	group->addChild(computeGroup.release());
628 	group->addChild(framebufferGroup.release());
629 	group->addChild(raytracingGroup.release());
630 	group->addChild(meshGroup.release());
631 
632 	return group.release();
633 }
634 } // subgroups
635 } // vkt
636