• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*------------------------------------------------------------------------
2  * Vulkan Conformance Tests
3  * ------------------------
4  *
5  * Copyright (c) 2019 The Khronos Group Inc.
6  * Copyright (c) 2019 Google Inc.
7  * Copyright (c) 2017 Codeplay Software Ltd.
8  * Copyright (c) 2018 NVIDIA Corporation
9  *
10  * Licensed under the Apache License, Version 2.0 (the "License");
11  * you may not use this file except in compliance with the License.
12  * You may obtain a copy of the License at
13  *
14  *      http://www.apache.org/licenses/LICENSE-2.0
15  *
16  * Unless required by applicable law or agreed to in writing, software
17  * distributed under the License is distributed on an "AS IS" BASIS,
18  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19  * See the License for the specific language governing permissions and
20  * limitations under the License.
21  *
22  */ /*!
23  * \file
24  * \brief Subgroups Tests
25  */ /*--------------------------------------------------------------------*/
26 
27 #include "vktSubgroupsPartitionedTests.hpp"
28 #include "vktSubgroupsScanHelpers.hpp"
29 #include "vktSubgroupsTestsUtils.hpp"
30 
31 #include <string>
32 #include <vector>
33 
34 using namespace tcu;
35 using namespace std;
36 using namespace vk;
37 using namespace vkt;
38 
39 namespace
40 {
41 enum OpType
42 {
43 	OPTYPE_ADD = 0,
44 	OPTYPE_MUL,
45 	OPTYPE_MIN,
46 	OPTYPE_MAX,
47 	OPTYPE_AND,
48 	OPTYPE_OR,
49 	OPTYPE_XOR,
50 	OPTYPE_INCLUSIVE_ADD,
51 	OPTYPE_INCLUSIVE_MUL,
52 	OPTYPE_INCLUSIVE_MIN,
53 	OPTYPE_INCLUSIVE_MAX,
54 	OPTYPE_INCLUSIVE_AND,
55 	OPTYPE_INCLUSIVE_OR,
56 	OPTYPE_INCLUSIVE_XOR,
57 	OPTYPE_EXCLUSIVE_ADD,
58 	OPTYPE_EXCLUSIVE_MUL,
59 	OPTYPE_EXCLUSIVE_MIN,
60 	OPTYPE_EXCLUSIVE_MAX,
61 	OPTYPE_EXCLUSIVE_AND,
62 	OPTYPE_EXCLUSIVE_OR,
63 	OPTYPE_EXCLUSIVE_XOR,
64 	OPTYPE_LAST
65 };
66 
67 struct CaseDefinition
68 {
69 	Operator			op;
70 	ScanType			scanType;
71 	VkShaderStageFlags	shaderStage;
72 	VkFormat			format;
73 	de::SharedPtr<bool>	geometryPointSizeSupported;
74 	deBool				requiredSubgroupSize;
75 	deBool				requires8BitUniformBuffer;
76 	deBool				requires16BitUniformBuffer;
77 };
78 
getOperator(OpType opType)79 static Operator getOperator (OpType opType)
80 {
81 	switch (opType)
82 	{
83 		case OPTYPE_ADD:
84 		case OPTYPE_INCLUSIVE_ADD:
85 		case OPTYPE_EXCLUSIVE_ADD:
86 			return OPERATOR_ADD;
87 		case OPTYPE_MUL:
88 		case OPTYPE_INCLUSIVE_MUL:
89 		case OPTYPE_EXCLUSIVE_MUL:
90 			return OPERATOR_MUL;
91 		case OPTYPE_MIN:
92 		case OPTYPE_INCLUSIVE_MIN:
93 		case OPTYPE_EXCLUSIVE_MIN:
94 			return OPERATOR_MIN;
95 		case OPTYPE_MAX:
96 		case OPTYPE_INCLUSIVE_MAX:
97 		case OPTYPE_EXCLUSIVE_MAX:
98 			return OPERATOR_MAX;
99 		case OPTYPE_AND:
100 		case OPTYPE_INCLUSIVE_AND:
101 		case OPTYPE_EXCLUSIVE_AND:
102 			return OPERATOR_AND;
103 		case OPTYPE_OR:
104 		case OPTYPE_INCLUSIVE_OR:
105 		case OPTYPE_EXCLUSIVE_OR:
106 			return OPERATOR_OR;
107 		case OPTYPE_XOR:
108 		case OPTYPE_INCLUSIVE_XOR:
109 		case OPTYPE_EXCLUSIVE_XOR:
110 			return OPERATOR_XOR;
111 		default:
112 			DE_FATAL("Unsupported op type");
113 			return OPERATOR_ADD;
114 	}
115 }
116 
getScanType(OpType opType)117 static ScanType getScanType (OpType opType)
118 {
119 	switch (opType)
120 	{
121 		case OPTYPE_ADD:
122 		case OPTYPE_MUL:
123 		case OPTYPE_MIN:
124 		case OPTYPE_MAX:
125 		case OPTYPE_AND:
126 		case OPTYPE_OR:
127 		case OPTYPE_XOR:
128 			return SCAN_REDUCE;
129 		case OPTYPE_INCLUSIVE_ADD:
130 		case OPTYPE_INCLUSIVE_MUL:
131 		case OPTYPE_INCLUSIVE_MIN:
132 		case OPTYPE_INCLUSIVE_MAX:
133 		case OPTYPE_INCLUSIVE_AND:
134 		case OPTYPE_INCLUSIVE_OR:
135 		case OPTYPE_INCLUSIVE_XOR:
136 			return SCAN_INCLUSIVE;
137 		case OPTYPE_EXCLUSIVE_ADD:
138 		case OPTYPE_EXCLUSIVE_MUL:
139 		case OPTYPE_EXCLUSIVE_MIN:
140 		case OPTYPE_EXCLUSIVE_MAX:
141 		case OPTYPE_EXCLUSIVE_AND:
142 		case OPTYPE_EXCLUSIVE_OR:
143 		case OPTYPE_EXCLUSIVE_XOR:
144 			return SCAN_EXCLUSIVE;
145 		default:
146 			DE_FATAL("Unsupported op type");
147 			return SCAN_REDUCE;
148 	}
149 }
150 
checkVertexPipelineStages(const void * internalData,vector<const void * > datas,deUint32 width,deUint32)151 static bool checkVertexPipelineStages (const void*			internalData,
152 									   vector<const void*>	datas,
153 									   deUint32				width,
154 									   deUint32)
155 {
156 	DE_UNREF(internalData);
157 
158 	return subgroups::check(datas, width, 0xFFFFFF);
159 }
160 
checkComputeOrMesh(const void * internalData,vector<const void * > datas,const deUint32 numWorkgroups[3],const deUint32 localSize[3],deUint32)161 static bool checkComputeOrMesh (const void*			internalData,
162 								vector<const void*>	datas,
163 								const deUint32		numWorkgroups[3],
164 								const deUint32		localSize[3],
165 								deUint32)
166 {
167 	DE_UNREF(internalData);
168 
169 	return subgroups::checkComputeOrMesh(datas, numWorkgroups, localSize, 0xFFFFFF);
170 }
171 
getOpTypeName(Operator op,ScanType scanType)172 string getOpTypeName (Operator op, ScanType scanType)
173 {
174 	return getScanOpName("subgroup", "", op, scanType);
175 }
176 
getOpTypeNamePartitioned(Operator op,ScanType scanType)177 string getOpTypeNamePartitioned (Operator op, ScanType scanType)
178 {
179 	return getScanOpName("subgroupPartitioned", "NV", op, scanType);
180 }
181 
getExtHeader(const CaseDefinition & caseDef)182 string getExtHeader (const CaseDefinition& caseDef)
183 {
184 	return	"#extension GL_NV_shader_subgroup_partitioned: enable\n"
185 			"#extension GL_KHR_shader_subgroup_arithmetic: enable\n"
186 			"#extension GL_KHR_shader_subgroup_ballot: enable\n"
187 			+ subgroups::getAdditionalExtensionForFormat(caseDef.format);
188 }
189 
getTestString(const CaseDefinition & caseDef)190 string getTestString (const CaseDefinition& caseDef)
191 {
192 	Operator op = caseDef.op;
193 	ScanType st = caseDef.scanType;
194 
195 	// NOTE: tempResult can't have anything in bits 31:24 to avoid int->float
196 	// conversion overflow in framebuffer tests.
197 	string fmt = subgroups::getFormatNameForGLSL(caseDef.format);
198 	string bdy =
199 		"  uvec4 mask = subgroupBallot(true);\n"
200 		"  uint tempResult = 0;\n"
201 		"  uint id = gl_SubgroupInvocationID;\n";
202 
203 	// Test the case where the partition has a single subset with all invocations in it.
204 	// This should generate the same result as the non-partitioned function.
205 	bdy +=
206 		"  uvec4 allBallot = mask;\n"
207 		"  " + fmt + " allResult = " + getOpTypeNamePartitioned(op, st) + "(data[gl_SubgroupInvocationID], allBallot);\n"
208 		"  " + fmt + " refResult = " + getOpTypeName(op, st) + "(data[gl_SubgroupInvocationID]);\n"
209 		"  if (" + getCompare(op, caseDef.format, "allResult", "refResult") + ") {\n"
210 		"      tempResult |= 0x1;\n"
211 		"  }\n";
212 
213 	// The definition of a partition doesn't forbid bits corresponding to inactive
214 	// invocations being in the subset with active invocations. In other words, test that
215 	// bits corresponding to inactive invocations are ignored.
216 	bdy +=
217 		"  if (0 == (gl_SubgroupInvocationID % 2)) {\n"
218 		"    " + fmt + " allResult = " + getOpTypeNamePartitioned(op, st) + "(data[gl_SubgroupInvocationID], allBallot);\n"
219 		"    " + fmt + " refResult = " + getOpTypeName(op, st) + "(data[gl_SubgroupInvocationID]);\n"
220 		"    if (" + getCompare(op, caseDef.format, "allResult", "refResult") + ") {\n"
221 		"        tempResult |= 0x2;\n"
222 		"    }\n"
223 		"  } else {\n"
224 		"    tempResult |= 0x2;\n"
225 		"  }\n";
226 
227 	// Test the case where the partition has each invocation in a unique subset. For
228 	// exclusive ops, the result is identity. For reduce/inclusive, it's the original value.
229 	string expectedSelfResult = "data[gl_SubgroupInvocationID]";
230 	if (st == SCAN_EXCLUSIVE)
231 		expectedSelfResult = getIdentity(op, caseDef.format);
232 
233 	bdy +=
234 		"  uvec4 selfBallot = subgroupPartitionNV(gl_SubgroupInvocationID);\n"
235 		"  " + fmt + " selfResult = " + getOpTypeNamePartitioned(op, st) + "(data[gl_SubgroupInvocationID], selfBallot);\n"
236 		"  if (" + getCompare(op, caseDef.format, "selfResult", expectedSelfResult) + ") {\n"
237 		"      tempResult |= 0x4;\n"
238 		"  }\n";
239 
240 	// Test "random" partitions based on a hash of the invocation id.
241 	// This "hash" function produces interesting/randomish partitions.
242 	static const char *idhash = "((id%N)+(id%(N+1))-(id%2)+(id/2))%((N+1)/2)";
243 
244 	bdy +=
245 		"  for (uint N = 1; N < 16; ++N) {\n"
246 		"    " + fmt + " idhashFmt = " + fmt + "(" + idhash + ");\n"
247 		"    uvec4 partitionBallot = subgroupPartitionNV(idhashFmt) & mask;\n"
248 		"    " + fmt + " partitionedResult = " + getOpTypeNamePartitioned(op, st) + "(data[gl_SubgroupInvocationID], partitionBallot);\n"
249 		"      for (uint i = 0; i < N; ++i) {\n"
250 		"        " + fmt + " iFmt = " + fmt + "(i);\n"
251 		"        if (" + getCompare(op, caseDef.format, "idhashFmt", "iFmt") + ") {\n"
252 		"          " + fmt + " subsetResult = " + getOpTypeName(op, st) + "(data[gl_SubgroupInvocationID]);\n"
253 		"          tempResult |= " + getCompare(op, caseDef.format, "partitionedResult", "subsetResult") + " ? (0x4 << N) : 0;\n"
254 		"        }\n"
255 		"      }\n"
256 		"  }\n"
257 		// tests in flow control:
258 		"  if (1 == (gl_SubgroupInvocationID % 2)) {\n"
259 		"    for (uint N = 1; N < 7; ++N) {\n"
260 		"      " + fmt + " idhashFmt = " + fmt + "(" + idhash + ");\n"
261 		"      uvec4 partitionBallot = subgroupPartitionNV(idhashFmt) & mask;\n"
262 		"      " + fmt + " partitionedResult = " + getOpTypeNamePartitioned(op, st) + "(data[gl_SubgroupInvocationID], partitionBallot);\n"
263 		"        for (uint i = 0; i < N; ++i) {\n"
264 		"          " + fmt + " iFmt = " + fmt + "(i);\n"
265 		"          if (" + getCompare(op, caseDef.format, "idhashFmt", "iFmt") + ") {\n"
266 		"            " + fmt + " subsetResult = " + getOpTypeName(op, st) + "(data[gl_SubgroupInvocationID]);\n"
267 		"            tempResult |= " + getCompare(op, caseDef.format, "partitionedResult", "subsetResult") + " ? (0x20000 << N) : 0;\n"
268 		"          }\n"
269 		"        }\n"
270 		"    }\n"
271 		"  } else {\n"
272 		"    tempResult |= 0xFC0000;\n"
273 		"  }\n"
274 		"  tempRes = tempResult;\n"
275 		;
276 
277 	return bdy;
278 }
279 
initFrameBufferPrograms(SourceCollections & programCollection,CaseDefinition caseDef)280 void initFrameBufferPrograms (SourceCollections& programCollection, CaseDefinition caseDef)
281 {
282 	const ShaderBuildOptions	buildOptions		(programCollection.usedVulkanVersion, SPIRV_VERSION_1_3, 0u);
283 	const string				extHeader			= getExtHeader(caseDef);
284 	const string				testSrc				= getTestString(caseDef);
285 	const bool					pointSizeSupport	= *caseDef.geometryPointSizeSupported;
286 
287 	subgroups::initStdFrameBufferPrograms(programCollection, buildOptions, caseDef.shaderStage, caseDef.format, pointSizeSupport, extHeader, testSrc, "");
288 }
289 
initPrograms(SourceCollections & programCollection,CaseDefinition caseDef)290 void initPrograms (SourceCollections& programCollection, CaseDefinition caseDef)
291 {
292 	const bool					spirv14required		= (isAllRayTracingStages(caseDef.shaderStage) || isAllMeshShadingStages(caseDef.shaderStage));
293 	const SpirvVersion			spirvVersion		= spirv14required ? SPIRV_VERSION_1_4 : SPIRV_VERSION_1_3;
294 	const ShaderBuildOptions	buildOptions		(programCollection.usedVulkanVersion, spirvVersion, 0u, spirv14required);
295 	const string				extHeader			= getExtHeader(caseDef);
296 	const string				testSrc				= getTestString(caseDef);
297 	const bool					pointSizeSupport	= *caseDef.geometryPointSizeSupported;
298 
299 	subgroups::initStdPrograms(programCollection, buildOptions, caseDef.shaderStage, caseDef.format, pointSizeSupport, extHeader, testSrc, "");
300 }
301 
supportedCheck(Context & context,CaseDefinition caseDef)302 void supportedCheck (Context& context, CaseDefinition caseDef)
303 {
304 	if (!subgroups::isSubgroupSupported(context))
305 		TCU_THROW(NotSupportedError, "Subgroup operations are not supported");
306 
307 	if (!subgroups::isSubgroupFeatureSupportedForDevice(context, VK_SUBGROUP_FEATURE_PARTITIONED_BIT_NV))
308 		TCU_THROW(NotSupportedError, "Device does not support subgroup partitioned operations");
309 
310 	if (!subgroups::isFormatSupportedForDevice(context, caseDef.format))
311 		TCU_THROW(NotSupportedError, "Device does not support the specified format in subgroup operations");
312 
313 	if (caseDef.requires16BitUniformBuffer)
314 	{
315 		if (!subgroups::is16BitUBOStorageSupported(context))
316 		{
317 			TCU_THROW(NotSupportedError, "Device does not support the specified format in subgroup operations");
318 		}
319 	}
320 
321 	if (caseDef.requires8BitUniformBuffer)
322 	{
323 		if (!subgroups::is8BitUBOStorageSupported(context))
324 		{
325 			TCU_THROW(NotSupportedError, "Device does not support the specified format in subgroup operations");
326 		}
327 	}
328 
329 	if (caseDef.requiredSubgroupSize)
330 	{
331 		context.requireDeviceFunctionality("VK_EXT_subgroup_size_control");
332 
333 		const VkPhysicalDeviceSubgroupSizeControlFeatures&		subgroupSizeControlFeatures		= context.getSubgroupSizeControlFeatures();
334 		const VkPhysicalDeviceSubgroupSizeControlProperties&	subgroupSizeControlProperties	= context.getSubgroupSizeControlProperties();
335 
336 		if (subgroupSizeControlFeatures.subgroupSizeControl == DE_FALSE)
337 			TCU_THROW(NotSupportedError, "Device does not support varying subgroup sizes nor required subgroup size");
338 
339 		if (subgroupSizeControlFeatures.computeFullSubgroups == DE_FALSE)
340 			TCU_THROW(NotSupportedError, "Device does not support full subgroups in compute shaders");
341 
342 		if ((subgroupSizeControlProperties.requiredSubgroupSizeStages & caseDef.shaderStage) != caseDef.shaderStage)
343 			TCU_THROW(NotSupportedError, "Required subgroup size is not supported for shader stage");
344 	}
345 
346 	*caseDef.geometryPointSizeSupported = subgroups::isTessellationAndGeometryPointSizeSupported(context);
347 
348 	if (isAllRayTracingStages(caseDef.shaderStage))
349 	{
350 		context.requireDeviceFunctionality("VK_KHR_ray_tracing_pipeline");
351 	}
352 	else if (isAllMeshShadingStages(caseDef.shaderStage))
353 	{
354 		context.requireDeviceCoreFeature(DEVICE_CORE_FEATURE_VERTEX_PIPELINE_STORES_AND_ATOMICS);
355 		context.requireDeviceFunctionality("VK_EXT_mesh_shader");
356 
357 		if ((caseDef.shaderStage & VK_SHADER_STAGE_TASK_BIT_EXT) != 0u)
358 		{
359 			const auto& features = context.getMeshShaderFeaturesEXT();
360 			if (!features.taskShader)
361 				TCU_THROW(NotSupportedError, "Task shaders not supported");
362 		}
363 	}
364 
365 	subgroups::supportedCheckShader(context, caseDef.shaderStage);
366 }
367 
noSSBOtest(Context & context,const CaseDefinition caseDef)368 TestStatus noSSBOtest (Context& context, const CaseDefinition caseDef)
369 {
370 	const subgroups::SSBOData	inputData
371 	{
372 		subgroups::SSBOData::InitializeNonZero,	//  InputDataInitializeType		initializeType;
373 		subgroups::SSBOData::LayoutStd140,		//  InputDataLayoutType			layout;
374 		caseDef.format,							//  vk::VkFormat				format;
375 		subgroups::maxSupportedSubgroupSize(),	//  vk::VkDeviceSize			numElements;
376 		subgroups::SSBOData::BindingUBO,		//  BindingType					bindingType;
377 	};
378 
379 	switch (caseDef.shaderStage)
380 	{
381 		case VK_SHADER_STAGE_VERTEX_BIT:					return subgroups::makeVertexFrameBufferTest(context, VK_FORMAT_R32_UINT, &inputData, 1, DE_NULL, checkVertexPipelineStages);
382 		case VK_SHADER_STAGE_GEOMETRY_BIT:					return subgroups::makeGeometryFrameBufferTest(context, VK_FORMAT_R32_UINT, &inputData, 1, DE_NULL, checkVertexPipelineStages);
383 		case VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT:		return subgroups::makeTessellationEvaluationFrameBufferTest(context, VK_FORMAT_R32_UINT, &inputData, 1, DE_NULL, checkVertexPipelineStages, caseDef.shaderStage);
384 		case VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT:	return subgroups::makeTessellationEvaluationFrameBufferTest(context, VK_FORMAT_R32_UINT, &inputData, 1, DE_NULL, checkVertexPipelineStages, caseDef.shaderStage);
385 		default:											TCU_THROW(InternalError, "Unhandled shader stage");
386 	}
387 }
388 
test(Context & context,const CaseDefinition caseDef)389 TestStatus test (Context& context, const CaseDefinition caseDef)
390 {
391 	const bool isCompute	= isAllComputeStages(caseDef.shaderStage);
392 	const bool isMesh		= isAllMeshShadingStages(caseDef.shaderStage);
393 	DE_ASSERT(!(isCompute && isMesh));
394 
395 	if (isCompute || isMesh)
396 	{
397 		const VkPhysicalDeviceSubgroupSizeControlProperties&	subgroupSizeControlProperties	= context.getSubgroupSizeControlProperties();
398 		TestLog&												log								= context.getTestContext().getLog();
399 		const subgroups::SSBOData								inputData						=
400 		{
401 			subgroups::SSBOData::InitializeNonZero,	//  InputDataInitializeType		initializeType;
402 			subgroups::SSBOData::LayoutStd430,		//  InputDataLayoutType			layout;
403 			caseDef.format,							//  vk::VkFormat				format;
404 			subgroups::maxSupportedSubgroupSize(),	//  vk::VkDeviceSize			numElements;
405 		};
406 
407 		if (caseDef.requiredSubgroupSize == DE_FALSE)
408 		{
409 			if (isCompute)
410 				return subgroups::makeComputeTest(context, VK_FORMAT_R32_UINT, &inputData, 1, DE_NULL, checkComputeOrMesh);
411 			else
412 				return subgroups::makeMeshTest(context, VK_FORMAT_R32_UINT, &inputData, 1, DE_NULL, checkComputeOrMesh);
413 		}
414 
415 		log << TestLog::Message << "Testing required subgroup size range [" <<  subgroupSizeControlProperties.minSubgroupSize << ", "
416 			<< subgroupSizeControlProperties.maxSubgroupSize << "]" << TestLog::EndMessage;
417 
418 		// According to the spec, requiredSubgroupSize must be a power-of-two integer.
419 		for (deUint32 size = subgroupSizeControlProperties.minSubgroupSize; size <= subgroupSizeControlProperties.maxSubgroupSize; size *= 2)
420 		{
421 			TestStatus result (QP_TEST_RESULT_INTERNAL_ERROR, "Internal Error");
422 
423 			if (isCompute)
424 				result = subgroups::makeComputeTest(context, VK_FORMAT_R32_UINT, &inputData, 1, DE_NULL, checkComputeOrMesh, size);
425 			else
426 				result = subgroups::makeMeshTest(context, VK_FORMAT_R32_UINT, &inputData, 1, DE_NULL, checkComputeOrMesh, size);
427 
428 			if (result.getCode() != QP_TEST_RESULT_PASS)
429 			{
430 				log << TestLog::Message << "subgroupSize " << size << " failed" << TestLog::EndMessage;
431 				return result;
432 			}
433 		}
434 
435 		return TestStatus::pass("OK");
436 	}
437 	else if (isAllGraphicsStages(caseDef.shaderStage))
438 	{
439 		const VkShaderStageFlags	stages		= subgroups::getPossibleGraphicsSubgroupStages(context, caseDef.shaderStage);
440 		const subgroups::SSBOData	inputData
441 		{
442 			subgroups::SSBOData::InitializeNonZero,	//  InputDataInitializeType		initializeType;
443 			subgroups::SSBOData::LayoutStd430,		//  InputDataLayoutType			layout;
444 			caseDef.format,							//  vk::VkFormat				format;
445 			subgroups::maxSupportedSubgroupSize(),	//  vk::VkDeviceSize			numElements;
446 			subgroups::SSBOData::BindingSSBO,		//  bool						isImage;
447 			4u,										//  deUint32					binding;
448 			stages,									//  vk::VkShaderStageFlags		stages;
449 		};
450 
451 		return subgroups::allStages(context, VK_FORMAT_R32_UINT, &inputData, 1, DE_NULL, checkVertexPipelineStages, stages);
452 	}
453 	else if (isAllRayTracingStages(caseDef.shaderStage))
454 	{
455 		const VkShaderStageFlags	stages		= subgroups::getPossibleRayTracingSubgroupStages(context, caseDef.shaderStage);
456 		const subgroups::SSBOData	inputData
457 		{
458 			subgroups::SSBOData::InitializeNonZero,	//  InputDataInitializeType		initializeType;
459 			subgroups::SSBOData::LayoutStd430,		//  InputDataLayoutType			layout;
460 			caseDef.format,							//  vk::VkFormat				format;
461 			subgroups::maxSupportedSubgroupSize(),	//  vk::VkDeviceSize			numElements;
462 			subgroups::SSBOData::BindingSSBO,		//  bool						isImage;
463 			6u,										//  deUint32					binding;
464 			stages,									//  vk::VkShaderStageFlags		stages;
465 		};
466 
467 		return subgroups::allRayTracingStages(context, VK_FORMAT_R32_UINT, &inputData, 1, DE_NULL, checkVertexPipelineStages, stages);
468 	}
469 	else
470 		TCU_THROW(InternalError, "Unknown stage or invalid stage set");
471 }
472 }
473 
474 namespace vkt
475 {
476 namespace subgroups
477 {
createSubgroupsPartitionedTests(TestContext & testCtx)478 TestCaseGroup* createSubgroupsPartitionedTests (TestContext& testCtx)
479 {
480 	de::MovePtr<TestCaseGroup>	group				(new TestCaseGroup(testCtx, "partitioned", "Subgroup partitioned category tests"));
481 	de::MovePtr<TestCaseGroup>	graphicGroup		(new TestCaseGroup(testCtx, "graphics", "Subgroup partitioned category tests: graphics"));
482 	de::MovePtr<TestCaseGroup>	computeGroup		(new TestCaseGroup(testCtx, "compute", "Subgroup partitioned category tests: compute"));
483 	de::MovePtr<TestCaseGroup>	meshGroup			(new TestCaseGroup(testCtx, "mesh", "Subgroup partitioned category tests: mesh shading"));
484 	de::MovePtr<TestCaseGroup>	framebufferGroup	(new TestCaseGroup(testCtx, "framebuffer", "Subgroup partitioned category tests: framebuffer"));
485 	de::MovePtr<TestCaseGroup>	raytracingGroup		(new TestCaseGroup(testCtx, "ray_tracing", "Subgroup partitioned category tests: ray tracing"));
486 	const VkShaderStageFlags	fbStages[]			=
487 	{
488 		VK_SHADER_STAGE_VERTEX_BIT,
489 		VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT,
490 		VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT,
491 		VK_SHADER_STAGE_GEOMETRY_BIT,
492 	};
493 	const VkShaderStageFlags	meshStages[]		=
494 	{
495 		VK_SHADER_STAGE_MESH_BIT_EXT,
496 		VK_SHADER_STAGE_TASK_BIT_EXT,
497 	};
498 	const deBool				boolValues[]		=
499 	{
500 		DE_FALSE,
501 		DE_TRUE
502 	};
503 
504 	{
505 		const vector<VkFormat>		formats		= subgroups::getAllFormats();
506 
507 		for (size_t formatIndex = 0; formatIndex < formats.size(); ++formatIndex)
508 		{
509 			const VkFormat	format					= formats[formatIndex];
510 			const string	formatName				= subgroups::getFormatNameForGLSL(format);
511 			const bool		isBool					= subgroups::isFormatBool(format);
512 			const bool		isFloat					= subgroups::isFormatFloat(format);
513 			const bool		needs8BitUBOStorage		= isFormat8bitTy(format);
514 			const bool		needs16BitUBOStorage	= isFormat16BitTy(format);
515 
516 			for (int opTypeIndex = 0; opTypeIndex < OPTYPE_LAST; ++opTypeIndex)
517 			{
518 				const OpType	opType		= static_cast<OpType>(opTypeIndex);
519 				const Operator	op			= getOperator(opType);
520 				const ScanType	st			= getScanType(opType);
521 				const bool		isBitwiseOp	= (op == OPERATOR_AND || op == OPERATOR_OR || op == OPERATOR_XOR);
522 
523 				// Skip float with bitwise category.
524 				if (isFloat && isBitwiseOp)
525 					continue;
526 
527 				// Skip bool when its not the bitwise category.
528 				if (isBool && !isBitwiseOp)
529 					continue;
530 
531 				const string	name = de::toLower(getOpTypeName(op, st)) + "_" + formatName;
532 
533 				for (size_t groupSizeNdx = 0; groupSizeNdx < DE_LENGTH_OF_ARRAY(boolValues); ++groupSizeNdx)
534 				{
535 					const deBool			requiredSubgroupSize	= boolValues[groupSizeNdx];
536 					const string			testName				= name + (requiredSubgroupSize ? "_requiredsubgroupsize" : "");
537 					const CaseDefinition	caseDef					=
538 					{
539 						op,								//  Operator			op;
540 						st,								//  ScanType			scanType;
541 						VK_SHADER_STAGE_COMPUTE_BIT,	//  VkShaderStageFlags	shaderStage;
542 						format,							//  VkFormat			format;
543 						de::SharedPtr<bool>(new bool),	//  de::SharedPtr<bool>	geometryPointSizeSupported;
544 						requiredSubgroupSize,			//  deBool				requiredSubgroupSize;
545 						DE_FALSE,						//  deBool				requires8BitUniformBuffer;
546 						DE_FALSE,						//  deBool				requires16BitUniformBuffer;
547 					};
548 
549 					addFunctionCaseWithPrograms(computeGroup.get(), testName,supportedCheck, initPrograms, test, caseDef);
550 				}
551 
552 				for (size_t groupSizeNdx = 0; groupSizeNdx < DE_LENGTH_OF_ARRAY(boolValues); ++groupSizeNdx)
553 				{
554 					for (const auto& stage : meshStages)
555 					{
556 						const deBool			requiredSubgroupSize	= boolValues[groupSizeNdx];
557 						const string			testName				= name + (requiredSubgroupSize ? "_requiredsubgroupsize" : "") + "_" + getShaderStageName(stage);
558 						const CaseDefinition	caseDef					=
559 						{
560 							op,								//  Operator			op;
561 							st,								//  ScanType			scanType;
562 							stage,							//  VkShaderStageFlags	shaderStage;
563 							format,							//  VkFormat			format;
564 							de::SharedPtr<bool>(new bool),	//  de::SharedPtr<bool>	geometryPointSizeSupported;
565 							requiredSubgroupSize,			//  deBool				requiredSubgroupSize;
566 							DE_FALSE,						//  deBool				requires8BitUniformBuffer;
567 							DE_FALSE,						//  deBool				requires16BitUniformBuffer;
568 						};
569 
570 						addFunctionCaseWithPrograms(meshGroup.get(), testName,supportedCheck, initPrograms, test, caseDef);
571 					}
572 				}
573 
574 				{
575 					const CaseDefinition	caseDef		=
576 					{
577 						op,								//  Operator			op;
578 						st,								//  ScanType			scanType;
579 						VK_SHADER_STAGE_ALL_GRAPHICS,	//  VkShaderStageFlags	shaderStage;
580 						format,							//  VkFormat			format;
581 						de::SharedPtr<bool>(new bool),	//  de::SharedPtr<bool>	geometryPointSizeSupported;
582 						DE_FALSE,						//  deBool				requiredSubgroupSize;
583 						DE_FALSE,						//  deBool				requires8BitUniformBuffer;
584 						DE_FALSE						//  deBool				requires16BitUniformBuffer;
585 					};
586 
587 					addFunctionCaseWithPrograms(graphicGroup.get(), name, supportedCheck, initPrograms, test, caseDef);
588 				}
589 
590 				for (int stageIndex = 0; stageIndex < DE_LENGTH_OF_ARRAY(fbStages); ++stageIndex)
591 				{
592 					const CaseDefinition	caseDef		=
593 					{
594 						op,								//  Operator			op;
595 						st,								//  ScanType			scanType;
596 						fbStages[stageIndex],			//  VkShaderStageFlags	shaderStage;
597 						format,							//  VkFormat			format;
598 						de::SharedPtr<bool>(new bool),	//  de::SharedPtr<bool>	geometryPointSizeSupported;
599 						DE_FALSE,						//  deBool				requiredSubgroupSize;
600 						deBool(needs8BitUBOStorage),	//  deBool				requires8BitUniformBuffer;
601 						deBool(needs16BitUBOStorage)	//  deBool				requires16BitUniformBuffer;
602 					};
603 					const string			testName	= name + "_" + getShaderStageName(caseDef.shaderStage);
604 
605 					addFunctionCaseWithPrograms(framebufferGroup.get(), testName,supportedCheck, initFrameBufferPrograms, noSSBOtest, caseDef);
606 				}
607 			}
608 		}
609 	}
610 
611 	{
612 		const vector<VkFormat>	formats		= subgroups::getAllRayTracingFormats();
613 
614 		for (size_t formatIndex = 0; formatIndex < formats.size(); ++formatIndex)
615 		{
616 			const VkFormat	format		= formats[formatIndex];
617 			const string	formatName	= subgroups::getFormatNameForGLSL(format);
618 			const bool		isBool		= subgroups::isFormatBool(format);
619 			const bool		isFloat		= subgroups::isFormatFloat(format);
620 
621 			for (int opTypeIndex = 0; opTypeIndex < OPTYPE_LAST; ++opTypeIndex)
622 			{
623 				const OpType	opType		= static_cast<OpType>(opTypeIndex);
624 				const Operator	op			= getOperator(opType);
625 				const ScanType	st			= getScanType(opType);
626 				const bool		isBitwiseOp	= (op == OPERATOR_AND || op == OPERATOR_OR || op == OPERATOR_XOR);
627 
628 				// Skip float with bitwise category.
629 				if (isFloat && isBitwiseOp)
630 					continue;
631 
632 				// Skip bool when its not the bitwise category.
633 				if (isBool && !isBitwiseOp)
634 					continue;
635 
636 				{
637 					const CaseDefinition	caseDef		=
638 					{
639 						op,								//  Operator			op;
640 						st,								//  ScanType			scanType;
641 						SHADER_STAGE_ALL_RAY_TRACING,	//  VkShaderStageFlags	shaderStage;
642 						format,							//  VkFormat			format;
643 						de::SharedPtr<bool>(new bool),	//  de::SharedPtr<bool>	geometryPointSizeSupported;
644 						DE_FALSE,						//  deBool				requiredSubgroupSize;
645 						DE_FALSE,						//  deBool				requires8BitUniformBuffer;
646 						DE_FALSE						//  deBool				requires16BitUniformBuffer;
647 					};
648 					const string			name		= de::toLower(getOpTypeName(op, st)) + "_" + formatName;
649 
650 					addFunctionCaseWithPrograms(raytracingGroup.get(), name, supportedCheck, initPrograms, test, caseDef);
651 				}
652 			}
653 		}
654 	}
655 
656 	group->addChild(graphicGroup.release());
657 	group->addChild(computeGroup.release());
658 	group->addChild(framebufferGroup.release());
659 	group->addChild(raytracingGroup.release());
660 	group->addChild(meshGroup.release());
661 
662 	return group.release();
663 }
664 } // subgroups
665 } // vkt
666