• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*------------------------------------------------------------------------
2  * Vulkan Conformance Tests
3  * ------------------------
4  *
5  * Copyright (c) 2019 The Khronos Group Inc.
6  * Copyright (c) 2019 Google Inc.
7  * Copyright (c) 2017 Codeplay Software Ltd.
8  *
9  * Licensed under the Apache License, Version 2.0 (the "License");
10  * you may not use this file except in compliance with the License.
11  * You may obtain a copy of the License at
12  *
13  *      http://www.apache.org/licenses/LICENSE-2.0
14  *
15  * Unless required by applicable law or agreed to in writing, software
16  * distributed under the License is distributed on an "AS IS" BASIS,
17  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18  * See the License for the specific language governing permissions and
19  * limitations under the License.
20  *
21  */ /*!
22  * \file
23  * \brief Subgroups Tests
24  */ /*--------------------------------------------------------------------*/
25 
26 #include "vktSubgroupsQuadTests.hpp"
27 #include "vktSubgroupsTestsUtils.hpp"
28 
29 #include <string>
30 #include <vector>
31 
32 using namespace tcu;
33 using namespace std;
34 using namespace vk;
35 using namespace vkt;
36 
37 namespace
38 {
39 enum OpType
40 {
41 	OPTYPE_QUAD_BROADCAST = 0,
42 	OPTYPE_QUAD_BROADCAST_NONCONST,
43 	OPTYPE_QUAD_SWAP_HORIZONTAL,
44 	OPTYPE_QUAD_SWAP_VERTICAL,
45 	OPTYPE_QUAD_SWAP_DIAGONAL,
46 	OPTYPE_LAST
47 };
48 
49 struct CaseDefinition
50 {
51 	OpType				opType;
52 	VkShaderStageFlags	shaderStage;
53 	VkFormat			format;
54 	de::SharedPtr<bool>	geometryPointSizeSupported;
55 	deBool				requiredSubgroupSize;
56 };
57 
checkVertexPipelineStages(const void * internalData,vector<const void * > datas,deUint32 width,deUint32)58 static bool checkVertexPipelineStages (const void*			internalData,
59 									   vector<const void*>	datas,
60 									   deUint32				width,
61 									   deUint32)
62 {
63 	DE_UNREF(internalData);
64 
65 	return subgroups::check(datas, width, 1);
66 }
67 
checkCompute(const void * internalData,vector<const void * > datas,const deUint32 numWorkgroups[3],const deUint32 localSize[3],deUint32)68 static bool checkCompute (const void*			internalData,
69 						  vector<const void*>	datas,
70 						  const deUint32		numWorkgroups[3],
71 						  const deUint32		localSize[3],
72 						  deUint32)
73 {
74 	DE_UNREF(internalData);
75 
76 	return subgroups::checkCompute(datas, numWorkgroups, localSize, 1);
77 }
78 
getOpTypeName(OpType opType)79 string getOpTypeName (OpType opType)
80 {
81 	switch (opType)
82 	{
83 		case OPTYPE_QUAD_BROADCAST:				return "subgroupQuadBroadcast";
84 		case OPTYPE_QUAD_BROADCAST_NONCONST:	return "subgroupQuadBroadcast";
85 		case OPTYPE_QUAD_SWAP_HORIZONTAL:		return "subgroupQuadSwapHorizontal";
86 		case OPTYPE_QUAD_SWAP_VERTICAL:			return "subgroupQuadSwapVertical";
87 		case OPTYPE_QUAD_SWAP_DIAGONAL:			return "subgroupQuadSwapDiagonal";
88 		default:								TCU_THROW(InternalError, "Unsupported op type");
89 	}
90 }
91 
getOpTypeCaseName(OpType opType)92 string getOpTypeCaseName (OpType opType)
93 {
94 	switch (opType)
95 	{
96 		case OPTYPE_QUAD_BROADCAST:				return "subgroupquadbroadcast";
97 		case OPTYPE_QUAD_BROADCAST_NONCONST:	return "subgroupquadbroadcast_nonconst";
98 		case OPTYPE_QUAD_SWAP_HORIZONTAL:		return "subgroupquadswaphorizontal";
99 		case OPTYPE_QUAD_SWAP_VERTICAL:			return "subgroupquadswapvertical";
100 		case OPTYPE_QUAD_SWAP_DIAGONAL:			return "subgroupquadswapdiagonal";
101 		default:								TCU_THROW(InternalError, "Unsupported op type");
102 	}
103 }
104 
getExtHeader(VkFormat format)105 string getExtHeader (VkFormat format)
106 {
107 	return	"#extension GL_KHR_shader_subgroup_quad: enable\n"
108 			"#extension GL_KHR_shader_subgroup_ballot: enable\n" +
109 			subgroups::getAdditionalExtensionForFormat(format);
110 }
111 
getTestSrc(const CaseDefinition & caseDef)112 string getTestSrc (const CaseDefinition &caseDef)
113 {
114 	const string	swapTable[OPTYPE_LAST]	=
115 	{
116 		"",
117 		"",
118 		"  const uint swapTable[4] = {1, 0, 3, 2};\n",
119 		"  const uint swapTable[4] = {2, 3, 0, 1};\n",
120 		"  const uint swapTable[4] = {3, 2, 1, 0};\n",
121 	};
122 	const string	validate				=
123 		"  if (subgroupBallotBitExtract(mask, otherID) && op !=data[otherID])\n"
124 		"    tempRes = 0;\n";
125 	const string	fmt						= subgroups::getFormatNameForGLSL(caseDef.format);
126 	const string	op						= getOpTypeName(caseDef.opType);
127 	ostringstream	testSrc;
128 
129 	testSrc	<< "  uvec4 mask = subgroupBallot(true);\n"
130 			<< swapTable[caseDef.opType]
131 			<< "  tempRes = 1;\n";
132 
133 	if (caseDef.opType == OPTYPE_QUAD_BROADCAST)
134 	{
135 		for (int i=0; i<4; i++)
136 		{
137 			testSrc << "  {\n"
138 					<< "  " << fmt << " op = " << op << "(data[gl_SubgroupInvocationID], " << i << ");\n"
139 					<< "  uint otherID = (gl_SubgroupInvocationID & ~0x3) + " << i << ";\n"
140 					<< validate
141 					<< "  }\n";
142 		}
143 	}
144 	else if (caseDef.opType == OPTYPE_QUAD_BROADCAST_NONCONST)
145 	{
146 		testSrc << "  for (int i=0; i<4; i++)"
147 				<< "  {\n"
148 				<< "  " << fmt << " op = " << op << "(data[gl_SubgroupInvocationID], i);\n"
149 				<< "  uint otherID = (gl_SubgroupInvocationID & ~0x3) + i;\n"
150 				<< validate
151 				<< "  }\n"
152 				<< "  uint quadID = gl_SubgroupInvocationID >> 2;\n"
153 				<< "  uint quadInvocation = gl_SubgroupInvocationID & 0x3;\n"
154 				<< "  // Test lane ID that is only uniform in active lanes\n"
155 				<< "  if (quadInvocation >= 2)\n"
156 				<< "  {\n"
157 				<< "    uint id = quadInvocation & ~1;\n"
158 				<< "    " << fmt << " op = " << op << "(data[gl_SubgroupInvocationID], id);\n"
159 				<< "    uint otherID = 4*quadID + id;\n"
160 				<< validate
161 				<< "  }\n"
162 				<< "  // Test lane ID that is only quad uniform, not subgroup uniform\n"
163 				<< "  {\n"
164 				<< "    uint id = quadID & 0x3;\n"
165 				<< "    " << fmt << " op = " << op << "(data[gl_SubgroupInvocationID], id);\n"
166 				<< "    uint otherID = 4*quadID + id;\n"
167 				<< validate
168 				<< "  }\n";
169 	}
170 	else
171 	{
172 		testSrc << "  " << fmt << " op = " << op << "(data[gl_SubgroupInvocationID]);\n"
173 				<< "  uint otherID = (gl_SubgroupInvocationID & ~0x3) + swapTable[gl_SubgroupInvocationID & 0x3];\n"
174 				<< validate;
175 	}
176 
177 	return testSrc.str();
178 }
179 
initFrameBufferPrograms(SourceCollections & programCollection,CaseDefinition caseDef)180 void initFrameBufferPrograms (SourceCollections& programCollection, CaseDefinition caseDef)
181 {
182 	const SpirvVersion			spirvVersion	= (caseDef.opType == OPTYPE_QUAD_BROADCAST_NONCONST) ? SPIRV_VERSION_1_5 : SPIRV_VERSION_1_3;
183 	const ShaderBuildOptions	buildOptions	(programCollection.usedVulkanVersion, spirvVersion, 0u);
184 
185 	subgroups::initStdFrameBufferPrograms(programCollection, buildOptions, caseDef.shaderStage, caseDef.format, *caseDef.geometryPointSizeSupported, getExtHeader(caseDef.format), getTestSrc(caseDef), "");
186 }
187 
initPrograms(SourceCollections & programCollection,CaseDefinition caseDef)188 void initPrograms (SourceCollections& programCollection, CaseDefinition caseDef)
189 {
190 	const bool					spirv15required	= caseDef.opType == OPTYPE_QUAD_BROADCAST_NONCONST;
191 	const bool					spirv14required	= isAllRayTracingStages(caseDef.shaderStage);
192 	const SpirvVersion			spirvVersion	= spirv15required ? SPIRV_VERSION_1_5
193 												: spirv14required ? SPIRV_VERSION_1_4
194 												: SPIRV_VERSION_1_3;
195 	const ShaderBuildOptions	buildOptions	(programCollection.usedVulkanVersion, spirvVersion, 0u);
196 	const string				extHeader		= getExtHeader(caseDef.format);
197 	const string				testSrc			= getTestSrc(caseDef);
198 
199 	subgroups::initStdPrograms(programCollection, buildOptions, caseDef.shaderStage, caseDef.format, *caseDef.geometryPointSizeSupported, extHeader, testSrc, "");
200 }
201 
supportedCheck(Context & context,CaseDefinition caseDef)202 void supportedCheck (Context& context, CaseDefinition caseDef)
203 {
204 	if (!subgroups::isSubgroupSupported(context))
205 		TCU_THROW(NotSupportedError, "Subgroup operations are not supported");
206 
207 	if (!subgroups::isSubgroupFeatureSupportedForDevice(context, VK_SUBGROUP_FEATURE_QUAD_BIT))
208 		TCU_THROW(NotSupportedError, "Device does not support subgroup quad operations");
209 
210 	if (!subgroups::isFormatSupportedForDevice(context, caseDef.format))
211 		TCU_THROW(NotSupportedError, "Device does not support the specified format in subgroup operations");
212 
213 	if ((caseDef.opType == OPTYPE_QUAD_BROADCAST_NONCONST) && !subgroups::isSubgroupBroadcastDynamicIdSupported(context))
214 		TCU_THROW(NotSupportedError, "Device does not support SubgroupBroadcastDynamicId");
215 
216 	if (caseDef.requiredSubgroupSize)
217 	{
218 		context.requireDeviceFunctionality("VK_EXT_subgroup_size_control");
219 
220 		const VkPhysicalDeviceSubgroupSizeControlFeaturesEXT&	subgroupSizeControlFeatures		= context.getSubgroupSizeControlFeaturesEXT();
221 		const VkPhysicalDeviceSubgroupSizeControlPropertiesEXT&	subgroupSizeControlProperties	= context.getSubgroupSizeControlPropertiesEXT();
222 
223 		if (subgroupSizeControlFeatures.subgroupSizeControl == DE_FALSE)
224 			TCU_THROW(NotSupportedError, "Device does not support varying subgroup sizes nor required subgroup size");
225 
226 		if (subgroupSizeControlFeatures.computeFullSubgroups == DE_FALSE)
227 			TCU_THROW(NotSupportedError, "Device does not support full subgroups in compute shaders");
228 
229 		if ((subgroupSizeControlProperties.requiredSubgroupSizeStages & caseDef.shaderStage) != caseDef.shaderStage)
230 			TCU_THROW(NotSupportedError, "Required subgroup size is not supported for shader stage");
231 	}
232 
233 	*caseDef.geometryPointSizeSupported = subgroups::isTessellationAndGeometryPointSizeSupported(context);
234 
235 	if (isAllRayTracingStages(caseDef.shaderStage))
236 	{
237 		context.requireDeviceFunctionality("VK_KHR_ray_tracing_pipeline");
238 	}
239 
240 	subgroups::supportedCheckShader(context, caseDef.shaderStage);
241 }
242 
noSSBOtest(Context & context,const CaseDefinition caseDef)243 TestStatus noSSBOtest (Context& context, const CaseDefinition caseDef)
244 {
245 	subgroups::SSBOData inputData;
246 	inputData.format = caseDef.format;
247 	inputData.layout = subgroups::SSBOData::LayoutStd140;
248 	inputData.numElements = subgroups::maxSupportedSubgroupSize();
249 	inputData.initializeType = subgroups::SSBOData::InitializeNonZero;
250 
251 	switch (caseDef.shaderStage)
252 	{
253 		case VK_SHADER_STAGE_VERTEX_BIT:					return subgroups::makeVertexFrameBufferTest(context, VK_FORMAT_R32_UINT, &inputData, 1, DE_NULL, checkVertexPipelineStages);
254 		case VK_SHADER_STAGE_GEOMETRY_BIT:					return subgroups::makeGeometryFrameBufferTest(context, VK_FORMAT_R32_UINT, &inputData, 1, DE_NULL, checkVertexPipelineStages);
255 		case VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT:		return subgroups::makeTessellationEvaluationFrameBufferTest(context, VK_FORMAT_R32_UINT, &inputData, 1, DE_NULL, checkVertexPipelineStages, caseDef.shaderStage);
256 		case VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT:	return subgroups::makeTessellationEvaluationFrameBufferTest(context, VK_FORMAT_R32_UINT, &inputData, 1, DE_NULL, checkVertexPipelineStages, caseDef.shaderStage);
257 		default:											TCU_THROW(InternalError, "Unhandled shader stage");
258 	}
259 }
260 
test(Context & context,const CaseDefinition caseDef)261 TestStatus test (Context& context, const CaseDefinition caseDef)
262 {
263 
264 	if (isAllComputeStages(caseDef.shaderStage))
265 	{
266 		const VkPhysicalDeviceSubgroupSizeControlPropertiesEXT&	subgroupSizeControlProperties	= context.getSubgroupSizeControlPropertiesEXT();
267 		TestLog&												log								= context.getTestContext().getLog();
268 		const subgroups::SSBOData								inputData
269 		{
270 			subgroups::SSBOData::InitializeNonZero,	// InputDataInitializeType		initializeType;
271 			subgroups::SSBOData::LayoutStd430,		// InputDataLayoutType			layout;
272 			caseDef.format,							// vk::VkFormat				format;
273 			subgroups::maxSupportedSubgroupSize(),	// vk::VkDeviceSize			numElements;
274 		};
275 
276 		if (caseDef.requiredSubgroupSize == DE_FALSE)
277 			return subgroups::makeComputeTest(context, VK_FORMAT_R32_UINT, &inputData, 1, DE_NULL, checkCompute);
278 
279 		log << TestLog::Message << "Testing required subgroup size range [" <<  subgroupSizeControlProperties.minSubgroupSize << ", "
280 			<< subgroupSizeControlProperties.maxSubgroupSize << "]" << TestLog::EndMessage;
281 
282 		// According to the spec, requiredSubgroupSize must be a power-of-two integer.
283 		for (deUint32 size = subgroupSizeControlProperties.minSubgroupSize; size <= subgroupSizeControlProperties.maxSubgroupSize; size *= 2)
284 		{
285 			TestStatus	result	= subgroups::makeComputeTest(context, VK_FORMAT_R32_UINT, &inputData, 1, DE_NULL, checkCompute,
286 															size, VK_PIPELINE_SHADER_STAGE_CREATE_REQUIRE_FULL_SUBGROUPS_BIT_EXT);
287 			if (result.getCode() != QP_TEST_RESULT_PASS)
288 			{
289 				log << TestLog::Message << "subgroupSize " << size << " failed" << TestLog::EndMessage;
290 				return result;
291 			}
292 		}
293 
294 		return TestStatus::pass("OK");
295 	}
296 	else if (isAllGraphicsStages(caseDef.shaderStage))
297 	{
298 		const VkShaderStageFlags	stages		= subgroups::getPossibleGraphicsSubgroupStages(context, caseDef.shaderStage);
299 		subgroups::SSBOData			inputData;
300 
301 		inputData.format			= caseDef.format;
302 		inputData.layout			= subgroups::SSBOData::LayoutStd430;
303 		inputData.numElements		= subgroups::maxSupportedSubgroupSize();
304 		inputData.initializeType	= subgroups::SSBOData::InitializeNonZero;
305 		inputData.binding			= 4u;
306 		inputData.stages			= stages;
307 
308 		return subgroups::allStages(context, VK_FORMAT_R32_UINT, &inputData, 1, DE_NULL, checkVertexPipelineStages, stages);
309 	}
310 	else if (isAllRayTracingStages(caseDef.shaderStage))
311 	{
312 		const VkShaderStageFlags	stages		= subgroups::getPossibleRayTracingSubgroupStages(context, caseDef.shaderStage);
313 		const subgroups::SSBOData	inputData	=
314 		{
315 			subgroups::SSBOData::InitializeNonZero,	//  InputDataInitializeType		initializeType;
316 			subgroups::SSBOData::LayoutStd430,		//  InputDataLayoutType			layout;
317 			caseDef.format,							//  vk::VkFormat				format;
318 			subgroups::maxSupportedSubgroupSize(),	//  vk::VkDeviceSize			numElements;
319 			false,									//  bool						isImage;
320 			6u,										//  deUint32					binding;
321 			stages,									//  vk::VkShaderStageFlags		stages;
322 		};
323 
324 		return subgroups::allRayTracingStages(context, VK_FORMAT_R32_UINT, &inputData, 1, DE_NULL, checkVertexPipelineStages, stages);
325 	}
326 	else
327 		TCU_THROW(InternalError, "Unknown stage or invalid stage set");
328 }
329 }
330 
331 namespace vkt
332 {
333 namespace subgroups
334 {
createSubgroupsQuadTests(TestContext & testCtx)335 TestCaseGroup* createSubgroupsQuadTests (TestContext& testCtx)
336 {
337 	de::MovePtr<TestCaseGroup>	group				(new TestCaseGroup(testCtx, "quad", "Subgroup quad category tests"));
338 	de::MovePtr<TestCaseGroup>	graphicGroup		(new TestCaseGroup(testCtx, "graphics", "Subgroup arithmetic category tests: graphics"));
339 	de::MovePtr<TestCaseGroup>	computeGroup		(new TestCaseGroup(testCtx, "compute", "Subgroup arithmetic category tests: compute"));
340 	de::MovePtr<TestCaseGroup>	framebufferGroup	(new TestCaseGroup(testCtx, "framebuffer", "Subgroup arithmetic category tests: framebuffer"));
341 	de::MovePtr<TestCaseGroup>	raytracingGroup		(new TestCaseGroup(testCtx, "ray_tracing", "Subgroup arithmetic category tests: ray tracing"));
342 	const VkShaderStageFlags	stages[]			=
343 	{
344 		VK_SHADER_STAGE_VERTEX_BIT,
345 		VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT,
346 		VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT,
347 		VK_SHADER_STAGE_GEOMETRY_BIT,
348 	};
349 	const deBool				boolValues[]		=
350 	{
351 		DE_FALSE,
352 		DE_TRUE
353 	};
354 
355 	{
356 		const vector<VkFormat>	formats	= subgroups::getAllFormats();
357 
358 		for (size_t formatIndex = 0; formatIndex < formats.size(); ++formatIndex)
359 		{
360 			const VkFormat	format		= formats[formatIndex];
361 			const string	formatName	= subgroups::getFormatNameForGLSL(format);
362 
363 			for (int opTypeIndex = 0; opTypeIndex < OPTYPE_LAST; ++opTypeIndex)
364 			{
365 				const OpType	opType	= static_cast<OpType>(opTypeIndex);
366 				const string	name	= getOpTypeCaseName(opType) + "_" + formatName;
367 
368 				for (size_t groupSizeNdx = 0; groupSizeNdx < DE_LENGTH_OF_ARRAY(boolValues); ++groupSizeNdx)
369 				{
370 					const deBool			requiredSubgroupSize	= boolValues[groupSizeNdx];
371 					const string			testNameSuffix			= requiredSubgroupSize ? "_requiredsubgroupsize" : "";
372 					const string			testName				= name + testNameSuffix;
373 					const CaseDefinition	caseDef					=
374 					{
375 						opType,							//  OpType				opType;
376 						VK_SHADER_STAGE_COMPUTE_BIT,	//  VkShaderStageFlags	shaderStage;
377 						format,							//  VkFormat			format;
378 						de::SharedPtr<bool>(new bool),	//  de::SharedPtr<bool>	geometryPointSizeSupported;
379 						requiredSubgroupSize,			//  deBool				requiredSubgroupSize;
380 					};
381 
382 					addFunctionCaseWithPrograms(computeGroup.get(), testName, "", supportedCheck, initPrograms, test, caseDef);
383 				}
384 
385 				{
386 					const CaseDefinition	caseDef		=
387 					{
388 						opType,							//  OpType				opType;
389 						VK_SHADER_STAGE_ALL_GRAPHICS,	//  VkShaderStageFlags	shaderStage;
390 						format,							//  VkFormat			format;
391 						de::SharedPtr<bool>(new bool),	//  de::SharedPtr<bool>	geometryPointSizeSupported;
392 						DE_FALSE						//  deBool				requiredSubgroupSize;
393 					};
394 
395 					addFunctionCaseWithPrograms(graphicGroup.get(), name, "", supportedCheck, initPrograms, test, caseDef);
396 				}
397 
398 				for (int stageIndex = 0; stageIndex < DE_LENGTH_OF_ARRAY(stages); ++stageIndex)
399 				{
400 					const CaseDefinition	caseDef		=
401 					{
402 						opType,							//  OpType				opType;
403 						stages[stageIndex],				//  VkShaderStageFlags	shaderStage;
404 						format,							//  VkFormat			format;
405 						de::SharedPtr<bool>(new bool),	//  de::SharedPtr<bool>	geometryPointSizeSupported;
406 						DE_FALSE						//  deBool				requiredSubgroupSize;
407 					};
408 					const string			testName	= name + "_" + getShaderStageName(caseDef.shaderStage);
409 
410 					addFunctionCaseWithPrograms(framebufferGroup.get(), testName, "", supportedCheck, initFrameBufferPrograms, noSSBOtest, caseDef);
411 				}
412 			}
413 		}
414 	}
415 
416 	{
417 		const vector<VkFormat>	formats	= subgroups::getAllRayTracingFormats();
418 
419 		for (size_t formatIndex = 0; formatIndex < formats.size(); ++formatIndex)
420 		{
421 			const VkFormat	format		= formats[formatIndex];
422 			const string	formatName	= subgroups::getFormatNameForGLSL(format);
423 
424 			for (int opTypeIndex = 0; opTypeIndex < OPTYPE_LAST; ++opTypeIndex)
425 			{
426 				const OpType			opType		= static_cast<OpType>(opTypeIndex);
427 				const string			testName	= getOpTypeCaseName(opType) + "_" + formatName;
428 				const CaseDefinition	caseDef		=
429 				{
430 					opType,							//  OpType				opType;
431 					SHADER_STAGE_ALL_RAY_TRACING,	//  VkShaderStageFlags	shaderStage;
432 					format,							//  VkFormat			format;
433 					de::SharedPtr<bool>(new bool),	//  de::SharedPtr<bool>	geometryPointSizeSupported;
434 					DE_FALSE						//  deBool				requiredSubgroupSize;
435 				};
436 
437 				addFunctionCaseWithPrograms(raytracingGroup.get(), testName, "", supportedCheck, initPrograms, test, caseDef);
438 			}
439 		}
440 	}
441 
442 	group->addChild(graphicGroup.release());
443 	group->addChild(computeGroup.release());
444 	group->addChild(framebufferGroup.release());
445 	group->addChild(raytracingGroup.release());
446 
447 	return group.release();
448 }
449 } // subgroups
450 } // vkt
451