• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*------------------------------------------------------------------------
2  * Vulkan Conformance Tests
3  * ------------------------
4  *
5  * Copyright (c) 2017 The Khronos Group Inc.
6  * Copyright (c) 2017 Codeplay Software Ltd.
7  *
8  * Licensed under the Apache License, Version 2.0 (the "License");
9  * you may not use this file except in compliance with the License.
10  * You may obtain a copy of the License at
11  *
12  *      http://www.apache.org/licenses/LICENSE-2.0
13  *
14  * Unless required by applicable law or agreed to in writing, software
15  * distributed under the License is distributed on an "AS IS" BASIS,
16  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17  * See the License for the specific language governing permissions and
18  * limitations under the License.
19  *
20  */ /*!
21  * \file
22  * \brief Subgroups Tests
23  */ /*--------------------------------------------------------------------*/
24 
25 #include "vktSubgroupsVoteTests.hpp"
26 #include "vktSubgroupsTestsUtils.hpp"
27 
28 #include <string>
29 #include <vector>
30 
31 using namespace tcu;
32 using namespace std;
33 using namespace vk;
34 using namespace vkt;
35 
36 namespace
37 {
38 enum OpType
39 {
40 	OPTYPE_ALL = 0,
41 	OPTYPE_ANY,
42 	OPTYPE_ALLEQUAL,
43 	OPTYPE_LAST
44 };
45 
checkVertexPipelineStages(std::vector<const void * > datas,deUint32 width,deUint32)46 static bool checkVertexPipelineStages(std::vector<const void*> datas,
47 									  deUint32 width, deUint32)
48 {
49 	const deUint32* data =
50 		reinterpret_cast<const deUint32*>(datas[0]);
51 	for (deUint32 x = 0; x < width; ++x)
52 	{
53 		deUint32 val = data[x];
54 
55 		if (0x7 != val)
56 		{
57 			return false;
58 		}
59 	}
60 
61 	return true;
62 }
63 
checkVertexPipelineStagesNoSSBO(std::vector<const void * > datas,deUint32 width,deUint32)64 static bool checkVertexPipelineStagesNoSSBO(std::vector<const void*> datas,
65 									  deUint32 width, deUint32)
66 {
67 	const float* data =
68 		reinterpret_cast<const float*>(datas[0]);
69 	for (deUint32 x = 0; x < width; ++x)
70 	{
71 		deUint32 val = static_cast<deUint32>(data[x]);
72 
73 		if (0x7 != val)
74 		{
75 			return false;
76 		}
77 	}
78 
79 	return true;
80 }
81 
checkFragment(std::vector<const void * > datas,deUint32 width,deUint32 height,deUint32)82 static bool checkFragment(std::vector<const void*> datas, deUint32 width,
83 						  deUint32 height, deUint32)
84 {
85 	const deUint32* data =
86 		reinterpret_cast<const deUint32*>(datas[0]);
87 	for (deUint32 x = 0; x < width; ++x)
88 	{
89 		for (deUint32 y = 0; y < height; ++y)
90 		{
91 			deUint32 val = data[x * height + y];
92 
93 			if (0x7 != val)
94 			{
95 				return false;
96 			}
97 		}
98 	}
99 
100 	return true;
101 }
102 
checkCompute(std::vector<const void * > datas,const deUint32 numWorkgroups[3],const deUint32 localSize[3],deUint32)103 static bool checkCompute(std::vector<const void*> datas,
104 						 const deUint32 numWorkgroups[3], const deUint32 localSize[3],
105 						 deUint32)
106 {
107 	const deUint32* data = reinterpret_cast<const deUint32*>(datas[0]);
108 
109 	for (deUint32 nX = 0; nX < numWorkgroups[0]; ++nX)
110 	{
111 		for (deUint32 nY = 0; nY < numWorkgroups[1]; ++nY)
112 		{
113 			for (deUint32 nZ = 0; nZ < numWorkgroups[2]; ++nZ)
114 			{
115 				for (deUint32 lX = 0; lX < localSize[0]; ++lX)
116 				{
117 					for (deUint32 lY = 0; lY < localSize[1]; ++lY)
118 					{
119 						for (deUint32 lZ = 0; lZ < localSize[2];
120 								++lZ)
121 						{
122 							const deUint32 globalInvocationX =
123 								nX * localSize[0] + lX;
124 							const deUint32 globalInvocationY =
125 								nY * localSize[1] + lY;
126 							const deUint32 globalInvocationZ =
127 								nZ * localSize[2] + lZ;
128 
129 							const deUint32 globalSizeX =
130 								numWorkgroups[0] * localSize[0];
131 							const deUint32 globalSizeY =
132 								numWorkgroups[1] * localSize[1];
133 
134 							const deUint32 offset =
135 								globalSizeX *
136 								((globalSizeY *
137 								  globalInvocationZ) +
138 								 globalInvocationY) +
139 								globalInvocationX;
140 
141 							// The data should look (in binary) 0b111
142 							if (0x7 != data[offset])
143 							{
144 								return false;
145 							}
146 						}
147 					}
148 				}
149 			}
150 		}
151 	}
152 
153 	return true;
154 }
155 
checkComputeAllEqual(std::vector<const void * > datas,const deUint32 numWorkgroups[3],const deUint32 localSize[3],deUint32)156 static bool checkComputeAllEqual(std::vector<const void*> datas,
157 								 const deUint32 numWorkgroups[3], const deUint32 localSize[3],
158 								 deUint32)
159 {
160 	const deUint32* data = reinterpret_cast<const deUint32*>(datas[0]);
161 
162 	for (deUint32 nX = 0; nX < numWorkgroups[0]; ++nX)
163 	{
164 		for (deUint32 nY = 0; nY < numWorkgroups[1]; ++nY)
165 		{
166 			for (deUint32 nZ = 0; nZ < numWorkgroups[2]; ++nZ)
167 			{
168 				for (deUint32 lX = 0; lX < localSize[0]; ++lX)
169 				{
170 					for (deUint32 lY = 0; lY < localSize[1]; ++lY)
171 					{
172 						for (deUint32 lZ = 0; lZ < localSize[2];
173 								++lZ)
174 						{
175 							const deUint32 globalInvocationX =
176 								nX * localSize[0] + lX;
177 							const deUint32 globalInvocationY =
178 								nY * localSize[1] + lY;
179 							const deUint32 globalInvocationZ =
180 								nZ * localSize[2] + lZ;
181 
182 							const deUint32 globalSizeX =
183 								numWorkgroups[0] * localSize[0];
184 							const deUint32 globalSizeY =
185 								numWorkgroups[1] * localSize[1];
186 
187 							const deUint32 offset =
188 								globalSizeX *
189 								((globalSizeY *
190 								  globalInvocationZ) +
191 								 globalInvocationY) +
192 								globalInvocationX;
193 
194 							// The data should look (in binary) 0b111
195 							if (0x7 != data[offset])
196 							{
197 								return false;
198 							}
199 						}
200 					}
201 				}
202 			}
203 		}
204 	}
205 
206 	return true;
207 }
208 
getOpTypeName(int opType)209 std::string getOpTypeName(int opType)
210 {
211 	switch (opType)
212 	{
213 		default:
214 			DE_FATAL("Unsupported op type");
215 		case OPTYPE_ALL:
216 			return "subgroupAll";
217 		case OPTYPE_ANY:
218 			return "subgroupAny";
219 		case OPTYPE_ALLEQUAL:
220 			return "subgroupAllEqual";
221 	}
222 }
223 
224 struct CaseDefinition
225 {
226 	int					opType;
227 	VkShaderStageFlags	shaderStage;
228 	VkFormat			format;
229 	bool				noSSBO;
230 };
231 
initFrameBufferPrograms(SourceCollections & programCollection,CaseDefinition caseDef)232 void initFrameBufferPrograms(SourceCollections& programCollection, CaseDefinition caseDef)
233 {
234 	std::ostringstream	vertexSrc;
235 	std::ostringstream	fragmentSrc;
236 
237 	if (VK_SHADER_STAGE_VERTEX_BIT == caseDef.shaderStage)
238 	{
239 		vertexSrc << glu::getGLSLVersionDeclaration(glu::GLSL_VERSION_450)<<"\n"
240 			<< "#extension GL_KHR_shader_subgroup_vote: enable\n"
241 			<< "layout(location = 0) out vec4 out_color;\n"
242 			<< "layout(location = 0) in highp vec4 in_position;\n"
243 			<< "layout(set = 0, binding = 0) uniform Buffer1\n"
244 			<< "{\n"
245 			<< "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " data[" << subgroups::maxSupportedSubgroupSize() << "];\n"
246 			<< "};\n"
247 			<< "\n"
248 			<< "void main (void)\n"
249 			<< "{\n"
250 			<< "  uint result;\n";
251 		if (OPTYPE_ALL == caseDef.opType)
252 		{
253 			vertexSrc << " result = " << getOpTypeName(caseDef.opType)
254 				<< "(true) ? 0x1 : 0;\n"
255 				<< "  result |= " << getOpTypeName(caseDef.opType)
256 				<< "(false) ? 0 : 0x2;\n"
257 				<< "  result |= 0x4;\n"
258 				<< "  out_color.r = float(result);\n";
259 		}
260 		else if (OPTYPE_ANY == caseDef.opType)
261 		{
262 			vertexSrc << "  result = " << getOpTypeName(caseDef.opType)
263 				<< "(true) ? 0x1 : 0;\n"
264 				<< "  result |= " << getOpTypeName(caseDef.opType)
265 				<< "(false) ? 0 : 0x2;\n"
266 				<< "  result |= 0x4;\n"
267 				<< "out_color.r = float(result);\n";
268 		}
269 		else if (OPTYPE_ALLEQUAL == caseDef.opType)
270 		{
271 			vertexSrc << "  result = " << getOpTypeName(caseDef.opType) << "("
272 				<< subgroups::getFormatNameForGLSL(caseDef.format) << "(1)) ? 0x1 : 0;\n"
273 				<< "  result |= " << getOpTypeName(caseDef.opType)
274 				<< "(gl_SubgroupInvocationID) ? 0 : 0x2;\n"
275 				<< "  if (subgroupElect()) result |= 0x2;\n"
276 				<< "  result |= " << getOpTypeName(caseDef.opType)
277 				<< "(data[0]) ? 0x4 : 0;\n"
278 				<< "  out_color.x = float(result);\n";
279 		}
280 
281 		vertexSrc << "  gl_Position = in_position;\n"
282 			<< "  gl_PointSize = 1.0f;\n"
283 			<< "}\n";
284 
285 		programCollection.glslSources.add("vert") << glu::VertexSource(vertexSrc.str()) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u);
286 
287 		fragmentSrc << glu::getGLSLVersionDeclaration(glu::GLSL_VERSION_450)<<"\n"
288 			<< "layout(location = 0) in vec4 in_color;\n"
289 			<< "layout(location = 0) out vec4 out_color;\n"
290 			<< "void main()\n"
291 			<<"{\n"
292 			<< "	out_color = in_color;\n"
293 			<< "}\n";
294 		programCollection.glslSources.add("fragment") << glu::FragmentSource(fragmentSrc.str()) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u);
295 	}
296 	else
297 	{
298 		DE_FATAL("Unsupported shader stage");
299 	}
300 }
301 
initPrograms(SourceCollections & programCollection,CaseDefinition caseDef)302 void initPrograms(SourceCollections& programCollection, CaseDefinition caseDef)
303 {
304 	if (VK_SHADER_STAGE_COMPUTE_BIT == caseDef.shaderStage)
305 	{
306 		std::ostringstream src;
307 
308 		src << "#version 450\n"
309 			<< "#extension GL_KHR_shader_subgroup_vote: enable\n"
310 			<< "layout (local_size_x_id = 0, local_size_y_id = 1, "
311 			"local_size_z_id = 2) in;\n"
312 			<< "layout(set = 0, binding = 0, std430) buffer Buffer1\n"
313 			<< "{\n"
314 			<< "  uint result[];\n"
315 			<< "};\n"
316 			<< "layout(set = 0, binding = 1, std430) buffer Buffer2\n"
317 			<< "{\n"
318 			<< "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " data[];\n"
319 			<< "};\n"
320 			<< "\n"
321 			<< "void main (void)\n"
322 			<< "{\n"
323 			<< "  uvec3 globalSize = gl_NumWorkGroups * gl_WorkGroupSize;\n"
324 			<< "  highp uint offset = globalSize.x * ((globalSize.y * "
325 			"gl_GlobalInvocationID.z) + gl_GlobalInvocationID.y) + "
326 			"gl_GlobalInvocationID.x;\n";
327 		if (OPTYPE_ALL == caseDef.opType)
328 		{
329 			src << "  result[offset] = " << getOpTypeName(caseDef.opType)
330 				<< "(true) ? 0x1 : 0;\n"
331 				<< "  result[offset] |= " << getOpTypeName(caseDef.opType)
332 				<< "(false) ? 0 : 0x2;\n"
333 				<< "  result[offset] |= " << getOpTypeName(caseDef.opType)
334 				<< "(data[gl_SubgroupInvocationID] > 0) ? 0x4 : 0;\n";
335 		}
336 		else if (OPTYPE_ANY == caseDef.opType)
337 		{
338 			src << "  result[offset] = " << getOpTypeName(caseDef.opType)
339 				<< "(true) ? 0x1 : 0;\n"
340 				<< "  result[offset] |= " << getOpTypeName(caseDef.opType)
341 				<< "(false) ? 0 : 0x2;\n"
342 				<< "  result[offset] |= " << getOpTypeName(caseDef.opType)
343 				<< "(data[gl_SubgroupInvocationID] == data[0]) ? 0x4 : 0;\n";
344 		}
345 		else if (OPTYPE_ALLEQUAL == caseDef.opType)
346 		{
347 			src << "  result[offset] = " << getOpTypeName(caseDef.opType) << "("
348 				<< subgroups::getFormatNameForGLSL(caseDef.format) << "(1)) ? 0x1 : 0;\n"
349 				<< "  result[offset] |= " << getOpTypeName(caseDef.opType)
350 				<< "(gl_SubgroupInvocationID) ? 0 : 0x2;\n"
351 				<< "  if (subgroupElect()) result[offset] |= 0x2;\n"
352 				<< "  result[offset] |= " << getOpTypeName(caseDef.opType)
353 				<< "(data[0]) ? 0x4 : 0;\n";
354 		}
355 
356 		src << "}\n";
357 
358 		programCollection.glslSources.add("comp")
359 				<< glu::ComputeSource(src.str()) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u);
360 	}
361 	else if (VK_SHADER_STAGE_FRAGMENT_BIT == caseDef.shaderStage)
362 	{
363 		programCollection.glslSources.add("vert")
364 				<< glu::VertexSource(subgroups::getVertShaderForStage(caseDef.shaderStage)) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u);
365 
366 		std::ostringstream frag;
367 
368 		frag << "#version 450\n"
369 			 << "#extension GL_KHR_shader_subgroup_vote: enable\n"
370 			 << "layout(location = 0) out uint result;\n"
371 			 << "layout(set = 0, binding = 0, std430) readonly buffer Buffer2\n"
372 			 << "{\n"
373 			 << "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " data[];\n"
374 			 << "};\n"
375 			 << "void main (void)\n"
376 			 << "{\n";
377 		if (OPTYPE_ALL == caseDef.opType)
378 		{
379 			frag << "  result = " << getOpTypeName(caseDef.opType)
380 				 << "(true) ? 0x1 : 0;\n"
381 				 << "  result |= " << getOpTypeName(caseDef.opType)
382 				 << "(false) ? 0 : 0x2;\n"
383 				 << "  result |= " << getOpTypeName(caseDef.opType)
384 				 << "(data[gl_SubgroupInvocationID] > 0) ? 0x4 : 0;\n";
385 		}
386 		else if (OPTYPE_ANY == caseDef.opType)
387 		{
388 			frag << "  result = " << getOpTypeName(caseDef.opType)
389 				 << "(true) ? 0x1 : 0;\n"
390 				 << "  result |= " << getOpTypeName(caseDef.opType)
391 				 << "(false) ? 0 : 0x2;\n"
392 				 << "  result |= " << getOpTypeName(caseDef.opType)
393 				 << "(data[gl_SubgroupInvocationID] == data[0]) ? 0x4 : 0;\n";
394 		}
395 		else if (OPTYPE_ALLEQUAL == caseDef.opType)
396 		{
397 			frag << "  result = " << getOpTypeName(caseDef.opType) << "("
398 				 << subgroups::getFormatNameForGLSL(caseDef.format) << "(1)) ? 0x1 : 0;\n"
399 				 << "  result |= " << getOpTypeName(caseDef.opType)
400 				 << "(gl_SubgroupInvocationID) ? 0 : 0x2;\n"
401 				 << "  if (subgroupElect()) result |= 0x2;\n"
402 				 << "  result |= " << getOpTypeName(caseDef.opType)
403 				 << "(data[0]) ? 0x4 : 0;\n";
404 		}
405 		frag << "}\n";
406 
407 		programCollection.glslSources.add("frag")
408 				<< glu::FragmentSource(frag.str()) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u);
409 	}
410 	else if (VK_SHADER_STAGE_VERTEX_BIT == caseDef.shaderStage)
411 	{
412 		std::ostringstream src;
413 
414 		src << "#version 450\n"
415 			<< "#extension GL_KHR_shader_subgroup_vote: enable\n"
416 			<< "layout(set = 0, binding = 0, std430) buffer Buffer1\n"
417 			<< "{\n"
418 			<< "  uint result[];\n"
419 			<< "};\n"
420 			<< "layout(set = 0, binding = 1, std430) buffer Buffer2\n"
421 			<< "{\n"
422 			<< "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " data[];\n"
423 			<< "};\n"
424 			<< "\n"
425 			<< "void main (void)\n"
426 			<< "{\n"
427 			<< "  highp uint offset = gl_VertexIndex;\n";
428 		if (OPTYPE_ALL == caseDef.opType)
429 		{
430 			src << "  result[offset] = " << getOpTypeName(caseDef.opType)
431 				<< "(true) ? 0x1 : 0;\n"
432 				<< "  result[offset] |= " << getOpTypeName(caseDef.opType)
433 				<< "(false) ? 0 : 0x2;\n"
434 				<< "  result[offset] |= 0x4;\n";
435 		}
436 		else if (OPTYPE_ANY == caseDef.opType)
437 		{
438 			src << "  result[offset] = " << getOpTypeName(caseDef.opType)
439 				<< "(true) ? 0x1 : 0;\n"
440 				<< "  result[offset] |= " << getOpTypeName(caseDef.opType)
441 				<< "(false) ? 0 : 0x2;\n"
442 				<< "  result[offset] |= 0x4;\n";
443 		}
444 		else if (OPTYPE_ALLEQUAL == caseDef.opType)
445 		{
446 			src << "  result[offset] = " << getOpTypeName(caseDef.opType) << "("
447 				<< subgroups::getFormatNameForGLSL(caseDef.format) << "(1)) ? 0x1 : 0;\n"
448 				<< "  result[offset] |= " << getOpTypeName(caseDef.opType)
449 				<< "(gl_SubgroupInvocationID) ? 0 : 0x2;\n"
450 				<< "  if (subgroupElect()) result[offset] |= 0x2;\n"
451 				<< "  result[offset] |= " << getOpTypeName(caseDef.opType)
452 				<< "(data[0]) ? 0x4 : 0;\n";
453 		}
454 
455 		src << "  gl_PointSize = 1.0f;\n";
456 		src << "}\n";
457 
458 		programCollection.glslSources.add("vert")
459 				<< glu::VertexSource(src.str()) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u);
460 	}
461 	else if (VK_SHADER_STAGE_GEOMETRY_BIT == caseDef.shaderStage)
462 	{
463 		programCollection.glslSources.add("vert")
464 				<< glu::VertexSource(subgroups::getVertShaderForStage(caseDef.shaderStage)) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u);
465 
466 		std::ostringstream src;
467 
468 		src << "#version 450\n"
469 			<< "#extension GL_KHR_shader_subgroup_vote: enable\n"
470 			<< "layout(points) in;\n"
471 			<< "layout(points, max_vertices = 1) out;\n"
472 			<< "layout(set = 0, binding = 0, std430) buffer Buffer1\n"
473 			<< "{\n"
474 			<< "  uint result[];\n"
475 			<< "};\n"
476 			<< "layout(set = 0, binding = 1, std430) buffer Buffer2\n"
477 			<< "{\n"
478 			<< "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " data[];\n"
479 			<< "};\n"
480 			<< "\n"
481 			<< "void main (void)\n"
482 			<< "{\n"
483 			<< "  highp uint offset = gl_PrimitiveIDIn;\n";
484 		if (OPTYPE_ALL == caseDef.opType)
485 		{
486 			src << "  result[offset] = " << getOpTypeName(caseDef.opType)
487 				<< "(true) ? 0x1 : 0;\n"
488 				<< "  result[offset] |= " << getOpTypeName(caseDef.opType)
489 				<< "(false) ? 0 : 0x2;\n"
490 				<< "  result[offset] |= 0x4;\n";
491 		}
492 		else if (OPTYPE_ANY == caseDef.opType)
493 		{
494 			src << "  result[offset] = " << getOpTypeName(caseDef.opType)
495 				<< "(true) ? 0x1 : 0;\n"
496 				<< "  result[offset] |= " << getOpTypeName(caseDef.opType)
497 				<< "(false) ? 0 : 0x2;\n"
498 				<< "  result[offset] |= 0x4;\n";
499 		}
500 		else if (OPTYPE_ALLEQUAL == caseDef.opType)
501 		{
502 			src << "  result[offset] = " << getOpTypeName(caseDef.opType) << "("
503 				<< subgroups::getFormatNameForGLSL(caseDef.format) << "(1)) ? 0x1 : 0;\n"
504 				<< "  result[offset] |= " << getOpTypeName(caseDef.opType)
505 				<< "(gl_SubgroupInvocationID) ? 0 : 0x2;\n"
506 				<< "  if (subgroupElect()) result[offset] |= 0x2;\n"
507 				<< "  result[offset] |= " << getOpTypeName(caseDef.opType)
508 				<< "(data[0]) ? 0x4 : 0;\n";
509 		}
510 
511 		src << "}\n";
512 
513 		programCollection.glslSources.add("geom")
514 				<< glu::GeometrySource(src.str()) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u);
515 	}
516 	else if (VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT == caseDef.shaderStage)
517 	{
518 		programCollection.glslSources.add("vert")
519 				<< glu::VertexSource(subgroups::getVertShaderForStage(caseDef.shaderStage)) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u);
520 
521 		programCollection.glslSources.add("tese")
522 				<< glu::TessellationEvaluationSource("#version 450\nlayout(isolines) in;\nvoid main (void) {}\n");
523 
524 		std::ostringstream src;
525 
526 		src << "#version 450\n"
527 			<< "#extension GL_KHR_shader_subgroup_vote: enable\n"
528 			<< "layout(vertices=1) out;\n"
529 			<< "layout(set = 0, binding = 0, std430) buffer Buffer1\n"
530 			<< "{\n"
531 			<< "  uint result[];\n"
532 			<< "};\n"
533 			<< "layout(set = 0, binding = 1, std430) buffer Buffer2\n"
534 			<< "{\n"
535 			<< "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " data[];\n"
536 			<< "};\n"
537 			<< "\n"
538 			<< "void main (void)\n"
539 			<< "{\n"
540 			<< "  highp uint offset = gl_PrimitiveID;\n";
541 		if (OPTYPE_ALL == caseDef.opType)
542 		{
543 			src << "  result[offset] = " << getOpTypeName(caseDef.opType)
544 				<< "(true) ? 0x1 : 0;\n"
545 				<< "  result[offset] |= " << getOpTypeName(caseDef.opType)
546 				<< "(false) ? 0 : 0x2;\n"
547 				<< "  result[offset] |= 0x4;\n";
548 		}
549 		else if (OPTYPE_ANY == caseDef.opType)
550 		{
551 			src << "  result[offset] = " << getOpTypeName(caseDef.opType)
552 				<< "(true) ? 0x1 : 0;\n"
553 				<< "  result[offset] |= " << getOpTypeName(caseDef.opType)
554 				<< "(false) ? 0 : 0x2;\n"
555 				<< "  result[offset] |= 0x4;\n";
556 		}
557 		else if (OPTYPE_ALLEQUAL == caseDef.opType)
558 		{
559 			src << "  result[offset] = " << getOpTypeName(caseDef.opType) << "("
560 				<< subgroups::getFormatNameForGLSL(caseDef.format) << "(1)) ? 0x1 : 0;\n"
561 				<< "  result[offset] |= " << getOpTypeName(caseDef.opType)
562 				<< "(gl_SubgroupInvocationID) ? 0 : 0x2;\n"
563 				<< "  if (subgroupElect()) result[offset] |= 0x2;\n"
564 				<< "  result[offset] |= " << getOpTypeName(caseDef.opType)
565 				<< "(data[0]) ? 0x4 : 0;\n";
566 		}
567 
568 		src << "}\n";
569 
570 		programCollection.glslSources.add("tesc")
571 				<< glu::TessellationControlSource(src.str()) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u);
572 	}
573 	else if (VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT == caseDef.shaderStage)
574 	{
575 		programCollection.glslSources.add("vert")
576 				<< glu::VertexSource(subgroups::getVertShaderForStage(caseDef.shaderStage)) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u);
577 
578 		programCollection.glslSources.add("tesc")
579 				<< glu::TessellationControlSource("#version 450\nlayout(vertices=1) out;\nvoid main (void) { for(uint i = 0; i < 4; i++) { gl_TessLevelOuter[i] = 1.0f; } }\n");
580 
581 		std::ostringstream src;
582 
583 		src << "#version 450\n"
584 			<< "#extension GL_KHR_shader_subgroup_vote: enable\n"
585 			<< "layout(isolines) in;\n"
586 			<< "layout(set = 0, binding = 0, std430) buffer Buffer1\n"
587 			<< "{\n"
588 			<< "  uint result[];\n"
589 			<< "};\n"
590 			<< "layout(set = 0, binding = 1, std430) buffer Buffer2\n"
591 			<< "{\n"
592 			<< "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " data[];\n"
593 			<< "};\n"
594 			<< "\n"
595 			<< "void main (void)\n"
596 			<< "{\n"
597 			<< "  highp uint offset = gl_PrimitiveID * 2 + uint(gl_TessCoord.x + 0.5);\n";
598 		if (OPTYPE_ALL == caseDef.opType)
599 		{
600 			src << "  result[offset] = " << getOpTypeName(caseDef.opType)
601 				<< "(true) ? 0x1 : 0;\n"
602 				<< "  result[offset] |= " << getOpTypeName(caseDef.opType)
603 				<< "(false) ? 0 : 0x2;\n"
604 				<< "  result[offset] |= 0x4;\n";
605 		}
606 		else if (OPTYPE_ANY == caseDef.opType)
607 		{
608 			src << "  result[offset] = " << getOpTypeName(caseDef.opType)
609 				<< "(true) ? 0x1 : 0;\n"
610 				<< "  result[offset] |= " << getOpTypeName(caseDef.opType)
611 				<< "(false) ? 0 : 0x2;\n"
612 				<< "  result[offset] |= 0x4;\n";
613 		}
614 		else if (OPTYPE_ALLEQUAL == caseDef.opType)
615 		{
616 			src << "  result[offset] = " << getOpTypeName(caseDef.opType) << "("
617 				<< subgroups::getFormatNameForGLSL(caseDef.format) << "(1)) ? 0x1 : 0;\n"
618 				<< "  result[offset] |= " << getOpTypeName(caseDef.opType)
619 				<< "(gl_SubgroupInvocationID) ? 0 : 0x2;\n"
620 				<< "  if (subgroupElect()) result[offset] |= 0x2;\n"
621 				<< "  result[offset] |= " << getOpTypeName(caseDef.opType)
622 				<< "(data[0]) ? 0x4 : 0;\n";
623 		}
624 
625 		src << "}\n";
626 
627 		programCollection.glslSources.add("tese")
628 				<< glu::TessellationEvaluationSource(src.str()) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u);
629 	}
630 	else
631 	{
632 		DE_FATAL("Unsupported shader stage");
633 	}
634 }
635 
test(Context & context,const CaseDefinition caseDef)636 tcu::TestStatus test(Context& context, const CaseDefinition caseDef)
637 {
638 	if (!subgroups::isSubgroupSupported(context))
639 		TCU_THROW(NotSupportedError, "Subgroup operations are not supported");
640 
641 	if (!subgroups::areSubgroupOperationsSupportedForStage(
642 				context, caseDef.shaderStage))
643 	{
644 		if (subgroups::areSubgroupOperationsRequiredForStage(
645 					caseDef.shaderStage))
646 		{
647 			return tcu::TestStatus::fail(
648 					   "Shader stage " +
649 					   subgroups::getShaderStageName(caseDef.shaderStage) +
650 					   " is required to support subgroup operations!");
651 		}
652 		else
653 		{
654 			TCU_THROW(NotSupportedError, "Device does not support subgroup operations for this stage");
655 		}
656 	}
657 
658 	if (!subgroups::isSubgroupFeatureSupportedForDevice(context, VK_SUBGROUP_FEATURE_VOTE_BIT))
659 	{
660 		TCU_THROW(NotSupportedError, "Device does not support subgroup vote operations");
661 	}
662 
663 	if (subgroups::isDoubleFormat(caseDef.format) &&
664 			!subgroups::isDoubleSupportedForDevice(context))
665 	{
666 		TCU_THROW(NotSupportedError, "Device does not support subgroup double operations");
667 	}
668 
669 	//Tests which don't use the SSBO
670 	if (caseDef.noSSBO)
671 	{
672 		if (VK_SHADER_STAGE_VERTEX_BIT == caseDef.shaderStage)
673 		{
674 			subgroups::SSBOData inputData;
675 			inputData.format = caseDef.format;
676 			inputData.numElements = subgroups::maxSupportedSubgroupSize();
677 			inputData.initializeType = subgroups::SSBOData::InitializeNonZero;
678 
679 			return subgroups::makeVertexFrameBufferTest(context, VK_FORMAT_R32_SFLOAT, &inputData,
680 											 1, checkVertexPipelineStagesNoSSBO);
681 		}
682 	}
683 
684 	if ((VK_SHADER_STAGE_FRAGMENT_BIT != caseDef.shaderStage) &&
685 			(VK_SHADER_STAGE_COMPUTE_BIT != caseDef.shaderStage))
686 	{
687 		if (!subgroups::isVertexSSBOSupportedForDevice(context))
688 		{
689 			TCU_THROW(NotSupportedError, "Device does not support vertex stage SSBO writes");
690 		}
691 	}
692 
693 	if (VK_SHADER_STAGE_FRAGMENT_BIT == caseDef.shaderStage)
694 	{
695 		subgroups::SSBOData inputData;
696 		inputData.format = caseDef.format;
697 		inputData.numElements = subgroups::maxSupportedSubgroupSize();
698 		inputData.initializeType = subgroups::SSBOData::InitializeNonZero;
699 
700 		return subgroups::makeFragmentTest(context, VK_FORMAT_R32_UINT,
701 										   &inputData, 1, checkFragment);
702 	}
703 	else if (VK_SHADER_STAGE_COMPUTE_BIT == caseDef.shaderStage)
704 	{
705 		subgroups::SSBOData inputData;
706 		inputData.format = caseDef.format;
707 		inputData.numElements = subgroups::maxSupportedSubgroupSize();
708 		inputData.initializeType = subgroups::SSBOData::InitializeNonZero;
709 
710 		return subgroups::makeComputeTest(context, VK_FORMAT_R32_UINT, &inputData,
711 										  1, (OPTYPE_ALLEQUAL == caseDef.opType) ? checkComputeAllEqual
712 										  : checkCompute);
713 	}
714 	else if (VK_SHADER_STAGE_VERTEX_BIT == caseDef.shaderStage)
715 	{
716 		subgroups::SSBOData inputData;
717 		inputData.format = caseDef.format;
718 		inputData.numElements = subgroups::maxSupportedSubgroupSize();
719 		inputData.initializeType = subgroups::SSBOData::InitializeNonZero;
720 
721 		return subgroups::makeVertexTest(context, VK_FORMAT_R32_UINT, &inputData,
722 										 1, checkVertexPipelineStages);
723 	}
724 	else if (VK_SHADER_STAGE_GEOMETRY_BIT == caseDef.shaderStage)
725 	{
726 		subgroups::SSBOData inputData;
727 		inputData.format = caseDef.format;
728 		inputData.numElements = subgroups::maxSupportedSubgroupSize();
729 		inputData.initializeType = subgroups::SSBOData::InitializeNonZero;
730 
731 		return subgroups::makeGeometryTest(context, VK_FORMAT_R32_UINT, &inputData,
732 										   1, checkVertexPipelineStages);
733 	}
734 	else if (VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT == caseDef.shaderStage)
735 	{
736 		subgroups::SSBOData inputData;
737 		inputData.format = caseDef.format;
738 		inputData.numElements = subgroups::maxSupportedSubgroupSize();
739 		inputData.initializeType = subgroups::SSBOData::InitializeNonZero;
740 
741 		return subgroups::makeTessellationControlTest(context, VK_FORMAT_R32_UINT, &inputData,
742 				1, checkVertexPipelineStages);
743 	}
744 	else if (VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT == caseDef.shaderStage)
745 	{
746 		subgroups::SSBOData inputData;
747 		inputData.format = caseDef.format;
748 		inputData.numElements = subgroups::maxSupportedSubgroupSize();
749 		inputData.initializeType = subgroups::SSBOData::InitializeNonZero;
750 
751 		return subgroups::makeTessellationEvaluationTest(context, VK_FORMAT_R32_UINT, &inputData,
752 				1, checkVertexPipelineStages);
753 	}
754 	else
755 	{
756 		TCU_THROW(InternalError, "Unhandled shader stage");
757 	}
758 }
759 }
760 
761 namespace vkt
762 {
763 namespace subgroups
764 {
createSubgroupsVoteTests(tcu::TestContext & testCtx)765 tcu::TestCaseGroup* createSubgroupsVoteTests(tcu::TestContext& testCtx)
766 {
767 	de::MovePtr<tcu::TestCaseGroup> group(new tcu::TestCaseGroup(
768 			testCtx, "vote", "Subgroup vote category tests"));
769 
770 	const VkShaderStageFlags stages[] =
771 	{
772 		VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT,
773 		VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT,
774 		VK_SHADER_STAGE_GEOMETRY_BIT,
775 		VK_SHADER_STAGE_VERTEX_BIT,
776 		VK_SHADER_STAGE_FRAGMENT_BIT,
777 		VK_SHADER_STAGE_COMPUTE_BIT
778 	};
779 
780 	const VkFormat formats[] =
781 	{
782 		VK_FORMAT_R32_SINT, VK_FORMAT_R32G32_SINT, VK_FORMAT_R32G32B32_SINT,
783 		VK_FORMAT_R32G32B32A32_SINT, VK_FORMAT_R32_UINT, VK_FORMAT_R32G32_UINT,
784 		VK_FORMAT_R32G32B32_UINT, VK_FORMAT_R32G32B32A32_UINT,
785 		VK_FORMAT_R32_SFLOAT, VK_FORMAT_R32G32_SFLOAT,
786 		VK_FORMAT_R32G32B32_SFLOAT, VK_FORMAT_R32G32B32A32_SFLOAT,
787 		VK_FORMAT_R64_SFLOAT, VK_FORMAT_R64G64_SFLOAT,
788 		VK_FORMAT_R64G64B64_SFLOAT, VK_FORMAT_R64G64B64A64_SFLOAT,
789 		VK_FORMAT_R8_USCALED, VK_FORMAT_R8G8_USCALED,
790 		VK_FORMAT_R8G8B8_USCALED, VK_FORMAT_R8G8B8A8_USCALED,
791 	};
792 
793 	for (int stageIndex = 0; stageIndex < DE_LENGTH_OF_ARRAY(stages); ++stageIndex)
794 	{
795 		const VkShaderStageFlags stage = stages[stageIndex];
796 
797 		for (int formatIndex = 0; formatIndex < DE_LENGTH_OF_ARRAY(formats); ++formatIndex)
798 		{
799 			const VkFormat format = formats[formatIndex];
800 
801 			for (int opTypeIndex = 0; opTypeIndex < OPTYPE_LAST; ++opTypeIndex)
802 			{
803 				// Skip the typed tests for all but subgroupAllEqual()
804 				if ((VK_FORMAT_R32_UINT != format) && (OPTYPE_ALLEQUAL != opTypeIndex))
805 				{
806 					continue;
807 				}
808 
809 				CaseDefinition caseDef = {opTypeIndex, stage, format, false};
810 
811 				std::string op = getOpTypeName(opTypeIndex);
812 
813 				addFunctionCaseWithPrograms(group.get(),
814 											de::toLower(op) + "_" +
815 											subgroups::getFormatNameForGLSL(format)
816 											+ "_" + getShaderStageName(stage),
817 											"", initPrograms, test, caseDef);
818 
819 				if (VK_SHADER_STAGE_VERTEX_BIT == stage )
820 				{
821 					caseDef.noSSBO = true;
822 					addFunctionCaseWithPrograms(group.get(),
823 								de::toLower(op) + "_" +
824 								subgroups::getFormatNameForGLSL(format)
825 								+ "_" + getShaderStageName(stage)+"_framebuffer", "",
826 								initFrameBufferPrograms, test, caseDef);
827 				}
828 
829 			}
830 		}
831 	}
832 
833 	return group.release();
834 }
835 
836 } // subgroups
837 } // vkt
838