• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*-------------------------------------------------------------------------
2  * Vulkan Conformance Tests
3  * ------------------------
4  *
5  * Copyright (c) 2020 Valve Corporation.
6  * Copyright (c) 2020 The Khronos Group Inc.
7  *
8  * Licensed under the Apache License, Version 2.0 (the "License");
9  * you may not use this file except in compliance with the License.
10  * You may obtain a copy of the License at
11  *
12  *      http://www.apache.org/licenses/LICENSE-2.0
13  *
14  * Unless required by applicable law or agreed to in writing, software
15  * distributed under the License is distributed on an "AS IS" BASIS,
16  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17  * See the License for the specific language governing permissions and
18  * limitations under the License.
19  *
20  *//*!
21  * \file
22  * \brief SPIR-V tests for VK_AMD_shader_trinary_minmax.
23  *//*--------------------------------------------------------------------*/
24 
25 #include "vktSpvAsmTrinaryMinMaxTests.hpp"
26 #include "vktTestCase.hpp"
27 
28 #include "vkQueryUtil.hpp"
29 #include "vkObjUtil.hpp"
30 #include "vkBufferWithMemory.hpp"
31 #include "vkBuilderUtil.hpp"
32 #include "vkTypeUtil.hpp"
33 #include "vkBarrierUtil.hpp"
34 #include "vkCmdUtil.hpp"
35 
36 #include "tcuStringTemplate.hpp"
37 #include "tcuFloat.hpp"
38 #include "tcuMaybe.hpp"
39 
40 #include "deStringUtil.hpp"
41 #include "deRandom.hpp"
42 #include "deMemory.h"
43 
44 #include <string>
45 #include <sstream>
46 #include <map>
47 #include <vector>
48 #include <algorithm>
49 #include <array>
50 #include <memory>
51 
52 namespace vkt
53 {
54 namespace SpirVAssembly
55 {
56 
57 namespace
58 {
59 
60 enum class OperationType
61 {
62 	MIN = 0,
63 	MAX = 1,
64 	MID = 2,
65 };
66 
67 enum class BaseType
68 {
69 	TYPE_INT = 0,
70 	TYPE_UINT,
71 	TYPE_FLOAT,
72 };
73 
74 // The numeric value is the size in bytes.
75 enum class TypeSize
76 {
77 	SIZE_8BIT	= 1,
78 	SIZE_16BIT	= 2,
79 	SIZE_32BIT	= 4,
80 	SIZE_64BIT	= 8,
81 };
82 
83 // The numeric value is the number of components.
84 enum class AggregationType
85 {
86 	SCALAR	= 1,
87 	VEC2	= 2,
88 	VEC3	= 3,
89 	VEC4	= 4,
90 };
91 
92 struct TestParams
93 {
94 	OperationType	operation;
95 	BaseType		baseType;
96 	TypeSize		typeSize;
97 	AggregationType	aggregation;
98 	deUint32		randomSeed;
99 
100 	deUint32		operandSize			() const;	// In bytes.
101 	deUint32		numComponents		() const;	// Number of components.
102 	deUint32		effectiveComponents	() const;	// Effective number of components for size calculation.
103 	deUint32		componentSize		() const;	// In bytes.
104 };
105 
operandSize() const106 deUint32 TestParams::operandSize () const
107 {
108 	return (effectiveComponents() * componentSize());
109 }
110 
numComponents() const111 deUint32 TestParams::numComponents () const
112 {
113 	return static_cast<deUint32>(aggregation);
114 }
115 
effectiveComponents() const116 deUint32 TestParams::effectiveComponents () const
117 {
118 	return static_cast<deUint32>((aggregation == AggregationType::VEC3) ? AggregationType::VEC4 : aggregation);
119 }
120 
componentSize() const121 deUint32 TestParams::componentSize () const
122 {
123 	return static_cast<deUint32>(typeSize);
124 }
125 
126 template <class T>
min3(T op1,T op2,T op3)127 T min3(T op1, T op2, T op3)
128 {
129 	return std::min({op1, op2, op3});
130 }
131 
132 template <class T>
max3(T op1,T op2,T op3)133 T max3(T op1, T op2, T op3)
134 {
135 	return std::max({op1, op2, op3});
136 }
137 
138 template <class T>
mid3(T op1,T op2,T op3)139 T mid3(T op1, T op2, T op3)
140 {
141 	std::array<T, 3> aux{{op1, op2, op3}};
142 	std::sort(begin(aux), end(aux));
143 	return aux[1];
144 }
145 
146 class OperationManager
147 {
148 public:
149 	// Operation and component index in case of error.
150 	using OperationComponent	= std::pair<deUint32, deUint32>;
151 	using ComparisonError		= tcu::Maybe<OperationComponent>;
152 
153 					OperationManager	(const TestParams& params);
154 	void			genInputBuffer		(void* bufferPtr, deUint32 numOperations);
155 	void			calculateResult		(void* referenceBuffer, void* inputBuffer, deUint32 numOperations);
156 	ComparisonError	compareResults		(void* referenceBuffer, void* resultsBuffer, deUint32 numOperations);
157 
158 private:
159 	using GenerateCompFunc = void (*)(de::Random&, void*); // Write a generated component to the given location.
160 
161 	// Generator variants to populate input buffer.
genInt8(de::Random & rnd,void * ptr)162 	static void genInt8		(de::Random& rnd, void* ptr) { *reinterpret_cast<deInt8*>(ptr) = static_cast<deInt8>(rnd.getUint8()); }
genUint8(de::Random & rnd,void * ptr)163 	static void genUint8	(de::Random& rnd, void* ptr) { *reinterpret_cast<deUint8*>(ptr) = rnd.getUint8(); }
genInt16(de::Random & rnd,void * ptr)164 	static void genInt16	(de::Random& rnd, void* ptr) { *reinterpret_cast<deInt16*>(ptr) = static_cast<deInt16>(rnd.getUint16()); }
genUint16(de::Random & rnd,void * ptr)165 	static void genUint16	(de::Random& rnd, void* ptr) { *reinterpret_cast<deUint16*>(ptr) = rnd.getUint16(); }
genInt32(de::Random & rnd,void * ptr)166 	static void genInt32	(de::Random& rnd, void* ptr) { *reinterpret_cast<deInt32*>(ptr) = static_cast<deInt32>(rnd.getUint32()); }
genUint32(de::Random & rnd,void * ptr)167 	static void genUint32	(de::Random& rnd, void* ptr) { *reinterpret_cast<deUint32*>(ptr) = rnd.getUint32(); }
genInt64(de::Random & rnd,void * ptr)168 	static void genInt64	(de::Random& rnd, void* ptr) { *reinterpret_cast<deInt64*>(ptr) = static_cast<deInt64>(rnd.getUint64()); }
genUint64(de::Random & rnd,void * ptr)169 	static void genUint64	(de::Random& rnd, void* ptr) { *reinterpret_cast<deUint64*>(ptr) = rnd.getUint64(); }
170 
171 	// Helper template for float generators.
172 	// T must be a tcu::Float instantiation.
173 	// Attempts to generate +-Inf once every 10 times and avoid denormals.
174 	template <class T>
genFloat(de::Random & rnd,void * ptr)175 	static inline void genFloat (de::Random& rnd, void *ptr)
176 	{
177 		T* valuePtr = reinterpret_cast<T*>(ptr);
178 		if (rnd.getInt(1, 10) == 1)
179 			*valuePtr = T::inf(rnd.getBool() ? 1 : -1);
180 		else {
181 			do {
182 				*valuePtr = T{rnd.getDouble(T::largestNormal(-1).asDouble(), T::largestNormal(1).asDouble())};
183 			} while (valuePtr->isDenorm());
184 		}
185 	}
186 
genFloat16(de::Random & rnd,void * ptr)187 	static void genFloat16	(de::Random& rnd, void* ptr) { genFloat<tcu::Float16>(rnd, ptr); }
genFloat32(de::Random & rnd,void * ptr)188 	static void genFloat32	(de::Random& rnd, void* ptr) { genFloat<tcu::Float32>(rnd, ptr); }
genFloat64(de::Random & rnd,void * ptr)189 	static void genFloat64	(de::Random& rnd, void* ptr) { genFloat<tcu::Float64>(rnd, ptr); }
190 
191 	// An operation function writes an output value given 3 input values.
192 	using OperationFunc = void (*)(void*, const void*, const void*, const void*);
193 
194 	// Helper template used below.
195 	template <class T, class F>
runOpFunc(F f,void * out,const void * in1,const void * in2,const void * in3)196 	static inline void runOpFunc (F f, void* out, const void* in1, const void* in2, const void* in3)
197 	{
198 		*reinterpret_cast<T*>(out) = f(*reinterpret_cast<const T*>(in1), *reinterpret_cast<const T*>(in2), *reinterpret_cast<const T*>(in3));
199 	}
200 
201 	// Apply an operation in software to a given group of components and calculate result.
minInt8(void * out,const void * in1,const void * in2,const void * in3)202 	static void minInt8		(void* out, const void* in1, const void* in2, const void* in3) { runOpFunc<deInt8>		(min3<deInt8>,			out, in1, in2, in3); }
maxInt8(void * out,const void * in1,const void * in2,const void * in3)203 	static void maxInt8		(void* out, const void* in1, const void* in2, const void* in3) { runOpFunc<deInt8>		(max3<deInt8>,			out, in1, in2, in3); }
midInt8(void * out,const void * in1,const void * in2,const void * in3)204 	static void midInt8		(void* out, const void* in1, const void* in2, const void* in3) { runOpFunc<deInt8>		(mid3<deInt8>,			out, in1, in2, in3); }
minUint8(void * out,const void * in1,const void * in2,const void * in3)205 	static void minUint8	(void* out, const void* in1, const void* in2, const void* in3) { runOpFunc<deUint8>		(min3<deUint8>,			out, in1, in2, in3); }
maxUint8(void * out,const void * in1,const void * in2,const void * in3)206 	static void maxUint8	(void* out, const void* in1, const void* in2, const void* in3) { runOpFunc<deUint8>		(max3<deUint8>,			out, in1, in2, in3); }
midUint8(void * out,const void * in1,const void * in2,const void * in3)207 	static void midUint8	(void* out, const void* in1, const void* in2, const void* in3) { runOpFunc<deUint8>		(mid3<deUint8>,			out, in1, in2, in3); }
minInt16(void * out,const void * in1,const void * in2,const void * in3)208 	static void minInt16	(void* out, const void* in1, const void* in2, const void* in3) { runOpFunc<deInt16>		(min3<deInt16>,			out, in1, in2, in3); }
maxInt16(void * out,const void * in1,const void * in2,const void * in3)209 	static void maxInt16	(void* out, const void* in1, const void* in2, const void* in3) { runOpFunc<deInt16>		(max3<deInt16>,			out, in1, in2, in3); }
midInt16(void * out,const void * in1,const void * in2,const void * in3)210 	static void midInt16	(void* out, const void* in1, const void* in2, const void* in3) { runOpFunc<deInt16>		(mid3<deInt16>,			out, in1, in2, in3); }
minUint16(void * out,const void * in1,const void * in2,const void * in3)211 	static void minUint16	(void* out, const void* in1, const void* in2, const void* in3) { runOpFunc<deUint16>	(min3<deUint16>,		out, in1, in2, in3); }
maxUint16(void * out,const void * in1,const void * in2,const void * in3)212 	static void maxUint16	(void* out, const void* in1, const void* in2, const void* in3) { runOpFunc<deUint16>	(max3<deUint16>,		out, in1, in2, in3); }
midUint16(void * out,const void * in1,const void * in2,const void * in3)213 	static void midUint16	(void* out, const void* in1, const void* in2, const void* in3) { runOpFunc<deUint16>	(mid3<deUint16>,		out, in1, in2, in3); }
minInt32(void * out,const void * in1,const void * in2,const void * in3)214 	static void minInt32	(void* out, const void* in1, const void* in2, const void* in3) { runOpFunc<deInt32>		(min3<deInt32>,			out, in1, in2, in3); }
maxInt32(void * out,const void * in1,const void * in2,const void * in3)215 	static void maxInt32	(void* out, const void* in1, const void* in2, const void* in3) { runOpFunc<deInt32>		(max3<deInt32>,			out, in1, in2, in3); }
midInt32(void * out,const void * in1,const void * in2,const void * in3)216 	static void midInt32	(void* out, const void* in1, const void* in2, const void* in3) { runOpFunc<deInt32>		(mid3<deInt32>,			out, in1, in2, in3); }
minUint32(void * out,const void * in1,const void * in2,const void * in3)217 	static void minUint32	(void* out, const void* in1, const void* in2, const void* in3) { runOpFunc<deUint32>	(min3<deUint32>,		out, in1, in2, in3); }
maxUint32(void * out,const void * in1,const void * in2,const void * in3)218 	static void maxUint32	(void* out, const void* in1, const void* in2, const void* in3) { runOpFunc<deUint32>	(max3<deUint32>,		out, in1, in2, in3); }
midUint32(void * out,const void * in1,const void * in2,const void * in3)219 	static void midUint32	(void* out, const void* in1, const void* in2, const void* in3) { runOpFunc<deUint32>	(mid3<deUint32>,		out, in1, in2, in3); }
minInt64(void * out,const void * in1,const void * in2,const void * in3)220 	static void minInt64	(void* out, const void* in1, const void* in2, const void* in3) { runOpFunc<deInt64>		(min3<deInt64>,			out, in1, in2, in3); }
maxInt64(void * out,const void * in1,const void * in2,const void * in3)221 	static void maxInt64	(void* out, const void* in1, const void* in2, const void* in3) { runOpFunc<deInt64>		(max3<deInt64>,			out, in1, in2, in3); }
midInt64(void * out,const void * in1,const void * in2,const void * in3)222 	static void midInt64	(void* out, const void* in1, const void* in2, const void* in3) { runOpFunc<deInt64>		(mid3<deInt64>,			out, in1, in2, in3); }
minUint64(void * out,const void * in1,const void * in2,const void * in3)223 	static void minUint64	(void* out, const void* in1, const void* in2, const void* in3) { runOpFunc<deUint64>	(min3<deUint64>,		out, in1, in2, in3); }
maxUint64(void * out,const void * in1,const void * in2,const void * in3)224 	static void maxUint64	(void* out, const void* in1, const void* in2, const void* in3) { runOpFunc<deUint64>	(max3<deUint64>,		out, in1, in2, in3); }
midUint64(void * out,const void * in1,const void * in2,const void * in3)225 	static void midUint64	(void* out, const void* in1, const void* in2, const void* in3) { runOpFunc<deUint64>	(mid3<deUint64>,		out, in1, in2, in3); }
minFloat16(void * out,const void * in1,const void * in2,const void * in3)226 	static void minFloat16	(void* out, const void* in1, const void* in2, const void* in3) { runOpFunc<tcu::Float16>(min3<tcu::Float16>,	out, in1, in2, in3); }
maxFloat16(void * out,const void * in1,const void * in2,const void * in3)227 	static void maxFloat16	(void* out, const void* in1, const void* in2, const void* in3) { runOpFunc<tcu::Float16>(max3<tcu::Float16>,	out, in1, in2, in3); }
midFloat16(void * out,const void * in1,const void * in2,const void * in3)228 	static void midFloat16	(void* out, const void* in1, const void* in2, const void* in3) { runOpFunc<tcu::Float16>(mid3<tcu::Float16>,	out, in1, in2, in3); }
minFloat32(void * out,const void * in1,const void * in2,const void * in3)229 	static void minFloat32	(void* out, const void* in1, const void* in2, const void* in3) { runOpFunc<tcu::Float32>(min3<tcu::Float32>,	out, in1, in2, in3); }
maxFloat32(void * out,const void * in1,const void * in2,const void * in3)230 	static void maxFloat32	(void* out, const void* in1, const void* in2, const void* in3) { runOpFunc<tcu::Float32>(max3<tcu::Float32>,	out, in1, in2, in3); }
midFloat32(void * out,const void * in1,const void * in2,const void * in3)231 	static void midFloat32	(void* out, const void* in1, const void* in2, const void* in3) { runOpFunc<tcu::Float32>(mid3<tcu::Float32>,	out, in1, in2, in3); }
minFloat64(void * out,const void * in1,const void * in2,const void * in3)232 	static void minFloat64	(void* out, const void* in1, const void* in2, const void* in3) { runOpFunc<tcu::Float64>(min3<tcu::Float64>,	out, in1, in2, in3); }
maxFloat64(void * out,const void * in1,const void * in2,const void * in3)233 	static void maxFloat64	(void* out, const void* in1, const void* in2, const void* in3) { runOpFunc<tcu::Float64>(max3<tcu::Float64>,	out, in1, in2, in3); }
midFloat64(void * out,const void * in1,const void * in2,const void * in3)234 	static void midFloat64	(void* out, const void* in1, const void* in2, const void* in3) { runOpFunc<tcu::Float64>(mid3<tcu::Float64>,	out, in1, in2, in3); }
235 
236 	// Case for accessing the functions map.
237 	struct Case
238 	{
239 		BaseType		type;
240 		TypeSize		size;
241 		OperationType	operation;
242 
243 		// This is required for sorting in the map.
operator <vkt::SpirVAssembly::__anonf00c35b40111::OperationManager::Case244 		bool operator< (const Case& other) const
245 		{
246 			return (toArray() < other.toArray());
247 		}
248 
249 	private:
toArrayvkt::SpirVAssembly::__anonf00c35b40111::OperationManager::Case250 		std::array<int, 3> toArray () const
251 		{
252 			return std::array<int, 3>{{static_cast<int>(type), static_cast<int>(size), static_cast<int>(operation)}};
253 		}
254 	};
255 
256 	// Helper map to correctly choose the right generator and operation function for the specific case being tested.
257 	using FuncPair	= std::pair<GenerateCompFunc, OperationFunc>;
258 	using CaseMap	= std::map<Case, FuncPair>;
259 
260 	static const CaseMap	kFunctionsMap;
261 
262 	GenerateCompFunc		m_chosenGenerator;
263 	OperationFunc			m_chosenOperation;
264 	de::Random				m_random;
265 
266 	const deUint32			m_operandSize;
267 	const deUint32			m_numComponents;
268 	const deUint32			m_componentSize;
269 };
270 
271 // This map is used to choose how to generate inputs for each case and which operation to run on the CPU to calculate the reference
272 // results for the generated inputs.
273 const OperationManager::CaseMap OperationManager::kFunctionsMap =
274 {
275 	{ { BaseType::TYPE_INT,		TypeSize::SIZE_8BIT,	OperationType::MIN }, { genInt8,	minInt8		} },
276 	{ { BaseType::TYPE_INT,		TypeSize::SIZE_8BIT,	OperationType::MAX }, { genInt8,	maxInt8		} },
277 	{ { BaseType::TYPE_INT,		TypeSize::SIZE_8BIT,	OperationType::MID }, { genInt8,	midInt8		} },
278 	{ { BaseType::TYPE_INT,		TypeSize::SIZE_16BIT,	OperationType::MIN }, { genInt16,	minInt16	} },
279 	{ { BaseType::TYPE_INT,		TypeSize::SIZE_16BIT,	OperationType::MAX }, { genInt16,	maxInt16	} },
280 	{ { BaseType::TYPE_INT,		TypeSize::SIZE_16BIT,	OperationType::MID }, { genInt16,	midInt16	} },
281 	{ { BaseType::TYPE_INT,		TypeSize::SIZE_32BIT,	OperationType::MIN }, { genInt32,	minInt32	} },
282 	{ { BaseType::TYPE_INT,		TypeSize::SIZE_32BIT,	OperationType::MAX }, { genInt32,	maxInt32	} },
283 	{ { BaseType::TYPE_INT,		TypeSize::SIZE_32BIT,	OperationType::MID }, { genInt32,	midInt32	} },
284 	{ { BaseType::TYPE_INT,		TypeSize::SIZE_64BIT,	OperationType::MIN }, { genInt64,	minInt64	} },
285 	{ { BaseType::TYPE_INT,		TypeSize::SIZE_64BIT,	OperationType::MAX }, { genInt64,	maxInt64	} },
286 	{ { BaseType::TYPE_INT,		TypeSize::SIZE_64BIT,	OperationType::MID }, { genInt64,	midInt64	} },
287 	{ { BaseType::TYPE_UINT,	TypeSize::SIZE_8BIT,	OperationType::MIN }, { genUint8,	minUint8	} },
288 	{ { BaseType::TYPE_UINT,	TypeSize::SIZE_8BIT,	OperationType::MAX }, { genUint8,	maxUint8	} },
289 	{ { BaseType::TYPE_UINT,	TypeSize::SIZE_8BIT,	OperationType::MID }, { genUint8,	midUint8	} },
290 	{ { BaseType::TYPE_UINT,	TypeSize::SIZE_16BIT,	OperationType::MIN }, { genUint16,	minUint16	} },
291 	{ { BaseType::TYPE_UINT,	TypeSize::SIZE_16BIT,	OperationType::MAX }, { genUint16,	maxUint16	} },
292 	{ { BaseType::TYPE_UINT,	TypeSize::SIZE_16BIT,	OperationType::MID }, { genUint16,	midUint16	} },
293 	{ { BaseType::TYPE_UINT,	TypeSize::SIZE_32BIT,	OperationType::MIN }, { genUint32,	minUint32	} },
294 	{ { BaseType::TYPE_UINT,	TypeSize::SIZE_32BIT,	OperationType::MAX }, { genUint32,	maxUint32	} },
295 	{ { BaseType::TYPE_UINT,	TypeSize::SIZE_32BIT,	OperationType::MID }, { genUint32,	midUint32	} },
296 	{ { BaseType::TYPE_UINT,	TypeSize::SIZE_64BIT,	OperationType::MIN }, { genUint64,	minUint64	} },
297 	{ { BaseType::TYPE_UINT,	TypeSize::SIZE_64BIT,	OperationType::MAX }, { genUint64,	maxUint64	} },
298 	{ { BaseType::TYPE_UINT,	TypeSize::SIZE_64BIT,	OperationType::MID }, { genUint64,	midUint64	} },
299 	{ { BaseType::TYPE_FLOAT,	TypeSize::SIZE_16BIT,	OperationType::MIN }, { genFloat16,	minFloat16	} },
300 	{ { BaseType::TYPE_FLOAT,	TypeSize::SIZE_16BIT,	OperationType::MAX }, { genFloat16,	maxFloat16	} },
301 	{ { BaseType::TYPE_FLOAT,	TypeSize::SIZE_16BIT,	OperationType::MID }, { genFloat16,	midFloat16	} },
302 	{ { BaseType::TYPE_FLOAT,	TypeSize::SIZE_32BIT,	OperationType::MIN }, { genFloat32,	minFloat32	} },
303 	{ { BaseType::TYPE_FLOAT,	TypeSize::SIZE_32BIT,	OperationType::MAX }, { genFloat32,	maxFloat32	} },
304 	{ { BaseType::TYPE_FLOAT,	TypeSize::SIZE_32BIT,	OperationType::MID }, { genFloat32,	midFloat32	} },
305 	{ { BaseType::TYPE_FLOAT,	TypeSize::SIZE_64BIT,	OperationType::MIN }, { genFloat64,	minFloat64	} },
306 	{ { BaseType::TYPE_FLOAT,	TypeSize::SIZE_64BIT,	OperationType::MAX }, { genFloat64,	maxFloat64	} },
307 	{ { BaseType::TYPE_FLOAT,	TypeSize::SIZE_64BIT,	OperationType::MID }, { genFloat64,	midFloat64	} },
308 };
309 
OperationManager(const TestParams & params)310 OperationManager::OperationManager (const TestParams& params)
311 	: m_chosenGenerator	{nullptr}
312 	, m_chosenOperation	{nullptr}
313 	, m_random			{params.randomSeed}
314 	, m_operandSize		{params.operandSize()}
315 	, m_numComponents	{params.numComponents()}
316 	, m_componentSize	{params.componentSize()}
317 {
318 	// Choose generator and CPU operation from the map.
319 	const Case paramCase{params.baseType, params.typeSize, params.operation};
320 	const auto iter = kFunctionsMap.find(paramCase);
321 
322 	DE_ASSERT(iter != kFunctionsMap.end());
323 	m_chosenGenerator = iter->second.first;
324 	m_chosenOperation = iter->second.second;
325 }
326 
327 // See TrinaryMinMaxCase::initPrograms for a description of the input buffer format.
328 // Generates inputs with the chosen generator.
genInputBuffer(void * bufferPtr,deUint32 numOperations)329 void OperationManager::genInputBuffer (void* bufferPtr, deUint32 numOperations)
330 {
331 	const deUint32	numOperands	= numOperations * 3u;
332 	char*			byteBuffer	= reinterpret_cast<char*>(bufferPtr);
333 
334 	for (deUint32 opIdx = 0u; opIdx < numOperands; ++opIdx)
335 	{
336 		char* compPtr = byteBuffer;
337 		for (deUint32 compIdx = 0u; compIdx < m_numComponents; ++compIdx)
338 		{
339 			m_chosenGenerator(m_random, reinterpret_cast<void*>(compPtr));
340 			compPtr += m_componentSize;
341 		}
342 		byteBuffer += m_operandSize;
343 	}
344 }
345 
346 // See TrinaryMinMaxCase::initPrograms for a description of the input and output buffer formats.
347 // Calculates reference results on the CPU using the chosen operation and the input buffer.
calculateResult(void * referenceBuffer,void * inputBuffer,deUint32 numOperations)348 void OperationManager::calculateResult (void* referenceBuffer, void* inputBuffer, deUint32 numOperations)
349 {
350 	char* outputByte	= reinterpret_cast<char*>(referenceBuffer);
351 	char* inputByte		= reinterpret_cast<char*>(inputBuffer);
352 
353 	for (deUint32 opIdx = 0u; opIdx < numOperations; ++opIdx)
354 	{
355 		char* res = outputByte;
356 		char* op1 = inputByte;
357 		char* op2 = inputByte + m_operandSize;
358 		char* op3 = inputByte + m_operandSize * 2u;
359 
360 		for (deUint32 compIdx = 0u; compIdx < m_numComponents; ++compIdx)
361 		{
362 			m_chosenOperation(
363 				reinterpret_cast<void*>(res),
364 				reinterpret_cast<void*>(op1),
365 				reinterpret_cast<void*>(op2),
366 				reinterpret_cast<void*>(op3));
367 
368 			res += m_componentSize;
369 			op1 += m_componentSize;
370 			op2 += m_componentSize;
371 			op3 += m_componentSize;
372 		}
373 
374 		outputByte	+= m_operandSize;
375 		inputByte	+= m_operandSize * 3u;
376 	}
377 }
378 
379 // See TrinaryMinMaxCase::initPrograms for a description of the output buffer format.
compareResults(void * referenceBuffer,void * resultsBuffer,deUint32 numOperations)380 OperationManager::ComparisonError OperationManager::compareResults (void* referenceBuffer, void* resultsBuffer, deUint32 numOperations)
381 {
382 	char* referenceBytes	= reinterpret_cast<char*>(referenceBuffer);
383 	char* resultsBytes		= reinterpret_cast<char*>(resultsBuffer);
384 
385 	for (deUint32 opIdx = 0u; opIdx < numOperations; ++opIdx)
386 	{
387 		char *refCompBytes = referenceBytes;
388 		char *resCompBytes = resultsBytes;
389 
390 		for (deUint32 compIdx = 0u; compIdx < m_numComponents; ++compIdx)
391 		{
392 			if (deMemCmp(refCompBytes, resCompBytes, m_componentSize) != 0)
393 				return tcu::just(OperationComponent(opIdx, compIdx));
394 			refCompBytes += m_componentSize;
395 			resCompBytes += m_componentSize;
396 		}
397 		referenceBytes += m_operandSize;
398 		resultsBytes += m_operandSize;
399 	}
400 
401 	return tcu::Nothing;
402 }
403 
404 class TrinaryMinMaxCase : public vkt::TestCase
405 {
406 public:
407 	using ReplacementsMap = std::map<std::string, std::string>;
408 
409 							TrinaryMinMaxCase		(tcu::TestContext& testCtx, const std::string& name, const std::string& description, const TestParams& params);
~TrinaryMinMaxCase(void)410 	virtual					~TrinaryMinMaxCase		(void) {}
411 
412 	virtual void			initPrograms			(vk::SourceCollections& programCollection) const;
413 	virtual TestInstance*	createInstance			(Context& context) const;
414 	virtual void			checkSupport			(Context& context) const;
415 	ReplacementsMap			getSpirVReplacements	(void) const;
416 
417 	static const deUint32	kArraySize;
418 private:
419 	TestParams				m_params;
420 };
421 
422 const deUint32 TrinaryMinMaxCase::kArraySize = 100u;
423 
424 class TrinaryMinMaxInstance : public vkt::TestInstance
425 {
426 public:
427 								TrinaryMinMaxInstance	(Context& context, const TestParams& params);
~TrinaryMinMaxInstance(void)428 	virtual						~TrinaryMinMaxInstance	(void) {}
429 
430 	virtual tcu::TestStatus		iterate					(void);
431 
432 private:
433 	TestParams	m_params;
434 };
435 
TrinaryMinMaxCase(tcu::TestContext & testCtx,const std::string & name,const std::string & description,const TestParams & params)436 TrinaryMinMaxCase::TrinaryMinMaxCase (tcu::TestContext& testCtx, const std::string& name, const std::string& description, const TestParams& params)
437 	: vkt::TestCase	(testCtx, name, description)
438 	, m_params		(params)
439 {}
440 
createInstance(Context & context) const441 TestInstance* TrinaryMinMaxCase::createInstance (Context& context) const
442 {
443 	return new TrinaryMinMaxInstance{context, m_params};
444 }
445 
checkSupport(Context & context) const446 void TrinaryMinMaxCase::checkSupport (Context& context) const
447 {
448 	// These are always required.
449 	context.requireInstanceFunctionality("VK_KHR_get_physical_device_properties2");
450 	context.requireDeviceFunctionality("VK_KHR_storage_buffer_storage_class");
451 	context.requireDeviceFunctionality("VK_AMD_shader_trinary_minmax");
452 
453 	const auto devFeatures			= context.getDeviceFeatures();
454 	const auto storage16BitFeatures	= context.get16BitStorageFeatures();
455 	const auto storage8BitFeatures	= context.get8BitStorageFeatures();
456 	const auto shaderFeatures		= context.getShaderFloat16Int8Features();
457 
458 	// Storage features.
459 	if (m_params.typeSize == TypeSize::SIZE_8BIT)
460 	{
461 		// We will be using 8-bit types in storage buffers.
462 		context.requireDeviceFunctionality("VK_KHR_8bit_storage");
463 		if (!storage8BitFeatures.storageBuffer8BitAccess)
464 			TCU_THROW(NotSupportedError, "8-bit storage buffer access not supported");
465 	}
466 	else if (m_params.typeSize == TypeSize::SIZE_16BIT)
467 	{
468 		// We will be using 16-bit types in storage buffers.
469 		context.requireDeviceFunctionality("VK_KHR_16bit_storage");
470 		if (!storage16BitFeatures.storageBuffer16BitAccess)
471 			TCU_THROW(NotSupportedError, "16-bit storage buffer access not supported");
472 	}
473 
474 	// Shader type features.
475 	if (m_params.baseType == BaseType::TYPE_INT || m_params.baseType == BaseType::TYPE_UINT)
476 	{
477 		if (m_params.typeSize == TypeSize::SIZE_8BIT && !shaderFeatures.shaderInt8)
478 			TCU_THROW(NotSupportedError, "8-bit integers not supported in shaders");
479 		else if (m_params.typeSize == TypeSize::SIZE_16BIT && !devFeatures.shaderInt16)
480 			TCU_THROW(NotSupportedError, "16-bit integers not supported in shaders");
481 		else if (m_params.typeSize == TypeSize::SIZE_64BIT && !devFeatures.shaderInt64)
482 			TCU_THROW(NotSupportedError, "64-bit integers not supported in shaders");
483 	}
484 	else // BaseType::TYPE_FLOAT
485 	{
486 		DE_ASSERT(m_params.typeSize != TypeSize::SIZE_8BIT);
487 		if (m_params.typeSize == TypeSize::SIZE_16BIT && !shaderFeatures.shaderFloat16)
488 			TCU_THROW(NotSupportedError, "16-bit floats not supported in shaders");
489 		else if (m_params.typeSize == TypeSize::SIZE_64BIT && !devFeatures.shaderFloat64)
490 			TCU_THROW(NotSupportedError, "64-bit floats not supported in shaders");
491 	}
492 }
493 
getSpirVReplacements(void) const494 TrinaryMinMaxCase::ReplacementsMap TrinaryMinMaxCase::getSpirVReplacements (void) const
495 {
496 	ReplacementsMap replacements;
497 
498 	// Capabilities and extensions.
499 	if (m_params.baseType == BaseType::TYPE_INT || m_params.baseType == BaseType::TYPE_UINT)
500 	{
501 		if (m_params.typeSize == TypeSize::SIZE_8BIT)
502 			replacements["CAPABILITIES"]	+= "OpCapability Int8\n";
503 		else if (m_params.typeSize == TypeSize::SIZE_16BIT)
504 			replacements["CAPABILITIES"]	+= "OpCapability Int16\n";
505 		else if (m_params.typeSize == TypeSize::SIZE_64BIT)
506 			replacements["CAPABILITIES"]	+= "OpCapability Int64\n";
507 	}
508 	else // BaseType::TYPE_FLOAT
509 	{
510 		if (m_params.typeSize == TypeSize::SIZE_16BIT)
511 			replacements["CAPABILITIES"]	+= "OpCapability Float16\n";
512 		else if (m_params.typeSize == TypeSize::SIZE_64BIT)
513 			replacements["CAPABILITIES"]	+= "OpCapability Float64\n";
514 	}
515 
516 	if (m_params.typeSize == TypeSize::SIZE_8BIT)
517 	{
518 		replacements["CAPABILITIES"]		+= "OpCapability StorageBuffer8BitAccess\n";
519 		replacements["EXTENSIONS"]			+= "OpExtension \"SPV_KHR_8bit_storage\"\n";
520 	}
521 	else if (m_params.typeSize == TypeSize::SIZE_16BIT)
522 	{
523 		replacements["CAPABILITIES"]		+= "OpCapability StorageBuffer16BitAccess\n";
524 		replacements["EXTENSIONS"]			+= "OpExtension \"SPV_KHR_16bit_storage\"\n";
525 	}
526 
527 	// Operand size in bytes.
528 	const deUint32 opSize				= m_params.operandSize();
529 	replacements["OPERAND_SIZE"]		= de::toString(opSize);
530 	replacements["OPERAND_SIZE_2TIMES"]	= de::toString(opSize * 2u);
531 	replacements["OPERAND_SIZE_3TIMES"]	= de::toString(opSize * 3u);
532 
533 	// Array size.
534 	replacements["ARRAY_SIZE"]			= de::toString(kArraySize);
535 
536 	// Types and operand type: define the base integer or float type and the vector type if needed, then set the operand type replacement.
537 	const std::string vecSize	= de::toString(m_params.numComponents());
538 	const std::string bitSize	= de::toString(m_params.componentSize() * 8u);
539 
540 	if (m_params.baseType == BaseType::TYPE_INT || m_params.baseType == BaseType::TYPE_UINT)
541 	{
542 		const std::string	signBit		= (m_params.baseType == BaseType::TYPE_INT ? "1" : "0");
543 		const std::string	typePrefix	= (m_params.baseType == BaseType::TYPE_UINT ? "u" : "");
544 		std::string			baseTypeName;
545 
546 		// 32-bit integers are already defined in the default shader text.
547 		if (m_params.typeSize != TypeSize::SIZE_32BIT)
548 		{
549 			baseTypeName = typePrefix + "int" + bitSize + "_t";
550 			replacements["TYPES"] += "%" + baseTypeName + " = OpTypeInt " + bitSize + " " + signBit + "\n";
551 		}
552 		else
553 		{
554 			baseTypeName = typePrefix + "int";
555 		}
556 
557 		if (m_params.aggregation == AggregationType::SCALAR)
558 		{
559 			replacements["OPERAND_TYPE"] = "%" + baseTypeName;
560 		}
561 		else
562 		{
563 			const std::string typeName = "%v" + vecSize + baseTypeName;
564 			// %v3uint is already defined in the default shader text.
565 			if (m_params.baseType != BaseType::TYPE_UINT || m_params.typeSize != TypeSize::SIZE_32BIT || m_params.aggregation != AggregationType::VEC3)
566 			{
567 				replacements["TYPES"] += typeName + " = OpTypeVector %" + baseTypeName + " " + vecSize + "\n";
568 			}
569 			replacements["OPERAND_TYPE"] = typeName;
570 		}
571 	}
572 	else // BaseType::TYPE_FLOAT
573 	{
574 		const std::string baseTypeName = "float" + bitSize + "_t";
575 		replacements["TYPES"] += "%" + baseTypeName + " = OpTypeFloat " + bitSize + "\n";
576 
577 		if (m_params.aggregation == AggregationType::SCALAR)
578 		{
579 			replacements["OPERAND_TYPE"] = "%" + baseTypeName;
580 		}
581 		else
582 		{
583 			const std::string typeName = "%v" + vecSize + baseTypeName;
584 			replacements["TYPES"] += typeName + " = OpTypeVector %" + baseTypeName + " " + vecSize + "\n";
585 			replacements["OPERAND_TYPE"] = typeName;
586 		}
587 	}
588 
589 	// Operation name.
590 	const static std::vector<std::string> opTypeStr	= { "Min", "Max", "Mid" };
591 	const static std::vector<std::string> opPrefix	= { "S", "U", "F" };
592 	replacements["OPERATION_NAME"] = opPrefix[static_cast<int>(m_params.baseType)] + opTypeStr[static_cast<int>(m_params.operation)] + "3AMD";
593 
594 	return replacements;
595 }
596 
initPrograms(vk::SourceCollections & programCollection) const597 void TrinaryMinMaxCase::initPrograms (vk::SourceCollections& programCollection) const
598 {
599 	// The shader below uses an input buffer at set 0 binding 0 and an output buffer at set 0 binding 1. Their structure is similar
600 	// to the code below:
601 	//
602 	//      struct Operands {
603 	//              <type> op1;
604 	//              <type> op2;
605 	//              <type> op3;
606 	//      };
607 	//
608 	//      layout (set=0, binding=0, std430) buffer InputBlock {
609 	//              Operands operands[<arraysize>];
610 	//      };
611 	//
612 	//      layout (set=0, binding=1, std430) buffer OutputBlock {
613 	//              <type> result[<arraysize>];
614 	//      };
615 	//
616 	// Where <type> can be int8_t, uint32_t, float, etc. So in the input buffer the operands are "grouped" per operation and can
617 	// have several components each and the output buffer contains an array of results, one per trio of input operands.
618 
619 	std::ostringstream shaderStr;
620 	shaderStr
621 		<< "; SPIR-V\n"
622 		<< "; Version: 1.5\n"
623 		<< "                            OpCapability Shader\n"
624 		<< "${CAPABILITIES:opt}"
625 		<< "                            OpExtension \"SPV_KHR_storage_buffer_storage_class\"\n"
626 		<< "                            OpExtension \"SPV_AMD_shader_trinary_minmax\"\n"
627 		<< "${EXTENSIONS:opt}"
628 		<< "                  %std450 = OpExtInstImport \"GLSL.std.450\"\n"
629 		<< "                 %trinary = OpExtInstImport \"SPV_AMD_shader_trinary_minmax\"\n"
630 		<< "                            OpMemoryModel Logical GLSL450\n"
631 		<< "                            OpEntryPoint GLCompute %main \"main\" %gl_GlobalInvocationID %output_buffer %input_buffer\n"
632 		<< "                            OpExecutionMode %main LocalSize 1 1 1\n"
633 		<< "                            OpDecorate %gl_GlobalInvocationID BuiltIn GlobalInvocationId\n"
634 		<< "                            OpDecorate %results_array_t ArrayStride ${OPERAND_SIZE}\n"
635 		<< "                            OpMemberDecorate %OutputBlock 0 Offset 0\n"
636 		<< "                            OpDecorate %OutputBlock Block\n"
637 		<< "                            OpDecorate %output_buffer DescriptorSet 0\n"
638 		<< "                            OpDecorate %output_buffer Binding 1\n"
639 		<< "                            OpMemberDecorate %Operands 0 Offset 0\n"
640 		<< "                            OpMemberDecorate %Operands 1 Offset ${OPERAND_SIZE}\n"
641 		<< "                            OpMemberDecorate %Operands 2 Offset ${OPERAND_SIZE_2TIMES}\n"
642 		<< "                            OpDecorate %_arr_Operands_arraysize ArrayStride ${OPERAND_SIZE_3TIMES}\n"
643 		<< "                            OpMemberDecorate %InputBlock 0 Offset 0\n"
644 		<< "                            OpDecorate %InputBlock Block\n"
645 		<< "                            OpDecorate %input_buffer DescriptorSet 0\n"
646 		<< "                            OpDecorate %input_buffer Binding 0\n"
647 		<< "                            OpDecorate %gl_WorkGroupSize BuiltIn WorkgroupSize\n"
648 		<< "                    %void = OpTypeVoid\n"
649 		<< "                %voidfunc = OpTypeFunction %void\n"
650 		<< "                     %int = OpTypeInt 32 1\n"
651 		<< "                    %uint = OpTypeInt 32 0\n"
652 		<< "                  %v3uint = OpTypeVector %uint 3\n"
653 		<< "${TYPES:opt}"
654 		<< "                   %int_0 = OpConstant %int 0\n"
655 		<< "                   %int_1 = OpConstant %int 1\n"
656 		<< "                   %int_2 = OpConstant %int 2\n"
657 		<< "                  %uint_1 = OpConstant %uint 1\n"
658 		<< "                  %uint_0 = OpConstant %uint 0\n"
659 		<< "               %arraysize = OpConstant %uint ${ARRAY_SIZE}\n"
660 		<< "      %_ptr_Function_uint = OpTypePointer Function %uint\n"
661 		<< "       %_ptr_Input_v3uint = OpTypePointer Input %v3uint\n"
662 		<< "   %gl_GlobalInvocationID = OpVariable %_ptr_Input_v3uint Input\n"
663 		<< "         %_ptr_Input_uint = OpTypePointer Input %uint\n"
664 		<< "         %results_array_t = OpTypeArray ${OPERAND_TYPE} %arraysize\n"
665 		<< "                %Operands = OpTypeStruct ${OPERAND_TYPE} ${OPERAND_TYPE} ${OPERAND_TYPE}\n"
666 		<< " %_arr_Operands_arraysize = OpTypeArray %Operands %arraysize\n"
667 		<< "             %OutputBlock = OpTypeStruct %results_array_t\n"
668 		<< "              %InputBlock = OpTypeStruct %_arr_Operands_arraysize\n"
669 		<< "%_ptr_Uniform_OutputBlock = OpTypePointer StorageBuffer %OutputBlock\n"
670 		<< " %_ptr_Uniform_InputBlock = OpTypePointer StorageBuffer %InputBlock\n"
671 		<< "           %output_buffer = OpVariable %_ptr_Uniform_OutputBlock StorageBuffer\n"
672 		<< "            %input_buffer = OpVariable %_ptr_Uniform_InputBlock StorageBuffer\n"
673 		<< "              %optype_ptr = OpTypePointer StorageBuffer ${OPERAND_TYPE}\n"
674 		<< "        %gl_WorkGroupSize = OpConstantComposite %v3uint %uint_1 %uint_1 %uint_1\n"
675 		<< "                    %main = OpFunction %void None %voidfunc\n"
676 		<< "               %mainlabel = OpLabel\n"
677 		<< "                 %gidxptr = OpAccessChain %_ptr_Input_uint %gl_GlobalInvocationID %uint_0\n"
678 		<< "                     %idx = OpLoad %uint %gidxptr\n"
679 		<< "                  %op1ptr = OpAccessChain %optype_ptr %input_buffer %int_0 %idx %int_0\n"
680 		<< "                     %op1 = OpLoad ${OPERAND_TYPE} %op1ptr\n"
681 		<< "                  %op2ptr = OpAccessChain %optype_ptr %input_buffer %int_0 %idx %int_1\n"
682 		<< "                     %op2 = OpLoad ${OPERAND_TYPE} %op2ptr\n"
683 		<< "                  %op3ptr = OpAccessChain %optype_ptr %input_buffer %int_0 %idx %int_2\n"
684 		<< "                     %op3 = OpLoad ${OPERAND_TYPE} %op3ptr\n"
685 		<< "                  %result = OpExtInst ${OPERAND_TYPE} %trinary ${OPERATION_NAME} %op1 %op2 %op3\n"
686 		<< "               %resultptr = OpAccessChain %optype_ptr %output_buffer %int_0 %idx\n"
687 		<< "                            OpStore %resultptr %result\n"
688 		<< "                            OpReturn\n"
689 		<< "                            OpFunctionEnd\n"
690 		;
691 
692 	const tcu::StringTemplate		shaderTemplate	{shaderStr.str()};
693 	const vk::SpirVAsmBuildOptions	buildOptions	{ VK_MAKE_API_VERSION(0, 1, 2, 0), vk::SPIRV_VERSION_1_5};
694 
695 	programCollection.spirvAsmSources.add("comp", &buildOptions) << shaderTemplate.specialize(getSpirVReplacements());
696 }
697 
TrinaryMinMaxInstance(Context & context,const TestParams & params)698 TrinaryMinMaxInstance::TrinaryMinMaxInstance (Context& context, const TestParams& params)
699 	: vkt::TestInstance	(context)
700 	, m_params			(params)
701 {}
702 
iterate(void)703 tcu::TestStatus TrinaryMinMaxInstance::iterate (void)
704 {
705 	const auto&	vkd			= m_context.getDeviceInterface();
706 	const auto	device		= m_context.getDevice();
707 	auto&		allocator	= m_context.getDefaultAllocator();
708 	const auto	queue		= m_context.getUniversalQueue();
709 	const auto	queueIndex	= m_context.getUniversalQueueFamilyIndex();
710 
711 	constexpr auto kNumOperations = TrinaryMinMaxCase::kArraySize;
712 
713 	const vk::VkDeviceSize kInputBufferSize		= static_cast<vk::VkDeviceSize>(kNumOperations * 3u * m_params.operandSize());
714 	const vk::VkDeviceSize kOutputBufferSize	= static_cast<vk::VkDeviceSize>(kNumOperations * m_params.operandSize()); // Single output per operation.
715 
716 	// Create input, output and reference buffers.
717 	auto inputBufferInfo	= vk::makeBufferCreateInfo(kInputBufferSize, vk::VK_BUFFER_USAGE_STORAGE_BUFFER_BIT);
718 	auto outputBufferInfo	= vk::makeBufferCreateInfo(kOutputBufferSize, vk::VK_BUFFER_USAGE_STORAGE_BUFFER_BIT);
719 
720 	vk::BufferWithMemory	inputBuffer		{vkd, device, allocator, inputBufferInfo,	vk::MemoryRequirement::HostVisible};
721 	vk::BufferWithMemory	outputBuffer	{vkd, device, allocator, outputBufferInfo,	vk::MemoryRequirement::HostVisible};
722 	std::unique_ptr<char[]>	referenceBuffer	{new char[static_cast<size_t>(kOutputBufferSize)]};
723 
724 	// Fill buffers with initial contents.
725 	auto& inputAlloc	= inputBuffer.getAllocation();
726 	auto& outputAlloc	= outputBuffer.getAllocation();
727 
728 	void* inputBufferPtr		= static_cast<deUint8*>(inputAlloc.getHostPtr()) + inputAlloc.getOffset();
729 	void* outputBufferPtr		= static_cast<deUint8*>(outputAlloc.getHostPtr()) + outputAlloc.getOffset();
730 	void* referenceBufferPtr	= referenceBuffer.get();
731 
732 	deMemset(inputBufferPtr, 0, static_cast<size_t>(kInputBufferSize));
733 	deMemset(outputBufferPtr, 0, static_cast<size_t>(kOutputBufferSize));
734 	deMemset(referenceBufferPtr, 0, static_cast<size_t>(kOutputBufferSize));
735 
736 	// Generate input buffer and calculate reference results.
737 	OperationManager opMan{m_params};
738 	opMan.genInputBuffer(inputBufferPtr, kNumOperations);
739 	opMan.calculateResult(referenceBufferPtr, inputBufferPtr, kNumOperations);
740 
741 	// Flush buffer memory before starting.
742 	vk::flushAlloc(vkd, device, inputAlloc);
743 	vk::flushAlloc(vkd, device, outputAlloc);
744 
745 	// Descriptor set layout.
746 	vk::DescriptorSetLayoutBuilder layoutBuilder;
747 	layoutBuilder.addSingleBinding(vk::VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, vk::VK_SHADER_STAGE_COMPUTE_BIT);
748 	layoutBuilder.addSingleBinding(vk::VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, vk::VK_SHADER_STAGE_COMPUTE_BIT);
749 	auto descriptorSetLayout = layoutBuilder.build(vkd, device);
750 
751 	// Descriptor pool.
752 	vk::DescriptorPoolBuilder poolBuilder;
753 	poolBuilder.addType(vk::VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, 2u);
754 	auto descriptorPool = poolBuilder.build(vkd, device, vk::VK_DESCRIPTOR_POOL_CREATE_FREE_DESCRIPTOR_SET_BIT, 1u);
755 
756 	// Descriptor set.
757 	const auto descriptorSet = vk::makeDescriptorSet(vkd, device, descriptorPool.get(), descriptorSetLayout.get());
758 
759 	// Update descriptor set using the buffers.
760 	const auto inputBufferDescriptorInfo	= vk::makeDescriptorBufferInfo(inputBuffer.get(), 0ull, VK_WHOLE_SIZE);
761 	const auto outputBufferDescriptorInfo	= vk::makeDescriptorBufferInfo(outputBuffer.get(), 0ull, VK_WHOLE_SIZE);
762 
763 	vk::DescriptorSetUpdateBuilder updateBuilder;
764 	updateBuilder.writeSingle(descriptorSet.get(), vk::DescriptorSetUpdateBuilder::Location::binding(0u), vk::VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, &inputBufferDescriptorInfo);
765 	updateBuilder.writeSingle(descriptorSet.get(), vk::DescriptorSetUpdateBuilder::Location::binding(1u), vk::VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, &outputBufferDescriptorInfo);
766 	updateBuilder.update(vkd, device);
767 
768 	// Create compute pipeline.
769 	auto shaderModule = vk::createShaderModule(vkd, device, m_context.getBinaryCollection().get("comp"), 0u);
770 	auto pipelineLayout = vk::makePipelineLayout(vkd, device, descriptorSetLayout.get());
771 
772 	const vk::VkComputePipelineCreateInfo pipelineCreateInfo =
773 	{
774 		vk::VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO,
775 		nullptr,
776 		0u,															// flags
777 		{															// compute shader
778 			vk::VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO,	// VkStructureType						sType;
779 			nullptr,													// const void*							pNext;
780 			0u,															// VkPipelineShaderStageCreateFlags		flags;
781 			vk::VK_SHADER_STAGE_COMPUTE_BIT,							// VkShaderStageFlagBits				stage;
782 			shaderModule.get(),											// VkShaderModule						module;
783 			"main",														// const char*							pName;
784 			nullptr,													// const VkSpecializationInfo*			pSpecializationInfo;
785 		},
786 		pipelineLayout.get(),										// layout
787 		DE_NULL,													// basePipelineHandle
788 		0,															// basePipelineIndex
789 	};
790 	auto pipeline = vk::createComputePipeline(vkd, device, DE_NULL, &pipelineCreateInfo);
791 
792 	// Synchronization barriers.
793 	auto inputBufferHostToDevBarrier	= vk::makeBufferMemoryBarrier(vk::VK_ACCESS_HOST_WRITE_BIT, vk::VK_ACCESS_SHADER_READ_BIT, inputBuffer.get(), 0ull, VK_WHOLE_SIZE);
794 	auto outputBufferHostToDevBarrier	= vk::makeBufferMemoryBarrier(vk::VK_ACCESS_HOST_WRITE_BIT, vk::VK_ACCESS_SHADER_WRITE_BIT, outputBuffer.get(), 0ull, VK_WHOLE_SIZE);
795 	auto outputBufferDevToHostBarrier	= vk::makeBufferMemoryBarrier(vk::VK_ACCESS_SHADER_WRITE_BIT, vk::VK_ACCESS_HOST_READ_BIT, outputBuffer.get(), 0ull, VK_WHOLE_SIZE);
796 
797 	// Command buffer.
798 	auto cmdPool		= vk::makeCommandPool(vkd, device, queueIndex);
799 	auto cmdBufferPtr	= vk::allocateCommandBuffer(vkd, device, cmdPool.get(), vk::VK_COMMAND_BUFFER_LEVEL_PRIMARY);
800 	auto cmdBuffer		= cmdBufferPtr.get();
801 
802 	// Record and submit commands.
803 	vk::beginCommandBuffer(vkd, cmdBuffer);
804 		vkd.cmdBindPipeline(cmdBuffer, vk::VK_PIPELINE_BIND_POINT_COMPUTE, pipeline.get());
805 		vkd.cmdBindDescriptorSets(cmdBuffer, vk::VK_PIPELINE_BIND_POINT_COMPUTE, pipelineLayout.get(), 0, 1u, &descriptorSet.get(), 0u, nullptr);
806 		vkd.cmdPipelineBarrier(cmdBuffer, vk::VK_PIPELINE_STAGE_HOST_BIT, vk::VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, 0u, 0u, nullptr, 1u, &inputBufferHostToDevBarrier, 0u, nullptr);
807 		vkd.cmdPipelineBarrier(cmdBuffer, vk::VK_PIPELINE_STAGE_HOST_BIT, vk::VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, 0u, 0u, nullptr, 1u, &outputBufferHostToDevBarrier, 0u, nullptr);
808 		vkd.cmdDispatch(cmdBuffer, kNumOperations, 1u, 1u);
809 		vkd.cmdPipelineBarrier(cmdBuffer, vk::VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, vk::VK_PIPELINE_STAGE_HOST_BIT, 0u, 0u, nullptr, 1u, &outputBufferDevToHostBarrier, 0u, nullptr);
810 	vk::endCommandBuffer(vkd, cmdBuffer);
811 	vk::submitCommandsAndWait(vkd, device, queue, cmdBuffer);
812 
813 	// Verify output buffer contents.
814 	vk::invalidateAlloc(vkd, device, outputAlloc);
815 
816 	const auto error = opMan.compareResults(referenceBufferPtr, outputBufferPtr, kNumOperations);
817 
818 	if (!error)
819 		return tcu::TestStatus::pass("Pass");
820 
821 	std::ostringstream msg;
822 	msg << "Value mismatch at operation " << error.get().first << " in component " << error.get().second;
823 	return tcu::TestStatus::fail(msg.str());
824 }
825 
826 } // anonymous
827 
createTrinaryMinMaxGroup(tcu::TestContext & testCtx)828 tcu::TestCaseGroup* createTrinaryMinMaxGroup (tcu::TestContext& testCtx)
829 {
830 	deUint32 seed = 0xFEE768FCu;
831 	de::MovePtr<tcu::TestCaseGroup> group{new tcu::TestCaseGroup{testCtx, "amd_trinary_minmax", "Tests for VK_AMD_trinary_minmax operations"}};
832 
833 	static const std::vector<std::pair<OperationType, std::string>> operationTypes =
834 	{
835 		{ OperationType::MIN, "min3" },
836 		{ OperationType::MAX, "max3" },
837 		{ OperationType::MID, "mid3" },
838 	};
839 
840 	static const std::vector<std::pair<BaseType, std::string>> baseTypes =
841 	{
842 		{ BaseType::TYPE_INT,	"i" },
843 		{ BaseType::TYPE_UINT,	"u" },
844 		{ BaseType::TYPE_FLOAT,	"f" },
845 	};
846 
847 	static const std::vector<std::pair<TypeSize, std::string>> typeSizes =
848 	{
849 		{ TypeSize::SIZE_8BIT,	"8"		},
850 		{ TypeSize::SIZE_16BIT,	"16"	},
851 		{ TypeSize::SIZE_32BIT,	"32"	},
852 		{ TypeSize::SIZE_64BIT,	"64"	},
853 	};
854 
855 	static const std::vector<std::pair<AggregationType, std::string>> aggregationTypes =
856 	{
857 		{ AggregationType::SCALAR,	"scalar"	},
858 		{ AggregationType::VEC2,	"vec2"		},
859 		{ AggregationType::VEC3,	"vec3"		},
860 		{ AggregationType::VEC4,	"vec4"		},
861 	};
862 
863 	for (const auto& opType : operationTypes)
864 	{
865 		const std::string opDesc = "Tests for " + opType.second + " operation";
866 		de::MovePtr<tcu::TestCaseGroup> opGroup{new tcu::TestCaseGroup{testCtx, opType.second.c_str(), opDesc.c_str()}};
867 
868 		for (const auto& baseType : baseTypes)
869 		for (const auto& typeSize : typeSizes)
870 		{
871 			// There are no 8-bit floats.
872 			if (baseType.first == BaseType::TYPE_FLOAT && typeSize.first == TypeSize::SIZE_8BIT)
873 				continue;
874 
875 			const std::string typeName = baseType.second + typeSize.second;
876 			const std::string typeDesc = "Tests using " + typeName + " data";
877 
878 			de::MovePtr<tcu::TestCaseGroup> typeGroup{new tcu::TestCaseGroup{testCtx, typeName.c_str(), typeDesc.c_str()}};
879 
880 			for (const auto& aggType : aggregationTypes)
881 			{
882 				const TestParams params =
883 				{
884 					opType.first,		// OperationType	operation;
885 					baseType.first,		// BaseType			baseType;
886 					typeSize.first,		// TypeSize			typeSize;
887 					aggType.first,		// AggregationType	aggregation;
888 					seed++,				// deUint32			randomSeed;
889 				};
890 				typeGroup->addChild(new TrinaryMinMaxCase{testCtx, aggType.second, "", params});
891 			}
892 
893 			opGroup->addChild(typeGroup.release());
894 		}
895 
896 		group->addChild(opGroup.release());
897 	}
898 
899 	return group.release();
900 }
901 
902 } // SpirVAssembly
903 } // vkt
904