• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*------------------------------------------------------------------------
2  * Vulkan Conformance Tests
3  * ------------------------
4  *
5  * Copyright (c) 2019 The Khronos Group Inc.
6  * Copyright (c) 2018-2020 NVIDIA Corporation
7  *
8  * Licensed under the Apache License, Version 2.0 (the "License");
9  * you may not use this file except in compliance with the License.
10  * You may obtain a copy of the License at
11  *
12  *	  http://www.apache.org/licenses/LICENSE-2.0
13  *
14  * Unless required by applicable law or agreed to in writing, software
15  * distributed under the License is distributed on an "AS IS" BASIS,
16  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17  * See the License for the specific language governing permissions and
18  * limitations under the License.
19  *
20  *//*!
21  * \file
22  * \brief Vulkan Reconvergence tests
23  *//*--------------------------------------------------------------------*/
24 
25 #include "vktReconvergenceTests.hpp"
26 
27 #include "vkBufferWithMemory.hpp"
28 #include "vkImageWithMemory.hpp"
29 #include "vkQueryUtil.hpp"
30 #include "vkBuilderUtil.hpp"
31 #include "vkCmdUtil.hpp"
32 #include "vkTypeUtil.hpp"
33 #include "vkObjUtil.hpp"
34 
35 #include "vktTestGroupUtil.hpp"
36 #include "vktTestCase.hpp"
37 
38 #include "deDefs.h"
39 #include "deFloat16.h"
40 #include "deMath.h"
41 #include "deRandom.h"
42 #include "deSharedPtr.hpp"
43 #include "deString.h"
44 
45 #include "tcuTestCase.hpp"
46 #include "tcuTestLog.hpp"
47 
48 #include <bitset>
49 #include <string>
50 #include <sstream>
51 #include <set>
52 #include <vector>
53 
54 namespace vkt
55 {
56 namespace Reconvergence
57 {
58 namespace
59 {
60 using namespace vk;
61 using namespace std;
62 
63 #define ARRAYSIZE(x) (sizeof(x) / sizeof(x[0]))
64 
65 const VkFlags allShaderStages = VK_SHADER_STAGE_COMPUTE_BIT;
66 
67 typedef enum {
68 	TT_SUCF_ELECT,	// subgroup_uniform_control_flow using elect (subgroup_basic)
69 	TT_SUCF_BALLOT,	// subgroup_uniform_control_flow using ballot (subgroup_ballot)
70 	TT_WUCF_ELECT,	// workgroup uniform control flow using elect (subgroup_basic)
71 	TT_WUCF_BALLOT,	// workgroup uniform control flow using ballot (subgroup_ballot)
72 	TT_MAXIMAL,		// maximal reconvergence
73 } TestType;
74 
75 struct CaseDef
76 {
77 	TestType testType;
78 	deUint32 maxNesting;
79 	deUint32 seed;
80 
isWUCFvkt::Reconvergence::__anon5fc7ae020111::CaseDef81 	bool isWUCF() const { return testType == TT_WUCF_ELECT || testType == TT_WUCF_BALLOT; }
isSUCFvkt::Reconvergence::__anon5fc7ae020111::CaseDef82 	bool isSUCF() const { return testType == TT_SUCF_ELECT || testType == TT_SUCF_BALLOT; }
isUCFvkt::Reconvergence::__anon5fc7ae020111::CaseDef83 	bool isUCF() const { return isWUCF() || isSUCF(); }
isElectvkt::Reconvergence::__anon5fc7ae020111::CaseDef84 	bool isElect() const { return testType == TT_WUCF_ELECT || testType == TT_SUCF_ELECT; }
85 };
86 
subgroupSizeToMask(deUint32 subgroupSize)87 deUint64 subgroupSizeToMask(deUint32 subgroupSize)
88 {
89 	if (subgroupSize == 64)
90 		return ~0ULL;
91 	else
92 		return (1ULL << subgroupSize) - 1;
93 }
94 
95 typedef std::bitset<128> bitset128;
96 
97 // Take a 64-bit integer, mask it to the subgroup size, and then
98 // replicate it for each subgroup
bitsetFromU64(deUint64 mask,deUint32 subgroupSize)99 bitset128 bitsetFromU64(deUint64 mask, deUint32 subgroupSize)
100 {
101 	mask &= subgroupSizeToMask(subgroupSize);
102 	bitset128 result(mask);
103 	for (deUint32 i = 0; i < 128 / subgroupSize - 1; ++i)
104 	{
105 		result = (result << subgroupSize) | bitset128(mask);
106 	}
107 	return result;
108 }
109 
110 // Pick out the mask for the subgroup that invocationID is a member of
bitsetToU64(const bitset128 & bitset,deUint32 subgroupSize,deUint32 invocationID)111 deUint64 bitsetToU64(const bitset128 &bitset, deUint32 subgroupSize, deUint32 invocationID)
112 {
113 	bitset128 copy(bitset);
114 	copy >>= (invocationID / subgroupSize) * subgroupSize;
115 	copy &= bitset128(subgroupSizeToMask(subgroupSize));
116 	deUint64 mask = copy.to_ullong();
117 	mask &= subgroupSizeToMask(subgroupSize);
118 	return mask;
119 }
120 
121 class ReconvergenceTestInstance : public TestInstance
122 {
123 public:
124 						ReconvergenceTestInstance	(Context& context, const CaseDef& data);
125 						~ReconvergenceTestInstance	(void);
126 	tcu::TestStatus		iterate				(void);
127 private:
128 	CaseDef			m_data;
129 };
130 
ReconvergenceTestInstance(Context & context,const CaseDef & data)131 ReconvergenceTestInstance::ReconvergenceTestInstance (Context& context, const CaseDef& data)
132 	: vkt::TestInstance		(context)
133 	, m_data				(data)
134 {
135 }
136 
~ReconvergenceTestInstance(void)137 ReconvergenceTestInstance::~ReconvergenceTestInstance (void)
138 {
139 }
140 
141 class ReconvergenceTestCase : public TestCase
142 {
143 	public:
144 								ReconvergenceTestCase		(tcu::TestContext& context, const char* name, const char* desc, const CaseDef data);
145 								~ReconvergenceTestCase	(void);
146 	virtual	void				initPrograms		(SourceCollections& programCollection) const;
147 	virtual TestInstance*		createInstance		(Context& context) const;
148 	virtual void				checkSupport		(Context& context) const;
149 
150 private:
151 	CaseDef					m_data;
152 };
153 
ReconvergenceTestCase(tcu::TestContext & context,const char * name,const char * desc,const CaseDef data)154 ReconvergenceTestCase::ReconvergenceTestCase (tcu::TestContext& context, const char* name, const char* desc, const CaseDef data)
155 	: vkt::TestCase	(context, name, desc)
156 	, m_data		(data)
157 {
158 }
159 
~ReconvergenceTestCase(void)160 ReconvergenceTestCase::~ReconvergenceTestCase	(void)
161 {
162 }
163 
checkSupport(Context & context) const164 void ReconvergenceTestCase::checkSupport(Context& context) const
165 {
166 	if (!context.contextSupports(vk::ApiVersion(1, 1, 0)))
167 		TCU_THROW(NotSupportedError, "Vulkan 1.1 not supported");
168 
169 	vk::VkPhysicalDeviceSubgroupProperties subgroupProperties;
170 	deMemset(&subgroupProperties, 0, sizeof(subgroupProperties));
171 	subgroupProperties.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SUBGROUP_PROPERTIES;
172 
173 	vk::VkPhysicalDeviceProperties2 properties2;
174 	deMemset(&properties2, 0, sizeof(properties2));
175 	properties2.sType = vk::VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_PROPERTIES_2;
176 	properties2.pNext = &subgroupProperties;
177 
178 	context.getInstanceInterface().getPhysicalDeviceProperties2(context.getPhysicalDevice(), &properties2);
179 
180 	if (m_data.isElect() && !(subgroupProperties.supportedOperations & VK_SUBGROUP_FEATURE_BASIC_BIT))
181 		TCU_THROW(NotSupportedError, "VK_SUBGROUP_FEATURE_BASIC_BIT not supported");
182 
183 	if (!m_data.isElect() && !(subgroupProperties.supportedOperations & VK_SUBGROUP_FEATURE_BALLOT_BIT))
184 		TCU_THROW(NotSupportedError, "VK_SUBGROUP_FEATURE_BALLOT_BIT not supported");
185 
186 	if (!(context.getSubgroupProperties().supportedStages & VK_SHADER_STAGE_COMPUTE_BIT))
187 		TCU_THROW(NotSupportedError, "compute stage does not support subgroup operations");
188 
189 	// Both subgroup- AND workgroup-uniform tests are enabled by shaderSubgroupUniformControlFlow.
190 	if (m_data.isUCF() && !context.getShaderSubgroupUniformControlFlowFeatures().shaderSubgroupUniformControlFlow)
191 		TCU_THROW(NotSupportedError, "shaderSubgroupUniformControlFlow not supported");
192 
193 	// XXX TODO: Check for maximal reconvergence support
194 	// if (m_data.testType == TT_MAXIMAL ...)
195 }
196 
197 typedef enum
198 {
199 	// store subgroupBallot().
200 	// For OP_BALLOT, OP::caseValue is initialized to zero, and then
201 	// set to 1 by simulate if the ballot is not workgroup- (or subgroup-_uniform.
202 	// Only workgroup-uniform ballots are validated for correctness in
203 	// WUCF modes.
204 	OP_BALLOT,
205 
206 	// store literal constant
207 	OP_STORE,
208 
209 	// if ((1ULL << gl_SubgroupInvocationID) & mask).
210 	// Special case if mask = ~0ULL, converted into "if (inputA.a[idx] == idx)"
211 	OP_IF_MASK,
212 	OP_ELSE_MASK,
213 	OP_ENDIF,
214 
215 	// if (gl_SubgroupInvocationID == loopIdxN) (where N is most nested loop counter)
216 	OP_IF_LOOPCOUNT,
217 	OP_ELSE_LOOPCOUNT,
218 
219 	// if (gl_LocalInvocationIndex >= inputA.a[N]) (where N is most nested loop counter)
220 	OP_IF_LOCAL_INVOCATION_INDEX,
221 	OP_ELSE_LOCAL_INVOCATION_INDEX,
222 
223 	// break/continue
224 	OP_BREAK,
225 	OP_CONTINUE,
226 
227 	// if (subgroupElect())
228 	OP_ELECT,
229 
230 	// Loop with uniform number of iterations (read from a buffer)
231 	OP_BEGIN_FOR_UNIF,
232 	OP_END_FOR_UNIF,
233 
234 	// for (int loopIdxN = 0; loopIdxN < gl_SubgroupInvocationID + 1; ++loopIdxN)
235 	OP_BEGIN_FOR_VAR,
236 	OP_END_FOR_VAR,
237 
238 	// for (int loopIdxN = 0;; ++loopIdxN, OP_BALLOT)
239 	// Always has an "if (subgroupElect()) break;" inside.
240 	// Does the equivalent of OP_BALLOT in the continue construct
241 	OP_BEGIN_FOR_INF,
242 	OP_END_FOR_INF,
243 
244 	// do { loopIdxN++; ... } while (loopIdxN < uniformValue);
245 	OP_BEGIN_DO_WHILE_UNIF,
246 	OP_END_DO_WHILE_UNIF,
247 
248 	// do { ... } while (true);
249 	// Always has an "if (subgroupElect()) break;" inside
250 	OP_BEGIN_DO_WHILE_INF,
251 	OP_END_DO_WHILE_INF,
252 
253 	// return;
254 	OP_RETURN,
255 
256 	// function call (code bracketed by these is extracted into a separate function)
257 	OP_CALL_BEGIN,
258 	OP_CALL_END,
259 
260 	// switch statement on uniform value
261 	OP_SWITCH_UNIF_BEGIN,
262 	// switch statement on gl_SubgroupInvocationID & 3 value
263 	OP_SWITCH_VAR_BEGIN,
264 	// switch statement on loopIdx value
265 	OP_SWITCH_LOOP_COUNT_BEGIN,
266 
267 	// case statement with a (invocation mask, case mask) pair
268 	OP_CASE_MASK_BEGIN,
269 	// case statement used for loop counter switches, with a value and a mask of loop iterations
270 	OP_CASE_LOOP_COUNT_BEGIN,
271 
272 	// end of switch/case statement
273 	OP_SWITCH_END,
274 	OP_CASE_END,
275 
276 	// Extra code with no functional effect. Currently inculdes:
277 	// - value 0: while (!subgroupElect()) {}
278 	// - value 1: if (condition_that_is_false) { infinite loop }
279 	OP_NOISE,
280 } OPType;
281 
282 typedef enum
283 {
284 	// Different if test conditions
285 	IF_MASK,
286 	IF_UNIFORM,
287 	IF_LOOPCOUNT,
288 	IF_LOCAL_INVOCATION_INDEX,
289 } IFType;
290 
291 class OP
292 {
293 public:
OP(OPType _type,deUint64 _value,deUint32 _caseValue=0)294 	OP(OPType _type, deUint64 _value, deUint32 _caseValue = 0)
295 		: type(_type), value(_value), caseValue(_caseValue)
296 	{}
297 
298 	// The type of operation and an optional value.
299 	// The value could be a mask for an if test, the index of the loop
300 	// header for an end of loop, or the constant value for a store instruction
301 	OPType type;
302 	deUint64 value;
303 	deUint32 caseValue;
304 };
305 
findLSB(deUint64 value)306 static int findLSB (deUint64 value)
307 {
308 	for (int i = 0; i < 64; i++)
309 	{
310 		if (value & (1ULL<<i))
311 			return i;
312 	}
313 	return -1;
314 }
315 
316 // For each subgroup, pick out the elected invocationID, and accumulate
317 // a bitset of all of them
bitsetElect(const bitset128 & value,deInt32 subgroupSize)318 static bitset128 bitsetElect (const bitset128& value, deInt32 subgroupSize)
319 {
320 	bitset128 ret; // zero initialized
321 
322 	for (deInt32 i = 0; i < 128; i += subgroupSize)
323 	{
324 		deUint64 mask = bitsetToU64(value, subgroupSize, i);
325 		int lsb = findLSB(mask);
326 		ret |= bitset128(lsb == -1 ? 0 : (1ULL << lsb)) << i;
327 	}
328 	return ret;
329 }
330 
331 class RandomProgram
332 {
333 public:
RandomProgram(const CaseDef & c)334 	RandomProgram(const CaseDef &c)
335 		: caseDef(c), numMasks(5), nesting(0), maxNesting(c.maxNesting), loopNesting(0), loopNestingThisFunction(0), callNesting(0), minCount(30), indent(0), isLoopInf(100, false), doneInfLoopBreak(100, false), storeBase(0x10000)
336 	{
337 		deRandom_init(&rnd, caseDef.seed);
338 		for (int i = 0; i < numMasks; ++i)
339 			masks.push_back(deRandom_getUint64(&rnd));
340 	}
341 
342 	const CaseDef caseDef;
343 	deRandom rnd;
344 	vector<OP> ops;
345 	vector<deUint64> masks;
346 	deInt32 numMasks;
347 	deInt32 nesting;
348 	deInt32 maxNesting;
349 	deInt32 loopNesting;
350 	deInt32 loopNestingThisFunction;
351 	deInt32 callNesting;
352 	deInt32 minCount;
353 	deInt32 indent;
354 	vector<bool> isLoopInf;
355 	vector<bool> doneInfLoopBreak;
356 	// Offset the value we use for OP_STORE, to avoid colliding with fully converged
357 	// active masks with small subgroup sizes (e.g. with subgroupSize == 4, the SUCF
358 	// tests need to know that 0xF is really an active mask).
359 	deInt32 storeBase;
360 
genIf(IFType ifType)361 	void genIf(IFType ifType)
362 	{
363 		deUint32 maskIdx = deRandom_getUint32(&rnd) % numMasks;
364 		deUint64 mask = masks[maskIdx];
365 		if (ifType == IF_UNIFORM)
366 			mask = ~0ULL;
367 
368 		deUint32 localIndexCmp = deRandom_getUint32(&rnd) % 128;
369 		if (ifType == IF_LOCAL_INVOCATION_INDEX)
370 			ops.push_back({OP_IF_LOCAL_INVOCATION_INDEX, localIndexCmp});
371 		else if (ifType == IF_LOOPCOUNT)
372 			ops.push_back({OP_IF_LOOPCOUNT, 0});
373 		else
374 			ops.push_back({OP_IF_MASK, mask});
375 
376 		nesting++;
377 
378 		size_t thenBegin = ops.size();
379 		pickOP(2);
380 		size_t thenEnd = ops.size();
381 
382 		deUint32 randElse = (deRandom_getUint32(&rnd) % 100);
383 		if (randElse < 50)
384 		{
385 			if (ifType == IF_LOCAL_INVOCATION_INDEX)
386 				ops.push_back({OP_ELSE_LOCAL_INVOCATION_INDEX, localIndexCmp});
387 			else if (ifType == IF_LOOPCOUNT)
388 				ops.push_back({OP_ELSE_LOOPCOUNT, 0});
389 			else
390 				ops.push_back({OP_ELSE_MASK, 0});
391 
392 			if (randElse < 10)
393 			{
394 				// Sometimes make the else block identical to the then block
395 				for (size_t i = thenBegin; i < thenEnd; ++i)
396 					ops.push_back(ops[i]);
397 			}
398 			else
399 				pickOP(2);
400 		}
401 		ops.push_back({OP_ENDIF, 0});
402 		nesting--;
403 	}
404 
genForUnif()405 	void genForUnif()
406 	{
407 		deUint32 iterCount = (deRandom_getUint32(&rnd) % 5) + 1;
408 		ops.push_back({OP_BEGIN_FOR_UNIF, iterCount});
409 		deUint32 loopheader = (deUint32)ops.size()-1;
410 		nesting++;
411 		loopNesting++;
412 		loopNestingThisFunction++;
413 		pickOP(2);
414 		ops.push_back({OP_END_FOR_UNIF, loopheader});
415 		loopNestingThisFunction--;
416 		loopNesting--;
417 		nesting--;
418 	}
419 
genDoWhileUnif()420 	void genDoWhileUnif()
421 	{
422 		deUint32 iterCount = (deRandom_getUint32(&rnd) % 5) + 1;
423 		ops.push_back({OP_BEGIN_DO_WHILE_UNIF, iterCount});
424 		deUint32 loopheader = (deUint32)ops.size()-1;
425 		nesting++;
426 		loopNesting++;
427 		loopNestingThisFunction++;
428 		pickOP(2);
429 		ops.push_back({OP_END_DO_WHILE_UNIF, loopheader});
430 		loopNestingThisFunction--;
431 		loopNesting--;
432 		nesting--;
433 	}
434 
genForVar()435 	void genForVar()
436 	{
437 		ops.push_back({OP_BEGIN_FOR_VAR, 0});
438 		deUint32 loopheader = (deUint32)ops.size()-1;
439 		nesting++;
440 		loopNesting++;
441 		loopNestingThisFunction++;
442 		pickOP(2);
443 		ops.push_back({OP_END_FOR_VAR, loopheader});
444 		loopNestingThisFunction--;
445 		loopNesting--;
446 		nesting--;
447 	}
448 
genForInf()449 	void genForInf()
450 	{
451 		ops.push_back({OP_BEGIN_FOR_INF, 0});
452 		deUint32 loopheader = (deUint32)ops.size()-1;
453 
454 		nesting++;
455 		loopNesting++;
456 		loopNestingThisFunction++;
457 		isLoopInf[loopNesting] = true;
458 		doneInfLoopBreak[loopNesting] = false;
459 
460 		pickOP(2);
461 
462 		genElect(true);
463 		doneInfLoopBreak[loopNesting] = true;
464 
465 		pickOP(2);
466 
467 		ops.push_back({OP_END_FOR_INF, loopheader});
468 
469 		isLoopInf[loopNesting] = false;
470 		doneInfLoopBreak[loopNesting] = false;
471 		loopNestingThisFunction--;
472 		loopNesting--;
473 		nesting--;
474 	}
475 
genDoWhileInf()476 	void genDoWhileInf()
477 	{
478 		ops.push_back({OP_BEGIN_DO_WHILE_INF, 0});
479 		deUint32 loopheader = (deUint32)ops.size()-1;
480 
481 		nesting++;
482 		loopNesting++;
483 		loopNestingThisFunction++;
484 		isLoopInf[loopNesting] = true;
485 		doneInfLoopBreak[loopNesting] = false;
486 
487 		pickOP(2);
488 
489 		genElect(true);
490 		doneInfLoopBreak[loopNesting] = true;
491 
492 		pickOP(2);
493 
494 		ops.push_back({OP_END_DO_WHILE_INF, loopheader});
495 
496 		isLoopInf[loopNesting] = false;
497 		doneInfLoopBreak[loopNesting] = false;
498 		loopNestingThisFunction--;
499 		loopNesting--;
500 		nesting--;
501 	}
502 
genBreak()503 	void genBreak()
504 	{
505 		if (loopNestingThisFunction > 0)
506 		{
507 			// Sometimes put the break in a divergent if
508 			if ((deRandom_getUint32(&rnd) % 100) < 10)
509 			{
510 				ops.push_back({OP_IF_MASK, masks[0]});
511 				ops.push_back({OP_BREAK, 0});
512 				ops.push_back({OP_ELSE_MASK, 0});
513 				ops.push_back({OP_BREAK, 0});
514 				ops.push_back({OP_ENDIF, 0});
515 			}
516 			else
517 				ops.push_back({OP_BREAK, 0});
518 		}
519 	}
520 
genContinue()521 	void genContinue()
522 	{
523 		// continues are allowed if we're in a loop and the loop is not infinite,
524 		// or if it is infinite and we've already done a subgroupElect+break.
525 		// However, adding more continues seems to reduce the failure rate, so
526 		// disable it for now
527 		if (loopNestingThisFunction > 0 && !(isLoopInf[loopNesting] /*&& !doneInfLoopBreak[loopNesting]*/))
528 		{
529 			// Sometimes put the continue in a divergent if
530 			if ((deRandom_getUint32(&rnd) % 100) < 10)
531 			{
532 				ops.push_back({OP_IF_MASK, masks[0]});
533 				ops.push_back({OP_CONTINUE, 0});
534 				ops.push_back({OP_ELSE_MASK, 0});
535 				ops.push_back({OP_CONTINUE, 0});
536 				ops.push_back({OP_ENDIF, 0});
537 			}
538 			else
539 				ops.push_back({OP_CONTINUE, 0});
540 		}
541 	}
542 
543 	// doBreak is used to generate "if (subgroupElect()) { ... break; }" inside infinite loops
genElect(bool doBreak)544 	void genElect(bool doBreak)
545 	{
546 		ops.push_back({OP_ELECT, 0});
547 		nesting++;
548 		if (doBreak)
549 		{
550 			// Put something interestign before the break
551 			optBallot();
552 			optBallot();
553 			if ((deRandom_getUint32(&rnd) % 100) < 10)
554 				pickOP(1);
555 
556 			// if we're in a function, sometimes  use return instead
557 			if (callNesting > 0 && (deRandom_getUint32(&rnd) % 100) < 30)
558 				ops.push_back({OP_RETURN, 0});
559 			else
560 				genBreak();
561 
562 		}
563 		else
564 			pickOP(2);
565 
566 		ops.push_back({OP_ENDIF, 0});
567 		nesting--;
568 	}
569 
genReturn()570 	void genReturn()
571 	{
572 		deUint32 r = deRandom_getUint32(&rnd) % 100;
573 		if (nesting > 0 &&
574 			// Use return rarely in main, 20% of the time in a singly nested loop in a function
575 			// and 50% of the time in a multiply nested loop in a function
576 			(r < 5 ||
577 			 (callNesting > 0 && loopNestingThisFunction > 0 && r < 20) ||
578 			 (callNesting > 0 && loopNestingThisFunction > 1 && r < 50)))
579 		{
580 			optBallot();
581 			if ((deRandom_getUint32(&rnd) % 100) < 10)
582 			{
583 				ops.push_back({OP_IF_MASK, masks[0]});
584 				ops.push_back({OP_RETURN, 0});
585 				ops.push_back({OP_ELSE_MASK, 0});
586 				ops.push_back({OP_RETURN, 0});
587 				ops.push_back({OP_ENDIF, 0});
588 			}
589 			else
590 				ops.push_back({OP_RETURN, 0});
591 		}
592 	}
593 
594 	// Generate a function call. Save and restore some loop information, which is used to
595 	// determine when it's safe to use break/continue
genCall()596 	void genCall()
597 	{
598 		ops.push_back({OP_CALL_BEGIN, 0});
599 		callNesting++;
600 		nesting++;
601 		deInt32 saveLoopNestingThisFunction = loopNestingThisFunction;
602 		loopNestingThisFunction = 0;
603 
604 		pickOP(2);
605 
606 		loopNestingThisFunction = saveLoopNestingThisFunction;
607 		nesting--;
608 		callNesting--;
609 		ops.push_back({OP_CALL_END, 0});
610 	}
611 
612 	// Generate switch on a uniform value:
613 	// switch (inputA.a[r]) {
614 	// case r+1: ... break; // should not execute
615 	// case r:   ... break; // should branch uniformly
616 	// case r+2: ... break; // should not execute
617 	// }
genSwitchUnif()618 	void genSwitchUnif()
619 	{
620 		deUint32 r = deRandom_getUint32(&rnd) % 5;
621 		ops.push_back({OP_SWITCH_UNIF_BEGIN, r});
622 		nesting++;
623 
624 		ops.push_back({OP_CASE_MASK_BEGIN, 0, 1u<<(r+1)});
625 		pickOP(1);
626 		ops.push_back({OP_CASE_END, 0});
627 
628 		ops.push_back({OP_CASE_MASK_BEGIN, ~0ULL, 1u<<r});
629 		pickOP(2);
630 		ops.push_back({OP_CASE_END, 0});
631 
632 		ops.push_back({OP_CASE_MASK_BEGIN, 0, 1u<<(r+2)});
633 		pickOP(1);
634 		ops.push_back({OP_CASE_END, 0});
635 
636 		ops.push_back({OP_SWITCH_END, 0});
637 		nesting--;
638 	}
639 
640 	// switch (gl_SubgroupInvocationID & 3) with four unique targets
genSwitchVar()641 	void genSwitchVar()
642 	{
643 		ops.push_back({OP_SWITCH_VAR_BEGIN, 0});
644 		nesting++;
645 
646 		ops.push_back({OP_CASE_MASK_BEGIN, 0x1111111111111111ULL, 1<<0});
647 		pickOP(1);
648 		ops.push_back({OP_CASE_END, 0});
649 
650 		ops.push_back({OP_CASE_MASK_BEGIN, 0x2222222222222222ULL, 1<<1});
651 		pickOP(1);
652 		ops.push_back({OP_CASE_END, 0});
653 
654 		ops.push_back({OP_CASE_MASK_BEGIN, 0x4444444444444444ULL, 1<<2});
655 		pickOP(1);
656 		ops.push_back({OP_CASE_END, 0});
657 
658 		ops.push_back({OP_CASE_MASK_BEGIN, 0x8888888888888888ULL, 1<<3});
659 		pickOP(1);
660 		ops.push_back({OP_CASE_END, 0});
661 
662 		ops.push_back({OP_SWITCH_END, 0});
663 		nesting--;
664 	}
665 
666 	// switch (gl_SubgroupInvocationID & 3) with two shared targets.
667 	// XXX TODO: The test considers these two targets to remain converged,
668 	// though we haven't agreed to that behavior yet.
genSwitchMulticase()669 	void genSwitchMulticase()
670 	{
671 		ops.push_back({OP_SWITCH_VAR_BEGIN, 0});
672 		nesting++;
673 
674 		ops.push_back({OP_CASE_MASK_BEGIN, 0x3333333333333333ULL, (1<<0)|(1<<1)});
675 		pickOP(2);
676 		ops.push_back({OP_CASE_END, 0});
677 
678 		ops.push_back({OP_CASE_MASK_BEGIN, 0xCCCCCCCCCCCCCCCCULL, (1<<2)|(1<<3)});
679 		pickOP(2);
680 		ops.push_back({OP_CASE_END, 0});
681 
682 		ops.push_back({OP_SWITCH_END, 0});
683 		nesting--;
684 	}
685 
686 	// switch (loopIdxN) {
687 	// case 1:  ... break;
688 	// case 2:  ... break;
689 	// default: ... break;
690 	// }
genSwitchLoopCount()691 	void genSwitchLoopCount()
692 	{
693 		deUint32 r = deRandom_getUint32(&rnd) % loopNesting;
694 		ops.push_back({OP_SWITCH_LOOP_COUNT_BEGIN, r});
695 		nesting++;
696 
697 		ops.push_back({OP_CASE_LOOP_COUNT_BEGIN, 1ULL<<1, 1});
698 		pickOP(1);
699 		ops.push_back({OP_CASE_END, 0});
700 
701 		ops.push_back({OP_CASE_LOOP_COUNT_BEGIN, 1ULL<<2, 2});
702 		pickOP(1);
703 		ops.push_back({OP_CASE_END, 0});
704 
705 		// default:
706 		ops.push_back({OP_CASE_LOOP_COUNT_BEGIN, ~6ULL, 0xFFFFFFFF});
707 		pickOP(1);
708 		ops.push_back({OP_CASE_END, 0});
709 
710 		ops.push_back({OP_SWITCH_END, 0});
711 		nesting--;
712 	}
713 
pickOP(deUint32 count)714 	void pickOP(deUint32 count)
715 	{
716 		// Pick "count" instructions. These can recursively insert more instructions,
717 		// so "count" is just a seed
718 		for (deUint32 i = 0; i < count; ++i)
719 		{
720 			optBallot();
721 			if (nesting < maxNesting)
722 			{
723 				deUint32 r = deRandom_getUint32(&rnd) % 11;
724 				switch (r)
725 				{
726 				default:
727 					DE_ASSERT(0);
728 					// fallthrough
729 				case 2:
730 					if (loopNesting)
731 					{
732 						genIf(IF_LOOPCOUNT);
733 						break;
734 					}
735 					// fallthrough
736 				case 10:
737 					genIf(IF_LOCAL_INVOCATION_INDEX);
738 					break;
739 				case 0:
740 					genIf(IF_MASK);
741 					break;
742 				case 1:
743 					genIf(IF_UNIFORM);
744 					break;
745 				case 3:
746 					{
747 						// don't nest loops too deeply, to avoid extreme memory usage or timeouts
748 						if (loopNesting <= 3)
749 						{
750 							deUint32 r2 = deRandom_getUint32(&rnd) % 3;
751 							switch (r2)
752 							{
753 							default: DE_ASSERT(0); // fallthrough
754 							case 0: genForUnif(); break;
755 							case 1: genForInf(); break;
756 							case 2: genForVar(); break;
757 							}
758 						}
759 					}
760 					break;
761 				case 4:
762 					genBreak();
763 					break;
764 				case 5:
765 					genContinue();
766 					break;
767 				case 6:
768 					genElect(false);
769 					break;
770 				case 7:
771 					{
772 						deUint32 r2 = deRandom_getUint32(&rnd) % 5;
773 						if (r2 == 0 && callNesting == 0 && nesting < maxNesting - 2)
774 							genCall();
775 						else
776 							genReturn();
777 						break;
778 					}
779 				case 8:
780 					{
781 						// don't nest loops too deeply, to avoid extreme memory usage or timeouts
782 						if (loopNesting <= 3)
783 						{
784 							deUint32 r2 = deRandom_getUint32(&rnd) % 2;
785 							switch (r2)
786 							{
787 							default: DE_ASSERT(0); // fallthrough
788 							case 0: genDoWhileUnif(); break;
789 							case 1: genDoWhileInf(); break;
790 							}
791 						}
792 					}
793 					break;
794 				case 9:
795 					{
796 						deUint32 r2 = deRandom_getUint32(&rnd) % 4;
797 						switch (r2)
798 						{
799 						default:
800 							DE_ASSERT(0);
801 							// fallthrough
802 						case 0:
803 							genSwitchUnif();
804 							break;
805 						case 1:
806 							if (loopNesting > 0) {
807 								genSwitchLoopCount();
808 								break;
809 							}
810 							// fallthrough
811 						case 2:
812 							if (caseDef.testType != TT_MAXIMAL)
813 							{
814 								// multicase doesn't have fully-defined behavior for MAXIMAL tests,
815 								// but does for SUCF tests
816 								genSwitchMulticase();
817 								break;
818 							}
819 							// fallthrough
820 						case 3:
821 							genSwitchVar();
822 							break;
823 						}
824 					}
825 					break;
826 				}
827 			}
828 			optBallot();
829 		}
830 	}
831 
optBallot()832 	void optBallot()
833 	{
834 		// optionally insert ballots, stores, and noise. Ballots and stores are used to determine
835 		// correctness.
836 		if ((deRandom_getUint32(&rnd) % 100) < 20)
837 		{
838 			if (ops.size() < 2 ||
839 			   !(ops[ops.size()-1].type == OP_BALLOT ||
840 				 (ops[ops.size()-1].type == OP_STORE && ops[ops.size()-2].type == OP_BALLOT)))
841 			{
842 				// do a store along with each ballot, so we can correlate where
843 				// the ballot came from
844 				if (caseDef.testType != TT_MAXIMAL)
845 					ops.push_back({OP_STORE, (deUint32)ops.size() + storeBase});
846 				ops.push_back({OP_BALLOT, 0});
847 			}
848 		}
849 
850 		if ((deRandom_getUint32(&rnd) % 100) < 10)
851 		{
852 			if (ops.size() < 2 ||
853 			   !(ops[ops.size()-1].type == OP_STORE ||
854 				 (ops[ops.size()-1].type == OP_BALLOT && ops[ops.size()-2].type == OP_STORE)))
855 			{
856 				// SUCF does a store with every ballot. Don't bloat the code by adding more.
857 				if (caseDef.testType == TT_MAXIMAL)
858 					ops.push_back({OP_STORE, (deUint32)ops.size() + storeBase});
859 			}
860 		}
861 
862 		deUint32 r = deRandom_getUint32(&rnd) % 10000;
863 		if (r < 3)
864 			ops.push_back({OP_NOISE, 0});
865 		else if (r < 10)
866 			ops.push_back({OP_NOISE, 1});
867 	}
868 
generateRandomProgram()869 	void generateRandomProgram()
870 	{
871 		do {
872 			ops.clear();
873 			while ((deInt32)ops.size() < minCount)
874 				pickOP(1);
875 
876 			// Retry until the program has some UCF results in it
877 			if (caseDef.isUCF())
878 			{
879 				const deUint32 invocationStride = 128;
880 				// Simulate for all subgroup sizes, to determine whether OP_BALLOTs are nonuniform
881 				for (deInt32 subgroupSize = 4; subgroupSize <= 64; subgroupSize *= 2) {
882 					simulate(true, subgroupSize, invocationStride, DE_NULL);
883 				}
884 			}
885 		} while (caseDef.isUCF() && !hasUCF());
886 	}
887 
printIndent(std::stringstream & css)888 	void printIndent(std::stringstream &css)
889 	{
890 		for (deInt32 i = 0; i < indent; ++i)
891 			css << " ";
892 	}
893 
genPartitionBallot()894 	std::string genPartitionBallot()
895 	{
896 		std::stringstream ss;
897 		ss << "subgroupBallot(true).xy";
898 		return ss.str();
899 	}
900 
printBallot(std::stringstream * css)901 	void printBallot(std::stringstream *css)
902 	{
903 		*css << "outputC.loc[gl_LocalInvocationIndex]++,";
904 		// When inside loop(s), use partitionBallot rather than subgroupBallot to compute
905 		// a ballot, to make sure the ballot is "diverged enough". Don't do this for
906 		// subgroup_uniform_control_flow, since we only validate results that must be fully
907 		// reconverged.
908 		if (loopNesting > 0 && caseDef.testType == TT_MAXIMAL)
909 		{
910 			*css << "outputB.b[(outLoc++)*invocationStride + gl_LocalInvocationIndex] = " << genPartitionBallot();
911 		}
912 		else if (caseDef.isElect())
913 		{
914 			*css << "outputB.b[(outLoc++)*invocationStride + gl_LocalInvocationIndex].x = elect()";
915 		}
916 		else
917 		{
918 			*css << "outputB.b[(outLoc++)*invocationStride + gl_LocalInvocationIndex] = subgroupBallot(true).xy";
919 		}
920 	}
921 
genCode(std::stringstream & functions,std::stringstream & main)922 	void genCode(std::stringstream &functions, std::stringstream &main)
923 	{
924 		std::stringstream *css = &main;
925 		indent = 4;
926 		loopNesting = 0;
927 		int funcNum = 0;
928 		for (deInt32 i = 0; i < (deInt32)ops.size(); ++i)
929 		{
930 			switch (ops[i].type)
931 			{
932 			case OP_IF_MASK:
933 				printIndent(*css);
934 				if (ops[i].value == ~0ULL)
935 				{
936 					// This equality test will always succeed, since inputA.a[i] == i
937 					int idx = deRandom_getUint32(&rnd) % 4;
938 					*css << "if (inputA.a[" << idx << "] == " << idx << ") {\n";
939 				}
940 				else
941 					*css << "if (testBit(uvec2(0x" << std::hex << (ops[i].value & 0xFFFFFFFF) << ", 0x" << (ops[i].value >> 32) << "), gl_SubgroupInvocationID)) {\n";
942 
943 				indent += 4;
944 				break;
945 			case OP_IF_LOOPCOUNT:
946 				printIndent(*css); *css << "if (gl_SubgroupInvocationID == loopIdx" << loopNesting - 1 << ") {\n";
947 				indent += 4;
948 				break;
949 			case OP_IF_LOCAL_INVOCATION_INDEX:
950 				printIndent(*css); *css << "if (gl_LocalInvocationIndex >= inputA.a[0x" << std::hex << ops[i].value << "]) {\n";
951 				indent += 4;
952 				break;
953 			case OP_ELSE_MASK:
954 			case OP_ELSE_LOOPCOUNT:
955 			case OP_ELSE_LOCAL_INVOCATION_INDEX:
956 				indent -= 4;
957 				printIndent(*css); *css << "} else {\n";
958 				indent += 4;
959 				break;
960 			case OP_ENDIF:
961 				indent -= 4;
962 				printIndent(*css); *css << "}\n";
963 				break;
964 			case OP_BALLOT:
965 				printIndent(*css); printBallot(css); *css << ";\n";
966 				break;
967 			case OP_STORE:
968 				printIndent(*css); *css << "outputC.loc[gl_LocalInvocationIndex]++;\n";
969 				printIndent(*css); *css << "outputB.b[(outLoc++)*invocationStride + gl_LocalInvocationIndex].x = 0x" << std::hex << ops[i].value << ";\n";
970 				break;
971 			case OP_BEGIN_FOR_UNIF:
972 				printIndent(*css); *css << "for (int loopIdx" << loopNesting << " = 0;\n";
973 				printIndent(*css); *css << "         loopIdx" << loopNesting << " < inputA.a[" << ops[i].value << "];\n";
974 				printIndent(*css); *css << "         loopIdx" << loopNesting << "++) {\n";
975 				indent += 4;
976 				loopNesting++;
977 				break;
978 			case OP_END_FOR_UNIF:
979 				loopNesting--;
980 				indent -= 4;
981 				printIndent(*css); *css << "}\n";
982 				break;
983 			case OP_BEGIN_DO_WHILE_UNIF:
984 				printIndent(*css); *css << "{\n";
985 				indent += 4;
986 				printIndent(*css); *css << "int loopIdx" << loopNesting << " = 0;\n";
987 				printIndent(*css); *css << "do {\n";
988 				indent += 4;
989 				printIndent(*css); *css << "loopIdx" << loopNesting << "++;\n";
990 				loopNesting++;
991 				break;
992 			case OP_BEGIN_DO_WHILE_INF:
993 				printIndent(*css); *css << "{\n";
994 				indent += 4;
995 				printIndent(*css); *css << "int loopIdx" << loopNesting << " = 0;\n";
996 				printIndent(*css); *css << "do {\n";
997 				indent += 4;
998 				loopNesting++;
999 				break;
1000 			case OP_END_DO_WHILE_UNIF:
1001 				loopNesting--;
1002 				indent -= 4;
1003 				printIndent(*css); *css << "} while (loopIdx" << loopNesting << " < inputA.a[" << ops[(deUint32)ops[i].value].value << "]);\n";
1004 				indent -= 4;
1005 				printIndent(*css); *css << "}\n";
1006 				break;
1007 			case OP_END_DO_WHILE_INF:
1008 				loopNesting--;
1009 				printIndent(*css); *css << "loopIdx" << loopNesting << "++;\n";
1010 				indent -= 4;
1011 				printIndent(*css); *css << "} while (true);\n";
1012 				indent -= 4;
1013 				printIndent(*css); *css << "}\n";
1014 				break;
1015 			case OP_BEGIN_FOR_VAR:
1016 				printIndent(*css); *css << "for (int loopIdx" << loopNesting << " = 0;\n";
1017 				printIndent(*css); *css << "         loopIdx" << loopNesting << " < gl_SubgroupInvocationID + 1;\n";
1018 				printIndent(*css); *css << "         loopIdx" << loopNesting << "++) {\n";
1019 				indent += 4;
1020 				loopNesting++;
1021 				break;
1022 			case OP_END_FOR_VAR:
1023 				loopNesting--;
1024 				indent -= 4;
1025 				printIndent(*css); *css << "}\n";
1026 				break;
1027 			case OP_BEGIN_FOR_INF:
1028 				printIndent(*css); *css << "for (int loopIdx" << loopNesting << " = 0;;loopIdx" << loopNesting << "++,";
1029 				loopNesting++;
1030 				printBallot(css);
1031 				*css << ") {\n";
1032 				indent += 4;
1033 				break;
1034 			case OP_END_FOR_INF:
1035 				loopNesting--;
1036 				indent -= 4;
1037 				printIndent(*css); *css << "}\n";
1038 				break;
1039 			case OP_BREAK:
1040 				printIndent(*css); *css << "break;\n";
1041 				break;
1042 			case OP_CONTINUE:
1043 				printIndent(*css); *css << "continue;\n";
1044 				break;
1045 			case OP_ELECT:
1046 				printIndent(*css); *css << "if (subgroupElect()) {\n";
1047 				indent += 4;
1048 				break;
1049 			case OP_RETURN:
1050 				printIndent(*css); *css << "return;\n";
1051 				break;
1052 			case OP_CALL_BEGIN:
1053 				printIndent(*css); *css << "func" << funcNum << "(";
1054 				for (deInt32 n = 0; n < loopNesting; ++n)
1055 				{
1056 					*css << "loopIdx" << n;
1057 					if (n != loopNesting - 1)
1058 						*css << ", ";
1059 				}
1060 				*css << ");\n";
1061 				css = &functions;
1062 				printIndent(*css); *css << "void func" << funcNum << "(";
1063 				for (deInt32 n = 0; n < loopNesting; ++n)
1064 				{
1065 					*css << "int loopIdx" << n;
1066 					if (n != loopNesting - 1)
1067 						*css << ", ";
1068 				}
1069 				*css << ") {\n";
1070 				indent += 4;
1071 				funcNum++;
1072 				break;
1073 			case OP_CALL_END:
1074 				indent -= 4;
1075 				printIndent(*css); *css << "}\n";
1076 				css = &main;
1077 				break;
1078 			case OP_NOISE:
1079 				if (ops[i].value == 0)
1080 				{
1081 					printIndent(*css); *css << "while (!subgroupElect()) {}\n";
1082 				}
1083 				else
1084 				{
1085 					printIndent(*css); *css << "if (inputA.a[0] == 12345) {\n";
1086 					indent += 4;
1087 					printIndent(*css); *css << "while (true) {\n";
1088 					indent += 4;
1089 					printIndent(*css); printBallot(css); *css << ";\n";
1090 					indent -= 4;
1091 					printIndent(*css); *css << "}\n";
1092 					indent -= 4;
1093 					printIndent(*css); *css << "}\n";
1094 				}
1095 				break;
1096 			case OP_SWITCH_UNIF_BEGIN:
1097 				printIndent(*css); *css << "switch (inputA.a[" << ops[i].value << "]) {\n";
1098 				indent += 4;
1099 				break;
1100 			case OP_SWITCH_VAR_BEGIN:
1101 				printIndent(*css); *css << "switch (gl_SubgroupInvocationID & 3) {\n";
1102 				indent += 4;
1103 				break;
1104 			case OP_SWITCH_LOOP_COUNT_BEGIN:
1105 				printIndent(*css); *css << "switch (loopIdx" << ops[i].value << ") {\n";
1106 				indent += 4;
1107 				break;
1108 			case OP_SWITCH_END:
1109 				indent -= 4;
1110 				printIndent(*css); *css << "}\n";
1111 				break;
1112 			case OP_CASE_MASK_BEGIN:
1113 				for (deInt32 b = 0; b < 32; ++b)
1114 				{
1115 					if ((1u<<b) & ops[i].caseValue)
1116 					{
1117 						printIndent(*css); *css << "case " << b << ":\n";
1118 					}
1119 				}
1120 				printIndent(*css); *css << "{\n";
1121 				indent += 4;
1122 				break;
1123 			case OP_CASE_LOOP_COUNT_BEGIN:
1124 				if (ops[i].caseValue == 0xFFFFFFFF)
1125 				{
1126 					printIndent(*css); *css << "default: {\n";
1127 				}
1128 				else
1129 				{
1130 					printIndent(*css); *css << "case " << ops[i].caseValue << ": {\n";
1131 				}
1132 				indent += 4;
1133 				break;
1134 			case OP_CASE_END:
1135 				printIndent(*css); *css << "break;\n";
1136 				indent -= 4;
1137 				printIndent(*css); *css << "}\n";
1138 				break;
1139 			default:
1140 				DE_ASSERT(0);
1141 				break;
1142 			}
1143 		}
1144 	}
1145 
1146 	// Simulate execution of the program. If countOnly is true, just return
1147 	// the max number of outputs written. If it's false, store out the result
1148 	// values to ref
simulate(bool countOnly,deUint32 subgroupSize,deUint32 invocationStride,deUint64 * ref)1149 	deUint32 simulate(bool countOnly, deUint32 subgroupSize, deUint32 invocationStride, deUint64 *ref)
1150 	{
1151 		// State of the subgroup at each level of nesting
1152 		struct SubgroupState
1153 		{
1154 			// Currently executing
1155 			bitset128 activeMask;
1156 			// Have executed a continue instruction in this loop
1157 			bitset128 continueMask;
1158 			// index of the current if test or loop header
1159 			deUint32 header;
1160 			// number of loop iterations performed
1161 			deUint32 tripCount;
1162 			// is this nesting a loop?
1163 			deUint32 isLoop;
1164 			// is this nesting a function call?
1165 			deUint32 isCall;
1166 			// is this nesting a switch?
1167 			deUint32 isSwitch;
1168 		};
1169 		SubgroupState stateStack[10];
1170 		deMemset(&stateStack, 0, sizeof(stateStack));
1171 
1172 		const deUint64 fullSubgroupMask = subgroupSizeToMask(subgroupSize);
1173 
1174 		// Per-invocation output location counters
1175 		deUint32 outLoc[128] = {0};
1176 
1177 		nesting = 0;
1178 		loopNesting = 0;
1179 		stateStack[nesting].activeMask = ~bitset128(); // initialized to ~0
1180 
1181 		deInt32 i = 0;
1182 		while (i < (deInt32)ops.size())
1183 		{
1184 			switch (ops[i].type)
1185 			{
1186 			case OP_BALLOT:
1187 
1188 				// Flag that this ballot is workgroup-nonuniform
1189 				if (caseDef.isWUCF() && stateStack[nesting].activeMask.any() && !stateStack[nesting].activeMask.all())
1190 					ops[i].caseValue = 1;
1191 
1192 				if (caseDef.isSUCF())
1193 				{
1194 					for (deUint32 id = 0; id < 128; id += subgroupSize)
1195 					{
1196 						deUint64 subgroupMask = bitsetToU64(stateStack[nesting].activeMask, subgroupSize, id);
1197 						// Flag that this ballot is subgroup-nonuniform
1198 						if (subgroupMask != 0 && subgroupMask != fullSubgroupMask)
1199 							ops[i].caseValue = 1;
1200 					}
1201 				}
1202 
1203 				for (deUint32 id = 0; id < 128; ++id)
1204 				{
1205 					if (stateStack[nesting].activeMask.test(id))
1206 					{
1207 						if (countOnly)
1208 						{
1209 							outLoc[id]++;
1210 						}
1211 						else
1212 						{
1213 							if (ops[i].caseValue)
1214 							{
1215 								// Emit a magic value to indicate that we shouldn't validate this ballot
1216 								ref[(outLoc[id]++)*invocationStride + id] = bitsetToU64(0x12345678, subgroupSize, id);
1217 							}
1218 							else
1219 								ref[(outLoc[id]++)*invocationStride + id] = bitsetToU64(stateStack[nesting].activeMask, subgroupSize, id);
1220 						}
1221 					}
1222 				}
1223 				break;
1224 			case OP_STORE:
1225 				for (deUint32 id = 0; id < 128; ++id)
1226 				{
1227 					if (stateStack[nesting].activeMask.test(id))
1228 					{
1229 						if (countOnly)
1230 							outLoc[id]++;
1231 						else
1232 							ref[(outLoc[id]++)*invocationStride + id] = ops[i].value;
1233 					}
1234 				}
1235 				break;
1236 			case OP_IF_MASK:
1237 				nesting++;
1238 				stateStack[nesting].activeMask = stateStack[nesting-1].activeMask & bitsetFromU64(ops[i].value, subgroupSize);
1239 				stateStack[nesting].header = i;
1240 				stateStack[nesting].isLoop = 0;
1241 				stateStack[nesting].isSwitch = 0;
1242 				break;
1243 			case OP_ELSE_MASK:
1244 				stateStack[nesting].activeMask = stateStack[nesting-1].activeMask & ~bitsetFromU64(ops[stateStack[nesting].header].value, subgroupSize);
1245 				break;
1246 			case OP_IF_LOOPCOUNT:
1247 				{
1248 					deUint32 n = nesting;
1249 					while (!stateStack[n].isLoop)
1250 						n--;
1251 
1252 					nesting++;
1253 					stateStack[nesting].activeMask = stateStack[nesting-1].activeMask & bitsetFromU64((1ULL << stateStack[n].tripCount), subgroupSize);
1254 					stateStack[nesting].header = i;
1255 					stateStack[nesting].isLoop = 0;
1256 					stateStack[nesting].isSwitch = 0;
1257 					break;
1258 				}
1259 			case OP_ELSE_LOOPCOUNT:
1260 				{
1261 					deUint32 n = nesting;
1262 					while (!stateStack[n].isLoop)
1263 						n--;
1264 
1265 					stateStack[nesting].activeMask = stateStack[nesting-1].activeMask & ~bitsetFromU64((1ULL << stateStack[n].tripCount), subgroupSize);
1266 					break;
1267 				}
1268 			case OP_IF_LOCAL_INVOCATION_INDEX:
1269 				{
1270 					// all bits >= N
1271 					bitset128 mask(0);
1272 					for (deInt32 j = (deInt32)ops[i].value; j < 128; ++j)
1273 						mask.set(j);
1274 
1275 					nesting++;
1276 					stateStack[nesting].activeMask = stateStack[nesting-1].activeMask & mask;
1277 					stateStack[nesting].header = i;
1278 					stateStack[nesting].isLoop = 0;
1279 					stateStack[nesting].isSwitch = 0;
1280 					break;
1281 				}
1282 			case OP_ELSE_LOCAL_INVOCATION_INDEX:
1283 				{
1284 					// all bits < N
1285 					bitset128 mask(0);
1286 					for (deInt32 j = 0; j < (deInt32)ops[i].value; ++j)
1287 						mask.set(j);
1288 
1289 					stateStack[nesting].activeMask = stateStack[nesting-1].activeMask & mask;
1290 					break;
1291 				}
1292 			case OP_ENDIF:
1293 				nesting--;
1294 				break;
1295 			case OP_BEGIN_FOR_UNIF:
1296 				// XXX TODO: We don't handle a for loop with zero iterations
1297 				nesting++;
1298 				loopNesting++;
1299 				stateStack[nesting].activeMask = stateStack[nesting-1].activeMask;
1300 				stateStack[nesting].header = i;
1301 				stateStack[nesting].tripCount = 0;
1302 				stateStack[nesting].isLoop = 1;
1303 				stateStack[nesting].isSwitch = 0;
1304 				stateStack[nesting].continueMask = 0;
1305 				break;
1306 			case OP_END_FOR_UNIF:
1307 				stateStack[nesting].tripCount++;
1308 				stateStack[nesting].activeMask |= stateStack[nesting].continueMask;
1309 				stateStack[nesting].continueMask = 0;
1310 				if (stateStack[nesting].tripCount < ops[stateStack[nesting].header].value &&
1311 					stateStack[nesting].activeMask.any())
1312 				{
1313 					i = stateStack[nesting].header+1;
1314 					continue;
1315 				}
1316 				else
1317 				{
1318 					loopNesting--;
1319 					nesting--;
1320 				}
1321 				break;
1322 			case OP_BEGIN_DO_WHILE_UNIF:
1323 				// XXX TODO: We don't handle a for loop with zero iterations
1324 				nesting++;
1325 				loopNesting++;
1326 				stateStack[nesting].activeMask = stateStack[nesting-1].activeMask;
1327 				stateStack[nesting].header = i;
1328 				stateStack[nesting].tripCount = 1;
1329 				stateStack[nesting].isLoop = 1;
1330 				stateStack[nesting].isSwitch = 0;
1331 				stateStack[nesting].continueMask = 0;
1332 				break;
1333 			case OP_END_DO_WHILE_UNIF:
1334 				stateStack[nesting].activeMask |= stateStack[nesting].continueMask;
1335 				stateStack[nesting].continueMask = 0;
1336 				if (stateStack[nesting].tripCount < ops[stateStack[nesting].header].value &&
1337 					stateStack[nesting].activeMask.any())
1338 				{
1339 					i = stateStack[nesting].header+1;
1340 					stateStack[nesting].tripCount++;
1341 					continue;
1342 				}
1343 				else
1344 				{
1345 					loopNesting--;
1346 					nesting--;
1347 				}
1348 				break;
1349 			case OP_BEGIN_FOR_VAR:
1350 				// XXX TODO: We don't handle a for loop with zero iterations
1351 				nesting++;
1352 				loopNesting++;
1353 				stateStack[nesting].activeMask = stateStack[nesting-1].activeMask;
1354 				stateStack[nesting].header = i;
1355 				stateStack[nesting].tripCount = 0;
1356 				stateStack[nesting].isLoop = 1;
1357 				stateStack[nesting].isSwitch = 0;
1358 				stateStack[nesting].continueMask = 0;
1359 				break;
1360 			case OP_END_FOR_VAR:
1361 				stateStack[nesting].tripCount++;
1362 				stateStack[nesting].activeMask |= stateStack[nesting].continueMask;
1363 				stateStack[nesting].continueMask = 0;
1364 				stateStack[nesting].activeMask &= bitsetFromU64(stateStack[nesting].tripCount == subgroupSize ? 0 : ~((1ULL << (stateStack[nesting].tripCount)) - 1), subgroupSize);
1365 				if (stateStack[nesting].activeMask.any())
1366 				{
1367 					i = stateStack[nesting].header+1;
1368 					continue;
1369 				}
1370 				else
1371 				{
1372 					loopNesting--;
1373 					nesting--;
1374 				}
1375 				break;
1376 			case OP_BEGIN_FOR_INF:
1377 			case OP_BEGIN_DO_WHILE_INF:
1378 				nesting++;
1379 				loopNesting++;
1380 				stateStack[nesting].activeMask = stateStack[nesting-1].activeMask;
1381 				stateStack[nesting].header = i;
1382 				stateStack[nesting].tripCount = 0;
1383 				stateStack[nesting].isLoop = 1;
1384 				stateStack[nesting].isSwitch = 0;
1385 				stateStack[nesting].continueMask = 0;
1386 				break;
1387 			case OP_END_FOR_INF:
1388 				stateStack[nesting].tripCount++;
1389 				stateStack[nesting].activeMask |= stateStack[nesting].continueMask;
1390 				stateStack[nesting].continueMask = 0;
1391 				if (stateStack[nesting].activeMask.any())
1392 				{
1393 					// output expected OP_BALLOT values
1394 					for (deUint32 id = 0; id < 128; ++id)
1395 					{
1396 						if (stateStack[nesting].activeMask.test(id))
1397 						{
1398 							if (countOnly)
1399 								outLoc[id]++;
1400 							else
1401 								ref[(outLoc[id]++)*invocationStride + id] = bitsetToU64(stateStack[nesting].activeMask, subgroupSize, id);
1402 						}
1403 					}
1404 
1405 					i = stateStack[nesting].header+1;
1406 					continue;
1407 				}
1408 				else
1409 				{
1410 					loopNesting--;
1411 					nesting--;
1412 				}
1413 				break;
1414 			case OP_END_DO_WHILE_INF:
1415 				stateStack[nesting].tripCount++;
1416 				stateStack[nesting].activeMask |= stateStack[nesting].continueMask;
1417 				stateStack[nesting].continueMask = 0;
1418 				if (stateStack[nesting].activeMask.any())
1419 				{
1420 					i = stateStack[nesting].header+1;
1421 					continue;
1422 				}
1423 				else
1424 				{
1425 					loopNesting--;
1426 					nesting--;
1427 				}
1428 				break;
1429 			case OP_BREAK:
1430 				{
1431 					deUint32 n = nesting;
1432 					bitset128 mask = stateStack[nesting].activeMask;
1433 					while (true)
1434 					{
1435 						stateStack[n].activeMask &= ~mask;
1436 						if (stateStack[n].isLoop || stateStack[n].isSwitch)
1437 							break;
1438 
1439 						n--;
1440 					}
1441 				}
1442 				break;
1443 			case OP_CONTINUE:
1444 				{
1445 					deUint32 n = nesting;
1446 					bitset128 mask = stateStack[nesting].activeMask;
1447 					while (true)
1448 					{
1449 						stateStack[n].activeMask &= ~mask;
1450 						if (stateStack[n].isLoop)
1451 						{
1452 							stateStack[n].continueMask |= mask;
1453 							break;
1454 						}
1455 						n--;
1456 					}
1457 				}
1458 				break;
1459 			case OP_ELECT:
1460 				{
1461 					nesting++;
1462 					stateStack[nesting].activeMask = bitsetElect(stateStack[nesting-1].activeMask, subgroupSize);
1463 					stateStack[nesting].header = i;
1464 					stateStack[nesting].isLoop = 0;
1465 					stateStack[nesting].isSwitch = 0;
1466 				}
1467 				break;
1468 			case OP_RETURN:
1469 				{
1470 					bitset128 mask = stateStack[nesting].activeMask;
1471 					for (deInt32 n = nesting; n >= 0; --n)
1472 					{
1473 						stateStack[n].activeMask &= ~mask;
1474 						if (stateStack[n].isCall)
1475 							break;
1476 					}
1477 				}
1478 				break;
1479 
1480 			case OP_CALL_BEGIN:
1481 				nesting++;
1482 				stateStack[nesting].activeMask = stateStack[nesting-1].activeMask;
1483 				stateStack[nesting].isLoop = 0;
1484 				stateStack[nesting].isSwitch = 0;
1485 				stateStack[nesting].isCall = 1;
1486 				break;
1487 			case OP_CALL_END:
1488 				stateStack[nesting].isCall = 0;
1489 				nesting--;
1490 				break;
1491 			case OP_NOISE:
1492 				break;
1493 
1494 			case OP_SWITCH_UNIF_BEGIN:
1495 			case OP_SWITCH_VAR_BEGIN:
1496 			case OP_SWITCH_LOOP_COUNT_BEGIN:
1497 				nesting++;
1498 				stateStack[nesting].activeMask = stateStack[nesting-1].activeMask;
1499 				stateStack[nesting].header = i;
1500 				stateStack[nesting].isLoop = 0;
1501 				stateStack[nesting].isSwitch = 1;
1502 				break;
1503 			case OP_SWITCH_END:
1504 				nesting--;
1505 				break;
1506 			case OP_CASE_MASK_BEGIN:
1507 				stateStack[nesting].activeMask = stateStack[nesting-1].activeMask & bitsetFromU64(ops[i].value, subgroupSize);
1508 				break;
1509 			case OP_CASE_LOOP_COUNT_BEGIN:
1510 				{
1511 					deUint32 n = nesting;
1512 					deUint32 l = loopNesting;
1513 
1514 					while (true)
1515 					{
1516 						if (stateStack[n].isLoop)
1517 						{
1518 							l--;
1519 							if (l == ops[stateStack[nesting].header].value)
1520 								break;
1521 						}
1522 						n--;
1523 					}
1524 
1525 					if ((1ULL << stateStack[n].tripCount) & ops[i].value)
1526 						stateStack[nesting].activeMask = stateStack[nesting-1].activeMask;
1527 					else
1528 						stateStack[nesting].activeMask = 0;
1529 					break;
1530 				}
1531 			case OP_CASE_END:
1532 				break;
1533 
1534 			default:
1535 				DE_ASSERT(0);
1536 				break;
1537 			}
1538 			i++;
1539 		}
1540 		deUint32 maxLoc = 0;
1541 		for (deUint32 id = 0; id < ARRAYSIZE(outLoc); ++id)
1542 			maxLoc = de::max(maxLoc, outLoc[id]);
1543 
1544 		return maxLoc;
1545 	}
1546 
hasUCF() const1547 	bool hasUCF() const
1548 	{
1549 		for (deInt32 i = 0; i < (deInt32)ops.size(); ++i)
1550 		{
1551 			if (ops[i].type == OP_BALLOT && ops[i].caseValue == 0)
1552 				return true;
1553 		}
1554 		return false;
1555 	}
1556 };
1557 
initPrograms(SourceCollections & programCollection) const1558 void ReconvergenceTestCase::initPrograms (SourceCollections& programCollection) const
1559 {
1560 	RandomProgram program(m_data);
1561 	program.generateRandomProgram();
1562 
1563 	std::stringstream css;
1564 	css << "#version 450 core\n";
1565 	css << "#extension GL_KHR_shader_subgroup_ballot : enable\n";
1566 	css << "#extension GL_KHR_shader_subgroup_vote : enable\n";
1567 	css << "#extension GL_NV_shader_subgroup_partitioned : enable\n";
1568 	css << "#extension GL_EXT_subgroup_uniform_control_flow : enable\n";
1569 	css << "layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;\n";
1570 	css << "layout(set=0, binding=0) coherent buffer InputA { uint a[]; } inputA;\n";
1571 	css << "layout(set=0, binding=1) coherent buffer OutputB { uvec2 b[]; } outputB;\n";
1572 	css << "layout(set=0, binding=2) coherent buffer OutputC { uint loc[]; } outputC;\n";
1573 	css << "layout(push_constant) uniform PC {\n"
1574 			"   // set to the real stride when writing out ballots, or zero when just counting\n"
1575 			"   int invocationStride;\n"
1576 			"};\n";
1577 	css << "int outLoc = 0;\n";
1578 
1579 	css << "bool testBit(uvec2 mask, uint bit) { return (bit < 32) ? ((mask.x >> bit) & 1) != 0 : ((mask.y >> (bit-32)) & 1) != 0; }\n";
1580 
1581 	css << "uint elect() { return int(subgroupElect()) + 1; }\n";
1582 
1583 	std::stringstream functions, main;
1584 	program.genCode(functions, main);
1585 
1586 	css << functions.str() << "\n\n";
1587 
1588 	css <<
1589 		"void main()\n"
1590 		<< (m_data.isSUCF() ? "[[subgroup_uniform_control_flow]]\n" : "") <<
1591 		"{\n";
1592 
1593 	css << main.str() << "\n\n";
1594 
1595 	css << "}\n";
1596 
1597 	const vk::ShaderBuildOptions	buildOptions	(programCollection.usedVulkanVersion, vk::SPIRV_VERSION_1_3, 0u);
1598 
1599 	programCollection.glslSources.add("test") << glu::ComputeSource(css.str()) << buildOptions;
1600 }
1601 
createInstance(Context & context) const1602 TestInstance* ReconvergenceTestCase::createInstance (Context& context) const
1603 {
1604 	return new ReconvergenceTestInstance(context, m_data);
1605 }
1606 
iterate(void)1607 tcu::TestStatus ReconvergenceTestInstance::iterate (void)
1608 {
1609 	const DeviceInterface&	vk						= m_context.getDeviceInterface();
1610 	const VkDevice			device					= m_context.getDevice();
1611 	Allocator&				allocator				= m_context.getDefaultAllocator();
1612 	tcu::TestLog&			log						= m_context.getTestContext().getLog();
1613 
1614 	deRandom rnd;
1615 	deRandom_init(&rnd, m_data.seed);
1616 
1617 	vk::VkPhysicalDeviceSubgroupProperties subgroupProperties;
1618 	deMemset(&subgroupProperties, 0, sizeof(subgroupProperties));
1619 	subgroupProperties.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SUBGROUP_PROPERTIES;
1620 
1621 	vk::VkPhysicalDeviceProperties2 properties2;
1622 	deMemset(&properties2, 0, sizeof(properties2));
1623 	properties2.sType = vk::VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_PROPERTIES_2;
1624 	properties2.pNext = &subgroupProperties;
1625 
1626 	m_context.getInstanceInterface().getPhysicalDeviceProperties2(m_context.getPhysicalDevice(), &properties2);
1627 
1628 	const deUint32 subgroupSize = subgroupProperties.subgroupSize;
1629 	const deUint32 invocationStride = 128;
1630 
1631 	if (subgroupSize > 64)
1632 		TCU_THROW(TestError, "Subgroup size greater than 64 not handled.");
1633 
1634 	RandomProgram program(m_data);
1635 	program.generateRandomProgram();
1636 
1637 	deUint32 maxLoc = program.simulate(true, subgroupSize, invocationStride, DE_NULL);
1638 
1639 	// maxLoc is per-invocation. Add one (to make sure no additional writes are done) and multiply by
1640 	// the number of invocations
1641 	maxLoc++;
1642 	maxLoc *= invocationStride;
1643 
1644 	// buffer[0] is an input filled with a[i] == i
1645 	// buffer[1] is the output
1646 	// buffer[2] is the location counts
1647 	de::MovePtr<BufferWithMemory> buffers[3];
1648 	vk::VkDescriptorBufferInfo bufferDescriptors[3];
1649 
1650 	VkDeviceSize sizes[3] =
1651 	{
1652 		128 * sizeof(deUint32),
1653 		maxLoc * sizeof(deUint64),
1654 		invocationStride * sizeof(deUint32),
1655 	};
1656 
1657 	for (deUint32 i = 0; i < 3; ++i)
1658 	{
1659 		if (sizes[i] > properties2.properties.limits.maxStorageBufferRange)
1660 			TCU_THROW(NotSupportedError, "Storage buffer size larger than device limits");
1661 
1662 		try
1663 		{
1664 			buffers[i] = de::MovePtr<BufferWithMemory>(new BufferWithMemory(
1665 				vk, device, allocator, makeBufferCreateInfo(sizes[i], VK_BUFFER_USAGE_STORAGE_BUFFER_BIT|VK_BUFFER_USAGE_TRANSFER_DST_BIT|VK_BUFFER_USAGE_TRANSFER_SRC_BIT),
1666 				MemoryRequirement::HostVisible | MemoryRequirement::Cached));
1667 		}
1668 		catch(tcu::ResourceError&)
1669 		{
1670 			// Allocation size is unpredictable and can be too large for some systems. Don't treat allocation failure as a test failure.
1671 			return tcu::TestStatus(QP_TEST_RESULT_QUALITY_WARNING, "Failed device memory allocation " + de::toString(sizes[i]) + " bytes");
1672 		}
1673 		bufferDescriptors[i] = makeDescriptorBufferInfo(**buffers[i], 0, sizes[i]);
1674 	}
1675 
1676 	deUint32 *ptrs[3];
1677 	for (deUint32 i = 0; i < 3; ++i)
1678 	{
1679 		ptrs[i] = (deUint32 *)buffers[i]->getAllocation().getHostPtr();
1680 	}
1681 	for (deUint32 i = 0; i < sizes[0] / sizeof(deUint32); ++i)
1682 	{
1683 		ptrs[0][i] = i;
1684 	}
1685 	deMemset(ptrs[1], 0, (size_t)sizes[1]);
1686 	deMemset(ptrs[2], 0, (size_t)sizes[2]);
1687 
1688 	vk::DescriptorSetLayoutBuilder layoutBuilder;
1689 
1690 	layoutBuilder.addSingleBinding(VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, allShaderStages);
1691 	layoutBuilder.addSingleBinding(VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, allShaderStages);
1692 	layoutBuilder.addSingleBinding(VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, allShaderStages);
1693 
1694 	vk::Unique<vk::VkDescriptorSetLayout>	descriptorSetLayout(layoutBuilder.build(vk, device));
1695 
1696 	vk::Unique<vk::VkDescriptorPool>		descriptorPool(vk::DescriptorPoolBuilder()
1697 		.addType(VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, 3u)
1698 		.build(vk, device, VK_DESCRIPTOR_POOL_CREATE_FREE_DESCRIPTOR_SET_BIT, 1u));
1699 	vk::Unique<vk::VkDescriptorSet>			descriptorSet		(makeDescriptorSet(vk, device, *descriptorPool, *descriptorSetLayout));
1700 
1701 	const deUint32 specData[1] =
1702 	{
1703 		invocationStride,
1704 	};
1705 	const vk::VkSpecializationMapEntry entries[1] =
1706 	{
1707 		{0, (deUint32)(sizeof(deUint32) * 0), sizeof(deUint32)},
1708 	};
1709 	const vk::VkSpecializationInfo specInfo =
1710 	{
1711 		1,						// mapEntryCount
1712 		entries,				// pMapEntries
1713 		sizeof(specData),		// dataSize
1714 		specData				// pData
1715 	};
1716 
1717 	const VkPushConstantRange				pushConstantRange				=
1718 	{
1719 		allShaderStages,											// VkShaderStageFlags					stageFlags;
1720 		0u,															// deUint32								offset;
1721 		sizeof(deInt32)												// deUint32								size;
1722 	};
1723 
1724 	const VkPipelineLayoutCreateInfo pipelineLayoutCreateInfo =
1725 	{
1726 		VK_STRUCTURE_TYPE_PIPELINE_LAYOUT_CREATE_INFO,				// sType
1727 		DE_NULL,													// pNext
1728 		(VkPipelineLayoutCreateFlags)0,
1729 		1,															// setLayoutCount
1730 		&descriptorSetLayout.get(),									// pSetLayouts
1731 		1u,															// pushConstantRangeCount
1732 		&pushConstantRange,											// pPushConstantRanges
1733 	};
1734 
1735 	Move<VkPipelineLayout> pipelineLayout = createPipelineLayout(vk, device, &pipelineLayoutCreateInfo, NULL);
1736 
1737 	VkPipelineBindPoint bindPoint = VK_PIPELINE_BIND_POINT_COMPUTE;
1738 
1739 	flushAlloc(vk, device, buffers[0]->getAllocation());
1740 	flushAlloc(vk, device, buffers[1]->getAllocation());
1741 	flushAlloc(vk, device, buffers[2]->getAllocation());
1742 
1743 	const VkBool32 computeFullSubgroups = subgroupProperties.subgroupSize <= 64 &&
1744 										  m_context.getSubgroupSizeControlFeaturesEXT().computeFullSubgroups;
1745 
1746 	const VkPipelineShaderStageRequiredSubgroupSizeCreateInfoEXT subgroupSizeCreateInfo =
1747 	{
1748 		VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_REQUIRED_SUBGROUP_SIZE_CREATE_INFO_EXT,	// VkStructureType		sType;
1749 		DE_NULL,																		// void*				pNext;
1750 		subgroupProperties.subgroupSize													// uint32_t				requiredSubgroupSize;
1751 	};
1752 
1753 	const void *shaderPNext = computeFullSubgroups ? &subgroupSizeCreateInfo : DE_NULL;
1754 	VkPipelineShaderStageCreateFlags pipelineShaderStageCreateFlags =
1755 		(VkPipelineShaderStageCreateFlags)(computeFullSubgroups ? VK_PIPELINE_SHADER_STAGE_CREATE_REQUIRE_FULL_SUBGROUPS_BIT_EXT : 0);
1756 
1757 	const Unique<VkShaderModule>			shader						(createShaderModule(vk, device, m_context.getBinaryCollection().get("test"), 0));
1758 	const VkPipelineShaderStageCreateInfo	shaderCreateInfo =
1759 	{
1760 		VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO,
1761 		shaderPNext,
1762 		pipelineShaderStageCreateFlags,
1763 		VK_SHADER_STAGE_COMPUTE_BIT,								// stage
1764 		*shader,													// shader
1765 		"main",
1766 		&specInfo,													// pSpecializationInfo
1767 	};
1768 
1769 	const VkComputePipelineCreateInfo		pipelineCreateInfo =
1770 	{
1771 		VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO,
1772 		DE_NULL,
1773 		0u,															// flags
1774 		shaderCreateInfo,											// cs
1775 		*pipelineLayout,											// layout
1776 		(vk::VkPipeline)0,											// basePipelineHandle
1777 		0u,															// basePipelineIndex
1778 	};
1779 	Move<VkPipeline> pipeline = createComputePipeline(vk, device, DE_NULL, &pipelineCreateInfo, NULL);
1780 
1781 	const VkQueue					queue					= m_context.getUniversalQueue();
1782 	Move<VkCommandPool>				cmdPool					= createCommandPool(vk, device, vk::VK_COMMAND_POOL_CREATE_RESET_COMMAND_BUFFER_BIT, m_context.getUniversalQueueFamilyIndex());
1783 	Move<VkCommandBuffer>			cmdBuffer				= allocateCommandBuffer(vk, device, *cmdPool, VK_COMMAND_BUFFER_LEVEL_PRIMARY);
1784 
1785 
1786 	vk::DescriptorSetUpdateBuilder setUpdateBuilder;
1787 	setUpdateBuilder.writeSingle(*descriptorSet, vk::DescriptorSetUpdateBuilder::Location::binding(0),
1788 		VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, &bufferDescriptors[0]);
1789 	setUpdateBuilder.writeSingle(*descriptorSet, vk::DescriptorSetUpdateBuilder::Location::binding(1),
1790 		VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, &bufferDescriptors[1]);
1791 	setUpdateBuilder.writeSingle(*descriptorSet, vk::DescriptorSetUpdateBuilder::Location::binding(2),
1792 		VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, &bufferDescriptors[2]);
1793 	setUpdateBuilder.update(vk, device);
1794 
1795 	// compute "maxLoc", the maximum number of locations written
1796 	beginCommandBuffer(vk, *cmdBuffer, 0u);
1797 
1798 	vk.cmdBindDescriptorSets(*cmdBuffer, bindPoint, *pipelineLayout, 0u, 1, &*descriptorSet, 0u, DE_NULL);
1799 	vk.cmdBindPipeline(*cmdBuffer, bindPoint, *pipeline);
1800 
1801 	deInt32 pcinvocationStride = 0;
1802 	vk.cmdPushConstants(*cmdBuffer, *pipelineLayout, allShaderStages, 0, sizeof(pcinvocationStride), &pcinvocationStride);
1803 
1804 	vk.cmdDispatch(*cmdBuffer, 1, 1, 1);
1805 
1806 	endCommandBuffer(vk, *cmdBuffer);
1807 
1808 	submitCommandsAndWait(vk, device, queue, cmdBuffer.get());
1809 
1810 	invalidateAlloc(vk, device, buffers[1]->getAllocation());
1811 	invalidateAlloc(vk, device, buffers[2]->getAllocation());
1812 
1813 	// Clear any writes to buffer[1] during the counting pass
1814 	deMemset(ptrs[1], 0, invocationStride * sizeof(deUint64));
1815 
1816 	// Take the max over all invocations. Add one (to make sure no additional writes are done) and multiply by
1817 	// the number of invocations
1818 	deUint32 newMaxLoc = 0;
1819 	for (deUint32 id = 0; id < invocationStride; ++id)
1820 		newMaxLoc = de::max(newMaxLoc, ptrs[2][id]);
1821 	newMaxLoc++;
1822 	newMaxLoc *= invocationStride;
1823 
1824 	// If we need more space, reallocate buffers[1]
1825 	if (newMaxLoc > maxLoc)
1826 	{
1827 		maxLoc = newMaxLoc;
1828 		sizes[1] = maxLoc * sizeof(deUint64);
1829 
1830 		if (sizes[1] > properties2.properties.limits.maxStorageBufferRange)
1831 			TCU_THROW(NotSupportedError, "Storage buffer size larger than device limits");
1832 
1833 		try
1834 		{
1835 			buffers[1] = de::MovePtr<BufferWithMemory>(new BufferWithMemory(
1836 				vk, device, allocator, makeBufferCreateInfo(sizes[1], VK_BUFFER_USAGE_STORAGE_BUFFER_BIT|VK_BUFFER_USAGE_TRANSFER_DST_BIT|VK_BUFFER_USAGE_TRANSFER_SRC_BIT),
1837 				MemoryRequirement::HostVisible | MemoryRequirement::Cached));
1838 		}
1839 		catch(tcu::ResourceError&)
1840 		{
1841 			// Allocation size is unpredictable and can be too large for some systems. Don't treat allocation failure as a test failure.
1842 			return tcu::TestStatus(QP_TEST_RESULT_QUALITY_WARNING, "Failed device memory allocation " + de::toString(sizes[1]) + " bytes");
1843 		}
1844 		bufferDescriptors[1] = makeDescriptorBufferInfo(**buffers[1], 0, sizes[1]);
1845 		ptrs[1] = (deUint32 *)buffers[1]->getAllocation().getHostPtr();
1846 		deMemset(ptrs[1], 0, (size_t)sizes[1]);
1847 
1848 		vk::DescriptorSetUpdateBuilder setUpdateBuilder2;
1849 		setUpdateBuilder2.writeSingle(*descriptorSet, vk::DescriptorSetUpdateBuilder::Location::binding(1),
1850 			VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, &bufferDescriptors[1]);
1851 		setUpdateBuilder2.update(vk, device);
1852 	}
1853 
1854 	flushAlloc(vk, device, buffers[1]->getAllocation());
1855 
1856 	// run the actual shader
1857 	beginCommandBuffer(vk, *cmdBuffer, 0u);
1858 
1859 	vk.cmdBindDescriptorSets(*cmdBuffer, bindPoint, *pipelineLayout, 0u, 1, &*descriptorSet, 0u, DE_NULL);
1860 	vk.cmdBindPipeline(*cmdBuffer, bindPoint, *pipeline);
1861 
1862 	pcinvocationStride = invocationStride;
1863 	vk.cmdPushConstants(*cmdBuffer, *pipelineLayout, allShaderStages, 0, sizeof(pcinvocationStride), &pcinvocationStride);
1864 
1865 	vk.cmdDispatch(*cmdBuffer, 1, 1, 1);
1866 
1867 	endCommandBuffer(vk, *cmdBuffer);
1868 
1869 	submitCommandsAndWait(vk, device, queue, cmdBuffer.get());
1870 
1871 	invalidateAlloc(vk, device, buffers[1]->getAllocation());
1872 
1873 	qpTestResult res = QP_TEST_RESULT_PASS;
1874 
1875 	// Simulate execution on the CPU, and compare against the GPU result
1876 	std::vector<deUint64> ref;
1877 	try
1878 	{
1879 		ref.resize(maxLoc, 0ull);
1880 	}
1881 	catch (const std::bad_alloc&)
1882 	{
1883 		// Allocation size is unpredictable and can be too large for some systems. Don't treat allocation failure as a test failure.
1884 		return tcu::TestStatus(QP_TEST_RESULT_NOT_SUPPORTED, "Failed system memory allocation " + de::toString(maxLoc * sizeof(deUint64)) + " bytes");
1885 	}
1886 
1887 	program.simulate(false, subgroupSize, invocationStride, &ref[0]);
1888 
1889 	const deUint64 *result = (const deUint64 *)ptrs[1];
1890 
1891 	if (m_data.testType == TT_MAXIMAL)
1892 	{
1893 		// With maximal reconvergence, we should expect the output to exactly match
1894 		// the reference.
1895 		for (deUint32 i = 0; i < maxLoc; ++i)
1896 		{
1897 			if (result[i] != ref[i])
1898 			{
1899 				log << tcu::TestLog::Message << "first mismatch at " << i << tcu::TestLog::EndMessage;
1900 				res = QP_TEST_RESULT_FAIL;
1901 				break;
1902 			}
1903 		}
1904 
1905 		if (res != QP_TEST_RESULT_PASS)
1906 		{
1907 			for (deUint32 i = 0; i < maxLoc; ++i)
1908 			{
1909 				// This log can be large and slow, ifdef it out by default
1910 #if 0
1911 				log << tcu::TestLog::Message << "result " << i << "(" << (i/invocationStride) << ", " << (i%invocationStride) << "): " << tcu::toHex(result[i]) << " ref " << tcu::toHex(ref[i]) << (result[i] != ref[i] ? " different" : "") << tcu::TestLog::EndMessage;
1912 #endif
1913 			}
1914 		}
1915 	}
1916 	else
1917 	{
1918 		deUint64 fullMask = subgroupSizeToMask(subgroupSize);
1919 		// For subgroup_uniform_control_flow, we expect any fully converged outputs in the reference
1920 		// to have a corresponding fully converged output in the result. So walk through each lane's
1921 		// results, and for each reference value of fullMask, find a corresponding result value of
1922 		// fullMask where the previous value (OP_STORE) matches. That means these came from the same
1923 		// source location.
1924 		vector<deUint32> firstFail(invocationStride, 0);
1925 		for (deUint32 lane = 0; lane < invocationStride; ++lane)
1926 		{
1927 			deUint32 resLoc = lane + invocationStride, refLoc = lane + invocationStride;
1928 			while (refLoc < maxLoc)
1929 			{
1930 				while (refLoc < maxLoc && ref[refLoc] != fullMask)
1931 					refLoc += invocationStride;
1932 				if (refLoc >= maxLoc)
1933 					break;
1934 
1935 				// For TT_SUCF_ELECT, when the reference result has a full mask, we expect lane 0 to be elected
1936 				// (a value of 2) and all other lanes to be not elected (a value of 1). For TT_SUCF_BALLOT, we
1937 				// expect a full mask. Search until we find the expected result with a matching store value in
1938 				// the previous result.
1939 				deUint64 expectedResult = m_data.isElect() ? ((lane % subgroupSize) == 0 ? 2 : 1)
1940 															 : fullMask;
1941 
1942 				while (resLoc < maxLoc && !(result[resLoc] == expectedResult && result[resLoc-invocationStride] == ref[refLoc-invocationStride]))
1943 					resLoc += invocationStride;
1944 
1945 				// If we didn't find this output in the result, flag it as an error.
1946 				if (resLoc >= maxLoc)
1947 				{
1948 					firstFail[lane] = refLoc;
1949 					log << tcu::TestLog::Message << "lane " << lane << " first mismatch at " << firstFail[lane] << tcu::TestLog::EndMessage;
1950 					res = QP_TEST_RESULT_FAIL;
1951 					break;
1952 				}
1953 				refLoc += invocationStride;
1954 				resLoc += invocationStride;
1955 			}
1956 		}
1957 
1958 		if (res != QP_TEST_RESULT_PASS)
1959 		{
1960 			for (deUint32 i = 0; i < maxLoc; ++i)
1961 			{
1962 				// This log can be large and slow, ifdef it out by default
1963 #if 0
1964 				log << tcu::TestLog::Message << "result " << i << "(" << (i/invocationStride) << ", " << (i%invocationStride) << "): " << tcu::toHex(result[i]) << " ref " << tcu::toHex(ref[i]) << (i == firstFail[i%invocationStride] ? " first fail" : "") << tcu::TestLog::EndMessage;
1965 #endif
1966 			}
1967 		}
1968 	}
1969 
1970 	return tcu::TestStatus(res, qpGetTestResultName(res));
1971 }
1972 
1973 }	// anonymous
1974 
createTests(tcu::TestContext & testCtx,bool createExperimental)1975 tcu::TestCaseGroup*	createTests (tcu::TestContext& testCtx, bool createExperimental)
1976 {
1977 	de::MovePtr<tcu::TestCaseGroup> group(new tcu::TestCaseGroup(
1978 			testCtx, "reconvergence", "reconvergence tests"));
1979 
1980 	typedef struct
1981 	{
1982 		deUint32				value;
1983 		const char*				name;
1984 		const char*				description;
1985 	} TestGroupCase;
1986 
1987 	TestGroupCase ttCases[] =
1988 	{
1989 		{ TT_SUCF_ELECT,				"subgroup_uniform_control_flow_elect",	"subgroup_uniform_control_flow_elect"		},
1990 		{ TT_SUCF_BALLOT,				"subgroup_uniform_control_flow_ballot",	"subgroup_uniform_control_flow_ballot"		},
1991 		{ TT_WUCF_ELECT,				"workgroup_uniform_control_flow_elect",	"workgroup_uniform_control_flow_elect"		},
1992 		{ TT_WUCF_BALLOT,				"workgroup_uniform_control_flow_ballot","workgroup_uniform_control_flow_ballot"		},
1993 		{ TT_MAXIMAL,					"maximal",								"maximal"									},
1994 	};
1995 
1996 	for (int ttNdx = 0; ttNdx < DE_LENGTH_OF_ARRAY(ttCases); ttNdx++)
1997 	{
1998 		de::MovePtr<tcu::TestCaseGroup> ttGroup(new tcu::TestCaseGroup(testCtx, ttCases[ttNdx].name, ttCases[ttNdx].description));
1999 		de::MovePtr<tcu::TestCaseGroup> computeGroup(new tcu::TestCaseGroup(testCtx, "compute", ""));
2000 
2001 		for (deUint32 nNdx = 2; nNdx <= 6; nNdx++)
2002 		{
2003 			de::MovePtr<tcu::TestCaseGroup> nestGroup(new tcu::TestCaseGroup(testCtx, ("nesting" + de::toString(nNdx)).c_str(), ""));
2004 
2005 			deUint32 seed = 0;
2006 
2007 			for (int sNdx = 0; sNdx < 8; sNdx++)
2008 			{
2009 				de::MovePtr<tcu::TestCaseGroup> seedGroup(new tcu::TestCaseGroup(testCtx, de::toString(sNdx).c_str(), ""));
2010 
2011 				deUint32 numTests = 0;
2012 				switch (nNdx)
2013 				{
2014 				default:
2015 					DE_ASSERT(0);
2016 					// fallthrough
2017 				case 2:
2018 				case 3:
2019 				case 4:
2020 					numTests = 250;
2021 					break;
2022 				case 5:
2023 					numTests = 100;
2024 					break;
2025 				case 6:
2026 					numTests = 50;
2027 					break;
2028 				}
2029 
2030 				if (ttCases[ttNdx].value != TT_MAXIMAL)
2031 				{
2032 					if (nNdx >= 5)
2033 						continue;
2034 				}
2035 
2036 				for (deUint32 ndx = 0; ndx < numTests; ndx++)
2037 				{
2038 					CaseDef c =
2039 					{
2040 						(TestType)ttCases[ttNdx].value,		// TestType testType;
2041 						nNdx,								// deUint32 maxNesting;
2042 						seed,								// deUint32 seed;
2043 					};
2044 					seed++;
2045 
2046 					bool isExperimentalTest = !c.isUCF() || (ndx >= numTests / 5);
2047 
2048 					if (createExperimental == isExperimentalTest)
2049 						seedGroup->addChild(new ReconvergenceTestCase(testCtx, de::toString(ndx).c_str(), "", c));
2050 				}
2051 				if (!seedGroup->empty())
2052 					nestGroup->addChild(seedGroup.release());
2053 			}
2054 			if (!nestGroup->empty())
2055 				computeGroup->addChild(nestGroup.release());
2056 		}
2057 		if (!computeGroup->empty())
2058 		{
2059 			ttGroup->addChild(computeGroup.release());
2060 			group->addChild(ttGroup.release());
2061 		}
2062 	}
2063 	return group.release();
2064 }
2065 
2066 }	// Reconvergence
2067 }	// vkt
2068