1 /*-------------------------------------------------------------------------
2 * Vulkan Conformance Tests
3 * ------------------------
4 *
5 * Copyright (c) 2020 Valve Corporation.
6 * Copyright (c) 2020 The Khronos Group Inc.
7 *
8 * Licensed under the Apache License, Version 2.0 (the "License");
9 * you may not use this file except in compliance with the License.
10 * You may obtain a copy of the License at
11 *
12 * http://www.apache.org/licenses/LICENSE-2.0
13 *
14 * Unless required by applicable law or agreed to in writing, software
15 * distributed under the License is distributed on an "AS IS" BASIS,
16 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17 * See the License for the specific language governing permissions and
18 * limitations under the License.
19 *
20 *//*!
21 * \file
22 * \brief SPIR-V tests for VK_AMD_shader_trinary_minmax.
23 *//*--------------------------------------------------------------------*/
24
25 #include "vktSpvAsmTrinaryMinMaxTests.hpp"
26 #include "vktTestCase.hpp"
27
28 #include "vkQueryUtil.hpp"
29 #include "vkObjUtil.hpp"
30 #include "vkBufferWithMemory.hpp"
31 #include "vkBuilderUtil.hpp"
32 #include "vkTypeUtil.hpp"
33 #include "vkBarrierUtil.hpp"
34 #include "vkCmdUtil.hpp"
35
36 #include "tcuStringTemplate.hpp"
37 #include "tcuFloat.hpp"
38 #include "tcuMaybe.hpp"
39
40 #include "deStringUtil.hpp"
41 #include "deRandom.hpp"
42 #include "deMemory.h"
43
44 #include <string>
45 #include <sstream>
46 #include <map>
47 #include <vector>
48 #include <algorithm>
49 #include <array>
50 #include <memory>
51
52 namespace vkt
53 {
54 namespace SpirVAssembly
55 {
56
57 namespace
58 {
59
60 enum class OperationType
61 {
62 MIN = 0,
63 MAX = 1,
64 MID = 2,
65 };
66
67 enum class BaseType
68 {
69 TYPE_INT = 0,
70 TYPE_UINT,
71 TYPE_FLOAT,
72 };
73
74 // The numeric value is the size in bytes.
75 enum class TypeSize
76 {
77 SIZE_8BIT = 1,
78 SIZE_16BIT = 2,
79 SIZE_32BIT = 4,
80 SIZE_64BIT = 8,
81 };
82
83 // The numeric value is the number of components.
84 enum class AggregationType
85 {
86 SCALAR = 1,
87 VEC2 = 2,
88 VEC3 = 3,
89 VEC4 = 4,
90 };
91
92 struct TestParams
93 {
94 OperationType operation;
95 BaseType baseType;
96 TypeSize typeSize;
97 AggregationType aggregation;
98 deUint32 randomSeed;
99
100 deUint32 operandSize () const; // In bytes.
101 deUint32 numComponents () const; // Number of components.
102 deUint32 effectiveComponents () const; // Effective number of components for size calculation.
103 deUint32 componentSize () const; // In bytes.
104 };
105
operandSize() const106 deUint32 TestParams::operandSize () const
107 {
108 return (effectiveComponents() * componentSize());
109 }
110
numComponents() const111 deUint32 TestParams::numComponents () const
112 {
113 return static_cast<deUint32>(aggregation);
114 }
115
effectiveComponents() const116 deUint32 TestParams::effectiveComponents () const
117 {
118 return static_cast<deUint32>((aggregation == AggregationType::VEC3) ? AggregationType::VEC4 : aggregation);
119 }
120
componentSize() const121 deUint32 TestParams::componentSize () const
122 {
123 return static_cast<deUint32>(typeSize);
124 }
125
126 template <class T>
min3(T op1,T op2,T op3)127 T min3(T op1, T op2, T op3)
128 {
129 return std::min({op1, op2, op3});
130 }
131
132 template <class T>
max3(T op1,T op2,T op3)133 T max3(T op1, T op2, T op3)
134 {
135 return std::max({op1, op2, op3});
136 }
137
138 template <class T>
mid3(T op1,T op2,T op3)139 T mid3(T op1, T op2, T op3)
140 {
141 std::array<T, 3> aux{{op1, op2, op3}};
142 std::sort(begin(aux), end(aux));
143 return aux[1];
144 }
145
146 class OperationManager
147 {
148 public:
149 // Operation and component index in case of error.
150 using OperationComponent = std::pair<deUint32, deUint32>;
151 using ComparisonError = tcu::Maybe<OperationComponent>;
152
153 OperationManager (const TestParams& params);
154 void genInputBuffer (void* bufferPtr, deUint32 numOperations);
155 void calculateResult (void* referenceBuffer, void* inputBuffer, deUint32 numOperations);
156 ComparisonError compareResults (void* referenceBuffer, void* resultsBuffer, deUint32 numOperations);
157
158 private:
159 using GenerateCompFunc = void (*)(de::Random&, void*); // Write a generated component to the given location.
160
161 // Generator variants to populate input buffer.
genInt8(de::Random & rnd,void * ptr)162 static void genInt8 (de::Random& rnd, void* ptr) { *reinterpret_cast<deInt8*>(ptr) = static_cast<deInt8>(rnd.getUint8()); }
genUint8(de::Random & rnd,void * ptr)163 static void genUint8 (de::Random& rnd, void* ptr) { *reinterpret_cast<deUint8*>(ptr) = rnd.getUint8(); }
genInt16(de::Random & rnd,void * ptr)164 static void genInt16 (de::Random& rnd, void* ptr) { *reinterpret_cast<deInt16*>(ptr) = static_cast<deInt16>(rnd.getUint16()); }
genUint16(de::Random & rnd,void * ptr)165 static void genUint16 (de::Random& rnd, void* ptr) { *reinterpret_cast<deUint16*>(ptr) = rnd.getUint16(); }
genInt32(de::Random & rnd,void * ptr)166 static void genInt32 (de::Random& rnd, void* ptr) { *reinterpret_cast<deInt32*>(ptr) = static_cast<deInt32>(rnd.getUint32()); }
genUint32(de::Random & rnd,void * ptr)167 static void genUint32 (de::Random& rnd, void* ptr) { *reinterpret_cast<deUint32*>(ptr) = rnd.getUint32(); }
genInt64(de::Random & rnd,void * ptr)168 static void genInt64 (de::Random& rnd, void* ptr) { *reinterpret_cast<deInt64*>(ptr) = static_cast<deInt64>(rnd.getUint64()); }
genUint64(de::Random & rnd,void * ptr)169 static void genUint64 (de::Random& rnd, void* ptr) { *reinterpret_cast<deUint64*>(ptr) = rnd.getUint64(); }
170
171 // Helper template for float generators.
172 // T must be a tcu::Float instantiation.
173 // Attempts to generate +-Inf once every 10 times and avoid denormals.
174 template <class T>
genFloat(de::Random & rnd,void * ptr)175 static inline void genFloat (de::Random& rnd, void *ptr)
176 {
177 T* valuePtr = reinterpret_cast<T*>(ptr);
178 if (rnd.getInt(1, 10) == 1)
179 *valuePtr = T::inf(rnd.getBool() ? 1 : -1);
180 else {
181 do {
182 *valuePtr = T{rnd.getDouble(T::largestNormal(-1).asDouble(), T::largestNormal(1).asDouble())};
183 } while (valuePtr->isDenorm());
184 }
185 }
186
genFloat16(de::Random & rnd,void * ptr)187 static void genFloat16 (de::Random& rnd, void* ptr) { genFloat<tcu::Float16>(rnd, ptr); }
genFloat32(de::Random & rnd,void * ptr)188 static void genFloat32 (de::Random& rnd, void* ptr) { genFloat<tcu::Float32>(rnd, ptr); }
genFloat64(de::Random & rnd,void * ptr)189 static void genFloat64 (de::Random& rnd, void* ptr) { genFloat<tcu::Float64>(rnd, ptr); }
190
191 // An operation function writes an output value given 3 input values.
192 using OperationFunc = void (*)(void*, const void*, const void*, const void*);
193
194 // Helper template used below.
195 template <class T, class F>
runOpFunc(F f,void * out,const void * in1,const void * in2,const void * in3)196 static inline void runOpFunc (F f, void* out, const void* in1, const void* in2, const void* in3)
197 {
198 *reinterpret_cast<T*>(out) = f(*reinterpret_cast<const T*>(in1), *reinterpret_cast<const T*>(in2), *reinterpret_cast<const T*>(in3));
199 }
200
201 // Apply an operation in software to a given group of components and calculate result.
minInt8(void * out,const void * in1,const void * in2,const void * in3)202 static void minInt8 (void* out, const void* in1, const void* in2, const void* in3) { runOpFunc<deInt8> (min3<deInt8>, out, in1, in2, in3); }
maxInt8(void * out,const void * in1,const void * in2,const void * in3)203 static void maxInt8 (void* out, const void* in1, const void* in2, const void* in3) { runOpFunc<deInt8> (max3<deInt8>, out, in1, in2, in3); }
midInt8(void * out,const void * in1,const void * in2,const void * in3)204 static void midInt8 (void* out, const void* in1, const void* in2, const void* in3) { runOpFunc<deInt8> (mid3<deInt8>, out, in1, in2, in3); }
minUint8(void * out,const void * in1,const void * in2,const void * in3)205 static void minUint8 (void* out, const void* in1, const void* in2, const void* in3) { runOpFunc<deUint8> (min3<deUint8>, out, in1, in2, in3); }
maxUint8(void * out,const void * in1,const void * in2,const void * in3)206 static void maxUint8 (void* out, const void* in1, const void* in2, const void* in3) { runOpFunc<deUint8> (max3<deUint8>, out, in1, in2, in3); }
midUint8(void * out,const void * in1,const void * in2,const void * in3)207 static void midUint8 (void* out, const void* in1, const void* in2, const void* in3) { runOpFunc<deUint8> (mid3<deUint8>, out, in1, in2, in3); }
minInt16(void * out,const void * in1,const void * in2,const void * in3)208 static void minInt16 (void* out, const void* in1, const void* in2, const void* in3) { runOpFunc<deInt16> (min3<deInt16>, out, in1, in2, in3); }
maxInt16(void * out,const void * in1,const void * in2,const void * in3)209 static void maxInt16 (void* out, const void* in1, const void* in2, const void* in3) { runOpFunc<deInt16> (max3<deInt16>, out, in1, in2, in3); }
midInt16(void * out,const void * in1,const void * in2,const void * in3)210 static void midInt16 (void* out, const void* in1, const void* in2, const void* in3) { runOpFunc<deInt16> (mid3<deInt16>, out, in1, in2, in3); }
minUint16(void * out,const void * in1,const void * in2,const void * in3)211 static void minUint16 (void* out, const void* in1, const void* in2, const void* in3) { runOpFunc<deUint16> (min3<deUint16>, out, in1, in2, in3); }
maxUint16(void * out,const void * in1,const void * in2,const void * in3)212 static void maxUint16 (void* out, const void* in1, const void* in2, const void* in3) { runOpFunc<deUint16> (max3<deUint16>, out, in1, in2, in3); }
midUint16(void * out,const void * in1,const void * in2,const void * in3)213 static void midUint16 (void* out, const void* in1, const void* in2, const void* in3) { runOpFunc<deUint16> (mid3<deUint16>, out, in1, in2, in3); }
minInt32(void * out,const void * in1,const void * in2,const void * in3)214 static void minInt32 (void* out, const void* in1, const void* in2, const void* in3) { runOpFunc<deInt32> (min3<deInt32>, out, in1, in2, in3); }
maxInt32(void * out,const void * in1,const void * in2,const void * in3)215 static void maxInt32 (void* out, const void* in1, const void* in2, const void* in3) { runOpFunc<deInt32> (max3<deInt32>, out, in1, in2, in3); }
midInt32(void * out,const void * in1,const void * in2,const void * in3)216 static void midInt32 (void* out, const void* in1, const void* in2, const void* in3) { runOpFunc<deInt32> (mid3<deInt32>, out, in1, in2, in3); }
minUint32(void * out,const void * in1,const void * in2,const void * in3)217 static void minUint32 (void* out, const void* in1, const void* in2, const void* in3) { runOpFunc<deUint32> (min3<deUint32>, out, in1, in2, in3); }
maxUint32(void * out,const void * in1,const void * in2,const void * in3)218 static void maxUint32 (void* out, const void* in1, const void* in2, const void* in3) { runOpFunc<deUint32> (max3<deUint32>, out, in1, in2, in3); }
midUint32(void * out,const void * in1,const void * in2,const void * in3)219 static void midUint32 (void* out, const void* in1, const void* in2, const void* in3) { runOpFunc<deUint32> (mid3<deUint32>, out, in1, in2, in3); }
minInt64(void * out,const void * in1,const void * in2,const void * in3)220 static void minInt64 (void* out, const void* in1, const void* in2, const void* in3) { runOpFunc<deInt64> (min3<deInt64>, out, in1, in2, in3); }
maxInt64(void * out,const void * in1,const void * in2,const void * in3)221 static void maxInt64 (void* out, const void* in1, const void* in2, const void* in3) { runOpFunc<deInt64> (max3<deInt64>, out, in1, in2, in3); }
midInt64(void * out,const void * in1,const void * in2,const void * in3)222 static void midInt64 (void* out, const void* in1, const void* in2, const void* in3) { runOpFunc<deInt64> (mid3<deInt64>, out, in1, in2, in3); }
minUint64(void * out,const void * in1,const void * in2,const void * in3)223 static void minUint64 (void* out, const void* in1, const void* in2, const void* in3) { runOpFunc<deUint64> (min3<deUint64>, out, in1, in2, in3); }
maxUint64(void * out,const void * in1,const void * in2,const void * in3)224 static void maxUint64 (void* out, const void* in1, const void* in2, const void* in3) { runOpFunc<deUint64> (max3<deUint64>, out, in1, in2, in3); }
midUint64(void * out,const void * in1,const void * in2,const void * in3)225 static void midUint64 (void* out, const void* in1, const void* in2, const void* in3) { runOpFunc<deUint64> (mid3<deUint64>, out, in1, in2, in3); }
minFloat16(void * out,const void * in1,const void * in2,const void * in3)226 static void minFloat16 (void* out, const void* in1, const void* in2, const void* in3) { runOpFunc<tcu::Float16>(min3<tcu::Float16>, out, in1, in2, in3); }
maxFloat16(void * out,const void * in1,const void * in2,const void * in3)227 static void maxFloat16 (void* out, const void* in1, const void* in2, const void* in3) { runOpFunc<tcu::Float16>(max3<tcu::Float16>, out, in1, in2, in3); }
midFloat16(void * out,const void * in1,const void * in2,const void * in3)228 static void midFloat16 (void* out, const void* in1, const void* in2, const void* in3) { runOpFunc<tcu::Float16>(mid3<tcu::Float16>, out, in1, in2, in3); }
minFloat32(void * out,const void * in1,const void * in2,const void * in3)229 static void minFloat32 (void* out, const void* in1, const void* in2, const void* in3) { runOpFunc<tcu::Float32>(min3<tcu::Float32>, out, in1, in2, in3); }
maxFloat32(void * out,const void * in1,const void * in2,const void * in3)230 static void maxFloat32 (void* out, const void* in1, const void* in2, const void* in3) { runOpFunc<tcu::Float32>(max3<tcu::Float32>, out, in1, in2, in3); }
midFloat32(void * out,const void * in1,const void * in2,const void * in3)231 static void midFloat32 (void* out, const void* in1, const void* in2, const void* in3) { runOpFunc<tcu::Float32>(mid3<tcu::Float32>, out, in1, in2, in3); }
minFloat64(void * out,const void * in1,const void * in2,const void * in3)232 static void minFloat64 (void* out, const void* in1, const void* in2, const void* in3) { runOpFunc<tcu::Float64>(min3<tcu::Float64>, out, in1, in2, in3); }
maxFloat64(void * out,const void * in1,const void * in2,const void * in3)233 static void maxFloat64 (void* out, const void* in1, const void* in2, const void* in3) { runOpFunc<tcu::Float64>(max3<tcu::Float64>, out, in1, in2, in3); }
midFloat64(void * out,const void * in1,const void * in2,const void * in3)234 static void midFloat64 (void* out, const void* in1, const void* in2, const void* in3) { runOpFunc<tcu::Float64>(mid3<tcu::Float64>, out, in1, in2, in3); }
235
236 // Case for accessing the functions map.
237 struct Case
238 {
239 BaseType type;
240 TypeSize size;
241 OperationType operation;
242
243 // This is required for sorting in the map.
operator <vkt::SpirVAssembly::__anonf00c35b40111::OperationManager::Case244 bool operator< (const Case& other) const
245 {
246 return (toArray() < other.toArray());
247 }
248
249 private:
toArrayvkt::SpirVAssembly::__anonf00c35b40111::OperationManager::Case250 std::array<int, 3> toArray () const
251 {
252 return std::array<int, 3>{{static_cast<int>(type), static_cast<int>(size), static_cast<int>(operation)}};
253 }
254 };
255
256 // Helper map to correctly choose the right generator and operation function for the specific case being tested.
257 using FuncPair = std::pair<GenerateCompFunc, OperationFunc>;
258 using CaseMap = std::map<Case, FuncPair>;
259
260 static const CaseMap kFunctionsMap;
261
262 GenerateCompFunc m_chosenGenerator;
263 OperationFunc m_chosenOperation;
264 de::Random m_random;
265
266 const deUint32 m_operandSize;
267 const deUint32 m_numComponents;
268 const deUint32 m_componentSize;
269 };
270
271 // This map is used to choose how to generate inputs for each case and which operation to run on the CPU to calculate the reference
272 // results for the generated inputs.
273 const OperationManager::CaseMap OperationManager::kFunctionsMap =
274 {
275 { { BaseType::TYPE_INT, TypeSize::SIZE_8BIT, OperationType::MIN }, { genInt8, minInt8 } },
276 { { BaseType::TYPE_INT, TypeSize::SIZE_8BIT, OperationType::MAX }, { genInt8, maxInt8 } },
277 { { BaseType::TYPE_INT, TypeSize::SIZE_8BIT, OperationType::MID }, { genInt8, midInt8 } },
278 { { BaseType::TYPE_INT, TypeSize::SIZE_16BIT, OperationType::MIN }, { genInt16, minInt16 } },
279 { { BaseType::TYPE_INT, TypeSize::SIZE_16BIT, OperationType::MAX }, { genInt16, maxInt16 } },
280 { { BaseType::TYPE_INT, TypeSize::SIZE_16BIT, OperationType::MID }, { genInt16, midInt16 } },
281 { { BaseType::TYPE_INT, TypeSize::SIZE_32BIT, OperationType::MIN }, { genInt32, minInt32 } },
282 { { BaseType::TYPE_INT, TypeSize::SIZE_32BIT, OperationType::MAX }, { genInt32, maxInt32 } },
283 { { BaseType::TYPE_INT, TypeSize::SIZE_32BIT, OperationType::MID }, { genInt32, midInt32 } },
284 { { BaseType::TYPE_INT, TypeSize::SIZE_64BIT, OperationType::MIN }, { genInt64, minInt64 } },
285 { { BaseType::TYPE_INT, TypeSize::SIZE_64BIT, OperationType::MAX }, { genInt64, maxInt64 } },
286 { { BaseType::TYPE_INT, TypeSize::SIZE_64BIT, OperationType::MID }, { genInt64, midInt64 } },
287 { { BaseType::TYPE_UINT, TypeSize::SIZE_8BIT, OperationType::MIN }, { genUint8, minUint8 } },
288 { { BaseType::TYPE_UINT, TypeSize::SIZE_8BIT, OperationType::MAX }, { genUint8, maxUint8 } },
289 { { BaseType::TYPE_UINT, TypeSize::SIZE_8BIT, OperationType::MID }, { genUint8, midUint8 } },
290 { { BaseType::TYPE_UINT, TypeSize::SIZE_16BIT, OperationType::MIN }, { genUint16, minUint16 } },
291 { { BaseType::TYPE_UINT, TypeSize::SIZE_16BIT, OperationType::MAX }, { genUint16, maxUint16 } },
292 { { BaseType::TYPE_UINT, TypeSize::SIZE_16BIT, OperationType::MID }, { genUint16, midUint16 } },
293 { { BaseType::TYPE_UINT, TypeSize::SIZE_32BIT, OperationType::MIN }, { genUint32, minUint32 } },
294 { { BaseType::TYPE_UINT, TypeSize::SIZE_32BIT, OperationType::MAX }, { genUint32, maxUint32 } },
295 { { BaseType::TYPE_UINT, TypeSize::SIZE_32BIT, OperationType::MID }, { genUint32, midUint32 } },
296 { { BaseType::TYPE_UINT, TypeSize::SIZE_64BIT, OperationType::MIN }, { genUint64, minUint64 } },
297 { { BaseType::TYPE_UINT, TypeSize::SIZE_64BIT, OperationType::MAX }, { genUint64, maxUint64 } },
298 { { BaseType::TYPE_UINT, TypeSize::SIZE_64BIT, OperationType::MID }, { genUint64, midUint64 } },
299 { { BaseType::TYPE_FLOAT, TypeSize::SIZE_16BIT, OperationType::MIN }, { genFloat16, minFloat16 } },
300 { { BaseType::TYPE_FLOAT, TypeSize::SIZE_16BIT, OperationType::MAX }, { genFloat16, maxFloat16 } },
301 { { BaseType::TYPE_FLOAT, TypeSize::SIZE_16BIT, OperationType::MID }, { genFloat16, midFloat16 } },
302 { { BaseType::TYPE_FLOAT, TypeSize::SIZE_32BIT, OperationType::MIN }, { genFloat32, minFloat32 } },
303 { { BaseType::TYPE_FLOAT, TypeSize::SIZE_32BIT, OperationType::MAX }, { genFloat32, maxFloat32 } },
304 { { BaseType::TYPE_FLOAT, TypeSize::SIZE_32BIT, OperationType::MID }, { genFloat32, midFloat32 } },
305 { { BaseType::TYPE_FLOAT, TypeSize::SIZE_64BIT, OperationType::MIN }, { genFloat64, minFloat64 } },
306 { { BaseType::TYPE_FLOAT, TypeSize::SIZE_64BIT, OperationType::MAX }, { genFloat64, maxFloat64 } },
307 { { BaseType::TYPE_FLOAT, TypeSize::SIZE_64BIT, OperationType::MID }, { genFloat64, midFloat64 } },
308 };
309
OperationManager(const TestParams & params)310 OperationManager::OperationManager (const TestParams& params)
311 : m_chosenGenerator {nullptr}
312 , m_chosenOperation {nullptr}
313 , m_random {params.randomSeed}
314 , m_operandSize {params.operandSize()}
315 , m_numComponents {params.numComponents()}
316 , m_componentSize {params.componentSize()}
317 {
318 // Choose generator and CPU operation from the map.
319 const Case paramCase{params.baseType, params.typeSize, params.operation};
320 const auto iter = kFunctionsMap.find(paramCase);
321
322 DE_ASSERT(iter != kFunctionsMap.end());
323 m_chosenGenerator = iter->second.first;
324 m_chosenOperation = iter->second.second;
325 }
326
327 // See TrinaryMinMaxCase::initPrograms for a description of the input buffer format.
328 // Generates inputs with the chosen generator.
genInputBuffer(void * bufferPtr,deUint32 numOperations)329 void OperationManager::genInputBuffer (void* bufferPtr, deUint32 numOperations)
330 {
331 const deUint32 numOperands = numOperations * 3u;
332 char* byteBuffer = reinterpret_cast<char*>(bufferPtr);
333
334 for (deUint32 opIdx = 0u; opIdx < numOperands; ++opIdx)
335 {
336 char* compPtr = byteBuffer;
337 for (deUint32 compIdx = 0u; compIdx < m_numComponents; ++compIdx)
338 {
339 m_chosenGenerator(m_random, reinterpret_cast<void*>(compPtr));
340 compPtr += m_componentSize;
341 }
342 byteBuffer += m_operandSize;
343 }
344 }
345
346 // See TrinaryMinMaxCase::initPrograms for a description of the input and output buffer formats.
347 // Calculates reference results on the CPU using the chosen operation and the input buffer.
calculateResult(void * referenceBuffer,void * inputBuffer,deUint32 numOperations)348 void OperationManager::calculateResult (void* referenceBuffer, void* inputBuffer, deUint32 numOperations)
349 {
350 char* outputByte = reinterpret_cast<char*>(referenceBuffer);
351 char* inputByte = reinterpret_cast<char*>(inputBuffer);
352
353 for (deUint32 opIdx = 0u; opIdx < numOperations; ++opIdx)
354 {
355 char* res = outputByte;
356 char* op1 = inputByte;
357 char* op2 = inputByte + m_operandSize;
358 char* op3 = inputByte + m_operandSize * 2u;
359
360 for (deUint32 compIdx = 0u; compIdx < m_numComponents; ++compIdx)
361 {
362 m_chosenOperation(
363 reinterpret_cast<void*>(res),
364 reinterpret_cast<void*>(op1),
365 reinterpret_cast<void*>(op2),
366 reinterpret_cast<void*>(op3));
367
368 res += m_componentSize;
369 op1 += m_componentSize;
370 op2 += m_componentSize;
371 op3 += m_componentSize;
372 }
373
374 outputByte += m_operandSize;
375 inputByte += m_operandSize * 3u;
376 }
377 }
378
379 // See TrinaryMinMaxCase::initPrograms for a description of the output buffer format.
compareResults(void * referenceBuffer,void * resultsBuffer,deUint32 numOperations)380 OperationManager::ComparisonError OperationManager::compareResults (void* referenceBuffer, void* resultsBuffer, deUint32 numOperations)
381 {
382 char* referenceBytes = reinterpret_cast<char*>(referenceBuffer);
383 char* resultsBytes = reinterpret_cast<char*>(resultsBuffer);
384
385 for (deUint32 opIdx = 0u; opIdx < numOperations; ++opIdx)
386 {
387 char *refCompBytes = referenceBytes;
388 char *resCompBytes = resultsBytes;
389
390 for (deUint32 compIdx = 0u; compIdx < m_numComponents; ++compIdx)
391 {
392 if (deMemCmp(refCompBytes, resCompBytes, m_componentSize) != 0)
393 return tcu::just(OperationComponent(opIdx, compIdx));
394 refCompBytes += m_componentSize;
395 resCompBytes += m_componentSize;
396 }
397 referenceBytes += m_operandSize;
398 resultsBytes += m_operandSize;
399 }
400
401 return tcu::Nothing;
402 }
403
404 class TrinaryMinMaxCase : public vkt::TestCase
405 {
406 public:
407 using ReplacementsMap = std::map<std::string, std::string>;
408
409 TrinaryMinMaxCase (tcu::TestContext& testCtx, const std::string& name, const std::string& description, const TestParams& params);
~TrinaryMinMaxCase(void)410 virtual ~TrinaryMinMaxCase (void) {}
411
412 virtual void initPrograms (vk::SourceCollections& programCollection) const;
413 virtual TestInstance* createInstance (Context& context) const;
414 virtual void checkSupport (Context& context) const;
415 ReplacementsMap getSpirVReplacements (void) const;
416
417 static const deUint32 kArraySize;
418 private:
419 TestParams m_params;
420 };
421
422 const deUint32 TrinaryMinMaxCase::kArraySize = 100u;
423
424 class TrinaryMinMaxInstance : public vkt::TestInstance
425 {
426 public:
427 TrinaryMinMaxInstance (Context& context, const TestParams& params);
~TrinaryMinMaxInstance(void)428 virtual ~TrinaryMinMaxInstance (void) {}
429
430 virtual tcu::TestStatus iterate (void);
431
432 private:
433 TestParams m_params;
434 };
435
TrinaryMinMaxCase(tcu::TestContext & testCtx,const std::string & name,const std::string & description,const TestParams & params)436 TrinaryMinMaxCase::TrinaryMinMaxCase (tcu::TestContext& testCtx, const std::string& name, const std::string& description, const TestParams& params)
437 : vkt::TestCase (testCtx, name, description)
438 , m_params (params)
439 {}
440
createInstance(Context & context) const441 TestInstance* TrinaryMinMaxCase::createInstance (Context& context) const
442 {
443 return new TrinaryMinMaxInstance{context, m_params};
444 }
445
checkSupport(Context & context) const446 void TrinaryMinMaxCase::checkSupport (Context& context) const
447 {
448 // These are always required.
449 context.requireInstanceFunctionality("VK_KHR_get_physical_device_properties2");
450 context.requireDeviceFunctionality("VK_KHR_storage_buffer_storage_class");
451 context.requireDeviceFunctionality("VK_AMD_shader_trinary_minmax");
452
453 const auto devFeatures = context.getDeviceFeatures();
454 const auto storage16BitFeatures = context.get16BitStorageFeatures();
455 const auto storage8BitFeatures = context.get8BitStorageFeatures();
456 const auto shaderFeatures = context.getShaderFloat16Int8Features();
457
458 // Storage features.
459 if (m_params.typeSize == TypeSize::SIZE_8BIT)
460 {
461 // We will be using 8-bit types in storage buffers.
462 context.requireDeviceFunctionality("VK_KHR_8bit_storage");
463 if (!storage8BitFeatures.storageBuffer8BitAccess)
464 TCU_THROW(NotSupportedError, "8-bit storage buffer access not supported");
465 }
466 else if (m_params.typeSize == TypeSize::SIZE_16BIT)
467 {
468 // We will be using 16-bit types in storage buffers.
469 context.requireDeviceFunctionality("VK_KHR_16bit_storage");
470 if (!storage16BitFeatures.storageBuffer16BitAccess)
471 TCU_THROW(NotSupportedError, "16-bit storage buffer access not supported");
472 }
473
474 // Shader type features.
475 if (m_params.baseType == BaseType::TYPE_INT || m_params.baseType == BaseType::TYPE_UINT)
476 {
477 if (m_params.typeSize == TypeSize::SIZE_8BIT && !shaderFeatures.shaderInt8)
478 TCU_THROW(NotSupportedError, "8-bit integers not supported in shaders");
479 else if (m_params.typeSize == TypeSize::SIZE_16BIT && !devFeatures.shaderInt16)
480 TCU_THROW(NotSupportedError, "16-bit integers not supported in shaders");
481 else if (m_params.typeSize == TypeSize::SIZE_64BIT && !devFeatures.shaderInt64)
482 TCU_THROW(NotSupportedError, "64-bit integers not supported in shaders");
483 }
484 else // BaseType::TYPE_FLOAT
485 {
486 DE_ASSERT(m_params.typeSize != TypeSize::SIZE_8BIT);
487 if (m_params.typeSize == TypeSize::SIZE_16BIT && !shaderFeatures.shaderFloat16)
488 TCU_THROW(NotSupportedError, "16-bit floats not supported in shaders");
489 else if (m_params.typeSize == TypeSize::SIZE_64BIT && !devFeatures.shaderFloat64)
490 TCU_THROW(NotSupportedError, "64-bit floats not supported in shaders");
491 }
492 }
493
getSpirVReplacements(void) const494 TrinaryMinMaxCase::ReplacementsMap TrinaryMinMaxCase::getSpirVReplacements (void) const
495 {
496 ReplacementsMap replacements;
497
498 // Capabilities and extensions.
499 if (m_params.baseType == BaseType::TYPE_INT || m_params.baseType == BaseType::TYPE_UINT)
500 {
501 if (m_params.typeSize == TypeSize::SIZE_8BIT)
502 replacements["CAPABILITIES"] += "OpCapability Int8\n";
503 else if (m_params.typeSize == TypeSize::SIZE_16BIT)
504 replacements["CAPABILITIES"] += "OpCapability Int16\n";
505 else if (m_params.typeSize == TypeSize::SIZE_64BIT)
506 replacements["CAPABILITIES"] += "OpCapability Int64\n";
507 }
508 else // BaseType::TYPE_FLOAT
509 {
510 if (m_params.typeSize == TypeSize::SIZE_16BIT)
511 replacements["CAPABILITIES"] += "OpCapability Float16\n";
512 else if (m_params.typeSize == TypeSize::SIZE_64BIT)
513 replacements["CAPABILITIES"] += "OpCapability Float64\n";
514 }
515
516 if (m_params.typeSize == TypeSize::SIZE_8BIT)
517 {
518 replacements["CAPABILITIES"] += "OpCapability StorageBuffer8BitAccess\n";
519 replacements["EXTENSIONS"] += "OpExtension \"SPV_KHR_8bit_storage\"\n";
520 }
521 else if (m_params.typeSize == TypeSize::SIZE_16BIT)
522 {
523 replacements["CAPABILITIES"] += "OpCapability StorageBuffer16BitAccess\n";
524 replacements["EXTENSIONS"] += "OpExtension \"SPV_KHR_16bit_storage\"\n";
525 }
526
527 // Operand size in bytes.
528 const deUint32 opSize = m_params.operandSize();
529 replacements["OPERAND_SIZE"] = de::toString(opSize);
530 replacements["OPERAND_SIZE_2TIMES"] = de::toString(opSize * 2u);
531 replacements["OPERAND_SIZE_3TIMES"] = de::toString(opSize * 3u);
532
533 // Array size.
534 replacements["ARRAY_SIZE"] = de::toString(kArraySize);
535
536 // Types and operand type: define the base integer or float type and the vector type if needed, then set the operand type replacement.
537 const std::string vecSize = de::toString(m_params.numComponents());
538 const std::string bitSize = de::toString(m_params.componentSize() * 8u);
539
540 if (m_params.baseType == BaseType::TYPE_INT || m_params.baseType == BaseType::TYPE_UINT)
541 {
542 const std::string signBit = (m_params.baseType == BaseType::TYPE_INT ? "1" : "0");
543 const std::string typePrefix = (m_params.baseType == BaseType::TYPE_UINT ? "u" : "");
544 std::string baseTypeName;
545
546 // 32-bit integers are already defined in the default shader text.
547 if (m_params.typeSize != TypeSize::SIZE_32BIT)
548 {
549 baseTypeName = typePrefix + "int" + bitSize + "_t";
550 replacements["TYPES"] += "%" + baseTypeName + " = OpTypeInt " + bitSize + " " + signBit + "\n";
551 }
552 else
553 {
554 baseTypeName = typePrefix + "int";
555 }
556
557 if (m_params.aggregation == AggregationType::SCALAR)
558 {
559 replacements["OPERAND_TYPE"] = "%" + baseTypeName;
560 }
561 else
562 {
563 const std::string typeName = "%v" + vecSize + baseTypeName;
564 // %v3uint is already defined in the default shader text.
565 if (m_params.baseType != BaseType::TYPE_UINT || m_params.typeSize != TypeSize::SIZE_32BIT || m_params.aggregation != AggregationType::VEC3)
566 {
567 replacements["TYPES"] += typeName + " = OpTypeVector %" + baseTypeName + " " + vecSize + "\n";
568 }
569 replacements["OPERAND_TYPE"] = typeName;
570 }
571 }
572 else // BaseType::TYPE_FLOAT
573 {
574 const std::string baseTypeName = "float" + bitSize + "_t";
575 replacements["TYPES"] += "%" + baseTypeName + " = OpTypeFloat " + bitSize + "\n";
576
577 if (m_params.aggregation == AggregationType::SCALAR)
578 {
579 replacements["OPERAND_TYPE"] = "%" + baseTypeName;
580 }
581 else
582 {
583 const std::string typeName = "%v" + vecSize + baseTypeName;
584 replacements["TYPES"] += typeName + " = OpTypeVector %" + baseTypeName + " " + vecSize + "\n";
585 replacements["OPERAND_TYPE"] = typeName;
586 }
587 }
588
589 // Operation name.
590 const static std::vector<std::string> opTypeStr = { "Min", "Max", "Mid" };
591 const static std::vector<std::string> opPrefix = { "S", "U", "F" };
592 replacements["OPERATION_NAME"] = opPrefix[static_cast<int>(m_params.baseType)] + opTypeStr[static_cast<int>(m_params.operation)] + "3AMD";
593
594 return replacements;
595 }
596
initPrograms(vk::SourceCollections & programCollection) const597 void TrinaryMinMaxCase::initPrograms (vk::SourceCollections& programCollection) const
598 {
599 // The shader below uses an input buffer at set 0 binding 0 and an output buffer at set 0 binding 1. Their structure is similar
600 // to the code below:
601 //
602 // struct Operands {
603 // <type> op1;
604 // <type> op2;
605 // <type> op3;
606 // };
607 //
608 // layout (set=0, binding=0, std430) buffer InputBlock {
609 // Operands operands[<arraysize>];
610 // };
611 //
612 // layout (set=0, binding=1, std430) buffer OutputBlock {
613 // <type> result[<arraysize>];
614 // };
615 //
616 // Where <type> can be int8_t, uint32_t, float, etc. So in the input buffer the operands are "grouped" per operation and can
617 // have several components each and the output buffer contains an array of results, one per trio of input operands.
618
619 std::ostringstream shaderStr;
620 shaderStr
621 << "; SPIR-V\n"
622 << "; Version: 1.5\n"
623 << " OpCapability Shader\n"
624 << "${CAPABILITIES:opt}"
625 << " OpExtension \"SPV_KHR_storage_buffer_storage_class\"\n"
626 << " OpExtension \"SPV_AMD_shader_trinary_minmax\"\n"
627 << "${EXTENSIONS:opt}"
628 << " %std450 = OpExtInstImport \"GLSL.std.450\"\n"
629 << " %trinary = OpExtInstImport \"SPV_AMD_shader_trinary_minmax\"\n"
630 << " OpMemoryModel Logical GLSL450\n"
631 << " OpEntryPoint GLCompute %main \"main\" %gl_GlobalInvocationID %output_buffer %input_buffer\n"
632 << " OpExecutionMode %main LocalSize 1 1 1\n"
633 << " OpDecorate %gl_GlobalInvocationID BuiltIn GlobalInvocationId\n"
634 << " OpDecorate %results_array_t ArrayStride ${OPERAND_SIZE}\n"
635 << " OpMemberDecorate %OutputBlock 0 Offset 0\n"
636 << " OpDecorate %OutputBlock Block\n"
637 << " OpDecorate %output_buffer DescriptorSet 0\n"
638 << " OpDecorate %output_buffer Binding 1\n"
639 << " OpMemberDecorate %Operands 0 Offset 0\n"
640 << " OpMemberDecorate %Operands 1 Offset ${OPERAND_SIZE}\n"
641 << " OpMemberDecorate %Operands 2 Offset ${OPERAND_SIZE_2TIMES}\n"
642 << " OpDecorate %_arr_Operands_arraysize ArrayStride ${OPERAND_SIZE_3TIMES}\n"
643 << " OpMemberDecorate %InputBlock 0 Offset 0\n"
644 << " OpDecorate %InputBlock Block\n"
645 << " OpDecorate %input_buffer DescriptorSet 0\n"
646 << " OpDecorate %input_buffer Binding 0\n"
647 << " OpDecorate %gl_WorkGroupSize BuiltIn WorkgroupSize\n"
648 << " %void = OpTypeVoid\n"
649 << " %voidfunc = OpTypeFunction %void\n"
650 << " %int = OpTypeInt 32 1\n"
651 << " %uint = OpTypeInt 32 0\n"
652 << " %v3uint = OpTypeVector %uint 3\n"
653 << "${TYPES:opt}"
654 << " %int_0 = OpConstant %int 0\n"
655 << " %int_1 = OpConstant %int 1\n"
656 << " %int_2 = OpConstant %int 2\n"
657 << " %uint_1 = OpConstant %uint 1\n"
658 << " %uint_0 = OpConstant %uint 0\n"
659 << " %arraysize = OpConstant %uint ${ARRAY_SIZE}\n"
660 << " %_ptr_Function_uint = OpTypePointer Function %uint\n"
661 << " %_ptr_Input_v3uint = OpTypePointer Input %v3uint\n"
662 << " %gl_GlobalInvocationID = OpVariable %_ptr_Input_v3uint Input\n"
663 << " %_ptr_Input_uint = OpTypePointer Input %uint\n"
664 << " %results_array_t = OpTypeArray ${OPERAND_TYPE} %arraysize\n"
665 << " %Operands = OpTypeStruct ${OPERAND_TYPE} ${OPERAND_TYPE} ${OPERAND_TYPE}\n"
666 << " %_arr_Operands_arraysize = OpTypeArray %Operands %arraysize\n"
667 << " %OutputBlock = OpTypeStruct %results_array_t\n"
668 << " %InputBlock = OpTypeStruct %_arr_Operands_arraysize\n"
669 << "%_ptr_Uniform_OutputBlock = OpTypePointer StorageBuffer %OutputBlock\n"
670 << " %_ptr_Uniform_InputBlock = OpTypePointer StorageBuffer %InputBlock\n"
671 << " %output_buffer = OpVariable %_ptr_Uniform_OutputBlock StorageBuffer\n"
672 << " %input_buffer = OpVariable %_ptr_Uniform_InputBlock StorageBuffer\n"
673 << " %optype_ptr = OpTypePointer StorageBuffer ${OPERAND_TYPE}\n"
674 << " %gl_WorkGroupSize = OpConstantComposite %v3uint %uint_1 %uint_1 %uint_1\n"
675 << " %main = OpFunction %void None %voidfunc\n"
676 << " %mainlabel = OpLabel\n"
677 << " %gidxptr = OpAccessChain %_ptr_Input_uint %gl_GlobalInvocationID %uint_0\n"
678 << " %idx = OpLoad %uint %gidxptr\n"
679 << " %op1ptr = OpAccessChain %optype_ptr %input_buffer %int_0 %idx %int_0\n"
680 << " %op1 = OpLoad ${OPERAND_TYPE} %op1ptr\n"
681 << " %op2ptr = OpAccessChain %optype_ptr %input_buffer %int_0 %idx %int_1\n"
682 << " %op2 = OpLoad ${OPERAND_TYPE} %op2ptr\n"
683 << " %op3ptr = OpAccessChain %optype_ptr %input_buffer %int_0 %idx %int_2\n"
684 << " %op3 = OpLoad ${OPERAND_TYPE} %op3ptr\n"
685 << " %result = OpExtInst ${OPERAND_TYPE} %trinary ${OPERATION_NAME} %op1 %op2 %op3\n"
686 << " %resultptr = OpAccessChain %optype_ptr %output_buffer %int_0 %idx\n"
687 << " OpStore %resultptr %result\n"
688 << " OpReturn\n"
689 << " OpFunctionEnd\n"
690 ;
691
692 const tcu::StringTemplate shaderTemplate {shaderStr.str()};
693 const vk::SpirVAsmBuildOptions buildOptions { VK_MAKE_API_VERSION(0, 1, 2, 0), vk::SPIRV_VERSION_1_5};
694
695 programCollection.spirvAsmSources.add("comp", &buildOptions) << shaderTemplate.specialize(getSpirVReplacements());
696 }
697
TrinaryMinMaxInstance(Context & context,const TestParams & params)698 TrinaryMinMaxInstance::TrinaryMinMaxInstance (Context& context, const TestParams& params)
699 : vkt::TestInstance (context)
700 , m_params (params)
701 {}
702
iterate(void)703 tcu::TestStatus TrinaryMinMaxInstance::iterate (void)
704 {
705 const auto& vkd = m_context.getDeviceInterface();
706 const auto device = m_context.getDevice();
707 auto& allocator = m_context.getDefaultAllocator();
708 const auto queue = m_context.getUniversalQueue();
709 const auto queueIndex = m_context.getUniversalQueueFamilyIndex();
710
711 constexpr auto kNumOperations = TrinaryMinMaxCase::kArraySize;
712
713 const vk::VkDeviceSize kInputBufferSize = static_cast<vk::VkDeviceSize>(kNumOperations * 3u * m_params.operandSize());
714 const vk::VkDeviceSize kOutputBufferSize = static_cast<vk::VkDeviceSize>(kNumOperations * m_params.operandSize()); // Single output per operation.
715
716 // Create input, output and reference buffers.
717 auto inputBufferInfo = vk::makeBufferCreateInfo(kInputBufferSize, vk::VK_BUFFER_USAGE_STORAGE_BUFFER_BIT);
718 auto outputBufferInfo = vk::makeBufferCreateInfo(kOutputBufferSize, vk::VK_BUFFER_USAGE_STORAGE_BUFFER_BIT);
719
720 vk::BufferWithMemory inputBuffer {vkd, device, allocator, inputBufferInfo, vk::MemoryRequirement::HostVisible};
721 vk::BufferWithMemory outputBuffer {vkd, device, allocator, outputBufferInfo, vk::MemoryRequirement::HostVisible};
722 std::unique_ptr<char[]> referenceBuffer {new char[static_cast<size_t>(kOutputBufferSize)]};
723
724 // Fill buffers with initial contents.
725 auto& inputAlloc = inputBuffer.getAllocation();
726 auto& outputAlloc = outputBuffer.getAllocation();
727
728 void* inputBufferPtr = static_cast<deUint8*>(inputAlloc.getHostPtr()) + inputAlloc.getOffset();
729 void* outputBufferPtr = static_cast<deUint8*>(outputAlloc.getHostPtr()) + outputAlloc.getOffset();
730 void* referenceBufferPtr = referenceBuffer.get();
731
732 deMemset(inputBufferPtr, 0, static_cast<size_t>(kInputBufferSize));
733 deMemset(outputBufferPtr, 0, static_cast<size_t>(kOutputBufferSize));
734 deMemset(referenceBufferPtr, 0, static_cast<size_t>(kOutputBufferSize));
735
736 // Generate input buffer and calculate reference results.
737 OperationManager opMan{m_params};
738 opMan.genInputBuffer(inputBufferPtr, kNumOperations);
739 opMan.calculateResult(referenceBufferPtr, inputBufferPtr, kNumOperations);
740
741 // Flush buffer memory before starting.
742 vk::flushAlloc(vkd, device, inputAlloc);
743 vk::flushAlloc(vkd, device, outputAlloc);
744
745 // Descriptor set layout.
746 vk::DescriptorSetLayoutBuilder layoutBuilder;
747 layoutBuilder.addSingleBinding(vk::VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, vk::VK_SHADER_STAGE_COMPUTE_BIT);
748 layoutBuilder.addSingleBinding(vk::VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, vk::VK_SHADER_STAGE_COMPUTE_BIT);
749 auto descriptorSetLayout = layoutBuilder.build(vkd, device);
750
751 // Descriptor pool.
752 vk::DescriptorPoolBuilder poolBuilder;
753 poolBuilder.addType(vk::VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, 2u);
754 auto descriptorPool = poolBuilder.build(vkd, device, vk::VK_DESCRIPTOR_POOL_CREATE_FREE_DESCRIPTOR_SET_BIT, 1u);
755
756 // Descriptor set.
757 const auto descriptorSet = vk::makeDescriptorSet(vkd, device, descriptorPool.get(), descriptorSetLayout.get());
758
759 // Update descriptor set using the buffers.
760 const auto inputBufferDescriptorInfo = vk::makeDescriptorBufferInfo(inputBuffer.get(), 0ull, VK_WHOLE_SIZE);
761 const auto outputBufferDescriptorInfo = vk::makeDescriptorBufferInfo(outputBuffer.get(), 0ull, VK_WHOLE_SIZE);
762
763 vk::DescriptorSetUpdateBuilder updateBuilder;
764 updateBuilder.writeSingle(descriptorSet.get(), vk::DescriptorSetUpdateBuilder::Location::binding(0u), vk::VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, &inputBufferDescriptorInfo);
765 updateBuilder.writeSingle(descriptorSet.get(), vk::DescriptorSetUpdateBuilder::Location::binding(1u), vk::VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, &outputBufferDescriptorInfo);
766 updateBuilder.update(vkd, device);
767
768 // Create compute pipeline.
769 auto shaderModule = vk::createShaderModule(vkd, device, m_context.getBinaryCollection().get("comp"), 0u);
770 auto pipelineLayout = vk::makePipelineLayout(vkd, device, descriptorSetLayout.get());
771
772 const vk::VkComputePipelineCreateInfo pipelineCreateInfo =
773 {
774 vk::VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO,
775 nullptr,
776 0u, // flags
777 { // compute shader
778 vk::VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO, // VkStructureType sType;
779 nullptr, // const void* pNext;
780 0u, // VkPipelineShaderStageCreateFlags flags;
781 vk::VK_SHADER_STAGE_COMPUTE_BIT, // VkShaderStageFlagBits stage;
782 shaderModule.get(), // VkShaderModule module;
783 "main", // const char* pName;
784 nullptr, // const VkSpecializationInfo* pSpecializationInfo;
785 },
786 pipelineLayout.get(), // layout
787 DE_NULL, // basePipelineHandle
788 0, // basePipelineIndex
789 };
790 auto pipeline = vk::createComputePipeline(vkd, device, DE_NULL, &pipelineCreateInfo);
791
792 // Synchronization barriers.
793 auto inputBufferHostToDevBarrier = vk::makeBufferMemoryBarrier(vk::VK_ACCESS_HOST_WRITE_BIT, vk::VK_ACCESS_SHADER_READ_BIT, inputBuffer.get(), 0ull, VK_WHOLE_SIZE);
794 auto outputBufferHostToDevBarrier = vk::makeBufferMemoryBarrier(vk::VK_ACCESS_HOST_WRITE_BIT, vk::VK_ACCESS_SHADER_WRITE_BIT, outputBuffer.get(), 0ull, VK_WHOLE_SIZE);
795 auto outputBufferDevToHostBarrier = vk::makeBufferMemoryBarrier(vk::VK_ACCESS_SHADER_WRITE_BIT, vk::VK_ACCESS_HOST_READ_BIT, outputBuffer.get(), 0ull, VK_WHOLE_SIZE);
796
797 // Command buffer.
798 auto cmdPool = vk::makeCommandPool(vkd, device, queueIndex);
799 auto cmdBufferPtr = vk::allocateCommandBuffer(vkd, device, cmdPool.get(), vk::VK_COMMAND_BUFFER_LEVEL_PRIMARY);
800 auto cmdBuffer = cmdBufferPtr.get();
801
802 // Record and submit commands.
803 vk::beginCommandBuffer(vkd, cmdBuffer);
804 vkd.cmdBindPipeline(cmdBuffer, vk::VK_PIPELINE_BIND_POINT_COMPUTE, pipeline.get());
805 vkd.cmdBindDescriptorSets(cmdBuffer, vk::VK_PIPELINE_BIND_POINT_COMPUTE, pipelineLayout.get(), 0, 1u, &descriptorSet.get(), 0u, nullptr);
806 vkd.cmdPipelineBarrier(cmdBuffer, vk::VK_PIPELINE_STAGE_HOST_BIT, vk::VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, 0u, 0u, nullptr, 1u, &inputBufferHostToDevBarrier, 0u, nullptr);
807 vkd.cmdPipelineBarrier(cmdBuffer, vk::VK_PIPELINE_STAGE_HOST_BIT, vk::VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, 0u, 0u, nullptr, 1u, &outputBufferHostToDevBarrier, 0u, nullptr);
808 vkd.cmdDispatch(cmdBuffer, kNumOperations, 1u, 1u);
809 vkd.cmdPipelineBarrier(cmdBuffer, vk::VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, vk::VK_PIPELINE_STAGE_HOST_BIT, 0u, 0u, nullptr, 1u, &outputBufferDevToHostBarrier, 0u, nullptr);
810 vk::endCommandBuffer(vkd, cmdBuffer);
811 vk::submitCommandsAndWait(vkd, device, queue, cmdBuffer);
812
813 // Verify output buffer contents.
814 vk::invalidateAlloc(vkd, device, outputAlloc);
815
816 const auto error = opMan.compareResults(referenceBufferPtr, outputBufferPtr, kNumOperations);
817
818 if (!error)
819 return tcu::TestStatus::pass("Pass");
820
821 std::ostringstream msg;
822 msg << "Value mismatch at operation " << error.get().first << " in component " << error.get().second;
823 return tcu::TestStatus::fail(msg.str());
824 }
825
826 } // anonymous
827
createTrinaryMinMaxGroup(tcu::TestContext & testCtx)828 tcu::TestCaseGroup* createTrinaryMinMaxGroup (tcu::TestContext& testCtx)
829 {
830 deUint32 seed = 0xFEE768FCu;
831 de::MovePtr<tcu::TestCaseGroup> group{new tcu::TestCaseGroup{testCtx, "amd_trinary_minmax", "Tests for VK_AMD_trinary_minmax operations"}};
832
833 static const std::vector<std::pair<OperationType, std::string>> operationTypes =
834 {
835 { OperationType::MIN, "min3" },
836 { OperationType::MAX, "max3" },
837 { OperationType::MID, "mid3" },
838 };
839
840 static const std::vector<std::pair<BaseType, std::string>> baseTypes =
841 {
842 { BaseType::TYPE_INT, "i" },
843 { BaseType::TYPE_UINT, "u" },
844 { BaseType::TYPE_FLOAT, "f" },
845 };
846
847 static const std::vector<std::pair<TypeSize, std::string>> typeSizes =
848 {
849 { TypeSize::SIZE_8BIT, "8" },
850 { TypeSize::SIZE_16BIT, "16" },
851 { TypeSize::SIZE_32BIT, "32" },
852 { TypeSize::SIZE_64BIT, "64" },
853 };
854
855 static const std::vector<std::pair<AggregationType, std::string>> aggregationTypes =
856 {
857 { AggregationType::SCALAR, "scalar" },
858 { AggregationType::VEC2, "vec2" },
859 { AggregationType::VEC3, "vec3" },
860 { AggregationType::VEC4, "vec4" },
861 };
862
863 for (const auto& opType : operationTypes)
864 {
865 const std::string opDesc = "Tests for " + opType.second + " operation";
866 de::MovePtr<tcu::TestCaseGroup> opGroup{new tcu::TestCaseGroup{testCtx, opType.second.c_str(), opDesc.c_str()}};
867
868 for (const auto& baseType : baseTypes)
869 for (const auto& typeSize : typeSizes)
870 {
871 // There are no 8-bit floats.
872 if (baseType.first == BaseType::TYPE_FLOAT && typeSize.first == TypeSize::SIZE_8BIT)
873 continue;
874
875 const std::string typeName = baseType.second + typeSize.second;
876 const std::string typeDesc = "Tests using " + typeName + " data";
877
878 de::MovePtr<tcu::TestCaseGroup> typeGroup{new tcu::TestCaseGroup{testCtx, typeName.c_str(), typeDesc.c_str()}};
879
880 for (const auto& aggType : aggregationTypes)
881 {
882 const TestParams params =
883 {
884 opType.first, // OperationType operation;
885 baseType.first, // BaseType baseType;
886 typeSize.first, // TypeSize typeSize;
887 aggType.first, // AggregationType aggregation;
888 seed++, // deUint32 randomSeed;
889 };
890 typeGroup->addChild(new TrinaryMinMaxCase{testCtx, aggType.second, "", params});
891 }
892
893 opGroup->addChild(typeGroup.release());
894 }
895
896 group->addChild(opGroup.release());
897 }
898
899 return group.release();
900 }
901
902 } // SpirVAssembly
903 } // vkt
904