• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*------------------------------------------------------------------------
2  * Vulkan Conformance Tests
3  * ------------------------
4  *
5  * Copyright (c) 2019 The Khronos Group Inc.
6  * Copyright (c) 2018-2024 NVIDIA Corporation
7  * Copyright (c) 2023 LunarG, Inc.
8  * Copyright (c) 2023 Nintendo
9  *
10  * Licensed under the Apache License, Version 2.0 (the "License");
11  * you may not use this file except in compliance with the License.
12  * You may obtain a copy of the License at
13  *
14  *      http://www.apache.org/licenses/LICENSE-2.0
15  *
16  * Unless required by applicable law or agreed to in writing, software
17  * distributed under the License is distributed on an "AS IS" BASIS,
18  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19  * See the License for the specific language governing permissions and
20  * limitations under the License.
21  *
22  *//*!
23  * \file
24  * \brief Vulkan Cooperative Matrix tests
25  *//*--------------------------------------------------------------------*/
26 
27 #include "vktComputeCooperativeMatrixTests.hpp"
28 
29 #include "vkBufferWithMemory.hpp"
30 #include "vkImageWithMemory.hpp"
31 #include "vkQueryUtil.hpp"
32 #include "vkBuilderUtil.hpp"
33 #include "vkCmdUtil.hpp"
34 #include "vkTypeUtil.hpp"
35 #include "vkObjUtil.hpp"
36 
37 #include "vktTestGroupUtil.hpp"
38 #include "vktTestCase.hpp"
39 
40 #include "deDefs.h"
41 #include "tcuFloat.hpp"
42 #include "deMath.h"
43 #include "deRandom.h"
44 #include "deSharedPtr.hpp"
45 #include "deString.h"
46 
47 #include "tcuTestCase.hpp"
48 #include "tcuTestLog.hpp"
49 #include "tcuStringTemplate.hpp"
50 
51 #include <string>
52 #include <sstream>
53 #include <set>
54 #include <algorithm>
55 #include <functional>
56 
57 namespace vkt
58 {
59 namespace compute
60 {
61 namespace
62 {
63 using namespace vk;
64 using namespace std;
65 
66 //#define COOPERATIVE_MATRIX_EXTENDED_DEBUG 1
67 
68 DE_STATIC_ASSERT((uint32_t)VK_COMPONENT_TYPE_FLOAT16_KHR == (uint32_t)VK_COMPONENT_TYPE_FLOAT16_NV);
69 DE_STATIC_ASSERT((uint32_t)VK_COMPONENT_TYPE_FLOAT32_KHR == (uint32_t)VK_COMPONENT_TYPE_FLOAT32_NV);
70 DE_STATIC_ASSERT((uint32_t)VK_COMPONENT_TYPE_FLOAT64_KHR == (uint32_t)VK_COMPONENT_TYPE_FLOAT64_NV);
71 DE_STATIC_ASSERT((uint32_t)VK_COMPONENT_TYPE_SINT8_KHR == (uint32_t)VK_COMPONENT_TYPE_SINT8_NV);
72 DE_STATIC_ASSERT((uint32_t)VK_COMPONENT_TYPE_SINT16_KHR == (uint32_t)VK_COMPONENT_TYPE_SINT16_NV);
73 DE_STATIC_ASSERT((uint32_t)VK_COMPONENT_TYPE_SINT32_KHR == (uint32_t)VK_COMPONENT_TYPE_SINT32_NV);
74 DE_STATIC_ASSERT((uint32_t)VK_COMPONENT_TYPE_SINT64_KHR == (uint32_t)VK_COMPONENT_TYPE_SINT64_NV);
75 DE_STATIC_ASSERT((uint32_t)VK_COMPONENT_TYPE_UINT8_KHR == (uint32_t)VK_COMPONENT_TYPE_UINT8_NV);
76 DE_STATIC_ASSERT((uint32_t)VK_COMPONENT_TYPE_UINT16_KHR == (uint32_t)VK_COMPONENT_TYPE_UINT16_NV);
77 DE_STATIC_ASSERT((uint32_t)VK_COMPONENT_TYPE_UINT32_KHR == (uint32_t)VK_COMPONENT_TYPE_UINT32_NV);
78 DE_STATIC_ASSERT((uint32_t)VK_COMPONENT_TYPE_UINT64_KHR == (uint32_t)VK_COMPONENT_TYPE_UINT64_NV);
79 
80 DE_STATIC_ASSERT((uint32_t)VK_SCOPE_DEVICE_KHR == (uint32_t)VK_SCOPE_DEVICE_NV);
81 DE_STATIC_ASSERT((uint32_t)VK_SCOPE_WORKGROUP_KHR == (uint32_t)VK_SCOPE_WORKGROUP_NV);
82 DE_STATIC_ASSERT((uint32_t)VK_SCOPE_SUBGROUP_KHR == (uint32_t)VK_SCOPE_SUBGROUP_NV);
83 DE_STATIC_ASSERT((uint32_t)VK_SCOPE_QUEUE_FAMILY_KHR == (uint32_t)VK_SCOPE_QUEUE_FAMILY_NV);
84 
85 typedef enum
86 {
87     UT_NV = 0,
88     UT_KHR_A,
89     UT_KHR_B,
90     UT_KHR_C,
91     UT_KHR_Result,
92 } UseType;
93 
94 typedef enum
95 {
96     TT_LENGTH = 0,
97     TT_CONSTANT,
98     TT_CONVERT,
99     TT_CONVERT_ACC_TO_A,
100     TT_CONVERT_ACC_TO_B,
101     TT_TRANSPOSE_ACC_TO_B,
102     TT_REDUCE_SUM_ROW,
103     TT_REDUCE_SUM_COL,
104     TT_REDUCE_SUM_ROWCOL,
105     TT_REDUCE_SUM_2X2,
106     TT_REDUCE_SUM_ROW_CHANGEDIM,
107     TT_REDUCE_SUM_COL_CHANGEDIM,
108     TT_REDUCE_SUM_ROWCOL_CHANGEDIM,
109     TT_REDUCE_MIN_ROW,
110     TT_REDUCE_MIN_COL,
111     TT_REDUCE_MIN_ROWCOL,
112     TT_REDUCE_MIN_2X2,
113     TT_PER_ELEMENT_OP,
114     TT_PER_ELEMENT_OP_ROW_COL,
115     TT_PER_ELEMENT_OP_STRUCT,
116     TT_PER_ELEMENT_OP_MAT,
117     TT_COMPOSITE,
118     TT_COMPOSITE_RVALUE,
119     TT_ADD,
120     TT_SUB,
121     TT_DIV,
122     TT_MUL,
123     TT_NEGATE,
124     TT_MATRIXTIMESSCALAR,
125     TT_FUNC,
126     TT_CLAMPCONSTANT,
127     TT_CLAMPTOEDGE,
128     TT_CLAMPREPEAT,
129     TT_CLAMPMIRRORREPEAT,
130     TT_MATRIXMULADD,
131     TT_COMPOSITE_ARRAY,
132     TT_MATRIXMULADD_ARRAY,
133     TT_MATRIXMULADD_SATURATED,
134     TT_MATRIXMULADD_WRAPPING,
135     TT_MATRIXMULADD_STRIDE0,
136     TT_MATRIXMULADD_DEQUANT,
137     TT_MULTICOMPONENT_LOAD,
138     TT_MULTICOMPONENT_SAVE,
139     TT_MATRIXMULADD_CROSS,
140     TT_TENSORLAYOUT_1D,
141     TT_TENSORLAYOUT_2D,
142     TT_TENSORLAYOUT_3D,
143     TT_TENSORLAYOUT_4D,
144     TT_TENSORLAYOUT_5D,
145     TT_TENSORLAYOUT_1D_CLIP,
146     TT_TENSORLAYOUT_2D_CLIP,
147     TT_TENSORLAYOUT_3D_CLIP,
148     TT_TENSORLAYOUT_4D_CLIP,
149     TT_TENSORLAYOUT_5D_CLIP,
150     TT_SPACETODEPTH,
151     TT_CONV,
152 } TestType;
153 
154 typedef enum
155 {
156     SC_BUFFER = 0,
157     SC_WORKGROUP,
158     SC_WORKGROUP_VARIABLE_POINTERS,
159     SC_BUFFER_VARIABLE_POINTERS,
160     SC_PHYSICAL_STORAGE_BUFFER,
161 } StorageClass;
162 
163 typedef enum
164 {
165     ADDR_LINEAR = 0,
166     ADDR_TENSORLAYOUT,
167     ADDR_BLOCKSIZE,
168     ADDR_DECODE,
169 } AddrMethod;
170 
171 enum SubgroupSizeMode
172 {
173     SUBGROUP_SIZE_NONE = 0,
174     SUBGROUP_SIZE_MIN  = 1,
175     SUBGROUP_SIZE_MAX  = 2,
176 };
177 
178 const VkFlags allShaderStages = VK_SHADER_STAGE_COMPUTE_BIT;
179 
180 struct CaseDef
181 {
182     TestType testType;
183     VkScopeKHR scope;
184     uint32_t subgroupsPerWorkgroupX;
185     uint32_t subgroupsPerWorkgroupY;
186     uint32_t workgroupsX;
187     uint32_t workgroupsY;
188     VkComponentTypeKHR inputType;
189     VkComponentTypeKHR outputType;
190     bool colMajor;
191     AddrMethod addrMethod;
192     StorageClass storageClass;
193     UseType useType;
194     SubgroupSizeMode subgroupSizeMode;
195     vk::ComputePipelineConstructionType computePipelineConstructionType;
196     uint32_t inputComponentCount;
197     uint32_t outputComponentCount;
198 };
199 
isKhr(UseType useType)200 bool isKhr(UseType useType)
201 {
202     return useType != UT_NV;
203 }
204 
isMatrixMulAddOp(TestType testType)205 bool isMatrixMulAddOp(TestType testType)
206 {
207     return testType == TT_MATRIXMULADD || testType == TT_MATRIXMULADD_ARRAY || testType == TT_MATRIXMULADD_SATURATED ||
208            testType == TT_MATRIXMULADD_WRAPPING || testType == TT_MATRIXMULADD_STRIDE0 ||
209            testType == TT_MATRIXMULADD_CROSS || testType == TT_MATRIXMULADD_DEQUANT;
210 }
211 
isReduceRow(TestType testType)212 bool isReduceRow(TestType testType)
213 {
214     return testType == TT_REDUCE_SUM_ROW || testType == TT_REDUCE_MIN_ROW || testType == TT_REDUCE_SUM_ROW_CHANGEDIM;
215 }
216 
isReduceCol(TestType testType)217 bool isReduceCol(TestType testType)
218 {
219     return testType == TT_REDUCE_SUM_COL || testType == TT_REDUCE_MIN_COL || testType == TT_REDUCE_SUM_COL_CHANGEDIM;
220 }
221 
isReduceRowCol(TestType testType)222 bool isReduceRowCol(TestType testType)
223 {
224     return testType == TT_REDUCE_SUM_ROWCOL || testType == TT_REDUCE_MIN_ROWCOL ||
225            testType == TT_REDUCE_SUM_ROWCOL_CHANGEDIM;
226 }
227 
isReduce2x2(TestType testType)228 bool isReduce2x2(TestType testType)
229 {
230     return testType == TT_REDUCE_SUM_2X2 || testType == TT_REDUCE_MIN_2X2;
231 }
232 
isReduceSum(TestType testType)233 bool isReduceSum(TestType testType)
234 {
235     return testType == TT_REDUCE_SUM_ROW || testType == TT_REDUCE_SUM_COL || testType == TT_REDUCE_SUM_ROWCOL ||
236            testType == TT_REDUCE_SUM_2X2 || testType == TT_REDUCE_SUM_ROW_CHANGEDIM ||
237            testType == TT_REDUCE_SUM_COL_CHANGEDIM || testType == TT_REDUCE_SUM_ROWCOL_CHANGEDIM;
238 }
239 
isReduceMin(TestType testType)240 bool isReduceMin(TestType testType)
241 {
242     return testType == TT_REDUCE_MIN_ROW || testType == TT_REDUCE_MIN_COL || testType == TT_REDUCE_MIN_ROWCOL ||
243            testType == TT_REDUCE_MIN_2X2;
244 }
245 
isReduceOp(TestType testType)246 bool isReduceOp(TestType testType)
247 {
248     return isReduceRow(testType) || isReduceCol(testType) || isReduceRowCol(testType) || isReduce2x2(testType);
249 }
250 
isReduceChangeDim(TestType testType)251 bool isReduceChangeDim(TestType testType)
252 {
253     return testType == TT_REDUCE_SUM_ROW_CHANGEDIM || testType == TT_REDUCE_SUM_COL_CHANGEDIM ||
254            testType == TT_REDUCE_SUM_ROWCOL_CHANGEDIM;
255 }
256 
reduceMScale(TestType testType)257 uint32_t reduceMScale(TestType testType)
258 {
259     if (testType == TT_REDUCE_SUM_COL_CHANGEDIM || testType == TT_REDUCE_SUM_ROWCOL_CHANGEDIM)
260     {
261         return 3;
262     }
263     else
264     {
265         return 1;
266     }
267 }
268 
reduceNScale(TestType testType)269 uint32_t reduceNScale(TestType testType)
270 {
271     if (testType == TT_REDUCE_SUM_ROW_CHANGEDIM || testType == TT_REDUCE_SUM_ROWCOL_CHANGEDIM)
272     {
273         return 3;
274     }
275     else
276     {
277         return 1;
278     }
279 }
280 
isClampTest(TestType testType)281 bool isClampTest(TestType testType)
282 {
283     return testType == TT_CLAMPCONSTANT || testType == TT_CLAMPTOEDGE || testType == TT_CLAMPREPEAT ||
284            testType == TT_CLAMPMIRRORREPEAT;
285 }
286 
isTensorLayoutClipTest(TestType testType)287 bool isTensorLayoutClipTest(TestType testType)
288 {
289     return testType == TT_TENSORLAYOUT_1D_CLIP || testType == TT_TENSORLAYOUT_2D_CLIP ||
290            testType == TT_TENSORLAYOUT_3D_CLIP || testType == TT_TENSORLAYOUT_4D_CLIP ||
291            testType == TT_TENSORLAYOUT_5D_CLIP;
292 }
293 
isTensorLayoutTest(TestType testType)294 bool isTensorLayoutTest(TestType testType)
295 {
296     return testType == TT_TENSORLAYOUT_1D || testType == TT_TENSORLAYOUT_2D || testType == TT_TENSORLAYOUT_3D ||
297            testType == TT_TENSORLAYOUT_4D || testType == TT_TENSORLAYOUT_5D || isTensorLayoutClipTest(testType) ||
298            testType == TT_SPACETODEPTH;
299 }
300 
isPerElemOp(TestType testType)301 bool isPerElemOp(TestType testType)
302 {
303     return testType == TT_PER_ELEMENT_OP || testType == TT_PER_ELEMENT_OP_ROW_COL ||
304            testType == TT_PER_ELEMENT_OP_STRUCT || testType == TT_PER_ELEMENT_OP_MAT;
305 }
306 
307 int32_t tensorLayout1dMatrixSize[][5] = {
308     {32, 32},
309     {64, 64},
310 };
311 
312 int32_t tensorLayout1dDim[5] = {65536, 1, 1, 1, 1};
313 
314 int32_t tensorLayout1dSpan[][5] = {
315     {1024},
316     {4096},
317 };
318 
319 int32_t tensorLayout1dLoadOffsets[][5] = {
320     {10000},
321     {-1},
322 };
323 int32_t tensorLayout1dStoreOffsets[][5] = {
324     {-1},
325     {4321},
326 };
327 
328 uint32_t tensorLayout1dNumCoords = sizeof(tensorLayout1dLoadOffsets) / sizeof(tensorLayout1dLoadOffsets[0]);
329 
330 int32_t tensorLayout2dMatrixSize[][5] = {
331     {32, 32},
332     {64, 64},
333 };
334 
335 int32_t tensorLayout2dDim[5] = {512, 512, 1, 1, 1};
336 
337 int32_t tensorLayout2dSpan[][5] = {
338     {32, 32},
339     {64, 64},
340 };
341 
342 int32_t tensorLayout2dLoadOffsets[][5] = {
343     {7, 13},
344     {0 + 128, 0 + 128},
345 };
346 int32_t tensorLayout2dStoreOffsets[][5] = {
347     {13, 7},
348     {20 + 128, 0},
349 };
350 
351 uint32_t tensorLayout2dNumCoords = sizeof(tensorLayout2dLoadOffsets) / sizeof(tensorLayout2dLoadOffsets[0]);
352 
353 int32_t tensorLayout3dDim[5] = {33, 44, 55, 1, 1};
354 
355 int32_t tensorLayout3dMatrixSize[][5] = {
356     {64, 32},
357     {32, 32},
358 };
359 
360 int32_t tensorLayout3dSpan[][5] = {
361     {16, 16, 8},
362     {8, 4, 32},
363 };
364 int32_t tensorLayout3dLoadOffsets[][5] = {
365     {1, 1, 1},
366     {-1, -1, -1},
367 };
368 int32_t tensorLayout3dStoreOffsets[][5] = {
369     {2, 2, 2},
370     {23, 2, 1},
371 };
372 
373 uint32_t tensorLayout3dNumCoords = sizeof(tensorLayout3dLoadOffsets) / sizeof(tensorLayout3dLoadOffsets[0]);
374 
375 int32_t tensorLayout4dDim[5] = {20, 25, 40, 10, 1};
376 
377 int32_t tensorLayout4dMatrixSize[][5] = {
378     {64, 64},
379 };
380 
381 int32_t tensorLayout4dSpan[][5] = {
382     {16, 8, 8, 4},
383 };
384 int32_t tensorLayout4dLoadOffsets[][5] = {
385     {-1, -1, -1, -1},
386 };
387 int32_t tensorLayout4dStoreOffsets[][5] = {
388     {1, 2, 1, 2},
389 };
390 
391 uint32_t tensorLayout4dNumCoords = sizeof(tensorLayout4dLoadOffsets) / sizeof(tensorLayout4dLoadOffsets[0]);
392 
393 int32_t tensorLayout5dDim[5] = {4, 4, 32, 16, 8};
394 
395 int32_t tensorLayout5dMatrixSize[][5] = {
396     {32, 32},
397 };
398 
399 int32_t tensorLayout5dSpan[][5] = {
400     {1, 4, 8, 4, 8},
401 };
402 int32_t tensorLayout5dLoadOffsets[][5] = {
403     {-1, -1, -1, -1, -1},
404 };
405 int32_t tensorLayout5dStoreOffsets[][5] = {
406     {1, 2, 1, 0, 1},
407 };
408 
409 uint32_t tensorLayout5dNumCoords = sizeof(tensorLayout5dLoadOffsets) / sizeof(tensorLayout5dLoadOffsets[0]);
410 
GetTensorLayoutMatrixSizes(uint32_t dim,uint32_t index)411 int32_t *GetTensorLayoutMatrixSizes(uint32_t dim, uint32_t index)
412 {
413     switch (dim)
414     {
415     case 1:
416         return tensorLayout1dMatrixSize[index];
417     case 2:
418         return tensorLayout2dMatrixSize[index];
419     case 3:
420         return tensorLayout3dMatrixSize[index];
421     case 4:
422         return tensorLayout4dMatrixSize[index];
423     case 5:
424         return tensorLayout5dMatrixSize[index];
425     }
426     DE_ASSERT(0);
427     return nullptr;
428 }
429 
GetTensorLayoutDim(uint32_t dim)430 int32_t *GetTensorLayoutDim(uint32_t dim)
431 {
432     switch (dim)
433     {
434     case 1:
435         return tensorLayout1dDim;
436     case 2:
437         return tensorLayout2dDim;
438     case 3:
439         return tensorLayout3dDim;
440     case 4:
441         return tensorLayout4dDim;
442     case 5:
443         return tensorLayout5dDim;
444     }
445     DE_ASSERT(0);
446     return nullptr;
447 }
448 
GetTensorLayoutSpan(uint32_t dim,uint32_t index)449 int32_t *GetTensorLayoutSpan(uint32_t dim, uint32_t index)
450 {
451     switch (dim)
452     {
453     case 1:
454         return tensorLayout1dSpan[index];
455     case 2:
456         return tensorLayout2dSpan[index];
457     case 3:
458         return tensorLayout3dSpan[index];
459     case 4:
460         return tensorLayout4dSpan[index];
461     case 5:
462         return tensorLayout5dSpan[index];
463     }
464     DE_ASSERT(0);
465     return nullptr;
466 }
467 
GetTensorLayoutLoadOffsets(uint32_t dim,uint32_t index)468 int32_t *GetTensorLayoutLoadOffsets(uint32_t dim, uint32_t index)
469 {
470     switch (dim)
471     {
472     case 1:
473         return tensorLayout1dLoadOffsets[index];
474     case 2:
475         return tensorLayout2dLoadOffsets[index];
476     case 3:
477         return tensorLayout3dLoadOffsets[index];
478     case 4:
479         return tensorLayout4dLoadOffsets[index];
480     case 5:
481         return tensorLayout5dLoadOffsets[index];
482     }
483     DE_ASSERT(0);
484     return nullptr;
485 }
486 
GetTensorLayoutStoreOffsets(uint32_t dim,uint32_t index)487 int32_t *GetTensorLayoutStoreOffsets(uint32_t dim, uint32_t index)
488 {
489     switch (dim)
490     {
491     case 1:
492         return tensorLayout1dStoreOffsets[index];
493     case 2:
494         return tensorLayout2dStoreOffsets[index];
495     case 3:
496         return tensorLayout3dStoreOffsets[index];
497     case 4:
498         return tensorLayout4dStoreOffsets[index];
499     case 5:
500         return tensorLayout5dStoreOffsets[index];
501     }
502     DE_ASSERT(0);
503     return nullptr;
504 }
505 
GetTensorLayoutNumCoords(uint32_t dim)506 uint32_t GetTensorLayoutNumCoords(uint32_t dim)
507 {
508     switch (dim)
509     {
510     case 1:
511         return tensorLayout1dNumCoords;
512     case 2:
513         return tensorLayout2dNumCoords;
514     case 3:
515         return tensorLayout3dNumCoords;
516     case 4:
517         return tensorLayout4dNumCoords;
518     case 5:
519         return tensorLayout5dNumCoords;
520     }
521     DE_ASSERT(0);
522     return 0;
523 }
524 
GetDim(TestType testType)525 uint32_t GetDim(TestType testType)
526 {
527     switch (testType)
528     {
529     case TT_TENSORLAYOUT_1D:
530         return 1;
531     case TT_TENSORLAYOUT_2D:
532         return 2;
533     case TT_TENSORLAYOUT_3D:
534         return 3;
535     case TT_TENSORLAYOUT_4D:
536         return 4;
537     case TT_TENSORLAYOUT_5D:
538         return 5;
539     case TT_TENSORLAYOUT_1D_CLIP:
540         return 1;
541     case TT_TENSORLAYOUT_2D_CLIP:
542         return 2;
543     case TT_TENSORLAYOUT_3D_CLIP:
544         return 3;
545     case TT_TENSORLAYOUT_4D_CLIP:
546         return 4;
547     case TT_TENSORLAYOUT_5D_CLIP:
548         return 5;
549     default:
550         DE_ASSERT(0);
551         return 0;
552     }
553 }
554 
555 static constexpr uint32_t blockSize[2] = {2, 4};
556 
557 template <typename T>
getCooperativeMatrixProperties(const InstanceInterface &,VkPhysicalDevice,uint32_t *,T *)558 VkResult getCooperativeMatrixProperties(const InstanceInterface &, VkPhysicalDevice, uint32_t *, T *)
559 {
560     TCU_THROW(InternalError, "Not Implementetd");
561 }
562 
getCooperativeMatrixProperties(const InstanceInterface & vki,VkPhysicalDevice physicalDevice,uint32_t * pPropertyCount,VkCooperativeMatrixPropertiesKHR * pProperties)563 VkResult getCooperativeMatrixProperties(const InstanceInterface &vki, VkPhysicalDevice physicalDevice,
564                                         uint32_t *pPropertyCount, VkCooperativeMatrixPropertiesKHR *pProperties)
565 {
566     return vki.getPhysicalDeviceCooperativeMatrixPropertiesKHR(physicalDevice, pPropertyCount, pProperties);
567 }
568 
getCooperativeMatrixProperties(const InstanceInterface & vki,VkPhysicalDevice physicalDevice,uint32_t * pPropertyCount,VkCooperativeMatrixPropertiesNV * pProperties)569 VkResult getCooperativeMatrixProperties(const InstanceInterface &vki, VkPhysicalDevice physicalDevice,
570                                         uint32_t *pPropertyCount, VkCooperativeMatrixPropertiesNV *pProperties)
571 {
572     return vki.getPhysicalDeviceCooperativeMatrixPropertiesNV(physicalDevice, pPropertyCount, pProperties);
573 }
574 
convertCooperativeMatrixProperties(const VkCooperativeMatrixPropertiesNV & properties)575 VkCooperativeMatrixPropertiesKHR convertCooperativeMatrixProperties(const VkCooperativeMatrixPropertiesNV &properties)
576 {
577     VkCooperativeMatrixPropertiesKHR result = initVulkanStructure();
578 
579     result.sType                  = (VkStructureType)properties.sType;
580     result.pNext                  = (void *)properties.pNext;
581     result.MSize                  = (uint32_t)properties.MSize;
582     result.NSize                  = (uint32_t)properties.NSize;
583     result.KSize                  = (uint32_t)properties.KSize;
584     result.AType                  = (VkComponentTypeKHR)properties.AType;
585     result.BType                  = (VkComponentTypeKHR)properties.BType;
586     result.CType                  = (VkComponentTypeKHR)properties.CType;
587     result.ResultType             = (VkComponentTypeKHR)properties.DType;
588     result.saturatingAccumulation = (VkBool32)VK_FALSE;
589     result.scope                  = (VkScopeKHR)properties.scope;
590 
591     return result;
592 }
593 
convertCooperativeMatrixProperties(const std::vector<VkCooperativeMatrixPropertiesNV> & properties)594 std::vector<VkCooperativeMatrixPropertiesKHR> convertCooperativeMatrixProperties(
595     const std::vector<VkCooperativeMatrixPropertiesNV> &properties)
596 {
597     std::vector<VkCooperativeMatrixPropertiesKHR> result(properties.size());
598 
599     for (size_t i = 0; i < properties.size(); ++i)
600         result[i] = convertCooperativeMatrixProperties(properties[i]);
601 
602     return result;
603 }
604 
605 template <typename T>
getCooperativeMatrixPropertiesAll(Context & context,std::vector<T> & properties)606 void getCooperativeMatrixPropertiesAll(Context &context, std::vector<T> &properties)
607 {
608     uint32_t propertyCount = 0;
609 
610     VK_CHECK(getCooperativeMatrixProperties(context.getInstanceInterface(), context.getPhysicalDevice(), &propertyCount,
611                                             (T *)nullptr));
612 
613     if (propertyCount > 0)
614     {
615         const T sample = initVulkanStructureConst();
616 
617         properties.resize(propertyCount, sample);
618 
619         VK_CHECK(getCooperativeMatrixProperties(context.getInstanceInterface(), context.getPhysicalDevice(),
620                                                 &propertyCount, properties.data()));
621     }
622     else
623     {
624         properties.clear();
625     }
626 }
627 
getCooperativeMatrixPropertiesConverted(Context & context,const bool khr)628 std::vector<VkCooperativeMatrixPropertiesKHR> getCooperativeMatrixPropertiesConverted(Context &context, const bool khr)
629 {
630     std::vector<VkCooperativeMatrixPropertiesKHR> properties;
631 
632     if (khr)
633     {
634         getCooperativeMatrixPropertiesAll(context, properties);
635     }
636     else
637     {
638         std::vector<VkCooperativeMatrixPropertiesNV> propertiesNV;
639 
640         getCooperativeMatrixPropertiesAll(context, propertiesNV);
641 
642         properties = convertCooperativeMatrixProperties(propertiesNV);
643     }
644 
645     return properties;
646 }
647 
getSubgroupSizeFromMode(Context & context,const SubgroupSizeMode subgroupSizeMode)648 uint32_t getSubgroupSizeFromMode(Context &context, const SubgroupSizeMode subgroupSizeMode)
649 {
650 #ifndef CTS_USES_VULKANSC
651     const VkPhysicalDeviceSubgroupSizeControlProperties &subgroupSizeControlProperties =
652         context.getSubgroupSizeControlProperties();
653 #else
654     const VkPhysicalDeviceSubgroupSizeControlPropertiesEXT &subgroupSizeControlProperties =
655         context.getSubgroupSizeControlProperties();
656 #endif // CTS_USES_VULKANSC
657 
658     switch (subgroupSizeMode)
659     {
660     case SUBGROUP_SIZE_MAX:
661         return subgroupSizeControlProperties.maxSubgroupSize;
662     case SUBGROUP_SIZE_MIN:
663         return subgroupSizeControlProperties.minSubgroupSize;
664     case SUBGROUP_SIZE_NONE:
665         return context.getSubgroupProperties().subgroupSize;
666     default:
667         TCU_THROW(NotSupportedError, "Unsupported Subgroup size");
668     }
669 }
670 
671 class CooperativeMatrixTestInstance : public TestInstance
672 {
673 public:
674     CooperativeMatrixTestInstance(Context &context, const CaseDef &data);
675     ~CooperativeMatrixTestInstance(void);
676     tcu::TestStatus iterate(void);
677 
678 private:
679     CaseDef m_data;
680 };
681 
CooperativeMatrixTestInstance(Context & context,const CaseDef & data)682 CooperativeMatrixTestInstance::CooperativeMatrixTestInstance(Context &context, const CaseDef &data)
683     : vkt::TestInstance(context)
684     , m_data(data)
685 {
686 }
687 
~CooperativeMatrixTestInstance(void)688 CooperativeMatrixTestInstance::~CooperativeMatrixTestInstance(void)
689 {
690 }
691 
692 class CooperativeMatrixTestCase : public TestCase
693 {
694 public:
695     CooperativeMatrixTestCase(tcu::TestContext &context, const char *name, const CaseDef data);
696     ~CooperativeMatrixTestCase(void);
697     virtual void initPrograms(SourceCollections &programCollection) const;
698     virtual TestInstance *createInstance(Context &context) const;
699     virtual void checkSupport(Context &context) const;
700 
701 private:
702     virtual void initProgramsGLSL(SourceCollections &programCollection) const;
703     virtual void initProgramsSPIRV(SourceCollections &programCollection) const;
704     CaseDef m_data;
705 };
706 
CooperativeMatrixTestCase(tcu::TestContext & context,const char * name,const CaseDef data)707 CooperativeMatrixTestCase::CooperativeMatrixTestCase(tcu::TestContext &context, const char *name, const CaseDef data)
708     : vkt::TestCase(context, name)
709     , m_data(data)
710 {
711 }
712 
~CooperativeMatrixTestCase(void)713 CooperativeMatrixTestCase::~CooperativeMatrixTestCase(void)
714 {
715 }
716 
checkSupport(Context & context) const717 void CooperativeMatrixTestCase::checkSupport(Context &context) const
718 {
719     if (!context.contextSupports(vk::ApiVersion(0, 1, 1, 0)))
720     {
721         TCU_THROW(NotSupportedError, "Vulkan 1.1 not supported");
722     }
723 
724     if (isKhr(m_data.useType))
725     {
726         if (!context.getCooperativeMatrixFeatures().cooperativeMatrix)
727         {
728             TCU_THROW(NotSupportedError,
729                       "VkPhysicalDeviceCooperativeMatrixFeaturesKHR::cooperativeMatrix not supported");
730         }
731     }
732     else
733     {
734         if (!context.getCooperativeMatrixFeaturesNV().cooperativeMatrix)
735         {
736             TCU_THROW(NotSupportedError,
737                       "VkPhysicalDeviceCooperativeMatrixFeaturesNV::cooperativeMatrix not supported");
738         }
739     }
740 
741     if (!context.getVulkanMemoryModelFeatures().vulkanMemoryModel)
742     {
743         TCU_THROW(NotSupportedError, "vulkanMemoryModel not supported");
744     }
745 
746     if ((m_data.storageClass == SC_WORKGROUP_VARIABLE_POINTERS || m_data.storageClass == SC_BUFFER_VARIABLE_POINTERS) &&
747         !context.getVariablePointersFeatures().variablePointers)
748     {
749         TCU_THROW(NotSupportedError, "variable pointers not supported");
750     }
751 
752     if (m_data.storageClass == SC_PHYSICAL_STORAGE_BUFFER && !context.isBufferDeviceAddressSupported())
753     {
754         TCU_THROW(NotSupportedError, "buffer device address not supported");
755     }
756 
757     if (!context.getShaderFloat16Int8Features().shaderFloat16 &&
758         (m_data.inputType == VK_COMPONENT_TYPE_FLOAT16_KHR || m_data.outputType == VK_COMPONENT_TYPE_FLOAT16_KHR))
759     {
760         TCU_THROW(NotSupportedError, "shaderFloat16 not supported");
761     }
762 
763 #define REQUIRE(FEATURE)                                             \
764     context.requireDeviceFunctionality("VK_NV_cooperative_matrix2"); \
765     if (!context.getCooperativeMatrix2FeaturesNV().FEATURE)          \
766     {                                                                \
767         TCU_THROW(NotSupportedError, #FEATURE " not supported");     \
768     }
769 
770     if (m_data.scope == VK_SCOPE_WORKGROUP_KHR)
771     {
772         REQUIRE(cooperativeMatrixWorkgroupScope)
773     }
774     if (isReduceOp(m_data.testType))
775     {
776         REQUIRE(cooperativeMatrixReductions)
777     }
778 
779     if (m_data.testType == TT_CONVERT_ACC_TO_A || m_data.testType == TT_CONVERT_ACC_TO_B ||
780         m_data.testType == TT_TRANSPOSE_ACC_TO_B)
781     {
782         REQUIRE(cooperativeMatrixConversions)
783     }
784 
785     if (isPerElemOp(m_data.testType))
786     {
787         REQUIRE(cooperativeMatrixPerElementOperations)
788     }
789 
790     if (m_data.addrMethod != ADDR_LINEAR || isTensorLayoutTest(m_data.testType) || isClampTest(m_data.testType))
791     {
792         REQUIRE(cooperativeMatrixTensorAddressing);
793     }
794 
795     if (isTensorLayoutTest(m_data.testType))
796     {
797         REQUIRE(cooperativeMatrixFlexibleDimensions);
798     }
799 
800     if (m_data.addrMethod == ADDR_BLOCKSIZE || m_data.addrMethod == ADDR_DECODE)
801     {
802         REQUIRE(cooperativeMatrixBlockLoads);
803     }
804 
805     std::vector<VkCooperativeMatrixPropertiesKHR> properties =
806         getCooperativeMatrixPropertiesConverted(context, isKhr(m_data.useType));
807     bool supported[2]   = {false, false};
808     const auto isMMA    = isMatrixMulAddOp(m_data.testType);
809     const auto isMMASat = m_data.testType == TT_MATRIXMULADD_SATURATED;
810 
811     for (size_t i = 0; i < properties.size(); ++i)
812     {
813         const VkCooperativeMatrixPropertiesKHR *p = &properties[i];
814 
815         if (p->scope != m_data.scope)
816             continue;
817 
818         if (isMMA && isMMASat != static_cast<bool>(p->saturatingAccumulation))
819             continue;
820 
821         if (isMMA)
822         {
823             if (p->AType == m_data.inputType && p->BType == m_data.inputType && p->CType == m_data.outputType &&
824                 p->ResultType == m_data.outputType)
825             {
826                 supported[0] = supported[1] = true;
827             }
828         }
829         else
830         {
831             const VkComponentTypeKHR types[2] = {m_data.inputType, m_data.outputType};
832             UseType uses[2]                   = {m_data.useType, m_data.useType};
833             if (m_data.testType == TT_CONVERT_ACC_TO_A)
834             {
835                 uses[1] = UT_KHR_A;
836             }
837             else if (m_data.testType == TT_CONVERT_ACC_TO_B || m_data.testType == TT_TRANSPOSE_ACC_TO_B)
838             {
839                 uses[1] = UT_KHR_B;
840             }
841 
842             for (uint32_t j = 0; j < 2; ++j)
843             {
844                 switch (uses[j])
845                 {
846                 case UT_NV:
847                 {
848                     if (p->AType == types[j] || p->BType == types[j] || p->CType == types[j] ||
849                         p->ResultType == types[j])
850                         supported[j] = true;
851 
852                     break;
853                 }
854                 case UT_KHR_A:
855                 {
856                     if (p->AType == types[j])
857                         supported[j] = true;
858 
859                     break;
860                 }
861                 case UT_KHR_B:
862                 {
863                     if (p->BType == types[j])
864                         supported[j] = true;
865 
866                     break;
867                 }
868                 case UT_KHR_Result:
869                 {
870                     if (p->ResultType == types[j])
871                         supported[j] = true;
872 
873                     break;
874                 }
875                 default:
876                     TCU_THROW(InternalError, "Unsupported use type");
877                 }
878             }
879         }
880     }
881 
882     if (context.getCooperativeMatrix2FeaturesNV().cooperativeMatrixFlexibleDimensions)
883     {
884         uint32_t flexiblePropertyCount = 0;
885         std::vector<VkCooperativeMatrixFlexibleDimensionsPropertiesNV> flexibleProperties;
886 
887         const InstanceInterface &vki = context.getInstanceInterface();
888         VK_CHECK(vki.getPhysicalDeviceCooperativeMatrixFlexibleDimensionsPropertiesNV(context.getPhysicalDevice(),
889                                                                                       &flexiblePropertyCount, nullptr));
890 
891         if (flexiblePropertyCount > 0)
892         {
893             const VkCooperativeMatrixFlexibleDimensionsPropertiesNV sample = initVulkanStructureConst();
894 
895             flexibleProperties.resize(flexiblePropertyCount, sample);
896 
897             VK_CHECK(vki.getPhysicalDeviceCooperativeMatrixFlexibleDimensionsPropertiesNV(
898                 context.getPhysicalDevice(), &flexiblePropertyCount, flexibleProperties.data()));
899         }
900         else
901         {
902             flexibleProperties.clear();
903         }
904 
905         for (size_t i = 0; i < flexibleProperties.size(); ++i)
906         {
907             const VkCooperativeMatrixFlexibleDimensionsPropertiesNV *p = &flexibleProperties[i];
908 
909             if (p->scope != m_data.scope)
910                 continue;
911 
912             if (isMMA && isMMASat != static_cast<bool>(p->saturatingAccumulation))
913                 continue;
914 
915             if (isMMA)
916             {
917                 if (p->AType == m_data.inputType && p->BType == m_data.inputType && p->CType == m_data.outputType &&
918                     p->ResultType == m_data.outputType)
919                 {
920                     supported[0] = supported[1] = true;
921                 }
922             }
923             else
924             {
925                 const VkComponentTypeKHR types[2] = {m_data.inputType, m_data.outputType};
926                 UseType uses[2]                   = {m_data.useType, m_data.useType};
927                 if (m_data.testType == TT_CONVERT_ACC_TO_A)
928                 {
929                     uses[1] = UT_KHR_A;
930                 }
931                 else if (m_data.testType == TT_CONVERT_ACC_TO_B || m_data.testType == TT_TRANSPOSE_ACC_TO_B)
932                 {
933                     uses[1] = UT_KHR_B;
934                 }
935 
936                 for (uint32_t j = 0; j < 2; ++j)
937                 {
938                     switch (uses[j])
939                     {
940                     case UT_NV:
941                         break;
942                     case UT_KHR_A:
943                     {
944                         if (p->AType == types[j])
945                             supported[j] = true;
946 
947                         break;
948                     }
949                     case UT_KHR_B:
950                     {
951                         if (p->BType == types[j])
952                             supported[j] = true;
953 
954                         break;
955                     }
956                     case UT_KHR_Result:
957                     {
958                         if (p->ResultType == types[j])
959                             supported[j] = true;
960 
961                         break;
962                     }
963                     default:
964                         TCU_THROW(InternalError, "Unsupported use type");
965                     }
966                 }
967             }
968         }
969     }
970 
971     if (!supported[0] || !supported[1])
972         TCU_THROW(NotSupportedError, "cooperative matrix combination not supported");
973 
974     checkShaderObjectRequirements(context.getInstanceInterface(), context.getPhysicalDevice(),
975                                   m_data.computePipelineConstructionType);
976 }
977 
978 struct
979 {
980     const char *typeName;
981     const char *coopmatTypeName;
982     uint32_t bits;
983     bool isSigned;
984 } componentTypeInfo[] = {
985     {"float16_t", "fcoopmatNV", 16, true}, {"float32_t", "fcoopmatNV", 32, true}, {"float64_t", "fcoopmatNV", 64, true},
986     {"int8_t", "icoopmatNV", 8, true},     {"int16_t", "icoopmatNV", 16, true},   {"int32_t", "icoopmatNV", 32, true},
987     {"int64_t", "icoopmatNV", 64, true},   {"uint8_t", "ucoopmatNV", 8, false},   {"uint16_t", "ucoopmatNV", 16, false},
988     {"uint32_t", "ucoopmatNV", 32, false}, {"uint64_t", "ucoopmatNV", 64, false},
989 };
990 
isFloatType(VkComponentTypeKHR t)991 bool isFloatType(VkComponentTypeKHR t)
992 {
993     switch (t)
994     {
995     case VK_COMPONENT_TYPE_FLOAT16_KHR:
996     case VK_COMPONENT_TYPE_FLOAT32_KHR:
997     case VK_COMPONENT_TYPE_FLOAT64_KHR:
998         return true;
999     default:
1000         return false;
1001     }
1002 }
1003 
isSIntType(VkComponentTypeKHR t)1004 bool isSIntType(VkComponentTypeKHR t)
1005 {
1006     switch (t)
1007     {
1008     case VK_COMPONENT_TYPE_SINT8_KHR:
1009     case VK_COMPONENT_TYPE_SINT16_KHR:
1010     case VK_COMPONENT_TYPE_SINT32_KHR:
1011     case VK_COMPONENT_TYPE_SINT64_KHR:
1012         return true;
1013     default:
1014         return false;
1015     }
1016 }
1017 
initProgramsGLSL(SourceCollections & programCollection) const1018 void CooperativeMatrixTestCase::initProgramsGLSL(SourceCollections &programCollection) const
1019 {
1020     const char *suffix = (isKhr(m_data.useType) ? "" : "NV");
1021     const char *ext    = isKhr(m_data.useType) ? "#extension GL_KHR_cooperative_matrix : enable\n" :
1022                                                  "#extension GL_NV_cooperative_matrix : enable\n"
1023                                                  "#extension GL_NV_integer_cooperative_matrix : enable\n";
1024     const char *sat = (m_data.testType == TT_MATRIXMULADD_SATURATED) ? ", gl_MatrixOperandsSaturatingAccumulation" : "";
1025     std::stringstream css;
1026     css << "#version 450 core\n";
1027     css << "#pragma use_vulkan_memory_model\n";
1028     css << "#extension GL_KHR_shader_subgroup_basic : enable\n"
1029            "#extension GL_KHR_memory_scope_semantics : enable\n"
1030         << ext
1031         << "#extension GL_EXT_shader_explicit_arithmetic_types : enable\n"
1032            "#extension GL_EXT_buffer_reference : enable\n"
1033            "#extension GL_NV_cooperative_matrix2 : enable\n"
1034            "// strides overriden by spec constants\n"
1035            "layout(constant_id = 2) const int AStride = 1;\n"
1036            "layout(constant_id = 3) const int BStride = 1;\n"
1037            "layout(constant_id = 4) const int CStride = 1;\n"
1038            "layout(constant_id = 5) const int OStride = 1;\n"
1039            "layout(constant_id = 6) const int M = 1;\n"
1040            "layout(constant_id = 7) const int N = 1;\n"
1041            "layout(constant_id = 8) const int K = 1;\n"
1042            "layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z = 1) in;\n";
1043 
1044     if (m_data.storageClass == SC_BUFFER_VARIABLE_POINTERS || m_data.storageClass == SC_WORKGROUP_VARIABLE_POINTERS)
1045         css << "#pragma use_variable_pointers\n";
1046 
1047     struct
1048     {
1049         string rows, cols;
1050     } dims[4];
1051 
1052     if (isMatrixMulAddOp(m_data.testType))
1053     {
1054         dims[0].rows = "M";
1055         dims[0].cols = "K";
1056         dims[1].rows = "K";
1057         dims[1].cols = "N";
1058         dims[2].rows = "M";
1059         dims[2].cols = "N";
1060         dims[3].rows = "M";
1061         dims[3].cols = "N";
1062     }
1063     else
1064     {
1065         if (isReduce2x2(m_data.testType))
1066         {
1067             dims[0].rows = "(M*2)";
1068             dims[0].cols = "(N*2)";
1069         }
1070         else
1071         {
1072             dims[0].rows = "M";
1073             dims[0].cols = "N";
1074         }
1075         dims[1].rows = "M";
1076         dims[1].cols = "N";
1077         dims[2].rows = "M";
1078         dims[2].cols = "N";
1079         if (isReduceChangeDim(m_data.testType))
1080         {
1081             dims[3].rows = "(M*" + std::to_string(reduceMScale(m_data.testType)) + ")";
1082             dims[3].cols = "(N*" + std::to_string(reduceNScale(m_data.testType)) + ")";
1083         }
1084         else if (m_data.testType == TT_TRANSPOSE_ACC_TO_B)
1085         {
1086             dims[2].rows = "N";
1087             dims[2].cols = "M";
1088             dims[3].rows = "N";
1089             dims[3].cols = "M";
1090         }
1091         else
1092         {
1093             dims[3].rows = "M";
1094             dims[3].cols = "N";
1095         }
1096     }
1097 
1098     const char *typeStrA = componentTypeInfo[m_data.inputType].typeName;
1099     const char *typeStrB = componentTypeInfo[m_data.inputType].typeName;
1100     const char *typeStrC = componentTypeInfo[m_data.outputType].typeName;
1101     const char *typeStrO = componentTypeInfo[m_data.outputType].typeName;
1102     string inputType;
1103     string outputType;
1104     string divisorA;
1105     string divisorB;
1106     string divisorC;
1107     string divisorO;
1108     string *divisors[4] = {&divisorA, &divisorB, &divisorC, &divisorO};
1109 
1110     string scopeStr = (m_data.scope == VK_SCOPE_WORKGROUP_KHR) ? "gl_ScopeWorkgroup" : "gl_ScopeSubgroup";
1111 
1112     if (m_data.testType == TT_MULTICOMPONENT_LOAD)
1113     {
1114         const char *componentSuffix = m_data.inputComponentCount == 2 ? "vec2" :
1115                                       m_data.inputComponentCount == 4 ? "vec4" :
1116                                                                         "";
1117 
1118         inputType = string(1, componentTypeInfo[m_data.inputType].coopmatTypeName[0]) +
1119                     de::toString(componentTypeInfo[m_data.inputType].bits) + componentSuffix;
1120 
1121         typeStrA = inputType.c_str();
1122         typeStrB = inputType.c_str();
1123         divisorA = m_data.inputComponentCount == 2 ? "/2" : m_data.inputComponentCount == 4 ? "/4" : "";
1124         divisorB = divisorA;
1125     }
1126 
1127     if (m_data.testType == TT_MULTICOMPONENT_SAVE)
1128     {
1129         const char *componentSuffix = m_data.outputComponentCount == 2 ? "vec2" :
1130                                       m_data.outputComponentCount == 4 ? "vec4" :
1131                                                                          "";
1132 
1133         outputType = string(1, componentTypeInfo[m_data.outputType].coopmatTypeName[0]) +
1134                      de::toString(componentTypeInfo[m_data.outputType].bits) + componentSuffix;
1135 
1136         typeStrC = outputType.c_str();
1137         typeStrO = outputType.c_str();
1138         divisorC = m_data.outputComponentCount == 2 ? "/2" : m_data.outputComponentCount == 4 ? "/4" : "";
1139         divisorO = divisorC;
1140     }
1141 
1142     css << "const int workgroupsX = " << m_data.workgroupsX << ";\n";
1143     if (m_data.scope != VK_SCOPE_WORKGROUP_KHR)
1144     {
1145         css << "const uvec2 subgroupsPerWG = uvec2(" << m_data.subgroupsPerWorkgroupX << ", "
1146             << m_data.subgroupsPerWorkgroupY << ");\n";
1147     }
1148 
1149     // Test loading from a struct
1150     string typeStrAStruct = typeStrA;
1151     if (m_data.storageClass != SC_WORKGROUP && m_data.storageClass != SC_WORKGROUP_VARIABLE_POINTERS &&
1152         m_data.addrMethod != ADDR_LINEAR)
1153     {
1154         css << "struct StructA { " << typeStrA << " y; };\n";
1155         typeStrAStruct = "StructA";
1156     }
1157 
1158     if (m_data.storageClass == SC_PHYSICAL_STORAGE_BUFFER)
1159     {
1160         css << "layout(buffer_reference) buffer InputA { " << typeStrAStruct << " x[]; };\n";
1161         css << "layout(buffer_reference) buffer InputB { " << typeStrB << " x[]; };\n";
1162         css << "layout(buffer_reference) buffer InputC { " << typeStrC << " x[]; };\n";
1163         css << "layout(buffer_reference) buffer Output { " << typeStrO << " x[]; };\n";
1164         css << "layout(set=0, binding=4) buffer Params { InputA inputA; InputB inputB; InputC inputC; Output outputO; "
1165                "} params;\n";
1166     }
1167     else
1168     {
1169         css << "layout(set=0, binding=0) coherent buffer InputA { " << typeStrAStruct << " x[]; } inputA;\n";
1170         css << "layout(set=0, binding=1) coherent buffer InputB { " << typeStrB << " x[]; } inputB;\n";
1171         css << "layout(set=0, binding=2) coherent buffer InputC { " << typeStrC << " x[]; } inputC;\n";
1172         css << "layout(set=0, binding=3) coherent buffer Output { " << typeStrO << " x[]; } outputO;\n";
1173     }
1174 
1175     if (m_data.storageClass == SC_WORKGROUP || m_data.storageClass == SC_WORKGROUP_VARIABLE_POINTERS)
1176     {
1177         string scale = (m_data.scope == VK_SCOPE_WORKGROUP_KHR) ? "1" : "subgroupsPerWG.x * subgroupsPerWG.y";
1178         css << "shared " << typeStrA << " sharedA[" << dims[0].rows << " * " << dims[0].cols << " * " << scale
1179             << "];\n";
1180         css << "shared " << typeStrB << " sharedB[" << dims[1].rows << " * " << dims[1].cols << " * " << scale
1181             << "];\n";
1182         css << "shared " << typeStrC << " sharedC[" << dims[2].rows << " * " << dims[2].cols << " * " << scale
1183             << "];\n";
1184         css << "shared " << typeStrO << " sharedO[" << dims[3].rows << " * " << dims[3].cols << " * " << scale
1185             << "];\n";
1186     }
1187 
1188     std::stringstream matAType, matBType, matCType, outputMatType;
1189 
1190     // GLSL only considers types the same if any spec constants are the same and have
1191     // no operations. So for 2x2 reductions, where A has M*2/N*2 rows and cols, we need
1192     // to put that in a variable. But we can't for other tests, where we e.g. want to
1193     // assign matA to matO.
1194     if (isReduce2x2(m_data.testType))
1195     {
1196         css << "const int ARows = " << dims[0].rows << ";\n";
1197         css << "const int ACols = " << dims[0].cols << ";\n";
1198     }
1199     else
1200     {
1201         css << "#define ARows " << dims[0].rows << "\n";
1202         css << "#define ACols " << dims[0].cols << "\n";
1203     }
1204     if (isReduceChangeDim(m_data.testType))
1205     {
1206         css << "const int ORows = " << dims[3].rows << ";\n";
1207         css << "const int OCols = " << dims[3].cols << ";\n";
1208     }
1209     else
1210     {
1211         css << "#define ORows " << dims[3].rows << "\n";
1212         css << "#define OCols " << dims[3].cols << "\n";
1213     }
1214 
1215     const char *sameType = m_data.useType == UT_KHR_A      ? "gl_MatrixUseA" :
1216                            m_data.useType == UT_KHR_B      ? "gl_MatrixUseB" :
1217                            m_data.useType == UT_KHR_Result ? "gl_MatrixUseAccumulator" :
1218                                                              "Invalid use";
1219 
1220     if (isKhr(m_data.useType))
1221     {
1222         const bool useSame = !isMatrixMulAddOp(m_data.testType);
1223         const char *atype  = useSame ? sameType : "gl_MatrixUseA";
1224         const char *btype  = useSame ? sameType : "gl_MatrixUseB";
1225         const char *ctype  = useSame ? sameType : "gl_MatrixUseAccumulator";
1226         const char *rtype  = useSame ? sameType : "gl_MatrixUseAccumulator";
1227 
1228         if (m_data.testType == TT_CONVERT_ACC_TO_A)
1229         {
1230             atype = "gl_MatrixUseAccumulator";
1231             btype = "gl_MatrixUseAccumulator";
1232             ctype = "gl_MatrixUseA";
1233             rtype = "gl_MatrixUseA";
1234         }
1235         else if (m_data.testType == TT_CONVERT_ACC_TO_B || m_data.testType == TT_TRANSPOSE_ACC_TO_B)
1236         {
1237             atype = "gl_MatrixUseAccumulator";
1238             btype = "gl_MatrixUseAccumulator";
1239             ctype = "gl_MatrixUseB";
1240             rtype = "gl_MatrixUseB";
1241         }
1242 
1243         matAType << "coopmat<" << componentTypeInfo[m_data.inputType].typeName << ", " << scopeStr << ", ARows, ACols, "
1244                  << atype << ">";
1245         matBType << "coopmat<" << componentTypeInfo[m_data.inputType].typeName << ", " << scopeStr << ", "
1246                  << dims[1].rows << ", " << dims[1].cols << ", " << btype << ">";
1247         matCType << "coopmat<" << componentTypeInfo[m_data.outputType].typeName << ", " << scopeStr << ", "
1248                  << dims[2].rows << ", " << dims[2].cols << ", " << ctype << ">";
1249         outputMatType << "coopmat<" << componentTypeInfo[m_data.outputType].typeName << ", " << scopeStr
1250                       << ", ORows, OCols, " << rtype << ">";
1251     }
1252     else
1253     {
1254         matAType << componentTypeInfo[m_data.inputType].coopmatTypeName << "<"
1255                  << componentTypeInfo[m_data.inputType].bits << ", " << scopeStr << ", ARows, ACols>";
1256         matBType << componentTypeInfo[m_data.inputType].coopmatTypeName << "<"
1257                  << componentTypeInfo[m_data.inputType].bits << ", " << scopeStr << ", " << dims[1].rows << ", "
1258                  << dims[1].cols << ">";
1259         matCType << componentTypeInfo[m_data.outputType].coopmatTypeName << "<"
1260                  << componentTypeInfo[m_data.outputType].bits << ", " << scopeStr << ", " << dims[2].rows << ", "
1261                  << dims[2].cols << ">";
1262         outputMatType << componentTypeInfo[m_data.outputType].coopmatTypeName << "<"
1263                       << componentTypeInfo[m_data.outputType].bits << ", " << scopeStr << ", ORows, OCols>";
1264     }
1265 
1266     css << matAType.str() << " matA;\n";
1267     css << matBType.str() << " matB;\n";
1268     css << matCType.str() << " matC;\n";
1269     css << outputMatType.str() << " matO;\n";
1270 
1271     if (m_data.testType == TT_CONSTANT)
1272         css << "const " << outputMatType.str() << " matConst = " << outputMatType.str() << "(1.0);\n";
1273 
1274     if (m_data.testType == TT_FUNC)
1275         css << matAType.str() << " f(" << matAType.str() << " m) { return -m; }\n";
1276 
1277     if (m_data.testType == TT_PER_ELEMENT_OP || m_data.testType == TT_PER_ELEMENT_OP_MAT)
1278     {
1279         std::string type = componentTypeInfo[m_data.inputType].typeName;
1280         css << type << " elemOp(const in uint32_t row, const in uint32_t col, const in " << type << " elem, const in "
1281             << type
1282             << " other) {\n"
1283                "    return elem + other;\n"
1284                "}\n";
1285     }
1286     else if (m_data.testType == TT_PER_ELEMENT_OP_ROW_COL)
1287     {
1288         std::string type = componentTypeInfo[m_data.inputType].typeName;
1289         css << type << " elemOpRowCol(const in uint32_t row, const in uint32_t col, const in " << type
1290             << " elem) {\n"
1291                "    return elem + "
1292             << type
1293             << "(row*3 + col);\n"
1294                "}\n";
1295     }
1296     else if (m_data.testType == TT_PER_ELEMENT_OP_STRUCT)
1297     {
1298         std::string type = componentTypeInfo[m_data.inputType].typeName;
1299         css << "struct ParamType { " << type << " x; };\n";
1300         std::string paramType = "ParamType";
1301         css << type << " elemOp(const in uint32_t row, const in uint32_t col, const in " << type << " elem, const in "
1302             << paramType
1303             << " other) {\n"
1304                "    return elem + other.x;\n"
1305                "}\n";
1306     }
1307     else if (isReduceOp(m_data.testType))
1308     {
1309         std::string type = componentTypeInfo[m_data.inputType].typeName;
1310         css << type << " combineOp(const in " << type << " a, const in " << type << " b) {\n";
1311         if (isReduceSum(m_data.testType))
1312         {
1313             css << "    return a + b;\n";
1314         }
1315         else if (isReduceMin(m_data.testType))
1316         {
1317             css << "    return min(a, b);\n";
1318         }
1319         css << "}\n";
1320     }
1321 
1322     if (m_data.testType == TT_MATRIXMULADD_DEQUANT)
1323     {
1324         // 4-bit elements [0,15) with -4 bias and scale of 0.5.
1325         css << "layout(buffer_reference, std430, buffer_reference_align = 1) buffer decodeBuf {\n"
1326                "   uint8_t bits["
1327             << blockSize[0] * blockSize[1] / 2
1328             << "];\n"
1329                "};\n";
1330 
1331         css << typeStrA
1332             << " decodeFunc(const in decodeBuf b, const in uint32_t blockCoords[2], const in uint32_t coordInBlock[2]) "
1333                "{\n"
1334                "   uint32_t idx = coordInBlock[0] * "
1335             << blockSize[1]
1336             << " + coordInBlock[1];\n"
1337                "   uint32_t arrayidx = idx / 2;\n"
1338                "   uint32_t shift = (idx & 1) * 4;\n"
1339                "   int32_t bits = int32_t(b.bits[arrayidx]);\n"
1340                "   bits = (bits >> shift) & 0xF;\n"
1341                "   return "
1342             << typeStrA
1343             << "(0.5 * float(bits - 4));\n"
1344                "}\n";
1345     }
1346     else if (m_data.addrMethod == ADDR_DECODE)
1347     {
1348         css << "layout(buffer_reference, std430, buffer_reference_align = "
1349             << (componentTypeInfo[m_data.inputType].bits / 8)
1350             << ") buffer decodeBuf {\n"
1351                "   "
1352             << typeStrA << " f[" << blockSize[0] * blockSize[1]
1353             << "];\n"
1354                "};\n";
1355 
1356         // Lookup from coord in block, and add f(blockCoords)
1357         css << typeStrA
1358             << " decodeFunc(const in decodeBuf b, const in uint32_t blockCoords[2], const in uint32_t coordInBlock[2]) "
1359                "{\n"
1360                "   return b.f[coordInBlock[0] * "
1361             << blockSize[1] << " + coordInBlock[1]] + " << typeStrA
1362             << "((2*blockCoords[0] + blockCoords[1]) & 3);\n"
1363                "}\n";
1364     }
1365 
1366     css << "void main()\n"
1367            "{\n";
1368     if (m_data.scope == VK_SCOPE_WORKGROUP_KHR)
1369     {
1370         css << "   uvec2 matrixID = uvec2(gl_WorkGroupID.xy);\n";
1371     }
1372     else
1373     {
1374         css <<
1375             // matrixID is the x,y index of the matrix owned by this subgroup.
1376             "   uvec2 subgroupXY = uvec2(gl_SubgroupID % subgroupsPerWG.x, gl_SubgroupID / subgroupsPerWG.x);\n"
1377             "   uvec2 matrixID = uvec2(gl_WorkGroupID.xy) * subgroupsPerWG + subgroupXY;\n";
1378     }
1379 
1380     if (m_data.storageClass == SC_PHYSICAL_STORAGE_BUFFER)
1381     {
1382         css << "   InputA inputA = params.inputA;\n";
1383         css << "   InputB inputB = params.inputB;\n";
1384         css << "   InputC inputC = params.inputC;\n";
1385         css << "   Output outputO = params.outputO;\n";
1386     }
1387 
1388     string strides[4];
1389     string heights[4];
1390     for (uint32_t i = 0; i < 4; ++i)
1391     {
1392         if (m_data.scope == VK_SCOPE_WORKGROUP_KHR)
1393         {
1394             strides[i] =
1395                 (m_data.colMajor ? dims[i].rows : dims[i].cols) + string(" * ") + de::toString(m_data.workgroupsX);
1396             heights[i] =
1397                 (m_data.colMajor ? dims[i].cols : dims[i].rows) + string(" * ") + de::toString(m_data.workgroupsY);
1398         }
1399         else
1400         {
1401             strides[i] = (m_data.colMajor ? dims[i].rows : dims[i].cols) + string(" * ") +
1402                          de::toString(m_data.subgroupsPerWorkgroupX * m_data.workgroupsX);
1403             heights[i] = (m_data.colMajor ? dims[i].cols : dims[i].rows) + string(" * ") +
1404                          de::toString(m_data.subgroupsPerWorkgroupY * m_data.workgroupsY);
1405         }
1406     }
1407 
1408     if (m_data.addrMethod != ADDR_LINEAR)
1409     {
1410         css << "   int offset00 = int(" << (m_data.colMajor ? dims[0].cols : dims[0].rows)
1411             << " * matrixID.y); int offset01 = int(" << (m_data.colMajor ? dims[0].rows : dims[0].cols)
1412             << " * matrixID.x);\n";
1413         css << "   int offset10 = int(" << (m_data.colMajor ? dims[1].cols : dims[1].rows)
1414             << " * matrixID.y); int offset11 = int(" << (m_data.colMajor ? dims[1].rows : dims[1].cols)
1415             << " * matrixID.x);\n";
1416         css << "   int offset20 = int(" << (m_data.colMajor ? dims[2].cols : dims[2].rows)
1417             << " * matrixID.y); int offset21 = int(" << (m_data.colMajor ? dims[2].rows : dims[2].cols)
1418             << " * matrixID.x);\n";
1419         css << "   int offset30 = int(" << (m_data.colMajor ? dims[3].cols : dims[3].rows)
1420             << " * matrixID.y); int offset31 = int(" << (m_data.colMajor ? dims[3].rows : dims[3].cols)
1421             << " * matrixID.x);\n";
1422 
1423         css << "   uint span00 = " << (m_data.colMajor ? dims[0].cols : dims[0].rows)
1424             << "; uint span01 = " << (m_data.colMajor ? dims[0].rows : dims[0].cols) << ";\n";
1425         css << "   uint span10 = " << (m_data.colMajor ? dims[1].cols : dims[1].rows)
1426             << "; uint span11 = " << (m_data.colMajor ? dims[1].rows : dims[1].cols) << ";\n";
1427         css << "   uint span20 = " << (m_data.colMajor ? dims[2].cols : dims[2].rows)
1428             << "; uint span21 = " << (m_data.colMajor ? dims[2].rows : dims[2].cols) << ";\n";
1429         css << "   uint span30 = " << (m_data.colMajor ? dims[3].cols : dims[3].rows)
1430             << "; uint span31 = " << (m_data.colMajor ? dims[3].rows : dims[3].cols) << ";\n";
1431     }
1432 
1433     if (isClampTest(m_data.testType))
1434     {
1435         // Clamp tests adjust offset and dimensions to shrink the load boundary by 3 on each edge
1436         css << "   offset00 -= 3; offset01 -= 3;\n";
1437         css << "   offset10 -= 3; offset11 -= 3;\n";
1438         css << "   offset20 -= 3; offset21 -= 3;\n";
1439     }
1440 
1441     // element<i> is the starting element in buffer memory.
1442     // elementS<i> is the starting element in shared memory.
1443     css << "   uint element0 = (" << strides[0] << " * " << (m_data.colMajor ? dims[0].cols : dims[0].rows)
1444         << " * matrixID.y + " << (m_data.colMajor ? dims[0].rows : dims[0].cols) << " * matrixID.x)" << divisorA
1445         << ";\n"
1446            "   uint element1 = ("
1447         << strides[1] << " * " << (m_data.colMajor ? dims[1].cols : dims[1].rows) << " * matrixID.y + "
1448         << (m_data.colMajor ? dims[1].rows : dims[1].cols) << " * matrixID.x)" << divisorB
1449         << ";\n"
1450            "   uint element2 = ("
1451         << strides[2] << " * " << (m_data.colMajor ? dims[2].cols : dims[2].rows) << " * matrixID.y + "
1452         << (m_data.colMajor ? dims[2].rows : dims[2].cols) << " * matrixID.x)" << divisorC
1453         << ";\n"
1454            "   uint element3 = ("
1455         << strides[3] << " * " << (m_data.colMajor ? dims[3].cols : dims[3].rows) << " * matrixID.y + "
1456         << (m_data.colMajor ? dims[3].rows : dims[3].cols) << " * matrixID.x)" << divisorO
1457         << ";\n"
1458            "   uint elementS0, elementS1, elementS2, elementS3;\n";
1459 
1460     // For shared memory tests, copy the matrix from buffer memory into
1461     // workgroup memory. For simplicity, do it all on a single thread.
1462     if (m_data.storageClass == SC_WORKGROUP || m_data.storageClass == SC_WORKGROUP_VARIABLE_POINTERS)
1463     {
1464         const char *name[] = {
1465             "sharedA",
1466             "sharedB",
1467             "sharedC",
1468         };
1469         const char *inputName[] = {
1470             "inputA",
1471             "inputB",
1472             "inputC",
1473         };
1474         for (uint32_t m = 0; m < 4; ++m)
1475         {
1476             string sharedStride = strides[m] + " / workgroupsX";
1477             if (m_data.scope == VK_SCOPE_WORKGROUP_KHR)
1478             {
1479                 css << "       elementS" << m << " = 0;\n";
1480             }
1481             else
1482             {
1483                 css << "       elementS" << m << " = (" << sharedStride << " * "
1484                     << (m_data.colMajor ? dims[m].cols : dims[m].rows) << " * subgroupXY.y + "
1485                     << (m_data.colMajor ? dims[m].rows : dims[m].cols) << " * subgroupXY.x)" << *divisors[m] << ";\n";
1486             }
1487         }
1488         css << "   if (subgroupElect()) {\n";
1489         // copy all three input buffers.
1490         for (uint32_t m = 0; m < 3; ++m)
1491         {
1492             if (m == 0 && (m_data.testType == TT_LENGTH || m_data.testType == TT_CONSTANT))
1493             {
1494                 // A matrix not needed
1495                 continue;
1496             }
1497             if (m == 1)
1498             {
1499                 // B matrix not needed
1500                 if (isReduceOp(m_data.testType) || isClampTest(m_data.testType))
1501                 {
1502                     continue;
1503                 }
1504                 switch (m_data.testType)
1505                 {
1506                 case TT_CONSTANT:
1507                 case TT_LENGTH:
1508                 case TT_CONVERT:
1509                 case TT_NEGATE:
1510                 case TT_FUNC:
1511                 case TT_MATRIXTIMESSCALAR:
1512                 case TT_MULTICOMPONENT_LOAD:
1513                 case TT_MULTICOMPONENT_SAVE:
1514                 case TT_CONVERT_ACC_TO_A:
1515                 case TT_CONVERT_ACC_TO_B:
1516                 case TT_TRANSPOSE_ACC_TO_B:
1517                 case TT_PER_ELEMENT_OP:
1518                 case TT_PER_ELEMENT_OP_MAT:
1519                 case TT_PER_ELEMENT_OP_STRUCT:
1520                 case TT_PER_ELEMENT_OP_ROW_COL:
1521                 case TT_SPACETODEPTH:
1522                     continue;
1523                 default:
1524                     break;
1525                 }
1526             }
1527             if (m == 2 && !isMatrixMulAddOp(m_data.testType))
1528             {
1529                 // C matrix only needed for matmul
1530                 continue;
1531             }
1532             string sharedStride = strides[m] + " / workgroupsX";
1533             css << "       for (int i = 0; i < " << dims[m].rows
1534                 << "; ++i) {\n"
1535                    "       for (int j = 0; j < "
1536                 << dims[m].cols
1537                 << "; ++j) {\n"
1538                    "           int localElementInput = ("
1539                 << strides[m] << " * " << (m_data.colMajor ? "j" : "i") << " + " << (m_data.colMajor ? "i" : "j") << ")"
1540                 << *divisors[m]
1541                 << ";\n"
1542                    "           int localElementShared = ("
1543                 << sharedStride << " * " << (m_data.colMajor ? "j" : "i") << " + " << (m_data.colMajor ? "i" : "j")
1544                 << ")" << *divisors[m]
1545                 << ";\n"
1546                    "           "
1547                 << name[m] << "[elementS" << m << " + localElementShared] = " << inputName[m] << ".x[element" << m
1548                 << " + localElementInput];\n"
1549                    "       }\n"
1550                    "       }\n";
1551             strides[m] = sharedStride;
1552         }
1553         css << "   }\n";
1554         css << "   controlBarrier(" << scopeStr << ", " << scopeStr
1555             << ", gl_StorageSemanticsShared, gl_SemanticsAcquireRelease);\n";
1556     }
1557 
1558     const char *colMajorNV = (m_data.colMajor ? "true" : "false");
1559     const char *colMajorKHR =
1560         (m_data.colMajor ? "gl_CooperativeMatrixLayoutColumnMajor" : "gl_CooperativeMatrixLayoutRowMajor");
1561     const char *colMajor = (isKhr(m_data.useType) ? colMajorKHR : colMajorNV);
1562 
1563     string loadStrides[3] = {strides[0] + divisorA, strides[1] + divisorB, strides[2] + divisorC};
1564     // Load with a stride of 0
1565     if (m_data.testType == TT_MATRIXMULADD_STRIDE0)
1566         loadStrides[0] = loadStrides[1] = loadStrides[2] = "0";
1567 
1568     std::string clampString;
1569     switch (m_data.testType)
1570     {
1571     default:
1572         break;
1573     case TT_CLAMPCONSTANT:
1574         clampString = "gl_CooperativeMatrixClampModeConstantNV";
1575         break;
1576     case TT_CLAMPTOEDGE:
1577         clampString = "gl_CooperativeMatrixClampModeClampToEdgeNV";
1578         break;
1579     case TT_CLAMPREPEAT:
1580         clampString = "gl_CooperativeMatrixClampModeRepeatNV";
1581         break;
1582     case TT_CLAMPMIRRORREPEAT:
1583         clampString = "gl_CooperativeMatrixClampModeMirrorRepeatNV";
1584         break;
1585     }
1586 
1587     if (!isTensorLayoutTest(m_data.testType))
1588     {
1589         if (m_data.addrMethod != ADDR_LINEAR)
1590         {
1591 
1592             if (m_data.testType == TT_MATRIXMULADD_STRIDE0)
1593             {
1594                 heights[0] = heights[1] = heights[2] = "1";
1595             }
1596 
1597             if (isClampTest(m_data.testType))
1598             {
1599                 css << "   tensorLayoutNV<2, " << clampString << "> tensorLayout0 = createTensorLayoutNV(2, "
1600                     << clampString
1601                     << ");\n"
1602                        "   tensorLayoutNV<2, "
1603                     << clampString << "> tensorLayout1 = createTensorLayoutNV(2, " << clampString
1604                     << ");\n"
1605                        "   tensorLayoutNV<2, "
1606                     << clampString << "> tensorLayout2 = createTensorLayoutNV(2, " << clampString << ");\n";
1607 
1608                 css << "   tensorLayout0 = setTensorLayoutDimensionNV(tensorLayout0, " << heights[0] << " - 6, "
1609                     << strides[0]
1610                     << " - 6);\n"
1611                        "   tensorLayout1 = setTensorLayoutDimensionNV(tensorLayout1, "
1612                     << heights[1] << " - 6, " << strides[1]
1613                     << " - 6);\n"
1614                        "   tensorLayout2 = setTensorLayoutDimensionNV(tensorLayout2, "
1615                     << heights[2] << " - 6, " << strides[2] << " - 6);\n";
1616                 css << "   tensorLayout0 = setTensorLayoutStrideNV(tensorLayout0, " << strides[0]
1617                     << ", 1);\n"
1618                        "   tensorLayout1 = setTensorLayoutStrideNV(tensorLayout1, "
1619                     << strides[1]
1620                     << ", 1);\n"
1621                        "   tensorLayout2 = setTensorLayoutStrideNV(tensorLayout2, "
1622                     << strides[2] << ", 1);\n";
1623                 if (m_data.inputType == VK_COMPONENT_TYPE_FLOAT32_KHR)
1624                 {
1625                     css << "   tensorLayout0 = setTensorLayoutClampValueNV(tensorLayout0, floatBitsToUint(0.5));\n";
1626                 }
1627                 else if (m_data.inputType == VK_COMPONENT_TYPE_FLOAT16_KHR)
1628                 {
1629                     // 0x3800 == 0.5f in fp16
1630                     css << "   tensorLayout0 = setTensorLayoutClampValueNV(tensorLayout0, 0x3800);\n";
1631                 }
1632                 else
1633                 {
1634                     css << "   tensorLayout0 = setTensorLayoutClampValueNV(tensorLayout0, 17);\n";
1635                 }
1636             }
1637             else
1638             {
1639                 css << "   tensorLayoutNV<2> tensorLayout0 = createTensorLayoutNV(2);\n"
1640                        "   tensorLayoutNV<2> tensorLayout1 = createTensorLayoutNV(2);\n"
1641                        "   tensorLayoutNV<2> tensorLayout2 = createTensorLayoutNV(2);\n";
1642 
1643                 if (m_data.addrMethod == ADDR_BLOCKSIZE || m_data.addrMethod == ADDR_DECODE)
1644                 {
1645                     css << "   tensorLayout0 = setTensorLayoutBlockSizeNV(tensorLayout0, " << blockSize[0] << ", "
1646                         << blockSize[1]
1647                         << ");\n"
1648                            "   tensorLayout1 = setTensorLayoutBlockSizeNV(tensorLayout1, "
1649                         << blockSize[0] << ", " << blockSize[1] << ");\n";
1650                 }
1651 
1652                 css << "   tensorLayout0 = setTensorLayoutDimensionNV(tensorLayout0, " << heights[0] << ", "
1653                     << strides[0]
1654                     << ");\n"
1655                        "   tensorLayout1 = setTensorLayoutDimensionNV(tensorLayout1, "
1656                     << heights[1] << ", " << strides[1]
1657                     << ");\n"
1658                        "   tensorLayout2 = setTensorLayoutDimensionNV(tensorLayout2, "
1659                     << heights[2] << ", " << strides[2] << ");\n";
1660             }
1661 
1662             string viewParam0, viewParam1, viewParam2;
1663             string decodeFunc;
1664 
1665             if (m_data.testType == TT_MATRIXMULADD_STRIDE0)
1666             {
1667                 if (m_data.colMajor)
1668                 {
1669                     css << "   tensorViewNV<2, true, 1, 0> stride0View0 = createTensorViewNV(2, true, 1, 0);\n"
1670                            "   tensorViewNV<2, true, 1, 0> stride0View1 = createTensorViewNV(2, true, 1, 0);\n"
1671                            "   tensorViewNV<2, true, 1, 0> stride0View2 = createTensorViewNV(2, true, 1, 0);\n";
1672                 }
1673                 else
1674                 {
1675                     css << "   tensorViewNV<2, true> stride0View0 = createTensorViewNV(2, true);\n"
1676                            "   tensorViewNV<2, true> stride0View1 = createTensorViewNV(2, true);\n"
1677                            "   tensorViewNV<2, true> stride0View2 = createTensorViewNV(2, true);\n";
1678                 }
1679                 css << "   stride0View0 = setTensorViewDimensionsNV(stride0View0, span00, span01);\n"
1680                        "   stride0View1 = setTensorViewDimensionsNV(stride0View1, span10, span11);\n"
1681                        "   stride0View2 = setTensorViewDimensionsNV(stride0View2, span20, span21);\n"
1682                        "   stride0View0 = setTensorViewStrideNV(stride0View0, 0, 1);\n"
1683                        "   stride0View1 = setTensorViewStrideNV(stride0View1, 0, 1);\n"
1684                        "   stride0View2 = setTensorViewStrideNV(stride0View2, 0, 1);\n";
1685 
1686                 viewParam0 = ", stride0View0";
1687                 viewParam1 = ", stride0View1";
1688                 viewParam2 = ", stride0View2";
1689             }
1690             else if (m_data.colMajor)
1691             {
1692                 css << "   tensorViewNV<2, true, 1, 0> colMajorView0 = createTensorViewNV(2, true, 1, 0);\n"
1693                        "   tensorViewNV<2, true, 1, 0> colMajorView1 = createTensorViewNV(2, true, 1, 0);\n"
1694                        "   tensorViewNV<2, true, 1, 0> colMajorView2 = createTensorViewNV(2, true, 1, 0);\n"
1695                        "   colMajorView0 = setTensorViewDimensionsNV(colMajorView0, span00, span01);\n"
1696                        "   colMajorView1 = setTensorViewDimensionsNV(colMajorView1, span10, span11);\n"
1697                        "   colMajorView2 = setTensorViewDimensionsNV(colMajorView2, span20, span21);\n";
1698 
1699                 viewParam0 = ", colMajorView0";
1700                 viewParam1 = ", colMajorView1";
1701                 viewParam2 = ", colMajorView2";
1702             }
1703 
1704             if (m_data.addrMethod == ADDR_DECODE)
1705             {
1706                 decodeFunc = ", decodeFunc";
1707             }
1708 
1709             if (m_data.storageClass == SC_WORKGROUP || m_data.storageClass == SC_WORKGROUP_VARIABLE_POINTERS)
1710             {
1711                 if (m_data.scope == VK_SCOPE_WORKGROUP_KHR)
1712                 {
1713                     css << "   elementS0 = elementS1 = elementS2 = 0;\n";
1714                 }
1715                 css << "   tensorLayout0 = sliceTensorLayoutNV(tensorLayout0, 0, span00, 0, span01);\n"
1716                        "   tensorLayout1 = sliceTensorLayoutNV(tensorLayout1, 0, span10, 0, span11);\n"
1717                        "   tensorLayout2 = sliceTensorLayoutNV(tensorLayout2, 0, span20, 0, span21);\n";
1718                 css << "   coopMatLoadTensorNV(matA, sharedA, elementS0, tensorLayout0" << viewParam0
1719                     << ");\n"
1720                        "   coopMatLoadTensorNV(matB, sharedB, elementS1, tensorLayout1"
1721                     << viewParam1
1722                     << ");\n"
1723                        "   coopMatLoadTensorNV(matC, sharedC, elementS2, tensorLayout2"
1724                     << viewParam2 << ");\n";
1725             }
1726             else
1727             {
1728                 css << "   tensorLayout0 = sliceTensorLayoutNV(tensorLayout0, offset00, span00, offset01, span01);\n"
1729                        "   tensorLayout1 = sliceTensorLayoutNV(tensorLayout1, offset10, span10, offset11, span11);\n"
1730                        "   tensorLayout2 = sliceTensorLayoutNV(tensorLayout2, offset20, span20, offset21, span21);\n";
1731                 css << "   coopMatLoadTensorNV(matA, inputA.x, 0, tensorLayout0" << viewParam0 << decodeFunc
1732                     << ");\n"
1733                        "   coopMatLoadTensorNV(matB, inputB.x, 0, tensorLayout1"
1734                     << viewParam1 << decodeFunc
1735                     << ");\n"
1736                        "   coopMatLoadTensorNV(matC, inputC.x, 0, tensorLayout2"
1737                     << viewParam2 << ");\n";
1738             }
1739         }
1740         else
1741         {
1742             if (m_data.storageClass == SC_WORKGROUP || m_data.storageClass == SC_WORKGROUP_VARIABLE_POINTERS)
1743             {
1744                 css << "   coopMatLoad" << suffix << "(matA, sharedA, elementS0, " << loadStrides[0] << ", " << colMajor
1745                     << ");\n"
1746                        "   coopMatLoad"
1747                     << suffix << "(matB, sharedB, elementS1, " << loadStrides[1] << ", " << colMajor
1748                     << ");\n"
1749                        "   coopMatLoad"
1750                     << suffix << "(matC, sharedC, elementS2, " << loadStrides[2] << ", " << colMajor << ");\n";
1751             }
1752             else
1753             {
1754                 css << "   coopMatLoad" << suffix << "(matA, inputA.x, element0, " << loadStrides[0] << ", " << colMajor
1755                     << ");\n"
1756                        "   coopMatLoad"
1757                     << suffix << "(matB, inputB.x, element1, " << loadStrides[1] << ", " << colMajor
1758                     << ");\n"
1759                        "   coopMatLoad"
1760                     << suffix << "(matC, inputC.x, element2, " << loadStrides[2] << ", " << colMajor << ");\n";
1761             }
1762         }
1763     }
1764 
1765     if (m_data.testType == TT_COMPOSITE_ARRAY || m_data.testType == TT_MATRIXMULADD_ARRAY)
1766     {
1767         css << "   " << matAType.str() << " matAArr[2];\n    matAArr[1] = matA; matAArr[0] = " << matAType.str()
1768             << "(0.0);\n"
1769                "   "
1770             << matBType.str() << " matBArr[2];\n    matBArr[1] = matB; matBArr[0] = " << matBType.str()
1771             << "(0.0);\n"
1772                "   "
1773             << matCType.str() << " matCArr[2];\n    matCArr[1] = matC; matCArr[0] = " << matCType.str()
1774             << "(0.0);\n"
1775                "   "
1776             << outputMatType.str() << " matOArr[2];\n";
1777     }
1778 
1779     switch (m_data.testType)
1780     {
1781     default:
1782         DE_ASSERT(0);
1783         // fall through
1784     case TT_LENGTH:
1785         css << "   matO = " << outputMatType.str() << "(matO.length());\n";
1786         break;
1787     case TT_CONSTANT:
1788         css << "   matO = matConst;\n";
1789         break;
1790     case TT_CONVERT:
1791         css << "   matO = " << outputMatType.str() << "(matA);\n";
1792         break;
1793     case TT_COMPOSITE:
1794         css << "   " << matAType.str() << " t = " << matAType.str()
1795             << "(matB[0]);\n"
1796                "   for (int i = 1; i < matA.length(); ++i) {\n"
1797                "       matO[i] = matA[i] + matB[i];\n"
1798                "   }\n"
1799                "   if (matA.length() > 0)\n"
1800                "       matO[0] = matA[0] + t[0];\n";
1801         break;
1802     case TT_COMPOSITE_RVALUE:
1803         css << "   for (int i = 1; i < matA.length(); ++i) {\n"
1804                "       matO[i] = matA[i] + matB[i];\n"
1805                "   }\n"
1806                "   "
1807             << matAType.str()
1808             << " t = matA;\n"
1809                "   if (matA.length() > 0) {\n"
1810                "       matO[0] = (t += matB)[0];\n"
1811                "   }\n";
1812         break;
1813     case TT_COMPOSITE_ARRAY:
1814         css << "   for (int i = 0; i < matA.length(); ++i) {\n"
1815                "       matOArr[1][i] = matAArr[1][i] + matBArr[1][i];\n"
1816                "   }\n";
1817         break;
1818     case TT_ADD:
1819         css << "   matO = matA + matB;\n";
1820         break;
1821     case TT_SUB:
1822         css << "   matO = matA - matB;\n";
1823         break;
1824     case TT_DIV:
1825         css << "   matO = matA / matB;\n";
1826         break;
1827     case TT_MUL:
1828         css << "   matO = matA * matB;\n";
1829         break;
1830     case TT_NEGATE:
1831         css << "   matO = -matA;\n";
1832         break;
1833     case TT_FUNC:
1834         css << "   matO = f(matA);\n";
1835         break;
1836     case TT_CLAMPTOEDGE:
1837     case TT_CLAMPCONSTANT:
1838     case TT_CLAMPREPEAT:
1839     case TT_CLAMPMIRRORREPEAT:
1840         css << "   matO = matA;\n";
1841         break;
1842     case TT_MATRIXTIMESSCALAR:
1843         css << "   matO = (" << typeStrA << "(2.0)*matA)*" << typeStrA << "(3.0);\n";
1844         break;
1845     case TT_MATRIXMULADD_DEQUANT:
1846     case TT_MATRIXMULADD_CROSS:
1847     case TT_MATRIXMULADD_STRIDE0:
1848     case TT_MATRIXMULADD_WRAPPING:
1849     case TT_MATRIXMULADD_SATURATED:
1850     case TT_MATRIXMULADD:
1851         css << "   matO = coopMatMulAdd" << suffix << "(matA, matB, matC" << sat << ");\n";
1852         break;
1853     case TT_MATRIXMULADD_ARRAY:
1854         css << "   matOArr[1] = coopMatMulAdd" << suffix << "(matAArr[1], matBArr[1], matCArr[1]);\n";
1855         break;
1856     case TT_MULTICOMPONENT_LOAD:
1857         css << "   matO = matA;\n";
1858         break;
1859     case TT_MULTICOMPONENT_SAVE:
1860         css << "   matO = matA;\n";
1861         break;
1862     case TT_CONVERT_ACC_TO_A:
1863     case TT_CONVERT_ACC_TO_B:
1864         css << "   matO = " << outputMatType.str() << "(matA);\n";
1865         break;
1866     case TT_TRANSPOSE_ACC_TO_B:
1867         css << "   coopMatTransposeNV(matO, matA);\n";
1868         break;
1869     case TT_REDUCE_SUM_ROW:
1870     case TT_REDUCE_SUM_COL:
1871     case TT_REDUCE_SUM_ROWCOL:
1872     case TT_REDUCE_SUM_2X2:
1873     case TT_REDUCE_SUM_ROW_CHANGEDIM:
1874     case TT_REDUCE_SUM_COL_CHANGEDIM:
1875     case TT_REDUCE_SUM_ROWCOL_CHANGEDIM:
1876     case TT_REDUCE_MIN_ROW:
1877     case TT_REDUCE_MIN_COL:
1878     case TT_REDUCE_MIN_ROWCOL:
1879     case TT_REDUCE_MIN_2X2:
1880     {
1881         string rowCol = isReduce2x2(m_data.testType) ? "gl_CooperativeMatrixReduce2x2NV" :
1882                         isReduceRow(m_data.testType) ? "gl_CooperativeMatrixReduceRowNV" :
1883                         isReduceCol(m_data.testType) ? "gl_CooperativeMatrixReduceColumnNV" :
1884                                                        "gl_CooperativeMatrixReduceRowAndColumnNV";
1885 
1886         css << "   coopMatReduceNV(matO, matA, " << rowCol << ", combineOp);\n";
1887     }
1888     break;
1889     case TT_PER_ELEMENT_OP:
1890         css << "   coopMatPerElementNV(matO, matA, elemOp, " << componentTypeInfo[m_data.inputType].typeName
1891             << "(2.0));\n";
1892         break;
1893     case TT_PER_ELEMENT_OP_MAT:
1894         css << "   coopMatPerElementNV(matO, matA, elemOp, " << componentTypeInfo[m_data.inputType].typeName
1895             << "(2.0) * matA);\n";
1896         break;
1897     case TT_PER_ELEMENT_OP_ROW_COL:
1898         css << "   coopMatPerElementNV(matO, matA, elemOpRowCol);\n";
1899         break;
1900     case TT_PER_ELEMENT_OP_STRUCT:
1901         css << "   ParamType p; p.x = " << componentTypeInfo[m_data.inputType].typeName << "(2.0);\n";
1902         css << "   coopMatPerElementNV(matO, matA, elemOp, p);\n";
1903         break;
1904     case TT_TENSORLAYOUT_1D:
1905     case TT_TENSORLAYOUT_2D:
1906     case TT_TENSORLAYOUT_3D:
1907     case TT_TENSORLAYOUT_4D:
1908     case TT_TENSORLAYOUT_5D:
1909     case TT_TENSORLAYOUT_1D_CLIP:
1910     case TT_TENSORLAYOUT_2D_CLIP:
1911     case TT_TENSORLAYOUT_3D_CLIP:
1912     case TT_TENSORLAYOUT_4D_CLIP:
1913     case TT_TENSORLAYOUT_5D_CLIP:
1914     {
1915         uint32_t dim = GetDim(m_data.testType);
1916 
1917         css << "   tensorLayoutNV<" << dim << ", gl_CooperativeMatrixClampModeConstantNV> t = createTensorLayoutNV("
1918             << dim << ", gl_CooperativeMatrixClampModeConstantNV);\n";
1919         if (isTensorLayoutClipTest(m_data.testType))
1920         {
1921             css << "   tensorViewNV<" << dim << "> v = createTensorViewNV(" << dim << ");\n";
1922         }
1923         for (uint32_t i = 0; i < GetTensorLayoutNumCoords(dim); ++i)
1924         {
1925             uint32_t dimFactor = isTensorLayoutClipTest(m_data.testType) ? 2 : 1;
1926 
1927             stringstream mattype;
1928             mattype << "coopmat<" << componentTypeInfo[m_data.inputType].typeName << ", " << scopeStr << ", "
1929                     << dimFactor * GetTensorLayoutMatrixSizes(dim, i)[0] << ", "
1930                     << dimFactor * GetTensorLayoutMatrixSizes(dim, i)[1] << ", " << sameType << ">";
1931             css << "   " << mattype.str() << " tempmat" << i << ";\n";
1932 
1933             css << "   tempmat" << i << " = " << mattype.str() << "(0.5);\n";
1934 
1935             if (isTensorLayoutClipTest(m_data.testType))
1936             {
1937                 // clip the double-size matrix to the requested size
1938                 css << "   v = setTensorViewClipNV(v, 1, " << GetTensorLayoutMatrixSizes(dim, i)[0] << ", 1, "
1939                     << GetTensorLayoutMatrixSizes(dim, i)[1] << ");\n";
1940             }
1941 
1942             css << "   t = setTensorLayoutDimensionNV(t";
1943             for (uint32_t j = 0; j < dim; ++j)
1944             {
1945                 css << ", " << GetTensorLayoutDim(dim)[j];
1946             }
1947             css << ");\n";
1948 
1949             css << "   t = sliceTensorLayoutNV(t";
1950             for (uint32_t j = 0; j < dim; ++j)
1951             {
1952                 css << ", " << GetTensorLayoutLoadOffsets(dim, i)[j] << ", " << GetTensorLayoutSpan(dim, i)[j];
1953             }
1954             css << ");\n";
1955             css << "   coopMatLoadTensorNV(tempmat" << i << ", inputA.x, 0, t"
1956                 << (isTensorLayoutClipTest(m_data.testType) ? ", v" : "") << ");\n";
1957 
1958             css << "   t = setTensorLayoutDimensionNV(t";
1959             for (uint32_t j = 0; j < dim; ++j)
1960             {
1961                 css << ", " << GetTensorLayoutDim(dim)[j];
1962             }
1963             css << ");\n";
1964 
1965             css << "   t = sliceTensorLayoutNV(t";
1966             for (uint32_t j = 0; j < dim; ++j)
1967             {
1968                 css << ", " << GetTensorLayoutStoreOffsets(dim, i)[j] << ", " << GetTensorLayoutSpan(dim, i)[j];
1969             }
1970             css << ");\n";
1971             css << "   coopMatStoreTensorNV(tempmat" << i << ", outputO.x, 0, t"
1972                 << (isTensorLayoutClipTest(m_data.testType) ? ", v" : "") << ");\n";
1973         }
1974     }
1975     break;
1976     case TT_SPACETODEPTH:
1977         css << "   const uint32_t H = 32;\n"
1978                "   const uint32_t W = 32;\n"
1979                "   const uint32_t NumCh = 16;\n";
1980         css << "   tensorLayoutNV<3> t = createTensorLayoutNV(3);\n";
1981         css << "   tensorViewNV<5, true, 0, 2, 1, 3, 4> v = createTensorViewNV(5, true, 0, 2, 1, 3, 4);\n";
1982 
1983         {
1984             stringstream mattype;
1985             mattype << "coopmat<" << componentTypeInfo[m_data.inputType].typeName << ", " << scopeStr
1986                     << ", (H/2 * W/2), (4*NumCh)," << sameType << ">";
1987             css << "   " << mattype.str() << " tempmat;\n";
1988             css << "   tempmat = " << mattype.str() << "(0.5);\n";
1989         }
1990 
1991         css << "   t = setTensorLayoutDimensionNV(t, H, W, NumCh);\n";
1992         css << "   v = setTensorViewDimensionsNV(v, H/2, 2, W/2, 2, NumCh);\n";
1993 
1994         css << "   coopMatLoadTensorNV(tempmat, inputA.x, 0, t, v);\n";
1995 
1996         css << "   tensorLayoutNV<2> t2 = createTensorLayoutNV(2);\n";
1997         css << "   t2 = setTensorLayoutDimensionNV(t2, H/2 * W/2, 4*NumCh);";
1998 
1999         css << "   coopMatStoreTensorNV(tempmat, outputO.x, 0, t2);\n";
2000         break;
2001     }
2002 
2003     if (!isTensorLayoutTest(m_data.testType))
2004     {
2005         if (m_data.testType == TT_COMPOSITE_ARRAY || m_data.testType == TT_MATRIXMULADD_ARRAY)
2006         {
2007             css << "   matOArr[0] = " << outputMatType.str() << "(0.0);\n";
2008             css << "   matO = matOArr[1];\n";
2009         }
2010 
2011         if (m_data.storageClass == SC_WORKGROUP || m_data.storageClass == SC_WORKGROUP_VARIABLE_POINTERS)
2012         {
2013             string sharedStride = strides[3] + " / workgroupsX";
2014             if (m_data.addrMethod != ADDR_LINEAR)
2015             {
2016                 css << "   tensorLayoutNV<2> tensorLayout3 = createTensorLayoutNV(2);\n"
2017                        "   tensorLayout3 = setTensorLayoutDimensionNV(tensorLayout3, "
2018                     << heights[3] << ", " << sharedStride
2019                     << ");\n"
2020                        "   tensorLayout3 = sliceTensorLayoutNV(tensorLayout3, 0, span30, 0, span31);\n";
2021 
2022                 css << "   tensorViewNV<2, false, 1, 0> colMajorView3 = createTensorViewNV(2, false, 1, 0);\n";
2023 
2024                 if (m_data.scope == VK_SCOPE_WORKGROUP_KHR)
2025                 {
2026                     css << "   elementS3 = 0;\n";
2027                 }
2028                 css << "   coopMatStoreTensorNV(matO, sharedO, elementS3, tensorLayout3"
2029                     << (m_data.colMajor ? ", colMajorView3" : "") << ");\n";
2030             }
2031             else
2032             {
2033                 css << "   coopMatStore" << suffix << "(matO, sharedO, elementS3, " << sharedStride << divisorO << ", "
2034                     << colMajor << ");\n";
2035             }
2036             css << "   controlBarrier(" << scopeStr << ", " << scopeStr
2037                 << ", gl_StorageSemanticsShared, gl_SemanticsAcquireRelease);\n";
2038             css << "   if (subgroupElect()) {\n";
2039             css << "       for (int i = 0; i < " << dims[3].rows
2040                 << "; ++i) {\n"
2041                    "       for (int j = 0; j < "
2042                 << dims[3].cols
2043                 << "; ++j) {\n"
2044                    "           int localElementInput = ("
2045                 << strides[3] << " * " << (m_data.colMajor ? "j" : "i") << " + " << (m_data.colMajor ? "i" : "j") << ")"
2046                 << *divisors[3]
2047                 << ";\n"
2048                    "           int localElementShared = ("
2049                 << sharedStride << " * " << (m_data.colMajor ? "j" : "i") << " + " << (m_data.colMajor ? "i" : "j")
2050                 << ")" << *divisors[3]
2051                 << ";\n"
2052                    "           outputO.x[element3 + localElementInput] = sharedO[elementS3 + localElementShared];\n"
2053                    "       }\n"
2054                    "       }\n";
2055             css << "   }\n";
2056             strides[3] = sharedStride;
2057         }
2058         else
2059         {
2060             if (m_data.addrMethod != ADDR_LINEAR)
2061             {
2062                 if (isClampTest(m_data.testType))
2063                 {
2064                     css << "   tensorLayoutNV<2, " << clampString << "> tensorLayout3 = createTensorLayoutNV(2, "
2065                         << clampString << ");\n";
2066 
2067                     // Shrink the width/height by 1
2068                     css << "   tensorLayout3 = setTensorLayoutDimensionNV(tensorLayout3, " << heights[3] << " - 1, "
2069                         << strides[3] << " - 1);\n";
2070                     css << "   tensorLayout3 = setTensorLayoutStrideNV(tensorLayout3, " << strides[3] << ", 1);\n";
2071                 }
2072                 else
2073                 {
2074                     css << "   tensorLayoutNV<2> tensorLayout3 = createTensorLayoutNV(2);\n"
2075                            "   tensorLayout3 = setTensorLayoutDimensionNV(tensorLayout3, "
2076                         << heights[3] << ", " << strides[3] << ");\n";
2077                 }
2078 
2079                 css << "   tensorLayout3 = sliceTensorLayoutNV(tensorLayout3, offset30, span30, offset31, span31);\n";
2080 
2081                 css << "   tensorViewNV<2, false, 1, 0> colMajorView3 = createTensorViewNV(2, false, 1, 0);\n";
2082                 css << "   coopMatStoreTensorNV(matO, outputO.x, 0, tensorLayout3"
2083                     << (m_data.colMajor ? ", colMajorView3" : "") << ");\n";
2084             }
2085             else
2086             {
2087                 css << "   coopMatStore" << suffix << "(matO, outputO.x, element3, " << strides[3] << divisorO << ", "
2088                     << colMajor << ");\n";
2089             }
2090         }
2091     }
2092 
2093     css << "}\n";
2094 
2095     const vk::ShaderBuildOptions buildOptions(programCollection.usedVulkanVersion, vk::SPIRV_VERSION_1_6, 0u);
2096 
2097     programCollection.glslSources.add("test") << glu::ComputeSource(css.str()) << buildOptions;
2098 }
2099 
getTypeName(const VkComponentTypeKHR type)2100 std::string getTypeName(const VkComponentTypeKHR type)
2101 {
2102     switch (type)
2103     {
2104     case VK_COMPONENT_TYPE_SINT8_KHR:
2105         return "s8";
2106     case VK_COMPONENT_TYPE_SINT32_KHR:
2107         return "s32";
2108     case VK_COMPONENT_TYPE_UINT8_KHR:
2109         return "u8";
2110     case VK_COMPONENT_TYPE_UINT32_KHR:
2111         return "u32";
2112     default:
2113         TCU_THROW(InternalError, "Support for this type is not implemented");
2114     }
2115 }
2116 
getTypeWidth(const VkComponentTypeKHR type)2117 size_t getTypeWidth(const VkComponentTypeKHR type)
2118 {
2119     switch (type)
2120     {
2121     case VK_COMPONENT_TYPE_SINT8_KHR:
2122         return 1;
2123     case VK_COMPONENT_TYPE_SINT32_KHR:
2124         return 4;
2125     case VK_COMPONENT_TYPE_UINT8_KHR:
2126         return 1;
2127     case VK_COMPONENT_TYPE_UINT32_KHR:
2128         return 4;
2129     default:
2130         TCU_THROW(InternalError, "Support for this type is not implemented");
2131     }
2132 }
2133 
getOppositeSignednessTypeName(const VkComponentTypeKHR type)2134 std::string getOppositeSignednessTypeName(const VkComponentTypeKHR type)
2135 {
2136     std::string result = getTypeName(type);
2137 
2138     if (result[0] == 'u')
2139         result[0] = 's';
2140     else if (result[0] == 's')
2141         result[0] = 'u';
2142     else
2143         TCU_THROW(InternalError, "Support for this type is not implemented");
2144 
2145     return result;
2146 }
2147 
initProgramsSPIRV(SourceCollections & programCollection) const2148 void CooperativeMatrixTestCase::initProgramsSPIRV(SourceCollections &programCollection) const
2149 {
2150     std::string dims[4] = {
2151         m_data.colMajor ? "M" : "K",
2152         m_data.colMajor ? "K" : "N",
2153         m_data.colMajor ? "M" : "N",
2154         m_data.colMajor ? "M" : "N",
2155     };
2156     //  #version 450 core
2157     //  #pragma use_vulkan_memory_model
2158     //  #extension GL_KHR_shader_subgroup_basic : enable
2159     //  #extension GL_KHR_memory_scope_semantics : enable
2160     //  #extension GL_KHR_cooperative_matrix : enable
2161     //  #extension GL_EXT_shader_explicit_arithmetic_types : enable
2162     //  #extension GL_EXT_buffer_reference : enable
2163     //  // strides overriden by spec constants
2164     //  layout(constant_id = 6) const int M = 1;
2165     //  layout(constant_id = 7) const int N = 1;
2166     //  layout(constant_id = 8) const int K = 1;
2167     //  layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z = 1) in;
2168     //  const int workgroupsX = 4;
2169     //  const uvec2 subgroupsPerWG = uvec2(2, 2);
2170     //  layout(set=0, binding=0) coherent buffer InputA { int8_t x[]; } inputA;
2171     //  layout(set=0, binding=1) coherent buffer InputB { int8_t x[]; } inputB;
2172     //  layout(set=0, binding=2) coherent buffer InputC { int32_t x[]; } inputC;
2173     //  layout(set=0, binding=3) coherent buffer Output { int32_t x[]; } outputO;
2174     //  coopmat<int8_t, gl_ScopeSubgroup, M, K, gl_MatrixUseA> matA;
2175     //  coopmat<int8_t, gl_ScopeSubgroup, K, N, gl_MatrixUseB> matB;
2176     //  coopmat<int32_t, gl_ScopeSubgroup, M, N, gl_MatrixUseAccumulator> matC;
2177     //  coopmat<int32_t, gl_ScopeSubgroup, M, N, gl_MatrixUseAccumulator> matO;
2178     //  void main()
2179     //  {
2180     //     uvec2 subgroupXY = uvec2(gl_SubgroupID % subgroupsPerWG.x, gl_SubgroupID / subgroupsPerWG.x);
2181     //     uvec2 matrixID = uvec2(gl_WorkGroupID.xy) * subgroupsPerWG + subgroupXY;
2182     //     uint element0 = (K * 8 * M * matrixID.y + K * matrixID.x);
2183     //     uint element1 = (N * 8 * K * matrixID.y + N * matrixID.x);
2184     //     uint element2 = (N * 8 * M * matrixID.y + N * matrixID.x);
2185     //     uint element3 = (N * 8 * M * matrixID.y + N * matrixID.x);
2186     //     uint elementS0, elementS1, elementS2, elementS3;
2187     //     coopMatLoad(matA, inputA.x, element0, K * 8, gl_CooperativeMatrixLayoutRowMajor);
2188     //     coopMatLoad(matB, inputB.x, element1, N * 8, gl_CooperativeMatrixLayoutRowMajor);
2189     //     coopMatLoad(matC, inputC.x, element2, N * 8, gl_CooperativeMatrixLayoutRowMajor);
2190     //     matO = coopMatMulAdd(matA, matB, matC);
2191     //     coopMatStore(matO, outputO.x, element3, N * 8, gl_CooperativeMatrixLayoutRowMajor);
2192     //  }
2193     const char *shaderTemplateGlobalString =
2194         "OpCapability Shader\n"
2195         "OpCapability Int8\n"
2196         "OpCapability GroupNonUniform\n"
2197         "OpCapability StorageBuffer8BitAccess\n"
2198         "OpCapability VulkanMemoryModel\n"
2199         "OpCapability CooperativeMatrixKHR\n"
2200         "OpExtension \"SPV_KHR_8bit_storage\"\n"
2201         "OpExtension \"SPV_KHR_cooperative_matrix\"\n"
2202         "OpExtension \"SPV_KHR_vulkan_memory_model\"\n"
2203         "%1 = OpExtInstImport \"GLSL.std.450\"\n"
2204         "OpMemoryModel Logical Vulkan\n"
2205         "OpEntryPoint GLCompute %main \"main\" %gl_SubgroupId %gl_WorkgroupID\n"
2206         "OpExecutionMode %main LocalSize 1 1 1\n"
2207         "OpDecorate %gl_SubgroupId BuiltIn SubgroupId\n"
2208         "OpDecorate %gl_WorkgroupID BuiltIn WorkgroupId\n"
2209 
2210         "OpDecorate %local_size_x SpecId 0\n"
2211         "OpDecorate %local_size_y SpecId 1\n"
2212         "OpDecorate %M            SpecId 6\n"
2213         "OpDecorate %N            SpecId 7\n"
2214         "OpDecorate %K            SpecId 8\n"
2215 
2216         "OpDecorate %inputA_x ArrayStride ${STRIDE_A}\n"
2217         "OpMemberDecorate %inputA_struct 0 Offset 0\n"
2218         "OpDecorate %inputA_struct Block\n"
2219         "OpDecorate %inputA_var DescriptorSet 0\n"
2220         "OpDecorate %inputA_var Binding 0\n"
2221 
2222         "OpDecorate %inputB_x ArrayStride ${STRIDE_B}\n"
2223         "OpMemberDecorate %inputB_struct 0 Offset 0\n"
2224         "OpDecorate %inputB_struct Block\n"
2225         "OpDecorate %inputB_var DescriptorSet 0\n"
2226         "OpDecorate %inputB_var Binding 1\n"
2227 
2228         "OpDecorate %inputC_x ArrayStride ${STRIDE_C}\n"
2229         "OpMemberDecorate %inputC_struct 0 Offset 0\n"
2230         "OpDecorate %inputC_struct Block\n"
2231         "OpDecorate %inputC_var DescriptorSet 0\n"
2232         "OpDecorate %inputC_var Binding 2\n"
2233 
2234         "OpDecorate %outputO_x ArrayStride ${STRIDE_R}\n"
2235         "OpMemberDecorate %outputO_struct 0 Offset 0\n"
2236         "OpDecorate %outputO_struct Block\n"
2237         "OpDecorate %outputO_var DescriptorSet 0\n"
2238         "OpDecorate %outputO_var Binding 3\n"
2239 
2240         "OpDecorate %wg_size BuiltIn WorkgroupSize\n"
2241         "%void            = OpTypeVoid\n"
2242         "%voidfunc        = OpTypeFunction %void\n"
2243         "%u8              = OpTypeInt 8 0\n"
2244         "%s8              = OpTypeInt 8 1\n"
2245         "%u32             = OpTypeInt 32 0\n"
2246         "%s32             = OpTypeInt 32 1\n"
2247         "%uvec2           = OpTypeVector %u32 2\n"
2248         "%uvec3           = OpTypeVector %u32 3\n"
2249         "%piu32           = OpTypePointer Input %u32\n"
2250         "%gl_SubgroupId   = OpVariable %piu32 Input\n"
2251         "%c0u             = OpConstant %u32 0\n"
2252         "%c1u             = OpConstant %u32 1\n"
2253         "%c2u             = OpConstant %u32 2\n"
2254         "%c3u             = OpConstant %u32 3\n"
2255         "%c5u             = OpConstant %u32 5\n"
2256         "%c8s             = OpConstant %s32 8\n"
2257         "%c0s             = OpConstant %s32 0\n"
2258         "%layout          = OpConstant %s32 ${LAYOUT}\n"
2259         "%piuvec3         = OpTypePointer Input %uvec3\n"
2260         "%gl_WorkgroupID  = OpVariable %piuvec3 Input\n"
2261         "%csubgroupsPerWG = OpConstantComposite %uvec2 %c2u %c2u\n"
2262         "%K               = OpSpecConstant %s32 1\n"
2263         "%M               = OpSpecConstant %s32 1\n"
2264         "%N               = OpSpecConstant %s32 1\n"
2265         "%Ku              = OpSpecConstantOp %u32 IAdd %K %c0u\n"
2266         "%Mu              = OpSpecConstantOp %u32 IAdd %M %c0u\n"
2267         "%Nu              = OpSpecConstantOp %u32 IAdd %N %c0u\n"
2268 
2269         "%k8              = OpSpecConstantOp %s32 IMul %K %c8s\n"
2270         "%mk8             = OpSpecConstantOp %s32 IMul %k8 %M\n"
2271         "%mk8u            = OpSpecConstantOp %u32 IAdd %mk8 %c0u\n"
2272 
2273         "%n8              = OpSpecConstantOp %s32 IMul %N %c8s\n"
2274         "%nk8             = OpSpecConstantOp %s32 IMul %n8 %K\n"
2275         "%nk8u            = OpSpecConstantOp %u32 IAdd %nk8 %c0u\n"
2276 
2277         "%nm8             = OpSpecConstantOp %s32 IMul %n8 %M\n"
2278         "%nm8u            = OpSpecConstantOp %u32 IAdd %nm8 %c0u\n"
2279 
2280         "%strideAs        = OpSpecConstantOp %s32 IMul %${MULT_A} %c8s\n"
2281         "%strideA         = OpSpecConstantOp %u32 IAdd %strideAs %c0u\n"
2282         "%strideBs        = OpSpecConstantOp %s32 IMul %${MULT_B} %c8s\n"
2283         "%strideB         = OpSpecConstantOp %u32 IAdd %strideBs %c0u\n"
2284         "%strideCs        = OpSpecConstantOp %s32 IMul %${MULT_C} %c8s\n"
2285         "%strideC         = OpSpecConstantOp %u32 IAdd %strideCs %c0u\n"
2286         "%strideRs        = OpSpecConstantOp %s32 IMul %${MULT_R} %c8s\n"
2287         "%strideR         = OpSpecConstantOp %u32 IAdd %strideRs %c0u\n"
2288 
2289         "%psbmat_s8       = OpTypePointer StorageBuffer %s8\n"
2290         "%psbmat_s32      = OpTypePointer StorageBuffer %s32\n"
2291         "%psbmat_u8       = OpTypePointer StorageBuffer %u8\n"
2292         "%psbmat_u32      = OpTypePointer StorageBuffer %u32\n"
2293 
2294         "%matA            = OpTypeCooperativeMatrixKHR %${A_ELEM_TYPE} %c3u %M %K %c0u\n"
2295         "%inputA_x        = OpTypeRuntimeArray %${A_ELEM_TYPE}\n"
2296         "%inputA_struct   = OpTypeStruct %inputA_x\n"
2297         "%inputA_ptr      = OpTypePointer StorageBuffer %inputA_struct\n"
2298         "%inputA_var      = OpVariable %inputA_ptr StorageBuffer\n"
2299 
2300         "%matB            = OpTypeCooperativeMatrixKHR %${B_ELEM_TYPE} %c3u %K %N %c1u\n"
2301         "%inputB_x        = OpTypeRuntimeArray %${B_ELEM_TYPE}\n"
2302         "%inputB_struct   = OpTypeStruct %inputB_x\n"
2303         "%inputB_ptr      = OpTypePointer StorageBuffer %inputB_struct\n"
2304         "%inputB_var      = OpVariable %inputB_ptr StorageBuffer\n"
2305 
2306         "%matS            = OpTypeCooperativeMatrixKHR %${S_ELEM_TYPE} %c3u %M %N %c2u\n"
2307         "%matU            = OpTypeCooperativeMatrixKHR %${U_ELEM_TYPE} %c3u %M %N %c2u\n"
2308 
2309         "%inputC_x        = OpTypeRuntimeArray %${C_ELEM_TYPE}\n"
2310         "%inputC_struct   = OpTypeStruct %inputC_x\n"
2311         "%inputC_ptr      = OpTypePointer StorageBuffer %inputC_struct\n"
2312         "%inputC_var      = OpVariable %inputC_ptr StorageBuffer\n"
2313 
2314         "%outputO_x       = OpTypeRuntimeArray %${R_ELEM_TYPE}\n"
2315         "%outputO_struct  = OpTypeStruct %outputO_x\n"
2316         "%outputO_ptr     = OpTypePointer StorageBuffer %outputO_struct\n"
2317         "%outputO_var     = OpVariable %outputO_ptr StorageBuffer\n"
2318 
2319         "%local_size_x           = OpSpecConstant %u32 1\n"
2320         "%local_size_y           = OpSpecConstant %u32 1\n"
2321         "%wg_size                = OpSpecConstantComposite %uvec3 %local_size_x %local_size_y %c1u\n"
2322         "%main                   = OpFunction %void None %voidfunc\n"
2323         "%label                  = OpLabel\n"
2324         "%gl_SubgroupId_         = OpLoad %u32 %gl_SubgroupId\n"
2325         "%subgroupXY_x           = OpUMod %u32 %gl_SubgroupId_ %c2u\n"
2326         "%subgroupXY_y           = OpUDiv %u32 %gl_SubgroupId_ %c2u\n"
2327         "%subgroupXY_uvec2       = OpCompositeConstruct %uvec2 %subgroupXY_x %subgroupXY_y\n"
2328         "%gl_WorkgroupID_uvec3   = OpLoad %uvec3 %gl_WorkgroupID\n"
2329         "%gl_WorkgroupID_uvec2   = OpVectorShuffle %uvec2 %gl_WorkgroupID_uvec3 %gl_WorkgroupID_uvec3 0 1\n"
2330         "%2xgl_WorkgroupID_uvec2 = OpIMul %uvec2 %gl_WorkgroupID_uvec2 %csubgroupsPerWG\n"
2331         "%matrixID               = OpIAdd %uvec2 %2xgl_WorkgroupID_uvec2 %subgroupXY_uvec2\n"
2332         "%matrixID_x             = OpCompositeExtract %u32 %matrixID 0\n"
2333         "%matrixID_y             = OpCompositeExtract %u32 %matrixID 1\n"
2334 
2335         "%e0a      = OpIMul %u32 %mk8u %matrixID_y\n"
2336         "%e0b      = OpIMul %u32 %${MULT_A}u %matrixID_x\n"
2337         "%element0 = OpIAdd %u32 %e0a %e0b\n"
2338 
2339         "%e1a      = OpIMul %u32 %nk8u %matrixID_y\n"
2340         "%e1b      = OpIMul %u32 %${MULT_B}u %matrixID_x\n"
2341         "%element1 = OpIAdd %u32 %e1a %e1b\n"
2342 
2343         "%e2a      = OpIMul %u32 %nm8u %matrixID_y\n"
2344         "%e2b      = OpIMul %u32 %${MULT_C}u %matrixID_x\n"
2345         "%element2 = OpIAdd %u32 %e2a %e2b\n"
2346 
2347         "%e3a      = OpIMul %u32 %nm8u %matrixID_y\n"
2348         "%e3b      = OpIMul %u32 %${MULT_R}u %matrixID_x\n"
2349         "%element3 = OpIAdd %u32 %e3a %e3b\n"
2350 
2351         "%Aij      = OpAccessChain %psbmat_${A_ELEM_TYPE} %inputA_var %c0s %element0\n"
2352         "%Aij_mat  = OpCooperativeMatrixLoadKHR %matA %Aij %layout %strideA MakePointerVisible|NonPrivatePointer %c5u\n"
2353 
2354         "%Bij      = OpAccessChain %psbmat_${B_ELEM_TYPE} %inputB_var %c0s %element1\n"
2355         "%Bij_mat  = OpCooperativeMatrixLoadKHR %matB %Bij %layout %strideB MakePointerVisible|NonPrivatePointer %c5u\n"
2356 
2357         "%Cij      = OpAccessChain %psbmat_${C_ELEM_TYPE} %inputC_var %c0s %element2\n"
2358         "%Cij_mat  = OpCooperativeMatrixLoadKHR %${C_TYPE} %Cij %layout %strideC MakePointerVisible|NonPrivatePointer "
2359         "%c5u\n"
2360 
2361         "%matR     = OpCooperativeMatrixMulAddKHR %${R_TYPE} %Aij_mat %Bij_mat %Cij_mat ${SIGNEDNESS}\n"
2362 
2363         "%Rij_mat  = OpAccessChain %psbmat_${R_ELEM_TYPE} %outputO_var %c0s %element3\n"
2364         "OpCooperativeMatrixStoreKHR %Rij_mat %matR %layout %strideR MakePointerAvailable|NonPrivatePointer %c5u\n"
2365 
2366         "OpReturn\n"
2367         "OpFunctionEnd\n";
2368     const tcu::StringTemplate shaderTemplateGlobal(shaderTemplateGlobalString);
2369     const vk::SpirVAsmBuildOptions buildOptions(programCollection.usedVulkanVersion, vk::SPIRV_VERSION_1_3);
2370     const string spirMatATypeName =
2371         m_data.useType == UT_KHR_A ? getOppositeSignednessTypeName(m_data.inputType) : getTypeName(m_data.inputType);
2372     const string spirMatBTypeName =
2373         m_data.useType == UT_KHR_B ? getOppositeSignednessTypeName(m_data.inputType) : getTypeName(m_data.inputType);
2374     const string spirMatCTypeName = m_data.useType == UT_KHR_C ? (isSIntType(m_data.outputType) ? "matU" : "matS") :
2375                                                                  (isSIntType(m_data.outputType) ? "matS" : "matU");
2376     const string spirMatRTypeName = m_data.useType == UT_KHR_Result ?
2377                                         (isSIntType(m_data.outputType) ? "matU" : "matS") :
2378                                         (isSIntType(m_data.outputType) ? "matS" : "matU");
2379     const string spirMatSTypeName = isSIntType(m_data.outputType) ? getTypeName(m_data.outputType) :
2380                                                                     getOppositeSignednessTypeName(m_data.outputType);
2381     const string spirMatUTypeName = isSIntType(m_data.outputType) ? getOppositeSignednessTypeName(m_data.outputType) :
2382                                                                     getTypeName(m_data.outputType);
2383     string signedness             = string(isSIntType(m_data.inputType) ? "|MatrixASignedComponentsKHR" : "") +
2384                         string(isSIntType(m_data.inputType) ? "|MatrixBSignedComponentsKHR" : "") +
2385                         string(isSIntType(m_data.outputType) ? "|MatrixCSignedComponentsKHR" : "") +
2386                         string(isSIntType(m_data.outputType) ? "|MatrixResultSignedComponentsKHR" : "");
2387     map<string, string> attributes;
2388 
2389     attributes["A_ELEM_TYPE"] = spirMatATypeName;
2390     attributes["B_ELEM_TYPE"] = spirMatBTypeName;
2391     attributes["C_ELEM_TYPE"] = getTypeName(m_data.outputType);
2392     attributes["R_ELEM_TYPE"] = getTypeName(m_data.outputType);
2393     attributes["S_ELEM_TYPE"] = spirMatSTypeName;
2394     attributes["U_ELEM_TYPE"] = spirMatUTypeName;
2395     attributes["C_TYPE"]      = spirMatCTypeName;
2396     attributes["R_TYPE"]      = spirMatRTypeName;
2397     attributes["SIGNEDNESS"]  = signedness.empty() ? "" : signedness.substr(1);
2398     attributes["MULT_A"]      = dims[0];
2399     attributes["MULT_B"]      = dims[1];
2400     attributes["MULT_C"]      = dims[2];
2401     attributes["MULT_R"]      = dims[3];
2402     attributes["STRIDE_A"]    = de::toString(getTypeWidth(m_data.inputType));
2403     attributes["STRIDE_B"]    = de::toString(getTypeWidth(m_data.inputType));
2404     attributes["STRIDE_C"]    = de::toString(getTypeWidth(m_data.outputType));
2405     attributes["STRIDE_R"]    = de::toString(getTypeWidth(m_data.outputType));
2406     attributes["LAYOUT"]      = m_data.colMajor ? "1" : "0";
2407 
2408     const std::string shaderCode = shaderTemplateGlobal.specialize(attributes);
2409 
2410     programCollection.spirvAsmSources.add("test") << shaderCode << buildOptions;
2411 }
2412 
initPrograms(SourceCollections & programCollection) const2413 void CooperativeMatrixTestCase::initPrograms(SourceCollections &programCollection) const
2414 {
2415     if (m_data.testType == TT_MATRIXMULADD_CROSS)
2416         initProgramsSPIRV(programCollection);
2417     else
2418         initProgramsGLSL(programCollection);
2419 }
2420 
createInstance(Context & context) const2421 TestInstance *CooperativeMatrixTestCase::createInstance(Context &context) const
2422 {
2423     return new CooperativeMatrixTestInstance(context, m_data);
2424 }
2425 
setDataFloat(void * base,VkComponentTypeKHR dt,uint32_t i,float value)2426 void setDataFloat(void *base, VkComponentTypeKHR dt, uint32_t i, float value)
2427 {
2428     if (dt == VK_COMPONENT_TYPE_FLOAT32_KHR)
2429     {
2430         ((float *)base)[i] = value;
2431     }
2432     else
2433     {
2434         DE_ASSERT(dt == VK_COMPONENT_TYPE_FLOAT16_KHR);
2435         ((tcu::float16_t *)base)[i] = tcu::Float16(value).bits();
2436     }
2437 }
2438 
getDataFloat(void * base,VkComponentTypeKHR dt,uint32_t i)2439 float getDataFloat(void *base, VkComponentTypeKHR dt, uint32_t i)
2440 {
2441     if (dt == VK_COMPONENT_TYPE_FLOAT32_KHR)
2442     {
2443         return ((float *)base)[i];
2444     }
2445     else
2446     {
2447         DE_ASSERT(dt == VK_COMPONENT_TYPE_FLOAT16_KHR);
2448         return tcu::Float16(((const tcu::float16_t *)base)[i]).asFloat();
2449     }
2450 }
2451 
setDataInt(void * base,VkComponentTypeKHR dt,uint32_t i,uint32_t value)2452 void setDataInt(void *base, VkComponentTypeKHR dt, uint32_t i, uint32_t value)
2453 {
2454     DE_ASSERT(componentTypeInfo[dt].bits <= 32);
2455 
2456     switch (dt)
2457     {
2458     case VK_COMPONENT_TYPE_UINT8_KHR:
2459         ((uint8_t *)base)[i] = (uint8_t)value;
2460         break;
2461     case VK_COMPONENT_TYPE_UINT16_KHR:
2462         ((uint16_t *)base)[i] = (uint16_t)value;
2463         break;
2464     case VK_COMPONENT_TYPE_UINT32_KHR:
2465         ((uint32_t *)base)[i] = (uint32_t)value;
2466         break;
2467     case VK_COMPONENT_TYPE_SINT8_KHR:
2468         ((int8_t *)base)[i] = (int8_t)value;
2469         break;
2470     case VK_COMPONENT_TYPE_SINT16_KHR:
2471         ((int16_t *)base)[i] = (int16_t)value;
2472         break;
2473     case VK_COMPONENT_TYPE_SINT32_KHR:
2474         ((int32_t *)base)[i] = (int32_t)value;
2475         break;
2476     default:
2477         TCU_THROW(InternalError, "Unsupported type");
2478     }
2479 }
2480 
getDataInt(void * base,VkComponentTypeKHR dt,uint32_t i)2481 uint32_t getDataInt(void *base, VkComponentTypeKHR dt, uint32_t i)
2482 {
2483     DE_ASSERT(componentTypeInfo[dt].bits <= 32);
2484 
2485     switch (dt)
2486     {
2487     case VK_COMPONENT_TYPE_UINT8_KHR:
2488         return ((uint8_t *)base)[i];
2489     case VK_COMPONENT_TYPE_UINT16_KHR:
2490         return ((uint16_t *)base)[i];
2491     case VK_COMPONENT_TYPE_UINT32_KHR:
2492         return ((uint32_t *)base)[i];
2493     case VK_COMPONENT_TYPE_SINT8_KHR:
2494         return ((int8_t *)base)[i];
2495     case VK_COMPONENT_TYPE_SINT16_KHR:
2496         return ((int16_t *)base)[i];
2497     case VK_COMPONENT_TYPE_SINT32_KHR:
2498         return ((int32_t *)base)[i];
2499     default:
2500         TCU_THROW(InternalError, "Unsupported type");
2501     }
2502 }
2503 
2504 template <typename T>
getDataConvertedToT(void * base,VkComponentTypeKHR dt,uint32_t i)2505 T getDataConvertedToT(void *base, VkComponentTypeKHR dt, uint32_t i)
2506 {
2507     DE_ASSERT(componentTypeInfo[dt].bits <= 32);
2508 
2509     switch (dt)
2510     {
2511     case VK_COMPONENT_TYPE_UINT8_KHR:
2512         return (T)((uint8_t *)base)[i];
2513     case VK_COMPONENT_TYPE_UINT16_KHR:
2514         return (T)((uint16_t *)base)[i];
2515     case VK_COMPONENT_TYPE_UINT32_KHR:
2516         return (T)((uint32_t *)base)[i];
2517     case VK_COMPONENT_TYPE_SINT8_KHR:
2518         return (T)((int8_t *)base)[i];
2519     case VK_COMPONENT_TYPE_SINT16_KHR:
2520         return (T)((int16_t *)base)[i];
2521     case VK_COMPONENT_TYPE_SINT32_KHR:
2522         return (T)((int32_t *)base)[i];
2523     case VK_COMPONENT_TYPE_FLOAT32_KHR:
2524     {
2525         float temp = ((float *)base)[i];
2526         if (std::numeric_limits<T>::min() == 0)
2527             temp = std::max(temp, 0.0f);
2528         return (T)temp;
2529     }
2530     case VK_COMPONENT_TYPE_FLOAT16_KHR:
2531     {
2532         float temp = tcu::Float16(((tcu::float16_t *)base)[i]).asFloat();
2533         if (std::numeric_limits<T>::min() == 0)
2534             temp = std::max(temp, 0.0f);
2535         return (T)temp;
2536     }
2537     default:
2538         TCU_THROW(InternalError, "Unsupported type");
2539     }
2540 }
2541 
2542 template <typename T>
satAdd(T a,T b)2543 T satAdd(T a, T b)
2544 {
2545     if (a > 0)
2546     {
2547         if (b > std::numeric_limits<T>::max() - a)
2548             return std::numeric_limits<T>::max();
2549     }
2550     else if (b < std::numeric_limits<T>::min() - a)
2551     {
2552         return std::numeric_limits<T>::min();
2553     }
2554 
2555     return (T)(a + b);
2556 }
2557 
satAddData(VkComponentTypeKHR dt,uint32_t a,uint32_t b)2558 uint32_t satAddData(VkComponentTypeKHR dt, uint32_t a, uint32_t b)
2559 {
2560     DE_ASSERT(componentTypeInfo[dt].bits <= 32);
2561 
2562     switch (dt)
2563     {
2564     case VK_COMPONENT_TYPE_UINT8_KHR:
2565         return deMinu32(a + b, std::numeric_limits<uint8_t>::max());
2566     case VK_COMPONENT_TYPE_UINT16_KHR:
2567         return deMinu32(a + b, std::numeric_limits<uint16_t>::max());
2568     case VK_COMPONENT_TYPE_UINT32_KHR:
2569         return (a + b >= a) ? a + b : std::numeric_limits<uint32_t>::max();
2570     case VK_COMPONENT_TYPE_SINT8_KHR:
2571         return (uint32_t)satAdd((int8_t)a, (int8_t)b);
2572     case VK_COMPONENT_TYPE_SINT16_KHR:
2573         return (uint32_t)satAdd((int16_t)a, (int16_t)b);
2574     case VK_COMPONENT_TYPE_SINT32_KHR:
2575         return (uint32_t)satAdd((int32_t)a, (int32_t)b);
2576     default:
2577         TCU_THROW(InternalError, "Unsupported type");
2578     }
2579 }
2580 
getLimit(VkComponentTypeKHR dt,bool positive)2581 uint32_t getLimit(VkComponentTypeKHR dt, bool positive)
2582 {
2583     DE_ASSERT(componentTypeInfo[dt].bits <= 32);
2584 
2585     switch (dt)
2586     {
2587     case VK_COMPONENT_TYPE_UINT8_KHR:
2588         return uint32_t(positive ? std::numeric_limits<uint8_t>::max() : std::numeric_limits<uint8_t>::min());
2589     case VK_COMPONENT_TYPE_UINT16_KHR:
2590         return uint32_t(positive ? std::numeric_limits<uint16_t>::max() : std::numeric_limits<uint16_t>::min());
2591     case VK_COMPONENT_TYPE_UINT32_KHR:
2592         return uint32_t(positive ? std::numeric_limits<uint32_t>::max() : std::numeric_limits<uint32_t>::min());
2593     case VK_COMPONENT_TYPE_SINT8_KHR:
2594         return uint32_t(positive ? std::numeric_limits<int8_t>::max() : std::numeric_limits<int8_t>::min());
2595     case VK_COMPONENT_TYPE_SINT16_KHR:
2596         return uint32_t(positive ? std::numeric_limits<int16_t>::max() : std::numeric_limits<int16_t>::min());
2597     case VK_COMPONENT_TYPE_SINT32_KHR:
2598         return uint32_t(positive ? std::numeric_limits<int32_t>::max() : std::numeric_limits<int32_t>::min());
2599     default:
2600         TCU_THROW(InternalError, "Unsupported type");
2601     }
2602 }
2603 
setSingleElementInt(void * data,VkComponentTypeKHR dt,uint32_t start,uint32_t count,uint32_t step,uint32_t at,uint32_t val)2604 void setSingleElementInt(void *data, VkComponentTypeKHR dt, uint32_t start, uint32_t count, uint32_t step, uint32_t at,
2605                          uint32_t val)
2606 {
2607     for (uint32_t i = 0; i < count; i++)
2608         setDataInt(data, dt, start + i * step, (i == at) ? val : 0);
2609 }
2610 
2611 #ifdef COOPERATIVE_MATRIX_EXTENDED_DEBUG
dumpWholeMatrix(void * data,VkComponentTypeKHR dt,bool colMajor,uint32_t matrixElemCount,uint32_t stride)2612 string dumpWholeMatrix(void *data, VkComponentTypeKHR dt, bool colMajor, uint32_t matrixElemCount, uint32_t stride)
2613 {
2614     const uint32_t rowsCount = colMajor ? stride : matrixElemCount / stride;
2615     const uint32_t colsCount = colMajor ? matrixElemCount / stride : stride;
2616     bool floatType           = isFloatType(dt);
2617     bool sIntType            = isSIntType(dt);
2618     std::stringstream ss;
2619 
2620     DE_ASSERT(rowsCount * colsCount == matrixElemCount);
2621 
2622     for (uint32_t r = 0; r < rowsCount; r++)
2623     {
2624         for (uint32_t c = 0; c < colsCount; c++)
2625         {
2626             const uint32_t i = colMajor ? rowsCount * c + r : colsCount * r + c;
2627 
2628             if (floatType)
2629                 ss << getDataFloat(data, dt, i) << "\t";
2630             else if (sIntType)
2631                 ss << (int32_t)getDataInt(data, dt, i) << "\t";
2632             else
2633                 ss << getDataInt(data, dt, i) << "\t";
2634         }
2635 
2636         ss << std::endl;
2637     }
2638 
2639     return ss.str();
2640 }
2641 #endif
2642 
iterate(void)2643 tcu::TestStatus CooperativeMatrixTestInstance::iterate(void)
2644 {
2645     const DeviceInterface &vk = m_context.getDeviceInterface();
2646     const VkDevice device     = m_context.getDevice();
2647     Allocator &allocator      = m_context.getDefaultAllocator();
2648     MemoryRequirement memoryDeviceAddress =
2649         m_data.storageClass == SC_PHYSICAL_STORAGE_BUFFER &&
2650                 m_context.isDeviceFunctionalitySupported("VK_KHR_buffer_device_address") ?
2651             MemoryRequirement::DeviceAddress :
2652             MemoryRequirement::Any;
2653     qpTestResult finalres       = QP_TEST_RESULT_NOT_SUPPORTED;
2654     tcu::TestLog &log           = m_context.getTestContext().getLog();
2655     const bool saturated        = (m_data.testType == TT_MATRIXMULADD_SATURATED);
2656     const uint32_t subgroupSize = getSubgroupSizeFromMode(m_context, m_data.subgroupSizeMode);
2657     const float epsilon         = 1.0f / float(1ull << 17); // 131072 is epsilon circa 1e-5
2658     vk::VkPhysicalDeviceProperties vkproperties;
2659     const bool coopMat2Supported = m_context.isDeviceFunctionalitySupported("VK_NV_cooperative_matrix2");
2660 
2661     m_context.getInstanceInterface().getPhysicalDeviceProperties(m_context.getPhysicalDevice(), &vkproperties);
2662 
2663     deRandom rnd;
2664     deRandom_init(&rnd, 1234);
2665 
2666     std::vector<VkCooperativeMatrixPropertiesKHR> properties =
2667         getCooperativeMatrixPropertiesConverted(m_context, isKhr(m_data.useType));
2668 
2669     struct TestTuple
2670     {
2671         TestTuple()
2672         {
2673         }
2674         TestTuple(uint32_t m, uint32_t n, uint32_t k, uint32_t w) : M(m), N(n), K(k), workgroupSize(w)
2675         {
2676         }
2677 
2678         bool operator<(const TestTuple &other) const
2679         {
2680             return workgroupSize < other.workgroupSize || (workgroupSize == other.workgroupSize && M < other.M) ||
2681                    (workgroupSize == other.workgroupSize && M == other.M && N < other.N) ||
2682                    (workgroupSize == other.workgroupSize && M == other.M && N == other.N && K < other.K);
2683         }
2684 
2685         uint32_t M, N, K, workgroupSize;
2686     };
2687 
2688     std::vector<VkCooperativeMatrixFlexibleDimensionsPropertiesNV> flexibleProperties;
2689     if (m_context.getCooperativeMatrix2FeaturesNV().cooperativeMatrixFlexibleDimensions)
2690     {
2691         uint32_t flexiblePropertyCount = 0;
2692 
2693         const InstanceInterface &vki = m_context.getInstanceInterface();
2694         VK_CHECK(vki.getPhysicalDeviceCooperativeMatrixFlexibleDimensionsPropertiesNV(m_context.getPhysicalDevice(),
2695                                                                                       &flexiblePropertyCount, nullptr));
2696 
2697         if (flexiblePropertyCount > 0)
2698         {
2699             const VkCooperativeMatrixFlexibleDimensionsPropertiesNV sample = initVulkanStructureConst();
2700 
2701             flexibleProperties.resize(flexiblePropertyCount, sample);
2702 
2703             VK_CHECK(vki.getPhysicalDeviceCooperativeMatrixFlexibleDimensionsPropertiesNV(
2704                 m_context.getPhysicalDevice(), &flexiblePropertyCount, flexibleProperties.data()));
2705         }
2706         else
2707         {
2708             flexibleProperties.clear();
2709         }
2710     }
2711 
2712     set<TestTuple> testSizes;
2713 
2714     if (isTensorLayoutTest(m_data.testType))
2715     {
2716         for (auto const &prop : flexibleProperties)
2717         {
2718             auto const *p = &prop;
2719             if (m_data.scope == p->scope)
2720             {
2721                 // placeholder matrix size. The test defines the real sizes elsewhere
2722                 testSizes.insert(TestTuple(32, 32, 32, p->workgroupInvocations));
2723             }
2724         }
2725     }
2726     else if (m_data.useType != UT_NV)
2727     {
2728         auto shmemOK = [&](uint32_t M, uint32_t N, uint32_t K) -> bool
2729         {
2730             uint32_t maxMatrixElements = max(M * N, max(M * K, K * N));
2731 
2732             if (isReduce2x2(m_data.testType))
2733             {
2734                 // A matrix is 4x larger
2735                 maxMatrixElements *= 4;
2736             }
2737             if (isReduceChangeDim(m_data.testType))
2738             {
2739                 // A matrix is 3-9x larger
2740                 maxMatrixElements *= reduceMScale(m_data.testType) * reduceNScale(m_data.testType);
2741             }
2742 
2743             if (m_data.scope == VK_SCOPE_SUBGROUP_KHR)
2744             {
2745                 maxMatrixElements *= m_data.subgroupsPerWorkgroupX * m_data.subgroupsPerWorkgroupY;
2746             }
2747 
2748             int32_t maxSharedMem = m_context.getDeviceProperties().limits.maxComputeSharedMemorySize;
2749 
2750             if (coopMat2Supported && m_data.scope == VK_SCOPE_WORKGROUP_KHR)
2751             {
2752                 // reserved for implementation
2753                 maxSharedMem -=
2754                     m_context.getCooperativeMatrix2PropertiesNV().cooperativeMatrixWorkgroupScopeReservedSharedMemory;
2755             }
2756 
2757             if (m_data.storageClass == SC_WORKGROUP || m_data.storageClass == SC_WORKGROUP_VARIABLE_POINTERS)
2758             {
2759                 return (int32_t)(maxMatrixElements * 2 *
2760                                  (componentTypeInfo[m_data.inputType].bits * m_data.inputComponentCount +
2761                                   componentTypeInfo[m_data.outputType].bits * m_data.outputComponentCount) /
2762                                  8) <= maxSharedMem;
2763             }
2764 
2765             return true;
2766         };
2767         if (m_context.getCooperativeMatrix2FeaturesNV().cooperativeMatrixFlexibleDimensions)
2768         {
2769             const auto isMMA    = isMatrixMulAddOp(m_data.testType);
2770             const auto isMMASat = m_data.testType == TT_MATRIXMULADD_SATURATED;
2771 
2772             std::vector<TestTuple> sizes;
2773             for (auto const &prop : flexibleProperties)
2774             {
2775                 auto const *p = &prop;
2776 
2777                 uint32_t MGranularity = 0;
2778                 uint32_t NGranularity = 0;
2779                 uint32_t KGranularity = 0;
2780                 bool ok               = false;
2781 
2782                 if (p->scope != m_data.scope)
2783                     continue;
2784 
2785                 if (isMMA && isMMASat != static_cast<bool>(p->saturatingAccumulation))
2786                     continue;
2787 
2788                 if (isMMA)
2789                 {
2790                     if (p->AType == m_data.inputType && p->BType == m_data.inputType && p->CType == m_data.outputType &&
2791                         p->ResultType == m_data.outputType)
2792                     {
2793                         ok           = true;
2794                         MGranularity = p->MGranularity;
2795                         NGranularity = p->NGranularity;
2796                         KGranularity = p->KGranularity;
2797                     }
2798                 }
2799                 else
2800                 {
2801                     const VkComponentTypeKHR types[2] = {m_data.inputType, m_data.outputType};
2802                     UseType uses[2]                   = {m_data.useType, m_data.useType};
2803                     if (m_data.testType == TT_CONVERT_ACC_TO_A)
2804                     {
2805                         uses[1] = UT_KHR_A;
2806                     }
2807                     else if (m_data.testType == TT_CONVERT_ACC_TO_B || m_data.testType == TT_TRANSPOSE_ACC_TO_B)
2808                     {
2809                         uses[1] = UT_KHR_B;
2810                     }
2811 
2812                     auto const &SetGranularity = [&](const VkCooperativeMatrixFlexibleDimensionsPropertiesNV *p2,
2813                                                      VkComponentTypeKHR type, UseType use)
2814                     {
2815                         ok = false;
2816                         switch (use)
2817                         {
2818                         case UT_NV:
2819                             break;
2820                         case UT_KHR_A:
2821                         {
2822                             if (p2->AType == type)
2823                             {
2824                                 ok           = true;
2825                                 MGranularity = std::max(MGranularity, p2->MGranularity);
2826                                 NGranularity = std::max(NGranularity, p2->KGranularity);
2827                             }
2828 
2829                             break;
2830                         }
2831                         case UT_KHR_B:
2832                         {
2833                             if (p2->BType == type)
2834                             {
2835                                 ok = true;
2836                                 if (m_data.testType == TT_TRANSPOSE_ACC_TO_B)
2837                                 {
2838                                     MGranularity = std::max(MGranularity, p2->NGranularity);
2839                                     NGranularity = std::max(NGranularity, p2->KGranularity);
2840                                 }
2841                                 else
2842                                 {
2843                                     MGranularity = std::max(MGranularity, p2->KGranularity);
2844                                     NGranularity = std::max(NGranularity, p2->NGranularity);
2845                                 }
2846                             }
2847 
2848                             break;
2849                         }
2850                         case UT_KHR_Result:
2851                         {
2852                             if (p2->ResultType == type)
2853                             {
2854                                 ok           = true;
2855                                 MGranularity = std::max(MGranularity, p2->MGranularity);
2856                                 NGranularity = std::max(NGranularity, p2->NGranularity);
2857                             }
2858 
2859                             break;
2860                         }
2861                         default:
2862                             TCU_THROW(InternalError, "Unsupported use type");
2863                         }
2864                     };
2865 
2866                     SetGranularity(p, types[0], uses[0]);
2867 
2868                     if (!ok)
2869                     {
2870                         continue;
2871                     }
2872 
2873                     // Need to find a "matching" property for the other use/type
2874                     // and take the max of the granularities
2875                     for (auto const &prop2 : flexibleProperties)
2876                     {
2877                         auto const *p2 = &prop2;
2878 
2879                         if (p2->scope != m_data.scope || p2->workgroupInvocations != p->workgroupInvocations)
2880                             continue;
2881 
2882                         SetGranularity(p2, types[1], uses[1]);
2883 
2884                         if (ok)
2885                         {
2886                             break;
2887                         }
2888                     }
2889                 }
2890                 if (ok)
2891                 {
2892                     DE_ASSERT(MGranularity && NGranularity && (!isMMA || KGranularity));
2893 
2894                     sizes.emplace_back(1U * MGranularity, 1U * NGranularity, 1U * KGranularity,
2895                                        p->workgroupInvocations);
2896                     if (m_data.storageClass != SC_WORKGROUP && m_data.storageClass != SC_WORKGROUP_VARIABLE_POINTERS)
2897                     {
2898                         sizes.emplace_back(3U * MGranularity, 1U * NGranularity, 1U * KGranularity,
2899                                            p->workgroupInvocations);
2900                         sizes.emplace_back(1U * MGranularity, 3U * NGranularity, 1U * KGranularity,
2901                                            p->workgroupInvocations);
2902                         if (isMatrixMulAddOp(m_data.testType))
2903                         {
2904                             sizes.emplace_back(2U * MGranularity, 2U * NGranularity, 3U * KGranularity,
2905                                                p->workgroupInvocations);
2906                             sizes.emplace_back(1U * MGranularity, 1U * NGranularity, 3U * KGranularity,
2907                                                p->workgroupInvocations);
2908                         }
2909                     }
2910                 }
2911             }
2912 
2913             for (auto &s : sizes)
2914             {
2915                 if (shmemOK(s.M, s.N, s.K))
2916                 {
2917                     testSizes.insert(s);
2918                 }
2919             }
2920         }
2921     }
2922     if (!isTensorLayoutTest(m_data.testType))
2923     {
2924         if (isMatrixMulAddOp(m_data.testType))
2925         {
2926             for (size_t i = 0; i < properties.size(); ++i)
2927             {
2928                 VkCooperativeMatrixPropertiesKHR *p = &properties[i];
2929 
2930                 if (p->AType == m_data.inputType && p->BType == m_data.inputType && p->CType == m_data.outputType &&
2931                     p->ResultType == m_data.outputType && p->scope == m_data.scope)
2932                 {
2933                     testSizes.insert(TestTuple(p->MSize, p->NSize, p->KSize, 0));
2934                 }
2935             }
2936         }
2937         else
2938         {
2939             set<TestTuple> typeSizes[2];
2940             VkComponentTypeKHR types[2] = {m_data.inputType, m_data.outputType};
2941             UseType uses[2]             = {m_data.useType, m_data.useType};
2942             if (m_data.testType == TT_CONVERT_ACC_TO_A)
2943             {
2944                 uses[1] = UT_KHR_A;
2945             }
2946             else if (m_data.testType == TT_CONVERT_ACC_TO_B || m_data.testType == TT_TRANSPOSE_ACC_TO_B)
2947             {
2948                 uses[1] = UT_KHR_B;
2949             }
2950 
2951             for (uint32_t i = 0; i < properties.size(); ++i)
2952             {
2953                 VkCooperativeMatrixPropertiesKHR *p = &properties[i];
2954 
2955                 if (p->scope != m_data.scope)
2956                     continue;
2957 
2958                 for (uint32_t j = 0; j < 2; ++j)
2959                 {
2960                     // For these tests, m_data.M/N are always the matrix size. Check if they match
2961                     // any input or output in the list.
2962                     if ((uses[j] == UT_KHR_A || uses[j] == UT_NV) && p->AType == types[j])
2963                         typeSizes[j].insert(TestTuple(p->MSize, p->KSize, 0, 0));
2964                     if ((uses[j] == UT_KHR_B || uses[j] == UT_NV) && p->BType == types[j])
2965                     {
2966                         if (m_data.testType == TT_TRANSPOSE_ACC_TO_B)
2967                         {
2968                             typeSizes[j].insert(TestTuple(p->NSize, p->KSize, 0, 0));
2969                         }
2970                         else
2971                         {
2972                             typeSizes[j].insert(TestTuple(p->KSize, p->NSize, 0, 0));
2973                         }
2974                     }
2975                     if ((uses[j] == UT_KHR_Result || uses[j] == UT_NV) &&
2976                         (p->CType == types[j] || p->ResultType == types[j]))
2977                         typeSizes[j].insert(TestTuple(p->MSize, p->NSize, 0, 0));
2978                 }
2979             }
2980             // Test those sizes that are supported for both the input and output type.
2981             std::set_intersection(typeSizes[0].begin(), typeSizes[0].end(), typeSizes[1].begin(), typeSizes[1].end(),
2982                                   std::inserter(testSizes, testSizes.begin()));
2983         }
2984     }
2985 
2986     properties.resize(0);
2987 
2988     for (auto &testSize : testSizes)
2989     {
2990         // When testing a multiply, MxNxK is the type of matrix multiply.
2991         // Otherwise, MxN is the size of the input/output matrices
2992         uint32_t M, N, K;
2993         M = testSize.M;
2994         N = testSize.N;
2995         K = testSize.K;
2996 
2997         log << tcu::TestLog::Message << "Testing M = " << M << ", N = " << N << ", K = " << K
2998             << ", WG = " << testSize.workgroupSize << tcu::TestLog::EndMessage;
2999 
3000         struct
3001         {
3002             uint32_t rows, cols;
3003         } dims[4];
3004 
3005         if (isMatrixMulAddOp(m_data.testType))
3006         {
3007             dims[0].rows = M;
3008             dims[0].cols = K;
3009             dims[1].rows = K;
3010             dims[1].cols = N;
3011             dims[2].rows = M;
3012             dims[2].cols = N;
3013             dims[3].rows = M;
3014             dims[3].cols = N;
3015         }
3016         else
3017         {
3018             if (isReduce2x2(m_data.testType))
3019             {
3020                 dims[0].rows = M * 2;
3021                 dims[0].cols = N * 2;
3022             }
3023             else
3024             {
3025                 dims[0].rows = M;
3026                 dims[0].cols = N;
3027             }
3028             dims[1].rows = M;
3029             dims[1].cols = N;
3030             dims[2].rows = M;
3031             dims[2].cols = N;
3032             if (isReduceChangeDim(m_data.testType))
3033             {
3034                 dims[3].rows = M * reduceMScale(m_data.testType);
3035                 dims[3].cols = N * reduceNScale(m_data.testType);
3036             }
3037             else if (m_data.testType == TT_TRANSPOSE_ACC_TO_B)
3038             {
3039                 dims[2].rows = N;
3040                 dims[2].cols = M;
3041                 dims[3].rows = N;
3042                 dims[3].cols = M;
3043             }
3044             else
3045             {
3046                 dims[3].rows = M;
3047                 dims[3].cols = N;
3048             }
3049         }
3050 
3051         VkComponentTypeKHR dataTypes[4];
3052         size_t elementSize[4];
3053         VkDeviceSize bufferSizes[5];
3054         de::MovePtr<BufferWithMemory> buffers[5];
3055         vk::VkDescriptorBufferInfo bufferDescriptors[5];
3056         uint32_t strides[4]; // in elements
3057         uint32_t loadStrides[4];
3058         uint32_t totalElements[4];
3059         size_t sharedMemoryUsage[4];
3060         size_t totalSharedMemoryUsage = 0;
3061 
3062         for (uint32_t i = 0; i < 5; ++i)
3063         {
3064             if (i < 4)
3065             {
3066                 // A/B use input type, C/D use output type
3067                 dataTypes[i]   = (i < 2) ? m_data.inputType : m_data.outputType;
3068                 elementSize[i] = componentTypeInfo[dataTypes[i]].bits / 8;
3069 
3070                 strides[i] = (m_data.colMajor ? dims[i].rows : dims[i].cols) * m_data.workgroupsX;
3071                 if (m_data.scope != VK_SCOPE_WORKGROUP_KHR)
3072                 {
3073                     strides[i] *= m_data.subgroupsPerWorkgroupX;
3074                 }
3075                 loadStrides[i]   = strides[i];
3076                 totalElements[i] = strides[i] * (m_data.colMajor ? dims[i].cols : dims[i].rows) * m_data.workgroupsY;
3077                 sharedMemoryUsage[i] = dims[i].cols * dims[i].rows * m_data.subgroupsPerWorkgroupX *
3078                                        m_data.subgroupsPerWorkgroupY * elementSize[i] *
3079                                        ((i < 2) ? m_data.inputComponentCount : m_data.outputComponentCount);
3080 
3081                 // Check there is enough shared memory supported
3082                 if ((m_data.useType != UT_NV) &&
3083                     ((m_data.storageClass == SC_WORKGROUP) || (m_data.storageClass == SC_WORKGROUP_VARIABLE_POINTERS)))
3084                 {
3085                     totalSharedMemoryUsage += sharedMemoryUsage[i];
3086                     if (totalSharedMemoryUsage > vkproperties.limits.maxComputeSharedMemorySize)
3087                         throw tcu::NotSupportedError("Not enough shared memory supported.");
3088                 }
3089 
3090                 if (m_data.testType == TT_MATRIXMULADD_DEQUANT && i < 2)
3091                 {
3092                     // logical type is fp16, but encoded as 4bpp so takes 1/4 the storage
3093                     DE_ASSERT(m_data.inputType == VK_COMPONENT_TYPE_FLOAT16_KHR);
3094                     totalElements[i] /= 4;
3095                 }
3096 
3097                 if (m_data.scope != VK_SCOPE_WORKGROUP_KHR)
3098                 {
3099                     totalElements[i] *= m_data.subgroupsPerWorkgroupY;
3100                 }
3101 
3102                 if (isTensorLayoutTest(m_data.testType))
3103                 {
3104                     // sized for 128x128 matrix, scaled up by 4 workgroups in x and y
3105                     totalElements[i] = 512 * 512;
3106                 }
3107 
3108                 bufferSizes[i] = totalElements[i] * elementSize[i];
3109             }
3110             else
3111             {
3112                 bufferSizes[4] = sizeof(VkDeviceAddress) * 4;
3113             }
3114 
3115             try
3116             {
3117                 buffers[i] = de::MovePtr<BufferWithMemory>(new BufferWithMemory(
3118                     vk, device, allocator,
3119                     makeBufferCreateInfo(bufferSizes[i], VK_BUFFER_USAGE_STORAGE_BUFFER_BIT |
3120                                                              VK_BUFFER_USAGE_TRANSFER_DST_BIT |
3121                                                              VK_BUFFER_USAGE_TRANSFER_SRC_BIT |
3122                                                              (memoryDeviceAddress == MemoryRequirement::DeviceAddress ?
3123                                                                   VK_BUFFER_USAGE_SHADER_DEVICE_ADDRESS_BIT_EXT :
3124                                                                   0)),
3125                     MemoryRequirement::HostVisible | MemoryRequirement::Cached | MemoryRequirement::Coherent |
3126                         memoryDeviceAddress));
3127             }
3128             catch (const tcu::NotSupportedError &)
3129             {
3130                 buffers[i] = de::MovePtr<BufferWithMemory>(new BufferWithMemory(
3131                     vk, device, allocator,
3132                     makeBufferCreateInfo(bufferSizes[i], VK_BUFFER_USAGE_STORAGE_BUFFER_BIT |
3133                                                              VK_BUFFER_USAGE_TRANSFER_DST_BIT |
3134                                                              VK_BUFFER_USAGE_TRANSFER_SRC_BIT |
3135                                                              (memoryDeviceAddress == MemoryRequirement::DeviceAddress ?
3136                                                                   VK_BUFFER_USAGE_SHADER_DEVICE_ADDRESS_BIT_EXT :
3137                                                                   0)),
3138                     MemoryRequirement::HostVisible | memoryDeviceAddress));
3139             }
3140 
3141             bufferDescriptors[i] = makeDescriptorBufferInfo(**buffers[i], 0, bufferSizes[i]);
3142         }
3143 
3144         // Load with a stride of 0
3145         if (m_data.testType == TT_MATRIXMULADD_STRIDE0)
3146             loadStrides[0] = loadStrides[1] = loadStrides[2] = loadStrides[3] = 0;
3147 
3148         void *ptrs[5];
3149         for (uint32_t i = 0; i < 5; ++i)
3150         {
3151             ptrs[i] = buffers[i]->getAllocation().getHostPtr();
3152         }
3153 
3154         vk::DescriptorSetLayoutBuilder layoutBuilder;
3155 
3156         layoutBuilder.addSingleBinding(VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, allShaderStages);
3157         layoutBuilder.addSingleBinding(VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, allShaderStages);
3158         layoutBuilder.addSingleBinding(VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, allShaderStages);
3159         layoutBuilder.addSingleBinding(VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, allShaderStages);
3160         layoutBuilder.addSingleBinding(VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, allShaderStages);
3161 
3162         vk::Unique<vk::VkDescriptorSetLayout> descriptorSetLayout(layoutBuilder.build(vk, device));
3163 
3164         vk::Unique<vk::VkDescriptorPool> descriptorPool(
3165             vk::DescriptorPoolBuilder()
3166                 .addType(VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, 5u)
3167                 .build(vk, device, VK_DESCRIPTOR_POOL_CREATE_FREE_DESCRIPTOR_SET_BIT, 1u));
3168         vk::Unique<vk::VkDescriptorSet> descriptorSet(
3169             makeDescriptorSet(vk, device, *descriptorPool, *descriptorSetLayout));
3170 
3171         vk::DescriptorSetUpdateBuilder setUpdateBuilder;
3172         if (m_data.storageClass == SC_PHYSICAL_STORAGE_BUFFER)
3173         {
3174             VkBufferDeviceAddressInfo info{
3175                 VK_STRUCTURE_TYPE_BUFFER_DEVICE_ADDRESS_INFO, // VkStructureType  sType;
3176                 nullptr,                                      // const void*  pNext;
3177                 VK_NULL_HANDLE,                               // VkBuffer            buffer
3178             };
3179             VkDeviceAddress *addrsInMemory = (VkDeviceAddress *)ptrs[4];
3180             for (uint32_t i = 0; i < 4; ++i)
3181             {
3182                 info.buffer          = **buffers[i];
3183                 VkDeviceAddress addr = vk.getBufferDeviceAddress(device, &info);
3184                 addrsInMemory[i]     = addr;
3185             }
3186             setUpdateBuilder.writeSingle(*descriptorSet, vk::DescriptorSetUpdateBuilder::Location::binding(4),
3187                                          VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, &bufferDescriptors[4]);
3188         }
3189         else
3190         {
3191             setUpdateBuilder.writeSingle(*descriptorSet, vk::DescriptorSetUpdateBuilder::Location::binding(0),
3192                                          VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, &bufferDescriptors[0]);
3193             setUpdateBuilder.writeSingle(*descriptorSet, vk::DescriptorSetUpdateBuilder::Location::binding(1),
3194                                          VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, &bufferDescriptors[1]);
3195             setUpdateBuilder.writeSingle(*descriptorSet, vk::DescriptorSetUpdateBuilder::Location::binding(2),
3196                                          VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, &bufferDescriptors[2]);
3197             setUpdateBuilder.writeSingle(*descriptorSet, vk::DescriptorSetUpdateBuilder::Location::binding(3),
3198                                          VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, &bufferDescriptors[3]);
3199         }
3200 
3201         setUpdateBuilder.update(vk, device);
3202 
3203         VkPipelineBindPoint bindPoint = VK_PIPELINE_BIND_POINT_COMPUTE;
3204 
3205         const uint32_t specData[9] = {
3206             (m_data.scope == VK_SCOPE_WORKGROUP_KHR) ? testSize.workgroupSize :
3207                                                        (subgroupSize * m_data.subgroupsPerWorkgroupX),
3208             (m_data.scope == VK_SCOPE_WORKGROUP_KHR) ? 1 : m_data.subgroupsPerWorkgroupY,
3209             strides[0],
3210             strides[1],
3211             strides[2],
3212             strides[3],
3213             M,
3214             N,
3215             K,
3216         };
3217 
3218         const vk::VkSpecializationMapEntry entries[9] = {
3219             {0, (uint32_t)(sizeof(uint32_t) * 0), sizeof(uint32_t)},
3220             {1, (uint32_t)(sizeof(uint32_t) * 1), sizeof(uint32_t)},
3221             {2, (uint32_t)(sizeof(uint32_t) * 2), sizeof(uint32_t)},
3222             {3, (uint32_t)(sizeof(uint32_t) * 3), sizeof(uint32_t)},
3223             {4, (uint32_t)(sizeof(uint32_t) * 4), sizeof(uint32_t)},
3224             {5, (uint32_t)(sizeof(uint32_t) * 5), sizeof(uint32_t)},
3225             {6, (uint32_t)(sizeof(uint32_t) * 6), sizeof(uint32_t)},
3226             {7, (uint32_t)(sizeof(uint32_t) * 7), sizeof(uint32_t)},
3227             {8, (uint32_t)(sizeof(uint32_t) * 8), sizeof(uint32_t)},
3228         };
3229 
3230         const vk::VkSpecializationInfo specInfo = {
3231             9,                // mapEntryCount
3232             entries,          // pMapEntries
3233             sizeof(specData), // dataSize
3234             specData          // pData
3235         };
3236 
3237         for (uint32_t i = 0; i < 4; ++i)
3238             for (uint32_t j = 0; j < totalElements[i]; ++j)
3239             {
3240                 if (isFloatType(dataTypes[i]))
3241                 {
3242                     if ((isTensorLayoutTest(m_data.testType) || isClampTest(m_data.testType)) && i == 3)
3243                     {
3244                         setDataFloat(ptrs[i], dataTypes[i], j, 123.0);
3245                     }
3246                     else if (!isMatrixMulAddOp(m_data.testType) && !isReduceSum(m_data.testType))
3247                         setDataFloat(ptrs[i], dataTypes[i], j,
3248                                      ((float)(deRandom_getUint32(&rnd) & 0xff) - 64.0f) / 2.0f);
3249                     else if (m_data.testType == TT_MATRIXMULADD_DEQUANT && i < 2)
3250                     {
3251                         // Each "element" still accounts for 16bpp, but it's stored quantized
3252                         // so we just want a random 16b pattern.
3253                         uint32_t value = (deRandom_getUint32(&rnd) & 0xffff);
3254                         setDataInt(ptrs[i], VK_COMPONENT_TYPE_UINT16_KHR, j, value);
3255                     }
3256                     else
3257                         setDataFloat(ptrs[i], dataTypes[i], j, ((float)(deRandom_getUint32(&rnd) & 0xf) - 4.0f) / 2.0f);
3258                 }
3259                 else
3260                 {
3261                     if (m_data.testType == TT_MATRIXMULADD_WRAPPING)
3262                     {
3263                         // Choose matrix values that should cause overflow and underflow, to
3264                         // verify wrapping behavior. Use the full range of values for A and B.
3265                         // For matrix C, use values clustered near where the type wraps (zero
3266                         // for unsigned, 2^(N-1) for signed).
3267                         uint32_t bits = componentTypeInfo[dataTypes[i]].bits;
3268                         uint32_t value;
3269                         if (i == 2)
3270                         {
3271                             value = (deRandom_getUint32(&rnd) & 0xff) - 128;
3272                             if (componentTypeInfo[dataTypes[i]].isSigned)
3273                                 value += (1U << (bits - 1));
3274                         }
3275                         else
3276                         {
3277                             uint32_t mask = (bits == 32) ? 0xFFFFFFFFU : ((1U << bits) - 1U);
3278                             value         = deRandom_getUint32(&rnd) & mask;
3279                         }
3280                         setDataInt(ptrs[i], dataTypes[i], j, value);
3281                     }
3282                     else if (m_data.testType == TT_MATRIXMULADD_SATURATED)
3283                     {
3284                         setDataInt(ptrs[i], dataTypes[i], j, 0);
3285                     }
3286                     else if ((isTensorLayoutTest(m_data.testType) || isClampTest(m_data.testType)) && i == 3)
3287                     {
3288                         setDataInt(ptrs[i], dataTypes[i], j, 123);
3289                     }
3290                     else
3291                     {
3292                         uint32_t value = (deRandom_getUint32(&rnd) & 0xff) - 128;
3293                         setDataInt(ptrs[i], dataTypes[i], j, value);
3294                     }
3295                 }
3296             }
3297 
3298         if (m_data.testType == TT_MATRIXMULADD_SATURATED)
3299         {
3300             // Set 1st row of A to 1,0,0...
3301             setSingleElementInt(ptrs[0], dataTypes[0], 0, dims[0].cols, (m_data.colMajor ? strides[0] : 1), 0, 1);
3302 
3303             // Set 1st column of B to 1,0,0...
3304             setSingleElementInt(ptrs[1], dataTypes[1], 0, dims[1].rows, (m_data.colMajor ? 1 : strides[1]), 0, 1);
3305 
3306             // Set C element at {0,0} to maximum type value, thus we will have overflow at plus operation in D=A*B+C for this element
3307             setDataInt(ptrs[2], dataTypes[2], 0, getLimit(dataTypes[2], true));
3308 
3309             // Check underflow if all involved elements support negative values
3310             if (isSIntType(dataTypes[1]) && isSIntType(dataTypes[2]) && isSIntType(dataTypes[3]))
3311             {
3312                 // Set 2nd row of A to 0,1,0,0...
3313                 setSingleElementInt(ptrs[0], dataTypes[0], (m_data.colMajor ? 1 : strides[0]), dims[0].cols,
3314                                     (m_data.colMajor ? strides[0] : 1), 1, 1);
3315 
3316                 // Set 2nd column of B to 0,-1,0,0...
3317                 setSingleElementInt(ptrs[1], dataTypes[1], (m_data.colMajor ? strides[1] : 1), dims[1].rows,
3318                                     (m_data.colMajor ? 1 : strides[1]), 1, -1);
3319 
3320                 // Set C element at {1,1} to minimum type value, thus we will have underflow at plus operation in D=A*B+C for this element
3321                 setDataInt(ptrs[2], dataTypes[2], strides[2] + 1, getLimit(dataTypes[2], false));
3322             }
3323         }
3324 
3325         flushAlloc(vk, device, buffers[0]->getAllocation());
3326         flushAlloc(vk, device, buffers[1]->getAllocation());
3327         flushAlloc(vk, device, buffers[2]->getAllocation());
3328         flushAlloc(vk, device, buffers[3]->getAllocation());
3329 
3330         ComputePipelineWrapper pipeline(vk, device, m_data.computePipelineConstructionType,
3331                                         m_context.getBinaryCollection().get("test"));
3332         pipeline.setDescriptorSetLayout(descriptorSetLayout.get());
3333         pipeline.setSpecializationInfo(specInfo);
3334         pipeline.setSubgroupSize(m_data.subgroupSizeMode == SUBGROUP_SIZE_NONE ?
3335                                      0 :
3336                                      getSubgroupSizeFromMode(m_context, m_data.subgroupSizeMode));
3337         pipeline.buildPipeline();
3338 
3339         const VkQueue queue             = m_context.getUniversalQueue();
3340         Move<VkCommandPool> cmdPool     = createCommandPool(vk, device, 0, m_context.getUniversalQueueFamilyIndex());
3341         Move<VkCommandBuffer> cmdBuffer = allocateCommandBuffer(vk, device, *cmdPool, VK_COMMAND_BUFFER_LEVEL_PRIMARY);
3342 
3343         beginCommandBuffer(vk, *cmdBuffer, 0u);
3344 
3345         vk.cmdBindDescriptorSets(*cmdBuffer, bindPoint, pipeline.getPipelineLayout(), 0u, 1, &*descriptorSet, 0u,
3346                                  nullptr);
3347         pipeline.bind(*cmdBuffer);
3348 
3349         // tensorlayout test has larger number of workgroups to allocate more memory
3350         // but only needs to launch one workgroup
3351         uint32_t workgroupsX = m_data.workgroupsX;
3352         uint32_t workgroupsY = m_data.workgroupsY;
3353         if (isTensorLayoutTest(m_data.testType))
3354         {
3355             workgroupsX = 1u;
3356             workgroupsY = 1u;
3357         }
3358 
3359         vk.cmdDispatch(*cmdBuffer, workgroupsX, workgroupsY, 1);
3360 
3361         const VkMemoryBarrier barrier = {
3362             VK_STRUCTURE_TYPE_MEMORY_BARRIER, // sType
3363             nullptr,                          // pNext
3364             VK_ACCESS_SHADER_WRITE_BIT,       // srcAccessMask
3365             VK_ACCESS_HOST_READ_BIT,          // dstAccessMask
3366         };
3367         vk.cmdPipelineBarrier(*cmdBuffer, VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, VK_PIPELINE_STAGE_HOST_BIT,
3368                               (VkDependencyFlags)0, 1, &barrier, 0, nullptr, 0, nullptr);
3369 
3370         endCommandBuffer(vk, *cmdBuffer);
3371 
3372         submitCommandsAndWait(vk, device, queue, cmdBuffer.get());
3373 
3374         invalidateAlloc(vk, device, buffers[3]->getAllocation());
3375 
3376         qpTestResult res = QP_TEST_RESULT_PASS;
3377 
3378         if (m_data.testType == TT_CONVERT)
3379         {
3380             for (uint32_t i = 0; i < totalElements[3]; ++i)
3381             {
3382                 // Store results as double, which has enough range to hold all the other types exactly.
3383                 double inputA, output;
3384 
3385                 // This loads the data according to dataTypes[0], and then converts to the template parameter type
3386                 switch (dataTypes[3])
3387                 {
3388                 case VK_COMPONENT_TYPE_UINT8_KHR:
3389                     inputA = getDataConvertedToT<uint8_t>(ptrs[0], dataTypes[0], i);
3390                     break;
3391                 case VK_COMPONENT_TYPE_UINT16_KHR:
3392                     inputA = getDataConvertedToT<uint16_t>(ptrs[0], dataTypes[0], i);
3393                     break;
3394                 case VK_COMPONENT_TYPE_UINT32_KHR:
3395                     inputA = getDataConvertedToT<uint32_t>(ptrs[0], dataTypes[0], i);
3396                     break;
3397                 case VK_COMPONENT_TYPE_SINT8_KHR:
3398                     inputA = getDataConvertedToT<int8_t>(ptrs[0], dataTypes[0], i);
3399                     break;
3400                 case VK_COMPONENT_TYPE_SINT16_KHR:
3401                     inputA = getDataConvertedToT<int16_t>(ptrs[0], dataTypes[0], i);
3402                     break;
3403                 case VK_COMPONENT_TYPE_SINT32_KHR:
3404                     inputA = getDataConvertedToT<int32_t>(ptrs[0], dataTypes[0], i);
3405                     break;
3406                 case VK_COMPONENT_TYPE_FLOAT32_KHR:
3407                     inputA = getDataConvertedToT<float>(ptrs[0], dataTypes[0], i);
3408                     break;
3409                 case VK_COMPONENT_TYPE_FLOAT16_KHR:
3410                 {
3411                     float temp = getDataConvertedToT<float>(ptrs[0], dataTypes[0], i);
3412                     inputA     = tcu::Float16(temp).asDouble();
3413                     break;
3414                 }
3415                 default:
3416                     TCU_THROW(InternalError, "Unexpected type");
3417                 }
3418 
3419                 switch (dataTypes[3])
3420                 {
3421                 case VK_COMPONENT_TYPE_UINT8_KHR:
3422                     output = getDataConvertedToT<uint8_t>(ptrs[3], dataTypes[3], i);
3423                     break;
3424                 case VK_COMPONENT_TYPE_UINT16_KHR:
3425                     output = getDataConvertedToT<uint16_t>(ptrs[3], dataTypes[3], i);
3426                     break;
3427                 case VK_COMPONENT_TYPE_UINT32_KHR:
3428                     output = getDataConvertedToT<uint32_t>(ptrs[3], dataTypes[3], i);
3429                     break;
3430                 case VK_COMPONENT_TYPE_SINT8_KHR:
3431                     output = getDataConvertedToT<int8_t>(ptrs[3], dataTypes[3], i);
3432                     break;
3433                 case VK_COMPONENT_TYPE_SINT16_KHR:
3434                     output = getDataConvertedToT<int16_t>(ptrs[3], dataTypes[3], i);
3435                     break;
3436                 case VK_COMPONENT_TYPE_SINT32_KHR:
3437                     output = getDataConvertedToT<int32_t>(ptrs[3], dataTypes[3], i);
3438                     break;
3439                 case VK_COMPONENT_TYPE_FLOAT32_KHR:
3440                     output = getDataConvertedToT<float>(ptrs[3], dataTypes[3], i);
3441                     break;
3442                 case VK_COMPONENT_TYPE_FLOAT16_KHR:
3443                 {
3444                     float temp = getDataConvertedToT<float>(ptrs[3], dataTypes[3], i);
3445                     output     = tcu::Float16(temp).asDouble();
3446                     break;
3447                 }
3448                 default:
3449                     TCU_THROW(InternalError, "Unexpected type");
3450                 }
3451 
3452                 if (inputA != output)
3453                 {
3454                     res = QP_TEST_RESULT_FAIL;
3455                     break;
3456                 }
3457             }
3458         }
3459         else if (isFloatType(dataTypes[0]))
3460         {
3461             if (isReduceOp(m_data.testType))
3462             {
3463                 uint32_t numMatrixX = (m_data.scope == VK_SCOPE_WORKGROUP_KHR) ?
3464                                           m_data.workgroupsX :
3465                                           (m_data.subgroupsPerWorkgroupX * m_data.workgroupsX);
3466                 uint32_t numMatrixY = (m_data.scope == VK_SCOPE_WORKGROUP_KHR) ?
3467                                           m_data.workgroupsY :
3468                                           (m_data.subgroupsPerWorkgroupY * m_data.workgroupsY);
3469                 for (uint32_t mX = 0; mX < numMatrixX; ++mX)
3470                 {
3471                     for (uint32_t mY = 0; mY < numMatrixY; ++mY)
3472                     {
3473                         auto const getA = [&](uint32_t i, uint32_t j) -> float
3474                         {
3475                             uint32_t ij;
3476                             if (m_data.colMajor)
3477                                 ij = mX * dims[0].rows + i + strides[0] * mY * dims[0].cols + loadStrides[0] * j;
3478                             else
3479                                 ij = mX * dims[0].cols + j + strides[0] * mY * dims[0].rows + loadStrides[0] * i;
3480 
3481                             float Aij = getDataFloat(ptrs[0], dataTypes[0], ij);
3482                             return Aij;
3483                         };
3484 
3485                         auto const getD = [&](uint32_t i, uint32_t j) -> float
3486                         {
3487                             uint32_t ij;
3488                             // When loading with stride 0, ij for matrix D is different from matrix C
3489                             if (m_data.colMajor)
3490                                 ij = mX * dims[3].rows + i + strides[3] * (mY * dims[3].cols + j);
3491                             else
3492                                 ij = mX * dims[3].cols + j + strides[3] * (mY * dims[3].rows + i);
3493 
3494                             float Dij = getDataFloat(ptrs[3], dataTypes[3], ij);
3495                             return Dij;
3496                         };
3497 
3498                         std::function<float(float, float)> Combine;
3499                         float identity;
3500                         if (isReduceSum(m_data.testType))
3501                         {
3502                             Combine  = [](float a, float b) { return a + b; };
3503                             identity = 0;
3504                         }
3505                         else if (isReduceMin(m_data.testType))
3506                         {
3507                             Combine  = [](float a, float b) { return std::min(a, b); };
3508                             identity = std::numeric_limits<float>::max();
3509                         }
3510                         else
3511                         {
3512                             Combine  = [](float a, float b) { return std::max(a, b); };
3513                             identity = -std::numeric_limits<float>::max();
3514                         }
3515 
3516                         uint32_t outputM = M * reduceMScale(m_data.testType);
3517                         uint32_t outputN = N * reduceNScale(m_data.testType);
3518                         if (isReduceRow(m_data.testType))
3519                         {
3520                             for (uint32_t i = 0; i < M; ++i)
3521                             {
3522                                 float ref = identity;
3523                                 for (uint32_t j = 0; j < N; ++j)
3524                                 {
3525                                     ref = Combine(ref, getA(i, j));
3526                                 }
3527                                 for (uint32_t j = 0; j < outputN; ++j)
3528                                 {
3529                                     float Dij = getD(i, j);
3530                                     if (fabs(ref - Dij) / (fabs(ref) + 0.001) > 3.0 / 1024)
3531                                     {
3532                                         //printf("mX %d mY %d i %d j %d ref %f Dij %f\n", mX, mY, i, j, ref, Dij);
3533                                         res = QP_TEST_RESULT_FAIL;
3534                                     }
3535                                     float Di0 = getD(i, 0);
3536                                     if (Dij != Di0)
3537                                     {
3538                                         //printf("mX %d mY %d i %d j %d Di0 %f Dij %f\n", mX, mY, i, j, Di0, Dij);
3539                                         res = QP_TEST_RESULT_FAIL;
3540                                     }
3541                                 }
3542                             }
3543                         }
3544                         else if (isReduceCol(m_data.testType))
3545                         {
3546                             for (uint32_t j = 0; j < N; ++j)
3547                             {
3548                                 float ref = identity;
3549                                 for (uint32_t i = 0; i < M; ++i)
3550                                 {
3551                                     ref = Combine(ref, getA(i, j));
3552                                 }
3553                                 for (uint32_t i = 0; i < outputM; ++i)
3554                                 {
3555                                     float Dij = getD(i, j);
3556                                     if (fabs(ref - Dij) / (fabs(ref) + 0.001) > 3.0 / 1024)
3557                                     {
3558                                         //printf("mX %d mY %d i %d j %d ref %f Dij %f\n", mX, mY, i, j, ref, Dij);
3559                                         res = QP_TEST_RESULT_FAIL;
3560                                     }
3561                                     float D0j = getD(0, j);
3562                                     if (Dij != D0j)
3563                                     {
3564                                         //printf("mX %d mY %d i %d j %d D0j %f Dij %f\n", mX, mY, i, j, D0j, Dij);
3565                                         res = QP_TEST_RESULT_FAIL;
3566                                     }
3567                                 }
3568                             }
3569                         }
3570                         else if (isReduceRowCol(m_data.testType))
3571                         {
3572                             float ref = identity;
3573                             for (uint32_t i = 0; i < M; ++i)
3574                             {
3575                                 for (uint32_t j = 0; j < N; ++j)
3576                                 {
3577                                     ref = Combine(ref, getA(i, j));
3578                                 }
3579                             }
3580                             for (uint32_t i = 0; i < outputM; ++i)
3581                             {
3582                                 for (uint32_t j = 0; j < outputN; ++j)
3583                                 {
3584                                     float Dij = getD(i, j);
3585                                     if (fabs(ref - Dij) / (fabs(ref) + 0.001) > 3.0 / 1024)
3586                                     {
3587                                         //printf("mX %d mY %d i %d j %d ref %f Dij %f\n", mX, mY, i, j, ref, Dij);
3588                                         res = QP_TEST_RESULT_FAIL;
3589                                     }
3590                                     float D00 = getD(0, 0);
3591                                     if (Dij != D00)
3592                                     {
3593                                         //printf("mX %d mY %d i %d j %d D00 %f Dij %f\n", mX, mY, i, j, D00, Dij);
3594                                         res = QP_TEST_RESULT_FAIL;
3595                                     }
3596                                 }
3597                             }
3598                         }
3599                         else if (isReduce2x2(m_data.testType))
3600                         {
3601                             for (uint32_t j = 0; j < N; ++j)
3602                             {
3603                                 for (uint32_t i = 0; i < M; ++i)
3604                                 {
3605                                     float ref = identity;
3606                                     ref       = Combine(ref, getA(i * 2 + 0, j * 2 + 0));
3607                                     ref       = Combine(ref, getA(i * 2 + 0, j * 2 + 1));
3608                                     ref       = Combine(ref, getA(i * 2 + 1, j * 2 + 0));
3609                                     ref       = Combine(ref, getA(i * 2 + 1, j * 2 + 1));
3610 
3611                                     float Dij = getD(i, j);
3612                                     if (ref != Dij)
3613                                     {
3614                                         //printf("mX %d mY %d i %d j %d ref %f Dij %f\n", mX, mY, i, j, ref, Dij);
3615                                         res = QP_TEST_RESULT_FAIL;
3616                                     }
3617                                 }
3618                             }
3619                         }
3620                         else
3621                         {
3622                             DE_ASSERT(0);
3623                         }
3624                     }
3625                 }
3626             }
3627             else if (m_data.testType == TT_TRANSPOSE_ACC_TO_B)
3628             {
3629                 uint32_t ij;
3630                 uint32_t numMatrixX = (m_data.scope == VK_SCOPE_WORKGROUP_KHR) ?
3631                                           m_data.workgroupsX :
3632                                           (m_data.subgroupsPerWorkgroupX * m_data.workgroupsX);
3633                 uint32_t numMatrixY = (m_data.scope == VK_SCOPE_WORKGROUP_KHR) ?
3634                                           m_data.workgroupsY :
3635                                           (m_data.subgroupsPerWorkgroupY * m_data.workgroupsY);
3636                 for (uint32_t mX = 0; mX < numMatrixX; ++mX)
3637                 {
3638                     for (uint32_t mY = 0; mY < numMatrixY; ++mY)
3639                     {
3640                         for (uint32_t i = 0; i < M; ++i)
3641                         {
3642                             for (uint32_t j = 0; j < N; ++j)
3643                             {
3644                                 // for row-major, src is MxN, so row,col = i,j
3645                                 if (m_data.colMajor)
3646                                     ij = mX * M + i + strides[0] * mY * N + loadStrides[0] * j;
3647                                 else
3648                                     ij = mX * N + j + strides[0] * mY * M + loadStrides[0] * i;
3649 
3650                                 float ref = getDataFloat(ptrs[0], dataTypes[0], ij);
3651 
3652                                 // for row-major, dst is NxM, so row,col = j,i
3653                                 if (m_data.colMajor)
3654                                     ij = mX * N + j + strides[3] * (mY * M + i);
3655                                 else
3656                                     ij = mX * M + i + strides[3] * (mY * N + j);
3657 
3658                                 float Dij = getDataFloat(ptrs[3], dataTypes[3], ij);
3659 
3660                                 if (ref != Dij)
3661                                 {
3662                                     res = QP_TEST_RESULT_FAIL;
3663                                 }
3664                             }
3665                         }
3666                     }
3667                 }
3668             }
3669             else if (m_data.testType == TT_SPACETODEPTH)
3670             {
3671                 uint32_t H = 32;
3672                 uint32_t W = 32;
3673                 uint32_t C = 16;
3674                 for (uint32_t h = 0; h < H; ++h)
3675                 {
3676                     for (uint32_t w = 0; w < W; ++w)
3677                     {
3678                         for (uint32_t c = 0; c < C; ++c)
3679                         {
3680                             uint32_t inputIndex  = (h * W + w) * C + c;
3681                             uint32_t outputIndex = ((h / 2) * W / 2 + w / 2) * 4 * C + ((h & 1) * 2 + (w & 1)) * C + c;
3682                             float ref            = getDataFloat(ptrs[0], dataTypes[0], inputIndex);
3683                             float output         = getDataFloat(ptrs[3], dataTypes[3], outputIndex);
3684                             if (ref != output)
3685                             {
3686                                 //printf("h %d w %d c %d ref %f output %f\n", h, w, c, ref, output);
3687                                 res = QP_TEST_RESULT_FAIL;
3688                             }
3689                         }
3690                     }
3691                 }
3692             }
3693             else if (isTensorLayoutTest(m_data.testType))
3694             {
3695                 uint32_t dim = GetDim(m_data.testType);
3696                 for (int32_t i0 = 0; i0 < GetTensorLayoutDim(dim)[0]; ++i0)
3697                 {
3698                     for (int32_t i1 = 0; i1 < GetTensorLayoutDim(dim)[1]; ++i1)
3699                     {
3700                         for (int32_t i2 = 0; i2 < GetTensorLayoutDim(dim)[2]; ++i2)
3701                         {
3702                             for (int32_t i3 = 0; i3 < GetTensorLayoutDim(dim)[3]; ++i3)
3703                             {
3704                                 for (int32_t i4 = 0; i4 < GetTensorLayoutDim(dim)[4]; ++i4)
3705                                 {
3706                                     int32_t tensorCoord[5] = {i0, i1, i2, i3, i4};
3707                                     uint32_t index         = 0;
3708                                     for (uint32_t k = 0; k < dim; ++k)
3709                                     {
3710                                         index = index * GetTensorLayoutDim(dim)[k] + tensorCoord[k];
3711                                     }
3712                                     float ref    = 123.0f;
3713                                     float output = getDataFloat(ptrs[3], dataTypes[3], index);
3714                                     // If the dest coord is in one of the store rectangles, compute
3715                                     // a different reference value.
3716                                     for (uint32_t r = 0; r < GetTensorLayoutNumCoords(dim); ++r)
3717                                     {
3718                                         bool inStoreRect = true;
3719                                         for (uint32_t k = 0; k < dim; ++k)
3720                                         {
3721                                             if ((int32_t)tensorCoord[k] < GetTensorLayoutStoreOffsets(dim, r)[k] ||
3722                                                 (int32_t)tensorCoord[k] >= GetTensorLayoutStoreOffsets(dim, r)[k] +
3723                                                                                GetTensorLayoutSpan(dim, r)[k])
3724                                             {
3725                                                 inStoreRect = false;
3726                                             }
3727                                         }
3728 
3729                                         if (inStoreRect)
3730                                         {
3731                                             int32_t loadCoord[5] = {i0, i1, i2, i3, i4};
3732                                             for (uint32_t k = 0; k < dim; ++k)
3733                                             {
3734                                                 loadCoord[k] = loadCoord[k] - GetTensorLayoutStoreOffsets(dim, r)[k] +
3735                                                                GetTensorLayoutLoadOffsets(dim, r)[k];
3736                                             }
3737                                             bool OOB = false;
3738                                             // gl_CooperativeMatrixClampModeConstant bounds checking
3739                                             for (uint32_t k = 0; k < dim; ++k)
3740                                             {
3741                                                 if (loadCoord[k] < 0 || loadCoord[k] >= GetTensorLayoutDim(dim)[k])
3742                                                 {
3743                                                     OOB = true;
3744                                                 }
3745                                             }
3746                                             if (OOB)
3747                                             {
3748                                                 ref = 0.0f;
3749                                             }
3750                                             else
3751                                             {
3752                                                 index = 0;
3753                                                 for (uint32_t k = 0; k < dim; ++k)
3754                                                 {
3755                                                     index = index * GetTensorLayoutDim(dim)[k] + loadCoord[k];
3756                                                 }
3757                                                 ref = getDataFloat(ptrs[0], dataTypes[0], index);
3758                                             }
3759                                             break;
3760                                         }
3761                                     }
3762                                     if (ref != output)
3763                                     {
3764                                         //printf("tensorCoord {%d, %d, %d, %d, %d} ref %f output %f\n", tensorCoord[0], tensorCoord[1], tensorCoord[2], tensorCoord[3], tensorCoord[4], ref, output);
3765                                         res = QP_TEST_RESULT_FAIL;
3766                                     }
3767                                 }
3768                             }
3769                         }
3770                     }
3771                 }
3772             }
3773             else if (m_data.testType == TT_PER_ELEMENT_OP_ROW_COL)
3774             {
3775                 uint32_t ij;
3776                 uint32_t numMatrixX = (m_data.scope == VK_SCOPE_WORKGROUP_KHR) ?
3777                                           m_data.workgroupsX :
3778                                           (m_data.subgroupsPerWorkgroupX * m_data.workgroupsX);
3779                 uint32_t numMatrixY = (m_data.scope == VK_SCOPE_WORKGROUP_KHR) ?
3780                                           m_data.workgroupsY :
3781                                           (m_data.subgroupsPerWorkgroupY * m_data.workgroupsY);
3782                 for (uint32_t mX = 0; mX < numMatrixX; ++mX)
3783                 {
3784                     for (uint32_t mY = 0; mY < numMatrixY; ++mY)
3785                     {
3786                         for (uint32_t i = 0; i < M; ++i)
3787                         {
3788                             for (uint32_t j = 0; j < N; ++j)
3789                             {
3790                                 if (m_data.colMajor)
3791                                     ij = mX * M + i + strides[0] * mY * N + loadStrides[0] * j;
3792                                 else
3793                                     ij = mX * N + j + strides[0] * mY * M + loadStrides[0] * i;
3794 
3795                                 float ref = getDataFloat(ptrs[0], dataTypes[0], ij);
3796 
3797                                 float Dij = getDataFloat(ptrs[3], dataTypes[3], ij);
3798 
3799                                 if (ref + (float)(i * 3 + j) != Dij)
3800                                 {
3801                                     res = QP_TEST_RESULT_FAIL;
3802                                 }
3803                             }
3804                         }
3805                     }
3806                 }
3807             }
3808             else if (isClampTest(m_data.testType))
3809             {
3810                 uint32_t ij;
3811                 uint32_t numMatrixX = (m_data.scope == VK_SCOPE_WORKGROUP_KHR) ?
3812                                           m_data.workgroupsX :
3813                                           (m_data.subgroupsPerWorkgroupX * m_data.workgroupsX);
3814                 uint32_t numMatrixY = (m_data.scope == VK_SCOPE_WORKGROUP_KHR) ?
3815                                           m_data.workgroupsY :
3816                                           (m_data.subgroupsPerWorkgroupY * m_data.workgroupsY);
3817                 uint32_t fullDimX   = numMatrixX * (m_data.colMajor ? dims[0].rows : dims[0].cols);
3818                 uint32_t fullDimY   = numMatrixY * (m_data.colMajor ? dims[0].cols : dims[0].rows);
3819                 uint32_t dimX       = fullDimX - 6;
3820                 uint32_t dimY       = fullDimY - 6;
3821                 for (uint32_t mX = 0; mX < numMatrixX; ++mX)
3822                 {
3823                     for (uint32_t mY = 0; mY < numMatrixY; ++mY)
3824                     {
3825                         for (uint32_t i = 0; i < M; ++i)
3826                         {
3827                             for (uint32_t j = 0; j < N; ++j)
3828                             {
3829                                 int32_t i2;
3830                                 int32_t j2;
3831                                 bool OOBLoad  = false;
3832                                 bool OOBStore = false;
3833 
3834                                 if (m_data.colMajor)
3835                                 {
3836                                     i2       = mX * M + i;
3837                                     j2       = mY * N + j;
3838                                     ij       = i2 + strides[3] * j2;
3839                                     OOBStore = i2 == (int32_t)fullDimX - 1 || j2 == (int32_t)fullDimY - 1;
3840                                 }
3841                                 else
3842                                 {
3843                                     i2       = mY * M + i;
3844                                     j2       = mX * N + j;
3845                                     ij       = j2 + strides[3] * i2;
3846                                     OOBStore = i2 == (int32_t)fullDimY - 1 || j2 == (int32_t)fullDimX - 1;
3847                                 }
3848 
3849                                 float Dij = getDataFloat(ptrs[3], dataTypes[3], ij);
3850 
3851                                 auto const mod = [](int32_t n, int32_t d) -> int32_t
3852                                 {
3853                                     // works for the range of values we use
3854                                     return (n + d) % d;
3855                                 };
3856 
3857                                 i2 -= 3;
3858                                 j2 -= 3;
3859                                 uint32_t dimI = m_data.colMajor ? dimX : dimY;
3860                                 uint32_t dimJ = m_data.colMajor ? dimY : dimX;
3861                                 switch (m_data.testType)
3862                                 {
3863                                 case TT_CLAMPCONSTANT:
3864                                     OOBLoad = i2 < 0 || j2 < 0 || i2 >= (int32_t)dimI || j2 >= (int32_t)dimJ;
3865                                     break;
3866                                 case TT_CLAMPTOEDGE:
3867                                     i2 = std::min(std::max(i2, 0), (int32_t)dimI - 1);
3868                                     j2 = std::min(std::max(j2, 0), (int32_t)dimJ - 1);
3869                                     break;
3870                                 case TT_CLAMPREPEAT:
3871                                     i2 = mod(i2, dimI);
3872                                     j2 = mod(j2, dimJ);
3873                                     break;
3874                                 case TT_CLAMPMIRRORREPEAT:
3875                                     i2 = mod(i2, (2 * dimI - 2));
3876                                     i2 = (i2 >= (int32_t)dimI) ? (2 * dimI - 2 - i2) : i2;
3877                                     j2 = mod(j2, (2 * dimJ - 2));
3878                                     j2 = (j2 >= (int32_t)dimJ) ? (2 * dimJ - 2 - j2) : j2;
3879                                     break;
3880                                 default:
3881                                     DE_ASSERT(0);
3882                                     break;
3883                                 }
3884 
3885                                 if (m_data.colMajor)
3886                                 {
3887                                     ij = i2 + strides[0] * j2;
3888                                 }
3889                                 else
3890                                 {
3891                                     ij = j2 + strides[0] * i2;
3892                                 }
3893 
3894                                 float ref = OOBStore ? 123.0f :
3895                                             OOBLoad  ? 0.5f :
3896                                                        getDataFloat(ptrs[0], dataTypes[0], ij);
3897 
3898                                 if (ref != Dij)
3899                                 {
3900                                     //printf("fail ");
3901                                     res = QP_TEST_RESULT_FAIL;
3902                                 }
3903                                 //printf("i %d j %d ref %f Dij %f\n", i, j, ref, Dij);
3904                             }
3905                         }
3906                     }
3907                 }
3908             }
3909             else if ((m_data.addrMethod == ADDR_BLOCKSIZE || m_data.addrMethod == ADDR_DECODE) &&
3910                      m_data.testType != TT_MATRIXMULADD_DEQUANT)
3911             {
3912                 uint32_t ij, blockij;
3913                 uint32_t numMatrixX = (m_data.scope == VK_SCOPE_WORKGROUP_KHR) ?
3914                                           m_data.workgroupsX :
3915                                           (m_data.subgroupsPerWorkgroupX * m_data.workgroupsX);
3916                 uint32_t numMatrixY = (m_data.scope == VK_SCOPE_WORKGROUP_KHR) ?
3917                                           m_data.workgroupsY :
3918                                           (m_data.subgroupsPerWorkgroupY * m_data.workgroupsY);
3919                 for (uint32_t mX = 0; mX < numMatrixX; ++mX)
3920                 {
3921                     for (uint32_t mY = 0; mY < numMatrixY; ++mY)
3922                     {
3923                         for (uint32_t i = 0; i < M; ++i)
3924                         {
3925                             for (uint32_t j = 0; j < N; ++j)
3926                             {
3927                                 uint32_t blockCoords[2];
3928                                 if (m_data.colMajor)
3929                                 {
3930                                     blockCoords[0] = (mY * N + j) / blockSize[0];
3931                                     blockCoords[1] = (mX * M + i) / blockSize[1];
3932                                     blockij        = blockCoords[1] + (strides[0] / blockSize[1]) * blockCoords[0];
3933                                     if (m_data.addrMethod == ADDR_DECODE)
3934                                     {
3935                                         blockij *= blockSize[0] * blockSize[1];
3936                                         blockij += (j % blockSize[0]) * blockSize[1] + (i % blockSize[1]);
3937                                     }
3938                                     ij = mX * M + i + strides[0] * mY * N + loadStrides[0] * j;
3939                                 }
3940                                 else
3941                                 {
3942                                     blockCoords[0] = (mY * M + i) / blockSize[0];
3943                                     blockCoords[1] = (mX * N + j) / blockSize[1];
3944                                     blockij        = blockCoords[1] + (strides[0] / blockSize[1]) * blockCoords[0];
3945                                     if (m_data.addrMethod == ADDR_DECODE)
3946                                     {
3947                                         blockij *= blockSize[0] * blockSize[1];
3948                                         blockij += (i % blockSize[0]) * blockSize[1] + (j % blockSize[1]);
3949                                     }
3950                                     ij = mX * N + j + strides[0] * mY * M + loadStrides[0] * i;
3951                                 }
3952 
3953                                 float ref = getDataFloat(ptrs[0], dataTypes[0], blockij);
3954 
3955                                 if (m_data.addrMethod == ADDR_DECODE)
3956                                 {
3957                                     ref += (float)((2 * blockCoords[0] + blockCoords[1]) & 3);
3958                                 }
3959 
3960                                 float Dij = getDataFloat(ptrs[3], dataTypes[3], ij);
3961 
3962                                 if (m_data.testType == TT_NEGATE)
3963                                 {
3964                                     ref = -ref;
3965                                 }
3966                                 else
3967                                 {
3968                                     DE_ASSERT(0);
3969                                 }
3970 
3971                                 if (ref != Dij)
3972                                 {
3973                                     //printf("fail ");
3974                                     res = QP_TEST_RESULT_FAIL;
3975                                 }
3976                                 //printf("mX %d mY %d i %d j %d ref %f D %f\n", mX, mY, i, j, ref, Dij);
3977                             }
3978                         }
3979                     }
3980                 }
3981             }
3982             else if (!isMatrixMulAddOp(m_data.testType))
3983             {
3984                 for (uint32_t i = 0; i < totalElements[3]; ++i)
3985                 {
3986                     float inputA = getDataFloat(ptrs[0], dataTypes[0], i);
3987                     float inputB = getDataFloat(ptrs[1], dataTypes[1], i);
3988                     float output = getDataFloat(ptrs[3], dataTypes[3], i);
3989                     switch (m_data.testType)
3990                     {
3991                     case TT_LENGTH:
3992                         if (output < 1.0f || output > (float)(N * M))
3993                             res = QP_TEST_RESULT_FAIL;
3994                         if (m_data.scope == VK_SCOPE_SUBGROUP_KHR)
3995                         {
3996                             // We expect the matrix to be spread evenly across invocations, it is
3997                             // surprising (but not necessarily illegal) if not
3998                             if (output != (float)(N * M / subgroupSize) && res == QP_TEST_RESULT_PASS)
3999                             {
4000                                 res = QP_TEST_RESULT_QUALITY_WARNING;
4001                             }
4002                         }
4003                         break;
4004                     case TT_CONSTANT:
4005                         if (output != 1.0f)
4006                             res = QP_TEST_RESULT_FAIL;
4007                         break;
4008                     case TT_COMPOSITE:
4009                     case TT_COMPOSITE_RVALUE:
4010                     case TT_COMPOSITE_ARRAY:
4011                     case TT_ADD:
4012                         if (output != inputA + inputB)
4013                             res = QP_TEST_RESULT_FAIL;
4014                         break;
4015                     case TT_SUB:
4016                         if (output != inputA - inputB)
4017                             res = QP_TEST_RESULT_FAIL;
4018                         break;
4019                     case TT_DIV:
4020                     {
4021                         float ulp = (m_data.inputType == VK_COMPONENT_TYPE_FLOAT16_KHR) ?
4022                                         1.0f / 1024.0f :
4023                                         1.0f / (8.0f * 1024.0f * 1024.0f);
4024                         // division allows 2.5ulp, but we'll use 3.
4025                         ulp *= 3;
4026                         if (inputB != 0 && fabs(output - inputA / inputB) > ulp * fabs(inputA / inputB))
4027                             res = QP_TEST_RESULT_FAIL;
4028                     }
4029                     break;
4030                     case TT_MUL:
4031                     {
4032                         if (dataTypes[0] == VK_COMPONENT_TYPE_FLOAT16_KHR)
4033                         {
4034                             const float expected32          = inputA * inputB;
4035                             const tcu::float16_t expected16 = tcu::Float16(expected32).bits();
4036                             const float expected            = tcu::Float16(expected16).asFloat();
4037 
4038                             if (output != expected)
4039                                 res = QP_TEST_RESULT_FAIL;
4040                         }
4041                         else
4042                         {
4043                             if (output != inputA * inputB)
4044                                 res = QP_TEST_RESULT_FAIL;
4045                         }
4046                         break;
4047                     }
4048                     case TT_NEGATE:
4049                     case TT_FUNC:
4050                         if (output != -inputA)
4051                             res = QP_TEST_RESULT_FAIL;
4052                         break;
4053                     case TT_MATRIXTIMESSCALAR:
4054                         if (output != 6.0 * inputA)
4055                             res = QP_TEST_RESULT_FAIL;
4056                         break;
4057                     case TT_MULTICOMPONENT_LOAD:
4058                     {
4059                         if (output != inputA)
4060                             res = QP_TEST_RESULT_FAIL;
4061                         break;
4062                     }
4063                     case TT_MULTICOMPONENT_SAVE:
4064                     case TT_CONVERT_ACC_TO_A:
4065                     case TT_CONVERT_ACC_TO_B:
4066                     {
4067                         if (output != inputA)
4068                         {
4069                             //printf("i %d inputA %f output %f\n", i, inputA, output);
4070                             res = QP_TEST_RESULT_FAIL;
4071                         }
4072                         break;
4073                     }
4074                     case TT_PER_ELEMENT_OP:
4075                     case TT_PER_ELEMENT_OP_STRUCT:
4076                         if (output != inputA + 2.0)
4077                             res = QP_TEST_RESULT_FAIL;
4078                         break;
4079                     case TT_PER_ELEMENT_OP_MAT:
4080                         if (output != 3 * inputA)
4081                             res = QP_TEST_RESULT_FAIL;
4082                         break;
4083                     default:
4084                         TCU_THROW(InternalError, "Unimplemented");
4085                     }
4086                 }
4087             }
4088             else
4089             {
4090                 uint32_t ik, kj, ij;
4091                 uint32_t numMatrixX = (m_data.scope == VK_SCOPE_WORKGROUP_KHR) ?
4092                                           m_data.workgroupsX :
4093                                           (m_data.subgroupsPerWorkgroupX * m_data.workgroupsX);
4094                 uint32_t numMatrixY = (m_data.scope == VK_SCOPE_WORKGROUP_KHR) ?
4095                                           m_data.workgroupsY :
4096                                           (m_data.subgroupsPerWorkgroupY * m_data.workgroupsY);
4097                 for (uint32_t mX = 0; mX < numMatrixX; ++mX)
4098                 {
4099                     for (uint32_t mY = 0; mY < numMatrixY; ++mY)
4100                     {
4101                         for (uint32_t i = 0; i < M; ++i)
4102                         {
4103                             for (uint32_t j = 0; j < N; ++j)
4104                             {
4105                                 float ref = 0;
4106                                 for (uint32_t k = 0; k < K; ++k)
4107                                 {
4108                                     float Aik, Bkj;
4109                                     if (m_data.testType == TT_MATRIXMULADD_DEQUANT)
4110                                     {
4111                                         uint32_t idxInBlock, idx, arrayidx, shift;
4112                                         int32_t value;
4113 
4114                                         // Blocks are stored in row-major order. Compute index of the block
4115                                         // and index within block.
4116                                         DE_ASSERT(!m_data.colMajor);
4117                                         uint32_t blockik = ((mX * K + k) / blockSize[1]) +
4118                                                            (strides[0] / blockSize[1]) * ((mY * M + i) / blockSize[0]);
4119 
4120                                         idxInBlock = (i % blockSize[0]) * blockSize[1] + (k % blockSize[1]);
4121 
4122                                         // Compute block index (idx) and extract a 4bpp element from the block
4123                                         idx      = blockik * blockSize[0] * blockSize[1] + idxInBlock;
4124                                         arrayidx = idx / 2;
4125                                         shift    = (idx & 1) * 4;
4126                                         value    = getDataInt(ptrs[0], VK_COMPONENT_TYPE_UINT8_KHR, arrayidx);
4127                                         value    = (value >> shift) & 0xF;
4128                                         // decode
4129                                         Aik = 0.5f * (float)(value - 4);
4130 
4131                                         // Repeat for B matrix
4132                                         uint32_t blockkj = ((mX * N + j) / blockSize[1]) +
4133                                                            (strides[1] / blockSize[1]) * ((mY * K + k) / blockSize[0]);
4134 
4135                                         idxInBlock = (k % blockSize[0]) * blockSize[1] + (j % blockSize[1]);
4136 
4137                                         idx      = blockkj * blockSize[0] * blockSize[1] + idxInBlock;
4138                                         arrayidx = idx / 2;
4139                                         shift    = (idx & 1) * 4;
4140                                         value    = getDataInt(ptrs[1], VK_COMPONENT_TYPE_UINT8_KHR, arrayidx);
4141                                         value    = (value >> shift) & 0xF;
4142                                         Bkj      = 0.5f * (float)(value - 4);
4143                                     }
4144                                     else
4145                                     {
4146                                         if (m_data.colMajor)
4147                                             ik = mX * M + i + strides[0] * mY * K + loadStrides[0] * k;
4148                                         else
4149                                             ik = mX * K + k + strides[0] * mY * M + loadStrides[0] * i;
4150 
4151                                         Aik = getDataFloat(ptrs[0], dataTypes[0], ik);
4152 
4153                                         if (m_data.colMajor)
4154                                             kj = mX * K + k + strides[1] * mY * N + loadStrides[1] * j;
4155                                         else
4156                                             kj = mX * N + j + strides[1] * mY * K + loadStrides[1] * k;
4157 
4158                                         Bkj = getDataFloat(ptrs[1], dataTypes[1], kj);
4159                                     }
4160 
4161                                     ref += Aik * Bkj;
4162                                 }
4163 
4164                                 if (m_data.colMajor)
4165                                     ij = mX * M + i + strides[2] * mY * N + loadStrides[2] * j;
4166                                 else
4167                                     ij = mX * N + j + strides[2] * mY * M + loadStrides[2] * i;
4168 
4169                                 float Cij = getDataFloat(ptrs[2], dataTypes[2], ij);
4170 
4171                                 ref += Cij;
4172 
4173                                 // When loading with stride 0, ij for matrix D is different from matrix C
4174                                 if (m_data.colMajor)
4175                                     ij = mX * M + i + strides[2] * (mY * N + j);
4176                                 else
4177                                     ij = mX * N + j + strides[2] * (mY * M + i);
4178 
4179                                 float Dij = getDataFloat(ptrs[3], dataTypes[3], ij);
4180 
4181                                 //printf("i %d j %d ref %f Dij %f\n", i, j, ref, Dij);
4182 
4183                                 if (fabs(ref - Dij) > epsilon)
4184                                 {
4185                                     if (max(max(M, N), K) >= 48)
4186                                     {
4187                                         if (fabs(ref - Dij) / (fabs(ref) + 0.001) > 3.0 / 1024)
4188                                         {
4189                                             //printf("ref %f Dij %f\n", ref, Dij);
4190                                             res = QP_TEST_RESULT_FAIL;
4191                                         }
4192                                     }
4193                                     else
4194                                     {
4195                                         //printf("i %d j %d ref %f Dij %f\n", i, j, ref, Dij);
4196                                         res = QP_TEST_RESULT_FAIL;
4197                                     }
4198                                 }
4199                             }
4200                         }
4201                     }
4202                 }
4203             }
4204         }
4205         else
4206         {
4207             if (isReduceOp(m_data.testType))
4208             {
4209                 uint32_t numMatrixX = (m_data.scope == VK_SCOPE_WORKGROUP_KHR) ?
4210                                           m_data.workgroupsX :
4211                                           (m_data.subgroupsPerWorkgroupX * m_data.workgroupsX);
4212                 uint32_t numMatrixY = (m_data.scope == VK_SCOPE_WORKGROUP_KHR) ?
4213                                           m_data.workgroupsY :
4214                                           (m_data.subgroupsPerWorkgroupY * m_data.workgroupsY);
4215                 int resultSize      = componentTypeInfo[dataTypes[3]].bits;
4216                 uint32_t mask       = resultSize == 32 ? ~0 : ((1 << resultSize) - 1);
4217                 for (uint32_t mX = 0; mX < numMatrixX; ++mX)
4218                 {
4219                     for (uint32_t mY = 0; mY < numMatrixY; ++mY)
4220                     {
4221                         bool isSigned   = componentTypeInfo[dataTypes[0]].isSigned;
4222                         auto const getA = [&](uint32_t i, uint32_t j) -> int64_t
4223                         {
4224                             uint32_t ij;
4225                             if (m_data.colMajor)
4226                                 ij = mX * dims[0].rows + i + strides[0] * mY * dims[0].cols + loadStrides[0] * j;
4227                             else
4228                                 ij = mX * dims[0].cols + j + strides[0] * mY * dims[0].rows + loadStrides[0] * i;
4229 
4230                             uint32_t Aij = getDataInt(ptrs[0], dataTypes[0], ij);
4231                             if (isSigned)
4232                             {
4233                                 return (int64_t)(int32_t)Aij;
4234                             }
4235                             else
4236                             {
4237                                 return (int64_t)Aij;
4238                             }
4239                         };
4240 
4241                         auto const getD = [&](uint32_t i, uint32_t j) -> int64_t
4242                         {
4243                             uint32_t ij;
4244                             // When loading with stride 0, ij for matrix D is different from matrix C
4245                             if (m_data.colMajor)
4246                                 ij = mX * dims[3].rows + i + strides[3] * (mY * dims[3].cols + j);
4247                             else
4248                                 ij = mX * dims[3].cols + j + strides[3] * (mY * dims[3].rows + i);
4249 
4250                             uint32_t Dij = getDataInt(ptrs[3], dataTypes[3], ij);
4251                             if (isSigned)
4252                             {
4253                                 return (int64_t)(int32_t)Dij;
4254                             }
4255                             else
4256                             {
4257                                 return (int64_t)Dij;
4258                             }
4259                         };
4260 
4261                         std::function<int64_t(int64_t, int64_t)> Combine;
4262                         int64_t identity;
4263                         if (isReduceSum(m_data.testType))
4264                         {
4265                             Combine  = [](int64_t a, int64_t b) { return a + b; };
4266                             identity = 0;
4267                         }
4268                         else if (isReduceMin(m_data.testType))
4269                         {
4270                             Combine  = [](int64_t a, int64_t b) { return std::min(a, b); };
4271                             identity = std::numeric_limits<int64_t>::max();
4272                         }
4273                         else
4274                         {
4275                             Combine  = [](int64_t a, int64_t b) { return std::max(a, b); };
4276                             identity = -std::numeric_limits<int64_t>::max();
4277                         }
4278 
4279                         uint32_t outputM = M * reduceMScale(m_data.testType);
4280                         uint32_t outputN = N * reduceNScale(m_data.testType);
4281                         if (isReduceRow(m_data.testType))
4282                         {
4283                             for (uint32_t i = 0; i < M; ++i)
4284                             {
4285                                 int64_t ref = identity;
4286                                 for (uint32_t j = 0; j < N; ++j)
4287                                 {
4288                                     ref = Combine(ref, getA(i, j));
4289                                 }
4290                                 for (uint32_t j = 0; j < outputN; ++j)
4291                                 {
4292                                     int64_t Dij = getD(i, j);
4293                                     if ((ref & mask) != (Dij & mask))
4294                                     {
4295                                         //printf("mX %d mY %d i %d j %d ref %d Dij %d\n", mX, mY, i, j, (int)ref, (int)Dij);
4296                                         res = QP_TEST_RESULT_FAIL;
4297                                     }
4298                                     int64_t Di0 = getD(i, 0);
4299                                     if (Dij != Di0)
4300                                     {
4301                                         //printf("mX %d mY %d i %d j %d Di0 %d Dij %d\n", mX, mY, i, j, (int)Di0, (int)Dij);
4302                                         res = QP_TEST_RESULT_FAIL;
4303                                     }
4304                                 }
4305                             }
4306                         }
4307                         else if (isReduceCol(m_data.testType))
4308                         {
4309                             for (uint32_t j = 0; j < N; ++j)
4310                             {
4311                                 int64_t ref = identity;
4312                                 for (uint32_t i = 0; i < M; ++i)
4313                                 {
4314                                     ref = Combine(ref, getA(i, j));
4315                                 }
4316                                 for (uint32_t i = 0; i < outputM; ++i)
4317                                 {
4318                                     int64_t Dij = getD(i, j);
4319                                     if ((ref & mask) != (Dij & mask))
4320                                     {
4321                                         //printf("mX %d mY %d i %d j %d ref %d Dij %d\n", mX, mY, i, j, (int)ref, (int)Dij);
4322                                         res = QP_TEST_RESULT_FAIL;
4323                                     }
4324                                     int64_t D0j = getD(0, j);
4325                                     if (Dij != D0j)
4326                                     {
4327                                         //printf("mX %d mY %d i %d j %d D0j %d Dij %d\n", mX, mY, i, j, (int)D0j, (int)Dij);
4328                                         res = QP_TEST_RESULT_FAIL;
4329                                     }
4330                                 }
4331                             }
4332                         }
4333                         else if (isReduceRowCol(m_data.testType))
4334                         {
4335                             int64_t ref = identity;
4336                             for (uint32_t i = 0; i < M; ++i)
4337                             {
4338                                 for (uint32_t j = 0; j < N; ++j)
4339                                 {
4340                                     ref = Combine(ref, getA(i, j));
4341                                 }
4342                             }
4343                             for (uint32_t i = 0; i < outputM; ++i)
4344                             {
4345                                 for (uint32_t j = 0; j < outputN; ++j)
4346                                 {
4347                                     int64_t Dij = getD(i, j);
4348                                     if ((ref & mask) != (Dij & mask))
4349                                     {
4350                                         //printf("mX %d mY %d i %d j %d ref %d Dij %d\n", mX, mY, i, j, (int)ref, (int)Dij);
4351                                         res = QP_TEST_RESULT_FAIL;
4352                                     }
4353                                     int64_t D00 = getD(0, 0);
4354                                     if (Dij != D00)
4355                                     {
4356                                         //printf("mX %d mY %d i %d j %d D00 %d Dij %d\n", mX, mY, i, j, (int)D00, (int)Dij);
4357                                         res = QP_TEST_RESULT_FAIL;
4358                                     }
4359                                 }
4360                             }
4361                         }
4362                         else if (isReduce2x2(m_data.testType))
4363                         {
4364                             for (uint32_t j = 0; j < N; ++j)
4365                             {
4366                                 for (uint32_t i = 0; i < M; ++i)
4367                                 {
4368                                     int64_t ref = identity;
4369                                     ref         = Combine(ref, getA(i * 2 + 0, j * 2 + 0));
4370                                     ref         = Combine(ref, getA(i * 2 + 0, j * 2 + 1));
4371                                     ref         = Combine(ref, getA(i * 2 + 1, j * 2 + 0));
4372                                     ref         = Combine(ref, getA(i * 2 + 1, j * 2 + 1));
4373 
4374                                     int64_t Dij = getD(i, j);
4375                                     if ((ref & mask) != (Dij & mask))
4376                                     {
4377                                         //printf("mX %d mY %d i %d j %d ref %d Dij %d\n", mX, mY, i, j, (int)ref, (int)Dij);
4378                                         res = QP_TEST_RESULT_FAIL;
4379                                     }
4380                                 }
4381                             }
4382                         }
4383                         else
4384                         {
4385                             DE_ASSERT(0);
4386                         }
4387                     }
4388                 }
4389             }
4390             else if (m_data.testType == TT_TRANSPOSE_ACC_TO_B)
4391             {
4392                 uint32_t ij;
4393                 uint32_t numMatrixX = (m_data.scope == VK_SCOPE_WORKGROUP_KHR) ?
4394                                           m_data.workgroupsX :
4395                                           (m_data.subgroupsPerWorkgroupX * m_data.workgroupsX);
4396                 uint32_t numMatrixY = (m_data.scope == VK_SCOPE_WORKGROUP_KHR) ?
4397                                           m_data.workgroupsY :
4398                                           (m_data.subgroupsPerWorkgroupY * m_data.workgroupsY);
4399                 int resultSize      = componentTypeInfo[dataTypes[3]].bits;
4400                 uint32_t mask       = resultSize == 32 ? ~0 : ((1 << resultSize) - 1);
4401 
4402                 for (uint32_t mX = 0; mX < numMatrixX; ++mX)
4403                 {
4404                     for (uint32_t mY = 0; mY < numMatrixY; ++mY)
4405                     {
4406                         for (uint32_t i = 0; i < M; ++i)
4407                         {
4408                             for (uint32_t j = 0; j < N; ++j)
4409                             {
4410                                 // for row-major, src is MxN, so row,col = i,j
4411                                 if (m_data.colMajor)
4412                                     ij = mX * M + i + strides[0] * mY * N + loadStrides[0] * j;
4413                                 else
4414                                     ij = mX * N + j + strides[0] * mY * M + loadStrides[0] * i;
4415 
4416                                 uint32_t ref = getDataInt(ptrs[0], dataTypes[0], ij);
4417 
4418                                 // for row-major, dst is NxM, so row,col = j,i
4419                                 if (m_data.colMajor)
4420                                     ij = mX * N + j + strides[3] * (mY * M + i);
4421                                 else
4422                                     ij = mX * M + i + strides[3] * (mY * N + j);
4423 
4424                                 uint32_t Dij = getDataInt(ptrs[3], dataTypes[3], ij);
4425 
4426                                 if ((ref & mask) != (Dij & mask))
4427                                 {
4428                                     res = QP_TEST_RESULT_FAIL;
4429                                 }
4430                             }
4431                         }
4432                     }
4433                 }
4434             }
4435             else if (m_data.testType == TT_SPACETODEPTH)
4436             {
4437                 uint32_t H = 32;
4438                 uint32_t W = 32;
4439                 uint32_t C = 16;
4440                 for (uint32_t h = 0; h < H; ++h)
4441                 {
4442                     for (uint32_t w = 0; w < W; ++w)
4443                     {
4444                         for (uint32_t c = 0; c < C; ++c)
4445                         {
4446                             uint32_t inputIndex  = (h * W + w) * C + c;
4447                             uint32_t outputIndex = ((h / 2) * W / 2 + w / 2) * 4 * C + ((h & 1) * 2 + (w & 1)) * C + c;
4448                             uint32_t ref         = getDataInt(ptrs[0], dataTypes[0], inputIndex);
4449                             uint32_t output      = getDataInt(ptrs[3], dataTypes[3], outputIndex);
4450                             if (ref != output)
4451                             {
4452                                 //printf("h %d w %d c %d ref %d output %d\n", h, w, c, ref, output);
4453                                 res = QP_TEST_RESULT_FAIL;
4454                             }
4455                         }
4456                     }
4457                 }
4458             }
4459             else if (isTensorLayoutTest(m_data.testType))
4460             {
4461                 uint32_t dim = GetDim(m_data.testType);
4462                 for (int32_t i0 = 0; i0 < GetTensorLayoutDim(dim)[0]; ++i0)
4463                 {
4464                     for (int32_t i1 = 0; i1 < GetTensorLayoutDim(dim)[1]; ++i1)
4465                     {
4466                         for (int32_t i2 = 0; i2 < GetTensorLayoutDim(dim)[2]; ++i2)
4467                         {
4468                             for (int32_t i3 = 0; i3 < GetTensorLayoutDim(dim)[3]; ++i3)
4469                             {
4470                                 for (int32_t i4 = 0; i4 < GetTensorLayoutDim(dim)[4]; ++i4)
4471                                 {
4472                                     int32_t tensorCoord[5] = {i0, i1, i2, i3, i4};
4473                                     uint32_t index         = 0;
4474                                     for (uint32_t k = 0; k < dim; ++k)
4475                                     {
4476                                         index = index * GetTensorLayoutDim(dim)[k] + tensorCoord[k];
4477                                     }
4478                                     uint32_t ref    = 123;
4479                                     uint32_t output = getDataInt(ptrs[3], dataTypes[3], index);
4480                                     // If the dest coord is in one of the store rectangles, compute
4481                                     // a different reference value.
4482                                     for (uint32_t r = 0; r < GetTensorLayoutNumCoords(dim); ++r)
4483                                     {
4484                                         bool inStoreRect = true;
4485                                         for (uint32_t k = 0; k < dim; ++k)
4486                                         {
4487                                             if ((int32_t)tensorCoord[k] < GetTensorLayoutStoreOffsets(dim, r)[k] ||
4488                                                 (int32_t)tensorCoord[k] >= GetTensorLayoutStoreOffsets(dim, r)[k] +
4489                                                                                GetTensorLayoutSpan(dim, r)[k])
4490                                             {
4491                                                 inStoreRect = false;
4492                                             }
4493                                         }
4494 
4495                                         if (inStoreRect)
4496                                         {
4497                                             int32_t loadCoord[5] = {i0, i1, i2, i3, i4};
4498                                             for (uint32_t k = 0; k < dim; ++k)
4499                                             {
4500                                                 loadCoord[k] = loadCoord[k] - GetTensorLayoutStoreOffsets(dim, r)[k] +
4501                                                                GetTensorLayoutLoadOffsets(dim, r)[k];
4502                                             }
4503                                             bool OOB = false;
4504                                             // gl_CooperativeMatrixClampModeConstant bounds checking
4505                                             for (uint32_t k = 0; k < dim; ++k)
4506                                             {
4507                                                 if (loadCoord[k] < 0 || loadCoord[k] >= GetTensorLayoutDim(dim)[k])
4508                                                 {
4509                                                     OOB = true;
4510                                                 }
4511                                             }
4512                                             if (OOB)
4513                                             {
4514                                                 ref = 0;
4515                                             }
4516                                             else
4517                                             {
4518                                                 index = 0;
4519                                                 for (uint32_t k = 0; k < dim; ++k)
4520                                                 {
4521                                                     index = index * GetTensorLayoutDim(dim)[k] + loadCoord[k];
4522                                                 }
4523                                                 ref = getDataInt(ptrs[0], dataTypes[0], index);
4524                                             }
4525                                             break;
4526                                         }
4527                                     }
4528                                     if (ref != output)
4529                                     {
4530                                         //printf("tensorCoord {%d, %d, %d, %d, %d} ref %d output %d\n", tensorCoord[0], tensorCoord[1], tensorCoord[2], tensorCoord[3], tensorCoord[4], ref, output);
4531                                         res = QP_TEST_RESULT_FAIL;
4532                                     }
4533                                 }
4534                             }
4535                         }
4536                     }
4537                 }
4538             }
4539             else if (m_data.testType == TT_PER_ELEMENT_OP_ROW_COL)
4540             {
4541                 uint32_t ij;
4542                 uint32_t numMatrixX = (m_data.scope == VK_SCOPE_WORKGROUP_KHR) ?
4543                                           m_data.workgroupsX :
4544                                           (m_data.subgroupsPerWorkgroupX * m_data.workgroupsX);
4545                 uint32_t numMatrixY = (m_data.scope == VK_SCOPE_WORKGROUP_KHR) ?
4546                                           m_data.workgroupsY :
4547                                           (m_data.subgroupsPerWorkgroupY * m_data.workgroupsY);
4548                 int resultSize      = componentTypeInfo[dataTypes[3]].bits;
4549                 uint32_t mask       = resultSize == 32 ? ~0 : ((1 << resultSize) - 1);
4550                 for (uint32_t mX = 0; mX < numMatrixX; ++mX)
4551                 {
4552                     for (uint32_t mY = 0; mY < numMatrixY; ++mY)
4553                     {
4554                         for (uint32_t i = 0; i < M; ++i)
4555                         {
4556                             for (uint32_t j = 0; j < N; ++j)
4557                             {
4558                                 if (m_data.colMajor)
4559                                     ij = mX * M + i + strides[0] * mY * N + loadStrides[0] * j;
4560                                 else
4561                                     ij = mX * N + j + strides[0] * mY * M + loadStrides[0] * i;
4562 
4563                                 uint32_t ref = getDataInt(ptrs[0], dataTypes[0], ij);
4564 
4565                                 uint32_t Dij = getDataInt(ptrs[3], dataTypes[3], ij);
4566 
4567                                 if (((ref + (i * 3 + j)) & mask) != (Dij & mask))
4568                                 {
4569                                     res = QP_TEST_RESULT_FAIL;
4570                                 }
4571                             }
4572                         }
4573                     }
4574                 }
4575             }
4576             else if (isClampTest(m_data.testType))
4577             {
4578                 uint32_t ij;
4579                 uint32_t numMatrixX = (m_data.scope == VK_SCOPE_WORKGROUP_KHR) ?
4580                                           m_data.workgroupsX :
4581                                           (m_data.subgroupsPerWorkgroupX * m_data.workgroupsX);
4582                 uint32_t numMatrixY = (m_data.scope == VK_SCOPE_WORKGROUP_KHR) ?
4583                                           m_data.workgroupsY :
4584                                           (m_data.subgroupsPerWorkgroupY * m_data.workgroupsY);
4585                 uint32_t fullDimX   = numMatrixX * (m_data.colMajor ? dims[0].rows : dims[0].cols);
4586                 uint32_t fullDimY   = numMatrixY * (m_data.colMajor ? dims[0].cols : dims[0].rows);
4587                 uint32_t dimX       = fullDimX - 6;
4588                 uint32_t dimY       = fullDimY - 6;
4589                 for (uint32_t mX = 0; mX < numMatrixX; ++mX)
4590                 {
4591                     for (uint32_t mY = 0; mY < numMatrixY; ++mY)
4592                     {
4593                         for (uint32_t i = 0; i < M; ++i)
4594                         {
4595                             for (uint32_t j = 0; j < N; ++j)
4596                             {
4597                                 int32_t i2;
4598                                 int32_t j2;
4599                                 bool OOBLoad  = false;
4600                                 bool OOBStore = false;
4601 
4602                                 if (m_data.colMajor)
4603                                 {
4604                                     i2       = mX * M + i;
4605                                     j2       = mY * N + j;
4606                                     ij       = i2 + strides[3] * j2;
4607                                     OOBStore = i2 == (int32_t)fullDimX - 1 || j2 == (int32_t)fullDimY - 1;
4608                                 }
4609                                 else
4610                                 {
4611                                     i2       = mY * M + i;
4612                                     j2       = mX * N + j;
4613                                     ij       = j2 + strides[3] * i2;
4614                                     OOBStore = i2 == (int32_t)fullDimY - 1 || j2 == (int32_t)fullDimX - 1;
4615                                 }
4616 
4617                                 uint32_t Dij = getDataInt(ptrs[3], dataTypes[3], ij);
4618 
4619                                 auto const mod = [](int32_t n, int32_t d) -> int32_t
4620                                 {
4621                                     // works for the range of values we use
4622                                     return (n + d) % d;
4623                                 };
4624 
4625                                 i2 -= 3;
4626                                 j2 -= 3;
4627                                 uint32_t dimI = m_data.colMajor ? dimX : dimY;
4628                                 uint32_t dimJ = m_data.colMajor ? dimY : dimX;
4629                                 switch (m_data.testType)
4630                                 {
4631                                 case TT_CLAMPCONSTANT:
4632                                     OOBLoad = i2 < 0 || j2 < 0 || i2 >= (int32_t)dimI || j2 >= (int32_t)dimJ;
4633                                     break;
4634                                 case TT_CLAMPTOEDGE:
4635                                     i2 = std::min(std::max(i2, 0), (int32_t)dimI - 1);
4636                                     j2 = std::min(std::max(j2, 0), (int32_t)dimJ - 1);
4637                                     break;
4638                                 case TT_CLAMPREPEAT:
4639                                     i2 = mod(i2, dimI);
4640                                     j2 = mod(j2, dimJ);
4641                                     break;
4642                                 case TT_CLAMPMIRRORREPEAT:
4643                                     i2 = mod(i2, (2 * dimI - 2));
4644                                     i2 = (i2 >= (int32_t)dimI) ? (2 * dimI - 2 - i2) : i2;
4645                                     j2 = mod(j2, (2 * dimJ - 2));
4646                                     j2 = (j2 >= (int32_t)dimJ) ? (2 * dimJ - 2 - j2) : j2;
4647                                     break;
4648                                 default:
4649                                     DE_ASSERT(0);
4650                                     break;
4651                                 }
4652 
4653                                 if (m_data.colMajor)
4654                                 {
4655                                     ij = i2 + strides[0] * j2;
4656                                 }
4657                                 else
4658                                 {
4659                                     ij = j2 + strides[0] * i2;
4660                                 }
4661 
4662                                 uint32_t ref = OOBStore ? 123 : OOBLoad ? 17 : getDataInt(ptrs[0], dataTypes[0], ij);
4663 
4664                                 if (ref != Dij)
4665                                 {
4666                                     res = QP_TEST_RESULT_FAIL;
4667                                 }
4668                             }
4669                         }
4670                     }
4671                 }
4672             }
4673             else if (m_data.addrMethod == ADDR_BLOCKSIZE || m_data.addrMethod == ADDR_DECODE)
4674             {
4675                 uint32_t ij;
4676                 uint32_t blockij;
4677                 uint32_t numMatrixX = (m_data.scope == VK_SCOPE_WORKGROUP_KHR) ?
4678                                           m_data.workgroupsX :
4679                                           (m_data.subgroupsPerWorkgroupX * m_data.workgroupsX);
4680                 uint32_t numMatrixY = (m_data.scope == VK_SCOPE_WORKGROUP_KHR) ?
4681                                           m_data.workgroupsY :
4682                                           (m_data.subgroupsPerWorkgroupY * m_data.workgroupsY);
4683                 int resultSize      = componentTypeInfo[dataTypes[3]].bits;
4684                 uint32_t mask       = resultSize == 32 ? ~0 : ((1 << resultSize) - 1);
4685                 for (uint32_t mX = 0; mX < numMatrixX; ++mX)
4686                 {
4687                     for (uint32_t mY = 0; mY < numMatrixY; ++mY)
4688                     {
4689                         for (uint32_t i = 0; i < M; ++i)
4690                         {
4691                             for (uint32_t j = 0; j < N; ++j)
4692                             {
4693                                 uint32_t blockCoords[2];
4694                                 if (m_data.colMajor)
4695                                 {
4696                                     blockCoords[0] = (mY * N + j) / blockSize[0];
4697                                     blockCoords[1] = (mX * M + i) / blockSize[1];
4698                                     blockij        = blockCoords[1] + (strides[0] / blockSize[1]) * blockCoords[0];
4699                                     if (m_data.addrMethod == ADDR_DECODE)
4700                                     {
4701                                         blockij *= blockSize[0] * blockSize[1];
4702                                         blockij += (j % blockSize[0]) * blockSize[1] + (i % blockSize[1]);
4703                                     }
4704                                     ij = mX * M + i + strides[0] * mY * N + loadStrides[0] * j;
4705                                 }
4706                                 else
4707                                 {
4708                                     blockCoords[0] = (mY * M + i) / blockSize[0];
4709                                     blockCoords[1] = (mX * N + j) / blockSize[1];
4710                                     blockij        = blockCoords[1] + (strides[0] / blockSize[1]) * blockCoords[0];
4711                                     if (m_data.addrMethod == ADDR_DECODE)
4712                                     {
4713                                         blockij *= blockSize[0] * blockSize[1];
4714                                         blockij += (i % blockSize[0]) * blockSize[1] + (j % blockSize[1]);
4715                                     }
4716                                     ij = mX * N + j + strides[0] * mY * M + loadStrides[0] * i;
4717                                 }
4718 
4719                                 uint32_t ref = getDataInt(ptrs[0], dataTypes[0], blockij);
4720 
4721                                 if (m_data.addrMethod == ADDR_DECODE)
4722                                 {
4723                                     ref += (2 * blockCoords[0] + blockCoords[1]) & 3;
4724                                 }
4725 
4726                                 uint32_t Dij = getDataInt(ptrs[3], dataTypes[3], ij);
4727 
4728                                 if (m_data.testType == TT_NEGATE)
4729                                 {
4730                                     ref = -(int32_t)ref;
4731                                 }
4732                                 else
4733                                 {
4734                                     DE_ASSERT(0);
4735                                 }
4736 
4737                                 if ((ref & mask) != (Dij & mask))
4738                                 {
4739                                     res = QP_TEST_RESULT_FAIL;
4740                                 }
4741                             }
4742                         }
4743                     }
4744                 }
4745             }
4746             else if (!isMatrixMulAddOp(m_data.testType))
4747             {
4748                 for (uint32_t i = 0; i < totalElements[3]; ++i)
4749                 {
4750                     uint32_t inputA = getDataInt(ptrs[0], dataTypes[0], i);
4751                     uint32_t inputB = getDataInt(ptrs[1], dataTypes[1], i);
4752                     uint32_t output = getDataInt(ptrs[3], dataTypes[3], i);
4753                     int resultSize  = componentTypeInfo[dataTypes[3]].bits;
4754                     uint32_t mask   = resultSize == 32 ? ~0 : ((1 << resultSize) - 1);
4755                     switch (m_data.testType)
4756                     {
4757                     case TT_LENGTH:
4758                         if (output < 1 || output > N * M)
4759                             res = QP_TEST_RESULT_FAIL;
4760                         if (m_data.scope == VK_SCOPE_SUBGROUP_KHR)
4761                         {
4762                             // We expect the matrix to be spread evenly across invocations, it is
4763                             // surprising (but not necessarily illegal) if not
4764                             if (output != N * M / subgroupSize && res == QP_TEST_RESULT_PASS)
4765                             {
4766                                 res = QP_TEST_RESULT_QUALITY_WARNING;
4767                             }
4768                         }
4769                         break;
4770                     case TT_CONSTANT:
4771                         if (output != 1)
4772                             res = QP_TEST_RESULT_FAIL;
4773                         break;
4774                     case TT_COMPOSITE:
4775                     case TT_COMPOSITE_RVALUE:
4776                     case TT_COMPOSITE_ARRAY:
4777                     case TT_ADD:
4778                         if ((output & mask) != ((inputA + inputB) & mask))
4779                         {
4780                             res = QP_TEST_RESULT_FAIL;
4781                         }
4782                         break;
4783                     case TT_SUB:
4784                         if ((output & mask) != ((inputA - inputB) & mask))
4785                             res = QP_TEST_RESULT_FAIL;
4786                         break;
4787                     case TT_DIV:
4788                     {
4789                         if (isSIntType(dataTypes[3]))
4790                         {
4791                             if (inputB != 0 && ((int32_t)output & mask) != (((int32_t)inputA / (int32_t)inputB) & mask))
4792                                 res = QP_TEST_RESULT_FAIL;
4793                         }
4794                         else
4795                         {
4796                             if (inputB != 0 && output != inputA / inputB)
4797                                 res = QP_TEST_RESULT_FAIL;
4798                         }
4799                     }
4800                     break;
4801                     case TT_MUL:
4802                     {
4803                         if (((int32_t)output & mask) != (((int32_t)inputA * (int32_t)inputB) & mask))
4804                         {
4805                             res = QP_TEST_RESULT_FAIL;
4806                         }
4807 
4808                         break;
4809                     }
4810                     case TT_NEGATE:
4811                     case TT_FUNC:
4812                         if ((output & mask) != ((-(int32_t)inputA) & mask))
4813                             res = QP_TEST_RESULT_FAIL;
4814                         break;
4815                     case TT_MATRIXTIMESSCALAR:
4816                         if ((output & mask) != ((6 * inputA) & mask))
4817                         {
4818                             res = QP_TEST_RESULT_FAIL;
4819                         }
4820                         break;
4821                     case TT_MULTICOMPONENT_LOAD:
4822                     {
4823                         if (output != inputA)
4824                             res = QP_TEST_RESULT_FAIL;
4825                         break;
4826                     }
4827                     case TT_CONVERT_ACC_TO_A:
4828                     case TT_CONVERT_ACC_TO_B:
4829                     case TT_MULTICOMPONENT_SAVE:
4830                     {
4831                         if ((output & mask) != (inputA & mask))
4832                         {
4833                             //printf("fail ");
4834                             res = QP_TEST_RESULT_FAIL;
4835                         }
4836                         //printf("i %d inputA %d output %d\n", i, inputA, output);
4837                         break;
4838                     }
4839                     case TT_PER_ELEMENT_OP:
4840                     case TT_PER_ELEMENT_OP_STRUCT:
4841                         if ((output & mask) != ((inputA + 2) & mask))
4842                         {
4843                             res = QP_TEST_RESULT_FAIL;
4844                         }
4845                         break;
4846                     case TT_PER_ELEMENT_OP_MAT:
4847                         if ((output & mask) != ((inputA * 3) & mask))
4848                         {
4849                             res = QP_TEST_RESULT_FAIL;
4850                         }
4851                         break;
4852                     default:
4853                         TCU_THROW(InternalError, "Unimplemented");
4854                     }
4855                 }
4856             }
4857             else
4858             {
4859                 uint32_t ik, kj, ij;
4860                 uint32_t numMatrixX = (m_data.scope == VK_SCOPE_WORKGROUP_KHR) ?
4861                                           m_data.workgroupsX :
4862                                           (m_data.subgroupsPerWorkgroupX * m_data.workgroupsX);
4863                 uint32_t numMatrixY = (m_data.scope == VK_SCOPE_WORKGROUP_KHR) ?
4864                                           m_data.workgroupsY :
4865                                           (m_data.subgroupsPerWorkgroupY * m_data.workgroupsY);
4866                 for (uint32_t mX = 0; mX < numMatrixX; ++mX)
4867                 {
4868                     for (uint32_t mY = 0; mY < numMatrixY; ++mY)
4869                     {
4870                         for (uint32_t i = 0; i < M; ++i)
4871                         {
4872                             for (uint32_t j = 0; j < N; ++j)
4873                             {
4874                                 uint32_t ref = 0;
4875 
4876                                 for (uint32_t k = 0; k < K; ++k)
4877                                 {
4878                                     if (m_data.colMajor)
4879                                         ik = mX * M + i + strides[0] * mY * K + loadStrides[0] * k;
4880                                     else
4881                                         ik = mX * K + k + strides[0] * mY * M + loadStrides[0] * i;
4882 
4883                                     uint32_t Aik = getDataInt(ptrs[0], dataTypes[0], ik);
4884 
4885                                     if (m_data.colMajor)
4886                                         kj = mX * K + k + strides[1] * mY * N + loadStrides[1] * j;
4887                                     else
4888                                         kj = mX * N + j + strides[1] * mY * K + loadStrides[1] * k;
4889 
4890                                     uint32_t Bkj = getDataInt(ptrs[1], dataTypes[1], kj);
4891 
4892                                     ref += Aik * Bkj;
4893                                 }
4894 
4895                                 if (m_data.colMajor)
4896                                     ij = mX * M + i + strides[2] * mY * N + loadStrides[2] * j;
4897                                 else
4898                                     ij = mX * N + j + strides[2] * mY * M + loadStrides[2] * i;
4899 
4900                                 uint32_t Cij = getDataInt(ptrs[2], dataTypes[2], ij);
4901 
4902                                 if (saturated)
4903                                 {
4904                                     ref = satAddData(dataTypes[2], ref, Cij);
4905                                 }
4906                                 else
4907                                 {
4908                                     ref += Cij;
4909                                     // truncate the result to the size of C's type.
4910                                     uint32_t bits = componentTypeInfo[dataTypes[3]].bits;
4911                                     uint32_t mask = (bits == 32) ? 0xFFFFFFFFU : ((1U << bits) - 1U);
4912                                     ref &= mask;
4913                                 }
4914 
4915                                 // When loading with stride 0, ij for matrix D is different from matrix C
4916                                 if (m_data.colMajor)
4917                                     ij = mX * M + i + strides[2] * (mY * N + j);
4918                                 else
4919                                     ij = mX * N + j + strides[2] * (mY * M + i);
4920 
4921                                 uint32_t Dij = getDataInt(ptrs[3], dataTypes[3], ij);
4922 
4923                                 if (ref != Dij)
4924                                 {
4925                                     res = QP_TEST_RESULT_FAIL;
4926                                 }
4927                             }
4928                         }
4929                     }
4930                 }
4931             }
4932         }
4933 
4934         if (res != QP_TEST_RESULT_PASS)
4935         {
4936             finalres = res;
4937 
4938             log << tcu::TestLog::Message << "failed with M = " << M << ", N = " << N << ", K = " << K
4939                 << ", WG = " << testSize.workgroupSize << tcu::TestLog::EndMessage;
4940 
4941 #ifdef COOPERATIVE_MATRIX_EXTENDED_DEBUG
4942             for (int i = 0; i < 4; i++)
4943             {
4944                 const char *matrixNames[] = {"A", "B", "C", "D"};
4945 
4946                 log << tcu::TestLog::Message << "Matrix " << matrixNames[i]
4947                     << "[rows=" << m_data.subgroupsPerWorkgroupY * m_data.workgroupsY * dims[i].rows
4948                     << ", cols=" << m_data.subgroupsPerWorkgroupX * m_data.workgroupsX * dims[i].cols << "]:\n"
4949                     << dumpWholeMatrix(ptrs[i], dataTypes[i], m_data.colMajor, totalElements[i], strides[i])
4950                     << tcu::TestLog::EndMessage;
4951             }
4952 #endif
4953         }
4954         else
4955         {
4956             if (finalres == QP_TEST_RESULT_NOT_SUPPORTED)
4957                 finalres = res;
4958         }
4959     }
4960 
4961     return tcu::TestStatus(finalres, qpGetTestResultName(finalres));
4962 }
4963 
getUseType(UseType useType)4964 const char *getUseType(UseType useType)
4965 {
4966     switch (useType)
4967     {
4968     case UT_NV:
4969         return "nv";
4970     case UT_KHR_A:
4971         return "khr_a";
4972     case UT_KHR_B:
4973         return "khr_b";
4974     case UT_KHR_C:
4975         return "khr_c";
4976     case UT_KHR_Result:
4977         return "khr_r";
4978     default:
4979         TCU_THROW(InternalError, "Unknown use type");
4980     }
4981 }
4982 
createCooperativeMatrixTestsInternal(tcu::TestContext & testCtx,vk::ComputePipelineConstructionType computePipelineConstructionType,UseType useType)4983 tcu::TestCaseGroup *createCooperativeMatrixTestsInternal(
4984     tcu::TestContext &testCtx, vk::ComputePipelineConstructionType computePipelineConstructionType, UseType useType)
4985 {
4986     de::MovePtr<tcu::TestCaseGroup> group(new tcu::TestCaseGroup(testCtx, getUseType(useType)));
4987 
4988     typedef struct
4989     {
4990         uint32_t value;
4991         const char *name;
4992     } TestGroupCase;
4993 
4994     typedef struct
4995     {
4996         uint32_t value[2];
4997         const char *name;
4998     } TestGroupCase2;
4999 
5000     typedef struct
5001     {
5002         SubgroupSizeMode value;
5003         const char *name;
5004     } SubGroubSizes;
5005 
5006     typedef struct
5007     {
5008         const char *name;
5009         const char *description;
5010         uint32_t componentCount;
5011     } MulticomponentTypes;
5012 
5013     typedef struct
5014     {
5015         const char *name;
5016         const char *description;
5017         TestType testType;
5018     } IOTypes;
5019 
5020     TestGroupCase ttCases[] = {
5021         // OpCooperativeMatrixLength
5022         {TT_LENGTH, "length"},
5023         // OpConstantComposite
5024         {TT_CONSTANT, "constant"},
5025         // OpCompositeConstruct
5026         {TT_COMPOSITE, "composite"},
5027         // OpCompositeExtract
5028         {TT_COMPOSITE_RVALUE, "composite_rvalue"},
5029         // OpFAdd/OpIAdd
5030         {TT_ADD, "add"},
5031         // OpFSub/OpISub
5032         {TT_SUB, "sub"},
5033         // OpFDiv/OpSDiv/OpUDiv
5034         {TT_DIV, "div"},
5035         // OpFMul/OpIMul
5036         {TT_MUL, "mul"},
5037         // OpFNegate/OpSNegate
5038         {TT_NEGATE, "negate"},
5039         // OpMatrixTimesScalar
5040         {TT_MATRIXTIMESSCALAR, "matrixtimesscalar"},
5041         // OpFunctionParameter
5042         {TT_FUNC, "func"},
5043         // OpCooperativeMatrixMulAdd
5044         {TT_MATRIXMULADD, "matrixmuladd"},
5045         // OpCompositeConstruct w/array
5046         {TT_COMPOSITE_ARRAY, "composite_array"},
5047         // OpCooperativeMatrixMulAdd w/array
5048         {TT_MATRIXMULADD_ARRAY, "matrixmuladd_array"},
5049         // OpCooperativeMatrixMulAdd w/saturations
5050         {TT_MATRIXMULADD_SATURATED, "matrixmuladd_saturated"},
5051         // OpCooperativeMatrixMulAdd w/wrapping
5052         {TT_MATRIXMULADD_WRAPPING, "matrixmuladd_wrapping"},
5053         // OpCooperativeMatrixMulAdd w/stride==0
5054         {TT_MATRIXMULADD_STRIDE0, "matrixmuladd_stride0"},
5055         // OpCooperativeMatrixMulAdd
5056         {TT_MATRIXMULADD_CROSS, "matrixmuladd_cross"},
5057         // OpCooperativeMatrixMulAdd w/decode
5058         {TT_MATRIXMULADD_DEQUANT, "matrixmuladd_dequant"},
5059         // OpConvertCooperativeMatrixNV
5060         {TT_CONVERT_ACC_TO_A, "convert_acc_to_a"},
5061         {TT_CONVERT_ACC_TO_B, "convert_acc_to_b"},
5062         // OpTransposeCooperativeMatrixNV
5063         {TT_TRANSPOSE_ACC_TO_B, "transpose_acc_to_b"},
5064         // OpCooperativeMatrixReduceNV
5065         {TT_REDUCE_SUM_ROW, "reduce_sum_row"},
5066         {TT_REDUCE_SUM_COL, "reduce_sum_col"},
5067         {TT_REDUCE_SUM_ROWCOL, "reduce_sum_rowcol"},
5068         {TT_REDUCE_SUM_2X2, "reduce_sum_2x2"},
5069         {TT_REDUCE_SUM_ROW_CHANGEDIM, "reduce_sum_row_changedim"},
5070         {TT_REDUCE_SUM_COL_CHANGEDIM, "reduce_sum_col_changedim"},
5071         {TT_REDUCE_SUM_ROWCOL_CHANGEDIM, "reduce_sum_rowcol_changedim"},
5072         {TT_REDUCE_MIN_ROW, "reduce_min_row"},
5073         {TT_REDUCE_MIN_COL, "reduce_min_col"},
5074         {TT_REDUCE_MIN_ROWCOL, "reduce_min_rowcol"},
5075         {TT_REDUCE_MIN_2X2, "reduce_min_2x2"},
5076 
5077         {TT_PER_ELEMENT_OP, "per_element_op"},
5078         {TT_PER_ELEMENT_OP_ROW_COL, "per_element_op_row_col"},
5079         {TT_PER_ELEMENT_OP_STRUCT, "per_element_op_struct"},
5080         {TT_PER_ELEMENT_OP_MAT, "per_element_op_mat"},
5081 
5082         {TT_TENSORLAYOUT_1D, "tensorlayout1d"},
5083         {TT_TENSORLAYOUT_2D, "tensorlayout2d"},
5084         {TT_TENSORLAYOUT_3D, "tensorlayout3d"},
5085         {TT_TENSORLAYOUT_4D, "tensorlayout4d"},
5086         {TT_TENSORLAYOUT_5D, "tensorlayout5d"},
5087         {TT_TENSORLAYOUT_1D_CLIP, "tensorlayout1dclip"},
5088         {TT_TENSORLAYOUT_2D_CLIP, "tensorlayout2dclip"},
5089         {TT_TENSORLAYOUT_3D_CLIP, "tensorlayout3dclip"},
5090         {TT_TENSORLAYOUT_4D_CLIP, "tensorlayout4dclip"},
5091         {TT_TENSORLAYOUT_5D_CLIP, "tensorlayout5dclip"},
5092         {TT_SPACETODEPTH, "spacetodepth"},
5093 
5094         {TT_CLAMPCONSTANT, "clampconstant"},
5095         {TT_CLAMPTOEDGE, "clamptoedge"},
5096         {TT_CLAMPREPEAT, "clamprepeat"},
5097         {TT_CLAMPMIRRORREPEAT, "clampmirrorrepeat"},
5098     };
5099     TestGroupCase2 dtCases[] = {
5100         // A/B are fp32 C/D are fp32
5101         {{VK_COMPONENT_TYPE_FLOAT32_KHR, VK_COMPONENT_TYPE_FLOAT32_KHR}, "float32_float32"},
5102         // A/B are fp32 C/D are fp16
5103         {{VK_COMPONENT_TYPE_FLOAT32_KHR, VK_COMPONENT_TYPE_FLOAT16_KHR}, "float32_float16"},
5104         // A/B are fp16 C/D are fp32
5105         {{VK_COMPONENT_TYPE_FLOAT16_KHR, VK_COMPONENT_TYPE_FLOAT32_KHR}, "float16_float32"},
5106         // A/B are fp16 C/D are fp16
5107         {{VK_COMPONENT_TYPE_FLOAT16_KHR, VK_COMPONENT_TYPE_FLOAT16_KHR}, "float16_float16"},
5108         // A/B are u8 C/D are u8
5109         {{VK_COMPONENT_TYPE_UINT8_KHR, VK_COMPONENT_TYPE_UINT8_KHR}, "uint8_uint8"},
5110         // A/B are u8 C/D are u32
5111         {{VK_COMPONENT_TYPE_UINT8_KHR, VK_COMPONENT_TYPE_UINT32_KHR}, "uint8_uint32"},
5112         // A/B are s8 C/D are s8
5113         {{VK_COMPONENT_TYPE_SINT8_KHR, VK_COMPONENT_TYPE_SINT8_KHR}, "sint8_sint8"},
5114         // A/B are s8 C/D are s32
5115         {{VK_COMPONENT_TYPE_SINT8_KHR, VK_COMPONENT_TYPE_SINT32_KHR}, "sint8_sint32"},
5116         // A/B are u8 C/D are s32
5117         {{VK_COMPONENT_TYPE_UINT8_KHR, VK_COMPONENT_TYPE_SINT32_KHR}, "uint8_sint32"},
5118         // A/B are u32 C/D are u32
5119         {{VK_COMPONENT_TYPE_UINT32_KHR, VK_COMPONENT_TYPE_UINT32_KHR}, "uint32_uint32"},
5120         // A/B are u32 C/D are u8
5121         {{VK_COMPONENT_TYPE_UINT32_KHR, VK_COMPONENT_TYPE_UINT8_KHR}, "uint32_uint8"},
5122         // A/B are s32 C/D are s32
5123         {{VK_COMPONENT_TYPE_SINT32_KHR, VK_COMPONENT_TYPE_SINT32_KHR}, "sint32_sint32"},
5124         // A/B are s32 C/D are s8
5125         {{VK_COMPONENT_TYPE_SINT32_KHR, VK_COMPONENT_TYPE_SINT8_KHR}, "sint32_sint8"},
5126     };
5127     SubGroubSizes sgsCases[] = {
5128         // Default subgroup size
5129         {SUBGROUP_SIZE_NONE, ""},
5130         // Minimum subgroup size
5131         {SUBGROUP_SIZE_MIN, "_min"},
5132         // Maximum subgroup size
5133         {SUBGROUP_SIZE_MAX, "_max"},
5134     };
5135 
5136     TestGroupCase colCases[] = {
5137         {0, "rowmajor"},
5138         {1, "colmajor"},
5139     };
5140 
5141     TestGroupCase addrCases[] = {
5142         {ADDR_LINEAR, "linear"},
5143         {ADDR_TENSORLAYOUT, "tensorlayout"},
5144         {ADDR_BLOCKSIZE, "blocksize"},
5145         {ADDR_DECODE, "decode"},
5146     };
5147 
5148     TestGroupCase scopeCases[] = {
5149         {VK_SCOPE_SUBGROUP_KHR, "subgroupscope"},
5150         {VK_SCOPE_WORKGROUP_KHR, "workgroupscope"},
5151     };
5152 
5153     TestGroupCase scCases[] = {
5154         // SSBO
5155         {SC_BUFFER, "buffer"},
5156         // shared memory
5157         {SC_WORKGROUP, "workgroup"},
5158         // SSBO w/variable pointers
5159         {SC_BUFFER_VARIABLE_POINTERS, "buffer_varptr"},
5160         // shared memory w/variable pointers
5161         {SC_WORKGROUP_VARIABLE_POINTERS, "workgroup_varptr"},
5162         // physical_storage_buffer
5163         {SC_PHYSICAL_STORAGE_BUFFER, "physical_buffer"},
5164     };
5165 
5166     // Types tested for conversions. Excludes 64b types.
5167     VkComponentTypeKHR allTypes[] = {
5168         VK_COMPONENT_TYPE_FLOAT16_KHR, VK_COMPONENT_TYPE_FLOAT32_KHR, VK_COMPONENT_TYPE_SINT8_KHR,
5169         VK_COMPONENT_TYPE_SINT16_KHR,  VK_COMPONENT_TYPE_SINT32_KHR,  VK_COMPONENT_TYPE_UINT8_KHR,
5170         VK_COMPONENT_TYPE_UINT16_KHR,  VK_COMPONENT_TYPE_UINT32_KHR,
5171     };
5172 
5173     // Types tested for load/store from/into multicomponent types
5174     MulticomponentTypes multicomponentTypes[] = {
5175         {"vec2", "2-component vector type as input or output", 2},
5176         {"vec4", "4-component vector type as input or output", 4},
5177     };
5178 
5179     // Types tested for load/store from/into multicomponent types
5180     IOTypes ioTypes[] = {
5181         {"load", "Test multicomponent type as input in load operation", TT_MULTICOMPONENT_LOAD},
5182         {"save", "Test multicomponent type as output in store operation", TT_MULTICOMPONENT_SAVE},
5183     };
5184 
5185     for (int scopeNdx = 0; scopeNdx < DE_LENGTH_OF_ARRAY(scopeCases); scopeNdx++)
5186     {
5187         de::MovePtr<tcu::TestCaseGroup> scopeGroup(new tcu::TestCaseGroup(testCtx, scopeCases[scopeNdx].name));
5188         if (useType == UT_NV && scopeCases[scopeNdx].value == VK_SCOPE_WORKGROUP_KHR)
5189         {
5190             continue;
5191         }
5192 
5193         for (int ttNdx = 0; ttNdx < DE_LENGTH_OF_ARRAY(ttCases); ttNdx++)
5194         {
5195             const TestType testType = (TestType)ttCases[ttNdx].value;
5196 
5197             for (int sgsNdx = 0; sgsNdx < DE_LENGTH_OF_ARRAY(sgsCases); sgsNdx++)
5198             {
5199                 if (testType != TT_MATRIXMULADD && sgsCases[sgsNdx].value != SUBGROUP_SIZE_NONE)
5200                     continue;
5201 
5202                 if (testType == TT_MATRIXMULADD && sgsCases[sgsNdx].value != SUBGROUP_SIZE_NONE && useType == UT_NV)
5203                     continue;
5204 
5205                 const string name = string(ttCases[ttNdx].name) + sgsCases[sgsNdx].name;
5206                 de::MovePtr<tcu::TestCaseGroup> ttGroup(new tcu::TestCaseGroup(testCtx, name.c_str()));
5207 
5208                 for (int dtNdx = 0; dtNdx < DE_LENGTH_OF_ARRAY(dtCases); dtNdx++)
5209                 {
5210                     de::MovePtr<tcu::TestCaseGroup> dtGroup(new tcu::TestCaseGroup(testCtx, dtCases[dtNdx].name));
5211                     for (int scNdx = 0; scNdx < DE_LENGTH_OF_ARRAY(scCases); scNdx++)
5212                     {
5213                         de::MovePtr<tcu::TestCaseGroup> scGroup(new tcu::TestCaseGroup(testCtx, scCases[scNdx].name));
5214                         for (int colNdx = 0; colNdx < DE_LENGTH_OF_ARRAY(colCases); colNdx++)
5215                         {
5216                             de::MovePtr<tcu::TestCaseGroup> colGroup(
5217                                 new tcu::TestCaseGroup(testCtx, colCases[colNdx].name));
5218                             for (int addrNdx = 0; addrNdx < DE_LENGTH_OF_ARRAY(addrCases); addrNdx++)
5219                             {
5220 
5221                                 const VkComponentTypeKHR inputType  = (VkComponentTypeKHR)dtCases[dtNdx].value[0];
5222                                 const VkComponentTypeKHR outputType = (VkComponentTypeKHR)dtCases[dtNdx].value[1];
5223                                 const bool isMatrixMul              = isMatrixMulAddOp(testType);
5224 
5225                                 if (testType == TT_MATRIXMULADD_CROSS)
5226                                 {
5227                                     if (isFloatType(inputType) || isFloatType(outputType) || useType == UT_NV ||
5228                                         scCases[scNdx].value != SC_BUFFER)
5229                                         continue;
5230 
5231                                     // handwritten spir-v would need to be ported
5232                                     if (scopeCases[scopeNdx].value == VK_SCOPE_WORKGROUP_KHR)
5233                                         continue;
5234                                 }
5235                                 else
5236                                 {
5237                                     // Rest of tests do not run on matrix C
5238                                     if (useType == UT_KHR_C)
5239                                     {
5240                                         continue;
5241                                     }
5242 
5243                                     // useType isn't used for matrixmul shaders. Don't generate 3 copies of those tests.
5244                                     if (isMatrixMul && (useType == UT_KHR_A || useType == UT_KHR_B))
5245                                     {
5246                                         continue;
5247                                     }
5248 
5249                                     // NV extension doesn't support mixing signedness
5250                                     if (isMatrixMul && (useType == UT_NV) &&
5251                                         isSIntType(inputType) != isSIntType(outputType))
5252                                     {
5253                                         continue;
5254                                     }
5255 
5256                                     if (isMatrixMul &&
5257                                         componentTypeInfo[inputType].bits > componentTypeInfo[outputType].bits)
5258                                         continue;
5259                                 }
5260 
5261                                 if (testType == TT_MATRIXMULADD_DEQUANT)
5262                                 {
5263                                     if (inputType != VK_COMPONENT_TYPE_FLOAT16_KHR)
5264                                     {
5265                                         continue;
5266                                     }
5267                                     if (addrCases[addrNdx].value != ADDR_DECODE)
5268                                     {
5269                                         continue;
5270                                     }
5271                                     if (colCases[colNdx].value)
5272                                     {
5273                                         // row major only, for now
5274                                         continue;
5275                                     }
5276                                 }
5277 
5278                                 if ((addrCases[addrNdx].value == ADDR_BLOCKSIZE ||
5279                                      addrCases[addrNdx].value == ADDR_DECODE) &&
5280                                     testType != TT_NEGATE && testType != TT_MATRIXMULADD_DEQUANT)
5281                                 {
5282                                     // only certain tests ported to handle blocksize
5283                                     continue;
5284                                 }
5285 
5286                                 if ((addrCases[addrNdx].value == ADDR_BLOCKSIZE ||
5287                                      addrCases[addrNdx].value == ADDR_DECODE) &&
5288                                     (scCases[scNdx].value == SC_WORKGROUP ||
5289                                      scCases[scNdx].value == SC_WORKGROUP_VARIABLE_POINTERS))
5290                                 {
5291                                     // copying into shared memory not ported to handle block size
5292                                     continue;
5293                                 }
5294 
5295                                 if (!isMatrixMul && testType != TT_CONVERT_ACC_TO_A &&
5296                                     testType != TT_CONVERT_ACC_TO_B && testType != TT_TRANSPOSE_ACC_TO_B &&
5297                                     inputType != outputType)
5298                                     continue;
5299 
5300                                 if (testType == TT_MUL && useType == UT_NV)
5301                                     continue;
5302 
5303                                 if (testType == TT_MATRIXMULADD_SATURATED &&
5304                                     (isFloatType(inputType) || useType == UT_NV))
5305                                     continue;
5306 
5307                                 if (testType == TT_MATRIXMULADD_WRAPPING &&
5308                                     (isFloatType(inputType) || useType == UT_NV))
5309                                     continue;
5310 
5311                                 if (testType == TT_MATRIXMULADD_STRIDE0 && useType == UT_NV)
5312                                     continue;
5313 
5314                                 if (testType == TT_LENGTH && useType != UT_NV &&
5315                                     (outputType == VK_COMPONENT_TYPE_SINT8_KHR ||
5316                                      outputType == VK_COMPONENT_TYPE_UINT8_KHR))
5317                                     continue;
5318 
5319                                 if (useType == UT_NV && (addrCases[addrNdx].value != ADDR_LINEAR ||
5320                                                          isReduceOp(testType) || isPerElemOp(testType)))
5321                                 {
5322                                     continue;
5323                                 }
5324 
5325                                 if ((testType == TT_CONVERT_ACC_TO_A || testType == TT_CONVERT_ACC_TO_B ||
5326                                      testType == TT_TRANSPOSE_ACC_TO_B) &&
5327                                     useType != UT_KHR_Result)
5328                                 {
5329                                     // These tests hardcode the use, no need to repeat them three times
5330                                     continue;
5331                                 }
5332 
5333                                 if (isReduceOp(testType) && (useType == UT_KHR_A || useType == UT_KHR_B))
5334                                 {
5335                                     continue;
5336                                 }
5337 
5338                                 if (isReduceOp(testType) && inputType != outputType)
5339                                 {
5340                                     continue;
5341                                 }
5342 
5343                                 if (isTensorLayoutTest(testType) &&
5344                                     (colCases[colNdx].value || scCases[scNdx].value == SC_WORKGROUP ||
5345                                      scCases[scNdx].value == SC_WORKGROUP_VARIABLE_POINTERS ||
5346                                      scCases[scNdx].value == SC_PHYSICAL_STORAGE_BUFFER ||
5347                                      addrCases[addrNdx].value == ADDR_LINEAR))
5348                                 {
5349                                     continue;
5350                                 }
5351 
5352                                 if ((scCases[scNdx].value == SC_BUFFER_VARIABLE_POINTERS ||
5353                                      scCases[scNdx].value == SC_WORKGROUP_VARIABLE_POINTERS) &&
5354                                     (!(testType == TT_MATRIXMULADD || testType == TT_MUL) ||
5355                                      sgsCases[sgsNdx].value != SUBGROUP_SIZE_NONE))
5356                                 {
5357                                     // trim test count
5358                                     continue;
5359                                 }
5360 
5361                                 if (colCases[colNdx].value && !(isMatrixMul || testType == TT_MUL))
5362                                 {
5363                                     // trim test count
5364                                     continue;
5365                                 }
5366 
5367                                 if (scCases[scNdx].value == SC_WORKGROUP ||
5368                                     scCases[scNdx].value == SC_WORKGROUP_VARIABLE_POINTERS ||
5369                                     addrCases[addrNdx].value == ADDR_LINEAR)
5370                                 {
5371                                     if (isClampTest(testType))
5372                                     {
5373                                         continue;
5374                                     }
5375                                 }
5376 
5377                                 uint32_t workgroupsX = 4u;
5378                                 uint32_t workgroupsY = 4u;
5379 
5380                                 uint32_t subgroupsPerWorkgroupX = 2;
5381                                 uint32_t subgroupsPerWorkgroupY = 2;
5382 
5383                                 // The program is meant to be run once
5384                                 if (isTensorLayoutTest(testType))
5385                                 {
5386                                     subgroupsPerWorkgroupX = 1;
5387                                     subgroupsPerWorkgroupY = 1;
5388                                     workgroupsX            = 1u;
5389                                     workgroupsY            = 1u;
5390                                 }
5391 
5392                                 CaseDef c = {
5393                                     testType, //  TestType testtype;
5394                                     (VkScopeKHR)scopeCases[scopeNdx]
5395                                         .value,                           //  VkScopeKHR                          scope;
5396                                     subgroupsPerWorkgroupX,               //  uint32_t subgroupsPerWorkgroupX;
5397                                     subgroupsPerWorkgroupY,               //  uint32_t subgroupsPerWorkgroupY;
5398                                     workgroupsX,                          //  uint32_t workgroupsX;
5399                                     workgroupsY,                          //  uint32_t workgroupsY;
5400                                     inputType,                            //  VkComponentTypeKHR inputType;
5401                                     outputType,                           //  VkComponentTypeKHR outputType;
5402                                     !!colCases[colNdx].value,             //  bool colMajor;
5403                                     (AddrMethod)addrCases[addrNdx].value, //  AddrMethod addrMethod;
5404                                     (StorageClass)scCases[scNdx].value,   //  StorageClass storageClass;
5405                                     useType,                              //  UseType useType;
5406                                     sgsCases[sgsNdx].value,               //  SubgroupSizeMode subgroupSizeMode;
5407                                     computePipelineConstructionType, //  vk::ComputePipelineConstructionType computePipelineConstructionType;
5408                                     1,                               //  uint32_t inputComponentCount;
5409                                     1,                               //  uint32_t outputComponentCount;
5410                                 };
5411                                 colGroup->addChild(new CooperativeMatrixTestCase(testCtx, addrCases[addrNdx].name, c));
5412                             }
5413                             scGroup->addChild(colGroup.release());
5414                         }
5415                         dtGroup->addChild(scGroup.release());
5416                     }
5417                     ttGroup->addChild(dtGroup.release());
5418                 }
5419                 scopeGroup->addChild(ttGroup.release());
5420             }
5421         }
5422 
5423         if (useType != UT_KHR_C)
5424         {
5425             const string name = string("convert");
5426             const string desc = string("OpFConvert/OpSConvert/OpUConvert/OpBitcast");
5427             de::MovePtr<tcu::TestCaseGroup> ttGroup(new tcu::TestCaseGroup(testCtx, name.c_str()));
5428 
5429             for (int dtNdx1 = 0; dtNdx1 < DE_LENGTH_OF_ARRAY(allTypes); dtNdx1++)
5430             {
5431                 for (int dtNdx2 = 0; dtNdx2 < DE_LENGTH_OF_ARRAY(allTypes); dtNdx2++)
5432                 {
5433                     const VkComponentTypeKHR inputType  = (VkComponentTypeKHR)allTypes[dtNdx1];
5434                     const VkComponentTypeKHR outputType = (VkComponentTypeKHR)allTypes[dtNdx2];
5435                     const string name2 = string("input_") + string(componentTypeInfo[inputType].typeName) +
5436                                          string("_output_") + string(componentTypeInfo[outputType].typeName);
5437                     de::MovePtr<tcu::TestCaseGroup> dtGroup(new tcu::TestCaseGroup(testCtx, name2.c_str()));
5438                     for (int scNdx = 0; scNdx < DE_LENGTH_OF_ARRAY(scCases); scNdx++)
5439                     {
5440                         de::MovePtr<tcu::TestCaseGroup> scGroup(new tcu::TestCaseGroup(testCtx, scCases[scNdx].name));
5441                         for (int colNdx = 0; colNdx < DE_LENGTH_OF_ARRAY(colCases); colNdx++)
5442                         {
5443 
5444                             if (scCases[scNdx].value == SC_BUFFER_VARIABLE_POINTERS ||
5445                                 scCases[scNdx].value == SC_WORKGROUP_VARIABLE_POINTERS)
5446                             {
5447                                 // trim test count
5448                                 continue;
5449                             }
5450 
5451                             if (colCases[colNdx].value)
5452                             {
5453                                 // trim test count
5454                                 continue;
5455                             }
5456 
5457                             AddrMethod addrMethod = (scopeCases[scopeNdx].value == VK_SCOPE_WORKGROUP_KHR) ?
5458                                                         ADDR_TENSORLAYOUT :
5459                                                         ADDR_LINEAR;
5460 
5461                             CaseDef c = {
5462                                 TT_CONVERT,                             //  TestType testtype;
5463                                 (VkScopeKHR)scopeCases[scopeNdx].value, //  VkScopeKHR                      scope;
5464                                 2u,                                     //  uint32_t subgroupsPerWorkgroupX;
5465                                 2u,                                     //  uint32_t subgroupsPerWorkgroupY;
5466                                 4u,                                     //  uint32_t workgroupsX;
5467                                 4u,                                     //  uint32_t workgroupsY;
5468                                 inputType,                              //  VkComponentTypeKHR inputType;
5469                                 outputType,                             //  VkComponentTypeKHR outputType;
5470                                 !!colCases[colNdx].value,               //  bool colMajor;
5471                                 addrMethod,                             //  AddrMethod addrMethod;
5472                                 (StorageClass)scCases[scNdx].value,     //  StorageClass storageClass;
5473                                 useType,                                //  UseType useType;
5474                                 SUBGROUP_SIZE_NONE,                     //  SubgroupSizeMode subgroupSizeMode;
5475                                 computePipelineConstructionType, //  vk::ComputePipelineConstructionType computePipelineConstructionType;
5476                                 1,                               //  uint32_t inputComponentCount;
5477                                 1,                               //  uint32_t outputComponentCount;
5478                             };
5479 
5480                             scGroup->addChild(new CooperativeMatrixTestCase(testCtx, colCases[colNdx].name, c));
5481                         }
5482                         dtGroup->addChild(scGroup.release());
5483                     }
5484                     ttGroup->addChild(dtGroup.release());
5485                 }
5486             }
5487             scopeGroup->addChild(ttGroup.release());
5488         }
5489 
5490         if (useType != UT_NV && useType != UT_KHR_C)
5491         {
5492             de::MovePtr<tcu::TestCaseGroup> ttGroup(
5493                 new tcu::TestCaseGroup(testCtx, "multicomponent", "Multicomponent types tests"));
5494             for (int ctNdx = 0; ctNdx < DE_LENGTH_OF_ARRAY(multicomponentTypes); ctNdx++)
5495             {
5496                 de::MovePtr<tcu::TestCaseGroup> ctGroup(new tcu::TestCaseGroup(testCtx, multicomponentTypes[ctNdx].name,
5497                                                                                multicomponentTypes[ctNdx].description));
5498                 const uint32_t componentCount = multicomponentTypes[ctNdx].componentCount;
5499 
5500                 for (int ioNdx = 0; ioNdx < DE_LENGTH_OF_ARRAY(ioTypes); ioNdx++)
5501                 {
5502                     de::MovePtr<tcu::TestCaseGroup> ioGroup(
5503                         new tcu::TestCaseGroup(testCtx, ioTypes[ioNdx].name, ioTypes[ioNdx].description));
5504                     const TestType testType             = ioTypes[ioNdx].testType;
5505                     const uint32_t inputComponentCount  = testType == TT_MULTICOMPONENT_LOAD ? componentCount : 1;
5506                     const uint32_t outputComponentCount = testType == TT_MULTICOMPONENT_LOAD ? 1 : componentCount;
5507 
5508                     for (int dtNdx = 0; dtNdx < DE_LENGTH_OF_ARRAY(allTypes); dtNdx++)
5509                     {
5510                         const VkComponentTypeKHR inputType = allTypes[dtNdx];
5511                         const string name                  = componentTypeInfo[inputType].typeName;
5512 
5513                         de::MovePtr<tcu::TestCaseGroup> dtGroup(new tcu::TestCaseGroup(testCtx, name.c_str(), ""));
5514                         for (int scNdx = 0; scNdx < DE_LENGTH_OF_ARRAY(scCases); scNdx++)
5515                         {
5516                             de::MovePtr<tcu::TestCaseGroup> scGroup(
5517                                 new tcu::TestCaseGroup(testCtx, scCases[scNdx].name, ""));
5518                             for (int colNdx = 0; colNdx < DE_LENGTH_OF_ARRAY(colCases); colNdx++)
5519                             {
5520                                 AddrMethod addrMethod = (scopeCases[scopeNdx].value == VK_SCOPE_WORKGROUP_KHR) ?
5521                                                             ADDR_TENSORLAYOUT :
5522                                                             ADDR_LINEAR;
5523 
5524                                 if (colCases[colNdx].value)
5525                                 {
5526                                     // trim test count
5527                                     continue;
5528                                 }
5529 
5530                                 CaseDef c = {
5531                                     testType,                               //  TestType testtype;
5532                                     (VkScopeKHR)scopeCases[scopeNdx].value, //  VkScopeKHR                      scope;
5533                                     2u,                                     //  uint32_t subgroupsPerWorkgroupX;
5534                                     2u,                                     //  uint32_t subgroupsPerWorkgroupY;
5535                                     4u,                                     //  uint32_t workgroupsX;
5536                                     4u,                                     //  uint32_t workgroupsY;
5537                                     inputType,                              //  VkComponentTypeKHR inputType;
5538                                     inputType,                              //  VkComponentTypeKHR outputType;
5539                                     !!colCases[colNdx].value,               //  bool colMajor;
5540                                     addrMethod,                             //  AddrMethod addrMethod;
5541                                     (StorageClass)scCases[scNdx].value,     //  StorageClass storageClass;
5542                                     useType,                                //  UseType useType;
5543                                     SUBGROUP_SIZE_NONE,                     //  SubgroupSizeMode subgroupSizeMode;
5544                                     computePipelineConstructionType, //  vk::ComputePipelineConstructionType computePipelineConstructionType;
5545                                     inputComponentCount,  //  uint32_t inputComponentCount;
5546                                     outputComponentCount, //  uint32_t outputComponentCount;
5547                                 };
5548 
5549                                 scGroup->addChild(new CooperativeMatrixTestCase(testCtx, colCases[colNdx].name, c));
5550                             }
5551                             dtGroup->addChild(scGroup.release());
5552                         }
5553                         ioGroup->addChild(dtGroup.release());
5554                     }
5555                     ctGroup->addChild(ioGroup.release());
5556                 }
5557                 ttGroup->addChild(ctGroup.release());
5558             }
5559             scopeGroup->addChild(ttGroup.release());
5560         }
5561         group->addChild(scopeGroup.release());
5562     }
5563 
5564     return group.release();
5565 }
5566 } // namespace
5567 
createCooperativeMatrixTests(tcu::TestContext & testCtx,vk::ComputePipelineConstructionType computePipelineConstructionType)5568 tcu::TestCaseGroup *createCooperativeMatrixTests(tcu::TestContext &testCtx,
5569                                                  vk::ComputePipelineConstructionType computePipelineConstructionType)
5570 {
5571     de::MovePtr<tcu::TestCaseGroup> group(new tcu::TestCaseGroup(testCtx, "cooperative_matrix"));
5572 
5573     group->addChild(createCooperativeMatrixTestsInternal(testCtx, computePipelineConstructionType, UT_NV));
5574     group->addChild(createCooperativeMatrixTestsInternal(testCtx, computePipelineConstructionType, UT_KHR_A));
5575     group->addChild(createCooperativeMatrixTestsInternal(testCtx, computePipelineConstructionType, UT_KHR_B));
5576     group->addChild(createCooperativeMatrixTestsInternal(testCtx, computePipelineConstructionType, UT_KHR_C));
5577     group->addChild(createCooperativeMatrixTestsInternal(testCtx, computePipelineConstructionType, UT_KHR_Result));
5578 
5579     return group.release();
5580 }
5581 
5582 } // namespace compute
5583 } // namespace vkt
5584