• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*------------------------------------------------------------------------
2  * OpenGL Conformance Tests
3  * ------------------------
4  *
5  * Copyright (c) 2017-2019 The Khronos Group Inc.
6  * Copyright (c) 2017 Codeplay Software Ltd.
7  * Copyright (c) 2019 NVIDIA Corporation.
8  *
9  * Licensed under the Apache License, Version 2.0 (the "License");
10  * you may not use this file except in compliance with the License.
11  * You may obtain a copy of the License at
12  *
13  *      http://www.apache.org/licenses/LICENSE-2.0
14  *
15  * Unless required by applicable law or agreed to in writing, software
16  * distributed under the License is distributed on an "AS IS" BASIS,
17  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18  * See the License for the specific language governing permissions and
19  * limitations under the License.
20  *
21  */ /*!
22  * \file
23  * \brief Subgroups Tests
24  */ /*--------------------------------------------------------------------*/
25 
26 #include "glcSubgroupsQuadTests.hpp"
27 #include "glcSubgroupsTestsUtils.hpp"
28 
29 #include <string>
30 #include <vector>
31 
32 using namespace tcu;
33 using namespace std;
34 
35 namespace glc
36 {
37 namespace subgroups
38 {
39 namespace
40 {
41 enum OpType
42 {
43     OPTYPE_QUAD_BROADCAST = 0,
44     OPTYPE_QUAD_SWAP_HORIZONTAL,
45     OPTYPE_QUAD_SWAP_VERTICAL,
46     OPTYPE_QUAD_SWAP_DIAGONAL,
47     OPTYPE_LAST
48 };
49 
checkVertexPipelineStages(std::vector<const void * > datas,uint32_t width,uint32_t)50 static bool checkVertexPipelineStages(std::vector<const void *> datas, uint32_t width, uint32_t)
51 {
52     return glc::subgroups::check(datas, width, 1);
53 }
54 
checkComputeStage(std::vector<const void * > datas,const uint32_t numWorkgroups[3],const uint32_t localSize[3],uint32_t)55 static bool checkComputeStage(std::vector<const void *> datas, const uint32_t numWorkgroups[3],
56                               const uint32_t localSize[3], uint32_t)
57 {
58     return glc::subgroups::checkCompute(datas, numWorkgroups, localSize, 1);
59 }
60 
getOpTypeName(int opType)61 std::string getOpTypeName(int opType)
62 {
63     switch (opType)
64     {
65     default:
66         DE_FATAL("Unsupported op type");
67         return "";
68     case OPTYPE_QUAD_BROADCAST:
69         return "subgroupQuadBroadcast";
70     case OPTYPE_QUAD_SWAP_HORIZONTAL:
71         return "subgroupQuadSwapHorizontal";
72     case OPTYPE_QUAD_SWAP_VERTICAL:
73         return "subgroupQuadSwapVertical";
74     case OPTYPE_QUAD_SWAP_DIAGONAL:
75         return "subgroupQuadSwapDiagonal";
76     }
77 }
78 
79 struct CaseDefinition
80 {
81     int opType;
82     ShaderStageFlags shaderStage;
83     Format format;
84     int direction;
85 };
86 
initFrameBufferPrograms(SourceCollections & programCollection,CaseDefinition caseDef)87 void initFrameBufferPrograms(SourceCollections &programCollection, CaseDefinition caseDef)
88 {
89     std::string swapTable[OPTYPE_LAST];
90 
91     subgroups::setFragmentShaderFrameBuffer(programCollection);
92 
93     if (SHADER_STAGE_VERTEX_BIT != caseDef.shaderStage)
94         subgroups::setVertexShaderFrameBuffer(programCollection);
95 
96     swapTable[OPTYPE_QUAD_BROADCAST]       = "";
97     swapTable[OPTYPE_QUAD_SWAP_HORIZONTAL] = "  const uint swapTable[4] = uint[](1u, 0u, 3u, 2u);\n";
98     swapTable[OPTYPE_QUAD_SWAP_VERTICAL]   = "  const uint swapTable[4] = uint[](2u, 3u, 0u, 1u);\n";
99     swapTable[OPTYPE_QUAD_SWAP_DIAGONAL]   = "  const uint swapTable[4] = uint[](3u, 2u, 1u, 0u);\n";
100 
101     if (SHADER_STAGE_VERTEX_BIT == caseDef.shaderStage)
102     {
103         std::ostringstream vertexSrc;
104         vertexSrc << "${VERSION_DECL}\n"
105                   << "#extension GL_KHR_shader_subgroup_quad: enable\n"
106                   << "#extension GL_KHR_shader_subgroup_ballot: enable\n"
107                   << "layout(location = 0) in highp vec4 in_position;\n"
108                   << "layout(location = 0) out float result;\n"
109                   << "layout(binding = 0, std140) uniform Buffer0\n"
110                   << "{\n"
111                   << "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " data["
112                   << subgroups::maxSupportedSubgroupSize() << "];\n"
113                   << "};\n"
114                   << "\n"
115                   << "void main (void)\n"
116                   << "{\n"
117                   << "  uvec4 mask = subgroupBallot(true);\n"
118                   << swapTable[caseDef.opType];
119 
120         if (OPTYPE_QUAD_BROADCAST == caseDef.opType)
121         {
122             vertexSrc << "  " << subgroups::getFormatNameForGLSL(caseDef.format)
123                       << " op = " << getOpTypeName(caseDef.opType) << "(data[gl_SubgroupInvocationID], "
124                       << caseDef.direction << "u);\n"
125                       << "  uint otherID = (gl_SubgroupInvocationID & ~0x3u) + " << caseDef.direction << "u;\n";
126         }
127         else
128         {
129             vertexSrc
130                 << "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " op = " << getOpTypeName(caseDef.opType)
131                 << "(data[gl_SubgroupInvocationID]);\n"
132                 << "  uint otherID = (gl_SubgroupInvocationID & ~0x3u) + swapTable[gl_SubgroupInvocationID & 0x3u];\n";
133         }
134 
135         vertexSrc << "  if (subgroupBallotBitExtract(mask, otherID))\n"
136                   << "  {\n"
137                   << "    result = (op == data[otherID]) ? 1.0f : 0.0f;\n"
138                   << "  }\n"
139                   << "  else\n"
140                   << "  {\n"
141                   << "    result = 1.0f;\n" // Invocation we read from was inactive, so we can't verify results!
142                   << "  }\n"
143                   << "  gl_Position = in_position;\n"
144                   << "  gl_PointSize = 1.0f;\n"
145                   << "}\n";
146         programCollection.add("vert") << glu::VertexSource(vertexSrc.str());
147     }
148     else if (SHADER_STAGE_GEOMETRY_BIT == caseDef.shaderStage)
149     {
150         std::ostringstream geometry;
151 
152         geometry << "${VERSION_DECL}\n"
153                  << "#extension GL_KHR_shader_subgroup_quad: enable\n"
154                  << "#extension GL_KHR_shader_subgroup_ballot: enable\n"
155                  << "layout(points) in;\n"
156                  << "layout(points, max_vertices = 1) out;\n"
157                  << "layout(location = 0) out float out_color;\n"
158                  << "layout(binding = 0, std140) uniform Buffer0\n"
159                  << "{\n"
160                  << "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " data["
161                  << subgroups::maxSupportedSubgroupSize() << "];\n"
162                  << "};\n"
163                  << "\n"
164                  << "void main (void)\n"
165                  << "{\n"
166                  << "  uvec4 mask = subgroupBallot(true);\n"
167                  << swapTable[caseDef.opType];
168 
169         if (OPTYPE_QUAD_BROADCAST == caseDef.opType)
170         {
171             geometry << "  " << subgroups::getFormatNameForGLSL(caseDef.format)
172                      << " op = " << getOpTypeName(caseDef.opType) << "(data[gl_SubgroupInvocationID], "
173                      << caseDef.direction << "u);\n"
174                      << "  uint otherID = (gl_SubgroupInvocationID & ~0x3u) + " << caseDef.direction << "u;\n";
175         }
176         else
177         {
178             geometry
179                 << "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " op = " << getOpTypeName(caseDef.opType)
180                 << "(data[gl_SubgroupInvocationID]);\n"
181                 << "  uint otherID = (gl_SubgroupInvocationID & ~0x3u) + swapTable[gl_SubgroupInvocationID & 0x3u];\n";
182         }
183 
184         geometry << "  if (subgroupBallotBitExtract(mask, otherID))\n"
185                  << "  {\n"
186                  << "    out_color = (op == data[otherID]) ? 1.0 : 0.0;\n"
187                  << "  }\n"
188                  << "  else\n"
189                  << "  {\n"
190                  << "    out_color = 1.0;\n" // Invocation we read from was inactive, so we can't verify results!
191                  << "  }\n"
192                  << "  gl_Position = gl_in[0].gl_Position;\n"
193                  << "  EmitVertex();\n"
194                  << "  EndPrimitive();\n"
195                  << "}\n";
196 
197         programCollection.add("geometry") << glu::GeometrySource(geometry.str());
198     }
199     else if (SHADER_STAGE_TESS_CONTROL_BIT == caseDef.shaderStage)
200     {
201         std::ostringstream controlSource;
202 
203         controlSource << "${VERSION_DECL}\n"
204                       << "#extension GL_KHR_shader_subgroup_quad: enable\n"
205                       << "#extension GL_KHR_shader_subgroup_ballot: enable\n"
206                       << "layout(vertices = 2) out;\n"
207                       << "layout(location = 0) out float out_color[];\n"
208                       << "layout(binding = 0, std140) uniform Buffer0\n"
209                       << "{\n"
210                       << "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " data["
211                       << subgroups::maxSupportedSubgroupSize() << "];\n"
212                       << "};\n"
213                       << "\n"
214                       << "void main (void)\n"
215                       << "{\n"
216                       << "  if (gl_InvocationID == 0)\n"
217                       << "  {\n"
218                       << "    gl_TessLevelOuter[0] = 1.0f;\n"
219                       << "    gl_TessLevelOuter[1] = 1.0f;\n"
220                       << "  }\n"
221                       << "  uvec4 mask = subgroupBallot(true);\n"
222                       << swapTable[caseDef.opType];
223 
224         if (OPTYPE_QUAD_BROADCAST == caseDef.opType)
225         {
226             controlSource << "  " << subgroups::getFormatNameForGLSL(caseDef.format)
227                           << " op = " << getOpTypeName(caseDef.opType) << "(data[gl_SubgroupInvocationID], "
228                           << caseDef.direction << "u);\n"
229                           << "  uint otherID = (gl_SubgroupInvocationID & ~0x3u) + " << caseDef.direction << "u;\n";
230         }
231         else
232         {
233             controlSource
234                 << "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " op = " << getOpTypeName(caseDef.opType)
235                 << "(data[gl_SubgroupInvocationID]);\n"
236                 << "  uint otherID = (gl_SubgroupInvocationID & ~0x3u) + swapTable[gl_SubgroupInvocationID & 0x3u];\n";
237         }
238 
239         controlSource
240             << "  if (subgroupBallotBitExtract(mask, otherID))\n"
241             << "  {\n"
242             << "    out_color[gl_InvocationID] = (op == data[otherID]) ? 1.0 : 0.0;\n"
243             << "  }\n"
244             << "  else\n"
245             << "  {\n"
246             << "    out_color[gl_InvocationID] = 1.0; \n" // Invocation we read from was inactive, so we can't verify results!
247             << "  }\n"
248             << "  gl_out[gl_InvocationID].gl_Position = gl_in[gl_InvocationID].gl_Position;\n"
249             << "}\n";
250 
251         programCollection.add("tesc") << glu::TessellationControlSource(controlSource.str());
252         subgroups::setTesEvalShaderFrameBuffer(programCollection);
253     }
254     else if (SHADER_STAGE_TESS_EVALUATION_BIT == caseDef.shaderStage)
255     {
256         ostringstream evaluationSource;
257         evaluationSource << "${VERSION_DECL}\n"
258                          << "#extension GL_KHR_shader_subgroup_quad: enable\n"
259                          << "#extension GL_KHR_shader_subgroup_ballot: enable\n"
260                          << "layout(isolines, equal_spacing, ccw ) in;\n"
261                          << "layout(location = 0) out float out_color;\n"
262                          << "layout(binding = 0, std140) uniform Buffer0\n"
263                          << "{\n"
264                          << "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " data["
265                          << subgroups::maxSupportedSubgroupSize() << "];\n"
266                          << "};\n"
267                          << "\n"
268                          << "void main (void)\n"
269                          << "{\n"
270                          << "  uvec4 mask = subgroupBallot(true);\n"
271                          << swapTable[caseDef.opType];
272 
273         if (OPTYPE_QUAD_BROADCAST == caseDef.opType)
274         {
275             evaluationSource << "  " << subgroups::getFormatNameForGLSL(caseDef.format)
276                              << " op = " << getOpTypeName(caseDef.opType) << "(data[gl_SubgroupInvocationID], "
277                              << caseDef.direction << "u);\n"
278                              << "  uint otherID = (gl_SubgroupInvocationID & ~0x3u) + " << caseDef.direction << "u;\n";
279         }
280         else
281         {
282             evaluationSource
283                 << "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " op = " << getOpTypeName(caseDef.opType)
284                 << "(data[gl_SubgroupInvocationID]);\n"
285                 << "  uint otherID = (gl_SubgroupInvocationID & ~0x3u) + swapTable[gl_SubgroupInvocationID & 0x3u];\n";
286         }
287 
288         evaluationSource
289             << "  if (subgroupBallotBitExtract(mask, otherID))\n"
290             << "  {\n"
291             << "    out_color = (op == data[otherID]) ? 1.0 : 0.0;\n"
292             << "  }\n"
293             << "  else\n"
294             << "  {\n"
295             << "    out_color = 1.0;\n" // Invocation we read from was inactive, so we can't verify results!
296             << "  }\n"
297             << "  gl_Position = mix(gl_in[0].gl_Position, gl_in[1].gl_Position, gl_TessCoord.x);\n"
298             << "}\n";
299 
300         subgroups::setTesCtrlShaderFrameBuffer(programCollection);
301         programCollection.add("tese") << glu::TessellationEvaluationSource(evaluationSource.str());
302     }
303     else
304     {
305         DE_FATAL("Unsupported shader stage");
306     }
307 }
308 
initPrograms(SourceCollections & programCollection,CaseDefinition caseDef)309 void initPrograms(SourceCollections &programCollection, CaseDefinition caseDef)
310 {
311     std::string swapTable[OPTYPE_LAST];
312     swapTable[OPTYPE_QUAD_BROADCAST]       = "";
313     swapTable[OPTYPE_QUAD_SWAP_HORIZONTAL] = "  const uint swapTable[4] = uint[](1u, 0u, 3u, 2u);\n";
314     swapTable[OPTYPE_QUAD_SWAP_VERTICAL]   = "  const uint swapTable[4] = uint[](2u, 3u, 0u, 1u);\n";
315     swapTable[OPTYPE_QUAD_SWAP_DIAGONAL]   = "  const uint swapTable[4] = uint[](3u, 2u, 1u, 0u);\n";
316 
317     if (SHADER_STAGE_COMPUTE_BIT == caseDef.shaderStage)
318     {
319         std::ostringstream src;
320 
321         src << "${VERSION_DECL}\n"
322             << "#extension GL_KHR_shader_subgroup_quad: enable\n"
323             << "#extension GL_KHR_shader_subgroup_ballot: enable\n"
324             << "layout (${LOCAL_SIZE_X}, ${LOCAL_SIZE_Y}, ${LOCAL_SIZE_Z}) in;\n"
325             << "layout(binding = 0, std430) buffer Buffer0\n"
326             << "{\n"
327             << "  uint result[];\n"
328             << "};\n"
329             << "layout(binding = 1, std430) buffer Buffer1\n"
330             << "{\n"
331             << "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " data[];\n"
332             << "};\n"
333             << "\n"
334             << "void main (void)\n"
335             << "{\n"
336             << "  uvec3 globalSize = gl_NumWorkGroups * gl_WorkGroupSize;\n"
337             << "  highp uint offset = globalSize.x * ((globalSize.y * "
338                "gl_GlobalInvocationID.z) + gl_GlobalInvocationID.y) + "
339                "gl_GlobalInvocationID.x;\n"
340             << "  uvec4 mask = subgroupBallot(true);\n"
341             << swapTable[caseDef.opType];
342 
343         if (OPTYPE_QUAD_BROADCAST == caseDef.opType)
344         {
345             src << "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " op = " << getOpTypeName(caseDef.opType)
346                 << "(data[gl_SubgroupInvocationID], " << caseDef.direction << "u);\n"
347                 << "  uint otherID = (gl_SubgroupInvocationID & ~0x3u) + " << caseDef.direction << "u;\n";
348         }
349         else
350         {
351             src << "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " op = " << getOpTypeName(caseDef.opType)
352                 << "(data[gl_SubgroupInvocationID]);\n"
353                 << "  uint otherID = (gl_SubgroupInvocationID & ~0x3u) + swapTable[gl_SubgroupInvocationID & 0x3u];\n";
354         }
355 
356         src << "  if (subgroupBallotBitExtract(mask, otherID))\n"
357             << "  {\n"
358             << "    result[offset] = (op == data[otherID]) ? 1u : 0u;\n"
359             << "  }\n"
360             << "  else\n"
361             << "  {\n"
362             << "    result[offset] = 1u; // Invocation we read from was inactive, so we can't verify results!\n"
363             << "  }\n"
364             << "}\n";
365 
366         programCollection.add("comp") << glu::ComputeSource(src.str());
367     }
368     else
369     {
370         std::ostringstream src;
371         if (OPTYPE_QUAD_BROADCAST == caseDef.opType)
372         {
373             src << "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " op = " << getOpTypeName(caseDef.opType)
374                 << "(data[gl_SubgroupInvocationID], " << caseDef.direction << "u);\n"
375                 << "  uint otherID = (gl_SubgroupInvocationID & ~0x3u) + " << caseDef.direction << "u;\n";
376         }
377         else
378         {
379             src << "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " op = " << getOpTypeName(caseDef.opType)
380                 << "(data[gl_SubgroupInvocationID]);\n"
381                 << "  uint otherID = (gl_SubgroupInvocationID & ~0x3u) + swapTable[gl_SubgroupInvocationID & 0x3u];\n";
382         }
383         const string sourceType = src.str();
384 
385         {
386             const string vertex =
387                 "${VERSION_DECL}\n"
388                 "#extension GL_KHR_shader_subgroup_quad: enable\n"
389                 "#extension GL_KHR_shader_subgroup_ballot: enable\n"
390                 "layout(binding = 0, std430) buffer Buffer0\n"
391                 "{\n"
392                 "  uint result[];\n"
393                 "} b0;\n"
394                 "layout(binding = 4, std430) readonly buffer Buffer4\n"
395                 "{\n"
396                 "  " +
397                 subgroups::getFormatNameForGLSL(caseDef.format) +
398                 " data[];\n"
399                 "};\n"
400                 "\n"
401                 "void main (void)\n"
402                 "{\n"
403                 "  uvec4 mask = subgroupBallot(true);\n" +
404                 swapTable[caseDef.opType] + sourceType +
405                 "  if (subgroupBallotBitExtract(mask, otherID))\n"
406                 "  {\n"
407                 "    b0.result[gl_VertexID] = (op == data[otherID]) ? 1u : 0u;\n"
408                 "  }\n"
409                 "  else\n"
410                 "  {\n"
411                 "    b0.result[gl_VertexID] = 1u; // Invocation we read from was inactive, so we can't verify "
412                 "results!\n"
413                 "  }\n"
414                 "  float pixelSize = 2.0f/1024.0f;\n"
415                 "  float pixelPosition = pixelSize/2.0f - 1.0f;\n"
416                 "  gl_Position = vec4(float(gl_VertexID) * pixelSize + pixelPosition, 0.0f, 0.0f, 1.0f);\n"
417                 "}\n";
418             programCollection.add("vert") << glu::VertexSource(vertex);
419         }
420 
421         {
422             const string tesc = "${VERSION_DECL}\n"
423                                 "#extension GL_KHR_shader_subgroup_quad: enable\n"
424                                 "#extension GL_KHR_shader_subgroup_ballot: enable\n"
425                                 "layout(vertices=1) out;\n"
426                                 "layout(binding = 1, std430) buffer Buffer1\n"
427                                 "{\n"
428                                 "  uint result[];\n"
429                                 "} b1;\n"
430                                 "layout(binding = 4, std430) readonly buffer Buffer4\n"
431                                 "{\n"
432                                 "  " +
433                                 subgroups::getFormatNameForGLSL(caseDef.format) +
434                                 " data[];\n"
435                                 "};\n"
436                                 "\n"
437                                 "void main (void)\n"
438                                 "{\n"
439                                 "  uvec4 mask = subgroupBallot(true);\n" +
440                                 swapTable[caseDef.opType] + sourceType +
441                                 "  if (subgroupBallotBitExtract(mask, otherID))\n"
442                                 "  {\n"
443                                 "    b1.result[gl_PrimitiveID] = (op == data[otherID]) ? 1u : 0u;\n"
444                                 "  }\n"
445                                 "  else\n"
446                                 "  {\n"
447                                 "    b1.result[gl_PrimitiveID] = 1u; // Invocation we read from was inactive, so we "
448                                 "can't verify results!\n"
449                                 "  }\n"
450                                 "  if (gl_InvocationID == 0)\n"
451                                 "  {\n"
452                                 "    gl_TessLevelOuter[0] = 1.0f;\n"
453                                 "    gl_TessLevelOuter[1] = 1.0f;\n"
454                                 "  }\n"
455                                 "  gl_out[gl_InvocationID].gl_Position = gl_in[gl_InvocationID].gl_Position;\n"
456                                 "}\n";
457             programCollection.add("tesc") << glu::TessellationControlSource(tesc);
458         }
459 
460         {
461             const string tese =
462                 "${VERSION_DECL}\n"
463                 "#extension GL_KHR_shader_subgroup_quad: enable\n"
464                 "#extension GL_KHR_shader_subgroup_ballot: enable\n"
465                 "layout(isolines) in;\n"
466                 "layout(binding = 2, std430)  buffer Buffer2\n"
467                 "{\n"
468                 "  uint result[];\n"
469                 "} b2;\n"
470                 "layout(binding = 4, std430) readonly buffer Buffer4\n"
471                 "{\n"
472                 "  " +
473                 subgroups::getFormatNameForGLSL(caseDef.format) +
474                 " data[];\n"
475                 "};\n"
476                 "\n"
477                 "void main (void)\n"
478                 "{\n"
479                 "  uvec4 mask = subgroupBallot(true);\n" +
480                 swapTable[caseDef.opType] + sourceType +
481                 "  if (subgroupBallotBitExtract(mask, otherID))\n"
482                 "  {\n"
483                 "    b2.result[gl_PrimitiveID * 2 + int(gl_TessCoord.x + 0.5)] = (op == data[otherID]) ? 1u : 0u;\n"
484                 "  }\n"
485                 "  else\n"
486                 "  {\n"
487                 "    b2.result[gl_PrimitiveID * 2 + int(gl_TessCoord.x + 0.5)] = 1u; // Invocation we read from was "
488                 "inactive, so we can't verify results!\n"
489                 "  }\n"
490                 "  float pixelSize = 2.0f/1024.0f;\n"
491                 "  gl_Position = gl_in[0].gl_Position + gl_TessCoord.x * pixelSize / 2.0f;\n"
492                 "}\n";
493             programCollection.add("tese") << glu::TessellationEvaluationSource(tese);
494         }
495 
496         {
497             const string geometry =
498                 // version added by addGeometryShadersFromTemplate
499                 "#extension GL_KHR_shader_subgroup_quad: enable\n"
500                 "#extension GL_KHR_shader_subgroup_ballot: enable\n"
501                 "layout(${TOPOLOGY}) in;\n"
502                 "layout(points, max_vertices = 1) out;\n"
503                 "layout(binding = 3, std430) buffer Buffer3\n"
504                 "{\n"
505                 "  uint result[];\n"
506                 "} b3;\n"
507                 "layout(binding = 4, std430) readonly buffer Buffer4\n"
508                 "{\n"
509                 "  " +
510                 subgroups::getFormatNameForGLSL(caseDef.format) +
511                 " data[];\n"
512                 "};\n"
513                 "\n"
514                 "void main (void)\n"
515                 "{\n"
516                 "  uvec4 mask = subgroupBallot(true);\n" +
517                 swapTable[caseDef.opType] + sourceType +
518                 "  if (subgroupBallotBitExtract(mask, otherID))\n"
519                 "  {\n"
520                 "    b3.result[gl_PrimitiveIDIn] = (op == data[otherID]) ? 1u : 0u;\n"
521                 "  }\n"
522                 "  else\n"
523                 "  {\n"
524                 "    b3.result[gl_PrimitiveIDIn] = 1u; // Invocation we read from was inactive, so we can't verify "
525                 "results!\n"
526                 "  }\n"
527                 "  gl_Position = gl_in[0].gl_Position;\n"
528                 "  EmitVertex();\n"
529                 "  EndPrimitive();\n"
530                 "}\n";
531             subgroups::addGeometryShadersFromTemplate(geometry, programCollection);
532         }
533 
534         {
535             const string fragment =
536                 "${VERSION_DECL}\n"
537                 "#extension GL_KHR_shader_subgroup_quad: enable\n"
538                 "#extension GL_KHR_shader_subgroup_ballot: enable\n"
539                 "precision highp int;\n"
540                 "precision highp float;\n"
541                 "layout(location = 0) out uint result;\n"
542                 "layout(binding = 4, std430) readonly buffer Buffer4\n"
543                 "{\n"
544                 "  " +
545                 subgroups::getFormatNameForGLSL(caseDef.format) +
546                 " data[];\n"
547                 "};\n"
548                 "void main (void)\n"
549                 "{\n"
550                 "  uvec4 mask = subgroupBallot(true);\n" +
551                 swapTable[caseDef.opType] + sourceType +
552                 "  if (subgroupBallotBitExtract(mask, otherID))\n"
553                 "  {\n"
554                 "    result = (op == data[otherID]) ? 1u : 0u;\n"
555                 "  }\n"
556                 "  else\n"
557                 "  {\n"
558                 "    result = 1u; // Invocation we read from was inactive, so we can't verify results!\n"
559                 "  }\n"
560                 "}\n";
561             programCollection.add("fragment") << glu::FragmentSource(fragment);
562         }
563         subgroups::addNoSubgroupShader(programCollection);
564     }
565 }
566 
supportedCheck(Context & context,CaseDefinition caseDef)567 void supportedCheck(Context &context, CaseDefinition caseDef)
568 {
569     if (!subgroups::isSubgroupSupported(context))
570         TCU_THROW(NotSupportedError, "Subgroup operations are not supported");
571 
572     if (!subgroups::isSubgroupFeatureSupportedForDevice(context, SUBGROUP_FEATURE_QUAD_BIT))
573         TCU_THROW(NotSupportedError, "Device does not support subgroup quad operations");
574 
575     if (subgroups::isDoubleFormat(caseDef.format) && !subgroups::isDoubleSupportedForDevice(context))
576     {
577         TCU_THROW(NotSupportedError, "Device does not support subgroup double operations");
578     }
579 }
580 
noSSBOtest(Context & context,const CaseDefinition caseDef)581 tcu::TestStatus noSSBOtest(Context &context, const CaseDefinition caseDef)
582 {
583     if (!subgroups::areSubgroupOperationsSupportedForStage(context, caseDef.shaderStage))
584     {
585         if (subgroups::areSubgroupOperationsRequiredForStage(caseDef.shaderStage))
586         {
587             return tcu::TestStatus::fail("Shader stage " + subgroups::getShaderStageName(caseDef.shaderStage) +
588                                          " is required to support subgroup operations!");
589         }
590         else
591         {
592             TCU_THROW(NotSupportedError, "Device does not support subgroup operations for this stage");
593         }
594     }
595 
596     subgroups::SSBOData inputData;
597     inputData.format         = caseDef.format;
598     inputData.layout         = subgroups::SSBOData::LayoutStd140;
599     inputData.numElements    = subgroups::maxSupportedSubgroupSize();
600     inputData.initializeType = subgroups::SSBOData::InitializeNonZero;
601     inputData.binding        = 0u;
602 
603     if (SHADER_STAGE_VERTEX_BIT == caseDef.shaderStage)
604         return subgroups::makeVertexFrameBufferTest(context, FORMAT_R32_UINT, &inputData, 1, checkVertexPipelineStages);
605     else if (SHADER_STAGE_GEOMETRY_BIT == caseDef.shaderStage)
606         return subgroups::makeGeometryFrameBufferTest(context, FORMAT_R32_UINT, &inputData, 1,
607                                                       checkVertexPipelineStages);
608     else if (SHADER_STAGE_TESS_CONTROL_BIT == caseDef.shaderStage)
609         return subgroups::makeTessellationEvaluationFrameBufferTest(
610             context, FORMAT_R32_UINT, &inputData, 1, checkVertexPipelineStages, SHADER_STAGE_TESS_CONTROL_BIT);
611     else if (SHADER_STAGE_TESS_EVALUATION_BIT == caseDef.shaderStage)
612         return subgroups::makeTessellationEvaluationFrameBufferTest(
613             context, FORMAT_R32_UINT, &inputData, 1, checkVertexPipelineStages, SHADER_STAGE_TESS_EVALUATION_BIT);
614     else
615         TCU_THROW(InternalError, "Unhandled shader stage");
616 }
617 
test(Context & context,const CaseDefinition caseDef)618 tcu::TestStatus test(Context &context, const CaseDefinition caseDef)
619 {
620     if (SHADER_STAGE_COMPUTE_BIT == caseDef.shaderStage)
621     {
622         if (!subgroups::areSubgroupOperationsSupportedForStage(context, caseDef.shaderStage))
623         {
624             return tcu::TestStatus::fail("Shader stage " + subgroups::getShaderStageName(caseDef.shaderStage) +
625                                          " is required to support subgroup operations!");
626         }
627         subgroups::SSBOData inputData;
628         inputData.format         = caseDef.format;
629         inputData.layout         = subgroups::SSBOData::LayoutStd430;
630         inputData.numElements    = subgroups::maxSupportedSubgroupSize();
631         inputData.initializeType = subgroups::SSBOData::InitializeNonZero;
632         inputData.binding        = 1u;
633 
634         return subgroups::makeComputeTest(context, FORMAT_R32_UINT, &inputData, 1, checkComputeStage);
635     }
636     else
637     {
638         int supportedStages = context.getDeqpContext().getContextInfo().getInt(GL_SUBGROUP_SUPPORTED_STAGES_KHR);
639 
640         ShaderStageFlags stages = (ShaderStageFlags)(caseDef.shaderStage & supportedStages);
641 
642         if (SHADER_STAGE_FRAGMENT_BIT != stages && !subgroups::isVertexSSBOSupportedForDevice(context))
643         {
644             if ((stages & SHADER_STAGE_FRAGMENT_BIT) == 0)
645                 TCU_THROW(NotSupportedError, "Device does not support vertex stage SSBO writes");
646             else
647                 stages = SHADER_STAGE_FRAGMENT_BIT;
648         }
649 
650         if ((ShaderStageFlags)0u == stages)
651             TCU_THROW(NotSupportedError, "Subgroup operations are not supported for any graphic shader");
652 
653         subgroups::SSBOData inputData;
654         inputData.format         = caseDef.format;
655         inputData.layout         = subgroups::SSBOData::LayoutStd430;
656         inputData.numElements    = subgroups::maxSupportedSubgroupSize();
657         inputData.initializeType = subgroups::SSBOData::InitializeNonZero;
658         inputData.binding        = 4u;
659         inputData.stages         = stages;
660 
661         return subgroups::allStages(context, FORMAT_R32_UINT, &inputData, 1, checkVertexPipelineStages, stages);
662     }
663 }
664 } // namespace
665 
createSubgroupsQuadTests(deqp::Context & testCtx)666 deqp::TestCaseGroup *createSubgroupsQuadTests(deqp::Context &testCtx)
667 {
668     de::MovePtr<deqp::TestCaseGroup> graphicGroup(
669         new deqp::TestCaseGroup(testCtx, "graphics", "Subgroup arithmetic category tests: graphics"));
670     de::MovePtr<deqp::TestCaseGroup> computeGroup(
671         new deqp::TestCaseGroup(testCtx, "compute", "Subgroup arithmetic category tests: compute"));
672     de::MovePtr<deqp::TestCaseGroup> framebufferGroup(
673         new deqp::TestCaseGroup(testCtx, "framebuffer", "Subgroup arithmetic category tests: framebuffer"));
674 
675     const Format formats[] = {
676         FORMAT_R32_SINT,   FORMAT_R32G32_SINT,   FORMAT_R32G32B32_SINT,   FORMAT_R32G32B32A32_SINT,
677         FORMAT_R32_UINT,   FORMAT_R32G32_UINT,   FORMAT_R32G32B32_UINT,   FORMAT_R32G32B32A32_UINT,
678         FORMAT_R32_SFLOAT, FORMAT_R32G32_SFLOAT, FORMAT_R32G32B32_SFLOAT, FORMAT_R32G32B32A32_SFLOAT,
679         FORMAT_R64_SFLOAT, FORMAT_R64G64_SFLOAT, FORMAT_R64G64B64_SFLOAT, FORMAT_R64G64B64A64_SFLOAT,
680         FORMAT_R32_BOOL,   FORMAT_R32G32_BOOL,   FORMAT_R32G32B32_BOOL,   FORMAT_R32G32B32A32_BOOL,
681     };
682 
683     const ShaderStageFlags stages[] = {
684         SHADER_STAGE_VERTEX_BIT,
685         SHADER_STAGE_TESS_EVALUATION_BIT,
686         SHADER_STAGE_TESS_CONTROL_BIT,
687         SHADER_STAGE_GEOMETRY_BIT,
688     };
689 
690     for (int direction = 0; direction < 4; ++direction)
691     {
692         for (int formatIndex = 0; formatIndex < DE_LENGTH_OF_ARRAY(formats); ++formatIndex)
693         {
694             const Format format = formats[formatIndex];
695 
696             for (int opTypeIndex = 0; opTypeIndex < OPTYPE_LAST; ++opTypeIndex)
697             {
698                 const std::string op = de::toLower(getOpTypeName(opTypeIndex));
699                 std::ostringstream name;
700                 name << de::toLower(op);
701 
702                 if (OPTYPE_QUAD_BROADCAST == opTypeIndex)
703                 {
704                     name << "_" << direction;
705                 }
706                 else
707                 {
708                     if (0 != direction)
709                     {
710                         // We don't need direction for swap operations.
711                         continue;
712                     }
713                 }
714 
715                 name << "_" << subgroups::getFormatNameForGLSL(format);
716 
717                 {
718                     const CaseDefinition caseDef = {opTypeIndex, SHADER_STAGE_COMPUTE_BIT, format, direction};
719                     SubgroupFactory<CaseDefinition>::addFunctionCaseWithPrograms(
720                         computeGroup.get(), name.str(), "", supportedCheck, initPrograms, test, caseDef);
721                 }
722 
723                 {
724                     const CaseDefinition caseDef = {opTypeIndex, SHADER_STAGE_ALL_GRAPHICS, format, direction};
725                     SubgroupFactory<CaseDefinition>::addFunctionCaseWithPrograms(
726                         graphicGroup.get(), name.str(), "", supportedCheck, initPrograms, test, caseDef);
727                 }
728                 for (int stageIndex = 0; stageIndex < DE_LENGTH_OF_ARRAY(stages); ++stageIndex)
729                 {
730                     const CaseDefinition caseDef = {opTypeIndex, stages[stageIndex], format, direction};
731                     SubgroupFactory<CaseDefinition>::addFunctionCaseWithPrograms(
732                         framebufferGroup.get(), name.str() + "_" + getShaderStageName(caseDef.shaderStage), "",
733                         supportedCheck, initFrameBufferPrograms, noSSBOtest, caseDef);
734                 }
735             }
736         }
737     }
738 
739     de::MovePtr<deqp::TestCaseGroup> group(new deqp::TestCaseGroup(testCtx, "quad", "Subgroup quad category tests"));
740 
741     group->addChild(graphicGroup.release());
742     group->addChild(computeGroup.release());
743     group->addChild(framebufferGroup.release());
744 
745     return group.release();
746 }
747 } // namespace subgroups
748 } // namespace glc
749