1 /*------------------------------------------------------------------------
2 * OpenGL Conformance Tests
3 * ------------------------
4 *
5 * Copyright (c) 2017-2019 The Khronos Group Inc.
6 * Copyright (c) 2017 Codeplay Software Ltd.
7 * Copyright (c) 2019 NVIDIA Corporation.
8 *
9 * Licensed under the Apache License, Version 2.0 (the "License");
10 * you may not use this file except in compliance with the License.
11 * You may obtain a copy of the License at
12 *
13 * http://www.apache.org/licenses/LICENSE-2.0
14 *
15 * Unless required by applicable law or agreed to in writing, software
16 * distributed under the License is distributed on an "AS IS" BASIS,
17 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18 * See the License for the specific language governing permissions and
19 * limitations under the License.
20 *
21 */ /*!
22 * \file
23 * \brief Subgroups Tests
24 */ /*--------------------------------------------------------------------*/
25
26 #include "glcSubgroupsArithmeticTests.hpp"
27 #include "glcSubgroupsTestsUtils.hpp"
28
29 #include <string>
30 #include <vector>
31
32 using namespace tcu;
33 using namespace std;
34
35 namespace glc
36 {
37 namespace subgroups
38 {
39 namespace
40 {
41 enum OpType
42 {
43 OPTYPE_ADD = 0,
44 OPTYPE_MUL,
45 OPTYPE_MIN,
46 OPTYPE_MAX,
47 OPTYPE_AND,
48 OPTYPE_OR,
49 OPTYPE_XOR,
50 OPTYPE_INCLUSIVE_ADD,
51 OPTYPE_INCLUSIVE_MUL,
52 OPTYPE_INCLUSIVE_MIN,
53 OPTYPE_INCLUSIVE_MAX,
54 OPTYPE_INCLUSIVE_AND,
55 OPTYPE_INCLUSIVE_OR,
56 OPTYPE_INCLUSIVE_XOR,
57 OPTYPE_EXCLUSIVE_ADD,
58 OPTYPE_EXCLUSIVE_MUL,
59 OPTYPE_EXCLUSIVE_MIN,
60 OPTYPE_EXCLUSIVE_MAX,
61 OPTYPE_EXCLUSIVE_AND,
62 OPTYPE_EXCLUSIVE_OR,
63 OPTYPE_EXCLUSIVE_XOR,
64 OPTYPE_LAST
65 };
66
checkVertexPipelineStages(std::vector<const void * > datas,deUint32 width,deUint32)67 static bool checkVertexPipelineStages(std::vector<const void*> datas,
68 deUint32 width, deUint32)
69 {
70 return glc::subgroups::check(datas, width, 0x3);
71 }
72
checkComputeStage(std::vector<const void * > datas,const deUint32 numWorkgroups[3],const deUint32 localSize[3],deUint32)73 static bool checkComputeStage(std::vector<const void*> datas,
74 const deUint32 numWorkgroups[3], const deUint32 localSize[3],
75 deUint32)
76 {
77 return glc::subgroups::checkCompute(datas, numWorkgroups, localSize, 0x3);
78 }
79
getOpTypeName(int opType)80 std::string getOpTypeName(int opType)
81 {
82 switch (opType)
83 {
84 default:
85 DE_FATAL("Unsupported op type");
86 return "";
87 case OPTYPE_ADD:
88 return "subgroupAdd";
89 case OPTYPE_MUL:
90 return "subgroupMul";
91 case OPTYPE_MIN:
92 return "subgroupMin";
93 case OPTYPE_MAX:
94 return "subgroupMax";
95 case OPTYPE_AND:
96 return "subgroupAnd";
97 case OPTYPE_OR:
98 return "subgroupOr";
99 case OPTYPE_XOR:
100 return "subgroupXor";
101 case OPTYPE_INCLUSIVE_ADD:
102 return "subgroupInclusiveAdd";
103 case OPTYPE_INCLUSIVE_MUL:
104 return "subgroupInclusiveMul";
105 case OPTYPE_INCLUSIVE_MIN:
106 return "subgroupInclusiveMin";
107 case OPTYPE_INCLUSIVE_MAX:
108 return "subgroupInclusiveMax";
109 case OPTYPE_INCLUSIVE_AND:
110 return "subgroupInclusiveAnd";
111 case OPTYPE_INCLUSIVE_OR:
112 return "subgroupInclusiveOr";
113 case OPTYPE_INCLUSIVE_XOR:
114 return "subgroupInclusiveXor";
115 case OPTYPE_EXCLUSIVE_ADD:
116 return "subgroupExclusiveAdd";
117 case OPTYPE_EXCLUSIVE_MUL:
118 return "subgroupExclusiveMul";
119 case OPTYPE_EXCLUSIVE_MIN:
120 return "subgroupExclusiveMin";
121 case OPTYPE_EXCLUSIVE_MAX:
122 return "subgroupExclusiveMax";
123 case OPTYPE_EXCLUSIVE_AND:
124 return "subgroupExclusiveAnd";
125 case OPTYPE_EXCLUSIVE_OR:
126 return "subgroupExclusiveOr";
127 case OPTYPE_EXCLUSIVE_XOR:
128 return "subgroupExclusiveXor";
129 }
130 }
131
getOpTypeOperation(int opType,Format format,std::string lhs,std::string rhs)132 std::string getOpTypeOperation(int opType, Format format, std::string lhs, std::string rhs)
133 {
134 switch (opType)
135 {
136 default:
137 DE_FATAL("Unsupported op type");
138 return "";
139 case OPTYPE_ADD:
140 case OPTYPE_INCLUSIVE_ADD:
141 case OPTYPE_EXCLUSIVE_ADD:
142 return lhs + " + " + rhs;
143 case OPTYPE_MUL:
144 case OPTYPE_INCLUSIVE_MUL:
145 case OPTYPE_EXCLUSIVE_MUL:
146 return lhs + " * " + rhs;
147 case OPTYPE_MIN:
148 case OPTYPE_INCLUSIVE_MIN:
149 case OPTYPE_EXCLUSIVE_MIN:
150 switch (format)
151 {
152 default:
153 return "min(" + lhs + ", " + rhs + ")";
154 case FORMAT_R32_SFLOAT:
155 case FORMAT_R64_SFLOAT:
156 return "(isnan(" + lhs + ") ? " + rhs + " : (isnan(" + rhs + ") ? " + lhs + " : min(" + lhs + ", " + rhs + ")))";
157 case FORMAT_R32G32_SFLOAT:
158 case FORMAT_R32G32B32_SFLOAT:
159 case FORMAT_R32G32B32A32_SFLOAT:
160 case FORMAT_R64G64_SFLOAT:
161 case FORMAT_R64G64B64_SFLOAT:
162 case FORMAT_R64G64B64A64_SFLOAT:
163 return "mix(mix(min(" + lhs + ", " + rhs + "), " + lhs + ", isnan(" + rhs + ")), " + rhs + ", isnan(" + lhs + "))";
164 }
165 case OPTYPE_MAX:
166 case OPTYPE_INCLUSIVE_MAX:
167 case OPTYPE_EXCLUSIVE_MAX:
168 switch (format)
169 {
170 default:
171 return "max(" + lhs + ", " + rhs + ")";
172 case FORMAT_R32_SFLOAT:
173 case FORMAT_R64_SFLOAT:
174 return "(isnan(" + lhs + ") ? " + rhs + " : (isnan(" + rhs + ") ? " + lhs + " : max(" + lhs + ", " + rhs + ")))";
175 case FORMAT_R32G32_SFLOAT:
176 case FORMAT_R32G32B32_SFLOAT:
177 case FORMAT_R32G32B32A32_SFLOAT:
178 case FORMAT_R64G64_SFLOAT:
179 case FORMAT_R64G64B64_SFLOAT:
180 case FORMAT_R64G64B64A64_SFLOAT:
181 return "mix(mix(max(" + lhs + ", " + rhs + "), " + lhs + ", isnan(" + rhs + ")), " + rhs + ", isnan(" + lhs + "))";
182 }
183 case OPTYPE_AND:
184 case OPTYPE_INCLUSIVE_AND:
185 case OPTYPE_EXCLUSIVE_AND:
186 switch (format)
187 {
188 default:
189 return lhs + " & " + rhs;
190 case FORMAT_R32_BOOL:
191 return lhs + " && " + rhs;
192 case FORMAT_R32G32_BOOL:
193 return "bvec2(" + lhs + ".x && " + rhs + ".x, " + lhs + ".y && " + rhs + ".y)";
194 case FORMAT_R32G32B32_BOOL:
195 return "bvec3(" + lhs + ".x && " + rhs + ".x, " + lhs + ".y && " + rhs + ".y, " + lhs + ".z && " + rhs + ".z)";
196 case FORMAT_R32G32B32A32_BOOL:
197 return "bvec4(" + lhs + ".x && " + rhs + ".x, " + lhs + ".y && " + rhs + ".y, " + lhs + ".z && " + rhs + ".z, " + lhs + ".w && " + rhs + ".w)";
198 }
199 case OPTYPE_OR:
200 case OPTYPE_INCLUSIVE_OR:
201 case OPTYPE_EXCLUSIVE_OR:
202 switch (format)
203 {
204 default:
205 return lhs + " | " + rhs;
206 case FORMAT_R32_BOOL:
207 return lhs + " || " + rhs;
208 case FORMAT_R32G32_BOOL:
209 return "bvec2(" + lhs + ".x || " + rhs + ".x, " + lhs + ".y || " + rhs + ".y)";
210 case FORMAT_R32G32B32_BOOL:
211 return "bvec3(" + lhs + ".x || " + rhs + ".x, " + lhs + ".y || " + rhs + ".y, " + lhs + ".z || " + rhs + ".z)";
212 case FORMAT_R32G32B32A32_BOOL:
213 return "bvec4(" + lhs + ".x || " + rhs + ".x, " + lhs + ".y || " + rhs + ".y, " + lhs + ".z || " + rhs + ".z, " + lhs + ".w || " + rhs + ".w)";
214 }
215 case OPTYPE_XOR:
216 case OPTYPE_INCLUSIVE_XOR:
217 case OPTYPE_EXCLUSIVE_XOR:
218 switch (format)
219 {
220 default:
221 return lhs + " ^ " + rhs;
222 case FORMAT_R32_BOOL:
223 return lhs + " ^^ " + rhs;
224 case FORMAT_R32G32_BOOL:
225 return "bvec2(" + lhs + ".x ^^ " + rhs + ".x, " + lhs + ".y ^^ " + rhs + ".y)";
226 case FORMAT_R32G32B32_BOOL:
227 return "bvec3(" + lhs + ".x ^^ " + rhs + ".x, " + lhs + ".y ^^ " + rhs + ".y, " + lhs + ".z ^^ " + rhs + ".z)";
228 case FORMAT_R32G32B32A32_BOOL:
229 return "bvec4(" + lhs + ".x ^^ " + rhs + ".x, " + lhs + ".y ^^ " + rhs + ".y, " + lhs + ".z ^^ " + rhs + ".z, " + lhs + ".w ^^ " + rhs + ".w)";
230 }
231 }
232 }
233
getIdentity(int opType,Format format)234 std::string getIdentity(int opType, Format format)
235 {
236 bool isFloat = false;
237 bool isInt = false;
238 bool isUnsigned = false;
239
240 switch (format)
241 {
242 default:
243 DE_FATAL("Unhandled format!");
244 break;
245 case FORMAT_R32_SINT:
246 case FORMAT_R32G32_SINT:
247 case FORMAT_R32G32B32_SINT:
248 case FORMAT_R32G32B32A32_SINT:
249 isInt = true;
250 break;
251 case FORMAT_R32_UINT:
252 case FORMAT_R32G32_UINT:
253 case FORMAT_R32G32B32_UINT:
254 case FORMAT_R32G32B32A32_UINT:
255 isUnsigned = true;
256 break;
257 case FORMAT_R32_SFLOAT:
258 case FORMAT_R32G32_SFLOAT:
259 case FORMAT_R32G32B32_SFLOAT:
260 case FORMAT_R32G32B32A32_SFLOAT:
261 case FORMAT_R64_SFLOAT:
262 case FORMAT_R64G64_SFLOAT:
263 case FORMAT_R64G64B64_SFLOAT:
264 case FORMAT_R64G64B64A64_SFLOAT:
265 isFloat = true;
266 break;
267 case FORMAT_R32_BOOL:
268 case FORMAT_R32G32_BOOL:
269 case FORMAT_R32G32B32_BOOL:
270 case FORMAT_R32G32B32A32_BOOL:
271 break; // bool types are not anything
272 }
273
274 switch (opType)
275 {
276 default:
277 DE_FATAL("Unsupported op type");
278 return "";
279 case OPTYPE_ADD:
280 case OPTYPE_INCLUSIVE_ADD:
281 case OPTYPE_EXCLUSIVE_ADD:
282 return subgroups::getFormatNameForGLSL(format) + "(0)";
283 case OPTYPE_MUL:
284 case OPTYPE_INCLUSIVE_MUL:
285 case OPTYPE_EXCLUSIVE_MUL:
286 return subgroups::getFormatNameForGLSL(format) + "(1)";
287 case OPTYPE_MIN:
288 case OPTYPE_INCLUSIVE_MIN:
289 case OPTYPE_EXCLUSIVE_MIN:
290 if (isFloat)
291 {
292 return subgroups::getFormatNameForGLSL(format) + "(intBitsToFloat(0x7f800000))";
293 }
294 else if (isInt)
295 {
296 return subgroups::getFormatNameForGLSL(format) + "(0x7fffffff)";
297 }
298 else if (isUnsigned)
299 {
300 return subgroups::getFormatNameForGLSL(format) + "(0xffffffffu)";
301 }
302 else
303 {
304 DE_FATAL("Unhandled case");
305 return "";
306 }
307 case OPTYPE_MAX:
308 case OPTYPE_INCLUSIVE_MAX:
309 case OPTYPE_EXCLUSIVE_MAX:
310 if (isFloat)
311 {
312 return subgroups::getFormatNameForGLSL(format) + "(intBitsToFloat(0xff800000))";
313 }
314 else if (isInt)
315 {
316 return subgroups::getFormatNameForGLSL(format) + "(0x80000000)";
317 }
318 else if (isUnsigned)
319 {
320 return subgroups::getFormatNameForGLSL(format) + "(0u)";
321 }
322 else
323 {
324 DE_FATAL("Unhandled case");
325 return "";
326 }
327 case OPTYPE_AND:
328 case OPTYPE_INCLUSIVE_AND:
329 case OPTYPE_EXCLUSIVE_AND:
330 return subgroups::getFormatNameForGLSL(format) + "(~0)";
331 case OPTYPE_OR:
332 case OPTYPE_INCLUSIVE_OR:
333 case OPTYPE_EXCLUSIVE_OR:
334 return subgroups::getFormatNameForGLSL(format) + "(0)";
335 case OPTYPE_XOR:
336 case OPTYPE_INCLUSIVE_XOR:
337 case OPTYPE_EXCLUSIVE_XOR:
338 return subgroups::getFormatNameForGLSL(format) + "(0)";
339 }
340 }
341
getCompare(int opType,Format format,std::string lhs,std::string rhs)342 std::string getCompare(int opType, Format format, std::string lhs, std::string rhs)
343 {
344 std::string formatName = subgroups::getFormatNameForGLSL(format);
345 switch (format)
346 {
347 default:
348 return "all(equal(" + lhs + ", " + rhs + "))";
349 case FORMAT_R32_BOOL:
350 case FORMAT_R32_UINT:
351 case FORMAT_R32_SINT:
352 return "(" + lhs + " == " + rhs + ")";
353 case FORMAT_R32_SFLOAT:
354 case FORMAT_R64_SFLOAT:
355 switch (opType)
356 {
357 default:
358 return "(abs(" + lhs + " - " + rhs + ") < 0.00001)";
359 case OPTYPE_MIN:
360 case OPTYPE_INCLUSIVE_MIN:
361 case OPTYPE_EXCLUSIVE_MIN:
362 case OPTYPE_MAX:
363 case OPTYPE_INCLUSIVE_MAX:
364 case OPTYPE_EXCLUSIVE_MAX:
365 return "(" + lhs + " == " + rhs + ")";
366 }
367 case FORMAT_R32G32_SFLOAT:
368 case FORMAT_R32G32B32_SFLOAT:
369 case FORMAT_R32G32B32A32_SFLOAT:
370 case FORMAT_R64G64_SFLOAT:
371 case FORMAT_R64G64B64_SFLOAT:
372 case FORMAT_R64G64B64A64_SFLOAT:
373 switch (opType)
374 {
375 default:
376 return "all(lessThan(abs(" + lhs + " - " + rhs + "), " + formatName + "(0.00001)))";
377 case OPTYPE_MIN:
378 case OPTYPE_INCLUSIVE_MIN:
379 case OPTYPE_EXCLUSIVE_MIN:
380 case OPTYPE_MAX:
381 case OPTYPE_INCLUSIVE_MAX:
382 case OPTYPE_EXCLUSIVE_MAX:
383 return "all(equal(" + lhs + ", " + rhs + "))";
384 }
385 }
386 }
387
388 struct CaseDefinition
389 {
390 int opType;
391 ShaderStageFlags shaderStage;
392 Format format;
393 };
394
initFrameBufferPrograms(SourceCollections & programCollection,CaseDefinition caseDef)395 void initFrameBufferPrograms (SourceCollections& programCollection, CaseDefinition caseDef)
396 {
397 std::string indexVars;
398 std::ostringstream bdy;
399
400 subgroups::setFragmentShaderFrameBuffer(programCollection);
401
402 if (SHADER_STAGE_VERTEX_BIT != caseDef.shaderStage)
403 subgroups::setVertexShaderFrameBuffer(programCollection);
404
405 switch (caseDef.opType)
406 {
407 default:
408 indexVars = " uint start = 0u, end = gl_SubgroupSize;\n";
409 break;
410 case OPTYPE_INCLUSIVE_ADD:
411 case OPTYPE_INCLUSIVE_MUL:
412 case OPTYPE_INCLUSIVE_MIN:
413 case OPTYPE_INCLUSIVE_MAX:
414 case OPTYPE_INCLUSIVE_AND:
415 case OPTYPE_INCLUSIVE_OR:
416 case OPTYPE_INCLUSIVE_XOR:
417 indexVars = " uint start = 0u, end = gl_SubgroupInvocationID + 1u;\n";
418 break;
419 case OPTYPE_EXCLUSIVE_ADD:
420 case OPTYPE_EXCLUSIVE_MUL:
421 case OPTYPE_EXCLUSIVE_MIN:
422 case OPTYPE_EXCLUSIVE_MAX:
423 case OPTYPE_EXCLUSIVE_AND:
424 case OPTYPE_EXCLUSIVE_OR:
425 case OPTYPE_EXCLUSIVE_XOR:
426 indexVars = " uint start = 0u, end = gl_SubgroupInvocationID;\n";
427 break;
428 }
429
430 bdy << indexVars
431 << " " << subgroups::getFormatNameForGLSL(caseDef.format) << " ref = "
432 << getIdentity(caseDef.opType, caseDef.format) << ";\n"
433 << " uint tempResult = 0u;\n"
434 << " for (uint index = start; index < end; index++)\n"
435 << " {\n"
436 << " if (subgroupBallotBitExtract(mask, index))\n"
437 << " {\n"
438 << " ref = " << getOpTypeOperation(caseDef.opType, caseDef.format, "ref", "data[index]") << ";\n"
439 << " }\n"
440 << " }\n"
441 << " tempResult = " << getCompare(caseDef.opType, caseDef.format, "ref",
442 getOpTypeName(caseDef.opType) + "(data[gl_SubgroupInvocationID])") << " ? 0x1u : 0u;\n"
443 << " if (1u == (gl_SubgroupInvocationID % 2u))\n"
444 << " {\n"
445 << " mask = subgroupBallot(true);\n"
446 << " ref = " << getIdentity(caseDef.opType, caseDef.format) << ";\n"
447 << " for (uint index = start; index < end; index++)\n"
448 << " {\n"
449 << " if (subgroupBallotBitExtract(mask, index))\n"
450 << " {\n"
451 << " ref = " << getOpTypeOperation(caseDef.opType, caseDef.format, "ref", "data[index]") << ";\n"
452 << " }\n"
453 << " }\n"
454 << " tempResult |= " << getCompare(caseDef.opType, caseDef.format, "ref",
455 getOpTypeName(caseDef.opType) + "(data[gl_SubgroupInvocationID])") << " ? 0x2u : 0u;\n"
456 << " }\n"
457 << " else\n"
458 << " {\n"
459 << " tempResult |= 0x2u;\n"
460 << " }\n";
461
462 if (SHADER_STAGE_VERTEX_BIT == caseDef.shaderStage)
463 {
464 std::ostringstream vertexSrc;
465 vertexSrc << "${VERSION_DECL}\n"
466 << "#extension GL_KHR_shader_subgroup_arithmetic: enable\n"
467 << "#extension GL_KHR_shader_subgroup_ballot: enable\n"
468 << "layout(location = 0) in highp vec4 in_position;\n"
469 << "layout(location = 0) out float out_color;\n"
470 << "layout(binding = 0, std140) uniform Buffer0\n"
471 << "{\n"
472 << " " << subgroups::getFormatNameForGLSL(caseDef.format) << " data[" << subgroups::maxSupportedSubgroupSize() << "];\n"
473 << "};\n"
474 << "\n"
475 << "void main (void)\n"
476 << "{\n"
477 << " uvec4 mask = subgroupBallot(true);\n"
478 << bdy.str()
479 << " out_color = float(tempResult);\n"
480 << " gl_Position = in_position;\n"
481 << " gl_PointSize = 1.0f;\n"
482 << "}\n";
483 programCollection.add("vert") << glu::VertexSource(vertexSrc.str());
484 }
485 else if (SHADER_STAGE_GEOMETRY_BIT == caseDef.shaderStage)
486 {
487 std::ostringstream geometry;
488
489 geometry << "${VERSION_DECL}\n"
490 << "#extension GL_KHR_shader_subgroup_arithmetic: enable\n"
491 << "#extension GL_KHR_shader_subgroup_ballot: enable\n"
492 << "layout(points) in;\n"
493 << "layout(points, max_vertices = 1) out;\n"
494 << "layout(location = 0) out float out_color;\n"
495 << "layout(binding = 0, std140) uniform Buffer0\n"
496 << "{\n"
497 << " " << subgroups::getFormatNameForGLSL(caseDef.format) << " data[" << subgroups::maxSupportedSubgroupSize() << "];\n"
498 << "};\n"
499 << "\n"
500 << "void main (void)\n"
501 << "{\n"
502 << " uvec4 mask = subgroupBallot(true);\n"
503 << bdy.str()
504 << " out_color = float(tempResult);\n"
505 << " gl_Position = gl_in[0].gl_Position;\n"
506 << " EmitVertex();\n"
507 << " EndPrimitive();\n"
508 << "}\n";
509
510 programCollection.add("geometry") << glu::GeometrySource(geometry.str());
511 }
512 else if (SHADER_STAGE_TESS_CONTROL_BIT == caseDef.shaderStage)
513 {
514 std::ostringstream controlSource;
515 controlSource << "${VERSION_DECL}\n"
516 << "#extension GL_KHR_shader_subgroup_arithmetic: enable\n"
517 << "#extension GL_KHR_shader_subgroup_ballot: enable\n"
518 << "layout(vertices = 2) out;\n"
519 << "layout(location = 0) out float out_color[];\n"
520 << "layout(binding = 0, std140) uniform Buffer0\n"
521 << "{\n"
522 << " " << subgroups::getFormatNameForGLSL(caseDef.format) << " data[" << subgroups::maxSupportedSubgroupSize() << "];\n"
523 << "};\n"
524 << "\n"
525 << "void main (void)\n"
526 << "{\n"
527 << " if (gl_InvocationID == 0)\n"
528 <<" {\n"
529 << " gl_TessLevelOuter[0] = 1.0f;\n"
530 << " gl_TessLevelOuter[1] = 1.0f;\n"
531 << " }\n"
532 << " uvec4 mask = subgroupBallot(true);\n"
533 << bdy.str()
534 << " out_color[gl_InvocationID] = float(tempResult);"
535 << " gl_out[gl_InvocationID].gl_Position = gl_in[gl_InvocationID].gl_Position;\n"
536 << "}\n";
537
538
539 programCollection.add("tesc") << glu::TessellationControlSource(controlSource.str());
540 subgroups::setTesEvalShaderFrameBuffer(programCollection);
541 }
542 else if (SHADER_STAGE_TESS_EVALUATION_BIT == caseDef.shaderStage)
543 {
544
545 std::ostringstream evaluationSource;
546 evaluationSource << "${VERSION_DECL}\n"
547 << "#extension GL_KHR_shader_subgroup_arithmetic: enable\n"
548 << "#extension GL_KHR_shader_subgroup_ballot: enable\n"
549 << "layout(isolines, equal_spacing, ccw ) in;\n"
550 << "layout(location = 0) out float out_color;\n"
551 << "layout(binding = 0, std140) uniform Buffer0\n"
552 << "{\n"
553 << " " << subgroups::getFormatNameForGLSL(caseDef.format) << " data[" << subgroups::maxSupportedSubgroupSize() << "];\n"
554 << "};\n"
555 << "\n"
556 << "void main (void)\n"
557 << "{\n"
558 << " uvec4 mask = subgroupBallot(true);\n"
559 << bdy.str()
560 << " out_color = float(tempResult);\n"
561 << " gl_Position = mix(gl_in[0].gl_Position, gl_in[1].gl_Position, gl_TessCoord.x);\n"
562 << "}\n";
563
564 subgroups::setTesCtrlShaderFrameBuffer(programCollection);
565 programCollection.add("tese") << glu::TessellationEvaluationSource(evaluationSource.str());
566 }
567 else
568 {
569 DE_FATAL("Unsupported shader stage");
570 }
571 }
572
initPrograms(SourceCollections & programCollection,CaseDefinition caseDef)573 void initPrograms(SourceCollections& programCollection, CaseDefinition caseDef)
574 {
575 std::string indexVars;
576 switch (caseDef.opType)
577 {
578 default:
579 indexVars = " uint start = 0u, end = gl_SubgroupSize;\n";
580 break;
581 case OPTYPE_INCLUSIVE_ADD:
582 case OPTYPE_INCLUSIVE_MUL:
583 case OPTYPE_INCLUSIVE_MIN:
584 case OPTYPE_INCLUSIVE_MAX:
585 case OPTYPE_INCLUSIVE_AND:
586 case OPTYPE_INCLUSIVE_OR:
587 case OPTYPE_INCLUSIVE_XOR:
588 indexVars = " uint start = 0u, end = gl_SubgroupInvocationID + 1u;\n";
589 break;
590 case OPTYPE_EXCLUSIVE_ADD:
591 case OPTYPE_EXCLUSIVE_MUL:
592 case OPTYPE_EXCLUSIVE_MIN:
593 case OPTYPE_EXCLUSIVE_MAX:
594 case OPTYPE_EXCLUSIVE_AND:
595 case OPTYPE_EXCLUSIVE_OR:
596 case OPTYPE_EXCLUSIVE_XOR:
597 indexVars = " uint start = 0u, end = gl_SubgroupInvocationID;\n";
598 break;
599 }
600
601 const string bdy =
602 indexVars +
603 " " + subgroups::getFormatNameForGLSL(caseDef.format) + " ref = "
604 + getIdentity(caseDef.opType, caseDef.format) + ";\n"
605 " uint tempResult = 0u;\n"
606 " for (uint index = start; index < end; index++)\n"
607 " {\n"
608 " if (subgroupBallotBitExtract(mask, index))\n"
609 " {\n"
610 " ref = " + getOpTypeOperation(caseDef.opType, caseDef.format, "ref", "data[index]") + ";\n"
611 " }\n"
612 " }\n"
613 " tempResult = " + getCompare(caseDef.opType, caseDef.format, "ref", getOpTypeName(caseDef.opType) + "(data[gl_SubgroupInvocationID])") + " ? 0x1u : 0u;\n"
614 " if (1u == (gl_SubgroupInvocationID % 2u))\n"
615 " {\n"
616 " mask = subgroupBallot(true);\n"
617 " ref = " + getIdentity(caseDef.opType, caseDef.format) + ";\n"
618 " for (uint index = start; index < end; index++)\n"
619 " {\n"
620 " if (subgroupBallotBitExtract(mask, index))\n"
621 " {\n"
622 " ref = " + getOpTypeOperation(caseDef.opType, caseDef.format, "ref", "data[index]") + ";\n"
623 " }\n"
624 " }\n"
625 " tempResult |= " + getCompare(caseDef.opType, caseDef.format, "ref", getOpTypeName(caseDef.opType) + "(data[gl_SubgroupInvocationID])") + " ? 0x2u : 0u;\n"
626 " }\n"
627 " else\n"
628 " {\n"
629 " tempResult |= 0x2u;\n"
630 " }\n";
631
632 if (SHADER_STAGE_COMPUTE_BIT == caseDef.shaderStage)
633 {
634 std::ostringstream src;
635
636 src << "${VERSION_DECL}\n"
637 << "#extension GL_KHR_shader_subgroup_arithmetic: enable\n"
638 << "#extension GL_KHR_shader_subgroup_ballot: enable\n"
639 << "layout (${LOCAL_SIZE_X}, ${LOCAL_SIZE_Y}, ${LOCAL_SIZE_Z}) in;\n"
640 << "layout(binding = 0, std430) buffer Buffer0\n"
641 << "{\n"
642 << " uint result[];\n"
643 << "};\n"
644 << "layout(binding = 1, std430) buffer Buffer1\n"
645 << "{\n"
646 << " " << subgroups::getFormatNameForGLSL(caseDef.format) << " data[];\n"
647 << "};\n"
648 << "\n"
649 << "void main (void)\n"
650 << "{\n"
651 << " uvec3 globalSize = gl_NumWorkGroups * gl_WorkGroupSize;\n"
652 << " highp uint offset = globalSize.x * ((globalSize.y * "
653 "gl_GlobalInvocationID.z) + gl_GlobalInvocationID.y) + "
654 "gl_GlobalInvocationID.x;\n"
655 << " uvec4 mask = subgroupBallot(true);\n"
656 << bdy
657 << " result[offset] = tempResult;\n"
658 << "}\n";
659
660 programCollection.add("comp") << glu::ComputeSource(src.str());
661 }
662 else
663 {
664 {
665 const std::string vertex =
666 "${VERSION_DECL}\n"
667 "#extension GL_KHR_shader_subgroup_arithmetic: enable\n"
668 "#extension GL_KHR_shader_subgroup_ballot: enable\n"
669 "layout(binding = 0, std430) buffer Buffer0\n"
670 "{\n"
671 " uint result[];\n"
672 "} b0;\n"
673 "layout(binding = 4, std430) readonly buffer Buffer4\n"
674 "{\n"
675 " " + subgroups::getFormatNameForGLSL(caseDef.format) + " data[];\n"
676 "};\n"
677 "\n"
678 "void main (void)\n"
679 "{\n"
680 " uvec4 mask = subgroupBallot(true);\n"
681 + bdy+
682 " b0.result[gl_VertexID] = tempResult;\n"
683 " float pixelSize = 2.0f/1024.0f;\n"
684 " float pixelPosition = pixelSize/2.0f - 1.0f;\n"
685 " gl_Position = vec4(float(gl_VertexID) * pixelSize + pixelPosition, 0.0f, 0.0f, 1.0f);\n"
686 " gl_PointSize = 1.0f;\n"
687 "}\n";
688 programCollection.add("vert") << glu::VertexSource(vertex);
689 }
690
691 {
692 const std::string tesc =
693 "${VERSION_DECL}\n"
694 "#extension GL_KHR_shader_subgroup_arithmetic: enable\n"
695 "#extension GL_KHR_shader_subgroup_ballot: enable\n"
696 "layout(vertices=1) out;\n"
697 "layout(binding = 1, std430) buffer Buffer1\n"
698 "{\n"
699 " uint result[];\n"
700 "} b1;\n"
701 "layout(binding = 4, std430) readonly buffer Buffer4\n"
702 "{\n"
703 " " + subgroups::getFormatNameForGLSL(caseDef.format) + " data[];\n"
704 "};\n"
705 "\n"
706 "void main (void)\n"
707 "{\n"
708 " uvec4 mask = subgroupBallot(true);\n"
709 + bdy +
710 " b1.result[gl_PrimitiveID] = tempResult;\n"
711 " if (gl_InvocationID == 0)\n"
712 " {\n"
713 " gl_TessLevelOuter[0] = 1.0f;\n"
714 " gl_TessLevelOuter[1] = 1.0f;\n"
715 " }\n"
716 " gl_out[gl_InvocationID].gl_Position = gl_in[gl_InvocationID].gl_Position;\n"
717 "}\n";
718 programCollection.add("tesc") << glu::TessellationControlSource(tesc);
719 }
720
721 {
722 const std::string tese =
723 "${VERSION_DECL}\n"
724 "#extension GL_KHR_shader_subgroup_arithmetic: enable\n"
725 "#extension GL_KHR_shader_subgroup_ballot: enable\n"
726 "layout(isolines) in;\n"
727 "layout(binding = 2, std430) buffer Buffer2\n"
728 "{\n"
729 " uint result[];\n"
730 "} b2;\n"
731 "layout(binding = 4, std430) readonly buffer Buffer4\n"
732 "{\n"
733 " " + subgroups::getFormatNameForGLSL(caseDef.format) + " data[];\n"
734 "};\n"
735 "\n"
736 "void main (void)\n"
737 "{\n"
738 " uvec4 mask = subgroupBallot(true);\n"
739 + bdy +
740 " b2.result[gl_PrimitiveID * 2 + int(gl_TessCoord.x + 0.5)] = tempResult;\n"
741 " float pixelSize = 2.0f/1024.0f;\n"
742 " gl_Position = gl_in[0].gl_Position + gl_TessCoord.x * pixelSize / 2.0f;\n"
743 "}\n";
744 programCollection.add("tese") << glu::TessellationEvaluationSource(tese);
745 }
746
747 {
748 const std::string geometry =
749 // version added by addGeometryShadersFromTemplate
750 "#extension GL_KHR_shader_subgroup_arithmetic: enable\n"
751 "#extension GL_KHR_shader_subgroup_ballot: enable\n"
752 "layout(${TOPOLOGY}) in;\n"
753 "layout(points, max_vertices = 1) out;\n"
754 "layout(binding = 3, std430) buffer Buffer3\n"
755 "{\n"
756 " uint result[];\n"
757 "} b3;\n"
758 "layout(binding = 4, std430) readonly buffer Buffer4\n"
759 "{\n"
760 " " + subgroups::getFormatNameForGLSL(caseDef.format) + " data[];\n"
761 "};\n"
762 "\n"
763 "void main (void)\n"
764 "{\n"
765 " uvec4 mask = subgroupBallot(true);\n"
766 + bdy +
767 " b3.result[gl_PrimitiveIDIn] = tempResult;\n"
768 " gl_Position = gl_in[0].gl_Position;\n"
769 " EmitVertex();\n"
770 " EndPrimitive();\n"
771 "}\n";
772 subgroups::addGeometryShadersFromTemplate(geometry, programCollection);
773 }
774
775 {
776 const std::string fragment =
777 "${VERSION_DECL}\n"
778 "#extension GL_KHR_shader_subgroup_arithmetic: enable\n"
779 "#extension GL_KHR_shader_subgroup_ballot: enable\n"
780 "precision highp int;\n"
781 "precision highp float;\n"
782 "layout(location = 0) out uint result;\n"
783 "layout(binding = 4, std430) readonly buffer Buffer4\n"
784 "{\n"
785 " " + subgroups::getFormatNameForGLSL(caseDef.format) + " data[];\n"
786 "};\n"
787 "void main (void)\n"
788 "{\n"
789 " uvec4 mask = subgroupBallot(true);\n"
790 + bdy +
791 " result = tempResult;\n"
792 "}\n";
793 programCollection.add("fragment") << glu::FragmentSource(fragment);
794 }
795 subgroups::addNoSubgroupShader(programCollection);
796 }
797 }
798
supportedCheck(Context & context,CaseDefinition caseDef)799 void supportedCheck (Context& context, CaseDefinition caseDef)
800 {
801 if (!subgroups::isSubgroupSupported(context))
802 TCU_THROW(NotSupportedError, "Subgroup operations are not supported");
803
804 if (!subgroups::isSubgroupFeatureSupportedForDevice(context, SUBGROUP_FEATURE_ARITHMETIC_BIT))
805 {
806 TCU_THROW(NotSupportedError, "Device does not support subgroup arithmetic operations");
807 }
808
809 if (subgroups::isDoubleFormat(caseDef.format) &&
810 !subgroups::isDoubleSupportedForDevice(context))
811 {
812 TCU_THROW(NotSupportedError, "Device does not support subgroup double operations");
813 }
814 }
815
noSSBOtest(Context & context,const CaseDefinition caseDef)816 tcu::TestStatus noSSBOtest (Context& context, const CaseDefinition caseDef)
817 {
818 if (!subgroups::areSubgroupOperationsSupportedForStage(
819 context, caseDef.shaderStage))
820 {
821 if (subgroups::areSubgroupOperationsRequiredForStage(
822 caseDef.shaderStage))
823 {
824 return tcu::TestStatus::fail(
825 "Shader stage " +
826 subgroups::getShaderStageName(caseDef.shaderStage) +
827 " is required to support subgroup operations!");
828 }
829 else
830 {
831 TCU_THROW(NotSupportedError, "Device does not support subgroup operations for this stage");
832 }
833 }
834
835 subgroups::SSBOData inputData;
836 inputData.format = caseDef.format;
837 inputData.layout = subgroups::SSBOData::LayoutStd140;
838 inputData.numElements = subgroups::maxSupportedSubgroupSize();
839 inputData.initializeType = subgroups::SSBOData::InitializeNonZero;
840 inputData.binding = 0u;
841
842 if (SHADER_STAGE_VERTEX_BIT == caseDef.shaderStage)
843 return subgroups::makeVertexFrameBufferTest(context, FORMAT_R32_UINT, &inputData, 1, checkVertexPipelineStages);
844 else if (SHADER_STAGE_GEOMETRY_BIT == caseDef.shaderStage)
845 return subgroups::makeGeometryFrameBufferTest(context, FORMAT_R32_UINT, &inputData, 1, checkVertexPipelineStages);
846 else if (SHADER_STAGE_TESS_CONTROL_BIT == caseDef.shaderStage)
847 return subgroups::makeTessellationEvaluationFrameBufferTest(context, FORMAT_R32_UINT, &inputData, 1, checkVertexPipelineStages, SHADER_STAGE_TESS_CONTROL_BIT);
848 else if (SHADER_STAGE_TESS_EVALUATION_BIT == caseDef.shaderStage)
849 return subgroups::makeTessellationEvaluationFrameBufferTest(context, FORMAT_R32_UINT, &inputData, 1, checkVertexPipelineStages, SHADER_STAGE_TESS_EVALUATION_BIT);
850 else
851 TCU_THROW(InternalError, "Unhandled shader stage");
852 }
853
checkShaderStages(Context & context,const CaseDefinition & caseDef)854 bool checkShaderStages (Context& context, const CaseDefinition& caseDef)
855 {
856 if (!subgroups::areSubgroupOperationsSupportedForStage(
857 context, caseDef.shaderStage))
858 {
859 if (subgroups::areSubgroupOperationsRequiredForStage(
860 caseDef.shaderStage))
861 {
862 return false;
863 }
864 else
865 {
866 TCU_THROW(NotSupportedError, "Device does not support subgroup operations for this stage");
867 }
868 }
869 return true;
870 }
871
test(Context & context,const CaseDefinition caseDef)872 tcu::TestStatus test(Context& context, const CaseDefinition caseDef)
873 {
874 if (SHADER_STAGE_COMPUTE_BIT == caseDef.shaderStage)
875 {
876 if(!checkShaderStages(context,caseDef))
877 {
878 return tcu::TestStatus::fail(
879 "Shader stage " +
880 subgroups::getShaderStageName(caseDef.shaderStage) +
881 " is required to support subgroup operations!");
882 }
883 subgroups::SSBOData inputData;
884 inputData.format = caseDef.format;
885 inputData.layout = subgroups::SSBOData::LayoutStd430;
886 inputData.numElements = subgroups::maxSupportedSubgroupSize();
887 inputData.initializeType = subgroups::SSBOData::InitializeNonZero;
888 inputData.binding = 1u;
889
890 return subgroups::makeComputeTest(context, FORMAT_R32_UINT, &inputData, 1, checkComputeStage);
891 }
892 else
893 {
894 int supportedStages = context.getDeqpContext().getContextInfo().getInt(GL_SUBGROUP_SUPPORTED_STAGES_KHR);
895
896 ShaderStageFlags stages = (ShaderStageFlags)(caseDef.shaderStage & supportedStages);
897
898 if ( SHADER_STAGE_FRAGMENT_BIT != stages && !subgroups::isVertexSSBOSupportedForDevice(context))
899 {
900 if ( (stages & SHADER_STAGE_FRAGMENT_BIT) == 0)
901 TCU_THROW(NotSupportedError, "Device does not support vertex stage SSBO writes");
902 else
903 stages = SHADER_STAGE_FRAGMENT_BIT;
904 }
905
906 if ((ShaderStageFlags)0u == stages)
907 TCU_THROW(NotSupportedError, "Subgroup operations are not supported for any graphic shader");
908
909 subgroups::SSBOData inputData;
910 inputData.format = caseDef.format;
911 inputData.layout = subgroups::SSBOData::LayoutStd430;
912 inputData.numElements = subgroups::maxSupportedSubgroupSize();
913 inputData.initializeType = subgroups::SSBOData::InitializeNonZero;
914 inputData.binding = 4u;
915 inputData.stages = stages;
916
917 return subgroups::allStages(context, FORMAT_R32_UINT, &inputData,
918 1, checkVertexPipelineStages, stages);
919 }
920 }
921 }
922
createSubgroupsArithmeticTests(deqp::Context & testCtx)923 deqp::TestCaseGroup* createSubgroupsArithmeticTests(deqp::Context& testCtx)
924 {
925 de::MovePtr<deqp::TestCaseGroup> graphicGroup(new deqp::TestCaseGroup(
926 testCtx, "graphics", "Subgroup arithmetic category tests: graphics"));
927 de::MovePtr<deqp::TestCaseGroup> computeGroup(new deqp::TestCaseGroup(
928 testCtx, "compute", "Subgroup arithmetic category tests: compute"));
929 de::MovePtr<deqp::TestCaseGroup> framebufferGroup(new deqp::TestCaseGroup(
930 testCtx, "framebuffer", "Subgroup arithmetic category tests: framebuffer"));
931
932 const ShaderStageFlags stages[] =
933 {
934 SHADER_STAGE_VERTEX_BIT,
935 SHADER_STAGE_TESS_EVALUATION_BIT,
936 SHADER_STAGE_TESS_CONTROL_BIT,
937 SHADER_STAGE_GEOMETRY_BIT,
938 };
939
940 const Format formats[] =
941 {
942 FORMAT_R32_SINT, FORMAT_R32G32_SINT, FORMAT_R32G32B32_SINT,
943 FORMAT_R32G32B32A32_SINT, FORMAT_R32_UINT, FORMAT_R32G32_UINT,
944 FORMAT_R32G32B32_UINT, FORMAT_R32G32B32A32_UINT,
945 FORMAT_R32_SFLOAT, FORMAT_R32G32_SFLOAT,
946 FORMAT_R32G32B32_SFLOAT, FORMAT_R32G32B32A32_SFLOAT,
947 FORMAT_R64_SFLOAT, FORMAT_R64G64_SFLOAT,
948 FORMAT_R64G64B64_SFLOAT, FORMAT_R64G64B64A64_SFLOAT,
949 FORMAT_R32_BOOL, FORMAT_R32G32_BOOL,
950 FORMAT_R32G32B32_BOOL, FORMAT_R32G32B32A32_BOOL,
951 };
952
953 for (int formatIndex = 0; formatIndex < DE_LENGTH_OF_ARRAY(formats); ++formatIndex)
954 {
955 const Format format = formats[formatIndex];
956
957 for (int opTypeIndex = 0; opTypeIndex < OPTYPE_LAST; ++opTypeIndex)
958 {
959 bool isBool = false;
960 bool isFloat = false;
961
962 switch (format)
963 {
964 default:
965 break;
966 case FORMAT_R32_SFLOAT:
967 case FORMAT_R32G32_SFLOAT:
968 case FORMAT_R32G32B32_SFLOAT:
969 case FORMAT_R32G32B32A32_SFLOAT:
970 case FORMAT_R64_SFLOAT:
971 case FORMAT_R64G64_SFLOAT:
972 case FORMAT_R64G64B64_SFLOAT:
973 case FORMAT_R64G64B64A64_SFLOAT:
974 isFloat = true;
975 break;
976 case FORMAT_R32_BOOL:
977 case FORMAT_R32G32_BOOL:
978 case FORMAT_R32G32B32_BOOL:
979 case FORMAT_R32G32B32A32_BOOL:
980 isBool = true;
981 break;
982 }
983
984 bool isBitwiseOp = false;
985
986 switch (opTypeIndex)
987 {
988 default:
989 break;
990 case OPTYPE_AND:
991 case OPTYPE_INCLUSIVE_AND:
992 case OPTYPE_EXCLUSIVE_AND:
993 case OPTYPE_OR:
994 case OPTYPE_INCLUSIVE_OR:
995 case OPTYPE_EXCLUSIVE_OR:
996 case OPTYPE_XOR:
997 case OPTYPE_INCLUSIVE_XOR:
998 case OPTYPE_EXCLUSIVE_XOR:
999 isBitwiseOp = true;
1000 break;
1001 }
1002
1003 if (isFloat && isBitwiseOp)
1004 {
1005 // Skip float with bitwise category.
1006 continue;
1007 }
1008
1009 if (isBool && !isBitwiseOp)
1010 {
1011 // Skip bool when its not the bitwise category.
1012 continue;
1013 }
1014 std::string op = getOpTypeName(opTypeIndex);
1015
1016 {
1017 const CaseDefinition caseDef = {opTypeIndex, SHADER_STAGE_COMPUTE_BIT, format};
1018 SubgroupFactory<CaseDefinition>::addFunctionCaseWithPrograms(computeGroup.get(),
1019 de::toLower(op) + "_" +
1020 subgroups::getFormatNameForGLSL(format),
1021 "", supportedCheck, initPrograms, test, caseDef);
1022 }
1023
1024 {
1025 const CaseDefinition caseDef = {opTypeIndex, SHADER_STAGE_ALL_GRAPHICS, format};
1026 SubgroupFactory<CaseDefinition>::addFunctionCaseWithPrograms(graphicGroup.get(),
1027 de::toLower(op) + "_" +
1028 subgroups::getFormatNameForGLSL(format),
1029 "", supportedCheck, initPrograms, test, caseDef);
1030 }
1031
1032 for (int stageIndex = 0; stageIndex < DE_LENGTH_OF_ARRAY(stages); ++stageIndex)
1033 {
1034 const CaseDefinition caseDef = {opTypeIndex, stages[stageIndex], format};
1035 SubgroupFactory<CaseDefinition>::addFunctionCaseWithPrograms(framebufferGroup.get(), de::toLower(op) + "_" + subgroups::getFormatNameForGLSL(format) +
1036 "_" + getShaderStageName(caseDef.shaderStage), "",
1037 supportedCheck, initFrameBufferPrograms, noSSBOtest, caseDef);
1038 }
1039 }
1040 }
1041
1042 de::MovePtr<deqp::TestCaseGroup> group(new deqp::TestCaseGroup(
1043 testCtx, "arithmetic", "Subgroup arithmetic category tests"));
1044
1045 group->addChild(graphicGroup.release());
1046 group->addChild(computeGroup.release());
1047 group->addChild(framebufferGroup.release());
1048
1049 return group.release();
1050 }
1051
1052 } // subgroups
1053 } // glc
1054