1 /*------------------------------------------------------------------------
2 * Vulkan Conformance Tests
3 * ------------------------
4 *
5 * Copyright (c) 2019 The Khronos Group Inc.
6 * Copyright (c) 2018-2019 NVIDIA Corporation
7 *
8 * Licensed under the Apache License, Version 2.0 (the "License");
9 * you may not use this file except in compliance with the License.
10 * You may obtain a copy of the License at
11 *
12 * http://www.apache.org/licenses/LICENSE-2.0
13 *
14 * Unless required by applicable law or agreed to in writing, software
15 * distributed under the License is distributed on an "AS IS" BASIS,
16 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17 * See the License for the specific language governing permissions and
18 * limitations under the License.
19 *
20 *//*!
21 * \file
22 * \brief Vulkan Cooperative Matrix tests
23 *//*--------------------------------------------------------------------*/
24
25 #include "vktComputeCooperativeMatrixTests.hpp"
26
27 #include "vkBufferWithMemory.hpp"
28 #include "vkImageWithMemory.hpp"
29 #include "vkQueryUtil.hpp"
30 #include "vkBuilderUtil.hpp"
31 #include "vkCmdUtil.hpp"
32 #include "vkTypeUtil.hpp"
33 #include "vkObjUtil.hpp"
34
35 #include "vktTestGroupUtil.hpp"
36 #include "vktTestCase.hpp"
37
38 #include "deDefs.h"
39 #include "deFloat16.h"
40 #include "deMath.h"
41 #include "deRandom.h"
42 #include "deSharedPtr.hpp"
43 #include "deString.h"
44
45 #include "tcuTestCase.hpp"
46 #include "tcuTestLog.hpp"
47
48 #include <string>
49 #include <sstream>
50 #include <set>
51 #include <algorithm>
52
53 namespace vkt
54 {
55 namespace compute
56 {
57 namespace
58 {
59 using namespace vk;
60 using namespace std;
61
62 typedef enum
63 {
64 TT_LENGTH = 0,
65 TT_CONSTANT,
66 TT_CONVERT,
67 TT_COMPOSITE,
68 TT_COMPOSITE_RVALUE,
69 TT_ADD,
70 TT_SUB,
71 TT_DIV,
72 TT_NEGATE,
73 TT_MATRIXTIMESSCALAR,
74 TT_FUNC,
75 TT_MATRIXMULADD,
76 TT_COMPOSITE_ARRAY,
77 TT_MATRIXMULADD_ARRAY,
78 } TestType;
79
80 typedef enum
81 {
82 SC_BUFFER = 0,
83 SC_WORKGROUP,
84 SC_WORKGROUP_VARIABLE_POINTERS,
85 SC_BUFFER_VARIABLE_POINTERS,
86 SC_PHYSICAL_STORAGE_BUFFER,
87 } StorageClass;
88
89 const VkFlags allShaderStages = VK_SHADER_STAGE_COMPUTE_BIT;
90
91 struct CaseDef
92 {
93 TestType testType;
94 deUint32 subgroupsPerWorkgroupX;
95 deUint32 subgroupsPerWorkgroupY;
96 deUint32 workgroupsX;
97 deUint32 workgroupsY;
98 VkComponentTypeNV inputType;
99 VkComponentTypeNV outputType;
100 bool colMajor;
101 StorageClass storageClass;
102 };
103
104 class CooperativeMatrixTestInstance : public TestInstance
105 {
106 public:
107 CooperativeMatrixTestInstance (Context& context, const CaseDef& data);
108 ~CooperativeMatrixTestInstance (void);
109 tcu::TestStatus iterate (void);
110 private:
111 CaseDef m_data;
112 };
113
CooperativeMatrixTestInstance(Context & context,const CaseDef & data)114 CooperativeMatrixTestInstance::CooperativeMatrixTestInstance (Context& context, const CaseDef& data)
115 : vkt::TestInstance (context)
116 , m_data (data)
117 {
118 }
119
~CooperativeMatrixTestInstance(void)120 CooperativeMatrixTestInstance::~CooperativeMatrixTestInstance (void)
121 {
122 }
123
124 class CooperativeMatrixTestCase : public TestCase
125 {
126 public:
127 CooperativeMatrixTestCase (tcu::TestContext& context, const char* name, const char* desc, const CaseDef data);
128 ~CooperativeMatrixTestCase (void);
129 virtual void initPrograms (SourceCollections& programCollection) const;
130 virtual TestInstance* createInstance (Context& context) const;
131 virtual void checkSupport (Context& context) const;
132
133 private:
134 CaseDef m_data;
135 };
136
CooperativeMatrixTestCase(tcu::TestContext & context,const char * name,const char * desc,const CaseDef data)137 CooperativeMatrixTestCase::CooperativeMatrixTestCase (tcu::TestContext& context, const char* name, const char* desc, const CaseDef data)
138 : vkt::TestCase (context, name, desc)
139 , m_data (data)
140 {
141 }
142
~CooperativeMatrixTestCase(void)143 CooperativeMatrixTestCase::~CooperativeMatrixTestCase (void)
144 {
145 }
146
checkSupport(Context & context) const147 void CooperativeMatrixTestCase::checkSupport(Context& context) const
148 {
149 if (!context.contextSupports(vk::ApiVersion(1, 1, 0)))
150 {
151 TCU_THROW(NotSupportedError, "Vulkan 1.1 not supported");
152 }
153
154 if (!context.getCooperativeMatrixFeatures().cooperativeMatrix)
155 {
156 TCU_THROW(NotSupportedError, "cooperativeMatrix not supported");
157 }
158
159 if (!context.getVulkanMemoryModelFeatures().vulkanMemoryModel)
160 {
161 TCU_THROW(NotSupportedError, "vulkanMemoryModel not supported");
162 }
163
164 if ((m_data.storageClass == SC_WORKGROUP_VARIABLE_POINTERS || m_data.storageClass == SC_BUFFER_VARIABLE_POINTERS) &&
165 !context.getVariablePointersFeatures().variablePointers)
166 {
167 TCU_THROW(NotSupportedError, "variable pointers not supported");
168 }
169
170 if (m_data.storageClass == SC_PHYSICAL_STORAGE_BUFFER && !context.isBufferDeviceAddressSupported())
171 {
172 TCU_THROW(NotSupportedError, "buffer device address not supported");
173 }
174
175 if (!context.getShaderFloat16Int8Features().shaderFloat16 &&
176 (m_data.inputType == VK_COMPONENT_TYPE_FLOAT16_NV || m_data.outputType == VK_COMPONENT_TYPE_FLOAT16_NV))
177 {
178 TCU_THROW(NotSupportedError, "shaderFloat16 not supported");
179 }
180
181 deUint32 propertyCount = 0;
182 VkCooperativeMatrixPropertiesNV *pProperties;
183 context.getInstanceInterface().getPhysicalDeviceCooperativeMatrixPropertiesNV(context.getPhysicalDevice(), &propertyCount, DE_NULL);
184 if (propertyCount == 0)
185 TCU_THROW(NotSupportedError, "cooperative matrices not supported");
186
187 bool supported[2] = { false, false };
188 pProperties = new VkCooperativeMatrixPropertiesNV[propertyCount];
189
190 for (deUint32 i = 0; i < propertyCount; ++i)
191 {
192 VkCooperativeMatrixPropertiesNV *p = &pProperties[i];
193 p->sType = VK_STRUCTURE_TYPE_COOPERATIVE_MATRIX_PROPERTIES_NV;
194 p->pNext = DE_NULL;
195 }
196
197 context.getInstanceInterface().getPhysicalDeviceCooperativeMatrixPropertiesNV(context.getPhysicalDevice(), &propertyCount, pProperties);
198
199 for (deUint32 i = 0; i < propertyCount; ++i)
200 {
201 VkCooperativeMatrixPropertiesNV *p = &pProperties[i];
202 if (m_data.testType == TT_MATRIXMULADD ||
203 m_data.testType == TT_MATRIXMULADD_ARRAY)
204 {
205 if (p->AType == m_data.inputType &&
206 p->BType == m_data.inputType &&
207 p->CType == m_data.outputType &&
208 p->DType == m_data.outputType &&
209 p->scope == VK_SCOPE_SUBGROUP_NV)
210 {
211 supported[0] = supported[1] = true;
212 }
213 }
214 else
215 {
216 VkComponentTypeNV types[2] = { m_data.inputType, m_data.outputType };
217
218 for (deUint32 j = 0; j < 2; ++j)
219 {
220 if (p->scope == VK_SCOPE_SUBGROUP_NV && (p->AType == types[j] || p->BType == types[j] || p->CType == types[j] || p->DType == types[j]))
221 {
222 supported[j] = true;
223 }
224 }
225 }
226 }
227
228 delete [] pProperties;
229
230 if (!supported[0] || !supported[1])
231 TCU_THROW(NotSupportedError, "cooperative matrix combination not supported");
232 }
233
234 struct {
235 const char *typeName;
236 const char *coopmatTypeName;
237 deUint32 bits;
238 } componentTypeInfo[] =
239 {
240 { "float16_t", "fcoopmatNV", 16 },
241 { "float32_t", "fcoopmatNV", 32 },
242 { "float64_t", "fcoopmatNV", 64 },
243 { "int8_t", "icoopmatNV", 8 },
244 { "int16_t", "icoopmatNV", 16 },
245 { "int32_t", "icoopmatNV", 32 },
246 { "int64_t", "icoopmatNV", 64 },
247 { "uint8_t", "ucoopmatNV", 8 },
248 { "uint16_t", "ucoopmatNV", 16 },
249 { "uint32_t", "ucoopmatNV", 32 },
250 { "uint64_t", "ucoopmatNV", 64 },
251 };
252
isFloatType(VkComponentTypeNV t)253 static bool isFloatType(VkComponentTypeNV t)
254 {
255 switch (t)
256 {
257 default:
258 return false;
259 case VK_COMPONENT_TYPE_FLOAT16_NV:
260 case VK_COMPONENT_TYPE_FLOAT32_NV:
261 case VK_COMPONENT_TYPE_FLOAT64_NV:
262 return true;
263 }
264 }
265
isSIntType(VkComponentTypeNV t)266 static bool isSIntType(VkComponentTypeNV t)
267 {
268 switch (t)
269 {
270 default:
271 return false;
272 case VK_COMPONENT_TYPE_SINT8_NV:
273 case VK_COMPONENT_TYPE_SINT16_NV:
274 case VK_COMPONENT_TYPE_SINT32_NV:
275 case VK_COMPONENT_TYPE_SINT64_NV:
276 return true;
277 }
278 }
279
initPrograms(SourceCollections & programCollection) const280 void CooperativeMatrixTestCase::initPrograms (SourceCollections& programCollection) const
281 {
282 std::stringstream css;
283 css << "#version 450 core\n";
284 css << "#pragma use_vulkan_memory_model\n";
285 css <<
286 "#extension GL_KHR_shader_subgroup_basic : enable\n"
287 "#extension GL_KHR_memory_scope_semantics : enable\n"
288 "#extension GL_NV_cooperative_matrix : enable\n"
289 "#extension GL_NV_integer_cooperative_matrix : enable\n"
290 "#extension GL_EXT_shader_explicit_arithmetic_types_float16 : enable\n"
291 "#extension GL_EXT_shader_explicit_arithmetic_types_float32 : enable\n"
292 "#extension GL_EXT_shader_explicit_arithmetic_types_int8 : enable\n"
293 "#extension GL_EXT_shader_explicit_arithmetic_types_int32 : enable\n"
294 "#extension GL_EXT_buffer_reference : enable\n"
295 "// strides overriden by spec constants\n"
296 "layout(constant_id = 2) const int AStride = 1;\n"
297 "layout(constant_id = 3) const int BStride = 1;\n"
298 "layout(constant_id = 4) const int CStride = 1;\n"
299 "layout(constant_id = 5) const int OStride = 1;\n"
300 "layout(constant_id = 6) const int M = 1;\n"
301 "layout(constant_id = 7) const int N = 1;\n"
302 "layout(constant_id = 8) const int K = 1;\n"
303 "layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z = 1) in;\n";
304
305 if (m_data.storageClass == SC_BUFFER_VARIABLE_POINTERS || m_data.storageClass == SC_WORKGROUP_VARIABLE_POINTERS)
306 css << "#pragma use_variable_pointers\n";
307
308 struct
309 {
310 string rows, cols;
311 } dims[4];
312
313 if (m_data.testType == TT_MATRIXMULADD ||
314 m_data.testType == TT_MATRIXMULADD_ARRAY)
315 {
316 dims[0].rows = "M";
317 dims[0].cols = "K";
318 dims[1].rows = "K";
319 dims[1].cols = "N";
320 dims[2].rows = "M";
321 dims[2].cols = "N";
322 dims[3].rows = "M";
323 dims[3].cols = "N";
324 }
325 else
326 {
327 dims[0].rows = "M";
328 dims[0].cols = "N";
329 dims[1].rows = "M";
330 dims[1].cols = "N";
331 dims[2].rows = "M";
332 dims[2].cols = "N";
333 dims[3].rows = "M";
334 dims[3].cols = "N";
335 }
336
337 const char *typeStrA = componentTypeInfo[m_data.inputType].typeName;
338 const char *typeStrB = componentTypeInfo[m_data.inputType].typeName;
339 const char *typeStrC = componentTypeInfo[m_data.outputType].typeName;
340 const char *typeStrO = componentTypeInfo[m_data.outputType].typeName;
341
342 css << "const int workgroupsX = " << m_data.workgroupsX << ";\n";
343 css << "const uvec2 subgroupsPerWG = uvec2(" << m_data.subgroupsPerWorkgroupX << ", " << m_data.subgroupsPerWorkgroupY << ");\n";
344
345 if (m_data.storageClass == SC_PHYSICAL_STORAGE_BUFFER)
346 {
347 css << "layout(buffer_reference) buffer InputA { " << typeStrA << " x[]; };\n";
348 css << "layout(buffer_reference) buffer InputB { " << typeStrB << " x[]; };\n";
349 css << "layout(buffer_reference) buffer InputC { " << typeStrC << " x[]; };\n";
350 css << "layout(buffer_reference) buffer Output { " << typeStrO << " x[]; };\n";
351 css << "layout(set=0, binding=4) buffer Params { InputA inputA; InputB inputB; InputC inputC; Output outputO; } params;\n";
352 }
353 else
354 {
355 css << "layout(set=0, binding=0) coherent buffer InputA { " << typeStrA << " x[]; } inputA;\n";
356 css << "layout(set=0, binding=1) coherent buffer InputB { " << typeStrB << " x[]; } inputB;\n";
357 css << "layout(set=0, binding=2) coherent buffer InputC { " << typeStrC << " x[]; } inputC;\n";
358 css << "layout(set=0, binding=3) coherent buffer Output { " << typeStrO << " x[]; } outputO;\n";
359 }
360
361 if (m_data.storageClass == SC_WORKGROUP || m_data.storageClass == SC_WORKGROUP_VARIABLE_POINTERS)
362 {
363 css << "shared " << typeStrA << " sharedA[" << dims[0].rows << " * " << dims[0].cols << " * subgroupsPerWG.x * subgroupsPerWG.y];\n";
364 css << "shared " << typeStrB << " sharedB[" << dims[1].rows << " * " << dims[1].cols << " * subgroupsPerWG.x * subgroupsPerWG.y];\n";
365 css << "shared " << typeStrC << " sharedC[" << dims[2].rows << " * " << dims[2].cols << " * subgroupsPerWG.x * subgroupsPerWG.y];\n";
366 css << "shared " << typeStrO << " sharedO[" << dims[3].rows << " * " << dims[3].cols << " * subgroupsPerWG.x * subgroupsPerWG.y];\n";
367 }
368
369 std::stringstream matAType, matBType, matCType, outputMatType;
370
371 matAType << componentTypeInfo[m_data.inputType].coopmatTypeName << "<" << componentTypeInfo[m_data.inputType].bits << ", gl_ScopeSubgroup, " << dims[0].rows << ", " << dims[0].cols << ">";
372 matBType << componentTypeInfo[m_data.inputType].coopmatTypeName << "<" << componentTypeInfo[m_data.inputType].bits << ", gl_ScopeSubgroup, " << dims[1].rows << ", " << dims[1].cols << ">";
373 matCType << componentTypeInfo[m_data.outputType].coopmatTypeName << "<" << componentTypeInfo[m_data.outputType].bits << ", gl_ScopeSubgroup, " << dims[2].rows << ", " << dims[2].cols << ">";
374 outputMatType << componentTypeInfo[m_data.outputType].coopmatTypeName << "<" << componentTypeInfo[m_data.outputType].bits << ", gl_ScopeSubgroup, " << dims[3].rows << ", " << dims[3].cols << ">";
375
376 css << matAType.str() << " matA;\n";
377 css << matBType.str() << " matB;\n";
378 css << matCType.str() << " matC;\n";
379 css << outputMatType.str() << " matO;\n";
380
381 if (m_data.testType == TT_CONSTANT)
382 css << "const " << outputMatType.str() << " matConst = " << outputMatType.str() << "(1.0);\n";
383
384 if (m_data.testType == TT_FUNC)
385 css << matAType.str() << " f(" << matAType.str() << " m) { return -m; }\n";
386
387 css <<
388 "void main()\n"
389 "{\n"
390 // matrixID is the x,y index of the matrix owned by this subgroup.
391 " uvec2 subgroupXY = uvec2(gl_SubgroupID % subgroupsPerWG.x, gl_SubgroupID / subgroupsPerWG.x);\n"
392 " uvec2 matrixID = uvec2(gl_WorkGroupID.xy) * subgroupsPerWG + subgroupXY;\n";
393
394 if (m_data.storageClass == SC_PHYSICAL_STORAGE_BUFFER)
395 {
396 css << " InputA inputA = params.inputA;\n";
397 css << " InputB inputB = params.inputB;\n";
398 css << " InputC inputC = params.inputC;\n";
399 css << " Output outputO = params.outputO;\n";
400 }
401
402 string strides[4];
403 for (deUint32 i = 0; i < 4; ++i)
404 {
405 strides[i] = (m_data.colMajor ? dims[i].rows : dims[i].cols) + string(" * ") + de::toString(m_data.subgroupsPerWorkgroupX * m_data.workgroupsX);
406 }
407
408 // element<i> is the starting element in buffer memory.
409 // elementS<i> is the starting element in shared memory.
410 css << " uint element0 = " << strides[0] << " * " << (m_data.colMajor ? dims[0].cols : dims[0].rows) << " * matrixID.y + " << (m_data.colMajor ? dims[0].rows : dims[0].cols) << " * matrixID.x;\n"
411 " uint element1 = " << strides[1] << " * " << (m_data.colMajor ? dims[1].cols : dims[1].rows) << " * matrixID.y + " << (m_data.colMajor ? dims[1].rows : dims[1].cols) << " * matrixID.x;\n"
412 " uint element2 = " << strides[2] << " * " << (m_data.colMajor ? dims[2].cols : dims[2].rows) << " * matrixID.y + " << (m_data.colMajor ? dims[2].rows : dims[2].cols) << " * matrixID.x;\n"
413 " uint element3 = " << strides[3] << " * " << (m_data.colMajor ? dims[3].cols : dims[3].rows) << " * matrixID.y + " << (m_data.colMajor ? dims[3].rows : dims[3].cols) << " * matrixID.x;\n"
414 " uint elementS0, elementS1, elementS2, elementS3;\n";
415
416 // For shared memory tests, copy the matrix from buffer memory into
417 // workgroup memory. For simplicity, do it all on a single thread.
418 if (m_data.storageClass == SC_WORKGROUP || m_data.storageClass == SC_WORKGROUP_VARIABLE_POINTERS)
419 {
420 const char *name[] =
421 {
422 "sharedA",
423 "sharedB",
424 "sharedC",
425 };
426 const char *inputName[] =
427 {
428 "inputA",
429 "inputB",
430 "inputC",
431 };
432 for (deUint32 m = 0; m < 4; ++m)
433 {
434 string sharedStride = strides[m] + " / workgroupsX";
435 css << " elementS" << m << " = " << sharedStride << " * " << (m_data.colMajor ? dims[m].cols : dims[m].rows) << " * subgroupXY.y + " << (m_data.colMajor ? dims[m].rows : dims[m].cols) << " * subgroupXY.x;\n";
436 }
437 css << " if (subgroupElect()) {\n";
438 // copy all three input buffers.
439 for (deUint32 m = 0; m < 3; ++m)
440 {
441 string sharedStride = strides[m] + " / workgroupsX";
442 css << " for (int i = 0; i < " << dims[m].rows << "; ++i) {\n"
443 " for (int j = 0; j < " << dims[m].cols << "; ++j) {\n"
444 " int localElementInput = " << strides[m] << " * " << (m_data.colMajor ? "j" : "i") << " + " << (m_data.colMajor ? "i" : "j") << ";\n"
445 " int localElementShared = " << sharedStride << " * " << (m_data.colMajor ? "j" : "i") << " + " << (m_data.colMajor ? "i" : "j") << ";\n"
446 " " << name[m] << "[elementS" << m << " + localElementShared] = " << inputName[m] << ".x[element" << m << " + localElementInput];\n"
447 " }\n"
448 " }\n";
449 strides[m] = sharedStride;
450 }
451 css << " }\n";
452 css << " controlBarrier(gl_ScopeSubgroup, gl_ScopeSubgroup, gl_StorageSemanticsShared, gl_SemanticsAcquireRelease);\n";
453 }
454
455 const char *colMajor = (m_data.colMajor ? "true" : "false");
456
457 if (m_data.storageClass == SC_WORKGROUP || m_data.storageClass == SC_WORKGROUP_VARIABLE_POINTERS)
458 {
459 css << " coopMatLoadNV(matA, sharedA, elementS0, " << strides[0] << ", " << colMajor << ");\n"
460 " coopMatLoadNV(matB, sharedB, elementS1, " << strides[1] << ", " << colMajor << ");\n"
461 " coopMatLoadNV(matC, sharedC, elementS2, " << strides[2] << ", " << colMajor << ");\n";
462 }
463 else
464 {
465 css << " coopMatLoadNV(matA, inputA.x, element0, " << strides[0] << ", " << colMajor << ");\n"
466 " coopMatLoadNV(matB, inputB.x, element1, " << strides[1] << ", " << colMajor << ");\n"
467 " coopMatLoadNV(matC, inputC.x, element2, " << strides[2] << ", " << colMajor << ");\n";
468 }
469
470 if (m_data.testType == TT_COMPOSITE_ARRAY ||
471 m_data.testType == TT_MATRIXMULADD_ARRAY)
472 {
473 css << " " << matAType.str() << " matAArr[2];\n matAArr[1] = matA; matAArr[0] = " << matAType.str() << "(0.0);\n"
474 " " << matBType.str() << " matBArr[2];\n matBArr[1] = matB; matBArr[0] = " << matBType.str() << "(0.0);\n"
475 " " << matCType.str() << " matCArr[2];\n matCArr[1] = matC; matCArr[0] = " << matCType.str() << "(0.0);\n"
476 " " << outputMatType.str() << " matOArr[2];\n";
477 }
478
479 switch (m_data.testType)
480 {
481 default:
482 DE_ASSERT(0);
483 // fall through
484 case TT_LENGTH:
485 css << " matO = " << outputMatType.str() << "(matO.length());\n";
486 break;
487 case TT_CONSTANT:
488 css << " matO = matConst;\n";
489 break;
490 case TT_CONVERT:
491 css << " matO = " << outputMatType.str() << "(matA);\n";
492 break;
493 case TT_COMPOSITE:
494 case TT_COMPOSITE_RVALUE:
495 css << " for (int i = 0; i < matA.length(); ++i) {\n"
496 " matO[i] = matA[i] + matB[i];\n"
497 " }\n";
498 if (m_data.testType == TT_COMPOSITE_RVALUE)
499 {
500 css << " " << matAType.str() << " t = matA;\n"
501 " matO[0] = (t += matB)[0];\n"
502 " if (matA.length() > 0) {\n"
503 " t = matA;\n"
504 " matO[1] = (t += matB)[1];\n"
505 " }\n";
506 }
507 break;
508 case TT_COMPOSITE_ARRAY:
509 css << " for (int i = 0; i < matA.length(); ++i) {\n"
510 " matOArr[1][i] = matAArr[1][i] + matBArr[1][i];\n"
511 " }\n";
512 break;
513 case TT_ADD:
514 css << " matO = matA + matB;\n";
515 break;
516 case TT_SUB:
517 css << " matO = matA - matB;\n";
518 break;
519 case TT_DIV:
520 css << " matO = matA / matB;\n";
521 break;
522 case TT_NEGATE:
523 css << " matO = -matA;\n";
524 break;
525 case TT_FUNC:
526 css << " matO = f(matA);\n";
527 break;
528 case TT_MATRIXTIMESSCALAR:
529 css << " matO = (" << typeStrA << "(2.0)*matA)*" << typeStrA << "(3.0);\n";
530 break;
531 case TT_MATRIXMULADD:
532 css << " matO = coopMatMulAddNV(matA, matB, matC);\n";
533 break;
534 case TT_MATRIXMULADD_ARRAY:
535 css << " matOArr[1] = coopMatMulAddNV(matAArr[1], matBArr[1], matCArr[1]);\n";
536 break;
537 }
538
539 if (m_data.testType == TT_COMPOSITE_ARRAY ||
540 m_data.testType == TT_MATRIXMULADD_ARRAY)
541 {
542 css << " matOArr[0] = " << outputMatType.str() << "(0.0);\n";
543 css << " matO = matOArr[1];\n";
544 }
545
546 if (m_data.storageClass == SC_WORKGROUP || m_data.storageClass == SC_WORKGROUP_VARIABLE_POINTERS)
547 {
548 string sharedStride = strides[3] + " / workgroupsX";
549 css << " coopMatStoreNV(matO, sharedO, elementS3, " << sharedStride << ", " << colMajor << ");\n";
550 css << " controlBarrier(gl_ScopeSubgroup, gl_ScopeSubgroup, gl_StorageSemanticsShared, gl_SemanticsAcquireRelease);\n";
551 css << " if (subgroupElect()) {\n";
552 css << " for (int i = 0; i < " << dims[3].rows << "; ++i) {\n"
553 " for (int j = 0; j < " << dims[3].cols << "; ++j) {\n"
554 " int localElementInput = " << strides[3] << " * " << (m_data.colMajor ? "j" : "i") << " + " << (m_data.colMajor ? "i" : "j") << ";\n"
555 " int localElementShared = " << sharedStride << " * " << (m_data.colMajor ? "j" : "i") << " + " << (m_data.colMajor ? "i" : "j") << ";\n"
556 " outputO.x[element3 + localElementInput] = sharedO[elementS3 + localElementShared];\n"
557 " }\n"
558 " }\n";
559 css << " }\n";
560 }
561 else
562 {
563 css << " coopMatStoreNV(matO, outputO.x, element3, " << strides[3] << ", " << colMajor << ");\n";
564 }
565
566 css <<
567 "}\n";
568
569 const vk::ShaderBuildOptions buildOptions (programCollection.usedVulkanVersion, vk::SPIRV_VERSION_1_3, 0u);
570
571 programCollection.glslSources.add("test") << glu::ComputeSource(css.str()) << buildOptions;
572 }
573
createInstance(Context & context) const574 TestInstance* CooperativeMatrixTestCase::createInstance (Context& context) const
575 {
576 return new CooperativeMatrixTestInstance(context, m_data);
577 }
578
setDataFloat(void * base,VkComponentTypeNV dt,deUint32 i,float value)579 static void setDataFloat(void *base, VkComponentTypeNV dt, deUint32 i, float value)
580 {
581 if (dt == VK_COMPONENT_TYPE_FLOAT32_NV)
582 {
583 ((float *)base)[i] = value;
584 }
585 else
586 {
587 DE_ASSERT(dt == VK_COMPONENT_TYPE_FLOAT16_NV);
588 ((deFloat16 *)base)[i] = deFloat32To16(value);
589 }
590 }
591
getDataFloat(void * base,VkComponentTypeNV dt,deUint32 i)592 static float getDataFloat(void *base, VkComponentTypeNV dt, deUint32 i)
593 {
594 if (dt == VK_COMPONENT_TYPE_FLOAT32_NV)
595 {
596 return ((float *)base)[i];
597 }
598 else
599 {
600 DE_ASSERT(dt == VK_COMPONENT_TYPE_FLOAT16_NV);
601 return deFloat16To32(((deFloat16 *)base)[i]);
602 }
603 }
604
setDataInt(void * base,VkComponentTypeNV dt,deUint32 i,deUint32 value)605 static void setDataInt(void *base, VkComponentTypeNV dt, deUint32 i, deUint32 value)
606 {
607 DE_ASSERT(componentTypeInfo[dt].bits <= 32);
608 switch (dt) {
609 default: DE_ASSERT(0); // fallthrough
610 case VK_COMPONENT_TYPE_UINT8_NV: ((deUint8 *)base)[i] = (deUint8)value; break;
611 case VK_COMPONENT_TYPE_UINT16_NV: ((deUint16 *)base)[i] = (deUint16)value; break;
612 case VK_COMPONENT_TYPE_UINT32_NV: ((deUint32 *)base)[i] = (deUint32)value; break;
613 case VK_COMPONENT_TYPE_SINT8_NV: ((deInt8 *)base)[i] = (deInt8)value; break;
614 case VK_COMPONENT_TYPE_SINT16_NV: ((deInt16 *)base)[i] = (deInt16)value; break;
615 case VK_COMPONENT_TYPE_SINT32_NV: ((deInt32 *)base)[i] = (deInt32)value; break;
616 }
617 }
618
getDataInt(void * base,VkComponentTypeNV dt,deUint32 i)619 static deUint32 getDataInt(void *base, VkComponentTypeNV dt, deUint32 i)
620 {
621 DE_ASSERT(componentTypeInfo[dt].bits <= 32);
622 switch (dt) {
623 default: DE_ASSERT(0); // fallthrough
624 case VK_COMPONENT_TYPE_UINT8_NV: return ((deUint8 *)base)[i];
625 case VK_COMPONENT_TYPE_UINT16_NV: return ((deUint16 *)base)[i];
626 case VK_COMPONENT_TYPE_UINT32_NV: return ((deUint32 *)base)[i];
627 case VK_COMPONENT_TYPE_SINT8_NV: return ((deInt8 *)base)[i];
628 case VK_COMPONENT_TYPE_SINT16_NV: return ((deInt16 *)base)[i];
629 case VK_COMPONENT_TYPE_SINT32_NV: return ((deInt32 *)base)[i];
630 }
631 }
632
iterate(void)633 tcu::TestStatus CooperativeMatrixTestInstance::iterate (void)
634 {
635 const DeviceInterface& vk = m_context.getDeviceInterface();
636 const VkDevice device = m_context.getDevice();
637 Allocator& allocator = m_context.getDefaultAllocator();
638 MemoryRequirement memoryDeviceAddress = m_data.storageClass == SC_PHYSICAL_STORAGE_BUFFER &&
639 m_context.isDeviceFunctionalitySupported("VK_KHR_buffer_device_address") ? MemoryRequirement::DeviceAddress : MemoryRequirement::Any;
640 qpTestResult finalres = QP_TEST_RESULT_PASS;
641 tcu::TestLog& log = m_context.getTestContext().getLog();
642
643 deRandom rnd;
644 deRandom_init(&rnd, 1234);
645
646 vk::VkPhysicalDeviceSubgroupProperties subgroupProperties;
647 deMemset(&subgroupProperties, 0, sizeof(subgroupProperties));
648 subgroupProperties.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SUBGROUP_PROPERTIES;
649
650 vk::VkPhysicalDeviceProperties2 properties2;
651 deMemset(&properties2, 0, sizeof(properties2));
652 properties2.sType = vk::VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_PROPERTIES_2;
653 properties2.pNext = &subgroupProperties;
654
655 m_context.getInstanceInterface().getPhysicalDeviceProperties2(m_context.getPhysicalDevice(), &properties2);
656
657 deUint32 propertyCount = 0;
658 VkCooperativeMatrixPropertiesNV *pProperties;
659 m_context.getInstanceInterface().getPhysicalDeviceCooperativeMatrixPropertiesNV(m_context.getPhysicalDevice(), &propertyCount, DE_NULL);
660 // Shouldn't have made it through checkSupport without any properties
661 DE_ASSERT(propertyCount != 0);
662
663 pProperties = new VkCooperativeMatrixPropertiesNV[propertyCount];
664
665 for (deUint32 i = 0; i < propertyCount; ++i)
666 {
667 VkCooperativeMatrixPropertiesNV *p = &pProperties[i];
668 p->sType = VK_STRUCTURE_TYPE_COOPERATIVE_MATRIX_PROPERTIES_NV;
669 p->pNext = DE_NULL;
670 }
671
672 m_context.getInstanceInterface().getPhysicalDeviceCooperativeMatrixPropertiesNV(m_context.getPhysicalDevice(), &propertyCount, pProperties);
673
674 struct TestTuple
675 {
676 TestTuple() {}
677 TestTuple(deUint32 m, deUint32 n, deUint32 k) : M(m), N(n), K(k) {}
678
679 bool operator<(const TestTuple &other) const
680 {
681 return M < other.M ||
682 (M == other.M && N < other.N) ||
683 (M == other.M && N == other.N && K < other.K);
684 }
685
686 deUint32 M, N, K;
687 };
688
689 vector<TestTuple> testSizes;
690
691 if (m_data.testType == TT_MATRIXMULADD ||
692 m_data.testType == TT_MATRIXMULADD_ARRAY)
693 {
694 for (deUint32 i = 0; i < propertyCount; ++i)
695 {
696 VkCooperativeMatrixPropertiesNV *p = &pProperties[i];
697
698 if (p->AType == m_data.inputType &&
699 p->BType == m_data.inputType &&
700 p->CType == m_data.outputType &&
701 p->DType == m_data.outputType &&
702 p->scope == VK_SCOPE_SUBGROUP_NV)
703 {
704 testSizes.push_back(TestTuple(p->MSize, p->NSize, p->KSize));
705 }
706 }
707 }
708 else
709 {
710 set<TestTuple> typeSizes[2];
711 VkComponentTypeNV types[2] = { m_data.inputType, m_data.outputType };
712
713 for (deUint32 i = 0; i < propertyCount; ++i)
714 {
715 VkCooperativeMatrixPropertiesNV *p = &pProperties[i];
716
717 if (p->scope != VK_SCOPE_SUBGROUP_NV)
718 continue;
719
720 for (deUint32 j = 0; j < 2; ++j)
721 {
722 // For these tests, m_data.M/N are always the matrix size. Check if they match
723 // any input or output in the list.
724 if (p->AType == types[j])
725 typeSizes[j].insert(TestTuple(p->MSize, p->KSize, 0));
726 if (p->BType == types[j])
727 typeSizes[j].insert(TestTuple(p->KSize, p->NSize, 0));
728 if (p->CType == types[j] ||
729 p->DType == types[j])
730 typeSizes[j].insert(TestTuple(p->MSize, p->NSize, 0));
731 }
732 }
733 // Test those sizes that are supported for both the input and output type.
734 std::set_intersection(typeSizes[0].begin(), typeSizes[0].end(),
735 typeSizes[1].begin(), typeSizes[1].end(),
736 std::back_inserter(testSizes));
737 }
738
739 delete [] pProperties;
740
741 for (unsigned int s = 0; s < testSizes.size(); ++s)
742 {
743 // When testing a multiply, MxNxK is the type of matrix multiply.
744 // Otherwise, MxN is the size of the input/output matrices
745 deUint32 M, N, K;
746 M = testSizes[s].M;
747 N = testSizes[s].N;
748 K = testSizes[s].K;
749
750 log << tcu::TestLog::Message << "Testing M = " << M << ", N = " << N << ", K = " << K << tcu::TestLog::EndMessage;
751
752 struct
753 {
754 deUint32 rows, cols;
755 } dims[4];
756
757 if (m_data.testType == TT_MATRIXMULADD ||
758 m_data.testType == TT_MATRIXMULADD_ARRAY)
759 {
760 dims[0].rows = M;
761 dims[0].cols = K;
762 dims[1].rows = K;
763 dims[1].cols = N;
764 dims[2].rows = M;
765 dims[2].cols = N;
766 dims[3].rows = M;
767 dims[3].cols = N;
768 }
769 else
770 {
771 dims[0].rows = M;
772 dims[0].cols = N;
773 dims[1].rows = M;
774 dims[1].cols = N;
775 dims[2].rows = M;
776 dims[2].cols = N;
777 dims[3].rows = M;
778 dims[3].cols = N;
779 }
780
781 VkComponentTypeNV dataTypes[4];
782 size_t elementSize[4];
783 VkDeviceSize bufferSizes[5];
784 de::MovePtr<BufferWithMemory> buffers[5];
785 vk::VkDescriptorBufferInfo bufferDescriptors[5];
786 deUint32 strides[4]; // in elements
787 deUint32 totalElements[4];
788
789 for (deUint32 i = 0; i < 5; ++i)
790 {
791 if (i < 4)
792 {
793 // A/B use input type, C/D use output type
794 dataTypes[i] = (i < 2) ? m_data.inputType : m_data.outputType;
795 elementSize[i] = componentTypeInfo[dataTypes[i]].bits / 8;
796
797 strides[i] = (m_data.colMajor ? dims[i].rows : dims[i].cols) * m_data.subgroupsPerWorkgroupX * m_data.workgroupsX;
798 totalElements[i] = strides[i] * (m_data.colMajor ? dims[i].cols : dims[i].rows) * m_data.subgroupsPerWorkgroupY * m_data.workgroupsY;
799
800 bufferSizes[i] = totalElements[i] * elementSize[i];
801 }
802 else
803 {
804 bufferSizes[4] = sizeof(VkDeviceAddress)*4;
805 }
806
807 try
808 {
809 buffers[i] = de::MovePtr<BufferWithMemory>(new BufferWithMemory(
810 vk, device, allocator, makeBufferCreateInfo(bufferSizes[i], VK_BUFFER_USAGE_STORAGE_BUFFER_BIT|VK_BUFFER_USAGE_TRANSFER_DST_BIT|VK_BUFFER_USAGE_TRANSFER_SRC_BIT|VK_BUFFER_USAGE_SHADER_DEVICE_ADDRESS_BIT_EXT),
811 MemoryRequirement::HostVisible | MemoryRequirement::Cached | MemoryRequirement::Coherent | memoryDeviceAddress));
812 }
813 catch (const tcu::NotSupportedError&)
814 {
815 buffers[i] = de::MovePtr<BufferWithMemory>(new BufferWithMemory(
816 vk, device, allocator, makeBufferCreateInfo(bufferSizes[i], VK_BUFFER_USAGE_STORAGE_BUFFER_BIT|VK_BUFFER_USAGE_TRANSFER_DST_BIT|VK_BUFFER_USAGE_TRANSFER_SRC_BIT|VK_BUFFER_USAGE_SHADER_DEVICE_ADDRESS_BIT_EXT),
817 MemoryRequirement::HostVisible | memoryDeviceAddress));
818 }
819
820 bufferDescriptors[i] = makeDescriptorBufferInfo(**buffers[i], 0, bufferSizes[i]);
821 }
822
823 void *ptrs[5];
824 for (deUint32 i = 0; i < 5; ++i)
825 {
826 ptrs[i] = buffers[i]->getAllocation().getHostPtr();
827 }
828
829 vk::DescriptorSetLayoutBuilder layoutBuilder;
830
831 layoutBuilder.addSingleBinding(VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, allShaderStages);
832 layoutBuilder.addSingleBinding(VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, allShaderStages);
833 layoutBuilder.addSingleBinding(VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, allShaderStages);
834 layoutBuilder.addSingleBinding(VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, allShaderStages);
835 layoutBuilder.addSingleBinding(VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, allShaderStages);
836
837 vk::Unique<vk::VkDescriptorSetLayout> descriptorSetLayout(layoutBuilder.build(vk, device));
838
839 vk::Unique<vk::VkDescriptorPool> descriptorPool(vk::DescriptorPoolBuilder()
840 .addType(VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, 5u)
841 .build(vk, device, VK_DESCRIPTOR_POOL_CREATE_FREE_DESCRIPTOR_SET_BIT, 1u));
842 vk::Unique<vk::VkDescriptorSet> descriptorSet (makeDescriptorSet(vk, device, *descriptorPool, *descriptorSetLayout));
843
844 vk::DescriptorSetUpdateBuilder setUpdateBuilder;
845 if (m_data.storageClass == SC_PHYSICAL_STORAGE_BUFFER)
846 {
847 const bool useKHR = m_context.isDeviceFunctionalitySupported("VK_KHR_buffer_device_address");
848
849 VkBufferDeviceAddressInfo info =
850 {
851 VK_STRUCTURE_TYPE_BUFFER_DEVICE_ADDRESS_INFO, // VkStructureType sType;
852 DE_NULL, // const void* pNext;
853 0, // VkBuffer buffer
854 };
855 VkDeviceAddress *addrsInMemory = (VkDeviceAddress *)ptrs[4];
856 for (deUint32 i = 0; i < 4; ++i)
857 {
858 info.buffer = **buffers[i];
859 VkDeviceAddress addr;
860 if (useKHR)
861 addr = vk.getBufferDeviceAddress(device, &info);
862 else
863 addr = vk.getBufferDeviceAddressEXT(device, &info);
864 addrsInMemory[i] = addr;
865 }
866 setUpdateBuilder.writeSingle(*descriptorSet, vk::DescriptorSetUpdateBuilder::Location::binding(4),
867 VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, &bufferDescriptors[4]);
868 }
869 else
870 {
871 setUpdateBuilder.writeSingle(*descriptorSet, vk::DescriptorSetUpdateBuilder::Location::binding(0),
872 VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, &bufferDescriptors[0]);
873 setUpdateBuilder.writeSingle(*descriptorSet, vk::DescriptorSetUpdateBuilder::Location::binding(1),
874 VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, &bufferDescriptors[1]);
875 setUpdateBuilder.writeSingle(*descriptorSet, vk::DescriptorSetUpdateBuilder::Location::binding(2),
876 VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, &bufferDescriptors[2]);
877 setUpdateBuilder.writeSingle(*descriptorSet, vk::DescriptorSetUpdateBuilder::Location::binding(3),
878 VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, &bufferDescriptors[3]);
879 }
880
881 setUpdateBuilder.update(vk, device);
882
883 const VkPipelineLayoutCreateInfo pipelineLayoutCreateInfo =
884 {
885 VK_STRUCTURE_TYPE_PIPELINE_LAYOUT_CREATE_INFO, // sType
886 DE_NULL, // pNext
887 (VkPipelineLayoutCreateFlags)0,
888 1, // setLayoutCount
889 &descriptorSetLayout.get(), // pSetLayouts
890 0u, // pushConstantRangeCount
891 DE_NULL, // pPushConstantRanges
892 };
893
894 Move<VkPipelineLayout> pipelineLayout = createPipelineLayout(vk, device, &pipelineLayoutCreateInfo, NULL);
895
896 Move<VkPipeline> pipeline;
897
898 VkPipelineBindPoint bindPoint = VK_PIPELINE_BIND_POINT_COMPUTE;
899
900 const deUint32 specData[9] =
901 {
902 subgroupProperties.subgroupSize * m_data.subgroupsPerWorkgroupX,
903 m_data.subgroupsPerWorkgroupY,
904 strides[0],
905 strides[1],
906 strides[2],
907 strides[3],
908 M,
909 N,
910 K,
911 };
912
913 const vk::VkSpecializationMapEntry entries[9] =
914 {
915 {0, (deUint32)(sizeof(deUint32) * 0), sizeof(deUint32)},
916 {1, (deUint32)(sizeof(deUint32) * 1), sizeof(deUint32)},
917 {2, (deUint32)(sizeof(deUint32) * 2), sizeof(deUint32)},
918 {3, (deUint32)(sizeof(deUint32) * 3), sizeof(deUint32)},
919 {4, (deUint32)(sizeof(deUint32) * 4), sizeof(deUint32)},
920 {5, (deUint32)(sizeof(deUint32) * 5), sizeof(deUint32)},
921 {6, (deUint32)(sizeof(deUint32) * 6), sizeof(deUint32)},
922 {7, (deUint32)(sizeof(deUint32) * 7), sizeof(deUint32)},
923 {8, (deUint32)(sizeof(deUint32) * 8), sizeof(deUint32)},
924 };
925
926 const vk::VkSpecializationInfo specInfo =
927 {
928 9, // mapEntryCount
929 entries, // pMapEntries
930 sizeof(specData), // dataSize
931 specData // pData
932 };
933
934 for (deUint32 i = 0; i < 4; ++i)
935 for (deUint32 j = 0; j < totalElements[i]; ++j)
936 {
937 if (isFloatType(dataTypes[i]))
938 {
939 if (m_data.testType != TT_MATRIXMULADD &&
940 m_data.testType != TT_MATRIXMULADD_ARRAY)
941 setDataFloat(ptrs[i], dataTypes[i], j, ((float)(deRandom_getUint32(&rnd) & 0xff) - 64.0f)/2.0f);
942 else
943 setDataFloat(ptrs[i], dataTypes[i], j, ((float)(deRandom_getUint32(&rnd) & 0xf) - 4.0f)/2.0f);
944 }
945 else
946 setDataInt(ptrs[i], dataTypes[i], j, (deRandom_getUint32(&rnd) & 0xff) - 128);
947 }
948
949 flushAlloc(vk, device, buffers[0]->getAllocation());
950 flushAlloc(vk, device, buffers[1]->getAllocation());
951 flushAlloc(vk, device, buffers[2]->getAllocation());
952 flushAlloc(vk, device, buffers[3]->getAllocation());
953
954 const Unique<VkShaderModule> shader (createShaderModule(vk, device, m_context.getBinaryCollection().get("test"), 0));
955
956 const VkPipelineShaderStageCreateInfo shaderCreateInfo =
957 {
958 VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO,
959 DE_NULL,
960 (VkPipelineShaderStageCreateFlags)0,
961 VK_SHADER_STAGE_COMPUTE_BIT, // stage
962 *shader, // shader
963 "main",
964 &specInfo, // pSpecializationInfo
965 };
966
967 const VkComputePipelineCreateInfo pipelineCreateInfo =
968 {
969 VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO,
970 DE_NULL,
971 0u, // flags
972 shaderCreateInfo, // cs
973 *pipelineLayout, // layout
974 (vk::VkPipeline)0, // basePipelineHandle
975 0u, // basePipelineIndex
976 };
977 pipeline = createComputePipeline(vk, device, DE_NULL, &pipelineCreateInfo, NULL);
978
979 const VkQueue queue = m_context.getUniversalQueue();
980 Move<VkCommandPool> cmdPool = createCommandPool(vk, device, 0, m_context.getUniversalQueueFamilyIndex());
981 Move<VkCommandBuffer> cmdBuffer = allocateCommandBuffer(vk, device, *cmdPool, VK_COMMAND_BUFFER_LEVEL_PRIMARY);
982
983 beginCommandBuffer(vk, *cmdBuffer, 0u);
984
985 vk.cmdBindDescriptorSets(*cmdBuffer, bindPoint, *pipelineLayout, 0u, 1, &*descriptorSet, 0u, DE_NULL);
986 vk.cmdBindPipeline(*cmdBuffer, bindPoint, *pipeline);
987
988 vk.cmdDispatch(*cmdBuffer, m_data.workgroupsX, m_data.workgroupsY, 1);
989
990 endCommandBuffer(vk, *cmdBuffer);
991
992 submitCommandsAndWait(vk, device, queue, cmdBuffer.get());
993
994 invalidateAlloc(vk, device, buffers[3]->getAllocation());
995
996 qpTestResult res = QP_TEST_RESULT_PASS;
997
998 if (isFloatType(dataTypes[0]))
999 {
1000 if (m_data.testType != TT_MATRIXMULADD &&
1001 m_data.testType != TT_MATRIXMULADD_ARRAY)
1002 {
1003 for (deUint32 i = 0; i < totalElements[3]; ++i)
1004 {
1005 float inputA = getDataFloat(ptrs[0], dataTypes[0], i);
1006 float inputB = getDataFloat(ptrs[1], dataTypes[1], i);
1007 float output = getDataFloat(ptrs[3], dataTypes[3], i);
1008 switch (m_data.testType)
1009 {
1010 case TT_LENGTH:
1011 if (output < 1.0f || output > (float)(N*M))
1012 res = QP_TEST_RESULT_FAIL;
1013 // We expect the matrix to be spread evenly across invocations, it is
1014 // surprising (but not necessarily illegal) if not
1015 if (output != (float)(N*M/subgroupProperties.subgroupSize) &&
1016 res == QP_TEST_RESULT_PASS)
1017 res = QP_TEST_RESULT_QUALITY_WARNING;
1018 break;
1019 case TT_CONSTANT:
1020 if (output != 1.0f)
1021 res = QP_TEST_RESULT_FAIL;
1022 break;
1023 case TT_CONVERT:
1024 if (output != inputA)
1025 res = QP_TEST_RESULT_FAIL;
1026 break;
1027 case TT_COMPOSITE:
1028 case TT_COMPOSITE_RVALUE:
1029 case TT_COMPOSITE_ARRAY:
1030 case TT_ADD:
1031 if (output != inputA + inputB)
1032 res = QP_TEST_RESULT_FAIL;
1033 break;
1034 case TT_SUB:
1035 if (output != inputA - inputB)
1036 res = QP_TEST_RESULT_FAIL;
1037 break;
1038 case TT_DIV:
1039 {
1040 float ulp = (m_data.inputType == VK_COMPONENT_TYPE_FLOAT16_NV) ? 1.0f/1024.0f : 1.0f/(8.0f*1024.0f*1024.0f);
1041 // division allows 2.5ulp, but we'll use 3.
1042 ulp *= 3;
1043 if (inputB != 0 && fabs(output - inputA / inputB) > ulp * fabs(inputA / inputB))
1044 res = QP_TEST_RESULT_FAIL;
1045 }
1046 break;
1047 case TT_NEGATE:
1048 case TT_FUNC:
1049 if (output != -inputA)
1050 res = QP_TEST_RESULT_FAIL;
1051 break;
1052 case TT_MATRIXTIMESSCALAR:
1053 if (output != 6.0*inputA)
1054 res = QP_TEST_RESULT_FAIL;
1055 break;
1056 default:
1057 break;
1058 }
1059 }
1060 }
1061 else
1062 {
1063 deUint32 ik, kj, ij;
1064 for (deUint32 mX = 0; mX < m_data.subgroupsPerWorkgroupX*m_data.workgroupsX; ++mX)
1065 {
1066 for (deUint32 mY = 0; mY < m_data.subgroupsPerWorkgroupY*m_data.workgroupsY; ++mY)
1067 {
1068 for (deUint32 i = 0; i < M; ++i)
1069 {
1070 for (deUint32 j = 0; j < N; ++j)
1071 {
1072 float ref = 0;
1073 for (deUint32 k = 0; k < K; ++k)
1074 {
1075 if (m_data.colMajor)
1076 ik = mX * M + i + strides[0] * (mY * K + k);
1077 else
1078 ik = mX * K + k + strides[0] * (mY * M + i);
1079
1080 float Aik = getDataFloat(ptrs[0], dataTypes[0], ik);
1081
1082 if (m_data.colMajor)
1083 kj = mX * K + k + strides[1] * (mY * N + j);
1084 else
1085 kj = mX * N + j + strides[1] * (mY * K + k);
1086
1087 float Bkj = getDataFloat(ptrs[1], dataTypes[1], kj);
1088
1089 ref += Aik*Bkj;
1090 }
1091
1092 if (m_data.colMajor)
1093 ij = mX * M + i + strides[2] * (mY * N + j);
1094 else
1095 ij = mX * N + j + strides[2] * (mY * M + i);
1096
1097 float Cij = getDataFloat(ptrs[2], dataTypes[2], ij);
1098
1099 ref += Cij;
1100
1101 float Dij = getDataFloat(ptrs[3], dataTypes[3], ij);
1102
1103 if (ref != Dij)
1104 {
1105 res = QP_TEST_RESULT_FAIL;
1106 }
1107 }
1108 }
1109 }
1110 }
1111 }
1112 } else {
1113 if (m_data.testType != TT_MATRIXMULADD &&
1114 m_data.testType != TT_MATRIXMULADD_ARRAY)
1115 {
1116 for (deUint32 i = 0; i < totalElements[3]; ++i)
1117 {
1118 deUint32 inputA = getDataInt(ptrs[0], dataTypes[0], i);
1119 deUint32 inputB = getDataInt(ptrs[1], dataTypes[1], i);
1120 deUint32 output = getDataInt(ptrs[3], dataTypes[3], i);
1121 int resultSize = componentTypeInfo[dataTypes[3]].bits;
1122 deUint32 mask = resultSize == 32 ? ~0 : ((1 << resultSize) - 1);
1123 switch (m_data.testType)
1124 {
1125 case TT_LENGTH:
1126 if (output < 1 || output > N*M)
1127 res = QP_TEST_RESULT_FAIL;
1128 // We expect the matrix to be spread evenly across invocations, it is
1129 // surprising (but not necessarily illegal) if not
1130 if (output != N*M/subgroupProperties.subgroupSize &&
1131 res == QP_TEST_RESULT_PASS)
1132 res = QP_TEST_RESULT_QUALITY_WARNING;
1133 break;
1134 case TT_CONSTANT:
1135 if (output != 1)
1136 res = QP_TEST_RESULT_FAIL;
1137 break;
1138 case TT_CONVERT:
1139 if (output != inputA)
1140 res = QP_TEST_RESULT_FAIL;
1141 break;
1142 case TT_COMPOSITE:
1143 case TT_COMPOSITE_RVALUE:
1144 case TT_COMPOSITE_ARRAY:
1145 case TT_ADD:
1146 if ((output & mask) != ((inputA + inputB) & mask)) {
1147 res = QP_TEST_RESULT_FAIL;
1148 }
1149 break;
1150 case TT_SUB:
1151 if ((output & mask) != ((inputA - inputB) & mask))
1152 res = QP_TEST_RESULT_FAIL;
1153 break;
1154 case TT_DIV:
1155 {
1156 if (isSIntType(dataTypes[3]))
1157 {
1158 if (inputB != 0 && ((deInt32)output & mask) != (((deInt32)inputA / (deInt32)inputB) & mask))
1159 res = QP_TEST_RESULT_FAIL;
1160 } else
1161 {
1162 if (inputB != 0 && output != inputA / inputB)
1163 res = QP_TEST_RESULT_FAIL;
1164 }
1165 }
1166 break;
1167 case TT_NEGATE:
1168 case TT_FUNC:
1169 if ((output & mask) != ((-(deInt32)inputA) & mask))
1170 res = QP_TEST_RESULT_FAIL;
1171 break;
1172 case TT_MATRIXTIMESSCALAR:
1173 if ((output & mask) != ((6*inputA) & mask)) {
1174 res = QP_TEST_RESULT_FAIL;
1175 }
1176 break;
1177 default:
1178 break;
1179 }
1180 }
1181 }
1182 else
1183 {
1184 deUint32 ik, kj, ij;
1185 for (deUint32 mX = 0; mX < m_data.subgroupsPerWorkgroupX*m_data.workgroupsX; ++mX)
1186 {
1187 for (deUint32 mY = 0; mY < m_data.subgroupsPerWorkgroupY*m_data.workgroupsY; ++mY)
1188 {
1189 for (deUint32 i = 0; i < M; ++i)
1190 {
1191 for (deUint32 j = 0; j < N; ++j)
1192 {
1193 deUint32 ref = 0;
1194 for (deUint32 k = 0; k < K; ++k)
1195 {
1196 if (m_data.colMajor)
1197 ik = mX * M + i + strides[0] * (mY * K + k);
1198 else
1199 ik = mX * K + k + strides[0] * (mY * M + i);
1200
1201 deUint32 Aik = getDataInt(ptrs[0], dataTypes[0], ik);
1202
1203 if (m_data.colMajor)
1204 kj = mX * K + k + strides[1] * (mY * N + j);
1205 else
1206 kj = mX * N + j + strides[1] * (mY * K + k);
1207
1208 deUint32 Bkj = getDataInt(ptrs[1], dataTypes[1], kj);
1209
1210 ref += Aik*Bkj;
1211 }
1212
1213 if (m_data.colMajor)
1214 ij = mX * M + i + strides[2] * (mY * N + j);
1215 else
1216 ij = mX * N + j + strides[2] * (mY * M + i);
1217
1218 deUint32 Cij = getDataInt(ptrs[2], dataTypes[2], ij);
1219
1220 ref += Cij;
1221
1222 deUint32 Dij = getDataInt(ptrs[3], dataTypes[3], ij);
1223
1224 if (ref != Dij)
1225 {
1226 res = QP_TEST_RESULT_FAIL;
1227 }
1228 }
1229 }
1230 }
1231 }
1232 }
1233 }
1234 if (res != QP_TEST_RESULT_PASS)
1235 {
1236 log << tcu::TestLog::Message << "failed with M = " << M << ", N = " << N << ", K = " << K << tcu::TestLog::EndMessage;
1237 finalres = res;
1238 }
1239 }
1240
1241 return tcu::TestStatus(finalres, qpGetTestResultName(finalres));
1242 }
1243
1244 } // anonymous
1245
createCooperativeMatrixTests(tcu::TestContext & testCtx)1246 tcu::TestCaseGroup* createCooperativeMatrixTests (tcu::TestContext& testCtx)
1247 {
1248 de::MovePtr<tcu::TestCaseGroup> group(new tcu::TestCaseGroup(
1249 testCtx, "cooperative_matrix", "GL_NV_cooperative_matrix tests"));
1250
1251 typedef struct
1252 {
1253 deUint32 value;
1254 const char* name;
1255 const char* description;
1256 } TestGroupCase;
1257
1258 typedef struct
1259 {
1260 deUint32 value[2];
1261 const char* name;
1262 const char* description;
1263 } TestGroupCase2;
1264
1265 TestGroupCase ttCases[] =
1266 {
1267 { TT_LENGTH, "length", "OpCooperativeMatrixLengthNV" },
1268 { TT_CONSTANT, "constant", "OpConstantComposite" },
1269 { TT_CONVERT, "convert", "OpFConvert/OpSConvert/OpUConvert" },
1270 { TT_COMPOSITE, "composite", "OpCompositeConstruct" },
1271 { TT_COMPOSITE_RVALUE, "composite_rvalue", "OpCompositeExtract" },
1272 { TT_ADD, "add", "OpFAdd/OpIAdd" },
1273 { TT_SUB, "sub", "OpFSub/OpISub" },
1274 { TT_DIV, "div", "OpFDiv/OpSDiv/OpUDiv" },
1275 { TT_NEGATE, "negate", "OpFNegate/OpSNegate" },
1276 { TT_MATRIXTIMESSCALAR, "matrixtimesscalar", "OpMatrixTimesScalar" },
1277 { TT_FUNC, "func", "OpFunctionParameter" },
1278 { TT_MATRIXMULADD, "matrixmuladd", "OpCooperativeMatrixMulAddNV" },
1279 { TT_COMPOSITE_ARRAY, "composite_array", "OpCompositeConstruct w/array" },
1280 { TT_MATRIXMULADD_ARRAY, "matrixmuladd_array", "OpCooperativeMatrixMulAddNV w/array" },
1281 };
1282
1283 TestGroupCase2 dtCases[] =
1284 {
1285 { { VK_COMPONENT_TYPE_FLOAT32_NV, VK_COMPONENT_TYPE_FLOAT32_NV }, "float32_float32", "A/B are fp32 C/D are fp32" },
1286 { { VK_COMPONENT_TYPE_FLOAT32_NV, VK_COMPONENT_TYPE_FLOAT16_NV }, "float32_float16", "A/B are fp32 C/D are fp16" },
1287 { { VK_COMPONENT_TYPE_FLOAT16_NV, VK_COMPONENT_TYPE_FLOAT32_NV }, "float16_float32", "A/B are fp16 C/D are fp32" },
1288 { { VK_COMPONENT_TYPE_FLOAT16_NV, VK_COMPONENT_TYPE_FLOAT16_NV }, "float16_float16", "A/B are fp16 C/D are fp16" },
1289 { { VK_COMPONENT_TYPE_UINT8_NV, VK_COMPONENT_TYPE_UINT8_NV }, "uint8_uint8", "A/B are u8 C/D are u8" },
1290 { { VK_COMPONENT_TYPE_UINT8_NV, VK_COMPONENT_TYPE_UINT32_NV }, "uint8_uint32", "A/B are u8 C/D are u32" },
1291 { { VK_COMPONENT_TYPE_SINT8_NV, VK_COMPONENT_TYPE_SINT8_NV }, "sint8_sint8", "A/B are s8 C/D are s8" },
1292 { { VK_COMPONENT_TYPE_SINT8_NV, VK_COMPONENT_TYPE_SINT32_NV }, "sint8_sint32", "A/B are s8 C/D are s32" },
1293 { { VK_COMPONENT_TYPE_UINT32_NV, VK_COMPONENT_TYPE_UINT32_NV }, "uint32_uint32", "A/B are u32 C/D are u32" },
1294 { { VK_COMPONENT_TYPE_UINT32_NV, VK_COMPONENT_TYPE_UINT8_NV }, "uint32_uint8", "A/B are u32 C/D are u8" },
1295 { { VK_COMPONENT_TYPE_SINT32_NV, VK_COMPONENT_TYPE_SINT32_NV }, "sint32_sint32", "A/B are s32 C/D are s32" },
1296 { { VK_COMPONENT_TYPE_SINT32_NV, VK_COMPONENT_TYPE_SINT8_NV }, "sint32_sint8", "A/B are s32 C/D are s8" },
1297 };
1298
1299 TestGroupCase colCases[] =
1300 {
1301 { 0, "rowmajor", "row major" },
1302 { 1, "colmajor", "col major" },
1303 };
1304
1305 TestGroupCase scCases[] =
1306 {
1307 { SC_BUFFER, "buffer", "SSBO" },
1308 { SC_WORKGROUP, "workgroup", "shared memory" },
1309 { SC_BUFFER_VARIABLE_POINTERS, "buffer_varptr", "SSBO w/variable pointers" },
1310 { SC_WORKGROUP_VARIABLE_POINTERS, "workgroup_varptr", "shared memory w/variable pointers" },
1311 { SC_PHYSICAL_STORAGE_BUFFER, "physical_buffer", "physical_storage_buffer" },
1312 };
1313
1314 for (int ttNdx = 0; ttNdx < DE_LENGTH_OF_ARRAY(ttCases); ttNdx++)
1315 {
1316 de::MovePtr<tcu::TestCaseGroup> ttGroup(new tcu::TestCaseGroup(testCtx, ttCases[ttNdx].name, ttCases[ttNdx].description));
1317 for (int dtNdx = 0; dtNdx < DE_LENGTH_OF_ARRAY(dtCases); dtNdx++)
1318 {
1319 de::MovePtr<tcu::TestCaseGroup> dtGroup(new tcu::TestCaseGroup(testCtx, dtCases[dtNdx].name, dtCases[dtNdx].description));
1320 for (int scNdx = 0; scNdx < DE_LENGTH_OF_ARRAY(scCases); scNdx++)
1321 {
1322 de::MovePtr<tcu::TestCaseGroup> scGroup(new tcu::TestCaseGroup(testCtx, scCases[scNdx].name, scCases[scNdx].description));
1323 for (int colNdx = 0; colNdx < DE_LENGTH_OF_ARRAY(colCases); colNdx++)
1324 {
1325 TestType testType = (TestType)ttCases[ttNdx].value;
1326 VkComponentTypeNV inputType = (VkComponentTypeNV)dtCases[dtNdx].value[0];
1327 VkComponentTypeNV outputType = (VkComponentTypeNV)dtCases[dtNdx].value[1];
1328
1329 bool isMatrixMul = testType == TT_MATRIXMULADD || testType == TT_MATRIXMULADD_ARRAY;
1330
1331 if (!isMatrixMul && testType != TT_CONVERT && inputType != outputType)
1332 continue;
1333
1334 if (testType == TT_CONVERT && inputType == outputType)
1335 continue;
1336
1337 if (isMatrixMul && componentTypeInfo[inputType].bits > componentTypeInfo[outputType].bits)
1338 continue;
1339
1340 CaseDef c =
1341 {
1342 testType, // TestType testtype;
1343 2u, // deUint32 subgroupsPerWorkgroupX;
1344 2u, // deUint32 subgroupsPerWorkgroupY;
1345 4u, // deUint32 workgroupsX;
1346 4u, // deUint32 workgroupsY;
1347 (VkComponentTypeNV)inputType, // VkComponentTypeNV inputType;
1348 (VkComponentTypeNV)outputType, // VkComponentTypeNV outputType;
1349 !!colCases[colNdx].value, // bool colMajor;
1350 (StorageClass)scCases[scNdx].value, // StorageClass storageClass;
1351 };
1352
1353 scGroup->addChild(new CooperativeMatrixTestCase(testCtx, colCases[colNdx].name, colCases[colNdx].description, c));
1354 }
1355 dtGroup->addChild(scGroup.release());
1356 }
1357 ttGroup->addChild(dtGroup.release());
1358 }
1359 group->addChild(ttGroup.release());
1360 }
1361 return group.release();
1362 }
1363
1364 } // compute
1365 } // vkt
1366