• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*------------------------------------------------------------------------
2  * Vulkan Conformance Tests
3  * ------------------------
4  *
5  * Copyright (c) 2019 The Khronos Group Inc.
6  * Copyright (c) 2019 Google Inc.
7  * Copyright (c) 2017 Codeplay Software Ltd.
8  *
9  * Licensed under the Apache License, Version 2.0 (the "License");
10  * you may not use this file except in compliance with the License.
11  * You may obtain a copy of the License at
12  *
13  *      http://www.apache.org/licenses/LICENSE-2.0
14  *
15  * Unless required by applicable law or agreed to in writing, software
16  * distributed under the License is distributed on an "AS IS" BASIS,
17  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18  * See the License for the specific language governing permissions and
19  * limitations under the License.
20  *
21  */ /*!
22  * \file
23  * \brief Subgroups Tests
24  */ /*--------------------------------------------------------------------*/
25 
26 #include "vktSubgroupsVoteTests.hpp"
27 #include "vktSubgroupsTestsUtils.hpp"
28 
29 #include <string>
30 #include <vector>
31 
32 using namespace tcu;
33 using namespace std;
34 using namespace vk;
35 using namespace vkt;
36 
37 namespace
38 {
39 enum OpType
40 {
41 	OPTYPE_ALL			= 0,
42 	OPTYPE_ANY			= 1,
43 	OPTYPE_ALLEQUAL		= 2,
44 	OPTYPE_LAST_NON_ARB	= 3,
45 	OPTYPE_ALL_ARB		= 4,
46 	OPTYPE_ANY_ARB		= 5,
47 	OPTYPE_ALLEQUAL_ARB	= 6,
48 	OPTYPE_LAST
49 };
50 
51 struct CaseDefinition
52 {
53 	OpType				opType;
54 	VkShaderStageFlags	shaderStage;
55 	VkFormat			format;
56 	de::SharedPtr<bool>	geometryPointSizeSupported;
57 	deBool				requiredSubgroupSize;
58 	deBool				requires8BitUniformBuffer;
59 	deBool				requires16BitUniformBuffer;
60 };
61 
checkVertexPipelineStages(const void * internalData,vector<const void * > datas,deUint32 width,deUint32)62 static bool checkVertexPipelineStages (const void*			internalData,
63 									   vector<const void*>	datas,
64 									   deUint32				width,
65 									   deUint32)
66 {
67 	DE_UNREF(internalData);
68 
69 	return subgroups::check(datas, width, 0x1F);
70 }
71 
checkFragmentPipelineStages(const void * internalData,vector<const void * > datas,deUint32 width,deUint32 height,deUint32)72 static bool checkFragmentPipelineStages (const void*			internalData,
73 										 vector<const void*>	datas,
74 										 deUint32				width,
75 										 deUint32				height,
76 										 deUint32)
77 {
78 	DE_UNREF(internalData);
79 
80 	const deUint32* data = reinterpret_cast<const deUint32*>(datas[0]);
81 
82 	for (deUint32 x = 0u; x < width; ++x)
83 	{
84 		for (deUint32 y = 0u; y < height; ++y)
85 		{
86 			const deUint32 ndx = (x * height + y);
87 			const deUint32 val = data[ndx] & 0x1F;
88 
89 			if (data[ndx] & 0x40) //Helper fragment shader invocation was executed
90 			{
91 				if(val != 0x1F)
92 					return false;
93 			}
94 			else //Helper fragment shader invocation was not executed yet
95 			{
96 				if (val != 0x1E)
97 					return false;
98 			}
99 		}
100 	}
101 
102 	return true;
103 }
104 
checkCompute(const void * internalData,vector<const void * > datas,const deUint32 numWorkgroups[3],const deUint32 localSize[3],deUint32)105 static bool checkCompute (const void*			internalData,
106 						  vector<const void*>	datas,
107 						  const deUint32		numWorkgroups[3],
108 						  const deUint32		localSize[3],
109 						  deUint32)
110 {
111 	DE_UNREF(internalData);
112 
113 	return subgroups::checkCompute(datas, numWorkgroups, localSize, 0x1F);
114 }
115 
getOpTypeName(int opType)116 string getOpTypeName (int opType)
117 {
118 	switch (opType)
119 	{
120 		case OPTYPE_ALL:			return "subgroupAll";
121 		case OPTYPE_ANY:			return "subgroupAny";
122 		case OPTYPE_ALLEQUAL:		return "subgroupAllEqual";
123 		case OPTYPE_ALL_ARB:		return "allInvocationsARB";
124 		case OPTYPE_ANY_ARB:		return "anyInvocationARB";
125 		case OPTYPE_ALLEQUAL_ARB:	return "allInvocationsEqualARB";
126 		default:					TCU_THROW(InternalError, "Unsupported op type");
127 	}
128 }
129 
fmtIsBoolean(VkFormat format)130 bool fmtIsBoolean (VkFormat format)
131 {
132 	// For reasons unknown, the tests use R8_USCALED as the boolean format
133 	return	format == VK_FORMAT_R8_USCALED || format == VK_FORMAT_R8G8_USCALED ||
134 			format == VK_FORMAT_R8G8B8_USCALED || format == VK_FORMAT_R8G8B8A8_USCALED;
135 }
136 
getExtensions(bool arbFunctions)137 const string getExtensions (bool arbFunctions)
138 {
139 	return arbFunctions	?	"#extension GL_ARB_shader_group_vote: enable\n"
140 							"#extension GL_KHR_shader_subgroup_basic: enable\n"
141 						:	"#extension GL_KHR_shader_subgroup_vote: enable\n";
142 }
143 
getStageTestSource(const CaseDefinition & caseDef)144 const string getStageTestSource (const CaseDefinition& caseDef)
145 {
146 	const bool		formatIsBoolean	= fmtIsBoolean(caseDef.format);
147 	const string	op				= getOpTypeName(caseDef.opType);
148 	const string	fmt				= subgroups::getFormatNameForGLSL(caseDef.format);
149 	const string	computePart		= isAllComputeStages(caseDef.shaderStage)
150 									? op + "(data[gl_SubgroupInvocationID] > 0) ? 0x4 : 0x0"
151 									: "0x4";
152 
153 	return
154 		(OPTYPE_ALL == caseDef.opType || OPTYPE_ALL_ARB == caseDef.opType) ?
155 			"  tempRes = " + op + "(true) ? 0x1 : 0;\n"
156 			"  tempRes |= " + op + "(false) ? 0 : 0x1A;\n"
157 			"  tempRes |= " + computePart + ";\n"
158 		: (OPTYPE_ANY == caseDef.opType || OPTYPE_ANY_ARB == caseDef.opType) ?
159 			"  tempRes = " + op + "(true) ? 0x1 : 0;\n"
160 			"  tempRes |= " + op + "(false) ? 0 : 0x1A;\n"
161 			"  tempRes |= " + computePart + ";\n"
162 		: (OPTYPE_ALLEQUAL == caseDef.opType || OPTYPE_ALLEQUAL_ARB == caseDef.opType) ?
163 			"  " + fmt + " valueEqual = " + fmt + "(1.25 * float(data[gl_SubgroupInvocationID]) + 5.0);\n" +
164 			"  " + fmt + " valueNoEqual = " + fmt + (formatIsBoolean ? "(subgroupElect());\n" : "(gl_SubgroupInvocationID);\n") +
165 			"  tempRes = " + op + "(" + fmt + "(1)) ? 0x1 : 0;\n"
166 			"  tempRes |= "
167 				+ (formatIsBoolean ? "0x2" : op + "(" + fmt + "(gl_SubgroupInvocationID)) ? 0 : 0x2")
168 				+ ";\n"
169 			"  tempRes |= " + op + "(data[0]) ? 0x4 : 0;\n"
170 			"  tempRes |= " + op + "(valueEqual) ? 0x8 : 0x0;\n"
171 			"  tempRes |= " + op + "(valueNoEqual) ? 0x0 : 0x10;\n"
172 			"  if (subgroupElect()) tempRes |= 0x2 | 0x10;\n"
173 		: "";
174 }
175 
initFrameBufferPrograms(SourceCollections & programCollection,CaseDefinition caseDef)176 void initFrameBufferPrograms (SourceCollections& programCollection, CaseDefinition caseDef)
177 {
178 	const SpirvVersion			spirvVersion	= isAllRayTracingStages(caseDef.shaderStage) ? SPIRV_VERSION_1_4 : SPIRV_VERSION_1_3;
179 	const ShaderBuildOptions	buildOptions	(programCollection.usedVulkanVersion, spirvVersion, 0u);
180 	const bool					arbFunctions	= caseDef.opType > OPTYPE_LAST_NON_ARB;
181 	const string				extensions		= getExtensions(arbFunctions) + subgroups::getAdditionalExtensionForFormat(caseDef.format);
182 	const bool					pointSize		= *caseDef.geometryPointSizeSupported;
183 
184 	subgroups::initStdFrameBufferPrograms(programCollection, buildOptions, caseDef.shaderStage, caseDef.format, pointSize, extensions, getStageTestSource(caseDef), "");
185 }
186 
getStageTestSourceFrag(const CaseDefinition & caseDef)187 const string getStageTestSourceFrag (const CaseDefinition& caseDef)
188 {
189 	const bool		formatIsBoolean	= fmtIsBoolean(caseDef.format);
190 	const string	op				= getOpTypeName(caseDef.opType);
191 	const string	fmt				= subgroups::getFormatNameForGLSL(caseDef.format);
192 
193 	return
194 		(OPTYPE_ALL == caseDef.opType || OPTYPE_ALL_ARB == caseDef.opType) ?
195 			"  tempRes |= " + op + "(!gl_HelperInvocation) ? 0x0 : 0x1;\n"
196 			"  tempRes |= " + op + "(false) ? 0 : 0x1A;\n"
197 			"  tempRes |= 0x4;\n"
198 		: (OPTYPE_ANY == caseDef.opType || OPTYPE_ANY_ARB == caseDef.opType) ?
199 			"  tempRes |= " + op + "(gl_HelperInvocation) ? 0x1 : 0x0;\n"
200 			"  tempRes |= " + op + "(false) ? 0 : 0x1A;\n"
201 			"  tempRes |= 0x4;\n"
202 		: (OPTYPE_ALLEQUAL == caseDef.opType || OPTYPE_ALLEQUAL_ARB == caseDef.opType) ?
203 			"  " + fmt + " valueEqual = " + fmt + "(1.25 * float(data[gl_SubgroupInvocationID]) + 5.0);\n" +
204 			"  " + fmt + " valueNoEqual = " + fmt + (formatIsBoolean ? "(subgroupElect());\n" : "(gl_SubgroupInvocationID);\n") +
205 			"  tempRes |= " + getOpTypeName(caseDef.opType) + "("
206 			+ fmt + "(1)) ? 0x10 : 0;\n"
207 			"  tempRes |= "
208 				+ (formatIsBoolean ? "0x2" : op + "(" + fmt + "(gl_SubgroupInvocationID)) ? 0 : 0x2")
209 				+ ";\n"
210 			"  tempRes |= " + op + "(data[0]) ? 0x4 : 0;\n"
211 			"  tempRes |= " + op + "(valueEqual) ? 0x8 : 0x0;\n"
212 			"  tempRes |= " + op + "(gl_HelperInvocation) ? 0x0 : 0x1;\n"
213 			"  if (subgroupElect()) tempRes |= 0x2 | 0x10;\n"
214 		: "";
215 }
216 
initFrameBufferProgramsFrag(SourceCollections & programCollection,CaseDefinition caseDef)217 void initFrameBufferProgramsFrag (SourceCollections& programCollection, CaseDefinition caseDef)
218 {
219 	const SpirvVersion			spirvVersion	= isAllRayTracingStages(caseDef.shaderStage) ? SPIRV_VERSION_1_4 : SPIRV_VERSION_1_3;
220 	const ShaderBuildOptions	buildOptions	(programCollection.usedVulkanVersion, spirvVersion, 0u);
221 	const bool					arbFunctions	= caseDef.opType > OPTYPE_LAST_NON_ARB;
222 	const string				extensions		= getExtensions(arbFunctions) + subgroups::getAdditionalExtensionForFormat(caseDef.format);
223 
224 	DE_ASSERT(VK_SHADER_STAGE_FRAGMENT_BIT == caseDef.shaderStage);
225 
226 	{
227 		const string	vertex	=
228 			"#version 450\n"
229 			"void main (void)\n"
230 			"{\n"
231 			"  vec2 uv = vec2(float(gl_VertexIndex & 1), float((gl_VertexIndex >> 1) & 1));\n"
232 			"  gl_Position = vec4(uv * 4.0f -2.0f, 0.0f, 1.0f);\n"
233 			"  gl_PointSize = 1.0f;\n"
234 			"}\n";
235 
236 		programCollection.glslSources.add("vert") << glu::VertexSource(vertex) << buildOptions;
237 	}
238 
239 	{
240 		ostringstream	fragmentSource;
241 
242 		fragmentSource << glu::getGLSLVersionDeclaration(glu::GLSL_VERSION_450)<<"\n"
243 			<< extensions
244 			<< "layout(location = 0) out uint out_color;\n"
245 			<< "layout(set = 0, binding = 0) uniform Buffer1\n"
246 			<< "{\n"
247 			<< "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " data[" << subgroups::maxSupportedSubgroupSize() << "];\n"
248 			<< "};\n"
249 			<< ""
250 			<< "void main()\n"
251 			<< "{\n"
252 			<< "  uint tempRes = 0u;\n"
253 			<< "  if (dFdx(gl_SubgroupInvocationID * gl_FragCoord.x * gl_FragCoord.y) - dFdy(gl_SubgroupInvocationID * gl_FragCoord.x * gl_FragCoord.y) > 0.0f)\n"
254 			<< "  {\n"
255 			<< "    tempRes |= 0x20;\n" // to be sure that compiler doesn't remove dFdx and dFdy executions
256 			<< "  }\n"
257 			<< (arbFunctions ?
258 				"  bool helper = anyInvocationARB(gl_HelperInvocation);\n" :
259 				"  bool helper = subgroupAny(gl_HelperInvocation);\n")
260 			<< "  if (helper)\n"
261 			<< "  {\n"
262 			<< "    tempRes |= 0x40;\n"
263 			<< "  }\n"
264 			<< getStageTestSourceFrag(caseDef)
265 			<< "  out_color = tempRes;\n"
266 			<< "}\n";
267 
268 		programCollection.glslSources.add("fragment") << glu::FragmentSource(fragmentSource.str())<< buildOptions;
269 	}
270 }
271 
initPrograms(SourceCollections & programCollection,CaseDefinition caseDef)272 void initPrograms (SourceCollections& programCollection, CaseDefinition caseDef)
273 {
274 	const SpirvVersion			spirvVersion	= isAllRayTracingStages(caseDef.shaderStage) ? SPIRV_VERSION_1_4 : SPIRV_VERSION_1_3;
275 	const ShaderBuildOptions	buildOptions	(programCollection.usedVulkanVersion, spirvVersion, 0u);
276 	const bool					arbFunctions	= caseDef.opType > OPTYPE_LAST_NON_ARB;
277 	const string				extensions		= getExtensions(arbFunctions) + subgroups::getAdditionalExtensionForFormat(caseDef.format);
278 	const bool					pointSize		= *caseDef.geometryPointSizeSupported;
279 
280 	subgroups::initStdPrograms(programCollection, buildOptions, caseDef.shaderStage, caseDef.format, pointSize, extensions, getStageTestSource(caseDef), "");
281 }
282 
supportedCheck(Context & context,CaseDefinition caseDef)283 void supportedCheck (Context& context, CaseDefinition caseDef)
284 {
285 	if (!subgroups::isSubgroupSupported(context))
286 		TCU_THROW(NotSupportedError, "Subgroup operations are not supported");
287 
288 	if (!subgroups::isSubgroupFeatureSupportedForDevice(context, VK_SUBGROUP_FEATURE_VOTE_BIT))
289 	{
290 		TCU_THROW(NotSupportedError, "Device does not support subgroup vote operations");
291 	}
292 
293 	if (!subgroups::isFormatSupportedForDevice(context, caseDef.format))
294 		TCU_THROW(NotSupportedError, "Device does not support the specified format in subgroup operations");
295 
296 	if (caseDef.requires16BitUniformBuffer)
297 	{
298 		if (!subgroups::is16BitUBOStorageSupported(context))
299 		{
300 			TCU_THROW(NotSupportedError, "Device does not support the specified format in subgroup operations");
301 		}
302 	}
303 
304 	if (caseDef.requires8BitUniformBuffer)
305 	{
306 		if (!subgroups::is8BitUBOStorageSupported(context))
307 		{
308 			TCU_THROW(NotSupportedError, "Device does not support the specified format in subgroup operations");
309 		}
310 	}
311 
312 	if (caseDef.opType > OPTYPE_LAST_NON_ARB)
313 	{
314 		context.requireDeviceFunctionality("VK_EXT_shader_subgroup_vote");
315 	}
316 
317 	if (caseDef.requiredSubgroupSize)
318 	{
319 		context.requireDeviceFunctionality("VK_EXT_subgroup_size_control");
320 
321 		const VkPhysicalDeviceSubgroupSizeControlFeaturesEXT&	subgroupSizeControlFeatures		= context.getSubgroupSizeControlFeaturesEXT();
322 		const VkPhysicalDeviceSubgroupSizeControlPropertiesEXT&	subgroupSizeControlProperties	= context.getSubgroupSizeControlPropertiesEXT();
323 
324 		if (subgroupSizeControlFeatures.subgroupSizeControl == DE_FALSE)
325 			TCU_THROW(NotSupportedError, "Device does not support varying subgroup sizes nor required subgroup size");
326 
327 		if (subgroupSizeControlFeatures.computeFullSubgroups == DE_FALSE)
328 			TCU_THROW(NotSupportedError, "Device does not support full subgroups in compute shaders");
329 
330 		if ((subgroupSizeControlProperties.requiredSubgroupSizeStages & caseDef.shaderStage) != caseDef.shaderStage)
331 			TCU_THROW(NotSupportedError, "Required subgroup size is not supported for shader stage");
332 	}
333 
334 	*caseDef.geometryPointSizeSupported = subgroups::isTessellationAndGeometryPointSizeSupported(context);
335 
336 	if (isAllRayTracingStages(caseDef.shaderStage))
337 	{
338 		context.requireDeviceFunctionality("VK_KHR_ray_tracing_pipeline");
339 	}
340 
341 	subgroups::supportedCheckShader(context, caseDef.shaderStage);
342 }
343 
noSSBOtest(Context & context,const CaseDefinition caseDef)344 TestStatus noSSBOtest (Context& context, const CaseDefinition caseDef)
345 {
346 	if (caseDef.opType > OPTYPE_LAST_NON_ARB)
347 	{
348 		context.requireDeviceFunctionality("VK_EXT_shader_subgroup_vote");
349 	}
350 
351 	const subgroups::SSBOData::InputDataInitializeType	initializeType	= (OPTYPE_ALLEQUAL == caseDef.opType || OPTYPE_ALLEQUAL_ARB == caseDef.opType)
352 																		? subgroups::SSBOData::InitializeZero
353 																		: subgroups::SSBOData::InitializeNonZero;
354 	const subgroups::SSBOData							inputData
355 	{
356 		initializeType,							//  InputDataInitializeType		initializeType;
357 		subgroups::SSBOData::LayoutStd140,		//  InputDataLayoutType			layout;
358 		caseDef.format,							//  vk::VkFormat				format;
359 		subgroups::maxSupportedSubgroupSize(),	//  vk::VkDeviceSize			numElements;
360 	};
361 
362 	switch (caseDef.shaderStage)
363 	{
364 		case VK_SHADER_STAGE_VERTEX_BIT:					return subgroups::makeVertexFrameBufferTest(context, VK_FORMAT_R32_UINT, &inputData, 1, DE_NULL, checkVertexPipelineStages);
365 		case VK_SHADER_STAGE_GEOMETRY_BIT:					return subgroups::makeGeometryFrameBufferTest(context, VK_FORMAT_R32_UINT, &inputData, 1, DE_NULL, checkVertexPipelineStages);
366 		case VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT:		return subgroups::makeTessellationEvaluationFrameBufferTest(context, VK_FORMAT_R32_UINT, &inputData, 1, DE_NULL, checkVertexPipelineStages, caseDef.shaderStage);
367 		case VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT:	return subgroups::makeTessellationEvaluationFrameBufferTest(context, VK_FORMAT_R32_UINT, &inputData, 1, DE_NULL, checkVertexPipelineStages, caseDef.shaderStage);
368 		case VK_SHADER_STAGE_FRAGMENT_BIT:					return subgroups::makeFragmentFrameBufferTest(context, VK_FORMAT_R32_UINT, &inputData, 1, DE_NULL, checkFragmentPipelineStages);
369 		default:											TCU_THROW(InternalError, "Unhandled shader stage");
370 	}
371 }
372 
test(Context & context,const CaseDefinition caseDef)373 TestStatus test (Context& context, const CaseDefinition caseDef)
374 {
375 	const subgroups::SSBOData::InputDataInitializeType	initializeType	= (OPTYPE_ALLEQUAL == caseDef.opType || OPTYPE_ALLEQUAL_ARB == caseDef.opType)
376 																		? subgroups::SSBOData::InitializeZero
377 																		: subgroups::SSBOData::InitializeNonZero;
378 
379 	if (isAllComputeStages(caseDef.shaderStage))
380 	{
381 		const VkPhysicalDeviceSubgroupSizeControlPropertiesEXT&	subgroupSizeControlProperties	= context.getSubgroupSizeControlPropertiesEXT();
382 		TestLog&												log								= context.getTestContext().getLog();
383 		const subgroups::SSBOData								inputData
384 		{
385 			initializeType,							//  InputDataInitializeType		initializeType;
386 			subgroups::SSBOData::LayoutStd430,		//  InputDataLayoutType			layout;
387 			caseDef.format,							//  vk::VkFormat				format;
388 			subgroups::maxSupportedSubgroupSize(),	//  vk::VkDeviceSize			numElements;
389 		};
390 
391 		if (caseDef.requiredSubgroupSize == DE_FALSE)
392 			return subgroups::makeComputeTest(context, VK_FORMAT_R32_UINT, &inputData, 1, DE_NULL, checkCompute);
393 
394 		log << TestLog::Message << "Testing required subgroup size range [" <<  subgroupSizeControlProperties.minSubgroupSize << ", "
395 			<< subgroupSizeControlProperties.maxSubgroupSize << "]" << TestLog::EndMessage;
396 
397 		// According to the spec, requiredSubgroupSize must be a power-of-two integer.
398 		for (deUint32 size = subgroupSizeControlProperties.minSubgroupSize; size <= subgroupSizeControlProperties.maxSubgroupSize; size *= 2)
399 		{
400 			TestStatus result = subgroups::makeComputeTest(context, VK_FORMAT_R32_UINT, &inputData, 1, DE_NULL, checkCompute, size, VK_PIPELINE_SHADER_STAGE_CREATE_REQUIRE_FULL_SUBGROUPS_BIT_EXT);
401 
402 			if (result.getCode() != QP_TEST_RESULT_PASS)
403 			{
404 				log << TestLog::Message << "subgroupSize " << size << " failed" << TestLog::EndMessage;
405 				return result;
406 			}
407 		}
408 
409 		return TestStatus::pass("OK");
410 	}
411 	else if (isAllGraphicsStages(caseDef.shaderStage))
412 	{
413 		const VkShaderStageFlags	stages		= subgroups::getPossibleGraphicsSubgroupStages(context, caseDef.shaderStage);
414 		const subgroups::SSBOData	inputData	=
415 		{
416 			initializeType,							//  InputDataInitializeType		initializeType;
417 			subgroups::SSBOData::LayoutStd430,		//  InputDataLayoutType			layout;
418 			caseDef.format,							//  vk::VkFormat				format;
419 			subgroups::maxSupportedSubgroupSize(),	//  vk::VkDeviceSize			numElements;
420 			false,									//  bool						isImage;
421 			4u,										//  deUint32					binding;
422 			stages,									//  vk::VkShaderStageFlags		stages;
423 		};
424 
425 		return subgroups::allStages(context, VK_FORMAT_R32_UINT, &inputData, 1, DE_NULL, checkVertexPipelineStages, stages);
426 	}
427 	else if (isAllRayTracingStages(caseDef.shaderStage))
428 	{
429 		const VkShaderStageFlags	stages		= subgroups::getPossibleRayTracingSubgroupStages(context, caseDef.shaderStage);
430 		const subgroups::SSBOData	inputData	=
431 		{
432 			initializeType,							//  InputDataInitializeType		initializeType;
433 			subgroups::SSBOData::LayoutStd430,		//  InputDataLayoutType			layout;
434 			caseDef.format,							//  vk::VkFormat				format;
435 			subgroups::maxSupportedSubgroupSize(),	//  vk::VkDeviceSize			numElements;
436 			false,									//  bool						isImage;
437 			6u,										//  deUint32					binding;
438 			stages,									//  vk::VkShaderStageFlags		stages;
439 		};
440 
441 		return subgroups::allRayTracingStages(context, VK_FORMAT_R32_UINT, &inputData, 1, DE_NULL, checkVertexPipelineStages, stages);
442 	}
443 	else
444 		TCU_THROW(InternalError, "Unknown stage or invalid stage set");
445 }
446 }
447 
448 namespace vkt
449 {
450 namespace subgroups
451 {
createSubgroupsVoteTests(TestContext & testCtx)452 TestCaseGroup* createSubgroupsVoteTests (TestContext& testCtx)
453 {
454 	de::MovePtr<TestCaseGroup>	group				(new TestCaseGroup(testCtx, "vote", "Subgroup vote category tests"));
455 	de::MovePtr<TestCaseGroup>	graphicGroup		(new TestCaseGroup(testCtx, "graphics", "Subgroup vote category tests: graphics"));
456 	de::MovePtr<TestCaseGroup>	computeGroup		(new TestCaseGroup(testCtx, "compute", "Subgroup vote category tests: compute"));
457 	de::MovePtr<TestCaseGroup>	framebufferGroup	(new TestCaseGroup(testCtx, "framebuffer", "Subgroup vote category tests: framebuffer"));
458 	de::MovePtr<TestCaseGroup>	fragHelperGroup		(new TestCaseGroup(testCtx, "frag_helper", "Subgroup vote category tests: fragment helper invocation"));
459 	de::MovePtr<TestCaseGroup>	raytracingGroup		(new TestCaseGroup(testCtx, "ray_tracing", "Subgroup vote category tests: raytracing"));
460 
461 	de::MovePtr<TestCaseGroup>	groupARB			(new TestCaseGroup(testCtx, "ext_shader_subgroup_vote", "VK_EXT_shader_subgroup_vote category tests"));
462 	de::MovePtr<TestCaseGroup>	graphicGroupARB		(new TestCaseGroup(testCtx, "graphics", "Subgroup vote category tests: graphics"));
463 	de::MovePtr<TestCaseGroup>	computeGroupARB		(new TestCaseGroup(testCtx, "compute", "Subgroup vote category tests: compute"));
464 	de::MovePtr<TestCaseGroup>	framebufferGroupARB	(new TestCaseGroup(testCtx, "framebuffer", "Subgroup vote category tests: framebuffer"));
465 	de::MovePtr<TestCaseGroup>	fragHelperGroupARB	(new TestCaseGroup(testCtx, "frag_helper", "Subgroup vote category tests: fragment helper invocation"));
466 	const deBool				boolValues[]		=
467 	{
468 		DE_FALSE,
469 		DE_TRUE
470 	};
471 
472 	{
473 		const VkShaderStageFlags	stages[]	=
474 		{
475 			VK_SHADER_STAGE_VERTEX_BIT,
476 			VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT,
477 			VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT,
478 			VK_SHADER_STAGE_GEOMETRY_BIT,
479 		};
480 		const vector<VkFormat>		formats		= subgroups::getAllFormats();
481 
482 		for (size_t formatIndex = 0; formatIndex < formats.size(); ++formatIndex)
483 		{
484 			const VkFormat	format					= formats[formatIndex];
485 			const bool		needs8BitUBOStorage		= isFormat8bitTy(format);
486 			const bool		needs16BitUBOStorage	= isFormat16BitTy(format);
487 			const deBool	formatIsNotVector		=  format == VK_FORMAT_R8_USCALED
488 													|| format == VK_FORMAT_R32_UINT
489 													|| format == VK_FORMAT_R32_SINT
490 													|| format == VK_FORMAT_R32_SFLOAT
491 													|| format == VK_FORMAT_R64_SFLOAT;
492 
493 			for (int opTypeIndex = 0; opTypeIndex < OPTYPE_LAST; ++opTypeIndex)
494 			{
495 				const OpType	opType	= static_cast<OpType>(opTypeIndex);
496 
497 				// Skip OPTYPE_LAST_NON_ARB because it is not a real op type.
498 				if (opType == OPTYPE_LAST_NON_ARB)
499 					continue;
500 
501 				// Skip the non-nonvector tests because VK_EXT_shader_subgroup_vote functions only supports boolean scalar arguments.
502 				if (opType > OPTYPE_LAST_NON_ARB && !formatIsNotVector)
503 					continue;
504 
505 				// Skip non-boolean formats when testing allInvocationsEqualARB(bool value), because it requires a boolean
506 				// argument that should have the same value for all invocations. For the rest of formats, it won't be a boolean argument,
507 				// so it may give wrong results when converting to bool.
508 				if (opType == OPTYPE_ALLEQUAL_ARB && format != VK_FORMAT_R8_USCALED)
509 					continue;
510 
511 				// Skip the typed tests for all but subgroupAllEqual() and allInvocationsEqualARB()
512 				if ((VK_FORMAT_R32_UINT != format) && (OPTYPE_ALLEQUAL != opType) && (OPTYPE_ALLEQUAL_ARB != opType))
513 				{
514 					continue;
515 				}
516 
517 				const string	op					= de::toLower(getOpTypeName(opType));
518 				const string	name				= op + "_" + subgroups::getFormatNameForGLSL(format);
519 				TestCaseGroup*	computeGroupPtr		= (opType < OPTYPE_LAST_NON_ARB) ? computeGroup.get() : computeGroupARB.get();
520 				TestCaseGroup*	graphicGroupPtr		= (opType < OPTYPE_LAST_NON_ARB) ? graphicGroup.get() : graphicGroupARB.get();
521 				TestCaseGroup*	framebufferGroupPtr	= (opType < OPTYPE_LAST_NON_ARB) ? framebufferGroup.get() : framebufferGroupARB.get();
522 				TestCaseGroup*	fragHelperGroupPtr	= (opType < OPTYPE_LAST_NON_ARB) ? fragHelperGroup.get() : fragHelperGroupARB.get();
523 
524 				for (size_t groupSizeNdx = 0; groupSizeNdx < DE_LENGTH_OF_ARRAY(boolValues); ++groupSizeNdx)
525 				{
526 					const deBool			requiredSubgroupSize	= boolValues[groupSizeNdx];
527 					const string			testName				= name + (requiredSubgroupSize ? "_requiredsubgroupsize" : "");
528 					const CaseDefinition	caseDef					=
529 					{
530 						opType,							//  OpType				opType;
531 						VK_SHADER_STAGE_COMPUTE_BIT,	//  VkShaderStageFlags	shaderStage;
532 						format,							//  VkFormat			format;
533 						de::SharedPtr<bool>(new bool),	//  de::SharedPtr<bool>	geometryPointSizeSupported;
534 						requiredSubgroupSize,			//  deBool				requiredSubgroupSize;
535 						deBool(false),					//  deBool				requires8BitUniformBuffer;
536 						deBool(false)					//  deBool				requires16BitUniformBuffer;
537 					};
538 
539 					addFunctionCaseWithPrograms(computeGroupPtr, testName, "", supportedCheck, initPrograms, test, caseDef);
540 				}
541 
542 				{
543 					const CaseDefinition	caseDef		=
544 					{
545 						opType,							//  OpType				opType;
546 						VK_SHADER_STAGE_ALL_GRAPHICS,	//  VkShaderStageFlags	shaderStage;
547 						format,							//  VkFormat			format;
548 						de::SharedPtr<bool>(new bool),	//  de::SharedPtr<bool>	geometryPointSizeSupported;
549 						DE_FALSE,						//  deBool				requiredSubgroupSize;
550 						deBool(false),					//  deBool				requires8BitUniformBuffer;
551 						deBool(false)					//  deBool				requires16BitUniformBuffer;
552 					};
553 
554 					addFunctionCaseWithPrograms(graphicGroupPtr, name, "", supportedCheck, initPrograms, test, caseDef);
555 				}
556 
557 				for (int stageIndex = 0; stageIndex < DE_LENGTH_OF_ARRAY(stages); ++stageIndex)
558 				{
559 					const CaseDefinition	caseDef		=
560 					{
561 						opType,							//  OpType				opType;
562 						stages[stageIndex],				//  VkShaderStageFlags	shaderStage;
563 						format,							//  VkFormat			format;
564 						de::SharedPtr<bool>(new bool),	//  de::SharedPtr<bool>	geometryPointSizeSupported;
565 						DE_FALSE,						//  deBool				requiredSubgroupSize;
566 						deBool(false),					//  deBool				requires8BitUniformBuffer;
567 						deBool(false)					//  deBool				requires16BitUniformBuffer;
568 					};
569 					const string			testName	= name + "_" + getShaderStageName(caseDef.shaderStage);
570 
571 					addFunctionCaseWithPrograms(framebufferGroupPtr, testName, "", supportedCheck, initFrameBufferPrograms, noSSBOtest, caseDef);
572 				}
573 
574 				{
575 					const CaseDefinition	caseDef		=
576 					{
577 						opType,							//  OpType				opType;
578 						VK_SHADER_STAGE_FRAGMENT_BIT,	//  VkShaderStageFlags	shaderStage;
579 						format,							//  VkFormat			format;
580 						de::SharedPtr<bool>(new bool),	//  de::SharedPtr<bool>	geometryPointSizeSupported;
581 						DE_FALSE,						//  deBool				requiredSubgroupSize;
582 						deBool(needs8BitUBOStorage),	//  deBool				requires8BitUniformBuffer;
583 						deBool(needs16BitUBOStorage)	//  deBool				requires16BitUniformBuffer;
584 					};
585 					const string			testName	= name + "_" + getShaderStageName(caseDef.shaderStage);
586 
587 					addFunctionCaseWithPrograms(fragHelperGroupPtr, testName, "", supportedCheck, initFrameBufferProgramsFrag, noSSBOtest, caseDef);
588 				}
589 			}
590 		}
591 	}
592 
593 	{
594 		const vector<VkFormat>	formats		= subgroups::getAllRayTracingFormats();
595 
596 		for (size_t formatIndex = 0; formatIndex < formats.size(); ++formatIndex)
597 		{
598 			const VkFormat	format	= formats[formatIndex];
599 
600 			for (int opTypeIndex = 0; opTypeIndex < OPTYPE_LAST_NON_ARB; ++opTypeIndex)
601 			{
602 				const OpType	opType	= static_cast<OpType>(opTypeIndex);
603 
604 				// Skip the typed tests for all but subgroupAllEqual()
605 				if ((VK_FORMAT_R32_UINT != format) && (OPTYPE_ALLEQUAL != opType))
606 				{
607 					continue;
608 				}
609 
610 				const string			op		= de::toLower(getOpTypeName(opType));
611 				const string			name	= op + "_" + subgroups::getFormatNameForGLSL(format);
612 				const CaseDefinition	caseDef	=
613 				{
614 					opType,							//  OpType				opType;
615 					SHADER_STAGE_ALL_RAY_TRACING,	//  VkShaderStageFlags	shaderStage;
616 					format,							//  VkFormat			format;
617 					de::SharedPtr<bool>(new bool),	//  de::SharedPtr<bool>	geometryPointSizeSupported;
618 					DE_FALSE,						//  deBool				requiredSubgroupSize;
619 					DE_FALSE,						//  deBool				requires8BitUniformBuffer;
620 					DE_FALSE						//  deBool				requires16BitUniformBuffer;
621 				};
622 
623 				addFunctionCaseWithPrograms(raytracingGroup.get(), name, "", supportedCheck, initPrograms, test, caseDef);
624 			}
625 		}
626 	}
627 
628 	groupARB->addChild(graphicGroupARB.release());
629 	groupARB->addChild(computeGroupARB.release());
630 	groupARB->addChild(framebufferGroupARB.release());
631 	groupARB->addChild(fragHelperGroupARB.release());
632 
633 	group->addChild(graphicGroup.release());
634 	group->addChild(computeGroup.release());
635 	group->addChild(framebufferGroup.release());
636 	group->addChild(fragHelperGroup.release());
637 	group->addChild(raytracingGroup.release());
638 
639 	group->addChild(groupARB.release());
640 
641 	return group.release();
642 }
643 
644 } // subgroups
645 } // vkt
646