1 /*------------------------------------------------------------------------
2 * Vulkan Conformance Tests
3 * ------------------------
4 *
5 * Copyright (c) 2019 The Khronos Group Inc.
6 * Copyright (c) 2018-2020 NVIDIA Corporation
7 *
8 * Licensed under the Apache License, Version 2.0 (the "License");
9 * you may not use this file except in compliance with the License.
10 * You may obtain a copy of the License at
11 *
12 * http://www.apache.org/licenses/LICENSE-2.0
13 *
14 * Unless required by applicable law or agreed to in writing, software
15 * distributed under the License is distributed on an "AS IS" BASIS,
16 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17 * See the License for the specific language governing permissions and
18 * limitations under the License.
19 *
20 *//*!
21 * \file
22 * \brief Vulkan Reconvergence tests
23 *//*--------------------------------------------------------------------*/
24
25 #include "vktReconvergenceTests.hpp"
26
27 #include "vkBufferWithMemory.hpp"
28 #include "vkImageWithMemory.hpp"
29 #include "vkQueryUtil.hpp"
30 #include "vkBuilderUtil.hpp"
31 #include "vkCmdUtil.hpp"
32 #include "vkTypeUtil.hpp"
33 #include "vkObjUtil.hpp"
34
35 #include "vktTestGroupUtil.hpp"
36 #include "vktTestCase.hpp"
37
38 #include "deDefs.h"
39 #include "deFloat16.h"
40 #include "deMath.h"
41 #include "deRandom.h"
42 #include "deSharedPtr.hpp"
43 #include "deString.h"
44
45 #include "tcuTestCase.hpp"
46 #include "tcuTestLog.hpp"
47
48 #include <bitset>
49 #include <string>
50 #include <sstream>
51 #include <set>
52 #include <vector>
53
54 namespace vkt
55 {
56 namespace Reconvergence
57 {
58 namespace
59 {
60 using namespace vk;
61 using namespace std;
62
63 #define ARRAYSIZE(x) (sizeof(x) / sizeof(x[0]))
64
65 const VkFlags allShaderStages = VK_SHADER_STAGE_COMPUTE_BIT;
66
67 typedef enum {
68 TT_SUCF_ELECT, // subgroup_uniform_control_flow using elect (subgroup_basic)
69 TT_SUCF_BALLOT, // subgroup_uniform_control_flow using ballot (subgroup_ballot)
70 TT_WUCF_ELECT, // workgroup uniform control flow using elect (subgroup_basic)
71 TT_WUCF_BALLOT, // workgroup uniform control flow using ballot (subgroup_ballot)
72 TT_MAXIMAL, // maximal reconvergence
73 } TestType;
74
75 struct CaseDef
76 {
77 TestType testType;
78 deUint32 maxNesting;
79 deUint32 seed;
80
isWUCFvkt::Reconvergence::__anon5fc7ae020111::CaseDef81 bool isWUCF() const { return testType == TT_WUCF_ELECT || testType == TT_WUCF_BALLOT; }
isSUCFvkt::Reconvergence::__anon5fc7ae020111::CaseDef82 bool isSUCF() const { return testType == TT_SUCF_ELECT || testType == TT_SUCF_BALLOT; }
isUCFvkt::Reconvergence::__anon5fc7ae020111::CaseDef83 bool isUCF() const { return isWUCF() || isSUCF(); }
isElectvkt::Reconvergence::__anon5fc7ae020111::CaseDef84 bool isElect() const { return testType == TT_WUCF_ELECT || testType == TT_SUCF_ELECT; }
85 };
86
subgroupSizeToMask(deUint32 subgroupSize)87 deUint64 subgroupSizeToMask(deUint32 subgroupSize)
88 {
89 if (subgroupSize == 64)
90 return ~0ULL;
91 else
92 return (1ULL << subgroupSize) - 1;
93 }
94
95 typedef std::bitset<128> bitset128;
96
97 // Take a 64-bit integer, mask it to the subgroup size, and then
98 // replicate it for each subgroup
bitsetFromU64(deUint64 mask,deUint32 subgroupSize)99 bitset128 bitsetFromU64(deUint64 mask, deUint32 subgroupSize)
100 {
101 mask &= subgroupSizeToMask(subgroupSize);
102 bitset128 result(mask);
103 for (deUint32 i = 0; i < 128 / subgroupSize - 1; ++i)
104 {
105 result = (result << subgroupSize) | bitset128(mask);
106 }
107 return result;
108 }
109
110 // Pick out the mask for the subgroup that invocationID is a member of
bitsetToU64(const bitset128 & bitset,deUint32 subgroupSize,deUint32 invocationID)111 deUint64 bitsetToU64(const bitset128 &bitset, deUint32 subgroupSize, deUint32 invocationID)
112 {
113 bitset128 copy(bitset);
114 copy >>= (invocationID / subgroupSize) * subgroupSize;
115 copy &= bitset128(subgroupSizeToMask(subgroupSize));
116 deUint64 mask = copy.to_ullong();
117 mask &= subgroupSizeToMask(subgroupSize);
118 return mask;
119 }
120
121 class ReconvergenceTestInstance : public TestInstance
122 {
123 public:
124 ReconvergenceTestInstance (Context& context, const CaseDef& data);
125 ~ReconvergenceTestInstance (void);
126 tcu::TestStatus iterate (void);
127 private:
128 CaseDef m_data;
129 };
130
ReconvergenceTestInstance(Context & context,const CaseDef & data)131 ReconvergenceTestInstance::ReconvergenceTestInstance (Context& context, const CaseDef& data)
132 : vkt::TestInstance (context)
133 , m_data (data)
134 {
135 }
136
~ReconvergenceTestInstance(void)137 ReconvergenceTestInstance::~ReconvergenceTestInstance (void)
138 {
139 }
140
141 class ReconvergenceTestCase : public TestCase
142 {
143 public:
144 ReconvergenceTestCase (tcu::TestContext& context, const char* name, const char* desc, const CaseDef data);
145 ~ReconvergenceTestCase (void);
146 virtual void initPrograms (SourceCollections& programCollection) const;
147 virtual TestInstance* createInstance (Context& context) const;
148 virtual void checkSupport (Context& context) const;
149
150 private:
151 CaseDef m_data;
152 };
153
ReconvergenceTestCase(tcu::TestContext & context,const char * name,const char * desc,const CaseDef data)154 ReconvergenceTestCase::ReconvergenceTestCase (tcu::TestContext& context, const char* name, const char* desc, const CaseDef data)
155 : vkt::TestCase (context, name, desc)
156 , m_data (data)
157 {
158 }
159
~ReconvergenceTestCase(void)160 ReconvergenceTestCase::~ReconvergenceTestCase (void)
161 {
162 }
163
checkSupport(Context & context) const164 void ReconvergenceTestCase::checkSupport(Context& context) const
165 {
166 if (!context.contextSupports(vk::ApiVersion(1, 1, 0)))
167 TCU_THROW(NotSupportedError, "Vulkan 1.1 not supported");
168
169 vk::VkPhysicalDeviceSubgroupProperties subgroupProperties;
170 deMemset(&subgroupProperties, 0, sizeof(subgroupProperties));
171 subgroupProperties.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SUBGROUP_PROPERTIES;
172
173 vk::VkPhysicalDeviceProperties2 properties2;
174 deMemset(&properties2, 0, sizeof(properties2));
175 properties2.sType = vk::VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_PROPERTIES_2;
176 properties2.pNext = &subgroupProperties;
177
178 context.getInstanceInterface().getPhysicalDeviceProperties2(context.getPhysicalDevice(), &properties2);
179
180 if (m_data.isElect() && !(subgroupProperties.supportedOperations & VK_SUBGROUP_FEATURE_BASIC_BIT))
181 TCU_THROW(NotSupportedError, "VK_SUBGROUP_FEATURE_BASIC_BIT not supported");
182
183 if (!m_data.isElect() && !(subgroupProperties.supportedOperations & VK_SUBGROUP_FEATURE_BALLOT_BIT))
184 TCU_THROW(NotSupportedError, "VK_SUBGROUP_FEATURE_BALLOT_BIT not supported");
185
186 if (!(context.getSubgroupProperties().supportedStages & VK_SHADER_STAGE_COMPUTE_BIT))
187 TCU_THROW(NotSupportedError, "compute stage does not support subgroup operations");
188
189 // Both subgroup- AND workgroup-uniform tests are enabled by shaderSubgroupUniformControlFlow.
190 if (m_data.isUCF() && !context.getShaderSubgroupUniformControlFlowFeatures().shaderSubgroupUniformControlFlow)
191 TCU_THROW(NotSupportedError, "shaderSubgroupUniformControlFlow not supported");
192
193 // XXX TODO: Check for maximal reconvergence support
194 // if (m_data.testType == TT_MAXIMAL ...)
195 }
196
197 typedef enum
198 {
199 // store subgroupBallot().
200 // For OP_BALLOT, OP::caseValue is initialized to zero, and then
201 // set to 1 by simulate if the ballot is not workgroup- (or subgroup-_uniform.
202 // Only workgroup-uniform ballots are validated for correctness in
203 // WUCF modes.
204 OP_BALLOT,
205
206 // store literal constant
207 OP_STORE,
208
209 // if ((1ULL << gl_SubgroupInvocationID) & mask).
210 // Special case if mask = ~0ULL, converted into "if (inputA.a[idx] == idx)"
211 OP_IF_MASK,
212 OP_ELSE_MASK,
213 OP_ENDIF,
214
215 // if (gl_SubgroupInvocationID == loopIdxN) (where N is most nested loop counter)
216 OP_IF_LOOPCOUNT,
217 OP_ELSE_LOOPCOUNT,
218
219 // if (gl_LocalInvocationIndex >= inputA.a[N]) (where N is most nested loop counter)
220 OP_IF_LOCAL_INVOCATION_INDEX,
221 OP_ELSE_LOCAL_INVOCATION_INDEX,
222
223 // break/continue
224 OP_BREAK,
225 OP_CONTINUE,
226
227 // if (subgroupElect())
228 OP_ELECT,
229
230 // Loop with uniform number of iterations (read from a buffer)
231 OP_BEGIN_FOR_UNIF,
232 OP_END_FOR_UNIF,
233
234 // for (int loopIdxN = 0; loopIdxN < gl_SubgroupInvocationID + 1; ++loopIdxN)
235 OP_BEGIN_FOR_VAR,
236 OP_END_FOR_VAR,
237
238 // for (int loopIdxN = 0;; ++loopIdxN, OP_BALLOT)
239 // Always has an "if (subgroupElect()) break;" inside.
240 // Does the equivalent of OP_BALLOT in the continue construct
241 OP_BEGIN_FOR_INF,
242 OP_END_FOR_INF,
243
244 // do { loopIdxN++; ... } while (loopIdxN < uniformValue);
245 OP_BEGIN_DO_WHILE_UNIF,
246 OP_END_DO_WHILE_UNIF,
247
248 // do { ... } while (true);
249 // Always has an "if (subgroupElect()) break;" inside
250 OP_BEGIN_DO_WHILE_INF,
251 OP_END_DO_WHILE_INF,
252
253 // return;
254 OP_RETURN,
255
256 // function call (code bracketed by these is extracted into a separate function)
257 OP_CALL_BEGIN,
258 OP_CALL_END,
259
260 // switch statement on uniform value
261 OP_SWITCH_UNIF_BEGIN,
262 // switch statement on gl_SubgroupInvocationID & 3 value
263 OP_SWITCH_VAR_BEGIN,
264 // switch statement on loopIdx value
265 OP_SWITCH_LOOP_COUNT_BEGIN,
266
267 // case statement with a (invocation mask, case mask) pair
268 OP_CASE_MASK_BEGIN,
269 // case statement used for loop counter switches, with a value and a mask of loop iterations
270 OP_CASE_LOOP_COUNT_BEGIN,
271
272 // end of switch/case statement
273 OP_SWITCH_END,
274 OP_CASE_END,
275
276 // Extra code with no functional effect. Currently inculdes:
277 // - value 0: while (!subgroupElect()) {}
278 // - value 1: if (condition_that_is_false) { infinite loop }
279 OP_NOISE,
280 } OPType;
281
282 typedef enum
283 {
284 // Different if test conditions
285 IF_MASK,
286 IF_UNIFORM,
287 IF_LOOPCOUNT,
288 IF_LOCAL_INVOCATION_INDEX,
289 } IFType;
290
291 class OP
292 {
293 public:
OP(OPType _type,deUint64 _value,deUint32 _caseValue=0)294 OP(OPType _type, deUint64 _value, deUint32 _caseValue = 0)
295 : type(_type), value(_value), caseValue(_caseValue)
296 {}
297
298 // The type of operation and an optional value.
299 // The value could be a mask for an if test, the index of the loop
300 // header for an end of loop, or the constant value for a store instruction
301 OPType type;
302 deUint64 value;
303 deUint32 caseValue;
304 };
305
findLSB(deUint64 value)306 static int findLSB (deUint64 value)
307 {
308 for (int i = 0; i < 64; i++)
309 {
310 if (value & (1ULL<<i))
311 return i;
312 }
313 return -1;
314 }
315
316 // For each subgroup, pick out the elected invocationID, and accumulate
317 // a bitset of all of them
bitsetElect(const bitset128 & value,deInt32 subgroupSize)318 static bitset128 bitsetElect (const bitset128& value, deInt32 subgroupSize)
319 {
320 bitset128 ret; // zero initialized
321
322 for (deInt32 i = 0; i < 128; i += subgroupSize)
323 {
324 deUint64 mask = bitsetToU64(value, subgroupSize, i);
325 int lsb = findLSB(mask);
326 ret |= bitset128(lsb == -1 ? 0 : (1ULL << lsb)) << i;
327 }
328 return ret;
329 }
330
331 class RandomProgram
332 {
333 public:
RandomProgram(const CaseDef & c)334 RandomProgram(const CaseDef &c)
335 : caseDef(c), numMasks(5), nesting(0), maxNesting(c.maxNesting), loopNesting(0), loopNestingThisFunction(0), callNesting(0), minCount(30), indent(0), isLoopInf(100, false), doneInfLoopBreak(100, false), storeBase(0x10000)
336 {
337 deRandom_init(&rnd, caseDef.seed);
338 for (int i = 0; i < numMasks; ++i)
339 masks.push_back(deRandom_getUint64(&rnd));
340 }
341
342 const CaseDef caseDef;
343 deRandom rnd;
344 vector<OP> ops;
345 vector<deUint64> masks;
346 deInt32 numMasks;
347 deInt32 nesting;
348 deInt32 maxNesting;
349 deInt32 loopNesting;
350 deInt32 loopNestingThisFunction;
351 deInt32 callNesting;
352 deInt32 minCount;
353 deInt32 indent;
354 vector<bool> isLoopInf;
355 vector<bool> doneInfLoopBreak;
356 // Offset the value we use for OP_STORE, to avoid colliding with fully converged
357 // active masks with small subgroup sizes (e.g. with subgroupSize == 4, the SUCF
358 // tests need to know that 0xF is really an active mask).
359 deInt32 storeBase;
360
genIf(IFType ifType)361 void genIf(IFType ifType)
362 {
363 deUint32 maskIdx = deRandom_getUint32(&rnd) % numMasks;
364 deUint64 mask = masks[maskIdx];
365 if (ifType == IF_UNIFORM)
366 mask = ~0ULL;
367
368 deUint32 localIndexCmp = deRandom_getUint32(&rnd) % 128;
369 if (ifType == IF_LOCAL_INVOCATION_INDEX)
370 ops.push_back({OP_IF_LOCAL_INVOCATION_INDEX, localIndexCmp});
371 else if (ifType == IF_LOOPCOUNT)
372 ops.push_back({OP_IF_LOOPCOUNT, 0});
373 else
374 ops.push_back({OP_IF_MASK, mask});
375
376 nesting++;
377
378 size_t thenBegin = ops.size();
379 pickOP(2);
380 size_t thenEnd = ops.size();
381
382 deUint32 randElse = (deRandom_getUint32(&rnd) % 100);
383 if (randElse < 50)
384 {
385 if (ifType == IF_LOCAL_INVOCATION_INDEX)
386 ops.push_back({OP_ELSE_LOCAL_INVOCATION_INDEX, localIndexCmp});
387 else if (ifType == IF_LOOPCOUNT)
388 ops.push_back({OP_ELSE_LOOPCOUNT, 0});
389 else
390 ops.push_back({OP_ELSE_MASK, 0});
391
392 if (randElse < 10)
393 {
394 // Sometimes make the else block identical to the then block
395 for (size_t i = thenBegin; i < thenEnd; ++i)
396 ops.push_back(ops[i]);
397 }
398 else
399 pickOP(2);
400 }
401 ops.push_back({OP_ENDIF, 0});
402 nesting--;
403 }
404
genForUnif()405 void genForUnif()
406 {
407 deUint32 iterCount = (deRandom_getUint32(&rnd) % 5) + 1;
408 ops.push_back({OP_BEGIN_FOR_UNIF, iterCount});
409 deUint32 loopheader = (deUint32)ops.size()-1;
410 nesting++;
411 loopNesting++;
412 loopNestingThisFunction++;
413 pickOP(2);
414 ops.push_back({OP_END_FOR_UNIF, loopheader});
415 loopNestingThisFunction--;
416 loopNesting--;
417 nesting--;
418 }
419
genDoWhileUnif()420 void genDoWhileUnif()
421 {
422 deUint32 iterCount = (deRandom_getUint32(&rnd) % 5) + 1;
423 ops.push_back({OP_BEGIN_DO_WHILE_UNIF, iterCount});
424 deUint32 loopheader = (deUint32)ops.size()-1;
425 nesting++;
426 loopNesting++;
427 loopNestingThisFunction++;
428 pickOP(2);
429 ops.push_back({OP_END_DO_WHILE_UNIF, loopheader});
430 loopNestingThisFunction--;
431 loopNesting--;
432 nesting--;
433 }
434
genForVar()435 void genForVar()
436 {
437 ops.push_back({OP_BEGIN_FOR_VAR, 0});
438 deUint32 loopheader = (deUint32)ops.size()-1;
439 nesting++;
440 loopNesting++;
441 loopNestingThisFunction++;
442 pickOP(2);
443 ops.push_back({OP_END_FOR_VAR, loopheader});
444 loopNestingThisFunction--;
445 loopNesting--;
446 nesting--;
447 }
448
genForInf()449 void genForInf()
450 {
451 ops.push_back({OP_BEGIN_FOR_INF, 0});
452 deUint32 loopheader = (deUint32)ops.size()-1;
453
454 nesting++;
455 loopNesting++;
456 loopNestingThisFunction++;
457 isLoopInf[loopNesting] = true;
458 doneInfLoopBreak[loopNesting] = false;
459
460 pickOP(2);
461
462 genElect(true);
463 doneInfLoopBreak[loopNesting] = true;
464
465 pickOP(2);
466
467 ops.push_back({OP_END_FOR_INF, loopheader});
468
469 isLoopInf[loopNesting] = false;
470 doneInfLoopBreak[loopNesting] = false;
471 loopNestingThisFunction--;
472 loopNesting--;
473 nesting--;
474 }
475
genDoWhileInf()476 void genDoWhileInf()
477 {
478 ops.push_back({OP_BEGIN_DO_WHILE_INF, 0});
479 deUint32 loopheader = (deUint32)ops.size()-1;
480
481 nesting++;
482 loopNesting++;
483 loopNestingThisFunction++;
484 isLoopInf[loopNesting] = true;
485 doneInfLoopBreak[loopNesting] = false;
486
487 pickOP(2);
488
489 genElect(true);
490 doneInfLoopBreak[loopNesting] = true;
491
492 pickOP(2);
493
494 ops.push_back({OP_END_DO_WHILE_INF, loopheader});
495
496 isLoopInf[loopNesting] = false;
497 doneInfLoopBreak[loopNesting] = false;
498 loopNestingThisFunction--;
499 loopNesting--;
500 nesting--;
501 }
502
genBreak()503 void genBreak()
504 {
505 if (loopNestingThisFunction > 0)
506 {
507 // Sometimes put the break in a divergent if
508 if ((deRandom_getUint32(&rnd) % 100) < 10)
509 {
510 ops.push_back({OP_IF_MASK, masks[0]});
511 ops.push_back({OP_BREAK, 0});
512 ops.push_back({OP_ELSE_MASK, 0});
513 ops.push_back({OP_BREAK, 0});
514 ops.push_back({OP_ENDIF, 0});
515 }
516 else
517 ops.push_back({OP_BREAK, 0});
518 }
519 }
520
genContinue()521 void genContinue()
522 {
523 // continues are allowed if we're in a loop and the loop is not infinite,
524 // or if it is infinite and we've already done a subgroupElect+break.
525 // However, adding more continues seems to reduce the failure rate, so
526 // disable it for now
527 if (loopNestingThisFunction > 0 && !(isLoopInf[loopNesting] /*&& !doneInfLoopBreak[loopNesting]*/))
528 {
529 // Sometimes put the continue in a divergent if
530 if ((deRandom_getUint32(&rnd) % 100) < 10)
531 {
532 ops.push_back({OP_IF_MASK, masks[0]});
533 ops.push_back({OP_CONTINUE, 0});
534 ops.push_back({OP_ELSE_MASK, 0});
535 ops.push_back({OP_CONTINUE, 0});
536 ops.push_back({OP_ENDIF, 0});
537 }
538 else
539 ops.push_back({OP_CONTINUE, 0});
540 }
541 }
542
543 // doBreak is used to generate "if (subgroupElect()) { ... break; }" inside infinite loops
genElect(bool doBreak)544 void genElect(bool doBreak)
545 {
546 ops.push_back({OP_ELECT, 0});
547 nesting++;
548 if (doBreak)
549 {
550 // Put something interestign before the break
551 optBallot();
552 optBallot();
553 if ((deRandom_getUint32(&rnd) % 100) < 10)
554 pickOP(1);
555
556 // if we're in a function, sometimes use return instead
557 if (callNesting > 0 && (deRandom_getUint32(&rnd) % 100) < 30)
558 ops.push_back({OP_RETURN, 0});
559 else
560 genBreak();
561
562 }
563 else
564 pickOP(2);
565
566 ops.push_back({OP_ENDIF, 0});
567 nesting--;
568 }
569
genReturn()570 void genReturn()
571 {
572 deUint32 r = deRandom_getUint32(&rnd) % 100;
573 if (nesting > 0 &&
574 // Use return rarely in main, 20% of the time in a singly nested loop in a function
575 // and 50% of the time in a multiply nested loop in a function
576 (r < 5 ||
577 (callNesting > 0 && loopNestingThisFunction > 0 && r < 20) ||
578 (callNesting > 0 && loopNestingThisFunction > 1 && r < 50)))
579 {
580 optBallot();
581 if ((deRandom_getUint32(&rnd) % 100) < 10)
582 {
583 ops.push_back({OP_IF_MASK, masks[0]});
584 ops.push_back({OP_RETURN, 0});
585 ops.push_back({OP_ELSE_MASK, 0});
586 ops.push_back({OP_RETURN, 0});
587 ops.push_back({OP_ENDIF, 0});
588 }
589 else
590 ops.push_back({OP_RETURN, 0});
591 }
592 }
593
594 // Generate a function call. Save and restore some loop information, which is used to
595 // determine when it's safe to use break/continue
genCall()596 void genCall()
597 {
598 ops.push_back({OP_CALL_BEGIN, 0});
599 callNesting++;
600 nesting++;
601 deInt32 saveLoopNestingThisFunction = loopNestingThisFunction;
602 loopNestingThisFunction = 0;
603
604 pickOP(2);
605
606 loopNestingThisFunction = saveLoopNestingThisFunction;
607 nesting--;
608 callNesting--;
609 ops.push_back({OP_CALL_END, 0});
610 }
611
612 // Generate switch on a uniform value:
613 // switch (inputA.a[r]) {
614 // case r+1: ... break; // should not execute
615 // case r: ... break; // should branch uniformly
616 // case r+2: ... break; // should not execute
617 // }
genSwitchUnif()618 void genSwitchUnif()
619 {
620 deUint32 r = deRandom_getUint32(&rnd) % 5;
621 ops.push_back({OP_SWITCH_UNIF_BEGIN, r});
622 nesting++;
623
624 ops.push_back({OP_CASE_MASK_BEGIN, 0, 1u<<(r+1)});
625 pickOP(1);
626 ops.push_back({OP_CASE_END, 0});
627
628 ops.push_back({OP_CASE_MASK_BEGIN, ~0ULL, 1u<<r});
629 pickOP(2);
630 ops.push_back({OP_CASE_END, 0});
631
632 ops.push_back({OP_CASE_MASK_BEGIN, 0, 1u<<(r+2)});
633 pickOP(1);
634 ops.push_back({OP_CASE_END, 0});
635
636 ops.push_back({OP_SWITCH_END, 0});
637 nesting--;
638 }
639
640 // switch (gl_SubgroupInvocationID & 3) with four unique targets
genSwitchVar()641 void genSwitchVar()
642 {
643 ops.push_back({OP_SWITCH_VAR_BEGIN, 0});
644 nesting++;
645
646 ops.push_back({OP_CASE_MASK_BEGIN, 0x1111111111111111ULL, 1<<0});
647 pickOP(1);
648 ops.push_back({OP_CASE_END, 0});
649
650 ops.push_back({OP_CASE_MASK_BEGIN, 0x2222222222222222ULL, 1<<1});
651 pickOP(1);
652 ops.push_back({OP_CASE_END, 0});
653
654 ops.push_back({OP_CASE_MASK_BEGIN, 0x4444444444444444ULL, 1<<2});
655 pickOP(1);
656 ops.push_back({OP_CASE_END, 0});
657
658 ops.push_back({OP_CASE_MASK_BEGIN, 0x8888888888888888ULL, 1<<3});
659 pickOP(1);
660 ops.push_back({OP_CASE_END, 0});
661
662 ops.push_back({OP_SWITCH_END, 0});
663 nesting--;
664 }
665
666 // switch (gl_SubgroupInvocationID & 3) with two shared targets.
667 // XXX TODO: The test considers these two targets to remain converged,
668 // though we haven't agreed to that behavior yet.
genSwitchMulticase()669 void genSwitchMulticase()
670 {
671 ops.push_back({OP_SWITCH_VAR_BEGIN, 0});
672 nesting++;
673
674 ops.push_back({OP_CASE_MASK_BEGIN, 0x3333333333333333ULL, (1<<0)|(1<<1)});
675 pickOP(2);
676 ops.push_back({OP_CASE_END, 0});
677
678 ops.push_back({OP_CASE_MASK_BEGIN, 0xCCCCCCCCCCCCCCCCULL, (1<<2)|(1<<3)});
679 pickOP(2);
680 ops.push_back({OP_CASE_END, 0});
681
682 ops.push_back({OP_SWITCH_END, 0});
683 nesting--;
684 }
685
686 // switch (loopIdxN) {
687 // case 1: ... break;
688 // case 2: ... break;
689 // default: ... break;
690 // }
genSwitchLoopCount()691 void genSwitchLoopCount()
692 {
693 deUint32 r = deRandom_getUint32(&rnd) % loopNesting;
694 ops.push_back({OP_SWITCH_LOOP_COUNT_BEGIN, r});
695 nesting++;
696
697 ops.push_back({OP_CASE_LOOP_COUNT_BEGIN, 1ULL<<1, 1});
698 pickOP(1);
699 ops.push_back({OP_CASE_END, 0});
700
701 ops.push_back({OP_CASE_LOOP_COUNT_BEGIN, 1ULL<<2, 2});
702 pickOP(1);
703 ops.push_back({OP_CASE_END, 0});
704
705 // default:
706 ops.push_back({OP_CASE_LOOP_COUNT_BEGIN, ~6ULL, 0xFFFFFFFF});
707 pickOP(1);
708 ops.push_back({OP_CASE_END, 0});
709
710 ops.push_back({OP_SWITCH_END, 0});
711 nesting--;
712 }
713
pickOP(deUint32 count)714 void pickOP(deUint32 count)
715 {
716 // Pick "count" instructions. These can recursively insert more instructions,
717 // so "count" is just a seed
718 for (deUint32 i = 0; i < count; ++i)
719 {
720 optBallot();
721 if (nesting < maxNesting)
722 {
723 deUint32 r = deRandom_getUint32(&rnd) % 11;
724 switch (r)
725 {
726 default:
727 DE_ASSERT(0);
728 // fallthrough
729 case 2:
730 if (loopNesting)
731 {
732 genIf(IF_LOOPCOUNT);
733 break;
734 }
735 // fallthrough
736 case 10:
737 genIf(IF_LOCAL_INVOCATION_INDEX);
738 break;
739 case 0:
740 genIf(IF_MASK);
741 break;
742 case 1:
743 genIf(IF_UNIFORM);
744 break;
745 case 3:
746 {
747 // don't nest loops too deeply, to avoid extreme memory usage or timeouts
748 if (loopNesting <= 3)
749 {
750 deUint32 r2 = deRandom_getUint32(&rnd) % 3;
751 switch (r2)
752 {
753 default: DE_ASSERT(0); // fallthrough
754 case 0: genForUnif(); break;
755 case 1: genForInf(); break;
756 case 2: genForVar(); break;
757 }
758 }
759 }
760 break;
761 case 4:
762 genBreak();
763 break;
764 case 5:
765 genContinue();
766 break;
767 case 6:
768 genElect(false);
769 break;
770 case 7:
771 {
772 deUint32 r2 = deRandom_getUint32(&rnd) % 5;
773 if (r2 == 0 && callNesting == 0 && nesting < maxNesting - 2)
774 genCall();
775 else
776 genReturn();
777 break;
778 }
779 case 8:
780 {
781 // don't nest loops too deeply, to avoid extreme memory usage or timeouts
782 if (loopNesting <= 3)
783 {
784 deUint32 r2 = deRandom_getUint32(&rnd) % 2;
785 switch (r2)
786 {
787 default: DE_ASSERT(0); // fallthrough
788 case 0: genDoWhileUnif(); break;
789 case 1: genDoWhileInf(); break;
790 }
791 }
792 }
793 break;
794 case 9:
795 {
796 deUint32 r2 = deRandom_getUint32(&rnd) % 4;
797 switch (r2)
798 {
799 default:
800 DE_ASSERT(0);
801 // fallthrough
802 case 0:
803 genSwitchUnif();
804 break;
805 case 1:
806 if (loopNesting > 0) {
807 genSwitchLoopCount();
808 break;
809 }
810 // fallthrough
811 case 2:
812 if (caseDef.testType != TT_MAXIMAL)
813 {
814 // multicase doesn't have fully-defined behavior for MAXIMAL tests,
815 // but does for SUCF tests
816 genSwitchMulticase();
817 break;
818 }
819 // fallthrough
820 case 3:
821 genSwitchVar();
822 break;
823 }
824 }
825 break;
826 }
827 }
828 optBallot();
829 }
830 }
831
optBallot()832 void optBallot()
833 {
834 // optionally insert ballots, stores, and noise. Ballots and stores are used to determine
835 // correctness.
836 if ((deRandom_getUint32(&rnd) % 100) < 20)
837 {
838 if (ops.size() < 2 ||
839 !(ops[ops.size()-1].type == OP_BALLOT ||
840 (ops[ops.size()-1].type == OP_STORE && ops[ops.size()-2].type == OP_BALLOT)))
841 {
842 // do a store along with each ballot, so we can correlate where
843 // the ballot came from
844 if (caseDef.testType != TT_MAXIMAL)
845 ops.push_back({OP_STORE, (deUint32)ops.size() + storeBase});
846 ops.push_back({OP_BALLOT, 0});
847 }
848 }
849
850 if ((deRandom_getUint32(&rnd) % 100) < 10)
851 {
852 if (ops.size() < 2 ||
853 !(ops[ops.size()-1].type == OP_STORE ||
854 (ops[ops.size()-1].type == OP_BALLOT && ops[ops.size()-2].type == OP_STORE)))
855 {
856 // SUCF does a store with every ballot. Don't bloat the code by adding more.
857 if (caseDef.testType == TT_MAXIMAL)
858 ops.push_back({OP_STORE, (deUint32)ops.size() + storeBase});
859 }
860 }
861
862 deUint32 r = deRandom_getUint32(&rnd) % 10000;
863 if (r < 3)
864 ops.push_back({OP_NOISE, 0});
865 else if (r < 10)
866 ops.push_back({OP_NOISE, 1});
867 }
868
generateRandomProgram()869 void generateRandomProgram()
870 {
871 do {
872 ops.clear();
873 while ((deInt32)ops.size() < minCount)
874 pickOP(1);
875
876 // Retry until the program has some UCF results in it
877 if (caseDef.isUCF())
878 {
879 const deUint32 invocationStride = 128;
880 // Simulate for all subgroup sizes, to determine whether OP_BALLOTs are nonuniform
881 for (deInt32 subgroupSize = 4; subgroupSize <= 64; subgroupSize *= 2) {
882 simulate(true, subgroupSize, invocationStride, DE_NULL);
883 }
884 }
885 } while (caseDef.isUCF() && !hasUCF());
886 }
887
printIndent(std::stringstream & css)888 void printIndent(std::stringstream &css)
889 {
890 for (deInt32 i = 0; i < indent; ++i)
891 css << " ";
892 }
893
genPartitionBallot()894 std::string genPartitionBallot()
895 {
896 std::stringstream ss;
897 ss << "subgroupBallot(true).xy";
898 return ss.str();
899 }
900
printBallot(std::stringstream * css)901 void printBallot(std::stringstream *css)
902 {
903 *css << "outputC.loc[gl_LocalInvocationIndex]++,";
904 // When inside loop(s), use partitionBallot rather than subgroupBallot to compute
905 // a ballot, to make sure the ballot is "diverged enough". Don't do this for
906 // subgroup_uniform_control_flow, since we only validate results that must be fully
907 // reconverged.
908 if (loopNesting > 0 && caseDef.testType == TT_MAXIMAL)
909 {
910 *css << "outputB.b[(outLoc++)*invocationStride + gl_LocalInvocationIndex] = " << genPartitionBallot();
911 }
912 else if (caseDef.isElect())
913 {
914 *css << "outputB.b[(outLoc++)*invocationStride + gl_LocalInvocationIndex].x = elect()";
915 }
916 else
917 {
918 *css << "outputB.b[(outLoc++)*invocationStride + gl_LocalInvocationIndex] = subgroupBallot(true).xy";
919 }
920 }
921
genCode(std::stringstream & functions,std::stringstream & main)922 void genCode(std::stringstream &functions, std::stringstream &main)
923 {
924 std::stringstream *css = &main;
925 indent = 4;
926 loopNesting = 0;
927 int funcNum = 0;
928 for (deInt32 i = 0; i < (deInt32)ops.size(); ++i)
929 {
930 switch (ops[i].type)
931 {
932 case OP_IF_MASK:
933 printIndent(*css);
934 if (ops[i].value == ~0ULL)
935 {
936 // This equality test will always succeed, since inputA.a[i] == i
937 int idx = deRandom_getUint32(&rnd) % 4;
938 *css << "if (inputA.a[" << idx << "] == " << idx << ") {\n";
939 }
940 else
941 *css << "if (testBit(uvec2(0x" << std::hex << (ops[i].value & 0xFFFFFFFF) << ", 0x" << (ops[i].value >> 32) << "), gl_SubgroupInvocationID)) {\n";
942
943 indent += 4;
944 break;
945 case OP_IF_LOOPCOUNT:
946 printIndent(*css); *css << "if (gl_SubgroupInvocationID == loopIdx" << loopNesting - 1 << ") {\n";
947 indent += 4;
948 break;
949 case OP_IF_LOCAL_INVOCATION_INDEX:
950 printIndent(*css); *css << "if (gl_LocalInvocationIndex >= inputA.a[0x" << std::hex << ops[i].value << "]) {\n";
951 indent += 4;
952 break;
953 case OP_ELSE_MASK:
954 case OP_ELSE_LOOPCOUNT:
955 case OP_ELSE_LOCAL_INVOCATION_INDEX:
956 indent -= 4;
957 printIndent(*css); *css << "} else {\n";
958 indent += 4;
959 break;
960 case OP_ENDIF:
961 indent -= 4;
962 printIndent(*css); *css << "}\n";
963 break;
964 case OP_BALLOT:
965 printIndent(*css); printBallot(css); *css << ";\n";
966 break;
967 case OP_STORE:
968 printIndent(*css); *css << "outputC.loc[gl_LocalInvocationIndex]++;\n";
969 printIndent(*css); *css << "outputB.b[(outLoc++)*invocationStride + gl_LocalInvocationIndex].x = 0x" << std::hex << ops[i].value << ";\n";
970 break;
971 case OP_BEGIN_FOR_UNIF:
972 printIndent(*css); *css << "for (int loopIdx" << loopNesting << " = 0;\n";
973 printIndent(*css); *css << " loopIdx" << loopNesting << " < inputA.a[" << ops[i].value << "];\n";
974 printIndent(*css); *css << " loopIdx" << loopNesting << "++) {\n";
975 indent += 4;
976 loopNesting++;
977 break;
978 case OP_END_FOR_UNIF:
979 loopNesting--;
980 indent -= 4;
981 printIndent(*css); *css << "}\n";
982 break;
983 case OP_BEGIN_DO_WHILE_UNIF:
984 printIndent(*css); *css << "{\n";
985 indent += 4;
986 printIndent(*css); *css << "int loopIdx" << loopNesting << " = 0;\n";
987 printIndent(*css); *css << "do {\n";
988 indent += 4;
989 printIndent(*css); *css << "loopIdx" << loopNesting << "++;\n";
990 loopNesting++;
991 break;
992 case OP_BEGIN_DO_WHILE_INF:
993 printIndent(*css); *css << "{\n";
994 indent += 4;
995 printIndent(*css); *css << "int loopIdx" << loopNesting << " = 0;\n";
996 printIndent(*css); *css << "do {\n";
997 indent += 4;
998 loopNesting++;
999 break;
1000 case OP_END_DO_WHILE_UNIF:
1001 loopNesting--;
1002 indent -= 4;
1003 printIndent(*css); *css << "} while (loopIdx" << loopNesting << " < inputA.a[" << ops[(deUint32)ops[i].value].value << "]);\n";
1004 indent -= 4;
1005 printIndent(*css); *css << "}\n";
1006 break;
1007 case OP_END_DO_WHILE_INF:
1008 loopNesting--;
1009 printIndent(*css); *css << "loopIdx" << loopNesting << "++;\n";
1010 indent -= 4;
1011 printIndent(*css); *css << "} while (true);\n";
1012 indent -= 4;
1013 printIndent(*css); *css << "}\n";
1014 break;
1015 case OP_BEGIN_FOR_VAR:
1016 printIndent(*css); *css << "for (int loopIdx" << loopNesting << " = 0;\n";
1017 printIndent(*css); *css << " loopIdx" << loopNesting << " < gl_SubgroupInvocationID + 1;\n";
1018 printIndent(*css); *css << " loopIdx" << loopNesting << "++) {\n";
1019 indent += 4;
1020 loopNesting++;
1021 break;
1022 case OP_END_FOR_VAR:
1023 loopNesting--;
1024 indent -= 4;
1025 printIndent(*css); *css << "}\n";
1026 break;
1027 case OP_BEGIN_FOR_INF:
1028 printIndent(*css); *css << "for (int loopIdx" << loopNesting << " = 0;;loopIdx" << loopNesting << "++,";
1029 loopNesting++;
1030 printBallot(css);
1031 *css << ") {\n";
1032 indent += 4;
1033 break;
1034 case OP_END_FOR_INF:
1035 loopNesting--;
1036 indent -= 4;
1037 printIndent(*css); *css << "}\n";
1038 break;
1039 case OP_BREAK:
1040 printIndent(*css); *css << "break;\n";
1041 break;
1042 case OP_CONTINUE:
1043 printIndent(*css); *css << "continue;\n";
1044 break;
1045 case OP_ELECT:
1046 printIndent(*css); *css << "if (subgroupElect()) {\n";
1047 indent += 4;
1048 break;
1049 case OP_RETURN:
1050 printIndent(*css); *css << "return;\n";
1051 break;
1052 case OP_CALL_BEGIN:
1053 printIndent(*css); *css << "func" << funcNum << "(";
1054 for (deInt32 n = 0; n < loopNesting; ++n)
1055 {
1056 *css << "loopIdx" << n;
1057 if (n != loopNesting - 1)
1058 *css << ", ";
1059 }
1060 *css << ");\n";
1061 css = &functions;
1062 printIndent(*css); *css << "void func" << funcNum << "(";
1063 for (deInt32 n = 0; n < loopNesting; ++n)
1064 {
1065 *css << "int loopIdx" << n;
1066 if (n != loopNesting - 1)
1067 *css << ", ";
1068 }
1069 *css << ") {\n";
1070 indent += 4;
1071 funcNum++;
1072 break;
1073 case OP_CALL_END:
1074 indent -= 4;
1075 printIndent(*css); *css << "}\n";
1076 css = &main;
1077 break;
1078 case OP_NOISE:
1079 if (ops[i].value == 0)
1080 {
1081 printIndent(*css); *css << "while (!subgroupElect()) {}\n";
1082 }
1083 else
1084 {
1085 printIndent(*css); *css << "if (inputA.a[0] == 12345) {\n";
1086 indent += 4;
1087 printIndent(*css); *css << "while (true) {\n";
1088 indent += 4;
1089 printIndent(*css); printBallot(css); *css << ";\n";
1090 indent -= 4;
1091 printIndent(*css); *css << "}\n";
1092 indent -= 4;
1093 printIndent(*css); *css << "}\n";
1094 }
1095 break;
1096 case OP_SWITCH_UNIF_BEGIN:
1097 printIndent(*css); *css << "switch (inputA.a[" << ops[i].value << "]) {\n";
1098 indent += 4;
1099 break;
1100 case OP_SWITCH_VAR_BEGIN:
1101 printIndent(*css); *css << "switch (gl_SubgroupInvocationID & 3) {\n";
1102 indent += 4;
1103 break;
1104 case OP_SWITCH_LOOP_COUNT_BEGIN:
1105 printIndent(*css); *css << "switch (loopIdx" << ops[i].value << ") {\n";
1106 indent += 4;
1107 break;
1108 case OP_SWITCH_END:
1109 indent -= 4;
1110 printIndent(*css); *css << "}\n";
1111 break;
1112 case OP_CASE_MASK_BEGIN:
1113 for (deInt32 b = 0; b < 32; ++b)
1114 {
1115 if ((1u<<b) & ops[i].caseValue)
1116 {
1117 printIndent(*css); *css << "case " << b << ":\n";
1118 }
1119 }
1120 printIndent(*css); *css << "{\n";
1121 indent += 4;
1122 break;
1123 case OP_CASE_LOOP_COUNT_BEGIN:
1124 if (ops[i].caseValue == 0xFFFFFFFF)
1125 {
1126 printIndent(*css); *css << "default: {\n";
1127 }
1128 else
1129 {
1130 printIndent(*css); *css << "case " << ops[i].caseValue << ": {\n";
1131 }
1132 indent += 4;
1133 break;
1134 case OP_CASE_END:
1135 printIndent(*css); *css << "break;\n";
1136 indent -= 4;
1137 printIndent(*css); *css << "}\n";
1138 break;
1139 default:
1140 DE_ASSERT(0);
1141 break;
1142 }
1143 }
1144 }
1145
1146 // Simulate execution of the program. If countOnly is true, just return
1147 // the max number of outputs written. If it's false, store out the result
1148 // values to ref
simulate(bool countOnly,deUint32 subgroupSize,deUint32 invocationStride,deUint64 * ref)1149 deUint32 simulate(bool countOnly, deUint32 subgroupSize, deUint32 invocationStride, deUint64 *ref)
1150 {
1151 // State of the subgroup at each level of nesting
1152 struct SubgroupState
1153 {
1154 // Currently executing
1155 bitset128 activeMask;
1156 // Have executed a continue instruction in this loop
1157 bitset128 continueMask;
1158 // index of the current if test or loop header
1159 deUint32 header;
1160 // number of loop iterations performed
1161 deUint32 tripCount;
1162 // is this nesting a loop?
1163 deUint32 isLoop;
1164 // is this nesting a function call?
1165 deUint32 isCall;
1166 // is this nesting a switch?
1167 deUint32 isSwitch;
1168 };
1169 SubgroupState stateStack[10];
1170 deMemset(&stateStack, 0, sizeof(stateStack));
1171
1172 const deUint64 fullSubgroupMask = subgroupSizeToMask(subgroupSize);
1173
1174 // Per-invocation output location counters
1175 deUint32 outLoc[128] = {0};
1176
1177 nesting = 0;
1178 loopNesting = 0;
1179 stateStack[nesting].activeMask = ~bitset128(); // initialized to ~0
1180
1181 deInt32 i = 0;
1182 while (i < (deInt32)ops.size())
1183 {
1184 switch (ops[i].type)
1185 {
1186 case OP_BALLOT:
1187
1188 // Flag that this ballot is workgroup-nonuniform
1189 if (caseDef.isWUCF() && stateStack[nesting].activeMask.any() && !stateStack[nesting].activeMask.all())
1190 ops[i].caseValue = 1;
1191
1192 if (caseDef.isSUCF())
1193 {
1194 for (deUint32 id = 0; id < 128; id += subgroupSize)
1195 {
1196 deUint64 subgroupMask = bitsetToU64(stateStack[nesting].activeMask, subgroupSize, id);
1197 // Flag that this ballot is subgroup-nonuniform
1198 if (subgroupMask != 0 && subgroupMask != fullSubgroupMask)
1199 ops[i].caseValue = 1;
1200 }
1201 }
1202
1203 for (deUint32 id = 0; id < 128; ++id)
1204 {
1205 if (stateStack[nesting].activeMask.test(id))
1206 {
1207 if (countOnly)
1208 {
1209 outLoc[id]++;
1210 }
1211 else
1212 {
1213 if (ops[i].caseValue)
1214 {
1215 // Emit a magic value to indicate that we shouldn't validate this ballot
1216 ref[(outLoc[id]++)*invocationStride + id] = bitsetToU64(0x12345678, subgroupSize, id);
1217 }
1218 else
1219 ref[(outLoc[id]++)*invocationStride + id] = bitsetToU64(stateStack[nesting].activeMask, subgroupSize, id);
1220 }
1221 }
1222 }
1223 break;
1224 case OP_STORE:
1225 for (deUint32 id = 0; id < 128; ++id)
1226 {
1227 if (stateStack[nesting].activeMask.test(id))
1228 {
1229 if (countOnly)
1230 outLoc[id]++;
1231 else
1232 ref[(outLoc[id]++)*invocationStride + id] = ops[i].value;
1233 }
1234 }
1235 break;
1236 case OP_IF_MASK:
1237 nesting++;
1238 stateStack[nesting].activeMask = stateStack[nesting-1].activeMask & bitsetFromU64(ops[i].value, subgroupSize);
1239 stateStack[nesting].header = i;
1240 stateStack[nesting].isLoop = 0;
1241 stateStack[nesting].isSwitch = 0;
1242 break;
1243 case OP_ELSE_MASK:
1244 stateStack[nesting].activeMask = stateStack[nesting-1].activeMask & ~bitsetFromU64(ops[stateStack[nesting].header].value, subgroupSize);
1245 break;
1246 case OP_IF_LOOPCOUNT:
1247 {
1248 deUint32 n = nesting;
1249 while (!stateStack[n].isLoop)
1250 n--;
1251
1252 nesting++;
1253 stateStack[nesting].activeMask = stateStack[nesting-1].activeMask & bitsetFromU64((1ULL << stateStack[n].tripCount), subgroupSize);
1254 stateStack[nesting].header = i;
1255 stateStack[nesting].isLoop = 0;
1256 stateStack[nesting].isSwitch = 0;
1257 break;
1258 }
1259 case OP_ELSE_LOOPCOUNT:
1260 {
1261 deUint32 n = nesting;
1262 while (!stateStack[n].isLoop)
1263 n--;
1264
1265 stateStack[nesting].activeMask = stateStack[nesting-1].activeMask & ~bitsetFromU64((1ULL << stateStack[n].tripCount), subgroupSize);
1266 break;
1267 }
1268 case OP_IF_LOCAL_INVOCATION_INDEX:
1269 {
1270 // all bits >= N
1271 bitset128 mask(0);
1272 for (deInt32 j = (deInt32)ops[i].value; j < 128; ++j)
1273 mask.set(j);
1274
1275 nesting++;
1276 stateStack[nesting].activeMask = stateStack[nesting-1].activeMask & mask;
1277 stateStack[nesting].header = i;
1278 stateStack[nesting].isLoop = 0;
1279 stateStack[nesting].isSwitch = 0;
1280 break;
1281 }
1282 case OP_ELSE_LOCAL_INVOCATION_INDEX:
1283 {
1284 // all bits < N
1285 bitset128 mask(0);
1286 for (deInt32 j = 0; j < (deInt32)ops[i].value; ++j)
1287 mask.set(j);
1288
1289 stateStack[nesting].activeMask = stateStack[nesting-1].activeMask & mask;
1290 break;
1291 }
1292 case OP_ENDIF:
1293 nesting--;
1294 break;
1295 case OP_BEGIN_FOR_UNIF:
1296 // XXX TODO: We don't handle a for loop with zero iterations
1297 nesting++;
1298 loopNesting++;
1299 stateStack[nesting].activeMask = stateStack[nesting-1].activeMask;
1300 stateStack[nesting].header = i;
1301 stateStack[nesting].tripCount = 0;
1302 stateStack[nesting].isLoop = 1;
1303 stateStack[nesting].isSwitch = 0;
1304 stateStack[nesting].continueMask = 0;
1305 break;
1306 case OP_END_FOR_UNIF:
1307 stateStack[nesting].tripCount++;
1308 stateStack[nesting].activeMask |= stateStack[nesting].continueMask;
1309 stateStack[nesting].continueMask = 0;
1310 if (stateStack[nesting].tripCount < ops[stateStack[nesting].header].value &&
1311 stateStack[nesting].activeMask.any())
1312 {
1313 i = stateStack[nesting].header+1;
1314 continue;
1315 }
1316 else
1317 {
1318 loopNesting--;
1319 nesting--;
1320 }
1321 break;
1322 case OP_BEGIN_DO_WHILE_UNIF:
1323 // XXX TODO: We don't handle a for loop with zero iterations
1324 nesting++;
1325 loopNesting++;
1326 stateStack[nesting].activeMask = stateStack[nesting-1].activeMask;
1327 stateStack[nesting].header = i;
1328 stateStack[nesting].tripCount = 1;
1329 stateStack[nesting].isLoop = 1;
1330 stateStack[nesting].isSwitch = 0;
1331 stateStack[nesting].continueMask = 0;
1332 break;
1333 case OP_END_DO_WHILE_UNIF:
1334 stateStack[nesting].activeMask |= stateStack[nesting].continueMask;
1335 stateStack[nesting].continueMask = 0;
1336 if (stateStack[nesting].tripCount < ops[stateStack[nesting].header].value &&
1337 stateStack[nesting].activeMask.any())
1338 {
1339 i = stateStack[nesting].header+1;
1340 stateStack[nesting].tripCount++;
1341 continue;
1342 }
1343 else
1344 {
1345 loopNesting--;
1346 nesting--;
1347 }
1348 break;
1349 case OP_BEGIN_FOR_VAR:
1350 // XXX TODO: We don't handle a for loop with zero iterations
1351 nesting++;
1352 loopNesting++;
1353 stateStack[nesting].activeMask = stateStack[nesting-1].activeMask;
1354 stateStack[nesting].header = i;
1355 stateStack[nesting].tripCount = 0;
1356 stateStack[nesting].isLoop = 1;
1357 stateStack[nesting].isSwitch = 0;
1358 stateStack[nesting].continueMask = 0;
1359 break;
1360 case OP_END_FOR_VAR:
1361 stateStack[nesting].tripCount++;
1362 stateStack[nesting].activeMask |= stateStack[nesting].continueMask;
1363 stateStack[nesting].continueMask = 0;
1364 stateStack[nesting].activeMask &= bitsetFromU64(stateStack[nesting].tripCount == subgroupSize ? 0 : ~((1ULL << (stateStack[nesting].tripCount)) - 1), subgroupSize);
1365 if (stateStack[nesting].activeMask.any())
1366 {
1367 i = stateStack[nesting].header+1;
1368 continue;
1369 }
1370 else
1371 {
1372 loopNesting--;
1373 nesting--;
1374 }
1375 break;
1376 case OP_BEGIN_FOR_INF:
1377 case OP_BEGIN_DO_WHILE_INF:
1378 nesting++;
1379 loopNesting++;
1380 stateStack[nesting].activeMask = stateStack[nesting-1].activeMask;
1381 stateStack[nesting].header = i;
1382 stateStack[nesting].tripCount = 0;
1383 stateStack[nesting].isLoop = 1;
1384 stateStack[nesting].isSwitch = 0;
1385 stateStack[nesting].continueMask = 0;
1386 break;
1387 case OP_END_FOR_INF:
1388 stateStack[nesting].tripCount++;
1389 stateStack[nesting].activeMask |= stateStack[nesting].continueMask;
1390 stateStack[nesting].continueMask = 0;
1391 if (stateStack[nesting].activeMask.any())
1392 {
1393 // output expected OP_BALLOT values
1394 for (deUint32 id = 0; id < 128; ++id)
1395 {
1396 if (stateStack[nesting].activeMask.test(id))
1397 {
1398 if (countOnly)
1399 outLoc[id]++;
1400 else
1401 ref[(outLoc[id]++)*invocationStride + id] = bitsetToU64(stateStack[nesting].activeMask, subgroupSize, id);
1402 }
1403 }
1404
1405 i = stateStack[nesting].header+1;
1406 continue;
1407 }
1408 else
1409 {
1410 loopNesting--;
1411 nesting--;
1412 }
1413 break;
1414 case OP_END_DO_WHILE_INF:
1415 stateStack[nesting].tripCount++;
1416 stateStack[nesting].activeMask |= stateStack[nesting].continueMask;
1417 stateStack[nesting].continueMask = 0;
1418 if (stateStack[nesting].activeMask.any())
1419 {
1420 i = stateStack[nesting].header+1;
1421 continue;
1422 }
1423 else
1424 {
1425 loopNesting--;
1426 nesting--;
1427 }
1428 break;
1429 case OP_BREAK:
1430 {
1431 deUint32 n = nesting;
1432 bitset128 mask = stateStack[nesting].activeMask;
1433 while (true)
1434 {
1435 stateStack[n].activeMask &= ~mask;
1436 if (stateStack[n].isLoop || stateStack[n].isSwitch)
1437 break;
1438
1439 n--;
1440 }
1441 }
1442 break;
1443 case OP_CONTINUE:
1444 {
1445 deUint32 n = nesting;
1446 bitset128 mask = stateStack[nesting].activeMask;
1447 while (true)
1448 {
1449 stateStack[n].activeMask &= ~mask;
1450 if (stateStack[n].isLoop)
1451 {
1452 stateStack[n].continueMask |= mask;
1453 break;
1454 }
1455 n--;
1456 }
1457 }
1458 break;
1459 case OP_ELECT:
1460 {
1461 nesting++;
1462 stateStack[nesting].activeMask = bitsetElect(stateStack[nesting-1].activeMask, subgroupSize);
1463 stateStack[nesting].header = i;
1464 stateStack[nesting].isLoop = 0;
1465 stateStack[nesting].isSwitch = 0;
1466 }
1467 break;
1468 case OP_RETURN:
1469 {
1470 bitset128 mask = stateStack[nesting].activeMask;
1471 for (deInt32 n = nesting; n >= 0; --n)
1472 {
1473 stateStack[n].activeMask &= ~mask;
1474 if (stateStack[n].isCall)
1475 break;
1476 }
1477 }
1478 break;
1479
1480 case OP_CALL_BEGIN:
1481 nesting++;
1482 stateStack[nesting].activeMask = stateStack[nesting-1].activeMask;
1483 stateStack[nesting].isLoop = 0;
1484 stateStack[nesting].isSwitch = 0;
1485 stateStack[nesting].isCall = 1;
1486 break;
1487 case OP_CALL_END:
1488 stateStack[nesting].isCall = 0;
1489 nesting--;
1490 break;
1491 case OP_NOISE:
1492 break;
1493
1494 case OP_SWITCH_UNIF_BEGIN:
1495 case OP_SWITCH_VAR_BEGIN:
1496 case OP_SWITCH_LOOP_COUNT_BEGIN:
1497 nesting++;
1498 stateStack[nesting].activeMask = stateStack[nesting-1].activeMask;
1499 stateStack[nesting].header = i;
1500 stateStack[nesting].isLoop = 0;
1501 stateStack[nesting].isSwitch = 1;
1502 break;
1503 case OP_SWITCH_END:
1504 nesting--;
1505 break;
1506 case OP_CASE_MASK_BEGIN:
1507 stateStack[nesting].activeMask = stateStack[nesting-1].activeMask & bitsetFromU64(ops[i].value, subgroupSize);
1508 break;
1509 case OP_CASE_LOOP_COUNT_BEGIN:
1510 {
1511 deUint32 n = nesting;
1512 deUint32 l = loopNesting;
1513
1514 while (true)
1515 {
1516 if (stateStack[n].isLoop)
1517 {
1518 l--;
1519 if (l == ops[stateStack[nesting].header].value)
1520 break;
1521 }
1522 n--;
1523 }
1524
1525 if ((1ULL << stateStack[n].tripCount) & ops[i].value)
1526 stateStack[nesting].activeMask = stateStack[nesting-1].activeMask;
1527 else
1528 stateStack[nesting].activeMask = 0;
1529 break;
1530 }
1531 case OP_CASE_END:
1532 break;
1533
1534 default:
1535 DE_ASSERT(0);
1536 break;
1537 }
1538 i++;
1539 }
1540 deUint32 maxLoc = 0;
1541 for (deUint32 id = 0; id < ARRAYSIZE(outLoc); ++id)
1542 maxLoc = de::max(maxLoc, outLoc[id]);
1543
1544 return maxLoc;
1545 }
1546
hasUCF() const1547 bool hasUCF() const
1548 {
1549 for (deInt32 i = 0; i < (deInt32)ops.size(); ++i)
1550 {
1551 if (ops[i].type == OP_BALLOT && ops[i].caseValue == 0)
1552 return true;
1553 }
1554 return false;
1555 }
1556 };
1557
initPrograms(SourceCollections & programCollection) const1558 void ReconvergenceTestCase::initPrograms (SourceCollections& programCollection) const
1559 {
1560 RandomProgram program(m_data);
1561 program.generateRandomProgram();
1562
1563 std::stringstream css;
1564 css << "#version 450 core\n";
1565 css << "#extension GL_KHR_shader_subgroup_ballot : enable\n";
1566 css << "#extension GL_KHR_shader_subgroup_vote : enable\n";
1567 css << "#extension GL_NV_shader_subgroup_partitioned : enable\n";
1568 css << "#extension GL_EXT_subgroup_uniform_control_flow : enable\n";
1569 css << "layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;\n";
1570 css << "layout(set=0, binding=0) coherent buffer InputA { uint a[]; } inputA;\n";
1571 css << "layout(set=0, binding=1) coherent buffer OutputB { uvec2 b[]; } outputB;\n";
1572 css << "layout(set=0, binding=2) coherent buffer OutputC { uint loc[]; } outputC;\n";
1573 css << "layout(push_constant) uniform PC {\n"
1574 " // set to the real stride when writing out ballots, or zero when just counting\n"
1575 " int invocationStride;\n"
1576 "};\n";
1577 css << "int outLoc = 0;\n";
1578
1579 css << "bool testBit(uvec2 mask, uint bit) { return (bit < 32) ? ((mask.x >> bit) & 1) != 0 : ((mask.y >> (bit-32)) & 1) != 0; }\n";
1580
1581 css << "uint elect() { return int(subgroupElect()) + 1; }\n";
1582
1583 std::stringstream functions, main;
1584 program.genCode(functions, main);
1585
1586 css << functions.str() << "\n\n";
1587
1588 css <<
1589 "void main()\n"
1590 << (m_data.isSUCF() ? "[[subgroup_uniform_control_flow]]\n" : "") <<
1591 "{\n";
1592
1593 css << main.str() << "\n\n";
1594
1595 css << "}\n";
1596
1597 const vk::ShaderBuildOptions buildOptions (programCollection.usedVulkanVersion, vk::SPIRV_VERSION_1_3, 0u);
1598
1599 programCollection.glslSources.add("test") << glu::ComputeSource(css.str()) << buildOptions;
1600 }
1601
createInstance(Context & context) const1602 TestInstance* ReconvergenceTestCase::createInstance (Context& context) const
1603 {
1604 return new ReconvergenceTestInstance(context, m_data);
1605 }
1606
iterate(void)1607 tcu::TestStatus ReconvergenceTestInstance::iterate (void)
1608 {
1609 const DeviceInterface& vk = m_context.getDeviceInterface();
1610 const VkDevice device = m_context.getDevice();
1611 Allocator& allocator = m_context.getDefaultAllocator();
1612 tcu::TestLog& log = m_context.getTestContext().getLog();
1613
1614 deRandom rnd;
1615 deRandom_init(&rnd, m_data.seed);
1616
1617 vk::VkPhysicalDeviceSubgroupProperties subgroupProperties;
1618 deMemset(&subgroupProperties, 0, sizeof(subgroupProperties));
1619 subgroupProperties.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SUBGROUP_PROPERTIES;
1620
1621 vk::VkPhysicalDeviceProperties2 properties2;
1622 deMemset(&properties2, 0, sizeof(properties2));
1623 properties2.sType = vk::VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_PROPERTIES_2;
1624 properties2.pNext = &subgroupProperties;
1625
1626 m_context.getInstanceInterface().getPhysicalDeviceProperties2(m_context.getPhysicalDevice(), &properties2);
1627
1628 const deUint32 subgroupSize = subgroupProperties.subgroupSize;
1629 const deUint32 invocationStride = 128;
1630
1631 if (subgroupSize > 64)
1632 TCU_THROW(TestError, "Subgroup size greater than 64 not handled.");
1633
1634 RandomProgram program(m_data);
1635 program.generateRandomProgram();
1636
1637 deUint32 maxLoc = program.simulate(true, subgroupSize, invocationStride, DE_NULL);
1638
1639 // maxLoc is per-invocation. Add one (to make sure no additional writes are done) and multiply by
1640 // the number of invocations
1641 maxLoc++;
1642 maxLoc *= invocationStride;
1643
1644 // buffer[0] is an input filled with a[i] == i
1645 // buffer[1] is the output
1646 // buffer[2] is the location counts
1647 de::MovePtr<BufferWithMemory> buffers[3];
1648 vk::VkDescriptorBufferInfo bufferDescriptors[3];
1649
1650 VkDeviceSize sizes[3] =
1651 {
1652 128 * sizeof(deUint32),
1653 maxLoc * sizeof(deUint64),
1654 invocationStride * sizeof(deUint32),
1655 };
1656
1657 for (deUint32 i = 0; i < 3; ++i)
1658 {
1659 if (sizes[i] > properties2.properties.limits.maxStorageBufferRange)
1660 TCU_THROW(NotSupportedError, "Storage buffer size larger than device limits");
1661
1662 try
1663 {
1664 buffers[i] = de::MovePtr<BufferWithMemory>(new BufferWithMemory(
1665 vk, device, allocator, makeBufferCreateInfo(sizes[i], VK_BUFFER_USAGE_STORAGE_BUFFER_BIT|VK_BUFFER_USAGE_TRANSFER_DST_BIT|VK_BUFFER_USAGE_TRANSFER_SRC_BIT),
1666 MemoryRequirement::HostVisible | MemoryRequirement::Cached));
1667 }
1668 catch(tcu::ResourceError&)
1669 {
1670 // Allocation size is unpredictable and can be too large for some systems. Don't treat allocation failure as a test failure.
1671 return tcu::TestStatus(QP_TEST_RESULT_QUALITY_WARNING, "Failed device memory allocation " + de::toString(sizes[i]) + " bytes");
1672 }
1673 bufferDescriptors[i] = makeDescriptorBufferInfo(**buffers[i], 0, sizes[i]);
1674 }
1675
1676 deUint32 *ptrs[3];
1677 for (deUint32 i = 0; i < 3; ++i)
1678 {
1679 ptrs[i] = (deUint32 *)buffers[i]->getAllocation().getHostPtr();
1680 }
1681 for (deUint32 i = 0; i < sizes[0] / sizeof(deUint32); ++i)
1682 {
1683 ptrs[0][i] = i;
1684 }
1685 deMemset(ptrs[1], 0, (size_t)sizes[1]);
1686 deMemset(ptrs[2], 0, (size_t)sizes[2]);
1687
1688 vk::DescriptorSetLayoutBuilder layoutBuilder;
1689
1690 layoutBuilder.addSingleBinding(VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, allShaderStages);
1691 layoutBuilder.addSingleBinding(VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, allShaderStages);
1692 layoutBuilder.addSingleBinding(VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, allShaderStages);
1693
1694 vk::Unique<vk::VkDescriptorSetLayout> descriptorSetLayout(layoutBuilder.build(vk, device));
1695
1696 vk::Unique<vk::VkDescriptorPool> descriptorPool(vk::DescriptorPoolBuilder()
1697 .addType(VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, 3u)
1698 .build(vk, device, VK_DESCRIPTOR_POOL_CREATE_FREE_DESCRIPTOR_SET_BIT, 1u));
1699 vk::Unique<vk::VkDescriptorSet> descriptorSet (makeDescriptorSet(vk, device, *descriptorPool, *descriptorSetLayout));
1700
1701 const deUint32 specData[1] =
1702 {
1703 invocationStride,
1704 };
1705 const vk::VkSpecializationMapEntry entries[1] =
1706 {
1707 {0, (deUint32)(sizeof(deUint32) * 0), sizeof(deUint32)},
1708 };
1709 const vk::VkSpecializationInfo specInfo =
1710 {
1711 1, // mapEntryCount
1712 entries, // pMapEntries
1713 sizeof(specData), // dataSize
1714 specData // pData
1715 };
1716
1717 const VkPushConstantRange pushConstantRange =
1718 {
1719 allShaderStages, // VkShaderStageFlags stageFlags;
1720 0u, // deUint32 offset;
1721 sizeof(deInt32) // deUint32 size;
1722 };
1723
1724 const VkPipelineLayoutCreateInfo pipelineLayoutCreateInfo =
1725 {
1726 VK_STRUCTURE_TYPE_PIPELINE_LAYOUT_CREATE_INFO, // sType
1727 DE_NULL, // pNext
1728 (VkPipelineLayoutCreateFlags)0,
1729 1, // setLayoutCount
1730 &descriptorSetLayout.get(), // pSetLayouts
1731 1u, // pushConstantRangeCount
1732 &pushConstantRange, // pPushConstantRanges
1733 };
1734
1735 Move<VkPipelineLayout> pipelineLayout = createPipelineLayout(vk, device, &pipelineLayoutCreateInfo, NULL);
1736
1737 VkPipelineBindPoint bindPoint = VK_PIPELINE_BIND_POINT_COMPUTE;
1738
1739 flushAlloc(vk, device, buffers[0]->getAllocation());
1740 flushAlloc(vk, device, buffers[1]->getAllocation());
1741 flushAlloc(vk, device, buffers[2]->getAllocation());
1742
1743 const VkBool32 computeFullSubgroups = subgroupProperties.subgroupSize <= 64 &&
1744 m_context.getSubgroupSizeControlFeaturesEXT().computeFullSubgroups;
1745
1746 const VkPipelineShaderStageRequiredSubgroupSizeCreateInfoEXT subgroupSizeCreateInfo =
1747 {
1748 VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_REQUIRED_SUBGROUP_SIZE_CREATE_INFO_EXT, // VkStructureType sType;
1749 DE_NULL, // void* pNext;
1750 subgroupProperties.subgroupSize // uint32_t requiredSubgroupSize;
1751 };
1752
1753 const void *shaderPNext = computeFullSubgroups ? &subgroupSizeCreateInfo : DE_NULL;
1754 VkPipelineShaderStageCreateFlags pipelineShaderStageCreateFlags =
1755 (VkPipelineShaderStageCreateFlags)(computeFullSubgroups ? VK_PIPELINE_SHADER_STAGE_CREATE_REQUIRE_FULL_SUBGROUPS_BIT_EXT : 0);
1756
1757 const Unique<VkShaderModule> shader (createShaderModule(vk, device, m_context.getBinaryCollection().get("test"), 0));
1758 const VkPipelineShaderStageCreateInfo shaderCreateInfo =
1759 {
1760 VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO,
1761 shaderPNext,
1762 pipelineShaderStageCreateFlags,
1763 VK_SHADER_STAGE_COMPUTE_BIT, // stage
1764 *shader, // shader
1765 "main",
1766 &specInfo, // pSpecializationInfo
1767 };
1768
1769 const VkComputePipelineCreateInfo pipelineCreateInfo =
1770 {
1771 VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO,
1772 DE_NULL,
1773 0u, // flags
1774 shaderCreateInfo, // cs
1775 *pipelineLayout, // layout
1776 (vk::VkPipeline)0, // basePipelineHandle
1777 0u, // basePipelineIndex
1778 };
1779 Move<VkPipeline> pipeline = createComputePipeline(vk, device, DE_NULL, &pipelineCreateInfo, NULL);
1780
1781 const VkQueue queue = m_context.getUniversalQueue();
1782 Move<VkCommandPool> cmdPool = createCommandPool(vk, device, vk::VK_COMMAND_POOL_CREATE_RESET_COMMAND_BUFFER_BIT, m_context.getUniversalQueueFamilyIndex());
1783 Move<VkCommandBuffer> cmdBuffer = allocateCommandBuffer(vk, device, *cmdPool, VK_COMMAND_BUFFER_LEVEL_PRIMARY);
1784
1785
1786 vk::DescriptorSetUpdateBuilder setUpdateBuilder;
1787 setUpdateBuilder.writeSingle(*descriptorSet, vk::DescriptorSetUpdateBuilder::Location::binding(0),
1788 VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, &bufferDescriptors[0]);
1789 setUpdateBuilder.writeSingle(*descriptorSet, vk::DescriptorSetUpdateBuilder::Location::binding(1),
1790 VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, &bufferDescriptors[1]);
1791 setUpdateBuilder.writeSingle(*descriptorSet, vk::DescriptorSetUpdateBuilder::Location::binding(2),
1792 VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, &bufferDescriptors[2]);
1793 setUpdateBuilder.update(vk, device);
1794
1795 // compute "maxLoc", the maximum number of locations written
1796 beginCommandBuffer(vk, *cmdBuffer, 0u);
1797
1798 vk.cmdBindDescriptorSets(*cmdBuffer, bindPoint, *pipelineLayout, 0u, 1, &*descriptorSet, 0u, DE_NULL);
1799 vk.cmdBindPipeline(*cmdBuffer, bindPoint, *pipeline);
1800
1801 deInt32 pcinvocationStride = 0;
1802 vk.cmdPushConstants(*cmdBuffer, *pipelineLayout, allShaderStages, 0, sizeof(pcinvocationStride), &pcinvocationStride);
1803
1804 vk.cmdDispatch(*cmdBuffer, 1, 1, 1);
1805
1806 endCommandBuffer(vk, *cmdBuffer);
1807
1808 submitCommandsAndWait(vk, device, queue, cmdBuffer.get());
1809
1810 invalidateAlloc(vk, device, buffers[1]->getAllocation());
1811 invalidateAlloc(vk, device, buffers[2]->getAllocation());
1812
1813 // Clear any writes to buffer[1] during the counting pass
1814 deMemset(ptrs[1], 0, invocationStride * sizeof(deUint64));
1815
1816 // Take the max over all invocations. Add one (to make sure no additional writes are done) and multiply by
1817 // the number of invocations
1818 deUint32 newMaxLoc = 0;
1819 for (deUint32 id = 0; id < invocationStride; ++id)
1820 newMaxLoc = de::max(newMaxLoc, ptrs[2][id]);
1821 newMaxLoc++;
1822 newMaxLoc *= invocationStride;
1823
1824 // If we need more space, reallocate buffers[1]
1825 if (newMaxLoc > maxLoc)
1826 {
1827 maxLoc = newMaxLoc;
1828 sizes[1] = maxLoc * sizeof(deUint64);
1829
1830 if (sizes[1] > properties2.properties.limits.maxStorageBufferRange)
1831 TCU_THROW(NotSupportedError, "Storage buffer size larger than device limits");
1832
1833 try
1834 {
1835 buffers[1] = de::MovePtr<BufferWithMemory>(new BufferWithMemory(
1836 vk, device, allocator, makeBufferCreateInfo(sizes[1], VK_BUFFER_USAGE_STORAGE_BUFFER_BIT|VK_BUFFER_USAGE_TRANSFER_DST_BIT|VK_BUFFER_USAGE_TRANSFER_SRC_BIT),
1837 MemoryRequirement::HostVisible | MemoryRequirement::Cached));
1838 }
1839 catch(tcu::ResourceError&)
1840 {
1841 // Allocation size is unpredictable and can be too large for some systems. Don't treat allocation failure as a test failure.
1842 return tcu::TestStatus(QP_TEST_RESULT_QUALITY_WARNING, "Failed device memory allocation " + de::toString(sizes[1]) + " bytes");
1843 }
1844 bufferDescriptors[1] = makeDescriptorBufferInfo(**buffers[1], 0, sizes[1]);
1845 ptrs[1] = (deUint32 *)buffers[1]->getAllocation().getHostPtr();
1846 deMemset(ptrs[1], 0, (size_t)sizes[1]);
1847
1848 vk::DescriptorSetUpdateBuilder setUpdateBuilder2;
1849 setUpdateBuilder2.writeSingle(*descriptorSet, vk::DescriptorSetUpdateBuilder::Location::binding(1),
1850 VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, &bufferDescriptors[1]);
1851 setUpdateBuilder2.update(vk, device);
1852 }
1853
1854 flushAlloc(vk, device, buffers[1]->getAllocation());
1855
1856 // run the actual shader
1857 beginCommandBuffer(vk, *cmdBuffer, 0u);
1858
1859 vk.cmdBindDescriptorSets(*cmdBuffer, bindPoint, *pipelineLayout, 0u, 1, &*descriptorSet, 0u, DE_NULL);
1860 vk.cmdBindPipeline(*cmdBuffer, bindPoint, *pipeline);
1861
1862 pcinvocationStride = invocationStride;
1863 vk.cmdPushConstants(*cmdBuffer, *pipelineLayout, allShaderStages, 0, sizeof(pcinvocationStride), &pcinvocationStride);
1864
1865 vk.cmdDispatch(*cmdBuffer, 1, 1, 1);
1866
1867 endCommandBuffer(vk, *cmdBuffer);
1868
1869 submitCommandsAndWait(vk, device, queue, cmdBuffer.get());
1870
1871 invalidateAlloc(vk, device, buffers[1]->getAllocation());
1872
1873 qpTestResult res = QP_TEST_RESULT_PASS;
1874
1875 // Simulate execution on the CPU, and compare against the GPU result
1876 std::vector<deUint64> ref;
1877 try
1878 {
1879 ref.resize(maxLoc, 0ull);
1880 }
1881 catch (const std::bad_alloc&)
1882 {
1883 // Allocation size is unpredictable and can be too large for some systems. Don't treat allocation failure as a test failure.
1884 return tcu::TestStatus(QP_TEST_RESULT_NOT_SUPPORTED, "Failed system memory allocation " + de::toString(maxLoc * sizeof(deUint64)) + " bytes");
1885 }
1886
1887 program.simulate(false, subgroupSize, invocationStride, &ref[0]);
1888
1889 const deUint64 *result = (const deUint64 *)ptrs[1];
1890
1891 if (m_data.testType == TT_MAXIMAL)
1892 {
1893 // With maximal reconvergence, we should expect the output to exactly match
1894 // the reference.
1895 for (deUint32 i = 0; i < maxLoc; ++i)
1896 {
1897 if (result[i] != ref[i])
1898 {
1899 log << tcu::TestLog::Message << "first mismatch at " << i << tcu::TestLog::EndMessage;
1900 res = QP_TEST_RESULT_FAIL;
1901 break;
1902 }
1903 }
1904
1905 if (res != QP_TEST_RESULT_PASS)
1906 {
1907 for (deUint32 i = 0; i < maxLoc; ++i)
1908 {
1909 // This log can be large and slow, ifdef it out by default
1910 #if 0
1911 log << tcu::TestLog::Message << "result " << i << "(" << (i/invocationStride) << ", " << (i%invocationStride) << "): " << tcu::toHex(result[i]) << " ref " << tcu::toHex(ref[i]) << (result[i] != ref[i] ? " different" : "") << tcu::TestLog::EndMessage;
1912 #endif
1913 }
1914 }
1915 }
1916 else
1917 {
1918 deUint64 fullMask = subgroupSizeToMask(subgroupSize);
1919 // For subgroup_uniform_control_flow, we expect any fully converged outputs in the reference
1920 // to have a corresponding fully converged output in the result. So walk through each lane's
1921 // results, and for each reference value of fullMask, find a corresponding result value of
1922 // fullMask where the previous value (OP_STORE) matches. That means these came from the same
1923 // source location.
1924 vector<deUint32> firstFail(invocationStride, 0);
1925 for (deUint32 lane = 0; lane < invocationStride; ++lane)
1926 {
1927 deUint32 resLoc = lane + invocationStride, refLoc = lane + invocationStride;
1928 while (refLoc < maxLoc)
1929 {
1930 while (refLoc < maxLoc && ref[refLoc] != fullMask)
1931 refLoc += invocationStride;
1932 if (refLoc >= maxLoc)
1933 break;
1934
1935 // For TT_SUCF_ELECT, when the reference result has a full mask, we expect lane 0 to be elected
1936 // (a value of 2) and all other lanes to be not elected (a value of 1). For TT_SUCF_BALLOT, we
1937 // expect a full mask. Search until we find the expected result with a matching store value in
1938 // the previous result.
1939 deUint64 expectedResult = m_data.isElect() ? ((lane % subgroupSize) == 0 ? 2 : 1)
1940 : fullMask;
1941
1942 while (resLoc < maxLoc && !(result[resLoc] == expectedResult && result[resLoc-invocationStride] == ref[refLoc-invocationStride]))
1943 resLoc += invocationStride;
1944
1945 // If we didn't find this output in the result, flag it as an error.
1946 if (resLoc >= maxLoc)
1947 {
1948 firstFail[lane] = refLoc;
1949 log << tcu::TestLog::Message << "lane " << lane << " first mismatch at " << firstFail[lane] << tcu::TestLog::EndMessage;
1950 res = QP_TEST_RESULT_FAIL;
1951 break;
1952 }
1953 refLoc += invocationStride;
1954 resLoc += invocationStride;
1955 }
1956 }
1957
1958 if (res != QP_TEST_RESULT_PASS)
1959 {
1960 for (deUint32 i = 0; i < maxLoc; ++i)
1961 {
1962 // This log can be large and slow, ifdef it out by default
1963 #if 0
1964 log << tcu::TestLog::Message << "result " << i << "(" << (i/invocationStride) << ", " << (i%invocationStride) << "): " << tcu::toHex(result[i]) << " ref " << tcu::toHex(ref[i]) << (i == firstFail[i%invocationStride] ? " first fail" : "") << tcu::TestLog::EndMessage;
1965 #endif
1966 }
1967 }
1968 }
1969
1970 return tcu::TestStatus(res, qpGetTestResultName(res));
1971 }
1972
1973 } // anonymous
1974
createTests(tcu::TestContext & testCtx,bool createExperimental)1975 tcu::TestCaseGroup* createTests (tcu::TestContext& testCtx, bool createExperimental)
1976 {
1977 de::MovePtr<tcu::TestCaseGroup> group(new tcu::TestCaseGroup(
1978 testCtx, "reconvergence", "reconvergence tests"));
1979
1980 typedef struct
1981 {
1982 deUint32 value;
1983 const char* name;
1984 const char* description;
1985 } TestGroupCase;
1986
1987 TestGroupCase ttCases[] =
1988 {
1989 { TT_SUCF_ELECT, "subgroup_uniform_control_flow_elect", "subgroup_uniform_control_flow_elect" },
1990 { TT_SUCF_BALLOT, "subgroup_uniform_control_flow_ballot", "subgroup_uniform_control_flow_ballot" },
1991 { TT_WUCF_ELECT, "workgroup_uniform_control_flow_elect", "workgroup_uniform_control_flow_elect" },
1992 { TT_WUCF_BALLOT, "workgroup_uniform_control_flow_ballot","workgroup_uniform_control_flow_ballot" },
1993 { TT_MAXIMAL, "maximal", "maximal" },
1994 };
1995
1996 for (int ttNdx = 0; ttNdx < DE_LENGTH_OF_ARRAY(ttCases); ttNdx++)
1997 {
1998 de::MovePtr<tcu::TestCaseGroup> ttGroup(new tcu::TestCaseGroup(testCtx, ttCases[ttNdx].name, ttCases[ttNdx].description));
1999 de::MovePtr<tcu::TestCaseGroup> computeGroup(new tcu::TestCaseGroup(testCtx, "compute", ""));
2000
2001 for (deUint32 nNdx = 2; nNdx <= 6; nNdx++)
2002 {
2003 de::MovePtr<tcu::TestCaseGroup> nestGroup(new tcu::TestCaseGroup(testCtx, ("nesting" + de::toString(nNdx)).c_str(), ""));
2004
2005 deUint32 seed = 0;
2006
2007 for (int sNdx = 0; sNdx < 8; sNdx++)
2008 {
2009 de::MovePtr<tcu::TestCaseGroup> seedGroup(new tcu::TestCaseGroup(testCtx, de::toString(sNdx).c_str(), ""));
2010
2011 deUint32 numTests = 0;
2012 switch (nNdx)
2013 {
2014 default:
2015 DE_ASSERT(0);
2016 // fallthrough
2017 case 2:
2018 case 3:
2019 case 4:
2020 numTests = 250;
2021 break;
2022 case 5:
2023 numTests = 100;
2024 break;
2025 case 6:
2026 numTests = 50;
2027 break;
2028 }
2029
2030 if (ttCases[ttNdx].value != TT_MAXIMAL)
2031 {
2032 if (nNdx >= 5)
2033 continue;
2034 }
2035
2036 for (deUint32 ndx = 0; ndx < numTests; ndx++)
2037 {
2038 CaseDef c =
2039 {
2040 (TestType)ttCases[ttNdx].value, // TestType testType;
2041 nNdx, // deUint32 maxNesting;
2042 seed, // deUint32 seed;
2043 };
2044 seed++;
2045
2046 bool isExperimentalTest = !c.isUCF() || (ndx >= numTests / 5);
2047
2048 if (createExperimental == isExperimentalTest)
2049 seedGroup->addChild(new ReconvergenceTestCase(testCtx, de::toString(ndx).c_str(), "", c));
2050 }
2051 if (!seedGroup->empty())
2052 nestGroup->addChild(seedGroup.release());
2053 }
2054 if (!nestGroup->empty())
2055 computeGroup->addChild(nestGroup.release());
2056 }
2057 if (!computeGroup->empty())
2058 {
2059 ttGroup->addChild(computeGroup.release());
2060 group->addChild(ttGroup.release());
2061 }
2062 }
2063 return group.release();
2064 }
2065
2066 } // Reconvergence
2067 } // vkt
2068