1 /*------------------------------------------------------------------------
2 * Vulkan Conformance Tests
3 * ------------------------
4 *
5 * Copyright (c) 2019 The Khronos Group Inc.
6 * Copyright (c) 2018-2019 NVIDIA Corporation
7 * Copyright (c) 2023 LunarG, Inc.
8 * Copyright (c) 2023 Nintendo
9 *
10 * Licensed under the Apache License, Version 2.0 (the "License");
11 * you may not use this file except in compliance with the License.
12 * You may obtain a copy of the License at
13 *
14 * http://www.apache.org/licenses/LICENSE-2.0
15 *
16 * Unless required by applicable law or agreed to in writing, software
17 * distributed under the License is distributed on an "AS IS" BASIS,
18 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19 * See the License for the specific language governing permissions and
20 * limitations under the License.
21 *
22 *//*!
23 * \file
24 * \brief Vulkan Cooperative Matrix tests
25 *//*--------------------------------------------------------------------*/
26
27 #include "vktComputeCooperativeMatrixTests.hpp"
28
29 #include "vkBufferWithMemory.hpp"
30 #include "vkImageWithMemory.hpp"
31 #include "vkQueryUtil.hpp"
32 #include "vkBuilderUtil.hpp"
33 #include "vkCmdUtil.hpp"
34 #include "vkTypeUtil.hpp"
35 #include "vkObjUtil.hpp"
36
37 #include "vktTestGroupUtil.hpp"
38 #include "vktTestCase.hpp"
39
40 #include "deDefs.h"
41 #include "deFloat16.h"
42 #include "deMath.h"
43 #include "deRandom.h"
44 #include "deSharedPtr.hpp"
45 #include "deString.h"
46
47 #include "tcuTestCase.hpp"
48 #include "tcuTestLog.hpp"
49
50 #include <string>
51 #include <sstream>
52 #include <set>
53 #include <algorithm>
54
55 namespace vkt
56 {
57 namespace compute
58 {
59 namespace
60 {
61 using namespace vk;
62 using namespace std;
63
64 //#define COOPERATIVE_MATRIX_EXTENDED_DEBUG 1
65
66 DE_STATIC_ASSERT((uint32_t)VK_COMPONENT_TYPE_FLOAT16_KHR == (uint32_t)VK_COMPONENT_TYPE_FLOAT16_NV);
67 DE_STATIC_ASSERT((uint32_t)VK_COMPONENT_TYPE_FLOAT32_KHR == (uint32_t)VK_COMPONENT_TYPE_FLOAT32_NV);
68 DE_STATIC_ASSERT((uint32_t)VK_COMPONENT_TYPE_FLOAT64_KHR == (uint32_t)VK_COMPONENT_TYPE_FLOAT64_NV);
69 DE_STATIC_ASSERT((uint32_t)VK_COMPONENT_TYPE_SINT8_KHR == (uint32_t)VK_COMPONENT_TYPE_SINT8_NV );
70 DE_STATIC_ASSERT((uint32_t)VK_COMPONENT_TYPE_SINT16_KHR == (uint32_t)VK_COMPONENT_TYPE_SINT16_NV );
71 DE_STATIC_ASSERT((uint32_t)VK_COMPONENT_TYPE_SINT32_KHR == (uint32_t)VK_COMPONENT_TYPE_SINT32_NV );
72 DE_STATIC_ASSERT((uint32_t)VK_COMPONENT_TYPE_SINT64_KHR == (uint32_t)VK_COMPONENT_TYPE_SINT64_NV );
73 DE_STATIC_ASSERT((uint32_t)VK_COMPONENT_TYPE_UINT8_KHR == (uint32_t)VK_COMPONENT_TYPE_UINT8_NV );
74 DE_STATIC_ASSERT((uint32_t)VK_COMPONENT_TYPE_UINT16_KHR == (uint32_t)VK_COMPONENT_TYPE_UINT16_NV );
75 DE_STATIC_ASSERT((uint32_t)VK_COMPONENT_TYPE_UINT32_KHR == (uint32_t)VK_COMPONENT_TYPE_UINT32_NV );
76 DE_STATIC_ASSERT((uint32_t)VK_COMPONENT_TYPE_UINT64_KHR == (uint32_t)VK_COMPONENT_TYPE_UINT64_NV );
77
78 DE_STATIC_ASSERT((uint32_t)VK_SCOPE_DEVICE_KHR == (uint32_t)VK_SCOPE_DEVICE_NV);
79 DE_STATIC_ASSERT((uint32_t)VK_SCOPE_WORKGROUP_KHR == (uint32_t)VK_SCOPE_WORKGROUP_NV);
80 DE_STATIC_ASSERT((uint32_t)VK_SCOPE_SUBGROUP_KHR == (uint32_t)VK_SCOPE_SUBGROUP_NV);
81 DE_STATIC_ASSERT((uint32_t)VK_SCOPE_QUEUE_FAMILY_KHR == (uint32_t)VK_SCOPE_QUEUE_FAMILY_NV);
82
83 typedef enum
84 {
85 UT_NV = 0,
86 UT_KHR_A,
87 UT_KHR_B,
88 UT_KHR_Result,
89 } UseType;
90
91 typedef enum
92 {
93 TT_LENGTH = 0,
94 TT_CONSTANT,
95 TT_CONVERT,
96 TT_COMPOSITE,
97 TT_COMPOSITE_RVALUE,
98 TT_ADD,
99 TT_SUB,
100 TT_DIV,
101 TT_MUL,
102 TT_NEGATE,
103 TT_MATRIXTIMESSCALAR,
104 TT_FUNC,
105 TT_MATRIXMULADD,
106 TT_COMPOSITE_ARRAY,
107 TT_MATRIXMULADD_ARRAY,
108 TT_MATRIXMULADD_SATURATED,
109 TT_MATRIXMULADD_WRAPPING,
110 TT_MATRIXMULADD_STRIDE0,
111 } TestType;
112
113 typedef enum
114 {
115 SC_BUFFER = 0,
116 SC_WORKGROUP,
117 SC_WORKGROUP_VARIABLE_POINTERS,
118 SC_BUFFER_VARIABLE_POINTERS,
119 SC_PHYSICAL_STORAGE_BUFFER,
120 } StorageClass;
121
122 enum SubgroupSizeMode
123 {
124 SUBGROUP_SIZE_NONE = 0,
125 SUBGROUP_SIZE_MIN = 1,
126 SUBGROUP_SIZE_MAX = 2,
127 };
128
129 const VkFlags allShaderStages = VK_SHADER_STAGE_COMPUTE_BIT;
130
131 struct CaseDef
132 {
133 TestType testType;
134 deUint32 subgroupsPerWorkgroupX;
135 deUint32 subgroupsPerWorkgroupY;
136 deUint32 workgroupsX;
137 deUint32 workgroupsY;
138 VkComponentTypeKHR inputType;
139 VkComponentTypeKHR outputType;
140 bool colMajor;
141 StorageClass storageClass;
142 UseType useType;
143 SubgroupSizeMode subgroupSizeMode;
144 vk::ComputePipelineConstructionType computePipelineConstructionType;
145 };
146
isKhr(UseType useType)147 bool isKhr (UseType useType)
148 {
149 return useType != UT_NV;
150 }
151
isMatrixMulAddOp(TestType testType)152 bool isMatrixMulAddOp (TestType testType)
153 {
154 return testType == TT_MATRIXMULADD || testType == TT_MATRIXMULADD_ARRAY || testType == TT_MATRIXMULADD_SATURATED || testType == TT_MATRIXMULADD_WRAPPING || testType == TT_MATRIXMULADD_STRIDE0;
155 }
156
157 template<typename T>
getCooperativeMatrixProperties(const InstanceInterface &,VkPhysicalDevice,uint32_t *,T *)158 VkResult getCooperativeMatrixProperties (const InstanceInterface&, VkPhysicalDevice, uint32_t*, T*)
159 {
160 TCU_THROW(InternalError, "Not Implementetd");
161 }
162
getCooperativeMatrixProperties(const InstanceInterface & vki,VkPhysicalDevice physicalDevice,uint32_t * pPropertyCount,VkCooperativeMatrixPropertiesKHR * pProperties)163 VkResult getCooperativeMatrixProperties (const InstanceInterface& vki, VkPhysicalDevice physicalDevice, uint32_t* pPropertyCount, VkCooperativeMatrixPropertiesKHR* pProperties)
164 {
165 return vki.getPhysicalDeviceCooperativeMatrixPropertiesKHR(physicalDevice, pPropertyCount, pProperties);
166 }
167
getCooperativeMatrixProperties(const InstanceInterface & vki,VkPhysicalDevice physicalDevice,uint32_t * pPropertyCount,VkCooperativeMatrixPropertiesNV * pProperties)168 VkResult getCooperativeMatrixProperties (const InstanceInterface& vki, VkPhysicalDevice physicalDevice, uint32_t* pPropertyCount, VkCooperativeMatrixPropertiesNV* pProperties)
169 {
170 return vki.getPhysicalDeviceCooperativeMatrixPropertiesNV(physicalDevice, pPropertyCount, pProperties);
171 }
172
convertCooperativeMatrixProperties(const VkCooperativeMatrixPropertiesNV & properties)173 VkCooperativeMatrixPropertiesKHR convertCooperativeMatrixProperties (const VkCooperativeMatrixPropertiesNV& properties)
174 {
175 VkCooperativeMatrixPropertiesKHR result = initVulkanStructure();
176
177 result.sType = (VkStructureType) properties.sType;
178 result.pNext = (void*) properties.pNext;
179 result.MSize = (uint32_t) properties.MSize;
180 result.NSize = (uint32_t) properties.NSize;
181 result.KSize = (uint32_t) properties.KSize;
182 result.AType = (VkComponentTypeKHR) properties.AType;
183 result.BType = (VkComponentTypeKHR) properties.BType;
184 result.CType = (VkComponentTypeKHR) properties.CType;
185 result.ResultType = (VkComponentTypeKHR) properties.DType;
186 result.saturatingAccumulation = (VkBool32) VK_FALSE;
187 result.scope = (VkScopeKHR) properties.scope;
188
189 return result;
190 }
191
convertCooperativeMatrixProperties(const std::vector<VkCooperativeMatrixPropertiesNV> & properties)192 std::vector<VkCooperativeMatrixPropertiesKHR> convertCooperativeMatrixProperties (const std::vector <VkCooperativeMatrixPropertiesNV>& properties)
193 {
194 std::vector<VkCooperativeMatrixPropertiesKHR> result(properties.size());
195
196 for (size_t i = 0; i < properties.size(); ++i)
197 result[i] = convertCooperativeMatrixProperties(properties[i]);
198
199 return result;
200 }
201
202 template<typename T>
getCooperativeMatrixPropertiesAll(Context & context,std::vector<T> & properties)203 void getCooperativeMatrixPropertiesAll (Context& context, std::vector<T>& properties)
204 {
205 deUint32 propertyCount = 0;
206
207 VK_CHECK(getCooperativeMatrixProperties(context.getInstanceInterface(), context.getPhysicalDevice(), &propertyCount, (T*)DE_NULL));
208
209 if (propertyCount > 0)
210 {
211 const T sample = initVulkanStructureConst();
212
213 properties.resize(propertyCount, sample);
214
215 VK_CHECK(getCooperativeMatrixProperties(context.getInstanceInterface(), context.getPhysicalDevice(), &propertyCount, properties.data()));
216 }
217 else
218 {
219 properties.clear();
220 }
221 }
222
getCooperativeMatrixPropertiesConverted(Context & context,const bool khr)223 std::vector<VkCooperativeMatrixPropertiesKHR> getCooperativeMatrixPropertiesConverted (Context& context, const bool khr)
224 {
225 std::vector<VkCooperativeMatrixPropertiesKHR> properties;
226
227 if (khr)
228 {
229 getCooperativeMatrixPropertiesAll(context, properties);
230 }
231 else
232 {
233 std::vector<VkCooperativeMatrixPropertiesNV> propertiesNV;
234
235 getCooperativeMatrixPropertiesAll(context, propertiesNV);
236
237 properties = convertCooperativeMatrixProperties(propertiesNV);
238 }
239
240 return properties;
241 }
242
getSubgroupSizeFromMode(Context & context,const SubgroupSizeMode subgroupSizeMode)243 deUint32 getSubgroupSizeFromMode (Context& context,
244 const SubgroupSizeMode subgroupSizeMode)
245 {
246 #ifndef CTS_USES_VULKANSC
247 const VkPhysicalDeviceSubgroupSizeControlProperties& subgroupSizeControlProperties = context.getSubgroupSizeControlProperties();
248 #else
249 const VkPhysicalDeviceSubgroupSizeControlPropertiesEXT& subgroupSizeControlProperties = context.getSubgroupSizeControlPropertiesEXT();
250 #endif // CTS_USES_VULKANSC
251
252 switch (subgroupSizeMode)
253 {
254 case SUBGROUP_SIZE_MAX: return subgroupSizeControlProperties.maxSubgroupSize;
255 case SUBGROUP_SIZE_MIN: return subgroupSizeControlProperties.minSubgroupSize;
256 case SUBGROUP_SIZE_NONE: return context.getSubgroupProperties().subgroupSize;
257 default: TCU_THROW(NotSupportedError, "Unsupported Subgroup size");
258 }
259 }
260
261
262 class CooperativeMatrixTestInstance : public TestInstance
263 {
264 public:
265 CooperativeMatrixTestInstance (Context& context, const CaseDef& data);
266 ~CooperativeMatrixTestInstance (void);
267 tcu::TestStatus iterate (void);
268 private:
269 CaseDef m_data;
270 };
271
CooperativeMatrixTestInstance(Context & context,const CaseDef & data)272 CooperativeMatrixTestInstance::CooperativeMatrixTestInstance (Context& context, const CaseDef& data)
273 : vkt::TestInstance (context)
274 , m_data (data)
275 {
276 }
277
~CooperativeMatrixTestInstance(void)278 CooperativeMatrixTestInstance::~CooperativeMatrixTestInstance (void)
279 {
280 }
281
282 class CooperativeMatrixTestCase : public TestCase
283 {
284 public:
285 CooperativeMatrixTestCase (tcu::TestContext& context, const char* name, const CaseDef data);
286 ~CooperativeMatrixTestCase (void);
287 virtual void initPrograms (SourceCollections& programCollection) const;
288 virtual TestInstance* createInstance (Context& context) const;
289 virtual void checkSupport (Context& context) const;
290
291 private:
292 CaseDef m_data;
293 };
294
CooperativeMatrixTestCase(tcu::TestContext & context,const char * name,const CaseDef data)295 CooperativeMatrixTestCase::CooperativeMatrixTestCase (tcu::TestContext& context, const char* name, const CaseDef data)
296 : vkt::TestCase (context, name)
297 , m_data (data)
298 {
299 }
300
~CooperativeMatrixTestCase(void)301 CooperativeMatrixTestCase::~CooperativeMatrixTestCase (void)
302 {
303 }
304
checkSupport(Context & context) const305 void CooperativeMatrixTestCase::checkSupport (Context& context) const
306 {
307 if (!context.contextSupports(vk::ApiVersion(0, 1, 1, 0)))
308 {
309 TCU_THROW(NotSupportedError, "Vulkan 1.1 not supported");
310 }
311
312 if (isKhr(m_data.useType))
313 {
314 if (!context.getCooperativeMatrixFeatures().cooperativeMatrix)
315 {
316 TCU_THROW(NotSupportedError, "VkPhysicalDeviceCooperativeMatrixFeaturesKHR::cooperativeMatrix not supported");
317 }
318 }
319 else
320 {
321 if (!context.getCooperativeMatrixFeaturesNV().cooperativeMatrix)
322 {
323 TCU_THROW(NotSupportedError, "VkPhysicalDeviceCooperativeMatrixFeaturesNV::cooperativeMatrix not supported");
324 }
325 }
326
327 if (!context.getVulkanMemoryModelFeatures().vulkanMemoryModel)
328 {
329 TCU_THROW(NotSupportedError, "vulkanMemoryModel not supported");
330 }
331
332 if ((m_data.storageClass == SC_WORKGROUP_VARIABLE_POINTERS || m_data.storageClass == SC_BUFFER_VARIABLE_POINTERS) &&
333 !context.getVariablePointersFeatures().variablePointers)
334 {
335 TCU_THROW(NotSupportedError, "variable pointers not supported");
336 }
337
338 if (m_data.storageClass == SC_PHYSICAL_STORAGE_BUFFER && !context.isBufferDeviceAddressSupported())
339 {
340 TCU_THROW(NotSupportedError, "buffer device address not supported");
341 }
342
343 if (!context.getShaderFloat16Int8Features().shaderFloat16 &&
344 (m_data.inputType == VK_COMPONENT_TYPE_FLOAT16_KHR || m_data.outputType == VK_COMPONENT_TYPE_FLOAT16_KHR))
345 {
346 TCU_THROW(NotSupportedError, "shaderFloat16 not supported");
347 }
348
349 std::vector<VkCooperativeMatrixPropertiesKHR> properties = getCooperativeMatrixPropertiesConverted(context, isKhr(m_data.useType));
350 bool supported[2] = { false, false };
351 const auto isMMA = isMatrixMulAddOp(m_data.testType);
352 const auto isMMASat = m_data.testType == TT_MATRIXMULADD_SATURATED;
353
354 for (size_t i = 0; i < properties.size(); ++i)
355 {
356 const VkCooperativeMatrixPropertiesKHR* p = &properties[i];
357
358 if (p->scope != VK_SCOPE_SUBGROUP_KHR)
359 continue;
360
361 if (isMMA && isMMASat != static_cast<bool>(p->saturatingAccumulation))
362 continue;
363
364 if (isMMA)
365 {
366 if (p->AType == m_data.inputType &&
367 p->BType == m_data.inputType &&
368 p->CType == m_data.outputType &&
369 p->ResultType == m_data.outputType)
370 {
371 supported[0] = supported[1] = true;
372 }
373 }
374 else
375 {
376 const VkComponentTypeKHR types[2] = { m_data.inputType, m_data.outputType };
377
378 for (deUint32 j = 0; j < 2; ++j)
379 {
380 switch (m_data.useType)
381 {
382 case UT_NV:
383 {
384 if (p->AType == types[j] || p->BType == types[j] || p->CType == types[j] || p->ResultType == types[j])
385 supported[j] = true;
386
387 break;
388 }
389 case UT_KHR_A:
390 {
391 if (p->AType == types[j])
392 supported[j] = true;
393
394 break;
395 }
396 case UT_KHR_B:
397 {
398 if (p->BType == types[j])
399 supported[j] = true;
400
401 break;
402 }
403 case UT_KHR_Result:
404 {
405 if (p->ResultType == types[j])
406 supported[j] = true;
407
408 break;
409 }
410 default:
411 TCU_THROW(InternalError, "Unsupported use type");
412 }
413 }
414 }
415 }
416
417 if (!supported[0] || !supported[1])
418 TCU_THROW(NotSupportedError, "cooperative matrix combination not supported");
419
420 checkShaderObjectRequirements(context.getInstanceInterface(), context.getPhysicalDevice(), m_data.computePipelineConstructionType);
421 }
422
423 struct {
424 const char *typeName;
425 const char *coopmatTypeName;
426 deUint32 bits;
427 bool isSigned;
428 } componentTypeInfo[] =
429 {
430 { "float16_t", "fcoopmatNV", 16, true },
431 { "float32_t", "fcoopmatNV", 32, true },
432 { "float64_t", "fcoopmatNV", 64, true },
433 { "int8_t", "icoopmatNV", 8, true },
434 { "int16_t", "icoopmatNV", 16, true },
435 { "int32_t", "icoopmatNV", 32, true },
436 { "int64_t", "icoopmatNV", 64, true },
437 { "uint8_t", "ucoopmatNV", 8, false },
438 { "uint16_t", "ucoopmatNV", 16, false },
439 { "uint32_t", "ucoopmatNV", 32, false },
440 { "uint64_t", "ucoopmatNV", 64, false },
441 };
442
isFloatType(VkComponentTypeKHR t)443 bool isFloatType (VkComponentTypeKHR t)
444 {
445 switch (t)
446 {
447 case VK_COMPONENT_TYPE_FLOAT16_KHR:
448 case VK_COMPONENT_TYPE_FLOAT32_KHR:
449 case VK_COMPONENT_TYPE_FLOAT64_KHR:
450 return true;
451 default:
452 return false;
453 }
454 }
455
isSIntType(VkComponentTypeKHR t)456 bool isSIntType (VkComponentTypeKHR t)
457 {
458 switch (t)
459 {
460 case VK_COMPONENT_TYPE_SINT8_KHR:
461 case VK_COMPONENT_TYPE_SINT16_KHR:
462 case VK_COMPONENT_TYPE_SINT32_KHR:
463 case VK_COMPONENT_TYPE_SINT64_KHR:
464 return true;
465 default:
466 return false;
467 }
468 }
469
initPrograms(SourceCollections & programCollection) const470 void CooperativeMatrixTestCase::initPrograms (SourceCollections& programCollection) const
471 {
472 const char* suffix = (isKhr(m_data.useType) ? "" : "NV");
473 const char* ext = isKhr(m_data.useType)
474 ? "#extension GL_KHR_cooperative_matrix : enable\n"
475 : "#extension GL_NV_cooperative_matrix : enable\n"
476 "#extension GL_NV_integer_cooperative_matrix : enable\n";
477 const char* sat = (m_data.testType == TT_MATRIXMULADD_SATURATED) ? ", gl_MatrixOperandsSaturatingAccumulation" : "";
478 std::stringstream css;
479 css << "#version 450 core\n";
480 css << "#pragma use_vulkan_memory_model\n";
481 css <<
482 "#extension GL_KHR_shader_subgroup_basic : enable\n"
483 "#extension GL_KHR_memory_scope_semantics : enable\n"
484 << ext <<
485 "#extension GL_EXT_shader_explicit_arithmetic_types : enable\n"
486 "#extension GL_EXT_buffer_reference : enable\n"
487 "// strides overriden by spec constants\n"
488 "layout(constant_id = 2) const int AStride = 1;\n"
489 "layout(constant_id = 3) const int BStride = 1;\n"
490 "layout(constant_id = 4) const int CStride = 1;\n"
491 "layout(constant_id = 5) const int OStride = 1;\n"
492 "layout(constant_id = 6) const int M = 1;\n"
493 "layout(constant_id = 7) const int N = 1;\n"
494 "layout(constant_id = 8) const int K = 1;\n"
495 "layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z = 1) in;\n";
496
497 if (m_data.storageClass == SC_BUFFER_VARIABLE_POINTERS || m_data.storageClass == SC_WORKGROUP_VARIABLE_POINTERS)
498 css << "#pragma use_variable_pointers\n";
499
500 struct
501 {
502 string rows, cols;
503 } dims[4];
504
505 if (isMatrixMulAddOp(m_data.testType))
506 {
507 dims[0].rows = "M";
508 dims[0].cols = "K";
509 dims[1].rows = "K";
510 dims[1].cols = "N";
511 dims[2].rows = "M";
512 dims[2].cols = "N";
513 dims[3].rows = "M";
514 dims[3].cols = "N";
515 }
516 else
517 {
518 dims[0].rows = "M";
519 dims[0].cols = "N";
520 dims[1].rows = "M";
521 dims[1].cols = "N";
522 dims[2].rows = "M";
523 dims[2].cols = "N";
524 dims[3].rows = "M";
525 dims[3].cols = "N";
526 }
527
528 const char *typeStrA = componentTypeInfo[m_data.inputType].typeName;
529 const char *typeStrB = componentTypeInfo[m_data.inputType].typeName;
530 const char *typeStrC = componentTypeInfo[m_data.outputType].typeName;
531 const char *typeStrO = componentTypeInfo[m_data.outputType].typeName;
532
533 css << "const int workgroupsX = " << m_data.workgroupsX << ";\n";
534 css << "const uvec2 subgroupsPerWG = uvec2(" << m_data.subgroupsPerWorkgroupX << ", " << m_data.subgroupsPerWorkgroupY << ");\n";
535
536 if (m_data.storageClass == SC_PHYSICAL_STORAGE_BUFFER)
537 {
538 css << "layout(buffer_reference) buffer InputA { " << typeStrA << " x[]; };\n";
539 css << "layout(buffer_reference) buffer InputB { " << typeStrB << " x[]; };\n";
540 css << "layout(buffer_reference) buffer InputC { " << typeStrC << " x[]; };\n";
541 css << "layout(buffer_reference) buffer Output { " << typeStrO << " x[]; };\n";
542 css << "layout(set=0, binding=4) buffer Params { InputA inputA; InputB inputB; InputC inputC; Output outputO; } params;\n";
543 }
544 else
545 {
546 css << "layout(set=0, binding=0) coherent buffer InputA { " << typeStrA << " x[]; } inputA;\n";
547 css << "layout(set=0, binding=1) coherent buffer InputB { " << typeStrB << " x[]; } inputB;\n";
548 css << "layout(set=0, binding=2) coherent buffer InputC { " << typeStrC << " x[]; } inputC;\n";
549 css << "layout(set=0, binding=3) coherent buffer Output { " << typeStrO << " x[]; } outputO;\n";
550 }
551
552 if (m_data.storageClass == SC_WORKGROUP || m_data.storageClass == SC_WORKGROUP_VARIABLE_POINTERS)
553 {
554 css << "shared " << typeStrA << " sharedA[" << dims[0].rows << " * " << dims[0].cols << " * subgroupsPerWG.x * subgroupsPerWG.y];\n";
555 css << "shared " << typeStrB << " sharedB[" << dims[1].rows << " * " << dims[1].cols << " * subgroupsPerWG.x * subgroupsPerWG.y];\n";
556 css << "shared " << typeStrC << " sharedC[" << dims[2].rows << " * " << dims[2].cols << " * subgroupsPerWG.x * subgroupsPerWG.y];\n";
557 css << "shared " << typeStrO << " sharedO[" << dims[3].rows << " * " << dims[3].cols << " * subgroupsPerWG.x * subgroupsPerWG.y];\n";
558 }
559
560 std::stringstream matAType, matBType, matCType, outputMatType;
561
562 if (isKhr(m_data.useType))
563 {
564 const bool useSame = !isMatrixMulAddOp(m_data.testType);
565 const char* sameType = m_data.useType == UT_KHR_A ? "gl_MatrixUseA"
566 : m_data.useType == UT_KHR_B ? "gl_MatrixUseB"
567 : m_data.useType == UT_KHR_Result ? "gl_MatrixUseAccumulator"
568 : "Invalid use";
569 const char* atype = useSame ? sameType : "gl_MatrixUseA";
570 const char* btype = useSame ? sameType : "gl_MatrixUseB";
571 const char* ctype = useSame ? sameType : "gl_MatrixUseAccumulator";
572 const char* rtype = useSame ? sameType : "gl_MatrixUseAccumulator";
573
574 matAType << "coopmat<" << componentTypeInfo[m_data.inputType].typeName << ", gl_ScopeSubgroup, " << dims[0].rows << ", " << dims[0].cols << ", " << atype << ">";
575 matBType << "coopmat<" << componentTypeInfo[m_data.inputType].typeName << ", gl_ScopeSubgroup, " << dims[1].rows << ", " << dims[1].cols << ", " << btype << ">";
576 matCType << "coopmat<" << componentTypeInfo[m_data.outputType].typeName << ", gl_ScopeSubgroup, " << dims[2].rows << ", " << dims[2].cols << ", " << ctype << ">";
577 outputMatType << "coopmat<" << componentTypeInfo[m_data.outputType].typeName << ", gl_ScopeSubgroup, " << dims[3].rows << ", " << dims[3].cols << ", " << rtype << ">";
578 }
579 else
580 {
581 matAType << componentTypeInfo[m_data.inputType].coopmatTypeName << "<" << componentTypeInfo[m_data.inputType].bits << ", gl_ScopeSubgroup, " << dims[0].rows << ", " << dims[0].cols << ">";
582 matBType << componentTypeInfo[m_data.inputType].coopmatTypeName << "<" << componentTypeInfo[m_data.inputType].bits << ", gl_ScopeSubgroup, " << dims[1].rows << ", " << dims[1].cols << ">";
583 matCType << componentTypeInfo[m_data.outputType].coopmatTypeName << "<" << componentTypeInfo[m_data.outputType].bits << ", gl_ScopeSubgroup, " << dims[2].rows << ", " << dims[2].cols << ">";
584 outputMatType << componentTypeInfo[m_data.outputType].coopmatTypeName << "<" << componentTypeInfo[m_data.outputType].bits << ", gl_ScopeSubgroup, " << dims[3].rows << ", " << dims[3].cols << ">";
585 }
586
587 css << matAType.str() << " matA;\n";
588 css << matBType.str() << " matB;\n";
589 css << matCType.str() << " matC;\n";
590 css << outputMatType.str() << " matO;\n";
591
592 if (m_data.testType == TT_CONSTANT)
593 css << "const " << outputMatType.str() << " matConst = " << outputMatType.str() << "(1.0);\n";
594
595 if (m_data.testType == TT_FUNC)
596 css << matAType.str() << " f(" << matAType.str() << " m) { return -m; }\n";
597
598 css <<
599 "void main()\n"
600 "{\n"
601 // matrixID is the x,y index of the matrix owned by this subgroup.
602 " uvec2 subgroupXY = uvec2(gl_SubgroupID % subgroupsPerWG.x, gl_SubgroupID / subgroupsPerWG.x);\n"
603 " uvec2 matrixID = uvec2(gl_WorkGroupID.xy) * subgroupsPerWG + subgroupXY;\n";
604
605 if (m_data.storageClass == SC_PHYSICAL_STORAGE_BUFFER)
606 {
607 css << " InputA inputA = params.inputA;\n";
608 css << " InputB inputB = params.inputB;\n";
609 css << " InputC inputC = params.inputC;\n";
610 css << " Output outputO = params.outputO;\n";
611 }
612
613 string strides[4];
614 for (deUint32 i = 0; i < 4; ++i)
615 {
616 strides[i] = (m_data.colMajor ? dims[i].rows : dims[i].cols) + string(" * ") + de::toString(m_data.subgroupsPerWorkgroupX * m_data.workgroupsX);
617 }
618
619 // element<i> is the starting element in buffer memory.
620 // elementS<i> is the starting element in shared memory.
621 css << " uint element0 = " << strides[0] << " * " << (m_data.colMajor ? dims[0].cols : dims[0].rows) << " * matrixID.y + " << (m_data.colMajor ? dims[0].rows : dims[0].cols) << " * matrixID.x;\n"
622 " uint element1 = " << strides[1] << " * " << (m_data.colMajor ? dims[1].cols : dims[1].rows) << " * matrixID.y + " << (m_data.colMajor ? dims[1].rows : dims[1].cols) << " * matrixID.x;\n"
623 " uint element2 = " << strides[2] << " * " << (m_data.colMajor ? dims[2].cols : dims[2].rows) << " * matrixID.y + " << (m_data.colMajor ? dims[2].rows : dims[2].cols) << " * matrixID.x;\n"
624 " uint element3 = " << strides[3] << " * " << (m_data.colMajor ? dims[3].cols : dims[3].rows) << " * matrixID.y + " << (m_data.colMajor ? dims[3].rows : dims[3].cols) << " * matrixID.x;\n"
625 " uint elementS0, elementS1, elementS2, elementS3;\n";
626
627 // For shared memory tests, copy the matrix from buffer memory into
628 // workgroup memory. For simplicity, do it all on a single thread.
629 if (m_data.storageClass == SC_WORKGROUP || m_data.storageClass == SC_WORKGROUP_VARIABLE_POINTERS)
630 {
631 const char *name[] =
632 {
633 "sharedA",
634 "sharedB",
635 "sharedC",
636 };
637 const char *inputName[] =
638 {
639 "inputA",
640 "inputB",
641 "inputC",
642 };
643 for (deUint32 m = 0; m < 4; ++m)
644 {
645 string sharedStride = strides[m] + " / workgroupsX";
646 css << " elementS" << m << " = " << sharedStride << " * " << (m_data.colMajor ? dims[m].cols : dims[m].rows) << " * subgroupXY.y + " << (m_data.colMajor ? dims[m].rows : dims[m].cols) << " * subgroupXY.x;\n";
647 }
648 css << " if (subgroupElect()) {\n";
649 // copy all three input buffers.
650 for (deUint32 m = 0; m < 3; ++m)
651 {
652 string sharedStride = strides[m] + " / workgroupsX";
653 css << " for (int i = 0; i < " << dims[m].rows << "; ++i) {\n"
654 " for (int j = 0; j < " << dims[m].cols << "; ++j) {\n"
655 " int localElementInput = " << strides[m] << " * " << (m_data.colMajor ? "j" : "i") << " + " << (m_data.colMajor ? "i" : "j") << ";\n"
656 " int localElementShared = " << sharedStride << " * " << (m_data.colMajor ? "j" : "i") << " + " << (m_data.colMajor ? "i" : "j") << ";\n"
657 " " << name[m] << "[elementS" << m << " + localElementShared] = " << inputName[m] << ".x[element" << m << " + localElementInput];\n"
658 " }\n"
659 " }\n";
660 strides[m] = sharedStride;
661 }
662 css << " }\n";
663 css << " controlBarrier(gl_ScopeSubgroup, gl_ScopeSubgroup, gl_StorageSemanticsShared, gl_SemanticsAcquireRelease);\n";
664 }
665
666 const char *colMajorNV = (m_data.colMajor ? "true" : "false");
667 const char* colMajorKHR = (m_data.colMajor ? "gl_CooperativeMatrixLayoutColumnMajor" : "gl_CooperativeMatrixLayoutRowMajor");
668 const char* colMajor = (isKhr(m_data.useType) ? colMajorKHR : colMajorNV);
669
670 string loadStrides[3] = { strides[0], strides[1], strides[2] };
671 // Load with a stride of 0
672 if (m_data.testType == TT_MATRIXMULADD_STRIDE0)
673 loadStrides[0] = loadStrides[1] = loadStrides[2] = "0";
674
675 if (m_data.storageClass == SC_WORKGROUP || m_data.storageClass == SC_WORKGROUP_VARIABLE_POINTERS)
676 {
677 css << " coopMatLoad" << suffix << "(matA, sharedA, elementS0, " << loadStrides[0] << ", " << colMajor << ");\n"
678 " coopMatLoad" << suffix << "(matB, sharedB, elementS1, " << loadStrides[1] << ", " << colMajor << ");\n"
679 " coopMatLoad" << suffix << "(matC, sharedC, elementS2, " << loadStrides[2] << ", " << colMajor << ");\n";
680 }
681 else
682 {
683 css << " coopMatLoad" << suffix << "(matA, inputA.x, element0, " << loadStrides[0] << ", " << colMajor << ");\n"
684 " coopMatLoad" << suffix << "(matB, inputB.x, element1, " << loadStrides[1] << ", " << colMajor << ");\n"
685 " coopMatLoad" << suffix << "(matC, inputC.x, element2, " << loadStrides[2] << ", " << colMajor << ");\n";
686 }
687
688 if (m_data.testType == TT_COMPOSITE_ARRAY ||
689 m_data.testType == TT_MATRIXMULADD_ARRAY)
690 {
691 css << " " << matAType.str() << " matAArr[2];\n matAArr[1] = matA; matAArr[0] = " << matAType.str() << "(0.0);\n"
692 " " << matBType.str() << " matBArr[2];\n matBArr[1] = matB; matBArr[0] = " << matBType.str() << "(0.0);\n"
693 " " << matCType.str() << " matCArr[2];\n matCArr[1] = matC; matCArr[0] = " << matCType.str() << "(0.0);\n"
694 " " << outputMatType.str() << " matOArr[2];\n";
695 }
696
697 switch (m_data.testType)
698 {
699 default:
700 DE_ASSERT(0);
701 // fall through
702 case TT_LENGTH:
703 css << " matO = " << outputMatType.str() << "(matO.length());\n";
704 break;
705 case TT_CONSTANT:
706 css << " matO = matConst;\n";
707 break;
708 case TT_CONVERT:
709 css << " matO = " << outputMatType.str() << "(matA);\n";
710 break;
711 case TT_COMPOSITE:
712 css << " " << matAType.str() << " t = " << matAType.str() << "(matB[0]);\n"
713 " for (int i = 1; i < matA.length(); ++i) {\n"
714 " matO[i] = matA[i] + matB[i];\n"
715 " }\n"
716 " if (matA.length() > 0)\n"
717 " matO[0] = matA[0] + t[0];\n";
718 break;
719 case TT_COMPOSITE_RVALUE:
720 css << " for (int i = 1; i < matA.length(); ++i) {\n"
721 " matO[i] = matA[i] + matB[i];\n"
722 " }\n"
723 " " << matAType.str() << " t = matA;\n"
724 " if (matA.length() > 0) {\n"
725 " matO[0] = (t += matB)[0];\n"
726 " }\n";
727 break;
728 case TT_COMPOSITE_ARRAY:
729 css << " for (int i = 0; i < matA.length(); ++i) {\n"
730 " matOArr[1][i] = matAArr[1][i] + matBArr[1][i];\n"
731 " }\n";
732 break;
733 case TT_ADD:
734 css << " matO = matA + matB;\n";
735 break;
736 case TT_SUB:
737 css << " matO = matA - matB;\n";
738 break;
739 case TT_DIV:
740 css << " matO = matA / matB;\n";
741 break;
742 case TT_MUL:
743 css << " matO = matA * matB;\n";
744 break;
745 case TT_NEGATE:
746 css << " matO = -matA;\n";
747 break;
748 case TT_FUNC:
749 css << " matO = f(matA);\n";
750 break;
751 case TT_MATRIXTIMESSCALAR:
752 css << " matO = (" << typeStrA << "(2.0)*matA)*" << typeStrA << "(3.0);\n";
753 break;
754 case TT_MATRIXMULADD_STRIDE0:
755 case TT_MATRIXMULADD_WRAPPING:
756 case TT_MATRIXMULADD_SATURATED:
757 case TT_MATRIXMULADD:
758 css << " matO = coopMatMulAdd" << suffix << "(matA, matB, matC" << sat << ");\n";
759 break;
760 case TT_MATRIXMULADD_ARRAY:
761 css << " matOArr[1] = coopMatMulAdd" << suffix << "(matAArr[1], matBArr[1], matCArr[1]);\n";
762 break;
763 }
764
765 if (m_data.testType == TT_COMPOSITE_ARRAY ||
766 m_data.testType == TT_MATRIXMULADD_ARRAY)
767 {
768 css << " matOArr[0] = " << outputMatType.str() << "(0.0);\n";
769 css << " matO = matOArr[1];\n";
770 }
771
772 if (m_data.storageClass == SC_WORKGROUP || m_data.storageClass == SC_WORKGROUP_VARIABLE_POINTERS)
773 {
774 string sharedStride = strides[3] + " / workgroupsX";
775 css << " coopMatStore" << suffix << "(matO, sharedO, elementS3, " << sharedStride << ", " << colMajor << ");\n";
776 css << " controlBarrier(gl_ScopeSubgroup, gl_ScopeSubgroup, gl_StorageSemanticsShared, gl_SemanticsAcquireRelease);\n";
777 css << " if (subgroupElect()) {\n";
778 css << " for (int i = 0; i < " << dims[3].rows << "; ++i) {\n"
779 " for (int j = 0; j < " << dims[3].cols << "; ++j) {\n"
780 " int localElementInput = " << strides[3] << " * " << (m_data.colMajor ? "j" : "i") << " + " << (m_data.colMajor ? "i" : "j") << ";\n"
781 " int localElementShared = " << sharedStride << " * " << (m_data.colMajor ? "j" : "i") << " + " << (m_data.colMajor ? "i" : "j") << ";\n"
782 " outputO.x[element3 + localElementInput] = sharedO[elementS3 + localElementShared];\n"
783 " }\n"
784 " }\n";
785 css << " }\n";
786 }
787 else
788 {
789 css << " coopMatStore" << suffix << "(matO, outputO.x, element3, " << strides[3] << ", " << colMajor << ");\n";
790 }
791
792 css <<
793 "}\n";
794
795 const vk::ShaderBuildOptions buildOptions (programCollection.usedVulkanVersion, vk::SPIRV_VERSION_1_3, 0u);
796
797 programCollection.glslSources.add("test") << glu::ComputeSource(css.str()) << buildOptions;
798 }
799
createInstance(Context & context) const800 TestInstance* CooperativeMatrixTestCase::createInstance (Context& context) const
801 {
802 return new CooperativeMatrixTestInstance(context, m_data);
803 }
804
setDataFloat(void * base,VkComponentTypeKHR dt,deUint32 i,float value)805 void setDataFloat (void *base, VkComponentTypeKHR dt, deUint32 i, float value)
806 {
807 if (dt == VK_COMPONENT_TYPE_FLOAT32_KHR)
808 {
809 ((float *)base)[i] = value;
810 }
811 else
812 {
813 DE_ASSERT(dt == VK_COMPONENT_TYPE_FLOAT16_KHR);
814 ((deFloat16 *)base)[i] = deFloat32To16(value);
815 }
816 }
817
getDataFloat(void * base,VkComponentTypeKHR dt,deUint32 i)818 float getDataFloat (void *base, VkComponentTypeKHR dt, deUint32 i)
819 {
820 if (dt == VK_COMPONENT_TYPE_FLOAT32_KHR)
821 {
822 return ((float *)base)[i];
823 }
824 else
825 {
826 DE_ASSERT(dt == VK_COMPONENT_TYPE_FLOAT16_KHR);
827 return deFloat16To32(((deFloat16 *)base)[i]);
828 }
829 }
830
setDataInt(void * base,VkComponentTypeKHR dt,deUint32 i,deUint32 value)831 void setDataInt (void *base, VkComponentTypeKHR dt, deUint32 i, deUint32 value)
832 {
833 DE_ASSERT(componentTypeInfo[dt].bits <= 32);
834
835 switch (dt)
836 {
837 case VK_COMPONENT_TYPE_UINT8_KHR: ((deUint8 *)base)[i] = (deUint8)value; break;
838 case VK_COMPONENT_TYPE_UINT16_KHR: ((deUint16 *)base)[i] = (deUint16)value; break;
839 case VK_COMPONENT_TYPE_UINT32_KHR: ((deUint32 *)base)[i] = (deUint32)value; break;
840 case VK_COMPONENT_TYPE_SINT8_KHR: ((deInt8 *)base)[i] = (deInt8)value; break;
841 case VK_COMPONENT_TYPE_SINT16_KHR: ((deInt16 *)base)[i] = (deInt16)value; break;
842 case VK_COMPONENT_TYPE_SINT32_KHR: ((deInt32 *)base)[i] = (deInt32)value; break;
843 default: TCU_THROW(InternalError, "Unsupported type");
844 }
845 }
846
getDataInt(void * base,VkComponentTypeKHR dt,deUint32 i)847 deUint32 getDataInt (void *base, VkComponentTypeKHR dt, deUint32 i)
848 {
849 DE_ASSERT(componentTypeInfo[dt].bits <= 32);
850
851 switch (dt)
852 {
853 case VK_COMPONENT_TYPE_UINT8_KHR: return ((deUint8*)base)[i];
854 case VK_COMPONENT_TYPE_UINT16_KHR: return ((deUint16*)base)[i];
855 case VK_COMPONENT_TYPE_UINT32_KHR: return ((deUint32*)base)[i];
856 case VK_COMPONENT_TYPE_SINT8_KHR: return ((deInt8*)base)[i];
857 case VK_COMPONENT_TYPE_SINT16_KHR: return ((deInt16*)base)[i];
858 case VK_COMPONENT_TYPE_SINT32_KHR: return ((deInt32 *)base)[i];
859 default: TCU_THROW(InternalError, "Unsupported type");
860 }
861 }
862
863 template <typename T>
getDataConvertedToT(void * base,VkComponentTypeKHR dt,deUint32 i)864 T getDataConvertedToT (void *base, VkComponentTypeKHR dt, deUint32 i)
865 {
866 DE_ASSERT(componentTypeInfo[dt].bits <= 32);
867
868 switch (dt)
869 {
870 case VK_COMPONENT_TYPE_UINT8_KHR: return (T)((deUint8*)base)[i];
871 case VK_COMPONENT_TYPE_UINT16_KHR: return (T)((deUint16*)base)[i];
872 case VK_COMPONENT_TYPE_UINT32_KHR: return (T)((deUint32*)base)[i];
873 case VK_COMPONENT_TYPE_SINT8_KHR: return (T)((deInt8*)base)[i];
874 case VK_COMPONENT_TYPE_SINT16_KHR: return (T)((deInt16*)base)[i];
875 case VK_COMPONENT_TYPE_SINT32_KHR: return (T)((deInt32 *)base)[i];
876 case VK_COMPONENT_TYPE_FLOAT32_KHR:
877 {
878 float temp = ((float *)base)[i];
879 if (std::numeric_limits<T>::min() == 0)
880 temp = std::max(temp, 0.0f);
881 return (T)temp;
882 }
883 case VK_COMPONENT_TYPE_FLOAT16_KHR:
884 {
885 float temp = deFloat16To32(((deFloat16 *)base)[i]);
886 if (std::numeric_limits<T>::min() == 0)
887 temp = std::max(temp, 0.0f);
888 return (T)temp;
889 }
890 default:
891 TCU_THROW(InternalError, "Unsupported type");
892 }
893 }
894
895 template<typename T>
satAdd(T a,T b)896 T satAdd(T a, T b)
897 {
898 if (a > 0)
899 {
900 if (b > std::numeric_limits<T>::max() - a)
901 return std::numeric_limits<T>::max();
902 }
903 else if (b < std::numeric_limits<T>::min() - a)
904 {
905 return std::numeric_limits<T>::min();
906 }
907
908 return (T)(a + b);
909 }
910
satAddData(VkComponentTypeKHR dt,deUint32 a,deUint32 b)911 deUint32 satAddData (VkComponentTypeKHR dt, deUint32 a, deUint32 b)
912 {
913 DE_ASSERT(componentTypeInfo[dt].bits <= 32);
914
915 switch (dt)
916 {
917 case VK_COMPONENT_TYPE_UINT8_KHR: return deMinu32(a + b, std::numeric_limits<deUint8>::max());
918 case VK_COMPONENT_TYPE_UINT16_KHR: return deMinu32(a + b, std::numeric_limits<deUint16>::max());
919 case VK_COMPONENT_TYPE_UINT32_KHR: return (a + b >= a) ? a + b : std::numeric_limits<deUint32>::max();
920 case VK_COMPONENT_TYPE_SINT8_KHR: return (deUint32)satAdd((deInt8)a, (deInt8)b);
921 case VK_COMPONENT_TYPE_SINT16_KHR: return (deUint32)satAdd((deInt16)a, (deInt16)b);
922 case VK_COMPONENT_TYPE_SINT32_KHR: return (deUint32)satAdd((deInt32)a, (deInt32)b);
923 default: TCU_THROW(InternalError, "Unsupported type");
924 }
925 }
926
getLimit(VkComponentTypeKHR dt,bool positive)927 deUint32 getLimit (VkComponentTypeKHR dt, bool positive)
928 {
929 DE_ASSERT(componentTypeInfo[dt].bits <= 32);
930
931 switch (dt)
932 {
933 case VK_COMPONENT_TYPE_UINT8_KHR: return deUint32(positive ? std::numeric_limits<deUint8>::max() : std::numeric_limits<deUint8>::min());
934 case VK_COMPONENT_TYPE_UINT16_KHR: return deUint32(positive ? std::numeric_limits<deUint16>::max() : std::numeric_limits<deUint16>::min());
935 case VK_COMPONENT_TYPE_UINT32_KHR: return deUint32(positive ? std::numeric_limits<deUint32>::max() : std::numeric_limits<deUint32>::min());
936 case VK_COMPONENT_TYPE_SINT8_KHR: return deUint32(positive ? std::numeric_limits<deInt8>::max() : std::numeric_limits<deInt8>::min());
937 case VK_COMPONENT_TYPE_SINT16_KHR: return deUint32(positive ? std::numeric_limits<deInt16>::max() : std::numeric_limits<deInt16>::min());
938 case VK_COMPONENT_TYPE_SINT32_KHR: return deUint32(positive ? std::numeric_limits<deInt32>::max() : std::numeric_limits<deInt32>::min());
939 default: TCU_THROW(InternalError, "Unsupported type");
940 }
941 }
942
setSingleElementInt(void * data,VkComponentTypeKHR dt,deUint32 start,deUint32 count,deUint32 step,deUint32 at,deUint32 val)943 void setSingleElementInt (void *data, VkComponentTypeKHR dt, deUint32 start, deUint32 count, deUint32 step, deUint32 at, deUint32 val)
944 {
945 for (deUint32 i = 0; i < count; i++)
946 setDataInt(data, dt, start + i * step, (i == at) ? val : 0);
947 }
948
949 #ifdef COOPERATIVE_MATRIX_EXTENDED_DEBUG
dumpWholeMatrix(void * data,VkComponentTypeKHR dt,bool colMajor,deUint32 matrixElemCount,deUint32 stride)950 string dumpWholeMatrix (void* data, VkComponentTypeKHR dt, bool colMajor, deUint32 matrixElemCount, deUint32 stride)
951 {
952 const deUint32 rowsCount = colMajor ? stride : matrixElemCount / stride;
953 const deUint32 colsCount = colMajor ? matrixElemCount / stride : stride;
954 bool floatType = isFloatType(dt);
955 bool sIntType = isSIntType(dt);
956 std::stringstream ss;
957
958 DE_ASSERT(rowsCount * colsCount == matrixElemCount);
959
960 for (deUint32 r = 0; r < rowsCount; r++)
961 {
962 for (deUint32 c = 0; c < colsCount; c++)
963 {
964 const deUint32 i = colMajor ? rowsCount * c + r : colsCount * r + c;
965
966 if (floatType)
967 ss << getDataFloat(data, dt, i) << "\t";
968 else if (sIntType)
969 ss << (deInt32)getDataInt(data, dt, i) << "\t";
970 else
971 ss << getDataInt(data, dt, i) << "\t";
972 }
973
974 ss << std::endl;
975 }
976
977 return ss.str();
978 }
979 #endif
980
iterate(void)981 tcu::TestStatus CooperativeMatrixTestInstance::iterate (void)
982 {
983 const DeviceInterface& vk = m_context.getDeviceInterface();
984 const VkDevice device = m_context.getDevice();
985 Allocator& allocator = m_context.getDefaultAllocator();
986 MemoryRequirement memoryDeviceAddress = m_data.storageClass == SC_PHYSICAL_STORAGE_BUFFER &&
987 m_context.isDeviceFunctionalitySupported("VK_KHR_buffer_device_address") ? MemoryRequirement::DeviceAddress : MemoryRequirement::Any;
988 qpTestResult finalres = QP_TEST_RESULT_NOT_SUPPORTED;
989 tcu::TestLog& log = m_context.getTestContext().getLog();
990 const bool saturated = (m_data.testType == TT_MATRIXMULADD_SATURATED);
991 const deUint32 subgroupSize = getSubgroupSizeFromMode(m_context, m_data.subgroupSizeMode);
992 const float epsilon = 1.0f / float(1ull<<17); // 131072 is epsilon circa 1e-5
993
994 deRandom rnd;
995 deRandom_init(&rnd, 1234);
996
997 std::vector<VkCooperativeMatrixPropertiesKHR> properties = getCooperativeMatrixPropertiesConverted(m_context, isKhr(m_data.useType));
998
999 struct TestTuple
1000 {
1001 TestTuple() {}
1002 TestTuple(deUint32 m, deUint32 n, deUint32 k) : M(m), N(n), K(k) {}
1003
1004 bool operator<(const TestTuple &other) const
1005 {
1006 return M < other.M ||
1007 (M == other.M && N < other.N) ||
1008 (M == other.M && N == other.N && K < other.K);
1009 }
1010
1011 deUint32 M, N, K;
1012 };
1013
1014 vector<TestTuple> testSizes;
1015
1016 if (isMatrixMulAddOp(m_data.testType))
1017 {
1018 for (size_t i = 0; i < properties.size(); ++i)
1019 {
1020 VkCooperativeMatrixPropertiesKHR *p = &properties[i];
1021
1022 if (p->AType == m_data.inputType &&
1023 p->BType == m_data.inputType &&
1024 p->CType == m_data.outputType &&
1025 p->ResultType == m_data.outputType &&
1026 p->scope == VK_SCOPE_SUBGROUP_KHR)
1027 {
1028 testSizes.push_back(TestTuple(p->MSize, p->NSize, p->KSize));
1029 }
1030 }
1031 }
1032 else
1033 {
1034 set<TestTuple> typeSizes[2];
1035 VkComponentTypeKHR types[2] = { m_data.inputType, m_data.outputType };
1036 const bool aType = (m_data.useType == UT_KHR_A) || (m_data.useType == UT_NV);
1037 const bool bType = (m_data.useType == UT_KHR_B) || (m_data.useType == UT_NV);
1038 const bool rType = (m_data.useType == UT_KHR_Result) || (m_data.useType == UT_NV);
1039
1040 for (deUint32 i = 0; i < properties.size(); ++i)
1041 {
1042 VkCooperativeMatrixPropertiesKHR *p = &properties[i];
1043
1044 if (p->scope != VK_SCOPE_SUBGROUP_KHR)
1045 continue;
1046
1047 for (deUint32 j = 0; j < 2; ++j)
1048 {
1049 // For these tests, m_data.M/N are always the matrix size. Check if they match
1050 // any input or output in the list.
1051 if (aType && p->AType == types[j]) typeSizes[j].insert(TestTuple(p->MSize, p->KSize, 0));
1052 if (bType && p->BType == types[j]) typeSizes[j].insert(TestTuple(p->KSize, p->NSize, 0));
1053 if (rType && (p->CType == types[j] || p->ResultType == types[j])) typeSizes[j].insert(TestTuple(p->MSize, p->NSize, 0));
1054 }
1055 }
1056 // Test those sizes that are supported for both the input and output type.
1057 std::set_intersection(typeSizes[0].begin(), typeSizes[0].end(),
1058 typeSizes[1].begin(), typeSizes[1].end(),
1059 std::back_inserter(testSizes));
1060 }
1061
1062 properties.resize(0);
1063
1064 for (unsigned int s = 0; s < testSizes.size(); ++s)
1065 {
1066 // When testing a multiply, MxNxK is the type of matrix multiply.
1067 // Otherwise, MxN is the size of the input/output matrices
1068 deUint32 M, N, K;
1069 M = testSizes[s].M;
1070 N = testSizes[s].N;
1071 K = testSizes[s].K;
1072
1073 log << tcu::TestLog::Message << "Testing M = " << M << ", N = " << N << ", K = " << K << tcu::TestLog::EndMessage;
1074
1075 struct
1076 {
1077 deUint32 rows, cols;
1078 } dims[4];
1079
1080 if (isMatrixMulAddOp(m_data.testType))
1081 {
1082 dims[0].rows = M;
1083 dims[0].cols = K;
1084 dims[1].rows = K;
1085 dims[1].cols = N;
1086 dims[2].rows = M;
1087 dims[2].cols = N;
1088 dims[3].rows = M;
1089 dims[3].cols = N;
1090 }
1091 else
1092 {
1093 dims[0].rows = M;
1094 dims[0].cols = N;
1095 dims[1].rows = M;
1096 dims[1].cols = N;
1097 dims[2].rows = M;
1098 dims[2].cols = N;
1099 dims[3].rows = M;
1100 dims[3].cols = N;
1101 }
1102
1103 VkComponentTypeKHR dataTypes[4];
1104 size_t elementSize[4];
1105 VkDeviceSize bufferSizes[5];
1106 de::MovePtr<BufferWithMemory> buffers[5];
1107 vk::VkDescriptorBufferInfo bufferDescriptors[5];
1108 deUint32 strides[4]; // in elements
1109 deUint32 loadStrides[4];
1110 deUint32 totalElements[4];
1111
1112 for (deUint32 i = 0; i < 5; ++i)
1113 {
1114 if (i < 4)
1115 {
1116 // A/B use input type, C/D use output type
1117 dataTypes[i] = (i < 2) ? m_data.inputType : m_data.outputType;
1118 elementSize[i] = componentTypeInfo[dataTypes[i]].bits / 8;
1119
1120 strides[i] = (m_data.colMajor ? dims[i].rows : dims[i].cols) * m_data.subgroupsPerWorkgroupX * m_data.workgroupsX;
1121 loadStrides[i] = strides[i];
1122 totalElements[i] = strides[i] * (m_data.colMajor ? dims[i].cols : dims[i].rows) * m_data.subgroupsPerWorkgroupY * m_data.workgroupsY;
1123
1124 bufferSizes[i] = totalElements[i] * elementSize[i];
1125 }
1126 else
1127 {
1128 bufferSizes[4] = sizeof(VkDeviceAddress)*4;
1129 }
1130
1131 try
1132 {
1133 buffers[i] = de::MovePtr<BufferWithMemory>(new BufferWithMemory(
1134 vk, device, allocator, makeBufferCreateInfo(bufferSizes[i], VK_BUFFER_USAGE_STORAGE_BUFFER_BIT|VK_BUFFER_USAGE_TRANSFER_DST_BIT|VK_BUFFER_USAGE_TRANSFER_SRC_BIT|
1135 (memoryDeviceAddress == MemoryRequirement::DeviceAddress ? VK_BUFFER_USAGE_SHADER_DEVICE_ADDRESS_BIT_EXT : 0)),
1136 MemoryRequirement::HostVisible | MemoryRequirement::Cached | MemoryRequirement::Coherent | memoryDeviceAddress));
1137 }
1138 catch (const tcu::NotSupportedError&)
1139 {
1140 buffers[i] = de::MovePtr<BufferWithMemory>(new BufferWithMemory(
1141 vk, device, allocator, makeBufferCreateInfo(bufferSizes[i], VK_BUFFER_USAGE_STORAGE_BUFFER_BIT|VK_BUFFER_USAGE_TRANSFER_DST_BIT|VK_BUFFER_USAGE_TRANSFER_SRC_BIT|
1142 (memoryDeviceAddress == MemoryRequirement::DeviceAddress ? VK_BUFFER_USAGE_SHADER_DEVICE_ADDRESS_BIT_EXT : 0)),
1143 MemoryRequirement::HostVisible | memoryDeviceAddress));
1144 }
1145
1146 bufferDescriptors[i] = makeDescriptorBufferInfo(**buffers[i], 0, bufferSizes[i]);
1147 }
1148
1149 // Load with a stride of 0
1150 if (m_data.testType == TT_MATRIXMULADD_STRIDE0)
1151 loadStrides[0] = loadStrides[1] = loadStrides[2] = loadStrides[3] = 0;
1152
1153 void *ptrs[5];
1154 for (deUint32 i = 0; i < 5; ++i)
1155 {
1156 ptrs[i] = buffers[i]->getAllocation().getHostPtr();
1157 }
1158
1159 vk::DescriptorSetLayoutBuilder layoutBuilder;
1160
1161 layoutBuilder.addSingleBinding(VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, allShaderStages);
1162 layoutBuilder.addSingleBinding(VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, allShaderStages);
1163 layoutBuilder.addSingleBinding(VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, allShaderStages);
1164 layoutBuilder.addSingleBinding(VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, allShaderStages);
1165 layoutBuilder.addSingleBinding(VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, allShaderStages);
1166
1167 vk::Unique<vk::VkDescriptorSetLayout> descriptorSetLayout(layoutBuilder.build(vk, device));
1168
1169 vk::Unique<vk::VkDescriptorPool> descriptorPool(vk::DescriptorPoolBuilder()
1170 .addType(VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, 5u)
1171 .build(vk, device, VK_DESCRIPTOR_POOL_CREATE_FREE_DESCRIPTOR_SET_BIT, 1u));
1172 vk::Unique<vk::VkDescriptorSet> descriptorSet (makeDescriptorSet(vk, device, *descriptorPool, *descriptorSetLayout));
1173
1174 vk::DescriptorSetUpdateBuilder setUpdateBuilder;
1175 if (m_data.storageClass == SC_PHYSICAL_STORAGE_BUFFER)
1176 {
1177 VkBufferDeviceAddressInfo info
1178 {
1179 VK_STRUCTURE_TYPE_BUFFER_DEVICE_ADDRESS_INFO, // VkStructureType sType;
1180 DE_NULL, // const void* pNext;
1181 0, // VkBuffer buffer
1182 };
1183 VkDeviceAddress *addrsInMemory = (VkDeviceAddress *)ptrs[4];
1184 for (deUint32 i = 0; i < 4; ++i)
1185 {
1186 info.buffer = **buffers[i];
1187 VkDeviceAddress addr = vk.getBufferDeviceAddress(device, &info);
1188 addrsInMemory[i] = addr;
1189 }
1190 setUpdateBuilder.writeSingle(*descriptorSet, vk::DescriptorSetUpdateBuilder::Location::binding(4),
1191 VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, &bufferDescriptors[4]);
1192 }
1193 else
1194 {
1195 setUpdateBuilder.writeSingle(*descriptorSet, vk::DescriptorSetUpdateBuilder::Location::binding(0),
1196 VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, &bufferDescriptors[0]);
1197 setUpdateBuilder.writeSingle(*descriptorSet, vk::DescriptorSetUpdateBuilder::Location::binding(1),
1198 VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, &bufferDescriptors[1]);
1199 setUpdateBuilder.writeSingle(*descriptorSet, vk::DescriptorSetUpdateBuilder::Location::binding(2),
1200 VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, &bufferDescriptors[2]);
1201 setUpdateBuilder.writeSingle(*descriptorSet, vk::DescriptorSetUpdateBuilder::Location::binding(3),
1202 VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, &bufferDescriptors[3]);
1203 }
1204
1205 setUpdateBuilder.update(vk, device);
1206
1207 VkPipelineBindPoint bindPoint = VK_PIPELINE_BIND_POINT_COMPUTE;
1208
1209 const deUint32 specData[9] =
1210 {
1211 subgroupSize * m_data.subgroupsPerWorkgroupX,
1212 m_data.subgroupsPerWorkgroupY,
1213 strides[0],
1214 strides[1],
1215 strides[2],
1216 strides[3],
1217 M,
1218 N,
1219 K,
1220 };
1221
1222 const vk::VkSpecializationMapEntry entries[9] =
1223 {
1224 {0, (deUint32)(sizeof(deUint32) * 0), sizeof(deUint32)},
1225 {1, (deUint32)(sizeof(deUint32) * 1), sizeof(deUint32)},
1226 {2, (deUint32)(sizeof(deUint32) * 2), sizeof(deUint32)},
1227 {3, (deUint32)(sizeof(deUint32) * 3), sizeof(deUint32)},
1228 {4, (deUint32)(sizeof(deUint32) * 4), sizeof(deUint32)},
1229 {5, (deUint32)(sizeof(deUint32) * 5), sizeof(deUint32)},
1230 {6, (deUint32)(sizeof(deUint32) * 6), sizeof(deUint32)},
1231 {7, (deUint32)(sizeof(deUint32) * 7), sizeof(deUint32)},
1232 {8, (deUint32)(sizeof(deUint32) * 8), sizeof(deUint32)},
1233 };
1234
1235 const vk::VkSpecializationInfo specInfo =
1236 {
1237 9, // mapEntryCount
1238 entries, // pMapEntries
1239 sizeof(specData), // dataSize
1240 specData // pData
1241 };
1242
1243 for (deUint32 i = 0; i < 4; ++i)
1244 for (deUint32 j = 0; j < totalElements[i]; ++j)
1245 {
1246 if (isFloatType(dataTypes[i]))
1247 {
1248 if (!isMatrixMulAddOp(m_data.testType))
1249 setDataFloat(ptrs[i], dataTypes[i], j, ((float)(deRandom_getUint32(&rnd) & 0xff) - 64.0f)/2.0f);
1250 else
1251 setDataFloat(ptrs[i], dataTypes[i], j, ((float)(deRandom_getUint32(&rnd) & 0xf) - 4.0f)/2.0f);
1252 }
1253 else
1254 {
1255 if (m_data.testType == TT_MATRIXMULADD_WRAPPING)
1256 {
1257 // Choose matrix values that should cause overflow and underflow, to
1258 // verify wrapping behavior. Use the full range of values for A and B.
1259 // For matrix C, use values clustered near where the type wraps (zero
1260 // for unsigned, 2^(N-1) for signed).
1261 deUint32 bits = componentTypeInfo[dataTypes[i]].bits;
1262 deUint32 value;
1263 if (i == 2) {
1264 value = (deRandom_getUint32(&rnd) & 0xff) - 128;
1265 if (componentTypeInfo[dataTypes[i]].isSigned)
1266 value += (1U << (bits - 1));
1267 } else {
1268 deUint32 mask = (bits == 32) ? 0xFFFFFFFFU : ((1U << bits) - 1U);
1269 value = deRandom_getUint32(&rnd) & mask;
1270 }
1271 setDataInt(ptrs[i], dataTypes[i], j, value);
1272 }
1273 else if (m_data.testType == TT_MATRIXMULADD_SATURATED)
1274 {
1275 setDataInt(ptrs[i], dataTypes[i], j, 0);
1276 }
1277 else
1278 {
1279 deUint32 value = (deRandom_getUint32(&rnd) & 0xff) - 128;
1280 setDataInt(ptrs[i], dataTypes[i], j, value);
1281 }
1282 }
1283 }
1284
1285 if (m_data.testType == TT_MATRIXMULADD_SATURATED)
1286 {
1287 // Set 1st row of A to 1,0,0...
1288 setSingleElementInt(ptrs[0], dataTypes[0], 0, dims[0].cols, (m_data.colMajor ? strides[0] : 1), 0, 1);
1289
1290 // Set 1st column of B to 1,0,0...
1291 setSingleElementInt(ptrs[1], dataTypes[1], 0, dims[1].rows, (m_data.colMajor ? 1 : strides[1]), 0, 1);
1292
1293 // Set C element at {0,0} to maximum type value, thus we will have overflow at plus operation in D=A*B+C for this element
1294 setDataInt(ptrs[2], dataTypes[2], 0, getLimit(dataTypes[2], true));
1295
1296 // Check underflow if all involved elements support negative values
1297 if (isSIntType(dataTypes[1]) && isSIntType(dataTypes[2]) && isSIntType(dataTypes[3]))
1298 {
1299 // Set 2nd row of A to 0,1,0,0...
1300 setSingleElementInt(ptrs[0], dataTypes[0], (m_data.colMajor ? 1 : strides[0]), dims[0].cols, (m_data.colMajor ? strides[0] : 1), 1, 1);
1301
1302 // Set 2nd column of B to 0,-1,0,0...
1303 setSingleElementInt(ptrs[1], dataTypes[1], (m_data.colMajor ? strides[1] : 1), dims[1].rows, (m_data.colMajor ? 1 : strides[1]), 1, -1);
1304
1305 // Set C element at {1,1} to minimum type value, thus we will have underflow at plus operation in D=A*B+C for this element
1306 setDataInt(ptrs[2], dataTypes[2], strides[2] + 1, getLimit(dataTypes[2], false));
1307 }
1308 }
1309
1310 flushAlloc(vk, device, buffers[0]->getAllocation());
1311 flushAlloc(vk, device, buffers[1]->getAllocation());
1312 flushAlloc(vk, device, buffers[2]->getAllocation());
1313 flushAlloc(vk, device, buffers[3]->getAllocation());
1314
1315 ComputePipelineWrapper pipeline(vk, device, m_data.computePipelineConstructionType, m_context.getBinaryCollection().get("test"));
1316 pipeline.setDescriptorSetLayout(descriptorSetLayout.get());
1317 pipeline.setSpecializationInfo(specInfo);
1318 pipeline.setSubgroupSize(m_data.subgroupSizeMode == SUBGROUP_SIZE_NONE ? 0 : getSubgroupSizeFromMode(m_context, m_data.subgroupSizeMode));
1319 pipeline.buildPipeline();
1320
1321 const VkQueue queue = m_context.getUniversalQueue();
1322 Move<VkCommandPool> cmdPool = createCommandPool(vk, device, 0, m_context.getUniversalQueueFamilyIndex());
1323 Move<VkCommandBuffer> cmdBuffer = allocateCommandBuffer(vk, device, *cmdPool, VK_COMMAND_BUFFER_LEVEL_PRIMARY);
1324
1325 beginCommandBuffer(vk, *cmdBuffer, 0u);
1326
1327 vk.cmdBindDescriptorSets(*cmdBuffer, bindPoint, pipeline.getPipelineLayout(), 0u, 1, &*descriptorSet, 0u, DE_NULL);
1328 pipeline.bind(*cmdBuffer);
1329
1330 vk.cmdDispatch(*cmdBuffer, m_data.workgroupsX, m_data.workgroupsY, 1);
1331
1332 endCommandBuffer(vk, *cmdBuffer);
1333
1334 submitCommandsAndWait(vk, device, queue, cmdBuffer.get());
1335
1336 invalidateAlloc(vk, device, buffers[3]->getAllocation());
1337
1338 qpTestResult res = QP_TEST_RESULT_PASS;
1339
1340 if (m_data.testType == TT_CONVERT)
1341 {
1342 for (deUint32 i = 0; i < totalElements[3]; ++i)
1343 {
1344 // Store results as double, which has enough range to hold all the other types exactly.
1345 double inputA, output;
1346
1347 // This loads the data according to dataTypes[0], and then converts to the template parameter type
1348 switch (dataTypes[3]) {
1349 case VK_COMPONENT_TYPE_UINT8_KHR: inputA = getDataConvertedToT<uint8_t>(ptrs[0], dataTypes[0], i); break;
1350 case VK_COMPONENT_TYPE_UINT16_KHR: inputA = getDataConvertedToT<uint16_t>(ptrs[0], dataTypes[0], i); break;
1351 case VK_COMPONENT_TYPE_UINT32_KHR: inputA = getDataConvertedToT<uint32_t>(ptrs[0], dataTypes[0], i); break;
1352 case VK_COMPONENT_TYPE_SINT8_KHR: inputA = getDataConvertedToT<int8_t>(ptrs[0], dataTypes[0], i); break;
1353 case VK_COMPONENT_TYPE_SINT16_KHR: inputA = getDataConvertedToT<int16_t>(ptrs[0], dataTypes[0], i); break;
1354 case VK_COMPONENT_TYPE_SINT32_KHR: inputA = getDataConvertedToT<int32_t>(ptrs[0], dataTypes[0], i); break;
1355 case VK_COMPONENT_TYPE_FLOAT32_KHR: inputA = getDataConvertedToT<float>(ptrs[0], dataTypes[0], i); break;
1356 case VK_COMPONENT_TYPE_FLOAT16_KHR:
1357 {
1358 float temp = getDataConvertedToT<float>(ptrs[0], dataTypes[0], i);
1359 inputA = deFloat16To32(deFloat32To16(temp));
1360 break;
1361 }
1362 default: TCU_THROW(InternalError, "Unexpected type");
1363 }
1364
1365 switch (dataTypes[3]) {
1366 case VK_COMPONENT_TYPE_UINT8_KHR: output = getDataConvertedToT<uint8_t>(ptrs[3], dataTypes[3], i); break;
1367 case VK_COMPONENT_TYPE_UINT16_KHR: output = getDataConvertedToT<uint16_t>(ptrs[3], dataTypes[3], i); break;
1368 case VK_COMPONENT_TYPE_UINT32_KHR: output = getDataConvertedToT<uint32_t>(ptrs[3], dataTypes[3], i); break;
1369 case VK_COMPONENT_TYPE_SINT8_KHR: output = getDataConvertedToT<int8_t>(ptrs[3], dataTypes[3], i); break;
1370 case VK_COMPONENT_TYPE_SINT16_KHR: output = getDataConvertedToT<int16_t>(ptrs[3], dataTypes[3], i); break;
1371 case VK_COMPONENT_TYPE_SINT32_KHR: output = getDataConvertedToT<int32_t>(ptrs[3], dataTypes[3], i); break;
1372 case VK_COMPONENT_TYPE_FLOAT32_KHR: output = getDataConvertedToT<float>(ptrs[3], dataTypes[3], i); break;
1373 case VK_COMPONENT_TYPE_FLOAT16_KHR:
1374 {
1375 float temp = getDataConvertedToT<float>(ptrs[3], dataTypes[3], i);
1376 output = deFloat16To32(deFloat32To16(temp));
1377 break;
1378 }
1379 default: TCU_THROW(InternalError, "Unexpected type");
1380 }
1381
1382 if (inputA != output) {
1383 res = QP_TEST_RESULT_FAIL;
1384 break;
1385 }
1386 }
1387 }
1388 else if (isFloatType(dataTypes[0]))
1389 {
1390 if (!isMatrixMulAddOp(m_data.testType))
1391 {
1392 for (deUint32 i = 0; i < totalElements[3]; ++i)
1393 {
1394 float inputA = getDataFloat(ptrs[0], dataTypes[0], i);
1395 float inputB = getDataFloat(ptrs[1], dataTypes[1], i);
1396 float output = getDataFloat(ptrs[3], dataTypes[3], i);
1397 switch (m_data.testType)
1398 {
1399 case TT_LENGTH:
1400 if (output < 1.0f || output > (float)(N*M))
1401 res = QP_TEST_RESULT_FAIL;
1402 // We expect the matrix to be spread evenly across invocations, it is
1403 // surprising (but not necessarily illegal) if not
1404 if (output != (float)(N*M/subgroupSize) &&
1405 res == QP_TEST_RESULT_PASS)
1406 res = QP_TEST_RESULT_QUALITY_WARNING;
1407 break;
1408 case TT_CONSTANT:
1409 if (output != 1.0f)
1410 res = QP_TEST_RESULT_FAIL;
1411 break;
1412 case TT_COMPOSITE:
1413 case TT_COMPOSITE_RVALUE:
1414 case TT_COMPOSITE_ARRAY:
1415 case TT_ADD:
1416 if (output != inputA + inputB)
1417 res = QP_TEST_RESULT_FAIL;
1418 break;
1419 case TT_SUB:
1420 if (output != inputA - inputB)
1421 res = QP_TEST_RESULT_FAIL;
1422 break;
1423 case TT_DIV:
1424 {
1425 float ulp = (m_data.inputType == VK_COMPONENT_TYPE_FLOAT16_KHR) ? 1.0f/1024.0f : 1.0f/(8.0f*1024.0f*1024.0f);
1426 // division allows 2.5ulp, but we'll use 3.
1427 ulp *= 3;
1428 if (inputB != 0 && fabs(output - inputA / inputB) > ulp * fabs(inputA / inputB))
1429 res = QP_TEST_RESULT_FAIL;
1430 }
1431 break;
1432 case TT_MUL:
1433 {
1434 if (dataTypes[0] == VK_COMPONENT_TYPE_FLOAT16_KHR)
1435 {
1436 const float expected32 = inputA * inputB;
1437 const deFloat16 expected16 = deFloat32To16(expected32);
1438 const float expected = deFloat16To32(expected16);
1439
1440 if (output != expected)
1441 res = QP_TEST_RESULT_FAIL;
1442 }
1443 else
1444 {
1445 if (output != inputA * inputB)
1446 res = QP_TEST_RESULT_FAIL;
1447 }
1448 break;
1449 }
1450 case TT_NEGATE:
1451 case TT_FUNC:
1452 if (output != -inputA)
1453 res = QP_TEST_RESULT_FAIL;
1454 break;
1455 case TT_MATRIXTIMESSCALAR:
1456 if (output != 6.0*inputA)
1457 res = QP_TEST_RESULT_FAIL;
1458 break;
1459 default:
1460 break;
1461 }
1462 }
1463 }
1464 else
1465 {
1466 deUint32 ik, kj, ij;
1467 for (deUint32 mX = 0; mX < m_data.subgroupsPerWorkgroupX*m_data.workgroupsX; ++mX)
1468 {
1469 for (deUint32 mY = 0; mY < m_data.subgroupsPerWorkgroupY*m_data.workgroupsY; ++mY)
1470 {
1471 for (deUint32 i = 0; i < M; ++i)
1472 {
1473 for (deUint32 j = 0; j < N; ++j)
1474 {
1475 float ref = 0;
1476 for (deUint32 k = 0; k < K; ++k)
1477 {
1478 if (m_data.colMajor)
1479 ik = mX * M + i + strides[0] * mY * K + loadStrides[0] * k;
1480 else
1481 ik = mX * K + k + strides[0] * mY * M + loadStrides[0] * i;
1482
1483 float Aik = getDataFloat(ptrs[0], dataTypes[0], ik);
1484
1485 if (m_data.colMajor)
1486 kj = mX * K + k + strides[1] * mY * N + loadStrides[1] * j;
1487 else
1488 kj = mX * N + j + strides[1] * mY * K + loadStrides[1] * k;
1489
1490 float Bkj = getDataFloat(ptrs[1], dataTypes[1], kj);
1491
1492 ref += Aik*Bkj;
1493 }
1494
1495 if (m_data.colMajor)
1496 ij = mX * M + i + strides[2] * mY * N + loadStrides[2] * j;
1497 else
1498 ij = mX * N + j + strides[2] * mY * M + loadStrides[2] * i;
1499
1500 float Cij = getDataFloat(ptrs[2], dataTypes[2], ij);
1501
1502 ref += Cij;
1503
1504 // When loading with stride 0, ij for matrix D is different from matrix C
1505 if (m_data.colMajor)
1506 ij = mX * M + i + strides[2] * (mY * N + j);
1507 else
1508 ij = mX * N + j + strides[2] * (mY * M + i);
1509
1510 float Dij = getDataFloat(ptrs[3], dataTypes[3], ij);
1511
1512 if (fabs(ref - Dij) > epsilon)
1513 {
1514 res = QP_TEST_RESULT_FAIL;
1515 }
1516 }
1517 }
1518 }
1519 }
1520 }
1521 } else {
1522 if (!isMatrixMulAddOp(m_data.testType))
1523 {
1524 for (deUint32 i = 0; i < totalElements[3]; ++i)
1525 {
1526 deUint32 inputA = getDataInt(ptrs[0], dataTypes[0], i);
1527 deUint32 inputB = getDataInt(ptrs[1], dataTypes[1], i);
1528 deUint32 output = getDataInt(ptrs[3], dataTypes[3], i);
1529 int resultSize = componentTypeInfo[dataTypes[3]].bits;
1530 deUint32 mask = resultSize == 32 ? ~0 : ((1 << resultSize) - 1);
1531 switch (m_data.testType)
1532 {
1533 case TT_LENGTH:
1534 if (output < 1 || output > N*M)
1535 res = QP_TEST_RESULT_FAIL;
1536 // We expect the matrix to be spread evenly across invocations, it is
1537 // surprising (but not necessarily illegal) if not
1538 if (output != N*M/subgroupSize &&
1539 res == QP_TEST_RESULT_PASS)
1540 res = QP_TEST_RESULT_QUALITY_WARNING;
1541 break;
1542 case TT_CONSTANT:
1543 if (output != 1)
1544 res = QP_TEST_RESULT_FAIL;
1545 break;
1546 case TT_COMPOSITE:
1547 case TT_COMPOSITE_RVALUE:
1548 case TT_COMPOSITE_ARRAY:
1549 case TT_ADD:
1550 if ((output & mask) != ((inputA + inputB) & mask)) {
1551 res = QP_TEST_RESULT_FAIL;
1552 }
1553 break;
1554 case TT_SUB:
1555 if ((output & mask) != ((inputA - inputB) & mask))
1556 res = QP_TEST_RESULT_FAIL;
1557 break;
1558 case TT_DIV:
1559 {
1560 if (isSIntType(dataTypes[3]))
1561 {
1562 if (inputB != 0 && ((deInt32)output & mask) != (((deInt32)inputA / (deInt32)inputB) & mask))
1563 res = QP_TEST_RESULT_FAIL;
1564 } else
1565 {
1566 if (inputB != 0 && output != inputA / inputB)
1567 res = QP_TEST_RESULT_FAIL;
1568 }
1569 }
1570 break;
1571 case TT_MUL:
1572 {
1573 if (((deInt32)output & mask) != (((deInt32)inputA * (deInt32)inputB) & mask))
1574 {
1575 res = QP_TEST_RESULT_FAIL;
1576 }
1577
1578 break;
1579 }
1580 case TT_NEGATE:
1581 case TT_FUNC:
1582 if ((output & mask) != ((-(deInt32)inputA) & mask))
1583 res = QP_TEST_RESULT_FAIL;
1584 break;
1585 case TT_MATRIXTIMESSCALAR:
1586 if ((output & mask) != ((6*inputA) & mask)) {
1587 res = QP_TEST_RESULT_FAIL;
1588 }
1589 break;
1590 default:
1591 break;
1592 }
1593 }
1594 }
1595 else
1596 {
1597 deUint32 ik, kj, ij;
1598 for (deUint32 mX = 0; mX < m_data.subgroupsPerWorkgroupX*m_data.workgroupsX; ++mX)
1599 {
1600 for (deUint32 mY = 0; mY < m_data.subgroupsPerWorkgroupY*m_data.workgroupsY; ++mY)
1601 {
1602 for (deUint32 i = 0; i < M; ++i)
1603 {
1604 for (deUint32 j = 0; j < N; ++j)
1605 {
1606 deUint32 ref = 0;
1607
1608 for (deUint32 k = 0; k < K; ++k)
1609 {
1610 if (m_data.colMajor)
1611 ik = mX * M + i + strides[0] * mY * K + loadStrides[0] * k;
1612 else
1613 ik = mX * K + k + strides[0] * mY * M + loadStrides[0] * i;
1614
1615 deUint32 Aik = getDataInt(ptrs[0], dataTypes[0], ik);
1616
1617 if (m_data.colMajor)
1618 kj = mX * K + k + strides[1] * mY * N + loadStrides[1] * j;
1619 else
1620 kj = mX * N + j + strides[1] * mY * K + loadStrides[1] * k;
1621
1622 deUint32 Bkj = getDataInt(ptrs[1], dataTypes[1], kj);
1623
1624 ref += Aik*Bkj;
1625 }
1626
1627 if (m_data.colMajor)
1628 ij = mX * M + i + strides[2] * mY * N + loadStrides[2] * j;
1629 else
1630 ij = mX * N + j + strides[2] * mY * M + loadStrides[2] * i;
1631
1632 deUint32 Cij = getDataInt(ptrs[2], dataTypes[2], ij);
1633
1634 if (saturated)
1635 {
1636 ref = satAddData(dataTypes[2], ref, Cij);
1637 }
1638 else
1639 {
1640 ref += Cij;
1641 // truncate the result to the size of C's type.
1642 deUint32 bits = componentTypeInfo[dataTypes[3]].bits;
1643 deUint32 mask = (bits == 32) ? 0xFFFFFFFFU : ((1U << bits) - 1U);
1644 ref &= mask;
1645 }
1646
1647 // When loading with stride 0, ij for matrix D is different from matrix C
1648 if (m_data.colMajor)
1649 ij = mX * M + i + strides[2] * (mY * N + j);
1650 else
1651 ij = mX * N + j + strides[2] * (mY * M + i);
1652
1653 deUint32 Dij = getDataInt(ptrs[3], dataTypes[3], ij);
1654
1655 if (ref != Dij)
1656 {
1657 res = QP_TEST_RESULT_FAIL;
1658 }
1659 }
1660 }
1661 }
1662 }
1663 }
1664 }
1665
1666 if (res != QP_TEST_RESULT_PASS)
1667 {
1668 finalres = res;
1669
1670 log << tcu::TestLog::Message << "failed with M = " << M << ", N = " << N << ", K = " << K << tcu::TestLog::EndMessage;
1671
1672 #ifdef COOPERATIVE_MATRIX_EXTENDED_DEBUG
1673 for (int i = 0; i < 4; i++)
1674 {
1675 const char* matrixNames[] = { "A", "B", "C", "D" };
1676
1677 log << tcu::TestLog::Message
1678 << "Matrix " << matrixNames[i]
1679 << "[rows="
1680 << m_data.subgroupsPerWorkgroupY * m_data.workgroupsY * dims[i].rows
1681 << ", cols="
1682 << m_data.subgroupsPerWorkgroupX * m_data.workgroupsX * dims[i].cols << "]:\n"
1683 << dumpWholeMatrix(ptrs[i], dataTypes[i], m_data.colMajor, totalElements[i], strides[i])
1684 << tcu::TestLog::EndMessage;
1685 }
1686 #endif
1687 }
1688 else
1689 {
1690 if (finalres == QP_TEST_RESULT_NOT_SUPPORTED)
1691 finalres = res;
1692 }
1693 }
1694
1695 return tcu::TestStatus(finalres, qpGetTestResultName(finalres));
1696 }
1697
getUseType(UseType useType)1698 const char* getUseType (UseType useType)
1699 {
1700 switch (useType)
1701 {
1702 case UT_NV: return "nv";
1703 case UT_KHR_A: return "khr_a";
1704 case UT_KHR_B: return "khr_b";
1705 case UT_KHR_Result: return "khr_r";
1706 default: TCU_THROW(InternalError, "Unknown use type");
1707 }
1708 }
1709
createCooperativeMatrixTestsInternal(tcu::TestContext & testCtx,vk::ComputePipelineConstructionType computePipelineConstructionType,UseType useType)1710 tcu::TestCaseGroup* createCooperativeMatrixTestsInternal (tcu::TestContext& testCtx, vk::ComputePipelineConstructionType computePipelineConstructionType, UseType useType)
1711 {
1712 de::MovePtr<tcu::TestCaseGroup> group (new tcu::TestCaseGroup(testCtx, getUseType(useType)));
1713
1714 typedef struct
1715 {
1716 deUint32 value;
1717 const char* name;
1718 } TestGroupCase;
1719
1720 typedef struct
1721 {
1722 deUint32 value[2];
1723 const char* name;
1724 } TestGroupCase2;
1725
1726 typedef struct
1727 {
1728 SubgroupSizeMode value;
1729 const char* name;
1730 } SubGroubSizes;
1731
1732 TestGroupCase ttCases[] =
1733 {
1734 // OpCooperativeMatrixLength
1735 { TT_LENGTH, "length"},
1736 // OpConstantComposite
1737 { TT_CONSTANT, "constant"},
1738 // OpCompositeConstruct
1739 { TT_COMPOSITE, "composite"},
1740 // OpCompositeExtract
1741 { TT_COMPOSITE_RVALUE, "composite_rvalue"},
1742 // OpFAdd/OpIAdd
1743 { TT_ADD, "add"},
1744 // OpFSub/OpISub
1745 { TT_SUB, "sub"},
1746 // OpFDiv/OpSDiv/OpUDiv
1747 { TT_DIV, "div"},
1748 // OpFMul/OpIMul
1749 { TT_MUL, "mul"},
1750 // OpFNegate/OpSNegate
1751 { TT_NEGATE, "negate"},
1752 // OpMatrixTimesScalar
1753 { TT_MATRIXTIMESSCALAR, "matrixtimesscalar"},
1754 // OpFunctionParameter
1755 { TT_FUNC, "func"},
1756 // OpCooperativeMatrixMulAdd
1757 { TT_MATRIXMULADD, "matrixmuladd"},
1758 // OpCompositeConstruct w/array
1759 { TT_COMPOSITE_ARRAY, "composite_array"},
1760 // OpCooperativeMatrixMulAdd w/array
1761 { TT_MATRIXMULADD_ARRAY, "matrixmuladd_array"},
1762 // OpCooperativeMatrixMulAdd w/saturations
1763 { TT_MATRIXMULADD_SATURATED,"matrixmuladd_saturated"},
1764 // OpCooperativeMatrixMulAdd w/wrapping
1765 { TT_MATRIXMULADD_WRAPPING, "matrixmuladd_wrapping"},
1766 // OpCooperativeMatrixMulAdd w/stride==0
1767 { TT_MATRIXMULADD_STRIDE0, "matrixmuladd_stride0"},
1768 };
1769 TestGroupCase2 dtCases[] =
1770 {
1771 // A/B are fp32 C/D are fp32
1772 { { VK_COMPONENT_TYPE_FLOAT32_KHR, VK_COMPONENT_TYPE_FLOAT32_KHR }, "float32_float32"},
1773 // A/B are fp32 C/D are fp16
1774 { { VK_COMPONENT_TYPE_FLOAT32_KHR, VK_COMPONENT_TYPE_FLOAT16_KHR }, "float32_float16"},
1775 // A/B are fp16 C/D are fp32
1776 { { VK_COMPONENT_TYPE_FLOAT16_KHR, VK_COMPONENT_TYPE_FLOAT32_KHR }, "float16_float32"},
1777 // A/B are fp16 C/D are fp16
1778 { { VK_COMPONENT_TYPE_FLOAT16_KHR, VK_COMPONENT_TYPE_FLOAT16_KHR }, "float16_float16"},
1779 // A/B are u8 C/D are u8
1780 { { VK_COMPONENT_TYPE_UINT8_KHR, VK_COMPONENT_TYPE_UINT8_KHR }, "uint8_uint8"},
1781 // A/B are u8 C/D are u32
1782 { { VK_COMPONENT_TYPE_UINT8_KHR, VK_COMPONENT_TYPE_UINT32_KHR }, "uint8_uint32"},
1783 // A/B are s8 C/D are s8
1784 { { VK_COMPONENT_TYPE_SINT8_KHR, VK_COMPONENT_TYPE_SINT8_KHR }, "sint8_sint8"},
1785 // A/B are s8 C/D are s32
1786 { { VK_COMPONENT_TYPE_SINT8_KHR, VK_COMPONENT_TYPE_SINT32_KHR }, "sint8_sint32"},
1787 // A/B are u8 C/D are s32
1788 { { VK_COMPONENT_TYPE_UINT8_KHR, VK_COMPONENT_TYPE_SINT32_KHR }, "uint8_sint32"},
1789 // A/B are u32 C/D are u32
1790 { { VK_COMPONENT_TYPE_UINT32_KHR, VK_COMPONENT_TYPE_UINT32_KHR }, "uint32_uint32"},
1791 // A/B are u32 C/D are u8
1792 { { VK_COMPONENT_TYPE_UINT32_KHR, VK_COMPONENT_TYPE_UINT8_KHR }, "uint32_uint8"},
1793 // A/B are s32 C/D are s32
1794 { { VK_COMPONENT_TYPE_SINT32_KHR, VK_COMPONENT_TYPE_SINT32_KHR }, "sint32_sint32"},
1795 // A/B are s32 C/D are s8
1796 { { VK_COMPONENT_TYPE_SINT32_KHR, VK_COMPONENT_TYPE_SINT8_KHR }, "sint32_sint8"},
1797 };
1798 SubGroubSizes sgsCases[] =
1799 {
1800 // Default subgroup size
1801 { SUBGROUP_SIZE_NONE, "" },
1802 // Minimum subgroup size
1803 { SUBGROUP_SIZE_MIN, "_min"},
1804 // Maximum subgroup size
1805 { SUBGROUP_SIZE_MAX, "_max"},
1806 };
1807
1808 TestGroupCase colCases[] =
1809 {
1810 { 0, "rowmajor"},
1811 { 1, "colmajor"},
1812 };
1813
1814 TestGroupCase scCases[] =
1815 {
1816 // SSBO
1817 { SC_BUFFER, "buffer"},
1818 // shared memory
1819 { SC_WORKGROUP, "workgroup"},
1820 // SSBO w/variable pointers
1821 { SC_BUFFER_VARIABLE_POINTERS, "buffer_varptr"},
1822 // shared memory w/variable pointers
1823 { SC_WORKGROUP_VARIABLE_POINTERS, "workgroup_varptr"},
1824 // physical_storage_buffer
1825 { SC_PHYSICAL_STORAGE_BUFFER, "physical_buffer"},
1826 };
1827
1828 // Types tested for conversions. Excludes 64b types.
1829 VkComponentTypeKHR allTypes[] =
1830 {
1831 VK_COMPONENT_TYPE_FLOAT16_KHR,
1832 VK_COMPONENT_TYPE_FLOAT32_KHR,
1833 VK_COMPONENT_TYPE_SINT8_KHR,
1834 VK_COMPONENT_TYPE_SINT16_KHR,
1835 VK_COMPONENT_TYPE_SINT32_KHR,
1836 VK_COMPONENT_TYPE_UINT8_KHR,
1837 VK_COMPONENT_TYPE_UINT16_KHR,
1838 VK_COMPONENT_TYPE_UINT32_KHR,
1839 };
1840
1841 for (int ttNdx = 0; ttNdx < DE_LENGTH_OF_ARRAY(ttCases); ttNdx++)
1842 {
1843 const TestType testType = (TestType)ttCases[ttNdx].value;
1844
1845 for (int sgsNdx = 0; sgsNdx < DE_LENGTH_OF_ARRAY(sgsCases); sgsNdx++)
1846 {
1847 if (testType != TT_MATRIXMULADD && sgsCases[sgsNdx].value != SUBGROUP_SIZE_NONE)
1848 continue;
1849
1850 if (testType == TT_MATRIXMULADD && sgsCases[sgsNdx].value != SUBGROUP_SIZE_NONE && useType == UT_NV)
1851 continue;
1852
1853 const string name = string(ttCases[ttNdx].name) + sgsCases[sgsNdx].name;
1854 de::MovePtr<tcu::TestCaseGroup> ttGroup (new tcu::TestCaseGroup(testCtx, name.c_str()));
1855
1856 for (int dtNdx = 0; dtNdx < DE_LENGTH_OF_ARRAY(dtCases); dtNdx++)
1857 {
1858 de::MovePtr<tcu::TestCaseGroup> dtGroup(new tcu::TestCaseGroup(testCtx, dtCases[dtNdx].name));
1859 for (int scNdx = 0; scNdx < DE_LENGTH_OF_ARRAY(scCases); scNdx++)
1860 {
1861 de::MovePtr<tcu::TestCaseGroup> scGroup(new tcu::TestCaseGroup(testCtx, scCases[scNdx].name));
1862 for (int colNdx = 0; colNdx < DE_LENGTH_OF_ARRAY(colCases); colNdx++)
1863 {
1864 const VkComponentTypeKHR inputType = (VkComponentTypeKHR)dtCases[dtNdx].value[0];
1865 const VkComponentTypeKHR outputType = (VkComponentTypeKHR)dtCases[dtNdx].value[1];
1866 const bool isMatrixMul = isMatrixMulAddOp(testType);
1867
1868 // useType isn't used for matrixmul shaders. Don't generate 3 copies of those tests.
1869 if (isMatrixMul && (useType == UT_KHR_A || useType == UT_KHR_B)) {
1870 continue;
1871 }
1872
1873 // NV extension doesn't support mixing signedness
1874 if (isMatrixMul && (useType == UT_NV) && isSIntType(inputType) != isSIntType(outputType)) {
1875 continue;
1876 }
1877
1878 if (!isMatrixMul && inputType != outputType)
1879 continue;
1880
1881 if (isMatrixMul && componentTypeInfo[inputType].bits > componentTypeInfo[outputType].bits)
1882 continue;
1883
1884 if (testType == TT_MUL && useType == UT_NV)
1885 continue;
1886
1887 if (testType == TT_MATRIXMULADD_SATURATED && (isFloatType(inputType) || useType == UT_NV))
1888 continue;
1889
1890 if (testType == TT_MATRIXMULADD_WRAPPING && (isFloatType(inputType) || useType == UT_NV))
1891 continue;
1892
1893 if (testType == TT_MATRIXMULADD_STRIDE0 && useType == UT_NV)
1894 continue;
1895
1896 if (testType == TT_LENGTH && useType != UT_NV && (outputType == VK_COMPONENT_TYPE_SINT8_KHR || outputType == VK_COMPONENT_TYPE_UINT8_KHR))
1897 continue;
1898
1899 CaseDef c =
1900 {
1901 testType, // TestType testtype;
1902 2u, // deUint32 subgroupsPerWorkgroupX;
1903 2u, // deUint32 subgroupsPerWorkgroupY;
1904 4u, // deUint32 workgroupsX;
1905 4u, // deUint32 workgroupsY;
1906 inputType, // VkComponentTypeKHR inputType;
1907 outputType, // VkComponentTypeKHR outputType;
1908 !!colCases[colNdx].value, // bool colMajor;
1909 (StorageClass)scCases[scNdx].value, // StorageClass storageClass;
1910 useType, // UseType useType;
1911 sgsCases[sgsNdx].value, // SubgroupSizeMode subgroupSizeMode;
1912 computePipelineConstructionType, // vk::ComputePipelineConstructionType computePipelineConstructionType;
1913 };
1914
1915 scGroup->addChild(new CooperativeMatrixTestCase(testCtx, colCases[colNdx].name, c));
1916 }
1917 dtGroup->addChild(scGroup.release());
1918 }
1919 ttGroup->addChild(dtGroup.release());
1920 }
1921 group->addChild(ttGroup.release());
1922 }
1923 }
1924
1925 {
1926 const string name = string("convert");
1927 const string desc = string("OpFConvert/OpSConvert/OpUConvert/OpBitcast");
1928 de::MovePtr<tcu::TestCaseGroup> ttGroup (new tcu::TestCaseGroup(testCtx, name.c_str()));
1929
1930 for (int dtNdx1 = 0; dtNdx1 < DE_LENGTH_OF_ARRAY(allTypes); dtNdx1++)
1931 {
1932 for (int dtNdx2 = 0; dtNdx2 < DE_LENGTH_OF_ARRAY(allTypes); dtNdx2++)
1933 {
1934 const VkComponentTypeKHR inputType = (VkComponentTypeKHR)allTypes[dtNdx1];
1935 const VkComponentTypeKHR outputType = (VkComponentTypeKHR)allTypes[dtNdx2];
1936 const string name2 = string("input_") + string(componentTypeInfo[inputType].typeName) + string("_output_") + string(componentTypeInfo[outputType].typeName);
1937 de::MovePtr<tcu::TestCaseGroup> dtGroup(new tcu::TestCaseGroup(testCtx, name2.c_str()));
1938 for (int scNdx = 0; scNdx < DE_LENGTH_OF_ARRAY(scCases); scNdx++)
1939 {
1940 de::MovePtr<tcu::TestCaseGroup> scGroup(new tcu::TestCaseGroup(testCtx, scCases[scNdx].name));
1941 for (int colNdx = 0; colNdx < DE_LENGTH_OF_ARRAY(colCases); colNdx++)
1942 {
1943
1944 CaseDef c =
1945 {
1946 TT_CONVERT, // TestType testtype;
1947 2u, // deUint32 subgroupsPerWorkgroupX;
1948 2u, // deUint32 subgroupsPerWorkgroupY;
1949 4u, // deUint32 workgroupsX;
1950 4u, // deUint32 workgroupsY;
1951 inputType, // VkComponentTypeKHR inputType;
1952 outputType, // VkComponentTypeKHR outputType;
1953 !!colCases[colNdx].value, // bool colMajor;
1954 (StorageClass)scCases[scNdx].value, // StorageClass storageClass;
1955 useType, // UseType useType;
1956 SUBGROUP_SIZE_NONE, // SubgroupSizeMode subgroupSizeMode;
1957 computePipelineConstructionType, // vk::ComputePipelineConstructionType computePipelineConstructionType;
1958 };
1959
1960 scGroup->addChild(new CooperativeMatrixTestCase(testCtx, colCases[colNdx].name, c));
1961 }
1962 dtGroup->addChild(scGroup.release());
1963 }
1964 ttGroup->addChild(dtGroup.release());
1965 }
1966 }
1967 group->addChild(ttGroup.release());
1968 }
1969
1970 return group.release();
1971 }
1972
1973 } // anonymous
1974
createCooperativeMatrixTests(tcu::TestContext & testCtx,vk::ComputePipelineConstructionType computePipelineConstructionType)1975 tcu::TestCaseGroup* createCooperativeMatrixTests (tcu::TestContext& testCtx, vk::ComputePipelineConstructionType computePipelineConstructionType)
1976 {
1977 de::MovePtr<tcu::TestCaseGroup> group(new tcu::TestCaseGroup(testCtx, "cooperative_matrix"));
1978
1979 group->addChild(createCooperativeMatrixTestsInternal(testCtx, computePipelineConstructionType, UT_NV));
1980 group->addChild(createCooperativeMatrixTestsInternal(testCtx, computePipelineConstructionType, UT_KHR_A));
1981 group->addChild(createCooperativeMatrixTestsInternal(testCtx, computePipelineConstructionType, UT_KHR_B));
1982 group->addChild(createCooperativeMatrixTestsInternal(testCtx, computePipelineConstructionType, UT_KHR_Result));
1983
1984 return group.release();
1985 }
1986
1987 } // compute
1988 } // vkt
1989